From 8092b80dc9bedd0eeadbe8bc93f4cdc32c151e4c Mon Sep 17 00:00:00 2001 From: Hasan Date: Thu, 28 Nov 2024 15:57:04 +0100 Subject: [PATCH] Only mark timers to update Timers are updated on every single packet call, but their accuracy is limited to 250ms where currentTime is updated. So just mark timers to update, and update them only once in update_timers. --- neptun/src/noise/mod.rs | 22 ++++----- neptun/src/noise/timers.rs | 94 +++++++++++++++++++++++++++++++++++++- 2 files changed, 103 insertions(+), 13 deletions(-) diff --git a/neptun/src/noise/mod.rs b/neptun/src/noise/mod.rs index 5b91302..5e7269a 100644 --- a/neptun/src/noise/mod.rs +++ b/neptun/src/noise/mod.rs @@ -276,10 +276,10 @@ impl Tunn { if let Some(ref session) = self.sessions[current % N_SESSIONS] { // Send the packet using an established session let packet = session.format_packet_data(src, dst); - self.timer_tick(TimerName::TimeLastPacketSent); + self.mark_timer_to_update(TimerName::TimeLastPacketSent); // Exclude Keepalive packets from timer update. if !src.is_empty() { - self.timer_tick(TimerName::TimeLastDataPacketSent); + self.mark_timer_to_update(TimerName::TimeLastDataPacketSent); } self.tx_bytes += packet.len(); return TunnResult::WriteToNetwork(packet); @@ -365,8 +365,8 @@ impl Tunn { let index = session.local_index(); self.sessions[index % N_SESSIONS] = Some(session); - self.timer_tick(TimerName::TimeLastPacketReceived); - self.timer_tick(TimerName::TimeLastPacketSent); + self.mark_timer_to_update(TimerName::TimeLastPacketReceived); + self.mark_timer_to_update(TimerName::TimeLastPacketSent); self.timer_tick_session_established(false, index); // New session established, we are not the initiator tracing::debug!(message = "Sending handshake_response", local_idx = index); @@ -399,7 +399,7 @@ impl Tunn { let index = l_idx % N_SESSIONS; self.sessions[index] = Some(session); - self.timer_tick(TimerName::TimeLastPacketReceived); + self.mark_timer_to_update(TimerName::TimeLastPacketReceived); self.timer_tick_session_established(true, index); // New session established, we are the initiator self.set_current_session(l_idx); @@ -424,8 +424,8 @@ impl Tunn { // Increase the rx_bytes accordingly self.rx_bytes += COOKIE_REPLY_SZ; - self.timer_tick(TimerName::TimeLastPacketReceived); - self.timer_tick(TimerName::TimeCookieReceived); + self.mark_timer_to_update(TimerName::TimeLastPacketReceived); + self.mark_timer_to_update(TimerName::TimeCookieReceived); tracing::debug!("Did set cookie"); @@ -469,7 +469,7 @@ impl Tunn { self.set_current_session(r_idx); - self.timer_tick(TimerName::TimeLastPacketReceived); + self.mark_timer_to_update(TimerName::TimeLastPacketReceived); Ok(self.validate_decapsulated_packet(decapsulated_packet)) } @@ -496,9 +496,9 @@ impl Tunn { tracing::debug!("Sending handshake_initiation"); if starting_new_handshake { - self.timer_tick(TimerName::TimeLastHandshakeStarted); + self.mark_timer_to_update(TimerName::TimeLastHandshakeStarted); } - self.timer_tick(TimerName::TimeLastPacketSent); + self.mark_timer_to_update(TimerName::TimeLastPacketSent); self.tx_bytes += packet.len(); TunnResult::WriteToNetwork(packet) @@ -548,7 +548,7 @@ impl Tunn { return TunnResult::Err(WireGuardError::InvalidPacket); } - self.timer_tick(TimerName::TimeLastDataPacketReceived); + self.mark_timer_to_update(TimerName::TimeLastDataPacketReceived); self.rx_bytes += message_data_len(computed_len); match src_ip_address { diff --git a/neptun/src/noise/timers.rs b/neptun/src/noise/timers.rs index 07b4bfd..6cce8df 100644 --- a/neptun/src/noise/timers.rs +++ b/neptun/src/noise/timers.rs @@ -6,6 +6,7 @@ use super::errors::WireGuardError; use crate::noise::{safe_duration::SafeDuration as Duration, Tunn, TunnResult}; use std::mem; use std::ops::{Index, IndexMut}; +use std::sync::atomic::AtomicU16; use std::time::SystemTime; #[cfg(feature = "mock-instant")] @@ -42,7 +43,7 @@ pub(crate) const REKEY_TIMEOUT: Duration = Duration::from_secs(5); const KEEPALIVE_TIMEOUT: Duration = Duration::from_secs(10); const COOKIE_EXPIRATION_TIME: Duration = Duration::from_secs(120); -#[derive(Debug)] +#[derive(Debug, Clone, Copy)] pub enum TimerName { /// Current time, updated each call to `update_timers` TimeCurrent, @@ -65,6 +66,20 @@ pub enum TimerName { Top, } +impl TimerName { + pub const VALUES: [Self; TimerName::Top as usize] = [ + Self::TimeCurrent, + Self::TimeSessionEstablished, + Self::TimeLastHandshakeStarted, + Self::TimeLastPacketReceived, + Self::TimeLastPacketSent, + Self::TimeLastDataPacketReceived, + Self::TimeLastDataPacketSent, + Self::TimeCookieReceived, + Self::TimePersistentKeepalive, + ]; +} + use self::TimerName::*; #[derive(Debug)] @@ -82,6 +97,7 @@ pub struct Timers { persistent_keepalive: usize, /// Should this timer call reset rr function (if not a shared rr instance) pub(super) should_reset_rr: bool, + timers_to_update_mask: AtomicU16, } impl Timers { @@ -95,6 +111,7 @@ impl Timers { want_handshake_since: Default::default(), persistent_keepalive: usize::from(persistent_keepalive.unwrap_or(0)), should_reset_rr: reset_rr, + timers_to_update_mask: Default::default(), } } @@ -128,7 +145,13 @@ impl IndexMut for Timers { } impl Tunn { - pub(super) fn timer_tick(&mut self, timer_name: TimerName) { + pub(super) fn mark_timer_to_update(&self, timer_name: TimerName) { + self.timers + .timers_to_update_mask + .fetch_or(1 << timer_name as u16, std::sync::atomic::Ordering::Relaxed); + } + + fn timer_tick(&mut self, timer_name: TimerName) { let time = self.timers[TimeCurrent]; match timer_name { TimeLastPacketReceived => { @@ -207,6 +230,21 @@ impl Tunn { let now = time.duration_since(self.timers.time_started).into(); self.timers[TimeCurrent] = now; + // Check which timers to update, and update them + let timer_mask = self + .timers + .timers_to_update_mask + .load(std::sync::atomic::Ordering::Relaxed); + for timer_name in TimerName::VALUES { + if (timer_mask & (1 << (timer_name as u16))) != 0 { + self.timer_tick(timer_name); + } + } + // Reset all marked bits + self.timers + .timers_to_update_mask + .store(0, std::sync::atomic::Ordering::Relaxed); + self.update_session_timers(now); // Load timers only once: @@ -380,3 +418,55 @@ impl Tunn { self.timers.persistent_keepalive = keepalive as usize; } } + +#[cfg(test)] +mod tests { + use rand::RngCore; + use rand_core::OsRng; + + use crate::noise::{safe_duration::SafeDuration, Tunn}; + + use super::TimerName; + + #[test] + fn create_two_tuns() { + let my_secret_key = x25519_dalek::StaticSecret::random_from_rng(OsRng); + let my_idx = OsRng.next_u32(); + + let their_secret_key = x25519_dalek::StaticSecret::random_from_rng(OsRng); + let their_public_key = x25519_dalek::PublicKey::from(&their_secret_key); + + let mut my_tun = + Tunn::new(my_secret_key, their_public_key, None, None, my_idx, None).unwrap(); + + // Mark timers to update + my_tun.mark_timer_to_update(super::TimerName::TimeLastDataPacketSent); + my_tun.mark_timer_to_update(super::TimerName::TimeLastDataPacketReceived); + my_tun.mark_timer_to_update(super::TimerName::TimePersistentKeepalive); + + // Update timers + my_tun.update_timers(&mut [0]); + + // Only those timers marked should be udpated + assert!(!my_tun.timers[TimerName::TimeLastDataPacketSent].is_zero()); + assert!(!my_tun.timers[TimerName::TimeLastDataPacketReceived].is_zero()); + assert!(!my_tun.timers[TimerName::TimePersistentKeepalive].is_zero()); + + // Unmarked timers should still be 0 + assert!(my_tun.timers[TimerName::TimeCookieReceived].is_zero()); + assert!(my_tun.timers[TimerName::TimeLastHandshakeStarted].is_zero()); + assert!(my_tun.timers[TimerName::TimeLastPacketReceived].is_zero()); + + // Reset the timers + my_tun.timers[TimerName::TimeLastDataPacketSent] = SafeDuration::from_millis(0); + my_tun.timers[TimerName::TimeLastDataPacketReceived] = SafeDuration::from_millis(0); + my_tun.timers[TimerName::TimePersistentKeepalive] = SafeDuration::from_millis(0); + + my_tun.update_timers(&mut [0]); + + // Now the timers should not update + assert!(my_tun.timers[TimerName::TimeLastDataPacketSent].is_zero()); + assert!(my_tun.timers[TimerName::TimeLastDataPacketReceived].is_zero()); + assert!(my_tun.timers[TimerName::TimePersistentKeepalive].is_zero()); + } +}