#define MODULE "pci"

#include "io/PCI.h"
#include "io/IO.h"
#include "log/Log.h"
#include "thread/Spinlock.h"

#define PCI_ADDRESS 0xCF8
#define PCI_VALUE 0xCFC

Spinlock pci_lock;

uint32_t PCI::raw_address(uint32_t bus, uint32_t slot, uint32_t function, int32_t offset)
{
    return 0x80000000 | (bus << 16) | (slot << 11) | (function << 8) | ((offset)&0xFC);
}

void PCI::raw_write8(uint32_t bus, uint32_t slot, uint32_t function, int32_t offset, uint8_t value)
{
    IO::outl(PCI_ADDRESS, raw_address(bus, slot, function, offset));
    IO::outl(PCI_VALUE, (uint32_t)value);
}

void PCI::raw_write16(uint32_t bus, uint32_t slot, uint32_t function, int32_t offset, uint16_t value)
{
    IO::outl(PCI_ADDRESS, raw_address(bus, slot, function, offset));
    IO::outl(PCI_VALUE, (uint32_t)value);
}

void PCI::raw_write32(uint32_t bus, uint32_t slot, uint32_t function, int32_t offset, uint32_t value)
{
    IO::outl(PCI_ADDRESS, raw_address(bus, slot, function, offset));
    IO::outl(PCI_VALUE, value);
}

uint8_t PCI::raw_read8(uint32_t bus, uint32_t slot, uint32_t function, int32_t offset)
{
    IO::outl(PCI_ADDRESS, raw_address(bus, slot, function, offset));
    return IO::inl(PCI_VALUE + (offset & 3));
}

uint16_t PCI::raw_read16(uint32_t bus, uint32_t slot, uint32_t function, int32_t offset)
{
    IO::outl(PCI_ADDRESS, raw_address(bus, slot, function, offset));
    return IO::inl(PCI_VALUE + (offset & 2));
}

uint32_t PCI::raw_read32(uint32_t bus, uint32_t slot, uint32_t function, int32_t offset)
{
    IO::outl(PCI_ADDRESS, raw_address(bus, slot, function, offset));
    return IO::inl(PCI_VALUE);
}

PCI::DeviceID PCI::get_device_id(uint32_t bus, uint32_t slot, uint32_t function)
{
    uint16_t vendor = PCI::raw_read16(bus, slot, function, PCI_VENDOR_FIELD);
    uint16_t device = PCI::raw_read16(bus, slot, function, PCI_DEVICE_FIELD);
    return {vendor, device};
}

PCI::DeviceType PCI::get_device_type(uint32_t bus, uint32_t slot, uint32_t function)
{
    uint8_t dev_subclass = PCI::raw_read8(bus, slot, function, PCI_SUBCLASS_FIELD);
    uint8_t dev_class = PCI::raw_read8(bus, slot, function, PCI_CLASS_FIELD);
    uint8_t prog_if = PCI::raw_read8(bus, slot, function, PCI_PROG_IF_FIELD);
    uint8_t revision = PCI::raw_read8(bus, slot, function, PCI_REVISION_ID_FIELD);
    return {dev_class, dev_subclass, prog_if, revision};
}

static void pci_scan_bus(uint8_t bus, void (*callback)(PCI::Device&))
{
    for (uint8_t slot = 0; slot < 32; slot++)
    {
        uint8_t num_functions = 1;
        for (uint8_t function = 0; function < num_functions; function++)
        {
            PCI::DeviceID device_id = PCI::get_device_id(bus, slot, function);
            if (device_id.vendor == 0xFFFF || device_id.device == 0xFFFF) continue;
            PCI::DeviceType device_type = PCI::get_device_type(bus, slot, function);
            uint8_t header = PCI::raw_read8(bus, slot, function, PCI_HEADER_TYPE_FIELD);
            if (header & 0x80) // multi function device
            {
                num_functions = 8;
            }
            if ((header & 0x7F) == 1)
            {
                uint8_t sub_bus = PCI::raw_read8(bus, slot, function, PCI_SECONDARY_BUS_NUMBER_FIELD);
                pci_scan_bus(sub_bus, callback);
            }
            PCI::Device device{device_id, device_type, bus, slot, function};
            pci_lock.release();
            callback(device);
            pci_lock.acquire();
        }
    }
}

void PCI::scan(void (*callback)(PCI::Device&))
{
    pci_lock.acquire();
    pci_scan_bus(0, callback);
    pci_lock.release();
}

PCI::Device::Device(const Device& other)
    : m_id(other.m_id), m_type(other.m_type), m_bus(other.m_bus), m_slot(other.m_slot), m_function(other.m_function)
{
}

PCI::Device::Device(uint8_t bus, uint8_t slot, uint8_t function) : m_bus(bus), m_slot(slot), m_function(function)
{
    m_id = get_device_id(m_bus, m_slot, m_function);
    m_type = get_device_type(m_bus, m_slot, m_function);
}

PCI::Device::Device(DeviceID id, DeviceType type, uint8_t bus, uint8_t slot, uint8_t function)
    : m_id(id), m_type(type), m_bus(bus), m_slot(slot), m_function(function)
{
}

void PCI::Device::write8(int32_t offset, uint8_t value)
{
    PCI::raw_write8(m_bus, m_slot, m_function, offset, value);
}

void PCI::Device::write16(int32_t offset, uint16_t value)
{
    PCI::raw_write16(m_bus, m_slot, m_function, offset, value);
}

void PCI::Device::write32(int32_t offset, uint32_t value)
{
    PCI::raw_write32(m_bus, m_slot, m_function, offset, value);
}

uint8_t PCI::Device::read8(int32_t offset)
{
    return PCI::raw_read8(m_bus, m_slot, m_function, offset);
}

uint16_t PCI::Device::read16(int32_t offset)
{
    return PCI::raw_read16(m_bus, m_slot, m_function, offset);
}

uint32_t PCI::Device::read32(int32_t offset)
{
    return PCI::raw_read32(m_bus, m_slot, m_function, offset);
}

uint32_t PCI::Device::getBAR0()
{
    return read32(PCI_BAR0_FIELD);
}

uint32_t PCI::Device::getBAR1()
{
    return read32(PCI_BAR1_FIELD);
}

uint32_t PCI::Device::getBAR2()
{
    return read32(PCI_BAR2_FIELD);
}

uint32_t PCI::Device::getBAR3()
{
    return read32(PCI_BAR3_FIELD);
}

uint32_t PCI::Device::getBAR4()
{
    return read32(PCI_BAR4_FIELD);
}

uint32_t PCI::Device::getBAR5()
{
    return read32(PCI_BAR5_FIELD);
}