Skip to content

Commit

Permalink
Fix bi stream bug
Browse files Browse the repository at this point in the history
  • Loading branch information
fantasy-peak committed Apr 11, 2024
1 parent 020e196 commit 28b71f1
Show file tree
Hide file tree
Showing 3 changed files with 29 additions and 17 deletions.
20 changes: 12 additions & 8 deletions out/frpc.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -1340,16 +1340,16 @@ class StreamClient final {
};

struct StreamServerHandler {
virtual void hello_world(const std::shared_ptr<asio::experimental::concurrent_channel<void(asio::error_code, std::string)>>&,
virtual void hello_world(std::shared_ptr<asio::experimental::concurrent_channel<void(asio::error_code, std::string)>>,
std::shared_ptr<Stream<void(std::string)>>) = 0;
};

struct CoroStreamServerHandler {
#ifdef __cpp_impl_coroutine
virtual asio::awaitable<void> hello_world(const std::shared_ptr<asio::experimental::concurrent_channel<void(asio::error_code, std::string)>>&,
virtual asio::awaitable<void> hello_world(std::shared_ptr<asio::experimental::concurrent_channel<void(asio::error_code, std::string)>>,
std::shared_ptr<Stream<void(std::string)>>) = 0;
#else
virtual void hello_world(const std::shared_ptr<asio::experimental::concurrent_channel<void(asio::error_code, std::string)>>&,
virtual void hello_world(std::shared_ptr<asio::experimental::concurrent_channel<void(asio::error_code, std::string)>>,
std::shared_ptr<Stream<void(std::string)>>) = 0;
#endif
};
Expand Down Expand Up @@ -1442,22 +1442,26 @@ class StreamServer final {
snd_bufs.emplace_back(zmq::message_t(packet.data(), packet.size()));
m_channel->send(snd_bufs);
},
[ptr, this] {
[ptr, this, req_id, channel_ptr]() mutable {
auto& recv_bufs = *ptr;
auto close = pack<bool>(true);
std::vector<zmq::message_t> snd_bufs;
snd_bufs.emplace_back(zmq::message_t(recv_bufs[0].data(), recv_bufs[0].size()));
snd_bufs.emplace_back(zmq::message_t(recv_bufs[2].data(), recv_bufs[2].size()));
snd_bufs.emplace_back(zmq::message_t(close.data(), close.size()));
snd_bufs.emplace_back(zmq::message_t("C", 1));
m_channel->send(snd_bufs);
channel_ptr->close();
{
std::lock_guard lk(m_mtx);
m_channel_mapping.erase(req_id);
}
});
std::visit([&](auto&& arg) mutable {
using T = std::decay_t<decltype(arg)>;
if constexpr (std::is_same_v<T, std::shared_ptr<StreamServerHandler>>) {
arg->hello_world(channel_ptr, std::move(out));
arg->hello_world(std::move(channel_ptr), std::move(out));
} else {
asio::co_spawn(m_pool_ptr->getIoContext(), [](auto& arg, auto channel_ptr, auto out) mutable -> asio::awaitable<void> {
co_await arg->hello_world(channel_ptr, std::move(out));
co_await arg->hello_world(std::move(channel_ptr), std::move(out));
co_return;
}(arg, std::move(channel_ptr), std::move(out)),
asio::detached);
Expand Down
20 changes: 12 additions & 8 deletions template/cpp/bi_stream.inja
Original file line number Diff line number Diff line change
Expand Up @@ -137,18 +137,18 @@ private:

struct {{value.callee}}Handler {
{% for func in value.definitions %}
virtual void {{func.func_name}}(const std::shared_ptr<asio::experimental::concurrent_channel<void(asio::error_code, {{_format_args_type(func.inputs)}})>>&,
virtual void {{func.func_name}}(std::shared_ptr<asio::experimental::concurrent_channel<void(asio::error_code, {{_format_args_type(func.inputs)}})>>,
std::shared_ptr<Stream<void({{_format_args_type(func.outputs)}})>>) = 0;
{% endfor %}
};

struct Coro{{value.callee}}Handler {
{% for func in value.definitions %}
#ifdef __cpp_impl_coroutine
virtual asio::awaitable<void> {{func.func_name}}(const std::shared_ptr<asio::experimental::concurrent_channel<void(asio::error_code, {{_format_args_type(func.inputs)}})>>&,
virtual asio::awaitable<void> {{func.func_name}}(std::shared_ptr<asio::experimental::concurrent_channel<void(asio::error_code, {{_format_args_type(func.inputs)}})>>,
std::shared_ptr<Stream<void({{_format_args_type(func.outputs)}})>>) = 0;
#else
virtual void {{func.func_name}}(const std::shared_ptr<asio::experimental::concurrent_channel<void(asio::error_code, {{_format_args_type(func.inputs)}})>>&,
virtual void {{func.func_name}}(std::shared_ptr<asio::experimental::concurrent_channel<void(asio::error_code, {{_format_args_type(func.inputs)}})>>,
std::shared_ptr<Stream<void({{_format_args_type(func.outputs)}})>>) = 0;
#endif
{% endfor %}
Expand Down Expand Up @@ -240,23 +240,27 @@ private:
snd_bufs.emplace_back(zmq::message_t(recv_bufs[1].data(), recv_bufs[1].size()));
snd_bufs.emplace_back(zmq::message_t(packet.data(), packet.size()));
m_channel->send(snd_bufs);
}, [ptr, this] {
}, [ptr, this, req_id, channel_ptr] () mutable {
auto& recv_bufs = *ptr;
auto close = pack<bool>(true);
std::vector<zmq::message_t> snd_bufs;
snd_bufs.emplace_back(zmq::message_t(recv_bufs[0].data(), recv_bufs[0].size()));
snd_bufs.emplace_back(zmq::message_t(recv_bufs[2].data(), recv_bufs[2].size()));
snd_bufs.emplace_back(zmq::message_t(close.data(), close.size()));
snd_bufs.emplace_back(zmq::message_t("C", 1));
m_channel->send(snd_bufs);
channel_ptr->close();
{
std::lock_guard lk(m_mtx);
m_channel_mapping.erase(req_id);
}
});
std::visit([&](auto&& arg) mutable {
using T = std::decay_t<decltype(arg)>;
if constexpr (std::is_same_v<T, std::shared_ptr<{{value.callee}}Handler>>) {
arg->{{func.func_name}}(channel_ptr, std::move(out));
arg->{{func.func_name}}(std::move(channel_ptr), std::move(out));
} else {
asio::co_spawn(m_pool_ptr->getIoContext(),
[] (auto& arg, auto channel_ptr, auto out) mutable -> asio::awaitable<void> {
co_await arg->{{func.func_name}}(channel_ptr, std::move(out));
co_await arg->{{func.func_name}}(std::move(channel_ptr), std::move(out));
co_return;
}(arg, std::move(channel_ptr), std::move(out)),
asio::detached);
Expand Down
6 changes: 5 additions & 1 deletion test/cpp/main.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -255,7 +255,7 @@ void test_coro_bi(auto& pool) {
#ifdef __cpp_impl_coroutine

struct CoroStreamServerHandler : public frpc::CoroStreamServerHandler {
virtual asio::awaitable<void> hello_world(const std::shared_ptr<asio::experimental::concurrent_channel<void(asio::error_code, std::string)>>& ins,
virtual asio::awaitable<void> hello_world(std::shared_ptr<asio::experimental::concurrent_channel<void(asio::error_code, std::string)>> ins,
std::shared_ptr<frpc::Stream<void(std::string)>> outs) {
start([outs = std::move(outs)]() mutable {
for (int i = 0; i < 5; i++) {
Expand All @@ -267,6 +267,10 @@ struct CoroStreamServerHandler : public frpc::CoroStreamServerHandler {
});
for (;;) {
auto [ec, str] = co_await ins->async_receive(asio::as_tuple(asio::use_awaitable));
if (ec) {
spdlog::info("stream server: {}", ec.message());
break;
}
spdlog::info("stream server recv: {}", str);
}
co_return;
Expand Down

0 comments on commit 28b71f1

Please sign in to comment.