From 5fc3314db7a09c68a7176920db55494d49cfd10e Mon Sep 17 00:00:00 2001 From: Frank Lee Date: Mon, 30 Dec 2024 23:41:29 -0800 Subject: [PATCH] Read socket rework (#35) * 1gbps throughput * warnings * tokio mpsc * migrate sender to mpsc * reducing acks sent * fix server accept logic * cleanup + minor speedup * codecov bring back bins --------- Co-authored-by: Frank Lee <> --- .github/workflows/bluefin.yml | 14 +- Cargo.toml | 9 +- src/bin/client.rs | 78 ++--- src/bin/server.rs | 78 +++-- src/core/mod.rs | 11 + src/core/packet.rs | 13 + src/net/client.rs | 16 +- src/net/connection.rs | 51 +-- src/net/mod.rs | 44 +-- src/net/server.rs | 40 +-- src/utils/mod.rs | 39 +++ src/worker/conn_reader.rs | 162 ++++++++++ src/worker/mod.rs | 1 + src/worker/reader.rs | 27 +- src/worker/writer.rs | 576 +++++++++++++++------------------ tests/basic/basic_handshake.rs | 263 +++++++-------- 16 files changed, 772 insertions(+), 650 deletions(-) create mode 100644 src/worker/conn_reader.rs diff --git a/.github/workflows/bluefin.yml b/.github/workflows/bluefin.yml index 974da28..a2c906c 100644 --- a/.github/workflows/bluefin.yml +++ b/.github/workflows/bluefin.yml @@ -14,7 +14,7 @@ jobs: runs-on: ${{ matrix.os }} strategy: matrix: - os: [ubuntu-latest, macos-latest] + os: [ ubuntu-latest, macos-latest ] include: - os: ubuntu-latest target: Linux @@ -23,11 +23,11 @@ jobs: target: Macos steps: - - uses: actions/checkout@v3 - - name: Build - run: cargo build --verbose - - name: Run tests - run: cargo test --verbose + - uses: actions/checkout@v3 + - name: Build + run: cargo build --verbose + - name: Run tests + run: cargo test --verbose coverage: runs-on: ubuntu-latest @@ -40,7 +40,7 @@ jobs: - name: Install cargo-llvm-cov uses: taiki-e/install-action@cargo-llvm-cov - name: Generate code coverage - run: cargo llvm-cov --all-features --workspace --lcov --ignore-filename-regex "error.rs|*/bin/*.rs" --output-path lcov.info + run: cargo llvm-cov --all-features --workspace --lcov --ignore-filename-regex "error.rs" --output-path lcov.info - name: Upload coverage to Codecov uses: codecov/codecov-action@v3 with: diff --git a/Cargo.toml b/Cargo.toml index 6c8f9b2..79c69a8 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -9,15 +9,14 @@ repository = "https://github.com/franklee26/bluefin" # See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html [dependencies] -etherparse = "0.15.0" local-ip-address = "0.6.3" rand = "0.8.5" rstest = "0.23.0" -thiserror = "2.0.3" -tokio = { version = "1.41.1", features = ["full", "tracing"] } +thiserror = "2.0.9" +tokio = { version = "1.42.0", features = ["full", "tracing"] } console-subscriber = "0.4.1" libc = "0.2.164" -sysctl = "0.6.0" +socket2 = "0.5.8" [dev-dependencies] local-ip-address = "0.6.3" @@ -38,7 +37,7 @@ path = "src/bin/server.rs" unexpected_cfgs = { level = "warn", check-cfg = ['cfg(coverage,coverage_nightly)', 'cfg(kani)'] } [profile.release] -opt-level = 3 +opt-level = 3 codegen-units = 1 lto = "fat" debug = true diff --git a/src/bin/client.rs b/src/bin/client.rs index 6d388e9..accc13f 100644 --- a/src/bin/client.rs +++ b/src/bin/client.rs @@ -16,54 +16,56 @@ async fn main() -> BluefinResult<()> { let ports = [1320, 1322, 1323, 1324, 1325]; let mut tasks = vec![]; for ix in 0..2 { - // sleep(Duration::from_secs(3)).await; - let task = spawn(async move { - let mut total_bytes = 0; - let mut client = BluefinClient::new(std::net::SocketAddr::V4(SocketAddrV4::new( + let mut client = BluefinClient::new(std::net::SocketAddr::V4(SocketAddrV4::new( + Ipv4Addr::new(127, 0, 0, 1), + ports[ix], + ))); + if let Ok(mut conn) = client + .connect(std::net::SocketAddr::V4(SocketAddrV4::new( Ipv4Addr::new(127, 0, 0, 1), - ports[ix], - ))); - let mut conn = client - .connect(std::net::SocketAddr::V4(SocketAddrV4::new( - Ipv4Addr::new(127, 0, 0, 1), - 1318, - ))) - .await?; + 1318, + ))) + .await + { + let task = spawn(async move { + let mut total_bytes = 0; - let bytes = [1, 2, 3, 4, 5, 6, 7]; - let mut size = conn.send(&bytes).await?; - total_bytes += size; - println!("Sent {} bytes", size); + let bytes = [1, 2, 3, 4, 5, 6, 7]; + let mut size = conn.send(&bytes)?; + total_bytes += size; + println!("Sent {} bytes", size); - size = conn.send(&[12, 12, 12, 12, 12, 12]).await?; - total_bytes += size; - println!("Sent {} bytes", size); + size = conn.send(&[12, 12, 12, 12, 12, 12])?; + total_bytes += size; + println!("Sent {} bytes", size); - size = conn.send(&[13; 100]).await?; - total_bytes += size; - println!("Sent {} bytes", size); + size = conn.send(&[13; 100])?; + total_bytes += size; + println!("Sent {} bytes", size); - sleep(Duration::from_secs(1)).await; + sleep(Duration::from_secs(1)).await; - size = conn.send(&[14, 14, 14, 14, 14, 14]).await?; - total_bytes += size; - println!("Sent {} bytes", size); + size = conn.send(&[14, 14, 14, 14, 14, 14])?; + total_bytes += size; + println!("Sent {} bytes", size); - for ix in 0..5000000 { - // let my_array: [u8; 32] = rand::random(); let my_array = [0u8; 1500]; - size = conn.send(&my_array).await?; - total_bytes += size; - if ix % 3000 == 0 { - sleep(Duration::from_millis(3)).await; + for ix in 0..10000000 { + // let my_array: [u8; 32] = rand::random(); + size = conn.send(&my_array)?; + total_bytes += size; + if ix % 4000 == 0 { + sleep(Duration::from_millis(1)).await; + } } - } - println!("Sent {} bytes", total_bytes); - sleep(Duration::from_secs(2)).await; + println!("Sent {} bytes", total_bytes); + sleep(Duration::from_secs(3)).await; - Ok::<(), BluefinError>(()) - }); - tasks.push(task); + Ok::<(), BluefinError>(()) + }); + tasks.push(task); + sleep(Duration::from_millis(1)).await; + } } for t in tasks { diff --git a/src/bin/server.rs b/src/bin/server.rs index 8f26ad8..55ce9dd 100644 --- a/src/bin/server.rs +++ b/src/bin/server.rs @@ -1,11 +1,10 @@ #![cfg_attr(coverage_nightly, feature(coverage_attribute))] +use bluefin::{net::server::BluefinServer, utils::common::BluefinResult}; use std::{ cmp::{max, min}, net::{Ipv4Addr, SocketAddrV4}, time::Instant, }; - -use bluefin::{net::server::BluefinServer, utils::common::BluefinResult}; use tokio::{spawn, task::JoinSet}; #[cfg_attr(coverage_nightly, coverage(off))] @@ -24,28 +23,24 @@ async fn run() -> BluefinResult<()> { Ipv4Addr::new(127, 0, 0, 1), 1318, ))); - server.set_num_reader_workers(300)?; + server.set_num_reader_workers(3)?; server.bind().await?; let mut join_set = JoinSet::new(); - const MAX_NUM_CONNECTIONS: usize = 2; - for conn_num in 0..MAX_NUM_CONNECTIONS { - let mut s = server.clone(); - let _num = conn_num; + let mut _num = 0; + while let Ok(mut conn) = server.accept().await { let _ = join_set.spawn(async move { - let _conn = s.accept().await; - - match _conn { - Ok(mut conn) => { let mut total_bytes = 0; - let mut recv_bytes = [0u8; 80000]; + let mut recv_bytes = [0u8; 10000]; let mut min_bytes = usize::MAX; let mut max_bytes = 0; - let mut iteration = 1; + let mut iteration: i64 = 1; let mut num_iterations_without_print = 0; + let mut max_throughput = 0.0; + let mut min_throughput = f64::MAX; let now = Instant::now(); loop { - let size = conn.recv(&mut recv_bytes, 80000).await.unwrap(); + let size = conn.recv(&mut recv_bytes, 10000).await.unwrap(); total_bytes += size; min_bytes = min(size, min_bytes); max_bytes = max(size, max_bytes); @@ -53,39 +48,68 @@ async fn run() -> BluefinResult<()> { /* println!( - "({:x}_{:x}) >>> Received: {:?} (total: {})", + "({:x}_{:x}) >>> Received: {} bytes", conn.src_conn_id, conn.dst_conn_id, - &recv_bytes[..size], total_bytes ); */ num_iterations_without_print += 1; - if total_bytes >= 100000 && num_iterations_without_print == 200 { + if total_bytes >= 1000000 && num_iterations_without_print == 3500 { let elapsed = now.elapsed().as_secs(); + if elapsed == 0 { + eprintln!("(#{})Total bytes: {} (0s???)", _num, total_bytes); + num_iterations_without_print = 0; + continue; + } let through_put = u64::try_from(total_bytes).unwrap() / elapsed; + let through_put_mb = through_put as f64 / 1e6; let avg_recv_bytes: f64 = total_bytes as f64 / iteration as f64; - eprintln!( - "{} {:.1} kb/s or {:.1} mb/s (read {:.1} kb/iteration, min: {:.1} kb, max: {:.1} kb)", + + if through_put_mb > max_throughput { + max_throughput = through_put_mb; + } + + if through_put_mb < min_throughput { + min_throughput = through_put_mb; + } + + if through_put_mb < 1000.0 { + eprintln!( + "{} {:.1} kb/s or {:.1} mb/s (read {:.1} kb/iteration, min: {:.1} kb, max: {:.1} kb) (max {:.1} mb/s, min {:.1} mb/s)", _num, through_put as f64 / 1e3, - through_put as f64 / 1e6, + through_put_mb, avg_recv_bytes / 1e3, min_bytes as f64 / 1e3, - max_bytes as f64 / 1e3 + max_bytes as f64 / 1e3, + max_throughput, + min_throughput ); - num_iterations_without_print = 0; + } else { + eprintln!( + "{} {:.2} gb/s (read {:.1} kb/iter, min: {:.1} kb, max: {:.1} kb) (max {:.2} gb/s, min {:.1} kb/s)", + _num, + through_put_mb / 1e3, + avg_recv_bytes / 1e3, + min_bytes as f64 / 1e3, + max_bytes as f64 / 1e3, + max_throughput / 1e3, + min_throughput + ); + } + num_iterations_without_print = 0; // break; } iteration += 1; } - } - Err(e) => { - eprintln!("Could not accept connection due to error: {:?}", e); - } - } }); + _num += 1; + if _num >= 2 { + break; + } } + join_set.join_all().await; Ok(()) } diff --git a/src/core/mod.rs b/src/core/mod.rs index bc9c550..c71e2f6 100644 --- a/src/core/mod.rs +++ b/src/core/mod.rs @@ -5,6 +5,17 @@ pub mod error; pub mod header; pub mod packet; +pub trait Extract: Default { + /// Replace self with default and returns the initial value. + fn extract(&mut self) -> Self; +} + +impl Extract for T { + fn extract(&mut self) -> Self { + std::mem::replace(self, T::default()) + } +} + pub trait Serialisable { fn serialise(&self) -> Vec; fn deserialise(bytes: &[u8]) -> Result diff --git a/src/core/packet.rs b/src/core/packet.rs index 2bcf941..0e1cded 100644 --- a/src/core/packet.rs +++ b/src/core/packet.rs @@ -51,6 +51,19 @@ impl Serialisable for BluefinPacket { } } +impl Default for BluefinPacket { + #[allow(invalid_value)] + #[inline] + fn default() -> Self { + // SAFETY + // Actually, this isn't safe and access to this kind of zero'd value would result + // in panics. There does not exist a 'default' bluefin packet. Therefore, the + // purpose of this is to quickly instantiate a 'filler' bluefin packet BUT this + // default value should NEVER be read/used. + unsafe { std::mem::zeroed() } + } +} + impl BluefinPacket { #[inline] pub fn builder() -> BluefinPacketBuilder { diff --git a/src/net/client.rs b/src/net/client.rs index 23f8f46..fc9f69c 100644 --- a/src/net/client.rs +++ b/src/net/client.rs @@ -7,6 +7,11 @@ use std::{ use rand::Rng; use tokio::{net::UdpSocket, sync::RwLock}; +use super::{ + connection::{BluefinConnection, ConnectionBuffer, ConnectionManager}, + AckBuffer, ConnectionManagedBuffers, +}; +use crate::utils::get_udp_socket; use crate::{ core::{context::BluefinHost, error::BluefinError, header::PacketType, Serialisable}, net::{ @@ -15,12 +20,7 @@ use crate::{ utils::common::BluefinResult, }; -use super::{ - connection::{BluefinConnection, ConnectionBuffer, ConnectionManager}, - AckBuffer, ConnectionManagedBuffers, -}; - -const NUM_TX_WORKERS_FOR_CLIENT_DEFAULT: u16 = 5; +const NUM_TX_WORKERS_FOR_CLIENT_DEFAULT: u16 = 1; pub struct BluefinClient { socket: Option>, @@ -53,7 +53,7 @@ impl BluefinClient { } pub async fn connect(&mut self, dst_addr: SocketAddr) -> BluefinResult { - let socket = Arc::new(UdpSocket::bind(self.src_addr).await?); + let socket = Arc::new(get_udp_socket(self.src_addr)?); self.socket = Some(Arc::clone(&socket)); self.dst_addr = Some(dst_addr); @@ -137,8 +137,8 @@ impl BluefinClient { packet_number + 2, Arc::clone(&conn_buffer), Arc::clone(&ack_buff), - Arc::clone(self.socket.as_ref().unwrap()), self.dst_addr.unwrap(), + self.src_addr, )) } } diff --git a/src/net/connection.rs b/src/net/connection.rs index 20430ab..20c0576 100644 --- a/src/net/connection.rs +++ b/src/net/connection.rs @@ -7,18 +7,19 @@ use std::{ time::Duration, }; -use tokio::{net::UdpSocket, time::timeout}; +use tokio::time::timeout; use crate::{ core::{context::BluefinHost, error::BluefinError, packet::BluefinPacket}, utils::common::BluefinResult, - worker::{reader::ReaderRxChannel, writer::WriterTxChannel}, + worker::{reader::ReaderRxChannel, writer::WriterHandler}, }; use super::{ - build_and_start_writer_rx_channel, + build_and_start_ack_consumer_workers, build_and_start_conn_reader_tx_channels, + get_connected_udp_socket, ordered_bytes::{ConsumeResult, OrderedBytes}, - AckBuffer, ConnectionManagedBuffers, WriterQueue, + AckBuffer, ConnectionManagedBuffers, }; pub const MAX_BUFFER_SIZE: usize = 2000; @@ -254,7 +255,7 @@ pub struct BluefinConnection { pub src_conn_id: u32, pub dst_conn_id: u32, reader_rx: ReaderRxChannel, - writer_tx: WriterTxChannel, + writer_handler: WriterHandler, } impl BluefinConnection { @@ -264,32 +265,40 @@ impl BluefinConnection { next_send_packet_num: u64, conn_buffer: Arc>, ack_buffer: Arc>, - socket: Arc, dst_addr: SocketAddr, + src_addr: SocketAddr, ) -> Self { - let writer_queue = Arc::new(Mutex::new(WriterQueue::new())); - let ack_queue = Arc::new(Mutex::new(WriterQueue::new())); - let writer_tx = WriterTxChannel::new(Arc::clone(&writer_queue), Arc::clone(&ack_queue)); - - build_and_start_writer_rx_channel( - Arc::clone(&writer_queue), - Arc::clone(&ack_queue), - Arc::clone(&socket), - 1, - dst_addr, - Arc::clone(&ack_buffer), + build_and_start_ack_consumer_workers(1, Arc::clone(&ack_buffer)); + let s = get_connected_udp_socket(src_addr, dst_addr); + if let Err(e) = s { + panic!("Failed to get connected sockets due to error: {:?}", e); + } + let conn_socket = Arc::new(s.unwrap()); + + let mut writer_handler = WriterHandler::new( + Arc::clone(&conn_socket), next_send_packet_num, src_conn_id, dst_conn_id, ); + if let Err(e) = writer_handler.start() { + panic!("Cannot start connection due to error: {:?}", e); + } + + let conn_bufs = Arc::new(ConnectionManagedBuffers { + conn_buff: Arc::clone(&conn_buffer), + ack_buff: Arc::clone(&ack_buffer), + }); + + let _ = build_and_start_conn_reader_tx_channels(Arc::clone(&conn_socket), conn_bufs); - let reader_rx = ReaderRxChannel::new(Arc::clone(&conn_buffer), writer_tx.clone()); + let reader_rx = ReaderRxChannel::new(Arc::clone(&conn_buffer), writer_handler.clone()); Self { src_conn_id, dst_conn_id, reader_rx, - writer_tx, + writer_handler, } } @@ -300,10 +309,10 @@ impl BluefinConnection { } #[inline] - pub async fn send(&mut self, buf: &[u8]) -> BluefinResult { + pub fn send(&mut self, buf: &[u8]) -> BluefinResult { // TODO! This returns the total bytes sent (including bluefin payload). This // really should only return the total payload bytes - let _ = self.writer_tx.send(buf).await?; + self.writer_handler.send_data(buf)?; Ok(buf.len()) } } diff --git a/src/net/mod.rs b/src/net/mod.rs index b8428da..861d16f 100644 --- a/src/net/mod.rs +++ b/src/net/mod.rs @@ -1,7 +1,4 @@ -use std::{ - net::SocketAddr, - sync::{Arc, Mutex}, -}; +use std::sync::{Arc, Mutex}; use ack_handler::{AckBuffer, AckConsumer}; use connection::{ConnectionBuffer, ConnectionManager}; @@ -13,10 +10,8 @@ use crate::{ header::{BluefinHeader, BluefinSecurityFields, PacketType}, packet::BluefinPacket, }, - worker::{ - reader::ReaderTxChannel, - writer::{WriterQueue, WriterRxChannel}, - }, + utils::{common::BluefinResult, get_connected_udp_socket}, + worker::{conn_reader::ConnReaderHandler, reader::ReaderTxChannel}, }; pub mod ack_handler; @@ -57,34 +52,21 @@ fn build_and_start_tx( } } -fn build_and_start_writer_rx_channel( - data_queue: Arc>, - ack_queue: Arc>, +#[inline] +fn build_and_start_conn_reader_tx_channels( socket: Arc, + conn_bufs: Arc, +) -> BluefinResult<()> { + let handler = ConnReaderHandler::new(socket, conn_bufs); + handler.start() +} + +#[inline] +fn build_and_start_ack_consumer_workers( num_ack_consumer_workers: u8, - dst_addr: SocketAddr, ack_buffer: Arc>, - next_packet_num: u64, - src_conn_id: u32, - dst_conn_id: u32, ) { let largest_recv_acked_packet_num = Arc::new(RwLock::new(0)); - let mut rx = WriterRxChannel::new( - data_queue, - ack_queue, - socket, - dst_addr, - next_packet_num, - src_conn_id, - dst_conn_id, - ); - let mut cloned = rx.clone(); - spawn(async move { - cloned.run_data().await; - }); - spawn(async move { - rx.run_ack().await; - }); let ack_consumer = AckConsumer::new(Arc::clone(&ack_buffer), largest_recv_acked_packet_num); for _ in 0..num_ack_consumer_workers { diff --git a/src/net/server.rs b/src/net/server.rs index 7af2723..9b20772 100644 --- a/src/net/server.rs +++ b/src/net/server.rs @@ -1,5 +1,4 @@ use std::{ - mem, net::SocketAddr, sync::{Arc, Mutex}, time::Duration, @@ -11,7 +10,7 @@ use tokio::{net::UdpSocket, sync::RwLock}; use crate::{ core::{context::BluefinHost, error::BluefinError, header::PacketType, Serialisable}, net::{build_empty_encrypted_packet, connection::HandshakeConnectionBuffer}, - utils::common::BluefinResult, + utils::{common::BluefinResult, get_udp_socket}, }; use super::{ @@ -19,10 +18,8 @@ use super::{ connection::{BluefinConnection, ConnectionBuffer, ConnectionManager}, AckBuffer, ConnectionManagedBuffers, }; -use std::os::fd::AsRawFd; -const NUM_TX_WORKERS_FOR_SERVER_DEFAULT: u16 = 10; +const NUM_TX_WORKERS_FOR_SERVER_DEFAULT: u16 = 1; -#[derive(Clone)] pub struct BluefinServer { socket: Option>, src_addr: SocketAddr, @@ -54,38 +51,9 @@ impl BluefinServer { } pub async fn bind(&mut self) -> BluefinResult<()> { - let socket = UdpSocket::bind(self.src_addr).await?; - let socket_fd = socket.as_raw_fd(); + let socket = get_udp_socket(self.src_addr)?; self.socket = Some(Arc::new(socket)); - #[cfg(target_os = "macos")] - { - use sysctl::Sysctl; - if let Ok(ctl) = sysctl::Ctl::new("net.inet.udp.maxdgram") { - match ctl.set_value_string("16000") { - Ok(s) => { - println!("Successfully set net.inet.udp.maxdgram to {}", s) - } - Err(e) => eprintln!("Failed to set net.inet.udp.maxdgram due to err: {:?}", e), - } - } - } - - #[cfg(any(target_os = "linux", target_os = "macos"))] - unsafe { - let optval: libc::c_int = 1; - let ret = libc::setsockopt( - socket_fd, - libc::SOL_SOCKET, - libc::SO_REUSEPORT, - &optval as *const _ as *const libc::c_void, - mem::size_of_val(&optval) as libc::socklen_t, - ); - if ret != 0 { - return Err(BluefinError::InvalidSocketError); - } - } - build_and_start_tx( self.num_reader_workers, Arc::clone(self.socket.as_ref().unwrap()), @@ -165,8 +133,8 @@ impl BluefinServer { packet_number + 1, Arc::clone(&conn_buffer), Arc::clone(&ack_buffer), - Arc::clone(self.socket.as_ref().unwrap()), addr, + self.src_addr, )) } } diff --git a/src/utils/mod.rs b/src/utils/mod.rs index f6e5d5f..4d9d633 100644 --- a/src/utils/mod.rs +++ b/src/utils/mod.rs @@ -1,2 +1,41 @@ +use std::net::SocketAddr; + +use common::BluefinResult; +use tokio::net::UdpSocket; + pub mod common; pub mod window; + +#[inline] +pub(crate) fn get_udp_socket(src_addr: SocketAddr) -> BluefinResult { + let s = get_udp_socket_impl(src_addr)?; + let udp_sock: std::net::UdpSocket = s.into(); + let socket = udp_sock.try_into()?; + + Ok(socket) +} + +#[inline] +pub(crate) fn get_connected_udp_socket( + src_addr: SocketAddr, + dst_addr: SocketAddr, +) -> BluefinResult { + let socket = get_udp_socket_impl(src_addr)?; + socket.connect(&socket2::SockAddr::from(dst_addr))?; + + let udp_sock: std::net::UdpSocket = socket.into(); + let s = udp_sock.try_into()?; + + Ok(s) +} + +#[inline] +fn get_udp_socket_impl(src_addr: SocketAddr) -> BluefinResult { + let udp_sock = socket2::Socket::new(socket2::Domain::IPV4, socket2::Type::DGRAM, None)?; + udp_sock.set_reuse_address(true)?; + udp_sock.set_reuse_port(true)?; + udp_sock.set_cloexec(true)?; + udp_sock.set_nonblocking(true).unwrap(); + udp_sock.bind(&socket2::SockAddr::from(src_addr))?; + Ok(udp_sock) +} diff --git a/src/worker/conn_reader.rs b/src/worker/conn_reader.rs new file mode 100644 index 0000000..e63cb38 --- /dev/null +++ b/src/worker/conn_reader.rs @@ -0,0 +1,162 @@ +use tokio::net::UdpSocket; +use tokio::spawn; +use tokio::sync::mpsc::{self}; + +use crate::core::error::BluefinError; +use crate::core::header::PacketType; +use crate::core::packet::BluefinPacket; +use crate::net::ack_handler::AckBuffer; +use crate::net::connection::ConnectionBuffer; +use crate::net::{ConnectionManagedBuffers, MAX_BLUEFIN_BYTES_IN_UDP_DATAGRAM}; +use crate::utils::common::BluefinResult; +use std::sync::{Arc, MutexGuard}; + +const DEFAULT_NUMBER_OF_TASKS_TO_SPAWN: usize = 3; + +pub(crate) struct ConnReaderHandler { + socket: Arc, + conn_bufs: Arc, +} + +impl ConnReaderHandler { + pub(crate) fn new(socket: Arc, conn_bufs: Arc) -> Self { + Self { socket, conn_bufs } + } + + pub(crate) fn start(&self) -> BluefinResult<()> { + let (tx, rx) = mpsc::channel::>(1024); + for _ in 0..Self::get_num_cpu_cores() { + let tx_cloned = tx.clone(); + let socket_cloned = self.socket.clone(); + spawn(async move { + let _ = ConnReaderHandler::tx_impl(socket_cloned, tx_cloned).await; + }); + } + + let conn_bufs = self.conn_bufs.clone(); + spawn(async move { + let _ = ConnReaderHandler::rx_impl(rx, &*conn_bufs).await; + }); + Ok(()) + } + + #[allow(unreachable_code)] + #[inline] + fn get_num_cpu_cores() -> usize { + // For linux, let's use all the cpu cores available. + #[cfg(target_os = "linux")] + { + use std::thread::available_parallelism; + if let Ok(num_cpu_cores) = available_parallelism() { + return num_cpu_cores.get(); + } + } + + // For macos (at least silicon macs), we can't seem to use + // SO_REUSEPORT to our benefit. We will pretend we have one core. + #[cfg(target_os = "macos")] + { + return 1; + } + + // For everything else, we assume the default. + DEFAULT_NUMBER_OF_TASKS_TO_SPAWN + } + + #[inline] + async fn tx_impl( + socket: Arc, + tx: mpsc::Sender>, + ) -> BluefinResult<()> { + let mut buf = [0u8; MAX_BLUEFIN_BYTES_IN_UDP_DATAGRAM]; + loop { + let size = socket.recv(&mut buf).await?; + let packets = BluefinPacket::from_bytes(&buf[..size])?; + + if packets.len() == 0 { + continue; + } + + let _ = tx.send(packets).await; + } + } + + #[inline] + async fn rx_impl( + mut rx: mpsc::Receiver>, + conn_bufs: &ConnectionManagedBuffers, + ) { + loop { + if let Some(packets) = rx.recv().await { + let _ = Self::buffer_in_packets(packets, conn_bufs); + } + } + } + + #[inline] + fn buffer_in_packets( + packets: Vec, + conn_bufs: &ConnectionManagedBuffers, + ) -> BluefinResult<()> { + // Nothing to do if empty + if packets.is_empty() { + return Ok(()); + } + + // Peek at the first packet and acquire the buffer. + let p = packets.first().unwrap(); + match p.header.type_field { + PacketType::Ack => { + let guard = conn_bufs.ack_buff.lock().unwrap(); + Self::buffer_in_ack_packets(guard, packets) + } + _ => { + let guard = conn_bufs.conn_buff.lock().unwrap(); + Self::buffer_in_data_packets(guard, packets) + } + } + } + + #[inline] + fn buffer_in_ack_packets( + mut guard: MutexGuard<'_, AckBuffer>, + packets: Vec, + ) -> BluefinResult<()> { + let mut e: Option = None; + for p in packets { + if let Err(err) = guard.buffer_in_ack_packet(p) { + e = Some(err); + } + } + guard.wake()?; + + if e.is_some() { + return Err(e.unwrap()); + } + Ok(()) + } + + #[inline] + fn buffer_in_data_packets( + mut guard: MutexGuard<'_, ConnectionBuffer>, + packets: Vec, + ) -> BluefinResult<()> { + let mut e: Option = None; + for p in packets { + if let Err(err) = guard.buffer_in_bytes(p) { + e = Some(err); + } + } + + if let Some(w) = guard.get_waker() { + w.wake_by_ref(); + } else { + return Err(BluefinError::NoSuchWakerError); + } + + if e.is_some() { + return Err(e.unwrap()); + } + Ok(()) + } +} diff --git a/src/worker/mod.rs b/src/worker/mod.rs index c9134a0..b4dfca4 100644 --- a/src/worker/mod.rs +++ b/src/worker/mod.rs @@ -1,2 +1,3 @@ +pub mod conn_reader; pub mod reader; pub mod writer; diff --git a/src/worker/reader.rs b/src/worker/reader.rs index 33f8254..9aced05 100644 --- a/src/worker/reader.rs +++ b/src/worker/reader.rs @@ -19,7 +19,7 @@ use crate::{ utils::common::BluefinResult, }; -use super::writer::WriterTxChannel; +use super::writer::WriterHandler; #[derive(Clone)] /// [ReaderTxChannel] is the transmission channel for the receiving [ReaderRxChannel]. This channel will when @@ -43,7 +43,9 @@ pub(crate) struct ReaderTxChannel { /// *receives* bytes *from* the buffer. pub(crate) struct ReaderRxChannel { future: ReaderRxChannelFuture, - writer_tx_channel: WriterTxChannel, + writer_handler: WriterHandler, + packets_consumed: usize, + packets_consumed_before_ack: usize, } #[derive(Clone)] @@ -69,14 +71,13 @@ impl Future for ReaderRxChannelFuture { } impl ReaderRxChannel { - pub(crate) fn new( - buffer: Arc>, - writer_tx_channel: WriterTxChannel, - ) -> Self { + pub(crate) fn new(buffer: Arc>, writer_handler: WriterHandler) -> Self { let future = ReaderRxChannelFuture { buffer }; Self { future, - writer_tx_channel, + writer_handler, + packets_consumed: 0, + packets_consumed_before_ack: 200, } } @@ -93,19 +94,23 @@ impl ReaderRxChannel { }; let num_packets_consumed = consume_res.get_num_packets_consumed(); let base_packet_num = consume_res.get_base_packet_number(); + self.packets_consumed += num_packets_consumed; // We need to send an ack. - if num_packets_consumed > 0 && base_packet_num != 0 { + if num_packets_consumed > 0 + && base_packet_num != 0 + && self.packets_consumed >= self.packets_consumed_before_ack + { if let Err(e) = self - .writer_tx_channel + .writer_handler .send_ack(base_packet_num, num_packets_consumed) - .await { eprintln!( "Failed to send ack packet after reads due to error: {:?}", e ); } + self.packets_consumed = 0; } Ok((consume_res.get_bytes_consumed(), addr)) @@ -195,8 +200,6 @@ impl ReaderTxChannel { let packet_src_conn_id = packet.header.source_connection_id; if !is_hello && !is_client_ack { // If not hello, we buffer in the bytes - // Could not buffer in packet... buffer is likely full. We will have to discard the - // packet. conn_buff.buffer_in_bytes(packet)?; } else { conn_buff.buffer_in_packet(packet)?; diff --git a/src/worker/writer.rs b/src/worker/writer.rs index 6305a8f..f89d212 100644 --- a/src/worker/writer.rs +++ b/src/worker/writer.rs @@ -1,17 +1,15 @@ -use std::{ - cmp::min, - collections::VecDeque, - future::Future, - net::SocketAddr, - sync::{Arc, Mutex}, - task::{Poll, Waker}, - time::Duration, -}; +use std::{cmp::min, collections::VecDeque, sync::Arc}; -use tokio::{net::UdpSocket, time::sleep}; +use tokio::{ + net::UdpSocket, + spawn, + sync::mpsc::{self, UnboundedReceiver, UnboundedSender}, +}; +use crate::core::Extract; use crate::{ core::{ + error::BluefinError, header::{BluefinHeader, BluefinSecurityFields, PacketType}, packet::BluefinPacket, Serialisable, @@ -20,31 +18,184 @@ use crate::{ utils::common::BluefinResult, }; -/// Each writer queue holds a queue of `WriterQueueData` -enum WriterQueueData { - Payload(Vec), - Ack { - base_packet_num: u64, - num_packets_consumed: usize, - }, +#[derive(Clone, Copy)] +struct AckData { + base_packet_num: u64, + num_packets_consumed: usize, } -pub(crate) struct WriterQueue { - queue: VecDeque, - waker: Option, +#[derive(Clone)] +pub(crate) struct WriterHandler { + socket: Arc, + next_packet_num: u64, + data_sender: Option>>, + ack_sender: Option>, + src_conn_id: u32, + dst_conn_id: u32, } -impl WriterQueue { - pub(crate) fn new() -> Self { +impl WriterHandler { + pub(crate) fn new( + socket: Arc, + next_packet_num: u64, + src_conn_id: u32, + dst_conn_id: u32, + ) -> Self { Self { - queue: VecDeque::new(), - waker: None, + socket, + src_conn_id, + dst_conn_id, + next_packet_num, + data_sender: None, + ack_sender: None, + } + } + + pub(crate) fn start(&mut self) -> BluefinResult<()> { + let (data_s, data_r) = mpsc::unbounded_channel(); + let (ack_s, ack_r) = mpsc::unbounded_channel(); + self.data_sender = Some(data_s); + self.ack_sender = Some(ack_s); + + let next_packet_num = self.next_packet_num; + let src_conn_id = self.src_conn_id; + let dst_conn_id = self.dst_conn_id; + let socket = Arc::clone(&self.socket); + spawn(async move { + Self::read_data(data_r, next_packet_num, src_conn_id, dst_conn_id, socket).await; + }); + + let socket = Arc::clone(&self.socket); + spawn(async move { + Self::read_ack(ack_r, socket, src_conn_id, dst_conn_id).await; + }); + + Ok(()) + } + + #[inline] + pub(crate) fn send_data(&self, payload: &[u8]) -> BluefinResult<()> { + if self.data_sender.is_none() { + return Err(BluefinError::WriteError( + "Sender is not available. Cannot send.".to_string(), + )); + } + + if let Err(e) = self.data_sender.as_ref().unwrap().send(payload.to_vec()) { + return Err(BluefinError::WriteError(format!( + "Failed to send data due to error: {:?}", + e + ))); } + Ok(()) } #[inline] - pub(crate) fn consume_data( - &mut self, + pub(crate) fn send_ack( + &self, + base_packet_num: u64, + num_packets_consumed: usize, + ) -> BluefinResult<()> { + if self.ack_sender.is_none() { + return Err(BluefinError::WriteError( + "Ack sender is not available. Cannot send.".to_string(), + )); + } + + let data = AckData { + base_packet_num, + num_packets_consumed, + }; + + if let Err(e) = self.ack_sender.as_ref().unwrap().send(data) { + return Err(BluefinError::WriteError(format!( + "Failed to send ack due to error: {:?}", + e + ))); + } + Ok(()) + } + + #[inline] + async fn read_ack( + mut rx: UnboundedReceiver, + socket: Arc, + src_conn_id: u32, + dst_conn_id: u32, + ) { + let mut ack_queue = VecDeque::new(); + let mut b = vec![]; + let limit = 10; + loop { + let size = rx.recv_many(&mut b, limit).await; + for i in 0..size { + ack_queue.push_back(b[i]); + } + + if let Err(e) = socket.writable().await { + eprintln!("Cannot write to socket due to err: {:?}", e); + continue; + } + + if let Some(data) = Self::consume_acks(&mut ack_queue, src_conn_id, dst_conn_id) { + if let Err(e) = socket.try_send(&data) { + eprintln!( + "Encountered error {} while sending ack packet across wire", + e + ); + continue; + } + } + } + } + + #[inline] + async fn read_data( + mut rx: UnboundedReceiver>, + next_packet_num: u64, + src_conn_id: u32, + dst_conn_id: u32, + socket: Arc, + ) { + let mut data_queue = VecDeque::new(); + let limit = 10; + let mut next_packet_num = next_packet_num; + let mut b = Vec::with_capacity(limit); + loop { + b.clear(); + let size = rx.recv_many(&mut b, limit).await; + for i in 0..size { + // Extract is a small optimization quicker. We avoid a (potentially) + // costly clone by moving the bytes out of the vec and replacing it + // via a zeroed default value. + data_queue.push_back(b[i].extract()); + } + + if let Err(e) = socket.writable().await { + eprintln!("Cannot write to socket due to err: {:?}", e); + continue; + } + + if let Some(data) = Self::consume_data( + &mut data_queue, + &mut next_packet_num, + src_conn_id, + dst_conn_id, + ) { + if let Err(e) = socket.try_send(&data) { + eprintln!( + "Encountered error {} while sending data packet across wire", + e + ); + continue; + } + } + } + } + + #[inline] + fn consume_data( + queue: &mut VecDeque>, next_packet_num: &mut u64, src_conn_id: u32, dst_conn_id: u32, @@ -62,7 +213,7 @@ impl WriterQueue { security_fields, ); - while !self.queue.is_empty() && bytes_remaining > 20 { + while !queue.is_empty() && bytes_remaining > 20 { // We already have some bytes left over and it's more than we can afford. Take what // we can and end. if running_payload.len() >= bytes_remaining - 20 { @@ -85,8 +236,7 @@ impl WriterQueue { } if !running_payload.is_empty() { - self.queue - .push_front(WriterQueueData::Payload(running_payload.to_vec())); + queue.push_front(running_payload.to_vec()); } return Some(ans); } @@ -111,36 +261,30 @@ impl WriterQueue { } // We have room - let data = self.queue.pop_front().unwrap(); - match data { - WriterQueueData::Payload(p) => { - let potential_bytes_len = p.len(); - if potential_bytes_len + running_payload.len() > MAX_BLUEFIN_PAYLOAD_SIZE_BYTES - { - // We cannot simply fit both payloads into this packet. - running_payload.extend(p); - - // Try to take as much as we can - let max_bytes_to_take = min( - running_payload.len(), - min(MAX_BLUEFIN_PAYLOAD_SIZE_BYTES, bytes_remaining - 20), - ); - header.with_packet_number(*next_packet_num); - header.type_specific_payload = max_bytes_to_take as u16; - let packet = BluefinPacket::builder() - .header(header) - .payload(running_payload[..max_bytes_to_take].to_vec()) - .build(); - ans.extend(packet.serialise()); - *next_packet_num += 1; - bytes_remaining -= max_bytes_to_take + 20; - running_payload = running_payload[max_bytes_to_take..].to_vec(); - } else { - // We can fit both the payload and the left over bytes - running_payload.extend(p); - } - } - _ => unreachable!(), + let data = queue.pop_front().unwrap(); + let potential_bytes_len = data.len(); + if potential_bytes_len + running_payload.len() > MAX_BLUEFIN_PAYLOAD_SIZE_BYTES { + // We cannot simply fit both payloads into this packet. + running_payload.extend(data); + + // Try to take as much as we can + let max_bytes_to_take = min( + running_payload.len(), + min(MAX_BLUEFIN_PAYLOAD_SIZE_BYTES, bytes_remaining - 20), + ); + header.with_packet_number(*next_packet_num); + header.type_specific_payload = max_bytes_to_take as u16; + let packet = BluefinPacket::builder() + .header(header) + .payload(running_payload[..max_bytes_to_take].to_vec()) + .build(); + ans.extend(packet.serialise()); + *next_packet_num += 1; + bytes_remaining -= max_bytes_to_take + 20; + running_payload = running_payload[max_bytes_to_take..].to_vec(); + } else { + // We can fit both the payload and the left over bytes + running_payload.extend(data); } } @@ -164,8 +308,7 @@ impl WriterQueue { // Re-queue the remaining bytes if !running_payload.is_empty() { - self.queue - .push_front(WriterQueueData::Payload(running_payload)); + queue.push_front(running_payload); } if ans.is_empty() { @@ -174,7 +317,11 @@ impl WriterQueue { Some(ans) } - pub(crate) fn consume_acks(&mut self, src_conn_id: u32, dst_conn_id: u32) -> Option> { + fn consume_acks( + queue: &mut VecDeque, + src_conn_id: u32, + dst_conn_id: u32, + ) -> Option> { let mut bytes = vec![]; let security_fields = BluefinSecurityFields::new(false, 0x0); let mut header = BluefinHeader::new( @@ -184,23 +331,15 @@ impl WriterQueue { 0, security_fields, ); - while !self.queue.is_empty() && bytes.len() <= MAX_BLUEFIN_BYTES_IN_UDP_DATAGRAM { - let data = self.queue.pop_front().unwrap(); - match data { - WriterQueueData::Ack { - base_packet_num: b, - num_packets_consumed: c, - } => { - header.packet_number = b; - header.type_specific_payload = c as u16; - if bytes.len() + 20 <= MAX_BLUEFIN_BYTES_IN_UDP_DATAGRAM { - bytes.extend(header.serialise()); - } else { - self.queue.push_front(data); - break; - } - } - _ => unreachable!(), + while !queue.is_empty() && bytes.len() <= MAX_BLUEFIN_BYTES_IN_UDP_DATAGRAM { + let data = queue.pop_front().unwrap(); + header.packet_number = data.base_packet_num; + header.type_specific_payload = data.num_packets_consumed as u16; + if bytes.len() + 20 <= MAX_BLUEFIN_BYTES_IN_UDP_DATAGRAM { + bytes.extend(header.serialise()); + } else { + queue.push_front(data); + break; } } @@ -212,244 +351,47 @@ impl WriterQueue { } } -/// Queues write requests to be sent. Each connection can have one or more [WriterTxChannel]. -#[derive(Clone)] -pub(crate) struct WriterTxChannel { - data_queue: Arc>, - ack_queue: Arc>, - num_runs_without_sleep: u32, -} - -impl WriterTxChannel { - pub(crate) fn new( - data_queue: Arc>, - ack_queue: Arc>, - ) -> Self { - Self { - data_queue, - ack_queue, - num_runs_without_sleep: 0, - } - } - - /// ONLY for sending data - pub(crate) async fn send(&mut self, payload: &[u8]) -> BluefinResult { - let bytes = payload.len(); - let data = WriterQueueData::Payload(payload.to_vec()); - - { - let mut guard = self.data_queue.lock().unwrap(); - guard.queue.push_back(data); - - // Signal to Rx channel that we have new packets in the queue - if let Some(ref waker) = guard.waker { - waker.wake_by_ref(); - } - } - - self.num_runs_without_sleep += 1; - if self.num_runs_without_sleep >= 100 { - sleep(Duration::from_nanos(10)).await; - self.num_runs_without_sleep = 0; - } - - Ok(bytes) - } - - pub(crate) async fn send_ack( - &mut self, - base_packet_num: u64, - num_packets_consumed: usize, - ) -> BluefinResult<()> { - let data = WriterQueueData::Ack { - base_packet_num, - num_packets_consumed, - }; - - { - let mut guard = self.ack_queue.lock().unwrap(); - guard.queue.push_back(data); - - // Signal to Rx channel that we have new packets in the queue - if let Some(ref waker) = guard.waker { - waker.wake_by_ref(); - } - } - - Ok(()) - } -} - -#[derive(Clone)] -struct WriterRxChannelDataFuture { - data_queue: Arc>, -} - -#[derive(Clone)] -struct WriterRxChannelAckFuture { - ack_queue: Arc>, -} - -/// Consumes queued requests and sends them across the wire. For now, each connection -/// has one and only one [WriterRxChannel]. This channel must run two separate jobs: -/// [WriterRxChannel::run_data], which reads out of the data queue and sends bluefin -/// packets w/ payloads across the wire AND [WriterRxChannel::run_ack], which reads -/// acks out of the ack queue and sends bluefin ack packets across the wire. -#[derive(Clone)] -pub(crate) struct WriterRxChannel { - data_future: WriterRxChannelDataFuture, - ack_future: WriterRxChannelAckFuture, - dst_addr: SocketAddr, - next_packet_num: u64, - src_conn_id: u32, - dst_conn_id: u32, - socket: Arc, -} - -impl Future for WriterRxChannelDataFuture { - type Output = usize; - - fn poll( - self: std::pin::Pin<&mut Self>, - cx: &mut std::task::Context<'_>, - ) -> std::task::Poll { - let mut guard = self.data_queue.lock().unwrap(); - let num_packets_to_send = guard.queue.len(); - if num_packets_to_send == 0 { - guard.waker = Some(cx.waker().clone()); - return Poll::Pending; - } - Poll::Ready(num_packets_to_send) - } -} - -impl Future for WriterRxChannelAckFuture { - type Output = usize; - - fn poll( - self: std::pin::Pin<&mut Self>, - cx: &mut std::task::Context<'_>, - ) -> std::task::Poll { - let mut guard = self.ack_queue.lock().unwrap(); - let num_packets_to_send = guard.queue.len(); - if num_packets_to_send == 0 { - guard.waker = Some(cx.waker().clone()); - return Poll::Pending; - } - Poll::Ready(num_packets_to_send) - } -} - -impl WriterRxChannel { - pub(crate) fn new( - data_queue: Arc>, - ack_queue: Arc>, - socket: Arc, - dst_addr: SocketAddr, - next_packet_num: u64, - src_conn_id: u32, - dst_conn_id: u32, - ) -> Self { - let data_future = WriterRxChannelDataFuture { - data_queue: Arc::clone(&data_queue), - }; - let ack_future = WriterRxChannelAckFuture { - ack_queue: Arc::clone(&ack_queue), - }; - Self { - data_future, - ack_future, - dst_addr, - next_packet_num, - src_conn_id, - dst_conn_id, - socket: Arc::clone(&socket), - } - } - pub(crate) async fn run_ack(&mut self) { - loop { - let _ = self.ack_future.clone().await; - let mut guard = self.ack_future.ack_queue.lock().unwrap(); - let bytes = guard.consume_acks(self.src_conn_id, self.dst_conn_id); - match bytes { - None => continue, - Some(b) => { - if let Err(e) = self.socket.try_send_to(&b, self.dst_addr) { - eprintln!( - "Encountered error {} while sending ack packet across wire", - e - ); - continue; - } - } - } - guard.waker = None; - } - } - - pub(crate) async fn run_data(&mut self) { - loop { - let _ = self.data_future.clone().await; - let mut guard = self.data_future.data_queue.lock().unwrap(); - let bytes = guard.consume_data( - &mut self.next_packet_num, - self.src_conn_id, - self.dst_conn_id, - ); - match bytes { - None => continue, - Some(b) => { - if let Err(e) = self.socket.try_send_to(&b, self.dst_addr) { - eprintln!( - "Encountered error {} while sending data packet across wire", - e - ); - continue; - } - } - } - guard.waker = None; - } - } -} - #[cfg(kani)] mod verification_tests { - use crate::worker::writer::WriterQueue; + use crate::worker::writer::WriterHandler; + use std::collections::VecDeque; #[kani::proof] fn kani_writer_queue_consume_empty_data_behaves_as_expected() { - let mut writer_q = WriterQueue::new(); let mut next_packet_num = kani::any(); + let mut queue = VecDeque::new(); let prev = next_packet_num; - assert!(writer_q - .consume_data(&mut next_packet_num, kani::any(), kani::any()) - .is_none()); + assert!(WriterHandler::consume_data( + &mut queue, + &mut next_packet_num, + kani::any(), + kani::any() + ) + .is_none()); assert_eq!(next_packet_num, prev); } #[kani::proof] fn kani_writer_queue_consume_empty_ack_behaves_as_expected() { - let mut writer_q = WriterQueue::new(); - assert!(writer_q.consume_acks(kani::any(), kani::any()).is_none()); + let mut queue = VecDeque::new(); + assert!(WriterHandler::consume_acks(&mut queue, kani::any(), kani::any()).is_none()); } } #[cfg(test)] mod tests { use rstest::rstest; + use std::collections::VecDeque; + use crate::worker::writer::{AckData, WriterHandler}; use crate::{ core::{header::PacketType, packet::BluefinPacket}, net::{ MAX_BLUEFIN_BYTES_IN_UDP_DATAGRAM, MAX_BLUEFIN_PACKETS_IN_UDP_DATAGRAM, MAX_BLUEFIN_PAYLOAD_SIZE_BYTES, }, - worker::writer::WriterQueue, }; - use super::WriterQueueData; - #[rstest] #[test] #[case(550)] @@ -461,15 +403,15 @@ mod tests { assert_ne!(expected_byte_size, 0); assert!(expected_byte_size <= MAX_BLUEFIN_BYTES_IN_UDP_DATAGRAM); - let mut writer_q = WriterQueue::new(); + let mut queue = VecDeque::new(); for _ in 0..num_acks { - writer_q.queue.push_back(WriterQueueData::Ack { + queue.push_back(AckData { base_packet_num: 1, num_packets_consumed: 3, }); } - let consume_res = writer_q.consume_acks(0xbcd, 0x521); + let consume_res = WriterHandler::consume_acks(&mut queue, 0xbcd, 0x521); assert!(consume_res.is_some()); let consume = consume_res.unwrap(); @@ -492,7 +434,7 @@ mod tests { } // Because we are adding at most 1 datagram worth of acks, we get nothing more - assert!(writer_q.consume_acks(0x0, 0x0).is_none()); + assert!(WriterHandler::consume_acks(&mut queue, 0x0, 0x0).is_none()); } #[rstest] @@ -511,15 +453,15 @@ mod tests { assert!(expected_byte_size > MAX_BLUEFIN_BYTES_IN_UDP_DATAGRAM); assert!(num_datagrams > 1 && num_datagrams <= 10); - let mut writer_q = WriterQueue::new(); + let mut queue = VecDeque::new(); for ix in 0..num_acks { - writer_q.queue.push_back(WriterQueueData::Ack { + queue.push_back(AckData { base_packet_num: ix as u64, num_packets_consumed: ix + 1, }); } - let consume_res = writer_q.consume_acks(0xbcd, 0x521); + let consume_res = WriterHandler::consume_acks(&mut queue, 0xbcd, 0x521); assert!(consume_res.is_some()); let consume = consume_res.unwrap(); @@ -540,13 +482,13 @@ mod tests { assert_eq!(p.header.type_specific_payload as usize, ix + 1); p_num = ix; } - assert!(p_num != 0); + assert_ne!(p_num, 0); let mut actual_num_acks = 0; actual_num_acks += packets.len(); let mut counter = 0; - let mut consume_res = writer_q.consume_acks(0x0, 0x0); + let mut consume_res = WriterHandler::consume_acks(&mut queue, 0x0, 0x0); while counter <= 10 && consume_res.is_some() { let consume = consume_res.unwrap(); let packets_res = BluefinPacket::from_bytes(&consume); @@ -563,7 +505,7 @@ mod tests { p_num += packets.len(); actual_num_acks += packets.len(); - consume_res = writer_q.consume_acks(0x0, 0x0); + consume_res = WriterHandler::consume_acks(&mut queue, 0x0, 0x0); counter += 1; } assert_eq!(num_acks, actual_num_acks); @@ -594,18 +536,17 @@ mod tests { let bytes_total = payload_size_total + (20 * num_packets_total); assert!(bytes_total <= MAX_BLUEFIN_BYTES_IN_UDP_DATAGRAM); - let mut writer_q = WriterQueue::new(); + let mut queue = VecDeque::new(); for ix in 0..num_iterations { let data = vec![ix as u8; payload_size]; - writer_q - .queue - .push_back(WriterQueueData::Payload(data.to_vec())); + queue.push_back(data.to_vec()); } let mut next_packet_num = 0; let src_conn_id = 0x123; let dst_conn_id = 0xabc; - let consume_res = writer_q.consume_data(&mut next_packet_num, src_conn_id, dst_conn_id); + let consume_res = + WriterHandler::consume_data(&mut queue, &mut next_packet_num, src_conn_id, dst_conn_id); assert!(consume_res.is_some()); let consume = consume_res.unwrap(); @@ -642,9 +583,9 @@ mod tests { // Since we added less than the max amount of bytes we can stuff in a datagram, one consume consumes // all of the data. There should be nothing left. assert!(consume.len() <= MAX_BLUEFIN_BYTES_IN_UDP_DATAGRAM); - assert!(writer_q - .consume_data(&mut next_packet_num, 0x123, 0x456) - .is_none()); + assert!( + WriterHandler::consume_data(&mut queue, &mut next_packet_num, 0x123, 0x456).is_none() + ); } #[rstest] @@ -673,19 +614,18 @@ mod tests { assert!(num_datagrams >= 1 && num_datagrams <= 10); let mut expected_data = vec![]; - let mut writer_q = WriterQueue::new(); + let mut queue = VecDeque::new(); for ix in 0..num_iterations { let data = vec![ix as u8; payload_size]; expected_data.extend_from_slice(&data); - writer_q - .queue - .push_back(WriterQueueData::Payload(data.to_vec())); + queue.push_back(data.to_vec()); } let mut next_packet_num = 0; let src_conn_id = 0x123; let dst_conn_id = 0xabc; - let consume_res = writer_q.consume_data(&mut next_packet_num, src_conn_id, dst_conn_id); + let consume_res = + WriterHandler::consume_data(&mut queue, &mut next_packet_num, src_conn_id, dst_conn_id); assert!(consume_res.is_some()); let consume = consume_res.unwrap(); @@ -710,7 +650,8 @@ mod tests { // Fetch the rest of the data. Our tests won't go beyond 10 datagrams worth of data // so we assert the count here just in case. let mut counter = 0; - let mut consume_res = writer_q.consume_data(&mut next_packet_num, src_conn_id, dst_conn_id); + let mut consume_res = + WriterHandler::consume_data(&mut queue, &mut next_packet_num, src_conn_id, dst_conn_id); while counter < 10 && consume_res.is_some() { let consume = consume_res.as_ref().unwrap(); assert_ne!(consume.len(), 0); @@ -729,7 +670,12 @@ mod tests { } counter += 1; - consume_res = writer_q.consume_data(&mut next_packet_num, src_conn_id, dst_conn_id); + consume_res = WriterHandler::consume_data( + &mut queue, + &mut next_packet_num, + src_conn_id, + dst_conn_id, + ); } assert_eq!(num_datagrams, 1 + counter); assert_eq!(expected_data, actual_data); diff --git a/tests/basic/basic_handshake.rs b/tests/basic/basic_handshake.rs index 6aec79f..1e511dc 100644 --- a/tests/basic/basic_handshake.rs +++ b/tests/basic/basic_handshake.rs @@ -1,14 +1,14 @@ +use bluefin::net::{client::BluefinClient, server::BluefinServer}; use core::str; +use local_ip_address::list_afinet_netifas; +use rstest::{fixture, rstest}; use std::{ collections::HashMap, net::{IpAddr, Ipv4Addr, SocketAddrV4}, time::Duration, }; - -use bluefin::net::{client::BluefinClient, server::BluefinServer}; -use local_ip_address::list_afinet_netifas; -use rstest::{fixture, rstest}; use tokio::{ + spawn, task::JoinSet, time::{sleep, timeout}, }; @@ -136,21 +136,12 @@ async fn basic_server_client_connection_send_recv( } // Now flip around and let the server send 5 bytes - let size = timeout(Duration::from_secs(1), conn.send(&[5, 4, 3, 2, 1])) - .await - .expect("Server timed out while trying to send five bytes") - .expect("Server encountered error while trying to send bytes"); - assert_eq!(size, 5); + let size = conn.send(&[5, 4, 3, 2, 1]); + assert!(size.is_ok_and(|s| s == 5), "Failed to send bytes"); // Send another 10 bytes - let size = timeout( - Duration::from_secs(1), - conn.send(&[2, 4, 6, 8, 10, 12, 14, 16, 18, 20]), - ) - .await - .expect("Server timed out while trying to send ten bytes") - .expect("Server encountered error while trying to send bytes"); - assert_eq!(size, 10); + let size = conn.send(&[2, 4, 6, 8, 10, 12, 14, 16, 18, 20]); + assert!(size.is_ok_and(|s| s == 10), "Failed to send bytes"); }); let loopback_cloned = loopback_ip_addr.clone(); @@ -171,77 +162,53 @@ async fn basic_server_client_connection_send_recv( // Send 7 bytes let bytes = [1, 2, 3, 4, 5, 6, 7]; - let size = timeout(Duration::from_secs(3), conn.send(&bytes)) - .await - .expect("Client timed out while sending batch #1") - .expect("Client encountered error while sending"); - assert_eq!(size, 7); + let size = conn.send(&bytes); + assert!(size.is_ok_and(|s| s == 7), "Failed to send bytes"); total_num_bytes_sent += 7; // Send 50 bytes let bytes = [10; 50]; - let size = timeout(Duration::from_secs(3), conn.send(&bytes)) - .await - .expect("Client timed out while sending batch #2") - .expect("Client encountered error while sending"); - assert_eq!(size, 50); + let size = conn.send(&bytes); + assert!(size.is_ok_and(|s| s == 50), "Failed to send bytes"); total_num_bytes_sent += 50; // Send 3 bytes let bytes = [8, 8, 8]; - let size = timeout(Duration::from_secs(3), conn.send(&bytes)) - .await - .expect("Client timed out while sending batch #3") - .expect("Client encountered error while sending"); - assert_eq!(size, 3); + let size = conn.send(&bytes); + assert!(size.is_ok_and(|s| s == 3), "Failed to send bytes"); total_num_bytes_sent += 3; // Send 40 bytes let bytes = [99; 40]; - let size = timeout(Duration::from_secs(3), conn.send(&bytes)) - .await - .expect("Client timed out while sending batch #4") - .expect("Client encountered error while sending"); - assert_eq!(size, 40); + let size = conn.send(&bytes); + assert!(size.is_ok_and(|s| s == 40), "Failed to send bytes"); total_num_bytes_sent += 40; // Send 500 bytes let bytes = [27; 500]; - let size = timeout(Duration::from_secs(3), conn.send(&bytes)) - .await - .expect("Client timed out while sending batch #5") - .expect("Client encountered error while sending"); - assert_eq!(size, 500); + let size = conn.send(&bytes); + assert!(size.is_ok_and(|s| s == 500), "Failed to send bytes"); total_num_bytes_sent += 500; // Send 399 bytes let bytes = [18; 399]; - let size = timeout(Duration::from_secs(3), conn.send(&bytes)) - .await - .expect("Client timed out while sending batch #6") - .expect("Client encountered error while sending"); - assert_eq!(size, 399); + let size = conn.send(&bytes); + assert!(size.is_ok_and(|s| s == 399), "Failed to send bytes"); total_num_bytes_sent += 399; // Send 1 byte let bytes = [19]; - let size = timeout(Duration::from_secs(3), conn.send(&bytes)) - .await - .expect("Client timed out while sending batch #7") - .expect("Client encountered error while sending"); - assert_eq!(size, 1); + let size = conn.send(&bytes); + assert!(size.is_ok_and(|s| s == 1), "Failed to send bytes"); total_num_bytes_sent += 1; // We will send 2000 bytes now in batches of 250 bytes for round_num in 0..8 { let bytes = [round_num; BATCH_SIZE]; - let size = timeout(Duration::from_secs(3), conn.send(&bytes)) - .await - .expect(&format!( - "Client timed out while sending batch #{}", - 8 + round_num - )) - .expect("Client encountered error while sending"); + let size = conn.send(&bytes).expect(&format!( + "Client timed out while sending batch #{}", + 8 + round_num + )); assert_eq!(size, BATCH_SIZE); total_num_bytes_sent += BATCH_SIZE; } @@ -285,106 +252,102 @@ async fn basic_server_client_multiple_connections_send_recv(loopback_ip_addr: &I let client_ports: [u16; NUM_CONNECTIONS] = [1420, 1421, 1422]; let loopback_cloned = loopback_ip_addr.clone(); let data = Arc::new(generate_connection_date(NUM_CONNECTIONS)); + let data_cloned = Arc::clone(&data); - for conn_num in 0..NUM_CONNECTIONS { - let mut s = server.clone(); - let data_cloned = Arc::clone(&data); - join_set.spawn(async move { - let mut conn = timeout(Duration::from_secs(10), s.accept()) - .await - .expect(&format!( - "Server #{} timed out waiting to accept connection from client", - conn_num - )) - .expect("Failed to create bluefin connection"); - - // The test will first send a key of five bytes. - let mut key_buf: [u8; 5] = [0; 5]; - let size = timeout(Duration::from_secs(1), conn.recv(&mut key_buf, 5)) - .await - .expect("Server timed out while waiting for key") - .expect("Server encountered error while receiving"); - assert_eq!(size, 5); - let key = match str::from_utf8(&key_buf) { - Ok(s) => s, - Err(_) => panic!("Could not retrieve key from client"), - }; - - let expected_data = data_cloned.get(key).expect("Could not fetch expected data"); - let mut stitched_bytes: Vec = Vec::new(); - let mut buf = [0u8; 1500]; - loop { - let size = timeout(Duration::from_secs(1), conn.recv(&mut buf, 1500)) + join_set.spawn(async move { + let mut conn_num = 0; + while let Ok(mut conn) = timeout(Duration::from_secs(10), server.accept()) + .await + .expect(&format!( + "Server #{} timed out waiting to accept connection from client", + conn_num + )) + { + let data_cloned = Arc::clone(&data_cloned); + spawn(async move { + // The test will first send a key of five bytes. + let mut key_buf: [u8; 5] = [0; 5]; + let size = timeout(Duration::from_secs(1), conn.recv(&mut key_buf, 5)) .await - .expect("Server timed out while waiting for data") - .expect("Server encountered error while receiving data"); - assert_ne!(size, 0); - assert!(size <= 1500); - stitched_bytes.extend_from_slice(&buf[..size]); - - if stitched_bytes.len() == MAX_BYTES_SENT_PER_CONNECTION { - break; + .expect("Server timed out while waiting for key") + .expect("Server encountered error while receiving"); + assert_eq!(size, 5); + let key = match str::from_utf8(&key_buf) { + Ok(s) => s, + Err(_) => panic!("Could not retrieve key from client"), + }; + + let expected_data = data_cloned.get(key).expect("Could not fetch expected data"); + let mut stitched_bytes: Vec = Vec::new(); + let mut buf = [0u8; 1500]; + loop { + let size = timeout(Duration::from_secs(1), conn.recv(&mut buf, 1500)) + .await + .expect("Server timed out while waiting for data") + .expect("Server encountered error while receiving data"); + assert_ne!(size, 0); + assert!(size <= 1500); + stitched_bytes.extend_from_slice(&buf[..size]); + + if stitched_bytes.len() == MAX_BYTES_SENT_PER_CONNECTION { + break; + } } + assert_eq!(stitched_bytes, *expected_data); + }); + conn_num += 1; + + if conn_num >= NUM_CONNECTIONS { + break; } - assert_eq!(stitched_bytes, *expected_data); - }); - } + } + }); + for conn_num in 0..NUM_CONNECTIONS { // Random amount of time to sleep let data_cloned = Arc::clone(&data); - join_set.spawn(async move { - let mut client = BluefinClient::new(std::net::SocketAddr::V4(SocketAddrV4::new( + let mut client = BluefinClient::new(std::net::SocketAddr::V4(SocketAddrV4::new( + loopback_cloned, + client_ports[conn_num], + ))); + + if let Ok(mut conn) = client + .connect(std::net::SocketAddr::V4(SocketAddrV4::new( loopback_cloned, - client_ports[conn_num], - ))); - - let mut conn = client - .connect(std::net::SocketAddr::V4(SocketAddrV4::new( - loopback_cloned, - 1419, - ))) - .await - .expect(&format!( - "Client #{} timed out waiting to connect to server", - conn_num - )); - - // Tell the server who we are by sending the key. Key is five bytes. - let key = format!("key_{}", conn_num); - let size = timeout(Duration::from_secs(1), conn.send(key.as_bytes())) - .await - .expect("Client timed out after sending key") - .expect("Client encountered error while sending"); - assert_eq!(size, 5); - - sleep(Duration::from_millis(10)).await; - - // Now begin sending the actual data in batches of 32 bytes - let mut total_bytes_sent = 5; - let data_to_send = data_cloned.get(&key).unwrap(); - let max_num_iterations = data_to_send.len() / 32; - let mut start_ix = 0; - let mut num_iterations = 0; - while num_iterations < max_num_iterations { - let size = timeout( - Duration::from_secs(1), - conn.send(&data_to_send[start_ix..start_ix + 32]), - ) - .await - .expect("Client timed out after sending data") - .expect("Client encountered error while sending"); - assert_eq!(size, 32); - start_ix += 32; - total_bytes_sent += size; - num_iterations += 1; - } + 1419, + ))) + .await + { + join_set.spawn(async move { + // Tell the server who we are by sending the key. Key is five bytes. + let key = format!("key_{}", conn_num); + let size = conn.send(key.as_bytes()); + assert!(size.is_ok_and(|s| s == 5), "Failed to send bytes"); + + sleep(Duration::from_millis(10)).await; + + // Now begin sending the actual data in batches of 32 bytes + let mut total_bytes_sent = 5; + let data_to_send = data_cloned.get(&key).unwrap(); + let max_num_iterations = data_to_send.len() / 32; + let mut start_ix = 0; + let mut num_iterations = 0; + while num_iterations < max_num_iterations { + let size = conn.send(&data_to_send[start_ix..start_ix + 32]); + assert!(size.is_ok_and(|s| s == 32), "Failed to send bytes"); + start_ix += 32; + total_bytes_sent += 32; + num_iterations += 1; + } - assert_eq!( - total_bytes_sent, - MAX_BYTES_SENT_PER_CONNECTION + 5, - "Did not send the expected number of bytes" - ); - }); + assert_eq!( + total_bytes_sent, + MAX_BYTES_SENT_PER_CONNECTION + 5, + "Did not send the expected number of bytes" + ); + }); + sleep(Duration::from_millis(5)).await; + } } join_set.join_all().await;