From 11d843b37d3dbdb39d564910ffcdc87dd0c52b91 Mon Sep 17 00:00:00 2001 From: Ruediger Klaehn Date: Fri, 20 Dec 2024 13:22:59 +0200 Subject: [PATCH 1/2] try to impl try_bidi_streaming pattern --- src/pattern/mod.rs | 1 + src/pattern/try_bidi_streaming.rs | 211 ++++++++++++++++++++++++++++++ 2 files changed, 212 insertions(+) create mode 100644 src/pattern/try_bidi_streaming.rs diff --git a/src/pattern/mod.rs b/src/pattern/mod.rs index da2b879..a1a58ef 100644 --- a/src/pattern/mod.rs +++ b/src/pattern/mod.rs @@ -8,4 +8,5 @@ pub mod bidi_streaming; pub mod client_streaming; pub mod rpc; pub mod server_streaming; +pub mod try_bidi_streaming; pub mod try_server_streaming; diff --git a/src/pattern/try_bidi_streaming.rs b/src/pattern/try_bidi_streaming.rs new file mode 100644 index 0000000..6391d5d --- /dev/null +++ b/src/pattern/try_bidi_streaming.rs @@ -0,0 +1,211 @@ +//! Fallible server streaming interaction pattern. + +use std::{ + error, + fmt::{self, Debug}, + result, +}; + +use futures_lite::{Future, Stream, StreamExt}; +use futures_util::{FutureExt, SinkExt, TryFutureExt}; +use serde::{Deserialize, Serialize}; + +use crate::{ + client::{BoxStreamSync, UpdateSink}, + message::{InteractionPattern, Msg}, + server::{race2, RpcChannel, RpcServerError, UpdateStream}, + transport::{self, ConnectionErrors, StreamTypes}, + Connector, RpcClient, Service, +}; + +/// A guard message to indicate that the stream has been created. +/// +/// This is so we can dinstinguish between an error creating the stream and +/// an error in the first item produced by the stream. +#[derive(Debug, Clone, Copy, Serialize, Deserialize)] +pub struct StreamCreated; + +/// Fallible server streaming interaction pattern. +#[derive(Debug, Clone, Copy)] +pub struct TryBidiStreaming; + +impl InteractionPattern for TryBidiStreaming {} + +/// Same as ServerStreamingMsg, but with lazy stream creation and the error type explicitly defined. +pub trait TryBidiStreamingMsg: Msg +where + result::Result: Into + TryFrom, + result::Result: Into + TryFrom, +{ + /// Error when creating the stream + type CreateError: Debug + Send + 'static; + + /// Update type + type Update: Into + TryFrom + Send + 'static; + + /// Successful response item + type Item: Send + 'static; + + /// Error for stream items + type ItemError: Debug + Send + 'static; +} + +/// Server error when accepting a server streaming request +/// +/// This combines network errors with application errors. Usually you don't +/// care about the exact nature of the error, but if you want to handle +/// application errors differently, you can match on this enum. +#[derive(Debug)] +pub enum Error { + /// Unable to open a substream at all + Open(C::OpenError), + /// Unable to send the request to the server + Send(C::SendError), + /// Error received when creating the stream + Recv(C::RecvError), + /// Connection was closed before receiving the first message + EarlyClose, + /// Unexpected response from the server + Downcast, + /// Application error + Application(E), +} + +impl fmt::Display for Error { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + fmt::Debug::fmt(self, f) + } +} + +impl error::Error for Error {} + +/// Client error when handling responses from a server streaming request. +/// +/// This combines network errors with application errors. +#[derive(Debug)] +pub enum ItemError { + /// Unable to receive the response from the server + Recv(S::RecvError), + /// Unexpected response from the server + Downcast, + /// Application error + Application(E), +} + +impl fmt::Display for ItemError { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + fmt::Debug::fmt(self, f) + } +} + +impl error::Error for ItemError {} + +impl RpcChannel +where + C: StreamTypes, + S: Service, +{ + /// handle the message M using the given function on the target object + /// + /// If you want to support concurrent requests, you need to spawn this on a tokio task yourself. + /// + /// Compared to [RpcChannel::server_streaming], with this method the stream creation is via + /// a function that returns a future that resolves to a stream. + pub async fn try_bidi_streaming( + self, + req: M, + target: T, + f: F, + ) -> result::Result<(), RpcServerError> + where + M: TryBidiStreamingMsg, + Result: Into + TryFrom, + Result: Into + TryFrom, + F: FnOnce(T, M, UpdateStream) -> Fut + Send + 'static, + Fut: Future> + Send + 'static, + Str: Stream> + Send + 'static, + T: Send + 'static, + { + let Self { mut send, recv, .. } = self; + let (updates, read_error) = UpdateStream::new(recv); + race2(read_error.map(Err), async move { + // get the response + let responses = match f(target, req, updates).await { + Ok(responses) => { + // turn into a S::Res so we can send it + let response = Ok(StreamCreated).into(); + // send it and return the error if any + send.send(response) + .await + .map_err(RpcServerError::SendError)?; + responses + } + Err(cause) => { + // turn into a S::Res so we can send it + let response = Err(cause).into(); + // send it and return the error if any + send.send(response) + .await + .map_err(RpcServerError::SendError)?; + return Ok(()); + } + }; + tokio::pin!(responses); + while let Some(response) = responses.next().await { + // turn into a S::Res so we can send it + let response = response.into(); + // send it and return the error if any + send.send(response) + .await + .map_err(RpcServerError::SendError)?; + } + Ok(()) + }) + .await + } +} + +impl RpcClient +where + C: Connector, + S: Service, +{ + /// Bidi call to the server, request opens a stream, response is a stream + pub async fn try_bidi_streaming( + &self, + msg: M, + ) -> result::Result< + ( + BoxStreamSync<'static, Result>>, + UpdateSink, + ), + Error, + > + where + M: TryBidiStreamingMsg, + Result: Into + TryFrom, + Result: Into + TryFrom, + { + let msg = msg.into(); + let (mut send, mut recv) = self.source.open().await.map_err(Error::Open)?; + send.send(msg).map_err(Error::Send).await?; + let Some(initial) = recv.next().await else { + return Err(Error::EarlyClose); + }; + let initial = initial.map_err(Error::Recv)?; // initial response + let initial = >::try_from(initial) + .map_err(|_| Error::Downcast)?; + let _ = initial.map_err(Error::Application)?; + let recv = recv.map(move |x| { + let x = x.map_err(ItemError::Recv)?; + let x = >::try_from(x) + .map_err(|_| ItemError::Downcast)?; + let x = x.map_err(ItemError::Application)?; + Ok(x) + }); + // keep send alive so the request on the server side does not get cancelled + let us = UpdateSink::new(send); + let recv = Box::pin(recv); + Ok((recv, us)) + } +} From dda76e44cd324c051b7f2227a36379132671fab7 Mon Sep 17 00:00:00 2001 From: Ruediger Klaehn Date: Fri, 20 Dec 2024 14:02:21 +0200 Subject: [PATCH 2/2] push test for trt bidi --- tests/try_bidi.rs | 109 ++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 109 insertions(+) create mode 100644 tests/try_bidi.rs diff --git a/tests/try_bidi.rs b/tests/try_bidi.rs new file mode 100644 index 0000000..d5a2813 --- /dev/null +++ b/tests/try_bidi.rs @@ -0,0 +1,109 @@ +#![cfg(feature = "flume-transport")] +use derive_more::{From, TryInto}; +use futures_lite::{Stream, StreamExt}; +use quic_rpc::{ + message::Msg, + pattern::try_bidi_streaming::{StreamCreated, TryBidiStreaming, TryBidiStreamingMsg}, + server::RpcServerError, + transport::flume, + RpcClient, RpcServer, Service, +}; +use serde::{Deserialize, Serialize}; + +#[derive(Debug, Clone)] +struct TryService; + +impl Service for TryService { + type Req = TryRequest; + type Res = TryResponse; +} + +#[derive(Debug, Serialize, Deserialize)] +pub struct StreamN { + n: u64, +} + +impl Msg for StreamN { + type Pattern = TryBidiStreaming; +} + +impl TryBidiStreamingMsg for StreamN { + type Item = u64; + type ItemError = String; + type CreateError = String; + type Update = u64; +} + +/// request enum +#[derive(Debug, Serialize, Deserialize, From, TryInto)] +pub enum TryRequest { + StreamN(StreamN), + StreamNUpdate(u64), +} + +#[derive(Debug, Serialize, Deserialize, From, TryInto, Clone)] +pub enum TryResponse { + StreamN(std::result::Result), + StreamNError(std::result::Result), +} + +#[derive(Clone)] +struct Handler; + +impl Handler { + async fn try_stream_n( + self, + req: StreamN, + updates: impl Stream, + ) -> std::result::Result>, String> { + if req.n % 2 != 0 { + return Err("odd n not allowed".to_string()); + } + let stream = async_stream::stream! { + for i in 0..req.n { + if i > 5 { + yield Err("n too large".to_string()); + return; + } + yield Ok(i); + } + }; + Ok(stream) + } +} + +#[tokio::test] +async fn try_bidi_streaming() -> anyhow::Result<()> { + tracing_subscriber::fmt::try_init().ok(); + let (server, client) = flume::channel(1); + + let server = RpcServer::::new(server); + let server_handle = tokio::task::spawn(async move { + loop { + let (req, chan) = server.accept().await?.read_first().await?; + let handler = Handler; + match req { + TryRequest::StreamN(req) => { + chan.try_bidi_streaming(req, handler, Handler::try_stream_n) + .await?; + } + TryRequest::StreamNUpdate(_) => { + return Err(RpcServerError::UnexpectedUpdateMessage); + } + } + } + #[allow(unreachable_code)] + Ok(()) + }); + let client = RpcClient::::new(client); + let (stream_n, update_sink) = client.try_bidi_streaming(StreamN { n: 10 }).await?; + let items: Vec<_> = stream_n.collect().await; + println!("{:?}", items); + drop(client); + // dropping the client will cause the server to terminate + match server_handle.await? { + Err(RpcServerError::Accept(_)) => {} + e => panic!("unexpected termination result {e:?}"), + } + Ok(()) +}