diff --git a/tokio/src/runtime/time/entry.rs b/tokio/src/runtime/time/entry.rs index 834077caa3d..e5f69a1a91c 100644 --- a/tokio/src/runtime/time/entry.rs +++ b/tokio/src/runtime/time/entry.rs @@ -21,9 +21,8 @@ //! //! Each timer has a state field associated with it. This field contains either //! the current scheduled time, or a special flag value indicating its state. -//! This state can either indicate that the timer is on the 'pending' queue (and -//! thus will be fired with an `Ok(())` result soon) or that it has already been -//! fired/deregistered. +//! This state can either indicate that the timer is firing (and thus will be fired +//! with an `Ok(())` result soon) or that it has already been fired/deregistered. //! //! This single state field allows for code that is firing the timer to //! synchronize with any racing `reset` calls reliably. @@ -49,10 +48,10 @@ //! There is of course a race condition between timer reset and timer //! expiration. If the driver fails to observe the updated expiration time, it //! could trigger expiration of the timer too early. However, because -//! [`mark_pending`][mark_pending] performs a compare-and-swap, it will identify this race and -//! refuse to mark the timer as pending. +//! [`mark_firing`][mark_firing] performs a compare-and-swap, it will identify this race and +//! refuse to mark the timer as firing. //! -//! [mark_pending]: TimerHandle::mark_pending +//! [mark_firing]: TimerHandle::mark_firing use crate::loom::cell::UnsafeCell; use crate::loom::sync::atomic::AtomicU64; @@ -70,9 +69,9 @@ use std::{marker::PhantomPinned, pin::Pin, ptr::NonNull}; type TimerResult = Result<(), crate::time::error::Error>; -const STATE_DEREGISTERED: u64 = u64::MAX; -const STATE_PENDING_FIRE: u64 = STATE_DEREGISTERED - 1; -const STATE_MIN_VALUE: u64 = STATE_PENDING_FIRE; +pub(super) const STATE_DEREGISTERED: u64 = u64::MAX; +const STATE_FIRING: u64 = STATE_DEREGISTERED - 1; +const STATE_MIN_VALUE: u64 = STATE_FIRING; /// The largest safe integer to use for ticks. /// /// This value should be updated if any other signal values are added above. @@ -123,10 +122,6 @@ impl StateCell { } } - fn is_pending(&self) -> bool { - self.state.load(Ordering::Relaxed) == STATE_PENDING_FIRE - } - /// Returns the current expiration time, or None if not currently scheduled. fn when(&self) -> Option { let cur_state = self.state.load(Ordering::Relaxed); @@ -162,26 +157,24 @@ impl StateCell { } } - /// Marks this timer as being moved to the pending list, if its scheduled - /// time is not after `not_after`. + /// Marks this timer firing, if its scheduled time is not after `not_after`. /// /// If the timer is scheduled for a time after `not_after`, returns an Err /// containing the current scheduled time. /// /// SAFETY: Must hold the driver lock. - unsafe fn mark_pending(&self, not_after: u64) -> Result<(), u64> { + unsafe fn mark_firing(&self, not_after: u64) -> Result<(), u64> { // Quick initial debug check to see if the timer is already fired. Since // firing the timer can only happen with the driver lock held, we know // we shouldn't be able to "miss" a transition to a fired state, even // with relaxed ordering. let mut cur_state = self.state.load(Ordering::Relaxed); - loop { // improve the error message for things like // https://github.com/tokio-rs/tokio/issues/3675 assert!( cur_state < STATE_MIN_VALUE, - "mark_pending called when the timer entry is in an invalid state" + "mark_firing called when the timer entry is in an invalid state" ); if cur_state > not_after { @@ -190,7 +183,7 @@ impl StateCell { match self.state.compare_exchange_weak( cur_state, - STATE_PENDING_FIRE, + STATE_FIRING, Ordering::AcqRel, Ordering::Acquire, ) { @@ -336,12 +329,11 @@ pub(crate) struct TimerShared { /// /// Only accessed under the entry lock. pointers: linked_list::Pointers, - /// The expiration time for which this entry is currently registered. + /// It is used to calculate which slot this entry is stored in. /// Generally owned by the driver, but is accessed by the entry when not /// registered. cached_when: AtomicU64, - /// Current state. This records whether the timer entry is currently under /// the ownership of the driver, and if not, its current state (not /// complete, fired, error, etc). @@ -386,7 +378,6 @@ impl TimerShared { // Cached-when is only accessed under the driver lock, so we can use relaxed self.cached_when.load(Ordering::Relaxed) } - /// Gets the true time-of-expiration value, and copies it into the cached /// time-of-expiration value. /// @@ -394,9 +385,7 @@ impl TimerShared { /// not in any timer wheel lists. pub(super) unsafe fn sync_when(&self) -> u64 { let true_when = self.true_when(); - - self.cached_when.store(true_when, Ordering::Relaxed); - + self.set_cached_when(true_when); true_when } @@ -404,7 +393,7 @@ impl TimerShared { /// /// SAFETY: Must be called with the driver lock held, and when this entry is /// not in any timer wheel lists. - unsafe fn set_cached_when(&self, when: u64) { + pub(super) unsafe fn set_cached_when(&self, when: u64) { self.cached_when.store(when, Ordering::Relaxed); } @@ -594,12 +583,12 @@ impl TimerHandle { unsafe { self.inner.as_ref().cached_when() } } - pub(super) unsafe fn sync_when(&self) -> u64 { - unsafe { self.inner.as_ref().sync_when() } + pub(super) unsafe fn set_cached_when(&self, t: u64) { + unsafe { self.inner.as_ref().set_cached_when(t) } } - pub(super) unsafe fn is_pending(&self) -> bool { - unsafe { self.inner.as_ref().state.is_pending() } + pub(super) unsafe fn sync_when(&self) -> u64 { + unsafe { self.inner.as_ref().sync_when() } } /// Forcibly sets the true and cached expiration times to the given tick. @@ -610,27 +599,13 @@ impl TimerHandle { self.inner.as_ref().set_expiration(tick); } - /// Attempts to mark this entry as pending. If the expiration time is after + /// Attempts to mark this entry as firing. If the expiration time is after /// `not_after`, however, returns an Err with the current expiration time. /// - /// If an `Err` is returned, the `cached_when` value will be updated to this - /// new expiration time. - /// /// SAFETY: The caller must ensure that the handle remains valid, the driver /// lock is held, and that the timer is not in any wheel linked lists. - /// After returning Ok, the entry must be added to the pending list. - pub(super) unsafe fn mark_pending(&self, not_after: u64) -> Result<(), u64> { - match self.inner.as_ref().state.mark_pending(not_after) { - Ok(()) => { - // mark this as being on the pending queue in cached_when - self.inner.as_ref().set_cached_when(u64::MAX); - Ok(()) - } - Err(tick) => { - self.inner.as_ref().set_cached_when(tick); - Err(tick) - } - } + pub(super) unsafe fn mark_firing(&self, not_after: u64) -> Result<(), u64> { + self.inner.as_ref().state.mark_firing(not_after) } /// Attempts to transition to a terminal state. If the state is already a diff --git a/tokio/src/runtime/time/mod.rs b/tokio/src/runtime/time/mod.rs index c01a5f2b25e..be1d84c2155 100644 --- a/tokio/src/runtime/time/mod.rs +++ b/tokio/src/runtime/time/mod.rs @@ -8,7 +8,7 @@ mod entry; pub(crate) use entry::TimerEntry; -use entry::{EntryList, TimerHandle, TimerShared, MAX_SAFE_MILLIS_DURATION}; +use entry::{EntryList, TimerHandle, TimerShared, MAX_SAFE_MILLIS_DURATION, STATE_DEREGISTERED}; mod handle; pub(crate) use self::handle::Handle; @@ -204,7 +204,6 @@ impl Driver { .inner .next_wake .store(next_wake_time(expiration_time)); - // Safety: After updating the `next_wake`, we drop all the locks. drop(locks); @@ -324,23 +323,54 @@ impl Handle { now = lock.elapsed(); } - while let Some(entry) = lock.poll(now) { - debug_assert!(unsafe { entry.is_pending() }); - - // SAFETY: We hold the driver lock, and just removed the entry from any linked lists. - if let Some(waker) = unsafe { entry.fire(Ok(())) } { - waker_list.push(waker); - - if !waker_list.can_push() { - // Wake a batch of wakers. To avoid deadlock, we must do this with the lock temporarily dropped. - drop(lock); - - waker_list.wake_all(); - - lock = self.inner.lock_sharded_wheel(id); + while let Some(expiration) = lock.poll(now) { + // It is critical for `GuardedLinkedList` safety that the guard node is + // pinned in memory and is not dropped until the guarded list is dropped. + let guard = TimerShared::new(id); + pin!(guard); + let guard_handle = guard.as_ref().get_ref().handle(); + + // * This list will be still guarded by the lock of the Wheel with the specefied id. + // `EntryWaitersList` wrapper makes sure we hold the lock to modify it. + // * This wrapper will empty the list on drop. It is critical for safety + // that we will not leave any list entry with a pointer to the local + // guard node after this function returns / panics. + // Safety: The `TimerShared` inside this `TimerHandle` is pinned in the memory. + let mut list = unsafe { lock.get_waiters_list(&expiration, guard_handle, id, self) }; + + while let Some(entry) = list.pop_back_locked(&mut lock) { + let deadline = expiration.deadline; + // Try to expire the entry; this is cheap (doesn't synchronize) if + // the timer is not expired, and updates cached_when. + match unsafe { entry.mark_firing(deadline) } { + Ok(()) => { + // Entry was expired. + // SAFETY: We hold the driver lock, and just removed the entry from any linked lists. + if let Some(waker) = unsafe { entry.fire(Ok(())) } { + waker_list.push(waker); + + if !waker_list.can_push() { + lock.occupied_bit_maintain(&expiration); + // Wake a batch of wakers. To avoid deadlock, + // we must do this with the lock temporarily dropped. + drop(lock); + waker_list.wake_all(); + + lock = self.inner.lock_sharded_wheel(id); + } + } + } + Err(state) => { + debug_assert_ne!(state, STATE_DEREGISTERED); + // Safety: This Entry has not expired. + unsafe { lock.reinsert_entry(entry, deadline, state) }; + } } } + lock.set_elapsed(expiration.deadline); + lock.occupied_bit_maintain(&expiration); } + let next_wake_up = lock.poll_at(); drop(lock); diff --git a/tokio/src/runtime/time/tests/mod.rs b/tokio/src/runtime/time/tests/mod.rs index 0e453433691..97557abbf89 100644 --- a/tokio/src/runtime/time/tests/mod.rs +++ b/tokio/src/runtime/time/tests/mod.rs @@ -202,6 +202,26 @@ fn reset_future() { }) } +#[test] +#[cfg(not(loom))] +fn reset_timer_and_drop() { + let rt = rt(false); + let handle = rt.handle(); + + let start = handle.inner.driver().clock().now(); + + for _ in 0..2 { + let entry = TimerEntry::new(handle.inner.clone(), start + Duration::from_millis(10)); + pin!(entry); + + let _ = entry + .as_mut() + .poll_elapsed(&mut Context::from_waker(futures::task::noop_waker_ref())); + + entry.as_mut().reset(start + Duration::from_secs(1), true); + } +} + #[cfg(not(loom))] fn normal_or_miri(normal: T, miri: T) -> T { if cfg!(miri) { diff --git a/tokio/src/runtime/time/wheel/level.rs b/tokio/src/runtime/time/wheel/level.rs index d31eaf46879..6f90811ff11 100644 --- a/tokio/src/runtime/time/wheel/level.rs +++ b/tokio/src/runtime/time/wheel/level.rs @@ -20,7 +20,6 @@ pub(crate) struct Level { } /// Indicates when a slot must be processed next. -#[derive(Debug)] pub(crate) struct Expiration { /// The level containing the slot. pub(crate) level: usize, @@ -81,7 +80,7 @@ impl Level { // pseudo-ring buffer, and we rotate around them indefinitely. If we // compute a deadline before now, and it's the top level, it // therefore means we're actually looking at a slot in the future. - debug_assert_eq!(self.level, super::NUM_LEVELS - 1); + debug_assert_eq!(self.level, super::MAX_LEVEL_INDEX); deadline += level_range; } @@ -132,19 +131,22 @@ impl Level { unsafe { self.slot[slot].remove(item) }; if self.slot[slot].is_empty() { - // The bit is currently set - debug_assert!(self.occupied & occupied_bit(slot) != 0); - // Unset the bit self.occupied ^= occupied_bit(slot); } } - pub(crate) fn take_slot(&mut self, slot: usize) -> EntryList { - self.occupied &= !occupied_bit(slot); - + pub(super) fn take_slot(&mut self, slot: usize) -> EntryList { std::mem::take(&mut self.slot[slot]) } + + pub(super) fn occupied_bit_maintain(&mut self, slot: usize) { + if self.slot[slot].is_empty() { + self.occupied &= !occupied_bit(slot); + } else { + self.occupied |= occupied_bit(slot); + } + } } impl fmt::Debug for Level { diff --git a/tokio/src/runtime/time/wheel/mod.rs b/tokio/src/runtime/time/wheel/mod.rs index f2b4228514c..fd260d54adb 100644 --- a/tokio/src/runtime/time/wheel/mod.rs +++ b/tokio/src/runtime/time/wheel/mod.rs @@ -1,5 +1,6 @@ use crate::runtime::time::{TimerHandle, TimerShared}; use crate::time::error::InsertError; +use crate::util::linked_list::{self, GuardedLinkedList, LinkedList}; mod level; pub(crate) use self::level::Expiration; @@ -7,7 +8,59 @@ use self::level::Level; use std::{array, ptr::NonNull}; -use super::EntryList; +use super::entry::MAX_SAFE_MILLIS_DURATION; +use super::Handle; + +/// List used in `Handle::process_at_sharded_time`. It wraps a guarded linked list +/// and gates the access to it on the lock of the `Wheel` with the specified `wheel_id`. +/// It also empties the list on drop. +pub(super) struct EntryWaitersList<'a> { + // GuardedLinkedList ensures that the concurrent drop of Entry in this slot is safe. + list: GuardedLinkedList::Target>, + is_empty: bool, + wheel_id: u32, + handle: &'a Handle, +} + +impl<'a> Drop for EntryWaitersList<'a> { + fn drop(&mut self) { + // If the list is not empty, we unlink all waiters from it. + // We do not wake the waiters to avoid double panics. + if !self.is_empty { + let _lock = self.handle.inner.lock_sharded_wheel(self.wheel_id); + while self.list.pop_back().is_some() {} + } + } +} + +impl<'a> EntryWaitersList<'a> { + fn new( + unguarded_list: LinkedList::Target>, + guard_handle: TimerHandle, + wheel_id: u32, + handle: &'a Handle, + ) -> Self { + let list = unguarded_list.into_guarded(guard_handle); + Self { + list, + is_empty: false, + wheel_id, + handle, + } + } + + /// Removes the last element from the guarded list. Modifying this list + /// requires an exclusive access to the Wheel with the specified `wheel_id`. + pub(super) fn pop_back_locked(&mut self, _wheel: &mut Wheel) -> Option { + let result = self.list.pop_back(); + if result.is_none() { + // Save information about emptiness to avoid waiting for lock + // in the destructor. + self.is_empty = true; + } + result + } +} /// Timing wheel implementation. /// @@ -36,9 +89,6 @@ pub(crate) struct Wheel { /// * ~ 4 hr slots / ~ 12 day range /// * ~ 12 day slots / ~ 2 yr range levels: Box<[Level; NUM_LEVELS]>, - - /// Entries queued for firing - pending: EntryList, } /// Number of levels. Each level has 64 slots. By using 6 levels with 64 slots @@ -46,6 +96,9 @@ pub(crate) struct Wheel { /// precision of 1 millisecond. const NUM_LEVELS: usize = 6; +/// The max level index. +pub(super) const MAX_LEVEL_INDEX: usize = NUM_LEVELS - 1; + /// The maximum duration of a `Sleep`. pub(super) const MAX_DURATION: u64 = (1 << (6 * NUM_LEVELS)) - 1; @@ -55,7 +108,6 @@ impl Wheel { Wheel { elapsed: 0, levels: Box::new(array::from_fn(Level::new)), - pending: EntryList::new(), } } @@ -96,12 +148,8 @@ impl Wheel { return Err((item, InsertError::Elapsed)); } - // Get the level at which the entry should be stored - let level = self.level_for(when); - - unsafe { - self.levels[level].add_entry(item); - } + // Safety: The `cached_when` of this item has been updated by calling the `sync_when`. + let level = self.inner_insert(item, self.elapsed, when); debug_assert!({ self.levels[level] @@ -117,9 +165,7 @@ impl Wheel { pub(crate) unsafe fn remove(&mut self, item: NonNull) { unsafe { let when = item.as_ref().cached_when(); - if when == u64::MAX { - self.pending.remove(item); - } else { + if when <= MAX_SAFE_MILLIS_DURATION { debug_assert!( self.elapsed <= when, "elapsed={}; when={}", @@ -128,54 +174,54 @@ impl Wheel { ); let level = self.level_for(when); + // If the entry is not contained in the `slot` list, + // then it is contained by a guarded list. self.levels[level].remove_entry(item); } } } + /// Reinserts `item` to the timing wheel. + /// Safety: This entry must not have expired. + pub(super) unsafe fn reinsert_entry(&mut self, entry: TimerHandle, elapsed: u64, when: u64) { + entry.set_cached_when(when); + // Safety: The `cached_when` of this entry has been updated by calling the `set_cached_when`. + let _level = self.inner_insert(entry, elapsed, when); + } + + /// Inserts the `entry` to the `Wheel`. + /// Returns the level where the entry is inserted. + /// + /// Safety: The `cached_when` of this `entry`` must have been updated to `when`. + unsafe fn inner_insert(&mut self, entry: TimerHandle, elapsed: u64, when: u64) -> usize { + // Get the level at which the entry should be stored. + let level = level_for(elapsed, when); + unsafe { self.levels[level].add_entry(entry) }; + level + } + /// Instant at which to poll. pub(crate) fn poll_at(&self) -> Option { self.next_expiration().map(|expiration| expiration.deadline) } /// Advances the timer up to the instant represented by `now`. - pub(crate) fn poll(&mut self, now: u64) -> Option { - loop { - if let Some(handle) = self.pending.pop_back() { - return Some(handle); - } - - match self.next_expiration() { - Some(ref expiration) if expiration.deadline <= now => { - self.process_expiration(expiration); - - self.set_elapsed(expiration.deadline); - } - _ => { - // in this case the poll did not indicate an expiration - // _and_ we were not able to find a next expiration in - // the current list of timers. advance to the poll's - // current time and do nothing else. - self.set_elapsed(now); - break; - } + pub(crate) fn poll(&mut self, now: u64) -> Option { + match self.next_expiration() { + Some(expiration) if expiration.deadline <= now => Some(expiration), + _ => { + // in this case the poll did not indicate an expiration + // _and_ we were not able to find a next expiration in + // the current list of timers. advance to the poll's + // current time and do nothing else. + self.set_elapsed(now); + None } } - - self.pending.pop_back() } /// Returns the instant at which the next timeout expires. fn next_expiration(&self) -> Option { - if !self.pending.is_empty() { - // Expire immediately as we have things pending firing - return Some(Expiration { - level: 0, - slot: 0, - deadline: self.elapsed, - }); - } - // Check all levels for (level_num, level) in self.levels.iter().enumerate() { if let Some(expiration) = level.next_expiration(self.elapsed) { @@ -211,46 +257,7 @@ impl Wheel { res } - /// iteratively find entries that are between the wheel's current - /// time and the expiration time. for each in that population either - /// queue it for notification (in the case of the last level) or tier - /// it down to the next level (in all other cases). - pub(crate) fn process_expiration(&mut self, expiration: &Expiration) { - // Note that we need to take _all_ of the entries off the list before - // processing any of them. This is important because it's possible that - // those entries might need to be reinserted into the same slot. - // - // This happens only on the highest level, when an entry is inserted - // more than MAX_DURATION into the future. When this happens, we wrap - // around, and process some entries a multiple of MAX_DURATION before - // they actually need to be dropped down a level. We then reinsert them - // back into the same position; we must make sure we don't then process - // those entries again or we'll end up in an infinite loop. - let mut entries = self.take_entries(expiration); - - while let Some(item) = entries.pop_back() { - if expiration.level == 0 { - debug_assert_eq!(unsafe { item.cached_when() }, expiration.deadline); - } - - // Try to expire the entry; this is cheap (doesn't synchronize) if - // the timer is not expired, and updates cached_when. - match unsafe { item.mark_pending(expiration.deadline) } { - Ok(()) => { - // Item was expired - self.pending.push_front(item); - } - Err(expiration_tick) => { - let level = level_for(expiration.deadline, expiration_tick); - unsafe { - self.levels[level].add_entry(item); - } - } - } - } - } - - fn set_elapsed(&mut self, when: u64) { + pub(super) fn set_elapsed(&mut self, when: u64) { assert!( self.elapsed <= when, "elapsed={:?}; when={:?}", @@ -263,9 +270,31 @@ impl Wheel { } } - /// Obtains the list of entries that need processing for the given expiration. - fn take_entries(&mut self, expiration: &Expiration) -> EntryList { - self.levels[expiration.level].take_slot(expiration.slot) + /// Obtains the guarded list of entries that need processing for the given expiration. + /// Safety: The `TimerShared` inside `guard_handle` must be pinned in the memory. + pub(super) unsafe fn get_waiters_list<'a>( + &mut self, + expiration: &Expiration, + guard_handle: TimerHandle, + wheel_id: u32, + handle: &'a Handle, + ) -> EntryWaitersList<'a> { + // Note that we need to take _all_ of the entries off the list before + // processing any of them. This is important because it's possible that + // those entries might need to be reinserted into the same slot. + // + // This happens only on the highest level, when an entry is inserted + // more than MAX_DURATION into the future. When this happens, we wrap + // around, and process some entries a multiple of MAX_DURATION before + // they actually need to be dropped down a level. We then reinsert them + // back into the same position; we must make sure we don't then process + // those entries again or we'll end up in an infinite loop. + let unguarded_list = self.levels[expiration.level].take_slot(expiration.slot); + EntryWaitersList::new(unguarded_list, guard_handle, wheel_id, handle) + } + + pub(super) fn occupied_bit_maintain(&mut self, expiration: &Expiration) { + self.levels[expiration.level].occupied_bit_maintain(expiration.slot); } fn level_for(&self, when: u64) -> usize { diff --git a/tokio/src/util/linked_list.rs b/tokio/src/util/linked_list.rs index 3650f87fbb0..382d9ee1978 100644 --- a/tokio/src/util/linked_list.rs +++ b/tokio/src/util/linked_list.rs @@ -334,6 +334,7 @@ feature! { feature = "sync", feature = "rt", feature = "signal", + feature = "time", )] /// An intrusive linked list, but instead of keeping pointers to the head