Skip to content

Commit f66734d

Browse files
feat: integrate authentication middleware with core AuthenticationClient
This commit integrates the authentication middleware implementation with the AuthenticationClient that was moved to the core keylime library, completing the authentication integration. Signed-off-by: Sergio Correia <[email protected]>
1 parent b50d621 commit f66734d

File tree

1 file changed

+85
-42
lines changed

1 file changed

+85
-42
lines changed

keylime/src/resilient_client.rs

Lines changed: 85 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
use crate::auth::{AuthConfig, SessionToken};
1+
use crate::auth::{AuthConfig, AuthenticationClient, SessionToken};
22
use anyhow;
33
use async_trait::async_trait;
44
use chrono::Utc;
@@ -197,17 +197,23 @@ struct TokenState {
197197
token: RwLock<Option<SessionToken>>,
198198
/// Mutex for refresh operations - ensures single writer
199199
refresh_lock: Mutex<()>,
200-
/// Authentication configuration
201-
auth_config: AuthConfig,
200+
/// Raw authentication client (no middleware to avoid loops)
201+
auth_client: AuthenticationClient,
202202
}
203203

204204
impl TokenState {
205-
fn new(auth_config: AuthConfig) -> Self {
206-
Self {
205+
fn new(
206+
auth_config: AuthConfig,
207+
) -> Result<Self, Box<dyn std::error::Error + Send + Sync>> {
208+
// Create a raw authentication client to avoid middleware loops
209+
let auth_client = AuthenticationClient::new_raw(auth_config)
210+
.map_err(|e| format!("Failed to create auth client: {}", e))?;
211+
212+
Ok(Self {
207213
token: RwLock::new(None),
208214
refresh_lock: Mutex::new(()),
209-
auth_config,
210-
}
215+
auth_client,
216+
})
211217
}
212218

213219
async fn get_valid_token(
@@ -217,9 +223,9 @@ impl TokenState {
217223
{
218224
let token_guard = self.token.read().await;
219225
if let Some(ref token) = *token_guard {
220-
if token
221-
.is_valid(self.auth_config.token_refresh_buffer_minutes)
222-
{
226+
if token.is_valid(
227+
self.auth_client.config().token_refresh_buffer_minutes,
228+
) {
223229
debug!("Using existing valid token from middleware");
224230
return Ok(token.token.clone());
225231
}
@@ -240,19 +246,39 @@ impl TokenState {
240246
{
241247
let token_guard = self.token.read().await;
242248
if let Some(ref token) = *token_guard {
243-
if token
244-
.is_valid(self.auth_config.token_refresh_buffer_minutes)
245-
{
249+
if token.is_valid(
250+
self.auth_client.config().token_refresh_buffer_minutes,
251+
) {
246252
debug!("Token was refreshed by another request");
247253
return Ok(token.token.clone());
248254
}
249255
}
250256
}
251257

252-
// TODO: Next, we'll integrate with the actual AuthenticationClient
253-
// For now, this is a placeholder that will be replaced
254-
warn!("Token refresh not yet implemented");
255-
Err("Authentication not yet integrated".into())
258+
// Use the raw authentication client to get a new token with metadata
259+
debug!("Performing token refresh using raw authentication client");
260+
match self.auth_client.get_auth_token_with_metadata().await {
261+
Ok((token_string, expires_at, session_id)) => {
262+
let new_token = SessionToken {
263+
token: token_string.clone(),
264+
expires_at,
265+
session_id,
266+
};
267+
268+
// Store the new token
269+
{
270+
let mut token_guard = self.token.write().await;
271+
*token_guard = Some(new_token);
272+
}
273+
274+
debug!("Token refresh completed successfully");
275+
Ok(token_string)
276+
}
277+
Err(e) => {
278+
warn!("Token refresh failed: {}", e);
279+
Err(format!("Authentication failed: {}", e).into())
280+
}
281+
}
256282
}
257283

258284
async fn clear_token(&self) {
@@ -269,9 +295,11 @@ pub struct AuthenticationMiddleware {
269295
}
270296

271297
impl AuthenticationMiddleware {
272-
pub fn new(auth_config: AuthConfig) -> Self {
273-
let token_state = Arc::new(TokenState::new(auth_config));
274-
Self { token_state }
298+
pub fn new(
299+
auth_config: AuthConfig,
300+
) -> Result<Self, Box<dyn std::error::Error + Send + Sync>> {
301+
let token_state = Arc::new(TokenState::new(auth_config)?);
302+
Ok(Self { token_state })
275303
}
276304

277305
fn is_auth_endpoint(&self, req: &reqwest::Request) -> bool {
@@ -384,7 +412,7 @@ impl ResilientClient {
384412
max_retries: u32,
385413
success_codes: &[StatusCode],
386414
max_delay: Option<std::time::Duration>,
387-
) -> Self {
415+
) -> Result<Self, Box<dyn std::error::Error + Send + Sync>> {
388416
let base_client = client.unwrap_or_default();
389417
let final_max_delay = max_delay.unwrap_or(DEFAULT_MAX_DELAY);
390418

@@ -405,15 +433,15 @@ impl ResilientClient {
405433
// Add authentication middleware if config is provided
406434
if let Some(auth_cfg) = auth_config {
407435
debug!("Adding authentication middleware to client");
408-
let auth_middleware = AuthenticationMiddleware::new(auth_cfg);
436+
let auth_middleware = AuthenticationMiddleware::new(auth_cfg)?;
409437
builder = builder.with(auth_middleware);
410438
}
411439

412440
let client_with_middleware = builder.with(LoggingMiddleware).build();
413441

414-
Self {
442+
Ok(Self {
415443
client: client_with_middleware,
416-
}
444+
})
417445
}
418446

419447
/// Generates a six-character lowercase alphanumeric request ID.
@@ -1084,7 +1112,8 @@ mod tests {
10841112
3,
10851113
&[StatusCode::OK],
10861114
None,
1087-
);
1115+
)
1116+
.unwrap(); //#[allow_ci]
10881117

10891118
// Verify the client was created successfully
10901119
// (We can't easily test the middleware behavior without a mock server,
@@ -1101,7 +1130,8 @@ mod tests {
11011130
3,
11021131
&[StatusCode::OK],
11031132
None,
1104-
);
1133+
)
1134+
.unwrap(); //#[allow_ci]
11051135

11061136
// Verify the client was created successfully
11071137
}
@@ -1119,7 +1149,7 @@ mod tests {
11191149
max_auth_retries: 3,
11201150
};
11211151

1122-
let middleware = AuthenticationMiddleware::new(auth_config);
1152+
let middleware = AuthenticationMiddleware::new(auth_config).unwrap(); //#[allow_ci]
11231153

11241154
// Mock a request to a sessions endpoint (should be detected as auth endpoint)
11251155
let mock_request = reqwest::Request::new(
@@ -1157,15 +1187,21 @@ mod tests {
11571187
max_auth_retries: 3,
11581188
};
11591189

1160-
let token_state = TokenState::new(auth_config);
1190+
let token_state = TokenState::new(auth_config).unwrap(); //#[allow_ci]
11611191

1162-
// Test initially no token
1192+
// Test initially no token - should trigger authentication
11631193
let result = token_state.get_valid_token().await;
1164-
assert!(result.is_err());
1165-
assert!(result
1166-
.unwrap_err() //#[allow_ci]
1167-
.to_string()
1168-
.contains("Authentication not yet integrated"));
1194+
assert!(
1195+
result.is_err(),
1196+
"Should fail when no auth server available"
1197+
);
1198+
// Since we're using a real auth client, we expect authentication-related errors
1199+
let error_msg = result.unwrap_err().to_string(); //#[allow_ci]
1200+
assert!(
1201+
error_msg.contains("Authentication failed"),
1202+
"Error: {}",
1203+
error_msg
1204+
);
11691205

11701206
// Test clear token when no token exists (should not panic)
11711207
token_state.clear_token().await;
@@ -1206,7 +1242,7 @@ mod tests {
12061242
max_auth_retries: 3,
12071243
};
12081244

1209-
let token_state = TokenState::new(auth_config);
1245+
let token_state = TokenState::new(auth_config).unwrap(); //#[allow_ci]
12101246

12111247
// Insert token that expires within buffer time (should be considered invalid)
12121248
{
@@ -1220,11 +1256,17 @@ mod tests {
12201256

12211257
// Should try to refresh because token is within buffer time
12221258
let result = token_state.get_valid_token().await;
1223-
assert!(result.is_err());
1224-
assert!(result
1225-
.unwrap_err() //#[allow_ci]
1226-
.to_string()
1227-
.contains("Authentication not yet integrated"));
1259+
assert!(
1260+
result.is_err(),
1261+
"Should fail due to token expiring within buffer"
1262+
);
1263+
// Since we're using a real auth client, we expect authentication-related errors
1264+
let error_msg = result.unwrap_err().to_string(); //#[allow_ci]
1265+
assert!(
1266+
error_msg.contains("Authentication failed"),
1267+
"Error: {}",
1268+
error_msg
1269+
);
12281270
}
12291271

12301272
#[tokio::test]
@@ -1238,7 +1280,8 @@ mod tests {
12381280
max_auth_retries: 3,
12391281
};
12401282

1241-
let middleware = AuthenticationMiddleware::new(auth_config);
1283+
let middleware =
1284+
AuthenticationMiddleware::new(auth_config).unwrap(); //#[allow_ci]
12421285

12431286
// Test different auth endpoint patterns
12441287
let test_cases = vec![
@@ -1276,7 +1319,7 @@ mod tests {
12761319
max_auth_retries: 3,
12771320
};
12781321

1279-
let token_state = Arc::new(TokenState::new(auth_config));
1322+
let token_state = Arc::new(TokenState::new(auth_config).unwrap()); //#[allow_ci]
12801323

12811324
// Test concurrent access to token state (should not deadlock)
12821325
let mut handles = vec![];

0 commit comments

Comments
 (0)