diff --git a/Cargo.toml b/Cargo.toml index b02b6aa0..7b67b333 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -27,6 +27,7 @@ xmltree = "0.11" http-body-util = "0.1" hyper = { package = "hyper", version = "1", features = ["server", "http1"] } hyper-util = { version = "0.1", features = ["tokio"] } +paste = "1.0.15" simplelog = "0.9" test-log = "0.2" tokio = {version = "1", features = ["full"]} diff --git a/src/common/tests.rs b/src/common/tests.rs index 0f9f4eca..c9af0e5c 100644 --- a/src/common/tests.rs +++ b/src/common/tests.rs @@ -2,6 +2,7 @@ use crate::{search_gateway, SearchError, SearchOptions}; use http_body_util::Full; use hyper::{body::Bytes, service::service_fn, Request, Response}; use hyper_util::rt::TokioIo; +use paste::paste; use std::{ convert::Infallible, future::Future, @@ -143,15 +144,49 @@ async fn start_http_server(responses: Vec) -> u16 { local_port } -#[test(tokio::test)] -async fn ip_spoofing_in_broadcast_response() { - async fn aux(search_gateway: F) - where - Fut: Future>, - F: Fn(SearchOptions) -> Fut, - { - let local_free_port = start_broadcast_reply_sender("http://1.2.3.4:5".to_owned()).await; +/// Runs given test body with both the blocking and async implementation +macro_rules! run_tests { + ($(fn $name:ident($sg:ident) { $($test_body:tt)* })*) => { + $( + paste! { + #[test(tokio::test)] + async fn []() { + async fn aux($sg: F) + where + Fut: Future>, + F: Fn(SearchOptions) -> Fut, + { + $($test_body)+ + } + aux(|opt| async { + tokio::task::spawn_blocking(|| search_gateway(opt).map(|_| ())) + .await + .unwrap() + }) + .await; + } + + #[cfg(feature = "aio")] + #[test(tokio::test)] + async fn []() { + async fn aux($sg: F) + where + Fut: Future>, + F: Fn(SearchOptions) -> Fut, + { + $($test_body)+ + } + + aux(|opt| async { crate::aio::search_gateway(opt).await.map(|_| ()) }).await; + } + } + )* + }; +} +run_tests! { + fn ip_spoofing_in_broadcast_response(search_gateway) { + let local_free_port = start_broadcast_reply_sender("http://1.2.3.4:5".to_owned()).await; let options = default_options_with_using_free_port(local_free_port); let result = search_gateway(options).await; @@ -163,29 +198,10 @@ async fn ip_spoofing_in_broadcast_response() { } } - aux(|opt| async { - tokio::task::spawn_blocking(|| search_gateway(opt).map(|_| ())) - .await - .unwrap() - }) - .await; - #[cfg(feature = "aio")] - aux(|opt| async { crate::aio::search_gateway(opt).await.map(|_| ()) }).await; -} - -#[test(tokio::test)] -async fn ip_spoofing_in_getxml_body() { - async fn aux(search_gateway: F) - where - Fut: Future>, - F: Fn(SearchOptions) -> Fut, - { + fn ip_spoofing_in_getxml_body(search_gateway) { let http_port = start_http_server(vec![RESP_SPOOFED_SCPDURL.to_owned()]).await; - let local_free_port = start_broadcast_reply_sender(format!("http://127.0.0.1:{http_port}")).await; - println!("http server port: {http_port}, udp port: {local_free_port}"); - let options = default_options_with_using_free_port(local_free_port); let result = search_gateway(options).await; @@ -196,23 +212,8 @@ async fn ip_spoofing_in_getxml_body() { panic!("Unexpected result: {result:?}"); } } - aux(|opt| async { - tokio::task::spawn_blocking(|| search_gateway(opt).map(|_| ())) - .await - .unwrap() - }) - .await; - #[cfg(feature = "aio")] - aux(|opt| async { crate::aio::search_gateway(opt).await.map(|_| ()) }).await; -} -#[test(tokio::test)] -async fn ip_spoofing_in_getxml_body_control_url() { - async fn aux(search_gateway: F) - where - Fut: Future>, - F: Fn(SearchOptions) -> Fut, - { + fn ip_spoofing_in_getxml_body_control_url(search_gateway) { let http_port = start_http_server(vec![ RESP_SPOOFED_CONTROL_URL.to_owned(), RESP_CONTROL_SCHEMA.to_owned(), @@ -220,9 +221,7 @@ async fn ip_spoofing_in_getxml_body_control_url() { .await; let local_free_port = start_broadcast_reply_sender(format!("http://127.0.0.1:{http_port}")).await; - let options = default_options_with_using_free_port(local_free_port); - let result = search_gateway(options).await; if let Err(SearchError::SpoofedUrl { src_ip, url_host }) = result { @@ -232,37 +231,11 @@ async fn ip_spoofing_in_getxml_body_control_url() { panic!("Unexpected result: {result:?}"); } } - aux(|opt| async { - tokio::task::spawn_blocking(|| search_gateway(opt).map(|_| ())) - .await - .unwrap() - }) - .await; - #[cfg(feature = "aio")] - aux(|opt| async { crate::aio::search_gateway(opt).await.map(|_| ()) }).await; -} -#[test(tokio::test)] -async fn non_spoofed_urls_result_in_search_gateway_success() { - async fn aux(search_gateway: F) - where - Fut: Future>, - F: Fn(SearchOptions) -> Fut, - { + fn non_spoofed_urls_result_in_search_gateway_success(search_gateway) { let http_port = start_http_server(vec![RESP.to_owned(), RESP_CONTROL_SCHEMA.to_owned()]).await; - let local_free_port = start_broadcast_reply_sender(format!("http://127.0.0.1:{http_port}")).await; - let options = default_options_with_using_free_port(local_free_port); - assert!(search_gateway(options).await.is_ok()); } - aux(|opt| async { - tokio::task::spawn_blocking(|| search_gateway(opt).map(|_| ())) - .await - .unwrap() - }) - .await; - #[cfg(feature = "aio")] - aux(|opt| async { crate::aio::search_gateway(opt).await.map(|_| ()) }).await; }