diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 5c8a577..ae4c197 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -17,16 +17,15 @@ jobs: runs-on: ubuntu-latest steps: - - uses: actions/checkout@master + - uses: actions/checkout@master - - name: Install nightly - uses: actions-rs/toolchain@v1 - with: - toolchain: nightly - override: true + - name: Install nightly + uses: actions-rs/toolchain@v1 + with: + toolchain: nightly + override: true - - name: tests - uses: actions-rs/cargo@v1 - with: - command: test - args: --features unstable + - name: tests + uses: actions-rs/cargo@v1 + with: + command: test diff --git a/Cargo.toml b/Cargo.toml index 97bd9c1..35b6076 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "stop-token" -version = "0.1.2" +version = "0.1.3" authors = ["Aleksey Kladov "] edition = "2018" license = "MIT OR Apache-2.0" @@ -10,7 +10,14 @@ description = "Experimental cooperative cancellation for async-std" [dependencies] pin-project-lite = "0.1.0" -async-std = "1.0" +futures = "0.3.5" +event-listener = "2.2.0" + + +[dev-dependencies] +async-std = { version = "1.0", features = ["unstable"] } [features] -unstable = ["async-std/unstable"] +unstable = [] +# This feature doesn't do anything anymore, +# but is needed for backwards-compatibility diff --git a/src/lib.rs b/src/lib.rs index 3a7026d..25297c6 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -67,14 +67,18 @@ //! The cancellation system is a subset of `C#` [`CancellationToken / CancellationTokenSource`](https://docs.microsoft.com/en-us/dotnet/standard/threading/cancellation-in-managed-threads). //! The `StopToken / StopTokenSource` terminology is borrowed from C++ paper P0660: https://wg21.link/p0660. +use std::future::Future; use std::pin::Pin; -use std::task::{Context, Poll}; +use std::ptr::null_mut; +use std::sync::{ + atomic::{AtomicPtr, Ordering}, + Arc, +}; -use async_std::prelude::*; -use async_std::sync::{channel, Receiver, Sender}; +use event_listener::{Event, EventListener}; +use futures::stream::Stream; use pin_project_lite::pin_project; - -enum Never {} +use std::task::{Context, Poll}; /// `StopSource` produces `StopToken` and cancels all of its tokens on drop. /// @@ -86,30 +90,163 @@ enum Never {} /// schedule_some_work(stop_token); /// drop(stop_source); // At this point, scheduled work notices that it is canceled. /// ``` + +/// An immutable, atomic option type that store data in Boxed Arcs +struct AtomicOption(AtomicPtr>); + +// TODO: relax orderings on atomic accesses +impl AtomicOption { + fn is_none(&self) -> bool { + self.0.load(Ordering::SeqCst).is_null() + } + + #[allow(dead_code)] + fn is_some(&self) -> bool { + !self.is_none() + } + + fn get(&self) -> Option> { + let ptr = self.0.load(Ordering::SeqCst); + if ptr.is_null() { + None + } else { + // Safety: we know that `ptr` is not null and can only have been created from a `Box` by `new` or `replace` + // this means it's safe to turn back into a `Box` + let arc_box = unsafe { Box::from_raw(ptr as *mut Arc) }; + + let arc = *arc_box.clone(); // Clone the Arc + + Box::leak(arc_box); // And make sure rust doesn't drop our inner value + + Some(arc) + } + } + + fn new(value: Option) -> Self { + let ptr = if let Some(value) = value { + Box::into_raw(Box::new(Arc::new(value))) + } else { + null_mut() + }; + + Self(AtomicPtr::new(ptr)) + } + + fn take(&self) -> Option> { + self.replace(None) + } + + fn replace(&self, new: Option) -> Option> { + let new_ptr = if let Some(new) = new { + Box::into_raw(Box::new(Arc::new(new))) + } else { + null_mut() + }; + + let ptr = self.0.swap(new_ptr, Ordering::SeqCst); + + if ptr.is_null() { + None + } else { + // Safety: we know that `ptr` is not null and can only have been created from a `Box` by `new` or `replace` + // this means it's safe to turn back into a `Box` + Some(unsafe { *Box::from_raw(ptr) }) + } + } +} + +impl Drop for AtomicOption { + fn drop(&mut self) { + std::mem::drop(self.take()); + } +} + +impl std::fmt::Debug for AtomicOption +where + T: std::fmt::Debug, +{ + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + if self.is_none() { + write!(f, "None") + } else { + write!(f, "Some()") + } + } +} + +/// a custom implementation of a CondVar that short-circuits after +/// being signaled once +#[derive(Debug)] +struct ShortCircuitingCondVar(AtomicOption); + +impl ShortCircuitingCondVar { + fn is_done(&self) -> bool { + self.0.is_none() + } + + fn notify(&self, n: usize) -> bool { + self.0.take().map(|x| x.notify(n)).is_some() + } + + fn listen(&self) -> Option { + self.0.get().map(|event| event.listen()) + } +} + #[derive(Debug)] pub struct StopSource { - /// Solely for `Drop`. - _chan: Sender, - stop_token: StopToken, + signal: Arc, } /// `StopToken` is a future which completes when the associated `StopSource` is dropped. -#[derive(Debug, Clone)] +#[derive(Debug)] pub struct StopToken { - chan: Receiver, + cond_var: Arc, + cached_listener: Option, +} + +impl StopToken { + fn new(cond_var: Arc) -> Self { + Self { + cond_var, + cached_listener: None, + } + } + + fn listen(&mut self) -> Option<&mut EventListener> { + if self.cond_var.is_done() { + return None; + } + + if self.cached_listener.is_none() { + self.cached_listener = self.cond_var.listen(); + } + self.cached_listener.as_mut() + } +} + +impl Clone for StopToken { + fn clone(&self) -> Self { + Self::new(self.cond_var.clone()) + } } impl Default for StopSource { fn default() -> StopSource { - let (sender, receiver) = channel::(1); - StopSource { - _chan: sender, - stop_token: StopToken { chan: receiver }, + signal: Arc::new(ShortCircuitingCondVar(AtomicOption::new( + Some(Event::new()), + ))), } } } +impl Drop for StopSource { + fn drop(&mut self) { + self.signal.notify(usize::MAX); + } +} + impl StopSource { /// Creates a new `StopSource`. pub fn new() -> StopSource { @@ -120,7 +257,7 @@ impl StopSource { /// /// Once the source is destroyed, `StopToken` future completes. pub fn stop_token(&self) -> StopToken { - self.stop_token.clone() + StopToken::new(self.signal.clone()) } } @@ -128,11 +265,15 @@ impl Future for StopToken { type Output = (); fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<()> { - let chan = Pin::new(&mut self.chan); - match Stream::poll_next(chan, cx) { - Poll::Pending => Poll::Pending, - Poll::Ready(Some(never)) => match never {}, - Poll::Ready(None) => Poll::Ready(()), + if let Some(mut listener) = self.listen() { + let result = match Future::poll(Pin::new(&mut listener), cx) { + Poll::Pending => Poll::Pending, + Poll::Ready(_) => Poll::Ready(()), + }; + + return result; + } else { + Poll::Ready(()) } } } diff --git a/tests/tests.rs b/tests/tests.rs index 8fe28f3..98bdbc7 100644 --- a/tests/tests.rs +++ b/tests/tests.rs @@ -1,6 +1,6 @@ use std::time::Duration; -use async_std::{prelude::*, task, sync::channel}; +use async_std::{prelude::*, sync::channel, task}; use stop_token::StopSource; @@ -13,24 +13,136 @@ fn smoke() { let stop_token = stop_source.stop_token(); let receiver = receiver.clone(); async move { - let mut xs = Vec::new(); - let mut stream = stop_token.stop_stream(receiver); - while let Some(x) = stream.next().await { - xs.push(x) + let mut xs = Vec::new(); + let mut stream = stop_token.stop_stream(receiver); + while let Some(x) = stream.next().await { + xs.push(x) + } + xs } - xs - }}); + }); sender.send(1).await; sender.send(2).await; sender.send(3).await; task::sleep(Duration::from_millis(250)).await; drop(stop_source); + task::sleep(Duration::from_millis(250)).await; sender.send(4).await; sender.send(5).await; sender.send(6).await; + assert_eq!(task.await, vec![1, 2, 3]); }) } + +#[test] +fn multiple_tokens() { + task::block_on(async { + let stop_source = StopSource::new(); + + let (sender_a, receiver_a) = channel::(10); + let task_a = task::spawn({ + let stop_token = stop_source.stop_token(); + let receiver = receiver_a.clone(); + async move { + let mut xs = Vec::new(); + let mut stream = stop_token.stop_stream(receiver); + while let Some(x) = stream.next().await { + xs.push(x) + } + xs + } + }); + + let (sender_b, receiver_b) = channel::(10); + let task_b = task::spawn({ + let stop_token = stop_source.stop_token(); + let receiver = receiver_b.clone(); + async move { + let mut xs = Vec::new(); + let mut stream = stop_token.stop_stream(receiver); + while let Some(x) = stream.next().await { + xs.push(x) + } + xs + } + }); + + sender_a.send(1).await; + sender_a.send(2).await; + sender_a.send(3).await; + + sender_b.send(101).await; + sender_b.send(102).await; + sender_b.send(103).await; + + task::sleep(Duration::from_millis(250)).await; + + drop(stop_source); + + task::sleep(Duration::from_millis(250)).await; + + sender_a.send(4).await; + sender_a.send(5).await; + sender_a.send(6).await; + + sender_b.send(104).await; + sender_b.send(105).await; + sender_b.send(106).await; + + assert_eq!(task_a.await, vec![1, 2, 3]); + assert_eq!(task_b.await, vec![101, 102, 103]); + }) +} + +#[test] +fn contest_cached_listener() { + task::block_on(async { + let stop_source = StopSource::new(); + + const N: usize = 8; + + let mut recv_tasks = Vec::with_capacity(N); + let mut send_tasks = Vec::with_capacity(N); + + for _ in 0..N { + let (sender, receiver) = channel::(10); + let recv_task = task::spawn({ + let stop_token = stop_source.stop_token(); + let receiver = receiver.clone(); + async move { + let mut messages = Vec::new(); + let mut stream = stop_token.stop_stream(receiver); + while let Some(msg) = stream.next().await { + messages.push(msg) + } + messages + } + }); + + let send_task = task::spawn({ + async move { + for msg in 0.. { + sender.send(msg).await; + } + } + }); + + recv_tasks.push(recv_task); + send_tasks.push(send_task); + } + + task::sleep(Duration::from_millis(500)).await; + + drop(stop_source); + + task::sleep(Duration::from_millis(500)).await; + + for (i, recv_task) in recv_tasks.into_iter().enumerate() { + eprintln!("receiver {} got {} messages", i, recv_task.await.len()); + } + }) +}