#include "memory/KernelMemoryManager.h"
#include "assert.h"
#include "memory/KernelHeap.h"
#include "memory/RangeAllocator.h"
#include "memory/VMM.h"

void* KernelMemoryManager::get_mapping(void* physicalAddress)
{
    uint64_t virtualAddress = KernelHeap::request_virtual_page();
    kernelVMM.map(virtualAddress, (uint64_t)physicalAddress);
    return (void*)virtualAddress;
}

void* KernelMemoryManager::get_unaligned_mapping(void* physicalAddress)
{
    uint64_t offset = (uint64_t)physicalAddress % 4096;
    uint64_t virtualAddress = KernelHeap::request_virtual_page();
    kernelVMM.map(virtualAddress, (uint64_t)physicalAddress - offset);
    return (void*)(virtualAddress + offset);
}

void* KernelMemoryManager::get_unaligned_mappings(void* physicalAddress, uint64_t count)
{
    uint64_t offset = (uint64_t)physicalAddress % 4096;
    uint64_t virtualAddress = KernelHeap::request_virtual_pages(count);
    for (uint64_t i = 0; i < count; i++)
    {
        kernelVMM.map(virtualAddress + (i * 4096), ((uint64_t)physicalAddress - offset) + (i * 4096));
    }
    return (void*)(virtualAddress + offset);
}

void KernelMemoryManager::release_unaligned_mapping(void* mapping)
{
    uint64_t offset = (uint64_t)mapping % 4096;
    kernelVMM.unmap((uint64_t)mapping - offset);
}

void KernelMemoryManager::release_unaligned_mappings(void* mapping, uint64_t count)
{
    uint64_t offset = (uint64_t)mapping % 4096;
    for (uint64_t i = 0; i < count; i++) { kernelVMM.unmap(((uint64_t)mapping - offset) + (i * 4096)); }
}

void KernelMemoryManager::release_mapping(void* mapping)
{
    kernelVMM.unmap((uint64_t)mapping);
}

void* KernelMemoryManager::get_page()
{
    void* physicalAddress = kernelPMM.request_page();
    uint64_t virtualAddress = KernelHeap::request_virtual_page();
    kernelVMM.map(virtualAddress, (uint64_t)physicalAddress);
    return (void*)virtualAddress;
}

void KernelMemoryManager::release_page(void* page)
{
    uint64_t physicalAddress = kernelVMM.getPhysical((uint64_t)page);
    ASSERT(physicalAddress != UINT64_MAX);
    kernelVMM.unmap((uint64_t)page);
    kernelPMM.free_page((void*)physicalAddress);
    KernelHeap::free_virtual_page((uint64_t)page);
}

void* KernelMemoryManager::get_pages(uint64_t count)
{
    uint64_t virtualAddress = KernelHeap::request_virtual_pages(count);
    for (uint64_t i = 0; i < count; i++)
    {
        void* physicalAddress = kernelPMM.request_page();
        kernelVMM.map(virtualAddress + (i * 4096), (uint64_t)physicalAddress);
    }
    return (void*)virtualAddress;
}

void KernelMemoryManager::release_pages(void* pages, uint64_t count)
{
    for (uint64_t i = 0; i < count; i++)
    {
        void* page = (void*)((uint64_t)pages + (i * 4096));
        uint64_t physicalAddress = kernelVMM.getPhysical((uint64_t)page);
        ASSERT(physicalAddress != UINT64_MAX);
        kernelVMM.unmap((uint64_t)page);
        kernelPMM.free_page((void*)physicalAddress);
    }
    KernelHeap::free_virtual_pages((uint64_t)pages, count);
}