From a3d0fa7d0a7576c8d6e9f59e7ae515cff10f013e Mon Sep 17 00:00:00 2001 From: apio Date: Mon, 16 Jan 2023 21:16:38 +0100 Subject: [PATCH] UserVM: Validate the entire range when freeing multiple VM pages --- kernel/src/memory/UserVM.cpp | 3 +-- luna/include/luna/Bitmap.h | 2 ++ luna/src/Bitmap.cpp | 37 ++++++++++++++++++++++++++++++++++++ 3 files changed, 40 insertions(+), 2 deletions(-) diff --git a/kernel/src/memory/UserVM.cpp b/kernel/src/memory/UserVM.cpp index 56c19e63..c6f39f55 100644 --- a/kernel/src/memory/UserVM.cpp +++ b/kernel/src/memory/UserVM.cpp @@ -100,8 +100,7 @@ Result UserVM::free_several_pages(u64 address, usize count) const u64 index = (address - VM_BASE) / ARCH_PAGE_SIZE; if ((index + count) > (MAX_VM_SIZE * 8)) return err(EINVAL); - // FIXME: Is it necessary to check all pages? - if (!m_bitmap.get(index)) return err(EFAULT); + if (!m_bitmap.match_region(index, count, true)) return err(EFAULT); m_bitmap.clear_region(index, count, false); diff --git a/luna/include/luna/Bitmap.h b/luna/include/luna/Bitmap.h index 1302a399..1e35d655 100644 --- a/luna/include/luna/Bitmap.h +++ b/luna/include/luna/Bitmap.h @@ -43,6 +43,8 @@ class Bitmap Option find_and_toggle_region(bool value, usize count, usize begin = 0); + bool match_region(usize start, usize bits, bool value); + void clear(bool value); void clear_region(usize start, usize bits, bool value); diff --git a/luna/src/Bitmap.cpp b/luna/src/Bitmap.cpp index 9ee2d099..ad400024 100644 --- a/luna/src/Bitmap.cpp +++ b/luna/src/Bitmap.cpp @@ -167,3 +167,40 @@ Option Bitmap::find_and_toggle_region(bool value, usize count, usize begi clear_region(index, count, !value); return index; } + +bool Bitmap::match_region(usize start, usize bits, bool value) +{ + expect(initialized(), "Bitmap was never initialized"); + expect((start + bits) <= size(), "Bitmap match out of range"); + + if (!bits) return true; + + // Match individual bits while not on a byte boundary. + while ((start % 8) && bits) + { + if (get(start) != value) return false; + start++; + bits--; + } + + // Match the rest in bytes. + usize bytes = bits / 8; + const u8 byte_that_contains_only_value = value_byte(value); + + for (usize i = start; i < start + bytes; i += 8) + { + if (m_location[i / 8] != byte_that_contains_only_value) return false; + } + + start += bytes * 8; + bits -= bytes * 8; + + // Match the remaining individual bits. + while (bits--) + { + if (get(start) != value) return false; + start++; + } + + return true; +}