#include "net/UnixSocket.h"
#include <bits/open-flags.h>
#include <luna/PathParser.h>
#include <thread/Scheduler.h>

UnixSocket::UnixSocket()
{
}

UnixSocket::UnixSocket(UnixSocket* peer) : m_state(State::Connected), m_peer(peer)
{
}

UnixSocket::~UnixSocket()
{
    did_close();
}

void UnixSocket::did_close()
{
    if (m_peer)
    {
        m_peer->m_peer = nullptr;
        m_peer->m_state = State::Reset;
    }
    m_state = State::Inactive;
}

void UnixSocket::connect_to_peer(UnixSocket* peer)
{
    m_peer = peer;
    m_state = State::Connected;
}

Result<usize> UnixSocket::send(const u8* buf, usize length, int)
{
    if (m_state == State::Reset) return err(ECONNRESET);
    if (m_state != State::Connected) return err(ENOTCONN);

    check(m_peer);

    TRY(m_peer->m_data.append_data(buf, length));

    return length;
}

Result<usize> UnixSocket::recv(u8* buf, usize length, int) const
{
    if (m_state == State::Reset && !m_data.size()) return err(ECONNRESET);
    if (m_state != State::Connected && m_state != State::Reset) return err(ENOTCONN);

    return m_data.dequeue_data(buf, length);
}

static Result<void> bind_socket_to_fs(const char* path, Credentials auth, SharedPtr<VFS::Inode> working_directory,
                                      SharedPtr<UnixSocket> socket)
{
    auto parent_path = TRY(PathParser::dirname(path));

    auto parent_inode = TRY(VFS::resolve_path(parent_path.chars(), auth, working_directory));

    if (!VFS::can_write(parent_inode, auth)) return err(EACCES);

    auto child_name = TRY(PathParser::basename(path));

    TRY(VFS::validate_filename(child_name.view()));

    socket->set_inode_number(TRY(parent_inode->fs()->allocate_inode_number()));
    socket->set_fs(parent_inode->fs());

    return parent_inode->add_entry(socket, child_name.chars());
}

Result<void> UnixSocket::bind(struct sockaddr* addr, socklen_t addrlen)
{
    if (!addr) return err(EDESTADDRREQ);
    if (addr->sa_family != AF_UNIX) return err(EAFNOSUPPORT);
    if ((usize)addrlen > sizeof(sockaddr_un)) return err(EINVAL);

    if (m_state == State::Connected) return err(EISCONN);
    if (m_state != State::Inactive) return err(EINVAL);

    struct sockaddr_un* un_address = (struct sockaddr_un*)addr;

    String path = TRY(String::from_string_view(
        StringView::from_fixed_size_cstring(un_address->sun_path, addrlen - sizeof(sa_family_t))));

    auto* current = Scheduler::current();

    m_metadata.mode = 0777 & ~current->umask;
    m_metadata.uid = current->auth.euid;
    m_metadata.gid = current->auth.egid;

    auto rc = bind_socket_to_fs(path.chars(), current->auth, current->current_directory, SharedPtr<Socket> { this });
    if (rc.has_error())
    {
        if (rc.error() == EEXIST) return err(EADDRINUSE);
        return rc.release_error();
    }

    memcpy(&m_addr, un_address, addrlen);
    m_addrlen = addrlen;

    m_state = State::Bound;

    return {};
}

Result<void> UnixSocket::connect(Registers* regs, int flags, struct sockaddr* addr, socklen_t addrlen)
{
    if (!addr) return err(EINVAL);
    if (addr->sa_family != AF_UNIX) return err(EAFNOSUPPORT);
    if ((usize)addrlen > sizeof(sockaddr_un)) return err(EINVAL);

    if (m_state == State::Connected) return err(EISCONN);
    if (m_state == State::Connecting) return err(EALREADY);
    if (m_state != State::Inactive) return err(EINVAL);

    struct sockaddr_un* un_address = (struct sockaddr_un*)addr;

    String path = TRY(String::from_string_view(
        StringView::from_fixed_size_cstring(un_address->sun_path, addrlen - sizeof(sa_family_t))));

    auto* current = Scheduler::current();

    auto inode = TRY(VFS::resolve_path(path.chars(), current->auth, current->current_directory));
    if (inode->type() != VFS::InodeType::Socket)
        return err(ENOTSOCK); // FIXME: POSIX doesn't say what error to return here?
    if (!VFS::can_write(inode, current->auth)) return err(EACCES);

    auto socket = (SharedPtr<UnixSocket>)inode;
    if (socket->m_state != State::Listening) return err(ECONNREFUSED);
    if (!socket->m_listen_queue.try_push(this)) return err(ECONNREFUSED);
    if (socket->m_blocked_thread) socket->m_blocked_thread->wake_up();

    m_state = Connecting;
    if (flags & O_NONBLOCK) return err(EINPROGRESS);

    while (1)
    {
        m_blocked_thread = current;
        kernel_wait_for_event();
        m_blocked_thread = nullptr;
        if (current->interrupted)
        {
            if (current->will_ignore_pending_signal())
            {
                current->process_pending_signals(regs);
                continue;
            }
            return err(EINTR);
        }
        break;
    }

    check(m_state == Connected);
    check(m_peer);

    return {};
}

Result<void> UnixSocket::listen(int backlog)
{
    if (backlog < 0) backlog = 0;
    if (m_state == State::Listening || m_state == State::Connected) return err(EINVAL);
    if (m_state != State::Bound) return err(EDESTADDRREQ);
    TRY(m_listen_queue.set_size(backlog));
    m_state = State::Listening;
    return {};
}

Result<SharedPtr<OpenFileDescription>> UnixSocket::accept(Registers* regs, int flags, struct sockaddr** addr,
                                                          socklen_t* addrlen)
{
    if (m_state != State::Listening) return err(EINVAL);

    auto* current = Scheduler::current();

    UnixSocket* peer = nullptr;
    while (!m_listen_queue.try_pop(peer))
    {
        if (flags & O_NONBLOCK) return err(EAGAIN);
        m_blocked_thread = current;
        kernel_wait_for_event();
        m_blocked_thread = nullptr;
        if (current->interrupted)
        {
            if (current->will_ignore_pending_signal())
            {
                current->process_pending_signals(regs);
                continue;
            }
            return err(EINTR);
        }
    }

    check(peer);

    auto socket = TRY(make_shared<UnixSocket>(peer));
    auto description = TRY(make_shared<OpenFileDescription>(socket, O_RDWR));

    peer->m_peer = socket.ptr();
    peer->m_state = State::Connected;

    if (peer->m_blocked_thread) peer->m_blocked_thread->wake_up();

    *addr = (struct sockaddr*)&peer->m_addr;
    *addrlen = peer->m_addrlen;

    return description;
}