diff --git a/saphir/Cargo.toml b/saphir/Cargo.toml index bf172bd..bb6caba 100644 --- a/saphir/Cargo.toml +++ b/saphir/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "saphir" -version = "2.6.12" +version = "2.7.0" edition = "2018" authors = ["Richer Archambault "] description = "Fully async-await http server framework" @@ -37,13 +37,13 @@ http-body = "0.3" parking_lot = "0.11" regex = "1.3" uuid = { version = "0.8", features = ["serde", "v4"], optional = true } -rustls = { version = "0.17", optional = true } -tokio-rustls = { version = "0.13", optional = true } +rustls = { version = "0.18", optional = true } +tokio-rustls = { version = "0.14", optional = true } base64 = { version = "0.12", optional = true } serde = { version = "1.0", optional = true } serde_json = { version = "1.0", optional = true } serde_urlencoded = { version = "0.6", optional = true } -saphir_macro = { path = "../saphir_macro", version = "2.0.10", optional = true } +saphir_macro = { path = "../saphir_macro", version = "2.1", optional = true } mime = { version = "0.3", optional = true } nom = { version = "5", optional = true } mime_guess = { version = "2.0.3", optional = true } diff --git a/saphir/examples/macro.rs b/saphir/examples/macro.rs index acbdab7..8882e17 100644 --- a/saphir/examples/macro.rs +++ b/saphir/examples/macro.rs @@ -101,6 +101,7 @@ impl ApiKeyMiddleware { info!("Not Authenticated"); } + info!("Handler {} will be used", ctx.metadata.name.unwrap_or("unknown")); chain.next(ctx).await } } diff --git a/saphir/src/controller.rs b/saphir/src/controller.rs index 9c8c5a9..5d3382e 100644 --- a/saphir/src/controller.rs +++ b/saphir/src/controller.rs @@ -42,7 +42,13 @@ use futures_util::future::{Future, FutureExt}; use http::Method; /// Type definition to represent a endpoint within a controller -pub type ControllerEndpoint = (Method, &'static str, Box + Send + Sync>, Box); +pub type ControllerEndpoint = ( + Option<&'static str>, + Method, + &'static str, + Box + Send + Sync>, + Box, +); /// Trait that defines how a controller handles its requests pub trait Controller { @@ -124,7 +130,7 @@ impl EndpointsBuilder { where H: 'static + DynControllerHandler + Send + Sync, { - self.handlers.push((method, route, Box::new(handler), GuardBuilder::default().build())); + self.handlers.push((None, method, route, Box::new(handler), GuardBuilder::default().build())); self } @@ -136,7 +142,32 @@ impl EndpointsBuilder { F: FnOnce(GuardBuilder) -> GuardBuilder, Chain: GuardChain + 'static, { - self.handlers.push((method, route, Box::new(handler), guards(GuardBuilder::default()).build())); + self.handlers + .push((None, method, route, Box::new(handler), guards(GuardBuilder::default()).build())); + self + } + + /// Add but with a handler name + #[inline] + pub fn add_with_name(mut self, handler_name: &'static str, method: Method, route: &'static str, handler: H) -> Self + where + H: 'static + DynControllerHandler + Send + Sync, + { + self.handlers + .push((Some(handler_name), method, route, Box::new(handler), GuardBuilder::default().build())); + self + } + + /// Add with guard but with a handler name + #[inline] + pub fn add_with_guards_and_name(mut self, handler_name: &'static str, method: Method, route: &'static str, handler: H, guards: F) -> Self + where + H: 'static + DynControllerHandler + Send + Sync, + F: FnOnce(GuardBuilder) -> GuardBuilder, + Chain: GuardChain + 'static, + { + self.handlers + .push((Some(handler_name), method, route, Box::new(handler), guards(GuardBuilder::default()).build())); self } diff --git a/saphir/src/http_context.rs b/saphir/src/http_context.rs index c0ffd12..745253a 100644 --- a/saphir/src/http_context.rs +++ b/saphir/src/http_context.rs @@ -147,6 +147,13 @@ impl State { } } +/// MetaData of the resolved request handler +#[derive(Debug, Clone, Eq, PartialEq, Default)] +pub struct HandlerMetadata { + pub route_id: u64, + pub name: Option<&'static str>, +} + /// Context representing the relationship between a request and a response /// This structure only appears inside Middleware since the act before and after /// the request @@ -161,16 +168,17 @@ pub struct HttpContext { #[cfg(feature = "operation")] /// Unique Identifier of the current request->response chain pub operation_id: crate::http_context::operation::OperationId, + pub metadata: HandlerMetadata, pub(crate) router: Option, } impl HttpContext { - pub(crate) fn new(request: Request, router: Router) -> Self { + pub(crate) fn new(request: Request, router: Router, metadata: HandlerMetadata) -> Self { #[cfg(not(feature = "operation"))] { let state = State::Before(Box::new(request)); let router = Some(router); - HttpContext { state, router } + HttpContext { state, metadata, router } } #[cfg(feature = "operation")] @@ -186,7 +194,12 @@ impl HttpContext { *request.operation_id_mut() = operation_id; let state = State::Before(Box::new(request)); let router = Some(router); - HttpContext { state, router, operation_id } + HttpContext { + state, + router, + operation_id, + metadata, + } } } diff --git a/saphir/src/middleware.rs b/saphir/src/middleware.rs index 728b56e..ffa3483 100644 --- a/saphir/src/middleware.rs +++ b/saphir/src/middleware.rs @@ -172,7 +172,7 @@ impl MiddlewareChain for MiddleChainEnd { fn next(&self, mut ctx: HttpContext) -> BoxFuture<'static, Result> { async { let router = ctx.router.take().ok_or_else(|| SaphirError::Internal(InternalError::Stack))?; - router.handle(ctx).await + router.dispatch(ctx).await } .boxed() } diff --git a/saphir/src/router.rs b/saphir/src/router.rs index 8608e9b..04c2035 100644 --- a/saphir/src/router.rs +++ b/saphir/src/router.rs @@ -15,7 +15,7 @@ use crate::{ error::SaphirError, guard::{Builder as GuardBuilder, GuardChain, GuardChainEnd}, handler::DynHandler, - http_context::{HttpContext, State}, + http_context::{HandlerMetadata, HttpContext, State}, request::Request, responder::{DynResponder, Responder}, utils::{EndpointResolver, EndpointResolverResult}, @@ -140,13 +140,14 @@ impl Builder(mut self, controller: C) -> Builder> { let mut handlers = HashMap::new(); - for (method, subroute, handler, guard_chain) in controller.handlers() { + for (name, method, subroute, handler, guard_chain) in controller.handlers() { let route = format!("{}{}", C::BASE_PATH, subroute); + let meta = name.map(|name| HandlerMetadata { route_id: 0, name: Some(name) }); let endpoint_id = if let Some(er) = self.resolver.get_mut(&route) { - er.add_method(method.clone()); + er.add_method_with_metadata(method.clone(), meta); er.id() } else { - let er = EndpointResolver::new(&route, method.clone()).expect("Unable to construct endpoint resolver"); + let er = EndpointResolver::new_with_metadata(&route, method.clone(), meta).expect("Unable to construct endpoint resolver"); let er_id = er.id(); self.resolver.insert(route, er); er_id @@ -168,9 +169,12 @@ impl Builder Router { let Builder { resolver, chain: controllers } = self; + let mut resolvers: Vec<_> = resolver.into_iter().map(|(_, e)| e).collect(); + resolvers.sort_unstable(); + Router { inner: Arc::new(RouterInner { - resolvers: resolver.into_iter().map(|(_, e)| e).collect(), + resolvers, chain: Box::new(controllers), }), } @@ -199,7 +203,7 @@ impl Router { match endpoint_resolver.resolve(req) { EndpointResolverResult::InvalidPath => continue, EndpointResolverResult::MethodNotAllowed => method_not_allowed = true, - EndpointResolverResult::Match => return Ok(endpoint_resolver.id()), + EndpointResolverResult::Match(_) => return Ok(endpoint_resolver.id()), } } @@ -210,23 +214,31 @@ impl Router { } } - pub async fn handle(self, mut ctx: HttpContext) -> Result { - let mut req = ctx.state.take_request().ok_or_else(|| SaphirError::RequestMovedBeforeHandler)?; - match self.resolve(&mut req) { - Ok(id) => self.dispatch(id, req, ctx).await, - Err(status) => { - ctx.state = State::After(Box::new(status.respond_with_builder(crate::response::Builder::new(), &ctx).build()?)); - Ok(ctx) + pub fn resolve_metadata(&self, req: &mut Request) -> Result<&HandlerMetadata, u16> { + let mut method_not_allowed = false; + + for endpoint_resolver in &self.inner.resolvers { + match endpoint_resolver.resolve(req) { + EndpointResolverResult::InvalidPath => continue, + EndpointResolverResult::MethodNotAllowed => method_not_allowed = true, + EndpointResolverResult::Match(meta) => return Ok(meta), } } + + if method_not_allowed { + Err(405) + } else { + Err(404) + } } - pub async fn dispatch(&self, resolver_id: u64, req: Request, mut ctx: HttpContext) -> Result { + pub async fn dispatch(&self, mut ctx: HttpContext) -> Result { + let req = ctx.state.take_request().ok_or_else(|| SaphirError::RequestMovedBeforeHandler)?; // # SAFETY # // The router is initialized in static memory when calling run on Server. let static_self = unsafe { std::mem::transmute::<&'_ Self, &'static Self>(self) }; let b = crate::response::Builder::new(); - let res = if let Some(responder) = static_self.inner.chain.dispatch(resolver_id, req) { + let res = if let Some(responder) = static_self.inner.chain.dispatch(ctx.metadata.route_id, req) { responder.await.dyn_respond(b, &ctx) } else { 404.respond_with_builder(b, &ctx) diff --git a/saphir/src/server.rs b/saphir/src/server.rs index c54a384..2e7ef45 100644 --- a/saphir/src/server.rs +++ b/saphir/src/server.rs @@ -27,7 +27,7 @@ use crate::{ response::Response, router::{Builder as RouterBuilder, Router, RouterChain, RouterChainEnd}, }; -use http::{HeaderValue, Request as RawRequest, Response as RawResponse}; +use http::{HeaderValue, Request as RawRequest, Response as RawResponse, StatusCode}; /// Default time for request handling is 30 seconds pub const DEFAULT_REQUEST_TIMEOUT_MS: u64 = 30_000; @@ -416,8 +416,17 @@ impl Stack { StackHandler { stack: self, peer_addr } } - async fn invoke(&self, req: Request) -> Result, SaphirError> { - let ctx = HttpContext::new(req, self.router.clone()); + async fn invoke(&self, mut req: Request) -> Result, SaphirError> { + let meta = match self.router.resolve_metadata(&mut req) { + Ok(m) => m, + Err(e) => { + let mut r = Response::new(Body::empty()); + *r.status_mut() = StatusCode::from_u16(e).unwrap_or(StatusCode::INTERNAL_SERVER_ERROR); + return Ok(r); + } + }; + let ctx = HttpContext::new(req, self.router.clone(), meta.clone()); + self.middlewares .next(ctx) .await diff --git a/saphir/src/utils.rs b/saphir/src/utils.rs index 6bb5072..b01645e 100644 --- a/saphir/src/utils.rs +++ b/saphir/src/utils.rs @@ -1,8 +1,9 @@ -use crate::{body::Body, error::SaphirError, request::Request}; +use crate::{body::Body, error::SaphirError, http_context::HandlerMetadata, request::Request}; use http::Method; use regex::Regex; use std::{ - collections::{HashMap, HashSet, VecDeque}, + cmp::{min, Ordering}, + collections::{HashMap, VecDeque}, iter::FromIterator, str::FromStr, sync::atomic::AtomicU64, @@ -17,50 +18,120 @@ use std::{ static ENDPOINT_ID: AtomicU64 = AtomicU64::new(0); -pub enum EndpointResolverResult { +pub enum EndpointResolverResult<'a> { InvalidPath, MethodNotAllowed, - Match, + Match(&'a HandlerMetadata), } +#[derive(Debug, Eq, PartialEq)] +pub enum EndpointResolverMethods { + Specific(HashMap), + Any(HandlerMetadata), +} + +#[derive(Debug, Eq)] pub struct EndpointResolver { - path_matcher: UriPathMatcher, - methods: HashSet, id: u64, - allow_any_method: bool, + path_matcher: UriPathMatcher, + methods: EndpointResolverMethods, +} + +impl Ord for EndpointResolver { + fn cmp(&self, other: &Self) -> Ordering { + self.path_matcher.cmp(&other.path_matcher) + } +} + +impl PartialOrd for EndpointResolver { + fn partial_cmp(&self, other: &Self) -> Option { + Some(self.cmp(other)) + } +} + +impl PartialEq for EndpointResolver { + fn eq(&self, other: &Self) -> bool { + self.cmp(other) == Ordering::Equal + } } impl EndpointResolver { pub fn new(path_str: &str, method: Method) -> Result { - let mut methods = HashSet::new(); - let allow_any_method = method.is_any(); - if !allow_any_method { - methods.insert(method); - } + let id = ENDPOINT_ID.fetch_add(1, std::sync::atomic::Ordering::SeqCst); + let meta = HandlerMetadata { route_id: id, name: None }; + let methods = if method.is_any() { + EndpointResolverMethods::Any(meta) + } else { + let mut methods = HashMap::new(); + methods.insert(method, meta); + EndpointResolverMethods::Specific(methods) + }; Ok(EndpointResolver { path_matcher: UriPathMatcher::new(path_str).map_err(SaphirError::Other)?, methods, - id: ENDPOINT_ID.fetch_add(1, std::sync::atomic::Ordering::SeqCst), - allow_any_method, + id, }) } - pub fn add_method(&mut self, m: Method) { - if !self.allow_any_method && m.is_any() { - self.allow_any_method = true; + pub fn new_with_metadata>>(path_str: &str, method: Method, meta: I) -> Result { + let id = ENDPOINT_ID.fetch_add(1, std::sync::atomic::Ordering::SeqCst); + let mut meta = meta.into().unwrap_or_default(); + meta.route_id = id; + let methods = if method.is_any() { + EndpointResolverMethods::Any(meta) } else { - self.methods.insert(m); + let mut methods = HashMap::new(); + methods.insert(method, meta); + EndpointResolverMethods::Specific(methods) + }; + + Ok(EndpointResolver { + path_matcher: UriPathMatcher::new(path_str).map_err(SaphirError::Other)?, + methods, + id, + }) + } + + pub fn add_method(&mut self, m: Method) { + match &mut self.methods { + EndpointResolverMethods::Specific(inner) => { + if m.is_any() { + panic!("Adding ANY method but an Handler already defines specific methods, This is fatal") + } + let meta = HandlerMetadata { route_id: self.id, name: None }; + inner.insert(m, meta); + } + EndpointResolverMethods::Any(_) => panic!("Adding a specific endpoint method but an Handler already defines ANY method, This is fatal"), + } + } + + pub fn add_method_with_metadata>>(&mut self, m: Method, meta: I) { + match &mut self.methods { + EndpointResolverMethods::Specific(inner) => { + if m.is_any() { + panic!("Adding ANY method but an Handler already defines specific methods, This is fatal") + } + let mut meta = meta.into().unwrap_or_default(); + meta.route_id = self.id; + inner.insert(m, meta); + } + EndpointResolverMethods::Any(_) => panic!("Adding a specific endpoint method but an Handler already defines ANY method, This is fatal"), } } pub fn resolve(&self, req: &mut Request) -> EndpointResolverResult { let path = req.uri().path().to_string(); if self.path_matcher.match_all_and_capture(path, req.captures_mut()) { - if self.allow_any_method || self.methods.contains(req.method()) { - EndpointResolverResult::Match - } else { - EndpointResolverResult::MethodNotAllowed + match &self.methods { + EndpointResolverMethods::Specific(methods) => { + if let Some(meta) = methods.get(req.method()) { + EndpointResolverResult::Match(meta) + } else { + EndpointResolverResult::MethodNotAllowed + } + } + EndpointResolverMethods::Any(meta) => EndpointResolverResult::Match(meta), } } else { EndpointResolverResult::InvalidPath @@ -72,7 +143,7 @@ impl EndpointResolver { } } -#[derive(Debug)] +#[derive(Debug, Eq)] pub(crate) enum UriPathMatcher { Simple { inner: Vec, @@ -84,6 +155,78 @@ pub(crate) enum UriPathMatcher { }, } +impl Ord for UriPathMatcher { + fn cmp(&self, other: &Self) -> Ordering { + let (start_self, end_self, simple_self) = match self { + UriPathMatcher::Simple { inner } => (inner, None, true), + UriPathMatcher::Wildcard { start, end, .. } => (start, Some(end), false), + }; + let (start_other, end_other, simple_other) = match other { + UriPathMatcher::Simple { inner } => (inner, None, true), + UriPathMatcher::Wildcard { start, end, .. } => (start, Some(end), false), + }; + + let i_self = start_self.len(); + let i_other = start_other.len(); + let min_len = min(i_self, i_other); + for i in 0..min_len { + let cmp = start_self[i].cmp(&start_other[i]); + if cmp != Ordering::Equal { + return cmp; + } + } + + if i_self > i_other { + for start in start_self.iter().take(i_self).skip(min_len) { + if let UriPathSegmentMatcher::Static { .. } = start { + return Ordering::Less; + } + } + } + if i_other > i_self { + for start in start_other.iter().take(i_other).skip(min_len) { + if let UriPathSegmentMatcher::Static { .. } = start { + return Ordering::Greater; + } + } + } + + match (end_self, end_other) { + (Some(end_self), Some(end_other)) => { + let j_self = end_self.len(); + let j_other = end_other.len(); + let min_len = min(j_self, j_other); + for j in 0..min_len { + let cmp = end_self[j].cmp(&end_other[j]); + if cmp != Ordering::Equal { + return cmp; + } + } + j_other.cmp(&j_self) + } + (Some(e), None) if !e.is_empty() => Ordering::Less, + (None, Some(e)) if !e.is_empty() => Ordering::Greater, + _ => match (simple_self, simple_other) { + (true, false) => Ordering::Less, + (false, true) => Ordering::Greater, + _ => i_other.cmp(&i_self), + }, + } + } +} + +impl PartialOrd for UriPathMatcher { + fn partial_cmp(&self, other: &Self) -> Option { + Some(self.cmp(other)) + } +} + +impl PartialEq for UriPathMatcher { + fn eq(&self, other: &Self) -> bool { + self.cmp(other) == Ordering::Equal + } +} + impl UriPathMatcher { pub fn new(path_str: &str) -> Result { let uri_path_matcher = if path_str.contains("**") || path_str.contains("..") { @@ -278,6 +421,25 @@ pub(crate) enum UriPathSegmentMatcher { Custom { name: Option, segment: Regex }, Wildcard { prefix: Option, suffix: Option }, } +impl Eq for UriPathSegmentMatcher {} + +impl Ord for UriPathSegmentMatcher { + fn cmp(&self, other: &Self) -> Ordering { + self.ord_index().cmp(&other.ord_index()) + } +} + +impl PartialOrd for UriPathSegmentMatcher { + fn partial_cmp(&self, other: &Self) -> Option { + Some(self.cmp(other)) + } +} + +impl PartialEq for UriPathSegmentMatcher { + fn eq(&self, other: &Self) -> bool { + self.cmp(other) == Ordering::Equal + } +} impl UriPathSegmentMatcher { const SEGMENT_VARIABLE_CLOSING_CHARS: &'static [char] = &['}', '>']; @@ -340,6 +502,16 @@ impl UriPathSegmentMatcher { UriPathSegmentMatcher::Wildcard { .. } => None, } } + + #[inline] + fn ord_index(&self) -> u16 { + match self { + UriPathSegmentMatcher::Static { .. } => 1, + UriPathSegmentMatcher::Variable { .. } => 3, + UriPathSegmentMatcher::Custom { .. } => 2, + UriPathSegmentMatcher::Wildcard { .. } => 3, + } + } } pub trait MethodExtension { @@ -371,3 +543,49 @@ where { serde_urlencoded::from_str::(query_str) } + +#[cfg(test)] +mod tests { + use super::{EndpointResolver, Method}; + use std::{collections::HashMap, str::FromStr}; + + #[test] + fn test_simple_endpoint_resolver_ordering() { + let paths = vec![ + "/api/v1/users", + "/api/v1/users/keys", + "/api/v1/users/keys/", + "/api/v1/users/keys/first", + "/api/v1/users/keys/**", + "/api/v1/users/keys/**/delete", + "/api/v1/users/**/delete", + "/api/v1/users//keys", + "/api/v1/users//", + "/api/v1/users/", + ]; + let mut resolvers = HashMap::new(); + let mut ids = HashMap::new(); + for path in &paths { + let resolver = EndpointResolver::new(path, Method::from_str("GET").unwrap()).unwrap(); + ids.insert(path, resolver.id()); + resolvers.insert(path, resolver); + } + + assert!(resolvers.get(&"/api/v1/users/keys/first") < resolvers.get(&"/api/v1/users/keys/")); + assert!(resolvers.get(&"/api/v1/users/keys/") < resolvers.get(&"/api/v1/users/keys/**")); + + let mut resolvers_vec: Vec<_> = resolvers.into_iter().map(|(_, r)| r).collect(); + resolvers_vec.sort_unstable(); + + assert_eq!(&resolvers_vec[0].id(), ids.get(&"/api/v1/users/keys/first").unwrap()); + assert_eq!(&resolvers_vec[1].id(), ids.get(&"/api/v1/users/keys/**/delete").unwrap()); + assert_eq!(&resolvers_vec[2].id(), ids.get(&"/api/v1/users/keys/").unwrap()); + assert_eq!(&resolvers_vec[3].id(), ids.get(&"/api/v1/users/keys").unwrap()); + assert_eq!(&resolvers_vec[4].id(), ids.get(&"/api/v1/users/keys/**").unwrap()); + assert_eq!(&resolvers_vec[5].id(), ids.get(&"/api/v1/users//keys").unwrap()); + assert_eq!(&resolvers_vec[6].id(), ids.get(&"/api/v1/users/**/delete").unwrap()); + assert_eq!(&resolvers_vec[7].id(), ids.get(&"/api/v1/users//").unwrap()); + assert_eq!(&resolvers_vec[8].id(), ids.get(&"/api/v1/users/").unwrap()); + assert_eq!(&resolvers_vec[9].id(), ids.get(&"/api/v1/users").unwrap()); + } +} diff --git a/saphir_macro/Cargo.toml b/saphir_macro/Cargo.toml index 263dbc7..e19d923 100644 --- a/saphir_macro/Cargo.toml +++ b/saphir_macro/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "saphir_macro" -version = "2.0.10" +version = "2.1.0" authors = ["Richer Archambault "] edition = "2018" description = "Macro generation for http server framework" diff --git a/saphir_macro/src/controller/controller_attr.rs b/saphir_macro/src/controller/controller_attr.rs index 89b142e..2b4d406 100644 --- a/saphir_macro/src/controller/controller_attr.rs +++ b/saphir_macro/src/controller/controller_attr.rs @@ -111,9 +111,10 @@ fn gen_controller_handlers_fn(attr: &ControllerAttr, handlers: &[HandlerRepr]) - for (method, path) in methods_paths { let method = method.as_str(); + let handler_name = handler_ident.to_string(); if guards.is_empty() { (quote! { - .add(Method::from_str(#method).expect("Method was validated by the macro expansion"), #path, #ctrl_ident::#handler_ident) + .add_with_name(#handler_name, Method::from_str(#method).expect("Method was validated by the macro expansion"), #path, #ctrl_ident::#handler_ident) }) .to_tokens(&mut handler_stream); } else { @@ -127,7 +128,7 @@ fn gen_controller_handlers_fn(attr: &ControllerAttr, handlers: &[HandlerRepr]) - } (quote! { - .add_with_guards(Method::from_str(#method).expect("Method was validated the macro expansion"), #path, #ctrl_ident::#handler_ident, |g| { + .add_with_guards_and_name(#handler_name, Method::from_str(#method).expect("Method was validated the macro expansion"), #path, #ctrl_ident::#handler_ident, |g| { g #guard_stream }) })