diff --git a/libluna/CMakeLists.txt b/libluna/CMakeLists.txt index e0bac9cf..51190b11 100644 --- a/libluna/CMakeLists.txt +++ b/libluna/CMakeLists.txt @@ -25,6 +25,7 @@ set(FREESTANDING_SOURCES src/PathParser.cpp src/UBSAN.cpp src/Base64.cpp + src/Hash.cpp ) set(SOURCES diff --git a/libluna/include/luna/Hash.h b/libluna/include/luna/Hash.h new file mode 100644 index 00000000..e17ee359 --- /dev/null +++ b/libluna/include/luna/Hash.h @@ -0,0 +1,29 @@ +#pragma once +#include +#include + +u64 hash_memory(const void* mem, usize size, u64 salt); + +template u64 hash(const T& value, u64 salt) +{ + return hash_memory(&value, sizeof(value), salt); +} + +template <> u64 hash(const char* const& value, u64 salt); + +template static void swap(T* a, T* b) +{ + char* x = (char*)a; + char* y = (char*)b; + + usize size = sizeof(T); + + while (size--) + { + char t = *x; + *x = *y; + *y = t; + x += 1; + y += 1; + } +} diff --git a/libluna/include/luna/HashTable.h b/libluna/include/luna/HashTable.h new file mode 100644 index 00000000..df97aa06 --- /dev/null +++ b/libluna/include/luna/HashTable.h @@ -0,0 +1,146 @@ +#pragma once +#include +#include +#include + +template class HashTable +{ + static constexpr usize GROW_RATE = 2; + static constexpr usize GROW_FACTOR = 16; + + public: + Result try_set(const T& value) + { + T copy { value }; + return try_set(move(copy)); + } + + Result try_set(T&& value) + { + if (should_grow()) TRY(rehash(m_capacity + GROW_FACTOR)); + + u64 index = hash(value, m_salt) % m_capacity; + + while (true) + { + auto& bucket = m_buckets[index]; + if (bucket.has_value()) + { + if (*bucket == value) return false; + index++; + continue; + } + bucket = { move(value) }; + m_size++; + break; + } + + return true; + } + + T* try_find(const T& value) + { + if (!m_size) return nullptr; + + check(m_capacity); + + const u64 index = hash(value, m_salt) % m_capacity; + usize i = index; + + do { + auto& bucket = m_buckets[index]; + if (bucket.has_value()) + { + if (*bucket == value) return bucket.value_ptr(); + i++; + } + return nullptr; + } while (i != index); + + return nullptr; + } + + bool try_remove(const T& value) + { + if (!m_size) return false; + + check(m_capacity); + + const u64 index = hash(value, m_salt) % m_capacity; + usize i = index; + + do { + auto& bucket = m_buckets[index]; + if (bucket.has_value()) + { + if (*bucket == value) + { + bucket = {}; + m_size--; + if (i != index) rehash(m_capacity); + return true; + } + i++; + } + return false; + } while (i != index); + + return false; + } + + void clear() + { + for (usize i = 0; i < m_capacity; i++) m_buckets[i].~Option(); + + free_impl(m_buckets); + m_capacity = m_size = 0; + } + + ~HashTable() + { + clear(); + } + + private: + bool should_grow() + { + return (m_capacity == 0) || ((m_size * GROW_RATE) >= m_capacity); + } + + Result rehash(usize new_capacity) + { + HashTable new_table; + TRY(new_table.initialize(new_capacity)); + + if (m_capacity != 0) + { + for (usize i = 0; i < m_capacity; i++) + { + auto& opt = m_buckets[i]; + if (opt.has_value()) + { + auto value = opt.release_value(); + TRY(new_table.try_set(move(value))); + } + } + } + + swap(this, &new_table); + + return {}; + } + + Result initialize(usize initial_capacity) + { + check(m_buckets == nullptr); + m_capacity = initial_capacity; + m_buckets = (Option*)TRY(calloc_impl(initial_capacity, sizeof(Option), false)); + return {}; + } + + Option* m_buckets { nullptr }; + usize m_capacity { 0 }; + usize m_size { 0 }; + // FIXME: Randomize this to protect against hash table attacks. + u64 m_salt { 0 }; +}; diff --git a/libluna/src/Hash.cpp b/libluna/src/Hash.cpp new file mode 100644 index 00000000..cc528251 --- /dev/null +++ b/libluna/src/Hash.cpp @@ -0,0 +1,14 @@ +#include + +u64 hash_memory(const void* mem, usize size, u64 salt) +{ + const char* p = (const char*)mem; + u64 h = salt; + while (--size) h = h * 101 + (u64)*p++; + return h; +} + +template <> u64 hash(const char* const& value, u64 salt) +{ + return hash_memory(value, strlen(value), salt); +} diff --git a/tests/CMakeLists.txt b/tests/CMakeLists.txt index 9e593cfc..eec0fb3a 100644 --- a/tests/CMakeLists.txt +++ b/tests/CMakeLists.txt @@ -17,6 +17,7 @@ luna_test(libluna/TestVector.cpp TestVector) luna_test(libluna/TestBase64.cpp TestBase64) luna_test(libluna/TestUtf8.cpp TestUtf8) luna_test(libluna/TestFormat.cpp TestFormat) +luna_test(libluna/TestHashTable.cpp TestHashTable) luna_app(run-tests.cpp run-tests) endif() diff --git a/tests/libluna/TestHashTable.cpp b/tests/libluna/TestHashTable.cpp new file mode 100644 index 00000000..27f56d28 --- /dev/null +++ b/tests/libluna/TestHashTable.cpp @@ -0,0 +1,79 @@ +#include +#include +#include + +struct TwoInts +{ + int a; + int b; + + bool operator==(TwoInts other) const + { + return other.a == a; + } +}; + +template <> u64 hash(const TwoInts& value, u64 salt) +{ + return hash(value.a, salt); +} + +TestResult test_empty_hash_table() +{ + HashTable table; + + validate(table.try_find(0) == nullptr); + + test_success; +} + +TestResult test_hash_table_find() +{ + HashTable table; + + validate(TRY(table.try_set(0))); + + validate(table.try_find(0)); + validate(table.try_find(1) == nullptr); + + test_success; +} + +TestResult test_hash_table_remove() +{ + HashTable table; + + validate(TRY(table.try_set(0))); + + validate(table.try_find(0)); + + validate(table.try_remove(0)); + + validate(table.try_find(0) == nullptr); + + test_success; +} + +TestResult test_hash_table_duplicates() +{ + HashTable table; + + validate(TRY(table.try_set(TwoInts { 1, 5 }))); + validate(!TRY(table.try_set(TwoInts { 1, 3 }))); + + validate(table.try_find(TwoInts { 1, 0 })->b == 5); + + test_success; +} + +Result test_main() +{ + test_prelude; + + run_test(test_empty_hash_table); + run_test(test_hash_table_find); + run_test(test_hash_table_remove); + run_test(test_hash_table_duplicates); + + return {}; +} diff --git a/tests/run-tests.cpp b/tests/run-tests.cpp index 9bf75e30..6b90dc10 100644 --- a/tests/run-tests.cpp +++ b/tests/run-tests.cpp @@ -17,7 +17,7 @@ Result luna_main(int argc, char** argv) auto dir = TRY(os::Directory::open(test_dir)); - auto files = TRY(dir->list(os::Directory::Filter::Hidden)); + auto files = TRY(dir->list_names(os::Directory::Filter::Hidden)); for (const auto& program : files) {