@@ -6,6 +6,7 @@ use super::errors::WireGuardError;
6
6
use crate :: noise:: { safe_duration:: SafeDuration as Duration , Tunn , TunnResult } ;
7
7
use std:: mem;
8
8
use std:: ops:: { Index , IndexMut } ;
9
+ use std:: sync:: atomic:: AtomicU16 ;
9
10
use std:: time:: SystemTime ;
10
11
11
12
#[ cfg( feature = "mock-instant" ) ]
@@ -42,7 +43,7 @@ pub(crate) const REKEY_TIMEOUT: Duration = Duration::from_secs(5);
42
43
const KEEPALIVE_TIMEOUT : Duration = Duration :: from_secs ( 10 ) ;
43
44
const COOKIE_EXPIRATION_TIME : Duration = Duration :: from_secs ( 120 ) ;
44
45
45
- #[ derive( Debug ) ]
46
+ #[ derive( Debug , Clone , Copy ) ]
46
47
pub enum TimerName {
47
48
/// Current time, updated each call to `update_timers`
48
49
TimeCurrent ,
@@ -65,6 +66,20 @@ pub enum TimerName {
65
66
Top ,
66
67
}
67
68
69
+ impl TimerName {
70
+ pub const VALUES : [ Self ; TimerName :: Top as usize ] = [
71
+ Self :: TimeCurrent ,
72
+ Self :: TimeSessionEstablished ,
73
+ Self :: TimeLastHandshakeStarted ,
74
+ Self :: TimeLastPacketReceived ,
75
+ Self :: TimeLastPacketSent ,
76
+ Self :: TimeLastDataPacketReceived ,
77
+ Self :: TimeLastDataPacketSent ,
78
+ Self :: TimeCookieReceived ,
79
+ Self :: TimePersistentKeepalive ,
80
+ ] ;
81
+ }
82
+
68
83
use self :: TimerName :: * ;
69
84
70
85
#[ derive( Debug ) ]
@@ -82,6 +97,7 @@ pub struct Timers {
82
97
persistent_keepalive : usize ,
83
98
/// Should this timer call reset rr function (if not a shared rr instance)
84
99
pub ( super ) should_reset_rr : bool ,
100
+ timers_to_update_mask : AtomicU16 ,
85
101
}
86
102
87
103
impl Timers {
@@ -95,6 +111,7 @@ impl Timers {
95
111
want_handshake_since : Default :: default ( ) ,
96
112
persistent_keepalive : usize:: from ( persistent_keepalive. unwrap_or ( 0 ) ) ,
97
113
should_reset_rr : reset_rr,
114
+ timers_to_update_mask : Default :: default ( ) ,
98
115
}
99
116
}
100
117
@@ -128,7 +145,13 @@ impl IndexMut<TimerName> for Timers {
128
145
}
129
146
130
147
impl Tunn {
131
- pub ( super ) fn timer_tick ( & mut self , timer_name : TimerName ) {
148
+ pub ( super ) fn mark_timer_to_update ( & self , timer_name : TimerName ) {
149
+ self . timers
150
+ . timers_to_update_mask
151
+ . fetch_or ( 1 << timer_name as u16 , std:: sync:: atomic:: Ordering :: Relaxed ) ;
152
+ }
153
+
154
+ fn timer_tick ( & mut self , timer_name : TimerName ) {
132
155
let time = self . timers [ TimeCurrent ] ;
133
156
match timer_name {
134
157
TimeLastPacketReceived => {
@@ -207,6 +230,21 @@ impl Tunn {
207
230
let now = time. duration_since ( self . timers . time_started ) . into ( ) ;
208
231
self . timers [ TimeCurrent ] = now;
209
232
233
+ // Check which timers to update, and update them
234
+ let timer_mask = self
235
+ . timers
236
+ . timers_to_update_mask
237
+ . load ( std:: sync:: atomic:: Ordering :: Relaxed ) ;
238
+ for timer_name in TimerName :: VALUES {
239
+ if ( timer_mask & ( 1 << ( timer_name as u16 ) ) ) != 0 {
240
+ self . timer_tick ( timer_name) ;
241
+ }
242
+ }
243
+ // Reset all marked bits
244
+ self . timers
245
+ . timers_to_update_mask
246
+ . store ( 0 , std:: sync:: atomic:: Ordering :: Relaxed ) ;
247
+
210
248
self . update_session_timers ( now) ;
211
249
212
250
// Load timers only once:
@@ -380,3 +418,55 @@ impl Tunn {
380
418
self . timers . persistent_keepalive = keepalive as usize ;
381
419
}
382
420
}
421
+
422
+ #[ cfg( test) ]
423
+ mod tests {
424
+ use rand:: RngCore ;
425
+ use rand_core:: OsRng ;
426
+
427
+ use crate :: noise:: { safe_duration:: SafeDuration , Tunn } ;
428
+
429
+ use super :: TimerName ;
430
+
431
+ #[ test]
432
+ fn create_two_tuns ( ) {
433
+ let my_secret_key = x25519_dalek:: StaticSecret :: random_from_rng ( OsRng ) ;
434
+ let my_idx = OsRng . next_u32 ( ) ;
435
+
436
+ let their_secret_key = x25519_dalek:: StaticSecret :: random_from_rng ( OsRng ) ;
437
+ let their_public_key = x25519_dalek:: PublicKey :: from ( & their_secret_key) ;
438
+
439
+ let mut my_tun =
440
+ Tunn :: new ( my_secret_key, their_public_key, None , None , my_idx, None ) . unwrap ( ) ;
441
+
442
+ // Mark timers to update
443
+ my_tun. mark_timer_to_update ( super :: TimerName :: TimeLastDataPacketSent ) ;
444
+ my_tun. mark_timer_to_update ( super :: TimerName :: TimeLastDataPacketReceived ) ;
445
+ my_tun. mark_timer_to_update ( super :: TimerName :: TimePersistentKeepalive ) ;
446
+
447
+ // Update timers
448
+ my_tun. update_timers ( & mut [ 0 ] ) ;
449
+
450
+ // Only those timers marked should be udpated
451
+ assert ! ( !my_tun. timers[ TimerName :: TimeLastDataPacketSent ] . is_zero( ) ) ;
452
+ assert ! ( !my_tun. timers[ TimerName :: TimeLastDataPacketReceived ] . is_zero( ) ) ;
453
+ assert ! ( !my_tun. timers[ TimerName :: TimePersistentKeepalive ] . is_zero( ) ) ;
454
+
455
+ // Unmarked timers should still be 0
456
+ assert ! ( my_tun. timers[ TimerName :: TimeCookieReceived ] . is_zero( ) ) ;
457
+ assert ! ( my_tun. timers[ TimerName :: TimeLastHandshakeStarted ] . is_zero( ) ) ;
458
+ assert ! ( my_tun. timers[ TimerName :: TimeLastPacketReceived ] . is_zero( ) ) ;
459
+
460
+ // Reset the timers
461
+ my_tun. timers [ TimerName :: TimeLastDataPacketSent ] = SafeDuration :: from_millis ( 0 ) ;
462
+ my_tun. timers [ TimerName :: TimeLastDataPacketReceived ] = SafeDuration :: from_millis ( 0 ) ;
463
+ my_tun. timers [ TimerName :: TimePersistentKeepalive ] = SafeDuration :: from_millis ( 0 ) ;
464
+
465
+ my_tun. update_timers ( & mut [ 0 ] ) ;
466
+
467
+ // Now the timers should not update
468
+ assert ! ( my_tun. timers[ TimerName :: TimeLastDataPacketSent ] . is_zero( ) ) ;
469
+ assert ! ( my_tun. timers[ TimerName :: TimeLastDataPacketReceived ] . is_zero( ) ) ;
470
+ assert ! ( my_tun. timers[ TimerName :: TimePersistentKeepalive ] . is_zero( ) ) ;
471
+ }
472
+ }
0 commit comments