#define MODULE "vmm"

#include "memory/VMM.h"
#include "assert.h"
#include "log/Log.h"
#include "memory/PMM.h"
#include "misc/utils.h"
#include "std/string.h"

// FIXME: There is a lot of duplicate code in this file. This should probably be refactored.

static PageTable* PML4;

void VMM::init()
{
    asm volatile("mov %%cr3, %0" : "=r"(PML4));
}

void VMM::unmap(uint64_t virtualAddress)
{
    virtualAddress = Utilities::round_down_to_nearest_page(virtualAddress);

    PageDirectoryEntry* pde = find_pde(PML4, virtualAddress);
    if (!pde) return; // Already unmapped

    memset(pde, 0, sizeof(PageDirectoryEntry));
    flush_tlb(virtualAddress);
}

uint64_t VMM::getPhysical(uint64_t virtualAddress)
{
    PageDirectoryEntry* pde = find_pde(PML4, Utilities::round_down_to_nearest_page(virtualAddress));
    if (!pde) return UINT64_MAX; // Not mapped

    return pde->Address << 12 | (virtualAddress % PAGE_SIZE);
}

uint64_t VMM::getFlags(uint64_t virtualAddress)
{
    PageDirectoryEntry* pde = find_pde(PML4, Utilities::round_down_to_nearest_page(virtualAddress));
    if (!pde) return 0; // Not mapped

    uint64_t flags = 0;
    if (pde->UserSuper) flags |= User;
    if (pde->ReadWrite) flags |= ReadWrite;
    return flags;
}

void VMM::map(uint64_t virtualAddress, uint64_t physicalAddress, int flags)
{
    virtualAddress = Utilities::round_down_to_nearest_page(virtualAddress);
    PageDirectoryEntry* pde = find_pde(PML4, virtualAddress);
    bool will_flush_tlb = true;
    if (!pde)
    {
        pde = create_pde_if_not_exists(PML4, virtualAddress);
        will_flush_tlb = false;
    }
    else if (pde->LargerPages)
    {
        unmap(virtualAddress);
        pde = create_pde_if_not_exists(PML4, virtualAddress);
        will_flush_tlb = false;
    }

    pde->set_address(Utilities::round_down_to_nearest_page(physicalAddress));
    if (flags & User) propagate_user(PML4, virtualAddress);
    if (flags & ReadWrite) propagate_read_write(PML4, virtualAddress);
    if (will_flush_tlb) flush_tlb(virtualAddress);
}

PageDirectoryEntry* VMM::find_pde(PageTable* root, uint64_t virtualAddress)
{
    uint64_t page_index, pt_index, pd_index, pdp_index;
    PageDirectoryEntry* pde;
    PageTable* pt = root;

    decompose_vaddr(virtualAddress, page_index, pt_index, pd_index, pdp_index);

    uint64_t indexes[3] = {pdp_index, pd_index, pt_index};

    for (int i = 0; i < 3;
         i++) // Walk through the page map level 4, page directory pointer, and page directory to find the page table.
    {
        pde = &pt->entries[indexes[i]];
        if (!pde->Present) return nullptr;
        else if (pde->LargerPages)
            return pde;
        else { pt = (PageTable*)((uint64_t)pde->Address << 12); }
    }

    pde = &pt->entries[page_index]; // PT
    if (!pde->Present) return nullptr;
    return pde;
}

PageDirectoryEntry* VMM::create_pde_if_not_exists(PageTable* root, uint64_t virtualAddress)
{
    uint64_t page_index, pt_index, pd_index, pdp_index;
    PageDirectoryEntry* pde;
    PageTable* pt = root;

    decompose_vaddr(virtualAddress, page_index, pt_index, pd_index, pdp_index);

    auto pde_create_if_not_present = [&]() {
        pt = (PageTable*)PMM::request_page();
        ASSERT(!(PMM_DID_FAIL(pt)));
        memset(pt, 0, PAGE_SIZE);
        pde->set_address((uint64_t)pt);
        pde->Present = true;
    };

    uint64_t indexes[3] = {pdp_index, pd_index, pt_index};

    for (int i = 0; i < 3; i++)
    {
        pde = &pt->entries[indexes[i]];
        if (!pde->Present) { pde_create_if_not_present(); }
        else if (pde->LargerPages)
            return pde;
        else { pt = (PageTable*)((uint64_t)pde->Address << 12); }
    }

    pde = &pt->entries[page_index];
    if (!pde->Present) { pde->Present = true; }
    return pde;
}

void VMM::propagate_read_write(PageTable* root, uint64_t virtualAddress)
{
    uint64_t page_index, pt_index, pd_index, pdp_index;
    PageDirectoryEntry* pde;
    PageTable* pt = root;

    decompose_vaddr(virtualAddress, page_index, pt_index, pd_index, pdp_index);

    uint64_t indexes[3] = {pdp_index, pd_index, pt_index};

    for (int i = 0; i < 3; i++)
    {
        pde = &pt->entries[indexes[i]];
        if (!pde->Present) return;
        else
        {
            pde->ReadWrite = true;
            if (pde->LargerPages) return;
            pt = (PageTable*)((uint64_t)pde->Address << 12);
        }
    }

    pde = &pt->entries[page_index];
    if (!pde->Present) return;
    else
        pde->ReadWrite = true;
}

void VMM::propagate_user(PageTable* root, uint64_t virtualAddress)
{
    uint64_t page_index, pt_index, pd_index, pdp_index;
    PageDirectoryEntry* pde;
    PageTable* pt = root;

    decompose_vaddr(virtualAddress, page_index, pt_index, pd_index, pdp_index);

    uint64_t indexes[3] = {pdp_index, pd_index, pt_index};

    for (int i = 0; i < 3; i++)
    {
        pde = &pt->entries[indexes[i]];
        if (!pde->Present) return;
        else
        {
            pde->UserSuper = true;
            if (pde->LargerPages) return;
            pt = (PageTable*)((uint64_t)pde->Address << 12);
        }
    }

    pde = &pt->entries[page_index];
    if (!pde->Present) return;
    else
        pde->UserSuper = true;
}

void VMM::flush_tlb(uint64_t addr)
{
    asm volatile("invlpg (%0)" : : "r"(addr) : "memory");
}

void VMM::decompose_vaddr(uint64_t vaddr, uint64_t& page_index, uint64_t& pt_index, uint64_t& pd_index,
                          uint64_t& pdp_index)
{
    vaddr >>= 12;
    page_index = vaddr & 0x1ff;
    vaddr >>= 9;
    pt_index = vaddr & 0x1ff;
    vaddr >>= 9;
    pd_index = vaddr & 0x1ff;
    vaddr >>= 9;
    pdp_index = vaddr & 0x1ff;
}