diff --git a/Cargo.toml b/Cargo.toml index 97bd9c1..7f62852 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "stop-token" -version = "0.1.2" +version = "0.2.0" authors = ["Aleksey Kladov "] edition = "2018" license = "MIT OR Apache-2.0" @@ -10,7 +10,7 @@ description = "Experimental cooperative cancellation for async-std" [dependencies] pin-project-lite = "0.1.0" -async-std = "1.0" +async-std = { version = "1.0", features = ["unstable"] } [features] unstable = ["async-std/unstable"] diff --git a/src/lib.rs b/src/lib.rs index 3a7026d..94ae95d 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -76,6 +76,19 @@ use pin_project_lite::pin_project; enum Never {} +#[derive(Debug)] +pub enum Error { + Stopped, +} + +impl std::fmt::Display for Error { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "{:?}", self) + } +} + +impl std::error::Error for Error {} + /// `StopSource` produces `StopToken` and cancels all of its tokens on drop. /// /// # Example: @@ -190,16 +203,44 @@ pin_project! { } impl Future for StopFuture { - type Output = Option; + type Output = Result; - fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { let this = self.project(); if let Poll::Ready(()) = this.stop_token.poll(cx) { - return Poll::Ready(None); + return Poll::Ready(Err(Error::Stopped)); } match this.future.poll(cx) { Poll::Pending => Poll::Pending, - Poll::Ready(it) => Poll::Ready(Some(it)), + Poll::Ready(it) => Poll::Ready(Ok(it)), + } + } +} + +impl WithStopTokenExt for F where F: Future {} + +pub trait WithStopTokenExt: Future { + fn with_stop_token(self, stop_token: &StopToken) -> StopFuture + where + Self: Sized, + { + StopFuture { + stop_token: stop_token.clone(), + future: self, + } + } +} + +impl WithStopTokenStreamExt for S where S: Stream {} + +pub trait WithStopTokenStreamExt: Stream { + fn with_stop_token(self, stop_token: &StopToken) -> StopStream + where + Self: Sized, + { + StopStream { + stop_token: stop_token.clone(), + stream: self, } } } diff --git a/tests/tests.rs b/tests/tests.rs index 8fe28f3..d132543 100644 --- a/tests/tests.rs +++ b/tests/tests.rs @@ -1,8 +1,8 @@ use std::time::Duration; -use async_std::{prelude::*, task, sync::channel}; +use async_std::{prelude::*, sync::channel, task}; -use stop_token::StopSource; +use stop_token::{Error, StopSource, StopToken, WithStopTokenExt as _}; #[test] fn smoke() { @@ -13,13 +13,14 @@ fn smoke() { let stop_token = stop_source.stop_token(); let receiver = receiver.clone(); async move { - let mut xs = Vec::new(); - let mut stream = stop_token.stop_stream(receiver); - while let Some(x) = stream.next().await { - xs.push(x) + let mut xs = Vec::new(); + let mut stream = stop_token.stop_stream(receiver); + while let Some(x) = stream.next().await { + xs.push(x) + } + xs } - xs - }}); + }); sender.send(1).await; sender.send(2).await; sender.send(3).await; @@ -34,3 +35,28 @@ fn smoke() { assert_eq!(task.await, vec![1, 2, 3]); }) } + +#[test] +fn extension_methods() { + async fn long_running(stop_token: StopToken) -> Result<(), Error> { + loop { + task::sleep(Duration::from_secs(10)) + .with_stop_token(&stop_token) + .await?; + } + } + + task::block_on(async { + let stop_source = StopSource::new(); + let stop_token = stop_source.stop_token(); + + task::spawn(async { + task::sleep(Duration::from_millis(250)).await; + drop(stop_source); + }); + + if let Ok(_) = long_running(stop_token).await { + panic!("expected to have been stopped"); + } + }) +}