Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
73 changes: 73 additions & 0 deletions .github/workflows/ci.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
name: CI
on:
pull_request:
branches: [ main ]
push:
branches: [ main ]

# Concurrency strategy:
# github.workflow: distinguish this workflow from others
# github.event_name: distinguish `push` event from `pull_request` event
# github.event.number: set to the number of the pull request if `pull_request` event
# github.run_id: otherwise, it's a `push` event, only cancel if we rerun the workflow
#
# Reference:
# https://docs.github.com/en/actions/using-jobs/using-concurrency
# https://docs.github.com/en/actions/learn-github-actions/contexts#github-context
concurrency:
group: ${{ github.workflow }}-${{ github.event_name }}-${{ github.event.number || github.run_id }}
cancel-in-progress: true

jobs:
check:
name: Check
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v4
- uses: Swatinem/rust-cache@v2
- name: Check Clippy
run: cargo clippy --tests --all-features --all-targets --workspace -- -D warnings
- name: Check format
run: cargo fmt --all --check

test:
name: Build and test
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v4
- uses: Swatinem/rust-cache@v2
- run: cargo build --workspace --all-features --tests --examples --benches
- name: Run tests
run: cargo test --workspace -- --nocapture

examples:
name: Validate examples
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v4
- uses: Swatinem/rust-cache@v2
- run: cargo build --workspace --all-features --tests --examples --benches
working-directory: examples/hello-world
- name: Run tests
run: cargo test --workspace -- --nocapture
working-directory: examples/hello-world

required:
name: Required
runs-on: ubuntu-latest
if: ${{ always() }}
needs:
- check
- examples
- test
steps:
- name: Guardian
run: |
if [[ ! ( \
"${{ needs.check.result }}" == "success" \
&& "${{ needs.test.result }}" == "success" \
&& "${{ needs.examples.result }}" == "success" \
) ]]; then
echo "Required jobs haven't been completed successfully."
exit -1
fi
14 changes: 7 additions & 7 deletions Cargo.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[package]
name = "axum-test-helper"
version = "0.3.0"
version = "0.4.0"
edition = "2021"
categories = ["development-tools::testing"]
description = "Extra utilities for axum"
Expand All @@ -11,16 +11,16 @@ readme = "README.md"
repository = "https://github.com/cloudwalk/axum-test-helper"

[dependencies]
axum = "0.6"
axum = "0.7"
reqwest = { version = "0.11", features = ["json", "stream", "multipart", "rustls-tls"], default-features = false }
http = "0.2"
http-body = "0.4"
bytes = "1.4.0"
tower = "0.4.13"
http = "1.1"
http-body = "1.0"
bytes = "1.4"
tower = "0.4"
tower-service = "0.3"
serde = "1.0"
tokio = "1"
hyper = "0.14"
hyper = "1.0"

[dev-dependencies]
serde = { version = "1", features = ["serde_derive"] }
Expand Down
4 changes: 2 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ More information about this crate can be found in the [crate documentation][docs

## High level features

- Provide an easy to use interface
- Provide an easy-to-use interface
- Start a server in a different port for each call
- Deal with JSON, text and files response/requests

Expand All @@ -16,7 +16,7 @@ Add this crate as a dev-dependency:

```
[dev-dependencies]
axum-test-helper = "0.*" # alternatively specify the version as "0.3.0"
axum-test-helper = "0.4"
```

Use the TestClient on your own Router:
Expand Down
4 changes: 2 additions & 2 deletions examples/hello-world/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,6 @@ edition = "2021"
publish = false

[dependencies]
axum = { version = "0.6.*" }
axum-test-helper = { version = "0.*" }
axum = { version = "0.7" }
axum-test-helper = { path = "../.." }
tokio = { version = "1.0", features = ["full"] }
8 changes: 4 additions & 4 deletions examples/hello-world/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

use axum::{response::Html, routing::get, Router};
use std::net::SocketAddr;
use tokio::net::TcpListener;

#[tokio::main]
async fn main() {
Expand All @@ -15,10 +16,9 @@ async fn main() {
// run it
let addr = SocketAddr::from(([127, 0, 0, 1], 3000));
println!("listening on {}", addr);
axum::Server::bind(&addr)
.serve(router.into_make_service())
.await
.unwrap();

let listener = TcpListener::bind(addr).await.unwrap();
axum::serve(listener, router.into_make_service()).await.unwrap();
}

fn router() -> Router {
Expand Down
3 changes: 3 additions & 0 deletions rust-toolchain.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
[toolchain]
channel = "stable"
components = ["rustfmt", "clippy", "rust-analyzer"]
126 changes: 80 additions & 46 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,71 +6,87 @@
//! - `withouttrace` - Disables tracing for the test client.
//!
//! ## Example
//!
//! ```rust
//! use axum::Router;
//! use axum::http::StatusCode;
//! use axum::routing::get;
//! use axum_test_helper::TestClient;
//!
//! fn main() {
//! let async_block = async {
//! // you can replace this Router with your own app
//! let app = Router::new().route("/", get(|| async {}));
//! let async_block = async {
//! // you can replace this Router with your own app
//! let app = Router::new().route("/", get(|| async {}));
//!
//! // initiate the TestClient with the previous declared Router
//! let client = TestClient::new(app);
//! // initiate the TestClient with the previous declared Router
//! let client = TestClient::new(app);
//!
//! let res = client.get("/").send().await;
//! assert_eq!(res.status(), StatusCode::OK);
//! };
//! let res = client.get("/").send().await;
//! assert_eq!(res.status(), StatusCode::OK);
//! };
//!
//! // Create a runtime for executing the async block. This runtime is local
//! // to the main function and does not require any global setup.
//! let runtime = tokio::runtime::Builder::new_current_thread()
//! .enable_all()
//! .build()
//! .unwrap();
//! // Create a runtime for executing the async block. This runtime is local
//! // to the main function and does not require any global setup.
//! let runtime = tokio::runtime::Builder::new_current_thread()
//! .enable_all()
//! .build()
//! .unwrap();
//!
//! // Use the local runtime to block on the async block.
//! runtime.block_on(async_block);
//! }

use axum::{body::HttpBody, BoxError};
//! // Use the local runtime to block on the async block.
//! runtime.block_on(async_block);
//! ```

use axum::extract::Request;
use axum::response::Response;
use axum::serve;
use axum::ServiceExt;
use bytes::Bytes;
use http::{
header::{HeaderName, HeaderValue},
Request, StatusCode,
StatusCode,
};
use hyper::{Body, Server};
use std::convert::TryFrom;
use std::convert::{Infallible, TryFrom};
use std::net::{SocketAddr, TcpListener};
use tower::make::Shared;
use std::str::FromStr;
use tower_service::Service;

pub struct TestClient {
client: reqwest::Client,
addr: SocketAddr,
}

pub(crate) fn spawn_service<S>(svc: S) -> SocketAddr
where
S: Service<Request, Response = Response, Error = Infallible> + Clone + Send + 'static,
S::Future: Send,
{
let std_listener = TcpListener::bind("127.0.0.1:0").unwrap();
std_listener.set_nonblocking(true).unwrap();
let listener = tokio::net::TcpListener::from_std(std_listener).unwrap();

let addr = listener.local_addr().unwrap();

#[cfg(feature = "withtrace")]
println!("Listening on {addr}");

tokio::spawn(async move {
serve(
listener,
svc.into_make_service_with_connect_info::<SocketAddr>(),
)
.await
.expect("server error")
});

addr
}

impl TestClient {
pub fn new<S, ResBody>(svc: S) -> Self
pub fn new<S>(svc: S) -> Self
where
S: Service<Request<Body>, Response = http::Response<ResBody>> + Clone + Send + 'static,
ResBody: HttpBody + Send + 'static,
ResBody::Data: Send,
ResBody::Error: Into<BoxError>,
S: Service<Request, Response = Response, Error = Infallible> + Clone + Send + 'static,
S::Future: Send,
S::Error: Into<BoxError>,
{
let listener = TcpListener::bind("127.0.0.1:0").expect("Could not bind ephemeral socket");
let addr = listener.local_addr().unwrap();
#[cfg(feature = "withtrace")]
println!("Listening on {}", addr);

tokio::spawn(async move {
let server = Server::from_tcp(listener).unwrap().serve(Shared::new(svc));
server.await.expect("server error");
});
let addr = spawn_service(svc);

#[cfg(feature = "cookies")]
let client = reqwest::Client::builder()
Expand Down Expand Up @@ -169,6 +185,13 @@ impl RequestBuilder {
HeaderValue: TryFrom<V>,
<HeaderValue as TryFrom<V>>::Error: Into<http::Error>,
{
// reqwest still uses http 0.2
let key: HeaderName = key.try_into().map_err(Into::into).unwrap();
let key = reqwest::header::HeaderName::from_bytes(key.as_ref()).unwrap();

let value: HeaderValue = value.try_into().map_err(Into::into).unwrap();
let value = reqwest::header::HeaderValue::from_bytes(value.as_bytes()).unwrap();

self.builder = self.builder.header(key, value);
self
}
Expand Down Expand Up @@ -206,11 +229,18 @@ impl TestResponse {
}

pub fn status(&self) -> StatusCode {
self.response.status()
StatusCode::from_u16(self.response.status().as_u16()).unwrap()
}

pub fn headers(&self) -> &http::HeaderMap {
self.response.headers()
pub fn headers(&self) -> http::HeaderMap {
// reqwest still uses http 0.2 so have to convert into http 1.0
let mut headers = http::HeaderMap::new();
for (key, value) in self.response.headers() {
let key = HeaderName::from_str(key.as_str()).unwrap();
let value = HeaderValue::from_bytes(value.as_bytes()).unwrap();
headers.insert(key, value);
}
headers
}

pub async fn chunk(&mut self) -> Option<Bytes> {
Expand All @@ -237,9 +267,9 @@ impl AsRef<reqwest::Response> for TestResponse {
#[cfg(test)]
mod tests {
use axum::response::Html;
use serde::{Deserialize, Serialize};
use axum::{routing::get, routing::post, Router, Json};
use axum::{routing::get, routing::post, Json, Router};
use http::StatusCode;
use serde::{Deserialize, Serialize};

#[derive(Deserialize)]
struct FooForm {
Expand Down Expand Up @@ -276,13 +306,17 @@ mod tests {

#[tokio::test]
async fn test_post_request_with_json() {
let app = Router::new().route("/", post(|json_value: Json<serde_json::Value>| async {json_value}));
let app = Router::new().route(
"/",
post(|json_value: Json<serde_json::Value>| async { json_value }),
);
let client = super::TestClient::new(app);
let payload = TestPayload {
name: "Alice".to_owned(),
age: 30,
};
let res = client.post("/")
let res = client
.post("/")
.header("Content-Type", "application/json")
.json(&payload)
.send()
Expand Down