Skip to content

Commit

Permalink
Drop the join waker of a task eagerly when the task completes and the…
Browse files Browse the repository at this point in the history
…re is no

join interest
  • Loading branch information
tglane committed Nov 13, 2024
1 parent bb7ca75 commit 16d9a86
Show file tree
Hide file tree
Showing 3 changed files with 182 additions and 46 deletions.
72 changes: 56 additions & 16 deletions tokio/src/runtime/task/harness.rs
Original file line number Diff line number Diff line change
Expand Up @@ -284,21 +284,39 @@ where
}

pub(super) fn drop_join_handle_slow(self) {
use super::state::TransitionToJoinHandleDrop;
// Try to unset `JOIN_INTEREST`. This must be done as a first step in
// case the task concurrently completed.
if self.state().unset_join_interested().is_err() {
// It is our responsibility to drop the output. This is critical as
// the task output may not be `Send` and as such must remain with
// the scheduler or `JoinHandle`. i.e. if the output remains in the
// task structure until the task is deallocated, it may be dropped
// by a Waker on any arbitrary thread.
//
// Panics are delivered to the user via the `JoinHandle`. Given that
// they are dropping the `JoinHandle`, we assume they are not
// interested in the panic and swallow it.
let _ = panic::catch_unwind(panic::AssertUnwindSafe(|| {
self.core().drop_future_or_output();
}));
let transition = self.state().transition_to_join_handle_drop();
match transition {
TransitionToJoinHandleDrop::DropOutput => {
// It is our responsibility to drop the output. This is critical as
// the task output may not be `Send` and as such must remain with
// the scheduler or `JoinHandle`. i.e. if the output remains in the
// task structure until the task is deallocated, it may be dropped
// by a Waker on any arbitrary thread.
//
// Panics are delivered to the user via the `JoinHandle`. Given that
// they are dropping the `JoinHandle`, we assume they are not
// interested in the panic and swallow it.
let _ = panic::catch_unwind(panic::AssertUnwindSafe(|| {
self.core().drop_future_or_output();
}));
}
TransitionToJoinHandleDrop::DropJoinWaker => unsafe {
// If there is a waker associated with this task when the `JoinHandle` is about to get
// dropped we want to also drop this waker if the task is already completed.
self.trailer().set_waker(None);
},
TransitionToJoinHandleDrop::DropBoth => {
let _ = panic::catch_unwind(panic::AssertUnwindSafe(|| {
self.core().drop_future_or_output();
}));
unsafe {
self.trailer().set_waker(None);
}
}
TransitionToJoinHandleDrop::DoNothing => (),
}

// Drop the `JoinHandle` reference, possibly deallocating the task
Expand All @@ -309,6 +327,7 @@ where

/// Completes the task. This method assumes that the state is RUNNING.
fn complete(self) {
use super::state::TransitionToTerminal;
// The future has completed and its output has been written to the task
// stage. We transition from running to complete.

Expand Down Expand Up @@ -346,8 +365,29 @@ where
// The task has completed execution and will no longer be scheduled.
let num_release = self.release();

if self.state().transition_to_terminal(num_release) {
self.dealloc();
match self.state().transition_to_terminal(num_release) {
TransitionToTerminal::OkDoNothing => (),
TransitionToTerminal::OkDealloc => {
self.dealloc();
}
TransitionToTerminal::FailedDropJoinWaker => {
// Safety: In this case we are the only one referencing the task and the active
// waker is the only one preventing the task from being deallocated so noone else
// will try to access the waker here.
unsafe {
self.trailer().set_waker(None);
}

// We do not expect this to happen since `TransitionToTerminal::DropJoinWaker`
// will only be returned when after dropping the JoinWaker the task can be
// safely. Because after this failed transition the COMPLETE bit is still set
// its fine to transition to terminal in two steps here
if let TransitionToTerminal::OkDealloc =
self.state().transition_to_terminal(num_release)
{
self.dealloc();
}
}
}
}

Expand Down Expand Up @@ -387,7 +427,7 @@ fn can_read_output(header: &Header, trailer: &Trailer, waker: &Waker) -> bool {

debug_assert!(snapshot.is_join_interested());

if !snapshot.is_complete() {
if !snapshot.is_complete() && !snapshot.is_terminal() {
// If the task is not complete, try storing the provided waker in the
// task's waker field.

Expand Down
126 changes: 96 additions & 30 deletions tokio/src/runtime/task/state.rs
Original file line number Diff line number Diff line change
Expand Up @@ -36,8 +36,12 @@ const JOIN_WAKER: usize = 0b10_000;
/// The task has been forcibly cancelled.
const CANCELLED: usize = 0b100_000;

const TERMINAL: usize = 0b1_000_000;

/// All bits.
const STATE_MASK: usize = LIFECYCLE_MASK | NOTIFIED | JOIN_INTEREST | JOIN_WAKER | CANCELLED;
// const STATE_MASK: usize = LIFECYCLE_MASK | NOTIFIED | JOIN_INTEREST | JOIN_WAKER | CANCELLED;
const STATE_MASK: usize =
LIFECYCLE_MASK | NOTIFIED | JOIN_INTEREST | JOIN_WAKER | CANCELLED | TERMINAL;

/// Bits used by the ref count portion of the state.
const REF_COUNT_MASK: usize = !STATE_MASK;
Expand Down Expand Up @@ -89,6 +93,21 @@ pub(crate) enum TransitionToNotifiedByRef {
Submit,
}

#[must_use]
pub(crate) enum TransitionToJoinHandleDrop {
DoNothing,
DropOutput,
DropJoinWaker,
DropBoth,
}

#[must_use]
pub(crate) enum TransitionToTerminal {
OkDoNothing,
OkDealloc,
FailedDropJoinWaker,
}

/// All transitions are performed via RMW operations. This establishes an
/// unambiguous modification order.
impl State {
Expand Down Expand Up @@ -174,30 +193,92 @@ impl State {
})
}

pub(super) fn transition_to_join_handle_drop(&self) -> TransitionToJoinHandleDrop {
self.fetch_update_action(|mut snapshot| {
assert!(snapshot.is_join_interested());

// Unset the `JOIN_INTEREST` bit because after this transition the join handle
// will get dropped thus removing the join interest.
snapshot.unset_join_interested();

if snapshot.is_complete() && (!snapshot.is_terminal() || !snapshot.is_join_waker_set())
{
// Task is complete but because its either not terminal yet or their is no join
// waker registered no one else is interested in the output of the task when the
// join handle is dropped so we also drop the output with the join handle.
(TransitionToJoinHandleDrop::DropOutput, None)
} else if snapshot.is_terminal() && snapshot.is_join_waker_set() {
// Task is terminal but there is a join waker registered so only the join handle
// and the join waker keep the task alive. We can drop the output and the waker
// together with the join handle to dealloc the task.
snapshot.unset_join_waker();
(TransitionToJoinHandleDrop::DropBoth, Some(snapshot))
} else if snapshot.is_join_waker_set() {
// Task is not complete but a join waker is registered. In this case we only drop
// the waker together with the join handle.
snapshot.unset_join_waker();
(TransitionToJoinHandleDrop::DropJoinWaker, Some(snapshot))
} else {
// Task is not complete and no join waker is registered so we do nothing special
// when dropping the join handle.
(TransitionToJoinHandleDrop::DoNothing, Some(snapshot))
}
})
}

/// Transitions the task from `Running` -> `Complete`.
pub(super) fn transition_to_complete(&self) -> Snapshot {
const DELTA: usize = RUNNING | COMPLETE;

let prev = Snapshot(self.val.fetch_xor(DELTA, AcqRel));
assert!(prev.is_running());
assert!(!prev.is_complete());
assert!(!prev.is_terminal());

Snapshot(prev.0 ^ DELTA)
}

/// Transitions from `Complete` -> `Terminal`, decrementing the reference
/// count the specified number of times.
///
/// Returns true if the task should be deallocated.
pub(super) fn transition_to_terminal(&self, count: usize) -> bool {
let prev = Snapshot(self.val.fetch_sub(count * REF_ONE, AcqRel));
assert!(
prev.ref_count() >= count,
"current: {}, sub: {}",
prev.ref_count(),
count
);
prev.ref_count() == count
/// Returns `TransitionToTerminal::OkDoNothing` if transition was successful but the task can
/// not already be deallocated.
/// Returns `TransitionToTerminal::OkDealloc` if the task should be deallocated.
/// Returns `TransitionToTerminal::FailedDropJoinWaker` if the transition failed because of a
/// the join waker being the only last. In this case the reference count will not be decremented
/// but the `JOIN_WAKER` bit will be unset.
pub(super) fn transition_to_terminal(&self, count: usize) -> TransitionToTerminal {
self.fetch_update_action(|mut snapshot| {
assert!(!snapshot.is_running());
assert!(snapshot.is_complete());
assert!(!snapshot.is_terminal());
assert!(
snapshot.ref_count() >= count,
"current: {}, sub: {}",
snapshot.ref_count(),
count
);

if snapshot.ref_count() == count {
// If the ref count of the task matches the count to decrease we know that there
// is no join waker left registered and we can safely deallocate the task.
snapshot.0 -= count * REF_ONE;
snapshot.0 |= TERMINAL;
(TransitionToTerminal::OkDealloc, Some(snapshot))
} else if !snapshot.is_join_interested() && snapshot.is_join_waker_set() {
// Otherwise if there is no join handle alive but a join waker registered the
// transition to the terminal state failed and we need to inform the caller that
// first the join waker needs to be dropped to reduce the reference count.
snapshot.unset_join_waker();
(TransitionToTerminal::FailedDropJoinWaker, Some(snapshot))
} else {
// Task transitions to terminal but there are still references to the task so
// we just decrease the refcount and do nothing else.
snapshot.0 -= count * REF_ONE;
snapshot.0 |= TERMINAL;
(TransitionToTerminal::OkDoNothing, Some(snapshot))
}
})
}

/// Transitions the state to `NOTIFIED`.
Expand Down Expand Up @@ -371,25 +452,6 @@ impl State {
.map_err(|_| ())
}

/// Tries to unset the `JOIN_INTEREST` flag.
///
/// Returns `Ok` if the operation happens before the task transitions to a
/// completed state, `Err` otherwise.
pub(super) fn unset_join_interested(&self) -> UpdateResult {
self.fetch_update(|curr| {
assert!(curr.is_join_interested());

if curr.is_complete() {
return None;
}

let mut next = curr;
next.unset_join_interested();

Some(next)
})
}

/// Sets the `JOIN_WAKER` bit.
///
/// Returns `Ok` if the bit is set, `Err` otherwise. This operation fails if
Expand Down Expand Up @@ -557,6 +619,10 @@ impl Snapshot {
self.0 & COMPLETE == COMPLETE
}

pub(super) fn is_terminal(self) -> bool {
self.0 & TERMINAL == TERMINAL
}

pub(super) fn is_join_interested(self) -> bool {
self.0 & JOIN_INTEREST == JOIN_INTEREST
}
Expand Down
30 changes: 30 additions & 0 deletions tokio/src/runtime/tests/loom_multi_thread.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ mod yield_now;
/// In order to speed up the C
use crate::runtime::tests::loom_oneshot as oneshot;
use crate::runtime::{self, Runtime};
use crate::sync::mpsc::channel;
use crate::{spawn, task};
use tokio_test::assert_ok;

Expand Down Expand Up @@ -459,3 +460,32 @@ impl<T: Future> Future for Track<T> {
})
}
}

#[test]
fn drop_tasks_with_reference_cycle() {
loom::model(|| {
let pool = mk_pool(2);

pool.block_on(async move {
let (tx, mut rx) = channel(1);

let (a_closer, mut wait_for_close_a) = channel::<()>(1);
let (b_closer, mut wait_for_close_b) = channel::<()>(1);

let a = spawn(async move {
let b = rx.recv().await.unwrap();

futures::future::select(std::pin::pin!(b), std::pin::pin!(a_closer.send(()))).await;
});

let b = spawn(async move {
let _ = a.await;
let _ = b_closer.send(()).await;
});

tx.send(b).await.unwrap();

futures::future::join(wait_for_close_a.recv(), wait_for_close_b.recv()).await;
});
});
}

0 comments on commit 16d9a86

Please sign in to comment.