Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

tokio: distinguish LocalSet::enter() with being polled #6016

Merged
merged 18 commits into from
Oct 15, 2023
Merged
95 changes: 54 additions & 41 deletions tokio/src/task/local.rs
Original file line number Diff line number Diff line change
Expand Up @@ -280,10 +280,41 @@ pin_project! {

tokio_thread_local!(static CURRENT: LocalData = const { LocalData {
ctx: RcCell::new(),
entered: Cell::new(false),
} });

struct LocalData {
ctx: RcCell<Context>,
entered: Cell<bool>,
}

impl LocalData {
/// Should be called except when we call `LocalSet::enter`.
/// Especially when we poll a LocalSet.
#[must_use = "dropping this guard will reset the entered state"]
fn enter(&self, ctx: Rc<Context>) -> LocalDataEnterGuard<'_> {
hawkw marked this conversation as resolved.
Show resolved Hide resolved
let ctx = self.ctx.replace(Some(ctx));
let entered = self.entered.replace(false);
LocalDataEnterGuard {
local_data_ref: self,
ctx,
entered,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I was reading over this code, and I'm confused about this. Why does calling enter set entered to false?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's because of the name of the user-facing level functions. So this LocalData::enter should be called when we poll, and the value entered should be set only if user called the LocalSer::enter, which is public.
I agree that it's confused, but could not find suitable naming..

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we rename the field to something like wake_on_spawn_local?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree with @Darksonn, the entered field could be renamed to something clearer.

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Agreed

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you very much! I've renamed it!
But... wake_on_spawn_local vs wake_on_schedule.. which would be better?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

wake_on_schedule also makes sense to me.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you! I've renamed it to wake_on_schdule!

}
}
}

hawkw marked this conversation as resolved.
Show resolved Hide resolved
/// A guard for `LocalData::enter()`
struct LocalDataEnterGuard<'a> {
inq marked this conversation as resolved.
Show resolved Hide resolved
local_data_ref: &'a LocalData,
ctx: Option<Rc<Context>>,
entered: bool,
}

impl<'a> Drop for LocalDataEnterGuard<'a> {
fn drop(&mut self) {
self.local_data_ref.ctx.set(self.ctx.take());
self.local_data_ref.entered.set(self.entered)
}
}

cfg_rt! {
Expand Down Expand Up @@ -360,12 +391,20 @@ const MAX_TASKS_PER_TICK: usize = 61;
const REMOTE_FIRST_INTERVAL: u8 = 31;

/// Context guard for LocalSet
pub struct LocalEnterGuard(Option<Rc<Context>>);
pub struct LocalEnterGuard {
ctx: Option<Rc<Context>>,

/// Distinguishes whether the context was entered or being polled.
/// When we enter it, the value `entered` is set. In this case
/// `spawn_local` refers the context, whereas it is not being polled now.
entered: bool,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can we add a comment here describing how entered is used and what it represents? thanks!

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've added! Thanks!

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry if this is overly nitpicky, but I think this comment could maybe be improved a bit --- it would be nice if it explained why we need to differentiate between enter and polling the LocalSet, and what behavior is controlled based on that. I don't think that's currelty all that clear from reading the code.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oops, sorry! I've corrected it!

}

impl Drop for LocalEnterGuard {
fn drop(&mut self) {
CURRENT.with(|LocalData { ctx, .. }| {
ctx.set(self.0.take());
CURRENT.with(|LocalData { ctx, entered }| {
ctx.set(self.ctx.take());
entered.set(self.entered);
})
}
}
Expand Down Expand Up @@ -408,9 +447,10 @@ impl LocalSet {
///
/// [`spawn_local`]: fn@crate::task::spawn_local
pub fn enter(&self) -> LocalEnterGuard {
CURRENT.with(|LocalData { ctx, .. }| {
let old = ctx.replace(Some(self.context.clone()));
LocalEnterGuard(old)
CURRENT.with(|LocalData { ctx, entered, .. }| {
let ctx = ctx.replace(Some(self.context.clone()));
let entered = entered.replace(true);
LocalEnterGuard { ctx, entered }
})
}

Expand Down Expand Up @@ -667,23 +707,8 @@ impl LocalSet {
}

fn with<T>(&self, f: impl FnOnce() -> T) -> T {
CURRENT.with(|LocalData { ctx, .. }| {
struct Reset<'a> {
ctx_ref: &'a RcCell<Context>,
val: Option<Rc<Context>>,
}
impl<'a> Drop for Reset<'a> {
fn drop(&mut self) {
self.ctx_ref.set(self.val.take());
}
}
let old = ctx.replace(Some(self.context.clone()));

let _reset = Reset {
ctx_ref: ctx,
val: old,
};

CURRENT.with(|local_data| {
let _guard = local_data.enter(self.context.clone());
f()
})
}
Expand All @@ -693,23 +718,8 @@ impl LocalSet {
fn with_if_possible<T>(&self, f: impl FnOnce() -> T) -> T {
let mut f = Some(f);

let res = CURRENT.try_with(|LocalData { ctx, .. }| {
struct Reset<'a> {
ctx_ref: &'a RcCell<Context>,
val: Option<Rc<Context>>,
}
impl<'a> Drop for Reset<'a> {
fn drop(&mut self) {
self.ctx_ref.replace(self.val.take());
}
}
let old = ctx.replace(Some(self.context.clone()));

let _reset = Reset {
ctx_ref: ctx,
val: old,
};

let res = CURRENT.try_with(|local_data| {
let _guard = local_data.enter(self.context.clone());
(f.take().unwrap())()
});

Expand Down Expand Up @@ -967,7 +977,10 @@ impl Shared {
fn schedule(&self, task: task::Notified<Arc<Self>>) {
CURRENT.with(|localdata| {
match localdata.ctx.get() {
Some(cx) if cx.shared.ptr_eq(self) => unsafe {
// If the current `LocalSet` is being polled, we don't need to wake it.
// When we `enter` it, then the value `entered` is set to be true.
// In this case it is not being polled, so we need to wake it.
Some(cx) if cx.shared.ptr_eq(self) && !localdata.entered.get() => unsafe {
Darksonn marked this conversation as resolved.
Show resolved Hide resolved
// Safety: if the current `LocalSet` context points to this
// `LocalSet`, then we are on the thread that owns it.
cx.shared.local_state.task_push_back(task);
Expand Down
21 changes: 21 additions & 0 deletions tokio/tests/task_local_set.rs
Original file line number Diff line number Diff line change
Expand Up @@ -573,6 +573,27 @@ async fn spawn_wakes_localset() {
}
}

/// Checks that the task wakes up with `enter`.
/// Reproduces <https://github.com/tokio-rs/tokio/issues/5020>.
#[tokio::test]
async fn sleep_with_local_enter_guard() {
hawkw marked this conversation as resolved.
Show resolved Hide resolved
let local = LocalSet::new();
let _guard = local.enter();

let (tx, rx) = oneshot::channel();

local
.run_until(async move {
tokio::task::spawn_local(async move {
time::sleep(Duration::ZERO).await;

tx.send(()).expect("failed to send");
});
assert_eq!(rx.await, Ok(()));
})
.await;
}

#[test]
fn store_local_set_in_thread_local_with_runtime() {
use tokio::runtime::Runtime;
Expand Down
Loading