Skip to content

Commit

Permalink
Extract sync/async test infrastructure into a macro
Browse files Browse the repository at this point in the history
  • Loading branch information
tomaszklak committed Dec 10, 2024
1 parent 3d0c5c7 commit 1bc1638
Show file tree
Hide file tree
Showing 2 changed files with 47 additions and 73 deletions.
1 change: 1 addition & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ xmltree = "0.11"
http-body-util = "0.1"
hyper = { package = "hyper", version = "1", features = ["server", "http1"] }
hyper-util = { version = "0.1", features = ["tokio"] }
paste = "1.0.15"
simplelog = "0.9"
test-log = "0.2"
tokio = {version = "1", features = ["full"]}
Expand Down
119 changes: 46 additions & 73 deletions src/common/tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ use crate::{search_gateway, SearchError, SearchOptions};
use http_body_util::Full;
use hyper::{body::Bytes, service::service_fn, Request, Response};
use hyper_util::rt::TokioIo;
use paste::paste;
use std::{
convert::Infallible,
future::Future,
Expand Down Expand Up @@ -143,15 +144,49 @@ async fn start_http_server(responses: Vec<String>) -> u16 {
local_port
}

#[test(tokio::test)]
async fn ip_spoofing_in_broadcast_response() {
async fn aux<F, Fut>(search_gateway: F)
where
Fut: Future<Output = Result<(), SearchError>>,
F: Fn(SearchOptions) -> Fut,
{
let local_free_port = start_broadcast_reply_sender("http://1.2.3.4:5".to_owned()).await;
/// Runs given test body with both the blocking and async implementation
macro_rules! run_tests {
($(fn $name:ident($sg:ident) { $($test_body:tt)* })*) => {
$(
paste! {
#[test(tokio::test)]
async fn [<blocking_ $name>]() {
async fn aux<F, Fut>($sg: F)
where
Fut: Future<Output = Result<(), SearchError>>,
F: Fn(SearchOptions) -> Fut,
{
$($test_body)+
}
aux(|opt| async {
tokio::task::spawn_blocking(|| search_gateway(opt).map(|_| ()))
.await
.unwrap()
})
.await;
}

#[cfg(feature = "aio")]
#[test(tokio::test)]
async fn [<async_ $name>]() {
async fn aux<F, Fut>($sg: F)
where
Fut: Future<Output = Result<(), SearchError>>,
F: Fn(SearchOptions) -> Fut,
{
$($test_body)+
}

aux(|opt| async { crate::aio::search_gateway(opt).await.map(|_| ()) }).await;
}
}
)*
};
}

run_tests! {
fn ip_spoofing_in_broadcast_response(search_gateway) {
let local_free_port = start_broadcast_reply_sender("http://1.2.3.4:5".to_owned()).await;
let options = default_options_with_using_free_port(local_free_port);

let result = search_gateway(options).await;
Expand All @@ -163,29 +198,10 @@ async fn ip_spoofing_in_broadcast_response() {
}
}

aux(|opt| async {
tokio::task::spawn_blocking(|| search_gateway(opt).map(|_| ()))
.await
.unwrap()
})
.await;
#[cfg(feature = "aio")]
aux(|opt| async { crate::aio::search_gateway(opt).await.map(|_| ()) }).await;
}

#[test(tokio::test)]
async fn ip_spoofing_in_getxml_body() {
async fn aux<F, Fut>(search_gateway: F)
where
Fut: Future<Output = Result<(), SearchError>>,
F: Fn(SearchOptions) -> Fut,
{
fn ip_spoofing_in_getxml_body(search_gateway) {
let http_port = start_http_server(vec![RESP_SPOOFED_SCPDURL.to_owned()]).await;

let local_free_port = start_broadcast_reply_sender(format!("http://127.0.0.1:{http_port}")).await;

println!("http server port: {http_port}, udp port: {local_free_port}");

let options = default_options_with_using_free_port(local_free_port);

let result = search_gateway(options).await;
Expand All @@ -196,33 +212,16 @@ async fn ip_spoofing_in_getxml_body() {
panic!("Unexpected result: {result:?}");
}
}
aux(|opt| async {
tokio::task::spawn_blocking(|| search_gateway(opt).map(|_| ()))
.await
.unwrap()
})
.await;
#[cfg(feature = "aio")]
aux(|opt| async { crate::aio::search_gateway(opt).await.map(|_| ()) }).await;
}

#[test(tokio::test)]
async fn ip_spoofing_in_getxml_body_control_url() {
async fn aux<F, Fut>(search_gateway: F)
where
Fut: Future<Output = Result<(), SearchError>>,
F: Fn(SearchOptions) -> Fut,
{
fn ip_spoofing_in_getxml_body_control_url(search_gateway) {
let http_port = start_http_server(vec![
RESP_SPOOFED_CONTROL_URL.to_owned(),
RESP_CONTROL_SCHEMA.to_owned(),
])
.await;

let local_free_port = start_broadcast_reply_sender(format!("http://127.0.0.1:{http_port}")).await;

let options = default_options_with_using_free_port(local_free_port);

let result = search_gateway(options).await;

if let Err(SearchError::SpoofedUrl { src_ip, url_host }) = result {
Expand All @@ -232,37 +231,11 @@ async fn ip_spoofing_in_getxml_body_control_url() {
panic!("Unexpected result: {result:?}");
}
}
aux(|opt| async {
tokio::task::spawn_blocking(|| search_gateway(opt).map(|_| ()))
.await
.unwrap()
})
.await;
#[cfg(feature = "aio")]
aux(|opt| async { crate::aio::search_gateway(opt).await.map(|_| ()) }).await;
}

#[test(tokio::test)]
async fn non_spoofed_urls_result_in_search_gateway_success() {
async fn aux<F, Fut>(search_gateway: F)
where
Fut: Future<Output = Result<(), SearchError>>,
F: Fn(SearchOptions) -> Fut,
{
fn non_spoofed_urls_result_in_search_gateway_success(search_gateway) {
let http_port = start_http_server(vec![RESP.to_owned(), RESP_CONTROL_SCHEMA.to_owned()]).await;

let local_free_port = start_broadcast_reply_sender(format!("http://127.0.0.1:{http_port}")).await;

let options = default_options_with_using_free_port(local_free_port);

assert!(search_gateway(options).await.is_ok());
}
aux(|opt| async {
tokio::task::spawn_blocking(|| search_gateway(opt).map(|_| ()))
.await
.unwrap()
})
.await;
#[cfg(feature = "aio")]
aux(|opt| async { crate::aio::search_gateway(opt).await.map(|_| ()) }).await;
}

0 comments on commit 1bc1638

Please sign in to comment.