#include "arch/CPU.h"
#include "Log.h"
#include "arch/Keyboard.h"
#include "arch/Timer.h"
#include "arch/x86_64/CPU.h"
#include "arch/x86_64/IO.h"
#include "fs/devices/ConsoleDevice.h"
#include "memory/MemoryManager.h"
#include "sys/Syscall.h"
#include "thread/Scheduler.h"
#include "video/TextConsole.h"
#include <cpuid.h>
#include <luna/CString.h>
#include <luna/CircularQueue.h>
#include <luna/Result.h>
#include <luna/SystemError.h>
#include <luna/Types.h>

extern "C" void enable_sse();
extern "C" void enable_write_protect();
extern "C" void enable_nx();

extern void setup_gdt();
extern void remap_pic();
extern void pic_eoi(unsigned char irq);
extern void pic_eoi(Registers* regs);
extern void setup_idt();

void FPData::save()
{
    asm volatile("fxsave (%0)" : : "r"(m_data));
    m_already_saved = true;
}

void FPData::restore()
{
    if (!m_already_saved) return;
    asm volatile("fxrstor (%0)" : : "r"(m_data));
}

// Interrupt handling

#define FIXME_UNHANDLED_INTERRUPT(name)                                                                                \
    kerrorln("FIXME(interrupt): %s", name);                                                                            \
    CPU::efficient_halt();

#define PF_PRESENT 1 << 0
#define PF_WRITE 1 << 1
#define PF_USER 1 << 2
#define PF_RESERVED 1 << 3
#define PF_NX_VIOLATION 1 << 4

void decode_page_fault_error_code(u64 code)
{
    kwarnln("Fault details: %s | %s | %s%s%s", (code & PF_PRESENT) ? "Present" : "Not present",
            (code & PF_WRITE) ? "Write access" : "Read access", (code & PF_USER) ? "User mode" : "Kernel mode",
            (code & PF_RESERVED) ? " | Reserved bits set" : "", (code & PF_NX_VIOLATION) ? " | NX violation" : "");
}

[[noreturn]] void handle_page_fault(Registers* regs)
{
    CPU::disable_interrupts();

    u64 cr2;
    asm volatile("mov %%cr2, %0" : "=r"(cr2));
    kerrorln("Page fault at RIP %lx while accessing %lx!", regs->rip, cr2);

    decode_page_fault_error_code(regs->error);

    if (!is_in_kernel(regs))
    {
        // FIXME: Kill this process with SIGSEGV once we have signals and all that.
        kerrorln("Current task %zu was terminated because of a page fault", Scheduler::current()->id);
        Scheduler::current()->state = ThreadState::Exited;
        Scheduler::current()->status = 127;
        kernel_yield();
        unreachable();
    }

    CPU::print_stack_trace_at(regs);

    CPU::efficient_halt();
}

[[noreturn]] void handle_general_protection_fault(Registers* regs)
{
    CPU::disable_interrupts();

    kerrorln("General protection fault at RIP %lx, error code %lx!", regs->rip, regs->error);

    CPU::print_stack_trace_at(regs);

    CPU::efficient_halt();
}

extern "C" void handle_x86_exception(Registers* regs)
{
    switch (regs->isr)
    {
    case 0: FIXME_UNHANDLED_INTERRUPT("Division by zero");
    case 1: FIXME_UNHANDLED_INTERRUPT("Debug interrupt");
    case 2: FIXME_UNHANDLED_INTERRUPT("NMI (Non-maskable interrupt)");
    case 3: FIXME_UNHANDLED_INTERRUPT("Breakpoint");
    case 4: FIXME_UNHANDLED_INTERRUPT("Overflow");
    case 5: FIXME_UNHANDLED_INTERRUPT("Bound range exceeded");
    case 6: FIXME_UNHANDLED_INTERRUPT("Invalid opcode");
    case 7: FIXME_UNHANDLED_INTERRUPT("Device not available");
    case 10: FIXME_UNHANDLED_INTERRUPT("Invalid TSS");
    case 11: FIXME_UNHANDLED_INTERRUPT("Segment not present");
    case 12: FIXME_UNHANDLED_INTERRUPT("Stack-segment fault");
    case 13: handle_general_protection_fault(regs);
    case 14: handle_page_fault(regs);
    case 16: FIXME_UNHANDLED_INTERRUPT("x87 floating-point exception");
    case 17: FIXME_UNHANDLED_INTERRUPT("Alignment check");
    case 19: FIXME_UNHANDLED_INTERRUPT("SIMD floating-point exception");
    case 20: FIXME_UNHANDLED_INTERRUPT("Virtualization exception");
    case 21: FIXME_UNHANDLED_INTERRUPT("Control-protection exception");
    default: FIXME_UNHANDLED_INTERRUPT("Reserved exception or #DF/#MC, which shouldn't call handle_x86_exception");
    }
}

CircularQueue<u8, 60> scancode_queue;

void io_thread()
{
    while (true)
    {
        u8 scancode;
        while (!scancode_queue.try_pop(scancode)) { kernel_sleep(10); }

        char key;
        if (Keyboard::decode_scancode(scancode).try_set_value(key)) { ConsoleDevice::did_press_key(key); }
    }
}

// Called from _asm_interrupt_entry
extern "C" void arch_interrupt_entry(Registers* regs)
{
    if (regs->isr < 32) handle_x86_exception(regs);
    else if (regs->isr == 32) // Timer interrupt
    {
        Timer::tick();
        if (should_invoke_scheduler()) Scheduler::invoke(regs);
        pic_eoi(regs);
    }
    else if (regs->isr == 33) // Keyboard interrupt
    {
        u8 scancode = IO::inb(0x60);
        scancode_queue.try_push(scancode);
        pic_eoi(regs);
    }
    else if (regs->isr == 66) // System call
    {
        SyscallArgs args = { regs->rdi, regs->rsi, regs->rdx, regs->r10, regs->r8, regs->r9 };
        regs->rax = (u64)invoke_syscall(regs, args, regs->rax);
    }
    else
    {
        kwarnln("IRQ catched! Halting.");
        CPU::efficient_halt();
    }
}

extern "C" [[noreturn]] void arch_double_fault()
{
    kerrorln("ERROR: Catched double fault");
    CPU::efficient_halt();
}

extern "C" [[noreturn]] void arch_machine_check()
{
    kerrorln("ERROR: Machine check failed");
    CPU::efficient_halt();
}

// Generic CPU code

static bool test_nx()
{
    u32 __unused, edx = 0;
    if (!__get_cpuid(0x80000001, &__unused, &__unused, &__unused, &edx)) return 0;
    return edx & (1 << 20);
}

namespace CPU
{
    Result<const char*> identify()
    {
        static char brand_string[49];

        u32 buf[4];
        if (!__get_cpuid(0x80000002, &buf[0], &buf[1], &buf[2], &buf[3])) return err(ENOTSUP);
        memcpy(brand_string, buf, 16);
        if (!__get_cpuid(0x80000003, &buf[0], &buf[1], &buf[2], &buf[3])) return err(ENOTSUP);
        memcpy(&brand_string[16], buf, 16);
        if (!__get_cpuid(0x80000004, &buf[0], &buf[1], &buf[2], &buf[3])) return err(ENOTSUP);
        memcpy(&brand_string[32], buf, 16);

        brand_string[48] = 0; // null-terminate it :)

        return brand_string;
    }

    const char* platform_string()
    {
        return "x86_64";
    }

    void platform_init()
    {
        enable_sse();
        // enable_write_protect();
        if (test_nx()) enable_nx();
        else
            kwarnln("not setting the NX bit as it is unsupported");
        setup_gdt();
        setup_idt();
    }

    void platform_finish_init()
    {
        Scheduler::new_kernel_thread(io_thread, "[x86_64-io]")
            .expect_value("Could not create the IO background thread!");

        remap_pic();
    }

    void enable_interrupts()
    {
        asm volatile("sti");
    }

    void disable_interrupts()
    {
        asm volatile("cli");
    }

    void wait_for_interrupt()
    {
        asm volatile("hlt");
    }

    [[noreturn]] void efficient_halt() // Halt the CPU, using the lowest power possible. On x86-64 we do this using the
                                       // "hlt" instruction, which puts the CPU into a low-power idle state until the
                                       // next interrupt arrives... and we disable interrupts beforehand.
    {
        asm volatile("cli"); // Disable interrupts
    loop:
        asm volatile("hlt"); // Let the cpu rest and pause until the next interrupt arrives... which in this case should
                             // be never (unless an NMI arrives) :)
        goto loop;           // Safeguard: if we ever wake up, start our low-power rest again
    }

    [[noreturn]] void idle_loop()
    {
        asm volatile("sti");
    loop:
        asm volatile("hlt");
        goto loop;
    }

    void switch_kernel_stack(u64 top)
    {
        task_state_segment.rsp[0] = top;
    }

    struct StackFrame
    {
        StackFrame* next;
        u64 instruction;
    };

    static void backtrace_impl(u64 base_pointer, void (*callback)(u64, void*), void* arg)
    {
        StackFrame* current_frame = (StackFrame*)base_pointer;
        // FIXME: Validate that the frame itself is readable, might span across multiple pages
        while (current_frame && MemoryManager::validate_readable_page((u64)current_frame) && current_frame->instruction)
        {
            callback(current_frame->instruction, arg);
            current_frame = current_frame->next;
        }
    }

    void get_stack_trace(void (*callback)(u64, void*), void* arg)
    {
        u64 rbp;
        asm volatile("mov %%rbp, %0" : "=r"(rbp));
        return backtrace_impl(rbp, callback, arg);
    }

    void print_stack_trace()
    {
        u64 rbp;
        int frame_index = 0;
        asm volatile("mov %%rbp, %0" : "=r"(rbp));
        return backtrace_impl(
            rbp,
            [](u64 instruction, void* arg) {
                int* ptr = (int*)arg;
                kinfoln("#%d at %p", *ptr, (void*)instruction);
                (*ptr)++;
            },
            &frame_index);
    }

    void get_stack_trace_at(Registers* regs, void (*callback)(u64, void*), void* arg)
    {
        callback(regs->rip, arg);
        return backtrace_impl(regs->rbp, callback, arg);
    }

    void print_stack_trace_at(Registers* regs)
    {
        int frame_index = 0;
        get_stack_trace_at(
            regs,
            [](u64 instruction, void* arg) {
                int* ptr = (int*)arg;
                kinfoln("#%d at %p", *ptr, (void*)instruction);
                (*ptr)++;
            },
            &frame_index);
    }

    [[noreturn]] void bootstrap_switch_stack(u64 stack, void* function)
    {
        asm volatile("mov %0, %%rsp\n"
                     "jmp *%1"
                     :
                     : "r"(stack), "r"(function));
        __builtin_unreachable();
    }

    void pause()
    {
        asm volatile("pause");
    }

    u16 get_processor_id()
    {
        unsigned int unused;
        unsigned int ebx = 0;
        __get_cpuid(1, &unused, &ebx, &unused, &unused);
        return (u16)(ebx >> 24);
    }
}

// called by kernel_yield
extern "C" void switch_task(Registers* regs)
{
    Scheduler::switch_task(regs);
}