#include "arch/PCI.h"
#include "Log.h"

struct ScanInfo
{
    PCI::Callback callback;
    PCI::Match match;
};

namespace PCI
{
    BAR::BAR(u32 raw) : m_raw(raw)
    {
    }

    Device::ID read_id(const Device::Address& address)
    {
        const u16 vendor = read16(address, Field::VendorID);
        const u16 device = read16(address, Field::DeviceID);
        return { vendor, device };
    }

    Device::Type read_type(const Device::Address& address)
    {
        const u8 klass = read8(address, Field::Class);
        const u8 subclass = read8(address, Field::Subclass);
        const u8 prog_if = read8(address, Field::ProgIF);
        return { klass, subclass, prog_if };
    }

    constexpr u8 PCI_BRIDGE_CLASS = 0x06;
    constexpr u8 PCI_TO_PCI_BRIDGE_SUBCLASS = 0x04;

    static bool is_pci_to_pci_bridge(const Device::Address& address)
    {
        return read8(address, Field::Class) == PCI_BRIDGE_CLASS &&
               read8(address, Field::Subclass) == PCI_TO_PCI_BRIDGE_SUBCLASS;
    }

    static bool matches(const Device::Type& type, const Match& match)
    {
        if (match.klass != -1 && type.klass != match.klass) return false;
        if (match.subclass != -1 && type.subclass != match.subclass) return false;
        if (match.prog_if != -1 && type.prog_if != match.prog_if) return false;
        return true;
    }

    void scan_bus(u32, ScanInfo);

    void scan_function(const Device::Address& address, ScanInfo info)
    {
        const Device::ID id = read_id(address);
        const Device::Type type = read_type(address);

        if (matches(type, info.match)) { info.callback({ id, type, address }); }

        if (is_pci_to_pci_bridge(address))
        {
            kdbgln("PCI-to-PCI bridge detected, performing another bus scan");
            scan_bus(read8(address, Field::SecondaryBus), info);
        }
    }

    void scan_slot(u32 bus, u32 slot, ScanInfo info)
    {
        Device::Address address { bus, slot, 0 };

        // No device on this slot!
        if (read16(address, Field::VendorID) == PCI::INVALID_ID) return;

        scan_function(address, info);

        const u8 header_type = read8(address, Field::HeaderType);

        // Multiple-function PCI device
        if ((header_type & 0x80) == 0x80)
        {
            for (u32 function = 1; function < 8; function++)
            {
                address.function = function;
                if (read16(address, Field::VendorID) == PCI::INVALID_ID) continue;
                scan_function(address, info);
            }
        }
    }

    void scan_bus(u32 bus, ScanInfo info)
    {
        for (u32 slot = 0; slot < 32; slot++) { scan_slot(bus, slot, info); }
    }

    void scan(Callback callback, Match match)
    {
        const u8 header_type = read8({ 0, 0, 0 }, Field::HeaderType);

        // Single-function PCI bus
        if ((header_type & 0x80) == 0)
        {
#ifdef PCI_DEBUG
            kdbgln("PCI bus is single-function");
#endif
            scan_bus(0, { callback, match });
        }
        else
        {
#ifdef PCI_DEBUG
            kdbgln("PCI bus is multiple-function");
#endif
            for (u32 function = 0; function < 8; function++)
            {
                if (read16({ 0, 0, function }, Field::VendorID) != PCI::INVALID_ID)
                {
                    kdbgln("PCI bus has function %u", function);
                    scan_bus(function, { callback, match });
                }
            }
        }
    }

    BAR Device::getBAR(u8 index) const
    {
        check(index < 6);

        u32 raw = read32(address, 0x10 + (index * 4));

        return { raw };
    }
}