Skip to content

Commit 9f1723b

Browse files
committed
sync::watch: Use Acquire/Release memory ordering instead of SeqCst
1 parent ad7f988 commit 9f1723b

File tree

1 file changed

+37
-17
lines changed

1 file changed

+37
-17
lines changed

tokio/src/sync/watch.rs

Lines changed: 37 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -114,7 +114,7 @@
114114
use crate::sync::notify::Notify;
115115

116116
use crate::loom::sync::atomic::AtomicUsize;
117-
use crate::loom::sync::atomic::Ordering::Relaxed;
117+
use crate::loom::sync::atomic::Ordering;
118118
use crate::loom::sync::{Arc, RwLock, RwLockReadGuard};
119119
use std::fmt;
120120
use std::mem;
@@ -247,7 +247,8 @@ struct Shared<T> {
247247

248248
impl<T: fmt::Debug> fmt::Debug for Shared<T> {
249249
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
250-
let state = self.state.load();
250+
// Using `Relaxed` ordering is sufficient for this purpose.
251+
let state = self.state.load(Ordering::Relaxed);
251252
f.debug_struct("Shared")
252253
.field("value", &self.value)
253254
.field("version", &state.version())
@@ -341,7 +342,7 @@ mod big_notify {
341342
/// This function implements the case where randomness is not available.
342343
#[cfg(not(all(not(loom), feature = "sync", any(feature = "rt", feature = "macros"))))]
343344
pub(super) fn notified(&self) -> Notified<'_> {
344-
let i = self.next.fetch_add(1, Relaxed) % 8;
345+
let i = self.next.fetch_add(1, Ordering::Relaxed) % 8;
345346
self.inner[i].notified()
346347
}
347348

@@ -357,7 +358,7 @@ mod big_notify {
357358
use self::state::{AtomicState, Version};
358359
mod state {
359360
use crate::loom::sync::atomic::AtomicUsize;
360-
use crate::loom::sync::atomic::Ordering::SeqCst;
361+
use crate::loom::sync::atomic::Ordering;
361362

362363
const CLOSED_BIT: usize = 1;
363364

@@ -377,6 +378,11 @@ mod state {
377378
pub(super) struct StateSnapshot(usize);
378379

379380
/// The state stored in an atomic integer.
381+
///
382+
/// The `Sender` uses `Release` ordering for storing a new state
383+
/// and the `Receiver`s use `Acquire` ordering for loading the
384+
/// current state. This ensures that written values are seen by
385+
/// the `Receiver`s for a proper handover.
380386
#[derive(Debug)]
381387
pub(super) struct AtomicState(AtomicUsize);
382388

@@ -412,18 +418,32 @@ mod state {
412418
}
413419

414420
/// Load the current value of the state.
415-
pub(super) fn load(&self) -> StateSnapshot {
416-
StateSnapshot(self.0.load(SeqCst))
421+
pub(super) fn load(&self, ordering: Ordering) -> StateSnapshot {
422+
StateSnapshot(self.0.load(ordering))
423+
}
424+
425+
/// Load the current value of the state.
426+
///
427+
/// The receiver side (read-only) uses `Acquire` ordering for a proper handover
428+
/// with the sender side (single writer).
429+
pub(super) fn load_receiver(&self) -> StateSnapshot {
430+
StateSnapshot(self.0.load(Ordering::Acquire))
417431
}
418432

419433
/// Increment the version counter.
420434
pub(super) fn increment_version(&self) {
421-
self.0.fetch_add(STEP_SIZE, SeqCst);
435+
// Use `Release` ordering to ensure that storing the version
436+
// state is seen by the receiver side that uses `Acquire` for
437+
// loading the state.
438+
self.0.fetch_add(STEP_SIZE, Ordering::Release);
422439
}
423440

424441
/// Set the closed bit in the state.
425442
pub(super) fn set_closed(&self) {
426-
self.0.fetch_or(CLOSED_BIT, SeqCst);
443+
// Use `Release` ordering to ensure that storing the version
444+
// state is seen by the receiver side that uses `Acquire` for
445+
// loading the state.
446+
self.0.fetch_or(CLOSED_BIT, Ordering::Release);
427447
}
428448
}
429449
}
@@ -489,7 +509,7 @@ impl<T> Receiver<T> {
489509
fn from_shared(version: Version, shared: Arc<Shared<T>>) -> Self {
490510
// No synchronization necessary as this is only used as a counter and
491511
// not memory access.
492-
shared.ref_count_rx.fetch_add(1, Relaxed);
512+
shared.ref_count_rx.fetch_add(1, Ordering::Relaxed);
493513

494514
Self { shared, version }
495515
}
@@ -543,7 +563,7 @@ impl<T> Receiver<T> {
543563

544564
// After obtaining a read-lock no concurrent writes could occur
545565
// and the loaded version matches that of the borrowed reference.
546-
let new_version = self.shared.state.load().version();
566+
let new_version = self.shared.state.load_receiver().version();
547567
let has_changed = self.version != new_version;
548568

549569
Ref { inner, has_changed }
@@ -590,7 +610,7 @@ impl<T> Receiver<T> {
590610

591611
// After obtaining a read-lock no concurrent writes could occur
592612
// and the loaded version matches that of the borrowed reference.
593-
let new_version = self.shared.state.load().version();
613+
let new_version = self.shared.state.load_receiver().version();
594614
let has_changed = self.version != new_version;
595615

596616
// Mark the shared value as seen by updating the version
@@ -631,7 +651,7 @@ impl<T> Receiver<T> {
631651
/// ```
632652
pub fn has_changed(&self) -> Result<bool, error::RecvError> {
633653
// Load the version from the state
634-
let state = self.shared.state.load();
654+
let state = self.shared.state.load_receiver();
635655
if state.is_closed() {
636656
// The sender has dropped.
637657
return Err(error::RecvError(()));
@@ -768,7 +788,7 @@ impl<T> Receiver<T> {
768788
{
769789
let inner = self.shared.value.read().unwrap();
770790

771-
let new_version = self.shared.state.load().version();
791+
let new_version = self.shared.state.load_receiver().version();
772792
let has_changed = self.version != new_version;
773793
self.version = new_version;
774794

@@ -814,7 +834,7 @@ fn maybe_changed<T>(
814834
version: &mut Version,
815835
) -> Option<Result<(), error::RecvError>> {
816836
// Load the version from the state
817-
let state = shared.state.load();
837+
let state = shared.state.load_receiver();
818838
let new_version = state.version();
819839

820840
if *version != new_version {
@@ -865,7 +885,7 @@ impl<T> Drop for Receiver<T> {
865885
fn drop(&mut self) {
866886
// No synchronization necessary as this is only used as a counter and
867887
// not memory access.
868-
if 1 == self.shared.ref_count_rx.fetch_sub(1, Relaxed) {
888+
if 1 == self.shared.ref_count_rx.fetch_sub(1, Ordering::Relaxed) {
869889
// This is the last `Receiver` handle, tasks waiting on `Sender::closed()`
870890
self.shared.notify_tx.notify_waiters();
871891
}
@@ -1228,7 +1248,7 @@ impl<T> Sender<T> {
12281248
/// ```
12291249
pub fn subscribe(&self) -> Receiver<T> {
12301250
let shared = self.shared.clone();
1231-
let version = shared.state.load().version();
1251+
let version = shared.state.load_receiver().version();
12321252

12331253
// The CLOSED bit in the state tracks only whether the sender is
12341254
// dropped, so we do not need to unset it if this reopens the channel.
@@ -1254,7 +1274,7 @@ impl<T> Sender<T> {
12541274
/// }
12551275
/// ```
12561276
pub fn receiver_count(&self) -> usize {
1257-
self.shared.ref_count_rx.load(Relaxed)
1277+
self.shared.ref_count_rx.load(Ordering::Relaxed)
12581278
}
12591279
}
12601280

0 commit comments

Comments
 (0)