/**
 * @file Format.cpp
 * @author apio (cloudapio.eu)
 * @brief C-style string formatting.
 *
 * @copyright Copyright (c) 2022-2023, the Luna authors.
 *
 */

#include <luna/CType.h>
#include <luna/Format.h>
#include <luna/NumberParsing.h>
#include <luna/Result.h>
#include <stdarg.h>
#include <stddef.h>

extern "C" usize strlen(const char*);

typedef int flags_t;
#define FLAG_ZERO_PAD (1 << 0)
#define FLAG_LEFT_ALIGN (1 << 1)
#define FLAG_BLANK_SIGNED (1 << 2)
#define FLAG_ALTERNATE (1 << 3)
#define FLAG_SIGN (1 << 4)
#define FLAG_USE_PRECISION (1 << 5)
#define FLAG_LONG (1 << 6)
#define FLAG_LONG_LONG (1 << 7)
#define FLAG_SHORT (1 << 8)
#define FLAG_CHAR (1 << 9)

struct format_state
{
    usize count;
    FormatCallback callback;
    void* arg;
};

struct conv_state
{
    flags_t flags;
    usize width;
    usize precision;
};

static Result<void> format_putchar(char c, format_state& state)
{
    state.count++;
    return state.callback(c, state.arg);
}

static Result<void> format_puts(const char* s, format_state& state)
{
    while (*s)
    {
        TRY(format_putchar(*s, state));
        s++;
    }

    return {};
}

static Result<void> start_pad(const conv_state& vstate, format_state& state, usize start)
{
    if (!(vstate.flags & FLAG_LEFT_ALIGN))
    {
        while (start++ < vstate.width) TRY(format_putchar(' ', state));
    }

    return {};
}

static Result<void> end_pad(const conv_state& vstate, format_state& state, usize start)
{
    if (vstate.flags & FLAG_LEFT_ALIGN)
    {
        while (start++ < vstate.width) TRY(format_putchar(' ', state));
    }

    return {};
}

static flags_t parse_flags(const char** format)
{
    flags_t result = 0;

    while (true)
    {
        switch (**format)
        {
        case '#':
            result |= FLAG_ALTERNATE;
            (*format)++;
            break;
        case '0':
            result |= FLAG_ZERO_PAD;
            (*format)++;
            break;
        case ' ':
            result |= FLAG_BLANK_SIGNED;
            (*format)++;
            break;
        case '-':
            result |= FLAG_LEFT_ALIGN;
            (*format)++;
            break;
        case '+':
            result |= FLAG_SIGN;
            (*format)++;
            break;
        default: return result;
        }
    }
}

static usize parse_width(const char** format, flags_t& flags, va_list ap)
{
    usize result = 0;

    if (_isdigit(**format)) result = scan_unsigned_integer(format);
    else if (**format == '*')
    {
        const int width = va_arg(ap, int);
        if (width >= 0) result = (usize)width;
        else
        {
            flags |= FLAG_LEFT_ALIGN;
            result = (usize)-width;
        }
        (*format)++;
    }

    return result;
}

static usize parse_precision(const char** format, flags_t& flags, va_list ap)
{
    usize result = 0;

    if (**format == '.')
    {
        (*format)++;

        flags |= FLAG_USE_PRECISION;

        if (_isdigit(**format)) result = scan_unsigned_integer(format);
        else if (**format == '*')
        {
            const int precision = va_arg(ap, int);
            if (precision >= 0) result = (usize)precision;
            else
            {
                result = 0;
                flags &= ~FLAG_USE_PRECISION;
            }
            (*format)++;
        }
    }

    return result;
}

static void parse_type(const char** format, flags_t& flags)
{
    // FIXME: Support %j (intmax_t/uintmax_t)
    switch (**format)
    {
    case 'h':
        flags |= FLAG_SHORT;
        (*format)++;
        if (**format == 'h')
        {
            flags |= FLAG_CHAR;
            (*format)++;
        }
        break;
    case 'l':
        flags |= FLAG_LONG;
        (*format)++;
        if (**format == 'l')
        {
            flags |= FLAG_LONG_LONG;
            (*format)++;
        }
        break;
    case 't':
        flags |= (sizeof(ptrdiff_t) == sizeof(long)) ? FLAG_LONG : FLAG_LONG_LONG;
        (*format)++;
        break;
    case 'z':
        flags |= (sizeof(usize) == sizeof(long)) ? FLAG_LONG : FLAG_LONG_LONG;
        (*format)++;
        break;
    default: break;
    }
}

static bool is_integer_format_specifier(char c)
{
    return (c == 'd') || (c == 'i') || (c == 'u') || (c == 'x') || (c == 'X') || (c == 'o') || (c == 'b');
}

static usize to_string(usize value, usize base, char* buf, usize max, bool uppercase)
{
    usize i = 0;
    if (!max) return 0;
    if (!value)
    {
        buf[i] = '0';
        return 1;
    }
    do {
        const int digit = (int)(value % base);
        const char c = (char)(digit < 10 ? '0' + digit : ((uppercase ? 'A' : 'a') + (digit - 10)));
        buf[i++] = c;
        value /= base;
    } while (value && i < max);
    return i;
}

static Result<void> output_integer_data(const conv_state& vstate, format_state& state, char* buf, usize len)
{
    if (!(vstate.flags & FLAG_ZERO_PAD)) TRY(start_pad(vstate, state, len));

    usize i = len;

    while (i--) TRY(format_putchar(buf[i], state));

    TRY(end_pad(vstate, state, len));

    return {};
}

static Result<void> output_integer(char specifier, conv_state& vstate, format_state& state, usize value, bool negative)
{
    usize base = 10;
    bool uppercase = false;

    // When 0 is printed with an explicit precision 0, the output is empty.
    if ((vstate.flags & FLAG_USE_PRECISION) && vstate.precision == 0 && value == 0) return {};

    switch (specifier)
    {
    case 'p':
    case 'x':
    case 'X': base = 16; break;
    case 'o': base = 8; break;
    case 'b': base = 2; break;
    default: break;
    }
    if (specifier == 'X') uppercase = true;

    if (base == 10) vstate.flags &= ~FLAG_ALTERNATE; // decimal doesn't have an alternate form

    char buf[1024];
    usize buflen = to_string(value, base, buf, sizeof(buf), uppercase);

    if (!(vstate.flags & FLAG_LEFT_ALIGN) &&
        (vstate.flags & FLAG_ZERO_PAD)) // we're padding with zeroes from the beginning
    {
        const bool extra_char =
            negative || ((vstate.flags & FLAG_SIGN) ||
                         (vstate.flags & FLAG_BLANK_SIGNED)); // are we adding an extra character after the buffer?
        if (vstate.width && extra_char) vstate.width--;

        if (vstate.width && (vstate.flags & FLAG_ALTERNATE)) // fit in the characters we're using for the alternate form
        {
            vstate.width--;
            if (vstate.width && (base == 2 || base == 16)) vstate.width--;
        }

        while (buflen < vstate.width && buflen < sizeof(buf)) buf[buflen++] = '0';
    }

    while (buflen < vstate.precision && buflen < sizeof(buf)) buf[buflen++] = '0';

    if (vstate.flags & FLAG_ALTERNATE)
    {
        if (base == 16 && !uppercase && buflen < sizeof(buf)) buf[buflen++] = 'x';
        if (base == 16 && uppercase && buflen < sizeof(buf)) buf[buflen++] = 'X';
        if (base == 2 && buflen < sizeof(buf)) buf[buflen++] = 'b';
        if (buflen < sizeof(buf)) buf[buflen++] = '0';
    }

    if (buflen < sizeof(buf))
    {
        if (negative) buf[buflen++] = '-';
        else if (vstate.flags & FLAG_SIGN)
            buf[buflen++] = '+';
        else if (vstate.flags & FLAG_BLANK_SIGNED)
            buf[buflen++] = ' ';
    }

    return output_integer_data(vstate, state, buf, buflen);
}

static Result<void> va_output_integer(char specifier, conv_state& vstate, format_state& state, va_list ap)
{
    bool is_signed = false;
    bool negative = false;

    if (specifier == 'd' || specifier == 'i') is_signed = true;

    if (!is_signed) vstate.flags &= ~(FLAG_SIGN | FLAG_BLANK_SIGNED);

    if (vstate.flags & FLAG_CHAR)
    {
        if (is_signed)
        {
            char v = (char)va_arg(ap, int);
            if (v < 0)
            {
                v = -v;
                negative = true;
            }
            return output_integer(specifier, vstate, state, (unsigned char)v, negative);
        }
        else
        {
            const unsigned char v = (unsigned char)va_arg(ap, unsigned int);
            return output_integer(specifier, vstate, state, v, false);
        }
    }
    else if (vstate.flags & FLAG_SHORT)
    {
        if (is_signed)
        {
            short v = (short)va_arg(ap, int);
            if (v < 0)
            {
                v = -v;
                negative = true;
            }
            return output_integer(specifier, vstate, state, (unsigned short)v, negative);
        }
        else
        {
            const unsigned short v = (unsigned short)va_arg(ap, unsigned int);
            return output_integer(specifier, vstate, state, v, false);
        }
    }
    else if (vstate.flags & FLAG_LONG_LONG)
    {
        if (is_signed)
        {
            long long v = va_arg(ap, long long);
            if (v < 0)
            {
                v = -v;
                negative = true;
            }
            return output_integer(specifier, vstate, state, (unsigned long long)v, negative);
        }
        else
        {
            const unsigned long long v = va_arg(ap, unsigned long long);
            return output_integer(specifier, vstate, state, v, false);
        }
    }
    else if (vstate.flags & FLAG_LONG)
    {
        if (is_signed)
        {
            long v = va_arg(ap, long);
            if (v < 0)
            {
                v = -v;
                negative = true;
            }
            return output_integer(specifier, vstate, state, (unsigned long)v, negative);
        }
        else
        {
            const unsigned long v = va_arg(ap, unsigned long);
            return output_integer(specifier, vstate, state, v, false);
        }
    }
    else
    {
        if (is_signed)
        {
            int v = va_arg(ap, int);
            if (v < 0)
            {
                v = -v;
                negative = true;
            }
            return output_integer(specifier, vstate, state, (unsigned int)v, negative);
        }
        else
        {
            const unsigned int v = va_arg(ap, unsigned int);
            return output_integer(specifier, vstate, state, v, false);
        }
    }
}

Result<usize> cstyle_format(const char* format, FormatCallback callback, void* arg, va_list ap)
{
    format_state state;
    state.callback = callback;
    state.arg = arg;
    state.count = 0;

    while (*format)
    {
        if (*format != '%')
        {
            TRY(format_putchar(*format, state));
            format++;
            continue;
        }

        format++;

        if (*format == '%')
        {
            TRY(format_putchar('%', state));
            format++;
            continue;
        }

        // %[flags][width][.precision][length]conversion

        flags_t flags = parse_flags(&format);
        const usize width = parse_width(&format, flags, ap);
        usize precision = parse_precision(&format, flags, ap);
        parse_type(&format, flags);

        conv_state vstate = { flags, width, precision };

        const char specifier = *format;
        format++;

        if (is_integer_format_specifier(specifier))
        {
            TRY(va_output_integer(specifier, vstate, state, ap));
            continue;
        }
        else if (specifier == 'p')
        {
            const void* ptr = va_arg(ap, void*);
            if (ptr == nullptr)
            {
                TRY(start_pad(vstate, state, 5));
                TRY(format_puts("(nil)", state));
                TRY(end_pad(vstate, state, 5));
                continue;
            }
            vstate.width = (sizeof(void*) * 2) + 2;
            vstate.flags |= (FLAG_ZERO_PAD | FLAG_ALTERNATE);
            TRY(output_integer('p', vstate, state, (usize)ptr, false));
            continue;
        }
        else if (specifier == 'c')
        {
            // FIXME: If FLAG_LONG is set, we should use a wint_t.
            const char c = (char)va_arg(ap, int);

            TRY(start_pad(vstate, state, 1));
            TRY(format_putchar(c, state));
            TRY(end_pad(vstate, state, 1));

            continue;
        }
        else if (specifier == 's')
        {
            // FIXME: If FLAG_LONG is set, we should use a wide string.
            const char* str = va_arg(ap, const char*);
            if (str == nullptr)
            {
                TRY(start_pad(vstate, state, 6));
                TRY(format_puts("(null)", state));
                TRY(end_pad(vstate, state, 6));
                continue;
            }
            else
            {
                usize len = strlen(str);

                bool use_precision = (flags & FLAG_USE_PRECISION);
                if (use_precision && len > precision) len = precision;

                TRY(start_pad(vstate, state, len));
                while (*str && (!use_precision || precision))
                {
                    TRY(format_putchar(*str, state));
                    if (use_precision) precision--;
                    str++;
                }
                TRY(end_pad(vstate, state, len));
                continue;
            }
        }
        else { continue; }
    }

    return state.count;
}

struct StringFormatInfo
{
    char* buffer;
    usize remaining;
};

usize vstring_format(char* buf, usize max, const char* format, va_list ap)
{
    StringFormatInfo info = { .buffer = buf, .remaining = max - 1 };

    usize result = cstyle_format(
                       format,
                       [](char c, void* arg) -> Result<void> {
                           StringFormatInfo* info_arg = (StringFormatInfo*)arg;
                           if (!info_arg->remaining) return {};
                           if (info_arg->buffer)
                           {
                               *(info_arg->buffer) = c;
                               info_arg->buffer++;
                           }
                           info_arg->remaining--;
                           return {};
                       },
                       &info, ap)
                       .value();

    if (info.buffer) *(info.buffer) = 0;

    return result;
}

usize string_format(char* buf, usize max, const char* format, ...)
{
    va_list ap;
    va_start(ap, format);

    usize result = vstring_format(buf, max, format, ap);

    va_end(ap);

    return result;
}