Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

time: avoid traversing entries in the time wheel twice #6718

Open
wants to merge 11 commits into
base: master
Choose a base branch
from
74 changes: 26 additions & 48 deletions tokio/src/runtime/time/entry.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,8 @@
//!
//! Each timer has a state field associated with it. This field contains either
//! the current scheduled time, or a special flag value indicating its state.
//! This state can either indicate that the timer is on the 'pending' queue (and
//! thus will be fired with an `Ok(())` result soon) or that it has already been
//! fired/deregistered.
//! This state can either indicate that the timer is firing (and thus will be fired
//! with an `Ok(())` result soon) or that it has already been fired/deregistered.
//!
//! This single state field allows for code that is firing the timer to
//! synchronize with any racing `reset` calls reliably.
Expand All @@ -49,10 +48,10 @@
//! There is of course a race condition between timer reset and timer
//! expiration. If the driver fails to observe the updated expiration time, it
//! could trigger expiration of the timer too early. However, because
//! [`mark_pending`][mark_pending] performs a compare-and-swap, it will identify this race and
//! refuse to mark the timer as pending.
//! [`mark_firing`][mark_firing] performs a compare-and-swap, it will identify this race and
//! refuse to mark the timer as firing.
//!
//! [mark_pending]: TimerHandle::mark_pending
//! [mark_firing]: TimerHandle::mark_firing

use crate::loom::cell::UnsafeCell;
use crate::loom::sync::atomic::AtomicU64;
Expand All @@ -70,9 +69,9 @@ use std::{marker::PhantomPinned, pin::Pin, ptr::NonNull};

type TimerResult = Result<(), crate::time::error::Error>;

const STATE_DEREGISTERED: u64 = u64::MAX;
const STATE_PENDING_FIRE: u64 = STATE_DEREGISTERED - 1;
const STATE_MIN_VALUE: u64 = STATE_PENDING_FIRE;
pub(super) const STATE_DEREGISTERED: u64 = u64::MAX;
const STATE_FIRING: u64 = STATE_DEREGISTERED - 1;
const STATE_MIN_VALUE: u64 = STATE_FIRING;
/// The largest safe integer to use for ticks.
///
/// This value should be updated if any other signal values are added above.
Expand Down Expand Up @@ -123,10 +122,6 @@ impl StateCell {
}
}

fn is_pending(&self) -> bool {
self.state.load(Ordering::Relaxed) == STATE_PENDING_FIRE
}

/// Returns the current expiration time, or None if not currently scheduled.
fn when(&self) -> Option<u64> {
let cur_state = self.state.load(Ordering::Relaxed);
Expand Down Expand Up @@ -162,26 +157,28 @@ impl StateCell {
}
}

/// Marks this timer as being moved to the pending list, if its scheduled
/// time is not after `not_after`.
/// Marks this timer firing, if its scheduled time is not after `not_after`.
///
/// If the timer is scheduled for a time after `not_after`, returns an Err
/// containing the current scheduled time.
///
/// SAFETY: Must hold the driver lock.
unsafe fn mark_pending(&self, not_after: u64) -> Result<(), u64> {
unsafe fn mark_firing(&self, not_after: u64) -> Result<(), u64> {
// Quick initial debug check to see if the timer is already fired. Since
// firing the timer can only happen with the driver lock held, we know
// we shouldn't be able to "miss" a transition to a fired state, even
// with relaxed ordering.
let mut cur_state = self.state.load(Ordering::Relaxed);

loop {
// Because its state is STATE_DEREGISTERED, it has been fired.
if cur_state == STATE_DEREGISTERED {
break Err(cur_state);
}
// improve the error message for things like
// https://github.com/tokio-rs/tokio/issues/3675
assert!(
cur_state < STATE_MIN_VALUE,
"mark_pending called when the timer entry is in an invalid state"
"mark_firing called when the timer entry is in an invalid state"
);

if cur_state > not_after {
Expand All @@ -190,7 +187,7 @@ impl StateCell {

match self.state.compare_exchange_weak(
cur_state,
STATE_PENDING_FIRE,
STATE_FIRING,
Ordering::AcqRel,
Ordering::Acquire,
) {
Expand Down Expand Up @@ -336,12 +333,11 @@ pub(crate) struct TimerShared {
///
/// Only accessed under the entry lock.
pointers: linked_list::Pointers<TimerShared>,

/// The expiration time for which this entry is currently registered.
/// It is used to calculate which slot this entry is stored in.
/// Generally owned by the driver, but is accessed by the entry when not
/// registered.
cached_when: AtomicU64,

/// Current state. This records whether the timer entry is currently under
/// the ownership of the driver, and if not, its current state (not
/// complete, fired, error, etc).
Expand Down Expand Up @@ -386,25 +382,22 @@ impl TimerShared {
// Cached-when is only accessed under the driver lock, so we can use relaxed
self.cached_when.load(Ordering::Relaxed)
}

/// Gets the true time-of-expiration value, and copies it into the cached
/// time-of-expiration value.
///
/// SAFETY: Must be called with the driver lock held, and when this entry is
/// not in any timer wheel lists.
pub(super) unsafe fn sync_when(&self) -> u64 {
let true_when = self.true_when();

self.cached_when.store(true_when, Ordering::Relaxed);

self.set_cached_when(true_when);
true_when
}

/// Sets the cached time-of-expiration value.
///
/// SAFETY: Must be called with the driver lock held, and when this entry is
/// not in any timer wheel lists.
unsafe fn set_cached_when(&self, when: u64) {
pub(super) unsafe fn set_cached_when(&self, when: u64) {
self.cached_when.store(when, Ordering::Relaxed);
}

Expand All @@ -420,7 +413,6 @@ impl TimerShared {
/// in the timer wheel.
pub(super) unsafe fn set_expiration(&self, t: u64) {
self.state.set_expiration(t);
self.cached_when.store(t, Ordering::Relaxed);
}

/// Sets the true time-of-expiration only if it is after the current.
Expand Down Expand Up @@ -594,12 +586,12 @@ impl TimerHandle {
unsafe { self.inner.as_ref().cached_when() }
}

pub(super) unsafe fn sync_when(&self) -> u64 {
unsafe { self.inner.as_ref().sync_when() }
pub(super) unsafe fn set_cached_when(&self, t: u64) {
unsafe { self.inner.as_ref().set_cached_when(t) }
}

pub(super) unsafe fn is_pending(&self) -> bool {
unsafe { self.inner.as_ref().state.is_pending() }
pub(super) unsafe fn sync_when(&self) -> u64 {
unsafe { self.inner.as_ref().sync_when() }
}

/// Forcibly sets the true and cached expiration times to the given tick.
Expand All @@ -610,27 +602,13 @@ impl TimerHandle {
self.inner.as_ref().set_expiration(tick);
}

/// Attempts to mark this entry as pending. If the expiration time is after
/// Attempts to mark this entry as firing. If the expiration time is after
/// `not_after`, however, returns an Err with the current expiration time.
///
/// If an `Err` is returned, the `cached_when` value will be updated to this
/// new expiration time.
///
/// SAFETY: The caller must ensure that the handle remains valid, the driver
/// lock is held, and that the timer is not in any wheel linked lists.
/// After returning Ok, the entry must be added to the pending list.
pub(super) unsafe fn mark_pending(&self, not_after: u64) -> Result<(), u64> {
match self.inner.as_ref().state.mark_pending(not_after) {
Ok(()) => {
// mark this as being on the pending queue in cached_when
self.inner.as_ref().set_cached_when(u64::MAX);
Ok(())
}
Err(tick) => {
self.inner.as_ref().set_cached_when(tick);
Err(tick)
}
}
pub(super) unsafe fn mark_firing(&self, not_after: u64) -> Result<(), u64> {
self.inner.as_ref().state.mark_firing(not_after)
}

/// Attempts to transition to a terminal state. If the state is already a
Expand Down
61 changes: 45 additions & 16 deletions tokio/src/runtime/time/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@

mod entry;
pub(crate) use entry::TimerEntry;
use entry::{EntryList, TimerHandle, TimerShared, MAX_SAFE_MILLIS_DURATION};
use entry::{EntryList, TimerHandle, TimerShared, MAX_SAFE_MILLIS_DURATION, STATE_DEREGISTERED};

mod handle;
pub(crate) use self::handle::Handle;
Expand Down Expand Up @@ -204,7 +204,6 @@ impl Driver {
.inner
.next_wake
.store(next_wake_time(expiration_time));

// Safety: After updating the `next_wake`, we drop all the locks.
drop(locks);

Expand Down Expand Up @@ -324,23 +323,53 @@ impl Handle {
now = lock.elapsed();
}

while let Some(entry) = lock.poll(now) {
debug_assert!(unsafe { entry.is_pending() });

// SAFETY: We hold the driver lock, and just removed the entry from any linked lists.
if let Some(waker) = unsafe { entry.fire(Ok(())) } {
waker_list.push(waker);

if !waker_list.can_push() {
// Wake a batch of wakers. To avoid deadlock, we must do this with the lock temporarily dropped.
drop(lock);

waker_list.wake_all();

lock = self.inner.lock_sharded_wheel(id);
while let Some(expiration) = lock.poll(now) {
lock.set_elapsed(expiration.deadline);
// It is critical for `GuardedLinkedList` safety that the guard node is
// pinned in memory and is not dropped until the guarded list is dropped.
let guard = TimerShared::new(id);
pin!(guard);
let guard_handle = guard.as_ref().get_ref().handle();

// * This list will be still guarded by the lock of the Wheel with the specefied id.
// `EntryWaitersList` wrapper makes sure we hold the lock to modify it.
// * This wrapper will empty the list on drop. It is critical for safety
// that we will not leave any list entry with a pointer to the local
// guard node after this function returns / panics.
// Safety: The `TimerShared` inside this `TimerHandle` is pinned in the memory.
let mut list = unsafe { lock.get_waiters_list(&expiration, guard_handle, id, self) };

while let Some(entry) = list.pop_back_locked(&mut lock) {
let deadline = expiration.deadline;
// Try to expire the entry; this is cheap (doesn't synchronize) if
// the timer is not expired, and updates cached_when.
match unsafe { entry.mark_firing(deadline) } {
Ok(()) => {
// Entry was expired.
// SAFETY: We hold the driver lock, and just removed the entry from any linked lists.
if let Some(waker) = unsafe { entry.fire(Ok(())) } {
waker_list.push(waker);

if !waker_list.can_push() {
// Wake a batch of wakers. To avoid deadlock,
// we must do this with the lock temporarily dropped.
drop(lock);
waker_list.wake_all();

lock = self.inner.lock_sharded_wheel(id);
}
}
}
Err(state) if state == STATE_DEREGISTERED => {}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can STATE_DEREGISTERED actually happen?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No. I think it will not happen.

Err(state) => {
// Safety: This Entry has not expired.
unsafe { lock.reinsert_entry(entry, deadline, state) };
}
}
}
lock.occupied_bit_maintain(&expiration);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The old code doesn't call this. Why is it necessary now?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The reason I add occupied_bit_maintain now is that in the past, we would use take_entries to take all items from the linked list at once and then update the occupied bit. But now, we need to traverse the items one by one, which means we need to update the occupied bit after the traversal is completed. I think the difference here is that in the current version, the linked list is not necessarily empty, so I add this method.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would it be accurate to say that this is an operation that must be called when we are done with list? Would it make sense for it to be part of the destructor of EntryWaitersList?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does it need to be called before drop(lock) in if !waker_list.can_push() {?

Copy link
Contributor Author

@wathenjiang wathenjiang Aug 5, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Calling occupied_bit_maintain before drop(lock) in if !waker_list.can_push() { makes sense to me. That is a good idea.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What about set_elapsed?

I would really like to see a loom test where a timer is inserted during the if !waker_list.can_push() { ... } block. It's hard for me to tell whether it's correct.

I think you can do the loom test along these lines:

  1. Start a runtime with paused time.
  2. Create and poll 32 timers with the timer.
  3. Spawn a thread. The background thread will create a timer, register it, and wait for it to complete.
  4. In parallel with that wait for the 32 timers in order.
  5. Join the background thread.

Spawning the thread after registering the 32 timers should reduce the size of the loom test, since the region with more than 1 thread will be shorter. As long as the if !waker_list.can_push() { ... } block happens in the more-than-1-thread region of the test, it should work.

}

let next_wake_up = lock.poll_at();
drop(lock);

Expand Down
21 changes: 21 additions & 0 deletions tokio/src/runtime/time/tests/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -202,6 +202,27 @@ fn reset_future() {
})
}

#[test]
fn reset_timer_and_drop() {
model(|| {
Darksonn marked this conversation as resolved.
Show resolved Hide resolved
let rt = rt(false);
let handle = rt.handle();

let start = handle.inner.driver().clock().now();

for _ in 0..2 {
let entry = TimerEntry::new(handle.inner.clone(), start + Duration::from_millis(10));
pin!(entry);

let _ = entry
.as_mut()
.poll_elapsed(&mut Context::from_waker(futures::task::noop_waker_ref()));

entry.as_mut().reset(start + Duration::from_secs(1), true);
}
});
}

#[cfg(not(loom))]
fn normal_or_miri<T>(normal: T, miri: T) -> T {
if cfg!(miri) {
Expand Down
22 changes: 13 additions & 9 deletions tokio/src/runtime/time/wheel/level.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@ pub(crate) struct Level {
}

/// Indicates when a slot must be processed next.
#[derive(Debug)]
pub(crate) struct Expiration {
/// The level containing the slot.
pub(crate) level: usize,
Expand Down Expand Up @@ -81,7 +80,7 @@ impl Level {
// pseudo-ring buffer, and we rotate around them indefinitely. If we
// compute a deadline before now, and it's the top level, it
// therefore means we're actually looking at a slot in the future.
debug_assert_eq!(self.level, super::NUM_LEVELS - 1);
debug_assert_eq!(self.level, super::MAX_LEVEL_INDEX);

deadline += level_range;
}
Expand Down Expand Up @@ -119,32 +118,37 @@ impl Level {
Some(slot)
}

pub(crate) unsafe fn add_entry(&mut self, item: TimerHandle) {
pub(crate) unsafe fn add_entry(&mut self, item: TimerHandle) -> usize {
let slot = slot_for(item.cached_when(), self.level);

self.slot[slot].push_front(item);

self.occupied |= occupied_bit(slot);

slot
}

pub(crate) unsafe fn remove_entry(&mut self, item: NonNull<TimerShared>) {
let slot = slot_for(unsafe { item.as_ref().cached_when() }, self.level);

unsafe { self.slot[slot].remove(item) };
if self.slot[slot].is_empty() {
// The bit is currently set
debug_assert!(self.occupied & occupied_bit(slot) != 0);

// Unset the bit
self.occupied ^= occupied_bit(slot);
}
}

pub(crate) fn take_slot(&mut self, slot: usize) -> EntryList {
self.occupied &= !occupied_bit(slot);

pub(super) fn take_slot(&mut self, slot: usize) -> EntryList {
std::mem::take(&mut self.slot[slot])
}

pub(super) fn occupied_bit_maintain(&mut self, slot: usize) {
if self.slot[slot].is_empty() {
self.occupied &= !occupied_bit(slot);
} else {
self.occupied |= occupied_bit(slot);
}
}
}

impl fmt::Debug for Level {
Expand Down
Loading
Loading