diff --git a/Cargo.toml b/Cargo.toml index 732558c..1dae594 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -16,6 +16,8 @@ description = "Async API wrapper for DRACOON in Rust." [dependencies] # HTTP client reqwest = {version = "0.11.14", features = ["json", "stream"]} +reqwest-middleware = "0.2.2" +reqwest-retry = "0.2.2" # crypto dco3_crypto = "0.5.0" @@ -35,6 +37,7 @@ serde_json = "1.0.95" # error handling thiserror = "1.0.2" +retry-policies = "0.1.0" # logging and tracing tracing = "0.1.37" diff --git a/src/auth/errors.rs b/src/auth/errors.rs index 46454c8..777426b 100644 --- a/src/auth/errors.rs +++ b/src/auth/errors.rs @@ -1,6 +1,7 @@ use async_trait::async_trait; use dco3_crypto::DracoonCryptoError; -use reqwest::{Error as ReqError, Response}; +use reqwest_middleware::{Error as ReqError}; +use reqwest::{Error as ClientError, Response}; use thiserror::Error; use crate::{nodes::models::S3ErrorResponse, utils::FromResponse}; @@ -43,14 +44,47 @@ pub enum DracoonClientError { impl From for DracoonClientError { fn from(value: ReqError) -> Self { - if value.is_builder() { - return DracoonClientError::Internal; + + match value { + ReqError::Middleware(error) => { + DracoonClientError::ConnectionFailed + + }, + ReqError::Reqwest(error) => { + if error.is_timeout() { + return DracoonClientError::ConnectionFailed + } + + if error.is_connect() { + return DracoonClientError::ConnectionFailed + } + + + DracoonClientError::Unknown + + }, + } + } +} + + +impl From for DracoonClientError { + fn from(value: ClientError) -> Self { + + if value.is_timeout() { + return DracoonClientError::ConnectionFailed; + } + + if value.is_connect() { + return DracoonClientError::ConnectionFailed; } DracoonClientError::Unknown } } + + #[async_trait] impl FromResponse for DracoonClientError { async fn from_response(value: Response) -> Result { diff --git a/src/auth/mod.rs b/src/auth/mod.rs index 321e835..ffe7973 100644 --- a/src/auth/mod.rs +++ b/src/auth/mod.rs @@ -1,8 +1,10 @@ -//! This module is responsible for the authentication with DRACOON and implements +//! This module is responsible for the authentication with DRACOON and implements //! the [DracoonClient] struct which is used to interact with the DRACOON API. use chrono::{DateTime, Utc}; use reqwest::{Client, Url}; -use std::marker::PhantomData; +use reqwest_middleware::{ClientBuilder, ClientWithMiddleware}; +use reqwest_retry::{policies::ExponentialBackoff, RetryTransientMiddleware}; +use std::{marker::PhantomData, time::Duration}; use tracing::{debug, error}; use base64::{ @@ -20,7 +22,10 @@ use crate::{ auth::models::{ OAuth2AuthCodeFlow, OAuth2PasswordFlow, OAuth2TokenResponse, OAuth2TokenRevoke, }, - constants::{DRACOON_TOKEN_REVOKE_URL, DRACOON_TOKEN_URL, TOKEN_TYPE_HINT_ACCESS_TOKEN}, + constants::{ + DRACOON_TOKEN_REVOKE_URL, DRACOON_TOKEN_URL, EXPONENTIAL_BACKOFF_BASE, MAX_RETRIES, + MAX_RETRY_DELAY, MIN_RETRY_DELAY, TOKEN_TYPE_HINT_ACCESS_TOKEN, + }, }; use self::{errors::DracoonClientError, models::OAuth2RefreshTokenFlow}; @@ -56,7 +61,7 @@ pub struct DracoonClient { redirect_uri: Option, client_id: String, client_secret: String, - pub http: Client, + pub http: ClientWithMiddleware, connection: Option, connected: PhantomData, } @@ -68,6 +73,10 @@ pub struct DracoonClientBuilder { redirect_uri: Option, client_id: Option, client_secret: Option, + user_agent: Option, + max_retries: Option, + min_retry_delay: Option, + max_retry_delay: Option, } impl DracoonClientBuilder { @@ -78,6 +87,10 @@ impl DracoonClientBuilder { redirect_uri: None, client_id: None, client_secret: None, + user_agent: None, + max_retries: None, + min_retry_delay: None, + max_retry_delay: None, } } @@ -105,10 +118,60 @@ impl DracoonClientBuilder { self } + pub fn with_user_agent(mut self, user_agent: impl Into) -> Self { + self.user_agent = Some(user_agent.into()); + self + } + + pub fn with_max_retries(mut self, max_retries: u32) -> Self { + self.max_retries = Some(max_retries); + self + } + + pub fn with_min_retry_delay(mut self, min_retry_delay: u64) -> Self { + self.min_retry_delay = Some(min_retry_delay); + self + } + + pub fn with_max_retry_delay(mut self, max_retry_delay: u64) -> Self { + self.max_retry_delay = Some(max_retry_delay); + self + } + /// Builds the [DracoonClient] struct - returns an error if any of the required fields are missing pub fn build(self) -> Result, DracoonClientError> { + let max_retries = self + .max_retries + .unwrap_or(MAX_RETRIES) + .clamp(1, MAX_RETRIES); + let min_retry_delay = self + .min_retry_delay + .unwrap_or(MIN_RETRY_DELAY) + .clamp(300, MIN_RETRY_DELAY); + let max_retry_delay = self + .max_retry_delay + .unwrap_or(MAX_RETRY_DELAY) + .clamp(min_retry_delay, MAX_RETRY_DELAY); + + let retry_policy: ExponentialBackoff = ExponentialBackoff::builder() + .backoff_exponent(EXPONENTIAL_BACKOFF_BASE) + .retry_bounds( + Duration::from_millis(min_retry_delay), + Duration::from_millis(max_retry_delay), + ) + .build_with_max_retries(max_retries); + + let user_agent = match self.user_agent { + Some(user_agent) => format!("{}|{}", user_agent, APP_USER_AGENT), + None => APP_USER_AGENT.to_string(), + }; + let http = Client::builder().user_agent(APP_USER_AGENT).build()?; + let http = ClientBuilder::new(http) + .with(RetryTransientMiddleware::new_with_policy(retry_policy)) + .build(); + let Some(base_url) = self.base_url.clone() else { error!("Missing base url"); return Err(DracoonClientError::MissingBaseUrl) @@ -241,7 +304,8 @@ impl DracoonClient { .header("Authorization", auth_header) .form(&auth) .send() - .await.map_err(|err| { + .await + .map_err(|err| { error!("Error connecting with password flow: {}", err); err })?; @@ -262,10 +326,16 @@ impl DracoonClient { .as_str(), ); - let res = self.http.post(token_url).form(&auth).send().await.map_err(|err| { - error!("Error connecting with auth code flow: {}", err); - err - })?; + let res = self + .http + .post(token_url) + .form(&auth) + .send() + .await + .map_err(|err| { + error!("Error connecting with auth code flow: {}", err); + err + })?; Ok(OAuth2TokenResponse::from_response(res).await?.into()) } @@ -278,10 +348,16 @@ impl DracoonClient { let auth = OAuth2RefreshTokenFlow::new(&self.client_id, &self.client_secret, refresh_token); - let res = self.http.post(token_url).form(&auth).send().await.map_err(|err| { - error!("Error connecting with refresh token flow: {}", err); - err - })?; + let res = self + .http + .post(token_url) + .form(&auth) + .send() + .await + .map_err(|err| { + error!("Error connecting with refresh token flow: {}", err); + err + })?; Ok(OAuth2TokenResponse::from_response(res).await?.into()) } } @@ -318,7 +394,7 @@ impl DracoonClient { http: self.http, }) } - + /// Returns the base url of the DRACOON instance pub fn get_base_url(&self) -> &Url { &self.base_url @@ -450,12 +526,16 @@ mod tests { .with_base_url(url) .with_client_id("client_id") .with_client_secret("client_secret") + .with_user_agent("test_client") + .with_max_retries(1) + .with_max_retry_delay(600) + .with_min_retry_delay(300) .build() .expect("valid client config") } - #[test] - fn test_auth_code_authentication() { + #[tokio::test] + async fn test_auth_code_authentication() { let mut mock_server = mockito::Server::new(); let base_url = mock_server.url(); @@ -477,7 +557,7 @@ mod tests { let auth_code = OAuth2Flow::AuthCodeFlow("hello world".to_string()); - let res = tokio_test::block_on(dracoon.connect(auth_code)); + let res = dracoon.connect(auth_code).await; auth_mock.assert(); assert_ok!(&res); @@ -485,8 +565,8 @@ mod tests { assert!(res.unwrap().connection.is_some()); } - #[test] - fn test_refresh_token_authentication() { + #[tokio::test] + async fn test_refresh_token_authentication() { let mut mock_server = mockito::Server::new(); let base_url = mock_server.url(); @@ -503,7 +583,7 @@ mod tests { let refresh_token_auth = OAuth2Flow::RefreshToken("hello world".to_string()); - let res = tokio_test::block_on(dracoon.connect(refresh_token_auth)); + let res = dracoon.connect(refresh_token_auth).await; auth_mock.assert(); assert_ok!(&res); @@ -533,8 +613,8 @@ mod tests { assert_eq!(expires_in, 3600); } - #[test] - fn test_auth_error_handling() { + #[tokio::test] + async fn test_auth_error_handling() { let mut mock_server = mockito::Server::new(); let base_url = mock_server.url(); @@ -551,15 +631,15 @@ mod tests { let auth_code = OAuth2Flow::AuthCodeFlow("hello world".to_string()); - let res = tokio_test::block_on(dracoon.connect(auth_code)); + let res = dracoon.connect(auth_code).await; auth_mock.assert(); assert!(res.is_err()); } - #[test] - fn test_get_auth_header() { + #[tokio::test] + async fn test_get_auth_header() { let mut mock_server = mockito::Server::new(); let base_url = mock_server.url(); @@ -575,17 +655,17 @@ mod tests { let dracoon = get_test_client(base_url.as_str()); let refresh_token_auth = OAuth2Flow::RefreshToken("hello world".to_string()); - let res = tokio_test::block_on(dracoon.connect(refresh_token_auth)); + let res = dracoon.connect(refresh_token_auth).await; let connected_client = res.unwrap(); - let access_token = tokio_test::block_on(connected_client.get_auth_header()).unwrap(); + let access_token = connected_client.get_auth_header().await.unwrap(); auth_mock.assert(); assert_eq!(access_token, "Bearer access_token"); } - #[test] - fn test_get_token_url() { + #[tokio::test] + async fn test_get_token_url() { let base_url = "https://dracoon.team"; let dracoon = get_test_client(base_url); @@ -595,8 +675,8 @@ mod tests { assert_eq!(token_url.as_str(), "https://dracoon.team/oauth/token"); } - #[test] - fn test_get_base_url() { + #[tokio::test] + async fn test_get_base_url() { let mut mock_server = mockito::Server::new(); let base_url = mock_server.url(); @@ -610,10 +690,10 @@ mod tests { .create(); let dracoon = get_test_client(&base_url); - let dracoon = tokio_test::block_on( - dracoon.connect(OAuth2Flow::AuthCodeFlow("hello world".to_string())), - ) - .unwrap(); + let dracoon = dracoon + .connect(OAuth2Flow::AuthCodeFlow("hello world".to_string())) + .await + .unwrap(); let base_url = dracoon.get_base_url(); @@ -621,8 +701,8 @@ mod tests { assert_eq!(base_url.as_str(), format!("{}/", mock_server.url())); } - #[test] - fn test_get_refresh_token() { + #[tokio::test] + async fn test_get_refresh_token() { let mut mock_server = mockito::Server::new(); let base_url = mock_server.url(); @@ -636,14 +716,39 @@ mod tests { .create(); let dracoon = get_test_client(&base_url); - let dracoon = tokio_test::block_on( - dracoon.connect(OAuth2Flow::AuthCodeFlow("hello world".to_string())), - ) - .unwrap(); + let dracoon = dracoon + .connect(OAuth2Flow::AuthCodeFlow("hello world".to_string())) + .await + .unwrap(); let refresh_token = dracoon.get_refresh_token(); auth_mock.assert(); assert_eq!(refresh_token, "refresh_token"); } + + #[tokio::test] + async fn test_retry_policy() { + let mut mock_server = mockito::Server::new(); + let base_url = mock_server.url(); + + let auth_mock = mock_server + .mock("POST", "/oauth/token") + .with_status(429) + .with_header("content-type", "application/json") + .with_body("Internal Server Error") + .create(); + + let dracoon = get_test_client(&base_url); + let dracoon = dracoon + .connect(OAuth2Flow::AuthCodeFlow("hello world".to_string())) + .await; + + // client retry set to 1 retry for testing + let req_count = 2; + let req_count: usize = req_count.try_into().unwrap(); + + auth_mock.expect_at_least(req_count); + assert!(dracoon.is_err()); + } } diff --git a/src/constants.rs b/src/constants.rs index 8b19f0f..b67924f 100644 --- a/src/constants.rs +++ b/src/constants.rs @@ -71,5 +71,10 @@ pub const USERS_BASE: &str = "users"; pub const USERS_LAST_ADMIN_ROOMS: &str = "last_admin_rooms"; /// user agent header -pub const APP_USER_AGENT: &str = concat!(env!("CARGO_PKG_NAME"), "/", env!("CARGO_PKG_VERSION")); +pub const APP_USER_AGENT: &str = concat!(env!("CARGO_PKG_NAME"), "|", env!("CARGO_PKG_VERSION")); +// retry config +pub const MAX_RETRIES: u32 = 5; +pub const EXPONENTIAL_BACKOFF_BASE: u32 = 3; +pub const MIN_RETRY_DELAY: u64 = 600; // in milliseconds (0.6 seconds) +pub const MAX_RETRY_DELAY: u64 = 20 * 1000; // in milliseconds (20 seconds) diff --git a/src/groups/mod.rs b/src/groups/mod.rs index 4f46fe7..6dc71d5 100644 --- a/src/groups/mod.rs +++ b/src/groups/mod.rs @@ -1,6 +1,7 @@ use async_trait::async_trait; +#[allow(clippy::module_inception)] mod groups; mod models; diff --git a/src/lib.rs b/src/lib.rs index 2d28ec2..2cf0148 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -30,7 +30,7 @@ //! //! ### Example //! ```no_run -//! use dco3::{Dracoon, auth::OAuth2Flow, user::User}; +//! use dco3::{Dracoon, OAuth2Flow, User}; //! //! #[tokio::main] //! async fn main() { @@ -55,7 +55,7 @@ //! //! ### Password Flow //! ```no_run -//! use dco3::{Dracoon, auth::OAuth2Flow}; +//! use dco3::{Dracoon, OAuth2Flow}; //! //! #[tokio::main] //! async fn main() { @@ -76,7 +76,7 @@ //!``` //! ### Authorization Code Flow //! ```no_run -//! use dco3::{Dracoon, auth::OAuth2Flow}; +//! use dco3::{Dracoon, OAuth2Flow}; //! //! #[tokio::main] //! async fn main() { @@ -105,7 +105,7 @@ //! ### Refresh Token //! //! ```no_run -//! use dco3::{Dracoon, auth::OAuth2Flow}; +//! use dco3::{Dracoon, OAuth2Flow}; //! //! #[tokio::main] //! async fn main() { @@ -137,7 +137,7 @@ //! You can check if the underlying error message if a specific API error by using the `is_*` methods. //! //! ```no_run -//! use dco3::{Dracoon, auth::OAuth2Flow, Nodes}; +//! use dco3::{Dracoon, OAuth2Flow, Nodes}; //! //! #[tokio::main] //! @@ -169,6 +169,33 @@ //! //! ``` //! +//! ### Retries +//! The client will automatically retry failed requests. +//! You can configure the retry behavior by passing your config during client creation. +//! +//! Default values are: 5 retries, min delay 600ms, max delay 20s. +//! Keep in mind that you cannot set arbitrary values - for all values, minimum and maximum values are defined. +//! +//! ``` +//! +//! use dco3::{Dracoon, OAuth2Flow}; +//! +//! #[tokio::main] +//! async fn main() { +//! +//! let dracoon = Dracoon::builder() +//! .with_base_url("https://dracoon.team") +//! .with_client_id("client_id") +//! .with_client_secret("client_secret") +//! .with_max_retries(3) +//! .with_min_retry_delay(400) +//! .with_max_retry_delay(1000) +//! .build(); +//! +//! } +//! +//! ``` +//! //! ## Building requests //! //! All API calls are implemented as traits. @@ -176,7 +203,7 @@ //! To access the builder, you can call the builder() method. //! //! ```no_run -//! # use dco3::{Dracoon, auth::OAuth2Flow, Rooms, nodes::CreateRoomRequest}; +//! # use dco3::{Dracoon, OAuth2Flow, Rooms, nodes::CreateRoomRequest}; //! # #[tokio::main] //! # async fn main() { //! # let dracoon = Dracoon::builder() @@ -199,7 +226,7 @@ //! ``` //! Some requests do not have any complicated fields - in these cases, use the `new()` method. //! ```no_run -//! # use dco3::{Dracoon, auth::OAuth2Flow, Groups, groups::CreateGroupRequest}; +//! # use dco3::{Dracoon, OAuth2Flow, Groups, groups::CreateGroupRequest}; //! # #[tokio::main] //! # async fn main() { //! # let dracoon = Dracoon::builder() @@ -364,6 +391,26 @@ impl DracoonBuilder { self } + pub fn with_user_agent(mut self, user_agent: impl Into) -> Self { + self.client_builder = self.client_builder.with_user_agent(user_agent); + self + } + + pub fn with_max_retries(mut self, max_retries: u32) -> Self { + self.client_builder = self.client_builder.with_max_retries(max_retries); + self + } + + pub fn with_min_retry_delay(mut self, min_retry_delay: u64) -> Self { + self.client_builder = self.client_builder.with_min_retry_delay(min_retry_delay); + self + } + + pub fn with_max_retry_delay(mut self, max_retry_delay: u64) -> Self { + self.client_builder = self.client_builder.with_max_retry_delay(max_retry_delay); + self + } + /// Builds the `Dracoon` struct - fails, if any of the required fields are missing pub fn build(self) -> Result, DracoonClientError> { let dracoon = self.client_builder.build()?; diff --git a/src/nodes/models/filters.rs b/src/nodes/models/filters.rs index 93cd5d8..3fe2ca2 100644 --- a/src/nodes/models/filters.rs +++ b/src/nodes/models/filters.rs @@ -53,11 +53,11 @@ impl FilterQuery for NodesFilter { impl NodesFilter { pub fn name_equals(val: impl Into) -> Self { - NodesFilter::Name(FilterOperator::Eq, String::from(val.into())) + NodesFilter::Name(FilterOperator::Eq, val.into()) } pub fn name_contains(val: impl Into) -> Self { - NodesFilter::Name(FilterOperator::Cn, String::from(val.into())) + NodesFilter::Name(FilterOperator::Cn, val.into()) } pub fn is_encrypted(val: bool) -> Self { @@ -69,19 +69,19 @@ impl NodesFilter { } pub fn created_before(val: impl Into) -> Self { - NodesFilter::TimestampCreation(FilterOperator::Le, String::from(val.into())) + NodesFilter::TimestampCreation(FilterOperator::Le, val.into()) } pub fn created_after(val: impl Into) -> Self { - NodesFilter::TimestampCreation(FilterOperator::Ge, String::from(val.into())) + NodesFilter::TimestampCreation(FilterOperator::Ge, val.into()) } pub fn modified_before(val: impl Into) -> Self { - NodesFilter::TimestampModification(FilterOperator::Le, String::from(val.into())) + NodesFilter::TimestampModification(FilterOperator::Le, val.into()) } pub fn modified_after(val: impl Into) -> Self { - NodesFilter::TimestampModification(FilterOperator::Ge, String::from(val.into())) + NodesFilter::TimestampModification(FilterOperator::Ge, val.into()) } pub fn branch_version_before(val: u64) -> Self { @@ -152,11 +152,11 @@ impl NodesSearchFilter { } pub fn parent_path_equals(val: impl Into) -> Self { - NodesSearchFilter::ParentPath(FilterOperator::Eq, String::from(val.into())) + NodesSearchFilter::ParentPath(FilterOperator::Eq, val.into()) } pub fn parent_path_contains(val: impl Into) -> Self { - NodesSearchFilter::ParentPath(FilterOperator::Cn, String::from(val.into())) + NodesSearchFilter::ParentPath(FilterOperator::Cn, val.into()) } pub fn size_greater_equals(val: u64) -> Self { @@ -176,27 +176,27 @@ impl NodesSearchFilter { } pub fn created_at_before(val: impl Into) -> Self { - NodesSearchFilter::CreatedAt(FilterOperator::Le, String::from(val.into())) + NodesSearchFilter::CreatedAt(FilterOperator::Le, val.into()) } pub fn created_at_after(val: impl Into) -> Self { - NodesSearchFilter::CreatedAt(FilterOperator::Ge, String::from(val.into())) + NodesSearchFilter::CreatedAt(FilterOperator::Ge, val.into()) } pub fn updated_at_before(val: impl Into) -> Self { - NodesSearchFilter::UpdatedAt(FilterOperator::Le, String::from(val.into())) + NodesSearchFilter::UpdatedAt(FilterOperator::Le, val.into()) } pub fn updated_at_after(val: impl Into) -> Self { - NodesSearchFilter::UpdatedAt(FilterOperator::Ge, String::from(val.into())) + NodesSearchFilter::UpdatedAt(FilterOperator::Ge, val.into()) } pub fn expire_at_before(val: impl Into) -> Self { - NodesSearchFilter::ExpireAt(FilterOperator::Le, String::from(val.into())) + NodesSearchFilter::ExpireAt(FilterOperator::Le, val.into()) } pub fn expire_at_after(val: impl Into) -> Self { - NodesSearchFilter::ExpireAt(FilterOperator::Ge, String::from(val.into())) + NodesSearchFilter::ExpireAt(FilterOperator::Ge, val.into()) } pub fn classification_equals(val: u8) -> Self { @@ -204,11 +204,11 @@ impl NodesSearchFilter { } pub fn file_type_equals(val: impl Into) -> Self { - NodesSearchFilter::FileType(FilterOperator::Eq, String::from(val.into())) + NodesSearchFilter::FileType(FilterOperator::Eq, val.into()) } pub fn file_type_contains(val: impl Into) -> Self { - NodesSearchFilter::FileType(FilterOperator::Cn, String::from(val.into())) + NodesSearchFilter::FileType(FilterOperator::Cn, val.into()) } } diff --git a/src/users/mod.rs b/src/users/mod.rs index d9462c5..a6e8d5e 100644 --- a/src/users/mod.rs +++ b/src/users/mod.rs @@ -1,6 +1,7 @@ use async_trait::async_trait; mod models; +#[allow(clippy::module_inception)] mod users; pub use models::*; diff --git a/src/users/models.rs b/src/users/models.rs index 3069152..cb82f1c 100644 --- a/src/users/models.rs +++ b/src/users/models.rs @@ -217,6 +217,7 @@ impl UpdateUserRequest { } } +#[derive(Default)] pub struct UpdateUserRequestBuilder { first_name: Option, last_name: Option, @@ -379,7 +380,7 @@ impl UserAuthData { login: None, ad_config_id: None, oid_config_id: None, - password: password.map(|p| p.into()), + password, must_change_password: Some(must_change_password), } } @@ -388,7 +389,7 @@ impl UserAuthData { let login: String = login.into(); Self { method: AuthMethod::OpenIdConnect{ login: login.clone(), oid_config_id }.into(), - login: Some(login.into()), + login: Some(login), ad_config_id: None, oid_config_id: Some(oid_config_id), password: None, @@ -400,7 +401,7 @@ impl UserAuthData { let login: String = login.into(); Self { method: AuthMethod::ActiveDirectory{ login: login.clone(), ad_config_id }.into(), - login: Some(login.into()), + login: Some(login), ad_config_id: Some(ad_config_id), oid_config_id: None, password: None,