diff --git a/neptun/src/device/mod.rs b/neptun/src/device/mod.rs index 0913aaa..12a0e28 100644 --- a/neptun/src/device/mod.rs +++ b/neptun/src/device/mod.rs @@ -700,7 +700,8 @@ impl Device { let res = { let mut tun = peer.tunnel.lock(); - tun.update_timers(&mut t.dst_buf[..]) + let timer_mask = tun.fetch_timer_mask(); + tun.update_timers(&mut t.dst_buf[..], timer_mask) }; match res { TunnResult::Done => {} diff --git a/neptun/src/noise/mod.rs b/neptun/src/noise/mod.rs index 5e7269a..7aef8a6 100644 --- a/neptun/src/noise/mod.rs +++ b/neptun/src/noise/mod.rs @@ -730,7 +730,7 @@ mod tests { #[cfg(feature = "mock-instant")] fn update_timer_results_in_handshake(tun: &mut Tunn) { let mut dst = vec![0u8; 2048]; - let result = tun.update_timers(&mut dst); + let result = tun.update_timers(&mut dst, tun.fetch_timer_mask()); assert!(matches!(result, TunnResult::WriteToNetwork(_))); let packet_data = if let TunnResult::WriteToNetwork(data) = result { data @@ -777,8 +777,14 @@ mod tests { fn full_handshake_plus_timers() { let (mut my_tun, mut their_tun) = create_two_tuns_and_handshake(); // Time has not yet advanced so their is nothing to do - assert!(matches!(my_tun.update_timers(&mut []), TunnResult::Done)); - assert!(matches!(their_tun.update_timers(&mut []), TunnResult::Done)); + assert!(matches!( + my_tun.update_timers(&mut [], my_tun.fetch_timer_mask()), + TunnResult::Done + )); + assert!(matches!( + their_tun.update_timers(&mut [], their_tun.fetch_timer_mask()), + TunnResult::Done + )); } #[test] @@ -790,9 +796,12 @@ mod tests { // Advance time 1 second and "send" 1 packet so that we send a handshake // after the timeout mock_instant::MockClock::advance(Duration::from_secs(1)); - assert!(matches!(their_tun.update_timers(&mut []), TunnResult::Done)); assert!(matches!( - my_tun.update_timers(&mut my_dst), + their_tun.update_timers(&mut [], their_tun.fetch_timer_mask()), + TunnResult::Done + )); + assert!(matches!( + my_tun.update_timers(&mut my_dst, my_tun.fetch_timer_mask()), TunnResult::Done )); let sent_packet_buf = create_ipv4_udp_packet(); @@ -801,7 +810,10 @@ mod tests { //Advance to timeout mock_instant::MockClock::advance(REKEY_AFTER_TIME.into()); - assert!(matches!(their_tun.update_timers(&mut []), TunnResult::Done)); + assert!(matches!( + their_tun.update_timers(&mut [], their_tun.fetch_timer_mask()), + TunnResult::Done + )); update_timer_results_in_handshake(&mut my_tun); } diff --git a/neptun/src/noise/timers.rs b/neptun/src/noise/timers.rs index 89c90f5..ac13b60 100644 --- a/neptun/src/noise/timers.rs +++ b/neptun/src/noise/timers.rs @@ -150,6 +150,12 @@ impl Tunn { .fetch_or(1 << timer_name as u16, std::sync::atomic::Ordering::Relaxed); } + pub fn fetch_timer_mask(&self) -> u16 { + self.timers + .timers_to_update_mask + .swap(0, std::sync::atomic::Ordering::Relaxed) + } + fn timer_tick(&mut self, timer_name: TimerName) { let time = self.timers[TimeCurrent]; match timer_name { @@ -222,7 +228,7 @@ impl Tunn { } } - pub fn update_timers<'a>(&mut self, dst: &'a mut [u8]) -> TunnResult<'a> { + pub fn update_timers<'a>(&mut self, dst: &'a mut [u8], timer_mask: u16) -> TunnResult<'a> { let mut handshake_initiation_required = false; let mut keepalive_required = false; @@ -238,10 +244,6 @@ impl Tunn { self.timers[TimeCurrent] = now; // Check which timers to update, and update them - let timer_mask = self - .timers - .timers_to_update_mask - .swap(0, std::sync::atomic::Ordering::Relaxed); self.tick_marked_timers(timer_mask); self.update_session_timers(now); @@ -444,7 +446,7 @@ mod tests { my_tun.mark_timer_to_update(super::TimerName::TimePersistentKeepalive); // Update timers - my_tun.update_timers(&mut [0]); + my_tun.update_timers(&mut [0], my_tun.fetch_timer_mask()); // Only those timers marked should be udpated assert!(!my_tun.timers[TimerName::TimeLastDataPacketSent].is_zero()); @@ -461,7 +463,7 @@ mod tests { my_tun.timers[TimerName::TimeLastDataPacketReceived] = SafeDuration::from_millis(0); my_tun.timers[TimerName::TimePersistentKeepalive] = SafeDuration::from_millis(0); - my_tun.update_timers(&mut [0]); + my_tun.update_timers(&mut [0], my_tun.fetch_timer_mask()); // Now the timers should not update assert!(my_tun.timers[TimerName::TimeLastDataPacketSent].is_zero()); @@ -470,7 +472,7 @@ mod tests { } #[test] - fn test_fetching_timers_by_swap_vs_load() { + fn test_marking_timers_after_fetch() { let my_secret_key = x25519_dalek::StaticSecret::random_from_rng(OsRng); let my_idx = OsRng.next_u32(); @@ -484,51 +486,19 @@ mod tests { my_tun.mark_timer_to_update(super::TimerName::TimeLastDataPacketSent); // Fetch by swap - let timer_mask = my_tun - .timers - .timers_to_update_mask - .swap(0, std::sync::atomic::Ordering::Relaxed); + let timer_mask = my_tun.fetch_timer_mask(); // Timer marked after fetch my_tun.mark_timer_to_update(super::TimerName::TimeLastDataPacketReceived); - my_tun.tick_marked_timers(timer_mask); + my_tun.update_timers(&mut [0], timer_mask); // Only those timers marked should be udpated assert!(!my_tun.timers[TimerName::TimeLastDataPacketSent].is_zero()); assert!(my_tun.timers[TimerName::TimeLastDataPacketReceived].is_zero()); - my_tun.update_timers(&mut [0]); - + my_tun.update_timers(&mut [0], my_tun.fetch_timer_mask()); // Previously marked TimeLastDataPacketReceivedOnly updated in next cycle assert!(!my_tun.timers[TimerName::TimeLastDataPacketReceived].is_zero()); - - // Reset the timers - my_tun.timers[TimerName::TimeLastDataPacketSent] = SafeDuration::from_millis(0); - - // Fetch timers by load - let timer_mask = my_tun - .timers - .timers_to_update_mask - .load(std::sync::atomic::Ordering::Relaxed); - - // Timer marked after fetch - my_tun.mark_timer_to_update(super::TimerName::TimeLastDataPacketSent); - - my_tun.tick_marked_timers(timer_mask); - - my_tun - .timers - .timers_to_update_mask - .store(0, std::sync::atomic::Ordering::Relaxed); - - // Only TimeLastDataPacketReceived udpated - assert!(my_tun.timers[TimerName::TimeLastDataPacketSent].is_zero()); - assert!(!my_tun.timers[TimerName::TimeLastDataPacketReceived].is_zero()); - - my_tun.update_timers(&mut [0]); - - // TimeLastDataPacketSent timer still not updated - assert!(my_tun.timers[TimerName::TimeLastDataPacketSent].is_zero()); } }