#include <luna/Base64.h>
#include <luna/CType.h>
#include <luna/DebugLog.h>
#include <luna/Slice.h>
#include <luna/StringBuilder.h>

static const char g_base64_alphabet[] = "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+/";

static Option<u8> value_from_base64_character(char letter)
{
    if (_isupper(letter)) return (u8)(letter - 'A');
    if (_islower(letter)) return (u8)(26 + (letter - 'a'));
    if (_isdigit(letter)) return (u8)(52 + (letter - '0'));
    if (letter == '+') return 62;
    if (letter == '/') return 63;
    if (letter == '=') return {};
    unreachable();
}

static Result<void> decode_base64_buffer(const Vector<char>& input, Buffer& output)
{
    const usize bytes_to_decode = 4;

    Option<u8> base64_data[bytes_to_decode];

    for (usize i = 0; i < 4; i++) { base64_data[i] = value_from_base64_character(input[i]); }

    u8 decoded_bytes[3];
    usize total_decoded_bytes = 0;

    decoded_bytes[0] = (base64_data[0].value() << 2) | (base64_data[1].value() >> 4);
    total_decoded_bytes++;

    if (base64_data[2].has_value())
    {
        decoded_bytes[1] = ((base64_data[1].value() & 0b1111) << 4) | (base64_data[2].value() >> 2);
        total_decoded_bytes++;
    }

    if (base64_data[3].has_value())
    {
        decoded_bytes[2] = ((base64_data[2].value() & 0b11) << 6) | base64_data[3].value();
        total_decoded_bytes++;
    }

    return output.append_data(decoded_bytes, total_decoded_bytes);
}

namespace Base64
{
    Result<String> encode(StringView data)
    {
        return encode(Slice<const u8> { (const u8*)data.chars(), data.length() });
    }

    Result<String> encode(const Buffer& data)
    {
        return encode(Slice<const u8> { data.data(), data.size() });
    }

    Result<String> encode(Slice<const u8> data)
    {
        StringBuilder sb;

        for (usize i = 0; i < data.size(); i += 3)
        {
            usize bytes_to_encode = data.size() - i;
            if (bytes_to_encode > 3) bytes_to_encode = 3;

            TRY(sb.add(g_base64_alphabet[data[i] >> 2]));
            if (bytes_to_encode > 1) TRY(sb.add(g_base64_alphabet[((data[i] & 0b11) << 4) | (data[i + 1] >> 4)]));
            else
            {
                TRY(sb.add(g_base64_alphabet[(data[i] & 0b11) << 6]));
                TRY(sb.add("=="_sv));
                break;
            }

            if (bytes_to_encode > 2) TRY(sb.add(g_base64_alphabet[((data[i + 1] & 0b1111) << 2) | (data[i + 2] >> 6)]));
            else
            {
                TRY(sb.add(g_base64_alphabet[(data[i + 1] & 0b1111) << 2]));
                TRY(sb.add("="_sv));
                break;
            }

            TRY(sb.add(g_base64_alphabet[data[i + 2] & 0b111111]));
        }

        return sb.string();
    }

    Result<Buffer> decode(StringView data, bool allow_garbage_chars)
    {
        Buffer buf;

        char* padding = strchr(data.chars(), '=');
        if (padding)
        {
            padding++;
            // If the string ends with padding, it must be either one or two equals signs.
            if (*padding != '=' && *padding != '\0') return err(EINVAL);

            if (*padding) padding++;

            // After that, only thing allowed is newline (and garbage characters if those are permitted)
            while (*padding)
            {
                char c = *padding;
                padding++;

                if (c == '\n') continue;

                if (_isalnum(c) || c == '+' || c == '/' || c == '=') return err(EINVAL);

                if (!allow_garbage_chars) return err(EINVAL);
            }
        }

        Vector<char> chars_read;
        TRY(chars_read.try_reserve(4));

        for (const auto& c : data)
        {
            if (c == '\n') continue;

            if (!_isalnum(c) && c != '+' && c != '/' && c != '=')
            {
                if (allow_garbage_chars) continue;
                return err(EINVAL);
            }

            TRY(chars_read.try_append(c));

            if (chars_read.size() == 4)
            {
                TRY(decode_base64_buffer(chars_read, buf));
                chars_read.clear_data();
            }
        }

        // Unterminated input
        if (chars_read.size() > 0) return err(EINVAL);

        return buf;
    }

    Result<String> decode_string(StringView data, bool allow_garbage_chars)
    {
        auto buf = TRY(decode(data, allow_garbage_chars));

        u8 nul_byte = '\0';
        TRY(buf.append_data(&nul_byte, 1));

        return String { (char*)buf.release_data() };
    }
}