diff --git a/.github/workflows/bluefin.yml b/.github/workflows/bluefin.yml index 77c1bfc..974da28 100644 --- a/.github/workflows/bluefin.yml +++ b/.github/workflows/bluefin.yml @@ -28,6 +28,7 @@ jobs: run: cargo build --verbose - name: Run tests run: cargo test --verbose + coverage: runs-on: ubuntu-latest env: @@ -39,10 +40,21 @@ jobs: - name: Install cargo-llvm-cov uses: taiki-e/install-action@cargo-llvm-cov - name: Generate code coverage - run: cargo llvm-cov --all-features --workspace --lcov --output-path lcov.info + run: cargo llvm-cov --all-features --workspace --lcov --ignore-filename-regex "error.rs|*/bin/*.rs" --output-path lcov.info - name: Upload coverage to Codecov uses: codecov/codecov-action@v3 with: token: ${{ secrets.CODECOV_TOKEN }} # not required for public repos files: lcov.info fail_ci_if_error: false + + kani: + runs-on: ubuntu-latest + strategy: + fail-fast: false + steps: + - name: Checkout bluefin + uses: actions/checkout@v4 + - name: Run Kani + uses: model-checking/kani-github-action@v1.1 + \ No newline at end of file diff --git a/Cargo.toml b/Cargo.toml index 9b493c7..6c8f9b2 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -13,8 +13,11 @@ etherparse = "0.15.0" local-ip-address = "0.6.3" rand = "0.8.5" rstest = "0.23.0" -thiserror = "1.0.39" -tokio = { version = "1.41.1", features = ["full"] } +thiserror = "2.0.3" +tokio = { version = "1.41.1", features = ["full", "tracing"] } +console-subscriber = "0.4.1" +libc = "0.2.164" +sysctl = "0.6.0" [dev-dependencies] local-ip-address = "0.6.3" @@ -31,6 +34,9 @@ path = "src/bin/client.rs" name = "server" path = "src/bin/server.rs" +[lints.rust] +unexpected_cfgs = { level = "warn", check-cfg = ['cfg(coverage,coverage_nightly)', 'cfg(kani)'] } + [profile.release] opt-level = 3 codegen-units = 1 diff --git a/src/bin/client.rs b/src/bin/client.rs index 90a1abe..6d388e9 100644 --- a/src/bin/client.rs +++ b/src/bin/client.rs @@ -1,3 +1,4 @@ +#![cfg_attr(coverage_nightly, feature(coverage_attribute))] use std::{ net::{Ipv4Addr, SocketAddrV4}, time::Duration, @@ -8,9 +9,11 @@ use bluefin::{ }; use tokio::{spawn, time::sleep}; +#[cfg_attr(coverage_nightly, coverage(off))] #[tokio::main] async fn main() -> BluefinResult<()> { - let ports = [1320, 1322]; + // console_subscriber::init(); + let ports = [1320, 1322, 1323, 1324, 1325]; let mut tasks = vec![]; for ix in 0..2 { // sleep(Duration::from_secs(3)).await; @@ -40,21 +43,23 @@ async fn main() -> BluefinResult<()> { total_bytes += size; println!("Sent {} bytes", size); - sleep(Duration::from_secs(2)).await; + sleep(Duration::from_secs(1)).await; size = conn.send(&[14, 14, 14, 14, 14, 14]).await?; total_bytes += size; println!("Sent {} bytes", size); - for ix in 0..200000 { - let my_array: [u8; 32] = rand::random(); + for ix in 0..5000000 { + // let my_array: [u8; 32] = rand::random(); + let my_array = [0u8; 1500]; size = conn.send(&my_array).await?; total_bytes += size; - if ix % 1250 == 0 { - sleep(Duration::from_millis(10)).await; + if ix % 3000 == 0 { + sleep(Duration::from_millis(3)).await; } } println!("Sent {} bytes", total_bytes); + sleep(Duration::from_secs(2)).await; Ok::<(), BluefinError>(()) }); diff --git a/src/bin/server.rs b/src/bin/server.rs index 510d1e7..08722ea 100644 --- a/src/bin/server.rs +++ b/src/bin/server.rs @@ -1,37 +1,48 @@ +#![cfg_attr(coverage_nightly, feature(coverage_attribute))] use std::{ + cmp::{max, min}, net::{Ipv4Addr, SocketAddrV4}, - time::Duration, + time::{Duration, Instant}, }; use bluefin::{net::server::BluefinServer, utils::common::BluefinResult}; use tokio::{spawn, time::sleep}; +#[cfg_attr(coverage_nightly, coverage(off))] #[tokio::main] async fn main() -> BluefinResult<()> { + // console_subscriber::init(); let mut server = BluefinServer::new(std::net::SocketAddr::V4(SocketAddrV4::new( Ipv4Addr::new(127, 0, 0, 1), 1318, ))); + server.set_num_reader_workers(50)?; server.bind().await?; - const MAX_NUM_CONNECTIONS: usize = 5; - for _ in 0..MAX_NUM_CONNECTIONS { + const MAX_NUM_CONNECTIONS: usize = 3; + for conn_num in 0..MAX_NUM_CONNECTIONS { let mut s = server.clone(); let _ = spawn(async move { - let mut total_bytes = 0; + let _num = conn_num; loop { - println!(); let _conn = s.accept().await; match _conn { Ok(mut conn) => { spawn(async move { + let mut total_bytes = 0; + let mut recv_bytes = [0u8; 500000]; + let mut min_bytes = usize::MAX; + let mut max_bytes = 0; + let mut iteration = 1; + let now = Instant::now(); loop { - let mut recv_bytes = [0u8; 1024]; - let size = conn.recv(&mut recv_bytes, 1024).await.unwrap(); + // eprintln!("Waiting..."); + let size = conn.recv(&mut recv_bytes, 500000).await.unwrap(); total_bytes += size; - - println!("total bytes: {}", total_bytes); + min_bytes = min(size, min_bytes); + max_bytes = max(size, max_bytes); + // eprintln!("read {} bytes --- total bytes: {}", size, total_bytes); /* println!( @@ -42,6 +53,22 @@ async fn main() -> BluefinResult<()> { total_bytes ); */ + if total_bytes >= 100000 { + let elapsed = now.elapsed().as_secs(); + let through_put = u64::try_from(total_bytes).unwrap() / elapsed; + let avg_recv_bytes: f64 = total_bytes as f64 / iteration as f64; + eprintln!( + "{} {:.1} kb/s or {:.1} mb/s (read {:.1} kb/iteration, min: {:.1} kb, max: {:.1} kb)", + _num, + through_put as f64 / 1e3, + through_put as f64 / 1e6, + avg_recv_bytes / 1e3, + min_bytes as f64 / 1e3, + max_bytes as f64 / 1e3 + ); + // break; + } + iteration += 1; } }); } diff --git a/src/core/error.rs b/src/core/error.rs index f5cfcaf..0da3fb0 100644 --- a/src/core/error.rs +++ b/src/core/error.rs @@ -1,6 +1,6 @@ use thiserror::Error; -#[derive(Error, Debug)] +#[derive(Error, Debug, PartialEq)] pub enum BluefinError { #[error("Unable to serialise data")] SerialiseError, @@ -11,8 +11,8 @@ pub enum BluefinError { #[error("Connection buffer does not exist")] BufferDoesNotExist, - #[error("Current buffer is full.")] - BufferFullError, + #[error("Current buffer is full: `{0}`")] + BufferFullError(String), #[error("Current buffer is empty.")] BufferEmptyError, diff --git a/src/core/packet.rs b/src/core/packet.rs index 7128fcf..2bcf941 100644 --- a/src/core/packet.rs +++ b/src/core/packet.rs @@ -1,6 +1,6 @@ -use crate::core::header::BluefinHeader; +use crate::{core::header::BluefinHeader, utils::common::BluefinResult}; -use super::{error::BluefinError, Serialisable}; +use super::{error::BluefinError, header::PacketType, Serialisable}; #[derive(Clone, Debug)] pub struct BluefinPacket { @@ -66,6 +66,55 @@ impl BluefinPacket { // Header is always 20 bytes self.payload.len() + 20 } + + /// Converts an array of bytes into a vector of bluefin packets. The array of bytes must be + /// a valid stream of bluefin packet bytes. Otherwise, an error is returned. + #[inline] + pub fn from_bytes(bytes: &[u8]) -> BluefinResult> { + if bytes.len() < 20 { + return Err(BluefinError::ReadError( + "Array must be at least 20 bytes to contain at least one bluefin packet" + .to_string(), + )); + } + let mut packets = vec![]; + let mut cursor = 0; + while cursor < bytes.len() && cursor + 20 <= bytes.len() { + let header = BluefinHeader::deserialise(&bytes[cursor..cursor + 20])?; + match header.type_field { + PacketType::Ack + | PacketType::UnencryptedClientHello + | PacketType::UnencryptedServerHello + | PacketType::ClientAck => { + // Acks + handshake packets contain no payload (for now) + let packet = BluefinPacket::builder().header(header).build(); + packets.push(packet); + cursor += 20; + } + _ => { + // This is some data field + let payload_len = header.type_specific_payload as usize; + if cursor + 20 >= bytes.len() || cursor + 19 + payload_len >= bytes.len() { + return Err(BluefinError::ReadError( + "Cannot read all bytes specified by header".to_string(), + )); + } + let payload = &bytes[cursor + 20..cursor + 20 + payload_len]; + let packet = BluefinPacket::builder() + .header(header) + .payload(payload.to_vec()) + .build(); + packets.push(packet); + cursor = cursor + 20 + payload_len; + } + }; + } + + if cursor != bytes.len() { + return Err(BluefinError::ReadError("Was not able to read all bytes into bluefin packets. Likely indicates corrupted UDP datagram.".to_string())); + } + Ok(packets) + } } impl BluefinPacketBuilder { @@ -89,3 +138,139 @@ impl BluefinPacketBuilder { } } } + +#[cfg(test)] +mod tests { + use crate::core::{ + error::BluefinError, + header::{BluefinHeader, BluefinSecurityFields, PacketType}, + Serialisable, + }; + + use super::BluefinPacket; + + #[test] + fn cannot_deserialise_invalid_bytes_into_bluefin_packets() { + let security_fields = BluefinSecurityFields::new(false, 0x0); + + let mut header = BluefinHeader::new(0x0, 0x0, PacketType::Ack, 13, security_fields); + let payload: [u8; 32] = rand::random(); + header.type_field = PacketType::UnencryptedData; + header.type_specific_payload = 32; + header.version = 13; + let mut packet = BluefinPacket::builder() + .header(header.clone()) + .payload(payload.to_vec()) + .build(); + assert_eq!(packet.len(), 52); + assert!(BluefinPacket::from_bytes(&packet.serialise()).is_ok()); + + // Incorrectly specify the length to be 33 instead of 32 + packet.header.type_specific_payload = (payload.len() + 1) as u16; + assert!( + BluefinPacket::from_bytes(&packet.serialise()).is_err_and(|e| e + == BluefinError::ReadError( + "Cannot read all bytes specified by header".to_string() + )) + ); + + // Now test again but specify a payload length under the actual payload len + packet.header.type_specific_payload = (payload.len() - 1) as u16; + assert!( + BluefinPacket::from_bytes(&packet.serialise()).is_err_and(|e| e + == BluefinError::ReadError( + "Was not able to read all bytes into bluefin packets. Likely indicates corrupted UDP datagram.".to_string() + )) + ); + } + + #[test] + fn able_to_deserialise_bytes_into_multiple_bluefin_packets_correctly() { + // Build 6 packets + let mut packets = vec![]; + let security_fields = BluefinSecurityFields::new(false, 0x0); + + // Push in an ack + let mut header = BluefinHeader::new(0x0, 0x0, PacketType::Ack, 13, security_fields); + let mut packet = BluefinPacket::builder().header(header.clone()).build(); + packets.push(packet); + + // Push in data payload with 32 bytes + let payload: [u8; 32] = rand::random(); + header.type_field = PacketType::UnencryptedData; + header.type_specific_payload = 32; + header.version = 13; + packet = BluefinPacket::builder() + .header(header.clone()) + .payload(payload.to_vec()) + .build(); + packets.push(packet); + + // Push in data payload with 20 bytes + let payload: [u8; 20] = rand::random(); + header.type_field = PacketType::UnencryptedData; + header.type_specific_payload = 20; + header.destination_connection_id = 0x123; + header.version = 15; + packet = BluefinPacket::builder() + .header(header.clone()) + .payload(payload.to_vec()) + .build(); + packets.push(packet); + + // Push in an ack + header.type_field = PacketType::Ack; + packet = BluefinPacket::builder().header(header.clone()).build(); + header.version = 0; + packets.push(packet); + + // Push in an client hello + header.type_field = PacketType::UnencryptedClientHello; + packet = BluefinPacket::builder().header(header.clone()).build(); + header.version = 5; + packets.push(packet); + + // Push in data payload with 15 bytes + let payload: [u8; 15] = rand::random(); + header.type_field = PacketType::UnencryptedData; + header.destination_connection_id = 0x0; + header.source_connection_id = 0xabc; + header.type_specific_payload = 15; + packet = BluefinPacket::builder() + .header(header.clone()) + .payload(payload.to_vec()) + .build(); + packets.push(packet); + + // Serialise packets and place into array + let mut bytes = vec![]; + for p in &packets { + bytes.extend_from_slice(&p.serialise()); + } + + // Total bytes should be the sum of the payloads plus all of the headers + assert_eq!(bytes.len(), 32 + 20 + 15 + (6 * 20)); + + // We were able to correctly restore the packets + let rebuilt_packets_res = BluefinPacket::from_bytes(&bytes); + assert!(rebuilt_packets_res.is_ok()); + + let rebuild_packets = rebuilt_packets_res.unwrap(); + assert_eq!(rebuild_packets.len(), packets.len()); + + for i in 0..packets.len() { + let expected = &packets[i]; + let actual = &rebuild_packets[i]; + assert_eq!( + expected.header.source_connection_id, + actual.header.source_connection_id + ); + assert_eq!( + expected.header.destination_connection_id, + actual.header.destination_connection_id + ); + assert_eq!(expected.header.version, actual.header.version); + assert_eq!(expected.payload, actual.payload); + } + } +} diff --git a/src/lib.rs b/src/lib.rs index c3cda43..da98698 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -2,3 +2,4 @@ pub mod core; pub mod net; pub mod utils; pub mod worker; +extern crate libc; diff --git a/src/main.rs b/src/main.rs index 7f755fb..fff2383 100644 --- a/src/main.rs +++ b/src/main.rs @@ -1,2 +1,4 @@ +#![cfg_attr(coverage_nightly, feature(coverage_attribute))] +#[cfg_attr(coverage_nightly, coverage(off))] #[tokio::main] async fn main() {} diff --git a/src/net/ack_handler.rs b/src/net/ack_handler.rs new file mode 100644 index 0000000..b096d01 --- /dev/null +++ b/src/net/ack_handler.rs @@ -0,0 +1,106 @@ +use std::{ + future::Future, + sync::{Arc, Mutex}, + task::{Poll, Waker}, + time::Duration, +}; + +use tokio::{sync::RwLock, time::sleep}; + +use crate::{ + core::{error::BluefinError, packet::BluefinPacket}, + utils::{ + common::BluefinResult, + window::{SlidingWindow, SlidingWindowConsumeResult}, + }, +}; + +#[derive(Clone)] +pub(crate) struct AckBuffer { + received_acks: SlidingWindow, + waker: Option, +} + +impl AckBuffer { + pub(crate) fn new(smallest_expected_recv_packet_num: u64) -> Self { + Self { + received_acks: SlidingWindow::new(smallest_expected_recv_packet_num), + waker: None, + } + } + + /// Buffers in the received ack + pub(crate) fn buffer_in_ack_packet(&mut self, packet: BluefinPacket) -> BluefinResult<()> { + let num_packets_to_ack = packet.header.type_specific_payload; + let base_packet_num = packet.header.packet_number; + for ix in 0..num_packets_to_ack { + self.received_acks + .insert_packet_number(base_packet_num + ix as u64)?; + } + Ok(()) + } + + #[inline] + fn consume(&mut self) -> Option { + self.received_acks.consume() + } + + #[inline] + pub(crate) fn wake(&mut self) -> BluefinResult<()> { + if let Some(ref waker) = self.waker { + waker.wake_by_ref(); + return Ok(()); + } + Err(BluefinError::NoSuchWakerError) + } +} + +#[derive(Clone)] +struct AckConsumerFuture { + ack_buff: Arc>, +} + +impl Future for AckConsumerFuture { + type Output = SlidingWindowConsumeResult; + + fn poll(self: std::pin::Pin<&mut Self>, cx: &mut std::task::Context<'_>) -> Poll { + let mut guard = self.ack_buff.lock().unwrap(); + if let Some(res) = guard.consume() { + return Poll::Ready(res); + } + guard.waker = Some(cx.waker().clone()); + Poll::Pending + } +} + +#[derive(Clone)] +pub(crate) struct AckConsumer { + future: AckConsumerFuture, + largest_recv_acked_packet_num: Arc>, +} + +impl AckConsumer { + pub(crate) fn new( + ack_buff: Arc>, + largest_recv_acked_packet_num: Arc>, + ) -> Self { + let future = AckConsumerFuture { ack_buff }; + Self { + future, + largest_recv_acked_packet_num, + } + } + + pub(crate) async fn run(&self) { + loop { + let res = self.future.clone().await; + + { + let mut guard = self.largest_recv_acked_packet_num.write().await; + *guard = res.largest_packet_number; + } + + sleep(Duration::from_micros(5)).await; + } + } +} diff --git a/src/net/client.rs b/src/net/client.rs index 33e7ae9..7b3d479 100644 --- a/src/net/client.rs +++ b/src/net/client.rs @@ -15,15 +15,19 @@ use crate::{ utils::common::BluefinResult, }; -use super::connection::{BluefinConnection, ConnectionBuffer, ConnectionManager}; +use super::{ + connection::{BluefinConnection, ConnectionBuffer, ConnectionManager}, + AckBuffer, ConnectionManagedBuffers, +}; -const NUM_TX_WORKERS_FOR_CLIENT: u8 = 5; +const NUM_TX_WORKERS_FOR_CLIENT_DEFAULT: u16 = 5; pub struct BluefinClient { socket: Option>, src_addr: SocketAddr, dst_addr: Option, conn_manager: Arc>, + num_reader_workers: u16, } impl BluefinClient { @@ -33,17 +37,17 @@ impl BluefinClient { dst_addr: None, conn_manager: Arc::new(RwLock::new(ConnectionManager::new())), src_addr, + num_reader_workers: NUM_TX_WORKERS_FOR_CLIENT_DEFAULT, } } pub async fn connect(&mut self, dst_addr: SocketAddr) -> BluefinResult { let socket = Arc::new(UdpSocket::bind(self.src_addr).await?); - // socket.connect(dst_addr).await?; self.socket = Some(Arc::clone(&socket)); self.dst_addr = Some(dst_addr); build_and_start_tx( - NUM_TX_WORKERS_FOR_CLIENT, + self.num_reader_workers, Arc::clone(self.socket.as_ref().unwrap()), Arc::clone(&self.conn_manager), Arc::new(Mutex::new(Vec::new())), @@ -51,10 +55,16 @@ impl BluefinClient { ); let src_conn_id: u32 = rand::thread_rng().gen(); + let packet_number: u64 = rand::thread_rng().gen(); let conn_buffer = Arc::new(Mutex::new(ConnectionBuffer::new( src_conn_id, BluefinHost::Client, ))); + let ack_buff = Arc::new(Mutex::new(AckBuffer::new(packet_number + 2))); + let conn_mgrs_buffs = ConnectionManagedBuffers { + conn_buff: Arc::clone(&conn_buffer), + ack_buff: Arc::clone(&ack_buff), + }; let handshake_buf = HandshakeConnectionBuffer::new(Arc::clone(&conn_buffer)); // Register the connection @@ -62,10 +72,9 @@ impl BluefinClient { self.conn_manager .write() .await - .insert(&hello_key, Arc::clone(&conn_buffer))?; + .insert(&hello_key, conn_mgrs_buffs.clone())?; // send the client hello - let packet_number: u64 = rand::thread_rng().gen(); let packet = build_empty_encrypted_packet( src_conn_id, 0x0, @@ -92,10 +101,11 @@ impl BluefinClient { } // delete the old hello entry and insert the new connection entry - let mut guard = self.conn_manager.write().await; - let _ = guard.remove(&hello_key); - let _ = guard.insert(&key, Arc::clone(&conn_buffer)); - drop(guard); + { + let mut guard = self.conn_manager.write().await; + let _ = guard.remove(&hello_key); + let _ = guard.insert(&key, conn_mgrs_buffs); + } // send the client ack let packet = build_empty_encrypted_packet( @@ -115,6 +125,7 @@ impl BluefinClient { dst_conn_id, packet_number + 2, Arc::clone(&conn_buffer), + Arc::clone(&ack_buff), Arc::clone(self.socket.as_ref().unwrap()), self.dst_addr.unwrap(), )) diff --git a/src/net/connection.rs b/src/net/connection.rs index 498fb5e..20430ab 100644 --- a/src/net/connection.rs +++ b/src/net/connection.rs @@ -1,7 +1,6 @@ use std::{ collections::HashMap, future::Future, - io::Write, net::SocketAddr, sync::{Arc, Mutex}, task::{Poll, Waker}, @@ -11,12 +10,7 @@ use std::{ use tokio::{net::UdpSocket, time::timeout}; use crate::{ - core::{ - context::BluefinHost, - error::BluefinError, - header::{BluefinHeader, BluefinSecurityFields, PacketType}, - packet::BluefinPacket, - }, + core::{context::BluefinHost, error::BluefinError, packet::BluefinPacket}, utils::common::BluefinResult, worker::{reader::ReaderRxChannel, writer::WriterTxChannel}, }; @@ -24,7 +18,7 @@ use crate::{ use super::{ build_and_start_writer_rx_channel, ordered_bytes::{ConsumeResult, OrderedBytes}, - WriterQueue, + AckBuffer, ConnectionManagedBuffers, WriterQueue, }; pub const MAX_BUFFER_SIZE: usize = 2000; @@ -135,28 +129,29 @@ impl ConnectionBuffer { } #[inline] - pub(crate) fn buffer_in_bytes(&mut self, packet: &BluefinPacket) -> BluefinResult<()> { + pub(crate) fn buffer_in_bytes(&mut self, packet: BluefinPacket) -> BluefinResult<()> { self.ordered_bytes.buffer_in_packet(packet) } #[inline] - pub(crate) fn buffer_in_packet(&mut self, packet: &BluefinPacket) -> BluefinResult<()> { + pub(crate) fn buffer_in_packet(&mut self, packet: BluefinPacket) -> BluefinResult<()> { if self.packet.is_some() { - return Err(BluefinError::BufferFullError); + return Err(BluefinError::BufferFullError( + "Buffer already contains a packet. Could not buffer another packet.".to_string(), + )); } - self.packet = Some(packet.clone()); + let packet_num = packet.header.packet_number; + self.packet = Some(packet); // We always set the start packet numbers once. For servers, we set in advance // that the start number is the first client hello we get + 2. (There is an ack) // For the client, we set it to + 1 (the next message we get should be data) if !self.set_start_packet_number { if self.host_type == BluefinHost::PackLeader { - self.ordered_bytes - .set_start_packet_number(packet.header.packet_number + 2); + self.ordered_bytes.set_start_packet_number(packet_num + 2); } else if self.host_type == BluefinHost::Client { - self.ordered_bytes - .set_start_packet_number(packet.header.packet_number + 1); + self.ordered_bytes.set_start_packet_number(packet_num + 1); } self.set_start_packet_number = true; } @@ -164,15 +159,11 @@ impl ConnectionBuffer { Ok(()) } - #[inline] - pub(crate) fn buffer_in_ack_packet(&mut self, _packet: &BluefinPacket) -> BluefinResult<()> { - self.ordered_bytes.buffer_in_ack_packet(_packet) - } - #[inline] pub(crate) fn consume( &mut self, bytes_to_read: usize, + buf: &mut [u8], ) -> BluefinResult<(ConsumeResult, SocketAddr)> { if self.addr.is_none() { return Err(BluefinError::Unexpected( @@ -180,14 +171,17 @@ impl ConnectionBuffer { )); } - let consume_res = self.ordered_bytes.consume(bytes_to_read)?; - if consume_res.get_bytes().len() > bytes_to_read { + let consume_res = self.ordered_bytes.consume(bytes_to_read, buf)?; + Ok((consume_res, self.addr.unwrap())) + } + + pub(crate) fn peek(&self) -> BluefinResult<()> { + if self.addr.is_none() { return Err(BluefinError::Unexpected( - "Consumed more bytes than specified".to_string(), + "Cannot consume buffer because addr is field is none".to_string(), )); } - - Ok((consume_res, self.addr.unwrap())) + self.ordered_bytes.peek() } #[inline] @@ -212,7 +206,7 @@ impl ConnectionBuffer { pub(crate) struct ConnectionManager { /// Key: {src_conn_id}_{dst_conn_id} /// Value: The connection buffer - map: HashMap>>, + map: HashMap, } impl ConnectionManager { @@ -226,7 +220,7 @@ impl ConnectionManager { pub(crate) fn insert( &mut self, key: &str, - element: Arc>, + element: ConnectionManagedBuffers, ) -> BluefinResult<()> { if self.map.contains_key(key) { return Err(BluefinError::ConnectionAlreadyExists); @@ -238,7 +232,7 @@ impl ConnectionManager { } #[inline] - pub(crate) fn get(&self, key: &str) -> Option>> { + pub(crate) fn get(&self, key: &str) -> Option { self.map.get(key).cloned() } @@ -255,11 +249,10 @@ impl ConnectionManager { /// connection established between a client and server after the handshake process /// has completed successfully. A bluefin connection allows users to [receive](BluefinConnection::recv) /// and to [send](BluefinConnection::send) bytes across the wire. +#[derive(Clone)] pub struct BluefinConnection { pub src_conn_id: u32, pub dst_conn_id: u32, - // This is the *next* packet number we must use - packet_num: Arc>, reader_rx: ReaderRxChannel, writer_tx: WriterTxChannel, } @@ -268,35 +261,33 @@ impl BluefinConnection { pub(crate) fn new( src_conn_id: u32, dst_conn_id: u32, - packet_num: u64, + next_send_packet_num: u64, conn_buffer: Arc>, + ack_buffer: Arc>, socket: Arc, dst_addr: SocketAddr, ) -> Self { - let shared_packet_num = Arc::new(tokio::sync::Mutex::new(packet_num)); let writer_queue = Arc::new(Mutex::new(WriterQueue::new())); - let writer_tx = WriterTxChannel::new(Arc::clone(&writer_queue)); + let ack_queue = Arc::new(Mutex::new(WriterQueue::new())); + let writer_tx = WriterTxChannel::new(Arc::clone(&writer_queue), Arc::clone(&ack_queue)); build_and_start_writer_rx_channel( Arc::clone(&writer_queue), + Arc::clone(&ack_queue), Arc::clone(&socket), - 2, + 1, dst_addr, - ); - - let reader_rx = ReaderRxChannel::new( - Arc::clone(&conn_buffer), - Arc::clone(&socket), + Arc::clone(&ack_buffer), + next_send_packet_num, src_conn_id, dst_conn_id, - writer_tx.clone(), - Arc::clone(&shared_packet_num), ); + let reader_rx = ReaderRxChannel::new(Arc::clone(&conn_buffer), writer_tx.clone()); + Self { src_conn_id, dst_conn_id, - packet_num: Arc::clone(&shared_packet_num), reader_rx, writer_tx, } @@ -304,36 +295,15 @@ impl BluefinConnection { #[inline] pub async fn recv(&mut self, buf: &mut [u8], len: usize) -> BluefinResult { - self.reader_rx.set_bytes_to_read(len); - let (bytes, _) = self.reader_rx.read().await?; - let size = buf.as_mut().write(&bytes)?; - return Ok(size); + let (size, _) = self.reader_rx.read(len, buf).await?; + return Ok(size as usize); } #[inline] pub async fn send(&mut self, buf: &[u8]) -> BluefinResult { - // create bluefin packet and send - let security_fields = BluefinSecurityFields::new(false, 0x0); - let mut header = BluefinHeader::new( - self.src_conn_id, - self.dst_conn_id, - PacketType::UnencryptedData, - 0x0, - security_fields, - ); - let mut packet_num = self.packet_num.lock().await; - header.with_packet_number(*packet_num); - let packet = BluefinPacket::builder() - .header(header) - .payload(buf.to_vec()) - .build(); // TODO! This returns the total bytes sent (including bluefin payload). This // really should only return the total payload bytes - let _ = self.writer_tx.send(packet).await?; - - // HANDLE THIS! - *packet_num += 1; - + let _ = self.writer_tx.send(buf).await?; Ok(buf.len()) } } diff --git a/src/net/mod.rs b/src/net/mod.rs index 79adf14..b8428da 100644 --- a/src/net/mod.rs +++ b/src/net/mod.rs @@ -3,7 +3,8 @@ use std::{ sync::{Arc, Mutex}, }; -use connection::ConnectionManager; +use ack_handler::{AckBuffer, AckConsumer}; +use connection::{ConnectionBuffer, ConnectionManager}; use tokio::{net::UdpSocket, spawn, sync::RwLock}; use crate::{ @@ -18,15 +19,28 @@ use crate::{ }, }; +pub mod ack_handler; pub mod client; pub mod connection; pub mod ordered_bytes; pub mod server; +pub(crate) const BLUEFIN_HEADER_SIZE_BYTES: usize = 20; +pub(crate) const MAX_BLUEFIN_PAYLOAD_SIZE_BYTES: usize = 1500; +pub(crate) const MAX_BLUEFIN_PACKETS_IN_UDP_DATAGRAM: usize = 10; +pub(crate) const MAX_BLUEFIN_BYTES_IN_UDP_DATAGRAM: usize = MAX_BLUEFIN_PACKETS_IN_UDP_DATAGRAM + * (BLUEFIN_HEADER_SIZE_BYTES + MAX_BLUEFIN_PAYLOAD_SIZE_BYTES); + +#[derive(Clone)] +pub(crate) struct ConnectionManagedBuffers { + pub(crate) conn_buff: Arc>, + pub(crate) ack_buff: Arc>, +} + /// Helper to build `num_tx_workers` number of tx workers to run. #[inline] fn build_and_start_tx( - num_tx_workers: u8, + num_tx_workers: u16, socket: Arc, conn_manager: Arc>, pending_accept_ids: Arc>>, @@ -44,16 +58,39 @@ fn build_and_start_tx( } fn build_and_start_writer_rx_channel( - queue: Arc>, + data_queue: Arc>, + ack_queue: Arc>, socket: Arc, - num_rx_workers: u8, + num_ack_consumer_workers: u8, dst_addr: SocketAddr, + ack_buffer: Arc>, + next_packet_num: u64, + src_conn_id: u32, + dst_conn_id: u32, ) { - let rx = WriterRxChannel::new(queue, socket, dst_addr); - for _ in 0..num_rx_workers { - let rx_clone = rx.clone(); + let largest_recv_acked_packet_num = Arc::new(RwLock::new(0)); + let mut rx = WriterRxChannel::new( + data_queue, + ack_queue, + socket, + dst_addr, + next_packet_num, + src_conn_id, + dst_conn_id, + ); + let mut cloned = rx.clone(); + spawn(async move { + cloned.run_data().await; + }); + spawn(async move { + rx.run_ack().await; + }); + let ack_consumer = AckConsumer::new(Arc::clone(&ack_buffer), largest_recv_acked_packet_num); + + for _ in 0..num_ack_consumer_workers { + let ack_consumer_clone = ack_consumer.clone(); spawn(async move { - let _ = rx_clone.run().await; + ack_consumer_clone.run().await; }); } } @@ -124,28 +161,3 @@ pub(crate) fn build_empty_encrypted_packet( header.with_packet_number(packet_number); BluefinPacket::builder().header(header).build() } - -#[inline] -// Type-specific payload contains the # of packets we acked. -// The actual payload contains the base packet number. -pub(crate) fn build_ack_packet( - src_conn_id: u32, - dst_conn_id: u32, - base_packet_number_ack: u64, - number_packets_to_ack: u16, - packet_number: u64, -) -> BluefinPacket { - let security_fields = BluefinSecurityFields::new(false, 0x0); - let mut header = BluefinHeader::new( - src_conn_id, - dst_conn_id, - PacketType::Ack, - number_packets_to_ack, - security_fields, - ); - header.with_packet_number(packet_number); - BluefinPacket::builder() - .header(header) - .payload(base_packet_number_ack.to_ne_bytes().to_vec()) - .build() -} diff --git a/src/net/ordered_bytes.rs b/src/net/ordered_bytes.rs index 1b6b8ae..c1d58e8 100644 --- a/src/net/ordered_bytes.rs +++ b/src/net/ordered_bytes.rs @@ -2,12 +2,12 @@ use std::fmt; use crate::{ core::{error::BluefinError, packet::BluefinPacket}, - utils::{common::BluefinResult, window::SlidingWindow}, + utils::common::BluefinResult, }; /// Represents the maximum number of *packets* we can buffer in memory. When bytes are consumed /// via [OrderedBytes::consume()], we can only consume at most [MAX_BUFFER_SIZE] number of packets. -pub const MAX_BUFFER_SIZE: usize = 2500; +pub const MAX_BUFFER_SIZE: usize = 10000000; /// [OrderedBytes] represents the connection's buffered packets. OrderedBytes stores at most /// [MAX_BUFFER_SIZE] number of bluefin packets and maintains their intended consumption @@ -18,7 +18,7 @@ pub(crate) struct OrderedBytes { /// The connection id that owns the ordered bytes. Used for debugging. conn_id: u32, /// Represents the in-ordered buffer of packets. This is a circular buffer. - packets: [Option; MAX_BUFFER_SIZE], + packets: Box<[Option; MAX_BUFFER_SIZE]>, /// Pointer to the where the packet with the smallest packet number is buffered smallest_packet_number_index: usize, /// The packet number of the packet that *should* be buffered at packets\[start_index\] and @@ -27,8 +27,6 @@ pub(crate) struct OrderedBytes { /// Stores any potential carry over bytes from a previous consume. These bytes belong to /// a packet we have already consumed. carry_over_bytes: Option>, - /// Holds all of the packet numbers for which we have received acks for. - received_acks: SlidingWindow, } /// The result returned when [OrderedBytes are consumed](OrderedBytes::consume()). This result @@ -36,34 +34,34 @@ pub(crate) struct OrderedBytes { /// ack packets to the sender. Notice that once bytes are returned in a [ConsumeResult] then /// the bytes are no longer available in the [OrderedBytes] for consumption. pub(crate) struct ConsumeResult { - bytes: Vec, num_packets_consumed: usize, base_packet_number: u64, + bytes_consumed: u64, } impl ConsumeResult { - fn new(bytes: Vec, num_packets_consumed: usize, base_packet_number: u64) -> Self { + #[inline] + fn new(num_packets_consumed: usize, base_packet_number: u64, bytes_consumed: u64) -> Self { Self { - bytes, num_packets_consumed, base_packet_number, + bytes_consumed, } } - pub(crate) fn get_bytes(&self) -> &Vec { - return &self.bytes; - } - + #[inline] pub(crate) fn get_num_packets_consumed(&self) -> usize { self.num_packets_consumed } + #[inline] pub(crate) fn get_base_packet_number(&self) -> u64 { self.base_packet_number } - pub(crate) fn take_bytes(self) -> Vec { - return self.bytes; + #[inline] + pub(crate) fn get_bytes_consumed(&self) -> u64 { + self.bytes_consumed } } @@ -95,14 +93,15 @@ impl fmt::Display for OrderedBytes { impl OrderedBytes { pub(crate) fn new(conn_id: u32, start_packet_number: u64) -> Self { const ARRAY_REPEAT_VALUE: Option = None; - let packets = [ARRAY_REPEAT_VALUE; MAX_BUFFER_SIZE]; + let packets = vec![ARRAY_REPEAT_VALUE; MAX_BUFFER_SIZE] + .try_into() + .unwrap(); Self { conn_id, packets, smallest_packet_number_index: 0, smallest_packet_number: start_packet_number, carry_over_bytes: None, - received_acks: SlidingWindow::new(start_packet_number), } } @@ -123,7 +122,7 @@ impl OrderedBytes { /// If [MAX_BUFFER_SIZE] or more number of packets are already buffered, then we cannot /// buffer any more packets and will drop packets from the network. #[inline] - pub(crate) fn buffer_in_packet(&mut self, packet: &BluefinPacket) -> BluefinResult<()> { + pub(crate) fn buffer_in_packet(&mut self, packet: BluefinPacket) -> BluefinResult<()> { let packet_num = packet.header.packet_number; // We are expecting a packet with packet number >= start_packet_number @@ -134,7 +133,9 @@ impl OrderedBytes { // We received a packet that cannot fit in the buffer let offset = (packet_num - self.smallest_packet_number) as usize; if offset >= MAX_BUFFER_SIZE { - return Err(BluefinError::BufferFullError); + return Err(BluefinError::BufferFullError( + "Ordered bytes buffer full".to_string(), + )); } let index = (self.smallest_packet_number_index + offset) % MAX_BUFFER_SIZE; @@ -149,12 +150,23 @@ impl OrderedBytes { )); } - self.packets[index] = Some(packet.clone()); + self.packets[index] = Some(packet); Ok(()) } - pub(crate) fn buffer_in_ack_packet(&mut self, packet: &BluefinPacket) -> BluefinResult<()> { - Ok(()) + /// Ok(()) indicates there are bytes to consume. Error otherwise. + pub(crate) fn peek(&self) -> BluefinResult<()> { + // There are at least carry over bytes to consume + if let Some(_) = self.carry_over_bytes.as_ref() { + return Ok(()); + } + + // We have at least one packet buffered + if let Some(_) = self.packets[self.smallest_packet_number_index] { + return Ok(()); + } + + Err(BluefinError::BufferEmptyError) } /// Consumes the buffer, which removes consumable bytes from the buffer in-order and places @@ -171,22 +183,24 @@ impl OrderedBytes { /// [OrderedBytes::consume()] will return [BluefinError::BufferEmptyError] if no bytes can be /// consumed. #[inline] - pub(crate) fn consume(&mut self, len: usize) -> BluefinResult { - let mut bytes: Vec = vec![]; + pub(crate) fn consume(&mut self, len: usize, buf: &mut [u8]) -> BluefinResult { let mut num_bytes = 0; + let mut writer_ix = 0; // peek into carry over bytes - if let Some(c_bytes) = self.carry_over_bytes.as_mut() { + if let Some(ref mut c_bytes) = self.carry_over_bytes { // We can take all of the carry over if c_bytes.len() <= len { num_bytes += c_bytes.len(); - bytes.append(c_bytes); + buf[writer_ix..writer_ix + c_bytes.len()].copy_from_slice(c_bytes); + writer_ix += c_bytes.len(); self.carry_over_bytes = None; // We still have some bytes left over in the carry over... } else { - bytes.append(&mut c_bytes[..len].to_vec()); - self.carry_over_bytes = Some(c_bytes[len..].to_vec()); - return Ok(ConsumeResult::new(bytes, 0, 0)); + let drained = c_bytes.drain(len..).collect(); + buf[writer_ix..writer_ix + len].copy_from_slice(&c_bytes); + self.carry_over_bytes = Some(drained); + return Ok(ConsumeResult::new(0, 0, len as u64)); } } @@ -214,12 +228,15 @@ impl OrderedBytes { // We cannot return all of the payload. We will partially consume the payload and // store the remaining in the carry over if payload_len > bytes_remaining { - bytes.append(&mut packet.payload[..bytes_remaining].to_vec()); + buf[writer_ix..writer_ix + bytes_remaining] + .copy_from_slice(&packet.payload[..bytes_remaining]); + writer_ix += bytes_remaining; self.carry_over_bytes = Some(packet.payload[bytes_remaining..].to_vec()); - num_bytes = len; + num_bytes += bytes_remaining; // We have enough space left to consume the entirity of this buffer } else { - bytes.append(&mut packet.payload); + buf[writer_ix..writer_ix + payload_len].copy_from_slice(&packet.payload); + writer_ix += payload_len; num_bytes += payload_len; } @@ -233,10 +250,10 @@ impl OrderedBytes { } // Nothing to consume, including any potential carry-over bytes - if bytes.len() == 0 { + if num_bytes == 0 { return Err(BluefinError::BufferEmptyError); } - Ok(ConsumeResult::new(bytes, ix, base_packet_number)) + Ok(ConsumeResult::new(ix, base_packet_number, num_bytes as u64)) } } diff --git a/src/net/server.rs b/src/net/server.rs index d487a0f..7af2723 100644 --- a/src/net/server.rs +++ b/src/net/server.rs @@ -1,4 +1,5 @@ use std::{ + mem, net::SocketAddr, sync::{Arc, Mutex}, time::Duration, @@ -16,9 +17,10 @@ use crate::{ use super::{ build_and_start_tx, connection::{BluefinConnection, ConnectionBuffer, ConnectionManager}, + AckBuffer, ConnectionManagedBuffers, }; - -const NUM_TX_WORKERS_FOR_SERVER: u8 = 10; +use std::os::fd::AsRawFd; +const NUM_TX_WORKERS_FOR_SERVER_DEFAULT: u16 = 10; #[derive(Clone)] pub struct BluefinServer { @@ -26,6 +28,7 @@ pub struct BluefinServer { src_addr: SocketAddr, conn_manager: Arc>, pending_accept_ids: Arc>>, + num_reader_workers: u16, } impl BluefinServer { @@ -35,15 +38,56 @@ impl BluefinServer { conn_manager: Arc::new(RwLock::new(ConnectionManager::new())), pending_accept_ids: Arc::new(Mutex::new(Vec::new())), src_addr, + num_reader_workers: NUM_TX_WORKERS_FOR_SERVER_DEFAULT, + } + } + + #[inline] + pub fn set_num_reader_workers(&mut self, num_reader_workers: u16) -> BluefinResult<()> { + if num_reader_workers == 0 { + return Err(BluefinError::Unexpected( + "Cannot have zero reader values".to_string(), + )); } + self.num_reader_workers = num_reader_workers; + Ok(()) } pub async fn bind(&mut self) -> BluefinResult<()> { let socket = UdpSocket::bind(self.src_addr).await?; + let socket_fd = socket.as_raw_fd(); self.socket = Some(Arc::new(socket)); + #[cfg(target_os = "macos")] + { + use sysctl::Sysctl; + if let Ok(ctl) = sysctl::Ctl::new("net.inet.udp.maxdgram") { + match ctl.set_value_string("16000") { + Ok(s) => { + println!("Successfully set net.inet.udp.maxdgram to {}", s) + } + Err(e) => eprintln!("Failed to set net.inet.udp.maxdgram due to err: {:?}", e), + } + } + } + + #[cfg(any(target_os = "linux", target_os = "macos"))] + unsafe { + let optval: libc::c_int = 1; + let ret = libc::setsockopt( + socket_fd, + libc::SOL_SOCKET, + libc::SO_REUSEPORT, + &optval as *const _ as *const libc::c_void, + mem::size_of_val(&optval) as libc::socklen_t, + ); + if ret != 0 { + return Err(BluefinError::InvalidSocketError); + } + } + build_and_start_tx( - NUM_TX_WORKERS_FOR_SERVER, + self.num_reader_workers, Arc::clone(self.socket.as_ref().unwrap()), Arc::clone(&self.conn_manager), Arc::clone(&self.pending_accept_ids), @@ -56,16 +100,24 @@ impl BluefinServer { pub async fn accept(&mut self) -> BluefinResult { // generate random conn id and insert buffer let src_conn_id: u32 = rand::thread_rng().gen(); + // This is the packet number the server will begin using. + let packet_number: u64 = rand::thread_rng().gen(); let conn_buffer = Arc::new(Mutex::new(ConnectionBuffer::new( src_conn_id, BluefinHost::PackLeader, ))); + let ack_buffer = Arc::new(Mutex::new(AckBuffer::new(packet_number + 1))); + let conn_mgr_buffers = ConnectionManagedBuffers { + conn_buff: Arc::clone(&conn_buffer), + ack_buff: Arc::clone(&ack_buffer), + }; + let hello_key = format!("{}_0", src_conn_id); let _ = self .conn_manager .write() .await - .insert(&hello_key, Arc::clone(&conn_buffer)); + .insert(&hello_key, conn_mgr_buffers.clone()); self.pending_accept_ids.lock().unwrap().push(src_conn_id); let handshake_buf = HandshakeConnectionBuffer::new(Arc::clone(&conn_buffer)); @@ -83,11 +135,10 @@ impl BluefinServer { // delete the old hello entry and insert the new connection entry let mut guard = self.conn_manager.write().await; let _ = guard.remove(&hello_key); - let _ = guard.insert(&key, Arc::clone(&conn_buffer)); + let _ = guard.insert(&key, conn_mgr_buffers); drop(guard); // send server hello - let packet_number: u64 = rand::thread_rng().gen(); let packet = build_empty_encrypted_packet( src_conn_id, dst_conn_id, @@ -113,6 +164,7 @@ impl BluefinServer { dst_conn_id, packet_number + 1, Arc::clone(&conn_buffer), + Arc::clone(&ack_buffer), Arc::clone(self.socket.as_ref().unwrap()), addr, )) diff --git a/src/utils/window.rs b/src/utils/window.rs index 05ab597..6c7b414 100644 --- a/src/utils/window.rs +++ b/src/utils/window.rs @@ -4,7 +4,7 @@ use crate::core::error::BluefinError; use super::common::BluefinResult; -pub const MAX_SLIDING_WINDOW_SIZE: usize = 10; +pub const MAX_SLIDING_WINDOW_SIZE: usize = 20000; #[derive(Clone)] pub(crate) struct SlidingWindow { @@ -12,6 +12,15 @@ pub(crate) struct SlidingWindow { ordered_packet_numbers: VecDeque, } +#[derive(PartialEq, Debug)] +pub(crate) struct SlidingWindowConsumeResult { + /// The largest packet number that we have contiguously buffered + pub(crate) largest_packet_number: u64, + /// The number of acks consumed. This means that this consume result represents + /// the accumulation of contiguous acks in the range of [largest_packet_number - num_acks_consumed - 1, largest_packet_number] + pub(crate) num_acks_consumed: u64, +} + impl SlidingWindow { pub(crate) fn new(smallest_expected_packet_number: u64) -> Self { Self { @@ -26,6 +35,14 @@ impl SlidingWindow { return Err(BluefinError::UnexpectedPacketNumberError); } + if packet_number - self.smallest_expected_packet_number + >= MAX_SLIDING_WINDOW_SIZE.try_into().unwrap() + { + return Err(BluefinError::BufferFullError( + "Sliding window buffer is full".to_string(), + )); + } + // Find the index to insert into our sorted vector. let index = match self.ordered_packet_numbers.binary_search(&packet_number) { // Ok result means we have already stored this packet number before. Fail here. @@ -35,26 +52,26 @@ impl SlidingWindow { Err(index) => index, }; - // We cannot accomodate this packet number. This means thie packet number is so high - // that we would have to allocate too much memory. - if index >= MAX_SLIDING_WINDOW_SIZE { - return Err(BluefinError::BufferFullError); - } - self.ordered_packet_numbers.insert(index, packet_number); - self.smallest_expected_packet_number = self.ordered_packet_numbers[0]; Ok(()) } - /// If present, returns the largest packet number that we have contigously buffered. For example, + /// If present, returns the largest packet number that we have contiguously buffered. For example, /// if Some(10) were returned, that means we have accounted for all packet numbers 10 and below. /// We may have packet numbers larger than 10 but they are disjointed from the contiguous set. - fn consume(&mut self) -> Option { + #[inline] + pub(crate) fn consume(&mut self) -> Option { // Nothing in the vector. Done. if self.ordered_packet_numbers.is_empty() { return None; } + // There are entries in the vector but we are still missing the smallest expected + // packet number. Nothing to consume. + if self.ordered_packet_numbers[0] > self.smallest_expected_packet_number { + return None; + } + // Vector is not empty so this is safe. let mut last_packet_number = self.ordered_packet_numbers.pop_front().unwrap(); while !self.ordered_packet_numbers.is_empty() { @@ -65,12 +82,98 @@ impl SlidingWindow { last_packet_number = p_number; continue; } else { - // There is a jump. Reinsert the poped val in the front and return. + // There is a jump. Reinsert the popped val in the front and return. self.ordered_packet_numbers.push_front(p_number); break; } } - Some(last_packet_number) + let prev = self.smallest_expected_packet_number; + self.smallest_expected_packet_number = last_packet_number + 1; + + Some(SlidingWindowConsumeResult { + largest_packet_number: last_packet_number, + num_acks_consumed: last_packet_number - prev + 1, + }) + } +} + +#[cfg(test)] +mod tests { + use crate::{core::error::BluefinError, utils::window::MAX_SLIDING_WINDOW_SIZE}; + + use super::SlidingWindow; + + #[test] + fn sliding_window_behaves_as_expected() { + // Start with a packet number of 100 + let mut sliding_window = SlidingWindow::new(100); + // Nothing inserted, show return none + assert_eq!(sliding_window.consume(), None); + + // We should fail if we insert a packet number less than 100 + let insert_res = sliding_window.insert_packet_number(99); + assert!(insert_res.is_err()); + assert_eq!( + insert_res.err().unwrap(), + BluefinError::UnexpectedPacketNumberError + ); + // Nothing inserted, should still fail + assert_eq!(sliding_window.consume(), None); + + // Should be able to insert 101, 102, 103, 104 and 106 + assert_eq!(sliding_window.insert_packet_number(101), Ok(())); + assert_eq!(sliding_window.insert_packet_number(102), Ok(())); + assert_eq!(sliding_window.insert_packet_number(103), Ok(())); + assert_eq!(sliding_window.insert_packet_number(104), Ok(())); + assert_eq!(sliding_window.insert_packet_number(106), Ok(())); + // Still nothing to consume since we are still missing packet #100 + assert_eq!(sliding_window.consume(), None); + + // Cannot re-insert already inserted numbers + let insert_res = sliding_window.insert_packet_number(103); + assert!(insert_res.is_err()); + assert_eq!( + insert_res.err().unwrap(), + BluefinError::UnexpectedPacketNumberError + ); + + // Should not be able to insert above the window limit + assert!(sliding_window + .insert_packet_number(100 + u64::try_from(MAX_SLIDING_WINDOW_SIZE).unwrap()) + .is_err()); + + // Complete a contiguous sequence [100, 104] inclusive. + assert_eq!(sliding_window.insert_packet_number(100), Ok(())); + let consume_res = sliding_window.consume(); + assert!(consume_res.is_some()); + let consume_res_unwrapped = consume_res.unwrap(); + assert_eq!(consume_res_unwrapped.largest_packet_number, 104); + assert_eq!(consume_res_unwrapped.num_acks_consumed, 5); + + // Consuming again returns none since we are missing #105 + assert!(sliding_window.consume().is_none()); + + // Insert #107 and #110 + assert_eq!(sliding_window.insert_packet_number(107), Ok(())); + assert_eq!(sliding_window.insert_packet_number(110), Ok(())); + assert!(sliding_window.consume().is_none()); + + // Complete contiguous sequence [105, 107] + assert_eq!(sliding_window.insert_packet_number(105), Ok(())); + let consume_res = sliding_window.consume(); + assert!(consume_res.is_some()); + let consume_res_unwrapped = consume_res.unwrap(); + assert_eq!(consume_res_unwrapped.largest_packet_number, 107); + assert_eq!(consume_res_unwrapped.num_acks_consumed, 3); + assert!(sliding_window.consume().is_none()); + + // Should not be able to insert above the window limit + assert!(sliding_window + .insert_packet_number(108 + u64::try_from(MAX_SLIDING_WINDOW_SIZE).unwrap()) + .is_err()); + assert!(sliding_window + .insert_packet_number(107 + u64::try_from(MAX_SLIDING_WINDOW_SIZE).unwrap()) + .is_ok()); } } diff --git a/src/worker/reader.rs b/src/worker/reader.rs index 470aac1..9dd19b8 100644 --- a/src/worker/reader.rs +++ b/src/worker/reader.rs @@ -9,15 +9,12 @@ use std::{ use tokio::{net::UdpSocket, sync::RwLock, time::sleep}; use crate::{ - core::{ - context::BluefinHost, error::BluefinError, header::PacketType, packet::BluefinPacket, - Serialisable, - }, + core::{context::BluefinHost, error::BluefinError, header::PacketType, packet::BluefinPacket}, net::{ - build_ack_packet, + ack_handler::AckBuffer, connection::{ConnectionBuffer, ConnectionManager}, - is_client_ack_packet, is_hello_packet, - ordered_bytes::ConsumeResult, + is_client_ack_packet, is_hello_packet, ConnectionManagedBuffers, + MAX_BLUEFIN_BYTES_IN_UDP_DATAGRAM, }, utils::common::BluefinResult, }; @@ -32,7 +29,7 @@ use super::writer::WriterTxChannel; /// to buffer in the bytes/packet. In other words, this channel *transmits* bytes *into* the buffer /// and signals any awaiters that data is ready. pub(crate) struct ReaderTxChannel { - pub(crate) id: u8, + pub(crate) id: u16, socket: Arc, conn_manager: Arc>, pending_accept_ids: Arc>>, @@ -45,30 +42,25 @@ pub(crate) struct ReaderTxChannel { /// buffered tuple contents ([ConsumeResult], [SocketAddr]). In other words, this channel /// *receives* bytes *from* the buffer. pub(crate) struct ReaderRxChannel { - socket: Arc, future: ReaderRxChannelFuture, - src_conn_id: u32, - dst_conn_id: u32, writer_tx_channel: WriterTxChannel, - packet_num: Arc>, } #[derive(Clone)] struct ReaderRxChannelFuture { buffer: Arc>, - bytes_to_read: usize, } impl Future for ReaderRxChannelFuture { - type Output = (ConsumeResult, SocketAddr); + type Output = (); fn poll( self: std::pin::Pin<&mut Self>, cx: &mut std::task::Context<'_>, ) -> std::task::Poll { let mut guard = self.buffer.lock().unwrap(); - if let Ok((consume_res, addr)) = guard.consume(self.bytes_to_read) { - return Poll::Ready((consume_res, addr)); + if let Ok(()) = guard.peek() { + return Poll::Ready(()); } guard.set_waker(cx.waker().clone()); @@ -79,48 +71,36 @@ impl Future for ReaderRxChannelFuture { impl ReaderRxChannel { pub(crate) fn new( buffer: Arc>, - socket: Arc, - src_conn_id: u32, - dst_conn_id: u32, writer_tx_channel: WriterTxChannel, - packet_num: Arc>, ) -> Self { - let future = ReaderRxChannelFuture { - buffer, - bytes_to_read: 0, - }; + let future = ReaderRxChannelFuture { buffer }; Self { - socket, future, - src_conn_id, - dst_conn_id, - packet_num, writer_tx_channel, } } #[inline] - pub(crate) fn set_bytes_to_read(&mut self, bytes_to_read: usize) { - self.future.bytes_to_read = bytes_to_read; - } - - #[inline] - pub(crate) async fn read(&self) -> BluefinResult<(Vec, SocketAddr)> { - let (consume_res, addr) = self.future.clone().await; + pub(crate) async fn read( + &mut self, + bytes_to_read: usize, + buf: &mut [u8], + ) -> BluefinResult<(u64, SocketAddr)> { + let _ = self.future.clone().await; + let (consume_res, addr) = { + let mut guard = self.future.buffer.lock().unwrap(); + guard.consume(bytes_to_read, buf).unwrap() + }; let num_packets_consumed = consume_res.get_num_packets_consumed(); let base_packet_num = consume_res.get_base_packet_number(); - // TODO: Handle packet numbers for sending acks. For now, use zero. // We need to send an ack. if num_packets_consumed > 0 { - let ack_packet = build_ack_packet( - self.src_conn_id, - self.dst_conn_id, - base_packet_num, - num_packets_consumed as u16, - 0, - ); - if let Err(e) = self.writer_tx_channel.send(ack_packet).await { + if let Err(e) = self + .writer_tx_channel + .send_ack(base_packet_num, num_packets_consumed) + .await + { eprintln!( "Failed to send ack packet after reads due to error: {:?}", e @@ -128,7 +108,7 @@ impl ReaderRxChannel { } } - Ok((consume_res.take_bytes(), addr)) + Ok((consume_res.get_bytes_consumed(), addr)) } } @@ -151,9 +131,10 @@ impl ReaderTxChannel { #[inline] async fn run_sleep(encountered_err: &mut bool) { if !*encountered_err { + sleep(Duration::from_micros(1)).await; return; } - sleep(Duration::from_millis(100)).await; + sleep(Duration::from_millis(5)).await; *encountered_err = false; } @@ -196,52 +177,67 @@ impl ReaderTxChannel { #[inline] fn build_conn_buff_key(is_hello: bool, src_conn_id: u32, dst_conn_id: u32) -> String { return { - if is_hello { - format!("{}_0", src_conn_id) - } else { + if !is_hello { format!("{}_{}", src_conn_id, dst_conn_id) + } else { + format!("{}_0", src_conn_id) } }; } - #[inline] - fn buffer_in_data( + fn buffer_to_conn_buffer( + conn_buff: &mut ConnectionBuffer, + packet: BluefinPacket, + addr: SocketAddr, is_hello: bool, is_client_ack: bool, - packet: &BluefinPacket, - addr: SocketAddr, - buffer: Arc>, ) -> BluefinResult<()> { - let mut buffer_guard = buffer.lock().unwrap(); - - // Ack packets are buffered differently - if !is_client_ack && !is_hello && packet.header.type_field == PacketType::Ack { - if let Err(e) = buffer_guard.buffer_in_ack_packet(&packet) { - return Err(e); - } - } else if !is_hello && !is_client_ack { + let packet_src_conn_id = packet.header.source_connection_id; + if !is_hello && !is_client_ack { // If not hello, we buffer in the bytes // Could not buffer in packet... buffer is likely full. We will have to discard the // packet. - if let Err(e) = buffer_guard.buffer_in_bytes(packet) { - return Err(e); - } + conn_buff.buffer_in_bytes(packet)?; } else { - if let Err(e) = buffer_guard.buffer_in_packet(&packet) { - return Err(e); - } - let _ = buffer_guard.buffer_in_addr(addr); + conn_buff.buffer_in_packet(packet)?; + let _ = conn_buff.buffer_in_addr(addr); } - buffer_guard.set_dst_conn_id(packet.header.source_connection_id); + conn_buff.set_dst_conn_id(packet_src_conn_id); // Wake future that buffered data is available - if let Some(w) = buffer_guard.get_waker() { + if let Some(w) = conn_buff.get_waker() { w.wake_by_ref(); } else { return Err(BluefinError::NoSuchWakerError); } + Ok(()) + } + + #[inline] + fn buffer_to_ack_buffer(ack_buff: &mut AckBuffer, packet: BluefinPacket) -> BluefinResult<()> { + ack_buff.buffer_in_ack_packet(packet)?; + ack_buff.wake() + } + #[inline] + fn buffer_in_data( + is_hello: bool, + host_type: BluefinHost, + packet: BluefinPacket, + addr: SocketAddr, + buffers: &ConnectionManagedBuffers, + ) -> BluefinResult<()> { + let is_client_ack = is_client_ack_packet(host_type, &packet); + if !is_client_ack && !is_hello && packet.header.type_field == PacketType::Ack { + let mut ack_buff = buffers.ack_buff.lock().unwrap(); + Self::buffer_to_ack_buffer(&mut ack_buff, packet)?; + drop(ack_buff); + } else { + let mut conn_buff = buffers.conn_buff.lock().unwrap(); + Self::buffer_to_conn_buffer(&mut conn_buff, packet, addr, is_hello, is_client_ack)?; + drop(conn_buff); + } Ok(()) } @@ -249,46 +245,55 @@ impl ReaderTxChannel { /// from the udp socket into a connection buffer. This method should be run its own asynchronous task. pub(crate) async fn run(&mut self) -> BluefinResult<()> { let mut encountered_err = false; + let mut buf = [0u8; MAX_BLUEFIN_BYTES_IN_UDP_DATAGRAM]; loop { ReaderTxChannel::run_sleep(&mut encountered_err).await; - let mut buf = vec![0; 1504]; - let (res, addr) = self.socket.recv_from(&mut buf).await?; - let packet_res = BluefinPacket::deserialise(&buf[..res]); + let (size, addr) = self.socket.recv_from(&mut buf).await?; + let packets_res = BluefinPacket::from_bytes(&buf[..size]); - // Not a bluefin packet or it's invalid. - if let Err(e) = packet_res { - eprintln!("{}", e); + // Not bluefin packet(s) or it's invalid. + if let Err(e) = packets_res { encountered_err = true; + eprintln!("Encountered err: {:?}", e); continue; } // Acquire lock and buffer in data - let packet = packet_res.unwrap(); - let mut src_conn_id = packet.header.destination_connection_id; - let dst_conn_id = packet.header.source_connection_id; - let mut is_hello = false; - let mut is_client_ack = false; - - if let Err(e) = self.handle_for_handshake(&packet, &mut is_hello, &mut src_conn_id) { - eprintln!("{}", e); + let packets = packets_res.unwrap(); + if packets.len() == 0 { encountered_err = true; continue; } - if is_client_ack_packet(self.host_type, &packet) { - is_client_ack = true; + // Because all bluefin packets bundled in a datagram must come from the same host, we just peek + // at the first one + let mut src_conn_id = packets[0].header.destination_connection_id; + let dst_conn_id = packets[0].header.source_connection_id; + let mut is_hello = false; + + // If there is only one packet, then it's possible it is a handshake packet. Handshakes are sent + // via one udp datagram carries exactly one bluefin packet + if packets.len() == 1 { + if let Err(e) = + self.handle_for_handshake(&packets[0], &mut is_hello, &mut src_conn_id) + { + eprintln!("{}", e); + encountered_err = true; + continue; + } } let key = ReaderTxChannel::build_conn_buff_key(is_hello, src_conn_id, dst_conn_id); - // ACQUIRE LOCK FOR CONN MANAGER - let guard = self.conn_manager.read().await; - let _conn_buf = guard.get(&key); - // We just need the conn buffer, which is behind its own lock. We don't need the - // conn manager anymore. - // RELEASE LOCK FOR CONN MANAGER - drop(guard); + let _conn_buf = { + // ACQUIRE LOCK FOR CONN MANAGER + let guard = self.conn_manager.read().await; + guard.get(&key) + // We just need the conn buffer, which is behind its own lock. We don't need the + // conn manager anymore. + // RELEASE LOCK FOR CONN MANAGER + }; if _conn_buf.is_none() { eprintln!("Could not find connection {}", &key); @@ -296,12 +301,14 @@ impl ReaderTxChannel { continue; } - let buffer = _conn_buf.unwrap(); - if let Err(e) = - ReaderTxChannel::buffer_in_data(is_hello, is_client_ack, &packet, addr, buffer) - { - eprintln!("Failed to buffer in data: {}", e); - encountered_err = true; + let buffers = _conn_buf.unwrap(); + for p in packets { + if let Err(e) = + ReaderTxChannel::buffer_in_data(is_hello, self.host_type, p, addr, &buffers) + { + eprintln!("Failed to buffer in data: {}", e); + encountered_err = true; + } } } } diff --git a/src/worker/writer.rs b/src/worker/writer.rs index a767562..a11f9b8 100644 --- a/src/worker/writer.rs +++ b/src/worker/writer.rs @@ -1,20 +1,36 @@ use std::{ + cmp::min, collections::VecDeque, future::Future, net::SocketAddr, sync::{Arc, Mutex}, task::{Poll, Waker}, + time::Duration, }; -use tokio::net::UdpSocket; +use tokio::{net::UdpSocket, time::sleep}; use crate::{ - core::{packet::BluefinPacket, Serialisable}, + core::{ + header::{BluefinHeader, BluefinSecurityFields, PacketType}, + packet::BluefinPacket, + Serialisable, + }, + net::{MAX_BLUEFIN_BYTES_IN_UDP_DATAGRAM, MAX_BLUEFIN_PAYLOAD_SIZE_BYTES}, utils::common::BluefinResult, }; +/// Each writer queue holds a queue of `WriterQueueData` +enum WriterQueueData { + Payload(Vec), + Ack { + base_packet_num: u64, + num_packets_consumed: usize, + }, +} + pub(crate) struct WriterQueue { - queue: VecDeque, + queue: VecDeque, waker: Option, } @@ -25,48 +41,296 @@ impl WriterQueue { waker: None, } } + + #[inline] + pub(crate) fn consume_data( + &mut self, + next_packet_num: &mut u64, + src_conn_id: u32, + dst_conn_id: u32, + ) -> Option> { + let mut ans = vec![]; + let mut bytes_remaining = MAX_BLUEFIN_BYTES_IN_UDP_DATAGRAM; + let mut running_payload = vec![]; + + let security_fields = BluefinSecurityFields::new(false, 0x0); + let mut header = BluefinHeader::new( + src_conn_id, + dst_conn_id, + PacketType::UnencryptedData, + 0, + security_fields, + ); + + while !self.queue.is_empty() && bytes_remaining > 20 { + // We already have some bytes left over and it's more than we can afford. Take what + // we can and end. + if running_payload.len() >= bytes_remaining - 20 { + // Keep taking as many bytes out of the running payload as we can afford to + while !running_payload.is_empty() && bytes_remaining >= 20 { + let max_bytes_to_take = min( + running_payload.len(), + min(MAX_BLUEFIN_PAYLOAD_SIZE_BYTES, bytes_remaining - 20), + ); + header.with_packet_number(*next_packet_num); + header.type_specific_payload = max_bytes_to_take as u16; + let p = BluefinPacket::builder() + .header(header) + .payload(running_payload[..max_bytes_to_take].to_vec()) + .build(); + ans.extend(p.serialise()); + *next_packet_num += 1; + bytes_remaining -= max_bytes_to_take + 20; + running_payload = running_payload[max_bytes_to_take..].to_vec(); + } + + if !running_payload.is_empty() { + self.queue + .push_front(WriterQueueData::Payload(running_payload.to_vec())); + } + return Some(ans); + } + + // We just happen to have a completely full running payload. Let's take as much as we can. + if running_payload.len() >= MAX_BLUEFIN_PAYLOAD_SIZE_BYTES { + let max_bytes_to_take = min( + running_payload.len(), + min(bytes_remaining - 20, MAX_BLUEFIN_PAYLOAD_SIZE_BYTES), + ); + header.with_packet_number(*next_packet_num); + header.type_specific_payload = max_bytes_to_take as u16; + let p = BluefinPacket::builder() + .header(header) + .payload(running_payload[..max_bytes_to_take].to_vec()) + .build(); + ans.extend(p.serialise()); + *next_packet_num += 1; + bytes_remaining -= max_bytes_to_take + 20; + running_payload = running_payload[max_bytes_to_take..].to_vec(); + continue; + } + + // We have room + let data = self.queue.pop_front().unwrap(); + match data { + WriterQueueData::Payload(p) => { + let potential_bytes_len = p.len(); + if potential_bytes_len + running_payload.len() > MAX_BLUEFIN_PAYLOAD_SIZE_BYTES + { + // We cannot simply fit both payloads into this packet. + running_payload.extend(p); + + // Try to take as much as we can + let max_bytes_to_take = min( + running_payload.len(), + min(MAX_BLUEFIN_PAYLOAD_SIZE_BYTES, bytes_remaining - 20), + ); + header.with_packet_number(*next_packet_num); + header.type_specific_payload = max_bytes_to_take as u16; + let packet = BluefinPacket::builder() + .header(header) + .payload(running_payload[..max_bytes_to_take].to_vec()) + .build(); + ans.extend(packet.serialise()); + *next_packet_num += 1; + bytes_remaining -= max_bytes_to_take + 20; + running_payload = running_payload[max_bytes_to_take..].to_vec(); + } else { + // We can fit both the payload and the left over bytes + running_payload.extend(p); + } + } + _ => unreachable!(), + } + } + + // Take the remaining amount + while !running_payload.is_empty() && bytes_remaining >= 20 { + let max_bytes_to_take = min( + running_payload.len(), + min(MAX_BLUEFIN_PAYLOAD_SIZE_BYTES, bytes_remaining - 20), + ); + header.with_packet_number(*next_packet_num); + header.type_specific_payload = max_bytes_to_take as u16; + let p = BluefinPacket::builder() + .header(header) + .payload(running_payload[..max_bytes_to_take].to_vec()) + .build(); + ans.extend(p.serialise()); + *next_packet_num += 1; + running_payload = running_payload[max_bytes_to_take..].to_vec(); + bytes_remaining -= max_bytes_to_take + 20; + } + + // Re-queue the remaining bytes + if !running_payload.is_empty() { + self.queue + .push_front(WriterQueueData::Payload(running_payload)); + } + + if ans.is_empty() { + return None; + } + Some(ans) + } + + pub(crate) fn consume_acks(&mut self, src_conn_id: u32, dst_conn_id: u32) -> Option> { + let mut bytes = vec![]; + let security_fields = BluefinSecurityFields::new(false, 0x0); + let mut header = BluefinHeader::new( + src_conn_id, + dst_conn_id, + PacketType::Ack, + 0, + security_fields, + ); + while !self.queue.is_empty() && bytes.len() <= MAX_BLUEFIN_BYTES_IN_UDP_DATAGRAM { + let data = self.queue.pop_front().unwrap(); + match data { + WriterQueueData::Ack { + base_packet_num: b, + num_packets_consumed: c, + } => { + header.packet_number = b; + header.type_specific_payload = c as u16; + if bytes.len() + 20 <= MAX_BLUEFIN_BYTES_IN_UDP_DATAGRAM { + bytes.extend(header.serialise()); + } else { + self.queue.push_front(data); + break; + } + } + _ => unreachable!(), + } + } + + if bytes.len() == 0 { + return None; + } + + Some(bytes) + } } -/// Queues write requests to be sent +/// Queues write requests to be sent. Each connection can have one or more [WriterTxChannel]. #[derive(Clone)] pub(crate) struct WriterTxChannel { - queue: Arc>, + data_queue: Arc>, + ack_queue: Arc>, + num_runs_without_sleep: u32, } impl WriterTxChannel { - pub(crate) fn new(queue: Arc>) -> Self { - Self { queue } + pub(crate) fn new( + data_queue: Arc>, + ack_queue: Arc>, + ) -> Self { + Self { + data_queue, + ack_queue, + num_runs_without_sleep: 0, + } } - pub(crate) async fn send(&self, packet: BluefinPacket) -> BluefinResult { - let bytes = packet.len(); - let mut guard = self.queue.lock().unwrap(); - guard.queue.push_back(packet); + /// ONLY for sending data + pub(crate) async fn send(&mut self, payload: &[u8]) -> BluefinResult { + let bytes = payload.len(); + let data = WriterQueueData::Payload(payload.to_vec()); - // Signal to Rx channel that we have new packets in the queue - if let Some(ref waker) = guard.waker { - waker.wake_by_ref(); + { + let mut guard = self.data_queue.lock().unwrap(); + guard.queue.push_back(data); + + // Signal to Rx channel that we have new packets in the queue + if let Some(ref waker) = guard.waker { + waker.wake_by_ref(); + } } + + self.num_runs_without_sleep += 1; + if self.num_runs_without_sleep >= 137 { + sleep(Duration::from_nanos(10)).await; + self.num_runs_without_sleep = 0; + } + Ok(bytes) } + + pub(crate) async fn send_ack( + &mut self, + base_packet_num: u64, + num_packets_consumed: usize, + ) -> BluefinResult<()> { + let data = WriterQueueData::Ack { + base_packet_num, + num_packets_consumed, + }; + + { + let mut guard = self.ack_queue.lock().unwrap(); + guard.queue.push_back(data); + + // Signal to Rx channel that we have new packets in the queue + if let Some(ref waker) = guard.waker { + waker.wake_by_ref(); + } + } + + Ok(()) + } +} + +#[derive(Clone)] +struct WriterRxChannelDataFuture { + data_queue: Arc>, +} + +#[derive(Clone)] +struct WriterRxChannelAckFuture { + ack_queue: Arc>, } -/// Consumes queued requests and sends them across the wire +/// Consumes queued requests and sends them across the wire. For now, each connection +/// has one and only one [WriterRxChannel]. This channel must run two separate jobs: +/// [WriterRxChannel::run_data], which reads out of the data queue and sends bluefin +/// packets w/ payloads across the wire AND [WriterRxChannel::run_ack], which reads +/// acks out of the ack queue and sends bluefin ack packets across the wire. #[derive(Clone)] pub(crate) struct WriterRxChannel { - socket: Arc, - queue: Arc>, + data_future: WriterRxChannelDataFuture, + ack_future: WriterRxChannelAckFuture, dst_addr: SocketAddr, + next_packet_num: u64, + src_conn_id: u32, + dst_conn_id: u32, + socket: Arc, +} + +impl Future for WriterRxChannelDataFuture { + type Output = usize; + + fn poll( + self: std::pin::Pin<&mut Self>, + cx: &mut std::task::Context<'_>, + ) -> std::task::Poll { + let mut guard = self.data_queue.lock().unwrap(); + let num_packets_to_send = guard.queue.len(); + if num_packets_to_send == 0 { + guard.waker = Some(cx.waker().clone()); + return Poll::Pending; + } + Poll::Ready(num_packets_to_send) + } } -impl Future for WriterRxChannel { +impl Future for WriterRxChannelAckFuture { type Output = usize; fn poll( self: std::pin::Pin<&mut Self>, cx: &mut std::task::Context<'_>, ) -> std::task::Poll { - let mut guard = self.queue.lock().unwrap(); + let mut guard = self.ack_queue.lock().unwrap(); let num_packets_to_send = guard.queue.len(); if num_packets_to_send == 0 { guard.waker = Some(cx.waker().clone()); @@ -78,31 +342,396 @@ impl Future for WriterRxChannel { impl WriterRxChannel { pub(crate) fn new( - queue: Arc>, + data_queue: Arc>, + ack_queue: Arc>, socket: Arc, dst_addr: SocketAddr, + next_packet_num: u64, + src_conn_id: u32, + dst_conn_id: u32, ) -> Self { + let data_future = WriterRxChannelDataFuture { + data_queue: Arc::clone(&data_queue), + }; + let ack_future = WriterRxChannelAckFuture { + ack_queue: Arc::clone(&ack_queue), + }; Self { - queue, - socket, + data_future, + ack_future, dst_addr, + next_packet_num, + src_conn_id, + dst_conn_id, + socket: Arc::clone(&socket), + } + } + pub(crate) async fn run_ack(&mut self) { + loop { + let _ = self.ack_future.clone().await; + let mut guard = self.ack_future.ack_queue.lock().unwrap(); + let bytes = guard.consume_acks(self.src_conn_id, self.dst_conn_id); + match bytes { + None => continue, + Some(b) => { + if let Err(e) = self.socket.try_send_to(&b, self.dst_addr) { + eprintln!( + "Encountered error {} while sending ack packet across wire", + e + ); + continue; + } + } + } + guard.waker = None; } } - pub(crate) async fn run(&self) { + pub(crate) async fn run_data(&mut self) { loop { - let mut num_packets_to_send = self.clone().await; - let mut guard = self.queue.lock().unwrap(); - while num_packets_to_send > 0 && !guard.queue.is_empty() { - let packet = guard.queue.pop_front().unwrap(); - if let Err(e) = self.socket.try_send_to(&packet.serialise(), self.dst_addr) { - eprintln!("Encountered error {} while sending packet across wire", e); - guard.queue.push_front(packet); - break; + let _ = self.data_future.clone().await; + let mut guard = self.data_future.data_queue.lock().unwrap(); + let bytes = guard.consume_data( + &mut self.next_packet_num, + self.src_conn_id, + self.dst_conn_id, + ); + match bytes { + None => continue, + Some(b) => { + if let Err(e) = self.socket.try_send_to(&b, self.dst_addr) { + eprintln!( + "Encountered error {} while sending data packet across wire", + e + ); + continue; + } } - num_packets_to_send -= 1; } guard.waker = None; } } } + +#[cfg(kani)] +mod verification_tests { + use crate::worker::writer::WriterQueue; + + #[kani::proof] + fn kani_writer_queue_consume_empty_data_behaves_as_expected() { + let mut writer_q = WriterQueue::new(); + let mut next_packet_num = kani::any(); + let prev = next_packet_num; + assert!(writer_q + .consume_data(&mut next_packet_num, kani::any(), kani::any()) + .is_none()); + assert_eq!(next_packet_num, prev); + } + + #[kani::proof] + fn kani_writer_queue_consume_empty_ack_behaves_as_expected() { + let mut writer_q = WriterQueue::new(); + assert!(writer_q.consume_acks(kani::any(), kani::any()).is_none()); + } +} + +#[cfg(test)] +mod tests { + use rstest::rstest; + + use crate::{ + core::{header::PacketType, packet::BluefinPacket}, + net::{ + MAX_BLUEFIN_BYTES_IN_UDP_DATAGRAM, MAX_BLUEFIN_PACKETS_IN_UDP_DATAGRAM, + MAX_BLUEFIN_PAYLOAD_SIZE_BYTES, + }, + worker::writer::WriterQueue, + }; + + use super::WriterQueueData; + + #[rstest] + #[test] + #[case(550)] + #[case(1)] + #[case(10)] + #[case(760)] + fn writer_queue_consume_ack_for_one_datagram_behaves_as_expected(#[case] num_acks: usize) { + let expected_byte_size = num_acks * 20; + assert_ne!(expected_byte_size, 0); + assert!(expected_byte_size <= MAX_BLUEFIN_BYTES_IN_UDP_DATAGRAM); + + let mut writer_q = WriterQueue::new(); + for _ in 0..num_acks { + writer_q.queue.push_back(WriterQueueData::Ack { + base_packet_num: 1, + num_packets_consumed: 3, + }); + } + + let consume_res = writer_q.consume_acks(0xbcd, 0x521); + assert!(consume_res.is_some()); + + let consume = consume_res.unwrap(); + assert_eq!(consume.len(), expected_byte_size); + + // Deserialise to get the packets. + let packets_res = BluefinPacket::from_bytes(&consume); + assert!(packets_res.is_ok()); + + let packets = packets_res.unwrap(); + assert_eq!(packets.len(), num_acks); + + for p in packets { + assert_eq!(p.len(), 20); + assert_eq!(p.header.type_field, PacketType::Ack); + assert_eq!(p.header.source_connection_id, 0xbcd); + assert_eq!(p.header.destination_connection_id, 0x521); + assert_eq!(p.header.type_specific_payload as usize, 3); + assert_eq!(p.header.packet_number, 1); + } + + // Because we are adding at most 1 datagram worth of acks, we get nothing more + assert!(writer_q.consume_acks(0x0, 0x0).is_none()); + } + + #[rstest] + #[test] + #[case(1000)] + #[case(761)] + #[case(1234)] + #[case(763)] + #[case(2000)] + fn writer_queue_consume_ack_for_multiple_datagrams_behaves_as_expected( + #[case] num_acks: usize, + ) { + let expected_byte_size = num_acks * 20; + let num_datagrams = expected_byte_size.div_ceil(MAX_BLUEFIN_BYTES_IN_UDP_DATAGRAM); + assert_ne!(expected_byte_size, 0); + assert!(expected_byte_size > MAX_BLUEFIN_BYTES_IN_UDP_DATAGRAM); + assert!(num_datagrams > 1 && num_datagrams <= 10); + + let mut writer_q = WriterQueue::new(); + for ix in 0..num_acks { + writer_q.queue.push_back(WriterQueueData::Ack { + base_packet_num: ix as u64, + num_packets_consumed: ix + 1, + }); + } + + let consume_res = writer_q.consume_acks(0xbcd, 0x521); + assert!(consume_res.is_some()); + + let consume = consume_res.unwrap(); + assert_eq!(consume.len(), MAX_BLUEFIN_BYTES_IN_UDP_DATAGRAM); + + // Deserialise to get the packets. + let packets_res = BluefinPacket::from_bytes(&consume); + assert!(packets_res.is_ok()); + let packets = packets_res.unwrap(); + + let mut p_num = 0; + for (ix, p) in packets.iter().enumerate() { + assert!(p.len() <= 20); + assert_eq!(p.header.type_field, PacketType::Ack); + assert_eq!(p.header.source_connection_id, 0xbcd); + assert_eq!(p.header.destination_connection_id, 0x521); + assert_eq!(p.header.packet_number, ix as u64); + assert_eq!(p.header.type_specific_payload as usize, ix + 1); + p_num = ix; + } + assert!(p_num != 0); + + let mut actual_num_acks = 0; + actual_num_acks += packets.len(); + + let mut counter = 0; + let mut consume_res = writer_q.consume_acks(0x0, 0x0); + while counter <= 10 && consume_res.is_some() { + let consume = consume_res.unwrap(); + let packets_res = BluefinPacket::from_bytes(&consume); + assert!(packets_res.is_ok()); + let packets = packets_res.unwrap(); + for (ix, p) in packets.iter().enumerate() { + assert!(p.len() <= 20); + assert_eq!(p.header.type_field, PacketType::Ack); + assert_eq!(p.header.source_connection_id, 0x0); + assert_eq!(p.header.destination_connection_id, 0x0); + assert_eq!(p.header.packet_number, (ix + p_num + 1) as u64); + assert_eq!(p.header.type_specific_payload as usize, ix + p_num + 2); + } + p_num += packets.len(); + + actual_num_acks += packets.len(); + consume_res = writer_q.consume_acks(0x0, 0x0); + counter += 1; + } + assert_eq!(num_acks, actual_num_acks); + } + + #[rstest] + #[case(6, 550)] + #[case(20, 700)] + #[case(1, 10000)] + #[case(2, 5000)] + #[case(1, 15000)] + #[case(1, 1)] + #[case(10000, 1)] + #[case(5555, 1)] + #[case(5432, 2)] + #[case(100, 100)] + #[case(57, 57)] + #[case(55, 56)] + #[case(3, 2000)] + #[case(10, 123)] + #[test] + fn writer_queue_consume_data_for_one_datagram_behaves_as_expected( + #[case] num_iterations: usize, + #[case] payload_size: usize, + ) { + let payload_size_total = num_iterations * payload_size; + let num_packets_total = payload_size_total.div_ceil(MAX_BLUEFIN_PAYLOAD_SIZE_BYTES); + let bytes_total = payload_size_total + (20 * num_packets_total); + assert!(bytes_total <= MAX_BLUEFIN_BYTES_IN_UDP_DATAGRAM); + + let mut writer_q = WriterQueue::new(); + for ix in 0..num_iterations { + let data = vec![ix as u8; payload_size]; + writer_q + .queue + .push_back(WriterQueueData::Payload(data.to_vec())); + } + + let mut next_packet_num = 0; + let src_conn_id = 0x123; + let dst_conn_id = 0xabc; + let consume_res = writer_q.consume_data(&mut next_packet_num, src_conn_id, dst_conn_id); + assert!(consume_res.is_some()); + + let consume = consume_res.unwrap(); + let bluefin_packets_created = + (num_iterations * payload_size).div_ceil(MAX_BLUEFIN_PAYLOAD_SIZE_BYTES); + assert_ne!(bluefin_packets_created, 0); + // First next packet num was 0. If I created n packets then the last packet would have packet number n - 1. + // Therefore, the next packet num is n - 1 + 1 = n; + assert_eq!(next_packet_num, bluefin_packets_created as u64); + assert_eq!( + consume.len(), + num_iterations * payload_size + (20 * bluefin_packets_created) + ); + + // Deserialise them to get the packets. + let packets_res = BluefinPacket::from_bytes(&consume); + assert!(packets_res.is_ok()); + + let packets = packets_res.unwrap(); + assert_eq!(packets.len(), bluefin_packets_created); + + let mut payload_bytes = 0; + for p in packets { + assert!(p.len() <= MAX_BLUEFIN_PAYLOAD_SIZE_BYTES + 20); + assert_eq!(p.header.type_field, PacketType::UnencryptedData); + assert_eq!(p.header.source_connection_id, src_conn_id); + assert_eq!(p.header.destination_connection_id, dst_conn_id); + assert_eq!(p.header.type_specific_payload as usize, p.payload.len()); + payload_bytes += p.payload.len(); + } + + assert_eq!(payload_bytes, num_iterations * payload_size); + + // Since we added less than the max amount of bytes we can stuff in a datagram, one consume consumes + // all of the data. There should be nothing left. + assert!(consume.len() <= MAX_BLUEFIN_BYTES_IN_UDP_DATAGRAM); + assert!(writer_q + .consume_data(&mut next_packet_num, 0x123, 0x456) + .is_none()); + } + + #[rstest] + #[case(30, 550)] + #[case(1, 15200)] + #[case(15200, 1)] + #[case(577, 43)] + #[case(2, 15200)] + #[case(5, 15200)] + #[case(1, 15001)] + #[case(87, 292)] + #[case(9, 15001)] + #[case(1, 150000)] + #[case(432, 234)] + #[test] + fn writer_queue_consume_data_for_multiple_datagram_behaves_as_expected( + #[case] num_iterations: usize, + #[case] payload_size: usize, + ) { + let payload_size_total = num_iterations * payload_size; + let num_packets_total = payload_size_total.div_ceil(MAX_BLUEFIN_PAYLOAD_SIZE_BYTES); + let bytes_total = payload_size_total + (20 * num_packets_total); + assert!(bytes_total > MAX_BLUEFIN_BYTES_IN_UDP_DATAGRAM); + + let num_datagrams = bytes_total.div_ceil(MAX_BLUEFIN_BYTES_IN_UDP_DATAGRAM); + assert!(num_datagrams >= 1 && num_datagrams <= 10); + + let mut expected_data = vec![]; + let mut writer_q = WriterQueue::new(); + for ix in 0..num_iterations { + let data = vec![ix as u8; payload_size]; + expected_data.extend_from_slice(&data); + writer_q + .queue + .push_back(WriterQueueData::Payload(data.to_vec())); + } + + let mut next_packet_num = 0; + let src_conn_id = 0x123; + let dst_conn_id = 0xabc; + let consume_res = writer_q.consume_data(&mut next_packet_num, src_conn_id, dst_conn_id); + assert!(consume_res.is_some()); + + let consume = consume_res.unwrap(); + assert_eq!(consume.len(), MAX_BLUEFIN_BYTES_IN_UDP_DATAGRAM); + + let packets_res = BluefinPacket::from_bytes(&consume); + assert!(packets_res.is_ok()); + + let packets = packets_res.unwrap(); + assert_eq!(packets.len(), MAX_BLUEFIN_PACKETS_IN_UDP_DATAGRAM); + + let mut actual_data = vec![]; + for p in packets { + assert!(p.len() <= MAX_BLUEFIN_PAYLOAD_SIZE_BYTES + 20); + assert_eq!(p.header.type_field, PacketType::UnencryptedData); + assert_eq!(p.header.source_connection_id, src_conn_id); + assert_eq!(p.header.destination_connection_id, dst_conn_id); + assert_eq!(p.header.type_specific_payload as usize, p.payload.len()); + actual_data.extend_from_slice(&p.payload); + } + + // Fetch the rest of the data. Our tests won't go beyond 10 datagrams worth of data + // so we assert the count here just in case. + let mut counter = 0; + let mut consume_res = writer_q.consume_data(&mut next_packet_num, src_conn_id, dst_conn_id); + while counter < 10 && consume_res.is_some() { + let consume = consume_res.as_ref().unwrap(); + assert_ne!(consume.len(), 0); + + let packets_res = BluefinPacket::from_bytes(&consume); + assert!(packets_res.is_ok()); + let packets = packets_res.unwrap(); + assert_ne!(packets.len(), 0); + for p in packets { + assert!(p.len() <= MAX_BLUEFIN_PAYLOAD_SIZE_BYTES + 20); + assert_eq!(p.header.type_field, PacketType::UnencryptedData); + assert_eq!(p.header.source_connection_id, src_conn_id); + assert_eq!(p.header.destination_connection_id, dst_conn_id); + assert_eq!(p.header.type_specific_payload as usize, p.payload.len()); + actual_data.extend_from_slice(&p.payload); + } + + counter += 1; + consume_res = writer_q.consume_data(&mut next_packet_num, src_conn_id, dst_conn_id); + } + assert_eq!(num_datagrams, 1 + counter); + assert_eq!(expected_data, actual_data); + } +} diff --git a/tests/basic/basic_handshake.rs b/tests/basic/basic_handshake.rs index e6b26ee..6aec79f 100644 --- a/tests/basic/basic_handshake.rs +++ b/tests/basic/basic_handshake.rs @@ -47,7 +47,7 @@ fn loopback_ip_addr() -> Ipv4Addr { } #[rstest] -#[timeout(Duration::from_secs(5))] +#[timeout(Duration::from_secs(15))] #[case(1318, 1319, 10)] #[case(1320, 1321, 100)] #[case(1322, 1323, 222)] @@ -67,6 +67,7 @@ async fn basic_server_client_connection_send_recv( .bind() .await .expect("Encountered error while binding server"); + let _ = server.set_num_reader_workers(20); let mut client = BluefinClient::new(std::net::SocketAddr::V4(SocketAddrV4::new( *loopback_ip_addr, @@ -78,7 +79,7 @@ async fn basic_server_client_connection_send_recv( let mut join_set = JoinSet::new(); join_set.spawn(async move { - let mut conn = timeout(Duration::from_secs(3), server.accept()) + let mut conn = timeout(Duration::from_secs(10), server.accept()) .await .expect("Server timed out waiting to accept connection from client") .expect("Failed to create bluefin connection"); @@ -162,8 +163,8 @@ async fn basic_server_client_connection_send_recv( .await .expect("Client timed out waiting to connect to server"); - // Wait for 250ms for the server to be ready - sleep(Duration::from_millis(250)).await; + // Wait for 100ms for the server to be ready + sleep(Duration::from_millis(100)).await; // Send TOTAL_NUM_BYTES_SENT across the wire let mut total_num_bytes_sent = 0; @@ -264,13 +265,11 @@ async fn basic_server_client_connection_send_recv( } #[rstest] -#[timeout(Duration::from_secs(5))] +#[timeout(Duration::from_secs(15))] #[tokio::test] async fn basic_server_client_multiple_connections_send_recv(loopback_ip_addr: &Ipv4Addr) { use std::sync::Arc; - use rand::Rng; - let mut server = BluefinServer::new(std::net::SocketAddr::V4(SocketAddrV4::new( *loopback_ip_addr, 1419, @@ -281,10 +280,9 @@ async fn basic_server_client_multiple_connections_send_recv(loopback_ip_addr: &I .expect("Encountered error while binding server"); let mut join_set = JoinSet::new(); - const NUM_CONNECTIONS: usize = 10; + const NUM_CONNECTIONS: usize = 3; const MAX_BYTES_SENT_PER_CONNECTION: usize = 3200; - let client_ports: [u16; NUM_CONNECTIONS] = - [1420, 1421, 1422, 1423, 1424, 1425, 1426, 1427, 1428, 1429]; + let client_ports: [u16; NUM_CONNECTIONS] = [1420, 1421, 1422]; let loopback_cloned = loopback_ip_addr.clone(); let data = Arc::new(generate_connection_date(NUM_CONNECTIONS)); @@ -292,7 +290,7 @@ async fn basic_server_client_multiple_connections_send_recv(loopback_ip_addr: &I let mut s = server.clone(); let data_cloned = Arc::clone(&data); join_set.spawn(async move { - let mut conn = timeout(Duration::from_secs(3), s.accept()) + let mut conn = timeout(Duration::from_secs(10), s.accept()) .await .expect(&format!( "Server #{} timed out waiting to accept connection from client", @@ -314,14 +312,14 @@ async fn basic_server_client_multiple_connections_send_recv(loopback_ip_addr: &I let expected_data = data_cloned.get(key).expect("Could not fetch expected data"); let mut stitched_bytes: Vec = Vec::new(); + let mut buf = [0u8; 1500]; loop { - let mut buf = [0u8; 100]; - let size = timeout(Duration::from_secs(1), conn.recv(&mut buf, 100)) + let size = timeout(Duration::from_secs(1), conn.recv(&mut buf, 1500)) .await .expect("Server timed out while waiting for data") .expect("Server encountered error while receiving data"); assert_ne!(size, 0); - assert!(size <= 100); + assert!(size <= 1500); stitched_bytes.extend_from_slice(&buf[..size]); if stitched_bytes.len() == MAX_BYTES_SENT_PER_CONNECTION { @@ -332,12 +330,9 @@ async fn basic_server_client_multiple_connections_send_recv(loopback_ip_addr: &I }); } for conn_num in 0..NUM_CONNECTIONS { - // Sleep for a random amount of time before sending data. This will add some variation - // in the order of processing. - let sleep_duration_in_ms = rand::thread_rng().gen_range(0..300); + // Random amount of time to sleep let data_cloned = Arc::clone(&data); join_set.spawn(async move { - sleep(Duration::from_millis(sleep_duration_in_ms)).await; let mut client = BluefinClient::new(std::net::SocketAddr::V4(SocketAddrV4::new( loopback_cloned, client_ports[conn_num], @@ -362,7 +357,7 @@ async fn basic_server_client_multiple_connections_send_recv(loopback_ip_addr: &I .expect("Client encountered error while sending"); assert_eq!(size, 5); - sleep(Duration::from_millis(50)).await; + sleep(Duration::from_millis(10)).await; // Now begin sending the actual data in batches of 32 bytes let mut total_bytes_sent = 5;