Skip to content

Commit

Permalink
Add TcpStream
Browse files Browse the repository at this point in the history
  • Loading branch information
longhao-li committed Sep 10, 2024
1 parent d9454a6 commit 201cf16
Show file tree
Hide file tree
Showing 2 changed files with 382 additions and 0 deletions.
360 changes: 360 additions & 0 deletions include/nyaio/io.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
#include <bit>

#include <netinet/in.h>
#include <netinet/tcp.h>

namespace nyaio {

Expand Down Expand Up @@ -651,4 +652,363 @@ class InetAddress {
} m_addr;
};

/// @class TcpStream
/// @brief
/// Wrapper class for TCP connection. This class is used for TCP socket IO.
class TcpStream {
public:
using Self = TcpStream;

/// @class ConnectAwaiter
/// @brief
/// Customized connect awaitable for @c TcpStream. Using awaitable to avoid memory allocation
/// by @c Task.
class ConnectAwaiter {
public:
using Self = ConnectAwaiter;

/// @brief
/// Create a new @c ConnectAwaiter for establishing a TCP connection.
/// @param[in, out] stream
/// The @c TcpStream to establish connection.
/// @param address
/// The peer address to connect.
ConnectAwaiter(TcpStream &stream, const InetAddress &address) noexcept
: m_promise(nullptr), m_socket(-1), m_address(&address), m_stream(&stream) {}

/// @brief
/// C++20 coroutine API method. Always execute @c await_suspend().
/// @return
/// This function always returns @c false.
[[nodiscard]]
static constexpr auto await_ready() noexcept -> bool {
return false;
}

/// @brief
/// Prepare for async connect operation and suspend the coroutine.
/// @tparam T
/// Type of promise of current coroutine.
/// @param coro
/// Current coroutine handle.
template <class T>
requires(std::is_base_of_v<detail::PromiseBase, T>)
auto await_suspend(std::coroutine_handle<T> coro) noexcept -> bool {
auto &promise = static_cast<detail::PromiseBase &>(coro.promise());
m_promise = &promise;

// Create a new socket to establish connection.
auto *addr = reinterpret_cast<const struct sockaddr *>(m_address);
m_socket = ::socket(addr->sa_family, SOCK_STREAM | SOCK_CLOEXEC, IPPROTO_TCP);
if (m_socket == -1) [[unlikely]] {
promise.ioResult = -errno;
return false;
}

auto *worker = promise.worker;
io_uring_sqe *sqe = worker->pollSubmissionQueueEntry();
while (sqe == nullptr) [[unlikely]] {
worker->submit();
sqe = worker->pollSubmissionQueueEntry();
}

sqe->opcode = IORING_OP_CONNECT;
sqe->fd = m_socket;
sqe->addr = reinterpret_cast<std::uintptr_t>(m_address);
sqe->off = m_address->size();
sqe->user_data = reinterpret_cast<std::uintptr_t>(&promise);

worker->flushSubmissionQueue();
return true;
}

/// @brief
/// Resume this coroutine and get result of the async connect operation.
/// @return
/// Error code of the connect operation. The error code is 0 if succeeded.
auto await_resume() const noexcept -> std::errc {
if (m_promise->ioResult < 0) [[unlikely]] {
if (m_socket != -1)
::close(m_socket);
return static_cast<std::errc>(-m_promise->ioResult);
}

if (m_stream->m_socket != -1)
::close(m_stream->m_socket);

m_stream->m_socket = m_socket;
m_stream->m_address = *m_address;

return {};
}

private:
detail::PromiseBase *m_promise;
int m_socket;
const InetAddress *m_address;
TcpStream *m_stream;
};

public:
/// @brief
/// Create an empty @c TcpStream. Empty @c TcpStream object cannot be used for IO operations.
TcpStream() noexcept : m_socket(-1), m_address() {}

/// @brief
/// @c TcpStream is not copyable.
TcpStream(const Self &other) = delete;

/// @brief
/// Move constructor of @c TcpStream.
/// @param[in, out] other
/// The @c TcpStream object to be moved. The moved object will be empty and can not be used
/// for IO operations.
TcpStream(Self &&other) noexcept : m_socket(other.m_socket), m_address(other.m_address) {
other.m_socket = -1;
}

/// @brief
/// Close the TCP connection and destroy this object.
~TcpStream() {
if (m_socket != -1)
::close(m_socket);
}

/// @brief
/// @c TcpStream is not copyable.
auto operator=(const Self &other) = delete;

/// @brief
/// Move assignment of @c TcpStream.
/// @param[in, out] other
/// The @c TcpStream object to be moved. The moved object will be empty and can not be used
/// for IO operations.
/// @return
/// Reference to this @c TcpStream.
auto operator=(Self &&other) noexcept -> Self & {
if (this == &other) [[unlikely]]
return *this;

if (m_socket != -1)
::close(m_socket);

m_socket = other.m_socket;
m_address = other.m_address;
other.m_socket = -1;

return *this;
}

/// @brief
/// Get remote address of this TCP stream. It is undefined behavior to get remote address from
/// an empty @c TcpStream.
/// @return
/// Remote address of this TCP stream.
[[nodiscard]]
auto remoteAddress() const noexcept -> const InetAddress & {
return m_address;
}

/// @brief
/// Connect to the specified peer address. This method will block current thread until the
/// connection is established or any error occurs. If this TCP stream is currently not empty,
/// the old connection will be closed once the new connection is established. The old
/// connection will not be closed if the new connection fails.
/// @param address
/// Peer Internet address to connect.
/// @return
/// Error code of the connect operation. The error code is 0 if succeeded.
NYAIO_API auto connect(const InetAddress &address) noexcept -> std::errc;

/// @brief
/// Connect to the specified peer address. This method will suspend this coroutine until the
/// new connection is established or any error occurs. If this TCP stream is currently not
/// empty, the old connection will be closed once the new connection is established. The old
/// connection will not be closed if the new connection fails.
/// @param address
/// Peer Internet address to connect.
/// @return
/// Error code of the connect operation. The error code is 0 if succeeded.
[[nodiscard]]
auto connectAsync(const InetAddress &address) noexcept -> ConnectAwaiter {
return {*this, address};
}

/// @brief
/// Send data to peer TCP endpoint. This method will block current thread until all data is
/// sent or any error occurs.
/// @param data
/// Pointer to start of data to be sent.
/// @param size
/// Expected size in byte of data to be sent.
/// @return
/// A struct that contains number of bytes sent and an error code. The error code is
/// @c std::errc{} if succeeded and the number of bytes sent is valid.
auto send(const void *data, std::uint32_t size) const noexcept -> SystemIoResult {
ssize_t result = ::send(m_socket, data, size, MSG_NOSIGNAL);
if (result == -1) [[unlikely]]
return {0, std::errc{-errno}};
return {static_cast<std::uint32_t>(result), std::errc{}};
}

/// @brief
/// Async send data to peer TCP endpoint. This method will suspend this coroutine until any
/// data is sent or any error occurs.
/// @param data
/// Pointer to start of data to be sent.
/// @param size
/// Expected size in byte of data to be sent.
/// @return
/// A struct that contains number of bytes sent and an error code. The error code is
/// @c std::errc{} if succeeded and the number of bytes sent is valid.
[[nodiscard]]
auto sendAsync(const void *data, std::uint32_t size) const noexcept -> SendAwaitable {
return {m_socket, data, size, MSG_NOSIGNAL};
}

/// @brief
/// Receive data from peer TCP endpoint. This method will block current thread until any data
/// is received or error occurs.
/// @param[out] buffer
/// Pointer to start of buffer to store the received data.
/// @param size
/// Maximum available size to be received.
/// @return
/// A struct that contains number of bytes received and an error code. The error code is
/// @c std::errc{} if succeeded and the number of bytes received is valid.
auto receive(void *buffer, std::uint32_t size) noexcept -> SystemIoResult {
ssize_t result = ::recv(m_socket, buffer, size, 0);
if (result == -1) [[unlikely]]
return {0, std::errc{-errno}};
return {static_cast<std::uint32_t>(result), std::errc{}};
}

/// @brief
/// Async receive data from peer TCP endpoint. This method will suspend this coroutine until
/// any data is received or any error occurs.
/// @param[out] buffer
/// Pointer to start of buffer to store the received data.
/// @param size
/// Maximum available size to be received.
/// @return
/// A struct that contains number of bytes received and an error code. The error code is
/// @c std::errc{} if succeeded and the number of bytes received is valid.
[[nodiscard]]
auto receiveAsync(void *buffer, std::uint32_t size) noexcept -> ReceiveAwaitable {
return {m_socket, buffer, size, 0};
}

/// @brief
/// Enable or disable keepalive for this TCP connection.
/// @param enable
/// Specifies whether to enable or disable keepalive for this TCP stream.
/// @return
/// An error code that indicates whether succeeded to enable or disable keepalive for this TCP
/// stream. The error code is 0 if succeeded to set keepalive attribute for this TCP stream.
auto setKeepAlive(bool enable) noexcept -> std::errc {
const int value = enable ? 1 : 0;
int ret = ::setsockopt(m_socket, SOL_SOCKET, SO_KEEPALIVE, &value, sizeof(value));
if (ret == -1) [[unlikely]]
return std::errc{errno};
return {};
}

/// @brief
/// Enable or disable nodelay for this TCP stream.
/// @param enable
/// Specifies whether to enable or disable nodelay for this TCP stream.
/// @return
/// An error code that indicates whether succeeded to enable or disable nodelay for this TCP
/// stream. The error code is 0 if succeeded to set nodelay attribute for this TCP stream.
auto setNoDelay(bool enable) noexcept -> std::errc {
const int value = enable ? 1 : 0;
int ret = ::setsockopt(m_socket, SOL_TCP, TCP_NODELAY, &value, sizeof(value));
if (ret == -1) [[unlikely]]
return std::errc{errno};
return {};
}

/// @brief
/// Set timeout event for send operation. @c TcpStream::send and @c TcpStream::sendAsync may
/// generate an error that indicates the timeout event if timeout event occurs. The TCP
/// connection may be in an undefined state and should be closed if send timeout event occurs.
/// @tparam Rep
/// Type representation of duration type. See @c std::chrono::duration for details.
/// @tparam Period
/// Ratio type that is used to measure how to do conversion between different duration types.
/// See @c std::chrono::duration for details.
/// @param duration
/// Timeout duration of send operation. Ratios less than microseconds are not allowed.
/// @return
/// An error code that indicates whether succeeded to set the timeout event. The error code is
/// 0 if succeeded to set or remove send timeout event for this TCP connection.
template <class Rep, class Period>
requires(std::ratio_less_equal_v<std::micro, Period>)
auto setSendTimeout(std::chrono::duration<Rep, Period> duration) noexcept -> std::errc {
auto microsec = std::chrono::duration_cast<std::chrono::microseconds>(duration).count();
if (microsec < 0) [[unlikely]]
return std::errc::invalid_argument;

const struct timeval timeout{
.tv_sec = static_cast<std::uint32_t>(microsec / 1000000),
.tv_usec = static_cast<std::uint32_t>(microsec % 1000000),
};

int ret = ::setsockopt(m_socket, SOL_SOCKET, SO_SNDTIMEO, &timeout, sizeof(timeout));
if (ret == -1) [[unlikely]]
return std::errc{errno};

return {};
}

/// @brief
/// Set timeout event for receive operation. @c TcpStream::receive and @c
/// TcpStream::receiveAsync may generate an error that indicates the timeout event if timeout
/// event occurs. The TCP connection may be in an undefined state and should be closed if
/// receive timeout event occurs.
/// @tparam Rep
/// Type representation of duration type. See @c std::chrono::duration for details.
/// @tparam Period
/// Ratio type that is used to measure how to do conversion between different duration types.
/// See @c std::chrono::duration for details.
/// @param duration
/// Timeout duration of receive operation. Ratios less than microseconds are not allowed.
/// @return
/// An error code that indicates whether succeeded to set the timeout event. The error code is
/// 0 if succeeded to set or remove receive timeout event for this TCP connection.
template <class Rep, class Period>
requires(std::ratio_less_equal_v<std::micro, Period>)
auto setReceiveTimeout(std::chrono::duration<Rep, Period> duration) noexcept -> std::errc {
auto microsec = std::chrono::duration_cast<std::chrono::microseconds>(duration).count();
if (microsec < 0) [[unlikely]]
return std::errc::invalid_argument;

const struct timeval timeout{
.tv_sec = static_cast<std::uint32_t>(microsec / 1000000),
.tv_usec = static_cast<std::uint32_t>(microsec % 1000000),
};

int ret = ::setsockopt(m_socket, SOL_SOCKET, SO_RCVTIMEO, &timeout, sizeof(timeout));
if (ret == -1) [[unlikely]]
return std::errc{errno};

return {};
}

/// @brief
/// Close this TCP stream and release all resources. Closing a TCP stream with pending IO
/// requirements may cause errors for the IO results. This method does nothing if current TCP
/// stream is empty.
auto close() noexcept -> void {
if (m_socket != -1) {
::close(m_socket);
m_socket = -1;
}
}

private:
int m_socket;
InetAddress m_address;
};

} // namespace nyaio
22 changes: 22 additions & 0 deletions src/nyaio/io.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -27,3 +27,25 @@ IpAddress::IpAddress(std::string_view addr) : m_isV6(), m_addr() {
if (ret != 1)
throw std::invalid_argument("Invalid IP address format.");
}

auto TcpStream::connect(const InetAddress &address) noexcept -> std::errc {
// Create a new socket to establish connection.
auto *addr = reinterpret_cast<const struct sockaddr *>(&address);
int s = ::socket(addr->sa_family, SOCK_STREAM | SOCK_CLOEXEC, IPPROTO_TCP);
if (s == -1) [[unlikely]]
return static_cast<std::errc>(errno);

if (::connect(s, addr, address.size()) == -1) [[unlikely]] {
int error = errno;
::close(s);
return static_cast<std::errc>(error);
}

if (m_socket != -1)
::close(m_socket);

m_socket = s;
m_address = address;

return {};
}

0 comments on commit 201cf16

Please sign in to comment.