Skip to content

Commit

Permalink
Make some things infallible
Browse files Browse the repository at this point in the history
  • Loading branch information
flub committed Jan 23, 2025
1 parent 4741be8 commit 68ce91e
Show file tree
Hide file tree
Showing 8 changed files with 77 additions and 57 deletions.
2 changes: 1 addition & 1 deletion iroh-relay/src/server/clients.rs
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ impl Clients {
/// peer is gone from the network.
///
/// Must be passed a matching connection_id.
pub(super) async fn unregister<'a>(&self, connection_id: u64, node_id: NodeId) {
pub(super) async fn unregister(&self, connection_id: u64, node_id: NodeId) {
trace!(
node_id = node_id.fmt_short(),
connection_id,
Expand Down
2 changes: 1 addition & 1 deletion iroh/examples/dht_discovery.rs
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@ async fn chat_server(args: Args) -> anyhow::Result<()> {
};
tokio::spawn(async move {
let connection = connecting.await?;
let remote_node_id = connection.remote_node_id()?;
let remote_node_id = connection.remote_node_id();
println!("got connection from {}", remote_node_id);
// just leave the tasks hanging. this is just an example.
let (mut writer, mut reader) = connection.accept_bi().await?;
Expand Down
2 changes: 1 addition & 1 deletion iroh/examples/echo.rs
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ impl ProtocolHandler for Echo {
// Wait for the connection to be fully established.
let connection = connecting.await?;
// We can get the remote's node id from the connection.
let node_id = connection.remote_node_id()?;
let node_id = connection.remote_node_id();
println!("accepted connection from {node_id}");

// Our protocol is a simple request-response protocol, so we expect the
Expand Down
6 changes: 3 additions & 3 deletions iroh/examples/listen-unreliable.rs
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ async fn main() -> anyhow::Result<()> {
// accept incoming connections, returns a normal QUIC connection

while let Some(incoming) = endpoint.accept().await {
let mut connecting = match incoming.accept() {
let connecting = match incoming.accept() {
Ok(connecting) => connecting,
Err(err) => {
warn!("incoming connection failed: {err:#}");
Expand All @@ -67,9 +67,9 @@ async fn main() -> anyhow::Result<()> {
continue;
}
};
let alpn = connecting.alpn().await?;
let conn = connecting.await?;
let node_id = conn.remote_node_id()?;
let alpn = conn.alpn();
let node_id = conn.remote_node_id();
info!(
"new (unreliable) connection from {node_id} with ALPN {} (coming from {})",
String::from_utf8_lossy(&alpn),
Expand Down
6 changes: 3 additions & 3 deletions iroh/examples/listen.rs
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ async fn main() -> anyhow::Result<()> {
);
// accept incoming connections, returns a normal QUIC connection
while let Some(incoming) = endpoint.accept().await {
let mut connecting = match incoming.accept() {
let connecting = match incoming.accept() {
Ok(connecting) => connecting,
Err(err) => {
warn!("incoming connection failed: {err:#}");
Expand All @@ -68,9 +68,9 @@ async fn main() -> anyhow::Result<()> {
continue;
}
};
let alpn = connecting.alpn().await?;
let conn = connecting.await?;
let node_id = conn.remote_node_id()?;
let alpn = conn.alpn();
let node_id = conn.remote_node_id();
info!(
"new connection from {node_id} with ALPN {} (coming from {})",
String::from_utf8_lossy(&alpn),
Expand Down
2 changes: 1 addition & 1 deletion iroh/examples/search.rs
Original file line number Diff line number Diff line change
Expand Up @@ -134,7 +134,7 @@ impl ProtocolHandler for BlobSearch {
// Wait for the connection to be fully established.
let connection = connecting.await?;
// We can get the remote's node id from the connection.
let node_id = connection.remote_node_id()?;
let node_id = connection.remote_node_id();
println!("accepted connection from {node_id}");

// Our protocol is a simple request-response protocol, so we expect the
Expand Down
2 changes: 1 addition & 1 deletion iroh/examples/transfer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,7 @@ async fn provide(size: u64, relay_url: Option<String>, relay_only: bool) -> anyh
}
};
let conn = connecting.await?;
let node_id = conn.remote_node_id()?;
let node_id = conn.remote_node_id();
info!(
"new connection from {node_id} with ALPN {} (coming from {})",
String::from_utf8_lossy(TRANSFER_ALPN),
Expand Down
112 changes: 66 additions & 46 deletions iroh/src/endpoint.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1257,16 +1257,24 @@ impl Connecting {
}

/// Extracts the ALPN protocol from the peer's handshake data.
// Note, we could totally provide this method to be on a Connection as well. But we'd
// need to wrap Connection too.
pub async fn alpn(&mut self) -> Result<Vec<u8>> {
pub async fn alpn(&mut self) -> Result<Vec<u8>, ConnectionError> {
let data = self.handshake_data().await?;
match data.downcast::<quinn::crypto::rustls::HandshakeData>() {
Ok(data) => match data.protocol {
Some(protocol) => Ok(protocol),
None => bail!("no ALPN protocol available"),
},
Err(_) => bail!("unknown handshake type"),
let data = data
.downcast::<quinn::crypto::rustls::HandshakeData>()
.expect("fixed crypto setup for iroh");
match data.protocol {
Some(proto) => Ok(proto),
None => {
// Using QUIC's CONNECTION_REFUSED error on the server-side is a bit odd,
// but perhaps not a total crime. Strictly speaking for QUIC this is an
// application error, but if we surfaced this as an application error it
// would be even more confusing for the user.
Err(ConnectionError::TransportError(TransportError {
code: TransportErrorCode::CONNECTION_REFUSED,
frame: None,
reason: String::from("iroh connections must use an ALPN"),
}))
}
}
}
}
Expand All @@ -1281,6 +1289,27 @@ impl Future for Connecting {
Poll::Ready(Err(err)) => Poll::Ready(Err(err)),
Poll::Ready(Ok(inner)) => {
let conn = Connection { inner };

// An iroh connection MUST always have an ALPN. Note that this check is
// only valid for accepting connections, but we **currently** do not use
// this `Connecting` struct when establishing connections. Once we do we'll
// have to make this check conditional.
let handshake_data = conn.handshake_data();
let handshake_data = handshake_data
.downcast::<quinn::crypto::rustls::HandshakeData>()
.expect("fixed crypto setup for iroh");
if handshake_data.protocol.is_none() {
// Using QUIC's CONNECTION_REFUSED error on the server-side is a bit odd,
// but perhaps not a total crime. Strictly speaking for QUIC this is an
// application error, but if we surfaced this as an application error it
// would be even more confusing for the user.
return Poll::Ready(Err(ConnectionError::TransportError(TransportError {
code: TransportErrorCode::CONNECTION_REFUSED,
frame: None,
reason: String::from("iroh connections must use an ALPN"),
})));
}

try_send_rtt_msg(&conn, this.ep);
Poll::Ready(Ok(conn))
}
Expand Down Expand Up @@ -1511,17 +1540,20 @@ impl Connection {
///
/// [`Connection::handshake_data()`]: crate::Connecting::handshake_data
#[inline]
pub fn handshake_data(&self) -> Option<Box<dyn Any>> {
self.inner.handshake_data()
pub fn handshake_data(&self) -> Box<dyn Any> {
self.inner
.handshake_data()
.expect("we always have handshake data")
}

/// Extracts the ALPN protocol from the peer's handshake data.
pub fn alpn(&self) -> Option<Vec<u8>> {
let data = self.handshake_data()?;
match data.downcast::<quinn::crypto::rustls::HandshakeData>() {
Ok(data) => data.protocol,
Err(_) => None,
}
pub fn alpn(&self) -> Vec<u8> {
let data = self.handshake_data();
let data = data
.downcast::<quinn::crypto::rustls::HandshakeData>()
.expect("fixed crypto setup for iroh");
data.protocol
.expect("checked in <Connecting as Future>::poll")
}

/// Cryptographic identity of the peer.
Expand All @@ -1533,8 +1565,10 @@ impl Connection {
/// [`Session`]: quinn_proto::crypto::Session
/// [`downcast`]: Box::downcast
#[inline]
pub fn peer_identity(&self) -> Option<Box<dyn Any>> {
self.inner.peer_identity()
pub fn peer_identity(&self) -> Box<dyn Any> {
self.inner
.peer_identity()
.expect("we always have a peer identity")
}

/// Returns the [`NodeId`] from the peer's TLS certificate.
Expand All @@ -1545,25 +1579,14 @@ impl Connection {
/// connection.
///
/// [`PublicKey`]: iroh_base::PublicKey
// TODO: Would be nice if this could be infallible.
pub fn remote_node_id(&self) -> Result<NodeId> {
pub fn remote_node_id(&self) -> NodeId {
let data = self.peer_identity();
match data {
None => bail!("no peer certificate found"),
Some(data) => match data.downcast::<Vec<rustls::pki_types::CertificateDer>>() {
Ok(certs) => {
if certs.len() != 1 {
bail!(
"expected a single peer certificate, but {} found",
certs.len()
);
}
let cert = tls::certificate::parse(&certs[0])?;
Ok(cert.peer_id())
}
Err(_) => bail!("invalid peer certificate"),
},
}
let certs = data
.downcast::<Vec<rustls::pki_types::CertificateDer>>()
.expect("we always have a cert");
debug_assert_eq!(certs.len(), 1);
let cert = tls::certificate::parse(&certs[0]).expect("valid cert");
cert.peer_id()
}

/// A stable identifier for this connection.
Expand Down Expand Up @@ -1624,10 +1647,7 @@ impl Connection {
/// function.
fn try_send_rtt_msg(conn: &Connection, magic_ep: &Endpoint) {
// If we can't notify the rtt-actor that's not great but not critical.
let Ok(node_id) = conn.remote_node_id() else {
warn!(?conn, "failed to get remote node id");
return;
};
let node_id = conn.remote_node_id();
let Ok(conn_type_changes) = magic_ep.conn_type(node_id) else {
warn!(?conn, "failed to create conn_type stream");
return;
Expand Down Expand Up @@ -1953,7 +1973,7 @@ mod tests {
info!("[server] round {i}");
let incoming = ep.accept().await.unwrap();
let conn = incoming.await.unwrap();
let node_id = conn.remote_node_id().unwrap();
let node_id = conn.remote_node_id();
info!(%i, peer = %node_id.fmt_short(), "accepted connection");
let (mut send, mut recv) = conn.accept_bi().await.unwrap();
let mut buf = vec![0u8; chunk_size];
Expand Down Expand Up @@ -2066,10 +2086,10 @@ mod tests {

async fn accept_world(ep: Endpoint, src: NodeId) {
let incoming = ep.accept().await.unwrap();
let mut iconn = incoming.accept().unwrap();
let alpn = iconn.alpn().await.unwrap();
let iconn = incoming.accept().unwrap();
let conn = iconn.await.unwrap();
let node_id = conn.remote_node_id().unwrap();
let alpn = conn.alpn();
let node_id = conn.remote_node_id();
assert_eq!(node_id, src);
assert_eq!(alpn, TEST_ALPN);
let (mut send, mut recv) = conn.accept_bi().await.unwrap();
Expand Down Expand Up @@ -2159,7 +2179,7 @@ mod tests {
async fn accept(ep: &Endpoint) -> NodeId {
let incoming = ep.accept().await.unwrap();
let conn = incoming.await.unwrap();
let node_id = conn.remote_node_id().unwrap();
let node_id = conn.remote_node_id();
tracing::info!(node_id=%node_id.fmt_short(), "accepted connection");
node_id
}
Expand Down

0 comments on commit 68ce91e

Please sign in to comment.