From 398a26d603457a5d1d89c908548275e30c387e0b Mon Sep 17 00:00:00 2001 From: Hasan Date: Thu, 28 Nov 2024 15:57:04 +0100 Subject: [PATCH 1/4] Only mark timers to update Timers are updated on every single packet call, but their accuracy is limited to 250ms where currentTime is updated. So just mark timers to update, and update them only once in update_timers. --- neptun/src/noise/mod.rs | 22 ++++----- neptun/src/noise/timers.rs | 94 +++++++++++++++++++++++++++++++++++++- 2 files changed, 103 insertions(+), 13 deletions(-) diff --git a/neptun/src/noise/mod.rs b/neptun/src/noise/mod.rs index 5b91302..5e7269a 100644 --- a/neptun/src/noise/mod.rs +++ b/neptun/src/noise/mod.rs @@ -276,10 +276,10 @@ impl Tunn { if let Some(ref session) = self.sessions[current % N_SESSIONS] { // Send the packet using an established session let packet = session.format_packet_data(src, dst); - self.timer_tick(TimerName::TimeLastPacketSent); + self.mark_timer_to_update(TimerName::TimeLastPacketSent); // Exclude Keepalive packets from timer update. if !src.is_empty() { - self.timer_tick(TimerName::TimeLastDataPacketSent); + self.mark_timer_to_update(TimerName::TimeLastDataPacketSent); } self.tx_bytes += packet.len(); return TunnResult::WriteToNetwork(packet); @@ -365,8 +365,8 @@ impl Tunn { let index = session.local_index(); self.sessions[index % N_SESSIONS] = Some(session); - self.timer_tick(TimerName::TimeLastPacketReceived); - self.timer_tick(TimerName::TimeLastPacketSent); + self.mark_timer_to_update(TimerName::TimeLastPacketReceived); + self.mark_timer_to_update(TimerName::TimeLastPacketSent); self.timer_tick_session_established(false, index); // New session established, we are not the initiator tracing::debug!(message = "Sending handshake_response", local_idx = index); @@ -399,7 +399,7 @@ impl Tunn { let index = l_idx % N_SESSIONS; self.sessions[index] = Some(session); - self.timer_tick(TimerName::TimeLastPacketReceived); + self.mark_timer_to_update(TimerName::TimeLastPacketReceived); self.timer_tick_session_established(true, index); // New session established, we are the initiator self.set_current_session(l_idx); @@ -424,8 +424,8 @@ impl Tunn { // Increase the rx_bytes accordingly self.rx_bytes += COOKIE_REPLY_SZ; - self.timer_tick(TimerName::TimeLastPacketReceived); - self.timer_tick(TimerName::TimeCookieReceived); + self.mark_timer_to_update(TimerName::TimeLastPacketReceived); + self.mark_timer_to_update(TimerName::TimeCookieReceived); tracing::debug!("Did set cookie"); @@ -469,7 +469,7 @@ impl Tunn { self.set_current_session(r_idx); - self.timer_tick(TimerName::TimeLastPacketReceived); + self.mark_timer_to_update(TimerName::TimeLastPacketReceived); Ok(self.validate_decapsulated_packet(decapsulated_packet)) } @@ -496,9 +496,9 @@ impl Tunn { tracing::debug!("Sending handshake_initiation"); if starting_new_handshake { - self.timer_tick(TimerName::TimeLastHandshakeStarted); + self.mark_timer_to_update(TimerName::TimeLastHandshakeStarted); } - self.timer_tick(TimerName::TimeLastPacketSent); + self.mark_timer_to_update(TimerName::TimeLastPacketSent); self.tx_bytes += packet.len(); TunnResult::WriteToNetwork(packet) @@ -548,7 +548,7 @@ impl Tunn { return TunnResult::Err(WireGuardError::InvalidPacket); } - self.timer_tick(TimerName::TimeLastDataPacketReceived); + self.mark_timer_to_update(TimerName::TimeLastDataPacketReceived); self.rx_bytes += message_data_len(computed_len); match src_ip_address { diff --git a/neptun/src/noise/timers.rs b/neptun/src/noise/timers.rs index 07b4bfd..6cce8df 100644 --- a/neptun/src/noise/timers.rs +++ b/neptun/src/noise/timers.rs @@ -6,6 +6,7 @@ use super::errors::WireGuardError; use crate::noise::{safe_duration::SafeDuration as Duration, Tunn, TunnResult}; use std::mem; use std::ops::{Index, IndexMut}; +use std::sync::atomic::AtomicU16; use std::time::SystemTime; #[cfg(feature = "mock-instant")] @@ -42,7 +43,7 @@ pub(crate) const REKEY_TIMEOUT: Duration = Duration::from_secs(5); const KEEPALIVE_TIMEOUT: Duration = Duration::from_secs(10); const COOKIE_EXPIRATION_TIME: Duration = Duration::from_secs(120); -#[derive(Debug)] +#[derive(Debug, Clone, Copy)] pub enum TimerName { /// Current time, updated each call to `update_timers` TimeCurrent, @@ -65,6 +66,20 @@ pub enum TimerName { Top, } +impl TimerName { + pub const VALUES: [Self; TimerName::Top as usize] = [ + Self::TimeCurrent, + Self::TimeSessionEstablished, + Self::TimeLastHandshakeStarted, + Self::TimeLastPacketReceived, + Self::TimeLastPacketSent, + Self::TimeLastDataPacketReceived, + Self::TimeLastDataPacketSent, + Self::TimeCookieReceived, + Self::TimePersistentKeepalive, + ]; +} + use self::TimerName::*; #[derive(Debug)] @@ -82,6 +97,7 @@ pub struct Timers { persistent_keepalive: usize, /// Should this timer call reset rr function (if not a shared rr instance) pub(super) should_reset_rr: bool, + timers_to_update_mask: AtomicU16, } impl Timers { @@ -95,6 +111,7 @@ impl Timers { want_handshake_since: Default::default(), persistent_keepalive: usize::from(persistent_keepalive.unwrap_or(0)), should_reset_rr: reset_rr, + timers_to_update_mask: Default::default(), } } @@ -128,7 +145,13 @@ impl IndexMut for Timers { } impl Tunn { - pub(super) fn timer_tick(&mut self, timer_name: TimerName) { + pub(super) fn mark_timer_to_update(&self, timer_name: TimerName) { + self.timers + .timers_to_update_mask + .fetch_or(1 << timer_name as u16, std::sync::atomic::Ordering::Relaxed); + } + + fn timer_tick(&mut self, timer_name: TimerName) { let time = self.timers[TimeCurrent]; match timer_name { TimeLastPacketReceived => { @@ -207,6 +230,21 @@ impl Tunn { let now = time.duration_since(self.timers.time_started).into(); self.timers[TimeCurrent] = now; + // Check which timers to update, and update them + let timer_mask = self + .timers + .timers_to_update_mask + .load(std::sync::atomic::Ordering::Relaxed); + for timer_name in TimerName::VALUES { + if (timer_mask & (1 << (timer_name as u16))) != 0 { + self.timer_tick(timer_name); + } + } + // Reset all marked bits + self.timers + .timers_to_update_mask + .store(0, std::sync::atomic::Ordering::Relaxed); + self.update_session_timers(now); // Load timers only once: @@ -380,3 +418,55 @@ impl Tunn { self.timers.persistent_keepalive = keepalive as usize; } } + +#[cfg(test)] +mod tests { + use rand::RngCore; + use rand_core::OsRng; + + use crate::noise::{safe_duration::SafeDuration, Tunn}; + + use super::TimerName; + + #[test] + fn create_two_tuns() { + let my_secret_key = x25519_dalek::StaticSecret::random_from_rng(OsRng); + let my_idx = OsRng.next_u32(); + + let their_secret_key = x25519_dalek::StaticSecret::random_from_rng(OsRng); + let their_public_key = x25519_dalek::PublicKey::from(&their_secret_key); + + let mut my_tun = + Tunn::new(my_secret_key, their_public_key, None, None, my_idx, None).unwrap(); + + // Mark timers to update + my_tun.mark_timer_to_update(super::TimerName::TimeLastDataPacketSent); + my_tun.mark_timer_to_update(super::TimerName::TimeLastDataPacketReceived); + my_tun.mark_timer_to_update(super::TimerName::TimePersistentKeepalive); + + // Update timers + my_tun.update_timers(&mut [0]); + + // Only those timers marked should be udpated + assert!(!my_tun.timers[TimerName::TimeLastDataPacketSent].is_zero()); + assert!(!my_tun.timers[TimerName::TimeLastDataPacketReceived].is_zero()); + assert!(!my_tun.timers[TimerName::TimePersistentKeepalive].is_zero()); + + // Unmarked timers should still be 0 + assert!(my_tun.timers[TimerName::TimeCookieReceived].is_zero()); + assert!(my_tun.timers[TimerName::TimeLastHandshakeStarted].is_zero()); + assert!(my_tun.timers[TimerName::TimeLastPacketReceived].is_zero()); + + // Reset the timers + my_tun.timers[TimerName::TimeLastDataPacketSent] = SafeDuration::from_millis(0); + my_tun.timers[TimerName::TimeLastDataPacketReceived] = SafeDuration::from_millis(0); + my_tun.timers[TimerName::TimePersistentKeepalive] = SafeDuration::from_millis(0); + + my_tun.update_timers(&mut [0]); + + // Now the timers should not update + assert!(my_tun.timers[TimerName::TimeLastDataPacketSent].is_zero()); + assert!(my_tun.timers[TimerName::TimeLastDataPacketReceived].is_zero()); + assert!(my_tun.timers[TimerName::TimePersistentKeepalive].is_zero()); + } +} From 4ba9f5e7a11294dd41b3ac55016fa546045cf913 Mon Sep 17 00:00:00 2001 From: Hasan Date: Thu, 5 Dec 2024 11:32:33 +0100 Subject: [PATCH 2/4] Add test + refactor --- neptun/src/noise/timers.rs | 62 ++++++++++++++++++++++++++++++-------- 1 file changed, 49 insertions(+), 13 deletions(-) diff --git a/neptun/src/noise/timers.rs b/neptun/src/noise/timers.rs index 6cce8df..a231128 100644 --- a/neptun/src/noise/timers.rs +++ b/neptun/src/noise/timers.rs @@ -67,8 +67,7 @@ pub enum TimerName { } impl TimerName { - pub const VALUES: [Self; TimerName::Top as usize] = [ - Self::TimeCurrent, + pub const VALUES: [Self; TimerName::Top as usize - 1] = [ Self::TimeSessionEstablished, Self::TimeLastHandshakeStarted, Self::TimeLastPacketReceived, @@ -215,6 +214,14 @@ impl Tunn { } } + fn tick_marked_timers(&mut self, timer_mask: u16) { + for timer_name in TimerName::VALUES { + if (timer_mask & (1 << (timer_name as u16))) != 0 { + self.timer_tick(timer_name); + } + } + } + pub fn update_timers<'a>(&mut self, dst: &'a mut [u8]) -> TunnResult<'a> { let mut handshake_initiation_required = false; let mut keepalive_required = false; @@ -234,16 +241,8 @@ impl Tunn { let timer_mask = self .timers .timers_to_update_mask - .load(std::sync::atomic::Ordering::Relaxed); - for timer_name in TimerName::VALUES { - if (timer_mask & (1 << (timer_name as u16))) != 0 { - self.timer_tick(timer_name); - } - } - // Reset all marked bits - self.timers - .timers_to_update_mask - .store(0, std::sync::atomic::Ordering::Relaxed); + .swap(0, std::sync::atomic::Ordering::Relaxed); + self.tick_marked_timers(timer_mask); self.update_session_timers(now); @@ -429,7 +428,7 @@ mod tests { use super::TimerName; #[test] - fn create_two_tuns() { + fn test_update_marked_timers() { let my_secret_key = x25519_dalek::StaticSecret::random_from_rng(OsRng); let my_idx = OsRng.next_u32(); @@ -469,4 +468,41 @@ mod tests { assert!(my_tun.timers[TimerName::TimeLastDataPacketReceived].is_zero()); assert!(my_tun.timers[TimerName::TimePersistentKeepalive].is_zero()); } + + #[test] + fn test_mark_timers_during_update() { + let my_secret_key = x25519_dalek::StaticSecret::random_from_rng(OsRng); + let my_idx = OsRng.next_u32(); + + let their_secret_key = x25519_dalek::StaticSecret::random_from_rng(OsRng); + let their_public_key = x25519_dalek::PublicKey::from(&their_secret_key); + + let mut my_tun = + Tunn::new(my_secret_key, their_public_key, None, None, my_idx, None).unwrap(); + + // Mark timers to update + my_tun.mark_timer_to_update(super::TimerName::TimeLastDataPacketSent); + + let timer_mask = my_tun + .timers + .timers_to_update_mask + .swap(0, std::sync::atomic::Ordering::Relaxed); + + my_tun.mark_timer_to_update(super::TimerName::TimeLastDataPacketReceived); + + my_tun.tick_marked_timers(timer_mask); + + // Only those timers marked should be udpated + assert!(!my_tun.timers[TimerName::TimeLastDataPacketSent].is_zero()); + assert!(my_tun.timers[TimerName::TimeLastDataPacketReceived].is_zero()); + + // Reset the timers + my_tun.timers[TimerName::TimeLastDataPacketSent] = SafeDuration::from_millis(0); + + my_tun.update_timers(&mut [0]); + + // Now the timers should not update + assert!(my_tun.timers[TimerName::TimeLastDataPacketSent].is_zero()); + assert!(!my_tun.timers[TimerName::TimeLastDataPacketReceived].is_zero()); + } } From 082c0653e0ecea480d80d31411738dc0d31d17a4 Mon Sep 17 00:00:00 2001 From: Hasan Date: Thu, 5 Dec 2024 19:49:26 +0100 Subject: [PATCH 3/4] Encrypt in-place --- neptun/src/device/mod.rs | 12 +++++++----- neptun/src/noise/mod.rs | 25 +++++++++++++++---------- neptun/src/noise/session.rs | 22 ++++++++++++---------- neptun/src/noise/timers.rs | 2 +- 4 files changed, 35 insertions(+), 26 deletions(-) diff --git a/neptun/src/device/mod.rs b/neptun/src/device/mod.rs index 0913aaa..3315df5 100644 --- a/neptun/src/device/mod.rs +++ b/neptun/src/device/mod.rs @@ -40,6 +40,7 @@ use std::thread; use crate::noise::errors::WireGuardError; use crate::noise::handshake::parse_handshake_anon; use crate::noise::rate_limiter::RateLimiter; +use crate::noise::DATA_OFFSET; use crate::noise::{Packet, Tunn, TunnResult}; use crate::x25519; use allowed_ips::AllowedIps; @@ -1025,8 +1026,9 @@ impl Device { let peers = &d.peers_by_ip; for _ in 0..MAX_ITR { - let src = match iface.read(&mut t.src_buf[..mtu]) { - Ok(src) => src, + let src_buf = &mut t.src_buf[DATA_OFFSET..]; + let src_len = match iface.read(&mut src_buf[..mtu]) { + Ok(src) => src.len(), Err(Error::IfaceRead(e)) => { let ek = e.kind(); if ek == io::ErrorKind::Interrupted || ek == io::ErrorKind::WouldBlock { @@ -1045,7 +1047,7 @@ impl Device { } }; - let dst_addr = match Tunn::dst_address(src) { + let dst_addr = match Tunn::dst_address(&src_buf[..src_len]) { Some(addr) => addr, None => continue, }; @@ -1057,14 +1059,14 @@ impl Device { if let Some(callback) = &d.config.firewall_process_outbound_callback { - if !callback(&peer.public_key.0, src) { + if !callback(&peer.public_key.0, &src_buf[..src_len]) { continue; } } let res = { let mut tun = peer.tunnel.lock(); - tun.encapsulate(src, &mut t.dst_buf[..]) + tun.encapsulate(&mut t.src_buf[..], src_len) }; match res { TunnResult::Done => {} diff --git a/neptun/src/noise/mod.rs b/neptun/src/noise/mod.rs index 5e7269a..9745c6e 100644 --- a/neptun/src/noise/mod.rs +++ b/neptun/src/noise/mod.rs @@ -46,6 +46,9 @@ const MAX_QUEUE_DEPTH: usize = 256; /// number of sessions in the ring, better keep a PoT const N_SESSIONS: usize = 8; +/// Where encrypted data resides in a data packet +pub const DATA_OFFSET: usize = 16; + #[derive(Debug)] pub enum TunnResult<'a> { Done, @@ -271,29 +274,29 @@ impl Tunn { /// # Panics /// Panics if dst buffer is too small. /// Size of dst should be at least src.len() + 32, and no less than 148 bytes. - pub fn encapsulate<'a>(&mut self, src: &[u8], dst: &'a mut [u8]) -> TunnResult<'a> { + pub fn encapsulate<'a>(&mut self, buffer: &'a mut [u8], data_len: usize) -> TunnResult<'a> { let current = self.current; if let Some(ref session) = self.sessions[current % N_SESSIONS] { // Send the packet using an established session - let packet = session.format_packet_data(src, dst); + let packet = session.format_packet_data(buffer, data_len); self.mark_timer_to_update(TimerName::TimeLastPacketSent); // Exclude Keepalive packets from timer update. - if !src.is_empty() { + if !data_len != 0 { self.mark_timer_to_update(TimerName::TimeLastDataPacketSent); } self.tx_bytes += packet.len(); return TunnResult::WriteToNetwork(packet); } - if !src.is_empty() { + if !data_len != 0 { // If there is no session, queue the packet for future retry, // except if it's keepalive packet, new keepalive packets will be sent when session is created. // This prevents double keepalive packets on initiation - self.queue_packet(src); + self.queue_packet(&buffer[..data_len]); } // Initiate a new handshake if none is in progress - self.format_handshake_initiation(dst, false) + self.format_handshake_initiation(buffer, false) } /// Receives a UDP datagram from the network and parses it. @@ -393,7 +396,7 @@ impl Tunn { // Increase the rx_bytes accordingly self.rx_bytes += HANDSHAKE_RESP_SZ; - let keepalive_packet = session.format_packet_data(&[], dst); + let keepalive_packet = session.format_packet_data(dst, 0); // Store new session in ring buffer let l_idx = session.local_index(); let index = l_idx % N_SESSIONS; @@ -560,7 +563,9 @@ impl Tunn { /// Get a packet from the queue, and try to encapsulate it fn send_queued_packet<'a>(&mut self, dst: &'a mut [u8]) -> TunnResult<'a> { if let Some(packet) = self.dequeue_packet() { - match self.encapsulate(&packet, dst) { + let len = packet.len(); + dst[..len].copy_from_slice(&packet); + match self.encapsulate(dst, len) { TunnResult::Err(_) => { // On error, return packet to the queue self.requeue_packet(packet); @@ -824,9 +829,9 @@ mod tests { let mut my_dst = [0u8; 1024]; let mut their_dst = [0u8; 1024]; - let sent_packet_buf = create_ipv4_udp_packet(); + let mut sent_packet_buf = create_ipv4_udp_packet(); - let data = my_tun.encapsulate(&sent_packet_buf, &mut my_dst); + let data = my_tun.encapsulate(&mut sent_packet_buf, sent_packet_buf.len()); assert!(matches!(data, TunnResult::WriteToNetwork(_))); let data = if let TunnResult::WriteToNetwork(sent) = data { sent diff --git a/neptun/src/noise/session.rs b/neptun/src/noise/session.rs index 3aa6408..a1d6cd3 100644 --- a/neptun/src/noise/session.rs +++ b/neptun/src/noise/session.rs @@ -3,6 +3,7 @@ // SPDX-License-Identifier: BSD-3-Clause use super::PacketData; +use super::DATA_OFFSET; use crate::noise::errors::WireGuardError; use parking_lot::Mutex; use ring::aead::{Aad, LessSafeKey, Nonce, UnboundKey, CHACHA20_POLY1305}; @@ -27,8 +28,6 @@ impl std::fmt::Debug for Session { } } -/// Where encrypted data resides in a data packet -const DATA_OFFSET: usize = 16; /// The overhead of the AEAD const AEAD_SIZE: usize = 16; @@ -194,14 +193,18 @@ impl Session { /// src - an IP packet from the interface /// dst - pre-allocated space to hold the encapsulating UDP packet to send over the network /// returns the size of the formatted packet - pub(super) fn format_packet_data<'a>(&self, src: &[u8], dst: &'a mut [u8]) -> &'a mut [u8] { - if dst.len() < src.len() + super::DATA_OVERHEAD_SZ { + pub(super) fn format_packet_data<'a>( + &self, + buffer: &'a mut [u8], + data_len: usize, + ) -> &'a mut [u8] { + if buffer.len() < data_len + super::DATA_OVERHEAD_SZ { panic!("The destination buffer is too small"); } let sending_key_counter = self.sending_key_counter.fetch_add(1, Ordering::Relaxed) as u64; - let (message_type, rest) = dst.split_at_mut(4); + let (message_type, rest) = buffer.split_at_mut(4); let (receiver_index, rest) = rest.split_at_mut(4); let (counter, data) = rest.split_at_mut(8); @@ -213,21 +216,20 @@ impl Session { let n = { let mut nonce = [0u8; 12]; nonce[4..12].copy_from_slice(&sending_key_counter.to_le_bytes()); - data[..src.len()].copy_from_slice(src); self.sender .seal_in_place_separate_tag( Nonce::assume_unique_for_key(nonce), Aad::from(&[]), - &mut data[..src.len()], + &mut data[..data_len], ) .map(|tag| { - data[src.len()..src.len() + AEAD_SIZE].copy_from_slice(tag.as_ref()); - src.len() + AEAD_SIZE + data[data_len..data_len + AEAD_SIZE].copy_from_slice(tag.as_ref()); + data_len + AEAD_SIZE }) .unwrap() }; - &mut dst[..DATA_OFFSET + n] + &mut buffer[..DATA_OFFSET + n] } /// packet - a data packet we received from the network diff --git a/neptun/src/noise/timers.rs b/neptun/src/noise/timers.rs index a231128..1ac0550 100644 --- a/neptun/src/noise/timers.rs +++ b/neptun/src/noise/timers.rs @@ -374,7 +374,7 @@ impl Tunn { } if keepalive_required { - return self.encapsulate(&[], dst); + return self.encapsulate(dst, 0); } TunnResult::Done From fa0f75f1f66a79edd57acd0816065a95ff97ba6c Mon Sep 17 00:00:00 2001 From: Hasan Date: Mon, 9 Dec 2024 20:47:04 +0100 Subject: [PATCH 4/4] Decrypt packets in place --- neptun/src/device/mod.rs | 26 +++-- neptun/src/noise/integration_tests/mod.rs | 11 +- neptun/src/noise/mod.rs | 123 +++++++++++++--------- neptun/src/noise/rate_limiter.rs | 19 +++- neptun/src/noise/session.rs | 9 +- xtask/src/perf.rs | 2 +- 6 files changed, 108 insertions(+), 82 deletions(-) diff --git a/neptun/src/device/mod.rs b/neptun/src/device/mod.rs index 3315df5..1e8a146 100644 --- a/neptun/src/device/mod.rs +++ b/neptun/src/device/mod.rs @@ -35,6 +35,7 @@ use std::os::fd::RawFd; use std::os::unix::io::AsRawFd; use std::sync::atomic::{AtomicUsize, Ordering}; use std::sync::Arc; +#[cfg(not(any(target_os = "macos", target_os = "ios", target_os = "tvos")))] use std::thread; use crate::noise::errors::WireGuardError; @@ -775,10 +776,9 @@ impl Device { let src_buf = unsafe { &mut *(&mut t.src_buf[..] as *mut [u8] as *mut [MaybeUninit]) }; while let Ok((packet_len, addr)) = udp.recv_from(src_buf) { - let packet = &t.src_buf[..packet_len]; // The rate limiter initially checks mac1 and mac2, and optionally asks to send a cookie let parsed_packet = - match rate_limiter.verify_packet(Some(addr.as_socket().unwrap().ip()), packet, &mut t.dst_buf) { + match rate_limiter.verify_packet(Some(addr.as_socket().unwrap().ip()), &mut t.src_buf, packet_len, &mut t.dst_buf) { Ok(packet) => packet, Err(TunnResult::WriteToNetwork(cookie)) => { if let Err(err) = udp.send_to(cookie, &addr) { @@ -866,7 +866,7 @@ impl Device { loop { let res = { let mut tun = peer.tunnel.lock(); - tun.decapsulate(None, &[], &mut t.dst_buf[..]) + tun.decapsulate(None, &mut [], 0, &mut t.dst_buf[..]) }; let TunnResult::WriteToNetwork(packet) = res else { @@ -926,11 +926,7 @@ impl Device { let res = { let mut tun = peer.tunnel.lock(); - tun.decapsulate( - Some(peer_addr), - &t.src_buf[..read_bytes], - &mut t.dst_buf[..], - ) + tun.decapsulate(Some(peer_addr), &mut t.src_buf, read_bytes, &mut t.dst_buf) }; match res { @@ -987,7 +983,7 @@ impl Device { loop { let res = { let mut tun = peer.tunnel.lock(); - tun.decapsulate(None, &[], &mut t.dst_buf[..]) + tun.decapsulate(None, &mut [], 0, &mut t.dst_buf[..]) }; let TunnResult::WriteToNetwork(packet) = res else { break; @@ -1026,8 +1022,8 @@ impl Device { let peers = &d.peers_by_ip; for _ in 0..MAX_ITR { - let src_buf = &mut t.src_buf[DATA_OFFSET..]; - let src_len = match iface.read(&mut src_buf[..mtu]) { + let data_buf = &mut t.src_buf[DATA_OFFSET..]; + let data_len = match iface.read(&mut data_buf[..mtu]) { Ok(src) => src.len(), Err(Error::IfaceRead(e)) => { let ek = e.kind(); @@ -1047,7 +1043,7 @@ impl Device { } }; - let dst_addr = match Tunn::dst_address(&src_buf[..src_len]) { + let dst_addr = match Tunn::dst_address(&data_buf[..data_len]) { Some(addr) => addr, None => continue, }; @@ -1059,14 +1055,16 @@ impl Device { if let Some(callback) = &d.config.firewall_process_outbound_callback { - if !callback(&peer.public_key.0, &src_buf[..src_len]) { + if !callback(&peer.public_key.0, &data_buf[..data_len]) { continue; } } let res = { let mut tun = peer.tunnel.lock(); - tun.encapsulate(&mut t.src_buf[..], src_len) + // Pass complete buffer as it contains space for headers as well + // Encryption is to be done in-place + tun.encapsulate(&mut t.src_buf[..], data_len) }; match res { TunnResult::Done => {} diff --git a/neptun/src/noise/integration_tests/mod.rs b/neptun/src/noise/integration_tests/mod.rs index f3558cc..68d7999 100644 --- a/neptun/src/noise/integration_tests/mod.rs +++ b/neptun/src/noise/integration_tests/mod.rs @@ -74,7 +74,7 @@ mod tests { let mut receiving_buffer = vec![0u8; MAX_PACKET]; // Initiate handshake from a side - match a.tunnel.lock().encapsulate(&[], &mut sending_buffer) { + match a.tunnel.lock().encapsulate(&mut sending_buffer, 0) { TunnResult::WriteToNetwork(msg) => { a.client_socket .send_to(msg, b.client_address) @@ -95,7 +95,8 @@ mod tests { match b.tunnel.lock().decapsulate( None, - &receiving_buffer[..bytes_read], + &mut receiving_buffer, + bytes_read, &mut sending_buffer, ) { TunnResult::WriteToNetwork(msg) => { @@ -118,7 +119,8 @@ mod tests { match a.tunnel.lock().decapsulate( None, - &receiving_buffer[..bytes_read], + &mut receiving_buffer, + bytes_read, &mut sending_buffer, ) { TunnResult::WriteToNetwork(msg) => { @@ -142,7 +144,8 @@ mod tests { match b.tunnel.lock().decapsulate( None, - &receiving_buffer[..bytes_read], + &mut receiving_buffer, + bytes_read, &mut sending_buffer, ) { TunnResult::Done => (), diff --git a/neptun/src/noise/mod.rs b/neptun/src/noise/mod.rs index 9745c6e..a7e3f55 100644 --- a/neptun/src/noise/mod.rs +++ b/neptun/src/noise/mod.rs @@ -96,6 +96,7 @@ const DATA_OVERHEAD_SZ: usize = 32; #[derive(Debug)] pub struct HandshakeInit<'a> { + msg_buffer: &'a [u8], sender_idx: u32, unencrypted_ephemeral: &'a [u8; 32], encrypted_static: &'a [u8], @@ -104,6 +105,7 @@ pub struct HandshakeInit<'a> { #[derive(Debug)] pub struct HandshakeResponse<'a> { + msg_buffer: &'a [u8], sender_idx: u32, pub receiver_idx: u32, unencrypted_ephemeral: &'a [u8; 32], @@ -121,7 +123,8 @@ pub struct PacketCookieReply<'a> { pub struct PacketData<'a> { pub receiver_idx: u32, counter: u64, - encrypted_encapsulated_packet: &'a [u8], + encrypted_packet_buffer: &'a mut [u8], + data_len: usize, } /// Describes a packet from network @@ -139,16 +142,20 @@ impl Tunn { } #[inline(always)] - pub fn parse_incoming_packet(src: &[u8]) -> Result { + pub fn parse_incoming_packet<'a>( + src: &'a mut [u8], + data_len: usize, + ) -> Result, WireGuardError> { if src.len() < 4 { return Err(WireGuardError::InvalidPacket); } // Checks the type, as well as the reserved zero fields let packet_type = u32::from_le_bytes(src[0..4].try_into().unwrap()); - - Ok(match (packet_type, src.len()) { + Ok(match (packet_type, data_len) { (HANDSHAKE_INIT, HANDSHAKE_INIT_SZ) => Packet::HandshakeInit(HandshakeInit { + // Keeping a reference to entire buffer for verifying macs + msg_buffer: &src[..data_len], sender_idx: u32::from_le_bytes(src[4..8].try_into().unwrap()), unencrypted_ephemeral: <&[u8; 32] as TryFrom<&[u8]>>::try_from(&src[8..40]) .expect("length already checked above"), @@ -156,6 +163,8 @@ impl Tunn { encrypted_timestamp: &src[88..116], }), (HANDSHAKE_RESP, HANDSHAKE_RESP_SZ) => Packet::HandshakeResponse(HandshakeResponse { + // Keeping a reference to entire buffer for verifying macs + msg_buffer: &src[..data_len], sender_idx: u32::from_le_bytes(src[4..8].try_into().unwrap()), receiver_idx: u32::from_le_bytes(src[8..12].try_into().unwrap()), unencrypted_ephemeral: <&[u8; 32] as TryFrom<&[u8]>>::try_from(&src[12..44]) @@ -170,7 +179,8 @@ impl Tunn { (DATA, DATA_OVERHEAD_SZ..=std::usize::MAX) => Packet::PacketData(PacketData { receiver_idx: u32::from_le_bytes(src[4..8].try_into().unwrap()), counter: u64::from_le_bytes(src[8..16].try_into().unwrap()), - encrypted_encapsulated_packet: &src[16..], + encrypted_packet_buffer: &mut src[16..], + data_len: data_len - 16, }), _ => return Err(WireGuardError::InvalidPacket), }) @@ -308,42 +318,44 @@ impl Tunn { pub fn decapsulate<'a>( &mut self, src_addr: Option, - datagram: &[u8], + data_buffer: &'a mut [u8], + data_len: usize, dst: &'a mut [u8], ) -> TunnResult<'a> { - if datagram.is_empty() { + if data_len == 0 { // Indicates a repeated call return self.send_queued_packet(dst); } let mut cookie = [0u8; COOKIE_REPLY_SZ]; - let packet = match self - .rate_limiter - .verify_packet(src_addr, datagram, &mut cookie) - { - Ok(packet) => packet, - Err(TunnResult::WriteToNetwork(cookie)) => { - dst[..cookie.len()].copy_from_slice(cookie); - self.tx_bytes += cookie.len(); - return TunnResult::WriteToNetwork(&mut dst[..cookie.len()]); - } - Err(TunnResult::Err(e)) => return TunnResult::Err(e), - _ => unreachable!(), - }; + let packet = + match self + .rate_limiter + .verify_packet(src_addr, data_buffer, data_len, &mut cookie) + { + Ok(packet) => packet, + Err(TunnResult::WriteToNetwork(cookie)) => { + dst[..cookie.len()].copy_from_slice(cookie); + self.tx_bytes += cookie.len(); + return TunnResult::WriteToNetwork(&mut dst[..cookie.len()]); + } + Err(TunnResult::Err(e)) => return TunnResult::Err(e), + _ => unreachable!(), + }; self.handle_verified_packet(packet, dst) } pub(crate) fn handle_verified_packet<'a>( &mut self, - packet: Packet, + packet: Packet<'a>, dst: &'a mut [u8], ) -> TunnResult<'a> { match packet { Packet::HandshakeInit(p) => self.handle_handshake_init(p, dst), Packet::HandshakeResponse(p) => self.handle_handshake_response(p, dst), Packet::PacketCookieReply(p) => self.handle_cookie_reply(p), - Packet::PacketData(p) => self.handle_data(p, dst), + Packet::PacketData(p) => self.handle_data(p), } .unwrap_or_else(TunnResult::from) } @@ -454,8 +466,7 @@ impl Tunn { /// Decrypts a data packet, and stores the decapsulated packet in dst. fn handle_data<'a>( &mut self, - packet: PacketData, - dst: &'a mut [u8], + packet: PacketData<'a>, ) -> Result, WireGuardError> { let r_idx = packet.receiver_idx as usize; let idx = r_idx % N_SESSIONS; @@ -467,7 +478,7 @@ impl Tunn { tracing::trace!(message = "No current session available", remote_idx = r_idx); WireGuardError::NoCurrentSession })?; - session.receive_packet_data(packet, dst)? + session.receive_packet_data(packet)? }; self.set_current_session(r_idx); @@ -679,9 +690,9 @@ mod tests { handshake_init.into() } - fn create_handshake_response(tun: &mut Tunn, handshake_init: &[u8]) -> Vec { + fn create_handshake_response(tun: &mut Tunn, handshake_init: &mut [u8]) -> Vec { let mut dst = vec![0u8; 2048]; - let handshake_resp = tun.decapsulate(None, handshake_init, &mut dst); + let handshake_resp = tun.decapsulate(None, handshake_init, handshake_init.len(), &mut dst); assert!(matches!(handshake_resp, TunnResult::WriteToNetwork(_))); let handshake_resp = if let TunnResult::WriteToNetwork(sent) = handshake_resp { @@ -693,9 +704,9 @@ mod tests { handshake_resp.into() } - fn parse_handshake_resp(tun: &mut Tunn, handshake_resp: &[u8]) -> Vec { + fn parse_handshake_resp(tun: &mut Tunn, handshake_resp: &mut [u8]) -> Vec { let mut dst = vec![0u8; 2048]; - let keepalive = tun.decapsulate(None, handshake_resp, &mut dst); + let keepalive = tun.decapsulate(None, handshake_resp, handshake_resp.len(), &mut dst); assert!(matches!(keepalive, TunnResult::WriteToNetwork(_))); let keepalive = if let TunnResult::WriteToNetwork(sent) = keepalive { @@ -707,18 +718,18 @@ mod tests { keepalive.into() } - fn parse_keepalive(tun: &mut Tunn, keepalive: &[u8]) { + fn parse_keepalive(tun: &mut Tunn, keepalive: &mut [u8]) { let mut dst = vec![0u8; 2048]; - let keepalive = tun.decapsulate(None, keepalive, &mut dst); + let keepalive = tun.decapsulate(None, keepalive, 0, &mut dst); assert!(matches!(keepalive, TunnResult::Done)); } fn create_two_tuns_and_handshake() -> (Tunn, Tunn) { let (mut my_tun, mut their_tun) = create_two_tuns(); - let init = create_handshake_init(&mut my_tun); - let resp = create_handshake_response(&mut their_tun, &init); - let keepalive = parse_handshake_resp(&mut my_tun, &resp); - parse_keepalive(&mut their_tun, &keepalive); + let mut init = create_handshake_init(&mut my_tun); + let mut resp = create_handshake_response(&mut their_tun, &mut init); + let mut keepalive = parse_handshake_resp(&mut my_tun, &mut resp); + parse_keepalive(&mut their_tun, &mut keepalive); (my_tun, their_tun) } @@ -742,7 +753,8 @@ mod tests { } else { unreachable!(); }; - let packet = Tunn::parse_incoming_packet(packet_data).unwrap(); + let len = packet_data.len(); + let packet = Tunn::parse_incoming_packet(packet_data, len).unwrap(); assert!(matches!(packet, Packet::HandshakeInit(_))); } @@ -754,27 +766,30 @@ mod tests { #[test] fn handshake_init() { let (mut my_tun, _their_tun) = create_two_tuns(); - let init = create_handshake_init(&mut my_tun); - let packet = Tunn::parse_incoming_packet(&init).unwrap(); + let mut init = create_handshake_init(&mut my_tun); + let init_len = init.len(); + let packet = Tunn::parse_incoming_packet(&mut init, init_len).unwrap(); assert!(matches!(packet, Packet::HandshakeInit(_))); } #[test] fn handshake_init_and_response() { let (mut my_tun, mut their_tun) = create_two_tuns(); - let init = create_handshake_init(&mut my_tun); - let resp = create_handshake_response(&mut their_tun, &init); - let packet = Tunn::parse_incoming_packet(&resp).unwrap(); + let mut init = create_handshake_init(&mut my_tun); + let mut resp = create_handshake_response(&mut their_tun, &mut init); + let resp_len = resp.len(); + let packet = Tunn::parse_incoming_packet(&mut resp, resp_len).unwrap(); assert!(matches!(packet, Packet::HandshakeResponse(_))); } #[test] fn full_handshake() { let (mut my_tun, mut their_tun) = create_two_tuns(); - let init = create_handshake_init(&mut my_tun); - let resp = create_handshake_response(&mut their_tun, &init); - let keepalive = parse_handshake_resp(&mut my_tun, &resp); - let packet = Tunn::parse_incoming_packet(&keepalive).unwrap(); + let mut init = create_handshake_init(&mut my_tun); + let mut resp = create_handshake_response(&mut their_tun, &mut init); + let mut keepalive = parse_handshake_resp(&mut my_tun, &mut resp); + let kepalive_len = keepalive.len(); + let packet = Tunn::parse_incoming_packet(&mut keepalive, kepalive_len).unwrap(); assert!(matches!(packet, Packet::PacketData(_))); } @@ -801,7 +816,9 @@ mod tests { TunnResult::Done )); let sent_packet_buf = create_ipv4_udp_packet(); - let data = my_tun.encapsulate(&sent_packet_buf, &mut my_dst); + let packet_len = sent_packet_buf.len(); + my_dst[16..16 + packet_len].copy_from_slice(&sent_packet_buf[..packet_len]); + let data = my_tun.encapsulate(&mut my_dst, packet_len); assert!(matches!(data, TunnResult::WriteToNetwork(_))); //Advance to timeout @@ -815,8 +832,9 @@ mod tests { fn handshake_no_resp_rekey_timeout() { let (mut my_tun, _their_tun) = create_two_tuns(); - let init = create_handshake_init(&mut my_tun); - let packet = Tunn::parse_incoming_packet(&init).unwrap(); + let mut init = create_handshake_init(&mut my_tun); + let init_len = init.len(); + let packet = Tunn::parse_incoming_packet(&mut init, init_len).unwrap(); assert!(matches!(packet, Packet::HandshakeInit(_))); mock_instant::MockClock::advance(REKEY_TIMEOUT.into()); @@ -829,9 +847,10 @@ mod tests { let mut my_dst = [0u8; 1024]; let mut their_dst = [0u8; 1024]; - let mut sent_packet_buf = create_ipv4_udp_packet(); - - let data = my_tun.encapsulate(&mut sent_packet_buf, sent_packet_buf.len()); + let sent_packet_buf = create_ipv4_udp_packet(); + let packet_buf_len = sent_packet_buf.len(); + my_dst[16..16 + packet_buf_len].copy_from_slice(&sent_packet_buf[..packet_buf_len]); + let data = my_tun.encapsulate(&mut my_dst, packet_buf_len); assert!(matches!(data, TunnResult::WriteToNetwork(_))); let data = if let TunnResult::WriteToNetwork(sent) = data { sent @@ -839,7 +858,7 @@ mod tests { unreachable!(); }; - let data = their_tun.decapsulate(None, data, &mut their_dst); + let data = their_tun.decapsulate(None, data, data.len(), &mut their_dst); assert!(matches!(data, TunnResult::WriteToTunnelV4(..))); let recv_packet_buf = if let TunnResult::WriteToTunnelV4(recv, _addr) = data { recv diff --git a/neptun/src/noise/rate_limiter.rs b/neptun/src/noise/rate_limiter.rs index db3db81..18751d3 100644 --- a/neptun/src/noise/rate_limiter.rs +++ b/neptun/src/noise/rate_limiter.rs @@ -156,16 +156,25 @@ impl RateLimiter { pub fn verify_packet<'a, 'b>( &self, src_addr: Option, - src: &'a [u8], + src: &'a mut [u8], + message_len: usize, dst: &'b mut [u8], ) -> Result, TunnResult<'b>> { - let packet = Tunn::parse_incoming_packet(src)?; + let packet = Tunn::parse_incoming_packet(src, message_len)?; // Verify and rate limit handshake messages only - if let Packet::HandshakeInit(HandshakeInit { sender_idx, .. }) - | Packet::HandshakeResponse(HandshakeResponse { sender_idx, .. }) = packet + if let Packet::HandshakeInit(HandshakeInit { + msg_buffer, + sender_idx, + .. + }) + | Packet::HandshakeResponse(HandshakeResponse { + msg_buffer, + sender_idx, + .. + }) = packet { - let (msg, macs) = src.split_at(src.len() - 32); + let (msg, macs) = msg_buffer.split_at(message_len - 32); let (mac1, mac2) = macs.split_at(16); let computed_mac1 = b2s_keyed_mac_16(&self.mac1_key, msg); diff --git a/neptun/src/noise/session.rs b/neptun/src/noise/session.rs index a1d6cd3..092c974 100644 --- a/neptun/src/noise/session.rs +++ b/neptun/src/noise/session.rs @@ -238,11 +238,9 @@ impl Session { /// return the size of the encapsulated packet on success pub(super) fn receive_packet_data<'a>( &self, - packet: PacketData, - dst: &'a mut [u8], + packet: PacketData<'a>, ) -> Result<&'a mut [u8], WireGuardError> { - let ct_len = packet.encrypted_encapsulated_packet.len(); - if dst.len() < ct_len { + if packet.encrypted_packet_buffer.len() < packet.data_len { // This is a very incorrect use of the library, therefore panic and not error panic!("The destination buffer is too small"); } @@ -255,12 +253,11 @@ impl Session { let ret = { let mut nonce = [0u8; 12]; nonce[4..12].copy_from_slice(&packet.counter.to_le_bytes()); - dst[..ct_len].copy_from_slice(packet.encrypted_encapsulated_packet); self.receiver .open_in_place( Nonce::assume_unique_for_key(nonce), Aad::from(&[]), - &mut dst[..ct_len], + &mut packet.encrypted_packet_buffer[..packet.data_len], ) .map_err(|_| WireGuardError::InvalidAeadTag)? }; diff --git a/xtask/src/perf.rs b/xtask/src/perf.rs index 3bb94ca..376f99b 100644 --- a/xtask/src/perf.rs +++ b/xtask/src/perf.rs @@ -22,7 +22,7 @@ impl GitWorktree { .expect("Failed to create base worktree"); GitWorktree { name: name.to_string(), - sh: sh, + sh, } } }