#define MODULE "mem"

#include "memory/PMM.h"
#include "bootboot.h"
#include "log/Log.h"
#include "memory/Memory.h"
#include "memory/MemoryManager.h"
#include "misc/utils.h"
#include "std/ensure.h"
#include "std/string.h"

extern BOOTBOOT bootboot;

static bool bitmap_read(uint64_t index);
static void bitmap_set(uint64_t index, bool value);

static uint64_t free_mem = 0;
static uint64_t used_mem = 0;
static uint64_t reserved_mem = 0;

static char* bitmap_addr;
static char* virtual_bitmap_addr;
static uint64_t bitmap_size;

static uint64_t start_index = 0;

void PMM::init()
{
    uint64_t total_mem = Memory::get_system();

    void* biggest_chunk = nullptr;
    uint64_t biggest_chunk_size = 0;

    MMapEnt* ptr = &bootboot.mmap;
    uint64_t mmap_entries = (bootboot.size - 128) / 16;
    for (uint64_t i = 0; i < mmap_entries; i++)
    {
        if (!MMapEnt_IsFree(ptr))
        {
            ptr++;
            continue;
        }
        if (MMapEnt_Size(ptr) > biggest_chunk_size)
        {
            biggest_chunk = (void*)MMapEnt_Ptr(ptr);
            biggest_chunk_size = MMapEnt_Size(ptr);
        }
        ptr++;
    }

    bitmap_addr = (char*)biggest_chunk;
    virtual_bitmap_addr = bitmap_addr;
    ensure((total_mem / PAGE_SIZE / 8) < biggest_chunk_size);
    bitmap_size = total_mem / PAGE_SIZE / 8 + 1;
    memset(bitmap_addr, 0xFF, bitmap_size);

    ptr = &bootboot.mmap;
    for (uint64_t i = 0; i < mmap_entries; i++)
    {
        uint64_t index = MMapEnt_Ptr(ptr) / PAGE_SIZE;
        if (!MMapEnt_IsFree(ptr)) { reserved_mem += MMapEnt_Size(ptr); }
        else
        {
            free_mem += MMapEnt_Size(ptr);
            for (uint64_t j = 0; j < (MMapEnt_Size(ptr) / PAGE_SIZE); j++) { bitmap_set(index + j, false); }
        }
        ptr++;
    }

    lock_pages(bitmap_addr, bitmap_size / PAGE_SIZE + 1);
}

static bool bitmap_read(uint64_t index)
{
    return (virtual_bitmap_addr[index / 8] & (0b10000000 >> (index % 8))) > 0;
}

static void bitmap_set(uint64_t index, bool value)
{
    uint64_t byteIndex = index / 8;
    uint8_t bitIndexer = 0b10000000 >> (index % 8);
    virtual_bitmap_addr[byteIndex] &= (uint8_t)(~bitIndexer);
    if (value) { virtual_bitmap_addr[byteIndex] |= bitIndexer; }
}

void* PMM::request_page()
{
    for (uint64_t index = start_index; index < (bitmap_size * 8); index++)
    {
        if (bitmap_read(index)) continue;
        bitmap_set(index, true);
        start_index = index + 1;
        free_mem -= PAGE_SIZE;
        used_mem += PAGE_SIZE;
        return (void*)(index * PAGE_SIZE);
    }

    return PMM_FAILED;
}

void* PMM::request_pages(uint64_t count)
{
    uint64_t contiguous = 0;
    uint64_t contiguous_start = 0;
    for (uint64_t index = start_index; index < (bitmap_size * 8); index++)
    {
        if (bitmap_read(index))
        {
            contiguous = 0;
            continue;
        }
        if (contiguous == 0)
        {
            contiguous_start = index;
            contiguous++;
        }
        else
            contiguous++;
        if (contiguous == count)
        {
            for (uint64_t i = 0; i < count; i++) bitmap_set(contiguous_start + i, true);
            free_mem -= (count * PAGE_SIZE);
            used_mem += (count * PAGE_SIZE);
            return (void*)(contiguous_start * PAGE_SIZE);
        }
    }

    return PMM_FAILED;
}

void PMM::free_page(void* address)
{
    uint64_t index = (uint64_t)address / PAGE_SIZE;
    if (index > (bitmap_size * 8))
    {
        kinfoln("attempt to free out-of-range address %p", address);
        return;
    }
    if (!bitmap_read(index)) return;
    bitmap_set(index, false);
    used_mem -= PAGE_SIZE;
    free_mem += PAGE_SIZE;
    if (start_index > index) start_index = index;
}

void PMM::free_pages(void* address, uint64_t count)
{
    for (uint64_t index = 0; index < count; index++) { free_page((void*)((uint64_t)address + index)); }
}

void PMM::lock_page(void* address)
{
    uint64_t index = ((uint64_t)address) / PAGE_SIZE;
    if (bitmap_read(index)) return;
    bitmap_set(index, true);
    used_mem += PAGE_SIZE;
    free_mem -= PAGE_SIZE;
}

void PMM::lock_pages(void* address, uint64_t count)
{
    for (uint64_t index = 0; index < count; index++) { lock_page((void*)((uint64_t)address + index)); }
}

uint64_t PMM::get_free()
{
    return free_mem;
}

uint64_t PMM::get_used()
{
    return used_mem;
}

uint64_t PMM::get_reserved()
{
    return reserved_mem;
}

uint64_t PMM::get_bitmap_size()
{
    return bitmap_size;
}

void PMM::map_bitmap_to_virtual()
{
    virtual_bitmap_addr = (char*)MemoryManager::get_unaligned_mappings(
        bitmap_addr, Utilities::get_blocks_from_size(PAGE_SIZE, bitmap_size));
}