#pragma once
#include "arch/PCI.h"
#include "fs/devices/DeviceRegistry.h"
#include "lib/KMutex.h"
#include <luna/Atomic.h>
#include <luna/SharedPtr.h>
#include <luna/StaticString.h>

namespace ATA
{
    enum class Register : u16
    {
        Data = 0,
        Error = 1,
        Features = 1,
        SectorCount = 2,
        SectorNumber = 3,
        LBALow = 3,
        CylinderLow = 4,
        LBAMiddle = 4,
        CylinderHigh = 5,
        LBAHigh = 5,
        DriveSelect = 6,
        Status = 7,
        Command = 7,
    };

    enum class ControlRegister : u16
    {
        AltStatus = 0,
        DeviceControl = 0,
        DriveAddress = 1,
    };

    enum class BusmasterRegister : u16
    {
        Command = 0,
        Status = 2,
        PRDTAddress = 4,
    };

    enum StatusRegister : u8
    {
        SR_Busy = 0x80,
        SR_DriveReady = 0x40,
        SR_WriteFault = 0x20,
        SR_SeekComplete = 0x10,
        SR_DataRequestReady = 0x08,
        SR_CorrectedData = 0x04,
        SR_Index = 0x02,
        SR_Error = 0x01
    };

    enum CommandRegister : u8
    {
        CMD_Identify = 0xec,
        CMD_Packet = 0xa0,
        CMD_Identify_Packet = 0xa1
    };

    enum BusMasterStatus : u8
    {
        BMS_SimplexOnly = 0x80,
        BMS_SlaveInit = 0x40,
        BMS_MasterInit = 0x20,
        BMS_IRQPending = 0x4,
        BMS_DMAFailure = 0x2,
        BMS_DMAMode = 0x1
    };

    enum BusMasterCommand : u8
    {
        BMC_ReadWrite = 0x8,
        BMC_StartStop = 0x1,
    };

    struct ATAIdentify
    {
        u16 flags;
        u16 unused1[9];
        char serial[20];
        u16 unused2[3];
        char firmware[8];
        char model[40];
        u16 sectors_per_int;
        u16 unused3;
        u16 capabilities[2];
        u16 unused4[2];
        u16 valid_ext_data;
        u16 unused5[5];
        u16 size_of_rw_mult;
        u32 sectors_28;
        u16 unused6[21];
        u16 unused7 : 10;
        u16 big_lba : 1;
        u16 unused8 : 5;
        u16 unused9[17];
        u64 sectors_48;
        u16 unused10[152];
    };

    enum ATAPICommand : u8
    {
        ATAPI_ReadCapacity = 0x25,
        ATAPI_Read = 0xa8,
    };

    class Controller;
    class Channel;

    struct prdt_entry
    {
        u32 address;
        u16 count;
        u16 flags;
    };

    struct atapi_packet
    {
        union {
            u16 command_words[6];
            u8 command_bytes[12];
        };
    };

    struct atapi_read_capacity_reply
    {
        u32 last_lba;
        u32 sector_size;
    };

    static constexpr u16 END_OF_PRDT = (1 << 15);

    class Drive
    {
      public:
        Drive(Channel* channel, u8 drive_index, Badge<Channel>);

        bool initialize();

        bool post_initialize();

        void irq_handler();

        usize block_size() const
        {
            return m_block_size;
        }

        usize block_count() const
        {
            return m_block_count;
        }

        usize capacity() const
        {
            return m_block_count * m_block_size;
        }

        Result<void> read_lba(u64 lba, void* out, usize nblocks);

      private:
        bool identify_ata();

        Result<void> send_packet_atapi_pio(const atapi_packet* packet, void* out, u16 response_size);
#if 0
        Result<void> send_packet_atapi_dma(const atapi_packet* packet, void* out, u16 response_size);

        Result<void> do_dma_command(u8 command, u16 count, bool write);
        Result<void> do_dma_transfer();
#endif

        Result<void> atapi_read_pio(u64 lba, void* out, usize size);

        Channel* m_channel;

        u8 m_drive_index;
        union {
            u16 m_identify_words[256];
            ATAIdentify m_identify_data;
        };

        bool m_is_atapi { false };
        bool m_uses_dma { true };

        bool m_is_lba48;
        u64 m_block_count;
        u64 m_block_size;

        prdt_entry* m_dma_prdt;
        u64 m_dma_prdt_phys;
        volatile void* m_dma_mem;
        u64 m_dma_mem_phys;

        constexpr static usize SERIAL_LEN = 20;
        constexpr static usize REVISION_LEN = 8;
        constexpr static usize MODEL_LEN = 40;

        StaticString<SERIAL_LEN> m_serial;
        StaticString<REVISION_LEN> m_revision;
        StaticString<MODEL_LEN> m_model;
    };

    class Channel
    {
      public:
        Channel(Controller* controller, u8 channel_index, Badge<Controller>);

        u8 read_register(Register reg);
        u16 read_data();
        void write_data(u16 value);
        void write_register(Register reg, u8 value);
        u8 read_control(ControlRegister reg);
        void write_control(ControlRegister reg, u8 value);

        u8 read_bm(BusmasterRegister reg);
        void write_bm(BusmasterRegister reg, u8 value);
        u32 read_prdt_address();
        void write_prdt_address(u32 value);

        bool wait_for_reg_set(Register reg, u8 value, u64 timeout);
        bool wait_for_reg_clear(Register reg, u8 value, u64 timeout);

        Result<void> wait_until_ready();

        void delay_400ns();

        void prepare_for_irq();

        void wait_for_irq();
        bool wait_for_irq_or_timeout(u64 timeout);
        void irq_handler(Registers*);

        void select(u8 drive);

        bool initialize();

      private:
        Controller* m_controller;
        u8 m_channel_index;
        bool m_is_pci_native_mode;

        u8 m_interrupt_line;

        KMutex<100> m_lock {};

        Thread* m_thread { nullptr };

        u16 m_io_base;
        u16 m_control_base;
        u16 m_busmaster_base;

        bool m_irq_called { false };

        u8 m_current_drive = (u8)-1;

        SharedPtr<Drive> m_drives[2];
    };

    class Controller
    {
      public:
        static Result<void> scan();

        const PCI::Device& device() const
        {
            return m_device;
        }

        bool initialize();

      private:
        Controller(const PCI::Device& device);
        PCI::Device m_device;
        Channel m_primary_channel;
        Channel m_secondary_channel;
    };

}

class ATADevice : public Device
{
  public:
    // Initializer for DeviceRegistry.
    static Result<void> create(SharedPtr<ATA::Drive> drive);

    Result<usize> read(u8*, usize, usize) const override;

    Result<usize> write(const u8*, usize, usize) override
    {
        return err(ENOTSUP);
    }

    bool blocking() const override
    {
        return false;
    }

    bool is_block_device() const override
    {
        return true;
    }

    usize size() const override
    {
        return m_drive->capacity();
    }

    virtual ~ATADevice() = default;

  private:
    ATADevice() = default;
    SharedPtr<ATA::Drive> m_drive;
};