diff --git a/include/nyaio/io.hpp b/include/nyaio/io.hpp index a12cfd1..ef8686d 100644 --- a/include/nyaio/io.hpp +++ b/include/nyaio/io.hpp @@ -754,6 +754,15 @@ class TcpStream { /// Create an empty @c TcpStream. Empty @c TcpStream object cannot be used for IO operations. TcpStream() noexcept : m_socket(-1), m_address() {} + /// @brief + /// For internal usage. Wrap a raw socket and address into a @c TcpStream object. + /// @param socket + /// Raw socket file descriptor. + /// @param address + /// Remote address of this TCP stream. + TcpStream(int socket, const InetAddress &address) noexcept + : m_socket(socket), m_address(address) {} + /// @brief /// @c TcpStream is not copyable. TcpStream(const Self &other) = delete; @@ -1011,4 +1020,201 @@ class TcpStream { InetAddress m_address; }; +/// @class TcpServer +/// @brief +/// Wrapper class for TCP server. This class is used for accepting incoming TCP connections. +class TcpServer { +public: + using Self = TcpServer; + + /// @struct AcceptResult + /// @brief + /// Result for accept operation. This struct contains the accepted TCP stream and error code. + /// The TCP stream is valid only if the error code is 0. + struct AcceptResult { + TcpStream stream; + std::errc error; + }; + + /// @class AcceptAwaiter + /// @brief + /// Customized accept awaitable for @c TcpServer. Using awaitable to avoid memory allocation + /// by @c Task. + class AcceptAwaiter { + public: + using Self = AcceptAwaiter; + + /// @brief + /// Create a new @c AcceptAwaiter for accepting a new TCP connection. + /// @param server + /// The server TCP socket. + AcceptAwaiter(int server) noexcept + : m_promise(), m_socket(server), m_addrLen(sizeof(InetAddress)), m_address() {} + + /// @brief + /// C++20 coroutine API method. Always execute @c await_suspend(). + /// @return + /// This function always returns @c false. + [[nodiscard]] + static constexpr auto await_ready() noexcept -> bool { + return false; + } + + /// @brief + /// Prepare for async accept operation and suspend the coroutine. + /// @tparam T + /// Type of promise of current coroutine. + /// @param coro + /// Current coroutine handle. + template + requires(std::is_base_of_v) + auto await_suspend(std::coroutine_handle coro) noexcept -> void { + auto &promise = static_cast(coro.promise()); + m_promise = &promise; + auto *worker = promise.worker; + + io_uring_sqe *sqe = worker->pollSubmissionQueueEntry(); + while (sqe == nullptr) [[unlikely]] { + worker->submit(); + sqe = worker->pollSubmissionQueueEntry(); + } + + sqe->opcode = IORING_OP_ACCEPT; + sqe->fd = m_socket; + sqe->addr = reinterpret_cast(&m_address); + sqe->off = reinterpret_cast(&m_addrLen); + sqe->accept_flags = SOCK_CLOEXEC; + sqe->user_data = reinterpret_cast(&promise); + + worker->flushSubmissionQueue(); + } + + /// @brief + /// Resume this coroutine and get result of the async connect operation. + /// @return + /// A struct of accepted TCP stream and error code. The error code is 0 if succeeded to + /// accept a new connection. + [[nodiscard]] + auto await_resume() noexcept -> AcceptResult { + if (m_promise->ioResult < 0) [[unlikely]] + return {{}, std::errc{-m_promise->ioResult}}; + return {{m_promise->ioResult, m_address}, std::errc{}}; + } + + private: + detail::PromiseBase *m_promise; + int m_socket; + socklen_t m_addrLen; + InetAddress m_address; + }; + +public: + /// @brief + /// Create an empty @c TcpServer. Empty @c TcpServer object cannot be used for accepting new + /// TCP connections before binding. + TcpServer() noexcept : m_socket(-1), m_address() {} + + /// @brief + /// @c TcpServer is not copyable. + TcpServer(const Self &other) = delete; + + /// @brief + /// Move constructor of @c TcpServer. + /// @param[in, out] other + /// The @c TcpServer object to be moved from. The moved object will be empty and can not be + /// used for accepting new TCP connections. + TcpServer(Self &&other) noexcept : m_socket(other.m_socket), m_address(other.m_address) { + other.m_socket = -1; + } + + /// @brief + /// Stop listening to incoming TCP connections and destroy this object. + ~TcpServer() { + if (m_socket != -1) + ::close(m_socket); + } + + /// @brief + /// @c TcpServer is not copyable. + auto operator=(const Self &other) = delete; + + /// @brief + /// Move assignment of @c TcpServer. + /// @param[in, out] other + /// The @c TcpServer object to be moved from. The moved object will be empty and can not be + /// used for accepting new TCP connections. + /// @return + /// Reference to this @c TcpServer. + auto operator=(Self &&other) noexcept -> Self & { + if (this == &other) [[unlikely]] + return *this; + + if (m_socket != -1) + ::close(m_socket); + + m_socket = other.m_socket; + m_address = other.m_address; + other.m_socket = -1; + + return *this; + } + + /// @brief + /// Get local listening address. Get local listening address from an empty @c TcpServer is + /// undefined behavior. + /// @return + /// Local listening address of this TCP server. The return value is undefined if this TCP + /// server is empty. + [[nodiscard]] + auto address() const noexcept -> const InetAddress & { + return m_address; + } + + /// @brief + /// Start listening to incoming TCP connections. This method will create a new TCP server + /// socket and bind to the specified address. The old TCP server socket will be closed if + /// succeeded to listen to the new address. + /// @param address + /// The local address to bind for listening incoming TCP connections. + /// @return + /// A system error code that indicates whether succeeded to start listening to incoming TCP + /// connections. The error code is 0 if succeeded to start listening. The original TCP server + /// socket will not be affected if any error occurs. + NYAIO_API auto bind(const InetAddress &address) noexcept -> std::errc; + + /// @brief + /// Accept a new incoming TCP connection. This method will block current thread until a new + /// TCP connection is established or any error occurs. + /// @return + /// A struct of accepted TCP stream and error code. The error code is 0 if succeeded to accept + /// a new connection. + NYAIO_API auto accept() noexcept -> AcceptResult; + + /// @brief + /// Async accept a new incoming TCP connection. This method will suspend this coroutine until + /// a new TCP connection is established or any error occurs. + /// @return + /// A struct of accepted TCP stream and error code. The error code is 0 if succeeded to accept + /// a new connection. + [[nodiscard]] + auto acceptAsync() noexcept -> AcceptAwaiter { + return {m_socket}; + } + + /// @brief + /// Stop listening to incoming TCP connections and close this TCP server. This TCP server will + /// be set to empty after this call. Call @c TcpServer::bind() to start listening to incoming + /// again. + auto close() noexcept -> void { + if (m_socket != -1) { + ::close(m_socket); + m_socket = -1; + } + } + +private: + int m_socket; + InetAddress m_address; +}; + } // namespace nyaio diff --git a/src/nyaio/io.cpp b/src/nyaio/io.cpp index 441612c..29c8380 100644 --- a/src/nyaio/io.cpp +++ b/src/nyaio/io.cpp @@ -49,3 +49,48 @@ auto TcpStream::connect(const InetAddress &address) noexcept -> std::errc { return {}; } + +auto TcpServer::bind(const InetAddress &address) noexcept -> std::errc { + auto *addr = reinterpret_cast(&address); + int s = ::socket(addr->sa_family, SOCK_STREAM | SOCK_CLOEXEC, IPPROTO_TCP); + if (s == -1) [[unlikely]] + return std::errc{errno}; + + { // Enable reuse address and reuse port. + const int value = 1; + ::setsockopt(s, SOL_SOCKET, SO_REUSEADDR, &value, sizeof(value)); + ::setsockopt(s, SOL_SOCKET, SO_REUSEPORT, &value, sizeof(value)); + } + + if (::bind(s, addr, address.size()) == -1) [[unlikely]] { + int error = errno; + ::close(s); + return std::errc{error}; + } + + if (::listen(s, SOMAXCONN) == -1) [[unlikely]] { + int error = errno; + ::close(s); + return std::errc{error}; + } + + if (m_socket != -1) + ::close(m_socket); + + m_socket = s; + m_address = address; + + return {}; +} + +auto TcpServer::accept() noexcept -> AcceptResult { + InetAddress address; + socklen_t length = sizeof(address); + + int s = + ::accept4(m_socket, reinterpret_cast(&address), &length, SOCK_CLOEXEC); + if (s == -1) [[unlikely]] + return {TcpStream{}, std::errc{errno}}; + + return {TcpStream{s, address}, std::errc{}}; +}