Skip to content

Commit

Permalink
runtime: implement initial set of task hooks (#6742)
Browse files Browse the repository at this point in the history
  • Loading branch information
Noah-Kennedy authored Aug 27, 2024
1 parent c9fad08 commit b37f0de
Show file tree
Hide file tree
Showing 19 changed files with 384 additions and 16 deletions.
12 changes: 11 additions & 1 deletion tokio/src/runtime/blocking/schedule.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
#[cfg(feature = "test-util")]
use crate::runtime::scheduler;
use crate::runtime::task::{self, Task};
use crate::runtime::task::{self, Task, TaskHarnessScheduleHooks};
use crate::runtime::Handle;

/// `task::Schedule` implementation that does nothing (except some bookkeeping
Expand All @@ -12,6 +12,7 @@ use crate::runtime::Handle;
pub(crate) struct BlockingSchedule {
#[cfg(feature = "test-util")]
handle: Handle,
hooks: TaskHarnessScheduleHooks,
}

impl BlockingSchedule {
Expand All @@ -32,6 +33,9 @@ impl BlockingSchedule {
BlockingSchedule {
#[cfg(feature = "test-util")]
handle: handle.clone(),
hooks: TaskHarnessScheduleHooks {
task_terminate_callback: handle.inner.hooks().task_terminate_callback.clone(),
},
}
}
}
Expand All @@ -57,4 +61,10 @@ impl task::Schedule for BlockingSchedule {
fn schedule(&self, _task: task::Notified<Self>) {
unreachable!();
}

fn hooks(&self) -> TaskHarnessScheduleHooks {
TaskHarnessScheduleHooks {
task_terminate_callback: self.hooks.task_terminate_callback.clone(),
}
}
}
106 changes: 105 additions & 1 deletion tokio/src/runtime/builder.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,9 @@
#![cfg_attr(loom, allow(unused_imports))]

use crate::runtime::handle::Handle;
use crate::runtime::{blocking, driver, Callback, HistogramBuilder, Runtime};
#[cfg(tokio_unstable)]
use crate::runtime::TaskMeta;
use crate::runtime::{blocking, driver, Callback, HistogramBuilder, Runtime, TaskCallback};
use crate::util::rand::{RngSeed, RngSeedGenerator};

use std::fmt;
Expand Down Expand Up @@ -78,6 +82,12 @@ pub struct Builder {
/// To run after each thread is unparked.
pub(super) after_unpark: Option<Callback>,

/// To run before each task is spawned.
pub(super) before_spawn: Option<TaskCallback>,

/// To run after each task is terminated.
pub(super) after_termination: Option<TaskCallback>,

/// Customizable keep alive timeout for `BlockingPool`
pub(super) keep_alive: Option<Duration>,

Expand Down Expand Up @@ -290,6 +300,9 @@ impl Builder {
before_park: None,
after_unpark: None,

before_spawn: None,
after_termination: None,

keep_alive: None,

// Defaults for these values depend on the scheduler kind, so we get them
Expand Down Expand Up @@ -677,6 +690,91 @@ impl Builder {
self
}

/// Executes function `f` just before a task is spawned.
///
/// `f` is called within the Tokio context, so functions like
/// [`tokio::spawn`](crate::spawn) can be called, and may result in this callback being
/// invoked immediately.
///
/// This can be used for bookkeeping or monitoring purposes.
///
/// Note: There can only be one spawn callback for a runtime; calling this function more
/// than once replaces the last callback defined, rather than adding to it.
///
/// This *does not* support [`LocalSet`](crate::task::LocalSet) at this time.
///
/// # Examples
///
/// ```
/// # use tokio::runtime;
/// # pub fn main() {
/// let runtime = runtime::Builder::new_current_thread()
/// .on_task_spawn(|_| {
/// println!("spawning task");
/// })
/// .build()
/// .unwrap();
///
/// runtime.block_on(async {
/// tokio::task::spawn(std::future::ready(()));
///
/// for _ in 0..64 {
/// tokio::task::yield_now().await;
/// }
/// })
/// # }
/// ```
#[cfg(all(not(loom), tokio_unstable))]
pub fn on_task_spawn<F>(&mut self, f: F) -> &mut Self
where
F: Fn(&TaskMeta<'_>) + Send + Sync + 'static,
{
self.before_spawn = Some(std::sync::Arc::new(f));
self
}

/// Executes function `f` just after a task is terminated.
///
/// `f` is called within the Tokio context, so functions like
/// [`tokio::spawn`](crate::spawn) can be called.
///
/// This can be used for bookkeeping or monitoring purposes.
///
/// Note: There can only be one task termination callback for a runtime; calling this
/// function more than once replaces the last callback defined, rather than adding to it.
///
/// This *does not* support [`LocalSet`](crate::task::LocalSet) at this time.
///
/// # Examples
///
/// ```
/// # use tokio::runtime;
/// # pub fn main() {
/// let runtime = runtime::Builder::new_current_thread()
/// .on_task_terminate(|_| {
/// println!("killing task");
/// })
/// .build()
/// .unwrap();
///
/// runtime.block_on(async {
/// tokio::task::spawn(std::future::ready(()));
///
/// for _ in 0..64 {
/// tokio::task::yield_now().await;
/// }
/// })
/// # }
/// ```
#[cfg(all(not(loom), tokio_unstable))]
pub fn on_task_terminate<F>(&mut self, f: F) -> &mut Self
where
F: Fn(&TaskMeta<'_>) + Send + Sync + 'static,
{
self.after_termination = Some(std::sync::Arc::new(f));
self
}

/// Creates the configured `Runtime`.
///
/// The returned `Runtime` instance is ready to spawn tasks.
Expand Down Expand Up @@ -1118,6 +1216,8 @@ impl Builder {
Config {
before_park: self.before_park.clone(),
after_unpark: self.after_unpark.clone(),
before_spawn: self.before_spawn.clone(),
after_termination: self.after_termination.clone(),
global_queue_interval: self.global_queue_interval,
event_interval: self.event_interval,
local_queue_capacity: self.local_queue_capacity,
Expand Down Expand Up @@ -1269,6 +1369,8 @@ cfg_rt_multi_thread! {
Config {
before_park: self.before_park.clone(),
after_unpark: self.after_unpark.clone(),
before_spawn: self.before_spawn.clone(),
after_termination: self.after_termination.clone(),
global_queue_interval: self.global_queue_interval,
event_interval: self.event_interval,
local_queue_capacity: self.local_queue_capacity,
Expand Down Expand Up @@ -1316,6 +1418,8 @@ cfg_rt_multi_thread! {
Config {
before_park: self.before_park.clone(),
after_unpark: self.after_unpark.clone(),
before_spawn: self.before_spawn.clone(),
after_termination: self.after_termination.clone(),
global_queue_interval: self.global_queue_interval,
event_interval: self.event_interval,
local_queue_capacity: self.local_queue_capacity,
Expand Down
8 changes: 7 additions & 1 deletion tokio/src/runtime/config.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
any(not(all(tokio_unstable, feature = "full")), target_family = "wasm"),
allow(dead_code)
)]
use crate::runtime::Callback;
use crate::runtime::{Callback, TaskCallback};
use crate::util::RngSeedGenerator;

pub(crate) struct Config {
Expand All @@ -21,6 +21,12 @@ pub(crate) struct Config {
/// Callback for a worker unparking itself
pub(crate) after_unpark: Option<Callback>,

/// To run before each task is spawned.
pub(crate) before_spawn: Option<TaskCallback>,

/// To run after each task is terminated.
pub(crate) after_termination: Option<TaskCallback>,

/// The multi-threaded scheduler includes a per-worker LIFO slot used to
/// store the last scheduled task. This can improve certain usage patterns,
/// especially message passing between tasks. However, this LIFO slot is not
Expand Down
7 changes: 7 additions & 0 deletions tokio/src/runtime/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -379,6 +379,13 @@ cfg_rt! {
pub use dump::Dump;
}

mod task_hooks;
pub(crate) use task_hooks::{TaskHooks, TaskCallback};
#[cfg(tokio_unstable)]
pub use task_hooks::TaskMeta;
#[cfg(not(tokio_unstable))]
pub(crate) use task_hooks::TaskMeta;

mod handle;
pub use handle::{EnterGuard, Handle, TryCurrentError};

Expand Down
27 changes: 25 additions & 2 deletions tokio/src/runtime/scheduler/current_thread/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,12 @@ use crate::loom::sync::atomic::AtomicBool;
use crate::loom::sync::Arc;
use crate::runtime::driver::{self, Driver};
use crate::runtime::scheduler::{self, Defer, Inject};
use crate::runtime::task::{self, JoinHandle, OwnedTasks, Schedule, Task};
use crate::runtime::{blocking, context, Config, MetricsBatch, SchedulerMetrics, WorkerMetrics};
use crate::runtime::task::{
self, JoinHandle, OwnedTasks, Schedule, Task, TaskHarnessScheduleHooks,
};
use crate::runtime::{
blocking, context, Config, MetricsBatch, SchedulerMetrics, TaskHooks, TaskMeta, WorkerMetrics,
};
use crate::sync::notify::Notify;
use crate::util::atomic_cell::AtomicCell;
use crate::util::{waker_ref, RngSeedGenerator, Wake, WakerRef};
Expand Down Expand Up @@ -41,6 +45,9 @@ pub(crate) struct Handle {

/// Current random number generator seed
pub(crate) seed_generator: RngSeedGenerator,

/// User-supplied hooks to invoke for things
pub(crate) task_hooks: TaskHooks,
}

/// Data required for executing the scheduler. The struct is passed around to
Expand Down Expand Up @@ -131,6 +138,10 @@ impl CurrentThread {
.unwrap_or(DEFAULT_GLOBAL_QUEUE_INTERVAL);

let handle = Arc::new(Handle {
task_hooks: TaskHooks {
task_spawn_callback: config.before_spawn.clone(),
task_terminate_callback: config.after_termination.clone(),
},
shared: Shared {
inject: Inject::new(),
owned: OwnedTasks::new(1),
Expand Down Expand Up @@ -436,6 +447,12 @@ impl Handle {
{
let (handle, notified) = me.shared.owned.bind(future, me.clone(), id);

me.task_hooks.spawn(&TaskMeta {
#[cfg(tokio_unstable)]
id,
_phantom: Default::default(),
});

if let Some(notified) = notified {
me.schedule(notified);
}
Expand Down Expand Up @@ -600,6 +617,12 @@ impl Schedule for Arc<Handle> {
});
}

fn hooks(&self) -> TaskHarnessScheduleHooks {
TaskHarnessScheduleHooks {
task_terminate_callback: self.task_hooks.task_terminate_callback.clone(),
}
}

cfg_unstable! {
fn unhandled_panic(&self) {
use crate::runtime::UnhandledPanic;
Expand Down
12 changes: 12 additions & 0 deletions tokio/src/runtime/scheduler/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@ cfg_rt! {

pub(crate) mod inject;
pub(crate) use inject::Inject;

use crate::runtime::TaskHooks;
}

cfg_rt_multi_thread! {
Expand Down Expand Up @@ -151,6 +153,16 @@ cfg_rt! {
}
}

pub(crate) fn hooks(&self) -> &TaskHooks {
match self {
Handle::CurrentThread(h) => &h.task_hooks,
#[cfg(feature = "rt-multi-thread")]
Handle::MultiThread(h) => &h.task_hooks,
#[cfg(all(tokio_unstable, feature = "rt-multi-thread"))]
Handle::MultiThreadAlt(h) => &h.task_hooks,
}
}

cfg_rt_multi_thread! {
cfg_unstable! {
pub(crate) fn expect_multi_thread_alt(&self) -> &Arc<multi_thread_alt::Handle> {
Expand Down
10 changes: 10 additions & 0 deletions tokio/src/runtime/scheduler/multi_thread/handle.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ use crate::runtime::scheduler::multi_thread::worker;
use crate::runtime::{
blocking, driver,
task::{self, JoinHandle},
TaskHooks, TaskMeta,
};
use crate::util::RngSeedGenerator;

Expand All @@ -28,6 +29,9 @@ pub(crate) struct Handle {

/// Current random number generator seed
pub(crate) seed_generator: RngSeedGenerator,

/// User-supplied hooks to invoke for things
pub(crate) task_hooks: TaskHooks,
}

impl Handle {
Expand All @@ -51,6 +55,12 @@ impl Handle {
{
let (handle, notified) = me.shared.owned.bind(future, me.clone(), id);

me.task_hooks.spawn(&TaskMeta {
#[cfg(tokio_unstable)]
id,
_phantom: Default::default(),
});

me.schedule_option_task_without_yield(notified);

handle
Expand Down
14 changes: 12 additions & 2 deletions tokio/src/runtime/scheduler/multi_thread/worker.rs
Original file line number Diff line number Diff line change
Expand Up @@ -58,15 +58,15 @@
use crate::loom::sync::{Arc, Mutex};
use crate::runtime;
use crate::runtime::context;
use crate::runtime::scheduler::multi_thread::{
idle, queue, Counters, Handle, Idle, Overflow, Parker, Stats, TraceStatus, Unparker,
};
use crate::runtime::scheduler::{inject, Defer, Lock};
use crate::runtime::task::OwnedTasks;
use crate::runtime::task::{OwnedTasks, TaskHarnessScheduleHooks};
use crate::runtime::{
blocking, coop, driver, scheduler, task, Config, SchedulerMetrics, WorkerMetrics,
};
use crate::runtime::{context, TaskHooks};
use crate::util::atomic_cell::AtomicCell;
use crate::util::rand::{FastRand, RngSeedGenerator};

Expand Down Expand Up @@ -284,6 +284,10 @@ pub(super) fn create(

let remotes_len = remotes.len();
let handle = Arc::new(Handle {
task_hooks: TaskHooks {
task_spawn_callback: config.before_spawn.clone(),
task_terminate_callback: config.after_termination.clone(),
},
shared: Shared {
remotes: remotes.into_boxed_slice(),
inject,
Expand Down Expand Up @@ -1037,6 +1041,12 @@ impl task::Schedule for Arc<Handle> {
self.schedule_task(task, false);
}

fn hooks(&self) -> TaskHarnessScheduleHooks {
TaskHarnessScheduleHooks {
task_terminate_callback: self.task_hooks.task_terminate_callback.clone(),
}
}

fn yield_now(&self, task: Notified) {
self.schedule_task(task, true);
}
Expand Down
Loading

0 comments on commit b37f0de

Please sign in to comment.