Skip to content

Commit

Permalink
Add TimedSendAwaitable and TimedReceiveAwaitable for timed async send…
Browse files Browse the repository at this point in the history
…/recv support
  • Loading branch information
longhao-li committed Sep 25, 2024
1 parent 3daaf7c commit f08ae08
Show file tree
Hide file tree
Showing 2 changed files with 307 additions and 0 deletions.
225 changes: 225 additions & 0 deletions include/nyaio/task.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -1863,6 +1863,118 @@ class SendAwaitable {
int m_flags;
};

/// @class TimedSendAwaitable
/// @brief
/// Awaitable object for async send operation with timeout support.
class TimedSendAwaitable {
public:
using Self = TimedSendAwaitable;

/// @brief
/// Create a new @c TimedSendAwaitable for async send operation.
/// @param socket
/// The socket to send data to.
/// @param data
/// Pointer to start of data to be sent.
/// @param size
/// Expected size in byte of data to be sent.
/// @param flags
/// Flags for this async send operation. See linux manual for @c send for details.
/// @param timeout
/// Timeout for this send operation. Negative values and zero will be considered as never
/// timeout.
template <class Rep, class Period>
requires(std::ratio_less_equal_v<std::nano, Period>)
TimedSendAwaitable(int socket, const void *data, std::uint32_t size, int flags,
std::chrono::duration<Rep, Period> timeout) noexcept
: m_promise(), m_socket(socket), m_data(data), m_size(size), m_flags(flags), m_timeout() {
auto nano = std::chrono::duration_cast<std::chrono::nanoseconds>(timeout).count();
if (nano <= 0)
return;

m_timeout.tv_sec = static_cast<std::uint64_t>(nano) / 1000000000ULL;
m_timeout.tv_nsec = static_cast<std::uint64_t>(nano) % 1000000000ULL;
}

/// @brief
/// C++20 coroutine API method. Always execute @c await_suspend().
/// @return
/// This function always returns @c false.
[[nodiscard]]
static constexpr auto await_ready() noexcept -> bool {
return false;
}

/// @brief
/// Prepare for async send operation and suspend the coroutine.
/// @tparam T
/// Type of promise of current coroutine.
/// @param coro
/// Current coroutine handle.
template <class T>
requires(std::is_base_of_v<detail::PromiseBase, T>)
auto await_suspend(std::coroutine_handle<T> coro) noexcept -> void {
auto &promise = static_cast<detail::PromiseBase &>(coro.promise());
m_promise = &promise;
auto *worker = promise.worker;

io_uring_sqe *sqe = worker->pollSubmissionQueueEntry();
while (sqe == nullptr) [[unlikely]] {
worker->submit();
sqe = worker->pollSubmissionQueueEntry();
}

sqe->opcode = IORING_OP_SEND;
sqe->fd = m_socket;
sqe->addr = reinterpret_cast<std::uintptr_t>(m_data);
sqe->len = m_size;
sqe->msg_flags = static_cast<std::uint32_t>(m_flags);
sqe->user_data = reinterpret_cast<std::uintptr_t>(&promise);

// Prepare for timeout.
if (m_timeout.tv_sec != 0 || m_timeout.tv_nsec != 0) {
io_uring_sqe *timeSqe = worker->pollSubmissionQueueEntry();
while (timeSqe == nullptr) [[unlikely]] {
worker->submit();
timeSqe = worker->pollSubmissionQueueEntry();
}

sqe->flags = IOSQE_IO_LINK;

// Prepare timeout event.
timeSqe->opcode = IORING_OP_LINK_TIMEOUT;
timeSqe->fd = -1;
timeSqe->addr = reinterpret_cast<std::uintptr_t>(&m_timeout);
timeSqe->len = 1;
timeSqe->user_data = 0;
timeSqe->timeout_flags = 0;
}

worker->flushSubmissionQueue();
}

/// @brief
/// Resume this coroutine and get result of the async send operation.
/// @return
/// A struct that contains number of bytes sent and an error code. The error code is
/// @c std::errc{} if succeeded and the number of bytes sent is valid. The return value is
/// @c std::errc::operation_canceled if timeout occured.
[[nodiscard]]
auto await_resume() const noexcept -> SystemIoResult {
if (m_promise->ioResult < 0) [[unlikely]]
return {0, std::errc{-m_promise->ioResult}};
return {static_cast<std::uint32_t>(m_promise->ioResult), std::errc{}};
}

private:
detail::PromiseBase *m_promise;
int m_socket;
const void *m_data;
std::uint32_t m_size;
int m_flags;
__kernel_timespec m_timeout;
};

/// @class ReceiveAwaitable
/// @brief
/// Awaitable object for async recv operation.
Expand Down Expand Up @@ -1941,6 +2053,119 @@ class ReceiveAwaitable {
int m_flags;
};

/// @class TimedReceiveAwaitable
/// @brief
/// Awaitable object for async recv operation with timeout support.
class TimedReceiveAwaitable {
public:
using Self = TimedReceiveAwaitable;

/// @brief
/// Create a new @c TimedReceiveAwaitable for async recv operation.
/// @param socket
/// The socket to receive data from.
/// @param[out] buffer
/// Pointer to start of buffer to store data received from the socket.
/// @param size
/// Maximum available size in byte of @c buffer.
/// @param flags
/// Flags for this async recv operation. See linux manual for @c recv for details.
/// @param timeout
/// Timeout for this receive operation. Negative values and zero will be considered as never
/// timeout.
template <class Rep, class Period>
requires(std::ratio_less_equal_v<std::nano, Period>)
TimedReceiveAwaitable(int socket, void *buffer, std::uint32_t size, int flags,
std::chrono::duration<Rep, Period> timeout) noexcept
: m_promise(), m_socket(socket), m_buffer(buffer), m_size(size), m_flags(flags),
m_timeout() {
auto nano = std::chrono::duration_cast<std::chrono::nanoseconds>(timeout).count();
if (nano <= 0)
return;

m_timeout.tv_sec = static_cast<std::uint64_t>(nano) / 1000000000ULL;
m_timeout.tv_nsec = static_cast<std::uint64_t>(nano) % 1000000000ULL;
}

/// @brief
/// C++20 coroutine API method. Always execute @c await_suspend().
/// @return
/// This function always returns @c false.
[[nodiscard]]
static constexpr auto await_ready() noexcept -> bool {
return false;
}

/// @brief
/// Prepare for async recv operation and suspend the coroutine.
/// @tparam T
/// Type of promise of current coroutine.
/// @param coro
/// Current coroutine handle.
template <class T>
requires(std::is_base_of_v<detail::PromiseBase, T>)
auto await_suspend(std::coroutine_handle<T> coro) noexcept -> void {
auto &promise = static_cast<detail::PromiseBase &>(coro.promise());
m_promise = &promise;
auto *worker = promise.worker;

io_uring_sqe *sqe = worker->pollSubmissionQueueEntry();
while (sqe == nullptr) [[unlikely]] {
worker->submit();
sqe = worker->pollSubmissionQueueEntry();
}

sqe->opcode = IORING_OP_RECV;
sqe->fd = m_socket;
sqe->addr = reinterpret_cast<std::uintptr_t>(m_buffer);
sqe->len = m_size;
sqe->msg_flags = static_cast<std::uint32_t>(m_flags);
sqe->user_data = reinterpret_cast<std::uintptr_t>(&promise);

// Prepare for timeout.
if (m_timeout.tv_sec != 0 || m_timeout.tv_nsec != 0) {
io_uring_sqe *timeSqe = worker->pollSubmissionQueueEntry();
while (timeSqe == nullptr) [[unlikely]] {
worker->submit();
timeSqe = worker->pollSubmissionQueueEntry();
}

sqe->flags = IOSQE_IO_LINK;

// Prepare timeout event.
timeSqe->opcode = IORING_OP_LINK_TIMEOUT;
timeSqe->fd = -1;
timeSqe->addr = reinterpret_cast<std::uintptr_t>(&m_timeout);
timeSqe->len = 1;
timeSqe->user_data = 0;
timeSqe->timeout_flags = 0;
}

worker->flushSubmissionQueue();
}

/// @brief
/// Resume this coroutine and get result of the async recv operation.
/// @return
/// A struct that contains number of bytes received and an error code. The error code is
/// @c std::errc{} if succeeded and the number of bytes received is valid. The return value is
/// @c std::errc::operation_canceled if timeout occured.
[[nodiscard]]
auto await_resume() const noexcept -> SystemIoResult {
if (m_promise->ioResult < 0) [[unlikely]]
return {0, std::errc{-m_promise->ioResult}};
return {static_cast<std::uint32_t>(m_promise->ioResult), std::errc{}};
}

private:
detail::PromiseBase *m_promise;
int m_socket;
void *m_buffer;
std::uint32_t m_size;
int m_flags;
__kernel_timespec m_timeout;
};

/// @class SendToAwaitable
/// @brief
/// Awaitable object for async sendto operation.
Expand Down
82 changes: 82 additions & 0 deletions test/task.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -408,6 +408,88 @@ TEST_CASE("[task] Send/ReceiveAwaitable") {

namespace {

auto timedSendAwaitableTask(IoContext &ctx, const char *address,
std::atomic_bool &couldConnect) noexcept -> Task<> {
struct sockaddr_un addr{};
addr.sun_family = AF_UNIX;
std::strncpy(addr.sun_path, address, std::size(addr.sun_path));

while (!couldConnect.load(std::memory_order_relaxed))
co_await YieldAwaitable();

int s = ::socket(AF_UNIX, SOCK_STREAM, 0);
CHECK(s >= 0);

std::errc e =
co_await ConnectAwaitable(s, reinterpret_cast<struct sockaddr *>(&addr), sizeof(addr));
CHECK(e == std::errc{});

for (std::size_t i = 0; i < 1024; ++i) {
auto [bytes, error] = co_await TimedSendAwaitable(s, &i, sizeof(i), MSG_NOSIGNAL, 1s);
CHECK(error == std::errc{});
CHECK(bytes == sizeof(i));
}

co_await sleep(1s);
::close(s);

ctx.stop();
}

auto timedRecvAwaitableTask(const char *address, std::atomic_bool &couldConnect) noexcept
-> Task<> {
struct sockaddr_un addr{};
addr.sun_family = AF_UNIX;
std::strncpy(addr.sun_path, address, std::size(addr.sun_path));

int s = ::socket(AF_UNIX, SOCK_STREAM, 0);
CHECK(s >= 0);

int ret = ::bind(s, reinterpret_cast<struct sockaddr *>(&addr), sizeof(addr));
CHECK(ret >= 0);

ret = ::listen(s, 1);
CHECK(ret >= 0);

couldConnect.store(true, std::memory_order_relaxed);

auto [client, e] = co_await AcceptAwaitable(s, nullptr, nullptr, SOCK_CLOEXEC);
CHECK(e == std::errc{});
CHECK(client >= 0);

for (std::size_t i = 0; i < 1024; ++i) {
std::size_t buffer;
auto [bytes, error] =
co_await TimedReceiveAwaitable(client, &buffer, sizeof(buffer), 0, 1s);
CHECK(error == std::errc{});
CHECK(bytes == sizeof(buffer));
CHECK(buffer == i);
}

std::size_t buffer;
auto [bytes, error] = co_await TimedReceiveAwaitable(client, &buffer, sizeof(buffer), 0, 100ms);
CHECK(error == std::errc::operation_canceled);

::close(client);
::close(s);
}

} // namespace

TEST_CASE("[task] Timed Send/ReceiveAwaitable") {
IoContext ctx(1);
std::atomic_bool couldConnect = false;

constexpr const char *address = "nyaio-timed-send-recv.sock";
ctx.schedule(timedSendAwaitableTask(ctx, address, couldConnect));
ctx.schedule(timedRecvAwaitableTask(address, couldConnect));

ctx.run();
::unlink(address);
}

namespace {

auto sendtoAwaitableTask(IoContext &ctx, const char *address,
std::atomic_bool &couldConnect) noexcept -> Task<> {
struct sockaddr_un addr{};
Expand Down

0 comments on commit f08ae08

Please sign in to comment.