diff --git a/axum-core/src/extract/mod.rs b/axum-core/src/extract/mod.rs index 1baa893555..761ea7fc0f 100644 --- a/axum-core/src/extract/mod.rs +++ b/axum-core/src/extract/mod.rs @@ -13,11 +13,16 @@ pub mod rejection; mod default_body_limit; mod from_ref; +mod option; mod request_parts; mod tuple; pub(crate) use self::default_body_limit::DefaultBodyLimitKind; -pub use self::{default_body_limit::DefaultBodyLimit, from_ref::FromRef}; +pub use self::{ + default_body_limit::DefaultBodyLimit, + from_ref::FromRef, + option::{OptionalFromRequest, OptionalFromRequestParts}, +}; /// Type alias for [`http::Request`] whose body type defaults to [`Body`], the most common body /// type used with axum. @@ -102,33 +107,6 @@ where } } -impl FromRequestParts for Option -where - T: FromRequestParts, - S: Send + Sync, -{ - type Rejection = Infallible; - - async fn from_request_parts( - parts: &mut Parts, - state: &S, - ) -> Result, Self::Rejection> { - Ok(T::from_request_parts(parts, state).await.ok()) - } -} - -impl FromRequest for Option -where - T: FromRequest, - S: Send + Sync, -{ - type Rejection = Infallible; - - async fn from_request(req: Request, state: &S) -> Result, Self::Rejection> { - Ok(T::from_request(req, state).await.ok()) - } -} - impl FromRequestParts for Result where T: FromRequestParts, diff --git a/axum-core/src/extract/option.rs b/axum-core/src/extract/option.rs new file mode 100644 index 0000000000..f628c8bf94 --- /dev/null +++ b/axum-core/src/extract/option.rs @@ -0,0 +1,62 @@ +use std::future::Future; + +use http::request::Parts; + +use crate::response::IntoResponse; + +use super::{private, FromRequest, FromRequestParts, Request}; + +/// TODO: DOCS +pub trait OptionalFromRequestParts: Sized { + /// If the extractor fails, it will use this "rejection" type. + /// + /// A rejection is a kind of error that can be converted into a response. + type Rejection: IntoResponse; + + /// Perform the extraction. + fn from_request_parts( + parts: &mut Parts, + state: &S, + ) -> impl Future, Self::Rejection>> + Send; +} + +/// TODO: DOCS +pub trait OptionalFromRequest: Sized { + /// If the extractor fails, it will use this "rejection" type. + /// + /// A rejection is a kind of error that can be converted into a response. + type Rejection: IntoResponse; + + /// Perform the extraction. + fn from_request( + req: Request, + state: &S, + ) -> impl Future, Self::Rejection>> + Send; +} + +impl FromRequestParts for Option +where + T: OptionalFromRequestParts, + S: Send + Sync, +{ + type Rejection = T::Rejection; + + fn from_request_parts( + parts: &mut Parts, + state: &S, + ) -> impl Future, Self::Rejection>> { + T::from_request_parts(parts, state) + } +} + +impl FromRequest for Option +where + T: OptionalFromRequest, + S: Send + Sync, +{ + type Rejection = T::Rejection; + + async fn from_request(req: Request, state: &S) -> Result, Self::Rejection> { + T::from_request(req, state).await + } +} diff --git a/axum-extra/src/extract/mod.rs b/axum-extra/src/extract/mod.rs index 7d2a5b2433..f858338d9c 100644 --- a/axum-extra/src/extract/mod.rs +++ b/axum-extra/src/extract/mod.rs @@ -24,9 +24,9 @@ pub mod multipart; #[cfg(feature = "scheme")] mod scheme; -pub use self::{ - cached::Cached, host::Host, optional_path::OptionalPath, with_rejection::WithRejection, -}; +#[allow(deprecated)] +pub use self::optional_path::OptionalPath; +pub use self::{cached::Cached, host::Host, with_rejection::WithRejection}; #[cfg(feature = "cookie")] pub use self::cookie::CookieJar; @@ -41,7 +41,10 @@ pub use self::cookie::SignedCookieJar; pub use self::form::{Form, FormRejection}; #[cfg(feature = "query")] -pub use self::query::{OptionalQuery, OptionalQueryRejection, Query, QueryRejection}; +#[allow(deprecated)] +pub use self::query::OptionalQuery; +#[cfg(feature = "query")] +pub use self::query::{OptionalQueryRejection, Query, QueryRejection}; #[cfg(feature = "multipart")] pub use self::multipart::Multipart; diff --git a/axum-extra/src/extract/optional_path.rs b/axum-extra/src/extract/optional_path.rs index 53443f1952..466944ff55 100644 --- a/axum-extra/src/extract/optional_path.rs +++ b/axum-extra/src/extract/optional_path.rs @@ -1,5 +1,5 @@ use axum::{ - extract::{path::ErrorKind, rejection::PathRejection, FromRequestParts, Path}, + extract::{rejection::PathRejection, FromRequestParts, Path}, RequestPartsExt, }; use serde::de::DeserializeOwned; @@ -31,9 +31,11 @@ use serde::de::DeserializeOwned; /// .route("/blog/{page}", get(render_blog)); /// # let app: Router = app; /// ``` +#[deprecated = "Use Option> instead"] #[derive(Debug)] pub struct OptionalPath(pub Option); +#[allow(deprecated)] impl FromRequestParts for OptionalPath where T: DeserializeOwned + Send + 'static, @@ -45,19 +47,15 @@ where parts: &mut http::request::Parts, _: &S, ) -> Result { - match parts.extract::>().await { - Ok(Path(params)) => Ok(Self(Some(params))), - Err(PathRejection::FailedToDeserializePathParams(e)) - if matches!(e.kind(), ErrorKind::WrongNumberOfParameters { got: 0, .. }) => - { - Ok(Self(None)) - } - Err(e) => Err(e), - } + parts + .extract::>>() + .await + .map(|opt| Self(opt.map(|Path(x)| x))) } } #[cfg(test)] +#[allow(deprecated)] mod tests { use std::num::NonZeroU32; diff --git a/axum-extra/src/extract/query.rs b/axum-extra/src/extract/query.rs index 695ea9576b..7c73cec38c 100644 --- a/axum-extra/src/extract/query.rs +++ b/axum-extra/src/extract/query.rs @@ -1,5 +1,5 @@ use axum::{ - extract::FromRequestParts, + extract::{FromRequestParts, OptionalFromRequestParts}, response::{IntoResponse, Response}, Error, }; @@ -96,6 +96,27 @@ where } } +impl OptionalFromRequestParts for Query +where + T: DeserializeOwned, + S: Send + Sync, +{ + type Rejection = QueryRejection; + + async fn from_request_parts( + parts: &mut Parts, + _state: &S, + ) -> Result, Self::Rejection> { + if let Some(query) = parts.uri.query() { + let value = serde_html_form::from_str(query) + .map_err(|err| QueryRejection::FailedToDeserializeQueryString(Error::new(err)))?; + Ok(Some(Self(value))) + } else { + Ok(None) + } + } +} + axum_core::__impl_deref!(Query); /// Rejection used for [`Query`]. @@ -182,9 +203,11 @@ impl std::error::Error for QueryRejection { /// /// [example]: https://github.com/tokio-rs/axum/blob/main/examples/query-params-with-empty-strings/src/main.rs #[cfg_attr(docsrs, doc(cfg(feature = "query")))] +#[deprecated = "Use Option> instead"] #[derive(Debug, Clone, Copy, Default)] pub struct OptionalQuery(pub Option); +#[allow(deprecated)] impl FromRequestParts for OptionalQuery where T: DeserializeOwned, @@ -204,6 +227,7 @@ where } } +#[allow(deprecated)] impl std::ops::Deref for OptionalQuery { type Target = Option; @@ -213,6 +237,7 @@ impl std::ops::Deref for OptionalQuery { } } +#[allow(deprecated)] impl std::ops::DerefMut for OptionalQuery { #[inline] fn deref_mut(&mut self) -> &mut Self::Target { @@ -260,6 +285,7 @@ impl std::error::Error for OptionalQueryRejection { } #[cfg(test)] +#[allow(deprecated)] mod tests { use super::*; use crate::test_helpers::*; diff --git a/axum-extra/src/typed_header.rs b/axum-extra/src/typed_header.rs index ef94c3779c..820c423577 100644 --- a/axum-extra/src/typed_header.rs +++ b/axum-extra/src/typed_header.rs @@ -1,7 +1,7 @@ //! Extractor and response for typed headers. use axum::{ - extract::FromRequestParts, + extract::{FromRequestParts, OptionalFromRequestParts}, response::{IntoResponse, IntoResponseParts, Response, ResponseParts}, }; use headers::{Header, HeaderMapExt}; @@ -78,6 +78,30 @@ where } } +impl OptionalFromRequestParts for TypedHeader +where + T: Header, + S: Send + Sync, +{ + type Rejection = TypedHeaderRejection; + + async fn from_request_parts( + parts: &mut Parts, + _state: &S, + ) -> Result, Self::Rejection> { + let mut values = parts.headers.get_all(T::name()).iter(); + let is_missing = values.size_hint() == (0, Some(0)); + match T::decode(&mut values) { + Ok(res) => Ok(Some(Self(res))), + Err(_) if is_missing => Ok(None), + Err(err) => Err(TypedHeaderRejection { + name: T::name(), + reason: TypedHeaderRejectionReason::Error(err), + }), + } + } +} + axum_core::__impl_deref!(TypedHeader); impl IntoResponseParts for TypedHeader diff --git a/axum-macros/tests/typed_path/pass/option_result.rs b/axum-macros/tests/typed_path/pass/result_handler.rs similarity index 87% rename from axum-macros/tests/typed_path/pass/option_result.rs rename to axum-macros/tests/typed_path/pass/result_handler.rs index 81cfb29482..2053c1a56c 100644 --- a/axum-macros/tests/typed_path/pass/option_result.rs +++ b/axum-macros/tests/typed_path/pass/result_handler.rs @@ -8,8 +8,6 @@ struct UsersShow { id: String, } -async fn option_handler(_: Option) {} - async fn result_handler(_: Result) {} #[derive(TypedPath, Deserialize)] @@ -20,7 +18,6 @@ async fn result_handler_unit_struct(_: Result) {} fn main() { _ = axum::Router::<()>::new() - .typed_get(option_handler) .typed_post(result_handler) .typed_post(result_handler_unit_struct); } diff --git a/axum/Cargo.toml b/axum/Cargo.toml index ef113c335b..0b2954f054 100644 --- a/axum/Cargo.toml +++ b/axum/Cargo.toml @@ -116,6 +116,7 @@ features = [ [dev-dependencies] anyhow = "1.0" +axum-extra = { path = "../axum-extra", features = ["typed-header"] } axum-macros = { path = "../axum-macros", features = ["__private"] } hyper = { version = "1.1.0", features = ["client"] } quickcheck = "1.0" diff --git a/axum/src/docs/extract.md b/axum/src/docs/extract.md index 244528d6a8..b5e8c93eda 100644 --- a/axum/src/docs/extract.md +++ b/axum/src/docs/extract.md @@ -202,26 +202,22 @@ and all others implement [`FromRequestParts`]. # Optional extractors -All extractors defined in axum will reject the request if it doesn't match. -If you wish to make an extractor optional you can wrap it in `Option`: +TODO: Docs, more realistic example ```rust,no_run -use axum::{ - extract::Json, - routing::post, - Router, -}; +use axum::{routing::post, Router}; +use axum_extra::{headers::UserAgent, TypedHeader}; use serde_json::Value; -async fn create_user(payload: Option>) { - if let Some(payload) = payload { - // We got a valid JSON payload +async fn foo(user_agent: Option>) { + if let Some(TypedHeader(user_agent)) = user_agent { + // The client sent a user agent } else { - // Payload wasn't valid JSON + // No user agent header } } -let app = Router::new().route("/users", post(create_user)); +let app = Router::new().route("/foo", post(foo)); # let _: Router = app; ``` diff --git a/axum/src/extract/matched_path.rs b/axum/src/extract/matched_path.rs index 124154f7ef..c9c08c768e 100644 --- a/axum/src/extract/matched_path.rs +++ b/axum/src/extract/matched_path.rs @@ -1,5 +1,6 @@ use super::{rejection::*, FromRequestParts}; use crate::routing::{RouteId, NEST_TAIL_PARAM_CAPTURE}; +use axum_core::extract::OptionalFromRequestParts; use http::request::Parts; use std::{collections::HashMap, sync::Arc}; @@ -79,6 +80,20 @@ where } } +impl OptionalFromRequestParts for MatchedPath +where + S: Send + Sync, +{ + type Rejection = MatchedPathRejection; + + async fn from_request_parts( + parts: &mut Parts, + _state: &S, + ) -> Result, Self::Rejection> { + Ok(parts.extensions.get::().cloned()) + } +} + #[derive(Clone, Debug)] struct MatchedNestedPath(Arc); diff --git a/axum/src/extract/mod.rs b/axum/src/extract/mod.rs index 6d5e8de857..61b57418a2 100644 --- a/axum/src/extract/mod.rs +++ b/axum/src/extract/mod.rs @@ -17,7 +17,10 @@ mod request_parts; mod state; #[doc(inline)] -pub use axum_core::extract::{DefaultBodyLimit, FromRef, FromRequest, FromRequestParts, Request}; +pub use axum_core::extract::{ + DefaultBodyLimit, FromRef, FromRequest, FromRequestParts, OptionalFromRequest, + OptionalFromRequestParts, Request, +}; #[cfg(feature = "macros")] pub use axum_macros::{FromRef, FromRequest, FromRequestParts}; diff --git a/axum/src/extract/path/mod.rs b/axum/src/extract/path/mod.rs index 427db8f20d..2ddcbccfbd 100644 --- a/axum/src/extract/path/mod.rs +++ b/axum/src/extract/path/mod.rs @@ -8,7 +8,11 @@ use crate::{ routing::url_params::UrlParams, util::PercentDecodedStr, }; -use axum_core::response::{IntoResponse, Response}; +use axum_core::{ + extract::OptionalFromRequestParts, + response::{IntoResponse, Response}, + RequestPartsExt as _, +}; use http::{request::Parts, StatusCode}; use serde::de::DeserializeOwned; use std::{fmt, sync::Arc}; @@ -176,6 +180,29 @@ where } } +impl OptionalFromRequestParts for Path +where + T: DeserializeOwned + Send + 'static, + S: Send + Sync, +{ + type Rejection = PathRejection; + + async fn from_request_parts( + parts: &mut Parts, + _state: &S, + ) -> Result, Self::Rejection> { + match parts.extract::().await { + Ok(Self(params)) => Ok(Some(Self(params))), + Err(PathRejection::FailedToDeserializePathParams(e)) + if matches!(e.kind(), ErrorKind::WrongNumberOfParameters { got: 0, .. }) => + { + Ok(None) + } + Err(e) => Err(e), + } + } +} + // this wrapper type is used as the deserializer error to hide the `serde::de::Error` impl which // would otherwise be public if we used `ErrorKind` as the error directly #[derive(Debug)] diff --git a/axum/src/extract/query.rs b/axum/src/extract/query.rs index 371612b71a..68f5bd4ef1 100644 --- a/axum/src/extract/query.rs +++ b/axum/src/extract/query.rs @@ -1,4 +1,4 @@ -use super::{rejection::*, FromRequestParts}; +use super::{rejection::*, FromRequestParts, OptionalFromRequestParts}; use http::{request::Parts, Uri}; use serde::de::DeserializeOwned; @@ -62,6 +62,27 @@ where } } +impl OptionalFromRequestParts for Query +where + T: DeserializeOwned, + S: Send + Sync, +{ + type Rejection = QueryRejection; + + async fn from_request_parts( + parts: &mut Parts, + _state: &S, + ) -> Result, Self::Rejection> { + if let Some(query) = parts.uri.query() { + let value = serde_urlencoded::from_str(query) + .map_err(FailedToDeserializeQueryString::from_err)?; + Ok(Some(Self(value))) + } else { + Ok(None) + } + } +} + impl Query where T: DeserializeOwned, diff --git a/examples/oauth/src/main.rs b/examples/oauth/src/main.rs index 4241db2a7c..8bea2c29b0 100644 --- a/examples/oauth/src/main.rs +++ b/examples/oauth/src/main.rs @@ -11,7 +11,7 @@ use anyhow::{anyhow, Context, Result}; use async_session::{MemoryStore, Session, SessionStore}; use axum::{ - extract::{FromRef, FromRequestParts, Query, State}, + extract::{FromRef, FromRequestParts, OptionalFromRequestParts, Query, State}, http::{header::SET_COOKIE, HeaderMap}, response::{IntoResponse, Redirect, Response}, routing::get, @@ -24,7 +24,7 @@ use oauth2::{ ClientSecret, CsrfToken, RedirectUrl, Scope, TokenResponse, TokenUrl, }; use serde::{Deserialize, Serialize}; -use std::env; +use std::{convert::Infallible, env}; use tracing_subscriber::{layer::SubscriberExt, util::SubscriberInitExt}; static COOKIE_NAME: &str = "SESSION"; @@ -351,6 +351,24 @@ where } } +impl OptionalFromRequestParts for User +where + MemoryStore: FromRef, + S: Send + Sync, +{ + type Rejection = Infallible; + + async fn from_request_parts( + parts: &mut Parts, + state: &S, + ) -> Result, Self::Rejection> { + match >::from_request_parts(parts, state).await { + Ok(res) => Ok(Some(res)), + Err(AuthRedirect) => Ok(None), + } + } +} + // Use anyhow, define error and enable '?' // For a simplified example of using anyhow in axum check /examples/anyhow-error-response #[derive(Debug)]