diff --git a/cSpell.json b/cSpell.json index 97205b8f0..cd222e986 100644 --- a/cSpell.json +++ b/cSpell.json @@ -22,10 +22,12 @@ "codegen", "compat", "concated", + "Condvar", "coppersurfer", "cpupool", "curr", "cust", + "cvar", "Cyberneering", "demonii", "Deque", @@ -51,6 +53,7 @@ "metainfo", "mpmc", "myapp", + "nanos", "natted", "nextest", "Oneshot", @@ -65,6 +68,7 @@ "rebootstrapping", "recvd", "reqq", + "reregister", "ringbuffer", "rpath", "rqst", diff --git a/contrib/umio/Cargo.toml b/contrib/umio/Cargo.toml index 7307f6b44..6ca916208 100644 --- a/contrib/umio/Cargo.toml +++ b/contrib/umio/Cargo.toml @@ -16,4 +16,8 @@ repository.workspace = true version.workspace = true [dependencies] -mio = "0.5" +mio = { version = "1", features = ["net", "os-poll"] } +tracing = "0" + +[dev-dependencies] +tracing-subscriber = "0" diff --git a/contrib/umio/src/buffer.rs b/contrib/umio/src/buffer.rs index c6b21b72f..41e479bc8 100644 --- a/contrib/umio/src/buffer.rs +++ b/contrib/umio/src/buffer.rs @@ -1,3 +1,7 @@ +use std::ops::{Deref, DerefMut}; + +use tracing::instrument; + #[allow(clippy::module_name_repetitions)] pub struct BufferPool { // Use Stack For Temporal Locality @@ -5,20 +9,39 @@ pub struct BufferPool { buffer_size: usize, } +impl std::fmt::Debug for BufferPool { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("BufferPool") + .field("buffers_len", &self.buffers.len()) + .field("buffer_size", &self.buffer_size) + .finish() + } +} + impl BufferPool { + #[instrument(skip())] pub fn new(buffer_size: usize) -> BufferPool { let buffers = Vec::new(); BufferPool { buffers, buffer_size } } + #[instrument(skip(self), fields(remaining= %self.buffers.len()))] pub fn pop(&mut self) -> Buffer { - self.buffers.pop().unwrap_or(Buffer::new(self.buffer_size)) + if let Some(buffer) = self.buffers.pop() { + tracing::trace!(?buffer, "popping old buffer taken from pool"); + buffer + } else { + let buffer = Buffer::new(self.buffer_size); + tracing::trace!(?buffer, "creating new buffer..."); + buffer + } } + #[instrument(skip(self, buffer), fields(existing= %self.buffers.len()))] pub fn push(&mut self, mut buffer: Buffer) { + tracing::trace!("Pushing buffer back to pool"); buffer.reset_position(); - self.buffers.push(buffer); } } @@ -27,42 +50,58 @@ impl BufferPool { /// Reusable region of memory for incoming and outgoing messages. pub struct Buffer { - buffer: Vec, - bytes_written: usize, + buffer: std::io::Cursor>, +} + +impl std::fmt::Debug for Buffer { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("Buffer").field("buffer", &self.as_ref()).finish() + } } impl Buffer { + #[instrument(skip())] fn new(len: usize) -> Buffer { Buffer { - buffer: vec![0u8; len], - bytes_written: 0, + buffer: std::io::Cursor::new(vec![0_u8; len]), } } fn reset_position(&mut self) { - self.set_written(0); + self.set_position(0); + } +} + +impl Deref for Buffer { + type Target = std::io::Cursor>; + + fn deref(&self) -> &Self::Target { + &self.buffer } +} - /// Update the number of bytes written to the buffer. - pub fn set_written(&mut self, bytes: usize) { - self.bytes_written = bytes; +impl DerefMut for Buffer { + fn deref_mut(&mut self) -> &mut Self::Target { + &mut self.buffer } } impl AsRef<[u8]> for Buffer { fn as_ref(&self) -> &[u8] { - &self.buffer[..self.bytes_written] + self.get_ref().split_at(self.buffer.position().try_into().unwrap()).0 } } impl AsMut<[u8]> for Buffer { fn as_mut(&mut self) -> &mut [u8] { - &mut self.buffer[self.bytes_written..] + let pos = self.buffer.position().try_into().unwrap(); + self.get_mut().split_at_mut(pos).1 } } #[cfg(test)] mod tests { + use super::{Buffer, BufferPool}; const DEFAULT_BUFFER_SIZE: usize = 1500; @@ -79,7 +118,8 @@ mod tests { #[test] fn positive_buffer_len_update() { let mut buffer = Buffer::new(DEFAULT_BUFFER_SIZE); - buffer.set_written(DEFAULT_BUFFER_SIZE - 1); + + buffer.set_position((DEFAULT_BUFFER_SIZE - 1).try_into().unwrap()); assert_eq!(buffer.as_mut().len(), 1); assert_eq!(buffer.as_ref().len(), DEFAULT_BUFFER_SIZE - 1); diff --git a/contrib/umio/src/dispatcher.rs b/contrib/umio/src/dispatcher.rs index a62fd5979..7ccda4a77 100644 --- a/contrib/umio/src/dispatcher.rs +++ b/contrib/umio/src/dispatcher.rs @@ -1,135 +1,195 @@ use std::collections::VecDeque; use std::net::SocketAddr; +use std::sync::mpsc::Sender; -use mio::udp::UdpSocket; -use mio::{EventLoop, EventSet, Handler, PollOpt, Token}; +use mio::net::UdpSocket; +use mio::{Interest, Poll, Waker}; +use tracing::{instrument, Level}; use crate::buffer::{Buffer, BufferPool}; -use crate::{provider, Provider}; +use crate::eloop::ShutdownHandle; +use crate::provider::TimeoutAction; +use crate::{Provider, UDP_SOCKET_TOKEN}; -/// Handles events occurring within the event loop. -pub trait Dispatcher: Sized { - type Timeout; - type Message: Send; +pub trait Dispatcher: Sized + std::fmt::Debug { + type TimeoutToken: std::fmt::Debug; + type Message: std::fmt::Debug; - /// Process an incoming message from the given address. - #[allow(unused)] - fn incoming(&mut self, provider: Provider<'_, Self>, message: &[u8], addr: SocketAddr) {} - - /// Process a message sent via the event loop channel. - #[allow(unused)] - fn notify(&mut self, provider: Provider<'_, Self>, message: Self::Message) {} - - /// Process a timeout that has been triggered. - #[allow(unused)] - fn timeout(&mut self, provider: Provider<'_, Self>, timeout: Self::Timeout) {} + fn incoming(&mut self, _provider: Provider<'_, Self>, _message: &[u8], _addr: SocketAddr) {} + fn notify(&mut self, _provider: Provider<'_, Self>, _message: Self::Message) {} + fn timeout(&mut self, _provider: Provider<'_, Self>, _timeout: Self::TimeoutToken) {} } -//----------------------------------------------------------------------------// - -const UDP_SOCKET_TOKEN: Token = Token(2); +pub struct DispatchHandler +where + D: std::fmt::Debug, +{ + pub dispatch: D, + pub out_queue: VecDeque<(Buffer, SocketAddr)>, + socket: UdpSocket, + pub buffer_pool: BufferPool, + current_interest: Interest, + pub timer_sender: Sender>, +} -pub struct DispatchHandler { - dispatch: D, - out_queue: VecDeque<(Buffer, SocketAddr)>, - udp_socket: UdpSocket, - buffer_pool: BufferPool, - current_set: EventSet, +impl std::fmt::Debug for DispatchHandler +where + D: std::fmt::Debug, +{ + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("DispatchHandler") + .field("dispatch", &self.dispatch) + .field("out_queue_len", &self.out_queue.len()) + .field("socket", &self.socket) + .field("buffer_pool", &self.buffer_pool) + .field("current_interest", &self.current_interest) + .field("timer_sender", &self.timer_sender) + .finish() + } } -impl DispatchHandler { +impl DispatchHandler +where + D: Dispatcher + std::fmt::Debug, +{ + #[instrument(skip(), ret(level = Level::TRACE))] pub fn new( - udp_socket: UdpSocket, + mut socket: UdpSocket, buffer_size: usize, dispatch: D, - event_loop: &mut EventLoop>, - ) -> DispatchHandler { + poll: &mut Poll, + timer_sender: Sender>, + ) -> DispatchHandler + where + D: std::fmt::Debug, + ::TimeoutToken: std::fmt::Debug, + ::Message: std::fmt::Debug, + { let buffer_pool = BufferPool::new(buffer_size); let out_queue = VecDeque::new(); - event_loop - .register(&udp_socket, UDP_SOCKET_TOKEN, EventSet::readable(), PollOpt::edge()) + poll.registry() + .register(&mut socket, UDP_SOCKET_TOKEN, Interest::READABLE) .unwrap(); DispatchHandler { dispatch, out_queue, - udp_socket, + socket, buffer_pool, - current_set: EventSet::readable(), + current_interest: Interest::READABLE, + timer_sender, } } + #[instrument(skip(self, waker, shutdown_handle))] + pub fn handle_message(&mut self, waker: &Waker, shutdown_handle: &mut ShutdownHandle, message: D::Message) { + tracing::trace!("message received"); + let provider = Provider::new( + &mut self.buffer_pool, + &mut self.out_queue, + waker, + shutdown_handle, + &self.timer_sender, + ); + + self.dispatch.notify(provider, message); + } + + #[instrument(skip(self, waker, shutdown_handle))] + pub fn handle_timeout(&mut self, waker: &Waker, shutdown_handle: &mut ShutdownHandle, token: D::TimeoutToken) { + tracing::trace!("timeout expired"); + let provider = Provider::new( + &mut self.buffer_pool, + &mut self.out_queue, + waker, + shutdown_handle, + &self.timer_sender, + ); + + self.dispatch.timeout(provider, token); + } + + #[instrument(skip(self))] pub fn handle_write(&mut self) { + tracing::trace!("handle write"); + if let Some((buffer, addr)) = self.out_queue.pop_front() { - self.udp_socket.send_to(buffer.as_ref(), &addr).unwrap(); + let bytes = self.socket.send_to(buffer.as_ref(), addr).unwrap(); + + tracing::debug!(?buffer, ?bytes, ?addr, "sent"); self.buffer_pool.push(buffer); - }; + } } + #[instrument(skip(self))] pub fn handle_read(&mut self) -> Option<(Buffer, SocketAddr)> { + tracing::trace!("handle read"); + let mut buffer = self.buffer_pool.pop(); - if let Ok(Some((bytes, addr))) = self.udp_socket.recv_from(buffer.as_mut()) { - buffer.set_written(bytes); + match self.socket.recv_from(buffer.as_mut()) { + Ok((bytes, addr)) => { + buffer.set_position(bytes.try_into().unwrap()); + tracing::trace!(?buffer, "DispatchHandler: Read {bytes} bytes from {addr}"); - Some((buffer, addr)) - } else { - None + Some((buffer, addr)) + } + Err(e) => { + tracing::error!("DispatchHandler: Failed to read from UDP socket: {e}"); + None + } } } -} - -impl Handler for DispatchHandler { - type Timeout = D::Timeout; - type Message = D::Message; - - fn ready(&mut self, event_loop: &mut EventLoop, token: Token, events: EventSet) { - if token != UDP_SOCKET_TOKEN { - return; - } - - if events.is_writable() { - self.handle_write(); - } - - if events.is_readable() { - let Some((buffer, addr)) = self.handle_read() else { - return; - }; - { - let provider = provider::new(&mut self.buffer_pool, &mut self.out_queue, event_loop); - - self.dispatch.incoming(provider, buffer.as_ref(), addr); + #[instrument(skip(self, waker, shutdown_handle, event, poll))] + pub fn handle_event( + &mut self, + waker: &Waker, + shutdown_handle: &mut ShutdownHandle, + event: &mio::event::Event, + poll: &mut Poll, + ) where + T: std::fmt::Debug, + { + tracing::trace!(?event, "handle event"); + + if event.token() == UDP_SOCKET_TOKEN { + if event.is_writable() { + self.handle_write(); } - self.buffer_pool.push(buffer); + if event.is_readable() { + if let Some((buffer, addr)) = self.handle_read() { + let provider = Provider::new( + &mut self.buffer_pool, + &mut self.out_queue, + waker, + shutdown_handle, + &self.timer_sender, + ); + self.dispatch.incoming(provider, buffer.as_ref(), addr); + self.buffer_pool.push(buffer); + } + } } - } - fn notify(&mut self, event_loop: &mut EventLoop, msg: Self::Message) { - let provider = provider::new(&mut self.buffer_pool, &mut self.out_queue, event_loop); - - self.dispatch.notify(provider, msg); + self.update_interest(poll); } - fn timeout(&mut self, event_loop: &mut EventLoop, timeout: Self::Timeout) { - let provider = provider::new(&mut self.buffer_pool, &mut self.out_queue, event_loop); - - self.dispatch.timeout(provider, timeout); - } + #[instrument(skip(self, poll))] + fn update_interest(&mut self, poll: &mut Poll) { + tracing::trace!("update interest"); - fn tick(&mut self, event_loop: &mut EventLoop) { - self.current_set = if self.out_queue.is_empty() { - EventSet::readable() + self.current_interest = if self.out_queue.is_empty() { + Interest::READABLE } else { - EventSet::readable() | EventSet::writable() + Interest::READABLE | Interest::WRITABLE }; - event_loop - .reregister(&self.udp_socket, UDP_SOCKET_TOKEN, self.current_set, PollOpt::edge()) + poll.registry() + .reregister(&mut self.socket, UDP_SOCKET_TOKEN, self.current_interest) .unwrap(); } } diff --git a/contrib/umio/src/eloop.rs b/contrib/umio/src/eloop.rs index cedf030bf..f787128dc 100644 --- a/contrib/umio/src/eloop.rs +++ b/contrib/umio/src/eloop.rs @@ -1,16 +1,254 @@ -use std::io::Result; +use std::cmp::Ordering; +use std::collections::{BinaryHeap, HashSet, VecDeque}; +use std::marker::PhantomData; use std::net::{Ipv4Addr, SocketAddr, SocketAddrV4}; +use std::sync::mpsc::{self, Receiver, SendError, Sender}; +use std::sync::{Arc, Condvar, Mutex, OnceLock, Weak}; +use std::thread::JoinHandle; +use std::time::Instant; -use mio::udp::UdpSocket; -use mio::{EventLoop, EventLoopConfig, Sender}; +use mio::net::UdpSocket; +use mio::{Events, Poll, Waker}; +use tracing::{instrument, Level}; use crate::dispatcher::{DispatchHandler, Dispatcher}; +use crate::provider::TimeoutAction; +use crate::WAKER_TOKEN; const DEFAULT_BUFFER_SIZE: usize = 1500; const DEFAULT_CHANNEL_CAPACITY: usize = 4096; const DEFAULT_TIMER_CAPACITY: usize = 65536; -/// Builder for specifying attributes of an event loop. +#[derive(Debug)] +pub struct MessageSender +where + T: std::fmt::Debug, +{ + sender: Sender, + waker: Arc, +} + +impl Clone for MessageSender +where + T: std::fmt::Debug, +{ + fn clone(&self) -> Self { + Self { + sender: self.sender.clone(), + waker: self.waker.clone(), + } + } +} + +impl MessageSender +where + T: std::fmt::Debug, +{ + #[instrument(skip(), ret(level = Level::TRACE))] + fn new(sender: Sender, waker: Arc) -> Self { + Self { sender, waker } + } + + #[instrument(skip(self))] + pub fn send(&self, msg: T) -> Result<(), SendError> { + tracing::trace!("sending message"); + + let res = self.sender.send(msg); + + self.waker.wake().unwrap(); + + res + } +} + +#[derive(Debug)] +struct Timeout { + when: Instant, + token: Weak, +} + +impl Ord for Timeout { + fn cmp(&self, other: &Self) -> Ordering { + other.when.cmp(&self.when) // from smallest to largest + } +} + +impl PartialOrd for Timeout { + fn partial_cmp(&self, other: &Self) -> Option { + Some(self.cmp(other)) + } +} + +impl Eq for Timeout {} + +impl PartialEq for Timeout { + fn eq(&self, other: &Self) -> bool { + self.when == other.when + } +} + +#[derive(Debug)] +struct LoopWaker +where + T: std::fmt::Debug, +{ + waker: Arc, + active: HashSet>, + pending: Arc<(Mutex>>, Condvar)>, + finished: Arc>>>, + _handle: JoinHandle<()>, +} + +impl LoopWaker +where + T: std::hash::Hash + std::cmp::Eq + std::fmt::Debug + 'static, + Weak: Send, + Arc: Send, +{ + #[instrument(skip(waker, shutdown_handle), ret(level = Level::TRACE))] + fn new(waker: Arc, shutdown_handle: ShutdownHandle) -> Self { + let pending: Arc<(Mutex>>, Condvar)> = Arc::default(); + let finished: Arc>>> = Arc::default(); + + let handle = { + let pending = pending.clone(); + let finished = finished.clone(); + let waker = waker.clone(); + + std::thread::spawn(move || { + let mut timeouts: BinaryHeap> = BinaryHeap::default(); + let mut elapsed = VecDeque::default(); + + while !shutdown_handle.is_shutdown() { + { + let (lock, cvar) = &*pending; + let mut pending = lock.lock().unwrap(); + + while pending.is_empty() && timeouts.is_empty() { + pending = cvar.wait(pending).unwrap(); + } + + timeouts.append(&mut pending); + } + + while let Some(timeout) = timeouts.pop() { + let Some(token) = Weak::upgrade(&timeout.token) else { + continue; + }; + + match timeout.when.checked_duration_since(Instant::now()) { + Some(wait) => { + std::thread::sleep(wait); + elapsed.push_back(token); + break; + } + None => elapsed.push_back(token), + } + } + + let mut finished = finished.lock().unwrap(); + finished.append(&mut elapsed); + waker.wake().unwrap(); + } + }) + }; + + Self { + waker, + active: HashSet::default(), + pending, + finished, + _handle: handle, + } + } + + #[instrument(skip(self))] + fn next(&mut self) -> Option { + let token = self.finished.lock().unwrap().pop_front()?; + + let token = if self.remove(&token) { Arc::into_inner(token) } else { None }; + + tracing::trace!(?token, "next timeout"); + + token + } + + #[instrument(skip(self))] + fn remove(&mut self, token: &T) -> bool { + let remove = self.active.remove(token); + + tracing::trace!(%remove, "removed timeout"); + + remove + } + + #[instrument(skip(self))] + fn push(&mut self, when: Instant, token: T) -> bool { + let token = Arc::new(token); + + let timeout = Timeout { + when, + token: Arc::downgrade(&token), + }; + + let inserted = self.active.insert(token); + + if inserted { + let (lock, cvar) = &*self.pending; + + lock.lock().unwrap().push(timeout); + cvar.notify_one(); + }; + + tracing::trace!(%inserted, "new timeout"); + + inserted + } +} + +#[derive(Debug, Clone)] +pub struct ShutdownHandle { + handle: Arc>, + waker: Arc, +} + +impl ShutdownHandle { + #[instrument(skip(), ret(level = Level::TRACE))] + fn new(waker: Arc) -> Self { + Self { + handle: Arc::default(), + waker, + } + } + + #[must_use] + pub fn is_shutdown(&self) -> bool { + self.handle.get().is_some() + } + + #[instrument(skip(self))] + pub fn shutdown(&mut self) { + if self.handle.set(()).is_ok() { + tracing::info!("shutdown called"); + } else { + tracing::debug!("shutdown already called"); + }; + + match self.waker.wake() { + Ok(()) => tracing::trace!("waking... shutdown"), + Err(e) => tracing::trace!("error waking... shutdown: {e}"), + } + } +} + +impl Drop for ShutdownHandle { + #[instrument(skip(self))] + fn drop(&mut self) { + self.shutdown(); + } +} + +#[derive(Debug)] pub struct ELoopBuilder { channel_capacity: usize, timer_capacity: usize, @@ -19,57 +257,55 @@ pub struct ELoopBuilder { } impl ELoopBuilder { - /// Create a new event loop builder. #[must_use] pub fn new() -> ELoopBuilder { Self::default() } - /// Manually set the maximum channel message capacity. #[must_use] pub fn channel_capacity(mut self, capacity: usize) -> ELoopBuilder { self.channel_capacity = capacity; - self } - /// Manually set the maximum timer capacity. #[must_use] pub fn timer_capacity(mut self, capacity: usize) -> ELoopBuilder { self.timer_capacity = capacity; - self } - /// Manually set the bind address for the udp socket in the event loop. #[must_use] pub fn bind_address(mut self, address: SocketAddr) -> ELoopBuilder { self.bind_address = address; - self } - /// Manually set the length of buffers provided by the event loop. #[must_use] pub fn buffer_length(mut self, length: usize) -> ELoopBuilder { self.buffer_size = length; - self } - /// Build the event loop with the current builder. + /// Builds an `ELoop` instance with the specified configuration. /// /// # Errors /// - /// It would error when the builder config has an problem. - pub fn build(self) -> Result> { + /// This function will return an error if creating the `Poll` or `Waker` fails. + pub fn build(self) -> std::io::Result<(ELoop, SocketAddr, ShutdownHandle)> + where + D: Dispatcher + std::fmt::Debug, + ::Message: std::fmt::Debug, + ::TimeoutToken: std::hash::Hash + std::cmp::Eq + std::fmt::Debug + 'static, + Arc<::TimeoutToken>: Send, + Weak<::TimeoutToken>: Send, + { ELoop::from_builder(&self) } } impl Default for ELoopBuilder { fn default() -> Self { - let default_addr = SocketAddr::V4(SocketAddrV4::new(Ipv4Addr::new(0, 0, 0, 0), 0)); + let default_addr = SocketAddr::V4(SocketAddrV4::new(Ipv4Addr::UNSPECIFIED, 0)); ELoopBuilder { channel_capacity: DEFAULT_CHANNEL_CAPACITY, @@ -80,46 +316,141 @@ impl Default for ELoopBuilder { } } -//----------------------------------------------------------------------------// - -/// Wrapper around the main application event loop. -pub struct ELoop { +#[derive(Debug)] +pub struct ELoop +where + D: Dispatcher, + ::Message: std::fmt::Debug, +{ buffer_size: usize, - socket_addr: SocketAddr, - event_loop: EventLoop>, + socket: Option, + poll: Poll, + events: Events, + loop_waker: LoopWaker, + shutdown_handle: ShutdownHandle, + message_sender: MessageSender, + message_receiver: Receiver, + timeout_sender: Sender>, + timeout_receiver: Receiver>, + _marker: PhantomData, } -impl ELoop { - fn from_builder(builder: &ELoopBuilder) -> Result> { - let mut event_loop_config = EventLoopConfig::new(); - event_loop_config - .notify_capacity(builder.channel_capacity) - .timer_capacity(builder.timer_capacity); +impl ELoop +where + D: Dispatcher + std::fmt::Debug, + ::Message: std::fmt::Debug, + ::TimeoutToken: std::hash::Hash + std::cmp::Eq + std::fmt::Debug + 'static, + Arc<::TimeoutToken>: Send, + Weak<::TimeoutToken>: Send, +{ + #[instrument(skip(), err, ret(level = Level::TRACE))] + fn from_builder(builder: &ELoopBuilder) -> std::io::Result<(ELoop, SocketAddr, ShutdownHandle)> { + let poll = Poll::new()?; + let events = Events::with_capacity(builder.channel_capacity); + + let (message_sender, message_receiver) = mpsc::channel(); + let (timeout_sender, timeout_receiver) = mpsc::channel(); + + let socket = UdpSocket::bind(builder.bind_address)?; + + let bound_socket = socket.local_addr()?; + + let waker = Arc::new(Waker::new(poll.registry(), WAKER_TOKEN)?); - let event_loop = EventLoop::configured(event_loop_config)?; + let shutdown_handle = ShutdownHandle::new(waker.clone()); - Ok(ELoop { - buffer_size: builder.buffer_size, - socket_addr: builder.bind_address, - event_loop, - }) + let loop_waker = LoopWaker::new(waker.clone(), shutdown_handle.clone()); + let message_sender = MessageSender::new(message_sender, waker); + + Ok(( + ELoop { + buffer_size: builder.buffer_size, + socket: Some(socket), + poll, + events, + loop_waker, + shutdown_handle: shutdown_handle.clone(), + message_sender, + message_receiver, + timeout_sender, + timeout_receiver, + _marker: PhantomData, + }, + bound_socket, + shutdown_handle, + )) + } + + #[must_use] + pub fn waker(&self) -> &Waker { + &self.loop_waker.waker } - /// Grab a channel to send messages to the event loop. + /// Creates a channel for sending messages to the event loop. #[must_use] - pub fn channel(&self) -> Sender { - self.event_loop.channel() + #[instrument(skip(self))] + pub fn channel(&self) -> MessageSender<::Message> { + self.message_sender.clone() } - /// Run the event loop with the given dispatcher until a shutdown occurs. + /// Runs the event loop with the provided dispatcher. /// /// # Errors /// - /// It would error if unable to bind to the socket. - pub fn run(&mut self, dispatcher: D) -> Result<()> { - let udp_socket = UdpSocket::bound(&self.socket_addr)?; - let mut dispatch_handler = DispatchHandler::new(udp_socket, self.buffer_size, dispatcher, &mut self.event_loop); + /// This function will return an error if binding the UDP socket or polling events fails. + #[instrument(skip(self, dispatcher))] + pub fn run(&mut self, dispatcher: D) -> std::io::Result<()> + where + D: std::fmt::Debug, + ::Message: std::fmt::Debug, + ::TimeoutToken: std::hash::Hash + std::cmp::Eq + std::fmt::Debug + 'static, + { + let mut dispatch_handler = DispatchHandler::new( + self.socket.take().unwrap(), + self.buffer_size, + dispatcher, + &mut self.poll, + self.timeout_sender.clone(), + ); + + loop { + if self.shutdown_handle.is_shutdown() { + tracing::debug!("shutting down..."); + break; + } + + // Handle timeouts + while let Some(token) = self.loop_waker.next() { + dispatch_handler.handle_timeout(&self.loop_waker.waker, &mut self.shutdown_handle, token); + } + + // Handle events + for event in &self.events { + dispatch_handler.handle_event::(&self.loop_waker.waker, &mut self.shutdown_handle, event, &mut self.poll); + } + + // Handle messages + while let Ok(message) = self.message_receiver.try_recv() { + dispatch_handler.handle_message(&self.loop_waker.waker, &mut self.shutdown_handle, message); + } + + // Add timeouts + while let Ok(action) = self.timeout_receiver.try_recv() { + match action { + TimeoutAction::Add { token, when } => { + tracing::trace!(?token, ?when, "set timeout"); + self.loop_waker.push(when, token); + } + TimeoutAction::Remove { token } => { + tracing::trace!(?token, "clear timeout"); + self.loop_waker.remove(&token); + } + } + } + + self.poll.poll(&mut self.events, None)?; + } - self.event_loop.run(&mut dispatch_handler) + Ok(()) } } diff --git a/contrib/umio/src/external.rs b/contrib/umio/src/external.rs index 2ef9d04f8..1b6104ce2 100644 --- a/contrib/umio/src/external.rs +++ b/contrib/umio/src/external.rs @@ -1 +1 @@ -pub use mio::{Sender, Timeout, TimerError, TimerResult}; +pub use mio::{Events, Interest, Token, Waker}; diff --git a/contrib/umio/src/lib.rs b/contrib/umio/src/lib.rs index c246f2dfe..a29c4867e 100644 --- a/contrib/umio/src/lib.rs +++ b/contrib/umio/src/lib.rs @@ -1,17 +1,14 @@ -//! Message Based Readiness API -//! -//! This library is a thin wrapper around mio for clients who wish to -//! use a single udp socket in conjunction with message passing and -//! timeouts. - mod buffer; mod dispatcher; mod eloop; mod provider; -/// Exports of bare mio types. +const WAKER_TOKEN: Token = Token(0); +const UDP_SOCKET_TOKEN: Token = Token(2); + pub mod external; pub use dispatcher::Dispatcher; -pub use eloop::{ELoop, ELoopBuilder}; +pub use eloop::{ELoop, ELoopBuilder, MessageSender, ShutdownHandle}; +use mio::Token; pub use provider::Provider; diff --git a/contrib/umio/src/provider.rs b/contrib/umio/src/provider.rs index b9961292f..94bfab5bb 100644 --- a/contrib/umio/src/provider.rs +++ b/contrib/umio/src/provider.rs @@ -1,72 +1,159 @@ use std::collections::VecDeque; +use std::io::Write; +use std::marker::PhantomData; use std::net::SocketAddr; +use std::sync::mpsc; +use std::time::Instant; -use mio::{EventLoop, Sender, Timeout, TimerResult}; +use mio::Waker; +use tracing::instrument; use crate::buffer::{Buffer, BufferPool}; -use crate::dispatcher::{DispatchHandler, Dispatcher}; +use crate::dispatcher::Dispatcher; +use crate::eloop::ShutdownHandle; -/// Provides services to dispatcher clients. -pub struct Provider<'a, D: Dispatcher> { - buffer_pool: &'a mut BufferPool, - out_queue: &'a mut VecDeque<(Buffer, SocketAddr)>, - event_loop: &'a mut EventLoop>, +pub enum TimeoutAction +where + T: std::fmt::Debug, +{ + Add { token: T, when: Instant }, + Remove { token: T }, } -pub fn new<'a, D: Dispatcher>( +#[derive(Debug)] +pub struct Provider<'a, D> +where + D: Dispatcher + std::fmt::Debug, +{ buffer_pool: &'a mut BufferPool, + buffer: Option, out_queue: &'a mut VecDeque<(Buffer, SocketAddr)>, - event_loop: &'a mut EventLoop>, -) -> Provider<'a, D> { - Provider { - buffer_pool, - out_queue, - event_loop, - } + waker: &'a Waker, + shutdown_handle: &'a mut ShutdownHandle, + timer_sender: &'a mpsc::Sender>, + outgoing_socket: Option, + _marker: PhantomData, } -impl<'a, D: Dispatcher> Provider<'a, D> { - /// Grab a channel to send messages to the event loop. - #[must_use] - pub fn channel(&self) -> Sender { - self.event_loop.channel() - } +impl<'a, D> Write for Provider<'a, D> +where + D: Dispatcher + std::fmt::Debug, +{ + #[instrument(skip(self), fields(buffer= ?self.buffer))] + fn write(&mut self, buf: &[u8]) -> std::io::Result { + let dest = self.buffer.get_or_insert_with(|| self.buffer_pool.pop()); - /// Execute a closure with a buffer and send the buffer contents to the - /// destination address or reclaim the buffer and do not send anything. - pub fn outgoing(&mut self, out: F) - where - F: FnOnce(&mut [u8]) -> Option<(usize, SocketAddr)>, - { - let mut buffer = self.buffer_pool.pop(); - let opt_send_to = out(buffer.as_mut()); + let wrote = dest.write(buf)?; + + tracing::trace!(%wrote, "write"); + + Ok(wrote) + } - match opt_send_to { - None => self.buffer_pool.push(buffer), - Some((bytes, addr)) => { - buffer.set_written(bytes); + #[instrument(skip(self))] + fn flush(&mut self) -> std::io::Result<()> { + if let Some(buffer) = self.buffer.take() { + tracing::trace!(?buffer, "flushing..."); + if let Some(addr) = self.outgoing_socket { self.out_queue.push_back((buffer, addr)); + self.wake(); + } else { + self.buffer_pool.push(buffer); + return Err(std::io::Error::new( + std::io::ErrorKind::Other, + "No outgoing socket address set", + )); } + } else { + tracing::warn!("flush empty"); + } + Ok(()) + } +} + +impl<'a, D> Provider<'a, D> +where + D: Dispatcher + std::fmt::Debug, +{ + #[instrument(skip())] + pub fn new( + buffer_pool: &'a mut BufferPool, + out_queue: &'a mut VecDeque<(Buffer, SocketAddr)>, + waker: &'a Waker, + shutdown_handle: &'a mut ShutdownHandle, + timer_sender: &'a mpsc::Sender>, + ) -> Provider<'a, D> { + Provider { + buffer_pool, + buffer: None, + out_queue, + waker, + timer_sender, + shutdown_handle, + outgoing_socket: None, + _marker: PhantomData, } } - /// Set a timeout with the given delay and token. + #[instrument(skip(self))] + pub fn set_dest(&mut self, dest: SocketAddr) -> Option { + self.outgoing_socket.replace(dest) + } + + /// Wakes the event loop. + /// + /// # Panics + /// + /// This function will panic if waking the event loop fails. + #[instrument(skip(self))] + pub fn wake(&self) { + self.waker.wake().expect("Failed to wake the event loop"); + } + + /// Sets a timeout with the given token and delay. /// /// # Errors /// - /// It would error when the timeout returns in a error. - pub fn set_timeout(&mut self, token: D::Timeout, delay: u64) -> TimerResult { - self.event_loop.timeout_ms(token, delay) + /// This function will return an error if sending message fails. + #[instrument(skip(self, token, when))] + pub fn set_timeout(&mut self, token: D::TimeoutToken, when: Instant) -> Result<(), Box> + where + D::TimeoutToken: 'static, + { + tracing::trace!(?token, ?when, "set timeout"); + + self.timer_sender.send(TimeoutAction::Add { token, when })?; + self.wake(); + Ok(()) } - /// Clear a timeout using the provided timeout identifier. - pub fn clear_timeout(&mut self, timeout: Timeout) -> bool { - self.event_loop.clear_timeout(timeout) + /// Removes a timeout + /// + /// # Errors + /// + /// This function will return an error if sending message fails. + #[instrument(skip(self))] + pub fn remove_timeout(&mut self, token: D::TimeoutToken) -> Result<(), Box> + where + D::TimeoutToken: 'static, + { + tracing::trace!("remove timeout"); + + self.timer_sender.send(TimeoutAction::Remove { token })?; + self.wake(); + Ok(()) } - /// Shutdown the event loop. + /// Shuts down the event loop. + /// + /// # Panics + /// + /// This function will panic if sending the shutdown signal fails. + #[instrument(skip(self))] pub fn shutdown(&mut self) { - self.event_loop.shutdown(); + tracing::debug!("shutdown"); + + self.shutdown_handle.shutdown(); } } diff --git a/contrib/umio/tests/common/mod.rs b/contrib/umio/tests/common/mod.rs index 97102082b..b90e0a028 100644 --- a/contrib/umio/tests/common/mod.rs +++ b/contrib/umio/tests/common/mod.rs @@ -1,8 +1,31 @@ -use std::net::SocketAddr; +use std::io::Write; +use std::net::{Ipv4Addr, SocketAddr, SocketAddrV4}; use std::sync::mpsc::{self}; +use std::sync::Once; +use std::time::Instant; +use tracing::level_filters::LevelFilter; +use tracing::{instrument, Level}; use umio::{Dispatcher, Provider}; +pub const LOOPBACK_IPV4: SocketAddr = SocketAddr::V4(SocketAddrV4::new(Ipv4Addr::LOCALHOST, 0)); + +#[allow(dead_code)] +pub static INIT: Once = Once::new(); + +#[allow(dead_code)] +pub fn tracing_stderr_init(filter: LevelFilter) { + let builder = tracing_subscriber::fmt() + .with_max_level(filter) + .with_ansi(true) + .with_writer(std::io::stderr); + + builder.pretty().with_file(true).init(); + + tracing::info!("Logging initialized"); +} + +#[derive(Debug)] pub struct MockDispatcher { send: mpsc::Sender, } @@ -16,12 +39,13 @@ pub enum MockMessage { SendNotify, SendMessage(Vec, SocketAddr), - SendTimeout(u32, u64), + SendTimeout(u32, Instant), Shutdown, } impl MockDispatcher { + #[instrument(skip(), ret(level = Level::TRACE))] pub fn new() -> (MockDispatcher, mpsc::Receiver) { let (send, recv) = mpsc::channel(); @@ -30,28 +54,29 @@ impl MockDispatcher { } impl Dispatcher for MockDispatcher { - type Timeout = u32; + type TimeoutToken = u32; type Message = MockMessage; + #[instrument(skip())] fn incoming(&mut self, _: Provider<'_, Self>, message: &[u8], addr: SocketAddr) { let owned_message = message.to_vec(); - + tracing::trace!("MockDispatcher: Received message from {addr}"); self.send.send(MockMessage::MessageReceived(owned_message, addr)).unwrap(); } + #[instrument(skip(provider))] fn notify(&mut self, mut provider: Provider<'_, Self>, msg: Self::Message) { + tracing::trace!("MockDispatcher: Received notification {msg:?}"); match msg { MockMessage::SendMessage(message, addr) => { - provider.outgoing(|buffer| { - for (src, dst) in message.iter().zip(buffer.as_mut().iter_mut()) { - *dst = *src; - } + let _ = provider.set_dest(addr); + + let _ = provider.write(&message).unwrap(); - Some((message.len(), addr)) - }); + let () = provider.flush().unwrap(); } - MockMessage::SendTimeout(token, delay) => { - provider.set_timeout(token, delay).unwrap(); + MockMessage::SendTimeout(token, when) => { + provider.set_timeout(token, when).unwrap(); } MockMessage::SendNotify => { self.send.send(MockMessage::NotifyReceived).unwrap(); @@ -63,7 +88,9 @@ impl Dispatcher for MockDispatcher { } } - fn timeout(&mut self, _: Provider<'_, Self>, token: Self::Timeout) { + #[instrument(skip())] + fn timeout(&mut self, _: Provider<'_, Self>, token: Self::TimeoutToken) { + tracing::trace!("MockDispatcher: Timeout received for token {token}"); self.send.send(MockMessage::TimeoutReceived(token)).unwrap(); } } diff --git a/contrib/umio/tests/test_incoming.rs b/contrib/umio/tests/test_incoming.rs index 85bfb1545..c2adb68dd 100644 --- a/contrib/umio/tests/test_incoming.rs +++ b/contrib/umio/tests/test_incoming.rs @@ -1,40 +1,50 @@ use std::net::UdpSocket; -use std::thread::{self}; +use std::thread; use std::time::Duration; -use common::{MockDispatcher, MockMessage}; +use common::{tracing_stderr_init, MockDispatcher, MockMessage, INIT, LOOPBACK_IPV4}; +use tracing::level_filters::LevelFilter; use umio::ELoopBuilder; mod common; +/// Tests that an incoming message is correctly received and processed. #[test] fn positive_receive_incoming_message() { - let eloop_addr = "127.0.0.1:5050".parse().unwrap(); - let mut eloop = ELoopBuilder::new().bind_address(eloop_addr).build().unwrap(); + INIT.call_once(|| { + tracing_stderr_init(LevelFilter::ERROR); + }); + + tracing::trace!("Starting test: positive_receive_incoming_message"); + + let (mut eloop, eloop_socket, _shutdown_handle) = ELoopBuilder::new().bind_address(LOOPBACK_IPV4).build().unwrap(); let (dispatcher, dispatch_recv) = MockDispatcher::new(); let dispatch_send = eloop.channel(); - thread::spawn(move || { + let handle = thread::spawn(move || { eloop.run(dispatcher).unwrap(); }); - thread::sleep(Duration::from_millis(50)); - let socket_addr = "127.0.0.1:5051".parse().unwrap(); - let socket = UdpSocket::bind(socket_addr).unwrap(); + let socket = UdpSocket::bind(LOOPBACK_IPV4).unwrap(); + let socket_addr = socket.local_addr().unwrap(); let message = b"This Is A Test Message"; - socket.send_to(&message[..], eloop_addr).unwrap(); + tracing::trace!("Sending message to event loop"); + socket.send_to(&message[..], eloop_socket).unwrap(); thread::sleep(Duration::from_millis(50)); - match dispatch_recv.try_recv() { + tracing::trace!("Checking for received message"); + let res: Result = dispatch_recv.try_recv(); + + dispatch_send.send(MockMessage::Shutdown).unwrap(); + handle.join().unwrap(); + + match res { Ok(MockMessage::MessageReceived(msg, addr)) => { assert_eq!(&msg[..], &message[..]); - assert_eq!(addr, socket_addr); } _ => panic!("ELoop Failed To Receive Incoming Message"), - } - - dispatch_send.send(MockMessage::Shutdown).unwrap(); + }; } diff --git a/contrib/umio/tests/test_notify.rs b/contrib/umio/tests/test_notify.rs index 5b624dee6..db9ba0769 100644 --- a/contrib/umio/tests/test_notify.rs +++ b/contrib/umio/tests/test_notify.rs @@ -1,31 +1,39 @@ use std::thread::{self}; use std::time::Duration; -use common::{MockDispatcher, MockMessage}; +use common::{tracing_stderr_init, MockDispatcher, MockMessage, INIT, LOOPBACK_IPV4}; +use tracing::level_filters::LevelFilter; use umio::ELoopBuilder; mod common; #[test] fn positive_send_notify() { - let eloop_addr = "127.0.0.1:0".parse().unwrap(); - let mut eloop = ELoopBuilder::new().bind_address(eloop_addr).build().unwrap(); + INIT.call_once(|| { + tracing_stderr_init(LevelFilter::ERROR); + }); + + let (mut eloop, _eloop_socket, _shutdown_handle) = ELoopBuilder::new().bind_address(LOOPBACK_IPV4).build().unwrap(); let (dispatcher, dispatch_recv) = MockDispatcher::new(); let dispatch_send = eloop.channel(); - thread::spawn(move || { + let handle = thread::spawn(move || { eloop.run(dispatcher).unwrap(); }); thread::sleep(Duration::from_millis(50)); + tracing::trace!("Sending MockMessage::SendNotify"); dispatch_send.send(MockMessage::SendNotify).unwrap(); thread::sleep(Duration::from_millis(50)); - match dispatch_recv.try_recv() { + let res = dispatch_recv.try_recv(); + + dispatch_send.send(MockMessage::Shutdown).unwrap(); + handle.join().unwrap(); + + match res { Ok(MockMessage::NotifyReceived) => (), _ => panic!("ELoop Failed To Receive Incoming Message"), } - - dispatch_send.send(MockMessage::Shutdown).unwrap(); } diff --git a/contrib/umio/tests/test_outgoing.rs b/contrib/umio/tests/test_outgoing.rs index 3db172f75..b2d36bd85 100644 --- a/contrib/umio/tests/test_outgoing.rs +++ b/contrib/umio/tests/test_outgoing.rs @@ -1,39 +1,47 @@ use std::net::UdpSocket; -use std::thread::{self}; +use std::thread; use std::time::Duration; -use common::{MockDispatcher, MockMessage}; +use common::{tracing_stderr_init, MockDispatcher, MockMessage, INIT, LOOPBACK_IPV4}; +use tracing::level_filters::LevelFilter; use umio::ELoopBuilder; mod common; #[test] fn positive_send_outgoing_message() { - let eloop_addr = "127.0.0.1:5052".parse().unwrap(); - let mut eloop = ELoopBuilder::new().bind_address(eloop_addr).build().unwrap(); + INIT.call_once(|| { + tracing_stderr_init(LevelFilter::ERROR); + }); + + tracing::trace!("Starting test: positive_send_outgoing_message"); + let (mut eloop, eloop_socket, _shutdown_handle) = ELoopBuilder::new().bind_address(LOOPBACK_IPV4).build().unwrap(); let (dispatcher, _) = MockDispatcher::new(); let dispatch_send = eloop.channel(); - thread::spawn(move || { + let handle = thread::spawn(move || { eloop.run(dispatcher).unwrap(); }); - thread::sleep(Duration::from_millis(50)); let message = b"This Is A Test Message"; let mut message_recv = [0u8; 22]; - let socket_addr = "127.0.0.1:5053".parse().unwrap(); - let socket = UdpSocket::bind(socket_addr).unwrap(); + let socket = UdpSocket::bind(LOOPBACK_IPV4).unwrap(); + let socket_addr = socket.local_addr().unwrap(); // Get the actual address + + tracing::trace!("sending message to: {socket_addr}"); dispatch_send .send(MockMessage::SendMessage(message.to_vec(), socket_addr)) .unwrap(); - thread::sleep(Duration::from_millis(50)); + tracing::trace!("receiving message from: {eloop_socket}"); + socket.set_read_timeout(Some(Duration::from_secs(1))).unwrap(); let (bytes, addr) = socket.recv_from(&mut message_recv).unwrap(); + dispatch_send.send(MockMessage::Shutdown).unwrap(); + handle.join().unwrap(); // Wait for the event loop to finish + assert_eq!(bytes, message.len()); assert_eq!(&message[..], &message_recv[..]); - assert_eq!(addr, eloop_addr); - - dispatch_send.send(MockMessage::Shutdown).unwrap(); + assert_eq!(addr, eloop_socket); } diff --git a/contrib/umio/tests/test_shutdown.rs b/contrib/umio/tests/test_shutdown.rs index dd7023f2f..954f90b89 100644 --- a/contrib/umio/tests/test_shutdown.rs +++ b/contrib/umio/tests/test_shutdown.rs @@ -1,26 +1,30 @@ use std::thread::{self}; use std::time::Duration; -use common::{MockDispatcher, MockMessage}; +use common::{tracing_stderr_init, MockDispatcher, MockMessage, INIT, LOOPBACK_IPV4}; +use tracing::level_filters::LevelFilter; use umio::ELoopBuilder; mod common; #[test] fn positive_execute_shutdown() { - let eloop_addr = "127.0.0.1:0".parse().unwrap(); - let mut eloop = ELoopBuilder::new().bind_address(eloop_addr).build().unwrap(); + INIT.call_once(|| { + tracing_stderr_init(LevelFilter::ERROR); + }); + + let (mut eloop, _eloop_socket, _shutdown_handle) = ELoopBuilder::new().bind_address(LOOPBACK_IPV4).build().unwrap(); let (dispatcher, _) = MockDispatcher::new(); let dispatch_send = eloop.channel(); - thread::spawn(move || { + let handle = thread::spawn(move || { eloop.run(dispatcher).unwrap(); }); thread::sleep(Duration::from_millis(50)); dispatch_send.send(MockMessage::Shutdown).unwrap(); - thread::sleep(Duration::from_millis(50)); + handle.join().unwrap(); assert!(dispatch_send.send(MockMessage::SendNotify).is_err()); } diff --git a/contrib/umio/tests/test_timeout.rs b/contrib/umio/tests/test_timeout.rs index 60f537016..19be83f89 100644 --- a/contrib/umio/tests/test_timeout.rs +++ b/contrib/umio/tests/test_timeout.rs @@ -1,34 +1,43 @@ use std::thread::{self}; -use std::time::Duration; +use std::time::{Duration, Instant}; -use common::{MockDispatcher, MockMessage}; +use common::{tracing_stderr_init, MockDispatcher, MockMessage, INIT, LOOPBACK_IPV4}; +use tracing::level_filters::LevelFilter; use umio::ELoopBuilder; mod common; #[test] fn positive_send_notify() { - let eloop_addr = "127.0.0.1:0".parse().unwrap(); - let mut eloop = ELoopBuilder::new().bind_address(eloop_addr).build().unwrap(); + INIT.call_once(|| { + tracing_stderr_init(LevelFilter::ERROR); + }); + + let (mut eloop, _eloop_socket, _shutdown_handle) = ELoopBuilder::new().bind_address(LOOPBACK_IPV4).build().unwrap(); let (dispatcher, dispatch_recv) = MockDispatcher::new(); let dispatch_send = eloop.channel(); - thread::spawn(move || { + let handle = thread::spawn(move || { eloop.run(dispatcher).unwrap(); }); thread::sleep(Duration::from_millis(50)); let token = 5; - dispatch_send.send(MockMessage::SendTimeout(token, 50)).unwrap(); + let timeout_at = Instant::now() + Duration::from_millis(50); + dispatch_send.send(MockMessage::SendTimeout(token, timeout_at)).unwrap(); thread::sleep(Duration::from_millis(300)); - match dispatch_recv.try_recv() { + let res = dispatch_recv.try_recv(); + + dispatch_send.send(MockMessage::Shutdown).unwrap(); + handle.join().unwrap(); + + match res { Ok(MockMessage::TimeoutReceived(tkn)) => { assert_eq!(tkn, token); } - _ => panic!("ELoop Failed To Receive Timeout"), + Ok(other) => panic!("Received Other: {other:?}"), + Err(e) => panic!("Received Error: {e}"), } - - dispatch_send.send(MockMessage::Shutdown).unwrap(); } diff --git a/packages/dht/examples/debug.rs b/packages/dht/examples/debug.rs index 4998bb4ed..eb43ca27a 100644 --- a/packages/dht/examples/debug.rs +++ b/packages/dht/examples/debug.rs @@ -43,7 +43,7 @@ impl HandshakerTrait for SimpleHandshaker { self.filter.insert(addr); self.count += 1; - println!("Received new peer {:?}, total unique peers {}", addr, self.count); + tracing::trace!("Received new peer {:?}, total unique peers {}", addr, self.count); Box::pin(std::future::ready(())) } @@ -85,7 +85,7 @@ async fn main() { let mut events = dht.events().await; tasks.spawn(async move { while let Some(event) = events.next().await { - println!("\nReceived Dht Event {event:?}"); + tracing::trace!("\nReceived Dht Event {event:?}"); } }); diff --git a/packages/handshake/examples/handshake_torrent.rs b/packages/handshake/examples/handshake_torrent.rs index 0d8bc20b8..ab9568b65 100644 --- a/packages/handshake/examples/handshake_torrent.rs +++ b/packages/handshake/examples/handshake_torrent.rs @@ -36,7 +36,7 @@ async fn main() -> std::io::Result<()> { .await .map_err(|e| std::io::Error::new(std::io::ErrorKind::Other, e))?; - println!("\nConnection With Peer Established...Closing In 10 Seconds"); + tracing::trace!("\nConnection With Peer Established...Closing In 10 Seconds"); sleep(Duration::from_secs(10)).await; diff --git a/packages/utracker/src/announce.rs b/packages/utracker/src/announce.rs index 7906b6ec6..52c2df12a 100644 --- a/packages/utracker/src/announce.rs +++ b/packages/utracker/src/announce.rs @@ -96,10 +96,12 @@ impl<'a> AnnounceRequest<'a> { /// # Errors /// /// It would return an IO error if unable to write the bytes. - pub fn write_bytes(&self, mut writer: W) -> std::io::Result<()> + pub fn write_bytes(&self, mut writer: &mut W) -> std::io::Result<()> where W: std::io::Write, { + tracing::trace!("write_bytes"); + writer.write_all(self.info_hash.as_ref())?; writer.write_all(self.peer_id.as_ref())?; @@ -316,7 +318,7 @@ fn parse_response<'a>( // ----------------------------------------------------------------------------// /// Announce state of a client reported to the server. -#[derive(Debug, PartialEq, Eq, Copy, Clone)] +#[derive(Clone, Copy, Debug, PartialEq, Eq, PartialOrd, Ord, Hash)] pub struct ClientState { downloaded: i64, left: i64, @@ -400,7 +402,7 @@ fn parse_state(bytes: &[u8]) -> IResult<&[u8], ClientState> { /// Announce event of a client reported to the server. #[allow(clippy::module_name_repetitions)] -#[derive(Debug, Copy, Clone, PartialEq, Eq)] +#[derive(Clone, Copy, Debug, PartialEq, Eq, PartialOrd, Ord, Hash)] pub enum AnnounceEvent { /// No event is reported. None, diff --git a/packages/utracker/src/client/dispatcher.rs b/packages/utracker/src/client/dispatcher.rs index 2d0dd8b9b..99114c624 100644 --- a/packages/utracker/src/client/dispatcher.rs +++ b/packages/utracker/src/client/dispatcher.rs @@ -1,18 +1,17 @@ use std::collections::hash_map::Entry; use std::collections::HashMap; use std::net::SocketAddr; +use std::ops::{Deref, DerefMut}; +use std::time::{Duration, Instant, UNIX_EPOCH}; -use chrono::offset::Utc; -use chrono::{DateTime, Duration}; use futures::executor::block_on; use futures::future::{BoxFuture, Either}; use futures::sink::Sink; use futures::{FutureExt, SinkExt}; use handshake::{DiscoveryInfo, InitiateMessage, Protocol}; use nom::IResult; -use tracing::instrument; -use umio::external::{self, Timeout}; -use umio::{Dispatcher, ELoopBuilder, Provider}; +use tracing::{instrument, Level}; +use umio::{Dispatcher, ELoopBuilder, MessageSender, Provider, ShutdownHandle}; use util::bt::PeerId; use super::HandshakerMessage; @@ -30,12 +29,64 @@ const CONNECTION_ID_VALID_DURATION_MILLIS: i64 = 60000; const MAXIMUM_REQUEST_RETRANSMIT_ATTEMPTS: u64 = 8; /// Internal dispatch timeout. -#[derive(Debug)] +#[derive(Clone, Copy, Debug, PartialEq, Eq, PartialOrd, Ord, Hash)] enum DispatchTimeout { Connect(ClientToken), CleanUp, } +impl Default for DispatchTimeout { + fn default() -> Self { + Self::CleanUp + } +} + +#[derive(Default, Clone, Copy, Debug)] +struct TimeoutToken { + id: TimeoutId, + dispatch: DispatchTimeout, +} + +impl TimeoutToken { + fn new(dispatch: DispatchTimeout) -> (Self, TimeoutId) { + let id = TimeoutId::default(); + (Self { id, dispatch }, id) + } + + fn cleanup(id: TimeoutId) -> Self { + Self { + id, + dispatch: DispatchTimeout::CleanUp, + } + } +} + +impl Ord for TimeoutToken { + fn cmp(&self, other: &Self) -> std::cmp::Ordering { + self.id.cmp(&other.id) + } +} + +impl PartialOrd for TimeoutToken { + fn partial_cmp(&self, other: &Self) -> Option { + Some(self.cmp(other)) + } +} + +impl Eq for TimeoutToken {} + +impl PartialEq for TimeoutToken { + fn eq(&self, other: &Self) -> bool { + self.id == other.id + } +} + +impl std::hash::Hash for TimeoutToken { + fn hash(&self, state: &mut H) { + self.id.hash(state); + } +} + /// Internal dispatch message for clients. #[derive(Debug)] pub enum DispatchMessage { @@ -54,7 +105,7 @@ pub fn create_dispatcher( handshaker: H, msg_capacity: usize, limiter: RequestLimiter, -) -> std::io::Result> +) -> std::io::Result<(MessageSender, SocketAddr, ShutdownHandle)> where H: Sink> + std::fmt::Debug + DiscoveryInfo + Send + Unpin + 'static, H::Error: std::fmt::Display, @@ -68,7 +119,7 @@ where .bind_address(bind) .buffer_length(EXPECTED_PACKET_LENGTH); - let mut eloop = builder.build()?; + let (mut eloop, socket, shutdown) = builder.build()?; let channel = eloop.channel(); let dispatch = ClientDispatcher::new(handshaker, bind, limiter); @@ -81,7 +132,7 @@ where .send(DispatchMessage::StartTimer) .expect("bip_utracker: ELoop Failed To Start Connect ID Timer..."); - Ok(channel) + Ok((channel, socket, shutdown)) } // ----------------------------------------------------------------------------// @@ -104,7 +155,7 @@ where H::Error: std::fmt::Display, { /// Create a new `ClientDispatcher`. - #[instrument(skip(), ret)] + #[instrument(skip(), ret(level = Level::TRACE))] pub fn new(handshaker: H, bind: SocketAddr, limiter: RequestLimiter) -> ClientDispatcher { tracing::debug!("new client dispatcher"); @@ -142,7 +193,7 @@ where /// Finish a request by sending the result back to the client. #[instrument(skip(self))] pub fn notify_client(&mut self, token: ClientToken, result: ClientResult) { - tracing::info!("notifying clients"); + tracing::trace!("notifying clients"); match block_on(self.handshaker.send(Ok(ClientMetadata::new(token, result).into()))) { Ok(()) => tracing::debug!("client metadata sent"), @@ -153,7 +204,7 @@ where } /// Process a request to be sent to the given address and associated with the given token. - #[instrument(skip(self, provider))] + #[instrument(skip(self, provider, addr, token, request))] pub fn send_request( &mut self, provider: &mut Provider<'_, ClientDispatcher>, @@ -161,7 +212,7 @@ where token: ClientToken, request: ClientRequest, ) { - tracing::debug!("sending request"); + tracing::debug!(?addr, ?token, ?request, "sending request"); let bound_addr = self.bound_addr; @@ -182,14 +233,14 @@ where } /// Process a response received from some tracker and match it up against our sent requests. - #[instrument(skip(self, provider, response))] + #[instrument(skip(self, provider, response, addr))] pub fn recv_response( &mut self, provider: &mut Provider<'_, ClientDispatcher>, response: &TrackerResponse<'_>, addr: SocketAddr, ) { - tracing::debug!("receiving response"); + tracing::debug!(?response, ?addr, "receiving response"); let token = ClientToken(response.transaction_id()); @@ -207,11 +258,11 @@ where return; }; - provider.clear_timeout( - conn_timer - .timeout_id() - .expect("bip_utracker: Failed To Clear Request Timeout"), - ); + if let Some(clear_timeout_token) = conn_timer.timeout_id().map(TimeoutToken::cleanup) { + provider + .remove_timeout(clear_timeout_token) + .expect("bip_utracker: Failed To Clear Request Timeout"); + }; // Check if the response requires us to update the connection timer if let &ResponseType::Connect(id) = response.response_type() { @@ -225,7 +276,7 @@ where (&ClientRequest::Announce(hash, _), ResponseType::Announce(res)) => { // Forward contact information on to the handshaker for addr in res.peers().iter() { - tracing::info!("sending will block if unable to send!"); + tracing::debug!("sending will block if unable to send!"); match block_on( self.handshaker .send(Ok(InitiateMessage::new(Protocol::BitTorrent, hash, addr).into())), @@ -253,9 +304,9 @@ where /// Process an existing request, either re requesting a connection id or sending the actual request again. /// /// If this call is the result of a timeout, that will decide whether to cancel the request or not. - #[instrument(skip(self, provider))] + #[instrument(skip(self, provider, token, timed_out))] fn process_request(&mut self, provider: &mut Provider<'_, ClientDispatcher>, token: ClientToken, timed_out: bool) { - tracing::debug!("processing request"); + tracing::debug!(?token, ?timed_out, "processing request"); let Some(mut conn_timer) = self.active_requests.remove(&token) else { tracing::error!(?token, "token not in active requests"); @@ -312,34 +363,37 @@ where // Try to write the request out to the server let mut write_success = false; - provider.outgoing(|bytes| { - let mut writer = std::io::Cursor::new(bytes); - match tracker_request.write_bytes(&mut writer) { + provider.set_dest(addr); + + { + match tracker_request.write_bytes(provider) { Ok(()) => { write_success = true; - Some((writer.position().try_into().unwrap(), addr)) } Err(e) => { - tracing::error!("failed to write out the tracker request with error: {e}"); - None + tracing::error!(?e, "failed to write out the tracker request with error"); } - } - }); + }; + } + + let next_timeout_at = Instant::now().checked_add(Duration::from_millis(next_timeout)).unwrap(); + + let (timeout_token, timeout_id) = TimeoutToken::new(DispatchTimeout::Connect(token)); + + let () = provider + .set_timeout(timeout_token, next_timeout_at) + .expect("bip_utracker: Failed To Set Timeout For Request"); // If message was not sent (too long to fit) then end the request if write_success { - conn_timer.set_timeout_id( - provider - .set_timeout(DispatchTimeout::Connect(token), next_timeout) - .expect("bip_utracker: Failed To Set Timeout For Request"), - ); + conn_timer.set_timeout_id(timeout_id); self.active_requests.insert(token, conn_timer); } else { - let err = ClientError::MaxLength; - tracing::warn!("notifying client with error: {err}"); + let e = ClientError::MaxLength; + tracing::warn!(?e, "notifying client with error"); - self.notify_client(token, Err(err)); + self.notify_client(token, Err(e)); } } } @@ -349,47 +403,53 @@ where H: Sink> + std::fmt::Debug + DiscoveryInfo + Send + Unpin + 'static, H::Error: std::fmt::Display, { - type Timeout = DispatchTimeout; + type TimeoutToken = TimeoutToken; type Message = DispatchMessage; - #[instrument(skip(self, provider))] + #[instrument(skip(self, provider, message, addr))] fn incoming(&mut self, mut provider: Provider<'_, Self>, message: &[u8], addr: SocketAddr) { + tracing::debug!(?message, %addr, "received incoming"); + let () = match TrackerResponse::from_bytes(message) { IResult::Ok((_, response)) => { - tracing::debug!("received an incoming response: {response:?}"); + tracing::trace!(?response, %addr, "received an incoming response"); self.recv_response(&mut provider, &response, addr); } Err(e) => { - tracing::error!("received an incoming error message: {e}"); + tracing::error!(%e, "received an incoming error message"); } }; } - #[instrument(skip(self, provider))] + #[instrument(skip(self, provider, message))] fn notify(&mut self, mut provider: Provider<'_, Self>, message: DispatchMessage) { - tracing::debug!("received notify"); + tracing::debug!(?message, "received notify"); match message { DispatchMessage::Request(addr, token, req_type) => { self.send_request(&mut provider, addr, token, req_type); } - DispatchMessage::StartTimer => self.timeout(provider, DispatchTimeout::CleanUp), + DispatchMessage::StartTimer => self.timeout(provider, TimeoutToken::default()), DispatchMessage::Shutdown => self.shutdown(&mut provider), } } - #[instrument(skip(self, provider))] - fn timeout(&mut self, mut provider: Provider<'_, Self>, timeout: DispatchTimeout) { - tracing::debug!("received timeout"); + #[instrument(skip(self, provider, timeout))] + fn timeout(&mut self, mut provider: Provider<'_, Self>, timeout: TimeoutToken) { + tracing::debug!(?timeout, "received timeout"); - match timeout { + match timeout.dispatch { DispatchTimeout::Connect(token) => self.process_request(&mut provider, token, true), DispatchTimeout::CleanUp => { self.id_cache.clean_expired(); + let next_timeout_at = Instant::now() + .checked_add(Duration::from_millis(CONNECTION_ID_VALID_DURATION_MILLIS as u64)) + .unwrap(); + provider - .set_timeout(DispatchTimeout::CleanUp, CONNECTION_ID_VALID_DURATION_MILLIS as u64) + .set_timeout(TimeoutToken::default(), next_timeout_at) .expect("bip_utracker: Failed To Restart Connect Id Cleanup Timer"); } }; @@ -398,29 +458,47 @@ where // ----------------------------------------------------------------------------// +#[derive(Clone, Copy, Debug, PartialEq, Eq, PartialOrd, Ord, Hash)] +struct TimeoutId { + id: u128, +} + +impl Default for TimeoutId { + fn default() -> Self { + Self { + id: UNIX_EPOCH.elapsed().unwrap().as_nanos(), + } + } +} + +impl TimeoutId { + fn new(id: u128) -> Self { + Self { id } + } +} + +impl Deref for TimeoutId { + type Target = u128; + + fn deref(&self) -> &Self::Target { + &self.id + } +} + +impl DerefMut for TimeoutId { + fn deref_mut(&mut self) -> &mut Self::Target { + &mut self.id + } +} + /// Contains logic for making sure a valid connection id is present /// and correctly timing out when sending requests to the server. +#[derive(Clone, Copy, Debug, PartialEq, Eq, PartialOrd, Ord, Hash)] struct ConnectTimer { addr: SocketAddr, attempt: u64, request: ClientRequest, - timeout_id: Option, -} - -impl std::fmt::Debug for ConnectTimer { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - let timeout_id = match self.timeout_id { - Some(_) => "Some(_)", - None => "None", - }; - - f.debug_struct("ConnectTimer") - .field("addr", &self.addr) - .field("attempt", &self.attempt) - .field("request", &self.request) - .field("timeout_id", &timeout_id) - .finish() - } + timeout_id: Option, } impl ConnectTimer { @@ -435,10 +513,8 @@ impl ConnectTimer { } /// Yields the current timeout value to use or None if the request should time out completely. - #[instrument(skip(), ret)] + #[instrument(skip(self), ret(level = Level::TRACE))] pub fn current_timeout(&mut self, timed_out: bool) -> Option { - tracing::debug!("getting current timeout"); - if self.attempt == MAXIMUM_REQUEST_RETRANSMIT_ATTEMPTS { tracing::warn!("request has reached maximum timeout attempts: {MAXIMUM_REQUEST_RETRANSMIT_ATTEMPTS}"); @@ -453,31 +529,31 @@ impl ConnectTimer { } /// Yields the current timeout id if one is set. - pub fn timeout_id(&self) -> Option { + pub fn timeout_id(&self) -> Option { self.timeout_id } /// Sets a new timeout id. - pub fn set_timeout_id(&mut self, id: Timeout) { + pub fn set_timeout_id(&mut self, id: TimeoutId) { self.timeout_id = Some(id); } /// Yields the message parameters for the current connection. - #[instrument(skip(), ret)] + #[instrument(skip(self), ret(level = Level::TRACE))] pub fn message_params(&self) -> (SocketAddr, &ClientRequest) { - tracing::debug!("getting message parameters"); - (self.addr, &self.request) } } /// Calculates the timeout for the request given the attempt count. -#[instrument(skip(), ret)] +#[instrument(skip())] fn calculate_message_timeout_millis(attempt: u64) -> u64 { - tracing::debug!("calculation message timeout in milliseconds"); - let attempt = attempt.try_into().unwrap_or(u32::MAX); - (15 * 2u64.pow(attempt)) * 1000 + let timeout = (15 * 2u64.pow(attempt)) * 1000; + + tracing::debug!(attempt, timeout, "calculated message timeout in milliseconds"); + + timeout } // ----------------------------------------------------------------------------// @@ -485,7 +561,7 @@ fn calculate_message_timeout_millis(attempt: u64) -> u64 { /// Cache for storing connection ids associated with a specific server address. #[derive(Debug)] struct ConnectIdCache { - cache: HashMap)>, + cache: HashMap, } impl ConnectIdCache { @@ -495,18 +571,16 @@ impl ConnectIdCache { } /// Get an active connection id for the given addr. - #[instrument(skip(self), ret)] + #[instrument(skip(self), ret(level = Level::TRACE))] fn get(&mut self, addr: SocketAddr) -> Option { - tracing::debug!("getting connection id"); - match self.cache.entry(addr) { Entry::Vacant(_) => { - tracing::warn!("connection id for {addr} not in cache"); + tracing::debug!("connection id for {addr} not in cache"); None } Entry::Occupied(occ) => { - let curr_time = Utc::now(); + let curr_time = Instant::now(); let prev_time = occ.get().1; if is_expired(curr_time, prev_time) { @@ -525,9 +599,9 @@ impl ConnectIdCache { /// Put an un expired connection id into cache for the given addr. #[instrument(skip(self))] fn put(&mut self, addr: SocketAddr, connect_id: u64) { - tracing::debug!("setting expired connection id"); + tracing::trace!("setting un expired connection id"); - let curr_time = Utc::now(); + let curr_time = Instant::now(); self.cache.insert(addr, (connect_id, curr_time)); } @@ -535,9 +609,8 @@ impl ConnectIdCache { /// Removes all entries that have expired. #[instrument(skip(self))] fn clean_expired(&mut self) { - tracing::debug!("cleaning expired connection id(s)"); - - let curr_time = Utc::now(); + let curr_time = Instant::now(); + let mut removed = 0; let mut curr_index = 0; let mut opt_curr_entry = self.cache.iter().skip(curr_index).map(|(&k, &v)| (k, v)).next(); @@ -549,16 +622,20 @@ impl ConnectIdCache { curr_index += 1; opt_curr_entry = self.cache.iter().skip(curr_index).map(|(&k, &v)| (k, v)).next(); } + + if removed != 0 { + tracing::debug!(%removed, "expired connection id(s)"); + } } } /// Returns true if the connect id received at `prev_time` is now expired. -#[instrument(skip(), ret)] -fn is_expired(curr_time: DateTime, prev_time: DateTime) -> bool { - tracing::debug!("checking if a previous time is now expired"); - - let valid_duration = Duration::milliseconds(CONNECTION_ID_VALID_DURATION_MILLIS); - let difference = prev_time.signed_duration_since(curr_time); - - difference >= valid_duration +#[instrument(skip(), ret(level = Level::TRACE))] +fn is_expired(curr_time: Instant, prev_time: Instant) -> bool { + let Some(difference) = curr_time.checked_duration_since(prev_time) else { + // in future + return true; + }; + + difference >= Duration::from_millis(CONNECTION_ID_VALID_DURATION_MILLIS as u64) } diff --git a/packages/utracker/src/client/mod.rs b/packages/utracker/src/client/mod.rs index fb0449d85..a6c777090 100644 --- a/packages/utracker/src/client/mod.rs +++ b/packages/utracker/src/client/mod.rs @@ -6,7 +6,7 @@ use futures::future::Either; use futures::sink::Sink; use handshake::{DiscoveryInfo, InitiateMessage}; use tracing::instrument; -use umio::external::Sender; +use umio::{MessageSender, ShutdownHandle}; use util::bt::InfoHash; use util::trans::{LocallyShuffledIds, TransactionIds}; @@ -41,7 +41,7 @@ impl From for HandshakerMessage { /// Request made by the `TrackerClient`. #[allow(clippy::module_name_repetitions)] -#[derive(Debug)] +#[derive(Clone, Copy, Debug, PartialEq, Eq, PartialOrd, Ord, Hash)] pub enum ClientRequest { Announce(InfoHash, ClientState), Scrape(InfoHash), @@ -119,10 +119,12 @@ impl ClientResponse { /// Client will shutdown on drop. #[allow(clippy::module_name_repetitions)] pub struct TrackerClient { - send: Sender, + send: MessageSender, // We are in charge of incrementing this, background worker is in charge of decrementing limiter: RequestLimiter, generator: TokenGenerator, + bound_socket: SocketAddr, + shutdown_handle: ShutdownHandle, } impl TrackerClient { @@ -143,14 +145,12 @@ impl TrackerClient { H: Sink> + std::fmt::Debug + DiscoveryInfo + Send + Unpin + 'static, H::Error: std::fmt::Display, { - tracing::info!("running client"); - let capacity = if let Some(capacity) = capacity_or_default { - tracing::debug!("with capacity {capacity}"); + tracing::trace!("with capacity {capacity}"); capacity } else { - tracing::debug!("with default capacity: {DEFAULT_CAPACITY}"); + tracing::trace!("with default capacity: {DEFAULT_CAPACITY}"); DEFAULT_CAPACITY }; @@ -165,12 +165,17 @@ impl TrackerClient { // Limit the capacity of messages (channel capacity - 1) let limiter = RequestLimiter::new(capacity); - let dispatcher = dispatcher::create_dispatcher(bind, handshaker, chan_capacity, limiter.clone())?; + let (dispatcher, bound_socket, shutdown_handle) = + dispatcher::create_dispatcher(bind, handshaker, chan_capacity, limiter.clone())?; + + tracing::info!(?bound_socket, "running client"); Ok(TrackerClient { send: dispatcher, limiter, generator: TokenGenerator::new(), + bound_socket, + shutdown_handle, }) } @@ -183,12 +188,15 @@ impl TrackerClient { /// It would panic if unable to send request message. #[instrument(skip(self))] pub fn request(&mut self, addr: SocketAddr, request: ClientRequest) -> Option { - tracing::debug!("requesting"); - if self.limiter.can_initiate() { let token = self.generator.generate(); + + let message = DispatchMessage::Request(addr, token, request); + + tracing::debug!(?message, "requesting"); + self.send - .send(DispatchMessage::Request(addr, token, request)) + .send(message) .expect("bip_utracker: Failed To Send Client Request Message..."); Some(token) @@ -212,7 +220,7 @@ impl Drop for TrackerClient { /// Associates a `ClientRequest` with a `ClientResponse`. #[allow(clippy::module_name_repetitions)] -#[derive(Copy, Clone, Debug, PartialEq, Eq, Hash)] +#[derive(Copy, Clone, Debug, Hash, PartialEq, Eq, PartialOrd, Ord)] pub struct ClientToken(u32); /// Generates tokens which double as transaction ids. diff --git a/packages/utracker/src/request.rs b/packages/utracker/src/request.rs index 67180dd73..d0b8a1a09 100644 --- a/packages/utracker/src/request.rs +++ b/packages/utracker/src/request.rs @@ -8,6 +8,7 @@ use nom::combinator::{map, map_res}; use nom::number::complete::{be_u32, be_u64}; use nom::sequence::tuple; use nom::IResult; +use tracing::instrument; use crate::announce::AnnounceRequest; use crate::scrape::ScrapeRequest; @@ -76,10 +77,13 @@ impl<'a> TrackerRequest<'a> { /// # Errors /// /// It would return an IO Error if unable to write the bytes. - pub fn write_bytes(&self, mut writer: W) -> std::io::Result<()> + #[instrument(skip(self, writer), err)] + pub fn write_bytes(&self, mut writer: &mut W) -> std::io::Result<()> where W: std::io::Write, { + tracing::trace!("write_bytes"); + writer.write_u64::(self.connection_id())?; match self.request_type() { @@ -96,16 +100,18 @@ impl<'a> TrackerRequest<'a> { writer.write_u32::(action_id)?; writer.write_u32::(self.transaction_id())?; - req.write_bytes(writer)?; + req.write_bytes(&mut writer)?; } RequestType::Scrape(req) => { writer.write_u32::(crate::SCRAPE_ACTION_ID)?; writer.write_u32::(self.transaction_id())?; - req.write_bytes(writer)?; + req.write_bytes(&mut writer)?; } }; + writer.flush()?; + Ok(()) } diff --git a/packages/utracker/src/response.rs b/packages/utracker/src/response.rs index 51742a290..fa7048bac 100644 --- a/packages/utracker/src/response.rs +++ b/packages/utracker/src/response.rs @@ -91,22 +91,24 @@ impl<'a> TrackerResponse<'a> { writer.write_u32::(action_id)?; writer.write_u32::(self.transaction_id())?; - req.write_bytes(writer)?; + req.write_bytes(&mut writer)?; } ResponseType::Scrape(req) => { writer.write_u32::(crate::SCRAPE_ACTION_ID)?; writer.write_u32::(self.transaction_id())?; - req.write_bytes(writer)?; + req.write_bytes(&mut writer)?; } ResponseType::Error(err) => { writer.write_u32::(ERROR_ACTION_ID)?; writer.write_u32::(self.transaction_id())?; - err.write_bytes(writer)?; + err.write_bytes(&mut writer)?; } }; + writer.flush(); + Ok(()) } diff --git a/packages/utracker/src/scrape.rs b/packages/utracker/src/scrape.rs index 8e0bc64e6..e033c3c3a 100644 --- a/packages/utracker/src/scrape.rs +++ b/packages/utracker/src/scrape.rs @@ -9,6 +9,7 @@ use nom::combinator::map_res; use nom::number::complete::be_i32; use nom::sequence::tuple; use nom::{IResult, Needed}; +use tracing::instrument; use util::bt::{self, InfoHash}; use util::convert; @@ -97,6 +98,7 @@ impl<'a> ScrapeRequest<'a> { /// # Errors /// /// It would return an IO Error if unable to write the bytes. + #[instrument(skip(self, writer), err)] pub fn write_bytes(&self, mut writer: W) -> std::io::Result<()> where W: std::io::Write, diff --git a/packages/utracker/src/server/dispatcher.rs b/packages/utracker/src/server/dispatcher.rs index 1d413658e..d83f31b58 100644 --- a/packages/utracker/src/server/dispatcher.rs +++ b/packages/utracker/src/server/dispatcher.rs @@ -1,9 +1,8 @@ use std::net::SocketAddr; use nom::IResult; -use tracing::instrument; -use umio::external::Sender; -use umio::{Dispatcher, ELoopBuilder, Provider}; +use tracing::{instrument, Level}; +use umio::{Dispatcher, ELoopBuilder, MessageSender, Provider, ShutdownHandle}; use crate::announce::AnnounceRequest; use crate::error::ErrorResponse; @@ -23,11 +22,14 @@ pub enum DispatchMessage { /// Create a new background dispatcher to service requests. #[allow(clippy::module_name_repetitions)] #[instrument(skip())] -pub fn create_dispatcher(bind: SocketAddr, handler: H) -> std::io::Result> +pub fn create_dispatcher( + bind: SocketAddr, + handler: H, +) -> std::io::Result<(MessageSender, SocketAddr, ShutdownHandle)> where H: ServerHandler + std::fmt::Debug + 'static, { - tracing::debug!("create dispatcher"); + tracing::trace!("create dispatcher"); let builder = ELoopBuilder::new() .channel_capacity(1) @@ -35,7 +37,7 @@ where .bind_address(bind) .buffer_length(EXPECTED_PACKET_LENGTH); - let mut eloop = builder.build()?; + let (mut eloop, socket, shutdown) = builder.build()?; let channel = eloop.channel(); let dispatch = ServerDispatcher::new(handler); @@ -44,7 +46,7 @@ where eloop.run(dispatch).expect("bip_utracker: ELoop Shutdown Unexpectedly..."); }); - Ok(channel) + Ok((channel, socket, shutdown)) } // ----------------------------------------------------------------------------// @@ -63,10 +65,8 @@ where H: ServerHandler + std::fmt::Debug, { /// Create a new `ServerDispatcher`. - #[instrument(skip(), ret)] + #[instrument(skip(), ret(level = Level::TRACE))] fn new(handler: H) -> ServerDispatcher { - tracing::debug!("new"); - ServerDispatcher { handler } } @@ -78,7 +78,7 @@ where request: &TrackerRequest<'_>, addr: SocketAddr, ) { - tracing::debug!("process request"); + tracing::trace!("process request"); let conn_id = request.connection_id(); let trans_id = request.transaction_id(); @@ -106,8 +106,6 @@ where /// Forward a connect request on to the appropriate handler method. #[instrument(skip(self, provider))] fn forward_connect(&mut self, provider: &mut Provider<'_, ServerDispatcher>, trans_id: u32, addr: SocketAddr) { - tracing::debug!("forward connect"); - let Some(attempt) = self.handler.connect(addr) else { tracing::warn!("connect attempt canceled"); @@ -121,6 +119,8 @@ where let response = TrackerResponse::new(trans_id, response_type); + tracing::trace!(?response, "forward connect"); + write_response(provider, &response, addr); } @@ -134,8 +134,6 @@ where request: &AnnounceRequest<'_>, addr: SocketAddr, ) { - tracing::debug!("forward announce"); - let Some(attempt) = self.handler.announce(addr, conn_id, request) else { tracing::warn!("announce attempt canceled"); @@ -148,6 +146,8 @@ where }; let response = TrackerResponse::new(trans_id, response_type); + tracing::trace!(?response, "forward announce"); + write_response(provider, &response, addr); } @@ -188,24 +188,21 @@ where { tracing::debug!("write response"); - provider.outgoing(|buffer| { - let mut cursor = std::io::Cursor::new(buffer); + provider.set_dest(addr); - match response.write_bytes(&mut cursor) { - Ok(()) => Some((cursor.position().try_into().unwrap(), addr)), - Err(e) => { - tracing::error!("error writing response to cursor: {e}"); - None - } + match response.write_bytes(provider) { + Ok(()) => (), + Err(e) => { + tracing::error!(%e, "error writing response to cursor"); } - }); + } } impl Dispatcher for ServerDispatcher where H: ServerHandler + std::fmt::Debug, { - type Timeout = (); + type TimeoutToken = (); type Message = DispatchMessage; #[instrument(skip(self, provider))] @@ -217,7 +214,7 @@ where self.process_request(&mut provider, &request, addr); } Err(e) => { - tracing::error!("received an incoming error message: {e}"); + tracing::error!(%e, "received an incoming error message"); } }; } diff --git a/packages/utracker/src/server/mod.rs b/packages/utracker/src/server/mod.rs index 5e2126208..5c96b641c 100644 --- a/packages/utracker/src/server/mod.rs +++ b/packages/utracker/src/server/mod.rs @@ -1,7 +1,7 @@ use std::net::SocketAddr; -use tracing::instrument; -use umio::external::Sender; +use tracing::{instrument, Level}; +use umio::{MessageSender, ShutdownHandle}; use crate::server::dispatcher::DispatchMessage; use crate::server::handler::ServerHandler; @@ -15,7 +15,9 @@ pub mod handler; #[allow(clippy::module_name_repetitions)] #[derive(Debug)] pub struct TrackerServer { - dispatcher: Sender, + dispatcher: MessageSender, + bound_socket: SocketAddr, + shutdown_handle: ShutdownHandle, } impl TrackerServer { @@ -24,16 +26,20 @@ impl TrackerServer { /// # Errors /// /// It would return an IO Error if unable to run the server. - #[instrument(skip(), ret)] + #[instrument(skip(), ret(level = Level::TRACE))] pub fn run(bind: SocketAddr, handler: H) -> std::io::Result where H: ServerHandler + std::fmt::Debug + 'static, { - tracing::info!("running server"); + let (dispatcher, bound_socket, shutdown_handle) = dispatcher::create_dispatcher(bind, handler)?; - let dispatcher = dispatcher::create_dispatcher(bind, handler)?; + tracing::info!(?bound_socket, "running server"); - Ok(TrackerServer { dispatcher }) + Ok(TrackerServer { + dispatcher, + bound_socket, + shutdown_handle, + }) } } diff --git a/packages/utracker/tests/common/mod.rs b/packages/utracker/tests/common/mod.rs index f410126fa..fc38b33b1 100644 --- a/packages/utracker/tests/common/mod.rs +++ b/packages/utracker/tests/common/mod.rs @@ -8,8 +8,8 @@ use futures::sink::SinkExt; use futures::stream::StreamExt; use futures::{Sink, Stream}; use handshake::DiscoveryInfo; -use tracing::instrument; use tracing::level_filters::LevelFilter; +use tracing::{instrument, Level}; use util::bt::{InfoHash, PeerId}; use util::trans::{LocallyShuffledIds, TransactionIds}; use utracker::announce::{AnnounceEvent, AnnounceRequest, AnnounceResponse}; @@ -58,7 +58,7 @@ pub struct InnerMockTrackerHandler { #[allow(dead_code)] impl MockTrackerHandler { - #[instrument(skip(), ret)] + #[instrument(skip(), ret(level = Level::TRACE))] pub fn new() -> MockTrackerHandler { tracing::debug!("new mock handler"); @@ -77,7 +77,7 @@ impl MockTrackerHandler { } impl ServerHandler for MockTrackerHandler { - #[instrument(skip(self), ret)] + #[instrument(skip(self), ret(level = Level::TRACE))] fn connect(&mut self, addr: SocketAddr) -> Option> { tracing::debug!("mock connect"); @@ -89,7 +89,7 @@ impl ServerHandler for MockTrackerHandler { Some(Ok(cid)) } - #[instrument(skip(self), ret)] + #[instrument(skip(self), ret(level = Level::TRACE))] fn announce( &mut self, addr: SocketAddr, @@ -158,7 +158,7 @@ impl ServerHandler for MockTrackerHandler { } } - #[instrument(skip(self), ret)] + #[instrument(skip(self), ret(level = Level::TRACE))] fn scrape(&mut self, _: SocketAddr, id: u64, req: &ScrapeRequest<'_>) -> Option>> { tracing::debug!("mock scrape"); @@ -211,16 +211,16 @@ impl DiscoveryInfo for MockHandshakerSink { impl Sink> for MockHandshakerSink { type Error = std::io::Error; - #[instrument(skip(self, cx), ret)] + #[instrument(skip(self, cx), ret(level = Level::TRACE))] fn poll_ready(self: std::pin::Pin<&mut Self>, cx: &mut std::task::Context<'_>) -> std::task::Poll> { - tracing::debug!("polling ready"); + tracing::trace!("polling ready"); self.send .poll_ready(cx) .map_err(|e| std::io::Error::new(std::io::ErrorKind::Other, e)) } - #[instrument(skip(self), ret)] + #[instrument(skip(self), ret(level = Level::TRACE))] fn start_send(mut self: std::pin::Pin<&mut Self>, item: std::io::Result) -> Result<(), Self::Error> { tracing::debug!("starting send"); @@ -229,19 +229,19 @@ impl Sink> for MockHandshakerSink { .map_err(|e| std::io::Error::new(std::io::ErrorKind::Other, e)) } - #[instrument(skip(self, cx), ret)] + #[instrument(skip(self, cx), ret(level = Level::TRACE))] fn poll_flush( mut self: std::pin::Pin<&mut Self>, cx: &mut std::task::Context<'_>, ) -> std::task::Poll> { - tracing::debug!("polling flush"); + tracing::trace!("polling flush"); self.send .poll_flush_unpin(cx) .map_err(|e| std::io::Error::new(std::io::ErrorKind::Other, e)) } - #[instrument(skip(self, cx), ret)] + #[instrument(skip(self, cx), ret(level = Level::TRACE))] fn poll_close( mut self: std::pin::Pin<&mut Self>, cx: &mut std::task::Context<'_>, @@ -261,9 +261,9 @@ pub struct MockHandshakerStream { impl Stream for MockHandshakerStream { type Item = std::io::Result; - #[instrument(skip(self, cx), ret)] + #[instrument(skip(self, cx), ret(level = Level::TRACE))] fn poll_next(mut self: std::pin::Pin<&mut Self>, cx: &mut std::task::Context<'_>) -> std::task::Poll> { - tracing::debug!("polling next"); + tracing::trace!("polling next"); self.recv.poll_next_unpin(cx).map(|maybe| maybe.map(Ok)) } diff --git a/packages/utracker/tests/test_announce_start.rs b/packages/utracker/tests/test_announce_start.rs index 37ef75f64..f946dd60b 100644 --- a/packages/utracker/tests/test_announce_start.rs +++ b/packages/utracker/tests/test_announce_start.rs @@ -14,7 +14,7 @@ mod common; #[tokio::test] async fn positive_announce_started() { INIT.call_once(|| { - tracing_stderr_init(LevelFilter::ERROR); + tracing_stderr_init(LevelFilter::INFO); }); let (handshaker_sender, mut handshaker_receiver) = handshaker(); @@ -29,7 +29,7 @@ async fn positive_announce_started() { let hash = [0u8; bt::INFO_HASH_LEN].into(); - tracing::warn!("sending announce"); + tracing::debug!("sending announce"); let _send_token = client .request( server_addr, @@ -37,7 +37,7 @@ async fn positive_announce_started() { ) .unwrap(); - tracing::warn!("receiving initiate message"); + tracing::debug!("receiving initiate message"); let init_msg = match tokio::time::timeout(DEFAULT_TIMEOUT, handshaker_receiver.next()) .await .unwrap() @@ -54,7 +54,7 @@ async fn positive_announce_started() { assert_eq!(&exp_peer_addr, init_msg.address()); assert_eq!(&hash, init_msg.hash()); - tracing::warn!("receiving client metadata"); + tracing::debug!("receiving client metadata"); let metadata = match tokio::time::timeout(DEFAULT_TIMEOUT, handshaker_receiver.next()) .await .unwrap() diff --git a/packages/utracker/tests/test_client_drop.rs b/packages/utracker/tests/test_client_drop.rs index 3564cb9e7..7cad967a2 100644 --- a/packages/utracker/tests/test_client_drop.rs +++ b/packages/utracker/tests/test_client_drop.rs @@ -12,7 +12,7 @@ mod common; #[tokio::test] async fn positive_client_request_failed() { INIT.call_once(|| { - tracing_stderr_init(LevelFilter::ERROR); + tracing_stderr_init(LevelFilter::INFO); }); let (sink, mut stream) = handshaker();