diff --git a/include/nyaio/io.hpp b/include/nyaio/io.hpp index 4dee161..a12cfd1 100644 --- a/include/nyaio/io.hpp +++ b/include/nyaio/io.hpp @@ -5,6 +5,7 @@ #include #include +#include namespace nyaio { @@ -651,4 +652,363 @@ class InetAddress { } m_addr; }; +/// @class TcpStream +/// @brief +/// Wrapper class for TCP connection. This class is used for TCP socket IO. +class TcpStream { +public: + using Self = TcpStream; + + /// @class ConnectAwaiter + /// @brief + /// Customized connect awaitable for @c TcpStream. Using awaitable to avoid memory allocation + /// by @c Task. + class ConnectAwaiter { + public: + using Self = ConnectAwaiter; + + /// @brief + /// Create a new @c ConnectAwaiter for establishing a TCP connection. + /// @param[in, out] stream + /// The @c TcpStream to establish connection. + /// @param address + /// The peer address to connect. + ConnectAwaiter(TcpStream &stream, const InetAddress &address) noexcept + : m_promise(nullptr), m_socket(-1), m_address(&address), m_stream(&stream) {} + + /// @brief + /// C++20 coroutine API method. Always execute @c await_suspend(). + /// @return + /// This function always returns @c false. + [[nodiscard]] + static constexpr auto await_ready() noexcept -> bool { + return false; + } + + /// @brief + /// Prepare for async connect operation and suspend the coroutine. + /// @tparam T + /// Type of promise of current coroutine. + /// @param coro + /// Current coroutine handle. + template + requires(std::is_base_of_v) + auto await_suspend(std::coroutine_handle coro) noexcept -> bool { + auto &promise = static_cast(coro.promise()); + m_promise = &promise; + + // Create a new socket to establish connection. + auto *addr = reinterpret_cast(m_address); + m_socket = ::socket(addr->sa_family, SOCK_STREAM | SOCK_CLOEXEC, IPPROTO_TCP); + if (m_socket == -1) [[unlikely]] { + promise.ioResult = -errno; + return false; + } + + auto *worker = promise.worker; + io_uring_sqe *sqe = worker->pollSubmissionQueueEntry(); + while (sqe == nullptr) [[unlikely]] { + worker->submit(); + sqe = worker->pollSubmissionQueueEntry(); + } + + sqe->opcode = IORING_OP_CONNECT; + sqe->fd = m_socket; + sqe->addr = reinterpret_cast(m_address); + sqe->off = m_address->size(); + sqe->user_data = reinterpret_cast(&promise); + + worker->flushSubmissionQueue(); + return true; + } + + /// @brief + /// Resume this coroutine and get result of the async connect operation. + /// @return + /// Error code of the connect operation. The error code is 0 if succeeded. + auto await_resume() const noexcept -> std::errc { + if (m_promise->ioResult < 0) [[unlikely]] { + if (m_socket != -1) + ::close(m_socket); + return static_cast(-m_promise->ioResult); + } + + if (m_stream->m_socket != -1) + ::close(m_stream->m_socket); + + m_stream->m_socket = m_socket; + m_stream->m_address = *m_address; + + return {}; + } + + private: + detail::PromiseBase *m_promise; + int m_socket; + const InetAddress *m_address; + TcpStream *m_stream; + }; + +public: + /// @brief + /// Create an empty @c TcpStream. Empty @c TcpStream object cannot be used for IO operations. + TcpStream() noexcept : m_socket(-1), m_address() {} + + /// @brief + /// @c TcpStream is not copyable. + TcpStream(const Self &other) = delete; + + /// @brief + /// Move constructor of @c TcpStream. + /// @param[in, out] other + /// The @c TcpStream object to be moved. The moved object will be empty and can not be used + /// for IO operations. + TcpStream(Self &&other) noexcept : m_socket(other.m_socket), m_address(other.m_address) { + other.m_socket = -1; + } + + /// @brief + /// Close the TCP connection and destroy this object. + ~TcpStream() { + if (m_socket != -1) + ::close(m_socket); + } + + /// @brief + /// @c TcpStream is not copyable. + auto operator=(const Self &other) = delete; + + /// @brief + /// Move assignment of @c TcpStream. + /// @param[in, out] other + /// The @c TcpStream object to be moved. The moved object will be empty and can not be used + /// for IO operations. + /// @return + /// Reference to this @c TcpStream. + auto operator=(Self &&other) noexcept -> Self & { + if (this == &other) [[unlikely]] + return *this; + + if (m_socket != -1) + ::close(m_socket); + + m_socket = other.m_socket; + m_address = other.m_address; + other.m_socket = -1; + + return *this; + } + + /// @brief + /// Get remote address of this TCP stream. It is undefined behavior to get remote address from + /// an empty @c TcpStream. + /// @return + /// Remote address of this TCP stream. + [[nodiscard]] + auto remoteAddress() const noexcept -> const InetAddress & { + return m_address; + } + + /// @brief + /// Connect to the specified peer address. This method will block current thread until the + /// connection is established or any error occurs. If this TCP stream is currently not empty, + /// the old connection will be closed once the new connection is established. The old + /// connection will not be closed if the new connection fails. + /// @param address + /// Peer Internet address to connect. + /// @return + /// Error code of the connect operation. The error code is 0 if succeeded. + NYAIO_API auto connect(const InetAddress &address) noexcept -> std::errc; + + /// @brief + /// Connect to the specified peer address. This method will suspend this coroutine until the + /// new connection is established or any error occurs. If this TCP stream is currently not + /// empty, the old connection will be closed once the new connection is established. The old + /// connection will not be closed if the new connection fails. + /// @param address + /// Peer Internet address to connect. + /// @return + /// Error code of the connect operation. The error code is 0 if succeeded. + [[nodiscard]] + auto connectAsync(const InetAddress &address) noexcept -> ConnectAwaiter { + return {*this, address}; + } + + /// @brief + /// Send data to peer TCP endpoint. This method will block current thread until all data is + /// sent or any error occurs. + /// @param data + /// Pointer to start of data to be sent. + /// @param size + /// Expected size in byte of data to be sent. + /// @return + /// A struct that contains number of bytes sent and an error code. The error code is + /// @c std::errc{} if succeeded and the number of bytes sent is valid. + auto send(const void *data, std::uint32_t size) const noexcept -> SystemIoResult { + ssize_t result = ::send(m_socket, data, size, MSG_NOSIGNAL); + if (result == -1) [[unlikely]] + return {0, std::errc{-errno}}; + return {static_cast(result), std::errc{}}; + } + + /// @brief + /// Async send data to peer TCP endpoint. This method will suspend this coroutine until any + /// data is sent or any error occurs. + /// @param data + /// Pointer to start of data to be sent. + /// @param size + /// Expected size in byte of data to be sent. + /// @return + /// A struct that contains number of bytes sent and an error code. The error code is + /// @c std::errc{} if succeeded and the number of bytes sent is valid. + [[nodiscard]] + auto sendAsync(const void *data, std::uint32_t size) const noexcept -> SendAwaitable { + return {m_socket, data, size, MSG_NOSIGNAL}; + } + + /// @brief + /// Receive data from peer TCP endpoint. This method will block current thread until any data + /// is received or error occurs. + /// @param[out] buffer + /// Pointer to start of buffer to store the received data. + /// @param size + /// Maximum available size to be received. + /// @return + /// A struct that contains number of bytes received and an error code. The error code is + /// @c std::errc{} if succeeded and the number of bytes received is valid. + auto receive(void *buffer, std::uint32_t size) noexcept -> SystemIoResult { + ssize_t result = ::recv(m_socket, buffer, size, 0); + if (result == -1) [[unlikely]] + return {0, std::errc{-errno}}; + return {static_cast(result), std::errc{}}; + } + + /// @brief + /// Async receive data from peer TCP endpoint. This method will suspend this coroutine until + /// any data is received or any error occurs. + /// @param[out] buffer + /// Pointer to start of buffer to store the received data. + /// @param size + /// Maximum available size to be received. + /// @return + /// A struct that contains number of bytes received and an error code. The error code is + /// @c std::errc{} if succeeded and the number of bytes received is valid. + [[nodiscard]] + auto receiveAsync(void *buffer, std::uint32_t size) noexcept -> ReceiveAwaitable { + return {m_socket, buffer, size, 0}; + } + + /// @brief + /// Enable or disable keepalive for this TCP connection. + /// @param enable + /// Specifies whether to enable or disable keepalive for this TCP stream. + /// @return + /// An error code that indicates whether succeeded to enable or disable keepalive for this TCP + /// stream. The error code is 0 if succeeded to set keepalive attribute for this TCP stream. + auto setKeepAlive(bool enable) noexcept -> std::errc { + const int value = enable ? 1 : 0; + int ret = ::setsockopt(m_socket, SOL_SOCKET, SO_KEEPALIVE, &value, sizeof(value)); + if (ret == -1) [[unlikely]] + return std::errc{errno}; + return {}; + } + + /// @brief + /// Enable or disable nodelay for this TCP stream. + /// @param enable + /// Specifies whether to enable or disable nodelay for this TCP stream. + /// @return + /// An error code that indicates whether succeeded to enable or disable nodelay for this TCP + /// stream. The error code is 0 if succeeded to set nodelay attribute for this TCP stream. + auto setNoDelay(bool enable) noexcept -> std::errc { + const int value = enable ? 1 : 0; + int ret = ::setsockopt(m_socket, SOL_TCP, TCP_NODELAY, &value, sizeof(value)); + if (ret == -1) [[unlikely]] + return std::errc{errno}; + return {}; + } + + /// @brief + /// Set timeout event for send operation. @c TcpStream::send and @c TcpStream::sendAsync may + /// generate an error that indicates the timeout event if timeout event occurs. The TCP + /// connection may be in an undefined state and should be closed if send timeout event occurs. + /// @tparam Rep + /// Type representation of duration type. See @c std::chrono::duration for details. + /// @tparam Period + /// Ratio type that is used to measure how to do conversion between different duration types. + /// See @c std::chrono::duration for details. + /// @param duration + /// Timeout duration of send operation. Ratios less than microseconds are not allowed. + /// @return + /// An error code that indicates whether succeeded to set the timeout event. The error code is + /// 0 if succeeded to set or remove send timeout event for this TCP connection. + template + requires(std::ratio_less_equal_v) + auto setSendTimeout(std::chrono::duration duration) noexcept -> std::errc { + auto microsec = std::chrono::duration_cast(duration).count(); + if (microsec < 0) [[unlikely]] + return std::errc::invalid_argument; + + const struct timeval timeout{ + .tv_sec = static_cast(microsec / 1000000), + .tv_usec = static_cast(microsec % 1000000), + }; + + int ret = ::setsockopt(m_socket, SOL_SOCKET, SO_SNDTIMEO, &timeout, sizeof(timeout)); + if (ret == -1) [[unlikely]] + return std::errc{errno}; + + return {}; + } + + /// @brief + /// Set timeout event for receive operation. @c TcpStream::receive and @c + /// TcpStream::receiveAsync may generate an error that indicates the timeout event if timeout + /// event occurs. The TCP connection may be in an undefined state and should be closed if + /// receive timeout event occurs. + /// @tparam Rep + /// Type representation of duration type. See @c std::chrono::duration for details. + /// @tparam Period + /// Ratio type that is used to measure how to do conversion between different duration types. + /// See @c std::chrono::duration for details. + /// @param duration + /// Timeout duration of receive operation. Ratios less than microseconds are not allowed. + /// @return + /// An error code that indicates whether succeeded to set the timeout event. The error code is + /// 0 if succeeded to set or remove receive timeout event for this TCP connection. + template + requires(std::ratio_less_equal_v) + auto setReceiveTimeout(std::chrono::duration duration) noexcept -> std::errc { + auto microsec = std::chrono::duration_cast(duration).count(); + if (microsec < 0) [[unlikely]] + return std::errc::invalid_argument; + + const struct timeval timeout{ + .tv_sec = static_cast(microsec / 1000000), + .tv_usec = static_cast(microsec % 1000000), + }; + + int ret = ::setsockopt(m_socket, SOL_SOCKET, SO_RCVTIMEO, &timeout, sizeof(timeout)); + if (ret == -1) [[unlikely]] + return std::errc{errno}; + + return {}; + } + + /// @brief + /// Close this TCP stream and release all resources. Closing a TCP stream with pending IO + /// requirements may cause errors for the IO results. This method does nothing if current TCP + /// stream is empty. + auto close() noexcept -> void { + if (m_socket != -1) { + ::close(m_socket); + m_socket = -1; + } + } + +private: + int m_socket; + InetAddress m_address; +}; + } // namespace nyaio diff --git a/src/nyaio/io.cpp b/src/nyaio/io.cpp index 783257e..441612c 100644 --- a/src/nyaio/io.cpp +++ b/src/nyaio/io.cpp @@ -27,3 +27,25 @@ IpAddress::IpAddress(std::string_view addr) : m_isV6(), m_addr() { if (ret != 1) throw std::invalid_argument("Invalid IP address format."); } + +auto TcpStream::connect(const InetAddress &address) noexcept -> std::errc { + // Create a new socket to establish connection. + auto *addr = reinterpret_cast(&address); + int s = ::socket(addr->sa_family, SOCK_STREAM | SOCK_CLOEXEC, IPPROTO_TCP); + if (s == -1) [[unlikely]] + return static_cast(errno); + + if (::connect(s, addr, address.size()) == -1) [[unlikely]] { + int error = errno; + ::close(s); + return static_cast(error); + } + + if (m_socket != -1) + ::close(m_socket); + + m_socket = s; + m_address = address; + + return {}; +}