From 4c51bec8e3d284188604094d226173880bdd26b8 Mon Sep 17 00:00:00 2001 From: Cameron Garnham Date: Sat, 10 Aug 2024 19:29:21 +0200 Subject: [PATCH] tests work --- packages/util/src/trans/locally_shuffled.rs | 2 +- packages/util/src/trans/sequential.rs | 2 +- packages/utracker/Cargo.toml | 1 + packages/utracker/src/announce.rs | 5 + packages/utracker/src/client/dispatcher.rs | 182 ++++++++++++++---- packages/utracker/src/client/error.rs | 23 ++- packages/utracker/src/client/mod.rs | 47 +++-- packages/utracker/src/error.rs | 9 +- packages/utracker/src/option.rs | 40 +++- packages/utracker/src/request.rs | 2 + packages/utracker/src/response.rs | 2 + packages/utracker/src/scrape.rs | 34 ++-- packages/utracker/src/server/dispatcher.rs | 153 ++++++++++----- packages/utracker/src/server/handler.rs | 23 +-- packages/utracker/src/server/mod.rs | 18 +- packages/utracker/tests/common/mod.rs | 85 +++++--- .../utracker/tests/test_announce_start.rs | 26 ++- packages/utracker/tests/test_announce_stop.rs | 27 ++- packages/utracker/tests/test_client_drop.rs | 13 +- packages/utracker/tests/test_client_full.rs | 14 +- packages/utracker/tests/test_connect.rs | 20 +- packages/utracker/tests/test_connect_cache.rs | 18 +- packages/utracker/tests/test_scrape.rs | 13 +- packages/utracker/tests/test_server_drop.rs | 2 +- 24 files changed, 540 insertions(+), 221 deletions(-) diff --git a/packages/util/src/trans/locally_shuffled.rs b/packages/util/src/trans/locally_shuffled.rs index f9a2faec4..6d40bb928 100644 --- a/packages/util/src/trans/locally_shuffled.rs +++ b/packages/util/src/trans/locally_shuffled.rs @@ -20,7 +20,7 @@ const TRANSACTION_ID_PREALLOC_LEN: usize = 2048; /// transaction type (such as u64) but also works with smaller types. #[allow(clippy::module_name_repetitions)] -#[derive(Default)] +#[derive(Debug, Default)] pub struct LocallyShuffledIds { sequential: SequentialIds, stored_ids: Vec, diff --git a/packages/util/src/trans/sequential.rs b/packages/util/src/trans/sequential.rs index ea414540d..a652692f4 100644 --- a/packages/util/src/trans/sequential.rs +++ b/packages/util/src/trans/sequential.rs @@ -7,7 +7,7 @@ use crate::trans::TransactionIds; /// Generates sequentially unique ids and wraps when overflow occurs. #[allow(clippy::module_name_repetitions)] -#[derive(Default)] +#[derive(Debug, Default)] pub struct SequentialIds { next_id: T, } diff --git a/packages/utracker/Cargo.toml b/packages/utracker/Cargo.toml index b80b7e2a4..13961b77f 100644 --- a/packages/utracker/Cargo.toml +++ b/packages/utracker/Cargo.toml @@ -26,6 +26,7 @@ nom = "7" rand = "0" umio = "0" tracing = "0" +thiserror = "1" [dev-dependencies] tokio = { version = "1", features = ["full"] } diff --git a/packages/utracker/src/announce.rs b/packages/utracker/src/announce.rs index 1b2e171ea..82a0c6454 100644 --- a/packages/utracker/src/announce.rs +++ b/packages/utracker/src/announce.rs @@ -12,6 +12,7 @@ use nom::multi::count; use nom::number::complete::{be_i32, be_i64, be_u16, be_u32, be_u8}; use nom::sequence::tuple; use nom::IResult; +use tracing::instrument; use util::bt::{self, InfoHash, PeerId}; use util::convert; @@ -47,6 +48,7 @@ pub struct AnnounceRequest<'a> { #[allow(clippy::too_many_arguments)] impl<'a> AnnounceRequest<'a> { /// Create a new `AnnounceRequest`. + #[instrument(skip())] #[must_use] pub fn new( hash: InfoHash, @@ -58,6 +60,7 @@ impl<'a> AnnounceRequest<'a> { port: u16, options: AnnounceOptions<'a>, ) -> AnnounceRequest<'a> { + tracing::trace!("new announce request"); AnnounceRequest { info_hash: hash, peer_id, @@ -324,7 +327,9 @@ pub struct ClientState { impl ClientState { /// Create a new `ClientState`. #[must_use] + #[instrument(skip())] pub fn new(bytes_downloaded: i64, bytes_left: i64, bytes_uploaded: i64, event: AnnounceEvent) -> ClientState { + tracing::trace!("new client state"); ClientState { downloaded: bytes_downloaded, left: bytes_left, diff --git a/packages/utracker/src/client/dispatcher.rs b/packages/utracker/src/client/dispatcher.rs index 2daafe4d4..b425dc454 100644 --- a/packages/utracker/src/client/dispatcher.rs +++ b/packages/utracker/src/client/dispatcher.rs @@ -6,11 +6,13 @@ use std::thread; use chrono::offset::Utc; use chrono::{DateTime, Duration}; -use futures::future::Either; +use futures::executor::block_on; +use futures::future::{BoxFuture, Either}; use futures::sink::Sink; -use futures::SinkExt; +use futures::{FutureExt, SinkExt}; use handshake::{DiscoveryInfo, InitiateMessage, Protocol}; use nom::IResult; +use tracing::instrument; use umio::external::{self, Timeout}; use umio::{Dispatcher, ELoopBuilder, Provider}; use util::bt::PeerId; @@ -30,12 +32,14 @@ const CONNECTION_ID_VALID_DURATION_MILLIS: i64 = 60000; const MAXIMUM_REQUEST_RETRANSMIT_ATTEMPTS: u64 = 8; /// Internal dispatch timeout. +#[derive(Debug)] enum DispatchTimeout { Connect(ClientToken), CleanUp, } /// Internal dispatch message for clients. +#[derive(Debug)] pub enum DispatchMessage { Request(SocketAddr, ClientToken, ClientRequest), StartTimer, @@ -46,6 +50,7 @@ pub enum DispatchMessage { /// /// Assumes `msg_capacity` is less than `usize::max_value`(). #[allow(clippy::module_name_repetitions)] +#[instrument(skip())] pub fn create_dispatcher( bind: SocketAddr, handshaker: H, @@ -53,8 +58,11 @@ pub fn create_dispatcher( limiter: RequestLimiter, ) -> std::io::Result> where - H: Sink> + DiscoveryInfo + Send + Unpin + 'static, + H: Sink> + std::fmt::Debug + DiscoveryInfo + Send + Unpin + 'static, + H::Error: std::fmt::Display, { + tracing::debug!("creating dispatcher"); + // Timer capacity is plus one for the cache cleanup timer let builder = ELoopBuilder::new() .channel_capacity(msg_capacity) @@ -81,6 +89,7 @@ where // ----------------------------------------------------------------------------// /// Dispatcher that executes requests asynchronously. +#[derive(Debug)] struct ClientDispatcher { handshaker: H, pid: PeerId, @@ -93,10 +102,14 @@ struct ClientDispatcher { impl ClientDispatcher where - H: Sink> + DiscoveryInfo + Send + Unpin + 'static, + H: Sink> + std::fmt::Debug + DiscoveryInfo + Send + Unpin + 'static, + H::Error: std::fmt::Display, { /// Create a new `ClientDispatcher`. + #[instrument(skip(), ret)] pub fn new(handshaker: H, bind: SocketAddr, limiter: RequestLimiter) -> ClientDispatcher { + tracing::debug!("new client dispatcher"); + let peer_id = handshaker.peer_id(); let port = handshaker.port(); @@ -112,7 +125,10 @@ where } /// Shutdown the current dispatcher, notifying all pending requests. + #[instrument(skip(self, provider))] pub fn shutdown(&mut self, provider: &mut Provider<'_, ClientDispatcher>) { + tracing::debug!("shutting down client dispatcher"); + // Notify all active requests with the appropriate error for token_index in 0..self.active_requests.len() { let next_token = *self.active_requests.keys().nth(token_index).unwrap(); @@ -126,16 +142,20 @@ where } /// Finish a request by sending the result back to the client. - pub async fn notify_client(&mut self, token: ClientToken, result: ClientResult) { - self.handshaker - .send(Ok(ClientMetadata::new(token, result).into())) - .await - .unwrap_or_else(|_| panic!("NEED TO FIX")); + #[instrument(skip(self))] + pub fn notify_client(&mut self, token: ClientToken, result: ClientResult) { + tracing::info!("notifying clients"); + + match block_on(self.handshaker.send(Ok(ClientMetadata::new(token, result).into()))) { + Ok(()) => tracing::debug!("client metadata sent"), + Err(e) => tracing::error!("sending client metadata failed with error: {e}"), + } self.limiter.acknowledge(); } /// Process a request to be sent to the given address and associated with the given token. + #[instrument(skip(self, provider))] pub fn send_request( &mut self, provider: &mut Provider<'_, ClientDispatcher>, @@ -143,9 +163,15 @@ where token: ClientToken, request: ClientRequest, ) { + tracing::debug!("sending request"); + + let bound_addr = self.bound_addr; + // Check for IP version mismatch between source addr and dest addr - match (self.bound_addr, addr) { + match (bound_addr, addr) { (SocketAddr::V4(_), SocketAddr::V6(_)) | (SocketAddr::V6(_), SocketAddr::V4(_)) => { + tracing::error!(%bound_addr, %addr, "ip version mismatch between bound address and address"); + self.notify_client(token, Err(ClientError::IPVersionMismatch)); return; @@ -158,23 +184,30 @@ where } /// Process a response received from some tracker and match it up against our sent requests. - pub async fn recv_response( + #[instrument(skip(self, provider, response))] + pub fn recv_response( &mut self, provider: &mut Provider<'_, ClientDispatcher>, - addr: SocketAddr, response: &TrackerResponse<'_>, + addr: SocketAddr, ) { + tracing::debug!("receiving response"); + let token = ClientToken(response.transaction_id()); let conn_timer = if let Some(conn_timer) = self.active_requests.remove(&token) { if conn_timer.message_params().0 == addr { conn_timer } else { + tracing::error!(?conn_timer, %addr, "different message prams"); + return; - } // TODO: Add Logging (Server Receive Addr Different Than Send Addr) + } } else { + tracing::error!(?token, "token not in active requests"); + return; - }; // TODO: Add Logging (Server Gave Us Invalid Transaction Id) + }; provider.clear_timeout( conn_timer @@ -194,10 +227,14 @@ where (&ClientRequest::Announce(hash, _), ResponseType::Announce(res)) => { // Forward contact information on to the handshaker for addr in res.peers().iter() { - self.handshaker - .send(Ok(InitiateMessage::new(Protocol::BitTorrent, hash, addr).into())) - .await - .unwrap_or_else(|_| panic!("NEED TO FIX")); + tracing::info!("sending will block if unable to send!"); + match block_on( + self.handshaker + .send(Ok(InitiateMessage::new(Protocol::BitTorrent, hash, addr).into())), + ) { + Ok(()) => tracing::debug!("handshake for: {addr} initiated"), + Err(e) => tracing::warn!("handshake for: {addr} failed with: {e}"), + } } self.notify_client(token, Ok(ClientResponse::Announce(res.to_owned()))); @@ -218,14 +255,23 @@ where /// Process an existing request, either re requesting a connection id or sending the actual request again. /// /// If this call is the result of a timeout, that will decide whether to cancel the request or not. + #[instrument(skip(self, provider))] fn process_request(&mut self, provider: &mut Provider<'_, ClientDispatcher>, token: ClientToken, timed_out: bool) { + tracing::debug!("processing request"); + let Some(mut conn_timer) = self.active_requests.remove(&token) else { + tracing::error!(?token, "token not in active requests"); + return; - }; // TODO: Add logging + }; // Resolve the duration of the current timeout to use let Some(next_timeout) = conn_timer.current_timeout(timed_out) else { - self.notify_client(token, Err(ClientError::MaxTimeout)); + let err = ClientError::MaxTimeout; + + tracing::error!("error reached timeout: {err}"); + + self.notify_client(token, Err(err)); return; }; @@ -270,12 +316,15 @@ where let mut write_success = false; provider.outgoing(|bytes| { let mut writer = Cursor::new(bytes); - write_success = tracker_request.write_bytes(&mut writer).is_ok(); - - if write_success { - Some((writer.position().try_into().unwrap(), addr)) - } else { - None + match tracker_request.write_bytes(&mut writer) { + Ok(()) => { + write_success = true; + Some((writer.position().try_into().unwrap(), addr)) + } + Err(e) => { + tracing::error!("failed to write out the tracker request with error: {e}"); + None + } } }); @@ -289,27 +338,40 @@ where self.active_requests.insert(token, conn_timer); } else { - self.notify_client(token, Err(ClientError::MaxLength)); + let err = ClientError::MaxLength; + tracing::warn!("notifying client with error: {err}"); + + self.notify_client(token, Err(err)); } } } impl Dispatcher for ClientDispatcher where - H: Sink> + DiscoveryInfo + Send + Unpin + 'static, + H: Sink> + std::fmt::Debug + DiscoveryInfo + Send + Unpin + 'static, + H::Error: std::fmt::Display, { type Timeout = DispatchTimeout; type Message = DispatchMessage; + #[instrument(skip(self, provider))] fn incoming(&mut self, mut provider: Provider<'_, Self>, message: &[u8], addr: SocketAddr) { - let IResult::Ok((_, response)) = TrackerResponse::from_bytes(message) else { - return; // TODO: Add Logging - }; + let () = match TrackerResponse::from_bytes(message) { + IResult::Ok((_, response)) => { + tracing::debug!("received an incoming response: {response:?}"); - self.recv_response(&mut provider, addr, &response); + self.recv_response(&mut provider, &response, addr); + } + Err(e) => { + tracing::error!("received an incoming error message: {e}"); + } + }; } + #[instrument(skip(self, provider))] fn notify(&mut self, mut provider: Provider<'_, Self>, message: DispatchMessage) { + tracing::debug!("received notify"); + match message { DispatchMessage::Request(addr, token, req_type) => { self.send_request(&mut provider, addr, token, req_type); @@ -319,7 +381,10 @@ where } } + #[instrument(skip(self, provider))] fn timeout(&mut self, mut provider: Provider<'_, Self>, timeout: DispatchTimeout) { + tracing::debug!("received timeout"); + match timeout { DispatchTimeout::Connect(token) => self.process_request(&mut provider, token, true), DispatchTimeout::CleanUp => { @@ -344,6 +409,22 @@ struct ConnectTimer { timeout_id: Option, } +impl std::fmt::Debug for ConnectTimer { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + let timeout_id = match self.timeout_id { + Some(_) => "Some(_)", + None => "None", + }; + + f.debug_struct("ConnectTimer") + .field("addr", &self.addr) + .field("attempt", &self.attempt) + .field("request", &self.request) + .field("timeout_id", &timeout_id) + .finish() + } +} + impl ConnectTimer { /// Create a new `ConnectTimer`. pub fn new(addr: SocketAddr, request: ClientRequest) -> ConnectTimer { @@ -356,8 +437,13 @@ impl ConnectTimer { } /// Yields the current timeout value to use or None if the request should time out completely. + #[instrument(skip(), ret)] pub fn current_timeout(&mut self, timed_out: bool) -> Option { + tracing::debug!("getting current timeout"); + if self.attempt == MAXIMUM_REQUEST_RETRANSMIT_ATTEMPTS { + tracing::warn!("request has reached maximum timeout attempts: {MAXIMUM_REQUEST_RETRANSMIT_ATTEMPTS}"); + None } else { if timed_out { @@ -379,21 +465,27 @@ impl ConnectTimer { } /// Yields the message parameters for the current connection. + #[instrument(skip(), ret)] pub fn message_params(&self) -> (SocketAddr, &ClientRequest) { + tracing::debug!("getting message parameters"); + (self.addr, &self.request) } } /// Calculates the timeout for the request given the attempt count. +#[instrument(skip(), ret)] fn calculate_message_timeout_millis(attempt: u64) -> u64 { - #[allow(clippy::cast_possible_truncation)] - let attempt = attempt as u32; + tracing::debug!("calculation message timeout in milliseconds"); + + let attempt = attempt.try_into().unwrap_or(u32::MAX); (15 * 2u64.pow(attempt)) * 1000 } // ----------------------------------------------------------------------------// /// Cache for storing connection ids associated with a specific server address. +#[derive(Debug)] struct ConnectIdCache { cache: HashMap)>, } @@ -404,10 +496,17 @@ impl ConnectIdCache { ConnectIdCache { cache: HashMap::new() } } - /// Get an un expired connection id for the given addr. + /// Get an active connection id for the given addr. + #[instrument(skip(self), ret)] fn get(&mut self, addr: SocketAddr) -> Option { + tracing::debug!("getting connection id"); + match self.cache.entry(addr) { - Entry::Vacant(_) => None, + Entry::Vacant(_) => { + tracing::warn!("connection id for {addr} not in cache"); + + None + } Entry::Occupied(occ) => { let curr_time = Utc::now(); let prev_time = occ.get().1; @@ -415,6 +514,8 @@ impl ConnectIdCache { if is_expired(curr_time, prev_time) { occ.remove(); + tracing::warn!("connection id was already expired"); + None } else { Some(occ.get().0) @@ -424,14 +525,20 @@ impl ConnectIdCache { } /// Put an un expired connection id into cache for the given addr. + #[instrument(skip(self))] fn put(&mut self, addr: SocketAddr, connect_id: u64) { + tracing::debug!("setting expired connection id"); + let curr_time = Utc::now(); self.cache.insert(addr, (connect_id, curr_time)); } /// Removes all entries that have expired. + #[instrument(skip(self))] fn clean_expired(&mut self) { + tracing::debug!("cleaning expired connection id(s)"); + let curr_time = Utc::now(); let mut curr_index = 0; @@ -448,7 +555,10 @@ impl ConnectIdCache { } /// Returns true if the connect id received at `prev_time` is now expired. +#[instrument(skip(), ret)] fn is_expired(curr_time: DateTime, prev_time: DateTime) -> bool { + tracing::debug!("checking if a previous time is now expired"); + let valid_duration = Duration::milliseconds(CONNECTION_ID_VALID_DURATION_MILLIS); let difference = prev_time.signed_duration_since(curr_time); diff --git a/packages/utracker/src/client/error.rs b/packages/utracker/src/client/error.rs index 69af8cb65..dc60ab898 100644 --- a/packages/utracker/src/client/error.rs +++ b/packages/utracker/src/client/error.rs @@ -1,3 +1,5 @@ +use thiserror::Error; + use crate::error::ErrorResponse; /// Result type for a `ClientRequest`. @@ -5,18 +7,23 @@ pub type ClientResult = Result; /// Errors occurring as the result of a `ClientRequest`. #[allow(clippy::module_name_repetitions)] -#[derive(Debug, Clone, PartialEq, Eq)] +#[derive(Error, Debug, Clone, PartialEq, Eq)] pub enum ClientError { - /// Request timeout reached. + #[error("Request timeout reached")] MaxTimeout, - /// Request length exceeded the packet length. + + #[error("Request length exceeded the packet length")] MaxLength, - /// Client shut down the request client. + + #[error("Client shut down the request client")] ClientShutdown, - /// Server sent us an invalid message. + + #[error("Server sent us an invalid message")] ServerError, - /// Requested to send from IPv4 to IPv6 or vice versa. + + #[error("Requested to send from IPv4 to IPv6 or vice versa")] IPVersionMismatch, - /// Server returned an error message. - ServerMessage(ErrorResponse<'static>), + + #[error("Server returned an error message : {0}")] + ServerMessage(#[from] ErrorResponse<'static>), } diff --git a/packages/utracker/src/client/mod.rs b/packages/utracker/src/client/mod.rs index 4bd8bde47..27e557c64 100644 --- a/packages/utracker/src/client/mod.rs +++ b/packages/utracker/src/client/mod.rs @@ -6,6 +6,7 @@ use std::sync::Arc; use futures::future::Either; use futures::sink::Sink; use handshake::{DiscoveryInfo, InitiateMessage}; +use tracing::instrument; use umio::external::Sender; use util::bt::InfoHash; use util::trans::{LocallyShuffledIds, TransactionIds}; @@ -21,6 +22,7 @@ pub mod error; /// Capacity of outstanding requests (assuming each request uses at most 1 timer at any time) const DEFAULT_CAPACITY: usize = 4096; +#[derive(Debug)] pub enum HandshakerMessage { InitiateMessage(InitiateMessage), ClientMetadata(ClientMetadata), @@ -125,18 +127,6 @@ pub struct TrackerClient { } impl TrackerClient { - /// Create a new `TrackerClient`. - /// - /// # Errors - /// - /// It would return a IO error if unable build a new client. - pub fn new(bind: SocketAddr, handshaker: H) -> std::io::Result - where - H: Sink> + DiscoveryInfo + Send + Unpin + 'static, - { - TrackerClient::with_capacity(bind, handshaker, DEFAULT_CAPACITY) - } - /// Create a new `TrackerClient` with the given message capacity. /// /// Panics if capacity == `usize::max_value`(). @@ -148,10 +138,24 @@ impl TrackerClient { /// # Panics /// /// It would panic if the desired capacity is too large. - pub fn with_capacity(bind: SocketAddr, handshaker: H, capacity: usize) -> std::io::Result + #[instrument(skip())] + pub fn new(bind: SocketAddr, handshaker: H, capacity_or_default: Option) -> std::io::Result where - H: Sink> + DiscoveryInfo + Send + Unpin + 'static, + H: Sink> + std::fmt::Debug + DiscoveryInfo + Send + Unpin + 'static, + H::Error: std::fmt::Display, { + tracing::info!("running client"); + + let capacity = if let Some(capacity) = capacity_or_default { + tracing::debug!("with capacity {capacity}"); + + capacity + } else { + tracing::debug!("with default capacity: {DEFAULT_CAPACITY}"); + + DEFAULT_CAPACITY + }; + // Need channel capacity to be 1 more in case channel is saturated and client // is dropped so shutdown message can get through in the worst case let (chan_capacity, would_overflow) = capacity.overflowing_add(1); @@ -162,8 +166,10 @@ impl TrackerClient { // Limit the capacity of messages (channel capacity - 1) let limiter = RequestLimiter::new(capacity); - dispatcher::create_dispatcher(bind, handshaker, chan_capacity, limiter.clone()).map(|chan| TrackerClient { - send: chan, + let dispatcher = dispatcher::create_dispatcher(bind, handshaker, chan_capacity, limiter.clone())?; + + Ok(TrackerClient { + send: dispatcher, limiter, generator: TokenGenerator::new(), }) @@ -176,7 +182,10 @@ impl TrackerClient { /// # Panics /// /// It would panic if unable to send request message. + #[instrument(skip(self))] pub fn request(&mut self, addr: SocketAddr, request: ClientRequest) -> Option { + tracing::debug!("requesting"); + if self.limiter.can_initiate() { let token = self.generator.generate(); self.send @@ -185,6 +194,8 @@ impl TrackerClient { Some(token) } else { + tracing::debug!("initiation was limited"); + None } } @@ -227,7 +238,7 @@ impl TokenGenerator { // ----------------------------------------------------------------------------// /// Limits requests based on the current number of outstanding requests. -#[derive(Clone)] +#[derive(Clone, Debug)] pub struct RequestLimiter { active: Arc, capacity: usize, @@ -255,7 +266,7 @@ impl RequestLimiter { // If the number of requests stored previously was less than the capacity, // then the add is considered good and a request can (SHOULD) be made. - if current_active_requests < self.capacity { + if current_active_requests <= self.capacity { true } else { // Act as if the request just completed (decrement back down) diff --git a/packages/utracker/src/error.rs b/packages/utracker/src/error.rs index fa99a5f32..318ab0e55 100644 --- a/packages/utracker/src/error.rs +++ b/packages/utracker/src/error.rs @@ -8,14 +8,21 @@ use nom::character::complete::not_line_ending; use nom::combinator::map_res; use nom::sequence::terminated; use nom::IResult; +use thiserror::Error; /// Error reported by the server and sent to the client. #[allow(clippy::module_name_repetitions)] -#[derive(Debug, Clone, PartialEq, Eq)] +#[derive(Error, Debug, Clone, PartialEq, Eq)] pub struct ErrorResponse<'a> { message: Cow<'a, str>, } +impl<'a> std::fmt::Display for ErrorResponse<'a> { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "Server Error: {}", self.message) + } +} + impl<'a> ErrorResponse<'a> { /// Create a new `ErrorResponse`. #[must_use] diff --git a/packages/utracker/src/option.rs b/packages/utracker/src/option.rs index 63945a62e..1d388fe3b 100644 --- a/packages/utracker/src/option.rs +++ b/packages/utracker/src/option.rs @@ -13,6 +13,7 @@ use nom::multi::length_data; use nom::number::complete::be_u8; use nom::sequence::tuple; use nom::IResult; +use tracing::instrument; const END_OF_OPTIONS_BYTE: u8 = 0x00; const NO_OPERATION_BYTE: u8 = 0x01; @@ -72,10 +73,12 @@ impl<'a> AnnounceOptions<'a> { /// # Panics /// /// It would panic if the chunk length is too large. + #[instrument(skip(self, writer))] pub fn write_bytes(&self, mut writer: W) -> std::io::Result<()> where W: Write, { + tracing::trace!("writing {} options", self.raw_options.len()); for (byte, content) in &self.raw_options { for content_chunk in content.chunks(u8::MAX as usize) { let content_chunk_len: u8 = content_chunk.len().try_into().unwrap(); @@ -88,9 +91,17 @@ impl<'a> AnnounceOptions<'a> { // If we can fit it in, include the option terminating byte, otherwise as per the // spec, we can leave it out since we are assuming this is the end of the packet. - writer.write_u8(END_OF_OPTIONS_BYTE)?; - - Ok(()) + match writer.write_u8(END_OF_OPTIONS_BYTE) { + Ok(()) => Ok(()), + Err(e) => { + if e.kind() == std::io::ErrorKind::WriteZero { + tracing::trace!("no space to write ending marker"); + Ok(()) + } else { + Err(e) + } + } + } } /// Search for and construct the given `AnnounceOption` from the current `AnnounceOptions`. @@ -221,15 +232,36 @@ impl<'a> AnnounceOption<'a> for URLDataOption<'a> { #[cfg(test)] mod tests { + use std::io::Write; + use std::sync::Once; use nom::IResult; + use tracing::level_filters::LevelFilter; use super::{AnnounceOptions, URLDataOption}; + #[allow(dead_code)] + pub static INIT: Once = Once::new(); + + #[allow(dead_code)] + pub fn tracing_stderr_init(filter: LevelFilter) { + let builder = tracing_subscriber::fmt() + .with_max_level(filter) + .with_ansi(true) + .with_writer(std::io::stderr); + + builder.pretty().with_file(true).init(); + + tracing::info!("Logging initialized"); + } + #[test] - #[ignore = "unable to write into a too-small buffer..."] fn positive_write_eof_option() { + INIT.call_once(|| { + tracing_stderr_init(LevelFilter::INFO); + }); + let mut received = []; let options = AnnounceOptions::new(); diff --git a/packages/utracker/src/request.rs b/packages/utracker/src/request.rs index 5878f2a5a..b5be624f5 100644 --- a/packages/utracker/src/request.rs +++ b/packages/utracker/src/request.rs @@ -20,6 +20,7 @@ pub const CONNECT_ID_PROTOCOL_ID: u64 = 0x0417_2710_1980; /// Enumerates all types of requests that can be made to a tracker. #[allow(clippy::module_name_repetitions)] +#[derive(Debug)] pub enum RequestType<'a> { Connect, Announce(AnnounceRequest<'a>), @@ -40,6 +41,7 @@ impl<'a> RequestType<'a> { /// `TrackerRequest` which encapsulates any request sent to a tracker. #[allow(clippy::module_name_repetitions)] +#[derive(Debug)] pub struct TrackerRequest<'a> { // Both the connection id and transaction id are technically not unsigned according // to the spec, but since they are just bits we will keep them as unsigned since it diff --git a/packages/utracker/src/response.rs b/packages/utracker/src/response.rs index 0f002bd16..11fb95a55 100644 --- a/packages/utracker/src/response.rs +++ b/packages/utracker/src/response.rs @@ -18,6 +18,7 @@ const ERROR_ACTION_ID: u32 = 3; /// Enumerates all types of responses that can be received from a tracker. #[allow(clippy::module_name_repetitions)] +#[derive(Debug)] pub enum ResponseType<'a> { Connect(u64), Announce(AnnounceResponse<'a>), @@ -40,6 +41,7 @@ impl<'a> ResponseType<'a> { /// `TrackerResponse` which encapsulates any response sent from a tracker. #[allow(clippy::module_name_repetitions)] +#[derive(Debug)] pub struct TrackerResponse<'a> { transaction_id: u32, response_type: ResponseType<'a>, diff --git a/packages/utracker/src/scrape.rs b/packages/utracker/src/scrape.rs index 532b0b4d0..355c9b647 100644 --- a/packages/utracker/src/scrape.rs +++ b/packages/utracker/src/scrape.rs @@ -128,17 +128,12 @@ impl<'a> ScrapeRequest<'a> { } fn parse_request(bytes: &[u8]) -> IResult<&[u8], ScrapeRequest<'_>> { - let Some(remainder_bytes) = SCRAPE_STATS_BYTES.checked_sub(bytes.len()) else { - return Err(nom::Err::Error(nom::error::Error::new( - bytes, - nom::error::ErrorKind::TooLarge, - ))); - }; + let remainder_bytes = NonZero::new(bytes.len() % bt::INFO_HASH_LEN); - let remainder_bytes = NonZero::new(remainder_bytes); + let needed = remainder_bytes.and_then(|rem| bt::INFO_HASH_LEN.checked_sub(rem.into()).and_then(NonZero::new)); - if let Some(remainder_bytes) = remainder_bytes { - Err(nom::Err::Incomplete(Needed::Size(remainder_bytes))) + if let Some(needed) = needed { + Err(nom::Err::Incomplete(Needed::Size(needed))) } else { let end_of_bytes = &bytes[bytes.len()..bytes.len()]; @@ -229,17 +224,12 @@ impl<'a> ScrapeResponse<'a> { } fn parse_response(bytes: &[u8]) -> IResult<&[u8], ScrapeResponse<'_>> { - let Some(remainder_bytes) = SCRAPE_STATS_BYTES.checked_sub(bytes.len()) else { - return Err(nom::Err::Error(nom::error::Error::new( - bytes, - nom::error::ErrorKind::TooLarge, - ))); - }; + let remainder_bytes = NonZero::new(bytes.len() % SCRAPE_STATS_BYTES); - let remainder_bytes = NonZero::new(remainder_bytes); + let needed = remainder_bytes.and_then(|rem| SCRAPE_STATS_BYTES.checked_sub(rem.into()).and_then(NonZero::new)); - if let Some(remainder_bytes) = remainder_bytes { - Err(nom::Err::Incomplete(Needed::Size(remainder_bytes))) + if let Some(needed) = needed { + Err(nom::Err::Incomplete(Needed::Size(needed))) } else { let end_of_bytes = &bytes[bytes.len()..bytes.len()]; @@ -443,11 +433,11 @@ mod tests { fn positive_parse_request_empty() { let hash_none = []; - let received = ScrapeRequest::from_bytes(&hash_none); + let received = ScrapeRequest::from_bytes(&hash_none).unwrap(); let expected = ScrapeRequest::new(); - assert_eq!(received, IResult::Ok((&b""[..], expected))); + assert_eq!(received, (&b""[..], expected)); } #[test] @@ -484,11 +474,11 @@ mod tests { fn positive_parse_response_empty() { let stats_bytes = []; - let received = ScrapeResponse::from_bytes(&stats_bytes); + let received = ScrapeResponse::from_bytes(&stats_bytes).unwrap(); let expected = ScrapeResponse::new(); - assert_eq!(received, IResult::Ok((&b""[..], expected))); + assert_eq!(received, (&b""[..], expected)); } #[test] diff --git a/packages/utracker/src/server/dispatcher.rs b/packages/utracker/src/server/dispatcher.rs index 3143fd4d3..c90165046 100644 --- a/packages/utracker/src/server/dispatcher.rs +++ b/packages/utracker/src/server/dispatcher.rs @@ -3,6 +3,7 @@ use std::net::SocketAddr; use std::thread; use nom::IResult; +use tracing::instrument; use umio::external::Sender; use umio::{Dispatcher, ELoopBuilder, Provider}; @@ -16,16 +17,20 @@ use crate::server::handler::ServerHandler; const EXPECTED_PACKET_LENGTH: usize = 1500; /// Internal dispatch message for servers. +#[derive(Debug)] pub enum DispatchMessage { Shutdown, } /// Create a new background dispatcher to service requests. #[allow(clippy::module_name_repetitions)] +#[instrument(skip())] pub fn create_dispatcher(bind: SocketAddr, handler: H) -> std::io::Result> where - H: ServerHandler + 'static, + H: ServerHandler + std::fmt::Debug + 'static, { + tracing::debug!("create dispatcher"); + let builder = ELoopBuilder::new() .channel_capacity(1) .timer_capacity(0) @@ -47,29 +52,36 @@ where // ----------------------------------------------------------------------------// /// Dispatcher that executes requests asynchronously. +#[derive(Debug)] struct ServerDispatcher where - H: ServerHandler, + H: ServerHandler + std::fmt::Debug, { handler: H, } impl ServerDispatcher where - H: ServerHandler, + H: ServerHandler + std::fmt::Debug, { /// Create a new `ServerDispatcher`. + #[instrument(skip(), ret)] fn new(handler: H) -> ServerDispatcher { + tracing::debug!("new"); + ServerDispatcher { handler } } /// Forward the request on to the appropriate handler method. + #[instrument(skip(self, provider))] fn process_request( &mut self, provider: &mut Provider<'_, ServerDispatcher>, request: &TrackerRequest<'_>, addr: SocketAddr, ) { + tracing::debug!("process request"); + let conn_id = request.connection_id(); let trans_id = request.transaction_id(); @@ -77,7 +89,12 @@ where &RequestType::Connect => { if conn_id == request::CONNECT_ID_PROTOCOL_ID { self.forward_connect(provider, trans_id, addr); - } // TODO: Add Logging + } else { + tracing::warn!( + "request was not `CONNECT_ID_PROTOCOL_ID`, i.e. {}, but {conn_id}.", + request::CONNECT_ID_PROTOCOL_ID + ); + } } RequestType::Announce(req) => { self.forward_announce(provider, trans_id, conn_id, req, addr); @@ -89,19 +106,28 @@ where } /// Forward a connect request on to the appropriate handler method. + #[instrument(skip(self, provider))] fn forward_connect(&mut self, provider: &mut Provider<'_, ServerDispatcher>, trans_id: u32, addr: SocketAddr) { - self.handler.connect(addr, |result| { - let response_type = match result { - Ok(conn_id) => ResponseType::Connect(conn_id), - Err(err_msg) => ResponseType::Error(ErrorResponse::new(err_msg)), - }; - let response = TrackerResponse::new(trans_id, response_type); - - write_response(provider, &response, addr); - }); + tracing::debug!("forward connect"); + + let Some(attempt) = self.handler.connect(addr) else { + tracing::warn!("connect attempt canceled"); + + return; + }; + + let response_type = match attempt { + Ok(conn_id) => ResponseType::Connect(conn_id), + Err(err_msg) => ResponseType::Error(ErrorResponse::new(err_msg)), + }; + + let response = TrackerResponse::new(trans_id, response_type); + + write_response(provider, &response, addr); } /// Forward an announce request on to the appropriate handler method. + #[instrument(skip(self, provider))] fn forward_announce( &mut self, provider: &mut Provider<'_, ServerDispatcher>, @@ -110,18 +136,25 @@ where request: &AnnounceRequest<'_>, addr: SocketAddr, ) { - self.handler.announce(addr, conn_id, request, |result| { - let response_type = match result { - Ok(response) => ResponseType::Announce(response), - Err(err_msg) => ResponseType::Error(ErrorResponse::new(err_msg)), - }; - let response = TrackerResponse::new(trans_id, response_type); - - write_response(provider, &response, addr); - }); + tracing::debug!("forward announce"); + + let Some(attempt) = self.handler.announce(addr, conn_id, request) else { + tracing::warn!("announce attempt canceled"); + + return; + }; + + let response_type = match attempt { + Ok(response) => ResponseType::Announce(response), + Err(err_msg) => ResponseType::Error(ErrorResponse::new(err_msg)), + }; + let response = TrackerResponse::new(trans_id, response_type); + + write_response(provider, &response, addr); } /// Forward a scrape request on to the appropriate handler method. + #[instrument(skip(self, provider))] fn forward_scrape( &mut self, provider: &mut Provider<'_, ServerDispatcher>, @@ -130,55 +163,81 @@ where request: &ScrapeRequest<'_>, addr: SocketAddr, ) { - self.handler.scrape(addr, conn_id, request, |result| { - let response_type = match result { - Ok(response) => ResponseType::Scrape(response), - Err(err_msg) => ResponseType::Error(ErrorResponse::new(err_msg)), - }; - let response = TrackerResponse::new(trans_id, response_type); - - write_response(provider, &response, addr); - }); + tracing::debug!("forward scrape"); + + let Some(attempt) = self.handler.scrape(addr, conn_id, request) else { + tracing::warn!("connect scrape canceled"); + + return; + }; + + let response_type = match attempt { + Ok(response) => ResponseType::Scrape(response), + Err(err_msg) => ResponseType::Error(ErrorResponse::new(err_msg)), + }; + + let response = TrackerResponse::new(trans_id, response_type); + + write_response(provider, &response, addr); } } /// Write the given tracker response through to the given provider. +#[instrument(skip(provider))] fn write_response(provider: &mut Provider<'_, ServerDispatcher>, response: &TrackerResponse<'_>, addr: SocketAddr) where - H: ServerHandler, + H: ServerHandler + std::fmt::Debug, { + tracing::debug!("write response"); + provider.outgoing(|buffer| { let mut cursor = Cursor::new(buffer); - let success = response.write_bytes(&mut cursor).is_ok(); - if success { - Some((cursor.position().try_into().unwrap(), addr)) - } else { - None - } // TODO: Add Logging + match response.write_bytes(&mut cursor) { + Ok(()) => Some((cursor.position().try_into().unwrap(), addr)), + Err(e) => { + tracing::error!("error writing response to cursor: {e}"); + None + } + } }); } impl Dispatcher for ServerDispatcher where - H: ServerHandler, + H: ServerHandler + std::fmt::Debug, { type Timeout = (); type Message = DispatchMessage; + #[instrument(skip(self, provider))] fn incoming(&mut self, mut provider: Provider<'_, Self>, message: &[u8], addr: SocketAddr) { - let IResult::Ok((_, request)) = TrackerRequest::from_bytes(message) else { - return; // TODO: Add Logging - }; + let () = match TrackerRequest::from_bytes(message) { + IResult::Ok((_, request)) => { + tracing::debug!("received an incoming request: {request:?}"); - self.process_request(&mut provider, &request, addr); + self.process_request(&mut provider, &request, addr); + } + Err(e) => { + tracing::error!("received an incoming error message: {e}"); + } + }; } + #[instrument(skip(self, provider))] fn notify(&mut self, mut provider: Provider<'_, Self>, message: DispatchMessage) { - match message { - DispatchMessage::Shutdown => provider.shutdown(), - } + let () = match message { + DispatchMessage::Shutdown => { + tracing::debug!("received a shutdown notification"); + + provider.shutdown(); + } + }; } - fn timeout(&mut self, _: Provider<'_, Self>, (): ()) {} + #[instrument(skip(self))] + fn timeout(&mut self, _: Provider<'_, Self>, (): ()) { + tracing::error!("timeout not yet supported!"); + unimplemented!(); + } } diff --git a/packages/utracker/src/server/handler.rs b/packages/utracker/src/server/handler.rs index 45f9ff563..e7682cd8b 100644 --- a/packages/utracker/src/server/handler.rs +++ b/packages/utracker/src/server/handler.rs @@ -13,23 +13,16 @@ pub type ServerResult<'a, T> = Result; pub trait ServerHandler: Send { /// Service a connection id request from the given address. - /// - /// If the result callback is not called, no response will be sent. - fn connect(&mut self, addr: SocketAddr, result: R) - where - R: for<'a> FnOnce(ServerResult<'a, u64>); + fn connect(&mut self, addr: SocketAddr) -> Option>; /// Service an announce request with the given connect id. - /// - /// If the result callback is not called, no response will be sent. - fn announce<'b, R>(&mut self, addr: SocketAddr, id: u64, req: &AnnounceRequest<'b>, result: R) - where - R: for<'a> FnOnce(ServerResult<'a, AnnounceResponse<'a>>); + fn announce( + &mut self, + addr: SocketAddr, + id: u64, + req: &AnnounceRequest<'_>, + ) -> Option>>; /// Service a scrape request with the given connect id. - /// - /// If the result callback is not called, no response will be sent. - fn scrape<'b, R>(&mut self, addr: SocketAddr, id: u64, req: &ScrapeRequest<'b>, result: R) - where - R: for<'a> FnOnce(ServerResult<'a, ScrapeResponse<'a>>); + fn scrape(&mut self, addr: SocketAddr, id: u64, req: &ScrapeRequest<'_>) -> Option>>; } diff --git a/packages/utracker/src/server/mod.rs b/packages/utracker/src/server/mod.rs index bf7eb2cfc..ae837f7ec 100644 --- a/packages/utracker/src/server/mod.rs +++ b/packages/utracker/src/server/mod.rs @@ -1,6 +1,7 @@ use std::io; use std::net::SocketAddr; +use tracing::instrument; use umio::external::Sender; use crate::server::dispatcher::DispatchMessage; @@ -13,8 +14,9 @@ pub mod handler; /// /// Server will shutdown on drop. #[allow(clippy::module_name_repetitions)] +#[derive(Debug)] pub struct TrackerServer { - send: Sender, + dispatcher: Sender, } impl TrackerServer { @@ -23,17 +25,25 @@ impl TrackerServer { /// # Errors /// /// It would return an IO Error if unable to run the server. + #[instrument(skip(), ret)] pub fn run(bind: SocketAddr, handler: H) -> std::io::Result where - H: ServerHandler + 'static, + H: ServerHandler + std::fmt::Debug + 'static, { - dispatcher::create_dispatcher(bind, handler).map(|send| TrackerServer { send }) + tracing::info!("running server"); + + let dispatcher = dispatcher::create_dispatcher(bind, handler)?; + + Ok(TrackerServer { dispatcher }) } } impl Drop for TrackerServer { + #[instrument(skip(self))] fn drop(&mut self) { - self.send + tracing::debug!("server was dropped, sending shutdown notification..."); + + self.dispatcher .send(DispatchMessage::Shutdown) .expect("bip_utracker: TrackerServer Failed To Send Shutdown Message"); } diff --git a/packages/utracker/tests/common/mod.rs b/packages/utracker/tests/common/mod.rs index c150f5d09..784f639d9 100644 --- a/packages/utracker/tests/common/mod.rs +++ b/packages/utracker/tests/common/mod.rs @@ -1,12 +1,14 @@ use std::collections::{HashMap, HashSet}; use std::net::{SocketAddr, SocketAddrV4, SocketAddrV6}; use std::sync::{Arc, Mutex, Once}; +use std::time::Duration; use futures::channel::mpsc::{self, UnboundedReceiver, UnboundedSender}; use futures::sink::SinkExt; use futures::stream::StreamExt; use futures::{Sink, Stream}; use handshake::DiscoveryInfo; +use tracing::instrument; use tracing::level_filters::LevelFilter; use util::bt::{InfoHash, PeerId}; use util::trans::{LocallyShuffledIds, TransactionIds}; @@ -15,6 +17,9 @@ use utracker::contact::{CompactPeers, CompactPeersV4, CompactPeersV6}; use utracker::scrape::{ScrapeRequest, ScrapeResponse, ScrapeStats}; use utracker::{HandshakerMessage, ServerHandler, ServerResult}; +#[allow(dead_code)] +pub const DEFAULT_TIMEOUT: Duration = Duration::from_millis(1000); + const NUM_PEERS_RETURNED: usize = 20; #[allow(dead_code)] @@ -39,11 +44,12 @@ pub fn tracing_stderr_init(filter: LevelFilter) { tracing::info!("Logging initialized"); } -#[derive(Clone)] +#[derive(Debug, Clone)] pub struct MockTrackerHandler { inner: Arc>, } +#[derive(Debug)] pub struct InnerMockTrackerHandler { cids: HashSet, cid_generator: LocallyShuffledIds, @@ -52,7 +58,10 @@ pub struct InnerMockTrackerHandler { #[allow(dead_code)] impl MockTrackerHandler { + #[instrument(skip(), ret)] pub fn new() -> MockTrackerHandler { + tracing::debug!("new mock handler"); + MockTrackerHandler { inner: Arc::new(Mutex::new(InnerMockTrackerHandler { cids: HashSet::new(), @@ -68,22 +77,27 @@ impl MockTrackerHandler { } impl ServerHandler for MockTrackerHandler { - fn connect(&mut self, _: SocketAddr, result: R) - where - R: for<'a> FnOnce(ServerResult<'a, u64>), - { + #[instrument(skip(self), ret)] + fn connect(&mut self, addr: SocketAddr) -> Option> { + tracing::debug!("mock connect"); + let mut inner_lock = self.inner.lock().unwrap(); let cid = inner_lock.cid_generator.generate(); inner_lock.cids.insert(cid); - result(Ok(cid)); + Some(Ok(cid)) } - fn announce<'b, R>(&mut self, addr: SocketAddr, id: u64, req: &AnnounceRequest<'b>, result: R) - where - R: for<'a> FnOnce(ServerResult<'a, AnnounceResponse<'a>>), - { + #[instrument(skip(self), ret)] + fn announce( + &mut self, + addr: SocketAddr, + id: u64, + req: &AnnounceRequest<'_>, + ) -> Option>> { + tracing::debug!("mock announce"); + let mut inner_lock = self.inner.lock().unwrap(); if inner_lock.cids.contains(&id) { @@ -133,21 +147,21 @@ impl ServerHandler for MockTrackerHandler { CompactPeers::V6(v6_peers) }; - result(Ok(AnnounceResponse::new( + Some(Ok(AnnounceResponse::new( 1800, peers.len().try_into().unwrap(), peers.len().try_into().unwrap(), compact_peers, - ))); + ))) } else { - result(Err("Connection ID Is Invalid")); + Some(Err("Connection ID Is Invalid")) } } - fn scrape<'b, R>(&mut self, _: SocketAddr, id: u64, req: &ScrapeRequest<'b>, result: R) - where - R: for<'a> FnOnce(ServerResult<'a, ScrapeResponse<'a>>), - { + #[instrument(skip(self), ret)] + fn scrape(&mut self, _: SocketAddr, id: u64, req: &ScrapeRequest<'_>) -> Option>> { + tracing::debug!("mock scrape"); + let mut inner_lock = self.inner.lock().unwrap(); if inner_lock.cids.contains(&id) { @@ -163,9 +177,9 @@ impl ServerHandler for MockTrackerHandler { )); } - result(Ok(response)); + Some(Ok(response)) } else { - result(Err("Connection ID Is Invalid")); + Some(Err("Connection ID Is Invalid")) } } } @@ -179,15 +193,11 @@ pub fn handshaker() -> (MockHandshakerSink, MockHandshakerStream) { (MockHandshakerSink { send }, MockHandshakerStream { recv }) } -#[derive(Clone)] +#[derive(Debug, Clone)] pub struct MockHandshakerSink { send: UnboundedSender, } -pub struct MockHandshakerStream { - recv: UnboundedReceiver, -} - impl DiscoveryInfo for MockHandshakerSink { fn port(&self) -> u16 { 6969 @@ -201,41 +211,60 @@ impl DiscoveryInfo for MockHandshakerSink { impl Sink> for MockHandshakerSink { type Error = std::io::Error; + #[instrument(skip(self, cx), ret)] fn poll_ready(self: std::pin::Pin<&mut Self>, cx: &mut std::task::Context<'_>) -> std::task::Poll> { + tracing::debug!("polling ready"); + self.send .poll_ready(cx) - .map_err(|_| std::io::Error::new(std::io::ErrorKind::Other, "SendError")) + .map_err(|e| std::io::Error::new(std::io::ErrorKind::Other, e)) } + #[instrument(skip(self), ret)] fn start_send(mut self: std::pin::Pin<&mut Self>, item: std::io::Result) -> Result<(), Self::Error> { + tracing::debug!("starting send"); + self.send .start_send(item?) - .map_err(|_| std::io::Error::new(std::io::ErrorKind::Other, "SendError")) + .map_err(|e| std::io::Error::new(std::io::ErrorKind::Other, e)) } + #[instrument(skip(self, cx), ret)] fn poll_flush( mut self: std::pin::Pin<&mut Self>, cx: &mut std::task::Context<'_>, ) -> std::task::Poll> { + tracing::debug!("polling flush"); + self.send .poll_flush_unpin(cx) - .map_err(|_| std::io::Error::new(std::io::ErrorKind::Other, "SendError")) + .map_err(|e| std::io::Error::new(std::io::ErrorKind::Other, e)) } + #[instrument(skip(self, cx), ret)] fn poll_close( mut self: std::pin::Pin<&mut Self>, cx: &mut std::task::Context<'_>, ) -> std::task::Poll> { + tracing::debug!("polling close"); + self.send .poll_close_unpin(cx) - .map_err(|_| std::io::Error::new(std::io::ErrorKind::Other, "SendError")) + .map_err(|e| std::io::Error::new(std::io::ErrorKind::Other, e)) } } +pub struct MockHandshakerStream { + recv: UnboundedReceiver, +} + impl Stream for MockHandshakerStream { type Item = std::io::Result; + #[instrument(skip(self, cx), ret)] fn poll_next(mut self: std::pin::Pin<&mut Self>, cx: &mut std::task::Context<'_>) -> std::task::Poll> { + tracing::debug!("polling next"); + self.recv.poll_next_unpin(cx).map(|maybe| maybe.map(Ok)) } } diff --git a/packages/utracker/tests/test_announce_start.rs b/packages/utracker/tests/test_announce_start.rs index 7d2ef258d..72a8a7d93 100644 --- a/packages/utracker/tests/test_announce_start.rs +++ b/packages/utracker/tests/test_announce_start.rs @@ -2,7 +2,7 @@ use std::net::SocketAddr; use std::thread::{self}; use std::time::Duration; -use common::{handshaker, tracing_stderr_init, MockTrackerHandler, INIT}; +use common::{handshaker, tracing_stderr_init, MockTrackerHandler, DEFAULT_TIMEOUT, INIT}; use futures::StreamExt as _; use handshake::Protocol; use tracing::level_filters::LevelFilter; @@ -15,10 +15,10 @@ mod common; #[tokio::test] async fn positive_announce_started() { INIT.call_once(|| { - tracing_stderr_init(LevelFilter::INFO); + tracing_stderr_init(LevelFilter::ERROR); }); - let (sink, mut stream) = handshaker(); + let (handshaker_sender, mut handshaker_receiver) = handshaker(); let server_addr = "127.0.0.1:3501".parse().unwrap(); let mock_handler = MockTrackerHandler::new(); @@ -26,9 +26,11 @@ async fn positive_announce_started() { thread::sleep(Duration::from_millis(100)); - let mut client = TrackerClient::new("127.0.0.1:4501".parse().unwrap(), sink).unwrap(); + let mut client = TrackerClient::new("127.0.0.1:4501".parse().unwrap(), handshaker_sender, None).unwrap(); let hash = [0u8; bt::INFO_HASH_LEN].into(); + + tracing::warn!("sending announce"); let _send_token = client .request( server_addr, @@ -36,7 +38,13 @@ async fn positive_announce_started() { ) .unwrap(); - let init_msg = match stream.next().await.unwrap().unwrap() { + tracing::warn!("receiving initiate message"); + let init_msg = match tokio::time::timeout(DEFAULT_TIMEOUT, handshaker_receiver.next()) + .await + .unwrap() + .unwrap() + .unwrap() + { HandshakerMessage::InitiateMessage(message) => message, HandshakerMessage::ClientMetadata(_) => unreachable!(), }; @@ -47,7 +55,13 @@ async fn positive_announce_started() { assert_eq!(&exp_peer_addr, init_msg.address()); assert_eq!(&hash, init_msg.hash()); - let metadata = match stream.next().await.unwrap().unwrap() { + tracing::warn!("receiving client metadata"); + let metadata = match tokio::time::timeout(DEFAULT_TIMEOUT, handshaker_receiver.next()) + .await + .unwrap() + .unwrap() + .unwrap() + { HandshakerMessage::InitiateMessage(_) => unreachable!(), HandshakerMessage::ClientMetadata(metadata) => metadata, }; diff --git a/packages/utracker/tests/test_announce_stop.rs b/packages/utracker/tests/test_announce_stop.rs index 713ad00e9..80e4d337e 100644 --- a/packages/utracker/tests/test_announce_stop.rs +++ b/packages/utracker/tests/test_announce_stop.rs @@ -1,7 +1,7 @@ use std::thread::{self}; use std::time::Duration; -use common::{handshaker, tracing_stderr_init, MockTrackerHandler, INIT}; +use common::{handshaker, tracing_stderr_init, MockTrackerHandler, DEFAULT_TIMEOUT, INIT}; use futures::StreamExt as _; use tracing::level_filters::LevelFilter; use util::bt::{self}; @@ -13,7 +13,7 @@ mod common; #[tokio::test] async fn positive_announce_stopped() { INIT.call_once(|| { - tracing_stderr_init(LevelFilter::INFO); + tracing_stderr_init(LevelFilter::ERROR); }); let (sink, mut stream) = handshaker(); @@ -24,7 +24,7 @@ async fn positive_announce_stopped() { thread::sleep(Duration::from_millis(100)); - let mut client = TrackerClient::new("127.0.0.1:4502".parse().unwrap(), sink).unwrap(); + let mut client = TrackerClient::new("127.0.0.1:4502".parse().unwrap(), sink, None).unwrap(); let info_hash = [0u8; bt::INFO_HASH_LEN].into(); @@ -37,12 +37,22 @@ async fn positive_announce_stopped() { ) .unwrap(); - let _init_msg = match stream.next().await.unwrap().unwrap() { + let _init_msg = match tokio::time::timeout(DEFAULT_TIMEOUT, stream.next()) + .await + .unwrap() + .unwrap() + .unwrap() + { HandshakerMessage::InitiateMessage(message) => message, HandshakerMessage::ClientMetadata(_) => unreachable!(), }; - let metadata = match stream.next().await.unwrap().unwrap() { + let metadata = match tokio::time::timeout(DEFAULT_TIMEOUT, stream.next()) + .await + .unwrap() + .unwrap() + .unwrap() + { HandshakerMessage::InitiateMessage(_) => unreachable!(), HandshakerMessage::ClientMetadata(metadata) => metadata, }; @@ -62,7 +72,12 @@ async fn positive_announce_stopped() { ) .unwrap(); - let metadata = match stream.next().await.unwrap().unwrap() { + let metadata = match tokio::time::timeout(DEFAULT_TIMEOUT, stream.next()) + .await + .unwrap() + .unwrap() + .unwrap() + { HandshakerMessage::InitiateMessage(_) => unreachable!(), HandshakerMessage::ClientMetadata(metadata) => metadata, }; diff --git a/packages/utracker/tests/test_client_drop.rs b/packages/utracker/tests/test_client_drop.rs index 1fc4cba1d..3564cb9e7 100644 --- a/packages/utracker/tests/test_client_drop.rs +++ b/packages/utracker/tests/test_client_drop.rs @@ -1,6 +1,6 @@ use std::net::SocketAddr; -use common::{handshaker, tracing_stderr_init, INIT}; +use common::{handshaker, tracing_stderr_init, DEFAULT_TIMEOUT, INIT}; use futures::StreamExt as _; use tracing::level_filters::LevelFilter; use util::bt::{self}; @@ -12,7 +12,7 @@ mod common; #[tokio::test] async fn positive_client_request_failed() { INIT.call_once(|| { - tracing_stderr_init(LevelFilter::INFO); + tracing_stderr_init(LevelFilter::ERROR); }); let (sink, mut stream) = handshaker(); @@ -21,7 +21,7 @@ async fn positive_client_request_failed() { // Don't actually create the server since we want the request to wait for a little bit until we drop let send_token = { - let mut client = TrackerClient::new("127.0.0.1:4503".parse().unwrap(), sink).unwrap(); + let mut client = TrackerClient::new("127.0.0.1:4503".parse().unwrap(), sink, None).unwrap(); client .request( @@ -35,7 +35,12 @@ async fn positive_client_request_failed() { }; // Client is now dropped - let metadata = match stream.next().await.unwrap().unwrap() { + let metadata = match tokio::time::timeout(DEFAULT_TIMEOUT, stream.next()) + .await + .unwrap() + .unwrap() + .unwrap() + { HandshakerMessage::InitiateMessage(_) => unreachable!(), HandshakerMessage::ClientMetadata(metadata) => metadata, }; diff --git a/packages/utracker/tests/test_client_full.rs b/packages/utracker/tests/test_client_full.rs index 8d29aae8e..94b9dbd59 100644 --- a/packages/utracker/tests/test_client_full.rs +++ b/packages/utracker/tests/test_client_full.rs @@ -1,6 +1,6 @@ use std::mem; -use common::{handshaker, tracing_stderr_init, INIT}; +use common::{handshaker, tracing_stderr_init, DEFAULT_TIMEOUT, INIT}; use futures::StreamExt as _; use tracing::level_filters::LevelFilter; use util::bt::{self}; @@ -12,7 +12,7 @@ mod common; #[tokio::test] async fn positive_client_request_dropped() { INIT.call_once(|| { - tracing_stderr_init(LevelFilter::INFO); + tracing_stderr_init(LevelFilter::ERROR); }); let (sink, stream) = handshaker(); @@ -21,9 +21,12 @@ async fn positive_client_request_dropped() { let request_capacity = 10; - let mut client = TrackerClient::with_capacity("127.0.0.1:4504".parse().unwrap(), sink, request_capacity).unwrap(); + let mut client = TrackerClient::new("127.0.0.1:4504".parse().unwrap(), sink, Some(request_capacity)).unwrap(); + + tracing::warn!("sending announce requests to fill buffer"); + for i in 1..=request_capacity { + tracing::warn!("request {i} of {request_capacity}"); - for _ in 0..request_capacity { client .request( server_addr, @@ -35,6 +38,7 @@ async fn positive_client_request_dropped() { .unwrap(); } + tracing::warn!("sending one more announce request, it should fail"); assert!(client .request( server_addr, @@ -47,6 +51,6 @@ async fn positive_client_request_dropped() { mem::drop(client); - let buffer: Vec<_> = stream.collect().await; + let buffer: Vec<_> = tokio::time::timeout(DEFAULT_TIMEOUT, stream.collect()).await.unwrap(); assert_eq!(request_capacity, buffer.len()); } diff --git a/packages/utracker/tests/test_connect.rs b/packages/utracker/tests/test_connect.rs index 450ad29b3..ac90106a5 100644 --- a/packages/utracker/tests/test_connect.rs +++ b/packages/utracker/tests/test_connect.rs @@ -1,7 +1,7 @@ use std::thread::{self}; use std::time::Duration; -use common::{handshaker, tracing_stderr_init, MockTrackerHandler, INIT}; +use common::{handshaker, tracing_stderr_init, MockTrackerHandler, DEFAULT_TIMEOUT, INIT}; use futures::StreamExt as _; use tracing::level_filters::LevelFilter; use util::bt::{self}; @@ -13,7 +13,7 @@ mod common; #[tokio::test] async fn positive_receive_connect_id() { INIT.call_once(|| { - tracing_stderr_init(LevelFilter::INFO); + tracing_stderr_init(LevelFilter::ERROR); }); let (sink, mut stream) = handshaker(); @@ -24,7 +24,7 @@ async fn positive_receive_connect_id() { thread::sleep(Duration::from_millis(100)); - let mut client = TrackerClient::new("127.0.0.1:4505".parse().unwrap(), sink).unwrap(); + let mut client = TrackerClient::new("127.0.0.1:4505".parse().unwrap(), sink, None).unwrap(); let send_token = client .request( @@ -36,12 +36,22 @@ async fn positive_receive_connect_id() { ) .unwrap(); - let _init_msg = match stream.next().await.unwrap().unwrap() { + let _init_msg = match tokio::time::timeout(DEFAULT_TIMEOUT, stream.next()) + .await + .unwrap() + .unwrap() + .unwrap() + { HandshakerMessage::InitiateMessage(message) => message, HandshakerMessage::ClientMetadata(_) => unreachable!(), }; - let metadata = match stream.next().await.unwrap().unwrap() { + let metadata = match tokio::time::timeout(DEFAULT_TIMEOUT, stream.next()) + .await + .unwrap() + .unwrap() + .unwrap() + { HandshakerMessage::InitiateMessage(_) => unreachable!(), HandshakerMessage::ClientMetadata(metadata) => metadata, }; diff --git a/packages/utracker/tests/test_connect_cache.rs b/packages/utracker/tests/test_connect_cache.rs index fe63581d3..095f82358 100644 --- a/packages/utracker/tests/test_connect_cache.rs +++ b/packages/utracker/tests/test_connect_cache.rs @@ -1,7 +1,7 @@ use std::thread::{self}; use std::time::Duration; -use common::{tracing_stderr_init, MockTrackerHandler, INIT}; +use common::{tracing_stderr_init, MockTrackerHandler, DEFAULT_TIMEOUT, INIT}; use futures::StreamExt as _; use tracing::level_filters::LevelFilter; use util::bt::{self}; @@ -12,7 +12,7 @@ mod common; #[tokio::test] async fn positive_connection_id_cache() { INIT.call_once(|| { - tracing_stderr_init(LevelFilter::INFO); + tracing_stderr_init(LevelFilter::ERROR); }); let (sink, mut stream) = common::handshaker(); @@ -23,13 +23,17 @@ async fn positive_connection_id_cache() { thread::sleep(Duration::from_millis(100)); - let mut client = TrackerClient::new("127.0.0.1:4506".parse().unwrap(), sink).unwrap(); + let mut client = TrackerClient::new("127.0.0.1:4506".parse().unwrap(), sink, None).unwrap(); let first_hash = [0u8; bt::INFO_HASH_LEN].into(); let second_hash = [1u8; bt::INFO_HASH_LEN].into(); client.request(server_addr, ClientRequest::Scrape(first_hash)).unwrap(); - stream.next().await.unwrap().unwrap(); + tokio::time::timeout(DEFAULT_TIMEOUT, stream.next()) + .await + .unwrap() + .unwrap() + .unwrap(); assert_eq!(mock_handler.num_active_connect_ids(), 1); @@ -38,7 +42,11 @@ async fn positive_connection_id_cache() { } for _ in 0..10 { - stream.next().await.unwrap().unwrap(); + tokio::time::timeout(DEFAULT_TIMEOUT, stream.next()) + .await + .unwrap() + .unwrap() + .unwrap(); } assert_eq!(mock_handler.num_active_connect_ids(), 1); diff --git a/packages/utracker/tests/test_scrape.rs b/packages/utracker/tests/test_scrape.rs index 68d5be768..e103cc990 100644 --- a/packages/utracker/tests/test_scrape.rs +++ b/packages/utracker/tests/test_scrape.rs @@ -1,7 +1,7 @@ use std::thread::{self}; use std::time::Duration; -use common::{handshaker, tracing_stderr_init, MockTrackerHandler, INIT}; +use common::{handshaker, tracing_stderr_init, MockTrackerHandler, DEFAULT_TIMEOUT, INIT}; use futures::StreamExt as _; use tracing::level_filters::LevelFilter; use util::bt::{self}; @@ -12,7 +12,7 @@ mod common; #[tokio::test] async fn positive_scrape() { INIT.call_once(|| { - tracing_stderr_init(LevelFilter::INFO); + tracing_stderr_init(LevelFilter::ERROR); }); let (sink, mut stream) = handshaker(); @@ -23,13 +23,18 @@ async fn positive_scrape() { thread::sleep(Duration::from_millis(100)); - let mut client = TrackerClient::new("127.0.0.1:4507".parse().unwrap(), sink).unwrap(); + let mut client = TrackerClient::new("127.0.0.1:4507".parse().unwrap(), sink, None).unwrap(); let send_token = client .request(server_addr, ClientRequest::Scrape([0u8; bt::INFO_HASH_LEN].into())) .unwrap(); - let metadata = match stream.next().await.unwrap().unwrap() { + let metadata = match tokio::time::timeout(DEFAULT_TIMEOUT, stream.next()) + .await + .unwrap() + .unwrap() + .unwrap() + { HandshakerMessage::InitiateMessage(_) => unreachable!(), HandshakerMessage::ClientMetadata(metadata) => metadata, }; diff --git a/packages/utracker/tests/test_server_drop.rs b/packages/utracker/tests/test_server_drop.rs index 21d2090b6..23002461a 100644 --- a/packages/utracker/tests/test_server_drop.rs +++ b/packages/utracker/tests/test_server_drop.rs @@ -12,7 +12,7 @@ mod common; #[allow(unused)] fn positive_server_dropped() { INIT.call_once(|| { - tracing_stderr_init(LevelFilter::INFO); + tracing_stderr_init(LevelFilter::ERROR); }); let server_addr = "127.0.0.1:3508".parse().unwrap();