Skip to content

Commit

Permalink
sync: add Sender<T>::closed future
Browse files Browse the repository at this point in the history
  • Loading branch information
evanrittenhouse committed Sep 22, 2024
1 parent a302367 commit b6491c0
Show file tree
Hide file tree
Showing 3 changed files with 75 additions and 3 deletions.
2 changes: 1 addition & 1 deletion tokio/src/loom/std/mutex.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
use std::sync::{self, MutexGuard, TryLockError};

/// Adapter for `std::Mutex` that removes the poisoning aspects
/// from its api.
/// from its API.
#[derive(Debug)]
pub(crate) struct Mutex<T: ?Sized>(sync::Mutex<T>);

Expand Down
59 changes: 57 additions & 2 deletions tokio/src/sync/broadcast.rs
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,7 @@
//! }
//! ```
use crate::future::poll_fn;
use crate::loom::cell::UnsafeCell;
use crate::loom::sync::atomic::{AtomicBool, AtomicUsize};
use crate::loom::sync::{Arc, Mutex, MutexGuard, RwLock, RwLockReadGuard};
Expand Down Expand Up @@ -163,6 +164,7 @@ use std::task::{ready, Context, Poll, Waker};
/// [`broadcast`]: crate::sync::broadcast
pub struct Sender<T> {
shared: Arc<Shared<T>>,
notify_rx_closed: Arc<Notify>,
}

/// Receiving-half of the [`broadcast`] channel.
Expand Down Expand Up @@ -300,6 +302,8 @@ pub mod error {

use self::error::{RecvError, SendError, TryRecvError};

use super::Notify;

/// Data shared between senders and receivers.
struct Shared<T> {
/// slots in the channel.
Expand All @@ -313,6 +317,9 @@ struct Shared<T> {

/// Number of outstanding Sender handles.
num_tx: AtomicUsize,

/// Notify when a subscribed [`Receiver`] is dropped.
notify_rx_drop: Notify,
}

/// Next position to write a value.
Expand Down Expand Up @@ -527,9 +534,15 @@ impl<T> Sender<T> {
waiters: LinkedList::new(),
}),
num_tx: AtomicUsize::new(1),
notify_rx_drop: Notify::new(),
});

Sender { shared }
let notify_rx_closed = Arc::new(Notify::new());

Sender {
shared,
notify_rx_closed,
}
}

/// Attempts to send a value to all active [`Receiver`] handles, returning
Expand Down Expand Up @@ -804,6 +817,38 @@ impl<T> Sender<T> {
Arc::ptr_eq(&self.shared, &other.shared)
}

/// A future which completes when the number of [Receiver]s subscribed to this `Sender` reaches
/// zero.
///
/// # Examples
///
/// ```
/// use futures::FutureExt;
/// use tokio::sync::broadcast;
///
/// #[tokio::main]
/// async fn main() {
/// let (tx, mut rx1) = broadcast::channel::<u32>(16);
/// let mut rx2 = tx.subscribe();
///
/// tokio::spawn(async move {
/// assert_eq!(rx1.recv().await.unwrap(), 10);
/// });
///
/// let _ = tx.send(10);
/// assert!(tx.closed().now_or_never().is_none());
///
/// let _ = tokio::spawn(async move {
/// assert_eq!(rx2.recv().await.unwrap(), 10);
/// }).await;
///
/// assert!(tx.closed().now_or_never().is_some());
/// }
/// ```
pub async fn closed(&self) {
self.shared.notify_rx_drop.notified().await;
}

fn close_channel(&self) {
let mut tail = self.shared.tail.lock();
tail.closed = true;
Expand Down Expand Up @@ -946,7 +991,12 @@ impl<T> Clone for Sender<T> {
let shared = self.shared.clone();
shared.num_tx.fetch_add(1, SeqCst);

Sender { shared }
let notify_rx_closed = Arc::clone(&self.notify_rx_closed);

Sender {
shared,
notify_rx_closed,
}
}
}

Expand Down Expand Up @@ -1346,9 +1396,14 @@ impl<T> Drop for Receiver<T> {

tail.rx_cnt -= 1;
let until = tail.pos;
let remaining_rx = tail.rx_cnt;

drop(tail);

if remaining_rx == 0 {
self.shared.notify_rx_drop.notify_waiters();
}

while self.next < until {
match self.recv_ref(None) {
Ok(_) => {}
Expand Down
17 changes: 17 additions & 0 deletions tokio/tests/sync_broadcast.rs
Original file line number Diff line number Diff line change
Expand Up @@ -640,3 +640,20 @@ fn send_in_waker_drop() {
// Shouldn't deadlock.
let _ = tx.send(());
}

#[test]
fn broadcast_sender_closed() {
let (tx, rx) = broadcast::channel::<()>(1);
let rx2 = tx.subscribe();

let mut task = task::spawn(tx.closed());
assert_pending!(task.poll());

drop(rx);
assert!(!task.is_woken());
assert_pending!(task.poll());

drop(rx2);
assert!(task.is_woken());
assert_ready!(task.poll());
}

0 comments on commit b6491c0

Please sign in to comment.