#include "memory/MemoryManager.h"
#include "Log.h"
#include "arch/MMU.h"
#include "memory/KernelVM.h"
#include "memory/MemoryMap.h"
#include <luna/Alignment.h>
#include <luna/Bitmap.h>
#include <luna/ScopeGuard.h>
#include <luna/Spinlock.h>
#include <luna/SystemError.h>
#include <luna/Types.h>
#include <luna/Vector.h>

extern const u8 start_of_kernel_rodata[1];
extern const u8 end_of_kernel_rodata[1];
extern const u8 start_of_kernel_data[1];
extern const u8 end_of_kernel_data[1];

static Atomic<usize> free_mem;
static Atomic<usize> used_mem;
static Atomic<usize> reserved_mem;

static Atomic<u64> start_index;

static LockedValue<Bitmap> g_frame_bitmap;

#define CHECK_PAGE_ALIGNED(address) expect(is_aligned<ARCH_PAGE_SIZE>(address), "Address is not page-aligned")

static usize get_physical_address_space_size()
{
    MemoryMapIterator iter;
    const MemoryMapEntry entry = iter.highest();

    return entry.address() + entry.size(); // This is the address at the end of the last (highest) entry, thus the whole
                                           // address space that was passed to us.
}

namespace MemoryManager
{
    Result<void> protect_kernel_sections()
    {
        const usize rodata_size = (usize)(end_of_kernel_rodata - start_of_kernel_rodata);
        const usize rodata_pages = get_blocks_from_size(rodata_size, ARCH_PAGE_SIZE);
        TRY(remap((u64)start_of_kernel_rodata, rodata_pages, MMU::NoExecute));

        const usize data_size = (usize)(end_of_kernel_data - start_of_kernel_data);
        const usize data_pages = get_blocks_from_size(data_size, ARCH_PAGE_SIZE);
        TRY(remap((u64)start_of_kernel_data, data_pages, MMU::NoExecute | MMU::ReadWrite));

        return {};
    }

    void init_physical_frame_allocator()
    {
        MemoryMapIterator iter;
        MemoryMapEntry entry;

        const auto largest_free_entry = iter.largest_free();

        expect(largest_free_entry.is_free(), "We were given a largest free memory region that isn't even free!");

        // The entire physical address space. May contain inexistent memory holes, thus differs from total_mem which
        // only counts existent memory. Our bitmap needs to have space for all of the physical address space, since
        // usable addresses will be scattered across it.
        const usize physical_address_space_size = get_physical_address_space_size();

        // We store our frame bitmap at the beginning of the largest free memory block.
        char* const frame_bitmap_addr = (char*)largest_free_entry.ptr();

        const usize frame_bitmap_size = get_blocks_from_size(physical_address_space_size / ARCH_PAGE_SIZE, 8UL);

        // This should never happen, unless memory is very fragmented. Usually there is always a very big block of
        // usable memory and then some tiny blocks around it.
        expect(frame_bitmap_size < largest_free_entry.size(),
               "No single memory region is enough to hold the frame bitmap");

        {
            auto frame_bitmap = g_frame_bitmap.lock();

            frame_bitmap->initialize(frame_bitmap_addr, frame_bitmap_size);

            frame_bitmap->clear(true); // Set all pages to used/reserved by default, then clear out the free ones

            iter.rewind();
            while (iter.next().try_set_value(entry))
            {
                const u64 index = entry.address() / ARCH_PAGE_SIZE;
                const usize pages = entry.size() / ARCH_PAGE_SIZE;
                if (!entry.is_free()) { reserved_mem += entry.size(); }
                else
                {
                    free_mem += entry.size();
                    frame_bitmap->clear_region(index, pages, false);
                }
            }
        }

        // Make sure that the physical frames used by the bitmap aren't handed out to anyone else.
        lock_frames(largest_free_entry.address(), get_blocks_from_size(frame_bitmap_size, ARCH_PAGE_SIZE));
    }

    void init()
    {
        init_physical_frame_allocator();

        MMU::setup_initial_page_directory();

        auto frame_bitmap = g_frame_bitmap.lock();
        u64 phys = (u64)frame_bitmap->location();

        auto virtual_bitmap_base = MMU::translate_physical_address(phys);
        frame_bitmap->initialize((void*)virtual_bitmap_base, frame_bitmap->size_in_bytes());

        KernelVM::init();
    }

    void do_lock_frame(u64 index, Bitmap& bitmap)
    {
        if (bitmap.get(index)) return;
        bitmap.set(index, true);
        used_mem += ARCH_PAGE_SIZE;
        free_mem -= ARCH_PAGE_SIZE;
    }

    void lock_frame(u64 frame)
    {
        const u64 index = frame / ARCH_PAGE_SIZE;
        auto frame_bitmap = g_frame_bitmap.lock();
        do_lock_frame(index, *frame_bitmap);
    }

    void lock_frames(u64 frames, usize count)
    {
        auto frame_bitmap = g_frame_bitmap.lock();
        const u64 frame_index = frames / ARCH_PAGE_SIZE;
        for (usize index = 0; index < count; index++) { do_lock_frame(frame_index + index, *frame_bitmap); }
    }

    Result<u64> alloc_frame()
    {
        auto frame_bitmap = g_frame_bitmap.lock();

        usize index;
        bool ok = frame_bitmap->find_and_toggle(false, start_index).try_set_value(index);
        if (!ok) return err(ENOMEM);

        start_index = index + 1;

        used_mem += ARCH_PAGE_SIZE;
        free_mem -= ARCH_PAGE_SIZE;

        return index * ARCH_PAGE_SIZE;
    }

    Result<u64> alloc_zeroed_frame()
    {
        const u64 frame = TRY(alloc_frame());

        const u64 address = MMU::translate_physical_address(frame);
        memset((void*)address, 0, ARCH_PAGE_SIZE);

        return frame;
    }

    Result<void> free_frame(u64 frame)
    {
        const u64 index = frame / ARCH_PAGE_SIZE;

        auto frame_bitmap = g_frame_bitmap.lock();

        if (index > frame_bitmap->size()) return err(EFAULT);
        if (!frame_bitmap->get(index)) return err(EFAULT);

        frame_bitmap->set(index, false);

        used_mem -= ARCH_PAGE_SIZE;
        free_mem += ARCH_PAGE_SIZE;

        if (start_index > index) start_index = index;
        return {};
    }

    Result<void> free_frames(u64 address, usize count)
    {
        while (count--)
        {
            TRY(free_frame(address));
            address += ARCH_PAGE_SIZE;
        }

        return {};
    }

    Result<void> remap(u64 address, usize count, int flags)
    {
        CHECK_PAGE_ALIGNED(address);

        while (count--)
        {
            TRY(MMU::remap(address, flags));
            address += ARCH_PAGE_SIZE;
        }

        return {};
    }

    Result<void> map_frames_at(u64 virt, u64 phys, usize count, int flags)
    {
        CHECK_PAGE_ALIGNED(virt);
        CHECK_PAGE_ALIGNED(phys);

        usize pages_mapped = 0;

        // Let's clean up after ourselves if we fail.
        auto guard = make_scope_guard([=, &pages_mapped] { unmap_weak(virt, pages_mapped); });

        while (pages_mapped < count)
        {
            TRY(MMU::map(virt, phys, flags, MMU::UseHugePages::No));
            virt += ARCH_PAGE_SIZE;
            phys += ARCH_PAGE_SIZE;
            pages_mapped++;
        }

        guard.deactivate();

        return {};
    }

    Result<void> copy_region(u64 virt, usize count, PageDirectory* oldpd, PageDirectory* newpd)
    {
        CHECK_PAGE_ALIGNED(virt);

        usize pages_mapped = 0;

        // Let's clean up after ourselves if we fail.
        auto guard = make_scope_guard(
            [=, &pages_mapped] { kwarnln("copy_region failed, sorry! cannot reclaim already copied pages"); });

        while (pages_mapped < count)
        {
            u64 phys = TRY(MMU::get_physical(virt, oldpd));
            int flags = TRY(MMU::get_flags(virt, oldpd));
            TRY(MMU::map(virt, phys, flags, MMU::UseHugePages::No, newpd));
            virt += ARCH_PAGE_SIZE;
            pages_mapped++;
        }

        guard.deactivate();

        return {};
    }

    Result<void> copy_region_data(u64 virt, usize count, PageDirectory* oldpd, PageDirectory* newpd)
    {
        CHECK_PAGE_ALIGNED(virt);

        usize pages_mapped = 0;

        // Let's clean up after ourselves if we fail.
        auto guard = make_scope_guard(
            [=, &pages_mapped] { kwarnln("copy_region_data failed, sorry! cannot reclaim already copied pages"); });

        while (pages_mapped < count)
        {
            u64 frame = TRY(alloc_frame());
            u64 phys = TRY(MMU::get_physical(virt, oldpd));
            int flags = TRY(MMU::get_flags(virt, oldpd));
            memcpy((void*)MMU::translate_physical_address(frame), (void*)MMU::translate_physical_address(phys),
                   ARCH_PAGE_SIZE);
            TRY(MMU::map(virt, frame, flags, MMU::UseHugePages::No, newpd));
            virt += ARCH_PAGE_SIZE;
            pages_mapped++;
        }

        guard.deactivate();

        return {};
    }

    Result<void> map_huge_frames_at(u64 virt, u64 phys, usize count, int flags)
    {
        CHECK_PAGE_ALIGNED(virt);
        CHECK_PAGE_ALIGNED(phys);

        usize pages_mapped = 0;

        // Let's clean up after ourselves if we fail.
        auto guard = make_scope_guard([=, &pages_mapped] { unmap_weak_huge(virt, pages_mapped); });

        while (pages_mapped < count)
        {
            TRY(MMU::map(virt, phys, flags, MMU::UseHugePages::Yes));
            virt += ARCH_HUGE_PAGE_SIZE;
            phys += ARCH_HUGE_PAGE_SIZE;
            pages_mapped++;
        }

        guard.deactivate();

        return {};
    }

    Result<u64> alloc_at(u64 virt, usize count, int flags)
    {
        CHECK_PAGE_ALIGNED(virt);

        u64 start = virt;
        usize pages_mapped = 0;

        auto guard = make_scope_guard([=, &pages_mapped] { unmap_owned(start, pages_mapped); });

        while (pages_mapped < count)
        {
            const u64 frame = TRY(alloc_frame());
            TRY(MMU::map(virt, frame, flags, MMU::UseHugePages::No));
            virt += ARCH_PAGE_SIZE;
            pages_mapped++;
        }

        guard.deactivate();

        return start;
    }

    Result<u64> alloc_at_zeroed(u64 virt, usize count, int flags)
    {
        CHECK_PAGE_ALIGNED(virt);

        u64 start = virt;
        usize pages_mapped = 0;

        auto guard = make_scope_guard([=, &pages_mapped] { unmap_owned(start, pages_mapped); });

        while (pages_mapped < count)
        {
            const u64 frame = TRY(alloc_zeroed_frame());
            TRY(MMU::map(virt, frame, flags, MMU::UseHugePages::No));
            virt += ARCH_PAGE_SIZE;
            pages_mapped++;
        }

        guard.deactivate();

        return start;
    }

    Result<u64> alloc_for_kernel(usize count, int flags)
    {
        const u64 start = TRY(KernelVM::alloc_several_pages(count));
        usize pages_mapped = 0;

        auto guard = make_scope_guard([=, &pages_mapped] {
            KernelVM::free_several_pages(start, pages_mapped);
            unmap_owned(start, pages_mapped);
        });

        u64 virt = start;

        while (pages_mapped < count)
        {
            const u64 frame = TRY(alloc_frame());
            TRY(MMU::map(virt, frame, flags, MMU::UseHugePages::No));
            virt += ARCH_PAGE_SIZE;
            pages_mapped++;
        }

        guard.deactivate();

        return start;
    }

    Result<u64> get_kernel_mapping_for_frames(u64 phys, usize count, int flags)
    {
        const u64 start = TRY(KernelVM::alloc_several_pages(count));

        usize pages_mapped = 0;

        auto guard = make_scope_guard([=, &pages_mapped] {
            KernelVM::free_several_pages(start, pages_mapped);
            unmap_weak(start, pages_mapped);
        });

        u64 virt = start;

        while (pages_mapped < count)
        {
            TRY(MMU::map(virt, phys, flags, MMU::UseHugePages::No));
            virt += ARCH_PAGE_SIZE;
            phys += ARCH_PAGE_SIZE;
            pages_mapped++;
        }

        guard.deactivate();

        return start;
    }

    Result<void> unmap_owned(u64 virt, usize count)
    {
        CHECK_PAGE_ALIGNED(virt);

        while (count--)
        {
            const u64 frame = TRY(MMU::unmap(virt));
            TRY(free_frame(frame));
            virt += ARCH_PAGE_SIZE;
        }

        return {};
    }

    Result<void> unmap_owned_if_possible(u64 virt, usize count)
    {
        CHECK_PAGE_ALIGNED(virt);

        while (count--)
        {
            const auto frame = MMU::unmap(virt);
            if (frame.has_value()) TRY(free_frame(frame.value()));
            virt += ARCH_PAGE_SIZE;
        }

        return {};
    }

    Result<void> unmap_owned_and_free_vm(u64 virt, usize count)
    {
        CHECK_PAGE_ALIGNED(virt);

        KernelVM::free_several_pages(virt, count);

        return unmap_owned(virt, count);
    }

    Result<void> unmap_weak(u64 virt, usize count)
    {
        CHECK_PAGE_ALIGNED(virt);

        while (count--)
        {
            TRY(MMU::unmap(virt));
            virt += ARCH_PAGE_SIZE;
        }

        return {};
    }

    Result<void> unmap_weak_huge(u64 virt, usize count)
    {
        CHECK_PAGE_ALIGNED(virt);

        while (count--)
        {
            TRY(MMU::unmap(virt));
            virt += ARCH_HUGE_PAGE_SIZE;
        }

        return {};
    }

    Result<void> unmap_weak_and_free_vm(u64 virt, usize count)
    {
        CHECK_PAGE_ALIGNED(virt);

        KernelVM::free_several_pages(virt, count);

        return unmap_weak(virt, count);
    }

    Result<void> remap_unaligned(u64 address, usize count, int flags)
    {
        if (!is_aligned<ARCH_PAGE_SIZE>(address)) count++;
        address = align_down<ARCH_PAGE_SIZE>(address);

        while (count--)
        {
            TRY(MMU::remap(address, flags));
            address += ARCH_PAGE_SIZE;
        }

        return {};
    }

    bool validate_page_default_access(u64 address)
    {
        auto rc = MMU::get_flags(address);
        if (rc.has_error()) return false;
        return true;
    }

    bool validate_page_access(u64 address, int flags)
    {
        auto rc = MMU::get_flags(address);
        if (rc.has_error()) return false;
        if (rc.value() & flags) return true;
        return false;
    }

    // FIXME: Make this more efficient.
    Result<String> strdup_from_user(u64 address)
    {
        if (!validate_page_access(address, MMU::User)) return err(EFAULT);

        Vector<char> result;

        while (*(char*)address != 0)
        {
            TRY(result.try_append(*(char*)address));
            address++;
            if (address % ARCH_PAGE_SIZE)
            {
                if (!validate_page_access(address, MMU::User)) return err(EFAULT);
            }
        }

        TRY(result.try_append(0)); // null terminator

        return String { result.release_data() };
    }

    bool validate_access(const void* mem, usize size, int flags)
    {
        uintptr_t address = (uintptr_t)mem;
        uintptr_t page = align_down<ARCH_PAGE_SIZE>(address);

        uintptr_t diff = address - page;

        usize pages = get_blocks_from_size(size + diff, ARCH_PAGE_SIZE);

        while (pages--)
        {
            if (flags > 0)
            {
                if (!validate_page_access(page, flags)) return false;
            }
            else
            {
                if (!validate_page_default_access(page)) return false;
            }
            page += ARCH_PAGE_SIZE;
        }

        return true;
    }

    // FIXME: Use memcpy() in both copy_to_user and copy_from_user().

    bool copy_to_user(void* user, const void* kernel, usize size)
    {
        uintptr_t user_ptr = (uintptr_t)user;
        uintptr_t user_page = align_down<ARCH_PAGE_SIZE>(user_ptr);

        const u8* kernel_ptr = (const u8*)kernel;

        // Userspace pointer not aligned on page boundary
        if (user_ptr != user_page)
        {
            if (!validate_page_access(user_page, MMU::ReadWrite | MMU::User)) return false;
        }

        while (size--)
        {
            // Crossed a page boundary, gotta check the page tables again before touching any memory!!
            if (user_ptr % ARCH_PAGE_SIZE)
            {
                if (!validate_page_access(user_ptr, MMU::ReadWrite | MMU::User)) return false;
            }

            *(u8*)user_ptr = *kernel_ptr++;
            user_ptr++;
        }

        return true;
    }

    bool copy_from_user(const void* user, void* kernel, usize size)
    {
        uintptr_t user_ptr = (uintptr_t)user;
        uintptr_t user_page = align_down<ARCH_PAGE_SIZE>(user_ptr);

        u8* kernel_ptr = (u8*)kernel;

        // Userspace pointer not aligned on page boundary
        if (user_ptr != user_page)
        {
            if (!validate_page_access(user_page, MMU::User)) return false;
        }

        while (size--)
        {
            // Crossed a page boundary, gotta check the page tables again before touching any memory!!
            if (user_ptr % ARCH_PAGE_SIZE)
            {
                if (!validate_page_access(user_ptr, MMU::User)) return false;
            }

            *kernel_ptr++ = *(const u8*)user_ptr;
            user_ptr++;
        }

        return true;
    }

    usize free()
    {
        return free_mem;
    }

    usize used()
    {
        return used_mem;
    }

    usize reserved()
    {
        return reserved_mem;
    }

    usize total()
    {
        return free_mem + used_mem + reserved_mem;
    }
}