From 68ce91e3e3161e979d4c13eb60a0851594259d04 Mon Sep 17 00:00:00 2001 From: Floris Bruynooghe Date: Thu, 23 Jan 2025 12:21:43 +0100 Subject: [PATCH] Make some things infallible --- iroh-relay/src/server/clients.rs | 2 +- iroh/examples/dht_discovery.rs | 2 +- iroh/examples/echo.rs | 2 +- iroh/examples/listen-unreliable.rs | 6 +- iroh/examples/listen.rs | 6 +- iroh/examples/search.rs | 2 +- iroh/examples/transfer.rs | 2 +- iroh/src/endpoint.rs | 112 +++++++++++++++++------------ 8 files changed, 77 insertions(+), 57 deletions(-) diff --git a/iroh-relay/src/server/clients.rs b/iroh-relay/src/server/clients.rs index 2164f149d4..2f60d7b078 100644 --- a/iroh-relay/src/server/clients.rs +++ b/iroh-relay/src/server/clients.rs @@ -71,7 +71,7 @@ impl Clients { /// peer is gone from the network. /// /// Must be passed a matching connection_id. - pub(super) async fn unregister<'a>(&self, connection_id: u64, node_id: NodeId) { + pub(super) async fn unregister(&self, connection_id: u64, node_id: NodeId) { trace!( node_id = node_id.fmt_short(), connection_id, diff --git a/iroh/examples/dht_discovery.rs b/iroh/examples/dht_discovery.rs index c514474d76..00df0d27dc 100644 --- a/iroh/examples/dht_discovery.rs +++ b/iroh/examples/dht_discovery.rs @@ -88,7 +88,7 @@ async fn chat_server(args: Args) -> anyhow::Result<()> { }; tokio::spawn(async move { let connection = connecting.await?; - let remote_node_id = connection.remote_node_id()?; + let remote_node_id = connection.remote_node_id(); println!("got connection from {}", remote_node_id); // just leave the tasks hanging. this is just an example. let (mut writer, mut reader) = connection.accept_bi().await?; diff --git a/iroh/examples/echo.rs b/iroh/examples/echo.rs index 89bd643a34..e26ce3b183 100644 --- a/iroh/examples/echo.rs +++ b/iroh/examples/echo.rs @@ -77,7 +77,7 @@ impl ProtocolHandler for Echo { // Wait for the connection to be fully established. let connection = connecting.await?; // We can get the remote's node id from the connection. - let node_id = connection.remote_node_id()?; + let node_id = connection.remote_node_id(); println!("accepted connection from {node_id}"); // Our protocol is a simple request-response protocol, so we expect the diff --git a/iroh/examples/listen-unreliable.rs b/iroh/examples/listen-unreliable.rs index 5028ca520a..1123daa43f 100644 --- a/iroh/examples/listen-unreliable.rs +++ b/iroh/examples/listen-unreliable.rs @@ -58,7 +58,7 @@ async fn main() -> anyhow::Result<()> { // accept incoming connections, returns a normal QUIC connection while let Some(incoming) = endpoint.accept().await { - let mut connecting = match incoming.accept() { + let connecting = match incoming.accept() { Ok(connecting) => connecting, Err(err) => { warn!("incoming connection failed: {err:#}"); @@ -67,9 +67,9 @@ async fn main() -> anyhow::Result<()> { continue; } }; - let alpn = connecting.alpn().await?; let conn = connecting.await?; - let node_id = conn.remote_node_id()?; + let alpn = conn.alpn(); + let node_id = conn.remote_node_id(); info!( "new (unreliable) connection from {node_id} with ALPN {} (coming from {})", String::from_utf8_lossy(&alpn), diff --git a/iroh/examples/listen.rs b/iroh/examples/listen.rs index fb93e5342d..90f00be056 100644 --- a/iroh/examples/listen.rs +++ b/iroh/examples/listen.rs @@ -59,7 +59,7 @@ async fn main() -> anyhow::Result<()> { ); // accept incoming connections, returns a normal QUIC connection while let Some(incoming) = endpoint.accept().await { - let mut connecting = match incoming.accept() { + let connecting = match incoming.accept() { Ok(connecting) => connecting, Err(err) => { warn!("incoming connection failed: {err:#}"); @@ -68,9 +68,9 @@ async fn main() -> anyhow::Result<()> { continue; } }; - let alpn = connecting.alpn().await?; let conn = connecting.await?; - let node_id = conn.remote_node_id()?; + let alpn = conn.alpn(); + let node_id = conn.remote_node_id(); info!( "new connection from {node_id} with ALPN {} (coming from {})", String::from_utf8_lossy(&alpn), diff --git a/iroh/examples/search.rs b/iroh/examples/search.rs index d60b629038..0d032465a2 100644 --- a/iroh/examples/search.rs +++ b/iroh/examples/search.rs @@ -134,7 +134,7 @@ impl ProtocolHandler for BlobSearch { // Wait for the connection to be fully established. let connection = connecting.await?; // We can get the remote's node id from the connection. - let node_id = connection.remote_node_id()?; + let node_id = connection.remote_node_id(); println!("accepted connection from {node_id}"); // Our protocol is a simple request-response protocol, so we expect the diff --git a/iroh/examples/transfer.rs b/iroh/examples/transfer.rs index a80e64b705..9f480f1dd5 100644 --- a/iroh/examples/transfer.rs +++ b/iroh/examples/transfer.rs @@ -121,7 +121,7 @@ async fn provide(size: u64, relay_url: Option, relay_only: bool) -> anyh } }; let conn = connecting.await?; - let node_id = conn.remote_node_id()?; + let node_id = conn.remote_node_id(); info!( "new connection from {node_id} with ALPN {} (coming from {})", String::from_utf8_lossy(TRANSFER_ALPN), diff --git a/iroh/src/endpoint.rs b/iroh/src/endpoint.rs index 4ceaa1e51d..584c4ee029 100644 --- a/iroh/src/endpoint.rs +++ b/iroh/src/endpoint.rs @@ -1257,16 +1257,24 @@ impl Connecting { } /// Extracts the ALPN protocol from the peer's handshake data. - // Note, we could totally provide this method to be on a Connection as well. But we'd - // need to wrap Connection too. - pub async fn alpn(&mut self) -> Result> { + pub async fn alpn(&mut self) -> Result, ConnectionError> { let data = self.handshake_data().await?; - match data.downcast::() { - Ok(data) => match data.protocol { - Some(protocol) => Ok(protocol), - None => bail!("no ALPN protocol available"), - }, - Err(_) => bail!("unknown handshake type"), + let data = data + .downcast::() + .expect("fixed crypto setup for iroh"); + match data.protocol { + Some(proto) => Ok(proto), + None => { + // Using QUIC's CONNECTION_REFUSED error on the server-side is a bit odd, + // but perhaps not a total crime. Strictly speaking for QUIC this is an + // application error, but if we surfaced this as an application error it + // would be even more confusing for the user. + Err(ConnectionError::TransportError(TransportError { + code: TransportErrorCode::CONNECTION_REFUSED, + frame: None, + reason: String::from("iroh connections must use an ALPN"), + })) + } } } } @@ -1281,6 +1289,27 @@ impl Future for Connecting { Poll::Ready(Err(err)) => Poll::Ready(Err(err)), Poll::Ready(Ok(inner)) => { let conn = Connection { inner }; + + // An iroh connection MUST always have an ALPN. Note that this check is + // only valid for accepting connections, but we **currently** do not use + // this `Connecting` struct when establishing connections. Once we do we'll + // have to make this check conditional. + let handshake_data = conn.handshake_data(); + let handshake_data = handshake_data + .downcast::() + .expect("fixed crypto setup for iroh"); + if handshake_data.protocol.is_none() { + // Using QUIC's CONNECTION_REFUSED error on the server-side is a bit odd, + // but perhaps not a total crime. Strictly speaking for QUIC this is an + // application error, but if we surfaced this as an application error it + // would be even more confusing for the user. + return Poll::Ready(Err(ConnectionError::TransportError(TransportError { + code: TransportErrorCode::CONNECTION_REFUSED, + frame: None, + reason: String::from("iroh connections must use an ALPN"), + }))); + } + try_send_rtt_msg(&conn, this.ep); Poll::Ready(Ok(conn)) } @@ -1511,17 +1540,20 @@ impl Connection { /// /// [`Connection::handshake_data()`]: crate::Connecting::handshake_data #[inline] - pub fn handshake_data(&self) -> Option> { - self.inner.handshake_data() + pub fn handshake_data(&self) -> Box { + self.inner + .handshake_data() + .expect("we always have handshake data") } /// Extracts the ALPN protocol from the peer's handshake data. - pub fn alpn(&self) -> Option> { - let data = self.handshake_data()?; - match data.downcast::() { - Ok(data) => data.protocol, - Err(_) => None, - } + pub fn alpn(&self) -> Vec { + let data = self.handshake_data(); + let data = data + .downcast::() + .expect("fixed crypto setup for iroh"); + data.protocol + .expect("checked in ::poll") } /// Cryptographic identity of the peer. @@ -1533,8 +1565,10 @@ impl Connection { /// [`Session`]: quinn_proto::crypto::Session /// [`downcast`]: Box::downcast #[inline] - pub fn peer_identity(&self) -> Option> { - self.inner.peer_identity() + pub fn peer_identity(&self) -> Box { + self.inner + .peer_identity() + .expect("we always have a peer identity") } /// Returns the [`NodeId`] from the peer's TLS certificate. @@ -1545,25 +1579,14 @@ impl Connection { /// connection. /// /// [`PublicKey`]: iroh_base::PublicKey - // TODO: Would be nice if this could be infallible. - pub fn remote_node_id(&self) -> Result { + pub fn remote_node_id(&self) -> NodeId { let data = self.peer_identity(); - match data { - None => bail!("no peer certificate found"), - Some(data) => match data.downcast::>() { - Ok(certs) => { - if certs.len() != 1 { - bail!( - "expected a single peer certificate, but {} found", - certs.len() - ); - } - let cert = tls::certificate::parse(&certs[0])?; - Ok(cert.peer_id()) - } - Err(_) => bail!("invalid peer certificate"), - }, - } + let certs = data + .downcast::>() + .expect("we always have a cert"); + debug_assert_eq!(certs.len(), 1); + let cert = tls::certificate::parse(&certs[0]).expect("valid cert"); + cert.peer_id() } /// A stable identifier for this connection. @@ -1624,10 +1647,7 @@ impl Connection { /// function. fn try_send_rtt_msg(conn: &Connection, magic_ep: &Endpoint) { // If we can't notify the rtt-actor that's not great but not critical. - let Ok(node_id) = conn.remote_node_id() else { - warn!(?conn, "failed to get remote node id"); - return; - }; + let node_id = conn.remote_node_id(); let Ok(conn_type_changes) = magic_ep.conn_type(node_id) else { warn!(?conn, "failed to create conn_type stream"); return; @@ -1953,7 +1973,7 @@ mod tests { info!("[server] round {i}"); let incoming = ep.accept().await.unwrap(); let conn = incoming.await.unwrap(); - let node_id = conn.remote_node_id().unwrap(); + let node_id = conn.remote_node_id(); info!(%i, peer = %node_id.fmt_short(), "accepted connection"); let (mut send, mut recv) = conn.accept_bi().await.unwrap(); let mut buf = vec![0u8; chunk_size]; @@ -2066,10 +2086,10 @@ mod tests { async fn accept_world(ep: Endpoint, src: NodeId) { let incoming = ep.accept().await.unwrap(); - let mut iconn = incoming.accept().unwrap(); - let alpn = iconn.alpn().await.unwrap(); + let iconn = incoming.accept().unwrap(); let conn = iconn.await.unwrap(); - let node_id = conn.remote_node_id().unwrap(); + let alpn = conn.alpn(); + let node_id = conn.remote_node_id(); assert_eq!(node_id, src); assert_eq!(alpn, TEST_ALPN); let (mut send, mut recv) = conn.accept_bi().await.unwrap(); @@ -2159,7 +2179,7 @@ mod tests { async fn accept(ep: &Endpoint) -> NodeId { let incoming = ep.accept().await.unwrap(); let conn = incoming.await.unwrap(); - let node_id = conn.remote_node_id().unwrap(); + let node_id = conn.remote_node_id(); tracing::info!(node_id=%node_id.fmt_short(), "accepted connection"); node_id }