Skip to content

loan_cell: add primitive for lending thread-local data #743

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 6 commits into from
Jan 30, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions Cargo.lock
Original file line number Diff line number Diff line change
Expand Up @@ -3567,6 +3567,13 @@ dependencies = [
"zerocopy",
]

[[package]]
name = "loan_cell"
version = "0.0.0"
dependencies = [
"static_assertions",
]

[[package]]
name = "local_clock"
version = "0.0.0"
Expand Down Expand Up @@ -4876,6 +4883,7 @@ dependencies = [
"futures",
"getrandom",
"libc",
"loan_cell",
"once_cell",
"pal",
"pal_async_test",
Expand Down Expand Up @@ -4919,6 +4927,7 @@ dependencies = [
"inspect",
"io-uring",
"libc",
"loan_cell",
"once_cell",
"pal",
"pal_async",
Expand Down Expand Up @@ -7204,6 +7213,7 @@ version = "0.0.0"
dependencies = [
"fs-err",
"inspect",
"loan_cell",
"pal",
"pal_async",
"pal_uring",
Expand Down
1 change: 1 addition & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,7 @@ inspect_proto = { path = "support/inspect_proto" }
inspect_rlimit = { path = "support/inspect_rlimit" }
inspect_task = { path = "support/inspect_task" }
kmsg = { path = "support/kmsg" }
loan_cell = { path = "support/loan_cell" }
local_clock = { path = "support/local_clock" }
mesh = { path = "support/mesh" }
mesh_build = { path = "support/mesh/mesh_build" }
Expand Down
11 changes: 6 additions & 5 deletions openhcl/underhill_core/src/vp.rs
Original file line number Diff line number Diff line change
Expand Up @@ -84,10 +84,11 @@ impl VpSpawner {
let thread = underhill_threadpool::Thread::current().unwrap();
// TODO propagate this error back earlier. This is easiest if
// set_idle_task is fixed to take a non-Send fn.
let mut vp = self
.vp
.bind_processor::<T>(thread.driver(), control)
.context("failed to initialize VP")?;
let mut vp = thread.with_driver(|driver| {
self.vp
.bind_processor::<T>(driver, control)
.context("failed to initialize VP")
})?;

if let Some(saved_state) = saved_state {
vmcore::save_restore::ProtobufSaveRestore::restore(&mut vp, saved_state)
Expand Down Expand Up @@ -166,7 +167,7 @@ impl VpSpawner {
self.vp.set_sidecar_exit_due_to_task(
thread
.first_task()
.map_or_else(|| "<unknown>".into(), |t| t.name.clone()),
.map_or_else(|| "<unknown>".into(), |t| t.name),
);
}

Expand Down
3 changes: 1 addition & 2 deletions openhcl/underhill_mem/src/init.rs
Original file line number Diff line number Diff line change
Expand Up @@ -469,8 +469,7 @@ async fn apply_vtl2_protections(
tracing::debug!(
cpu = underhill_threadpool::Thread::current()
.unwrap()
.driver()
.target_cpu(),
.with_driver(|driver| driver.target_cpu()),
%range,
"applying protections"
);
Expand Down
1 change: 1 addition & 0 deletions openhcl/underhill_threadpool/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ rust-version.workspace = true

[target.'cfg(target_os = "linux")'.dependencies]
inspect = { workspace = true, features = ["std"] }
loan_cell.workspace = true
pal.workspace = true
pal_async.workspace = true
pal_uring.workspace = true
Expand Down
120 changes: 53 additions & 67 deletions openhcl/underhill_threadpool/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,10 @@
//! This is built on top of [`pal_uring`] and [`pal_async`].

#![warn(missing_docs)]
// UNSAFETY: needed for saving per-thread state.
#![expect(unsafe_code)]
#![forbid(unsafe_code)]

use inspect::Inspect;
use loan_cell::LoanCell;
use pal::unix::affinity::CpuSet;
use pal_async::fd::FdReadyDriver;
use pal_async::task::Runnable;
Expand All @@ -30,7 +30,6 @@ use pal_uring::IoUringPool;
use pal_uring::PoolClient;
use pal_uring::Timer;
use parking_lot::Mutex;
use std::cell::Cell;
use std::future::poll_fn;
use std::io;
use std::marker::PhantomData;
Expand Down Expand Up @@ -215,15 +214,12 @@ impl ThreadpoolBuilder {
send.send(Ok(pool.client().clone())).ok();

// Store the current thread's driver so that spawned tasks can
// find it via `Thread::current()`.
CURRENT_THREADPOOL_CPU.with(|current| {
current.set(std::ptr::from_ref(&driver));
// find it via `Thread::current()`. Do this via a loan instead
// of storing it directly in TLS to avoid the overhead of
// registering a destructor.
CURRENT_THREAD_DRIVER.with(|current| {
current.lend(&driver, || pool.run());
});
pool.run();
CURRENT_THREADPOOL_CPU.with(|current| {
current.set(std::ptr::null());
});
drop(driver);
})?;

// Wait for the pool to be initialized.
Expand Down Expand Up @@ -360,33 +356,27 @@ impl Initiate for AffinitizedThreadpool {
/// The state for the thread pool thread for the currently running CPU.
#[derive(Debug, Copy, Clone)]
pub struct Thread {
driver: &'static ThreadpoolDriver,
_not_send_sync: PhantomData<*const ()>,
}

impl Thread {
/// Returns a new driver for the current CPU.
/// Returns an instance for the current CPU.
pub fn current() -> Option<Self> {
let inner = CURRENT_THREADPOOL_CPU.with(|current| {
let p = current.get();
// SAFETY: the `ThreadpoolDriver` is on the current thread's stack
// and so is guaranteed to be valid. And since `Thread` is not
// `Send` or `Sync`, this reference cannot be accessed after the
// driver has been dropped, since any task that can construct a
// `Thread` will have been completed by that time. So it's OK for
// this reference to live as long as `Thread`.
(!p.is_null()).then(|| unsafe { &*p })
})?;
if !CURRENT_THREAD_DRIVER.with(|current| current.is_lent()) {
return None;
}
Some(Self {
driver: inner,
_not_send_sync: PhantomData,
})
}

fn once(&self) -> &ThreadpoolDriverOnce {
// Since we are on the thread, the thread is guaranteed to have been
// initialized.
self.driver.inner.once.get().unwrap()
/// Calls `f` with the driver for the current thread.
pub fn with_driver<R>(&self, f: impl FnOnce(&ThreadpoolDriver) -> R) -> R {
CURRENT_THREAD_DRIVER.with(|current| current.borrow(|driver| f(driver.unwrap())))
}

fn with_once<R>(&self, f: impl FnOnce(&ThreadpoolDriver, &ThreadpoolDriverOnce) -> R) -> R {
self.with_driver(|driver| f(driver, driver.inner.once.get().unwrap()))
}

/// Sets the idle task to run. The task is returned by `f`, which receives
Expand All @@ -400,56 +390,52 @@ impl Thread {
F: 'static + Send + FnOnce(IdleControl) -> Fut,
Fut: std::future::Future<Output = ()>,
{
self.once().client.set_idle_task(f)
}

/// Returns the driver for the current thread.
pub fn driver(&self) -> &ThreadpoolDriver {
self.driver
self.with_once(|_, once| once.client.set_idle_task(f))
}

/// Tries to set the affinity to this thread's intended CPU, if it has not
/// already been set. Returns `Ok(false)` if the intended CPU is still
/// offline.
pub fn try_set_affinity(&self) -> Result<bool, SetAffinityError> {
let mut state = self.driver.inner.state.lock();
if matches!(state.affinity, AffinityState::Set) {
return Ok(true);
}
if !is_cpu_online(self.driver.inner.cpu).map_err(SetAffinityError::Online)? {
return Ok(false);
}
self.with_once(|driver, once| {
let mut state = driver.inner.state.lock();
if matches!(state.affinity, AffinityState::Set) {
return Ok(true);
}
if !is_cpu_online(driver.inner.cpu).map_err(SetAffinityError::Online)? {
return Ok(false);
}

let mut affinity = CpuSet::new();
affinity.set(self.driver.inner.cpu);

pal::unix::affinity::set_current_thread_affinity(&affinity)
.map_err(SetAffinityError::Thread)?;
self.once()
.client
.set_iowq_affinity(&affinity)
.map_err(SetAffinityError::Ring)?;

let old_affinity_state = std::mem::replace(&mut state.affinity, AffinityState::Set);
self.driver.inner.affinity_set.store(true, Relaxed);
drop(state);

match old_affinity_state {
AffinityState::Waiting(wakers) => {
for waker in wakers {
waker.wake();
let mut affinity = CpuSet::new();
affinity.set(driver.inner.cpu);

pal::unix::affinity::set_current_thread_affinity(&affinity)
.map_err(SetAffinityError::Thread)?;
once.client
.set_iowq_affinity(&affinity)
.map_err(SetAffinityError::Ring)?;

let old_affinity_state = std::mem::replace(&mut state.affinity, AffinityState::Set);
driver.inner.affinity_set.store(true, Relaxed);
drop(state);

match old_affinity_state {
AffinityState::Waiting(wakers) => {
for waker in wakers {
waker.wake();
}
}
AffinityState::Set => unreachable!(),
}
AffinityState::Set => unreachable!(),
}
Ok(true)
Ok(true)
})
}

/// Returns the that caused this thread to spawn.
///
/// Returns `None` if the thread was spawned to issue IO.
pub fn first_task(&self) -> Option<&TaskInfo> {
self.once().first_task.as_ref()
pub fn first_task(&self) -> Option<TaskInfo> {
self.with_once(|_, once| once.first_task.clone())
}
}

Expand All @@ -468,12 +454,12 @@ pub enum SetAffinityError {
}

thread_local! {
static CURRENT_THREADPOOL_CPU: Cell<*const ThreadpoolDriver> = const { Cell::new(std::ptr::null()) };
static CURRENT_THREAD_DRIVER: LoanCell<ThreadpoolDriver> = const { LoanCell::new() };
}

impl SpawnLocal for Thread {
fn scheduler_local(&self, metadata: &TaskMetadata) -> Arc<dyn Schedule> {
self.driver.scheduler(metadata).clone()
self.with_driver(|driver| driver.scheduler(metadata).clone())
}
}

Expand Down Expand Up @@ -506,7 +492,7 @@ struct ThreadpoolDriverOnce {
}

/// Information about a task that caused a thread to spawn.
#[derive(Debug, Inspect)]
#[derive(Debug, Clone, Inspect)]
pub struct TaskInfo {
/// The name of the task.
pub name: Arc<str>,
Expand Down
15 changes: 15 additions & 0 deletions support/loan_cell/Cargo.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.

[package]
name = "loan_cell"
rust-version.workspace = true
edition.workspace = true

[dependencies]

[dev-dependencies]
static_assertions.workspace = true

[lints]
workspace = true
Loading