#include "memory/MemoryManager.h"
#include "Log.h"
#include "arch/CPU.h"
#include "arch/MMU.h"
#include "boot/bootboot.h"
#include "memory/MemoryMap.h"
#include <luna/Alignment.h>
#include <luna/Bitmap.h>
#include <luna/String.h>
#include <luna/SystemError.h>
#include <luna/Types.h>
#include <luna/Units.h>

extern BOOTBOOT bootboot;

extern u8 start_of_kernel_rodata[1];
extern u8 end_of_kernel_rodata[1];
extern u8 start_of_kernel_data[1];
extern u8 end_of_kernel_data[1];

static u64 free_mem = 0;
static u64 used_mem = 0;
static u64 reserved_mem = 0;

static u64 start_index = 0;

static Bitmap g_frame_bitmap;

#define CHECK_PAGE_ALIGNED(address) check(is_aligned(address, ARCH_PAGE_SIZE))

static usize get_physical_address_space_size()
{
    MemoryMapIterator iter;
    MemoryMapEntry entry = iter.highest();

    return entry.ptr + entry.size; // This is the address at the end of the last (highest) entry, thus the whole
                                   // address space that was passed to us.
}

namespace MemoryManager
{
    Result<void> protect_kernel_sections()
    {
        const u64 rodata_size = (u64)(end_of_kernel_rodata - start_of_kernel_rodata);
        const u64 rodata_pages = get_blocks_from_size(rodata_size, ARCH_PAGE_SIZE);
        TRY(remap((u64)start_of_kernel_rodata, rodata_pages, MMU::NoExecute));

        const u64 data_size = (u64)(end_of_kernel_data - start_of_kernel_data);
        const u64 data_pages = get_blocks_from_size(data_size, ARCH_PAGE_SIZE);
        TRY(remap((u64)start_of_kernel_data, data_pages, MMU::NoExecute | MMU::ReadWrite));

        return {};
    }

    void init_physical_frame_allocator()
    {
        MemoryMapIterator iter;
        MemoryMapEntry entry;

        auto largest_free = iter.largest_free();

        expect(largest_free.free, "We were given a largest free block that isn't even free!");

        // The entire physical address space. May contain inexistent memory holes, thus differs from total_mem which
        // only counts existent memory. Our bitmap needs to have space for all of the physical address space, since
        // usable addresses will be scattered across it.
        usize physical_address_space_size = get_physical_address_space_size();

        // We store our frame bitmap at the beginning of the largest free memory block.
        char* frame_bitmap_addr = (char*)largest_free.ptr;

        usize frame_bitmap_size = physical_address_space_size / ARCH_PAGE_SIZE / 8 + 1;

        // This should never happen, unless memory is very fragmented. Usually there is always a very big block of
        // usable memory and then some tiny blocks around it.
        if (frame_bitmap_size >= largest_free.size) [[unlikely]]
        {
            kerrorln("ERROR: No single memory block is enough to hold the frame bitmap");
            CPU::efficient_halt();
        }

        g_frame_bitmap.initialize(frame_bitmap_addr, frame_bitmap_size);

        g_frame_bitmap.clear(true); // Set all pages to used/reserved by default, then clear out the free ones

        iter.rewind();
        while (iter.next().try_set_value(entry))
        {
            u64 index = entry.ptr / ARCH_PAGE_SIZE;
            u64 pages = entry.size / ARCH_PAGE_SIZE;
            if (!entry.free) { reserved_mem += entry.size; }
            else
            {
                free_mem += entry.size;
                g_frame_bitmap.clear_region(index, pages, false);
            }
        }

        lock_frames((u64)frame_bitmap_addr, frame_bitmap_size / ARCH_PAGE_SIZE + 1);
    }

    void init()
    {
        init_physical_frame_allocator();
        MMU::setup_initial_page_directory();
    }

    void lock_frame(u64 frame)
    {
        const u64 index = ((u64)frame) / ARCH_PAGE_SIZE;
        if (g_frame_bitmap.get(index)) return;
        g_frame_bitmap.set(index, true);
        used_mem += ARCH_PAGE_SIZE;
        free_mem -= ARCH_PAGE_SIZE;
    }

    void lock_frames(u64 frames, u64 count)
    {
        for (u64 index = 0; index < count; index++) { lock_frame(frames + (index * ARCH_PAGE_SIZE)); }
    }

    Result<u64> alloc_frame()
    {
        for (u64 index = start_index; index < g_frame_bitmap.size(); index++)
        {
            if (g_frame_bitmap.get(index)) continue;
            g_frame_bitmap.set(index, true);
            start_index = index + 1;
            free_mem -= ARCH_PAGE_SIZE;
            used_mem += ARCH_PAGE_SIZE;
            return index * ARCH_PAGE_SIZE;
        }

        return err(ENOMEM);
    }

    Result<void> free_frame(u64 frame)
    {
        const u64 index = frame / ARCH_PAGE_SIZE;
        if (index > g_frame_bitmap.size()) return err(EFAULT);
        if (!g_frame_bitmap.get(index)) return err(EFAULT);
        g_frame_bitmap.set(index, false);
        used_mem -= ARCH_PAGE_SIZE;
        free_mem += ARCH_PAGE_SIZE;
        if (start_index > index) start_index = index;
        return {};
    }

    Result<void> remap(u64 address, usize count, int flags)
    {
        CHECK_PAGE_ALIGNED(address);

        while (count--)
        {
            TRY(MMU::remap(address, flags));
            address += ARCH_PAGE_SIZE;
        }

        return {};
    }

    Result<void> map_frames_at(u64 virt, u64 phys, usize count, int flags)
    {
        CHECK_PAGE_ALIGNED(virt);
        CHECK_PAGE_ALIGNED(phys);

        while (count--)
        {
            TRY(MMU::map(virt, phys, flags));
            virt += ARCH_PAGE_SIZE;
            phys += ARCH_PAGE_SIZE;
        }

        return {};
    }

    Result<u64> alloc_at(u64 virt, usize count, int flags)
    {
        CHECK_PAGE_ALIGNED(virt);

        u64 start = virt;

        while (count--)
        {
            u64 frame = TRY(alloc_frame());
            TRY(MMU::map(virt, frame, flags));
            virt += ARCH_PAGE_SIZE;
        }

        return start;
    }

    Result<void> unmap_owned(u64 virt, usize count)
    {
        CHECK_PAGE_ALIGNED(virt);

        while (count--)
        {
            u64 frame = TRY(MMU::unmap(virt));
            TRY(free_frame(frame));
            virt += ARCH_PAGE_SIZE;
        }

        return {};
    }

    Result<void> unmap_weak(u64 virt, usize count)
    {
        CHECK_PAGE_ALIGNED(virt);

        while (count--)
        {
            TRY(MMU::unmap(virt));
            virt += ARCH_PAGE_SIZE;
        }

        return {};
    }

    Result<void> remap_unaligned(u64 address, usize count, int flags)
    {
        if (!is_aligned(address, ARCH_PAGE_SIZE)) count++;
        address = align_down(address, ARCH_PAGE_SIZE);

        while (count--)
        {
            TRY(MMU::remap(address, flags));
            address += ARCH_PAGE_SIZE;
        }

        return {};
    }

    bool validate_readable_page(u64 address)
    {
        auto rc = MMU::get_flags(address);
        if (rc.has_error()) return false;
        return true;
    }

    bool validate_writable_page(u64 address)
    {
        auto rc = MMU::get_flags(address);
        if (rc.has_error()) return false;
        if (rc.release_value() & MMU::ReadWrite) return true;
        return false;
    }

    u64 free()
    {
        return free_mem;
    }

    u64 used()
    {
        return used_mem;
    }

    u64 reserved()
    {
        return reserved_mem;
    }

    u64 total()
    {
        return free_mem + used_mem + reserved_mem;
    }
}