#include "memory/VMM.h"
#include "assert.h"
#include "memory/PMM.h"
#include "std/string.h"

Paging::VirtualMemoryManager kernelVMM;

namespace Paging
{
    void VirtualMemoryManager::init()
    {
        asm volatile("mov %%cr3, %0" : "=r"(PML4));
    }

    void VirtualMemoryManager::init(PageTable* PML4)
    {
        this->PML4 = PML4;
    }

    void VirtualMemoryManager::unmap(uint64_t virtualAddress)
    {
        virtualAddress >>= 12;
        uint64_t P_i = virtualAddress & 0x1ff;
        virtualAddress >>= 9;
        uint64_t PT_i = virtualAddress & 0x1ff;
        virtualAddress >>= 9;
        uint64_t PD_i = virtualAddress & 0x1ff;
        virtualAddress >>= 9;
        uint64_t PDP_i = virtualAddress & 0x1ff;

        PageDirectoryEntry PDE;

        PDE = PML4->entries[PDP_i];
        PageTable* PDP;
        if (!PDE.Present)
        {
            return; // Already unmapped
        }
        else { PDP = (PageTable*)((uint64_t)PDE.Address << 12); }

        PDE = PDP->entries[PD_i];
        PageTable* PD;
        if (!PDE.Present)
        {
            return; // Already unmapped
        }
        else { PD = (PageTable*)((uint64_t)PDE.Address << 12); }

        PDE = PD->entries[PT_i];
        PageTable* PT;
        if (!PDE.Present)
        {
            return; // Already unmapped
        }
        else { PT = (PageTable*)((uint64_t)PDE.Address << 12); }

        PDE = PT->entries[P_i];
        PDE.Present = false;
        PT->entries[P_i] = PDE;
    }

    uint64_t VirtualMemoryManager::getPhysical(uint64_t virtualAddress)
    {
        virtualAddress >>= 12;
        uint64_t P_i = virtualAddress & 0x1ff;
        virtualAddress >>= 9;
        uint64_t PT_i = virtualAddress & 0x1ff;
        virtualAddress >>= 9;
        uint64_t PD_i = virtualAddress & 0x1ff;
        virtualAddress >>= 9;
        uint64_t PDP_i = virtualAddress & 0x1ff;

        PageDirectoryEntry PDE;

        PDE = PML4->entries[PDP_i];
        PageTable* PDP;
        if (!PDE.Present)
        {
            return UINT64_MAX; // Not mapped
        }
        else { PDP = (PageTable*)((uint64_t)PDE.Address << 12); }

        PDE = PDP->entries[PD_i];
        PageTable* PD;
        if (!PDE.Present)
        {
            return UINT64_MAX; // Not mapped
        }
        else { PD = (PageTable*)((uint64_t)PDE.Address << 12); }

        PDE = PD->entries[PT_i];
        PageTable* PT;
        if (!PDE.Present)
        {
            return UINT64_MAX; // Not mapped
        }
        else { PT = (PageTable*)((uint64_t)PDE.Address << 12); }

        PDE = PT->entries[P_i];
        return PDE.Address << 12;
    }

    uint64_t VirtualMemoryManager::getFlags(uint64_t virtualAddress)
    {
        virtualAddress >>= 12;
        uint64_t P_i = virtualAddress & 0x1ff;
        virtualAddress >>= 9;
        uint64_t PT_i = virtualAddress & 0x1ff;
        virtualAddress >>= 9;
        uint64_t PD_i = virtualAddress & 0x1ff;
        virtualAddress >>= 9;
        uint64_t PDP_i = virtualAddress & 0x1ff;

        PageDirectoryEntry PDE;

        PDE = PML4->entries[PDP_i];
        PageTable* PDP;
        if (!PDE.Present)
        {
            return 0; // Not mapped
        }
        else { PDP = (PageTable*)((uint64_t)PDE.Address << 12); }

        PDE = PDP->entries[PD_i];
        PageTable* PD;
        if (!PDE.Present)
        {
            return 0; // Not mapped
        }
        else { PD = (PageTable*)((uint64_t)PDE.Address << 12); }

        PDE = PD->entries[PT_i];
        PageTable* PT;
        if (!PDE.Present)
        {
            return 0; // Not mapped
        }
        else { PT = (PageTable*)((uint64_t)PDE.Address << 12); }

        uint64_t flags = 0;

        PDE = PT->entries[P_i];
        if (PDE.UserSuper) flags |= User;
        if (PDE.ReadWrite) flags |= ReadWrite;
        return flags;
    }

    void VirtualMemoryManager::map(uint64_t virtualAddress, uint64_t physicalAddress, int flags)
    {
        virtualAddress >>= 12;
        uint64_t P_i = virtualAddress & 0x1ff;
        virtualAddress >>= 9;
        uint64_t PT_i = virtualAddress & 0x1ff;
        virtualAddress >>= 9;
        uint64_t PD_i = virtualAddress & 0x1ff;
        virtualAddress >>= 9;
        uint64_t PDP_i = virtualAddress & 0x1ff;

        PageDirectoryEntry PDE;

        PDE = PML4->entries[PDP_i];
        PageTable* PDP;
        if (!PDE.Present)
        {
            PDP = (PageTable*)PMM::request_page();
            ASSERT(!(PMM_DID_FAIL(PDP)));
            memset(PDP, 0, 0x1000);
            PDE.Address = (uint64_t)PDP >> 12;
            PDE.Present = true;
            PDE.ReadWrite = true;
            if (flags & User) PDE.UserSuper = true;
            PML4->entries[PDP_i] = PDE;
        }
        else { PDP = (PageTable*)((uint64_t)PDE.Address << 12); }
        if ((flags & User) && !PDE.UserSuper)
        {
            PDE.UserSuper = true;
            PML4->entries[PDP_i] = PDE;
        }

        PDE = PDP->entries[PD_i];
        PageTable* PD;
        if (!PDE.Present)
        {
            PD = (PageTable*)PMM::request_page();
            ASSERT(!(PMM_DID_FAIL(PD)));
            memset(PD, 0, 0x1000);
            PDE.Address = (uint64_t)PD >> 12;
            PDE.Present = true;
            PDE.ReadWrite = true;
            if (flags & User) PDE.UserSuper = true;
            PDP->entries[PD_i] = PDE;
        }
        else { PD = (PageTable*)((uint64_t)PDE.Address << 12); }
        if ((flags & User) && !PDE.UserSuper)
        {
            PDE.UserSuper = true;
            PDP->entries[PD_i] = PDE;
        }

        PDE = PD->entries[PT_i];
        PageTable* PT;
        if (!PDE.Present)
        {
            PT = (PageTable*)PMM::request_page();
            ASSERT(!(PMM_DID_FAIL(PT)));
            memset(PT, 0, 0x1000);
            PDE.Address = (uint64_t)PT >> 12;
            PDE.Present = true;
            PDE.ReadWrite = true;
            if (flags & User) PDE.UserSuper = true;
            PD->entries[PT_i] = PDE;
        }
        else { PT = (PageTable*)((uint64_t)PDE.Address << 12); }
        if ((flags & User) && !PDE.UserSuper)
        {
            PDE.UserSuper = true;
            PD->entries[PT_i] = PDE;
        }

        PDE = PT->entries[P_i];
        PDE.Present = true;
        PDE.ReadWrite = flags & ReadWrite;
        PDE.UserSuper = flags & User;
        PDE.Address = physicalAddress >> 12;
        PT->entries[P_i] = PDE;
    }
}