From 2572695c8d3b79da3a7de5a54b54b384db3453ba Mon Sep 17 00:00:00 2001 From: apio Date: Wed, 2 Aug 2023 22:19:06 +0200 Subject: [PATCH] kernel: Support mapping shared memory using mmap() --- kernel/src/binfmt/ELF.cpp | 8 ++- kernel/src/memory/AddressSpace.cpp | 105 ++++++++++++++++++++++------ kernel/src/memory/AddressSpace.h | 16 +++-- kernel/src/memory/MemoryManager.cpp | 66 +++++++++++++++++ kernel/src/memory/MemoryManager.h | 4 ++ kernel/src/sys/mmap.cpp | 86 +++++++++++++++-------- kernel/src/thread/ThreadImage.cpp | 5 +- libc/include/bits/mmap.h | 13 ++++ 8 files changed, 244 insertions(+), 59 deletions(-) create mode 100644 libc/include/bits/mmap.h diff --git a/kernel/src/binfmt/ELF.cpp b/kernel/src/binfmt/ELF.cpp index c4dc19cd..d5c0dc34 100644 --- a/kernel/src/binfmt/ELF.cpp +++ b/kernel/src/binfmt/ELF.cpp @@ -3,6 +3,7 @@ #include "arch/CPU.h" #include "arch/MMU.h" #include "memory/MemoryManager.h" +#include #include #include #include @@ -108,8 +109,13 @@ Result ELFLoader::load(AddressSpace* space) if (can_write_segment(program_header.p_flags)) flags |= MMU::ReadWrite; if (can_execute_segment(program_header.p_flags)) flags &= ~MMU::NoExecute; + int prot = PROT_READ; + if (can_write_segment(program_header.p_flags)) prot |= PROT_WRITE; + if (can_execute_segment(program_header.p_flags)) prot |= PROT_EXEC; + if (!TRY(space->test_and_alloc_region( - base_vaddr, get_blocks_from_size(program_header.p_memsz + vaddr_diff, ARCH_PAGE_SIZE), true))) + base_vaddr, get_blocks_from_size(program_header.p_memsz + vaddr_diff, ARCH_PAGE_SIZE), prot, + MAP_ANONYMOUS | MAP_PRIVATE, 0, true))) return err(ENOMEM); // Allocate physical memory for the segment diff --git a/kernel/src/memory/AddressSpace.cpp b/kernel/src/memory/AddressSpace.cpp index a485c627..5ffefe8d 100644 --- a/kernel/src/memory/AddressSpace.cpp +++ b/kernel/src/memory/AddressSpace.cpp @@ -1,7 +1,9 @@ #include "memory/AddressSpace.h" -#include "Log.h" #include "arch/MMU.h" #include "memory/Heap.h" +#include "memory/MemoryManager.h" +#include "memory/SharedMemory.h" +#include #include #include @@ -49,14 +51,25 @@ Result> AddressSpace::clone() { OwnedPtr ptr = TRY(make_owned()); + ptr->m_directory = TRY(MMU::clone_userspace_page_directory(m_directory)); + for (const auto* region : m_regions) { auto* new_region = TRY(make()); memcpy(new_region, region, sizeof(*region)); ptr->m_regions.append(new_region); - } - ptr->m_directory = TRY(MMU::clone_userspace_page_directory(m_directory)); + if (new_region->used && new_region->prot != 0 && new_region->flags & MAP_SHARED) + { + TRY(MemoryManager::copy_region(new_region->start, new_region->count, m_directory, ptr->m_directory)); + auto* shm = g_shared_memory_map.try_get_ref(new_region->shmid); + if (shm) shm->refs++; + } + else if (new_region->used && new_region->prot != 0) + { + TRY(MemoryManager::copy_region_data(new_region->start, new_region->count, m_directory, ptr->m_directory)); + } + } return move(ptr); } @@ -81,8 +94,16 @@ AddressSpace& AddressSpace::operator=(AddressSpace&& other) return *this; } -Result AddressSpace::alloc_region(usize count, bool persistent) +Result AddressSpace::alloc_region(usize count, int prot, int flags, u64 shmid, bool persistent) { + auto update_region = [=](VMRegion* region) { + region->used = true; + region->persistent = persistent; + region->prot = prot; + region->flags = flags; + region->shmid = shmid; + }; + for (auto* region = m_regions.expect_last(); region; region = m_regions.previous(region).value_or(nullptr)) { if (!region->used) @@ -90,8 +111,7 @@ Result AddressSpace::alloc_region(usize count, bool persistent) if (region->count < count) continue; if (region->count == count) { - region->used = true; - region->persistent = persistent; + update_region(region); u64 address = region->start; try_merge_region_with_neighbors(region); return address; @@ -100,8 +120,7 @@ Result AddressSpace::alloc_region(usize count, bool persistent) u64 boundary = region->end - (count * ARCH_PAGE_SIZE); auto* new_region = TRY(split_region(region, boundary)); - new_region->used = true; - new_region->persistent = persistent; + update_region(new_region); try_merge_region_with_neighbors(new_region); return boundary; @@ -111,12 +130,22 @@ Result AddressSpace::alloc_region(usize count, bool persistent) return err(ENOMEM); } -Result AddressSpace::set_region(u64 address, usize count, bool used, bool persistent) +Result AddressSpace::set_region(u64 address, usize count, bool used, int prot, int flags, u64 shmid, + bool persistent) { if (address >= VM_END) return err(EINVAL); u64 end = address + (count * ARCH_PAGE_SIZE); + auto update_region = [=](VMRegion* region) { + if (!used) region->cleanup_shared(); + region->used = used; + region->persistent = persistent; + region->prot = prot; + region->flags = flags; + region->shmid = shmid; + }; + for (auto* region : m_regions) { if (region->end < address) continue; @@ -131,13 +160,13 @@ Result AddressSpace::set_region(u64 address, usize count, bool used, bool if (region->start >= address && region->end <= end) { - region->used = used; - region->persistent = persistent; + update_region(region); if (region->start == address && region->end == end) { try_merge_region_with_neighbors(region); return true; } + try_merge_region_with_neighbors(region); continue; } @@ -145,8 +174,7 @@ Result AddressSpace::set_region(u64 address, usize count, bool used, bool { auto* middle_region = TRY(split_region(region, address)); TRY(split_region(middle_region, end)); - middle_region->used = used; - middle_region->persistent = persistent; + update_region(middle_region); return true; } @@ -154,8 +182,7 @@ Result AddressSpace::set_region(u64 address, usize count, bool used, bool { bool finished = region->end == end; auto* split = TRY(split_region(region, address)); - split->used = used; - split->persistent = persistent; + update_region(split); try_merge_region_with_neighbors(split); if (!finished) continue; return true; @@ -164,8 +191,7 @@ Result AddressSpace::set_region(u64 address, usize count, bool used, bool if (region->end > end) { TRY(split_region(region, end)); - region->used = used; - region->persistent = persistent; + update_region(region); try_merge_region_with_neighbors(region); return true; } @@ -184,18 +210,24 @@ void AddressSpace::merge_contiguous_regions(VMRegion* a, VMRegion* b) void AddressSpace::try_merge_region_with_neighbors(VMRegion* region) { + auto equals = [](VMRegion* a, VMRegion* b) { + if (a->used != b->used) return false; + if (a->persistent != b->persistent) return false; + if (a->prot != b->prot) return false; + if (a->flags != b->flags) return false; + if (a->shmid != b->shmid) return false; + return true; + }; + auto prev = m_regions.previous(region); - if (prev.has_value() && (*prev)->used == region->used && (*prev)->persistent == region->persistent) + if (prev.has_value() && equals(*prev, region)) { merge_contiguous_regions(*prev, region); region = *prev; } auto next = m_regions.next(region); - if (next.has_value() && (*next)->used == region->used && (*next)->persistent == region->persistent) - { - merge_contiguous_regions(region, *next); - } + if (next.has_value() && equals(*next, region)) { merge_contiguous_regions(region, *next); } } Result AddressSpace::split_region(VMRegion* parent, u64 boundary) @@ -207,6 +239,9 @@ Result AddressSpace::split_region(VMRegion* parent, u64 boundary) region->count = (region->end - region->start) / ARCH_PAGE_SIZE; region->used = parent->used; region->persistent = parent->persistent; + region->prot = parent->prot; + region->flags = parent->flags; + region->shmid = parent->shmid; m_regions.add_after(parent, region); parent->end = boundary; @@ -217,6 +252,30 @@ Result AddressSpace::split_region(VMRegion* parent, u64 boundary) AddressSpace::~AddressSpace() { - m_regions.consume([](VMRegion* region) { delete region; }); + auto* directory = MMU::get_page_directory(); + MMU::switch_page_directory(this->m_directory); + m_regions.consume([this](VMRegion* region) { + region->cleanup_shared(); + delete region; + }); + MMU::switch_page_directory(directory); + if (m_directory) MMU::delete_userspace_page_directory(m_directory); } + +void VMRegion::cleanup_shared() +{ + if (used && (flags & MAP_SHARED)) + { + SharedMemory* shmem = g_shared_memory_map.try_get_ref(shmid); + if (shmem) + { + for (u64 addr = start; addr < end; addr += ARCH_PAGE_SIZE) { MMU::unmap(addr); } + if (--shmem->refs == 0) + { + shmem->free(); + g_shared_memory_map.try_remove(shmid); + } + } + } +} diff --git a/kernel/src/memory/AddressSpace.h b/kernel/src/memory/AddressSpace.h index e096a41e..9a96f51d 100644 --- a/kernel/src/memory/AddressSpace.h +++ b/kernel/src/memory/AddressSpace.h @@ -12,6 +12,11 @@ class VMRegion : LinkedListNode usize count; bool used { true }; bool persistent { false }; + int flags { 0 }; + int prot { 0 }; + u64 shmid; + + void cleanup_shared(); }; class AddressSpace @@ -22,16 +27,17 @@ class AddressSpace AddressSpace& operator=(AddressSpace&& other); - Result alloc_region(usize count, bool persistent = false); + Result alloc_region(usize count, int prot, int flags, u64 shmid = 0, bool persistent = false); - Result test_and_alloc_region(u64 address, usize count, bool persistent = false) + Result test_and_alloc_region(u64 address, usize count, int prot, int flags, u64 shmid = 0, + bool persistent = false) { - return set_region(address, count, true, persistent); + return set_region(address, count, true, prot, flags, shmid, persistent); } Result free_region(u64 address, usize count) { - return set_region(address, count, false, false); + return set_region(address, count, false, 0, 0, 0, false); } static Result> try_create(); @@ -44,7 +50,7 @@ class AddressSpace } private: - Result set_region(u64 address, usize count, bool used, bool persistent); + Result set_region(u64 address, usize count, bool used, int prot, int flags, u64 shmid, bool persistent); Result create_default_region(); Result create_null_region(); void try_merge_region_with_neighbors(VMRegion* region); diff --git a/kernel/src/memory/MemoryManager.cpp b/kernel/src/memory/MemoryManager.cpp index 7763d4b6..2c86a8b8 100644 --- a/kernel/src/memory/MemoryManager.cpp +++ b/kernel/src/memory/MemoryManager.cpp @@ -1,4 +1,5 @@ #include "memory/MemoryManager.h" +#include "Log.h" #include "arch/MMU.h" #include "memory/KernelVM.h" #include "memory/MemoryMap.h" @@ -226,6 +227,57 @@ namespace MemoryManager return {}; } + Result copy_region(u64 virt, usize count, PageDirectory* oldpd, PageDirectory* newpd) + { + CHECK_PAGE_ALIGNED(virt); + + usize pages_mapped = 0; + + // Let's clean up after ourselves if we fail. + auto guard = make_scope_guard( + [=, &pages_mapped] { kwarnln("copy_region failed, sorry! cannot reclaim already copied pages"); }); + + while (pages_mapped < count) + { + u64 phys = TRY(MMU::get_physical(virt, oldpd)); + int flags = TRY(MMU::get_flags(virt, oldpd)); + TRY(MMU::map(virt, phys, flags, MMU::UseHugePages::No, newpd)); + virt += ARCH_PAGE_SIZE; + pages_mapped++; + } + + guard.deactivate(); + + return {}; + } + + Result copy_region_data(u64 virt, usize count, PageDirectory* oldpd, PageDirectory* newpd) + { + CHECK_PAGE_ALIGNED(virt); + + usize pages_mapped = 0; + + // Let's clean up after ourselves if we fail. + auto guard = make_scope_guard( + [=, &pages_mapped] { kwarnln("copy_region_data failed, sorry! cannot reclaim already copied pages"); }); + + while (pages_mapped < count) + { + u64 frame = TRY(alloc_frame()); + u64 phys = TRY(MMU::get_physical(virt, oldpd)); + int flags = TRY(MMU::get_flags(virt, oldpd)); + memcpy((void*)MMU::translate_physical_address(frame), (void*)MMU::translate_physical_address(phys), + ARCH_PAGE_SIZE); + TRY(MMU::map(virt, frame, flags, MMU::UseHugePages::No, newpd)); + virt += ARCH_PAGE_SIZE; + pages_mapped++; + } + + guard.deactivate(); + + return {}; + } + Result map_huge_frames_at(u64 virt, u64 phys, usize count, int flags) { CHECK_PAGE_ALIGNED(virt); @@ -358,6 +410,20 @@ namespace MemoryManager return {}; } + Result unmap_owned_if_possible(u64 virt, usize count) + { + CHECK_PAGE_ALIGNED(virt); + + while (count--) + { + const auto frame = MMU::unmap(virt); + if (frame.has_value()) TRY(free_frame(frame.value())); + virt += ARCH_PAGE_SIZE; + } + + return {}; + } + Result unmap_owned_and_free_vm(u64 virt, usize count) { CHECK_PAGE_ALIGNED(virt); diff --git a/kernel/src/memory/MemoryManager.h b/kernel/src/memory/MemoryManager.h index e16130bd..23884712 100644 --- a/kernel/src/memory/MemoryManager.h +++ b/kernel/src/memory/MemoryManager.h @@ -67,6 +67,9 @@ namespace MemoryManager Result map_frames_at(u64 virt, u64 phys, usize count, int flags); Result map_huge_frames_at(u64 virt, u64 phys, usize count, int flags); + Result copy_region(u64 virt, usize count, PageDirectory* oldpd, PageDirectory* newpd); + Result copy_region_data(u64 virt, usize count, PageDirectory* oldpd, PageDirectory* newpd); + Result alloc_at(u64 virt, usize count, int flags); Result alloc_at_zeroed(u64, usize count, int flags); Result alloc_for_kernel(usize count, int flags); @@ -75,6 +78,7 @@ namespace MemoryManager Result unmap_owned(u64 virt, usize count); Result unmap_owned_and_free_vm(u64 virt, usize count); + Result unmap_owned_if_possible(u64 virt, usize count); Result unmap_weak(u64 virt, usize count); Result unmap_weak_and_free_vm(u64 virt, usize count); diff --git a/kernel/src/sys/mmap.cpp b/kernel/src/sys/mmap.cpp index 8f02b278..3bde30ac 100644 --- a/kernel/src/sys/mmap.cpp +++ b/kernel/src/sys/mmap.cpp @@ -1,62 +1,90 @@ #include "Log.h" #include "arch/MMU.h" #include "memory/MemoryManager.h" +#include "memory/SharedMemory.h" #include "sys/Syscall.h" #include "thread/Scheduler.h" #include +#include +#include #include constexpr uintptr_t USERSPACE_HEAP_BASE = 0x3000000; Result sys_mmap(Registers*, SyscallArgs args) { - void* addr = (void*)args[0]; - usize len = (usize)args[1]; - int prot = (int)args[2]; - int flags = (int)args[3]; + mmap_params params; + if (!MemoryManager::copy_from_user_typed((const mmap_params*)args[0], ¶ms)) return err(EFAULT); - if (len == 0) return err(EINVAL); + kdbgln("mmap: address=%p, len=%zu, prot=%d, flags=%d, fd=%d, offset=%lu", params.addr, params.len, params.prot, + params.flags, params.fd, params.offset); - if (flags < 0) return err(EINVAL); + if (params.len == 0) return err(EINVAL); - // We support only anonymous mappings for now. - if ((flags & MAP_ANONYMOUS) != MAP_ANONYMOUS) - { - kwarnln("mmap: FIXME: attempt to mmap file instead of anonymous memory"); - return err(ENOTSUP); - } - - if (flags & MAP_SHARED) - { - kwarnln("mmap: FIXME: attempt to mmap shared memory"); - return err(ENOTSUP); - } - - len = align_up(len); + if (params.flags < 0) return err(EINVAL); Thread* current = Scheduler::current(); + SharedPtr description; + if ((params.flags & MAP_ANONYMOUS) != MAP_ANONYMOUS) + { + description = TRY(current->resolve_fd(params.fd))->description; + if (!(description->flags & O_RDONLY)) return err(EACCES); + if (description->flags & O_APPEND) return err(EACCES); + } + + if (!is_aligned(params.offset)) return err(EINVAL); + + params.len = align_up(params.len); + + SharedMemory* shmem = nullptr; + u64 shmid = 0; + if (params.flags & MAP_SHARED) + { + if (!description) + { + params.offset = 0; + shmid = TRY(SharedMemory::create(nullptr, 0, params.len / ARCH_PAGE_SIZE)); + } + else + { + if ((params.prot & PROT_WRITE) && !(description->flags & O_WRONLY)) return err(EACCES); + shmid = TRY(description->inode->query_shared_memory(params.offset, params.len)); + } + shmem = g_shared_memory_map.try_get_ref(shmid); + shmem->refs++; + } + u64 address; - if (!addr) address = TRY(current->address_space->alloc_region(get_blocks_from_size(len, ARCH_PAGE_SIZE))); + if (!params.addr) + address = TRY(current->address_space->alloc_region(get_blocks_from_size(params.len, ARCH_PAGE_SIZE), + params.prot, params.flags, shmid)); else { // FIXME: We should be more flexible if MAP_FIXED was not specified. - address = align_down((u64)addr); - if (!TRY(current->address_space->test_and_alloc_region(address, get_blocks_from_size(len, ARCH_PAGE_SIZE)))) + address = align_down((u64)params.addr); + if (!TRY(current->address_space->test_and_alloc_region( + address, get_blocks_from_size(params.len, ARCH_PAGE_SIZE), params.prot, params.flags, shmid))) return err(ENOMEM); } int mmu_flags = MMU::User | MMU::NoExecute; - if (prot & PROT_WRITE) mmu_flags |= MMU::ReadWrite; - if (prot & PROT_EXEC) mmu_flags &= ~MMU::NoExecute; - if (prot == PROT_NONE) mmu_flags = MMU::NoExecute; + if (params.prot & PROT_WRITE) mmu_flags |= MMU::ReadWrite; + if (params.prot & PROT_EXEC) mmu_flags &= ~MMU::NoExecute; + if (params.prot == PROT_NONE) mmu_flags = MMU::NoExecute; #ifdef MMAP_DEBUG kdbgln("mmap: mapping memory at %#lx, size=%zu", address, len); #endif - // FIXME: This leaks VM if it fails. - return MemoryManager::alloc_at_zeroed(address, get_blocks_from_size(len, ARCH_PAGE_SIZE), mmu_flags); + if (shmem) { TRY(shmem->map(address, mmu_flags, params.offset, get_blocks_from_size(params.len, ARCH_PAGE_SIZE))); } + else + { + TRY(MemoryManager::alloc_at_zeroed(address, get_blocks_from_size(params.len, ARCH_PAGE_SIZE), mmu_flags)); + if (description) { TRY(description->inode->read((u8*)address, params.offset, params.len)); } + } + + return address; } Result sys_munmap(Registers*, SyscallArgs args) @@ -78,7 +106,7 @@ Result sys_munmap(Registers*, SyscallArgs args) kdbgln("munmap: unmapping memory at %#lx, size=%zu", address, size); #endif - TRY(MemoryManager::unmap_owned(address, get_blocks_from_size(size, ARCH_PAGE_SIZE))); + TRY(MemoryManager::unmap_owned_if_possible(address, get_blocks_from_size(size, ARCH_PAGE_SIZE))); return { 0 }; } diff --git a/kernel/src/thread/ThreadImage.cpp b/kernel/src/thread/ThreadImage.cpp index 43d1c9ac..9b79e0c1 100644 --- a/kernel/src/thread/ThreadImage.cpp +++ b/kernel/src/thread/ThreadImage.cpp @@ -1,6 +1,7 @@ #include "thread/ThreadImage.h" #include "memory/MemoryManager.h" #include "thread/Thread.h" +#include #include #include @@ -15,7 +16,9 @@ static Result create_user_stack(Stack& user_stack, AddressSpace* space) auto guard = make_scope_guard([] { MemoryManager::unmap_owned(THREAD_STACK_BASE, DEFAULT_USER_STACK_PAGES); }); - if (!TRY(space->test_and_alloc_region(THREAD_STACK_BASE, DEFAULT_USER_STACK_PAGES, true))) return err(ENOMEM); + if (!TRY(space->test_and_alloc_region(THREAD_STACK_BASE, DEFAULT_USER_STACK_PAGES, PROT_READ | PROT_WRITE, + MAP_ANONYMOUS | MAP_PRIVATE, 0, true))) + return err(ENOMEM); guard.deactivate(); diff --git a/libc/include/bits/mmap.h b/libc/include/bits/mmap.h new file mode 100644 index 00000000..1882b549 --- /dev/null +++ b/libc/include/bits/mmap.h @@ -0,0 +1,13 @@ +#pragma once +#include +#include + +struct mmap_params +{ + void* addr; + size_t len; + int prot; + int flags; + int fd; + off_t offset; +};