Skip to content

Commit

Permalink
Only mark timers to update
Browse files Browse the repository at this point in the history
Timers are updated on every single packet call, but their
accuracy is limited to 250ms where currentTime is updated.
So just mark timers to update, and update them only once in
update_timers.
  • Loading branch information
Hasan6979 committed Nov 28, 2024
1 parent db16cad commit 5e1292c
Show file tree
Hide file tree
Showing 2 changed files with 83 additions and 13 deletions.
22 changes: 11 additions & 11 deletions neptun/src/noise/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -276,10 +276,10 @@ impl Tunn {
if let Some(ref session) = self.sessions[current % N_SESSIONS] {
// Send the packet using an established session
let packet = session.format_packet_data(src, dst);
self.timer_tick(TimerName::TimeLastPacketSent);
self.mark_timer_to_update(TimerName::TimeLastPacketSent);
// Exclude Keepalive packets from timer update.
if !src.is_empty() {
self.timer_tick(TimerName::TimeLastDataPacketSent);
self.mark_timer_to_update(TimerName::TimeLastPacketSent);
}
self.tx_bytes += packet.len();
return TunnResult::WriteToNetwork(packet);
Expand Down Expand Up @@ -365,8 +365,8 @@ impl Tunn {
let index = session.local_index();
self.sessions[index % N_SESSIONS] = Some(session);

self.timer_tick(TimerName::TimeLastPacketReceived);
self.timer_tick(TimerName::TimeLastPacketSent);
self.mark_timer_to_update(TimerName::TimeLastPacketReceived);
self.mark_timer_to_update(TimerName::TimeLastPacketSent);
self.timer_tick_session_established(false, index); // New session established, we are not the initiator

tracing::debug!(message = "Sending handshake_response", local_idx = index);
Expand Down Expand Up @@ -399,7 +399,7 @@ impl Tunn {
let index = l_idx % N_SESSIONS;
self.sessions[index] = Some(session);

self.timer_tick(TimerName::TimeLastPacketReceived);
self.mark_timer_to_update(TimerName::TimeLastPacketReceived);
self.timer_tick_session_established(true, index); // New session established, we are the initiator
self.set_current_session(l_idx);

Expand All @@ -424,8 +424,8 @@ impl Tunn {
// Increase the rx_bytes accordingly
self.rx_bytes += COOKIE_REPLY_SZ;

self.timer_tick(TimerName::TimeLastPacketReceived);
self.timer_tick(TimerName::TimeCookieReceived);
self.mark_timer_to_update(TimerName::TimeLastPacketReceived);
self.mark_timer_to_update(TimerName::TimeCookieReceived);

tracing::debug!("Did set cookie");

Expand Down Expand Up @@ -469,7 +469,7 @@ impl Tunn {

self.set_current_session(r_idx);

self.timer_tick(TimerName::TimeLastPacketReceived);
self.mark_timer_to_update(TimerName::TimeLastPacketReceived);

Ok(self.validate_decapsulated_packet(decapsulated_packet))
}
Expand All @@ -496,9 +496,9 @@ impl Tunn {
tracing::debug!("Sending handshake_initiation");

if starting_new_handshake {
self.timer_tick(TimerName::TimeLastHandshakeStarted);
self.mark_timer_to_update(TimerName::TimeLastHandshakeStarted);
}
self.timer_tick(TimerName::TimeLastPacketSent);
self.mark_timer_to_update(TimerName::TimeLastPacketSent);
self.tx_bytes += packet.len();

TunnResult::WriteToNetwork(packet)
Expand Down Expand Up @@ -548,7 +548,7 @@ impl Tunn {
return TunnResult::Err(WireGuardError::InvalidPacket);
}

self.timer_tick(TimerName::TimeLastDataPacketReceived);
self.mark_timer_to_update(TimerName::TimeLastDataPacketReceived);
self.rx_bytes += message_data_len(computed_len);

match src_ip_address {
Expand Down
74 changes: 72 additions & 2 deletions neptun/src/noise/timers.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ use super::errors::WireGuardError;
use crate::noise::{safe_duration::SafeDuration as Duration, Tunn, TunnResult};
use std::mem;
use std::ops::{Index, IndexMut};
use std::sync::atomic::AtomicU16;
use std::time::SystemTime;

#[cfg(feature = "mock-instant")]
Expand Down Expand Up @@ -42,7 +43,7 @@ pub(crate) const REKEY_TIMEOUT: Duration = Duration::from_secs(5);
const KEEPALIVE_TIMEOUT: Duration = Duration::from_secs(10);
const COOKIE_EXPIRATION_TIME: Duration = Duration::from_secs(120);

#[derive(Debug)]
#[derive(Debug, Clone, Copy)]
pub enum TimerName {
/// Current time, updated each call to `update_timers`
TimeCurrent,
Expand All @@ -65,6 +66,20 @@ pub enum TimerName {
Top,
}

impl TimerName {
pub const VALUES: [Self; TimerName::Top as usize] = [
Self::TimeCurrent,
Self::TimeSessionEstablished,
Self::TimeLastHandshakeStarted,
Self::TimeLastPacketReceived,
Self::TimeLastPacketSent,
Self::TimeLastDataPacketReceived,
Self::TimeLastDataPacketSent,
Self::TimeCookieReceived,
Self::TimePersistentKeepalive,
];
}

use self::TimerName::*;

#[derive(Debug)]
Expand All @@ -82,6 +97,7 @@ pub struct Timers {
persistent_keepalive: usize,
/// Should this timer call reset rr function (if not a shared rr instance)
pub(super) should_reset_rr: bool,
timers_to_update_mask: AtomicU16,
}

impl Timers {
Expand All @@ -95,6 +111,7 @@ impl Timers {
want_handshake_since: Default::default(),
persistent_keepalive: usize::from(persistent_keepalive.unwrap_or(0)),
should_reset_rr: reset_rr,
timers_to_update_mask: Default::default(),
}
}

Expand Down Expand Up @@ -128,7 +145,13 @@ impl IndexMut<TimerName> for Timers {
}

impl Tunn {
pub(super) fn timer_tick(&mut self, timer_name: TimerName) {
pub(super) fn mark_timer_to_update(&self, timer_name: TimerName) {
self.timers
.timers_to_update_mask
.fetch_or(1 << timer_name as u16, std::sync::atomic::Ordering::Relaxed);
}

fn timer_tick(&mut self, timer_name: TimerName) {
let time = self.timers[TimeCurrent];
match timer_name {
TimeLastPacketReceived => {
Expand Down Expand Up @@ -207,6 +230,17 @@ impl Tunn {
let now = time.duration_since(self.timers.time_started).into();
self.timers[TimeCurrent] = now;

// Check which timers to update, and update them
let timer_mask = self
.timers
.timers_to_update_mask
.load(std::sync::atomic::Ordering::Relaxed);
for timer_name in TimerName::VALUES {
if (timer_mask & (1 << (timer_name as u16))) != 0 {
self.timer_tick(timer_name);
}
}

self.update_session_timers(now);

// Load timers only once:
Expand Down Expand Up @@ -380,3 +414,39 @@ impl Tunn {
self.timers.persistent_keepalive = keepalive as usize;
}
}

#[cfg(test)]
mod tests {
use rand::RngCore;
use rand_core::OsRng;

use crate::noise::Tunn;

use super::TimerName;

#[test]
fn create_two_tuns() {
let my_secret_key = x25519_dalek::StaticSecret::random_from_rng(OsRng);
let my_idx = OsRng.next_u32();

let their_secret_key = x25519_dalek::StaticSecret::random_from_rng(OsRng);
let their_public_key = x25519_dalek::PublicKey::from(&their_secret_key);

let mut my_tun =
Tunn::new(my_secret_key, their_public_key, None, None, my_idx, None).unwrap();

my_tun.mark_timer_to_update(super::TimerName::TimeLastDataPacketSent);
my_tun.mark_timer_to_update(super::TimerName::TimeLastDataPacketReceived);
my_tun.mark_timer_to_update(super::TimerName::TimePersistentKeepalive);

my_tun.update_timers(&mut [0]);

assert!(!my_tun.timers[TimerName::TimeLastDataPacketSent].is_zero());
assert!(!my_tun.timers[TimerName::TimeLastDataPacketReceived].is_zero());
assert!(!my_tun.timers[TimerName::TimePersistentKeepalive].is_zero());

assert!(my_tun.timers[TimerName::TimeCookieReceived].is_zero());
assert!(my_tun.timers[TimerName::TimeLastHandshakeStarted].is_zero());
assert!(my_tun.timers[TimerName::TimeLastPacketReceived].is_zero());
}
}

0 comments on commit 5e1292c

Please sign in to comment.