diff --git a/neptun/src/noise/timers.rs b/neptun/src/noise/timers.rs index 6cce8df..a231128 100644 --- a/neptun/src/noise/timers.rs +++ b/neptun/src/noise/timers.rs @@ -67,8 +67,7 @@ pub enum TimerName { } impl TimerName { - pub const VALUES: [Self; TimerName::Top as usize] = [ - Self::TimeCurrent, + pub const VALUES: [Self; TimerName::Top as usize - 1] = [ Self::TimeSessionEstablished, Self::TimeLastHandshakeStarted, Self::TimeLastPacketReceived, @@ -215,6 +214,14 @@ impl Tunn { } } + fn tick_marked_timers(&mut self, timer_mask: u16) { + for timer_name in TimerName::VALUES { + if (timer_mask & (1 << (timer_name as u16))) != 0 { + self.timer_tick(timer_name); + } + } + } + pub fn update_timers<'a>(&mut self, dst: &'a mut [u8]) -> TunnResult<'a> { let mut handshake_initiation_required = false; let mut keepalive_required = false; @@ -234,16 +241,8 @@ impl Tunn { let timer_mask = self .timers .timers_to_update_mask - .load(std::sync::atomic::Ordering::Relaxed); - for timer_name in TimerName::VALUES { - if (timer_mask & (1 << (timer_name as u16))) != 0 { - self.timer_tick(timer_name); - } - } - // Reset all marked bits - self.timers - .timers_to_update_mask - .store(0, std::sync::atomic::Ordering::Relaxed); + .swap(0, std::sync::atomic::Ordering::Relaxed); + self.tick_marked_timers(timer_mask); self.update_session_timers(now); @@ -429,7 +428,7 @@ mod tests { use super::TimerName; #[test] - fn create_two_tuns() { + fn test_update_marked_timers() { let my_secret_key = x25519_dalek::StaticSecret::random_from_rng(OsRng); let my_idx = OsRng.next_u32(); @@ -469,4 +468,41 @@ mod tests { assert!(my_tun.timers[TimerName::TimeLastDataPacketReceived].is_zero()); assert!(my_tun.timers[TimerName::TimePersistentKeepalive].is_zero()); } + + #[test] + fn test_mark_timers_during_update() { + let my_secret_key = x25519_dalek::StaticSecret::random_from_rng(OsRng); + let my_idx = OsRng.next_u32(); + + let their_secret_key = x25519_dalek::StaticSecret::random_from_rng(OsRng); + let their_public_key = x25519_dalek::PublicKey::from(&their_secret_key); + + let mut my_tun = + Tunn::new(my_secret_key, their_public_key, None, None, my_idx, None).unwrap(); + + // Mark timers to update + my_tun.mark_timer_to_update(super::TimerName::TimeLastDataPacketSent); + + let timer_mask = my_tun + .timers + .timers_to_update_mask + .swap(0, std::sync::atomic::Ordering::Relaxed); + + my_tun.mark_timer_to_update(super::TimerName::TimeLastDataPacketReceived); + + my_tun.tick_marked_timers(timer_mask); + + // Only those timers marked should be udpated + assert!(!my_tun.timers[TimerName::TimeLastDataPacketSent].is_zero()); + assert!(my_tun.timers[TimerName::TimeLastDataPacketReceived].is_zero()); + + // Reset the timers + my_tun.timers[TimerName::TimeLastDataPacketSent] = SafeDuration::from_millis(0); + + my_tun.update_timers(&mut [0]); + + // Now the timers should not update + assert!(my_tun.timers[TimerName::TimeLastDataPacketSent].is_zero()); + assert!(!my_tun.timers[TimerName::TimeLastDataPacketReceived].is_zero()); + } }