diff --git a/libos/CMakeLists.txt b/libos/CMakeLists.txt index db5748a8..3bee1165 100644 --- a/libos/CMakeLists.txt +++ b/libos/CMakeLists.txt @@ -15,6 +15,7 @@ set(SOURCES src/Prompt.cpp src/Security.cpp src/LocalServer.cpp + src/LocalClient.cpp ) add_library(os ${SOURCES}) diff --git a/libos/include/os/LocalClient.h b/libos/include/os/LocalClient.h new file mode 100644 index 00000000..46314053 --- /dev/null +++ b/libos/include/os/LocalClient.h @@ -0,0 +1,90 @@ +#pragma once +#include +#include + +namespace os +{ + /** + * @brief A client used to connect to a local server socket. + */ + class LocalClient : public Shareable + { + public: + /** + * @brief Create a new client object and connect it to a local server. + * + * @param path The path of the server socket to connect to. + * @param blocking Whether the client should block if no data is available and recv() is called. + * @return Result> An error, or a new client object. + */ + static Result> connect(StringView path, bool blocking); + + /** + * @brief Return the underlying socket file descriptor used by this object. + * + * @return int The file descriptor. + */ + int fd() const + { + return m_fd; + } + + /** + * @brief Read arbitrary data from the server. The call will block if there is no data and this object has not + * been created as non-blocking. + * + * @param buf The buffer to read data into. + * @param length The maximum amount of bytes to read. + * @return Result An error, or the number of bytes read. + */ + Result recv(u8* buf, usize length); + + /** + * @brief Read an object from the server. The call will block if there is no data and this object has not been + * created as non-blocking. + * + * @tparam T The type of the object. + * @param out A reference to the object to read data into. + * @return Result Whether the operation succeded. + */ + template Result recv_typed(T& out) + { + TRY(recv((u8*)&out, sizeof(T))); + return {}; + } + + /** + * @brief Send arbitrary data to the server. + * + * @param buf The buffer to send data from. + * @param length The amount of bytes to send. + * @return Result An error, or the number of bytes actually sent. + */ + Result send(const u8* buf, usize length); + + /** + * @brief Send an object to the server. + * + * @tparam T The type of the object. + * @param out A reference to the object to send data from. + * @return Result Whether the operation succeded. + */ + template Result send_typed(const T& out) + { + TRY(send((const u8*)&out, sizeof(T))); + return {}; + } + + /** + * @brief Disconnect from the attached server. + * + * This will make any further reads on this connection return ECONNRESET, and will make this object invalid. + */ + void disconnect(); + + ~LocalClient(); + + private: + int m_fd; + }; +} diff --git a/libos/src/LocalClient.cpp b/libos/src/LocalClient.cpp new file mode 100644 index 00000000..42d61a9c --- /dev/null +++ b/libos/src/LocalClient.cpp @@ -0,0 +1,59 @@ +#include +#include +#include +#include +#include +#include + +namespace os +{ + Result> LocalClient::connect(StringView path, bool blocking) + { + auto client = TRY(make_shared()); + + int sockfd = socket(AF_UNIX, SOCK_STREAM, 0); + if (sockfd < 0) return err(errno); + + struct sockaddr_un un; + un.sun_family = AF_UNIX; + strncpy(un.sun_path, path.chars(), sizeof(un.sun_path)); + + if (::connect(sockfd, (struct sockaddr*)&un, sizeof(un)) < 0) + { + close(sockfd); + return err(errno); + } + + if (!blocking) { fcntl(sockfd, F_SETFL, O_NONBLOCK); } + + fcntl(sockfd, F_SETFD, FD_CLOEXEC); + + client->m_fd = sockfd; + return client; + } + + LocalClient::~LocalClient() + { + close(m_fd); + } + + Result LocalClient::recv(u8* buf, usize length) + { + ssize_t nread = read(m_fd, buf, length); + if (nread < 0) return err(errno); + return nread; + } + + Result LocalClient::send(const u8* buf, usize length) + { + ssize_t nwrite = write(m_fd, buf, length); + if (nwrite < 0) return err(errno); + return nwrite; + } + + void LocalClient::disconnect() + { + close(m_fd); + m_fd = -1; + } +}