diff --git a/tower-http/Cargo.toml b/tower-http/Cargo.toml index b88ddeb3..c9de6366 100644 --- a/tower-http/Cargo.toml +++ b/tower-http/Cargo.toml @@ -68,6 +68,7 @@ full = [ "map-response-body", "metrics", "normalize-path", + "propagate-extension", "propagate-header", "redirect", "request-id", @@ -91,6 +92,7 @@ map-request-body = [] map-response-body = [] metrics = ["tokio/time"] normalize-path = [] +propagate-extension = [] propagate-header = [] redirect = [] request-id = ["uuid"] diff --git a/tower-http/src/builder.rs b/tower-http/src/builder.rs index 2cb4f94a..5784274a 100644 --- a/tower-http/src/builder.rs +++ b/tower-http/src/builder.rs @@ -54,6 +54,16 @@ pub trait ServiceBuilderExt<L>: crate::sealed::Sealed<L> + Sized { header: HeaderName, ) -> ServiceBuilder<Stack<crate::propagate_header::PropagateHeaderLayer, L>>; + /// Propagate an extension from the request to the response. + /// + /// See [`tower_http::propagate_extension`] for more details. + /// + /// [`tower_http::propagate_extension`]: crate::propagate_extension + #[cfg(feature = "propagate-extension")] + fn propagate_extension<T>( + self + ) -> ServiceBuilder<Stack<crate::propagate_extension::PropagateExtensionLayer<T>, L>>; + /// Add some shareable value to [request extensions]. /// /// See [`tower_http::add_extension`] for more details. @@ -380,6 +390,13 @@ impl<L> ServiceBuilderExt<L> for ServiceBuilder<L> { self.layer(crate::propagate_header::PropagateHeaderLayer::new(header)) } + #[cfg(feature = "propagate-extension")] + fn propagate_extension<X>( + self, + ) -> ServiceBuilder<Stack<crate::propagate_extension::PropagateExtensionLayer<X>, L>> { + self.layer(crate::propagate_extension::PropagateExtensionLayer::<X>::new()) + } + #[cfg(feature = "add-extension")] fn add_extension<T>( self, diff --git a/tower-http/src/compression/future.rs b/tower-http/src/compression/future.rs index 426bb161..bfccff52 100644 --- a/tower-http/src/compression/future.rs +++ b/tower-http/src/compression/future.rs @@ -73,6 +73,7 @@ where CompressionBody::new(BodyInner::zstd(WrapBody::new(body, self.quality))) } #[cfg(feature = "fs")] + #[allow(unreachable_patterns)] (true, _) => { // This should never happen because the `AcceptEncoding` struct which is used to determine // `self.encoding` will only enable the different compression algorithms if the diff --git a/tower-http/src/lib.rs b/tower-http/src/lib.rs index 6719ddbd..a5777b24 100644 --- a/tower-http/src/lib.rs +++ b/tower-http/src/lib.rs @@ -231,6 +231,9 @@ pub mod auth; #[cfg(feature = "set-header")] pub mod set_header; +#[cfg(feature = "propagate-extension")] +pub mod propagate_extension; + #[cfg(feature = "propagate-header")] pub mod propagate_header; diff --git a/tower-http/src/propagate_extension.rs b/tower-http/src/propagate_extension.rs new file mode 100644 index 00000000..3438cf2d --- /dev/null +++ b/tower-http/src/propagate_extension.rs @@ -0,0 +1,248 @@ +//! Propagate an extension from the request to the response. +//! +//! This middleware is intended to wrap a Request->Response service handler that is _unaware_ of the +//! extension. Consequently it _removes_ the extension from the request before forwarding the request, and then +//! inserts it into the response when the response is ready. As a usage example, if you have pre-service mappers +//! that need to share state with post-service mappers, you can store the state in the Request extensions, +//! and this middleware will ensure that it is available to the post service mappers via the Response extensions. +//! +//! # Example +//! +//! ```rust +//! use http::{Request, Response}; +//! use std::convert::Infallible; +//! use tower::{Service, ServiceExt, ServiceBuilder, service_fn}; +//! use tower_http::add_extension::AddExtensionLayer; +//! use tower_http::propagate_extension::PropagateExtensionLayer; +//! use tower_http::ServiceBuilderExt; +//! use hyper::Body; +//! +//! # #[tokio::main] +//! # async fn main() -> Result<(), Box<dyn std::error::Error>> { +//! async fn handle(req: Request<Body>) -> Result<Response<Body>, Infallible> { +//! // ... +//! # Ok(Response::new(Body::empty())) +//! } +//! +//! // +//! // Note that while the state object must _implement_ Clone, it should never actually +//! // _be_ cloned due to the manner in which it is used within the middleware. +//! // +//! #[derive(Clone)] +//! struct MyState { +//! state_message: String +//! }; +//! +//! let my_state = MyState { state_message: "propagated state".to_string() }; +//! +//! let mut svc = ServiceBuilder::new() +//! .add_extension(my_state) // any other way of adding the extension to the request is OK too +//! .propagate_extension::<MyState>() +//! .service_fn(handle); +//! +//! // Call the service. +//! let request = Request::builder() +//! .body(Body::empty())?; +//! +//! let response = svc.ready().await?.call(request).await?; +//! +//! assert_eq!(response.extensions().get::<MyState>().unwrap().state_message, "propagated state"); +//! # +//! # Ok(()) +//! # } +//! ``` + +use futures_util::ready; +use http::{Request, Response}; +use pin_project_lite::pin_project; +use std::future::Future; +use std::{ + pin::Pin, + task::{Context, Poll}, + marker::PhantomData, +}; +use tower_layer::Layer; +use tower_service::Service; + +#[allow(unused_imports)] +use tracing::{ + trace, + debug, + info, + warn, + error, +}; + +/// Layer that applies [`PropagateExtension`] which propagates an extension from the request to the response. +/// +/// This middleware is intended to wrap a Request->Response service handler that is _unaware_ of the +/// extension. Consequently it _removes_ the extension from the request before forwarding the request, and then +/// inserts it into the response when the response is ready. As a usage example, if you have pre-service mappers +/// that need to share state with post-service mappers, you can store the state in the Request extensions, +/// and this middleware will ensure that it is available to the post service mappers via the Response extensions. +/// +/// See the [module docs](crate::propagate_extension) for more details. +#[derive(Clone, Debug)] +pub struct PropagateExtensionLayer<X> { + _phantom: PhantomData<X> +} + +impl<X> PropagateExtensionLayer<X> { + /// Create a new [`PropagateExtensionLayer`]. + pub fn new() -> Self { + Self { _phantom: PhantomData } + } +} + +impl<S,X> Layer<S> for PropagateExtensionLayer<X> { + type Service = PropagateExtension<S,X>; + + fn layer(&self, inner: S) -> Self::Service { + PropagateExtension::<S,X> { + inner, + _phantom: PhantomData + } + } +} + +/// Middleware that propagates extensions from requests to responses. +/// +/// If the extension is present on the request it'll be removed from the request and +/// inserted into the response. +/// +/// See the [module docs](crate::propagate_extension) for more details. +#[derive(Clone,Debug)] +pub struct PropagateExtension<S,X> { + inner: S, + _phantom: PhantomData<X> +} + +impl<S,X> PropagateExtension<S,X> { + /// Create a new [`PropagateExtension`] that propagates the given extension type. + pub fn new(inner: S) -> Self { + Self { inner, _phantom: PhantomData } + } + + define_inner_service_accessors!(); + + /// Returns a new [`Layer`] that wraps services with a `PropagateExtension` middleware. + /// + /// [`Layer`]: tower_layer::Layer + pub fn layer() -> PropagateExtensionLayer<X> { + PropagateExtensionLayer::<X>::new() + } +} + +impl<ReqBody, ResBody, S, X> Service<Request<ReqBody>> for PropagateExtension<S,X> +where + X: Sync + Send + 'static, + S: Service<Request<ReqBody>, Response = Response<ResBody>>, +{ + type Response = S::Response; + type Error = S::Error; + type Future = ResponseFuture<S::Future,X>; + + #[inline] + fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> { + self.inner.poll_ready(cx) + } + + fn call(&mut self, mut req: Request<ReqBody>) -> Self::Future { + let extension: Option<X> = req.extensions_mut().remove(); + debug!("Removed state from request extensions. is_some? {}", extension.is_some()); + + ResponseFuture { + future: self.inner.call(req), + extension, + } + } +} + +pin_project! { + /// Response future for [`PropagateExtension`]. + #[derive(Debug)] + pub struct ResponseFuture<F,X> { + #[pin] + future: F, + extension: Option<X>, + } +} + +impl<F, ResBody, E, X> Future for ResponseFuture<F,X> +where + X: Sync + Send + 'static, + F: Future<Output = Result<Response<ResBody>, E>>, +{ + type Output = F::Output; + + fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> { + let this = self.project(); + let mut res = ready!(this.future.poll(cx)?); + + if let Some(extension) = this.extension.take() { + debug!("Inserting state into response extensions"); + res.extensions_mut().insert(extension); + } else { + debug!("No state to insert into response"); + } + + Poll::Ready(Ok(res)) + } +} + +#[cfg(test)] +mod tests { + use super::*; + + use http::{Request, Response}; + use std::convert::Infallible; + use tower::{Service, ServiceExt, ServiceBuilder}; + use crate::add_extension::AddExtensionLayer; + use crate::builder::ServiceBuilderExt; + use hyper::Body; + + async fn handle(_req: Request<Body>) -> Result<Response<Body>, Infallible> { + Ok(Response::new(Body::empty())) + } + + #[derive(Clone)] + struct MyState { + state_message: String + } + + #[test] + fn basic_test() { + + let my_state = MyState { state_message: "propagated state".to_string() }; + + let mut svc = ServiceBuilder::new() + .layer(AddExtensionLayer::new(my_state)) // any other way of adding the extension to the request is OK too + .layer(PropagateExtensionLayer::<MyState>::new()) + .service_fn(handle); + + let request = Request::builder().body(Body::empty()).expect("Expected an empty body"); + + // Call the service. + let ready = futures::executor::block_on(svc.ready()).expect("Expected the service to be ready"); + let response = futures::executor::block_on(ready.call(request)).expect("Expected the service to be successful"); + assert_eq!(response.extensions().get::<MyState>().unwrap().state_message, "propagated state"); + } + + #[test] + fn test_server_builder_ext() { + + let my_state = MyState { state_message: "propagated state".to_string() }; + + let mut svc = ServiceBuilder::new() + .add_extension(my_state) // any other way of adding the extension to the request is OK too + .propagate_extension::<MyState>() + .service_fn(handle); + + let request = Request::builder().body(Body::empty()).expect("Expected an empty body"); + + // Call the service. + let ready = futures::executor::block_on(svc.ready()).expect("Expected the service to be ready"); + let response = futures::executor::block_on(ready.call(request)).expect("Expected the service to be successful"); + assert_eq!(response.extensions().get::<MyState>().unwrap().state_message, "propagated state"); + } +}