Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion protocols/stream/src/behaviour.rs
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ impl Default for Behaviour {

impl Behaviour {
pub fn new() -> Self {
let (dial_sender, dial_receiver) = mpsc::channel(0);
let (dial_sender, dial_receiver) = mpsc::channel(32);

Self {
shared: Arc::new(Mutex::new(Shared::new(dial_sender))),
Expand Down
15 changes: 8 additions & 7 deletions protocols/stream/src/control.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ use std::{

use futures::{
channel::{mpsc, oneshot},
SinkExt as _, StreamExt as _,
StreamExt as _,
};
use libp2p_identity::PeerId;
use libp2p_swarm::{Stream, StreamProtocol};
Expand Down Expand Up @@ -48,14 +48,15 @@ impl Control {
) -> Result<Stream, OpenStreamError> {
tracing::debug!(%peer, "Requesting new stream");

let mut new_stream_sender = Shared::lock(&self.shared).sender(peer);

let (sender, receiver) = oneshot::channel();

new_stream_sender
.send(NewStream { protocol, sender })
.await
.map_err(|e| io::Error::new(io::ErrorKind::ConnectionReset, e))?;
Shared::send_new_stream(
&self.shared,
peer,
NewStream { protocol, sender },
)
.await
.map_err(OpenStreamError::Io)?;

let stream = receiver
.await
Expand Down
58 changes: 53 additions & 5 deletions protocols/stream/src/shared.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ use std::{
sync::{Arc, Mutex, MutexGuard},
};

use futures::channel::mpsc;
use futures::{channel::mpsc, SinkExt as _};
use libp2p_identity::PeerId;
use libp2p_swarm::{ConnectionId, Stream, StreamProtocol};
use rand::seq::IteratorRandom as _;
Expand Down Expand Up @@ -122,7 +122,7 @@ impl Shared {
}
}

pub(crate) fn sender(&mut self, peer: PeerId) -> mpsc::Sender<NewStream> {
fn prepare_sender(&mut self, peer: PeerId) -> SenderAction {
let maybe_sender = self
.connections
.iter()
Expand All @@ -134,7 +134,9 @@ impl Shared {
Some(sender) => {
tracing::debug!("Returning sender to existing connection");

sender.clone()
SenderAction::Connected {
sender: sender.clone(),
}
}
None => {
tracing::debug!(%peer, "Not connected to peer, initiating dial");
Expand All @@ -144,9 +146,44 @@ impl Shared {
.entry(peer)
.or_insert_with(|| mpsc::channel(0));

let _ = self.dial_sender.try_send(peer);
SenderAction::Dial {
pending_sender: sender.clone(),
dial_sender: self.dial_sender.clone(),
peer_to_dial: peer,
}
}
}
}

pub(crate) async fn send_new_stream(
shared: &Arc<Mutex<Shared>>,
peer: PeerId,
new_stream: NewStream,
) -> io::Result<()> {
let action = {
let mut shared = Shared::lock(shared);
shared.prepare_sender(peer)
};

sender.clone()
match action {
SenderAction::Connected { mut sender } => sender
.send(new_stream)
.await
.map_err(|e| io::Error::new(io::ErrorKind::ConnectionReset, e)),
SenderAction::Dial {
mut pending_sender,
mut dial_sender,
peer_to_dial,
} => {
dial_sender
.send(peer_to_dial)
.await
.map_err(|e| io::Error::new(io::ErrorKind::ConnectionReset, e.clone()))?;

pending_sender
.send(new_stream)
.await
.map_err(|e| io::Error::new(io::ErrorKind::ConnectionReset, e))
}
}
}
Expand All @@ -171,3 +208,14 @@ impl Shared {
receiver
}
}

enum SenderAction {
Connected {
sender: mpsc::Sender<NewStream>,
},
Dial {
pending_sender: mpsc::Sender<NewStream>,
dial_sender: mpsc::Sender<PeerId>,
peer_to_dial: PeerId,
},
}
44 changes: 44 additions & 0 deletions protocols/stream/tests/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -78,3 +78,47 @@ async fn dial_errors_are_propagated() {
assert_eq!(e.kind(), io::ErrorKind::NotConnected);
assert_eq!("Dial error: no addresses for peer.", e.to_string());
}

#[tokio::test]
async fn backpressure_on_many_concurrent_dials() {
let _ = tracing_subscriber::fmt()
.with_env_filter(
EnvFilter::builder()
.with_default_directive(LevelFilter::DEBUG.into())
.from_env()
.unwrap(),
)
.with_test_writer()
.try_init();

let swarm1 = Swarm::new_ephemeral_tokio(|_| stream::Behaviour::new());
let control = swarm1.behaviour().new_control();

tokio::spawn(swarm1.loop_on_next());

// Spawn many concurrent dial attempts that will all fail
// Before the fix: some would silently drop and hang forever
// After the fix: all should fail with proper errors (backpressure propagated)
let mut handles = vec![];

for _ in 0..50 {
let mut control_clone = control.clone();
let handle = tokio::spawn(async move {
let result = control_clone.open_stream(PeerId::random(), PROTOCOL).await;
// All should fail, none should hang
assert!(result.is_err());
});
handles.push(handle);
}

// All tasks should complete (not hang indefinitely)
for handle in handles {
tokio::time::timeout(
std::time::Duration::from_secs(5),
handle,
)
.await
.expect("Task should not hang - backpressure should work")
.expect("Task should complete successfully");
}
}
Loading