From 85bc539253365aa876d2c4eec9b7b4e4b5fa12df Mon Sep 17 00:00:00 2001 From: hanabi1224 Date: Tue, 9 Dec 2025 00:09:46 +0800 Subject: [PATCH] feat(swarm): smart dialing --- Cargo.lock | 2 +- Cargo.toml | 2 +- swarm/CHANGELOG.md | 4 + swarm/Cargo.toml | 2 +- swarm/src/connection/pool.rs | 38 +- swarm/src/connection/pool/concurrent_dial.rs | 71 +++- swarm/src/connection/pool/dial_ranker.rs | 326 ++++++++++++++++++ .../src/connection/pool/dial_ranker/tests.rs | 202 +++++++++++ swarm/src/lib.rs | 94 ++++- 9 files changed, 700 insertions(+), 41 deletions(-) create mode 100644 swarm/src/connection/pool/dial_ranker.rs create mode 100644 swarm/src/connection/pool/dial_ranker/tests.rs diff --git a/Cargo.lock b/Cargo.lock index fdeada37500..bd9aa9d87a3 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -3068,7 +3068,7 @@ dependencies = [ [[package]] name = "libp2p-swarm" -version = "0.47.0" +version = "0.47.1" dependencies = [ "criterion", "either", diff --git a/Cargo.toml b/Cargo.toml index 9e65d1ab62f..b5bcbd30e77 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -102,7 +102,7 @@ libp2p-rendezvous = { version = "0.17.0", path = "protocols/rendezvous" } libp2p-request-response = { version = "0.29.0", path = "protocols/request-response" } libp2p-server = { version = "0.12.7", path = "misc/server" } libp2p-stream = { version = "0.4.0-alpha", path = "protocols/stream" } -libp2p-swarm = { version = "0.47.0", path = "swarm" } +libp2p-swarm = { version = "0.47.1", path = "swarm" } libp2p-swarm-derive = { version = "=0.35.1", path = "swarm-derive" } # `libp2p-swarm-derive` may not be compatible with different `libp2p-swarm` non-breaking releases. E.g. `libp2p-swarm` might introduce a new enum variant `FromSwarm` (which is `#[non-exhaustive]`) in a non-breaking release. Older versions of `libp2p-swarm-derive` would not forward this enum variant within the `NetworkBehaviour` hierarchy. Thus the version pinning is required. libp2p-swarm-test = { version = "0.6.0", path = "swarm-test" } libp2p-tcp = { version = "0.44.0", path = "transports/tcp" } diff --git a/swarm/CHANGELOG.md b/swarm/CHANGELOG.md index 90d104156f6..ff96b46190e 100644 --- a/swarm/CHANGELOG.md +++ b/swarm/CHANGELOG.md @@ -1,3 +1,7 @@ +## 0.47.1 + +- Add smart dialing support. + ## 0.47.0 - Remove `async-std` support. diff --git a/swarm/Cargo.toml b/swarm/Cargo.toml index 81d03ca9559..29b3ccc60ed 100644 --- a/swarm/Cargo.toml +++ b/swarm/Cargo.toml @@ -3,7 +3,7 @@ name = "libp2p-swarm" edition.workspace = true rust-version = { workspace = true } description = "The libp2p swarm" -version = "0.47.0" +version = "0.47.1" authors = ["Parity Technologies "] license = "MIT" repository = "https://github.com/libp2p/rust-libp2p" diff --git a/swarm/src/connection/pool.rs b/swarm/src/connection/pool.rs index 37ae63af033..bbf30fe2a91 100644 --- a/swarm/src/connection/pool.rs +++ b/swarm/src/connection/pool.rs @@ -18,20 +18,23 @@ // LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING // FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER // DEALINGS IN THE SOFTWARE. + use std::{ collections::HashMap, convert::Infallible, fmt, num::{NonZeroU8, NonZeroUsize}, pin::Pin, + sync::Arc, task::{Context, Poll, Waker}, }; -use concurrent_dial::ConcurrentDial; +use concurrent_dial::{ConcurrentDial, Dial}; +use dial_ranker::DialRanker; use fnv::FnvHashMap; use futures::{ channel::{mpsc, oneshot}, - future::{poll_fn, BoxFuture, Either}, + future::{poll_fn, Either}, prelude::*, ready, stream::{FuturesUnordered, SelectAll}, @@ -44,16 +47,16 @@ use libp2p_core::{ use tracing::Instrument; use web_time::{Duration, Instant}; +use super::{ + Connected, Connection, ConnectionError, ConnectionId, IncomingInfo, + PendingInboundConnectionError, PendingOutboundConnectionError, PendingPoint, +}; use crate::{ - connection::{ - Connected, Connection, ConnectionError, ConnectionId, IncomingInfo, - PendingInboundConnectionError, PendingOutboundConnectionError, PendingPoint, - }, - transport::TransportError, - ConnectedPoint, ConnectionHandler, Executor, Multiaddr, PeerId, + transport::TransportError, ConnectedPoint, ConnectionHandler, Executor, Multiaddr, PeerId, }; mod concurrent_dial; +pub(crate) mod dial_ranker; mod task; enum ExecSwitch { @@ -142,6 +145,9 @@ where /// How long a connection should be kept alive once it starts idling. idle_connection_timeout: Duration, + + /// Ranker that determines the ranking of outgoing connection attempts. + dial_ranker: Option>, } #[derive(Debug)] @@ -333,6 +339,7 @@ where no_established_connections_waker: None, established_connection_events: Default::default(), new_connection_dropped_listeners: Default::default(), + dial_ranker: config.dial_ranker, } } @@ -413,15 +420,7 @@ where /// that establishes and negotiates the connection. pub(crate) fn add_outgoing( &mut self, - dials: Vec< - BoxFuture< - 'static, - ( - Multiaddr, - Result<(PeerId, StreamMuxerBox), TransportError>, - ), - >, - >, + dials: Vec<(Multiaddr, Dial)>, peer: Option, role_override: Endpoint, port_use: PortUse, @@ -438,7 +437,7 @@ where self.executor.spawn( task::new_for_pending_outgoing_connection( connection_id, - ConcurrentDial::new(dials, concurrency_factor), + ConcurrentDial::new(dials, concurrency_factor, self.dial_ranker.clone()), abort_receiver, self.pending_connection_events_tx.clone(), ) @@ -979,6 +978,8 @@ pub(crate) struct PoolConfig { pub(crate) per_connection_event_buffer_size: usize, /// Number of addresses concurrently dialed for a single outbound connection attempt. pub(crate) dial_concurrency_factor: NonZeroU8, + /// Ranker that determines the ranking of outgoing connection attempts. + pub(crate) dial_ranker: Option>, /// How long a connection should be kept alive once it is idling. pub(crate) idle_connection_timeout: Duration, /// The configured override for substream protocol upgrades, if any. @@ -1000,6 +1001,7 @@ impl PoolConfig { idle_connection_timeout: Duration::from_secs(10), substream_upgrade_protocol_override: None, max_negotiating_inbound_streams: 128, + dial_ranker: None, } } diff --git a/swarm/src/connection/pool/concurrent_dial.rs b/swarm/src/connection/pool/concurrent_dial.rs index 99f0b385884..23e897cf077 100644 --- a/swarm/src/connection/pool/concurrent_dial.rs +++ b/swarm/src/connection/pool/concurrent_dial.rs @@ -19,22 +19,28 @@ // DEALINGS IN THE SOFTWARE. use std::{ + collections::HashMap, num::NonZeroU8, pin::Pin, + sync::Arc, task::{Context, Poll}, + time::Duration, }; use futures::{ future::{BoxFuture, Future}, ready, stream::{FuturesUnordered, StreamExt}, + FutureExt, }; +use futures_timer::Delay; use libp2p_core::muxing::StreamMuxerBox; use libp2p_identity::PeerId; +use super::DialRanker; use crate::{transport::TransportError, Multiaddr}; -type Dial = BoxFuture< +pub(crate) type Dial = BoxFuture< 'static, ( Multiaddr, @@ -43,29 +49,56 @@ type Dial = BoxFuture< >; pub(crate) struct ConcurrentDial { + concurrency_factor: NonZeroU8, dials: FuturesUnordered, - pending_dials: Box + Send>, + pending_dials: Box, Dial)> + Send>, errors: Vec<(Multiaddr, TransportError)>, } impl Unpin for ConcurrentDial {} impl ConcurrentDial { - pub(crate) fn new(pending_dials: Vec, concurrency_factor: NonZeroU8) -> Self { - let mut pending_dials = pending_dials.into_iter(); - + pub(crate) fn new( + pending_dials: Vec<(Multiaddr, Dial)>, + concurrency_factor: NonZeroU8, + dial_ranker: Option>, + ) -> Self { let dials = FuturesUnordered::new(); - for dial in pending_dials.by_ref() { - dials.push(dial); - if dials.len() == concurrency_factor.get() as usize { - break; - } - } - + let pending_dials: Vec<_> = if let Some(dial_ranker) = dial_ranker { + let addresses = pending_dials.iter().map(|(k, _)| k.clone()).collect(); + let mut dials: HashMap = HashMap::from_iter(pending_dials); + dial_ranker(addresses) + .into_iter() + .filter_map(|(addr, delay)| dials.remove(&addr).map(|dial| (addr, delay, dial))) + .collect() + } else { + pending_dials + .into_iter() + .map(|(addr, dial)| (addr, None, dial)) + .collect() + }; Self { + concurrency_factor, dials, errors: Default::default(), - pending_dials: Box::new(pending_dials), + pending_dials: Box::new(pending_dials.into_iter()), + } + } + + fn dial_pending(&mut self) -> bool { + if let Some((_, delay, dial)) = self.pending_dials.next() { + self.dials.push( + async move { + if let Some(delay) = delay { + Delay::new(delay).await; + } + dial.await + } + .boxed(), + ); + true + } else { + false } } } @@ -92,12 +125,16 @@ impl Future for ConcurrentDial { } Some((addr, Err(e))) => { self.errors.push((addr, e)); - if let Some(dial) = self.pending_dials.next() { - self.dials.push(dial) - } + self.dial_pending(); } None => { - return Poll::Ready(Err(std::mem::take(&mut self.errors))); + while self.dials.len() < self.concurrency_factor.get() as usize + && self.dial_pending() + {} + + if self.dials.is_empty() { + return Poll::Ready(Err(std::mem::take(&mut self.errors))); + } } } } diff --git a/swarm/src/connection/pool/dial_ranker.rs b/swarm/src/connection/pool/dial_ranker.rs new file mode 100644 index 00000000000..ceb467f913b --- /dev/null +++ b/swarm/src/connection/pool/dial_ranker.rs @@ -0,0 +1,326 @@ +// Copyright 2021 Protocol Labs. +// +// Permission is hereby granted, free of charge, to any person obtaining a +// copy of this software and associated documentation files (the "Software"), +// to deal in the Software without restriction, including without limitation +// the rights to use, copy, modify, merge, publish, distribute, sublicense, +// and/or sell copies of the Software, and to permit persons to whom the +// Software is furnished to do so, subject to the following conditions: +// +// The above copyright notice and this permission notice shall be included in +// all copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS +// OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING +// FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER +// DEALINGS IN THE SOFTWARE. + +#[cfg(test)] +mod tests; + +use std::{borrow::Cow, cmp::Ordering}; + +use libp2p_core::multiaddr::Protocol; + +use super::*; + +// The 250ms value is from happy eyeballs RFC 8305. This is a rough estimate of 1 RTT +// Duration by which TCP dials are delayed relative to the last QUIC dial +const PUBLIC_TCP_DELAY: Duration = Duration::from_millis(250); +const PRIVATE_TCP_DELAY: Duration = Duration::from_millis(30); + +// duration by which QUIC dials are delayed relative to previous QUIC dial +const PUBLIC_QUIC_DELAY: Duration = Duration::from_millis(250); +const PRIVATE_QUIC_DELAY: Duration = Duration::from_millis(30); + +// RelayDelay is the duration by which relay dials are delayed relative to direct addresses +const RELAY_DELAY: Duration = Duration::from_millis(250); + +// delay for other transport addresses. This will apply to /webrtc-direct. +const PUBLIC_OTHER_DELAY: Duration = Duration::from_millis(1000); +const PRIVATE_OTHER_DELAY: Duration = Duration::from_millis(100); + +pub(crate) type DialRanker = fn(Vec) -> Vec<(Multiaddr, Option)>; + +// Ported from +pub(crate) fn smart_dial_ranker( + mut addresses: Vec, +) -> Vec<(Multiaddr, Option)> { + let mut relay = vec![]; + addresses.retain(|a| { + if a.iter().any(|p| matches!(p, Protocol::P2pCircuit)) { + relay.push(a.clone()); + false + } else { + true + } + }); + let mut private = vec![]; + addresses.retain(|a| { + if let Some(Protocol::Ip4(ip4)) = a.iter().find(|p| matches!(p, Protocol::Ip4(_))) { + if ip4.is_private() { + private.push(a.clone()); + false + } else { + true + } + } else if a.iter().any(|p| matches!(p, Protocol::Ip6zone(_))) { + private.push(a.clone()); + false + } else if let Some(dns) = get_dns(a) { + if is_dns_private(dns.as_ref()) { + private.push(a.clone()); + false + } else { + true + } + } else { + true + } + }); + let mut public = vec![]; + addresses.retain(|a| { + if a.iter() + .any(|p| matches!(p, Protocol::Ip4(_) | Protocol::Ip6(_))) + { + public.push(a.clone()); + false + } else { + true + } + }); + let relay_offset = if public.is_empty() { + Duration::ZERO + } else { + RELAY_DELAY + }; + let mut result = Vec::with_capacity(addresses.len()); + result.extend(get_addresses_delay( + private, + PRIVATE_TCP_DELAY, + PRIVATE_QUIC_DELAY, + PRIVATE_OTHER_DELAY, + Duration::ZERO, + )); + result.extend(get_addresses_delay( + public, + PUBLIC_TCP_DELAY, + PUBLIC_QUIC_DELAY, + PUBLIC_OTHER_DELAY, + Duration::ZERO, + )); + result.extend(get_addresses_delay( + relay, + PUBLIC_TCP_DELAY, + PUBLIC_QUIC_DELAY, + PUBLIC_OTHER_DELAY, + relay_offset, + )); + let max_delay = if let Some((_, Some(delay))) = result.last() { + *delay + } else { + Duration::ZERO + }; + result.extend( + addresses + .into_iter() + .map(|a| (a, Some(max_delay + PUBLIC_OTHER_DELAY))), + ); + result +} + +// Ported from +fn get_addresses_delay( + mut addresses: Vec, + tcp_delay: Duration, + quic_delay: Duration, + other_delay: Duration, + offset: Duration, +) -> Vec<(Multiaddr, Option)> { + if addresses.is_empty() { + return vec![]; + } + + addresses.sort_by_key(score); + + // addrs is now sorted by (Transport, IPVersion). Reorder addrs for happy eyeballs dialing. + // For QUIC and TCP, if we have both IPv6 and IPv4 addresses, move the + // highest priority IPv4 address to the second position. + let mut happy_eyeballs_quic = false; + let mut happy_eyeballs_tcp = false; + let mut tcp_start_index = 0; + { + // If the first QUIC address is IPv6 move the first QUIC IPv4 address to second position + if is_quic_address(&addresses[0]) + && addresses[0].iter().any(|p| matches!(p, Protocol::Ip6(_))) + { + for j in 1..addresses.len() { + let addr = &addresses[j]; + if is_quic_address(addr) && addr.iter().any(|p| matches!(p, Protocol::Ip4(_))) { + // The first IPv4 address is at position j + // Move the jth element at position 1 shifting the affected elements + if j > 1 { + let tmp = addresses.remove(j); + addresses.insert(1, tmp); + } + happy_eyeballs_quic = true; + tcp_start_index = j + 1; + break; + } + } + } + + while tcp_start_index < addresses.len() { + if addresses[tcp_start_index] + .iter() + .any(|p| matches!(p, Protocol::Tcp(_))) + { + break; + } + tcp_start_index += 1; + } + + // If the first TCP address is IPv6 move the first TCP IPv4 address to second position + if tcp_start_index < addresses.len() + && addresses[tcp_start_index] + .iter() + .any(|p| matches!(p, Protocol::Ip6(_))) + { + for j in (tcp_start_index + 1)..addresses.len() { + let addr = &addresses[j]; + if addr.iter().any(|p| matches!(p, Protocol::Tcp(_))) + && addr.iter().any(|p| matches!(p, Protocol::Ip4(_))) + { + // First TCP IPv4 address is at position j, move it to position tcpStartIdx+1 + // which is the second priority TCP address + if j > tcp_start_index + 1 { + let tmp = addresses.remove(j); + addresses.insert(tcp_start_index + 1, tmp); + } + happy_eyeballs_tcp = true; + break; + } + } + } + } + + let mut result = Vec::with_capacity(addresses.len()); + let mut tcp_first_dial_delay = Duration::ZERO; + let mut last_quic_or_tcp_delay = Duration::ZERO; + for (i, addr) in addresses.into_iter().enumerate() { + let mut delay = Duration::ZERO; + if is_quic_address(&addr) { + // We dial an IPv6 address, then after quicDelay an IPv4 + // address, then after a further quicDelay we dial the rest of the addresses. + match i.cmp(&1) { + Ordering::Equal => { + delay = quic_delay; + } + Ordering::Greater => { + // If we have happy eyeballs for QUIC, dials after the second position + // will be delayed by 2*quicDelay + if happy_eyeballs_quic { + delay = 2 * quic_delay; + } else { + delay = quic_delay; + } + } + _ => {} + } + last_quic_or_tcp_delay = delay; + tcp_first_dial_delay = delay + tcp_delay; + } else if addr.iter().any(|p| matches!(p, Protocol::Tcp(_))) { + // We dial an IPv6 address, then after tcpDelay an IPv4 + // address, then after a further tcpDelay we dial the rest of the addresses. + match i.cmp(&(tcp_start_index + 1)) { + Ordering::Equal => { + delay = tcp_delay; + } + Ordering::Greater => { + // If we have happy eyeballs for TCP, dials after the second position + // will be delayed by 2*tcpDelay + if happy_eyeballs_tcp { + delay = 2 * tcp_delay; + } else { + delay = tcp_delay; + } + } + _ => {} + } + delay += tcp_first_dial_delay; + last_quic_or_tcp_delay = delay; + } else { + // if it's neither quic, webtransport, tcp, or websocket address + delay = last_quic_or_tcp_delay + other_delay; + } + match offset + delay { + Duration::ZERO => { + result.push((addr, None)); + } + d => { + result.push((addr, Some(d))); + } + } + } + result +} + +// Ported from +fn score(a: &Multiaddr) -> i32 { + let mut ip4_weight = 0; + if a.iter().any(|p| matches!(p, Protocol::Ip4(_))) { + ip4_weight = 1 << 18; + } + + if a.iter().any(|p| matches!(p, Protocol::WebTransport)) { + if let Some(Protocol::Udp(p)) = a.iter().find(|p| matches!(p, Protocol::Udp(_))) { + return ip4_weight + (1 << 19) + p as i32; + } + } + + if a.iter().any(|p| matches!(p, Protocol::Quic)) { + if let Some(Protocol::Udp(p)) = a.iter().find(|p| matches!(p, Protocol::Udp(_))) { + return ip4_weight + (1 << 17) + p as i32; + } + } + + if a.iter().any(|p| matches!(p, Protocol::QuicV1)) { + if let Some(Protocol::Udp(p)) = a.iter().find(|p| matches!(p, Protocol::Udp(_))) { + return ip4_weight + p as i32; + } + } + + if let Some(Protocol::Tcp(p)) = a.iter().find(|p| matches!(p, Protocol::Tcp(_))) { + return ip4_weight + (1 << 20) + p as i32; + } + + if a.iter().any(|p| matches!(p, Protocol::WebRTCDirect)) { + return 1 << 21; + } + + 1 << 30 +} + +fn is_quic_address(a: &Multiaddr) -> bool { + a.iter() + .any(|p| matches!(p, Protocol::Quic | Protocol::QuicV1)) +} + +fn get_dns(a: &Multiaddr) -> Option> { + if let Some(Protocol::Dns(dns)) = a.iter().find(|p| matches!(p, Protocol::Dns(_))) { + Some(dns) + } else if let Some(Protocol::Dns4(dns)) = a.iter().find(|p| matches!(p, Protocol::Dns4(_))) { + Some(dns) + } else if let Some(Protocol::Dns6(dns)) = a.iter().find(|p| matches!(p, Protocol::Dns6(_))) { + Some(dns) + } else { + None + } +} + +fn is_dns_private(dns: &str) -> bool { + dns == "localhost" || dns.ends_with(".localhost") +} diff --git a/swarm/src/connection/pool/dial_ranker/tests.rs b/swarm/src/connection/pool/dial_ranker/tests.rs new file mode 100644 index 00000000000..e6aa0715bc9 --- /dev/null +++ b/swarm/src/connection/pool/dial_ranker/tests.rs @@ -0,0 +1,202 @@ +// Copyright 2021 Protocol Labs. +// +// Permission is hereby granted, free of charge, to any person obtaining a +// copy of this software and associated documentation files (the "Software"), +// to deal in the Software without restriction, including without limitation +// the rights to use, copy, modify, merge, publish, distribute, sublicense, +// and/or sell copies of the Software, and to permit persons to whom the +// Software is furnished to do so, subject to the following conditions: +// +// The above copyright notice and this permission notice shall be included in +// all copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS +// OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING +// FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER +// DEALINGS IN THE SOFTWARE. + +//! Tests are ported from + +use super::*; + +#[test] +fn test_quic_delay_ipv4() { + let q1v1: Multiaddr = "/ip4/1.2.3.4/udp/1/quic-v1".parse().unwrap(); + let q2v1: Multiaddr = "/ip4/1.2.3.4/udp/2/quic-v1".parse().unwrap(); + let q3v1: Multiaddr = "/ip4/1.2.3.4/udp/3/quic-v1".parse().unwrap(); + + let addresses = vec![q1v1.clone(), q2v1.clone(), q3v1.clone()]; + let output = smart_dial_ranker(addresses); + assert_eq!( + output, + vec![ + (q1v1, None), + (q2v1, Some(PUBLIC_QUIC_DELAY)), + (q3v1, Some(PUBLIC_QUIC_DELAY)), + ] + ) +} + +#[test] +fn test_quic_delay_ipv6() { + let q1v16: Multiaddr = "/ip6/1::2/udp/1/quic-v1".parse().unwrap(); + let q2v16: Multiaddr = "/ip6/1::2/udp/2/quic-v1".parse().unwrap(); + let q3v16: Multiaddr = "/ip6/1::2/udp/3/quic-v1".parse().unwrap(); + + let addresses = vec![q1v16.clone(), q2v16.clone(), q3v16.clone()]; + let output = smart_dial_ranker(addresses); + assert_eq!( + output, + vec![ + (q1v16, None), + (q2v16, Some(PUBLIC_QUIC_DELAY)), + (q3v16, Some(PUBLIC_QUIC_DELAY)), + ] + ) +} + +#[test] +fn test_quic_delay_ipv4_ipv6() { + let q2v1: Multiaddr = "/ip4/1.2.3.4/udp/2/quic-v1".parse().unwrap(); + let q1v16: Multiaddr = "/ip6/1::2/udp/1/quic-v1".parse().unwrap(); + + let addresses = vec![q1v16.clone(), q2v1.clone()]; + let output = smart_dial_ranker(addresses); + assert_eq!( + output, + vec![(q1v16, None), (q2v1, Some(PUBLIC_QUIC_DELAY)),] + ) +} + +#[test] +fn test_quic_with_tcp_ipv6_ipv4() { + let q1v1: Multiaddr = "/ip4/1.2.3.4/udp/1/quic-v1".parse().unwrap(); + let q2v1: Multiaddr = "/ip4/1.2.3.4/udp/2/quic-v1".parse().unwrap(); + + let q1v16: Multiaddr = "/ip6/1::2/udp/1/quic-v1".parse().unwrap(); + let q2v16: Multiaddr = "/ip6/1::2/udp/2/quic-v1".parse().unwrap(); + let q3v16: Multiaddr = "/ip6/1::2/udp/3/quic-v1".parse().unwrap(); + + let t1: Multiaddr = "/ip4/1.2.3.5/tcp/1".parse().unwrap(); + let t1v6: Multiaddr = "/ip6/1::2/tcp/1".parse().unwrap(); + let t2: Multiaddr = "/ip4/1.2.3.4/tcp/2".parse().unwrap(); + let t3: Multiaddr = "/ip4/1.2.3.4/tcp/3".parse().unwrap(); + + let addresses = vec![ + q1v1.clone(), + q1v16.clone(), + q2v16.clone(), + q3v16.clone(), + q2v1.clone(), + t1.clone(), + t1v6.clone(), + t2.clone(), + t3.clone(), + ]; + let output = smart_dial_ranker(addresses); + assert_eq!( + output, + vec![ + (q1v16, None), + (q1v1, Some(PUBLIC_QUIC_DELAY)), + (q2v16, Some(2 * PUBLIC_QUIC_DELAY)), + (q3v16, Some(2 * PUBLIC_QUIC_DELAY)), + (q2v1, Some(2 * PUBLIC_QUIC_DELAY)), + (t1v6, Some(3 * PUBLIC_QUIC_DELAY)), + (t1, Some(4 * PUBLIC_QUIC_DELAY)), + (t2, Some(5 * PUBLIC_QUIC_DELAY)), + (t3, Some(5 * PUBLIC_QUIC_DELAY)), + ] + ) +} + +#[test] +fn test_quic_ip4_with_tcp() { + let q1v1: Multiaddr = "/ip4/1.2.3.4/udp/1/quic-v1".parse().unwrap(); + + let t1: Multiaddr = "/ip4/1.2.3.5/tcp/1".parse().unwrap(); + let t1v6: Multiaddr = "/ip6/1::2/tcp/1".parse().unwrap(); + let t2: Multiaddr = "/ip4/1.2.3.4/tcp/2".parse().unwrap(); + + let addresses = vec![q1v1.clone(), t2.clone(), t1v6.clone(), t1.clone()]; + let output = smart_dial_ranker(addresses); + assert_eq!( + output, + vec![ + (q1v1, None), + (t1v6, Some(PUBLIC_QUIC_DELAY)), + (t1, Some(2 * PUBLIC_QUIC_DELAY)), + (t2, Some(3 * PUBLIC_QUIC_DELAY)), + ] + ) +} + +#[test] +fn test_quic_ip4_with_tcp_ipv4() { + let q1v1: Multiaddr = "/ip4/1.2.3.4/udp/1/quic-v1".parse().unwrap(); + + let t1: Multiaddr = "/ip4/1.2.3.5/tcp/1".parse().unwrap(); + let t2: Multiaddr = "/ip4/1.2.3.4/tcp/2".parse().unwrap(); + let t3: Multiaddr = "/ip4/1.2.3.4/tcp/3".parse().unwrap(); + + let addresses = vec![q1v1.clone(), t2.clone(), t3.clone(), t1.clone()]; + let output = smart_dial_ranker(addresses); + assert_eq!( + output, + vec![ + (q1v1, None), + (t1, Some(PUBLIC_TCP_DELAY)), + (t2, Some(2 * PUBLIC_QUIC_DELAY)), + (t3, Some(2 * PUBLIC_TCP_DELAY)), + ] + ) +} + +#[test] +fn test_quic_ip4_with_two_tcp() { + let q1v1: Multiaddr = "/ip4/1.2.3.4/udp/1/quic-v1".parse().unwrap(); + + let t1v6: Multiaddr = "/ip6/1::2/tcp/1".parse().unwrap(); + let t2: Multiaddr = "/ip4/1.2.3.4/tcp/2".parse().unwrap(); + + let addresses = vec![q1v1.clone(), t1v6.clone(), t2.clone()]; + let output = smart_dial_ranker(addresses); + assert_eq!( + output, + vec![ + (q1v1, None), + (t1v6, Some(PUBLIC_TCP_DELAY)), + (t2, Some(2 * PUBLIC_TCP_DELAY)), + ] + ) +} + +#[test] +fn test_tcp_ip4_ip6() { + let t1: Multiaddr = "/ip4/1.2.3.5/tcp/1".parse().unwrap(); + let t1v6: Multiaddr = "/ip6/1::2/tcp/1".parse().unwrap(); + let t2: Multiaddr = "/ip4/1.2.3.4/tcp/2".parse().unwrap(); + let t3: Multiaddr = "/ip4/1.2.3.4/tcp/3".parse().unwrap(); + + let addresses = vec![t1.clone(), t2.clone(), t1v6.clone(), t3.clone()]; + let output = smart_dial_ranker(addresses); + assert_eq!( + output, + vec![ + (t1v6, None), + (t1, Some(PUBLIC_TCP_DELAY)), + (t2, Some(2 * PUBLIC_TCP_DELAY)), + (t3, Some(2 * PUBLIC_TCP_DELAY)), + ] + ) +} + +#[test] +fn test_empty() { + let addresses = vec![]; + let output = smart_dial_ranker(addresses); + assert_eq!(output, vec![]) +} diff --git a/swarm/src/lib.rs b/swarm/src/lib.rs index d0ae6118190..96d7382ee6f 100644 --- a/swarm/src/lib.rs +++ b/swarm/src/lib.rs @@ -97,6 +97,7 @@ use std::{ error, fmt, io, num::{NonZeroU32, NonZeroU8, NonZeroUsize}, pin::Pin, + sync::Arc, task::{Context, Poll}, time::Duration, }; @@ -109,7 +110,10 @@ pub use behaviour::{ }; pub use connection::{pool::ConnectionCounters, ConnectionError, ConnectionId, SupportedProtocols}; use connection::{ - pool::{EstablishedConnection, Pool, PoolConfig, PoolEvent}, + pool::{ + dial_ranker::{smart_dial_ranker, DialRanker}, + EstablishedConnection, Pool, PoolConfig, PoolEvent, + }, IncomingInfo, PendingInboundConnectionError, PendingOutboundConnectionError, }; use dial_opts::{DialOpts, PeerCondition}; @@ -515,7 +519,7 @@ where let dials = addresses .into_iter() - .map(|a| match peer_id.map_or(Ok(a.clone()), |p| a.with_p2p(p)) { + .map(|a| (a.clone(), match peer_id.map_or(Ok(a.clone()), |p| a.with_p2p(p)) { Ok(address) => { let dial = self.transport.dial( address.clone(), @@ -539,7 +543,7 @@ where Err(TransportError::MultiaddrNotSupported(address)), )) .boxed(), - }) + },)) .collect(); self.pool.add_outgoing( @@ -1509,6 +1513,17 @@ impl Config { self.pool_config.idle_connection_timeout = timeout; self } + + /// Sets a dial ranker that determines the ranking of outgoing connection attempts. + pub fn with_dial_ranker(mut self, dial_ranker: DialRanker) -> Self { + self.pool_config.dial_ranker = Some(Arc::new(dial_ranker)); + self + } + + /// Enables smart dialing. + pub fn with_smart_dial_ranker(self) -> Self { + self.with_dial_ranker(smart_dial_ranker) + } } /// Possible errors when trying to establish or upgrade an outbound connection. @@ -2115,6 +2130,79 @@ mod tests { QuickCheck::new().tests(10).quickcheck(prop as fn(_) -> _); } + #[test] + fn concurrent_smart_dialing() { + #[derive(Clone, Debug)] + struct DialConcurrencyFactor(NonZeroU8); + + impl Arbitrary for DialConcurrencyFactor { + fn arbitrary(g: &mut Gen) -> Self { + Self(NonZeroU8::new(g.gen_range(1..11)).unwrap()) + } + } + + fn prop(concurrency_factor: DialConcurrencyFactor) { + tokio::runtime::Runtime::new().unwrap().block_on(async { + let mut swarm = new_test_swarm( + Config::with_tokio_executor() + .with_smart_dial_ranker() + .with_dial_concurrency_factor(concurrency_factor.0), + ); + + // Listen on `concurrency_factor + 1` addresses. + // + // `+ 2` to ensure a subset of addresses is dialed by network_2. + let num_listen_addrs = concurrency_factor.0.get() + 2; + let mut listen_addresses = Vec::new(); + let mut transports = Vec::new(); + for _ in 0..num_listen_addrs { + let mut transport = transport::MemoryTransport::default().boxed(); + transport + .listen_on(ListenerId::next(), "/memory/0".parse().unwrap()) + .unwrap(); + + match transport.select_next_some().await { + TransportEvent::NewAddress { listen_addr, .. } => { + listen_addresses.push(listen_addr); + } + _ => panic!("Expected `NewListenAddr` event."), + } + + transports.push(transport); + } + + // Have swarm dial each listener and wait for each listener to receive the incoming + // connections. + swarm + .dial( + DialOpts::peer_id(PeerId::random()) + .addresses(listen_addresses) + .build(), + ) + .unwrap(); + for mut transport in transports.into_iter() { + match futures::future::select(transport.select_next_some(), swarm.next()).await + { + future::Either::Left((TransportEvent::Incoming { .. }, _)) => {} + future::Either::Left(_) => { + panic!("Unexpected transport event.") + } + future::Either::Right((e, _)) => { + panic!("Expect swarm to not emit any event {e:?}") + } + } + } + + match swarm.next().await.unwrap() { + SwarmEvent::OutgoingConnectionError { .. } => {} + e => panic!("Unexpected swarm event {e:?}"), + } + }) + } + + QuickCheck::new().tests(10).quickcheck(prop as fn(_) -> _); + } + #[tokio::test] async fn invalid_peer_id() { // Checks whether dialing an address containing the wrong peer id raises an error