#pragma once
#include <luna/Alloc.h>
#include <luna/Atomic.h>
#include <luna/Hash.h>
#include <luna/OwnedPtr.h>
#include <luna/Result.h>
#include <luna/ScopeGuard.h>

namespace __detail
{
    struct RefCount
    {
        void ref()
        {
            m_ref_count++;
        }

        bool unref()
        {
            m_ref_count--;
            return m_ref_count == 0;
        }

      private:
        Atomic<int> m_ref_count { 1 };
    };
}

template <typename T> class SharedPtr
{
    using RefCount = __detail::RefCount;

  public:
    SharedPtr()
    {
        m_ptr = nullptr;
        m_ref_count = nullptr;
    }

    SharedPtr(T* ptr, RefCount* ref_count) : m_ptr(ptr), m_ref_count(ref_count)
    {
    }

    SharedPtr(const SharedPtr<T>& other) : m_ptr(other.m_ptr), m_ref_count(other.m_ref_count)
    {
        if (m_ref_count) m_ref_count->ref();
    }

    SharedPtr(SharedPtr<T>&& other) : m_ptr(other.m_ptr), m_ref_count(other.m_ref_count)
    {
        other.m_ptr = nullptr;
        other.m_ref_count = nullptr;
    }

    template <typename Tp> operator SharedPtr<Tp>()
    {
        if (m_ref_count) m_ref_count->ref();
        return { (Tp*)m_ptr, m_ref_count };
    }

    ~SharedPtr()
    {
        if (m_ref_count && m_ref_count->unref())
        {
            delete m_ref_count;
            delete m_ptr;
        }
    }

    SharedPtr<T>& operator=(const SharedPtr<T>& other)
    {
        if (&other == this) return *this;

        if (m_ref_count && m_ref_count->unref())
        {
            delete m_ref_count;
            delete m_ptr;
        }

        m_ptr = other.m_ptr;
        m_ref_count = other.m_ref_count;

        if (m_ref_count) m_ref_count->ref();

        return *this;
    }

    bool operator==(const SharedPtr<T>& other)
    {
        return m_ptr == other.m_ptr && m_ref_count == other.m_ref_count;
    }

    T* ptr() const
    {
        return m_ptr;
    }

    T* operator->() const
    {
        return m_ptr;
    }

    T& operator*() const
    {
        return *m_ptr;
    }

    operator bool() const
    {
        return m_ptr != nullptr;
    }

  private:
    T* m_ptr;
    RefCount* m_ref_count;
};

// NOTE: ptr is deleted if any of the adopt_shared* functions fail to construct a SharedPtr.
template <typename T> Result<SharedPtr<T>> adopt_shared(T* ptr)
{
    using RefCount = __detail::RefCount;

    auto guard = make_scope_guard([ptr] { delete ptr; });

    RefCount* const ref_count = TRY(make<RefCount>());

    guard.deactivate();

    return SharedPtr<T> { ptr, ref_count };
}

template <typename T, class... Args> Result<SharedPtr<T>> make_shared(Args... args)
{
    T* raw_ptr = TRY(make<T>(args...));
    return adopt_shared(raw_ptr);
}

template <typename T> Result<SharedPtr<T>> adopt_shared_if_nonnull(T* ptr)
{
    if (ptr) return adopt_shared(ptr);
    else
        return err(ENOMEM);
}

template <typename T> Result<SharedPtr<T>> adopt_shared_from_owned(OwnedPtr<T>&& other)
{
    T* ptr = other.m_ptr;
    other.m_ptr = nullptr;

    const SharedPtr<T> shared_ptr = TRY(adopt_shared(ptr));

    return shared_ptr;
}