Skip to content

Commit

Permalink
time: revert set_cached_when
Browse files Browse the repository at this point in the history
  • Loading branch information
wathenjiang committed Jul 24, 2024
1 parent 73d6560 commit cd70a24
Show file tree
Hide file tree
Showing 5 changed files with 89 additions and 17 deletions.
47 changes: 41 additions & 6 deletions tokio/src/runtime/time/entry.rs
Original file line number Diff line number Diff line change
Expand Up @@ -333,7 +333,11 @@ pub(crate) struct TimerShared {
///
/// Only accessed under the entry lock.
pointers: linked_list::Pointers<TimerShared>,

/// The expiration time for which this entry is currently registered.
/// It is used to calculate which slot this entry is stored in.
/// Generally owned by the driver, but is accessed by the entry when not
/// registered.
cached_when: AtomicU64,
/// Current state. This records whether the timer entry is currently under
/// the ownership of the driver, and if not, its current state (not
/// complete, fired, error, etc).
Expand All @@ -348,6 +352,7 @@ unsafe impl Sync for TimerShared {}
impl std::fmt::Debug for TimerShared {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("TimerShared")
.field("cached_when", &self.cached_when.load(Ordering::Relaxed))
.field("state", &self.state)
.finish()
}
Expand All @@ -365,12 +370,37 @@ impl TimerShared {
pub(super) fn new(shard_id: u32) -> Self {
Self {
shard_id,
cached_when: AtomicU64::new(0),
pointers: linked_list::Pointers::new(),
state: StateCell::default(),
_p: PhantomPinned,
}
}

/// Gets the cached time-of-expiration value.
pub(super) fn cached_when(&self) -> u64 {
// Cached-when is only accessed under the driver lock, so we can use relaxed
self.cached_when.load(Ordering::Relaxed)
}
/// Gets the true time-of-expiration value, and copies it into the cached
/// time-of-expiration value.
///
/// SAFETY: Must be called with the driver lock held, and when this entry is
/// not in any timer wheel lists.
pub(super) unsafe fn sync_when(&self) -> u64 {
let true_when = self.true_when();
self.set_cached_when(true_when);
true_when
}

/// Sets the cached time-of-expiration value.
///
/// SAFETY: Must be called with the driver lock held, and when this entry is
/// not in any timer wheel lists.
pub(super) unsafe fn set_cached_when(&self, when: u64) {
self.cached_when.store(when, Ordering::Relaxed);
}

/// Returns the true time-of-expiration value, with relaxed memory ordering.
pub(super) fn true_when(&self) -> u64 {
self.state.when().expect("Timer already fired")
Expand Down Expand Up @@ -552,8 +582,16 @@ impl TimerEntry {
}

impl TimerHandle {
pub(super) unsafe fn true_when(&self) -> u64 {
unsafe { self.inner.as_ref().true_when() }
pub(super) unsafe fn cached_when(&self) -> u64 {
unsafe { self.inner.as_ref().cached_when() }
}

pub(super) unsafe fn set_cached_when(&self, t: u64) {
unsafe { self.inner.as_ref().set_cached_when(t) }
}

pub(super) unsafe fn sync_when(&self) -> u64 {
unsafe { self.inner.as_ref().sync_when() }
}

/// Forcibly sets the true and cached expiration times to the given tick.
Expand All @@ -567,9 +605,6 @@ impl TimerHandle {
/// Attempts to mark this entry as firing. If the expiration time is after
/// `not_after`, however, returns an Err with the current expiration time.
///
/// If an `Err` is returned, the `cached_when` value will be updated to this
/// new expiration time.
///
/// SAFETY: The caller must ensure that the handle remains valid, the driver
/// lock is held, and that the timer is not in any wheel linked lists.
pub(super) unsafe fn mark_firing(&self, not_after: u64) -> Result<(), u64> {
Expand Down
1 change: 0 additions & 1 deletion tokio/src/runtime/time/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -204,7 +204,6 @@ impl Driver {
.inner
.next_wake
.store(next_wake_time(expiration_time));

// Safety: After updating the `next_wake`, we drop all the locks.
drop(locks);

Expand Down
21 changes: 21 additions & 0 deletions tokio/src/runtime/time/tests/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -202,6 +202,27 @@ fn reset_future() {
})
}

#[test]
fn reset_timer_and_drop() {
model(|| {
let rt = rt(false);
let handle = rt.handle();

let start = handle.inner.driver().clock().now();

for _ in 0..2 {
let entry = TimerEntry::new(handle.inner.clone(), start + Duration::from_millis(10));
pin!(entry);

let _ = entry
.as_mut()
.poll_elapsed(&mut Context::from_waker(futures::task::noop_waker_ref()));

entry.as_mut().reset(start + Duration::from_secs(1), true);
}
});
}

#[cfg(not(loom))]
fn normal_or_miri<T>(normal: T, miri: T) -> T {
if cfg!(miri) {
Expand Down
8 changes: 5 additions & 3 deletions tokio/src/runtime/time/wheel/level.rs
Original file line number Diff line number Diff line change
Expand Up @@ -118,16 +118,18 @@ impl Level {
Some(slot)
}

pub(crate) unsafe fn add_entry(&mut self, item: TimerHandle) {
let slot = slot_for(item.true_when(), self.level);
pub(crate) unsafe fn add_entry(&mut self, item: TimerHandle) -> usize {
let slot = slot_for(item.cached_when(), self.level);

self.slot[slot].push_front(item);

self.occupied |= occupied_bit(slot);

slot
}

pub(crate) unsafe fn remove_entry(&mut self, item: NonNull<TimerShared>) {
let slot = slot_for(unsafe { item.as_ref().true_when() }, self.level);
let slot = slot_for(unsafe { item.as_ref().cached_when() }, self.level);

unsafe { self.slot[slot].remove(item) };
if self.slot[slot].is_empty() {
Expand Down
29 changes: 22 additions & 7 deletions tokio/src/runtime/time/wheel/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -142,16 +142,14 @@ impl Wheel {
&mut self,
item: TimerHandle,
) -> Result<u64, (TimerHandle, InsertError)> {
let when = item.true_when();
let when = item.sync_when();

if when <= self.elapsed {
return Err((item, InsertError::Elapsed));
}

// Get the level at which the entry should be stored
let level = self.level_for(when);

unsafe { self.levels[level].add_entry(item) };
// Safety: The `cached_when` of this item has been updated by calling the `sync_when`.
let (level, _slot) = self.inner_insert(item, self.elapsed, when);

debug_assert!({
self.levels[level]
Expand All @@ -166,7 +164,7 @@ impl Wheel {
/// Removes `item` from the timing wheel.
pub(crate) unsafe fn remove(&mut self, item: NonNull<TimerShared>) {
unsafe {
let when = item.as_ref().true_when();
let when = item.as_ref().cached_when();
if when <= MAX_SAFE_MILLIS_DURATION {
debug_assert!(
self.elapsed <= when,
Expand All @@ -186,8 +184,25 @@ impl Wheel {
/// Reinserts `item` to the timing wheel.
/// Safety: This entry must not have expired.
pub(super) unsafe fn reinsert_entry(&mut self, entry: TimerHandle, elapsed: u64, when: u64) {
entry.set_cached_when(when);
// Safety: The `cached_when` of this entry has been updated by calling the `set_cached_when`.
let (_level, _slot) = self.inner_insert(entry, elapsed, when);
}

/// Inserts the `entry` to the `Wheel`.
/// Returns the level and the slot which `entry` insert into.
///
/// Safety: The `cached_when` of this `entry`` must have been updated to `when`.
unsafe fn inner_insert(
&mut self,
entry: TimerHandle,
elapsed: u64,
when: u64,
) -> (usize, usize) {
// Get the level at which the entry should be stored
let level = level_for(elapsed, when);
unsafe { self.levels[level].add_entry(entry) };
let slot = unsafe { self.levels[level].add_entry(entry) };
(level, slot)
}

/// Instant at which to poll.
Expand Down

0 comments on commit cd70a24

Please sign in to comment.