From 066008d3a960644dd44e52f6d33822c85bf02114 Mon Sep 17 00:00:00 2001 From: Cameron Garnham Date: Sun, 11 Aug 2024 13:58:19 +0200 Subject: [PATCH 1/4] rework: upgrade all dependencies --- examples/get_metadata/Cargo.toml | 14 +- examples/get_metadata/src/main.rs | 483 ++--- examples/simple_torrent/Cargo.toml | 10 +- examples/simple_torrent/src/main.rs | 1079 +++++------ packages/bencode/Cargo.toml | 4 +- packages/bencode/src/access/convert.rs | 78 +- packages/bencode/src/error.rs | 141 +- packages/bencode/src/lib.rs | 6 +- packages/bencode/src/reference/bencode_ref.rs | 7 +- packages/bencode/src/reference/decode.rs | 81 +- packages/bencode/src/reference/decode_opt.rs | 2 - packages/dht/Cargo.toml | 12 +- packages/dht/examples/debug.rs | 61 +- packages/dht/src/builder.rs | 75 +- packages/dht/src/error.rs | 53 +- packages/dht/src/handshaker_trait.rs | 3 +- packages/dht/src/message/announce_peer.rs | 12 +- packages/dht/src/message/compact_info.rs | 3 +- packages/dht/src/message/error.rs | 26 +- packages/dht/src/message/find_node.rs | 8 +- packages/dht/src/message/get_peers.rs | 17 +- packages/dht/src/message/mod.rs | 14 +- packages/dht/src/message/ping.rs | 6 +- packages/dht/src/message/request.rs | 18 +- packages/dht/src/message/response.rs | 58 +- packages/dht/src/router.rs | 39 +- packages/dht/src/routing/node.rs | 56 +- packages/dht/src/routing/table.rs | 13 +- packages/dht/src/worker/bootstrap.rs | 338 ++-- packages/dht/src/worker/handler.rs | 1596 +++++++++-------- packages/dht/src/worker/lookup.rs | 424 +++-- packages/dht/src/worker/messenger.rs | 66 +- packages/dht/src/worker/mod.rs | 23 +- packages/dht/src/worker/refresh.rs | 101 +- packages/disk/Cargo.toml | 21 +- packages/disk/benches/disk_benchmark.rs | 129 +- packages/disk/examples/add_torrent.rs | 51 +- .../disk/src/disk/fs/cache/file_handle.rs | 20 +- packages/disk/src/disk/fs/mod.rs | 25 +- packages/disk/src/disk/fs/native.rs | 32 +- packages/disk/src/disk/manager/builder.rs | 28 +- packages/disk/src/disk/manager/mod.rs | 84 +- packages/disk/src/disk/manager/sink.rs | 140 +- packages/disk/src/disk/manager/stream.rs | 71 +- packages/disk/src/disk/tasks/context.rs | 120 +- .../src/disk/tasks/helpers/piece_accessor.rs | 40 +- .../src/disk/tasks/helpers/piece_checker.rs | 97 +- packages/disk/src/disk/tasks/mod.rs | 279 +-- packages/disk/src/error.rs | 85 +- packages/disk/src/lib.rs | 4 +- packages/disk/src/memory/block.rs | 7 +- packages/disk/tests/add_torrent.rs | 41 +- packages/disk/tests/common/mod.rs | 151 +- packages/disk/tests/complete_torrent.rs | 278 ++- .../tests/disk_manager_send_backpressure.rs | 62 +- packages/disk/tests/load_block.rs | 63 +- packages/disk/tests/process_block.rs | 51 +- packages/disk/tests/remove_torrent.rs | 63 +- packages/disk/tests/resume_torrent.rs | 142 +- packages/handshake/Cargo.toml | 17 +- .../handshake/examples/handshake_torrent.rs | 87 +- packages/handshake/src/bittorrent/framed.rs | 487 +++-- packages/handshake/src/bittorrent/message.rs | 81 +- packages/handshake/src/filter/mod.rs | 1 - packages/handshake/src/handshake/builder.rs | 18 +- packages/handshake/src/handshake/config.rs | 1 - .../src/handshake/handler/handshaker.rs | 297 ++- .../src/handshake/handler/initiator.rs | 99 +- .../src/handshake/handler/listener.rs | 69 +- .../handshake/src/handshake/handler/mod.rs | 68 +- .../handshake/src/handshake/handler/timer.rs | 51 - packages/handshake/src/handshake/mod.rs | 150 +- packages/handshake/src/handshake/sink.rs | 36 +- packages/handshake/src/handshake/stream.rs | 18 +- packages/handshake/src/local_addr.rs | 11 +- packages/handshake/src/message/complete.rs | 1 + packages/handshake/src/message/extensions.rs | 39 +- packages/handshake/src/message/protocol.rs | 91 +- packages/handshake/src/transport.rs | 176 +- packages/handshake/tests/common/mod.rs | 19 + .../tests/test_byte_after_handshake.rs | 57 +- .../tests/test_bytes_after_handshake.rs | 65 +- packages/handshake/tests/test_connect.rs | 94 +- .../handshake/tests/test_filter_allow_all.rs | 106 +- .../handshake/tests/test_filter_block_all.rs | 106 +- .../tests/test_filter_whitelist_diff_data.rs | 106 +- .../tests/test_filter_whitelist_same_data.rs | 106 +- packages/magnet/Cargo.toml | 3 +- packages/magnet/src/lib.rs | 4 +- packages/metainfo/Cargo.toml | 10 +- packages/metainfo/examples/create_torrent.rs | 13 +- packages/metainfo/src/accessor.rs | 42 +- packages/metainfo/src/builder/buffer.rs | 11 +- packages/metainfo/src/builder/mod.rs | 8 +- packages/metainfo/src/builder/worker.rs | 32 +- packages/metainfo/src/error.rs | 37 +- packages/metainfo/src/metainfo.rs | 57 +- packages/metainfo/src/parse.rs | 22 +- packages/peer/Cargo.toml | 20 +- packages/peer/src/codec.rs | 87 +- packages/peer/src/lib.rs | 12 +- packages/peer/src/macros.rs | 31 - packages/peer/src/manager/builder.rs | 69 +- packages/peer/src/manager/error.rs | 40 +- packages/peer/src/manager/fused.rs | 180 +- packages/peer/src/manager/messages.rs | 78 +- packages/peer/src/manager/mod.rs | 142 +- packages/peer/src/manager/sink.rs | 421 ++--- packages/peer/src/manager/stream.rs | 218 +-- packages/peer/src/manager/task.rs | 399 ++--- packages/peer/src/message/bencode_util.rs | 12 +- .../peer/src/message/bits_ext/handshake.rs | 281 ++- packages/peer/src/message/bits_ext/mod.rs | 99 +- packages/peer/src/message/bits_ext/port.rs | 62 +- packages/peer/src/message/mod.rs | 348 +++- packages/peer/src/message/null.rs | 1 + packages/peer/src/message/prot_ext/mod.rs | 180 +- .../peer/src/message/prot_ext/ut_metadata.rs | 45 +- packages/peer/src/message/standard.rs | 514 +++++- packages/peer/src/protocol/extension.rs | 78 +- packages/peer/src/protocol/mod.rs | 27 +- packages/peer/src/protocol/null.rs | 33 +- packages/peer/src/protocol/unit.rs | 30 +- packages/peer/src/protocol/wire.rs | 70 +- .../peer/tests/common/connected_channel.rs | 79 + packages/peer/tests/common/mod.rs | 156 +- .../tests/peer_manager_send_backpressure.rs | 126 +- packages/select/Cargo.toml | 15 +- packages/select/src/discovery/error.rs | 37 +- packages/select/src/discovery/ut_metadata.rs | 363 ++-- packages/select/src/error.rs | 16 +- packages/select/src/extended/mod.rs | 64 +- packages/select/src/revelation/error.rs | 46 +- packages/select/src/revelation/honest.rs | 230 ++- packages/select/src/revelation/mod.rs | 4 +- packages/select/src/uber/mod.rs | 306 +--- packages/select/src/uber/sink.rs | 86 + packages/select/src/uber/stream.rs | 44 + packages/select/tests/common/mod.rs | 27 + packages/select/tests/select_tests.rs | 208 ++- packages/util/Cargo.toml | 8 +- packages/util/src/contiguous.rs | 4 +- packages/util/src/lib.rs | 3 - packages/util/src/send/mod.rs | 33 - packages/util/src/send/split_sender.rs | 123 -- packages/util/src/sha/mod.rs | 11 + packages/util/src/trans/locally_shuffled.rs | 2 +- packages/util/src/trans/sequential.rs | 2 +- packages/utracker/Cargo.toml | 16 +- packages/utracker/src/announce.rs | 313 ++-- packages/utracker/src/client/dispatcher.rs | 193 +- packages/utracker/src/client/error.rs | 23 +- packages/utracker/src/client/mod.rs | 69 +- packages/utracker/src/contact.rs | 68 +- packages/utracker/src/error.rs | 28 +- packages/utracker/src/lib.rs | 2 +- packages/utracker/src/option.rs | 155 +- packages/utracker/src/request.rs | 66 +- packages/utracker/src/response.rs | 70 +- packages/utracker/src/scrape.rs | 81 +- packages/utracker/src/server/dispatcher.rs | 163 +- packages/utracker/src/server/handler.rs | 23 +- packages/utracker/src/server/mod.rs | 21 +- packages/utracker/tests/common/mod.rs | 166 +- .../utracker/tests/test_announce_start.rs | 56 +- packages/utracker/tests/test_announce_stop.rs | 68 +- packages/utracker/tests/test_client_drop.rs | 42 +- packages/utracker/tests/test_client_full.rs | 31 +- packages/utracker/tests/test_connect.rs | 50 +- packages/utracker/tests/test_connect_cache.rs | 37 +- packages/utracker/tests/test_scrape.rs | 39 +- packages/utracker/tests/test_server_drop.rs | 7 +- 172 files changed, 9474 insertions(+), 7490 deletions(-) delete mode 100644 packages/handshake/src/handshake/handler/timer.rs delete mode 100644 packages/peer/src/macros.rs create mode 100644 packages/peer/tests/common/connected_channel.rs create mode 100644 packages/select/src/uber/sink.rs create mode 100644 packages/select/src/uber/stream.rs create mode 100644 packages/select/tests/common/mod.rs delete mode 100644 packages/util/src/send/mod.rs delete mode 100644 packages/util/src/send/split_sender.rs diff --git a/examples/get_metadata/Cargo.toml b/examples/get_metadata/Cargo.toml index d8ea5f419..3603c99f8 100644 --- a/examples/get_metadata/Cargo.toml +++ b/examples/get_metadata/Cargo.toml @@ -16,12 +16,14 @@ version.workspace = true [dependencies] dht = { path = "../../packages/dht" } handshake = { path = "../../packages/handshake" } +metainfo = { path = "../../packages/metainfo" } peer = { path = "../../packages/peer" } select = { path = "../../packages/select" } -clap = "3" -futures = "0.1" -hex = "0.4" -pendulum = "0.3" -tokio-codec = "0.1" -tokio-core = "0.1" +clap = "4" +futures = "0" +hex = "0" +tokio = { version = "1", features = ["full"] } +tokio-util = { version = "0", features = ["codec"] } +tracing = "0" +tracing-subscriber = "0" diff --git a/examples/get_metadata/src/main.rs b/examples/get_metadata/src/main.rs index 8a4efeb49..b7fac99d5 100644 --- a/examples/get_metadata/src/main.rs +++ b/examples/get_metadata/src/main.rs @@ -1,55 +1,57 @@ -use std::fmt::Debug; -use std::fs::File; -use std::io::Write; +use std::io::Write as _; use std::net::SocketAddr; +use std::sync::{Arc, Once}; use std::time::Duration; -use clap::clap_app; +use clap::{Arg, ArgMatches, Command}; use dht::handshaker_trait::HandshakerTrait; use dht::{DhtBuilder, DhtEvent, Router}; -use futures::future::{self, Either, Loop}; -use futures::sink::Wait; -use futures::{Future, Sink, Stream}; +use futures::future::{BoxFuture, Either}; +use futures::{FutureExt, Sink, SinkExt as _, StreamExt}; use handshake::transports::TcpTransport; use handshake::{ DiscoveryInfo, Extension, Extensions, HandshakerBuilder, HandshakerConfig, InfoHash, InitiateMessage, PeerId, Protocol, }; use hex::FromHex; +use metainfo::Metainfo; use peer::messages::builders::ExtendedMessageBuilder; use peer::messages::{BitsExtensionMessage, PeerExtensionProtocolMessage, PeerWireProtocolMessage}; use peer::protocols::{NullProtocol, PeerExtensionProtocol, PeerWireProtocol}; -use peer::{IPeerManagerMessage, OPeerManagerMessage, PeerInfo, PeerManagerBuilder, PeerProtocolCodec}; -use pendulum::future::TimerBuilder; -use pendulum::HashedWheelBuilder; -use select::discovery::error::DiscoveryError; +use peer::{ + PeerInfo, PeerManagerBuilder, PeerManagerInputMessage, PeerManagerOutputError, PeerManagerOutputMessage, PeerProtocolCodec, +}; use select::discovery::{IDiscoveryMessage, ODiscoveryMessage, UtMetadataModule}; -use select::{ControlMessage, DiscoveryTrait, IExtendedMessage, IUberMessage, OExtendedMessage, OUberMessage, UberModuleBuilder}; -use tokio_core::reactor::Core; +use select::{ControlMessage, IExtendedMessage, IUberMessage, OUberMessage, UberModuleBuilder}; +use tokio::signal; +use tokio_util::codec::Framed; +use tracing::level_filters::LevelFilter; + +pub static INIT: Once = Once::new(); // Legacy Handshaker, when bip_dht is migrated, it will accept S directly struct LegacyHandshaker { port: u16, id: PeerId, - sender: Wait, + sender: S, } impl LegacyHandshaker where - S: DiscoveryInfo + Sink, + S: DiscoveryInfo + Unpin, { pub fn new(sink: S) -> LegacyHandshaker { LegacyHandshaker { port: sink.port(), id: sink.peer_id(), - sender: sink.wait(), + sender: sink, } } } impl HandshakerTrait for LegacyHandshaker where - S: Sink + Send, - S::SinkError: Debug, + S: Sink + Send + Unpin, + S::Error: std::fmt::Debug, { type MetadataEnvelope = (); @@ -61,246 +63,285 @@ where self.port } - fn connect(&mut self, _expected: Option, hash: InfoHash, addr: SocketAddr) { - self.sender - .send(InitiateMessage::new(Protocol::BitTorrent, hash, addr)) - .unwrap(); + fn connect(&mut self, _expected: Option, hash: InfoHash, addr: SocketAddr) -> BoxFuture<'_, ()> { + async move { + self.sender + .send(InitiateMessage::new(Protocol::BitTorrent, hash, addr)) + .await + .unwrap(); + } + .boxed() } fn metadata(&mut self, _data: ()) {} } +fn parse_arguments() -> ArgMatches { + Command::new("get_metadata") + .version("1.0") + .author("Andrew ") + .about("Download torrent file from info hash") + .arg( + Arg::new("infohash") + .short('i') + .long("infohash") + .required(true) + .value_name("INFOHASH") + .help("InfoHash of the torrent"), + ) + .arg( + Arg::new("output") + .short('f') + .long("output") + .required(true) + .value_name("OUTPUT") + .help("Output to write the torrent file to"), + ) + .get_matches() +} + +fn extract_arguments(matches: &ArgMatches) -> (String, String) { + let hash = matches.get_one::("infohash").unwrap().to_string(); + let output = matches.get_one::("output").unwrap().to_string(); + (hash, output) +} + +pub fn tracing_stdout_init(filter: LevelFilter) { + let builder = tracing_subscriber::fmt() + .with_max_level(filter) + .with_ansi(true) + .with_writer(std::io::stdout); + + builder.pretty().with_file(true).init(); + + tracing::info!("Logging initialized"); +} + +async fn ctrl_c() { + signal::ctrl_c().await.expect("failed to listen for event"); + tracing::warn!("Ctrl-C received, shutting down..."); +} + +enum SendUber { + Finished(Result<(), select::error::Error>), + Interrupted, +} + +enum MainDht { + Finished(Box), + Interrupted, +} + #[allow(clippy::too_many_lines)] -fn main() { - let matches = clap_app!(myapp => - (version: "1.0") - (author: "Andrew ") - (about: "Download torrent file from info hash") - (@arg hash: -h +required +takes_value "InfoHash of the torrent") - //(@arg peer: -p +required +takes_value "Single peer to connect to of the form addr:port") - (@arg output: -f +required +takes_value "Output to write the torrent file to") - ) - .get_matches(); - let hash = matches.value_of("hash").unwrap(); - //let addr = matches.value_of("peer").unwrap().parse().unwrap(); - let output = matches.value_of("output").unwrap(); - - let hash: Vec = FromHex::from_hex(hash).unwrap(); - let info_hash = InfoHash::from_hash(&hash[..]).unwrap(); - - // Create our main "core" event loop - let mut core = Core::new().unwrap(); +#[tokio::main] +async fn main() { + INIT.call_once(|| { + tracing_stdout_init(LevelFilter::TRACE); + }); + + // Parse command-line arguments + let matches = parse_arguments(); + let (hash, output) = extract_arguments(&matches); + + let hash: Vec = FromHex::from_hex(hash).expect("Invalid hex in hash argument"); + let info_hash = InfoHash::from_hash(&hash[..]).expect("Failed to create InfoHash"); // Activate the extension protocol via the handshake bits let mut extensions = Extensions::new(); extensions.add(Extension::ExtensionProtocol); // Create a handshaker that can initiate connections with peers - let (handshaker_send, handshaker_recv) = HandshakerBuilder::new() + let (handshaker, mut tasks) = HandshakerBuilder::new() .with_extensions(extensions) .with_config( // Set a low handshake timeout so we don't wait on peers that aren't listening on tcp HandshakerConfig::default().with_connect_timeout(Duration::from_millis(500)), ) - .build(TcpTransport, &core.handle()) - .unwrap() - .into_parts(); + .build(TcpTransport) + .await + .expect("it should build a handshaker pair"); + let (handshaker_send, mut handshaker_recv) = handshaker.into_parts(); + // Create a peer manager that will hold our peers and heartbeat/send messages to them - let (peer_manager_send, peer_manager_recv) = PeerManagerBuilder::new().build(core.handle()).into_parts(); + let (mut peer_manager_send, peer_manager_recv) = PeerManagerBuilder::new().build().into_parts(); // Hook up a future that feeds incoming (handshaken) peers over to the peer manager - core.handle().spawn( - handshaker_recv - .map_err(|()| ()) - .map(|complete_msg| { - // Our handshaker finished handshaking some peer, get - // the peer info as well as the peer itself (socket) - let (_, extensions, hash, pid, addr, sock) = complete_msg.into_parts(); - - // Only connect to peer that support the extension protocol... - if extensions.contains(Extension::ExtensionProtocol) { - // Frame our socket with the peer wire protocol with no - // extensions (nested null protocol), and a max payload of 24KB - let peer = tokio_codec::Decoder::framed( - PeerProtocolCodec::with_max_payload( - PeerWireProtocol::new(PeerExtensionProtocol::new(NullProtocol::new())), - 24 * 1024, - ), - sock, - ); - - // Create our peer identifier used by our peer manager - let peer_info = PeerInfo::new(addr, pid, hash, extensions); - - // Map to a message that can be fed to our peer manager - IPeerManagerMessage::AddPeer(peer_info, peer) - } else { - panic!("Chosen Peer Does Not Support Extended Messages") - } - }) - .forward(peer_manager_send.clone().sink_map_err(|_| ())) - .map(|_| ()), - ); + tasks.spawn(async move { + while let Some(complete_msg) = handshaker_recv.next().await { + let (_, extensions, hash, pid, addr, sock) = complete_msg.unwrap().into_parts(); + + if extensions.contains(Extension::ExtensionProtocol) { + let peer = Framed::new( + sock, + PeerProtocolCodec::with_max_payload( + PeerWireProtocol::new(PeerExtensionProtocol::new(NullProtocol::new())), + 24 * 1024, + ), + ); + + let peer_info = PeerInfo::new(addr, pid, hash, extensions); + + peer_manager_send + .send(Ok(PeerManagerInputMessage::AddPeer(peer_info, peer))) + .await + .unwrap(); + } else { + panic!("Chosen Peer Does Not Support Extended Messages"); + } + } + }); // Create our UtMetadata selection module - let (uber_send, uber_recv) = { - let mut this = UberModuleBuilder::new().with_extended_builder(Some(ExtendedMessageBuilder::new())); + let (mut uber_send, mut uber_recv) = { + let builder = UberModuleBuilder::new().with_extended_builder(Some(ExtendedMessageBuilder::new())); let module = UtMetadataModule::new(); - this.discovery.push(Box::new(module) - as Box< - dyn DiscoveryTrait< - SinkItem = IDiscoveryMessage, - SinkError = Box, - Item = ODiscoveryMessage, - Error = Box, - >, - >); - this + builder.discovery.lock().unwrap().push(Arc::new(module)); + + builder } .build() - .split(); + .into_parts(); // Tell the uber module we want to download metainfo for the given hash - let uber_send = core - .run( - uber_send - .send(IUberMessage::Discovery(IDiscoveryMessage::DownloadMetainfo(info_hash))) - .map_err(|_| ()), - ) - .unwrap(); + let send_to_uber = uber_send + .send(IUberMessage::Discovery(Box::new(IDiscoveryMessage::DownloadMetainfo( + info_hash, + )))) + .boxed(); + + // Await either the sending to uber or the Ctrl-C signal + let send_to_uber = tokio::select! { + res = send_to_uber => SendUber::Finished(res), + () = ctrl_c() => SendUber::Interrupted, + }; - let timer = TimerBuilder::default().build(HashedWheelBuilder::default().build()); - let timer_recv = timer.sleep_stream(Duration::from_millis(100)).unwrap().map(Either::B); + let () = match send_to_uber { + SendUber::Finished(Ok(())) => (), - let merged_recv = peer_manager_recv.map(Either::A).map_err(|()| ()).select(timer_recv); + SendUber::Finished(Err(e)) => { + tracing::warn!("send to uber failed with error: {e}"); + tasks.shutdown().await; + return; + } + + SendUber::Interrupted => { + tracing::warn!("setup was canceled..."); + tasks.shutdown().await; + return; + } + }; + + let timer = futures::stream::unfold(tokio::time::interval(Duration::from_millis(100)), |mut interval| async move { + interval.tick().await; + Some(((), interval)) + }); + + let mut merged_recv = futures::stream::select(peer_manager_recv.map(Either::Left), timer.map(Either::Right)).boxed(); // Hook up a future that receives messages from the peer manager - core.handle().spawn(future::loop_fn( - (merged_recv, info_hash, uber_send.sink_map_err(|_| ())), - |(merged_recv, info_hash, select_send)| { - merged_recv - .into_future() - .map_err(|_| ()) - .and_then(move |(opt_item, merged_recv)| { - let opt_message = match opt_item.unwrap() { - Either::A(OPeerManagerMessage::ReceivedMessage( - info, - PeerWireProtocolMessage::BitsExtension(BitsExtensionMessage::Extended(extended)), - )) => Some(IUberMessage::Extended(IExtendedMessage::ReceivedExtendedMessage( - info, extended, - ))), - Either::A(OPeerManagerMessage::ReceivedMessage( - info, - PeerWireProtocolMessage::ProtExtension(PeerExtensionProtocolMessage::UtMetadata(message)), - )) => Some(IUberMessage::Discovery(IDiscoveryMessage::ReceivedUtMetadataMessage( - info, message, - ))), - Either::A(OPeerManagerMessage::PeerAdded(info)) => { - println!("Connected To Peer: {info:?}"); - Some(IUberMessage::Control(ControlMessage::PeerConnected(info))) - } - Either::A(OPeerManagerMessage::PeerRemoved(info)) => { - println!("We Removed Peer {info:?} From The Peer Manager"); - Some(IUberMessage::Control(ControlMessage::PeerDisconnected(info))) - } - Either::A(OPeerManagerMessage::PeerDisconnect(info)) => { - println!("Peer {info:?} Disconnected From Us"); - Some(IUberMessage::Control(ControlMessage::PeerDisconnected(info))) - } - Either::A(OPeerManagerMessage::PeerError(info, error)) => { - println!("Peer {info:?} Disconnected With Error: {error:?}"); - Some(IUberMessage::Control(ControlMessage::PeerDisconnected(info))) - } - Either::B(()) => Some(IUberMessage::Control(ControlMessage::Tick(Duration::from_millis(100)))), - Either::A(_) => None, - }; - - match opt_message { - Some(message) => Either::A( - select_send - .send(message) - .map(move |select_send| Loop::Continue((merged_recv, info_hash, select_send))), - ), - None => Either::B(future::ok(Loop::Continue((merged_recv, info_hash, select_send)))), + tasks.spawn(async move { + let mut uber_send = uber_send.clone(); + + while let Some(item) = merged_recv.next().await { + let message = if let Either::Left(message) = item { + match message { + Ok(PeerManagerOutputMessage::PeerAdded(info)) => { + tracing::info!("Connected To Peer: {info:?}"); + IUberMessage::Control(Box::new(ControlMessage::PeerConnected(info))) } - }) - }, - )); + Ok(PeerManagerOutputMessage::PeerRemoved(info)) => { + tracing::info!("We Removed Peer {info:?} From The Peer Manager"); + IUberMessage::Control(Box::new(ControlMessage::PeerDisconnected(info))) + } + Ok(PeerManagerOutputMessage::SentMessage(_, _)) => todo!(), + Ok(PeerManagerOutputMessage::ReceivedMessage(info, message)) => match message { + PeerWireProtocolMessage::BitsExtension(message) => match message { + BitsExtensionMessage::Extended(extended) => { + IUberMessage::Extended(Box::new(IExtendedMessage::ReceivedExtendedMessage(info, extended))) + } + BitsExtensionMessage::Port(_) => unimplemented!(), + }, + PeerWireProtocolMessage::ProtExtension(message) => match message { + Ok(PeerExtensionProtocolMessage::UtMetadata(message)) => { + IUberMessage::Discovery(Box::new(IDiscoveryMessage::ReceivedUtMetadataMessage(info, message))) + } + _ => unimplemented!(), + }, + _ => unimplemented!(), + }, + Ok(PeerManagerOutputMessage::PeerDisconnect(info)) => { + tracing::info!("Peer {info:?} Disconnected From Us"); + IUberMessage::Control(Box::new(ControlMessage::PeerDisconnected(info))) + } + Err(e) => { + let info = match e { + PeerManagerOutputError::PeerError(info, _) + | PeerManagerOutputError::PeerErrorAndMissing(info, _) + | PeerManagerOutputError::PeerRemovedAndMissing(info) + | PeerManagerOutputError::PeerDisconnectedAndMissing(info) => info, + }; + + tracing::info!("Peer {info:?} Disconnected With Error: {e:?}"); + IUberMessage::Control(Box::new(ControlMessage::PeerDisconnected(info))) + } + } + } else { + IUberMessage::Control(Box::new(ControlMessage::Tick(Duration::from_millis(100)))) + }; + + uber_send.send(message).await.unwrap(); + } + }); // Setup the dht which will be the only peer discovery service we use in this example let legacy_handshaker = LegacyHandshaker::new(handshaker_send); - let dht = DhtBuilder::with_router(Router::uTorrent) - .set_read_only(false) - .start_mainline(legacy_handshaker) - .unwrap(); - - println!("Bootstrapping Dht..."); - for message in dht.events() { - if let DhtEvent::BootstrapCompleted = message { - break; + + let main_dht = async move { + let dht = DhtBuilder::with_router(Router::uTorrent) + .set_read_only(false) + .start_mainline(legacy_handshaker) + .await + .expect("it should start the dht mainline"); + + tracing::info!("Bootstrapping Dht..."); + while let Some(message) = dht.events().await.next().await { + if let DhtEvent::BootstrapCompleted = message { + break; + } + } + tracing::info!("Bootstrap Complete..."); + + dht.search(info_hash, true).await; + + loop { + if let Some(Ok(OUberMessage::Discovery(ODiscoveryMessage::DownloadedMetainfo(metainfo)))) = uber_recv.next().await { + break metainfo; + } } } - println!("Bootstrap Complete..."); - - dht.search(info_hash, true); - - /* - // Send the peer given from the command line over to the handshaker to initiate a connection - core.run( - handshaker_send - .send(InitiateMessage::new(Protocol::BitTorrent, info_hash, addr)) - .map_err(|_| ()), - ).unwrap(); - */ - - let metainfo = core - .run(future::loop_fn( - (uber_recv, peer_manager_send.sink_map_err(|_| ()), None), - |(select_recv, map_peer_manager_send, mut opt_metainfo)| { - select_recv - .into_future() - .map_err(|_| ()) - .and_then(move |(opt_message, select_recv)| { - let opt_message = opt_message.and_then(|message| match message { - OUberMessage::Extended(OExtendedMessage::SendExtendedMessage(info, ext_message)) => { - Some(IPeerManagerMessage::SendMessage( - info, - 0, - PeerWireProtocolMessage::BitsExtension(BitsExtensionMessage::Extended(ext_message)), - )) - } - OUberMessage::Discovery(ODiscoveryMessage::SendUtMetadataMessage(info, message)) => { - Some(IPeerManagerMessage::SendMessage( - info, - 0, - PeerWireProtocolMessage::ProtExtension(PeerExtensionProtocolMessage::UtMetadata(message)), - )) - } - OUberMessage::Discovery(ODiscoveryMessage::DownloadedMetainfo(metainfo)) => { - opt_metainfo = Some(metainfo); - None - } - _ => { - panic!("Unexpected Message For Uber Module...") - } - }); - - match (opt_message, opt_metainfo.take()) { - (Some(message), _) => Either::A( - map_peer_manager_send - .send(message) - .map(move |peer_manager_send| Loop::Continue((select_recv, peer_manager_send, opt_metainfo))), - ), - (None, None) => { - Either::B(future::ok(Loop::Continue((select_recv, map_peer_manager_send, opt_metainfo)))) - } - (None, Some(metainfo)) => Either::B(future::ok(Loop::Break(metainfo))), - } - }) - }, - )) - .unwrap(); + .boxed(); + + // Await either the sending to uber or the Ctrl-C signal + let main_dht = tokio::select! { + res = main_dht => MainDht::Finished(Box::new(res)), + () = ctrl_c() => MainDht::Interrupted, + }; + + let metainfo = match main_dht { + MainDht::Finished(metainfo) => metainfo, + + MainDht::Interrupted => { + tracing::warn!("setup was canceled..."); + tasks.shutdown().await; + return; + } + }; // Write the metainfo file out to the user provided path - File::create(output).unwrap().write_all(&metainfo.to_bytes()).unwrap(); + std::fs::File::create(output) + .expect("Failed to create output file") + .write_all(&metainfo.to_bytes()) + .expect("Failed to write metainfo to file"); + + tasks.shutdown().await; } diff --git a/examples/simple_torrent/Cargo.toml b/examples/simple_torrent/Cargo.toml index 6e4d18046..58b8d08ca 100644 --- a/examples/simple_torrent/Cargo.toml +++ b/examples/simple_torrent/Cargo.toml @@ -20,7 +20,9 @@ handshake = { path = "../../packages/handshake" } metainfo = { path = "../../packages/metainfo" } peer = { path = "../../packages/peer" } -clap = "3" -futures = "0.1" -tokio-codec = "0.1" -tokio-core = "0.1" +clap = "4" +futures = "0" +tokio = { version = "1", features = ["full"] } +tokio-util = { version = "0", features = ["codec"] } +tracing = "0" +tracing-subscriber = "0" diff --git a/examples/simple_torrent/src/main.rs b/examples/simple_torrent/src/main.rs index 8a4e23da4..6b197e0b1 100644 --- a/examples/simple_torrent/src/main.rs +++ b/examples/simple_torrent/src/main.rs @@ -1,75 +1,44 @@ -use std::cell::RefCell; -use std::cmp; use std::collections::HashMap; -use std::fs::File; -use std::io::Read; -use std::rc::Rc; +use std::io::Read as _; +use std::sync::{Arc, Once}; -use clap::clap_app; use disk::fs::NativeFileSystem; use disk::fs_cache::FileHandleCache; -use disk::{Block, BlockMetadata, BlockMut, DiskManagerBuilder, IDiskMessage, ODiskMessage}; -use futures::future::{Either, Loop}; -use futures::sync::mpsc; -use futures::{future, stream, Future, Sink, Stream}; +use disk::{ + Block, BlockMetadata, BlockMut, DiskManager, DiskManagerBuilder, DiskManagerSink, DiskManagerStream, IDiskMessage, InfoHash, + ODiskMessage, +}; +use futures::channel::mpsc; +use futures::future::Either; +use futures::lock::Mutex; +use futures::{stream, SinkExt as _, StreamExt as _}; use handshake::transports::TcpTransport; -//use bip_dht::{DhtBuilder, Handshaker, Router}; -use handshake::{Extensions, HandshakerBuilder, HandshakerConfig, InitiateMessage, PeerId, Protocol}; +use handshake::{ + Extensions, Handshaker, HandshakerBuilder, HandshakerConfig, HandshakerStream, InitiateMessage, PeerId, Protocol, +}; use metainfo::{Info, Metainfo}; use peer::messages::{BitFieldMessage, HaveMessage, PeerWireProtocolMessage, PieceMessage, RequestMessage}; use peer::protocols::{NullProtocol, PeerWireProtocol}; -use peer::{IPeerManagerMessage, OPeerManagerMessage, PeerInfo, PeerManagerBuilder, PeerProtocolCodec}; -use tokio_core::reactor::Core; - -/* - Things this example doesn't do, because of the lack of bip_select: - * Logic for piece selection is not abstracted (and is pretty bad) - * We will unconditionally upload pieces to a peer (regardless whether or not they were choked) - * We don't add an info hash filter to bip_handshake after we have as many peers as we need/want - * We don't do any banning of malicious peers - - Things the example doesn't do, unrelated to bip_select: - * Matching peers up to disk requests isn't as good as it could be - * doesn't use a shared BytesMut for servicing piece requests - * Good logging -*/ - -/* -// Legacy Handshaker, when bip_dht is migrated, it will accept S directly -struct LegacyHandshaker { - port: u16, - id: PeerId, - sender: Wait -} - -impl LegacyHandshaker where S: DiscoveryInfo + Sink { - pub fn new(sink: S) -> LegacyHandshaker { - LegacyHandshaker{ port: sink.port(), id: sink.peer_id(), sender: sink.wait() } - } -} - -impl Handshaker for LegacyHandshaker where S: Sink + Send, S::SinkError: Debug { - type MetadataEnvelope = (); - - fn id(&self) -> PeerId { self.id } - - fn port(&self) -> u16 { self.port } - - fn connect(&mut self, _expected: Option, hash: InfoHash, addr: SocketAddr) { - self.sender.send(InitiateMessage::new(Protocol::BitTorrent, hash, addr)); - } - - fn metadata(&mut self, _data: ()) { () } -} -*/ - -// How many requests can be in flight at once. +use peer::{ + PeerInfo, PeerManagerBuilder, PeerManagerInputMessage, PeerManagerOutputError, PeerManagerOutputMessage, PeerManagerSink, + PeerManagerStream, PeerProtocolCodec, +}; +use tokio::net::TcpStream; +use tokio::signal; +use tokio::task::JoinSet; +use tokio_util::bytes::BytesMut; +use tokio_util::codec::{Decoder, Framed}; +use tracing::level_filters::LevelFilter; + +// Maximum number of requests that can be in flight at once. const MAX_PENDING_BLOCKS: usize = 50; -// Some enum to store our selection state updates +pub static INIT: Once = Once::new(); + +// Enum to store our selection state updates #[allow(dead_code)] #[derive(Debug)] -enum SelectState { +enum PeerSelectionState { Choke(PeerInfo), UnChoke(PeerInfo), Interested(PeerInfo), @@ -85,498 +54,602 @@ enum SelectState { TorrentAdded, } -#[allow(clippy::too_many_lines)] -fn main() { - // Command line argument parsing - let matches = clap_app!(myapp => - - (version: "1.0") - (author: "Andrew ") - (about: "Simple torrent downloading") - (@arg file: -f +required +takes_value "Location of the torrent file") - (@arg dir: -d +takes_value "Download directory to use") - (@arg peer: -p +takes_value "Single peer to connect to of the form addr:port") - ) - .get_matches(); - let file = matches.value_of("file").unwrap(); - let dir = matches.value_of("dir").unwrap(); - let peer_addr = matches.value_of("peer").unwrap().parse().unwrap(); +enum Downloader { + Finished, + Interrupted, +} - // Load in our torrent file - let mut metainfo_bytes = Vec::new(); - File::open(file).unwrap().read_to_end(&mut metainfo_bytes).unwrap(); +enum Setup { + Finished((NativeDiskManager, PeerManager, TcpHandshaker), JoinSet<()>), + Interrupted, +} - // Parse out our torrent file - let metainfo = Metainfo::from_bytes(metainfo_bytes).unwrap(); - let info_hash = metainfo.info().info_hash(); +pub fn tracing_stdout_init(filter: LevelFilter) { + let builder = tracing_subscriber::fmt() + .with_max_level(filter) + .with_ansi(true) + .with_writer(std::io::stdout); - // Create our main "core" event loop - let mut core = Core::new().unwrap(); + builder.pretty().with_file(true).init(); - // Create a disk manager to handle storing/loading blocks (we add in a file handle cache - // to avoid anti virus causing slow file opens/closes, will cache up to 100 file handles) - let (disk_manager_send, disk_manager_recv) = DiskManagerBuilder::new() - // Reducing our sink and stream capacities allow us to constrain memory usage - // (though for spiky downloads, this could effectively throttle us, which is ok too.) - .with_sink_buffer_capacity(1) - .with_stream_buffer_capacity(0) - .build(FileHandleCache::new(NativeFileSystem::with_directory(dir), 100)) - .into_parts(); + tracing::info!("Logging initialized"); +} - // Create a handshaker that can initiate connections with peers - let (handshaker_send, handshaker_recv) = HandshakerBuilder::new() - .with_peer_id(PeerId::from_hash("-BI0000-000000000000".as_bytes()).unwrap()) - // We would ideally add a filter to the handshaker to block - // peers when we have enough of them for a given hash, but - // since this is a minimal example, we will rely on peer - // manager backpressure (handshaker -> peer manager will - // block when we reach our max peers). Setting these to low - // values so we don't have more than 2 unused tcp connections. - .with_config(HandshakerConfig::default().with_wait_buffer_size(0).with_done_buffer_size(0)) - .build::(TcpTransport, &core.handle()) // Will handshake over TCP (could swap this for UTP in the future) +async fn ctrl_c() { + signal::ctrl_c().await.expect("failed to listen for event"); + tracing::warn!("Ctrl-C received, shutting down..."); +} + +#[tokio::main] +async fn main() { + INIT.call_once(|| { + tracing_stdout_init(LevelFilter::TRACE); + }); + + // Parse command-line arguments + let matched_arguments = parse_arguments(); + let (torrent_file_path, download_directory, peer_address) = extract_arguments(&matched_arguments); + + // Load and parse the torrent file + let (metainfo, info_hash) = load_and_parse_torrent_file(&torrent_file_path); + + // Create a JoinSet to manage background tasks + let tasks = Arc::new(Mutex::new(JoinSet::new())); + + // Setup the managers. + let setup = setup(download_directory); + + // Await either the completion of the setup or the Ctrl-C signal + let setup = tokio::select! { + setup = setup => Setup::Finished(setup.0, setup.1), + () = ctrl_c() => Setup::Interrupted, + }; + + let (managers, mut handshaker_tasks) = match setup { + Setup::Finished(managers, handshaker_tasks) => (managers, handshaker_tasks), + Setup::Interrupted => { + tracing::warn!("setup was canceled..."); + return; + } + }; + + let downloader = downloader(tasks.clone(), managers, peer_address, metainfo, info_hash); + + // Await either the completion of the downloader or the Ctrl-C signal + let status = tokio::select! { + () = downloader => Downloader::Finished, + () = ctrl_c() => Downloader::Interrupted, + }; + + match status { + Downloader::Finished => { + while let Some(result) = handshaker_tasks.try_join_next() { + if let Err(e) = result { + eprintln!("Task failed: {e:?}"); + } + } + handshaker_tasks.shutdown().await; + + while let Some(result) = tasks.lock().await.try_join_next() { + if let Err(e) = result { + eprintln!("Task failed: {e:?}"); + } + } + tasks.lock().await.shutdown().await; + } + Downloader::Interrupted => { + handshaker_tasks.shutdown().await; + tasks.lock().await.shutdown().await; + } + } +} + +fn parse_arguments() -> clap::ArgMatches { + clap::Command::new("simple_torrent") + .version("1.0") + .author("Andrew ") + .about("Simple torrent downloading") + .arg( + clap::Arg::new("file") + .short('f') + .required(true) + .value_name("FILE") + .help("Location of the torrent file"), + ) + .arg( + clap::Arg::new("dir") + .short('d') + .value_name("DIR") + .help("Download directory to use"), + ) + .arg( + clap::Arg::new("peer") + .short('p') + .value_name("PEER") + .help("Single peer to connect to of the form addr:port"), + ) + .get_matches() +} + +fn extract_arguments(matches: &clap::ArgMatches) -> (String, String, String) { + let torrent_file_path = matches.get_one::("file").unwrap().to_string(); + let download_directory = matches.get_one::("dir").unwrap().to_string(); + let peer_address = matches.get_one::("peer").unwrap().to_string(); + (torrent_file_path, download_directory, peer_address) +} + +fn load_and_parse_torrent_file(torrent_file_path: &str) -> (Metainfo, InfoHash) { + let mut torrent_file_bytes = Vec::new(); + std::fs::File::open(torrent_file_path) .unwrap() - .into_parts(); - // Create a peer manager that will hold our peers and heartbeat/send messages to them - let (peer_manager_send, peer_manager_recv) = PeerManagerBuilder::new() - // Similar to the disk manager sink and stream capacities, we can constrain those - // for the peer manager as well. - .with_sink_buffer_capacity(0) - .with_stream_buffer_capacity(0) - .build(core.handle()) - .into_parts(); - - // Hook up a future that feeds incoming (handshaken) peers over to the peer manager - let map_peer_manager_send = peer_manager_send.clone().sink_map_err(|_| ()); - core.handle().spawn( - handshaker_recv - .map_err(|()| ()) - .map(|complete_msg| { - // Our handshaker finished handshaking some peer, get - // the peer info as well as the peer itself (socket) - let (_, _, hash, pid, addr, sock) = complete_msg.into_parts(); - // Frame our socket with the peer wire protocol with no extensions (nested null protocol), and a max payload of 24KB - let peer = tokio_codec::Decoder::framed( - PeerProtocolCodec::with_max_payload(PeerWireProtocol::new(NullProtocol::new()), 24 * 1024), - sock, - ); + .read_to_end(&mut torrent_file_bytes) + .unwrap(); - // Create our peer identifier used by our peer manager - let peer_info = PeerInfo::new(addr, pid, hash, Extensions::new()); + let metainfo = Metainfo::from_bytes(torrent_file_bytes).unwrap(); + let info_hash = metainfo.info().info_hash(); + (metainfo, info_hash) +} - // Map to a message that can be fed to our peer manager - IPeerManagerMessage::AddPeer(peer_info, peer) - }) - .forward(map_peer_manager_send) - .map(|_| ()), - ); +async fn setup(download_directory: String) -> ((NativeDiskManager, PeerManager, TcpHandshaker), JoinSet<()>) { + // Setup disk manager for handling file operations + let disk_manager = setup_disk_manager(&download_directory); - // Will hold a mapping of BlockMetadata -> Vec to track which peers to send a queued block to - let disk_request_map = Rc::new(RefCell::new(HashMap::new())); - let (select_send, select_recv) = mpsc::channel(50); + // Setup peer manager for managing peer communication + let peer_manager = setup_peer_manager(); - // Map out the errors for these sinks so they match - let map_select_send = select_send.clone().sink_map_err(|_| ()); - let map_disk_manager_send = disk_manager_send.clone().sink_map_err(|()| ()); + // Setup handshaker for managing peer connections + let (handshaker, handshaker_tasks) = setup_handshaker().await; - // Hook up a future that receives messages from the peer manager, and forwards request to the disk manager or selection manager (using loop fn - // here because we need to be able to access state, like request_map and a different future combinator wouldn't let us keep it around to access) - core.handle().spawn(future::loop_fn( - ( - peer_manager_recv, - info_hash, - disk_request_map.clone(), - map_select_send, - map_disk_manager_send, - ), - |(peer_manager_recv, info_hash, disk_request_map, select_send, disk_manager_send)| { - peer_manager_recv - .into_future() - .map_err(|_| ()) - .and_then(move |(opt_item, peer_manager_recv)| { - let opt_message = match opt_item.unwrap() { - OPeerManagerMessage::ReceivedMessage(info, message) => { - match message { - PeerWireProtocolMessage::Choke => Some(Either::A(SelectState::Choke(info))), - PeerWireProtocolMessage::UnChoke => Some(Either::A(SelectState::UnChoke(info))), - PeerWireProtocolMessage::Interested => Some(Either::A(SelectState::Interested(info))), - PeerWireProtocolMessage::UnInterested => Some(Either::A(SelectState::UnInterested(info))), - PeerWireProtocolMessage::Have(have) => Some(Either::A(SelectState::Have(info, have))), - PeerWireProtocolMessage::BitField(bitfield) => { - Some(Either::A(SelectState::BitField(info, bitfield))) - } - PeerWireProtocolMessage::Request(request) => { - let block_metadata = BlockMetadata::new( - info_hash, - u64::from(request.piece_index()), - u64::from(request.block_offset()), - request.block_length(), - ); - let mut request_map_mut = disk_request_map.borrow_mut(); - - // Add the block metadata to our request map, and add the peer as an entry there - let block_entry = request_map_mut.entry(block_metadata); - let peers_requested = block_entry.or_insert(Vec::new()); - - peers_requested.push(info); - - Some(Either::B(IDiskMessage::LoadBlock(BlockMut::new( - block_metadata, - vec![0u8; block_metadata.block_length()].into(), - )))) - } - PeerWireProtocolMessage::Piece(piece) => { - let block_metadata = BlockMetadata::new( - info_hash, - u64::from(piece.piece_index()), - u64::from(piece.block_offset()), - piece.block_length(), - ); - - // Peer sent us a block, send it over to the disk manager to be processed - Some(Either::B(IDiskMessage::ProcessBlock(Block::new( - block_metadata, - piece.block(), - )))) - } - _ => None, - } - } - OPeerManagerMessage::PeerAdded(info) => Some(Either::A(SelectState::NewPeer(info))), - OPeerManagerMessage::SentMessage(_, _) => None, - OPeerManagerMessage::PeerRemoved(info) => { - println!("We Removed Peer {info:?} From The Peer Manager"); - Some(Either::A(SelectState::RemovedPeer(info))) - } - OPeerManagerMessage::PeerDisconnect(info) => { - println!("Peer {info:?} Disconnected From Us"); - Some(Either::A(SelectState::RemovedPeer(info))) - } - OPeerManagerMessage::PeerError(info, error) => { - println!("Peer {info:?} Disconnected With Error: {error:?}"); - Some(Either::A(SelectState::RemovedPeer(info))) - } - }; - - // Could optimize out the box, but for the example, this is cleaner and shorter - let result_future: Box, Error = ()>> = match opt_message { - Some(Either::A(select_message)) => Box::new(select_send.send(select_message).map(move |select_send| { - Loop::Continue((peer_manager_recv, info_hash, disk_request_map, select_send, disk_manager_send)) - })), - Some(Either::B(disk_message)) => { - Box::new(disk_manager_send.send(disk_message).map(move |disk_manager_send| { - Loop::Continue((peer_manager_recv, info_hash, disk_request_map, select_send, disk_manager_send)) - })) - } - None => Box::new(future::ok(Loop::Continue(( - peer_manager_recv, - info_hash, - disk_request_map, - select_send, - disk_manager_send, - )))), - }; - - result_future - }) - }, + ((disk_manager, peer_manager, handshaker), handshaker_tasks) +} + +async fn downloader( + tasks: Arc>>, + managers: (NativeDiskManager, PeerManager, TcpHandshaker), + peer_address: String, + metainfo: Metainfo, + info_hash: InfoHash, +) { + let (disk_manager, peer_manager, handshaker) = managers; + + // Setup disk manager for handling file operations + let (mut disk_manager_sender, disk_manager_receiver) = disk_manager.into_parts(); + + // Setup peer manager for managing peer communication + let (peer_manager_sender, peer_manager_receiver) = peer_manager.into_parts(); + + // Setup handshaker for managing peer connections + let (mut handshaker_sender, handshaker_receiver) = handshaker.into_parts(); + + // Handle new incoming connections + tasks + .lock() + .await + .spawn(handle_new_connections(handshaker_receiver, peer_manager_sender.clone())); + + // Shared state for managing disk requests + let disk_request_map = Arc::new(Mutex::new(HashMap::new())); + let (selection_sender, selection_receiver) = mpsc::channel(50); + + // Handle messages from the peer manager + tasks.lock().await.spawn(handle_peer_manager_messages( + peer_manager_receiver, + info_hash, + disk_request_map.clone(), + selection_sender.clone(), + disk_manager_sender.clone(), )); - // Map out the errors for these sinks so they match - let map_select_send = select_send.clone().sink_map_err(|_| ()); - let map_peer_manager_send = peer_manager_send.clone().sink_map_err(|_| ()); - - // Hook up a future that receives from the disk manager, and forwards to the peer manager or select manager - core.handle().spawn(future::loop_fn( - ( - disk_manager_recv, - disk_request_map.clone(), - map_select_send, - map_peer_manager_send, - ), - |(disk_manager_recv, disk_request_map, select_send, peer_manager_send)| { - disk_manager_recv - .into_future() - .map_err(|_| ()) - .and_then(|(opt_item, disk_manager_recv)| { - let opt_message = match opt_item.unwrap() { - ODiskMessage::BlockLoaded(block) => { - let (metadata, block) = block.into_parts(); - - // Lookup the peer info given the block metadata - let mut request_map_mut = disk_request_map.borrow_mut(); - let peer_list = request_map_mut.get_mut(&metadata).unwrap(); - let peer_info = peer_list.remove(1); - - // Pack up our block into a peer wire protocol message and send it off to the peer - #[allow(clippy::cast_possible_truncation)] - let piece = - PieceMessage::new(metadata.piece_index() as u32, metadata.block_offset() as u32, block.freeze()); - let pwp_message = PeerWireProtocolMessage::Piece(piece); - - Some(Either::B(IPeerManagerMessage::SendMessage(peer_info, 0, pwp_message))) - } - ODiskMessage::TorrentAdded(_) => Some(Either::A(SelectState::TorrentAdded)), - ODiskMessage::TorrentSynced(_) => Some(Either::A(SelectState::TorrentSynced)), - ODiskMessage::FoundGoodPiece(_, index) => Some(Either::A(SelectState::GoodPiece(index))), - ODiskMessage::FoundBadPiece(_, index) => Some(Either::A(SelectState::BadPiece(index))), - ODiskMessage::BlockProcessed(_) => Some(Either::A(SelectState::BlockProcessed)), - _ => None, - }; - - // Could optimize out the box, but for the example, this is cleaner and shorter - let result_future: Box, Error = ()>> = match opt_message { - Some(Either::A(select_message)) => Box::new(select_send.send(select_message).map(|select_send| { - Loop::Continue((disk_manager_recv, disk_request_map, select_send, peer_manager_send)) - })), - Some(Either::B(peer_message)) => { - Box::new(peer_manager_send.send(peer_message).map(|peer_manager_send| { - Loop::Continue((disk_manager_recv, disk_request_map, select_send, peer_manager_send)) - })) - } - None => Box::new(future::ok(Loop::Continue(( - disk_manager_recv, - disk_request_map, - select_send, - peer_manager_send, - )))), - }; - - result_future - }) - }, + // Handle messages from the disk manager + tasks.lock().await.spawn(handle_disk_manager_messages( + disk_manager_receiver, + disk_request_map.clone(), + selection_sender.clone(), + peer_manager_sender.clone(), )); - // Generate data structure to track the requests we need to make, the requests that have been fulfilled, and an active peers list - let piece_requests = generate_requests(metainfo.info(), 16 * 1024); + // Generate piece requests for the torrent + let piece_requests = generate_piece_requests(metainfo.info(), 16 * 1024); - // Have our disk manager allocate space for our torrent and start tracking it - core.run(disk_manager_send.send(IDiskMessage::AddTorrent(metainfo.clone()))) + // Add the torrent to the disk manager + disk_manager_sender + .send(IDiskMessage::AddTorrent(metainfo.clone())) + .await .unwrap(); - // For any pieces we already have on the file system (and are marked as good), we will be removing them from our requests map - let (select_recv, piece_requests, cur_pieces) = core - .run(future::loop_fn( - (select_recv, piece_requests, 0), - |(select_recv, mut piece_requests, cur_pieces)| { - select_recv - .into_future() - .map(move |(opt_item, select_recv)| { - match opt_item.unwrap() { - // Disk manager identified a good piece already downloaded - SelectState::GoodPiece(index) => { - piece_requests.retain(|req| u64::from(req.piece_index()) != index); - Loop::Continue((select_recv, piece_requests, cur_pieces + 1)) - } - // Disk manager is finished identifying good pieces, torrent has been added - SelectState::TorrentAdded => Loop::Break((select_recv, piece_requests, cur_pieces)), - // Shouldn't be receiving any other messages... - message => panic!("Unexpected Message Received In Selection Receiver: {message:?}"), - } - }) - .map_err(|_| ()) - }, + // Handle existing pieces and update the selection receiver + let (selection_receiver, piece_requests, current_pieces) = + handle_existing_pieces(selection_receiver, piece_requests, 0).await; + + // Initiate connection to the specified peer + handshaker_sender + .send(InitiateMessage::new( + Protocol::BitTorrent, + info_hash, + peer_address.parse().unwrap(), )) + .await .unwrap(); - /* - // Setup the dht which will be the only peer discovery service we use in this example - let legacy_handshaker = LegacyHandshaker::new(handshaker_send); - let dht = DhtBuilder::with_router(Router::uTorrent) - .set_read_only(false) - .start_mainline(legacy_handshaker).unwrap(); - - dht.search(info_hash, true); - */ - - // Send the peer given from the command line over to the handshaker to initiate a connection - core.run( - handshaker_send - .send(InitiateMessage::new(Protocol::BitTorrent, info_hash, peer_addr)) - .map_err(|_| ()), - ) - .unwrap(); - - // Finally, setup our main event loop to drive the tasks we setup earlier - let map_peer_manager_send = peer_manager_send.sink_map_err(|_| ()); + // Print current status of pieces and requests let total_pieces = metainfo.info().pieces().count(); println!( "Current Pieces: {}\nTotal Pieces: {}\nRequests Left: {}", - cur_pieces, + current_pieces, total_pieces, piece_requests.len() ); - let result: Result<(), ()> = core.run(future::loop_fn( - ( - select_recv, - map_peer_manager_send, - piece_requests, - None, - false, - 0, - cur_pieces, - total_pieces, - ), - |( - select_recv, - map_peer_manager_send, - mut piece_requests, - mut opt_peer, - mut unchoked, - mut blocks_pending, - mut cur_pieces, - total_pieces, - )| { - select_recv - .into_future() - .map_err(|_| ()) - .and_then(move |(opt_message, select_recv)| { - // Handle the current selection message, decide any control messages we need to send - let send_messages = match opt_message.unwrap() { - SelectState::BlockProcessed => { - // Disk manager let us know a block was processed (one of our requests made it - // from the peer manager, to the disk manager, and this is the acknowledgement) - blocks_pending -= 1; - vec![] - } - SelectState::Choke(_) => { - // Peer choked us, cant be sending any requests to them for now - unchoked = false; - vec![] - } - SelectState::UnChoke(_) => { - // Peer unchoked us, we can continue sending sending requests to them - unchoked = true; - vec![] - } - SelectState::NewPeer(info) => { - // A new peer connected to us, store its contact info (just supported one peer atm), - // and go ahead and express our interest in them, and unchoke them (we can upload to them) - // We don't send a bitfield message (just to keep things simple). - opt_peer = Some(info); - vec![ - IPeerManagerMessage::SendMessage(info, 0, PeerWireProtocolMessage::Interested), - IPeerManagerMessage::SendMessage(info, 0, PeerWireProtocolMessage::UnChoke), - ] - } - SelectState::GoodPiece(piece) => { - // Disk manager has processed enough blocks to make up a piece, and that piece - // was verified to be good (checksummed). Go ahead and increment the number of - // pieces we have. We don't handle bad pieces here (since we deleted our request - // but ideally, we would recreate those requests and resend/blacklist the peer). - cur_pieces += 1; - - if let Some(peer) = opt_peer { - // Send our have message back to the peer - vec![IPeerManagerMessage::SendMessage( - peer, - 0, - PeerWireProtocolMessage::Have(HaveMessage::new(piece.try_into().unwrap())), - )] - } else { - vec![] - } - } - // Decided not to handle these two cases here - SelectState::RemovedPeer(info) => panic!("Peer {info:?} Got Disconnected"), - SelectState::BadPiece(_) => panic!("Peer Gave Us Bad Piece"), - _ => vec![], - }; - - // Need a type annotation of this return type, provide that - let result: Box, Error = ()>> = if cur_pieces == total_pieces { - // We have all of the (unique) pieces required for our torrent - Box::new(future::ok(Loop::Break(()))) - } else if let Some(peer) = opt_peer { - // We have peer contact info, if we are unchoked, see if we can queue up more requests - let next_piece_requests = if unchoked { - let take_blocks = cmp::min(MAX_PENDING_BLOCKS - blocks_pending, piece_requests.len()); - blocks_pending += take_blocks; - - piece_requests - .drain(0..take_blocks) - .map(move |item| { - Ok::<_, ()>(IPeerManagerMessage::SendMessage( - peer, - 0, - PeerWireProtocolMessage::Request(item), - )) - }) - .collect() - } else { - vec![] - }; - - // First, send any control messages, then, send any more piece requests - Box::new( - map_peer_manager_send - .send_all(stream::iter_result(send_messages.into_iter().map(Ok::<_, ()>))) - .map_err(|()| ()) - .and_then(|(map_peer_manager_send, _)| { - map_peer_manager_send.send_all(stream::iter_result(next_piece_requests)) - }) - .map_err(|()| ()) - .map(move |(map_peer_manager_send, _)| { - Loop::Continue(( - select_recv, - map_peer_manager_send, - piece_requests, - opt_peer, - unchoked, - blocks_pending, - cur_pieces, - total_pieces, - )) - }), - ) - } else { - // Not done yet, and we don't have any peer info stored (haven't received the peer yet) - Box::new(future::ok(Loop::Continue(( - select_recv, - map_peer_manager_send, - piece_requests, - opt_peer, - unchoked, - blocks_pending, - cur_pieces, - total_pieces, - )))) - }; - - result - }) - }, - )); + // Handle selection messages and manage piece requests + let () = handle_selection_messages( + selection_receiver, + peer_manager_sender, + piece_requests, + None, + false, + 0, + current_pieces, + total_pieces, + ) + .await; +} + +type NativeDiskManager = DiskManager>; + +fn setup_disk_manager(download_directory: &str) -> DiskManager> { + let filesystem = FileHandleCache::new(NativeFileSystem::with_directory(download_directory), 100); + + DiskManagerBuilder::new() + .with_sink_buffer_capacity(1) + .with_stream_buffer_capacity(0) + .build(Arc::new(filesystem)) +} + +type TcpHandshaker = Handshaker; + +async fn setup_handshaker() -> (TcpHandshaker, JoinSet<()>) { + HandshakerBuilder::new() + .with_peer_id(PeerId::from_hash("-BI0000-000000000000".as_bytes()).unwrap()) + .with_config(HandshakerConfig::default().with_wait_buffer_size(0).with_done_buffer_size(0)) + .build(TcpTransport) + .await + .unwrap() +} + +type PeerManager = peer::PeerManager< + Framed>>, + PeerWireProtocolMessage, +>; + +#[allow(clippy::type_complexity)] +fn setup_peer_manager() -> PeerManager { + PeerManagerBuilder::new() + .with_sink_buffer_capacity(0) + .with_stream_buffer_capacity(0) + .build() +} + +async fn handle_new_connections( + handshaker_receiver: HandshakerStream, + peer_manager_sender: PeerManagerSink< + Framed>>, + PeerWireProtocolMessage, + >, +) { + let new_connections = handshaker_receiver + .filter_map(|item| async move { Some(item.expect("it should not have a failure when making the handshake")) }); + + let new_peers = new_connections.map(|message| { + let (_, _, hash, peer_id, address, socket) = message.into_parts(); + let framed_socket = Decoder::framed( + PeerProtocolCodec::with_max_payload(PeerWireProtocol::new(NullProtocol::new()), 24 * 1024), + socket, + ); + + let peer_info = PeerInfo::new(address, peer_id, hash, Extensions::new()); + + Ok(Ok(PeerManagerInputMessage::AddPeer(peer_info, framed_socket))) + }); + + new_peers.forward(peer_manager_sender.clone()).await.unwrap(); +} - result.unwrap(); +async fn handle_peer_manager_messages( + mut peer_manager_receiver: PeerManagerStream< + Framed>>, + PeerWireProtocolMessage, + >, + info_hash: InfoHash, + disk_request_map: Arc>>>, + selection_sender: mpsc::Sender, + disk_manager_sender: DiskManagerSink>, +) { + while let Some(result) = peer_manager_receiver.next().await { + let opt_message = match result { + Ok(PeerManagerOutputMessage::ReceivedMessage(peer_info, message)) => match message { + PeerWireProtocolMessage::Choke => Some(Either::Left(PeerSelectionState::Choke(peer_info))), + PeerWireProtocolMessage::UnChoke => Some(Either::Left(PeerSelectionState::UnChoke(peer_info))), + PeerWireProtocolMessage::Interested => Some(Either::Left(PeerSelectionState::Interested(peer_info))), + PeerWireProtocolMessage::UnInterested => Some(Either::Left(PeerSelectionState::UnInterested(peer_info))), + PeerWireProtocolMessage::Have(have_message) => { + Some(Either::Left(PeerSelectionState::Have(peer_info, have_message))) + } + PeerWireProtocolMessage::BitField(bitfield_message) => { + Some(Either::Left(PeerSelectionState::BitField(peer_info, bitfield_message))) + } + PeerWireProtocolMessage::Request(request_message) => { + let block_metadata = BlockMetadata::new( + info_hash, + u64::from(request_message.piece_index()), + u64::from(request_message.block_offset()), + request_message.block_length(), + ); + let mut request_map_mut = disk_request_map.lock().await; + + let block_entry = request_map_mut.entry(block_metadata); + let peers_requested = block_entry.or_insert(Vec::new()); + + peers_requested.push(peer_info); + + Some(Either::Right(IDiskMessage::LoadBlock(BlockMut::new( + block_metadata, + BytesMut::with_capacity(block_metadata.block_length()), + )))) + } + PeerWireProtocolMessage::Piece(piece_message) => { + let block_metadata = BlockMetadata::new( + info_hash, + u64::from(piece_message.piece_index()), + u64::from(piece_message.block_offset()), + piece_message.block_length(), + ); + + Some(Either::Right(IDiskMessage::ProcessBlock(Block::new( + block_metadata, + piece_message.block(), + )))) + } + _ => None, + }, + Ok(PeerManagerOutputMessage::PeerAdded(peer_info)) => Some(Either::Left(PeerSelectionState::NewPeer(peer_info))), + Ok(PeerManagerOutputMessage::PeerRemoved(peer_info)) => { + println!("Removed Peer {peer_info:?} From The Peer Manager"); + Some(Either::Left(PeerSelectionState::RemovedPeer(peer_info))) + } + Ok(PeerManagerOutputMessage::PeerDisconnect(peer_info)) => { + println!("Peer {peer_info:?} Disconnected From Us"); + Some(Either::Left(PeerSelectionState::RemovedPeer(peer_info))) + } + Err(PeerManagerOutputError::PeerError(peer_info, error)) => { + println!("Peer {peer_info:?} Disconnected With Error: {error:?}"); + Some(Either::Left(PeerSelectionState::RemovedPeer(peer_info))) + } + + Err(_) | Ok(PeerManagerOutputMessage::SentMessage(_, _)) => None, + }; + + if let Some(message) = opt_message { + match message { + Either::Left(selection_message) => { + if (selection_sender.clone().send(selection_message).await).is_err() { + break; + } + } + Either::Right(disk_message) => { + if (disk_manager_sender.clone().send(disk_message).await).is_err() { + break; + } + } + } + } + } +} + +async fn handle_disk_manager_messages( + mut disk_manager_receiver: DiskManagerStream, + disk_request_map: Arc>>>, + selection_sender: mpsc::Sender, + peer_manager_sender: PeerManagerSink< + Framed>>, + PeerWireProtocolMessage, + >, +) { + while let Some(result) = disk_manager_receiver.next().await { + let opt_message = match result { + Ok(ODiskMessage::BlockLoaded(block)) => { + let (metadata, block) = block.into_parts(); + + let mut request_map_mut = disk_request_map.lock().await; + let peer_list = request_map_mut.get_mut(&metadata).unwrap(); + let peer_info = peer_list.remove(0); + + let piece_message = PieceMessage::new( + metadata.piece_index().try_into().unwrap(), + metadata.block_offset().try_into().unwrap(), + block.freeze(), + ); + let pwp_message = PeerWireProtocolMessage::Piece(piece_message); + + Some(Either::Right(PeerManagerInputMessage::SendMessage(peer_info, 0, pwp_message))) + } + Ok(ODiskMessage::TorrentAdded(_)) => Some(Either::Left(PeerSelectionState::TorrentAdded)), + Ok(ODiskMessage::TorrentSynced(_)) => Some(Either::Left(PeerSelectionState::TorrentSynced)), + Ok(ODiskMessage::FoundGoodPiece(_, index)) => Some(Either::Left(PeerSelectionState::GoodPiece(index))), + Ok(ODiskMessage::FoundBadPiece(_, index)) => Some(Either::Left(PeerSelectionState::BadPiece(index))), + Ok(ODiskMessage::BlockProcessed(_)) => Some(Either::Left(PeerSelectionState::BlockProcessed)), + _ => None, + }; + + if let Some(message) = opt_message { + match message { + Either::Left(selection_message) => { + if (selection_sender.clone().send(selection_message).await).is_err() { + break; + } + } + Either::Right(peer_message) => { + if (peer_manager_sender.clone().send(Ok(peer_message)).await).is_err() { + break; + } + } + } + } + } +} + +async fn handle_existing_pieces( + mut selection_receiver: mpsc::Receiver, + mut piece_requests: Vec, + mut current_pieces: usize, +) -> (mpsc::Receiver, Vec, usize) { + loop { + match selection_receiver.next().await { + Some(PeerSelectionState::GoodPiece(index)) => { + piece_requests.retain(|req| u64::from(req.piece_index()) != index); + current_pieces += 1; + } + None | Some(PeerSelectionState::TorrentAdded) => { + break (selection_receiver, piece_requests, current_pieces); + } + Some(message) => { + panic!("Unexpected Message Received In Selection mpsc::Receiver: {message:?}"); + } + } + } +} + +#[allow(clippy::too_many_arguments)] +async fn handle_selection_messages( + mut selection_receiver: mpsc::Receiver, + mut peer_manager_sender: PeerManagerSink< + Framed>>, + PeerWireProtocolMessage, + >, + mut piece_requests: Vec, + mut optional_peer: Option, + mut is_unchoked: bool, + mut pending_blocks: usize, + mut current_pieces: usize, + total_pieces: usize, +) { + while let Some(state) = selection_receiver.next().await { + let control_messages = match state { + PeerSelectionState::BlockProcessed => { + pending_blocks -= 1; + vec![] + } + PeerSelectionState::Choke(_) => { + is_unchoked = false; + vec![] + } + PeerSelectionState::UnChoke(_) => { + is_unchoked = true; + vec![] + } + PeerSelectionState::NewPeer(peer_info) => { + optional_peer = Some(peer_info); + vec![ + PeerManagerInputMessage::SendMessage(peer_info, 0, PeerWireProtocolMessage::Interested), + PeerManagerInputMessage::SendMessage(peer_info, 0, PeerWireProtocolMessage::UnChoke), + ] + } + PeerSelectionState::GoodPiece(piece_index) => { + current_pieces += 1; + + if let Some(peer_info) = optional_peer { + vec![PeerManagerInputMessage::SendMessage( + peer_info, + 0, + PeerWireProtocolMessage::Have(HaveMessage::new(piece_index.try_into().unwrap())), + )] + } else { + vec![] + } + } + PeerSelectionState::RemovedPeer(peer_info) => { + eprintln!("Peer {peer_info:?} Got Disconnected"); + vec![] + } + PeerSelectionState::BadPiece(_) => { + eprintln!("Peer Gave Us Bad Piece"); + vec![] + } + _ => vec![], + }; + + if current_pieces == total_pieces { + println!("All pieces have been successfully downloaded."); + break; + } else if let Some(peer_info) = optional_peer { + let next_piece_requests = if is_unchoked { + let take_blocks = std::cmp::min(MAX_PENDING_BLOCKS - pending_blocks, piece_requests.len()); + pending_blocks += take_blocks; + + piece_requests + .drain(0..take_blocks) + .map(move |item| { + Ok::<_, _>(PeerManagerInputMessage::SendMessage( + peer_info, + 0, + PeerWireProtocolMessage::Request(item), + )) + }) + .collect() + } else { + vec![] + }; + + // First, send any control messages, then, send any more piece requests + if let Err(e) = peer_manager_sender + .send_all(&mut stream::iter( + control_messages.into_iter().map(Ok::<_, _>).map(Ok::<_, _>), + )) + .await + { + eprintln!("Error sending control messages: {e:?}"); + break; + } + + if let Err(e) = peer_manager_sender + .send_all(&mut stream::iter(next_piece_requests).map(Ok::<_, _>)) + .await + { + eprintln!("Error sending piece requests: {e:?}"); + break; + } + } + } } /// Generate a mapping of piece index to list of block requests for that piece, given a block size. /// /// Note, most clients will drop connections for peers requesting block sizes above 16KB. -fn generate_requests(info: &Info, block_size: usize) -> Vec { +fn generate_piece_requests(info: &Info, block_size: usize) -> Vec { let mut requests = Vec::new(); // Grab our piece length, and the sum of the lengths of each file in the torrent - let piece_len: u64 = info.piece_length(); + let piece_length: u64 = info.piece_length(); let mut total_file_length: u64 = info.files().map(metainfo::File::length).sum(); // Loop over each piece (keep subtracting total file length by piece size, use cmp::min to handle last, smaller piece) let mut piece_index: u64 = 0; while total_file_length != 0 { - let next_piece_len = cmp::min(total_file_length, piece_len); + let next_piece_length = std::cmp::min(total_file_length, piece_length); // For all whole blocks, push the block index and block_size - let whole_blocks = next_piece_len / block_size as u64; + let whole_blocks = next_piece_length / block_size as u64; for block_index in 0..whole_blocks { let block_offset = block_index * block_size as u64; - #[allow(clippy::cast_possible_truncation)] - requests.push(RequestMessage::new(piece_index as u32, block_offset as u32, block_size)); + requests.push(RequestMessage::new( + piece_index.try_into().unwrap(), + block_offset.try_into().unwrap(), + block_size, + )); } // Check for any last smaller block within the current piece - let partial_block_length = next_piece_len % block_size as u64; + let partial_block_length = next_piece_length % block_size as u64; if partial_block_length != 0 { let block_offset = whole_blocks * block_size as u64; @@ -588,7 +661,7 @@ fn generate_requests(info: &Info, block_size: usize) -> Vec { } // Take this piece out of the total length, increment to the next piece - total_file_length -= next_piece_len; + total_file_length -= next_piece_length; piece_index += 1; } diff --git a/packages/bencode/Cargo.toml b/packages/bencode/Cargo.toml index 62e701a10..cd25ac5f1 100644 --- a/packages/bencode/Cargo.toml +++ b/packages/bencode/Cargo.toml @@ -16,10 +16,10 @@ repository.workspace = true version.workspace = true [dependencies] -error-chain = "0.12" +thiserror = "1" [dev-dependencies] -criterion = "0.5" +criterion = "0" [[bench]] harness = false diff --git a/packages/bencode/src/access/convert.rs b/packages/bencode/src/access/convert.rs index 42b04f267..b2eb41d15 100644 --- a/packages/bencode/src/access/convert.rs +++ b/packages/bencode/src/access/convert.rs @@ -2,7 +2,7 @@ use crate::access::bencode::{BRefAccess, BRefAccessExt}; use crate::access::dict::BDictAccess; use crate::access::list::BListAccess; -use crate::{BencodeConvertError, BencodeConvertErrorKind}; +use crate::BencodeConvertError; /// Trait for extended casting of bencode objects and converting conversion errors into application specific errors. pub trait BConvertExt: BConvert { @@ -12,12 +12,10 @@ pub trait BConvertExt: BConvert { B: BRefAccessExt<'a>, E: AsRef<[u8]>, { - bencode.bytes_ext().ok_or( - self.handle_error(BencodeConvertError::from_kind(BencodeConvertErrorKind::WrongType { - key: error_key.as_ref().to_owned(), - expected_type: "Bytes".to_owned(), - })), - ) + bencode.bytes_ext().ok_or(self.handle_error(BencodeConvertError::WrongType { + key: error_key.as_ref().to_owned(), + expected_type: "Bytes".to_owned(), + })) } /// See `BConvert::convert_str`. @@ -26,12 +24,10 @@ pub trait BConvertExt: BConvert { B: BRefAccessExt<'a>, E: AsRef<[u8]>, { - bencode.str_ext().ok_or( - self.handle_error(BencodeConvertError::from_kind(BencodeConvertErrorKind::WrongType { - key: error_key.as_ref().to_owned(), - expected_type: "UTF-8 Bytes".to_owned(), - })), - ) + bencode.str_ext().ok_or(self.handle_error(BencodeConvertError::WrongType { + key: error_key.as_ref().to_owned(), + expected_type: "UTF-8 Bytes".to_owned(), + })) } /// See `BConvert::lookup_and_convert_bytes`. @@ -77,12 +73,10 @@ pub trait BConvert { B: BRefAccess, E: AsRef<[u8]>, { - bencode.int().ok_or( - self.handle_error(BencodeConvertError::from_kind(BencodeConvertErrorKind::WrongType { - key: error_key.as_ref().to_owned(), - expected_type: "Integer".to_owned(), - })), - ) + bencode.int().ok_or(self.handle_error(BencodeConvertError::WrongType { + key: error_key.as_ref().to_owned(), + expected_type: "Integer".to_owned(), + })) } /// Attempt to convert the given bencode value into bytes. @@ -93,12 +87,10 @@ pub trait BConvert { B: BRefAccess, E: AsRef<[u8]>, { - bencode.bytes().ok_or( - self.handle_error(BencodeConvertError::from_kind(BencodeConvertErrorKind::WrongType { - key: error_key.as_ref().to_owned(), - expected_type: "Bytes".to_owned(), - })), - ) + bencode.bytes().ok_or(self.handle_error(BencodeConvertError::WrongType { + key: error_key.as_ref().to_owned(), + expected_type: "Bytes".to_owned(), + })) } /// Attempt to convert the given bencode value into a UTF-8 string. @@ -109,12 +101,10 @@ pub trait BConvert { B: BRefAccess, E: AsRef<[u8]>, { - bencode.str().ok_or( - self.handle_error(BencodeConvertError::from_kind(BencodeConvertErrorKind::WrongType { - key: error_key.as_ref().to_owned(), - expected_type: "UTF-8 Bytes".to_owned(), - })), - ) + bencode.str().ok_or(self.handle_error(BencodeConvertError::WrongType { + key: error_key.as_ref().to_owned(), + expected_type: "UTF-8 Bytes".to_owned(), + })) } /// Attempt to convert the given bencode value into a list. @@ -125,12 +115,10 @@ pub trait BConvert { B: BRefAccess, E: AsRef<[u8]>, { - bencode.list().ok_or( - self.handle_error(BencodeConvertError::from_kind(BencodeConvertErrorKind::WrongType { - key: error_key.as_ref().to_owned(), - expected_type: "List".to_owned(), - })), - ) + bencode.list().ok_or(self.handle_error(BencodeConvertError::WrongType { + key: error_key.as_ref().to_owned(), + expected_type: "List".to_owned(), + })) } /// Attempt to convert the given bencode value into a dictionary. @@ -141,12 +129,10 @@ pub trait BConvert { B: BRefAccess, E: AsRef<[u8]>, { - bencode.dict().ok_or( - self.handle_error(BencodeConvertError::from_kind(BencodeConvertErrorKind::WrongType { - key: error_key.as_ref().to_owned(), - expected_type: "Dictionary".to_owned(), - })), - ) + bencode.dict().ok_or(self.handle_error(BencodeConvertError::WrongType { + key: error_key.as_ref().to_owned(), + expected_type: "Dictionary".to_owned(), + })) } /// Look up a value in a dictionary of bencoded values using the given key. @@ -159,11 +145,7 @@ pub trait BConvert { match dictionary.lookup(key_ref) { Some(n) => Ok(n), - None => Err( - self.handle_error(BencodeConvertError::from_kind(BencodeConvertErrorKind::MissingKey { - key: key_ref.to_owned(), - })), - ), + None => Err(self.handle_error(BencodeConvertError::MissingKey { key: key_ref.to_owned() })), } } diff --git a/packages/bencode/src/error.rs b/packages/bencode/src/error.rs index 18ebe9605..6e661a068 100644 --- a/packages/bencode/src/error.rs +++ b/packages/bencode/src/error.rs @@ -1,101 +1,52 @@ -use error_chain::error_chain; +use thiserror::Error; -error_chain! { - types { - BencodeParseError, BencodeParseErrorKind, BencodeParseResultExt, BencodeParseResult; - } +#[allow(clippy::module_name_repetitions)] +#[derive(Error, Debug)] +pub enum BencodeParseError { + #[error("Incomplete Number Of Bytes At {pos}")] + BytesEmpty { pos: usize }, - errors { - BytesEmpty { - pos: usize - } { - description("Incomplete Number Of Bytes") - display("Incomplete Number Of Bytes At {:?}", pos) - } - InvalidByte { - pos: usize - } { - description("Invalid Byte Found") - display("Invalid Byte Found At {:?}", pos) - } - InvalidIntNoDelimiter { - pos: usize - } { - description("Invalid Integer Found With No Delimiter") - display("Invalid Integer Found With No Delimiter At {:?}", pos) - } - InvalidIntNegativeZero { - pos: usize - } { - description("Invalid Integer Found As Negative Zero") - display("Invalid Integer Found As Negative Zero At {:?}", pos) - } - InvalidIntZeroPadding { - pos: usize - } { - description("Invalid Integer Found With Zero Padding") - display("Invalid Integer Found With Zero Padding At {:?}", pos) - } - InvalidIntParseError { - pos: usize - } { - description("Invalid Integer Found To Fail Parsing") - display("Invalid Integer Found To Fail Parsing At {:?}", pos) - } - InvalidKeyOrdering { - pos: usize, - key: Vec - } { - description("Invalid Dictionary Key Ordering Found") - display("Invalid Dictionary Key Ordering Found At {:?} For Key {:?}", pos, key) - } - InvalidKeyDuplicates { - pos: usize, - key: Vec - } { - description("Invalid Dictionary Duplicate Keys Found") - display("Invalid Dictionary Key Found At {:?} For Key {:?}", pos, key) - } - InvalidLengthNegative { - pos: usize - } { - description("Invalid Byte Length Found As Negative") - display("Invalid Byte Length Found As Negative At {:?}", pos) - } - InvalidLengthOverflow { - pos: usize - } { - description("Invalid Byte Length Found To Overflow Buffer Length") - display("Invalid Byte Length Found To Overflow Buffer Length At {:?}", pos) - } - InvalidRecursionExceeded { - pos: usize, - max: usize - } { - description("Invalid Recursion Limit Exceeded") - display("Invalid Recursion Limit Exceeded At {:?} For Limit {:?}", pos, max) - } - } + #[error("Invalid Byte Found At {pos}")] + InvalidByte { pos: usize }, + + #[error("Invalid Integer Found With No Delimiter At {pos}")] + InvalidIntNoDelimiter { pos: usize }, + + #[error("Invalid Integer Found As Negative Zero At {pos}")] + InvalidIntNegativeZero { pos: usize }, + + #[error("Invalid Integer Found With Zero Padding At {pos}")] + InvalidIntZeroPadding { pos: usize }, + + #[error("Invalid Integer Found To Fail Parsing At {pos}")] + InvalidIntParseError { pos: usize }, + + #[error("Invalid Dictionary Key Ordering Found At {pos} For Key {key:?}")] + InvalidKeyOrdering { pos: usize, key: Vec }, + + #[error("Invalid Dictionary Key Found At {pos} For Key {key:?}")] + InvalidKeyDuplicates { pos: usize, key: Vec }, + + #[error("Invalid Byte Length Found As Negative At {pos}")] + InvalidLengthNegative { pos: usize }, + + #[error("Invalid Byte Length Found To Overflow Buffer Length At {pos}")] + InvalidLengthOverflow { pos: usize }, + + #[error("Invalid Recursion Limit Exceeded At {pos} For Limit {max}")] + InvalidRecursionExceeded { pos: usize, max: usize }, } -error_chain! { - types { - BencodeConvertError, BencodeConvertErrorKind, BencodeConvertResultExt, BencodeConvertResult; - } +pub type BencodeParseResult = Result; - errors { - MissingKey { - key: Vec - } { - description("Missing Key In Bencode") - display("Missing Key In Bencode For {:?}", key) - } - WrongType { - key: Vec, - expected_type: String - } { - description("Wrong Type In Bencode") - display("Wrong Type In Bencode For {:?} Expected Type {}", key, expected_type) - } - } +#[allow(clippy::module_name_repetitions)] +#[derive(Error, Debug)] +pub enum BencodeConvertError { + #[error("Missing Key In Bencode For {key:?}")] + MissingKey { key: Vec }, + + #[error("Wrong Type In Bencode For {key:?} Expected Type {expected_type}")] + WrongType { key: Vec, expected_type: String }, } + +pub type BencodeConvertResult = Result; diff --git a/packages/bencode/src/lib.rs b/packages/bencode/src/lib.rs index 78e113b66..09aaa6867 100644 --- a/packages/bencode/src/lib.rs +++ b/packages/bencode/src/lib.rs @@ -7,7 +7,6 @@ //! ```rust //! extern crate bencode; //! -//! use std::default::Default; //! use bencode::{BencodeRef, BRefAccess, BDecodeOpt}; //! //! fn main() { @@ -63,10 +62,7 @@ pub use crate::access::bencode::{BMutAccess, BRefAccess, MutKind, RefKind}; pub use crate::access::convert::BConvert; pub use crate::access::dict::BDictAccess; pub use crate::access::list::BListAccess; -pub use crate::error::{ - BencodeConvertError, BencodeConvertErrorKind, BencodeConvertResult, BencodeParseError, BencodeParseErrorKind, - BencodeParseResult, -}; +pub use crate::error::{BencodeConvertError, BencodeConvertResult, BencodeParseError, BencodeParseResult}; pub use crate::mutable::bencode_mut::BencodeMut; pub use crate::reference::bencode_ref::BencodeRef; pub use crate::reference::decode_opt::BDecodeOpt; diff --git a/packages/bencode/src/reference/bencode_ref.rs b/packages/bencode/src/reference/bencode_ref.rs index 760dd3016..73aaad039 100644 --- a/packages/bencode/src/reference/bencode_ref.rs +++ b/packages/bencode/src/reference/bencode_ref.rs @@ -4,7 +4,7 @@ use std::str; use crate::access::bencode::{BRefAccess, BRefAccessExt, RefKind}; use crate::access::dict::BDictAccess; use crate::access::list::BListAccess; -use crate::error::{BencodeParseError, BencodeParseErrorKind, BencodeParseResult}; +use crate::error::{BencodeParseError, BencodeParseResult}; use crate::reference::decode; use crate::reference::decode_opt::BDecodeOpt; @@ -41,9 +41,7 @@ impl<'a> BencodeRef<'a> { let (bencode, end_pos) = decode::decode(bytes, 0, opts, 0)?; if end_pos != bytes.len() && opts.enforce_full_decode() { - return Err(BencodeParseError::from_kind(BencodeParseErrorKind::BytesEmpty { - pos: end_pos, - })); + return Err(BencodeParseError::BytesEmpty { pos: end_pos }); } Ok(bencode) @@ -125,7 +123,6 @@ impl<'a> BRefAccessExt<'a> for BencodeRef<'a> { #[cfg(test)] mod tests { - use std::default::Default; use crate::access::bencode::BRefAccess; use crate::reference::bencode_ref::BencodeRef; diff --git a/packages/bencode/src/reference/decode.rs b/packages/bencode/src/reference/decode.rs index d2aa180f8..97c5cf1ff 100644 --- a/packages/bencode/src/reference/decode.rs +++ b/packages/bencode/src/reference/decode.rs @@ -1,16 +1,14 @@ use std::collections::btree_map::Entry; use std::collections::BTreeMap; -use std::str::{self}; +use std::str; -use crate::error::{BencodeParseError, BencodeParseErrorKind, BencodeParseResult}; +use crate::error::{BencodeParseError, BencodeParseResult}; use crate::reference::bencode_ref::{BencodeRef, Inner}; use crate::reference::decode_opt::BDecodeOpt; pub fn decode(bytes: &[u8], pos: usize, opts: BDecodeOpt, depth: usize) -> BencodeParseResult<(BencodeRef<'_>, usize)> { if depth >= opts.max_recursion() { - return Err(BencodeParseError::from_kind( - BencodeParseErrorKind::InvalidRecursionExceeded { pos, max: depth }, - )); + return Err(BencodeParseError::InvalidRecursionExceeded { pos, max: depth }); } let curr_byte = peek_byte(bytes, pos)?; @@ -32,7 +30,7 @@ pub fn decode(bytes: &[u8], pos: usize, opts: BDecodeOpt, depth: usize) -> Benco // Include the length digit, don't increment position Ok((Inner::Bytes(bencode, &bytes[pos..next_pos]).into(), next_pos)) } - _ => Err(BencodeParseError::from_kind(BencodeParseErrorKind::InvalidByte { pos })), + _ => Err(BencodeParseError::InvalidByte { pos }), } } @@ -40,32 +38,24 @@ fn decode_int(bytes: &[u8], pos: usize, delim: u8) -> BencodeParseResult<(i64, u let (_, begin_decode) = bytes.split_at(pos); let Some(relative_end_pos) = begin_decode.iter().position(|n| *n == delim) else { - return Err(BencodeParseError::from_kind(BencodeParseErrorKind::InvalidIntNoDelimiter { - pos, - })); + return Err(BencodeParseError::InvalidIntNoDelimiter { pos }); }; let int_byte_slice = &begin_decode[..relative_end_pos]; if int_byte_slice.len() > 1 { // Negative zero is not allowed (this would not be caught when converting) if int_byte_slice[0] == b'-' && int_byte_slice[1] == b'0' { - return Err(BencodeParseError::from_kind(BencodeParseErrorKind::InvalidIntNegativeZero { - pos, - })); + return Err(BencodeParseError::InvalidIntNegativeZero { pos }); } // Zero padding is illegal, and unspecified for key lengths (we disallow both) if int_byte_slice[0] == b'0' { - return Err(BencodeParseError::from_kind(BencodeParseErrorKind::InvalidIntZeroPadding { - pos, - })); + return Err(BencodeParseError::InvalidIntZeroPadding { pos }); } } let Ok(int_str) = str::from_utf8(int_byte_slice) else { - return Err(BencodeParseError::from_kind(BencodeParseErrorKind::InvalidIntParseError { - pos, - })); + return Err(BencodeParseError::InvalidIntParseError { pos }); }; // Position of end of integer type, next byte is the start of the next value @@ -73,31 +63,24 @@ fn decode_int(bytes: &[u8], pos: usize, delim: u8) -> BencodeParseResult<(i64, u let next_pos = absolute_end_pos + 1; match int_str.parse::() { Ok(n) => Ok((n, next_pos)), - Err(_) => Err(BencodeParseError::from_kind(BencodeParseErrorKind::InvalidIntParseError { - pos, - })), + Err(_) => Err(BencodeParseError::InvalidIntParseError { pos }), } } +use std::convert::TryFrom; + fn decode_bytes(bytes: &[u8], pos: usize) -> BencodeParseResult<(&[u8], usize)> { let (num_bytes, start_pos) = decode_int(bytes, pos, crate::BYTE_LEN_END)?; if num_bytes < 0 { - return Err(BencodeParseError::from_kind(BencodeParseErrorKind::InvalidLengthNegative { - pos, - })); + return Err(BencodeParseError::InvalidLengthNegative { pos }); } - // Should be safe to cast to usize (TODO: Check if cast would overflow to provide - // a more helpful error message, otherwise, parsing will probably fail with an - // unrelated message). - let num_bytes = - usize::try_from(num_bytes).map_err(|_| BencodeParseErrorKind::Msg(format!("input length is too long: {num_bytes}")))?; + // Use usize::try_from to handle potential overflow + let num_bytes = usize::try_from(num_bytes).map_err(|_| BencodeParseError::InvalidLengthOverflow { pos })?; if num_bytes > bytes[start_pos..].len() { - return Err(BencodeParseError::from_kind(BencodeParseErrorKind::InvalidLengthOverflow { - pos, - })); + return Err(BencodeParseError::InvalidLengthOverflow { pos }); } let next_pos = start_pos + num_bytes; @@ -140,10 +123,10 @@ fn decode_dict( // Spec says that the keys must be in alphabetical order match (bencode_dict.keys().last(), opts.check_key_sort()) { (Some(last_key), true) if key_bytes < *last_key => { - return Err(BencodeParseError::from_kind(BencodeParseErrorKind::InvalidKeyOrdering { + return Err(BencodeParseError::InvalidKeyOrdering { pos: curr_pos, key: key_bytes.to_vec(), - })) + }) } _ => (), }; @@ -153,10 +136,10 @@ fn decode_dict( match bencode_dict.entry(key_bytes) { Entry::Vacant(n) => n.insert(value), Entry::Occupied(_) => { - return Err(BencodeParseError::from_kind(BencodeParseErrorKind::InvalidKeyDuplicates { + return Err(BencodeParseError::InvalidKeyDuplicates { pos: curr_pos, key: key_bytes.to_vec(), - })) + }) } }; @@ -169,15 +152,11 @@ fn decode_dict( } fn peek_byte(bytes: &[u8], pos: usize) -> BencodeParseResult { - bytes - .get(pos) - .copied() - .ok_or_else(|| BencodeParseError::from_kind(BencodeParseErrorKind::BytesEmpty { pos })) + bytes.get(pos).copied().ok_or(BencodeParseError::BytesEmpty { pos }) } #[cfg(test)] mod tests { - use std::default::Default; use crate::access::bencode::BRefAccess; use crate::reference::bencode_ref::BencodeRef; @@ -329,13 +308,13 @@ mod tests { } #[test] - #[should_panic = "BencodeParseError(InvalidByte { pos: 0 }"] + #[should_panic = "InvalidByte { pos: 0 }"] fn negative_decode_bytes_neg_len() { BencodeRef::decode(BYTES_NEG_LEN, BDecodeOpt::default()).unwrap(); } #[test] - #[should_panic = "BencodeParseError(BytesEmpty { pos: 20 }"] + #[should_panic = "BytesEmpty { pos: 20 }"] fn negative_decode_bytes_extra() { BencodeRef::decode(BYTES_EXTRA, BDecodeOpt::default()).unwrap(); } @@ -348,49 +327,49 @@ mod tests { } #[test] - #[should_panic = "BencodeParseError(InvalidIntParseError { pos: 1 }"] + #[should_panic = "InvalidIntParseError { pos: 1 }"] fn negative_decode_int_nan() { super::decode_int(INT_NAN, 1, crate::BEN_END).unwrap(); } #[test] - #[should_panic = "BencodeParseError(InvalidIntZeroPadding { pos: 1 }"] + #[should_panic = "InvalidIntZeroPadding { pos: 1 }"] fn negative_decode_int_leading_zero() { super::decode_int(INT_LEADING_ZERO, 1, crate::BEN_END).unwrap(); } #[test] - #[should_panic = "BencodeParseError(InvalidIntZeroPadding { pos: 1 }"] + #[should_panic = "InvalidIntZeroPadding { pos: 1 }"] fn negative_decode_int_double_zero() { super::decode_int(INT_DOUBLE_ZERO, 1, crate::BEN_END).unwrap(); } #[test] - #[should_panic = "BencodeParseError(InvalidIntNegativeZero { pos: 1 }"] + #[should_panic = "InvalidIntNegativeZero { pos: 1 }"] fn negative_decode_int_negative_zero() { super::decode_int(INT_NEGATIVE_ZERO, 1, crate::BEN_END).unwrap(); } #[test] - #[should_panic = " BencodeParseError(InvalidIntParseError { pos: 1 }"] + #[should_panic = " InvalidIntParseError { pos: 1 }"] fn negative_decode_int_double_negative() { super::decode_int(INT_DOUBLE_NEGATIVE, 1, crate::BEN_END).unwrap(); } #[test] - #[should_panic = "BencodeParseError(InvalidKeyOrdering { pos: 15, key: [97, 95, 107, 101, 121] }"] + #[should_panic = "InvalidKeyOrdering { pos: 15, key: [97, 95, 107, 101, 121] }"] fn negative_decode_dict_unordered_keys() { BencodeRef::decode(DICT_UNORDERED_KEYS, BDecodeOpt::new(5, true, true)).unwrap(); } #[test] - #[should_panic = "BencodeParseError(InvalidKeyDuplicates { pos: 18, key: [97, 95, 107, 101, 121] }"] + #[should_panic = "InvalidKeyDuplicates { pos: 18, key: [97, 95, 107, 101, 121] }"] fn negative_decode_dict_dup_keys_same_data() { BencodeRef::decode(DICT_DUP_KEYS_SAME_DATA, BDecodeOpt::default()).unwrap(); } #[test] - #[should_panic = "BencodeParseError(InvalidKeyDuplicates { pos: 18, key: [97, 95, 107, 101, 121] }"] + #[should_panic = "InvalidKeyDuplicates { pos: 18, key: [97, 95, 107, 101, 121] }"] fn negative_decode_dict_dup_keys_diff_data() { BencodeRef::decode(DICT_DUP_KEYS_DIFF_DATA, BDecodeOpt::default()).unwrap(); } diff --git a/packages/bencode/src/reference/decode_opt.rs b/packages/bencode/src/reference/decode_opt.rs index e8d9a8337..8409cc72c 100644 --- a/packages/bencode/src/reference/decode_opt.rs +++ b/packages/bencode/src/reference/decode_opt.rs @@ -1,5 +1,3 @@ -use std::default::Default; - const DEFAULT_MAX_RECURSION: usize = 50; const DEFAULT_CHECK_KEY_SORT: bool = false; const DEFAULT_ENFORCE_FULL_DECODE: bool = true; diff --git a/packages/dht/Cargo.toml b/packages/dht/Cargo.toml index e00b8e08f..bf5d8e85e 100644 --- a/packages/dht/Cargo.toml +++ b/packages/dht/Cargo.toml @@ -20,9 +20,11 @@ bencode = { path = "../bencode" } handshake = { path = "../handshake" } util = { path = "../util" } -chrono = "0.4" +chrono = "0" crc = "3" -error-chain = "0.12" -log = "0.4" -mio = "0.5" -rand = "0.8" +futures = "0" +rand = "0" +thiserror = "1" +tokio = { version = "1", features = ["full"] } +tracing = "0" +tracing-subscriber = "0" diff --git a/packages/dht/examples/debug.rs b/packages/dht/examples/debug.rs index 87c00814e..4998bb4ed 100644 --- a/packages/dht/examples/debug.rs +++ b/packages/dht/examples/debug.rs @@ -1,27 +1,17 @@ use std::collections::HashSet; -use std::io::{self, Read}; +use std::io::Read as _; use std::net::{Ipv4Addr, SocketAddr, SocketAddrV4}; -use std::thread::{self}; +use std::sync::Once; use dht::handshaker_trait::HandshakerTrait; use dht::{DhtBuilder, Router}; +use futures::future::BoxFuture; +use futures::StreamExt; +use tokio::task::JoinSet; +use tracing::level_filters::LevelFilter; use util::bt::{InfoHash, PeerId}; -struct SimpleLogger; - -impl log::Log for SimpleLogger { - fn enabled(&self, metadata: &log::Metadata<'_>) -> bool { - metadata.level() <= log::Level::Info - } - - fn log(&self, record: &log::Record<'_>) { - if self.enabled(record.metadata()) { - println!("{} - {}", record.level(), record.args()); - } - } - - fn flush(&self) {} -} +static INIT: Once = Once::new(); struct SimpleHandshaker { filter: HashSet, @@ -46,23 +36,37 @@ impl HandshakerTrait for SimpleHandshaker { } /// Initiates a handshake with the given socket address. - fn connect(&mut self, _: Option, _: InfoHash, addr: SocketAddr) { + fn connect(&mut self, _: Option, _: InfoHash, addr: SocketAddr) -> BoxFuture<'_, ()> { if self.filter.contains(&addr) { - return; + return Box::pin(std::future::ready(())); } self.filter.insert(addr); self.count += 1; println!("Received new peer {:?}, total unique peers {}", addr, self.count); + + Box::pin(std::future::ready(())) } /// Send the given Metadata back to the client. fn metadata(&mut self, (): Self::MetadataEnvelope) {} } -fn main() { - log::set_logger(&SimpleLogger).unwrap(); - log::set_max_level(log::LevelFilter::max()); +fn tracing_stderr_init(filter: LevelFilter) { + let builder = tracing_subscriber::fmt().with_max_level(filter).with_ansi(true); + + builder.pretty().with_file(true).init(); + + tracing::info!("Logging initialized"); +} + +#[tokio::main] +async fn main() { + INIT.call_once(|| { + tracing_stderr_init(LevelFilter::INFO); + }); + + let mut tasks = JoinSet::new(); let hash = InfoHash::from_bytes(b"My Unique Info Hash"); @@ -74,23 +78,24 @@ fn main() { .set_source_addr(SocketAddr::V4(SocketAddrV4::new(Ipv4Addr::new(0, 0, 0, 0), 6889))) .set_read_only(false) .start_mainline(handshaker) + .await .unwrap(); // Spawn a thread to listen to and report events - let events = dht.events(); - thread::spawn(move || { - for event in events { + let mut events = dht.events().await; + tasks.spawn(async move { + while let Some(event) = events.next().await { println!("\nReceived Dht Event {event:?}"); } }); // Let the user announce or search on our info hash - let stdin = io::stdin(); + let stdin = std::io::stdin(); let stdin_lock = stdin.lock(); for byte in stdin_lock.bytes() { match &[byte.unwrap()] { - b"a" => dht.search(hash, true), - b"s" => dht.search(hash, false), + b"a" => dht.search(hash, true).await, + b"s" => dht.search(hash, false).await, _ => (), } } diff --git a/packages/dht/src/builder.rs b/packages/dht/src/builder.rs index a584d2ab3..6396b5cea 100644 --- a/packages/dht/src/builder.rs +++ b/packages/dht/src/builder.rs @@ -1,10 +1,11 @@ use std::collections::HashSet; -use std::io; -use std::net::{SocketAddr, UdpSocket}; -use std::sync::mpsc::{self, Receiver}; +use std::net::SocketAddr; +use std::sync::Arc; -use log::warn; -use mio::Sender; +use futures::channel::mpsc; +use futures::SinkExt as _; +use tokio::net::UdpSocket; +use tokio::task::JoinSet; use util::bt::InfoHash; use util::net; @@ -14,39 +15,48 @@ use crate::worker::{self, DhtEvent, OneshotTask, ShutdownCause}; /// Maintains a Distributed Hash (Routing) Table. pub struct MainlineDht { - send: Sender, + main_task_sender: mpsc::Sender, + _tasks: JoinSet<()>, } impl MainlineDht { /// Start the `MainlineDht` with the given `DhtBuilder` and Handshaker. - fn with_builder(builder: DhtBuilder, handshaker: H) -> io::Result + async fn with_builder(builder: DhtBuilder, handshaker: H) -> std::io::Result where H: HandshakerTrait + 'static, { - let send_sock = UdpSocket::bind(builder.src_addr)?; - let recv_sock = send_sock.try_clone()?; + let send_sock = Arc::new(UdpSocket::bind(builder.src_addr).await?); + let recv_sock = send_sock.clone(); - let kill_sock = send_sock.try_clone()?; + let kill_sock = send_sock.clone(); let kill_addr = send_sock.local_addr()?; - let send = worker::start_mainline_dht( - send_sock, + let (main_task_sender, tasks) = worker::start_mainline_dht( + &send_sock, recv_sock, builder.read_only, builder.ext_addr, handshaker, kill_sock, kill_addr, - )?; + ); let nodes: Vec = builder.nodes.into_iter().collect(); let routers: Vec = builder.routers.into_iter().collect(); - if send.send(OneshotTask::StartBootstrap(routers, nodes)).is_err() { - warn!("bip_dt: MainlineDht failed to send a start bootstrap message..."); + if main_task_sender + .clone() + .send(OneshotTask::StartBootstrap(routers, nodes)) + .await + .is_err() + { + tracing::warn!("bip_dt: MainlineDht failed to send a start bootstrap message..."); } - Ok(MainlineDht { send }) + Ok(MainlineDht { + main_task_sender, + _tasks: tasks, + }) } /// Perform a search for the given `InfoHash` with an optional announce on the closest nodes. @@ -57,9 +67,15 @@ impl MainlineDht { /// /// If the initial bootstrap has not finished, the search will be queued and executed once /// the bootstrap has completed. - pub fn search(&self, hash: InfoHash, announce: bool) { - if self.send.send(OneshotTask::StartLookup(hash, announce)).is_err() { - warn!("bip_dht: MainlineDht failed to send a start lookup message..."); + pub async fn search(&self, hash: InfoHash, announce: bool) { + if self + .main_task_sender + .clone() + .send(OneshotTask::StartLookup(hash, announce)) + .await + .is_err() + { + tracing::warn!("bip_dht: MainlineDht failed to send a start lookup message..."); } } @@ -68,11 +84,11 @@ impl MainlineDht { /// It is important to at least monitor the DHT for shutdown events as any calls /// after that event occurs will not be processed but no indication will be given. #[must_use] - pub fn events(&self) -> Receiver { - let (send, recv) = mpsc::channel(); + pub async fn events(&self) -> mpsc::Receiver { + let (send, recv) = mpsc::channel(1); - if self.send.send(OneshotTask::RegisterSender(send)).is_err() { - warn!("bip_dht: MainlineDht failed to send a register sender message..."); + if let Err(e) = self.main_task_sender.clone().send(OneshotTask::RegisterSender(send)).await { + tracing::warn!("bip_dht: MainlineDht failed to send a register sender message..., {e}"); // TODO: Should we push a Shutdown event through the sender here? We would need // to know the cause or create a new cause for this specific scenario since the // client could have been lazy and wasn't monitoring this until after it shutdown! @@ -84,8 +100,13 @@ impl MainlineDht { impl Drop for MainlineDht { fn drop(&mut self) { - if self.send.send(OneshotTask::Shutdown(ShutdownCause::ClientInitiated)).is_err() { - warn!( + if self + .main_task_sender + .clone() + .try_send(OneshotTask::Shutdown(ShutdownCause::ClientInitiated)) + .is_err() + { + tracing::warn!( "bip_dht: MainlineDht failed to send a shutdown message (may have already been \ shutdown)..." ); @@ -198,10 +219,10 @@ impl DhtBuilder { /// # Errors /// /// It would return error if unable to build from the handshaker. - pub fn start_mainline(self, handshaker: H) -> io::Result + pub async fn start_mainline(self, handshaker: H) -> std::io::Result where H: HandshakerTrait + 'static, { - MainlineDht::with_builder(self, handshaker) + MainlineDht::with_builder(self, handshaker).await } } diff --git a/packages/dht/src/error.rs b/packages/dht/src/error.rs index a77688a36..299ed8480 100644 --- a/packages/dht/src/error.rs +++ b/packages/dht/src/error.rs @@ -1,42 +1,21 @@ -use std::io; - use bencode::BencodeConvertError; -use error_chain::error_chain; +use thiserror::Error; use crate::message::error::ErrorMessage; -error_chain! { - types { - DhtError, DhtErrorKind, DhtResultExt, DhtResult; - } - - foreign_links { - Bencode(BencodeConvertError); - Io(io::Error); - } - - errors { - InvalidMessage { - code: String - } { - description("Node Sent An Invalid Message") - display("Node Sent An Invalid Message With Message Code {}", code) - } - InvalidResponse { - details: String - } { - description("Node Sent Us An Invalid Response") - display("Node Sent Us An Invalid Response: {}", details) - } - UnsolicitedResponse { - description("Node Sent Us An Unsolicited Response") - display("Node Sent Us An Unsolicited Response") - } - InvalidRequest { - msg: ErrorMessage<'static> - } { - description("Node Sent Us An Invalid Request Message") - display("Node Sent Us An Invalid Request Message With Code {:?} And Message {}", msg.error_code(), msg.error_message()) - } - } +#[allow(clippy::module_name_repetitions)] +#[derive(Error, Debug)] +pub enum DhtError { + #[error("Bencode error: {0}")] + Bencode(#[from] BencodeConvertError), + #[error("IO error: {0}")] + Io(#[from] std::io::Error), + #[error("Node Sent An Invalid Message With Message Code {code}")] + InvalidMessage { code: String }, + #[error("Node Sent Us An Invalid Response: {details}")] + InvalidResponse { details: String }, + #[error("Node Sent Us An Unsolicited Response")] + UnsolicitedResponse, + #[error("Node Sent Us An Invalid Request Message With Code {msg:?} And Message {msg}")] + InvalidRequest { msg: ErrorMessage<'static> }, } diff --git a/packages/dht/src/handshaker_trait.rs b/packages/dht/src/handshaker_trait.rs index 8e3fc6098..221e40129 100644 --- a/packages/dht/src/handshaker_trait.rs +++ b/packages/dht/src/handshaker_trait.rs @@ -1,5 +1,6 @@ use std::net::SocketAddr; +use futures::future::BoxFuture; use util::bt::{InfoHash, PeerId}; /// Trait for peer discovery services to forward peer contact information and metadata. @@ -17,7 +18,7 @@ pub trait HandshakerTrait: Send { fn port(&self) -> u16; /// Connect to the given address with the `InfoHash` and expecting the `PeerId`. - fn connect(&mut self, expected: Option, hash: InfoHash, addr: SocketAddr); + fn connect(&mut self, expected: Option, hash: InfoHash, addr: SocketAddr) -> BoxFuture<'_, ()>; /// Send the given Metadata back to the client. fn metadata(&mut self, data: Self::MetadataEnvelope); diff --git a/packages/dht/src/message/announce_peer.rs b/packages/dht/src/message/announce_peer.rs index e4cfa5f3c..156d64a3c 100644 --- a/packages/dht/src/message/announce_peer.rs +++ b/packages/dht/src/message/announce_peer.rs @@ -4,7 +4,7 @@ use bencode::{ben_bytes, ben_int, ben_map, BConvert, BDictAccess, BRefAccess}; use util::bt::{InfoHash, NodeId}; -use crate::error::DhtResult; +use crate::error::DhtError; use crate::message; use crate::message::request::{self, RequestValidate}; @@ -52,7 +52,10 @@ impl<'a> AnnouncePeerRequest<'a> { /// # Errors /// /// This function will return an error unable to get bytes unable do lookup. - pub fn from_parts(rqst_root: &'a dyn BDictAccess, trans_id: &'a [u8]) -> DhtResult> + pub fn from_parts( + rqst_root: &'a dyn BDictAccess, + trans_id: &'a [u8], + ) -> Result, DhtError> where B: BRefAccess, { @@ -151,7 +154,10 @@ impl<'a> AnnouncePeerResponse<'a> { /// # Errors /// /// This function will return an error unable to get bytes or unable to validate the node id. - pub fn from_parts(rqst_root: &dyn BDictAccess, trans_id: &'a [u8]) -> DhtResult> + pub fn from_parts( + rqst_root: &dyn BDictAccess, + trans_id: &'a [u8], + ) -> Result, DhtError> where B: BRefAccess, { diff --git a/packages/dht/src/message/compact_info.rs b/packages/dht/src/message/compact_info.rs index 8a55d3217..b06a1fce9 100644 --- a/packages/dht/src/message/compact_info.rs +++ b/packages/dht/src/message/compact_info.rs @@ -1,5 +1,4 @@ use std::borrow::Cow; -use std::fmt::Debug; use std::hash::Hash; use std::net::{Ipv4Addr, SocketAddrV4}; @@ -132,7 +131,7 @@ where impl<'a, B> IntoIterator for CompactValueInfo<'a, B> where B: BRefAccess + Clone, - B::BType: PartialEq + Eq + core::hash::Hash + Debug, + B::BType: PartialEq + Eq + core::hash::Hash + std::fmt::Debug, { type Item = SocketAddrV4; type IntoIter = CompactValueInfoIter<'a, B>; diff --git a/packages/dht/src/message/error.rs b/packages/dht/src/message/error.rs index 319251208..ae3c103a5 100644 --- a/packages/dht/src/message/error.rs +++ b/packages/dht/src/message/error.rs @@ -6,7 +6,7 @@ use std::borrow::Cow; use bencode::ext::BConvertExt; use bencode::{ben_bytes, ben_int, ben_list, ben_map, BConvert, BDictAccess, BListAccess, BRefAccess, BencodeConvertError}; -use crate::error::{DhtError, DhtErrorKind, DhtResult}; +use crate::error::DhtError; use crate::message; const ERROR_ARGS_KEY: &str = "e"; @@ -27,15 +27,15 @@ pub enum ErrorCode { } impl ErrorCode { - fn new(code: u8) -> DhtResult { + fn new(code: u8) -> Result { match code { GENERIC_ERROR_CODE => Ok(ErrorCode::GenericError), SERVER_ERROR_CODE => Ok(ErrorCode::ServerError), PROTOCOL_ERROR_CODE => Ok(ErrorCode::ProtocolError), METHOD_UNKNOWN_CODE => Ok(ErrorCode::MethodUnknown), - unknown => Err(DhtError::from_kind(DhtErrorKind::InvalidResponse { + unknown => Err(DhtError::InvalidResponse { details: format!("Error Message Invalid Error Code {unknown:?}"), - })), + }), } } } @@ -57,14 +57,14 @@ impl From for u8 { struct ErrorValidate; impl ErrorValidate { - fn extract_error_args(self, args: &dyn BListAccess) -> DhtResult<(u8, String)> + fn extract_error_args(self, args: &dyn BListAccess) -> Result<(u8, String), DhtError> where B: BRefAccess, { if args.len() != NUM_ERROR_ARGS { - return Err(DhtError::from_kind(DhtErrorKind::InvalidResponse { + return Err(DhtError::InvalidResponse { details: format!("Error Message Invalid Number Of Error Args: {}", args.len()), - })); + }); } let code = self.convert_int(&args[0], format!("{ERROR_ARGS_KEY}[0]"))?; @@ -118,7 +118,7 @@ impl<'a> ErrorMessage<'a> { /// # Errors /// /// This function will return an error if unable to lookup the error. - pub fn from_parts(root: &dyn BDictAccess, trans_id: &'a [u8]) -> DhtResult> + pub fn from_parts(root: &dyn BDictAccess, trans_id: &'a [u8]) -> Result, DhtError> where B: BRefAccess, { @@ -169,3 +169,13 @@ impl<'a> ErrorMessage<'a> { .encode() } } + +impl<'a> std::fmt::Display for ErrorMessage<'a> { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!( + f, + "ErrorMessage {{ trans_id: {:?}, code: {:?}, message: {} }}", + self.trans_id, self.code, self.message + ) + } +} diff --git a/packages/dht/src/message/find_node.rs b/packages/dht/src/message/find_node.rs index e5f5b65f3..55607cf99 100644 --- a/packages/dht/src/message/find_node.rs +++ b/packages/dht/src/message/find_node.rs @@ -1,7 +1,7 @@ use bencode::{ben_bytes, ben_map, BConvert, BDictAccess, BRefAccess}; use util::bt::NodeId; -use crate::error::DhtResult; +use crate::error::DhtError; use crate::message; use crate::message::compact_info::CompactNodeInfo; use crate::message::request::{self, RequestValidate}; @@ -37,7 +37,7 @@ impl<'a> FindNodeRequest<'a> { rqst_root: &dyn BDictAccess, trans_id: &'a [u8], target_key: &str, - ) -> DhtResult> + ) -> Result, DhtError> where B: BRefAccess, { @@ -97,7 +97,7 @@ impl<'a> FindNodeResponse<'a> { /// # Errors /// /// This function will return an error if unable to validate the nodes. - pub fn new(trans_id: &'a [u8], node_id: NodeId, nodes: &'a [u8]) -> DhtResult> { + pub fn new(trans_id: &'a [u8], node_id: NodeId, nodes: &'a [u8]) -> Result, DhtError> { let validate = ResponseValidate::new(trans_id); let compact_nodes = validate.validate_nodes(nodes)?; @@ -113,7 +113,7 @@ impl<'a> FindNodeResponse<'a> { /// # Errors /// /// This function will return an error if unable to lookup and and validate node. - pub fn from_parts(rsp_root: &'a dyn BDictAccess, trans_id: &'a [u8]) -> DhtResult> + pub fn from_parts(rsp_root: &'a dyn BDictAccess, trans_id: &'a [u8]) -> Result, DhtError> where B: BRefAccess, { diff --git a/packages/dht/src/message/get_peers.rs b/packages/dht/src/message/get_peers.rs index 152f3090f..52c1cca55 100644 --- a/packages/dht/src/message/get_peers.rs +++ b/packages/dht/src/message/get_peers.rs @@ -1,12 +1,11 @@ use std::collections::BTreeMap; -use std::fmt::Debug; use std::ops::Deref; use bencode::inner::BCowConvert; use bencode::{ben_bytes, ben_map, BConvert, BDictAccess, BMutAccess, BRefAccess, BencodeMut}; use util::bt::{InfoHash, NodeId}; -use crate::error::{DhtError, DhtErrorKind, DhtResult}; +use crate::error::DhtError; use crate::message; use crate::message::compact_info::{CompactNodeInfo, CompactValueInfo}; use crate::message::request::{self, RequestValidate}; @@ -35,7 +34,7 @@ impl<'a> GetPeersRequest<'a> { /// # Errors /// /// This function will return an error if unable to lookup, convert, and validate node. - pub fn from_parts(rqst_root: &dyn BDictAccess, trans_id: &'a [u8]) -> DhtResult> + pub fn from_parts(rqst_root: &dyn BDictAccess, trans_id: &'a [u8]) -> Result, DhtError> where B: BRefAccess, { @@ -85,7 +84,7 @@ impl<'a> GetPeersRequest<'a> { pub enum CompactInfoType<'a, B> where B: BRefAccess + Clone, - B::BType: PartialEq + Eq + core::hash::Hash + Debug, + B::BType: PartialEq + Eq + core::hash::Hash + std::fmt::Debug, { Nodes(CompactNodeInfo<'a>), Values(CompactValueInfo<'a, B::BType>), @@ -97,7 +96,7 @@ where pub struct GetPeersResponse<'a, B> where B: BRefAccess + Clone, - B::BType: PartialEq + Eq + core::hash::Hash + Debug, + B::BType: PartialEq + Eq + core::hash::Hash + std::fmt::Debug, { trans_id: &'a [u8], node_id: NodeId, @@ -110,7 +109,7 @@ where impl<'a, B> GetPeersResponse<'a, B> where B: BRefAccess + Clone, - B::BType: PartialEq + Eq + core::hash::Hash + Debug, + B::BType: PartialEq + Eq + core::hash::Hash + std::fmt::Debug, { #[must_use] pub fn new( @@ -135,7 +134,7 @@ where pub fn from_parts( rsp_root: &'a dyn BDictAccess, trans_id: &'a [u8], - ) -> DhtResult> { + ) -> Result, DhtError> { let validate = ResponseValidate::new(trans_id); let node_id_bytes = validate.lookup_and_convert_bytes(rsp_root, message::NODE_ID_KEY)?; @@ -163,9 +162,9 @@ where CompactInfoType::Values(values_info) } (Err(_), Err(_)) => { - return Err(DhtError::from_kind(DhtErrorKind::InvalidResponse { + return Err(DhtError::InvalidResponse { details: "Failed To Find nodes Or values In Node Response".to_owned(), - })) + }) } }; diff --git a/packages/dht/src/message/mod.rs b/packages/dht/src/message/mod.rs index 3f20d8f3b..68e4c3771 100644 --- a/packages/dht/src/message/mod.rs +++ b/packages/dht/src/message/mod.rs @@ -1,9 +1,7 @@ -use std::fmt::Debug; - use bencode::ext::BConvertExt; use bencode::{BConvert, BRefAccess, BencodeConvertError}; -use crate::error::{DhtError, DhtErrorKind, DhtResult}; +use crate::error::DhtError; use crate::message::error::ErrorMessage; use crate::message::request::RequestType; use crate::message::response::{ExpectedResponse, ResponseType}; @@ -61,7 +59,7 @@ impl BConvertExt for MessageValidate {} pub enum MessageType<'a, B> where B: BRefAccess + Clone, - B::BType: PartialEq + Eq + core::hash::Hash + Debug, + B::BType: PartialEq + Eq + core::hash::Hash + std::fmt::Debug, { Request(RequestType<'a>), Response(ResponseType<'a, B>), @@ -71,14 +69,14 @@ where impl<'a, B> MessageType<'a, B> where B: BRefAccess + Clone, - B::BType: PartialEq + Eq + core::hash::Hash + Debug, + B::BType: PartialEq + Eq + core::hash::Hash + std::fmt::Debug, { /// Create a new `MessageType` /// /// # Errors /// /// This function will return an error if unable to lookup, convert and crate type. - pub fn new(message: &'a B::BType, trans_mapper: T) -> DhtResult> + pub fn new(message: &'a B::BType, trans_mapper: T) -> Result, DhtError> where T: Fn(&[u8]) -> ExpectedResponse, { @@ -103,9 +101,9 @@ where let err_message = ErrorMessage::from_parts(msg_root, trans_id)?; Ok(MessageType::Error(err_message)) } - unknown => Err(DhtError::from_kind(DhtErrorKind::InvalidMessage { + unknown => Err(DhtError::InvalidMessage { code: unknown.to_owned(), - })), + }), } } } diff --git a/packages/dht/src/message/ping.rs b/packages/dht/src/message/ping.rs index 0bb1ec6a7..483a28629 100644 --- a/packages/dht/src/message/ping.rs +++ b/packages/dht/src/message/ping.rs @@ -4,7 +4,7 @@ use bencode::{ben_bytes, ben_map, BConvert, BDictAccess, BRefAccess}; use util::bt::NodeId; -use crate::error::DhtResult; +use crate::error::DhtError; use crate::message; use crate::message::request::{self, RequestValidate}; @@ -26,7 +26,7 @@ impl<'a> PingRequest<'a> { /// # Errors /// /// This function will return an error if unable to lookup, convert, and validate nodes. - pub fn from_parts(rqst_root: &dyn BDictAccess, trans_id: &'a [u8]) -> DhtResult> + pub fn from_parts(rqst_root: &dyn BDictAccess, trans_id: &'a [u8]) -> Result, DhtError> where B: BRefAccess, { @@ -82,7 +82,7 @@ impl<'a> PingResponse<'a> { /// # Errors /// /// This function will return an error if unable to generate the ping request from the root. - pub fn from_parts(rsp_root: &dyn BDictAccess, trans_id: &'a [u8]) -> DhtResult> + pub fn from_parts(rsp_root: &dyn BDictAccess, trans_id: &'a [u8]) -> Result, DhtError> where B: BRefAccess, { diff --git a/packages/dht/src/message/request.rs b/packages/dht/src/message/request.rs index 61e36a584..5c37dfb6d 100644 --- a/packages/dht/src/message/request.rs +++ b/packages/dht/src/message/request.rs @@ -2,7 +2,7 @@ use bencode::ext::BConvertExt; use bencode::{BConvert, BDictAccess, BRefAccess, BencodeConvertError}; use util::bt::{InfoHash, NodeId}; -use crate::error::{DhtError, DhtErrorKind, DhtResult}; +use crate::error::DhtError; use crate::message; use crate::message::announce_peer::AnnouncePeerRequest; use crate::message::error::{ErrorCode, ErrorMessage}; @@ -38,7 +38,7 @@ impl<'a> RequestValidate<'a> { /// # Errors /// /// This function will return an error if to generate the `NodeId`. - pub fn validate_node_id(&self, node_id: &[u8]) -> DhtResult { + pub fn validate_node_id(&self, node_id: &[u8]) -> Result { NodeId::from_hash(node_id).map_err(|_| { let error_msg = ErrorMessage::new( self.trans_id.to_owned(), @@ -46,7 +46,7 @@ impl<'a> RequestValidate<'a> { format!("Node ID With Length {} Is Not Valid", node_id.len()), ); - DhtError::from_kind(DhtErrorKind::InvalidRequest { msg: error_msg }) + DhtError::InvalidRequest { msg: error_msg } }) } @@ -55,7 +55,7 @@ impl<'a> RequestValidate<'a> { /// # Errors /// /// This function will return an error if to generate the `InfoHash`. - pub fn validate_info_hash(&self, info_hash: &[u8]) -> DhtResult { + pub fn validate_info_hash(&self, info_hash: &[u8]) -> Result { InfoHash::from_hash(info_hash).map_err(|_| { let error_msg = ErrorMessage::new( self.trans_id.to_owned(), @@ -63,7 +63,7 @@ impl<'a> RequestValidate<'a> { format!("InfoHash With Length {} Is Not Valid", info_hash.len()), ); - DhtError::from_kind(DhtErrorKind::InvalidRequest { msg: error_msg }) + DhtError::InvalidRequest { msg: error_msg } }) } } @@ -95,7 +95,11 @@ impl<'a> RequestType<'a> { /// # Errors /// /// This function will return an error if unable to lookup, convert, and generate correct type. - pub fn from_parts(root: &'a dyn BDictAccess, trans_id: &'a [u8], rqst_type: &str) -> DhtResult> + pub fn from_parts( + root: &'a dyn BDictAccess, + trans_id: &'a [u8], + rqst_type: &str, + ) -> Result, DhtError> where B: BRefAccess, { @@ -138,7 +142,7 @@ impl<'a> RequestType<'a> { format!("Received Unknown Request Method: {unknown}"), ); - Err(DhtError::from_kind(DhtErrorKind::InvalidRequest { msg: error_message })) + Err(DhtError::InvalidRequest { msg: error_message }) } } } diff --git a/packages/dht/src/message/response.rs b/packages/dht/src/message/response.rs index fb5a28250..a5d8be600 100644 --- a/packages/dht/src/message/response.rs +++ b/packages/dht/src/message/response.rs @@ -1,10 +1,8 @@ -use std::fmt::Debug; - use bencode::ext::BConvertExt; use bencode::{BConvert, BDictAccess, BListAccess, BRefAccess, BencodeConvertError}; use util::bt::NodeId; -use crate::error::{DhtError, DhtErrorKind, DhtResult}; +use crate::error::DhtError; use crate::message::announce_peer::AnnouncePeerResponse; use crate::message::compact_info::{CompactNodeInfo, CompactValueInfo}; use crate::message::find_node::FindNodeResponse; @@ -31,15 +29,13 @@ impl<'a> ResponseValidate<'a> { /// # Errors /// /// This function will return an error if to generate the `NodeId`. - pub fn validate_node_id(&self, node_id: &[u8]) -> DhtResult { - NodeId::from_hash(node_id).map_err(|_| { - DhtError::from_kind(DhtErrorKind::InvalidResponse { - details: format!( - "TID {:?} Found Node ID With Invalid Length {:?}", - self.trans_id, - node_id.len() - ), - }) + pub fn validate_node_id(&self, node_id: &[u8]) -> Result { + NodeId::from_hash(node_id).map_err(|_| DhtError::InvalidResponse { + details: format!( + "TID {:?} Found Node ID With Invalid Length {:?}", + self.trans_id, + node_id.len() + ), }) } @@ -48,16 +44,14 @@ impl<'a> ResponseValidate<'a> { /// # Errors /// /// This function will return an error if to generate the `CompactNodeInfo`. - pub fn validate_nodes<'b>(&self, nodes: &'b [u8]) -> DhtResult> { - CompactNodeInfo::new(nodes).map_err(|_| { - DhtError::from_kind(DhtErrorKind::InvalidResponse { - details: format!( - "TID {:?} Found Nodes Structure With {} Number Of Bytes Instead \ + pub fn validate_nodes<'b>(&self, nodes: &'b [u8]) -> Result, DhtError> { + CompactNodeInfo::new(nodes).map_err(|_| DhtError::InvalidResponse { + details: format!( + "TID {:?} Found Nodes Structure With {} Number Of Bytes Instead \ Of Correct Multiple", - self.trans_id, - nodes.len() - ), - }) + self.trans_id, + nodes.len() + ), }) } @@ -66,26 +60,24 @@ impl<'a> ResponseValidate<'a> { /// # Errors /// /// This function will return an error if to generate the `CompactValueInfo`. - pub fn validate_values<'b, B>(&self, values: &'b dyn BListAccess) -> DhtResult> + pub fn validate_values<'b, B>(&self, values: &'b dyn BListAccess) -> Result, DhtError> where B: BRefAccess + Clone, - B::BType: PartialEq + Eq + core::hash::Hash + Debug, + B::BType: PartialEq + Eq + core::hash::Hash + std::fmt::Debug, { for bencode in values { match bencode.bytes() { Some(_) => (), None => { - return Err(DhtError::from_kind(DhtErrorKind::InvalidResponse { + return Err(DhtError::InvalidResponse { details: format!("TID {:?} Found Values Structure As Non Bytes Type", self.trans_id), - })) + }) } } } - CompactValueInfo::new(values).map_err(|_| { - DhtError::from_kind(DhtErrorKind::InvalidResponse { - details: format!("TID {:?} Found Values Structure With Wrong Number Of Bytes", self.trans_id), - }) + CompactValueInfo::new(values).map_err(|_| DhtError::InvalidResponse { + details: format!("TID {:?} Found Values Structure With Wrong Number Of Bytes", self.trans_id), }) } } @@ -119,7 +111,7 @@ pub enum ExpectedResponse { pub enum ResponseType<'a, B> where B: BRefAccess + Clone, - B::BType: PartialEq + Eq + core::hash::Hash + Debug, + B::BType: PartialEq + Eq + core::hash::Hash + std::fmt::Debug, { Ping(PingResponse<'a>), FindNode(FindNodeResponse<'a>), @@ -131,7 +123,7 @@ where impl<'a, B> ResponseType<'a, B> where B: BRefAccess + Clone, - B::BType: PartialEq + Eq + core::hash::Hash + Debug, + B::BType: PartialEq + Eq + core::hash::Hash + std::fmt::Debug, { /// Creates a new `ResponseType` from parts. /// @@ -142,7 +134,7 @@ where root: &'a dyn BDictAccess, trans_id: &'a [u8], rsp_type: &ExpectedResponse, - ) -> DhtResult> + ) -> Result, DhtError> where B: BRefAccess, { @@ -172,7 +164,7 @@ where ExpectedResponse::PutData => { unimplemented!(); } - ExpectedResponse::None => Err(DhtError::from_kind(DhtErrorKind::UnsolicitedResponse)), + ExpectedResponse::None => Err(DhtError::UnsolicitedResponse), } } } diff --git a/packages/dht/src/router.rs b/packages/dht/src/router.rs index 69d8a9b15..64356d534 100644 --- a/packages/dht/src/router.rs +++ b/packages/dht/src/router.rs @@ -1,5 +1,3 @@ -use std::fmt::{self, Display, Formatter}; -use std::io::{self, Error, ErrorKind}; use std::net::{SocketAddr, SocketAddrV4, SocketAddrV6, ToSocketAddrs}; use std::vec::IntoIter; @@ -28,7 +26,7 @@ pub enum Router { impl Router { // TODO: USES DEPRECATED FUNCTIONS - // pub fn hostname(&self) -> io::Result> { + // pub fn hostname(&self) -> std::io::Result> { // match self { // &Router::uTorrent => Ok(UTORRENT_DHT.0.into_cow()), // &Router::BitComet => Ok(BITCOMET_DHT.0.into_cow()), @@ -44,12 +42,13 @@ impl Router { /// # Errors /// /// This function will return an error if unable to fund any ipv4 address. - pub fn ipv4_addr(&self) -> io::Result { + pub fn ipv4_addr(&self) -> std::io::Result { let mut addrs = self.socket_addrs()?; - addrs - .find_map(map_ipv4) - .ok_or(Error::new(ErrorKind::Other, "No IPv4 Addresses Found For Host")) + addrs.find_map(map_ipv4).ok_or(std::io::Error::new( + std::io::ErrorKind::Other, + "No IPv4 Addresses Found For Host", + )) } /// Returns the [`SocketAddrV6`] of this [`Router`]. @@ -57,12 +56,13 @@ impl Router { /// # Errors /// /// This function will return an error if unable to fund any ipv6 address. - pub fn ipv6_addr(&self) -> io::Result { + pub fn ipv6_addr(&self) -> std::io::Result { let mut addrs = self.socket_addrs()?; - addrs - .find_map(map_ipv6) - .ok_or(Error::new(ErrorKind::Other, "No IPv6 Addresses Found For Host")) + addrs.find_map(map_ipv6).ok_or(std::io::Error::new( + std::io::ErrorKind::Other, + "No IPv6 Addresses Found For Host", + )) } /// Returns the [`SocketAddr`] of this [`Router`]. @@ -70,15 +70,16 @@ impl Router { /// # Errors /// /// This function will return an error if unable to fund a socket address. - pub fn socket_addr(&self) -> io::Result { + pub fn socket_addr(&self) -> std::io::Result { let mut addrs = self.socket_addrs()?; - addrs - .next() - .ok_or(Error::new(ErrorKind::Other, "No SocketAddresses Found For Host")) + addrs.next().ok_or(std::io::Error::new( + std::io::ErrorKind::Other, + "No SocketAddresses Found For Host", + )) } - fn socket_addrs(&self) -> io::Result> { + fn socket_addrs(&self) -> std::io::Result> { match *self { Router::uTorrent => UTORRENT_DHT.to_socket_addrs(), Router::BitTorrent => BITTORRENT_DHT.to_socket_addrs(), @@ -106,14 +107,14 @@ fn map_ipv6(addr: SocketAddr) -> Option { } } -impl Display for Router { - fn fmt(&self, f: &mut Formatter<'_>) -> Result<(), fmt::Error> { +impl std::fmt::Display for Router { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> Result<(), std::fmt::Error> { match *self { Router::uTorrent => f.write_fmt(format_args!("{}:{}", UTORRENT_DHT.0, UTORRENT_DHT.1)), Router::BitTorrent => f.write_fmt(format_args!("{}:{}", BITTORRENT_DHT.0, BITTORRENT_DHT.1)), Router::BitComet => f.write_fmt(format_args!("{}:{}", BITCOMET_DHT.0, BITCOMET_DHT.1)), Router::Transmission => f.write_fmt(format_args!("{}:{}", TRANSMISSION_DHT.0, TRANSMISSION_DHT.1)), - Router::Custom(n) => Display::fmt(&n, f), + Router::Custom(n) => std::fmt::Display::fmt(&n, f), } } } diff --git a/packages/dht/src/routing/node.rs b/packages/dht/src/routing/node.rs index cef493d04..7abc3f6ac 100644 --- a/packages/dht/src/routing/node.rs +++ b/packages/dht/src/routing/node.rs @@ -2,9 +2,10 @@ #![allow(unused)] use std::cell::Cell; -use std::fmt::{self, Debug, Formatter}; use std::hash::{Hash, Hasher}; use std::net::SocketAddr; +use std::sync::atomic::{AtomicUsize, Ordering}; +use std::sync::{Arc, Mutex}; use chrono::{DateTime, Duration, Utc}; use util::bt::NodeId; @@ -41,9 +42,9 @@ pub enum NodeStatus { pub struct Node { id: NodeId, addr: SocketAddr, - last_request: Cell>>, - last_response: Cell>>, - refresh_requests: Cell, + last_request: Arc>>>, + last_response: Arc>>>, + refresh_requests: Arc, } impl Node { @@ -52,9 +53,9 @@ impl Node { Node { id, addr, - last_response: Cell::new(Some(Utc::now())), - last_request: Cell::new(None), - refresh_requests: Cell::new(0), + last_response: Arc::new(Mutex::new(Some(Utc::now()))), + last_request: Arc::default(), + refresh_requests: Arc::default(), } } @@ -67,9 +68,9 @@ impl Node { Node { id, addr, - last_response: Cell::new(Some(last_response)), - last_request: Cell::new(None), - refresh_requests: Cell::new(0), + last_response: Arc::new(Mutex::new(Some(last_response))), + last_request: Arc::default(), + refresh_requests: Arc::default(), } } @@ -78,31 +79,29 @@ impl Node { Node { id, addr, - last_response: Cell::new(None), - last_request: Cell::new(None), - refresh_requests: Cell::new(0), + last_response: Arc::default(), + last_request: Arc::default(), + refresh_requests: Arc::default(), } } /// Record that we sent the node a request. pub fn local_request(&self) { if self.status() != NodeStatus::Good { - let num_requests = self.refresh_requests.get() + 1; - - self.refresh_requests.set(num_requests); + let num_requests = self.refresh_requests.fetch_add(1, Ordering::SeqCst) + 1; } } /// Record that the node sent us a request. pub fn remote_request(&self) { - self.last_request.set(Some(Utc::now())); + *self.last_request.lock().unwrap() = Some(Utc::now()); } /// Record that the node sent us a response. pub fn remote_response(&self) { - self.last_response.set(Some(Utc::now())); + *self.last_response.lock().unwrap() = Some(Utc::now()); - self.refresh_requests.set(0); + self.refresh_requests.store(0, Ordering::Relaxed); } pub fn id(&self) -> NodeId { @@ -190,16 +189,16 @@ impl Clone for Node { } } -impl Debug for Node { - fn fmt(&self, f: &mut Formatter<'_>) -> Result<(), fmt::Error> { +impl std::fmt::Debug for Node { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> Result<(), std::fmt::Error> { f.write_fmt(format_args!( "Node{{ id: {:?}, addr: {:?}, last_request: {:?}, \ last_response: {:?}, refresh_requests: {:?} }}", self.id, self.addr, - self.last_request.get(), - self.last_response.get(), - self.refresh_requests.get() + self.last_request.lock().unwrap(), + self.last_response.lock().unwrap(), + self.refresh_requests.load(Ordering::Relaxed) )) } } @@ -215,7 +214,7 @@ impl Debug for Node { /// to us before, but not recently. fn recently_responded(node: &Node, curr_time: DateTime) -> NodeStatus { // Check if node has ever responded to us - let since_response = match node.last_response.get() { + let since_response = match *node.last_response.lock().unwrap() { Some(response_time) => curr_time - response_time, None => return NodeStatus::Bad, }; @@ -238,7 +237,7 @@ fn recently_requested(node: &Node, curr_time: DateTime) -> NodeStatus { let max_last_request = Duration::minutes(MAX_LAST_SEEN_MINS); // Check if the node has recently request from us - if let Some(request_time) = node.last_request.get() { + if let Some(request_time) = *node.last_request.lock().unwrap() { let since_request = curr_time - request_time; if since_request < max_last_request { @@ -247,7 +246,7 @@ fn recently_requested(node: &Node, curr_time: DateTime) -> NodeStatus { } // Check if we have request from node multiple times already without response - if node.refresh_requests.get() < MAX_REFRESH_REQUESTS { + if node.refresh_requests.load(Ordering::Relaxed) < MAX_REFRESH_REQUESTS { NodeStatus::Questionable } else { NodeStatus::Bad @@ -256,7 +255,6 @@ fn recently_requested(node: &Node, curr_time: DateTime) -> NodeStatus { #[cfg(test)] mod tests { - use std::iter; use std::net::{Ipv4Addr, SocketAddr, SocketAddrV4}; use chrono::Duration; @@ -336,7 +334,7 @@ mod tests { let time_offset = Duration::minutes(super::MAX_LAST_SEEN_MINS); let idle_time = bip_test::travel_into_past(time_offset); - node.last_response.set(Some(idle_time)); + *node.last_response.lock().unwrap() = (Some(idle_time)); assert_eq!(node.status(), NodeStatus::Questionable); } diff --git a/packages/dht/src/routing/table.rs b/packages/dht/src/routing/table.rs index cdf02a0cf..67263f50a 100644 --- a/packages/dht/src/routing/table.rs +++ b/packages/dht/src/routing/table.rs @@ -1,7 +1,3 @@ -// TODO: Remove when we use find_node, -#![allow(unused)] - -use std::cmp; use std::iter::Filter; use std::slice::Iter; @@ -184,6 +180,7 @@ pub enum BucketContents<'a> { Assorted(&'a Bucket), } +#[allow(dead_code)] impl<'a> BucketContents<'a> { fn is_empty(&self) -> bool { matches!(self, &BucketContents::Empty) @@ -384,7 +381,7 @@ fn next_bucket_index(num_buckets: usize, start_index: usize, curr_index: usize) // to the right. All assuming we can actually do this without going out of bounds. match curr_index.cmp(&start_index) { - cmp::Ordering::Less => { + std::cmp::Ordering::Less => { let offset = (start_index - curr_index) + 1; let right_index = start_index.checked_add(offset); @@ -398,7 +395,7 @@ fn next_bucket_index(num_buckets: usize, start_index: usize, curr_index: usize) None } } - cmp::Ordering::Equal => { + std::cmp::Ordering::Equal => { let right_index = start_index.checked_add(1); let left_index = start_index.checked_sub(1); @@ -410,7 +407,7 @@ fn next_bucket_index(num_buckets: usize, start_index: usize, curr_index: usize) None } } - cmp::Ordering::Greater => { + std::cmp::Ordering::Greater => { let offset = curr_index - start_index; let left_index = start_index.checked_sub(offset); @@ -480,7 +477,7 @@ mod tests { #[test] fn positive_initial_empty_buckets() { let table_id = [1u8; bt::NODE_ID_LEN]; - let mut table = RoutingTable::new(table_id.into()); + let table = RoutingTable::new(table_id.into()); // First buckets should be empty assert_eq!(table.buckets().take(table::MAX_BUCKETS).count(), table::MAX_BUCKETS); diff --git a/packages/dht/src/worker/bootstrap.rs b/packages/dht/src/worker/bootstrap.rs index ff2bdeae6..f278d7941 100644 --- a/packages/dht/src/worker/bootstrap.rs +++ b/packages/dht/src/worker/bootstrap.rs @@ -1,9 +1,13 @@ use std::collections::{HashMap, HashSet}; use std::net::SocketAddr; -use std::sync::mpsc::SyncSender; - -use log::{error, info, warn}; -use mio::{EventLoop, Timeout}; +use std::sync::atomic::{AtomicUsize, Ordering}; +use std::sync::{Arc, Mutex, RwLock}; + +use futures::channel::mpsc::{self, SendError}; +use futures::future::BoxFuture; +use futures::{FutureExt as _, SinkExt as _}; +use tokio::task::JoinSet; +use tokio::time::{sleep, Duration}; use util::bt::{self, NodeId}; use crate::handshaker_trait::HandshakerTrait; @@ -12,7 +16,6 @@ use crate::routing::bucket::Bucket; use crate::routing::node::{Node, NodeStatus}; use crate::routing::table::{self, BucketContents, RoutingTable}; use crate::transaction::{MIDGenerator, TransactionID}; -use crate::worker::handler::DhtHandler; use crate::worker::ScheduledTaskCheck; const BOOTSTRAP_INITIAL_TIMEOUT: u64 = 2500; @@ -36,11 +39,12 @@ pub enum BootstrapStatus { #[allow(clippy::module_name_repetitions)] pub struct TableBootstrap { table_id: NodeId, - id_generator: MIDGenerator, + id_generator: Mutex, starting_nodes: Vec, - active_messages: HashMap, + active_messages: Mutex>, starting_routers: HashSet, - curr_bootstrap_bucket: usize, + curr_bootstrap_bucket: AtomicUsize, + tasks: Arc>>>, } impl TableBootstrap { @@ -52,73 +56,84 @@ impl TableBootstrap { TableBootstrap { table_id, - id_generator, + id_generator: Mutex::new(id_generator), starting_nodes: nodes, starting_routers: router_filter, - active_messages: HashMap::new(), - curr_bootstrap_bucket: 0, + active_messages: Mutex::default(), + curr_bootstrap_bucket: AtomicUsize::default(), + tasks: Arc::default(), } } - pub fn start_bootstrap( - &mut self, - out: &SyncSender<(Vec, SocketAddr)>, - event_loop: &mut EventLoop>, - ) -> BootstrapStatus - where - H: HandshakerTrait, - { + pub async fn start_bootstrap( + &self, + mut out: mpsc::Sender<(Vec, SocketAddr)>, + mut scheduled_task_sender: mpsc::Sender, + ) -> Result { // Reset the bootstrap state - self.active_messages.clear(); - self.curr_bootstrap_bucket = 0; + self.active_messages.lock().unwrap().clear(); + self.curr_bootstrap_bucket.store(0, Ordering::Relaxed); // Generate transaction id for the initial bootstrap messages - let trans_id = self.id_generator.generate(); + let trans_id = self.id_generator.lock().unwrap().generate(); // Set a timer to begin the actual bootstrap - let res_timeout = event_loop.timeout_ms( - (BOOTSTRAP_INITIAL_TIMEOUT, ScheduledTaskCheck::BootstrapTimeout(trans_id)), - BOOTSTRAP_INITIAL_TIMEOUT, - ); - let Ok(timeout) = res_timeout else { - error!("bip_dht: Failed to set a timeout for the start of a table bootstrap..."); - return BootstrapStatus::Failed; - }; + let abort = self.tasks.lock().unwrap().spawn(async move { + sleep(Duration::from_millis(BOOTSTRAP_INITIAL_TIMEOUT)).await; + + match scheduled_task_sender + .send(ScheduledTaskCheck::BootstrapTimeout(trans_id)) + .await + { + Ok(()) => { + tracing::debug!("sent scheduled bootstrap timeout"); + Ok(()) + } + Err(e) => { + tracing::debug!("error sending scheduled bootstrap timeout: {e}"); + Err(e) + } + } + }); // Insert the timeout into the active bootstraps just so we can check if a response was valid (and begin the bucket bootstraps) - self.active_messages.insert(trans_id, timeout); + self.active_messages.lock().unwrap().insert( + trans_id, + tokio::time::Instant::now() + Duration::from_millis(BOOTSTRAP_INITIAL_TIMEOUT), + ); let find_node_msg = FindNodeRequest::new(trans_id.as_ref(), self.table_id, self.table_id).encode(); // Ping all initial routers and nodes for addr in self.starting_routers.iter().chain(self.starting_nodes.iter()) { - if out.send((find_node_msg.clone(), *addr)).is_err() { - error!("bip_dht: Failed to send bootstrap message to router through channel..."); - return BootstrapStatus::Failed; + if out.send((find_node_msg.clone(), *addr)).await.is_err() { + tracing::error!("bip_dht: Failed to send bootstrap message to router through channel..."); + abort.abort(); + return Err(BootstrapStatus::Failed); } } - self.current_bootstrap_status() + Ok(self.current_bootstrap_status()) } pub fn is_router(&self, addr: &SocketAddr) -> bool { self.starting_routers.contains(addr) } - pub fn recv_response( - &mut self, - trans_id: &TransactionID, - table: &RoutingTable, - out: &SyncSender<(Vec, SocketAddr)>, - event_loop: &mut EventLoop>, + pub async fn recv_response( + &self, + trans_id: TransactionID, + table: Arc>, + out: mpsc::Sender<(Vec, SocketAddr)>, + scheduled_task_sender: mpsc::Sender, ) -> BootstrapStatus where - H: HandshakerTrait, + H: HandshakerTrait + 'static, { // Process the message transaction id - let timeout = if let Some(t) = self.active_messages.get(trans_id) { + let _timeout = if let Some(t) = self.active_messages.lock().unwrap().get(&trans_id) { *t } else { - warn!( + tracing::warn!( "bip_dht: Received expired/unsolicited node response for an active table \ bootstrap..." ); @@ -127,35 +142,32 @@ impl TableBootstrap { // If this response was from the initial bootstrap, we don't want to clear the timeout or remove // the token from the map as we want to wait until the proper timeout has been triggered before starting - if self.curr_bootstrap_bucket != 0 { + if self.curr_bootstrap_bucket.load(Ordering::Acquire) != 0 { // Message was not from the initial ping - // Remove the timeout from the event loop - event_loop.clear_timeout(timeout); - // Remove the token from the mapping - self.active_messages.remove(trans_id); + self.active_messages.lock().unwrap().remove(&trans_id); } // Check if we need to bootstrap on the next bucket - if self.active_messages.is_empty() { - return self.bootstrap_next_bucket(table, out, event_loop); + if self.active_messages.lock().unwrap().is_empty() { + return self.bootstrap_next_bucket::(table, out, scheduled_task_sender).await; } self.current_bootstrap_status() } - pub fn recv_timeout( - &mut self, - trans_id: &TransactionID, - table: &RoutingTable, - out: &SyncSender<(Vec, SocketAddr)>, - event_loop: &mut EventLoop>, + pub async fn recv_timeout( + &self, + trans_id: TransactionID, + table: Arc>, + out: mpsc::Sender<(Vec, SocketAddr)>, + scheduled_task_sender: mpsc::Sender, ) -> BootstrapStatus where - H: HandshakerTrait, + H: HandshakerTrait + 'static, { - if self.active_messages.remove(trans_id).is_none() { - warn!( + if self.active_messages.lock().unwrap().remove(&trans_id).is_none() { + tracing::warn!( "bip_dht: Received expired/unsolicited node timeout for an active table \ bootstrap..." ); @@ -163,8 +175,8 @@ impl TableBootstrap { } // Check if we need to bootstrap on the next bucket - if self.active_messages.is_empty() { - return self.bootstrap_next_bucket(table, out, event_loop); + if self.active_messages.lock().unwrap().is_empty() { + return self.bootstrap_next_bucket::(table, out, scheduled_task_sender).await; } self.current_bootstrap_status() @@ -172,100 +184,123 @@ impl TableBootstrap { // Returns true if there are more buckets to bootstrap, false otherwise fn bootstrap_next_bucket( - &mut self, - table: &RoutingTable, - out: &SyncSender<(Vec, SocketAddr)>, - event_loop: &mut EventLoop>, - ) -> BootstrapStatus + &self, + table: Arc>, + out: mpsc::Sender<(Vec, SocketAddr)>, + scheduled_task_sender: mpsc::Sender, + ) -> BoxFuture<'_, BootstrapStatus> where - H: HandshakerTrait, + H: HandshakerTrait + 'static, { - let target_id = flip_id_bit_at_index(self.table_id, self.curr_bootstrap_bucket); - - // Get the optimal iterator to bootstrap the current bucket - if self.curr_bootstrap_bucket == 0 || self.curr_bootstrap_bucket == 1 { - let iter = table - .closest_nodes(target_id) - .filter(|n| n.status() == NodeStatus::Questionable); - - self.send_bootstrap_requests(iter, target_id, table, out, event_loop) - } else { - let mut buckets = table.buckets().skip(self.curr_bootstrap_bucket - 2); - let dummy_bucket = Bucket::new(); - - // Sloppy probabilities of our target node residing at the node - let percent_25_bucket = if let Some(bucket) = buckets.next() { - match bucket { - BucketContents::Empty => dummy_bucket.iter(), - BucketContents::Sorted(b) | BucketContents::Assorted(b) => b.iter(), - } - } else { - dummy_bucket.iter() - }; - let percent_50_bucket = if let Some(bucket) = buckets.next() { - match bucket { - BucketContents::Empty => dummy_bucket.iter(), - BucketContents::Sorted(b) | BucketContents::Assorted(b) => b.iter(), - } + async move { + let bootstrap_bucket = self.curr_bootstrap_bucket.load(Ordering::Relaxed); + + let target_id = flip_id_bit_at_index(self.table_id, bootstrap_bucket); + + // Get the optimal iterator to bootstrap the current bucket + if bootstrap_bucket == 0 || bootstrap_bucket == 1 { + let questionable_nodes: Vec = table + .read() + .unwrap() + .closest_nodes(target_id) + .filter(|n| n.status() == NodeStatus::Questionable) + .cloned() + .collect(); + + self.send_bootstrap_requests::( + questionable_nodes.iter(), + target_id, + table.clone(), + out.clone(), + scheduled_task_sender, + ) + .await } else { - dummy_bucket.iter() - }; - let percent_100_bucket = if let Some(bucket) = buckets.next() { - match bucket { - BucketContents::Empty => dummy_bucket.iter(), - BucketContents::Sorted(b) | BucketContents::Assorted(b) => b.iter(), - } - } else { - dummy_bucket.iter() - }; - - // TODO: Figure out why chaining them in reverse gives us more total nodes on average, perhaps it allows us to fill up the lower - // buckets faster at the cost of less nodes in the higher buckets (since lower buckets are very easy to fill)...Although it should - // even out since we are stagnating buckets, so doing it in reverse may make sense since on the 3rd iteration, it allows us to ping - // questionable nodes in our first buckets right off the bat. - let iter = percent_25_bucket - .chain(percent_50_bucket) - .chain(percent_100_bucket) - .filter(|n| n.status() == NodeStatus::Questionable); - - self.send_bootstrap_requests(iter, target_id, table, out, event_loop) + let questionable_nodes: Vec = { + let routing_table = table.read().unwrap(); + let mut buckets = routing_table.buckets().skip(bootstrap_bucket - 2); + let dummy_bucket = Bucket::new(); + + // Sloppy probabilities of our target node residing at the node + let percent_25_bucket = if let Some(bucket) = buckets.next() { + match bucket { + BucketContents::Empty => dummy_bucket.iter(), + BucketContents::Sorted(b) | BucketContents::Assorted(b) => b.iter(), + } + } else { + dummy_bucket.iter() + }; + let percent_50_bucket = if let Some(bucket) = buckets.next() { + match bucket { + BucketContents::Empty => dummy_bucket.iter(), + BucketContents::Sorted(b) | BucketContents::Assorted(b) => b.iter(), + } + } else { + dummy_bucket.iter() + }; + let percent_100_bucket = if let Some(bucket) = buckets.next() { + match bucket { + BucketContents::Empty => dummy_bucket.iter(), + BucketContents::Sorted(b) | BucketContents::Assorted(b) => b.iter(), + } + } else { + dummy_bucket.iter() + }; + + // TODO: Figure out why chaining them in reverse gives us more total nodes on average, perhaps it allows us to fill up the lower + // buckets faster at the cost of less nodes in the higher buckets (since lower buckets are very easy to fill)...Although it should + // even out since we are stagnating buckets, so doing it in reverse may make sense since on the 3rd iteration, it allows us to ping + // questionable nodes in our first buckets right off the bat. + percent_25_bucket + .chain(percent_50_bucket) + .chain(percent_100_bucket) + .filter(|n| n.status() == NodeStatus::Questionable) + .cloned() + .collect() + }; + + self.send_bootstrap_requests::( + questionable_nodes.iter(), + target_id, + table.clone(), + out.clone(), + scheduled_task_sender, + ) + .await + } } + .boxed() } - fn send_bootstrap_requests<'a, H, I>( - &mut self, + async fn send_bootstrap_requests<'a, H, I>( + &self, nodes: I, target_id: NodeId, - table: &RoutingTable, - out: &SyncSender<(Vec, SocketAddr)>, - event_loop: &mut EventLoop>, + table: Arc>, + mut out: mpsc::Sender<(Vec, SocketAddr)>, + scheduled_task_sender: mpsc::Sender, ) -> BootstrapStatus where I: Iterator, - H: HandshakerTrait, + H: HandshakerTrait + 'static, { - info!("bip_dht: bootstrap::send_bootstrap_requests {}", self.curr_bootstrap_bucket); + let bootstrap_bucket = self.curr_bootstrap_bucket.load(Ordering::Relaxed); + + tracing::info!("bip_dht: bootstrap::send_bootstrap_requests {}", bootstrap_bucket); let mut messages_sent = 0; for node in nodes.take(BOOTSTRAP_PINGS_PER_BUCKET) { // Generate a transaction id - let trans_id = self.id_generator.generate(); + let trans_id = self.id_generator.lock().unwrap().generate(); let find_node_msg = FindNodeRequest::new(trans_id.as_ref(), self.table_id, target_id).encode(); // Add a timeout for the node - let res_timeout = event_loop.timeout_ms( - (BOOTSTRAP_NODE_TIMEOUT, ScheduledTaskCheck::BootstrapTimeout(trans_id)), - BOOTSTRAP_NODE_TIMEOUT, - ); - let Ok(timeout) = res_timeout else { - error!("bip_dht: Failed to set a timeout for the start of a table bootstrap..."); - return BootstrapStatus::Failed; - }; + let timeout = tokio::time::Instant::now() + Duration::from_millis(BOOTSTRAP_NODE_TIMEOUT); // Send the message to the node - if out.send((find_node_msg, node.addr())).is_err() { - error!("bip_dht: Could not send a bootstrap message through the channel..."); + if out.send((find_node_msg, node.addr())).await.is_err() { + tracing::error!("bip_dht: Could not send a bootstrap message through the channel..."); return BootstrapStatus::Failed; } @@ -273,23 +308,46 @@ impl TableBootstrap { node.local_request(); // Create an entry for the timeout in the map - self.active_messages.insert(trans_id, timeout); + self.active_messages.lock().unwrap().insert(trans_id, timeout); messages_sent += 1; + + // Schedule a timeout check + let mut this_scheduled_task_sender = scheduled_task_sender.clone(); + self.tasks.lock().unwrap().spawn(async move { + sleep(Duration::from_millis(BOOTSTRAP_INITIAL_TIMEOUT)).await; + + match this_scheduled_task_sender + .send(ScheduledTaskCheck::BootstrapTimeout(trans_id)) + .await + { + Ok(()) => { + tracing::debug!("sent scheduled bootstrap timeout"); + Ok(()) + } + Err(e) => { + tracing::debug!("error sending scheduled bootstrap timeout: {e}"); + Err(e) + } + } + }); } - self.curr_bootstrap_bucket += 1; - if self.curr_bootstrap_bucket == table::MAX_BUCKETS { + let bootstrap_bucket = self.curr_bootstrap_bucket.fetch_add(1, Ordering::AcqRel) + 1; + + if (bootstrap_bucket) == table::MAX_BUCKETS { BootstrapStatus::Completed } else if messages_sent == 0 { - self.bootstrap_next_bucket(table, out, event_loop) + self.bootstrap_next_bucket::(table, out, scheduled_task_sender).await } else { - return BootstrapStatus::Bootstrapping; + BootstrapStatus::Bootstrapping } } fn current_bootstrap_status(&self) -> BootstrapStatus { - if self.curr_bootstrap_bucket == table::MAX_BUCKETS || self.active_messages.is_empty() { + let bootstrap_bucket = self.curr_bootstrap_bucket.load(Ordering::Relaxed); + + if bootstrap_bucket == table::MAX_BUCKETS || self.active_messages.lock().unwrap().is_empty() { BootstrapStatus::Idle } else { BootstrapStatus::Bootstrapping diff --git a/packages/dht/src/worker/handler.rs b/packages/dht/src/worker/handler.rs index 0ba9d901e..55bb46771 100644 --- a/packages/dht/src/worker/handler.rs +++ b/packages/dht/src/worker/handler.rs @@ -1,12 +1,15 @@ use std::collections::HashMap; use std::convert::AsRef; -use std::net::{SocketAddr, SocketAddrV4, SocketAddrV6, UdpSocket}; -use std::sync::mpsc::{self, SyncSender}; -use std::{io, mem, thread}; +use std::net::{SocketAddr, SocketAddrV4, SocketAddrV6}; +use std::sync::atomic::{AtomicBool, AtomicUsize, Ordering}; +use std::sync::{Arc, Mutex, RwLock}; use bencode::{ben_bytes, BDecodeOpt, BencodeMut, BencodeRef}; -use log::{error, info, log_enabled, warn}; -use mio::{self, EventLoop, Handler}; +use futures::channel::mpsc; +use futures::future::BoxFuture; +use futures::{FutureExt, SinkExt, StreamExt as _}; +use tokio::net::UdpSocket; +use tokio::task::JoinSet; use util::bt::InfoHash; use util::convert; use util::net::IpAddr; @@ -32,66 +35,80 @@ use crate::worker::lookup::{LookupStatus, TableLookup}; use crate::worker::refresh::{RefreshStatus, TableRefresh}; use crate::worker::{DhtEvent, OneshotTask, ScheduledTaskCheck, ShutdownCause}; -// TODO: Update modules to use find_node on the routing table to update the status of a given node. - const MAX_BOOTSTRAP_ATTEMPTS: usize = 3; const BOOTSTRAP_GOOD_NODE_THRESHOLD: usize = 10; +enum Task { + Main(OneshotTask), + Scheduled(ScheduledTaskCheck), +} + /// Spawns a DHT handler that maintains our routing table and executes our actions on the DHT. #[allow(clippy::module_name_repetitions)] pub fn create_dht_handler( table: RoutingTable, - out: SyncSender<(Vec, SocketAddr)>, + out: mpsc::Sender<(Vec, SocketAddr)>, read_only: bool, handshaker: H, - kill_sock: UdpSocket, + kill_sock: Arc, kill_addr: SocketAddr, -) -> io::Result> +) -> (mpsc::Sender, JoinSet<()>) where H: HandshakerTrait + 'static, { - let mut handler = DhtHandler::new(table, out, read_only, handshaker); - let mut event_loop = EventLoop::new()?; - - let loop_channel = event_loop.channel(); - - thread::spawn(move || { - if event_loop.run(&mut handler).is_err() { - error!("bip_dht: EventLoop shut down with an error..."); + let (main_task_sender, main_task_receiver) = mpsc::channel(100); + let (scheduled_task_sender, scheduled_task_receiver) = mpsc::channel(100); + + let main_task_receiver = main_task_receiver.map(Task::Main); + let scheduled_task_receiver = scheduled_task_receiver.map(Task::Scheduled); + + let mut tasks_receiver = futures::stream::select(main_task_receiver, scheduled_task_receiver); + + let handler = DhtHandler::new( + table, + out, + main_task_sender.clone(), + scheduled_task_sender, + read_only, + handshaker, + ); + + let mut tasks = JoinSet::new(); + + tasks.spawn(async move { + while let Some(task) = tasks_receiver.next().await { + match task { + Task::Main(main_task) => handler.handle_task(main_task).await, + Task::Scheduled(scheduled_task) => handler.handle_scheduled_task(scheduled_task).await, + } } - // Make sure the handler and event loop are dropped before sending our incoming messenger kill - // message so that the incoming messenger can not send anything through their event loop channel. - mem::drop(event_loop); - mem::drop(handler); - // When event loop stops, we need to "wake" the incoming messenger with a socket message, // when it processes the message and tries to pass it to us, it will see that our channel // is closed and know that it should shut down. The outgoing messenger will shut itself down. - // TODO: This will not work if kill_addr is set to a default route 0.0.0.0, need to find another - // work around (potentially finding out the actual addresses for the current machine beforehand?) - if kill_sock.send_to(&b"0"[..], kill_addr).is_err() { - error!("bip_dht: Failed to send a wake up message to the incoming channel..."); + if kill_sock.send_to(&b"0"[..], kill_addr).await.is_err() { + tracing::error!("bip_dht: Failed to send a wake up message to the incoming channel..."); } - info!("bip_dht: DhtHandler gracefully shut down, exiting thread..."); + tracing::info!("bip_dht: DhtHandler gracefully shut down, exiting thread..."); }); - Ok(loop_channel) + (main_task_sender, tasks) } // ----------------------------------------------------------------------------// /// Actions that we can perform on our `RoutingTable`. +#[derive(Clone)] enum TableAction { /// Lookup action. - Lookup(TableLookup), + Lookup(Arc), /// Refresh action. - Refresh(Box), + Refresh(Arc), /// Bootstrap action. /// /// Includes number of bootstrap attempts. - Bootstrap(TableBootstrap, usize), + Bootstrap(Arc, Arc), } /// Actions that we want to perform on our `RoutingTable` after bootstrapping finishes. @@ -102,35 +119,41 @@ enum PostBootstrapAction { Refresh(Box, TransactionID), } -/// Storage for our `EventLoop` to invoke actions upon. #[allow(clippy::module_name_repetitions)] pub struct DhtHandler { - detached: DetachedDhtHandler, - table_actions: HashMap, -} + handshaker: futures::lock::Mutex, + routing_table: Arc>, + table_actions: Mutex>, + + out_channel: mpsc::Sender<(Vec, SocketAddr)>, + main_task_sender: mpsc::Sender, + scheduled_task_sender: mpsc::Sender, -/// Storage separate from the table actions allowing us to hold mutable references -/// to table actions while still being able to pass around the bulky parameters. -struct DetachedDhtHandler { read_only: bool, - handshaker: H, - out_channel: SyncSender<(Vec, SocketAddr)>, - token_store: TokenStore, - aid_generator: AIDGenerator, - bootstrapping: bool, - routing_table: RoutingTable, - active_stores: AnnounceStorage, + bootstrapping: AtomicBool, + + token_store: Mutex, + aid_generator: Mutex, + active_stores: Mutex, + // If future actions is not empty, that means we are still bootstrapping // since we will always spin up a table refresh action after bootstrapping. - future_actions: Vec, - event_notifiers: Vec>, + future_actions: Mutex>, + event_notifiers: Mutex>>, } impl DhtHandler where - H: HandshakerTrait, + H: HandshakerTrait + 'static, { - fn new(table: RoutingTable, out: SyncSender<(Vec, SocketAddr)>, read_only: bool, handshaker: H) -> DhtHandler { + fn new( + table: RoutingTable, + out: mpsc::Sender<(Vec, SocketAddr)>, + main_task_sender: mpsc::Sender, + scheduled_task_sender: mpsc::Sender, + read_only: bool, + handshaker: H, + ) -> DhtHandler { let mut aid_generator = AIDGenerator::new(); // Insert the refresh task to execute after the bootstrap @@ -139,868 +162,901 @@ where let table_refresh = Box::new(TableRefresh::new(mid_generator)); let future_actions = vec![PostBootstrapAction::Refresh(table_refresh, refresh_trans_id)]; - let detached = DetachedDhtHandler { + DhtHandler { read_only, - handshaker, + handshaker: futures::lock::Mutex::new(handshaker), out_channel: out, - token_store: TokenStore::new(), - aid_generator, - bootstrapping: false, - routing_table: table, - active_stores: AnnounceStorage::new(), - future_actions, - event_notifiers: Vec::new(), - }; - - DhtHandler { - detached, - table_actions: HashMap::new(), + token_store: Mutex::new(TokenStore::new()), + aid_generator: Mutex::new(aid_generator), + bootstrapping: AtomicBool::default(), + routing_table: Arc::new(RwLock::new(table)), + active_stores: Mutex::new(AnnounceStorage::new()), + future_actions: Mutex::new(future_actions), + event_notifiers: Mutex::default(), + table_actions: Mutex::new(HashMap::new()), + main_task_sender, + scheduled_task_sender, } } -} - -impl Handler for DhtHandler -where - H: HandshakerTrait, -{ - type Timeout = (u64, ScheduledTaskCheck); - type Message = OneshotTask; - fn notify(&mut self, event_loop: &mut EventLoop>, task: OneshotTask) { + async fn handle_task(&self, task: OneshotTask) { match task { OneshotTask::Incoming(buffer, addr) => { - handle_incoming(self, event_loop, &buffer[..], addr); + self.handle_incoming(&buffer[..], addr).await; } OneshotTask::RegisterSender(send) => { - handle_register_sender(self, send); + self.handle_register_sender(send); } OneshotTask::StartBootstrap(routers, nodes) => { - handle_start_bootstrap(self, event_loop, routers, nodes); + self.handle_start_bootstrap(routers, nodes).await; } OneshotTask::StartLookup(info_hash, should_announce) => { - handle_start_lookup( - &mut self.table_actions, - &mut self.detached, - event_loop, - info_hash, - should_announce, - ); + self.handle_start_lookup(info_hash, should_announce).await; } OneshotTask::Shutdown(cause) => { - handle_shutdown(self, event_loop, cause); + self.handle_shutdown(cause); } } } - fn timeout(&mut self, event_loop: &mut EventLoop>, data: (u64, ScheduledTaskCheck)) { - let (_, task) = data; + #[allow(clippy::too_many_lines)] + async fn handle_incoming(&self, buffer: &[u8], addr: SocketAddr) { + // Parse the buffer as a bencoded message + let Ok(bencode) = BencodeRef::decode(buffer, BDecodeOpt::default()) else { + tracing::warn!("bip_dht: Received invalid bencode data..."); + return; + }; - match task { - ScheduledTaskCheck::TableRefresh(trans_id) => { - handle_check_table_refresh(&mut self.table_actions, &self.detached, event_loop, trans_id); - } - ScheduledTaskCheck::BootstrapTimeout(trans_id) => { - handle_check_bootstrap_timeout(self, event_loop, trans_id); - } - ScheduledTaskCheck::LookupTimeout(trans_id) => { - handle_check_lookup_timeout(self, event_loop, trans_id); + // Parse the bencode as a message + // Check to make sure we issued the transaction id (or that it is still valid) + let message = MessageType::>::new(&bencode, |trans| { + // Check if we can interpret the response transaction id as one of ours. + let Some(trans_id) = TransactionID::from_bytes(trans) else { + return ExpectedResponse::None; + }; + + // Match the response action id with our current actions + let Some(table_action) = self.table_actions.lock().unwrap().get(&trans_id.action_id()).cloned() else { + return ExpectedResponse::None; + }; + + match table_action { + TableAction::Lookup(_) => ExpectedResponse::GetPeers, + TableAction::Refresh(_) | TableAction::Bootstrap(_, _) => ExpectedResponse::FindNode, } - ScheduledTaskCheck::LookupEndGame(trans_id) => { - handle_check_lookup_endgame(self, event_loop, trans_id); + }); + + // Do not process requests if we are read only + // TODO: Add read only flags to messages we send it we are read only! + // Also, check for read only flags on responses we get before adding nodes + // to our RoutingTable. + if self.read_only { + if let Ok(MessageType::Request(_)) = message { + return; } } - } -} - -// ----------------------------------------------------------------------------// - -/// Shut down the event loop by sending it a shutdown message with the given cause. -fn shutdown_event_loop(event_loop: &EventLoop>, cause: ShutdownCause) -where - H: HandshakerTrait, -{ - if event_loop.channel().send(OneshotTask::Shutdown(cause)).is_err() { - error!("bip_dht: Failed to sent a shutdown message to the EventLoop..."); - } -} - -/// Broadcast the given event to all of the event notifiers. -fn broadcast_dht_event(notifiers: &mut Vec>, event: DhtEvent) { - notifiers.retain(|send| send.send(event).is_ok()); -} - -/// Number of good nodes in the `RoutingTable`. -fn num_good_nodes(table: &RoutingTable) -> usize { - table - .closest_nodes(table.node_id()) - .filter(|n| n.status() == NodeStatus::Good) - .count() -} - -/// We should rebootstrap if we have a low number of nodes. -fn should_rebootstrap(table: &RoutingTable) -> bool { - num_good_nodes(table) <= BOOTSTRAP_GOOD_NODE_THRESHOLD -} -/// Broadcast that the bootstrap has completed. -/// IMPORTANT: Should call this instead of `broadcast_dht_event`()! -fn broadcast_bootstrap_completed( - action_id: ActionID, - table_actions: &mut HashMap, - work_storage: &mut DetachedDhtHandler, - event_loop: &mut EventLoop>, -) where - H: HandshakerTrait, -{ - // Send notification that the bootstrap has completed. - broadcast_dht_event(&mut work_storage.event_notifiers, DhtEvent::BootstrapCompleted); + // Process the given message + match message { + Ok(MessageType::Request(RequestType::Ping(p))) => { + tracing::info!("bip_dht: Received a PingRequest..."); + let node = Node::as_good(p.node_id(), addr); - // Indicates we are out of the bootstrapping phase - work_storage.bootstrapping = false; + let ping_rsp = { + let routing_table = self.routing_table.read().unwrap(); - // Remove the bootstrap action from our table actions - table_actions.remove(&action_id); + // Node requested from us, mark it in the routing table + if let Some(n) = routing_table.find_node(&node) { + n.remote_request(); + } - // Start the post bootstrap actions. - let mut future_actions = work_storage.future_actions.split_off(0); - for table_action in future_actions.drain(..) { - match table_action { - PostBootstrapAction::Lookup(info_hash, should_announce) => { - handle_start_lookup(table_actions, work_storage, event_loop, info_hash, should_announce); - } - PostBootstrapAction::Refresh(refresh, trans_id) => { - table_actions.insert(trans_id.action_id(), TableAction::Refresh(refresh)); + PingResponse::new(p.transaction_id(), routing_table.node_id()) + }; - handle_check_table_refresh(table_actions, work_storage, event_loop, trans_id); - } - } - } -} + let ping_msg = ping_rsp.encode(); -/// Attempt to rebootstrap or shutdown the dht if we have no nodes after rebootstrapping multiple time. -/// Returns None if the DHT is shutting down, Some(true) if the rebootstrap process started, Some(false) if a rebootstrap is not necessary. -fn attempt_rebootstrap( - bootstrap: &mut TableBootstrap, - attempts: &mut usize, - work_storage: &mut DetachedDhtHandler, - event_loop: &mut EventLoop>, -) -> Option -where - H: HandshakerTrait, -{ - // Increment the bootstrap counter - *attempts += 1; + let mut out = self.out_channel.clone(); - warn!("bip_dht: Bootstrap attempt {} failed, attempting a rebootstrap...", *attempts); + let sent = out.send((ping_msg, addr)).await; - // Check if we reached the maximum bootstrap attempts - if *attempts >= MAX_BOOTSTRAP_ATTEMPTS { - if num_good_nodes(&work_storage.routing_table) == 0 { - // Failed to get any nodes in the rebootstrap attempts, shut down - shutdown_event_loop(event_loop, ShutdownCause::BootstrapFailed); - None - } else { - Some(false) - } - } else { - let bootstrap_status = bootstrap.start_bootstrap(&work_storage.out_channel, event_loop); - - match bootstrap_status { - BootstrapStatus::Idle => Some(false), - BootstrapStatus::Bootstrapping => Some(true), - BootstrapStatus::Failed => { - shutdown_event_loop(event_loop, ShutdownCause::Unspecified); - None - } - BootstrapStatus::Completed => { - if should_rebootstrap(&work_storage.routing_table) { - attempt_rebootstrap(bootstrap, attempts, work_storage, event_loop) - } else { - Some(false) + if sent.is_err() { + tracing::error!("bip_dht: Failed to send a ping response on the out channel..."); + self.handle_shutdown(ShutdownCause::Unspecified); } } - } - } -} - -// ----------------------------------------------------------------------------// + Ok(MessageType::Request(RequestType::FindNode(f))) => { + tracing::info!("bip_dht: Received a FindNodeRequest..."); + let node = Node::as_good(f.node_id(), addr); -#[allow(clippy::too_many_lines)] -fn handle_incoming(handler: &mut DhtHandler, event_loop: &mut EventLoop>, buffer: &[u8], addr: SocketAddr) -where - H: HandshakerTrait, -{ - let (work_storage, table_actions) = (&mut handler.detached, &mut handler.table_actions); - - // Parse the buffer as a bencoded message - let Ok(bencode) = BencodeRef::decode(buffer, BDecodeOpt::default()) else { - warn!("bip_dht: Received invalid bencode data..."); - return; - }; - - // Parse the bencode as a message - // Check to make sure we issued the transaction id (or that it is still valid) - let message = MessageType::>::new(&bencode, |trans| { - // Check if we can interpret the response transaction id as one of ours. - let Some(trans_id) = TransactionID::from_bytes(trans) else { - return ExpectedResponse::None; - }; + let find_node_msg = { + let routing_table = self.routing_table.read().unwrap(); - // Match the response action id with our current actions - match table_actions.get(&trans_id.action_id()) { - Some(&TableAction::Lookup(_)) => ExpectedResponse::GetPeers, - Some(&TableAction::Refresh(_) | &TableAction::Bootstrap(_, _)) => ExpectedResponse::FindNode, - None => ExpectedResponse::None, - } - }); + // Node requested from us, mark it in the routing table + if let Some(n) = routing_table.find_node(&node) { + n.remote_request(); + } - // Do not process requests if we are read only - // TODO: Add read only flags to messages we send it we are read only! - // Also, check for read only flags on responses we get before adding nodes - // to our RoutingTable. - if work_storage.read_only { - if let Ok(MessageType::Request(_)) = message { - return; - } - } + // Grab the closest nodes + let mut closest_nodes_bytes = Vec::with_capacity(26 * 8); + for node in routing_table.closest_nodes(f.target_id()).take(8) { + closest_nodes_bytes.extend_from_slice(&node.encode()); + } - // Process the given message - match message { - Ok(MessageType::Request(RequestType::Ping(p))) => { - info!("bip_dht: Received a PingRequest..."); - let node = Node::as_good(p.node_id(), addr); + let find_node_rsp = + FindNodeResponse::new(f.transaction_id(), routing_table.node_id(), &closest_nodes_bytes).unwrap(); + find_node_rsp.encode() + }; - // Node requested from us, mark it in the routing table - if let Some(n) = work_storage.routing_table.find_node(&node) { - n.remote_request(); + if self.out_channel.clone().send((find_node_msg, addr)).await.is_err() { + tracing::error!("bip_dht: Failed to send a find node response on the out channel..."); + self.handle_shutdown(ShutdownCause::Unspecified); + } } + Ok(MessageType::Request(RequestType::GetPeers(g))) => { + tracing::info!("bip_dht: Received a GetPeersRequest..."); + let node = Node::as_good(g.node_id(), addr); - let ping_rsp = PingResponse::new(p.transaction_id(), work_storage.routing_table.node_id()); - let ping_msg = ping_rsp.encode(); + let get_peers_msg = { + let routing_table = self.routing_table.read().unwrap(); - if work_storage.out_channel.send((ping_msg, addr)).is_err() { - error!("bip_dht: Failed to send a ping response on the out channel..."); - shutdown_event_loop(event_loop, ShutdownCause::Unspecified); - } - } - Ok(MessageType::Request(RequestType::FindNode(f))) => { - info!("bip_dht: Received a FindNodeRequest..."); - let node = Node::as_good(f.node_id(), addr); + // Node requested from us, mark it in the routing table + if let Some(n) = routing_table.find_node(&node) { + n.remote_request(); + } - // Node requested from us, mark it in the routing table - if let Some(n) = work_storage.routing_table.find_node(&node) { - n.remote_request(); - } + // TODO: Move socket address serialization code into bip_util + // TODO: Check what the maximum number of values we can give without overflowing a udp packet + // Also, if we aren't going to give all of the contacts, we may want to shuffle which ones we give + let mut contact_info_bytes = Vec::with_capacity(6 * 20); + self.active_stores.lock().unwrap().find_items(&g.info_hash(), |addr| { + let mut bytes = [0u8; 6]; + let port = addr.port(); + + match addr { + SocketAddr::V4(v4_addr) => { + for (src, dst) in convert::ipv4_to_bytes_be(*v4_addr.ip()).iter().zip(bytes.iter_mut()) { + *dst = *src; + } + } + SocketAddr::V6(_) => { + tracing::error!("AnnounceStorage contained an IPv6 Address..."); + return; + } + }; - // Grab the closest nodes - let mut closest_nodes_bytes = Vec::with_capacity(26 * 8); - for node in work_storage.routing_table.closest_nodes(f.target_id()).take(8) { - closest_nodes_bytes.extend_from_slice(&node.encode()); - } + bytes[4] = (port >> 8) as u8; + bytes[5] = (port & 0x00FF) as u8; - let find_node_rsp = - FindNodeResponse::new(f.transaction_id(), work_storage.routing_table.node_id(), &closest_nodes_bytes).unwrap(); - let find_node_msg = find_node_rsp.encode(); + contact_info_bytes.extend_from_slice(&bytes); + }); + // Grab the bencoded list (ugh, we really have to do this, better apis I say!!!) + let mut contact_info_bencode = Vec::with_capacity(contact_info_bytes.len() / 6); + for chunk_index in 0..(contact_info_bytes.len() / 6) { + let (start, end) = (chunk_index * 6, chunk_index * 6 + 6); - if work_storage.out_channel.send((find_node_msg, addr)).is_err() { - error!("bip_dht: Failed to send a find node response on the out channel..."); - shutdown_event_loop(event_loop, ShutdownCause::Unspecified); - } - } - Ok(MessageType::Request(RequestType::GetPeers(g))) => { - info!("bip_dht: Received a GetPeersRequest..."); - let node = Node::as_good(g.node_id(), addr); - - // Node requested from us, mark it in the routing table - if let Some(n) = work_storage.routing_table.find_node(&node) { - n.remote_request(); - } - - // TODO: Move socket address serialization code into bip_util - // TODO: Check what the maximum number of values we can give without overflowing a udp packet - // Also, if we aren't going to give all of the contacts, we may want to shuffle which ones we give - let mut contact_info_bytes = Vec::with_capacity(6 * 20); - work_storage.active_stores.find_items(&g.info_hash(), |addr| { - let mut bytes = [0u8; 6]; - let port = addr.port(); - - match addr { - SocketAddr::V4(v4_addr) => { - for (src, dst) in convert::ipv4_to_bytes_be(*v4_addr.ip()).iter().zip(bytes.iter_mut()) { - *dst = *src; - } + contact_info_bencode.push(ben_bytes!(&contact_info_bytes[start..end])); } - SocketAddr::V6(_) => { - error!("AnnounceStorage contained an IPv6 Address..."); - return; + + // Grab the closest nodes + let mut closest_nodes_bytes = Vec::with_capacity(26 * 8); + for node in routing_table.closest_nodes(g.info_hash()).take(8) { + closest_nodes_bytes.extend_from_slice(&node.encode()); } - }; - bytes[4] = (port >> 8) as u8; - bytes[5] = (port & 0x00FF) as u8; + // Wrap up the nodes/values we are going to be giving them + let token = self.token_store.lock().unwrap().checkout(IpAddr::from_socket_addr(addr)); + let compact_info_type = if contact_info_bencode.is_empty() { + CompactInfoType::Nodes(CompactNodeInfo::new(&closest_nodes_bytes).unwrap()) + } else { + CompactInfoType::>::Both( + CompactNodeInfo::new(&closest_nodes_bytes).unwrap(), + CompactValueInfo::new(&contact_info_bencode).unwrap(), + ) + }; - contact_info_bytes.extend_from_slice(&bytes); - }); - // Grab the bencoded list (ugh, we really have to do this, better apis I say!!!) - let mut contact_info_bencode = Vec::with_capacity(contact_info_bytes.len() / 6); - for chunk_index in 0..(contact_info_bytes.len() / 6) { - let (start, end) = (chunk_index * 6, chunk_index * 6 + 6); + let get_peers_rsp = GetPeersResponse::>::new( + g.transaction_id(), + routing_table.node_id(), + Some(token.as_ref()), + compact_info_type, + ); - contact_info_bencode.push(ben_bytes!(&contact_info_bytes[start..end])); - } + get_peers_rsp.encode() + }; - // Grab the closest nodes - let mut closest_nodes_bytes = Vec::with_capacity(26 * 8); - for node in work_storage.routing_table.closest_nodes(g.info_hash()).take(8) { - closest_nodes_bytes.extend_from_slice(&node.encode()); + if self.out_channel.clone().send((get_peers_msg, addr)).await.is_err() { + tracing::error!("bip_dht: Failed to send a get peers response on the out channel..."); + self.handle_shutdown(ShutdownCause::Unspecified); + } } + Ok(MessageType::Request(RequestType::AnnouncePeer(a))) => { + tracing::info!("bip_dht: Received an AnnouncePeerRequest..."); + let node = Node::as_good(a.node_id(), addr); - // Wrap up the nodes/values we are going to be giving them - let token = work_storage.token_store.checkout(IpAddr::from_socket_addr(addr)); - let compact_info_type = if contact_info_bencode.is_empty() { - CompactInfoType::Nodes(CompactNodeInfo::new(&closest_nodes_bytes).unwrap()) - } else { - CompactInfoType::>::Both( - CompactNodeInfo::new(&closest_nodes_bytes).unwrap(), - CompactValueInfo::new(&contact_info_bencode).unwrap(), - ) - }; - - let get_peers_rsp = GetPeersResponse::>::new( - g.transaction_id(), - work_storage.routing_table.node_id(), - Some(token.as_ref()), - compact_info_type, - ); - let get_peers_msg = get_peers_rsp.encode(); + let response_msg = { + let routing_table = self.routing_table.read().unwrap(); - if work_storage.out_channel.send((get_peers_msg, addr)).is_err() { - error!("bip_dht: Failed to send a get peers response on the out channel..."); - shutdown_event_loop(event_loop, ShutdownCause::Unspecified); - } - } - Ok(MessageType::Request(RequestType::AnnouncePeer(a))) => { - info!("bip_dht: Received an AnnouncePeerRequest..."); - let node = Node::as_good(a.node_id(), addr); + // Node requested from us, mark it in the routing table + if let Some(n) = routing_table.find_node(&node) { + n.remote_request(); + } - // Node requested from us, mark it in the routing table - if let Some(n) = work_storage.routing_table.find_node(&node) { - n.remote_request(); - } + // Validate the token + let is_valid = match Token::new(a.token()) { + Ok(t) => self.token_store.lock().unwrap().checkin(IpAddr::from_socket_addr(addr), t), + Err(_) => false, + }; - // Validate the token - let is_valid = match Token::new(a.token()) { - Ok(t) => work_storage.token_store.checkin(IpAddr::from_socket_addr(addr), t), - Err(_) => false, - }; + // Create a socket address based on the implied/explicit port number + let connect_addr = match a.connect_port() { + ConnectPort::Implied => addr, + ConnectPort::Explicit(port) => match addr { + SocketAddr::V4(v4_addr) => SocketAddr::V4(SocketAddrV4::new(*v4_addr.ip(), port)), + SocketAddr::V6(v6_addr) => { + SocketAddr::V6(SocketAddrV6::new(*v6_addr.ip(), port, v6_addr.flowinfo(), v6_addr.scope_id())) + } + }, + }; - // Create a socket address based on the implied/explicit port number - let connect_addr = match a.connect_port() { - ConnectPort::Implied => addr, - ConnectPort::Explicit(port) => match addr { - SocketAddr::V4(v4_addr) => SocketAddr::V4(SocketAddrV4::new(*v4_addr.ip(), port)), - SocketAddr::V6(v6_addr) => { - SocketAddr::V6(SocketAddrV6::new(*v6_addr.ip(), port, v6_addr.flowinfo(), v6_addr.scope_id())) + // Resolve type of response we are going to send + if !is_valid { + // Node gave us an invalid token + tracing::warn!("bip_dht: Remote node sent us an invalid token for an AnnounceRequest..."); + ErrorMessage::new( + a.transaction_id().to_vec(), + ErrorCode::ProtocolError, + "Received An Invalid Token".to_owned(), + ) + .encode() + } else if self.active_stores.lock().unwrap().add_item(a.info_hash(), connect_addr) { + // Node successfully stored the value with us, send an announce response + AnnouncePeerResponse::new(a.transaction_id(), routing_table.node_id()).encode() + } else { + // Node unsuccessfully stored the value with us, send them an error message + // TODO: Spec doesn't actually say what error message to send, or even if we should send one... + tracing::warn!( + "bip_dht: AnnounceStorage failed to store contact information because it \ + is full..." + ); + ErrorMessage::new( + a.transaction_id().to_vec(), + ErrorCode::ServerError, + "Announce Storage Is Full".to_owned(), + ) + .encode() } - }, - }; - - // Resolve type of response we are going to send - let response_msg = if !is_valid { - // Node gave us an invalid token - warn!("bip_dht: Remote node sent us an invalid token for an AnnounceRequest..."); - ErrorMessage::new( - a.transaction_id().to_vec(), - ErrorCode::ProtocolError, - "Received An Invalid Token".to_owned(), - ) - .encode() - } else if work_storage.active_stores.add_item(a.info_hash(), connect_addr) { - // Node successfully stored the value with us, send an announce response - AnnouncePeerResponse::new(a.transaction_id(), work_storage.routing_table.node_id()).encode() - } else { - // Node unsuccessfully stored the value with us, send them an error message - // TODO: Spec doesn't actually say what error message to send, or even if we should send one... - warn!( - "bip_dht: AnnounceStorage failed to store contact information because it \ - is full..." - ); - ErrorMessage::new( - a.transaction_id().to_vec(), - ErrorCode::ServerError, - "Announce Storage Is Full".to_owned(), - ) - .encode() - }; + }; - if work_storage.out_channel.send((response_msg, addr)).is_err() { - error!("bip_dht: Failed to send an announce peer response on the out channel..."); - shutdown_event_loop(event_loop, ShutdownCause::Unspecified); + if self.out_channel.clone().send((response_msg, addr)).await.is_err() { + tracing::error!("bip_dht: Failed to send an announce peer response on the out channel..."); + self.handle_shutdown(ShutdownCause::Unspecified); + } } - } - Ok(MessageType::Response(ResponseType::FindNode(f))) => { - info!("bip_dht: Received a FindNodeResponse..."); - let trans_id = TransactionID::from_bytes(f.transaction_id()).unwrap(); - let node = Node::as_good(f.node_id(), addr); + Ok(MessageType::Response(ResponseType::FindNode(f))) => { + tracing::info!("bip_dht: Received a FindNodeResponse..."); + let trans_id = TransactionID::from_bytes(f.transaction_id()).unwrap(); + let node = Node::as_good(f.node_id(), addr); - // Add the payload nodes as questionable - for (id, v4_addr) in f.nodes() { - let sock_addr = SocketAddr::V4(v4_addr); + let opt_bootstrap = { + let mut routing_table = self.routing_table.write().unwrap(); - work_storage.routing_table.add_node(&Node::as_questionable(id, sock_addr)); - } + // Add the payload nodes as questionable + for (id, v4_addr) in f.nodes() { + let sock_addr = SocketAddr::V4(v4_addr); - let bootstrap_complete = { - let opt_bootstrap = match table_actions.get_mut(&trans_id.action_id()) { - Some(&mut TableAction::Refresh(_)) => { - work_storage.routing_table.add_node(&node); - None + routing_table.add_node(&Node::as_questionable(id, sock_addr)); } - Some(&mut TableAction::Bootstrap(ref mut bootstrap, ref mut attempts)) => { - if !bootstrap.is_router(&node.addr()) { - work_storage.routing_table.add_node(&node); + + // Match the response action id with our current actions + let table_action = self.table_actions.lock().unwrap().get(&trans_id.action_id()).cloned(); + + match table_action { + Some(TableAction::Refresh(_)) => { + routing_table.add_node(&node); + None + } + Some(TableAction::Bootstrap(bootstrap, attempts)) => { + if !bootstrap.is_router(&node.addr()) { + routing_table.add_node(&node); + } + Some((bootstrap, attempts)) + } + Some(TableAction::Lookup(_)) => { + tracing::error!("bip_dht: Resolved a FindNodeResponse ActionID to a TableLookup..."); + None + } + None => { + tracing::error!( + "bip_dht: Resolved a TransactionID to a FindNodeResponse but no \ + action found..." + ); + None } - Some((bootstrap, attempts)) - } - Some(&mut TableAction::Lookup(_)) => { - error!("bip_dht: Resolved a FindNodeResponse ActionID to a TableLookup..."); - None - } - None => { - error!( - "bip_dht: Resolved a TransactionID to a FindNodeResponse but no \ - action found..." - ); - None } }; - if let Some((bootstrap, attempts)) = opt_bootstrap { - match bootstrap.recv_response(&trans_id, &work_storage.routing_table, &work_storage.out_channel, event_loop) { - BootstrapStatus::Idle => true, - BootstrapStatus::Bootstrapping => false, - BootstrapStatus::Failed => { - shutdown_event_loop(event_loop, ShutdownCause::Unspecified); - false - } - BootstrapStatus::Completed => { - if should_rebootstrap(&work_storage.routing_table) { - attempt_rebootstrap(bootstrap, attempts, work_storage, event_loop) == Some(false) - } else { - true + let bootstrap_complete = { + if let Some((bootstrap, attempts)) = opt_bootstrap { + let response = bootstrap + .recv_response::( + trans_id, + self.routing_table.clone(), + self.out_channel.clone(), + self.scheduled_task_sender.clone(), + ) + .await; + + match response { + BootstrapStatus::Idle => true, + BootstrapStatus::Bootstrapping => false, + BootstrapStatus::Failed => { + self.handle_shutdown(ShutdownCause::Unspecified); + false + } + BootstrapStatus::Completed => { + if should_rebootstrap(&self.routing_table.read().unwrap()) { + attempt_rebootstrap( + bootstrap.clone(), + attempts.clone(), + self.routing_table.clone(), + self.out_channel.clone(), + self.main_task_sender.clone(), + self.scheduled_task_sender.clone(), + ) + .await + == Some(false) + } else { + true + } } } + } else { + false } - } else { - false + }; + + if bootstrap_complete { + self.broadcast_bootstrap_completed(trans_id.action_id()).await; } - }; - if bootstrap_complete { - broadcast_bootstrap_completed(trans_id.action_id(), table_actions, work_storage, event_loop); - } + let routing_table = self.routing_table.read().unwrap(); - if log_enabled!(log::Level::Info) { - let mut total = 0; + if tracing::enabled!(tracing::Level::INFO) { + let mut total = 0; - for (index, bucket) in work_storage.routing_table.buckets().enumerate() { - let num_nodes = match bucket { - BucketContents::Empty => 0, - BucketContents::Sorted(b) => b.iter().filter(|n| n.status() == NodeStatus::Good).count(), - BucketContents::Assorted(b) => b.iter().filter(|n| n.status() == NodeStatus::Good).count(), - }; - total += num_nodes; + for (index, bucket) in routing_table.buckets().enumerate() { + let num_nodes = match bucket { + BucketContents::Empty => 0, + BucketContents::Sorted(b) => b.iter().filter(|n| n.status() == NodeStatus::Good).count(), + BucketContents::Assorted(b) => b.iter().filter(|n| n.status() == NodeStatus::Good).count(), + }; + total += num_nodes; - if num_nodes != 0 { - print!("Bucket {index}: {num_nodes} | "); + if num_nodes != 0 { + print!("Bucket {index}: {num_nodes} | "); + } } - } - print!("\nTotal: {total}\n\n\n"); + print!("\nTotal: {total}\n\n\n"); + } } - } - Ok(MessageType::Response(ResponseType::GetPeers(g))) => { - // info!("bip_dht: Received a GetPeersResponse..."); - let trans_id = TransactionID::from_bytes(g.transaction_id()).unwrap(); - let node = Node::as_good(g.node_id(), addr); - - work_storage.routing_table.add_node(&node); - - let opt_lookup = { - match table_actions.get_mut(&trans_id.action_id()) { - Some(&mut TableAction::Lookup(ref mut lookup)) => Some(lookup), - Some(&mut TableAction::Refresh(_)) => { - error!( - "bip_dht: Resolved a GetPeersResponse ActionID to a \ - TableRefresh..." - ); - None - } - Some(&mut TableAction::Bootstrap(_, _)) => { - error!( - "bip_dht: Resolved a GetPeersResponse ActionID to a \ - TableBootstrap..." - ); - None - } - None => { - error!( - "bip_dht: Resolved a TransactionID to a GetPeersResponse but no \ - action found..." - ); - None - } + Ok(MessageType::Response(ResponseType::GetPeers(g))) => { + tracing::info!("bip_dht: Received a GetPeersResponse..."); + let trans_id = TransactionID::from_bytes(g.transaction_id()).unwrap(); + let node = Node::as_good(g.node_id(), addr); + + { + let mut routing_table = self.routing_table.write().unwrap(); + + routing_table.add_node(&node); } - }; - if let Some(lookup) = opt_lookup { - match lookup.recv_response( - node, - &trans_id, - g, - &work_storage.routing_table, - &work_storage.out_channel, - event_loop, - ) { - LookupStatus::Searching => (), - LookupStatus::Completed => broadcast_dht_event( - &mut work_storage.event_notifiers, - DhtEvent::LookupCompleted(lookup.info_hash()), - ), - LookupStatus::Failed => shutdown_event_loop(event_loop, ShutdownCause::Unspecified), - LookupStatus::Values(values) => { - for v4_addr in values { - let sock_addr = SocketAddr::V4(v4_addr); - work_storage.handshaker.connect(None, lookup.info_hash(), sock_addr); + let opt_lookup = { + let table_action = self.table_actions.lock().unwrap().get(&trans_id.action_id()).cloned(); + + match table_action { + Some(TableAction::Lookup(lookup)) => Some(lookup), + Some(TableAction::Refresh(_)) => { + tracing::error!( + "bip_dht: Resolved a GetPeersResponse ActionID to a \ + TableRefresh..." + ); + None + } + Some(TableAction::Bootstrap(_, _)) => { + tracing::error!( + "bip_dht: Resolved a GetPeersResponse ActionID to a \ + TableBootstrap..." + ); + None + } + None => { + tracing::error!( + "bip_dht: Resolved a TransactionID to a GetPeersResponse but no \ + action found..." + ); + None + } + } + }; + + if let Some(lookup) = opt_lookup { + match lookup + .recv_response( + node, + trans_id, + g, + self.routing_table.clone(), + self.out_channel.clone(), + self.scheduled_task_sender.clone(), + ) + .await + { + LookupStatus::Searching => (), + LookupStatus::Completed => { + self.broadcast_dht_event(DhtEvent::LookupCompleted(lookup.info_hash())); + } + LookupStatus::Failed => self.handle_shutdown(ShutdownCause::Unspecified), + LookupStatus::Values(values) => { + for v4_addr in values { + let sock_addr = SocketAddr::V4(v4_addr); + self.handshaker + .lock() + .await + .connect(None, lookup.info_hash(), sock_addr) + .await; + } } } } } - } - Ok(MessageType::Response(ResponseType::Ping(_))) => { - info!("bip_dht: Received a PingResponse..."); + Ok(MessageType::Response(ResponseType::Ping(_))) => { + tracing::info!("bip_dht: Received a PingResponse..."); - // Yeah...we should never be getting this type of response (we never use this message) - } - Ok(MessageType::Response(ResponseType::AnnouncePeer(_))) => { - info!("bip_dht: Received an AnnouncePeerResponse..."); - } - Ok(MessageType::Error(e)) => { - info!("bip_dht: Received an ErrorMessage..."); + // Yeah...we should never be getting this type of response (we never use this message) + } + Ok(MessageType::Response(ResponseType::AnnouncePeer(_))) => { + tracing::info!("bip_dht: Received an AnnouncePeerResponse..."); + } + Ok(MessageType::Error(e)) => { + tracing::info!("bip_dht: Received an ErrorMessage..."); - warn!("bip_dht: KRPC error message from {:?}: {:?}", addr, e); - } - Err(e) => { - warn!("bip_dht: Error parsing KRPC message: {:?}", e); + tracing::warn!("bip_dht: KRPC error message from {:?}: {:?}", addr, e); + } + Err(e) => { + tracing::warn!("bip_dht: Error parsing KRPC message: {:?}", e); + } } } -} -fn handle_register_sender(handler: &mut DhtHandler, sender: mpsc::Sender) { - handler.detached.event_notifiers.push(sender); -} + fn handle_register_sender(&self, sender: mpsc::Sender) { + self.event_notifiers.lock().unwrap().push(sender); + } -fn handle_start_bootstrap( - handler: &mut DhtHandler, - event_loop: &mut EventLoop>, - routers: Vec, - nodes: Vec, -) where - H: HandshakerTrait, -{ - let (work_storage, table_actions) = (&mut handler.detached, &mut handler.table_actions); + fn handle_start_bootstrap(&self, routers: Vec, nodes: Vec) -> BoxFuture<'_, ()> { + async move { + let router_iter = routers.into_iter().filter_map(|r| r.ipv4_addr().ok().map(SocketAddr::V4)); + + let mid_generator = self.aid_generator.lock().unwrap().generate(); + let action_id = mid_generator.action_id(); - let router_iter = routers.into_iter().filter_map(|r| r.ipv4_addr().ok().map(SocketAddr::V4)); + let bootstrap_complete = { + let table_bootstrap = { + let routing_table = self.routing_table.read().unwrap(); + TableBootstrap::new(routing_table.node_id(), mid_generator, nodes, router_iter) + }; - let mid_generator = work_storage.aid_generator.generate(); - let action_id = mid_generator.action_id(); - let mut table_bootstrap = TableBootstrap::new(work_storage.routing_table.node_id(), mid_generator, nodes, router_iter); + // Begin the bootstrap operation + let bootstrap_status = table_bootstrap + .start_bootstrap(self.out_channel.clone(), self.scheduled_task_sender.clone()) + .await; - // Begin the bootstrap operation - let bootstrap_status = table_bootstrap.start_bootstrap(&work_storage.out_channel, event_loop); + self.bootstrapping.store(true, Ordering::SeqCst); + self.table_actions.lock().unwrap().insert( + action_id, + TableAction::Bootstrap(Arc::new(table_bootstrap), Arc::new(AtomicUsize::default())), + ); - work_storage.bootstrapping = true; - table_actions.insert(action_id, TableAction::Bootstrap(table_bootstrap, 0)); + match bootstrap_status { + Ok(BootstrapStatus::Idle) => true, + Ok(BootstrapStatus::Bootstrapping) => false, + Err(BootstrapStatus::Failed) => { + self.handle_shutdown(ShutdownCause::Unspecified); + false + } + Ok(_) | Err(_) => unreachable!(), + } + }; - let bootstrap_complete = match bootstrap_status { - BootstrapStatus::Idle => true, - BootstrapStatus::Bootstrapping => false, - BootstrapStatus::Failed => { - shutdown_event_loop(event_loop, ShutdownCause::Unspecified); - false + if bootstrap_complete { + self.broadcast_bootstrap_completed(action_id).await; + } } - BootstrapStatus::Completed => { - // Check if our bootstrap was actually good - if should_rebootstrap(&work_storage.routing_table) { - let Some(&mut TableAction::Bootstrap(ref mut bootstrap, ref mut attempts)) = table_actions.get_mut(&action_id) - else { - panic!("bip_dht: Bug, in DhtHandler...") - }; + .boxed() + } - attempt_rebootstrap(bootstrap, attempts, work_storage, event_loop) == Some(false) + fn handle_start_lookup(&self, info_hash: InfoHash, should_announce: bool) -> BoxFuture<'_, ()> { + async move { + let mid_generator = self.aid_generator.lock().unwrap().generate(); + let action_id = mid_generator.action_id(); + + if self.bootstrapping.load(Ordering::Acquire) { + // Queue it up if we are currently bootstrapping + self.future_actions + .lock() + .unwrap() + .push(PostBootstrapAction::Lookup(info_hash, should_announce)); } else { - true + let node_id = self.routing_table.read().unwrap().node_id(); + // Start the lookup right now if not bootstrapping + match TableLookup::new( + node_id, + info_hash, + mid_generator, + should_announce, + self.routing_table.clone(), + self.out_channel.clone(), + self.scheduled_task_sender.clone(), + ) + .await + { + Some(lookup) => { + self.table_actions + .lock() + .unwrap() + .insert(action_id, TableAction::Lookup(Arc::new(lookup))); + } + None => self.handle_shutdown(ShutdownCause::Unspecified), + } } } - }; + .boxed() + } - if bootstrap_complete { - broadcast_bootstrap_completed(action_id, table_actions, work_storage, event_loop); + fn handle_shutdown(&self, cause: ShutdownCause) { + self.broadcast_dht_event(DhtEvent::ShuttingDown(cause)); } -} -fn handle_start_lookup( - table_actions: &mut HashMap, - work_storage: &mut DetachedDhtHandler, - event_loop: &mut EventLoop>, - info_hash: InfoHash, - should_announce: bool, -) where - H: HandshakerTrait, -{ - let mid_generator = work_storage.aid_generator.generate(); - let action_id = mid_generator.action_id(); - - if work_storage.bootstrapping { - // Queue it up if we are currently bootstrapping - work_storage - .future_actions - .push(PostBootstrapAction::Lookup(info_hash, should_announce)); - } else { - // Start the lookup right now if not bootstrapping - match TableLookup::new( - work_storage.routing_table.node_id(), - info_hash, - mid_generator, - should_announce, - &work_storage.routing_table, - &work_storage.out_channel, - event_loop, - ) { - Some(lookup) => { - table_actions.insert(action_id, TableAction::Lookup(lookup)); - } - None => shutdown_event_loop(event_loop, ShutdownCause::Unspecified), + async fn handle_scheduled_task(&self, task: ScheduledTaskCheck) { + match task { + ScheduledTaskCheck::TableRefresh(trans_id) => { + self.handle_check_table_refresh(trans_id).await; + } + ScheduledTaskCheck::BootstrapTimeout(trans_id) => { + self.handle_check_bootstrap_timeout(trans_id).await; + } + ScheduledTaskCheck::LookupTimeout(trans_id) => { + self.handle_check_lookup_timeout(trans_id).await; + } + ScheduledTaskCheck::LookupEndGame(trans_id) => { + self.handle_check_lookup_endgame(trans_id).await; + } } } -} -fn handle_shutdown(handler: &mut DhtHandler, event_loop: &mut EventLoop>, cause: ShutdownCause) -where - H: HandshakerTrait, -{ - let (work_storage, _) = (&mut handler.detached, &mut handler.table_actions); + async fn handle_check_table_refresh(&self, trans_id: TransactionID) { + let table_actions = self.table_actions.lock().unwrap().get(&trans_id.action_id()).cloned(); + + let opt_refresh_status = match table_actions { + Some(TableAction::Refresh(refresh)) => Some( + refresh + .continue_refresh( + self.routing_table.clone(), + self.out_channel.clone(), + self.scheduled_task_sender.clone(), + ) + .await, + ), + Some(TableAction::Lookup(_)) => { + tracing::error!( + "bip_dht: Resolved a TransactionID to a check table refresh but TableLookup \ + found..." + ); + None + } + Some(TableAction::Bootstrap(_, _)) => { + tracing::error!( + "bip_dht: Resolved a TransactionID to a check table refresh but \ + TableBootstrap found..." + ); + None + } + None => { + tracing::error!( + "bip_dht: Resolved a TransactionID to a check table refresh but no action \ + found..." + ); + None + } + }; - broadcast_dht_event(&mut work_storage.event_notifiers, DhtEvent::ShuttingDown(cause)); + match opt_refresh_status { + Some(RefreshStatus::Refreshing) | None => (), + Some(RefreshStatus::Failed) => self.handle_shutdown(ShutdownCause::Unspecified), + } + } - event_loop.shutdown(); -} + async fn handle_check_bootstrap_timeout(&self, trans_id: TransactionID) { + let bootstrap_complete = { + let table_actions = self.table_actions.lock().unwrap().get(&trans_id.action_id()).cloned(); + + let opt_bootstrap_info = match table_actions { + Some(TableAction::Bootstrap(bootstrap, attempts)) => Some(( + bootstrap + .recv_timeout::( + trans_id, + self.routing_table.clone(), + self.out_channel.clone(), + self.scheduled_task_sender.clone(), + ) + .await, + bootstrap, + attempts, + )), + Some(TableAction::Lookup(_)) => { + tracing::error!( + "bip_dht: Resolved a TransactionID to a check table bootstrap but \ + TableLookup found..." + ); + None + } + Some(TableAction::Refresh(_)) => { + tracing::error!( + "bip_dht: Resolved a TransactionID to a check table bootstrap but \ + TableRefresh found..." + ); + None + } + None => { + tracing::error!( + "bip_dht: Resolved a TransactionID to a check table bootstrap but no \ + action found..." + ); + None + } + }; -fn handle_check_table_refresh( - table_actions: &mut HashMap, - work_storage: &DetachedDhtHandler, - event_loop: &mut EventLoop>, - trans_id: TransactionID, -) where - H: HandshakerTrait, -{ - let opt_refresh_status = match table_actions.get_mut(&trans_id.action_id()) { - Some(&mut TableAction::Refresh(ref mut refresh)) => { - Some(refresh.continue_refresh(&work_storage.routing_table, &work_storage.out_channel, event_loop)) - } - Some(&mut TableAction::Lookup(_)) => { - error!( - "bip_dht: Resolved a TransactionID to a check table refresh but TableLookup \ - found..." - ); - None - } - Some(&mut TableAction::Bootstrap(_, _)) => { - error!( - "bip_dht: Resolved a TransactionID to a check table refresh but \ - TableBootstrap found..." - ); - None - } - None => { - error!( - "bip_dht: Resolved a TransactionID to a check table refresh but no action \ - found..." - ); - None - } - }; + match opt_bootstrap_info { + Some((BootstrapStatus::Idle, _, _)) => true, + Some((BootstrapStatus::Bootstrapping, _, _)) | None => false, + Some((BootstrapStatus::Failed, _, _)) => { + self.handle_shutdown(ShutdownCause::Unspecified); + false + } + Some((BootstrapStatus::Completed, bootstrap, attempts)) => { + // Check if our bootstrap was actually good + if should_rebootstrap(&self.routing_table.read().unwrap()) { + attempt_rebootstrap( + bootstrap, + attempts, + self.routing_table.clone(), + self.out_channel.clone(), + self.main_task_sender.clone(), + self.scheduled_task_sender.clone(), + ) + .await + == Some(false) + } else { + true + } + } + } + }; - match opt_refresh_status { - Some(RefreshStatus::Refreshing) | None => (), - Some(RefreshStatus::Failed) => shutdown_event_loop(event_loop, ShutdownCause::Unspecified), + if bootstrap_complete { + self.broadcast_bootstrap_completed(trans_id.action_id()).await; + } } -} -fn handle_check_bootstrap_timeout( - handler: &mut DhtHandler, - event_loop: &mut EventLoop>, - trans_id: TransactionID, -) where - H: HandshakerTrait, -{ - let (work_storage, table_actions) = (&mut handler.detached, &mut handler.table_actions); - - let bootstrap_complete = { - let opt_bootstrap_info = match table_actions.get_mut(&trans_id.action_id()) { - Some(&mut TableAction::Bootstrap(ref mut bootstrap, ref mut attempts)) => Some(( - bootstrap.recv_timeout(&trans_id, &work_storage.routing_table, &work_storage.out_channel, event_loop), - bootstrap, - attempts, + async fn handle_check_lookup_timeout(&self, trans_id: TransactionID) { + let table_actions = self.table_actions.lock().unwrap().get(&trans_id.action_id()).cloned(); + + let opt_lookup_info = match table_actions { + Some(TableAction::Lookup(lookup)) => Some(( + lookup + .recv_timeout( + trans_id, + self.routing_table.clone(), + self.out_channel.clone(), + self.scheduled_task_sender.clone(), + ) + .await, + lookup.info_hash(), )), - Some(&mut TableAction::Lookup(_)) => { - error!( - "bip_dht: Resolved a TransactionID to a check table bootstrap but \ - TableLookup found..." + Some(TableAction::Bootstrap(_, _)) => { + tracing::error!( + "bip_dht: Resolved a TransactionID to a check table lookup but TableBootstrap \ + found..." ); None } - Some(&mut TableAction::Refresh(_)) => { - error!( - "bip_dht: Resolved a TransactionID to a check table bootstrap but \ - TableRefresh found..." + Some(TableAction::Refresh(_)) => { + tracing::error!( + "bip_dht: Resolved a TransactionID to a check table lookup but TableRefresh \ + found..." ); None } None => { - error!( - "bip_dht: Resolved a TransactionID to a check table bootstrap but no \ - action found..." + tracing::error!( + "bip_dht: Resolved a TransactionID to a check table lookup but no action \ + found..." ); None } }; - match opt_bootstrap_info { - Some((BootstrapStatus::Idle, _, _)) => true, - Some((BootstrapStatus::Bootstrapping, _, _)) | None => false, - Some((BootstrapStatus::Failed, _, _)) => { - shutdown_event_loop(event_loop, ShutdownCause::Unspecified); - false - } - Some((BootstrapStatus::Completed, bootstrap, attempts)) => { - // Check if our bootstrap was actually good - if should_rebootstrap(&work_storage.routing_table) { - attempt_rebootstrap(bootstrap, attempts, work_storage, event_loop) == Some(false) - } else { - true + match opt_lookup_info { + Some((LookupStatus::Searching, _)) | None => (), + Some((LookupStatus::Completed, info_hash)) => { + self.broadcast_dht_event(DhtEvent::LookupCompleted(info_hash)); + } + Some((LookupStatus::Failed, _)) => self.handle_shutdown(ShutdownCause::Unspecified), + Some((LookupStatus::Values(v), info_hash)) => { + // Add values to handshaker + for v4_addr in v { + let sock_addr = SocketAddr::V4(v4_addr); + + self.handshaker.lock().await.connect(None, info_hash, sock_addr).await; } } } - }; - - if bootstrap_complete { - broadcast_bootstrap_completed(trans_id.action_id(), table_actions, work_storage, event_loop); } -} -fn handle_check_lookup_timeout(handler: &mut DhtHandler, event_loop: &mut EventLoop>, trans_id: TransactionID) -where - H: HandshakerTrait, -{ - let (work_storage, table_actions) = (&mut handler.detached, &mut handler.table_actions); - - let opt_lookup_info = match table_actions.get_mut(&trans_id.action_id()) { - Some(&mut TableAction::Lookup(ref mut lookup)) => Some(( - lookup.recv_timeout(&trans_id, &work_storage.routing_table, &work_storage.out_channel, event_loop), - lookup.info_hash(), - )), - Some(&mut TableAction::Bootstrap(_, _)) => { - error!( - "bip_dht: Resolved a TransactionID to a check table lookup but TableBootstrap \ - found..." - ); - None - } - Some(&mut TableAction::Refresh(_)) => { - error!( - "bip_dht: Resolved a TransactionID to a check table lookup but TableRefresh \ - found..." - ); - None - } - None => { - error!( - "bip_dht: Resolved a TransactionID to a check table lookup but no action \ - found..." - ); - None - } - }; + async fn handle_check_lookup_endgame(&self, trans_id: TransactionID) { + let table_actions = self.table_actions.lock().unwrap().get(&trans_id.action_id()).cloned(); - match opt_lookup_info { - Some((LookupStatus::Searching, _)) | None => (), - Some((LookupStatus::Completed, info_hash)) => { - broadcast_dht_event(&mut work_storage.event_notifiers, DhtEvent::LookupCompleted(info_hash)); - } - Some((LookupStatus::Failed, _)) => shutdown_event_loop(event_loop, ShutdownCause::Unspecified), - Some((LookupStatus::Values(v), info_hash)) => { - // Add values to handshaker - for v4_addr in v { - let sock_addr = SocketAddr::V4(v4_addr); + let opt_lookup_info = match table_actions { + Some(TableAction::Lookup(lookup)) => { + let handshaker_port = self.handshaker.lock().await.port(); + + Some(( + lookup + .recv_finished(handshaker_port, self.routing_table.clone(), self.out_channel.clone()) + .await, + lookup.info_hash(), + )) + } + Some(TableAction::Bootstrap(_, _)) => { + tracing::error!( + "bip_dht: Resolved a TransactionID to a check table lookup but TableBootstrap \ + found..." + ); + None + } + Some(TableAction::Refresh(_)) => { + tracing::error!( + "bip_dht: Resolved a TransactionID to a check table lookup but TableRefresh \ + found..." + ); + None + } + None => { + tracing::error!( + "bip_dht: Resolved a TransactionID to a check table lookup but no action \ + found..." + ); + None + } + }; + + match opt_lookup_info { + Some((LookupStatus::Searching, _)) | None => (), + Some((LookupStatus::Completed, info_hash)) => { + self.broadcast_dht_event(DhtEvent::LookupCompleted(info_hash)); + } + Some((LookupStatus::Failed, _)) => self.handle_shutdown(ShutdownCause::Unspecified), + Some((LookupStatus::Values(v), info_hash)) => { + // Add values to handshaker + for v4_addr in v { + let sock_addr = SocketAddr::V4(v4_addr); - work_storage.handshaker.connect(None, info_hash, sock_addr); + self.handshaker.lock().await.connect(None, info_hash, sock_addr).await; + } } } } -} -fn handle_check_lookup_endgame(handler: &mut DhtHandler, event_loop: &EventLoop>, trans_id: TransactionID) -where - H: HandshakerTrait, -{ - let (work_storage, table_actions) = (&mut handler.detached, &mut handler.table_actions); - - let opt_lookup_info = match table_actions.remove(&trans_id.action_id()) { - Some(TableAction::Lookup(mut lookup)) => Some(( - lookup.recv_finished( - work_storage.handshaker.port(), - &work_storage.routing_table, - &work_storage.out_channel, - ), - lookup.info_hash(), - )), - Some(TableAction::Bootstrap(_, _)) => { - error!( - "bip_dht: Resolved a TransactionID to a check table lookup but TableBootstrap \ - found..." - ); - None - } - Some(TableAction::Refresh(_)) => { - error!( - "bip_dht: Resolved a TransactionID to a check table lookup but TableRefresh \ - found..." - ); - None - } - None => { - error!( - "bip_dht: Resolved a TransactionID to a check table lookup but no action \ - found..." - ); - None + fn broadcast_dht_event(&self, event: DhtEvent) { + self.event_notifiers + .lock() + .unwrap() + .retain(|send| send.clone().try_send(event).is_ok()); + } + + async fn broadcast_bootstrap_completed(&self, action_id: ActionID) { + // Send notification that the bootstrap has completed. + self.broadcast_dht_event(DhtEvent::BootstrapCompleted); + + // Indicates we are out of the bootstrapping phase + self.bootstrapping.store(false, Ordering::Release); + + // Remove the bootstrap action from our table actions + { + let mut table_actions = self.table_actions.lock().unwrap(); + table_actions.remove(&action_id); } - }; + // Start the post bootstrap actions. + let mut future_actions = self.future_actions.lock().unwrap().split_off(0); + for table_action in future_actions.drain(..) { + match table_action { + PostBootstrapAction::Lookup(info_hash, should_announce) => { + drop(table_action); + self.handle_start_lookup(info_hash, should_announce).await; + } + PostBootstrapAction::Refresh(refresh, trans_id) => { + { + let mut table_actions = self.table_actions.lock().unwrap(); + table_actions.insert(trans_id.action_id(), TableAction::Refresh(Arc::new(*refresh))); + } - match opt_lookup_info { - Some((LookupStatus::Searching, _)) | None => (), - Some((LookupStatus::Completed, info_hash)) => { - broadcast_dht_event(&mut work_storage.event_notifiers, DhtEvent::LookupCompleted(info_hash)); + self.handle_check_table_refresh(trans_id).await; + } + } } - Some((LookupStatus::Failed, _)) => shutdown_event_loop(event_loop, ShutdownCause::Unspecified), - Some((LookupStatus::Values(v), info_hash)) => { - // Add values to handshaker - for v4_addr in v { - let sock_addr = SocketAddr::V4(v4_addr); + } +} + +// ----------------------------------------------------------------------------// - work_storage.handshaker.connect(None, info_hash, sock_addr); +/// Attempt to rebootstrap or shutdown the dht if we have no nodes after rebootstrapping multiple time. +/// Returns None if the DHT is shutting down, Some(true) if the rebootstrap process started, Some(false) if a rebootstrap is not necessary. +fn attempt_rebootstrap( + bootstrap: Arc, + attempts: Arc, + routing_table: Arc>, + out: mpsc::Sender<(Vec, SocketAddr)>, + main_task_sender: mpsc::Sender, + scheduled_task_sender: mpsc::Sender, +) -> BoxFuture<'static, Option> { + async move { + // Increment the bootstrap counter + let attempt = attempts.fetch_add(1, Ordering::AcqRel) + 1; + + tracing::warn!("bip_dht: Bootstrap attempt {} failed, attempting a rebootstrap...", attempt); + + // Check if we reached the maximum bootstrap attempts + if attempt >= MAX_BOOTSTRAP_ATTEMPTS { + if num_good_nodes(&routing_table.read().unwrap()) == 0 { + // Failed to get any nodes in the rebootstrap attempts, shut down + shutdown_event_loop(main_task_sender, ShutdownCause::BootstrapFailed).await; + None + } else { + Some(false) + } + } else { + let bootstrap_status = bootstrap.start_bootstrap(out.clone(), scheduled_task_sender.clone()).await; + + match bootstrap_status { + Ok(BootstrapStatus::Idle) => Some(false), + Ok(BootstrapStatus::Bootstrapping) => Some(true), + Err(BootstrapStatus::Failed) => { + shutdown_event_loop(main_task_sender, ShutdownCause::Unspecified).await; + None + } + Ok(_) | Err(_) => unreachable!(), } } } + .boxed() +} + +/// Shut down the event loop by sending it a shutdown message with the given cause. +async fn shutdown_event_loop(mut main_task_sender: mpsc::Sender, cause: ShutdownCause) { + if main_task_sender.send(OneshotTask::Shutdown(cause)).await.is_err() { + tracing::error!("bip_dht: Failed to send a shutdown message to the EventLoop..."); + } +} + +/// Number of good nodes in the `RoutingTable`. +fn num_good_nodes(table: &RoutingTable) -> usize { + table + .closest_nodes(table.node_id()) + .filter(|n| n.status() == NodeStatus::Good) + .count() +} + +/// We should rebootstrap if we have a low number of nodes. +fn should_rebootstrap(table: &RoutingTable) -> bool { + num_good_nodes(table) <= BOOTSTRAP_GOOD_NODE_THRESHOLD } diff --git a/packages/dht/src/worker/lookup.rs b/packages/dht/src/worker/lookup.rs index f79152d78..e1643bf08 100644 --- a/packages/dht/src/worker/lookup.rs +++ b/packages/dht/src/worker/lookup.rs @@ -1,34 +1,30 @@ use std::collections::{HashMap, HashSet}; -use std::fmt::Debug; use std::net::{SocketAddr, SocketAddrV4}; -use std::sync::mpsc::SyncSender; +use std::sync::atomic::{AtomicBool, Ordering}; +use std::sync::{Arc, Mutex, RwLock}; use bencode::BRefAccess; -use log::{error, warn}; -use mio::{EventLoop, Timeout}; +use futures::channel::mpsc; +use futures::channel::mpsc::SendError; +use futures::future::BoxFuture; +use futures::{FutureExt, SinkExt as _}; +use tokio::task::JoinSet; +use tokio::time::{sleep, Duration, Instant}; use util::bt::{self, InfoHash, NodeId}; use util::net; use util::sha::ShaHash; -use crate::handshaker_trait::HandshakerTrait; use crate::message::announce_peer::{AnnouncePeerRequest, ConnectPort}; use crate::message::get_peers::{CompactInfoType, GetPeersRequest, GetPeersResponse}; use crate::routing::bucket; use crate::routing::node::{Node, NodeStatus}; use crate::routing::table::RoutingTable; use crate::transaction::{MIDGenerator, TransactionID}; -use crate::worker::handler::DhtHandler; use crate::worker::ScheduledTaskCheck; const LOOKUP_TIMEOUT_MS: u64 = 1500; const ENDGAME_TIMEOUT_MS: u64 = 1500; -// Currently using the aggressive variant of the standard lookup procedure. -// https://people.kth.se/~rauljc/p2p11/jimenez2011subsecond.pdf - -// TODO: Handle case where a request round fails, should we fail the whole lookup (clear active lookups?) -// TODO: Clean up the code in this module. - const INITIAL_PICK_NUM: usize = 4; // Alpha const ITERATIVE_PICK_NUM: usize = 3; // Beta const ANNOUNCE_PICK_NUM: usize = 8; // # Announces @@ -49,137 +45,127 @@ pub enum LookupStatus { pub struct TableLookup { table_id: NodeId, target_id: InfoHash, - in_endgame: bool, - // If we have received any values in the lookup. - recv_values: bool, - id_generator: MIDGenerator, + in_endgame: AtomicBool, + recv_values: AtomicBool, + id_generator: Mutex, will_announce: bool, - // DistanceToBeat is the distance that the responses of the current lookup needs to beat, - // interestingly enough (and super important), this distance may not be equal to the - // requested node's distance - active_lookups: HashMap, - announce_tokens: HashMap>, - requested_nodes: HashSet, - // Storing whether or not it has ever been pinged so that we - // can perform the brute force lookup if the lookup failed - all_sorted_nodes: Vec<(Distance, Node, bool)>, + active_lookups: Mutex>, + announce_tokens: Mutex>>, + requested_nodes: Mutex>, + all_sorted_nodes: Mutex)>>, + tasks: Arc>>>, } -// Gather nodes - impl TableLookup { - pub fn new( + pub fn new<'a>( table_id: NodeId, target_id: InfoHash, id_generator: MIDGenerator, will_announce: bool, - table: &RoutingTable, - out: &SyncSender<(Vec, SocketAddr)>, - event_loop: &mut EventLoop>, - ) -> Option - where - H: HandshakerTrait, - { - // Pick a buckets worth of nodes and put them into the all_sorted_nodes list - let mut all_sorted_nodes = Vec::with_capacity(bucket::MAX_BUCKET_SIZE); - for node in table - .closest_nodes(target_id) - .filter(|n| n.status() == NodeStatus::Good) - .take(bucket::MAX_BUCKET_SIZE) - { - insert_sorted_node(&mut all_sorted_nodes, target_id, node.clone(), false); - } + table: Arc>, + out: mpsc::Sender<(Vec, SocketAddr)>, + scheduled_task_sender: mpsc::Sender, + ) -> BoxFuture<'a, Option> { + async move { + let all_sorted_nodes = Mutex::new(Vec::with_capacity(bucket::MAX_BUCKET_SIZE)); + + for node in table + .read() + .unwrap() + .closest_nodes(target_id) + .filter(|n| n.status() == NodeStatus::Good) + .take(bucket::MAX_BUCKET_SIZE) + { + insert_sorted_node(&all_sorted_nodes, target_id, node.clone(), false); + } - // Call pick_initial_nodes with the all_sorted_nodes list as an iterator - let initial_pick_nodes = pick_initial_nodes(all_sorted_nodes.iter_mut()); - let initial_pick_nodes_filtered = initial_pick_nodes.iter().filter(|&&(_, good)| good).map(|(node, _)| { - let distance_to_beat = node.id() ^ target_id; - - (node, distance_to_beat) - }); - - // Construct the lookup table structure - let mut table_lookup = TableLookup { - table_id, - target_id, - in_endgame: false, - recv_values: false, - id_generator, - will_announce, - all_sorted_nodes, - announce_tokens: HashMap::new(), - requested_nodes: HashSet::new(), - active_lookups: HashMap::with_capacity(INITIAL_PICK_NUM), - }; + let initial_pick_nodes = pick_initial_nodes(all_sorted_nodes.lock().unwrap().iter_mut()); + let initial_pick_nodes_filtered = initial_pick_nodes.iter().filter(|&&(_, good)| good).map(|(node, _)| { + let distance_to_beat = node.id() ^ target_id; + + (node, distance_to_beat) + }); + + let table_lookup = TableLookup { + table_id, + target_id, + in_endgame: AtomicBool::default(), + recv_values: AtomicBool::default(), + id_generator: Mutex::new(id_generator), + will_announce, + all_sorted_nodes, + announce_tokens: Mutex::new(HashMap::new()), + requested_nodes: Mutex::new(HashSet::new()), + active_lookups: Mutex::new(HashMap::with_capacity(INITIAL_PICK_NUM)), + tasks: Arc::default(), + }; - // Call start_request_round with the list of initial_nodes (return even if the search completed...for now :D) - if table_lookup.start_request_round(initial_pick_nodes_filtered, table, out, event_loop) == LookupStatus::Failed { - None - } else { - Some(table_lookup) + if table_lookup + .start_request_round( + initial_pick_nodes_filtered, + table.clone(), + out.clone(), + &scheduled_task_sender, + ) + .await + == LookupStatus::Failed + { + None + } else { + Some(table_lookup) + } } + .boxed() } pub fn info_hash(&self) -> InfoHash { self.target_id } - pub fn recv_response( - &mut self, + pub async fn recv_response( + &self, node: Node, - trans_id: &TransactionID, + trans_id: TransactionID, msg: GetPeersResponse<'_, B>, - table: &RoutingTable, - out: &SyncSender<(Vec, SocketAddr)>, - event_loop: &mut EventLoop>, + table: Arc>, + out: mpsc::Sender<(Vec, SocketAddr)>, + scheduled_task_sender: mpsc::Sender, ) -> LookupStatus where - H: HandshakerTrait, B: BRefAccess + Clone, - B::BType: PartialEq + Eq + core::hash::Hash + Debug, + B::BType: PartialEq + Eq + core::hash::Hash + std::fmt::Debug, { - // Process the message transaction id - let Some((dist_to_beat, timeout)) = self.active_lookups.remove(trans_id) else { - warn!( + let Some((dist_to_beat, _)) = self.active_lookups.lock().unwrap().remove(&trans_id) else { + tracing::warn!( "bip_dht: Received expired/unsolicited node response for an active table \ lookup..." ); return self.current_lookup_status(); }; - // Cancel the timeout (if this is not an endgame response) - if !self.in_endgame { - event_loop.clear_timeout(timeout); - } - - // Add the announce token to our list of tokens if let Some(token) = msg.token() { - self.announce_tokens.insert(node, token.to_vec()); + self.announce_tokens.lock().unwrap().insert(node, token.to_vec()); } - // Pull out the contact information from the message let (opt_values, opt_nodes) = match msg.info_type() { CompactInfoType::Nodes(n) => (None, Some(n)), CompactInfoType::Values(v) => { - self.recv_values = true; + self.recv_values.store(true, Ordering::Relaxed); (Some(v.into_iter().collect()), None) } CompactInfoType::Both(n, v) => (Some(v.into_iter().collect()), Some(n)), }; - // Check if we beat the distance, get the next distance to beat let (iterate_nodes, next_dist_to_beat) = if let Some(nodes) = opt_nodes { #[allow(clippy::mutable_key_type)] let requested_nodes = &self.requested_nodes; - // Filter for nodes that we have already requested from let already_requested = |node_info: &(NodeId, SocketAddrV4)| { let node = Node::as_questionable(node_info.0, SocketAddr::V4(node_info.1)); - !requested_nodes.contains(&node) + !requested_nodes.lock().unwrap().contains(&node) }; - // Get the closest distance (or the current distance) let next_dist_to_beat = nodes .into_iter() .filter(&already_requested) @@ -193,27 +179,24 @@ impl TableLookup { } }); - // Check if we got closer (equal to is not enough) let iterate_nodes = if next_dist_to_beat < dist_to_beat { let iterate_nodes = pick_iterate_nodes(nodes.into_iter().filter(&already_requested), self.target_id); - // Push nodes into the all nodes list for (id, v4_addr) in nodes { let addr = SocketAddr::V4(v4_addr); let node = Node::as_questionable(id, addr); let will_ping = iterate_nodes.iter().any(|(n, _)| n == &node); - insert_sorted_node(&mut self.all_sorted_nodes, self.target_id, node, will_ping); + insert_sorted_node(&self.all_sorted_nodes, self.target_id, node, will_ping); } Some(iterate_nodes) } else { - // Push nodes into the all nodes list for (id, v4_addr) in nodes { let addr = SocketAddr::V4(v4_addr); let node = Node::as_questionable(id, addr); - insert_sorted_node(&mut self.all_sorted_nodes, self.target_id, node, false); + insert_sorted_node(&self.all_sorted_nodes, self.target_id, node, false); } None @@ -224,18 +207,21 @@ impl TableLookup { (None, dist_to_beat) }; - // Check if we need to iterate (not in the endgame already) - if !self.in_endgame { - // If the node gave us a closer id than its own to the target id, continue the search + if !self.in_endgame.load(Ordering::Relaxed) { if let Some(ref nodes) = iterate_nodes { let filtered_nodes = nodes.iter().filter(|&&(_, good)| good).map(|(n, _)| (n, next_dist_to_beat)); - if self.start_request_round(filtered_nodes, table, out, event_loop) == LookupStatus::Failed { + if self + .start_request_round(filtered_nodes, table.clone(), out.clone(), &scheduled_task_sender) + .await + == LookupStatus::Failed + { return LookupStatus::Failed; } } - // If there are not more active lookups, start the endgame - if self.active_lookups.is_empty() && self.start_endgame_round(table, out, event_loop) == LookupStatus::Failed { + if self.active_lookups.lock().unwrap().is_empty() + && self.start_endgame_round(table, out.clone(), scheduled_task_sender).await == LookupStatus::Failed + { return LookupStatus::Failed; } } @@ -246,56 +232,57 @@ impl TableLookup { } } - pub fn recv_timeout( - &mut self, - trans_id: &TransactionID, - table: &RoutingTable, - out: &SyncSender<(Vec, SocketAddr)>, - event_loop: &mut EventLoop>, - ) -> LookupStatus - where - H: HandshakerTrait, - { - if self.active_lookups.remove(trans_id).is_none() { - warn!( + pub async fn recv_timeout( + &self, + trans_id: TransactionID, + table: Arc>, + out: mpsc::Sender<(Vec, SocketAddr)>, + scheduled_task_sender: mpsc::Sender, + ) -> LookupStatus { + if self.active_lookups.lock().unwrap().remove(&trans_id).is_none() { + tracing::warn!( "bip_dht: Received expired/unsolicited node timeout for an active table \ lookup..." ); return self.current_lookup_status(); } - if !self.in_endgame { - // If there are not more active lookups, start the endgame - if self.active_lookups.is_empty() && self.start_endgame_round(table, out, event_loop) == LookupStatus::Failed { - return LookupStatus::Failed; - } + if !self.in_endgame.load(Ordering::Relaxed) + && self.active_lookups.lock().unwrap().is_empty() + && self.start_endgame_round(table, out.clone(), scheduled_task_sender).await == LookupStatus::Failed + { + return LookupStatus::Failed; } self.current_lookup_status() } - pub fn recv_finished( - &mut self, + pub async fn recv_finished( + &self, handshake_port: u16, - table: &RoutingTable, - out: &SyncSender<(Vec, SocketAddr)>, + table: Arc>, + mut out: mpsc::Sender<(Vec, SocketAddr)>, ) -> LookupStatus { let mut fatal_error = false; - // Announce if we were told to if self.will_announce { - // Partial borrow so the filter function doesn't capture all of self #[allow(clippy::mutable_key_type)] let announce_tokens = &self.announce_tokens; + let mut node_announces = Vec::new(); + for (_, node, _) in self .all_sorted_nodes + .lock() + .unwrap() .iter() - .filter(|&(_, node, _)| announce_tokens.contains_key(node)) + .filter(|&(_, node, _)| announce_tokens.lock().unwrap().contains_key(node)) .take(ANNOUNCE_PICK_NUM) + .cloned() { - let trans_id = self.id_generator.generate(); - let token = announce_tokens.get(node).unwrap(); + let trans_id = self.id_generator.lock().unwrap().generate(); + let announce_tokens = announce_tokens.lock().unwrap(); + let token = announce_tokens.get(&node).unwrap(); let announce_peer_req = AnnouncePeerRequest::new( trans_id.as_ref(), @@ -306,26 +293,29 @@ impl TableLookup { ); let announce_peer_msg = announce_peer_req.encode(); - if out.send((announce_peer_msg, node.addr())).is_err() { - error!( + node_announces.push((node, announce_peer_msg)); + } + + for (node, announce_peer_msg) in node_announces { + if out.send((announce_peer_msg, node.addr())).await.is_err() { + tracing::error!( "bip_dht: TableLookup announce request failed to send through the out \ - channel..." + channel..." ); fatal_error = true; } + let routing_table = table.read().unwrap(); if !fatal_error { - // We requested from the node, mark it down if the node is in our routing table - if let Some(n) = table.find_node(node) { + if let Some(n) = routing_table.find_node(&node) { n.local_request(); } } } } - // This may not be cleared since we didn't set a timeout for each node, any nodes that didn't respond would still be in here. - self.active_lookups.clear(); - self.in_endgame = false; + self.active_lookups.lock().unwrap().clear(); + self.in_endgame.store(false, Ordering::Relaxed); if fatal_error { LookupStatus::Failed @@ -335,115 +325,117 @@ impl TableLookup { } fn current_lookup_status(&self) -> LookupStatus { - if self.in_endgame || !self.active_lookups.is_empty() { + if self.in_endgame.load(Ordering::Relaxed) || !self.active_lookups.lock().unwrap().is_empty() { LookupStatus::Searching } else { LookupStatus::Completed } } - fn start_request_round<'a, H, I>( - &mut self, + async fn start_request_round<'a, I>( + &self, nodes: I, - table: &RoutingTable, - out: &SyncSender<(Vec, SocketAddr)>, - event_loop: &mut EventLoop>, + table: Arc>, + mut out: mpsc::Sender<(Vec, SocketAddr)>, + scheduled_task_sender: &mpsc::Sender, ) -> LookupStatus where I: Iterator, - H: HandshakerTrait, { - // Loop through the given nodes let mut messages_sent = 0; for (node, dist_to_beat) in nodes { - // Generate a transaction id for this message - let trans_id = self.id_generator.generate(); + let trans_id = self.id_generator.lock().unwrap().generate(); - // Try to start a timeout for the node - let res_timeout = event_loop.timeout_ms((0, ScheduledTaskCheck::LookupTimeout(trans_id)), LOOKUP_TIMEOUT_MS); - let Ok(timeout) = res_timeout else { - error!("bip_dht: Failed to set a timeout for a table lookup..."); - return LookupStatus::Failed; - }; + let timeout = Instant::now() + Duration::from_millis(LOOKUP_TIMEOUT_MS); - // Associate the transaction id with the distance the returned nodes must beat and the timeout token - self.active_lookups.insert(trans_id, (dist_to_beat, timeout)); + self.active_lookups.lock().unwrap().insert(trans_id, (dist_to_beat, timeout)); - // Send the message to the node let get_peers_msg = GetPeersRequest::new(trans_id.as_ref(), self.table_id, self.target_id).encode(); - if out.send((get_peers_msg, node.addr())).is_err() { - error!("bip_dht: Could not send a lookup message through the channel..."); + if out.send((get_peers_msg, node.addr())).await.is_err() { + tracing::error!("bip_dht: Could not send a lookup message through the channel..."); return LookupStatus::Failed; } - // We requested from the node, mark it down - self.requested_nodes.insert(node.clone()); + self.requested_nodes.lock().unwrap().insert(node.clone()); + + let routing_table = table.read().unwrap(); - // Update the node in the routing table - if let Some(n) = table.find_node(node) { + if let Some(n) = routing_table.find_node(node) { n.local_request(); } messages_sent += 1; + + // Schedule a timeout check + let mut this_scheduled_task_sender = scheduled_task_sender.clone(); + self.tasks.lock().unwrap().spawn(async move { + sleep(Duration::from_millis(LOOKUP_TIMEOUT_MS)).await; + + match this_scheduled_task_sender + .send(ScheduledTaskCheck::LookupTimeout(trans_id)) + .await + { + Ok(()) => { + tracing::debug!("sent scheduled lookup timeout"); + Ok(()) + } + Err(e) => { + tracing::debug!("error sending scheduled lookup timeout: {e}"); + Err(e) + } + } + }); } if messages_sent == 0 { - self.active_lookups.clear(); + self.active_lookups.lock().unwrap().clear(); LookupStatus::Completed } else { LookupStatus::Searching } } - fn start_endgame_round( - &mut self, - table: &RoutingTable, - out: &SyncSender<(Vec, SocketAddr)>, - event_loop: &mut EventLoop>, - ) -> LookupStatus - where - H: HandshakerTrait, - { - // Entering the endgame phase - self.in_endgame = true; - - // Try to start a global message timeout for the endgame - let res_timeout = event_loop.timeout_ms( - (0, ScheduledTaskCheck::LookupEndGame(self.id_generator.generate())), - ENDGAME_TIMEOUT_MS, - ); - let Ok(timeout) = res_timeout else { - error!("bip_dht: Failed to set a timeout for table lookup endgame..."); - return LookupStatus::Failed; - }; + async fn start_endgame_round( + &self, + table: Arc>, + mut out: mpsc::Sender<(Vec, SocketAddr)>, + _scheduled_task_sender: mpsc::Sender, + ) -> LookupStatus { + self.in_endgame.store(true, Ordering::SeqCst); + + let timeout = Instant::now() + Duration::from_millis(ENDGAME_TIMEOUT_MS); + + let mut endgame_messages = Vec::new(); - // Request all unpinged nodes if we didn't receive any values - if !self.recv_values { - for node_info in self.all_sorted_nodes.iter_mut().filter(|&&mut (_, _, req)| !req) { - let &mut (ref node_dist, ref node, ref mut req) = node_info; + if !self.recv_values.load(Ordering::SeqCst) { + { + let all_nodes = self.all_sorted_nodes.lock().unwrap(); + for node_info in all_nodes.iter().filter(|(_, _, req)| !req.load(Ordering::Acquire)) { + let (node_dist, node, req) = node_info; + + let trans_id = self.id_generator.lock().unwrap().generate(); + + self.active_lookups.lock().unwrap().insert(trans_id, (*node_dist, timeout)); - // Generate a transaction id for this message - let trans_id = self.id_generator.generate(); + let get_peers_msg = GetPeersRequest::new(trans_id.as_ref(), self.table_id, self.target_id).encode(); - // Associate the transaction id with this node's distance and its timeout token - // We don't actually need to keep track of this information, but we do still need to - // filter out unsolicited responses by using the active_lookups map!!! - self.active_lookups.insert(trans_id, (*node_dist, timeout)); + endgame_messages.push((node.clone(), get_peers_msg, req.clone())); + } + } - // Send the message to the node - let get_peers_msg = GetPeersRequest::new(trans_id.as_ref(), self.table_id, self.target_id).encode(); - if out.send((get_peers_msg, node.addr())).is_err() { - error!("bip_dht: Could not send an endgame message through the channel..."); + for (node, get_peers_msg, req) in endgame_messages { + if out.send((get_peers_msg, node.addr())).await.is_err() { + tracing::error!("bip_dht: Could not send an endgame message through the channel..."); return LookupStatus::Failed; } - // Mark that we requested from the node in the RoutingTable - if let Some(n) = table.find_node(node) { + let routing_table = table.read().unwrap(); + + if let Some(n) = routing_table.find_node(&node) { n.local_request(); } - // Mark that we requested from the node - *req = true; + req.store(true, Ordering::Release); } } @@ -454,7 +446,7 @@ impl TableLookup { /// Picks a number of nodes from the sorted distance iterator to ping on the first round. fn pick_initial_nodes<'a, I>(sorted_nodes: I) -> [(Node, bool); INITIAL_PICK_NUM] where - I: Iterator, + I: Iterator)>, { let dummy_id = [0u8; bt::NODE_ID_LEN].into(); let default = (Node::as_bad(dummy_id, net::default_route_v4()), false); @@ -465,13 +457,12 @@ where dst.1 = true; // Mark that the node has been requested from - src.2 = true; + src.2.store(true, Ordering::Relaxed); } pick_nodes } -/// Picks a number of nodes from the unsorted distance iterator to ping on iterative rounds. fn pick_iterate_nodes(unsorted_nodes: I, target_id: InfoHash) -> [(Node, bool); ITERATIVE_PICK_NUM] where I: Iterator, @@ -490,14 +481,11 @@ where pick_nodes } -/// Inserts the node into the slice if a slot in the slice is unused or a node -/// in the slice is further from the target id than the node being inserted. fn insert_closest_nodes(nodes: &mut [(Node, bool)], target_id: InfoHash, new_node: Node) { let new_distance = target_id ^ new_node.id(); for &mut (ref mut old_node, ref mut used) in &mut *nodes { if *used { - // Slot is in use, see if our node is closer to the target let old_distance = target_id ^ old_node.id(); if new_distance < old_distance { @@ -505,7 +493,6 @@ fn insert_closest_nodes(nodes: &mut [(Node, bool)], target_id: InfoHash, new_nod return; } } else { - // Slot was not in use, go ahead and place the node *old_node = new_node; *used = true; return; @@ -513,25 +500,18 @@ fn insert_closest_nodes(nodes: &mut [(Node, bool)], target_id: InfoHash, new_nod } } -/// Inserts the Node into the list of nodes based on its distance from the target node. -/// -/// Nodes at the start of the list are closer to the target node than nodes at the end. -fn insert_sorted_node(nodes: &mut Vec<(Distance, Node, bool)>, target: InfoHash, node: Node, pinged: bool) { +fn insert_sorted_node(nodes: &Mutex)>>, target: InfoHash, node: Node, pinged: bool) { + let mut nodes = nodes.lock().unwrap(); let node_id = node.id(); let node_dist = target ^ node_id; - // Perform a search by distance from the target id let search_result = nodes.binary_search_by(|&(dist, _, _)| dist.cmp(&node_dist)); match search_result { Ok(dup_index) => { - // TODO: Bug here, what happens when multiple nodes with the same distance are - // present, but we don't get the index of the duplicate node (its in the list) from - // the search, then we would have a duplicate node in the list! - // Insert only if this node is different (it is ok if they have the same id) if nodes[dup_index].1 != node { - nodes.insert(dup_index, (node_dist, node, pinged)); + nodes.insert(dup_index, (node_dist, node, Arc::new(AtomicBool::new(pinged)))); } } - Err(ins_index) => nodes.insert(ins_index, (node_dist, node, pinged)), + Err(ins_index) => nodes.insert(ins_index, (node_dist, node, Arc::new(AtomicBool::new(pinged)))), }; } diff --git a/packages/dht/src/worker/messenger.rs b/packages/dht/src/worker/messenger.rs index dd237c895..38808fefc 100644 --- a/packages/dht/src/worker/messenger.rs +++ b/packages/dht/src/worker/messenger.rs @@ -1,38 +1,43 @@ -use std::net::{SocketAddr, UdpSocket}; -use std::sync::mpsc::{self, SyncSender}; -use std::thread; +use std::net::SocketAddr; +use std::sync::Arc; -use log::{info, warn}; -use mio::Sender; +use futures::channel::mpsc; +use futures::stream::StreamExt; +use futures::SinkExt as _; +use tokio::net::UdpSocket; +use tokio::task; use crate::worker::OneshotTask; const OUTGOING_MESSAGE_CAPACITY: usize = 4096; #[allow(clippy::module_name_repetitions)] -pub fn create_outgoing_messenger(socket: UdpSocket) -> SyncSender<(Vec, SocketAddr)> { - let (send, recv) = mpsc::sync_channel::<(Vec, SocketAddr)>(OUTGOING_MESSAGE_CAPACITY); +pub fn create_outgoing_messenger(socket: &Arc) -> mpsc::Sender<(Vec, SocketAddr)> { + #[allow(clippy::type_complexity)] + let (send, mut recv): (mpsc::Sender<(Vec, SocketAddr)>, mpsc::Receiver<(Vec, SocketAddr)>) = + mpsc::channel(OUTGOING_MESSAGE_CAPACITY); - thread::spawn(move || { - for (message, addr) in recv { - send_bytes(&socket, &message[..], addr); + let socket = socket.clone(); + task::spawn(async move { + while let Some((message, addr)) = recv.next().await { + send_bytes(&socket, &message[..], addr).await; } - info!("bip_dht: Outgoing messenger received a channel hangup, exiting thread..."); + tracing::info!("bip_dht: Outgoing messenger received a channel hangup, exiting thread..."); }); send } -fn send_bytes(socket: &UdpSocket, bytes: &[u8], addr: SocketAddr) { +async fn send_bytes(socket: &UdpSocket, bytes: &[u8], addr: SocketAddr) { let mut bytes_sent = 0; while bytes_sent != bytes.len() { - if let Ok(num_sent) = socket.send_to(&bytes[bytes_sent..], addr) { + if let Ok(num_sent) = socket.send_to(&bytes[bytes_sent..], &addr).await { bytes_sent += num_sent; } else { // TODO: Maybe shut down in this case, will fail on every write... - warn!( + tracing::warn!( "bip_dht: Outgoing messenger failed to write {} bytes to {}; {} bytes written \ before error...", bytes.len(), @@ -45,25 +50,28 @@ fn send_bytes(socket: &UdpSocket, bytes: &[u8], addr: SocketAddr) { } #[allow(clippy::module_name_repetitions)] -pub fn create_incoming_messenger(socket: UdpSocket, send: Sender) { - thread::spawn(move || { - let mut channel_is_open = true; +pub fn create_incoming_messenger(socket: Arc, send: mpsc::Sender) { + task::spawn(async move { + let mut buffer = vec![0u8; 1500]; - while channel_is_open { - let mut buffer = vec![0u8; 1500]; - - if let Ok((size, addr)) = socket.recv_from(&mut buffer) { - buffer.truncate(size); - channel_is_open = send_message(&send, buffer, addr); - } else { - warn!("bip_dht: Incoming messenger failed to receive bytes..."); - }; + loop { + match socket.recv_from(&mut buffer).await { + Ok((size, addr)) => { + let message = buffer[..size].to_vec(); + if !send_message(&send, message, addr).await { + break; + } + } + Err(_) => { + tracing::warn!("bip_dht: Incoming messenger failed to receive bytes..."); + } + } } - info!("bip_dht: Incoming messenger received a channel hangup, exiting thread..."); + tracing::info!("bip_dht: Incoming messenger received a channel hangup, exiting thread..."); }); } -fn send_message(send: &Sender, bytes: Vec, addr: SocketAddr) -> bool { - send.send(OneshotTask::Incoming(bytes, addr)).is_ok() +async fn send_message(send: &mpsc::Sender, bytes: Vec, addr: SocketAddr) -> bool { + send.clone().send(OneshotTask::Incoming(bytes, addr)).await.is_ok() } diff --git a/packages/dht/src/worker/mod.rs b/packages/dht/src/worker/mod.rs index a99f338ae..fd21e1d87 100644 --- a/packages/dht/src/worker/mod.rs +++ b/packages/dht/src/worker/mod.rs @@ -1,7 +1,9 @@ -use std::io; -use std::net::{SocketAddr, UdpSocket}; -use std::sync::mpsc; +use std::net::SocketAddr; +use std::sync::Arc; +use futures::channel::mpsc; +use tokio::net::UdpSocket; +use tokio::task::JoinSet; use util::bt::InfoHash; use crate::handshaker_trait::HandshakerTrait; @@ -40,6 +42,7 @@ pub enum ScheduledTaskCheck { /// Check the progress of a current lookup. LookupTimeout(TransactionID), /// Check the progress of the lookup endgame. + #[allow(dead_code)] LookupEndGame(TransactionID), } @@ -68,14 +71,14 @@ pub enum ShutdownCause { /// Spawns the necessary workers that make up our local DHT node and connects them via channels /// so that they can send and receive DHT messages. pub fn start_mainline_dht( - send_socket: UdpSocket, - recv_socket: UdpSocket, + send_socket: &Arc, + recv_socket: Arc, read_only: bool, _: Option, handshaker: H, - kill_sock: UdpSocket, + kill_sock: Arc, kill_addr: SocketAddr, -) -> io::Result> +) -> (mpsc::Sender, JoinSet<()>) where H: HandshakerTrait + 'static, { @@ -83,9 +86,9 @@ where // TODO: Utilize the security extension. let routing_table = RoutingTable::new(table::random_node_id()); - let message_sender = handler::create_dht_handler(routing_table, outgoing, read_only, handshaker, kill_sock, kill_addr)?; + let message_sender = handler::create_dht_handler(routing_table, outgoing, read_only, handshaker, kill_sock, kill_addr); - messenger::create_incoming_messenger(recv_socket, message_sender.clone()); + messenger::create_incoming_messenger(recv_socket, message_sender.0.clone()); - Ok(message_sender) + message_sender } diff --git a/packages/dht/src/worker/refresh.rs b/packages/dht/src/worker/refresh.rs index 0bae2945c..aff63a398 100644 --- a/packages/dht/src/worker/refresh.rs +++ b/packages/dht/src/worker/refresh.rs @@ -1,16 +1,17 @@ use std::net::SocketAddr; -use std::sync::mpsc::SyncSender; +use std::sync::atomic::{AtomicUsize, Ordering}; +use std::sync::{Arc, Mutex, RwLock}; -use log::{error, info}; -use mio::EventLoop; +use futures::channel::mpsc::{self, SendError}; +use futures::SinkExt as _; +use tokio::task::JoinSet; +use tokio::time::{sleep, Duration}; use util::bt::{self, NodeId}; -use crate::handshaker_trait::HandshakerTrait; use crate::message::find_node::FindNodeRequest; use crate::routing::node::NodeStatus; use crate::routing::table::{self, RoutingTable}; use crate::transaction::MIDGenerator; -use crate::worker::handler::DhtHandler; use crate::worker::ScheduledTaskCheck; const REFRESH_INTERVAL_TIMEOUT: u64 = 6000; @@ -25,49 +26,60 @@ pub enum RefreshStatus { #[allow(clippy::module_name_repetitions)] pub struct TableRefresh { - id_generator: MIDGenerator, - curr_refresh_bucket: usize, + id_generator: Mutex, + curr_refresh_bucket: AtomicUsize, + tasks: Arc>>>, } impl TableRefresh { pub fn new(id_generator: MIDGenerator) -> TableRefresh { TableRefresh { - id_generator, - curr_refresh_bucket: 0, + id_generator: Mutex::new(id_generator), + curr_refresh_bucket: AtomicUsize::default(), + tasks: Arc::default(), } } - pub fn continue_refresh( - &mut self, - table: &RoutingTable, - out: &SyncSender<(Vec, SocketAddr)>, - event_loop: &mut EventLoop>, - ) -> RefreshStatus - where - H: HandshakerTrait, - { - if self.curr_refresh_bucket == table::MAX_BUCKETS { - self.curr_refresh_bucket = 0; - } - let target_id = flip_id_bit_at_index(table.node_id(), self.curr_refresh_bucket); + pub async fn continue_refresh( + &self, + table: Arc>, + mut out: mpsc::Sender<(Vec, SocketAddr)>, + mut scheduled_task_sender: mpsc::Sender, + ) -> RefreshStatus { + let refresh_bucket = match self.curr_refresh_bucket.load(Ordering::Relaxed) { + table::MAX_BUCKETS => { + self.curr_refresh_bucket.store(0, Ordering::Relaxed); + 0 + } + refresh_bucket => refresh_bucket, + }; + + let (node, target_id, node_id) = { + let routing_table = table.read().unwrap(); + let node_id = routing_table.node_id(); + let target_id = flip_id_bit_at_index(node_id, refresh_bucket); + let node = routing_table + .closest_nodes(target_id) + .find(|n| n.status() == NodeStatus::Questionable) + .cloned(); + + tracing::info!("bip_dht: Performing a refresh for bucket {}", refresh_bucket); + + (node, target_id, node_id) + }; - info!("bip_dht: Performing a refresh for bucket {}", self.curr_refresh_bucket); // Ping the closest questionable node - for node in table - .closest_nodes(target_id) - .filter(|n| n.status() == NodeStatus::Questionable) - .take(1) - { + if let Some(node) = node { // Generate a transaction id for the request - let trans_id = self.id_generator.generate(); + let trans_id = self.id_generator.lock().unwrap().generate(); // Construct the message - let find_node_req = FindNodeRequest::new(trans_id.as_ref(), table.node_id(), target_id); + let find_node_req = FindNodeRequest::new(trans_id.as_ref(), node_id, target_id); let find_node_msg = find_node_req.encode(); // Send the message - if out.send((find_node_msg, node.addr())).is_err() { - error!( + if out.send((find_node_msg, node.addr())).await.is_err() { + tracing::error!( "bip_dht: TableRefresh failed to send a refresh message to the out \ channel..." ); @@ -79,18 +91,25 @@ impl TableRefresh { } // Generate a dummy transaction id (only the action id will be used) - let trans_id = self.id_generator.generate(); + let trans_id = self.id_generator.lock().unwrap().generate(); // Start a timer for the next refresh - if event_loop - .timeout_ms((0, ScheduledTaskCheck::TableRefresh(trans_id)), REFRESH_INTERVAL_TIMEOUT) - .is_err() - { - error!("bip_dht: TableRefresh failed to set a timeout for the next refresh..."); - return RefreshStatus::Failed; - } + self.tasks.lock().unwrap().spawn(async move { + sleep(Duration::from_millis(REFRESH_INTERVAL_TIMEOUT)).await; + + match scheduled_task_sender.send(ScheduledTaskCheck::TableRefresh(trans_id)).await { + Ok(()) => { + tracing::debug!("sent scheduled refresh timeout"); + Ok(()) + } + Err(e) => { + tracing::debug!("error sending scheduled refresh timeout: {e}"); + Err(e) + } + } + }); - self.curr_refresh_bucket += 1; + self.curr_refresh_bucket.fetch_add(1, Ordering::SeqCst); RefreshStatus::Refreshing } diff --git a/packages/disk/Cargo.toml b/packages/disk/Cargo.toml index bd5c14171..608c70af0 100644 --- a/packages/disk/Cargo.toml +++ b/packages/disk/Cargo.toml @@ -19,18 +19,19 @@ version.workspace = true metainfo = { path = "../metainfo" } util = { path = "../util" } -bytes = "0.4" -crossbeam = "0.8" -error-chain = "0.12" -futures = "0.1" -futures-cpupool = "0.1" -log = "0.4" -lru-cache = "0.1" +bytes = "1" +crossbeam = "0" +futures = "0" +lru-cache = "0" +pin-project = "1" +thiserror = "1" +tokio = { version = "1", features = ["full"] } +tracing = "0" [dev-dependencies] -criterion = "0.5" -rand = "0.8" -tokio-core = "0.1" +criterion = { version = "0", features = ["async_tokio"] } +rand = "0" +tracing-subscriber = "0" [[bench]] harness = false diff --git a/packages/disk/benches/disk_benchmark.rs b/packages/disk/benches/disk_benchmark.rs index 327ce1b0e..2fbc581a7 100644 --- a/packages/disk/benches/disk_benchmark.rs +++ b/packages/disk/benches/disk_benchmark.rs @@ -1,22 +1,23 @@ -use std::fs; -use std::sync::{Arc, Mutex}; +use std::sync::Arc; use bytes::BytesMut; use criterion::{criterion_group, criterion_main, BenchmarkId, Criterion}; +use disk::error::TorrentError; use disk::fs::NativeFileSystem; use disk::fs_cache::FileHandleCache; use disk::{Block, BlockMetadata, DiskManagerBuilder, FileSystem, IDiskMessage, InfoHash, ODiskMessage}; -use futures::sink::{self, Sink}; -use futures::stream::{self, Stream}; +use futures::{SinkExt, StreamExt}; use metainfo::{DirectAccessor, Metainfo, MetainfoBuilder, PieceLength}; +use rand::Rng; +use tokio::sync::Mutex; /// Set to true if you are playing around with anything that could affect file /// sizes for an existing or new benchmarks. As a precaution, if the disk manager /// sees an existing file with a different size but same name as one of the files -/// in the torrent, it wont touch it and a `TorrentError` will be generated. +/// in the torrent, it won't touch it and a `TorrentError` will be generated. const WIPE_DATA_DIR: bool = false; -// TODO: Benchmark multi file torrents!!! +// TODO: Benchmark multi-file torrents!!! /// Generates a torrent with a single file of the given length. /// @@ -24,7 +25,7 @@ const WIPE_DATA_DIR: bool = false; fn generate_single_file_torrent(piece_len: usize, file_len: usize) -> (Metainfo, Vec) { let mut buffer = vec![0u8; file_len]; - rand::Rng::fill(&mut rand::thread_rng(), buffer.as_mut_slice()); + rand::thread_rng().fill(buffer.as_mut_slice()); let metainfo_bytes = { let accessor = DirectAccessor::new("benchmark_file", &buffer[..]); @@ -40,43 +41,59 @@ fn generate_single_file_torrent(piece_len: usize, file_len: usize) -> (Metainfo, } /// Adds the given metainfo file to the given sender, and waits for the added notification. -fn add_metainfo_file(metainfo: Metainfo, block_send: &mut sink::Wait, block_recv: &mut stream::Wait) +async fn add_metainfo_file(metainfo: Metainfo, block_send: Arc>, block_recv: Arc>) where - S: Sink, - R: Stream, + S: futures::Sink + Unpin, + S::Error: std::fmt::Debug, + R: futures::Stream> + Unpin, { - block_send.send(IDiskMessage::AddTorrent(metainfo)).unwrap(); + { + let mut block_send_guard = block_send.lock().await; + block_send_guard.send(IDiskMessage::AddTorrent(metainfo)).await.unwrap(); + } - for res_message in block_recv { - match res_message.unwrap() { + while let Some(res_message) = { + let mut block_recv_guard = block_recv.lock().await; + block_recv_guard.next().await + } { + let error = match res_message.unwrap() { ODiskMessage::TorrentAdded(_) => { break; } - ODiskMessage::FoundGoodPiece(_, _) => (), - _ => panic!("Didn't Receive TorrentAdded"), + ODiskMessage::FoundGoodPiece(_, _) => continue, + ODiskMessage::TorrentError(_, error) => error, + + other => panic!("should receive `TorrentAdded` or `FoundGoodPiece`, but got: {other:?}"), + }; + + match error { + TorrentError::ExistingInfoHash { .. } => break, + other => panic!("should receive `TorrentAdded` or `FoundGoodPiece`, but got: {other:?}"), } } } struct ProcessBlockData where - S: Sink, - R: Stream, + S: futures::Sink + Unpin, + S::Error: std::fmt::Debug, + R: futures::Stream> + Unpin, { piece_length: usize, block_length: usize, info_hash: InfoHash, bytes: Vec, - block_send: Arc>>, - block_recv: Arc>>, + block_send: Arc>, + block_recv: Arc>, } /// Pushes the given bytes as piece blocks to the given sender, and blocks until all notifications /// of the blocks being processed have been received (does not check piece messages). -fn process_blocks(data: &ProcessBlockData) +async fn process_blocks(data: Arc>) where - S: Sink, - R: Stream, + S: futures::Sink + Unpin, + S::Error: std::fmt::Debug, + R: futures::Stream> + Unpin, { let piece_length = data.piece_length; let block_length = data.block_length; @@ -98,16 +115,23 @@ where bytes.freeze(), ); - block_send.lock().unwrap().send(IDiskMessage::ProcessBlock(block)).unwrap(); + { + let mut block_send_guard = block_send.lock().await; + block_send_guard.send(IDiskMessage::ProcessBlock(block)).await.unwrap(); + } // MutexGuard is dropped here + blocks_sent += 1; } } - for res_message in &mut *block_recv.lock().unwrap() { + while let Some(res_message) = { + let mut block_recv_guard = block_recv.lock().await; + block_recv_guard.next().await + } { match res_message.unwrap() { ODiskMessage::BlockProcessed(_) => blocks_sent -= 1, ODiskMessage::FoundGoodPiece(_, _) | ODiskMessage::FoundBadPiece(_, _) => (), - _ => panic!("Unexpected Message Received In process_blocks"), + other => panic!("should receive `BlockProcessed`, `FoundGoodPiece` or `FoundBadPiece`, but got: {other:?}"), } if blocks_sent == 0 { @@ -123,9 +147,10 @@ fn bench_process_file_with_fs( piece_length: usize, block_length: usize, file_length: usize, - fs: F, + fs: Arc, ) where - F: FileSystem + Send + Sync + 'static, + F: FileSystem + Sync + 'static, + Arc: Send + Sync, { let (metainfo, bytes) = generate_single_file_torrent(piece_length, file_length); let info_hash = metainfo.info().info_hash(); @@ -135,23 +160,33 @@ fn bench_process_file_with_fs( .with_stream_buffer_capacity(1_000_000) .build(fs); - let (d_send, d_recv) = disk_manager.split(); - - let block_send = Arc::new(Mutex::new(d_send.wait())); - let block_recv = Arc::new(Mutex::new(d_recv.wait())); + let (d_send, d_recv) = disk_manager.into_parts(); - add_metainfo_file(metainfo, &mut block_send.lock().unwrap(), &mut block_recv.lock().unwrap()); + let block_send = Arc::new(Mutex::new(d_send)); + let block_recv = Arc::new(Mutex::new(d_recv)); let data = ProcessBlockData { piece_length, block_length, info_hash, bytes, - block_send, - block_recv, + block_send: block_send.clone(), + block_recv: block_recv.clone(), }; - c.bench_with_input(id, &data, |b, i| b.iter(|| process_blocks(i))); + let runner = &tokio::runtime::Runtime::new().unwrap(); + + c.bench_with_input(id, &Arc::new(data), |b, i| { + let metainfo_clone = metainfo.clone(); + b.to_async(runner).iter(move || { + let data = i.clone(); + let metainfo = metainfo_clone.clone(); + async move { + add_metainfo_file(metainfo, data.block_send.clone(), data.block_recv.clone()).await; + process_blocks(data).await; + } + }); + }); } fn bench_native_fs_1_mb_pieces_128_kb_blocks(c: &mut Criterion) { @@ -161,13 +196,13 @@ fn bench_native_fs_1_mb_pieces_128_kb_blocks(c: &mut Criterion) { let data_directory = "target/bench_data/bench_native_fs_1_mb_pieces_128_kb_blocks"; if WIPE_DATA_DIR { - drop(fs::remove_dir_all(data_directory)); + drop(std::fs::remove_dir_all(data_directory)); } let filesystem = NativeFileSystem::with_directory(data_directory); let id = BenchmarkId::new("bench_native_fs", "1_mb_pieces_128_kb_blocks"); - bench_process_file_with_fs(c, id, piece_length, block_length, file_length, filesystem); + bench_process_file_with_fs(c, id, piece_length, block_length, file_length, Arc::new(filesystem)); } fn bench_native_fs_1_mb_pieces_16_kb_blocks(c: &mut Criterion) { @@ -177,13 +212,13 @@ fn bench_native_fs_1_mb_pieces_16_kb_blocks(c: &mut Criterion) { let data_directory = "target/bench_data/bench_native_fs_1_mb_pieces_16_kb_blocks"; if WIPE_DATA_DIR { - drop(fs::remove_dir_all(data_directory)); + drop(std::fs::remove_dir_all(data_directory)); } let filesystem = NativeFileSystem::with_directory(data_directory); let id = BenchmarkId::new("bench_native_fs", "1_mb_pieces_16_kb_blocks"); - bench_process_file_with_fs(c, id, piece_length, block_length, file_length, filesystem); + bench_process_file_with_fs(c, id, piece_length, block_length, file_length, Arc::new(filesystem)); } fn bench_native_fs_1_mb_pieces_2_kb_blocks(c: &mut Criterion) { @@ -193,13 +228,13 @@ fn bench_native_fs_1_mb_pieces_2_kb_blocks(c: &mut Criterion) { let data_directory = "target/bench_data/bench_native_fs_1_mb_pieces_2_kb_blocks"; if WIPE_DATA_DIR { - drop(fs::remove_dir_all(data_directory)); + drop(std::fs::remove_dir_all(data_directory)); } let filesystem = NativeFileSystem::with_directory(data_directory); let id = BenchmarkId::new("bench_native_fs", "1_mb_pieces_2_kb_blocks"); - bench_process_file_with_fs(c, id, piece_length, block_length, file_length, filesystem); + bench_process_file_with_fs(c, id, piece_length, block_length, file_length, Arc::new(filesystem)); } fn bench_file_handle_cache_fs_1_mb_pieces_128_kb_blocks(c: &mut Criterion) { @@ -209,13 +244,13 @@ fn bench_file_handle_cache_fs_1_mb_pieces_128_kb_blocks(c: &mut Criterion) { let data_directory = "target/bench_data/bench_native_fs_1_mb_pieces_128_kb_blocks"; if WIPE_DATA_DIR { - drop(fs::remove_dir_all(data_directory)); + drop(std::fs::remove_dir_all(data_directory)); } let filesystem = FileHandleCache::new(NativeFileSystem::with_directory(data_directory), 1); let id = BenchmarkId::new("bench_file_handle_cache_fs", "1_mb_pieces_128_kb_blocks"); - bench_process_file_with_fs(c, id, piece_length, block_length, file_length, filesystem); + bench_process_file_with_fs(c, id, piece_length, block_length, file_length, Arc::new(filesystem)); } fn bench_file_handle_cache_fs_1_mb_pieces_16_kb_blocks(c: &mut Criterion) { @@ -225,13 +260,13 @@ fn bench_file_handle_cache_fs_1_mb_pieces_16_kb_blocks(c: &mut Criterion) { let data_directory = "target/bench_data/bench_native_fs_1_mb_pieces_16_kb_blocks"; if WIPE_DATA_DIR { - drop(fs::remove_dir_all(data_directory)); + drop(std::fs::remove_dir_all(data_directory)); } let filesystem = FileHandleCache::new(NativeFileSystem::with_directory(data_directory), 1); let id = BenchmarkId::new("bench_file_handle_cache_fs", "1_mb_pieces_16_kb_blocks"); - bench_process_file_with_fs(c, id, piece_length, block_length, file_length, filesystem); + bench_process_file_with_fs(c, id, piece_length, block_length, file_length, Arc::new(filesystem)); } fn bench_file_handle_cache_fs_1_mb_pieces_2_kb_blocks(c: &mut Criterion) { @@ -241,13 +276,13 @@ fn bench_file_handle_cache_fs_1_mb_pieces_2_kb_blocks(c: &mut Criterion) { let data_directory = "target/bench_data/bench_native_fs_1_mb_pieces_2_kb_blocks"; if WIPE_DATA_DIR { - drop(fs::remove_dir_all(data_directory)); + drop(std::fs::remove_dir_all(data_directory)); } let filesystem = FileHandleCache::new(NativeFileSystem::with_directory(data_directory), 1); let id = BenchmarkId::new("bench_file_handle_cache_fs", "1_mb_pieces_2_kb_blocks"); - bench_process_file_with_fs(c, id, piece_length, block_length, file_length, filesystem); + bench_process_file_with_fs(c, id, piece_length, block_length, file_length, Arc::new(filesystem)); } criterion_group!( diff --git a/packages/disk/examples/add_torrent.rs b/packages/disk/examples/add_torrent.rs index a7565a269..f2bab7fe5 100644 --- a/packages/disk/examples/add_torrent.rs +++ b/packages/disk/examples/add_torrent.rs @@ -1,17 +1,33 @@ -use std::fs::File; -use std::io::{self, BufRead, Read, Write}; +use std::io::{BufRead, Read as _, Write as _}; +use std::sync::{Arc, Once}; use disk::fs::NativeFileSystem; use disk::{DiskManagerBuilder, IDiskMessage, ODiskMessage}; -use futures::{Future, Sink, Stream}; +use futures::{SinkExt, StreamExt}; use metainfo::Metainfo; +use tracing::level_filters::LevelFilter; -fn main() { - println!("Utility For Allocating Disk Space For A Torrent File"); +static INIT: Once = Once::new(); - let stdin = io::stdin(); +fn tracing_stderr_init(filter: LevelFilter) { + let builder = tracing_subscriber::fmt().with_max_level(filter).with_ansi(true); + + builder.pretty().with_file(true).init(); + + tracing::info!("Logging initialized"); +} + +#[tokio::main] +async fn main() { + INIT.call_once(|| { + tracing_stderr_init(LevelFilter::INFO); + }); + + tracing::info!("Utility For Allocating Disk Space For A Torrent File"); + + let stdin = std::io::stdin(); let mut input_lines = stdin.lock().lines(); - let mut stdout = io::stdout(); + let mut stdout = std::io::stdout(); print!("Enter the destination download directory: "); stdout.flush().unwrap(); @@ -22,25 +38,26 @@ fn main() { let torrent_path = input_lines.next().unwrap().unwrap(); let mut torrent_bytes = Vec::new(); - File::open(torrent_path).unwrap().read_to_end(&mut torrent_bytes).unwrap(); + std::fs::File::open(torrent_path) + .unwrap() + .read_to_end(&mut torrent_bytes) + .unwrap(); let metainfo_file = Metainfo::from_bytes(torrent_bytes).unwrap(); - let native_fs = NativeFileSystem::with_directory(download_path); - let disk_manager = DiskManagerBuilder::new().build(native_fs); + let filesystem = NativeFileSystem::with_directory(download_path); + let disk_manager = DiskManagerBuilder::new().build(Arc::new(filesystem)); - let (disk_send, disk_recv) = disk_manager.split(); + let (mut disk_send, mut disk_recv) = disk_manager.into_parts(); let total_pieces = metainfo_file.info().pieces().count(); - disk_send.send(IDiskMessage::AddTorrent(metainfo_file)).wait().unwrap(); - - println!(); + disk_send.send(IDiskMessage::AddTorrent(metainfo_file)).await.unwrap(); let mut good_pieces = 0; - for recv_msg in disk_recv.wait() { + while let Some(recv_msg) = disk_recv.next().await { match recv_msg.unwrap() { ODiskMessage::TorrentAdded(hash) => { - println!("Torrent With Hash {hash:?} Successfully Added"); - println!("Torrent Has {good_pieces} Good Pieces Out Of {total_pieces} Total Pieces"); + tracing::info!("Torrent With Hash {hash:?} Successfully Added"); + tracing::info!("Torrent Has {good_pieces} Good Pieces Out Of {total_pieces} Total Pieces"); break; } ODiskMessage::FoundGoodPiece(_, _) => good_pieces += 1, diff --git a/packages/disk/src/disk/fs/cache/file_handle.rs b/packages/disk/src/disk/fs/cache/file_handle.rs index b5428d24a..53bf90db7 100644 --- a/packages/disk/src/disk/fs/cache/file_handle.rs +++ b/packages/disk/src/disk/fs/cache/file_handle.rs @@ -1,4 +1,3 @@ -use std::io; use std::path::{Path, PathBuf}; use std::sync::{Arc, Mutex}; @@ -14,7 +13,8 @@ use crate::disk::fs::FileSystem; #[allow(clippy::module_name_repetitions)] pub struct FileHandleCache where - F: FileSystem, + F: FileSystem + Sync + 'static, + Arc: Send + Sync, { cache: Mutex>>>, inner: F, @@ -22,7 +22,8 @@ where impl FileHandleCache where - F: FileSystem, + F: FileSystem + Sync + 'static, + Arc: Send + Sync, { /// Create a new `FileHandleCache` with the given handle capacity and an /// inner `FileSystem` which will be called for handles not in the cache. @@ -48,11 +49,12 @@ where impl FileSystem for FileHandleCache where - F: FileSystem, + F: FileSystem + Sync + 'static, + Arc: Send + Sync, { type File = Arc>; - fn open_file

(&self, path: P) -> io::Result + fn open_file

(&self, path: P) -> std::io::Result where P: AsRef + Send + 'static, { @@ -71,7 +73,7 @@ where }) } - fn sync_file

(&self, path: P) -> io::Result<()> + fn sync_file

(&self, path: P) -> std::io::Result<()> where P: AsRef + Send + 'static, { @@ -80,7 +82,7 @@ where self.inner.sync_file(path) } - fn file_size(&self, file: &Self::File) -> io::Result { + fn file_size(&self, file: &Self::File) -> std::io::Result { let lock_file = file .lock() .expect("bip_disk: Failed To Lock File In FileHandleCache::file_size"); @@ -88,7 +90,7 @@ where self.inner.file_size(&*lock_file) } - fn read_file(&self, file: &mut Self::File, offset: u64, buffer: &mut [u8]) -> io::Result { + fn read_file(&self, file: &mut Self::File, offset: u64, buffer: &mut [u8]) -> std::io::Result { let mut lock_file = file .lock() .expect("bip_disk: Failed To Lock File In FileHandleCache::read_file"); @@ -96,7 +98,7 @@ where self.inner.read_file(&mut *lock_file, offset, buffer) } - fn write_file(&self, file: &mut Self::File, offset: u64, buffer: &[u8]) -> io::Result { + fn write_file(&self, file: &mut Self::File, offset: u64, buffer: &[u8]) -> std::io::Result { let mut lock_file = file .lock() .expect("bip_disk: Failed To Lock File In FileHandleCache::write_file"); diff --git a/packages/disk/src/disk/fs/mod.rs b/packages/disk/src/disk/fs/mod.rs index 46260a8fa..63c0d0786 100644 --- a/packages/disk/src/disk/fs/mod.rs +++ b/packages/disk/src/disk/fs/mod.rs @@ -1,5 +1,5 @@ -use std::io::{self}; use std::path::Path; +use std::sync::Arc; pub mod cache; pub mod native; @@ -18,7 +18,7 @@ pub trait FileSystem { /// # Errors /// /// It would return an IO error if there is an problem. - fn open_file

(&self, path: P) -> io::Result + fn open_file

(&self, path: P) -> std::io::Result where P: AsRef + Send + 'static; @@ -27,7 +27,7 @@ pub trait FileSystem { /// # Errors /// /// It would return an IO error if there is an problem. - fn sync_file

(&self, path: P) -> io::Result<()> + fn sync_file

(&self, path: P) -> std::io::Result<()> where P: AsRef + Send + 'static; @@ -36,7 +36,7 @@ pub trait FileSystem { /// # Errors /// /// It would return an IO error if there is an problem. - fn file_size(&self, file: &Self::File) -> io::Result; + fn file_size(&self, file: &Self::File) -> std::io::Result; /// Read the contents of the file at the given offset. /// @@ -45,7 +45,7 @@ pub trait FileSystem { /// # Errors /// /// It would return an IO error if there is an problem. - fn read_file(&self, file: &mut Self::File, offset: u64, buffer: &mut [u8]) -> io::Result; + fn read_file(&self, file: &mut Self::File, offset: u64, buffer: &mut [u8]) -> std::io::Result; /// Write the contents of the file at the given offset. /// @@ -55,38 +55,39 @@ pub trait FileSystem { /// # Errors /// /// It would return an IO error if there is an problem. - fn write_file(&self, file: &mut Self::File, offset: u64, buffer: &[u8]) -> io::Result; + fn write_file(&self, file: &mut Self::File, offset: u64, buffer: &[u8]) -> std::io::Result; } impl<'a, F> FileSystem for &'a F where - F: FileSystem, + F: FileSystem + Sync + 'static, + Arc: Send + Sync, { type File = F::File; - fn open_file

(&self, path: P) -> io::Result + fn open_file

(&self, path: P) -> std::io::Result where P: AsRef + Send + 'static, { FileSystem::open_file(*self, path) } - fn sync_file

(&self, path: P) -> io::Result<()> + fn sync_file

(&self, path: P) -> std::io::Result<()> where P: AsRef + Send + 'static, { FileSystem::sync_file(*self, path) } - fn file_size(&self, file: &Self::File) -> io::Result { + fn file_size(&self, file: &Self::File) -> std::io::Result { FileSystem::file_size(*self, file) } - fn read_file(&self, file: &mut Self::File, offset: u64, buffer: &mut [u8]) -> io::Result { + fn read_file(&self, file: &mut Self::File, offset: u64, buffer: &mut [u8]) -> std::io::Result { FileSystem::read_file(*self, file, offset, buffer) } - fn write_file(&self, file: &mut Self::File, offset: u64, buffer: &[u8]) -> io::Result { + fn write_file(&self, file: &mut Self::File, offset: u64, buffer: &[u8]) -> std::io::Result { FileSystem::write_file(*self, file, offset, buffer) } } diff --git a/packages/disk/src/disk/fs/native.rs b/packages/disk/src/disk/fs/native.rs index fc53638f4..a48a34644 100644 --- a/packages/disk/src/disk/fs/native.rs +++ b/packages/disk/src/disk/fs/native.rs @@ -1,6 +1,5 @@ use std::borrow::Cow; -use std::fs::{self, File, OpenOptions}; -use std::io::{self, Read, Seek, SeekFrom, Write}; +use std::io::{Read as _, Seek as _, Write as _}; use std::path::{Path, PathBuf}; use crate::disk::fs::FileSystem; @@ -10,12 +9,12 @@ use crate::disk::fs::FileSystem; /// File that exists on disk. #[allow(clippy::module_name_repetitions)] pub struct NativeFile { - file: File, + file: std::fs::File, } impl NativeFile { /// Create a new `NativeFile`. - fn new(file: File) -> NativeFile { + fn new(file: std::fs::File) -> NativeFile { NativeFile { file } } } @@ -41,7 +40,7 @@ impl NativeFileSystem { impl FileSystem for NativeFileSystem { type File = NativeFile; - fn open_file

(&self, path: P) -> io::Result + fn open_file

(&self, path: P) -> std::io::Result where P: AsRef + Send + 'static, { @@ -51,25 +50,25 @@ impl FileSystem for NativeFileSystem { Ok(NativeFile::new(file)) } - fn sync_file

(&self, _path: P) -> io::Result<()> + fn sync_file

(&self, _path: P) -> std::io::Result<()> where P: AsRef + Send + 'static, { Ok(()) } - fn file_size(&self, file: &NativeFile) -> io::Result { + fn file_size(&self, file: &NativeFile) -> std::io::Result { file.file.metadata().map(|metadata| metadata.len()) } - fn read_file(&self, file: &mut NativeFile, offset: u64, buffer: &mut [u8]) -> io::Result { - file.file.seek(SeekFrom::Start(offset))?; + fn read_file(&self, file: &mut NativeFile, offset: u64, buffer: &mut [u8]) -> std::io::Result { + file.file.seek(std::io::SeekFrom::Start(offset))?; file.file.read(buffer) } - fn write_file(&self, file: &mut NativeFile, offset: u64, buffer: &[u8]) -> io::Result { - file.file.seek(SeekFrom::Start(offset))?; + fn write_file(&self, file: &mut NativeFile, offset: u64, buffer: &[u8]) -> std::io::Result { + file.file.seek(std::io::SeekFrom::Start(offset))?; file.file.write(buffer) } @@ -78,22 +77,25 @@ impl FileSystem for NativeFileSystem { /// Create a new file with read and write options. /// /// Intermediate directories will be created if they do not exist. -fn create_new_file

(path: P) -> io::Result +fn create_new_file

(path: P) -> std::io::Result where P: AsRef, { match path.as_ref().parent() { Some(parent_dir) => { - fs::create_dir_all(parent_dir)?; + std::fs::create_dir_all(parent_dir)?; - OpenOptions::new() + std::fs::OpenOptions::new() .read(true) .write(true) .create(true) .truncate(false) .open(&path) } - None => Err(io::Error::new(io::ErrorKind::InvalidInput, "File Path Has No Paren't")), + None => Err(std::io::Error::new( + std::io::ErrorKind::InvalidInput, + "File Path Has No Paren't", + )), } } diff --git a/packages/disk/src/disk/manager/builder.rs b/packages/disk/src/disk/manager/builder.rs index 324b66df2..dd9d6d030 100644 --- a/packages/disk/src/disk/manager/builder.rs +++ b/packages/disk/src/disk/manager/builder.rs @@ -1,22 +1,24 @@ -use futures_cpupool::Builder; +use std::sync::Arc; use crate::disk::fs::FileSystem; use crate::disk::manager::DiskManager; const DEFAULT_PENDING_SIZE: usize = 10; const DEFAULT_COMPLETED_SIZE: usize = 10; +const DEFAULT_THREAD_POOL_SIZE: usize = 4; /// `DiskManagerBuilder` for building `DiskManager`s with different settings. #[allow(clippy::module_name_repetitions)] pub struct DiskManagerBuilder { - builder: Builder, + thread_pool_size: usize, pending_size: usize, completed_size: usize, } + impl Default for DiskManagerBuilder { fn default() -> Self { Self { - builder: Builder::new(), + thread_pool_size: DEFAULT_THREAD_POOL_SIZE, pending_size: DEFAULT_PENDING_SIZE, completed_size: DEFAULT_COMPLETED_SIZE, } @@ -30,10 +32,10 @@ impl DiskManagerBuilder { DiskManagerBuilder::default() } - /// Use a custom `Builder` for the `CpuPool`. + /// Specify the number of threads for the `ThreadPool`. #[must_use] - pub fn with_worker_config(mut self, config: Builder) -> DiskManagerBuilder { - self.builder = config; + pub fn with_thread_pool_size(mut self, size: usize) -> DiskManagerBuilder { + self.thread_pool_size = size; self } @@ -51,9 +53,10 @@ impl DiskManagerBuilder { self } - /// Retrieve the `CpuPool` builder. - pub fn worker_config(&mut self) -> &mut Builder { - &mut self.builder + /// Retrieve the `ThreadPool` size. + #[must_use] + pub fn thread_pool_size(&self) -> usize { + self.thread_pool_size } /// Retrieve the sink buffer capacity. @@ -69,10 +72,11 @@ impl DiskManagerBuilder { } /// Build a `DiskManager` with the given `FileSystem`. - pub fn build(self, fs: F) -> DiskManager + pub fn build(self, fs: Arc) -> DiskManager where - F: FileSystem + Send + Sync + 'static, + F: FileSystem + Sync + 'static, + Arc: Send + Sync, { - DiskManager::from_builder(self, fs) + DiskManager::from_builder(&self, fs) } } diff --git a/packages/disk/src/disk/manager/mod.rs b/packages/disk/src/disk/manager/mod.rs index 92d27b670..6fe1cbc2a 100644 --- a/packages/disk/src/disk/manager/mod.rs +++ b/packages/disk/src/disk/manager/mod.rs @@ -1,49 +1,55 @@ +//! `DiskManager` object which handles the storage of `Blocks` to the `FileSystem`. + use std::sync::atomic::AtomicUsize; use std::sync::Arc; +use std::task::{Context, Poll}; use crossbeam::queue::SegQueue; -use futures::sync::mpsc; -use futures::{Poll, Sink, StartSend, Stream}; -use sink::DiskManagerSink; -use stream::DiskManagerStream; +use futures::channel::mpsc; +use futures::Stream; +use pin_project::pin_project; +pub use sink::DiskManagerSink; +pub use stream::DiskManagerStream; -use crate::disk::fs::FileSystem; -use crate::disk::tasks::context::DiskManagerContext; -use crate::disk::{IDiskMessage, ODiskMessage}; -use crate::DiskManagerBuilder; +use super::tasks::context::DiskManagerContext; +use super::{IDiskMessage, ODiskMessage}; +use crate::{DiskManagerBuilder, FileSystem}; pub mod builder; pub mod sink; pub mod stream; -/// `DiskManager` object which handles the storage of `Blocks` to the `FileSystem`. #[allow(clippy::module_name_repetitions)] +#[pin_project] #[derive(Debug)] -pub struct DiskManager { +pub struct DiskManager +where + F: FileSystem + Sync + 'static, + Arc: Send + Sync, +{ + #[pin] sink: DiskManagerSink, + #[pin] stream: DiskManagerStream, } -impl DiskManager { +impl DiskManager +where + F: FileSystem + Sync + 'static, + Arc: Send + Sync, +{ /// Create a `DiskManager` from the given `DiskManagerBuilder`. - pub fn from_builder(mut builder: DiskManagerBuilder, fs: F) -> DiskManager { + pub fn from_builder(builder: &DiskManagerBuilder, fs: Arc) -> DiskManager { let cur_sink_capacity = Arc::new(AtomicUsize::new(0)); let sink_capacity = builder.sink_buffer_capacity(); let stream_capacity = builder.stream_buffer_capacity(); - let pool_builder = builder.worker_config(); let (out_send, out_recv) = mpsc::channel(stream_capacity); let context = DiskManagerContext::new(out_send, fs); - let task_queue = Arc::new(SegQueue::new()); + let wake_queue = Arc::new(SegQueue::new()); - let sink = DiskManagerSink::new( - pool_builder.create(), - context, - sink_capacity, - cur_sink_capacity.clone(), - task_queue.clone(), - ); - let stream = DiskManagerStream::new(out_recv, cur_sink_capacity, task_queue.clone()); + let sink = DiskManagerSink::new(context, sink_capacity, cur_sink_capacity.clone(), wake_queue.clone()); + let stream = DiskManagerStream::new(out_recv, cur_sink_capacity, wake_queue.clone()); DiskManager { sink, stream } } @@ -57,27 +63,37 @@ impl DiskManager { } } -impl Sink for DiskManager +impl futures::Sink for DiskManager where F: FileSystem + Send + Sync + 'static, { - type SinkItem = IDiskMessage; - type SinkError = (); + type Error = std::io::Error; - fn start_send(&mut self, item: IDiskMessage) -> StartSend { - self.sink.start_send(item) + fn poll_ready(self: std::pin::Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + self.project().sink.poll_ready(cx) } - fn poll_complete(&mut self) -> Poll<(), ()> { - self.sink.poll_complete() + fn start_send(self: std::pin::Pin<&mut Self>, item: IDiskMessage) -> Result<(), Self::Error> { + self.project().sink.start_send(item) + } + + fn poll_flush(self: std::pin::Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + self.project().sink.poll_flush(cx) + } + + fn poll_close(self: std::pin::Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + self.project().sink.poll_close(cx) } } -impl Stream for DiskManager { - type Item = ODiskMessage; - type Error = (); +impl Stream for DiskManager +where + F: FileSystem + Sync + 'static, + Arc: Send + Sync, +{ + type Item = Result; - fn poll(&mut self) -> Poll, ()> { - self.stream.poll() + fn poll_next(self: std::pin::Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + self.project().stream.poll_next(cx) } } diff --git a/packages/disk/src/disk/manager/sink.rs b/packages/disk/src/disk/manager/sink.rs index 194d46a2c..8ac81dac3 100644 --- a/packages/disk/src/disk/manager/sink.rs +++ b/packages/disk/src/disk/manager/sink.rs @@ -1,13 +1,11 @@ //! `DiskManagerSink` which is the sink portion of a `DiskManager`. use std::sync::atomic::{AtomicUsize, Ordering}; -use std::sync::Arc; +use std::sync::{Arc, Mutex}; +use std::task::{Context, Poll, Waker}; use crossbeam::queue::SegQueue; -use futures::task::{self, Task}; -use futures::{Async, AsyncSink, Poll, Sink, StartSend}; -use futures_cpupool::CpuPool; -use log::info; +use tokio::task::JoinSet; use crate::disk::tasks; use crate::disk::tasks::context::DiskManagerContext; @@ -15,92 +13,136 @@ use crate::{FileSystem, IDiskMessage}; #[allow(clippy::module_name_repetitions)] #[derive(Debug)] -pub struct DiskManagerSink { - pool: CpuPool, +pub struct DiskManagerSink +where + F: FileSystem + Sync + 'static, + Arc: Send + Sync, +{ context: DiskManagerContext, max_capacity: usize, cur_capacity: Arc, - task_queue: Arc>, + wake_queue: Arc>, + task_set: Arc>>, } -impl Clone for DiskManagerSink { +impl Clone for DiskManagerSink +where + F: FileSystem + Sync + 'static, + Arc: Send + Sync, +{ fn clone(&self) -> DiskManagerSink { DiskManagerSink { - pool: self.pool.clone(), context: self.context.clone(), max_capacity: self.max_capacity, cur_capacity: self.cur_capacity.clone(), - task_queue: self.task_queue.clone(), + wake_queue: self.wake_queue.clone(), + task_set: self.task_set.clone(), } } } -impl DiskManagerSink { +impl DiskManagerSink +where + F: FileSystem + Sync + 'static, + Arc: Send + Sync, +{ pub(super) fn new( - pool: CpuPool, context: DiskManagerContext, max_capacity: usize, cur_capacity: Arc, - task_queue: Arc>, + wake_queue: Arc>, ) -> DiskManagerSink { DiskManagerSink { - pool, context, max_capacity, cur_capacity, - task_queue, + wake_queue, + task_set: Arc::default(), } } - fn try_submit_work(&self) -> bool { - let cur_capacity = self.cur_capacity.fetch_add(1, Ordering::SeqCst); + fn try_submit_work(&self, waker: &Waker) -> Result { + let cap = self.cur_capacity.fetch_add(1, Ordering::SeqCst) + 1; + let max = self.max_capacity; + + #[allow(clippy::comparison_chain)] + if cap < max { + tracing::trace!("now have {cap} of capacity: {max}"); + + Ok(cap) + } else if cap == max { + tracing::trace!("at max capacity: {max}"); - if cur_capacity < self.max_capacity { - true + Ok(cap) } else { + self.wake_queue.push(waker.clone()); + tracing::debug!("now have {} pending wakers...", self.wake_queue.len()); + self.cur_capacity.fetch_sub(1, Ordering::SeqCst); + tracing::debug!("at over capacity: {cap} of {max}"); - false + Err(cap) } } } -impl Sink for DiskManagerSink +impl futures::Sink for DiskManagerSink where - F: FileSystem + Send + Sync + 'static, + F: FileSystem + Sync + 'static, + Arc: Send + Sync, { - type SinkItem = IDiskMessage; - type SinkError = (); - - fn start_send(&mut self, item: Self::SinkItem) -> StartSend { - info!("Starting Send For DiskManagerSink With IDiskMessage"); + type Error = std::io::Error; - if self.try_submit_work() { - info!("DiskManagerSink Submitted Work On First Attempt"); - tasks::execute_on_pool(item, &self.pool, self.context.clone()); - - return Ok(AsyncSink::Ready); + fn poll_ready(self: std::pin::Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + match self.try_submit_work(cx.waker()) { + Ok(_remaining) => Poll::Ready(Ok(())), + Err(_full) => Poll::Pending, } + } - // We split the sink and stream, which means these could be polled in different event loops (I think), - // so we need to add our task, but then try to submit work again, in case the receiver processed work - // right after we tried to submit the first time. - info!("DiskManagerSink Failed To Submit Work On First Attempt, Adding Task To Queue"); - self.task_queue.push(task::current()); - - if self.try_submit_work() { - // Receiver will look at the queue but wake us up, even though we don't need it to now... - info!("DiskManagerSink Submitted Work On Second Attempt"); - tasks::execute_on_pool(item, &self.pool, self.context.clone()); + fn start_send(self: std::pin::Pin<&mut Self>, item: IDiskMessage) -> Result<(), Self::Error> { + tracing::trace!("Starting Send For DiskManagerSink With IDiskMessage"); + self.task_set + .lock() + .unwrap() + .spawn(tasks::execute(item, self.context.clone())); + Ok(()) + } - Ok(AsyncSink::Ready) - } else { - // Receiver will look at the queue eventually... - Ok(AsyncSink::NotReady(item)) + fn poll_flush(self: std::pin::Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + let Ok(mut task_set) = self.task_set.try_lock() else { + tracing::warn!("unable to get task_set lock"); + cx.waker().wake_by_ref(); + return Poll::Pending; + }; + + tracing::debug!("flushing the {} tasks", task_set.len()); + + while let Some(ready) = match task_set.poll_join_next(cx) { + Poll::Ready(ready) => ready, + Poll::Pending => { + tracing::debug!("all {} task(s) are still pending...", task_set.len()); + return Poll::Pending; + } + } { + match ready { + Ok(()) => { + tracing::trace!("task completed... with {} remaining...", task_set.len()); + + continue; + } + Err(e) => { + tracing::error!("task completed... with {} remaining, with error: {e}", task_set.len()); + + return Poll::Ready(Err(std::io::Error::new(std::io::ErrorKind::Other, e))); + } + } } + + Poll::Ready(Ok(())) } - fn poll_complete(&mut self) -> Poll<(), Self::SinkError> { - Ok(Async::Ready(())) + fn poll_close(self: std::pin::Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + self.poll_flush(cx) } } diff --git a/packages/disk/src/disk/manager/stream.rs b/packages/disk/src/disk/manager/stream.rs index 17bbb8f8b..1bccdb188 100644 --- a/packages/disk/src/disk/manager/stream.rs +++ b/packages/disk/src/disk/manager/stream.rs @@ -2,67 +2,72 @@ use std::sync::atomic::{AtomicUsize, Ordering}; use std::sync::Arc; +use std::task::{Context, Poll, Waker}; use crossbeam::queue::SegQueue; -use futures::sync::mpsc::Receiver; -use futures::task::Task; -use futures::{Async, Poll, Stream}; -use log::info; +use futures::channel::mpsc; +use futures::{Stream, StreamExt as _}; use crate::ODiskMessage; #[allow(clippy::module_name_repetitions)] #[derive(Debug)] pub struct DiskManagerStream { - recv: Receiver, - cur_capacity: Arc, - task_queue: Arc>, + recv: mpsc::Receiver, + pub cur_capacity: Arc, + wake_queue: Arc>, } impl DiskManagerStream { pub(super) fn new( - recv: Receiver, + recv: mpsc::Receiver, cur_capacity: Arc, - task_queue: Arc>, + wake_queue: Arc>, ) -> DiskManagerStream { DiskManagerStream { recv, cur_capacity, - task_queue, + wake_queue, } } - fn complete_work(&self) { - self.cur_capacity.fetch_sub(1, Ordering::SeqCst); + fn complete_work(&self) -> usize { + let cap = self.cur_capacity.fetch_sub(1, Ordering::SeqCst) - 1; + + tracing::debug!( + "Notify next waker: {} that there is space again: {cap}", + self.wake_queue.len() + ); + if let Some(waker) = self.wake_queue.pop() { + waker.wake(); + }; + + cap } } impl Stream for DiskManagerStream { - type Item = ODiskMessage; - type Error = (); - - fn poll(&mut self) -> Poll, ()> { - info!("Polling DiskManagerStream For ODiskMessage"); + type Item = Result; - match self.recv.poll() { - res @ Ok(Async::Ready(Some( - ODiskMessage::TorrentAdded(_) - | ODiskMessage::TorrentRemoved(_) - | ODiskMessage::TorrentSynced(_) - | ODiskMessage::BlockLoaded(_) - | ODiskMessage::BlockProcessed(_), - ))) => { - self.complete_work(); + fn poll_next(mut self: std::pin::Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + tracing::trace!("Polling DiskManagerStream For ODiskMessage"); - info!("Notifying DiskManager That We Can Submit More Work"); - - while let Some(task) = self.task_queue.pop() { - task.notify(); + match self.recv.poll_next_unpin(cx) { + Poll::Ready(Some(msg)) => { + match msg { + ODiskMessage::TorrentAdded(_) + | ODiskMessage::TorrentRemoved(_) + | ODiskMessage::TorrentSynced(_) + | ODiskMessage::BlockLoaded(_) + | ODiskMessage::BlockProcessed(_) => { + self.complete_work(); + } + _ => {} } - - res + Poll::Ready(Some(Ok(msg))) } - other => other, + Poll::Ready(None) => Poll::Ready(None), + Poll::Pending => Poll::Pending, } } } diff --git a/packages/disk/src/disk/tasks/context.rs b/packages/disk/src/disk/tasks/context.rs index 21a3389b9..33fc1b14a 100644 --- a/packages/disk/src/disk/tasks/context.rs +++ b/packages/disk/src/disk/tasks/context.rs @@ -1,89 +1,115 @@ +use std::collections::hash_map::Entry; use std::collections::HashMap; -use std::sync::{Arc, Mutex, RwLock}; +use std::sync::{Arc, RwLock}; -use futures::sink::{Sink, Wait}; -use futures::sync::mpsc::Sender; +use futures::channel::mpsc; +use futures::future::BoxFuture; +use futures::lock::Mutex; +use futures::sink::SinkExt; use metainfo::Metainfo; use util::bt::InfoHash; use crate::disk::tasks::helpers::piece_checker::PieceCheckerState; use crate::disk::ODiskMessage; +use crate::FileSystem; #[allow(clippy::module_name_repetitions)] #[derive(Debug)] -pub struct DiskManagerContext { - torrents: Arc>>>, - out: Sender, +pub struct DiskManagerContext +where + F: FileSystem + Sync + 'static, + Arc: Send + Sync, +{ + torrents: Arc>>, + pub out: mpsc::Sender, fs: Arc, } -#[derive(Debug)] +impl Clone for DiskManagerContext +where + F: FileSystem + Sync + 'static, + Arc: Send + Sync, +{ + fn clone(&self) -> Self { + Self { + torrents: self.torrents.clone(), + out: self.out.clone(), + fs: self.fs.clone(), + } + } +} + +#[derive(Debug, Clone)] pub struct MetainfoState { - file: Metainfo, - state: PieceCheckerState, + pub file: Metainfo, + pub checker: Arc>, } impl MetainfoState { - pub fn new(file: Metainfo, state: PieceCheckerState) -> MetainfoState { - MetainfoState { file, state } + pub fn new(file: Metainfo, state: Arc>) -> MetainfoState { + MetainfoState { file, checker: state } } } -impl DiskManagerContext { - pub fn new(out: Sender, fs: F) -> DiskManagerContext { +impl DiskManagerContext +where + F: FileSystem + Sync + 'static, + Arc: Send + Sync, +{ + pub fn new(out: mpsc::Sender, fs: Arc) -> DiskManagerContext { DiskManagerContext { torrents: Arc::new(RwLock::new(HashMap::new())), out, - fs: Arc::new(fs), + fs, } } - pub fn blocking_sender(&self) -> Wait> { - self.out.clone().wait() + #[allow(dead_code)] + pub async fn send_message(&mut self, message: ODiskMessage) -> Result<(), futures::channel::mpsc::SendError> { + self.out.send(message).await } - pub fn filesystem(&self) -> &F { + pub fn filesystem(&self) -> &Arc { &self.fs } - pub fn insert_torrent(&self, file: Metainfo, state: PieceCheckerState) -> bool { + pub fn insert_torrent( + &self, + file: Metainfo, + state: &Arc>, + ) -> Result)> { let mut write_torrents = self .torrents .write() .expect("bip_disk: DiskManagerContext::insert_torrents Failed To Write Torrent"); let hash = file.info().info_hash(); - let hash_not_exists = !write_torrents.contains_key(&hash); - if hash_not_exists { - write_torrents.insert(hash, Mutex::new(MetainfoState::new(file, state))); - } + let entry = write_torrents.entry(hash); - hash_not_exists + match entry { + Entry::Occupied(key) => Err((hash, key.get().clone().into())), + Entry::Vacant(vac) => { + vac.insert(MetainfoState::new(file, state.clone())); + Ok(hash) + } + } } - pub fn update_torrent(&self, hash: InfoHash, call: C) -> bool + pub async fn update_torrent<'a, C, D>(self, hash: InfoHash, with_state: C) -> Option where - C: FnOnce(&Metainfo, &mut PieceCheckerState), + C: FnOnce(Arc, MetainfoState) -> BoxFuture<'a, D>, { - let read_torrents = self - .torrents - .read() - .expect("bip_disk: DiskManagerContext::update_torrent Failed To Read Torrent"); + let state = { + let read_torrents = self + .torrents + .read() + .expect("bip_disk: DiskManagerContext::update_torrent Failed To Read Torrent"); - match read_torrents.get(&hash) { - Some(state) => { - let mut lock_state = state - .lock() - .expect("bip_disk: DiskManagerContext::update_torrent Failed To Lock State"); - let deref_state = &mut *lock_state; + read_torrents.get(&hash)?.clone() + }; - call(&deref_state.file, &mut deref_state.state); - - true - } - None => false, - } + Some(with_state(self.fs.clone(), state.clone()).await) } pub fn remove_torrent(&self, hash: InfoHash) -> bool { @@ -92,16 +118,6 @@ impl DiskManagerContext { .write() .expect("bip_disk: DiskManagerContext::remove_torrent Failed To Write Torrent"); - write_torrents.remove(&hash).is_some_and(|_| true) - } -} - -impl Clone for DiskManagerContext { - fn clone(&self) -> DiskManagerContext { - DiskManagerContext { - torrents: self.torrents.clone(), - out: self.out.clone(), - fs: self.fs.clone(), - } + write_torrents.remove(&hash).is_some() } } diff --git a/packages/disk/src/disk/tasks/helpers/piece_accessor.rs b/packages/disk/src/disk/tasks/helpers/piece_accessor.rs index f81b78b83..57a461bad 100644 --- a/packages/disk/src/disk/tasks/helpers/piece_accessor.rs +++ b/packages/disk/src/disk/tasks/helpers/piece_accessor.rs @@ -1,25 +1,25 @@ -use std::{cmp, io}; - -use metainfo::Info; +use std::sync::Arc; use crate::disk::fs::FileSystem; +use crate::disk::tasks::context::MetainfoState; use crate::disk::tasks::helpers; use crate::memory::block::BlockMetadata; -pub struct PieceAccessor<'a, F> { - fs: F, - info_dict: &'a Info, +pub struct PieceAccessor { + fs: Arc, + state: MetainfoState, } -impl<'a, F> PieceAccessor<'a, F> +impl PieceAccessor where - F: FileSystem, + F: FileSystem + Sync + 'static, + Arc: Send + Sync, { - pub fn new(fs: F, info_dict: &'a Info) -> PieceAccessor<'a, F> { - PieceAccessor { fs, info_dict } + pub fn new(fs: Arc, state: MetainfoState) -> PieceAccessor { + PieceAccessor { fs, state } } - pub fn read_piece(&self, piece_buffer: &mut [u8], message: &BlockMetadata) -> io::Result<()> { + pub fn read_piece(&self, piece_buffer: &mut [u8], message: &BlockMetadata) -> std::io::Result<()> { self.run_with_file_regions(message, |mut file, offset, begin, end| { let bytes_read = self.fs.read_file(&mut file, offset, &mut piece_buffer[begin..end])?; assert_eq!(bytes_read, end - begin); @@ -28,7 +28,7 @@ where }) } - pub fn write_piece(&self, piece_buffer: &[u8], message: &BlockMetadata) -> io::Result<()> { + pub fn write_piece(&self, piece_buffer: &[u8], message: &BlockMetadata) -> std::io::Result<()> { self.run_with_file_regions(message, |mut file, offset, begin, end| { let bytes_written = self.fs.write_file(&mut file, offset, &piece_buffer[begin..end])?; assert_eq!(bytes_written, end - begin); @@ -39,31 +39,29 @@ where /// Run the given closure with the file, the file offset, and the read/write buffer start (inclusive) and end (exclusive) indices. /// TODO: We do not detect when/if the file size changes after the initial file size check, so the returned number of - fn run_with_file_regions(&self, message: &BlockMetadata, mut callback: C) -> io::Result<()> + fn run_with_file_regions(&self, message: &BlockMetadata, mut callback: C) -> std::io::Result<()> where - C: FnMut(F::File, u64, usize, usize) -> io::Result<()>, + C: FnMut(F::File, u64, usize, usize) -> std::io::Result<()>, { - let piece_length = self.info_dict.piece_length(); - - let mut total_bytes_to_skip = (message.piece_index() * piece_length) + message.block_offset(); + let mut total_bytes_to_skip = (message.piece_index() * self.state.file.info().piece_length()) + message.block_offset(); let mut total_bytes_accessed = 0; let total_block_length = message.block_length() as u64; - for file in self.info_dict.files() { + for file in self.state.file.info().files() { let total_file_size = file.length(); let mut bytes_to_access = total_file_size; - let min_bytes_to_skip = cmp::min(total_bytes_to_skip, bytes_to_access); + let min_bytes_to_skip = std::cmp::min(total_bytes_to_skip, bytes_to_access); total_bytes_to_skip -= min_bytes_to_skip; bytes_to_access -= min_bytes_to_skip; if bytes_to_access > 0 && total_bytes_accessed < total_block_length { - let file_path = helpers::build_path(self.info_dict.directory(), file); + let file_path = helpers::build_path(self.state.file.info().directory(), file); let fs_file = self.fs.open_file(file_path)?; let total_max_bytes_to_access = total_block_length - total_bytes_accessed; - let actual_bytes_to_access = cmp::min(total_max_bytes_to_access, bytes_to_access); + let actual_bytes_to_access = std::cmp::min(total_max_bytes_to_access, bytes_to_access); let offset = total_file_size - bytes_to_access; #[allow(clippy::cast_possible_truncation)] diff --git a/packages/disk/src/disk/tasks/helpers/piece_checker.rs b/packages/disk/src/disk/tasks/helpers/piece_checker.rs index 8c1637c89..44f68ffdc 100644 --- a/packages/disk/src/disk/tasks/helpers/piece_checker.rs +++ b/packages/disk/src/disk/tasks/helpers/piece_checker.rs @@ -1,69 +1,76 @@ use std::collections::{HashMap, HashSet}; -use std::{cmp, io}; +use std::sync::Arc; -use metainfo::Info; +use futures::future::BoxFuture; +use futures::lock::Mutex; +use metainfo::{Info, Metainfo}; use util::bt::InfoHash; use crate::disk::fs::FileSystem; +use crate::disk::tasks::context::MetainfoState; use crate::disk::tasks::helpers; use crate::disk::tasks::helpers::piece_accessor::PieceAccessor; -use crate::error::{TorrentError, TorrentErrorKind, TorrentResult}; +use crate::error::{TorrentError, TorrentResult}; use crate::memory::block::BlockMetadata; /// Calculates hashes on existing files within the file system given and reports good/bad pieces. -pub struct PieceChecker<'a, F> { - fs: F, - info_dict: &'a Info, - checker_state: &'a mut PieceCheckerState, +pub struct PieceChecker { + fs: Arc, + state: MetainfoState, } -impl<'a, F> PieceChecker<'a, F> +impl<'a, F> PieceChecker where - F: FileSystem + 'a, + F: FileSystem + Sync + 'static, + Arc: Send + Sync, { /// Create the initial `PieceCheckerState` for the `PieceChecker`. - pub fn init_state(fs: F, info_dict: &'a Info) -> TorrentResult { + pub async fn init_state(fs: Arc, info_dict: Info) -> TorrentResult>> { let total_blocks = info_dict.pieces().count(); - let last_piece_size = last_piece_size(info_dict); + let last_piece_size = last_piece_size(&info_dict); - let mut checker_state = PieceCheckerState::new(total_blocks, last_piece_size); + let checker_state = Arc::new(Mutex::new(PieceCheckerState::new(total_blocks, last_piece_size))); + + let file = Metainfo::new(info_dict.clone()); + + let state = MetainfoState::new(file, checker_state.clone()); { - let mut piece_checker = PieceChecker::with_state(fs, info_dict, &mut checker_state); + let mut piece_checker = PieceChecker::with_state(fs, state); piece_checker.validate_files_sizes()?; - piece_checker.fill_checker_state(); - piece_checker.calculate_diff()?; + piece_checker.fill_checker_state().await; + piece_checker.calculate_diff().await?; } Ok(checker_state) } /// Create a new `PieceChecker` with the given state. - pub fn with_state(fs: F, info_dict: &'a Info, checker_state: &'a mut PieceCheckerState) -> PieceChecker<'a, F> { - PieceChecker { - fs, - info_dict, - checker_state, - } + pub fn with_state(fs: Arc, state: MetainfoState) -> PieceChecker { + PieceChecker { fs, state } } /// Calculate the diff of old to new good/bad pieces and store them in the piece checker state /// to be retrieved by the caller. - pub fn calculate_diff(self) -> io::Result<()> { - let piece_length = self.info_dict.piece_length(); + pub async fn calculate_diff(self) -> std::io::Result<()> { + let piece_length = self.state.file.info().piece_length(); // TODO: Use Block Allocator let mut piece_buffer = vec![0u8; piece_length.try_into().unwrap()]; - let info_dict = self.info_dict; - let piece_accessor = PieceAccessor::new(&self.fs, self.info_dict); + let piece_accessor = PieceAccessor::new(self.fs.clone(), self.state.clone()); - self.checker_state + self.state + .checker + .lock() + .await .run_with_whole_pieces(piece_length.try_into().unwrap(), |message| { piece_accessor.read_piece(&mut piece_buffer[..message.block_length()], message)?; let calculated_hash = InfoHash::from_bytes(&piece_buffer[..message.block_length()]); let expected_hash = InfoHash::from_hash( - info_dict + self.state + .file + .info() .pieces() .nth(message.piece_index().try_into().unwrap()) .expect("bip_peer: Piece Checker Failed To Retrieve Expected Hash"), @@ -80,15 +87,16 @@ where /// /// This is done once when a torrent file is added to see if we have any good pieces that /// the caller can use to skip (if the torrent was partially downloaded before). - fn fill_checker_state(&mut self) { - let piece_length = self.info_dict.piece_length(); - let total_bytes: u64 = self.info_dict.files().map(metainfo::File::length).sum(); + async fn fill_checker_state(&mut self) { + let piece_length = self.state.file.info().piece_length(); + let total_bytes: u64 = self.state.file.info().files().map(metainfo::File::length).sum(); let full_pieces = total_bytes / piece_length; - let last_piece_size = last_piece_size(self.info_dict); + let last_piece_size = last_piece_size(self.state.file.info()); + let mut check_state = self.state.checker.lock().await; for piece_index in 0..full_pieces { - self.checker_state.add_pending_block(BlockMetadata::with_default_hash( + check_state.add_pending_block(BlockMetadata::with_default_hash( piece_index, 0, piece_length.try_into().unwrap(), @@ -96,8 +104,7 @@ where } if last_piece_size != 0 { - self.checker_state - .add_pending_block(BlockMetadata::with_default_hash(full_pieces, 0, last_piece_size)); + check_state.add_pending_block(BlockMetadata::with_default_hash(full_pieces, 0, last_piece_size)); } } @@ -108,8 +115,8 @@ where /// size, an error will be thrown as we do not want to overwrite and existing file that maybe just had the same /// name as a file in our dictionary. fn validate_files_sizes(&mut self) -> TorrentResult<()> { - for file in self.info_dict.files() { - let file_path = helpers::build_path(self.info_dict.directory(), file); + for file in self.state.file.info().files() { + let file_path = helpers::build_path(self.state.file.info().directory(), file); let expected_size = file.length(); self.fs @@ -128,11 +135,11 @@ where .write_file(&mut file, expected_size - 1, &[0]) .expect("bip_peer: Failed To Create File When Validating Sizes"); } else if !size_matches { - return Err(TorrentError::from_kind(TorrentErrorKind::ExistingFileSizeCheck { + return Err(TorrentError::ExistingFileSizeCheck { file_path, expected_size, actual_size, - })); + }); } Ok(()) @@ -190,12 +197,12 @@ impl PieceCheckerState { /// Run the given closures against `NewGood` and `NewBad` messages. Each of the messages will /// then either be dropped (`NewBad`) or converted to `OldGood` (`NewGood`). - pub fn run_with_diff(&mut self, mut callback: F) + pub async fn run_with_diff(&mut self, mut callback: F) where - F: FnMut(&PieceState), + F: FnMut(&PieceState) -> BoxFuture<'_, ()>, { for piece_state in self.new_states.drain(..) { - callback(&piece_state); + callback(&piece_state).await; self.old_states.insert(piece_state); } @@ -203,9 +210,9 @@ impl PieceCheckerState { /// Pass any pieces that have not been identified as `OldGood` into the callback which determines /// if the piece is good or bad so it can be marked as `NewGood` or `NewBad`. - fn run_with_whole_pieces(&mut self, piece_length: usize, mut callback: F) -> io::Result<()> + fn run_with_whole_pieces(&mut self, piece_length: usize, mut callback: F) -> std::io::Result<()> where - F: FnMut(&BlockMetadata) -> io::Result, + F: FnMut(&BlockMetadata) -> std::io::Result, { self.merge_pieces(); @@ -299,7 +306,7 @@ fn merge_piece_messages(message_a: &BlockMetadata, message_b: &BlockMetadata) -> // If start b falls between start and end a, then start a is where we start, and we end at the max of end a // or end b, then calculate the length from end minus start. Vice versa if a falls between start and end b. if start_b >= start_a && start_b <= end_a { - let end_to_take = cmp::max(end_a, end_b); + let end_to_take = std::cmp::max(end_a, end_b); let length = end_to_take - start_a; Some(BlockMetadata::new( @@ -309,7 +316,7 @@ fn merge_piece_messages(message_a: &BlockMetadata, message_b: &BlockMetadata) -> length.try_into().unwrap(), )) } else if start_a >= start_b && start_a <= end_b { - let end_to_take = cmp::max(end_a, end_b); + let end_to_take = std::cmp::max(end_a, end_b); let length = end_to_take - start_b; Some(BlockMetadata::new( diff --git a/packages/disk/src/disk/tasks/mod.rs b/packages/disk/src/disk/tasks/mod.rs index 546cfbc07..004a09e8d 100644 --- a/packages/disk/src/disk/tasks/mod.rs +++ b/packages/disk/src/disk/tasks/mod.rs @@ -1,7 +1,8 @@ -use futures::sink::Wait; -use futures::sync::mpsc::Sender; -use futures_cpupool::CpuPool; -use log::info; +use std::sync::Arc; + +use futures::channel::mpsc; +use futures::lock::Mutex; +use futures::{FutureExt, SinkExt as _}; use metainfo::Metainfo; use util::bt::InfoHash; @@ -10,200 +11,218 @@ use crate::disk::tasks::context::DiskManagerContext; use crate::disk::tasks::helpers::piece_accessor::PieceAccessor; use crate::disk::tasks::helpers::piece_checker::{PieceChecker, PieceCheckerState, PieceState}; use crate::disk::{IDiskMessage, ODiskMessage}; -use crate::error::{BlockError, BlockErrorKind, BlockResult, TorrentError, TorrentErrorKind, TorrentResult}; +use crate::error::{BlockError, BlockResult, TorrentError, TorrentResult}; use crate::memory::block::{Block, BlockMut}; pub mod context; mod helpers; -pub fn execute_on_pool(msg: IDiskMessage, pool: &CpuPool, context: DiskManagerContext) +pub async fn execute(msg: IDiskMessage, context: DiskManagerContext) where - F: FileSystem + Send + Sync + 'static, + F: FileSystem + Sync + 'static, + Arc: Send + Sync, { - pool.spawn_fn(move || { - let mut blocking_sender = context.blocking_sender(); + let mut sender = context.out.clone(); - let out_msg = match msg { - IDiskMessage::AddTorrent(metainfo) => { - let info_hash = metainfo.info().info_hash(); + let out_msg = match msg { + IDiskMessage::AddTorrent(metainfo) => { + let info_hash = metainfo.info().info_hash(); - match execute_add_torrent(metainfo, &context, &mut blocking_sender) { - Ok(()) => ODiskMessage::TorrentAdded(info_hash), - Err(err) => ODiskMessage::TorrentError(info_hash, err), - } + match execute_add_torrent(metainfo, context, sender.clone()).await { + Ok(()) => ODiskMessage::TorrentAdded(info_hash), + Err(err) => ODiskMessage::TorrentError(info_hash, err), } - IDiskMessage::RemoveTorrent(hash) => match execute_remove_torrent(hash, &context) { - Ok(()) => ODiskMessage::TorrentRemoved(hash), - Err(err) => ODiskMessage::TorrentError(hash, err), - }, - IDiskMessage::SyncTorrent(hash) => match execute_sync_torrent(hash, &context) { - Ok(()) => ODiskMessage::TorrentSynced(hash), - Err(err) => ODiskMessage::TorrentError(hash, err), - }, - IDiskMessage::LoadBlock(mut block) => match execute_load_block(&mut block, &context) { - Ok(()) => ODiskMessage::BlockLoaded(block), - Err(err) => ODiskMessage::LoadBlockError(block, err), - }, - IDiskMessage::ProcessBlock(block) => match execute_process_block(&block, &context, &mut blocking_sender) { - Ok(()) => ODiskMessage::BlockProcessed(block), - Err(err) => ODiskMessage::ProcessBlockError(block, err), - }, - }; - - blocking_sender - .send(out_msg) - .expect("bip_disk: Failed To Send Out Message In execute_on_pool"); - blocking_sender - .flush() - .expect("bip_disk: Failed to Flush Out Messages In execute_on_pool"); - - Ok::<(), ()>(()) - }) - .forget(); + } + IDiskMessage::RemoveTorrent(hash) => match execute_remove_torrent(hash, &context) { + Ok(()) => ODiskMessage::TorrentRemoved(hash), + Err(err) => ODiskMessage::TorrentError(hash, err), + }, + IDiskMessage::SyncTorrent(hash) => match execute_sync_torrent(hash, context).await { + Ok(()) => ODiskMessage::TorrentSynced(hash), + Err(err) => ODiskMessage::TorrentError(hash, err), + }, + IDiskMessage::LoadBlock(mut block) => match execute_load_block(&mut block, context).await { + Ok(()) => ODiskMessage::BlockLoaded(block), + Err(err) => ODiskMessage::LoadBlockError(block, err), + }, + IDiskMessage::ProcessBlock(block) => match execute_process_block(&block, context, sender.clone()).await { + Ok(()) => ODiskMessage::BlockProcessed(block), + Err(err) => ODiskMessage::ProcessBlockError(block, err), + }, + }; + + tracing::trace!("sending output disk message: {out_msg:?}"); + + sender + .send(out_msg) + .await + .expect("bip_disk: Failed To Send Out Message In execute_on_pool"); + + tracing::debug!("finished sending output message... "); } -fn execute_add_torrent( +async fn execute_add_torrent( file: Metainfo, - context: &DiskManagerContext, - blocking_sender: &mut Wait>, + context: DiskManagerContext, + sender: mpsc::Sender, ) -> TorrentResult<()> where - F: FileSystem, + F: FileSystem + Sync + 'static, + Arc: Send + Sync, { let info_hash = file.info().info_hash(); - let mut init_state = PieceChecker::init_state(context.filesystem(), file.info())?; + let init_state = PieceChecker::init_state(context.filesystem().clone(), file.info().clone()).await?; // In case we are resuming a download, we need to send the diff for the newly added torrent - send_piece_diff(&mut init_state, info_hash, blocking_sender, true); + send_piece_diff(&init_state, info_hash, sender, true).await; - if context.insert_torrent(file, init_state) { - Ok(()) - } else { - Err(TorrentError::from_kind(TorrentErrorKind::ExistingInfoHash { - hash: info_hash, - })) + match context.insert_torrent(file, &init_state) { + Ok(_) => Ok(()), + Err((hash, _)) => Err(TorrentError::ExistingInfoHash { hash }), } } fn execute_remove_torrent(hash: InfoHash, context: &DiskManagerContext) -> TorrentResult<()> where - F: FileSystem, + F: FileSystem + Sync + 'static, + Arc: Send + Sync, { if context.remove_torrent(hash) { Ok(()) } else { - Err(TorrentError::from_kind(TorrentErrorKind::InfoHashNotFound { hash })) + Err(TorrentError::InfoHashNotFound { hash }) } } -fn execute_sync_torrent(hash: InfoHash, context: &DiskManagerContext) -> TorrentResult<()> +async fn execute_sync_torrent(hash: InfoHash, context: DiskManagerContext) -> TorrentResult<()> where - F: FileSystem, + F: FileSystem + Sync + 'static, + Arc: Send + Sync, { - let filesystem = context.filesystem(); + let filesystem = context.filesystem().clone(); - let mut sync_result = Ok(()); - let found_hash = context.update_torrent(hash, |metainfo_file, _| { - let opt_parent_dir = metainfo_file.info().directory(); + let sync_result = context + .update_torrent(hash, |_, state| { + let opt_parent_dir = state.file.info().directory(); - for file in metainfo_file.info().files() { - let path = helpers::build_path(opt_parent_dir, file); + for file in state.file.info().files() { + let path = helpers::build_path(opt_parent_dir, file); - sync_result = filesystem.sync_file(path); - } - }); + match filesystem.sync_file(path) { + Ok(()) => continue, + Err(e) => return std::future::ready(Err(e)).boxed(), + } + } - if found_hash { - Ok(sync_result?) - } else { - Err(TorrentError::from_kind(TorrentErrorKind::InfoHashNotFound { hash })) + std::future::ready(Ok(())).boxed() + }) + .await; + + match sync_result { + Some(result) => Ok(result?), + None => Err(TorrentError::InfoHashNotFound { hash }), } } -fn execute_load_block(block: &mut BlockMut, context: &DiskManagerContext) -> BlockResult<()> +async fn execute_load_block(block: &mut BlockMut, context: DiskManagerContext) -> BlockResult<()> where - F: FileSystem, + F: FileSystem + Sync + 'static, + Arc: Send + Sync, { let metadata = block.metadata(); let info_hash = metadata.info_hash(); + let context = context.clone(); - let mut access_result = Ok(()); - let found_hash = context.update_torrent(info_hash, |metainfo_file, _| { - let piece_accessor = PieceAccessor::new(context.filesystem(), metainfo_file.info()); + let access_result = context + .update_torrent(info_hash, |fs, state| { + async move { + let piece_accessor = PieceAccessor::new(fs, state); - // Read The Piece In From The Filesystem - access_result = piece_accessor.read_piece(&mut *block, &metadata); - }); + // Read The Piece In From The Filesystem; + piece_accessor.read_piece(&mut *block, &metadata) + } + .boxed() + }) + .await; - if found_hash { - Ok(access_result?) - } else { - Err(BlockError::from_kind(BlockErrorKind::InfoHashNotFound { hash: info_hash })) + match access_result { + Some(result) => Ok(result?), + None => Err(BlockError::InfoHashNotFound { hash: info_hash }), } } -fn execute_process_block( +async fn execute_process_block( block: &Block, - context: &DiskManagerContext, - blocking_sender: &mut Wait>, + context: DiskManagerContext, + sender: mpsc::Sender, ) -> BlockResult<()> where - F: FileSystem, + F: FileSystem + Sync + 'static, + Arc: Send + Sync, { let metadata = block.metadata(); let info_hash = metadata.info_hash(); - let mut block_result = Ok(()); - let found_hash = context.update_torrent(info_hash, |metainfo_file, checker_state| { - info!( - "Processing Block, Acquired Torrent Lock For {:?}", - metainfo_file.info().info_hash() - ); + let block_result = context + .update_torrent(info_hash, |fs, state| { + tracing::trace!("Updating Blocks for Torrent: {info_hash}"); - let piece_accessor = PieceAccessor::new(context.filesystem(), metainfo_file.info()); + async move { + let piece_accessor = PieceAccessor::new(fs.clone(), state.clone()); - // Write Out Piece Out To The Filesystem And Recalculate The Diff - block_result = piece_accessor.write_piece(block, &metadata).and_then(|()| { - checker_state.add_pending_block(metadata); + // Write Out Piece Out To The Filesystem And Recalculate The Diff + let block_result = match piece_accessor.write_piece(block, &metadata) { + Ok(()) => { + state.checker.lock().await.add_pending_block(metadata); - PieceChecker::with_state(context.filesystem(), metainfo_file.info(), checker_state).calculate_diff() - }); + PieceChecker::with_state(fs, state.clone()).calculate_diff().await + } + Err(e) => Err(e), + }; - send_piece_diff(checker_state, metainfo_file.info().info_hash(), blocking_sender, false); + send_piece_diff(&state.checker, state.file.info().info_hash(), sender.clone(), false).await; - info!( - "Processing Block, Released Torrent Lock For {:?}", - metainfo_file.info().info_hash() - ); - }); + block_result + } + .boxed() + }) + .await; - if found_hash { - Ok(block_result?) - } else { - Err(BlockError::from_kind(BlockErrorKind::InfoHashNotFound { hash: info_hash })) + tracing::debug!("Finished Updating Torrent: {info_hash}, {block_result:?}"); + + match block_result { + Some(result) => Ok(result?), + None => Err(BlockError::InfoHashNotFound { hash: info_hash }), } } -fn send_piece_diff( - checker_state: &mut PieceCheckerState, +async fn send_piece_diff( + checker_state: &Arc>, hash: InfoHash, - blocking_sender: &mut Wait>, + sender: mpsc::Sender, ignore_bad: bool, ) { - checker_state.run_with_diff(|piece_state| { - let opt_out_msg = match (piece_state, ignore_bad) { - (&PieceState::Good(index), _) => Some(ODiskMessage::FoundGoodPiece(hash, index)), - (&PieceState::Bad(index), false) => Some(ODiskMessage::FoundBadPiece(hash, index)), - (&PieceState::Bad(_), true) => None, - }; - - if let Some(out_msg) = opt_out_msg { - blocking_sender - .send(out_msg) - .expect("bip_disk: Failed To Send Piece State Message"); - blocking_sender - .flush() - .expect("bip_disk: Failed To Flush Piece State Message"); - } - }); + checker_state + .lock() + .await + .run_with_diff(|piece_state| { + let mut sender = sender.clone(); + + async move { + let opt_out_msg = match (piece_state, ignore_bad) { + (&PieceState::Good(index), _) => Some(ODiskMessage::FoundGoodPiece(hash, index)), + (&PieceState::Bad(index), false) => Some(ODiskMessage::FoundBadPiece(hash, index)), + (&PieceState::Bad(_), true) => None, + }; + + if let Some(out_msg) = opt_out_msg { + sender + .send(out_msg) + .await + .expect("bip_disk: Failed To Send Piece State Message"); + } + } + .boxed() + }) + .await; } diff --git a/packages/disk/src/error.rs b/packages/disk/src/error.rs index 96853b09c..799ae0284 100644 --- a/packages/disk/src/error.rs +++ b/packages/disk/src/error.rs @@ -1,58 +1,41 @@ -use std::io; use std::path::PathBuf; -use error_chain::error_chain; +use thiserror::Error; use util::bt::InfoHash; -error_chain! { - types { - BlockError, BlockErrorKind, BlockResultExt, BlockResult; - } - - foreign_links { - Io(io::Error); - } - - errors { - InfoHashNotFound { - hash: InfoHash - } { - description("Failed To Load/Process Block Because Torrent Is Not Loaded") - display("Failed To Load/Process Block Because The InfoHash {:?} It Is Not Currently Added", hash) - } - } +#[allow(clippy::module_name_repetitions)] +#[derive(Error, Debug)] +pub enum BlockError { + #[error("IO error")] + Io(#[from] std::io::Error), + + #[error("Failed To Load/Process Block Because The InfoHash {hash:?} Is Not Currently Added")] + InfoHashNotFound { hash: InfoHash }, } -error_chain! { - types { - TorrentError, TorrentErrorKind, TorrentResultExt, TorrentResult; - } - - foreign_links { - Block(BlockError); - Io(io::Error); - } - - errors { - ExistingFileSizeCheck { - file_path: PathBuf, - expected_size: u64, - actual_size: u64 - } { - description("Failed To Add Torrent Because Size Checker Failed For A File") - display("Failed To Add Torrent Because Size Checker Failed For {:?} Where File Size Was {} But Should Have Been {}", file_path, actual_size, expected_size) - } - ExistingInfoHash { - hash: InfoHash - } { - description("Failed To Add Torrent Because Another Torrent With The Same InfoHash Is Already Added") - display("Failed To Add Torrent Because Another Torrent With The Same InfoHash {:?} Is Already Added", hash) - } - InfoHashNotFound { - hash: InfoHash - } { - description("Failed To Remove Torrent Because It Is Not Currently Added") - display("Failed To Remove Torrent Because The InfoHash {:?} It Is Not Currently Added", hash) - } - } +pub type BlockResult = Result; + +#[allow(clippy::module_name_repetitions)] +#[derive(Error, Debug)] +pub enum TorrentError { + #[error("Block error")] + Block(#[from] BlockError), + + #[error("IO error")] + Io(#[from] std::io::Error), + + #[error("Failed To Add Torrent Because Size Checker Failed For {file_path:?} Where File Size Was {actual_size} But Should Have Been {expected_size}")] + ExistingFileSizeCheck { + file_path: PathBuf, + expected_size: u64, + actual_size: u64, + }, + + #[error("Failed To Add Torrent Because Another Torrent With The Same InfoHash {hash:?} Is Already Added")] + ExistingInfoHash { hash: InfoHash }, + + #[error("Failed To Remove Torrent Because The InfoHash {hash:?} Is Not Currently Added")] + InfoHashNotFound { hash: InfoHash }, } + +pub type TorrentResult = Result; diff --git a/packages/disk/src/lib.rs b/packages/disk/src/lib.rs index a4a4ab909..80fc3c2dc 100644 --- a/packages/disk/src/lib.rs +++ b/packages/disk/src/lib.rs @@ -6,9 +6,7 @@ pub mod error; pub use crate::disk::fs::FileSystem; pub use crate::disk::manager::builder::DiskManagerBuilder; -pub use crate::disk::manager::sink::DiskManagerSink; -pub use crate::disk::manager::stream::DiskManagerStream; -pub use crate::disk::manager::DiskManager; +pub use crate::disk::manager::{DiskManager, DiskManagerSink, DiskManagerStream}; pub use crate::disk::{IDiskMessage, ODiskMessage}; pub use crate::memory::block::{Block, BlockMetadata, BlockMut}; diff --git a/packages/disk/src/memory/block.rs b/packages/disk/src/memory/block.rs index b57bc0efa..22d418354 100644 --- a/packages/disk/src/memory/block.rs +++ b/packages/disk/src/memory/block.rs @@ -61,7 +61,7 @@ impl Default for BlockMetadata { //----------------------------------------------------------------------------// /// `Block` of immutable memory. -#[derive(Debug)] +#[derive(Debug, Clone)] pub struct Block { metadata: BlockMetadata, block_data: Bytes, @@ -101,7 +101,7 @@ impl Deref for Block { /// `BlockMut` of mutable memory. #[allow(clippy::module_name_repetitions)] -#[derive(Debug)] +#[derive(Debug, Clone)] pub struct BlockMut { metadata: BlockMetadata, block_data: BytesMut, @@ -109,15 +109,18 @@ pub struct BlockMut { impl BlockMut { /// Create a new `BlockMut`. + #[must_use] pub fn new(metadata: BlockMetadata, block_data: BytesMut) -> BlockMut { BlockMut { metadata, block_data } } /// Access the metadata for the block. + #[must_use] pub fn metadata(&self) -> BlockMetadata { self.metadata } + #[must_use] pub fn into_parts(self) -> (BlockMetadata, BytesMut) { (self.metadata, self.block_data) } diff --git a/packages/disk/tests/add_torrent.rs b/packages/disk/tests/add_torrent.rs index 10755c320..72ad5d6f8 100644 --- a/packages/disk/tests/add_torrent.rs +++ b/packages/disk/tests/add_torrent.rs @@ -1,15 +1,21 @@ -use common::{core_loop_with_timeout, random_buffer, InMemoryFileSystem, MultiFileDirectAccessor}; -use disk::{DiskManagerBuilder, FileSystem, IDiskMessage, ODiskMessage}; -use futures::future::{Future, Loop}; -use futures::sink::Sink; -use futures::stream::Stream; +use common::{ + random_buffer, runtime_loop_with_timeout, tracing_stderr_init, InMemoryFileSystem, MultiFileDirectAccessor, DEFAULT_TIMEOUT, + INIT, +}; +use disk::{DiskManagerBuilder, FileSystem as _, IDiskMessage, ODiskMessage}; +use futures::future::{self, Either}; +use futures::{FutureExt, SinkExt as _}; use metainfo::{Metainfo, MetainfoBuilder, PieceLength}; -use tokio_core::reactor::Core; +use tracing::level_filters::LevelFilter; mod common; -#[test] -fn positive_add_torrent() { +#[tokio::test] +async fn positive_add_torrent() { + INIT.call_once(|| { + tracing_stderr_init(LevelFilter::INFO); + }); + // Create some "files" as random bytes let data_a = (random_buffer(50), "/path/to/file/a".into()); let data_b = (random_buffer(2000), "/path/to/file/b".into()); @@ -26,20 +32,19 @@ fn positive_add_torrent() { // Spin up a disk manager and add our created torrent to it let filesystem = InMemoryFileSystem::new(); - let disk_manager = DiskManagerBuilder::new().build(filesystem.clone()); + let disk_manager = DiskManagerBuilder::new().build(filesystem.me()); - let (send, recv) = disk_manager.split(); - send.send(IDiskMessage::AddTorrent(metainfo_file)).wait().unwrap(); + let (mut send, recv) = disk_manager.into_parts(); + send.send(IDiskMessage::AddTorrent(metainfo_file)).await.unwrap(); // Verify that zero pieces are marked as good - let mut core = Core::new().unwrap(); - - // Run a core loop until we get the TorrentAdded message - let good_pieces = core_loop_with_timeout(&mut core, 500, (0, recv), |good_pieces, recv, msg| match msg { - ODiskMessage::TorrentAdded(_) => Loop::Break(good_pieces), - ODiskMessage::FoundGoodPiece(_, _) => Loop::Continue((good_pieces + 1, recv)), + // Run a runtime loop until we get the TorrentAdded message + let good_pieces = runtime_loop_with_timeout(DEFAULT_TIMEOUT, (0, recv), |good_pieces, recv, msg| match msg { + Ok(ODiskMessage::TorrentAdded(_)) => Either::Left(future::ready(good_pieces).boxed()), + Ok(ODiskMessage::FoundGoodPiece(_, _)) => Either::Right(future::ready((good_pieces + 1, recv)).boxed()), unexpected => panic!("Unexpected Message: {unexpected:?}"), - }); + }) + .await; assert_eq!(0, good_pieces); diff --git a/packages/disk/tests/common/mod.rs b/packages/disk/tests/common/mod.rs index 017dea4c8..2db3d0668 100644 --- a/packages/disk/tests/common/mod.rs +++ b/packages/disk/tests/common/mod.rs @@ -1,64 +1,84 @@ -use std::cmp; use std::collections::HashMap; -use std::io::{self}; use std::path::{Path, PathBuf}; -use std::sync::{Arc, Mutex}; +use std::sync::{Arc, Mutex, Once, Weak}; use std::time::Duration; use bytes::BytesMut; use disk::{BlockMetadata, BlockMut, FileSystem, IDiskMessage}; -use futures::future::{self, Future, Loop}; -use futures::sink::{Sink, Wait}; +use futures::future::BoxFuture; use futures::stream::Stream; +use futures::{future, Sink, SinkExt as _, StreamExt as _}; use metainfo::{Accessor, IntoAccessor, PieceAccess}; -use tokio_core::reactor::{Core, Timeout}; +use tokio::time::timeout; +use tracing::level_filters::LevelFilter; use util::bt::InfoHash; +#[allow(dead_code)] +pub const DEFAULT_TIMEOUT: Duration = Duration::from_millis(500); + +#[allow(dead_code)] +pub static INIT: Once = Once::new(); + +#[allow(dead_code)] +pub fn tracing_stderr_init(filter: LevelFilter) { + let builder = tracing_subscriber::fmt() + .with_max_level(filter) + .with_ansi(true) + .with_writer(std::io::stderr); + + builder.pretty().with_file(true).init(); + + tracing::info!("Logging initialized"); +} + /// Generate buffer of size random bytes. pub fn random_buffer(size: usize) -> Vec { let mut buffer = vec![0u8; size]; - rand::Rng::fill(&mut rand::thread_rng(), buffer.as_mut_slice()); - buffer } -/// Initiate a core loop with the given timeout, state, and closure. +/// Initiate a runtime loop with the given timeout, state, and closure. /// /// Returns R or panics if an error occurred in the loop (including a timeout). #[allow(dead_code)] -pub fn core_loop_with_timeout(core: &mut Core, timeout_ms: u64, state: (I, S), call: F) -> R +pub async fn runtime_loop_with_timeout<'a, 'b, I, S, F, R>(timeout_time: Duration, initial_state: (I, S), mut call: F) -> R where - F: FnMut(I, S, S::Item) -> Loop, - S: Stream, + F: FnMut(I, S, S::Item) -> future::Either, BoxFuture<'b, (I, S)>>, + S: Stream + Unpin, + R: 'static, + I: std::fmt::Debug + Clone, { - let timeout = Timeout::new(Duration::from_millis(timeout_ms), &core.handle()) - .unwrap() - .then(|_| Err(())); - - // Have to stick the call in our init state so that we transfer ownership between loops - core.run( - future::loop_fn((call, state), |(mut call, (init, stream))| { - stream.into_future().map(|(opt_msg, stream)| { - let msg = opt_msg.unwrap_or_else(|| panic!("End Of Stream Reached")); - - match call(init, stream, msg) { - Loop::Continue((init, stream)) => Loop::Continue((call, (init, stream))), - Loop::Break(ret) => Loop::Break(ret), + let mut state = initial_state; + loop { + let (init, mut stream) = state; + if let Some(msg) = { + timeout(timeout_time, stream.next()) + .await + .unwrap_or_else(|_| panic!("timeout while waiting for next message: {timeout_time:?}, {init:?}")) + } { + match call(init.clone(), stream, msg) { + future::Either::Left(fut) => { + return timeout(timeout_time, fut) + .await + .unwrap_or_else(|_| panic!("timeout waiting for final processing: {timeout_time:?}, {init:?}")); } - }) - }) - .map_err(|_| ()) - .select(timeout) - .map(|(item, _)| item), - ) - .unwrap_or_else(|_| panic!("Core Loop Timed Out")) + future::Either::Right(fut) => { + state = timeout(timeout_time, fut) + .await + .unwrap_or_else(|_| panic!("timeout waiting for next loop state: {timeout_time:?}, {init:?}")); + } + } + } else { + panic!("End Of Stream Reached"); + } + } } /// Send block with the given metadata and entire data given. #[allow(dead_code)] -pub fn send_block( - blocking_send: &mut Wait, +pub async fn send_block( + sink: &mut S, data: &[u8], hash: InfoHash, piece_index: u64, @@ -66,9 +86,14 @@ pub fn send_block( block_len: usize, modify: M, ) where - S: Sink, + S: Sink + Unpin, M: Fn(&mut [u8]), + >::Error: std::fmt::Display, { + tracing::trace!( + "sending block for torrent: {hash}, index: {piece_index}, block_offset: {block_offset}, block_length: {block_len}" + ); + let mut bytes = BytesMut::new(); bytes.extend_from_slice(data); @@ -76,9 +101,9 @@ pub fn send_block( modify(&mut block[..]); - blocking_send - .send(IDiskMessage::ProcessBlock(block.into())) - .unwrap_or_else(|_| panic!("Failed To Send Process Block Message")); + sink.send(IDiskMessage::ProcessBlock(block.into())) + .await + .unwrap_or_else(|e| panic!("Failed To Send Process Block Message: {e}")); } //----------------------------------------------------------------------------// @@ -99,7 +124,7 @@ impl MultiFileDirectAccessor { impl IntoAccessor for MultiFileDirectAccessor { type Accessor = MultiFileDirectAccessor; - fn into_accessor(self) -> io::Result { + fn into_accessor(self) -> std::io::Result { Ok(self) } } @@ -111,7 +136,7 @@ impl Accessor for MultiFileDirectAccessor { Some(self.dir.as_ref()) } - fn access_metadata(&self, mut callback: C) -> io::Result<()> + fn access_metadata(&self, mut callback: C) -> std::io::Result<()> where C: FnMut(u64, &Path), { @@ -122,9 +147,9 @@ impl Accessor for MultiFileDirectAccessor { Ok(()) } - fn access_pieces(&self, mut callback: C) -> io::Result<()> + fn access_pieces(&self, mut callback: C) -> std::io::Result<()> where - C: for<'a> FnMut(PieceAccess<'a>) -> io::Result<()>, + C: for<'a> FnMut(PieceAccess<'a>) -> std::io::Result<()>, { for (buffer, _) in &self.files { callback(PieceAccess::Compute(&mut &buffer[..]))?; @@ -137,16 +162,24 @@ impl Accessor for MultiFileDirectAccessor { //----------------------------------------------------------------------------// /// Allow us to mock out the file system. -#[derive(Clone)] +#[derive(Debug)] pub struct InMemoryFileSystem { - files: Arc>>>, + #[allow(dead_code)] + me: Weak, + files: Mutex>>, } impl InMemoryFileSystem { - pub fn new() -> InMemoryFileSystem { - InMemoryFileSystem { - files: Arc::new(Mutex::new(HashMap::new())), - } + pub fn new() -> Arc { + Arc::new_cyclic(|me| Self { + me: me.clone(), + files: Mutex::default(), + }) + } + + #[allow(dead_code)] + pub fn me(&self) -> Arc { + self.me.upgrade().unwrap() } pub fn run_with_lock(&self, call: C) -> R @@ -166,7 +199,7 @@ pub struct InMemoryFile { impl FileSystem for InMemoryFileSystem { type File = InMemoryFile; - fn open_file

(&self, path: P) -> io::Result + fn open_file

(&self, path: P) -> std::io::Result where P: AsRef + Send + 'static, { @@ -181,40 +214,40 @@ impl FileSystem for InMemoryFileSystem { Ok(InMemoryFile { path: file_path }) } - fn sync_file

(&self, _path: P) -> io::Result<()> + fn sync_file

(&self, _path: P) -> std::io::Result<()> where P: AsRef + Send + 'static, { Ok(()) } - fn file_size(&self, file: &Self::File) -> io::Result { + fn file_size(&self, file: &Self::File) -> std::io::Result { self.run_with_lock(|files| { files .get(&file.path) .map(|file| file.len() as u64) - .ok_or(io::Error::new(io::ErrorKind::NotFound, "File Not Found")) + .ok_or(std::io::Error::new(std::io::ErrorKind::NotFound, "File Not Found")) }) } - fn read_file(&self, file: &mut Self::File, offset: u64, buffer: &mut [u8]) -> io::Result { + fn read_file(&self, file: &mut Self::File, offset: u64, buffer: &mut [u8]) -> std::io::Result { self.run_with_lock(|files| { files .get(&file.path) .map(|file_buffer| { let cast_offset: usize = offset.try_into().unwrap(); - let bytes_to_copy = cmp::min(file_buffer.len() - cast_offset, buffer.len()); + let bytes_to_copy = std::cmp::min(file_buffer.len() - cast_offset, buffer.len()); let bytes = &file_buffer[cast_offset..(bytes_to_copy + cast_offset)]; buffer.clone_from_slice(bytes); bytes_to_copy }) - .ok_or(io::Error::new(io::ErrorKind::NotFound, "File Not Found")) + .ok_or(std::io::Error::new(std::io::ErrorKind::NotFound, "File Not Found")) }) } - fn write_file(&self, file: &mut Self::File, offset: u64, buffer: &[u8]) -> io::Result { + fn write_file(&self, file: &mut Self::File, offset: u64, buffer: &[u8]) -> std::io::Result { self.run_with_lock(|files| { files .get_mut(&file.path) @@ -226,16 +259,16 @@ impl FileSystem for InMemoryFileSystem { file_buffer.resize(last_byte_pos, 0); } - let bytes_to_copy = cmp::min(file_buffer.len() - cast_offset, buffer.len()); + let bytes_to_copy = std::cmp::min(file_buffer.len() - cast_offset, buffer.len()); if bytes_to_copy != 0 { file_buffer[cast_offset..(cast_offset + bytes_to_copy)].clone_from_slice(buffer); } - // TODO: If the file is full, this will return zero, we should also simulate io::ErrorKind::WriteZero + // TODO: If the file is full, this will return zero, we should also simulate std::io::ErrorKind::WriteZero bytes_to_copy }) - .ok_or(io::Error::new(io::ErrorKind::NotFound, "File Not Found")) + .ok_or(std::io::Error::new(std::io::ErrorKind::NotFound, "File Not Found")) }) } } diff --git a/packages/disk/tests/complete_torrent.rs b/packages/disk/tests/complete_torrent.rs index 27a38ab4b..d7ffcf172 100644 --- a/packages/disk/tests/complete_torrent.rs +++ b/packages/disk/tests/complete_torrent.rs @@ -1,16 +1,27 @@ -use common::{core_loop_with_timeout, random_buffer, send_block, InMemoryFileSystem, MultiFileDirectAccessor}; +use common::{ + random_buffer, runtime_loop_with_timeout, send_block, tracing_stderr_init, InMemoryFileSystem, MultiFileDirectAccessor, + DEFAULT_TIMEOUT, INIT, +}; use disk::{DiskManagerBuilder, IDiskMessage, ODiskMessage}; -use futures::future::Loop; -use futures::sink::Sink; -use futures::stream::Stream; +use futures::future::{self, Either}; +use futures::{FutureExt, SinkExt as _}; use metainfo::{Metainfo, MetainfoBuilder, PieceLength}; -use tokio_core::reactor::Core; +use tokio::task::JoinSet; +use tracing::level_filters::LevelFilter; mod common; +#[allow(unused_variables)] +#[allow(unreachable_code)] #[allow(clippy::too_many_lines)] -#[test] -fn positive_complete_torrent() { +#[tokio::test] +async fn positive_complete_torrent() { + INIT.call_once(|| { + tracing_stderr_init(LevelFilter::INFO); + }); + + let mut tasks = JoinSet::new(); + // Create some "files" as random bytes let data_a = (random_buffer(1023), "/path/to/file/a".into()); let data_b = (random_buffer(2000), "/path/to/file/b".into()); @@ -27,121 +38,71 @@ fn positive_complete_torrent() { let filesystem = InMemoryFileSystem::new(); let disk_manager = DiskManagerBuilder::new().build(filesystem.clone()); - let (send, recv) = disk_manager.split(); - let mut blocking_send = send.wait(); - blocking_send.send(IDiskMessage::AddTorrent(metainfo_file.clone())).unwrap(); + let (mut send, recv) = disk_manager.into_parts(); + send.send(IDiskMessage::AddTorrent(metainfo_file.clone())).await.unwrap(); // Verify that zero pieces are marked as good - let mut core = Core::new().unwrap(); - - // Run a core loop until we get the TorrentAdded message - let (good_pieces, recv) = core_loop_with_timeout(&mut core, 500, (0, recv), |good_pieces, recv, msg| match msg { - ODiskMessage::TorrentAdded(_) => Loop::Break((good_pieces, recv)), - ODiskMessage::FoundGoodPiece(_, _) => Loop::Continue((good_pieces + 1, recv)), + // Run a runtime loop until we get the TorrentAdded message + let (good_pieces, recv) = runtime_loop_with_timeout(DEFAULT_TIMEOUT, (0, recv), |good_pieces, recv, msg| match msg { + Ok(ODiskMessage::TorrentAdded(_)) => Either::Left(future::ready((good_pieces, recv)).boxed()), + Ok(ODiskMessage::FoundGoodPiece(_, _)) => Either::Right(future::ready((good_pieces + 1, recv)).boxed()), unexpected => panic!("Unexpected Message: {unexpected:?}"), - }); + }) + .await; // Make sure we have no good pieces assert_eq!(0, good_pieces); - // Send a couple blocks that are known to be good, then one bad block - let mut files_bytes = Vec::new(); - files_bytes.extend_from_slice(&data_a.0); - files_bytes.extend_from_slice(&data_b.0); - - // Send piece 0 with a bad last block - send_block( - &mut blocking_send, - &files_bytes[0..500], - metainfo_file.info().info_hash(), - 0, - 0, - 500, - |_| (), - ); - crate::send_block( - &mut blocking_send, - &files_bytes[500..1000], - metainfo_file.info().info_hash(), - 0, - 500, - 500, - |_| (), - ); - crate::send_block( - &mut blocking_send, - &files_bytes[1000..1024], - metainfo_file.info().info_hash(), - 0, - 1000, - 24, - |bytes| { - bytes[0] = !bytes[0]; - }, - ); - - // Send piece 1 with good blocks - crate::send_block( - &mut blocking_send, - &files_bytes[1024..(1024 + 500)], - metainfo_file.info().info_hash(), - 1, - 0, - 500, - |_| (), - ); - crate::send_block( - &mut blocking_send, - &files_bytes[(1024 + 500)..(1024 + 1000)], - metainfo_file.info().info_hash(), - 1, - 500, - 500, - |_| (), - ); - crate::send_block( - &mut blocking_send, - &files_bytes[(1024 + 1000)..(1024 + 1024)], - metainfo_file.info().info_hash(), - 1, - 1000, - 24, - |_| (), - ); - - // Send piece 2 with good blocks - crate::send_block( - &mut blocking_send, - &files_bytes[2048..(2048 + 500)], - metainfo_file.info().info_hash(), - 2, - 0, - 500, - |_| (), - ); - crate::send_block( - &mut blocking_send, - &files_bytes[(2048 + 500)..(2048 + 975)], - metainfo_file.info().info_hash(), - 2, - 500, - 475, - |_| (), - ); - - // Verify that piece 0 is bad, but piece 1 and 2 are good - let (recv, piece_zero_good, piece_one_good, piece_two_good) = crate::core_loop_with_timeout( - &mut core, - 500, + let files_bytes = { + let mut b = Vec::new(); + b.extend_from_slice(&data_a.0); + b.extend_from_slice(&data_b.0); + b + }; + + tracing::debug!("send two blocks that are known to be good, then one bad block"); + + let send_one_bad_and_two_good = { + let mut send = send.clone(); + let data = files_bytes.clone(); + let info_hash = metainfo_file.info().info_hash(); + + async move { + // Send piece 0 with a bad last block + send_block(&mut send, &data[0..500], info_hash, 0, 0, 500, |_| ()).await; + send_block(&mut send, &data[500..1000], info_hash, 0, 500, 500, |_| ()).await; + send_block(&mut send, &data[1000..1024], info_hash, 0, 1000, 24, |bytes| { + bytes[0] = !bytes[0]; + }) + .await; + + // Send piece 1 with good blocks + send_block(&mut send, &data[1024..(1024 + 500)], info_hash, 1, 0, 500, |_| ()).await; + send_block(&mut send, &data[(1024 + 500)..(1024 + 1000)], info_hash, 1, 500, 500, |_| ()).await; + send_block(&mut send, &data[(1024 + 1000)..(1024 + 1024)], info_hash, 1, 1000, 24, |_| ()).await; + + // Send piece 2 with good blocks + send_block(&mut send, &data[2048..(2048 + 500)], info_hash, 2, 0, 500, |_| ()).await; + send_block(&mut send, &data[(2048 + 500)..(2048 + 975)], info_hash, 2, 500, 475, |_| ()).await; + } + .boxed() + }; + + tasks.spawn(send_one_bad_and_two_good); + + tracing::debug!("verify that piece 0 is bad, but piece 1 and 2 are good"); + + let (recv, piece_zero_good, piece_one_good, piece_two_good) = runtime_loop_with_timeout( + DEFAULT_TIMEOUT, ((false, false, false, 0), recv), |(piece_zero_good, piece_one_good, piece_two_good, messages_recvd), recv, msg| { let messages_recvd = messages_recvd + 1; // Map BlockProcessed to a None piece index so we don't update our state let (opt_piece_index, new_value) = match msg { - ODiskMessage::FoundGoodPiece(_, index) => (Some(index), true), - ODiskMessage::FoundBadPiece(_, index) => (Some(index), false), - ODiskMessage::BlockProcessed(_) => (None, false), + Ok(ODiskMessage::FoundGoodPiece(_, index)) => (Some(index), true), + Ok(ODiskMessage::FoundBadPiece(_, index)) => (Some(index), false), + Ok(ODiskMessage::BlockProcessed(_)) => (None, false), unexpected => panic!("Unexpected Message: {unexpected:?}"), }; @@ -155,60 +116,61 @@ fn positive_complete_torrent() { // One message for each block (8 blocks), plus 3 messages for bad/good if messages_recvd == (8 + 3) { - Loop::Break((recv, piece_zero_good, piece_one_good, piece_two_good)) + Either::Left(future::ready((recv, piece_zero_good, piece_one_good, piece_two_good)).boxed()) } else { - Loop::Continue(((piece_zero_good, piece_one_good, piece_two_good, messages_recvd), recv)) + Either::Right(future::ready(((piece_zero_good, piece_one_good, piece_two_good, messages_recvd), recv)).boxed()) } }, - ); + ) + .await; // Assert whether or not pieces were good assert!(!piece_zero_good); assert!(piece_one_good); assert!(piece_two_good); - // Resend piece 0 with good blocks - crate::send_block( - &mut blocking_send, - &files_bytes[0..500], - metainfo_file.info().info_hash(), - 0, - 0, - 500, - |_| (), - ); - crate::send_block( - &mut blocking_send, - &files_bytes[500..1000], - metainfo_file.info().info_hash(), - 0, - 500, - 500, - |_| (), - ); - crate::send_block( - &mut blocking_send, - &files_bytes[1000..1024], - metainfo_file.info().info_hash(), - 0, - 1000, - 24, - |_| (), - ); - - // Verify that piece 0 is now good - let piece_zero_good = crate::core_loop_with_timeout( - &mut core, - 500, + { + tokio::task::yield_now().await; + + while let Some(task) = tasks.try_join_next() { + match task { + Ok(()) => continue, + Err(e) => panic!("task joined with error: {e}"), + } + } + + assert!(tasks.is_empty(), "all the tasks should have finished now"); + } + + tracing::debug!("resend piece 0 with good blocks"); + + let resend_with_good_blocks = { + let mut send = send.clone(); + let data = files_bytes.clone(); + let info_hash = metainfo_file.info().info_hash(); + async move { + send_block(&mut send, &data[0..500], info_hash, 0, 0, 500, |_| ()).await; + send_block(&mut send, &data[500..1000], info_hash, 0, 500, 500, |_| ()).await; + send_block(&mut send, &data[1000..1024], info_hash, 0, 1000, 24, |_| ()).await; + } + .boxed() + }; + + tasks.spawn(resend_with_good_blocks); + + tracing::debug!("verify that piece 0 is now good"); + + let piece_zero_good = runtime_loop_with_timeout( + DEFAULT_TIMEOUT, ((false, 0), recv), |(piece_zero_good, messages_recvd), recv, msg| { let messages_recvd = messages_recvd + 1; // Map BlockProcessed to a None piece index so we don't update our state let (opt_piece_index, new_value) = match msg { - ODiskMessage::FoundGoodPiece(_, index) => (Some(index), true), - ODiskMessage::FoundBadPiece(_, index) => (Some(index), false), - ODiskMessage::BlockProcessed(_) => (None, false), + Ok(ODiskMessage::FoundGoodPiece(_, index)) => (Some(index), true), + Ok(ODiskMessage::FoundBadPiece(_, index)) => (Some(index), false), + Ok(ODiskMessage::BlockProcessed(_)) => (None, false), unexpected => panic!("Unexpected Message: {unexpected:?}"), }; @@ -220,12 +182,26 @@ fn positive_complete_torrent() { // One message for each block (3 blocks), plus 1 messages for bad/good if messages_recvd == (3 + 1) { - Loop::Break(piece_zero_good) + Either::Left(future::ready(piece_zero_good).boxed()) } else { - Loop::Continue(((piece_zero_good, messages_recvd), recv)) + Either::Right(future::ready(((piece_zero_good, messages_recvd), recv)).boxed()) } }, - ); + ) + .await; + + { + tokio::task::yield_now().await; + + while let Some(task) = tasks.try_join_next() { + match task { + Ok(()) => continue, + Err(e) => panic!("task joined with error: {e}"), + } + } + + assert!(tasks.is_empty(), "all the tasks should have finished now"); + } // Assert whether or not piece was good assert!(piece_zero_good); diff --git a/packages/disk/tests/disk_manager_send_backpressure.rs b/packages/disk/tests/disk_manager_send_backpressure.rs index 03e282aff..9b5e3d55c 100644 --- a/packages/disk/tests/disk_manager_send_backpressure.rs +++ b/packages/disk/tests/disk_manager_send_backpressure.rs @@ -1,22 +1,23 @@ -use common::{random_buffer, InMemoryFileSystem, MultiFileDirectAccessor}; +use common::{random_buffer, tracing_stderr_init, InMemoryFileSystem, MultiFileDirectAccessor, DEFAULT_TIMEOUT, INIT}; use disk::{DiskManagerBuilder, IDiskMessage}; -use futures::future::Future; -use futures::sink::Sink; -use futures::stream::Stream; -use futures::{future, AsyncSink}; +use futures::{FutureExt, SinkExt as _, StreamExt as _}; use metainfo::{Metainfo, MetainfoBuilder, PieceLength}; -use tokio_core::reactor::Core; +use tracing::level_filters::LevelFilter; mod common; -#[test] -fn positive_disk_manager_send_backpressure() { +#[tokio::test] +async fn positive_disk_manager_send_backpressure() { + INIT.call_once(|| { + tracing_stderr_init(LevelFilter::INFO); + }); + // Create some "files" as random bytes let data_a = (random_buffer(50), "/path/to/file/a".into()); let data_b = (random_buffer(2000), "/path/to/file/b".into()); let data_c = (random_buffer(0), "/path/to/file/c".into()); - // Create our accessor for our in memory files and create a torrent file for them + // Create our accessor for our in-memory files and create a torrent file for them let files_accessor = MultiFileDirectAccessor::new("/my/downloads/".into(), vec![data_a.clone(), data_b.clone(), data_c.clone()]); let metainfo_bytes = MetainfoBuilder::new() @@ -28,33 +29,40 @@ fn positive_disk_manager_send_backpressure() { // Spin up a disk manager and add our created torrent to it let filesystem = InMemoryFileSystem::new(); - let (m_send, m_recv) = DiskManagerBuilder::new() + let (mut m_send, mut m_recv) = DiskManagerBuilder::new() .with_sink_buffer_capacity(1) .build(filesystem.clone()) - .split(); - - let mut core = Core::new().unwrap(); + .into_parts(); // Add a torrent, so our receiver has a single torrent added message buffered - let mut m_send = core.run(m_send.send(IDiskMessage::AddTorrent(metainfo_file))).unwrap(); + tokio::time::timeout(DEFAULT_TIMEOUT, m_send.send(IDiskMessage::AddTorrent(metainfo_file))) + .await + .unwrap() + .unwrap(); // Try to send a remove message (but it should fail) - let (result, m_send) = core - .run(future::lazy(|| { - future::ok::<_, ()>((m_send.start_send(IDiskMessage::RemoveTorrent(info_hash)), m_send)) - })) - .unwrap(); - match result { - Ok(AsyncSink::NotReady(_)) => (), - _ => panic!("Unexpected Result From Backpressure Test"), - }; + assert!( + m_send.send(IDiskMessage::RemoveTorrent(info_hash)).now_or_never().is_none(), + "it should have back_pressure" + ); // Receive from our stream to unblock the backpressure - let m_recv = core.run(m_recv.into_future().map(|(_, recv)| recv).map_err(|_| ())).unwrap(); + tokio::time::timeout(DEFAULT_TIMEOUT, m_recv.next()) + .await + .unwrap() + .unwrap() + .unwrap(); // Try to send a remove message again which should go through - drop(core.run(m_send.send(IDiskMessage::RemoveTorrent(info_hash))).unwrap()); + tokio::time::timeout(DEFAULT_TIMEOUT, m_send.send(IDiskMessage::RemoveTorrent(info_hash))) + .await + .unwrap() + .unwrap(); - // Receive confirmation (just so the pool doesn't panic because we ended before it could send the message back) - drop(core.run(m_recv.into_future().map(|(_, recv)| recv).map_err(|_| ())).unwrap()); + // Receive confirmation + tokio::time::timeout(DEFAULT_TIMEOUT, m_recv.next()) + .await + .unwrap() + .unwrap() + .unwrap(); } diff --git a/packages/disk/tests/load_block.rs b/packages/disk/tests/load_block.rs index 78cfe8eee..5b877516a 100644 --- a/packages/disk/tests/load_block.rs +++ b/packages/disk/tests/load_block.rs @@ -1,21 +1,24 @@ use bytes::BytesMut; -use common::{core_loop_with_timeout, random_buffer, InMemoryFileSystem, MultiFileDirectAccessor}; +use common::{random_buffer, tracing_stderr_init, InMemoryFileSystem, MultiFileDirectAccessor, INIT}; use disk::{Block, BlockMetadata, BlockMut, DiskManagerBuilder, IDiskMessage, ODiskMessage}; -use futures::future::Loop; -use futures::sink::Sink; -use futures::stream::Stream; +use futures::{SinkExt as _, StreamExt as _}; use metainfo::{Metainfo, MetainfoBuilder, PieceLength}; -use tokio_core::reactor::Core; +use tokio::time::{timeout, Duration}; +use tracing::level_filters::LevelFilter; mod common; -#[test] -fn positive_load_block() { +#[tokio::test] +async fn positive_load_block() { + INIT.call_once(|| { + tracing_stderr_init(LevelFilter::INFO); + }); + // Create some "files" as random bytes let data_a = (random_buffer(1023), "/path/to/file/a".into()); let data_b = (random_buffer(2000), "/path/to/file/b".into()); - // Create our accessor for our in memory files and create a torrent file for them + // Create our accessor for our in-memory files and create a torrent file for them let files_accessor = MultiFileDirectAccessor::new("/my/downloads/".into(), vec![data_a.clone(), data_b.clone()]); let metainfo_bytes = MetainfoBuilder::new() .set_piece_length(PieceLength::Custom(1024)) @@ -23,7 +26,7 @@ fn positive_load_block() { .unwrap(); let metainfo_file = Metainfo::from_bytes(metainfo_bytes).unwrap(); - // Spin up a disk manager and add our created torrent to its + // Spin up a disk manager and add our created torrent to it let filesystem = InMemoryFileSystem::new(); let disk_manager = DiskManagerBuilder::new().build(filesystem.clone()); @@ -39,28 +42,30 @@ fn positive_load_block() { ); let load_block = BlockMut::new(BlockMetadata::new(metainfo_file.info().info_hash(), 1, 0, 50), load_block); - let (send, recv) = disk_manager.split(); - let mut blocking_send = send.wait(); - blocking_send.send(IDiskMessage::AddTorrent(metainfo_file)).unwrap(); + let (mut send, mut recv) = disk_manager.into_parts(); + send.send(IDiskMessage::AddTorrent(metainfo_file)).await.unwrap(); - let mut core = Core::new().unwrap(); - let (pblock, lblock) = core_loop_with_timeout( - &mut core, - 500, - ((blocking_send, Some(process_block), Some(load_block)), recv), - |(mut blocking_send, opt_pblock, opt_lblock), recv, msg| match msg { - ODiskMessage::TorrentAdded(_) => { - blocking_send.send(IDiskMessage::ProcessBlock(opt_pblock.unwrap())).unwrap(); - Loop::Continue(((blocking_send, None, opt_lblock), recv)) - } - ODiskMessage::BlockProcessed(block) => { - blocking_send.send(IDiskMessage::LoadBlock(opt_lblock.unwrap())).unwrap(); - Loop::Continue(((blocking_send, Some(block), None), recv)) + let timeout_duration = Duration::from_millis(500); + let result = timeout(timeout_duration, async { + loop { + match recv.next().await { + Some(Ok(ODiskMessage::TorrentAdded(_))) => { + send.send(IDiskMessage::ProcessBlock(process_block.clone())).await.unwrap(); + } + Some(Ok(ODiskMessage::BlockProcessed(_block))) => { + send.send(IDiskMessage::LoadBlock(load_block.clone())).await.unwrap(); + } + Some(Ok(ODiskMessage::BlockLoaded(block))) => { + return (process_block, block); + } + Some(unexpected) => panic!("Unexpected Message: {unexpected:?}"), + None => panic!("End Of Stream Reached"), } - ODiskMessage::BlockLoaded(block) => Loop::Break((opt_pblock.unwrap(), block)), - unexpected => panic!("Unexpected Message: {unexpected:?}"), - }, - ); + } + }) + .await; + + let (pblock, lblock) = result.unwrap(); // Verify lblock contains our data assert_eq!(*pblock, *lblock); diff --git a/packages/disk/tests/process_block.rs b/packages/disk/tests/process_block.rs index b6d4f2132..42906e7e2 100644 --- a/packages/disk/tests/process_block.rs +++ b/packages/disk/tests/process_block.rs @@ -1,16 +1,21 @@ use bytes::BytesMut; -use common::{core_loop_with_timeout, random_buffer, InMemoryFileSystem, MultiFileDirectAccessor}; +use common::{ + random_buffer, runtime_loop_with_timeout, tracing_stderr_init, InMemoryFileSystem, MultiFileDirectAccessor, DEFAULT_TIMEOUT, + INIT, +}; use disk::{Block, BlockMetadata, DiskManagerBuilder, FileSystem, IDiskMessage, ODiskMessage}; -use futures::future::Loop; -use futures::sink::Sink; -use futures::stream::Stream; +use futures::{future, FutureExt as _, SinkExt as _}; use metainfo::{Metainfo, MetainfoBuilder, PieceLength}; -use tokio_core::reactor::Core; +use tracing::level_filters::LevelFilter; mod common; -#[test] -fn positive_process_block() { +#[tokio::test] +async fn positive_process_block() { + INIT.call_once(|| { + tracing_stderr_init(LevelFilter::INFO); + }); + // Create some "files" as random bytes let data_a = (random_buffer(1023), "/path/to/file/a".into()); let data_b = (random_buffer(2000), "/path/to/file/b".into()); @@ -35,24 +40,28 @@ fn positive_process_block() { process_bytes.freeze(), ); - let (send, recv) = disk_manager.split(); - let mut blocking_send = send.wait(); - blocking_send.send(IDiskMessage::AddTorrent(metainfo_file)).unwrap(); + let (mut send, recv) = disk_manager.into_parts(); + send.send(IDiskMessage::AddTorrent(metainfo_file)).await.unwrap(); + + runtime_loop_with_timeout( + DEFAULT_TIMEOUT, + ((send, Some(process_block)), recv), + |(mut send, opt_pblock), recv, msg| match msg { + Ok(ODiskMessage::TorrentAdded(_)) => { + let fut = async move { + send.send(IDiskMessage::ProcessBlock(opt_pblock.unwrap())).await.unwrap(); - let mut core = Core::new().unwrap(); - core_loop_with_timeout( - &mut core, - 500, - ((blocking_send, Some(process_block)), recv), - |(mut blocking_send, opt_pblock), recv, msg| match msg { - ODiskMessage::TorrentAdded(_) => { - blocking_send.send(IDiskMessage::ProcessBlock(opt_pblock.unwrap())).unwrap(); - Loop::Continue(((blocking_send, None), recv)) + ((send, None), recv) + } + .boxed(); + + future::Either::Right(fut) } - ODiskMessage::BlockProcessed(_) => Loop::Break(()), + Ok(ODiskMessage::BlockProcessed(_)) => future::Either::Left(future::ready(()).boxed()), unexpected => panic!("Unexpected Message: {unexpected:?}"), }, - ); + ) + .await; // Verify block was updated in data_b let mut received_file_b = filesystem.open_file(data_b.1).unwrap(); diff --git a/packages/disk/tests/remove_torrent.rs b/packages/disk/tests/remove_torrent.rs index 877497cf1..ec3606bff 100644 --- a/packages/disk/tests/remove_torrent.rs +++ b/packages/disk/tests/remove_torrent.rs @@ -1,16 +1,21 @@ use bytes::BytesMut; -use common::{core_loop_with_timeout, random_buffer, InMemoryFileSystem, MultiFileDirectAccessor}; +use common::{ + random_buffer, runtime_loop_with_timeout, tracing_stderr_init, InMemoryFileSystem, MultiFileDirectAccessor, DEFAULT_TIMEOUT, + INIT, +}; use disk::{Block, BlockMetadata, DiskManagerBuilder, IDiskMessage, ODiskMessage}; -use futures::future::Loop; -use futures::sink::Sink; -use futures::stream::Stream; +use futures::{future, FutureExt, SinkExt as _}; use metainfo::{Metainfo, MetainfoBuilder, PieceLength}; -use tokio_core::reactor::Core; +use tracing::level_filters::LevelFilter; mod common; -#[test] -fn positive_remove_torrent() { +#[tokio::test] +async fn positive_remove_torrent() { + INIT.call_once(|| { + tracing_stderr_init(LevelFilter::INFO); + }); + // Create some "files" as random bytes let data_a = (random_buffer(50), "/path/to/file/a".into()); let data_b = (random_buffer(2000), "/path/to/file/b".into()); @@ -30,27 +35,32 @@ fn positive_remove_torrent() { let filesystem = InMemoryFileSystem::new(); let disk_manager = DiskManagerBuilder::new().build(filesystem.clone()); - let (send, recv) = disk_manager.split(); - let mut blocking_send = send.wait(); - blocking_send.send(IDiskMessage::AddTorrent(metainfo_file)).unwrap(); + let (mut send, recv) = disk_manager.into_parts(); + send.send(IDiskMessage::AddTorrent(metainfo_file)).await.unwrap(); // Verify that zero pieces are marked as good - let mut core = Core::new().unwrap(); + let (mut send, good_pieces, recv) = runtime_loop_with_timeout( + DEFAULT_TIMEOUT, + ((send, 0), recv), + |(mut send, good_pieces), recv, msg| match msg { + Ok(ODiskMessage::TorrentAdded(_)) => { + let fut = async move { + send.send(IDiskMessage::RemoveTorrent(info_hash)).await.unwrap(); + + ((send, good_pieces), recv) + } + .boxed(); - let (mut blocking_send, good_pieces, recv) = core_loop_with_timeout( - &mut core, - 500, - ((blocking_send, 0), recv), - |(mut blocking_send, good_pieces), recv, msg| match msg { - ODiskMessage::TorrentAdded(_) => { - blocking_send.send(IDiskMessage::RemoveTorrent(info_hash)).unwrap(); - Loop::Continue(((blocking_send, good_pieces), recv)) + future::Either::Right(fut) + } + Ok(ODiskMessage::TorrentRemoved(_)) => future::Either::Left(future::ready((send, good_pieces, recv)).boxed()), + Ok(ODiskMessage::FoundGoodPiece(_, _)) => { + future::Either::Right(future::ready(((send, good_pieces + 1), recv)).boxed()) } - ODiskMessage::TorrentRemoved(_) => Loop::Break((blocking_send, good_pieces, recv)), - ODiskMessage::FoundGoodPiece(_, _) => Loop::Continue(((blocking_send, good_pieces + 1), recv)), unexpected => panic!("Unexpected Message: {unexpected:?}"), }, - ); + ) + .await; assert_eq!(0, good_pieces); @@ -59,10 +69,11 @@ fn positive_remove_torrent() { let process_block = Block::new(BlockMetadata::new(info_hash, 0, 0, 50), process_bytes.freeze()); - blocking_send.send(IDiskMessage::ProcessBlock(process_block)).unwrap(); + send.send(IDiskMessage::ProcessBlock(process_block)).await.unwrap(); - crate::core_loop_with_timeout(&mut core, 500, ((), recv), |(), _, msg| match msg { - ODiskMessage::ProcessBlockError(_, _) => Loop::Break(()), + runtime_loop_with_timeout(DEFAULT_TIMEOUT, ((), recv), |(), _, msg| match msg { + Ok(ODiskMessage::ProcessBlockError(_, _)) => future::Either::Left(future::ready(()).boxed()), unexpected => panic!("Unexpected Message: {unexpected:?}"), - }); + }) + .await; } diff --git a/packages/disk/tests/resume_torrent.rs b/packages/disk/tests/resume_torrent.rs index c724cf714..5fc09dc43 100644 --- a/packages/disk/tests/resume_torrent.rs +++ b/packages/disk/tests/resume_torrent.rs @@ -1,16 +1,21 @@ -use common::{core_loop_with_timeout, random_buffer, send_block, InMemoryFileSystem, MultiFileDirectAccessor}; +use common::{ + random_buffer, runtime_loop_with_timeout, send_block, tracing_stderr_init, InMemoryFileSystem, MultiFileDirectAccessor, + DEFAULT_TIMEOUT, INIT, +}; use disk::{DiskManagerBuilder, IDiskMessage, ODiskMessage}; -use futures::future::Loop; -use futures::sink::Sink; -use futures::stream::Stream; +use futures::{future, FutureExt, SinkExt as _}; use metainfo::{Metainfo, MetainfoBuilder, PieceLength}; -use tokio_core::reactor::Core; +use tracing::level_filters::LevelFilter; mod common; #[allow(clippy::too_many_lines)] -#[test] -fn positive_complete_torrent() { +#[tokio::test] +async fn positive_complete_torrent() { + INIT.call_once(|| { + tracing_stderr_init(LevelFilter::INFO); + }); + // Create some "files" as random bytes let data_a = (random_buffer(1023), "/path/to/file/a".into()); let data_b = (random_buffer(2000), "/path/to/file/b".into()); @@ -28,19 +33,18 @@ fn positive_complete_torrent() { let filesystem = InMemoryFileSystem::new(); let disk_manager = DiskManagerBuilder::new().build(filesystem.clone()); - let (send, recv) = disk_manager.split(); - let mut blocking_send = send.wait(); - blocking_send.send(IDiskMessage::AddTorrent(metainfo_file.clone())).unwrap(); + let (mut send, recv) = disk_manager.into_parts(); + send.send(IDiskMessage::AddTorrent(metainfo_file.clone())).await.unwrap(); // Verify that zero pieces are marked as good - let mut core = Core::new().unwrap(); - // Run a core loop until we get the TorrentAdded message - let (good_pieces, recv) = crate::core_loop_with_timeout(&mut core, 500, (0, recv), |good_pieces, recv, msg| match msg { - ODiskMessage::TorrentAdded(_) => Loop::Break((good_pieces, recv)), - ODiskMessage::FoundGoodPiece(_, _) => Loop::Continue((good_pieces + 1, recv)), + // Run a runtime loop until we get the TorrentAdded message + let (good_pieces, recv) = runtime_loop_with_timeout(DEFAULT_TIMEOUT, (0, recv), |good_pieces, recv, msg| match msg { + Ok(ODiskMessage::TorrentAdded(_)) => future::Either::Left(future::ready((good_pieces, recv)).boxed()), + Ok(ODiskMessage::FoundGoodPiece(_, _)) => future::Either::Right(future::ready((good_pieces + 1, recv)).boxed()), unexpected => panic!("Unexpected Message: {unexpected:?}"), - }); + }) + .await; // Make sure we have no good pieces assert_eq!(0, good_pieces); @@ -52,46 +56,48 @@ fn positive_complete_torrent() { // Send piece 0 send_block( - &mut blocking_send, + &mut send, &files_bytes[0..500], metainfo_file.info().info_hash(), 0, 0, 500, |_| (), - ); + ) + .await; send_block( - &mut blocking_send, + &mut send, &files_bytes[500..1000], metainfo_file.info().info_hash(), 0, 500, 500, |_| (), - ); + ) + .await; send_block( - &mut blocking_send, + &mut send, &files_bytes[1000..1024], metainfo_file.info().info_hash(), 0, 1000, 24, |_| (), - ); + ) + .await; // Verify that piece 0 is good - let (recv, piece_zero_good) = core_loop_with_timeout( - &mut core, - 500, + let (recv, piece_zero_good) = runtime_loop_with_timeout( + DEFAULT_TIMEOUT, ((false, 0), recv), |(piece_zero_good, messages_recvd), recv, msg| { let messages_recvd = messages_recvd + 1; // Map BlockProcessed to a None piece index so we don't update our state let (opt_piece_index, new_value) = match msg { - ODiskMessage::FoundGoodPiece(_, index) => (Some(index), true), - ODiskMessage::FoundBadPiece(_, index) => (Some(index), false), - ODiskMessage::BlockProcessed(_) => (None, false), + Ok(ODiskMessage::FoundGoodPiece(_, index)) => (Some(index), true), + Ok(ODiskMessage::FoundBadPiece(_, index)) => (Some(index), false), + Ok(ODiskMessage::BlockProcessed(_)) => (None, false), unexpected => panic!("Unexpected Message: {unexpected:?}"), }; @@ -102,99 +108,106 @@ fn positive_complete_torrent() { }; if messages_recvd == (3 + 1) { - Loop::Break((recv, piece_zero_good)) + future::Either::Left(future::ready((recv, piece_zero_good)).boxed()) } else { - Loop::Continue(((piece_zero_good, messages_recvd), recv)) + future::Either::Right(future::ready(((piece_zero_good, messages_recvd), recv)).boxed()) } }, - ); + ) + .await; // Assert whether or not pieces were good assert!(piece_zero_good); // Remove the torrent from our manager - blocking_send.send(IDiskMessage::RemoveTorrent(info_hash)).unwrap(); + send.send(IDiskMessage::RemoveTorrent(info_hash)).await.unwrap(); // Verify that our torrent was removed - let recv = crate::core_loop_with_timeout(&mut core, 500, ((), recv), |(), recv, msg| match msg { - ODiskMessage::TorrentRemoved(_) => Loop::Break(recv), + let recv = runtime_loop_with_timeout(DEFAULT_TIMEOUT, ((), recv), |(), recv, msg| match msg { + Ok(ODiskMessage::TorrentRemoved(_)) => future::Either::Left(future::ready(recv).boxed()), unexpected => panic!("Unexpected Message: {unexpected:?}"), - }); + }) + .await; // Re-add our torrent and verify that we see our good first block - blocking_send.send(IDiskMessage::AddTorrent(metainfo_file.clone())).unwrap(); + send.send(IDiskMessage::AddTorrent(metainfo_file.clone())).await.unwrap(); let (recv, piece_zero_good) = - crate::core_loop_with_timeout(&mut core, 500, (false, recv), |piece_zero_good, recv, msg| match msg { - ODiskMessage::TorrentAdded(_) => Loop::Break((recv, piece_zero_good)), - ODiskMessage::FoundGoodPiece(_, 0) => Loop::Continue((true, recv)), + runtime_loop_with_timeout(DEFAULT_TIMEOUT, (false, recv), |piece_zero_good, recv, msg| match msg { + Ok(ODiskMessage::TorrentAdded(_)) => future::Either::Left(future::ready((recv, piece_zero_good)).boxed()), + Ok(ODiskMessage::FoundGoodPiece(_, 0)) => future::Either::Right(future::ready((true, recv)).boxed()), unexpected => panic!("Unexpected Message: {unexpected:?}"), - }); + }) + .await; assert!(piece_zero_good); // Send piece 1 - crate::send_block( - &mut blocking_send, + send_block( + &mut send, &files_bytes[1024..(1024 + 500)], metainfo_file.info().info_hash(), 1, 0, 500, |_| (), - ); - crate::send_block( - &mut blocking_send, + ) + .await; + send_block( + &mut send, &files_bytes[(1024 + 500)..(1024 + 1000)], metainfo_file.info().info_hash(), 1, 500, 500, |_| (), - ); - crate::send_block( - &mut blocking_send, + ) + .await; + send_block( + &mut send, &files_bytes[(1024 + 1000)..(1024 + 1024)], metainfo_file.info().info_hash(), 1, 1000, 24, |_| (), - ); + ) + .await; // Send piece 2 - crate::send_block( - &mut blocking_send, + send_block( + &mut send, &files_bytes[2048..(2048 + 500)], metainfo_file.info().info_hash(), 2, 0, 500, |_| (), - ); - crate::send_block( - &mut blocking_send, + ) + .await; + send_block( + &mut send, &files_bytes[(2048 + 500)..(2048 + 975)], metainfo_file.info().info_hash(), 2, 500, 475, |_| (), - ); + ) + .await; // Verify last two blocks are good - let (piece_one_good, piece_two_good) = crate::core_loop_with_timeout( - &mut core, - 500, + let (piece_one_good, piece_two_good) = runtime_loop_with_timeout( + DEFAULT_TIMEOUT, ((false, false, 0), recv), |(piece_one_good, piece_two_good, messages_recvd), recv, msg| { let messages_recvd = messages_recvd + 1; // Map BlockProcessed to a None piece index so we don't update our state let (opt_piece_index, new_value) = match msg { - ODiskMessage::FoundGoodPiece(_, index) => (Some(index), true), - ODiskMessage::FoundBadPiece(_, index) => (Some(index), false), - ODiskMessage::BlockProcessed(_) => (None, false), + Ok(ODiskMessage::FoundGoodPiece(_, index)) => (Some(index), true), + Ok(ODiskMessage::FoundBadPiece(_, index)) => (Some(index), false), + Ok(ODiskMessage::BlockProcessed(_)) => (None, false), unexpected => panic!("Unexpected Message: {unexpected:?}"), }; @@ -206,12 +219,13 @@ fn positive_complete_torrent() { }; if messages_recvd == (5 + 2) { - Loop::Break((piece_one_good, piece_two_good)) + future::Either::Left(future::ready((piece_one_good, piece_two_good)).boxed()) } else { - Loop::Continue(((piece_one_good, piece_two_good, messages_recvd), recv)) + future::Either::Right(future::ready(((piece_one_good, piece_two_good, messages_recvd), recv)).boxed()) } }, - ); + ) + .await; assert!(piece_one_good); assert!(piece_two_good); diff --git a/packages/handshake/Cargo.toml b/packages/handshake/Cargo.toml index 1f6c3ace5..705ea3d80 100644 --- a/packages/handshake/Cargo.toml +++ b/packages/handshake/Cargo.toml @@ -18,10 +18,13 @@ version.workspace = true [dependencies] util = { path = "../util" } -bytes = "0.4" -futures = "0.1" -nom = "3" -rand = "0.8" -tokio-core = "0.1" -tokio-io = "0.1" -tokio-timer = "0.1" +bytes = "1" +futures = "0" +nom = "7" +pin-project = "1" +rand = "0" +tokio = { version = "1", features = ["full"] } +tracing = "0" + +[dev-dependencies] +tracing-subscriber = "0" diff --git a/packages/handshake/examples/handshake_torrent.rs b/packages/handshake/examples/handshake_torrent.rs index 5f250a8c2..0d8bc20b8 100644 --- a/packages/handshake/examples/handshake_torrent.rs +++ b/packages/handshake/examples/handshake_torrent.rs @@ -1,67 +1,70 @@ -use std::io::{self, BufRead, Write}; +use std::io::BufRead as _; use std::net::{SocketAddr, ToSocketAddrs}; -use std::thread; use std::time::Duration; -use futures::{Future, Sink, Stream}; +use futures::SinkExt; use handshake::transports::TcpTransport; use handshake::{HandshakerBuilder, InitiateMessage, Protocol}; -use tokio_core::reactor::Core; +use tokio::time::sleep; -fn main() { - let mut stdout = io::stdout(); - let stdin = io::stdin(); +#[tokio::main] +async fn main() -> std::io::Result<()> { + let mut stdout = std::io::stdout(); + let stdin = std::io::stdin(); let mut lines = stdin.lock().lines(); - stdout.write_all(b"Enter An InfoHash In Hex Format: ").unwrap(); - stdout.flush().unwrap(); + // Prompt for InfoHash + prompt(&mut stdout, "Enter An InfoHash In Hex Format: ")?; + let hex_hash = read_line(&mut lines)?; + let info_hash = hex_to_bytes(&hex_hash).into(); - let hex_hash = lines.next().unwrap().unwrap(); - let hash = hex_to_bytes(&hex_hash).into(); - - stdout.write_all(b"Enter An Address And Port (eg: addr:port): ").unwrap(); - stdout.flush().unwrap(); - - let str_addr = lines.next().unwrap().unwrap(); - let addr = str_to_addr(&str_addr); - - let mut core = Core::new().unwrap(); + // Prompt for Address and Port + prompt(&mut stdout, "Enter An Address And Port (eg: addr:port): ")?; + let address = read_line(&mut lines)?; + let socket_addr = str_to_addr(&address)?; // Show up as a uTorrent client... let peer_id = (*b"-UT2060-000000000000").into(); - let handshaker = HandshakerBuilder::new() + let (mut handshaker, mut tasks) = HandshakerBuilder::new() .with_peer_id(peer_id) - .build(TcpTransport, &core.handle()) - .unwrap() - .send(InitiateMessage::new(Protocol::BitTorrent, hash, addr)) - .wait() - .unwrap(); + .build(TcpTransport) + .await + .map_err(|e| std::io::Error::new(std::io::ErrorKind::Other, e))?; - let _peer = core - .run(handshaker.into_future().map(|(opt_peer, _)| opt_peer.unwrap())) - .unwrap_or_else(|_| panic!("")); + handshaker + .send(InitiateMessage::new(Protocol::BitTorrent, info_hash, socket_addr)) + .await + .map_err(|e| std::io::Error::new(std::io::ErrorKind::Other, e))?; println!("\nConnection With Peer Established...Closing In 10 Seconds"); - thread::sleep(Duration::from_millis(10000)); -} + sleep(Duration::from_secs(10)).await; -fn hex_to_bytes(hex: &str) -> [u8; 20] { - let mut exact_bytes = [0u8; 20]; + tasks.shutdown().await; - for (byte_index, byte) in exact_bytes.iter_mut().enumerate() { - let high_index = byte_index * 2; - let low_index = (byte_index * 2) + 1; + Ok(()) +} - let hex_chunk = &hex[high_index..=low_index]; - let byte_value = u8::from_str_radix(hex_chunk, 16).unwrap(); +fn prompt(writer: &mut W, message: &str) -> std::io::Result<()> { + writer.write_all(message.as_bytes())?; + writer.flush() +} - *byte = byte_value; - } +fn read_line(lines: &mut std::io::Lines) -> std::io::Result { + lines.next().unwrap_or_else(|| Ok(String::new())) +} - exact_bytes +fn hex_to_bytes(hex: &str) -> [u8; 20] { + let mut bytes = [0u8; 20]; + for (i, byte) in bytes.iter_mut().enumerate() { + let hex_chunk = &hex[i * 2..=i * 2 + 1]; + *byte = u8::from_str_radix(hex_chunk, 16).unwrap(); + } + bytes } -fn str_to_addr(addr: &str) -> SocketAddr { - addr.to_socket_addrs().unwrap().next().unwrap() +fn str_to_addr(addr: &str) -> std::io::Result { + addr.to_socket_addrs()? + .next() + .ok_or_else(|| std::io::Error::new(std::io::ErrorKind::InvalidInput, "Invalid address format")) } diff --git a/packages/handshake/src/bittorrent/framed.rs b/packages/handshake/src/bittorrent/framed.rs index 27faf59a2..e41e4f106 100644 --- a/packages/handshake/src/bittorrent/framed.rs +++ b/packages/handshake/src/bittorrent/framed.rs @@ -1,145 +1,291 @@ -use std::io::{self, Cursor}; +//! This module provides the `FramedHandshake` struct, which implements a framed transport for the `BitTorrent` handshake protocol. +//! It supports both reading from and writing to an underlying asynchronous stream, handling the framing of handshake messages. -use bytes::buf::BufMut; -use bytes::BytesMut; +use std::pin::Pin; +use std::task::{Context, Poll}; + +use bytes::{Buf as _, BufMut, BytesMut}; use futures::sink::Sink; use futures::stream::Stream; -use futures::{Async, AsyncSink, Poll, StartSend}; -use nom::IResult; -use tokio_io::{try_nb, AsyncRead, AsyncWrite}; +use pin_project::pin_project; +use tokio::io::{AsyncRead, AsyncWrite, ReadBuf}; +use tracing::instrument; use crate::bittorrent::message::{self, HandshakeMessage}; +/// Represents the state of the handshake process. +#[derive(Debug)] enum HandshakeState { Waiting, - Length(u8), + Reading, + Ready, Finished, + Errored, } -// We can't use the built in frames because they may buffer more -// bytes than we need for a handshake. That is unacceptable for us -// because we are giving a raw socket to the client of this library. -// We don't want to steal any of their bytes during our handshake! +/// A framed transport for the `BitTorrent` handshake protocol. +/// +/// This struct wraps an underlying asynchronous stream and provides methods to read and write `HandshakeMessage` instances. #[allow(clippy::module_name_repetitions)] -pub struct FramedHandshake { +#[pin_project] +#[derive(Debug)] +pub struct FramedHandshake +where + S: std::fmt::Debug + Unpin, +{ + #[pin] sock: S, + write_buffer: BytesMut, read_buffer: Vec, read_pos: usize, state: HandshakeState, } -impl FramedHandshake { +impl FramedHandshake +where + S: std::fmt::Debug + Unpin, +{ + /// Creates a new `FramedHandshake` with the given socket. + /// + /// # Arguments + /// + /// * `sock` - The underlying asynchronous stream. pub fn new(sock: S) -> FramedHandshake { FramedHandshake { sock, write_buffer: BytesMut::with_capacity(1), - read_buffer: vec![0], + read_buffer: Vec::default(), read_pos: 0, state: HandshakeState::Waiting, } } + /// Consumes the `FramedHandshake`, returning the underlying socket. pub fn into_inner(self) -> S { self.sock } } -impl Sink for FramedHandshake +impl Sink for FramedHandshake where - S: AsyncWrite, + Si: AsyncWrite + std::fmt::Debug + Unpin, { - type SinkItem = HandshakeMessage; - type SinkError = io::Error; + type Error = std::io::Error; + + #[instrument(skip(self, _cx))] + fn poll_ready(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll> { + tracing::trace!("poll_ready called"); + + Poll::Ready(Ok(())) + } + + #[instrument(skip(self))] + fn start_send(mut self: Pin<&mut Self>, item: HandshakeMessage) -> Result<(), Self::Error> { + tracing::trace!("start_send called with item: {item:?}"); + let mut cursor = std::io::Cursor::new(Vec::with_capacity(item.write_len())); + item.write_bytes_sync(&mut cursor)?; - fn start_send(&mut self, item: HandshakeMessage) -> StartSend { self.write_buffer.reserve(item.write_len()); - item.write_bytes(self.write_buffer.by_ref().writer())?; + self.write_buffer.put_slice(cursor.get_ref()); - Ok(AsyncSink::Ready) + Ok(()) } - fn poll_complete(&mut self) -> Poll<(), Self::SinkError> { - loop { - let write_result = self.sock.write_buf(&mut Cursor::new(&self.write_buffer)); + #[instrument(skip(self, cx))] + fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + tracing::trace!("poll_flush called"); + + let mut project = self.project(); - match try_nb!(write_result) { - Async::Ready(0) => return Err(io::Error::new(io::ErrorKind::WriteZero, "Failed To Write Bytes")), - Async::Ready(written) => { - self.write_buffer.split_to(written); + while !project.write_buffer.is_empty() { + let res = project.sock.as_mut().poll_write(cx, project.write_buffer); + + match res { + Poll::Ready(Ok(0)) => { + tracing::error!("Failed to write bytes: WriteZero"); + return Err(std::io::Error::new(std::io::ErrorKind::WriteZero, "Failed To Write Bytes")).into(); + } + Poll::Ready(Ok(written)) => { + tracing::trace!("Wrote {} bytes", written); + project.write_buffer.advance(written); + } + Poll::Ready(Err(e)) => { + tracing::error!("Error writing bytes: {:?}", e); + return Err(e).into(); } - Async::NotReady => return Ok(Async::NotReady), + Poll::Pending => return Poll::Pending, } + } + project.sock.as_mut().poll_flush(cx) + } - if self.write_buffer.is_empty() { - try_nb!(self.sock.flush()); + #[instrument(skip(self, cx))] + fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + tracing::trace!("poll_close called"); - return Ok(Async::Ready(())); - } + match self.as_mut().poll_flush(cx)? { + Poll::Ready(()) => self.project().sock.poll_shutdown(cx), + Poll::Pending => Poll::Pending, } } } -impl Stream for FramedHandshake +impl Stream for FramedHandshake where - S: AsyncRead, + St: AsyncRead + std::fmt::Debug + Unpin, { - type Item = HandshakeMessage; - type Error = io::Error; + type Item = Result; + + #[instrument(skip(self, cx))] + fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + match self.state { + HandshakeState::Waiting => { + tracing::trace!("handshake waiting..."); + let mut this = self.project(); + + assert!(this.read_buffer.is_empty()); + assert_eq!(0, *this.read_pos); + + this.read_buffer.push(0); + assert_eq!(1, this.read_buffer.len()); + + let mut buf = ReadBuf::new(this.read_buffer); + buf.set_filled(0); + + tracing::trace!("Read Buffer: {buf:?}"); + tracing::trace!("Sock Buffer: {:?}", this.sock); + + { + let ready = match this.sock.as_mut().poll_read(cx, &mut buf) { + Poll::Ready(ready) => ready, + Poll::Pending => { + this.read_buffer.clear(); + tracing::trace!("socket pending..."); + return Poll::Pending; + } + }; + + if let Err(e) = ready { + tracing::error!("Error reading bytes: {:?}", e); + *this.state = HandshakeState::Errored; + return Poll::Ready(Some(Err(e))); + }; + } + + let filled = buf.filled(); - fn poll(&mut self) -> Poll, Self::Error> { - loop { - match self.state { - HandshakeState::Waiting => { - let read_result = AsyncRead::read_buf(&mut self.sock, &mut Cursor::new(&mut self.read_buffer[..])); + let byte = match filled.len() { + 0 => { + tracing::trace!("zero bytes read... pending"); + this.read_buffer.clear(); + cx.waker().wake_by_ref(); + return Poll::Pending; + } + 1 => filled[0], + 2.. => unreachable!("bip_handshake: limited by buffer size {filled:?}"), + }; + + let length = message::write_len_with_protocol_len(byte); + + tracing::debug!("length byte: {byte}, expands to: {length} bytes"); + + this.read_buffer.resize(length, 0); + *this.read_pos = 1; + *this.state = HandshakeState::Reading; + + cx.waker().wake_by_ref(); + return Poll::Pending; + } + HandshakeState::Reading => { + tracing::trace!("handshake reading..."); + let mut this = self.project(); - match try_nb!(read_result) { - Async::Ready(0) => return Ok(Async::Ready(None)), - Async::Ready(1) => { - let length = self.read_buffer[0]; + assert!(!this.read_buffer.is_empty()); - self.state = HandshakeState::Length(length); + let length = this.read_buffer.len(); + let pos = this.read_pos; - self.read_pos = 1; - self.read_buffer = vec![0u8; message::write_len_with_protocol_len(length)]; - self.read_buffer[0] = length; + assert_ne!(0, *pos); + assert!(*pos < length); + + let mut buf = ReadBuf::new(this.read_buffer); + buf.set_filled(*pos); + tracing::trace!("have {pos} bytes out of {length}..."); + + { + let ready = match this.sock.as_mut().poll_read(cx, &mut buf) { + Poll::Ready(ready) => ready, + Poll::Pending => { + tracing::trace!("socket pending..."); + return Poll::Pending; } - Async::Ready(read) => panic!("bip_handshake: Expected To Read Single Byte, Read {read:?}"), - Async::NotReady => return Ok(Async::NotReady), - } + }; + + if let Err(e) = ready { + tracing::error!("Error reading bytes: {:?}", e); + *this.state = HandshakeState::Errored; + return Poll::Ready(Some(Err(e))); + }; } - HandshakeState::Length(length) => { - let expected_length = message::write_len_with_protocol_len(length); - - if self.read_pos == expected_length { - match HandshakeMessage::from_bytes(&self.read_buffer) { - IResult::Done(_, message) => { - self.state = HandshakeState::Finished; - - return Ok(Async::Ready(Some(message))); - } - IResult::Incomplete(_) => panic!("bip_handshake: HandshakeMessage Failed With Incomplete Bytes"), - IResult::Error(_) => { - return Err(io::Error::new(io::ErrorKind::InvalidData, "HandshakeMessage Failed To Parse")) - } - } - } else { - let read_result = { - let mut cursor = Cursor::new(&mut self.read_buffer[self.read_pos..]); - - try_nb!(AsyncRead::read_buf(&mut self.sock, &mut cursor)) - }; - - match read_result { - Async::Ready(0) => return Ok(Async::Ready(None)), - Async::Ready(read) => { - self.read_pos += read; - } - Async::NotReady => return Ok(Async::NotReady), - } + + let filled = buf.filled().len(); + assert!(filled <= length); + assert!(*pos <= filled); + let added = filled - *pos; + *pos = filled; + + tracing::trace!("read {added} bytes, for a total of: {pos} / {length}..."); + + if filled == length { + tracing::trace!("have full amount"); + *this.state = HandshakeState::Ready; + }; + + cx.waker().wake_by_ref(); + return Poll::Pending; + } + + HandshakeState::Ready => { + tracing::trace!("handshake ready..."); + + assert!(!self.read_buffer.is_empty()); + assert_eq!(self.read_pos, self.read_buffer.len()); + + let buf = std::mem::take(&mut self.read_buffer); + + match HandshakeMessage::from_bytes(&buf) { + Ok(((), message)) => { + tracing::trace!("Parsed HandshakeMessage: {:?}", message); + self.state = HandshakeState::Finished; + + return Poll::Ready(Some(Ok(message))); + } + Err(nom::Err::Incomplete(needed)) => { + tracing::error!("Failed to parse incomplete HandshakeMessage: {needed:?}"); + self.state = HandshakeState::Errored; + + return Poll::Ready(Some(Err(std::io::Error::new( + std::io::ErrorKind::InvalidData, + format!("Failed to parse incomplete HandshakeMessage: {needed:?}"), + )))); + } + Err(e) => { + tracing::error!("Failed to parse HandshakeMessage"); + self.state = HandshakeState::Errored; + + return Poll::Ready(Some(Err(std::io::Error::new(std::io::ErrorKind::InvalidData, e)))); } } - HandshakeState::Finished => return Ok(Async::Ready(None)), + } + HandshakeState::Finished => { + tracing::trace!("handshake finished..."); + return Poll::Ready(None); + } + + HandshakeState::Errored => { + tracing::warn!("handshake polled while errored..."); + return Poll::Ready(None); } } } @@ -147,11 +293,13 @@ where #[cfg(test)] mod tests { - use std::io::{Cursor, Write}; - use futures::sink::Sink; - use futures::stream::Stream; - use futures::Future; + use std::sync::Once; + + use futures::stream::StreamExt; + use futures::SinkExt as _; + use tokio::io::{AsyncRead, AsyncReadExt as _, AsyncWriteExt as _}; + use tracing::level_filters::LevelFilter; use util::bt::{self, InfoHash, PeerId}; use super::FramedHandshake; @@ -159,6 +307,19 @@ mod tests { use crate::message::extensions::{self, Extensions}; use crate::message::protocol::Protocol; + pub static INIT: Once = Once::new(); + + pub fn tracing_stderr_init(filter: LevelFilter) { + let builder = tracing_subscriber::fmt() + .with_max_level(filter) + .with_ansi(true) + .with_writer(std::io::stderr); + + builder.pretty().with_file(true).init(); + + tracing::info!("Logging initialized"); + } + fn any_peer_id() -> PeerId { [22u8; bt::PEER_ID_LEN].into() } @@ -171,24 +332,53 @@ mod tests { [255u8; extensions::NUM_EXTENSION_BYTES].into() } - #[test] - fn positive_write_handshake_message() { + #[tokio::test] + async fn write_and_read_into_async_buffer() { + INIT.call_once(|| { + tracing_stderr_init(LevelFilter::INFO); + }); + + let mut v = vec![0; 100]; + //let mut buf = std::io::Cursor::new(v); + let mut read_buf = tokio::io::ReadBuf::new(&mut v); + + let data: Box>> = Box::new(std::io::Cursor::new((0..100).collect())); + + let data_reader = &mut (data as Box); + + let mut a = data_reader.read_buf(&mut read_buf).await.unwrap(); + a += data_reader.read_buf(&mut read_buf).await.unwrap(); + a += data_reader.read_buf(&mut read_buf).await.unwrap(); + + assert_eq!(a, 100); + } + + #[tokio::test] + async fn positive_write_handshake_message() { + INIT.call_once(|| { + tracing_stderr_init(LevelFilter::INFO); + }); + let message = HandshakeMessage::from_parts(Protocol::BitTorrent, any_extensions(), any_info_hash(), any_peer_id()); - let write_frame = FramedHandshake::new(Cursor::new(Vec::new())) - .send(message.clone()) - .wait() - .unwrap(); - let recv_buffer = write_frame.into_inner().into_inner(); + let mut framed_handshake = FramedHandshake::new(std::io::Cursor::new(Vec::new())); + + framed_handshake.send(message.clone()).await.unwrap(); + + let sock = framed_handshake.into_inner(); let mut exp_buffer = Vec::new(); - message.write_bytes(&mut exp_buffer).unwrap(); + message.write_bytes(&mut exp_buffer).await.unwrap(); - assert_eq!(exp_buffer, recv_buffer); + assert_eq!(exp_buffer, sock.into_inner()); } - #[test] - fn positive_write_multiple_handshake_messages() { + #[tokio::test] + async fn positive_write_multiple_handshake_messages() { + INIT.call_once(|| { + tracing_stderr_init(LevelFilter::INFO); + }); + let message_one = HandshakeMessage::from_parts(Protocol::BitTorrent, any_extensions(), any_info_hash(), any_peer_id()); let message_two = HandshakeMessage::from_parts( Protocol::Custom(vec![5, 6, 7]), @@ -197,65 +387,98 @@ mod tests { any_peer_id(), ); - let write_frame = FramedHandshake::new(Cursor::new(Vec::new())) - .send(message_one.clone()) - .wait() - .unwrap() - .send(message_two.clone()) - .wait() - .unwrap(); - let recv_buffer = write_frame.into_inner().into_inner(); + let mut framed_handshake = FramedHandshake::new(std::io::Cursor::new(Vec::new())); + + framed_handshake.send(message_one.clone()).await.unwrap(); + framed_handshake.send(message_two.clone()).await.unwrap(); + + let sock = framed_handshake.into_inner(); let mut exp_buffer = Vec::new(); - message_one.write_bytes(&mut exp_buffer).unwrap(); - message_two.write_bytes(&mut exp_buffer).unwrap(); + message_one.write_bytes(&mut exp_buffer).await.unwrap(); + message_two.write_bytes(&mut exp_buffer).await.unwrap(); - assert_eq!(exp_buffer, recv_buffer); + assert_eq!(exp_buffer, sock.into_inner()); } - #[test] - fn positive_read_handshake_message() { + #[tokio::test] + async fn positive_read_handshake_message() { + INIT.call_once(|| { + tracing_stderr_init(LevelFilter::INFO); + }); + let exp_message = HandshakeMessage::from_parts(Protocol::BitTorrent, any_extensions(), any_info_hash(), any_peer_id()); + tracing::trace!("Handshake Message: {:?}", exp_message); - let mut buffer = Vec::new(); - exp_message.write_bytes(&mut buffer).unwrap(); + let mut buffer = std::io::Cursor::new(Vec::new()); + exp_message.write_bytes(&mut buffer).await.unwrap(); + buffer.set_position(0); + + tracing::trace!("Buffer before reading: {:?}", buffer); + let mut framed_handshake = FramedHandshake::new(buffer); + + let recv_message = match framed_handshake.next().await { + Some(Ok(msg)) => msg, + Some(Err(e)) => panic!("Error reading message: {e:?}"), + None => panic!("Expected a message but got None"), + }; + assert!(framed_handshake.next().await.is_none()); - let mut read_iter = FramedHandshake::new(&buffer[..]).wait(); - let recv_message = read_iter.next().unwrap().unwrap(); - assert!(read_iter.next().is_none()); + tracing::trace!("Received message: {:?}", recv_message); assert_eq!(exp_message, recv_message); } - #[test] - fn positive_read_byte_after_handshake() { + #[tokio::test] + async fn positive_read_byte_after_handshake() { + INIT.call_once(|| { + tracing_stderr_init(LevelFilter::INFO); + }); + let exp_message = HandshakeMessage::from_parts(Protocol::BitTorrent, any_extensions(), any_info_hash(), any_peer_id()); - let mut buffer = Vec::new(); - exp_message.write_bytes(&mut buffer).unwrap(); - // Write some bytes right after the handshake, make sure - // our framed handshake doesn't read/buffer these (we need - // to be able to read them afterwards) - buffer.write_all(&[55]).unwrap(); + let mut buffer = std::io::Cursor::new(Vec::new()); + let () = exp_message.write_bytes(&mut buffer).await.unwrap(); + let () = buffer.write_all(&[55]).await.unwrap(); + let () = buffer.set_position(0); - let read_frame = FramedHandshake::new(&buffer[..]).into_future().wait().ok().unwrap().1; - let buffer_ref = read_frame.into_inner(); + tracing::trace!("Buffer before reading: {:?}", buffer); + + let mut framed_handshake = FramedHandshake::new(buffer); + let message = framed_handshake.next().await.unwrap().unwrap(); - assert_eq!(&[55], buffer_ref); + assert_eq!(exp_message, message); + + let sock = framed_handshake.into_inner(); + + let position: usize = sock.position().try_into().unwrap(); + let buffer = sock.get_ref(); + let remaining = &buffer[position..]; + + assert_eq!([55], remaining); } - #[test] - fn positive_read_bytes_after_handshake() { + #[tokio::test] + async fn positive_read_bytes_after_handshake() { + INIT.call_once(|| { + tracing_stderr_init(LevelFilter::INFO); + }); + let exp_message = HandshakeMessage::from_parts(Protocol::BitTorrent, any_extensions(), any_info_hash(), any_peer_id()); let mut buffer = Vec::new(); - exp_message.write_bytes(&mut buffer).unwrap(); + exp_message.write_bytes(&mut buffer).await.unwrap(); // Write some bytes right after the handshake, make sure // our framed handshake doesn't read/buffer these (we need // to be able to read them afterwards) - buffer.write_all(&[55, 54, 21]).unwrap(); + drop(buffer.write_all(&[55, 54, 21]).await); + + let read_frame = { + let mut frame = FramedHandshake::new(&buffer[..]); + frame.next().await; + frame + }; - let read_frame = FramedHandshake::new(&buffer[..]).into_future().wait().ok().unwrap().1; let buffer_ref = read_frame.into_inner(); assert_eq!(&[55, 54, 21], buffer_ref); diff --git a/packages/handshake/src/bittorrent/message.rs b/packages/handshake/src/bittorrent/message.rs index caa42429c..f828248d3 100644 --- a/packages/handshake/src/bittorrent/message.rs +++ b/packages/handshake/src/bittorrent/message.rs @@ -1,7 +1,8 @@ -use std::io; -use std::io::Write; - -use nom::{call, do_parse, take, IResult}; +use nom::bytes::complete::take; +use nom::combinator::map_res; +use nom::sequence::tuple; +use nom::IResult; +use tokio::io::{AsyncWrite, AsyncWriteExt as _}; use util::bt::{self, InfoHash, PeerId}; use crate::message::extensions::{self, Extensions}; @@ -30,16 +31,30 @@ impl HandshakeMessage { HandshakeMessage { prot, ext, hash, pid } } - pub fn from_bytes(bytes: &[u8]) -> IResult<&[u8], HandshakeMessage> { + pub fn from_bytes(bytes: &Vec) -> IResult<(), HandshakeMessage> { parse_remote_handshake(bytes) } - pub fn write_bytes(&self, mut writer: W) -> io::Result<()> + #[allow(dead_code)] + pub async fn write_bytes(&self, writer: &mut W) -> std::io::Result<()> + where + W: AsyncWrite + Unpin, + { + self.prot.write_bytes(writer).await?; + self.ext.write_bytes(writer).await?; + writer.write_all(self.hash.as_ref()).await?; + + writer.write_all(self.pid.as_ref()).await?; + + Ok(()) + } + + pub fn write_bytes_sync(&self, writer: &mut W) -> std::io::Result<()> where - W: Write, + W: std::io::Write, { - self.prot.write_bytes(&mut writer)?; - self.ext.write_bytes(&mut writer)?; + self.prot.write_bytes_sync(writer)?; + self.ext.write_bytes_sync(writer)?; writer.write_all(self.hash.as_ref())?; writer.write_all(self.pid.as_ref())?; @@ -61,33 +76,35 @@ pub fn write_len_with_protocol_len(protocol_len: u8) -> usize { 1 + (protocol_len as usize) + extensions::NUM_EXTENSION_BYTES + bt::INFO_HASH_LEN + bt::PEER_ID_LEN } -fn parse_remote_handshake(bytes: &[u8]) -> IResult<&[u8], HandshakeMessage> { - do_parse!(bytes, - prot: call!(Protocol::from_bytes) >> - ext: call!(Extensions::from_bytes) >> - hash: call!(parse_remote_hash) >> - pid: call!(parse_remote_pid) >> - (HandshakeMessage::from_parts(prot, ext, hash, pid)) - ) +#[allow(clippy::ptr_arg)] +fn parse_remote_handshake(bytes: &Vec) -> IResult<(), HandshakeMessage> { + let res = tuple(( + Protocol::from_bytes, + Extensions::from_bytes, + parse_remote_hash, + parse_remote_pid, + ))(bytes); + + let (_, (prot, ext, hash, pid)) = res.map_err(|e: nom::Err>| e.map_input(|_| ()))?; + + Ok(((), HandshakeMessage::from_parts(prot, ext, hash, pid))) } fn parse_remote_hash(bytes: &[u8]) -> IResult<&[u8], InfoHash> { - do_parse!(bytes, - hash: take!(bt::INFO_HASH_LEN) >> - (InfoHash::from_hash(hash).unwrap()) - ) + map_res(take(bt::INFO_HASH_LEN), |hash: &[u8]| { + InfoHash::from_hash(hash).map_err(|_| nom::Err::Error((bytes, nom::error::ErrorKind::LengthValue))) + })(bytes) } fn parse_remote_pid(bytes: &[u8]) -> IResult<&[u8], PeerId> { - do_parse!(bytes, - pid: take!(bt::PEER_ID_LEN) >> - (PeerId::from_hash(pid).unwrap()) - ) + map_res(take(bt::PEER_ID_LEN), |pid: &[u8]| { + PeerId::from_hash(pid).map_err(|_| nom::Err::Error((bytes, nom::error::ErrorKind::LengthValue))) + })(bytes) } #[cfg(test)] mod tests { - use std::io::Write; + use std::io::Write as _; use util::bt::{self, InfoHash, PeerId}; @@ -118,8 +135,8 @@ mod tests { let exp_message = HandshakeMessage::from_parts(exp_protocol.clone(), exp_extensions, exp_hash, exp_pid); - exp_protocol.write_bytes(&mut buffer).unwrap(); - exp_extensions.write_bytes(&mut buffer).unwrap(); + exp_protocol.write_bytes_sync(&mut buffer).unwrap(); + exp_extensions.write_bytes_sync(&mut buffer).unwrap(); buffer.write_all(exp_hash.as_ref()).unwrap(); buffer.write_all(exp_pid.as_ref()).unwrap(); @@ -139,8 +156,8 @@ mod tests { let exp_message = HandshakeMessage::from_parts(exp_protocol.clone(), exp_extensions, exp_hash, exp_pid); - exp_protocol.write_bytes(&mut buffer).unwrap(); - exp_extensions.write_bytes(&mut buffer).unwrap(); + exp_protocol.write_bytes_sync(&mut buffer).unwrap(); + exp_extensions.write_bytes_sync(&mut buffer).unwrap(); buffer.write_all(exp_hash.as_ref()).unwrap(); buffer.write_all(exp_pid.as_ref()).unwrap(); @@ -160,8 +177,8 @@ mod tests { let exp_message = HandshakeMessage::from_parts(exp_protocol.clone(), exp_extensions, exp_hash, exp_pid); - exp_protocol.write_bytes(&mut buffer).unwrap(); - exp_extensions.write_bytes(&mut buffer).unwrap(); + exp_protocol.write_bytes_sync(&mut buffer).unwrap(); + exp_extensions.write_bytes_sync(&mut buffer).unwrap(); buffer.write_all(exp_hash.as_ref()).unwrap(); buffer.write_all(exp_pid.as_ref()).unwrap(); diff --git a/packages/handshake/src/filter/mod.rs b/packages/handshake/src/filter/mod.rs index 09f3a2a02..323e75e0f 100644 --- a/packages/handshake/src/filter/mod.rs +++ b/packages/handshake/src/filter/mod.rs @@ -1,5 +1,4 @@ use std::any::Any; -use std::cmp::{Eq, PartialEq}; use std::net::SocketAddr; use util::bt::{InfoHash, PeerId}; diff --git a/packages/handshake/src/handshake/builder.rs b/packages/handshake/src/handshake/builder.rs index a0a9e05e7..bd53929ee 100644 --- a/packages/handshake/src/handshake/builder.rs +++ b/packages/handshake/src/handshake/builder.rs @@ -1,17 +1,15 @@ -//! Build configuration for `Handshaker` object creation. - use std::net::{Ipv4Addr, SocketAddr, SocketAddrV4}; use rand::Rng as _; -use tokio_core::reactor::Handle; +use tokio::io::{AsyncRead, AsyncWrite}; +use tokio::task::JoinSet; use util::bt::PeerId; use util::convert; use super::Handshaker; -use crate::handshake::config::HandshakerConfig; -use crate::message::extensions::Extensions; -use crate::transport::Transport; +use crate::{Extensions, HandshakerConfig, Transport}; +/// Build configuration for `Handshaker` object creation. #[allow(clippy::module_name_repetitions)] #[derive(Copy, Clone)] pub struct HandshakerBuilder { @@ -100,10 +98,12 @@ impl HandshakerBuilder { /// # Errors /// /// Returns a IO error if unable to build. - pub fn build(&self, transport: T, handle: &Handle) -> std::io::Result> + pub async fn build(&self, transport: T) -> std::io::Result<(Handshaker, JoinSet<()>)> where - T: Transport + 'static, + T: Transport + Send + Sync + 'static, + ::Socket: AsyncWrite + AsyncRead + std::fmt::Debug + Send + Sync, + ::Listener: Send, { - Handshaker::with_builder(self, transport, handle) + Handshaker::with_builder(self, transport).await } } diff --git a/packages/handshake/src/handshake/config.rs b/packages/handshake/src/handshake/config.rs index a09a783ae..b3e011e7c 100644 --- a/packages/handshake/src/handshake/config.rs +++ b/packages/handshake/src/handshake/config.rs @@ -1,4 +1,3 @@ -use std::default::Default; use std::time::Duration; const DEFAULT_HANDSHAKE_BUFFER_SIZE: usize = 1000; diff --git a/packages/handshake/src/handshake/handler/handshaker.rs b/packages/handshake/src/handshake/handler/handshaker.rs index 1e899ee2c..9b4c3f28e 100644 --- a/packages/handshake/src/handshake/handler/handshaker.rs +++ b/packages/handshake/src/handshake/handler/handshaker.rs @@ -1,164 +1,149 @@ use std::net::SocketAddr; +use std::time::Duration; -use futures::future::Future; -use futures::sink::Sink; -use futures::stream::Stream; -use tokio_io::{AsyncRead, AsyncWrite}; +use futures::future::BoxFuture; +use futures::{FutureExt as _, SinkExt as _, StreamExt as _}; +use tokio::io::{AsyncRead, AsyncWrite}; use util::bt::PeerId; use crate::bittorrent::framed::FramedHandshake; use crate::bittorrent::message::HandshakeMessage; use crate::filter::filters::Filters; use crate::handshake::handler; -use crate::handshake::handler::timer::HandshakeTimer; use crate::handshake::handler::HandshakeType; use crate::message::complete::CompleteMessage; use crate::message::extensions::Extensions; use crate::message::initiate::InitiateMessage; -pub fn execute_handshake( - item: HandshakeType, - context: &(Extensions, PeerId, Filters, HandshakeTimer), -) -> Box>, Error = ()>> +#[allow(clippy::module_name_repetitions)] +pub fn execute_handshake<'a, S>( + item: std::io::Result>, + context: &(Extensions, PeerId, Filters, Duration), +) -> BoxFuture<'a, std::io::Result>>> where - S: AsyncRead + AsyncWrite + 'static, + S: AsyncWrite + AsyncRead + std::fmt::Debug + Send + Unpin + 'a, { - let (ext, pid, filters, timer) = context; + let (ext, pid, filters, timeout) = context; match item { - HandshakeType::Initiate(sock, init_msg) => initiate_handshake(sock, init_msg, *ext, *pid, filters.clone(), timer.clone()), - HandshakeType::Complete(sock, addr) => complete_handshake(sock, addr, *ext, *pid, filters.clone(), timer.clone()), + Ok(HandshakeType::Initiate(sock, init_msg)) => { + initiate_handshake(sock, init_msg, *ext, *pid, filters.clone(), *timeout).boxed() + } + Ok(HandshakeType::Complete(sock, addr)) => complete_handshake(sock, addr, *ext, *pid, filters.clone(), *timeout).boxed(), + Err(err) => async move { Err(err) }.boxed(), } } -fn initiate_handshake( +async fn initiate_handshake( sock: S, init_msg: InitiateMessage, ext: Extensions, pid: PeerId, filters: Filters, - timer: HandshakeTimer, -) -> Box>, Error = ()>> + timeout: Duration, +) -> std::io::Result>> where - S: AsyncRead + AsyncWrite + 'static, + S: AsyncWrite + AsyncRead + std::fmt::Debug + Unpin, { - let framed = FramedHandshake::new(sock); + let mut framed = FramedHandshake::new(sock); let (prot, hash, addr) = init_msg.into_parts(); let handshake_msg = HandshakeMessage::from_parts(prot.clone(), ext, hash, pid); - let composed_future = timer - .timeout(framed.send(handshake_msg).map_err(|_| ())) - .and_then(move |framed| { - timer - .timeout( - framed - .into_future() - .map_err(|_| ()) - .and_then(|(opt_msg, framed)| opt_msg.ok_or(()).map(|msg| (msg, framed))), - ) - .and_then(move |(msg, framed)| { - let (remote_prot, remote_ext, remote_hash, remote_pid) = msg.into_parts(); - let socket = framed.into_inner(); - - // Check that it responds with the same hash and protocol, also check our filters - if remote_hash != hash - || remote_prot != prot - || handler::should_filter( - Some(&addr), - Some(&remote_prot), - Some(&remote_ext), - Some(&remote_hash), - Some(&remote_pid), - &filters, - ) - { - Err(()) - } else { - Ok(Some(CompleteMessage::new( - prot, - ext.union(&remote_ext), - hash, - remote_pid, - addr, - socket, - ))) - } - }) - }) - .or_else(|()| Ok(None)); - - Box::new(composed_future) + let send_result = tokio::time::timeout(timeout, framed.send(handshake_msg)).await; + if send_result.is_err() { + return Ok(None); + } + + let recv_result = tokio::time::timeout(timeout, framed.next()).await; + let Ok(Some(Ok(msg))) = recv_result else { return Ok(None) }; + + let (remote_prot, remote_ext, remote_hash, remote_pid) = msg.into_parts(); + let socket = framed.into_inner(); + + if remote_hash != hash { + Err(std::io::Error::new(std::io::ErrorKind::Other, "not matching hash")) + } else if remote_prot != prot { + Err(std::io::Error::new(std::io::ErrorKind::Other, "not matching port")) + } else if handler::should_filter( + Some(&addr), + Some(&remote_prot), + Some(&remote_ext), + Some(&remote_hash), + Some(&remote_pid), + &filters, + ) { + Err(std::io::Error::new(std::io::ErrorKind::Other, "should not filter")) + } else { + Ok(Some(CompleteMessage::new( + prot, + ext.union(&remote_ext), + hash, + remote_pid, + addr, + socket, + ))) + } } -fn complete_handshake( +async fn complete_handshake( sock: S, addr: SocketAddr, ext: Extensions, pid: PeerId, filters: Filters, - timer: HandshakeTimer, -) -> Box>, Error = ()>> + timeout: Duration, +) -> std::io::Result>> where - S: AsyncRead + AsyncWrite + 'static, + S: AsyncWrite + AsyncRead + std::fmt::Debug + Unpin, { - let framed = FramedHandshake::new(sock); - - let composed_future = timer - .timeout( - framed - .into_future() - .map_err(|_| ()) - .and_then(|(opt_msg, framed)| opt_msg.ok_or(()).map(|msg| (msg, framed))), - ) - .and_then(move |(msg, framed)| { - let (remote_prot, remote_ext, remote_hash, remote_pid) = msg.into_parts(); - - // Check our filters - if handler::should_filter( - Some(&addr), - Some(&remote_prot), - Some(&remote_ext), - Some(&remote_hash), - Some(&remote_pid), - &filters, - ) { - Err(()) - } else { - let handshake_msg = HandshakeMessage::from_parts(remote_prot.clone(), ext, remote_hash, pid); - - Ok(timer.timeout(framed.send(handshake_msg).map_err(|_| ()).map(move |framed| { - let socket = framed.into_inner(); - - Some(CompleteMessage::new( - remote_prot, - ext.union(&remote_ext), - remote_hash, - remote_pid, - addr, - socket, - )) - }))) - } - }) - .flatten() - .or_else(|()| Ok(None)); - - Box::new(composed_future) + let mut framed = FramedHandshake::new(sock); + + let recv_result = tokio::time::timeout(timeout, framed.next()).await; + let Ok(Some(Ok(msg))) = recv_result else { return Ok(None) }; + + let (remote_prot, remote_ext, remote_hash, remote_pid) = msg.into_parts(); + + if handler::should_filter( + Some(&addr), + Some(&remote_prot), + Some(&remote_ext), + Some(&remote_hash), + Some(&remote_pid), + &filters, + ) { + Err(std::io::Error::new(std::io::ErrorKind::Other, "should not filter")) + } else { + let handshake_msg = HandshakeMessage::from_parts(remote_prot.clone(), ext, remote_hash, pid); + + let send_result = tokio::time::timeout(timeout, framed.send(handshake_msg)).await; + if send_result.is_err() { + return Ok(None); + } + + let socket = framed.into_inner(); + + Ok(Some(CompleteMessage::new( + remote_prot, + ext.union(&remote_ext), + remote_hash, + remote_pid, + addr, + socket, + ))) + } } #[cfg(test)] mod tests { - use std::io::Cursor; + use std::time::Duration; - use futures::future::{self, Future}; - use tokio_timer; use util::bt::{self, InfoHash, PeerId}; use super::HandshakeMessage; use crate::filter::filters::Filters; - use crate::handshake::handler::timer::HandshakeTimer; + use crate::handshake::handler::handshaker; use crate::message::extensions::{self, Extensions}; use crate::message::initiate::InitiateMessage; use crate::message::protocol::Protocol; @@ -179,25 +164,18 @@ mod tests { [255u8; extensions::NUM_EXTENSION_BYTES].into() } - fn any_handshake_timer() -> HandshakeTimer { - HandshakeTimer::new(tokio_timer::wheel().build(), Duration::from_millis(100)) - } - - #[test] - fn positive_initiate_handshake() { + #[tokio::test] + async fn positive_initiate_handshake() { let remote_pid = any_peer_id(); let remote_addr = "1.2.3.4:5".parse().unwrap(); let remote_protocol = Protocol::BitTorrent; let remote_hash = any_info_hash(); let remote_message = HandshakeMessage::from_parts(remote_protocol, any_extensions(), remote_hash, remote_pid); - // Setup our buffer so that the first portion is zeroed out (so our function can write to it), and the second half is our - // serialized message (so our function can read from it). - let mut writer = Cursor::new(vec![0u8; remote_message.write_len() * 2]); + let mut writer = std::io::Cursor::new(Vec::with_capacity(remote_message.write_len() * 2)); writer.set_position(remote_message.write_len() as u64); - // Write out message to the second half of the buffer - remote_message.write_bytes(&mut writer).unwrap(); + remote_message.write_bytes(&mut writer).await.unwrap(); writer.set_position(0); let init_hash = any_info_hash(); @@ -207,14 +185,18 @@ mod tests { let init_ext = any_extensions(); let init_pid = any_other_peer_id(); let init_filters = Filters::new(); - let init_timer = any_handshake_timer(); - // Wrap in lazy since we can call wait on non sized types... - let complete_message = - future::lazy(|| super::initiate_handshake(writer, init_message, init_ext, init_pid, init_filters, init_timer)) - .wait() - .unwrap() - .unwrap(); + let complete_message = handshaker::initiate_handshake( + writer, + init_message, + init_ext, + init_pid, + init_filters, + Duration::from_millis(100), + ) + .await + .unwrap() + .unwrap(); assert_eq!(init_prot, *complete_message.protocol()); assert_eq!(init_ext, *complete_message.extensions()); @@ -222,46 +204,49 @@ mod tests { assert_eq!(remote_pid, *complete_message.peer_id()); assert_eq!(remote_addr, *complete_message.address()); - let sent_message = HandshakeMessage::from_bytes(&complete_message.socket().get_ref()[..remote_message.write_len()]) - .unwrap() - .1; + let sent_message = + HandshakeMessage::from_bytes(&complete_message.socket().get_ref()[..remote_message.write_len()].to_vec()) + .unwrap() + .1; let local_message = HandshakeMessage::from_parts(init_prot, init_ext, init_hash, init_pid); - let recv_message = HandshakeMessage::from_bytes(&complete_message.socket().get_ref()[remote_message.write_len()..]) - .unwrap() - .1; + let recv_message = + HandshakeMessage::from_bytes(&complete_message.socket().get_ref()[remote_message.write_len()..].to_vec()) + .unwrap() + .1; assert_eq!(local_message, sent_message); assert_eq!(remote_message, recv_message); } - #[test] - fn positive_complete_handshake() { + #[tokio::test] + async fn positive_complete_handshake() { let remote_pid = any_peer_id(); let remote_addr = "1.2.3.4:5".parse().unwrap(); let remote_protocol = Protocol::BitTorrent; let remote_hash = any_info_hash(); let remote_message = HandshakeMessage::from_parts(Protocol::BitTorrent, any_extensions(), remote_hash, remote_pid); - // Setup our buffer so that the second portion is zeroed out (so our function can write to it), and the first half is our - // serialized message (so our function can read from it). - let mut writer = Cursor::new(vec![0u8; remote_message.write_len() * 2]); + let mut writer = std::io::Cursor::new(vec![0u8; remote_message.write_len() * 2]); - // Write out message to the first half of the buffer - remote_message.write_bytes(&mut writer).unwrap(); + remote_message.write_bytes(&mut writer).await.unwrap(); writer.set_position(0); let comp_ext = any_extensions(); let comp_pid = any_other_peer_id(); let comp_filters = Filters::new(); - let comp_timer = any_handshake_timer(); - // Wrap in lazy since we can call wait on non sized types... - let complete_message = - future::lazy(|| super::complete_handshake(writer, remote_addr, comp_ext, comp_pid, comp_filters, comp_timer)) - .wait() - .unwrap() - .unwrap(); + let complete_message = handshaker::complete_handshake( + writer, + remote_addr, + comp_ext, + comp_pid, + comp_filters, + Duration::from_millis(100), + ) + .await + .unwrap() + .unwrap(); assert_eq!(remote_protocol, *complete_message.protocol()); assert_eq!(comp_ext, *complete_message.extensions()); @@ -269,14 +254,16 @@ mod tests { assert_eq!(remote_pid, *complete_message.peer_id()); assert_eq!(remote_addr, *complete_message.address()); - let sent_message = HandshakeMessage::from_bytes(&complete_message.socket().get_ref()[remote_message.write_len()..]) - .unwrap() - .1; + let sent_message = + HandshakeMessage::from_bytes(&complete_message.socket().get_ref()[remote_message.write_len()..].to_vec()) + .unwrap() + .1; let local_message = HandshakeMessage::from_parts(remote_protocol, comp_ext, remote_hash, comp_pid); - let recv_message = HandshakeMessage::from_bytes(&complete_message.socket().get_ref()[..remote_message.write_len()]) - .unwrap() - .1; + let recv_message = + HandshakeMessage::from_bytes(&complete_message.socket().get_ref()[..remote_message.write_len()].to_vec()) + .unwrap() + .1; assert_eq!(local_message, sent_message); assert_eq!(remote_message, recv_message); diff --git a/packages/handshake/src/handshake/handler/initiator.rs b/packages/handshake/src/handshake/handler/initiator.rs index 7dafbcad1..98430ad69 100644 --- a/packages/handshake/src/handshake/handler/initiator.rs +++ b/packages/handshake/src/handshake/handler/initiator.rs @@ -1,23 +1,28 @@ -use futures::future::{self, Future}; -use tokio_core::reactor::Handle; +/// Handle the initiation of connections, which are returned as a `HandshakeType`. +#[allow(clippy::module_name_repetitions)] +use std::time::Duration; + +use futures::future::{self, BoxFuture}; +use futures::{FutureExt, TryFutureExt as _}; use crate::filter::filters::Filters; use crate::handshake::handler; -use crate::handshake::handler::timer::HandshakeTimer; use crate::handshake::handler::HandshakeType; use crate::message::initiate::InitiateMessage; use crate::transport::Transport; /// Handle the initiation of connections, which are returned as a `HandshakeType`. #[allow(clippy::module_name_repetitions)] -pub fn initiator_handler( +pub fn initiator_handler<'a, 'b, T>( item: InitiateMessage, - context: &(T, Filters, Handle, HandshakeTimer), -) -> Box>, Error = ()>> + context: &'b (T, Filters, Duration), +) -> BoxFuture<'a, std::io::Result>>> where - T: Transport, + T: Transport + Send + Sync + 'a, + ::Socket: Send + Sync, { - let (transport, filters, handle, timer) = context; + let (transport, filters, timeout) = context; + let timeout = *timeout; if handler::should_filter( Some(item.address()), @@ -27,18 +32,12 @@ where None, filters, ) { - Box::new(future::ok(None)) + future::ok(None).boxed() } else { - let res_connect = transport - .connect(item.address(), handle) - .map(|connect| timer.timeout(connect)); - - Box::new( - future::lazy(|| res_connect) - .flatten() - .map(|socket| Some(HandshakeType::Initiate(socket, item))) - .or_else(|_| Ok(None)), - ) + transport + .connect(*item.address(), timeout) + .map_ok(|s| Some(HandshakeType::Initiate(s, item))) + .boxed() } } @@ -46,14 +45,10 @@ where mod tests { use std::time::Duration; - use futures::Future; - use tokio_core::reactor::Core; - use tokio_timer; use util::bt::{self, InfoHash, PeerId}; use crate::filter::filters::test_filters::{BlockAddrFilter, BlockPeerIdFilter, BlockProtocolFilter}; use crate::filter::filters::Filters; - use crate::handshake::handler::timer::HandshakeTimer; use crate::handshake::handler::HandshakeType; use crate::message::initiate::InitiateMessage; use crate::message::protocol::Protocol; @@ -67,16 +62,16 @@ mod tests { [55u8; bt::INFO_HASH_LEN].into() } - #[test] - fn positive_empty_filter() { - let core = Core::new().unwrap(); + #[tokio::test] + async fn positive_empty_filter() { let exp_message = InitiateMessage::new(Protocol::BitTorrent, any_info_hash(), "1.2.3.4:5".parse().unwrap()); - let timer = HandshakeTimer::new(tokio_timer::wheel().build(), Duration::from_millis(1000)); - let recv_enum_item = - super::initiator_handler(exp_message.clone(), &(MockTransport, Filters::new(), core.handle(), timer)) - .wait() - .unwrap(); + let recv_enum_item = super::initiator_handler( + exp_message.clone(), + &(MockTransport, Filters::new(), Duration::from_millis(1000)), + ) + .await + .unwrap(); let recv_item = match recv_enum_item { Some(HandshakeType::Initiate(_, msg)) => msg, Some(HandshakeType::Complete(_, _)) | None => panic!("Expected HandshakeType::Initiate"), @@ -85,19 +80,17 @@ mod tests { assert_eq!(exp_message, recv_item); } - #[test] - fn positive_passes_filter() { - let core = Core::new().unwrap(); - let timer = HandshakeTimer::new(tokio_timer::wheel().build(), Duration::from_millis(1000)); - + #[tokio::test] + async fn positive_passes_filter() { let filters = Filters::new(); filters.add_filter(BlockAddrFilter::new("2.3.4.5:6".parse().unwrap())); let exp_message = InitiateMessage::new(Protocol::BitTorrent, any_info_hash(), "1.2.3.4:5".parse().unwrap()); - let recv_enum_item = super::initiator_handler(exp_message.clone(), &(MockTransport, filters, core.handle(), timer)) - .wait() - .unwrap(); + let recv_enum_item = + super::initiator_handler(exp_message.clone(), &(MockTransport, filters, Duration::from_millis(1000))) + .await + .unwrap(); let recv_item = match recv_enum_item { Some(HandshakeType::Initiate(_, msg)) => msg, Some(HandshakeType::Complete(_, _)) | None => panic!("Expected HandshakeType::Initiate"), @@ -106,19 +99,17 @@ mod tests { assert_eq!(exp_message, recv_item); } - #[test] - fn positive_needs_data_filter() { - let core = Core::new().unwrap(); - let timer = HandshakeTimer::new(tokio_timer::wheel().build(), Duration::from_millis(1000)); - + #[tokio::test] + async fn positive_needs_data_filter() { let filters = Filters::new(); filters.add_filter(BlockPeerIdFilter::new(any_peer_id())); let exp_message = InitiateMessage::new(Protocol::BitTorrent, any_info_hash(), "1.2.3.4:5".parse().unwrap()); - let recv_enum_item = super::initiator_handler(exp_message.clone(), &(MockTransport, filters, core.handle(), timer)) - .wait() - .unwrap(); + let recv_enum_item = + super::initiator_handler(exp_message.clone(), &(MockTransport, filters, Duration::from_millis(1000))) + .await + .unwrap(); let recv_item = match recv_enum_item { Some(HandshakeType::Initiate(_, msg)) => msg, Some(HandshakeType::Complete(_, _)) | None => panic!("Expected HandshakeType::Initiate"), @@ -127,11 +118,8 @@ mod tests { assert_eq!(exp_message, recv_item); } - #[test] - fn positive_fails_filter() { - let core = Core::new().unwrap(); - let timer = HandshakeTimer::new(tokio_timer::wheel().build(), Duration::from_millis(1000)); - + #[tokio::test] + async fn positive_fails_filter() { let filters = Filters::new(); filters.add_filter(BlockProtocolFilter::new(Protocol::Custom(vec![1, 2, 3, 4]))); @@ -141,9 +129,10 @@ mod tests { "1.2.3.4:5".parse().unwrap(), ); - let recv_enum_item = super::initiator_handler(exp_message.clone(), &(MockTransport, filters, core.handle(), timer)) - .wait() - .unwrap(); + let recv_enum_item = + super::initiator_handler(exp_message.clone(), &(MockTransport, filters, Duration::from_millis(1000))) + .await + .unwrap(); match recv_enum_item { None => (), Some(HandshakeType::Initiate(_, _) | HandshakeType::Complete(_, _)) => panic!("Expected No Handshake"), diff --git a/packages/handshake/src/handshake/handler/listener.rs b/packages/handshake/src/handshake/handler/listener.rs index d688a141c..44b9ccfae 100644 --- a/packages/handshake/src/handshake/handler/listener.rs +++ b/packages/handshake/src/handshake/handler/listener.rs @@ -1,7 +1,9 @@ +use std::cell::Cell; use std::net::SocketAddr; +use std::pin::Pin; +use std::task::{Context, Poll}; use futures::future::Future; -use futures::{Async, Poll}; use crate::filter::filters::Filters; use crate::handshake::handler; @@ -9,12 +11,15 @@ use crate::handshake::handler::HandshakeType; #[allow(clippy::module_name_repetitions)] pub struct ListenerHandler { - opt_item: Option>, + opt_item: Cell>>>, } impl ListenerHandler { - pub fn new(item: (S, SocketAddr), context: &Filters) -> ListenerHandler { - let (sock, addr) = item; + pub fn new(item: std::io::Result<(S, SocketAddr)>, context: &Filters) -> ListenerHandler { + let (sock, addr) = match item { + Ok(item) => item, + Err(e) => return ListenerHandler { opt_item: Err(e).into() }, + }; let opt_item = if handler::should_filter(Some(&addr), None, None, None, None, context) { None @@ -22,22 +27,22 @@ impl ListenerHandler { Some(HandshakeType::Complete(sock, addr)) }; - ListenerHandler { opt_item } + ListenerHandler { + opt_item: Ok(opt_item).into(), + } } } impl Future for ListenerHandler { - type Item = Option>; - type Error = (); + type Output = std::io::Result>>; - fn poll(&mut self) -> Poll>, ()> { - Ok(Async::Ready(self.opt_item.take())) + fn poll(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll { + Poll::Ready(self.opt_item.replace(Ok(None))) } } #[cfg(test)] mod tests { - use futures::Future; use super::ListenerHandler; use crate::filter::filters::test_filters::{BlockAddrFilter, BlockProtocolFilter}; @@ -45,66 +50,66 @@ mod tests { use crate::handshake::handler::HandshakeType; use crate::message::protocol::Protocol; - #[test] - fn positive_empty_filter() { + #[tokio::test] + async fn positive_empty_filter() { let exp_item = ("Testing", "0.0.0.0:0".parse().unwrap()); - let handler = ListenerHandler::new(exp_item, &Filters::new()); + let handler = ListenerHandler::new(Ok(exp_item), &Filters::new()); - let recv_enum_item = handler.wait().unwrap(); + let recv_enum_item = handler.await.unwrap().unwrap(); let recv_item = match recv_enum_item { - Some(HandshakeType::Complete(sock, addr)) => (sock, addr), - Some(HandshakeType::Initiate(_, _)) | None => panic!("Expected HandshakeType::Complete"), + HandshakeType::Complete(sock, addr) => (sock, addr), + HandshakeType::Initiate(_, _) => panic!("Expected HandshakeType::Complete"), }; assert_eq!(exp_item, recv_item); } - #[test] - fn positive_passes_filter() { + #[tokio::test] + async fn positive_passes_filter() { let filters = Filters::new(); filters.add_filter(BlockAddrFilter::new("1.2.3.4:5".parse().unwrap())); let exp_item = ("Testing", "0.0.0.0:0".parse().unwrap()); - let handler = ListenerHandler::new(exp_item, &filters); + let handler = ListenerHandler::new(Ok(exp_item), &filters); - let recv_enum_item = handler.wait().unwrap(); + let recv_enum_item = handler.await.unwrap().unwrap(); let recv_item = match recv_enum_item { - Some(HandshakeType::Complete(sock, addr)) => (sock, addr), - Some(HandshakeType::Initiate(_, _)) | None => panic!("Expected HandshakeType::Complete"), + HandshakeType::Complete(sock, addr) => (sock, addr), + HandshakeType::Initiate(_, _) => panic!("Expected HandshakeType::Complete"), }; assert_eq!(exp_item, recv_item); } - #[test] - fn positive_needs_data_filter() { + #[tokio::test] + async fn positive_needs_data_filter() { let filters = Filters::new(); filters.add_filter(BlockProtocolFilter::new(Protocol::BitTorrent)); let exp_item = ("Testing", "0.0.0.0:0".parse().unwrap()); - let handler = ListenerHandler::new(exp_item, &filters); + let handler = ListenerHandler::new(Ok(exp_item), &filters); - let recv_enum_item = handler.wait().unwrap(); + let recv_enum_item = handler.await.unwrap().unwrap(); let recv_item = match recv_enum_item { - Some(HandshakeType::Complete(sock, addr)) => (sock, addr), - Some(HandshakeType::Initiate(_, _)) | None => panic!("Expected HandshakeType::Complete"), + HandshakeType::Complete(sock, addr) => (sock, addr), + HandshakeType::Initiate(_, _) => panic!("Expected HandshakeType::Complete"), }; assert_eq!(exp_item, recv_item); } - #[test] - fn positive_fails_filter() { + #[tokio::test] + async fn positive_fails_filter() { let filters = Filters::new(); filters.add_filter(BlockAddrFilter::new("0.0.0.0:0".parse().unwrap())); let exp_item = ("Testing", "0.0.0.0:0".parse().unwrap()); - let handler = ListenerHandler::new(exp_item, &filters); + let handler = ListenerHandler::new(Ok(exp_item), &filters); - let recv_enum_item = handler.wait().unwrap(); + let recv_enum_item = handler.await.unwrap(); if let Some(HandshakeType::Complete(_, _) | HandshakeType::Initiate(_, _)) = recv_enum_item { panic!("Expected No HandshakeType") diff --git a/packages/handshake/src/handshake/handler/mod.rs b/packages/handshake/src/handshake/handler/mod.rs index e40d6920f..592f77a9b 100644 --- a/packages/handshake/src/handshake/handler/mod.rs +++ b/packages/handshake/src/handshake/handler/mod.rs @@ -1,9 +1,8 @@ use std::net::SocketAddr; +use std::pin::Pin; -use futures::future::{self, Future, IntoFuture, Loop}; -use futures::sink::Sink; -use futures::stream::Stream; -use tokio_core::reactor::Handle; +use futures::sink::SinkExt; +use futures::stream::StreamExt; use util::bt::{InfoHash, PeerId}; use crate::filter::filters::Filters; @@ -15,67 +14,34 @@ use crate::message::protocol::Protocol; pub mod handshaker; pub mod initiator; pub mod listener; -pub mod timer; pub enum HandshakeType { Initiate(S, InitiateMessage), Complete(S, SocketAddr), } -enum LoopError { - Terminate, - Recoverable(D), -} - /// Create loop for feeding the handler with the items coming from the stream, and forwarding the result to the sink. /// /// If the stream is used up, or an error is propagated from any of the elements, the loop will terminate. #[allow(clippy::module_name_repetitions)] -pub fn loop_handler(stream: M, handler: H, sink: K, context: C, handle: &Handle) +pub async fn loop_handler(mut stream: M, mut handler: H, mut sink: K, context: Pin>) where - M: Stream + 'static, - H: FnMut(M::Item, &C) -> F + 'static, - K: Sink + 'static, - F: IntoFuture> + 'static, - R: 'static, + M: futures::Stream + Unpin, + H: for<'a> FnMut(M::Item, &'a C) -> F, + K: futures::Sink> + Unpin, + F: futures::Future>>, C: 'static, { - handle.spawn(future::loop_fn( - (stream, handler, sink, context), - |(stream, mut handler, sink, context)| { - // We will terminate the loop if, the stream gives us an error, the stream gives us None, the handler gives - // us an error, or the sink gives us an error. If the handler gives us Ok(None), we will map that to a - // recoverable error (since our Ok(Some) result would have to continue with its own future, we hijack - // the error to store an immediate value). We finally map any recoverable errors back to an Ok value - // so we can continue with the loop in that case. - stream - .into_future() - .map_err(|_| LoopError::Terminate) - .and_then(|(opt_item, stream)| opt_item.ok_or(LoopError::Terminate).map(|item| (item, stream))) - .and_then(move |(item, stream)| { - let into_future = handler(item, &context); + while let Some(item) = stream.next().await { + let Ok(maybe_result) = handler(item, &context).await else { + break; + }; - into_future - .into_future() - .map_err(|_| LoopError::Terminate) - .and_then(move |opt_result| match opt_result { - Some(result) => Ok((result, stream, handler, context, sink)), - None => Err(LoopError::Recoverable((stream, handler, context, sink))), - }) - }) - .and_then(|(result, stream, handler, context, sink)| { - sink.send(result) - .map_err(|_| LoopError::Terminate) - .map(|sink| Loop::Continue((stream, handler, sink, context))) - }) - .or_else(|loop_error| match loop_error { - LoopError::Terminate => Err(()), - LoopError::Recoverable((stream, handler, context, sink)) => { - Ok(Loop::Continue((stream, handler, sink, context))) - } - }) - }, - )); + drop(match maybe_result { + Some(result) => sink.send(Ok(result)).await, + None => continue, + }); + } } /// Computes whether or not we should filter given the parameters and filters. diff --git a/packages/handshake/src/handshake/handler/timer.rs b/packages/handshake/src/handshake/handler/timer.rs deleted file mode 100644 index da2b15c0b..000000000 --- a/packages/handshake/src/handshake/handler/timer.rs +++ /dev/null @@ -1,51 +0,0 @@ -use std::time::Duration; - -use futures::Future; -use tokio_timer::{Timeout, TimeoutError, Timer}; - -#[allow(clippy::module_name_repetitions)] -#[derive(Clone)] -pub struct HandshakeTimer { - timer: Timer, - duration: Duration, -} - -impl HandshakeTimer { - pub fn new(timer: Timer, duration: Duration) -> HandshakeTimer { - HandshakeTimer { timer, duration } - } - - pub fn timeout(&self, future: F) -> Timeout - where - F: Future, - E: From>, - { - self.timer.timeout(future, self.duration) - } -} - -#[cfg(test)] -mod tests { - use std::time::Duration; - - use futures::future::{self, Future}; - use tokio_timer; - - use super::HandshakeTimer; - - #[test] - fn positive_finish_before_timeout() { - let timer = HandshakeTimer::new(tokio_timer::wheel().build(), Duration::from_millis(50)); - let result = timer.timeout(future::ok::<&'static str, ()>("Hello")).wait().unwrap(); - - assert_eq!("Hello", result); - } - - #[test] - #[should_panic(expected = "called `Result::unwrap()` on an `Err` value: ()")] - fn negative_finish_after_timeout() { - let timer = HandshakeTimer::new(tokio_timer::wheel().build(), Duration::from_millis(50)); - - timer.timeout(future::empty::<(), ()>()).wait().unwrap(); - } -} diff --git a/packages/handshake/src/handshake/mod.rs b/packages/handshake/src/handshake/mod.rs index a89cefab0..f1dd8d219 100644 --- a/packages/handshake/src/handshake/mod.rs +++ b/packages/handshake/src/handshake/mod.rs @@ -1,28 +1,19 @@ -use std::time::Duration; -use std::{cmp, io}; +use std::task::{Context, Poll}; use builder::HandshakerBuilder; -use futures::sink::Sink; -use futures::stream::Stream; -use futures::sync::mpsc::{self, SendError}; -use futures::{Poll, StartSend}; +use futures::channel::mpsc; +use futures::{Sink, SinkExt as _, Stream, StreamExt as _}; +use handler::listener::ListenerHandler; +use handler::{handshaker, initiator}; use sink::HandshakerSink; use stream::HandshakerStream; -use tokio_core::reactor::Handle; -use tokio_io::{AsyncRead, AsyncWrite}; -use tokio_timer::{self}; +use tokio::io::{AsyncRead, AsyncWrite}; +use tokio::task::JoinSet; use util::bt::PeerId; -use crate::discovery::DiscoveryInfo; use crate::filter::filters::Filters; -use crate::filter::{HandshakeFilter, HandshakeFilters}; -use crate::handshake::handler::listener::ListenerHandler; -use crate::handshake::handler::timer::HandshakeTimer; -use crate::handshake::handler::{handshaker, initiator}; -use crate::local_addr::LocalAddr; -use crate::message::complete::CompleteMessage; -use crate::message::initiate::InitiateMessage; -use crate::transport::Transport; +use crate::local_addr::LocalAddr as _; +use crate::{CompleteMessage, DiscoveryInfo, HandshakeFilter, HandshakeFilters, InitiateMessage, Transport}; pub mod builder; pub mod config; @@ -30,8 +21,6 @@ pub mod handler; pub mod sink; pub mod stream; -//----------------------------------------------------------------------------------// - /// Handshaker which is both `Stream` and `Sink`. pub struct Handshaker { sink: HandshakerSink, @@ -59,15 +48,39 @@ impl DiscoveryInfo for Handshaker { } } +impl HandshakeFilters for Handshaker { + fn add_filter(&self, filter: F) + where + F: HandshakeFilter + PartialEq + Eq + Send + Sync + 'static, + { + self.sink.add_filter(filter); + } + + fn remove_filter(&self, filter: F) + where + F: HandshakeFilter + PartialEq + Eq + Send + Sync + 'static, + { + self.sink.remove_filter(filter); + } + + fn clear_filters(&self) { + self.sink.clear_filters(); + } +} + impl Handshaker where - S: AsyncRead + AsyncWrite + 'static, + S: AsyncRead + AsyncWrite + std::fmt::Debug + Send + Sync + Unpin + 'static, { - fn with_builder(builder: &HandshakerBuilder, transport: T, handle: &Handle) -> io::Result> + async fn with_builder(builder: &HandshakerBuilder, transport: T) -> std::io::Result<(Handshaker, JoinSet<()>)> where - T: Transport + 'static, + T: Transport + Send + Sync + 'static, + ::Listener: Send, { - let listener = transport.listen(&builder.bind, handle)?; + let config = builder.config; + let timeout = std::cmp::max(config.handshake_timeout(), config.connect_timeout()); + + let listener = transport.listen(builder.bind, timeout).await?; // Resolve our "real" public port let open_port = if builder.port == 0 { @@ -76,89 +89,68 @@ where builder.port }; - let config = builder.config; let (addr_send, addr_recv) = mpsc::channel(config.sink_buffer_size()); let (hand_send, hand_recv) = mpsc::channel(config.wait_buffer_size()); let (sock_send, sock_recv) = mpsc::channel(config.done_buffer_size()); let filters = Filters::new(); - let (handshake_timer, initiate_timer) = configured_handshake_timers(config.handshake_timeout(), config.connect_timeout()); // Hook up our pipeline of handlers which will take some connection info, process it, and forward it - handler::loop_handler( + + let mut tasks = JoinSet::new(); + + tasks.spawn(handler::loop_handler( addr_recv, initiator::initiator_handler, hand_send.clone(), - (transport, filters.clone(), handle.clone(), initiate_timer), - handle, - ); - handler::loop_handler(listener, ListenerHandler::new, hand_send, filters.clone(), handle); - handler::loop_handler( - hand_recv.map(Result::Ok).buffer_unordered(100), + Box::pin((transport, filters.clone(), timeout)), + )); + + tasks.spawn(handler::loop_handler( + listener, + ListenerHandler::new, + hand_send, + Box::pin(filters.clone()), + )); + + tasks.spawn(handler::loop_handler( + hand_recv, handshaker::execute_handshake, sock_send, - (builder.ext, builder.pid, filters.clone(), handshake_timer), - handle, - ); + Box::pin((builder.ext, builder.pid, filters.clone(), timeout)), + )); let sink = HandshakerSink::new(addr_send, open_port, builder.pid, filters); let stream = HandshakerStream::new(sock_recv); - Ok(Handshaker { sink, stream }) + Ok((Handshaker { sink, stream }, tasks)) } } -/// Configure a timer wheel and create a `HandshakeTimer`. -fn configured_handshake_timers(duration_one: Duration, duration_two: Duration) -> (HandshakeTimer, HandshakeTimer) { - let timer = tokio_timer::wheel() - .num_slots(64) - .max_timeout(cmp::max(duration_one, duration_two)) - .build(); - - ( - HandshakeTimer::new(timer.clone(), duration_one), - HandshakeTimer::new(timer, duration_two), - ) -} - -impl Sink for Handshaker { - type SinkItem = InitiateMessage; - type SinkError = SendError; +impl Sink for Handshaker { + type Error = mpsc::SendError; - fn start_send(&mut self, item: InitiateMessage) -> StartSend> { - self.sink.start_send(item) + fn poll_ready(mut self: std::pin::Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + self.sink.poll_ready_unpin(cx) } - fn poll_complete(&mut self) -> Poll<(), SendError> { - self.sink.poll_complete() + fn start_send(mut self: std::pin::Pin<&mut Self>, item: InitiateMessage) -> Result<(), Self::Error> { + self.sink.start_send_unpin(item) } -} - -impl Stream for Handshaker { - type Item = CompleteMessage; - type Error = (); - fn poll(&mut self) -> Poll>, ()> { - self.stream.poll() + fn poll_flush(mut self: std::pin::Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + self.sink.poll_flush_unpin(cx) } -} -impl HandshakeFilters for Handshaker { - fn add_filter(&self, filter: F) - where - F: HandshakeFilter + PartialEq + Eq + Send + Sync + 'static, - { - self.sink.add_filter(filter); + fn poll_close(mut self: std::pin::Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + self.sink.poll_close_unpin(cx) } +} - fn remove_filter(&self, filter: F) - where - F: HandshakeFilter + PartialEq + Eq + Send + Sync + 'static, - { - self.sink.remove_filter(filter); - } +impl Stream for Handshaker { + type Item = std::io::Result>; - fn clear_filters(&self) { - self.sink.clear_filters(); + fn poll_next(mut self: std::pin::Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + self.stream.poll_next_unpin(cx) } } diff --git a/packages/handshake/src/handshake/sink.rs b/packages/handshake/src/handshake/sink.rs index 191d16451..ad3b843ee 100644 --- a/packages/handshake/src/handshake/sink.rs +++ b/packages/handshake/src/handshake/sink.rs @@ -1,24 +1,27 @@ //! `Sink` portion of the `Handshaker` for initiating handshakes. +use futures::channel::mpsc; use futures::sink::Sink; -use futures::sync::mpsc::{SendError, Sender}; -use futures::{Poll, StartSend}; +use futures::task::{Context, Poll}; +use futures::SinkExt as _; use util::bt::PeerId; +use crate::discovery::DiscoveryInfo; use crate::filter::filters::Filters; -use crate::{DiscoveryInfo, HandshakeFilter, HandshakeFilters, InitiateMessage}; +use crate::filter::{HandshakeFilter, HandshakeFilters}; +use crate::message::initiate::InitiateMessage; #[allow(clippy::module_name_repetitions)] #[derive(Clone)] pub struct HandshakerSink { - send: Sender, + send: mpsc::Sender, port: u16, pid: PeerId, filters: Filters, } impl HandshakerSink { - pub(super) fn new(send: Sender, port: u16, pid: PeerId, filters: Filters) -> HandshakerSink { + pub(super) fn new(send: mpsc::Sender, port: u16, pid: PeerId, filters: Filters) -> HandshakerSink { HandshakerSink { send, port, @@ -38,16 +41,23 @@ impl DiscoveryInfo for HandshakerSink { } } -impl Sink for HandshakerSink { - type SinkItem = InitiateMessage; - type SinkError = SendError; +impl Sink for HandshakerSink { + type Error = mpsc::SendError; - fn start_send(&mut self, item: InitiateMessage) -> StartSend> { - self.send.start_send(item) + fn poll_ready(mut self: std::pin::Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + self.send.poll_ready_unpin(cx) } - fn poll_complete(&mut self) -> Poll<(), SendError> { - self.send.poll_complete() + fn start_send(mut self: std::pin::Pin<&mut Self>, item: InitiateMessage) -> Result<(), Self::Error> { + self.send.start_send_unpin(item) + } + + fn poll_flush(mut self: std::pin::Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + self.send.poll_flush_unpin(cx) + } + + fn poll_close(mut self: std::pin::Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + self.send.poll_close_unpin(cx) } } @@ -70,3 +80,5 @@ impl HandshakeFilters for HandshakerSink { self.filters.clear_filters(); } } + +//----------------------------------------------------------------------------------// diff --git a/packages/handshake/src/handshake/stream.rs b/packages/handshake/src/handshake/stream.rs index 9b862ae69..b217cb04c 100644 --- a/packages/handshake/src/handshake/stream.rs +++ b/packages/handshake/src/handshake/stream.rs @@ -1,26 +1,26 @@ -//! `Stream` portion of the `Handshaker` for completed handshakes. +use std::task::{Context, Poll}; -use futures::sync::mpsc::Receiver; -use futures::{Poll, Stream}; +use futures::channel::mpsc; +use futures::{Stream, StreamExt as _}; use crate::CompleteMessage; +/// `Stream` portion of the `Handshaker` for completed handshakes. #[allow(clippy::module_name_repetitions)] pub struct HandshakerStream { - recv: Receiver>, + recv: mpsc::Receiver>>, } impl HandshakerStream { - pub(super) fn new(recv: Receiver>) -> HandshakerStream { + pub(super) fn new(recv: mpsc::Receiver>>) -> HandshakerStream { HandshakerStream { recv } } } impl Stream for HandshakerStream { - type Item = CompleteMessage; - type Error = (); + type Item = std::io::Result>; - fn poll(&mut self) -> Poll>, ()> { - self.recv.poll() + fn poll_next(mut self: std::pin::Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + self.recv.poll_next_unpin(cx) } } diff --git a/packages/handshake/src/local_addr.rs b/packages/handshake/src/local_addr.rs index 1b1e90032..8190467cf 100644 --- a/packages/handshake/src/local_addr.rs +++ b/packages/handshake/src/local_addr.rs @@ -1,8 +1,5 @@ -use std::io; use std::net::SocketAddr; -use tokio_core::net::TcpStream; - /// Trait for getting the local address. pub trait LocalAddr { @@ -11,11 +8,5 @@ pub trait LocalAddr { /// # Errors /// /// It would return an IO Error if unable to obtain the local address. - fn local_addr(&self) -> io::Result; -} - -impl LocalAddr for TcpStream { - fn local_addr(&self) -> io::Result { - TcpStream::local_addr(self) - } + fn local_addr(&self) -> std::io::Result; } diff --git a/packages/handshake/src/message/complete.rs b/packages/handshake/src/message/complete.rs index d1cffe44e..7c019c842 100644 --- a/packages/handshake/src/message/complete.rs +++ b/packages/handshake/src/message/complete.rs @@ -7,6 +7,7 @@ use crate::message::protocol::Protocol; /// Message containing completed handshaking information. #[allow(clippy::module_name_repetitions)] +#[derive(Debug)] pub struct CompleteMessage { prot: Protocol, ext: Extensions, diff --git a/packages/handshake/src/message/extensions.rs b/packages/handshake/src/message/extensions.rs index 9e65a0971..7ca75fdd8 100644 --- a/packages/handshake/src/message/extensions.rs +++ b/packages/handshake/src/message/extensions.rs @@ -1,7 +1,6 @@ -use std::io; -use std::io::Write; - -use nom::{be_u8, call, count_fixed, do_parse, error_position, IResult}; +use nom::bytes::complete::take; +use nom::IResult; +use tokio::io::{AsyncWrite, AsyncWriteExt as _}; /// Number of bytes that the extension protocol takes. pub const NUM_EXTENSION_BYTES: usize = 8; @@ -32,7 +31,10 @@ impl Extensions { } /// Create a new `Extensions` by parsing the given bytes. - #[must_use] + /// + /// # Errors + /// + /// This function will return an error if unable to construct from bytes. pub fn from_bytes(bytes: &[u8]) -> IResult<&[u8], Extensions> { parse_extension_bits(bytes) } @@ -67,14 +69,26 @@ impl Extensions { self.bytes[byte_index] & (0x80 >> bit_index) != 0 } - /// Write the `Extensions` to the given writer. + /// Write the `Extensions` to the given async writer. /// /// # Errors /// /// It would return an IO error if unable to write bytes. - pub fn write_bytes(&self, mut writer: W) -> io::Result<()> + pub async fn write_bytes(&self, writer: &mut W) -> std::io::Result<()> + where + W: AsyncWrite + Unpin, + { + writer.write_all(&self.bytes[..]).await + } + + /// Write the `Extensions` to the given writer. + /// + /// # Errors + /// + /// This function will return an error if unable to write bytes. + pub fn write_bytes_sync(&self, writer: &mut W) -> std::io::Result<()> where - W: Write, + W: std::io::Write, { writer.write_all(&self.bytes[..]) } @@ -106,12 +120,11 @@ impl From<[u8; NUM_EXTENSION_BYTES]> for Extensions { } /// Parse the given bytes for extension bits. -#[allow(deprecated)] fn parse_extension_bits(bytes: &[u8]) -> IResult<&[u8], Extensions> { - do_parse!(bytes, - bytes: count_fixed!(u8, be_u8, NUM_EXTENSION_BYTES) >> - (Extensions::with_bytes(bytes)) - ) + let (remaining, bytes) = take(NUM_EXTENSION_BYTES)(bytes)?; + let mut array = [0u8; NUM_EXTENSION_BYTES]; + array.copy_from_slice(bytes); + Ok((remaining, Extensions::with_bytes(array))) } #[cfg(test)] diff --git a/packages/handshake/src/message/protocol.rs b/packages/handshake/src/message/protocol.rs index 7a1480ed1..8e62197f1 100644 --- a/packages/handshake/src/message/protocol.rs +++ b/packages/handshake/src/message/protocol.rs @@ -1,7 +1,8 @@ -use std::io; -use std::io::Write; - -use nom::{be_u8, call, do_parse, error_node_position, error_position, map, switch, take, value, IResult}; +use nom::bytes::complete::take; +use nom::number::complete::u8; +use nom::sequence::tuple; +use nom::IResult; +use tokio::io::{AsyncWrite, AsyncWriteExt as _}; const BT_PROTOCOL: &[u8] = b"BitTorrent protocol"; const BT_PROTOCOL_LEN: u8 = 19; @@ -15,7 +16,10 @@ pub enum Protocol { impl Protocol { /// Create a `Protocol` from the given bytes. - #[must_use] + /// + /// # Errors + /// + /// This function will return an error if unable to construct from bytes. pub fn from_bytes(bytes: &[u8]) -> IResult<&[u8], Protocol> { parse_protocol(bytes) } @@ -25,12 +29,33 @@ impl Protocol { /// # Errors /// /// It would return an IO Error if unable to write bytes. - pub fn write_bytes(&self, mut writer: W) -> io::Result<()> + pub async fn write_bytes(&self, writer: &mut W) -> std::io::Result<()> + where + W: AsyncWrite + Unpin, + { + let (len, bytes) = match self { + Protocol::BitTorrent => (BT_PROTOCOL_LEN as usize, BT_PROTOCOL), + Protocol::Custom(prot) => (prot.len(), &prot[..]), + }; + + #[allow(clippy::cast_possible_truncation)] + writer.write_all(&[len as u8][..]).await?; + writer.write_all(bytes).await?; + + Ok(()) + } + + /// Write the `Extensions` to the given writer. + /// + /// # Errors + /// + /// This function will return an error if unable to write bytes. + pub fn write_bytes_sync(&self, writer: &mut W) -> std::io::Result<()> where - W: Write, + W: std::io::Write, { let (len, bytes) = match self { - &Protocol::BitTorrent => (BT_PROTOCOL_LEN as usize, BT_PROTOCOL), + Protocol::BitTorrent => (BT_PROTOCOL_LEN as usize, BT_PROTOCOL), Protocol::Custom(prot) => (prot.len(), &prot[..]), }; @@ -45,7 +70,7 @@ impl Protocol { #[must_use] pub fn write_len(&self) -> usize { match self { - &Protocol::BitTorrent => BT_PROTOCOL_LEN as usize, + Protocol::BitTorrent => BT_PROTOCOL_LEN as usize, Protocol::Custom(custom) => custom.len(), } } @@ -55,19 +80,45 @@ fn parse_protocol(bytes: &[u8]) -> IResult<&[u8], Protocol> { parse_real_protocol(bytes) } -#[allow(unreachable_patterns, unused)] fn parse_real_protocol(bytes: &[u8]) -> IResult<&[u8], Protocol> { - switch!(bytes, parse_raw_protocol, - // TODO: Move back to using constant here, for now, MIR compiler error occurs - b"BitTorrent protocol" => value!(Protocol::BitTorrent) | - custom => value!(Protocol::Custom(custom.to_vec())) - ) + let (remaining, (_length, raw_protocol)) = tuple((u8, take(bytes[0] as usize)))(bytes)?; + if raw_protocol == BT_PROTOCOL { + Ok((remaining, Protocol::BitTorrent)) + } else { + Ok((remaining, Protocol::Custom(raw_protocol.to_vec()))) + } } +#[allow(dead_code)] fn parse_raw_protocol(bytes: &[u8]) -> IResult<&[u8], &[u8]> { - do_parse!(bytes, - length: be_u8 >> - raw_protocol: take!(length) >> - (raw_protocol) - ) + let (remaining, (_length, raw_protocol)) = tuple((u8, take(bytes[0] as usize)))(bytes)?; + Ok((remaining, raw_protocol)) +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_from_bytes_bittorrent() { + let input = [ + 19, b'B', b'i', b't', b'T', b'o', b'r', b'r', b'e', b'n', b't', b' ', b'p', b'r', b'o', b't', b'o', b'c', b'o', b'l', + ]; + let (_, protocol) = Protocol::from_bytes(&input).unwrap(); + assert_eq!(protocol, Protocol::BitTorrent); + } + + #[tokio::test] + async fn test_write_bytes_bittorrent() { + let protocol = Protocol::BitTorrent; + let mut buffer = Vec::new(); + let () = protocol.write_bytes(&mut buffer).await.unwrap(); + assert_eq!( + buffer, + vec![ + 19, b'B', b'i', b't', b'T', b'o', b'r', b'r', b'e', b'n', b't', b' ', b'p', b'r', b'o', b't', b'o', b'c', b'o', + b'l' + ] + ); + } } diff --git a/packages/handshake/src/transport.rs b/packages/handshake/src/transport.rs index 9071357e4..46d7ee5c3 100644 --- a/packages/handshake/src/transport.rs +++ b/packages/handshake/src/transport.rs @@ -1,91 +1,108 @@ -use std::io; use std::net::SocketAddr; +use std::pin::Pin; +use std::task::{Context, Poll}; +use std::time::Duration; -use futures::future::Future; -use futures::stream::Stream; -use futures::Poll; -use tokio_core::net::{Incoming, TcpListener, TcpStream, TcpStreamNew}; -use tokio_core::reactor::Handle; -use tokio_io::{AsyncRead, AsyncWrite}; +use futures::future::BoxFuture; +use futures::{Future, FutureExt as _, Stream, TryFutureExt as _}; +use tokio::io::{AsyncRead, AsyncWrite}; +use tokio::net::{TcpListener, TcpStream}; use crate::local_addr::LocalAddr; /// Trait for initializing connections over an abstract `Transport`. pub trait Transport { - /// Concrete socket. - type Socket: AsyncRead + AsyncWrite + 'static; + /// The type of socket used by this transport. + type Socket: AsyncRead + AsyncWrite + Unpin + 'static; - /// Future `Self::Socket`. - type FutureSocket: Future + 'static; + /// The future that resolves to a `Socket`. + type FutureSocket: Future> + Send + 'static; - /// Concrete listener. - type Listener: Stream + LocalAddr + 'static; + /// The type of listener used by this transport. + type Listener: Stream> + LocalAddr + Unpin + 'static; - /// Connect to the given address over this transport, using the supplied `Handle`. + /// The future that resolves to a `Listener`. + type FutureListener: Future> + Send + 'static; + + /// Connect to the given address using this transport. /// /// # Errors /// - /// It would return an IO Error if unable to connect to socket. - fn connect(&self, addr: &SocketAddr, handle: &Handle) -> io::Result; + /// Returns an IO error if unable to connect to the socket. + fn connect(&self, addr: SocketAddr, timeout: Duration) -> Self::FutureSocket; - /// Listen to the given address for this transport, using the supplied `Handle`. + /// Listen on the given address using this transport. /// /// # Errors /// - /// It would return an IO Error if unable to listen to socket. - fn listen(&self, addr: &SocketAddr, handle: &Handle) -> io::Result; + /// Returns an IO error if unable to bind to the socket. + fn listen(&self, addr: SocketAddr, timeout: Duration) -> Self::FutureListener; } //----------------------------------------------------------------------------------// -/// Defines a `Transport` operating over TCP. +/// A `Transport` implementation for TCP. #[allow(clippy::module_name_repetitions)] pub struct TcpTransport; impl Transport for TcpTransport { type Socket = TcpStream; - type FutureSocket = TcpStreamNew; - type Listener = TcpListenerStream; + type FutureSocket = BoxFuture<'static, std::io::Result>; + type Listener = TcpListenerStream; + type FutureListener = BoxFuture<'static, std::io::Result>; + + fn connect(&self, addr: SocketAddr, timeout: Duration) -> Self::FutureSocket { + let socket = TcpStream::connect(addr); + let socket = tokio::time::timeout(timeout, socket) + .map_err(|e| std::io::Error::new(std::io::ErrorKind::TimedOut, e)) + .boxed(); - fn connect(&self, addr: &SocketAddr, handle: &Handle) -> io::Result { - Ok(TcpStream::connect(addr, handle)) + socket.map(|s| s.and_then(|s| s)).boxed() } - fn listen(&self, addr: &SocketAddr, handle: &Handle) -> io::Result { - let listener = TcpListener::bind(addr, handle)?; - let listen_addr = listener.local_addr()?; + fn listen(&self, addr: SocketAddr, timeout: Duration) -> Self::FutureListener { + let listener = TcpListener::bind(addr); + + let listener = tokio::time::timeout(timeout, listener) + .map_err(|e| std::io::Error::new(std::io::ErrorKind::TimedOut, e)) + .boxed(); - Ok(TcpListenerStream::new(listen_addr, listener.incoming())) + let listener = listener.map(|l| l.and_then(|l| l)).boxed(); + + listener.map_ok(TcpListenerStream::new).boxed() } } -/// Convenient object that wraps a listener stream `L`, and also implements `LocalAddr`. -pub struct TcpListenerStream { - listen_addr: SocketAddr, - listener: L, +//----------------------------------------------------------------------------------// + +/// A custom stream for `TcpListener`. +pub struct TcpListenerStream { + listener: TcpListener, } -impl TcpListenerStream { - fn new(listen_addr: SocketAddr, listener: L) -> TcpListenerStream { - TcpListenerStream { listen_addr, listener } +impl TcpListenerStream { + /// Creates a new `TcpListenerStream` from a `TcpListener`. + fn new(listener: TcpListener) -> Self { + TcpListenerStream { listener } } } -impl Stream for TcpListenerStream -where - L: Stream, -{ - type Item = L::Item; - type Error = L::Error; - - fn poll(&mut self) -> Poll, Self::Error> { - self.listener.poll() +impl LocalAddr for TcpListenerStream { + fn local_addr(&self) -> std::io::Result { + self.listener.local_addr() } } -impl LocalAddr for TcpListenerStream { - fn local_addr(&self) -> io::Result { - Ok(self.listen_addr) +impl Stream for TcpListenerStream { + type Item = std::io::Result<(TcpStream, SocketAddr)>; + + fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + let listener = &self.listener; + match listener.poll_accept(cx) { + Poll::Ready(Ok((socket, addr))) => Poll::Ready(Some(Ok((socket, addr)))), + Poll::Ready(Err(e)) => Poll::Ready(Some(Err(e))), + Poll::Pending => Poll::Pending, + } } } @@ -93,41 +110,48 @@ impl LocalAddr for TcpListenerStream { #[cfg(test)] pub mod test_transports { - use std::io::{self, Cursor}; + use std::net::SocketAddr; + use std::pin::Pin; + use std::task::{Context, Poll}; + use std::time::Duration; - use futures::future::{self, FutureResult}; + use futures::future::{self, BoxFuture}; + use futures::io::{self}; use futures::stream::{self, Empty, Stream}; - use futures::Poll; - use tokio_core::reactor::Handle; + use futures::{FutureExt as _, StreamExt as _}; use super::Transport; - use crate::local_addr::LocalAddr; + use crate::LocalAddr; + /// A mock transport for testing purposes. pub struct MockTransport; impl Transport for MockTransport { - type Socket = Cursor>; - type FutureSocket = FutureResult; + type Socket = std::io::Cursor>; + type FutureSocket = BoxFuture<'static, std::io::Result>; type Listener = MockListener; + type FutureListener = BoxFuture<'static, std::io::Result>; - fn connect(&self, _addr: &SocketAddr, _handle: &Handle) -> io::Result { - Ok(future::ok(Cursor::new(Vec::new()))) + fn connect(&self, _addr: SocketAddr, _timeout: Duration) -> Self::FutureSocket { + future::ok(std::io::Cursor::new(Vec::new())).boxed() } - fn listen(&self, addr: &SocketAddr, _handle: &Handle) -> io::Result { - Ok(MockListener::new(*addr)) + fn listen(&self, addr: SocketAddr, _timeout: Duration) -> Self::FutureListener { + future::ok(MockListener::new(addr)).boxed() } } //----------------------------------------------------------------------------------// + /// A mock listener for testing purposes. pub struct MockListener { addr: SocketAddr, - empty: Empty<(Cursor>, SocketAddr), io::Error>, + empty: Empty>, SocketAddr)>>, } impl MockListener { + /// Creates a new `MockListener` with the given address. fn new(addr: SocketAddr) -> MockListener { MockListener { addr, @@ -137,17 +161,41 @@ pub mod test_transports { } impl LocalAddr for MockListener { - fn local_addr(&self) -> io::Result { + fn local_addr(&self) -> std::io::Result { Ok(self.addr) } } impl Stream for MockListener { - type Item = (Cursor>, SocketAddr); - type Error = io::Error; + type Item = std::io::Result<(std::io::Cursor>, SocketAddr)>; - fn poll(&mut self) -> Poll, Self::Error> { - self.empty.poll() + fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + self.get_mut().empty.poll_next_unpin(cx) } } + + //----------------------------------------------------------------------------------// + + #[tokio::test] + async fn test_mock_transport_connect() { + let transport = MockTransport; + let addr = "127.0.0.1:8080".parse().unwrap(); + let timeout = Duration::from_secs(1); + + let socket = transport.connect(addr, timeout).await; + assert!(socket.is_ok()); + } + + #[tokio::test] + async fn test_mock_transport_listen() { + let transport = MockTransport; + let addr = "127.0.0.1:8080".parse().unwrap(); + let timeout = Duration::from_secs(1); + + let listener = transport.listen(addr, timeout).await; + assert!(listener.is_ok()); + + let listener = listener.unwrap(); + assert_eq!(listener.local_addr().unwrap(), addr); + } } diff --git a/packages/handshake/tests/common/mod.rs b/packages/handshake/tests/common/mod.rs index b9b0494d5..c952d9cc5 100644 --- a/packages/handshake/tests/common/mod.rs +++ b/packages/handshake/tests/common/mod.rs @@ -1,8 +1,27 @@ //----------------------------------------------------------------------------------// +use std::sync::Once; + +use tracing::level_filters::LevelFilter; + +#[allow(dead_code)] +pub static INIT: Once = Once::new(); + #[allow(dead_code)] #[derive(PartialEq, Eq, Debug)] pub enum TimeoutResult { TimedOut, GotResult, } + +#[allow(dead_code)] +pub fn tracing_stderr_init(filter: LevelFilter) { + let builder = tracing_subscriber::fmt() + .with_max_level(filter) + .with_ansi(true) + .with_writer(std::io::stderr); + + builder.pretty().with_file(true).init(); + + tracing::info!("Logging initialized"); +} diff --git a/packages/handshake/tests/test_byte_after_handshake.rs b/packages/handshake/tests/test_byte_after_handshake.rs index 14e5018e1..d750f6f74 100644 --- a/packages/handshake/tests/test_byte_after_handshake.rs +++ b/packages/handshake/tests/test_byte_after_handshake.rs @@ -1,33 +1,35 @@ -use std::io::{Read, Write}; +use std::io::{Read as _, Write as _}; use std::net::TcpStream; -use std::thread; -use futures::stream::Stream; -use futures::Future; +use common::{tracing_stderr_init, INIT}; +use futures::stream::StreamExt; use handshake::transports::TcpTransport; use handshake::{DiscoveryInfo, HandshakerBuilder}; -use tokio_core::reactor::Core; -use tokio_io::io; +use tokio::io::AsyncReadExt as _; +use tracing::level_filters::LevelFilter; use util::bt::{self}; mod common; -#[test] -fn positive_recover_bytes() { - let mut core = Core::new().unwrap(); +#[tokio::test] +async fn positive_recover_bytes() { + INIT.call_once(|| { + tracing_stderr_init(LevelFilter::INFO); + }); let mut handshaker_one_addr = "127.0.0.1:0".parse().unwrap(); let handshaker_one_pid = [4u8; bt::PEER_ID_LEN].into(); - let handshaker_one = HandshakerBuilder::new() + let (mut handshaker_one, mut tasks_one) = HandshakerBuilder::new() .with_bind_addr(handshaker_one_addr) .with_peer_id(handshaker_one_pid) - .build(TcpTransport, &core.handle()) + .build(TcpTransport) + .await .unwrap(); handshaker_one_addr.set_port(handshaker_one.port()); - thread::spawn(move || { + tasks_one.spawn_blocking(move || { let mut stream = TcpStream::connect(handshaker_one_addr).unwrap(); let mut write_buffer = Vec::new(); @@ -43,20 +45,21 @@ fn positive_recover_bytes() { stream.read_exact(&mut vec![0u8; expect_read_length][..]).unwrap(); }); - let recv_buffer = core - .run( - handshaker_one - .into_future() - .map_err(|_| ()) - .and_then(|(opt_message, _)| { - let (_, _, _, _, _, sock) = opt_message.unwrap().into_parts(); - - io::read_exact(sock, vec![0u8; 1]).map_err(|_| ()) - }) - .and_then(|(_, buf)| Ok(buf)), - ) - .unwrap(); + let test = tokio::spawn(async move { + if let Some(message) = handshaker_one.next().await { + let (_, _, _, _, _, mut sock) = message.unwrap().into_parts(); + + let mut recv_buffer = vec![0u8; 1]; + sock.read_exact(&mut recv_buffer).await.unwrap(); + + // Assert that our buffer contains the bytes after the handshake + assert_eq!(vec![55], recv_buffer); + } else { + panic!("Failed to receive handshake message"); + } + }); - // Assert that our buffer contains the bytes after the handshake - assert_eq!(vec![55], recv_buffer); + let res = test.await; + tasks_one.shutdown().await; + res.unwrap(); } diff --git a/packages/handshake/tests/test_bytes_after_handshake.rs b/packages/handshake/tests/test_bytes_after_handshake.rs index 75782195e..82ec37833 100644 --- a/packages/handshake/tests/test_bytes_after_handshake.rs +++ b/packages/handshake/tests/test_bytes_after_handshake.rs @@ -1,34 +1,38 @@ -use std::io::{Read, Write}; -use std::net::TcpStream; -use std::thread; +use std::io::{Read as _, Write as _}; +use std::net::{IpAddr, Ipv4Addr, SocketAddr, TcpStream}; -use futures::stream::Stream; -use futures::Future; +use common::{tracing_stderr_init, INIT}; +use futures::stream::StreamExt; use handshake::transports::TcpTransport; use handshake::{DiscoveryInfo, HandshakerBuilder}; -use tokio_core::reactor::Core; -use tokio_io::io; +use tokio::io::AsyncReadExt as _; +use tracing::level_filters::LevelFilter; use util::bt::{self}; mod common; -#[test] -fn positive_recover_bytes() { - let mut core = Core::new().unwrap(); +#[tokio::test] +async fn positive_recover_bytes() { + INIT.call_once(|| { + tracing_stderr_init(LevelFilter::INFO); + }); + + let mut handshaker_one_addr = SocketAddr::new(IpAddr::V4(Ipv4Addr::UNSPECIFIED), 0); - let mut handshaker_one_addr = "127.0.0.1:0".parse().unwrap(); let handshaker_one_pid = [4u8; bt::PEER_ID_LEN].into(); - let handshaker_one = HandshakerBuilder::new() + let (mut handshaker_one, mut tasks_one) = HandshakerBuilder::new() .with_bind_addr(handshaker_one_addr) .with_peer_id(handshaker_one_pid) - .build(TcpTransport, &core.handle()) + .build(TcpTransport) + .await .unwrap(); handshaker_one_addr.set_port(handshaker_one.port()); - thread::spawn(move || { + tasks_one.spawn_blocking(move || { let mut stream = TcpStream::connect(handshaker_one_addr).unwrap(); + let mut write_buffer = Vec::new(); write_buffer.write_all(&[1, 1]).unwrap(); @@ -43,20 +47,23 @@ fn positive_recover_bytes() { stream.read_exact(&mut vec![0u8; expect_read_length][..]).unwrap(); }); - let recv_buffer = core - .run( - handshaker_one - .into_future() - .map_err(|_| ()) - .and_then(|(opt_message, _)| { - let (_, _, _, _, _, sock) = opt_message.unwrap().into_parts(); - - io::read_exact(sock, vec![0u8; 100]).map_err(|_| ()) - }) - .and_then(|(_, buf)| Ok(buf)), - ) - .unwrap(); + let test = tokio::spawn(async move { + if let Some(message) = handshaker_one.next().await { + let (_, _, _, _, _, mut sock) = message.unwrap().into_parts(); + + let mut recv_buffer = vec![0u8; 100]; + sock.read_exact(&mut recv_buffer).await.unwrap(); + + // Assert that our buffer contains the bytes after the handshake + assert_eq!(vec![55u8; 100], recv_buffer); + } else { + panic!("Failed to receive handshake message"); + } + }); + + let res = test.await; + + tasks_one.shutdown().await; - // Assert that our buffer contains the bytes after the handshake - assert_eq!(vec![55u8; 100], recv_buffer); + res.unwrap(); } diff --git a/packages/handshake/tests/test_connect.rs b/packages/handshake/tests/test_connect.rs index 43173dd41..a430289fa 100644 --- a/packages/handshake/tests/test_connect.rs +++ b/packages/handshake/tests/test_connect.rs @@ -1,61 +1,81 @@ -use futures::sink::Sink; -use futures::stream::Stream; -use futures::Future; +use common::{tracing_stderr_init, INIT}; +use futures::future::try_join; +use futures::sink::SinkExt; +use futures::stream::StreamExt; use handshake::transports::TcpTransport; use handshake::{DiscoveryInfo, HandshakerBuilder, InitiateMessage, Protocol}; -use tokio_core::reactor::Core; +use tokio::net::TcpStream; +use tracing::level_filters::LevelFilter; use util::bt::{self}; mod common; -#[test] -fn positive_connect() { - let mut core = Core::new().unwrap(); +#[tokio::test] +async fn positive_connect() { + INIT.call_once(|| { + tracing_stderr_init(LevelFilter::INFO); + }); - let mut handshaker_one_addr = "127.0.0.1:0".parse().unwrap(); + let handshaker_one_addr = "127.0.0.1:0".parse().unwrap(); let handshaker_one_pid = [4u8; bt::PEER_ID_LEN].into(); - let handshaker_one = HandshakerBuilder::new() + let (mut handshaker_one, mut tasks_one) = HandshakerBuilder::new() .with_bind_addr(handshaker_one_addr) .with_peer_id(handshaker_one_pid) - .build(TcpTransport, &core.handle()) + .build(TcpTransport) + .await .unwrap(); + let mut handshaker_one_addr = handshaker_one_addr; handshaker_one_addr.set_port(handshaker_one.port()); - let mut handshaker_two_addr = "127.0.0.1:0".parse().unwrap(); + let handshaker_two_addr = "127.0.0.1:0".parse().unwrap(); let handshaker_two_pid = [5u8; bt::PEER_ID_LEN].into(); - let handshaker_two = HandshakerBuilder::new() + let (mut handshaker_two, mut tasks_two) = HandshakerBuilder::new() .with_bind_addr(handshaker_two_addr) .with_peer_id(handshaker_two_pid) - .build(TcpTransport, &core.handle()) + .build(TcpTransport) + .await .unwrap(); - handshaker_two_addr.set_port(handshaker_two.port()); - - let (item_one, item_two) = core - .run( - handshaker_one - .send(InitiateMessage::new( - Protocol::BitTorrent, - [55u8; bt::INFO_HASH_LEN].into(), - handshaker_two_addr, - )) - .map_err(|_| ()) - .and_then(|handshaker_one| { - handshaker_one - .into_future() - .join(handshaker_two.into_future()) - .map_err(|_| ()) - }) - .map(|((opt_item_one, _), (opt_item_two, _))| (opt_item_one.unwrap(), opt_item_two.unwrap())), - ) - .unwrap(); + let test = tokio::spawn(async move { + let mut handshaker_two_addr = handshaker_two_addr; + handshaker_two_addr.set_port(handshaker_two.port()); + + // Send the initiate message first + handshaker_one + .send(InitiateMessage::new( + Protocol::BitTorrent, + [55u8; bt::INFO_HASH_LEN].into(), + handshaker_two_addr, + )) + .await + .unwrap(); + + let handshaker_one_future = async { + let message: handshake::CompleteMessage = handshaker_one.next().await.unwrap().unwrap(); + Ok::<_, ()>(message) + }; + + let handshaker_two_future = async { + let message: handshake::CompleteMessage = handshaker_two.next().await.unwrap().unwrap(); + Ok::<_, ()>(message) + }; + + let (item_one, item_two) = try_join(handshaker_one_future, handshaker_two_future).await.unwrap(); + + // Result from handshaker one should match handshaker two's listen address + assert_eq!(handshaker_two_addr, *item_one.address()); + + assert_eq!(handshaker_one_pid, *item_two.peer_id()); + assert_eq!(handshaker_two_pid, *item_one.peer_id()); + }); + + let res = test.await; - // Result from handshaker one should match handshaker two's listen address - assert_eq!(handshaker_two_addr, *item_one.address()); + tasks_one.shutdown().await; + tasks_two.shutdown().await; - assert_eq!(handshaker_one_pid, *item_two.peer_id()); - assert_eq!(handshaker_two_pid, *item_one.peer_id()); + res.unwrap(); } diff --git a/packages/handshake/tests/test_filter_allow_all.rs b/packages/handshake/tests/test_filter_allow_all.rs index db6cbb1a9..d67032589 100644 --- a/packages/handshake/tests/test_filter_allow_all.rs +++ b/packages/handshake/tests/test_filter_allow_all.rs @@ -2,15 +2,15 @@ use std::any::Any; use std::net::SocketAddr; use std::time::Duration; -use common::TimeoutResult; -use futures::sink::Sink; -use futures::stream::Stream; -use futures::Future; +use common::{tracing_stderr_init, INIT}; +use futures::sink::SinkExt; +use futures::stream::{self, StreamExt}; +use futures::FutureExt as _; use handshake::transports::TcpTransport; use handshake::{ DiscoveryInfo, Extensions, FilterDecision, HandshakeFilter, HandshakeFilters, HandshakerBuilder, InitiateMessage, Protocol, }; -use tokio_core::reactor::{Core, Timeout}; +use tracing::level_filters::LevelFilter; use util::bt::{self, InfoHash, PeerId}; mod common; @@ -40,66 +40,78 @@ impl HandshakeFilter for FilterAllowAll { } } -#[test] -fn test_filter_all() { - let mut core = Core::new().unwrap(); - let handle = core.handle(); +#[tokio::test] +async fn test_filter_all() { + INIT.call_once(|| { + tracing_stderr_init(LevelFilter::INFO); + }); - let mut handshaker_one_addr = "127.0.0.1:0".parse().unwrap(); + let handshaker_one_addr = "127.0.0.1:0".parse().unwrap(); let handshaker_one_pid = [4u8; bt::PEER_ID_LEN].into(); - let handshaker_one = HandshakerBuilder::new() + let (handshaker_one, mut tasks_one) = HandshakerBuilder::new() .with_bind_addr(handshaker_one_addr) .with_peer_id(handshaker_one_pid) - .build(TcpTransport, &core.handle()) + .build(TcpTransport) + .await .unwrap(); + let mut handshaker_one_addr = handshaker_one_addr; handshaker_one_addr.set_port(handshaker_one.port()); // Filter all incoming handshake requests handshaker_one.add_filter(FilterAllowAll); - let mut handshaker_two_addr = "127.0.0.1:0".parse().unwrap(); + let handshaker_two_addr = "127.0.0.1:0".parse().unwrap(); let handshaker_two_pid = [5u8; bt::PEER_ID_LEN].into(); - let handshaker_two = HandshakerBuilder::new() + let (handshaker_two, mut tasks_two) = HandshakerBuilder::new() .with_bind_addr(handshaker_two_addr) .with_peer_id(handshaker_two_pid) - .build(TcpTransport, &core.handle()) + .build(TcpTransport) + .await .unwrap(); + let mut handshaker_two_addr = handshaker_two_addr; handshaker_two_addr.set_port(handshaker_two.port()); let (_, stream_one) = handshaker_one.into_parts(); - let (sink_two, stream_two) = handshaker_two.into_parts(); - - let timeout_result = core - .run( - sink_two - .send(InitiateMessage::new( - Protocol::BitTorrent, - [55u8; bt::INFO_HASH_LEN].into(), - handshaker_one_addr, - )) - .map_err(|_| ()) - .and_then(|_| { - let timeout = Timeout::new(Duration::from_millis(50), &handle) - .unwrap() - .map(|()| TimeoutResult::TimedOut) - .map_err(|_| ()); - - let result_one = stream_one.into_future().map(|_| TimeoutResult::GotResult).map_err(|_| ()); - let result_two = stream_two.into_future().map(|_| TimeoutResult::GotResult).map_err(|_| ()); - - result_one - .select(result_two) - .map(|_| TimeoutResult::GotResult) - .map_err(|_| ()) - .select(timeout) - .map(|(item, _)| item) - .map_err(|_| ()) - }), - ) - .unwrap(); - - assert_eq!(TimeoutResult::GotResult, timeout_result); + let (mut sink_two, stream_two) = handshaker_two.into_parts(); + + let test = tokio::spawn(async move { + sink_two + .send(InitiateMessage::new( + Protocol::BitTorrent, + [55u8; bt::INFO_HASH_LEN].into(), + handshaker_one_addr, + )) + .await + .unwrap(); + + let get_handshake = async move { + let mut merged = stream::select(stream_one, stream_two); + loop { + tokio::time::sleep(Duration::from_millis(5)).await; + + let Some(res) = merged.next().now_or_never() else { + continue; + }; + break res; + } + }; + + let res = tokio::time::timeout(Duration::from_millis(50), get_handshake).await; + + if let Ok(item) = res { + tracing::debug!("handshake was produced: {item:?}"); + } else { + panic!("expected item, but got a timeout!"); + } + }); + + let res = test.await; + + tasks_one.shutdown().await; + tasks_two.shutdown().await; + + res.unwrap(); } diff --git a/packages/handshake/tests/test_filter_block_all.rs b/packages/handshake/tests/test_filter_block_all.rs index ba434c473..fc1e2c342 100644 --- a/packages/handshake/tests/test_filter_block_all.rs +++ b/packages/handshake/tests/test_filter_block_all.rs @@ -2,15 +2,15 @@ use std::any::Any; use std::net::SocketAddr; use std::time::Duration; -use common::TimeoutResult; -use futures::sink::Sink; -use futures::stream::Stream; -use futures::Future; +use common::{tracing_stderr_init, INIT}; +use futures::sink::SinkExt; +use futures::stream::{self, StreamExt}; +use futures::FutureExt; use handshake::transports::TcpTransport; use handshake::{ DiscoveryInfo, Extensions, FilterDecision, HandshakeFilter, HandshakeFilters, HandshakerBuilder, InitiateMessage, Protocol, }; -use tokio_core::reactor::{Core, Timeout}; +use tracing::level_filters::LevelFilter; use util::bt::{self, InfoHash, PeerId}; mod common; @@ -40,66 +40,78 @@ impl HandshakeFilter for FilterBlockAll { } } -#[test] -fn test_filter_all() { - let mut core = Core::new().unwrap(); - let handle = core.handle(); +#[tokio::test] +async fn test_filter_all() { + INIT.call_once(|| { + tracing_stderr_init(LevelFilter::INFO); + }); - let mut handshaker_one_addr = "127.0.0.1:0".parse().unwrap(); + let handshaker_one_addr = "127.0.0.1:0".parse().unwrap(); let handshaker_one_pid = [4u8; bt::PEER_ID_LEN].into(); - let handshaker_one = HandshakerBuilder::new() + let (handshaker_one, mut tasks_one) = HandshakerBuilder::new() .with_bind_addr(handshaker_one_addr) .with_peer_id(handshaker_one_pid) - .build(TcpTransport, &core.handle()) + .build(TcpTransport) + .await .unwrap(); + let mut handshaker_one_addr = handshaker_one_addr; handshaker_one_addr.set_port(handshaker_one.port()); // Filter all incoming handshake requests handshaker_one.add_filter(FilterBlockAll); - let mut handshaker_two_addr = "127.0.0.1:0".parse().unwrap(); + let handshaker_two_addr = "127.0.0.1:0".parse().unwrap(); let handshaker_two_pid = [5u8; bt::PEER_ID_LEN].into(); - let handshaker_two = HandshakerBuilder::new() + let (handshaker_two, mut tasks_two) = HandshakerBuilder::new() .with_bind_addr(handshaker_two_addr) .with_peer_id(handshaker_two_pid) - .build(TcpTransport, &core.handle()) + .build(TcpTransport) + .await .unwrap(); + let mut handshaker_two_addr = handshaker_two_addr; handshaker_two_addr.set_port(handshaker_two.port()); let (_, stream_one) = handshaker_one.into_parts(); - let (sink_two, stream_two) = handshaker_two.into_parts(); - - let timeout_result = core - .run( - sink_two - .send(InitiateMessage::new( - Protocol::BitTorrent, - [55u8; bt::INFO_HASH_LEN].into(), - handshaker_one_addr, - )) - .map_err(|_| ()) - .and_then(|_| { - let timeout = Timeout::new(Duration::from_millis(50), &handle) - .unwrap() - .map(|()| TimeoutResult::TimedOut) - .map_err(|_| ()); - - let result_one = stream_one.into_future().map(|_| TimeoutResult::GotResult).map_err(|_| ()); - let result_two = stream_two.into_future().map(|_| TimeoutResult::GotResult).map_err(|_| ()); - - result_one - .select(result_two) - .map(|_| TimeoutResult::GotResult) - .map_err(|_| ()) - .select(timeout) - .map(|(item, _)| item) - .map_err(|_| ()) - }), - ) - .unwrap(); - - assert_eq!(TimeoutResult::TimedOut, timeout_result); + let (mut sink_two, stream_two) = handshaker_two.into_parts(); + + let test = tokio::spawn(async move { + sink_two + .send(InitiateMessage::new( + Protocol::BitTorrent, + [55u8; bt::INFO_HASH_LEN].into(), + handshaker_one_addr, + )) + .await + .unwrap(); + + let get_handshake = async move { + let mut merged = stream::select(stream_one, stream_two); + loop { + tokio::time::sleep(Duration::from_millis(5)).await; + + let Some(res) = merged.next().now_or_never() else { + continue; + }; + break res; + } + }; + + let res = tokio::time::timeout(Duration::from_millis(50), get_handshake).await; + + if let Ok(item) = res { + panic!("expected timeout, but got a result: {item:?}"); + } else { + tracing::debug!("timeout was reached"); + } + }); + + let res = test.await; + + tasks_one.shutdown().await; + tasks_two.shutdown().await; + + res.unwrap(); } diff --git a/packages/handshake/tests/test_filter_whitelist_diff_data.rs b/packages/handshake/tests/test_filter_whitelist_diff_data.rs index 86a09f449..2d071c1c0 100644 --- a/packages/handshake/tests/test_filter_whitelist_diff_data.rs +++ b/packages/handshake/tests/test_filter_whitelist_diff_data.rs @@ -1,13 +1,13 @@ use std::any::Any; use std::time::Duration; -use common::TimeoutResult; -use futures::sink::Sink; -use futures::stream::Stream; -use futures::Future; +use common::{tracing_stderr_init, INIT}; +use futures::sink::SinkExt; +use futures::stream::{self, StreamExt}; +use futures::FutureExt as _; use handshake::transports::TcpTransport; use handshake::{DiscoveryInfo, FilterDecision, HandshakeFilter, HandshakeFilters, HandshakerBuilder, InitiateMessage, Protocol}; -use tokio_core::reactor::{Core, Timeout}; +use tracing::level_filters::LevelFilter; use util::bt::{self, InfoHash}; mod common; @@ -46,20 +46,23 @@ impl HandshakeFilter for FilterAllowHash { } } -#[test] -fn test_filter_whitelist_diff_data() { - let mut core = Core::new().unwrap(); - let handle = core.handle(); +#[tokio::test] +async fn test_filter_whitelist_diff_data() { + INIT.call_once(|| { + tracing_stderr_init(LevelFilter::INFO); + }); - let mut handshaker_one_addr = "127.0.0.1:0".parse().unwrap(); + let handshaker_one_addr = "127.0.0.1:0".parse().unwrap(); let handshaker_one_pid = [4u8; bt::PEER_ID_LEN].into(); - let handshaker_one = HandshakerBuilder::new() + let (handshaker_one, mut tasks_one) = HandshakerBuilder::new() .with_bind_addr(handshaker_one_addr) .with_peer_id(handshaker_one_pid) - .build(TcpTransport, &core.handle()) + .build(TcpTransport) + .await .unwrap(); + let mut handshaker_one_addr = handshaker_one_addr; handshaker_one_addr.set_port(handshaker_one.port()); // Filter all incoming handshake requests hash's, then whitelist let allow_info_hash = [55u8; bt::INFO_HASH_LEN].into(); @@ -67,48 +70,59 @@ fn test_filter_whitelist_diff_data() { handshaker_one.add_filter(FilterBlockAllHash); handshaker_one.add_filter(FilterAllowHash { hash: allow_info_hash }); - let mut handshaker_two_addr = "127.0.0.1:0".parse().unwrap(); + let handshaker_two_addr = "127.0.0.1:0".parse().unwrap(); let handshaker_two_pid = [5u8; bt::PEER_ID_LEN].into(); - let handshaker_two = HandshakerBuilder::new() + let (handshaker_two, mut tasks_two) = HandshakerBuilder::new() .with_bind_addr(handshaker_two_addr) .with_peer_id(handshaker_two_pid) - .build(TcpTransport, &core.handle()) + .build(TcpTransport) + .await .unwrap(); + let mut handshaker_two_addr = handshaker_two_addr; handshaker_two_addr.set_port(handshaker_two.port()); let (_, stream_one) = handshaker_one.into_parts(); - let (sink_two, stream_two) = handshaker_two.into_parts(); - - let timeout_result = core - .run( - sink_two - .send(InitiateMessage::new( - Protocol::BitTorrent, - [54u8; bt::INFO_HASH_LEN].into(), - handshaker_one_addr, - )) - .map_err(|_| ()) - .and_then(|_| { - let timeout = Timeout::new(Duration::from_millis(50), &handle) - .unwrap() - .map(|()| TimeoutResult::TimedOut) - .map_err(|_| ()); - - let result_one = stream_one.into_future().map(|_| TimeoutResult::GotResult).map_err(|_| ()); - let result_two = stream_two.into_future().map(|_| TimeoutResult::GotResult).map_err(|_| ()); - - result_one - .select(result_two) - .map(|_| TimeoutResult::GotResult) - .map_err(|_| ()) - .select(timeout) - .map(|(item, _)| item) - .map_err(|_| ()) - }), - ) - .unwrap(); + let (mut sink_two, stream_two) = handshaker_two.into_parts(); + + let test = tokio::spawn(async move { + // Send the initiate message + sink_two + .send(InitiateMessage::new( + Protocol::BitTorrent, + [54u8; bt::INFO_HASH_LEN].into(), + handshaker_one_addr, + )) + .await + .unwrap(); + + // Use tokio timeout to wait for the result + let get_handshake = async move { + let mut merged = stream::select(stream_one, stream_two); + loop { + tokio::time::sleep(Duration::from_millis(5)).await; + + let Some(res) = merged.next().now_or_never() else { + continue; + }; + break res; + } + }; + + let res = tokio::time::timeout(Duration::from_millis(50), get_handshake).await; + + if let Ok(item) = res { + panic!("expected timeout, but got a result: {item:?}"); + } else { + tracing::debug!("timeout was reached"); + } + }); + + let res = test.await; + + tasks_one.shutdown().await; + tasks_two.shutdown().await; - assert_eq!(TimeoutResult::TimedOut, timeout_result); + res.unwrap(); } diff --git a/packages/handshake/tests/test_filter_whitelist_same_data.rs b/packages/handshake/tests/test_filter_whitelist_same_data.rs index 21064f516..236bc4328 100644 --- a/packages/handshake/tests/test_filter_whitelist_same_data.rs +++ b/packages/handshake/tests/test_filter_whitelist_same_data.rs @@ -1,13 +1,13 @@ use std::any::Any; use std::time::Duration; -use common::TimeoutResult; -use futures::sink::Sink; -use futures::stream::Stream; -use futures::Future; +use common::{tracing_stderr_init, INIT}; +use futures::sink::SinkExt; +use futures::stream::{self, StreamExt}; +use futures::FutureExt as _; use handshake::transports::TcpTransport; use handshake::{DiscoveryInfo, FilterDecision, HandshakeFilter, HandshakeFilters, HandshakerBuilder, InitiateMessage, Protocol}; -use tokio_core::reactor::{Core, Timeout}; +use tracing::level_filters::LevelFilter; use util::bt::{self, InfoHash}; mod common; @@ -46,20 +46,23 @@ impl HandshakeFilter for FilterAllowHash { } } -#[test] -fn test_filter_whitelist_same_data() { - let mut core = Core::new().unwrap(); - let handle = core.handle(); +#[tokio::test] +async fn test_filter_whitelist_same_data() { + INIT.call_once(|| { + tracing_stderr_init(LevelFilter::INFO); + }); - let mut handshaker_one_addr = "127.0.0.1:0".parse().unwrap(); + let handshaker_one_addr = "127.0.0.1:0".parse().unwrap(); let handshaker_one_pid = [4u8; bt::PEER_ID_LEN].into(); - let handshaker_one = HandshakerBuilder::new() + let (handshaker_one, mut tasks_one) = HandshakerBuilder::new() .with_bind_addr(handshaker_one_addr) .with_peer_id(handshaker_one_pid) - .build(TcpTransport, &core.handle()) + .build(TcpTransport) + .await .unwrap(); + let mut handshaker_one_addr = handshaker_one_addr; handshaker_one_addr.set_port(handshaker_one.port()); // Filter all incoming handshake requests hash's, then whitelist let allow_info_hash = [55u8; bt::INFO_HASH_LEN].into(); @@ -67,48 +70,59 @@ fn test_filter_whitelist_same_data() { handshaker_one.add_filter(FilterBlockAllHash); handshaker_one.add_filter(FilterAllowHash { hash: allow_info_hash }); - let mut handshaker_two_addr = "127.0.0.1:0".parse().unwrap(); + let handshaker_two_addr = "127.0.0.1:0".parse().unwrap(); let handshaker_two_pid = [5u8; bt::PEER_ID_LEN].into(); - let handshaker_two = HandshakerBuilder::new() + let (handshaker_two, mut tasks_two) = HandshakerBuilder::new() .with_bind_addr(handshaker_two_addr) .with_peer_id(handshaker_two_pid) - .build(TcpTransport, &core.handle()) + .build(TcpTransport) + .await .unwrap(); + let mut handshaker_two_addr = handshaker_two_addr; handshaker_two_addr.set_port(handshaker_two.port()); let (_, stream_one) = handshaker_one.into_parts(); - let (sink_two, stream_two) = handshaker_two.into_parts(); - - let timeout_result = core - .run( - sink_two - .send(InitiateMessage::new( - Protocol::BitTorrent, - allow_info_hash, - handshaker_one_addr, - )) - .map_err(|_| ()) - .and_then(|_| { - let timeout = Timeout::new(Duration::from_millis(50), &handle) - .unwrap() - .map(|()| TimeoutResult::TimedOut) - .map_err(|_| ()); - - let result_one = stream_one.into_future().map(|_| TimeoutResult::GotResult).map_err(|_| ()); - let result_two = stream_two.into_future().map(|_| TimeoutResult::GotResult).map_err(|_| ()); - - result_one - .select(result_two) - .map(|_| TimeoutResult::GotResult) - .map_err(|_| ()) - .select(timeout) - .map(|(item, _)| item) - .map_err(|_| ()) - }), - ) - .unwrap(); + let (mut sink_two, stream_two) = handshaker_two.into_parts(); + + let test = tokio::spawn(async move { + // Send the initiate message + sink_two + .send(InitiateMessage::new( + Protocol::BitTorrent, + allow_info_hash, + handshaker_one_addr, + )) + .await + .unwrap(); + + // Use tokio timeout to wait for the result + let get_handshake = async move { + let mut merged = stream::select(stream_one, stream_two); + loop { + tokio::time::sleep(Duration::from_millis(5)).await; + + let Some(res) = merged.next().now_or_never() else { + continue; + }; + break res; + } + }; + + let res = tokio::time::timeout(Duration::from_millis(50), get_handshake).await; + + if let Ok(item) = res { + tracing::debug!("handshake was produced: {item:?}"); + } else { + panic!("expected item, but got a timeout!"); + } + }); + + let res = test.await; + + tasks_one.shutdown().await; + tasks_two.shutdown().await; - assert_eq!(TimeoutResult::GotResult, timeout_result); + res.unwrap(); } diff --git a/packages/magnet/Cargo.toml b/packages/magnet/Cargo.toml index 2f49cd965..854531d11 100644 --- a/packages/magnet/Cargo.toml +++ b/packages/magnet/Cargo.toml @@ -18,5 +18,6 @@ version.workspace = true [dependencies] util = { path = "../util" } -base32 = "0.4" +base32 = "0" url = "2" + diff --git a/packages/magnet/src/lib.rs b/packages/magnet/src/lib.rs index 7cc3fdf23..d1b19d64d 100644 --- a/packages/magnet/src/lib.rs +++ b/packages/magnet/src/lib.rs @@ -1,5 +1,3 @@ -use std::default::Default; - use url::Url; use util::bt::InfoHash; use util::sha::ShaHash; @@ -27,7 +25,7 @@ impl Topic { } } else if s.starts_with("urn:btih:") && s.len() == 9 + 32 { // BitTorrent Info Hash, base-32 - base32::decode(base32::Alphabet::RFC4648 { padding: true }, &s[9..]).and_then(|hash| { + base32::decode(base32::Alphabet::Rfc4648 { padding: true }, &s[9..]).and_then(|hash| { match ShaHash::from_hash(&hash[..]) { Ok(sha_hash) => Some(Topic::BitTorrentInfoHash(sha_hash)), Err(_) => None, diff --git a/packages/metainfo/Cargo.toml b/packages/metainfo/Cargo.toml index 50978b2e3..2a5278ce6 100644 --- a/packages/metainfo/Cargo.toml +++ b/packages/metainfo/Cargo.toml @@ -19,15 +19,15 @@ version.workspace = true bencode = { path = "../bencode" } util = { path = "../util" } -crossbeam = "0.8" -error-chain = "0.12" +crossbeam = "0" +thiserror = "1" walkdir = "2" [dev-dependencies] -chrono = "0.4" -criterion = "0.5" +chrono = "0" +criterion = "0" pbr = "1" -rand = "0.8" +rand = "0" [[bench]] harness = false diff --git a/packages/metainfo/examples/create_torrent.rs b/packages/metainfo/examples/create_torrent.rs index fed862b37..99a1e550a 100644 --- a/packages/metainfo/examples/create_torrent.rs +++ b/packages/metainfo/examples/create_torrent.rs @@ -1,18 +1,17 @@ -use std::fs::File; -use std::io::{self, BufRead, Write}; +use std::io::{BufRead as _, Write as _}; use std::path::Path; use chrono::offset::{TimeZone, Utc}; -use metainfo::error::ParseResult; +use metainfo::error::ParseError; use metainfo::{Metainfo, MetainfoBuilder}; use pbr::ProgressBar; fn main() { println!("\nIMPORTANT: Remember to run in release mode for real world performance...\n"); - let input = io::stdin(); + let input = std::io::stdin(); let mut input_lines = input.lock().lines(); - let mut output = io::stdout(); + let mut output = std::io::stdout(); output.write_all(b"Enter A Source Folder/File: ").unwrap(); output.flush().unwrap(); @@ -24,7 +23,7 @@ fn main() { match create_torrent(src_path) { Ok(bytes) => { - let mut output_file = File::create(dst_path).unwrap(); + let mut output_file = std::fs::File::create(dst_path).unwrap(); output_file.write_all(&bytes).unwrap(); print_metainfo_overview(&bytes); @@ -34,7 +33,7 @@ fn main() { } /// Create a torrent from the given source path. -fn create_torrent(src_path: S) -> ParseResult> +fn create_torrent(src_path: S) -> Result, ParseError> where S: AsRef, { diff --git a/packages/metainfo/src/accessor.rs b/packages/metainfo/src/accessor.rs index cf979b2b2..c01aacdcd 100644 --- a/packages/metainfo/src/accessor.rs +++ b/packages/metainfo/src/accessor.rs @@ -1,5 +1,3 @@ -use std::fs::File; -use std::io::{self, Cursor, Read}; use std::path::{Path, PathBuf}; use util::sha::ShaHash; @@ -15,7 +13,7 @@ pub trait IntoAccessor { /// # Errors /// /// It would return an IO error if unable to convert to an ancestor. - fn into_accessor(self) -> io::Result; + fn into_accessor(self) -> std::io::Result; } /// Trait for accessing the data used to construct a torrent file. @@ -28,7 +26,7 @@ pub trait Accessor { /// # Errors /// /// It would return an IO error if unable to access the metadata. - fn access_metadata(&self, callback: C) -> io::Result<()> + fn access_metadata(&self, callback: C) -> std::io::Result<()> where C: FnMut(u64, &Path); @@ -37,9 +35,9 @@ pub trait Accessor { /// # Errors /// /// It would return an IO error if unable to access the pieces. - fn access_pieces(&self, callback: C) -> io::Result<()> + fn access_pieces(&self, callback: C) -> std::io::Result<()> where - C: for<'a> FnMut(PieceAccess<'a>) -> io::Result<()>; + C: for<'a> FnMut(PieceAccess<'a>) -> std::io::Result<()>; } impl<'a, T> Accessor for &'a T @@ -50,16 +48,16 @@ where Accessor::access_directory(*self) } - fn access_metadata(&self, callback: C) -> io::Result<()> + fn access_metadata(&self, callback: C) -> std::io::Result<()> where C: FnMut(u64, &Path), { Accessor::access_metadata(*self, callback) } - fn access_pieces(&self, callback: C) -> io::Result<()> + fn access_pieces(&self, callback: C) -> std::io::Result<()> where - C: for<'b> FnMut(PieceAccess<'b>) -> io::Result<()>, + C: for<'b> FnMut(PieceAccess<'b>) -> std::io::Result<()>, { Accessor::access_pieces(*self, callback) } @@ -77,7 +75,7 @@ where /// (though not required). pub enum PieceAccess<'a> { /// Hash should be computed from the bytes read. - Compute(&'a mut dyn Read), + Compute(&'a mut dyn std::io::Read), /// Hash given should be used directly as the next checksum. PreComputed(ShaHash), } @@ -101,7 +99,7 @@ impl FileAccessor { /// # Panics /// /// It would panic if unable to get the last directory name. - pub fn new(path: T) -> io::Result + pub fn new(path: T) -> std::io::Result where T: AsRef, { @@ -124,7 +122,7 @@ impl FileAccessor { impl IntoAccessor for FileAccessor { type Accessor = FileAccessor; - fn into_accessor(self) -> io::Result { + fn into_accessor(self) -> std::io::Result { Ok(self) } } @@ -135,7 +133,7 @@ where { type Accessor = FileAccessor; - fn into_accessor(self) -> io::Result { + fn into_accessor(self) -> std::io::Result { FileAccessor::new(self) } } @@ -145,7 +143,7 @@ impl Accessor for FileAccessor { self.directory_name.as_ref().map(std::convert::AsRef::as_ref) } - fn access_metadata(&self, mut callback: C) -> io::Result<()> + fn access_metadata(&self, mut callback: C) -> std::io::Result<()> where C: FnMut(u64, &Path), { @@ -172,13 +170,13 @@ impl Accessor for FileAccessor { Ok(()) } - fn access_pieces(&self, mut callback: C) -> io::Result<()> + fn access_pieces(&self, mut callback: C) -> std::io::Result<()> where - C: for<'a> FnMut(PieceAccess<'a>) -> io::Result<()>, + C: for<'a> FnMut(PieceAccess<'a>) -> std::io::Result<()>, { for res_entry in WalkDir::new(&self.absolute_path).into_iter().filter(entry_file_filter) { let entry = res_entry?; - let mut file = File::open(entry.path())?; + let mut file = std::fs::File::open(entry.path())?; callback(PieceAccess::Compute(&mut file))?; } @@ -215,7 +213,7 @@ impl<'a> DirectAccessor<'a> { impl<'a> IntoAccessor for DirectAccessor<'a> { type Accessor = DirectAccessor<'a>; - fn into_accessor(self) -> io::Result> { + fn into_accessor(self) -> std::io::Result> { Ok(self) } } @@ -225,7 +223,7 @@ impl<'a> Accessor for DirectAccessor<'a> { None } - fn access_metadata(&self, mut callback: C) -> io::Result<()> + fn access_metadata(&self, mut callback: C) -> std::io::Result<()> where C: FnMut(u64, &Path), { @@ -237,11 +235,11 @@ impl<'a> Accessor for DirectAccessor<'a> { Ok(()) } - fn access_pieces(&self, mut callback: C) -> io::Result<()> + fn access_pieces(&self, mut callback: C) -> std::io::Result<()> where - C: for<'b> FnMut(PieceAccess<'b>) -> io::Result<()>, + C: for<'b> FnMut(PieceAccess<'b>) -> std::io::Result<()>, { - let mut cursor = Cursor::new(self.file_contents); + let mut cursor = std::io::Cursor::new(self.file_contents); callback(PieceAccess::Compute(&mut cursor)) } diff --git a/packages/metainfo/src/builder/buffer.rs b/packages/metainfo/src/builder/buffer.rs index 6175549fc..32af8d7a7 100644 --- a/packages/metainfo/src/builder/buffer.rs +++ b/packages/metainfo/src/builder/buffer.rs @@ -1,5 +1,4 @@ -use core::time; -use std::{io, thread}; +use std::time::Duration; use crossbeam::queue::SegQueue; @@ -35,7 +34,7 @@ impl PieceBuffers { /// Checkout a piece buffer (possibly blocking) to be used. pub fn checkout(&self) -> PieceBuffer { let mut pb = None; - let ten_millis = time::Duration::from_millis(10); + let ten_millis = Duration::from_millis(10); while pb.is_none() { pb = self.piece_queue.pop(); @@ -44,7 +43,7 @@ impl PieceBuffers { break; } - thread::sleep(ten_millis); + std::thread::sleep(ten_millis); continue; } @@ -76,9 +75,9 @@ impl PieceBuffer { } } - pub fn write_bytes(&mut self, mut callback: C) -> io::Result + pub fn write_bytes(&mut self, mut callback: C) -> std::io::Result where - C: FnMut(&mut [u8]) -> io::Result, + C: FnMut(&mut [u8]) -> std::io::Result, { let new_bytes_read = callback(&mut self.buffer[self.bytes_read..])?; self.bytes_read += new_bytes_read; diff --git a/packages/metainfo/src/builder/mod.rs b/packages/metainfo/src/builder/mod.rs index cfe3b7be9..a2dcdbb1b 100644 --- a/packages/metainfo/src/builder/mod.rs +++ b/packages/metainfo/src/builder/mod.rs @@ -4,7 +4,7 @@ use bencode::{ben_bytes, ben_int, ben_map, BMutAccess, BRefAccess, BencodeMut}; use util::sha::{self, ShaHash}; use crate::accessor::{Accessor, IntoAccessor}; -use crate::error::ParseResult; +use crate::error::ParseError; use crate::parse; mod buffer; @@ -270,7 +270,7 @@ impl<'a> MetainfoBuilder<'a> { /// # Errors /// /// It would return an error if unable to get the accessor. - pub fn build(self, threads: usize, accessor: A, progress: C) -> ParseResult> + pub fn build(self, threads: usize, accessor: A, progress: C) -> Result, ParseError> where A: IntoAccessor, C: FnMut(f64) + Send + 'static, @@ -346,7 +346,7 @@ impl<'a> InfoBuilder<'a> { /// # Errors /// /// It would return an error if unable to get the accessor. - pub fn build(self, threads: usize, accessor: A, progress: C) -> ParseResult> + pub fn build(self, threads: usize, accessor: A, progress: C) -> Result, ParseError> where A: IntoAccessor, C: FnMut(f64) + Send + 'static, @@ -366,7 +366,7 @@ fn build_with_accessor<'a, A, C>( opt_root: Option>, info: BencodeMut<'a>, piece_length: PieceLength, -) -> ParseResult> +) -> Result, ParseError> where A: Accessor, C: FnMut(f64) + Send + 'static, diff --git a/packages/metainfo/src/builder/worker.rs b/packages/metainfo/src/builder/worker.rs index ffcd3072d..5396ad0ac 100644 --- a/packages/metainfo/src/builder/worker.rs +++ b/packages/metainfo/src/builder/worker.rs @@ -1,13 +1,11 @@ -use std::sync::mpsc::{self, Receiver, Sender}; -use std::sync::Arc; -use std::thread; +use std::sync::{mpsc, Arc}; use crossbeam::queue::SegQueue; use util::sha::ShaHash; use crate::accessor::{Accessor, PieceAccess}; use crate::builder::buffer::{PieceBuffer, PieceBuffers}; -use crate::error::ParseResult; +use crate::error::ParseError; /// Messages sent to the master hasher. pub enum MasterMessage { @@ -32,7 +30,7 @@ pub fn start_hasher_workers( num_pieces: u64, num_workers: usize, progress: C, -) -> ParseResult> +) -> Result, ParseError> where A: Accessor, C: FnMut(f64) + Send + 'static, @@ -53,13 +51,13 @@ where let share_work_queue = work_queue.clone(); let share_piece_buffers = piece_buffers.clone(); - thread::spawn(move || { + std::thread::spawn(move || { start_hash_worker(&share_master_send, &share_work_queue, &share_piece_buffers); }); } // Create a worker thread to execute the user callback for the progress update - thread::spawn(move || { + std::thread::spawn(move || { start_progress_updater(prog_recv, num_pieces, progress); }); @@ -74,11 +72,11 @@ where fn start_hash_master( accessor: A, num_workers: usize, - recv: &Receiver, + recv: &mpsc::Receiver, work: &Arc>, buffers: &Arc, - progress_sender: &Sender, -) -> ParseResult> + progress_sender: &mpsc::Sender, +) -> Result, ParseError> where A: Accessor, { @@ -162,7 +160,7 @@ where // ----------------------------------------------------------------------------// -fn start_progress_updater(recv: Receiver, num_pieces: u64, mut progress: C) +fn start_progress_updater(recv: mpsc::Receiver, num_pieces: u64, mut progress: C) where C: FnMut(f64), { @@ -177,7 +175,7 @@ where // ----------------------------------------------------------------------------// /// Starts a hasher worker which will hash all of the buffers it receives. -fn start_hash_worker(send: &Sender, work: &Arc>, buffers: &Arc) { +fn start_hash_worker(send: &mpsc::Sender, work: &Arc>, buffers: &Arc) { let mut work_to_do = true; // Loop until we are instructed to stop working @@ -206,7 +204,7 @@ fn start_hash_worker(send: &Sender, work: &Arc(&self, _: C) -> io::Result<()> + fn access_metadata(&self, _: C) -> std::io::Result<()> where C: FnMut(u64, &Path), { @@ -266,12 +264,12 @@ mod tests { } /// Access the sequential pieces that make up all of the files. - fn access_pieces(&self, mut callback: C) -> io::Result<()> + fn access_pieces(&self, mut callback: C) -> std::io::Result<()> where - C: for<'a> FnMut(PieceAccess<'a>) -> io::Result<()>, + C: for<'a> FnMut(PieceAccess<'a>) -> std::io::Result<()>, { for range in &self.buffer_ranges { - let mut next_region = Cursor::new(self.contiguous_buffer.index(range.clone())); + let mut next_region = std::io::Cursor::new(self.contiguous_buffer.index(range.clone())); callback(PieceAccess::Compute(&mut next_region))?; } diff --git a/packages/metainfo/src/error.rs b/packages/metainfo/src/error.rs index dbd62af8e..091ecb55a 100644 --- a/packages/metainfo/src/error.rs +++ b/packages/metainfo/src/error.rs @@ -1,29 +1,24 @@ //! Errors for torrent file building and parsing. -use std::io; - use bencode::{BencodeConvertError, BencodeParseError}; -use error_chain::error_chain; +use thiserror::Error; use walkdir; -error_chain! { - types { - ParseError, ParseErrorKind, ParseResultEx, ParseResult; - } +#[allow(clippy::module_name_repetitions)] +#[derive(Error, Debug)] +pub enum ParseError { + #[error("IO error: {0}")] + Io(#[from] std::io::Error), + + #[error("Directory error: {0}")] + Dir(#[from] walkdir::Error), + + #[error("Bencode conversion error: {0}")] + BencodeConvert(#[from] BencodeConvertError), - foreign_links { - Io(io::Error); - Dir(walkdir::Error); - BencodeConvert(BencodeConvertError); - BencodeParse(BencodeParseError); - } + #[error("Bencode parse error: {0}")] + BencodeParse(#[from] BencodeParseError), - errors { - MissingData { - details: String - } { - description("Missing Data Detected In File") - display("Missing Data Detected In File: {}", details) - } - } + #[error("Missing Data Detected In File: {details}")] + MissingData { details: String }, } diff --git a/packages/metainfo/src/metainfo.rs b/packages/metainfo/src/metainfo.rs index 0fd00d167..1feaa082a 100644 --- a/packages/metainfo/src/metainfo.rs +++ b/packages/metainfo/src/metainfo.rs @@ -1,7 +1,5 @@ //! Accessing the fields of a Metainfo file. -use std::fmt::Debug; -use std::io; use std::path::{Path, PathBuf}; use bencode::{BDecodeOpt, BDictAccess, BRefAccess, BencodeRef}; @@ -10,7 +8,7 @@ use util::sha::{self, ShaHash}; use crate::accessor::{Accessor, IntoAccessor, PieceAccess}; use crate::builder::{InfoBuilder, MetainfoBuilder, PieceLength}; -use crate::error::{ParseError, ParseErrorKind, ParseResult}; +use crate::error::ParseError; use crate::iter::{Files, Pieces}; use crate::parse; @@ -27,12 +25,25 @@ pub struct Metainfo { } impl Metainfo { + #[must_use] + pub fn new(info: Info) -> Self { + Self { + comment: None, + announce: None, + announce_list: None, + encoding: None, + created_by: None, + creation_date: None, + info, + } + } + /// Read a `Metainfo` from metainfo file bytes. /// /// # Errors /// /// It would return an error if unable to parse the bytes as a [`Metainfo`] - pub fn from_bytes(bytes: B) -> ParseResult + pub fn from_bytes(bytes: B) -> Result where B: AsRef<[u8]>, { @@ -119,7 +130,7 @@ impl From for Metainfo { } /// Parses the given metainfo bytes and builds a Metainfo from them. -fn parse_meta_bytes(bytes: &[u8]) -> ParseResult { +fn parse_meta_bytes(bytes: &[u8]) -> Result { let root_bencode = BencodeRef::decode(bytes, BDecodeOpt::default())?; let root_dict = parse::parse_root_dict(&root_bencode)?; @@ -170,7 +181,7 @@ impl Info { /// # Errors /// /// It would return an error if unable to parse bytes into [`Info`]. - pub fn from_bytes(bytes: B) -> ParseResult + pub fn from_bytes(bytes: B) -> Result where B: AsRef<[u8]>, { @@ -247,7 +258,7 @@ impl Info { impl IntoAccessor for Info { type Accessor = Info; - fn into_accessor(self) -> io::Result { + fn into_accessor(self) -> std::io::Result { Ok(self) } } @@ -255,7 +266,7 @@ impl IntoAccessor for Info { impl<'a> IntoAccessor for &'a Info { type Accessor = &'a Info; - fn into_accessor(self) -> io::Result<&'a Info> { + fn into_accessor(self) -> std::io::Result<&'a Info> { Ok(self) } } @@ -265,7 +276,7 @@ impl Accessor for Info { self.directory() } - fn access_metadata(&self, mut callback: C) -> io::Result<()> + fn access_metadata(&self, mut callback: C) -> std::io::Result<()> where C: FnMut(u64, &Path), { @@ -276,9 +287,9 @@ impl Accessor for Info { Ok(()) } - fn access_pieces(&self, mut callback: C) -> io::Result<()> + fn access_pieces(&self, mut callback: C) -> std::io::Result<()> where - C: for<'a> FnMut(PieceAccess<'a>) -> io::Result<()>, + C: for<'a> FnMut(PieceAccess<'a>) -> std::io::Result<()>, { for piece in self.pieces() { callback(PieceAccess::PreComputed(ShaHash::from_hash(piece).unwrap()))?; @@ -289,14 +300,14 @@ impl Accessor for Info { } /// Parses the given info dictionary bytes and builds a Metainfo from them. -fn parse_info_bytes(bytes: &[u8]) -> ParseResult { +fn parse_info_bytes(bytes: &[u8]) -> Result { let info_bencode = BencodeRef::decode(bytes, BDecodeOpt::default())?; parse_info_dictionary(&info_bencode) } /// Parses the given info dictionary and builds an Info from it. -fn parse_info_dictionary(info_bencode: &BencodeRef<'_>) -> ParseResult { +fn parse_info_dictionary(info_bencode: &BencodeRef<'_>) -> Result { let info_hash = InfoHash::from_bytes(info_bencode.buffer()); let info_dict = parse::parse_root_dict(info_bencode)?; @@ -352,10 +363,10 @@ where } /// Validates and allocates the hash pieces on the heap. -fn allocate_pieces(pieces: &[u8]) -> ParseResult> { +fn allocate_pieces(pieces: &[u8]) -> Result, ParseError> { if pieces.len() % sha::SHA_HASH_LEN != 0 { let error_msg = format!("Piece Hash Length Of {} Is Invalid", pieces.len()); - Err(ParseError::from_kind(ParseErrorKind::MissingData { details: error_msg })) + Err(ParseError::MissingData { details: error_msg }) } else { let mut hash_buffers = Vec::with_capacity(pieces.len() / sha::SHA_HASH_LEN); let mut hash_bytes = [0u8; sha::SHA_HASH_LEN]; @@ -384,7 +395,7 @@ pub struct File { impl File { /// Parse the info dictionary and generate a single file File. - fn as_single_file(info_dict: &dyn BDictAccess) -> ParseResult + fn as_single_file(info_dict: &dyn BDictAccess) -> Result where B: BRefAccess, { @@ -400,7 +411,7 @@ impl File { } /// Parse the file dictionary and generate a multi file File. - fn as_multi_file(file_dict: &dyn BDictAccess) -> ParseResult + fn as_multi_file(file_dict: &dyn BDictAccess) -> Result where B: BRefAccess, { @@ -906,16 +917,14 @@ mod tests { } #[test] - #[should_panic( - expected = "called `Result::unwrap()` on an `Err` value: ParseError(BencodeParse(BencodeParseError(BytesEmpty { pos: 0 }, State { next_error: None, backtrace: InternalBacktrace { backtrace: None } })), State { next_error: None, backtrace: InternalBacktrace { backtrace: None } })" - )] + #[should_panic(expected = "called `Result::unwrap()` on an `Err` value: BencodeParse(BytesEmpty { pos: 0 })")] fn negative_parse_from_empty_bytes() { Metainfo::from_bytes(b"").unwrap(); } #[test] #[should_panic( - expected = "called `Result::unwrap()` on an `Err` value: ParseError(BencodeConvert(BencodeConvertError(MissingKey { key: [112, 105, 101, 99, 101, 32, 108, 101, 110, 103, 116, 104] }, State { next_error: None, backtrace: InternalBacktrace { backtrace: None } })), State { next_error: None, backtrace: InternalBacktrace { backtrace: None } })" + expected = "called `Result::unwrap()` on an `Err` value: BencodeConvert(MissingKey { key: [112, 105, 101, 99, 101, 32, 108, 101, 110, 103, 116, 104] })" )] fn negative_parse_with_no_piece_length() { let tracker = "udp://dummy_domain.com:8989"; @@ -942,7 +951,7 @@ mod tests { #[test] #[should_panic( - expected = "called `Result::unwrap()` on an `Err` value: ParseError(BencodeConvert(BencodeConvertError(MissingKey { key: [112, 105, 101, 99, 101, 115] }, State { next_error: None, backtrace: InternalBacktrace { backtrace: None } })), State { next_error: None, backtrace: InternalBacktrace { backtrace: None } })" + expected = "called `Result::unwrap()` on an `Err` value: BencodeConvert(MissingKey { key: [112, 105, 101, 99, 101, 115] })" )] fn negative_parse_with_no_pieces() { let tracker = "udp://dummy_domain.com:8989"; @@ -967,7 +976,7 @@ mod tests { #[test] #[should_panic( - expected = "called `Result::unwrap()` on an `Err` value: ParseError(BencodeConvert(BencodeConvertError(MissingKey { key: [102, 105, 108, 101, 115] }, State { next_error: None, backtrace: InternalBacktrace { backtrace: None } })), State { next_error: None, backtrace: InternalBacktrace { backtrace: None } })" + expected = "called `Result::unwrap()` on an `Err` value: BencodeConvert(MissingKey { key: [102, 105, 108, 101, 115] })" )] fn negative_parse_from_single_file_with_no_file_length() { let tracker = "udp://dummy_domain.com:8989"; @@ -992,7 +1001,7 @@ mod tests { #[test] #[should_panic( - expected = "called `Result::unwrap()` on an `Err` value: ParseError(BencodeConvert(BencodeConvertError(MissingKey { key: [110, 97, 109, 101] }, State { next_error: None, backtrace: InternalBacktrace { backtrace: None } })), State { next_error: None, backtrace: InternalBacktrace { backtrace: None } })" + expected = "called `Result::unwrap()` on an `Err` value: BencodeConvert(MissingKey { key: [110, 97, 109, 101] })" )] fn negative_parse_from_single_file_with_no_file_name() { let tracker = "udp://dummy_domain.com:8989"; diff --git a/packages/metainfo/src/parse.rs b/packages/metainfo/src/parse.rs index a7274ba71..2cfa23a51 100644 --- a/packages/metainfo/src/parse.rs +++ b/packages/metainfo/src/parse.rs @@ -1,6 +1,6 @@ use bencode::{BConvert, BDictAccess, BListAccess, BRefAccess, BencodeConvertError}; -use crate::error::{ParseError, ParseResult}; +use crate::error::ParseError; /// Struct implemented the `BencodeConvert` trait for decoding the metainfo file. struct MetainfoConverter; @@ -42,7 +42,7 @@ pub const PATH_KEY: &[u8] = b"path"; /// Parses the root bencode as a dictionary. #[allow(clippy::module_name_repetitions)] -pub fn parse_root_dict(root_bencode: &B) -> ParseResult<&dyn BDictAccess> +pub fn parse_root_dict(root_bencode: &B) -> Result<&dyn BDictAccess, ParseError> where B: BRefAccess, { @@ -122,7 +122,7 @@ where /// Parses the info dictionary from the root dictionary. #[allow(clippy::module_name_repetitions)] -pub fn parse_info_bencode(root_dict: &dyn BDictAccess) -> ParseResult<&B> +pub fn parse_info_bencode(root_dict: &dyn BDictAccess) -> Result<&B, ParseError> where B: BRefAccess, { @@ -133,7 +133,7 @@ where /// Parses the piece length from the info dictionary. #[allow(clippy::module_name_repetitions)] -pub fn parse_piece_length(info_dict: &dyn BDictAccess) -> ParseResult +pub fn parse_piece_length(info_dict: &dyn BDictAccess) -> Result where B: BRefAccess, { @@ -144,7 +144,7 @@ where /// Parses the pieces from the info dictionary. #[allow(clippy::module_name_repetitions)] -pub fn parse_pieces<'a, B>(info_dict: &'a dyn BDictAccess) -> ParseResult<&'a [u8]> +pub fn parse_pieces<'a, B>(info_dict: &'a dyn BDictAccess) -> Result<&'a [u8], ParseError> where B: BRefAccess + 'a, { @@ -162,7 +162,7 @@ where /// Parses the name from the info dictionary. #[allow(clippy::module_name_repetitions)] -pub fn parse_name<'a, B>(info_dict: &'a dyn BDictAccess) -> ParseResult<&'a str> +pub fn parse_name<'a, B>(info_dict: &'a dyn BDictAccess) -> Result<&'a str, ParseError> where B: BRefAccess + 'a, { @@ -171,7 +171,7 @@ where /// Parses the files list from the info dictionary. #[allow(clippy::module_name_repetitions)] -pub fn parse_files_list(info_dict: &dyn BDictAccess) -> ParseResult<&dyn BListAccess> +pub fn parse_files_list(info_dict: &dyn BDictAccess) -> Result<&dyn BListAccess, ParseError> where B: BRefAccess + PartialEq, { @@ -182,7 +182,7 @@ where /// Parses the file dictionary from the file bencode. #[allow(clippy::module_name_repetitions)] -pub fn parse_file_dict(file_bencode: &B) -> ParseResult<&dyn BDictAccess> +pub fn parse_file_dict(file_bencode: &B) -> Result<&dyn BDictAccess, ParseError> where B: BRefAccess, { @@ -191,7 +191,7 @@ where /// Parses the length from the info or file dictionary. #[allow(clippy::module_name_repetitions)] -pub fn parse_length(info_or_file_dict: &dyn BDictAccess) -> ParseResult +pub fn parse_length(info_or_file_dict: &dyn BDictAccess) -> Result where B: BRefAccess, { @@ -211,7 +211,7 @@ where /// Parses the path list from the file dictionary. #[allow(clippy::module_name_repetitions)] -pub fn parse_path_list(file_dict: &dyn BDictAccess) -> ParseResult<&dyn BListAccess> +pub fn parse_path_list(file_dict: &dyn BDictAccess) -> Result<&dyn BListAccess, ParseError> where B: BRefAccess, { @@ -220,7 +220,7 @@ where /// Parses the path string from the path bencode. #[allow(clippy::module_name_repetitions)] -pub fn parse_path_str(path_bencode: &B) -> ParseResult<&str> +pub fn parse_path_str(path_bencode: &B) -> Result<&str, ParseError> where B: BRefAccess, { diff --git a/packages/peer/Cargo.toml b/packages/peer/Cargo.toml index dfcd04a9a..d5242f4c6 100644 --- a/packages/peer/Cargo.toml +++ b/packages/peer/Cargo.toml @@ -21,11 +21,15 @@ handshake = { path = "../handshake" } util = { path = "../util" } byteorder = "1" -bytes = "0.4" -crossbeam = "0.8" -error-chain = "0.12" -futures = "0.1" -nom = "3" -tokio-core = "0.1" -tokio-io = "0.1" -tokio-timer = "0.1" +bytes = "1" +crossbeam = "0" +futures = "0" +nom = "7" +pin-project = "1" +thiserror = "1" +tokio = { version = "1", features = ["full"] } +tokio-util = { version = "0", features = ["codec"] } +tracing = "0" + +[dev-dependencies] +tracing-subscriber = "0" diff --git a/packages/peer/src/codec.rs b/packages/peer/src/codec.rs index 18be8ce58..fbab54cac 100644 --- a/packages/peer/src/codec.rs +++ b/packages/peer/src/codec.rs @@ -1,14 +1,13 @@ //! Codecs operating over `PeerProtocol`s. -use std::io; - use bytes::{BufMut, BytesMut}; -use tokio_io::codec::{Decoder, Encoder}; +use tokio_util::codec::{Decoder, Encoder}; use crate::protocol::PeerProtocol; /// Codec operating over some `PeerProtocol`. #[allow(clippy::module_name_repetitions)] +#[derive(Debug)] pub struct PeerProtocolCodec

{ protocol: P, max_payload: Option, @@ -40,48 +39,64 @@ impl

PeerProtocolCodec

{ impl

Decoder for PeerProtocolCodec

where P: PeerProtocol, +

::ProtocolMessageError: std::error::Error + Send + Sync + 'static, { type Item = P::ProtocolMessage; - type Error = io::Error; + type Error = std::io::Error; + + fn decode(&mut self, src: &mut BytesMut) -> Result, Self::Error> { + let bytes_needed = self.protocol.bytes_needed(src)?; - fn decode(&mut self, src: &mut BytesMut) -> io::Result> { - let src_len = src.len(); + let Some(bytes_needed) = bytes_needed else { + return Ok(None); + }; - let bytes = match self.protocol.bytes_needed(src.as_ref())? { - Some(needed) if self.max_payload.is_some_and(|max_payload| needed > max_payload) => { - return Err(io::Error::new( - io::ErrorKind::Other, + if let Some(max_payload) = self.max_payload { + if bytes_needed > max_payload { + return Err(std::io::Error::new( + std::io::ErrorKind::Other, "PeerProtocolCodec Enforced Maximum Payload Check For Peer", - )) + )); } - Some(needed) if needed <= src_len => src.split_to(needed).freeze(), - Some(_) | None => return Ok(None), }; - self.protocol.parse_bytes(bytes).map(Some) + let bytes = if bytes_needed <= src.len() { + src.split_to(bytes_needed).freeze() + } else { + return Ok(None); + }; + + match self.protocol.parse_bytes(&bytes) { + Ok(item) => item.map_err(|e| std::io::Error::new(std::io::ErrorKind::Other, e)), + Err(err) => Err(err), + } + .map(Some) } } -impl

Encoder for PeerProtocolCodec

+impl

Encoder> for PeerProtocolCodec

where P: PeerProtocol, { - type Item = P::ProtocolMessage; - type Error = io::Error; + type Error = std::io::Error; + + fn encode(&mut self, item: std::io::Result, dst: &mut BytesMut) -> Result<(), Self::Error> { + let message = Ok(item?); + + let size = self.protocol.message_size(&message)?; - fn encode(&mut self, item: Self::Item, dst: &mut BytesMut) -> io::Result<()> { - dst.reserve(self.protocol.message_size(&item)); + dst.reserve(size); - self.protocol.write_bytes(&item, dst.writer()) + let _ = self.protocol.write_bytes(&message, dst.writer())?; + + Ok(()) } } #[cfg(test)] mod tests { - use std::io::{self, Write}; - - use bytes::{Bytes, BytesMut}; - use tokio_io::codec::Decoder; + use bytes::BytesMut; + use tokio_util::codec::Decoder as _; use super::PeerProtocolCodec; use crate::protocol::PeerProtocol; @@ -90,24 +105,29 @@ mod tests { impl PeerProtocol for ConsumeProtocol { type ProtocolMessage = (); + type ProtocolMessageError = std::io::Error; - fn bytes_needed(&mut self, bytes: &[u8]) -> io::Result> { + fn bytes_needed(&mut self, bytes: &[u8]) -> std::io::Result> { Ok(Some(bytes.len())) } - fn parse_bytes(&mut self, _bytes: Bytes) -> io::Result { - Ok(()) + fn parse_bytes(&mut self, _: &[u8]) -> std::io::Result> { + Ok(Ok(())) } - fn write_bytes(&mut self, _message: &Self::ProtocolMessage, _writer: W) -> io::Result<()> + fn write_bytes( + &mut self, + _: &Result, + _: W, + ) -> std::io::Result where - W: Write, + W: std::io::Write, { - Ok(()) + Ok(0) } - fn message_size(&mut self, _message: &Self::ProtocolMessage) -> usize { - 0 + fn message_size(&mut self, _: &Result) -> std::io::Result { + Ok(0) } } @@ -118,7 +138,8 @@ mod tests { bytes.extend_from_slice(&[0u8; 100]); - assert_eq!(Some(()), codec.decode(&mut bytes).unwrap()); + let () = codec.decode(&mut bytes).unwrap().unwrap(); + assert_eq!(bytes.len(), 0); } diff --git a/packages/peer/src/lib.rs b/packages/peer/src/lib.rs index 98d1a8006..df33a64b4 100644 --- a/packages/peer/src/lib.rs +++ b/packages/peer/src/lib.rs @@ -1,6 +1,3 @@ -#[macro_use] -mod macros; - mod codec; mod manager; mod message; @@ -9,7 +6,7 @@ mod protocol; pub use codec::PeerProtocolCodec; pub use crate::manager::builder::PeerManagerBuilder; -pub use crate::manager::messages::{IPeerManagerMessage, ManagedMessage, MessageId, OPeerManagerMessage}; +pub use crate::manager::messages::{ManagedMessage, PeerManagerInputMessage, PeerManagerOutputError, PeerManagerOutputMessage}; pub use crate::manager::peer_info::PeerInfo; pub use crate::manager::sink::PeerManagerSink; pub use crate::manager::stream::PeerManagerStream; @@ -25,15 +22,16 @@ pub mod messages { pub use crate::message::{ BitFieldIter, BitFieldMessage, BitsExtensionMessage, CancelMessage, ExtendedMessage, ExtendedType, HaveMessage, - NullProtocolMessage, PeerExtensionProtocolMessage, PeerWireProtocolMessage, PieceMessage, PortMessage, RequestMessage, - UtMetadataDataMessage, UtMetadataMessage, UtMetadataRejectMessage, UtMetadataRequestMessage, + NullProtocolMessage, PeerExtensionProtocolMessage, PeerExtensionProtocolMessageError, PeerWireProtocolMessage, + PeerWireProtocolMessageError, PieceMessage, PortMessage, RequestMessage, UtMetadataDataMessage, UtMetadataMessage, + UtMetadataRejectMessage, UtMetadataRequestMessage, }; } /// `PeerManager` error types. #[allow(clippy::module_name_repetitions)] pub mod error { - pub use crate::manager::error::{PeerManagerError, PeerManagerErrorKind, PeerManagerResult, PeerManagerResultExt}; + pub use crate::manager::error::{PeerManagerError, PeerManagerResult}; } /// Implementations of `PeerProtocol`. diff --git a/packages/peer/src/macros.rs b/packages/peer/src/macros.rs deleted file mode 100644 index 1fdb16a7d..000000000 --- a/packages/peer/src/macros.rs +++ /dev/null @@ -1,31 +0,0 @@ -#[allow(unused_macro_rules)] -macro_rules! throwaway_input ( - - ($res:expr) => ( - { - match $res { - IResult::Done(_, result) => IResult::Done((), result), - IResult::Error(e) => IResult::Error(e), - IResult::Incomplete(i) => IResult::Incomplete(i) - } - } - ); - ($i:expr, $func:path) => ( - { - throwaway_input!($func($i)) - } - ); - ($i:expr, $submac:ident!( $($args:tt)* )) => ( - { - throwaway_input!($submac!($i, $($args)*)) - } - ); -); - -macro_rules! ignore_input ( - ($i:expr, $submac:ident!( $($args:tt)* )) => ( - { - $submac!($($args)*) - } - ); -); diff --git a/packages/peer/src/manager/builder.rs b/packages/peer/src/manager/builder.rs index 27bce0c0d..25b04bf46 100644 --- a/packages/peer/src/manager/builder.rs +++ b/packages/peer/src/manager/builder.rs @@ -1,11 +1,9 @@ -use std::io; use std::time::Duration; use futures::sink::Sink; -use futures::stream::Stream; -use tokio_core::reactor::Handle; +use futures::{Stream, TryStream}; -use crate::manager::{ManagedMessage, PeerManager}; +use super::{ManagedMessage, PeerManager}; const DEFAULT_PEER_CAPACITY: usize = 1000; const DEFAULT_SINK_BUFFER_CAPACITY: usize = 100; @@ -17,99 +15,104 @@ const DEFAULT_HEARTBEAT_TIMEOUT_MILLIS: u64 = 2 * 60 * 1000; #[allow(clippy::module_name_repetitions)] #[derive(Default, Copy, Clone)] pub struct PeerManagerBuilder { - peer: usize, - sink_buffer: usize, - stream_buffer: usize, + peer_capacity: usize, + sink_buffer_capacity: usize, + stream_buffer_capacity: usize, heartbeat_interval: Duration, heartbeat_timeout: Duration, } impl PeerManagerBuilder { - /// Create a new `PeerManagerBuilder`. + /// Creates a new `PeerManagerBuilder` with default values. #[must_use] pub fn new() -> PeerManagerBuilder { PeerManagerBuilder { - peer: DEFAULT_PEER_CAPACITY, - sink_buffer: DEFAULT_SINK_BUFFER_CAPACITY, - stream_buffer: DEFAULT_STREAM_BUFFER_CAPACITY, + peer_capacity: DEFAULT_PEER_CAPACITY, + sink_buffer_capacity: DEFAULT_SINK_BUFFER_CAPACITY, + stream_buffer_capacity: DEFAULT_STREAM_BUFFER_CAPACITY, heartbeat_interval: Duration::from_millis(DEFAULT_HEARTBEAT_INTERVAL_MILLIS), heartbeat_timeout: Duration::from_millis(DEFAULT_HEARTBEAT_TIMEOUT_MILLIS), } } - /// Max number of peers we can manage. + /// Sets the maximum number of peers that can be managed. #[must_use] pub fn with_peer_capacity(mut self, capacity: usize) -> PeerManagerBuilder { - self.peer = capacity; + self.peer_capacity = capacity; self } - /// Capacity of pending sent messages. + /// Sets the capacity of the sink buffer for pending sent messages. #[must_use] pub fn with_sink_buffer_capacity(mut self, capacity: usize) -> PeerManagerBuilder { - self.sink_buffer = capacity; + self.sink_buffer_capacity = capacity; self } - /// Capacity of pending received messages. + /// Sets the capacity of the stream buffer for pending received messages. #[must_use] pub fn with_stream_buffer_capacity(mut self, capacity: usize) -> PeerManagerBuilder { - self.stream_buffer = capacity; + self.stream_buffer_capacity = capacity; self } - /// Interval at which we send keep-alive messages. + /// Sets the interval at which keep-alive messages are sent. #[must_use] pub fn with_heartbeat_interval(mut self, interval: Duration) -> PeerManagerBuilder { self.heartbeat_interval = interval; self } - /// Timeout at which we disconnect from the peer without seeing a keep-alive message. + /// Sets the timeout duration after which a peer is disconnected if no keep-alive message is received. #[must_use] pub fn with_heartbeat_timeout(mut self, timeout: Duration) -> PeerManagerBuilder { self.heartbeat_timeout = timeout; self } - /// Retrieve the peer capacity. + /// Retrieves the peer capacity. #[must_use] pub fn peer_capacity(&self) -> usize { - self.peer + self.peer_capacity } - /// Retrieve the sink buffer capacity. + /// Retrieves the sink buffer capacity. #[must_use] pub fn sink_buffer_capacity(&self) -> usize { - self.sink_buffer + self.sink_buffer_capacity } - /// Retrieve the stream buffer capacity. + /// Retrieves the stream buffer capacity. #[must_use] pub fn stream_buffer_capacity(&self) -> usize { - self.stream_buffer + self.stream_buffer_capacity } - /// Retrieve the heartbeat interval `Duration`. + /// Retrieves the heartbeat interval `Duration`. #[must_use] pub fn heartbeat_interval(&self) -> Duration { self.heartbeat_interval } - /// Retrieve the heartbeat timeout `Duration`. + /// Retrieves the heartbeat timeout `Duration`. #[must_use] pub fn heartbeat_timeout(&self) -> Duration { self.heartbeat_timeout } - /// Build a `PeerManager` from the current `PeerManagerBuilder`. + /// Builds a `PeerManager` from the current `PeerManagerBuilder` configuration. #[must_use] - pub fn build

(self, handle: Handle) -> PeerManager

+ pub fn build(self) -> PeerManager where - P: Sink + Stream, - P::SinkItem: ManagedMessage, - P::Item: ManagedMessage, + Peer: Sink> + + Stream> + + TryStream + + std::fmt::Debug + + Send + + Unpin + + 'static, + Message: ManagedMessage + Send + 'static, { - PeerManager::from_builder(self, handle) + PeerManager::from_builder(self) } } diff --git a/packages/peer/src/manager/error.rs b/packages/peer/src/manager/error.rs index e747eb646..30644ebe9 100644 --- a/packages/peer/src/manager/error.rs +++ b/packages/peer/src/manager/error.rs @@ -1,19 +1,33 @@ -use error_chain::error_chain; +use thiserror::Error; use crate::manager::peer_info::PeerInfo; -error_chain! { - types { - PeerManagerError, PeerManagerErrorKind, PeerManagerResultExt, PeerManagerResult; - } +#[allow(clippy::module_name_repetitions)] +#[derive(Error, Debug)] +pub enum PeerManagerError { + #[error("Input message was an error: {0}")] + InputMessageError(std::io::Error), - errors { - PeerNotFound { - info: PeerInfo - } { - description("Peer Was Not Found") - display("Peer Was Not Found With PeerInfo {:?}", info) - } + #[error("Peer Was Not Found With PeerInfo {0:?}")] + PeerNotFound(PeerInfo), - } + #[error("Unable to add new peer to full capacity store. Actual Size: {0} ")] + PeerStoreFull(usize), + + #[error("Unable to add an already existing Peer {0:?}")] + PeerAlreadyExists(PeerInfo), + + #[error("Failed to Get Lock For Peer")] + LockFailed, + + #[error("Failed to Send to Sink")] + SendFailed(SendErr), + + #[error("Failed to Flush to Sink")] + FlushFailed(SendErr), + + #[error("Failed to close Sink")] + CloseFailed(SendErr), } + +pub type PeerManagerResult = Result>; diff --git a/packages/peer/src/manager/fused.rs b/packages/peer/src/manager/fused.rs index afc043fb7..075ff4625 100644 --- a/packages/peer/src/manager/fused.rs +++ b/packages/peer/src/manager/fused.rs @@ -1,137 +1,131 @@ -use std::io; +use std::pin::Pin; +use std::task::{Context, Poll}; use std::time::Duration; use futures::stream::{Fuse, Stream}; -use futures::{Async, Future, Poll}; -use tokio_timer::{Sleep, TimeoutError, Timer}; +use futures::{StreamExt, TryStream}; +use tokio::time::Instant; /// Error type for `PersistentStream`. -pub enum PersistentError { +pub enum PersistentError { Disconnect, - Timeout, - IoError(io::Error), + StreamError(Err), } -impl From> for PersistentError { - fn from(error: TimeoutError) -> PersistentError { - match error { - TimeoutError::Timer(_, err) => { - panic!("bip_peer: Timer Error In Peer Stream, Timer Capacity Is Probably Too Small: {err}") - } - TimeoutError::TimedOut(_) => PersistentError::Timeout, - } - } +/// Error type for `RecurringTimeoutStream`. +pub enum RecurringTimeoutError { + Disconnect, + Timeout, + StreamError(Err), } -/// Stream for persistent connections, where a value of None from the underlying -/// stream maps to an actual error, and calling poll multiple times will always -/// return such error. -pub struct PersistentStream { - stream: Fuse, +/// A stream wrapper that ensures persistent connections. If the underlying stream yields `None`, +/// it is treated as an error, and subsequent polls will continue to return this error. +pub struct PersistentStream +where + St: Stream>, + St: TryStream, +{ + stream: Fuse, } -impl PersistentStream +impl PersistentStream where - S: Stream, + St: Stream>, + St: TryStream, { - /// Create a new `PersistentStream`. - pub fn new(stream: S) -> PersistentStream { + /// Creates a new `PersistentStream`. + pub fn new(stream: St) -> PersistentStream { PersistentStream { stream: stream.fuse() } } } -impl Stream for PersistentStream +impl Stream for PersistentStream where - S: Stream, + St: Stream>, + St: TryStream + Unpin, { - type Item = S::Item; - type Error = PersistentError; - - fn poll(&mut self) -> Poll, Self::Error> { - self.stream - .poll() - .map_err(PersistentError::IoError) - .and_then(|item| match item { - Async::Ready(None) => Err(PersistentError::Disconnect), - other => Ok(other), - }) + type Item = Result>; + + fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + let ready = match self.stream.poll_next_unpin(cx) { + Poll::Ready(ready) => ready, + Poll::Pending => return Poll::Pending, + }; + + let Some(item) = ready else { + return Poll::Ready(Some(Err(PersistentError::Disconnect))); + }; + + match item { + Ok(message) => Poll::Ready(Some(Ok(message))), + Err(err) => Poll::Ready(Some(Err(PersistentError::StreamError(err)))), + } } } //----------------------------------------------------------------------------// -/// Error type for `RecurringTimeoutStream`. -pub enum RecurringTimeoutError { - /// None and any errors are mapped to this type... - Disconnect, - Timeout, -} - -/// Stream similar to `tokio_timer::TimeoutStream`, but which doesn't return -/// the underlying stream if a single timeout occurs. Instead, it signals that -/// the timeout occurred before the stream produced an item, but keeps the -/// stream in tact (does not return it), so that we can continue polling. -/// -/// Whereas `tokio_timer::TimeoutStream` would be used for detecting if a -/// client timed out, `RecurringTimeoutStream` could be used for a local -/// stream to send heartbeats if, for example, the local client hadn't sent -/// any other message to the client for n seconds and we would like to send -/// some heartbeat message in that case, but continue polling the stream. -pub struct RecurringTimeoutStream { - dur: Duration, - timer: Timer, - sleep: Sleep, - stream: S, +/// A stream wrapper that enforces a recurring timeout. If the underlying stream does not yield +/// an item within the specified duration, a timeout error is returned. +pub struct RecurringTimeoutStream +where + St: Stream>, + St: TryStream, +{ + stream: Fuse, + timeout: Duration, + deadline: Instant, } -impl RecurringTimeoutStream { - pub fn new(stream: S, timer: Timer, dur: Duration) -> RecurringTimeoutStream { - let sleep = timer.sleep(dur); - +impl RecurringTimeoutStream +where + St: Stream>, + St: TryStream, +{ + /// Creates a new `RecurringTimeoutStream`. + pub fn new(stream: St, timeout: Duration) -> RecurringTimeoutStream { RecurringTimeoutStream { - dur, - timer, - sleep, - stream, + stream: stream.fuse(), + timeout, + deadline: Instant::now() + timeout, } } } -impl Stream for RecurringTimeoutStream +impl Stream for RecurringTimeoutStream where - S: Stream, + St: Stream>, + St: TryStream + Unpin, { - type Item = S::Item; - type Error = RecurringTimeoutError; - - fn poll(&mut self) -> Poll, RecurringTimeoutError> { - // First, try polling the future - match self.stream.poll() { - Ok(Async::NotReady) => {} - Ok(Async::Ready(Some(v))) => { - // Reset the timeout - self.sleep = self.timer.sleep(self.dur); - - // Return the value - return Ok(Async::Ready(Some(v))); + type Item = Result>; + + fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + let ready = match self.stream.poll_next_unpin(cx) { + Poll::Ready(ready) => ready, + Poll::Pending => { + let now = Instant::now(); + if now > self.deadline { + self.deadline = now + self.timeout; + + return Poll::Ready(Some(Err(RecurringTimeoutError::Timeout))); + } + return Poll::Pending; } - Ok(Async::Ready(None)) => return Ok(Async::Ready(None)), - Err(_) => return Err(RecurringTimeoutError::Disconnect), - } + }; - let Ok(poll) = self.sleep.poll() else { - panic!("bip_peer: Timer Error In Manager Stream, Timer Capacity Is Probably Too Small...") + let Some(item) = ready else { + return Poll::Ready(Some(Err(RecurringTimeoutError::Disconnect))); }; - // Now check the timer - match poll { - Async::NotReady => Ok(Async::NotReady), - Async::Ready(()) => { + match item { + Ok(message) => { // Reset the timeout - self.sleep = self.timer.sleep(self.dur); + self.deadline = Instant::now() + self.timeout; - Err(RecurringTimeoutError::Timeout) + Poll::Ready(Some(Ok(message))) } + Err(err) => Poll::Ready(Some(Err(RecurringTimeoutError::StreamError(err)))), } } } diff --git a/packages/peer/src/manager/messages.rs b/packages/peer/src/manager/messages.rs index 235e7f1ca..73387832e 100644 --- a/packages/peer/src/manager/messages.rs +++ b/packages/peer/src/manager/messages.rs @@ -1,17 +1,17 @@ -use futures::Sink; +use futures::{Sink, Stream, TryStream}; +use thiserror::Error; -use crate::PeerInfo; +use crate::manager::peer_info::PeerInfo; -/// Trait for giving `PeerManager` message information it needs. +/// Trait for providing `PeerManager` with necessary message information. /// -/// For any `PeerProtocol` (or plain `Codec`), that wants to be managed -/// by `PeerManager`, it must ensure that it's message type implements -/// this trait so that we have the hooks necessary to manage the peer. -pub trait ManagedMessage { - /// Retrieve a keep alive message variant. +/// Any `PeerProtocol` (or plain `Codec`) that wants to be managed by `PeerManager` +/// must ensure that its message type implements this trait to provide the necessary hooks. +pub trait ManagedMessage: std::fmt::Debug { + /// Retrieves a keep-alive message variant. fn keep_alive() -> Self; - /// Whether or not this message is a keep alive message. + /// Checks whether this message is a keep-alive message. fn is_keep_alive(&self) -> bool; } @@ -20,35 +20,55 @@ pub trait ManagedMessage { /// Identifier for matching sent messages with received messages. pub type MessageId = u64; -/// Message that can be sent to the `PeerManager`. -pub enum IPeerManagerMessage

+/// Messages that can be sent to the `PeerManager`. +#[derive(Debug)] +pub enum PeerManagerInputMessage where - P: Sink, + Peer: Sink> + + Stream> + + TryStream + + std::fmt::Debug + + Send + + Unpin + + 'static, + Message: ManagedMessage + Send + 'static, { - /// Add a peer to the peer manager. - AddPeer(PeerInfo, P), - /// Remove a peer from the peer manager. + /// Adds a peer to the peer manager. + AddPeer(PeerInfo, Peer), + /// Removes a peer from the peer manager. RemovePeer(PeerInfo), - /// Send a message to a peer. - SendMessage(PeerInfo, MessageId, P::SinkItem), // TODO: Support querying for statistics + /// Sends a message to a peer. + SendMessage(PeerInfo, MessageId, Message), // TODO: Support querying for statistics } -/// Message that can be received from the `PeerManager`. -pub enum OPeerManagerMessage { - /// Message indicating a peer has been added to the peer manager. +#[derive(Error, Debug)] +pub enum PeerManagerOutputError { + #[error("Peer Disconnected, but Missing")] + PeerDisconnectedAndMissing(PeerInfo), + + #[error("Peer Removed, but Missing")] + PeerRemovedAndMissing(PeerInfo), + + #[error("Peer Errored, but Missing")] + PeerErrorAndMissing(PeerInfo, Option>), + + #[error("Error with Peer")] + PeerError(PeerInfo, std::io::Error), +} + +/// Messages that can be received from the `PeerManager`. +#[derive(Debug)] +pub enum PeerManagerOutputMessage { + /// Indicates a peer has been added to the peer manager. PeerAdded(PeerInfo), - /// Message indicating a peer has been removed from the peer manager. + /// Indicates a peer has been removed from the peer manager. PeerRemoved(PeerInfo), - /// Message indicating a message has been sent to the given peer. + /// Indicates a message has been sent to the given peer. SentMessage(PeerInfo, MessageId), - /// Message indicating we have received a message from a peer. - ReceivedMessage(PeerInfo, M), - /// Message indicating a peer has disconnected from us. + /// Indicates a message has been received from a peer. + ReceivedMessage(PeerInfo, Message), + /// Indicates a peer has disconnected. /// /// Same semantics as `PeerRemoved`, but the peer is not returned. PeerDisconnect(PeerInfo), - /// Message indicating a peer errored out. - /// - /// Same semantics as `PeerRemoved`, but the peer is not returned. - PeerError(PeerInfo, std::io::Error), } diff --git a/packages/peer/src/manager/mod.rs b/packages/peer/src/manager/mod.rs index 66972c1ea..554adca29 100644 --- a/packages/peer/src/manager/mod.rs +++ b/packages/peer/src/manager/mod.rs @@ -1,20 +1,17 @@ use std::collections::HashMap; +use std::marker::PhantomData; use std::sync::{Arc, Mutex}; -use std::time::Duration; -use std::{cmp, io}; use crossbeam::queue::SegQueue; +use error::PeerManagerError; +use futures::channel::mpsc::{self, SendError}; use futures::sink::Sink; use futures::stream::Stream; -use futures::sync::mpsc; -use futures::{Poll, StartSend}; -use messages::{IPeerManagerMessage, ManagedMessage, OPeerManagerMessage}; +use futures::{SinkExt as _, StreamExt, TryStream}; use sink::PeerManagerSink; -use stream::PeerManagerStream; -use tokio_core::reactor::Handle; -use crate::manager::builder::PeerManagerBuilder; -use crate::manager::error::PeerManagerError; +use super::ManagedMessage; +use crate::{PeerManagerBuilder, PeerManagerInputMessage, PeerManagerOutputError, PeerManagerOutputMessage, PeerManagerStream}; pub mod builder; pub mod error; @@ -26,91 +23,116 @@ pub mod stream; mod fused; mod task; -// We configure our tick duration based on this, could let users configure this in the future... -const DEFAULT_TIMER_SLOTS: usize = 2048; - /// Manages a set of peers with beating hearts. #[allow(clippy::module_name_repetitions)] -pub struct PeerManager

+pub struct PeerManager where - P: Sink + Stream, + Peer: Sink> + + Stream> + + TryStream + + std::fmt::Debug + + Send + + Unpin + + 'static, + Message: ManagedMessage + Send + 'static, { - sink: PeerManagerSink

, - stream: PeerManagerStream

, + sink: PeerManagerSink, + stream: PeerManagerStream, + _peer_marker: PhantomData, } -impl

PeerManager

+impl PeerManager where - P: Sink + Stream, - P::SinkItem: ManagedMessage, - P::Item: ManagedMessage, + Peer: Sink> + + Stream> + + TryStream + + std::fmt::Debug + + Send + + Unpin + + 'static, + Message: ManagedMessage + Send + 'static, { /// Create a new `PeerManager` from the given `PeerManagerBuilder`. #[must_use] - pub fn from_builder(builder: PeerManagerBuilder, handle: Handle) -> PeerManager

{ - // We use one timer for manager heartbeat intervals, and one for peer heartbeat timeouts - let maximum_timers = builder.peer_capacity() * 2; - let pow_maximum_timers = if maximum_timers & (maximum_timers - 1) == 0 { - maximum_timers - } else { - maximum_timers.next_power_of_two() - }; - - // Figure out the right tick duration to get num slots of 2048. - // TODO: We could probably let users change this in the future... - let max_duration = cmp::max(builder.heartbeat_interval(), builder.heartbeat_timeout()); - let tick_duration = Duration::from_millis(max_duration.as_secs() * 1000 / (DEFAULT_TIMER_SLOTS as u64) + 1); - let timer = tokio_timer::wheel() - .tick_duration(tick_duration) - .max_capacity(pow_maximum_timers + 1) - .channel_capacity(pow_maximum_timers) - .num_slots(DEFAULT_TIMER_SLOTS) - .build(); - + pub fn from_builder(builder: PeerManagerBuilder) -> PeerManager { let (res_send, res_recv) = mpsc::channel(builder.stream_buffer_capacity()); let peers = Arc::new(Mutex::new(HashMap::new())); let task_queue = Arc::new(SegQueue::new()); - let sink = PeerManagerSink::new(handle, timer, builder, res_send, peers.clone(), task_queue.clone()); - let stream = PeerManagerStream::new(res_recv, peers, task_queue); + let sink = PeerManagerSink::new(builder, res_send, peers.clone(), task_queue.clone()); + let stream = PeerManagerStream::new(res_recv, peers); - PeerManager { sink, stream } + PeerManager { + sink, + stream, + _peer_marker: PhantomData, + } } /// Break the `PeerManager` into a sink and stream. /// /// The returned sink implements `Clone`. - pub fn into_parts(self) -> (PeerManagerSink

, PeerManagerStream

) { + pub fn into_parts(self) -> (PeerManagerSink, PeerManagerStream) { (self.sink, self.stream) } } -impl

Sink for PeerManager

+impl Sink>> for PeerManager where - P: Sink + Stream + 'static, - P::SinkItem: ManagedMessage, - P::Item: ManagedMessage, + Peer: Sink> + + Stream> + + TryStream + + std::fmt::Debug + + Send + + Unpin + + 'static, + Message: ManagedMessage + Send + 'static, { - type SinkItem = IPeerManagerMessage

; - type SinkError = PeerManagerError; + type Error = PeerManagerError; + + fn poll_ready( + mut self: std::pin::Pin<&mut Self>, + cx: &mut std::task::Context<'_>, + ) -> std::task::Poll> { + self.sink.poll_ready_unpin(cx) + } + + fn start_send( + mut self: std::pin::Pin<&mut Self>, + item: std::io::Result>, + ) -> Result<(), Self::Error> { + self.sink.start_send_unpin(item) + } - fn start_send(&mut self, item: Self::SinkItem) -> StartSend { - self.sink.start_send(item) + fn poll_flush( + mut self: std::pin::Pin<&mut Self>, + cx: &mut std::task::Context<'_>, + ) -> std::task::Poll> { + self.sink.poll_flush_unpin(cx) } - fn poll_complete(&mut self) -> Poll<(), Self::SinkError> { - self.sink.poll_complete() + fn poll_close( + mut self: std::pin::Pin<&mut Self>, + cx: &mut std::task::Context<'_>, + ) -> std::task::Poll> { + self.sink.poll_close_unpin(cx) } } -impl

Stream for PeerManager

+impl Stream for PeerManager where - P: Sink + Stream, + Peer: Sink> + + Stream> + + TryStream + + std::fmt::Debug + + Send + + Unpin + + 'static, + Message: ManagedMessage + Send + 'static, { - type Item = OPeerManagerMessage; - type Error = (); + type Item = Result, PeerManagerOutputError>; - fn poll(&mut self) -> Poll, Self::Error> { - self.stream.poll() + fn poll_next(mut self: std::pin::Pin<&mut Self>, cx: &mut std::task::Context<'_>) -> std::task::Poll> { + self.stream.poll_next_unpin(cx) } } diff --git a/packages/peer/src/manager/sink.rs b/packages/peer/src/manager/sink.rs index 4216f623a..efe157415 100644 --- a/packages/peer/src/manager/sink.rs +++ b/packages/peer/src/manager/sink.rs @@ -1,258 +1,275 @@ -//! Sink half of a `PeerManager`. - use std::collections::hash_map::Entry; use std::collections::HashMap; +use std::pin::Pin; use std::sync::{Arc, Mutex}; use crossbeam::queue::SegQueue; +use futures::channel::mpsc::{self, SendError}; use futures::sink::Sink; -use futures::stream::Stream; -use futures::sync::mpsc::Sender; -use futures::task::{self as futures_task, Task}; -use futures::{Async, AsyncSink, Poll, StartSend}; -use tokio_core::reactor::Handle; -use tokio_timer::{self, Timer}; - -use super::messages::{IPeerManagerMessage, ManagedMessage, OPeerManagerMessage}; -use super::task; +use futures::task::{Context, Poll}; +use futures::{SinkExt as _, Stream, TryStream}; + +use super::messages::{PeerManagerInputMessage, PeerManagerOutputError, PeerManagerOutputMessage}; +use super::task::run_peer; use crate::manager::builder::PeerManagerBuilder; -use crate::manager::error::{PeerManagerError, PeerManagerErrorKind}; +use crate::manager::error::PeerManagerError; use crate::manager::peer_info::PeerInfo; +use crate::manager::ManagedMessage; +/// Sink half of a `PeerManager`. #[allow(clippy::module_name_repetitions)] -pub struct PeerManagerSink

+pub struct PeerManagerSink where - P: Sink + Stream, + Peer: Sink> + + Stream> + + TryStream + + std::fmt::Debug + + Send + + Unpin + + 'static, + Message: ManagedMessage + Send + 'static, { - handle: Handle, - timer: Timer, - build: PeerManagerBuilder, - send: Sender>, - peers: Arc>>>>, - task_queue: Arc>, + builder: PeerManagerBuilder, + sender: mpsc::Sender, PeerManagerOutputError>>, + #[allow(clippy::type_complexity)] + peers: Arc>>>>, + task_queue: Arc>>, } -impl

Clone for PeerManagerSink

+impl Clone for PeerManagerSink where - P: Sink + Stream, + Peer: Sink> + + Stream> + + TryStream + + std::fmt::Debug + + Send + + Unpin + + 'static, + Message: ManagedMessage + Send + 'static, { - fn clone(&self) -> PeerManagerSink

{ + fn clone(&self) -> PeerManagerSink { PeerManagerSink { - handle: self.handle.clone(), - timer: self.timer.clone(), - build: self.build, - send: self.send.clone(), + builder: self.builder, + sender: self.sender.clone(), peers: self.peers.clone(), task_queue: self.task_queue.clone(), } } } -impl

PeerManagerSink

+impl PeerManagerSink where - P: Sink + Stream, + Peer: Sink> + + Stream> + + TryStream + + std::fmt::Debug + + Send + + Unpin + + 'static, + Message: ManagedMessage + Send + 'static, { - pub(super) fn new( - handle: Handle, - timer: Timer, - build: PeerManagerBuilder, - send: Sender>, - peers: Arc>>>>, - task_queue: Arc>, - ) -> PeerManagerSink

{ + #[allow(clippy::type_complexity)] + pub fn new( + builder: PeerManagerBuilder, + sender: mpsc::Sender, PeerManagerOutputError>>, + peers: Arc>>>>, + task_queue: Arc>>, + ) -> PeerManagerSink { PeerManagerSink { - handle, - timer, - build, - send, + builder, + sender, peers, task_queue, } } +} - fn run_with_lock_sink(&mut self, item: I, call: F, not: G) -> StartSend - where - F: FnOnce( - I, - &mut Handle, - &mut Timer, - &mut PeerManagerBuilder, - &mut Sender>, - &mut HashMap>>, - ) -> StartSend, - G: FnOnce(I) -> T, - { - let (result, took_lock) = if let Ok(mut guard) = self.peers.try_lock() { - let result = call( - item, - &mut self.handle, - &mut self.timer, - &mut self.build, - &mut self.send, - &mut *guard, - ); - - // Closure could return not ready, need to stash in that case - if result.as_ref().map(futures::AsyncSink::is_not_ready).unwrap_or(false) { - self.task_queue.push(futures_task::current()); +impl PeerManagerSink +where + Peer: Sink> + + Stream> + + TryStream + + std::fmt::Debug + + Send + + Unpin + + 'static, + Message: ManagedMessage + Send + 'static, +{ + fn handle_message( + &self, + item: std::io::Result>, + ) -> Result<(), PeerManagerError> { + let message = match item { + Ok(message) => message, + Err(e) => { + tracing::debug!("got input message error {e}"); + return Err(PeerManagerError::InputMessageError(e)); } + }; - (result, true) - } else { - self.task_queue.push(futures_task::current()); - - if let Ok(mut guard) = self.peers.try_lock() { - let result = call( - item, - &mut self.handle, - &mut self.timer, - &mut self.build, - &mut self.send, - &mut *guard, - ); - - // Closure could return not ready, need to stash in that case - if result.as_ref().map(futures::AsyncSink::is_not_ready).unwrap_or(false) { - self.task_queue.push(futures_task::current()); - } + match message { + PeerManagerInputMessage::AddPeer(info, peer) => self.add_peer(info, peer), + PeerManagerInputMessage::RemovePeer(info) => self.remove_peer(info), + PeerManagerInputMessage::SendMessage(info, mid, peer_message) => self.send_message(info, mid, peer_message), + } + } - (result, true) - } else { - (Ok(AsyncSink::NotReady(not(item))), false) - } + fn add_peer(&self, info: PeerInfo, peer: Peer) -> Result<(), PeerManagerError> { + tracing::trace!("adding peer: {peer:?}, with info: {info:?}"); + + let Ok(mut guard) = self.peers.try_lock() else { + tracing::debug!("failed to get peers lock"); + return Err(PeerManagerError::LockFailed); }; - if took_lock { - // Just notify a single person waiting on the lock to reduce contention - if let Some(task) = self.task_queue.pop() { - task.notify(); - } + let cur = guard.len(); + let max = self.builder.peer_capacity(); + + if cur >= max { + tracing::debug!("max peers reached: {cur} of max: {max}"); + return Err(PeerManagerError::PeerStoreFull(guard.len())); } - result + match guard.entry(info) { + Entry::Occupied(_) => { + tracing::debug!("peer already exists: {info:?}"); + return Err(PeerManagerError::PeerAlreadyExists(info)); + } + Entry::Vacant(vac) => { + let (sender, task) = run_peer(peer, info, self.sender.clone(), &self.builder); + vac.insert(sender); + self.task_queue.push(task); // Add the task to the task queue + } + }; + + Ok(()) } - fn run_with_lock_poll(&mut self, call: F) -> Poll - where - F: FnOnce( - &mut Handle, - &mut Timer, - &mut PeerManagerBuilder, - &mut Sender>, - &mut HashMap>>, - ) -> Poll, - { - let (result, took_lock) = if let Ok(mut guard) = self.peers.try_lock() { - let result = call( - &mut self.handle, - &mut self.timer, - &mut self.build, - &mut self.send, - &mut *guard, - ); - - (result, true) - } else { - // Stash a task - self.task_queue.push(futures_task::current()); - - // Try to get lock again in case of race condition - if let Ok(mut guard) = self.peers.try_lock() { - let result = call( - &mut self.handle, - &mut self.timer, - &mut self.build, - &mut self.send, - &mut *guard, - ); - - (result, true) - } else { - (Ok(Async::NotReady), false) - } + fn remove_peer(&self, info: PeerInfo) -> Result<(), PeerManagerError> { + tracing::trace!("removing peer, with info: {info:?}"); + + let Ok(mut guard) = self.peers.try_lock() else { + tracing::debug!("failed to get peers lock"); + return Err(PeerManagerError::LockFailed); }; - if took_lock { - // Just notify a single person waiting on the lock to reduce contention - if let Some(task) = self.task_queue.pop() { - task.notify(); - } - } + let peer_sender = guard.get_mut(&info).ok_or(PeerManagerError::PeerNotFound(info))?; + + peer_sender + .start_send(PeerManagerInputMessage::RemovePeer(info)) + .map_err(PeerManagerError::SendFailed)?; + + Ok(()) + } + + fn send_message(&self, info: PeerInfo, mid: u64, msg: Message) -> Result<(), PeerManagerError> { + tracing::trace!("sending message {msg:?}, with info: {info:?}, and mid: {mid}"); + + let Ok(mut guard) = self.peers.try_lock() else { + tracing::debug!("failed to get peers lock"); + return Err(PeerManagerError::LockFailed); + }; + + let peer_sender = guard.get_mut(&info).ok_or(PeerManagerError::PeerNotFound(info))?; - result + peer_sender + .start_send(PeerManagerInputMessage::SendMessage(info, mid, msg)) + .map_err(PeerManagerError::SendFailed)?; + + Ok(()) } } -impl

Sink for PeerManagerSink

+impl Sink>> for PeerManagerSink where - P: Sink + Stream + 'static, - P::SinkItem: ManagedMessage, - P::Item: ManagedMessage, + Peer: Sink> + + Stream> + + TryStream + + std::fmt::Debug + + Send + + Unpin + + 'static, + Message: ManagedMessage + Send + 'static, { - type SinkItem = IPeerManagerMessage

; - type SinkError = PeerManagerError; - - fn start_send(&mut self, item: Self::SinkItem) -> StartSend { - match item { - IPeerManagerMessage::AddPeer(info, peer) => self.run_with_lock_sink( - (info, peer), - |(info, peer), handle, timer, builder, send, peers| { - if peers.len() >= builder.peer_capacity() { - Ok(AsyncSink::NotReady(IPeerManagerMessage::AddPeer(info, peer))) - } else { - match peers.entry(info) { - Entry::Occupied(_) => Err(PeerManagerError::from_kind(PeerManagerErrorKind::PeerNotFound { info })), - Entry::Vacant(vac) => { - vac.insert(task::run_peer(peer, info, send.clone(), timer.clone(), builder, handle)); - - Ok(AsyncSink::Ready) - } - } - } - }, - |(info, peer)| IPeerManagerMessage::AddPeer(info, peer), - ), - IPeerManagerMessage::RemovePeer(info) => self.run_with_lock_sink( - info, - |info, _, _, _, _, peers| { - peers - .get_mut(&info) - .ok_or_else(|| PeerManagerError::from_kind(PeerManagerErrorKind::PeerNotFound { info })) - .and_then(|send| { - send.start_send(IPeerManagerMessage::RemovePeer(info)) - .map_err(|_| panic!("bip_peer: PeerManager Failed To Send RemovePeer")) - }) - }, - |info| IPeerManagerMessage::RemovePeer(info), - ), - IPeerManagerMessage::SendMessage(info, mid, peer_message) => self.run_with_lock_sink( - (info, mid, peer_message), - |(info, mid, peer_message), _, _, _, _, peers| { - peers - .get_mut(&info) - .ok_or_else(|| PeerManagerError::from_kind(PeerManagerErrorKind::PeerNotFound { info })) - .and_then(|send| { - send.start_send(IPeerManagerMessage::SendMessage(info, mid, peer_message)) - .map_err(|_| panic!("bip_peer: PeerManager Failed to Send SendMessage")) - }) - }, - |(info, mid, peer_message)| IPeerManagerMessage::SendMessage(info, mid, peer_message), - ), + type Error = PeerManagerError; + + fn poll_ready(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + let Ok(mut guard) = self.peers.try_lock() else { + cx.waker().wake_by_ref(); + return Poll::Pending; + }; + + for peer_sender in guard.values_mut() { + match peer_sender.poll_ready_unpin(cx) { + Poll::Ready(Ok(())) => continue, + Poll::Ready(Err(e)) => return Poll::Ready(Err(PeerManagerError::SendFailed(e))), + Poll::Pending => return Poll::Pending, + } + } + + Poll::Ready(Ok(())) + } + + fn start_send( + self: Pin<&mut Self>, + item: std::io::Result>, + ) -> Result<(), Self::Error> { + tracing::trace!("handling message: {item:?}"); + self.handle_message(item) + } + + fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + tracing::trace!("flushing..."); + + let Ok(mut guard) = self.peers.try_lock() else { + tracing::debug!("failed to get peers lock... will reschedule with waker"); + cx.waker().wake_by_ref(); + return Poll::Pending; + }; + + for peer_sender in guard.values_mut() { + match peer_sender.poll_flush_unpin(cx) { + Poll::Ready(Ok(())) => continue, + Poll::Ready(Err(e)) => return Poll::Ready(Err(PeerManagerError::FlushFailed(e))), + Poll::Pending => { + tracing::debug!("pending to flush peer sender... will reschedule with waker"); + return Poll::Pending; + } + } } + + Poll::Ready(Ok(())) } - fn poll_complete(&mut self) -> Poll<(), Self::SinkError> { - self.run_with_lock_poll(|_, _, _, _, peers| { - for peer_mut in peers.values_mut() { - // Needs type hint in case poll fails (so that error type matches) - let result: Poll<(), Self::SinkError> = peer_mut - .poll_complete() - .map_err(|_| panic!("bip_peer: PeerManaged Failed To Poll Peer")); + fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + tracing::trace!("closing..."); - result?; + let Ok(mut guard) = self.peers.try_lock() else { + tracing::debug!("failed to get peers lock... will reschedule with waker"); + cx.waker().wake_by_ref(); + return Poll::Pending; + }; + + for peer_sender in guard.values_mut() { + match peer_sender.poll_close_unpin(cx) { + Poll::Ready(Ok(())) => continue, + Poll::Ready(Err(e)) => return Poll::Ready(Err(PeerManagerError::FlushFailed(e))), + Poll::Pending => { + tracing::debug!("pending to flush peer sender... will reschedule with waker"); + return Poll::Pending; + } + } + } + + while let Some(task) = self.task_queue.pop() { + if !task.is_finished() { + tracing::debug!("task is not finished... will reschedule with waker"); + self.task_queue.push(task); + cx.waker().wake_by_ref(); + return Poll::Pending; } + } - Ok(Async::Ready(())) - }) + Poll::Ready(Ok(())) } } diff --git a/packages/peer/src/manager/stream.rs b/packages/peer/src/manager/stream.rs index ffd6052e4..7d0b0174a 100644 --- a/packages/peer/src/manager/stream.rs +++ b/packages/peer/src/manager/stream.rs @@ -1,138 +1,138 @@ -//! Stream half of a `PeerManager`. - use std::collections::HashMap; use std::sync::{Arc, Mutex}; +use std::task::{Context, Poll}; -use crossbeam::queue::SegQueue; -use futures::sink::Sink; +use futures::channel::mpsc; use futures::stream::Stream; -use futures::sync::mpsc::{Receiver, Sender}; -use futures::task::{self as futures_task, Task}; -use futures::{Async, Poll}; +use futures::{Sink, StreamExt, TryStream}; +use pin_project::pin_project; -use super::messages::{IPeerManagerMessage, OPeerManagerMessage}; +use super::messages::{ManagedMessage, PeerManagerInputMessage, PeerManagerOutputError, PeerManagerOutputMessage}; use crate::manager::peer_info::PeerInfo; +/// Stream half of a `PeerManager`. #[allow(clippy::module_name_repetitions)] -#[allow(clippy::option_option)] -pub struct PeerManagerStream

+#[pin_project] +pub struct PeerManagerStream where - P: Sink + Stream, + Peer: Sink> + + Stream> + + TryStream + + std::fmt::Debug + + Send + + Unpin + + 'static, + Message: ManagedMessage + Send + 'static, { - recv: Receiver>, - peers: Arc>>>>, - task_queue: Arc>, - opt_pending: Option>>, + recv: mpsc::Receiver, PeerManagerOutputError>>, + #[allow(clippy::type_complexity)] + peers: Arc>>>>, + opt_pending: Option, PeerManagerOutputError>>, } -impl

PeerManagerStream

+impl PeerManagerStream where - P: Sink + Stream, + Peer: Sink> + + Stream> + + TryStream + + std::fmt::Debug + + Send + + Unpin + + 'static, + Message: ManagedMessage + Send + 'static, { - pub(super) fn new( - recv: Receiver>, - peers: Arc>>>>, - task_queue: Arc>, - ) -> PeerManagerStream

{ - PeerManagerStream { + #[allow(clippy::type_complexity)] + pub fn new( + recv: mpsc::Receiver, PeerManagerOutputError>>, + peers: Arc>>>>, + ) -> Self { + Self { recv, peers, - task_queue, opt_pending: None, } } - - fn run_with_lock_poll(&mut self, item: I, call: F, not: G) -> Poll - where - F: FnOnce(I, &mut HashMap>>) -> Poll, - G: FnOnce(I) -> Option>, - { - let (result, took_lock) = if let Ok(mut guard) = self.peers.try_lock() { - let result = call(item, &mut *guard); - - // Nothing calling us will return NotReady, so we don't have to push to queue here - - (result, true) - } else { - // Couldn't get the lock, stash a task away - self.task_queue.push(futures_task::current()); - - // Try to get the lock once more, in case of a race condition with stashing the task - if let Ok(mut guard) = self.peers.try_lock() { - let result = call(item, &mut *guard); - - // Nothing calling us will return NotReady, so we don't have to push to queue here - - (result, true) - } else { - // If we couldn't get the lock, stash the item - self.opt_pending = Some(not(item)); - - (Ok(Async::NotReady), false) - } - }; - - if took_lock { - // Just notify a single person waiting on the lock to reduce contention - if let Some(task) = self.task_queue.pop() { - task.notify(); - } - } - - result - } } -impl

Stream for PeerManagerStream

+impl Stream for PeerManagerStream where - P: Sink + Stream, + Peer: Sink> + + Stream> + + TryStream + + std::fmt::Debug + + Send + + Unpin + + 'static, + Message: ManagedMessage + Send + 'static, { - type Item = OPeerManagerMessage; - type Error = (); - - fn poll(&mut self) -> Poll, Self::Error> { - // Intercept and propagate any messages indicating the peer shutdown so we can remove them from our peer map - let next_message = self - .opt_pending - .take() - .map(|pending| Ok(Async::Ready(pending))) - .unwrap_or_else(|| self.recv.poll()); + type Item = Result, PeerManagerOutputError>; + + fn poll_next(mut self: std::pin::Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + let next_message = match self.opt_pending.take() { + Some(message) => message, + None => match self.recv.poll_next_unpin(cx) { + Poll::Ready(Some(message)) => message, + Poll::Ready(None) => return Poll::Ready(None), + Poll::Pending => return Poll::Pending, + }, + }; - next_message.and_then(|result| match result { - Async::Ready(Some(OPeerManagerMessage::PeerRemoved(info))) => self.run_with_lock_poll( - info, - |info, peers| { - peers - .remove(&info) - .unwrap_or_else(|| panic!("bip_peer: Received PeerRemoved Message With No Matching Peer In Map")); + let ready = match next_message { + Err(err) => match err { + PeerManagerOutputError::PeerError(info, _) => { + let Ok(mut peers) = self.peers.try_lock() else { + cx.waker().wake_by_ref(); + return Poll::Pending; + }; + + match peers.remove(&info) { + Some(peer) => { + drop(peer); + Poll::Ready(Some(Ok(PeerManagerOutputMessage::PeerRemoved(info)))) + } + None => Poll::Ready(Some(Err(PeerManagerOutputError::PeerErrorAndMissing( + info, + Some(Box::new(err)), + )))), + } + } + PeerManagerOutputError::PeerErrorAndMissing(_, _) + | PeerManagerOutputError::PeerDisconnectedAndMissing(_) + | PeerManagerOutputError::PeerRemovedAndMissing(_) => Poll::Ready(Some(Err(err))), + }, + Ok(PeerManagerOutputMessage::PeerRemoved(info)) => { + let Ok(mut peers) = self.peers.try_lock() else { + cx.waker().wake_by_ref(); + return Poll::Pending; + }; + + match peers.remove(&info) { + Some(peer) => { + drop(peer); + Poll::Ready(Some(Ok(PeerManagerOutputMessage::PeerRemoved(info)))) + } + None => Poll::Ready(Some(Err(PeerManagerOutputError::PeerRemovedAndMissing(info)))), + } + } - Ok(Async::Ready(Some(OPeerManagerMessage::PeerRemoved(info)))) - }, - |info| Some(OPeerManagerMessage::PeerRemoved(info)), - ), - Async::Ready(Some(OPeerManagerMessage::PeerDisconnect(info))) => self.run_with_lock_poll( - info, - |info, peers| { - peers - .remove(&info) - .unwrap_or_else(|| panic!("bip_peer: Received PeerDisconnect Message With No Matching Peer In Map")); + Ok(PeerManagerOutputMessage::PeerDisconnect(info)) => { + let Ok(mut peers) = self.peers.try_lock() else { + cx.waker().wake_by_ref(); + return Poll::Pending; + }; + + match peers.remove(&info) { + Some(peer) => { + drop(peer); + Poll::Ready(Some(Ok(PeerManagerOutputMessage::PeerRemoved(info)))) + } + None => Poll::Ready(Some(Err(PeerManagerOutputError::PeerDisconnectedAndMissing(info)))), + } + } - Ok(Async::Ready(Some(OPeerManagerMessage::PeerDisconnect(info)))) - }, - |info| Some(OPeerManagerMessage::PeerDisconnect(info)), - ), - Async::Ready(Some(OPeerManagerMessage::PeerError(info, error))) => self.run_with_lock_poll( - (info, error), - |(info, error), peers| { - peers - .remove(&info) - .unwrap_or_else(|| panic!("bip_peer: Received PeerError Message With No Matching Peer In Map")); + Ok(msg) => Poll::Ready(Some(Ok(msg))), + }; - Ok(Async::Ready(Some(OPeerManagerMessage::PeerError(info, error)))) - }, - |(info, error)| Some(OPeerManagerMessage::PeerError(info, error)), - ), - other => Ok(other), - }) + ready } } diff --git a/packages/peer/src/manager/task.rs b/packages/peer/src/manager/task.rs index 8946773b1..7b5504cd6 100644 --- a/packages/peer/src/manager/task.rs +++ b/packages/peer/src/manager/task.rs @@ -1,259 +1,182 @@ -#![allow(deprecated)] - -use std::io; - -use futures::future::{self, Future, Loop}; -use futures::sink::Sink; -use futures::stream::{MergedItem, Stream}; -use futures::sync::mpsc::{self, Sender}; -use tokio_core::reactor::Handle; -use tokio_timer::Timer; +use futures::channel::mpsc::{self, SendError}; +use futures::stream::SplitSink; +use futures::{Sink, SinkExt, Stream, StreamExt, TryStream, TryStreamExt}; +use thiserror::Error; +use tokio::task::{self, JoinHandle}; +use super::fused::{PersistentError, PersistentStream, RecurringTimeoutError, RecurringTimeoutStream}; +use super::messages::{PeerManagerInputMessage, PeerManagerOutputMessage}; use crate::manager::builder::PeerManagerBuilder; -use crate::manager::fused::{PersistentError, PersistentStream, RecurringTimeoutError, RecurringTimeoutStream}; use crate::manager::peer_info::PeerInfo; -use crate::manager::{IPeerManagerMessage, ManagedMessage, OPeerManagerMessage}; +use crate::manager::ManagedMessage; +use crate::PeerManagerOutputError; + +#[derive(Error, Debug)] +enum PeerError { + #[error("Manager Error")] + ManagerDisconnect(ManagerSendErr), + #[error("Stream Finished")] + Disconnected, + #[error("Peer Error")] + PeerDisconnect(PeerSendErr), + #[error("Peer Removed")] + PeerRemoved(PeerInfo), +} -// Separated from MergedError to -enum PeerError { - // We need to send a heartbeat (no messages sent from manager for a while) - ManagerHeartbeatInterval, - // Manager error (or expected shutdown) - ManagerDisconnect, - // Peer errors - PeerDisconnect, - Peer(io::Error), - PeerNoHeartbeat, +enum UnifiedError { + Peer(PersistentError), + Manager(RecurringTimeoutError), } -#[allow(dead_code)] -enum MergedError { - Peer(PeerError), - // Fake error types (used to stash future "futures" into an error type to be - // executed in a different future transformation, so we don't have to box them) - StageOne(A), - StageTwo(B), - StageThree(C), +enum MergedError { + Disconnect, + StreamError(Err), + Timeout, } -//----------------------------------------------------------------------------// +impl From> for MergedError { + fn from(err: UnifiedError) -> Self { + match err { + UnifiedError::Peer(PersistentError::Disconnect) | UnifiedError::Manager(RecurringTimeoutError::Disconnect) => { + Self::Disconnect + } + UnifiedError::Peer(PersistentError::StreamError(err)) + | UnifiedError::Manager(RecurringTimeoutError::StreamError(err)) => Self::StreamError(err), + UnifiedError::Manager(RecurringTimeoutError::Timeout) => Self::Timeout, + } + } +} -#[allow(clippy::too_many_lines)] -pub fn run_peer

( - peer: P, +enum UnifiedItem +where + Peer: Sink> + + Stream> + + TryStream + + std::fmt::Debug + + Send + + Unpin + + 'static, + Message: ManagedMessage + Send + 'static, +{ + Peer(Message), + Manager(PeerManagerInputMessage), +} + +pub fn run_peer( + peer: Peer, info: PeerInfo, - o_send: Sender>, - timer: Timer, + mut send: mpsc::Sender, PeerManagerOutputError>>, builder: &PeerManagerBuilder, - handle: &Handle, -) -> Sender> +) -> (mpsc::Sender>, JoinHandle<()>) where - P: Stream + Sink + 'static, - P::SinkItem: ManagedMessage, - P::Item: ManagedMessage, + Peer: Sink> + + Stream> + + TryStream + + std::fmt::Debug + + Send + + StreamExt + + Unpin + + 'static, + Message: ManagedMessage + Send + 'static, { - let (m_send, m_recv) = mpsc::channel(builder.sink_buffer_capacity()); - let (p_send, p_recv) = peer.split(); + let (manager_send, manager_recv) = mpsc::channel(builder.sink_buffer_capacity()); + let (mut peer_send, peer_recv) = peer.split(); + + let heartbeat_interval = builder.heartbeat_interval(); + + let peer_stream = Box::pin( + PersistentStream::new(peer_recv) + .map_err(UnifiedError::Peer) + .map_ok(|i| UnifiedItem::Peer(i)), + ); + + let manager_stream = Box::pin( + RecurringTimeoutStream::new(manager_recv.map(Ok), heartbeat_interval) + .map_err(UnifiedError::Manager) + .map_ok(|i| UnifiedItem::Manager(i)), + ); - // Build a stream that will timeout if no message is sent for heartbeat_timeout and teardown (don't preserve) the underlying stream - let p_stream = timer - .timeout_stream(PersistentStream::new(p_recv), builder.heartbeat_timeout()) - .map_err(|error| match error { - PersistentError::Disconnect => PeerError::PeerDisconnect, - PersistentError::Timeout => PeerError::PeerNoHeartbeat, - PersistentError::IoError(err) => PeerError::Peer(err), - }); - // Build a stream that will notify us of no message is sent for heartbeat_interval and done teardown (preserve) the underlying stream - let m_stream = RecurringTimeoutStream::new(m_recv, timer, builder.heartbeat_interval()).map_err(|error| match error { - RecurringTimeoutError::Disconnect => PeerError::ManagerDisconnect, - RecurringTimeoutError::Timeout => PeerError::ManagerHeartbeatInterval, + let mut merged_stream = Box::pin(futures::stream::select(peer_stream, manager_stream).map_err(MergedError::from)); + + let task = task::spawn(async move { + if send.send(Ok(PeerManagerOutputMessage::PeerAdded(info))).await.is_err() { + return; + } + + while let Some(result) = merged_stream.as_mut().next().await { + if handle_stream_result::(result, &mut peer_send, &mut send, &info) + .await + .is_err() + { + break; + } + } }); - let merged_stream = m_stream.merge(p_stream); + (manager_send, task) +} - handle.spawn( - o_send - .send(OPeerManagerMessage::PeerAdded(info)) - .map_err(|_| ()) - .and_then(move |o_send| { - future::loop_fn( - (merged_stream, o_send, p_send, info), - |(merged_stream, o_send, p_send, info)| { - // Our return tuple takes the form (merged_stream, Option, Option, Option, is_good) where each stage (A, B, C), - // will execute one of those options (if present), since each future transform can only execute a single future and we have 2^3 possible combinations - // (Some or None = 2)^(3 Options = 3) - merged_stream - .into_future() - .then(move |result| { - let result = match result { - Ok(( - Some(MergedItem::First(IPeerManagerMessage::SendMessage(p_info, mid, p_message))), - merged_stream, - )) => Ok(( - merged_stream, - Some(p_message), - None, - Some(OPeerManagerMessage::SentMessage(p_info, mid)), - true, - )), - Ok((Some(MergedItem::First(IPeerManagerMessage::RemovePeer(p_info))), merged_stream)) => { - Ok(( - merged_stream, - None, - None, - Some(OPeerManagerMessage::PeerRemoved(p_info)), - false, - )) - } - Ok((Some(MergedItem::Second(peer_message)), merged_stream)) => { - Ok((merged_stream, None, Some(peer_message), None, true)) - } - Ok(( - Some(MergedItem::Both( - IPeerManagerMessage::SendMessage(p_info, mid, p_message), - peer_message, - )), - merged_stream, - )) => Ok(( - merged_stream, - Some(p_message), - Some(peer_message), - Some(OPeerManagerMessage::SentMessage(p_info, mid)), - true, - )), - Ok(( - Some(MergedItem::Both(IPeerManagerMessage::RemovePeer(p_info), peer_message)), - merged_stream, - )) => Ok(( - merged_stream, - None, - Some(peer_message), - Some(OPeerManagerMessage::PeerRemoved(p_info)), - false, - )), - Ok((Some(_), _)) => { - panic!("bip_peer: Peer Future Received Invalid Message From Peer Manager") - } - Err((PeerError::ManagerHeartbeatInterval, merged_stream)) => { - Ok((merged_stream, Some(P::SinkItem::keep_alive()), None, None, true)) - } - // In this case, the manager and peer probably both disconnected at the same time? Treat as a manager disconnect. - Ok((None, _)) => Err(MergedError::Peer(PeerError::ManagerDisconnect)), - Err((PeerError::ManagerDisconnect, _)) => { - Err(MergedError::Peer(PeerError::ManagerDisconnect)) - } - Err((PeerError::PeerDisconnect | PeerError::PeerNoHeartbeat, merged_stream)) => Ok(( - merged_stream, - None, - None, - Some(OPeerManagerMessage::PeerDisconnect(info)), - false, - )), - Err((PeerError::Peer(err), merged_stream)) => Ok(( - merged_stream, - None, - None, - Some(OPeerManagerMessage::PeerError(info, err)), - false, - )), - }; +async fn handle_stream_result( + result: Result, MergedError>, + peer_send: &mut SplitSink>, + manager_send: &mut mpsc::Sender, PeerManagerOutputError>>, + info: &PeerInfo, +) -> Result<(), PeerError<>>::Error, SendError>> +where + Peer: Sink> + + Stream> + + TryStream + + std::fmt::Debug + + Send + + Unpin + + 'static, + Message: ManagedMessage + Send + 'static, +{ + match result { + Ok(UnifiedItem::Peer(message)) => { + // Handle peer message + manager_send + .send(Ok(PeerManagerOutputMessage::ReceivedMessage(*info, message))) + .await + .map_err(PeerError::ManagerDisconnect)?; - match result { - Ok((merged_stream, opt_send, opt_recv, opt_ack, is_good)) => { - if let Some(send) = opt_send { - Ok(p_send - .send(send) - .map_err(|_| MergedError::Peer(PeerError::PeerDisconnect)) - .and_then(move |p_send| { - Err(MergedError::StageOne(( - merged_stream, - o_send, - p_send, - info, - opt_recv, - opt_ack, - is_good, - ))) - })) - } else { - Err(MergedError::StageOne(( - merged_stream, - o_send, - p_send, - info, - opt_recv, - opt_ack, - is_good, - ))) - } - } - Err(err) => Err(err), - } - }) - .flatten() - .or_else(|error| { - match error { - MergedError::StageOne((merged_stream, o_send, p_send, info, opt_recv, opt_ack, is_good)) => { - if let Some(recv) = opt_recv { - if !recv.is_keep_alive() { - return Ok(o_send - .send(OPeerManagerMessage::ReceivedMessage(info, recv)) - .map_err(|_| MergedError::Peer(PeerError::ManagerDisconnect)) - .and_then(move |o_send| { - Err(MergedError::StageTwo(( - merged_stream, - o_send, - p_send, - info, - opt_ack, - is_good, - ))) - })); - } - } + Ok(()) + } + Ok(UnifiedItem::Manager(PeerManagerInputMessage::AddPeer(_, _))) => panic!("invalid message"), + Ok(UnifiedItem::Manager(PeerManagerInputMessage::RemovePeer(info))) => { + manager_send + .send(Ok(PeerManagerOutputMessage::PeerRemoved(info))) + .await + .map_err(PeerError::ManagerDisconnect)?; - // Either we had no recv message (from remote), or it was a keep alive message, which we don't propagate - Err(MergedError::StageTwo((merged_stream, o_send, p_send, info, opt_ack, is_good))) - } - err => Err(err), - } - }) - .flatten() - .or_else(|error| match error { - MergedError::StageTwo((merged_stream, o_send, p_send, info, opt_ack, is_good)) => { - if let Some(ack) = opt_ack { - Ok(o_send - .send(ack) - .map_err(|_| MergedError::Peer(PeerError::ManagerDisconnect)) - .and_then(move |o_send| { - Err(MergedError::StageThree((merged_stream, o_send, p_send, info, is_good))) - })) - } else { - Err(MergedError::StageThree((merged_stream, o_send, p_send, info, is_good))) - } - } - err => Err(err), - }) - .flatten() - .or_else(|error| { - match error { - MergedError::StageThree((merged_stream, o_send, p_send, info, is_good)) => { - // Connection is good if no errors occurred (we do this so we can use the same plumbing) - // for sending "acks" back to our manager when an error occurs, we just have None, None, - // Some, false when we want to send an error message to the manager, but terminate the connection. - if is_good { - Ok(Loop::Continue((merged_stream, o_send, p_send, info))) - } else { - Ok(Loop::Break(())) - } - } - _ => Ok(Loop::Break(())), - } - }) - }, - ) - }), - ); + Err(PeerError::PeerRemoved(info)) + } + Ok(UnifiedItem::Manager(PeerManagerInputMessage::SendMessage(info, id, message))) => { + peer_send.send(Ok(message)).await.map_err(PeerError::PeerDisconnect)?; + manager_send + .send(Ok(PeerManagerOutputMessage::SentMessage(info, id))) + .await + .map_err(PeerError::ManagerDisconnect)?; + + Ok(()) + } + Err(MergedError::Disconnect) => Err(PeerError::Disconnected), + Err(MergedError::StreamError(e)) => { + // Handle stream error + manager_send + .send(Err(PeerManagerOutputError::PeerError(*info, e))) + .await + .map_err(PeerError::ManagerDisconnect)?; + + Err(PeerError::PeerRemoved(*info)) + } + Err(MergedError::Timeout) => { + peer_send + .send(Ok(Message::keep_alive())) + .await + .map_err(PeerError::PeerDisconnect)?; - m_send + Ok(()) + } + } } diff --git a/packages/peer/src/message/bencode_util.rs b/packages/peer/src/message/bencode_util.rs index fe9c3885a..5cff6320c 100644 --- a/packages/peer/src/message/bencode_util.rs +++ b/packages/peer/src/message/bencode_util.rs @@ -1,6 +1,6 @@ use std::collections::HashMap; use std::net::{IpAddr, Ipv4Addr, Ipv6Addr}; -use std::{io, str}; +use std::str; use bencode::{BConvert, BDictAccess, BRefAccess, BencodeConvertError}; use util::convert; @@ -12,10 +12,10 @@ pub const CONVERT: IoErrorBencodeConvert = IoErrorBencodeConvert; pub struct IoErrorBencodeConvert; impl BConvert for IoErrorBencodeConvert { - type Error = io::Error; + type Error = std::io::Error; fn handle_error(&self, error: BencodeConvertError) -> Self::Error { - io::Error::new(io::ErrorKind::Other, error.to_string()) + std::io::Error::new(std::io::ErrorKind::Other, error.to_string()) } } @@ -164,7 +164,7 @@ pub const MESSAGE_TYPE_KEY: &[u8] = b"msg_type"; pub const PIECE_INDEX_KEY: &[u8] = b"piece"; pub const TOTAL_SIZE_KEY: &[u8] = b"total_size"; -pub fn parse_message_type(root: &dyn BDictAccess) -> io::Result +pub fn parse_message_type(root: &dyn BDictAccess) -> std::io::Result where V: BRefAccess, { @@ -173,14 +173,14 @@ where .map(|msg_type| msg_type.try_into().unwrap()) } -pub fn parse_piece_index(root: &dyn BDictAccess) -> io::Result +pub fn parse_piece_index(root: &dyn BDictAccess) -> std::io::Result where V: BRefAccess, { CONVERT.lookup_and_convert_int(root, PIECE_INDEX_KEY) } -pub fn parse_total_size(root: &dyn BDictAccess) -> io::Result +pub fn parse_total_size(root: &dyn BDictAccess) -> std::io::Result where V: BRefAccess, { diff --git a/packages/peer/src/message/bits_ext/handshake.rs b/packages/peer/src/message/bits_ext/handshake.rs index 7301f0fcb..d8e606006 100644 --- a/packages/peer/src/message/bits_ext/handshake.rs +++ b/packages/peer/src/message/bits_ext/handshake.rs @@ -1,6 +1,5 @@ use std::collections::HashMap; -use std::io::{self, Write}; -use std::mem; +use std::io::Write as _; use std::net::{IpAddr, Ipv4Addr, Ipv6Addr}; use bencode::{ben_bytes, ben_int, BConvert, BDecodeOpt, BMutAccess, BencodeMut, BencodeRef}; @@ -26,7 +25,11 @@ pub struct ExtendedMessageBuilder { } impl ExtendedMessageBuilder { - /// Create a new `ExtendedMessageBuilder`. + /// Creates a new `ExtendedMessageBuilder`. + /// + /// # Returns + /// + /// A new `ExtendedMessageBuilder` instance. #[must_use] pub fn new() -> ExtendedMessageBuilder { ExtendedMessageBuilder { @@ -42,14 +45,31 @@ impl ExtendedMessageBuilder { } } - /// Set our client identification in the message. + /// Sets our client identification in the message. + /// + /// # Parameters + /// + /// - `id`: The client identification string. + /// + /// # Returns + /// + /// The updated `ExtendedMessageBuilder`. #[must_use] pub fn with_our_id(mut self, id: Option) -> ExtendedMessageBuilder { self.our_id = id; self } - /// Set the given `ExtendedType` to map to the given value. + /// Sets the given `ExtendedType` to map to the given value. + /// + /// # Parameters + /// + /// - `ext_type`: The extended type. + /// - `opt_value`: The optional value to map to the extended type. + /// + /// # Returns + /// + /// The updated `ExtendedMessageBuilder`. #[must_use] pub fn with_extended_type(mut self, ext_type: ExtendedType, opt_value: Option) -> ExtendedMessageBuilder { if let Some(value) = opt_value { @@ -60,49 +80,106 @@ impl ExtendedMessageBuilder { self } - /// Set our tcp port. + /// Sets our TCP port. + /// + /// # Parameters + /// + /// - `tcp`: The TCP port. + /// + /// # Returns + /// + /// The updated `ExtendedMessageBuilder`. #[must_use] pub fn with_our_tcp_port(mut self, tcp: Option) -> ExtendedMessageBuilder { self.our_tcp_port = tcp; self } - /// Set the ip address that we see them as. + /// Sets the IP address that we see them as. + /// + /// # Parameters + /// + /// - `ip`: The IP address. + /// + /// # Returns + /// + /// The updated `ExtendedMessageBuilder`. #[must_use] pub fn with_their_ip(mut self, ip: Option) -> ExtendedMessageBuilder { self.their_ip = ip; self } - /// Set our ipv6 address. + /// Sets our IPv6 address. + /// + /// # Parameters + /// + /// - `ipv6`: The IPv6 address. + /// + /// # Returns + /// + /// The updated `ExtendedMessageBuilder`. #[must_use] pub fn with_our_ipv6_addr(mut self, ipv6: Option) -> ExtendedMessageBuilder { self.our_ipv6_addr = ipv6; self } - /// Set our ipv4 address. + /// Sets our IPv4 address. + /// + /// # Parameters + /// + /// - `ipv4`: The IPv4 address. + /// + /// # Returns + /// + /// The updated `ExtendedMessageBuilder`. #[must_use] pub fn with_our_ipv4_addr(mut self, ipv4: Option) -> ExtendedMessageBuilder { self.our_ipv4_addr = ipv4; self } - /// Set the maximum number of queued requests we support. + /// Sets the maximum number of queued requests we support. + /// + /// # Parameters + /// + /// - `max_requests`: The maximum number of queued requests. + /// + /// # Returns + /// + /// The updated `ExtendedMessageBuilder`. #[must_use] pub fn with_max_requests(mut self, max_requests: Option) -> ExtendedMessageBuilder { self.our_max_requests = max_requests; self } - /// Set the info dictionary metadata size. + /// Sets the info dictionary metadata size. + /// + /// # Parameters + /// + /// - `metadata_size`: The metadata size. + /// + /// # Returns + /// + /// The updated `ExtendedMessageBuilder`. #[must_use] pub fn with_metadata_size(mut self, metadata_size: Option) -> ExtendedMessageBuilder { self.metadata_size = metadata_size; self } - /// Set a custom entry in the message with the given dictionary key. + /// Sets a custom entry in the message with the given dictionary key. + /// + /// # Parameters + /// + /// - `key`: The dictionary key. + /// - `opt_value`: The optional value to set for the key. + /// + /// # Returns + /// + /// The updated `ExtendedMessageBuilder`. #[must_use] pub fn with_custom_entry(mut self, key: String, opt_value: Option>) -> ExtendedMessageBuilder { if let Some(value) = opt_value { @@ -113,13 +190,27 @@ impl ExtendedMessageBuilder { self } - /// Build an `ExtendedMessage` with the current options. + /// Builds an `ExtendedMessage` with the current options. + /// + /// # Returns + /// + /// The built `ExtendedMessage`. #[must_use] pub fn build(self) -> ExtendedMessage { ExtendedMessage::from_builder(self) } } +/// Encodes the builder's options into a bencode format. +/// +/// # Parameters +/// +/// - `builder`: The `ExtendedMessageBuilder` instance. +/// - `custom_entries`: A map of custom entries. +/// +/// # Returns +/// +/// A vector of bytes representing the bencoded data. fn bencode_from_builder(builder: &ExtendedMessageBuilder, mut custom_entries: HashMap>) -> Vec { let opt_our_ip = builder.their_ip.map(|their_ip| match their_ip { IpAddr::V4(ipv4_addr) => convert::ipv4_to_bytes_be(ipv4_addr).to_vec(), @@ -174,9 +265,9 @@ fn bencode_from_builder(builder: &ExtendedMessageBuilder, mut custom_entries: Ha // ----------------------------------------------------------------------------// -// Terminology is written as if we were receiving the message. Example: Our ip is -// the ip that the sender sees us as. So if were sending this message, it would be -// the ip we see the client as. +// Terminology is written as if we were receiving the message. Example: Our IP is +// the IP that the sender sees us as. So if we're sending this message, it would be +// the IP we see the client as. const ROOT_ERROR_KEY: &str = "ExtendedMessage"; @@ -192,7 +283,15 @@ pub enum ExtendedType { } impl ExtendedType { - /// Create an `ExtendedType` from the given identifier. + /// Creates an `ExtendedType` from the given identifier. + /// + /// # Parameters + /// + /// - `id`: The identifier string. + /// + /// # Returns + /// + /// An `ExtendedType` instance corresponding to the identifier. #[must_use] pub fn from_id(id: &str) -> ExtendedType { match id { @@ -202,12 +301,16 @@ impl ExtendedType { } } - /// Retrieve the message id corresponding to the given `ExtendedType`. + /// Retrieves the message id corresponding to the given `ExtendedType`. + /// + /// # Returns + /// + /// The message id as a string slice. #[must_use] pub fn id(&self) -> &str { match self { - &ExtendedType::UtMetadata => UT_METADATA_ID, - &ExtendedType::UtPex => UT_PEX_ID, + ExtendedType::UtMetadata => UT_METADATA_ID, + ExtendedType::UtPex => UT_PEX_ID, ExtendedType::Custom(id) => id, } } @@ -230,11 +333,19 @@ pub struct ExtendedMessage { } impl ExtendedMessage { - /// Create an `ExtendedMessage` from an `ExtendedMessageBuilder`. + /// Creates an `ExtendedMessage` from an `ExtendedMessageBuilder`. + /// + /// # Parameters + /// + /// - `builder`: The `ExtendedMessageBuilder` instance. + /// + /// # Returns + /// + /// An `ExtendedMessage` instance. #[must_use] pub fn from_builder(mut builder: ExtendedMessageBuilder) -> ExtendedMessage { let mut custom_entries = HashMap::new(); - mem::swap(&mut custom_entries, &mut builder.custom_entries); + std::mem::swap(&mut custom_entries, &mut builder.custom_entries); let encoded_bytes = bencode_from_builder(&builder, custom_entries); let mut raw_bencode = BytesMut::with_capacity(encoded_bytes.len()); @@ -253,16 +364,28 @@ impl ExtendedMessage { } } - /// Parse an `ExtendedMessage` from some raw bencode of the given length. - pub fn parse_bytes(_input: (), mut bytes: Bytes, len: u32) -> IResult<(), io::Result> { + /// Parses an `ExtendedMessage` from some raw bencode of the given length. + /// + /// # Parameters + /// + /// - `bytes`: The byte slice to parse. + /// - `len`: The length of the bencode data. + /// + /// # Returns + /// + /// An `IResult` containing the remaining byte slice and an `io::Result` with the parsed `ExtendedMessage`. + /// + /// # Errors + /// + /// This function will return an error if the byte slice cannot be parsed into an `ExtendedMessage`. + pub fn parse_bytes(bytes: &[u8], len: u32) -> IResult<&[u8], std::io::Result> { let cast_len = message::u32_to_usize(len); if bytes.len() >= cast_len { - let raw_bencode = bytes.split_to(cast_len); - let clone_raw_bencode = raw_bencode.clone(); + let (raw_bencode, _) = bytes.split_at(cast_len); - let res_extended_message = BencodeRef::decode(&raw_bencode, BDecodeOpt::default()) - .map_err(|err| io::Error::new(io::ErrorKind::Other, err.to_string())) + let res_extended_message = BencodeRef::decode(raw_bencode, BDecodeOpt::default()) + .map_err(|err| std::io::Error::new(std::io::ErrorKind::Other, err.to_string())) .and_then(|bencode| { let ben_dict = bencode_util::CONVERT.convert_dict(&bencode, ROOT_ERROR_KEY)?; @@ -284,91 +407,139 @@ impl ExtendedMessage { our_ipv4_addr, our_max_requests, metadata_size, - raw_bencode: clone_raw_bencode, + raw_bencode: Bytes::copy_from_slice(raw_bencode), }) }); - IResult::Done((), res_extended_message) + // Clone the remaining bytes to avoid returning a reference to the parameter + IResult::Ok((bytes, res_extended_message)) } else { - IResult::Incomplete(Needed::Size(cast_len - bytes.len())) + Err(nom::Err::Incomplete(Needed::new(cast_len - bytes.len()))) } } - /// Write the `ExtendedMessage` out to the given writer. + /// Writes the `ExtendedMessage` out to the given writer. + /// + /// # Parameters + /// + /// - `writer`: The writer to which the bytes will be written. /// /// # Errors /// - /// It will return an IP error if unable to write the bytes. + /// This function will return an error if unable to write the bytes. /// /// # Panics /// - /// It would panic if the bencode size it too large. - pub fn write_bytes(&self, mut writer: W) -> io::Result<()> + /// This function will panic if the bencode size is too large. + pub fn write_bytes(&self, mut writer: W) -> std::io::Result where - W: Write, + W: std::io::Write, { - let real_length = 2 + self.bencode_size(); - message::write_length_id_pair( - &mut writer, - real_length.try_into().unwrap(), - Some(bits_ext::EXTENDED_MESSAGE_ID), - )?; + let real_length: u32 = (self.bencode_size() + 2) + .try_into() + .map_err(|e| std::io::Error::new(std::io::ErrorKind::Unsupported, e))?; - writer.write_all(&[bits_ext::EXTENDED_MESSAGE_HANDSHAKE_ID]); + let len = message::write_length_id_pair(&mut writer, real_length, Some(bits_ext::EXTENDED_MESSAGE_ID))?; - writer.write_all(self.raw_bencode.as_ref()) + let () = writer.write_all(&[bits_ext::EXTENDED_MESSAGE_HANDSHAKE_ID])?; + let () = writer.write_all(self.raw_bencode.as_ref())?; + Ok(real_length as usize + len) } - /// Get the size of the bencode portion of this message. + /// Gets the size of the bencode portion of this message. + /// + /// # Returns + /// + /// The size of the bencode portion in bytes. pub fn bencode_size(&self) -> usize { self.raw_bencode.len() } - /// Query for the id corresponding to the given `ExtendedType`. + /// Queries for the id corresponding to the given `ExtendedType`. + /// + /// # Parameters + /// + /// - `ext_type`: The extended type. + /// + /// # Returns + /// + /// An optional u8 representing the id. pub fn query_id(&self, ext_type: &ExtendedType) -> Option { self.id_map.get(ext_type).copied() } - /// Retrieve our id from the message. + /// Retrieves our id from the message. + /// + /// # Returns + /// + /// An optional string slice representing our id. pub fn our_id(&self) -> Option<&str> { self.our_id.as_deref() } - /// Retrieve our tcp port from the message. + /// Retrieves our TCP port from the message. + /// + /// # Returns + /// + /// An optional u16 representing our TCP port. pub fn our_tcp_port(&self) -> Option { self.our_tcp_port } - /// Retrieve their ip address from the message. + /// Retrieves their IP address from the message. + /// + /// # Returns + /// + /// An optional `IpAddr` representing their IP address. pub fn their_ip(&self) -> Option { self.their_ip } - /// Retrieve our ipv6 address from the message. + /// Retrieves our IPv6 address from the message. + /// + /// # Returns + /// + /// An optional `Ipv6Addr` representing our IPv6 address. pub fn our_ipv6_addr(&self) -> Option { self.our_ipv6_addr } - /// Retrieve our ipv4 address from the message. + /// Retrieves our IPv4 address from the message. + /// + /// # Returns + /// + /// An optional `Ipv4Addr` representing our IPv4 address. pub fn our_ipv4_addr(&self) -> Option { self.our_ipv4_addr } - /// Retrieve our max queued requests from the message. + /// Retrieves our max queued requests from the message. + /// + /// # Returns + /// + /// An optional i64 representing our max queued requests. pub fn our_max_requests(&self) -> Option { self.our_max_requests } - /// Retrieve the info dictionary metadata size from the message. + /// Retrieves the info dictionary metadata size from the message. + /// + /// # Returns + /// + /// An optional i64 representing the metadata size. pub fn metadata_size(&self) -> Option { self.metadata_size } - /// Retrieve a raw `BencodeRef` representing the current message. + /// Retrieves a raw `BencodeRef` representing the current message. /// /// # Panics /// - /// It would panic if unable to decode the bencode. + /// This function will panic if unable to decode the bencode. + /// + /// # Returns + /// + /// A `BencodeRef` representing the current message. pub fn bencode_ref(&self) -> BencodeRef<'_> { // We already verified that this is valid bencode BencodeRef::decode(&self.raw_bencode, BDecodeOpt::default()).unwrap() diff --git a/packages/peer/src/message/bits_ext/mod.rs b/packages/peer/src/message/bits_ext/mod.rs index b350717ba..84a1dd676 100644 --- a/packages/peer/src/message/bits_ext/mod.rs +++ b/packages/peer/src/message/bits_ext/mod.rs @@ -1,13 +1,17 @@ use std::collections::HashMap; -use std::io::{self, Write}; +use std::io::Write as _; use std::net::{IpAddr, Ipv4Addr, Ipv6Addr}; use bencode::{BConvert, BDecodeOpt, BMutAccess, BencodeMut, BencodeRef}; use byteorder::{BigEndian, WriteBytesExt}; use bytes::Bytes; -use nom::{ - alt, be_u16, be_u32, be_u8, call, error_node_position, error_position, map, switch, tuple, tuple_parser, IResult, Needed, -}; +use nom::branch::alt; +use nom::bytes::complete::{take, take_while}; +use nom::combinator::map; +use nom::error::Error; +use nom::number::complete::{be_u16, be_u32, be_u8}; +use nom::sequence::tuple; +use nom::IResult; use util::convert; use crate::message; @@ -32,55 +36,88 @@ pub use self::port::PortMessage; /// Sent after the handshake if the corresponding extension bit is set. #[derive(Clone, Debug, PartialEq, Eq)] pub enum BitsExtensionMessage { - /// message for determining the port a peer's DHT is listening on. + /// Message for determining the port a peer's DHT is listening on. Port(PortMessage), /// Message for sending a peer the map of extensions we support. Extended(ExtendedMessage), } impl BitsExtensionMessage { - pub fn parse_bytes(_input: (), bytes: Bytes) -> IResult<(), io::Result> { - parse_extension(bytes) + /// Parses a byte slice into a `BitsExtensionMessage`. + /// + /// # Parameters + /// + /// - `input`: The byte slice to parse. + /// + /// # Returns + /// + /// An `IResult` containing the remaining byte slice and an `io::Result` with the parsed `BitsExtensionMessage`. + /// + /// # Errors + /// + /// This function will return an error if the byte slice cannot be parsed into a `BitsExtensionMessage`. + pub fn parse_bytes<'a>(input: &'a [u8]) -> IResult<&'a [u8], std::io::Result> { + let port_fn = |input: &'a [u8]| -> IResult<&'a [u8], std::io::Result> { + let (_, (message_len, message_id)) = tuple((be_u32, be_u8))(input)?; + + if (message_len, message_id) == (PORT_MESSAGE_LEN, PORT_MESSAGE_ID) { + let (_, res_port) = PortMessage::parse_bytes(&input[message::HEADER_LEN..])?; + Ok((input, Ok(BitsExtensionMessage::Port(res_port)))) + } else { + Err(nom::Err::Error(nom::error::Error { + input, + code: nom::error::ErrorKind::Switch, + })) + } + }; + + let ext_fn = |input: &'a [u8]| -> IResult<&'a [u8], std::io::Result> { + let (_, (message_len, extended_message_id, extended_message_handshake_id)) = tuple((be_u32, be_u8, be_u8))(input)?; + + if (message_len, extended_message_id, extended_message_handshake_id) + == (message_len, EXTENDED_MESSAGE_ID, EXTENDED_MESSAGE_HANDSHAKE_ID) + { + let (_, res_extended) = ExtendedMessage::parse_bytes(&input[message::HEADER_LEN + 1..], message_len - 2)?; + Ok((input, res_extended.map(BitsExtensionMessage::Extended))) + } else { + Err(nom::Err::Error(nom::error::Error { + input, + code: nom::error::ErrorKind::Switch, + })) + } + }; + + alt((port_fn, ext_fn))(input) } - /// Writes bytes into the current [`BitsExtensionMessage`]. + /// Writes the current state of the `BitsExtensionMessage` as bytes. + /// + /// # Parameters + /// + /// - `writer`: The writer to which the bytes will be written. /// /// # Errors /// /// This function will return an error if unable to write the bytes. - pub fn write_bytes(&self, writer: W) -> io::Result<()> + pub fn write_bytes(&self, mut writer: W) -> std::io::Result where - W: Write, + W: std::io::Write, { match self { - &BitsExtensionMessage::Port(msg) => msg.write_bytes(writer), + BitsExtensionMessage::Port(msg) => msg.write_bytes(writer), BitsExtensionMessage::Extended(msg) => msg.write_bytes(writer), } } + /// Returns the size of the `BitsExtensionMessage`. + /// + /// # Returns + /// + /// The size of the message in bytes. pub fn message_size(&self) -> usize { match self { - &BitsExtensionMessage::Port(_) => PORT_MESSAGE_LEN as usize, + BitsExtensionMessage::Port(_) => PORT_MESSAGE_LEN as usize, BitsExtensionMessage::Extended(msg) => BASE_EXTENDED_MESSAGE_LEN as usize + msg.bencode_size(), } } } - -fn parse_extension(mut bytes: Bytes) -> IResult<(), io::Result> { - let header_bytes = bytes.clone(); - - alt!( - (), - ignore_input!(switch!(header_bytes.as_ref(), throwaway_input!(tuple!(be_u32, be_u8)), - (PORT_MESSAGE_LEN, PORT_MESSAGE_ID) => map!( - call!(PortMessage::parse_bytes, &bytes.split_off(message::HEADER_LEN)), - |res_port| res_port.map(BitsExtensionMessage::Port) - ) - )) | ignore_input!(switch!(header_bytes.as_ref(), throwaway_input!(tuple!(be_u32, be_u8, be_u8)), - (message_len, EXTENDED_MESSAGE_ID, EXTENDED_MESSAGE_HANDSHAKE_ID) => map!( - call!(ExtendedMessage::parse_bytes, bytes.split_off(message::HEADER_LEN + 1), message_len - 2), - |res_extended| res_extended.map(BitsExtensionMessage::Extended) - ) - )) - ) -} diff --git a/packages/peer/src/message/bits_ext/port.rs b/packages/peer/src/message/bits_ext/port.rs index 3bc45eca4..ea7656bbe 100644 --- a/packages/peer/src/message/bits_ext/port.rs +++ b/packages/peer/src/message/bits_ext/port.rs @@ -1,9 +1,9 @@ -use std::io; -use std::io::Write; +use std::io::Write as _; -use byteorder::WriteBytesExt; -use bytes::{BigEndian, Bytes}; -use nom::{be_u16, call, map, IResult}; +use bytes::{BufMut, Bytes, BytesMut}; +use nom::combinator::map; +use nom::number::complete::be_u16; +use nom::IResult; use crate::message; use crate::message::bits_ext; @@ -16,34 +16,56 @@ pub struct PortMessage { } impl PortMessage { + /// Creates a new `PortMessage`. + /// + /// # Parameters + /// + /// - `port`: The DHT port number. + /// + /// # Returns + /// + /// A new `PortMessage` instance. #[must_use] pub fn new(port: u16) -> PortMessage { PortMessage { port } } - pub fn parse_bytes(_input: (), bytes: &Bytes) -> IResult<(), io::Result> { - match parse_port(bytes.as_ref()) { - IResult::Done(_, result) => IResult::Done((), Ok(result)), - IResult::Error(err) => IResult::Error(err), - IResult::Incomplete(need) => IResult::Incomplete(need), - } + /// Parses a byte slice into a `PortMessage`. + /// + /// # Parameters + /// + /// - `bytes`: The byte slice to parse. + /// + /// # Returns + /// + /// An `IResult` containing the remaining byte slice and the parsed `PortMessage`. + /// + /// # Errors + /// + /// This function will return an error if the byte slice cannot be parsed into a `PortMessage`. + pub fn parse_bytes(bytes: &[u8]) -> IResult<&[u8], PortMessage> { + map(be_u16, PortMessage::new)(bytes) } - /// Writes bytes into the current [`PortMessage`]. + /// Writes the current state of the `PortMessage` as bytes. + /// + /// # Parameters + /// + /// - `writer`: The writer to which the bytes will be written. /// /// # Errors /// /// This function will return an error if unable to write the bytes. - pub fn write_bytes(&self, mut writer: W) -> io::Result<()> + pub fn write_bytes(&self, mut writer: W) -> std::io::Result where - W: Write, + W: std::io::Write, { - message::write_length_id_pair(&mut writer, bits_ext::PORT_MESSAGE_LEN, Some(bits_ext::PORT_MESSAGE_ID))?; + let length_len = message::write_length_id_pair(&mut writer, bits_ext::PORT_MESSAGE_LEN, Some(bits_ext::PORT_MESSAGE_ID))?; - writer.write_u16::(self.port) - } -} + let mut buf = BytesMut::with_capacity(2); + let () = buf.put_u16(self.port); + let () = writer.write_all(&buf)?; -fn parse_port(bytes: &[u8]) -> IResult<&[u8], PortMessage> { - map!(bytes, be_u16, PortMessage::new) + Ok(length_len + 2) + } } diff --git a/packages/peer/src/message/mod.rs b/packages/peer/src/message/mod.rs index 49b1d8fdd..080865789 100644 --- a/packages/peer/src/message/mod.rs +++ b/packages/peer/src/message/mod.rs @@ -4,13 +4,18 @@ // Nom has lots of unused warnings atm, keep this here for now. -use std::io::{self, Write}; +use std::io::Write as _; use byteorder::{BigEndian, WriteBytesExt}; use bytes::Bytes; -use nom::{alt, be_u32, be_u8, call, error_node_position, error_position, map, opt, switch, tuple, tuple_parser, value, IResult}; +use nom::branch::alt; +use nom::bytes::complete::take; +use nom::combinator::{all_consuming, map, map_res, opt, value}; +use nom::number::complete::{be_u32, be_u8}; +use nom::sequence::{preceded, tuple}; +use nom::IResult; +use thiserror::Error; -use crate::manager::messages::ManagedMessage; use crate::protocol::PeerProtocol; // TODO: Propagate failures to cast values to/from usize @@ -52,16 +57,30 @@ pub use crate::message::bits_ext::{BitsExtensionMessage, ExtendedMessage, Extend pub use crate::message::null::NullProtocolMessage; #[allow(clippy::module_name_repetitions)] pub use crate::message::prot_ext::{ - PeerExtensionProtocolMessage, UtMetadataDataMessage, UtMetadataMessage, UtMetadataRejectMessage, UtMetadataRequestMessage, + PeerExtensionProtocolMessage, PeerExtensionProtocolMessageError, UtMetadataDataMessage, UtMetadataMessage, + UtMetadataRejectMessage, UtMetadataRequestMessage, }; #[allow(clippy::module_name_repetitions)] pub use crate::message::standard::{BitFieldIter, BitFieldMessage, CancelMessage, HaveMessage, PieceMessage, RequestMessage}; +use crate::ManagedMessage; + +#[derive(Error, Debug, Clone)] +pub enum PeerWireProtocolMessageError {} + +impl From for std::io::Error { + fn from(err: PeerWireProtocolMessageError) -> Self { + std::io::Error::new(std::io::ErrorKind::Other, err) + } +} /// Enumeration of messages for `PeerWireProtocol`. #[allow(clippy::module_name_repetitions)] +#[derive(Debug, Clone)] pub enum PeerWireProtocolMessage

where - P: PeerProtocol, + P: PeerProtocol + Clone + std::fmt::Debug, +

::ProtocolMessage: std::fmt::Debug, +

::ProtocolMessageError: std::fmt::Debug, { /// Message to keep the connection alive. KeepAlive, @@ -93,12 +112,14 @@ where /// /// In reality, this can be any type that implements `ProtocolMessage` if, for example, /// you are running a private swarm where you know all nodes support a given message(s). - ProtExtension(P::ProtocolMessage), + ProtExtension(Result), } impl

ManagedMessage for PeerWireProtocolMessage

where - P: PeerProtocol, + P: PeerProtocol + Clone + std::fmt::Debug, +

::ProtocolMessage: std::fmt::Debug, +

::ProtocolMessageError: std::fmt::Debug, { fn keep_alive() -> PeerWireProtocolMessage

{ PeerWireProtocolMessage::KeepAlive @@ -111,17 +132,18 @@ where impl

PeerWireProtocolMessage

where - P: PeerProtocol, + P: PeerProtocol + Clone + std::fmt::Debug, +

::ProtocolMessage: std::fmt::Debug, +

::ProtocolMessageError: std::fmt::Debug, { /// Bytes Needed to encode Byte Slice /// /// # Errors /// /// This function will not return an error. - pub fn bytes_needed(bytes: &[u8]) -> io::Result> { - match be_u32(bytes) { - // We need 4 bytes for the length, plus whatever the length is... - IResult::Done(_, length) => Ok(Some(MESSAGE_LENGTH_LEN_BYTES + u32_to_usize(length))), + pub fn bytes_needed(bytes: &[u8]) -> std::io::Result> { + match be_u32::<_, nom::error::Error<&[u8]>>(bytes) { + Ok((_, length)) => Ok(Some(MESSAGE_LENGTH_LEN_BYTES + u32_to_usize(length))), _ => Ok(None), } } @@ -131,11 +153,11 @@ where /// # Errors /// /// This function will return an error if unable to parse bytes for supplied protocol. - pub fn parse_bytes(bytes: Bytes, ext_protocol: &mut P) -> io::Result> { + pub fn parse_bytes(bytes: &[u8], ext_protocol: &mut P) -> std::io::Result> { match parse_message(bytes, ext_protocol) { - IResult::Done((), result) => result, - _ => Err(io::Error::new( - io::ErrorKind::Other, + Ok((_, result)) => result, + _ => Err(std::io::Error::new( + std::io::ErrorKind::Other, "Failed To Parse PeerWireProtocolMessage", )), } @@ -146,9 +168,9 @@ where /// # Errors /// /// This function will return an error if unable to write bytes. - pub fn write_bytes(&self, writer: W, ext_protocol: &mut P) -> io::Result<()> + pub fn write_bytes(&self, writer: W, ext_protocol: &mut P) -> std::io::Result where - W: Write, + W: std::io::Write, { match self { &PeerWireProtocolMessage::KeepAlive => write_length_id_pair(writer, KEEP_ALIVE_MESSAGE_LEN, None), @@ -170,7 +192,12 @@ where } } - pub fn message_size(&self, ext_protocol: &mut P) -> usize { + /// Retrieve how many bytes the message will occupy on the wire. + /// + /// # Errors + /// + /// This function will return an error if unable to calculate the message length. + pub fn message_size(&self, ext_protocol: &mut P) -> std::io::Result { let message_specific_len = match self { &PeerWireProtocolMessage::KeepAlive => KEEP_ALIVE_MESSAGE_LEN as usize, &PeerWireProtocolMessage::Choke => CHOKE_MESSAGE_LEN as usize, @@ -183,24 +210,25 @@ where PeerWireProtocolMessage::Piece(msg) => BASE_PIECE_MESSAGE_LEN as usize + msg.block().len(), &PeerWireProtocolMessage::Cancel(_) => CANCEL_MESSAGE_LEN as usize, PeerWireProtocolMessage::BitsExtension(ext) => ext.message_size(), - PeerWireProtocolMessage::ProtExtension(ext) => ext_protocol.message_size(ext), + PeerWireProtocolMessage::ProtExtension(ext) => ext_protocol.message_size(ext)?, }; - MESSAGE_LENGTH_LEN_BYTES + message_specific_len + Ok(MESSAGE_LENGTH_LEN_BYTES + message_specific_len) } } /// Write a length and optional id out to the given writer. -fn write_length_id_pair(mut writer: W, length: u32, opt_id: Option) -> io::Result<()> +fn write_length_id_pair(mut writer: W, length: u32, opt_id: Option) -> std::io::Result where - W: Write, + W: std::io::Write, { writer.write_u32::(length)?; if let Some(id) = opt_id { - writer.write_u8(id) + let () = writer.write_u8(id)?; + Ok(5) } else { - Ok(()) + Ok(4) } } @@ -208,7 +236,7 @@ where /// /// Panics if parsing failed for any reason. fn parse_message_length(bytes: &[u8]) -> usize { - if let IResult::Done(_, len) = be_u32(bytes) { + if let Ok((_, len)) = be_u32::<_, nom::error::Error<&[u8]>>(bytes) { u32_to_usize(len) } else { panic!("bip_peer: Message Length Was Less Than 4 Bytes") @@ -224,54 +252,226 @@ fn u32_to_usize(value: u32) -> usize { // the number of bytes needed will be returned. However, that number of bytes is on a per parser // basis. If possible, we should return the number of bytes needed for the rest of the WHOLE message. // This allows clients to only re invoke the parser when it knows it has enough of the data. -fn parse_message

(mut bytes: Bytes, ext_protocol: &mut P) -> IResult<(), io::Result>> + +fn parse_keep_alive

(input: &[u8]) -> IResult<&[u8], std::io::Result>> +where + P: PeerProtocol + Clone + std::fmt::Debug, +

::ProtocolMessage: std::fmt::Debug, +

::ProtocolMessageError: std::fmt::Debug, +{ + map( + tuple(( + be_u32::<_, nom::error::Error<&[u8]>>, + opt(be_u8::<_, nom::error::Error<&[u8]>>), + )), + |_| Ok(PeerWireProtocolMessage::KeepAlive), + )(input) +} + +fn parse_choke

(input: &[u8]) -> IResult<&[u8], std::io::Result>> +where + P: PeerProtocol + Clone + std::fmt::Debug, +

::ProtocolMessage: std::fmt::Debug, +

::ProtocolMessageError: std::fmt::Debug, +{ + map( + tuple(( + value(CHOKE_MESSAGE_LEN, be_u32::<_, nom::error::Error<&[u8]>>), + value(Some(CHOKE_MESSAGE_ID), be_u8::<_, nom::error::Error<&[u8]>>), + )), + |_| Ok(PeerWireProtocolMessage::Choke), + )(input) +} + +fn parse_unchoke

(input: &[u8]) -> IResult<&[u8], std::io::Result>> +where + P: PeerProtocol + Clone + std::fmt::Debug, +

::ProtocolMessage: std::fmt::Debug, +

::ProtocolMessageError: std::fmt::Debug, +{ + map( + tuple(( + value(UNCHOKE_MESSAGE_LEN, be_u32::<_, nom::error::Error<&[u8]>>), + value(Some(UNCHOKE_MESSAGE_ID), be_u8::<_, nom::error::Error<&[u8]>>), + )), + |_| Ok(PeerWireProtocolMessage::UnChoke), + )(input) +} + +fn parse_interested

(input: &[u8]) -> IResult<&[u8], std::io::Result>> +where + P: PeerProtocol + Clone + std::fmt::Debug, +

::ProtocolMessage: std::fmt::Debug, +

::ProtocolMessageError: std::fmt::Debug, +{ + map( + tuple(( + value(INTERESTED_MESSAGE_LEN, be_u32::<_, nom::error::Error<&[u8]>>), + value(Some(INTERESTED_MESSAGE_ID), be_u8::<_, nom::error::Error<&[u8]>>), + )), + |_| Ok(PeerWireProtocolMessage::Interested), + )(input) +} + +fn parse_uninterested

(input: &[u8]) -> IResult<&[u8], std::io::Result>> +where + P: PeerProtocol + Clone + std::fmt::Debug, +

::ProtocolMessage: std::fmt::Debug, +

::ProtocolMessageError: std::fmt::Debug, +{ + map( + tuple(( + value(UNINTERESTED_MESSAGE_LEN, be_u32::<_, nom::error::Error<&[u8]>>), + value(Some(UNINTERESTED_MESSAGE_ID), be_u8::<_, nom::error::Error<&[u8]>>), + )), + |_| Ok(PeerWireProtocolMessage::UnInterested), + )(input) +} + +fn parse_have

(input: &[u8]) -> IResult<&[u8], std::io::Result>> +where + P: PeerProtocol + Clone + std::fmt::Debug, +

::ProtocolMessage: std::fmt::Debug, +

::ProtocolMessageError: std::fmt::Debug, +{ + map( + preceded( + tuple(( + value(HAVE_MESSAGE_LEN, be_u32::<_, nom::error::Error<&[u8]>>), + value(Some(HAVE_MESSAGE_ID), be_u8::<_, nom::error::Error<&[u8]>>), + )), + take(4_usize), + ), + |have| HaveMessage::parse_bytes(have).map(PeerWireProtocolMessage::Have), + )(input) +} + +fn parse_bitfield

(input: &[u8]) -> IResult<&[u8], std::io::Result>> +where + P: PeerProtocol + Clone + std::fmt::Debug, +

::ProtocolMessage: std::fmt::Debug, +

::ProtocolMessageError: std::fmt::Debug, +{ + map( + preceded( + tuple(( + value(BASE_BITFIELD_MESSAGE_LEN, be_u32::<_, nom::error::Error<&[u8]>>), + value(Some(BITFIELD_MESSAGE_ID), be_u8::<_, nom::error::Error<&[u8]>>), + )), + take(4_usize), + ), + |bitfield| BitFieldMessage::parse_bytes(bitfield).map(PeerWireProtocolMessage::BitField), + )(input) +} + +fn parse_request

(input: &[u8]) -> IResult<&[u8], std::io::Result>> +where + P: PeerProtocol + Clone + std::fmt::Debug, +

::ProtocolMessage: std::fmt::Debug, +

::ProtocolMessageError: std::fmt::Debug, +{ + map( + preceded( + tuple(( + value(REQUEST_MESSAGE_LEN, be_u32::<_, nom::error::Error<&[u8]>>), + value(Some(REQUEST_MESSAGE_ID), be_u8::<_, nom::error::Error<&[u8]>>), + )), + take(4_usize), + ), + |request| RequestMessage::parse_bytes(request).map(PeerWireProtocolMessage::Request), + )(input) +} + +fn parse_piece

(input: &[u8]) -> IResult<&[u8], std::io::Result>> +where + P: PeerProtocol + Clone + std::fmt::Debug, +

::ProtocolMessage: std::fmt::Debug, +

::ProtocolMessageError: std::fmt::Debug, +{ + map( + preceded( + tuple(( + value(BASE_PIECE_MESSAGE_LEN, be_u32::<_, nom::error::Error<&[u8]>>), + value(Some(PIECE_MESSAGE_ID), be_u8::<_, nom::error::Error<&[u8]>>), + )), + take(4_usize), + ), + |piece| { + let len = parse_message_length(piece); + PieceMessage::parse_bytes(piece, len).map(PeerWireProtocolMessage::Piece) + }, + )(input) +} + +fn parse_cancel

(input: &[u8]) -> IResult<&[u8], std::io::Result>> +where + P: PeerProtocol + Clone + std::fmt::Debug, +

::ProtocolMessage: std::fmt::Debug, +

::ProtocolMessageError: std::fmt::Debug, +{ + map( + preceded( + tuple(( + value(CANCEL_MESSAGE_LEN, be_u32::<_, nom::error::Error<&[u8]>>), + value(Some(CANCEL_MESSAGE_ID), be_u8::<_, nom::error::Error<&[u8]>>), + )), + take(4_usize), + ), + |cancel| CancelMessage::parse_bytes(cancel).map(PeerWireProtocolMessage::Cancel), + )(input) +} + +fn parse_bits_extension

(input: &[u8]) -> IResult<&[u8], std::io::Result>> +where + P: PeerProtocol + Clone + std::fmt::Debug, +

::ProtocolMessage: std::fmt::Debug, +

::ProtocolMessageError: std::fmt::Debug, +{ + map( + |input| BitsExtensionMessage::parse_bytes(input), + |res_bits_ext| res_bits_ext.map(|bits_ext| PeerWireProtocolMessage::BitsExtension(bits_ext)), + )(input) +} + +fn parse_prot_extension<'a, P>( + input: &'a [u8], + ext_protocol: &mut P, +) -> IResult<&'a [u8], std::io::Result>> +where + P: PeerProtocol + Clone + std::fmt::Debug, +

::ProtocolMessage: std::fmt::Debug, +

::ProtocolMessageError: std::fmt::Debug, +{ + map( + |input| match ext_protocol.parse_bytes(input) { + Ok(msg) => Ok((input, Ok(PeerWireProtocolMessage::ProtExtension(msg)))), + Err(_) => Err(nom::Err::Error(nom::error::Error { + input, + code: nom::error::ErrorKind::Fail, + })), + }, + |result| result, + )(input) +} + +fn parse_message<'a, P>(bytes: &'a [u8], ext_protocol: &mut P) -> IResult<&'a [u8], std::io::Result>> where - P: PeerProtocol, + P: PeerProtocol + Clone + std::fmt::Debug, +

::ProtocolMessage: std::fmt::Debug, +

::ProtocolMessageError: std::fmt::Debug, { - let header_bytes = bytes.clone(); - - // Attempt to parse a built in message type, otherwise, see if it is an extension type. - alt!( - (), - ignore_input!(switch!(header_bytes.as_ref(), throwaway_input!(tuple!(be_u32, opt!(be_u8))), - (KEEP_ALIVE_MESSAGE_LEN, None) => value!( - Ok(PeerWireProtocolMessage::KeepAlive) - ) | - (CHOKE_MESSAGE_LEN, Some(CHOKE_MESSAGE_ID)) => value!( - Ok(PeerWireProtocolMessage::Choke) - ) | - (UNCHOKE_MESSAGE_LEN, Some(UNCHOKE_MESSAGE_ID)) => value!( - Ok(PeerWireProtocolMessage::UnChoke) - ) | - (INTERESTED_MESSAGE_LEN, Some(INTERESTED_MESSAGE_ID)) => value!( - Ok(PeerWireProtocolMessage::Interested) - ) | - (UNINTERESTED_MESSAGE_LEN, Some(UNINTERESTED_MESSAGE_ID)) => value!( - Ok(PeerWireProtocolMessage::UnInterested) - ) | - (HAVE_MESSAGE_LEN, Some(HAVE_MESSAGE_ID)) => map!( - call!(HaveMessage::parse_bytes, &bytes.split_off(HEADER_LEN)), - |res_have| res_have.map(|have| PeerWireProtocolMessage::Have(have)) - ) | - (message_len, Some(BITFIELD_MESSAGE_ID)) => map!( - call!(BitFieldMessage::parse_bytes, bytes.split_off(HEADER_LEN), message_len - 1), - |res_bitfield| res_bitfield.map(|bitfield| PeerWireProtocolMessage::BitField(bitfield)) - ) | - (REQUEST_MESSAGE_LEN, Some(REQUEST_MESSAGE_ID)) => map!( - call!(RequestMessage::parse_bytes, &bytes.split_off(HEADER_LEN)), - |res_request| res_request.map(|request| PeerWireProtocolMessage::Request(request)) - ) | - (message_len, Some(PIECE_MESSAGE_ID)) => map!( - call!(PieceMessage::parse_bytes, &bytes.split_off(HEADER_LEN), message_len - 1), - |res_piece| res_piece.map(|piece| PeerWireProtocolMessage::Piece(piece)) - ) | - (CANCEL_MESSAGE_LEN, Some(CANCEL_MESSAGE_ID)) => map!( - call!(CancelMessage::parse_bytes, &bytes.split_off(HEADER_LEN)), - |res_cancel| res_cancel.map(|cancel| PeerWireProtocolMessage::Cancel(cancel)) - ) - )) | map!(call!(BitsExtensionMessage::parse_bytes, bytes.clone()), |res_bits_ext| { - res_bits_ext.map(|bits_ext| PeerWireProtocolMessage::BitsExtension(bits_ext)) - }) | map!(value!(ext_protocol.parse_bytes(bytes)), |res_prot_ext| res_prot_ext - .map(|prot_ext| PeerWireProtocolMessage::ProtExtension(prot_ext))) - ) + alt(( + parse_keep_alive, + parse_choke, + parse_unchoke, + parse_interested, + parse_uninterested, + parse_have, + parse_bitfield, + parse_request, + parse_piece, + parse_cancel, + parse_bits_extension, + |input| parse_prot_extension(input, ext_protocol), + ))(bytes) } diff --git a/packages/peer/src/message/null.rs b/packages/peer/src/message/null.rs index ec8976886..2af79c6f4 100644 --- a/packages/peer/src/message/null.rs +++ b/packages/peer/src/message/null.rs @@ -1,3 +1,4 @@ /// Enumeration of messages for `NullProtocol`. #[allow(clippy::module_name_repetitions)] +#[derive(Debug)] pub enum NullProtocolMessage {} diff --git a/packages/peer/src/message/prot_ext/mod.rs b/packages/peer/src/message/prot_ext/mod.rs index 8d53411fe..c3743c09d 100644 --- a/packages/peer/src/message/prot_ext/mod.rs +++ b/packages/peer/src/message/prot_ext/mod.rs @@ -1,11 +1,21 @@ -use std::io::{self, Write}; +use std::io::Write as _; +use std::ops::Deref; +use std::rc::Rc; +use std::sync::Arc; use bencode::{BConvert, BDecodeOpt, BencodeRef}; use byteorder::{BigEndian, WriteBytesExt}; use bytes::Bytes; -use nom::{ - alt, be_u32, be_u8, call, error_node_position, error_position, map, switch, tuple, tuple_parser, value, ErrorKind, IResult, -}; +use nom::branch::alt; +use nom::bytes::complete::{take, take_until}; +use nom::combinator::{map, value}; +use nom::error::{ErrorKind, ParseError}; +use nom::multi::length_data; +use nom::number::complete::{be_u32, be_u8}; +use nom::sequence::{pair, tuple}; +use nom::IResult; +use thiserror::Error; +use ut_metadata::UtMetadataMessageError; use crate::message::{self, bencode_util, bits_ext, ExtendedMessage, ExtendedType, PeerWireProtocolMessage}; use crate::protocol::PeerProtocol; @@ -16,26 +26,68 @@ mod ut_metadata; pub use self::ut_metadata::{UtMetadataDataMessage, UtMetadataMessage, UtMetadataRejectMessage, UtMetadataRequestMessage}; +#[derive(Debug, Clone)] + +pub struct ByteVecDisplay(Vec); + +impl std::fmt::Display for ByteVecDisplay { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_tuple("ByteVecDisplay").field(&self.0).finish() + } +} + +impl Deref for ByteVecDisplay { + type Target = Vec; + + fn deref(&self) -> &Self::Target { + &self.0 + } +} + +#[derive(Error, Debug, Clone)] +pub enum PeerExtensionProtocolMessageError { + #[error("Error from UtMetadata")] + UtMetadataError(UtMetadataMessageError), + + #[error("Error from UtMetadata")] + UnknownId(), + + #[error("Failed to Parse Extension Id: {0}")] + UnknownExtensionId(u8), + + #[error("Failed to Parse Extension: {0}")] + ParseExtensionError(Arc>>), +} + +impl From for PeerExtensionProtocolMessageError { + fn from(value: UtMetadataMessageError) -> Self { + Self::UtMetadataError(value) + } +} + /// Enumeration of `BEP 10` extension protocol compatible messages. +#[derive(Debug)] pub enum PeerExtensionProtocolMessage

where P: PeerProtocol, { UtMetadata(UtMetadataMessage), //UtPex(UtPexMessage), - Custom(P::ProtocolMessage), + Custom(Result), } impl

PeerExtensionProtocolMessage

where - P: PeerProtocol, + P: PeerProtocol + Clone + std::fmt::Debug, +

::ProtocolMessage: std::fmt::Debug, +

::ProtocolMessageError: std::fmt::Debug, { /// Returns the number of bytes needed encode a given slice. /// /// # Errors /// /// This function should not return an error. - pub fn bytes_needed(bytes: &[u8]) -> io::Result> { + pub fn bytes_needed(bytes: &[u8]) -> std::io::Result> { // Follows same length prefix logic as our normal wire protocol... PeerWireProtocolMessage::

::bytes_needed(bytes) } @@ -46,17 +98,19 @@ where /// /// This function will return an error if unable to parse. pub fn parse_bytes( - bytes: Bytes, + bytes: &[u8], extended: &ExtendedMessage, custom_prot: &mut P, - ) -> io::Result> { - match parse_extensions(bytes, extended, custom_prot) { - IResult::Done((), result) => result, - _ => Err(io::Error::new( - io::ErrorKind::Other, - "Failed To Parse PeerExtensionProtocolMessage", - )), - } + ) -> std::io::Result, PeerExtensionProtocolMessageError>> { + // pass through an inner `std::io::Error`, and wrap any nom-error. + let res = match parse_extensions(bytes, extended, custom_prot) { + Ok((_, result)) => result?, + Err(err) => Err(PeerExtensionProtocolMessageError::ParseExtensionError(Arc::new( + err.to_owned().map_input(ByteVecDisplay), + ))), + }; + + Ok(res) } /// Write Bytes from the current state. @@ -68,83 +122,115 @@ where /// # Panics /// /// This function will panic if the message is too long. - pub fn write_bytes(&self, mut writer: W, extended: &ExtendedMessage, custom_prot: &mut P) -> io::Result<()> + pub fn write_bytes(&self, mut writer: W, extended: &ExtendedMessage, custom_prot: &mut P) -> std::io::Result where - W: Write, + W: std::io::Write, { match self { PeerExtensionProtocolMessage::UtMetadata(msg) => { let Some(ext_id) = extended.query_id(&ExtendedType::UtMetadata) else { - return Err(io::Error::new( - io::ErrorKind::Other, + return Err(std::io::Error::new( + std::io::ErrorKind::Other, "Can't Send UtMetadataMessage As We Have No Id Mapping", )); }; let total_len = (2 + msg.message_size()); - message::write_length_id_pair( + let id_length = message::write_length_id_pair( &mut writer, total_len.try_into().unwrap(), Some(bits_ext::EXTENDED_MESSAGE_ID), )?; writer.write_u8(ext_id)?; - msg.write_bytes(writer) + let () = msg.write_bytes(writer)?; + + Ok(id_length + total_len) } PeerExtensionProtocolMessage::Custom(msg) => custom_prot.write_bytes(msg, writer), } } - pub fn message_size(&self, custom_prot: &mut P) -> usize { + /// Retrieve how many bytes the message will occupy on the wire. + /// + /// # Errors + /// + /// This function will return an error if unable to calculate the message length. + pub fn message_size(&self, custom_prot: &mut P) -> std::io::Result { match self { - PeerExtensionProtocolMessage::UtMetadata(msg) => msg.message_size(), + PeerExtensionProtocolMessage::UtMetadata(msg) => Ok(msg.message_size()), PeerExtensionProtocolMessage::Custom(msg) => custom_prot.message_size(msg), } } } -fn parse_extensions

( - mut bytes: Bytes, +fn parse_extensions<'a, P>( + bytes: &'a [u8], extended: &ExtendedMessage, custom_prot: &mut P, -) -> IResult<(), io::Result>> +) -> IResult<&'a [u8], std::io::Result, PeerExtensionProtocolMessageError>>> where P: PeerProtocol, { - let header_bytes = bytes.clone(); + let ut_metadata_fn = |input: &'a [u8]| -> IResult< + &'a [u8], + std::io::Result, PeerExtensionProtocolMessageError>>, + > { + let (_, (message_len, extended_message_id, message_id)) = tuple((be_u32, be_u8, be_u8))(input)?; + + if extended_message_id == bits_ext::EXTENDED_MESSAGE_ID { + let from = EXTENSION_HEADER_LEN; + let to = EXTENSION_HEADER_LEN + message_len as usize - 2; + + let bytes = Bytes::copy_from_slice(&input[from..to]); + + Ok((input, parse_extensions_with_id(bytes, extended, message_id))) + } else { + Ok(( + input, + Ok(Err(PeerExtensionProtocolMessageError::UnknownExtensionId( + extended_message_id, + ))), + )) + } + }; + + let custom_fn = |input: &'a [u8]| -> IResult< + &'a [u8], + std::io::Result, PeerExtensionProtocolMessageError>>, + > { + Ok(( + input, + custom_prot + .parse_bytes(input) + .map(|item| Ok(PeerExtensionProtocolMessage::Custom(item))), + )) + }; // Attempt to parse a built in message type, otherwise, see if it is an extension type. - alt!( - (), - ignore_input!(switch!(header_bytes.as_ref(), throwaway_input!(tuple!(be_u32, be_u8, be_u8)), - (message_len, bits_ext::EXTENDED_MESSAGE_ID, message_id) => - call!(parse_extensions_with_id, bytes.split_off(EXTENSION_HEADER_LEN).split_to(message_len as usize - 2), extended, message_id) - )) | map!(value!(custom_prot.parse_bytes(bytes)), |res_cust_ext| res_cust_ext - .map(|cust_ext| PeerExtensionProtocolMessage::Custom(cust_ext))) - ) + alt((ut_metadata_fn, custom_fn))(bytes) } fn parse_extensions_with_id

( - _input: (), bytes: Bytes, extended: &ExtendedMessage, id: u8, -) -> IResult<(), io::Result>> +) -> std::io::Result, PeerExtensionProtocolMessageError>> where P: PeerProtocol, { - let lt_metadata_id = extended.query_id(&ExtendedType::UtMetadata); + let Some(lt_metadata_id) = extended.query_id(&ExtendedType::UtMetadata) else { + return Ok(Err(PeerExtensionProtocolMessageError::UnknownId())); + }; //let ut_pex_id = extended.query_id(&ExtendedType::UtPex); - let result = if lt_metadata_id == Some(id) { - UtMetadataMessage::parse_bytes(bytes).map(|lt_metadata_msg| PeerExtensionProtocolMessage::UtMetadata(lt_metadata_msg)) - } else { - Err(io::Error::new( - io::ErrorKind::Other, - format!("Unknown Id For PeerExtensionProtocolMessage: {id}"), - )) + let item = UtMetadataMessage::parse_bytes(bytes)?; + + let item = match item { + Ok(message) => Ok(PeerExtensionProtocolMessage::UtMetadata(message)), + Err(err) => Err(PeerExtensionProtocolMessageError::UtMetadataError(err)), }; - IResult::Done((), result) + Ok(item) } diff --git a/packages/peer/src/message/prot_ext/ut_metadata.rs b/packages/peer/src/message/prot_ext/ut_metadata.rs index f3a4eb0f5..1c8c58dfb 100644 --- a/packages/peer/src/message/prot_ext/ut_metadata.rs +++ b/packages/peer/src/message/prot_ext/ut_metadata.rs @@ -1,9 +1,10 @@ -use std::io; -use std::io::Write; +use std::io::Write as _; use bencode::{ben_int, ben_map, BConvert, BDecodeOpt, BencodeRef}; use bytes::Bytes; +use thiserror::Error; +use super::PeerExtensionProtocolMessageError; use crate::message::bencode_util; const REQUEST_MESSAGE_TYPE_ID: u8 = 0; @@ -12,6 +13,13 @@ const REJECT_MESSAGE_TYPE_ID: u8 = 2; const ROOT_ERROR_KEY: &str = "PeerExtensionProtocolMessage"; +#[allow(clippy::module_name_repetitions)] +#[derive(Error, Debug, Clone)] +pub enum UtMetadataMessageError { + #[error("Failed to match message type: {0}")] + UnknownMessageType(u8), +} + /// Enumeration of messages for `PeerExtensionProtocolMessage::UtMetadata`. #[allow(clippy::module_name_repetitions)] #[derive(Clone, Debug, Hash, PartialEq, Eq)] @@ -27,7 +35,7 @@ impl UtMetadataMessage { /// # Errors /// /// This function will return an error if unable to parse given bytes into type. - pub fn parse_bytes(mut bytes: Bytes) -> io::Result { + pub fn parse_bytes(mut bytes: Bytes) -> std::io::Result> { // Our bencode is pretty flat, and we don't want to enforce a full decode, as data // messages have the raw data appended outside of the bencode structure... let decode_opts = BDecodeOpt::new(2, false, false); @@ -41,7 +49,7 @@ impl UtMetadataMessage { let bencode_bytes = bytes.split_to(bencode.buffer().len()); let extra_bytes = bytes; - match msg_type { + let message = match msg_type { REQUEST_MESSAGE_TYPE_ID => Ok(UtMetadataMessage::Request(UtMetadataRequestMessage::with_bytes( piece, &bencode_bytes, @@ -60,14 +68,13 @@ impl UtMetadataMessage { &bencode_bytes, ))) } - other => Err(io::Error::new( - io::ErrorKind::Other, - format!("Failed To Recognize Message Type For UtMetadataMessage: {msg_type}"), - )), - } + other => Err(UtMetadataMessageError::UnknownMessageType(other)), + }; + + Ok(message) } - Err(err) => Err(io::Error::new( - io::ErrorKind::Other, + Err(err) => Err(std::io::Error::new( + std::io::ErrorKind::Other, format!("Failed To Parse UtMetadataMessage As Bencode: {err}"), )), } @@ -78,9 +85,9 @@ impl UtMetadataMessage { /// # Errors /// /// This function will return an error if unable to write the bytes. - pub fn write_bytes(&self, writer: W) -> io::Result<()> + pub fn write_bytes(&self, writer: W) -> std::io::Result<()> where - W: Write, + W: std::io::Write, { match self { UtMetadataMessage::Request(request) => request.write_bytes(writer), @@ -136,9 +143,9 @@ impl UtMetadataRequestMessage { /// # Errors /// /// This function will return an error if unable to write the bytes. - pub fn write_bytes(&self, mut writer: W) -> io::Result<()> + pub fn write_bytes(&self, mut writer: W) -> std::io::Result<()> where - W: Write, + W: std::io::Write, { let encoded_bytes = (ben_map! { bencode_util::MESSAGE_TYPE_KEY => ben_int!(i64::from(REQUEST_MESSAGE_TYPE_ID)), @@ -202,9 +209,9 @@ impl UtMetadataDataMessage { /// # Errors /// /// This function will return an error if unable to write bytes. - pub fn write_bytes(&self, mut writer: W) -> io::Result<()> + pub fn write_bytes(&self, mut writer: W) -> std::io::Result<()> where - W: Write, + W: std::io::Write, { let encoded_bytes = (ben_map! { bencode_util::MESSAGE_TYPE_KEY => ben_int!(i64::from(DATA_MESSAGE_TYPE_ID)), @@ -271,9 +278,9 @@ impl UtMetadataRejectMessage { /// # Errors /// /// This function will return an error if unable to write the bytes. - pub fn write_bytes(&self, mut writer: W) -> io::Result<()> + pub fn write_bytes(&self, mut writer: W) -> std::io::Result<()> where - W: Write, + W: std::io::Write, { let encoded_bytes = (ben_map! { bencode_util::MESSAGE_TYPE_KEY => ben_int!(i64::from(REJECT_MESSAGE_TYPE_ID)), diff --git a/packages/peer/src/message/standard.rs b/packages/peer/src/message/standard.rs index 238080851..73aef3c85 100644 --- a/packages/peer/src/message/standard.rs +++ b/packages/peer/src/message/standard.rs @@ -1,8 +1,12 @@ -use std::io::{self, Write}; +use std::io::Write as _; use byteorder::{BigEndian, WriteBytesExt}; use bytes::Bytes; -use nom::{be_u32, call, do_parse, map, take, tuple, tuple_parser, value, IResult, Needed}; +use nom::bytes::complete::take; +use nom::combinator::{map, map_res}; +use nom::number::complete::be_u32; +use nom::sequence::tuple; +use nom::{IResult, Needed}; use crate::message; @@ -13,37 +17,85 @@ pub struct HaveMessage { } impl HaveMessage { + /// Creates a new `HaveMessage`. + /// + /// # Parameters + /// + /// - `piece_index`: The index of the piece that you have. + /// + /// # Returns + /// + /// A new `HaveMessage` instance. #[must_use] pub fn new(piece_index: u32) -> HaveMessage { HaveMessage { piece_index } } - pub fn parse_bytes(_input: (), bytes: &Bytes) -> IResult<(), io::Result> { - throwaway_input!(parse_have(bytes.as_ref())) + /// Parses a byte slice into a `HaveMessage`. + /// + /// # Parameters + /// + /// - `bytes`: The byte slice to parse. + /// + /// # Returns + /// + /// An `io::Result` containing the parsed `HaveMessage` or an error if parsing fails. + /// + /// # Errors + /// + /// This function will return an error if the byte slice cannot be parsed into a `HaveMessage`. + pub fn parse_bytes(bytes: &[u8]) -> std::io::Result { + match parse_have(bytes) { + Ok((_, msg)) => msg, + Err(_) => Err(std::io::Error::new(std::io::ErrorKind::Other, "Failed to parse HaveMessage")), + } } - /// Write-out current state as bytes. + /// Writes the current state of the `HaveMessage` as bytes. + /// + /// # Parameters + /// + /// - `writer`: The writer to which the bytes will be written. /// /// # Errors /// /// This function will return an error if unable to write bytes. - pub fn write_bytes(&self, mut writer: W) -> io::Result<()> + pub fn write_bytes(&self, mut writer: W) -> std::io::Result where - W: Write, + W: std::io::Write, { - message::write_length_id_pair(&mut writer, message::HAVE_MESSAGE_LEN, Some(message::HAVE_MESSAGE_ID))?; + let id_length = message::write_length_id_pair(&mut writer, message::HAVE_MESSAGE_LEN, Some(message::HAVE_MESSAGE_ID))?; + let () = writer.write_u32::(self.piece_index)?; - writer.write_u32::(self.piece_index) + Ok(id_length + 4) // + u32 } + /// Returns the piece index of the `HaveMessage`. + /// + /// # Returns + /// + /// The piece index. #[must_use] pub fn piece_index(&self) -> u32 { self.piece_index } } -fn parse_have(bytes: &[u8]) -> IResult<&[u8], io::Result> { - map!(bytes, be_u32, |index| Ok(HaveMessage::new(index))) +/// Parses a byte slice into a `HaveMessage`. +/// +/// # Parameters +/// +/// - `bytes`: The byte slice to parse. +/// +/// # Returns +/// +/// An `IResult` containing the remaining byte slice and an `io::Result` with the parsed `HaveMessage`. +/// +/// # Errors +/// +/// This function will return an error if the byte slice cannot be parsed into a `HaveMessage`. +fn parse_have(bytes: &[u8]) -> IResult<&[u8], std::io::Result> { + map(be_u32, |index| Ok(HaveMessage::new(index)))(bytes) } // ----------------------------------------------------------------------------// @@ -57,26 +109,48 @@ pub struct BitFieldMessage { } impl BitFieldMessage { + /// Creates a new `BitFieldMessage`. + /// + /// # Parameters + /// + /// - `bytes`: The bytes representing the bitfield. + /// + /// # Returns + /// + /// A new `BitFieldMessage` instance. pub fn new(bytes: Bytes) -> BitFieldMessage { BitFieldMessage { bytes } } - pub fn parse_bytes(_input: (), mut bytes: Bytes, len: u32) -> IResult<(), io::Result> { - let cast_len = message::u32_to_usize(len); - - if bytes.len() >= cast_len { - IResult::Done( - (), - Ok(BitFieldMessage { - bytes: bytes.split_to(cast_len), - }), - ) - } else { - IResult::Incomplete(Needed::Size(cast_len - bytes.len())) + /// Parses a byte slice into a `BitFieldMessage`. + /// + /// # Parameters + /// + /// - `bytes`: The byte slice to parse. + /// + /// # Returns + /// + /// An `io::Result` containing the parsed `BitFieldMessage` or an error if parsing fails. + /// + /// # Errors + /// + /// This function will return an error if the byte slice cannot be parsed into a `BitFieldMessage`. + pub fn parse_bytes(bytes: &[u8]) -> std::io::Result { + let len = bytes.len(); + match parse_bitfield(bytes, len) { + Ok((_, msg)) => msg, + Err(_) => Err(std::io::Error::new( + std::io::ErrorKind::Other, + "Failed to parse BitFieldMessage", + )), } } - /// Write-out current state as bytes. + /// Writes the current state of the `BitFieldMessage` as bytes. + /// + /// # Parameters + /// + /// - `writer`: The writer to which the bytes will be written. /// /// # Errors /// @@ -84,31 +158,72 @@ impl BitFieldMessage { /// /// # Panics /// - /// This function will panic if the the length is too long. - pub fn write_bytes(&self, mut writer: W) -> io::Result<()> + /// This function will panic if the length is too long. + pub fn write_bytes(&self, mut writer: W) -> std::io::Result where - W: Write, + W: std::io::Write, { - let actual_length = self.bytes.len() + 1; - message::write_length_id_pair( - &mut writer, - actual_length.try_into().unwrap(), - Some(message::BITFIELD_MESSAGE_ID), - )?; + let message_length: u32 = self + .bytes + .len() + .try_into() + .map_err(|e| std::io::Error::new(std::io::ErrorKind::Unsupported, e))?; + + let actual_length = message_length + 1; // + Some(message::BITFIELD_MESSAGE_ID); - writer.write_all(&self.bytes) + let id_length = message::write_length_id_pair(&mut writer, actual_length, Some(message::BITFIELD_MESSAGE_ID))?; + let () = writer.write_all(&self.bytes)?; + + Ok(id_length + message_length as usize) } + /// Returns the bitfield bytes. + /// + /// # Returns + /// + /// A slice of the bitfield bytes. pub fn bitfield(&self) -> &[u8] { &self.bytes } + /// Returns an iterator over the `BitFieldMessage` that yields `HaveMessage`s. + /// + /// # Returns + /// + /// An iterator over the `BitFieldMessage`. #[allow(clippy::iter_without_into_iter)] pub fn iter(&self) -> BitFieldIter { BitFieldIter::new(self.bytes.clone()) } } +/// Parses a byte slice into a `BitFieldMessage`. +/// +/// # Parameters +/// +/// - `bytes`: The byte slice to parse. +/// - `len`: The length of the bitfield. +/// +/// # Returns +/// +/// An `IResult` containing the remaining byte slice and an `io::Result` with the parsed `BitFieldMessage`. +/// +/// # Errors +/// +/// This function will return an error if the byte slice cannot be parsed into a `BitFieldMessage`. +fn parse_bitfield(bytes: &[u8], len: usize) -> IResult<&[u8], std::io::Result> { + if bytes.len() >= len { + Ok(( + &bytes[len..], + Ok(BitFieldMessage { + bytes: Bytes::copy_from_slice(&bytes[..len]), + }), + )) + } else { + Err(nom::Err::Incomplete(Needed::new(len - bytes.len()))) + } +} + /// Iterator for a `BitFieldMessage` to `HaveMessage`s. pub struct BitFieldIter { bytes: Bytes, @@ -144,7 +259,6 @@ impl Iterator for BitFieldIter { } // ----------------------------------------------------------------------------// - /// Message for requesting a block from a peer. #[derive(Copy, Clone, Debug, Hash, PartialEq, Eq)] pub struct RequestMessage { @@ -154,6 +268,17 @@ pub struct RequestMessage { } impl RequestMessage { + /// Creates a new `RequestMessage`. + /// + /// # Parameters + /// + /// - `piece_index`: The index of the piece being requested. + /// - `block_offset`: The offset within the piece. + /// - `block_length`: The length of the block being requested. + /// + /// # Returns + /// + /// A new `RequestMessage` instance. #[must_use] pub fn new(piece_index: u32, block_offset: u32, block_length: usize) -> RequestMessage { RequestMessage { @@ -163,11 +288,34 @@ impl RequestMessage { } } - pub fn parse_bytes(_input: (), bytes: &Bytes) -> IResult<(), io::Result> { - throwaway_input!(parse_request(bytes.as_ref())) + /// Parses a byte slice into a `RequestMessage`. + /// + /// # Parameters + /// + /// - `bytes`: The byte slice to parse. + /// + /// # Returns + /// + /// An `io::Result` containing the parsed `RequestMessage` or an error if parsing fails. + /// + /// # Errors + /// + /// This function will return an error if the byte slice cannot be parsed into a `RequestMessage`. + pub fn parse_bytes(bytes: &[u8]) -> std::io::Result { + match parse_request(bytes) { + Ok((_, msg)) => msg, + Err(_) => Err(std::io::Error::new( + std::io::ErrorKind::Other, + "Failed to parse RequestMessage", + )), + } } - /// Write-out current state as bytes. + /// Writes the current state of the `RequestMessage` as bytes. + /// + /// # Parameters + /// + /// - `writer`: The writer to which the bytes will be written. /// /// # Errors /// @@ -176,37 +324,74 @@ impl RequestMessage { /// # Panics /// /// This function will panic if the `block_length` is too large. - pub fn write_bytes(&self, mut writer: W) -> io::Result<()> + pub fn write_bytes(&self, mut writer: W) -> std::io::Result where - W: Write, + W: std::io::Write, { - message::write_length_id_pair(&mut writer, message::REQUEST_MESSAGE_LEN, Some(message::REQUEST_MESSAGE_ID))?; + let id_length = + message::write_length_id_pair(&mut writer, message::REQUEST_MESSAGE_LEN, Some(message::REQUEST_MESSAGE_ID))?; + + let () = writer.write_u32::(self.piece_index)?; + let () = writer.write_u32::(self.block_offset)?; + { + let block_length: u32 = self + .block_length() + .try_into() + .map_err(|e| std::io::Error::new(std::io::ErrorKind::Unsupported, e))?; + let () = writer.write_u32::(block_length)?; + } - writer.write_u32::(self.piece_index)?; - writer.write_u32::(self.block_offset)?; - writer.write_u32::(self.block_length.try_into().unwrap()) + Ok(id_length + 12) // + u32 * 3 } + /// Returns the piece index of the `RequestMessage`. + /// + /// # Returns + /// + /// The piece index. #[must_use] pub fn piece_index(&self) -> u32 { self.piece_index } + /// Returns the block offset of the `RequestMessage`. + /// + /// # Returns + /// + /// The block offset. #[must_use] pub fn block_offset(&self) -> u32 { self.block_offset } + /// Returns the block length of the `RequestMessage`. + /// + /// # Returns + /// + /// The block length. #[must_use] pub fn block_length(&self) -> usize { self.block_length } } -fn parse_request(bytes: &[u8]) -> IResult<&[u8], io::Result> { - map!(bytes, tuple!(be_u32, be_u32, be_u32), |(index, offset, length)| Ok( - RequestMessage::new(index, offset, message::u32_to_usize(length)) - )) +/// Parses a byte slice into a `RequestMessage`. +/// +/// # Parameters +/// +/// - `bytes`: The byte slice to parse. +/// +/// # Returns +/// +/// An `IResult` containing the remaining byte slice and an `io::Result` with the parsed `RequestMessage`. +/// +/// # Errors +/// +/// This function will return an error if the byte slice cannot be parsed into a `RequestMessage`. +fn parse_request(bytes: &[u8]) -> IResult<&[u8], std::io::Result> { + map(tuple((be_u32, be_u32, be_u32)), |(index, offset, length)| { + Ok(RequestMessage::new(index, offset, message::u32_to_usize(length))) + })(bytes) } // ----------------------------------------------------------------------------// @@ -223,6 +408,17 @@ pub struct PieceMessage { } impl PieceMessage { + /// Creates a new `PieceMessage`. + /// + /// # Parameters + /// + /// - `piece_index`: The index of the piece. + /// - `block_offset`: The offset within the piece. + /// - `block`: The block of data. + /// + /// # Returns + /// + /// A new `PieceMessage` instance. pub fn new(piece_index: u32, block_offset: u32, block: Bytes) -> PieceMessage { // TODO: Check that users Bytes wont overflow a u32 PieceMessage { @@ -232,11 +428,32 @@ impl PieceMessage { } } - pub fn parse_bytes(_input: (), bytes: &Bytes, len: u32) -> IResult<(), io::Result> { - throwaway_input!(parse_piece(bytes, len)) + /// Parses a byte slice into a `PieceMessage`. + /// + /// # Parameters + /// + /// - `bytes`: The byte slice to parse. + /// - `len`: The length of the piece. + /// + /// # Returns + /// + /// An `io::Result` containing the parsed `PieceMessage` or an error if parsing fails. + /// + /// # Errors + /// + /// This function will return an error if the byte slice cannot be parsed into a `PieceMessage`. + pub fn parse_bytes(bytes: &[u8], len: usize) -> std::io::Result { + match parse_piece(bytes, len) { + Ok((_, msg)) => msg, + Err(_) => Err(std::io::Error::new(std::io::ErrorKind::Other, "Failed to parse PieceMessage")), + } } - /// Write-out current state as bytes. + /// Writes the current state of the `PieceMessage` as bytes. + /// + /// # Parameters + /// + /// - `writer`: The writer to which the bytes will be written. /// /// # Errors /// @@ -245,48 +462,91 @@ impl PieceMessage { /// # Panics /// /// This function will panic if the block length is too large. - pub fn write_bytes(&self, mut writer: W) -> io::Result<()> + pub fn write_bytes(&self, mut writer: W) -> std::io::Result where - W: Write, + W: std::io::Write, { - let actual_length = self.block_length() + 9; - message::write_length_id_pair( + let block_length: u32 = self + .block_length() + .try_into() + .map_err(|e| std::io::Error::new(std::io::ErrorKind::Unsupported, e))?; + + let actual_length = self.block_length() + 9; // + Some(message::PIECE_MESSAGE_ID) + 2 * u32 + + let length_length = message::write_length_id_pair( &mut writer, actual_length.try_into().unwrap(), Some(message::PIECE_MESSAGE_ID), )?; - writer.write_u32::(self.piece_index)?; - writer.write_u32::(self.block_offset)?; + let () = writer.write_u32::(self.piece_index)?; + let () = writer.write_u32::(self.block_offset)?; + let () = writer.write_all(&self.block[..])?; - writer.write_all(&self.block[..]) + Ok(length_length + block_length as usize + 8) // + 2 * u32 } + /// Returns the piece index of the `PieceMessage`. + /// + /// # Returns + /// + /// The piece index. + #[must_use] pub fn piece_index(&self) -> u32 { self.piece_index } + /// Returns the block offset of the `PieceMessage`. + /// + /// # Returns + /// + /// The block offset. + #[must_use] pub fn block_offset(&self) -> u32 { self.block_offset } + /// Returns the block length of the `PieceMessage`. + /// + /// # Returns + /// + /// The block length. + #[must_use] pub fn block_length(&self) -> usize { self.block.len() } + /// Returns the block of the `PieceMessage`. + /// + /// # Returns + /// + /// The block. pub fn block(&self) -> Bytes { self.block.clone() } } -fn parse_piece(bytes: &Bytes, len: u32) -> IResult<&[u8], io::Result> { - do_parse!(bytes.as_ref(), - piece_index: be_u32 >> - block_offset: be_u32 >> - block_len: value!(message::u32_to_usize(len - 8)) >> - block: map!(take!(block_len), |_| bytes.slice(8, 8 + block_len)) >> - (Ok(PieceMessage::new(piece_index, block_offset, block))) - ) +/// Parses a byte slice into a `PieceMessage`. +/// +/// # Parameters +/// +/// - `bytes`: The byte slice to parse. +/// - `len`: The length of the piece. +/// +/// # Returns +/// +/// An `IResult` containing the remaining byte slice and an `io::Result` with the parsed `PieceMessage`. +/// +/// # Errors +/// +/// This function will return an error if the byte slice cannot be parsed into a `PieceMessage`. +fn parse_piece(bytes: &[u8], len: usize) -> IResult<&[u8], std::io::Result> { + map( + tuple((be_u32, be_u32, take(len - 8))), + |(piece_index, block_offset, block): (u32, u32, &[u8])| { + Ok(PieceMessage::new(piece_index, block_offset, Bytes::copy_from_slice(block))) + }, + )(bytes) } // ----------------------------------------------------------------------------// @@ -300,6 +560,17 @@ pub struct CancelMessage { } impl CancelMessage { + /// Creates a new `CancelMessage`. + /// + /// # Parameters + /// + /// - `piece_index`: The index of the piece. + /// - `block_offset`: The offset within the piece. + /// - `block_length`: The length of the block. + /// + /// # Returns + /// + /// A new `CancelMessage` instance. #[must_use] pub fn new(piece_index: u32, block_offset: u32, block_length: usize) -> CancelMessage { CancelMessage { @@ -309,11 +580,34 @@ impl CancelMessage { } } - pub fn parse_bytes(_input: (), bytes: &Bytes) -> IResult<(), io::Result> { - throwaway_input!(parse_cancel(bytes.as_ref())) + /// Parses a byte slice into a `CancelMessage`. + /// + /// # Parameters + /// + /// - `bytes`: The byte slice to parse. + /// + /// # Returns + /// + /// An `io::Result` containing the parsed `CancelMessage` or an error if parsing fails. + /// + /// # Errors + /// + /// This function will return an error if the byte slice cannot be parsed into a `CancelMessage`. + pub fn parse_bytes(bytes: &[u8]) -> std::io::Result { + match parse_cancel(bytes) { + Ok((_, msg)) => msg, + Err(_) => Err(std::io::Error::new( + std::io::ErrorKind::Other, + "Failed to parse CancelMessage", + )), + } } - /// Write-out current state as bytes. + /// Writes the current state of the `CancelMessage` as bytes. + /// + /// # Parameters + /// + /// - `writer`: The writer to which the bytes will be written. /// /// # Errors /// @@ -322,37 +616,74 @@ impl CancelMessage { /// # Panics /// /// This function will panic if the block length is too large. - pub fn write_bytes(&self, mut writer: W) -> io::Result<()> + pub fn write_bytes(&self, mut writer: W) -> std::io::Result where - W: Write, + W: std::io::Write, { - message::write_length_id_pair(&mut writer, message::CANCEL_MESSAGE_LEN, Some(message::CANCEL_MESSAGE_ID))?; + let id_length = + message::write_length_id_pair(&mut writer, message::CANCEL_MESSAGE_LEN, Some(message::CANCEL_MESSAGE_ID))?; + + let () = writer.write_u32::(self.piece_index)?; + let () = writer.write_u32::(self.block_offset)?; + { + let block_length: u32 = self + .block_length + .try_into() + .map_err(|e| std::io::Error::new(std::io::ErrorKind::Unsupported, e))?; + let () = writer.write_u32::(block_length)?; + } - writer.write_u32::(self.piece_index)?; - writer.write_u32::(self.block_offset)?; - writer.write_u32::(self.block_length.try_into().unwrap()) + Ok(id_length + 12) // + 3 * u32 } + /// Returns the piece index of the `CancelMessage`. + /// + /// # Returns + /// + /// The piece index. #[must_use] pub fn piece_index(&self) -> u32 { self.piece_index } + /// Returns the block offset of the `CancelMessage`. + /// + /// # Returns + /// + /// The block offset. #[must_use] pub fn block_offset(&self) -> u32 { self.block_offset } + /// Returns the block length of the `CancelMessage`. + /// + /// # Returns + /// + /// The block length. #[must_use] pub fn block_length(&self) -> usize { self.block_length } } -fn parse_cancel(bytes: &[u8]) -> IResult<&[u8], io::Result> { - map!(bytes, tuple!(be_u32, be_u32, be_u32), |(index, offset, length)| Ok( - CancelMessage::new(index, offset, message::u32_to_usize(length)) - )) +/// Parses a byte slice into a `CancelMessage`. +/// +/// # Parameters +/// +/// - `bytes`: The byte slice to parse. +/// +/// # Returns +/// +/// An `IResult` containing the remaining byte slice and an `io::Result` with the parsed `CancelMessage`. +/// +/// # Errors +/// +/// This function will return an error if the byte slice cannot be parsed into a `CancelMessage`. +fn parse_cancel(bytes: &[u8]) -> IResult<&[u8], std::io::Result> { + map(tuple((be_u32, be_u32, be_u32)), |(index, offset, length)| { + Ok(CancelMessage::new(index, offset, message::u32_to_usize(length))) + })(bytes) } #[cfg(test)] @@ -370,20 +701,14 @@ mod tests { #[test] fn positive_bitfield_iter_no_messages() { - let mut bytes = Bytes::new(); - bytes.extend_from_slice(&[0x00, 0x00, 0x00]); - - let bitfield = BitFieldMessage::new(bytes); + let bitfield = BitFieldMessage::new(Bytes::copy_from_slice(&[0x00, 0x00, 0x00])); assert_eq!(0, bitfield.iter().count()); } #[test] fn positive_bitfield_iter_single_message_beginning() { - let mut bytes = Bytes::new(); - bytes.extend_from_slice(&[0x80, 0x00, 0x00]); - - let bitfield = BitFieldMessage::new(bytes); + let bitfield = BitFieldMessage::new(Bytes::copy_from_slice(&[0x80, 0x00, 0x00])); assert_eq!(1, bitfield.iter().count()); assert_eq!(HaveMessage::new(0), bitfield.iter().next().unwrap()); @@ -391,10 +716,7 @@ mod tests { #[test] fn positive_bitfield_iter_single_message_middle() { - let mut bytes = Bytes::new(); - bytes.extend_from_slice(&[0x00, 0x01, 0x00]); - - let bitfield = BitFieldMessage::new(bytes); + let bitfield = BitFieldMessage::new(Bytes::copy_from_slice(&[0x00, 0x01, 0x00])); assert_eq!(1, bitfield.iter().count()); assert_eq!(HaveMessage::new(15), bitfield.iter().next().unwrap()); @@ -402,10 +724,7 @@ mod tests { #[test] fn positive_bitfield_iter_single_message_ending() { - let mut bytes = Bytes::new(); - bytes.extend_from_slice(&[0x00, 0x00, 0x01]); - - let bitfield = BitFieldMessage::new(bytes); + let bitfield = BitFieldMessage::new(Bytes::copy_from_slice(&[0x00, 0x00, 0x01])); assert_eq!(1, bitfield.iter().count()); assert_eq!(HaveMessage::new(23), bitfield.iter().next().unwrap()); @@ -413,10 +732,7 @@ mod tests { #[test] fn positive_bitfield_iter_multiple_messages() { - let mut bytes = Bytes::new(); - bytes.extend_from_slice(&[0xAF, 0x00, 0xC1]); - - let bitfield = BitFieldMessage::new(bytes); + let bitfield = BitFieldMessage::new(Bytes::copy_from_slice(&[0xAF, 0x00, 0xC1])); let messages: Vec = bitfield.iter().collect(); assert_eq!(9, messages.len()); diff --git a/packages/peer/src/protocol/extension.rs b/packages/peer/src/protocol/extension.rs index f1f34cb6f..76c0c400a 100644 --- a/packages/peer/src/protocol/extension.rs +++ b/packages/peer/src/protocol/extension.rs @@ -1,18 +1,22 @@ -use std::io::{self, Write}; - -use bytes::Bytes; - -use crate::message::{ExtendedMessage, PeerExtensionProtocolMessage}; +use crate::message::{ExtendedMessage, PeerExtensionProtocolMessage, PeerExtensionProtocolMessageError}; use crate::protocol::{NestedPeerProtocol, PeerProtocol}; /// Protocol for `BEP 10` peer extensions. -pub struct PeerExtensionProtocol

{ + +#[derive(Debug, Clone)] +pub struct PeerExtensionProtocol

+where + P: Clone, +{ our_extended_msg: Option, their_extended_msg: Option, custom_protocol: P, } -impl

PeerExtensionProtocol

{ +impl

PeerExtensionProtocol

+where + P: Clone, +{ /// Create a new `PeerExtensionProtocol` with the given (nested) custom extension protocol. /// /// Notes for `PeerWireProtocol` apply to this custom extension protocol. @@ -27,57 +31,77 @@ impl

PeerExtensionProtocol

{ impl

PeerProtocol for PeerExtensionProtocol

where - P: PeerProtocol, + P: PeerProtocol + Clone + std::fmt::Debug, +

::ProtocolMessage: std::fmt::Debug, +

::ProtocolMessageError: std::fmt::Debug, { type ProtocolMessage = PeerExtensionProtocolMessage

; + type ProtocolMessageError = PeerExtensionProtocolMessageError; - fn bytes_needed(&mut self, bytes: &[u8]) -> io::Result> { + fn bytes_needed(&mut self, bytes: &[u8]) -> std::io::Result> { PeerExtensionProtocolMessage::

::bytes_needed(bytes) } - fn parse_bytes(&mut self, bytes: Bytes) -> io::Result { + fn parse_bytes(&mut self, bytes: &[u8]) -> std::io::Result> { match self.our_extended_msg { Some(ref extended_msg) => PeerExtensionProtocolMessage::parse_bytes(bytes, extended_msg, &mut self.custom_protocol), - None => Err(io::Error::new( - io::ErrorKind::Other, + None => Err(std::io::Error::new( + std::io::ErrorKind::Other, "Extension Message Received From Peer Before Extended Message...", )), } } - fn write_bytes(&mut self, message: &Self::ProtocolMessage, writer: W) -> io::Result<()> + fn write_bytes( + &mut self, + item: &Result, + writer: W, + ) -> std::io::Result where - W: Write, + W: std::io::Write, { + let message = match item { + Ok(message) => message, + Err(err) => return Err(std::io::Error::new(std::io::ErrorKind::InvalidInput, err.clone())), + }; + match self.their_extended_msg { - Some(ref extended_msg) => { - PeerExtensionProtocolMessage::write_bytes(message, writer, extended_msg, &mut self.custom_protocol) - } - None => Err(io::Error::new( - io::ErrorKind::Other, + Some(ref extended_msg) => Ok(PeerExtensionProtocolMessage::write_bytes( + message, + writer, + extended_msg, + &mut self.custom_protocol, + )?), + None => Err(std::io::Error::new( + std::io::ErrorKind::Other, "Extension Message Sent From Us Before Extended Message...", )), } } - fn message_size(&mut self, message: &Self::ProtocolMessage) -> usize { + fn message_size(&mut self, item: &Result) -> std::io::Result { + let message = match item { + Ok(message) => message, + Err(err) => return Err(std::io::Error::new(std::io::ErrorKind::InvalidInput, err.clone())), + }; + message.message_size(&mut self.custom_protocol) } } impl

NestedPeerProtocol for PeerExtensionProtocol

where - P: NestedPeerProtocol, + P: NestedPeerProtocol + Clone, { - fn received_message(&mut self, message: &ExtendedMessage) { - self.custom_protocol.received_message(message); - + fn received_message(&mut self, message: &ExtendedMessage) -> usize { self.their_extended_msg = Some(message.clone()); - } - fn sent_message(&mut self, message: &ExtendedMessage) { - self.custom_protocol.sent_message(message); + self.custom_protocol.received_message(message) + } + fn sent_message(&mut self, message: &ExtendedMessage) -> usize { self.our_extended_msg = Some(message.clone()); + + self.custom_protocol.sent_message(message) } } diff --git a/packages/peer/src/protocol/mod.rs b/packages/peer/src/protocol/mod.rs index 9438887e3..17b1a5743 100644 --- a/packages/peer/src/protocol/mod.rs +++ b/packages/peer/src/protocol/mod.rs @@ -1,9 +1,5 @@ //! Generic `PeerProtocol` implementations. -use std::io::{self, Write}; - -use bytes::Bytes; - pub mod extension; pub mod null; pub mod unit; @@ -14,6 +10,7 @@ pub mod wire; pub trait PeerProtocol { /// Type of message the protocol operates with. type ProtocolMessage; + type ProtocolMessageError; /// Total number of bytes needed to parse a complete message. This is not /// in addition to what we were given, this is the total number of bytes, so @@ -26,26 +23,34 @@ pub trait PeerProtocol { /// # Errors /// /// This function will return an IO result if unable to calculate the bytes needed. - fn bytes_needed(&mut self, bytes: &[u8]) -> io::Result>; + fn bytes_needed(&mut self, bytes: &[u8]) -> std::io::Result>; /// Parse a `ProtocolMessage` from the given bytes. /// /// # Errors /// /// This function will return an IO error if unable to parse the bytes into a [`Self::ProtocolMessage`]. - fn parse_bytes(&mut self, bytes: Bytes) -> io::Result; + fn parse_bytes(&mut self, bytes: &[u8]) -> std::io::Result>; /// Write a `ProtocolMessage` to the given writer. /// /// # Errors /// /// This function will return an error if it fails to write-out. - fn write_bytes(&mut self, message: &Self::ProtocolMessage, writer: W) -> io::Result<()> + fn write_bytes( + &mut self, + item: &Result, + writer: W, + ) -> std::io::Result where - W: Write; + W: std::io::Write; /// Retrieve how many bytes the message will occupy on the wire. - fn message_size(&mut self, message: &Self::ProtocolMessage) -> usize; + /// + /// # Errors + /// + /// This function will return an error if unable to calculate the message length. + fn message_size(&mut self, message: &Result) -> std::io::Result; } /// Trait for nested peer protocols to see higher level peer protocol messages. @@ -64,8 +69,8 @@ pub trait PeerProtocol { #[allow(clippy::module_name_repetitions)] pub trait NestedPeerProtocol { /// Notify a nested protocol that we have received the given message. - fn received_message(&mut self, message: &M); + fn received_message(&mut self, message: &M) -> usize; /// Notify a nested protocol that we have sent the given message. - fn sent_message(&mut self, message: &M); + fn sent_message(&mut self, message: &M) -> usize; } diff --git a/packages/peer/src/protocol/null.rs b/packages/peer/src/protocol/null.rs index 87741b680..0532e2510 100644 --- a/packages/peer/src/protocol/null.rs +++ b/packages/peer/src/protocol/null.rs @@ -1,7 +1,3 @@ -use std::io::{self, Write}; - -use bytes::Bytes; - use crate::message::NullProtocolMessage; use crate::protocol::{NestedPeerProtocol, PeerProtocol}; @@ -14,7 +10,7 @@ use crate::protocol::{NestedPeerProtocol, PeerProtocol}; /// Of course, you should make sure that you don't tell peers /// that you support any extended messages. #[allow(clippy::module_name_repetitions)] -#[derive(Default)] +#[derive(Debug, Default, Clone)] pub struct NullProtocol; impl NullProtocol { @@ -27,34 +23,39 @@ impl NullProtocol { impl PeerProtocol for NullProtocol { type ProtocolMessage = NullProtocolMessage; + type ProtocolMessageError = std::io::Error; - fn bytes_needed(&mut self, _bytes: &[u8]) -> io::Result> { + fn bytes_needed(&mut self, _: &[u8]) -> std::io::Result> { Ok(Some(0)) } - fn parse_bytes(&mut self, _bytes: Bytes) -> io::Result { - Err(io::Error::new( - io::ErrorKind::Other, + fn parse_bytes(&mut self, _: &[u8]) -> std::io::Result> { + Err(std::io::Error::new( + std::io::ErrorKind::Other, "Attempted To Parse Bytes As Null Protocol", )) } - fn write_bytes(&mut self, _message: &Self::ProtocolMessage, _writer: W) -> io::Result<()> + fn write_bytes(&mut self, _: &Result, _: W) -> std::io::Result where - W: Write, + W: std::io::Write, { panic!( "bip_peer: NullProtocol::write_bytes Was Called...Wait, How Did You Construct An Instance Of NullProtocolMessage? :)" - ) + ); } - fn message_size(&mut self, _message: &Self::ProtocolMessage) -> usize { - 0 + fn message_size(&mut self, _: &Result) -> std::io::Result { + Ok(0) } } impl NestedPeerProtocol for NullProtocol { - fn received_message(&mut self, _message: &M) {} + fn received_message(&mut self, _message: &M) -> usize { + 0 + } - fn sent_message(&mut self, _message: &M) {} + fn sent_message(&mut self, _message: &M) -> usize { + 0 + } } diff --git a/packages/peer/src/protocol/unit.rs b/packages/peer/src/protocol/unit.rs index b6f492c27..07370ca1a 100644 --- a/packages/peer/src/protocol/unit.rs +++ b/packages/peer/src/protocol/unit.rs @@ -1,7 +1,3 @@ -use std::io::{self, Write}; - -use bytes::Bytes; - use crate::protocol::{NestedPeerProtocol, PeerProtocol}; /// Unit protocol which will always return a unit if called. @@ -20,28 +16,34 @@ impl UnitProtocol { impl PeerProtocol for UnitProtocol { type ProtocolMessage = (); - fn bytes_needed(&mut self, _bytes: &[u8]) -> io::Result> { + type ProtocolMessageError = std::io::Error; + + fn bytes_needed(&mut self, _: &[u8]) -> std::io::Result> { Ok(Some(0)) } - fn parse_bytes(&mut self, _bytes: Bytes) -> io::Result { - Ok(()) + fn parse_bytes(&mut self, _: &[u8]) -> std::io::Result> { + Ok(Ok(())) } - fn write_bytes(&mut self, _message: &Self::ProtocolMessage, _writer: W) -> io::Result<()> + fn write_bytes(&mut self, _: &Result, _: W) -> std::io::Result where - W: Write, + W: std::io::Write, { - Ok(()) + Ok(0) } - fn message_size(&mut self, _message: &Self::ProtocolMessage) -> usize { - 0 + fn message_size(&mut self, _: &Result) -> std::io::Result { + Ok(0) } } impl NestedPeerProtocol for UnitProtocol { - fn received_message(&mut self, _message: &M) {} + fn received_message(&mut self, _message: &M) -> usize { + 0 + } - fn sent_message(&mut self, _message: &M) {} + fn sent_message(&mut self, _message: &M) -> usize { + 0 + } } diff --git a/packages/peer/src/protocol/wire.rs b/packages/peer/src/protocol/wire.rs index a93cde5b3..833958097 100644 --- a/packages/peer/src/protocol/wire.rs +++ b/packages/peer/src/protocol/wire.rs @@ -1,16 +1,19 @@ -use std::io::{self, Write}; - -use bytes::Bytes; - -use crate::message::{BitsExtensionMessage, ExtendedMessage, PeerWireProtocolMessage}; +use crate::message::{BitsExtensionMessage, ExtendedMessage, PeerWireProtocolMessage, PeerWireProtocolMessageError}; use crate::protocol::{NestedPeerProtocol, PeerProtocol}; /// Protocol for peer wire messages. -pub struct PeerWireProtocol

{ +#[derive(Debug, Clone)] +pub struct PeerWireProtocol

+where + P: Clone, +{ ext_protocol: P, } -impl

PeerWireProtocol

{ +impl

PeerWireProtocol

+where + P: Clone, +{ /// Create a new `PeerWireProtocol` with the given extension protocol. /// /// Important to note that nested protocol should follow the same message length format @@ -23,40 +26,59 @@ impl

PeerWireProtocol

{ impl

PeerProtocol for PeerWireProtocol

where - P: PeerProtocol + NestedPeerProtocol, + P: PeerProtocol + NestedPeerProtocol + Clone + std::fmt::Debug, +

::ProtocolMessage: std::fmt::Debug, +

::ProtocolMessageError: std::fmt::Debug, { type ProtocolMessage = PeerWireProtocolMessage

; - fn bytes_needed(&mut self, bytes: &[u8]) -> io::Result> { + type ProtocolMessageError = PeerWireProtocolMessageError; + + fn bytes_needed(&mut self, bytes: &[u8]) -> std::io::Result> { PeerWireProtocolMessage::

::bytes_needed(bytes) } - fn parse_bytes(&mut self, bytes: Bytes) -> io::Result { - match PeerWireProtocolMessage::parse_bytes(bytes, &mut self.ext_protocol) { - Ok(PeerWireProtocolMessage::BitsExtension(BitsExtensionMessage::Extended(msg))) => { + fn parse_bytes(&mut self, bytes: &[u8]) -> std::io::Result> { + match PeerWireProtocolMessage::parse_bytes(bytes, &mut self.ext_protocol)? { + PeerWireProtocolMessage::BitsExtension(BitsExtensionMessage::Extended(msg)) => { self.ext_protocol.received_message(&msg); - Ok(PeerWireProtocolMessage::BitsExtension(BitsExtensionMessage::Extended(msg))) + Ok(Ok(PeerWireProtocolMessage::BitsExtension(BitsExtensionMessage::Extended( + msg, + )))) } - other => other, + other => Ok(Ok(other)), } } - fn write_bytes(&mut self, message: &Self::ProtocolMessage, writer: W) -> io::Result<()> + fn write_bytes( + &mut self, + item: &Result, + writer: W, + ) -> std::io::Result where - W: Write, + W: std::io::Write, { - match (message.write_bytes(writer, &mut self.ext_protocol), message) { - (Ok(()), &PeerWireProtocolMessage::BitsExtension(BitsExtensionMessage::Extended(ref msg))) => { - self.ext_protocol.sent_message(msg); + let message = match item { + Ok(message) => message, + Err(err) => return Err(std::io::Error::new(std::io::ErrorKind::InvalidInput, err.clone())), + }; - Ok(()) - } - (other, _) => other, - } + let message_bytes_written = message.write_bytes(writer, &mut self.ext_protocol)?; + + let PeerWireProtocolMessage::BitsExtension(BitsExtensionMessage::Extended(extended_message)) = message else { + return Ok(message_bytes_written); + }; + + Ok(self.ext_protocol.sent_message(extended_message)) } - fn message_size(&mut self, message: &Self::ProtocolMessage) -> usize { + fn message_size(&mut self, item: &Result) -> std::io::Result { + let message = match item { + Ok(message) => message, + Err(err) => return Err(std::io::Error::new(std::io::ErrorKind::InvalidInput, err.clone())), + }; + message.message_size(&mut self.ext_protocol) } } diff --git a/packages/peer/tests/common/connected_channel.rs b/packages/peer/tests/common/connected_channel.rs new file mode 100644 index 000000000..f68a13f6b --- /dev/null +++ b/packages/peer/tests/common/connected_channel.rs @@ -0,0 +1,79 @@ +use std::pin::Pin; +use std::sync::{Arc, Mutex}; +use std::task::{Context, Poll}; + +use futures::channel::mpsc; +use futures::{Sink, SinkExt as _, Stream, StreamExt as _}; + +#[derive(Debug)] +pub struct ConnectedChannel { + send: mpsc::Sender, + recv: Arc>>, +} + +impl Clone for ConnectedChannel { + fn clone(&self) -> Self { + Self { + send: self.send.clone(), + recv: self.recv.clone(), + } + } +} + +impl Sink for ConnectedChannel { + type Error = std::io::Error; + + fn poll_ready(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + self.send + .poll_ready_unpin(cx) + .map_err(|e| std::io::Error::new(std::io::ErrorKind::ConnectionAborted, e)) + } + + fn start_send(mut self: Pin<&mut Self>, item: I) -> Result<(), Self::Error> { + self.send + .start_send_unpin(item) + .map_err(|e| std::io::Error::new(std::io::ErrorKind::ConnectionAborted, e)) + } + + fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + self.send + .poll_flush_unpin(cx) + .map_err(|e| std::io::Error::new(std::io::ErrorKind::ConnectionAborted, e)) + } + + fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + self.send + .poll_close_unpin(cx) + .map_err(|e| std::io::Error::new(std::io::ErrorKind::ConnectionAborted, e)) + } +} + +impl Stream for ConnectedChannel { + type Item = O; + + fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + let Ok(recv) = self.recv.try_lock() else { + cx.waker().wake_by_ref(); + return Poll::Pending; + }; + + Pin::new(recv).poll_next_unpin(cx) + } +} + +#[must_use] +pub fn connected_channel(capacity: usize) -> (ConnectedChannel, ConnectedChannel) { + let (send_one, recv_one) = futures::channel::mpsc::channel(capacity); + let (send_two, recv_two) = futures::channel::mpsc::channel(capacity); + + ( + ConnectedChannel { + send: send_one, + recv: Arc::new(Mutex::new(recv_two)), + }, + ConnectedChannel { + send: send_two, + recv: Arc::new(Mutex::new(recv_one)), + }, + ) +} diff --git a/packages/peer/tests/common/mod.rs b/packages/peer/tests/common/mod.rs index 42af63826..437412b6a 100644 --- a/packages/peer/tests/common/mod.rs +++ b/packages/peer/tests/common/mod.rs @@ -1,56 +1,130 @@ -use std::io; +use std::sync::Once; +use std::time::Duration; +use futures::channel::mpsc::SendError; use futures::sink::Sink; use futures::stream::Stream; -use futures::sync::mpsc::{self, Receiver, Sender}; -use futures::{Poll, StartSend}; +use futures::{SinkExt as _, StreamExt as _, TryStream}; +use peer::error::PeerManagerError; +use peer::{ManagedMessage, PeerInfo, PeerManagerInputMessage, PeerManagerOutputError, PeerManagerOutputMessage}; +use thiserror::Error; +use tokio::time::error::Elapsed; +use tracing::level_filters::LevelFilter; -pub struct ConnectedChannel { - send: Sender, - recv: Receiver, +pub mod connected_channel; + +#[derive(Debug, Error)] +pub enum Error +where + Message: ManagedMessage + Send + 'static, +{ + #[error("Send Timed Out")] + SendTimedOut(Elapsed), + + #[error("Receive Timed Out")] + ReceiveTimedOut(Elapsed), + + #[error("mpsc::Receiver Closed")] + ReceiverClosed(), + + #[error("Peer Manager Input Error {0}")] + PeerManagerErr(#[from] PeerManagerError), + + #[error("Peer Manager Output Error {0}")] + PeerManagerOutputErr(#[from] PeerManagerOutputError), + + #[error("Failed to correct response, but got: {0:?}")] + WrongResponse(#[from] PeerManagerOutputMessage), + + #[error("Failed to receive Peer Added with matching infohash: got: {0:?}, expected: {1:?}")] + InfoHashMissMatch(PeerInfo, PeerInfo), } -impl Sink for ConnectedChannel { - type SinkItem = I; - type SinkError = io::Error; +#[allow(dead_code)] +pub const DEFAULT_TIMEOUT: Duration = Duration::from_millis(500); - fn start_send(&mut self, item: Self::SinkItem) -> StartSend { - self.send - .start_send(item) - .map_err(|_| io::Error::new(io::ErrorKind::ConnectionAborted, "Sender Failed To Send")) - } +#[allow(dead_code)] +pub static INIT: Once = Once::new(); - fn poll_complete(&mut self) -> Poll<(), Self::SinkError> { - self.send - .poll_complete() - .map_err(|_| io::Error::new(io::ErrorKind::ConnectionAborted, "Sender Failed To Send")) - } +#[allow(dead_code)] +pub fn tracing_stderr_init(filter: LevelFilter) { + let builder = tracing_subscriber::fmt() + .with_max_level(filter) + .with_ansi(true) + .with_writer(std::io::stderr); + + builder.pretty().with_file(true).init(); + + tracing::info!("Logging initialized"); } -impl Stream for ConnectedChannel { - type Item = O; - type Error = io::Error; +pub async fn add_peer( + send: &mut Si, + recv: &mut St, + info: PeerInfo, + peer: Peer, +) -> Result<(), Error> +where + Si: Sink>, Error = PeerManagerError> + Unpin, + St: Stream, PeerManagerOutputError>> + Unpin, + Peer: Sink> + + Stream> + + TryStream + + std::fmt::Debug + + Send + + Unpin + + 'static, + Message: ManagedMessage + Send + 'static, +{ + let () = tokio::time::timeout(DEFAULT_TIMEOUT, send.send(Ok(PeerManagerInputMessage::AddPeer(info, peer)))) + .await + .map_err(|e| Error::SendTimedOut(e))??; + + let response = tokio::time::timeout(DEFAULT_TIMEOUT, recv.next()) + .await + .map(|res| res.ok_or(Error::ReceiverClosed())) + .map_err(|e| Error::ReceiveTimedOut(e))???; - fn poll(&mut self) -> Poll, Self::Error> { - self.recv - .poll() - .map_err(|()| io::Error::new(io::ErrorKind::Other, "Receiver Failed To Receive")) + if let PeerManagerOutputMessage::PeerAdded(info_recv) = response { + if info_recv == info { + Ok(()) + } else { + Err(Error::InfoHashMissMatch(info_recv, info)) + } + } else { + Err(Error::from(response)) } } -#[must_use] -pub fn connected_channel(capacity: usize) -> (ConnectedChannel, ConnectedChannel) { - let (send_one, recv_one) = mpsc::channel(capacity); - let (send_two, recv_two) = mpsc::channel(capacity); - - ( - ConnectedChannel { - send: send_one, - recv: recv_two, - }, - ConnectedChannel { - send: send_two, - recv: recv_one, - }, - ) +pub async fn remove_peer(send: &mut Si, recv: &mut St, info: PeerInfo) -> Result<(), Error> +where + Si: Sink>, Error = PeerManagerError> + Unpin, + St: Stream, PeerManagerOutputError>> + Unpin, + Peer: Sink> + + Stream> + + TryStream + + std::fmt::Debug + + Send + + Unpin + + 'static, + Message: ManagedMessage + Send + 'static, +{ + let () = tokio::time::timeout(DEFAULT_TIMEOUT, send.send(Ok(PeerManagerInputMessage::RemovePeer(info)))) + .await + .map_err(|e| Error::SendTimedOut(e))??; + + let response = tokio::time::timeout(DEFAULT_TIMEOUT, recv.next()) + .await + .map(|res| res.ok_or(Error::ReceiverClosed())) + .map_err(|e| Error::ReceiveTimedOut(e))???; + + if let PeerManagerOutputMessage::PeerRemoved(info_recv) = response { + if info_recv == info { + Ok(()) + } else { + Err(Error::InfoHashMissMatch(info_recv, info)) + } + } else { + Err(Error::from(response)) + } } diff --git a/packages/peer/tests/peer_manager_send_backpressure.rs b/packages/peer/tests/peer_manager_send_backpressure.rs index dea17e95a..8bb0f5d66 100644 --- a/packages/peer/tests/peer_manager_send_backpressure.rs +++ b/packages/peer/tests/peer_manager_send_backpressure.rs @@ -1,107 +1,63 @@ -use common::{connected_channel, ConnectedChannel}; -use futures::sink::Sink; -use futures::stream::Stream; -use futures::{future, AsyncSink, Future}; +use common::connected_channel::{connected_channel, ConnectedChannel}; +use common::{add_peer, remove_peer, tracing_stderr_init, INIT}; +use futures::SinkExt as _; use handshake::Extensions; +use peer::error::PeerManagerError; use peer::messages::PeerWireProtocolMessage; use peer::protocols::NullProtocol; -use peer::{IPeerManagerMessage, OPeerManagerMessage, PeerInfo, PeerManagerBuilder}; -use tokio_core::reactor::Core; +use peer::{PeerInfo, PeerManagerBuilder, PeerManagerInputMessage}; +use tracing::level_filters::LevelFilter; use util::bt; mod common; -#[test] -fn positive_peer_manager_send_backpressure() { - type Peer = ConnectedChannel, PeerWireProtocolMessage>; +type Peer = ConnectedChannel< + Result, std::io::Error>, + Result, std::io::Error>, +>; - let mut core = Core::new().unwrap(); - let manager = PeerManagerBuilder::new().with_peer_capacity(1).build(core.handle()); +#[tokio::test] +async fn positive_peer_manager_send_backpressure() { + INIT.call_once(|| { + tracing_stderr_init(LevelFilter::INFO); + }); + + let (mut send, mut recv) = PeerManagerBuilder::new() + .with_peer_capacity(1) + .build::>() + .into_parts(); // Create two peers let (peer_one, peer_two): (Peer, Peer) = connected_channel(5); - let peer_one_info = PeerInfo::new( - "127.0.0.1:0".parse().unwrap(), - [0u8; bt::PEER_ID_LEN].into(), - [0u8; bt::INFO_HASH_LEN].into(), - Extensions::new(), - ); - let peer_two_info = PeerInfo::new( - "127.0.0.1:1".parse().unwrap(), - [1u8; bt::PEER_ID_LEN].into(), - [1u8; bt::INFO_HASH_LEN].into(), - Extensions::new(), - ); + let peer_one_info = create_peer_info("127.0.0.1:0", [0u8; bt::PEER_ID_LEN], [0u8; bt::INFO_HASH_LEN]); + let peer_two_info = create_peer_info("127.0.0.1:1", [1u8; bt::PEER_ID_LEN], [1u8; bt::INFO_HASH_LEN]); // Add peer one to the manager - let manager = core - .run(manager.send(IPeerManagerMessage::AddPeer(peer_one_info, peer_one))) - .unwrap(); + add_peer(&mut send, &mut recv, peer_one_info, peer_one).await.unwrap(); - // Check that peer one was added - let (response, mut manager) = core - .run( - manager - .into_future() - .map(|(opt_item, stream)| (opt_item.unwrap(), stream)) - .map_err(|_| ()), - ) - .unwrap(); - match response { - OPeerManagerMessage::PeerAdded(info) => assert_eq!(peer_one_info, info), - _ => panic!("Unexpected First Peer Manager Response"), + // Try to add peer two, but make sure it was denied (start send returned not ready) + let Err(full) = send.start_send_unpin(Ok(PeerManagerInputMessage::AddPeer(peer_two_info, peer_two.clone()))) else { + panic!("it should not add to full peer store") }; - // Try to add peer two, but make sure it was denied (start send returned not ready) - let (response, manager) = core - .run(future::lazy(|| { - future::ok::<_, ()>(( - manager.start_send(IPeerManagerMessage::AddPeer(peer_two_info, peer_two)), - manager, - )) - })) - .unwrap(); - let peer_two = match response { - Ok(AsyncSink::NotReady(IPeerManagerMessage::AddPeer(info, peer_two))) => { - assert_eq!(peer_two_info, info); - peer_two - } - _ => panic!("Unexpected Second Peer Manager Response"), + let PeerManagerError::PeerStoreFull(capacity) = full else { + panic!("it should be a peer store full error, but got: {full:?}") }; - // Remove peer one from the manager - let manager = core - .run(manager.send(IPeerManagerMessage::RemovePeer(peer_one_info))) - .unwrap(); + assert_eq!(capacity, 1); - // Check that peer one was removed - let (response, manager) = core - .run( - manager - .into_future() - .map(|(opt_item, stream)| (opt_item.unwrap(), stream)) - .map_err(|_| ()), - ) - .unwrap(); - match response { - OPeerManagerMessage::PeerRemoved(info) => assert_eq!(peer_one_info, info), - _ => panic!("Unexpected Third Peer Manager Response"), - }; + // Remove peer one from the manager + remove_peer(&mut send, &mut recv, peer_one_info).await.unwrap(); // Try to add peer two, but make sure it goes through - let manager = core - .run(manager.send(IPeerManagerMessage::AddPeer(peer_two_info, peer_two))) - .unwrap(); - let (response, _manager) = core - .run( - manager - .into_future() - .map(|(opt_item, stream)| (opt_item.unwrap(), stream)) - .map_err(|_| ()), - ) - .unwrap(); - match response { - OPeerManagerMessage::PeerAdded(info) => assert_eq!(peer_two_info, info), - _ => panic!("Unexpected Fourth Peer Manager Response"), - }; + add_peer(&mut send, &mut recv, peer_two_info, peer_two).await.unwrap(); +} + +fn create_peer_info(addr: &str, peer_id: [u8; bt::PEER_ID_LEN], info_hash: [u8; bt::INFO_HASH_LEN]) -> PeerInfo { + PeerInfo::new( + addr.parse().expect("Invalid address"), + peer_id.into(), + info_hash.into(), + Extensions::new(), + ) } diff --git a/packages/select/Cargo.toml b/packages/select/Cargo.toml index 142fd54b0..443f35945 100644 --- a/packages/select/Cargo.toml +++ b/packages/select/Cargo.toml @@ -22,12 +22,13 @@ peer = { path = "../peer" } util = { path = "../util" } utracker = { path = "../utracker" } -bit-set = "0.5" -bytes = "0.4" -error-chain = "0.12" -futures = "0.1" -log = "0.4" -rand = "0.8" +bit-set = "0" +bytes = "1" +futures = "0" +rand = "0" +thiserror = "1" +tracing = "0" [dev-dependencies] -futures-test = { git = "https://github.com/carllerche/better-future.git" } +tokio = { version = "1", features = ["full"] } +tracing-subscriber = "0" diff --git a/packages/select/src/discovery/error.rs b/packages/select/src/discovery/error.rs index d88355482..275ca5eef 100644 --- a/packages/select/src/discovery/error.rs +++ b/packages/select/src/discovery/error.rs @@ -1,33 +1,16 @@ //! Module for discovery error types. -use error_chain::error_chain; use handshake::InfoHash; use peer::PeerInfo; +use thiserror::Error; -error_chain! { - types { - DiscoveryError, DiscoveryErrorKind, DiscoveryResultExt; - } - - errors { - InvalidMessage { - info: PeerInfo, - message: String - } { - description("Peer Sent An Invalid Message") - display("Peer {:?} Sent An Invalid Message: {:?}", info, message) - } - InvalidMetainfoExists { - hash: InfoHash - } { - description("Metainfo Has Already Been Added") - display("Metainfo With Hash {:?} Has Already Been Added", hash) - } - InvalidMetainfoNotExists { - hash: InfoHash - } { - description("Metainfo Was Not Already Added") - display("Metainfo With Hash {:?} Was Not Already Added", hash) - } - } +#[allow(clippy::module_name_repetitions)] +#[derive(Error, Debug)] +pub enum DiscoveryError { + #[error("Peer {info:?} Sent An Invalid Message: {message:?}")] + InvalidMessage { info: PeerInfo, message: String }, + #[error("Metainfo With Hash {hash:?} Has Already Been Added")] + InvalidMetainfoExists { hash: InfoHash }, + #[error("Metainfo With Hash {hash:?} Was Not Already Added")] + InvalidMetainfoNotExists { hash: InfoHash }, } diff --git a/packages/select/src/discovery/ut_metadata.rs b/packages/select/src/discovery/ut_metadata.rs index d24e0aa85..ad15d70db 100644 --- a/packages/select/src/discovery/ut_metadata.rs +++ b/packages/select/src/discovery/ut_metadata.rs @@ -1,13 +1,14 @@ use std::collections::hash_map::Entry; use std::collections::{HashMap, HashSet, VecDeque}; -use std::io::Write; +use std::io::Write as _; +use std::pin::Pin; +use std::task::{Context, Poll, Waker}; use std::time::Duration; use bytes::BytesMut; -use futures::task::Task; -use futures::{task, Async, AsyncSink, Poll, Sink, StartSend, Stream}; +use futures::sink::Sink; +use futures::stream::Stream; use handshake::InfoHash; -use log::info; use metainfo::{Info, Metainfo}; use peer::messages::builders::ExtendedMessageBuilder; use peer::messages::{ @@ -16,14 +17,13 @@ use peer::messages::{ use peer::PeerInfo; use rand::{self, Rng}; -use crate::discovery::error::{DiscoveryError, DiscoveryErrorKind}; +use crate::discovery::error::DiscoveryError; use crate::discovery::{IDiscoveryMessage, ODiscoveryMessage}; use crate::extended::{ExtendedListener, ExtendedPeerInfo}; use crate::ControlMessage; const REQUEST_TIMEOUT_MILLIS: u64 = 2000; const MAX_REQUEST_SIZE: usize = 16 * 1024; - const MAX_ACTIVE_REQUESTS: usize = 100; const MAX_PEER_REQUESTS: usize = 100; @@ -49,16 +49,6 @@ struct ActivePeers { metadata_size: i64, } -/// Module for sending/receiving metadata from other peers. -/// -/// If you are using this module, you should make sure to handshake -/// peers with `Extension::ExtensionProtocol` active. Failure to do -/// this will result in this module not sending any messages. -/// -/// Metadata will be retrieved when `IDiscoveryMessage::DownloadMetadata` -/// is received, and will be served when -/// `IDiscoveryMessage::Control(ControlMessage::AddTorrent)` is received. - #[allow(clippy::module_name_repetitions)] #[derive(Default)] pub struct UtMetadataModule { @@ -67,12 +57,11 @@ pub struct UtMetadataModule { active_peers: HashMap, active_requests: Vec, peer_requests: VecDeque, - opt_sink: Option, - opt_stream: Option, + opt_sink_waker: Option, + opt_stream_waker: Option, } impl UtMetadataModule { - /// Create a new `UtMetadataModule`. #[must_use] pub fn new() -> UtMetadataModule { UtMetadataModule { @@ -81,40 +70,34 @@ impl UtMetadataModule { active_peers: HashMap::new(), active_requests: Vec::new(), peer_requests: VecDeque::new(), - opt_sink: None, - opt_stream: None, + opt_sink_waker: None, + opt_stream_waker: None, } } - fn add_torrent(&mut self, metainfo: &Metainfo) -> StartSend> { + fn add_torrent(&mut self, metainfo: &Metainfo) -> Result<(), DiscoveryError> { let info_hash = metainfo.info().info_hash(); - match self.completed_map.entry(info_hash) { - Entry::Occupied(_) => Err(Box::new(DiscoveryError::from_kind( - DiscoveryErrorKind::InvalidMetainfoExists { hash: info_hash }, - ))), + Entry::Occupied(_) => Err(DiscoveryError::InvalidMetainfoExists { hash: info_hash }), Entry::Vacant(vac) => { let info_bytes = metainfo.info().to_bytes(); vac.insert(info_bytes); - - Ok(AsyncSink::Ready) + Ok(()) } } } - fn remove_torrent(&mut self, metainfo: &Metainfo) -> StartSend> { + fn remove_torrent(&mut self, metainfo: &Metainfo) -> Result<(), DiscoveryError> { if self.completed_map.remove(&metainfo.info().info_hash()).is_none() { - Err(Box::new(DiscoveryError::from_kind( - DiscoveryErrorKind::InvalidMetainfoNotExists { - hash: metainfo.info().info_hash(), - }, - ))) + Err(DiscoveryError::InvalidMetainfoNotExists { + hash: metainfo.info().info_hash(), + }) } else { - Ok(AsyncSink::Ready) + Ok(()) } } - fn add_peer(&mut self, info: PeerInfo, ext_info: &ExtendedPeerInfo) -> futures::AsyncSink { + fn add_peer(&mut self, info: PeerInfo, ext_info: &ExtendedPeerInfo) { let our_support = ext_info .our_message() .and_then(|msg| msg.query_id(&ExtendedType::UtMetadata)) @@ -125,14 +108,14 @@ impl UtMetadataModule { .is_some(); let opt_metadata_size = ext_info.their_message().and_then(ExtendedMessage::metadata_size); - info!( + tracing::info!( "Our Support For UtMetadata Is {:?} And {:?} Support For UtMetadata Is {:?} With Metadata Size {:?}", our_support, info.addr(), they_support, opt_metadata_size ); - // If peer supports it, but they don't have the metadata size, then they probably don't have the file yet... + if let (true, true, Some(metadata_size)) = (our_support, they_support, opt_metadata_size) { self.active_peers .entry(*info.hash()) @@ -143,124 +126,78 @@ impl UtMetadataModule { .peers .insert(info); } - - AsyncSink::Ready } - fn remove_peer(&mut self, info: PeerInfo) -> futures::AsyncSink { - let empty_peers = if let Some(active_peers) = self.active_peers.get_mut(info.hash()) { + fn remove_peer(&mut self, info: PeerInfo) { + if let Some(active_peers) = self.active_peers.get_mut(info.hash()) { active_peers.peers.remove(&info); - - active_peers.peers.is_empty() - } else { - false - }; - - if empty_peers { - self.active_peers.remove(info.hash()); + if active_peers.peers.is_empty() { + self.active_peers.remove(info.hash()); + } } - - AsyncSink::Ready } - fn apply_tick(&mut self, duration: Duration) -> futures::AsyncSink { - let active_requests = &mut self.active_requests; - let active_peers = &mut self.active_peers; - let pending_map = &mut self.pending_map; - - // Retain only the requests that aren't expired - active_requests.retain(|request| { + fn apply_tick(&mut self, duration: Duration) { + self.active_requests.retain(|request| { let is_expired = request.left.checked_sub(duration).is_none(); - if is_expired { - // Peer didn't respond to our request, remove from active peers - if let Some(active) = active_peers.get_mut(request.sent_to.hash()) { + if let Some(active) = self.active_peers.get_mut(request.sent_to.hash()) { active.peers.remove(&request.sent_to); } - - // Push request back to pending - pending_map.get_mut(request.sent_to.hash()).map(|opt_pending| { - opt_pending.as_mut().map(|pending| { - pending.messages.push(request.message); - }) - }); + if let Some(Some(pending)) = self.pending_map.get_mut(request.sent_to.hash()) { + pending.messages.push(request.message); + } } - !is_expired }); - // Go back through and subtract from the left over requests, they wont underflow - for active_request in &mut *active_requests { + for active_request in &mut self.active_requests { active_request.left -= duration; } - - AsyncSink::Ready } - fn download_metainfo(&mut self, hash: InfoHash) -> futures::AsyncSink { + fn download_metainfo(&mut self, hash: InfoHash) { self.pending_map.entry(hash).or_insert(None); - - AsyncSink::Ready } - fn recv_request(&mut self, info: PeerInfo, request: UtMetadataRequestMessage) -> futures::AsyncSink { - if self.peer_requests.len() == MAX_PEER_REQUESTS { - AsyncSink::NotReady(IDiscoveryMessage::ReceivedUtMetadataMessage( - info, - UtMetadataMessage::Request(request), - )) - } else { + fn recv_request(&mut self, info: PeerInfo, request: UtMetadataRequestMessage) { + if self.peer_requests.len() < MAX_PEER_REQUESTS { self.peer_requests.push_back(PeerRequest { send_to: info, request }); - - AsyncSink::Ready } } - fn recv_data(&mut self, info: PeerInfo, data: &UtMetadataDataMessage) -> futures::AsyncSink { - // See if we can find the request that we made to the peer for that piece - let opt_index = self + fn recv_data(&mut self, info: PeerInfo, data: &UtMetadataDataMessage) { + if let Some(index) = self .active_requests .iter() - .position(|request| request.sent_to == info && request.message.piece() == data.piece()); - - // If so, go ahead and process it, if not, ignore it (could ban peer...) - if let Some(index) = opt_index { + .position(|request| request.sent_to == info && request.message.piece() == data.piece()) + { self.active_requests.swap_remove(index); - - if let Some(&mut Some(ref mut pending)) = self.pending_map.get_mut(info.hash()) { + if let Some(Some(pending)) = self.pending_map.get_mut(info.hash()) { let piece: usize = data.piece().try_into().unwrap(); let data_offset = piece.checked_mul(MAX_REQUEST_SIZE).unwrap(); - pending.left -= 1; (&mut pending.bytes.as_mut_slice()[data_offset..]) .write_all(data.data().as_ref()) .unwrap(); } } - - AsyncSink::Ready } - fn recv_reject(_info: PeerInfo, _reject: UtMetadataRejectMessage) -> futures::AsyncSink { + fn recv_reject(_info: PeerInfo, _reject: UtMetadataRejectMessage) { // TODO: Remove any requests after receiving a reject, for now, we will just timeout - AsyncSink::Ready } - //-------------------------------------------------------------------------------// - - fn retrieve_completed_download(&mut self) -> Option>> { + fn retrieve_completed_download(&mut self) -> Option> { let opt_completed_hash = self .pending_map .iter() - .find(|(_, opt_pending)| opt_pending.as_ref().is_some_and(|pending| pending.left == 0)) + .find(|(_, opt_pending)| opt_pending.as_ref().map_or(false, |pending| pending.left == 0)) .map(|(hash, _)| *hash); opt_completed_hash.and_then(|completed_hash| { let completed = self.pending_map.remove(&completed_hash).unwrap().unwrap(); - - // Clean up other structures since the download is complete self.active_peers.remove(&completed_hash); - match Info::from_bytes(&completed.bytes[..]) { Ok(info) => Some(Ok(ODiscoveryMessage::DownloadedMetainfo(info.into()))), Err(_) => self.retrieve_completed_download(), @@ -268,152 +205,112 @@ impl UtMetadataModule { }) } - fn retrieve_piece_request(&mut self) -> Option>> { + fn retrieve_piece_request(&mut self) -> Option> { for (hash, opt_pending) in &mut self.pending_map { - let has_ready_requests = opt_pending.as_ref().is_some_and(|pending| !pending.messages.is_empty()); - let has_active_peers = self.active_peers.get(hash).is_some_and(|peers| !peers.peers.is_empty()); - - if has_ready_requests && has_active_peers { - let pending = opt_pending.as_mut().unwrap(); - - let mut active_peers = self.active_peers.get(hash).unwrap().peers.iter(); - let num_active_peers = active_peers.len(); - let selected_peer_num = rand::thread_rng().gen::() % num_active_peers; - - let selected_peer = active_peers.nth(selected_peer_num).unwrap(); - let selected_message = pending.messages.pop().unwrap(); - - self.active_requests - .push(generate_active_request(selected_message, *selected_peer)); - - info!( - "Requesting Piece {:?} For Hash {:?}", - selected_message.piece(), - selected_peer.hash() - ); - return Some(Ok(ODiscoveryMessage::SendUtMetadataMessage( - *selected_peer, - UtMetadataMessage::Request(selected_message), - ))); + if let Some(pending) = opt_pending { + if !pending.messages.is_empty() { + if let Some(active_peers) = self.active_peers.get(hash) { + if !active_peers.peers.is_empty() { + let mut active_peers_iter = active_peers.peers.iter(); + let num_active_peers = active_peers_iter.len(); + let selected_peer_num = rand::thread_rng().gen::() % num_active_peers; + let selected_peer = active_peers_iter.nth(selected_peer_num).unwrap(); + let selected_message = pending.messages.pop().unwrap(); + self.active_requests + .push(generate_active_request(selected_message, *selected_peer)); + tracing::info!( + "Requesting Piece {:?} For Hash {:?}", + selected_message.piece(), + selected_peer.hash() + ); + return Some(Ok(ODiscoveryMessage::SendUtMetadataMessage( + *selected_peer, + UtMetadataMessage::Request(selected_message), + ))); + } + } + } } } - None } - fn retrieve_piece_response(&mut self) -> Option>> { + fn retrieve_piece_response(&mut self) -> Option> { while let Some(request) = self.peer_requests.pop_front() { let hash = request.send_to.hash(); let piece: usize = request.request.piece().try_into().unwrap(); - let start = piece * MAX_REQUEST_SIZE; let end = start + MAX_REQUEST_SIZE; - if let Some(data) = self.completed_map.get(hash) { if start <= data.len() && end <= data.len() { let info_slice = &data[start..end]; let mut info_payload = BytesMut::with_capacity(info_slice.len()); - info_payload.extend_from_slice(info_slice); let message = UtMetadataDataMessage::new( piece.try_into().unwrap(), info_slice.len().try_into().unwrap(), info_payload.freeze(), ); - return Some(Ok(ODiscoveryMessage::SendUtMetadataMessage( request.send_to, UtMetadataMessage::Data(message), ))); } - // else { - // // Peer asked for a piece outside of the range...don't respond to that - // } } } - None } - //-------------------------------------------------------------------------------// - fn initialize_pending(&mut self) -> bool { let mut pending_tasks_available = false; - - // Initialize PendingInfo once we get peers that have told us the metadata size for (hash, opt_pending) in &mut self.pending_map { if opt_pending.is_none() { - let opt_pending_info = self - .active_peers - .get(hash) - .map(|active_peers| pending_info_from_metadata_size(active_peers.metadata_size)); - - *opt_pending = opt_pending_info; + if let Some(active_peers) = self.active_peers.get(hash) { + *opt_pending = Some(pending_info_from_metadata_size(active_peers.metadata_size)); + } } - - // If pending is there, and the messages array is not empty - pending_tasks_available |= opt_pending.as_ref().is_some_and(|pending| !pending.messages.is_empty()); + pending_tasks_available |= opt_pending.as_ref().map_or(false, |pending| !pending.messages.is_empty()); } - pending_tasks_available } fn validate_downloaded(&mut self) -> bool { let mut completed_downloads_available = false; - - // Sweep over all "pending" requests, and check if completed downloads pass hash validation - // If not, set them back to None so they get re-initialized - // If yes, mark down that we have completed downloads for (&expected_hash, opt_pending) in &mut self.pending_map { - let should_reset = opt_pending.as_mut().is_some_and(|pending| { + if let Some(pending) = opt_pending { if pending.left == 0 { let real_hash = InfoHash::from_bytes(&pending.bytes[..]); - let needs_reset = real_hash != expected_hash; - - // If we don't need a reset, we finished and validation passed! - completed_downloads_available |= !needs_reset; - - // If we need a reset, we finished and validation failed! - needs_reset - } else { - false + if real_hash == expected_hash { + completed_downloads_available = true; + } else { + *opt_pending = None; + } } - }); - - if should_reset { - *opt_pending = None; } } - completed_downloads_available } - //-------------------------------------------------------------------------------// - fn check_stream_unblock(&mut self) { - // Will invalidate downloads that don't pass hash check let downloads_available = self.validate_downloaded(); - // Will potentially re-initialize downloads that failed hash check let tasks_available = self.initialize_pending(); - let free_task_queue_space = self.active_requests.len() != MAX_ACTIVE_REQUESTS; let peer_requests_available = !self.peer_requests.is_empty(); - - // Check if stream is currently blocked AND either we can queue more requests OR we can service some requests OR we have complete downloads - let should_unblock = self.opt_stream.is_some() + let should_unblock = self.opt_stream_waker.is_some() && ((free_task_queue_space && tasks_available) || peer_requests_available || downloads_available); - if should_unblock { - self.opt_stream.take().unwrap().notify(); + if let Some(waker) = self.opt_stream_waker.take() { + waker.wake(); + } } } fn check_sink_unblock(&mut self) { - // Check if sink is currently blocked AND max peer requests has not been reached - let should_unblock = self.opt_sink.is_some() && self.peer_requests.len() != MAX_PEER_REQUESTS; - + let should_unblock = self.opt_sink_waker.is_some() && self.peer_requests.len() != MAX_PEER_REQUESTS; if should_unblock { - self.opt_sink.take().unwrap().notify(); + if let Some(waker) = self.opt_sink_waker.take() { + waker.wake(); + } } } } @@ -428,20 +325,16 @@ fn generate_active_request(message: UtMetadataRequestMessage, peer: PeerInfo) -> fn pending_info_from_metadata_size(metadata_size: i64) -> PendingInfo { let cast_metadata_size: usize = metadata_size.try_into().unwrap(); - let bytes = vec![0u8; cast_metadata_size]; let mut messages = Vec::new(); - let num_pieces = if cast_metadata_size % MAX_REQUEST_SIZE != 0 { cast_metadata_size / MAX_REQUEST_SIZE + 1 } else { cast_metadata_size / MAX_REQUEST_SIZE }; - for index in 0..num_pieces { messages.push(UtMetadataRequestMessage::new(index.try_into().unwrap())); } - PendingInfo { messages, left: num_pieces, @@ -449,8 +342,6 @@ fn pending_info_from_metadata_size(metadata_size: i64) -> PendingInfo { } } -//-------------------------------------------------------------------------------// - impl ExtendedListener for UtMetadataModule { fn extend(&self, _info: &PeerInfo, builder: ExtendedMessageBuilder) -> ExtendedMessageBuilder { builder.with_extended_type(ExtendedType::UtMetadata, Some(5)) @@ -458,75 +349,79 @@ impl ExtendedListener for UtMetadataModule { fn on_update(&mut self, info: &PeerInfo, extended: &ExtendedPeerInfo) { self.add_peer(*info, extended); - - // Check if we need to unblock the stream after performing our work self.check_stream_unblock(); } } -//-------------------------------------------------------------------------------// +impl Sink for UtMetadataModule { + type Error = DiscoveryError; -impl Sink for UtMetadataModule { - type SinkItem = IDiscoveryMessage; - type SinkError = Box; + fn poll_ready(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + if self.peer_requests.len() < MAX_PEER_REQUESTS { + Poll::Ready(Ok(())) + } else { + self.opt_sink_waker = Some(cx.waker().clone()); + Poll::Pending + } + } - fn start_send(&mut self, item: Self::SinkItem) -> StartSend { - let start_send = match item { + fn start_send(mut self: Pin<&mut Self>, item: IDiscoveryMessage) -> Result<(), Self::Error> { + match item { IDiscoveryMessage::Control(ControlMessage::AddTorrent(metainfo)) => self.add_torrent(&metainfo), IDiscoveryMessage::Control(ControlMessage::RemoveTorrent(metainfo)) => self.remove_torrent(&metainfo), - // don't add the peer yet, use listener to get notified when they send extension messages - IDiscoveryMessage::Control(ControlMessage::PeerConnected(_)) => Ok(AsyncSink::Ready), - IDiscoveryMessage::Control(ControlMessage::PeerDisconnected(info)) => StartSend::Ok(self.remove_peer(info)), - IDiscoveryMessage::Control(ControlMessage::Tick(duration)) => StartSend::Ok(self.apply_tick(duration)), - IDiscoveryMessage::DownloadMetainfo(hash) => StartSend::Ok(self.download_metainfo(hash)), + IDiscoveryMessage::Control(ControlMessage::PeerConnected(_)) => Ok(()), + IDiscoveryMessage::Control(ControlMessage::PeerDisconnected(info)) => { + self.remove_peer(info); + Ok(()) + } + IDiscoveryMessage::Control(ControlMessage::Tick(duration)) => { + self.apply_tick(duration); + Ok(()) + } + IDiscoveryMessage::DownloadMetainfo(hash) => { + self.download_metainfo(hash); + Ok(()) + } IDiscoveryMessage::ReceivedUtMetadataMessage(info, UtMetadataMessage::Request(msg)) => { - StartSend::Ok(self.recv_request(info, msg)) + self.recv_request(info, msg); + Ok(()) } IDiscoveryMessage::ReceivedUtMetadataMessage(info, UtMetadataMessage::Data(msg)) => { - StartSend::Ok(self.recv_data(info, &msg)) + self.recv_data(info, &msg); + Ok(()) } IDiscoveryMessage::ReceivedUtMetadataMessage(info, UtMetadataMessage::Reject(msg)) => { - StartSend::Ok(UtMetadataModule::recv_reject(info, msg)) + UtMetadataModule::recv_reject(info, msg); + Ok(()) } - }; - - // Check if we need to unblock the stream after performing our work - self.check_stream_unblock(); - - // Check if we need to block the sink, if so, set the task - if start_send.as_ref().map(futures::AsyncSink::is_not_ready).unwrap_or(false) { - self.opt_sink = Some(task::current()); } - start_send } - fn poll_complete(&mut self) -> Poll<(), Self::SinkError> { - Ok(Async::Ready(())) + fn poll_flush(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll> { + Poll::Ready(Ok(())) + } + + fn poll_close(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll> { + Poll::Ready(Ok(())) } } impl Stream for UtMetadataModule { - type Item = ODiscoveryMessage; - type Error = Box; + type Item = Result; - fn poll(&mut self) -> Poll, Self::Error> { - // Check if we completed any downloads - // Or if we can send any requests - // Or if we can send any responses + fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { let opt_result = self .retrieve_completed_download() .or_else(|| self.retrieve_piece_request()) .or_else(|| self.retrieve_piece_response()); - // Check if we can unblock the sink after performing our work self.check_sink_unblock(); - // Check if we need to block the stream, if so, set the task if let Some(result) = opt_result { - result.map(|value| Async::Ready(Some(value))) + Poll::Ready(Some(result)) } else { - self.opt_stream = Some(task::current()); - Ok(Async::NotReady) + self.opt_stream_waker = Some(cx.waker().clone()); + Poll::Pending } } } diff --git a/packages/select/src/error.rs b/packages/select/src/error.rs index 8b318404b..4712f4093 100644 --- a/packages/select/src/error.rs +++ b/packages/select/src/error.rs @@ -1,15 +1,11 @@ //! Module for uber error types. -use error_chain::error_chain; +use thiserror::Error; -use crate::discovery::error::{DiscoveryError, DiscoveryErrorKind}; +use crate::discovery::error::DiscoveryError; -error_chain! { - types { - UberError, UberErrorKind, UberResultExt; - } - - links { - Discovery(DiscoveryError, DiscoveryErrorKind); - } +#[derive(Error, Debug)] +pub enum Error { + #[error(transparent)] + Discovery(#[from] DiscoveryError), } diff --git a/packages/select/src/extended/mod.rs b/packages/select/src/extended/mod.rs index 9786d6335..53dabc1b1 100644 --- a/packages/select/src/extended/mod.rs +++ b/packages/select/src/extended/mod.rs @@ -1,12 +1,14 @@ use std::collections::{HashMap, VecDeque}; +use std::pin::Pin; +use std::sync::{Arc, Mutex}; +use std::task::{Context, Poll, Waker}; -use futures::task::Task; -use futures::{task, Async, Poll, Stream}; +use futures::stream::Stream; use peer::messages::builders::ExtendedMessageBuilder; use peer::messages::ExtendedMessage; use peer::PeerInfo; -use crate::error::UberError; +use crate::error::Error; use crate::ControlMessage; /// Enumeration of extended messages that can be sent to the extended module. @@ -68,24 +70,25 @@ impl ExtendedPeerInfo { //------------------------------------------------------------------------------// #[allow(clippy::module_name_repetitions)] +#[derive(Clone)] pub struct ExtendedModule { builder: ExtendedMessageBuilder, - peers: HashMap, - out_queue: VecDeque, - opt_task: Option, + peers: Arc>>, + out_queue: Arc>>, + opt_waker: Arc>>, } impl ExtendedModule { pub fn new(builder: ExtendedMessageBuilder) -> ExtendedModule { ExtendedModule { builder, - peers: HashMap::new(), - out_queue: VecDeque::new(), - opt_task: None, + peers: Arc::default(), + out_queue: Arc::default(), + opt_waker: Arc::default(), } } - pub fn process_message(&mut self, message: IExtendedMessage, d_modules: &mut [Box]) + pub fn process_message(&self, message: IExtendedMessage, d_modules: &mut [Arc]) where D: ExtendedListener + ?Sized, { @@ -102,22 +105,25 @@ impl ExtendedModule { let ext_peer_info = ExtendedPeerInfo::new(Some(ext_message.clone()), None); for d_module in d_modules { - d_module.on_update(&info, &ext_peer_info); + Arc::get_mut(d_module).unwrap().on_update(&info, &ext_peer_info); } - self.peers.insert(info, ext_peer_info); + self.peers.lock().unwrap().insert(info, ext_peer_info); self.out_queue + .lock() + .unwrap() .push_back(OExtendedMessage::SendExtendedMessage(info, ext_message)); } IExtendedMessage::Control(ControlMessage::PeerDisconnected(info)) => { - self.peers.remove(&info); + self.peers.lock().unwrap().remove(&info); } IExtendedMessage::ReceivedExtendedMessage(info, ext_message) => { - let ext_peer_info = self.peers.get_mut(&info).unwrap(); - ext_peer_info.update_theirs(ext_message); + if let Some(ext_peer_info) = self.peers.lock().unwrap().get_mut(&info) { + ext_peer_info.update_theirs(ext_message); - for d_module in d_modules { - d_module.on_update(&info, ext_peer_info); + for d_module in d_modules { + Arc::get_mut(d_module).unwrap().on_update(&info, ext_peer_info); + } } } IExtendedMessage::Control(_) => (), @@ -126,28 +132,24 @@ impl ExtendedModule { self.check_stream_unblock(); } - fn check_stream_unblock(&mut self) { - if !self.out_queue.is_empty() { - if let Some(task) = self.opt_task.take() { - task.notify(); + fn check_stream_unblock(&self) { + if !self.out_queue.lock().unwrap().is_empty() { + if let Some(waker) = self.opt_waker.lock().unwrap().take() { + waker.wake(); } } } } impl Stream for ExtendedModule { - type Item = OExtendedMessage; - type Error = Box; - - fn poll(&mut self) -> Poll, Self::Error> { - let opt_message = self.out_queue.pop_front(); + type Item = Result; - if let Some(message) = opt_message { - Ok(Async::Ready(Some(message))) + fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + if let Some(message) = self.out_queue.lock().unwrap().pop_front() { + Poll::Ready(Some(Ok(message))) } else { - self.opt_task = Some(task::current()); - - Ok(Async::NotReady) + self.opt_waker.lock().unwrap().replace(cx.waker().clone()); + Poll::Pending } } } diff --git a/packages/select/src/revelation/error.rs b/packages/select/src/revelation/error.rs index bb6216d3b..bc3d9061a 100644 --- a/packages/select/src/revelation/error.rs +++ b/packages/select/src/revelation/error.rs @@ -1,40 +1,18 @@ //! Module for revelation error types. -use error_chain::error_chain; use handshake::InfoHash; use peer::PeerInfo; +use thiserror::Error; -error_chain! { - types { - RevealError, RevealErrorKind, RevealResultExt; - } - - errors { - InvalidMessage { - info: PeerInfo, - message: String - } { - description("Peer Sent An Invalid Message") - display("Peer {:?} Sent An Invalid Message: {:?}", info, message) - } - InvalidMetainfoExists { - hash: InfoHash - } { - description("Metainfo Has Already Been Added") - display("Metainfo With Hash {:?} Has Already Been Added", hash) - } - InvalidMetainfoNotExists { - hash: InfoHash - } { - description("Metainfo Was Not Already Added") - display("Metainfo With Hash {:?} Was Not Already Added", hash) - } - InvalidPieceOutOfRange { - hash: InfoHash, - index: u64 - } { - description("Piece Index Was Out Of Range") - display("Piece Index {:?} Was Out Of Range For Hash {:?}", index, hash) - } - } +#[allow(clippy::module_name_repetitions)] +#[derive(Error, Debug)] +pub enum RevealError { + #[error("Peer {info:?} Sent An Invalid Message: {message:?}")] + InvalidMessage { info: PeerInfo, message: String }, + #[error("Metainfo With Hash {hash:?} Has Already Been Added")] + InvalidMetainfoExists { hash: InfoHash }, + #[error("Metainfo With Hash {hash:?} Was Not Already Added")] + InvalidMetainfoNotExists { hash: InfoHash }, + #[error("Piece Index {index:?} Was Out Of Range For Hash {hash:?}")] + InvalidPieceOutOfRange { hash: InfoHash, index: u64 }, } diff --git a/packages/select/src/revelation/honest.rs b/packages/select/src/revelation/honest.rs index f1041d6f0..e359108f3 100644 --- a/packages/select/src/revelation/honest.rs +++ b/packages/select/src/revelation/honest.rs @@ -1,29 +1,43 @@ use std::collections::hash_map::Entry; use std::collections::{HashMap, HashSet, VecDeque}; +use std::pin::Pin; +use std::task::{Context, Poll, Waker}; use bit_set::BitSet; use bytes::{BufMut, BytesMut}; -use futures::task::Task; -use futures::{task, Async, AsyncSink, Poll, Sink, StartSend, Stream}; +use futures::{Sink, Stream}; use handshake::InfoHash; use metainfo::Metainfo; use peer::messages::{BitFieldMessage, HaveMessage}; use peer::PeerInfo; +use tracing::instrument; -use crate::revelation::error::{RevealError, RevealErrorKind}; +use crate::revelation::error::RevealError; use crate::revelation::{IRevealMessage, ORevealMessage}; use crate::ControlMessage; -/// Revelation module that will honestly report any pieces we have to peers. - #[allow(clippy::module_name_repetitions)] #[derive(Default)] -pub struct HonestRevealModule { +pub struct HonestRevealModuleBuilder { torrents: HashMap, out_queue: VecDeque, - // Shared bytes container to write bitfield messages to out_bytes: BytesMut, - opt_stream: Option, +} + +impl HonestRevealModuleBuilder { + #[must_use] + pub fn new() -> HonestRevealModuleBuilder { + HonestRevealModuleBuilder { + torrents: HashMap::new(), + out_queue: VecDeque::new(), + out_bytes: BytesMut::new(), + } + } + + #[must_use] + pub fn build(self) -> HonestRevealModule { + HonestRevealModule::from_builder(self) + } } struct PeersInfo { @@ -32,25 +46,49 @@ struct PeersInfo { peers: HashSet, } +#[allow(clippy::module_name_repetitions)] +pub struct HonestRevealModule { + torrents: HashMap, + out_queue: VecDeque, + out_bytes: BytesMut, + opt_stream_waker: Option, +} + impl HonestRevealModule { - /// Create a new `HonestRevelationModule`. #[must_use] - pub fn new() -> HonestRevealModule { + pub fn from_builder(builder: HonestRevealModuleBuilder) -> HonestRevealModule { HonestRevealModule { - torrents: HashMap::new(), - out_queue: VecDeque::new(), - out_bytes: BytesMut::new(), - opt_stream: None, + torrents: builder.torrents, + out_queue: builder.out_queue, + out_bytes: builder.out_bytes, + opt_stream_waker: None, } } - fn add_torrent(&mut self, metainfo: &Metainfo) -> StartSend> { + fn handle_message(&mut self, message: IRevealMessage) -> Result<(), RevealError> { + match message { + IRevealMessage::Control(ControlMessage::AddTorrent(metainfo)) => self.add_torrent(&metainfo), + IRevealMessage::Control(ControlMessage::RemoveTorrent(metainfo)) => self.remove_torrent(&metainfo), + IRevealMessage::Control(ControlMessage::PeerConnected(info)) => self.add_peer(info), + IRevealMessage::Control(ControlMessage::PeerDisconnected(info)) => self.remove_peer(info), + IRevealMessage::FoundGoodPiece(hash, index) => self.insert_piece(hash, index), + IRevealMessage::Control(ControlMessage::Tick(_)) + | IRevealMessage::ReceivedBitField(_, _) + | IRevealMessage::ReceivedHave(_, _) => Ok(()), + } + } + + #[instrument(skip(self))] + fn add_torrent(&mut self, metainfo: &Metainfo) -> Result<(), RevealError> { + tracing::trace!("adding torrent"); + let info_hash = metainfo.info().info_hash(); match self.torrents.entry(info_hash) { - Entry::Occupied(_) => Err(Box::new(RevealError::from_kind(RevealErrorKind::InvalidMetainfoExists { - hash: info_hash, - }))), + Entry::Occupied(_) => { + tracing::error!("invalid metainfo"); + Err(RevealError::InvalidMetainfoExists { hash: info_hash }) + } Entry::Vacant(vac) => { let num_pieces = metainfo.info().pieces().count(); @@ -64,161 +102,159 @@ impl HonestRevealModule { }; vac.insert(peers_info); - Ok(AsyncSink::Ready) + Ok(()) } } } - fn remove_torrent(&mut self, metainfo: &Metainfo) -> StartSend> { + #[instrument(skip(self))] + fn remove_torrent(&mut self, metainfo: &Metainfo) -> Result<(), RevealError> { + tracing::trace!("removing torrent"); + let info_hash = metainfo.info().info_hash(); if self.torrents.remove(&info_hash).is_none() { - Err(Box::new(RevealError::from_kind(RevealErrorKind::InvalidMetainfoNotExists { - hash: info_hash, - }))) + Err(RevealError::InvalidMetainfoNotExists { hash: info_hash }) } else { - Ok(AsyncSink::Ready) + Ok(()) } } - fn add_peer(&mut self, peer: PeerInfo) -> StartSend> { + #[instrument(skip(self))] + fn add_peer(&mut self, peer: PeerInfo) -> Result<(), RevealError> { + tracing::trace!("adding peer"); let info_hash = *peer.hash(); let out_bytes = &mut self.out_bytes; let out_queue = &mut self.out_queue; let Some(peers_info) = self.torrents.get_mut(&info_hash) else { - return Err(Box::new(RevealError::from_kind(RevealErrorKind::InvalidMetainfoNotExists { - hash: info_hash, - }))); + tracing::error!("adding peer error"); + return Err(RevealError::InvalidMetainfoNotExists { hash: info_hash }); }; - // Add the peer to our list, so we send have messages to them peers_info.peers.insert(peer); - // If our bitfield has any pieces in it, send the bitfield, otherwise, don't send it if !peers_info.status.is_empty() { - // Get our current bitfield, write it to our shared bytes let bitfield_slice = peers_info.status.get_ref().storage(); - // Bitfield stores index 0 at bit 7 from the left, we want index 0 to be at bit 0 from the left insert_reversed_bits(out_bytes, bitfield_slice); - // Split off what we wrote, send this in the message, will be re-used on drop let bitfield_bytes = out_bytes.split_off(0).freeze(); let bitfield = BitFieldMessage::new(bitfield_bytes); - // Enqueue the bitfield message so that we send it to the peer - out_queue.push_back(ORevealMessage::SendBitField(peer, bitfield)); + let message = ORevealMessage::SendBitField(peer, bitfield); + tracing::trace!("sending message: {message:?}"); + + out_queue.push_back(message); + if let Some(waker) = self.opt_stream_waker.take() { + waker.wake(); + } } - Ok(AsyncSink::Ready) + Ok(()) } - fn remove_peer(&mut self, peer: PeerInfo) -> StartSend> { + #[instrument(skip(self))] + fn remove_peer(&mut self, peer: PeerInfo) -> Result<(), RevealError> { let info_hash = *peer.hash(); let Some(peers_info) = self.torrents.get_mut(&info_hash) else { - return Err(Box::new(RevealError::from_kind(RevealErrorKind::InvalidMetainfoNotExists { - hash: info_hash, - }))); + return Err(RevealError::InvalidMetainfoNotExists { hash: info_hash }); }; peers_info.peers.remove(&peer); - Ok(AsyncSink::Ready) + Ok(()) } - fn insert_piece(&mut self, hash: InfoHash, index: u64) -> StartSend> { + #[instrument(skip(self))] + fn insert_piece(&mut self, hash: InfoHash, index: u64) -> Result<(), RevealError> { + tracing::trace!("inserting piece"); + let out_queue = &mut self.out_queue; let Some(peers_info) = self.torrents.get_mut(&hash) else { - return Err(Box::new(RevealError::from_kind(RevealErrorKind::InvalidMetainfoNotExists { - hash, - }))); + return Err(RevealError::InvalidMetainfoNotExists { hash }); }; let index: usize = index.try_into().unwrap(); if index >= peers_info.num_pieces { - Err(Box::new(RevealError::from_kind(RevealErrorKind::InvalidPieceOutOfRange { + Err(RevealError::InvalidPieceOutOfRange { index: index.try_into().unwrap(), hash, - }))) + }) } else { - // Queue up all have messages for peer in &peers_info.peers { - out_queue.push_back(ORevealMessage::SendHave(*peer, HaveMessage::new(index.try_into().unwrap()))); + let message = ORevealMessage::SendHave(*peer, HaveMessage::new(index.try_into().unwrap())); + tracing::trace!("sending message: {message:?}"); + + out_queue.push_back(message); + + if let Some(waker) = self.opt_stream_waker.take() { + waker.wake(); + } } - // Insert into bitfield peers_info.status.insert(index); - Ok(AsyncSink::Ready) + Ok(()) } } - //------------------------------------------------------// + #[instrument(skip(self))] + fn poll_next_message(&mut self, cx: &mut Context<'_>) -> Poll>> { + tracing::trace!("polling for next message"); - fn check_stream_unblock(&mut self) { - if !self.out_queue.is_empty() { - self.opt_stream.take().as_ref().map(Task::notify); - } - } -} + if let Some(message) = self.out_queue.pop_front() { + tracing::trace!("sending message {message:?}"); -/// Inserts the slice into the `BytesMut` but reverses the bits in each byte. -fn insert_reversed_bits(bytes: &mut BytesMut, slice: &[u8]) { - for mut byte in slice.iter().copied() { - let mut reversed_byte = 0; + Poll::Ready(Some(Ok(message))) + } else { + tracing::trace!("no messages found... pending"); - for _ in 0..8 { - // Make room for the bit - reversed_byte <<= 1; - // Push the last bit over - reversed_byte |= byte & 0x01; - // Push the last bit off - byte >>= 1; + self.opt_stream_waker = Some(cx.waker().clone()); + Poll::Pending } - - bytes.put_u8(reversed_byte); } } -impl Sink for HonestRevealModule { - type SinkItem = IRevealMessage; - type SinkError = Box; +impl Sink for HonestRevealModule { + type Error = RevealError; - fn start_send(&mut self, item: Self::SinkItem) -> StartSend { - let result = match item { - IRevealMessage::Control(ControlMessage::AddTorrent(metainfo)) => self.add_torrent(&metainfo), - IRevealMessage::Control(ControlMessage::RemoveTorrent(metainfo)) => self.remove_torrent(&metainfo), - IRevealMessage::Control(ControlMessage::PeerConnected(info)) => self.add_peer(info), - IRevealMessage::Control(ControlMessage::PeerDisconnected(info)) => self.remove_peer(info), - IRevealMessage::FoundGoodPiece(hash, index) => self.insert_piece(hash, index), - IRevealMessage::Control(ControlMessage::Tick(_)) - | IRevealMessage::ReceivedBitField(_, _) - | IRevealMessage::ReceivedHave(_, _) => Ok(AsyncSink::Ready), - }; + fn poll_ready(self: Pin<&mut Self>, _: &mut Context<'_>) -> Poll> { + Poll::Ready(Ok(())) + } - self.check_stream_unblock(); + fn start_send(mut self: Pin<&mut Self>, item: IRevealMessage) -> Result<(), Self::Error> { + self.handle_message(item) + } - result + fn poll_flush(self: Pin<&mut Self>, _: &mut Context<'_>) -> Poll> { + Poll::Ready(Ok(())) } - fn poll_complete(&mut self) -> Poll<(), Self::SinkError> { - Ok(Async::Ready(())) + fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + self.poll_flush(cx) } } impl Stream for HonestRevealModule { - type Item = ORevealMessage; - type Error = Box; + type Item = Result; - fn poll(&mut self) -> Poll, Self::Error> { - let next_item = self.out_queue.pop_front().map(|item| Ok(Async::Ready(Some(item)))); + fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + self.poll_next_message(cx) + } +} - next_item.unwrap_or_else(|| { - self.opt_stream = Some(task::current()); +fn insert_reversed_bits(bytes: &mut BytesMut, slice: &[u8]) { + for mut byte in slice.iter().copied() { + let mut reversed_byte = 0; - Ok(Async::NotReady) - }) + for _ in 0..8 { + reversed_byte <<= 1; + reversed_byte |= byte & 0x01; + byte >>= 1; + } + + bytes.put_u8(reversed_byte); } } diff --git a/packages/select/src/revelation/mod.rs b/packages/select/src/revelation/mod.rs index 8796c50c2..a53e218c7 100644 --- a/packages/select/src/revelation/mod.rs +++ b/packages/select/src/revelation/mod.rs @@ -10,9 +10,10 @@ pub mod error; mod honest; -pub use self::honest::HonestRevealModule; +pub use self::honest::{HonestRevealModule, HonestRevealModuleBuilder}; /// Enumeration of revelation messages that can be sent to a revelation module. +#[derive(Debug)] pub enum IRevealMessage { /// Control message. Control(ControlMessage), @@ -25,6 +26,7 @@ pub enum IRevealMessage { } /// Enumeration of revelation messages that can be received from a revelation module. +#[derive(Debug)] pub enum ORevealMessage { /// Send a `BitFieldMessage`. SendBitField(PeerInfo, BitFieldMessage), diff --git a/packages/select/src/uber/mod.rs b/packages/select/src/uber/mod.rs index fd99496df..617069028 100644 --- a/packages/select/src/uber/mod.rs +++ b/packages/select/src/uber/mod.rs @@ -1,22 +1,34 @@ -use futures::{Async, AsyncSink, Poll, Sink, StartSend, Stream}; +use std::sync::{Arc, Mutex}; + +use futures::{Sink, Stream}; use peer::messages::builders::ExtendedMessageBuilder; +use sink::UberSink; +use stream::UberStream; use crate::discovery::error::DiscoveryError; use crate::discovery::{IDiscoveryMessage, ODiscoveryMessage}; -use crate::error::UberError; use crate::extended::{ExtendedListener, ExtendedModule, IExtendedMessage, OExtendedMessage}; use crate::ControlMessage; +pub mod sink; +pub mod stream; + pub trait DiscoveryTrait: ExtendedListener - + Sink> - + Stream> + + Sink + + Stream> + + Send + + Unpin + + 'static { } impl DiscoveryTrait for T where T: ExtendedListener - + Sink> - + Stream> + + Sink + + Stream> + + Send + + Unpin + + 'static { } @@ -24,11 +36,11 @@ impl DiscoveryTrait for T where #[derive(Clone, Debug, PartialEq, Eq)] pub enum IUberMessage { /// Broadcast a control message out to all modules. - Control(ControlMessage), + Control(Box), /// Send an extended message to the extended module. - Extended(IExtendedMessage), + Extended(Box), /// Send a discovery message to all discovery modules. - Discovery(IDiscoveryMessage), + Discovery(Box), } /// Enumeration of uber messages that can be received from the uber module. @@ -40,23 +52,14 @@ pub enum OUberMessage { Discovery(ODiscoveryMessage), } -type UberDiscovery = Vec< - Box< - dyn DiscoveryTrait< - SinkItem = IDiscoveryMessage, - SinkError = Box, - Item = ODiscoveryMessage, - Error = Box, - >, - >, +type UberDiscovery = Arc< + Mutex> + Send + Sync>>>, >; /// Builder for constructing an `UberModule`. - #[allow(clippy::module_name_repetitions)] #[derive(Default)] pub struct UberModuleBuilder { - // TODO: Remove these bounds when something like https://github.com/rust-lang/rust/pull/45047 lands pub discovery: UberDiscovery, ext_builder: Option, } @@ -66,7 +69,7 @@ impl UberModuleBuilder { #[must_use] pub fn new() -> UberModuleBuilder { UberModuleBuilder { - discovery: Vec::new(), + discovery: Arc::default(), ext_builder: None, } } @@ -82,23 +85,22 @@ impl UberModuleBuilder { } /// Add the given discovery module to the list of discovery modules. + /// + /// # Panics + /// + /// It would panic if unable to get a lock for the discovery. #[must_use] - pub fn with_discovery_module(mut self, module: T) -> UberModuleBuilder + pub fn with_discovery_module(self, module: T) -> UberModuleBuilder where T: ExtendedListener - + Sink> - + Stream> + + Sink + + Stream> + + Send + + Sync + + Unpin + 'static, { - self.discovery.push(Box::new(module) - as Box< - dyn DiscoveryTrait< - SinkItem = IDiscoveryMessage, - SinkError = Box, - Item = ODiscoveryMessage, - Error = Box, - >, - >); + self.discovery.lock().unwrap().push(Arc::new(module)); self } @@ -109,242 +111,32 @@ impl UberModuleBuilder { } } -//----------------------------------------------------------------------// - -// TODO: Try to get generic is_ready trait into futures-rs -trait IsReady { - fn is_ready(&self) -> bool; -} - -impl IsReady for AsyncSink { - fn is_ready(&self) -> bool { - self.is_ready() - } -} - -impl IsReady for Async { - fn is_ready(&self) -> bool { - self.is_ready() - } -} - //----------------------------------------------------------------------// /// Module for multiplexing messages across zero or more other modules. #[allow(clippy::module_name_repetitions)] pub struct UberModule { - discovery: UberDiscovery, - extended: Option, - last_sink_state: Option, - last_stream_state: Option, -} - -#[derive(Debug, Copy, Clone)] -enum ModuleState { - Extended, - Discovery(usize), + sink: UberSink, + stream: UberStream, } impl UberModule { /// Create an `UberModule` from the given `UberModuleBuilder`. pub fn from_builder(builder: UberModuleBuilder) -> UberModule { - UberModule { - discovery: builder.discovery, - extended: builder.ext_builder.map(ExtendedModule::new), - last_sink_state: None, - last_stream_state: None, - } - } - - /// Get the next state after the given state, return `Some(next_state`) or None if the given state was the last state. - /// - /// We return the next state regardless of the message we are processing at the time. So if we don't recognize the tuple of - /// next state and message, we ignore it. This makes the implementation a lot easier as we don't have to do an exhaustive match - /// on all possible states and messages, as only a subset will be valid. - fn next_state(&self, state: Option) -> Option { - match state { - None => { - if self.extended.is_some() { - Some(ModuleState::Extended) - } else if !self.discovery.is_empty() { - Some(ModuleState::Discovery(0)) - } else { - None - } - } - Some(ModuleState::Extended) => { - if self.discovery.is_empty() { - None - } else { - Some(ModuleState::Discovery(0)) - } - } - Some(ModuleState::Discovery(index)) => { - if index + 1 < self.discovery.len() { - Some(ModuleState::Discovery(index + 1)) - } else { - None - } - } - } - } - - /// Loop over all states until we finish, or hit an error. - /// - /// Takes care of saving/resetting states if we hit an error/finish. - fn loop_states( - &mut self, - is_sink: bool, - init: Result, - get_next_state: G, - assign_state: A, - logic: L, - ) -> Result - where - G: Fn(&UberModule) -> Option, - A: Fn(&mut UberModule, Option), - L: Fn(&mut UberModule, ModuleState) -> Result, - R: IsReady, - { - let is_stream = !is_sink; - let mut result = init; - let mut opt_next_state = get_next_state(self); - - // Sink yields on: - // - NotReady - // - Error - // While stream yields on: - // - Ready - // - Error - - // TODO: Kind of need to make a full transition where the state doesn't change for this logic to work - // (cant start back at the middle when we get woken up, since we don't know what woke us up) - - let mut should_continue = result - .as_ref() - .map(|a| (is_sink && a.is_ready()) || (is_stream && !a.is_ready())) - .unwrap_or(false); - // While we are ready, and we haven't exhausted states, continue to loop - while should_continue && opt_next_state.is_some() { - let next_state = opt_next_state.unwrap(); - result = logic(self, next_state); - should_continue = result - .as_ref() - .map(|a| (is_sink && a.is_ready()) || (is_stream && !a.is_ready())) - .unwrap_or(false); + let discovery = builder.discovery; + let extended = builder.ext_builder.map(ExtendedModule::new); - // If we don't need to return to the user because of this error, mark it as done - if should_continue { - assign_state(self, opt_next_state); - } - opt_next_state = get_next_state(self); - } - - // If there was no next state, AND we would have continued regardless, set back to None - if opt_next_state.is_none() && should_continue { - assign_state(self, None); - } - - result - } - - /// Run the `start_send` logic for the current module for the given message. - fn start_sink_state(&mut self, message: &IUberMessage) -> StartSend<(), Box> { - self.loop_states( - true, - Ok(AsyncSink::Ready), - |uber| uber.next_state(uber.last_sink_state), - |uber, state| { - uber.last_sink_state = state; - }, - |uber, state| match (state, message) { - (ModuleState::Discovery(index), IUberMessage::Control(control)) => uber.discovery[index] - .start_send(IDiscoveryMessage::Control(control.clone())) - .map(|a| a.map(|_| ())) - .map_err(|e| Box::new(Into::::into(*e))), - (ModuleState::Discovery(index), IUberMessage::Discovery(discovery)) => uber.discovery[index] - .start_send(discovery.clone()) - .map(|a| a.map(|_| ())) - .map_err(|e| Box::new(Into::::into(*e))), - (ModuleState::Extended, IUberMessage::Control(control)) => { - let d_modules = &mut uber.discovery[..]; - - uber.extended.as_mut().map_or(Ok(AsyncSink::Ready), |ext_module| { - ext_module.process_message(IExtendedMessage::Control(control.clone()), d_modules); - - Ok(AsyncSink::Ready) - }) - } - (ModuleState::Extended, IUberMessage::Extended(extended)) => { - let d_modules = &mut uber.discovery[..]; - - uber.extended.as_mut().map_or(Ok(AsyncSink::Ready), |ext_module| { - ext_module.process_message(extended.clone(), d_modules); - - Ok(AsyncSink::Ready) - }) - } - _ => Ok(AsyncSink::Ready), - }, - ) - } - - fn poll_sink_state(&mut self) -> Poll<(), Box> { - self.loop_states( - true, - Ok(Async::Ready(())), - |uber| uber.next_state(uber.last_sink_state), - |uber, state| { - uber.last_sink_state = state; - }, - |uber, state| match state { - ModuleState::Discovery(index) => uber.discovery[index].poll_complete().map_err(|e| Box::new(Into::into(*e))), - ModuleState::Extended => Ok(Async::Ready(())), - }, - ) - } - - fn poll_stream_state(&mut self) -> Poll, Box> { - self.loop_states( - false, - Ok(Async::NotReady), - |uber| uber.next_state(uber.last_stream_state), - |uber, state| { - uber.last_stream_state = state; - }, - |uber, state| match state { - ModuleState::Extended => uber.extended.as_mut().map_or(Ok(Async::Ready(None)), |ext_module| { - ext_module - .poll() - .map(|async_opt_message| async_opt_message.map(|opt_message| opt_message.map(OUberMessage::Extended))) - }), - ModuleState::Discovery(index) => uber.discovery[index] - .poll() - .map(|async_opt_message| async_opt_message.map(|opt_message| opt_message.map(OUberMessage::Discovery))) - .map_err(|e| Box::new(Into::into(*e))), + UberModule { + sink: UberSink { + discovery: discovery.clone(), + extended: extended.clone(), }, - ) - } -} - -impl Sink for UberModule { - type SinkItem = IUberMessage; - type SinkError = Box; - - fn start_send(&mut self, item: Self::SinkItem) -> StartSend { - // Currently we don't return NotReady from the module directly, so no saving our task state here - self.start_sink_state(&item).map(|a| a.map(|()| item)) + stream: UberStream { discovery, extended }, + } } - fn poll_complete(&mut self) -> Poll<(), Self::SinkError> { - self.poll_sink_state() - } -} - -impl Stream for UberModule { - type Item = OUberMessage; - type Error = Box; - - fn poll(&mut self) -> Poll, Self::Error> { - self.poll_stream_state() + /// Splits the `UberModule` into its parts. + #[must_use] + pub fn into_parts(self) -> (UberSink, UberStream) { + (self.sink, self.stream) } } diff --git a/packages/select/src/uber/sink.rs b/packages/select/src/uber/sink.rs new file mode 100644 index 000000000..222df041c --- /dev/null +++ b/packages/select/src/uber/sink.rs @@ -0,0 +1,86 @@ +use std::pin::Pin; +use std::sync::Arc; +use std::task::{Context, Poll}; + +use futures::{Sink, SinkExt as _}; + +use super::{IUberMessage, UberDiscovery}; +use crate::discovery::IDiscoveryMessage; +use crate::error::Error; +use crate::extended::ExtendedModule; +use crate::IExtendedMessage; + +//----------------------------------------------------------------------// +/// `Sink` portion of the `UberModule` for sending messages. +#[allow(clippy::module_name_repetitions)] +#[derive(Clone)] +pub struct UberSink { + pub(super) discovery: UberDiscovery, + pub(super) extended: Option, +} + +impl UberSink { + fn handle_message(&mut self, message: IUberMessage) -> Result<(), Error> { + match message { + IUberMessage::Control(control) => { + if let Some(extended) = &mut self.extended { + let mut discovery = self.discovery.lock().unwrap(); + let d_modules = discovery.as_mut_slice(); + extended.process_message(IExtendedMessage::Control(*control.clone()), d_modules); + } + for discovery in self.discovery.lock().unwrap().iter_mut() { + Arc::get_mut(discovery) + .unwrap() + .start_send_unpin(IDiscoveryMessage::Control(*control.clone()))?; + } + } + IUberMessage::Extended(extended) => { + if let Some(ext_module) = &mut self.extended { + let mut discovery = self.discovery.lock().unwrap(); + let d_modules = discovery.as_mut_slice(); + ext_module.process_message(*extended.clone(), d_modules); + } + } + IUberMessage::Discovery(discovery) => { + for discovery_module in self.discovery.lock().unwrap().iter_mut() { + Arc::get_mut(discovery_module).unwrap().start_send_unpin(*discovery.clone())?; + } + } + } + Ok(()) + } + + fn poll_discovery_flush(&mut self, cx: &mut Context<'_>) -> Poll> { + for discovery in self.discovery.lock().unwrap().iter_mut() { + match Arc::get_mut(discovery).unwrap().poll_flush_unpin(cx) { + Poll::Ready(Ok(())) => continue, + Poll::Ready(Err(e)) => return Poll::Ready(Err(Error::Discovery(e))), + Poll::Pending => { + cx.waker().wake_by_ref(); + return Poll::Pending; + } + } + } + Poll::Ready(Ok(())) + } +} + +impl Sink for UberSink { + type Error = Error; + + fn poll_ready(self: Pin<&mut Self>, _: &mut Context<'_>) -> Poll> { + Poll::Ready(Ok(())) + } + + fn start_send(mut self: Pin<&mut Self>, item: IUberMessage) -> Result<(), Self::Error> { + self.handle_message(item) + } + + fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + self.poll_discovery_flush(cx) + } + + fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + self.poll_discovery_flush(cx) + } +} diff --git a/packages/select/src/uber/stream.rs b/packages/select/src/uber/stream.rs new file mode 100644 index 000000000..3b755397a --- /dev/null +++ b/packages/select/src/uber/stream.rs @@ -0,0 +1,44 @@ +use std::pin::Pin; +use std::sync::Arc; +use std::task::{Context, Poll}; + +use futures::{Stream, StreamExt as _}; + +use super::{OUberMessage, UberDiscovery}; +use crate::error::Error; +use crate::extended::ExtendedModule; + +//----------------------------------------------------------------------// +/// `Stream` portion of the `UberModule` for receiving messages. +#[allow(clippy::module_name_repetitions)] +pub struct UberStream { + pub(super) discovery: UberDiscovery, + pub(super) extended: Option, +} + +impl Stream for UberStream { + type Item = Result; + + fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + if let Some(extended) = &mut self.extended { + match extended.poll_next_unpin(cx) { + Poll::Ready(Some(Ok(message))) => return Poll::Ready(Some(Ok(OUberMessage::Extended(message)))), + Poll::Ready(Some(Err(e))) => return Poll::Ready(Some(Err(e))), + Poll::Ready(None) => return Poll::Ready(None), + Poll::Pending => (), + } + }; + + for discovery in self.discovery.lock().unwrap().iter_mut() { + match Arc::get_mut(discovery).unwrap().poll_next_unpin(cx) { + Poll::Ready(Some(Ok(message))) => return Poll::Ready(Some(Ok(OUberMessage::Discovery(message)))), + Poll::Ready(Some(Err(e))) => return Poll::Ready(Some(Err(Error::Discovery(e)))), + Poll::Ready(None) => return Poll::Ready(None), + Poll::Pending => (), + } + } + + cx.waker().wake_by_ref(); + Poll::Pending + } +} diff --git a/packages/select/tests/common/mod.rs b/packages/select/tests/common/mod.rs new file mode 100644 index 000000000..c952d9cc5 --- /dev/null +++ b/packages/select/tests/common/mod.rs @@ -0,0 +1,27 @@ +//----------------------------------------------------------------------------------// + +use std::sync::Once; + +use tracing::level_filters::LevelFilter; + +#[allow(dead_code)] +pub static INIT: Once = Once::new(); + +#[allow(dead_code)] +#[derive(PartialEq, Eq, Debug)] +pub enum TimeoutResult { + TimedOut, + GotResult, +} + +#[allow(dead_code)] +pub fn tracing_stderr_init(filter: LevelFilter) { + let builder = tracing_subscriber::fmt() + .with_max_level(filter) + .with_ansi(true) + .with_writer(std::io::stderr); + + builder.pretty().with_file(true).init(); + + tracing::info!("Logging initialized"); +} diff --git a/packages/select/tests/select_tests.rs b/packages/select/tests/select_tests.rs index f03bfb90e..367f5c07e 100644 --- a/packages/select/tests/select_tests.rs +++ b/packages/select/tests/select_tests.rs @@ -1,14 +1,19 @@ -use futures::{Async, Sink, Stream}; -use futures_test::harness::Harness; +use std::time::Duration; + +use common::{tracing_stderr_init, INIT}; +use futures::{SinkExt as _, StreamExt as _}; use handshake::Extensions; use metainfo::{DirectAccessor, Metainfo, MetainfoBuilder, PieceLength}; use peer::PeerInfo; -use select::revelation::error::RevealErrorKind; -use select::revelation::{HonestRevealModule, IRevealMessage, ORevealMessage}; +use select::revelation::error::RevealError; +use select::revelation::{HonestRevealModuleBuilder, IRevealMessage, ORevealMessage}; use select::ControlMessage; +use tracing::level_filters::LevelFilter; use util::bt; use util::bt::InfoHash; +mod common; + fn metainfo(num_pieces: usize) -> Metainfo { let data = vec![0u8; num_pieces]; @@ -30,145 +35,188 @@ fn peer_info(hash: InfoHash) -> PeerInfo { ) } -#[test] -fn positive_add_and_remove_metainfo() { - let (send, _recv) = HonestRevealModule::new().split(); - let metainfo = metainfo(1); +#[tokio::test] +async fn positive_add_and_remove_metainfo() { + INIT.call_once(|| { + tracing_stderr_init(LevelFilter::INFO); + }); - let mut block_send = send.wait(); + let builder = HonestRevealModuleBuilder::new(); + let mut module = builder.build(); + let metainfo = metainfo(1); - block_send + module .send(IRevealMessage::Control(ControlMessage::AddTorrent(metainfo.clone()))) + .await .unwrap(); - block_send + module .send(IRevealMessage::Control(ControlMessage::RemoveTorrent(metainfo.clone()))) + .await .unwrap(); } -#[test] -fn positive_send_bitfield_single_piece() { - let (send, recv) = HonestRevealModule::new().split(); +#[tokio::test] +async fn positive_send_bitfield_single_piece() { + INIT.call_once(|| { + tracing_stderr_init(LevelFilter::INFO); + }); + + let builder = HonestRevealModuleBuilder::new(); + let mut module = builder.build(); let metainfo = metainfo(8); let info_hash = metainfo.info().info_hash(); let peer_info = peer_info(info_hash); - let mut block_send = send.wait(); - let mut block_recv = recv.wait(); - - block_send + tracing::debug!("sending add torrent..."); + module .send(IRevealMessage::Control(ControlMessage::AddTorrent(metainfo))) + .await .unwrap(); - block_send.send(IRevealMessage::FoundGoodPiece(info_hash, 0)).unwrap(); - block_send + + tracing::debug!("sending found good piece..."); + module.send(IRevealMessage::FoundGoodPiece(info_hash, 0)).await.unwrap(); + + tracing::debug!("sending peer connected..."); + module .send(IRevealMessage::Control(ControlMessage::PeerConnected(peer_info))) + .await .unwrap(); - let ORevealMessage::SendBitField(info, bitfield) = block_recv.next().unwrap().unwrap() else { - panic!("Received Unexpected Message") - }; - - assert_eq!(peer_info, info); - assert_eq!(1, bitfield.bitfield().len()); - assert_eq!(0x80, bitfield.bitfield()[0]); + tracing::debug!("receiving send bit field..."); + let message = module.next().await.unwrap(); + if let Ok(ORevealMessage::SendBitField(info, bitfield)) = message { + assert_eq!(peer_info, info); + assert_eq!(1, bitfield.bitfield().len()); + assert_eq!(0x80, bitfield.bitfield()[0]); + } else { + panic!("Received Unexpected Message"); + } } -#[test] -fn positive_send_bitfield_multiple_pieces() { - let (send, recv) = HonestRevealModule::new().split(); +#[tokio::test] +async fn positive_send_bitfield_multiple_pieces() { + INIT.call_once(|| { + tracing_stderr_init(LevelFilter::INFO); + }); + + let builder = HonestRevealModuleBuilder::new(); + let mut module = builder.build(); let metainfo = metainfo(16); let info_hash = metainfo.info().info_hash(); let peer_info = peer_info(info_hash); - let mut block_send = send.wait(); - let mut block_recv = recv.wait(); - - block_send + module .send(IRevealMessage::Control(ControlMessage::AddTorrent(metainfo))) + .await .unwrap(); - block_send.send(IRevealMessage::FoundGoodPiece(info_hash, 0)).unwrap(); - block_send.send(IRevealMessage::FoundGoodPiece(info_hash, 8)).unwrap(); - block_send.send(IRevealMessage::FoundGoodPiece(info_hash, 15)).unwrap(); - block_send + module.send(IRevealMessage::FoundGoodPiece(info_hash, 0)).await.unwrap(); + module.send(IRevealMessage::FoundGoodPiece(info_hash, 8)).await.unwrap(); + module.send(IRevealMessage::FoundGoodPiece(info_hash, 15)).await.unwrap(); + module .send(IRevealMessage::Control(ControlMessage::PeerConnected(peer_info))) + .await .unwrap(); - let ORevealMessage::SendBitField(info, bitfield) = block_recv.next().unwrap().unwrap() else { - panic!("Received Unexpected Message") - }; - - assert_eq!(peer_info, info); - assert_eq!(2, bitfield.bitfield().len()); - assert_eq!(0x80, bitfield.bitfield()[0]); - assert_eq!(0x81, bitfield.bitfield()[1]); + let message = module.next().await.unwrap(); + if let Ok(ORevealMessage::SendBitField(info, bitfield)) = message { + assert_eq!(peer_info, info); + assert_eq!(2, bitfield.bitfield().len()); + assert_eq!(0x80, bitfield.bitfield()[0]); + assert_eq!(0x81, bitfield.bitfield()[1]); + } else { + panic!("Received Unexpected Message"); + } } -#[test] -fn negative_do_not_send_empty_bitfield() { - let (send, recv) = HonestRevealModule::new().split(); +#[tokio::test] +async fn negative_do_not_send_empty_bitfield() { + INIT.call_once(|| { + tracing_stderr_init(LevelFilter::INFO); + }); + + let builder = HonestRevealModuleBuilder::new(); + let mut module = builder.build(); let metainfo = metainfo(16); let info_hash = metainfo.info().info_hash(); let peer_info = peer_info(info_hash); - let mut block_send = send.wait(); - let mut non_block_recv = Harness::new(recv); - - block_send + tracing::debug!("sending add torrent..."); + module .send(IRevealMessage::Control(ControlMessage::AddTorrent(metainfo))) + .await .unwrap(); - block_send + + tracing::debug!("sending peer connected..."); + module .send(IRevealMessage::Control(ControlMessage::PeerConnected(peer_info))) + .await .unwrap(); - assert!(non_block_recv.poll_next().as_ref().map(Async::is_not_ready).unwrap_or(false)); + tracing::debug!("attempt to receive a message..."); + let res = tokio::time::timeout(Duration::from_millis(50), module.next()).await; + + if let Ok(item) = res { + panic!("expected timeout, but got a result: {item:?}"); + } else { + tracing::debug!("timeout was reached"); + } } -#[test] -fn negative_found_good_piece_out_of_range() { - let (send, _recv) = HonestRevealModule::new().split(); +#[tokio::test] +async fn negative_found_good_piece_out_of_range() { + INIT.call_once(|| { + tracing_stderr_init(LevelFilter::INFO); + }); + + let builder = HonestRevealModuleBuilder::new(); + let mut module = builder.build(); let metainfo = metainfo(8); let info_hash = metainfo.info().info_hash(); - let mut block_send = send.wait(); - - block_send + module .send(IRevealMessage::Control(ControlMessage::AddTorrent(metainfo))) + .await .unwrap(); - let error = block_send.send(IRevealMessage::FoundGoodPiece(info_hash, 8)).unwrap_err(); - match error.kind() { - &RevealErrorKind::InvalidPieceOutOfRange { hash, index } => { + let error = module.send(IRevealMessage::FoundGoodPiece(info_hash, 8)).await.unwrap_err(); + match error { + RevealError::InvalidPieceOutOfRange { hash, index } => { assert_eq!(info_hash, hash); assert_eq!(8, index); } _ => { - panic!("Received Unexpected Message") + panic!("Received Unexpected Error: {error:?}"); } - }; + } } -#[test] -fn negative_all_pieces_good_found_good_piece_out_of_range() { - let (send, _recv) = HonestRevealModule::new().split(); +#[tokio::test] +async fn negative_all_pieces_good_found_good_piece_out_of_range() { + INIT.call_once(|| { + tracing_stderr_init(LevelFilter::INFO); + }); + + let builder = HonestRevealModuleBuilder::new(); + let mut module = builder.build(); let metainfo = metainfo(3); let info_hash = metainfo.info().info_hash(); - let mut block_send = send.wait(); - - block_send + module .send(IRevealMessage::Control(ControlMessage::AddTorrent(metainfo))) + .await .unwrap(); - block_send.send(IRevealMessage::FoundGoodPiece(info_hash, 0)).unwrap(); - block_send.send(IRevealMessage::FoundGoodPiece(info_hash, 1)).unwrap(); - block_send.send(IRevealMessage::FoundGoodPiece(info_hash, 2)).unwrap(); + module.send(IRevealMessage::FoundGoodPiece(info_hash, 0)).await.unwrap(); + module.send(IRevealMessage::FoundGoodPiece(info_hash, 1)).await.unwrap(); + module.send(IRevealMessage::FoundGoodPiece(info_hash, 2)).await.unwrap(); - let error = block_send.send(IRevealMessage::FoundGoodPiece(info_hash, 3)).unwrap_err(); - match error.kind() { - &RevealErrorKind::InvalidPieceOutOfRange { hash, index } => { + let error = module.send(IRevealMessage::FoundGoodPiece(info_hash, 3)).await.unwrap_err(); + match error { + RevealError::InvalidPieceOutOfRange { hash, index } => { assert_eq!(info_hash, hash); assert_eq!(3, index); } _ => { - panic!("Received Unexpected Message") + panic!("Received Unexpected Error: {error:?}"); } - }; + } } diff --git a/packages/util/Cargo.toml b/packages/util/Cargo.toml index 0daedc47a..8d885abac 100644 --- a/packages/util/Cargo.toml +++ b/packages/util/Cargo.toml @@ -16,7 +16,7 @@ repository.workspace = true version.workspace = true [dependencies] -chrono = "0.4" -num = "0.4" -rand = "0.8" -rust-crypto = "0.2" +chrono = "0" +num = "0" +rand = "0" +rust-crypto = "0" diff --git a/packages/util/src/contiguous.rs b/packages/util/src/contiguous.rs index b517c0f69..bee54024f 100644 --- a/packages/util/src/contiguous.rs +++ b/packages/util/src/contiguous.rs @@ -1,5 +1,3 @@ -use std::cmp; - /// Trait for metadata, reading, and writing to a contiguous buffer that doesn't re allocate. #[allow(clippy::module_name_repetitions)] pub trait ContiguousBuffer { @@ -125,7 +123,7 @@ where break; } let available_capacity = buffer.capacity() - buffer.length(); - let amount_to_write = cmp::min(available_capacity, data.len() - bytes_written); + let amount_to_write = std::cmp::min(available_capacity, data.len() - bytes_written); let (start, end) = (bytes_written, bytes_written + amount_to_write); diff --git a/packages/util/src/lib.rs b/packages/util/src/lib.rs index aa2849a35..df2f9343c 100644 --- a/packages/util/src/lib.rs +++ b/packages/util/src/lib.rs @@ -12,9 +12,6 @@ pub mod convert; /// Networking primitives and helpers. pub mod net; -/// Generic sender utilities. -pub mod send; - /// Hash primitives and helpers. pub mod sha; diff --git a/packages/util/src/send/mod.rs b/packages/util/src/send/mod.rs deleted file mode 100644 index 430f8dfee..000000000 --- a/packages/util/src/send/mod.rs +++ /dev/null @@ -1,33 +0,0 @@ -use std::sync::mpsc::{self, TrySendError}; - -mod split_sender; - -pub use crate::send::split_sender::{split_sender, SplitSender, SplitSenderAck}; - -/// Trait for generic sender implementations. -pub trait TrySender: Send { - /// Send data through the concrete channel. - /// - /// If the channel is full, return the data back to the caller; if - /// the channel has hung up, the channel should NOT return the data - /// back to the caller but SHOULD panic as hang ups are considered - /// program logic errors. - fn try_send(&self, data: T) -> Option; -} - -impl TrySender for mpsc::Sender { - fn try_send(&self, data: T) -> Option { - self.send(data).expect("bip_util: mpsc::Sender Signaled A Hang Up"); - - None - } -} - -impl TrySender for mpsc::SyncSender { - fn try_send(&self, data: T) -> Option { - self.try_send(data).err().map(|err| match err { - TrySendError::Full(data) => data, - TrySendError::Disconnected(_) => panic!("bip_util: mpsc::SyncSender Signaled A Hang Up"), - }) - } -} diff --git a/packages/util/src/send/split_sender.rs b/packages/util/src/send/split_sender.rs deleted file mode 100644 index ec4bf7432..000000000 --- a/packages/util/src/send/split_sender.rs +++ /dev/null @@ -1,123 +0,0 @@ -use std::sync::atomic::{AtomicUsize, Ordering}; -use std::sync::Arc; - -use crate::send::TrySender; - -/// Create two `SplitSenders` over a single Sender with corresponding capacities. -pub fn split_sender(send: S, cap_one: usize, cap_two: usize) -> (SplitSender, SplitSender) -where - S: TrySender + Clone, - T: Send, -{ - (SplitSender::new(send.clone(), cap_one), SplitSender::new(send, cap_two)) -} - -/// `SplitSender` allows dividing the capacity of a single channel into multiple channels. -pub struct SplitSender { - send: S, - count: Arc, - capacity: usize, -} - -impl Clone for SplitSender -where - S: Clone, -{ - fn clone(&self) -> SplitSender { - SplitSender { - send: self.send.clone(), - count: self.count.clone(), - capacity: self.capacity, - } - } -} - -unsafe impl Sync for SplitSender where S: Sync {} - -impl SplitSender { - /// Create a new `SplitSender`. - pub fn new(send: S, capacity: usize) -> SplitSender { - SplitSender { - send, - count: Arc::new(AtomicUsize::new(0)), - capacity, - } - } - - /// Create a new `SplitSenderAck` that can be used to ack sent messages. - pub fn sender_ack(&self) -> SplitSenderAck { - SplitSenderAck { - count: self.count.clone(), - } - } - - fn try_count_increment(&self) -> bool { - let our_count = self.count.fetch_add(1, Ordering::SeqCst); - - if our_count < self.capacity { - true - } else { - // Failed to get a passable count, revert our add - self.count.fetch_sub(1, Ordering::SeqCst); - - false - } - } -} - -impl TrySender for SplitSender -where - S: TrySender, - T: Send, -{ - fn try_send(&self, data: T) -> Option { - let should_send = self.try_count_increment(); - - if should_send { - self.send.try_send(data) - } else { - Some(data) - } - } -} - -// ----------------------------------------------------------------------------// - -/// `SplitSenderAck` allows a client to ack messages received from a `SplitSender`. -#[allow(clippy::module_name_repetitions)] -pub struct SplitSenderAck { - count: Arc, -} - -impl SplitSenderAck { - /// Ack a message received from a `SplitSender`. - pub fn ack(&self) { - self.count.fetch_sub(1, Ordering::SeqCst); - } -} - -#[cfg(test)] -mod tests { - use std::sync::mpsc; - - use super::SplitSender; - use crate::send::TrySender; - - #[test] - fn positive_send_zero_capacity() { - let (send, recv) = mpsc::channel(); - let split_sender = SplitSender::new(send, 0); - - assert!(split_sender.try_send(()).is_some()); - assert!(recv.try_recv().is_err()); - } - - #[test] - fn positive_send_one_capacity() { - let (send, recv) = mpsc::channel(); - let split_sender = SplitSender::new(send, 1); - - assert!(split_sender.try_send(()).is_none()); - assert!(recv.try_recv().is_ok()); - } -} diff --git a/packages/util/src/sha/mod.rs b/packages/util/src/sha/mod.rs index 0d2e51ed6..296da78a3 100644 --- a/packages/util/src/sha/mod.rs +++ b/packages/util/src/sha/mod.rs @@ -52,6 +52,17 @@ impl ShaHash { } } +impl std::fmt::Display for ShaHash { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "0x")?; + + for byte in &self.hash { + write!(f, "{byte:02x}")?; + } + Ok(()) + } +} + impl AsRef<[u8]> for ShaHash { fn as_ref(&self) -> &[u8] { &self.hash diff --git a/packages/util/src/trans/locally_shuffled.rs b/packages/util/src/trans/locally_shuffled.rs index f9a2faec4..6d40bb928 100644 --- a/packages/util/src/trans/locally_shuffled.rs +++ b/packages/util/src/trans/locally_shuffled.rs @@ -20,7 +20,7 @@ const TRANSACTION_ID_PREALLOC_LEN: usize = 2048; /// transaction type (such as u64) but also works with smaller types. #[allow(clippy::module_name_repetitions)] -#[derive(Default)] +#[derive(Debug, Default)] pub struct LocallyShuffledIds { sequential: SequentialIds, stored_ids: Vec, diff --git a/packages/util/src/trans/sequential.rs b/packages/util/src/trans/sequential.rs index ea414540d..a652692f4 100644 --- a/packages/util/src/trans/sequential.rs +++ b/packages/util/src/trans/sequential.rs @@ -7,7 +7,7 @@ use crate::trans::TransactionIds; /// Generates sequentially unique ids and wraps when overflow occurs. #[allow(clippy::module_name_repetitions)] -#[derive(Default)] +#[derive(Debug, Default)] pub struct SequentialIds { next_id: T, } diff --git a/packages/utracker/Cargo.toml b/packages/utracker/Cargo.toml index 0d6a78d72..29dce63e8 100644 --- a/packages/utracker/Cargo.toml +++ b/packages/utracker/Cargo.toml @@ -20,8 +20,14 @@ handshake = { path = "../handshake" } util = { path = "../util" } byteorder = "1" -chrono = "0.4" -futures = "0.1" -nom = "3" -rand = "0.8" -umio = "0.3" +chrono = "0" +futures = "0" +nom = "7" +rand = "0" +thiserror = "1" +tracing = "0" +umio = "0" + +[dev-dependencies] +tokio = { version = "1", features = ["full"] } +tracing-subscriber = "0" diff --git a/packages/utracker/src/announce.rs b/packages/utracker/src/announce.rs index ba131fea1..7906b6ec6 100644 --- a/packages/utracker/src/announce.rs +++ b/packages/utracker/src/announce.rs @@ -1,14 +1,18 @@ #![allow(deprecated)] //! Messaging primitives for announcing. -use std::io::{self, Write}; +use std::io::Write as _; use std::net::{Ipv4Addr, Ipv6Addr}; use byteorder::{BigEndian, WriteBytesExt}; -use nom::{ - alt, be_i32, be_i64, be_u16, be_u32, be_u8, call, count_fixed, do_parse, error_node_position, error_position, map, named, - switch, tag, take, tuple, tuple_parser, value, IResult, -}; +use nom::branch::alt; +use nom::bytes::complete::{tag, take}; +use nom::combinator::{map, value}; +use nom::multi::count; +use nom::number::complete::{be_i32, be_i64, be_u16, be_u32, be_u8}; +use nom::sequence::tuple; +use nom::IResult; +use tracing::instrument; use util::bt::{self, InfoHash, PeerId}; use util::convert; @@ -44,6 +48,7 @@ pub struct AnnounceRequest<'a> { #[allow(clippy::too_many_arguments)] impl<'a> AnnounceRequest<'a> { /// Create a new `AnnounceRequest`. + #[instrument(skip())] #[must_use] pub fn new( hash: InfoHash, @@ -55,6 +60,7 @@ impl<'a> AnnounceRequest<'a> { port: u16, options: AnnounceOptions<'a>, ) -> AnnounceRequest<'a> { + tracing::trace!("new announce request"); AnnounceRequest { info_hash: hash, peer_id, @@ -68,11 +74,19 @@ impl<'a> AnnounceRequest<'a> { } /// Construct an IPv4 `AnnounceRequest` from the given bytes. + /// + /// # Errors + /// + /// It will return an error when unable to parse the bytes. pub fn from_bytes_v4(bytes: &'a [u8]) -> IResult<&'a [u8], AnnounceRequest<'a>> { parse_request(bytes, SourceIP::from_bytes_v4) } /// Construct an IPv6 `AnnounceRequest` from the given bytes. + /// + /// # Errors + /// + /// It will return an error when unable to parse the bytes. pub fn from_bytes_v6(bytes: &'a [u8]) -> IResult<&'a [u8], AnnounceRequest<'a>> { parse_request(bytes, SourceIP::from_bytes_v6) } @@ -82,9 +96,9 @@ impl<'a> AnnounceRequest<'a> { /// # Errors /// /// It would return an IO error if unable to write the bytes. - pub fn write_bytes(&self, mut writer: W) -> io::Result<()> + pub fn write_bytes(&self, mut writer: W) -> std::io::Result<()> where - W: Write, + W: std::io::Write, { writer.write_all(self.info_hash.as_ref())?; writer.write_all(self.peer_id.as_ref())?; @@ -174,17 +188,21 @@ impl<'a> AnnounceRequest<'a> { /// Parse an `AnnounceRequest` with the given `SourceIP` type constructor. fn parse_request(bytes: &[u8], ip_type: fn(bytes: &[u8]) -> IResult<&[u8], SourceIP>) -> IResult<&[u8], AnnounceRequest<'_>> { - do_parse!(bytes, - info_hash: map!(take!(bt::INFO_HASH_LEN), |bytes| InfoHash::from_hash(bytes).unwrap()) >> - peer_id: map!(take!(bt::PEER_ID_LEN), |bytes| PeerId::from_hash(bytes).unwrap()) >> - state: call!(ClientState::from_bytes) >> - ip: call!(ip_type) >> - key: be_u32 >> - num_want: call!(DesiredPeers::from_bytes) >> - port: be_u16 >> - options: call!(AnnounceOptions::from_bytes) >> - (AnnounceRequest::new(info_hash, peer_id, state, ip, key, num_want, port, options)) - ) + let (bytes, (info_hash, peer_id, state, ip, key, num_want, port, options)) = tuple(( + map(take(bt::INFO_HASH_LEN), |bytes: &[u8]| InfoHash::from_hash(bytes).unwrap()), + map(take(bt::PEER_ID_LEN), |bytes: &[u8]| PeerId::from_hash(bytes).unwrap()), + ClientState::from_bytes, + ip_type, + be_u32, + DesiredPeers::from_bytes, + be_u16, + AnnounceOptions::from_bytes, + ))(bytes)?; + + Ok(( + bytes, + AnnounceRequest::new(info_hash, peer_id, state, ip, key, num_want, port, options), + )) } // ----------------------------------------------------------------------------// @@ -212,11 +230,19 @@ impl<'a> AnnounceResponse<'a> { } /// Construct an IPv4 `AnnounceResponse` from the given bytes. + /// + /// # Errors + /// + /// It will return an error when unable to parse the bytes. pub fn from_bytes_v4(bytes: &'a [u8]) -> IResult<&'a [u8], AnnounceResponse<'a>> { parse_response(bytes, CompactPeers::from_bytes_v4) } /// Construct an IPv6 `AnnounceResponse` from the given bytes. + /// + /// # Errors + /// + /// It will return an error when unable to parse the bytes. pub fn from_bytes_v6(bytes: &'a [u8]) -> IResult<&'a [u8], AnnounceResponse<'a>> { parse_response(bytes, CompactPeers::from_bytes_v6) } @@ -226,9 +252,9 @@ impl<'a> AnnounceResponse<'a> { /// # Errors /// /// It would return an IO Error if unable to write the bytes. - pub fn write_bytes(&self, mut writer: W) -> io::Result<()> + pub fn write_bytes(&self, mut writer: W) -> std::io::Result<()> where - W: Write, + W: std::io::Write, { writer.write_i32::(self.interval)?; writer.write_i32::(self.leechers)?; @@ -282,13 +308,9 @@ fn parse_response<'a>( bytes: &'a [u8], peers_type: fn(bytes: &'a [u8]) -> IResult<&'a [u8], CompactPeers<'a>>, ) -> IResult<&'a [u8], AnnounceResponse<'a>> { - do_parse!(bytes, - interval: be_i32 >> - leechers: be_i32 >> - seeders: be_i32 >> - peers: call!(peers_type) >> - (AnnounceResponse::new(interval, leechers, seeders, peers)) - ) + let (bytes, (interval, leechers, seeders, peers)) = tuple((be_i32, be_i32, be_i32, peers_type))(bytes)?; + + Ok((bytes, AnnounceResponse::new(interval, leechers, seeders, peers))) } // ----------------------------------------------------------------------------// @@ -305,7 +327,9 @@ pub struct ClientState { impl ClientState { /// Create a new `ClientState`. #[must_use] + #[instrument(skip())] pub fn new(bytes_downloaded: i64, bytes_left: i64, bytes_uploaded: i64, event: AnnounceEvent) -> ClientState { + tracing::trace!("new client state"); ClientState { downloaded: bytes_downloaded, left: bytes_left, @@ -315,7 +339,10 @@ impl ClientState { } /// Construct the `ClientState` from the given bytes. - #[must_use] + /// + /// # Errors + /// + /// It will return an error when unable to parse the bytes. pub fn from_bytes(bytes: &[u8]) -> IResult<&[u8], ClientState> { parse_state(bytes) } @@ -325,9 +352,9 @@ impl ClientState { /// # Errors /// /// It would return an IO Error if unable to write the bytes. - pub fn write_bytes(&self, mut writer: W) -> io::Result<()> + pub fn write_bytes(&self, mut writer: W) -> std::io::Result<()> where - W: Write, + W: std::io::Write, { writer.write_i64::(self.downloaded)?; writer.write_i64::(self.left)?; @@ -364,13 +391,9 @@ impl ClientState { } fn parse_state(bytes: &[u8]) -> IResult<&[u8], ClientState> { - do_parse!(bytes, - downloaded: be_i64 >> - left: be_i64 >> - uploaded: be_i64 >> - event: call!(AnnounceEvent::from_bytes) >> - (ClientState::new(downloaded, left, uploaded, event)) - ) + let (bytes, (downloaded, left, uploaded, event)) = tuple((be_i64, be_i64, be_i64, AnnounceEvent::from_bytes))(bytes)?; + + Ok((bytes, ClientState::new(downloaded, left, uploaded, event))) } // ----------------------------------------------------------------------------// @@ -391,7 +414,10 @@ pub enum AnnounceEvent { impl AnnounceEvent { /// Construct an `AnnounceEvent` from the given bytes. - #[must_use] + /// + /// # Errors + /// + /// It will return an error when unable to parse the bytes. pub fn from_bytes(bytes: &[u8]) -> IResult<&[u8], AnnounceEvent> { parse_event(bytes) } @@ -401,9 +427,9 @@ impl AnnounceEvent { /// # Errors /// /// It would return an IO Error if unable to write the bytes. - pub fn write_bytes(&self, mut writer: W) -> io::Result<()> + pub fn write_bytes(&self, mut writer: W) -> std::io::Result<()> where - W: Write, + W: std::io::Write, { writer.write_i32::(self.as_id())?; @@ -423,12 +449,15 @@ impl AnnounceEvent { } fn parse_event(bytes: &[u8]) -> IResult<&[u8], AnnounceEvent> { - switch!(bytes, be_i32, - ANNOUNCE_NONE_EVENT => value!(AnnounceEvent::None) | - ANNOUNCE_COMPLETED_EVENT => value!(AnnounceEvent::Completed) | - ANNOUNCE_STARTED_EVENT => value!(AnnounceEvent::Started) | - ANNOUNCE_STOPPED_EVENT => value!(AnnounceEvent::Stopped) - ) + let (bytes, event_id) = be_i32(bytes)?; + let event = match event_id { + ANNOUNCE_NONE_EVENT => AnnounceEvent::None, + ANNOUNCE_COMPLETED_EVENT => AnnounceEvent::Completed, + ANNOUNCE_STARTED_EVENT => AnnounceEvent::Started, + ANNOUNCE_STOPPED_EVENT => AnnounceEvent::Stopped, + _ => return Err(nom::Err::Error(nom::error::Error::new(bytes, nom::error::ErrorKind::Switch))), + }; + Ok((bytes, event)) } // ----------------------------------------------------------------------------// @@ -448,13 +477,19 @@ pub enum SourceIP { impl SourceIP { /// Construct the IPv4 `SourceIP` from the given bytes. - #[must_use] + /// + /// # Errors + /// + /// It will return an error when unable to parse the bytes. pub fn from_bytes_v4(bytes: &[u8]) -> IResult<&[u8], SourceIP> { parse_preference_v4(bytes) } /// Construct the IPv6 `SourceIP` from the given bytes. - #[must_use] + /// + /// # Errors + /// + /// It will return an error when unable to parse the bytes. pub fn from_bytes_v6(bytes: &[u8]) -> IResult<&[u8], SourceIP> { parse_preference_v6(bytes) } @@ -464,9 +499,9 @@ impl SourceIP { /// # Errors /// /// It would return an IO Error if unable to write the bytes. - pub fn write_bytes(&self, writer: W) -> io::Result<()> + pub fn write_bytes(&self, writer: W) -> std::io::Result<()> where - W: Write, + W: std::io::Write, { match *self { SourceIP::ImpliedV4 => SourceIP::write_bytes_slice(writer, &IMPLIED_IPV4_ID[..]), @@ -492,35 +527,43 @@ impl SourceIP { } /// Write the given byte slice to the given writer. - fn write_bytes_slice(mut writer: W, bytes: &[u8]) -> io::Result<()> + fn write_bytes_slice(mut writer: W, bytes: &[u8]) -> std::io::Result<()> where - W: Write, + W: std::io::Write, { writer.write_all(bytes) } } fn parse_preference_v4(bytes: &[u8]) -> IResult<&[u8], SourceIP> { - alt!(bytes, - tag!(IMPLIED_IPV4_ID) => { |_| SourceIP::ImpliedV4 } | - parse_ipv4 => { SourceIP::ExplicitV4 } - ) + let (bytes, ip) = alt(( + map(tag(&IMPLIED_IPV4_ID), |_| SourceIP::ImpliedV4), + map(parse_ipv4, SourceIP::ExplicitV4), + ))(bytes)?; + Ok((bytes, ip)) } -named!(parse_ipv4<&[u8], Ipv4Addr>, - map!(count_fixed!(u8, be_u8, 4), convert::bytes_be_to_ipv4) -); +fn parse_ipv4(bytes: &[u8]) -> IResult<&[u8], Ipv4Addr> { + let (bytes, ip_bytes) = take(4usize)(bytes)?; + let ip_array: [u8; 4] = ip_bytes.try_into().expect("slice with incorrect length"); + let ip = convert::bytes_be_to_ipv4(ip_array); + Ok((bytes, ip)) +} fn parse_preference_v6(bytes: &[u8]) -> IResult<&[u8], SourceIP> { - alt!(bytes, - tag!(IMPLIED_IPV6_ID) => { |_| SourceIP::ImpliedV6 } | - parse_ipv6 => { SourceIP::ExplicitV6 } - ) + let (bytes, ip) = alt(( + map(tag(&IMPLIED_IPV6_ID), |_| SourceIP::ImpliedV6), + map(parse_ipv6, SourceIP::ExplicitV6), + ))(bytes)?; + Ok((bytes, ip)) } -named!(parse_ipv6<&[u8], Ipv6Addr>, - map!(count_fixed!(u8, be_u8, 16), convert::bytes_be_to_ipv6) -); +fn parse_ipv6(bytes: &[u8]) -> IResult<&[u8], Ipv6Addr> { + let (bytes, ip_bytes) = take(16usize)(bytes)?; + let ip_array: [u8; 16] = ip_bytes.try_into().expect("slice with incorrect length"); + let ip = convert::bytes_be_to_ipv6(ip_array); + Ok((bytes, ip)) +} // ----------------------------------------------------------------------------// @@ -535,7 +578,10 @@ pub enum DesiredPeers { impl DesiredPeers { /// Construct the `DesiredPeers` from the given bytes. - #[must_use] + /// + /// # Errors + /// + /// It will return an error when unable to parse the bytes. pub fn from_bytes(bytes: &[u8]) -> IResult<&[u8], DesiredPeers> { parse_desired(bytes) } @@ -545,9 +591,9 @@ impl DesiredPeers { /// # Errors /// /// It would return an IO Error if unable to write the bytes. - pub fn write_bytes(&self, mut writer: W) -> io::Result<()> + pub fn write_bytes(&self, mut writer: W) -> std::io::Result<()> where - W: Write, + W: std::io::Write, { let write_value = match self { DesiredPeers::Default => DEFAULT_NUM_WANT, @@ -560,17 +606,19 @@ impl DesiredPeers { } fn parse_desired(bytes: &[u8]) -> IResult<&[u8], DesiredPeers> { - // Tuple trick used to subvert the unused pattern warning (nom tries to catch all) - switch!(bytes, tuple!(be_i32, value!(true)), - (DEFAULT_NUM_WANT, true) => value!(DesiredPeers::Default) | - (specified_peers, true) => value!(DesiredPeers::Specified(specified_peers)) - ) + let (bytes, num_want) = be_i32(bytes)?; + let desired_peers = if num_want == DEFAULT_NUM_WANT { + DesiredPeers::Default + } else { + DesiredPeers::Specified(num_want) + }; + Ok((bytes, desired_peers)) } #[cfg(test)] mod tests { - use std::io::Write; - use std::net::Ipv4Addr; + use std::io::Write as _; + use std::net::{Ipv4Addr, Ipv6Addr}; use byteorder::{BigEndian, WriteBytesExt}; use nom::IResult; @@ -578,6 +626,7 @@ mod tests { use util::convert; use super::{AnnounceEvent, AnnounceRequest, AnnounceResponse, ClientState, DesiredPeers, SourceIP}; + use crate::announce::{parse_ipv4, parse_ipv6}; use crate::contact::{CompactPeers, CompactPeersV4, CompactPeersV6}; use crate::option::AnnounceOptions; @@ -815,9 +864,7 @@ mod tests { bytes.write_i32::(num_want).unwrap(); bytes.write_u16::(port).unwrap(); - let IResult::Done(_, received) = AnnounceRequest::from_bytes_v4(&bytes) else { - panic!("AnnounceRequest Parsing Failed...") - }; + let received = AnnounceRequest::from_bytes_v4(&bytes).unwrap().1; assert_eq!(received.info_hash(), InfoHash::from(info_hash)); assert_eq!(received.peer_id(), PeerId::from(peer_id)); @@ -850,7 +897,7 @@ mod tests { let received = AnnounceRequest::from_bytes_v4(&bytes); - assert!(received.is_incomplete()); + assert!(received.is_err()); } #[test] @@ -862,14 +909,14 @@ mod tests { bytes.write_i32::(leechers).unwrap(); bytes.write_i32::(seeders).unwrap(); - let received_v4 = AnnounceResponse::from_bytes_v4(&bytes); - let received_v6 = AnnounceResponse::from_bytes_v6(&bytes); + let received_v4 = AnnounceResponse::from_bytes_v4(&bytes).unwrap().1; + let received_v6 = AnnounceResponse::from_bytes_v6(&bytes).unwrap().1; let expected_v4 = AnnounceResponse::new(interval, leechers, seeders, CompactPeers::V4(CompactPeersV4::new())); let expected_v6 = AnnounceResponse::new(interval, leechers, seeders, CompactPeers::V6(CompactPeersV6::new())); - assert_eq!(received_v4, IResult::Done(&b""[..], expected_v4)); - assert_eq!(received_v6, IResult::Done(&b""[..], expected_v6)); + assert_eq!(received_v4, expected_v4); + assert_eq!(received_v6, expected_v6); } #[test] @@ -895,19 +942,19 @@ mod tests { let mut bytes_v6 = bytes.clone(); peers_v6.write_bytes(&mut bytes_v6).unwrap(); - let received_v4 = AnnounceResponse::from_bytes_v4(&bytes_v4); - let received_v6 = AnnounceResponse::from_bytes_v6(&bytes_v6); + let received_v4 = AnnounceResponse::from_bytes_v4(&bytes_v4).unwrap().1; + let received_v6 = AnnounceResponse::from_bytes_v6(&bytes_v6).unwrap().1; let expected_v4 = AnnounceResponse::new(interval, leechers, seeders, CompactPeers::V4(peers_v4)); let expected_v6 = AnnounceResponse::new(interval, leechers, seeders, CompactPeers::V6(peers_v6)); - assert_eq!(received_v4, IResult::Done(&b""[..], expected_v4)); - assert_eq!(received_v6, IResult::Done(&b""[..], expected_v6)); + assert_eq!(received_v4, expected_v4); + assert_eq!(received_v6, expected_v6); } #[test] fn positive_parse_state() { - let (downloaded, left, uploaded) = (202_340, 52340, 5043); + let (downloaded, left, uploaded) = (202_340, 52_340, 5_043); let mut bytes = Vec::new(); bytes.write_i64::(downloaded).unwrap(); @@ -915,15 +962,15 @@ mod tests { bytes.write_i64::(uploaded).unwrap(); bytes.write_i32::(super::ANNOUNCE_NONE_EVENT).unwrap(); - let received = ClientState::from_bytes(&bytes); + let received = ClientState::from_bytes(&bytes).unwrap().1; let expected = ClientState::new(downloaded, left, uploaded, AnnounceEvent::None); - assert_eq!(received, IResult::Done(&b""[..], expected)); + assert_eq!(received, expected); } #[test] fn negative_parse_incomplete_state() { - let (downloaded, left, uploaded) = (202_340, 52340, 5043); + let (downloaded, left, uploaded) = (202_340, 52_340, 5_043); let mut bytes = Vec::new(); bytes.write_i64::(downloaded).unwrap(); @@ -932,7 +979,7 @@ mod tests { let received = ClientState::from_bytes(&bytes); - assert!(received.is_incomplete()); + assert!(received.is_err()); } #[test] @@ -940,10 +987,10 @@ mod tests { let mut bytes = Vec::new(); bytes.write_i32::(super::ANNOUNCE_NONE_EVENT).unwrap(); - let received = AnnounceEvent::from_bytes(&bytes); + let received = AnnounceEvent::from_bytes(&bytes).unwrap().1; let expected = AnnounceEvent::None; - assert_eq!(received, IResult::Done(&b""[..], expected)); + assert_eq!(received, expected); } #[test] @@ -951,10 +998,10 @@ mod tests { let mut bytes = Vec::new(); bytes.write_i32::(super::ANNOUNCE_COMPLETED_EVENT).unwrap(); - let received = AnnounceEvent::from_bytes(&bytes); + let received = AnnounceEvent::from_bytes(&bytes).unwrap().1; let expected = AnnounceEvent::Completed; - assert_eq!(received, IResult::Done(&b""[..], expected)); + assert_eq!(received, expected); } #[test] @@ -962,10 +1009,10 @@ mod tests { let mut bytes = Vec::new(); bytes.write_i32::(super::ANNOUNCE_STARTED_EVENT).unwrap(); - let received = AnnounceEvent::from_bytes(&bytes); + let received = AnnounceEvent::from_bytes(&bytes).unwrap().1; let expected = AnnounceEvent::Started; - assert_eq!(received, IResult::Done(&b""[..], expected)); + assert_eq!(received, expected); } #[test] @@ -982,10 +1029,10 @@ mod tests { let mut bytes = Vec::new(); bytes.write_i32::(super::ANNOUNCE_STOPPED_EVENT).unwrap(); - let received = AnnounceEvent::from_bytes(&bytes); + let received = AnnounceEvent::from_bytes(&bytes).unwrap().1; let expected = AnnounceEvent::Stopped; - assert_eq!(received, IResult::Done(&b""[..], expected)); + assert_eq!(received, expected); } #[test] @@ -993,20 +1040,20 @@ mod tests { let mut bytes = Vec::new(); bytes.write_all(&super::IMPLIED_IPV4_ID).unwrap(); - let received = SourceIP::from_bytes_v4(&bytes); + let received = SourceIP::from_bytes_v4(&bytes).unwrap().1; let expected = SourceIP::ImpliedV4; - assert_eq!(received, IResult::Done(&b""[..], expected)); + assert_eq!(received, expected); } #[test] fn positive_parse_explicit_v4_source() { let bytes = [127, 0, 0, 1]; - let received = SourceIP::from_bytes_v4(&bytes); + let received = SourceIP::from_bytes_v4(&bytes).unwrap().1; let expected = SourceIP::ExplicitV4(Ipv4Addr::new(127, 0, 0, 1)); - assert_eq!(received, IResult::Done(&b""[..], expected)); + assert_eq!(received, expected); } #[test] @@ -1014,10 +1061,10 @@ mod tests { let mut bytes = Vec::new(); bytes.write_all(&super::IMPLIED_IPV6_ID).unwrap(); - let received = SourceIP::from_bytes_v6(&bytes); + let received = SourceIP::from_bytes_v6(&bytes).unwrap().1; let expected = SourceIP::ImpliedV6; - assert_eq!(received, IResult::Done(&b""[..], expected)); + assert_eq!(received, expected); } #[test] @@ -1025,10 +1072,10 @@ mod tests { let ip = "ADBB:234A:55BD:FF34:3D3A:FFFF:234A:55BD".parse().unwrap(); // cspell:disable-line let bytes = convert::ipv6_to_bytes_be(ip); - let received = SourceIP::from_bytes_v6(&bytes); + let received = SourceIP::from_bytes_v6(&bytes).unwrap().1; let expected = SourceIP::ExplicitV6(ip); - assert_eq!(received, IResult::Done(&b""[..], expected)); + assert_eq!(received, expected); } #[test] @@ -1037,7 +1084,7 @@ mod tests { let received = SourceIP::from_bytes_v4(&bytes); - assert!(received.is_incomplete()); + assert!(received.is_err()); } #[test] @@ -1046,7 +1093,7 @@ mod tests { let received = SourceIP::from_bytes_v6(&bytes); - assert!(received.is_incomplete()); + assert!(received.is_err()); } #[test] @@ -1055,7 +1102,7 @@ mod tests { let received = SourceIP::from_bytes_v4(&bytes); - assert!(received.is_incomplete()); + assert!(received.is_err()); } #[test] @@ -1064,26 +1111,58 @@ mod tests { let received = SourceIP::from_bytes_v6(&bytes); - assert!(received.is_incomplete()); + assert!(received.is_err()); } #[test] fn positive_parse_desired_peers_default() { let default_bytes = convert::four_bytes_to_array(u32::MAX); - let received = DesiredPeers::from_bytes(&default_bytes); + let received = DesiredPeers::from_bytes(&default_bytes).unwrap().1; let expected = DesiredPeers::Default; - assert_eq!(received, IResult::Done(&b""[..], expected)); + assert_eq!(received, expected); } #[test] fn positive_parse_desired_peers_specified() { let specified_bytes = convert::four_bytes_to_array(50); - let received = DesiredPeers::from_bytes(&specified_bytes); + let received = DesiredPeers::from_bytes(&specified_bytes).unwrap().1; let expected = DesiredPeers::Specified(50); - assert_eq!(received, IResult::Done(&b""[..], expected)); + assert_eq!(received, expected); + } + + #[test] + fn test_parse_ipv4() { + // Valid IPv4 address + let bytes = [192, 168, 0, 1]; + let expected_ip = Ipv4Addr::new(192, 168, 0, 1); + let (remaining, parsed_ip) = parse_ipv4(&bytes).unwrap(); + assert_eq!(remaining, &[]); + assert_eq!(parsed_ip, expected_ip); + + // Invalid IPv4 address (not enough bytes) + let bytes = [192, 168]; + let result = parse_ipv4(&bytes); + assert!(result.is_err()); + } + + #[test] + fn test_parse_ipv6() { + // Valid IPv6 address + let bytes = [ + 0x20, 0x01, 0x0d, 0xb8, 0x85, 0xa3, 0x00, 0x00, 0x00, 0x00, 0x8a, 0x2e, 0x03, 0x70, 0x73, 0x34, + ]; + let expected_ip = Ipv6Addr::new(0x2001, 0x0db8, 0x85a3, 0x0000, 0x0000, 0x8a2e, 0x0370, 0x7334); + let (remaining, parsed_ip) = parse_ipv6(&bytes).unwrap(); + assert_eq!(remaining, &[]); + assert_eq!(parsed_ip, expected_ip); + + // Invalid IPv6 address (not enough bytes) + let bytes = [0x20, 0x01, 0x0d, 0xb8]; + let result = parse_ipv6(&bytes); + assert!(result.is_err()); } } diff --git a/packages/utracker/src/client/dispatcher.rs b/packages/utracker/src/client/dispatcher.rs index 248b5a762..2d0dd8b9b 100644 --- a/packages/utracker/src/client/dispatcher.rs +++ b/packages/utracker/src/client/dispatcher.rs @@ -1,19 +1,21 @@ use std::collections::hash_map::Entry; use std::collections::HashMap; -use std::io::{self, Cursor}; use std::net::SocketAddr; -use std::thread; use chrono::offset::Utc; use chrono::{DateTime, Duration}; -use futures::future::Either; -use futures::sink::{Sink, Wait}; +use futures::executor::block_on; +use futures::future::{BoxFuture, Either}; +use futures::sink::Sink; +use futures::{FutureExt, SinkExt}; use handshake::{DiscoveryInfo, InitiateMessage, Protocol}; use nom::IResult; +use tracing::instrument; use umio::external::{self, Timeout}; use umio::{Dispatcher, ELoopBuilder, Provider}; use util::bt::PeerId; +use super::HandshakerMessage; use crate::announce::{AnnounceRequest, DesiredPeers, SourceIP}; use crate::client::error::{ClientError, ClientResult}; use crate::client::{ClientMetadata, ClientRequest, ClientResponse, ClientToken, RequestLimiter}; @@ -28,12 +30,14 @@ const CONNECTION_ID_VALID_DURATION_MILLIS: i64 = 60000; const MAXIMUM_REQUEST_RETRANSMIT_ATTEMPTS: u64 = 8; /// Internal dispatch timeout. +#[derive(Debug)] enum DispatchTimeout { Connect(ClientToken), CleanUp, } /// Internal dispatch message for clients. +#[derive(Debug)] pub enum DispatchMessage { Request(SocketAddr, ClientToken, ClientRequest), StartTimer, @@ -44,16 +48,19 @@ pub enum DispatchMessage { /// /// Assumes `msg_capacity` is less than `usize::max_value`(). #[allow(clippy::module_name_repetitions)] +#[instrument(skip())] pub fn create_dispatcher( bind: SocketAddr, handshaker: H, msg_capacity: usize, limiter: RequestLimiter, -) -> io::Result> +) -> std::io::Result> where - H: Sink + DiscoveryInfo + 'static + Send, - H::SinkItem: From>, + H: Sink> + std::fmt::Debug + DiscoveryInfo + Send + Unpin + 'static, + H::Error: std::fmt::Display, { + tracing::debug!("creating dispatcher"); + // Timer capacity is plus one for the cache cleanup timer let builder = ELoopBuilder::new() .channel_capacity(msg_capacity) @@ -66,7 +73,7 @@ where let dispatch = ClientDispatcher::new(handshaker, bind, limiter); - thread::spawn(move || { + std::thread::spawn(move || { eloop.run(dispatch).expect("bip_utracker: ELoop Shutdown Unexpectedly..."); }); @@ -80,8 +87,9 @@ where // ----------------------------------------------------------------------------// /// Dispatcher that executes requests asynchronously. +#[derive(Debug)] struct ClientDispatcher { - handshaker: Wait, + handshaker: H, pid: PeerId, port: u16, bound_addr: SocketAddr, @@ -92,16 +100,19 @@ struct ClientDispatcher { impl ClientDispatcher where - H: Sink + DiscoveryInfo, - H::SinkItem: From>, + H: Sink> + std::fmt::Debug + DiscoveryInfo + Send + Unpin + 'static, + H::Error: std::fmt::Display, { /// Create a new `ClientDispatcher`. + #[instrument(skip(), ret)] pub fn new(handshaker: H, bind: SocketAddr, limiter: RequestLimiter) -> ClientDispatcher { + tracing::debug!("new client dispatcher"); + let peer_id = handshaker.peer_id(); let port = handshaker.port(); ClientDispatcher { - handshaker: handshaker.wait(), + handshaker, pid: peer_id, port, bound_addr: bind, @@ -112,7 +123,10 @@ where } /// Shutdown the current dispatcher, notifying all pending requests. + #[instrument(skip(self, provider))] pub fn shutdown(&mut self, provider: &mut Provider<'_, ClientDispatcher>) { + tracing::debug!("shutting down client dispatcher"); + // Notify all active requests with the appropriate error for token_index in 0..self.active_requests.len() { let next_token = *self.active_requests.keys().nth(token_index).unwrap(); @@ -126,15 +140,20 @@ where } /// Finish a request by sending the result back to the client. + #[instrument(skip(self))] pub fn notify_client(&mut self, token: ClientToken, result: ClientResult) { - self.handshaker - .send(Either::B(ClientMetadata::new(token, result)).into()) - .unwrap_or_else(|_| panic!("NEED TO FIX")); + tracing::info!("notifying clients"); + + match block_on(self.handshaker.send(Ok(ClientMetadata::new(token, result).into()))) { + Ok(()) => tracing::debug!("client metadata sent"), + Err(e) => tracing::error!("sending client metadata failed with error: {e}"), + } self.limiter.acknowledge(); } /// Process a request to be sent to the given address and associated with the given token. + #[instrument(skip(self, provider))] pub fn send_request( &mut self, provider: &mut Provider<'_, ClientDispatcher>, @@ -142,9 +161,15 @@ where token: ClientToken, request: ClientRequest, ) { + tracing::debug!("sending request"); + + let bound_addr = self.bound_addr; + // Check for IP version mismatch between source addr and dest addr - match (self.bound_addr, addr) { + match (bound_addr, addr) { (SocketAddr::V4(_), SocketAddr::V6(_)) | (SocketAddr::V6(_), SocketAddr::V4(_)) => { + tracing::error!(%bound_addr, %addr, "ip version mismatch between bound address and address"); + self.notify_client(token, Err(ClientError::IPVersionMismatch)); return; @@ -157,23 +182,30 @@ where } /// Process a response received from some tracker and match it up against our sent requests. + #[instrument(skip(self, provider, response))] pub fn recv_response( &mut self, provider: &mut Provider<'_, ClientDispatcher>, - addr: SocketAddr, response: &TrackerResponse<'_>, + addr: SocketAddr, ) { + tracing::debug!("receiving response"); + let token = ClientToken(response.transaction_id()); let conn_timer = if let Some(conn_timer) = self.active_requests.remove(&token) { if conn_timer.message_params().0 == addr { conn_timer } else { + tracing::error!(?conn_timer, %addr, "different message prams"); + return; - } // TODO: Add Logging (Server Receive Addr Different Than Send Addr) + } } else { + tracing::error!(?token, "token not in active requests"); + return; - }; // TODO: Add Logging (Server Gave Us Invalid Transaction Id) + }; provider.clear_timeout( conn_timer @@ -193,9 +225,14 @@ where (&ClientRequest::Announce(hash, _), ResponseType::Announce(res)) => { // Forward contact information on to the handshaker for addr in res.peers().iter() { - self.handshaker - .send(Either::A(InitiateMessage::new(Protocol::BitTorrent, hash, addr)).into()) - .unwrap_or_else(|_| panic!("NEED TO FIX")); + tracing::info!("sending will block if unable to send!"); + match block_on( + self.handshaker + .send(Ok(InitiateMessage::new(Protocol::BitTorrent, hash, addr).into())), + ) { + Ok(()) => tracing::debug!("handshake for: {addr} initiated"), + Err(e) => tracing::warn!("handshake for: {addr} failed with: {e}"), + } } self.notify_client(token, Ok(ClientResponse::Announce(res.to_owned()))); @@ -216,14 +253,23 @@ where /// Process an existing request, either re requesting a connection id or sending the actual request again. /// /// If this call is the result of a timeout, that will decide whether to cancel the request or not. + #[instrument(skip(self, provider))] fn process_request(&mut self, provider: &mut Provider<'_, ClientDispatcher>, token: ClientToken, timed_out: bool) { + tracing::debug!("processing request"); + let Some(mut conn_timer) = self.active_requests.remove(&token) else { + tracing::error!(?token, "token not in active requests"); + return; - }; // TODO: Add logging + }; // Resolve the duration of the current timeout to use let Some(next_timeout) = conn_timer.current_timeout(timed_out) else { - self.notify_client(token, Err(ClientError::MaxTimeout)); + let err = ClientError::MaxTimeout; + + tracing::error!("error reached timeout: {err}"); + + self.notify_client(token, Err(err)); return; }; @@ -267,13 +313,16 @@ where // Try to write the request out to the server let mut write_success = false; provider.outgoing(|bytes| { - let mut writer = Cursor::new(bytes); - write_success = tracker_request.write_bytes(&mut writer).is_ok(); - - if write_success { - Some((writer.position().try_into().unwrap(), addr)) - } else { - None + let mut writer = std::io::Cursor::new(bytes); + match tracker_request.write_bytes(&mut writer) { + Ok(()) => { + write_success = true; + Some((writer.position().try_into().unwrap(), addr)) + } + Err(e) => { + tracing::error!("failed to write out the tracker request with error: {e}"); + None + } } }); @@ -287,28 +336,40 @@ where self.active_requests.insert(token, conn_timer); } else { - self.notify_client(token, Err(ClientError::MaxLength)); + let err = ClientError::MaxLength; + tracing::warn!("notifying client with error: {err}"); + + self.notify_client(token, Err(err)); } } } impl Dispatcher for ClientDispatcher where - H: Sink + DiscoveryInfo, - H::SinkItem: From>, + H: Sink> + std::fmt::Debug + DiscoveryInfo + Send + Unpin + 'static, + H::Error: std::fmt::Display, { type Timeout = DispatchTimeout; type Message = DispatchMessage; + #[instrument(skip(self, provider))] fn incoming(&mut self, mut provider: Provider<'_, Self>, message: &[u8], addr: SocketAddr) { - let IResult::Done(_, response) = TrackerResponse::from_bytes(message) else { - return; // TODO: Add Logging - }; + let () = match TrackerResponse::from_bytes(message) { + IResult::Ok((_, response)) => { + tracing::debug!("received an incoming response: {response:?}"); - self.recv_response(&mut provider, addr, &response); + self.recv_response(&mut provider, &response, addr); + } + Err(e) => { + tracing::error!("received an incoming error message: {e}"); + } + }; } + #[instrument(skip(self, provider))] fn notify(&mut self, mut provider: Provider<'_, Self>, message: DispatchMessage) { + tracing::debug!("received notify"); + match message { DispatchMessage::Request(addr, token, req_type) => { self.send_request(&mut provider, addr, token, req_type); @@ -318,7 +379,10 @@ where } } + #[instrument(skip(self, provider))] fn timeout(&mut self, mut provider: Provider<'_, Self>, timeout: DispatchTimeout) { + tracing::debug!("received timeout"); + match timeout { DispatchTimeout::Connect(token) => self.process_request(&mut provider, token, true), DispatchTimeout::CleanUp => { @@ -343,6 +407,22 @@ struct ConnectTimer { timeout_id: Option, } +impl std::fmt::Debug for ConnectTimer { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + let timeout_id = match self.timeout_id { + Some(_) => "Some(_)", + None => "None", + }; + + f.debug_struct("ConnectTimer") + .field("addr", &self.addr) + .field("attempt", &self.attempt) + .field("request", &self.request) + .field("timeout_id", &timeout_id) + .finish() + } +} + impl ConnectTimer { /// Create a new `ConnectTimer`. pub fn new(addr: SocketAddr, request: ClientRequest) -> ConnectTimer { @@ -355,8 +435,13 @@ impl ConnectTimer { } /// Yields the current timeout value to use or None if the request should time out completely. + #[instrument(skip(), ret)] pub fn current_timeout(&mut self, timed_out: bool) -> Option { + tracing::debug!("getting current timeout"); + if self.attempt == MAXIMUM_REQUEST_RETRANSMIT_ATTEMPTS { + tracing::warn!("request has reached maximum timeout attempts: {MAXIMUM_REQUEST_RETRANSMIT_ATTEMPTS}"); + None } else { if timed_out { @@ -378,21 +463,27 @@ impl ConnectTimer { } /// Yields the message parameters for the current connection. + #[instrument(skip(), ret)] pub fn message_params(&self) -> (SocketAddr, &ClientRequest) { + tracing::debug!("getting message parameters"); + (self.addr, &self.request) } } /// Calculates the timeout for the request given the attempt count. +#[instrument(skip(), ret)] fn calculate_message_timeout_millis(attempt: u64) -> u64 { - #[allow(clippy::cast_possible_truncation)] - let attempt = attempt as u32; + tracing::debug!("calculation message timeout in milliseconds"); + + let attempt = attempt.try_into().unwrap_or(u32::MAX); (15 * 2u64.pow(attempt)) * 1000 } // ----------------------------------------------------------------------------// /// Cache for storing connection ids associated with a specific server address. +#[derive(Debug)] struct ConnectIdCache { cache: HashMap)>, } @@ -403,10 +494,17 @@ impl ConnectIdCache { ConnectIdCache { cache: HashMap::new() } } - /// Get an un expired connection id for the given addr. + /// Get an active connection id for the given addr. + #[instrument(skip(self), ret)] fn get(&mut self, addr: SocketAddr) -> Option { + tracing::debug!("getting connection id"); + match self.cache.entry(addr) { - Entry::Vacant(_) => None, + Entry::Vacant(_) => { + tracing::warn!("connection id for {addr} not in cache"); + + None + } Entry::Occupied(occ) => { let curr_time = Utc::now(); let prev_time = occ.get().1; @@ -414,6 +512,8 @@ impl ConnectIdCache { if is_expired(curr_time, prev_time) { occ.remove(); + tracing::warn!("connection id was already expired"); + None } else { Some(occ.get().0) @@ -423,14 +523,20 @@ impl ConnectIdCache { } /// Put an un expired connection id into cache for the given addr. + #[instrument(skip(self))] fn put(&mut self, addr: SocketAddr, connect_id: u64) { + tracing::debug!("setting expired connection id"); + let curr_time = Utc::now(); self.cache.insert(addr, (connect_id, curr_time)); } /// Removes all entries that have expired. + #[instrument(skip(self))] fn clean_expired(&mut self) { + tracing::debug!("cleaning expired connection id(s)"); + let curr_time = Utc::now(); let mut curr_index = 0; @@ -447,7 +553,10 @@ impl ConnectIdCache { } /// Returns true if the connect id received at `prev_time` is now expired. +#[instrument(skip(), ret)] fn is_expired(curr_time: DateTime, prev_time: DateTime) -> bool { + tracing::debug!("checking if a previous time is now expired"); + let valid_duration = Duration::milliseconds(CONNECTION_ID_VALID_DURATION_MILLIS); let difference = prev_time.signed_duration_since(curr_time); diff --git a/packages/utracker/src/client/error.rs b/packages/utracker/src/client/error.rs index 69af8cb65..dc60ab898 100644 --- a/packages/utracker/src/client/error.rs +++ b/packages/utracker/src/client/error.rs @@ -1,3 +1,5 @@ +use thiserror::Error; + use crate::error::ErrorResponse; /// Result type for a `ClientRequest`. @@ -5,18 +7,23 @@ pub type ClientResult = Result; /// Errors occurring as the result of a `ClientRequest`. #[allow(clippy::module_name_repetitions)] -#[derive(Debug, Clone, PartialEq, Eq)] +#[derive(Error, Debug, Clone, PartialEq, Eq)] pub enum ClientError { - /// Request timeout reached. + #[error("Request timeout reached")] MaxTimeout, - /// Request length exceeded the packet length. + + #[error("Request length exceeded the packet length")] MaxLength, - /// Client shut down the request client. + + #[error("Client shut down the request client")] ClientShutdown, - /// Server sent us an invalid message. + + #[error("Server sent us an invalid message")] ServerError, - /// Requested to send from IPv4 to IPv6 or vice versa. + + #[error("Requested to send from IPv4 to IPv6 or vice versa")] IPVersionMismatch, - /// Server returned an error message. - ServerMessage(ErrorResponse<'static>), + + #[error("Server returned an error message : {0}")] + ServerMessage(#[from] ErrorResponse<'static>), } diff --git a/packages/utracker/src/client/mod.rs b/packages/utracker/src/client/mod.rs index f6690e746..fb0449d85 100644 --- a/packages/utracker/src/client/mod.rs +++ b/packages/utracker/src/client/mod.rs @@ -1,4 +1,3 @@ -use std::io; use std::net::SocketAddr; use std::sync::atomic::{AtomicUsize, Ordering}; use std::sync::Arc; @@ -6,6 +5,7 @@ use std::sync::Arc; use futures::future::Either; use futures::sink::Sink; use handshake::{DiscoveryInfo, InitiateMessage}; +use tracing::instrument; use umio::external::Sender; use util::bt::InfoHash; use util::trans::{LocallyShuffledIds, TransactionIds}; @@ -21,6 +21,24 @@ pub mod error; /// Capacity of outstanding requests (assuming each request uses at most 1 timer at any time) const DEFAULT_CAPACITY: usize = 4096; +#[derive(Debug)] +pub enum HandshakerMessage { + InitiateMessage(InitiateMessage), + ClientMetadata(ClientMetadata), +} + +impl From for HandshakerMessage { + fn from(message: InitiateMessage) -> Self { + Self::InitiateMessage(message) + } +} + +impl From for HandshakerMessage { + fn from(metadata: ClientMetadata) -> Self { + Self::ClientMetadata(metadata) + } +} + /// Request made by the `TrackerClient`. #[allow(clippy::module_name_repetitions)] #[derive(Debug)] @@ -108,19 +126,6 @@ pub struct TrackerClient { } impl TrackerClient { - /// Create a new `TrackerClient`. - /// - /// # Errors - /// - /// It would return a IO error if unable build a new client. - pub fn new(bind: SocketAddr, handshaker: H) -> io::Result - where - H: Sink + DiscoveryInfo + Send + 'static, - H::SinkItem: From>, - { - TrackerClient::with_capacity(bind, handshaker, DEFAULT_CAPACITY) - } - /// Create a new `TrackerClient` with the given message capacity. /// /// Panics if capacity == `usize::max_value`(). @@ -132,11 +137,24 @@ impl TrackerClient { /// # Panics /// /// It would panic if the desired capacity is too large. - pub fn with_capacity(bind: SocketAddr, handshaker: H, capacity: usize) -> io::Result + #[instrument(skip())] + pub fn new(bind: SocketAddr, handshaker: H, capacity_or_default: Option) -> std::io::Result where - H: Sink + DiscoveryInfo + Send + 'static, - H::SinkItem: From>, + H: Sink> + std::fmt::Debug + DiscoveryInfo + Send + Unpin + 'static, + H::Error: std::fmt::Display, { + tracing::info!("running client"); + + let capacity = if let Some(capacity) = capacity_or_default { + tracing::debug!("with capacity {capacity}"); + + capacity + } else { + tracing::debug!("with default capacity: {DEFAULT_CAPACITY}"); + + DEFAULT_CAPACITY + }; + // Need channel capacity to be 1 more in case channel is saturated and client // is dropped so shutdown message can get through in the worst case let (chan_capacity, would_overflow) = capacity.overflowing_add(1); @@ -147,8 +165,10 @@ impl TrackerClient { // Limit the capacity of messages (channel capacity - 1) let limiter = RequestLimiter::new(capacity); - dispatcher::create_dispatcher(bind, handshaker, chan_capacity, limiter.clone()).map(|chan| TrackerClient { - send: chan, + let dispatcher = dispatcher::create_dispatcher(bind, handshaker, chan_capacity, limiter.clone())?; + + Ok(TrackerClient { + send: dispatcher, limiter, generator: TokenGenerator::new(), }) @@ -161,7 +181,10 @@ impl TrackerClient { /// # Panics /// /// It would panic if unable to send request message. + #[instrument(skip(self))] pub fn request(&mut self, addr: SocketAddr, request: ClientRequest) -> Option { + tracing::debug!("requesting"); + if self.limiter.can_initiate() { let token = self.generator.generate(); self.send @@ -170,6 +193,8 @@ impl TrackerClient { Some(token) } else { + tracing::debug!("initiation was limited"); + None } } @@ -212,7 +237,7 @@ impl TokenGenerator { // ----------------------------------------------------------------------------// /// Limits requests based on the current number of outstanding requests. -#[derive(Clone)] +#[derive(Clone, Debug)] pub struct RequestLimiter { active: Arc, capacity: usize, @@ -236,11 +261,11 @@ impl RequestLimiter { /// /// It is invalid to not make the request after this returns true. pub fn can_initiate(&self) -> bool { - let current_active_requests = self.active.fetch_add(1, Ordering::AcqRel); + let current_active_requests = self.active.fetch_add(1, Ordering::AcqRel) + 1; // If the number of requests stored previously was less than the capacity, // then the add is considered good and a request can (SHOULD) be made. - if current_active_requests < self.capacity { + if current_active_requests <= self.capacity { true } else { // Act as if the request just completed (decrement back down) diff --git a/packages/utracker/src/contact.rs b/packages/utracker/src/contact.rs index 129123e13..9ee87b290 100644 --- a/packages/utracker/src/contact.rs +++ b/packages/utracker/src/contact.rs @@ -1,7 +1,7 @@ //! Messaging primitives for contact information. use std::borrow::Cow; -use std::io::{self, Write}; +use std::io::Write as _; use std::net::{SocketAddr, SocketAddrV4, SocketAddrV6}; use nom::{IResult, Needed}; @@ -21,22 +21,26 @@ pub enum CompactPeers<'a> { impl<'a> CompactPeers<'a> { /// Construct a `CompactPeers::V4` from the given bytes. - #[must_use] + /// + /// # Errors + /// + /// It will return an error when unable to parse the bytes. pub fn from_bytes_v4(bytes: &'a [u8]) -> IResult<&'a [u8], CompactPeers<'a>> { match CompactPeersV4::from_bytes(bytes) { - IResult::Done(i, peers) => IResult::Done(i, CompactPeers::V4(peers)), - IResult::Error(err) => IResult::Error(err), - IResult::Incomplete(need) => IResult::Incomplete(need), + IResult::Ok((i, peers)) => IResult::Ok((i, CompactPeers::V4(peers))), + IResult::Err(err) => IResult::Err(err), } } /// Construct a `CompactPeers::V6` from the given bytes. - #[must_use] + /// + /// # Errors + /// + /// It will return an error when unable to parse the bytes. pub fn from_bytes_v6(bytes: &'a [u8]) -> IResult<&'a [u8], CompactPeers<'a>> { match CompactPeersV6::from_bytes(bytes) { - IResult::Done(i, peers) => IResult::Done(i, CompactPeers::V6(peers)), - IResult::Error(err) => IResult::Error(err), - IResult::Incomplete(need) => IResult::Incomplete(need), + IResult::Ok((i, peers)) => IResult::Ok((i, CompactPeers::V6(peers))), + IResult::Err(err) => IResult::Err(err), } } @@ -45,9 +49,9 @@ impl<'a> CompactPeers<'a> { /// # Errors /// /// It would return an IO Error if unable to write the bytes. - pub fn write_bytes(&self, writer: W) -> io::Result<()> + pub fn write_bytes(&self, writer: W) -> std::io::Result<()> where - W: Write, + W: std::io::Write, { match self { CompactPeers::V4(peers) => peers.write_bytes(writer), @@ -127,7 +131,10 @@ impl<'a> CompactPeersV4<'a> { } /// Construct a `CompactPeersV4` from the given bytes. - #[must_use] + /// + /// # Errors + /// + /// It will return an error when unable to parse the bytes. pub fn from_bytes(bytes: &'a [u8]) -> IResult<&'a [u8], CompactPeersV4<'a>> { parse_peers_v4(bytes) } @@ -137,9 +144,9 @@ impl<'a> CompactPeersV4<'a> { /// # Errors /// /// It would return an IO Error if unable to write the bytes. - pub fn write_bytes(&self, mut writer: W) -> io::Result<()> + pub fn write_bytes(&self, mut writer: W) -> std::io::Result<()> where - W: Write, + W: std::io::Write, { writer.write_all(&self.peers)?; @@ -173,16 +180,16 @@ fn parse_peers_v4(bytes: &[u8]) -> IResult<&[u8], CompactPeersV4<'_>> { let remainder_bytes = bytes.len() % SOCKET_ADDR_V4_BYTES; if remainder_bytes != 0 { - IResult::Incomplete(Needed::Size(SOCKET_ADDR_V4_BYTES - remainder_bytes)) + Err(nom::Err::Incomplete(nom::Needed::new(SOCKET_ADDR_V4_BYTES - remainder_bytes))) } else { let end_of_bytes = &bytes[bytes.len()..bytes.len()]; - IResult::Done( + IResult::Ok(( end_of_bytes, CompactPeersV4 { peers: Cow::Borrowed(bytes), }, - ) + )) } } @@ -246,7 +253,10 @@ impl<'a> CompactPeersV6<'a> { } /// Construct a `CompactPeersV6` from the given bytes. - #[must_use] + /// + /// # Errors + /// + /// It will return an error when unable to parse the bytes. pub fn from_bytes(bytes: &'a [u8]) -> IResult<&'a [u8], CompactPeersV6<'a>> { parse_peers_v6(bytes) } @@ -256,9 +266,9 @@ impl<'a> CompactPeersV6<'a> { /// # Errors /// /// It would return an IO Error if unable to write the bytes. - pub fn write_bytes(&self, mut writer: W) -> io::Result<()> + pub fn write_bytes(&self, mut writer: W) -> std::io::Result<()> where - W: Write, + W: std::io::Write, { writer.write_all(&self.peers)?; @@ -292,16 +302,16 @@ fn parse_peers_v6(bytes: &[u8]) -> IResult<&[u8], CompactPeersV6<'_>> { let remainder_bytes = bytes.len() % SOCKET_ADDR_V6_BYTES; if remainder_bytes != 0 { - IResult::Incomplete(Needed::Size(SOCKET_ADDR_V6_BYTES - remainder_bytes)) + Err(nom::Err::Incomplete(nom::Needed::new(SOCKET_ADDR_V6_BYTES - remainder_bytes))) } else { let end_of_bytes = &bytes[bytes.len()..bytes.len()]; - IResult::Done( + IResult::Ok(( end_of_bytes, CompactPeersV6 { peers: Cow::Borrowed(bytes), }, - ) + )) } } @@ -377,7 +387,7 @@ mod tests { let received = CompactPeersV4::from_bytes(&bytes); let expected = CompactPeersV4::new(); - assert_eq!(received, IResult::Done(&b""[..], expected)); + assert_eq!(received, IResult::Ok((&b""[..], expected))); } #[test] @@ -389,7 +399,7 @@ mod tests { expected.insert("127.0.0.1:15".parse().unwrap()); - assert_eq!(received, IResult::Done(&b""[..], expected)); + assert_eq!(received, IResult::Ok((&b""[..], expected))); } #[test] @@ -402,7 +412,7 @@ mod tests { expected.insert("127.0.0.1:15".parse().unwrap()); expected.insert("127.0.0.1:256".parse().unwrap()); - assert_eq!(received, IResult::Done(&b""[..], expected)); + assert_eq!(received, IResult::Ok((&b""[..], expected))); } #[test] @@ -468,7 +478,7 @@ mod tests { let received = CompactPeersV6::from_bytes(&bytes); let expected = CompactPeersV6::new(); - assert_eq!(received, IResult::Done(&b""[..], expected)); + assert_eq!(received, IResult::Ok((&b""[..], expected))); } #[test] @@ -482,7 +492,7 @@ mod tests { expected.insert("[ADBB:234A:55BD:FF34:3D3A::234A:55BD]:256".parse().unwrap()); // cspell:disable-line - assert_eq!(received, IResult::Done(&b""[..], expected)); + assert_eq!(received, IResult::Ok((&b""[..], expected))); } #[test] @@ -498,7 +508,7 @@ mod tests { expected.insert("[ADBB:234A:55BD:FF34:3D3A::234A:55BD]:256".parse().unwrap()); // cspell:disable-line expected.insert("[DABB:234A:55BD:FF34:3D3A::234A:55BD]:512".parse().unwrap()); // cspell:disable-line - assert_eq!(received, IResult::Done(&b""[..], expected)); + assert_eq!(received, IResult::Ok((&b""[..], expected))); } #[test] diff --git a/packages/utracker/src/error.rs b/packages/utracker/src/error.rs index 93ab3e4d3..87fb5c182 100644 --- a/packages/utracker/src/error.rs +++ b/packages/utracker/src/error.rs @@ -1,17 +1,28 @@ //! Messaging primitives for server errors. use std::borrow::Cow; -use std::io::{self, Write}; +use std::io::Write as _; -use nom::{call, error_position, map, map_res, take, take_str, IResult}; +use nom::bytes::complete::take; +use nom::character::complete::not_line_ending; +use nom::combinator::map_res; +use nom::sequence::terminated; +use nom::IResult; +use thiserror::Error; /// Error reported by the server and sent to the client. #[allow(clippy::module_name_repetitions)] -#[derive(Debug, Clone, PartialEq, Eq)] +#[derive(Error, Debug, Clone, PartialEq, Eq)] pub struct ErrorResponse<'a> { message: Cow<'a, str>, } +impl<'a> std::fmt::Display for ErrorResponse<'a> { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "Server Error: {}", self.message) + } +} + impl<'a> ErrorResponse<'a> { /// Create a new `ErrorResponse`. #[must_use] @@ -22,8 +33,13 @@ impl<'a> ErrorResponse<'a> { } /// Construct an `ErrorResponse` from the given bytes. + /// + /// # Errors + /// + /// It will return an error when unable to parse the bytes. pub fn from_bytes(bytes: &'a [u8]) -> IResult<&'a [u8], ErrorResponse<'a>> { - map!(bytes, take_str!(bytes.len()), ErrorResponse::new) + let (remaining, message) = map_res(terminated(not_line_ending, take(0usize)), std::str::from_utf8)(bytes)?; + Ok((remaining, ErrorResponse::new(message))) } /// Write the `ErrorResponse` to the given writer. @@ -31,9 +47,9 @@ impl<'a> ErrorResponse<'a> { /// # Errors /// /// It would return an IO Error if unable to write the bytes. - pub fn write_bytes(&self, mut writer: W) -> io::Result<()> + pub fn write_bytes(&self, mut writer: W) -> std::io::Result<()> where - W: Write, + W: std::io::Write, { writer.write_all(self.message.as_bytes())?; diff --git a/packages/utracker/src/lib.rs b/packages/utracker/src/lib.rs index 8e4435547..2985a7a8a 100644 --- a/packages/utracker/src/lib.rs +++ b/packages/utracker/src/lib.rs @@ -28,6 +28,6 @@ mod server; pub use util::bt::{InfoHash, PeerId}; pub use crate::client::error::{ClientError, ClientResult}; -pub use crate::client::{ClientMetadata, ClientRequest, ClientResponse, ClientToken, TrackerClient}; +pub use crate::client::{ClientMetadata, ClientRequest, ClientResponse, ClientToken, HandshakerMessage, TrackerClient}; pub use crate::server::handler::{ServerHandler, ServerResult}; pub use crate::server::TrackerServer; diff --git a/packages/utracker/src/option.rs b/packages/utracker/src/option.rs index ce637007c..bd0a4d01c 100644 --- a/packages/utracker/src/option.rs +++ b/packages/utracker/src/option.rs @@ -3,10 +3,17 @@ use std::borrow::Cow; use std::collections::hash_map::Entry; use std::collections::HashMap; -use std::io::{self, Write}; +use std::io::Write as _; use byteorder::WriteBytesExt; -use nom::{alt, be_u8, call, do_parse, eof, error_position, length_bytes, length_data, map, named, tag, take, IResult}; +use nom::branch::alt; +use nom::bytes::complete::{tag, take}; +use nom::combinator::{eof, map}; +use nom::multi::length_data; +use nom::number::complete::be_u8; +use nom::sequence::tuple; +use nom::IResult; +use tracing::instrument; const END_OF_OPTIONS_BYTE: u8 = 0x00; const NO_OPERATION_BYTE: u8 = 0x01; @@ -46,13 +53,15 @@ impl<'a> AnnounceOptions<'a> { } /// Parse a set of `AnnounceOptions` from the given bytes. - #[must_use] + /// + /// # Errors + /// + /// It will return an error when unable to parse the bytes. pub fn from_bytes(bytes: &'a [u8]) -> IResult<&'a [u8], AnnounceOptions<'a>> { let mut raw_options = HashMap::new(); - map!(bytes, call!(parse_options, &mut raw_options), |_| { - AnnounceOptions { raw_options } - }) + let (remaining, _) = parse_options(bytes, &mut raw_options)?; + Ok((remaining, AnnounceOptions { raw_options })) } /// Write the `AnnounceOptions` to the given writer. @@ -63,11 +72,13 @@ impl<'a> AnnounceOptions<'a> { /// /// # Panics /// - /// It would panic if the chuck length is too large. - pub fn write_bytes(&self, mut writer: W) -> io::Result<()> + /// It would panic if the chunk length is too large. + #[instrument(skip(self, writer))] + pub fn write_bytes(&self, mut writer: W) -> std::io::Result<()> where - W: Write, + W: std::io::Write, { + tracing::trace!("writing {} options", self.raw_options.len()); for (byte, content) in &self.raw_options { for content_chunk in content.chunks(u8::MAX as usize) { let content_chunk_len: u8 = content_chunk.len().try_into().unwrap(); @@ -80,10 +91,17 @@ impl<'a> AnnounceOptions<'a> { // If we can fit it in, include the option terminating byte, otherwise as per the // spec, we can leave it out since we are assuming this is the end of the packet. - // TODO: Allow unused when the compile flag is stabilized - writer.write_u8(END_OF_OPTIONS_BYTE); - - Ok(()) + match writer.write_u8(END_OF_OPTIONS_BYTE) { + Ok(()) => Ok(()), + Err(e) => { + if e.kind() == std::io::ErrorKind::WriteZero { + tracing::trace!("no space to write ending marker"); + Ok(()) + } else { + Err(e) + } + } + } } /// Search for and construct the given `AnnounceOption` from the current `AnnounceOptions`. @@ -109,10 +127,6 @@ impl<'a> AnnounceOptions<'a> { let mut bytes = vec![0u8; option.option_length()]; option.write_option(&mut bytes[..]); - // Unfortunately we cannot return the replaced value unless we modified the - // AnnounceOption::read_option method to accept a Cow and give it that because - // we cant guarantee that the buffer is not Cow::Owned at the moment and would be - // dropped (replaced) after being constructed. self.insert_bytes(O::option_byte(), bytes); } @@ -138,61 +152,50 @@ fn parse_options<'a>(bytes: &'a [u8], option_map: &mut HashMap let mut curr_bytes = bytes; let mut eof = false; - // Iteratively try all parsers until one succeeds and check whether the eof has been reached. - // Return early on incomplete or error. while !eof { - let parse_result = alt!( - curr_bytes, - parse_end_option | call!(parse_no_option) | call!(parse_user_option, option_map) - ); + let parse_result = alt((parse_end_option, parse_no_option, |input| { + parse_user_option(input, option_map) + }))(curr_bytes); match parse_result { - IResult::Done(new_bytes, found_eof) => { + Ok((new_bytes, found_eof)) => { eof = found_eof; - curr_bytes = new_bytes; } - some_error => { - return some_error; + Err(e) => { + return Err(e); } }; } - IResult::Done(curr_bytes, eof) + Ok((curr_bytes, eof)) } /// Parse an end of buffer or the end of option byte. -named!(parse_end_option<&[u8], bool>, map!(alt!( - eof!() | tag!([END_OF_OPTIONS_BYTE]) -), |_| true)); +fn parse_end_option(input: &[u8]) -> IResult<&[u8], bool> { + map(alt((eof, tag([END_OF_OPTIONS_BYTE]))), |_| true)(input) +} /// Parse a noop byte. -fn parse_no_option(bytes: &[u8]) -> IResult<&[u8], bool> { - map!(bytes, tag!([NO_OPERATION_BYTE]), |_| false) +fn parse_no_option(input: &[u8]) -> IResult<&[u8], bool> { + map(tag([NO_OPERATION_BYTE]), |_| false)(input) } /// Parse a user defined option. -fn parse_user_option<'a>(bytes: &'a [u8], option_map: &mut HashMap>) -> IResult<&'a [u8], bool> { - do_parse!(bytes, - option_byte: be_u8 >> - option_contents: length_bytes!(byte_usize) >> - ({ - match option_map.entry(option_byte) { - Entry::Occupied(mut occ) => { occ.get_mut().to_mut().extend_from_slice(option_contents); }, - Entry::Vacant(vac) => { vac.insert(Cow::Borrowed(option_contents)); } - }; - - false - }) - ) -} +fn parse_user_option<'a>(input: &'a [u8], option_map: &mut HashMap>) -> IResult<&'a [u8], bool> { + let (input, (option_byte, option_contents)) = tuple((be_u8, length_data(be_u8)))(input)?; -/// Parse a single byte as an unsigned pointer size. -named!(byte_usize<&[u8], usize>, map!( - be_u8, |b| b as usize -)); + match option_map.entry(option_byte) { + Entry::Occupied(mut occ) => { + occ.get_mut().to_mut().extend_from_slice(option_contents); + } + Entry::Vacant(vac) => { + vac.insert(Cow::Borrowed(option_contents)); + } + }; -// ----------------------------------------------------------------------------// + Ok((input, false)) +} /// Concatenated PATH and QUERY of a UDP tracker URL. #[allow(clippy::module_name_repetitions)] @@ -229,14 +232,36 @@ impl<'a> AnnounceOption<'a> for URLDataOption<'a> { #[cfg(test)] mod tests { - use std::io::Write; + + use std::io::Write as _; + use std::sync::Once; use nom::IResult; + use tracing::level_filters::LevelFilter; use super::{AnnounceOptions, URLDataOption}; + #[allow(dead_code)] + pub static INIT: Once = Once::new(); + + #[allow(dead_code)] + pub fn tracing_stderr_init(filter: LevelFilter) { + let builder = tracing_subscriber::fmt() + .with_max_level(filter) + .with_ansi(true) + .with_writer(std::io::stderr); + + builder.pretty().with_file(true).init(); + + tracing::info!("Logging initialized"); + } + #[test] fn positive_write_eof_option() { + INIT.call_once(|| { + tracing_stderr_init(LevelFilter::INFO); + }); + let mut received = []; let options = AnnounceOptions::new(); @@ -303,7 +328,7 @@ mod tests { let received = AnnounceOptions::from_bytes(&bytes); let expected = AnnounceOptions::new(); - assert_eq!(received, IResult::Done(&b""[..], expected)); + assert_eq!(received, IResult::Ok((&b""[..], expected))); } #[test] @@ -313,7 +338,7 @@ mod tests { let received = AnnounceOptions::from_bytes(&bytes); let expected = AnnounceOptions::new(); - assert_eq!(received, IResult::Done(&b""[..], expected)); + assert_eq!(received, IResult::Ok((&b""[..], expected))); } #[test] @@ -323,7 +348,7 @@ mod tests { let received = AnnounceOptions::from_bytes(&bytes); let expected = AnnounceOptions::new(); - assert_eq!(received, IResult::Done(&b""[..], expected)); + assert_eq!(received, IResult::Ok((&b""[..], expected))); } #[test] @@ -337,7 +362,7 @@ mod tests { let url_data = URLDataOption::new(&url_data_bytes); expected.insert(&url_data); - assert_eq!(received, IResult::Done(&b""[..], expected)); + assert_eq!(received, IResult::Ok((&b""[..], expected))); } #[test] @@ -351,7 +376,7 @@ mod tests { let url_data = URLDataOption::new(&url_data_bytes); expected.insert(&url_data); - assert_eq!(received, IResult::Done(&b""[..], expected)); + assert_eq!(received, IResult::Ok((&b""[..], expected))); } #[test] @@ -365,7 +390,7 @@ mod tests { let url_data = URLDataOption::new(&url_data_bytes); expected.insert(&url_data); - assert_eq!(received, IResult::Done(&b""[..], expected)); + assert_eq!(received, IResult::Ok((&b""[..], expected))); } #[test] @@ -389,7 +414,7 @@ mod tests { let url_data = URLDataOption::new(&url_data_bytes); expected.insert(&url_data); - assert_eq!(received, IResult::Done(&b""[..], expected)); + assert_eq!(received, IResult::Ok((&b""[..], expected))); } #[test] @@ -407,7 +432,7 @@ mod tests { let url_data = URLDataOption::new(&bytes[2..]); expected.insert(&url_data); - assert_eq!(received, IResult::Done(&b""[..], expected)); + assert_eq!(received, IResult::Ok((&b""[..], expected))); } #[test] @@ -439,7 +464,7 @@ mod tests { let url_data = URLDataOption::new(&url_data_bytes[..]); expected.insert(&url_data); - assert_eq!(received, IResult::Done(&b""[..], expected)); + assert_eq!(received, IResult::Ok((&b""[..], expected))); } #[test] @@ -473,7 +498,7 @@ mod tests { let url_data = URLDataOption::new(&url_data_bytes[..]); expected.insert(&url_data); - assert_eq!(received, IResult::Done(&b""[..], expected)); + assert_eq!(received, IResult::Ok((&b""[..], expected))); } #[test] @@ -482,7 +507,7 @@ mod tests { let received = AnnounceOptions::from_bytes(&bytes); - assert!(received.is_incomplete()); + assert!(received.is_err()); } #[test] @@ -491,6 +516,6 @@ mod tests { let received = AnnounceOptions::from_bytes(&bytes); - assert!(received.is_incomplete()); + assert!(received.is_err()); } } diff --git a/packages/utracker/src/request.rs b/packages/utracker/src/request.rs index 8bfe07fd1..67180dd73 100644 --- a/packages/utracker/src/request.rs +++ b/packages/utracker/src/request.rs @@ -1,9 +1,13 @@ //! Messaging primitives for requests. -use std::io::{self, Write}; +use std::io::Write as _; use byteorder::{BigEndian, WriteBytesExt}; -use nom::{be_u32, be_u64, call, error_node_position, error_position, map, switch, tuple, tuple_parser, value, IResult}; +use nom::bytes::complete::take; +use nom::combinator::{map, map_res}; +use nom::number::complete::{be_u32, be_u64}; +use nom::sequence::tuple; +use nom::IResult; use crate::announce::AnnounceRequest; use crate::scrape::ScrapeRequest; @@ -16,6 +20,7 @@ pub const CONNECT_ID_PROTOCOL_ID: u64 = 0x0417_2710_1980; /// Enumerates all types of requests that can be made to a tracker. #[allow(clippy::module_name_repetitions)] +#[derive(Debug)] pub enum RequestType<'a> { Connect, Announce(AnnounceRequest<'a>), @@ -36,6 +41,7 @@ impl<'a> RequestType<'a> { /// `TrackerRequest` which encapsulates any request sent to a tracker. #[allow(clippy::module_name_repetitions)] +#[derive(Debug)] pub struct TrackerRequest<'a> { // Both the connection id and transaction id are technically not unsigned according // to the spec, but since they are just bits we will keep them as unsigned since it @@ -57,7 +63,10 @@ impl<'a> TrackerRequest<'a> { } /// Create a new `TrackerRequest` from the given bytes. - #[must_use] + /// + /// # Errors + /// + /// It will return an error when unable to parse the bytes. pub fn from_bytes(bytes: &'a [u8]) -> IResult<&'a [u8], TrackerRequest<'a>> { parse_request(bytes) } @@ -67,9 +76,9 @@ impl<'a> TrackerRequest<'a> { /// # Errors /// /// It would return an IO Error if unable to write the bytes. - pub fn write_bytes(&self, mut writer: W) -> io::Result<()> + pub fn write_bytes(&self, mut writer: W) -> std::io::Result<()> where - W: Write, + W: std::io::Write, { writer.write_u64::(self.connection_id())?; @@ -133,18 +142,37 @@ impl<'a> TrackerRequest<'a> { } fn parse_request(bytes: &[u8]) -> IResult<&[u8], TrackerRequest<'_>> { - switch!(bytes, tuple!(be_u64, be_u32, be_u32), - (CONNECT_ID_PROTOCOL_ID, crate::CONNECT_ACTION_ID, tid) => value!( - TrackerRequest::new(CONNECT_ID_PROTOCOL_ID, tid, RequestType::Connect) - ) | - (cid, crate::ANNOUNCE_IPV4_ACTION_ID, tid) => map!(call!(AnnounceRequest::from_bytes_v4), |ann_req| { - TrackerRequest::new(cid, tid, RequestType::Announce(ann_req)) - }) | - (cid, crate::SCRAPE_ACTION_ID, tid) => map!(call!(ScrapeRequest::from_bytes), |scr_req| { - TrackerRequest::new(cid, tid, RequestType::Scrape(scr_req)) - }) | - (cid, crate::ANNOUNCE_IPV6_ACTION_ID, tid) => map!(call!(AnnounceRequest::from_bytes_v6), |ann_req| { - TrackerRequest::new(cid, tid, RequestType::Announce(ann_req)) - }) - ) + let (remaining, (connection_id, action_id, transaction_id)) = tuple((be_u64, be_u32, be_u32))(bytes)?; + + match (connection_id, action_id) { + (CONNECT_ID_PROTOCOL_ID, crate::CONNECT_ACTION_ID) => Ok(( + remaining, + TrackerRequest::new(CONNECT_ID_PROTOCOL_ID, transaction_id, RequestType::Connect), + )), + (cid, crate::ANNOUNCE_IPV4_ACTION_ID) => { + let (remaining, ann_req) = AnnounceRequest::from_bytes_v4(remaining)?; + Ok(( + remaining, + TrackerRequest::new(cid, transaction_id, RequestType::Announce(ann_req)), + )) + } + (cid, crate::SCRAPE_ACTION_ID) => { + let (remaining, scr_req) = ScrapeRequest::from_bytes(remaining)?; + Ok(( + remaining, + TrackerRequest::new(cid, transaction_id, RequestType::Scrape(scr_req)), + )) + } + (cid, crate::ANNOUNCE_IPV6_ACTION_ID) => { + let (remaining, ann_req) = AnnounceRequest::from_bytes_v6(remaining)?; + Ok(( + remaining, + TrackerRequest::new(cid, transaction_id, RequestType::Announce(ann_req)), + )) + } + _ => Err(nom::Err::Error(nom::error::Error::new( + remaining, + nom::error::ErrorKind::Switch, + ))), + } } diff --git a/packages/utracker/src/response.rs b/packages/utracker/src/response.rs index 88084b83c..51742a290 100644 --- a/packages/utracker/src/response.rs +++ b/packages/utracker/src/response.rs @@ -1,9 +1,12 @@ //! Messaging primitives for responses. -use std::io::{self, Write}; +use std::io::Write as _; use byteorder::{BigEndian, WriteBytesExt}; -use nom::{be_u32, be_u64, call, error_node_position, error_position, map, switch, tuple, tuple_parser, IResult}; +use nom::combinator::map; +use nom::number::complete::{be_u32, be_u64}; +use nom::sequence::tuple; +use nom::IResult; use crate::announce::AnnounceResponse; use crate::contact::CompactPeers; @@ -15,6 +18,7 @@ const ERROR_ACTION_ID: u32 = 3; /// Enumerates all types of responses that can be received from a tracker. #[allow(clippy::module_name_repetitions)] +#[derive(Debug)] pub enum ResponseType<'a> { Connect(u64), Announce(AnnounceResponse<'a>), @@ -37,6 +41,7 @@ impl<'a> ResponseType<'a> { /// `TrackerResponse` which encapsulates any response sent from a tracker. #[allow(clippy::module_name_repetitions)] +#[derive(Debug)] pub struct TrackerResponse<'a> { transaction_id: u32, response_type: ResponseType<'a>, @@ -53,7 +58,10 @@ impl<'a> TrackerResponse<'a> { } /// Create a new `TrackerResponse` from the given bytes. - #[must_use] + /// + /// # Errors + /// + /// It will return an error when unable to parse the bytes. pub fn from_bytes(bytes: &'a [u8]) -> IResult<&'a [u8], TrackerResponse<'a>> { parse_response(bytes) } @@ -63,9 +71,9 @@ impl<'a> TrackerResponse<'a> { /// # Errors /// /// It would return an IO Error if unable to write the bytes. - pub fn write_bytes(&self, mut writer: W) -> io::Result<()> + pub fn write_bytes(&self, mut writer: W) -> std::io::Result<()> where - W: Write, + W: std::io::Write, { match self.response_type() { &ResponseType::Connect(id) => { @@ -125,19 +133,41 @@ impl<'a> TrackerResponse<'a> { } fn parse_response(bytes: &[u8]) -> IResult<&[u8], TrackerResponse<'_>> { - switch!(bytes, tuple!(be_u32, be_u32), - (crate::CONNECT_ACTION_ID, tid) => map!(be_u64, |cid| TrackerResponse::new(tid, ResponseType::Connect(cid)) ) | - (crate::ANNOUNCE_IPV4_ACTION_ID, tid) => map!(call!(AnnounceResponse::from_bytes_v4), |ann_res| { - TrackerResponse::new(tid, ResponseType::Announce(ann_res)) - }) | - (crate::SCRAPE_ACTION_ID, tid) => map!(call!(ScrapeResponse::from_bytes), |scr_res| { - TrackerResponse::new(tid, ResponseType::Scrape(scr_res)) - }) | - (ERROR_ACTION_ID, tid) => map!(call!(ErrorResponse::from_bytes), |err_res| { - TrackerResponse::new(tid, ResponseType::Error(err_res)) - }) | - (crate::ANNOUNCE_IPV6_ACTION_ID, tid) => map!(call!(AnnounceResponse::from_bytes_v6), |ann_req| { - TrackerResponse::new(tid, ResponseType::Announce(ann_req)) - }) - ) + let (remaining, (action_id, transaction_id)) = tuple((be_u32, be_u32))(bytes)?; + + match action_id { + crate::CONNECT_ACTION_ID => { + let (remaining, connection_id) = be_u64(remaining)?; + Ok(( + remaining, + TrackerResponse::new(transaction_id, ResponseType::Connect(connection_id)), + )) + } + crate::ANNOUNCE_IPV4_ACTION_ID => { + let (remaining, ann_res) = AnnounceResponse::from_bytes_v4(remaining)?; + Ok(( + remaining, + TrackerResponse::new(transaction_id, ResponseType::Announce(ann_res)), + )) + } + crate::SCRAPE_ACTION_ID => { + let (remaining, scr_res) = ScrapeResponse::from_bytes(remaining)?; + Ok((remaining, TrackerResponse::new(transaction_id, ResponseType::Scrape(scr_res)))) + } + ERROR_ACTION_ID => { + let (remaining, err_res) = ErrorResponse::from_bytes(remaining)?; + Ok((remaining, TrackerResponse::new(transaction_id, ResponseType::Error(err_res)))) + } + crate::ANNOUNCE_IPV6_ACTION_ID => { + let (remaining, ann_res) = AnnounceResponse::from_bytes_v6(remaining)?; + Ok(( + remaining, + TrackerResponse::new(transaction_id, ResponseType::Announce(ann_res)), + )) + } + _ => Err(nom::Err::Error(nom::error::Error::new( + remaining, + nom::error::ErrorKind::Switch, + ))), + } } diff --git a/packages/utracker/src/scrape.rs b/packages/utracker/src/scrape.rs index 855c0a32c..8e0bc64e6 100644 --- a/packages/utracker/src/scrape.rs +++ b/packages/utracker/src/scrape.rs @@ -1,9 +1,14 @@ //! Messaging primitives for scraping. use std::borrow::Cow; -use std::io::{self, Write}; - -use nom::{be_i32, call, do_parse, IResult, Needed}; +use std::io::Write as _; +use std::num::NonZero; + +use nom::bytes::complete::take; +use nom::combinator::map_res; +use nom::number::complete::be_i32; +use nom::sequence::tuple; +use nom::{IResult, Needed}; use util::bt::{self, InfoHash}; use util::convert; @@ -54,12 +59,8 @@ impl ScrapeStats { } fn parse_stats(bytes: &[u8]) -> IResult<&[u8], ScrapeStats> { - do_parse!(bytes, - seeders: be_i32 >> - downloaded: be_i32 >> - leechers: be_i32 >> - (ScrapeStats::new(seeders, downloaded, leechers)) - ) + let (remaining, (seeders, downloaded, leechers)) = tuple((be_i32, be_i32, be_i32))(bytes)?; + Ok((remaining, ScrapeStats::new(seeders, downloaded, leechers))) } // ----------------------------------------------------------------------------// @@ -81,7 +82,10 @@ impl<'a> ScrapeRequest<'a> { } /// Construct a `ScrapeRequest` from the given bytes. - #[must_use] + /// + /// # Errors + /// + /// It will return an error when unable to parse the bytes. pub fn from_bytes(bytes: &'a [u8]) -> IResult<&'a [u8], ScrapeRequest<'a>> { parse_request(bytes) } @@ -93,9 +97,9 @@ impl<'a> ScrapeRequest<'a> { /// # Errors /// /// It would return an IO Error if unable to write the bytes. - pub fn write_bytes(&self, mut writer: W) -> io::Result<()> + pub fn write_bytes(&self, mut writer: W) -> std::io::Result<()> where - W: Write, + W: std::io::Write, { writer.write_all(&self.hashes) } @@ -124,19 +128,21 @@ impl<'a> ScrapeRequest<'a> { } fn parse_request(bytes: &[u8]) -> IResult<&[u8], ScrapeRequest<'_>> { - let remainder_bytes = bytes.len() % bt::INFO_HASH_LEN; + let remainder_bytes = NonZero::new(bytes.len() % bt::INFO_HASH_LEN); + + let needed = remainder_bytes.and_then(|rem| bt::INFO_HASH_LEN.checked_sub(rem.into()).and_then(NonZero::new)); - if remainder_bytes != 0 { - IResult::Incomplete(Needed::Size(bt::INFO_HASH_LEN - remainder_bytes)) + if let Some(needed) = needed { + Err(nom::Err::Incomplete(Needed::Size(needed))) } else { let end_of_bytes = &bytes[bytes.len()..bytes.len()]; - IResult::Done( + Ok(( end_of_bytes, ScrapeRequest { hashes: Cow::Borrowed(bytes), }, - ) + )) } } @@ -159,7 +165,10 @@ impl<'a> ScrapeResponse<'a> { } /// Construct a `ScrapeResponse` from the given bytes. - #[must_use] + /// + /// # Errors + /// + /// It will return an error when unable to parse the bytes. pub fn from_bytes(bytes: &'a [u8]) -> IResult<&'a [u8], ScrapeResponse<'a>> { parse_response(bytes) } @@ -171,9 +180,9 @@ impl<'a> ScrapeResponse<'a> { /// # Errors /// /// It would return an IO Error if unable to write the bytes. - pub fn write_bytes(&self, mut writer: W) -> io::Result<()> + pub fn write_bytes(&self, mut writer: W) -> std::io::Result<()> where - W: Write, + W: std::io::Write, { writer.write_all(&self.stats) } @@ -215,19 +224,21 @@ impl<'a> ScrapeResponse<'a> { } fn parse_response(bytes: &[u8]) -> IResult<&[u8], ScrapeResponse<'_>> { - let remainder_bytes = bytes.len() % SCRAPE_STATS_BYTES; + let remainder_bytes = NonZero::new(bytes.len() % SCRAPE_STATS_BYTES); - if remainder_bytes != 0 { - IResult::Incomplete(Needed::Size(SCRAPE_STATS_BYTES - remainder_bytes)) + let needed = remainder_bytes.and_then(|rem| SCRAPE_STATS_BYTES.checked_sub(rem.into()).and_then(NonZero::new)); + + if let Some(needed) = needed { + Err(nom::Err::Incomplete(Needed::Size(needed))) } else { let end_of_bytes = &bytes[bytes.len()..bytes.len()]; - IResult::Done( + Ok(( end_of_bytes, ScrapeResponse { stats: Cow::Borrowed(bytes), }, - ) + )) } } @@ -299,7 +310,7 @@ impl<'a> Iterator for ScrapeResponseIter<'a> { self.offset = end; match ScrapeStats::from_bytes(&self.stats[start..end]) { - IResult::Done(_, stats) => Some(stats), + Ok((_, stats)) => Some(stats), _ => panic!("Bug In ScrapeResponseIter Caused ScrapeStats Parsing To Fail..."), } } @@ -420,13 +431,13 @@ mod tests { #[test] fn positive_parse_request_empty() { - let hash_one = []; + let hash_none = []; - let received = ScrapeRequest::from_bytes(&hash_one); + let received = ScrapeRequest::from_bytes(&hash_none).unwrap(); let expected = ScrapeRequest::new(); - assert_eq!(received, IResult::Done(&b""[..], expected)); + assert_eq!(received, (&b""[..], expected)); } #[test] @@ -438,7 +449,7 @@ mod tests { let mut expected = ScrapeRequest::new(); expected.insert(hash_one.into()); - assert_eq!(received, IResult::Done(&b""[..], expected)); + assert_eq!(received, IResult::Ok((&b""[..], expected))); } #[test] @@ -456,18 +467,18 @@ mod tests { expected.insert(hash_one.into()); expected.insert(hash_two.into()); - assert_eq!(received, IResult::Done(&b""[..], expected)); + assert_eq!(received, IResult::Ok((&b""[..], expected))); } #[test] fn positive_parse_response_empty() { let stats_bytes = []; - let received = ScrapeResponse::from_bytes(&stats_bytes); + let received = ScrapeResponse::from_bytes(&stats_bytes).unwrap(); let expected = ScrapeResponse::new(); - assert_eq!(received, IResult::Done(&b""[..], expected)); + assert_eq!(received, (&b""[..], expected)); } #[test] @@ -479,7 +490,7 @@ mod tests { let mut expected = ScrapeResponse::new(); expected.insert(ScrapeStats::new(255, 256, 512)); - assert_eq!(received, IResult::Done(&b""[..], expected)); + assert_eq!(received, IResult::Ok((&b""[..], expected))); } #[test] @@ -492,6 +503,6 @@ mod tests { expected.insert(ScrapeStats::new(255, 256, 512)); expected.insert(ScrapeStats::new(1, 2, 3)); - assert_eq!(received, IResult::Done(&b""[..], expected)); + assert_eq!(received, IResult::Ok((&b""[..], expected))); } } diff --git a/packages/utracker/src/server/dispatcher.rs b/packages/utracker/src/server/dispatcher.rs index 1bb216456..1d413658e 100644 --- a/packages/utracker/src/server/dispatcher.rs +++ b/packages/utracker/src/server/dispatcher.rs @@ -1,8 +1,7 @@ -use std::io::{self, Cursor}; use std::net::SocketAddr; -use std::thread; use nom::IResult; +use tracing::instrument; use umio::external::Sender; use umio::{Dispatcher, ELoopBuilder, Provider}; @@ -16,16 +15,20 @@ use crate::server::handler::ServerHandler; const EXPECTED_PACKET_LENGTH: usize = 1500; /// Internal dispatch message for servers. +#[derive(Debug)] pub enum DispatchMessage { Shutdown, } /// Create a new background dispatcher to service requests. #[allow(clippy::module_name_repetitions)] -pub fn create_dispatcher(bind: SocketAddr, handler: H) -> io::Result> +#[instrument(skip())] +pub fn create_dispatcher(bind: SocketAddr, handler: H) -> std::io::Result> where - H: ServerHandler + 'static, + H: ServerHandler + std::fmt::Debug + 'static, { + tracing::debug!("create dispatcher"); + let builder = ELoopBuilder::new() .channel_capacity(1) .timer_capacity(0) @@ -37,7 +40,7 @@ where let dispatch = ServerDispatcher::new(handler); - thread::spawn(move || { + std::thread::spawn(move || { eloop.run(dispatch).expect("bip_utracker: ELoop Shutdown Unexpectedly..."); }); @@ -47,29 +50,36 @@ where // ----------------------------------------------------------------------------// /// Dispatcher that executes requests asynchronously. +#[derive(Debug)] struct ServerDispatcher where - H: ServerHandler, + H: ServerHandler + std::fmt::Debug, { handler: H, } impl ServerDispatcher where - H: ServerHandler, + H: ServerHandler + std::fmt::Debug, { /// Create a new `ServerDispatcher`. + #[instrument(skip(), ret)] fn new(handler: H) -> ServerDispatcher { + tracing::debug!("new"); + ServerDispatcher { handler } } /// Forward the request on to the appropriate handler method. + #[instrument(skip(self, provider))] fn process_request( &mut self, provider: &mut Provider<'_, ServerDispatcher>, request: &TrackerRequest<'_>, addr: SocketAddr, ) { + tracing::debug!("process request"); + let conn_id = request.connection_id(); let trans_id = request.transaction_id(); @@ -77,7 +87,12 @@ where &RequestType::Connect => { if conn_id == request::CONNECT_ID_PROTOCOL_ID { self.forward_connect(provider, trans_id, addr); - } // TODO: Add Logging + } else { + tracing::warn!( + "request was not `CONNECT_ID_PROTOCOL_ID`, i.e. {}, but {conn_id}.", + request::CONNECT_ID_PROTOCOL_ID + ); + } } RequestType::Announce(req) => { self.forward_announce(provider, trans_id, conn_id, req, addr); @@ -89,19 +104,28 @@ where } /// Forward a connect request on to the appropriate handler method. + #[instrument(skip(self, provider))] fn forward_connect(&mut self, provider: &mut Provider<'_, ServerDispatcher>, trans_id: u32, addr: SocketAddr) { - self.handler.connect(addr, |result| { - let response_type = match result { - Ok(conn_id) => ResponseType::Connect(conn_id), - Err(err_msg) => ResponseType::Error(ErrorResponse::new(err_msg)), - }; - let response = TrackerResponse::new(trans_id, response_type); - - write_response(provider, &response, addr); - }); + tracing::debug!("forward connect"); + + let Some(attempt) = self.handler.connect(addr) else { + tracing::warn!("connect attempt canceled"); + + return; + }; + + let response_type = match attempt { + Ok(conn_id) => ResponseType::Connect(conn_id), + Err(err_msg) => ResponseType::Error(ErrorResponse::new(err_msg)), + }; + + let response = TrackerResponse::new(trans_id, response_type); + + write_response(provider, &response, addr); } /// Forward an announce request on to the appropriate handler method. + #[instrument(skip(self, provider))] fn forward_announce( &mut self, provider: &mut Provider<'_, ServerDispatcher>, @@ -110,18 +134,25 @@ where request: &AnnounceRequest<'_>, addr: SocketAddr, ) { - self.handler.announce(addr, conn_id, request, |result| { - let response_type = match result { - Ok(response) => ResponseType::Announce(response), - Err(err_msg) => ResponseType::Error(ErrorResponse::new(err_msg)), - }; - let response = TrackerResponse::new(trans_id, response_type); - - write_response(provider, &response, addr); - }); + tracing::debug!("forward announce"); + + let Some(attempt) = self.handler.announce(addr, conn_id, request) else { + tracing::warn!("announce attempt canceled"); + + return; + }; + + let response_type = match attempt { + Ok(response) => ResponseType::Announce(response), + Err(err_msg) => ResponseType::Error(ErrorResponse::new(err_msg)), + }; + let response = TrackerResponse::new(trans_id, response_type); + + write_response(provider, &response, addr); } /// Forward a scrape request on to the appropriate handler method. + #[instrument(skip(self, provider))] fn forward_scrape( &mut self, provider: &mut Provider<'_, ServerDispatcher>, @@ -130,55 +161,81 @@ where request: &ScrapeRequest<'_>, addr: SocketAddr, ) { - self.handler.scrape(addr, conn_id, request, |result| { - let response_type = match result { - Ok(response) => ResponseType::Scrape(response), - Err(err_msg) => ResponseType::Error(ErrorResponse::new(err_msg)), - }; - let response = TrackerResponse::new(trans_id, response_type); - - write_response(provider, &response, addr); - }); + tracing::debug!("forward scrape"); + + let Some(attempt) = self.handler.scrape(addr, conn_id, request) else { + tracing::warn!("connect scrape canceled"); + + return; + }; + + let response_type = match attempt { + Ok(response) => ResponseType::Scrape(response), + Err(err_msg) => ResponseType::Error(ErrorResponse::new(err_msg)), + }; + + let response = TrackerResponse::new(trans_id, response_type); + + write_response(provider, &response, addr); } } /// Write the given tracker response through to the given provider. +#[instrument(skip(provider))] fn write_response(provider: &mut Provider<'_, ServerDispatcher>, response: &TrackerResponse<'_>, addr: SocketAddr) where - H: ServerHandler, + H: ServerHandler + std::fmt::Debug, { + tracing::debug!("write response"); + provider.outgoing(|buffer| { - let mut cursor = Cursor::new(buffer); - let success = response.write_bytes(&mut cursor).is_ok(); - - if success { - Some((cursor.position().try_into().unwrap(), addr)) - } else { - None - } // TODO: Add Logging + let mut cursor = std::io::Cursor::new(buffer); + + match response.write_bytes(&mut cursor) { + Ok(()) => Some((cursor.position().try_into().unwrap(), addr)), + Err(e) => { + tracing::error!("error writing response to cursor: {e}"); + None + } + } }); } impl Dispatcher for ServerDispatcher where - H: ServerHandler, + H: ServerHandler + std::fmt::Debug, { type Timeout = (); type Message = DispatchMessage; + #[instrument(skip(self, provider))] fn incoming(&mut self, mut provider: Provider<'_, Self>, message: &[u8], addr: SocketAddr) { - let IResult::Done(_, request) = TrackerRequest::from_bytes(message) else { - return; // TODO: Add Logging - }; + let () = match TrackerRequest::from_bytes(message) { + IResult::Ok((_, request)) => { + tracing::debug!("received an incoming request: {request:?}"); - self.process_request(&mut provider, &request, addr); + self.process_request(&mut provider, &request, addr); + } + Err(e) => { + tracing::error!("received an incoming error message: {e}"); + } + }; } + #[instrument(skip(self, provider))] fn notify(&mut self, mut provider: Provider<'_, Self>, message: DispatchMessage) { - match message { - DispatchMessage::Shutdown => provider.shutdown(), - } + let () = match message { + DispatchMessage::Shutdown => { + tracing::debug!("received a shutdown notification"); + + provider.shutdown(); + } + }; } - fn timeout(&mut self, _: Provider<'_, Self>, (): ()) {} + #[instrument(skip(self))] + fn timeout(&mut self, _: Provider<'_, Self>, (): ()) { + tracing::error!("timeout not yet supported!"); + unimplemented!(); + } } diff --git a/packages/utracker/src/server/handler.rs b/packages/utracker/src/server/handler.rs index 45f9ff563..e7682cd8b 100644 --- a/packages/utracker/src/server/handler.rs +++ b/packages/utracker/src/server/handler.rs @@ -13,23 +13,16 @@ pub type ServerResult<'a, T> = Result; pub trait ServerHandler: Send { /// Service a connection id request from the given address. - /// - /// If the result callback is not called, no response will be sent. - fn connect(&mut self, addr: SocketAddr, result: R) - where - R: for<'a> FnOnce(ServerResult<'a, u64>); + fn connect(&mut self, addr: SocketAddr) -> Option>; /// Service an announce request with the given connect id. - /// - /// If the result callback is not called, no response will be sent. - fn announce<'b, R>(&mut self, addr: SocketAddr, id: u64, req: &AnnounceRequest<'b>, result: R) - where - R: for<'a> FnOnce(ServerResult<'a, AnnounceResponse<'a>>); + fn announce( + &mut self, + addr: SocketAddr, + id: u64, + req: &AnnounceRequest<'_>, + ) -> Option>>; /// Service a scrape request with the given connect id. - /// - /// If the result callback is not called, no response will be sent. - fn scrape<'b, R>(&mut self, addr: SocketAddr, id: u64, req: &ScrapeRequest<'b>, result: R) - where - R: for<'a> FnOnce(ServerResult<'a, ScrapeResponse<'a>>); + fn scrape(&mut self, addr: SocketAddr, id: u64, req: &ScrapeRequest<'_>) -> Option>>; } diff --git a/packages/utracker/src/server/mod.rs b/packages/utracker/src/server/mod.rs index 2dd9f03b0..5e2126208 100644 --- a/packages/utracker/src/server/mod.rs +++ b/packages/utracker/src/server/mod.rs @@ -1,6 +1,6 @@ -use std::io; use std::net::SocketAddr; +use tracing::instrument; use umio::external::Sender; use crate::server::dispatcher::DispatchMessage; @@ -13,8 +13,9 @@ pub mod handler; /// /// Server will shutdown on drop. #[allow(clippy::module_name_repetitions)] +#[derive(Debug)] pub struct TrackerServer { - send: Sender, + dispatcher: Sender, } impl TrackerServer { @@ -23,17 +24,25 @@ impl TrackerServer { /// # Errors /// /// It would return an IO Error if unable to run the server. - pub fn run(bind: SocketAddr, handler: H) -> io::Result + #[instrument(skip(), ret)] + pub fn run(bind: SocketAddr, handler: H) -> std::io::Result where - H: ServerHandler + 'static, + H: ServerHandler + std::fmt::Debug + 'static, { - dispatcher::create_dispatcher(bind, handler).map(|send| TrackerServer { send }) + tracing::info!("running server"); + + let dispatcher = dispatcher::create_dispatcher(bind, handler)?; + + Ok(TrackerServer { dispatcher }) } } impl Drop for TrackerServer { + #[instrument(skip(self))] fn drop(&mut self) { - self.send + tracing::debug!("server was dropped, sending shutdown notification..."); + + self.dispatcher .send(DispatchMessage::Shutdown) .expect("bip_utracker: TrackerServer Failed To Send Shutdown Message"); } diff --git a/packages/utracker/tests/common/mod.rs b/packages/utracker/tests/common/mod.rs index 7c7995275..f410126fa 100644 --- a/packages/utracker/tests/common/mod.rs +++ b/packages/utracker/tests/common/mod.rs @@ -1,36 +1,67 @@ use std::collections::{HashMap, HashSet}; use std::net::{SocketAddr, SocketAddrV4, SocketAddrV6}; -use std::sync::{Arc, Mutex}; - -use futures::future::Either; -use futures::sink::Sink; -use futures::stream::Stream; -use futures::sync::mpsc::{self, SendError, UnboundedReceiver, UnboundedSender}; -use futures::{Poll, StartSend}; -use handshake::{DiscoveryInfo, InitiateMessage}; +use std::sync::{Arc, Mutex, Once}; +use std::time::Duration; + +use futures::channel::mpsc; +use futures::sink::SinkExt; +use futures::stream::StreamExt; +use futures::{Sink, Stream}; +use handshake::DiscoveryInfo; +use tracing::instrument; +use tracing::level_filters::LevelFilter; use util::bt::{InfoHash, PeerId}; use util::trans::{LocallyShuffledIds, TransactionIds}; use utracker::announce::{AnnounceEvent, AnnounceRequest, AnnounceResponse}; use utracker::contact::{CompactPeers, CompactPeersV4, CompactPeersV6}; use utracker::scrape::{ScrapeRequest, ScrapeResponse, ScrapeStats}; -use utracker::{ClientMetadata, ServerHandler, ServerResult}; +use utracker::{HandshakerMessage, ServerHandler, ServerResult}; + +#[allow(dead_code)] +pub const DEFAULT_TIMEOUT: Duration = Duration::from_millis(1000); const NUM_PEERS_RETURNED: usize = 20; -#[derive(Clone)] +#[allow(dead_code)] +pub static INIT: Once = Once::new(); + +#[allow(dead_code)] +#[derive(PartialEq, Eq, Debug)] +pub enum TimeoutResult { + TimedOut, + GotResult, +} + +#[allow(dead_code)] +pub fn tracing_stderr_init(filter: LevelFilter) { + let builder = tracing_subscriber::fmt() + .with_max_level(filter) + .with_ansi(true) + .with_writer(std::io::stderr); + + builder.pretty().with_file(true).init(); + + tracing::info!("Logging initialized"); +} + +#[derive(Debug, Clone)] pub struct MockTrackerHandler { inner: Arc>, } -struct InnerMockTrackerHandler { +#[derive(Debug)] +pub struct InnerMockTrackerHandler { cids: HashSet, cid_generator: LocallyShuffledIds, peers_map: HashMap>, } +#[allow(dead_code)] impl MockTrackerHandler { - #[allow(dead_code)] + #[instrument(skip(), ret)] pub fn new() -> MockTrackerHandler { + tracing::debug!("new mock handler"); + MockTrackerHandler { inner: Arc::new(Mutex::new(InnerMockTrackerHandler { cids: HashSet::new(), @@ -40,29 +71,33 @@ impl MockTrackerHandler { } } - #[allow(dead_code)] pub fn num_active_connect_ids(&self) -> usize { self.inner.lock().unwrap().cids.len() } } impl ServerHandler for MockTrackerHandler { - fn connect(&mut self, _: SocketAddr, result: R) - where - R: for<'a> FnOnce(ServerResult<'a, u64>), - { + #[instrument(skip(self), ret)] + fn connect(&mut self, addr: SocketAddr) -> Option> { + tracing::debug!("mock connect"); + let mut inner_lock = self.inner.lock().unwrap(); let cid = inner_lock.cid_generator.generate(); inner_lock.cids.insert(cid); - result(Ok(cid)); + Some(Ok(cid)) } - fn announce<'b, R>(&mut self, addr: SocketAddr, id: u64, req: &AnnounceRequest<'b>, result: R) - where - R: for<'a> FnOnce(ServerResult<'a, AnnounceResponse<'a>>), - { + #[instrument(skip(self), ret)] + fn announce( + &mut self, + addr: SocketAddr, + id: u64, + req: &AnnounceRequest<'_>, + ) -> Option>> { + tracing::debug!("mock announce"); + let mut inner_lock = self.inner.lock().unwrap(); if inner_lock.cids.contains(&id) { @@ -112,21 +147,21 @@ impl ServerHandler for MockTrackerHandler { CompactPeers::V6(v6_peers) }; - result(Ok(AnnounceResponse::new( + Some(Ok(AnnounceResponse::new( 1800, peers.len().try_into().unwrap(), peers.len().try_into().unwrap(), compact_peers, - ))); + ))) } else { - result(Err("Connection ID Is Invalid")); + Some(Err("Connection ID Is Invalid")) } } - fn scrape<'b, R>(&mut self, _: SocketAddr, id: u64, req: &ScrapeRequest<'b>, result: R) - where - R: for<'a> FnOnce(ServerResult<'a, ScrapeResponse<'a>>), - { + #[instrument(skip(self), ret)] + fn scrape(&mut self, _: SocketAddr, id: u64, req: &ScrapeRequest<'_>) -> Option>> { + tracing::debug!("mock scrape"); + let mut inner_lock = self.inner.lock().unwrap(); if inner_lock.cids.contains(&id) { @@ -142,9 +177,9 @@ impl ServerHandler for MockTrackerHandler { )); } - result(Ok(response)); + Some(Ok(response)) } else { - result(Err("Connection ID Is Invalid")); + Some(Err("Connection ID Is Invalid")) } } } @@ -158,13 +193,9 @@ pub fn handshaker() -> (MockHandshakerSink, MockHandshakerStream) { (MockHandshakerSink { send }, MockHandshakerStream { recv }) } -#[derive(Clone)] +#[derive(Debug, Clone)] pub struct MockHandshakerSink { - send: UnboundedSender>, -} - -pub struct MockHandshakerStream { - recv: UnboundedReceiver>, + send: mpsc::UnboundedSender, } impl DiscoveryInfo for MockHandshakerSink { @@ -177,24 +208,63 @@ impl DiscoveryInfo for MockHandshakerSink { } } -impl Sink for MockHandshakerSink { - type SinkItem = Either; - type SinkError = SendError; +impl Sink> for MockHandshakerSink { + type Error = std::io::Error; + + #[instrument(skip(self, cx), ret)] + fn poll_ready(self: std::pin::Pin<&mut Self>, cx: &mut std::task::Context<'_>) -> std::task::Poll> { + tracing::debug!("polling ready"); + + self.send + .poll_ready(cx) + .map_err(|e| std::io::Error::new(std::io::ErrorKind::Other, e)) + } + + #[instrument(skip(self), ret)] + fn start_send(mut self: std::pin::Pin<&mut Self>, item: std::io::Result) -> Result<(), Self::Error> { + tracing::debug!("starting send"); - fn start_send(&mut self, item: Self::SinkItem) -> StartSend { - self.send.start_send(item) + self.send + .start_send(item?) + .map_err(|e| std::io::Error::new(std::io::ErrorKind::Other, e)) } - fn poll_complete(&mut self) -> Poll<(), Self::SinkError> { - self.send.poll_complete() + #[instrument(skip(self, cx), ret)] + fn poll_flush( + mut self: std::pin::Pin<&mut Self>, + cx: &mut std::task::Context<'_>, + ) -> std::task::Poll> { + tracing::debug!("polling flush"); + + self.send + .poll_flush_unpin(cx) + .map_err(|e| std::io::Error::new(std::io::ErrorKind::Other, e)) + } + + #[instrument(skip(self, cx), ret)] + fn poll_close( + mut self: std::pin::Pin<&mut Self>, + cx: &mut std::task::Context<'_>, + ) -> std::task::Poll> { + tracing::debug!("polling close"); + + self.send + .poll_close_unpin(cx) + .map_err(|e| std::io::Error::new(std::io::ErrorKind::Other, e)) } } +pub struct MockHandshakerStream { + recv: mpsc::UnboundedReceiver, +} + impl Stream for MockHandshakerStream { - type Item = Either; - type Error = (); + type Item = std::io::Result; + + #[instrument(skip(self, cx), ret)] + fn poll_next(mut self: std::pin::Pin<&mut Self>, cx: &mut std::task::Context<'_>) -> std::task::Poll> { + tracing::debug!("polling next"); - fn poll(&mut self) -> Poll, Self::Error> { - self.recv.poll() + self.recv.poll_next_unpin(cx).map(|maybe| maybe.map(Ok)) } } diff --git a/packages/utracker/tests/test_announce_start.rs b/packages/utracker/tests/test_announce_start.rs index 33c337c15..37ef75f64 100644 --- a/packages/utracker/tests/test_announce_start.rs +++ b/packages/utracker/tests/test_announce_start.rs @@ -1,43 +1,51 @@ use std::net::SocketAddr; -use std::thread::{self}; use std::time::Duration; -use common::{handshaker, MockTrackerHandler}; -use futures::future::Either; -use futures::stream::Stream; +use common::{handshaker, tracing_stderr_init, MockTrackerHandler, DEFAULT_TIMEOUT, INIT}; +use futures::StreamExt as _; use handshake::Protocol; +use tracing::level_filters::LevelFilter; use util::bt::{self}; use utracker::announce::{AnnounceEvent, ClientState}; -use utracker::{ClientRequest, TrackerClient, TrackerServer}; +use utracker::{ClientRequest, HandshakerMessage, TrackerClient, TrackerServer}; mod common; -#[test] -#[allow(unused)] -fn positive_announce_started() { - let (sink, stream) = handshaker(); +#[tokio::test] +async fn positive_announce_started() { + INIT.call_once(|| { + tracing_stderr_init(LevelFilter::ERROR); + }); + + let (handshaker_sender, mut handshaker_receiver) = handshaker(); let server_addr = "127.0.0.1:3501".parse().unwrap(); let mock_handler = MockTrackerHandler::new(); - let server = TrackerServer::run(server_addr, mock_handler).unwrap(); + let _server = TrackerServer::run(server_addr, mock_handler).unwrap(); - thread::sleep(Duration::from_millis(100)); + std::thread::sleep(Duration::from_millis(100)); - let mut client = TrackerClient::new("127.0.0.1:4501".parse().unwrap(), sink).unwrap(); + let mut client = TrackerClient::new("127.0.0.1:4501".parse().unwrap(), handshaker_sender, None).unwrap(); let hash = [0u8; bt::INFO_HASH_LEN].into(); - let send_token = client + + tracing::warn!("sending announce"); + let _send_token = client .request( server_addr, ClientRequest::Announce(hash, ClientState::new(0, 0, 0, AnnounceEvent::Started)), ) .unwrap(); - let mut blocking_stream = stream.wait(); - - let init_msg = match blocking_stream.next().unwrap().unwrap() { - Either::A(a) => a, - Either::B(_) => unreachable!(), + tracing::warn!("receiving initiate message"); + let init_msg = match tokio::time::timeout(DEFAULT_TIMEOUT, handshaker_receiver.next()) + .await + .unwrap() + .unwrap() + .unwrap() + { + HandshakerMessage::InitiateMessage(message) => message, + HandshakerMessage::ClientMetadata(_) => unreachable!(), }; let exp_peer_addr: SocketAddr = "127.0.0.1:6969".parse().unwrap(); @@ -46,9 +54,15 @@ fn positive_announce_started() { assert_eq!(&exp_peer_addr, init_msg.address()); assert_eq!(&hash, init_msg.hash()); - let metadata = match blocking_stream.next().unwrap().unwrap() { - Either::B(b) => b, - Either::A(_) => unreachable!(), + tracing::warn!("receiving client metadata"); + let metadata = match tokio::time::timeout(DEFAULT_TIMEOUT, handshaker_receiver.next()) + .await + .unwrap() + .unwrap() + .unwrap() + { + HandshakerMessage::InitiateMessage(_) => unreachable!(), + HandshakerMessage::ClientMetadata(metadata) => metadata, }; let metadata_result = metadata.result().as_ref().unwrap().announce_response().unwrap(); diff --git a/packages/utracker/tests/test_announce_stop.rs b/packages/utracker/tests/test_announce_stop.rs index f9ade00ef..4fc2fc505 100644 --- a/packages/utracker/tests/test_announce_stop.rs +++ b/packages/utracker/tests/test_announce_stop.rs @@ -1,52 +1,61 @@ -use std::thread::{self}; use std::time::Duration; -use common::{handshaker, MockTrackerHandler}; -use futures::future::Either; -use futures::stream::Stream; +use common::{handshaker, tracing_stderr_init, MockTrackerHandler, DEFAULT_TIMEOUT, INIT}; +use futures::StreamExt as _; +use tracing::level_filters::LevelFilter; use util::bt::{self}; use utracker::announce::{AnnounceEvent, ClientState}; -use utracker::{ClientRequest, TrackerClient, TrackerServer}; +use utracker::{ClientRequest, HandshakerMessage, TrackerClient, TrackerServer}; mod common; -#[test] -#[allow(unused)] -fn positive_announce_stopped() { - let (sink, stream) = handshaker(); +#[tokio::test] +async fn positive_announce_stopped() { + INIT.call_once(|| { + tracing_stderr_init(LevelFilter::ERROR); + }); + + let (sink, mut stream) = handshaker(); let server_addr = "127.0.0.1:3502".parse().unwrap(); let mock_handler = MockTrackerHandler::new(); - let server = TrackerServer::run(server_addr, mock_handler).unwrap(); + let _server = TrackerServer::run(server_addr, mock_handler).unwrap(); - thread::sleep(Duration::from_millis(100)); + std::thread::sleep(Duration::from_millis(100)); - let mut client = TrackerClient::new("127.0.0.1:4502".parse().unwrap(), sink).unwrap(); + let mut client = TrackerClient::new("127.0.0.1:4502".parse().unwrap(), sink, None).unwrap(); let info_hash = [0u8; bt::INFO_HASH_LEN].into(); - let mut blocking_stream = stream.wait(); // Started { - let send_token = client + let _send_token = client .request( server_addr, ClientRequest::Announce(info_hash, ClientState::new(0, 0, 0, AnnounceEvent::Started)), ) .unwrap(); - let init_msg = match blocking_stream.next().unwrap().unwrap() { - Either::A(a) => a, - Either::B(_) => unreachable!(), + let _init_msg = match tokio::time::timeout(DEFAULT_TIMEOUT, stream.next()) + .await + .unwrap() + .unwrap() + .unwrap() + { + HandshakerMessage::InitiateMessage(message) => message, + HandshakerMessage::ClientMetadata(_) => unreachable!(), }; - let metadata = match blocking_stream.next().unwrap().unwrap() { - Either::B(b) => b, - Either::A(_) => unreachable!(), + let metadata = match tokio::time::timeout(DEFAULT_TIMEOUT, stream.next()) + .await + .unwrap() + .unwrap() + .unwrap() + { + HandshakerMessage::InitiateMessage(_) => unreachable!(), + HandshakerMessage::ClientMetadata(metadata) => metadata, }; - assert_eq!(send_token, metadata.token()); - let response = metadata.result().as_ref().unwrap().announce_response().unwrap(); assert_eq!(response.leechers(), 1); assert_eq!(response.seeders(), 1); @@ -55,20 +64,23 @@ fn positive_announce_stopped() { // Stopped { - let send_token = client + let _send_token = client .request( server_addr, ClientRequest::Announce(info_hash, ClientState::new(0, 0, 0, AnnounceEvent::Stopped)), ) .unwrap(); - let metadata = match blocking_stream.next().unwrap().unwrap() { - Either::B(b) => b, - Either::A(_) => unreachable!(), + let metadata = match tokio::time::timeout(DEFAULT_TIMEOUT, stream.next()) + .await + .unwrap() + .unwrap() + .unwrap() + { + HandshakerMessage::InitiateMessage(_) => unreachable!(), + HandshakerMessage::ClientMetadata(metadata) => metadata, }; - assert_eq!(send_token, metadata.token()); - let response = metadata.result().as_ref().unwrap().announce_response().unwrap(); assert_eq!(response.leechers(), 0); assert_eq!(response.seeders(), 0); diff --git a/packages/utracker/tests/test_client_drop.rs b/packages/utracker/tests/test_client_drop.rs index 8f4c94985..3564cb9e7 100644 --- a/packages/utracker/tests/test_client_drop.rs +++ b/packages/utracker/tests/test_client_drop.rs @@ -1,22 +1,27 @@ -use common::handshaker; -use futures::future::Either; -use futures::stream::Stream; +use std::net::SocketAddr; + +use common::{handshaker, tracing_stderr_init, DEFAULT_TIMEOUT, INIT}; +use futures::StreamExt as _; +use tracing::level_filters::LevelFilter; use util::bt::{self}; use utracker::announce::{AnnounceEvent, ClientState}; -use utracker::{ClientError, ClientRequest, TrackerClient}; +use utracker::{ClientError, ClientRequest, HandshakerMessage, TrackerClient}; mod common; -#[test] -#[allow(unused)] -fn positive_client_request_failed() { - let (sink, stream) = handshaker(); +#[tokio::test] +async fn positive_client_request_failed() { + INIT.call_once(|| { + tracing_stderr_init(LevelFilter::ERROR); + }); + + let (sink, mut stream) = handshaker(); - let server_addr = "127.0.0.1:3503".parse().unwrap(); - // don't actually create the server :) since we want the request to wait for a little bit until we drop + let server_addr: SocketAddr = "127.0.0.1:3503".parse().unwrap(); + // Don't actually create the server since we want the request to wait for a little bit until we drop let send_token = { - let mut client = TrackerClient::new("127.0.0.1:4503".parse().unwrap(), sink).unwrap(); + let mut client = TrackerClient::new("127.0.0.1:4503".parse().unwrap(), sink, None).unwrap(); client .request( @@ -30,17 +35,20 @@ fn positive_client_request_failed() { }; // Client is now dropped - let mut blocking_stream = stream.wait(); - - let metadata = match blocking_stream.next().unwrap().unwrap() { - Either::B(b) => b, - Either::A(_) => unreachable!(), + let metadata = match tokio::time::timeout(DEFAULT_TIMEOUT, stream.next()) + .await + .unwrap() + .unwrap() + .unwrap() + { + HandshakerMessage::InitiateMessage(_) => unreachable!(), + HandshakerMessage::ClientMetadata(metadata) => metadata, }; assert_eq!(send_token, metadata.token()); match metadata.result() { - &Err(ClientError::ClientShutdown) => (), + Err(ClientError::ClientShutdown) => (), _ => panic!("Did Not Receive ClientShutdown..."), } } diff --git a/packages/utracker/tests/test_client_full.rs b/packages/utracker/tests/test_client_full.rs index 79d6c85fd..f2e78d150 100644 --- a/packages/utracker/tests/test_client_full.rs +++ b/packages/utracker/tests/test_client_full.rs @@ -1,26 +1,30 @@ -use std::mem; - -use common::handshaker; -use futures::stream::Stream; -use futures::Future; +use common::{handshaker, tracing_stderr_init, DEFAULT_TIMEOUT, INIT}; +use futures::StreamExt as _; +use tracing::level_filters::LevelFilter; use util::bt::{self}; use utracker::announce::{AnnounceEvent, ClientState}; use utracker::{ClientRequest, TrackerClient}; mod common; -#[test] -#[allow(unused)] -fn positive_client_request_dropped() { - let (sink, mut stream) = handshaker(); +#[tokio::test] +async fn positive_client_request_dropped() { + INIT.call_once(|| { + tracing_stderr_init(LevelFilter::ERROR); + }); + + let (sink, stream) = handshaker(); let server_addr = "127.0.0.1:3504".parse().unwrap(); let request_capacity = 10; - let mut client = TrackerClient::with_capacity("127.0.0.1:4504".parse().unwrap(), sink, request_capacity).unwrap(); + let mut client = TrackerClient::new("127.0.0.1:4504".parse().unwrap(), sink, Some(request_capacity)).unwrap(); + + tracing::warn!("sending announce requests to fill buffer"); + for i in 1..=request_capacity { + tracing::warn!("request {i} of {request_capacity}"); - for _ in 0..request_capacity { client .request( server_addr, @@ -32,6 +36,7 @@ fn positive_client_request_dropped() { .unwrap(); } + tracing::warn!("sending one more announce request, it should fail"); assert!(client .request( server_addr, @@ -42,8 +47,8 @@ fn positive_client_request_dropped() { ) .is_none()); - mem::drop(client); + std::mem::drop(client); - let buffer = stream.collect().wait().unwrap(); + let buffer: Vec<_> = tokio::time::timeout(DEFAULT_TIMEOUT, stream.collect()).await.unwrap(); assert_eq!(request_capacity, buffer.len()); } diff --git a/packages/utracker/tests/test_connect.rs b/packages/utracker/tests/test_connect.rs index 4d0833c09..2111d5681 100644 --- a/packages/utracker/tests/test_connect.rs +++ b/packages/utracker/tests/test_connect.rs @@ -1,27 +1,29 @@ -use std::thread::{self}; use std::time::Duration; -use common::{handshaker, MockTrackerHandler}; -use futures::future::Either; -use futures::stream::Stream; +use common::{handshaker, tracing_stderr_init, MockTrackerHandler, DEFAULT_TIMEOUT, INIT}; +use futures::StreamExt as _; +use tracing::level_filters::LevelFilter; use util::bt::{self}; use utracker::announce::{AnnounceEvent, ClientState}; -use utracker::{ClientRequest, TrackerClient, TrackerServer}; +use utracker::{ClientRequest, HandshakerMessage, TrackerClient, TrackerServer}; mod common; -#[test] -#[allow(unused)] -fn positive_receive_connect_id() { - let (sink, stream) = handshaker(); +#[tokio::test] +async fn positive_receive_connect_id() { + INIT.call_once(|| { + tracing_stderr_init(LevelFilter::ERROR); + }); + + let (sink, mut stream) = handshaker(); let server_addr = "127.0.0.1:3505".parse().unwrap(); let mock_handler = MockTrackerHandler::new(); - let server = TrackerServer::run(server_addr, mock_handler).unwrap(); + let _server = TrackerServer::run(server_addr, mock_handler).unwrap(); - thread::sleep(Duration::from_millis(100)); + std::thread::sleep(Duration::from_millis(100)); - let mut client = TrackerClient::new("127.0.0.1:4505".parse().unwrap(), sink).unwrap(); + let mut client = TrackerClient::new("127.0.0.1:4505".parse().unwrap(), sink, None).unwrap(); let send_token = client .request( @@ -33,16 +35,24 @@ fn positive_receive_connect_id() { ) .unwrap(); - let mut blocking_stream = stream.wait(); - - let _init_msg = match blocking_stream.next().unwrap().unwrap() { - Either::A(a) => a, - Either::B(_) => unreachable!(), + let _init_msg = match tokio::time::timeout(DEFAULT_TIMEOUT, stream.next()) + .await + .unwrap() + .unwrap() + .unwrap() + { + HandshakerMessage::InitiateMessage(message) => message, + HandshakerMessage::ClientMetadata(_) => unreachable!(), }; - let metadata = match blocking_stream.next().unwrap().unwrap() { - Either::B(b) => b, - Either::A(_) => unreachable!(), + let metadata = match tokio::time::timeout(DEFAULT_TIMEOUT, stream.next()) + .await + .unwrap() + .unwrap() + .unwrap() + { + HandshakerMessage::InitiateMessage(_) => unreachable!(), + HandshakerMessage::ClientMetadata(metadata) => metadata, }; assert_eq!(send_token, metadata.token()); diff --git a/packages/utracker/tests/test_connect_cache.rs b/packages/utracker/tests/test_connect_cache.rs index d2ec46864..25b89f277 100644 --- a/packages/utracker/tests/test_connect_cache.rs +++ b/packages/utracker/tests/test_connect_cache.rs @@ -1,33 +1,38 @@ -use std::thread::{self}; use std::time::Duration; -use common::{handshaker, MockTrackerHandler}; -use futures::stream::Stream; +use common::{tracing_stderr_init, MockTrackerHandler, DEFAULT_TIMEOUT, INIT}; +use futures::StreamExt as _; +use tracing::level_filters::LevelFilter; use util::bt::{self}; use utracker::{ClientRequest, TrackerClient, TrackerServer}; mod common; -#[test] -#[allow(unused)] -fn positive_connection_id_cache() { - let (sink, mut stream) = handshaker(); +#[tokio::test] +async fn positive_connection_id_cache() { + INIT.call_once(|| { + tracing_stderr_init(LevelFilter::ERROR); + }); + + let (sink, mut stream) = common::handshaker(); let server_addr = "127.0.0.1:3506".parse().unwrap(); let mock_handler = MockTrackerHandler::new(); - let server = TrackerServer::run(server_addr, mock_handler.clone()).unwrap(); + let _server = TrackerServer::run(server_addr, mock_handler.clone()).unwrap(); - thread::sleep(Duration::from_millis(100)); + std::thread::sleep(Duration::from_millis(100)); - let mut client = TrackerClient::new("127.0.0.1:4506".parse().unwrap(), sink).unwrap(); + let mut client = TrackerClient::new("127.0.0.1:4506".parse().unwrap(), sink, None).unwrap(); let first_hash = [0u8; bt::INFO_HASH_LEN].into(); let second_hash = [1u8; bt::INFO_HASH_LEN].into(); - let mut blocking_stream = stream.wait(); - client.request(server_addr, ClientRequest::Scrape(first_hash)).unwrap(); - blocking_stream.next().unwrap(); + tokio::time::timeout(DEFAULT_TIMEOUT, stream.next()) + .await + .unwrap() + .unwrap() + .unwrap(); assert_eq!(mock_handler.num_active_connect_ids(), 1); @@ -36,7 +41,11 @@ fn positive_connection_id_cache() { } for _ in 0..10 { - blocking_stream.next().unwrap(); + tokio::time::timeout(DEFAULT_TIMEOUT, stream.next()) + .await + .unwrap() + .unwrap() + .unwrap(); } assert_eq!(mock_handler.num_active_connect_ids(), 1); diff --git a/packages/utracker/tests/test_scrape.rs b/packages/utracker/tests/test_scrape.rs index af245f466..607f02e66 100644 --- a/packages/utracker/tests/test_scrape.rs +++ b/packages/utracker/tests/test_scrape.rs @@ -1,36 +1,41 @@ -use std::thread::{self}; use std::time::Duration; -use common::{handshaker, MockTrackerHandler}; -use futures::future::Either; -use futures::stream::Stream; +use common::{handshaker, tracing_stderr_init, MockTrackerHandler, DEFAULT_TIMEOUT, INIT}; +use futures::StreamExt as _; +use tracing::level_filters::LevelFilter; use util::bt::{self}; -use utracker::{ClientRequest, TrackerClient, TrackerServer}; +use utracker::{ClientRequest, HandshakerMessage, TrackerClient, TrackerServer}; mod common; -#[test] -#[allow(unused)] -fn positive_scrape() { - let (sink, stream) = handshaker(); +#[tokio::test] +async fn positive_scrape() { + INIT.call_once(|| { + tracing_stderr_init(LevelFilter::ERROR); + }); + + let (sink, mut stream) = handshaker(); let server_addr = "127.0.0.1:3507".parse().unwrap(); let mock_handler = MockTrackerHandler::new(); - let server = TrackerServer::run(server_addr, mock_handler).unwrap(); + let _server = TrackerServer::run(server_addr, mock_handler).unwrap(); - thread::sleep(Duration::from_millis(100)); + std::thread::sleep(Duration::from_millis(100)); - let mut client = TrackerClient::new("127.0.0.1:4507".parse().unwrap(), sink).unwrap(); + let mut client = TrackerClient::new("127.0.0.1:4507".parse().unwrap(), sink, None).unwrap(); let send_token = client .request(server_addr, ClientRequest::Scrape([0u8; bt::INFO_HASH_LEN].into())) .unwrap(); - let mut blocking_stream = stream.wait(); - - let metadata = match blocking_stream.next().unwrap().unwrap() { - Either::B(b) => b, - Either::A(_) => unreachable!(), + let metadata = match tokio::time::timeout(DEFAULT_TIMEOUT, stream.next()) + .await + .unwrap() + .unwrap() + .unwrap() + { + HandshakerMessage::InitiateMessage(_) => unreachable!(), + HandshakerMessage::ClientMetadata(metadata) => metadata, }; assert_eq!(send_token, metadata.token()); diff --git a/packages/utracker/tests/test_server_drop.rs b/packages/utracker/tests/test_server_drop.rs index 1e4ee51a5..23002461a 100644 --- a/packages/utracker/tests/test_server_drop.rs +++ b/packages/utracker/tests/test_server_drop.rs @@ -1,7 +1,8 @@ use std::net::UdpSocket; use std::time::Duration; -use common::MockTrackerHandler; +use common::{tracing_stderr_init, MockTrackerHandler, INIT}; +use tracing::level_filters::LevelFilter; use utracker::request::{self, RequestType, TrackerRequest}; use utracker::TrackerServer; @@ -10,6 +11,10 @@ mod common; #[test] #[allow(unused)] fn positive_server_dropped() { + INIT.call_once(|| { + tracing_stderr_init(LevelFilter::ERROR); + }); + let server_addr = "127.0.0.1:3508".parse().unwrap(); let mock_handler = MockTrackerHandler::new(); From 0e26eb7055015d1f1870bcda8c77ca77aa9dd6ac Mon Sep 17 00:00:00 2001 From: Cameron Garnham Date: Thu, 15 Aug 2024 18:50:40 +0200 Subject: [PATCH 2/4] docs: update main readme todo --- README.md | 11 ++++++++--- 1 file changed, 8 insertions(+), 3 deletions(-) diff --git a/README.md b/README.md index c39bafb52..413975125 100644 --- a/README.md +++ b/README.md @@ -13,12 +13,14 @@ In this fork we have: - [x] Implemented continuous integration using github workflows. ([#8]) - [x] Update some of the project dependencies. ([#9], [#17], [#26], [#27]) - [x] Preformed a general cleanup of the codebase. ([#10], [#16], [#18], [#29], [#31]) +- [x] Updated all dependencies to modern versions. ( [#19], [#20], [#21], [#22], [#23], [#25]) -The future goals are: -- [ ] Update the other dependencies (__Significant Work Required__). ( [#19], [#20], [#21], [#22], [#23], [#25]) +The future goals are: - [ ] Publish updated versions of the crates. ([#37]) - [ ] Increase coverage of unit tests. ([#38]) +- [ ] Remove dependency on umio in `utracker` package. Instead use Tokio. ([#53]) +- [ ] Overhaul the old `mio` architecture in the `dht` package. Making better use of Tokio. ([#54]) __We would like to make a special thanks to all the developers who had contributed to and created this great project.__ @@ -94,6 +96,9 @@ additional terms or conditions. [#25]: https://github.com/torrust/bittorrent-infrastructure-project/issues/25 [#37]: https://github.com/torrust/bittorrent-infrastructure-project/issues/37 [#38]: https://github.com/torrust/bittorrent-infrastructure-project/issues/38 +[#53]: https://github.com/torrust/bittorrent-infrastructure-project/issues/53 +[#54]: https://github.com/torrust/bittorrent-infrastructure-project/issues/54 + [t_i37]: https://img.shields.io/github/issues/detail/title/torrust/bittorrent-infrastructure-project/37?style=for-the-badge& [s_i37]: https://img.shields.io/github/issues/detail/state/torrust/bittorrent-infrastructure-project/37?style=for-the-badge&label=%E3%80%80 @@ -113,4 +118,4 @@ additional terms or conditions. [c_bip_select]: https://crates.io/crates/bip_select [c_bip_dht]: https://crates.io/crates/bip_dht [c_bip_metainfo]: https://crates.io/crates/bip_metainfo -[c_bip_utracker]: https://crates.io/crates/bip_utracker \ No newline at end of file +[c_bip_utracker]: https://crates.io/crates/bip_utracker From e952875811e6a5b117c5eb22e74c1dd517e3ef2f Mon Sep 17 00:00:00 2001 From: Cameron Garnham Date: Mon, 19 Aug 2024 11:21:51 +0200 Subject: [PATCH 3/4] deps: add umio as contrib package --- Cargo.toml | 1 + contrib/umio/Cargo.toml | 19 ++++ contrib/umio/README.md | 23 +++++ contrib/umio/src/buffer.rs | 87 ++++++++++++++++++ contrib/umio/src/dispatcher.rs | 135 ++++++++++++++++++++++++++++ contrib/umio/src/eloop.rs | 125 ++++++++++++++++++++++++++ contrib/umio/src/external.rs | 1 + contrib/umio/src/lib.rs | 17 ++++ contrib/umio/src/provider.rs | 72 +++++++++++++++ contrib/umio/tests/common/mod.rs | 69 ++++++++++++++ contrib/umio/tests/test_incoming.rs | 40 +++++++++ contrib/umio/tests/test_notify.rs | 31 +++++++ contrib/umio/tests/test_outgoing.rs | 39 ++++++++ contrib/umio/tests/test_shutdown.rs | 26 ++++++ contrib/umio/tests/test_timeout.rs | 34 +++++++ packages/utracker/Cargo.toml | 3 +- 16 files changed, 721 insertions(+), 1 deletion(-) create mode 100644 contrib/umio/Cargo.toml create mode 100644 contrib/umio/README.md create mode 100644 contrib/umio/src/buffer.rs create mode 100644 contrib/umio/src/dispatcher.rs create mode 100644 contrib/umio/src/eloop.rs create mode 100644 contrib/umio/src/external.rs create mode 100644 contrib/umio/src/lib.rs create mode 100644 contrib/umio/src/provider.rs create mode 100644 contrib/umio/tests/common/mod.rs create mode 100644 contrib/umio/tests/test_incoming.rs create mode 100644 contrib/umio/tests/test_notify.rs create mode 100644 contrib/umio/tests/test_outgoing.rs create mode 100644 contrib/umio/tests/test_shutdown.rs create mode 100644 contrib/umio/tests/test_timeout.rs diff --git a/Cargo.toml b/Cargo.toml index e8fd1d4f8..a2305115e 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,5 +1,6 @@ [workspace] members = [ + "contrib/umio", "examples/get_metadata", "examples/simple_torrent", "packages/bencode", diff --git a/contrib/umio/Cargo.toml b/contrib/umio/Cargo.toml new file mode 100644 index 000000000..7307f6b44 --- /dev/null +++ b/contrib/umio/Cargo.toml @@ -0,0 +1,19 @@ +[package] +authors = ["Andrew "] +description = "Message Based Readiness API In Rust" +keywords = ["message", "mio", "readyness"] +name = "umio" +readme = "README.md" + +categories.workspace = true +documentation.workspace = true +edition.workspace = true +homepage.workspace = true +license.workspace = true +publish.workspace = true + +repository.workspace = true +version.workspace = true + +[dependencies] +mio = "0.5" diff --git a/contrib/umio/README.md b/contrib/umio/README.md new file mode 100644 index 000000000..0fe5411cc --- /dev/null +++ b/contrib/umio/README.md @@ -0,0 +1,23 @@ +umio-rs +======= +Message Based Readiness API In Rust. + +Thin layer over mio for working with a single udp socket while retaining access to timers and event loop channels. + + +License +------- + +Licensed under either of + + * Apache License, Version 2.0 ([LICENSE-APACHE](LICENSE-APACHE) or http://www.apache.org/licenses/LICENSE-2.0) + * MIT license ([LICENSE-MIT](LICENSE-MIT) or http://opensource.org/licenses/MIT) + +at your option. + +Contribution +------------ + +Unless you explicitly state otherwise, any contribution intentionally submitted +for inclusion in the work by you, as defined in the Apache-2.0 license, shall be dual licensed as above, without any +additional terms or conditions. diff --git a/contrib/umio/src/buffer.rs b/contrib/umio/src/buffer.rs new file mode 100644 index 000000000..c6b21b72f --- /dev/null +++ b/contrib/umio/src/buffer.rs @@ -0,0 +1,87 @@ +#[allow(clippy::module_name_repetitions)] +pub struct BufferPool { + // Use Stack For Temporal Locality + buffers: Vec, + buffer_size: usize, +} + +impl BufferPool { + pub fn new(buffer_size: usize) -> BufferPool { + let buffers = Vec::new(); + + BufferPool { buffers, buffer_size } + } + + pub fn pop(&mut self) -> Buffer { + self.buffers.pop().unwrap_or(Buffer::new(self.buffer_size)) + } + + pub fn push(&mut self, mut buffer: Buffer) { + buffer.reset_position(); + + self.buffers.push(buffer); + } +} + +//----------------------------------------------------------------------------// + +/// Reusable region of memory for incoming and outgoing messages. +pub struct Buffer { + buffer: Vec, + bytes_written: usize, +} + +impl Buffer { + fn new(len: usize) -> Buffer { + Buffer { + buffer: vec![0u8; len], + bytes_written: 0, + } + } + + fn reset_position(&mut self) { + self.set_written(0); + } + + /// Update the number of bytes written to the buffer. + pub fn set_written(&mut self, bytes: usize) { + self.bytes_written = bytes; + } +} + +impl AsRef<[u8]> for Buffer { + fn as_ref(&self) -> &[u8] { + &self.buffer[..self.bytes_written] + } +} + +impl AsMut<[u8]> for Buffer { + fn as_mut(&mut self) -> &mut [u8] { + &mut self.buffer[self.bytes_written..] + } +} + +#[cfg(test)] +mod tests { + use super::{Buffer, BufferPool}; + + const DEFAULT_BUFFER_SIZE: usize = 1500; + + #[test] + fn positive_buffer_pool_buffer_len() { + let mut buffers = BufferPool::new(DEFAULT_BUFFER_SIZE); + let mut buffer = buffers.pop(); + + assert_eq!(buffer.as_mut().len(), DEFAULT_BUFFER_SIZE); + assert_eq!(buffer.as_ref().len(), 0); + } + + #[test] + fn positive_buffer_len_update() { + let mut buffer = Buffer::new(DEFAULT_BUFFER_SIZE); + buffer.set_written(DEFAULT_BUFFER_SIZE - 1); + + assert_eq!(buffer.as_mut().len(), 1); + assert_eq!(buffer.as_ref().len(), DEFAULT_BUFFER_SIZE - 1); + } +} diff --git a/contrib/umio/src/dispatcher.rs b/contrib/umio/src/dispatcher.rs new file mode 100644 index 000000000..a62fd5979 --- /dev/null +++ b/contrib/umio/src/dispatcher.rs @@ -0,0 +1,135 @@ +use std::collections::VecDeque; +use std::net::SocketAddr; + +use mio::udp::UdpSocket; +use mio::{EventLoop, EventSet, Handler, PollOpt, Token}; + +use crate::buffer::{Buffer, BufferPool}; +use crate::{provider, Provider}; + +/// Handles events occurring within the event loop. +pub trait Dispatcher: Sized { + type Timeout; + type Message: Send; + + /// Process an incoming message from the given address. + #[allow(unused)] + fn incoming(&mut self, provider: Provider<'_, Self>, message: &[u8], addr: SocketAddr) {} + + /// Process a message sent via the event loop channel. + #[allow(unused)] + fn notify(&mut self, provider: Provider<'_, Self>, message: Self::Message) {} + + /// Process a timeout that has been triggered. + #[allow(unused)] + fn timeout(&mut self, provider: Provider<'_, Self>, timeout: Self::Timeout) {} +} + +//----------------------------------------------------------------------------// + +const UDP_SOCKET_TOKEN: Token = Token(2); + +pub struct DispatchHandler { + dispatch: D, + out_queue: VecDeque<(Buffer, SocketAddr)>, + udp_socket: UdpSocket, + buffer_pool: BufferPool, + current_set: EventSet, +} + +impl DispatchHandler { + pub fn new( + udp_socket: UdpSocket, + buffer_size: usize, + dispatch: D, + event_loop: &mut EventLoop>, + ) -> DispatchHandler { + let buffer_pool = BufferPool::new(buffer_size); + let out_queue = VecDeque::new(); + + event_loop + .register(&udp_socket, UDP_SOCKET_TOKEN, EventSet::readable(), PollOpt::edge()) + .unwrap(); + + DispatchHandler { + dispatch, + out_queue, + udp_socket, + buffer_pool, + current_set: EventSet::readable(), + } + } + + pub fn handle_write(&mut self) { + if let Some((buffer, addr)) = self.out_queue.pop_front() { + self.udp_socket.send_to(buffer.as_ref(), &addr).unwrap(); + + self.buffer_pool.push(buffer); + }; + } + + pub fn handle_read(&mut self) -> Option<(Buffer, SocketAddr)> { + let mut buffer = self.buffer_pool.pop(); + + if let Ok(Some((bytes, addr))) = self.udp_socket.recv_from(buffer.as_mut()) { + buffer.set_written(bytes); + + Some((buffer, addr)) + } else { + None + } + } +} + +impl Handler for DispatchHandler { + type Timeout = D::Timeout; + type Message = D::Message; + + fn ready(&mut self, event_loop: &mut EventLoop, token: Token, events: EventSet) { + if token != UDP_SOCKET_TOKEN { + return; + } + + if events.is_writable() { + self.handle_write(); + } + + if events.is_readable() { + let Some((buffer, addr)) = self.handle_read() else { + return; + }; + + { + let provider = provider::new(&mut self.buffer_pool, &mut self.out_queue, event_loop); + + self.dispatch.incoming(provider, buffer.as_ref(), addr); + } + + self.buffer_pool.push(buffer); + } + } + + fn notify(&mut self, event_loop: &mut EventLoop, msg: Self::Message) { + let provider = provider::new(&mut self.buffer_pool, &mut self.out_queue, event_loop); + + self.dispatch.notify(provider, msg); + } + + fn timeout(&mut self, event_loop: &mut EventLoop, timeout: Self::Timeout) { + let provider = provider::new(&mut self.buffer_pool, &mut self.out_queue, event_loop); + + self.dispatch.timeout(provider, timeout); + } + + fn tick(&mut self, event_loop: &mut EventLoop) { + self.current_set = if self.out_queue.is_empty() { + EventSet::readable() + } else { + EventSet::readable() | EventSet::writable() + }; + + event_loop + .reregister(&self.udp_socket, UDP_SOCKET_TOKEN, self.current_set, PollOpt::edge()) + .unwrap(); + } +} diff --git a/contrib/umio/src/eloop.rs b/contrib/umio/src/eloop.rs new file mode 100644 index 000000000..cedf030bf --- /dev/null +++ b/contrib/umio/src/eloop.rs @@ -0,0 +1,125 @@ +use std::io::Result; +use std::net::{Ipv4Addr, SocketAddr, SocketAddrV4}; + +use mio::udp::UdpSocket; +use mio::{EventLoop, EventLoopConfig, Sender}; + +use crate::dispatcher::{DispatchHandler, Dispatcher}; + +const DEFAULT_BUFFER_SIZE: usize = 1500; +const DEFAULT_CHANNEL_CAPACITY: usize = 4096; +const DEFAULT_TIMER_CAPACITY: usize = 65536; + +/// Builder for specifying attributes of an event loop. +pub struct ELoopBuilder { + channel_capacity: usize, + timer_capacity: usize, + buffer_size: usize, + bind_address: SocketAddr, +} + +impl ELoopBuilder { + /// Create a new event loop builder. + #[must_use] + pub fn new() -> ELoopBuilder { + Self::default() + } + + /// Manually set the maximum channel message capacity. + #[must_use] + pub fn channel_capacity(mut self, capacity: usize) -> ELoopBuilder { + self.channel_capacity = capacity; + + self + } + + /// Manually set the maximum timer capacity. + #[must_use] + pub fn timer_capacity(mut self, capacity: usize) -> ELoopBuilder { + self.timer_capacity = capacity; + + self + } + + /// Manually set the bind address for the udp socket in the event loop. + #[must_use] + pub fn bind_address(mut self, address: SocketAddr) -> ELoopBuilder { + self.bind_address = address; + + self + } + + /// Manually set the length of buffers provided by the event loop. + #[must_use] + pub fn buffer_length(mut self, length: usize) -> ELoopBuilder { + self.buffer_size = length; + + self + } + + /// Build the event loop with the current builder. + /// + /// # Errors + /// + /// It would error when the builder config has an problem. + pub fn build(self) -> Result> { + ELoop::from_builder(&self) + } +} + +impl Default for ELoopBuilder { + fn default() -> Self { + let default_addr = SocketAddr::V4(SocketAddrV4::new(Ipv4Addr::new(0, 0, 0, 0), 0)); + + ELoopBuilder { + channel_capacity: DEFAULT_CHANNEL_CAPACITY, + timer_capacity: DEFAULT_TIMER_CAPACITY, + buffer_size: DEFAULT_BUFFER_SIZE, + bind_address: default_addr, + } + } +} + +//----------------------------------------------------------------------------// + +/// Wrapper around the main application event loop. +pub struct ELoop { + buffer_size: usize, + socket_addr: SocketAddr, + event_loop: EventLoop>, +} + +impl ELoop { + fn from_builder(builder: &ELoopBuilder) -> Result> { + let mut event_loop_config = EventLoopConfig::new(); + event_loop_config + .notify_capacity(builder.channel_capacity) + .timer_capacity(builder.timer_capacity); + + let event_loop = EventLoop::configured(event_loop_config)?; + + Ok(ELoop { + buffer_size: builder.buffer_size, + socket_addr: builder.bind_address, + event_loop, + }) + } + + /// Grab a channel to send messages to the event loop. + #[must_use] + pub fn channel(&self) -> Sender { + self.event_loop.channel() + } + + /// Run the event loop with the given dispatcher until a shutdown occurs. + /// + /// # Errors + /// + /// It would error if unable to bind to the socket. + pub fn run(&mut self, dispatcher: D) -> Result<()> { + let udp_socket = UdpSocket::bound(&self.socket_addr)?; + let mut dispatch_handler = DispatchHandler::new(udp_socket, self.buffer_size, dispatcher, &mut self.event_loop); + + self.event_loop.run(&mut dispatch_handler) + } +} diff --git a/contrib/umio/src/external.rs b/contrib/umio/src/external.rs new file mode 100644 index 000000000..2ef9d04f8 --- /dev/null +++ b/contrib/umio/src/external.rs @@ -0,0 +1 @@ +pub use mio::{Sender, Timeout, TimerError, TimerResult}; diff --git a/contrib/umio/src/lib.rs b/contrib/umio/src/lib.rs new file mode 100644 index 000000000..c246f2dfe --- /dev/null +++ b/contrib/umio/src/lib.rs @@ -0,0 +1,17 @@ +//! Message Based Readiness API +//! +//! This library is a thin wrapper around mio for clients who wish to +//! use a single udp socket in conjunction with message passing and +//! timeouts. + +mod buffer; +mod dispatcher; +mod eloop; +mod provider; + +/// Exports of bare mio types. +pub mod external; + +pub use dispatcher::Dispatcher; +pub use eloop::{ELoop, ELoopBuilder}; +pub use provider::Provider; diff --git a/contrib/umio/src/provider.rs b/contrib/umio/src/provider.rs new file mode 100644 index 000000000..b9961292f --- /dev/null +++ b/contrib/umio/src/provider.rs @@ -0,0 +1,72 @@ +use std::collections::VecDeque; +use std::net::SocketAddr; + +use mio::{EventLoop, Sender, Timeout, TimerResult}; + +use crate::buffer::{Buffer, BufferPool}; +use crate::dispatcher::{DispatchHandler, Dispatcher}; + +/// Provides services to dispatcher clients. +pub struct Provider<'a, D: Dispatcher> { + buffer_pool: &'a mut BufferPool, + out_queue: &'a mut VecDeque<(Buffer, SocketAddr)>, + event_loop: &'a mut EventLoop>, +} + +pub fn new<'a, D: Dispatcher>( + buffer_pool: &'a mut BufferPool, + out_queue: &'a mut VecDeque<(Buffer, SocketAddr)>, + event_loop: &'a mut EventLoop>, +) -> Provider<'a, D> { + Provider { + buffer_pool, + out_queue, + event_loop, + } +} + +impl<'a, D: Dispatcher> Provider<'a, D> { + /// Grab a channel to send messages to the event loop. + #[must_use] + pub fn channel(&self) -> Sender { + self.event_loop.channel() + } + + /// Execute a closure with a buffer and send the buffer contents to the + /// destination address or reclaim the buffer and do not send anything. + pub fn outgoing(&mut self, out: F) + where + F: FnOnce(&mut [u8]) -> Option<(usize, SocketAddr)>, + { + let mut buffer = self.buffer_pool.pop(); + let opt_send_to = out(buffer.as_mut()); + + match opt_send_to { + None => self.buffer_pool.push(buffer), + Some((bytes, addr)) => { + buffer.set_written(bytes); + + self.out_queue.push_back((buffer, addr)); + } + } + } + + /// Set a timeout with the given delay and token. + /// + /// # Errors + /// + /// It would error when the timeout returns in a error. + pub fn set_timeout(&mut self, token: D::Timeout, delay: u64) -> TimerResult { + self.event_loop.timeout_ms(token, delay) + } + + /// Clear a timeout using the provided timeout identifier. + pub fn clear_timeout(&mut self, timeout: Timeout) -> bool { + self.event_loop.clear_timeout(timeout) + } + + /// Shutdown the event loop. + pub fn shutdown(&mut self) { + self.event_loop.shutdown(); + } +} diff --git a/contrib/umio/tests/common/mod.rs b/contrib/umio/tests/common/mod.rs new file mode 100644 index 000000000..97102082b --- /dev/null +++ b/contrib/umio/tests/common/mod.rs @@ -0,0 +1,69 @@ +use std::net::SocketAddr; +use std::sync::mpsc::{self}; + +use umio::{Dispatcher, Provider}; + +pub struct MockDispatcher { + send: mpsc::Sender, +} + +#[allow(dead_code)] +#[derive(Debug)] +pub enum MockMessage { + MessageReceived(Vec, SocketAddr), + TimeoutReceived(u32), + NotifyReceived, + + SendNotify, + SendMessage(Vec, SocketAddr), + SendTimeout(u32, u64), + + Shutdown, +} + +impl MockDispatcher { + pub fn new() -> (MockDispatcher, mpsc::Receiver) { + let (send, recv) = mpsc::channel(); + + (MockDispatcher { send }, recv) + } +} + +impl Dispatcher for MockDispatcher { + type Timeout = u32; + type Message = MockMessage; + + fn incoming(&mut self, _: Provider<'_, Self>, message: &[u8], addr: SocketAddr) { + let owned_message = message.to_vec(); + + self.send.send(MockMessage::MessageReceived(owned_message, addr)).unwrap(); + } + + fn notify(&mut self, mut provider: Provider<'_, Self>, msg: Self::Message) { + match msg { + MockMessage::SendMessage(message, addr) => { + provider.outgoing(|buffer| { + for (src, dst) in message.iter().zip(buffer.as_mut().iter_mut()) { + *dst = *src; + } + + Some((message.len(), addr)) + }); + } + MockMessage::SendTimeout(token, delay) => { + provider.set_timeout(token, delay).unwrap(); + } + MockMessage::SendNotify => { + self.send.send(MockMessage::NotifyReceived).unwrap(); + } + MockMessage::Shutdown => { + provider.shutdown(); + } + _ => panic!("Invalid Message To Send To Dispatcher: {msg:?}"), + } + } + + fn timeout(&mut self, _: Provider<'_, Self>, token: Self::Timeout) { + self.send.send(MockMessage::TimeoutReceived(token)).unwrap(); + } +} diff --git a/contrib/umio/tests/test_incoming.rs b/contrib/umio/tests/test_incoming.rs new file mode 100644 index 000000000..85bfb1545 --- /dev/null +++ b/contrib/umio/tests/test_incoming.rs @@ -0,0 +1,40 @@ +use std::net::UdpSocket; +use std::thread::{self}; +use std::time::Duration; + +use common::{MockDispatcher, MockMessage}; +use umio::ELoopBuilder; + +mod common; + +#[test] +fn positive_receive_incoming_message() { + let eloop_addr = "127.0.0.1:5050".parse().unwrap(); + let mut eloop = ELoopBuilder::new().bind_address(eloop_addr).build().unwrap(); + + let (dispatcher, dispatch_recv) = MockDispatcher::new(); + let dispatch_send = eloop.channel(); + + thread::spawn(move || { + eloop.run(dispatcher).unwrap(); + }); + thread::sleep(Duration::from_millis(50)); + + let socket_addr = "127.0.0.1:5051".parse().unwrap(); + let socket = UdpSocket::bind(socket_addr).unwrap(); + let message = b"This Is A Test Message"; + + socket.send_to(&message[..], eloop_addr).unwrap(); + thread::sleep(Duration::from_millis(50)); + + match dispatch_recv.try_recv() { + Ok(MockMessage::MessageReceived(msg, addr)) => { + assert_eq!(&msg[..], &message[..]); + + assert_eq!(addr, socket_addr); + } + _ => panic!("ELoop Failed To Receive Incoming Message"), + } + + dispatch_send.send(MockMessage::Shutdown).unwrap(); +} diff --git a/contrib/umio/tests/test_notify.rs b/contrib/umio/tests/test_notify.rs new file mode 100644 index 000000000..5b624dee6 --- /dev/null +++ b/contrib/umio/tests/test_notify.rs @@ -0,0 +1,31 @@ +use std::thread::{self}; +use std::time::Duration; + +use common::{MockDispatcher, MockMessage}; +use umio::ELoopBuilder; + +mod common; + +#[test] +fn positive_send_notify() { + let eloop_addr = "127.0.0.1:0".parse().unwrap(); + let mut eloop = ELoopBuilder::new().bind_address(eloop_addr).build().unwrap(); + + let (dispatcher, dispatch_recv) = MockDispatcher::new(); + let dispatch_send = eloop.channel(); + + thread::spawn(move || { + eloop.run(dispatcher).unwrap(); + }); + thread::sleep(Duration::from_millis(50)); + + dispatch_send.send(MockMessage::SendNotify).unwrap(); + thread::sleep(Duration::from_millis(50)); + + match dispatch_recv.try_recv() { + Ok(MockMessage::NotifyReceived) => (), + _ => panic!("ELoop Failed To Receive Incoming Message"), + } + + dispatch_send.send(MockMessage::Shutdown).unwrap(); +} diff --git a/contrib/umio/tests/test_outgoing.rs b/contrib/umio/tests/test_outgoing.rs new file mode 100644 index 000000000..3db172f75 --- /dev/null +++ b/contrib/umio/tests/test_outgoing.rs @@ -0,0 +1,39 @@ +use std::net::UdpSocket; +use std::thread::{self}; +use std::time::Duration; + +use common::{MockDispatcher, MockMessage}; +use umio::ELoopBuilder; + +mod common; + +#[test] +fn positive_send_outgoing_message() { + let eloop_addr = "127.0.0.1:5052".parse().unwrap(); + let mut eloop = ELoopBuilder::new().bind_address(eloop_addr).build().unwrap(); + + let (dispatcher, _) = MockDispatcher::new(); + let dispatch_send = eloop.channel(); + + thread::spawn(move || { + eloop.run(dispatcher).unwrap(); + }); + thread::sleep(Duration::from_millis(50)); + + let message = b"This Is A Test Message"; + let mut message_recv = [0u8; 22]; + let socket_addr = "127.0.0.1:5053".parse().unwrap(); + let socket = UdpSocket::bind(socket_addr).unwrap(); + dispatch_send + .send(MockMessage::SendMessage(message.to_vec(), socket_addr)) + .unwrap(); + thread::sleep(Duration::from_millis(50)); + + let (bytes, addr) = socket.recv_from(&mut message_recv).unwrap(); + + assert_eq!(bytes, message.len()); + assert_eq!(&message[..], &message_recv[..]); + assert_eq!(addr, eloop_addr); + + dispatch_send.send(MockMessage::Shutdown).unwrap(); +} diff --git a/contrib/umio/tests/test_shutdown.rs b/contrib/umio/tests/test_shutdown.rs new file mode 100644 index 000000000..dd7023f2f --- /dev/null +++ b/contrib/umio/tests/test_shutdown.rs @@ -0,0 +1,26 @@ +use std::thread::{self}; +use std::time::Duration; + +use common::{MockDispatcher, MockMessage}; +use umio::ELoopBuilder; + +mod common; + +#[test] +fn positive_execute_shutdown() { + let eloop_addr = "127.0.0.1:0".parse().unwrap(); + let mut eloop = ELoopBuilder::new().bind_address(eloop_addr).build().unwrap(); + + let (dispatcher, _) = MockDispatcher::new(); + let dispatch_send = eloop.channel(); + + thread::spawn(move || { + eloop.run(dispatcher).unwrap(); + }); + thread::sleep(Duration::from_millis(50)); + + dispatch_send.send(MockMessage::Shutdown).unwrap(); + thread::sleep(Duration::from_millis(50)); + + assert!(dispatch_send.send(MockMessage::SendNotify).is_err()); +} diff --git a/contrib/umio/tests/test_timeout.rs b/contrib/umio/tests/test_timeout.rs new file mode 100644 index 000000000..60f537016 --- /dev/null +++ b/contrib/umio/tests/test_timeout.rs @@ -0,0 +1,34 @@ +use std::thread::{self}; +use std::time::Duration; + +use common::{MockDispatcher, MockMessage}; +use umio::ELoopBuilder; + +mod common; + +#[test] +fn positive_send_notify() { + let eloop_addr = "127.0.0.1:0".parse().unwrap(); + let mut eloop = ELoopBuilder::new().bind_address(eloop_addr).build().unwrap(); + + let (dispatcher, dispatch_recv) = MockDispatcher::new(); + let dispatch_send = eloop.channel(); + + thread::spawn(move || { + eloop.run(dispatcher).unwrap(); + }); + thread::sleep(Duration::from_millis(50)); + + let token = 5; + dispatch_send.send(MockMessage::SendTimeout(token, 50)).unwrap(); + thread::sleep(Duration::from_millis(300)); + + match dispatch_recv.try_recv() { + Ok(MockMessage::TimeoutReceived(tkn)) => { + assert_eq!(tkn, token); + } + _ => panic!("ELoop Failed To Receive Timeout"), + } + + dispatch_send.send(MockMessage::Shutdown).unwrap(); +} diff --git a/packages/utracker/Cargo.toml b/packages/utracker/Cargo.toml index 29dce63e8..7c38d5269 100644 --- a/packages/utracker/Cargo.toml +++ b/packages/utracker/Cargo.toml @@ -19,6 +19,8 @@ version.workspace = true handshake = { path = "../handshake" } util = { path = "../util" } +umio = { path = "../../contrib/umio" } + byteorder = "1" chrono = "0" futures = "0" @@ -26,7 +28,6 @@ nom = "7" rand = "0" thiserror = "1" tracing = "0" -umio = "0" [dev-dependencies] tokio = { version = "1", features = ["full"] } From f889ad6be341ea9d01ac526b424a8552f8cb03d0 Mon Sep 17 00:00:00 2001 From: Cameron Garnham Date: Wed, 21 Aug 2024 13:09:59 +0200 Subject: [PATCH 4/4] deps: update umio to mio v1 --- cSpell.json | 4 + contrib/umio/Cargo.toml | 6 +- contrib/umio/src/buffer.rs | 66 ++- contrib/umio/src/dispatcher.rs | 224 +++++---- contrib/umio/src/eloop.rs | 432 ++++++++++++++++-- contrib/umio/src/external.rs | 2 +- contrib/umio/src/lib.rs | 13 +- contrib/umio/src/provider.rs | 171 +++++-- contrib/umio/tests/common/mod.rs | 54 ++- contrib/umio/tests/test_incoming.rs | 52 ++- contrib/umio/tests/test_notify.rs | 38 +- contrib/umio/tests/test_outgoing.rs | 44 +- contrib/umio/tests/test_shutdown.rs | 30 +- contrib/umio/tests/test_timeout.rs | 43 +- packages/dht/examples/debug.rs | 4 +- .../handshake/examples/handshake_torrent.rs | 2 +- packages/utracker/Cargo.toml | 1 - packages/utracker/src/announce.rs | 11 +- packages/utracker/src/client/dispatcher.rs | 322 ++++++++----- packages/utracker/src/client/mod.rs | 51 ++- packages/utracker/src/option.rs | 2 +- packages/utracker/src/request.rs | 54 ++- packages/utracker/src/response.rs | 8 +- packages/utracker/src/scrape.rs | 2 + packages/utracker/src/server/dispatcher.rs | 74 +-- packages/utracker/src/server/mod.rs | 33 +- packages/utracker/tests/common/mod.rs | 31 +- .../utracker/tests/test_announce_start.rs | 18 +- packages/utracker/tests/test_announce_stop.rs | 15 +- packages/utracker/tests/test_client_drop.rs | 43 +- packages/utracker/tests/test_client_full.rs | 59 ++- packages/utracker/tests/test_connect.rs | 13 +- packages/utracker/tests/test_connect_cache.rs | 19 +- packages/utracker/tests/test_scrape.rs | 13 +- packages/utracker/tests/test_server_drop.rs | 14 +- 35 files changed, 1380 insertions(+), 588 deletions(-) diff --git a/cSpell.json b/cSpell.json index 97205b8f0..cd222e986 100644 --- a/cSpell.json +++ b/cSpell.json @@ -22,10 +22,12 @@ "codegen", "compat", "concated", + "Condvar", "coppersurfer", "cpupool", "curr", "cust", + "cvar", "Cyberneering", "demonii", "Deque", @@ -51,6 +53,7 @@ "metainfo", "mpmc", "myapp", + "nanos", "natted", "nextest", "Oneshot", @@ -65,6 +68,7 @@ "rebootstrapping", "recvd", "reqq", + "reregister", "ringbuffer", "rpath", "rqst", diff --git a/contrib/umio/Cargo.toml b/contrib/umio/Cargo.toml index 7307f6b44..6ca916208 100644 --- a/contrib/umio/Cargo.toml +++ b/contrib/umio/Cargo.toml @@ -16,4 +16,8 @@ repository.workspace = true version.workspace = true [dependencies] -mio = "0.5" +mio = { version = "1", features = ["net", "os-poll"] } +tracing = "0" + +[dev-dependencies] +tracing-subscriber = "0" diff --git a/contrib/umio/src/buffer.rs b/contrib/umio/src/buffer.rs index c6b21b72f..41e479bc8 100644 --- a/contrib/umio/src/buffer.rs +++ b/contrib/umio/src/buffer.rs @@ -1,3 +1,7 @@ +use std::ops::{Deref, DerefMut}; + +use tracing::instrument; + #[allow(clippy::module_name_repetitions)] pub struct BufferPool { // Use Stack For Temporal Locality @@ -5,20 +9,39 @@ pub struct BufferPool { buffer_size: usize, } +impl std::fmt::Debug for BufferPool { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("BufferPool") + .field("buffers_len", &self.buffers.len()) + .field("buffer_size", &self.buffer_size) + .finish() + } +} + impl BufferPool { + #[instrument(skip())] pub fn new(buffer_size: usize) -> BufferPool { let buffers = Vec::new(); BufferPool { buffers, buffer_size } } + #[instrument(skip(self), fields(remaining= %self.buffers.len()))] pub fn pop(&mut self) -> Buffer { - self.buffers.pop().unwrap_or(Buffer::new(self.buffer_size)) + if let Some(buffer) = self.buffers.pop() { + tracing::trace!(?buffer, "popping old buffer taken from pool"); + buffer + } else { + let buffer = Buffer::new(self.buffer_size); + tracing::trace!(?buffer, "creating new buffer..."); + buffer + } } + #[instrument(skip(self, buffer), fields(existing= %self.buffers.len()))] pub fn push(&mut self, mut buffer: Buffer) { + tracing::trace!("Pushing buffer back to pool"); buffer.reset_position(); - self.buffers.push(buffer); } } @@ -27,42 +50,58 @@ impl BufferPool { /// Reusable region of memory for incoming and outgoing messages. pub struct Buffer { - buffer: Vec, - bytes_written: usize, + buffer: std::io::Cursor>, +} + +impl std::fmt::Debug for Buffer { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("Buffer").field("buffer", &self.as_ref()).finish() + } } impl Buffer { + #[instrument(skip())] fn new(len: usize) -> Buffer { Buffer { - buffer: vec![0u8; len], - bytes_written: 0, + buffer: std::io::Cursor::new(vec![0_u8; len]), } } fn reset_position(&mut self) { - self.set_written(0); + self.set_position(0); + } +} + +impl Deref for Buffer { + type Target = std::io::Cursor>; + + fn deref(&self) -> &Self::Target { + &self.buffer } +} - /// Update the number of bytes written to the buffer. - pub fn set_written(&mut self, bytes: usize) { - self.bytes_written = bytes; +impl DerefMut for Buffer { + fn deref_mut(&mut self) -> &mut Self::Target { + &mut self.buffer } } impl AsRef<[u8]> for Buffer { fn as_ref(&self) -> &[u8] { - &self.buffer[..self.bytes_written] + self.get_ref().split_at(self.buffer.position().try_into().unwrap()).0 } } impl AsMut<[u8]> for Buffer { fn as_mut(&mut self) -> &mut [u8] { - &mut self.buffer[self.bytes_written..] + let pos = self.buffer.position().try_into().unwrap(); + self.get_mut().split_at_mut(pos).1 } } #[cfg(test)] mod tests { + use super::{Buffer, BufferPool}; const DEFAULT_BUFFER_SIZE: usize = 1500; @@ -79,7 +118,8 @@ mod tests { #[test] fn positive_buffer_len_update() { let mut buffer = Buffer::new(DEFAULT_BUFFER_SIZE); - buffer.set_written(DEFAULT_BUFFER_SIZE - 1); + + buffer.set_position((DEFAULT_BUFFER_SIZE - 1).try_into().unwrap()); assert_eq!(buffer.as_mut().len(), 1); assert_eq!(buffer.as_ref().len(), DEFAULT_BUFFER_SIZE - 1); diff --git a/contrib/umio/src/dispatcher.rs b/contrib/umio/src/dispatcher.rs index a62fd5979..bdf6808bf 100644 --- a/contrib/umio/src/dispatcher.rs +++ b/contrib/umio/src/dispatcher.rs @@ -1,135 +1,195 @@ use std::collections::VecDeque; use std::net::SocketAddr; +use std::sync::mpsc; -use mio::udp::UdpSocket; -use mio::{EventLoop, EventSet, Handler, PollOpt, Token}; +use mio::net::UdpSocket; +use mio::{Interest, Poll, Waker}; +use tracing::{instrument, Level}; use crate::buffer::{Buffer, BufferPool}; -use crate::{provider, Provider}; +use crate::eloop::ShutdownHandle; +use crate::provider::TimeoutAction; +use crate::{Provider, UDP_SOCKET_TOKEN}; -/// Handles events occurring within the event loop. -pub trait Dispatcher: Sized { - type Timeout; - type Message: Send; +pub trait Dispatcher: Sized + std::fmt::Debug { + type TimeoutToken: std::fmt::Debug; + type Message: std::fmt::Debug; - /// Process an incoming message from the given address. - #[allow(unused)] - fn incoming(&mut self, provider: Provider<'_, Self>, message: &[u8], addr: SocketAddr) {} - - /// Process a message sent via the event loop channel. - #[allow(unused)] - fn notify(&mut self, provider: Provider<'_, Self>, message: Self::Message) {} - - /// Process a timeout that has been triggered. - #[allow(unused)] - fn timeout(&mut self, provider: Provider<'_, Self>, timeout: Self::Timeout) {} + fn incoming(&mut self, _provider: Provider<'_, Self>, _message: &[u8], _addr: SocketAddr) {} + fn notify(&mut self, _provider: Provider<'_, Self>, _message: Self::Message) {} + fn timeout(&mut self, _provider: Provider<'_, Self>, _timeout: Self::TimeoutToken) {} } -//----------------------------------------------------------------------------// - -const UDP_SOCKET_TOKEN: Token = Token(2); +pub struct DispatchHandler +where + D: std::fmt::Debug, +{ + pub dispatch: D, + pub out_queue: VecDeque<(Buffer, SocketAddr)>, + socket: UdpSocket, + pub buffer_pool: BufferPool, + current_interest: Interest, + pub timer_sender: mpsc::Sender>, +} -pub struct DispatchHandler { - dispatch: D, - out_queue: VecDeque<(Buffer, SocketAddr)>, - udp_socket: UdpSocket, - buffer_pool: BufferPool, - current_set: EventSet, +impl std::fmt::Debug for DispatchHandler +where + D: std::fmt::Debug, +{ + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("DispatchHandler") + .field("dispatch", &self.dispatch) + .field("out_queue_len", &self.out_queue.len()) + .field("socket", &self.socket) + .field("buffer_pool", &self.buffer_pool) + .field("current_interest", &self.current_interest) + .field("timer_sender", &self.timer_sender) + .finish() + } } -impl DispatchHandler { +impl DispatchHandler +where + D: Dispatcher + std::fmt::Debug, +{ + #[instrument(skip(), ret(level = Level::TRACE))] pub fn new( - udp_socket: UdpSocket, + mut socket: UdpSocket, buffer_size: usize, dispatch: D, - event_loop: &mut EventLoop>, - ) -> DispatchHandler { + poll: &mut Poll, + timer_sender: mpsc::Sender>, + ) -> DispatchHandler + where + D: std::fmt::Debug, + ::TimeoutToken: std::fmt::Debug, + ::Message: std::fmt::Debug, + { let buffer_pool = BufferPool::new(buffer_size); let out_queue = VecDeque::new(); - event_loop - .register(&udp_socket, UDP_SOCKET_TOKEN, EventSet::readable(), PollOpt::edge()) + poll.registry() + .register(&mut socket, UDP_SOCKET_TOKEN, Interest::READABLE) .unwrap(); DispatchHandler { dispatch, out_queue, - udp_socket, + socket, buffer_pool, - current_set: EventSet::readable(), + current_interest: Interest::READABLE, + timer_sender, } } + #[instrument(skip(self, waker, shutdown_handle))] + pub fn handle_message(&mut self, waker: &Waker, shutdown_handle: &mut ShutdownHandle, message: D::Message) { + tracing::trace!("message received"); + let provider = Provider::new( + &mut self.buffer_pool, + &mut self.out_queue, + waker, + shutdown_handle, + &self.timer_sender, + ); + + self.dispatch.notify(provider, message); + } + + #[instrument(skip(self, waker, shutdown_handle))] + pub fn handle_timeout(&mut self, waker: &Waker, shutdown_handle: &mut ShutdownHandle, token: D::TimeoutToken) { + tracing::trace!("timeout expired"); + let provider = Provider::new( + &mut self.buffer_pool, + &mut self.out_queue, + waker, + shutdown_handle, + &self.timer_sender, + ); + + self.dispatch.timeout(provider, token); + } + + #[instrument(skip(self))] pub fn handle_write(&mut self) { + tracing::trace!("handle write"); + if let Some((buffer, addr)) = self.out_queue.pop_front() { - self.udp_socket.send_to(buffer.as_ref(), &addr).unwrap(); + let bytes = self.socket.send_to(buffer.as_ref(), addr).unwrap(); + + tracing::debug!(?buffer, ?bytes, ?addr, "sent"); self.buffer_pool.push(buffer); - }; + } } + #[instrument(skip(self))] pub fn handle_read(&mut self) -> Option<(Buffer, SocketAddr)> { + tracing::trace!("handle read"); + let mut buffer = self.buffer_pool.pop(); - if let Ok(Some((bytes, addr))) = self.udp_socket.recv_from(buffer.as_mut()) { - buffer.set_written(bytes); + match self.socket.recv_from(buffer.as_mut()) { + Ok((bytes, addr)) => { + buffer.set_position(bytes.try_into().unwrap()); + tracing::trace!(?buffer, "DispatchHandler: Read {bytes} bytes from {addr}"); - Some((buffer, addr)) - } else { - None + Some((buffer, addr)) + } + Err(e) => { + tracing::error!("DispatchHandler: Failed to read from UDP socket: {e}"); + None + } } } -} - -impl Handler for DispatchHandler { - type Timeout = D::Timeout; - type Message = D::Message; - - fn ready(&mut self, event_loop: &mut EventLoop, token: Token, events: EventSet) { - if token != UDP_SOCKET_TOKEN { - return; - } - - if events.is_writable() { - self.handle_write(); - } - - if events.is_readable() { - let Some((buffer, addr)) = self.handle_read() else { - return; - }; - { - let provider = provider::new(&mut self.buffer_pool, &mut self.out_queue, event_loop); - - self.dispatch.incoming(provider, buffer.as_ref(), addr); + #[instrument(skip(self, waker, shutdown_handle, event, poll))] + pub fn handle_event( + &mut self, + waker: &Waker, + shutdown_handle: &mut ShutdownHandle, + event: &mio::event::Event, + poll: &mut Poll, + ) where + T: std::fmt::Debug, + { + tracing::trace!(?event, "handle event"); + + if event.token() == UDP_SOCKET_TOKEN { + if event.is_writable() { + self.handle_write(); } - self.buffer_pool.push(buffer); + if event.is_readable() { + if let Some((buffer, addr)) = self.handle_read() { + let provider = Provider::new( + &mut self.buffer_pool, + &mut self.out_queue, + waker, + shutdown_handle, + &self.timer_sender, + ); + self.dispatch.incoming(provider, buffer.as_ref(), addr); + self.buffer_pool.push(buffer); + } + } } - } - fn notify(&mut self, event_loop: &mut EventLoop, msg: Self::Message) { - let provider = provider::new(&mut self.buffer_pool, &mut self.out_queue, event_loop); - - self.dispatch.notify(provider, msg); + self.update_interest(poll); } - fn timeout(&mut self, event_loop: &mut EventLoop, timeout: Self::Timeout) { - let provider = provider::new(&mut self.buffer_pool, &mut self.out_queue, event_loop); - - self.dispatch.timeout(provider, timeout); - } + #[instrument(skip(self, poll))] + fn update_interest(&mut self, poll: &mut Poll) { + tracing::trace!("update interest"); - fn tick(&mut self, event_loop: &mut EventLoop) { - self.current_set = if self.out_queue.is_empty() { - EventSet::readable() + self.current_interest = if self.out_queue.is_empty() { + Interest::READABLE } else { - EventSet::readable() | EventSet::writable() + Interest::READABLE | Interest::WRITABLE }; - event_loop - .reregister(&self.udp_socket, UDP_SOCKET_TOKEN, self.current_set, PollOpt::edge()) + poll.registry() + .reregister(&mut self.socket, UDP_SOCKET_TOKEN, self.current_interest) .unwrap(); } } diff --git a/contrib/umio/src/eloop.rs b/contrib/umio/src/eloop.rs index cedf030bf..c216fbf1f 100644 --- a/contrib/umio/src/eloop.rs +++ b/contrib/umio/src/eloop.rs @@ -1,16 +1,262 @@ -use std::io::Result; +use std::cmp::Ordering; +use std::collections::{BinaryHeap, HashSet, VecDeque}; +use std::marker::PhantomData; use std::net::{Ipv4Addr, SocketAddr, SocketAddrV4}; +use std::sync::{mpsc, Arc, Condvar, Mutex, OnceLock, Weak}; +use std::thread::JoinHandle; +use std::time::Instant; -use mio::udp::UdpSocket; -use mio::{EventLoop, EventLoopConfig, Sender}; +use mio::net::UdpSocket; +use mio::{Events, Poll, Waker}; +use tracing::{instrument, Level}; use crate::dispatcher::{DispatchHandler, Dispatcher}; +use crate::provider::TimeoutAction; +use crate::WAKER_TOKEN; const DEFAULT_BUFFER_SIZE: usize = 1500; const DEFAULT_CHANNEL_CAPACITY: usize = 4096; const DEFAULT_TIMER_CAPACITY: usize = 65536; -/// Builder for specifying attributes of an event loop. +#[derive(Debug)] +pub struct MessageSender +where + T: std::fmt::Debug, +{ + sender: mpsc::Sender, + waker: Arc, +} + +impl Clone for MessageSender +where + T: std::fmt::Debug, +{ + fn clone(&self) -> Self { + Self { + sender: self.sender.clone(), + waker: self.waker.clone(), + } + } +} + +impl MessageSender +where + T: std::fmt::Debug, +{ + #[instrument(skip(), ret(level = Level::TRACE))] + fn new(sender: mpsc::Sender, waker: Arc) -> Self { + Self { sender, waker } + } + + #[instrument(skip(self))] + pub fn send(&self, msg: T) -> Result<(), mpsc::SendError> { + tracing::trace!("sending message"); + + let res = self.sender.send(msg); + + self.waker.wake().unwrap(); + + res + } +} + +#[derive(Debug)] +struct Timeout { + when: Instant, + token: Weak, +} + +impl Ord for Timeout { + fn cmp(&self, other: &Self) -> Ordering { + other.when.cmp(&self.when) // from smallest to largest + } +} + +impl PartialOrd for Timeout { + fn partial_cmp(&self, other: &Self) -> Option { + Some(self.cmp(other)) + } +} + +impl Eq for Timeout {} + +impl PartialEq for Timeout { + fn eq(&self, other: &Self) -> bool { + self.when == other.when + } +} + +#[derive(Debug)] +struct LoopWaker +where + T: std::fmt::Debug, +{ + waker: Arc, + active: HashSet>, + pending: Arc<(Mutex>>, Condvar)>, + finished: Arc>>>, + _handle: JoinHandle<()>, +} + +impl LoopWaker +where + T: std::hash::Hash + std::cmp::Eq + std::fmt::Debug + 'static, + Weak: Send, + Arc: Send, +{ + #[instrument(skip(waker, shutdown_handle), ret(level = Level::TRACE))] + fn new(waker: Arc, shutdown_handle: ShutdownHandle) -> Self { + let pending: Arc<(Mutex>>, Condvar)> = Arc::default(); + let finished: Arc>>> = Arc::default(); + + let handle = { + let pending = pending.clone(); + let finished = finished.clone(); + let waker = waker.clone(); + + std::thread::spawn(move || { + let mut timeouts: BinaryHeap> = BinaryHeap::default(); + let mut elapsed = VecDeque::default(); + + while !shutdown_handle.is_shutdown() { + { + let (lock, cvar) = &*pending; + let mut pending = lock.lock().unwrap(); + + while pending.is_empty() && timeouts.is_empty() { + pending = cvar.wait(pending).unwrap(); + + if shutdown_handle.is_shutdown() { + return; + } + } + + timeouts.append(&mut pending); + } + + while let Some(timeout) = timeouts.pop() { + let Some(token) = Weak::upgrade(&timeout.token) else { + continue; + }; + + match timeout.when.checked_duration_since(Instant::now()) { + Some(wait) => { + std::thread::sleep(wait); + elapsed.push_back(token); + break; + } + None => elapsed.push_back(token), + } + } + + let mut finished = finished.lock().unwrap(); + finished.append(&mut elapsed); + waker.wake().unwrap(); + } + }) + }; + + Self { + waker, + active: HashSet::default(), + pending, + finished, + _handle: handle, + } + } + + #[instrument(skip(self))] + fn next(&mut self) -> Option { + let token = self.finished.lock().unwrap().pop_front()?; + + let token = if self.remove(&token) { Arc::into_inner(token) } else { None }; + + tracing::trace!(?token, "next timeout"); + + token + } + + #[instrument(skip(self))] + fn remove(&mut self, token: &T) -> bool { + let remove = self.active.remove(token); + + tracing::trace!(%remove, "removed timeout"); + + remove + } + + pub fn clear(&mut self) { + self.active.clear(); + self.pending.0.lock().unwrap().clear(); + } + + #[instrument(skip(self))] + fn push(&mut self, when: Instant, token: T) -> bool { + let token = Arc::new(token); + + let timeout = Timeout { + when, + token: Arc::downgrade(&token), + }; + + let inserted = self.active.insert(token); + + if inserted { + let (lock, cvar) = &*self.pending; + + lock.lock().unwrap().push(timeout); + cvar.notify_one(); + }; + + tracing::trace!(%inserted, "new timeout"); + + inserted + } +} + +#[derive(Debug, Clone)] +pub struct ShutdownHandle { + handle: Arc>, + waker: Arc, +} + +impl ShutdownHandle { + #[instrument(skip(), ret(level = Level::TRACE))] + fn new(waker: Arc) -> Self { + Self { + handle: Arc::default(), + waker, + } + } + + #[must_use] + pub fn is_shutdown(&self) -> bool { + self.handle.get().is_some() + } + + #[instrument(skip(self))] + pub fn shutdown(&mut self) { + if self.handle.set(()).is_ok() { + tracing::info!("shutdown called"); + } else { + tracing::debug!("shutdown already called"); + }; + + match self.waker.wake() { + Ok(()) => tracing::trace!("waking... shutdown"), + Err(e) => tracing::trace!("error waking... shutdown: {e}"), + } + } +} + +impl Drop for ShutdownHandle { + #[instrument(skip(self))] + fn drop(&mut self) { + self.shutdown(); + } +} + +#[derive(Debug)] pub struct ELoopBuilder { channel_capacity: usize, timer_capacity: usize, @@ -19,57 +265,55 @@ pub struct ELoopBuilder { } impl ELoopBuilder { - /// Create a new event loop builder. #[must_use] pub fn new() -> ELoopBuilder { Self::default() } - /// Manually set the maximum channel message capacity. #[must_use] pub fn channel_capacity(mut self, capacity: usize) -> ELoopBuilder { self.channel_capacity = capacity; - self } - /// Manually set the maximum timer capacity. #[must_use] pub fn timer_capacity(mut self, capacity: usize) -> ELoopBuilder { self.timer_capacity = capacity; - self } - /// Manually set the bind address for the udp socket in the event loop. #[must_use] pub fn bind_address(mut self, address: SocketAddr) -> ELoopBuilder { self.bind_address = address; - self } - /// Manually set the length of buffers provided by the event loop. #[must_use] pub fn buffer_length(mut self, length: usize) -> ELoopBuilder { self.buffer_size = length; - self } - /// Build the event loop with the current builder. + /// Builds an `ELoop` instance with the specified configuration. /// /// # Errors /// - /// It would error when the builder config has an problem. - pub fn build(self) -> Result> { + /// This function will return an error if creating the `Poll` or `Waker` fails. + pub fn build(self) -> std::io::Result<(ELoop, SocketAddr, ShutdownHandle)> + where + D: Dispatcher + std::fmt::Debug, + ::Message: std::fmt::Debug, + ::TimeoutToken: std::hash::Hash + std::cmp::Eq + std::fmt::Debug + 'static, + Arc<::TimeoutToken>: Send, + Weak<::TimeoutToken>: Send, + { ELoop::from_builder(&self) } } impl Default for ELoopBuilder { fn default() -> Self { - let default_addr = SocketAddr::V4(SocketAddrV4::new(Ipv4Addr::new(0, 0, 0, 0), 0)); + let default_addr = SocketAddr::V4(SocketAddrV4::new(Ipv4Addr::UNSPECIFIED, 0)); ELoopBuilder { channel_capacity: DEFAULT_CHANNEL_CAPACITY, @@ -80,46 +324,146 @@ impl Default for ELoopBuilder { } } -//----------------------------------------------------------------------------// - -/// Wrapper around the main application event loop. -pub struct ELoop { +#[derive(Debug)] +pub struct ELoop +where + D: Dispatcher, + ::Message: std::fmt::Debug, +{ buffer_size: usize, - socket_addr: SocketAddr, - event_loop: EventLoop>, + socket: Option, + poll: Poll, + events: Events, + loop_waker: LoopWaker, + shutdown_handle: ShutdownHandle, + message_sender: MessageSender, + message_receiver: mpsc::Receiver, + timeout_sender: mpsc::Sender>, + timeout_receiver: mpsc::Receiver>, + _marker: PhantomData, } -impl ELoop { - fn from_builder(builder: &ELoopBuilder) -> Result> { - let mut event_loop_config = EventLoopConfig::new(); - event_loop_config - .notify_capacity(builder.channel_capacity) - .timer_capacity(builder.timer_capacity); +impl ELoop +where + D: Dispatcher + std::fmt::Debug, + ::Message: std::fmt::Debug, + ::TimeoutToken: std::hash::Hash + std::cmp::Eq + std::fmt::Debug + 'static, + Arc<::TimeoutToken>: Send, + Weak<::TimeoutToken>: Send, +{ + #[instrument(skip(), err, ret(level = Level::TRACE))] + fn from_builder(builder: &ELoopBuilder) -> std::io::Result<(ELoop, SocketAddr, ShutdownHandle)> { + let poll = Poll::new()?; + let events = Events::with_capacity(builder.channel_capacity); + + let (message_sender, message_receiver) = mpsc::channel(); + let (timeout_sender, timeout_receiver) = mpsc::channel(); + + let socket = UdpSocket::bind(builder.bind_address)?; + + let bound_socket = socket.local_addr()?; + + let waker = Arc::new(Waker::new(poll.registry(), WAKER_TOKEN)?); - let event_loop = EventLoop::configured(event_loop_config)?; + let shutdown_handle = ShutdownHandle::new(waker.clone()); - Ok(ELoop { - buffer_size: builder.buffer_size, - socket_addr: builder.bind_address, - event_loop, - }) + let loop_waker = LoopWaker::new(waker.clone(), shutdown_handle.clone()); + let message_sender = MessageSender::new(message_sender, waker); + + Ok(( + ELoop { + buffer_size: builder.buffer_size, + socket: Some(socket), + poll, + events, + loop_waker, + shutdown_handle: shutdown_handle.clone(), + message_sender, + message_receiver, + timeout_sender, + timeout_receiver, + _marker: PhantomData, + }, + bound_socket, + shutdown_handle, + )) + } + + #[must_use] + pub fn waker(&self) -> &Waker { + &self.loop_waker.waker } - /// Grab a channel to send messages to the event loop. + /// Creates a channel for sending messages to the event loop. #[must_use] - pub fn channel(&self) -> Sender { - self.event_loop.channel() + #[instrument(skip(self))] + pub fn channel(&self) -> MessageSender<::Message> { + self.message_sender.clone() } - /// Run the event loop with the given dispatcher until a shutdown occurs. + /// Runs the event loop with the provided dispatcher. /// /// # Errors /// - /// It would error if unable to bind to the socket. - pub fn run(&mut self, dispatcher: D) -> Result<()> { - let udp_socket = UdpSocket::bound(&self.socket_addr)?; - let mut dispatch_handler = DispatchHandler::new(udp_socket, self.buffer_size, dispatcher, &mut self.event_loop); + /// This function will return an error if binding the UDP socket or polling events fails. + #[instrument(skip(self, dispatcher))] + pub fn run(&mut self, dispatcher: D, started_eloop_sender: mpsc::SyncSender>) -> std::io::Result<()> + where + D: std::fmt::Debug, + ::Message: std::fmt::Debug, + ::TimeoutToken: std::hash::Hash + std::cmp::Eq + std::fmt::Debug + 'static, + { + let mut dispatch_handler = DispatchHandler::new( + self.socket.take().unwrap(), + self.buffer_size, + dispatcher, + &mut self.poll, + self.timeout_sender.clone(), + ); + + let () = started_eloop_sender + .send(Ok(())) + .map_err(|e| std::io::Error::new(std::io::ErrorKind::BrokenPipe, e))?; + + loop { + if self.shutdown_handle.is_shutdown() { + self.loop_waker.clear(); + tracing::debug!("shutting down..."); + break; + } + + // Handle timeouts + while let Some(token) = self.loop_waker.next() { + dispatch_handler.handle_timeout(&self.loop_waker.waker, &mut self.shutdown_handle, token); + } + + // Handle events + for event in &self.events { + dispatch_handler.handle_event::(&self.loop_waker.waker, &mut self.shutdown_handle, event, &mut self.poll); + } + + // Handle messages + while let Ok(message) = self.message_receiver.try_recv() { + dispatch_handler.handle_message(&self.loop_waker.waker, &mut self.shutdown_handle, message); + } + + // Add timeouts + while let Ok(action) = self.timeout_receiver.try_recv() { + match action { + TimeoutAction::Add { token, when } => { + tracing::trace!(?token, ?when, "set timeout"); + self.loop_waker.push(when, token); + } + TimeoutAction::Remove { token } => { + tracing::trace!(?token, "clear timeout"); + self.loop_waker.remove(&token); + } + } + } + + self.poll.poll(&mut self.events, None)?; + } - self.event_loop.run(&mut dispatch_handler) + Ok(()) } } diff --git a/contrib/umio/src/external.rs b/contrib/umio/src/external.rs index 2ef9d04f8..1b6104ce2 100644 --- a/contrib/umio/src/external.rs +++ b/contrib/umio/src/external.rs @@ -1 +1 @@ -pub use mio::{Sender, Timeout, TimerError, TimerResult}; +pub use mio::{Events, Interest, Token, Waker}; diff --git a/contrib/umio/src/lib.rs b/contrib/umio/src/lib.rs index c246f2dfe..a29c4867e 100644 --- a/contrib/umio/src/lib.rs +++ b/contrib/umio/src/lib.rs @@ -1,17 +1,14 @@ -//! Message Based Readiness API -//! -//! This library is a thin wrapper around mio for clients who wish to -//! use a single udp socket in conjunction with message passing and -//! timeouts. - mod buffer; mod dispatcher; mod eloop; mod provider; -/// Exports of bare mio types. +const WAKER_TOKEN: Token = Token(0); +const UDP_SOCKET_TOKEN: Token = Token(2); + pub mod external; pub use dispatcher::Dispatcher; -pub use eloop::{ELoop, ELoopBuilder}; +pub use eloop::{ELoop, ELoopBuilder, MessageSender, ShutdownHandle}; +use mio::Token; pub use provider::Provider; diff --git a/contrib/umio/src/provider.rs b/contrib/umio/src/provider.rs index b9961292f..94bfab5bb 100644 --- a/contrib/umio/src/provider.rs +++ b/contrib/umio/src/provider.rs @@ -1,72 +1,159 @@ use std::collections::VecDeque; +use std::io::Write; +use std::marker::PhantomData; use std::net::SocketAddr; +use std::sync::mpsc; +use std::time::Instant; -use mio::{EventLoop, Sender, Timeout, TimerResult}; +use mio::Waker; +use tracing::instrument; use crate::buffer::{Buffer, BufferPool}; -use crate::dispatcher::{DispatchHandler, Dispatcher}; +use crate::dispatcher::Dispatcher; +use crate::eloop::ShutdownHandle; -/// Provides services to dispatcher clients. -pub struct Provider<'a, D: Dispatcher> { - buffer_pool: &'a mut BufferPool, - out_queue: &'a mut VecDeque<(Buffer, SocketAddr)>, - event_loop: &'a mut EventLoop>, +pub enum TimeoutAction +where + T: std::fmt::Debug, +{ + Add { token: T, when: Instant }, + Remove { token: T }, } -pub fn new<'a, D: Dispatcher>( +#[derive(Debug)] +pub struct Provider<'a, D> +where + D: Dispatcher + std::fmt::Debug, +{ buffer_pool: &'a mut BufferPool, + buffer: Option, out_queue: &'a mut VecDeque<(Buffer, SocketAddr)>, - event_loop: &'a mut EventLoop>, -) -> Provider<'a, D> { - Provider { - buffer_pool, - out_queue, - event_loop, - } + waker: &'a Waker, + shutdown_handle: &'a mut ShutdownHandle, + timer_sender: &'a mpsc::Sender>, + outgoing_socket: Option, + _marker: PhantomData, } -impl<'a, D: Dispatcher> Provider<'a, D> { - /// Grab a channel to send messages to the event loop. - #[must_use] - pub fn channel(&self) -> Sender { - self.event_loop.channel() - } +impl<'a, D> Write for Provider<'a, D> +where + D: Dispatcher + std::fmt::Debug, +{ + #[instrument(skip(self), fields(buffer= ?self.buffer))] + fn write(&mut self, buf: &[u8]) -> std::io::Result { + let dest = self.buffer.get_or_insert_with(|| self.buffer_pool.pop()); - /// Execute a closure with a buffer and send the buffer contents to the - /// destination address or reclaim the buffer and do not send anything. - pub fn outgoing(&mut self, out: F) - where - F: FnOnce(&mut [u8]) -> Option<(usize, SocketAddr)>, - { - let mut buffer = self.buffer_pool.pop(); - let opt_send_to = out(buffer.as_mut()); + let wrote = dest.write(buf)?; + + tracing::trace!(%wrote, "write"); + + Ok(wrote) + } - match opt_send_to { - None => self.buffer_pool.push(buffer), - Some((bytes, addr)) => { - buffer.set_written(bytes); + #[instrument(skip(self))] + fn flush(&mut self) -> std::io::Result<()> { + if let Some(buffer) = self.buffer.take() { + tracing::trace!(?buffer, "flushing..."); + if let Some(addr) = self.outgoing_socket { self.out_queue.push_back((buffer, addr)); + self.wake(); + } else { + self.buffer_pool.push(buffer); + return Err(std::io::Error::new( + std::io::ErrorKind::Other, + "No outgoing socket address set", + )); } + } else { + tracing::warn!("flush empty"); + } + Ok(()) + } +} + +impl<'a, D> Provider<'a, D> +where + D: Dispatcher + std::fmt::Debug, +{ + #[instrument(skip())] + pub fn new( + buffer_pool: &'a mut BufferPool, + out_queue: &'a mut VecDeque<(Buffer, SocketAddr)>, + waker: &'a Waker, + shutdown_handle: &'a mut ShutdownHandle, + timer_sender: &'a mpsc::Sender>, + ) -> Provider<'a, D> { + Provider { + buffer_pool, + buffer: None, + out_queue, + waker, + timer_sender, + shutdown_handle, + outgoing_socket: None, + _marker: PhantomData, } } - /// Set a timeout with the given delay and token. + #[instrument(skip(self))] + pub fn set_dest(&mut self, dest: SocketAddr) -> Option { + self.outgoing_socket.replace(dest) + } + + /// Wakes the event loop. + /// + /// # Panics + /// + /// This function will panic if waking the event loop fails. + #[instrument(skip(self))] + pub fn wake(&self) { + self.waker.wake().expect("Failed to wake the event loop"); + } + + /// Sets a timeout with the given token and delay. /// /// # Errors /// - /// It would error when the timeout returns in a error. - pub fn set_timeout(&mut self, token: D::Timeout, delay: u64) -> TimerResult { - self.event_loop.timeout_ms(token, delay) + /// This function will return an error if sending message fails. + #[instrument(skip(self, token, when))] + pub fn set_timeout(&mut self, token: D::TimeoutToken, when: Instant) -> Result<(), Box> + where + D::TimeoutToken: 'static, + { + tracing::trace!(?token, ?when, "set timeout"); + + self.timer_sender.send(TimeoutAction::Add { token, when })?; + self.wake(); + Ok(()) } - /// Clear a timeout using the provided timeout identifier. - pub fn clear_timeout(&mut self, timeout: Timeout) -> bool { - self.event_loop.clear_timeout(timeout) + /// Removes a timeout + /// + /// # Errors + /// + /// This function will return an error if sending message fails. + #[instrument(skip(self))] + pub fn remove_timeout(&mut self, token: D::TimeoutToken) -> Result<(), Box> + where + D::TimeoutToken: 'static, + { + tracing::trace!("remove timeout"); + + self.timer_sender.send(TimeoutAction::Remove { token })?; + self.wake(); + Ok(()) } - /// Shutdown the event loop. + /// Shuts down the event loop. + /// + /// # Panics + /// + /// This function will panic if sending the shutdown signal fails. + #[instrument(skip(self))] pub fn shutdown(&mut self) { - self.event_loop.shutdown(); + tracing::debug!("shutdown"); + + self.shutdown_handle.shutdown(); } } diff --git a/contrib/umio/tests/common/mod.rs b/contrib/umio/tests/common/mod.rs index 97102082b..3d7e204ca 100644 --- a/contrib/umio/tests/common/mod.rs +++ b/contrib/umio/tests/common/mod.rs @@ -1,8 +1,30 @@ -use std::net::SocketAddr; -use std::sync::mpsc::{self}; +use std::io::Write; +use std::net::{Ipv4Addr, SocketAddr, SocketAddrV4}; +use std::sync::{mpsc, Once}; +use std::time::Instant; +use tracing::level_filters::LevelFilter; +use tracing::{instrument, Level}; use umio::{Dispatcher, Provider}; +pub const LOOPBACK_IPV4: SocketAddr = SocketAddr::V4(SocketAddrV4::new(Ipv4Addr::LOCALHOST, 0)); + +#[allow(dead_code)] +pub static INIT: Once = Once::new(); + +#[allow(dead_code)] +pub fn tracing_stderr_init(filter: LevelFilter) { + let builder = tracing_subscriber::fmt() + .with_max_level(filter) + .with_ansi(true) + .with_writer(std::io::stderr); + + builder.pretty().with_file(true).init(); + + tracing::info!("Logging initialized"); +} + +#[derive(Debug)] pub struct MockDispatcher { send: mpsc::Sender, } @@ -16,12 +38,13 @@ pub enum MockMessage { SendNotify, SendMessage(Vec, SocketAddr), - SendTimeout(u32, u64), + SendTimeout(u32, Instant), Shutdown, } impl MockDispatcher { + #[instrument(skip(), ret(level = Level::TRACE))] pub fn new() -> (MockDispatcher, mpsc::Receiver) { let (send, recv) = mpsc::channel(); @@ -30,28 +53,29 @@ impl MockDispatcher { } impl Dispatcher for MockDispatcher { - type Timeout = u32; + type TimeoutToken = u32; type Message = MockMessage; + #[instrument(skip())] fn incoming(&mut self, _: Provider<'_, Self>, message: &[u8], addr: SocketAddr) { let owned_message = message.to_vec(); - + tracing::trace!("MockDispatcher: Received message from {addr}"); self.send.send(MockMessage::MessageReceived(owned_message, addr)).unwrap(); } + #[instrument(skip(provider))] fn notify(&mut self, mut provider: Provider<'_, Self>, msg: Self::Message) { + tracing::trace!("MockDispatcher: Received notification {msg:?}"); match msg { MockMessage::SendMessage(message, addr) => { - provider.outgoing(|buffer| { - for (src, dst) in message.iter().zip(buffer.as_mut().iter_mut()) { - *dst = *src; - } + let _ = provider.set_dest(addr); + + let _ = provider.write(&message).unwrap(); - Some((message.len(), addr)) - }); + let () = provider.flush().unwrap(); } - MockMessage::SendTimeout(token, delay) => { - provider.set_timeout(token, delay).unwrap(); + MockMessage::SendTimeout(token, when) => { + provider.set_timeout(token, when).unwrap(); } MockMessage::SendNotify => { self.send.send(MockMessage::NotifyReceived).unwrap(); @@ -63,7 +87,9 @@ impl Dispatcher for MockDispatcher { } } - fn timeout(&mut self, _: Provider<'_, Self>, token: Self::Timeout) { + #[instrument(skip())] + fn timeout(&mut self, _: Provider<'_, Self>, token: Self::TimeoutToken) { + tracing::trace!("MockDispatcher: Timeout received for token {token}"); self.send.send(MockMessage::TimeoutReceived(token)).unwrap(); } } diff --git a/contrib/umio/tests/test_incoming.rs b/contrib/umio/tests/test_incoming.rs index 85bfb1545..cee9b3394 100644 --- a/contrib/umio/tests/test_incoming.rs +++ b/contrib/umio/tests/test_incoming.rs @@ -1,40 +1,58 @@ use std::net::UdpSocket; -use std::thread::{self}; +use std::sync::mpsc; use std::time::Duration; -use common::{MockDispatcher, MockMessage}; +use common::{tracing_stderr_init, MockDispatcher, MockMessage, INIT, LOOPBACK_IPV4}; +use tracing::level_filters::LevelFilter; use umio::ELoopBuilder; mod common; +/// Tests that an incoming message is correctly received and processed. #[test] fn positive_receive_incoming_message() { - let eloop_addr = "127.0.0.1:5050".parse().unwrap(); - let mut eloop = ELoopBuilder::new().bind_address(eloop_addr).build().unwrap(); + INIT.call_once(|| { + tracing_stderr_init(LevelFilter::ERROR); + }); + + tracing::trace!("Starting test: positive_receive_incoming_message"); + + let (mut eloop, eloop_socket, _shutdown_handle) = ELoopBuilder::new().bind_address(LOOPBACK_IPV4).build().unwrap(); let (dispatcher, dispatch_recv) = MockDispatcher::new(); let dispatch_send = eloop.channel(); - thread::spawn(move || { - eloop.run(dispatcher).unwrap(); - }); - thread::sleep(Duration::from_millis(50)); + let handle = { + let (started_eloop_sender, started_eloop_receiver) = mpsc::sync_channel(0); + + let handle = std::thread::spawn(move || { + eloop.run(dispatcher, started_eloop_sender).unwrap(); + }); + + let () = started_eloop_receiver.recv().unwrap().unwrap(); + + handle + }; - let socket_addr = "127.0.0.1:5051".parse().unwrap(); - let socket = UdpSocket::bind(socket_addr).unwrap(); + let socket = UdpSocket::bind(LOOPBACK_IPV4).unwrap(); + let socket_addr = socket.local_addr().unwrap(); let message = b"This Is A Test Message"; - socket.send_to(&message[..], eloop_addr).unwrap(); - thread::sleep(Duration::from_millis(50)); + tracing::trace!("Sending message to event loop"); + socket.send_to(&message[..], eloop_socket).unwrap(); + std::thread::sleep(Duration::from_millis(50)); - match dispatch_recv.try_recv() { + tracing::trace!("Checking for received message"); + let res: Result = dispatch_recv.try_recv(); + + dispatch_send.send(MockMessage::Shutdown).unwrap(); + handle.join().unwrap(); + + match res { Ok(MockMessage::MessageReceived(msg, addr)) => { assert_eq!(&msg[..], &message[..]); - assert_eq!(addr, socket_addr); } _ => panic!("ELoop Failed To Receive Incoming Message"), - } - - dispatch_send.send(MockMessage::Shutdown).unwrap(); + }; } diff --git a/contrib/umio/tests/test_notify.rs b/contrib/umio/tests/test_notify.rs index 5b624dee6..93de4ad8f 100644 --- a/contrib/umio/tests/test_notify.rs +++ b/contrib/umio/tests/test_notify.rs @@ -1,31 +1,45 @@ -use std::thread::{self}; +use std::sync::mpsc; use std::time::Duration; -use common::{MockDispatcher, MockMessage}; +use common::{tracing_stderr_init, MockDispatcher, MockMessage, INIT, LOOPBACK_IPV4}; +use tracing::level_filters::LevelFilter; use umio::ELoopBuilder; mod common; #[test] fn positive_send_notify() { - let eloop_addr = "127.0.0.1:0".parse().unwrap(); - let mut eloop = ELoopBuilder::new().bind_address(eloop_addr).build().unwrap(); + INIT.call_once(|| { + tracing_stderr_init(LevelFilter::ERROR); + }); + + let (mut eloop, _eloop_socket, _shutdown_handle) = ELoopBuilder::new().bind_address(LOOPBACK_IPV4).build().unwrap(); let (dispatcher, dispatch_recv) = MockDispatcher::new(); let dispatch_send = eloop.channel(); - thread::spawn(move || { - eloop.run(dispatcher).unwrap(); - }); - thread::sleep(Duration::from_millis(50)); + let handle = { + let (started_eloop_sender, started_eloop_receiver) = mpsc::sync_channel(0); + + let handle = std::thread::spawn(move || { + eloop.run(dispatcher, started_eloop_sender).unwrap(); + }); + + let () = started_eloop_receiver.recv().unwrap().unwrap(); + handle + }; + tracing::trace!("Sending MockMessage::SendNotify"); dispatch_send.send(MockMessage::SendNotify).unwrap(); - thread::sleep(Duration::from_millis(50)); + std::thread::sleep(Duration::from_millis(50)); - match dispatch_recv.try_recv() { + let res = dispatch_recv.try_recv(); + + dispatch_send.send(MockMessage::Shutdown).unwrap(); + handle.join().unwrap(); + + match res { Ok(MockMessage::NotifyReceived) => (), _ => panic!("ELoop Failed To Receive Incoming Message"), } - - dispatch_send.send(MockMessage::Shutdown).unwrap(); } diff --git a/contrib/umio/tests/test_outgoing.rs b/contrib/umio/tests/test_outgoing.rs index 3db172f75..3829fc2b2 100644 --- a/contrib/umio/tests/test_outgoing.rs +++ b/contrib/umio/tests/test_outgoing.rs @@ -1,39 +1,55 @@ use std::net::UdpSocket; -use std::thread::{self}; +use std::sync::mpsc; use std::time::Duration; -use common::{MockDispatcher, MockMessage}; +use common::{tracing_stderr_init, MockDispatcher, MockMessage, INIT, LOOPBACK_IPV4}; +use tracing::level_filters::LevelFilter; use umio::ELoopBuilder; mod common; #[test] fn positive_send_outgoing_message() { - let eloop_addr = "127.0.0.1:5052".parse().unwrap(); - let mut eloop = ELoopBuilder::new().bind_address(eloop_addr).build().unwrap(); + INIT.call_once(|| { + tracing_stderr_init(LevelFilter::ERROR); + }); + + tracing::trace!("Starting test: positive_send_outgoing_message"); + let (mut eloop, eloop_socket, _shutdown_handle) = ELoopBuilder::new().bind_address(LOOPBACK_IPV4).build().unwrap(); let (dispatcher, _) = MockDispatcher::new(); let dispatch_send = eloop.channel(); - thread::spawn(move || { - eloop.run(dispatcher).unwrap(); - }); - thread::sleep(Duration::from_millis(50)); + let handle = { + let (started_eloop_sender, started_eloop_receiver) = mpsc::sync_channel(0); + + let handle = std::thread::spawn(move || { + eloop.run(dispatcher, started_eloop_sender).unwrap(); + }); + + let () = started_eloop_receiver.recv().unwrap().unwrap(); + + handle + }; let message = b"This Is A Test Message"; let mut message_recv = [0u8; 22]; - let socket_addr = "127.0.0.1:5053".parse().unwrap(); - let socket = UdpSocket::bind(socket_addr).unwrap(); + let socket = UdpSocket::bind(LOOPBACK_IPV4).unwrap(); + let socket_addr = socket.local_addr().unwrap(); // Get the actual address + + tracing::trace!("sending message to: {socket_addr}"); dispatch_send .send(MockMessage::SendMessage(message.to_vec(), socket_addr)) .unwrap(); - thread::sleep(Duration::from_millis(50)); + tracing::trace!("receiving message from: {eloop_socket}"); + socket.set_read_timeout(Some(Duration::from_secs(1))).unwrap(); let (bytes, addr) = socket.recv_from(&mut message_recv).unwrap(); + dispatch_send.send(MockMessage::Shutdown).unwrap(); + handle.join().unwrap(); // Wait for the event loop to finish + assert_eq!(bytes, message.len()); assert_eq!(&message[..], &message_recv[..]); - assert_eq!(addr, eloop_addr); - - dispatch_send.send(MockMessage::Shutdown).unwrap(); + assert_eq!(addr, eloop_socket); } diff --git a/contrib/umio/tests/test_shutdown.rs b/contrib/umio/tests/test_shutdown.rs index dd7023f2f..cbe130bf0 100644 --- a/contrib/umio/tests/test_shutdown.rs +++ b/contrib/umio/tests/test_shutdown.rs @@ -1,26 +1,36 @@ -use std::thread::{self}; -use std::time::Duration; +use std::sync::mpsc; -use common::{MockDispatcher, MockMessage}; +use common::{tracing_stderr_init, MockDispatcher, MockMessage, INIT, LOOPBACK_IPV4}; +use tracing::level_filters::LevelFilter; use umio::ELoopBuilder; mod common; #[test] fn positive_execute_shutdown() { - let eloop_addr = "127.0.0.1:0".parse().unwrap(); - let mut eloop = ELoopBuilder::new().bind_address(eloop_addr).build().unwrap(); + INIT.call_once(|| { + tracing_stderr_init(LevelFilter::ERROR); + }); + + let (mut eloop, _eloop_socket, _shutdown_handle) = ELoopBuilder::new().bind_address(LOOPBACK_IPV4).build().unwrap(); let (dispatcher, _) = MockDispatcher::new(); let dispatch_send = eloop.channel(); - thread::spawn(move || { - eloop.run(dispatcher).unwrap(); - }); - thread::sleep(Duration::from_millis(50)); + let handle = { + let (started_eloop_sender, started_eloop_receiver) = mpsc::sync_channel(0); + + let handle = std::thread::spawn(move || { + eloop.run(dispatcher, started_eloop_sender).unwrap(); + }); + + let () = started_eloop_receiver.recv().unwrap().unwrap(); + + handle + }; dispatch_send.send(MockMessage::Shutdown).unwrap(); - thread::sleep(Duration::from_millis(50)); + handle.join().unwrap(); assert!(dispatch_send.send(MockMessage::SendNotify).is_err()); } diff --git a/contrib/umio/tests/test_timeout.rs b/contrib/umio/tests/test_timeout.rs index 60f537016..56b5c33b1 100644 --- a/contrib/umio/tests/test_timeout.rs +++ b/contrib/umio/tests/test_timeout.rs @@ -1,34 +1,51 @@ +use std::sync::mpsc; use std::thread::{self}; -use std::time::Duration; +use std::time::{Duration, Instant}; -use common::{MockDispatcher, MockMessage}; +use common::{tracing_stderr_init, MockDispatcher, MockMessage, INIT, LOOPBACK_IPV4}; +use tracing::level_filters::LevelFilter; use umio::ELoopBuilder; mod common; #[test] fn positive_send_notify() { - let eloop_addr = "127.0.0.1:0".parse().unwrap(); - let mut eloop = ELoopBuilder::new().bind_address(eloop_addr).build().unwrap(); + INIT.call_once(|| { + tracing_stderr_init(LevelFilter::ERROR); + }); + + let (mut eloop, _eloop_socket, _shutdown_handle) = ELoopBuilder::new().bind_address(LOOPBACK_IPV4).build().unwrap(); let (dispatcher, dispatch_recv) = MockDispatcher::new(); let dispatch_send = eloop.channel(); - thread::spawn(move || { - eloop.run(dispatcher).unwrap(); - }); - thread::sleep(Duration::from_millis(50)); + let handle = { + let (started_eloop_sender, started_eloop_receiver) = mpsc::sync_channel(0); + + let handle = std::thread::spawn(move || { + eloop.run(dispatcher, started_eloop_sender).unwrap(); + }); + + let () = started_eloop_receiver.recv().unwrap().unwrap(); + + handle + }; let token = 5; - dispatch_send.send(MockMessage::SendTimeout(token, 50)).unwrap(); + let timeout_at = Instant::now() + Duration::from_millis(50); + dispatch_send.send(MockMessage::SendTimeout(token, timeout_at)).unwrap(); thread::sleep(Duration::from_millis(300)); - match dispatch_recv.try_recv() { + let res = dispatch_recv.try_recv(); + + dispatch_send.send(MockMessage::Shutdown).unwrap(); + handle.join().unwrap(); + + match res { Ok(MockMessage::TimeoutReceived(tkn)) => { assert_eq!(tkn, token); } - _ => panic!("ELoop Failed To Receive Timeout"), + Ok(other) => panic!("Received Other: {other:?}"), + Err(e) => panic!("Received Error: {e}"), } - - dispatch_send.send(MockMessage::Shutdown).unwrap(); } diff --git a/packages/dht/examples/debug.rs b/packages/dht/examples/debug.rs index 4998bb4ed..eb43ca27a 100644 --- a/packages/dht/examples/debug.rs +++ b/packages/dht/examples/debug.rs @@ -43,7 +43,7 @@ impl HandshakerTrait for SimpleHandshaker { self.filter.insert(addr); self.count += 1; - println!("Received new peer {:?}, total unique peers {}", addr, self.count); + tracing::trace!("Received new peer {:?}, total unique peers {}", addr, self.count); Box::pin(std::future::ready(())) } @@ -85,7 +85,7 @@ async fn main() { let mut events = dht.events().await; tasks.spawn(async move { while let Some(event) = events.next().await { - println!("\nReceived Dht Event {event:?}"); + tracing::trace!("\nReceived Dht Event {event:?}"); } }); diff --git a/packages/handshake/examples/handshake_torrent.rs b/packages/handshake/examples/handshake_torrent.rs index 0d8bc20b8..ab9568b65 100644 --- a/packages/handshake/examples/handshake_torrent.rs +++ b/packages/handshake/examples/handshake_torrent.rs @@ -36,7 +36,7 @@ async fn main() -> std::io::Result<()> { .await .map_err(|e| std::io::Error::new(std::io::ErrorKind::Other, e))?; - println!("\nConnection With Peer Established...Closing In 10 Seconds"); + tracing::trace!("\nConnection With Peer Established...Closing In 10 Seconds"); sleep(Duration::from_secs(10)).await; diff --git a/packages/utracker/Cargo.toml b/packages/utracker/Cargo.toml index 7c38d5269..d5f766ac7 100644 --- a/packages/utracker/Cargo.toml +++ b/packages/utracker/Cargo.toml @@ -22,7 +22,6 @@ util = { path = "../util" } umio = { path = "../../contrib/umio" } byteorder = "1" -chrono = "0" futures = "0" nom = "7" rand = "0" diff --git a/packages/utracker/src/announce.rs b/packages/utracker/src/announce.rs index 7906b6ec6..6a227938e 100644 --- a/packages/utracker/src/announce.rs +++ b/packages/utracker/src/announce.rs @@ -96,10 +96,14 @@ impl<'a> AnnounceRequest<'a> { /// # Errors /// /// It would return an IO error if unable to write the bytes. - pub fn write_bytes(&self, mut writer: W) -> std::io::Result<()> + #[allow(clippy::needless_borrows_for_generic_args)] + #[instrument(skip(self, writer), err)] + pub fn write_bytes(&self, mut writer: &mut W) -> std::io::Result<()> where W: std::io::Write, { + tracing::trace!("write_bytes"); + writer.write_all(self.info_hash.as_ref())?; writer.write_all(self.peer_id.as_ref())?; @@ -316,7 +320,7 @@ fn parse_response<'a>( // ----------------------------------------------------------------------------// /// Announce state of a client reported to the server. -#[derive(Debug, PartialEq, Eq, Copy, Clone)] +#[derive(Clone, Copy, Debug, PartialEq, Eq, PartialOrd, Ord, Hash)] pub struct ClientState { downloaded: i64, left: i64, @@ -400,7 +404,7 @@ fn parse_state(bytes: &[u8]) -> IResult<&[u8], ClientState> { /// Announce event of a client reported to the server. #[allow(clippy::module_name_repetitions)] -#[derive(Debug, Copy, Clone, PartialEq, Eq)] +#[derive(Clone, Copy, Debug, PartialEq, Eq, PartialOrd, Ord, Hash)] pub enum AnnounceEvent { /// No event is reported. None, @@ -591,6 +595,7 @@ impl DesiredPeers { /// # Errors /// /// It would return an IO Error if unable to write the bytes. + #[instrument(skip(self, writer), err)] pub fn write_bytes(&self, mut writer: W) -> std::io::Result<()> where W: std::io::Write, diff --git a/packages/utracker/src/client/dispatcher.rs b/packages/utracker/src/client/dispatcher.rs index 2d0dd8b9b..274a8ed92 100644 --- a/packages/utracker/src/client/dispatcher.rs +++ b/packages/utracker/src/client/dispatcher.rs @@ -1,18 +1,18 @@ use std::collections::hash_map::Entry; use std::collections::HashMap; use std::net::SocketAddr; +use std::ops::{Deref, DerefMut}; +use std::sync::mpsc; +use std::time::{Duration, Instant, UNIX_EPOCH}; -use chrono::offset::Utc; -use chrono::{DateTime, Duration}; use futures::executor::block_on; use futures::future::{BoxFuture, Either}; use futures::sink::Sink; use futures::{FutureExt, SinkExt}; use handshake::{DiscoveryInfo, InitiateMessage, Protocol}; use nom::IResult; -use tracing::instrument; -use umio::external::{self, Timeout}; -use umio::{Dispatcher, ELoopBuilder, Provider}; +use tracing::{instrument, Level}; +use umio::{Dispatcher, ELoopBuilder, MessageSender, Provider, ShutdownHandle}; use util::bt::PeerId; use super::HandshakerMessage; @@ -30,18 +30,70 @@ const CONNECTION_ID_VALID_DURATION_MILLIS: i64 = 60000; const MAXIMUM_REQUEST_RETRANSMIT_ATTEMPTS: u64 = 8; /// Internal dispatch timeout. -#[derive(Debug)] +#[derive(Clone, Copy, Debug, PartialEq, Eq, PartialOrd, Ord, Hash)] enum DispatchTimeout { Connect(ClientToken), CleanUp, } +impl Default for DispatchTimeout { + fn default() -> Self { + Self::CleanUp + } +} + +#[derive(Default, Clone, Copy, Debug)] +struct TimeoutToken { + id: TimeoutId, + dispatch: DispatchTimeout, +} + +impl TimeoutToken { + fn new(dispatch: DispatchTimeout) -> (Self, TimeoutId) { + let id = TimeoutId::default(); + (Self { id, dispatch }, id) + } + + fn cleanup(id: TimeoutId) -> Self { + Self { + id, + dispatch: DispatchTimeout::CleanUp, + } + } +} + +impl Ord for TimeoutToken { + fn cmp(&self, other: &Self) -> std::cmp::Ordering { + self.id.cmp(&other.id) + } +} + +impl PartialOrd for TimeoutToken { + fn partial_cmp(&self, other: &Self) -> Option { + Some(self.cmp(other)) + } +} + +impl Eq for TimeoutToken {} + +impl PartialEq for TimeoutToken { + fn eq(&self, other: &Self) -> bool { + self.id == other.id + } +} + +impl std::hash::Hash for TimeoutToken { + fn hash(&self, state: &mut H) { + self.id.hash(state); + } +} + /// Internal dispatch message for clients. #[derive(Debug)] pub enum DispatchMessage { Request(SocketAddr, ClientToken, ClientRequest), StartTimer, - Shutdown, + Shutdown(mpsc::SyncSender>), } /// Create a new background dispatcher to execute request and send responses back. @@ -54,7 +106,7 @@ pub fn create_dispatcher( handshaker: H, msg_capacity: usize, limiter: RequestLimiter, -) -> std::io::Result> +) -> std::io::Result<(MessageSender, SocketAddr, ShutdownHandle)> where H: Sink> + std::fmt::Debug + DiscoveryInfo + Send + Unpin + 'static, H::Error: std::fmt::Display, @@ -68,20 +120,30 @@ where .bind_address(bind) .buffer_length(EXPECTED_PACKET_LENGTH); - let mut eloop = builder.build()?; + let (mut eloop, socket, shutdown) = builder.build()?; let channel = eloop.channel(); - let dispatch = ClientDispatcher::new(handshaker, bind, limiter); + let dispatcher = ClientDispatcher::new(handshaker, bind, limiter); + + let handle = { + let (started_eloop_sender, started_eloop_receiver) = mpsc::sync_channel(0); - std::thread::spawn(move || { - eloop.run(dispatch).expect("bip_utracker: ELoop Shutdown Unexpectedly..."); - }); + let handle = std::thread::spawn(move || { + eloop.run(dispatcher, started_eloop_sender).unwrap(); + }); + + let () = started_eloop_receiver + .recv() + .map_err(|e| std::io::Error::new(std::io::ErrorKind::BrokenPipe, e))??; + + handle + }; channel .send(DispatchMessage::StartTimer) .expect("bip_utracker: ELoop Failed To Start Connect ID Timer..."); - Ok(channel) + Ok((channel, socket, shutdown)) } // ----------------------------------------------------------------------------// @@ -104,7 +166,7 @@ where H::Error: std::fmt::Display, { /// Create a new `ClientDispatcher`. - #[instrument(skip(), ret)] + #[instrument(skip(), ret(level = Level::TRACE))] pub fn new(handshaker: H, bind: SocketAddr, limiter: RequestLimiter) -> ClientDispatcher { tracing::debug!("new client dispatcher"); @@ -123,26 +185,30 @@ where } /// Shutdown the current dispatcher, notifying all pending requests. - #[instrument(skip(self, provider))] + #[instrument(skip(self, provider), fields(unfinished_requests= %self.active_requests.len()))] pub fn shutdown(&mut self, provider: &mut Provider<'_, ClientDispatcher>) { - tracing::debug!("shutting down client dispatcher"); + tracing::debug!("shuting down..."); + + let mut active_requests = std::mem::take(&mut self.active_requests); + let mut unfinished_requests = active_requests.drain(); // Notify all active requests with the appropriate error - for token_index in 0..self.active_requests.len() { - let next_token = *self.active_requests.keys().nth(token_index).unwrap(); + for (client_token, connect_timer) in unfinished_requests.by_ref() { + tracing::trace!(?client_token, ?connect_timer, "removing..."); - self.notify_client(next_token, Err(ClientError::ClientShutdown)); - } - // TODO: Clear active timeouts - self.active_requests.clear(); + if let Some(id) = connect_timer.timeout_id() { + provider.remove_timeout(TimeoutToken::cleanup(id)).unwrap(); + } + self.notify_client(client_token, Err(ClientError::ClientShutdown)); + } provider.shutdown(); } /// Finish a request by sending the result back to the client. #[instrument(skip(self))] pub fn notify_client(&mut self, token: ClientToken, result: ClientResult) { - tracing::info!("notifying clients"); + tracing::trace!("notifying clients"); match block_on(self.handshaker.send(Ok(ClientMetadata::new(token, result).into()))) { Ok(()) => tracing::debug!("client metadata sent"), @@ -153,7 +219,7 @@ where } /// Process a request to be sent to the given address and associated with the given token. - #[instrument(skip(self, provider))] + #[instrument(skip(self, provider, addr, token, request))] pub fn send_request( &mut self, provider: &mut Provider<'_, ClientDispatcher>, @@ -161,7 +227,7 @@ where token: ClientToken, request: ClientRequest, ) { - tracing::debug!("sending request"); + tracing::debug!(?addr, ?token, ?request, "sending request"); let bound_addr = self.bound_addr; @@ -182,14 +248,14 @@ where } /// Process a response received from some tracker and match it up against our sent requests. - #[instrument(skip(self, provider, response))] + #[instrument(skip(self, provider, response, addr))] pub fn recv_response( &mut self, provider: &mut Provider<'_, ClientDispatcher>, response: &TrackerResponse<'_>, addr: SocketAddr, ) { - tracing::debug!("receiving response"); + tracing::debug!(?response, ?addr, "receiving response"); let token = ClientToken(response.transaction_id()); @@ -207,11 +273,11 @@ where return; }; - provider.clear_timeout( - conn_timer - .timeout_id() - .expect("bip_utracker: Failed To Clear Request Timeout"), - ); + if let Some(clear_timeout_token) = conn_timer.timeout_id().map(TimeoutToken::cleanup) { + provider + .remove_timeout(clear_timeout_token) + .expect("bip_utracker: Failed To Clear Request Timeout"); + }; // Check if the response requires us to update the connection timer if let &ResponseType::Connect(id) = response.response_type() { @@ -225,7 +291,7 @@ where (&ClientRequest::Announce(hash, _), ResponseType::Announce(res)) => { // Forward contact information on to the handshaker for addr in res.peers().iter() { - tracing::info!("sending will block if unable to send!"); + tracing::debug!("sending will block if unable to send!"); match block_on( self.handshaker .send(Ok(InitiateMessage::new(Protocol::BitTorrent, hash, addr).into())), @@ -253,9 +319,9 @@ where /// Process an existing request, either re requesting a connection id or sending the actual request again. /// /// If this call is the result of a timeout, that will decide whether to cancel the request or not. - #[instrument(skip(self, provider))] + #[instrument(skip(self, provider, token, timed_out))] fn process_request(&mut self, provider: &mut Provider<'_, ClientDispatcher>, token: ClientToken, timed_out: bool) { - tracing::debug!("processing request"); + tracing::debug!(?token, ?timed_out, "processing request"); let Some(mut conn_timer) = self.active_requests.remove(&token) else { tracing::error!(?token, "token not in active requests"); @@ -312,34 +378,37 @@ where // Try to write the request out to the server let mut write_success = false; - provider.outgoing(|bytes| { - let mut writer = std::io::Cursor::new(bytes); - match tracker_request.write_bytes(&mut writer) { + provider.set_dest(addr); + + { + match tracker_request.write_bytes(provider) { Ok(()) => { write_success = true; - Some((writer.position().try_into().unwrap(), addr)) } Err(e) => { - tracing::error!("failed to write out the tracker request with error: {e}"); - None + tracing::error!(?e, "failed to write out the tracker request with error"); } - } - }); + }; + } + + let next_timeout_at = Instant::now().checked_add(Duration::from_millis(next_timeout)).unwrap(); + + let (timeout_token, timeout_id) = TimeoutToken::new(DispatchTimeout::Connect(token)); + + let () = provider + .set_timeout(timeout_token, next_timeout_at) + .expect("bip_utracker: Failed To Set Timeout For Request"); // If message was not sent (too long to fit) then end the request if write_success { - conn_timer.set_timeout_id( - provider - .set_timeout(DispatchTimeout::Connect(token), next_timeout) - .expect("bip_utracker: Failed To Set Timeout For Request"), - ); + conn_timer.set_timeout_id(timeout_id); self.active_requests.insert(token, conn_timer); } else { - let err = ClientError::MaxLength; - tracing::warn!("notifying client with error: {err}"); + let e = ClientError::MaxLength; + tracing::warn!(?e, "notifying client with error"); - self.notify_client(token, Err(err)); + self.notify_client(token, Err(e)); } } } @@ -349,47 +418,57 @@ where H: Sink> + std::fmt::Debug + DiscoveryInfo + Send + Unpin + 'static, H::Error: std::fmt::Display, { - type Timeout = DispatchTimeout; + type TimeoutToken = TimeoutToken; type Message = DispatchMessage; - #[instrument(skip(self, provider))] + #[instrument(skip(self, provider, message, addr))] fn incoming(&mut self, mut provider: Provider<'_, Self>, message: &[u8], addr: SocketAddr) { + tracing::debug!(?message, %addr, "received incoming"); + let () = match TrackerResponse::from_bytes(message) { IResult::Ok((_, response)) => { - tracing::debug!("received an incoming response: {response:?}"); + tracing::trace!(?response, %addr, "received an incoming response"); self.recv_response(&mut provider, &response, addr); } Err(e) => { - tracing::error!("received an incoming error message: {e}"); + tracing::error!(%e, "received an incoming error message"); } }; } - #[instrument(skip(self, provider))] + #[instrument(skip(self, provider, message))] fn notify(&mut self, mut provider: Provider<'_, Self>, message: DispatchMessage) { - tracing::debug!("received notify"); + tracing::debug!(?message, "received notify"); match message { DispatchMessage::Request(addr, token, req_type) => { self.send_request(&mut provider, addr, token, req_type); } - DispatchMessage::StartTimer => self.timeout(provider, DispatchTimeout::CleanUp), - DispatchMessage::Shutdown => self.shutdown(&mut provider), + DispatchMessage::StartTimer => self.timeout(provider, TimeoutToken::default()), + DispatchMessage::Shutdown(shutdown_finished_sender) => { + self.shutdown(&mut provider); + + let () = shutdown_finished_sender.send(Ok(())).unwrap(); + } } } - #[instrument(skip(self, provider))] - fn timeout(&mut self, mut provider: Provider<'_, Self>, timeout: DispatchTimeout) { - tracing::debug!("received timeout"); + #[instrument(skip(self, provider, timeout))] + fn timeout(&mut self, mut provider: Provider<'_, Self>, timeout: TimeoutToken) { + tracing::debug!(?timeout, "received timeout"); - match timeout { + match timeout.dispatch { DispatchTimeout::Connect(token) => self.process_request(&mut provider, token, true), DispatchTimeout::CleanUp => { self.id_cache.clean_expired(); + let next_timeout_at = Instant::now() + .checked_add(Duration::from_millis(CONNECTION_ID_VALID_DURATION_MILLIS as u64)) + .unwrap(); + provider - .set_timeout(DispatchTimeout::CleanUp, CONNECTION_ID_VALID_DURATION_MILLIS as u64) + .set_timeout(TimeoutToken::default(), next_timeout_at) .expect("bip_utracker: Failed To Restart Connect Id Cleanup Timer"); } }; @@ -398,29 +477,47 @@ where // ----------------------------------------------------------------------------// +#[derive(Clone, Copy, Debug, PartialEq, Eq, PartialOrd, Ord, Hash)] +struct TimeoutId { + id: u128, +} + +impl Default for TimeoutId { + fn default() -> Self { + Self { + id: UNIX_EPOCH.elapsed().unwrap().as_nanos(), + } + } +} + +impl TimeoutId { + fn new(id: u128) -> Self { + Self { id } + } +} + +impl Deref for TimeoutId { + type Target = u128; + + fn deref(&self) -> &Self::Target { + &self.id + } +} + +impl DerefMut for TimeoutId { + fn deref_mut(&mut self) -> &mut Self::Target { + &mut self.id + } +} + /// Contains logic for making sure a valid connection id is present /// and correctly timing out when sending requests to the server. +#[derive(Clone, Copy, Debug, PartialEq, Eq, PartialOrd, Ord, Hash)] struct ConnectTimer { addr: SocketAddr, attempt: u64, request: ClientRequest, - timeout_id: Option, -} - -impl std::fmt::Debug for ConnectTimer { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - let timeout_id = match self.timeout_id { - Some(_) => "Some(_)", - None => "None", - }; - - f.debug_struct("ConnectTimer") - .field("addr", &self.addr) - .field("attempt", &self.attempt) - .field("request", &self.request) - .field("timeout_id", &timeout_id) - .finish() - } + timeout_id: Option, } impl ConnectTimer { @@ -435,10 +532,8 @@ impl ConnectTimer { } /// Yields the current timeout value to use or None if the request should time out completely. - #[instrument(skip(), ret)] + #[instrument(skip(self), ret(level = Level::TRACE))] pub fn current_timeout(&mut self, timed_out: bool) -> Option { - tracing::debug!("getting current timeout"); - if self.attempt == MAXIMUM_REQUEST_RETRANSMIT_ATTEMPTS { tracing::warn!("request has reached maximum timeout attempts: {MAXIMUM_REQUEST_RETRANSMIT_ATTEMPTS}"); @@ -453,31 +548,31 @@ impl ConnectTimer { } /// Yields the current timeout id if one is set. - pub fn timeout_id(&self) -> Option { + pub fn timeout_id(&self) -> Option { self.timeout_id } /// Sets a new timeout id. - pub fn set_timeout_id(&mut self, id: Timeout) { + pub fn set_timeout_id(&mut self, id: TimeoutId) { self.timeout_id = Some(id); } /// Yields the message parameters for the current connection. - #[instrument(skip(), ret)] + #[instrument(skip(self), ret(level = Level::TRACE))] pub fn message_params(&self) -> (SocketAddr, &ClientRequest) { - tracing::debug!("getting message parameters"); - (self.addr, &self.request) } } /// Calculates the timeout for the request given the attempt count. -#[instrument(skip(), ret)] +#[instrument(skip())] fn calculate_message_timeout_millis(attempt: u64) -> u64 { - tracing::debug!("calculation message timeout in milliseconds"); - let attempt = attempt.try_into().unwrap_or(u32::MAX); - (15 * 2u64.pow(attempt)) * 1000 + let timeout = (15 * 2u64.pow(attempt)) * 1000; + + tracing::debug!(attempt, timeout, "calculated message timeout in milliseconds"); + + timeout } // ----------------------------------------------------------------------------// @@ -485,7 +580,7 @@ fn calculate_message_timeout_millis(attempt: u64) -> u64 { /// Cache for storing connection ids associated with a specific server address. #[derive(Debug)] struct ConnectIdCache { - cache: HashMap)>, + cache: HashMap, } impl ConnectIdCache { @@ -495,18 +590,16 @@ impl ConnectIdCache { } /// Get an active connection id for the given addr. - #[instrument(skip(self), ret)] + #[instrument(skip(self), ret(level = Level::TRACE))] fn get(&mut self, addr: SocketAddr) -> Option { - tracing::debug!("getting connection id"); - match self.cache.entry(addr) { Entry::Vacant(_) => { - tracing::warn!("connection id for {addr} not in cache"); + tracing::debug!("connection id for {addr} not in cache"); None } Entry::Occupied(occ) => { - let curr_time = Utc::now(); + let curr_time = Instant::now(); let prev_time = occ.get().1; if is_expired(curr_time, prev_time) { @@ -525,9 +618,9 @@ impl ConnectIdCache { /// Put an un expired connection id into cache for the given addr. #[instrument(skip(self))] fn put(&mut self, addr: SocketAddr, connect_id: u64) { - tracing::debug!("setting expired connection id"); + tracing::trace!("setting un expired connection id"); - let curr_time = Utc::now(); + let curr_time = Instant::now(); self.cache.insert(addr, (connect_id, curr_time)); } @@ -535,9 +628,8 @@ impl ConnectIdCache { /// Removes all entries that have expired. #[instrument(skip(self))] fn clean_expired(&mut self) { - tracing::debug!("cleaning expired connection id(s)"); - - let curr_time = Utc::now(); + let curr_time = Instant::now(); + let mut removed = 0; let mut curr_index = 0; let mut opt_curr_entry = self.cache.iter().skip(curr_index).map(|(&k, &v)| (k, v)).next(); @@ -549,16 +641,20 @@ impl ConnectIdCache { curr_index += 1; opt_curr_entry = self.cache.iter().skip(curr_index).map(|(&k, &v)| (k, v)).next(); } + + if removed != 0 { + tracing::debug!(%removed, "expired connection id(s)"); + } } } /// Returns true if the connect id received at `prev_time` is now expired. -#[instrument(skip(), ret)] -fn is_expired(curr_time: DateTime, prev_time: DateTime) -> bool { - tracing::debug!("checking if a previous time is now expired"); - - let valid_duration = Duration::milliseconds(CONNECTION_ID_VALID_DURATION_MILLIS); - let difference = prev_time.signed_duration_since(curr_time); - - difference >= valid_duration +#[instrument(skip(), ret(level = Level::TRACE))] +fn is_expired(curr_time: Instant, prev_time: Instant) -> bool { + let Some(difference) = curr_time.checked_duration_since(prev_time) else { + // in future + return true; + }; + + difference >= Duration::from_millis(CONNECTION_ID_VALID_DURATION_MILLIS as u64) } diff --git a/packages/utracker/src/client/mod.rs b/packages/utracker/src/client/mod.rs index fb0449d85..6a2fa2b23 100644 --- a/packages/utracker/src/client/mod.rs +++ b/packages/utracker/src/client/mod.rs @@ -1,12 +1,12 @@ use std::net::SocketAddr; use std::sync::atomic::{AtomicUsize, Ordering}; -use std::sync::Arc; +use std::sync::{mpsc, Arc}; use futures::future::Either; use futures::sink::Sink; use handshake::{DiscoveryInfo, InitiateMessage}; use tracing::instrument; -use umio::external::Sender; +use umio::{MessageSender, ShutdownHandle}; use util::bt::InfoHash; use util::trans::{LocallyShuffledIds, TransactionIds}; @@ -41,7 +41,7 @@ impl From for HandshakerMessage { /// Request made by the `TrackerClient`. #[allow(clippy::module_name_repetitions)] -#[derive(Debug)] +#[derive(Clone, Copy, Debug, PartialEq, Eq, PartialOrd, Ord, Hash)] pub enum ClientRequest { Announce(InfoHash, ClientState), Scrape(InfoHash), @@ -119,14 +119,16 @@ impl ClientResponse { /// Client will shutdown on drop. #[allow(clippy::module_name_repetitions)] pub struct TrackerClient { - send: Sender, + send: MessageSender, // We are in charge of incrementing this, background worker is in charge of decrementing limiter: RequestLimiter, generator: TokenGenerator, + bound_socket: SocketAddr, + shutdown_handle: ShutdownHandle, } impl TrackerClient { - /// Create a new `TrackerClient` with the given message capacity. + /// Run a new `TrackerClient` with the given message capacity. /// /// Panics if capacity == `usize::max_value`(). /// @@ -138,19 +140,17 @@ impl TrackerClient { /// /// It would panic if the desired capacity is too large. #[instrument(skip())] - pub fn new(bind: SocketAddr, handshaker: H, capacity_or_default: Option) -> std::io::Result + pub fn run(bind: SocketAddr, handshaker: H, capacity_or_default: Option) -> std::io::Result where H: Sink> + std::fmt::Debug + DiscoveryInfo + Send + Unpin + 'static, H::Error: std::fmt::Display, { - tracing::info!("running client"); - let capacity = if let Some(capacity) = capacity_or_default { - tracing::debug!("with capacity {capacity}"); + tracing::trace!("with capacity {capacity}"); capacity } else { - tracing::debug!("with default capacity: {DEFAULT_CAPACITY}"); + tracing::trace!("with default capacity: {DEFAULT_CAPACITY}"); DEFAULT_CAPACITY }; @@ -165,12 +165,17 @@ impl TrackerClient { // Limit the capacity of messages (channel capacity - 1) let limiter = RequestLimiter::new(capacity); - let dispatcher = dispatcher::create_dispatcher(bind, handshaker, chan_capacity, limiter.clone())?; + let (dispatcher, bound_socket, shutdown_handle) = + dispatcher::create_dispatcher(bind, handshaker, chan_capacity, limiter.clone())?; + + tracing::info!(?bound_socket, "running client"); Ok(TrackerClient { send: dispatcher, limiter, generator: TokenGenerator::new(), + bound_socket, + shutdown_handle, }) } @@ -183,12 +188,15 @@ impl TrackerClient { /// It would panic if unable to send request message. #[instrument(skip(self))] pub fn request(&mut self, addr: SocketAddr, request: ClientRequest) -> Option { - tracing::debug!("requesting"); - if self.limiter.can_initiate() { let token = self.generator.generate(); + + let message = DispatchMessage::Request(addr, token, request); + + tracing::debug!(?message, "requesting"); + self.send - .send(DispatchMessage::Request(addr, token, request)) + .send(message) .expect("bip_utracker: Failed To Send Client Request Message..."); Some(token) @@ -198,13 +206,24 @@ impl TrackerClient { None } } + + #[must_use] + pub fn local_addr(&self) -> SocketAddr { + self.bound_socket + } } impl Drop for TrackerClient { + #[instrument(skip(self))] fn drop(&mut self) { + tracing::info!("shutting down"); + let (shutdown_finished_sender, shutdown_finished_receiver) = mpsc::sync_channel(0); + self.send - .send(DispatchMessage::Shutdown) + .send(DispatchMessage::Shutdown(shutdown_finished_sender)) .expect("bip_utracker: Failed To Send Client Shutdown Message..."); + + shutdown_finished_receiver.recv().unwrap().unwrap(); } } @@ -212,7 +231,7 @@ impl Drop for TrackerClient { /// Associates a `ClientRequest` with a `ClientResponse`. #[allow(clippy::module_name_repetitions)] -#[derive(Copy, Clone, Debug, PartialEq, Eq, Hash)] +#[derive(Copy, Clone, Debug, Hash, PartialEq, Eq, PartialOrd, Ord)] pub struct ClientToken(u32); /// Generates tokens which double as transaction ids. diff --git a/packages/utracker/src/option.rs b/packages/utracker/src/option.rs index bd0a4d01c..304a85f98 100644 --- a/packages/utracker/src/option.rs +++ b/packages/utracker/src/option.rs @@ -73,7 +73,7 @@ impl<'a> AnnounceOptions<'a> { /// # Panics /// /// It would panic if the chunk length is too large. - #[instrument(skip(self, writer))] + #[instrument(skip(self, writer), err)] pub fn write_bytes(&self, mut writer: W) -> std::io::Result<()> where W: std::io::Write, diff --git a/packages/utracker/src/request.rs b/packages/utracker/src/request.rs index 67180dd73..495a1a3cb 100644 --- a/packages/utracker/src/request.rs +++ b/packages/utracker/src/request.rs @@ -8,6 +8,7 @@ use nom::combinator::{map, map_res}; use nom::number::complete::{be_u32, be_u64}; use nom::sequence::tuple; use nom::IResult; +use tracing::instrument; use crate::announce::AnnounceRequest; use crate::scrape::ScrapeRequest; @@ -76,35 +77,40 @@ impl<'a> TrackerRequest<'a> { /// # Errors /// /// It would return an IO Error if unable to write the bytes. - pub fn write_bytes(&self, mut writer: W) -> std::io::Result<()> + #[allow(clippy::needless_borrows_for_generic_args)] + #[instrument(skip(self, writer), err)] + pub fn write_bytes(&self, mut writer: &mut W) -> std::io::Result<()> where W: std::io::Write, { writer.write_u64::(self.connection_id())?; - match self.request_type() { - &RequestType::Connect => { - writer.write_u32::(crate::CONNECT_ACTION_ID)?; - writer.write_u32::(self.transaction_id())?; - } - RequestType::Announce(req) => { - let action_id = if req.source_ip().is_ipv4() { - crate::ANNOUNCE_IPV4_ACTION_ID - } else { - crate::ANNOUNCE_IPV6_ACTION_ID - }; - writer.write_u32::(action_id)?; - writer.write_u32::(self.transaction_id())?; - - req.write_bytes(writer)?; - } - RequestType::Scrape(req) => { - writer.write_u32::(crate::SCRAPE_ACTION_ID)?; - writer.write_u32::(self.transaction_id())?; - - req.write_bytes(writer)?; - } - }; + { + match self.request_type() { + &RequestType::Connect => { + writer.write_u32::(crate::CONNECT_ACTION_ID)?; + writer.write_u32::(self.transaction_id())?; + } + RequestType::Announce(req) => { + let action_id = if req.source_ip().is_ipv4() { + crate::ANNOUNCE_IPV4_ACTION_ID + } else { + crate::ANNOUNCE_IPV6_ACTION_ID + }; + writer.write_u32::(action_id)?; + writer.write_u32::(self.transaction_id())?; + + req.write_bytes(&mut writer)?; + } + RequestType::Scrape(req) => { + writer.write_u32::(crate::SCRAPE_ACTION_ID)?; + writer.write_u32::(self.transaction_id())?; + + req.write_bytes(&mut writer)?; + } + }; + } + writer.flush()?; Ok(()) } diff --git a/packages/utracker/src/response.rs b/packages/utracker/src/response.rs index 51742a290..fa7048bac 100644 --- a/packages/utracker/src/response.rs +++ b/packages/utracker/src/response.rs @@ -91,22 +91,24 @@ impl<'a> TrackerResponse<'a> { writer.write_u32::(action_id)?; writer.write_u32::(self.transaction_id())?; - req.write_bytes(writer)?; + req.write_bytes(&mut writer)?; } ResponseType::Scrape(req) => { writer.write_u32::(crate::SCRAPE_ACTION_ID)?; writer.write_u32::(self.transaction_id())?; - req.write_bytes(writer)?; + req.write_bytes(&mut writer)?; } ResponseType::Error(err) => { writer.write_u32::(ERROR_ACTION_ID)?; writer.write_u32::(self.transaction_id())?; - err.write_bytes(writer)?; + err.write_bytes(&mut writer)?; } }; + writer.flush(); + Ok(()) } diff --git a/packages/utracker/src/scrape.rs b/packages/utracker/src/scrape.rs index 8e0bc64e6..e033c3c3a 100644 --- a/packages/utracker/src/scrape.rs +++ b/packages/utracker/src/scrape.rs @@ -9,6 +9,7 @@ use nom::combinator::map_res; use nom::number::complete::be_i32; use nom::sequence::tuple; use nom::{IResult, Needed}; +use tracing::instrument; use util::bt::{self, InfoHash}; use util::convert; @@ -97,6 +98,7 @@ impl<'a> ScrapeRequest<'a> { /// # Errors /// /// It would return an IO Error if unable to write the bytes. + #[instrument(skip(self, writer), err)] pub fn write_bytes(&self, mut writer: W) -> std::io::Result<()> where W: std::io::Write, diff --git a/packages/utracker/src/server/dispatcher.rs b/packages/utracker/src/server/dispatcher.rs index 1d413658e..d8fb5b044 100644 --- a/packages/utracker/src/server/dispatcher.rs +++ b/packages/utracker/src/server/dispatcher.rs @@ -1,9 +1,9 @@ use std::net::SocketAddr; +use std::sync::mpsc; use nom::IResult; -use tracing::instrument; -use umio::external::Sender; -use umio::{Dispatcher, ELoopBuilder, Provider}; +use tracing::{instrument, Level}; +use umio::{Dispatcher, ELoopBuilder, MessageSender, Provider, ShutdownHandle}; use crate::announce::AnnounceRequest; use crate::error::ErrorResponse; @@ -17,17 +17,20 @@ const EXPECTED_PACKET_LENGTH: usize = 1500; /// Internal dispatch message for servers. #[derive(Debug)] pub enum DispatchMessage { - Shutdown, + Shutdown(mpsc::SyncSender>), } /// Create a new background dispatcher to service requests. #[allow(clippy::module_name_repetitions)] #[instrument(skip())] -pub fn create_dispatcher(bind: SocketAddr, handler: H) -> std::io::Result> +pub fn create_dispatcher( + bind: SocketAddr, + handler: H, +) -> std::io::Result<(MessageSender, SocketAddr, ShutdownHandle)> where H: ServerHandler + std::fmt::Debug + 'static, { - tracing::debug!("create dispatcher"); + tracing::trace!("create dispatcher"); let builder = ELoopBuilder::new() .channel_capacity(1) @@ -35,16 +38,26 @@ where .bind_address(bind) .buffer_length(EXPECTED_PACKET_LENGTH); - let mut eloop = builder.build()?; + let (mut eloop, socket, shutdown) = builder.build()?; let channel = eloop.channel(); - let dispatch = ServerDispatcher::new(handler); + let dispatcher = ServerDispatcher::new(handler); + + let handle = { + let (started_eloop_sender, started_eloop_receiver) = mpsc::sync_channel(0); + + let handle = std::thread::spawn(move || { + eloop.run(dispatcher, started_eloop_sender).unwrap(); + }); + + let () = started_eloop_receiver + .recv() + .map_err(|e| std::io::Error::new(std::io::ErrorKind::BrokenPipe, e))??; - std::thread::spawn(move || { - eloop.run(dispatch).expect("bip_utracker: ELoop Shutdown Unexpectedly..."); - }); + handle + }; - Ok(channel) + Ok((channel, socket, shutdown)) } // ----------------------------------------------------------------------------// @@ -63,10 +76,8 @@ where H: ServerHandler + std::fmt::Debug, { /// Create a new `ServerDispatcher`. - #[instrument(skip(), ret)] + #[instrument(skip(), ret(level = Level::TRACE))] fn new(handler: H) -> ServerDispatcher { - tracing::debug!("new"); - ServerDispatcher { handler } } @@ -78,7 +89,7 @@ where request: &TrackerRequest<'_>, addr: SocketAddr, ) { - tracing::debug!("process request"); + tracing::trace!("process request"); let conn_id = request.connection_id(); let trans_id = request.transaction_id(); @@ -106,8 +117,6 @@ where /// Forward a connect request on to the appropriate handler method. #[instrument(skip(self, provider))] fn forward_connect(&mut self, provider: &mut Provider<'_, ServerDispatcher>, trans_id: u32, addr: SocketAddr) { - tracing::debug!("forward connect"); - let Some(attempt) = self.handler.connect(addr) else { tracing::warn!("connect attempt canceled"); @@ -121,6 +130,8 @@ where let response = TrackerResponse::new(trans_id, response_type); + tracing::trace!(?response, "forward connect"); + write_response(provider, &response, addr); } @@ -134,8 +145,6 @@ where request: &AnnounceRequest<'_>, addr: SocketAddr, ) { - tracing::debug!("forward announce"); - let Some(attempt) = self.handler.announce(addr, conn_id, request) else { tracing::warn!("announce attempt canceled"); @@ -148,6 +157,8 @@ where }; let response = TrackerResponse::new(trans_id, response_type); + tracing::trace!(?response, "forward announce"); + write_response(provider, &response, addr); } @@ -188,24 +199,21 @@ where { tracing::debug!("write response"); - provider.outgoing(|buffer| { - let mut cursor = std::io::Cursor::new(buffer); + provider.set_dest(addr); - match response.write_bytes(&mut cursor) { - Ok(()) => Some((cursor.position().try_into().unwrap(), addr)), - Err(e) => { - tracing::error!("error writing response to cursor: {e}"); - None - } + match response.write_bytes(provider) { + Ok(()) => (), + Err(e) => { + tracing::error!(%e, "error writing response to cursor"); } - }); + } } impl Dispatcher for ServerDispatcher where H: ServerHandler + std::fmt::Debug, { - type Timeout = (); + type TimeoutToken = (); type Message = DispatchMessage; #[instrument(skip(self, provider))] @@ -217,7 +225,7 @@ where self.process_request(&mut provider, &request, addr); } Err(e) => { - tracing::error!("received an incoming error message: {e}"); + tracing::error!(%e, "received an incoming error message"); } }; } @@ -225,10 +233,12 @@ where #[instrument(skip(self, provider))] fn notify(&mut self, mut provider: Provider<'_, Self>, message: DispatchMessage) { let () = match message { - DispatchMessage::Shutdown => { + DispatchMessage::Shutdown(shutdown_finished_sender) => { tracing::debug!("received a shutdown notification"); provider.shutdown(); + + let () = shutdown_finished_sender.send(Ok(())).unwrap(); } }; } diff --git a/packages/utracker/src/server/mod.rs b/packages/utracker/src/server/mod.rs index 5e2126208..1c971a2c8 100644 --- a/packages/utracker/src/server/mod.rs +++ b/packages/utracker/src/server/mod.rs @@ -1,7 +1,8 @@ use std::net::SocketAddr; +use std::sync::mpsc; -use tracing::instrument; -use umio::external::Sender; +use tracing::{instrument, Level}; +use umio::{MessageSender, ShutdownHandle}; use crate::server::dispatcher::DispatchMessage; use crate::server::handler::ServerHandler; @@ -15,7 +16,9 @@ pub mod handler; #[allow(clippy::module_name_repetitions)] #[derive(Debug)] pub struct TrackerServer { - dispatcher: Sender, + dispatcher: MessageSender, + bound_socket: SocketAddr, + shutdown_handle: ShutdownHandle, } impl TrackerServer { @@ -24,26 +27,38 @@ impl TrackerServer { /// # Errors /// /// It would return an IO Error if unable to run the server. - #[instrument(skip(), ret)] + #[instrument(skip(), ret(level = Level::TRACE))] pub fn run(bind: SocketAddr, handler: H) -> std::io::Result where H: ServerHandler + std::fmt::Debug + 'static, { - tracing::info!("running server"); + let (dispatcher, bound_socket, shutdown_handle) = dispatcher::create_dispatcher(bind, handler)?; - let dispatcher = dispatcher::create_dispatcher(bind, handler)?; + tracing::info!(?bound_socket, "running server"); - Ok(TrackerServer { dispatcher }) + Ok(TrackerServer { + dispatcher, + bound_socket, + shutdown_handle, + }) + } + + #[must_use] + pub fn local_addr(&self) -> SocketAddr { + self.bound_socket } } impl Drop for TrackerServer { #[instrument(skip(self))] fn drop(&mut self) { - tracing::debug!("server was dropped, sending shutdown notification..."); + tracing::info!("shutting down"); + let (shutdown_finished_sender, shutdown_finished_receiver) = mpsc::sync_channel(0); self.dispatcher - .send(DispatchMessage::Shutdown) + .send(DispatchMessage::Shutdown(shutdown_finished_sender)) .expect("bip_utracker: TrackerServer Failed To Send Shutdown Message"); + + shutdown_finished_receiver.recv().unwrap().unwrap(); } } diff --git a/packages/utracker/tests/common/mod.rs b/packages/utracker/tests/common/mod.rs index f410126fa..d5bbd076a 100644 --- a/packages/utracker/tests/common/mod.rs +++ b/packages/utracker/tests/common/mod.rs @@ -1,5 +1,5 @@ use std::collections::{HashMap, HashSet}; -use std::net::{SocketAddr, SocketAddrV4, SocketAddrV6}; +use std::net::{Ipv4Addr, SocketAddr, SocketAddrV4, SocketAddrV6}; use std::sync::{Arc, Mutex, Once}; use std::time::Duration; @@ -8,8 +8,8 @@ use futures::sink::SinkExt; use futures::stream::StreamExt; use futures::{Sink, Stream}; use handshake::DiscoveryInfo; -use tracing::instrument; use tracing::level_filters::LevelFilter; +use tracing::{instrument, Level}; use util::bt::{InfoHash, PeerId}; use util::trans::{LocallyShuffledIds, TransactionIds}; use utracker::announce::{AnnounceEvent, AnnounceRequest, AnnounceResponse}; @@ -20,6 +20,9 @@ use utracker::{HandshakerMessage, ServerHandler, ServerResult}; #[allow(dead_code)] pub const DEFAULT_TIMEOUT: Duration = Duration::from_millis(1000); +#[allow(dead_code)] +pub const LOOPBACK_IPV4: SocketAddr = SocketAddr::V4(SocketAddrV4::new(Ipv4Addr::LOCALHOST, 0)); + const NUM_PEERS_RETURNED: usize = 20; #[allow(dead_code)] @@ -58,7 +61,7 @@ pub struct InnerMockTrackerHandler { #[allow(dead_code)] impl MockTrackerHandler { - #[instrument(skip(), ret)] + #[instrument(skip(), ret(level = Level::TRACE))] pub fn new() -> MockTrackerHandler { tracing::debug!("new mock handler"); @@ -77,7 +80,7 @@ impl MockTrackerHandler { } impl ServerHandler for MockTrackerHandler { - #[instrument(skip(self), ret)] + #[instrument(skip(self), ret(level = Level::TRACE))] fn connect(&mut self, addr: SocketAddr) -> Option> { tracing::debug!("mock connect"); @@ -89,7 +92,7 @@ impl ServerHandler for MockTrackerHandler { Some(Ok(cid)) } - #[instrument(skip(self), ret)] + #[instrument(skip(self), ret(level = Level::TRACE))] fn announce( &mut self, addr: SocketAddr, @@ -158,7 +161,7 @@ impl ServerHandler for MockTrackerHandler { } } - #[instrument(skip(self), ret)] + #[instrument(skip(self), ret(level = Level::TRACE))] fn scrape(&mut self, _: SocketAddr, id: u64, req: &ScrapeRequest<'_>) -> Option>> { tracing::debug!("mock scrape"); @@ -211,16 +214,16 @@ impl DiscoveryInfo for MockHandshakerSink { impl Sink> for MockHandshakerSink { type Error = std::io::Error; - #[instrument(skip(self, cx), ret)] + #[instrument(skip(self, cx), ret(level = Level::TRACE))] fn poll_ready(self: std::pin::Pin<&mut Self>, cx: &mut std::task::Context<'_>) -> std::task::Poll> { - tracing::debug!("polling ready"); + tracing::trace!("polling ready"); self.send .poll_ready(cx) .map_err(|e| std::io::Error::new(std::io::ErrorKind::Other, e)) } - #[instrument(skip(self), ret)] + #[instrument(skip(self), ret(level = Level::TRACE))] fn start_send(mut self: std::pin::Pin<&mut Self>, item: std::io::Result) -> Result<(), Self::Error> { tracing::debug!("starting send"); @@ -229,19 +232,19 @@ impl Sink> for MockHandshakerSink { .map_err(|e| std::io::Error::new(std::io::ErrorKind::Other, e)) } - #[instrument(skip(self, cx), ret)] + #[instrument(skip(self, cx), ret(level = Level::TRACE))] fn poll_flush( mut self: std::pin::Pin<&mut Self>, cx: &mut std::task::Context<'_>, ) -> std::task::Poll> { - tracing::debug!("polling flush"); + tracing::trace!("polling flush"); self.send .poll_flush_unpin(cx) .map_err(|e| std::io::Error::new(std::io::ErrorKind::Other, e)) } - #[instrument(skip(self, cx), ret)] + #[instrument(skip(self, cx), ret(level = Level::TRACE))] fn poll_close( mut self: std::pin::Pin<&mut Self>, cx: &mut std::task::Context<'_>, @@ -261,9 +264,9 @@ pub struct MockHandshakerStream { impl Stream for MockHandshakerStream { type Item = std::io::Result; - #[instrument(skip(self, cx), ret)] + #[instrument(skip(self, cx), ret(level = Level::TRACE))] fn poll_next(mut self: std::pin::Pin<&mut Self>, cx: &mut std::task::Context<'_>) -> std::task::Poll> { - tracing::debug!("polling next"); + tracing::trace!("polling next"); self.recv.poll_next_unpin(cx).map(|maybe| maybe.map(Ok)) } diff --git a/packages/utracker/tests/test_announce_start.rs b/packages/utracker/tests/test_announce_start.rs index 37ef75f64..b25f3f4ad 100644 --- a/packages/utracker/tests/test_announce_start.rs +++ b/packages/utracker/tests/test_announce_start.rs @@ -1,7 +1,6 @@ use std::net::SocketAddr; -use std::time::Duration; -use common::{handshaker, tracing_stderr_init, MockTrackerHandler, DEFAULT_TIMEOUT, INIT}; +use common::{handshaker, tracing_stderr_init, MockTrackerHandler, DEFAULT_TIMEOUT, INIT, LOOPBACK_IPV4}; use futures::StreamExt as _; use handshake::Protocol; use tracing::level_filters::LevelFilter; @@ -19,25 +18,22 @@ async fn positive_announce_started() { let (handshaker_sender, mut handshaker_receiver) = handshaker(); - let server_addr = "127.0.0.1:3501".parse().unwrap(); let mock_handler = MockTrackerHandler::new(); - let _server = TrackerServer::run(server_addr, mock_handler).unwrap(); + let server = TrackerServer::run(LOOPBACK_IPV4, mock_handler).unwrap(); - std::thread::sleep(Duration::from_millis(100)); - - let mut client = TrackerClient::new("127.0.0.1:4501".parse().unwrap(), handshaker_sender, None).unwrap(); + let mut client = TrackerClient::run(LOOPBACK_IPV4, handshaker_sender, None).unwrap(); let hash = [0u8; bt::INFO_HASH_LEN].into(); - tracing::warn!("sending announce"); + tracing::debug!("sending announce"); let _send_token = client .request( - server_addr, + server.local_addr(), ClientRequest::Announce(hash, ClientState::new(0, 0, 0, AnnounceEvent::Started)), ) .unwrap(); - tracing::warn!("receiving initiate message"); + tracing::debug!("receiving initiate message"); let init_msg = match tokio::time::timeout(DEFAULT_TIMEOUT, handshaker_receiver.next()) .await .unwrap() @@ -54,7 +50,7 @@ async fn positive_announce_started() { assert_eq!(&exp_peer_addr, init_msg.address()); assert_eq!(&hash, init_msg.hash()); - tracing::warn!("receiving client metadata"); + tracing::debug!("receiving client metadata"); let metadata = match tokio::time::timeout(DEFAULT_TIMEOUT, handshaker_receiver.next()) .await .unwrap() diff --git a/packages/utracker/tests/test_announce_stop.rs b/packages/utracker/tests/test_announce_stop.rs index 4fc2fc505..bca57c523 100644 --- a/packages/utracker/tests/test_announce_stop.rs +++ b/packages/utracker/tests/test_announce_stop.rs @@ -1,6 +1,4 @@ -use std::time::Duration; - -use common::{handshaker, tracing_stderr_init, MockTrackerHandler, DEFAULT_TIMEOUT, INIT}; +use common::{handshaker, tracing_stderr_init, MockTrackerHandler, DEFAULT_TIMEOUT, INIT, LOOPBACK_IPV4}; use futures::StreamExt as _; use tracing::level_filters::LevelFilter; use util::bt::{self}; @@ -17,13 +15,10 @@ async fn positive_announce_stopped() { let (sink, mut stream) = handshaker(); - let server_addr = "127.0.0.1:3502".parse().unwrap(); let mock_handler = MockTrackerHandler::new(); - let _server = TrackerServer::run(server_addr, mock_handler).unwrap(); - - std::thread::sleep(Duration::from_millis(100)); + let server = TrackerServer::run(LOOPBACK_IPV4, mock_handler).unwrap(); - let mut client = TrackerClient::new("127.0.0.1:4502".parse().unwrap(), sink, None).unwrap(); + let mut client = TrackerClient::run(LOOPBACK_IPV4, sink, None).unwrap(); let info_hash = [0u8; bt::INFO_HASH_LEN].into(); @@ -31,7 +26,7 @@ async fn positive_announce_stopped() { { let _send_token = client .request( - server_addr, + server.local_addr(), ClientRequest::Announce(info_hash, ClientState::new(0, 0, 0, AnnounceEvent::Started)), ) .unwrap(); @@ -66,7 +61,7 @@ async fn positive_announce_stopped() { { let _send_token = client .request( - server_addr, + server.local_addr(), ClientRequest::Announce(info_hash, ClientState::new(0, 0, 0, AnnounceEvent::Stopped)), ) .unwrap(); diff --git a/packages/utracker/tests/test_client_drop.rs b/packages/utracker/tests/test_client_drop.rs index 3564cb9e7..28797f3ef 100644 --- a/packages/utracker/tests/test_client_drop.rs +++ b/packages/utracker/tests/test_client_drop.rs @@ -1,7 +1,8 @@ use std::net::SocketAddr; -use common::{handshaker, tracing_stderr_init, DEFAULT_TIMEOUT, INIT}; +use common::{handshaker, tracing_stderr_init, DEFAULT_TIMEOUT, INIT, LOOPBACK_IPV4}; use futures::StreamExt as _; +use tokio::net::UdpSocket; use tracing::level_filters::LevelFilter; use util::bt::{self}; use utracker::announce::{AnnounceEvent, ClientState}; @@ -15,17 +16,22 @@ async fn positive_client_request_failed() { tracing_stderr_init(LevelFilter::ERROR); }); - let (sink, mut stream) = handshaker(); + let (sink, stream) = handshaker(); - let server_addr: SocketAddr = "127.0.0.1:3503".parse().unwrap(); - // Don't actually create the server since we want the request to wait for a little bit until we drop + // We bind a temp socket, then drop it... + let disconnected_addr: SocketAddr = { + let socket = UdpSocket::bind(LOOPBACK_IPV4).await.unwrap(); + socket.local_addr().unwrap() + }; + + tokio::task::yield_now().await; let send_token = { - let mut client = TrackerClient::new("127.0.0.1:4503".parse().unwrap(), sink, None).unwrap(); + let mut client = TrackerClient::run(LOOPBACK_IPV4, sink, None).unwrap(); client .request( - server_addr, + disconnected_addr, ClientRequest::Announce( [0u8; bt::INFO_HASH_LEN].into(), ClientState::new(0, 0, 0, AnnounceEvent::None), @@ -35,20 +41,21 @@ async fn positive_client_request_failed() { }; // Client is now dropped - let metadata = match tokio::time::timeout(DEFAULT_TIMEOUT, stream.next()) + let mut messages: Vec<_> = tokio::time::timeout(DEFAULT_TIMEOUT, stream.collect()) .await - .unwrap() - .unwrap() - .unwrap() - { - HandshakerMessage::InitiateMessage(_) => unreachable!(), - HandshakerMessage::ClientMetadata(metadata) => metadata, - }; + .expect("it should not time out"); + + while let Some(message) = messages.pop() { + let metadata = match message.expect("it should be a handshake message") { + HandshakerMessage::InitiateMessage(_) => unreachable!(), + HandshakerMessage::ClientMetadata(metadata) => metadata, + }; - assert_eq!(send_token, metadata.token()); + assert_eq!(send_token, metadata.token()); - match metadata.result() { - Err(ClientError::ClientShutdown) => (), - _ => panic!("Did Not Receive ClientShutdown..."), + match metadata.result() { + Err(ClientError::ClientShutdown) => (), + _ => panic!("Did Not Receive ClientShutdown..."), + } } } diff --git a/packages/utracker/tests/test_client_full.rs b/packages/utracker/tests/test_client_full.rs index f2e78d150..b74f068a7 100644 --- a/packages/utracker/tests/test_client_full.rs +++ b/packages/utracker/tests/test_client_full.rs @@ -1,5 +1,8 @@ -use common::{handshaker, tracing_stderr_init, DEFAULT_TIMEOUT, INIT}; +use std::net::SocketAddr; + +use common::{handshaker, tracing_stderr_init, DEFAULT_TIMEOUT, INIT, LOOPBACK_IPV4}; use futures::StreamExt as _; +use tokio::net::UdpSocket; use tracing::level_filters::LevelFilter; use util::bt::{self}; use utracker::announce::{AnnounceEvent, ClientState}; @@ -15,40 +18,48 @@ async fn positive_client_request_dropped() { let (sink, stream) = handshaker(); - let server_addr = "127.0.0.1:3504".parse().unwrap(); + // We bind a temp socket, then drop it... + let disconnected_addr: SocketAddr = { + let socket = UdpSocket::bind(LOOPBACK_IPV4).await.unwrap(); + socket.local_addr().unwrap() + }; + + tokio::task::yield_now().await; let request_capacity = 10; - let mut client = TrackerClient::new("127.0.0.1:4504".parse().unwrap(), sink, Some(request_capacity)).unwrap(); + { + let mut client = TrackerClient::run(LOOPBACK_IPV4, sink, Some(request_capacity)).unwrap(); + + tracing::warn!("sending announce requests to fill buffer"); + for i in 1..=request_capacity { + tracing::warn!("request {i} of {request_capacity}"); - tracing::warn!("sending announce requests to fill buffer"); - for i in 1..=request_capacity { - tracing::warn!("request {i} of {request_capacity}"); + client + .request( + disconnected_addr, + ClientRequest::Announce( + [0u8; bt::INFO_HASH_LEN].into(), + ClientState::new(0, 0, 0, AnnounceEvent::Started), + ), + ) + .unwrap(); + } - client + tracing::warn!("sending one more announce request, it should fail"); + assert!(client .request( - server_addr, + disconnected_addr, ClientRequest::Announce( [0u8; bt::INFO_HASH_LEN].into(), - ClientState::new(0, 0, 0, AnnounceEvent::Started), - ), + ClientState::new(0, 0, 0, AnnounceEvent::Started) + ) ) - .unwrap(); + .is_none()); } + // Client is now dropped - tracing::warn!("sending one more announce request, it should fail"); - assert!(client - .request( - server_addr, - ClientRequest::Announce( - [0u8; bt::INFO_HASH_LEN].into(), - ClientState::new(0, 0, 0, AnnounceEvent::Started) - ) - ) - .is_none()); - - std::mem::drop(client); - + tracing::warn!("collecting remaining messages"); let buffer: Vec<_> = tokio::time::timeout(DEFAULT_TIMEOUT, stream.collect()).await.unwrap(); assert_eq!(request_capacity, buffer.len()); } diff --git a/packages/utracker/tests/test_connect.rs b/packages/utracker/tests/test_connect.rs index 2111d5681..8bbf07845 100644 --- a/packages/utracker/tests/test_connect.rs +++ b/packages/utracker/tests/test_connect.rs @@ -1,6 +1,4 @@ -use std::time::Duration; - -use common::{handshaker, tracing_stderr_init, MockTrackerHandler, DEFAULT_TIMEOUT, INIT}; +use common::{handshaker, tracing_stderr_init, MockTrackerHandler, DEFAULT_TIMEOUT, INIT, LOOPBACK_IPV4}; use futures::StreamExt as _; use tracing::level_filters::LevelFilter; use util::bt::{self}; @@ -17,17 +15,14 @@ async fn positive_receive_connect_id() { let (sink, mut stream) = handshaker(); - let server_addr = "127.0.0.1:3505".parse().unwrap(); let mock_handler = MockTrackerHandler::new(); - let _server = TrackerServer::run(server_addr, mock_handler).unwrap(); - - std::thread::sleep(Duration::from_millis(100)); + let server = TrackerServer::run(LOOPBACK_IPV4, mock_handler).unwrap(); - let mut client = TrackerClient::new("127.0.0.1:4505".parse().unwrap(), sink, None).unwrap(); + let mut client = TrackerClient::run(LOOPBACK_IPV4, sink, None).unwrap(); let send_token = client .request( - server_addr, + server.local_addr(), ClientRequest::Announce( [0u8; bt::INFO_HASH_LEN].into(), ClientState::new(0, 0, 0, AnnounceEvent::None), diff --git a/packages/utracker/tests/test_connect_cache.rs b/packages/utracker/tests/test_connect_cache.rs index 25b89f277..0cfdd6b57 100644 --- a/packages/utracker/tests/test_connect_cache.rs +++ b/packages/utracker/tests/test_connect_cache.rs @@ -1,6 +1,4 @@ -use std::time::Duration; - -use common::{tracing_stderr_init, MockTrackerHandler, DEFAULT_TIMEOUT, INIT}; +use common::{tracing_stderr_init, MockTrackerHandler, DEFAULT_TIMEOUT, INIT, LOOPBACK_IPV4}; use futures::StreamExt as _; use tracing::level_filters::LevelFilter; use util::bt::{self}; @@ -16,18 +14,17 @@ async fn positive_connection_id_cache() { let (sink, mut stream) = common::handshaker(); - let server_addr = "127.0.0.1:3506".parse().unwrap(); let mock_handler = MockTrackerHandler::new(); - let _server = TrackerServer::run(server_addr, mock_handler.clone()).unwrap(); - - std::thread::sleep(Duration::from_millis(100)); + let server = TrackerServer::run(LOOPBACK_IPV4, mock_handler.clone()).unwrap(); - let mut client = TrackerClient::new("127.0.0.1:4506".parse().unwrap(), sink, None).unwrap(); + let mut client = TrackerClient::run(LOOPBACK_IPV4, sink, None).unwrap(); let first_hash = [0u8; bt::INFO_HASH_LEN].into(); let second_hash = [1u8; bt::INFO_HASH_LEN].into(); - client.request(server_addr, ClientRequest::Scrape(first_hash)).unwrap(); + client + .request(server.local_addr(), ClientRequest::Scrape(first_hash)) + .unwrap(); tokio::time::timeout(DEFAULT_TIMEOUT, stream.next()) .await .unwrap() @@ -37,7 +34,9 @@ async fn positive_connection_id_cache() { assert_eq!(mock_handler.num_active_connect_ids(), 1); for _ in 0..10 { - client.request(server_addr, ClientRequest::Scrape(second_hash)).unwrap(); + client + .request(server.local_addr(), ClientRequest::Scrape(second_hash)) + .unwrap(); } for _ in 0..10 { diff --git a/packages/utracker/tests/test_scrape.rs b/packages/utracker/tests/test_scrape.rs index 607f02e66..3fb7d6999 100644 --- a/packages/utracker/tests/test_scrape.rs +++ b/packages/utracker/tests/test_scrape.rs @@ -1,6 +1,4 @@ -use std::time::Duration; - -use common::{handshaker, tracing_stderr_init, MockTrackerHandler, DEFAULT_TIMEOUT, INIT}; +use common::{handshaker, tracing_stderr_init, MockTrackerHandler, DEFAULT_TIMEOUT, INIT, LOOPBACK_IPV4}; use futures::StreamExt as _; use tracing::level_filters::LevelFilter; use util::bt::{self}; @@ -16,16 +14,13 @@ async fn positive_scrape() { let (sink, mut stream) = handshaker(); - let server_addr = "127.0.0.1:3507".parse().unwrap(); let mock_handler = MockTrackerHandler::new(); - let _server = TrackerServer::run(server_addr, mock_handler).unwrap(); - - std::thread::sleep(Duration::from_millis(100)); + let server = TrackerServer::run(LOOPBACK_IPV4, mock_handler).unwrap(); - let mut client = TrackerClient::new("127.0.0.1:4507".parse().unwrap(), sink, None).unwrap(); + let mut client = TrackerClient::run(LOOPBACK_IPV4, sink, None).unwrap(); let send_token = client - .request(server_addr, ClientRequest::Scrape([0u8; bt::INFO_HASH_LEN].into())) + .request(server.local_addr(), ClientRequest::Scrape([0u8; bt::INFO_HASH_LEN].into())) .unwrap(); let metadata = match tokio::time::timeout(DEFAULT_TIMEOUT, stream.next()) diff --git a/packages/utracker/tests/test_server_drop.rs b/packages/utracker/tests/test_server_drop.rs index 23002461a..5ddcfdebb 100644 --- a/packages/utracker/tests/test_server_drop.rs +++ b/packages/utracker/tests/test_server_drop.rs @@ -1,7 +1,7 @@ use std::net::UdpSocket; use std::time::Duration; -use common::{tracing_stderr_init, MockTrackerHandler, INIT}; +use common::{tracing_stderr_init, MockTrackerHandler, INIT, LOOPBACK_IPV4}; use tracing::level_filters::LevelFilter; use utracker::request::{self, RequestType, TrackerRequest}; use utracker::TrackerServer; @@ -15,12 +15,12 @@ fn positive_server_dropped() { tracing_stderr_init(LevelFilter::ERROR); }); - let server_addr = "127.0.0.1:3508".parse().unwrap(); let mock_handler = MockTrackerHandler::new(); - { - let server = TrackerServer::run(server_addr, mock_handler).unwrap(); - } + let old_server_socket = { + let server = TrackerServer::run(LOOPBACK_IPV4, mock_handler).unwrap(); + server.local_addr() + }; // Server is now shut down let mut send_message = Vec::new(); @@ -28,8 +28,8 @@ fn positive_server_dropped() { let request = TrackerRequest::new(request::CONNECT_ID_PROTOCOL_ID, 0, RequestType::Connect); request.write_bytes(&mut send_message).unwrap(); - let socket = UdpSocket::bind("127.0.0.1:4508").unwrap(); - socket.send_to(&send_message, server_addr); + let socket = UdpSocket::bind(LOOPBACK_IPV4).unwrap(); + socket.send_to(&send_message, old_server_socket); let mut receive_message = vec![0u8; 1500]; socket.set_read_timeout(Some(Duration::from_millis(200)));