From c779ef84ef464e80fecc65369f55589d17246464 Mon Sep 17 00:00:00 2001 From: apio Date: Thu, 3 Aug 2023 16:39:30 +0200 Subject: [PATCH] libos: Add a new LocalServer class for local domain sockets --- libos/CMakeLists.txt | 1 + libos/include/os/LocalServer.h | 132 +++++++++++++++++++++++++++++++++ libos/src/LocalServer.cpp | 87 ++++++++++++++++++++++ 3 files changed, 220 insertions(+) create mode 100644 libos/include/os/LocalServer.h create mode 100644 libos/src/LocalServer.cpp diff --git a/libos/CMakeLists.txt b/libos/CMakeLists.txt index 0771eb28..b2c532eb 100644 --- a/libos/CMakeLists.txt +++ b/libos/CMakeLists.txt @@ -13,6 +13,7 @@ set(SOURCES src/Path.cpp src/Mode.cpp src/Prompt.cpp + src/LocalServer.cpp ) add_library(os ${SOURCES}) diff --git a/libos/include/os/LocalServer.h b/libos/include/os/LocalServer.h new file mode 100644 index 00000000..9a1c9a7e --- /dev/null +++ b/libos/include/os/LocalServer.h @@ -0,0 +1,132 @@ +#pragma once +#include +#include +#include + +namespace os +{ + /** + * @brief A local domain server, used to communicate between processes on the same machine. + */ + class LocalServer : public Shareable + { + public: + /** + * @brief Create a new server object and bind it to a local address. + * + * @param path The path to use for the server socket. + * @param blocking Whether the server should block if no connections are available when calling accept(). + * @return Result> An error, or a new server object. + */ + static Result> create(StringView path, bool blocking); + + /** + * @brief Activate the server and start listening for connections. + * + * @param backlog The number of unaccepted connections to keep. + * @return Result Whether the operation succeded. + */ + Result listen(int backlog); + + /** + * @brief Return the underlying socket file descriptor used by this object. + * + * @return int The file descriptor. + */ + int fd() const + { + return m_fd; + } + + /** + * @brief An interface to communicate with clients connected to a local server. + */ + class Client : public Shareable + { + public: + /** + * @brief Read arbitrary data from the client. The call will block if there is no data and the parent server + * object has not been created as non-blocking. + * + * @param buf The buffer to read data into. + * @param length The maximum amount of bytes to read. + * @return Result An error, or the number of bytes read. + */ + Result recv(u8* buf, usize length); + + /** + * @brief Read an object from the client. The call will block if there is no data and the parent server + * object has not been created as non-blocking. + * + * @tparam T The type of the object. + * @param out A reference to the object to read data into. + * @return Result Whether the operation succeded. + */ + template Result recv_typed(T& out) + { + TRY(recv((u8*)&out, sizeof(T))); + return {}; + } + + /** + * @brief Send arbitrary data to the client. + * + * @param buf The buffer to send data from. + * @param length The amount of bytes to send. + * @return Result An error, or the number of bytes actually sent. + */ + Result send(const u8* buf, usize length); + + /** + * @brief Send an object to the client. + * + * @tparam T The type of the object. + * @param out A reference to the object to send data from. + * @return Result Whether the operation succeded. + */ + template Result send_typed(const T& out) + { + TRY(send((const u8*)&out, sizeof(T))); + return {}; + } + + /** + * @brief Disconnect from the attached client. + * + * This will make any further reads on the client return ECONNRESET, and will make this object invalid. + */ + void disconnect(); + + /** + * @brief Return the underlying socket file descriptor used by this object. + * + * @return int The file descriptor. + */ + int fd() const + { + return m_fd; + } + + Client(int fd); + ~Client(); + + private: + int m_fd; + }; + + /** + * @brief Accept a new incoming connection and return a handle to it. If there are no incoming connections, + * accept() either blocks until there is one (if the object was created with blocking=true), or returns EAGAIN + * (if the object was created with blocking=false). + * + * @return Result> An error, or a handle to the new connection. + */ + Result> accept(); + + ~LocalServer(); + + private: + int m_fd; + bool m_blocking; + }; +} diff --git a/libos/src/LocalServer.cpp b/libos/src/LocalServer.cpp new file mode 100644 index 00000000..fba62c57 --- /dev/null +++ b/libos/src/LocalServer.cpp @@ -0,0 +1,87 @@ +#include +#include +#include +#include +#include +#include +#include + +namespace os +{ + Result> LocalServer::create(StringView path, bool blocking) + { + auto server = TRY(make_shared()); + + (void)os::FileSystem::remove(path); // We explicitly ignore any error here, either it doesn't exist (which is + // fine), or it cannot be removed, which will make bind() fail later. + + int sockfd = socket(AF_UNIX, SOCK_STREAM, 0); + if (sockfd < 0) return err(errno); + + struct sockaddr_un un; + un.sun_family = AF_UNIX; + strncpy(un.sun_path, path.chars(), sizeof(un.sun_path)); + + if (bind(sockfd, (struct sockaddr*)&un, sizeof(un)) < 0) + { + close(sockfd); + return err(errno); + } + + if (!blocking) { fcntl(sockfd, F_SETFL, O_NONBLOCK); } + server->m_blocking = blocking; + + fcntl(sockfd, F_SETFD, FD_CLOEXEC); + + server->m_fd = sockfd; + return server; + } + + Result LocalServer::listen(int backlog) + { + if (::listen(m_fd, backlog) < 0) return err(errno); + return {}; + } + + Result> LocalServer::accept() + { + int fd = ::accept(m_fd, nullptr, nullptr); + if (fd < 0) return err(errno); + if (!m_blocking) fcntl(fd, F_SETFL, O_NONBLOCK); + return make_shared(fd); + } + + LocalServer::~LocalServer() + { + close(m_fd); + } + + LocalServer::Client::Client(int fd) : m_fd(fd) + { + } + + LocalServer::Client::~Client() + { + if (m_fd >= 0) close(m_fd); + } + + Result LocalServer::Client::recv(u8* buf, usize length) + { + ssize_t nread = read(m_fd, buf, length); + if (nread < 0) return err(errno); + return nread; + } + + Result LocalServer::Client::send(const u8* buf, usize length) + { + ssize_t nwrite = write(m_fd, buf, length); + if (nwrite < 0) return err(errno); + return nwrite; + } + + void LocalServer::Client::disconnect() + { + close(m_fd); + m_fd = -1; + } +}