Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

task: Drop the join waker of a task eagerly when the JoinHandle gets dropped or the task completes #6986

Merged
merged 8 commits into from
Dec 29, 2024
1 change: 1 addition & 0 deletions spellcheck.dic
Original file line number Diff line number Diff line change
Expand Up @@ -274,6 +274,7 @@ unparks
Unparks
unreceived
unsafety
unsets
Unsets
unsynchronized
untrusted
Expand Down
41 changes: 37 additions & 4 deletions tokio/src/runtime/task/harness.rs
Original file line number Diff line number Diff line change
Expand Up @@ -284,9 +284,11 @@ where
}

pub(super) fn drop_join_handle_slow(self) {
// Try to unset `JOIN_INTEREST`. This must be done as a first step in
// Try to unset `JOIN_INTEREST` and `JOIN_WAKER`. This must be done as a first step in
// case the task concurrently completed.
if self.state().unset_join_interested().is_err() {
let transition = self.state().transition_to_join_handle_dropped();

if transition.drop_output {
// It is our responsibility to drop the output. This is critical as
// the task output may not be `Send` and as such must remain with
// the scheduler or `JoinHandle`. i.e. if the output remains in the
Expand All @@ -301,6 +303,23 @@ where
}));
}

if transition.drop_waker {
// If the JOIN_WAKER flag is unset at this point, the task is either
// already terminal or not complete so the `JoinHandle` is responsible
// for dropping the waker.
// Safety:
// If the JOIN_WAKER bit is not set the join handle has exclusive
// access to the waker as per rule 2 in task/mod.rs.
// This can only be the case at this point in two scenarios:
// 1. The task completed and the runtime unset `JOIN_WAKER` flag
// after accessing the waker during task completion. So the
// `JoinHandle` is the only one to access the join waker here.
// 2. The task is not completed so the `JoinHandle` was able to unset
// `JOIN_WAKER` bit itself to get mutable access to the waker.
// The runtime will not access the waker when this flag is unset.
unsafe { self.trailer().set_waker(None) };
}

// Drop the `JoinHandle` reference, possibly deallocating the task
self.drop_reference();
}
Expand All @@ -311,7 +330,6 @@ where
fn complete(self) {
// The future has completed and its output has been written to the task
// stage. We transition from running to complete.

let snapshot = self.state().transition_to_complete();

// We catch panics here in case dropping the future or waking the
Expand All @@ -320,13 +338,28 @@ where
if !snapshot.is_join_interested() {
// The `JoinHandle` is not interested in the output of
// this task. It is our responsibility to drop the
// output.
// output. The join waker was already dropped by the
// `JoinHandle` before.
self.core().drop_future_or_output();
} else if snapshot.is_join_waker_set() {
// Notify the waker. Reading the waker field is safe per rule 4
// in task/mod.rs, since the JOIN_WAKER bit is set and the call
// to transition_to_complete() above set the COMPLETE bit.
self.trailer().wake_join();

// Inform the `JoinHandle` that we are done waking the waker by
// unsetting the `JOIN_WAKER` bit. If the `JoinHandle` has
// already been dropped and `JOIN_INTEREST` is unset, then we must
// drop the waker ourselves.
if !self
.state()
.unset_waker_after_complete()
.is_join_interested()
{
// SAFETY: We have COMPLETE=1 and JOIN_INTEREST=0, so
// we have exclusive access to the waker.
unsafe { self.trailer().set_waker(None) };
}
}
}));

Expand Down
18 changes: 16 additions & 2 deletions tokio/src/runtime/task/mod.rs
tglane marked this conversation as resolved.
Show resolved Hide resolved
Original file line number Diff line number Diff line change
Expand Up @@ -94,16 +94,30 @@
//! `JoinHandle` needs to (i) successfully set `JOIN_WAKER` to zero if it is
//! not already zero to gain exclusive access to the waker field per rule
//! 2, (ii) write a waker, and (iii) successfully set `JOIN_WAKER` to one.
//! If the `JoinHandle` unsets `JOIN_WAKER` in the process of being dropped
//! to clear the waker field, only steps (i) and (ii) are relevant.
//!
//! 6. The `JoinHandle` can change `JOIN_WAKER` only if COMPLETE is zero (i.e.
//! the task hasn't yet completed).
//! the task hasn't yet completed). The runtime can change `JOIN_WAKER` only
//! if COMPLETE is one.
//!
//! 7. If `JOIN_INTEREST` is zero and COMPLETE is one, then the runtime has
//! exclusive (mutable) access to the waker field. This might happen if the
//! `JoinHandle` gets dropped right after the task completes and the runtime
//! sets the `COMPLETE` bit. In this case the runtime needs the mutable access
//! to the waker field to drop it.
//!
//! Rule 6 implies that the steps (i) or (iii) of rule 5 may fail due to a
//! race. If step (i) fails, then the attempt to write a waker is aborted. If
//! step (iii) fails because COMPLETE is set to one by another thread after
//! step (i), then the waker field is cleared. Once COMPLETE is one (i.e.
//! task has completed), the `JoinHandle` will not modify `JOIN_WAKER`. After the
//! runtime sets COMPLETE to one, it invokes the waker if there is one.
//! runtime sets COMPLETE to one, it invokes the waker if there is one so in this
//! case when a task completes the `JOIN_WAKER` bit implicates to the runtime
//! whether it should invoke the waker or not. After the runtime is done with
//! using the waker during task completion, it unsets the `JOIN_WAKER` bit to give
//! the `JoinHandle` exclusive access again so that it is able to drop the waker
//! at a later point.
//!
//! All other fields are immutable and can be accessed immutably without
//! synchronization by anyone.
Expand Down
63 changes: 51 additions & 12 deletions tokio/src/runtime/task/state.rs
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,12 @@ pub(crate) enum TransitionToNotifiedByRef {
Submit,
}

#[must_use]
pub(super) struct TransitionToJoinHandleDrop {
pub(super) drop_waker: bool,
pub(super) drop_output: bool,
}

/// All transitions are performed via RMW operations. This establishes an
/// unambiguous modification order.
impl State {
Expand Down Expand Up @@ -371,22 +377,45 @@ impl State {
.map_err(|_| ())
}

/// Tries to unset the `JOIN_INTEREST` flag.
///
/// Returns `Ok` if the operation happens before the task transitions to a
/// completed state, `Err` otherwise.
pub(super) fn unset_join_interested(&self) -> UpdateResult {
self.fetch_update(|curr| {
assert!(curr.is_join_interested());
/// Unsets the `JOIN_INTEREST` flag. If `COMPLETE` is not set, the `JOIN_WAKER`
/// flag is also unset.
/// The returned `TransitionToJoinHandleDrop` indicates whether the `JoinHandle` should drop
/// the output of the future or the join waker after the transition.
pub(super) fn transition_to_join_handle_dropped(&self) -> TransitionToJoinHandleDrop {
self.fetch_update_action(|mut snapshot| {
assert!(snapshot.is_join_interested());

if curr.is_complete() {
return None;
let mut transition = TransitionToJoinHandleDrop {
drop_waker: false,
drop_output: false,
};

snapshot.unset_join_interested();

if !snapshot.is_complete() {
// If `COMPLETE` is unset we also unset `JOIN_WAKER` to give the
// `JoinHandle` exclusive access to the waker following rule 6 in task/mod.rs.
// The `JoinHandle` will drop the waker if it has exclusive access
// to drop it.
snapshot.unset_join_waker();
} else {
// If `COMPLETE` is set the task is completed so the `JoinHandle` is responsible
// for dropping the output.
transition.drop_output = true;
}

let mut next = curr;
next.unset_join_interested();
if !snapshot.is_join_waker_set() {
// If the `JOIN_WAKER` bit is unset and the `JOIN_HANDLE` has exclusive access to
// the join waker and should drop it following this transition.
// This might happen in two situations:
// 1. The task is not completed and we just unset the `JOIN_WAKer` above in this
// function.
// 2. The task is completed. In that case the `JOIN_WAKER` bit was already unset
// by the runtime during completion.
transition.drop_waker = true;
}

Some(next)
(transition, Some(snapshot))
})
}

Expand Down Expand Up @@ -430,6 +459,16 @@ impl State {
})
}

/// Unsets the `JOIN_WAKER` bit unconditionally after task completion.
///
/// This operation requires the task to be completed.
pub(super) fn unset_waker_after_complete(&self) -> Snapshot {
let prev = Snapshot(self.val.fetch_and(!JOIN_WAKER, AcqRel));
assert!(prev.is_complete());
assert!(prev.is_join_waker_set());
Snapshot(prev.0 & !JOIN_WAKER)
}

pub(super) fn ref_inc(&self) {
use std::process;
use std::sync::atomic::Ordering::Relaxed;
Expand Down
58 changes: 56 additions & 2 deletions tokio/src/runtime/tests/loom_current_thread.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
mod yield_now;

use crate::loom::sync::atomic::AtomicUsize;
use crate::loom::sync::atomic::{AtomicUsize, Ordering};
use crate::loom::sync::Arc;
use crate::loom::thread;
use crate::runtime::{Builder, Runtime};
Expand All @@ -9,7 +9,7 @@ use crate::task;
use std::future::Future;
use std::pin::Pin;
use std::sync::atomic::Ordering::{Acquire, Release};
use std::task::{Context, Poll};
use std::task::{Context, Poll, RawWaker, RawWakerVTable, Waker};

fn assert_at_most_num_polls(rt: Arc<Runtime>, at_most_polls: usize) {
let (tx, rx) = oneshot::channel();
Expand Down Expand Up @@ -106,6 +106,60 @@ fn assert_no_unnecessary_polls() {
});
}

#[test]
fn drop_jh_during_schedule() {
unsafe fn waker_clone(ptr: *const ()) -> RawWaker {
let atomic = unsafe { &*(ptr as *const AtomicUsize) };
atomic.fetch_add(1, Ordering::Relaxed);
RawWaker::new(ptr, &VTABLE)
}
unsafe fn waker_drop(ptr: *const ()) {
let atomic = unsafe { &*(ptr as *const AtomicUsize) };
atomic.fetch_sub(1, Ordering::Relaxed);
}
unsafe fn waker_nop(_ptr: *const ()) {}

static VTABLE: RawWakerVTable =
RawWakerVTable::new(waker_clone, waker_drop, waker_nop, waker_drop);

loom::model(|| {
let rt = Builder::new_current_thread().build().unwrap();

let mut jh = rt.spawn(async {});
// Using AbortHandle to increment task refcount. This ensures that the waker is not
// destroyed due to the refcount hitting zero.
let task_refcnt = jh.abort_handle();

let waker_refcnt = AtomicUsize::new(1);
{
// Set up the join waker.
use std::future::Future;
use std::pin::Pin;

// SAFETY: Before `waker_refcnt` goes out of scope, this test asserts that the refcnt
// has dropped to zero.
let join_waker = unsafe {
Waker::from_raw(RawWaker::new(
(&waker_refcnt) as *const AtomicUsize as *const (),
&VTABLE,
))
};

assert!(Pin::new(&mut jh)
.poll(&mut Context::from_waker(&join_waker))
.is_pending());
}
assert_eq!(waker_refcnt.load(Ordering::Relaxed), 1);

let bg_thread = loom::thread::spawn(move || drop(jh));
rt.block_on(crate::task::yield_now());
bg_thread.join().unwrap();

assert_eq!(waker_refcnt.load(Ordering::Relaxed), 0);
drop(task_refcnt);
});
}

struct BlockedFuture {
rx: Receiver<()>,
num_polls: Arc<AtomicUsize>,
Expand Down
39 changes: 38 additions & 1 deletion tokio/src/runtime/tests/task.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
use crate::runtime::task::{
self, unowned, Id, JoinHandle, OwnedTasks, Schedule, Task, TaskHarnessScheduleHooks,
};
use crate::runtime::tests::NoopSchedule;
use crate::runtime::{self, tests::NoopSchedule};
use crate::spawn;
use crate::sync::{mpsc, Barrier};

use std::collections::VecDeque;
use std::future::Future;
Expand Down Expand Up @@ -45,6 +47,41 @@ impl Drop for AssertDrop {
}
}

#[test]
fn drop_tasks_with_reference_cycle() {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This does not use any Tokio internals, so it does not need to be inside src/. It should be in tokio/tests/ instead.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'll move it for you.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh thanks I wasnt aware of this but will keep it in mind for the future.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No worries. Thanks a lot for the PR!

let rt = runtime::Builder::new_current_thread().build().unwrap();

rt.block_on(async {
let (tx, mut rx) = mpsc::channel(1);

let barrier = Arc::new(Barrier::new(3));
let barrier_a = barrier.clone();
let barrier_b = barrier.clone();

let a = spawn(async move {
let b = rx.recv().await.unwrap();

// Poll the JoinHandle once. This registers the waker.
// The other task cannot have finished at this point due to the barrier below.
futures::future::select(b, std::future::ready(())).await;

barrier_a.wait().await;
});

let b = spawn(async move {
// Poll the JoinHandle once. This registers the waker.
// The other task cannot have finished at this point due to the barrier below.
futures::future::select(a, std::future::ready(())).await;

barrier_b.wait().await;
});

tx.send(b).await.unwrap();

barrier.wait().await;
});
}

// A Notified does not shut down on drop, but it is dropped once the ref-count
// hits zero.
#[test]
Expand Down
Loading