#pragma once #include #include template class LinkedList; template class LinkedListNode { using SelfType = LinkedListNode; private: SelfType* m_next_node; SelfType* m_last_node; void set_next(SelfType* next) { m_next_node = next; } void set_last(SelfType* last) { m_last_node = last; } SelfType* get_next() { return m_next_node; } SelfType* get_last() { return m_last_node; } void detach_from_list() { if (m_next_node) m_next_node->m_last_node = m_last_node; if (m_last_node) m_last_node->m_next_node = m_next_node; } void append_to_list(SelfType* end_node) { end_node->m_next_node = this; this->m_last_node = end_node; this->m_next_node = nullptr; } void prepend_to_list(SelfType* start_node) { start_node->m_last_node = this; this->m_next_node = start_node; this->m_last_node = nullptr; } friend class LinkedList; }; template class LinkedList { using Node = LinkedListNode; static_assert(IsBaseOf, T>); public: void append(T* ptr) { Node* const node = extract_node(ptr); if (!m_start_node) m_start_node = node; if (m_end_node) node->append_to_list(m_end_node); else { node->set_next(nullptr); node->set_last(nullptr); } m_end_node = node; m_count++; } void prepend(T* ptr) { Node* const node = extract_node(ptr); if (!m_end_node) m_end_node = node; if (m_start_node) node->prepend_to_list(m_start_node); else { node->set_next(nullptr); node->set_last(nullptr); } m_start_node = node; m_count++; } void add_after(T* base, T* ptr) { Node* const new_node = extract_node(ptr); Node* const base_node = extract_node(base); if (m_end_node == base_node) m_end_node = new_node; if (base_node->get_next()) base_node->get_next()->set_last(new_node); new_node->set_next(base_node->get_next()); base_node->set_next(new_node); new_node->set_last(base_node); m_count++; } T* remove(T* ptr) { Node* const node = extract_node(ptr); if (node == m_end_node) m_end_node = node->get_last(); if (node == m_start_node) m_start_node = node->get_next(); node->detach_from_list(); m_count--; return ptr; } Option first() { return nonnull_or_empty_option((T*)m_start_node); } T* expect_first() { check(m_start_node); return (T*)m_start_node; } Option last() { return nonnull_or_empty_option((T*)m_end_node); } T* expect_last() { check(m_end_node); return (T*)m_end_node; } Option next(T* item) { return nonnull_or_empty_option((T*)extract_node(item)->get_next()); } Option previous(T* item) { return nonnull_or_empty_option((T*)extract_node(item)->get_last()); } // Iterates over the elements of the LinkedList from start to end, calling callback for every element. template void for_each(Callback callback) { for (Node* node = m_start_node; node; node = node->get_next()) { callback((T*)node); } } // Iterates over the elements of the LinkedList from start to end, calling callback for every element. This // for_each is implemented in such a way that elements can be removed while iterating over it. template void delayed_for_each(Callback callback) { for (Node* node = m_start_node; node;) { T* current = (T*)node; node = node->get_next(); callback(current); } } // Iterates over the elements of the LinkedList from end to start, calling callback for every element. template void for_each_reversed(Callback callback) { for (Node* node = m_end_node; node; node = node->get_last()) { callback((T*)node); } } // Iterates over the elements of the LinkedList from the element after 'start' to end, calling callback for // every element. template void for_each_after(T* start, Callback callback) { for (Node* node = extract_node(start)->m_next_node; node; node = node->get_next()) { callback((T*)node); } } // Iterates over the elements of the LinkedList from the element before 'end' to start, calling callback for // every element. template void for_each_before(T* end, Callback callback) { for (Node* node = extract_node(end)->m_last_node; node; node = node->get_last()) { callback((T*)node); } } // Iterates over the elements of the LinkedList from start to end, removing each element before passing it to // the callback. template void consume(Callback callback) { for (Node* node = m_start_node; node;) { T* current = (T*)node; node = node->get_next(); remove(current); callback(current); } } usize count() { return m_count; } struct LinkedListIterator { typedef T* PtrT; private: LinkedListIterator(PtrT ptr, LinkedList& list) : m_ptr(ptr), m_list(list) { } PtrT m_ptr; LinkedList& m_list; public: PtrT& operator*() { return m_ptr; } void operator++() { m_ptr = m_list.next(m_ptr).value_or(nullptr); } bool operator!=(LinkedListIterator& other) { return m_ptr != other.m_ptr || &m_list != &other.m_list; } friend class LinkedList; }; LinkedListIterator begin() { return { (T*)m_start_node, *this }; } LinkedListIterator end() { return { nullptr, *this }; } private: Node* m_start_node = nullptr; Node* m_end_node = nullptr; Node* extract_node(T* item) { return (Node*)item; } usize m_count = 0; };