Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

tokio: distinguish LocalSet::enter() with being polled #6016

Merged
merged 18 commits into from
Oct 15, 2023
Merged
115 changes: 72 additions & 43 deletions tokio/src/task/local.rs
Original file line number Diff line number Diff line change
Expand Up @@ -280,10 +280,43 @@ pin_project! {

tokio_thread_local!(static CURRENT: LocalData = const { LocalData {
ctx: RcCell::new(),
wake_on_schedule: Cell::new(false),
} });

struct LocalData {
ctx: RcCell<Context>,
wake_on_schedule: Cell<bool>,
}

impl LocalData {
/// Should be called except when we call `LocalSet::enter`.
/// Especially when we poll a LocalSet.
#[must_use = "dropping this guard will reset the entered state"]
fn enter(&self, ctx: Rc<Context>) -> LocalDataEnterGuard<'_> {
hawkw marked this conversation as resolved.
Show resolved Hide resolved
let ctx = self.ctx.replace(Some(ctx));
let wake_on_schedule = self.wake_on_schedule.replace(false);
LocalDataEnterGuard {
local_data_ref: self,
ctx,
wake_on_schedule,
}
}
}

hawkw marked this conversation as resolved.
Show resolved Hide resolved
/// A guard for `LocalData::enter()`
struct LocalDataEnterGuard<'a> {
inq marked this conversation as resolved.
Show resolved Hide resolved
local_data_ref: &'a LocalData,
ctx: Option<Rc<Context>>,
wake_on_schedule: bool,
}

impl<'a> Drop for LocalDataEnterGuard<'a> {
fn drop(&mut self) {
self.local_data_ref.ctx.set(self.ctx.take());
self.local_data_ref
.wake_on_schedule
.set(self.wake_on_schedule)
}
}

cfg_rt! {
Expand Down Expand Up @@ -360,13 +393,26 @@ const MAX_TASKS_PER_TICK: usize = 61;
const REMOTE_FIRST_INTERVAL: u8 = 31;

/// Context guard for LocalSet
pub struct LocalEnterGuard(Option<Rc<Context>>);
pub struct LocalEnterGuard {
ctx: Option<Rc<Context>>,

/// Distinguishes whether the context was entered or being polled.
/// When we enter it, the value `wake_on_schedule` is set. In this case
/// `spawn_local` refers the context, whereas it is not being polled now.
wake_on_schedule: bool,
}

impl Drop for LocalEnterGuard {
fn drop(&mut self) {
CURRENT.with(|LocalData { ctx, .. }| {
ctx.set(self.0.take());
})
CURRENT.with(
|LocalData {
ctx,
wake_on_schedule,
}| {
ctx.set(self.ctx.take());
wake_on_schedule.set(self.wake_on_schedule);
},
)
}
}

Expand Down Expand Up @@ -408,10 +454,20 @@ impl LocalSet {
///
/// [`spawn_local`]: fn@crate::task::spawn_local
pub fn enter(&self) -> LocalEnterGuard {
CURRENT.with(|LocalData { ctx, .. }| {
let old = ctx.replace(Some(self.context.clone()));
LocalEnterGuard(old)
})
CURRENT.with(
|LocalData {
ctx,
wake_on_schedule,
..
}| {
let ctx = ctx.replace(Some(self.context.clone()));
let wake_on_schedule = wake_on_schedule.replace(true);
LocalEnterGuard {
ctx,
wake_on_schedule,
}
},
)
}

/// Spawns a `!Send` task onto the local task set.
Expand Down Expand Up @@ -667,23 +723,8 @@ impl LocalSet {
}

fn with<T>(&self, f: impl FnOnce() -> T) -> T {
CURRENT.with(|LocalData { ctx, .. }| {
struct Reset<'a> {
ctx_ref: &'a RcCell<Context>,
val: Option<Rc<Context>>,
}
impl<'a> Drop for Reset<'a> {
fn drop(&mut self) {
self.ctx_ref.set(self.val.take());
}
}
let old = ctx.replace(Some(self.context.clone()));

let _reset = Reset {
ctx_ref: ctx,
val: old,
};

CURRENT.with(|local_data| {
let _guard = local_data.enter(self.context.clone());
f()
})
}
Expand All @@ -693,23 +734,8 @@ impl LocalSet {
fn with_if_possible<T>(&self, f: impl FnOnce() -> T) -> T {
let mut f = Some(f);

let res = CURRENT.try_with(|LocalData { ctx, .. }| {
struct Reset<'a> {
ctx_ref: &'a RcCell<Context>,
val: Option<Rc<Context>>,
}
impl<'a> Drop for Reset<'a> {
fn drop(&mut self) {
self.ctx_ref.replace(self.val.take());
}
}
let old = ctx.replace(Some(self.context.clone()));

let _reset = Reset {
ctx_ref: ctx,
val: old,
};

let res = CURRENT.try_with(|local_data| {
let _guard = local_data.enter(self.context.clone());
(f.take().unwrap())()
});

Expand Down Expand Up @@ -967,7 +993,10 @@ impl Shared {
fn schedule(&self, task: task::Notified<Arc<Self>>) {
CURRENT.with(|localdata| {
match localdata.ctx.get() {
Some(cx) if cx.shared.ptr_eq(self) => unsafe {
// If the current `LocalSet` is being polled, we don't need to wake it.
// When we `enter` it, then the value `wake_on_schedule` is set to be true.
// In this case it is not being polled, so we need to wake it.
Some(cx) if cx.shared.ptr_eq(self) && !localdata.wake_on_schedule.get() => unsafe {
// Safety: if the current `LocalSet` context points to this
// `LocalSet`, then we are on the thread that owns it.
cx.shared.local_state.task_push_back(task);
Expand Down
21 changes: 21 additions & 0 deletions tokio/tests/task_local_set.rs
Original file line number Diff line number Diff line change
Expand Up @@ -573,6 +573,27 @@ async fn spawn_wakes_localset() {
}
}

/// Checks that the task wakes up with `enter`.
/// Reproduces <https://github.com/tokio-rs/tokio/issues/5020>.
#[tokio::test]
async fn sleep_with_local_enter_guard() {
hawkw marked this conversation as resolved.
Show resolved Hide resolved
let local = LocalSet::new();
let _guard = local.enter();

let (tx, rx) = oneshot::channel();

local
.run_until(async move {
tokio::task::spawn_local(async move {
time::sleep(Duration::ZERO).await;

tx.send(()).expect("failed to send");
});
assert_eq!(rx.await, Ok(()));
})
.await;
}

#[test]
fn store_local_set_in_thread_local_with_runtime() {
use tokio::runtime::Runtime;
Expand Down
Loading