diff --git a/Cargo.lock b/Cargo.lock index b948723..f4bca56 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -102,7 +102,6 @@ dependencies = [ "cobs", "futures", "heapless", - "maitake", "maitake-sync", "postcard", "proptest", diff --git a/source/calliope/Cargo.toml b/source/calliope/Cargo.toml index f140a75..ee75993 100644 --- a/source/calliope/Cargo.toml +++ b/source/calliope/Cargo.toml @@ -11,8 +11,8 @@ alloc = ["postcard/alloc", "tricky-pipe/alloc"] [dependencies] cobs = "0.2" heapless = "0.7.16" -tracing = { version = "0.1.21", default-features = false } postcard = { version = "1.0.7", default-features = false } +maitake-sync = { version = "0.1", default-features = false } [dependencies.futures] version = "0.3.17" @@ -36,8 +36,10 @@ rev = "416b7d59fbc7fa889a774f54133786a584eb8732" [dependencies.tricky-pipe] path = "../tricky-pipe" -[dependencies.maitake] -version = "0.1" +[dependencies.tracing] +version = "0.1.21" +default-features = false +features = ["attributes"] [dev-dependencies] postcard = { version = "1.0.7", features = ["alloc"] } @@ -45,7 +47,7 @@ proptest = "1.3.1" proptest-derive = "0.4" tokio = { version = "1.31.0", features = ["macros", "rt", "sync", "io-util"] } tokio-stream = "0.1.14" -tracing = { version = "0.1.21", features = ["attributes", "std"] } +tracing = { version = "0.1.21", features = ["std"] } tracing-subscriber = { version = "0.3.17", default-features = false, features = ["fmt", "ansi", "env-filter"] } tricky-pipe = { path = "../tricky-pipe", features = ["alloc"] } diff --git a/source/calliope/src/client.rs b/source/calliope/src/client.rs index 6c183bc..9ff46ee 100644 --- a/source/calliope/src/client.rs +++ b/source/calliope/src/client.rs @@ -1,6 +1,6 @@ use crate::{ message::{Rejection, Reset}, - service, Service, + req_rsp, service, Service, }; use tricky_pipe::{bidi, mpsc, oneshot, serbox}; @@ -91,3 +91,16 @@ impl Connector { } } } + +impl Connector { + pub async fn connect_req_rsp( + &mut self, + identity: impl Into, + hello: S::Hello, + channels: Channels, + ) -> Result, ConnectError> { + self.connect(identity, hello, channels) + .await + .map(req_rsp::Client::new) + } +} diff --git a/source/calliope/src/lib.rs b/source/calliope/src/lib.rs index 816f4af..3ed5aa4 100644 --- a/source/calliope/src/lib.rs +++ b/source/calliope/src/lib.rs @@ -8,6 +8,7 @@ use core::fmt; pub mod client; mod conn_table; pub mod message; +pub mod req_rsp; pub mod service; pub use client::Connector; @@ -23,7 +24,7 @@ use message::{InboundFrame, OutboundFrame, Rejection}; use tricky_pipe::{mpsc, oneshot, serbox}; #[cfg(test)] -mod tests; +pub(crate) mod tests; /// A wire-level transport for [Calliope frames](Frame). /// diff --git a/source/calliope/src/req_rsp.rs b/source/calliope/src/req_rsp.rs new file mode 100644 index 0000000..99cd474 --- /dev/null +++ b/source/calliope/src/req_rsp.rs @@ -0,0 +1,272 @@ +use crate::{client, message, Service}; +use core::sync::atomic::{AtomicBool, AtomicUsize, Ordering}; +use futures::{future::FutureExt, pin_mut, select_biased}; +use maitake_sync::{ + wait_map::{WaitError, WaitMap, WakeOutcome}, + WaitCell, +}; +use serde::{de::DeserializeOwned, Deserialize, Serialize}; +use tracing::Level; +use tricky_pipe::mpsc::error::{RecvError, SendError, TrySendError}; + +#[cfg(test)] +mod tests; + +#[derive(Debug, Eq, PartialEq)] +pub struct Seq(usize); + +#[derive(Copy, Clone, Debug, Eq, PartialEq, Serialize, Deserialize)] +#[must_use] +pub struct Request { + seq: usize, + body: T, +} + +#[derive(Copy, Clone, Debug, Eq, PartialEq, Serialize, Deserialize)] +#[must_use] +pub struct Response { + seq: usize, + body: T, +} + +pub struct Client +where + S: ReqRspService, +{ + seq: AtomicUsize, + channel: client::Connection, + shutdown: WaitCell, + dispatcher: WaitMap, + has_dispatcher: AtomicBool, +} + +pub trait ReqRspService: + Service, ServerMsg = Response> +{ + type Request: Serialize + DeserializeOwned + Send + Sync + 'static; + type Response: Serialize + DeserializeOwned + Send + Sync + 'static; +} + +// === impl Seq === + +impl Seq { + pub fn respond(self, body: T) -> Response { + Response { seq: self.0, body } + } +} + +// === impl Request === + +impl Request { + pub fn body(&self) -> &T { + &self.body + } + + pub fn into_parts(self) -> (Seq, T) { + (Seq(self.seq), self.body) + } + + pub fn respond(self, body: U) -> Response { + Response { + seq: self.seq, + body, + } + } +} + +#[derive(Clone, Debug, Eq, PartialEq)] +pub enum DispatchError { + AlreadyRunning, + ConnectionReset(message::Reset), +} + +enum RequestError { + Reset(message::Reset), + SeqInUse, +} + +// === impl Client === + +impl Client +where + S: ReqRspService, +{ + pub fn new(client: client::Connection) -> Self { + Self { + seq: AtomicUsize::new(0), + channel: client, + dispatcher: WaitMap::new(), + has_dispatcher: AtomicBool::new(false), + shutdown: WaitCell::new(), + } + } + + pub async fn request(&self, body: S::Request) -> Result { + #[cfg_attr(debug_assertions, allow(unreachable_code))] + let handle_wait_error = |err: WaitError| match err { + WaitError::Closed => { + let error = self.channel.tx().try_reserve().expect_err( + "if the waitmap was closed, then the channel should \ + have been closed with an error!", + ); + if let TrySendError::Error { error, .. } = error { + return RequestError::Reset(error); + } + + #[cfg(debug_assertions)] + unreachable!( + "closing the channel with an error should have priority \ + over full/disconnected errors." + ); + + RequestError::Reset(message::Reset::BecauseISaidSo) + } + WaitError::Duplicate => RequestError::SeqInUse, + WaitError::AlreadyConsumed => { + unreachable!("data should not already be consumed, this is a bug") + } + WaitError::NeverAdded => { + unreachable!("we ensured the waiter was added, this is a bug!") + } + error => { + #[cfg(debug_assertions)] + todo!( + "james added a new WaitError variant that we don't \ + know how to handle: {error:}" + ); + + #[cfg_attr(debug_assertions, allow(unreachable_code))] + RequestError::Reset(message::Reset::BecauseISaidSo) + } + }; + + // aquire a send permit first --- this way, we don't increment the + // sequence number until we actually have a channel reservation. + let permit = self.channel.tx().reserve().await.map_err(|e| match e { + SendError::Disconnected(()) => message::Reset::BecauseISaidSo, + SendError::Error { error, .. } => error, + })?; + + loop { + let seq = self.seq.fetch_add(1, Ordering::Relaxed); + // ensure waiter is enqueued before sending the request. + let wait = self.dispatcher.wait(seq); + pin_mut!(wait); + match wait.as_mut().enqueue().await.map_err(handle_wait_error) { + Ok(_) => {} + Err(RequestError::Reset(reset)) => return Err(reset), + Err(RequestError::SeqInUse) => { + // NOTE: yes, in theory, this loop *could* never terminate, + // if *all* sequence numbers have a currently-in-flight + // request. but, if you've somehow managed to spawn + // `usize::MAX` request tasks at the same time, and none of + // them have completed, you probably have worse problems... + tracing::trace!(seq, "sequence number in use, retrying..."); + continue; + } + }; + + let req = Request { seq, body }; + // actually send the message... + permit.send(req); + + return match wait.await.map_err(handle_wait_error) { + Ok(rsp) => Ok(rsp), + Err(RequestError::Reset(reset)) => Err(reset), + Err(RequestError::SeqInUse) => unreachable!( + "we should have already enqueued the waiter, so its \ + sequence number should be okay. this is a bug!" + ), + }; + } + } + + /// Shut down the client dispatcher for this `Client`. + /// + /// This will fail any outstanding `Request` futures, and reset the + /// connection. + pub fn shutdown(&self) { + tracing::debug!("shutting down client..."); + self.shutdown.close(); + self.channel + .close_with_error(message::Reset::BecauseISaidSo); + self.dispatcher.close(); + } + + /// Run the client's dispatcher in the background until cancelled or the + /// connection is reset. + #[tracing::instrument( + level = Level::DEBUG, + name = "Client::dispatcher", + skip(self), + fields(svc = %core::any::type_name::()), + ret(Debug), + err(Debug), + )] + pub async fn dispatch(&self) -> Result<(), DispatchError> { + #[cfg_attr(debug_assertions, allow(unreachable_code))] + if self.has_dispatcher.swap(true, Ordering::AcqRel) { + #[cfg(debug_assertions)] + panic!( + "a client connection may only have one running dispatcher \ + task! a second call to `Client::dispatch` is likely a bug. \ + this is a panic in debug mode." + ); + + tracing::warn!("a client connection may only have one running dispatcher task!"); + return Err(DispatchError::AlreadyRunning); + } + + loop { + // wait for the next server message, or for the client to trigger a + // shutdown. + let msg = select_biased! { + _ = self.shutdown.wait().fuse() => { + tracing::debug!("client dispatcher `shutting down..."); + return Ok(()); + } + msg = self.channel.rx().recv().fuse() => msg, + }; + + let Response { seq, body } = match msg { + Ok(msg) => msg, + Err(reset) => { + let reset = match reset { + RecvError::Error(e) => e, + _ => message::Reset::BecauseISaidSo, + }; + + tracing::debug!(%reset, "client connection reset, shutting down..."); + self.channel.close_with_error(reset); + self.dispatcher.close(); + return Err(DispatchError::ConnectionReset(reset)); + } + }; + + tracing::trace!(seq, "dispatching response..."); + + match self.dispatcher.wake(&seq, body) { + WakeOutcome::Woke => { + tracing::trace!(seq, "dispatched response"); + } + WakeOutcome::Closed(_) => { + #[cfg(debug_assertions)] + unreachable!("the dispatcher should not be closed if it is still running..."); + } + WakeOutcome::NoMatch(_) => { + tracing::debug!(seq, "client no longer interested in request"); + } + }; + } + } +} + +impl ReqRspService for S +where + S: Service, ServerMsg = Response>, + Req: Serialize + DeserializeOwned + Send + Sync + 'static, + Rsp: Serialize + DeserializeOwned + Send + Sync + 'static, +{ + type Request = Req; + type Response = Rsp; +} diff --git a/source/calliope/src/req_rsp/tests.rs b/source/calliope/src/req_rsp/tests.rs new file mode 100644 index 0000000..3c26949 --- /dev/null +++ b/source/calliope/src/req_rsp/tests.rs @@ -0,0 +1,51 @@ +use super::*; +use crate::tests::*; +use std::sync::Arc; + +#[tokio::test] +async fn req_rsp() { + let remote_registry: TestRegistry = TestRegistry::default(); + let conns = remote_registry.add_service(svcs::EchoDelayService::identity()); + tokio::spawn(svcs::EchoDelayService::serve(conns)); + + let fixture = Fixture::new() + .spawn_local(Default::default()) + .spawn_remote(remote_registry); + + let mut connector = fixture.local_iface().connector::(); + + let chan = connect(&mut connector, "echo-delay", ()).await; + let client = Arc::new(Client::::new(chan)); + + let dispatcher = tokio::spawn({ + let client = client.clone(); + async move { + client.dispatch().await.expect("dispatcher died"); + } + }); + + let rsp_futs = (10..100).rev().map(|val| { + let client = client.clone(); + tokio::spawn( + async move { + tracing::info!(val, "sending request..."); + let rsp = client.request(svcs::Echo { val }).await; + tracing::info!(?rsp, "recieved response"); + assert_eq!(rsp, Ok(svcs::Echo { val })); + } + .instrument(tracing::info_span!("request", val)), + ) + }); + + for fut in rsp_futs { + fut.await.expect("response task should not have panicked!"); + } + + client.shutdown(); + + dispatcher + .await + .expect("dispatcher task should not have panicked"); + + fixture.finish_test().await +} diff --git a/source/calliope/src/tests/mod.rs b/source/calliope/src/tests/mod.rs index 49b1bc9..1e52e27 100644 --- a/source/calliope/src/tests/mod.rs +++ b/source/calliope/src/tests/mod.rs @@ -21,16 +21,24 @@ mod integration; pub(crate) mod svcs { use super::*; + use crate::req_rsp; use crate::service; use uuid::{uuid, Uuid}; pub struct HelloWorld; + pub struct EchoDelayService; + #[derive(Debug, Eq, PartialEq, serde::Serialize, serde::Deserialize)] pub struct HelloWorldRequest { pub hello: String, } + #[derive(Debug, Eq, PartialEq, serde::Serialize, serde::Deserialize)] + pub struct Echo { + pub val: usize, + } + #[derive(Debug, Eq, PartialEq, serde::Serialize, serde::Deserialize)] pub struct HelloHello { pub hello: String, @@ -59,6 +67,57 @@ pub(crate) mod svcs { const UUID: Uuid = uuid!("9442b293-93d8-48b9-bbf7-52f636462bfe"); } + impl service::Service for EchoDelayService { + type ClientMsg = req_rsp::Request; + type ServerMsg = req_rsp::Response; + type ConnectError = (); + type Hello = (); + const UUID: Uuid = uuid!("5562a756-4d1d-4077-b8b9-7305f4a6b6a0"); + } + + impl EchoDelayService { + pub fn identity() -> service::Identity { + service::Identity::from_name::("echo-delay") + } + + #[tracing::instrument(level = tracing::Level::INFO, name = "EchoDelayService::serve", skip(conns))] + pub async fn serve(mut conns: mpsc::Receiver) { + let mut worker = 1; + while let Some(req) = conns.recv().await { + let InboundConnect { hello, rsp } = req; + tracing::info!(?hello, "hello world service received connection"); + let (their_chan, my_chan) = make_bidis(8); + tokio::spawn(Self::worker(worker, my_chan)); + worker += 1; + let sent = rsp.send(Ok(their_chan)).is_ok(); + tracing::debug!(?sent); + } + } + + #[tracing::instrument(level = tracing::Level::INFO, name = "EchoDelayService::worker", skip(chan))] + async fn worker( + worker: usize, + chan: BiDi, req_rsp::Response, Reset>, + ) { + while let Ok(req) = chan.rx().recv().await { + tracing::info!(?req, "echo-delay worker {worker} received request"); + let (seq, Echo { val }) = req.into_parts(); + + let tx = chan.tx().clone(); + let span = tracing::info_span!("respond", ?seq, val); + tokio::spawn( + async move { + tracing::debug!("responding in {val} ms..."); + tokio::time::sleep(tokio::time::Duration::from_millis(val as u64)).await; + tx.send(seq.respond(Echo { val })).await.unwrap(); + tracing::info!("responded after {val} ms"); + } + .instrument(span), + ); + } + } + } + pub fn hello_with_hello_id() -> service::Identity { service::Identity::from_name::("hello-hello") }