#pragma once
#include <luna/Option.h>
#include <luna/TypeTraits.h>

template <typename T> class LinkedList;

template <typename T> class LinkedListNode
{
    using SelfType = LinkedListNode<T>;

  private:
    SelfType* m_next_node;
    SelfType* m_last_node;

    void set_next(SelfType* next)
    {
        m_next_node = next;
    }

    void set_last(SelfType* last)
    {
        m_last_node = last;
    }

    SelfType* get_next()
    {
        return m_next_node;
    }

    SelfType* get_last()
    {
        return m_last_node;
    }

    void detach_from_list()
    {
        if (m_next_node) m_next_node->m_last_node = m_last_node;
        if (m_last_node) m_last_node->m_next_node = m_next_node;
    }

    void append_to_list(SelfType* end_node)
    {
        end_node->m_next_node = this;
        this->m_last_node = end_node;
        this->m_next_node = nullptr;
    }

    void prepend_to_list(SelfType* start_node)
    {
        start_node->m_last_node = this;
        this->m_next_node = start_node;
        this->m_last_node = nullptr;
    }

    friend class LinkedList<T>;
};

template <typename T> class LinkedList
{
    using Node = LinkedListNode<T>;

    static_assert(IsBaseOf<LinkedListNode<T>, T>);

  public:
    void append(T* ptr)
    {
        Node* const node = extract_node(ptr);
        if (!m_start_node) m_start_node = node;
        if (m_end_node) node->append_to_list(m_end_node);
        else
        {
            node->set_next(nullptr);
            node->set_last(nullptr);
        }
        m_end_node = node;

        m_count++;
    }

    void prepend(T* ptr)
    {
        Node* const node = extract_node(ptr);
        if (!m_end_node) m_end_node = node;
        if (m_start_node) node->prepend_to_list(m_start_node);
        else
        {
            node->set_next(nullptr);
            node->set_last(nullptr);
        }
        m_start_node = node;

        m_count++;
    }

    void add_after(T* base, T* ptr)
    {
        Node* const new_node = extract_node(ptr);
        Node* const base_node = extract_node(base);

        if (m_end_node == base_node) m_end_node = new_node;

        if (base_node->get_next()) base_node->get_next()->set_last(new_node);

        new_node->set_next(base_node->get_next());
        base_node->set_next(new_node);
        new_node->set_last(base_node);

        m_count++;
    }

    T* remove(T* ptr)
    {
        Node* const node = extract_node(ptr);

        if (node == m_end_node) m_end_node = node->get_last();
        if (node == m_start_node) m_start_node = node->get_next();

        node->detach_from_list();

        m_count--;

        return ptr;
    }

    Option<T*> first()
    {
        return nonnull_or_empty_option((T*)m_start_node);
    }

    T* expect_first()
    {
        check(m_start_node);
        return (T*)m_start_node;
    }

    Option<T*> last()
    {
        return nonnull_or_empty_option((T*)m_end_node);
    }

    T* expect_last()
    {
        check(m_end_node);
        return (T*)m_end_node;
    }

    Option<T*> next(T* item)
    {
        return nonnull_or_empty_option((T*)extract_node(item)->get_next());
    }

    Option<T*> previous(T* item)
    {
        return nonnull_or_empty_option((T*)extract_node(item)->get_last());
    }

    // Iterates over the elements of the LinkedList from start to end, calling callback for every element.
    template <typename Callback> void for_each(Callback callback)
    {
        for (Node* node = m_start_node; node; node = node->get_next()) { callback((T*)node); }
    }

    // Iterates over the elements of the LinkedList from start to end, calling callback for every element. This
    // for_each is implemented in such a way that elements can be removed while iterating over it.
    template <typename Callback> void delayed_for_each(Callback callback)
    {
        for (Node* node = m_start_node; node;)
        {
            T* current = (T*)node;
            node = node->get_next();
            callback(current);
        }
    }

    // Iterates over the elements of the LinkedList from end to start, calling callback for every element.
    template <typename Callback> void for_each_reversed(Callback callback)
    {
        for (Node* node = m_end_node; node; node = node->get_last()) { callback((T*)node); }
    }

    // Iterates over the elements of the LinkedList from the element after 'start' to end, calling callback for
    // every element.
    template <typename Callback> void for_each_after(T* start, Callback callback)
    {
        for (Node* node = extract_node(start)->m_next_node; node; node = node->get_next()) { callback((T*)node); }
    }

    // Iterates over the elements of the LinkedList from the element before 'end' to start, calling callback for
    // every element.
    template <typename Callback> void for_each_before(T* end, Callback callback)
    {
        for (Node* node = extract_node(end)->m_last_node; node; node = node->get_last()) { callback((T*)node); }
    }

    // Iterates over the elements of the LinkedList from start to end, removing each element before passing it to
    // the callback.
    template <typename Callback> void consume(Callback callback)
    {
        for (Node* node = m_start_node; node;)
        {
            T* current = (T*)node;
            node = node->get_next();
            remove(current);
            callback(current);
        }
    }

    usize count()
    {
        return m_count;
    }

    struct LinkedListIterator
    {
        typedef T* PtrT;

      private:
        LinkedListIterator(PtrT ptr, LinkedList<T>& list) : m_ptr(ptr), m_list(list)
        {
        }

        PtrT m_ptr;
        LinkedList<T>& m_list;

      public:
        PtrT& operator*()
        {
            return m_ptr;
        }

        void operator++()
        {
            m_ptr = m_list.next(m_ptr).value_or(nullptr);
        }

        bool operator!=(LinkedListIterator& other)
        {
            return m_ptr != other.m_ptr || &m_list != &other.m_list;
        }

        friend class LinkedList<T>;
    };

    LinkedListIterator begin()
    {
        return { (T*)m_start_node, *this };
    }

    LinkedListIterator end()
    {
        return { nullptr, *this };
    }

  private:
    Node* m_start_node = nullptr;
    Node* m_end_node = nullptr;

    Node* extract_node(T* item)
    {
        return (Node*)item;
    }

    usize m_count = 0;
};