Skip to content

Commit c9b59a1

Browse files
authored
loan_cell: add primitive for lending thread-local data (#743)
Add a safe abstraction for temporarily lending on-stack data into thread local storage. Use it in various places across the stack. This fixes a use-after-free in `pal_async`, and it reduces the overhead of TLS in `pal_async` and `underhill_threadpool`.
1 parent 1b9ba4e commit c9b59a1

File tree

14 files changed

+374
-204
lines changed

14 files changed

+374
-204
lines changed

Cargo.lock

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3567,6 +3567,13 @@ dependencies = [
35673567
"zerocopy",
35683568
]
35693569

3570+
[[package]]
3571+
name = "loan_cell"
3572+
version = "0.0.0"
3573+
dependencies = [
3574+
"static_assertions",
3575+
]
3576+
35703577
[[package]]
35713578
name = "local_clock"
35723579
version = "0.0.0"
@@ -4876,6 +4883,7 @@ dependencies = [
48764883
"futures",
48774884
"getrandom",
48784885
"libc",
4886+
"loan_cell",
48794887
"once_cell",
48804888
"pal",
48814889
"pal_async_test",
@@ -4919,6 +4927,7 @@ dependencies = [
49194927
"inspect",
49204928
"io-uring",
49214929
"libc",
4930+
"loan_cell",
49224931
"once_cell",
49234932
"pal",
49244933
"pal_async",
@@ -7204,6 +7213,7 @@ version = "0.0.0"
72047213
dependencies = [
72057214
"fs-err",
72067215
"inspect",
7216+
"loan_cell",
72077217
"pal",
72087218
"pal_async",
72097219
"pal_uring",

Cargo.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -103,6 +103,7 @@ inspect_proto = { path = "support/inspect_proto" }
103103
inspect_rlimit = { path = "support/inspect_rlimit" }
104104
inspect_task = { path = "support/inspect_task" }
105105
kmsg = { path = "support/kmsg" }
106+
loan_cell = { path = "support/loan_cell" }
106107
local_clock = { path = "support/local_clock" }
107108
mesh = { path = "support/mesh" }
108109
mesh_build = { path = "support/mesh/mesh_build" }

openhcl/underhill_core/src/vp.rs

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -84,10 +84,11 @@ impl VpSpawner {
8484
let thread = underhill_threadpool::Thread::current().unwrap();
8585
// TODO propagate this error back earlier. This is easiest if
8686
// set_idle_task is fixed to take a non-Send fn.
87-
let mut vp = self
88-
.vp
89-
.bind_processor::<T>(thread.driver(), control)
90-
.context("failed to initialize VP")?;
87+
let mut vp = thread.with_driver(|driver| {
88+
self.vp
89+
.bind_processor::<T>(driver, control)
90+
.context("failed to initialize VP")
91+
})?;
9192

9293
if let Some(saved_state) = saved_state {
9394
vmcore::save_restore::ProtobufSaveRestore::restore(&mut vp, saved_state)
@@ -166,7 +167,7 @@ impl VpSpawner {
166167
self.vp.set_sidecar_exit_due_to_task(
167168
thread
168169
.first_task()
169-
.map_or_else(|| "<unknown>".into(), |t| t.name.clone()),
170+
.map_or_else(|| "<unknown>".into(), |t| t.name),
170171
);
171172
}
172173

openhcl/underhill_mem/src/init.rs

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -469,8 +469,7 @@ async fn apply_vtl2_protections(
469469
tracing::debug!(
470470
cpu = underhill_threadpool::Thread::current()
471471
.unwrap()
472-
.driver()
473-
.target_cpu(),
472+
.with_driver(|driver| driver.target_cpu()),
474473
%range,
475474
"applying protections"
476475
);

openhcl/underhill_threadpool/Cargo.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ rust-version.workspace = true
88

99
[target.'cfg(target_os = "linux")'.dependencies]
1010
inspect = { workspace = true, features = ["std"] }
11+
loan_cell.workspace = true
1112
pal.workspace = true
1213
pal_async.workspace = true
1314
pal_uring.workspace = true

openhcl/underhill_threadpool/src/lib.rs

Lines changed: 53 additions & 67 deletions
Original file line numberDiff line numberDiff line change
@@ -8,10 +8,10 @@
88
//! This is built on top of [`pal_uring`] and [`pal_async`].
99
1010
#![warn(missing_docs)]
11-
// UNSAFETY: needed for saving per-thread state.
12-
#![expect(unsafe_code)]
11+
#![forbid(unsafe_code)]
1312

1413
use inspect::Inspect;
14+
use loan_cell::LoanCell;
1515
use pal::unix::affinity::CpuSet;
1616
use pal_async::fd::FdReadyDriver;
1717
use pal_async::task::Runnable;
@@ -30,7 +30,6 @@ use pal_uring::IoUringPool;
3030
use pal_uring::PoolClient;
3131
use pal_uring::Timer;
3232
use parking_lot::Mutex;
33-
use std::cell::Cell;
3433
use std::future::poll_fn;
3534
use std::io;
3635
use std::marker::PhantomData;
@@ -215,15 +214,12 @@ impl ThreadpoolBuilder {
215214
send.send(Ok(pool.client().clone())).ok();
216215

217216
// Store the current thread's driver so that spawned tasks can
218-
// find it via `Thread::current()`.
219-
CURRENT_THREADPOOL_CPU.with(|current| {
220-
current.set(std::ptr::from_ref(&driver));
217+
// find it via `Thread::current()`. Do this via a loan instead
218+
// of storing it directly in TLS to avoid the overhead of
219+
// registering a destructor.
220+
CURRENT_THREAD_DRIVER.with(|current| {
221+
current.lend(&driver, || pool.run());
221222
});
222-
pool.run();
223-
CURRENT_THREADPOOL_CPU.with(|current| {
224-
current.set(std::ptr::null());
225-
});
226-
drop(driver);
227223
})?;
228224

229225
// Wait for the pool to be initialized.
@@ -360,33 +356,27 @@ impl Initiate for AffinitizedThreadpool {
360356
/// The state for the thread pool thread for the currently running CPU.
361357
#[derive(Debug, Copy, Clone)]
362358
pub struct Thread {
363-
driver: &'static ThreadpoolDriver,
364359
_not_send_sync: PhantomData<*const ()>,
365360
}
366361

367362
impl Thread {
368-
/// Returns a new driver for the current CPU.
363+
/// Returns an instance for the current CPU.
369364
pub fn current() -> Option<Self> {
370-
let inner = CURRENT_THREADPOOL_CPU.with(|current| {
371-
let p = current.get();
372-
// SAFETY: the `ThreadpoolDriver` is on the current thread's stack
373-
// and so is guaranteed to be valid. And since `Thread` is not
374-
// `Send` or `Sync`, this reference cannot be accessed after the
375-
// driver has been dropped, since any task that can construct a
376-
// `Thread` will have been completed by that time. So it's OK for
377-
// this reference to live as long as `Thread`.
378-
(!p.is_null()).then(|| unsafe { &*p })
379-
})?;
365+
if !CURRENT_THREAD_DRIVER.with(|current| current.is_lent()) {
366+
return None;
367+
}
380368
Some(Self {
381-
driver: inner,
382369
_not_send_sync: PhantomData,
383370
})
384371
}
385372

386-
fn once(&self) -> &ThreadpoolDriverOnce {
387-
// Since we are on the thread, the thread is guaranteed to have been
388-
// initialized.
389-
self.driver.inner.once.get().unwrap()
373+
/// Calls `f` with the driver for the current thread.
374+
pub fn with_driver<R>(&self, f: impl FnOnce(&ThreadpoolDriver) -> R) -> R {
375+
CURRENT_THREAD_DRIVER.with(|current| current.borrow(|driver| f(driver.unwrap())))
376+
}
377+
378+
fn with_once<R>(&self, f: impl FnOnce(&ThreadpoolDriver, &ThreadpoolDriverOnce) -> R) -> R {
379+
self.with_driver(|driver| f(driver, driver.inner.once.get().unwrap()))
390380
}
391381

392382
/// Sets the idle task to run. The task is returned by `f`, which receives
@@ -400,56 +390,52 @@ impl Thread {
400390
F: 'static + Send + FnOnce(IdleControl) -> Fut,
401391
Fut: std::future::Future<Output = ()>,
402392
{
403-
self.once().client.set_idle_task(f)
404-
}
405-
406-
/// Returns the driver for the current thread.
407-
pub fn driver(&self) -> &ThreadpoolDriver {
408-
self.driver
393+
self.with_once(|_, once| once.client.set_idle_task(f))
409394
}
410395

411396
/// Tries to set the affinity to this thread's intended CPU, if it has not
412397
/// already been set. Returns `Ok(false)` if the intended CPU is still
413398
/// offline.
414399
pub fn try_set_affinity(&self) -> Result<bool, SetAffinityError> {
415-
let mut state = self.driver.inner.state.lock();
416-
if matches!(state.affinity, AffinityState::Set) {
417-
return Ok(true);
418-
}
419-
if !is_cpu_online(self.driver.inner.cpu).map_err(SetAffinityError::Online)? {
420-
return Ok(false);
421-
}
400+
self.with_once(|driver, once| {
401+
let mut state = driver.inner.state.lock();
402+
if matches!(state.affinity, AffinityState::Set) {
403+
return Ok(true);
404+
}
405+
if !is_cpu_online(driver.inner.cpu).map_err(SetAffinityError::Online)? {
406+
return Ok(false);
407+
}
422408

423-
let mut affinity = CpuSet::new();
424-
affinity.set(self.driver.inner.cpu);
425-
426-
pal::unix::affinity::set_current_thread_affinity(&affinity)
427-
.map_err(SetAffinityError::Thread)?;
428-
self.once()
429-
.client
430-
.set_iowq_affinity(&affinity)
431-
.map_err(SetAffinityError::Ring)?;
432-
433-
let old_affinity_state = std::mem::replace(&mut state.affinity, AffinityState::Set);
434-
self.driver.inner.affinity_set.store(true, Relaxed);
435-
drop(state);
436-
437-
match old_affinity_state {
438-
AffinityState::Waiting(wakers) => {
439-
for waker in wakers {
440-
waker.wake();
409+
let mut affinity = CpuSet::new();
410+
affinity.set(driver.inner.cpu);
411+
412+
pal::unix::affinity::set_current_thread_affinity(&affinity)
413+
.map_err(SetAffinityError::Thread)?;
414+
once.client
415+
.set_iowq_affinity(&affinity)
416+
.map_err(SetAffinityError::Ring)?;
417+
418+
let old_affinity_state = std::mem::replace(&mut state.affinity, AffinityState::Set);
419+
driver.inner.affinity_set.store(true, Relaxed);
420+
drop(state);
421+
422+
match old_affinity_state {
423+
AffinityState::Waiting(wakers) => {
424+
for waker in wakers {
425+
waker.wake();
426+
}
441427
}
428+
AffinityState::Set => unreachable!(),
442429
}
443-
AffinityState::Set => unreachable!(),
444-
}
445-
Ok(true)
430+
Ok(true)
431+
})
446432
}
447433

448434
/// Returns the that caused this thread to spawn.
449435
///
450436
/// Returns `None` if the thread was spawned to issue IO.
451-
pub fn first_task(&self) -> Option<&TaskInfo> {
452-
self.once().first_task.as_ref()
437+
pub fn first_task(&self) -> Option<TaskInfo> {
438+
self.with_once(|_, once| once.first_task.clone())
453439
}
454440
}
455441

@@ -468,12 +454,12 @@ pub enum SetAffinityError {
468454
}
469455

470456
thread_local! {
471-
static CURRENT_THREADPOOL_CPU: Cell<*const ThreadpoolDriver> = const { Cell::new(std::ptr::null()) };
457+
static CURRENT_THREAD_DRIVER: LoanCell<ThreadpoolDriver> = const { LoanCell::new() };
472458
}
473459

474460
impl SpawnLocal for Thread {
475461
fn scheduler_local(&self, metadata: &TaskMetadata) -> Arc<dyn Schedule> {
476-
self.driver.scheduler(metadata).clone()
462+
self.with_driver(|driver| driver.scheduler(metadata).clone())
477463
}
478464
}
479465

@@ -506,7 +492,7 @@ struct ThreadpoolDriverOnce {
506492
}
507493

508494
/// Information about a task that caused a thread to spawn.
509-
#[derive(Debug, Inspect)]
495+
#[derive(Debug, Clone, Inspect)]
510496
pub struct TaskInfo {
511497
/// The name of the task.
512498
pub name: Arc<str>,

support/loan_cell/Cargo.toml

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
# Copyright (c) Microsoft Corporation.
2+
# Licensed under the MIT License.
3+
4+
[package]
5+
name = "loan_cell"
6+
rust-version.workspace = true
7+
edition.workspace = true
8+
9+
[dependencies]
10+
11+
[dev-dependencies]
12+
static_assertions.workspace = true
13+
14+
[lints]
15+
workspace = true

0 commit comments

Comments
 (0)