#include <luna/Alignment.h>
#include <luna/Alloc.h>
#include <luna/CString.h>
#include <luna/DebugLog.h>
#include <luna/Heap.h>
#include <luna/LinkedList.h>
#include <luna/SafeArithmetic.h>
#include <luna/ScopeGuard.h>
#include <luna/Spinlock.h>
#include <luna/SystemError.h>

#ifdef USE_FREESTANDING
#include "arch/MMU.h"
#define PAGE_SIZE ARCH_PAGE_SIZE
#else
#include <sys/mman.h>
#endif

namespace std
{
    const nothrow_t nothrow;
}

static constexpr int BLOCK_USED = 1 << 0;
static constexpr int BLOCK_START_MEM = 1 << 1;
static constexpr int BLOCK_END_MEM = 1 << 2;

static constexpr usize BLOCK_MAGIC = 0x6d616c6c6f63210a; // echo 'malloc!' | hexdump -C (includes a newline)
static constexpr usize BLOCK_DEAD = 0xdeaddeaddeaddead;

static constexpr u8 MALLOC_SCRUB_BYTE = 0xac;
static constexpr u8 FREE_SCRUB_BYTE = 0xde;

static constexpr usize MINIMUM_PAGES_PER_ALLOCATION = 4;

struct HeapBlock : LinkedListNode<HeapBlock>
{
    usize req_size;
    usize full_size;
    int status;
    usize magic;
};

static_assert(sizeof(HeapBlock) == 48UL);

static const isize HEAP_BLOCK_SIZE = 48;

static LinkedList<HeapBlock> heap;
static Spinlock g_heap_lock;

// If we're allocating a large amount of memory, map enough pages for it, but otherwise just use the default amount
// of pages.
static usize get_pages_for_allocation(usize bytes)
{
    usize pages = get_blocks_from_size(bytes, PAGE_SIZE);
    if (pages < MINIMUM_PAGES_PER_ALLOCATION) pages = MINIMUM_PAGES_PER_ALLOCATION;
    return pages;
}

static bool is_block_free(HeapBlock* block)
{
    return !(block->status & BLOCK_USED);
}

static usize space_available(HeapBlock* block)
{
    expect(!is_block_free(block), "Attempting to split a free block");
    return block->full_size - block->req_size;
}

// The heap block is stored right behind a memory block.
static HeapBlock* get_heap_block_for_pointer(void* ptr)
{
    return (HeapBlock*)offset_ptr(ptr, -HEAP_BLOCK_SIZE);
}

static void* get_pointer_from_heap_block(HeapBlock* block)
{
    return (void*)offset_ptr(block, HEAP_BLOCK_SIZE);
}

// Used when the caller tells us this block may be realloc-ed. In this case, we split the available space roughly
// equally between both blocks.
static usize get_fair_offset_to_split_at(HeapBlock* block, usize min)
{
    usize available = space_available(block);

    available -= min; // reserve at least min size for the new block.

    available -= (available /
                  2); // reserve half of the rest for the new block, while still leaving another half for the old one.

    available = align_down<16>(available); // Everything has to be aligned on a 16-byte boundary

    return available + block->req_size;
}

// Used when the caller tells us this block will not be realloc-ed. In this case, we make the new block as small as
// possible.
static usize get_small_offset_to_split_at(HeapBlock* block, usize min)
{
    usize available = space_available(block);

    available -= min; // reserve only min size for the new block.

    available = align_down<16>(available); // Everything has to be aligned on a 16-byte boundary

    return available + block->req_size;
}

static Option<HeapBlock*> split(HeapBlock* block, usize size, bool may_realloc)
{
    const usize available = space_available(block); // How much space can we steal from this block?
    const usize old_size =
        block->full_size; // Save the old value of this variable since we are going to use it after modifying it

    if (available <= (size + sizeof(HeapBlock)))
        return {}; // This block hasn't got enough free space to hold the requested size.

    const usize offset = may_realloc ? get_fair_offset_to_split_at(block, size + sizeof(HeapBlock))
                                     : get_small_offset_to_split_at(block, size + sizeof(HeapBlock));
    block->full_size = offset; // shrink the old block to fit this offset

    HeapBlock* const new_block = offset_ptr(block, offset + sizeof(HeapBlock));

    memset(new_block, 0, sizeof(*new_block));

    new_block->magic = BLOCK_MAGIC;
    new_block->status = (block->status & BLOCK_END_MEM) ? BLOCK_END_MEM : 0;
    new_block->full_size = old_size - (offset + sizeof(HeapBlock));
    heap.add_after(block, new_block);

    block->status &= ~BLOCK_END_MEM; // this block is no longer the last block in its memory range

    return new_block;
}

static Result<void> combine_forward(HeapBlock* block)
{
    // This block ends a memory range, cannot be combined with blocks outside its range.
    if (block->status & BLOCK_END_MEM) return {};

    // The caller needs to ensure there is a next block.
    HeapBlock* const next = heap.next(block).value();
    // This block starts a memory range, cannot be combined with blocks outside its range.
    if (next->status & BLOCK_START_MEM) return {};

    heap.remove(next);
    next->magic = BLOCK_DEAD;

    block->full_size += next->full_size + sizeof(HeapBlock);

    if (next->status & BLOCK_END_MEM)
    {
        if (next->status & BLOCK_START_MEM)
        {
            const usize pages = get_blocks_from_size(next->full_size + sizeof(HeapBlock), PAGE_SIZE);
            TRY(release_pages_impl(next, pages));
            return {};
        }
        else
            block->status |= BLOCK_END_MEM;
    }

    return {};
}

static Result<HeapBlock*> combine_backward(HeapBlock* block)
{
    // This block starts a memory range, cannot be combined with blocks outside its range.
    if (block->status & BLOCK_START_MEM) return block;

    // The caller needs to ensure there is a last block.
    HeapBlock* const last = heap.previous(block).value();
    // This block ends a memory range, cannot be combined with blocks outside its range.
    if (last->status & BLOCK_END_MEM) return block;
    heap.remove(block);
    block->magic = BLOCK_DEAD;

    last->full_size += block->full_size + sizeof(HeapBlock);

    if (block->status & BLOCK_END_MEM)
    {
        if (block->status & BLOCK_START_MEM)
        {
            const usize pages = get_blocks_from_size(block->full_size + sizeof(HeapBlock), PAGE_SIZE);
            TRY(release_pages_impl(block, pages));
            return last;
        }
        else
            last->status |= BLOCK_END_MEM;
    }

    return last;
}

Result<void*> malloc_impl(usize size, bool may_realloc, bool should_scrub)
{
    if (!size) return (void*)BLOCK_MAGIC;

    ScopeLock lock(g_heap_lock);

    size = align_up<16>(size);

    Option<HeapBlock*> block = heap.first();
    while (block.has_value())
    {
        HeapBlock* const current = block.value();
        // Trying to find a free block...
        if (is_block_free(current))
        {
            if (current->full_size < size)
            {
                block = heap.next(current);
                continue;
            }
            break; // We found a free block that's big enough!!
        }
        auto rc = split(current, size, may_realloc);
        if (rc.has_value())
        {
            block = rc.value(); // We managed to get a free block from a larger used block!!
            break;
        }
        block = heap.next(current);
    }

    if (!block.has_value()) // No free blocks, let's allocate a new one
    {
        usize pages = get_pages_for_allocation(size + sizeof(HeapBlock));
        HeapBlock* const current = (HeapBlock*)TRY(allocate_pages_impl(pages));

        memset(current, 0, sizeof(*current));

        current->full_size = (pages * PAGE_SIZE) - sizeof(HeapBlock);
        current->magic = BLOCK_MAGIC;
        current->status = BLOCK_START_MEM | BLOCK_END_MEM;
        heap.append(current);

        block = current;
    }

    HeapBlock* const current = block.value();

    current->req_size = size;
    current->status |= BLOCK_USED;

    if (should_scrub) { memset(get_pointer_from_heap_block(current), MALLOC_SCRUB_BYTE, size); }

    return get_pointer_from_heap_block(current);
}

Result<void> free_impl(void* ptr)
{
    if (ptr == (void*)BLOCK_MAGIC) return {}; // This pointer was returned from a call to malloc(0)
    if (!ptr) return {};

    ScopeLock lock(g_heap_lock);

    HeapBlock* block = get_heap_block_for_pointer(ptr);

    if (block->magic != BLOCK_MAGIC)
    {
        if (block->magic == BLOCK_DEAD) { dbgln("ERROR: Attempt to free memory at %p, which was already freed", ptr); }
        else
            dbgln("ERROR: Attempt to free memory at %p, which wasn't allocated with malloc", ptr);

#ifdef USE_FREESTANDING
        fail("Call to free_impl() with an invalid argument (double-free or erroneous deallocation)");
#else
        return err(EFAULT);
#endif
    }

    if (is_block_free(block))
    {
        dbgln("ERROR: Attempt to free memory at %p, which was already freed", ptr);
#ifdef USE_FREESTANDING
        fail("Call to free_impl() with a pointer to freed memory (probably double-free)");
#else
        return err(EFAULT);
#endif
    }
    else
        block->status &= ~BLOCK_USED;

    memset(ptr, FREE_SCRUB_BYTE, block->req_size);

    auto maybe_next = heap.next(block);
    if (maybe_next.has_value() && is_block_free(maybe_next.value()))
    {
        // The next block is also free, thus we can merge!
        TRY(combine_forward(block));
    }

    auto maybe_last = heap.previous(block);
    if (maybe_last.has_value() && is_block_free(maybe_last.value()))
    {
        // The last block is also free, thus we can merge!
        block = TRY(combine_backward(block));
    }

    if ((block->status & BLOCK_START_MEM) && (block->status & BLOCK_END_MEM))
    {
        heap.remove(block);
        const usize pages = get_blocks_from_size(block->full_size + sizeof(HeapBlock), PAGE_SIZE);
        TRY(release_pages_impl(block, pages));
    }

    return {};
}

Result<void*> realloc_impl(void* ptr, usize size, bool may_realloc_again)
{
    if (!ptr) return malloc_impl(size, may_realloc_again);
    if (ptr == (void*)BLOCK_MAGIC) return malloc_impl(size, may_realloc_again);
    if (!size)
    {
        TRY(free_impl(ptr));
        return (void*)BLOCK_MAGIC;
    }

    ScopeLock lock(g_heap_lock);

    HeapBlock* const block = get_heap_block_for_pointer(ptr);

    if (block->magic != BLOCK_MAGIC)
    {
        if (block->magic == BLOCK_DEAD)
        {
            dbgln("ERROR: Attempt to realloc memory at %p, which was already freed", ptr);
        }
        else
            dbgln("ERROR: Attempt to realloc memory at %p, which wasn't allocated with malloc", ptr);

        return err(EFAULT);
    }

    size = align_up<16>(size);

    if (is_block_free(block))
    {
        dbgln("ERROR: Attempt to realloc memory at %p, which was already freed", ptr);
        return err(EFAULT);
    }

    if (block->full_size >= size)
    {
        // This block is already large enough!
        if (size > block->req_size)
        {
            // If the new size is larger, scrub the newly allocated space.
            memset(offset_ptr(ptr, block->req_size), MALLOC_SCRUB_BYTE, size - block->req_size);
        }
        else if (size < block->req_size)
        {
            // If the new size is smaller, scrub the removed space as if it was freed.
            memset(offset_ptr(ptr, size), FREE_SCRUB_BYTE, block->req_size - size);
        }
        block->req_size = size;
        return ptr;
    }

    usize old_size = block->req_size;

    lock.take_over().unlock();

    void* const new_ptr = TRY(malloc_impl(size, may_realloc_again, false));
    memcpy(new_ptr, ptr, old_size > size ? size : old_size);
    TRY(free_impl(ptr));

    if (old_size < size) { memset(offset_ptr(new_ptr, old_size), MALLOC_SCRUB_BYTE, size - old_size); }

    return new_ptr;
}

Result<void*> calloc_impl(usize nmemb, usize size, bool may_realloc)
{
    const usize realsize = TRY(safe_mul(nmemb, size));
    void* const ptr = TRY(malloc_impl(realsize, may_realloc, false));
    return memset(ptr, 0, realsize);
}

void dump_heap_usage()
{
    dbgln("-- Dumping usage stats for heap:");
    if (!heap.count())
    {
        dbgln("- Heap is not currently being used");
        return;
    }
    usize alloc_total = 0;
    usize alloc_used = 0;
    auto block = heap.first();
    while (block.has_value())
    {
        HeapBlock* current = block.value();
        if (is_block_free(current))
        {
            dbgln("- Available block (%p), of size %zu (%s%s)", (void*)current, current->full_size,
                  current->status & BLOCK_START_MEM ? "b" : "-", current->status & BLOCK_END_MEM ? "e" : "-");
            alloc_total += current->full_size + sizeof(HeapBlock);
        }
        else
        {
            dbgln("- Used block (%p), of size %zu, of which %zu bytes are being used (%s%s)", (void*)current,
                  current->full_size, current->req_size, current->status & BLOCK_START_MEM ? "b" : "-",
                  current->status & BLOCK_END_MEM ? "e" : "-");
            alloc_total += current->full_size + sizeof(HeapBlock);
            alloc_used += current->req_size;
        }
        block = heap.next(current);
    }

    dbgln("-- Total memory allocated for heap: %zu bytes", alloc_total);
    dbgln("-- Heap memory in use: %zu bytes", alloc_used);
}

void* operator new(usize size, const std::nothrow_t&) noexcept
{
    return malloc_impl(size).value_or(nullptr);
}

void* operator new[](usize size, const std::nothrow_t&) noexcept
{
    return malloc_impl(size).value_or(nullptr);
}

void operator delete(void* p) noexcept
{
    free_impl(p);
}

void operator delete[](void* p) noexcept
{
    free_impl(p);
}

void operator delete(void* p, usize) noexcept
{
    free_impl(p);
}

void operator delete[](void* p, usize) noexcept
{
    free_impl(p);
}