diff --git a/kernel/src/net/Socket.h b/kernel/src/net/Socket.h index 1ac16c92..5b980958 100644 --- a/kernel/src/net/Socket.h +++ b/kernel/src/net/Socket.h @@ -65,6 +65,10 @@ class Socket : public VFS::FileInode m_metadata.nlinks--; } + virtual bool can_accept_connections() const = 0; + + virtual bool can_read_data() const = 0; + virtual ~Socket() = default; protected: diff --git a/kernel/src/net/UnixSocket.h b/kernel/src/net/UnixSocket.h index 69b23cfa..9ec88000 100644 --- a/kernel/src/net/UnixSocket.h +++ b/kernel/src/net/UnixSocket.h @@ -17,6 +17,16 @@ class UnixSocket : public Socket return (m_state == Connected || m_state == Reset) && !m_data.size(); } + bool can_read_data() const override + { + return (m_state == Connected || m_state == Reset) && m_data.size(); + } + + bool can_accept_connections() const override + { + return !m_listen_queue.is_empty(); + } + Result send(const u8*, usize, int) override; Result recv(u8*, usize, int) const override; diff --git a/kernel/src/sys/poll.cpp b/kernel/src/sys/poll.cpp index 42d99fff..bb05710e 100644 --- a/kernel/src/sys/poll.cpp +++ b/kernel/src/sys/poll.cpp @@ -1,6 +1,7 @@ #include "Log.h" #include "fs/VFS.h" #include "memory/MemoryManager.h" +#include "net/Socket.h" #include "sys/Syscall.h" #include "thread/Scheduler.h" #include @@ -43,10 +44,25 @@ Result sys_poll(Registers* regs, SyscallArgs args) auto& inode = inodes[i]; if (!inode) continue; - if (kfds[i].events & POLLIN && !inode->will_block_if_read()) + if (kfds[i].events & POLLIN) { - fds_with_events++; - kfds[i].revents |= POLLIN; + if (inode->type() == VFS::InodeType::Socket) + { + auto socket = (Socket*)inode.ptr(); + if (socket->can_read_data() || socket->can_accept_connections()) + { + fds_with_events++; + kfds[i].revents |= POLLIN; + } + } + else + { + if (!inode->will_block_if_read()) + { + fds_with_events++; + kfds[i].revents |= POLLIN; + } + } } } diff --git a/libluna/include/luna/CircularQueue.h b/libluna/include/luna/CircularQueue.h index 50317feb..0bee00fd 100644 --- a/libluna/include/luna/CircularQueue.h +++ b/libluna/include/luna/CircularQueue.h @@ -16,7 +16,7 @@ template class CircularQueue { } - bool is_empty() + bool is_empty() const { return m_tail.load() == m_head.load(); } @@ -76,7 +76,7 @@ template class DynamicCircularQueue if (m_data) free_impl(m_data); } - bool is_empty() + bool is_empty() const { return m_tail.load() == m_head.load(); }