diff --git a/kernel/include/sys/Syscall.h b/kernel/include/sys/Syscall.h index a0a8e10a..e17137ed 100644 --- a/kernel/include/sys/Syscall.h +++ b/kernel/include/sys/Syscall.h @@ -17,6 +17,7 @@ #define SYS_seek 12 #define SYS_exec 13 #define SYS_fcntl 14 +#define SYS_mprotect 15 namespace Syscall { @@ -32,11 +33,12 @@ void sys_write(Context* context, int fd, size_t size, const char* addr); void sys_paint(Context* context, uint64_t x, uint64_t y, uint64_t w, uint64_t h, uint64_t col); void sys_rand(Context* context); void sys_gettid(Context* context); -void sys_mmap(Context* context, void* address, size_t size, int flags); +void sys_mmap(Context* context, void* address, size_t size, int prot); void sys_munmap(Context* context, void* address, size_t size); void sys_open(Context* context, const char* filename, int flags); void sys_read(Context* context, int fd, size_t size, char* buffer); void sys_close(Context* context, int fd); void sys_seek(Context* context, int fd, long offset, int whence); void sys_exec(Context* context, const char* pathname); -void sys_fcntl(Context* context, int fd, int command, uintptr_t arg); \ No newline at end of file +void sys_fcntl(Context* context, int fd, int command, uintptr_t arg); +void sys_mprotect(Context* context, void* address, size_t size, int prot); \ No newline at end of file diff --git a/kernel/src/sys/Syscall.cpp b/kernel/src/sys/Syscall.cpp index d4ac73b7..a1988d15 100644 --- a/kernel/src/sys/Syscall.cpp +++ b/kernel/src/sys/Syscall.cpp @@ -26,6 +26,7 @@ void Syscall::entry(Context* context) case SYS_seek: sys_seek(context, (int)context->rdi, (long)context->rsi, (int)context->rdx); break; case SYS_exec: sys_exec(context, (const char*)context->rdi); break; case SYS_fcntl: sys_fcntl(context, (int)context->rdi, (int)context->rsi, context->rdx); break; + case SYS_mprotect: sys_mprotect(context, (void*)context->rdi, context->rsi, (int)context->rdx); break; default: context->rax = -ENOSYS; break; } VMM::exit_syscall_context(); diff --git a/kernel/src/sys/mem.cpp b/kernel/src/sys/mem.cpp index d0744b74..ff406d8a 100644 --- a/kernel/src/sys/mem.cpp +++ b/kernel/src/sys/mem.cpp @@ -8,9 +8,30 @@ #include "misc/utils.h" #include +#define MAP_READ 1 +#define MAP_WRITE 2 +#define MAP_NONE 0 + #define MAP_FAIL(errno) 0xffffffffffffff00 | (unsigned char)(errno) -void sys_mmap(Context* context, void* address, size_t size, int flags) +static const char* format_prot(int prot) +{ + static char prot_string[3]; + prot_string[2] = 0; + prot_string[0] = ((prot & MAP_READ) > 0) ? 'r' : '-'; + prot_string[1] = ((prot & MAP_WRITE) > 0) ? 'w' : '-'; + return prot_string; +} + +static int mman_flags_from_prot(int prot) +{ + prot &= 0b11; + if (prot == MAP_NONE) return 0; + if ((prot & MAP_WRITE) > 0) return MAP_USER | MAP_READ_WRITE; + return MAP_USER; +} + +void sys_mmap(Context* context, void* address, size_t size, int prot) { if (size < PAGE_SIZE) { @@ -24,12 +45,10 @@ void sys_mmap(Context* context, void* address, size_t size, int flags) context->rax = MAP_FAIL(EINVAL); return; } - int real_flags = MAP_USER; - if (flags & MAP_READ_WRITE) real_flags |= MAP_READ_WRITE; + int real_flags = mman_flags_from_prot(prot); if (address) { - kdbgln("mmap(): %ld pages at address %p, %s", size / PAGE_SIZE, address, - real_flags & MAP_READ_WRITE ? "rw" : "ro"); + kdbgln("mmap(): %ld pages at address %p, %s", size / PAGE_SIZE, address, format_prot(prot)); if (VMM::get_physical((uint64_t)address) != (uint64_t)-1) // Address is already used. { kwarnln("attempt to map an already mapped address"); @@ -52,8 +71,7 @@ void sys_mmap(Context* context, void* address, size_t size, int flags) return; } } - kdbgln("mmap(): %ld pages at any address, %s", Utilities::get_blocks_from_size(PAGE_SIZE, size), - real_flags & MAP_READ_WRITE ? "rw" : "ro"); + kdbgln("mmap(): %ld pages at any address, %s", Utilities::get_blocks_from_size(PAGE_SIZE, size), format_prot(prot)); void* result = MemoryManager::get_pages(Utilities::get_blocks_from_size(PAGE_SIZE, size), real_flags); if (result) { @@ -102,4 +120,43 @@ void sys_munmap(Context* context, void* address, size_t size) kdbgln("munmap() succeeded"); context->rax = 0; return; +} + +void sys_mprotect(Context* context, void* address, size_t size, int prot) +{ + kdbgln("mprotect(): attempting to protect %p with %s", address, format_prot(prot)); + + if (size < PAGE_SIZE) + { + kwarnln("mprotect() failed: size is too small"); + context->rax = -EINVAL; + return; + } + if (size % PAGE_SIZE) + { + kwarnln("mprotect() failed: size is not a multiple of PAGE_SIZE"); + context->rax = -EINVAL; + return; + } + if (!address) + { + kwarnln("mprotect() failed: attempted to unmap page 0"); + context->rax = -EINVAL; + return; + } + + uint64_t flags = VMM::get_flags((uint64_t)address); + if (!(flags & MAP_USER)) + { + kwarnln("mprotect() failed: attempted to protect a non-existent or kernel page"); + context->rax = -EINVAL; + return; + } + + uint64_t offset = (uint64_t)address % PAGE_SIZE; + MemoryManager::protect((void*)((uint64_t)address - offset), Utilities::get_blocks_from_size(PAGE_SIZE, size), + mman_flags_from_prot(prot)); + kdbgln("mprotect() succeeded"); + context->rax = 0; + return; } \ No newline at end of file diff --git a/libs/libc/include/sys/mman.h b/libs/libc/include/sys/mman.h index fbb6d3c5..63bfc921 100644 --- a/libs/libc/include/sys/mman.h +++ b/libs/libc/include/sys/mman.h @@ -7,7 +7,9 @@ /* Address returned by mmap when it fails. */ #define MAP_FAILED (void*)-1 -#define PROT_READ_WRITE 1 +#define PROT_NONE 0 +#define PROT_READ 1 +#define PROT_WRITE 2 #define PAGE_SIZE 4096 @@ -24,6 +26,9 @@ extern "C" * address space. */ int munmap(void* addr, size_t size); + /* Protects size bytes of memory according to the prot argument. */ + int mprotect(void* addr, size_t size, int prot); + #ifdef __cplusplus } #endif diff --git a/libs/libc/src/bits/bindings.c b/libs/libc/src/bits/bindings.c index d884575b..5c58ef0b 100644 --- a/libs/libc/src/bits/bindings.c +++ b/libs/libc/src/bits/bindings.c @@ -13,7 +13,7 @@ int liballoc_unlock() void* liballoc_alloc(size_t size) { - void* result = mmap(NULL, size * PAGE_SIZE, PROT_READ_WRITE, 0, 0, 0); + void* result = mmap(NULL, size * PAGE_SIZE, PROT_READ | PROT_WRITE, 0, 0, 0); if (result == MAP_FAILED) return 0; return (void*)result; }