Skip to content

Commit

Permalink
Merge branch 'master' into fix-sharded-timer
Browse files Browse the repository at this point in the history
  • Loading branch information
wathenjiang authored Aug 5, 2024
2 parents 36b8c07 + ab53bf0 commit 955ec31
Show file tree
Hide file tree
Showing 13 changed files with 290 additions and 16 deletions.
52 changes: 52 additions & 0 deletions tokio/src/future/maybe_done.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ pin_project! {
#[derive(Debug)]
#[project = MaybeDoneProj]
#[project_replace = MaybeDoneProjReplace]
#[repr(C)] // https://github.com/rust-lang/miri/issues/3780
pub enum MaybeDone<Fut: Future> {
/// A not-yet-completed future.
Future { #[pin] future: Fut },
Expand Down Expand Up @@ -69,3 +70,54 @@ impl<Fut: Future> Future for MaybeDone<Fut> {
Poll::Ready(())
}
}

// Test for https://github.com/tokio-rs/tokio/issues/6729
#[cfg(test)]
mod miri_tests {
use super::maybe_done;

use std::{
future::Future,
pin::Pin,
sync::Arc,
task::{Context, Poll, Wake},
};

struct ThingAdder<'a> {
thing: &'a mut String,
}

impl Future for ThingAdder<'_> {
type Output = ();

fn poll(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<Self::Output> {
unsafe {
*self.get_unchecked_mut().thing += ", world";
}
Poll::Pending
}
}

#[test]
fn maybe_done_miri() {
let mut thing = "hello".to_owned();

// The async block is necessary to trigger the miri failure.
#[allow(clippy::redundant_async_block)]
let fut = async move { ThingAdder { thing: &mut thing }.await };

let mut fut = maybe_done(fut);
let mut fut = unsafe { Pin::new_unchecked(&mut fut) };

let waker = Arc::new(DummyWaker).into();
let mut ctx = Context::from_waker(&waker);
assert_eq!(fut.as_mut().poll(&mut ctx), Poll::Pending);
assert_eq!(fut.as_mut().poll(&mut ctx), Poll::Pending);
}

struct DummyWaker;

impl Wake for DummyWaker {
fn wake(self: Arc<Self>) {}
}
}
12 changes: 10 additions & 2 deletions tokio/src/io/util/write_all_buf.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ use crate::io::AsyncWrite;
use bytes::Buf;
use pin_project_lite::pin_project;
use std::future::Future;
use std::io;
use std::io::{self, IoSlice};
use std::marker::PhantomPinned;
use std::pin::Pin;
use std::task::{Context, Poll};
Expand Down Expand Up @@ -42,9 +42,17 @@ where
type Output = io::Result<()>;

fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
const MAX_VECTOR_ELEMENTS: usize = 64;

let me = self.project();
while me.buf.has_remaining() {
let n = ready!(Pin::new(&mut *me.writer).poll_write(cx, me.buf.chunk())?);
let n = if me.writer.is_write_vectored() {
let mut slices = [IoSlice::new(&[]); MAX_VECTOR_ELEMENTS];
let cnt = me.buf.chunks_vectored(&mut slices);
ready!(Pin::new(&mut *me.writer).poll_write_vectored(cx, &slices[..cnt]))?
} else {
ready!(Pin::new(&mut *me.writer).poll_write(cx, me.buf.chunk())?)
};
me.buf.advance(n);
if n == 0 {
return Poll::Ready(Err(io::ErrorKind::WriteZero.into()));
Expand Down
6 changes: 6 additions & 0 deletions tokio/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -367,6 +367,12 @@
//! rustflags = ["--cfg", "tokio_unstable"]
//! ```
//!
//! <div class="warning">
//! The <code>[build]</code> section does <strong>not</strong> go in a
//! <code>Cargo.toml</code> file. Instead it must be placed in the Cargo config
//! file <code>.cargo/config.toml</code>.
//! </div>
//!
//! Alternatively, you can specify it with an environment variable:
//!
//! ```sh
Expand Down
1 change: 1 addition & 0 deletions tokio/src/runtime/task/core.rs
Original file line number Diff line number Diff line change
Expand Up @@ -196,6 +196,7 @@ generate_addr_of_methods! {
}

/// Either the future or the output.
#[repr(C)] // https://github.com/rust-lang/miri/issues/3780
pub(super) enum Stage<T: Future> {
Running(T),
Finished(super::Result<T::Output>),
Expand Down
23 changes: 12 additions & 11 deletions tokio/src/runtime/task/id.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
use crate::runtime::context;

use std::fmt;
use std::{fmt, num::NonZeroU64};

/// An opaque ID that uniquely identifies a task relative to all other currently
/// running tasks.
Expand All @@ -24,7 +24,7 @@ use std::fmt;
#[cfg_attr(docsrs, doc(cfg(all(feature = "rt", tokio_unstable))))]
#[cfg_attr(not(tokio_unstable), allow(unreachable_pub))]
#[derive(Clone, Copy, Debug, Hash, Eq, PartialEq)]
pub struct Id(pub(crate) u64);
pub struct Id(pub(crate) NonZeroU64);

/// Returns the [`Id`] of the currently running task.
///
Expand Down Expand Up @@ -78,21 +78,22 @@ impl Id {
use crate::loom::sync::atomic::StaticAtomicU64;

#[cfg(all(test, loom))]
{
crate::loom::lazy_static! {
static ref NEXT_ID: StaticAtomicU64 = StaticAtomicU64::new(1);
}
Self(NEXT_ID.fetch_add(1, Relaxed))
crate::loom::lazy_static! {
static ref NEXT_ID: StaticAtomicU64 = StaticAtomicU64::new(1);
}

#[cfg(not(all(test, loom)))]
{
static NEXT_ID: StaticAtomicU64 = StaticAtomicU64::new(1);
Self(NEXT_ID.fetch_add(1, Relaxed))
static NEXT_ID: StaticAtomicU64 = StaticAtomicU64::new(1);

loop {
let id = NEXT_ID.fetch_add(1, Relaxed);
if let Some(id) = NonZeroU64::new(id) {
return Self(id);
}
}
}

pub(crate) fn as_u64(&self) -> u64 {
self.0
self.0.get()
}
}
2 changes: 1 addition & 1 deletion tokio/src/runtime/task/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -532,6 +532,6 @@ unsafe impl<S> sharded_list::ShardedListItem for Task<S> {
unsafe fn get_shard_id(target: NonNull<Self::Target>) -> usize {
// SAFETY: The caller guarantees that `target` points at a valid task.
let task_id = unsafe { Header::get_id(target) };
task_id.0 as usize
task_id.0.get() as usize
}
}
94 changes: 94 additions & 0 deletions tokio/src/runtime/tests/task.rs
Original file line number Diff line number Diff line change
Expand Up @@ -223,6 +223,100 @@ fn shutdown_immediately() {
})
}

// Test for https://github.com/tokio-rs/tokio/issues/6729
#[test]
fn spawn_niche_in_task() {
use crate::future::poll_fn;
use std::task::{Context, Poll, Waker};

with(|rt| {
let state = Arc::new(Mutex::new(State::new()));

let mut subscriber = Subscriber::new(Arc::clone(&state), 1);
rt.spawn(async move {
subscriber.wait().await;
subscriber.wait().await;
});

rt.spawn(async move {
state.lock().unwrap().set_version(2);
state.lock().unwrap().set_version(0);
});

rt.tick_max(10);
assert!(rt.is_empty());
rt.shutdown();
});

pub(crate) struct Subscriber {
state: Arc<Mutex<State>>,
observed_version: u64,
waker_key: Option<usize>,
}

impl Subscriber {
pub(crate) fn new(state: Arc<Mutex<State>>, version: u64) -> Self {
Self {
state,
observed_version: version,
waker_key: None,
}
}

pub(crate) async fn wait(&mut self) {
poll_fn(|cx| {
self.state
.lock()
.unwrap()
.poll_update(&mut self.observed_version, &mut self.waker_key, cx)
.map(|_| ())
})
.await;
}
}

struct State {
version: u64,
wakers: Vec<Waker>,
}

impl State {
pub(crate) fn new() -> Self {
Self {
version: 1,
wakers: Vec::new(),
}
}

pub(crate) fn poll_update(
&mut self,
observed_version: &mut u64,
waker_key: &mut Option<usize>,
cx: &Context<'_>,
) -> Poll<Option<()>> {
if self.version == 0 {
*waker_key = None;
Poll::Ready(None)
} else if *observed_version < self.version {
*waker_key = None;
*observed_version = self.version;
Poll::Ready(Some(()))
} else {
self.wakers.push(cx.waker().clone());
*waker_key = Some(self.wakers.len());
Poll::Pending
}
}

pub(crate) fn set_version(&mut self, version: u64) {
self.version = version;
for waker in self.wakers.drain(..) {
waker.wake();
}
}
}
}

#[test]
fn spawn_during_shutdown() {
static DID_SPAWN: AtomicBool = AtomicBool::new(false);
Expand Down
1 change: 1 addition & 0 deletions tokio/src/sync/batch_semaphore.rs
Original file line number Diff line number Diff line change
Expand Up @@ -147,6 +147,7 @@ impl Semaphore {
#[cfg(all(tokio_unstable, feature = "tracing"))]
let resource_span = {
let resource_span = tracing::trace_span!(
parent: None,
"runtime.resource",
concrete_type = "Semaphore",
kind = "Sync",
Expand Down
1 change: 1 addition & 0 deletions tokio/src/time/sleep.rs
Original file line number Diff line number Diff line change
Expand Up @@ -267,6 +267,7 @@ impl Sleep {

let location = location.expect("should have location if tracing");
let resource_span = tracing::trace_span!(
parent: None,
"runtime.resource",
concrete_type = "Sleep",
kind = "timer",
Expand Down
49 changes: 49 additions & 0 deletions tokio/tests/io_write_all_buf.rs
Original file line number Diff line number Diff line change
Expand Up @@ -94,3 +94,52 @@ async fn write_buf_err() {
Bytes::from_static(b"oworld")
);
}

#[tokio::test]
async fn write_all_buf_vectored() {
struct Wr {
buf: BytesMut,
}
impl AsyncWrite for Wr {
fn poll_write(
self: Pin<&mut Self>,
_cx: &mut Context<'_>,
_buf: &[u8],
) -> Poll<io::Result<usize>> {
// When executing `write_all_buf` with this writer,
// `poll_write` is not called.
panic!("shouldn't be called")
}
fn poll_flush(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<io::Result<()>> {
Ok(()).into()
}
fn poll_shutdown(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<io::Result<()>> {
Ok(()).into()
}
fn poll_write_vectored(
mut self: Pin<&mut Self>,
_cx: &mut Context<'_>,
bufs: &[io::IoSlice<'_>],
) -> Poll<Result<usize, io::Error>> {
for buf in bufs {
self.buf.extend_from_slice(buf);
}
let n = self.buf.len();
Ok(n).into()
}
fn is_write_vectored(&self) -> bool {
// Enable vectored write.
true
}
}

let mut wr = Wr {
buf: BytesMut::with_capacity(64),
};
let mut buf = Bytes::from_static(b"hello")
.chain(Bytes::from_static(b" "))
.chain(Bytes::from_static(b"world"));

wr.write_all_buf(&mut buf).await.unwrap();
assert_eq!(&wr.buf[..], b"hello world");
}
2 changes: 1 addition & 1 deletion tokio/tests/macros_join.rs
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ fn join_size() {
let ready2 = future::ready(0i32);
tokio::join!(ready1, ready2)
};
assert_eq!(mem::size_of_val(&fut), 40);
assert_eq!(mem::size_of_val(&fut), 48);
}

async fn non_cooperative_task(permits: Arc<Semaphore>) -> usize {
Expand Down
2 changes: 1 addition & 1 deletion tokio/tests/tracing-instrumentation/tests/sync.rs
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ async fn test_mutex_creates_span() {
.new_span(mutex_span.clone().with_explicit_parent(None))
.enter(mutex_span.clone())
.event(locked_event)
.new_span(batch_semaphore_span.clone())
.new_span(batch_semaphore_span.clone().with_explicit_parent(None))
.enter(batch_semaphore_span.clone())
.event(batch_semaphore_permits_event)
.exit(batch_semaphore_span.clone())
Expand Down
Loading

0 comments on commit 955ec31

Please sign in to comment.