diff --git a/tokio/src/loom/mocked.rs b/tokio/src/loom/mocked.rs index c25018e7e8c..cfcbb2967e1 100644 --- a/tokio/src/loom/mocked.rs +++ b/tokio/src/loom/mocked.rs @@ -2,7 +2,7 @@ pub(crate) use loom::*; pub(crate) mod sync { - pub(crate) use loom::sync::MutexGuard; + pub(crate) use loom::sync::{MutexGuard, RwLockReadGuard, RwLockWriteGuard}; #[derive(Debug)] pub(crate) struct Mutex(loom::sync::Mutex); @@ -30,6 +30,38 @@ pub(crate) mod sync { self.0.get_mut().unwrap() } } + + #[derive(Debug)] + pub(crate) struct RwLock(loom::sync::RwLock); + + #[allow(dead_code)] + impl RwLock { + #[inline] + pub(crate) fn new(t: T) -> Self { + Self(loom::sync::RwLock::new(t)) + } + + #[inline] + pub(crate) fn read(&self) -> RwLockReadGuard<'_, T> { + self.0.read().unwrap() + } + + #[inline] + pub(crate) fn try_read(&self) -> Option> { + self.0.try_read().ok() + } + + #[inline] + pub(crate) fn write(&self) -> RwLockWriteGuard<'_, T> { + self.0.write().unwrap() + } + + #[inline] + pub(crate) fn try_write(&self) -> Option> { + self.0.try_write().ok() + } + } + pub(crate) use loom::sync::*; pub(crate) mod atomic { diff --git a/tokio/src/loom/std/mod.rs b/tokio/src/loom/std/mod.rs index 985d8d73aeb..d446f2ee804 100644 --- a/tokio/src/loom/std/mod.rs +++ b/tokio/src/loom/std/mod.rs @@ -8,6 +8,7 @@ mod barrier; mod mutex; #[cfg(all(feature = "parking_lot", not(miri)))] mod parking_lot; +mod rwlock; mod unsafe_cell; pub(crate) mod cell { @@ -64,11 +65,14 @@ pub(crate) mod sync { #[cfg(not(all(feature = "parking_lot", not(miri))))] #[allow(unused_imports)] - pub(crate) use std::sync::{Condvar, MutexGuard, RwLock, RwLockReadGuard, WaitTimeoutResult}; + pub(crate) use std::sync::{Condvar, MutexGuard, RwLockReadGuard, WaitTimeoutResult}; #[cfg(not(all(feature = "parking_lot", not(miri))))] pub(crate) use crate::loom::std::mutex::Mutex; + #[cfg(not(all(feature = "parking_lot", not(miri))))] + pub(crate) use crate::loom::std::rwlock::RwLock; + pub(crate) mod atomic { pub(crate) use crate::loom::std::atomic_u16::AtomicU16; pub(crate) use crate::loom::std::atomic_u32::AtomicU32; diff --git a/tokio/src/loom/std/parking_lot.rs b/tokio/src/loom/std/parking_lot.rs index 9b9a81d35b0..6a8375b0787 100644 --- a/tokio/src/loom/std/parking_lot.rs +++ b/tokio/src/loom/std/parking_lot.rs @@ -96,12 +96,20 @@ impl RwLock { RwLock(PhantomData, parking_lot::RwLock::new(t)) } - pub(crate) fn read(&self) -> LockResult> { - Ok(RwLockReadGuard(PhantomData, self.1.read())) + pub(crate) fn read(&self) -> RwLockReadGuard<'_, T> { + RwLockReadGuard(PhantomData, self.1.read()) } - pub(crate) fn write(&self) -> LockResult> { - Ok(RwLockWriteGuard(PhantomData, self.1.write())) + pub(crate) fn try_read(&self) -> Option> { + Some(RwLockReadGuard(PhantomData, self.1.read())) + } + + pub(crate) fn write(&self) -> RwLockWriteGuard<'_, T> { + RwLockWriteGuard(PhantomData, self.1.write()) + } + + pub(crate) fn try_write(&self) -> Option> { + Some(RwLockWriteGuard(PhantomData, self.1.write())) } } diff --git a/tokio/src/loom/std/rwlock.rs b/tokio/src/loom/std/rwlock.rs new file mode 100644 index 00000000000..2b2c5f3fcde --- /dev/null +++ b/tokio/src/loom/std/rwlock.rs @@ -0,0 +1,48 @@ +use std::sync::{self, RwLockReadGuard, RwLockWriteGuard, TryLockError}; + +/// Adapter for `std::sync::RwLock` that removes the poisoning aspects +/// from its api. +#[derive(Debug)] +pub(crate) struct RwLock(sync::RwLock); + +#[allow(dead_code)] +impl RwLock { + #[inline] + pub(crate) fn new(t: T) -> Self { + Self(sync::RwLock::new(t)) + } + + #[inline] + pub(crate) fn read(&self) -> RwLockReadGuard<'_, T> { + match self.0.read() { + Ok(guard) => guard, + Err(p_err) => p_err.into_inner(), + } + } + + #[inline] + pub(crate) fn try_read(&self) -> Option> { + match self.0.try_read() { + Ok(guard) => Some(guard), + Err(TryLockError::Poisoned(p_err)) => Some(p_err.into_inner()), + Err(TryLockError::WouldBlock) => None, + } + } + + #[inline] + pub(crate) fn write(&self) -> RwLockWriteGuard<'_, T> { + match self.0.write() { + Ok(guard) => guard, + Err(p_err) => p_err.into_inner(), + } + } + + #[inline] + pub(crate) fn try_write(&self) -> Option> { + match self.0.try_write() { + Ok(guard) => Some(guard), + Err(TryLockError::Poisoned(p_err)) => Some(p_err.into_inner()), + Err(TryLockError::WouldBlock) => None, + } + } +} diff --git a/tokio/src/runtime/time/mod.rs b/tokio/src/runtime/time/mod.rs index 50603ed9ef4..56e0ba64d9c 100644 --- a/tokio/src/runtime/time/mod.rs +++ b/tokio/src/runtime/time/mod.rs @@ -20,7 +20,7 @@ pub(crate) use source::TimeSource; mod wheel; use crate::loom::sync::atomic::{AtomicBool, Ordering}; -use crate::loom::sync::Mutex; +use crate::loom::sync::{Mutex, RwLock}; use crate::runtime::driver::{self, IoHandle, IoStack}; use crate::time::error::Error; use crate::time::{Clock, Duration}; @@ -28,7 +28,6 @@ use crate::util::WakeList; use crate::loom::sync::atomic::AtomicU64; use std::fmt; -use std::sync::RwLock; use std::{num::NonZeroU64, ptr::NonNull}; struct AtomicOptionNonZeroU64(AtomicU64); @@ -199,12 +198,7 @@ impl Driver { // Finds out the min expiration time to park. let expiration_time = { - let mut wheels_lock = rt_handle - .time() - .inner - .wheels - .write() - .expect("Timer wheel shards poisoned"); + let mut wheels_lock = rt_handle.time().inner.wheels.write(); let expiration_time = wheels_lock .0 .iter_mut() @@ -324,11 +318,7 @@ impl Handle { // Returns the next wakeup time of this shard. pub(self) fn process_at_sharded_time(&self, id: u32, mut now: u64) -> Option { let mut waker_list = WakeList::new(); - let mut wheels_lock = self - .inner - .wheels - .read() - .expect("Timer wheel shards poisoned"); + let mut wheels_lock = self.inner.wheels.read(); let mut lock = wheels_lock.lock_sharded_wheel(id); if now < lock.elapsed() { @@ -355,11 +345,7 @@ impl Handle { waker_list.wake_all(); - wheels_lock = self - .inner - .wheels - .read() - .expect("Timer wheel shards poisoned"); + wheels_lock = self.inner.wheels.read(); lock = wheels_lock.lock_sharded_wheel(id); } } @@ -384,11 +370,7 @@ impl Handle { /// `add_entry` must not be called concurrently. pub(self) unsafe fn clear_entry(&self, entry: NonNull) { unsafe { - let wheels_lock = self - .inner - .wheels - .read() - .expect("Timer wheel shards poisoned"); + let wheels_lock = self.inner.wheels.read(); let mut lock = wheels_lock.lock_sharded_wheel(entry.as_ref().shard_id()); if entry.as_ref().might_be_registered() { @@ -412,11 +394,7 @@ impl Handle { entry: NonNull, ) { let waker = unsafe { - let wheels_lock = self - .inner - .wheels - .read() - .expect("Timer wheel shards poisoned"); + let wheels_lock = self.inner.wheels.read(); let mut lock = wheels_lock.lock_sharded_wheel(entry.as_ref().shard_id()); diff --git a/tokio/src/sync/broadcast.rs b/tokio/src/sync/broadcast.rs index 2161a4794ef..67d67a666e3 100644 --- a/tokio/src/sync/broadcast.rs +++ b/tokio/src/sync/broadcast.rs @@ -599,7 +599,7 @@ impl Sender { tail.pos = tail.pos.wrapping_add(1); // Get the slot - let mut slot = self.shared.buffer[idx].write().unwrap(); + let mut slot = self.shared.buffer[idx].write(); // Track the position slot.pos = pos; @@ -695,7 +695,7 @@ impl Sender { while low < high { let mid = low + (high - low) / 2; let idx = base_idx.wrapping_add(mid) & self.shared.mask; - if self.shared.buffer[idx].read().unwrap().rem.load(SeqCst) == 0 { + if self.shared.buffer[idx].read().rem.load(SeqCst) == 0 { low = mid + 1; } else { high = mid; @@ -737,7 +737,7 @@ impl Sender { let tail = self.shared.tail.lock(); let idx = (tail.pos.wrapping_sub(1) & self.shared.mask as u64) as usize; - self.shared.buffer[idx].read().unwrap().rem.load(SeqCst) == 0 + self.shared.buffer[idx].read().rem.load(SeqCst) == 0 } /// Returns the number of active receivers. @@ -1057,7 +1057,7 @@ impl Receiver { let idx = (self.next & self.shared.mask as u64) as usize; // The slot holding the next value to read - let mut slot = self.shared.buffer[idx].read().unwrap(); + let mut slot = self.shared.buffer[idx].read(); if slot.pos != self.next { // Release the `slot` lock before attempting to acquire the `tail` @@ -1074,7 +1074,7 @@ impl Receiver { let mut tail = self.shared.tail.lock(); // Acquire slot lock again - slot = self.shared.buffer[idx].read().unwrap(); + slot = self.shared.buffer[idx].read(); // Make sure the position did not change. This could happen in the // unlikely event that the buffer is wrapped between dropping the diff --git a/tokio/src/sync/watch.rs b/tokio/src/sync/watch.rs index 366066797f1..490b9e4df88 100644 --- a/tokio/src/sync/watch.rs +++ b/tokio/src/sync/watch.rs @@ -575,7 +575,7 @@ impl Receiver { /// assert_eq!(*rx.borrow(), "hello"); /// ``` pub fn borrow(&self) -> Ref<'_, T> { - let inner = self.shared.value.read().unwrap(); + let inner = self.shared.value.read(); // After obtaining a read-lock no concurrent writes could occur // and the loaded version matches that of the borrowed reference. @@ -622,7 +622,7 @@ impl Receiver { /// [`changed`]: Receiver::changed /// [`borrow`]: Receiver::borrow pub fn borrow_and_update(&mut self) -> Ref<'_, T> { - let inner = self.shared.value.read().unwrap(); + let inner = self.shared.value.read(); // After obtaining a read-lock no concurrent writes could occur // and the loaded version matches that of the borrowed reference. @@ -813,7 +813,7 @@ impl Receiver { let mut closed = false; loop { { - let inner = self.shared.value.read().unwrap(); + let inner = self.shared.value.read(); let new_version = self.shared.state.load().version(); let has_changed = self.version != new_version; @@ -1087,7 +1087,7 @@ impl Sender { { { // Acquire the write lock and update the value. - let mut lock = self.shared.value.write().unwrap(); + let mut lock = self.shared.value.write(); // Update the value and catch possible panic inside func. let result = panic::catch_unwind(panic::AssertUnwindSafe(|| modify(&mut lock))); @@ -1164,7 +1164,7 @@ impl Sender { /// assert_eq!(*tx.borrow(), "hello"); /// ``` pub fn borrow(&self) -> Ref<'_, T> { - let inner = self.shared.value.read().unwrap(); + let inner = self.shared.value.read(); // The sender/producer always sees the current version let has_changed = false;