#define MODULE "mem"

#include "interrupts/Context.h"
#include "log/Log.h"
#include "memory/Memory.h"
#include "memory/MemoryManager.h"
#include "memory/VMM.h"
#include "misc/utils.h"
#include "std/errno.h"
#include "thread/Scheduler.h"
#include <stddef.h>

#define MAP_READ 1
#define MAP_WRITE 2
#define MAP_NONE 0

#define MAP_FAIL(errno) 0xffffffffffffff00 | (unsigned char)(errno)

static const char* format_prot(int prot)
{
    static char prot_string[3];
    prot_string[2] = 0;
    prot_string[0] = ((prot & MAP_READ) > 0) ? 'r' : '-';
    prot_string[1] = ((prot & MAP_WRITE) > 0) ? 'w' : '-';
    return prot_string;
}

static int mman_flags_from_prot(int prot)
{
    prot &= 0b11;
    if (prot == MAP_NONE) return 0;
    if ((prot & MAP_WRITE) > 0) return MAP_USER | MAP_READ_WRITE;
    return MAP_USER;
}

void sys_mmap(Context* context, void* address, size_t size, int prot)
{
    if (size < PAGE_SIZE)
    {
        kwarnln("mmap(): size too small");
        context->rax = MAP_FAIL(EINVAL);
        return;
    }
    if (size % PAGE_SIZE)
    {
        kwarnln("mmap(): size not a multiple of PAGE_SIZE");
        context->rax = MAP_FAIL(EINVAL);
        return;
    }
    int real_flags = mman_flags_from_prot(prot);
    if (address)
    {
        kdbgln("mmap(): %ld pages at address %p, %s", size / PAGE_SIZE, address, format_prot(prot));
        if (Memory::is_kernel_address((uintptr_t)address))
        {
            kwarnln("munmap() failed: attempted to unmap a kernel page");
            context->rax = MAP_FAIL(ENOMEM);
            return;
        }
        if (VMM::get_physical((uint64_t)address) != (uint64_t)-1) // Address is already used.
        {
            kwarnln("attempt to map an already mapped address");
            context->rax = MAP_FAIL(ENOMEM);
            return;
        }
        uint64_t offset = (uint64_t)address % PAGE_SIZE;
        void* result = MemoryManager::get_pages_at((uint64_t)address - offset,
                                                   Utilities::get_blocks_from_size(PAGE_SIZE, size), real_flags);
        if (result)
        {
            kdbgln("mmap() succeeded: %p", result);
            context->rax = (uint64_t)result;
            return;
        }
        else
        {
            kwarnln("mmap() failed: failed to allocate physical memory");
            context->rax = MAP_FAIL(ENOMEM);
            return;
        }
    }
    kdbgln("mmap(): %ld pages at any address, %s", Utilities::get_blocks_from_size(PAGE_SIZE, size), format_prot(prot));
    uint64_t ptr =
        Scheduler::current_task()->allocator.request_virtual_pages(Utilities::get_blocks_from_size(PAGE_SIZE, size));
    if (!ptr)
    {
        kwarnln("mmap() failed: failed to allocate virtual address");
        context->rax = MAP_FAIL(ENOMEM);
        return;
    }
    void* result = MemoryManager::get_pages_at(ptr, Utilities::get_blocks_from_size(PAGE_SIZE, size), real_flags);
    if (result)
    {
        kdbgln("mmap() succeeded: %p", result);
        context->rax = (uint64_t)result;
        return;
    }
    else
    {
        kwarnln("mmap() failed: failed to allocate physical memory");
        context->rax = MAP_FAIL(ENOMEM);
        return;
    }
}

void sys_munmap(Context* context, void* address, size_t size)
{
    kdbgln("munmap(): attempting to unmap %p", address);
    if (size < PAGE_SIZE)
    {
        kwarnln("munmap() failed: size is too small");
        context->rax = -EINVAL;
        return;
    }
    if (size % PAGE_SIZE)
    {
        kwarnln("munmap() failed: size is not a multiple of PAGE_SIZE");
        context->rax = -EINVAL;
        return;
    }
    if (!address)
    {
        kwarnln("munmap() failed: attempted to unmap page 0");
        context->rax = -EINVAL;
        return;
    }
    if (Memory::is_kernel_address((uintptr_t)address))
    {
        kwarnln("munmap() failed: attempted to unmap a kernel page");
        context->rax = -EINVAL;
        return;
    }
    uint64_t phys = VMM::get_physical((uint64_t)address);
    if (phys == (uint64_t)-1)
    {
        kwarnln("munmap() failed: attempted to unmap a non-existent page");
        context->rax = -EINVAL;
        return;
    }
    uint64_t offset = (uint64_t)address % PAGE_SIZE;
    Scheduler::current_task()->allocator.free_virtual_pages(((uint64_t)address - offset),
                                                            Utilities::get_blocks_from_size(PAGE_SIZE, size));
    MemoryManager::release_pages((void*)((uint64_t)address - offset), Utilities::get_blocks_from_size(PAGE_SIZE, size));
    kdbgln("munmap() succeeded");
    context->rax = 0;
    return;
}

void sys_mprotect(Context* context, void* address, size_t size, int prot)
{
    kdbgln("mprotect(): attempting to protect %p with %s", address, format_prot(prot));

    if (size < PAGE_SIZE)
    {
        kwarnln("mprotect() failed: size is too small");
        context->rax = -EINVAL;
        return;
    }
    if (size % PAGE_SIZE)
    {
        kwarnln("mprotect() failed: size is not a multiple of PAGE_SIZE");
        context->rax = -EINVAL;
        return;
    }
    if (!address)
    {
        kwarnln("mprotect() failed: attempted to protect page 0");
        context->rax = -EINVAL;
        return;
    }
    if (Memory::is_kernel_address((uintptr_t)address))
    {
        kwarnln("mprotect() failed: attempted to protect a kernel page");
        context->rax = -EINVAL;
        return;
    }
    uint64_t phys = VMM::get_physical((uint64_t)address);
    if (phys == (uint64_t)-1)
    {
        kwarnln("mprotect() failed: attempted to protect a non-existent page");
        context->rax = -EINVAL;
        return;
    }

    uint64_t offset = (uint64_t)address % PAGE_SIZE;
    MemoryManager::protect((void*)((uint64_t)address - offset), Utilities::get_blocks_from_size(PAGE_SIZE, size),
                           mman_flags_from_prot(prot));
    kdbgln("mprotect() succeeded");
    context->rax = 0;
    return;
}