Skip to content

Commit

Permalink
sync: add CancellationToken::run_until_cancelled (#6618)
Browse files Browse the repository at this point in the history
  • Loading branch information
tglane authored Jun 13, 2024
1 parent a865ca1 commit 53ea44b
Show file tree
Hide file tree
Showing 2 changed files with 94 additions and 0 deletions.
46 changes: 46 additions & 0 deletions tokio-util/src/sync/cancellation_token.rs
Original file line number Diff line number Diff line change
Expand Up @@ -241,6 +241,52 @@ impl CancellationToken {
pub fn drop_guard(self) -> DropGuard {
DropGuard { inner: Some(self) }
}

/// Runs a future to completion and returns its result wrapped inside of an `Option`
/// unless the `CancellationToken` is cancelled. In that case the function returns
/// `None` and the future gets dropped.
///
/// # Cancel safety
///
/// This method is only cancel safe if `fut` is cancel safe.
pub async fn run_until_cancelled<F>(&self, fut: F) -> Option<F::Output>
where
F: Future,
{
pin_project! {
/// A Future that is resolved once the corresponding [`CancellationToken`]
/// is cancelled or a given Future gets resolved. It is biased towards the
/// Future completion.
#[must_use = "futures do nothing unless polled"]
struct RunUntilCancelledFuture<'a, F: Future> {
#[pin]
cancellation: WaitForCancellationFuture<'a>,
#[pin]
future: F,
}
}

impl<'a, F: Future> Future for RunUntilCancelledFuture<'a, F> {
type Output = Option<F::Output>;

fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
let this = self.project();
if let Poll::Ready(res) = this.future.poll(cx) {
Poll::Ready(Some(res))
} else if this.cancellation.poll(cx).is_ready() {
Poll::Ready(None)
} else {
Poll::Pending
}
}
}

RunUntilCancelledFuture {
cancellation: self.cancelled(),
future: fut,
}
.await
}
}

// ===== impl WaitForCancellationFuture =====
Expand Down
48 changes: 48 additions & 0 deletions tokio-util/tests/sync_cancellation_token.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
#![warn(rust_2018_idioms)]

use tokio::pin;
use tokio::sync::oneshot;
use tokio_util::sync::{CancellationToken, WaitForCancellationFuture};

use core::future::Future;
Expand Down Expand Up @@ -445,3 +446,50 @@ fn derives_send_sync() {
assert_send::<WaitForCancellationFuture<'static>>();
assert_sync::<WaitForCancellationFuture<'static>>();
}

#[test]
fn run_until_cancelled_test() {
let (waker, _) = new_count_waker();

{
let token = CancellationToken::new();

let fut = token.run_until_cancelled(std::future::pending::<()>());
pin!(fut);

assert_eq!(
Poll::Pending,
fut.as_mut().poll(&mut Context::from_waker(&waker))
);

token.cancel();

assert_eq!(
Poll::Ready(None),
fut.as_mut().poll(&mut Context::from_waker(&waker))
);
}

{
let (tx, rx) = oneshot::channel::<()>();

let token = CancellationToken::new();
let fut = token.run_until_cancelled(async move {
rx.await.unwrap();
42
});
pin!(fut);

assert_eq!(
Poll::Pending,
fut.as_mut().poll(&mut Context::from_waker(&waker))
);

tx.send(()).unwrap();

assert_eq!(
Poll::Ready(Some(42)),
fut.as_mut().poll(&mut Context::from_waker(&waker))
);
}
}

0 comments on commit 53ea44b

Please sign in to comment.