Skip to content

Commit

Permalink
Merge pull request #18 from unbekanntes-pferd/features/0.9.0
Browse files Browse the repository at this point in the history
features/0.9.0
  • Loading branch information
unbekanntes-pferd authored Apr 2, 2024
2 parents 83424ae + e8d6cd6 commit bafc800
Show file tree
Hide file tree
Showing 7 changed files with 109 additions and 50 deletions.
2 changes: 1 addition & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[package]
name = "dco3"
version = "0.8.0"
version = "0.9.0"
edition = "2021"
authors = ["Octavio Simone"]
repository = "https://github.com/unbekanntes-pferd/dco3"
Expand Down
75 changes: 53 additions & 22 deletions src/auth/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -251,7 +251,7 @@ impl DracoonClientBuilder {
None => APP_USER_AGENT.to_string(),
};

let http = Client::builder().user_agent(APP_USER_AGENT).build()?;
let http = Client::builder().user_agent(user_agent).build()?;
let upload_http = http.clone();

let http = ClientBuilder::new(http)
Expand Down Expand Up @@ -404,7 +404,9 @@ impl DracoonClient<Disconnected> {
additional_connections.push(new_connection);
}

self.additional_connections.set(additional_connections);
self.additional_connections
.set(additional_connections)
.await;
}

Ok(DracoonClient {
Expand Down Expand Up @@ -596,6 +598,7 @@ impl DracoonClient<Connected> {
let access_token = self
.connection
.get()
.await
.expect("Connected client has no connection")
.access_token
.clone();
Expand All @@ -622,6 +625,7 @@ impl DracoonClient<Connected> {
let refresh_token = self
.connection
.get()
.await
.expect("Connected client has no connection")
.refresh_token
.clone();
Expand Down Expand Up @@ -650,6 +654,7 @@ impl DracoonClient<Connected> {
let refresh_token = self
.connection
.get()
.await
.expect("Connected client has no connection")
.refresh_token
.clone();
Expand All @@ -665,15 +670,17 @@ impl DracoonClient<Connected> {
pub async fn get_auth_header(&self) -> Result<String, DracoonClientError> {
if let Some(token_rotation) = self.token_rotation {
// get the next connection in the rotation
let connection = match self.curr_connection.get() {
let connection = match self.curr_connection.get().await {
Some(CurrentConnection::Main) => self
.connection
.get()
.await
.expect("Connected client has no connection"),
Some(CurrentConnection::Additional(idx)) => {
let additional_connections = self
.additional_connections
.get()
.await
.expect("Connected client has no additional connections");

additional_connections
Expand All @@ -684,6 +691,7 @@ impl DracoonClient<Connected> {
None => self
.connection
.get()
.await
.expect("Connected client has no connection"),
};

Expand All @@ -692,26 +700,29 @@ impl DracoonClient<Connected> {
let new_connection = self.connect_refresh_token().await?;
let access_token = new_connection.access_token.clone();

match self.curr_connection.get() {
Some(CurrentConnection::Main) => self.connection.set(new_connection),
match self.curr_connection.get().await {
Some(CurrentConnection::Main) => self.connection.set(new_connection).await,
Some(CurrentConnection::Additional(idx)) => {
let mut additional_connections = self
.additional_connections
.get()
.await
.expect("Connected client has no additional connections");

additional_connections[idx as usize] = new_connection;
self.additional_connections.set(additional_connections);
self.additional_connections
.set(additional_connections)
.await;
}
None => self.connection.set(new_connection),
None => self.connection.set(new_connection).await,
}

// no need to rotate, there's a new access token
return Ok(format!("Bearer {}", access_token));
}

// rotate the connection
let next_connection = match self.curr_connection.get() {
let next_connection = match self.curr_connection.get().await {
Some(CurrentConnection::Main) => CurrentConnection::Additional(0),
Some(CurrentConnection::Additional(idx)) => {
if idx + 1 < token_rotation - 1 {
Expand All @@ -723,37 +734,40 @@ impl DracoonClient<Connected> {
None => CurrentConnection::Main,
};

self.curr_connection.set(next_connection);
self.curr_connection.set(next_connection).await;

return Ok(format!("Bearer {}", connection.access_token));
}

if self.is_connection_expired() {
if self.is_connection_expired().await {
let new_connection = self.connect_refresh_token().await?;
self.connection.set(new_connection);
self.connection.set(new_connection).await;
}

Ok(format!(
"Bearer {}",
self.connection
.get()
.await
.expect("Connected client has no connection")
.access_token
))
}

/// Returns the refresh token
pub fn get_refresh_token(&self) -> String {
pub async fn get_refresh_token(&self) -> String {
self.connection
.get()
.await
.expect("Connected client has no connection")
.refresh_token()
}

/// Checks if the access token is still valid
fn is_connection_expired(&self) -> bool {
async fn is_connection_expired(&self) -> bool {
self.connection
.get()
.await
.expect("Connected client has no connection")
.is_expired()
}
Expand Down Expand Up @@ -858,7 +872,7 @@ mod tests {
auth_mock.assert();
assert!(&res.is_ok());

assert!(res.unwrap().connection.is_some());
assert!(res.unwrap().connection.is_some().await);
}

#[tokio::test]
Expand All @@ -884,13 +898,14 @@ mod tests {
auth_mock.assert();
assert!(&res.is_ok());

assert!(res.as_ref().unwrap().connection.is_some());
assert!(res.as_ref().unwrap().connection.is_some().await);

let access_token = res
.as_ref()
.unwrap()
.connection
.get()
.await
.unwrap()
.access_token();

Expand All @@ -899,10 +914,18 @@ mod tests {
.unwrap()
.connection
.get()
.await
.unwrap()
.refresh_token();

let expires_in = res.as_ref().unwrap().connection.get().unwrap().expires_in();
let expires_in = res
.as_ref()
.unwrap()
.connection
.get()
.await
.unwrap()
.expires_in();

assert_eq!(access_token, "access_token");
assert_eq!(refresh_token, "refresh_token");
Expand Down Expand Up @@ -1015,15 +1038,16 @@ mod tests {
auth_mock_5.assert();

assert_eq!(dracoon.token_rotation, Some(5));
assert_eq!(dracoon.additional_connections.get().unwrap().len(), 4);
assert_eq!(dracoon.additional_connections.get().await.unwrap().len(), 4);
assert_eq!(
dracoon.curr_connection.get().unwrap(),
dracoon.curr_connection.get().await.unwrap(),
CurrentConnection::Main
);
assert_eq!(
dracoon
.additional_connections
.get()
.await
.unwrap()
.first()
.unwrap()
Expand All @@ -1034,6 +1058,7 @@ mod tests {
dracoon
.additional_connections
.get()
.await
.unwrap()
.get(1)
.unwrap()
Expand All @@ -1044,6 +1069,7 @@ mod tests {
dracoon
.additional_connections
.get()
.await
.unwrap()
.get(2)
.unwrap()
Expand All @@ -1054,6 +1080,7 @@ mod tests {
dracoon
.additional_connections
.get()
.await
.unwrap()
.get(3)
.unwrap()
Expand Down Expand Up @@ -1087,15 +1114,16 @@ mod tests {
auth_mock.assert();

assert_eq!(dracoon.token_rotation, Some(5));
assert_eq!(dracoon.additional_connections.get().unwrap().len(), 4);
assert_eq!(dracoon.additional_connections.get().await.unwrap().len(), 4);
assert_eq!(
dracoon.curr_connection.get().unwrap(),
dracoon.curr_connection.get().await.unwrap(),
CurrentConnection::Main
);
assert_eq!(
dracoon
.additional_connections
.get()
.await
.unwrap()
.first()
.unwrap()
Expand All @@ -1106,6 +1134,7 @@ mod tests {
dracoon
.additional_connections
.get()
.await
.unwrap()
.get(1)
.unwrap()
Expand All @@ -1116,6 +1145,7 @@ mod tests {
dracoon
.additional_connections
.get()
.await
.unwrap()
.get(2)
.unwrap()
Expand All @@ -1126,6 +1156,7 @@ mod tests {
dracoon
.additional_connections
.get()
.await
.unwrap()
.get(3)
.unwrap()
Expand Down Expand Up @@ -1157,7 +1188,7 @@ mod tests {
auth_mock.assert();

assert_eq!(dracoon.token_rotation, None);
assert!(dracoon.additional_connections.get().is_none());
assert!(dracoon.additional_connections.is_none().await);
}

#[tokio::test]
Expand Down Expand Up @@ -1217,7 +1248,7 @@ mod tests {
.await
.unwrap();

let refresh_token = dracoon.get_refresh_token();
let refresh_token = dracoon.get_refresh_token().await;

auth_mock.assert();
assert_eq!(refresh_token, "refresh_token");
Expand Down
Loading

0 comments on commit bafc800

Please sign in to comment.