#include <luna/Base64.h>
#include <test.h>

TestResult test_base64_encode_unpadded_message()
{
    auto encoded = TRY(Base64::encode("abc"_sv));

    validate(encoded.view() == "YWJj");

    test_success;
}

TestResult test_base64_decode_unpadded_message()
{
    auto rc = Base64::decode_string("YWJj"_sv);
    if (rc.has_error())
    {
        validate(rc.error() != EINVAL);
        return rc.release_error();
    }

    auto decoded = rc.release_value();

    validate(decoded.view() == "abc"_sv);

    test_success;
}

TestResult test_base64_encode_padded_message()
{
    auto encoded = TRY(Base64::encode("abcd"_sv));

    validate(encoded.view() == "YWJjZA==");

    test_success;
}

TestResult test_base64_encode_padded_message_2()
{
    auto encoded = TRY(Base64::encode("abcde"_sv));

    validate(encoded.view() == "YWJjZGU=");

    test_success;
}

TestResult test_base64_decode_padded_message()
{
    auto rc = Base64::decode_string("YWJjZA=="_sv);
    if (rc.has_error())
    {
        validate(rc.error() != EINVAL);
        return rc.release_error();
    }

    auto decoded = rc.release_value();

    validate(decoded.view() == "abcd"_sv);

    test_success;
}

TestResult test_base64_decode_padded_message_2()
{
    auto rc = Base64::decode_string("YWJjZGU="_sv);
    if (rc.has_error())
    {
        validate(rc.error() != EINVAL);
        return rc.release_error();
    }

    auto decoded = rc.release_value();

    validate(decoded.view() == "abcde"_sv);

    test_success;
}

TestResult test_base64_decode_padded_message_with_newlines()
{
    auto rc = Base64::decode_string("YWJj\nZA==\n"_sv);
    if (rc.has_error())
    {
        validate(rc.error() != EINVAL);
        return rc.release_error();
    }

    auto decoded = rc.release_value();

    validate(decoded.view() == "abcd"_sv);

    test_success;
}

TestResult test_base64_disallow_characters_after_padding()
{
    auto rc = Base64::decode_string("YWJjZA==bd"_sv);

    validate(rc.has_error());

    if (rc.error() != EINVAL) return rc.release_error();

    test_success;
}

TestResult test_base64_disallow_garbage_chars_by_default()
{
    auto rc = Base64::decode_string("YWJj?-ZA=="_sv);

    validate(rc.has_error());

    if (rc.error() != EINVAL) return rc.release_error();

    test_success;
}

TestResult test_base64_skip_garbage_chars_if_allowed()
{
    auto rc = Base64::decode_string("YWJj?-ZA=="_sv, true);
    if (rc.has_error())
    {
        validate(rc.error() != EINVAL);
        return rc.release_error();
    }

    auto decoded = rc.release_value();

    validate(decoded.view() == "abcd"_sv);

    test_success;
}

TestResult test_base64_skip_garbage_chars_if_allowed_after_padding()
{
    auto rc = Base64::decode_string("YWJjZA==\n?-"_sv, true);
    if (rc.has_error())
    {
        validate(rc.error() != EINVAL);
        return rc.release_error();
    }

    auto decoded = rc.release_value();

    validate(decoded.view() == "abcd"_sv);

    test_success;
}

Result<void> test_main()
{
    test_prelude;

    run_test(test_base64_encode_unpadded_message);
    run_test(test_base64_decode_unpadded_message);
    run_test(test_base64_encode_padded_message);
    run_test(test_base64_encode_padded_message_2);
    run_test(test_base64_decode_padded_message);
    run_test(test_base64_decode_padded_message_2);
    run_test(test_base64_decode_padded_message_with_newlines);
    run_test(test_base64_disallow_characters_after_padding);
    run_test(test_base64_disallow_garbage_chars_by_default);
    run_test(test_base64_skip_garbage_chars_if_allowed);
    run_test(test_base64_skip_garbage_chars_if_allowed_after_padding);

    return {};
}