Skip to content

Commit

Permalink
Add timer mask as parameter to update timers
Browse files Browse the repository at this point in the history
  • Loading branch information
Hasan6979 committed Dec 18, 2024
1 parent 09834fe commit e64f7d8
Show file tree
Hide file tree
Showing 3 changed files with 33 additions and 50 deletions.
3 changes: 2 additions & 1 deletion neptun/src/device/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -700,7 +700,8 @@ impl Device {

let res = {
let mut tun = peer.tunnel.lock();
tun.update_timers(&mut t.dst_buf[..])
let timer_mask = tun.fetch_timer_mask();
tun.update_timers(&mut t.dst_buf[..], timer_mask)
};
match res {
TunnResult::Done => {}
Expand Down
24 changes: 18 additions & 6 deletions neptun/src/noise/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -730,7 +730,7 @@ mod tests {
#[cfg(feature = "mock-instant")]
fn update_timer_results_in_handshake(tun: &mut Tunn) {
let mut dst = vec![0u8; 2048];
let result = tun.update_timers(&mut dst);
let result = tun.update_timers(&mut dst, tun.fetch_timer_mask());
assert!(matches!(result, TunnResult::WriteToNetwork(_)));
let packet_data = if let TunnResult::WriteToNetwork(data) = result {
data
Expand Down Expand Up @@ -777,8 +777,14 @@ mod tests {
fn full_handshake_plus_timers() {
let (mut my_tun, mut their_tun) = create_two_tuns_and_handshake();
// Time has not yet advanced so their is nothing to do
assert!(matches!(my_tun.update_timers(&mut []), TunnResult::Done));
assert!(matches!(their_tun.update_timers(&mut []), TunnResult::Done));
assert!(matches!(
my_tun.update_timers(&mut [], my_tun.fetch_timer_mask()),
TunnResult::Done
));
assert!(matches!(
their_tun.update_timers(&mut [], their_tun.fetch_timer_mask()),
TunnResult::Done
));
}

#[test]
Expand All @@ -790,9 +796,12 @@ mod tests {
// Advance time 1 second and "send" 1 packet so that we send a handshake
// after the timeout
mock_instant::MockClock::advance(Duration::from_secs(1));
assert!(matches!(their_tun.update_timers(&mut []), TunnResult::Done));
assert!(matches!(
my_tun.update_timers(&mut my_dst),
their_tun.update_timers(&mut [], their_tun.fetch_timer_mask()),
TunnResult::Done
));
assert!(matches!(
my_tun.update_timers(&mut my_dst, my_tun.fetch_timer_mask()),
TunnResult::Done
));
let sent_packet_buf = create_ipv4_udp_packet();
Expand All @@ -801,7 +810,10 @@ mod tests {

//Advance to timeout
mock_instant::MockClock::advance(REKEY_AFTER_TIME.into());
assert!(matches!(their_tun.update_timers(&mut []), TunnResult::Done));
assert!(matches!(
their_tun.update_timers(&mut [], their_tun.fetch_timer_mask()),
TunnResult::Done
));
update_timer_results_in_handshake(&mut my_tun);
}

Expand Down
56 changes: 13 additions & 43 deletions neptun/src/noise/timers.rs
Original file line number Diff line number Diff line change
Expand Up @@ -150,6 +150,12 @@ impl Tunn {
.fetch_or(1 << timer_name as u16, std::sync::atomic::Ordering::Relaxed);
}

pub fn fetch_timer_mask(&self) -> u16 {
self.timers
.timers_to_update_mask
.swap(0, std::sync::atomic::Ordering::Relaxed)
}

fn timer_tick(&mut self, timer_name: TimerName) {
let time = self.timers[TimeCurrent];
match timer_name {
Expand Down Expand Up @@ -222,7 +228,7 @@ impl Tunn {
}
}

pub fn update_timers<'a>(&mut self, dst: &'a mut [u8]) -> TunnResult<'a> {
pub fn update_timers<'a>(&mut self, dst: &'a mut [u8], timer_mask: u16) -> TunnResult<'a> {
let mut handshake_initiation_required = false;
let mut keepalive_required = false;

Expand All @@ -238,10 +244,6 @@ impl Tunn {
self.timers[TimeCurrent] = now;

// Check which timers to update, and update them
let timer_mask = self
.timers
.timers_to_update_mask
.swap(0, std::sync::atomic::Ordering::Relaxed);
self.tick_marked_timers(timer_mask);

self.update_session_timers(now);
Expand Down Expand Up @@ -444,7 +446,7 @@ mod tests {
my_tun.mark_timer_to_update(super::TimerName::TimePersistentKeepalive);

// Update timers
my_tun.update_timers(&mut [0]);
my_tun.update_timers(&mut [0], my_tun.fetch_timer_mask());

// Only those timers marked should be udpated
assert!(!my_tun.timers[TimerName::TimeLastDataPacketSent].is_zero());
Expand All @@ -461,7 +463,7 @@ mod tests {
my_tun.timers[TimerName::TimeLastDataPacketReceived] = SafeDuration::from_millis(0);
my_tun.timers[TimerName::TimePersistentKeepalive] = SafeDuration::from_millis(0);

my_tun.update_timers(&mut [0]);
my_tun.update_timers(&mut [0], my_tun.fetch_timer_mask());

// Now the timers should not update
assert!(my_tun.timers[TimerName::TimeLastDataPacketSent].is_zero());
Expand All @@ -470,7 +472,7 @@ mod tests {
}

#[test]
fn test_fetching_timers_by_swap_vs_load() {
fn test_marking_timers_after_fetch() {
let my_secret_key = x25519_dalek::StaticSecret::random_from_rng(OsRng);
let my_idx = OsRng.next_u32();

Expand All @@ -484,51 +486,19 @@ mod tests {
my_tun.mark_timer_to_update(super::TimerName::TimeLastDataPacketSent);

// Fetch by swap
let timer_mask = my_tun
.timers
.timers_to_update_mask
.swap(0, std::sync::atomic::Ordering::Relaxed);
let timer_mask = my_tun.fetch_timer_mask();

// Timer marked after fetch
my_tun.mark_timer_to_update(super::TimerName::TimeLastDataPacketReceived);

my_tun.tick_marked_timers(timer_mask);
my_tun.update_timers(&mut [0], timer_mask);

// Only those timers marked should be udpated
assert!(!my_tun.timers[TimerName::TimeLastDataPacketSent].is_zero());
assert!(my_tun.timers[TimerName::TimeLastDataPacketReceived].is_zero());

my_tun.update_timers(&mut [0]);

my_tun.update_timers(&mut [0], my_tun.fetch_timer_mask());
// Previously marked TimeLastDataPacketReceivedOnly updated in next cycle
assert!(!my_tun.timers[TimerName::TimeLastDataPacketReceived].is_zero());

// Reset the timers
my_tun.timers[TimerName::TimeLastDataPacketSent] = SafeDuration::from_millis(0);

// Fetch timers by load
let timer_mask = my_tun
.timers
.timers_to_update_mask
.load(std::sync::atomic::Ordering::Relaxed);

// Timer marked after fetch
my_tun.mark_timer_to_update(super::TimerName::TimeLastDataPacketSent);

my_tun.tick_marked_timers(timer_mask);

my_tun
.timers
.timers_to_update_mask
.store(0, std::sync::atomic::Ordering::Relaxed);

// Only TimeLastDataPacketReceived udpated
assert!(my_tun.timers[TimerName::TimeLastDataPacketSent].is_zero());
assert!(!my_tun.timers[TimerName::TimeLastDataPacketReceived].is_zero());

my_tun.update_timers(&mut [0]);

// TimeLastDataPacketSent timer still not updated
assert!(my_tun.timers[TimerName::TimeLastDataPacketSent].is_zero());
}
}

0 comments on commit e64f7d8

Please sign in to comment.