diff --git a/src/net/client.rs b/src/net/client.rs index 7b3d479..23f8f46 100644 --- a/src/net/client.rs +++ b/src/net/client.rs @@ -41,6 +41,17 @@ impl BluefinClient { } } + #[inline] + pub fn set_num_reader_workers(&mut self, num_reader_workers: u16) -> BluefinResult<()> { + if num_reader_workers == 0 { + return Err(BluefinError::Unexpected( + "Cannot have zero reader values".to_string(), + )); + } + self.num_reader_workers = num_reader_workers; + Ok(()) + } + pub async fn connect(&mut self, dst_addr: SocketAddr) -> BluefinResult { let socket = Arc::new(UdpSocket::bind(self.src_addr).await?); self.socket = Some(Arc::clone(&socket)); diff --git a/src/net/ordered_bytes.rs b/src/net/ordered_bytes.rs index c1d58e8..986492a 100644 --- a/src/net/ordered_bytes.rs +++ b/src/net/ordered_bytes.rs @@ -204,7 +204,6 @@ impl OrderedBytes { } } - let mut ix = 0; let base = self.smallest_packet_number_index; let base_packet_number = { if let Some(ref _p) = self.packets[base] { @@ -214,6 +213,7 @@ impl OrderedBytes { } }; + let mut ix = 0; while ix < MAX_BUFFER_SIZE && self.packets[(base + ix) % MAX_BUFFER_SIZE].is_some() && num_bytes < len @@ -257,3 +257,186 @@ impl OrderedBytes { Ok(ConsumeResult::new(ix, base_packet_number, num_bytes as u64)) } } + +#[cfg(test)] +mod tests { + use crate::{ + core::{ + error::BluefinError, + header::{BluefinHeader, BluefinSecurityFields, PacketType}, + packet::BluefinPacket, + }, + net::MAX_BLUEFIN_PAYLOAD_SIZE_BYTES, + }; + + use super::OrderedBytes; + + #[test] + fn ordered_bytes_carry_over_behaves_as_expected() { + let start_packet_num = rand::random(); + let mut ordered_bytes = OrderedBytes::new(0x0, start_packet_num); + + assert!(ordered_bytes + .peek() + .is_err_and(|e| e == BluefinError::BufferEmptyError)); + + // Buffer in one packet with payload of 1500 bytes + let mut payload = vec![]; + while payload.len() != MAX_BLUEFIN_PAYLOAD_SIZE_BYTES { + let r: [u8; 15] = rand::random(); + payload.extend(r); + } + + let security_fields = BluefinSecurityFields::new(false, 0x0); + let mut header = + BluefinHeader::new(0x0, 0x0, PacketType::UnencryptedData, 0, security_fields); + header.packet_number = start_packet_num; + let packet = BluefinPacket::builder() + .header(header) + .payload(payload.clone()) + .build(); + assert!(ordered_bytes.buffer_in_packet(packet).is_ok()); + + let mut buf = [0u8; 100]; + let consume_res = ordered_bytes.consume(100, &mut buf); + assert!(consume_res.is_ok()); + + // Consumed 100 bytes. This means 1500 - 100 = 1400 bytes are buffered in the left-over + // bytes buffer + let consume = consume_res.unwrap(); + assert_eq!(consume.base_packet_number, start_packet_num); + assert_eq!(consume.num_packets_consumed, 1); + assert_eq!(consume.bytes_consumed, 100); + assert_eq!(payload[..100], buf[..100]); + + // Insert another packet with 1500 bytes + let mut second_payload = vec![]; + while second_payload.len() != MAX_BLUEFIN_PAYLOAD_SIZE_BYTES { + let r: [u8; 15] = rand::random(); + second_payload.extend(r); + } + header.packet_number = start_packet_num + 1; + let packet = BluefinPacket::builder() + .header(header) + .payload(second_payload.clone()) + .build(); + assert!(ordered_bytes.buffer_in_packet(packet).is_ok()); + + // Consume another 100 bytes. These 100 bytes should still come from the first payload. + let consume_res = ordered_bytes.consume(100, &mut buf); + assert!(consume_res.is_ok()); + + // We now have 1400 - 100 = 1300 bytes left in the carry over. + let consume = consume_res.unwrap(); + // Base packet number should be zero since it's all coming from the carry over + assert_eq!(consume.base_packet_number, 0); + assert_eq!(consume.num_packets_consumed, 0); + assert_eq!(consume.bytes_consumed, 100); + assert_eq!(payload[100..200], buf[..100]); + + // Concume 1400 bytes. + let mut buf = [0u8; 1400]; + let consume_res = ordered_bytes.consume(1400, &mut buf); + assert!(consume_res.is_ok()); + + // 1300 of these bytes come from the carry over. The remaining 100 bytes are from the second + // packet we inserted + let consume = consume_res.unwrap(); + assert_eq!(consume.base_packet_number, start_packet_num + 1); + assert_eq!(consume.num_packets_consumed, 1); + assert_eq!(consume.bytes_consumed, 1400); + assert_eq!(payload[200..], buf[..1300]); + assert_eq!(second_payload[..100], buf[1300..]); + } + + #[test] + fn ordered_bytes_consume_behaves_as_expected() { + let start_packet_num = rand::random(); + let mut ordered_bytes = OrderedBytes::new(0x0, start_packet_num); + + assert!(ordered_bytes + .peek() + .is_err_and(|e| e == BluefinError::BufferEmptyError)); + + let security_fields = BluefinSecurityFields::new(false, 0x0); + let mut header = + BluefinHeader::new(0x0, 0x0, PacketType::UnencryptedData, 0, security_fields); + header.packet_number = start_packet_num + 1; + let mut packet = BluefinPacket::builder() + .header(header) + .payload([1, 2, 3].to_vec()) + .build(); + + assert!(ordered_bytes.buffer_in_packet(packet.clone()).is_ok()); + assert!(ordered_bytes + .peek() + .is_err_and(|e| e == BluefinError::BufferEmptyError)); + + packet.header.packet_number = start_packet_num + 2; + assert!(ordered_bytes.buffer_in_packet(packet.clone()).is_ok()); + assert!(ordered_bytes + .peek() + .is_err_and(|e| e == BluefinError::BufferEmptyError)); + + packet.header.packet_number = start_packet_num + 3; + assert!(ordered_bytes.buffer_in_packet(packet.clone()).is_ok()); + assert!(ordered_bytes + .peek() + .is_err_and(|e| e == BluefinError::BufferEmptyError)); + + packet.header.packet_number = start_packet_num + 5; + assert!(ordered_bytes.buffer_in_packet(packet.clone()).is_ok()); + assert!(ordered_bytes + .peek() + .is_err_and(|e| e == BluefinError::BufferEmptyError)); + + packet.header.packet_number = start_packet_num; + assert!(ordered_bytes.buffer_in_packet(packet.clone()).is_ok()); + assert!(ordered_bytes.peek().is_ok()); + + let mut buf = [0u8; 10]; + let consume_res = ordered_bytes.consume(1, &mut buf); + assert!(consume_res.is_ok()); + + let consume = consume_res.unwrap(); + assert_eq!(consume.base_packet_number, start_packet_num); + assert_eq!(consume.num_packets_consumed, 1); + assert_eq!(consume.bytes_consumed, 1); + assert_eq!(buf, [1, 0, 0, 0, 0, 0, 0, 0, 0, 0]); + + // From carry over, 0 packets + let consume_res = ordered_bytes.consume(1, &mut buf); + assert!(consume_res.is_ok()); + let consume = consume_res.unwrap(); + assert_eq!(consume.num_packets_consumed, 0); + assert_eq!(consume.bytes_consumed, 1); + assert_eq!(buf, [2, 0, 0, 0, 0, 0, 0, 0, 0, 0]); + + let consume_res = ordered_bytes.consume(3, &mut buf); + assert!(consume_res.is_ok()); + let consume = consume_res.unwrap(); + assert_eq!(consume.num_packets_consumed, 1); + assert_eq!(consume.bytes_consumed, 3); + assert_eq!(buf, [3, 1, 2, 0, 0, 0, 0, 0, 0, 0]); + + let consume_res = ordered_bytes.consume(4, &mut buf); + assert!(consume_res.is_ok()); + let consume = consume_res.unwrap(); + assert_eq!(consume.num_packets_consumed, 1); + assert_eq!(consume.bytes_consumed, 4); + assert_eq!(buf, [3, 1, 2, 3, 0, 0, 0, 0, 0, 0]); + + let mut buf = [0u8; 10]; + let consume_res = ordered_bytes.consume(10, &mut buf); + assert!(consume_res.is_ok()); + let consume = consume_res.unwrap(); + assert_eq!(consume.num_packets_consumed, 1); + assert_eq!(consume.bytes_consumed, 3); + assert_eq!(buf, [1, 2, 3, 0, 0, 0, 0, 0, 0, 0]); + + assert!(ordered_bytes + .peek() + .is_err_and(|e| e == BluefinError::BufferEmptyError)); + assert!(ordered_bytes.consume(1, &mut buf).is_err()); + } +} diff --git a/src/worker/reader.rs b/src/worker/reader.rs index 9dd19b8..33f8254 100644 --- a/src/worker/reader.rs +++ b/src/worker/reader.rs @@ -95,7 +95,7 @@ impl ReaderRxChannel { let base_packet_num = consume_res.get_base_packet_number(); // We need to send an ack. - if num_packets_consumed > 0 { + if num_packets_consumed > 0 && base_packet_num != 0 { if let Err(e) = self .writer_tx_channel .send_ack(base_packet_num, num_packets_consumed) @@ -232,11 +232,9 @@ impl ReaderTxChannel { if !is_client_ack && !is_hello && packet.header.type_field == PacketType::Ack { let mut ack_buff = buffers.ack_buff.lock().unwrap(); Self::buffer_to_ack_buffer(&mut ack_buff, packet)?; - drop(ack_buff); } else { let mut conn_buff = buffers.conn_buff.lock().unwrap(); Self::buffer_to_conn_buffer(&mut conn_buff, packet, addr, is_hello, is_client_ack)?; - drop(conn_buff); } Ok(()) } diff --git a/src/worker/writer.rs b/src/worker/writer.rs index a11f9b8..6305a8f 100644 --- a/src/worker/writer.rs +++ b/src/worker/writer.rs @@ -248,7 +248,7 @@ impl WriterTxChannel { } self.num_runs_without_sleep += 1; - if self.num_runs_without_sleep >= 137 { + if self.num_runs_without_sleep >= 100 { sleep(Duration::from_nanos(10)).await; self.num_runs_without_sleep = 0; }