From 73d65601953a796d4f5f8b7b07765f76938b2853 Mon Sep 17 00:00:00 2001 From: wathenjiang Date: Wed, 24 Jul 2024 19:45:50 +0800 Subject: [PATCH 1/8] Revert "time: revert "avoid traversing entries in the time wheel twice" (#6715)" This reverts commit 47210a8e6eeb82b51aa778074fdc4d757b953b8c. --- tokio/src/runtime/time/entry.rs | 99 +++---------- tokio/src/runtime/time/mod.rs | 60 ++++++-- tokio/src/runtime/time/wheel/level.rs | 22 +-- tokio/src/runtime/time/wheel/mod.rs | 193 ++++++++++++++------------ tokio/src/util/linked_list.rs | 1 + 5 files changed, 185 insertions(+), 190 deletions(-) diff --git a/tokio/src/runtime/time/entry.rs b/tokio/src/runtime/time/entry.rs index 834077caa3d..0bd15a74f8b 100644 --- a/tokio/src/runtime/time/entry.rs +++ b/tokio/src/runtime/time/entry.rs @@ -21,9 +21,8 @@ //! //! Each timer has a state field associated with it. This field contains either //! the current scheduled time, or a special flag value indicating its state. -//! This state can either indicate that the timer is on the 'pending' queue (and -//! thus will be fired with an `Ok(())` result soon) or that it has already been -//! fired/deregistered. +//! This state can either indicate that the timer is firing (and thus will be fired +//! with an `Ok(())` result soon) or that it has already been fired/deregistered. //! //! This single state field allows for code that is firing the timer to //! synchronize with any racing `reset` calls reliably. @@ -49,10 +48,10 @@ //! There is of course a race condition between timer reset and timer //! expiration. If the driver fails to observe the updated expiration time, it //! could trigger expiration of the timer too early. However, because -//! [`mark_pending`][mark_pending] performs a compare-and-swap, it will identify this race and -//! refuse to mark the timer as pending. +//! [`mark_firing`][mark_firing] performs a compare-and-swap, it will identify this race and +//! refuse to mark the timer as firing. //! -//! [mark_pending]: TimerHandle::mark_pending +//! [mark_firing]: TimerHandle::mark_firing use crate::loom::cell::UnsafeCell; use crate::loom::sync::atomic::AtomicU64; @@ -70,9 +69,9 @@ use std::{marker::PhantomPinned, pin::Pin, ptr::NonNull}; type TimerResult = Result<(), crate::time::error::Error>; -const STATE_DEREGISTERED: u64 = u64::MAX; -const STATE_PENDING_FIRE: u64 = STATE_DEREGISTERED - 1; -const STATE_MIN_VALUE: u64 = STATE_PENDING_FIRE; +pub(super) const STATE_DEREGISTERED: u64 = u64::MAX; +const STATE_FIRING: u64 = STATE_DEREGISTERED - 1; +const STATE_MIN_VALUE: u64 = STATE_FIRING; /// The largest safe integer to use for ticks. /// /// This value should be updated if any other signal values are added above. @@ -123,10 +122,6 @@ impl StateCell { } } - fn is_pending(&self) -> bool { - self.state.load(Ordering::Relaxed) == STATE_PENDING_FIRE - } - /// Returns the current expiration time, or None if not currently scheduled. fn when(&self) -> Option { let cur_state = self.state.load(Ordering::Relaxed); @@ -162,26 +157,28 @@ impl StateCell { } } - /// Marks this timer as being moved to the pending list, if its scheduled - /// time is not after `not_after`. + /// Marks this timer firing, if its scheduled time is not after `not_after`. /// /// If the timer is scheduled for a time after `not_after`, returns an Err /// containing the current scheduled time. /// /// SAFETY: Must hold the driver lock. - unsafe fn mark_pending(&self, not_after: u64) -> Result<(), u64> { + unsafe fn mark_firing(&self, not_after: u64) -> Result<(), u64> { // Quick initial debug check to see if the timer is already fired. Since // firing the timer can only happen with the driver lock held, we know // we shouldn't be able to "miss" a transition to a fired state, even // with relaxed ordering. let mut cur_state = self.state.load(Ordering::Relaxed); - loop { + // Because its state is STATE_DEREGISTERED, it has been fired. + if cur_state == STATE_DEREGISTERED { + break Err(cur_state); + } // improve the error message for things like // https://github.com/tokio-rs/tokio/issues/3675 assert!( cur_state < STATE_MIN_VALUE, - "mark_pending called when the timer entry is in an invalid state" + "mark_firing called when the timer entry is in an invalid state" ); if cur_state > not_after { @@ -190,7 +187,7 @@ impl StateCell { match self.state.compare_exchange_weak( cur_state, - STATE_PENDING_FIRE, + STATE_FIRING, Ordering::AcqRel, Ordering::Acquire, ) { @@ -337,11 +334,6 @@ pub(crate) struct TimerShared { /// Only accessed under the entry lock. pointers: linked_list::Pointers, - /// The expiration time for which this entry is currently registered. - /// Generally owned by the driver, but is accessed by the entry when not - /// registered. - cached_when: AtomicU64, - /// Current state. This records whether the timer entry is currently under /// the ownership of the driver, and if not, its current state (not /// complete, fired, error, etc). @@ -356,7 +348,6 @@ unsafe impl Sync for TimerShared {} impl std::fmt::Debug for TimerShared { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { f.debug_struct("TimerShared") - .field("cached_when", &self.cached_when.load(Ordering::Relaxed)) .field("state", &self.state) .finish() } @@ -374,40 +365,12 @@ impl TimerShared { pub(super) fn new(shard_id: u32) -> Self { Self { shard_id, - cached_when: AtomicU64::new(0), pointers: linked_list::Pointers::new(), state: StateCell::default(), _p: PhantomPinned, } } - /// Gets the cached time-of-expiration value. - pub(super) fn cached_when(&self) -> u64 { - // Cached-when is only accessed under the driver lock, so we can use relaxed - self.cached_when.load(Ordering::Relaxed) - } - - /// Gets the true time-of-expiration value, and copies it into the cached - /// time-of-expiration value. - /// - /// SAFETY: Must be called with the driver lock held, and when this entry is - /// not in any timer wheel lists. - pub(super) unsafe fn sync_when(&self) -> u64 { - let true_when = self.true_when(); - - self.cached_when.store(true_when, Ordering::Relaxed); - - true_when - } - - /// Sets the cached time-of-expiration value. - /// - /// SAFETY: Must be called with the driver lock held, and when this entry is - /// not in any timer wheel lists. - unsafe fn set_cached_when(&self, when: u64) { - self.cached_when.store(when, Ordering::Relaxed); - } - /// Returns the true time-of-expiration value, with relaxed memory ordering. pub(super) fn true_when(&self) -> u64 { self.state.when().expect("Timer already fired") @@ -420,7 +383,6 @@ impl TimerShared { /// in the timer wheel. pub(super) unsafe fn set_expiration(&self, t: u64) { self.state.set_expiration(t); - self.cached_when.store(t, Ordering::Relaxed); } /// Sets the true time-of-expiration only if it is after the current. @@ -590,16 +552,8 @@ impl TimerEntry { } impl TimerHandle { - pub(super) unsafe fn cached_when(&self) -> u64 { - unsafe { self.inner.as_ref().cached_when() } - } - - pub(super) unsafe fn sync_when(&self) -> u64 { - unsafe { self.inner.as_ref().sync_when() } - } - - pub(super) unsafe fn is_pending(&self) -> bool { - unsafe { self.inner.as_ref().state.is_pending() } + pub(super) unsafe fn true_when(&self) -> u64 { + unsafe { self.inner.as_ref().true_when() } } /// Forcibly sets the true and cached expiration times to the given tick. @@ -610,7 +564,7 @@ impl TimerHandle { self.inner.as_ref().set_expiration(tick); } - /// Attempts to mark this entry as pending. If the expiration time is after + /// Attempts to mark this entry as firing. If the expiration time is after /// `not_after`, however, returns an Err with the current expiration time. /// /// If an `Err` is returned, the `cached_when` value will be updated to this @@ -618,19 +572,8 @@ impl TimerHandle { /// /// SAFETY: The caller must ensure that the handle remains valid, the driver /// lock is held, and that the timer is not in any wheel linked lists. - /// After returning Ok, the entry must be added to the pending list. - pub(super) unsafe fn mark_pending(&self, not_after: u64) -> Result<(), u64> { - match self.inner.as_ref().state.mark_pending(not_after) { - Ok(()) => { - // mark this as being on the pending queue in cached_when - self.inner.as_ref().set_cached_when(u64::MAX); - Ok(()) - } - Err(tick) => { - self.inner.as_ref().set_cached_when(tick); - Err(tick) - } - } + pub(super) unsafe fn mark_firing(&self, not_after: u64) -> Result<(), u64> { + self.inner.as_ref().state.mark_firing(not_after) } /// Attempts to transition to a terminal state. If the state is already a diff --git a/tokio/src/runtime/time/mod.rs b/tokio/src/runtime/time/mod.rs index c01a5f2b25e..0e4c995dcd0 100644 --- a/tokio/src/runtime/time/mod.rs +++ b/tokio/src/runtime/time/mod.rs @@ -8,7 +8,7 @@ mod entry; pub(crate) use entry::TimerEntry; -use entry::{EntryList, TimerHandle, TimerShared, MAX_SAFE_MILLIS_DURATION}; +use entry::{EntryList, TimerHandle, TimerShared, MAX_SAFE_MILLIS_DURATION, STATE_DEREGISTERED}; mod handle; pub(crate) use self::handle::Handle; @@ -324,23 +324,53 @@ impl Handle { now = lock.elapsed(); } - while let Some(entry) = lock.poll(now) { - debug_assert!(unsafe { entry.is_pending() }); - - // SAFETY: We hold the driver lock, and just removed the entry from any linked lists. - if let Some(waker) = unsafe { entry.fire(Ok(())) } { - waker_list.push(waker); - - if !waker_list.can_push() { - // Wake a batch of wakers. To avoid deadlock, we must do this with the lock temporarily dropped. - drop(lock); - - waker_list.wake_all(); - - lock = self.inner.lock_sharded_wheel(id); + while let Some(expiration) = lock.poll(now) { + lock.set_elapsed(expiration.deadline); + // It is critical for `GuardedLinkedList` safety that the guard node is + // pinned in memory and is not dropped until the guarded list is dropped. + let guard = TimerShared::new(id); + pin!(guard); + let guard_handle = guard.as_ref().get_ref().handle(); + + // * This list will be still guarded by the lock of the Wheel with the specefied id. + // `EntryWaitersList` wrapper makes sure we hold the lock to modify it. + // * This wrapper will empty the list on drop. It is critical for safety + // that we will not leave any list entry with a pointer to the local + // guard node after this function returns / panics. + // Safety: The `TimerShared` inside this `TimerHandle` is pinned in the memory. + let mut list = unsafe { lock.get_waiters_list(&expiration, guard_handle, id, self) }; + + while let Some(entry) = list.pop_back_locked(&mut lock) { + let deadline = expiration.deadline; + // Try to expire the entry; this is cheap (doesn't synchronize) if + // the timer is not expired, and updates cached_when. + match unsafe { entry.mark_firing(deadline) } { + Ok(()) => { + // Entry was expired. + // SAFETY: We hold the driver lock, and just removed the entry from any linked lists. + if let Some(waker) = unsafe { entry.fire(Ok(())) } { + waker_list.push(waker); + + if !waker_list.can_push() { + // Wake a batch of wakers. To avoid deadlock, + // we must do this with the lock temporarily dropped. + drop(lock); + waker_list.wake_all(); + + lock = self.inner.lock_sharded_wheel(id); + } + } + } + Err(state) if state == STATE_DEREGISTERED => {} + Err(state) => { + // Safety: This Entry has not expired. + unsafe { lock.reinsert_entry(entry, deadline, state) }; + } } } + lock.occupied_bit_maintain(&expiration); } + let next_wake_up = lock.poll_at(); drop(lock); diff --git a/tokio/src/runtime/time/wheel/level.rs b/tokio/src/runtime/time/wheel/level.rs index d31eaf46879..6539f47b4fa 100644 --- a/tokio/src/runtime/time/wheel/level.rs +++ b/tokio/src/runtime/time/wheel/level.rs @@ -20,7 +20,6 @@ pub(crate) struct Level { } /// Indicates when a slot must be processed next. -#[derive(Debug)] pub(crate) struct Expiration { /// The level containing the slot. pub(crate) level: usize, @@ -81,7 +80,7 @@ impl Level { // pseudo-ring buffer, and we rotate around them indefinitely. If we // compute a deadline before now, and it's the top level, it // therefore means we're actually looking at a slot in the future. - debug_assert_eq!(self.level, super::NUM_LEVELS - 1); + debug_assert_eq!(self.level, super::MAX_LEVEL_INDEX); deadline += level_range; } @@ -120,7 +119,7 @@ impl Level { } pub(crate) unsafe fn add_entry(&mut self, item: TimerHandle) { - let slot = slot_for(item.cached_when(), self.level); + let slot = slot_for(item.true_when(), self.level); self.slot[slot].push_front(item); @@ -128,23 +127,26 @@ impl Level { } pub(crate) unsafe fn remove_entry(&mut self, item: NonNull) { - let slot = slot_for(unsafe { item.as_ref().cached_when() }, self.level); + let slot = slot_for(unsafe { item.as_ref().true_when() }, self.level); unsafe { self.slot[slot].remove(item) }; if self.slot[slot].is_empty() { - // The bit is currently set - debug_assert!(self.occupied & occupied_bit(slot) != 0); - // Unset the bit self.occupied ^= occupied_bit(slot); } } - pub(crate) fn take_slot(&mut self, slot: usize) -> EntryList { - self.occupied &= !occupied_bit(slot); - + pub(super) fn take_slot(&mut self, slot: usize) -> EntryList { std::mem::take(&mut self.slot[slot]) } + + pub(super) fn occupied_bit_maintain(&mut self, slot: usize) { + if self.slot[slot].is_empty() { + self.occupied &= !occupied_bit(slot); + } else { + self.occupied |= occupied_bit(slot); + } + } } impl fmt::Debug for Level { diff --git a/tokio/src/runtime/time/wheel/mod.rs b/tokio/src/runtime/time/wheel/mod.rs index f2b4228514c..a4034053134 100644 --- a/tokio/src/runtime/time/wheel/mod.rs +++ b/tokio/src/runtime/time/wheel/mod.rs @@ -1,5 +1,6 @@ use crate::runtime::time::{TimerHandle, TimerShared}; use crate::time::error::InsertError; +use crate::util::linked_list::{self, GuardedLinkedList, LinkedList}; mod level; pub(crate) use self::level::Expiration; @@ -7,7 +8,59 @@ use self::level::Level; use std::{array, ptr::NonNull}; -use super::EntryList; +use super::entry::MAX_SAFE_MILLIS_DURATION; +use super::Handle; + +/// List used in `Handle::process_at_sharded_time`. It wraps a guarded linked list +/// and gates the access to it on the lock of the `Wheel` with the specified `wheel_id`. +/// It also empties the list on drop. +pub(super) struct EntryWaitersList<'a> { + // GuardedLinkedList ensures that the concurrent drop of Entry in this slot is safe. + list: GuardedLinkedList::Target>, + is_empty: bool, + wheel_id: u32, + handle: &'a Handle, +} + +impl<'a> Drop for EntryWaitersList<'a> { + fn drop(&mut self) { + // If the list is not empty, we unlink all waiters from it. + // We do not wake the waiters to avoid double panics. + if !self.is_empty { + let _lock = self.handle.inner.lock_sharded_wheel(self.wheel_id); + while self.list.pop_back().is_some() {} + } + } +} + +impl<'a> EntryWaitersList<'a> { + fn new( + unguarded_list: LinkedList::Target>, + guard_handle: TimerHandle, + wheel_id: u32, + handle: &'a Handle, + ) -> Self { + let list = unguarded_list.into_guarded(guard_handle); + Self { + list, + is_empty: false, + wheel_id, + handle, + } + } + + /// Removes the last element from the guarded list. Modifying this list + /// requires an exclusive access to the Wheel with the specified `wheel_id`. + pub(super) fn pop_back_locked(&mut self, _wheel: &mut Wheel) -> Option { + let result = self.list.pop_back(); + if result.is_none() { + // Save information about emptiness to avoid waiting for lock + // in the destructor. + self.is_empty = true; + } + result + } +} /// Timing wheel implementation. /// @@ -36,9 +89,6 @@ pub(crate) struct Wheel { /// * ~ 4 hr slots / ~ 12 day range /// * ~ 12 day slots / ~ 2 yr range levels: Box<[Level; NUM_LEVELS]>, - - /// Entries queued for firing - pending: EntryList, } /// Number of levels. Each level has 64 slots. By using 6 levels with 64 slots @@ -46,6 +96,9 @@ pub(crate) struct Wheel { /// precision of 1 millisecond. const NUM_LEVELS: usize = 6; +/// The max level index. +pub(super) const MAX_LEVEL_INDEX: usize = NUM_LEVELS - 1; + /// The maximum duration of a `Sleep`. pub(super) const MAX_DURATION: u64 = (1 << (6 * NUM_LEVELS)) - 1; @@ -55,7 +108,6 @@ impl Wheel { Wheel { elapsed: 0, levels: Box::new(array::from_fn(Level::new)), - pending: EntryList::new(), } } @@ -90,7 +142,7 @@ impl Wheel { &mut self, item: TimerHandle, ) -> Result { - let when = item.sync_when(); + let when = item.true_when(); if when <= self.elapsed { return Err((item, InsertError::Elapsed)); @@ -99,9 +151,7 @@ impl Wheel { // Get the level at which the entry should be stored let level = self.level_for(when); - unsafe { - self.levels[level].add_entry(item); - } + unsafe { self.levels[level].add_entry(item) }; debug_assert!({ self.levels[level] @@ -116,10 +166,8 @@ impl Wheel { /// Removes `item` from the timing wheel. pub(crate) unsafe fn remove(&mut self, item: NonNull) { unsafe { - let when = item.as_ref().cached_when(); - if when == u64::MAX { - self.pending.remove(item); - } else { + let when = item.as_ref().true_when(); + if when <= MAX_SAFE_MILLIS_DURATION { debug_assert!( self.elapsed <= when, "elapsed={}; when={}", @@ -128,54 +176,42 @@ impl Wheel { ); let level = self.level_for(when); + // If the entry is not contained in the `slot` list, + // then it is contained by a guarded list. self.levels[level].remove_entry(item); } } } + /// Reinserts `item` to the timing wheel. + /// Safety: This entry must not have expired. + pub(super) unsafe fn reinsert_entry(&mut self, entry: TimerHandle, elapsed: u64, when: u64) { + let level = level_for(elapsed, when); + unsafe { self.levels[level].add_entry(entry) }; + } + /// Instant at which to poll. pub(crate) fn poll_at(&self) -> Option { self.next_expiration().map(|expiration| expiration.deadline) } /// Advances the timer up to the instant represented by `now`. - pub(crate) fn poll(&mut self, now: u64) -> Option { - loop { - if let Some(handle) = self.pending.pop_back() { - return Some(handle); - } - - match self.next_expiration() { - Some(ref expiration) if expiration.deadline <= now => { - self.process_expiration(expiration); - - self.set_elapsed(expiration.deadline); - } - _ => { - // in this case the poll did not indicate an expiration - // _and_ we were not able to find a next expiration in - // the current list of timers. advance to the poll's - // current time and do nothing else. - self.set_elapsed(now); - break; - } + pub(crate) fn poll(&mut self, now: u64) -> Option { + match self.next_expiration() { + Some(expiration) if expiration.deadline <= now => Some(expiration), + _ => { + // in this case the poll did not indicate an expiration + // _and_ we were not able to find a next expiration in + // the current list of timers. advance to the poll's + // current time and do nothing else. + self.set_elapsed(now); + None } } - - self.pending.pop_back() } /// Returns the instant at which the next timeout expires. fn next_expiration(&self) -> Option { - if !self.pending.is_empty() { - // Expire immediately as we have things pending firing - return Some(Expiration { - level: 0, - slot: 0, - deadline: self.elapsed, - }); - } - // Check all levels for (level_num, level) in self.levels.iter().enumerate() { if let Some(expiration) = level.next_expiration(self.elapsed) { @@ -211,46 +247,7 @@ impl Wheel { res } - /// iteratively find entries that are between the wheel's current - /// time and the expiration time. for each in that population either - /// queue it for notification (in the case of the last level) or tier - /// it down to the next level (in all other cases). - pub(crate) fn process_expiration(&mut self, expiration: &Expiration) { - // Note that we need to take _all_ of the entries off the list before - // processing any of them. This is important because it's possible that - // those entries might need to be reinserted into the same slot. - // - // This happens only on the highest level, when an entry is inserted - // more than MAX_DURATION into the future. When this happens, we wrap - // around, and process some entries a multiple of MAX_DURATION before - // they actually need to be dropped down a level. We then reinsert them - // back into the same position; we must make sure we don't then process - // those entries again or we'll end up in an infinite loop. - let mut entries = self.take_entries(expiration); - - while let Some(item) = entries.pop_back() { - if expiration.level == 0 { - debug_assert_eq!(unsafe { item.cached_when() }, expiration.deadline); - } - - // Try to expire the entry; this is cheap (doesn't synchronize) if - // the timer is not expired, and updates cached_when. - match unsafe { item.mark_pending(expiration.deadline) } { - Ok(()) => { - // Item was expired - self.pending.push_front(item); - } - Err(expiration_tick) => { - let level = level_for(expiration.deadline, expiration_tick); - unsafe { - self.levels[level].add_entry(item); - } - } - } - } - } - - fn set_elapsed(&mut self, when: u64) { + pub(super) fn set_elapsed(&mut self, when: u64) { assert!( self.elapsed <= when, "elapsed={:?}; when={:?}", @@ -263,9 +260,31 @@ impl Wheel { } } - /// Obtains the list of entries that need processing for the given expiration. - fn take_entries(&mut self, expiration: &Expiration) -> EntryList { - self.levels[expiration.level].take_slot(expiration.slot) + /// Obtains the guarded list of entries that need processing for the given expiration. + /// Safety: The `TimerShared` inside `guard_handle` must be pinned in the memory. + pub(super) unsafe fn get_waiters_list<'a>( + &mut self, + expiration: &Expiration, + guard_handle: TimerHandle, + wheel_id: u32, + handle: &'a Handle, + ) -> EntryWaitersList<'a> { + // Note that we need to take _all_ of the entries off the list before + // processing any of them. This is important because it's possible that + // those entries might need to be reinserted into the same slot. + // + // This happens only on the highest level, when an entry is inserted + // more than MAX_DURATION into the future. When this happens, we wrap + // around, and process some entries a multiple of MAX_DURATION before + // they actually need to be dropped down a level. We then reinsert them + // back into the same position; we must make sure we don't then process + // those entries again or we'll end up in an infinite loop. + let unguarded_list = self.levels[expiration.level].take_slot(expiration.slot); + EntryWaitersList::new(unguarded_list, guard_handle, wheel_id, handle) + } + + pub(super) fn occupied_bit_maintain(&mut self, expiration: &Expiration) { + self.levels[expiration.level].occupied_bit_maintain(expiration.slot); } fn level_for(&self, when: u64) -> usize { diff --git a/tokio/src/util/linked_list.rs b/tokio/src/util/linked_list.rs index 3650f87fbb0..382d9ee1978 100644 --- a/tokio/src/util/linked_list.rs +++ b/tokio/src/util/linked_list.rs @@ -334,6 +334,7 @@ feature! { feature = "sync", feature = "rt", feature = "signal", + feature = "time", )] /// An intrusive linked list, but instead of keeping pointers to the head From cd70a2497eb38e9e6a057c85aec7adfc3cba9871 Mon Sep 17 00:00:00 2001 From: wathenjiang Date: Wed, 24 Jul 2024 19:57:50 +0800 Subject: [PATCH 2/8] time: revert set_cached_when --- tokio/src/runtime/time/entry.rs | 47 +++++++++++++++++++++++---- tokio/src/runtime/time/mod.rs | 1 - tokio/src/runtime/time/tests/mod.rs | 21 ++++++++++++ tokio/src/runtime/time/wheel/level.rs | 8 +++-- tokio/src/runtime/time/wheel/mod.rs | 29 +++++++++++++---- 5 files changed, 89 insertions(+), 17 deletions(-) diff --git a/tokio/src/runtime/time/entry.rs b/tokio/src/runtime/time/entry.rs index 0bd15a74f8b..f52944fb1dc 100644 --- a/tokio/src/runtime/time/entry.rs +++ b/tokio/src/runtime/time/entry.rs @@ -333,7 +333,11 @@ pub(crate) struct TimerShared { /// /// Only accessed under the entry lock. pointers: linked_list::Pointers, - + /// The expiration time for which this entry is currently registered. + /// It is used to calculate which slot this entry is stored in. + /// Generally owned by the driver, but is accessed by the entry when not + /// registered. + cached_when: AtomicU64, /// Current state. This records whether the timer entry is currently under /// the ownership of the driver, and if not, its current state (not /// complete, fired, error, etc). @@ -348,6 +352,7 @@ unsafe impl Sync for TimerShared {} impl std::fmt::Debug for TimerShared { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { f.debug_struct("TimerShared") + .field("cached_when", &self.cached_when.load(Ordering::Relaxed)) .field("state", &self.state) .finish() } @@ -365,12 +370,37 @@ impl TimerShared { pub(super) fn new(shard_id: u32) -> Self { Self { shard_id, + cached_when: AtomicU64::new(0), pointers: linked_list::Pointers::new(), state: StateCell::default(), _p: PhantomPinned, } } + /// Gets the cached time-of-expiration value. + pub(super) fn cached_when(&self) -> u64 { + // Cached-when is only accessed under the driver lock, so we can use relaxed + self.cached_when.load(Ordering::Relaxed) + } + /// Gets the true time-of-expiration value, and copies it into the cached + /// time-of-expiration value. + /// + /// SAFETY: Must be called with the driver lock held, and when this entry is + /// not in any timer wheel lists. + pub(super) unsafe fn sync_when(&self) -> u64 { + let true_when = self.true_when(); + self.set_cached_when(true_when); + true_when + } + + /// Sets the cached time-of-expiration value. + /// + /// SAFETY: Must be called with the driver lock held, and when this entry is + /// not in any timer wheel lists. + pub(super) unsafe fn set_cached_when(&self, when: u64) { + self.cached_when.store(when, Ordering::Relaxed); + } + /// Returns the true time-of-expiration value, with relaxed memory ordering. pub(super) fn true_when(&self) -> u64 { self.state.when().expect("Timer already fired") @@ -552,8 +582,16 @@ impl TimerEntry { } impl TimerHandle { - pub(super) unsafe fn true_when(&self) -> u64 { - unsafe { self.inner.as_ref().true_when() } + pub(super) unsafe fn cached_when(&self) -> u64 { + unsafe { self.inner.as_ref().cached_when() } + } + + pub(super) unsafe fn set_cached_when(&self, t: u64) { + unsafe { self.inner.as_ref().set_cached_when(t) } + } + + pub(super) unsafe fn sync_when(&self) -> u64 { + unsafe { self.inner.as_ref().sync_when() } } /// Forcibly sets the true and cached expiration times to the given tick. @@ -567,9 +605,6 @@ impl TimerHandle { /// Attempts to mark this entry as firing. If the expiration time is after /// `not_after`, however, returns an Err with the current expiration time. /// - /// If an `Err` is returned, the `cached_when` value will be updated to this - /// new expiration time. - /// /// SAFETY: The caller must ensure that the handle remains valid, the driver /// lock is held, and that the timer is not in any wheel linked lists. pub(super) unsafe fn mark_firing(&self, not_after: u64) -> Result<(), u64> { diff --git a/tokio/src/runtime/time/mod.rs b/tokio/src/runtime/time/mod.rs index 0e4c995dcd0..3d1dc92f68f 100644 --- a/tokio/src/runtime/time/mod.rs +++ b/tokio/src/runtime/time/mod.rs @@ -204,7 +204,6 @@ impl Driver { .inner .next_wake .store(next_wake_time(expiration_time)); - // Safety: After updating the `next_wake`, we drop all the locks. drop(locks); diff --git a/tokio/src/runtime/time/tests/mod.rs b/tokio/src/runtime/time/tests/mod.rs index 0e453433691..7084c2d7519 100644 --- a/tokio/src/runtime/time/tests/mod.rs +++ b/tokio/src/runtime/time/tests/mod.rs @@ -202,6 +202,27 @@ fn reset_future() { }) } +#[test] +fn reset_timer_and_drop() { + model(|| { + let rt = rt(false); + let handle = rt.handle(); + + let start = handle.inner.driver().clock().now(); + + for _ in 0..2 { + let entry = TimerEntry::new(handle.inner.clone(), start + Duration::from_millis(10)); + pin!(entry); + + let _ = entry + .as_mut() + .poll_elapsed(&mut Context::from_waker(futures::task::noop_waker_ref())); + + entry.as_mut().reset(start + Duration::from_secs(1), true); + } + }); +} + #[cfg(not(loom))] fn normal_or_miri(normal: T, miri: T) -> T { if cfg!(miri) { diff --git a/tokio/src/runtime/time/wheel/level.rs b/tokio/src/runtime/time/wheel/level.rs index 6539f47b4fa..cccff6d8827 100644 --- a/tokio/src/runtime/time/wheel/level.rs +++ b/tokio/src/runtime/time/wheel/level.rs @@ -118,16 +118,18 @@ impl Level { Some(slot) } - pub(crate) unsafe fn add_entry(&mut self, item: TimerHandle) { - let slot = slot_for(item.true_when(), self.level); + pub(crate) unsafe fn add_entry(&mut self, item: TimerHandle) -> usize { + let slot = slot_for(item.cached_when(), self.level); self.slot[slot].push_front(item); self.occupied |= occupied_bit(slot); + + slot } pub(crate) unsafe fn remove_entry(&mut self, item: NonNull) { - let slot = slot_for(unsafe { item.as_ref().true_when() }, self.level); + let slot = slot_for(unsafe { item.as_ref().cached_when() }, self.level); unsafe { self.slot[slot].remove(item) }; if self.slot[slot].is_empty() { diff --git a/tokio/src/runtime/time/wheel/mod.rs b/tokio/src/runtime/time/wheel/mod.rs index a4034053134..7424c4892cb 100644 --- a/tokio/src/runtime/time/wheel/mod.rs +++ b/tokio/src/runtime/time/wheel/mod.rs @@ -142,16 +142,14 @@ impl Wheel { &mut self, item: TimerHandle, ) -> Result { - let when = item.true_when(); + let when = item.sync_when(); if when <= self.elapsed { return Err((item, InsertError::Elapsed)); } - // Get the level at which the entry should be stored - let level = self.level_for(when); - - unsafe { self.levels[level].add_entry(item) }; + // Safety: The `cached_when` of this item has been updated by calling the `sync_when`. + let (level, _slot) = self.inner_insert(item, self.elapsed, when); debug_assert!({ self.levels[level] @@ -166,7 +164,7 @@ impl Wheel { /// Removes `item` from the timing wheel. pub(crate) unsafe fn remove(&mut self, item: NonNull) { unsafe { - let when = item.as_ref().true_when(); + let when = item.as_ref().cached_when(); if when <= MAX_SAFE_MILLIS_DURATION { debug_assert!( self.elapsed <= when, @@ -186,8 +184,25 @@ impl Wheel { /// Reinserts `item` to the timing wheel. /// Safety: This entry must not have expired. pub(super) unsafe fn reinsert_entry(&mut self, entry: TimerHandle, elapsed: u64, when: u64) { + entry.set_cached_when(when); + // Safety: The `cached_when` of this entry has been updated by calling the `set_cached_when`. + let (_level, _slot) = self.inner_insert(entry, elapsed, when); + } + + /// Inserts the `entry` to the `Wheel`. + /// Returns the level and the slot which `entry` insert into. + /// + /// Safety: The `cached_when` of this `entry`` must have been updated to `when`. + unsafe fn inner_insert( + &mut self, + entry: TimerHandle, + elapsed: u64, + when: u64, + ) -> (usize, usize) { + // Get the level at which the entry should be stored let level = level_for(elapsed, when); - unsafe { self.levels[level].add_entry(entry) }; + let slot = unsafe { self.levels[level].add_entry(entry) }; + (level, slot) } /// Instant at which to poll. From 1410a9fa63e7bb247c4a9582942117a97dc10611 Mon Sep 17 00:00:00 2001 From: wathenjiang Date: Wed, 24 Jul 2024 22:16:42 +0800 Subject: [PATCH 3/8] add #[cfg(not(loom))] for the reset_timer_and_drop test --- tokio/src/runtime/time/tests/mod.rs | 1 + 1 file changed, 1 insertion(+) diff --git a/tokio/src/runtime/time/tests/mod.rs b/tokio/src/runtime/time/tests/mod.rs index 7084c2d7519..35da14d8d82 100644 --- a/tokio/src/runtime/time/tests/mod.rs +++ b/tokio/src/runtime/time/tests/mod.rs @@ -203,6 +203,7 @@ fn reset_future() { } #[test] +#[cfg(not(loom))] fn reset_timer_and_drop() { model(|| { let rt = rt(false); From 71f60d516ad154a3e6de99a05295f74c045177dc Mon Sep 17 00:00:00 2001 From: wathenjiang Date: Wed, 24 Jul 2024 23:37:04 +0800 Subject: [PATCH 4/8] rm model in the reset_timer_and_drop test --- tokio/src/runtime/time/tests/mod.rs | 24 +++++++++++------------- 1 file changed, 11 insertions(+), 13 deletions(-) diff --git a/tokio/src/runtime/time/tests/mod.rs b/tokio/src/runtime/time/tests/mod.rs index 35da14d8d82..97557abbf89 100644 --- a/tokio/src/runtime/time/tests/mod.rs +++ b/tokio/src/runtime/time/tests/mod.rs @@ -205,23 +205,21 @@ fn reset_future() { #[test] #[cfg(not(loom))] fn reset_timer_and_drop() { - model(|| { - let rt = rt(false); - let handle = rt.handle(); + let rt = rt(false); + let handle = rt.handle(); - let start = handle.inner.driver().clock().now(); + let start = handle.inner.driver().clock().now(); - for _ in 0..2 { - let entry = TimerEntry::new(handle.inner.clone(), start + Duration::from_millis(10)); - pin!(entry); + for _ in 0..2 { + let entry = TimerEntry::new(handle.inner.clone(), start + Duration::from_millis(10)); + pin!(entry); - let _ = entry - .as_mut() - .poll_elapsed(&mut Context::from_waker(futures::task::noop_waker_ref())); + let _ = entry + .as_mut() + .poll_elapsed(&mut Context::from_waker(futures::task::noop_waker_ref())); - entry.as_mut().reset(start + Duration::from_secs(1), true); - } - }); + entry.as_mut().reset(start + Duration::from_secs(1), true); + } } #[cfg(not(loom))] From 004367cdf5e433d21d20cad602c48835cb7a3beb Mon Sep 17 00:00:00 2001 From: wathenjiang Date: Sun, 28 Jul 2024 14:04:14 +0800 Subject: [PATCH 5/8] revert some unnecessary modification && update some comments --- tokio/src/runtime/time/entry.rs | 1 + tokio/src/runtime/time/wheel/level.rs | 4 +--- tokio/src/runtime/time/wheel/mod.rs | 19 +++++++------------ 3 files changed, 9 insertions(+), 15 deletions(-) diff --git a/tokio/src/runtime/time/entry.rs b/tokio/src/runtime/time/entry.rs index f52944fb1dc..3f803b52cea 100644 --- a/tokio/src/runtime/time/entry.rs +++ b/tokio/src/runtime/time/entry.rs @@ -413,6 +413,7 @@ impl TimerShared { /// in the timer wheel. pub(super) unsafe fn set_expiration(&self, t: u64) { self.state.set_expiration(t); + self.cached_when.store(t, Ordering::Relaxed); } /// Sets the true time-of-expiration only if it is after the current. diff --git a/tokio/src/runtime/time/wheel/level.rs b/tokio/src/runtime/time/wheel/level.rs index cccff6d8827..6f90811ff11 100644 --- a/tokio/src/runtime/time/wheel/level.rs +++ b/tokio/src/runtime/time/wheel/level.rs @@ -118,14 +118,12 @@ impl Level { Some(slot) } - pub(crate) unsafe fn add_entry(&mut self, item: TimerHandle) -> usize { + pub(crate) unsafe fn add_entry(&mut self, item: TimerHandle) { let slot = slot_for(item.cached_when(), self.level); self.slot[slot].push_front(item); self.occupied |= occupied_bit(slot); - - slot } pub(crate) unsafe fn remove_entry(&mut self, item: NonNull) { diff --git a/tokio/src/runtime/time/wheel/mod.rs b/tokio/src/runtime/time/wheel/mod.rs index 7424c4892cb..fd260d54adb 100644 --- a/tokio/src/runtime/time/wheel/mod.rs +++ b/tokio/src/runtime/time/wheel/mod.rs @@ -149,7 +149,7 @@ impl Wheel { } // Safety: The `cached_when` of this item has been updated by calling the `sync_when`. - let (level, _slot) = self.inner_insert(item, self.elapsed, when); + let level = self.inner_insert(item, self.elapsed, when); debug_assert!({ self.levels[level] @@ -186,23 +186,18 @@ impl Wheel { pub(super) unsafe fn reinsert_entry(&mut self, entry: TimerHandle, elapsed: u64, when: u64) { entry.set_cached_when(when); // Safety: The `cached_when` of this entry has been updated by calling the `set_cached_when`. - let (_level, _slot) = self.inner_insert(entry, elapsed, when); + let _level = self.inner_insert(entry, elapsed, when); } /// Inserts the `entry` to the `Wheel`. - /// Returns the level and the slot which `entry` insert into. + /// Returns the level where the entry is inserted. /// /// Safety: The `cached_when` of this `entry`` must have been updated to `when`. - unsafe fn inner_insert( - &mut self, - entry: TimerHandle, - elapsed: u64, - when: u64, - ) -> (usize, usize) { - // Get the level at which the entry should be stored + unsafe fn inner_insert(&mut self, entry: TimerHandle, elapsed: u64, when: u64) -> usize { + // Get the level at which the entry should be stored. let level = level_for(elapsed, when); - let slot = unsafe { self.levels[level].add_entry(entry) }; - (level, slot) + unsafe { self.levels[level].add_entry(entry) }; + level } /// Instant at which to poll. From b1cf26d196a713c94a6dcfe295a4cfc55affd8a5 Mon Sep 17 00:00:00 2001 From: wathenjiang Date: Wed, 31 Jul 2024 11:48:27 +0800 Subject: [PATCH 6/8] test: assert_ne MAX_SAFE_MILLIS_DURATION for test --- tokio/src/runtime/time/mod.rs | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/tokio/src/runtime/time/mod.rs b/tokio/src/runtime/time/mod.rs index 3d1dc92f68f..68605938366 100644 --- a/tokio/src/runtime/time/mod.rs +++ b/tokio/src/runtime/time/mod.rs @@ -360,8 +360,10 @@ impl Handle { } } } - Err(state) if state == STATE_DEREGISTERED => {} + // TODO: The following code is commented out for running the test. Just for test. + // Err(state) if state == STATE_DEREGISTERED => {} Err(state) => { + assert_ne!(state, STATE_DEREGISTERED); // Safety: This Entry has not expired. unsafe { lock.reinsert_entry(entry, deadline, state) }; } From 08a688a7477562783407eea55427571546a6d8ef Mon Sep 17 00:00:00 2001 From: wathenjiang Date: Wed, 31 Jul 2024 12:29:18 +0800 Subject: [PATCH 7/8] time: remove MAX_SAFE_MILLIS_DURATION check --- tokio/src/runtime/time/entry.rs | 4 ---- tokio/src/runtime/time/mod.rs | 4 +--- 2 files changed, 1 insertion(+), 7 deletions(-) diff --git a/tokio/src/runtime/time/entry.rs b/tokio/src/runtime/time/entry.rs index 3f803b52cea..e5f69a1a91c 100644 --- a/tokio/src/runtime/time/entry.rs +++ b/tokio/src/runtime/time/entry.rs @@ -170,10 +170,6 @@ impl StateCell { // with relaxed ordering. let mut cur_state = self.state.load(Ordering::Relaxed); loop { - // Because its state is STATE_DEREGISTERED, it has been fired. - if cur_state == STATE_DEREGISTERED { - break Err(cur_state); - } // improve the error message for things like // https://github.com/tokio-rs/tokio/issues/3675 assert!( diff --git a/tokio/src/runtime/time/mod.rs b/tokio/src/runtime/time/mod.rs index 68605938366..f227e4f6cc3 100644 --- a/tokio/src/runtime/time/mod.rs +++ b/tokio/src/runtime/time/mod.rs @@ -360,10 +360,8 @@ impl Handle { } } } - // TODO: The following code is commented out for running the test. Just for test. - // Err(state) if state == STATE_DEREGISTERED => {} Err(state) => { - assert_ne!(state, STATE_DEREGISTERED); + debug_assert_ne!(state, STATE_DEREGISTERED); // Safety: This Entry has not expired. unsafe { lock.reinsert_entry(entry, deadline, state) }; } From 36b8c07ea3e9b6cbecca9e3c57cb0ab732b400dc Mon Sep 17 00:00:00 2001 From: wathenjiang Date: Mon, 5 Aug 2024 23:37:57 +0800 Subject: [PATCH 8/8] time: move wheel state maintenance method --- tokio/src/runtime/time/mod.rs | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/tokio/src/runtime/time/mod.rs b/tokio/src/runtime/time/mod.rs index f227e4f6cc3..be1d84c2155 100644 --- a/tokio/src/runtime/time/mod.rs +++ b/tokio/src/runtime/time/mod.rs @@ -324,7 +324,6 @@ impl Handle { } while let Some(expiration) = lock.poll(now) { - lock.set_elapsed(expiration.deadline); // It is critical for `GuardedLinkedList` safety that the guard node is // pinned in memory and is not dropped until the guarded list is dropped. let guard = TimerShared::new(id); @@ -351,6 +350,7 @@ impl Handle { waker_list.push(waker); if !waker_list.can_push() { + lock.occupied_bit_maintain(&expiration); // Wake a batch of wakers. To avoid deadlock, // we must do this with the lock temporarily dropped. drop(lock); @@ -367,6 +367,7 @@ impl Handle { } } } + lock.set_elapsed(expiration.deadline); lock.occupied_bit_maintain(&expiration); }