Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add miri support to tests suite #6900

Closed
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions tokio-test/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ futures-core = "0.3.0"
[dev-dependencies]
tokio = { version = "1.2.0", path = "../tokio", features = ["full"] }
futures-util = "0.3.0"
tokio-test-macros = { path = "./macros" }

[package.metadata.docs.rs]
all-features = true
11 changes: 11 additions & 0 deletions tokio-test/macros/Cargo.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
[package]
name = "tokio-test-macros"
version = "0.0.0"
edition = "2021"

[lib]
proc-macro = true

[dependencies]
syn = { version = "2", features = ["full"] }
quote2 = "0.9"
118 changes: 118 additions & 0 deletions tokio-test/macros/src/expend.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,118 @@
use quote2::{proc_macro2::TokenStream, quote, utils::quote_rep, Quote, ToTokens};
use syn::{
parse::{Parse, ParseStream, Parser},
punctuated::Punctuated,
Attribute, Meta, MetaNameValue, Signature, Token, Visibility,
};

type AttributeArgs = Punctuated<Meta, Token![,]>;

pub struct ItemFn {
pub attrs: Vec<Attribute>,
pub vis: Visibility,
pub sig: Signature,
pub body: TokenStream,
}

impl Parse for ItemFn {
fn parse(input: ParseStream) -> Result<Self, syn::Error> {
Ok(Self {
attrs: input.call(Attribute::parse_outer)?,
vis: input.parse()?,
sig: input.parse()?,
body: input.parse()?,
})
}
}

pub fn tokio_test(args: TokenStream, item_fn: ItemFn) -> TokenStream {
let metadata = match AttributeArgs::parse_terminated.parse2(args) {
Ok(args) => args,
Err(err) => return err.into_compile_error(),
};

let has_miri_cfg = metadata.iter().any(|meta| meta.path().is_ident("miri"));
let id_multi_thread = metadata.iter().any(|meta| match meta {
Meta::NameValue(meta) if meta.path.is_ident("flavor") => {
match meta.value.to_token_stream().to_string().as_str() {
"multi_thread" => true,
"current_thread" => false,
val => panic!("unknown `flavor = {val}`, expected: multi_thread | current_thread"),
}
}
_ => false,
});
let config = quote_rep(metadata, |t, meta| {
for key in ["miri", "flavor"] {
if meta.path().is_ident(key) {
return;
}
}
if let Meta::NameValue(MetaNameValue { path, value, .. }) = &meta {
for key in ["worker_threads", "start_paused"] {
if path.is_ident(key) {
quote!(t, { .#path(#value) });
return;
}
}
}
panic!("unknown config `{}`", meta.path().to_token_stream())
});
let runtime_type = quote(|t| {
if id_multi_thread {
quote!(t, { new_multi_thread });
} else {
quote!(t, { new_current_thread });
}
});
let ignore_miri = quote(|t| {
if !has_miri_cfg {
quote!(t, { #[cfg_attr(miri, ignore)] });
}
});
let miri_test_executor = quote(|t| {
if has_miri_cfg {
quote!(t, {
if cfg!(miri) {
return tokio_test::task::spawn(body).block_on();
}
});
}
});

let ItemFn {
attrs,
vis,
mut sig,
body,
} = item_fn;

let async_keyword = sig.asyncness.take();
let attrs = quote_rep(attrs, |t, attr| {
quote!(t, { #attr });
});

let mut out = TokenStream::new();
quote!(out, {
#attrs
#ignore_miri
#[::core::prelude::v1::test]
#vis #sig {
let body = #async_keyword #body;
let body= ::std::pin::pin!(body);

#miri_test_executor

#[allow(clippy::expect_used, clippy::diverging_sub_expression, clippy::needless_return)]
{
return tokio::runtime::Builder::#runtime_type()
#config
.enable_all()
.build()
.expect("Failed building the Runtime")
.block_on(body);
}
}
});
out
}
7 changes: 7 additions & 0 deletions tokio-test/macros/src/lib.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
mod expend;
use proc_macro::TokenStream;

#[proc_macro_attribute]
pub fn tokio_test(args: TokenStream, item: TokenStream) -> TokenStream {
expend::tokio_test(args.into(), syn::parse_macro_input!(item)).into()
}
81 changes: 33 additions & 48 deletions tokio-test/src/task.rs
Original file line number Diff line number Diff line change
Expand Up @@ -26,11 +26,10 @@
//! ```

use std::future::Future;
use std::mem;
use std::ops;
use std::pin::Pin;
use std::sync::{Arc, Condvar, Mutex};
use std::task::{Context, Poll, RawWaker, RawWakerVTable, Waker};
use std::task::{Context, Poll, Wake, Waker};

use tokio_stream::Stream;

Expand Down Expand Up @@ -123,6 +122,34 @@ impl<T: Future> Spawn<T> {
let fut = self.future.as_mut();
self.task.enter(|cx| fut.poll(cx))
}

/// Run a future to completion on the current thread.
///
/// This function will block the caller until the given future has completed.
///
/// Note: This does not create a Tokio runtime, and therefore does not support
/// Tokio-specific asynchronous APIs, such as [tokio::time::sleep].
pub fn block_on(&mut self) -> T::Output {
loop {
match self.poll() {
Poll::Ready(val) => return val,
Poll::Pending => {
let mut guard = self.task.waker.state.lock().unwrap();
let state = *guard;

if state == WAKE {
continue;
}

assert_eq!(state, IDLE);
*guard = SLEEP;
let guard = self.task.waker.condvar.wait(guard).unwrap();
assert_eq!(*guard, WAKE);
drop(guard);
}
};
}
}
}

impl<T: Stream> Spawn<T> {
Expand Down Expand Up @@ -171,9 +198,8 @@ impl MockTask {
F: FnOnce(&mut Context<'_>) -> R,
{
self.waker.clear();
let waker = self.waker();
let waker = Waker::from(self.waker.clone());
let mut cx = Context::from_waker(&waker);

f(&mut cx)
}

Expand All @@ -189,13 +215,6 @@ impl MockTask {
fn waker_ref_count(&self) -> usize {
Arc::strong_count(&self.waker)
}

fn waker(&self) -> Waker {
unsafe {
let raw = to_raw(self.waker.clone());
Waker::from_raw(raw)
}
}
}

impl Default for MockTask {
Expand Down Expand Up @@ -226,8 +245,10 @@ impl ThreadWaker {
_ => unreachable!(),
}
}
}

fn wake(&self) {
impl Wake for ThreadWaker {
fn wake(self: Arc<Self>) {
// First, try transitioning from IDLE -> NOTIFY, this does not require a lock.
let mut state = self.state.lock().unwrap();
let prev = *state;
Expand All @@ -247,39 +268,3 @@ impl ThreadWaker {
self.condvar.notify_one();
}
}

static VTABLE: RawWakerVTable = RawWakerVTable::new(clone, wake, wake_by_ref, drop_waker);

unsafe fn to_raw(waker: Arc<ThreadWaker>) -> RawWaker {
RawWaker::new(Arc::into_raw(waker) as *const (), &VTABLE)
}

unsafe fn from_raw(raw: *const ()) -> Arc<ThreadWaker> {
Arc::from_raw(raw as *const ThreadWaker)
}

unsafe fn clone(raw: *const ()) -> RawWaker {
let waker = from_raw(raw);

// Increment the ref count
mem::forget(waker.clone());

to_raw(waker)
}

unsafe fn wake(raw: *const ()) {
let waker = from_raw(raw);
waker.wake();
}

unsafe fn wake_by_ref(raw: *const ()) {
let waker = from_raw(raw);
waker.wake();

// We don't actually own a reference to the unparker
mem::forget(waker);
}

unsafe fn drop_waker(raw: *const ()) {
let _ = from_raw(raw);
}
3 changes: 3 additions & 0 deletions tokio-test/tests/block_on.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ use tokio::time::{sleep_until, Duration, Instant};
use tokio_test::block_on;

#[test]
#[cfg_attr(miri, ignore)]
fn async_block() {
assert_eq!(4, block_on(async { 4 }));
}
Expand All @@ -13,11 +14,13 @@ async fn five() -> u8 {
}

#[test]
#[cfg_attr(miri, ignore)]
fn async_fn() {
assert_eq!(5, block_on(five()));
}

#[test]
#[cfg_attr(miri, ignore)]
fn test_sleep() {
let deadline = Instant::now() + Duration::from_millis(100);

Expand Down
22 changes: 14 additions & 8 deletions tokio-test/tests/io.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,12 @@ use tokio::io::{AsyncReadExt, AsyncWriteExt};
use tokio::time::{Duration, Instant};
use tokio_test::io::Builder;

#[tokio::test]
mod tokio {
pub use ::tokio::*;
pub use ::tokio_test_macros::tokio_test as test;
}

#[tokio::test(miri)]
async fn read() {
let mut mock = Builder::new().read(b"hello ").read(b"world!").build();

Expand All @@ -18,7 +23,7 @@ async fn read() {
assert_eq!(&buf[..n], b"world!");
}

#[tokio::test]
#[tokio::test(miri)]
async fn read_error() {
let error = io::Error::new(io::ErrorKind::Other, "cruel");
let mut mock = Builder::new()
Expand All @@ -43,15 +48,15 @@ async fn read_error() {
assert_eq!(&buf[..n], b"world!");
}

#[tokio::test]
#[tokio::test(miri)]
async fn write() {
let mut mock = Builder::new().write(b"hello ").write(b"world!").build();

mock.write_all(b"hello ").await.expect("write 1");
mock.write_all(b"world!").await.expect("write 2");
}

#[tokio::test]
#[tokio::test(miri)]
async fn write_with_handle() {
let (mut mock, mut handle) = Builder::new().build_with_handle();
handle.write(b"hello ");
Expand All @@ -61,7 +66,7 @@ async fn write_with_handle() {
mock.write_all(b"world!").await.expect("write 2");
}

#[tokio::test]
#[tokio::test(miri)]
async fn read_with_handle() {
let (mut mock, mut handle) = Builder::new().build_with_handle();
handle.read(b"hello ");
Expand All @@ -74,14 +79,15 @@ async fn read_with_handle() {
assert_eq!(&buf[..], b"world!");
}

#[tokio::test]
#[tokio::test(miri)]
async fn write_error() {
let error = io::Error::new(io::ErrorKind::Other, "cruel");
let mut mock = Builder::new()
.write(b"hello ")
.write_error(error)
.write(b"world!")
.build();

mock.write_all(b"hello ").await.expect("write 1");

match mock.write_all(b"whoa").await {
Expand All @@ -95,14 +101,14 @@ async fn write_error() {
mock.write_all(b"world!").await.expect("write 2");
}

#[tokio::test]
#[tokio::test(miri)]
#[should_panic]
async fn mock_panics_read_data_left() {
use tokio_test::io::Builder;
Builder::new().read(b"read").build();
}

#[tokio::test]
#[tokio::test(miri)]
#[should_panic]
async fn mock_panics_write_data_left() {
use tokio_test::io::Builder;
Expand Down
Loading
Loading