From 4326a1f34151752f9124a6b84f46c03ab4be1295 Mon Sep 17 00:00:00 2001 From: Hasan Date: Tue, 17 Dec 2024 17:38:57 +0100 Subject: [PATCH] Test marking timers after fetch --- neptun/src/device/mod.rs | 3 +- neptun/src/noise/mod.rs | 10 +++++-- neptun/src/noise/timers.rs | 56 +++++++++----------------------------- 3 files changed, 23 insertions(+), 46 deletions(-) diff --git a/neptun/src/device/mod.rs b/neptun/src/device/mod.rs index 0913aaa..12a0e28 100644 --- a/neptun/src/device/mod.rs +++ b/neptun/src/device/mod.rs @@ -700,7 +700,8 @@ impl Device { let res = { let mut tun = peer.tunnel.lock(); - tun.update_timers(&mut t.dst_buf[..]) + let timer_mask = tun.fetch_timer_mask(); + tun.update_timers(&mut t.dst_buf[..], timer_mask) }; match res { TunnResult::Done => {} diff --git a/neptun/src/noise/mod.rs b/neptun/src/noise/mod.rs index 5e7269a..d210281 100644 --- a/neptun/src/noise/mod.rs +++ b/neptun/src/noise/mod.rs @@ -777,8 +777,14 @@ mod tests { fn full_handshake_plus_timers() { let (mut my_tun, mut their_tun) = create_two_tuns_and_handshake(); // Time has not yet advanced so their is nothing to do - assert!(matches!(my_tun.update_timers(&mut []), TunnResult::Done)); - assert!(matches!(their_tun.update_timers(&mut []), TunnResult::Done)); + assert!(matches!( + my_tun.update_timers(&mut [], my_tun.fetch_timer_mask()), + TunnResult::Done + )); + assert!(matches!( + their_tun.update_timers(&mut [], their_tun.fetch_timer_mask()), + TunnResult::Done + )); } #[test] diff --git a/neptun/src/noise/timers.rs b/neptun/src/noise/timers.rs index 89c90f5..ac13b60 100644 --- a/neptun/src/noise/timers.rs +++ b/neptun/src/noise/timers.rs @@ -150,6 +150,12 @@ impl Tunn { .fetch_or(1 << timer_name as u16, std::sync::atomic::Ordering::Relaxed); } + pub fn fetch_timer_mask(&self) -> u16 { + self.timers + .timers_to_update_mask + .swap(0, std::sync::atomic::Ordering::Relaxed) + } + fn timer_tick(&mut self, timer_name: TimerName) { let time = self.timers[TimeCurrent]; match timer_name { @@ -222,7 +228,7 @@ impl Tunn { } } - pub fn update_timers<'a>(&mut self, dst: &'a mut [u8]) -> TunnResult<'a> { + pub fn update_timers<'a>(&mut self, dst: &'a mut [u8], timer_mask: u16) -> TunnResult<'a> { let mut handshake_initiation_required = false; let mut keepalive_required = false; @@ -238,10 +244,6 @@ impl Tunn { self.timers[TimeCurrent] = now; // Check which timers to update, and update them - let timer_mask = self - .timers - .timers_to_update_mask - .swap(0, std::sync::atomic::Ordering::Relaxed); self.tick_marked_timers(timer_mask); self.update_session_timers(now); @@ -444,7 +446,7 @@ mod tests { my_tun.mark_timer_to_update(super::TimerName::TimePersistentKeepalive); // Update timers - my_tun.update_timers(&mut [0]); + my_tun.update_timers(&mut [0], my_tun.fetch_timer_mask()); // Only those timers marked should be udpated assert!(!my_tun.timers[TimerName::TimeLastDataPacketSent].is_zero()); @@ -461,7 +463,7 @@ mod tests { my_tun.timers[TimerName::TimeLastDataPacketReceived] = SafeDuration::from_millis(0); my_tun.timers[TimerName::TimePersistentKeepalive] = SafeDuration::from_millis(0); - my_tun.update_timers(&mut [0]); + my_tun.update_timers(&mut [0], my_tun.fetch_timer_mask()); // Now the timers should not update assert!(my_tun.timers[TimerName::TimeLastDataPacketSent].is_zero()); @@ -470,7 +472,7 @@ mod tests { } #[test] - fn test_fetching_timers_by_swap_vs_load() { + fn test_marking_timers_after_fetch() { let my_secret_key = x25519_dalek::StaticSecret::random_from_rng(OsRng); let my_idx = OsRng.next_u32(); @@ -484,51 +486,19 @@ mod tests { my_tun.mark_timer_to_update(super::TimerName::TimeLastDataPacketSent); // Fetch by swap - let timer_mask = my_tun - .timers - .timers_to_update_mask - .swap(0, std::sync::atomic::Ordering::Relaxed); + let timer_mask = my_tun.fetch_timer_mask(); // Timer marked after fetch my_tun.mark_timer_to_update(super::TimerName::TimeLastDataPacketReceived); - my_tun.tick_marked_timers(timer_mask); + my_tun.update_timers(&mut [0], timer_mask); // Only those timers marked should be udpated assert!(!my_tun.timers[TimerName::TimeLastDataPacketSent].is_zero()); assert!(my_tun.timers[TimerName::TimeLastDataPacketReceived].is_zero()); - my_tun.update_timers(&mut [0]); - + my_tun.update_timers(&mut [0], my_tun.fetch_timer_mask()); // Previously marked TimeLastDataPacketReceivedOnly updated in next cycle assert!(!my_tun.timers[TimerName::TimeLastDataPacketReceived].is_zero()); - - // Reset the timers - my_tun.timers[TimerName::TimeLastDataPacketSent] = SafeDuration::from_millis(0); - - // Fetch timers by load - let timer_mask = my_tun - .timers - .timers_to_update_mask - .load(std::sync::atomic::Ordering::Relaxed); - - // Timer marked after fetch - my_tun.mark_timer_to_update(super::TimerName::TimeLastDataPacketSent); - - my_tun.tick_marked_timers(timer_mask); - - my_tun - .timers - .timers_to_update_mask - .store(0, std::sync::atomic::Ordering::Relaxed); - - // Only TimeLastDataPacketReceived udpated - assert!(my_tun.timers[TimerName::TimeLastDataPacketSent].is_zero()); - assert!(!my_tun.timers[TimerName::TimeLastDataPacketReceived].is_zero()); - - my_tun.update_timers(&mut [0]); - - // TimeLastDataPacketSent timer still not updated - assert!(my_tun.timers[TimerName::TimeLastDataPacketSent].is_zero()); } }