#include "arch/x86_64/disk/ATA.h"
#include "Log.h"
#include "arch/Serial.h"
#include "arch/Timer.h"
#include "arch/x86_64/IO.h"
#include "fs/MBR.h"
#include "memory/MemoryManager.h"
#include <luna/Alignment.h>
#include <luna/Buffer.h>
#include <luna/CType.h>
#include <luna/SafeArithmetic.h>
#include <luna/Vector.h>

SharedPtr<ATA::Controller> g_controller;

static void irq_handler(Registers* regs, void* ctx)
{
    ((ATA::Controller*)ctx)->irq_handler(regs);
}

static usize copy_ata_string(char* out, u16* in, usize size)
{
    for (usize i = 0; i < size; i += 2)
    {
        u16 val = in[i / 2];
        out[i] = (u8)(val >> 8);
        out[i + 1] = (u8)(val & 0xff);
    }

    out[size + 1] = '\0';

    return size;
}

namespace ATA
{
    Result<void> Controller::scan()
    {
        // FIXME: Propagate errors.
        PCI::scan(
            [](const PCI::Device& device) {
                if (!g_controller)
                {
                    auto controller = adopt_shared_if_nonnull(new (std::nothrow) Controller(device)).release_value();
                    kinfoln("ata: Found ATA controller on PCI bus (%x:%x:%x)", device.address.bus,
                            device.address.function, device.address.slot);

                    if (controller->initialize()) g_controller = controller;
                }
            },
            { .klass = 1, .subclass = 1 });

        if (!g_controller) kwarnln("ata: No ATA controller found.");

        return {};
    }

    bool Controller::initialize()
    {
        u16 command_old = PCI::read16(m_device.address, PCI::Command);
        u16 command_new = command_old;

        command_new &= ~PCI::CMD_INTERRUPT_DISABLE;
        command_new |= PCI::CMD_IO_SPACE;
        command_new |= PCI::CMD_BUS_MASTER;

        if (command_new != command_old) PCI::write16(m_device.address, PCI::Command, command_new);

        bool success = false;

        if (m_primary_channel.initialize()) success = true;
        if (m_secondary_channel.initialize()) success = true;

        return success;
    }

    void Controller::irq_handler(Registers* regs)
    {
        if (regs->irq == m_primary_channel.irq_line()) m_primary_channel.irq_handler(regs);
        if (regs->irq == m_secondary_channel.irq_line()) m_secondary_channel.irq_handler(regs);
    }

    Controller::Controller(const PCI::Device& device)
        : m_device(device), m_primary_channel(this, 0, {}), m_secondary_channel(this, 1, {})
    {
    }

    Channel::Channel(Controller* controller, u8 channel_index, Badge<Controller>)
        : m_controller(controller), m_channel_index(channel_index)
    {
    }

    u8 Channel::read_register(Register reg)
    {
        return IO::inb(m_io_base + (u16)reg);
    }

    u16 Channel::read_data()
    {
        return IO::inw(m_io_base + (u16)Register::Data);
    }

    void Channel::write_data(u16 value)
    {
        IO::outw(m_io_base + (u16)Register::Data, value);
    }

    void Channel::write_register(Register reg, u8 value)
    {
        IO::outb(m_io_base + (u16)reg, value);
    }

    u8 Channel::read_control(ControlRegister reg)
    {
        return IO::inb(m_control_base + (u16)reg);
    }

    void Channel::write_control(ControlRegister reg, u8 value)
    {
        IO::outb(m_control_base + (u16)reg, value);
    }

    u8 Channel::read_bm(BusmasterRegister reg)
    {
        return IO::inb(m_busmaster_base + (u16)reg);
    }

    void Channel::write_bm(BusmasterRegister reg, u8 value)
    {
        IO::outb(m_busmaster_base + (u16)reg, value);
    }

    u32 Channel::read_prdt_address()
    {
        return IO::inl(m_busmaster_base + (u16)BusmasterRegister::PRDTAddress);
    }

    void Channel::write_prdt_address(u32 value)
    {
        IO::outl(m_busmaster_base + (u16)BusmasterRegister::PRDTAddress, value);
    }

    void Channel::delay_400ns()
    {
        // FIXME: We should use kernel_sleep(), but it doesn't support nanosecond granularity.
        for (int i = 0; i < 14; i++) { [[maybe_unused]] volatile u8 val = read_control(ControlRegister::AltStatus); }
    }

    void Channel::select(u8 drive)
    {
        if (drive == m_current_drive) return;

        u8 value = (u8)(drive << 4) | 0xa0;
        write_register(Register::DriveSelect, value);

        delay_400ns();

        m_current_drive = drive;
    }

    void Channel::irq_handler(Registers*)
    {
        if (!(read_bm(BusmasterRegister::Status) & BMS_IRQPending)) return;

        if (m_current_drive < 2 && m_drives[m_current_drive].has_value()) m_drives[m_current_drive]->irq_handler();

        m_irq_called = true;

        if (m_thread) m_thread->wake_up();
    }

    void Channel::prepare_for_irq()
    {
        m_thread = Scheduler::current();
        m_irq_called = false;
    }

    void Channel::wait_for_irq()
    {
        if (!m_irq_called) kernel_wait_for_event();

        m_irq_called = false;
    }

    bool Channel::wait_for_irq_or_timeout(u64 timeout)
    {
        if (!m_irq_called)
        {
            kernel_sleep(timeout);
            m_irq_called = false;
            return m_thread->sleep_ticks_left;
        }

        m_irq_called = false;
        return true;
    }

    bool Channel::wait_for_reg_set(Register reg, u8 value, u64 timeout)
    {
        u64 begin = Timer::ticks_ms();
        while (true)
        {
            u8 reg_value = reg == Register::Status ? read_control(ControlRegister::AltStatus) : read_register(reg);
            if (reg_value & value) return true;
            if ((Timer::ticks_ms() - begin) >= timeout) return false;
            kernel_sleep(1);
        }
    }

    bool Channel::wait_for_reg_clear(Register reg, u8 value, u64 timeout)
    {
        u64 begin = Timer::ticks_ms();
        while (true)
        {
            u8 reg_value = reg == Register::Status ? read_control(ControlRegister::AltStatus) : read_register(reg);
            if ((reg_value & value) == 0) return true;
            if ((Timer::ticks_ms() - begin) >= timeout) return false;
            kernel_sleep(1);
        }
    }

    Result<void> Channel::wait_until_ready()
    {
        if (!wait_for_reg_clear(Register::Status, SR_Busy, 1000))
        {
            kwarnln("ata: Drive %d:%d timed out (BSY)", m_channel_index, m_current_drive);
            return err(EIO);
        }

        if (!wait_for_reg_set(Register::Status, SR_DataRequestReady | SR_Error, 1000))
        {
            kwarnln("ata: Drive %d:%d timed out (DRQ)", m_channel_index, m_current_drive);
            return err(EIO);
        }

        u8 status = read_control(ControlRegister::AltStatus);
        if (status & SR_Error)
        {
            kwarnln("ata: An error occurred in drive %d:%d while waiting for data to become available", m_channel_index,
                    m_current_drive);
            return err(EIO);
        }

        return {};
    }

    bool Channel::initialize()
    {
        int offset = m_channel_index ? 2 : 0;
        m_is_pci_native_mode = m_controller->device().type.prog_if & (1 << offset);

        u16 control_port_base_address;
        u16 io_base_address;

        if (m_is_pci_native_mode)
        {
            auto io_base = m_controller->device().getBAR(m_channel_index ? 2 : 0);
            if (!io_base.is_iospace())
            {
                kwarnln("ata: Channel %d's IO base BAR is not in IO space", m_channel_index);
                return false;
            }
            io_base_address = io_base.port();

            auto io_control = m_controller->device().getBAR(m_channel_index ? 3 : 1);
            if (!io_control.is_iospace())
            {
                kwarnln("ata: Channel %d's control base BAR is not in IO space", m_channel_index);
                return false;
            }

            control_port_base_address = io_control.port() + 2;
        }
        else
        {
            io_base_address = m_channel_index ? 0x170 : 0x1f0;
            control_port_base_address = m_channel_index ? 0x376 : 0x3f6;
        }

        m_io_base = io_base_address;
        m_control_base = control_port_base_address;

        auto io_busmaster = m_controller->device().getBAR(4);
        if (!io_busmaster.is_iospace())
        {
            kwarnln("ata: Channel %d's busmaster base BAR is not in IO space", m_channel_index);
            return false;
        }
        m_busmaster_base = io_busmaster.port() + (u16)(m_channel_index * 8u);

        if (m_is_pci_native_mode) m_interrupt_line = PCI::read8(m_controller->device().address, PCI::InterruptLine);
        else
            m_interrupt_line = m_channel_index ? 15 : 14;

        write_control(ControlRegister::DeviceControl, 0);

        for (u8 drive = 0; drive < 2; drive++)
        {
            ScopedKMutexLock<100> lock(m_lock);

            select(drive);

            if (read_register(Register::Status) == 0)
            {
                // No drive on this slot.
                continue;
            }

            kinfoln("ata: Channel %d has a drive on slot %d!", m_channel_index, drive);

            m_drives[drive] = Drive { this, drive, {} };

            if (!m_drives[drive]->initialize())
            {
                m_drives[drive] = {};
                return false;
            }
        }

        CPU::register_interrupt(m_interrupt_line, ::irq_handler, m_controller);

        for (u8 drive = 0; drive < 2; drive++)
        {
            if (m_drives[drive].has_value())
            {
                if (!m_drives[drive]->post_initialize())
                {
                    m_drives[drive] = {};
                    return false;
                }

                auto rc = ATADevice::create(m_drives[drive].value_ptr());

                if (rc.has_error())
                {
                    kwarnln("ata: Failed to register ATA drive %d:%d in DeviceRegistry", m_channel_index, drive);
                    continue;
                }

                auto device = rc.release_value();
                MBR::identify(device);
            }
        }

        return true;
    }

    Drive::Drive(Channel* channel, u8 drive_index, Badge<Channel>) : m_channel(channel), m_drive_index(drive_index)
    {
    }

    bool Drive::identify_ata()
    {
        m_channel->write_register(Register::Command, m_is_atapi ? CMD_Identify_Packet : CMD_Identify);

        m_channel->delay_400ns();

        if (!m_channel->wait_for_reg_clear(Register::Status, SR_Busy, 1000))
        {
            kwarnln("ata: Drive %d timed out clearing SR_Busy (waited for 1000 ms)", m_drive_index);
            return false;
        }

        if (m_channel->read_register(Register::Status) & SR_Error)
        {
            u8 lbam = m_channel->read_register(Register::LBAMiddle);
            u8 lbah = m_channel->read_register(Register::LBAHigh);

            if ((lbam == 0x14 && lbah == 0xeb) || (lbam == 0x69 && lbah == 0x96))
            {
                if (!m_is_atapi)
                {
                    kinfoln("ata: Drive %d is ATAPI, sending IDENTIFY_PACKET command", m_drive_index);
                    m_is_atapi = true;
                    return identify_ata();
                }
            }

            kwarnln("ata: IDENTIFY command for drive %d returned error", m_drive_index);

            return false;
        }

        if (!m_channel->wait_for_reg_set(Register::Status, SR_DataRequestReady | SR_Error, 1000))
        {
            kwarnln("ata: Drive %d timed out setting SR_DataRequestReady (waited for 1000 ms)", m_drive_index);
            return false;
        }

        u8 status = m_channel->read_register(Register::Status);
        if (status & SR_Error)
        {
            kwarnln("ata: IDENTIFY command for drive %d returned error", m_drive_index);
            return false;
        }

        for (usize i = 0; i < 256; i++)
        {
            u16 data = m_channel->read_data();
            m_identify_words[i] = data;
        }

        return true;
    }

    bool Drive::initialize()
    {
        m_channel->select(m_drive_index);

        m_channel->write_register(Register::SectorCount, 0);
        m_channel->write_register(Register::LBALow, 0);
        m_channel->write_register(Register::LBAMiddle, 0);
        m_channel->write_register(Register::LBAHigh, 0);

        if (!identify_ata()) return false;

        m_serial.set_length(copy_ata_string(m_serial.data(), &m_identify_words[10], SERIAL_LEN));
        m_revision.set_length(copy_ata_string(m_revision.data(), &m_identify_words[23], REVISION_LEN));
        m_model.set_length(copy_ata_string(m_model.data(), &m_identify_words[27], MODEL_LEN));

        m_serial.trim(" ");
        m_revision.trim(" ");
        m_model.trim(" ");

        kinfoln("ata: Drive IDENTIFY returned serial='%s', revision='%s' and model='%s'", m_serial.chars(),
                m_revision.chars(), m_model.chars());

        auto status = m_channel->read_bm(BusmasterRegister::Status);
        if (status & BMS_SimplexOnly)
        {
            kwarnln("ata: Drive %d will not use DMA because of simplex shenanigans", m_drive_index);
            m_uses_dma = false;
        }

        if (m_drive_index == 0 && !(status & BMS_MasterInit))
        {
            kwarnln("ata: Drive %d does not have DMA support", m_drive_index);
            m_uses_dma = false;
        }

        if (m_drive_index == 1 && !(status & BMS_SlaveInit))
        {
            kwarnln("ata: Drive %d does not have DMA support", m_drive_index);
            m_uses_dma = false;
        }

        auto frame = MemoryManager::alloc_frame();
        if (frame.has_error() || frame.value() > 0xffffffff)
        {
            kwarnln("ata: Failed to allocate memory below the 32-bit limit for the PRDT");
            return false;
        }
        m_dma_prdt_phys = frame.release_value();
        m_dma_prdt = (prdt_entry*)MMU::translate_physical_address(m_dma_prdt_phys);

        memset(m_dma_prdt, 0, ARCH_PAGE_SIZE);

        frame = MemoryManager::alloc_frame();
        if (frame.has_error() || frame.value() > 0xffffffff)
        {
            kwarnln("ata: Failed to allocate memory below the 32-bit limit for DMA memory");
            return false;
        }
        m_dma_mem_phys = frame.release_value();
        m_dma_mem = (void*)MMU::translate_physical_address(m_dma_mem_phys);

        memset(const_cast<void*>(m_dma_mem), 0, ARCH_PAGE_SIZE);

        if (m_uses_dma)
        {
            auto cmd = m_channel->read_bm(BusmasterRegister::Command);
            cmd &= ~BMC_StartStop;
            m_channel->write_bm(BusmasterRegister::Command, cmd);
        }

        return true;
    }

    bool Drive::post_initialize()
    {
        if (m_is_atapi)
        {
            atapi_packet packet;
            memset(&packet, 0, sizeof(packet));
            packet.command_bytes[0] = ATAPI_ReadCapacity;

            atapi_read_capacity_reply reply;

            if (send_packet_atapi_pio(&packet, &reply, sizeof(reply)).has_error())
            {
                kwarnln("ata: Failed to send Read Capacity command to ATAPI drive");
                return false;
            }

            m_is_lba48 = true;

            // FIXME: This assumes the host machine is little-endian.
            u32 last_lba = __builtin_bswap32(reply.last_lba);
            u32 sector_size = __builtin_bswap32(reply.sector_size);

            m_block_count = last_lba + 1;
            m_block_size = sector_size;
        }
        else
        {
            u8 buf[8];
            memcpy(buf, &m_identify_words[100], 8);

            m_block_count = *reinterpret_cast<u64*>(buf);

            if (!m_block_count)
            {
                memcpy(buf, &m_identify_words[60], 4);
                m_block_count = *reinterpret_cast<u32*>(buf);
            }
            else { m_is_lba48 = true; }

            // FIXME: Should we check for CHS?

            // FIXME: Maybe a different block size is in use? Detect that.
            m_block_size = 512;
        }

        u64 total_capacity;
        if (!safe_mul(m_block_count, m_block_size).try_set_value(total_capacity))
        {
            kwarnln("ata: Drive %d's total capacity is too large", m_drive_index);
            return false;
        }

        kinfoln("ata: Drive %d capacity information: Block Count=%lu, Block Size=%lu, Total Capacity=%lu",
                m_drive_index, m_block_count, m_block_size, total_capacity);

        return true;
    }

    Result<void> Drive::send_packet_atapi_pio(const atapi_packet* packet, void* out, u16 response_size)
    {
        u8* ptr = (u8*)out;

        m_channel->select(m_drive_index);

        // We use PIO here.
        m_channel->write_register(Register::Features, 0x00);

        m_channel->write_register(Register::LBAMiddle, (u8)(response_size & 0xff));
        m_channel->write_register(Register::LBAHigh, (u8)(response_size >> 8));

        m_channel->write_register(Register::Command, CMD_Packet);

        m_channel->delay_400ns();

        usize i = 0;

        TRY(m_channel->wait_until_ready());

        for (int j = 0; j < 6; j++) m_channel->write_data(packet->command_words[j]);

        while (i < response_size)
        {
            TRY(m_channel->wait_until_ready());

            usize byte_count =
                m_channel->read_register(Register::LBAHigh) << 8 | m_channel->read_register(Register::LBAMiddle);
            usize word_count = byte_count / 2;

            while (word_count--)
            {
                u16 value = m_channel->read_data();
                ptr[0] = (u8)(value & 0xff);
                ptr[1] = (u8)(value >> 8);
                ptr += 2;
            }

            i += byte_count;
        }

        return {};
    }

#if 0

    Result<void> Drive::send_packet_atapi_dma(const atapi_packet* packet, void* out, u16 response_size)
    {
        check(m_uses_dma);

        m_channel->select(m_drive_index);

        kdbgln("have selected");

        // We use DMA here.
        m_channel->write_register(Register::Features, 0x01);

        m_channel->write_register(Register::LBAMiddle, 0);
        m_channel->write_register(Register::LBAHigh, 0);

        kdbgln("will do_dma_command");

        TRY(do_dma_command(CMD_Packet, response_size, false));

        TRY(m_channel->wait_until_ready());

        kdbgln("send atapi packet data");

        for (int j = 0; j < 6; j++) m_channel->write_data(packet->command_words[j]);

        kdbgln("do dma transfer");

        TRY(do_dma_transfer());

        memcpy(out, const_cast<void*>(m_dma_mem), response_size);

        return {};
    }

    Result<void> Drive::do_dma_command(u8 command, u16 count, bool write)
    {
        m_dma_prdt->address = (u32)m_dma_mem_phys;
        m_dma_prdt->count = count;
        m_dma_prdt->flags = END_OF_PRDT;

        kdbgln("ata: do_dma_command: phys=%x, command=%x, count=%u, write=%d", m_dma_prdt->address, command, count,
               write);

        m_channel->write_prdt_address((u32)m_dma_prdt_phys);

        auto status = m_channel->read_bm(BusmasterRegister::Status);
        status &= ~(BMS_DMAFailure | BMS_IRQPending);
        m_channel->write_bm(BusmasterRegister::Status, status);

        auto cmd = m_channel->read_bm(BusmasterRegister::Command);
        if (!write) cmd |= BMC_ReadWrite;
        else
            cmd &= ~BMC_ReadWrite;
        m_channel->write_bm(BusmasterRegister::Command, cmd);

        m_channel->prepare_for_irq();

        m_channel->write_register(Register::Command, command);

        cmd = m_channel->read_bm(BusmasterRegister::Command);
        cmd |= BMC_StartStop;
        m_channel->write_bm(BusmasterRegister::Command, cmd);

        m_channel->delay_400ns();

        return {};
    }

    Result<void> Drive::do_dma_transfer()
    {
        if (!m_channel->wait_for_irq_or_timeout(2000))
        {
            kwarnln("ata: Drive %d timed out (DMA)", m_drive_index);
            return err(EIO);
        }

        u8 status = m_channel->read_control(ControlRegister::AltStatus);
        kdbgln("ata: status after irq: %#x", status);

        m_channel->delay_400ns();

        auto cmd = m_channel->read_bm(BusmasterRegister::Command);
        cmd &= ~BMC_StartStop;
        m_channel->write_bm(BusmasterRegister::Command, cmd);

        status = m_channel->read_bm(BusmasterRegister::Status);
        m_channel->write_bm(BusmasterRegister::Status, status & ~(BMS_DMAFailure | BMS_IRQPending));

        if (status & BMS_DMAFailure)
        {
            kwarnln("ata: DMA failure while trying to read drive %d", m_drive_index);
            return err(EIO);
        }

        return {};
    }

#endif

    Result<void> Drive::atapi_read_pio(u64 lba, void* out, usize size)
    {
        check(lba < m_block_count);
        check(size <= ARCH_PAGE_SIZE);

        atapi_packet read_packet;
        memset(&read_packet, 0, sizeof(read_packet));
        read_packet.command_bytes[0] = ATAPI_Read;
        read_packet.command_bytes[2] = (lba >> 0x18) & 0xff;
        read_packet.command_bytes[3] = (lba >> 0x10) & 0xff;
        read_packet.command_bytes[4] = (lba >> 0x08) & 0xff;
        read_packet.command_bytes[5] = (lba >> 0x00) & 0xff;
        read_packet.command_bytes[9] = (u8)(size / m_block_size);

        return send_packet_atapi_pio(&read_packet, out, (u16)size);
    }

    Result<void> Drive::read_lba(u64 lba, void* out, usize nblocks)
    {
        const usize blocks_per_page = ARCH_PAGE_SIZE / m_block_size;
        if (m_is_atapi)
        {
            while (nblocks > blocks_per_page)
            {
                TRY(atapi_read_pio(lba, out, ARCH_PAGE_SIZE));
                lba += blocks_per_page;
                nblocks -= blocks_per_page;
                out = offset_ptr(out, ARCH_PAGE_SIZE);
            }
            return atapi_read_pio(lba, out, nblocks * m_block_size);
        }
        else
            todo();
    }

    void Drive::irq_handler()
    {
        // Clear the IRQ flag.
        u8 status = m_channel->read_register(Register::Status);

        if (status & SR_Error)
        {
            u8 error = m_channel->read_register(Register::Error);
            (void)error;
        }

        if (m_uses_dma)
        {
            status = m_channel->read_bm(BusmasterRegister::Status);
            if (status & BMS_DMAFailure) { kwarnln("ata: DMA failure in irq"); }
            m_channel->write_bm(BusmasterRegister::Status, 4);
        }
    }
}

static u32 next_minor = 0;

Result<String> ATA::Drive::create_drive_name(ATA::Drive* drive)
{
    static u32 cd_index = 0;
    static u32 sd_index = 0;

    return String::format("%s%d"_sv, drive->m_is_atapi ? "cd" : "sd", drive->m_is_atapi ? cd_index++ : sd_index++);
}

Result<SharedPtr<Device>> ATADevice::create(ATA::Drive* drive)
{
    auto device = TRY(adopt_shared_if_nonnull(new (std::nothrow) ATADevice(drive)));
    device->m_device_path = TRY(ATA::Drive::create_drive_name(drive));
    TRY(DeviceRegistry::register_special_device(DeviceRegistry::Disk, next_minor++, device, 0400));
    return (SharedPtr<Device>)device;
}

ATADevice::ATADevice(ATA::Drive* drive) : BlockDevice(drive->block_size(), drive->block_count()), m_drive(drive)
{
}

Result<void> ATADevice::read_block(Buffer& buf, u64 block) const
{
    ScopedKMutexLock<100> lock(m_drive->channel()->lock());

    if (buf.size() != m_drive->block_size())
    {
        kwarnln("ata: error while reading block %lu: cache entry size mismatch (%lu), data=%p", block, buf.size(),
                buf.data());
        fail("Cache entry size mismatch");
    }

    return m_drive->read_lba(block, buf.data(), 1);
}