From cd70a2497eb38e9e6a057c85aec7adfc3cba9871 Mon Sep 17 00:00:00 2001 From: wathenjiang Date: Wed, 24 Jul 2024 19:57:50 +0800 Subject: [PATCH] time: revert set_cached_when --- tokio/src/runtime/time/entry.rs | 47 +++++++++++++++++++++++---- tokio/src/runtime/time/mod.rs | 1 - tokio/src/runtime/time/tests/mod.rs | 21 ++++++++++++ tokio/src/runtime/time/wheel/level.rs | 8 +++-- tokio/src/runtime/time/wheel/mod.rs | 29 +++++++++++++---- 5 files changed, 89 insertions(+), 17 deletions(-) diff --git a/tokio/src/runtime/time/entry.rs b/tokio/src/runtime/time/entry.rs index 0bd15a74f8b..f52944fb1dc 100644 --- a/tokio/src/runtime/time/entry.rs +++ b/tokio/src/runtime/time/entry.rs @@ -333,7 +333,11 @@ pub(crate) struct TimerShared { /// /// Only accessed under the entry lock. pointers: linked_list::Pointers, - + /// The expiration time for which this entry is currently registered. + /// It is used to calculate which slot this entry is stored in. + /// Generally owned by the driver, but is accessed by the entry when not + /// registered. + cached_when: AtomicU64, /// Current state. This records whether the timer entry is currently under /// the ownership of the driver, and if not, its current state (not /// complete, fired, error, etc). @@ -348,6 +352,7 @@ unsafe impl Sync for TimerShared {} impl std::fmt::Debug for TimerShared { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { f.debug_struct("TimerShared") + .field("cached_when", &self.cached_when.load(Ordering::Relaxed)) .field("state", &self.state) .finish() } @@ -365,12 +370,37 @@ impl TimerShared { pub(super) fn new(shard_id: u32) -> Self { Self { shard_id, + cached_when: AtomicU64::new(0), pointers: linked_list::Pointers::new(), state: StateCell::default(), _p: PhantomPinned, } } + /// Gets the cached time-of-expiration value. + pub(super) fn cached_when(&self) -> u64 { + // Cached-when is only accessed under the driver lock, so we can use relaxed + self.cached_when.load(Ordering::Relaxed) + } + /// Gets the true time-of-expiration value, and copies it into the cached + /// time-of-expiration value. + /// + /// SAFETY: Must be called with the driver lock held, and when this entry is + /// not in any timer wheel lists. + pub(super) unsafe fn sync_when(&self) -> u64 { + let true_when = self.true_when(); + self.set_cached_when(true_when); + true_when + } + + /// Sets the cached time-of-expiration value. + /// + /// SAFETY: Must be called with the driver lock held, and when this entry is + /// not in any timer wheel lists. + pub(super) unsafe fn set_cached_when(&self, when: u64) { + self.cached_when.store(when, Ordering::Relaxed); + } + /// Returns the true time-of-expiration value, with relaxed memory ordering. pub(super) fn true_when(&self) -> u64 { self.state.when().expect("Timer already fired") @@ -552,8 +582,16 @@ impl TimerEntry { } impl TimerHandle { - pub(super) unsafe fn true_when(&self) -> u64 { - unsafe { self.inner.as_ref().true_when() } + pub(super) unsafe fn cached_when(&self) -> u64 { + unsafe { self.inner.as_ref().cached_when() } + } + + pub(super) unsafe fn set_cached_when(&self, t: u64) { + unsafe { self.inner.as_ref().set_cached_when(t) } + } + + pub(super) unsafe fn sync_when(&self) -> u64 { + unsafe { self.inner.as_ref().sync_when() } } /// Forcibly sets the true and cached expiration times to the given tick. @@ -567,9 +605,6 @@ impl TimerHandle { /// Attempts to mark this entry as firing. If the expiration time is after /// `not_after`, however, returns an Err with the current expiration time. /// - /// If an `Err` is returned, the `cached_when` value will be updated to this - /// new expiration time. - /// /// SAFETY: The caller must ensure that the handle remains valid, the driver /// lock is held, and that the timer is not in any wheel linked lists. pub(super) unsafe fn mark_firing(&self, not_after: u64) -> Result<(), u64> { diff --git a/tokio/src/runtime/time/mod.rs b/tokio/src/runtime/time/mod.rs index 0e4c995dcd0..3d1dc92f68f 100644 --- a/tokio/src/runtime/time/mod.rs +++ b/tokio/src/runtime/time/mod.rs @@ -204,7 +204,6 @@ impl Driver { .inner .next_wake .store(next_wake_time(expiration_time)); - // Safety: After updating the `next_wake`, we drop all the locks. drop(locks); diff --git a/tokio/src/runtime/time/tests/mod.rs b/tokio/src/runtime/time/tests/mod.rs index 0e453433691..7084c2d7519 100644 --- a/tokio/src/runtime/time/tests/mod.rs +++ b/tokio/src/runtime/time/tests/mod.rs @@ -202,6 +202,27 @@ fn reset_future() { }) } +#[test] +fn reset_timer_and_drop() { + model(|| { + let rt = rt(false); + let handle = rt.handle(); + + let start = handle.inner.driver().clock().now(); + + for _ in 0..2 { + let entry = TimerEntry::new(handle.inner.clone(), start + Duration::from_millis(10)); + pin!(entry); + + let _ = entry + .as_mut() + .poll_elapsed(&mut Context::from_waker(futures::task::noop_waker_ref())); + + entry.as_mut().reset(start + Duration::from_secs(1), true); + } + }); +} + #[cfg(not(loom))] fn normal_or_miri(normal: T, miri: T) -> T { if cfg!(miri) { diff --git a/tokio/src/runtime/time/wheel/level.rs b/tokio/src/runtime/time/wheel/level.rs index 6539f47b4fa..cccff6d8827 100644 --- a/tokio/src/runtime/time/wheel/level.rs +++ b/tokio/src/runtime/time/wheel/level.rs @@ -118,16 +118,18 @@ impl Level { Some(slot) } - pub(crate) unsafe fn add_entry(&mut self, item: TimerHandle) { - let slot = slot_for(item.true_when(), self.level); + pub(crate) unsafe fn add_entry(&mut self, item: TimerHandle) -> usize { + let slot = slot_for(item.cached_when(), self.level); self.slot[slot].push_front(item); self.occupied |= occupied_bit(slot); + + slot } pub(crate) unsafe fn remove_entry(&mut self, item: NonNull) { - let slot = slot_for(unsafe { item.as_ref().true_when() }, self.level); + let slot = slot_for(unsafe { item.as_ref().cached_when() }, self.level); unsafe { self.slot[slot].remove(item) }; if self.slot[slot].is_empty() { diff --git a/tokio/src/runtime/time/wheel/mod.rs b/tokio/src/runtime/time/wheel/mod.rs index a4034053134..7424c4892cb 100644 --- a/tokio/src/runtime/time/wheel/mod.rs +++ b/tokio/src/runtime/time/wheel/mod.rs @@ -142,16 +142,14 @@ impl Wheel { &mut self, item: TimerHandle, ) -> Result { - let when = item.true_when(); + let when = item.sync_when(); if when <= self.elapsed { return Err((item, InsertError::Elapsed)); } - // Get the level at which the entry should be stored - let level = self.level_for(when); - - unsafe { self.levels[level].add_entry(item) }; + // Safety: The `cached_when` of this item has been updated by calling the `sync_when`. + let (level, _slot) = self.inner_insert(item, self.elapsed, when); debug_assert!({ self.levels[level] @@ -166,7 +164,7 @@ impl Wheel { /// Removes `item` from the timing wheel. pub(crate) unsafe fn remove(&mut self, item: NonNull) { unsafe { - let when = item.as_ref().true_when(); + let when = item.as_ref().cached_when(); if when <= MAX_SAFE_MILLIS_DURATION { debug_assert!( self.elapsed <= when, @@ -186,8 +184,25 @@ impl Wheel { /// Reinserts `item` to the timing wheel. /// Safety: This entry must not have expired. pub(super) unsafe fn reinsert_entry(&mut self, entry: TimerHandle, elapsed: u64, when: u64) { + entry.set_cached_when(when); + // Safety: The `cached_when` of this entry has been updated by calling the `set_cached_when`. + let (_level, _slot) = self.inner_insert(entry, elapsed, when); + } + + /// Inserts the `entry` to the `Wheel`. + /// Returns the level and the slot which `entry` insert into. + /// + /// Safety: The `cached_when` of this `entry`` must have been updated to `when`. + unsafe fn inner_insert( + &mut self, + entry: TimerHandle, + elapsed: u64, + when: u64, + ) -> (usize, usize) { + // Get the level at which the entry should be stored let level = level_for(elapsed, when); - unsafe { self.levels[level].add_entry(entry) }; + let slot = unsafe { self.levels[level].add_entry(entry) }; + (level, slot) } /// Instant at which to poll.