diff --git a/libluna/CMakeLists.txt b/libluna/CMakeLists.txt index 7c110a14..a775b140 100644 --- a/libluna/CMakeLists.txt +++ b/libluna/CMakeLists.txt @@ -23,6 +23,7 @@ set(FREESTANDING_SOURCES src/Spinlock.cpp src/PathParser.cpp src/UBSAN.cpp + src/Base64.cpp ) set(SOURCES diff --git a/libluna/include/luna/Base64.h b/libluna/include/luna/Base64.h new file mode 100644 index 00000000..3efe79b8 --- /dev/null +++ b/libluna/include/luna/Base64.h @@ -0,0 +1,13 @@ +#pragma once +#include +#include + +namespace Base64 +{ + Result encode(StringView data); + Result encode(Slice data); + Result encode(const Buffer& data); + + Result decode(StringView data, bool allow_garbage_chars = false); + Result decode_string(StringView data, bool allow_garbage_chars = false); +} diff --git a/libluna/src/Base64.cpp b/libluna/src/Base64.cpp new file mode 100644 index 00000000..7af1b0e5 --- /dev/null +++ b/libluna/src/Base64.cpp @@ -0,0 +1,144 @@ +#include +#include +#include +#include +#include + +static const char g_base64_alphabet[] = "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+/"; + +static Option value_from_base64_character(char letter) +{ + if (_isupper(letter)) return (u8)(letter - 'A'); + if (_islower(letter)) return (u8)(26 + (letter - 'a')); + if (_isdigit(letter)) return (u8)(52 + (letter - '0')); + if (letter == '+') return 62; + if (letter == '/') return 63; + if (letter == '=') return {}; + unreachable(); +} + +static Result decode_base64_buffer(const Vector& input, Buffer& output) +{ + const usize bytes_to_decode = 4; + + Option base64_data[bytes_to_decode]; + + for (usize i = 0; i < 4; i++) { base64_data[i] = value_from_base64_character(input[i]); } + + u8 decoded_bytes[3]; + usize total_decoded_bytes = 0; + + decoded_bytes[0] = (base64_data[0].value() << 2) | (base64_data[1].value() >> 4); + total_decoded_bytes++; + + if (base64_data[2].has_value()) + { + decoded_bytes[1] = ((base64_data[1].value() & 0b1111) << 4) | (base64_data[2].value() >> 2); + total_decoded_bytes++; + } + + if (base64_data[3].has_value()) + { + decoded_bytes[2] = ((base64_data[2].value() & 0b11) << 6) | base64_data[3].value(); + total_decoded_bytes++; + } + + return output.append_data(decoded_bytes, total_decoded_bytes); +} + +namespace Base64 +{ + Result encode(StringView data) + { + return encode(Slice { (const u8*)data.chars(), data.length() }); + } + + Result encode(const Buffer& data) + { + return encode(Slice { data.data(), data.size() }); + } + + Result encode(Slice data) + { + StringBuilder sb; + + for (usize i = 0; i < data.size(); i += 3) + { + usize bytes_to_encode = data.size() - i; + if (bytes_to_encode > 3) bytes_to_encode = 3; + + TRY(sb.add(g_base64_alphabet[data[i] >> 2])); + if (bytes_to_encode > 1) TRY(sb.add(g_base64_alphabet[((data[i] & 0b11) << 4) | (data[i + 1] >> 4)])); + else + { + TRY(sb.add(g_base64_alphabet[(data[i] & 0b11) << 6])); + TRY(sb.add("=="_sv)); + break; + } + + if (bytes_to_encode > 2) TRY(sb.add(g_base64_alphabet[((data[i + 1] & 0b1111) << 2) | (data[i + 2] >> 6)])); + else + { + TRY(sb.add(g_base64_alphabet[(data[i + 1] & 0b1111) << 2])); + TRY(sb.add("="_sv)); + break; + } + + TRY(sb.add(g_base64_alphabet[data[i + 2] & 0b111111])); + } + + return sb.string(); + } + + Result decode(StringView data, bool allow_garbage_chars) + { + Buffer buf; + + char* padding = strchr(data.chars(), '='); + if (padding) + { + // If the string ends with padding, it must be either one or two equals signs. + if (padding[1] != '=' && padding[1] != '\0') return err(EINVAL); + if (padding[1]) + { + if (padding[strspn(&padding[2], "\n") + 2]) return err(EINVAL); + } + } + + Vector chars_read; + + for (const auto& c : data) + { + if (c == '\n') continue; + + if (!_isalnum(c) && c != '+' && c != '/' && c != '=') + { + if (allow_garbage_chars) continue; + return err(EINVAL); + } + + TRY(chars_read.try_append(c)); + + if (chars_read.size() == 4) + { + TRY(decode_base64_buffer(chars_read, buf)); + chars_read.clear(); + } + } + + // Unterminated input + if (chars_read.size() > 0) return err(EINVAL); + + return buf; + } + + Result decode_string(StringView data, bool allow_garbage_chars) + { + auto buf = TRY(decode(data, allow_garbage_chars)); + + u8 nul_byte = '\0'; + TRY(buf.append_data(&nul_byte, 1)); + + return String::from_cstring((char*)buf.data()); + } +}