From be5903c0c88a7ec3ae17073d8007a0d0bf6a88c2 Mon Sep 17 00:00:00 2001 From: apio Date: Fri, 28 Jul 2023 17:30:20 +0200 Subject: [PATCH] kernel: Implement listen(), connect() and accept() --- kernel/src/net/Socket.h | 10 +++- kernel/src/net/UnixSocket.cpp | 102 +++++++++++++++++++++++++++++++-- kernel/src/net/UnixSocket.h | 15 ++++- kernel/src/sys/socket.cpp | 80 ++++++++++++++++++++++++++ libluna/include/luna/Syscall.h | 2 +- 5 files changed, 199 insertions(+), 10 deletions(-) diff --git a/kernel/src/net/Socket.h b/kernel/src/net/Socket.h index 766a7b58..fa7eb101 100644 --- a/kernel/src/net/Socket.h +++ b/kernel/src/net/Socket.h @@ -1,5 +1,7 @@ #pragma once +#include "arch/CPU.h" #include "fs/VFS.h" +#include "thread/Thread.h" #include class Socket : public VFS::FileInode @@ -47,7 +49,11 @@ class Socket : public VFS::FileInode virtual Result recv(u8*, usize, int) const = 0; virtual Result bind(SharedPtr, struct sockaddr*, socklen_t) = 0; - virtual Result connect(struct sockaddr*, socklen_t) = 0; + virtual Result connect(Registers*, int, struct sockaddr*, socklen_t) = 0; + + virtual Result> accept(Registers*, int, struct sockaddr**, socklen_t*) = 0; + + virtual Result listen(int backlog) = 0; Result truncate(usize) override { @@ -105,7 +111,7 @@ class Socket : public VFS::FileInode virtual ~Socket() = default; protected: - VFS::FileSystem* m_fs; + VFS::FileSystem* m_fs { nullptr }; usize m_inode_number { 0 }; mode_t m_mode; u32 m_uid { 0 }; diff --git a/kernel/src/net/UnixSocket.cpp b/kernel/src/net/UnixSocket.cpp index e389d50e..8530c3d3 100644 --- a/kernel/src/net/UnixSocket.cpp +++ b/kernel/src/net/UnixSocket.cpp @@ -1,4 +1,5 @@ #include "net/UnixSocket.h" +#include #include #include @@ -45,8 +46,8 @@ Result UnixSocket::send(const u8* buf, usize length, int) Result UnixSocket::recv(u8* buf, usize length, int) const { - if (m_state == State::Reset) return err(ECONNRESET); - if (m_state != State::Connected) return err(ENOTCONN); + if (m_state == State::Reset && !m_data.size()) return err(ECONNRESET); + if (m_state != State::Connected && m_state != State::Reset) return err(ENOTCONN); return m_data.dequeue_data(buf, length); } @@ -104,7 +105,100 @@ Result UnixSocket::bind(SharedPtr socket, struct sockaddr* addr, s return {}; } -Result UnixSocket::connect(struct sockaddr*, socklen_t) +Result UnixSocket::connect(Registers* regs, int flags, struct sockaddr* addr, socklen_t addrlen) { - return err(ENOSYS); + if (!addr) return err(EINVAL); + if (addr->sa_family != AF_UNIX) return err(EAFNOSUPPORT); + if ((usize)addrlen > sizeof(sockaddr_un)) return err(EINVAL); + + if (m_state == State::Connected) return err(EISCONN); + if (m_state == State::Connecting) return err(EALREADY); + if (m_state != State::Inactive) return err(EINVAL); + + struct sockaddr_un* un_address = (struct sockaddr_un*)addr; + + String path = TRY(String::from_string_view( + StringView::from_fixed_size_cstring(un_address->sun_path, addrlen - sizeof(sa_family_t)))); + + auto* current = Scheduler::current(); + + auto inode = TRY(VFS::resolve_path(path.chars(), current->auth, current->current_directory)); + if (inode->type() != VFS::InodeType::Socket) + return err(ENOTSOCK); // FIXME: POSIX doesn't say what error to return here? + if (!VFS::can_write(inode, current->auth)) return err(EACCES); + + auto socket = (SharedPtr)inode; + if (socket->m_state != State::Listening) return err(ECONNREFUSED); + if (!socket->m_listen_queue.try_push(this)) return err(ECONNREFUSED); + if (socket->m_blocked_thread) socket->m_blocked_thread->wake_up(); + + m_state = Connecting; + if (flags & O_NONBLOCK) return err(EINPROGRESS); + + while (1) + { + m_blocked_thread = current; + kernel_wait_for_event(); + m_blocked_thread = nullptr; + if (current->interrupted) + { + if (current->will_invoke_signal_handler()) return err(EINTR); + current->process_pending_signals(regs); + continue; + } + break; + } + + check(m_state == Connected); + check(m_peer); + + return {}; +} + +Result UnixSocket::listen(int backlog) +{ + if (backlog < 0) backlog = 0; + if (m_state == State::Listening || m_state == State::Connected) return err(EINVAL); + if (m_state != State::Bound) return err(EDESTADDRREQ); + TRY(m_listen_queue.set_size(backlog)); + m_state = State::Listening; + return {}; +} + +Result> UnixSocket::accept(Registers* regs, int flags, struct sockaddr** addr, + socklen_t* addrlen) +{ + if (m_state != State::Listening) return err(EINVAL); + + auto* current = Scheduler::current(); + + UnixSocket* peer = nullptr; + while (!m_listen_queue.try_pop(peer)) + { + if (flags & O_NONBLOCK) return err(EAGAIN); + m_blocked_thread = current; + kernel_wait_for_event(); + m_blocked_thread = nullptr; + if (current->interrupted) + { + if (current->will_invoke_signal_handler()) return err(EINTR); + current->process_pending_signals(regs); + continue; + } + } + + check(peer); + + auto socket = TRY(make_shared(peer)); + auto description = TRY(make_shared(socket, O_RDWR)); + + peer->m_peer = socket.ptr(); + peer->m_state = State::Connected; + + if (peer->m_blocked_thread) peer->m_blocked_thread->wake_up(); + + *addr = (struct sockaddr*)&peer->m_addr; + *addrlen = peer->m_addrlen; + + return description; } diff --git a/kernel/src/net/UnixSocket.h b/kernel/src/net/UnixSocket.h index 0f376cbf..88c8150e 100644 --- a/kernel/src/net/UnixSocket.h +++ b/kernel/src/net/UnixSocket.h @@ -1,6 +1,8 @@ #pragma once #include "net/Socket.h" +#include "thread/Thread.h" #include +#include #include #include @@ -12,14 +14,17 @@ class UnixSocket : public Socket bool blocking() const override { - return !m_data.size(); + return (m_state == Connected || m_state == Reset) && !m_data.size(); } Result send(const u8*, usize, int) override; Result recv(u8*, usize, int) const override; Result bind(SharedPtr, struct sockaddr*, socklen_t) override; - Result connect(struct sockaddr*, socklen_t) override; + Result connect(Registers*, int, struct sockaddr*, socklen_t) override; + Result> accept(Registers*, int, struct sockaddr**, socklen_t*) override; + + Result listen(int backlog) override; void did_close() override; @@ -39,10 +44,14 @@ class UnixSocket : public Socket }; State m_state = State::Inactive; - UnixSocket* m_peer; + UnixSocket* m_peer = nullptr; mutable Buffer m_data; + Thread* m_blocked_thread { nullptr }; + + DynamicCircularQueue m_listen_queue; + struct sockaddr_un m_addr = { .sun_family = AF_UNIX, .sun_path = {} }; socklen_t m_addrlen = sizeof(sa_family_t); }; diff --git a/kernel/src/sys/socket.cpp b/kernel/src/sys/socket.cpp index d6d3ade7..eed20db2 100644 --- a/kernel/src/sys/socket.cpp +++ b/kernel/src/sys/socket.cpp @@ -47,3 +47,83 @@ Result sys_bind(Registers*, SyscallArgs args) return 0; } + +Result sys_connect(Registers* regs, SyscallArgs args) +{ + int sockfd = (int)args[0]; + struct sockaddr* addr = (struct sockaddr*)args[1]; + socklen_t addrlen = (socklen_t)args[2]; + + struct sockaddr_storage storage; + if ((usize)addrlen > sizeof(storage)) return err(EINVAL); + if (!MemoryManager::copy_from_user(addr, &storage, addrlen)) return err(EFAULT); + + auto* current = Scheduler::current(); + + auto description = TRY(current->resolve_fd(sockfd))->description; + + if (description->inode->type() != VFS::InodeType::Socket) return err(ENOTSOCK); + + auto socket = (SharedPtr)description->inode; + + TRY(socket->connect(regs, description->flags, (struct sockaddr*)&storage, addrlen)); + + return 0; +} + +Result sys_listen(Registers*, SyscallArgs args) +{ + int sockfd = (int)args[0]; + int backlog = (int)args[1]; + + auto* current = Scheduler::current(); + + auto inode = TRY(current->resolve_fd(sockfd))->inode(); + + if (inode->type() != VFS::InodeType::Socket) return err(ENOTSOCK); + + auto socket = (SharedPtr)inode; + + TRY(socket->listen(backlog)); + + return 0; +} + +Result sys_accept(Registers* regs, SyscallArgs args) +{ + int sockfd = (int)args[0]; + struct sockaddr* addr = (struct sockaddr*)args[1]; + socklen_t* addrlen = (socklen_t*)args[2]; + + if (addr && !addrlen) return err(EINVAL); + + socklen_t len; + if (addr) + { + if (!MemoryManager::copy_from_user_typed(addrlen, &len)) return err(EFAULT); + } + + auto* current = Scheduler::current(); + + auto description = TRY(current->resolve_fd(sockfd))->description; + + if (description->inode->type() != VFS::InodeType::Socket) return err(ENOTSOCK); + + auto socket = (SharedPtr)description->inode; + + struct sockaddr* client; + socklen_t client_len; + auto new_description = TRY(socket->accept(regs, description->flags, &client, &client_len)); + + int fd = TRY(current->allocate_fd(0)); + current->fd_table[fd] = FileDescriptor { new_description, 0 }; + + if (client_len < len) len = client_len; + if (addr) + { + MemoryManager::copy_to_user(addr, client, len); + MemoryManager::copy_to_user_typed(addrlen, &client_len); + } + + return fd; +} diff --git a/libluna/include/luna/Syscall.h b/libluna/include/luna/Syscall.h index f82a9c72..b663a02d 100644 --- a/libluna/include/luna/Syscall.h +++ b/libluna/include/luna/Syscall.h @@ -7,7 +7,7 @@ _e(fstatat) _e(chdir) _e(getcwd) _e(unlinkat) _e(uname) _e(sethostname) _e(dup2) _e(pipe) _e(mount) \ _e(umount) _e(pstat) _e(getrusage) _e(symlinkat) _e(readlinkat) _e(umask) _e(linkat) _e(faccessat) \ _e(pivot_root) _e(sigreturn) _e(sigaction) _e(kill) _e(sigprocmask) _e(setpgid) _e(isatty) \ - _e(getpgid) _e(socket) _e(bind) + _e(getpgid) _e(socket) _e(bind) _e(connect) _e(listen) _e(accept) enum Syscalls {