From 345227eb47d1fed23fd743c85be609b56664bd5c Mon Sep 17 00:00:00 2001 From: Eliza Weisman Date: Sat, 11 Nov 2023 12:42:23 -0800 Subject: [PATCH 01/21] start on good resets --- source/mgnp/src/conn_table.rs | 68 +++++++++++++++++++++++------------ source/mgnp/src/message.rs | 45 +++++++++++++++++++++-- 2 files changed, 88 insertions(+), 25 deletions(-) diff --git a/source/mgnp/src/conn_table.rs b/source/mgnp/src/conn_table.rs index d6f963d..f174830 100644 --- a/source/mgnp/src/conn_table.rs +++ b/source/mgnp/src/conn_table.rs @@ -1,6 +1,6 @@ use crate::{ client::OutboundConnect, - message::{Header, InboundFrame, OutboundFrame, Rejection}, + message::{Header, InboundFrame, OutboundFrame, Rejection, Reset}, registry, }; use core::{fmt, mem, num::NonZeroU16, task::Poll}; @@ -106,7 +106,10 @@ impl ConnTable { // closed. Poll::Ready(None) => { self.dead_index = Some(Id::from_index(idx)); - return Poll::Ready(OutboundFrame::reset(*remote_id)); + return Poll::Ready(OutboundFrame::reset( + *remote_id, + Reset::BecauseISaidSo, + )); } // nothing to do, move on to the next socket. @@ -154,28 +157,45 @@ impl ConnTable { // the remote peer's remote ID is our local ID. let id = remote_id; let Some(socket) = self.conns.get_mut(id) else { - tracing::debug!("process_inbound: no socket for data frame, resetting..."); - return Some(OutboundFrame::reset(local_id)); + tracing::debug!( + id.remote = %local_id, + id.local = %remote_id, + "process_inbound: recieved a DATA frame on a connection that does not exist, resetting...", + ); + return Some(OutboundFrame::reset(local_id, Reset::NoSuchConn)); }; + // try to reserve send capacity on this socket. - let error = match socket.reserve_send().await { + let reset = match socket.reserve_send().await { Ok(permit) => match permit.send(frame.body) { Ok(_) => return None, Err(error) => { - tracing::debug!(%error, "process_inbound: failed to deserialize data"); - // TODO(eliza): we should probably tell the peer - // that they sent us something bad... - return None; + tracing::debug!( + id.remote = %local_id, + id.local = %remote_id, + %error, + "process_inbound: failed to deserialize DATA frame; resetting...", + ); + Reset::bad_frame(error) } }, - Err(error) => error, + Err(InboundError::ChannelClosed) => { + // the channel has closed locally + tracing::trace!("process_inbound: recieved a DATA frame on a connection closed locally; resetting..."); + Reset::BecauseISaidSo + } + Err(InboundError::NoSocket) => { + tracing::debug!( + id.remote = %local_id, + id.local = %remote_id, + "process_inbound: recieved a DATA frame on a connection that does not exist, resetting...", + ); + Reset::NoSuchConn + } }; - // otherwise, we couldn't reserve a send permit because the - // channel has closed locally. - tracing::trace!("process_inbound: local error: {error}; resetting..."); self.close(id, None); - Some(OutboundFrame::reset(local_id)) + Some(OutboundFrame::reset(local_id, reset)) } Header::Ack { local_id, @@ -195,8 +215,8 @@ impl ConnTable { // frame that it was bad... None } - Header::Reset { remote_id } => { - tracing::trace!(id.local = %remote_id, "process_inbound: RESET"); + Header::Reset { remote_id, reason } => { + tracing::trace!(id.local = %remote_id, %reason, "process_inbound: RESET"); let _closed = self.close(remote_id, None); tracing::trace!(id.local = %remote_id, closed = _closed, "process_inbound: RESET ->"); None @@ -250,7 +270,7 @@ impl ConnTable { fn process_ack(&mut self, local_id: Id, remote_id: Id) -> Option> { let Some(Entry::Occupied(ref mut sock)) = self.conns.get_mut(local_id) else { tracing::debug!(id.local = %local_id, id.remote = %remote_id, "process_ack: no such socket"); - return Some(OutboundFrame::reset(remote_id)); + return Some(OutboundFrame::reset(remote_id, Reset::NoSuchConn)); }; match sock.state { @@ -264,7 +284,7 @@ impl ConnTable { id.actual_remote = %real_remote_id, "process_ack: socket is not connecting" ); - Some(OutboundFrame::reset(remote_id)) + Some(OutboundFrame::reset(remote_id, Reset::ConnAlreadyExists)) } ref mut state @ State::Connecting(_) => { let State::Connecting(rsp) = mem::replace(state, State::Open { remote_id }) else { @@ -277,14 +297,18 @@ impl ConnTable { if rsp.send(Ok(())).is_err() { // local initiator is no longer there, reset! tracing::debug!( - ?local_id, - ?remote_id, + id.remote = %local_id, + id.local = %remote_id, "process_ack: local initiator is no longer there; resetting" ); - return Some(OutboundFrame::reset(remote_id)); + return Some(OutboundFrame::reset(remote_id, Reset::BecauseISaidSo)); } - tracing::trace!(?local_id, ?remote_id, "process_ack: connection established"); + tracing::trace!( + id.remote = %local_id, + id.local = %remote_id, + "process_ack: connection established", + ); None } } diff --git a/source/mgnp/src/message.rs b/source/mgnp/src/message.rs index 1e0c031..368cc57 100644 --- a/source/mgnp/src/message.rs +++ b/source/mgnp/src/message.rs @@ -90,6 +90,7 @@ pub enum Header { Reset { remote_id: Id, + reason: Reset, }, } @@ -125,6 +126,24 @@ pub enum Rejection { DecodeError(DecodeError), } +/// Describes why a connection was reset. +#[derive(Copy, Clone, Debug, Eq, PartialEq, serde::Serialize, serde::Deserialize)] +pub enum Reset { + /// The connection was reset because the peer could not decode a received + /// frame. + BadFrame(DecodeError), + /// A `DATA` frame or `ACK` was recieved on a connection that did not exist. + NoSuchConn, + /// A `CONNECT` or `ACK` frame was recieved on a connection that was already + /// established. + ConnAlreadyExists, + /// The connection was reset because the peer is shutting down its MGNP + /// interface on this wire. + ShuttingDown, + /// The peer "just wanted to" reset the connection. + BecauseISaidSo, +} + #[derive(Copy, Clone, Debug, Eq, PartialEq, serde::Serialize, serde::Deserialize)] pub enum Nak { /// A frame was rejected because it could not be decoded successfully. @@ -204,7 +223,7 @@ impl Header { local: Some(local_id), remote: None, }, - Self::Reset { remote_id } => LinkId { + Self::Reset { remote_id, .. } => LinkId { remote: Some(remote_id), local: None, }, @@ -296,9 +315,9 @@ impl<'data> Frame> { } } - pub fn reset(remote_id: Id) -> Self { + pub fn reset(remote_id: Id, reason: Reset) -> Self { Self { - header: Header::Reset { remote_id }, + header: Header::Reset { remote_id, reason }, body: OutboundData::Empty, } } @@ -350,6 +369,26 @@ impl OutboundData<'_> { } } +// === impl Reset === + +impl Reset { + pub(crate) fn bad_frame(error: postcard::Error) -> Self { + Self::BadFrame(DecodeError::body(error)) + } +} + +impl fmt::Display for Reset { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match self { + Self::BadFrame(err) => write!(f, "received a bad frame: {err}"), + Self::NoSuchConn => f.write_str("no such connection exists"), + Self::ConnAlreadyExists => f.write_str("connection already exists"), + Self::ShuttingDown => f.write_str("the peer is shutting down this interface"), + Self::BecauseISaidSo => f.write_str("because i said so"), + } + } +} + // === impl DecodeError === impl DecodeError { From e709c8e9706545c4e4393be2ab96459d1142cc90 Mon Sep 17 00:00:00 2001 From: Eliza Weisman Date: Sat, 11 Nov 2023 13:04:01 -0800 Subject: [PATCH 02/21] better reason codes --- source/mgnp/src/conn_table.rs | 8 +----- source/mgnp/src/message.rs | 54 +++++++++++++++++++---------------- 2 files changed, 31 insertions(+), 31 deletions(-) diff --git a/source/mgnp/src/conn_table.rs b/source/mgnp/src/conn_table.rs index f174830..b754784 100644 --- a/source/mgnp/src/conn_table.rs +++ b/source/mgnp/src/conn_table.rs @@ -209,12 +209,6 @@ impl ConnTable { self.close(remote_id, Some(reason)); None } - Header::Nak { remote_id, reason } => { - tracing::warn!(id.local = %remote_id, ?reason, "process_inbound: NAK"); - // TODO(eliza): if applicable, tell the local peer that sent the - // frame that it was bad... - None - } Header::Reset { remote_id, reason } => { tracing::trace!(id.local = %remote_id, %reason, "process_inbound: RESET"); let _closed = self.close(remote_id, None); @@ -284,7 +278,7 @@ impl ConnTable { id.actual_remote = %real_remote_id, "process_ack: socket is not connecting" ); - Some(OutboundFrame::reset(remote_id, Reset::ConnAlreadyExists)) + Some(OutboundFrame::reset(remote_id, Reset::YesSuchConn)) } ref mut state @ State::Connecting(_) => { let State::Connecting(rsp) = mem::replace(state, State::Open { remote_id }) else { diff --git a/source/mgnp/src/message.rs b/source/mgnp/src/message.rs index 368cc57..ce9192e 100644 --- a/source/mgnp/src/message.rs +++ b/source/mgnp/src/message.rs @@ -81,13 +81,6 @@ pub enum Header { reason: Rejection, }, - /// A frame (other than `CONNECT`) was not acknowledged. - Nak { - /// The remote ID of the NAKed frame. - remote_id: Id, - reason: Nak, - }, - Reset { remote_id: Id, reason: Reset, @@ -129,19 +122,35 @@ pub enum Rejection { /// Describes why a connection was reset. #[derive(Copy, Clone, Debug, Eq, PartialEq, serde::Serialize, serde::Deserialize)] pub enum Reset { - /// The connection was reset because the peer could not decode a received - /// frame. - BadFrame(DecodeError), - /// A `DATA` frame or `ACK` was recieved on a connection that did not exist. + /// The peer "just wanted to" reset the connection. + /// + /// This may because of an application-layer protocol error, or because the + /// connection simply was no longer needed. + BecauseISaidSo, + /// The connection was reset because this peer could not decode a [`DATA`] + /// frame received from the remote peer. + /// + /// Cyber police have been alerted. + /// + /// [`DATA`]: Header::Data + YouDoneGoofed(DecodeError), + /// A [`DATA`] frame or [`ACK`] was recieved on a connection that did not + /// exist. + /// + /// [`DATA`]: Header::Data + /// [`ACK`]: Header::Ack NoSuchConn, - /// A `CONNECT` or `ACK` frame was recieved on a connection that was already + /// A [`CONNECT`] or [`ACK`] frame was recieved on a connection that was already /// established. - ConnAlreadyExists, + /// + /// [`CONNECT`]: Header::Connect + /// [`ACK`]: Header::Ack + YesSuchConn, /// The connection was reset because the peer is shutting down its MGNP /// interface on this wire. - ShuttingDown, - /// The peer "just wanted to" reset the connection. - BecauseISaidSo, + /// + /// No further connections will be accepted, and you should go away forever. + GoAway, } #[derive(Copy, Clone, Debug, Eq, PartialEq, serde::Serialize, serde::Deserialize)] @@ -211,10 +220,7 @@ impl Header { local: Some(local_id), remote: Some(remote_id), }, - Self::Nak { remote_id, .. } => LinkId { - local: None, - remote: Some(remote_id), - }, + Self::Reject { remote_id, .. } => LinkId { local: None, remote: Some(remote_id), @@ -373,17 +379,17 @@ impl OutboundData<'_> { impl Reset { pub(crate) fn bad_frame(error: postcard::Error) -> Self { - Self::BadFrame(DecodeError::body(error)) + Self::YouDoneGoofed(DecodeError::body(error)) } } impl fmt::Display for Reset { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { match self { - Self::BadFrame(err) => write!(f, "received a bad frame: {err}"), + Self::YouDoneGoofed(err) => write!(f, "you done goofed! received a bad frame: {err}"), Self::NoSuchConn => f.write_str("no such connection exists"), - Self::ConnAlreadyExists => f.write_str("connection already exists"), - Self::ShuttingDown => f.write_str("the peer is shutting down this interface"), + Self::YesSuchConn => f.write_str("connection already exists"), + Self::GoAway => f.write_str("the peer is shutting down this interface, go away!"), Self::BecauseISaidSo => f.write_str("because i said so"), } } From 24c43709b99343ff6caa3e8d4821c25484e1d4e3 Mon Sep 17 00:00:00 2001 From: Eliza Weisman Date: Sat, 11 Nov 2023 13:11:31 -0800 Subject: [PATCH 03/21] nicer maybe --- source/mgnp/src/conn_table.rs | 47 ++++++++++++++++------------------- 1 file changed, 22 insertions(+), 25 deletions(-) diff --git a/source/mgnp/src/conn_table.rs b/source/mgnp/src/conn_table.rs index b754784..d2fb1fc 100644 --- a/source/mgnp/src/conn_table.rs +++ b/source/mgnp/src/conn_table.rs @@ -152,15 +152,15 @@ impl ConnTable { id.remote = %local_id, id.local = %remote_id, len = frame.body.len(), - "process_inbound: data", + "process_inbound: DATA", ); // the remote peer's remote ID is our local ID. let id = remote_id; let Some(socket) = self.conns.get_mut(id) else { tracing::debug!( id.remote = %local_id, - id.local = %remote_id, - "process_inbound: recieved a DATA frame on a connection that does not exist, resetting...", + id.local = %id, + "process_inbound(DATA): connection does not exist, resetting...", ); return Some(OutboundFrame::reset(local_id, Reset::NoSuchConn)); }; @@ -170,30 +170,27 @@ impl ConnTable { Ok(permit) => match permit.send(frame.body) { Ok(_) => return None, Err(error) => { + // TODO(eliza): possibly it would be better if we + // just sent the deserialize error to the local peer + // and let it decide whether this should kill the + // connection or not? but that's annoying... tracing::debug!( - id.remote = %local_id, - id.local = %remote_id, - %error, - "process_inbound: failed to deserialize DATA frame; resetting...", + id.remote = %local_id, + id.local = %id, + %error, + "process_inbound(DATA): failed to deserialize; resetting...", ); Reset::bad_frame(error) } }, - Err(InboundError::ChannelClosed) => { - // the channel has closed locally - tracing::trace!("process_inbound: recieved a DATA frame on a connection closed locally; resetting..."); - Reset::BecauseISaidSo - } - Err(InboundError::NoSocket) => { - tracing::debug!( - id.remote = %local_id, - id.local = %remote_id, - "process_inbound: recieved a DATA frame on a connection that does not exist, resetting...", - ); - Reset::NoSuchConn - } + Err(reset) => reset, }; - + tracing::trace!( + id.remote = %local_id, + id.local = %id, + reason = %reset, + "process_inbound(DATA): connection reset", + ); self.close(id, None); Some(OutboundFrame::reset(local_id, reset)) } @@ -212,7 +209,7 @@ impl ConnTable { Header::Reset { remote_id, reason } => { tracing::trace!(id.local = %remote_id, %reason, "process_inbound: RESET"); let _closed = self.close(remote_id, None); - tracing::trace!(id.local = %remote_id, closed = _closed, "process_inbound: RESET ->"); + tracing::trace!(id.local = %remote_id, closed = _closed, "process_inbound(RESET): connection closed"); None } Header::Connect { local_id, identity } => { @@ -514,13 +511,13 @@ impl Entries { // === impl Entry === impl Entry { - async fn reserve_send(&self) -> Result, InboundError> { + async fn reserve_send(&self) -> Result, Reset> { self.channel() - .ok_or(InboundError::NoSocket)? + .ok_or(Reset::NoSuchConn)? .tx() .reserve() .await - .map_err(|_| InboundError::ChannelClosed) + .map_err(|_| Reset::BecauseISaidSo) } fn socket(&self) -> Option<&Socket> { From d7be82f6dd128360df0ed8c08a3c0abfc78c447c Mon Sep 17 00:00:00 2001 From: Eliza Weisman Date: Sat, 11 Nov 2023 13:22:12 -0800 Subject: [PATCH 04/21] moar --- source/mgnp/src/conn_table.rs | 4 +++- source/mgnp/src/message.rs | 10 ---------- 2 files changed, 3 insertions(+), 11 deletions(-) diff --git a/source/mgnp/src/conn_table.rs b/source/mgnp/src/conn_table.rs index d2fb1fc..ec65141 100644 --- a/source/mgnp/src/conn_table.rs +++ b/source/mgnp/src/conn_table.rs @@ -173,7 +173,9 @@ impl ConnTable { // TODO(eliza): possibly it would be better if we // just sent the deserialize error to the local peer // and let it decide whether this should kill the - // connection or not? but that's annoying... + // connection or not? maybe by turning the server's + // client-to-server stream into `Result`s? tracing::debug!( id.remote = %local_id, id.local = %id, diff --git a/source/mgnp/src/message.rs b/source/mgnp/src/message.rs index ce9192e..68aabb0 100644 --- a/source/mgnp/src/message.rs +++ b/source/mgnp/src/message.rs @@ -153,16 +153,6 @@ pub enum Reset { GoAway, } -#[derive(Copy, Clone, Debug, Eq, PartialEq, serde::Serialize, serde::Deserialize)] -pub enum Nak { - /// A frame was rejected because it could not be decoded successfully. - DecodeError(DecodeError), - /// The local ID sent by the remote does not exist. - UnknownLocalId(Id), - /// The remote ID sent by the remote does not correspond to an existing stream. - UnknownRemoteId(Id), -} - pub type InboundFrame<'data> = Frame<&'data [u8]>; pub type OutboundFrame<'data> = Frame>; From 0022203270f2180058379388305870812b4d086c Mon Sep 17 00:00:00 2001 From: Eliza Weisman Date: Sat, 11 Nov 2023 13:41:38 -0800 Subject: [PATCH 05/21] send resets to client conns --- source/mgnp/src/client.rs | 14 ++-- source/mgnp/src/conn_table.rs | 122 ++++++++++++++++++++++++------- source/mgnp/tests/integration.rs | 44 +++++------ 3 files changed, 128 insertions(+), 52 deletions(-) diff --git a/source/mgnp/src/client.rs b/source/mgnp/src/client.rs index 1ca0e8b..3e10f2b 100644 --- a/source/mgnp/src/client.rs +++ b/source/mgnp/src/client.rs @@ -1,4 +1,7 @@ -use crate::{message::Rejection, registry}; +use crate::{ + message::{Rejection, Reset}, + registry, +}; use tricky_pipe::{bidi, mpsc, oneshot, serbox}; pub struct Connector { @@ -25,16 +28,17 @@ pub enum ConnectError { Nak(Rejection), } -pub type ClientChannel = - bidi::BiDi<::ServerMsg, ::ClientMsg>; +pub type ClientChannel = bidi::BiDi, ::ClientMsg>; + +type ServerResult = Result<::ServerMsg, Reset>; pub struct Channels { srv_chan: bidi::SerBiDi, - client_chan: bidi::BiDi, + client_chan: bidi::BiDi, S::ClientMsg>, } pub struct StaticChannels { - s2c: mpsc::StaticTrickyPipe, + s2c: mpsc::StaticTrickyPipe, CAPACITY>, c2s: mpsc::StaticTrickyPipe, } diff --git a/source/mgnp/src/conn_table.rs b/source/mgnp/src/conn_table.rs index ec65141..c4b2969 100644 --- a/source/mgnp/src/conn_table.rs +++ b/source/mgnp/src/conn_table.rs @@ -130,7 +130,7 @@ impl ConnTable { // weird... if let Some(local_id) = self.dead_index.take() { tracing::debug!(id.local = %local_id, "removing closed stream from dead index"); - self.close(local_id, None); + self.remove(local_id); } } @@ -193,7 +193,7 @@ impl ConnTable { reason = %reset, "process_inbound(DATA): connection reset", ); - self.close(id, None); + self.remove(id); Some(OutboundFrame::reset(local_id, reset)) } Header::Ack { @@ -205,12 +205,12 @@ impl ConnTable { } Header::Reject { remote_id, reason } => { tracing::trace!(id.local = %remote_id, ?reason, "process_inbound: REJECT"); - self.close(remote_id, Some(reason)); + self.reject(remote_id, reason); None } Header::Reset { remote_id, reason } => { tracing::trace!(id.local = %remote_id, %reason, "process_inbound: RESET"); - let _closed = self.close(remote_id, None); + let _closed = self.reset(remote_id, reason).await; tracing::trace!(id.local = %remote_id, closed = _closed, "process_inbound(RESET): connection closed"); None } @@ -323,40 +323,112 @@ impl ConnTable { } } + fn reject(&mut self, local_id: Id, rejection: Rejection) -> bool { + match self.remove(local_id) { + Some(Socket { + state: State::Open { .. }, + .. + }) => { + tracing::warn!( + iid.local = %local_id, + ?rejection, + "reject: tried to REJECT an established connection. the remote *should* have sent a RESET instead", + ); + false + } + Some(Socket { + state: State::Connecting(rsp), + .. + }) => rsp.send(Err(rejection)).is_ok(), + None => { + tracing::warn!( + id.local = %local_id, + ?rejection, + "reject: tried to REJECT a non-existent connection", + ); + false + } + } + } + + async fn reset(&mut self, local_id: Id, reset: Reset) -> bool { + match self.remove(local_id) { + Some(Socket { + state: State::Open { .. }, + channel, + }) => { + let mut bytes = [0; 32]; + match postcard::to_slice(&reset, &mut bytes) { + Err(error) => { + debug_assert!(false, "failed to serialize RESET, what the fuck! this should not happen! {error:?}"); + tracing::error!( + ?error, + "failed to serialize RESET, what the fuck! this should not happen!" + ); + false + } + Ok(bytes) => channel.tx().send(bytes).await.is_ok(), + } + } + Some(Socket { + state: State::Connecting(_rsp), + .. + }) => { + tracing::warn!( + id.local = %local_id, + ?reset, + "reset: tried to RESET an establishing connection. the remote *should* have sent a REJECT instead", + ); + // TODO(eliza): send some kinda rejection? + false + } + None => { + tracing::warn!( + id.local = %local_id, + ?reset, + "reset: tried to RESET a non-existent connection", + ); + false + } + } + } + /// Returns `true` if a connection with the provided ID was closed, `false` if /// no conn existed for that ID. - fn close(&mut self, local_id: Id, nak: Option) -> bool { + fn remove(&mut self, local_id: Id) -> Option { match self.conns.get_mut(local_id) { None => { tracing::trace!(?local_id, "close: ID greater than max conns ({CAPACITY})"); - false + None } Some(entry @ Entry::Occupied(_)) => { - tracing::trace!(?local_id, self.len, ?nak, "close: closing connection"); - let entry = mem::replace(entry, Entry::Closed(self.next_id)); - if let Some(nak) = nak { - match entry { - Entry::Occupied(Socket { - state: State::Connecting(rsp), - .. - }) => { - let nacked = rsp.send(Err(nak)).is_ok(); - tracing::trace!(?local_id, ?nak, nacked, "close: sent nak"); - } - Entry::Occupied(..) => { - tracing::warn!(?local_id, ?nak, "close: tried to NAK an established connection. the remote *should* have sent a RESET instead"); - } - _ => unreachable!("we just matched an occupied entry!"), - } - } + tracing::trace!(?local_id, self.len, "close: closing connection"); + let Entry::Occupied(sock) = mem::replace(entry, Entry::Closed(self.next_id)) else { + unreachable!("what the fuck, we just matched this as an occupied entry!"); + }; + // if let Some(nak) = nak { + // match entry { + // Entry::Occupied(Socket { + // state: State::Connecting(rsp), + // .. + // }) => { + // let nacked = rsp.send(Err(nak)).is_ok(); + // tracing::trace!(?local_id, ?nak, nacked, "close: sent nak"); + // } + // Entry::Occupied(..) => { + // + // } + // _ => unreachable!("we just matched an occupied entry!"), + // } + // } self.next_id = local_id; self.len -= 1; - true + Some(sock) } Some(_) => { tracing::trace!(?local_id, "close: no connection for ID"); - false + None } } } diff --git a/source/mgnp/tests/integration.rs b/source/mgnp/tests/integration.rs index 6977c57..3b34e83 100644 --- a/source/mgnp/tests/integration.rs +++ b/source/mgnp/tests/integration.rs @@ -25,9 +25,9 @@ async fn basically_works() { let rsp = chan.rx().recv().await; assert_eq!( rsp, - Some(HelloWorldResponse { + Some(Ok(HelloWorldResponse { world: "world".to_string() - }) + })) ); fixture.finish_test().await; @@ -62,9 +62,9 @@ async fn hellos_work() { let rsp = chan.rx().recv().await; assert_eq!( rsp, - Some(HelloWorldResponse { + Some(Ok(HelloWorldResponse { world: "world".to_string() - }) + })) ); fixture.finish_test().await; @@ -112,9 +112,9 @@ async fn nak_bad_hello() { let rsp = chan.rx().recv().await; assert_eq!( rsp, - Some(HelloWorldResponse { + Some(Ok(HelloWorldResponse { world: "world".to_string() - }) + })) ); fixture.finish_test().await; @@ -152,15 +152,15 @@ async fn mux_single_service() { assert_eq!( rsp1, - Some(HelloWorldResponse { + Some(Ok(HelloWorldResponse { world: "world".to_string() - }) + })) ); assert_eq!( rsp2, - Some(HelloWorldResponse { + Some(Ok(HelloWorldResponse { world: "world".to_string() - }) + })) ); fixture.finish_test().await; @@ -228,9 +228,9 @@ async fn service_type_routing() { let rsp = helloworld_chan.rx().recv().await; assert_eq!( rsp, - Some(HelloWorldResponse { + Some(Ok(HelloWorldResponse { world: "world".to_string() - }) + })) ); // add the other service @@ -264,17 +264,17 @@ async fn service_type_routing() { let rsp = helloworld_chan.rx().recv().await; assert_eq!( rsp, - Some(HelloWorldResponse { + Some(Ok(HelloWorldResponse { world: "world".to_string() - }) + })) ); let rsp = hellohello_chan.rx().recv().await; assert_eq!( rsp, - Some(HelloWorldResponse { + Some(Ok(HelloWorldResponse { world: "world".to_string() - }) + })) ); } @@ -318,9 +318,9 @@ async fn service_identity_routing() { let rsp = sf_conn.rx().recv().await; assert_eq!( rsp, - Some(HelloWorldResponse { + Some(Ok(HelloWorldResponse { world: "san francisco".to_string() - }) + })) ); // add the 'hello-universe' service @@ -346,14 +346,14 @@ async fn service_identity_routing() { let (sf_rsp, uni_rsp) = tokio::join! { sf_conn.rx().recv(), uni_conn.rx().recv() }; assert_eq!( sf_rsp, - Some(HelloWorldResponse { + Some(Ok(HelloWorldResponse { world: "san francisco".to_string() - }) + })) ); assert_eq!( uni_rsp, - Some(HelloWorldResponse { + Some(Ok(HelloWorldResponse { world: "universe".to_string() - }) + })) ); } From 3eef17d768b3dcbc2797b029cd15d79e7ef228fd Mon Sep 17 00:00:00 2001 From: Eliza Weisman Date: Sat, 11 Nov 2023 14:48:59 -0800 Subject: [PATCH 06/21] okay actually that doesn't work because msgs don't serialize --- source/mgnp/src/client.rs | 14 ++++------ source/mgnp/tests/integration.rs | 44 ++++++++++++++++---------------- 2 files changed, 27 insertions(+), 31 deletions(-) diff --git a/source/mgnp/src/client.rs b/source/mgnp/src/client.rs index 3e10f2b..1ca0e8b 100644 --- a/source/mgnp/src/client.rs +++ b/source/mgnp/src/client.rs @@ -1,7 +1,4 @@ -use crate::{ - message::{Rejection, Reset}, - registry, -}; +use crate::{message::Rejection, registry}; use tricky_pipe::{bidi, mpsc, oneshot, serbox}; pub struct Connector { @@ -28,17 +25,16 @@ pub enum ConnectError { Nak(Rejection), } -pub type ClientChannel = bidi::BiDi, ::ClientMsg>; - -type ServerResult = Result<::ServerMsg, Reset>; +pub type ClientChannel = + bidi::BiDi<::ServerMsg, ::ClientMsg>; pub struct Channels { srv_chan: bidi::SerBiDi, - client_chan: bidi::BiDi, S::ClientMsg>, + client_chan: bidi::BiDi, } pub struct StaticChannels { - s2c: mpsc::StaticTrickyPipe, CAPACITY>, + s2c: mpsc::StaticTrickyPipe, c2s: mpsc::StaticTrickyPipe, } diff --git a/source/mgnp/tests/integration.rs b/source/mgnp/tests/integration.rs index 3b34e83..6977c57 100644 --- a/source/mgnp/tests/integration.rs +++ b/source/mgnp/tests/integration.rs @@ -25,9 +25,9 @@ async fn basically_works() { let rsp = chan.rx().recv().await; assert_eq!( rsp, - Some(Ok(HelloWorldResponse { + Some(HelloWorldResponse { world: "world".to_string() - })) + }) ); fixture.finish_test().await; @@ -62,9 +62,9 @@ async fn hellos_work() { let rsp = chan.rx().recv().await; assert_eq!( rsp, - Some(Ok(HelloWorldResponse { + Some(HelloWorldResponse { world: "world".to_string() - })) + }) ); fixture.finish_test().await; @@ -112,9 +112,9 @@ async fn nak_bad_hello() { let rsp = chan.rx().recv().await; assert_eq!( rsp, - Some(Ok(HelloWorldResponse { + Some(HelloWorldResponse { world: "world".to_string() - })) + }) ); fixture.finish_test().await; @@ -152,15 +152,15 @@ async fn mux_single_service() { assert_eq!( rsp1, - Some(Ok(HelloWorldResponse { + Some(HelloWorldResponse { world: "world".to_string() - })) + }) ); assert_eq!( rsp2, - Some(Ok(HelloWorldResponse { + Some(HelloWorldResponse { world: "world".to_string() - })) + }) ); fixture.finish_test().await; @@ -228,9 +228,9 @@ async fn service_type_routing() { let rsp = helloworld_chan.rx().recv().await; assert_eq!( rsp, - Some(Ok(HelloWorldResponse { + Some(HelloWorldResponse { world: "world".to_string() - })) + }) ); // add the other service @@ -264,17 +264,17 @@ async fn service_type_routing() { let rsp = helloworld_chan.rx().recv().await; assert_eq!( rsp, - Some(Ok(HelloWorldResponse { + Some(HelloWorldResponse { world: "world".to_string() - })) + }) ); let rsp = hellohello_chan.rx().recv().await; assert_eq!( rsp, - Some(Ok(HelloWorldResponse { + Some(HelloWorldResponse { world: "world".to_string() - })) + }) ); } @@ -318,9 +318,9 @@ async fn service_identity_routing() { let rsp = sf_conn.rx().recv().await; assert_eq!( rsp, - Some(Ok(HelloWorldResponse { + Some(HelloWorldResponse { world: "san francisco".to_string() - })) + }) ); // add the 'hello-universe' service @@ -346,14 +346,14 @@ async fn service_identity_routing() { let (sf_rsp, uni_rsp) = tokio::join! { sf_conn.rx().recv(), uni_conn.rx().recv() }; assert_eq!( sf_rsp, - Some(Ok(HelloWorldResponse { + Some(HelloWorldResponse { world: "san francisco".to_string() - })) + }) ); assert_eq!( uni_rsp, - Some(Ok(HelloWorldResponse { + Some(HelloWorldResponse { world: "universe".to_string() - })) + }) ); } From 110d02a0fd37268037d96d7994bc3d199013654a Mon Sep 17 00:00:00 2001 From: Eliza Weisman Date: Mon, 13 Nov 2023 09:16:07 -0800 Subject: [PATCH 07/21] wip --- source/mgnp/src/client.rs | 24 +++++--- source/mgnp/src/conn_table.rs | 65 +++++++++------------ source/tricky-pipe/src/mpsc/arc_impl.rs | 8 +-- source/tricky-pipe/src/mpsc/channel_core.rs | 16 ++--- 4 files changed, 56 insertions(+), 57 deletions(-) diff --git a/source/mgnp/src/client.rs b/source/mgnp/src/client.rs index 1ca0e8b..0f6f872 100644 --- a/source/mgnp/src/client.rs +++ b/source/mgnp/src/client.rs @@ -1,4 +1,7 @@ -use crate::{message::Rejection, registry}; +use crate::{ + message::{Rejection, Reset}, + registry, Service, +}; use tricky_pipe::{bidi, mpsc, oneshot, serbox}; pub struct Connector { @@ -17,6 +20,7 @@ pub struct OutboundConnect { pub(crate) channel: bidi::SerBiDi, /// Sender for the response from the remote service. pub(crate) rsp: oneshot::Sender>, + pub(crate) reset: oneshot::Sender, } #[derive(Debug, Eq, PartialEq)] @@ -25,20 +29,23 @@ pub enum ConnectError { Nak(Rejection), } -pub type ClientChannel = - bidi::BiDi<::ServerMsg, ::ClientMsg>; +pub struct Connection { + chan: bidi::BiDi, + reset: oneshot::Receiver, +} pub struct Channels { srv_chan: bidi::SerBiDi, client_chan: bidi::BiDi, + reset: oneshot::Receiver, } -pub struct StaticChannels { +pub struct StaticChannels { s2c: mpsc::StaticTrickyPipe, c2s: mpsc::StaticTrickyPipe, } -impl Channels { +impl Channels { pub fn from_static( storage: &'static StaticChannels, ) -> Self { @@ -54,11 +61,12 @@ impl Channels { Self { srv_chan, client_chan, + reset: oneshot::Receiver::new(), } } } -impl Connector { +impl Connector { pub async fn connect( &mut self, identity: impl Into, @@ -66,8 +74,9 @@ impl Connector { Channels { srv_chan, client_chan, + reset, }: Channels, - ) -> Result, ConnectError> { + ) -> Result, ConnectError> { let permit = self .tx .reserve() @@ -80,6 +89,7 @@ impl Connector { hello, channel: srv_chan, rsp, + reset, }; permit.send(connect); match self.rsp.recv().await { diff --git a/source/mgnp/src/conn_table.rs b/source/mgnp/src/conn_table.rs index c4b2969..be0e1f8 100644 --- a/source/mgnp/src/conn_table.rs +++ b/source/mgnp/src/conn_table.rs @@ -46,6 +46,7 @@ enum InboundError { struct Socket { state: State, channel: SerBiDi, + reset: oneshot::Sender, } enum Entry { @@ -87,6 +88,7 @@ impl ConnTable { let Some(Socket { state: State::Open { remote_id }, channel, + .. }) = entry.socket() else { continue; @@ -236,6 +238,7 @@ impl ConnTable { identity, channel, rsp, + reset, } = connect; let local_id = match self.reserve_id() { @@ -253,6 +256,7 @@ impl ConnTable { Socket { state: State::Connecting(rsp), channel, + reset, }, ); @@ -313,6 +317,7 @@ impl ConnTable { let sock = Socket { state: State::Open { remote_id }, channel, + .. }; match self.insert(sock) { @@ -351,46 +356,28 @@ impl ConnTable { } } - async fn reset(&mut self, local_id: Id, reset: Reset) -> bool { - match self.remove(local_id) { - Some(Socket { - state: State::Open { .. }, - channel, - }) => { - let mut bytes = [0; 32]; - match postcard::to_slice(&reset, &mut bytes) { - Err(error) => { - debug_assert!(false, "failed to serialize RESET, what the fuck! this should not happen! {error:?}"); - tracing::error!( - ?error, - "failed to serialize RESET, what the fuck! this should not happen!" - ); - false - } - Ok(bytes) => channel.tx().send(bytes).await.is_ok(), - } - } - Some(Socket { - state: State::Connecting(_rsp), - .. - }) => { - tracing::warn!( - id.local = %local_id, - ?reset, - "reset: tried to RESET an establishing connection. the remote *should* have sent a REJECT instead", - ); - // TODO(eliza): send some kinda rejection? - false - } - None => { - tracing::warn!( - id.local = %local_id, - ?reset, - "reset: tried to RESET a non-existent connection", - ); - false - } + async fn reset(&mut self, local_id: Id, reason: Reset) -> bool { + tracing::trace!(id.local = %local_id, %reason, "reset: resetting connection..."); + let Some(Socket { state, reset, .. }) = self.remove(local_id) else { + tracing::warn!( + id.local = %local_id, + %reason, + "reset: tried to RESET a non-existent connection", + ); + return false; + }; + + if matches!(state, State::Connecting(..)) { + tracing::warn!( + id.local = %local_id, + %reason, + "reset: tried to RESET an establishing connection. the remote *should* have sent a REJECT instead", + ); + // TODO(eliza): send some kinda rejection? + return false; } + + reset.send(reason).is_ok() } /// Returns `true` if a connection with the provided ID was closed, `false` if diff --git a/source/tricky-pipe/src/mpsc/arc_impl.rs b/source/tricky-pipe/src/mpsc/arc_impl.rs index eb2ce53..376348e 100644 --- a/source/tricky-pipe/src/mpsc/arc_impl.rs +++ b/source/tricky-pipe/src/mpsc/arc_impl.rs @@ -11,10 +11,10 @@ use super::channel_core::{Core, CoreVtable}; // TODO(eliza): we should probably replace the use of `Arc` here with manual ref // counting, since the `Core` tracks the number of senders and receivers // already. But, I was in a hurry to get a prototype working... -pub struct TrickyPipe(Arc>); +pub struct TrickyPipe(Arc>); -struct Inner { - core: Core, +struct Inner { + core: Core, // TODO(eliza): instead of boxing the elements array, we should probably // manually allocate a `Layout`. This works for now, though. // @@ -41,7 +41,7 @@ impl TrickyPipe { })) } - const CORE_VTABLE: &'static CoreVtable = &CoreVtable { + const CORE_VTABLE: &'static CoreVtable = &CoreVtable { get_core: Self::get_core, get_elems: Self::get_elems, clone: Self::erased_clone, diff --git a/source/tricky-pipe/src/mpsc/channel_core.rs b/source/tricky-pipe/src/mpsc/channel_core.rs index 89b2568..5662c41 100644 --- a/source/tricky-pipe/src/mpsc/channel_core.rs +++ b/source/tricky-pipe/src/mpsc/channel_core.rs @@ -18,7 +18,7 @@ use serde::{de::DeserializeOwned, Serialize}; #[cfg(feature = "alloc")] use alloc::vec::Vec; -pub(super) struct Core { +pub(super) struct Core { // === receiver-only state ==== /// The head of the queue (i.e. the position at which elements are popped by /// the receiver). @@ -76,17 +76,19 @@ pub(super) struct Core { /// This is the length of the actual queue elements array (which is not part /// of this struct). pub(super) capacity: u8, + /// If the channel closed with an error, this is the error. + error: UnsafeCell>, } -pub(super) struct Reservation<'core> { - core: &'core Core, +pub(super) struct Reservation<'core, E> { + core: &'core Core, pub(super) idx: u8, } /// Erases both a pipe and its element type. -pub(super) struct ErasedPipe { +pub(super) struct ErasedPipe { ptr: *const (), - vtable: &'static CoreVtable, + vtable: &'static CoreVtable, } pub(super) struct TypedPipe { @@ -103,8 +105,8 @@ pub(super) struct ErasedSlice { typ: core::any::TypeId, } -pub(super) struct CoreVtable { - pub(super) get_core: unsafe fn(*const ()) -> *const Core, +pub(super) struct CoreVtable { + pub(super) get_core: unsafe fn(*const ()) -> *const Core, pub(super) get_elems: unsafe fn(*const ()) -> ErasedSlice, pub(super) clone: unsafe fn(*const ()), pub(super) drop: unsafe fn(*const ()), From 13c1dfb7d12c7be354cb076030d20a5febd4f664 Mon Sep 17 00:00:00 2001 From: Eliza Weisman Date: Tue, 21 Nov 2023 09:45:29 -0800 Subject: [PATCH 08/21] redo channels to have errors --- source/tricky-pipe/src/bidi.rs | 64 ++-- source/tricky-pipe/src/mpsc/arc_impl.rs | 53 ++-- source/tricky-pipe/src/mpsc/channel_core.rs | 325 +++++++++++--------- source/tricky-pipe/src/mpsc/error.rs | 78 ++++- source/tricky-pipe/src/mpsc/mod.rs | 172 ++++++----- source/tricky-pipe/src/mpsc/static_impl.rs | 45 ++- source/tricky-pipe/src/mpsc/tests.rs | 18 +- 7 files changed, 440 insertions(+), 315 deletions(-) diff --git a/source/tricky-pipe/src/bidi.rs b/source/tricky-pipe/src/bidi.rs index 8946a62..257fbf7 100644 --- a/source/tricky-pipe/src/bidi.rs +++ b/source/tricky-pipe/src/bidi.rs @@ -55,22 +55,22 @@ where (self.tx, self.rx) } - /// Wait until the channel is either ready to send a message *or* a new - /// incoming message is received, whichever occurs first. - #[must_use] - pub async fn wait(&self) -> Option>> { - futures::select_biased! { - res = self.tx.reserve().fuse() => { - match res { - Ok(permit) => Some(Event::SendReady(permit)), - Err(_) => self.rx.recv().await.map(Event::Recv), - } - } - recv = self.rx.recv().fuse() => { - recv.map(Event::Recv) - } - } - } + // /// Wait until the channel is either ready to send a message *or* a new + // /// incoming message is received, whichever occurs first. + // #[must_use] + // pub async fn wait(&self) -> Option>> { + // futures::select_biased! { + // res = self.tx.reserve().fuse() => { + // match res { + // Ok(permit) => Some(Event::SendReady(permit)), + // Err(_) => self.rx.recv().await.map(Event::Recv), + // } + // } + // recv = self.rx.recv().fuse() => { + // recv.map(Event::Recv) + // } + // } + // } /// Borrows the **send half** of this bidirectional channel. /// @@ -149,22 +149,22 @@ impl SerBiDi { (self.tx, self.rx) } - /// Wait until the channel is either ready to send a message *or* a new - /// incoming message is received, whichever occurs first. - #[must_use] - pub async fn wait(&self) -> Option, SerPermit<'_>>> { - futures::select_biased! { - res = self.tx.reserve().fuse() => { - match res { - Ok(permit) => Some(Event::SendReady(permit)), - Err(_) => self.rx.recv().await.map(Event::Recv), - } - } - recv = self.rx.recv().fuse() => { - recv.map(Event::Recv) - } - } - } + // /// Wait until the channel is either ready to send a message *or* a new + // /// incoming message is received, whichever occurs first. + // #[must_use] + // pub async fn wait(&self) -> Option, SerPermit<'_, ()>>> { + // futures::select_biased! { + // res = self.tx.reserve().fuse() => { + // match res { + // Ok(permit) => Some(Event::SendReady(permit)), + // Err(_) => self.rx.recv().await.map(Event::Recv), + // } + // } + // recv = self.rx.recv().fuse() => { + // recv.map(Event::Recv) + // } + // } + // } /// Borrows the **send half** of this bidirectional channel. /// diff --git a/source/tricky-pipe/src/mpsc/arc_impl.rs b/source/tricky-pipe/src/mpsc/arc_impl.rs index 376348e..733205c 100644 --- a/source/tricky-pipe/src/mpsc/arc_impl.rs +++ b/source/tricky-pipe/src/mpsc/arc_impl.rs @@ -11,10 +11,17 @@ use super::channel_core::{Core, CoreVtable}; // TODO(eliza): we should probably replace the use of `Arc` here with manual ref // counting, since the `Core` tracks the number of senders and receivers // already. But, I was in a hurry to get a prototype working... -pub struct TrickyPipe(Arc>); +pub struct TrickyPipe(Arc>) +where + T: 'static, + E: Clone + 'static; -struct Inner { - core: Core, +struct Inner +where + T: 'static, + E: Clone + 'static, +{ + core: Core, // TODO(eliza): instead of boxing the elements array, we should probably // manually allocate a `Layout`. This works for now, though. // @@ -26,7 +33,7 @@ struct Inner { elements: Box<[Cell]>, } -impl TrickyPipe { +impl TrickyPipe { /// Create a new [`TrickyPipe`] allocated on the heap. /// /// NOTE: `CAPACITY` MUST be a power of two, and must also be <= the number of bits @@ -49,19 +56,19 @@ impl TrickyPipe { type_name: core::any::type_name::, }; - fn erased(&self) -> ErasedPipe { + fn erased(&self) -> ErasedPipe { let ptr = Arc::into_raw(self.0.clone()) as *const _; unsafe { ErasedPipe::new(ptr, Self::CORE_VTABLE) } } - fn typed(&self) -> TypedPipe { + fn typed(&self) -> TypedPipe { unsafe { self.erased().typed() } } /// Try to obtain a [`Receiver`] capable of receiving `T`-typed data /// /// This method will only return [`Some`] on the first call. All subsequent calls /// will return [`None`]. - pub fn receiver(&self) -> Option> { + pub fn receiver(&self) -> Option> { self.0.core.try_claim_rx()?; Some(Receiver { pipe: self.typed() }) @@ -70,45 +77,46 @@ impl TrickyPipe { /// Obtain a [`Sender`] capable of sending `T`-typed data /// /// This function may be called multiple times. - pub fn sender(&self) -> Sender { + pub fn sender(&self) -> Sender { self.0.core.add_tx(); Sender { pipe: self.typed() } } - unsafe fn get_core(ptr: *const ()) -> *const Core { + unsafe fn get_core(ptr: *const ()) -> *const Core { unsafe { - let ptr = ptr.cast::>(); + let ptr = ptr.cast::>(); ptr::addr_of!((*ptr).core) } } unsafe fn get_elems(ptr: *const ()) -> ErasedSlice { - let ptr = ptr.cast::>(); + let ptr = ptr.cast::>(); ErasedSlice::erase(&(*ptr).elements) } unsafe fn erased_clone(ptr: *const ()) { test_println!("erased_clone({ptr:p})"); - Arc::increment_strong_count(ptr.cast::>()) + Arc::increment_strong_count(ptr.cast::>()) } unsafe fn erased_drop(ptr: *const ()) { - let arc = Arc::from_raw(ptr.cast::>()); + let arc = Arc::from_raw(ptr.cast::>()); test_println!(refs = Arc::strong_count(&arc), "erased_drop({ptr:p})"); drop(arc) } } -impl TrickyPipe +impl TrickyPipe where T: Serialize + Send + 'static, + E: Clone + Send + Sync, { /// Try to obtain a [`SerReceiver`] capable of receiving bytes containing /// a serialized instance of `T`. /// /// This method will only return [`Some`] on the first call. All subsequent calls /// will return [`None`]. - pub fn ser_receiver(&self) -> Option { + pub fn ser_receiver(&self) -> Option> { self.0.core.try_claim_rx()?; Some(SerReceiver { @@ -128,16 +136,17 @@ where }; } -impl TrickyPipe +impl TrickyPipe where T: DeserializeOwned + Send + 'static, + E: Clone + Send + Sync, { /// Try to obtain a [`DeserSender`] capable of sending bytes containing /// a serialized instance of `T`. /// /// This method will only return [`Some`] on the first call. All subsequent calls /// will return [`None`]. - pub fn deser_sender(&self) -> DeserSender { + pub fn deser_sender(&self) -> DeserSender { self.0.core.add_tx(); DeserSender { pipe: self.erased(), @@ -148,7 +157,7 @@ where const DESER_VTABLE: &'static DeserVtable = &DeserVtable::new::(); } -impl Clone for TrickyPipe { +impl Clone for TrickyPipe { fn clone(&self) -> Self { test_span!("TrickyPipe::clone"); // Since the `TrickyPipe` type can construct new `Sender`s, this @@ -160,7 +169,7 @@ impl Clone for TrickyPipe { } } -impl Drop for TrickyPipe { +impl Drop for TrickyPipe { fn drop(&mut self) { test_span!("TrickyPipe::drop"); // Since the `TrickyPipe` type can construct new `Sender`s, this @@ -171,12 +180,12 @@ impl Drop for TrickyPipe { } } -unsafe impl Send for TrickyPipe {} -unsafe impl Sync for TrickyPipe {} +unsafe impl Send for TrickyPipe {} +unsafe impl Sync for TrickyPipe {} // === impl Inner === -impl Drop for Inner { +impl Drop for Inner { fn drop(&mut self) { test_span!("Inner::drop"); diff --git a/source/tricky-pipe/src/mpsc/channel_core.rs b/source/tricky-pipe/src/mpsc/channel_core.rs index 5662c41..278e429 100644 --- a/source/tricky-pipe/src/mpsc/channel_core.rs +++ b/source/tricky-pipe/src/mpsc/channel_core.rs @@ -86,13 +86,13 @@ pub(super) struct Reservation<'core, E> { } /// Erases both a pipe and its element type. -pub(super) struct ErasedPipe { +pub(super) struct ErasedPipe { ptr: *const (), vtable: &'static CoreVtable, } -pub(super) struct TypedPipe { - pipe: ErasedPipe, +pub(super) struct TypedPipe { + pipe: ErasedPipe, _t: PhantomData, } @@ -156,18 +156,21 @@ pub(super) const MAX_CAPACITY: usize = IndexAllocWord::MAX_CAPACITY as usize; /// /// This is the first bit of the pos word, so that it is not clobbered if /// incrementing the actual position in the queue wraps around (which is fine). -const CLOSED: u16 = 0b1; +const CLOSED: u16 = 1 << 0; +const HAS_ERROR: u16 = 1 << 1; +const CLOSED_ERROR: u16 = CLOSED | HAS_ERROR; +const POS_SHIFT: u16 = CLOSED_ERROR.trailing_ones() as u16; /// The value by which `enqueue_pos` and `dequeue_pos` are incremented. This is -/// shifted left by one to account for the lowest bit being used for -/// [`CLOSED_BIT`]. -const POS_ONE: u16 = 1 << 1; +/// shifted left by two to account for the lowest bits being used for `CLOSED` +/// and `HAS_ERROR` +const POS_ONE: u16 = 1 << POS_SHIFT; const MASK: u16 = MAX_CAPACITY as u16 - 1; const SEQ_SHIFT: u16 = MASK.trailing_ones() as u16; const SEQ_ONE: u16 = 1 << SEQ_SHIFT; // === impl Core === -impl Core { +impl Core { #[cfg(not(loom))] pub(super) const fn new(capacity: u8) -> Self { #[allow(clippy::declare_interior_mutable_const)] @@ -195,6 +198,7 @@ impl Core { // dropped. state: AtomicUsize::new(state::TX_ONE), capacity, + error: UnsafeCell::new(MaybeUninit::uninit()), } } @@ -222,123 +226,8 @@ impl Core { queue, state: AtomicUsize::new(state::TX_ONE), capacity, - } - } - - pub(super) fn try_reserve(&self) -> Result, TrySendError> { - test_span!("Core::try_reserve"); - if test_dbg!(self.enqueue_pos.load(Acquire)) & CLOSED != 0 { - return Err(TrySendError::Closed(())); - } - test_dbg!(self.indices.allocate()) - .ok_or(TrySendError::Full(())) - .map(|idx| Reservation { core: self, idx }) - } - pub(super) async fn reserve(&self) -> Result { - loop { - match self.try_reserve() { - Ok(res) => return Ok(res), - Err(TrySendError::Closed(())) => return Err(SendError(())), - Err(TrySendError::Full(())) => { - self.prod_wait.wait().await.map_err(|_| SendError(()))? - } - } - } - } - - pub(super) fn poll_dequeue(&self, cx: &mut Context<'_>) -> Poll>> { - loop { - match self.try_dequeue() { - Ok(res) => return Poll::Ready(Some(res)), - Err(TryRecvError::Closed) => return Poll::Ready(None), - Err(TryRecvError::Empty) => { - // we never close the rx waitcell, because the - // rx is responsible for determining if the channel is - // closed by the tx: there may be messages in the channel to - // consume before the rx considers it properly closed. - let _ = task::ready!(test_dbg!(self.cons_wait.poll_wait(cx))); - // if the poll_wait returns ready, then another thread just - // enqueued something. sticking a spin loop hint here tells - // `loom` that we're waiting for that thread before we can - // make progress. in real life, the `PAUSE` instruction or - // similar may also help us actually see the other thread's - // change...if it takes a single cycle of delay for it to - // reflect? idk lol ¯\_(ツ)_/¯ - hint::spin_loop(); - } - } - } - } - - pub(super) fn try_dequeue(&self) -> Result, TryRecvError> { - test_span!("Core::try_dequeue"); - let mut head = test_dbg!(self.dequeue_pos.load(Acquire)); - loop { - // Shift one bit to the right to extract the actual position, and - // discard the `CLOSED` bit. - let pos = head >> 1; - let slot = &self.queue[(pos & MASK) as usize]; - // Load the slot's current value, and extract its sequence number. - let val = slot.load(Acquire); - let seq = val >> SEQ_SHIFT; - let dif = test_dbg!(seq as i8).wrapping_sub(test_dbg!(pos).wrapping_add(1) as i8); - - match test_dbg!(dif).cmp(&0) { - cmp::Ordering::Less if test_dbg!(head & CLOSED) != 0 => { - return Err(TryRecvError::Closed) - } - cmp::Ordering::Less => return Err(TryRecvError::Empty), - cmp::Ordering::Equal => match test_dbg!(self.dequeue_pos.compare_exchange_weak( - head, - head.wrapping_add(POS_ONE), - AcqRel, - Acquire, - )) { - Ok(_) => { - slot.store(val.wrapping_add(SEQ_ONE), Release); - return Ok(Reservation { - core: self, - idx: (val & MASK) as u8, - }); - } - Err(actual) => head = actual, - }, - cmp::Ordering::Greater => head = test_dbg!(self.dequeue_pos.load(Acquire)), - } - } - } - - fn commit_send(&self, idx: u8) { - test_span!("Core::commit_send", idx); - debug_assert!(idx as u16 <= MASK); - let mut tail = test_dbg!(self.enqueue_pos.load(Acquire)); - loop { - // Shift one bit to the right to extract the actual position, and - // discard the `CLOSED` bit. - let pos = tail >> 1; - let slot = &self.queue[test_dbg!(pos & MASK) as usize]; - let seq = slot.load(Acquire) >> SEQ_SHIFT; - let dif = test_dbg!(seq as i8).wrapping_sub(test_dbg!(pos as i8)); - - match test_dbg!(dif).cmp(&0) { - cmp::Ordering::Less => unreachable!(), - cmp::Ordering::Equal => match test_dbg!(self.enqueue_pos.compare_exchange_weak( - tail, - tail.wrapping_add(POS_ONE), - AcqRel, - Acquire, - )) { - Ok(_) => { - let new = test_dbg!(test_dbg!((pos) << SEQ_SHIFT).wrapping_add(SEQ_ONE)); - slot.store(test_dbg!(idx as u16 | new), Release); - test_dbg!(self.cons_wait.wake()); - return; - } - Err(actual) => tail = actual, - }, - cmp::Ordering::Greater => tail = test_dbg!(self.enqueue_pos.load(Acquire)), - } + error: UnsafeCell::new(MaybeUninit::uninit()), } } @@ -460,25 +349,177 @@ impl Core { } } +impl Core { + pub(super) fn try_reserve(&self) -> Result, TrySendError> { + test_span!("Core::try_reserve"); + let enqueue_pos = self.enqueue_pos.load(Acquire); + if test_dbg!(enqueue_pos & CLOSED) == CLOSED { + return Err(self + .send_closed_error() + .map(|error| TrySendError::Error { error, message: () }) + .unwrap_or(TrySendError::Closed(()))); + } + + test_dbg!(self.indices.allocate()) + .ok_or(TrySendError::Full(())) + .map(|idx| Reservation { core: self, idx }) + } + + pub(super) async fn reserve(&self) -> Result, SendError> { + loop { + match self.try_reserve() { + Ok(res) => return Ok(res), + Err(TrySendError::Closed(())) => return Err(SendError::Closed(())), + Err(TrySendError::Error { error, .. }) => { + return Err(SendError::Error { error, message: () }) + } + Err(TrySendError::Full(())) => self.prod_wait.wait().await.map_err(|_| { + self.send_closed_error() + .map(|error| SendError::Error { error, message: () }) + .unwrap_or(SendError::Closed(())) + })?, + } + } + } + + pub(super) fn poll_dequeue( + &self, + cx: &mut Context<'_>, + ) -> Poll, RecvError>> { + loop { + match self.try_dequeue() { + Ok(res) => return Poll::Ready(Ok(res)), + Err(TryRecvError::Closed) => return Poll::Ready(Err(RecvError::Closed)), + Err(TryRecvError::Error(error)) => { + return Poll::Ready(Err(RecvError::Error(error))) + } + Err(TryRecvError::Empty) => { + // we never close the rx waitcell, because the + // rx is responsible for determining if the channel is + // closed by the tx: there may be messages in the channel to + // consume before the rx considers it properly closed. + let _ = task::ready!(test_dbg!(self.cons_wait.poll_wait(cx))); + // if the poll_wait returns ready, then another thread just + // enqueued something. sticking a spin loop hint here tells + // `loom` that we're waiting for that thread before we can + // make progress. in real life, the `PAUSE` instruction or + // similar may also help us actually see the other thread's + // change...if it takes a single cycle of delay for it to + // reflect? idk lol ¯\_(ツ)_/¯ + hint::spin_loop(); + } + } + } + } + + pub(super) fn try_dequeue(&self) -> Result, TryRecvError> { + test_span!("Core::try_dequeue"); + let mut head = test_dbg!(self.dequeue_pos.load(Acquire)); + loop { + if head & CLOSED_ERROR == CLOSED_ERROR { + return Err(TryRecvError::Error(unsafe { self.close_error() })); + } + // Shift to the right to extract the actual position, and + // discard the `CLOSED` and `HAS_ERROR` bits. + let pos = head >> POS_SHIFT; + let slot = &self.queue[(pos & MASK) as usize]; + // Load the slot's current value, and extract its sequence number. + let val = slot.load(Acquire); + let seq = val >> SEQ_SHIFT; + let dif = test_dbg!(seq as i8).wrapping_sub(test_dbg!(pos).wrapping_add(1) as i8); + + match test_dbg!(dif).cmp(&0) { + cmp::Ordering::Less if test_dbg!(head & CLOSED) != 0 => { + return Err(TryRecvError::Closed) + } + cmp::Ordering::Less => return Err(TryRecvError::Empty), + cmp::Ordering::Equal => match test_dbg!(self.dequeue_pos.compare_exchange_weak( + head, + head.wrapping_add(POS_ONE), + AcqRel, + Acquire, + )) { + Ok(_) => { + slot.store(val.wrapping_add(SEQ_ONE), Release); + return Ok(Reservation { + core: self, + idx: (val & MASK) as u8, + }); + } + Err(actual) => head = actual, + }, + cmp::Ordering::Greater => head = test_dbg!(self.dequeue_pos.load(Acquire)), + } + } + } + + fn commit_send(&self, idx: u8) -> Result<(), SendError<(), E>> { + test_span!("Core::commit_send", idx); + debug_assert!(idx as u16 <= MASK); + let mut tail = test_dbg!(self.enqueue_pos.load(Acquire)); + loop { + // Shift one bit to the right to extract the actual position, and + // discard the `CLOSED` bit. + let pos = tail >> 1; + let slot = &self.queue[test_dbg!(pos & MASK) as usize]; + let seq = slot.load(Acquire) >> SEQ_SHIFT; + let dif = test_dbg!(seq as i8).wrapping_sub(test_dbg!(pos as i8)); + + match test_dbg!(dif).cmp(&0) { + cmp::Ordering::Less => unreachable!(), + cmp::Ordering::Equal => match test_dbg!(self.enqueue_pos.compare_exchange_weak( + tail, + tail.wrapping_add(POS_ONE), + AcqRel, + Acquire, + )) { + Ok(_) => { + let new = test_dbg!(test_dbg!((pos) << SEQ_SHIFT).wrapping_add(SEQ_ONE)); + slot.store(test_dbg!(idx as u16 | new), Release); + test_dbg!(self.cons_wait.wake()); + return Ok(()); + } + Err(actual) => tail = actual, + }, + cmp::Ordering::Greater => tail = test_dbg!(self.enqueue_pos.load(Acquire)), + } + } + } + + fn send_closed_error(&self) -> Option { + if test_dbg!(self.enqueue_pos.load(Acquire) & CLOSED_ERROR) == CLOSED_ERROR { + Some(unsafe { self.close_error() }) + } else { + None + } + } + + unsafe fn close_error(&self) -> E { + // debug_assert!(self.enqueue_pos.load(Acquire) & CLOSED_ERROR == CLOSED_ERROR); + self.error + .with(|ptr| unsafe { (*ptr).assume_init_ref().clone() }) + } +} + // === impl Reservation === -impl Reservation<'_> { - pub(super) fn commit_send(self) { +impl Reservation<'_, E> { + pub(super) fn commit_send(self) -> Result<(), SendError<(), E>> { // don't run the destructor that frees the index, since we are dropping // the cell... let this = ManuallyDrop::new(self); // ...and commit to the queue. - this.core.commit_send(this.idx); + this.core.commit_send(this.idx) } } -impl Drop for Reservation<'_> { +impl Drop for Reservation<'_, E> { fn drop(&mut self) { unsafe { self.core.uncommit(self.idx) } } } -impl fmt::Debug for Reservation<'_> { +impl fmt::Debug for Reservation<'_, E> { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { let Self { core, idx } = self; f.debug_struct("Reservation") @@ -514,8 +555,8 @@ impl ErasedSlice { // == impl ErasedPipe === -impl ErasedPipe { - pub(super) unsafe fn new(ptr: *const (), vtable: &'static CoreVtable) -> Self { +impl ErasedPipe { + pub(super) unsafe fn new(ptr: *const (), vtable: &'static CoreVtable) -> Self { Self { ptr, vtable } } @@ -523,14 +564,14 @@ impl ErasedPipe { /// /// This `ErasedPipe` must have been type-erased from a tricky-pipe with /// elements of type `T`! - pub(super) unsafe fn typed(self) -> TypedPipe { + pub(super) unsafe fn typed(self) -> TypedPipe { TypedPipe { pipe: self, _t: PhantomData, } } - pub(super) fn core(&self) -> &Core { + pub(super) fn core(&self) -> &Core { unsafe { &*(self.vtable.get_core)(self.ptr) } } @@ -549,7 +590,7 @@ impl ErasedPipe { } } -impl Clone for ErasedPipe { +impl Clone for ErasedPipe { fn clone(&self) -> Self { unsafe { (self.vtable.clone)(self.ptr) } Self { @@ -559,21 +600,21 @@ impl Clone for ErasedPipe { } } -impl Drop for ErasedPipe { +impl Drop for ErasedPipe { fn drop(&mut self) { unsafe { (self.vtable.drop)(self.ptr) } } } // Safety: a pipe's element type must be `Send` in order to be erased. -unsafe impl Send for ErasedPipe {} +unsafe impl Send for ErasedPipe {} // Safety: a pipe's element type must be `Send` in order to be erased. -unsafe impl Sync for ErasedPipe {} +unsafe impl Sync for ErasedPipe {} // === impl TypedPipe === -impl TypedPipe { - pub(super) fn core(&self) -> &Core { +impl TypedPipe { + pub(super) fn core(&self) -> &Core { self.pipe.core() } @@ -586,7 +627,7 @@ impl TypedPipe { } } -impl Clone for TypedPipe { +impl Clone for TypedPipe { fn clone(&self) -> Self { Self { pipe: self.pipe.clone(), @@ -595,8 +636,8 @@ impl Clone for TypedPipe { } } -unsafe impl Send for TypedPipe {} -unsafe impl Sync for TypedPipe {} +unsafe impl Send for TypedPipe {} +unsafe impl Sync for TypedPipe {} // === impl SerVtable === diff --git a/source/tricky-pipe/src/mpsc/error.rs b/source/tricky-pipe/src/mpsc/error.rs index 2337068..15ecb1a 100644 --- a/source/tricky-pipe/src/mpsc/error.rs +++ b/source/tricky-pipe/src/mpsc/error.rs @@ -21,7 +21,18 @@ use core::fmt; /// [`DeserSender::reserve`]: super::DeserSender::reserve /// [`Sender::send`]: super::Sender::send #[derive(Eq, PartialEq)] -pub struct SendError(pub(crate) T); +pub enum SendError { + /// A message cannot be sent because the channel is closed (no [`Receiver`] + /// or [`SerReceiver`] exists). + /// + /// [`Receiver`]: super::Receiver + /// [`SerReceiver`]: super::SerReceiver + Closed(T), + Error { + message: T, + error: E, + }, +} /// Error returned by [`Sender::try_reserve`], [`Sender::try_send`], and /// [`DeserSender::try_reserve`]. @@ -35,7 +46,7 @@ pub struct SendError(pub(crate) T); /// [`DeserSender::try_reserve`]: super::DeserSender::try_reserve /// [`Sender::try_send`]: super::Sender::try_send #[derive(Eq, PartialEq)] -pub enum TrySendError { +pub enum TrySendError { /// The channel is currently full, and a message cannot be sent without /// waiting for a slot to become available. Full(T), @@ -45,6 +56,10 @@ pub enum TrySendError { /// [`Receiver`]: super::Receiver /// [`SerReceiver`]: super::SerReceiver Closed(T), + Error { + message: T, + error: E, + }, } /// Errors returned by [`Receiver::try_recv`] and [`SerReceiver::try_recv`]. @@ -52,7 +67,7 @@ pub enum TrySendError { /// [`Receiver::try_recv`]: super::Receiver::try_recv /// [`SerReceiver::try_recv`]: super::SerReceiver::try_recv #[derive(Debug, Eq, PartialEq)] -pub enum TryRecvError { +pub enum TryRecvError { /// No messages are currently present in the channel. The receiver must wait /// for an additional message to be sent. Empty, @@ -64,6 +79,7 @@ pub enum TryRecvError { /// [`Sender`]: super::Sender /// [`DeserSender`]: super::DeserSender Closed, + Error(E), } /// Errors returned by [`Receiver::recv`] and [`SerReceiver::recv`]. @@ -71,7 +87,7 @@ pub enum TryRecvError { /// [`Receiver::recv`]: super::Receiver::recv /// [`SerReceiver::recv`]: super::SerReceiver::recv #[derive(Debug, Eq, PartialEq)] -pub enum RecvError { +pub enum RecvError { /// A message cannot be received because channel is closed. /// /// This indicates that no [`Sender`]s or [`DeserSender`]s exist, and all @@ -80,6 +96,7 @@ pub enum RecvError { /// [`Sender`]: super::Sender /// [`DeserSender`]: super::DeserSender Closed, + Error(E), } /// Errors returned by [`DeserSender::send`] and [`DeserSender::send_framed`]. @@ -105,9 +122,9 @@ pub enum SerSendError { /// [`DeserSender::try_send`]: super::DeserSender::try_send /// [`DeserSender::try_send_framed`]: super::DeserSender::send_framed #[derive(Debug, Eq, PartialEq)] -pub enum SerTrySendError { +pub enum SerTrySendError { /// The channel is [`Closed`](TrySendError::Closed) or [`Full`](TrySendError::Full). - Send(TrySendError), + Send(TrySendError), /// The sent bytes could not be deserialized to a value of this channel's /// message type. Deserialize(postcard::Error), @@ -115,40 +132,69 @@ pub enum SerTrySendError { // === impl SendError === -impl SendError { +impl SendError { /// Obtain the `T` that failed to send, discarding the "kind" of [`SendError`]. #[inline] #[must_use] pub fn into_inner(self) -> T { - self.0 + match self { + Self::Closed(msg) => msg, + Self::Error { message, .. } => message, + } + } + + pub(crate) fn with_message(self, message: M) -> SendError { + match self { + Self::Closed(_) => SendError::Closed(message), + Self::Error { error, .. } => SendError::Error { message, error }, + } } } -impl fmt::Debug for SendError { +impl fmt::Debug for SendError { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - f.debug_struct("SendError").finish_non_exhaustive() + match self { + Self::Closed(_) => f.debug_tuple("SendError::Closed").finish(), + Self::Error { error, .. } => f + .debug_struct("SendError::Error") + .field("error", &error) + .finish_non_exhaustive(), + } } } // === impl TrySendError === -impl TrySendError { +impl TrySendError { /// Obtain the `T` that failed to send, discarding the "kind" of [`TrySendError`]. #[inline] #[must_use] pub fn into_inner(self) -> T { match self { - TrySendError::Closed(t) => t, - TrySendError::Full(t) => t, + Self::Closed(inner) => inner, + Self::Full(t) => t, + Self::Error { message, .. } => message, + } + } + + pub(crate) fn with_message(self, message: M) -> TrySendError { + match self { + Self::Closed(_) => TrySendError::Closed(message), + Self::Full(_) => TrySendError::Full(message), + Self::Error { error, .. } => TrySendError::Error { message, error }, } } } -impl fmt::Debug for TrySendError { +impl fmt::Debug for TrySendError { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { match self { - Self::Closed(_) => f.write_str("Closed(..)"), - Self::Full(_) => f.write_str("Full(..)"), + Self::Closed(_) => f.debug_tuple("TrySendError::Closed").finish(), + Self::Full(_) => f.debug_tuple("TrySendError::Full").finish(), + Self::Error { error, .. } => f + .debug_struct("TrySendError::Error") + .field("error", &error) + .finish_non_exhaustive(), } } } diff --git a/source/tricky-pipe/src/mpsc/mod.rs b/source/tricky-pipe/src/mpsc/mod.rs index 40b57a5..5a7d982 100644 --- a/source/tricky-pipe/src/mpsc/mod.rs +++ b/source/tricky-pipe/src/mpsc/mod.rs @@ -37,16 +37,16 @@ pub use self::arc_impl::*; /// /// A `Receiver` for a channel can be obtained using the /// [`StaticTrickyPipe::receiver`] and [`TrickyPipe::receiver`] methods. -pub struct Receiver { - pipe: TypedPipe, +pub struct Receiver { + pipe: TypedPipe, } /// Sends `T`-typed values to an associated [`Receiver`]s or [`SerReceiver`]. /// /// A `Sender` for a channel can be obtained using the /// [`StaticTrickyPipe::sender`] and [`TrickyPipe::sender`] methods. -pub struct Sender { - pipe: TypedPipe, +pub struct Sender { + pipe: TypedPipe, } /// Receives serialized values from associated [`Sender`]s or [`DeserSender`]s. @@ -55,8 +55,8 @@ pub struct Sender { /// [`StaticTrickyPipe::ser_receiver`] and [`TrickyPipe::ser_receiver`] methods, /// when the channel's message type implements [`Serialize`]. Messages may be /// sent as typed values by a [`Sender`], or as serialized bytes by a [`DeserSender`]. -pub struct SerReceiver { - pipe: ErasedPipe, +pub struct SerReceiver { + pipe: ErasedPipe, vtable: &'static SerVtable, } @@ -67,8 +67,8 @@ pub struct SerReceiver { /// when the channel's message type implements [`DeserializeOwned`]. Messages may be /// received as deserialized typed values by a [`Receiver`], or as serialized /// bytes by a [`SerReceiver`]. -pub struct DeserSender { - pipe: ErasedPipe, +pub struct DeserSender { + pipe: ErasedPipe, vtable: &'static DeserVtable, } @@ -77,8 +77,8 @@ pub struct DeserSender { /// See [the method documentation for `recv`](Receiver::recv) for details. #[must_use = "futures do nothing unless `.await`ed or `poll`ed"] #[derive(Debug)] -pub struct Recv<'rx, T: 'static> { - rx: &'rx Receiver, +pub struct Recv<'rx, T: 'static, E: 'static = ()> { + rx: &'rx Receiver, } /// Future returned by [`SerReceiver::recv`]. @@ -86,8 +86,8 @@ pub struct Recv<'rx, T: 'static> { /// See [the method documentation for `recv`](SerReceiver::recv) for details. #[must_use = "futures do nothing unless `.await`ed or `poll`ed"] #[derive(Debug)] -pub struct SerRecv<'rx> { - rx: &'rx SerReceiver, +pub struct SerRecv<'rx, E: 'static = ()> { + rx: &'rx SerReceiver, } /// A reference to a type-erased, serializable message received from a @@ -108,8 +108,8 @@ pub struct SerRecv<'rx> { /// [`to_vec_framed`]: Self::to_vec_framed #[must_use = "a `SerRecvRef` does nothing unless the `to_slice`, \ `to_slice_framed`, `to_vec`, or `to_vec_framed` methods are called"] -pub struct SerRecvRef<'pipe> { - res: Reservation<'pipe>, +pub struct SerRecvRef<'pipe, E: 'static = ()> { + res: Reservation<'pipe, E>, elems: ErasedSlice, vtable: &'static SerVtable, } @@ -135,10 +135,10 @@ pub struct SerRecvRef<'pipe> { /// [`commit`]: Self::commit #[must_use = "a `Permit` does nothing unless the `send` or `commit` \ methods are called"] -pub struct Permit<'core, T> { +pub struct Permit<'core, T, E> { // load bearing drop ordering lol lmao cell: cell::MutPtr>, - pipe: Reservation<'core>, + pipe: Reservation<'core, E>, } /// A permit to send a single serialized value to a channel. @@ -156,8 +156,8 @@ pub struct Permit<'core, T> { /// [`send_framed`]: Self::send_framed #[must_use = "a `SerPermit` does nothing unless the `send` or `send_framed` ' methods are called"] -pub struct SerPermit<'core> { - res: Reservation<'core>, +pub struct SerPermit<'core, E> { + res: Reservation<'core, E>, elems: ErasedSlice, vtable: &'static DeserVtable, } @@ -166,7 +166,10 @@ type Cell = UnsafeCell>; // === impl Receiver === -impl Receiver { +impl Receiver +where + E: Clone, +{ /// Attempts to receive the next message from the channel, without waiting /// for a new message to be sent. /// @@ -186,7 +189,7 @@ impl Receiver { /// messages sent before the channel closed have already been received. /// - [`Err`]`(`[`TryRecvError::Empty`]`)` if there are currently no /// messages in the queue, but the channel has not been closed. - pub fn try_recv(&self) -> Result { + pub fn try_recv(&self) -> Result> { self.pipe .core() .try_dequeue() @@ -222,22 +225,22 @@ impl Receiver { /// complete, and another future completes first, it is guaranteed that no /// message will be received from the channel. #[inline] - pub fn recv(&self) -> Recv<'_, T> { + pub fn recv(&self) -> Recv<'_, T, E> { Recv { rx: self } } /// Polls to receive a message from the channel, returning [`Poll::Ready`] /// if a message has been recieved, or [`Poll::Pending`] if there are /// currently no messages in the channel. - pub fn poll_recv(&self, cx: &mut Context<'_>) -> Poll> { + pub fn poll_recv(&self, cx: &mut Context<'_>) -> Poll>> { self.pipe .core() .poll_dequeue(cx) - .map(|res| Some(self.take_value(res?))) + .map(|res| Ok(self.take_value(res?))) } #[inline(always)] - fn take_value(&self, res: Reservation<'_>) -> T { + fn take_value(&self, res: Reservation<'_, E>) -> T { self.pipe.elems()[res.idx as usize].with(|ptr| unsafe { (*ptr).assume_init_read() }) } @@ -300,39 +303,47 @@ impl Receiver { } } -impl Drop for Receiver { +impl Drop for Receiver { fn drop(&mut self) { self.pipe.core().close_rx(); } } -impl fmt::Debug for Receiver { +impl fmt::Debug for Receiver { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { self.pipe.fmt_into(&mut f.debug_struct("Receiver")) } } -impl futures::Stream for &'_ Receiver { - type Item = T; +impl futures::Stream for &'_ Receiver { + type Item = Result; #[inline] fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - self.as_ref().get_ref().poll_recv(cx) + self.as_ref().get_ref().poll_recv(cx).map(|res| match res { + Ok(res) => Some(Ok(res)), + Err(RecvError::Closed) => None, + Err(RecvError::Error(error)) => Some(Err(error)), + }) } } -impl futures::Stream for Receiver { - type Item = T; +impl futures::Stream for Receiver { + type Item = Result; #[inline] fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - self.as_ref().get_ref().poll_recv(cx) + self.as_ref().get_ref().poll_recv(cx).map(|res| match res { + Ok(res) => Some(Ok(res)), + Err(RecvError::Closed) => None, + Err(RecvError::Error(error)) => Some(Err(error)), + }) } } // === impl SerReceiver === -impl SerReceiver { +impl SerReceiver { /// Attempts to receive the serialized representation of next message from /// the channel, without waiting for a new message to be sent if none are /// available. @@ -361,7 +372,7 @@ impl SerReceiver { /// messages in the queue, but the channel has not been closed. /// /// [`Vec`]: alloc::vec::Vec - pub fn try_recv(&self) -> Result, TryRecvError> { + pub fn try_recv(&self) -> Result, TryRecvError> { let res = self.pipe.core().try_dequeue()?; Ok(SerRecvRef { res, @@ -414,16 +425,16 @@ impl SerReceiver { /// message will be received from the channel. /// /// [`Vec`]: alloc::vec::Vec - pub fn recv(&self) -> SerRecv<'_> { + pub fn recv(&self) -> SerRecv<'_, E> { SerRecv { rx: self } } /// Polls to receive a serialized message from the channel, returning /// [`Poll::Ready`] if a message has been recieved, or [`Poll::Pending`] if /// there are currently no messages in the channel. - pub fn poll_recv(&self, cx: &mut Context<'_>) -> Poll>> { + pub fn poll_recv(&self, cx: &mut Context<'_>) -> Poll, RecvError>> { self.pipe.core().poll_dequeue(cx).map(|res| { - Some(SerRecvRef { + Ok(SerRecvRef { res: res?, elems: self.pipe.elems(), vtable: self.vtable, @@ -490,31 +501,35 @@ impl SerReceiver { } } -impl Drop for SerReceiver { +impl Drop for SerReceiver { fn drop(&mut self) { self.pipe.core().close_rx(); } } -impl fmt::Debug for SerReceiver { +impl fmt::Debug for SerReceiver { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { self.pipe.fmt_into(&mut f.debug_struct("SerReceiver")) } } -impl<'rx> futures::Stream for &'rx SerReceiver { - type Item = SerRecvRef<'rx>; +impl<'rx, E: Clone> futures::Stream for &'rx SerReceiver { + type Item = Result, E>; #[inline] fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - self.as_ref().get_ref().poll_recv(cx) + self.as_ref().get_ref().poll_recv(cx).map(|res| match res { + Ok(res) => Some(Ok(res)), + Err(RecvError::Closed) => None, + Err(RecvError::Error(error)) => Some(Err(error)), + }) } } // === impl Recv === -impl Future for Recv<'_, T> { - type Output = Option; +impl Future for Recv<'_, T, E> { + type Output = Result>; #[inline] fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { @@ -524,8 +539,8 @@ impl Future for Recv<'_, T> { // === impl SerRecv === -impl<'rx> Future for SerRecv<'rx> { - type Output = Option>; +impl<'rx, E: Clone> Future for SerRecv<'rx, E> { + type Output = Result, RecvError>; #[inline] fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { @@ -535,7 +550,7 @@ impl<'rx> Future for SerRecv<'rx> { // === impl SerRecvRef === -impl SerRecvRef<'_> { +impl SerRecvRef<'_, E> { /// Attempt to serialize the received item into the provided buffer /// /// This function will fail if the provided buffer was too small. @@ -564,7 +579,7 @@ impl SerRecvRef<'_> { } } -impl fmt::Debug for SerRecvRef<'_> { +impl fmt::Debug for SerRecvRef<'_, E> { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { let Self { res, @@ -578,7 +593,7 @@ impl fmt::Debug for SerRecvRef<'_> { } } -impl Drop for SerRecvRef<'_> { +impl Drop for SerRecvRef<'_, E> { fn drop(&mut self) { let Self { res, elems, vtable } = self; unsafe { @@ -593,15 +608,15 @@ impl Drop for SerRecvRef<'_> { // Safety: this is safe, because a `SerRecvRef` can only be constructed by a // `SerReceiver`, and `SerReceiver`s may only be constructed for a pipe whose // messages are `Send`. -unsafe impl Send for SerRecvRef<'_> {} +unsafe impl Send for SerRecvRef<'_, E> {} // Safety: this is safe, because a `SerRecvRef` can only be constructed by a // `SerReceiver`, and `SerReceiver`s may only be constructed for a pipe whose // messages are `Send`. -unsafe impl Sync for SerRecvRef<'_> {} +unsafe impl Sync for SerRecvRef<'_, E> {} // === impl DeserSender === -impl DeserSender { +impl DeserSender { /// Reserve capacity to send a serialized message to the channel. /// /// If the channel is currently at capacity, this method waits until @@ -636,7 +651,7 @@ impl DeserSender { /// This channel uses a queue to ensure that calls to `send` and `reserve` /// complete in the order they were requested. Cancelling a call to /// `reserve` causes the caller to lose its place in that queue. - pub async fn reserve(&self) -> Result, SendError<()>> { + pub async fn reserve(&self) -> Result, SendError> { self.pipe.core().reserve().await.map(|res| SerPermit { res, elems: self.pipe.elems(), @@ -677,7 +692,7 @@ impl DeserSender { /// have capacity to send another message without waiting. A subsequent /// call to `try_reserve` may complete successfully, once capacity has /// become available again. - pub fn try_reserve(&self) -> Result, TrySendError> { + pub fn try_reserve(&self) -> Result, TrySendError> { self.pipe.core().try_reserve().map(|res| SerPermit { res, elems: self.pipe.elems(), @@ -694,7 +709,7 @@ impl DeserSender { /// /// This is equivalent to calling [DeserSender::try_reserve] followed by /// [SerPermit::send]. - pub fn try_send(&self, bytes: impl AsRef<[u8]>) -> Result<(), SerTrySendError> { + pub fn try_send(&self, bytes: impl AsRef<[u8]>) -> Result<(), SerTrySendError> { self.try_reserve() .map_err(SerTrySendError::Send)? .send(bytes) @@ -710,7 +725,7 @@ impl DeserSender { /// /// This is equivalent to calling [DeserSender::try_reserve] followed by /// [SerPermit::send_framed]. - pub fn try_send_framed(&self, bytes: impl AsRef<[u8]>) -> Result<(), SerTrySendError> { + pub fn try_send_framed(&self, bytes: impl AsRef<[u8]>) -> Result<(), SerTrySendError> { self.try_reserve() .map_err(SerTrySendError::Send)? .send_framed(bytes) @@ -809,7 +824,7 @@ impl DeserSender { } } -impl Clone for DeserSender { +impl Clone for DeserSender { fn clone(&self) -> Self { self.pipe.core().add_tx(); Self { @@ -819,20 +834,20 @@ impl Clone for DeserSender { } } -impl Drop for DeserSender { +impl Drop for DeserSender { fn drop(&mut self) { self.pipe.core().drop_tx(); } } -impl fmt::Debug for DeserSender { +impl fmt::Debug for DeserSender { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { self.pipe.fmt_into(&mut f.debug_struct("DeserSender")) } } // === impl SerPermit === -impl SerPermit<'_> { +impl SerPermit<'_, E> { /// Attempt to send the given bytes /// /// This will attempt to deserialize the bytes into the reservation, consuming @@ -865,15 +880,15 @@ impl SerPermit<'_> { // Safety: this is safe, because a `SerPermit` can only be constructed by a // `SerSender`, and `SerSender`s may only be constructed for a pipe whose // messages are `Send`. -unsafe impl Send for SerPermit<'_> {} +unsafe impl Send for SerPermit<'_, E> {} // Safety: this is safe, because a `SerPermit` can only be constructed by a // `SerSender`, and `SerSender`s may only be constructed for a pipe whose // messages are `Send`. -unsafe impl Sync for SerPermit<'_> {} +unsafe impl Sync for SerPermit<'_, E> {} // === impl Sender === -impl Sender { +impl Sender { /// Send a `T`-typed message to the channel. /// /// If the channel is currently at capacity, this method waits until @@ -899,13 +914,13 @@ impl Sender { /// This channel uses a queue to ensure that calls to `send` and `reserve` /// complete in the order they were requested. Cancelling a call to /// `send` causes the caller to lose its place in that queue. - pub async fn send(&self, message: T) -> Result<(), SendError> { + pub async fn send(&self, message: T) -> Result<(), SendError> { match self.reserve().await { Ok(permit) => { permit.send(message); Ok(()) } - Err(_) => Err(SendError(message)), + Err(err) => Err(err.with_message(message)), } } @@ -934,14 +949,13 @@ impl Sender { /// [`send`]: Self::send /// [`reserve`]: Self::reserve /// [`try_reserve`]: Self::try_reserve - pub fn try_send(&self, message: T) -> Result<(), TrySendError> { + pub fn try_send(&self, message: T) -> Result<(), TrySendError> { match self.try_reserve() { Ok(permit) => { permit.send(message); Ok(()) } - Err(TrySendError::Closed(())) => Err(TrySendError::Closed(message)), - Err(TrySendError::Full(())) => Err(TrySendError::Full(message)), + Err(e) => Err(e.with_message(message)), } } @@ -979,7 +993,7 @@ impl Sender { /// This channel uses a queue to ensure that calls to `send` and `reserve` /// complete in the order they were requested. Cancelling a call to /// `reserve` causes the caller to lose its place in that queue. - pub async fn reserve(&self) -> Result, SendError> { + pub async fn reserve(&self) -> Result, SendError> { let pipe = self.pipe.core().reserve().await?; let cell = self.pipe.elems()[pipe.idx as usize].get_mut(); Ok(Permit { cell, pipe }) @@ -1018,7 +1032,7 @@ impl Sender { /// have capacity to send another message without waiting. A subsequent /// call to `try_reserve` may complete successfully, once capacity has /// become available again. - pub fn try_reserve(&self) -> Result, TrySendError> { + pub fn try_reserve(&self) -> Result, TrySendError> { let pipe = self.pipe.core().try_reserve()?; let cell = self.pipe.elems()[pipe.idx as usize].get_mut(); Ok(Permit { cell, pipe }) @@ -1082,7 +1096,7 @@ impl Sender { } } -impl Clone for Sender { +impl Clone for Sender { fn clone(&self) -> Self { self.pipe.core().add_tx(); Self { @@ -1091,13 +1105,13 @@ impl Clone for Sender { } } -impl Drop for Sender { +impl Drop for Sender { fn drop(&mut self) { self.pipe.core().drop_tx(); } } -impl fmt::Debug for Sender { +impl fmt::Debug for Sender { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { self.pipe.fmt_into(&mut f.debug_struct("Sender")) } @@ -1105,7 +1119,7 @@ impl fmt::Debug for Sender { // === impl Permit === -impl Permit<'_, T> { +impl Permit<'_, T, E> { /// Write the given value into the [Permit], and send it /// /// This makes the data available to the [Receiver]. @@ -1145,20 +1159,20 @@ impl Permit<'_, T> { } } -impl fmt::Debug for Permit<'_, T> { +impl fmt::Debug for Permit<'_, T, E> { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { f.debug_tuple("Permit").field(&self.pipe).finish() } } -impl Deref for Permit<'_, T> { +impl Deref for Permit<'_, T, E> { type Target = MaybeUninit; fn deref(&self) -> &Self::Target { unsafe { &*self.cell.deref() } } } -impl DerefMut for Permit<'_, T> { +impl DerefMut for Permit<'_, T, E> { fn deref_mut(&mut self) -> &mut Self::Target { unsafe { &mut *self.cell.deref() } } @@ -1166,8 +1180,8 @@ impl DerefMut for Permit<'_, T> { // Safety: a `Permit` allows referencing a `T`, so it's morally equivalent to a // reference: a `Permit` is `Send` if `T` is `Send + Sync`. -unsafe impl Send for Permit<'_, T> {} +unsafe impl Send for Permit<'_, T, E> {} // Safety: a `Permit` allows referencing a `T`, so it's morally equivalent to a // reference: a `Permit` is `Sync` if `T` is `Sync`. -unsafe impl Sync for Permit<'_, T> {} +unsafe impl Sync for Permit<'_, T, E> {} diff --git a/source/tricky-pipe/src/mpsc/static_impl.rs b/source/tricky-pipe/src/mpsc/static_impl.rs index 246810b..0be7e7f 100644 --- a/source/tricky-pipe/src/mpsc/static_impl.rs +++ b/source/tricky-pipe/src/mpsc/static_impl.rs @@ -7,12 +7,16 @@ use super::{ /// /// This variant is intended to be used in static storage on targets /// such as embedded systems, where channels are pre-allocated at compile time. -pub struct StaticTrickyPipe { +pub struct StaticTrickyPipe { elements: [Cell; CAPACITY], - core: Core, + core: Core, } -impl StaticTrickyPipe { +impl StaticTrickyPipe +where + T: 'static, + E: 'static, +{ const EMPTY_CELL: Cell = UnsafeCell::new(MaybeUninit::uninit()); /// Create a new [`StaticTrickyPipe`]. @@ -33,7 +37,7 @@ impl StaticTrickyPipe { /// The maximum possible capacity of a [`StaticTrickyPipe`] on this platform pub const MAX_CAPACITY: usize = channel_core::MAX_CAPACITY; - const CORE_VTABLE: &'static CoreVtable = &CoreVtable { + const CORE_VTABLE: &'static CoreVtable = &CoreVtable { get_core: Self::get_core, get_elems: Self::get_elems, clone: Self::erased_clone, @@ -41,11 +45,11 @@ impl StaticTrickyPipe { type_name: core::any::type_name::, }; - fn erased(&'static self) -> ErasedPipe { + fn erased(&'static self) -> ErasedPipe { unsafe { ErasedPipe::new(self as *const _ as *const (), Self::CORE_VTABLE) } } - fn typed(&'static self) -> TypedPipe { + fn typed(&'static self) -> TypedPipe { unsafe { self.erased().typed() } } @@ -53,7 +57,7 @@ impl StaticTrickyPipe { /// /// This method will only return [`Some`] on the first call. All subsequent calls /// will return [`None`]. - pub fn receiver(&'static self) -> Option> { + pub fn receiver(&'static self) -> Option> { self.core.try_claim_rx()?; Some(Receiver { pipe: self.typed() }) @@ -62,12 +66,12 @@ impl StaticTrickyPipe { /// Obtain a [`Sender`] capable of sending `T`-typed data /// /// This function may be called multiple times. - pub fn sender(&'static self) -> Sender { + pub fn sender(&'static self) -> Sender { self.core.add_tx(); Sender { pipe: self.typed() } } - fn get_core(ptr: *const ()) -> *const Core { + fn get_core(ptr: *const ()) -> *const Core { unsafe { let ptr = ptr.cast::(); ptr::addr_of!((*ptr).core) @@ -88,7 +92,7 @@ impl StaticTrickyPipe { fn erased_drop(_: *const ()) {} } -impl StaticTrickyPipe +impl StaticTrickyPipe where T: Serialize + Send + 'static, { @@ -97,7 +101,7 @@ where /// /// This method will only return [`Some`] on the first call. All subsequent calls /// will return [`None`]. - pub fn ser_receiver(&'static self) -> Option { + pub fn ser_receiver(&'static self) -> Option> { self.core.try_claim_rx()?; Some(SerReceiver { @@ -117,16 +121,17 @@ where }; } -impl StaticTrickyPipe +impl StaticTrickyPipe where T: DeserializeOwned + Send + 'static, + E:, { /// Try to obtain a [`DeserSender`] capable of sending bytes containing /// a serialized instance of `T`. /// /// This method will only return [`Some`] on the first call. All subsequent calls /// will return [`None`]. - pub fn deser_sender(&'static self) -> DeserSender { + pub fn deser_sender(&'static self) -> DeserSender { self.core.add_tx(); DeserSender { pipe: self.erased(), @@ -137,8 +142,18 @@ where const DESER_VTABLE: &'static DeserVtable = &DeserVtable::new::(); } -unsafe impl Send for StaticTrickyPipe {} -unsafe impl Sync for StaticTrickyPipe {} +unsafe impl Send for StaticTrickyPipe +where + T: Send, + E: Send, +{ +} +unsafe impl Sync for StaticTrickyPipe +where + T: Send, + E: Send, +{ +} #[cfg(all(test, not(loom)))] mod tests { diff --git a/source/tricky-pipe/src/mpsc/tests.rs b/source/tricky-pipe/src/mpsc/tests.rs index 9cefb92..5d0c7df 100644 --- a/source/tricky-pipe/src/mpsc/tests.rs +++ b/source/tricky-pipe/src/mpsc/tests.rs @@ -92,16 +92,16 @@ mod trait_impls { #[test] fn permit() { - assert_send::>(); - assert_sync::>(); - assert_unpin::>(); + assert_send::>(); + assert_sync::>(); + assert_unpin::>(); } #[test] fn ser_permit() { - assert_send::>(); - assert_sync::>(); - assert_unpin::>(); + assert_send::>(); + assert_sync::>(); + assert_unpin::>(); } } @@ -505,7 +505,7 @@ fn spsc_try_send_in_capacity() { future::block_on(async move { let mut i = 0; - while let Some(msg) = test_dbg!(rx.recv().await) { + while let Ok(msg) = test_dbg!(rx.recv().await) { assert_eq!(msg.get_ref(), &i); i += 1; } @@ -532,7 +532,7 @@ fn spsc_send() { future::block_on(async move { let mut i = 0; - while let Some(msg) = rx.recv().await { + while let Ok(msg) = rx.recv().await { assert_eq!(msg.get_ref(), &i); i += 1; } @@ -565,7 +565,7 @@ fn mpsc_send() { let recvs = future::block_on(async move { let mut recvs = std::collections::BTreeSet::new(); - while let Some(msg) = rx.recv().await { + while let Ok(msg) = rx.recv().await { let msg = msg.into_inner(); tracing::info!(received = msg); assert!( From 8a8c8e03354bfab565e2d04d996cbd10919515a9 Mon Sep 17 00:00:00 2001 From: Eliza Weisman Date: Tue, 21 Nov 2023 10:02:09 -0800 Subject: [PATCH 09/21] more --- source/tricky-pipe/src/mpsc/arc_impl.rs | 6 +++- source/tricky-pipe/src/mpsc/channel_core.rs | 17 +++++++++++ source/tricky-pipe/src/mpsc/mod.rs | 34 +++++++++++++++++++++ source/tricky-pipe/src/mpsc/static_impl.rs | 6 +++- source/tricky-pipe/src/mpsc/tests.rs | 26 ++++++++++++++++ 5 files changed, 87 insertions(+), 2 deletions(-) diff --git a/source/tricky-pipe/src/mpsc/arc_impl.rs b/source/tricky-pipe/src/mpsc/arc_impl.rs index 733205c..8a4fb99 100644 --- a/source/tricky-pipe/src/mpsc/arc_impl.rs +++ b/source/tricky-pipe/src/mpsc/arc_impl.rs @@ -71,7 +71,10 @@ impl TrickyPipe { pub fn receiver(&self) -> Option> { self.0.core.try_claim_rx()?; - Some(Receiver { pipe: self.typed() }) + Some(Receiver { + pipe: self.typed(), + closed_error: false, + }) } /// Obtain a [`Sender`] capable of sending `T`-typed data @@ -122,6 +125,7 @@ where Some(SerReceiver { pipe: self.erased(), vtable: Self::SER_VTABLE, + closed_error: false, }) } diff --git a/source/tricky-pipe/src/mpsc/channel_core.rs b/source/tricky-pipe/src/mpsc/channel_core.rs index 278e429..42260f6 100644 --- a/source/tricky-pipe/src/mpsc/channel_core.rs +++ b/source/tricky-pipe/src/mpsc/channel_core.rs @@ -256,6 +256,23 @@ impl Core { test_println!("Core::close_rx: -> closed"); } + pub(super) fn close_rx_error(&self, error: E) { + // store the error in the channel. + self.error.with_mut(|ptr| unsafe { + // Safety: this is okay, because there is only one receiver, and the + // senders will not attempt to access the error until the receiver + // has set the `CLOSED_ERROR` bits. + // + // The receiver will not close the channel more than once. + (*ptr).write(error); + }); + // set the state to indicate that the receiver closed the channel. + test_dbg!(self.enqueue_pos.fetch_or(CLOSED_ERROR, Release)); + // notify any waiting senders that the channel is closed. + self.prod_wait.close(); + test_println!("Core::close_rx_error: -> closed"); + } + #[inline] pub(super) fn add_tx(&self) { // Using a relaxed ordering is alright here, as knowledge of the diff --git a/source/tricky-pipe/src/mpsc/mod.rs b/source/tricky-pipe/src/mpsc/mod.rs index 5a7d982..f015853 100644 --- a/source/tricky-pipe/src/mpsc/mod.rs +++ b/source/tricky-pipe/src/mpsc/mod.rs @@ -39,6 +39,7 @@ pub use self::arc_impl::*; /// [`StaticTrickyPipe::receiver`] and [`TrickyPipe::receiver`] methods. pub struct Receiver { pipe: TypedPipe, + closed_error: bool, } /// Sends `T`-typed values to an associated [`Receiver`]s or [`SerReceiver`]. @@ -58,6 +59,7 @@ pub struct Sender { pub struct SerReceiver { pipe: ErasedPipe, vtable: &'static SerVtable, + closed_error: bool, } /// Sends serialized values to an associated [`Receiver`] or [`SerReceiver`]. @@ -244,6 +246,22 @@ where self.pipe.elems()[res.idx as usize].with(|ptr| unsafe { (*ptr).assume_init_read() }) } + /// Close this channel with an error. Any subsequent attempts to send + /// messages to this channel will fail with `error`. + /// + /// This method returns `true` if the channel was successfully closed. If + /// this channel has already been closed with an error, this method does + /// nothing and returns `false`. + pub fn close_with_error(&mut self, error: E) -> bool { + if self.closed_error { + return false; + } + + self.pipe.core().close_rx_error(error); + + true + } + /// Returns `true` if this channel is empty. /// /// If this method returns `true`, calling [`Receiver::recv`] or @@ -442,6 +460,22 @@ impl SerReceiver { }) } + /// Close this channel with an error. Any subsequent attempts to send + /// messages to this channel will fail with `error`. + /// + /// This method returns `true` if the channel was successfully closed. If + /// this channel has already been closed with an error, this method does + /// nothing and returns `false`. + pub fn close_with_error(&mut self, error: E) -> bool { + if self.closed_error { + return false; + } + + self.pipe.core().close_rx_error(error); + + true + } + /// Returns `true` if this channel is empty. /// /// If this method returns `true`, calling [`Receiver::recv`] or diff --git a/source/tricky-pipe/src/mpsc/static_impl.rs b/source/tricky-pipe/src/mpsc/static_impl.rs index 0be7e7f..d99f7d3 100644 --- a/source/tricky-pipe/src/mpsc/static_impl.rs +++ b/source/tricky-pipe/src/mpsc/static_impl.rs @@ -60,7 +60,10 @@ where pub fn receiver(&'static self) -> Option> { self.core.try_claim_rx()?; - Some(Receiver { pipe: self.typed() }) + Some(Receiver { + pipe: self.typed(), + closed_error: false, + }) } /// Obtain a [`Sender`] capable of sending `T`-typed data @@ -107,6 +110,7 @@ where Some(SerReceiver { pipe: self.erased(), vtable: Self::SER_VTABLE, + closed_error: false, }) } diff --git a/source/tricky-pipe/src/mpsc/tests.rs b/source/tricky-pipe/src/mpsc/tests.rs index 5d0c7df..c0819f8 100644 --- a/source/tricky-pipe/src/mpsc/tests.rs +++ b/source/tricky-pipe/src/mpsc/tests.rs @@ -589,6 +589,32 @@ fn mpsc_send() { }) } +#[test] +fn close_error_simple() { + const CAPACITY: u8 = 2; + + loom::model(|| { + let chan = TrickyPipe::, &'static str>::new(CAPACITY); + + let mut rx = test_dbg!(chan.receiver()).expect("can't get rx"); + let tx = chan.sender(); + + rx.close_with_error("fake rx error"); + + let t1 = thread::spawn(move || { + future::block_on(async move { + let err = test_dbg!(tx.send(loom::alloc::Track::new(1)).await.unwrap_err()); + match err { + SendError::Error { error, .. } => assert_eq!(error, "fake rx error"), + err => panic!("expected SendError::Error, got {:?}", err), + } + }) + }); + + t1.join().unwrap(); + }) +} + fn do_tx( sends: usize, offset: usize, From 971e25bd9fceb31e8c3268ecd7c9878d3a4bd1cc Mon Sep 17 00:00:00 2001 From: Eliza Weisman Date: Tue, 21 Nov 2023 10:55:05 -0800 Subject: [PATCH 10/21] unbreak tricky pipe --- source/tricky-pipe/src/mpsc/channel_core.rs | 14 +++++++++----- 1 file changed, 9 insertions(+), 5 deletions(-) diff --git a/source/tricky-pipe/src/mpsc/channel_core.rs b/source/tricky-pipe/src/mpsc/channel_core.rs index 42260f6..13865d3 100644 --- a/source/tricky-pipe/src/mpsc/channel_core.rs +++ b/source/tricky-pipe/src/mpsc/channel_core.rs @@ -433,9 +433,6 @@ impl Core { test_span!("Core::try_dequeue"); let mut head = test_dbg!(self.dequeue_pos.load(Acquire)); loop { - if head & CLOSED_ERROR == CLOSED_ERROR { - return Err(TryRecvError::Error(unsafe { self.close_error() })); - } // Shift to the right to extract the actual position, and // discard the `CLOSED` and `HAS_ERROR` bits. let pos = head >> POS_SHIFT; @@ -447,7 +444,11 @@ impl Core { match test_dbg!(dif).cmp(&0) { cmp::Ordering::Less if test_dbg!(head & CLOSED) != 0 => { - return Err(TryRecvError::Closed) + if head & CLOSED_ERROR == CLOSED_ERROR { + return Err(TryRecvError::Error(unsafe { self.close_error() })); + } else { + return Err(TryRecvError::Closed); + } } cmp::Ordering::Less => return Err(TryRecvError::Empty), cmp::Ordering::Equal => match test_dbg!(self.dequeue_pos.compare_exchange_weak( @@ -477,7 +478,7 @@ impl Core { loop { // Shift one bit to the right to extract the actual position, and // discard the `CLOSED` bit. - let pos = tail >> 1; + let pos = tail >> POS_SHIFT; let slot = &self.queue[test_dbg!(pos & MASK) as usize]; let seq = slot.load(Acquire) >> SEQ_SHIFT; let dif = test_dbg!(seq as i8).wrapping_sub(test_dbg!(pos as i8)); @@ -518,6 +519,9 @@ impl Core { } } +unsafe impl Send for Core {} +unsafe impl Sync for Core {} + // === impl Reservation === impl Reservation<'_, E> { From b676a362b5963a302783f4daa98f43679914f4d0 Mon Sep 17 00:00:00 2001 From: Eliza Weisman Date: Tue, 21 Nov 2023 11:08:01 -0800 Subject: [PATCH 11/21] resets worky-ish --- source/mgnp/src/client.rs | 16 +- source/mgnp/src/conn_table.rs | 120 +++++------ source/mgnp/src/lib.rs | 7 +- source/mgnp/src/message.rs | 4 +- source/mgnp/src/registry.rs | 4 +- .../integration.rs => src/tests/e2e.rs} | 28 ++- source/mgnp/src/tests/integration.rs | 186 ++++++++++++++++++ .../{tests/support.rs => src/tests/mod.rs} | 125 +++++++----- source/tricky-pipe/src/bidi.rs | 38 ++-- 9 files changed, 367 insertions(+), 161 deletions(-) rename source/mgnp/{tests/integration.rs => src/tests/e2e.rs} (95%) create mode 100644 source/mgnp/src/tests/integration.rs rename source/mgnp/{tests/support.rs => src/tests/mod.rs} (83%) diff --git a/source/mgnp/src/client.rs b/source/mgnp/src/client.rs index 0f6f872..5ab5762 100644 --- a/source/mgnp/src/client.rs +++ b/source/mgnp/src/client.rs @@ -17,10 +17,9 @@ pub struct OutboundConnect { /// The "hello" message to send to the remote service. pub(crate) hello: serbox::Consumer, /// The local bidirectional channel to bind to the remote service. - pub(crate) channel: bidi::SerBiDi, + pub(crate) channel: bidi::SerBiDi, /// Sender for the response from the remote service. pub(crate) rsp: oneshot::Sender>, - pub(crate) reset: oneshot::Sender, } #[derive(Debug, Eq, PartialEq)] @@ -29,15 +28,11 @@ pub enum ConnectError { Nak(Rejection), } -pub struct Connection { - chan: bidi::BiDi, - reset: oneshot::Receiver, -} +pub type Connection = bidi::BiDi<::ServerMsg, ::ClientMsg, Reset>; pub struct Channels { - srv_chan: bidi::SerBiDi, - client_chan: bidi::BiDi, - reset: oneshot::Receiver, + srv_chan: bidi::SerBiDi, + client_chan: bidi::BiDi, } pub struct StaticChannels { @@ -61,7 +56,6 @@ impl Channels { Self { srv_chan, client_chan, - reset: oneshot::Receiver::new(), } } } @@ -74,7 +68,6 @@ impl Connector { Channels { srv_chan, client_chan, - reset, }: Channels, ) -> Result, ConnectError> { let permit = self @@ -89,7 +82,6 @@ impl Connector { hello, channel: srv_chan, rsp, - reset, }; permit.send(connect); match self.rsp.recv().await { diff --git a/source/mgnp/src/conn_table.rs b/source/mgnp/src/conn_table.rs index be0e1f8..8b1b8ea 100644 --- a/source/mgnp/src/conn_table.rs +++ b/source/mgnp/src/conn_table.rs @@ -4,7 +4,14 @@ use crate::{ registry, }; use core::{fmt, mem, num::NonZeroU16, task::Poll}; -use tricky_pipe::{bidi::SerBiDi, mpsc::SerPermit, oneshot}; +use tricky_pipe::{ + bidi::SerBiDi, + mpsc::{ + error::{RecvError, SendError}, + SerPermit, + }, + oneshot, +}; #[derive( Copy, Clone, PartialEq, Eq, PartialOrd, Ord, Hash, serde::Serialize, serde::Deserialize, @@ -32,21 +39,10 @@ enum State { Connecting(oneshot::Sender>), } -#[derive(Debug)] -#[non_exhaustive] -enum InboundError { - /// The connection tracking table doesn't have a connection for the provided ID. - NoSocket, - - /// The local channel for this socket has closed. - ChannelClosed, -} - #[derive(Debug)] struct Socket { state: State, - channel: SerBiDi, - reset: oneshot::Sender, + channel: SerBiDi, } enum Entry { @@ -98,7 +94,7 @@ impl ConnTable { // data or has closed. match channel.rx().poll_recv(cx) { // a local data frame is ready to send! - Poll::Ready(Some(data)) => { + Poll::Ready(Ok(data)) => { let local_id = Id::from_index(idx); return Poll::Ready(OutboundFrame::data(*remote_id, local_id, data)); } @@ -106,12 +102,13 @@ impl ConnTable { // the local stream has closed, so mark the socket as dead // and generate a reset frame to tell the remote that it's // closed. - Poll::Ready(None) => { + Poll::Ready(Err(error)) => { self.dead_index = Some(Id::from_index(idx)); - return Poll::Ready(OutboundFrame::reset( - *remote_id, - Reset::BecauseISaidSo, - )); + let reason = match error { + RecvError::Closed => Reset::BecauseISaidSo, + RecvError::Error(reason) => reason, + }; + return Poll::Ready(OutboundFrame::reset(*remote_id, reason)); } // nothing to do, move on to the next socket. @@ -158,17 +155,43 @@ impl ConnTable { ); // the remote peer's remote ID is our local ID. let id = remote_id; - let Some(socket) = self.conns.get_mut(id) else { - tracing::debug!( - id.remote = %local_id, - id.local = %id, - "process_inbound(DATA): connection does not exist, resetting...", - ); - return Some(OutboundFrame::reset(local_id, Reset::NoSuchConn)); + let socket = match self.conns.get_mut(id).and_then(|entry| entry.socket()) { + None => { + tracing::debug!( + id.remote = %local_id, + id.local = %id, + "process_inbound(DATA): connection does not exist, resetting...", + ); + return Some(OutboundFrame::reset(local_id, Reset::NoSuchConn)); + } + Some(Socket { + state: State::Open { remote_id }, + .. + }) if &local_id != remote_id => { + tracing::warn!( + id.remote = %local_id, + id.remote.actual = %remote_id, + id.local = %id, + "process_inbound(DATA): wrong remote ID, resetting...", + ); + return Some(OutboundFrame::reset(local_id, Reset::NoSuchConn)); + } + Some(Socket { + state: State::Connecting(..), + .. + }) => { + tracing::warn!( + id.remote = %local_id, + id.local = %id, + "process_inbound(DATA): recieved DATA on a socket that was not ACKed", + ); + return Some(OutboundFrame::reset(local_id, Reset::NoSuchConn)); + } + Some(socket) => socket, }; // try to reserve send capacity on this socket. - let reset = match socket.reserve_send().await { + let reset = match socket.channel.tx().reserve().await { Ok(permit) => match permit.send(frame.body) { Ok(_) => return None, Err(error) => { @@ -187,7 +210,8 @@ impl ConnTable { Reset::bad_frame(error) } }, - Err(reset) => reset, + Err(SendError::Closed(_)) => Reset::BecauseISaidSo, + Err(SendError::Error { error, .. }) => error, }; tracing::trace!( id.remote = %local_id, @@ -238,7 +262,6 @@ impl ConnTable { identity, channel, rsp, - reset, } = connect; let local_id = match self.reserve_id() { @@ -256,7 +279,6 @@ impl ConnTable { Socket { state: State::Connecting(rsp), channel, - reset, }, ); @@ -313,11 +335,10 @@ impl ConnTable { /// Accept a remote initiated connection with the provided `remote_id`. #[must_use] - fn accept(&mut self, remote_id: Id, channel: SerBiDi) -> OutboundFrame<'_> { + fn accept(&mut self, remote_id: Id, channel: SerBiDi) -> OutboundFrame<'_> { let sock = Socket { state: State::Open { remote_id }, channel, - .. }; match self.insert(sock) { @@ -358,7 +379,7 @@ impl ConnTable { async fn reset(&mut self, local_id: Id, reason: Reset) -> bool { tracing::trace!(id.local = %local_id, %reason, "reset: resetting connection..."); - let Some(Socket { state, reset, .. }) = self.remove(local_id) else { + let Some(Socket { state, channel }) = self.remove(local_id) else { tracing::warn!( id.local = %local_id, %reason, @@ -377,7 +398,8 @@ impl ConnTable { return false; } - reset.send(reason).is_ok() + let (_, mut rx) = channel.split(); + rx.close_with_error(reason) } /// Returns `true` if a connection with the provided ID was closed, `false` if @@ -457,10 +479,10 @@ impl ConnTable { // === impl Id === impl Id { - // #[cfg(test)] - // pub(crate) fn new(n: u16) -> Self { - // Self(NonZeroU16::new(n).expect("IDs must be non-zero")) - // } + #[cfg(test)] + pub(crate) fn new(n: u16) -> Self { + Self(NonZeroU16::new(n).expect("IDs must be non-zero")) + } #[cfg(not(debug_assertions))] #[must_use] @@ -572,34 +594,12 @@ impl Entries { // === impl Entry === impl Entry { - async fn reserve_send(&self) -> Result, Reset> { - self.channel() - .ok_or(Reset::NoSuchConn)? - .tx() - .reserve() - .await - .map_err(|_| Reset::BecauseISaidSo) - } - fn socket(&self) -> Option<&Socket> { match self { Entry::Occupied(ref sock) => Some(sock), _ => None, } } - - fn channel(&self) -> Option<&SerBiDi> { - self.socket().map(|sock| &sock.channel) - } -} - -impl fmt::Display for InboundError { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - match self { - Self::NoSocket => f.write_str("no socket exists for this ID"), - Self::ChannelClosed => f.write_str("local channel has closed"), - } - } } // #[cfg(test)] diff --git a/source/mgnp/src/lib.rs b/source/mgnp/src/lib.rs index a6319ec..2ddf3e5 100644 --- a/source/mgnp/src/lib.rs +++ b/source/mgnp/src/lib.rs @@ -22,6 +22,9 @@ use futures::FutureExt; use message::{InboundFrame, OutboundFrame, Rejection}; use tricky_pipe::{mpsc, oneshot, serbox}; +#[cfg(test)] +mod tests; + /// A wire-level transport for [MGNP frames](Frame). /// /// A `Wire` represents a point-to-point link between a local MGNP [`Interface`] @@ -133,7 +136,7 @@ impl Interface { /// [`serbox::Sharer`] and [`oneshot::Receiver`], which may be heap- or /// statically-allocated, use the [`Interface::connector_with`] method /// instead. - #[cfg(feature = "alloc")] + #[cfg(any(test, feature = "alloc"))] pub fn connector(&self) -> client::Connector { self.connector_with(serbox::Sharer::new(), oneshot::Receiver::new()) } @@ -207,7 +210,7 @@ where // locally-initiated connect request conn = next_conn.fuse() => { - if let Some(conn) = conn { + if let Ok(conn) = conn { out_conn = Some(conn); } else { tracing::info!("connection stream has terminated"); diff --git a/source/mgnp/src/message.rs b/source/mgnp/src/message.rs index 68aabb0..26048b2 100644 --- a/source/mgnp/src/message.rs +++ b/source/mgnp/src/message.rs @@ -159,7 +159,7 @@ pub type OutboundFrame<'data> = Frame>; #[derive(Debug)] pub enum OutboundData<'recv> { Empty, - Data(SerRecvRef<'recv>), + Data(SerRecvRef<'recv, Reset>), Rejected(serbox::Consumer), Hello(serbox::Consumer), } @@ -284,7 +284,7 @@ impl<'bytes, T: Deserialize<'bytes>> Frame { } impl<'data> Frame> { - pub fn data(remote_id: Id, local_id: Id, data: SerRecvRef<'data>) -> Self { + pub fn data(remote_id: Id, local_id: Id, data: SerRecvRef<'data, Reset>) -> Self { Self { header: Header::Data { local_id, diff --git a/source/mgnp/src/registry.rs b/source/mgnp/src/registry.rs index 41cd5de..5bb85fb 100644 --- a/source/mgnp/src/registry.rs +++ b/source/mgnp/src/registry.rs @@ -1,4 +1,4 @@ -use super::Rejection; +use super::{message::Reset, Rejection}; use serde::{de::DeserializeOwned, Deserialize, Serialize}; use tricky_pipe::bidi::SerBiDi; use uuid::Uuid; @@ -16,7 +16,7 @@ pub enum IdentityKind { /// Represents a mechanism for discovering services on the local node. pub trait Registry { - async fn connect(&self, identity: Identity, hello: &[u8]) -> Result; + async fn connect(&self, identity: Identity, hello: &[u8]) -> Result, Rejection>; } /// A service definition. diff --git a/source/mgnp/tests/integration.rs b/source/mgnp/src/tests/e2e.rs similarity index 95% rename from source/mgnp/tests/integration.rs rename to source/mgnp/src/tests/e2e.rs index 6977c57..be151a8 100644 --- a/source/mgnp/tests/integration.rs +++ b/source/mgnp/src/tests/e2e.rs @@ -1,7 +1,5 @@ -#![cfg(feature = "alloc")] -mod support; -use mgnp::{message::Rejection, registry::Identity}; -use support::*; +use super::*; +use crate::{message::Rejection, registry::Identity}; use svcs::{HelloWorldRequest, HelloWorldResponse}; #[tokio::test] @@ -25,7 +23,7 @@ async fn basically_works() { let rsp = chan.rx().recv().await; assert_eq!( rsp, - Some(HelloWorldResponse { + Ok(HelloWorldResponse { world: "world".to_string() }) ); @@ -62,7 +60,7 @@ async fn hellos_work() { let rsp = chan.rx().recv().await; assert_eq!( rsp, - Some(HelloWorldResponse { + Ok(HelloWorldResponse { world: "world".to_string() }) ); @@ -112,7 +110,7 @@ async fn nak_bad_hello() { let rsp = chan.rx().recv().await; assert_eq!( rsp, - Some(HelloWorldResponse { + Ok(HelloWorldResponse { world: "world".to_string() }) ); @@ -152,13 +150,13 @@ async fn mux_single_service() { assert_eq!( rsp1, - Some(HelloWorldResponse { + Ok(HelloWorldResponse { world: "world".to_string() }) ); assert_eq!( rsp2, - Some(HelloWorldResponse { + Ok(HelloWorldResponse { world: "world".to_string() }) ); @@ -228,7 +226,7 @@ async fn service_type_routing() { let rsp = helloworld_chan.rx().recv().await; assert_eq!( rsp, - Some(HelloWorldResponse { + Ok(HelloWorldResponse { world: "world".to_string() }) ); @@ -264,7 +262,7 @@ async fn service_type_routing() { let rsp = helloworld_chan.rx().recv().await; assert_eq!( rsp, - Some(HelloWorldResponse { + Ok(HelloWorldResponse { world: "world".to_string() }) ); @@ -272,7 +270,7 @@ async fn service_type_routing() { let rsp = hellohello_chan.rx().recv().await; assert_eq!( rsp, - Some(HelloWorldResponse { + Ok(HelloWorldResponse { world: "world".to_string() }) ); @@ -318,7 +316,7 @@ async fn service_identity_routing() { let rsp = sf_conn.rx().recv().await; assert_eq!( rsp, - Some(HelloWorldResponse { + Ok(HelloWorldResponse { world: "san francisco".to_string() }) ); @@ -346,13 +344,13 @@ async fn service_identity_routing() { let (sf_rsp, uni_rsp) = tokio::join! { sf_conn.rx().recv(), uni_conn.rx().recv() }; assert_eq!( sf_rsp, - Some(HelloWorldResponse { + Ok(HelloWorldResponse { world: "san francisco".to_string() }) ); assert_eq!( uni_rsp, - Some(HelloWorldResponse { + Ok(HelloWorldResponse { world: "universe".to_string() }) ); diff --git a/source/mgnp/src/tests/integration.rs b/source/mgnp/src/tests/integration.rs new file mode 100644 index 0000000..1a7e81e --- /dev/null +++ b/source/mgnp/src/tests/integration.rs @@ -0,0 +1,186 @@ +use super::*; +use crate::{ + message::{self, InboundFrame, OutboundFrame}, + Wire, +}; +use tricky_pipe::serbox; + +#[tokio::test] +async fn reset_decode_error() { + let remote_registry: TestRegistry = TestRegistry::default(); + remote_registry.spawn_hello_world(); + + let mut fixture = Fixture::new().spawn_remote(remote_registry); + let mut wire = fixture.take_local_wire(); + let mut hellobox = serbox::Sharer::new(); + let hello = hellobox.share(()).await; + + wire.send(OutboundFrame::connect( + crate::Id::new(1), + svcs::hello_world_id(), + hello, + )) + .await + .unwrap(); + + let frame = wire.recv().await.unwrap(); + let msg = InboundFrame::from_bytes(&frame[..]); + assert_eq!( + msg, + Ok(InboundFrame { + header: message::Header::Ack { + local_id: crate::Id::new(1), + remote_id: crate::Id::new(1), + }, + body: &[] + }) + ); + + let mut out_frame = postcard::to_allocvec(&message::Header::Data { + local_id: crate::Id::new(1), + remote_id: crate::Id::new(1), + }) + .unwrap(); + out_frame.extend(&[0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff]); + + wire.send_bytes(out_frame).await.unwrap(); + + let frame = wire.recv().await.unwrap(); + let msg = InboundFrame::from_bytes(&frame[..]); + assert_eq!( + msg, + Ok(InboundFrame { + header: message::Header::Reset { + remote_id: crate::Id::new(1), + reason: message::Reset::YouDoneGoofed(message::DecodeError::Body( + message::DecodeErrorKind::UnexpectedEnd + )) + }, + body: &[] + }) + ); +} + +#[tokio::test] +async fn reset_no_such_conn() { + let remote_registry: TestRegistry = TestRegistry::default(); + remote_registry.spawn_hello_world(); + + let mut fixture = Fixture::new().spawn_remote(remote_registry); + let mut wire = fixture.take_local_wire(); + let mut hellobox = serbox::Sharer::new(); + let hello = hellobox.share(()).await; + + wire.send(OutboundFrame::connect( + crate::Id::new(1), + svcs::hello_world_id(), + hello, + )) + .await + .unwrap(); + + let frame = wire.recv().await.unwrap(); + let msg = InboundFrame::from_bytes(&frame[..]); + assert_eq!( + msg, + Ok(InboundFrame { + header: message::Header::Ack { + local_id: crate::Id::new(1), + remote_id: crate::Id::new(1), + }, + body: &[] + }) + ); + + let chan = tricky_pipe::mpsc::TrickyPipe::new(8); + let rx = chan.ser_receiver().unwrap(); + let tx = chan.sender(); + tx.try_send(svcs::HelloWorldRequest { + hello: "hello".into(), + }) + .unwrap(); + + let body = rx.try_recv().unwrap(); + + let out_frame = { + let frame = OutboundFrame { + header: message::Header::Data { + local_id: crate::Id::new(1), + remote_id: crate::Id::new(1), // good conn ID + }, + body: message::OutboundData::Data(body), + }; + frame.to_vec().unwrap() + }; + + wire.send_bytes(out_frame).await.unwrap(); + + let frame = wire.recv().await.unwrap(); + let msg = InboundFrame::from_bytes(&frame[..]).unwrap(); + assert_eq!( + postcard::from_bytes(msg.body), + Ok(svcs::HelloWorldResponse { + world: "world".into() + }) + ); + + // another message, with a bad conn ID + tx.try_send(svcs::HelloWorldRequest { + hello: "hello".into(), + }) + .unwrap(); + let body = rx.try_recv().unwrap(); + let out_frame = OutboundFrame { + header: message::Header::Data { + remote_id: crate::Id::new(666), // bad conn ID + local_id: crate::Id::new(1), + }, + body: message::OutboundData::Data(body), + } + .to_vec() + .unwrap(); + + wire.send_bytes(out_frame).await.unwrap(); + let frame = wire.recv().await.unwrap(); + let msg = dbg!(InboundFrame::from_bytes(&frame[..])); + assert_eq!( + msg, + Ok(InboundFrame { + header: message::Header::Reset { + remote_id: crate::Id::new(1), + reason: message::Reset::NoSuchConn, + }, + body: &[] + }) + ); + + // another message, with a differently conn ID + tx.try_send(svcs::HelloWorldRequest { + hello: "hello".into(), + }) + .unwrap(); + let body = rx.try_recv().unwrap(); + let out_frame = OutboundFrame { + header: message::Header::Data { + remote_id: crate::Id::new(1), + local_id: crate::Id::new(666), // bad conn ID + }, + body: message::OutboundData::Data(body), + } + .to_vec() + .unwrap(); + + wire.send_bytes(out_frame).await.unwrap(); + let frame = wire.recv().await.unwrap(); + let msg = dbg!(InboundFrame::from_bytes(&frame[..])); + assert_eq!( + msg, + Ok(InboundFrame { + header: message::Header::Reset { + remote_id: crate::Id::new(666), + reason: message::Reset::NoSuchConn, + }, + body: &[] + }) + ); +} diff --git a/source/mgnp/tests/support.rs b/source/mgnp/src/tests/mod.rs similarity index 83% rename from source/mgnp/tests/support.rs rename to source/mgnp/src/tests/mod.rs index a1b0b47..c087d12 100644 --- a/source/mgnp/tests/support.rs +++ b/source/mgnp/src/tests/mod.rs @@ -1,6 +1,5 @@ -#![cfg(feature = "alloc")] -use mgnp::{ - message::{OutboundFrame, Rejection}, +use crate::{ + message::{OutboundFrame, Rejection, Reset}, registry::{self, Registry}, tricky_pipe::{ bidi::{BiDi, SerBiDi}, @@ -16,9 +15,12 @@ use std::{ use tokio::sync::{mpsc, oneshot, Notify}; pub use tracing::Instrument; -pub mod svcs { +mod e2e; +mod integration; + +pub(crate) mod svcs { use super::*; - use mgnp::registry; + use crate::registry; use uuid::{uuid, Uuid}; pub struct HelloWorld; @@ -88,10 +90,10 @@ pub mod svcs { worker: usize, req_msg: &'static str, rsp_msg: &'static str, - chan: BiDi, + chan: BiDi, ) { tracing::debug!("hello world worker {worker} running..."); - while let Some(req) = chan.rx().recv().await { + while let Ok(req) = chan.rx().recv().await { tracing::info!(?req); assert_eq!(req.hello, req_msg); chan.tx() @@ -110,19 +112,19 @@ pub struct Fixture { test_done: Arc, } -impl Fixture { +impl Fixture, Option> { pub fn new() -> Self { Self::default() } } -impl Default for Fixture { +impl Default for Fixture, Option> { fn default() -> Self { trace_init(); let (local, remote) = TestWire::new(); Self { - local, - remote, + local: Some(local), + remote: Some(remote), test_done: Arc::new(Notify::new()), } } @@ -130,52 +132,70 @@ impl Default for Fixture { type Running = (Interface, tokio::task::JoinHandle<()>); -impl Fixture { +impl Fixture { + fn spawn_peer( + name: &'static str, + wire: TestWire, + registry: TestRegistry, + test_done: &Arc, + ) -> Running { + let (iface, machine) = Interface::new::<_, _, { crate::DEFAULT_MAX_CONNS }>( + wire, + registry, + TrickyPipe::new(8), + ); + let task = tokio::spawn(interface("name", machine, test_done.clone())); + (iface, task) + } +} + +impl Fixture, R> { pub fn spawn_local(self, registry: TestRegistry) -> Fixture { let Fixture { - local, + mut local, remote, test_done, } = self; - - let (iface, machine) = Interface::new::<_, _, { mgnp::DEFAULT_MAX_CONNS }>( - local, - registry, - TrickyPipe::new(8), - ); + let wire = local + .take() + .expect("attempted to take the local end of the wire twice!"); Fixture { - local: ( - iface, - tokio::spawn(interface("local", machine, test_done.clone())), - ), + local: Self::spawn_peer("local", wire, registry, &test_done), remote, test_done, } } + + pub fn take_local_wire(&mut self) -> TestWire { + self.local + .take() + .expect("attempted to take the local end of the wire twice!") + } } -impl Fixture { +impl Fixture> { pub fn spawn_remote(self, registry: TestRegistry) -> Fixture { let Fixture { local, - remote, + mut remote, test_done, } = self; - let (iface, machine) = Interface::new::<_, _, { mgnp::DEFAULT_MAX_CONNS }>( - remote, - registry, - TrickyPipe::new(8), - ); + let wire = remote + .take() + .expect("attempted to take the remote end of the wire twice!"); Fixture { local, - remote: ( - iface, - tokio::spawn(interface("remote", machine, test_done.clone())), - ), + remote: Self::spawn_peer("local", wire, registry, &test_done), test_done, } } + + pub fn take_remote_wire(&mut self) -> TestWire { + self.remote + .take() + .expect("attempted to take the remote end of the wire twice!") + } } impl Fixture { @@ -211,7 +231,7 @@ impl Fixture { #[tracing::instrument(level = tracing::Level::INFO, skip(machine, test_done))] async fn interface( peer: &'static str, - mut machine: mgnp::Machine, + mut machine: crate::Machine, test_done: Arc, ) { tokio::select! { @@ -240,7 +260,7 @@ pub fn trace_init() { .try_init(); } -pub fn make_bidis(cap: u8) -> (SerBiDi, BiDi) +pub fn make_bidis(cap: u8) -> (SerBiDi, BiDi) where In: serde::Serialize + serde::de::DeserializeOwned + Send + 'static, Out: serde::Serialize + serde::de::DeserializeOwned + Send + 'static, @@ -277,7 +297,7 @@ pub struct TestFrame(Vec); pub struct InboundConnect { pub hello: Vec, - pub rsp: oneshot::Sender>, + pub rsp: oneshot::Sender, Rejection>>, } // === impl TestRegistry === @@ -288,7 +308,7 @@ impl Registry for TestRegistry { &self, identity: registry::Identity, hello: &[u8], - ) -> Result { + ) -> Result, Rejection> { let Some(svc) = self.svcs.read().unwrap().get(&identity).cloned() else { tracing::info!("REGISTRY: service not found!"); return Err(Rejection::NotFound); @@ -372,6 +392,15 @@ impl TestWire { let (tx2, rx2) = mpsc::channel(8); (Self { tx: tx1, rx: rx2 }, Self { tx: tx2, rx: rx1 }) } + + pub async fn send_bytes(&mut self, frame: impl Into>) -> Result<(), &'static str> { + let frame = frame.into(); + tracing::info!(frame = ?HexSlice::new(&frame), "SEND"); + self.tx + .send(frame) + .await + .map_err(|_| "the recv end of this wire has been dropped") + } } impl Wire for TestWire { @@ -387,11 +416,7 @@ impl Wire for TestWire { async fn send(&mut self, msg: OutboundFrame<'_>) -> Result<(), &'static str> { tracing::info!(?msg, "sending message"); let frame = msg.to_vec().expect("message should serialize"); - tracing::info!(frame = ?HexSlice::new(&frame), "SEND"); - self.tx - .send(frame) - .await - .map_err(|_| "the recv end of this wire has been dropped") + self.send_bytes(frame).await } } @@ -424,18 +449,18 @@ impl fmt::Debug for HexSlice<'_> { #[tracing::instrument(level = tracing::Level::INFO, skip(connector, hello))] pub async fn connect_should_nak( - connector: &mut mgnp::Connector, + connector: &mut crate::Connector, name: &'static str, hello: S::Hello, - nak: mgnp::message::Rejection, + nak: crate::message::Rejection, ) { tracing::info!("connecting to {name} (should NAK)..."); let res = connector - .connect(name, hello, mgnp::client::Channels::new(8)) + .connect(name, hello, crate::client::Channels::new(8)) .await; tracing::info!(?res, "connect result"); match res { - Err(mgnp::client::ConnectError::Nak(actual)) => assert_eq!( + Err(crate::client::ConnectError::Nak(actual)) => assert_eq!( actual, nak, "expected connection to {name} to be NAK'd with {nak:?}, but it was NAK'd with {actual:?}!" ), @@ -450,13 +475,13 @@ pub async fn connect_should_nak( #[tracing::instrument(level = tracing::Level::INFO, skip(connector, hello))] pub async fn connect( - connector: &mut mgnp::Connector, + connector: &mut crate::Connector, name: &'static str, hello: S::Hello, -) -> mgnp::client::ClientChannel { +) -> crate::client::Connection { tracing::info!("connecting to {name} (should SUCCEED)..."); let res = connector - .connect(name, hello, mgnp::client::Channels::new(8)) + .connect(name, hello, crate::client::Channels::new(8)) .await; tracing::info!(?res, "connect result"); match res { diff --git a/source/tricky-pipe/src/bidi.rs b/source/tricky-pipe/src/bidi.rs index 257fbf7..f46f584 100644 --- a/source/tricky-pipe/src/bidi.rs +++ b/source/tricky-pipe/src/bidi.rs @@ -13,9 +13,9 @@ use futures::FutureExt; /// This channel consists of a [`Sender`] paired with a [`Receiver`], and can be /// used to both send and receive typed messages to and from a remote peer. #[must_use] -pub struct BiDi { - tx: Sender, - rx: Receiver, +pub struct BiDi { + tx: Sender, + rx: Receiver, } /// A bidirectional type-erased serializing channel. @@ -24,9 +24,9 @@ pub struct BiDi { /// and can be used to both send and receive serialized messages to and from a /// remote peer. #[must_use] -pub struct SerBiDi { - tx: DeserSender, - rx: SerReceiver, +pub struct SerBiDi { + tx: DeserSender, + rx: SerReceiver, } /// Events returned by [`BiDi::wait`] and [`SerBiDi::wait`]. @@ -39,19 +39,20 @@ pub enum Event { SendReady(Out), } -impl BiDi +impl BiDi where In: 'static, Out: 'static, + E: Clone + 'static, { /// Constructs a new `BiDi` from a [`Sender`] and a [`Receiver`]. - pub fn from_pair(tx: Sender, rx: Receiver) -> Self { + pub fn from_pair(tx: Sender, rx: Receiver) -> Self { Self { tx, rx } } /// Consumes `self`, extracting the inner [`Sender`] and [`Receiver`]. #[must_use] - pub fn split(self) -> (Sender, Receiver) { + pub fn split(self) -> (Sender, Receiver) { (self.tx, self.rx) } @@ -78,7 +79,7 @@ where /// [`Sender::try_reserve`], [`Sender::capacity`], et cetera, on the send /// half of the channel. #[must_use] - pub fn tx(&self) -> &Sender { + pub fn tx(&self) -> &Sender { &self.tx } @@ -88,7 +89,7 @@ where /// [`Receiver::try_recv`], [`Receiver::capacity`], et cetera, on the /// receive half of the channel. #[must_use] - pub fn rx(&self) -> &Receiver { + pub fn rx(&self) -> &Receiver { &self.rx } @@ -121,10 +122,11 @@ where } } -impl fmt::Debug for BiDi +impl fmt::Debug for BiDi where In: 'static, Out: 'static, + E: 'static, { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { let Self { tx, rx } = self; @@ -137,15 +139,15 @@ where // === impl SerBiDi === -impl SerBiDi { +impl SerBiDi { /// Constructs a new `SerBiDi` from a [`DeserSender`] and a [`SerReceiver`]. - pub fn from_pair(tx: DeserSender, rx: SerReceiver) -> Self { + pub fn from_pair(tx: DeserSender, rx: SerReceiver) -> Self { Self { tx, rx } } /// Consumes `self`, extracting the inner [`DeserSender`] and [`SerReceiver`]. #[must_use] - pub fn split(self) -> (DeserSender, SerReceiver) { + pub fn split(self) -> (DeserSender, SerReceiver) { (self.tx, self.rx) } @@ -172,7 +174,7 @@ impl SerBiDi { /// [`DeserSender::try_reserve`], [`DeserSender::capacity`], et cetera, on /// the send half of the channel. #[must_use] - pub fn tx(&self) -> &DeserSender { + pub fn tx(&self) -> &DeserSender { &self.tx } @@ -182,7 +184,7 @@ impl SerBiDi { /// [`SerReceiver::try_recv`], [`SerReceiver::capacity`], et cetera, on the /// receive half of the channel. #[must_use] - pub fn rx(&self) -> &SerReceiver { + pub fn rx(&self) -> &SerReceiver { &self.rx } @@ -215,7 +217,7 @@ impl SerBiDi { } } -impl fmt::Debug for SerBiDi { +impl fmt::Debug for SerBiDi { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { let Self { tx, rx } = self; f.debug_struct("SerBiDi") From fdd2927c56d2ef4a297d64d51b3504059d45f9e6 Mon Sep 17 00:00:00 2001 From: Eliza Weisman Date: Tue, 21 Nov 2023 15:52:41 -0800 Subject: [PATCH 12/21] moar --- source/mgnp/src/conn_table.rs | 46 ++++----- source/tricky-pipe/src/mpsc/channel_core.rs | 88 ++++++++++++---- source/tricky-pipe/src/mpsc/error.rs | 107 +++++++++++++++++--- source/tricky-pipe/src/mpsc/mod.rs | 44 ++++---- source/tricky-pipe/src/mpsc/static_impl.rs | 27 ++--- source/tricky-pipe/src/mpsc/tests.rs | 27 ++--- 6 files changed, 236 insertions(+), 103 deletions(-) diff --git a/source/mgnp/src/conn_table.rs b/source/mgnp/src/conn_table.rs index 8b1b8ea..03533b9 100644 --- a/source/mgnp/src/conn_table.rs +++ b/source/mgnp/src/conn_table.rs @@ -7,7 +7,7 @@ use core::{fmt, mem, num::NonZeroU16, task::Poll}; use tricky_pipe::{ bidi::SerBiDi, mpsc::{ - error::{RecvError, SendError}, + error::{RecvError, SendError, SerSendError}, SerPermit, }, oneshot, @@ -105,7 +105,7 @@ impl ConnTable { Poll::Ready(Err(error)) => { self.dead_index = Some(Id::from_index(idx)); let reason = match error { - RecvError::Closed => Reset::BecauseISaidSo, + RecvError::Disconnected => Reset::BecauseISaidSo, RecvError::Error(reason) => reason, }; return Poll::Ready(OutboundFrame::reset(*remote_id, reason)); @@ -191,27 +191,27 @@ impl ConnTable { }; // try to reserve send capacity on this socket. - let reset = match socket.channel.tx().reserve().await { - Ok(permit) => match permit.send(frame.body) { - Ok(_) => return None, - Err(error) => { - // TODO(eliza): possibly it would be better if we - // just sent the deserialize error to the local peer - // and let it decide whether this should kill the - // connection or not? maybe by turning the server's - // client-to-server stream into `Result`s? - tracing::debug!( - id.remote = %local_id, - id.local = %id, - %error, - "process_inbound(DATA): failed to deserialize; resetting...", - ); - Reset::bad_frame(error) - } - }, - Err(SendError::Closed(_)) => Reset::BecauseISaidSo, - Err(SendError::Error { error, .. }) => error, + let reset = match socket.channel.tx().send(frame.body).await { + Ok(_) => return None, + Err(SerSendError::Deserialize(error)) => { + // TODO(eliza): possibly it would be better if we + // just sent the deserialize error to the local peer + // and let it decide whether this should kill the + // connection or not? maybe by turning the server's + // client-to-server stream into `Result`s? + tracing::debug!( + id.remote = %local_id, + id.local = %id, + %error, + "process_inbound(DATA): failed to deserialize; resetting...", + ); + Reset::bad_frame(error) + } + // channel closed gracefully. + Err(SerSendError::Disconnected) => Reset::BecauseISaidSo, + // remote server reset the connection. + Err(SerSendError::Error(error)) => error, }; tracing::trace!( id.remote = %local_id, diff --git a/source/tricky-pipe/src/mpsc/channel_core.rs b/source/tricky-pipe/src/mpsc/channel_core.rs index 13865d3..db4917c 100644 --- a/source/tricky-pipe/src/mpsc/channel_core.rs +++ b/source/tricky-pipe/src/mpsc/channel_core.rs @@ -76,8 +76,12 @@ pub(super) struct Core { /// This is the length of the actual queue elements array (which is not part /// of this struct). pub(super) capacity: u8, - /// If the channel closed with an error, this is the error. - error: UnsafeCell>, + /// If the receiver side of the channel closed with an error, this is the + /// error. + send_error: UnsafeCell>, + /// If the sender side of the channel closed with an error, this is the + /// error. + recv_error: UnsafeCell>, } pub(super) struct Reservation<'core, E> { @@ -198,7 +202,8 @@ impl Core { // dropped. state: AtomicUsize::new(state::TX_ONE), capacity, - error: UnsafeCell::new(MaybeUninit::uninit()), + send_error: UnsafeCell::new(MaybeUninit::uninit()), + recv_error: UnsafeCell::new(MaybeUninit::uninit()), } } @@ -258,7 +263,7 @@ impl Core { pub(super) fn close_rx_error(&self, error: E) { // store the error in the channel. - self.error.with_mut(|ptr| unsafe { + self.send_error.with_mut(|ptr| unsafe { // Safety: this is okay, because there is only one receiver, and the // senders will not attempt to access the error until the receiver // has set the `CLOSED_ERROR` bits. @@ -273,6 +278,27 @@ impl Core { test_println!("Core::close_rx_error: -> closed"); } + pub(super) fn close_tx_error(&self, error: E) -> bool { + if test_dbg!(self.dequeue_pos.fetch_or(HAS_ERROR, AcqRel) & HAS_ERROR) == HAS_ERROR { + // someone else is setting the close error! + return false; + } + // store the error in the channel. + self.recv_error.with_mut(|ptr| unsafe { + // Safety: this is okay, because the HAS_ERROR bit guards against + // any other sender setting the error, but the receiver will not + // read the error until the CLOSED bit is also set. For now, we have + // exclusive access to the error field. + (*ptr).write(error); + }); + // set the state to indicate that the sender closed the channel. + test_dbg!(self.dequeue_pos.fetch_or(CLOSED, Release)); + // notify any waiting senders that the channel is closed. + self.cons_wait.wake(); + test_println!("Core::close_tx_error: -> closed"); + true + } + #[inline] pub(super) fn add_tx(&self) { // Using a relaxed ordering is alright here, as knowledge of the @@ -374,7 +400,7 @@ impl Core { return Err(self .send_closed_error() .map(|error| TrySendError::Error { error, message: () }) - .unwrap_or(TrySendError::Closed(()))); + .unwrap_or(TrySendError::Disconnected(()))); } test_dbg!(self.indices.allocate()) @@ -386,14 +412,14 @@ impl Core { loop { match self.try_reserve() { Ok(res) => return Ok(res), - Err(TrySendError::Closed(())) => return Err(SendError::Closed(())), + Err(TrySendError::Disconnected(())) => return Err(SendError::Disconnected(())), Err(TrySendError::Error { error, .. }) => { return Err(SendError::Error { error, message: () }) } Err(TrySendError::Full(())) => self.prod_wait.wait().await.map_err(|_| { self.send_closed_error() .map(|error| SendError::Error { error, message: () }) - .unwrap_or(SendError::Closed(())) + .unwrap_or(SendError::Disconnected(())) })?, } } @@ -406,7 +432,9 @@ impl Core { loop { match self.try_dequeue() { Ok(res) => return Poll::Ready(Ok(res)), - Err(TryRecvError::Closed) => return Poll::Ready(Err(RecvError::Closed)), + Err(TryRecvError::Disconnected) => { + return Poll::Ready(Err(RecvError::Disconnected)) + } Err(TryRecvError::Error(error)) => { return Poll::Ready(Err(RecvError::Error(error))) } @@ -444,11 +472,10 @@ impl Core { match test_dbg!(dif).cmp(&0) { cmp::Ordering::Less if test_dbg!(head & CLOSED) != 0 => { - if head & CLOSED_ERROR == CLOSED_ERROR { - return Err(TryRecvError::Error(unsafe { self.close_error() })); - } else { - return Err(TryRecvError::Closed); - } + return Err(self + .recv_close_error() + .map(TryRecvError::Error) + .unwrap_or(TryRecvError::Disconnected)); } cmp::Ordering::Less => return Err(TryRecvError::Empty), cmp::Ordering::Equal => match test_dbg!(self.dequeue_pos.compare_exchange_weak( @@ -471,11 +498,18 @@ impl Core { } } - fn commit_send(&self, idx: u8) -> Result<(), SendError<(), E>> { + fn commit_send(&self, idx: u8) -> Result<(), SendError> { test_span!("Core::commit_send", idx); debug_assert!(idx as u16 <= MASK); let mut tail = test_dbg!(self.enqueue_pos.load(Acquire)); loop { + if test_dbg!(tail & CLOSED) == CLOSED { + return Err(self + .send_closed_error() + .map(|error| SendError::Error { error, message: () }) + .unwrap_or(SendError::Disconnected(()))); + } + // Shift one bit to the right to extract the actual position, and // discard the `CLOSED` bit. let pos = tail >> POS_SHIFT; @@ -505,17 +539,29 @@ impl Core { } fn send_closed_error(&self) -> Option { - if test_dbg!(self.enqueue_pos.load(Acquire) & CLOSED_ERROR) == CLOSED_ERROR { - Some(unsafe { self.close_error() }) + let pos = self.enqueue_pos.load(Acquire); + debug_assert_eq!(pos & CLOSED, CLOSED); + if test_dbg!(pos & CLOSED_ERROR) == CLOSED_ERROR { + Some( + self.send_error + .with(|ptr| unsafe { (*ptr).assume_init_ref().clone() }), + ) } else { None } } - unsafe fn close_error(&self) -> E { - // debug_assert!(self.enqueue_pos.load(Acquire) & CLOSED_ERROR == CLOSED_ERROR); - self.error - .with(|ptr| unsafe { (*ptr).assume_init_ref().clone() }) + fn recv_close_error(&self) -> Option { + let pos = self.dequeue_pos.load(Acquire); + debug_assert_eq!(pos & CLOSED, CLOSED); + if test_dbg!(pos & CLOSED_ERROR) == CLOSED_ERROR { + Some( + self.recv_error + .with(|ptr| unsafe { (*ptr).assume_init_ref().clone() }), + ) + } else { + None + } } } @@ -525,7 +571,7 @@ unsafe impl Sync for Core {} // === impl Reservation === impl Reservation<'_, E> { - pub(super) fn commit_send(self) -> Result<(), SendError<(), E>> { + pub(super) fn commit_send(self) -> Result<(), SendError> { // don't run the destructor that frees the index, since we are dropping // the cell... let this = ManuallyDrop::new(self); diff --git a/source/tricky-pipe/src/mpsc/error.rs b/source/tricky-pipe/src/mpsc/error.rs index 15ecb1a..e95c997 100644 --- a/source/tricky-pipe/src/mpsc/error.rs +++ b/source/tricky-pipe/src/mpsc/error.rs @@ -27,7 +27,7 @@ pub enum SendError { /// /// [`Receiver`]: super::Receiver /// [`SerReceiver`]: super::SerReceiver - Closed(T), + Disconnected(T), Error { message: T, error: E, @@ -55,7 +55,7 @@ pub enum TrySendError { /// /// [`Receiver`]: super::Receiver /// [`SerReceiver`]: super::SerReceiver - Closed(T), + Disconnected(T), Error { message: T, error: E, @@ -78,7 +78,7 @@ pub enum TryRecvError { /// /// [`Sender`]: super::Sender /// [`DeserSender`]: super::DeserSender - Closed, + Disconnected, Error(E), } @@ -95,7 +95,7 @@ pub enum RecvError { /// /// [`Sender`]: super::Sender /// [`DeserSender`]: super::DeserSender - Closed, + Disconnected, Error(E), } @@ -104,16 +104,17 @@ pub enum RecvError { /// [`DeserSender::send`]: super::DeserSender::send /// [`DeserSender::send_framed`]: super::DeserSender::send_framed #[derive(Debug, Eq, PartialEq)] -pub enum SerSendError { +pub enum SerSendError { /// A message cannot be sent because the channel is closed (no [`Receiver`] /// or [`SerReceiver`] exists). /// /// [`Receiver`]: super::Receiver /// [`SerReceiver`]: super::SerReceiver - Closed, + Disconnected, /// The sent bytes could not be deserialized to a value of this channel's /// message type. Deserialize(postcard::Error), + Error(E), } /// Errors returned by [`DeserSender::try_send`] and @@ -123,11 +124,18 @@ pub enum SerSendError { /// [`DeserSender::try_send_framed`]: super::DeserSender::send_framed #[derive(Debug, Eq, PartialEq)] pub enum SerTrySendError { - /// The channel is [`Closed`](TrySendError::Closed) or [`Full`](TrySendError::Full). - Send(TrySendError), + /// A message cannot be sent because the channel is closed (no [`Receiver`] + /// or [`SerReceiver`] exists). + /// + /// [`Receiver`]: super::Receiver + /// [`SerReceiver`]: super::SerReceiver + Disconnected, + Full, /// The sent bytes could not be deserialized to a value of this channel's /// message type. Deserialize(postcard::Error), + + Error(E), } // === impl SendError === @@ -138,14 +146,14 @@ impl SendError { #[must_use] pub fn into_inner(self) -> T { match self { - Self::Closed(msg) => msg, + Self::Disconnected(msg) => msg, Self::Error { message, .. } => message, } } pub(crate) fn with_message(self, message: M) -> SendError { match self { - Self::Closed(_) => SendError::Closed(message), + Self::Disconnected(_) => SendError::Disconnected(message), Self::Error { error, .. } => SendError::Error { message, error }, } } @@ -154,7 +162,7 @@ impl SendError { impl fmt::Debug for SendError { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { match self { - Self::Closed(_) => f.debug_tuple("SendError::Closed").finish(), + Self::Disconnected(_) => f.debug_tuple("SendError::Closed").finish(), Self::Error { error, .. } => f .debug_struct("SendError::Error") .field("error", &error) @@ -163,6 +171,15 @@ impl fmt::Debug for SendError { } } +impl fmt::Display for SendError { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match self { + Self::Disconnected(_) => f.write_str("the receiver has been dropped"), + Self::Error { error, .. } => error.fmt(f), + } + } +} + // === impl TrySendError === impl TrySendError { @@ -171,7 +188,7 @@ impl TrySendError { #[must_use] pub fn into_inner(self) -> T { match self { - Self::Closed(inner) => inner, + Self::Disconnected(inner) => inner, Self::Full(t) => t, Self::Error { message, .. } => message, } @@ -179,7 +196,7 @@ impl TrySendError { pub(crate) fn with_message(self, message: M) -> TrySendError { match self { - Self::Closed(_) => TrySendError::Closed(message), + Self::Disconnected(_) => TrySendError::Disconnected(message), Self::Full(_) => TrySendError::Full(message), Self::Error { error, .. } => TrySendError::Error { message, error }, } @@ -189,7 +206,7 @@ impl TrySendError { impl fmt::Debug for TrySendError { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { match self { - Self::Closed(_) => f.debug_tuple("TrySendError::Closed").finish(), + Self::Disconnected(_) => f.debug_tuple("TrySendError::Closed").finish(), Self::Full(_) => f.debug_tuple("TrySendError::Full").finish(), Self::Error { error, .. } => f .debug_struct("TrySendError::Error") @@ -198,3 +215,65 @@ impl fmt::Debug for TrySendError { } } } + +impl fmt::Display for TrySendError { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match self { + Self::Disconnected(_) => f.write_str("the receiver has been dropped"), + Self::Full(_) => f.write_str("the channel is currently at capacity"), + Self::Error { error, .. } => error.fmt(f), + } + } +} + +// === impl SerSendError === + +impl SerSendError { + pub(crate) fn from_send_error(err: SendError) -> Self { + match err { + SendError::Disconnected(_) => Self::Disconnected, + SendError::Error { error, .. } => Self::Error(error), + } + } +} + +impl fmt::Display for SerSendError { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match self { + Self::Disconnected => f.write_str("the receiver has been dropped"), + Self::Error(error) => error.fmt(f), + Self::Deserialize(error) => write!(f, "error deserializing message: {error}"), + } + } +} + +// === impl SerTrySendError === + +impl SerTrySendError { + pub(crate) fn from_try_send_error(err: TrySendError) -> Self { + match err { + TrySendError::Disconnected(_) => Self::Disconnected, + TrySendError::Error { error, .. } => Self::Error(error), + TrySendError::Full(_) => Self::Full, + } + } + + pub(crate) fn from_ser_send_error(err: SerSendError) -> Self { + match err { + SerSendError::Disconnected => Self::Disconnected, + SerSendError::Error(e) => Self::Error(e), + SerSendError::Deserialize(e) => Self::Deserialize(e), + } + } +} + +impl fmt::Display for SerTrySendError { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match self { + Self::Disconnected => f.write_str("the receiver has been dropped"), + Self::Error(error) => error.fmt(f), + Self::Full => f.write_str("the channel is currently at capacity"), + Self::Deserialize(error) => write!(f, "error deserializing message: {error}"), + } + } +} diff --git a/source/tricky-pipe/src/mpsc/mod.rs b/source/tricky-pipe/src/mpsc/mod.rs index f015853..295ded1 100644 --- a/source/tricky-pipe/src/mpsc/mod.rs +++ b/source/tricky-pipe/src/mpsc/mod.rs @@ -340,7 +340,7 @@ impl futures::Stream for &'_ Receiver { fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { self.as_ref().get_ref().poll_recv(cx).map(|res| match res { Ok(res) => Some(Ok(res)), - Err(RecvError::Closed) => None, + Err(RecvError::Disconnected) => None, Err(RecvError::Error(error)) => Some(Err(error)), }) } @@ -353,7 +353,7 @@ impl futures::Stream for Receiver { fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { self.as_ref().get_ref().poll_recv(cx).map(|res| match res { Ok(res) => Some(Ok(res)), - Err(RecvError::Closed) => None, + Err(RecvError::Disconnected) => None, Err(RecvError::Error(error)) => Some(Err(error)), }) } @@ -554,7 +554,7 @@ impl<'rx, E: Clone> futures::Stream for &'rx SerReceiver { fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { self.as_ref().get_ref().poll_recv(cx).map(|res| match res { Ok(res) => Some(Ok(res)), - Err(RecvError::Closed) => None, + Err(RecvError::Disconnected) => None, Err(RecvError::Error(error)) => Some(Err(error)), }) } @@ -745,9 +745,9 @@ impl DeserSender { /// [SerPermit::send]. pub fn try_send(&self, bytes: impl AsRef<[u8]>) -> Result<(), SerTrySendError> { self.try_reserve() - .map_err(SerTrySendError::Send)? + .map_err(SerTrySendError::from_try_send_error)? .send(bytes) - .map_err(SerTrySendError::Deserialize) + .map_err(SerTrySendError::from_ser_send_error) } /// Attempt to immediately send the given framed bytes @@ -761,9 +761,9 @@ impl DeserSender { /// [SerPermit::send_framed]. pub fn try_send_framed(&self, bytes: impl AsRef<[u8]>) -> Result<(), SerTrySendError> { self.try_reserve() - .map_err(SerTrySendError::Send)? + .map_err(SerTrySendError::from_try_send_error)? .send_framed(bytes) - .map_err(SerTrySendError::Deserialize) + .map_err(SerTrySendError::from_ser_send_error) } /// Attempt to send the given bytes @@ -775,12 +775,11 @@ impl DeserSender { /// /// This is equivalent to calling [DeserSender::reserve] followed by /// [SerPermit::send]. - pub async fn send(&self, bytes: impl AsRef<[u8]>) -> Result<(), SerSendError> { + pub async fn send(&self, bytes: impl AsRef<[u8]>) -> Result<(), SerSendError> { self.reserve() .await - .map_err(|_| SerSendError::Closed)? + .map_err(SerSendError::from_send_error)? .send(bytes) - .map_err(SerSendError::Deserialize) } /// Attempt to send the given framed bytes @@ -792,12 +791,11 @@ impl DeserSender { /// /// This is equivalent to calling [DeserSender::reserve] followed by /// [SerPermit::send_framed]. - pub async fn send_framed(&self, bytes: impl AsRef<[u8]>) -> Result<(), SerSendError> { + pub async fn send_framed(&self, bytes: impl AsRef<[u8]>) -> Result<(), SerSendError> { self.reserve() .await - .map_err(|_| SerSendError::Closed)? + .map_err(SerSendError::from_send_error)? .send_framed(bytes) - .map_err(SerSendError::Deserialize) } /// Returns `true` if this channel is empty. @@ -886,28 +884,32 @@ impl SerPermit<'_, E> { /// /// This will attempt to deserialize the bytes into the reservation, consuming /// it. If the deserialization fails, the [SerPermit] is still consumed. - pub fn send(self, bytes: impl AsRef<[u8]>) -> postcard::Result<()> { + pub fn send(self, bytes: impl AsRef<[u8]>) -> Result<(), SerSendError> { // try to deserialize the bytes into the reserved pipe slot. - (self.vtable.from_bytes)(self.elems, self.res.idx, bytes.as_ref())?; + (self.vtable.from_bytes)(self.elems, self.res.idx, bytes.as_ref()) + .map_err(SerSendError::Deserialize)?; // if we successfully deserialized the bytes, commit the send. // otherwise, we'll release the send index when we drop the reservation. - self.res.commit_send(); - Ok(()) + self.res + .commit_send() + .map_err(SerSendError::from_send_error) } /// Attempt to send the given bytes /// /// This will attempt to deserialize the COBS-encoded bytes into the reservation, consuming /// it. If the deserialization fails, the [SerPermit] is still consumed. - pub fn send_framed(self, bytes: impl AsRef<[u8]>) -> postcard::Result<()> { + pub fn send_framed(self, bytes: impl AsRef<[u8]>) -> Result<(), SerSendError> { // try to deserialize the bytes into the reserved pipe slot. - (self.vtable.from_bytes_framed)(self.elems, self.res.idx, bytes.as_ref())?; + (self.vtable.from_bytes_framed)(self.elems, self.res.idx, bytes.as_ref()) + .map_err(SerSendError::Deserialize)?; // if we successfully deserialized the bytes, commit the send. // otherwise, we'll release the send index when we drop the reservation. - self.res.commit_send(); - Ok(()) + self.res + .commit_send() + .map_err(SerSendError::from_send_error) } } diff --git a/source/tricky-pipe/src/mpsc/static_impl.rs b/source/tricky-pipe/src/mpsc/static_impl.rs index d99f7d3..9a1d873 100644 --- a/source/tricky-pipe/src/mpsc/static_impl.rs +++ b/source/tricky-pipe/src/mpsc/static_impl.rs @@ -238,7 +238,10 @@ mod tests { let tx = CHAN.sender(); let rx = CHAN.receiver().unwrap(); drop(rx); - assert_eq!(tx.try_reserve().unwrap_err(), TrySendError::Closed(()),); + assert_eq!( + tx.try_reserve().unwrap_err(), + TrySendError::Disconnected(()), + ); } #[test] @@ -252,7 +255,7 @@ mod tests { let tx = CHAN.sender(); let rx = CHAN.receiver().unwrap(); drop(tx); - assert_eq!(rx.try_recv().unwrap_err(), TryRecvError::Closed,); + assert_eq!(rx.try_recv().unwrap_err(), TryRecvError::Disconnected,); } #[test] @@ -268,7 +271,7 @@ mod tests { let rx = CHAN.receiver().unwrap(); drop(tx1); drop(tx2); - assert_eq!(rx.try_recv().unwrap_err(), TryRecvError::Closed,); + assert_eq!(rx.try_recv().unwrap_err(), TryRecvError::Disconnected,); } #[test] @@ -321,7 +324,10 @@ mod tests { let tx = CHAN.sender(); let rx = CHAN.ser_receiver().unwrap(); drop(rx); - assert_eq!(tx.try_reserve().unwrap_err(), TrySendError::Closed(())); + assert_eq!( + tx.try_reserve().unwrap_err(), + TrySendError::Disconnected(()) + ); } #[test] @@ -335,7 +341,7 @@ mod tests { let tx = CHAN.sender(); let rx = CHAN.ser_receiver().unwrap(); drop(tx); - assert_eq!(rx.try_recv().unwrap_err(), TryRecvError::Closed); + assert_eq!(rx.try_recv().unwrap_err(), TryRecvError::Disconnected); } #[test] @@ -351,7 +357,7 @@ mod tests { let rx = CHAN.ser_receiver().unwrap(); drop(tx1); drop(tx2); - assert_eq!(rx.try_recv().unwrap_err(), TryRecvError::Closed); + assert_eq!(rx.try_recv().unwrap_err(), TryRecvError::Disconnected); } #[test] @@ -450,10 +456,7 @@ mod tests { let rx = CHAN.receiver().unwrap(); drop(rx); let res = tx.try_send([240, 223, 93, 160, 141, 6, 5, 104, 101, 108, 108, 111]); - assert_eq!( - res.unwrap_err(), - SerTrySendError::Send(TrySendError::Closed(())), - ); + assert_eq!(res.unwrap_err(), SerTrySendError::Disconnected,); } #[test] @@ -467,7 +470,7 @@ mod tests { let tx = CHAN.deser_sender(); let rx = CHAN.receiver().unwrap(); drop(tx); - assert_eq!(rx.try_recv().unwrap_err(), TryRecvError::Closed); + assert_eq!(rx.try_recv().unwrap_err(), TryRecvError::Disconnected); } #[test] @@ -483,7 +486,7 @@ mod tests { let rx = CHAN.receiver().unwrap(); drop(tx1); drop(tx2); - assert_eq!(rx.try_recv().unwrap_err(), TryRecvError::Closed); + assert_eq!(rx.try_recv().unwrap_err(), TryRecvError::Disconnected); } #[test] diff --git a/source/tricky-pipe/src/mpsc/tests.rs b/source/tricky-pipe/src/mpsc/tests.rs index c0819f8..3e3c0dd 100644 --- a/source/tricky-pipe/src/mpsc/tests.rs +++ b/source/tricky-pipe/src/mpsc/tests.rs @@ -174,7 +174,10 @@ mod single_threaded { let tx = test_dbg!(chan.sender()); let rx = test_dbg!(chan.receiver().unwrap()); drop(rx); - assert_eq!(tx.try_reserve().unwrap_err(), TrySendError::Closed(()),); + assert_eq!( + tx.try_reserve().unwrap_err(), + TrySendError::Disconnected(()), + ); }); } @@ -186,7 +189,7 @@ mod single_threaded { let rx = test_dbg!(chan.receiver().unwrap()); drop(chan); drop(tx); - assert_eq!(rx.try_recv().unwrap_err(), TryRecvError::Closed,); + assert_eq!(rx.try_recv().unwrap_err(), TryRecvError::Disconnected,); }); } @@ -200,7 +203,7 @@ mod single_threaded { drop(chan); drop(tx1); drop(tx2); - assert_eq!(rx.try_recv().unwrap_err(), TryRecvError::Closed,); + assert_eq!(rx.try_recv().unwrap_err(), TryRecvError::Disconnected,); }); } @@ -257,7 +260,10 @@ mod single_threaded { let tx = test_dbg!(chan.sender()); let rx = test_dbg!(chan.ser_receiver()).unwrap(); drop(rx); - assert_eq!(tx.try_reserve().unwrap_err(), TrySendError::Closed(())); + assert_eq!( + tx.try_reserve().unwrap_err(), + TrySendError::Disconnected(()) + ); }); } @@ -269,7 +275,7 @@ mod single_threaded { let rx = test_dbg!(chan.ser_receiver()).unwrap(); drop(chan); drop(tx); - assert_eq!(rx.try_recv().unwrap_err(), TryRecvError::Closed); + assert_eq!(rx.try_recv().unwrap_err(), TryRecvError::Disconnected); }); } @@ -283,7 +289,7 @@ mod single_threaded { drop(chan); drop(tx1); drop(tx2); - assert_eq!(rx.try_recv().unwrap_err(), TryRecvError::Closed); + assert_eq!(rx.try_recv().unwrap_err(), TryRecvError::Disconnected); }); } @@ -389,10 +395,7 @@ mod single_threaded { drop(chan); drop(rx); let res = tx.try_send([240, 223, 93, 160, 141, 6, 5, 104, 101, 108, 108, 111]); - assert_eq!( - res.unwrap_err(), - SerTrySendError::Send(TrySendError::Closed(())), - ); + assert_eq!(res.unwrap_err(), SerTrySendError::Disconnected,); }); } @@ -404,7 +407,7 @@ mod single_threaded { let rx = test_dbg!(chan.receiver()).unwrap(); drop(chan); drop(tx); - assert_eq!(rx.try_recv().unwrap_err(), TryRecvError::Closed); + assert_eq!(rx.try_recv().unwrap_err(), TryRecvError::Disconnected); }); } @@ -418,7 +421,7 @@ mod single_threaded { drop(chan); drop(tx1); drop(tx2); - assert_eq!(rx.try_recv().unwrap_err(), TryRecvError::Closed); + assert_eq!(rx.try_recv().unwrap_err(), TryRecvError::Disconnected); }); } From 47f20cad7726f665d2d81da4afb89b834ad4e00e Mon Sep 17 00:00:00 2001 From: Eliza Weisman Date: Tue, 21 Nov 2023 17:31:36 -0800 Subject: [PATCH 13/21] simplify erroring and allow both sides to error --- Cargo.lock | 1 + source/tricky-pipe/Cargo.toml | 1 + source/tricky-pipe/src/mpsc/arc_impl.rs | 6 +- source/tricky-pipe/src/mpsc/channel_core.rs | 127 ++++++++------------ source/tricky-pipe/src/mpsc/mod.rs | 42 ++++--- source/tricky-pipe/src/mpsc/static_impl.rs | 6 +- source/tricky-pipe/src/mpsc/tests.rs | 36 +++++- 7 files changed, 113 insertions(+), 106 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 2ae7823..30a2ba3 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1063,6 +1063,7 @@ dependencies = [ "loom 0.7.1", "maitake-sync", "mnemos-bitslab", + "mycelium-bitfield 0.1.3 (registry+https://github.com/rust-lang/crates.io-index)", "portable-atomic", "postcard", "serde", diff --git a/source/tricky-pipe/Cargo.toml b/source/tricky-pipe/Cargo.toml index 6d95263..a119ddd 100644 --- a/source/tricky-pipe/Cargo.toml +++ b/source/tricky-pipe/Cargo.toml @@ -25,6 +25,7 @@ alloc = ["postcard/alloc"] [dependencies] maitake-sync = "0.1.0" +mycelium-bitfield = { version = "0.1.3", default-features = false} portable-atomic = "1.4.3" mnemos-bitslab = { git = "https://github.com/tosc-rs/mnemos", branch = "eliza/bitslab-loom" } futures = { version = "0.3", features = ["async-await"], default-features = false } diff --git a/source/tricky-pipe/src/mpsc/arc_impl.rs b/source/tricky-pipe/src/mpsc/arc_impl.rs index 8a4fb99..733205c 100644 --- a/source/tricky-pipe/src/mpsc/arc_impl.rs +++ b/source/tricky-pipe/src/mpsc/arc_impl.rs @@ -71,10 +71,7 @@ impl TrickyPipe { pub fn receiver(&self) -> Option> { self.0.core.try_claim_rx()?; - Some(Receiver { - pipe: self.typed(), - closed_error: false, - }) + Some(Receiver { pipe: self.typed() }) } /// Obtain a [`Sender`] capable of sending `T`-typed data @@ -125,7 +122,6 @@ where Some(SerReceiver { pipe: self.erased(), vtable: Self::SER_VTABLE, - closed_error: false, }) } diff --git a/source/tricky-pipe/src/mpsc/channel_core.rs b/source/tricky-pipe/src/mpsc/channel_core.rs index db4917c..d53b6a0 100644 --- a/source/tricky-pipe/src/mpsc/channel_core.rs +++ b/source/tricky-pipe/src/mpsc/channel_core.rs @@ -76,12 +76,8 @@ pub(super) struct Core { /// This is the length of the actual queue elements array (which is not part /// of this struct). pub(super) capacity: u8, - /// If the receiver side of the channel closed with an error, this is the - /// error. - send_error: UnsafeCell>, - /// If the sender side of the channel closed with an error, this is the - /// error. - recv_error: UnsafeCell>, + /// If the channel has been closed with an error, this is that error. + error: UnsafeCell>, } pub(super) struct Reservation<'core, E> { @@ -141,15 +137,20 @@ pub(super) type DeserFn = fn(ErasedSlice, u8, &[u8]) -> postcard::Result<()>; /// Values for the `core.state` bitfield. mod state { + use mycelium_bitfield::PackUsize; /// If set, the channel's receiver has been claimed, indicating that no /// additional receivers can be claimed. - pub(super) const RX_CLAIMED: usize = 1 << 0; + pub(super) const RX_CLAIMED: PackUsize = PackUsize::least_significant(1); - /// Sender reference count; value of one sender. - pub(super) const TX_ONE: usize = 1 << TX_SHIFT; + pub(super) const ERRORING: PackUsize = RX_CLAIMED.next(1); + pub(super) const ERRORED: PackUsize = ERRORING.next(1); - /// Offset of TX count, in bits - pub(super) const TX_SHIFT: usize = 1; + /// Sender reference count. + pub(super) const TX_CNT: PackUsize = ERRORED.remaining(); + /// Sender reference count, one sender. + pub(super) const TX_ONE: usize = TX_CNT.first_bit(); + /// Sender reference count; bit offset. + pub(super) const TX_SHIFT: u32 = TX_CNT.least_significant_index(); } pub(super) const MAX_CAPACITY: usize = IndexAllocWord::MAX_CAPACITY as usize; @@ -161,9 +162,7 @@ pub(super) const MAX_CAPACITY: usize = IndexAllocWord::MAX_CAPACITY as usize; /// This is the first bit of the pos word, so that it is not clobbered if /// incrementing the actual position in the queue wraps around (which is fine). const CLOSED: u16 = 1 << 0; -const HAS_ERROR: u16 = 1 << 1; -const CLOSED_ERROR: u16 = CLOSED | HAS_ERROR; -const POS_SHIFT: u16 = CLOSED_ERROR.trailing_ones() as u16; +const POS_SHIFT: u16 = CLOSED.trailing_ones() as u16; /// The value by which `enqueue_pos` and `dequeue_pos` are incremented. This is /// shifted left by two to account for the lowest bits being used for `CLOSED` /// and `HAS_ERROR` @@ -202,8 +201,7 @@ impl Core { // dropped. state: AtomicUsize::new(state::TX_ONE), capacity, - send_error: UnsafeCell::new(MaybeUninit::uninit()), - recv_error: UnsafeCell::new(MaybeUninit::uninit()), + error: UnsafeCell::new(MaybeUninit::uninit()), } } @@ -231,7 +229,6 @@ impl Core { queue, state: AtomicUsize::new(state::TX_ONE), capacity, - error: UnsafeCell::new(MaybeUninit::uninit()), } } @@ -244,10 +241,10 @@ impl Core { pub(super) fn try_claim_rx(&self) -> Option<()> { // set `RX_CLAIMED`. - let state = test_dbg!(self.state.fetch_or(state::RX_CLAIMED, AcqRel)); + let state = test_dbg!(self.state.fetch_or(state::RX_CLAIMED.first_bit(), AcqRel)); // if the `RX_CLAIMED` bit was not set, we successfully claimed the // receiver. - let claimed = test_dbg!(state & state::RX_CLAIMED) == 0; + let claimed = test_dbg!(!state::RX_CLAIMED.contained_in_any(state)); test_println!(claimed, "Core::try_claim_rx"); claimed.then_some(()) } @@ -261,41 +258,36 @@ impl Core { test_println!("Core::close_rx: -> closed"); } - pub(super) fn close_rx_error(&self, error: E) { - // store the error in the channel. - self.send_error.with_mut(|ptr| unsafe { - // Safety: this is okay, because there is only one receiver, and the - // senders will not attempt to access the error until the receiver - // has set the `CLOSED_ERROR` bits. - // - // The receiver will not close the channel more than once. - (*ptr).write(error); - }); - // set the state to indicate that the receiver closed the channel. - test_dbg!(self.enqueue_pos.fetch_or(CLOSED_ERROR, Release)); - // notify any waiting senders that the channel is closed. - self.prod_wait.close(); - test_println!("Core::close_rx_error: -> closed"); + /// Close the channel from the sender side. + fn close_tx(&self) { + test_dbg!(self.dequeue_pos.fetch_or(CLOSED, Release)); + self.cons_wait.close(); } - pub(super) fn close_tx_error(&self, error: E) -> bool { - if test_dbg!(self.dequeue_pos.fetch_or(HAS_ERROR, AcqRel) & HAS_ERROR) == HAS_ERROR { - // someone else is setting the close error! + pub(super) fn close_with_error(&self, error: E) -> bool { + test_span!("Core::close_with_error()"); + // If `ERRORING` _or_ `ERRORED` are set, we can't set the error... + const CANT_ERROR: usize = state::ERRORING.first_bit() | state::ERRORED.first_bit(); + let state = test_dbg!(self.state.fetch_or(state::ERRORING.first_bit(), AcqRel)); + + if test_dbg!(state & CANT_ERROR != 0) { return false; } - // store the error in the channel. - self.recv_error.with_mut(|ptr| unsafe { - // Safety: this is okay, because the HAS_ERROR bit guards against - // any other sender setting the error, but the receiver will not - // read the error until the CLOSED bit is also set. For now, we have - // exclusive access to the error field. + + self.error.with_mut(|ptr| unsafe { + // Safety: this is okay, because access to the error field is + // guarded by the `ERRORING` bit, and if we were the first thread to + // successfully set it, then we have exclusive access to the error + // field. Readers won't try to access the error until we set the + // `ERRORED` bit, which hasn't been set yet. (*ptr).write(error); }); - // set the state to indicate that the sender closed the channel. - test_dbg!(self.dequeue_pos.fetch_or(CLOSED, Release)); - // notify any waiting senders that the channel is closed. - self.cons_wait.wake(); - test_println!("Core::close_tx_error: -> closed"); + + // set the ERRORED bit. + test_dbg!(self.state.fetch_or(state::ERRORED.first_bit(), Release)); + self.close_rx(); + self.close_tx(); + true } @@ -334,9 +326,7 @@ impl Core { debug_assert_eq!(_val >> state::TX_SHIFT, 0); // Now that we're after all other ref count ops, we can close the // channel itself. - test_dbg!(self.dequeue_pos.fetch_or(CLOSED, Release)); - self.cons_wait.close(); - + self.close_tx(); test_println!("Core::drop_tx -> closed"); } else { test_println!("Core::drop_tx -> tx refs remaining"); @@ -398,7 +388,7 @@ impl Core { let enqueue_pos = self.enqueue_pos.load(Acquire); if test_dbg!(enqueue_pos & CLOSED) == CLOSED { return Err(self - .send_closed_error() + .close_reason() .map(|error| TrySendError::Error { error, message: () }) .unwrap_or(TrySendError::Disconnected(()))); } @@ -417,7 +407,7 @@ impl Core { return Err(SendError::Error { error, message: () }) } Err(TrySendError::Full(())) => self.prod_wait.wait().await.map_err(|_| { - self.send_closed_error() + self.close_reason() .map(|error| SendError::Error { error, message: () }) .unwrap_or(SendError::Disconnected(())) })?, @@ -473,7 +463,7 @@ impl Core { match test_dbg!(dif).cmp(&0) { cmp::Ordering::Less if test_dbg!(head & CLOSED) != 0 => { return Err(self - .recv_close_error() + .close_reason() .map(TryRecvError::Error) .unwrap_or(TryRecvError::Disconnected)); } @@ -505,7 +495,7 @@ impl Core { loop { if test_dbg!(tail & CLOSED) == CLOSED { return Err(self - .send_closed_error() + .close_reason() .map(|error| SendError::Error { error, message: () }) .unwrap_or(SendError::Disconnected(()))); } @@ -538,27 +528,12 @@ impl Core { } } - fn send_closed_error(&self) -> Option { - let pos = self.enqueue_pos.load(Acquire); - debug_assert_eq!(pos & CLOSED, CLOSED); - if test_dbg!(pos & CLOSED_ERROR) == CLOSED_ERROR { - Some( - self.send_error - .with(|ptr| unsafe { (*ptr).assume_init_ref().clone() }), - ) - } else { - None - } - } - - fn recv_close_error(&self) -> Option { - let pos = self.dequeue_pos.load(Acquire); - debug_assert_eq!(pos & CLOSED, CLOSED); - if test_dbg!(pos & CLOSED_ERROR) == CLOSED_ERROR { - Some( - self.recv_error - .with(|ptr| unsafe { (*ptr).assume_init_ref().clone() }), - ) + fn close_reason(&self) -> Option { + if test_dbg!(state::ERRORED.contained_in_any(self.state.load(Acquire))) { + let error = self + .error + .with(|ptr| unsafe { (*ptr).assume_init_ref().clone() }); + Some(error) } else { None } diff --git a/source/tricky-pipe/src/mpsc/mod.rs b/source/tricky-pipe/src/mpsc/mod.rs index 295ded1..f935245 100644 --- a/source/tricky-pipe/src/mpsc/mod.rs +++ b/source/tricky-pipe/src/mpsc/mod.rs @@ -39,7 +39,6 @@ pub use self::arc_impl::*; /// [`StaticTrickyPipe::receiver`] and [`TrickyPipe::receiver`] methods. pub struct Receiver { pipe: TypedPipe, - closed_error: bool, } /// Sends `T`-typed values to an associated [`Receiver`]s or [`SerReceiver`]. @@ -59,7 +58,6 @@ pub struct Sender { pub struct SerReceiver { pipe: ErasedPipe, vtable: &'static SerVtable, - closed_error: bool, } /// Sends serialized values to an associated [`Receiver`] or [`SerReceiver`]. @@ -252,14 +250,8 @@ where /// This method returns `true` if the channel was successfully closed. If /// this channel has already been closed with an error, this method does /// nothing and returns `false`. - pub fn close_with_error(&mut self, error: E) -> bool { - if self.closed_error { - return false; - } - - self.pipe.core().close_rx_error(error); - - true + pub fn close_with_error(&self, error: E) -> bool { + self.pipe.core().close_with_error(error) } /// Returns `true` if this channel is empty. @@ -466,14 +458,8 @@ impl SerReceiver { /// This method returns `true` if the channel was successfully closed. If /// this channel has already been closed with an error, this method does /// nothing and returns `false`. - pub fn close_with_error(&mut self, error: E) -> bool { - if self.closed_error { - return false; - } - - self.pipe.core().close_rx_error(error); - - true + pub fn close_with_error(&self, error: E) -> bool { + self.pipe.core().close_with_error(error) } /// Returns `true` if this channel is empty. @@ -798,6 +784,16 @@ impl DeserSender { .send_framed(bytes) } + /// Close this channel with an error. Any subsequent attempts to send + /// messages to this channel will fail with `error`. + /// + /// This method returns `true` if the channel was successfully closed. If + /// this channel has already been closed with an error, this method does + /// nothing and returns `false`. + pub fn close_with_error(&self, error: E) -> bool { + self.pipe.core().close_with_error(error) + } + /// Returns `true` if this channel is empty. /// /// If this method returns `true`, calling [`Receiver::recv`] or @@ -1074,6 +1070,16 @@ impl Sender { Ok(Permit { cell, pipe }) } + /// Close this channel with an error. Any subsequent attempts to send + /// messages to this channel will fail with `error`. + /// + /// This method returns `true` if the channel was successfully closed. If + /// this channel has already been closed with an error, this method does + /// nothing and returns `false`. + pub fn close_with_error(&self, error: E) -> bool { + self.pipe.core().close_with_error(error) + } + /// Returns `true` if this channel is empty. /// /// If this method returns `true`, calling [`Receiver::recv`] or diff --git a/source/tricky-pipe/src/mpsc/static_impl.rs b/source/tricky-pipe/src/mpsc/static_impl.rs index 9a1d873..336e637 100644 --- a/source/tricky-pipe/src/mpsc/static_impl.rs +++ b/source/tricky-pipe/src/mpsc/static_impl.rs @@ -60,10 +60,7 @@ where pub fn receiver(&'static self) -> Option> { self.core.try_claim_rx()?; - Some(Receiver { - pipe: self.typed(), - closed_error: false, - }) + Some(Receiver { pipe: self.typed() }) } /// Obtain a [`Sender`] capable of sending `T`-typed data @@ -110,7 +107,6 @@ where Some(SerReceiver { pipe: self.erased(), vtable: Self::SER_VTABLE, - closed_error: false, }) } diff --git a/source/tricky-pipe/src/mpsc/tests.rs b/source/tricky-pipe/src/mpsc/tests.rs index 3e3c0dd..6aa9896 100644 --- a/source/tricky-pipe/src/mpsc/tests.rs +++ b/source/tricky-pipe/src/mpsc/tests.rs @@ -593,13 +593,13 @@ fn mpsc_send() { } #[test] -fn close_error_simple() { +fn rx_closes_error() { const CAPACITY: u8 = 2; loom::model(|| { let chan = TrickyPipe::, &'static str>::new(CAPACITY); - let mut rx = test_dbg!(chan.receiver()).expect("can't get rx"); + let rx = test_dbg!(chan.receiver()).expect("can't get rx"); let tx = chan.sender(); rx.close_with_error("fake rx error"); @@ -618,6 +618,38 @@ fn close_error_simple() { }) } +#[test] +fn tx_closes_error() { + const CAPACITY: u8 = 2; + + loom::model(|| { + let chan = TrickyPipe::, &'static str>::new(CAPACITY); + + let rx = test_dbg!(chan.receiver()).expect("can't get rx"); + let tx1 = chan.sender(); + let tx2 = chan.sender(); + + let t1 = thread::spawn(move || { + tx1.close_with_error("fake tx1 error"); + }); + + let t2 = thread::spawn(move || { + tx2.close_with_error("fake tx2 error"); + }); + + future::block_on(async move { + let err = test_dbg!(rx.recv().await).unwrap_err(); + assert!(matches!( + err, + RecvError::Error("fake tx1 error") | RecvError::Error("fake tx2 error") + )) + }); + + t1.join().unwrap(); + t2.join().unwrap(); + }) +} + fn do_tx( sends: usize, offset: usize, From 3638747e7f2c4c6e6be54d0c7f3800a71035b4ee Mon Sep 17 00:00:00 2001 From: Eliza Weisman Date: Wed, 22 Nov 2023 09:15:49 -0800 Subject: [PATCH 14/21] document error types --- source/tricky-pipe/src/mpsc/error.rs | 115 ++++++++++++++++++++++++--- 1 file changed, 103 insertions(+), 12 deletions(-) diff --git a/source/tricky-pipe/src/mpsc/error.rs b/source/tricky-pipe/src/mpsc/error.rs index e95c997..9cdc283 100644 --- a/source/tricky-pipe/src/mpsc/error.rs +++ b/source/tricky-pipe/src/mpsc/error.rs @@ -7,29 +7,47 @@ //! [`SerReceiver`]: super::SerReceiver use core::fmt; -/// A message cannot be sent because the channel is closed (no [`Receiver`] -/// or [`SerReceiver`] exists). +/// Errors returned by [`Sender::send`], [`Sender::reserve`], and +/// [`DeserSender::reserve`], indicating that a message cannot be sent to the +/// channel. /// /// A `SendError<()>` is returned by the [`Sender::reserve`] and /// [`DeserSender::reserve`] methods. The [`Sender::send`] method instead returns /// a `SendError`, from which the original message can be recovered using /// [`SendError::into_inner`]. /// +/// Both the [`Disconnected`](Self::Disconnected) and [`Error`](Self::Error) +/// variants of this error indicate that the channel is *closed*: once a +/// `SendError` is returned, no future attempts to send a message to this +/// channel (i.e. [`Sender::send`], [`Sender::reserve`], +/// [`DeserSender::reserve`], [`DeserSender::send`], etc.) will ever +/// succeed. +/// /// [`Receiver`]: super::Receiver /// [`SerReceiver`]: super::SerReceiver +/// [`Sender::send`]: super::Sender::send /// [`Sender::reserve`]: super::Sender::reserve /// [`DeserSender::reserve`]: super::DeserSender::reserve -/// [`Sender::send`]: super::Sender::send +/// [`DeserSender::send`]: super::DeserSender::send #[derive(Eq, PartialEq)] pub enum SendError { - /// A message cannot be sent because the channel is closed (no [`Receiver`] - /// or [`SerReceiver`] exists). + /// A message cannot be sent because no [`Receiver`] or [`SerReceiver`] + /// exists to receive the message. /// /// [`Receiver`]: super::Receiver /// [`SerReceiver`]: super::SerReceiver Disconnected(T), + + /// A message could not be sent because this channel was closed with an + /// error, by the [`Receiver::close_with_error`] or + /// [`SerReceiver::close_with_error`] methods. + /// + /// [`Receiver::close_with_error`]: super::Receiver::close_with_error + /// [`SerReceiver::close_with_error`]: super::Receiver::close_with_error Error { + /// The message that the sender was attempting to send. message: T, + /// The error set when the channel closed. error: E, }, } @@ -50,14 +68,24 @@ pub enum TrySendError { /// The channel is currently full, and a message cannot be sent without /// waiting for a slot to become available. Full(T), - /// A message cannot be sent because the channel is closed (no [`Receiver`] - /// or [`SerReceiver`] exists). + + /// A message cannot be sent because no [`Receiver`] or [`SerReceiver`] + /// exists. /// /// [`Receiver`]: super::Receiver /// [`SerReceiver`]: super::SerReceiver Disconnected(T), + + /// A message could not be sent because this channel was closed with an + /// error, by the [`Receiver::close_with_error`] or + /// [`SerReceiver::close_with_error`] methods. + /// + /// [`Receiver::close_with_error`]: super::Receiver::close_with_error + /// [`SerReceiver::close_with_error`]: super::Receiver::close_with_error Error { + /// The message that the sender was attempting to send. message: T, + /// The error set when the channel closed. error: E, }, } @@ -71,14 +99,35 @@ pub enum TryRecvError { /// No messages are currently present in the channel. The receiver must wait /// for an additional message to be sent. Empty, - /// A message cannot be received because the channel is closed. + + /// A message cannot be received because tno [`Sender`] or [`DeserSender`] + /// exists. /// /// This indicates that no [`Sender`]s or [`DeserSender`]s exist, and all /// previously sent messages have already been received. /// + /// If this variant is returned, the channel is permanently closed, and no + /// subsequent calls to [`Receiver::try_recv`] or [`SerReceiver::try_recv`] + /// will succeed on this channel. + /// /// [`Sender`]: super::Sender /// [`DeserSender`]: super::DeserSender + /// [`Receiver::try_recv`]: Receiver::try_recv + /// [`SerReceiver::try_recv`]: SerReceiver::try_recv Disconnected, + + /// A message could not be sent because this channel was closed with an + /// error, by the [`Sender::close_with_error`] or + /// [`DeserSender::close_with_error`] methods. + /// + /// If this variant is returned, the channel is permanently closed, and no + /// subsequent calls to [`Receiver::try_recv`] or [`SerReceiver::try_recv`] + /// will succeed on this channel. + /// + /// [`Sender::close_with_error`]: super::Sender::close_with_error + /// [`DeserSender::close_with_error`]: super::DeserSender::close_with_error + /// [`Receiver::try_recv`]: Receiver::try_recv + /// [`SerReceiver::try_recv`]: SerReceiver::try_recv Error(E), } @@ -96,6 +145,13 @@ pub enum RecvError { /// [`Sender`]: super::Sender /// [`DeserSender`]: super::DeserSender Disconnected, + + /// A message could not be sent because this channel was closed with an + /// error, by the [`Receiver::close_with_error`] or + /// [`SerReceiver::close_with_error`] methods. + /// + /// [`Receiver::close_with_error`]: super::Receiver::close_with_error + /// [`SerReceiver::close_with_error`]: super::Receiver::close_with_error Error(E), } @@ -105,15 +161,28 @@ pub enum RecvError { /// [`DeserSender::send_framed`]: super::DeserSender::send_framed #[derive(Debug, Eq, PartialEq)] pub enum SerSendError { - /// A message cannot be sent because the channel is closed (no [`Receiver`] - /// or [`SerReceiver`] exists). + /// A message cannot be sent because no [`Receiver`] or [`SerReceiver`] exists. /// /// [`Receiver`]: super::Receiver /// [`SerReceiver`]: super::SerReceiver Disconnected, + /// The sent bytes could not be deserialized to a value of this channel's /// message type. Deserialize(postcard::Error), + + /// A message could not be sent because this channel was closed with an + /// error, by the [`Receiver::close_with_error`] or + /// [`SerReceiver::close_with_error`] methods. + /// + /// If this variant is returned, the channel is permanently closed, and no + /// subsequent calls to [`DeserSender::send`] and + /// [`DeserSender::send_framed`] will succeed on this channel. + /// + /// [`Receiver::close_with_error`]: super::Receiver::close_with_error + /// [`SerReceiver::close_with_error`]: super::Receiver::close_with_error + /// [`DeserSender::send`]: super::DeserSender::send + /// [`DeserSender::send_framed`]: super::DeserSender::send_framed Error(E), } @@ -124,17 +193,39 @@ pub enum SerSendError { /// [`DeserSender::try_send_framed`]: super::DeserSender::send_framed #[derive(Debug, Eq, PartialEq)] pub enum SerTrySendError { - /// A message cannot be sent because the channel is closed (no [`Receiver`] - /// or [`SerReceiver`] exists). + /// A message cannot be sent because no [`Receiver`] or [`SerReceiver`] + /// exists. + /// + /// If this variant is returned, the channel is permanently closed, and no + /// subsequent calls to [`DeserSender::try_send`] and + /// [`DeserSender::try_send_framed`] will succeed on this channel. /// /// [`Receiver`]: super::Receiver /// [`SerReceiver`]: super::SerReceiver + /// [`DeserSender::try_send`]: super::DeserSender::try_send + /// [`DeserSender::try_send_framed`]: super::DeserSender::try_send_framed Disconnected, + + /// The channel is currently full, and a message cannot be sent without + /// waiting for a slot to become available. Full, + /// The sent bytes could not be deserialized to a value of this channel's /// message type. Deserialize(postcard::Error), + /// A message could not be sent because this channel was closed with an + /// error, by the [`Receiver::close_with_error`] or + /// [`SerReceiver::close_with_error`] methods. + /// + /// If this variant is returned, the channel is permanently closed, and no + /// subsequent calls to [`DeserSender::try_send`] and + /// [`DeserSender::try_send_framed`] will succeed on this channel. + /// + /// [`Receiver::close_with_error`]: super::Receiver::close_with_error + /// [`SerReceiver::close_with_error`]: super::Receiver::close_with_error + /// [`DeserSender::try_send`]: super::DeserSender::try_send + /// [`DeserSender::try_send_framed`]: super::DeserSender::try_send_framed Error(E), } From dc208f9d75cd7ca658afcc31cb262c47bd52a46b Mon Sep 17 00:00:00 2001 From: Eliza Weisman Date: Wed, 22 Nov 2023 09:39:24 -0800 Subject: [PATCH 15/21] bidi re-engoodening --- source/mgnp/src/conn_table.rs | 3 +- source/tricky-pipe/src/bidi.rs | 208 ++++++++++++++++++++++++----- source/tricky-pipe/src/mpsc/mod.rs | 11 +- 3 files changed, 180 insertions(+), 42 deletions(-) diff --git a/source/mgnp/src/conn_table.rs b/source/mgnp/src/conn_table.rs index 03533b9..5ddc98c 100644 --- a/source/mgnp/src/conn_table.rs +++ b/source/mgnp/src/conn_table.rs @@ -398,8 +398,7 @@ impl ConnTable { return false; } - let (_, mut rx) = channel.split(); - rx.close_with_error(reason) + channel.close_with_error(reason) } /// Returns `true` if a connection with the provided ID was closed, `false` if diff --git a/source/tricky-pipe/src/bidi.rs b/source/tricky-pipe/src/bidi.rs index f46f584..576e9b0 100644 --- a/source/tricky-pipe/src/bidi.rs +++ b/source/tricky-pipe/src/bidi.rs @@ -4,7 +4,10 @@ //! [`Sender`] and [`Receiver`] or a [`DeserSender`] and [`SerReceiver`] //! (respectively) into a single bidirectional channel which can both send and //! receive messages to/from a remote peer. -use crate::mpsc::*; +use crate::mpsc::{ + error::{RecvError, SendError}, + *, +}; use core::fmt; use futures::FutureExt; @@ -16,6 +19,8 @@ use futures::FutureExt; pub struct BiDi { tx: Sender, rx: Receiver, + seen_rx_error: bool, + seen_tx_error: bool, } /// A bidirectional type-erased serializing channel. @@ -27,6 +32,8 @@ pub struct BiDi { pub struct SerBiDi { tx: DeserSender, rx: SerReceiver, + seen_rx_error: bool, + seen_tx_error: bool, } /// Events returned by [`BiDi::wait`] and [`SerBiDi::wait`]. @@ -39,6 +46,22 @@ pub enum Event { SendReady(Out), } +/// [`Result`]s returned by [`BiDi::wait`] and [`SerBiDi::wait`]. +pub type WaitResult = Result>; + +/// Errors returned by [`BiDi::wait`] and [`SerBiDi::wait`]. +#[derive(Debug)] +pub enum WaitError { + /// The receive side of the channel has been closed with an error. + Recv(E), + /// The send side of the channel has been closed with an error. + Send(E), + /// Both the send and receive sides are disconnected (all corresponding + /// [`Sender`]/[`DeserSender`]s and the corresponding ([`Receiver`] or + /// [`SerReceiver`] have been dropped). + Disconnected, +} + impl BiDi where In: 'static, @@ -47,7 +70,12 @@ where { /// Constructs a new `BiDi` from a [`Sender`] and a [`Receiver`]. pub fn from_pair(tx: Sender, rx: Receiver) -> Self { - Self { tx, rx } + Self { + tx, + rx, + seen_rx_error: false, + seen_tx_error: false, + } } /// Consumes `self`, extracting the inner [`Sender`] and [`Receiver`]. @@ -56,22 +84,55 @@ where (self.tx, self.rx) } - // /// Wait until the channel is either ready to send a message *or* a new - // /// incoming message is received, whichever occurs first. - // #[must_use] - // pub async fn wait(&self) -> Option>> { - // futures::select_biased! { - // res = self.tx.reserve().fuse() => { - // match res { - // Ok(permit) => Some(Event::SendReady(permit)), - // Err(_) => self.rx.recv().await.map(Event::Recv), - // } - // } - // recv = self.rx.recv().fuse() => { - // recv.map(Event::Recv) - // } - // } - // } + /// Wait until the channel is either ready to send a message *or* a new + /// incoming message is received, whichever occurs first. + #[must_use] + pub async fn wait(&mut self) -> WaitResult>, E> { + futures::select_biased! { + reserve = self.tx.reserve().fuse() => { + match reserve { + Ok(permit) => Ok(Event::SendReady(permit)), + // If the send channel has closed with an error, return it + // immediately *if we haven't returned it already*. If we + // *have* returned that error previously, fall through and + // try a recv. + Err(SendError::Error{ error, .. }) if !self.seen_tx_error => { + self.seen_tx_error = true; + Err(WaitError::Send(error)) + } + Err(_) => self.rx.recv().await.map(Event::Recv).map_err(|error| match error { + // both sides have disconnected + RecvError::Disconnected => WaitError::Disconnected, + RecvError::Error(e) => { + self.seen_rx_error = true; + WaitError::Recv(e) + } + }), + } + } + recv = self.rx.recv().fuse() => { + match recv { + Ok(msg) => Ok(Event::Recv(msg)), + // If the recv channel has closed with an error, return it + // immediately *if we haven't returned it already*. If we + // *have* returned that error previously, fall through and + // try a send. + Err(RecvError::Error(e)) if !self.seen_rx_error => { + self.seen_rx_error = true; + Err(WaitError::Recv(e)) + } + Err(_) => self.tx.reserve().await.map(Event::SendReady).map_err(|error| match error { + // both sides have disconnected + SendError::Disconnected(()) => WaitError::Disconnected, + SendError::Error { error, .. } => { + self.seen_tx_error = true; + WaitError::Send(error) + } + }), + } + } + } + } /// Borrows the **send half** of this bidirectional channel. /// @@ -93,6 +154,17 @@ where &self.rx } + /// Closes both sides of this channel with an error. + /// + /// Returns `true` if *either* side of the channel was closed by this error. + /// If both sides of the channel have already closed, this method returns + /// `false`. + pub fn close_with_error(&self, error: E) -> bool { + let tx_closed = self.tx.close_with_error(error.clone()); + let rx_closed = self.rx.close_with_error(error); + rx_closed || tx_closed + } + /// Returns `true` if **both halves** of this bidirectional channel are /// empty. /// @@ -129,10 +201,17 @@ where E: 'static, { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - let Self { tx, rx } = self; + let Self { + tx, + rx, + seen_rx_error, + seen_tx_error, + } = self; f.debug_struct("BiDi") .field("tx", tx) .field("rx", rx) + .field("seen_rx_error", seen_rx_error) + .field("seen_tx_error", seen_tx_error) .finish() } } @@ -142,7 +221,12 @@ where impl SerBiDi { /// Constructs a new `SerBiDi` from a [`DeserSender`] and a [`SerReceiver`]. pub fn from_pair(tx: DeserSender, rx: SerReceiver) -> Self { - Self { tx, rx } + Self { + tx, + rx, + seen_rx_error: false, + seen_tx_error: false, + } } /// Consumes `self`, extracting the inner [`DeserSender`] and [`SerReceiver`]. @@ -151,22 +235,65 @@ impl SerBiDi { (self.tx, self.rx) } - // /// Wait until the channel is either ready to send a message *or* a new - // /// incoming message is received, whichever occurs first. - // #[must_use] - // pub async fn wait(&self) -> Option, SerPermit<'_, ()>>> { - // futures::select_biased! { - // res = self.tx.reserve().fuse() => { - // match res { - // Ok(permit) => Some(Event::SendReady(permit)), - // Err(_) => self.rx.recv().await.map(Event::Recv), - // } - // } - // recv = self.rx.recv().fuse() => { - // recv.map(Event::Recv) - // } - // } - // } + /// Wait until the channel is either ready to send a message *or* a new + /// incoming message is received, whichever occurs first. + pub async fn wait(&mut self) -> WaitResult, SerPermit<'_, E>>, E> { + futures::select_biased! { + reserve = self.tx.reserve().fuse() => { + match reserve { + Ok(permit) => Ok(Event::SendReady(permit)), + // If the send channel has closed with an error, return it + // immediately *if we haven't returned it already*. If we + // *have* returned that error previously, fall through and + // try a recv. + Err(SendError::Error{ error, .. }) if !self.seen_tx_error => { + self.seen_tx_error = true; + Err(WaitError::Send(error)) + } + Err(_) => self.rx.recv().await.map(Event::Recv).map_err(|error| match error { + // both sides have disconnected + RecvError::Disconnected => WaitError::Disconnected, + RecvError::Error(e) => { + self.seen_rx_error = true; + WaitError::Recv(e) + } + }), + } + } + recv = self.rx.recv().fuse() => { + match recv { + Ok(msg) => Ok(Event::Recv(msg)), + // If the recv channel has closed with an error, return it + // immediately *if we haven't returned it already*. If we + // *have* returned that error previously, fall through and + // try a send. + Err(RecvError::Error(e)) if !self.seen_rx_error => { + self.seen_rx_error = true; + Err(WaitError::Recv(e)) + } + Err(_) => self.tx.reserve().await.map(Event::SendReady).map_err(|error| match error { + // both sides have disconnected + SendError::Disconnected(()) => WaitError::Disconnected, + SendError::Error { error, .. } => { + self.seen_tx_error = true; + WaitError::Send(error) + } + }), + } + } + } + } + + /// Closes both sides of this channel with an error. + /// + /// Returns `true` if *either* side of the channel was closed by this error. + /// If both sides of the channel have already closed, this method returns + /// `false`. + pub fn close_with_error(&self, error: E) -> bool { + let tx_closed = self.tx.close_with_error(error.clone()); + let rx_closed = self.rx.close_with_error(error); + rx_closed || tx_closed + } /// Borrows the **send half** of this bidirectional channel. /// @@ -219,10 +346,17 @@ impl SerBiDi { impl fmt::Debug for SerBiDi { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - let Self { tx, rx } = self; + let Self { + tx, + rx, + seen_rx_error, + seen_tx_error, + } = self; f.debug_struct("SerBiDi") .field("tx", tx) .field("rx", rx) + .field("seen_rx_error", seen_rx_error) + .field("seen_tx_error", seen_tx_error) .finish() } } diff --git a/source/tricky-pipe/src/mpsc/mod.rs b/source/tricky-pipe/src/mpsc/mod.rs index f935245..b61ec28 100644 --- a/source/tricky-pipe/src/mpsc/mod.rs +++ b/source/tricky-pipe/src/mpsc/mod.rs @@ -1162,9 +1162,13 @@ impl fmt::Debug for Sender { // === impl Permit === impl Permit<'_, T, E> { - /// Write the given value into the [Permit], and send it + /// Write the given value into the [Permit], and send it. /// /// This makes the data available to the [Receiver]. + /// + /// Capacity for the message has already been reserved. The message is sent + /// to the receiver and the permit is consumed. The operation will succeed + /// even if the receiver half has been closed. pub fn send(self, val: T) { // write the value... unsafe { @@ -1173,7 +1177,7 @@ impl Permit<'_, T, E> { self.cell.deref().write(val); // ...and commit. - self.commit(); + self.commit() } } @@ -1197,7 +1201,8 @@ impl Permit<'_, T, E> { #[cfg_attr(not(loom), allow(clippy::drop_non_drop))] drop(self.cell); - self.pipe.commit_send(); + // ignore errors here because capacity is already reserved. + let _ = self.pipe.commit_send(); } } From 0e766e641d32161172b5a833b0632f78c6671e09 Mon Sep 17 00:00:00 2001 From: Eliza Weisman Date: Wed, 22 Nov 2023 10:11:02 -0800 Subject: [PATCH 16/21] test cleanup --- source/mgnp/src/tests/integration.rs | 152 ++++++++++++--------------- 1 file changed, 69 insertions(+), 83 deletions(-) diff --git a/source/mgnp/src/tests/integration.rs b/source/mgnp/src/tests/integration.rs index 1a7e81e..dc5d0e6 100644 --- a/source/mgnp/src/tests/integration.rs +++ b/source/mgnp/src/tests/integration.rs @@ -1,7 +1,7 @@ use super::*; use crate::{ - message::{self, InboundFrame, OutboundFrame}, - Wire, + message::{self, DecodeError, DecodeErrorKind, Header, InboundFrame, OutboundFrame, Reset}, + Id, Wire, }; use tricky_pipe::serbox; @@ -16,7 +16,7 @@ async fn reset_decode_error() { let hello = hellobox.share(()).await; wire.send(OutboundFrame::connect( - crate::Id::new(1), + Id::new(1), svcs::hello_world_id(), hello, )) @@ -28,17 +28,17 @@ async fn reset_decode_error() { assert_eq!( msg, Ok(InboundFrame { - header: message::Header::Ack { - local_id: crate::Id::new(1), - remote_id: crate::Id::new(1), + header: Header::Ack { + local_id: Id::new(1), + remote_id: Id::new(1), }, body: &[] }) ); - let mut out_frame = postcard::to_allocvec(&message::Header::Data { - local_id: crate::Id::new(1), - remote_id: crate::Id::new(1), + let mut out_frame = postcard::to_allocvec(&Header::Data { + local_id: Id::new(1), + remote_id: Id::new(1), }) .unwrap(); out_frame.extend(&[0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff]); @@ -46,18 +46,15 @@ async fn reset_decode_error() { wire.send_bytes(out_frame).await.unwrap(); let frame = wire.recv().await.unwrap(); - let msg = InboundFrame::from_bytes(&frame[..]); - assert_eq!( - msg, - Ok(InboundFrame { - header: message::Header::Reset { - remote_id: crate::Id::new(1), - reason: message::Reset::YouDoneGoofed(message::DecodeError::Body( - message::DecodeErrorKind::UnexpectedEnd - )) + expect_inbound_frame( + frame, + InboundFrame { + header: Header::Reset { + remote_id: Id::new(1), + reason: Reset::YouDoneGoofed(DecodeError::Body(DecodeErrorKind::UnexpectedEnd)), }, - body: &[] - }) + body: &[], + }, ); } @@ -72,7 +69,7 @@ async fn reset_no_such_conn() { let hello = hellobox.share(()).await; wire.send(OutboundFrame::connect( - crate::Id::new(1), + Id::new(1), svcs::hello_world_id(), hello, )) @@ -84,9 +81,9 @@ async fn reset_no_such_conn() { assert_eq!( msg, Ok(InboundFrame { - header: message::Header::Ack { - local_id: crate::Id::new(1), - remote_id: crate::Id::new(1), + header: Header::Ack { + local_id: Id::new(1), + remote_id: Id::new(1), }, body: &[] }) @@ -95,25 +92,27 @@ async fn reset_no_such_conn() { let chan = tricky_pipe::mpsc::TrickyPipe::new(8); let rx = chan.ser_receiver().unwrap(); let tx = chan.sender(); - tx.try_send(svcs::HelloWorldRequest { - hello: "hello".into(), - }) - .unwrap(); - - let body = rx.try_recv().unwrap(); - let out_frame = { + let data_frame = |header: Header| { + tx.try_send(svcs::HelloWorldRequest { + hello: "hello".into(), + }) + .expect("send should just work"); + let body = rx.try_recv().expect("recv should just work"); let frame = OutboundFrame { - header: message::Header::Data { - local_id: crate::Id::new(1), - remote_id: crate::Id::new(1), // good conn ID - }, + header, body: message::OutboundData::Data(body), }; - frame.to_vec().unwrap() + tracing::info!(frame = %format_args!("{frame:#?}"), "OUTBOUND FRAME"); + frame.to_vec().expect("frame must serialize") }; - wire.send_bytes(out_frame).await.unwrap(); + wire.send_bytes(data_frame(Header::Data { + local_id: Id::new(1), // known good ID + remote_id: Id::new(1), + })) + .await + .unwrap(); let frame = wire.recv().await.unwrap(); let msg = InboundFrame::from_bytes(&frame[..]).unwrap(); @@ -125,62 +124,49 @@ async fn reset_no_such_conn() { ); // another message, with a bad conn ID - tx.try_send(svcs::HelloWorldRequest { - hello: "hello".into(), - }) - .unwrap(); - let body = rx.try_recv().unwrap(); - let out_frame = OutboundFrame { - header: message::Header::Data { - remote_id: crate::Id::new(666), // bad conn ID - local_id: crate::Id::new(1), - }, - body: message::OutboundData::Data(body), - } - .to_vec() + wire.send_bytes(data_frame(Header::Data { + local_id: Id::new(1), + remote_id: Id::new(666), // bad conn ID + })) + .await .unwrap(); - wire.send_bytes(out_frame).await.unwrap(); let frame = wire.recv().await.unwrap(); - let msg = dbg!(InboundFrame::from_bytes(&frame[..])); - assert_eq!( - msg, - Ok(InboundFrame { - header: message::Header::Reset { - remote_id: crate::Id::new(1), - reason: message::Reset::NoSuchConn, + expect_inbound_frame( + frame, + InboundFrame { + header: Header::Reset { + remote_id: Id::new(1), + reason: Reset::NoSuchConn, }, - body: &[] - }) + body: &[], + }, ); // another message, with a differently conn ID - tx.try_send(svcs::HelloWorldRequest { - hello: "hello".into(), - }) - .unwrap(); - let body = rx.try_recv().unwrap(); - let out_frame = OutboundFrame { - header: message::Header::Data { - remote_id: crate::Id::new(1), - local_id: crate::Id::new(666), // bad conn ID - }, - body: message::OutboundData::Data(body), - } - .to_vec() + wire.send_bytes(data_frame(Header::Data { + local_id: Id::new(666), // bad conn ID + remote_id: Id::new(1), + })) + .await .unwrap(); - wire.send_bytes(out_frame).await.unwrap(); let frame = wire.recv().await.unwrap(); - let msg = dbg!(InboundFrame::from_bytes(&frame[..])); - assert_eq!( - msg, - Ok(InboundFrame { - header: message::Header::Reset { - remote_id: crate::Id::new(666), - reason: message::Reset::NoSuchConn, + expect_inbound_frame( + frame, + InboundFrame { + header: Header::Reset { + remote_id: Id::new(666), + reason: Reset::NoSuchConn, }, - body: &[] - }) + body: &[], + }, ); } + +#[track_caller] +fn expect_inbound_frame(frame: impl AsRef<[u8]>, expected: InboundFrame<'_>) { + let decoded = InboundFrame::from_bytes(frame.as_ref()); + tracing::info!(frame = %format_args!("{decoded:#?}"), "INBOUND FRAME"); + assert_eq!(decoded, Ok(expected)); +} From 5a2da381a276e34476eee4acb7b5d9eeaea30ea3 Mon Sep 17 00:00:00 2001 From: Eliza Weisman Date: Wed, 22 Nov 2023 10:11:42 -0800 Subject: [PATCH 17/21] cleanup --- source/mgnp/src/conn_table.rs | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/source/mgnp/src/conn_table.rs b/source/mgnp/src/conn_table.rs index 5ddc98c..722f83f 100644 --- a/source/mgnp/src/conn_table.rs +++ b/source/mgnp/src/conn_table.rs @@ -6,10 +6,7 @@ use crate::{ use core::{fmt, mem, num::NonZeroU16, task::Poll}; use tricky_pipe::{ bidi::SerBiDi, - mpsc::{ - error::{RecvError, SendError, SerSendError}, - SerPermit, - }, + mpsc::error::{RecvError, SerSendError}, oneshot, }; From 9e6f47c3cf01a6bd9a559a6f6642d409b5d1cf75 Mon Sep 17 00:00:00 2001 From: Eliza Weisman Date: Wed, 22 Nov 2023 10:55:37 -0800 Subject: [PATCH 18/21] add test that a service shutting down sends reset --- source/mgnp/src/tests/e2e.rs | 114 +++++++++++++++++++++++++++-------- source/mgnp/src/tests/mod.rs | 70 ++++++++++++++++----- 2 files changed, 146 insertions(+), 38 deletions(-) diff --git a/source/mgnp/src/tests/e2e.rs b/source/mgnp/src/tests/e2e.rs index be151a8..7882e8e 100644 --- a/source/mgnp/src/tests/e2e.rs +++ b/source/mgnp/src/tests/e2e.rs @@ -15,9 +15,7 @@ async fn basically_works() { let chan = connect(&mut connector, "hello-world", ()).await; chan.tx() - .send(HelloWorldRequest { - hello: "hello".to_string(), - }) + .send(svcs::hello_req("hello")) .await .expect("send request"); let rsp = chan.rx().recv().await; @@ -52,9 +50,7 @@ async fn hellos_work() { .await; chan.tx() - .send(HelloWorldRequest { - hello: "hello".to_string(), - }) + .send(svcs::hello_req("hello")) .await .expect("send request"); let rsp = chan.rx().recv().await; @@ -102,9 +98,7 @@ async fn nak_bad_hello() { // the good connection should stil lwork chan.tx() - .send(HelloWorldRequest { - hello: "hello".to_string(), - }) + .send(svcs::hello_req("hello")) .await .expect("send request"); let rsp = chan.rx().recv().await; @@ -134,12 +128,8 @@ async fn mux_single_service() { let chan2 = connect(&mut connector, "hello-world", ()).await; tokio::try_join! { - chan1.tx().send(HelloWorldRequest { - hello: "hello".to_string(), - }), - chan2.tx().send(HelloWorldRequest { - hello: "hello".to_string(), - }) + chan1.tx().send(svcs::hello_req("hello")), + chan2.tx().send(svcs::hello_req("hello")) } .expect("send should work"); @@ -218,9 +208,7 @@ async fn service_type_routing() { helloworld_chan .tx() - .send(HelloWorldRequest { - hello: "hello".to_string(), - }) + .send(svcs::hello_req("hello")) .await .expect("send request"); let rsp = helloworld_chan.rx().recv().await; @@ -245,17 +233,13 @@ async fn service_type_routing() { hellohello_chan .tx() - .send(HelloWorldRequest { - hello: "hello".to_string(), - }) + .send(svcs::hello_req("hello")) .await .expect("send request"); helloworld_chan .tx() - .send(HelloWorldRequest { - hello: "hello".to_string(), - }) + .send(svcs::hello_req("hello")) .await .expect("send request"); @@ -355,3 +339,85 @@ async fn service_identity_routing() { }) ); } + +#[tokio::test] +async fn reset_closed() { + let remote_registry: TestRegistry = TestRegistry::default(); + let conns = remote_registry.add_service(svcs::hello_world_id()); + let shutdown = Arc::new(tokio::sync::Notify::new()); + + tokio::spawn(svcs::serve_hello_with_shutdown( + "hello", + "world", + conns, + shutdown.clone(), + )); + + let fixture = Fixture::new() + .spawn_local(Default::default()) + .spawn_remote(remote_registry); + + let mut connector = fixture.local_iface().connector::(); + + let chan1 = connect(&mut connector, "hello-world", ()).await; + + let chan2 = connect(&mut connector, "hello-world", ()).await; + + tokio::try_join! { + chan1.tx().send(svcs::hello_req("hello")), + chan2.tx().send(svcs::hello_req("hello")) + } + .expect("send should work"); + + let (rsp1, rsp2) = tokio::join! { + chan1.rx().recv(), + chan2.rx().recv(), + }; + + assert_eq!( + rsp1, + Ok(HelloWorldResponse { + world: "world".to_string() + }) + ); + assert_eq!( + rsp2, + Ok(HelloWorldResponse { + world: "world".to_string() + }) + ); + + // now shut down the remote service + tracing::info!(""); + tracing::info!("!!! shutting down remote service !!!"); + tracing::info!(""); + shutdown.notify_waiters(); + + let send2 = tokio::join! { + chan1.tx().send(svcs::hello_req("hello")), + chan2.tx().send(svcs::hello_req("hello")) + }; + + let _ = dbg!(send2); + + let (rsp1, rsp2) = tokio::join! { + chan1.rx().recv(), + chan2.rx().recv(), + }; + + assert_eq!( + dbg!(rsp1), + Err(tricky_pipe::mpsc::error::RecvError::Error( + Reset::BecauseISaidSo + )) + ); + + assert_eq!( + dbg!(rsp2), + Err(tricky_pipe::mpsc::error::RecvError::Error( + Reset::BecauseISaidSo + )) + ); + + fixture.finish_test().await; +} diff --git a/source/mgnp/src/tests/mod.rs b/source/mgnp/src/tests/mod.rs index c087d12..f7332b9 100644 --- a/source/mgnp/src/tests/mod.rs +++ b/source/mgnp/src/tests/mod.rs @@ -10,6 +10,7 @@ use crate::{ use std::{ collections::HashMap, fmt, + future::Future, sync::{Arc, RwLock}, }; use tokio::sync::{mpsc, oneshot, Notify}; @@ -66,11 +67,28 @@ pub(crate) mod svcs { registry::Identity::from_name::("hello-world") } + pub fn hello_req(hello: impl ToString) -> HelloWorldRequest { + HelloWorldRequest { + hello: hello.to_string(), + } + } + #[tracing::instrument(level = tracing::Level::INFO, skip(conns))] pub async fn serve_hello( + req_msg: &'static str, + rsp_msg: &'static str, + conns: mpsc::Receiver, + ) { + let shutdown = Arc::new(tokio::sync::Notify::new()); + serve_hello_with_shutdown(req_msg, rsp_msg, conns, shutdown).await + } + + #[tracing::instrument(level = tracing::Level::INFO, skip(conns, shutdown))] + pub async fn serve_hello_with_shutdown( req_msg: &'static str, rsp_msg: &'static str, mut conns: mpsc::Receiver, + shutdown: Arc, ) { let mut worker = 1; while let Some(req) = conns.recv().await { @@ -78,30 +96,47 @@ pub(crate) mod svcs { tracing::info!(?hello, "hello world service received connection"); let (their_chan, my_chan) = make_bidis::(8); - tokio::spawn(hello_worker(worker, req_msg, rsp_msg, my_chan)); + tokio::spawn(hello_worker( + worker, + req_msg, + rsp_msg, + my_chan, + shutdown.clone(), + )); worker += 1; let sent = rsp.send(Ok(their_chan)).is_ok(); tracing::debug!(?sent); } } - #[tracing::instrument(level = tracing::Level::INFO, skip(chan))] - pub(super) async fn hello_worker( + #[tracing::instrument(level = tracing::Level::INFO, skip(chan, shutdown))] + pub async fn hello_worker( worker: usize, req_msg: &'static str, rsp_msg: &'static str, chan: BiDi, + shutdown: Arc, ) { tracing::debug!("hello world worker {worker} running..."); - while let Ok(req) = chan.rx().recv().await { - tracing::info!(?req); - assert_eq!(req.hello, req_msg); - chan.tx() - .send(svcs::HelloWorldResponse { - world: rsp_msg.into(), - }) - .await - .unwrap(); + loop { + tokio::select! { + biased; + _ = shutdown.notified() => { + tracing::info!("worker shutting down!"); + return; + } + req = chan.rx().recv() => { + tracing::info!(?req); + assert_eq!(req.expect("request should be Ok").hello, req_msg); + chan.tx() + .send(svcs::HelloWorldResponse { + world: rsp_msg.into(), + }) + .await + .unwrap(); + } + + } } } } @@ -144,7 +179,7 @@ impl Fixture { registry, TrickyPipe::new(8), ); - let task = tokio::spawn(interface("name", machine, test_done.clone())); + let task = tokio::spawn(interface(name, machine, test_done.clone())); (iface, task) } } @@ -354,6 +389,7 @@ impl TestRegistry { let mut chan = self.add_service(svcs::hello_with_hello_id()); tokio::spawn( async move { + let shutdown = Arc::new(Notify::new()); let mut worker = 1; while let Some(req) = chan.recv().await { let InboundConnect { hello, rsp } = req; @@ -364,7 +400,13 @@ impl TestRegistry { tracing::info!(?hello, "hellohello service received hello"); let (their_chan, my_chan) = make_bidis::(8); - tokio::spawn(svcs::hello_worker(worker, "hello", "world", my_chan)); + tokio::spawn(svcs::hello_worker( + worker, + "hello", + "world", + my_chan, + shutdown.clone(), + )); worker += 1; Ok(their_chan) } else { From efe051edb55ce8d645eb067fb4c374f149ff83b8 Mon Sep 17 00:00:00 2001 From: Eliza Weisman Date: Wed, 22 Nov 2023 12:07:22 -0800 Subject: [PATCH 19/21] docs fixy-uppy --- source/tricky-pipe/src/mpsc/error.rs | 8 +- source/tricky-pipe/src/mpsc/mod.rs | 151 ++++++++++++++++++++------- 2 files changed, 118 insertions(+), 41 deletions(-) diff --git a/source/tricky-pipe/src/mpsc/error.rs b/source/tricky-pipe/src/mpsc/error.rs index 9cdc283..5ca04e8 100644 --- a/source/tricky-pipe/src/mpsc/error.rs +++ b/source/tricky-pipe/src/mpsc/error.rs @@ -112,8 +112,8 @@ pub enum TryRecvError { /// /// [`Sender`]: super::Sender /// [`DeserSender`]: super::DeserSender - /// [`Receiver::try_recv`]: Receiver::try_recv - /// [`SerReceiver::try_recv`]: SerReceiver::try_recv + /// [`Receiver::try_recv`]: super::Receiver::try_recv + /// [`SerReceiver::try_recv`]: super::SerReceiver::try_recv Disconnected, /// A message could not be sent because this channel was closed with an @@ -126,8 +126,8 @@ pub enum TryRecvError { /// /// [`Sender::close_with_error`]: super::Sender::close_with_error /// [`DeserSender::close_with_error`]: super::DeserSender::close_with_error - /// [`Receiver::try_recv`]: Receiver::try_recv - /// [`SerReceiver::try_recv`]: SerReceiver::try_recv + /// [`Receiver::try_recv`]: super::Receiver::try_recv + /// [`SerReceiver::try_recv`]: super::SerReceiver::try_recv Error(E), } diff --git a/source/tricky-pipe/src/mpsc/mod.rs b/source/tricky-pipe/src/mpsc/mod.rs index b61ec28..c7696ad 100644 --- a/source/tricky-pipe/src/mpsc/mod.rs +++ b/source/tricky-pipe/src/mpsc/mod.rs @@ -184,9 +184,12 @@ where /// # Returns /// /// - [`Ok`]`(T)` if a message was received from the channel. - /// - [`Err`]`(`[`TryRecvError::Closed`]`) if the channel has been closed - /// (all [`Sender`]s and [`DeserSender`]s have been dropped) *and* all - /// messages sent before the channel closed have already been received. + /// - [`Err`]`(`[`TryRecvError::Disconnected`]`)` if all [`Sender`]s and + /// [`DeserSender`]s have been dropped *and* all messages sent before the + /// channel closed have already been received. + /// - [`Err`]`(`[`TryRecvError::Error`]`(E)` if the channel has been closed + /// with an error using the [`Receiver::close_with_error`] or + /// [`Sender::close_with_error`] methods. /// - [`Err`]`(`[`TryRecvError::Empty`]`)` if there are currently no /// messages in the queue, but the channel has not been closed. pub fn try_recv(&self) -> Result> { @@ -207,7 +210,7 @@ where /// # } /// ``` /// - /// The [`Future`] returned by this method outputs [`None`] if the channel + /// The [`Future`] returned by this method outputs an error if the channel /// has been closed (all [`Sender`]s and [`DeserSender`]s have been dropped) /// *and* all messages sent before the channel closed have been received. /// @@ -218,6 +221,16 @@ where /// To return an error rather than waiting, use the /// [`try_recv`](Self::try_recv) method, instead. /// + /// # Returns + /// + /// - [`Ok`]`(T)` if a message was received from the channel. + /// - [`Err`]`(`[`RecvError::Disconnected`]`)` if all [`Sender`]s and + /// [`DeserSender`]s have been dropped *and* all messages sent before the + /// channel closed have already been received. + /// - [`Err`]`(`[`RecvError::Error`]`(E)` if the channel has been closed + /// with an error using the [`Receiver::close_with_error`] or + /// [`Sender::close_with_error`] methods. + /// /// # Cancellation Safety /// /// This method is cancel-safe. If `recv` is used as part of a `select!` or @@ -232,6 +245,18 @@ where /// Polls to receive a message from the channel, returning [`Poll::Ready`] /// if a message has been recieved, or [`Poll::Pending`] if there are /// currently no messages in the channel. + /// + /// # Returns + /// + /// - [`Poll::Ready`]`(`[`Ok`]`(T))` if a message was received from the channel. + /// - [`Poll::Ready`]`(`[`Err`]`(`[`RecvError::Disconnected`]`))` if all + /// [`Sender`]s and [`DeserSender`]s have been dropped *and* all messages + /// sent before the channel closed have already been received. + /// - [`Poll::Ready`]`(`[`Err`]`(`[`TryRecvError::Error`]`(E))` if the channel + /// has been closed with an error using the [`Receiver::close_with_error`] + /// or [`Sender::close_with_error`] methods. + /// - [`Poll::Pending`] if there are currently no messages in the queue and + /// the calling task should wait for additional messages to be sent. pub fn poll_recv(&self, cx: &mut Context<'_>) -> Poll>> { self.pipe .core() @@ -375,10 +400,13 @@ impl SerReceiver { /// serialize the message as a COBS frame, use the /// [`SerRecvRef::to_slice_framed`] or [`SerRecvRef::to_vec_framed`] /// methods, instead. - /// - [`Err`]`([`TryRecvError::Closed`]) if the channel has been closed - /// (all [`Sender`]s and [`DeserSender`]s have been dropped) *and* all - /// messages sent before the channel closed have already been received. - /// - [`Err`]`([`TryRecvError::Empty`]) if there are currently no + /// - [`Err`]`(`[`TryRecvError::Disconnected`]`)` if all [`Sender`]s and + /// [`DeserSender`]s have been dropped) *and* all messages sent before the + /// channel closed have already been received. + /// - [`Err`]`(`[`TryRecvError::Error`]`(E)` if the channel has been closed + /// with an error using the [`Receiver::close_with_error`] or + /// [`Sender::close_with_error`] methods. + /// - [`Err`]`(`[`TryRecvError::Empty`]`)` if there are currently no /// messages in the queue, but the channel has not been closed. /// /// [`Vec`]: alloc::vec::Vec @@ -404,10 +432,11 @@ impl SerReceiver { /// ``` /// /// This method returns a [`SerRecv`] [`Future`] that outputs an - /// [`Option`]`<`[`SerRecvRef`]`>`. The future will complete with [`None`] - /// if the channel has been closed (all [`Sender`]s and [`DeserSender`]s - /// have been dropped) *and* all messages sent before the channel closed - /// have been received. If the channel has not yet been closed, but there + /// [`Result`]`<`[`SerRecvRef`]`, `[`RecvError`]`>`. The future will + /// complete with an error if the channel has been closed (all [`Sender`]s + /// and [`DeserSender`]s have been dropped) *and* all messages sent before + /// the channel closed have been received, or if the channel is closed with + /// an error by the user. If the channel has not yet been closed, but there /// are no messages currently available in the queue, the [`SerRecv`] future /// yields and waits for a new message to be sent, or for the channel to /// close. @@ -417,15 +446,19 @@ impl SerReceiver { /// /// # Returns /// - /// - [`Some`]`(`[`SerRecvRef`]`)` if a message was received from the + /// - [`Ok`]`(`[`SerRecvRef`]`)` if a message was received from the /// channel. The [`SerRecvRef::to_slice`] and [`SerRecvRef::to_vec`] /// methods can be used to serialize the binary representation of the /// message to a `&mut [u8]` or to a [`Vec`]``, respectively. To /// serialize the message as a COBS frame, use the /// [`SerRecvRef::to_slice_framed`] or [`SerRecvRef::to_vec_framed`] /// methods, instead. - /// - [`None`] if the channel is closed *and* all messages have been - /// received. + /// - [`Err`]`(`[`RecvError::Disconnected`]`)` if all [`Sender`]s and + /// [`DeserSender`]s have been dropped *and* all messages sent before the + /// channel closed have already been received. + /// - [`Err`]`(`[`RecvError::Error`]`(E)` if the channel has been closed + /// with an error using the [`Receiver::close_with_error`] or + /// [`Sender::close_with_error`] methods. /// /// # Cancellation Safety /// @@ -442,6 +475,19 @@ impl SerReceiver { /// Polls to receive a serialized message from the channel, returning /// [`Poll::Ready`] if a message has been recieved, or [`Poll::Pending`] if /// there are currently no messages in the channel. + /// + /// # Returns + /// + /// - [`Poll::Ready`]`(`[`Ok`]`(`[`SerRecvRef`]`<'_, E>))` if a message was + /// received from the channel. + /// - [`Poll::Ready`]`(`[`Err`]`(`[`RecvError::Disconnected`]`))` if all + /// [`Sender`]s and [`DeserSender`]s have been dropped *and* all messages + /// sent before the channel closed have already been received. + /// - [`Poll::Ready`]`(`[`Err`]`(`[`RecvError::Error`]`(E))` if the channel + /// has been closed with an error using the [`Receiver::close_with_error`] + /// or [`Sender::close_with_error`] methods. + /// - [`Poll::Pending`] if there are currently no messages in the queue and + /// the calling task should wait for additional messages to be sent. pub fn poll_recv(&self, cx: &mut Context<'_>) -> Poll, RecvError>> { self.pipe.core().poll_dequeue(cx).map(|res| { Ok(SerRecvRef { @@ -660,8 +706,12 @@ impl DeserSender { /// # Returns /// /// - [`Ok`]`(`[`SerPermit`]`)` if the channel is not closed. - /// - [`Err`]`(`[SendError`]`<()>)` if the channel is closed (the - /// [`Receiver`] or [`SerReceiver`]) has been dropped. + /// - [`Err`]`(`[`SendError::Disconnected`]`<()>)` if the [`Receiver`] or + /// [`SerReceiver`]) has been dropped. + /// - [`Err`]`(`[`SendError::Error`]`)` if the channel has been closed + /// with an error using the [`Sender::close_with_error`] or + /// [`Receiver::close_with_error`] methods. This indicates that subsequent + /// calls to [`try_reserve`] or `reserve` on this channel will always fail. /// /// # Cancellation Safety /// @@ -704,10 +754,13 @@ impl DeserSender { /// /// - [`Ok`]`(`[`SerPermit`]`)` if the channel has capacity available and /// has not closed. - /// - [`Err`]`(`[TrySendError::Closed`]`)` if the channel is closed (the - /// [`Receiver`] or [`SerReceiver`]) has been dropped. This indicates that - /// subsequent calls to `try_reserve` or [`reserve`] on this channel will - /// always fail. + /// - [`Err`]`(`[`TrySendError::Disconnected`]`<()>)` if the [`Receiver`] or + /// [`SerReceiver`] has been dropped. This indicates that subsequent calls + /// to `try_reserve` or [`reserve`] on this channel will always fail. + /// - [`Err`]`(`[`TrySendError::Error`]`)` if the channel has been closed + /// with an error using the [`Sender::close_with_error`] or + /// [`Receiver::close_with_error`] methods. This indicates that subsequent + /// calls to `try_reserve` or [`reserve`] on this channel will always fail. /// - [`Err`]`(`[`TrySendError::Full`]`)` if the channel does not currently /// have capacity to send another message without waiting. A subsequent /// call to `try_reserve` may complete successfully, once capacity has @@ -935,8 +988,12 @@ impl Sender { /// # Returns /// /// - [`Ok`]`(`[`()`]`)` if the channel is not closed. - /// - [`Err`]([1SendError`]``) if the channel is closed (the - /// [`Receiver`] or [`SerReceiver`]) has been dropped. + /// - [`Err`]([`SendError::Disconnected`]``) if the [`Receiver`] or + /// [`SerReceiver`]) has been dropped. + /// - [`Err`]`(`[`SendError::Error`]`)` if the channel has been closed + /// with an error using the [`Sender::close_with_error`] or + /// [`Receiver::close_with_error`] methods. This indicates that subsequent + /// calls to `send` or [`try_send`] on this channel will always fail. /// /// # Cancellation Safety /// @@ -969,10 +1026,15 @@ impl Sender { /// # Returns /// /// - [`Ok`]`(())` if the message was sent successfully. - /// - [`Err`]([`TrySendError::Closed`]``) if the channel is closed (the - /// [`Receiver`] or [`SerReceiver`]) has been dropped. This indicates that - /// subsequent calls to [`send`], `try_send`, [`try_reserve`], or - /// [`reserve`] on this channel will always fail. + /// - [`Err`]([`TrySendError::Disconnected`]``) if the [`Receiver`] or + /// [`SerReceiver`]) has been dropped. This indicates that subsequent + /// calls to [`send`], `try_send`, [`try_reserve`], or [`reserve`] on this + /// channel will always fail. + /// - [`Err`]([`TrySendError::Error`]``) if the channel has been closed + /// with an error using the [`Sender::close_with_error`] or + /// [`Receiver::close_with_error`] methods. This indicates that subsequent + /// calls to [`send`], `try_send`, [`try_reserve`], or [`reserve`] on this + /// channel will always fail. /// - [`Err`]`(`[`TrySendError::Full`]`)` if the channel does not currently /// have capacity to send another message without waiting. A subsequent /// call to `try_reserve` may complete successfully, once capacity has @@ -1006,16 +1068,17 @@ impl Sender { /// To attempt to reserve capacity *without* waiting if the channel is full, /// use the [`try_reserve`] method, instead. /// - /// [`Permit`]: Permit - /// [`send`]: Permit::send - /// [`commit`]: Permit::commit - /// [`try_reserve`]: Self::try_reserve /// /// # Returns /// /// - [`Ok`]`(`[`Permit`]`)` if the channel is not closed. - /// - [`Err`]`(`[SendError::Closed`]`)` if the channel is closed (the - /// [`Receiver`] or [`SerReceiver`]) has been dropped. + /// - [`Err`]([`SendError::Disconnected`]`<()>`) if the [`Receiver`] or + /// [`SerReceiver`]) has been dropped. + /// - [`Err`]`(`[`SendError::Error`]`)` if the channel has been closed + /// with an error using the [`Sender::close_with_error`] or + /// [`Receiver::close_with_error`] methods. This indicates that subsequent + /// calls to `reserve`, [`try_reserve`], [`send`](Self::send), + /// [`try_send`] on this channel will always fail. /// /// # Cancellation Safety /// @@ -1025,6 +1088,12 @@ impl Sender { /// This channel uses a queue to ensure that calls to `send` and `reserve` /// complete in the order they were requested. Cancelling a call to /// `reserve` causes the caller to lose its place in that queue. + /// + /// [`Permit`]: Permit + /// [`send`]: Permit::send + /// [`try_send`]: Self::try_send + /// [`commit`]: Permit::commit + /// [`try_reserve`]: Self::try_reserve pub async fn reserve(&self) -> Result, SendError> { let pipe = self.pipe.core().reserve().await?; let cell = self.pipe.elems()[pipe.idx as usize].get_mut(); @@ -1056,14 +1125,22 @@ impl Sender { /// /// - [`Ok`]`(`[`Permit`]`)` if the channel has capacity available and /// has not closed. - /// - [`Err`]`(`[TrySendError::Closed`]`)` if the channel is closed (the - /// [`Receiver`] or [`SerReceiver`]) has been dropped. This indicates that - /// subsequent calls to `try_reserve` or [`reserve`] on this channel will - /// always fail. + /// - [`Err`]([`TrySendError::Disconnected`]`<()>`) if the [`Receiver`] or + /// [`SerReceiver`]) has been dropped. This indicates that subsequent + /// calls to [`send`], `try_send`, [`try_reserve`], or [`reserve`] on this + /// channel will always fail. + /// - [`Err`]([`TrySendError::Error`]``) if the channel has been closed + /// with an error using the [`Sender::close_with_error`] or + /// [`Receiver::close_with_error`] methods. This indicates that subsequent + /// calls to [`send`](Self::send), `try_send`, [`try_reserve`], or + /// [`reserve`] on this channel will always fail. /// - [`Err`]`(`[`TrySendError::Full`]`)` if the channel does not currently /// have capacity to send another message without waiting. A subsequent /// call to `try_reserve` may complete successfully, once capacity has /// become available again. + /// + /// [`reserve`]: Self::reserve + /// [`try_reserve`]: Self::try_reserve pub fn try_reserve(&self) -> Result, TrySendError> { let pipe = self.pipe.core().try_reserve()?; let cell = self.pipe.elems()[pipe.idx as usize].get_mut(); From 018a806befabc67748b66dbd6bb63c1303854a44 Mon Sep 17 00:00:00 2001 From: Eliza Weisman Date: Wed, 22 Nov 2023 12:15:22 -0800 Subject: [PATCH 20/21] s/nak/reject --- source/mgnp/src/conn_table.rs | 10 ++++++++-- source/mgnp/src/message.rs | 4 ++-- 2 files changed, 10 insertions(+), 4 deletions(-) diff --git a/source/mgnp/src/conn_table.rs b/source/mgnp/src/conn_table.rs index 722f83f..0ef6990 100644 --- a/source/mgnp/src/conn_table.rs +++ b/source/mgnp/src/conn_table.rs @@ -119,6 +119,12 @@ impl ConnTable { .await } + pub(crate) async fn reset_all(&mut self) { + for entry in self.conns.iter_mut() { + if entry + } + } + fn cleanup_dead(&mut self) { // receiving a data frame from the conn table borrows it, so we must // remove the dead index from the *previous* next_outbound call before @@ -241,7 +247,7 @@ impl ConnTable { tracing::trace!(id.remote = %local_id, ?identity, "process_inbound: CONNECT"); match registry.connect(identity, frame.body).await { Ok(channel) => Some(self.accept(local_id, channel)), - Err(reason) => Some(OutboundFrame::nak(local_id, reason)), + Err(reason) => Some(OutboundFrame::reject(local_id, reason)), } } } @@ -342,7 +348,7 @@ impl ConnTable { // Accepted, we got a local ID! Some(local_id) => OutboundFrame::ack(local_id, remote_id), // Conn table is full, can't accept this stream. - None => OutboundFrame::nak(remote_id, Rejection::ConnTableFull(CAPACITY)), + None => OutboundFrame::reject(remote_id, Rejection::ConnTableFull(CAPACITY)), } } diff --git a/source/mgnp/src/message.rs b/source/mgnp/src/message.rs index 26048b2..090f906 100644 --- a/source/mgnp/src/message.rs +++ b/source/mgnp/src/message.rs @@ -111,7 +111,7 @@ pub enum Rejection { NotFound, /// The connection was rejected by the [`Service`](crate::registry::Service). /// - /// The body of this [`NAK`](Header::Nak) frame may contain additional bytes + /// The body of this [`REJECT`](Header::Reject) frame may contain additional bytes /// which can be interpreted as a [service-specific `ConnectError` /// value](crate::registry::Service::ConnectError).] ServiceRejected, @@ -304,7 +304,7 @@ impl<'data> Frame> { } } - pub fn nak(remote_id: Id, reason: Rejection) -> Self { + pub fn reject(remote_id: Id, reason: Rejection) -> Self { Self { header: Header::Reject { remote_id, reason }, body: OutboundData::Empty, // todo From 51be5267bbe1e0ea3431d144c380d9b35dfd6888 Mon Sep 17 00:00:00 2001 From: Eliza Weisman Date: Wed, 22 Nov 2023 12:28:10 -0800 Subject: [PATCH 21/21] gah didnt mean to do that --- source/mgnp/src/conn_table.rs | 6 ------ 1 file changed, 6 deletions(-) diff --git a/source/mgnp/src/conn_table.rs b/source/mgnp/src/conn_table.rs index 0ef6990..a09ee09 100644 --- a/source/mgnp/src/conn_table.rs +++ b/source/mgnp/src/conn_table.rs @@ -119,12 +119,6 @@ impl ConnTable { .await } - pub(crate) async fn reset_all(&mut self) { - for entry in self.conns.iter_mut() { - if entry - } - } - fn cleanup_dead(&mut self) { // receiving a data frame from the conn table borrows it, so we must // remove the dead index from the *previous* next_outbound call before