diff --git a/Cargo.lock b/Cargo.lock index 959112f82..70a16b5e5 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2037,6 +2037,7 @@ dependencies = [ "rand", "reqwest", "ring", + "rstest", "serde", "serde_json", "serde_test", diff --git a/Cargo.toml b/Cargo.toml index a2b7cbbb7..f7bfe1e76 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -45,6 +45,8 @@ opentelemetry = { version = "0.19", features = ["metrics"] } prio = { version = "0.12.2", features = ["multithreaded"] } serde = { version = "1.0.175", features = ["derive"] } rstest = "0.17.0" +thiserror = "1.0" +tokio = { version = "1.29", features = ["full", "tracing"] } trillium = "0.2.9" trillium-api = { version = "0.2.0-rc.3", default-features = false } trillium-caching-headers = "0.2.1" diff --git a/aggregator/Cargo.toml b/aggregator/Cargo.toml index 5ce35ecd0..56d722c90 100644 --- a/aggregator/Cargo.toml +++ b/aggregator/Cargo.toml @@ -82,7 +82,7 @@ signal-hook = "0.3.17" signal-hook-tokio = { version = "0.3.1", features = ["futures-v0_3"] } testcontainers = { version = "0.14.0", optional = true } thiserror = "1.0" -tokio = { version = "1.29", features = ["full", "tracing"] } +tokio.workspace = true tokio-postgres = { version = "0.7.8", features = ["with-chrono-0_4", "with-serde_json-1", "with-uuid-1", "array-impls"] } tonic = { version = "0.8", optional = true, features = ["tls", "tls-webpki-roots"] } # keep this version in sync with what opentelemetry-otlp uses tracing = "0.1.37" diff --git a/aggregator/src/aggregator/http_handlers.rs b/aggregator/src/aggregator/http_handlers.rs index 0b663d4b6..cbcc3f2ea 100644 --- a/aggregator/src/aggregator/http_handlers.rs +++ b/aggregator/src/aggregator/http_handlers.rs @@ -4,7 +4,7 @@ use async_trait::async_trait; use janus_aggregator_core::{datastore::Datastore, instrumented}; use janus_core::{ http::extract_bearer_token, - task::{AuthenticationToken, DapAuthToken, DAP_AUTH_HEADER}, + task::{AuthenticationToken, DAP_AUTH_HEADER}, time::Clock, }; use janus_messages::{ @@ -538,14 +538,13 @@ fn parse_auth_token(task_id: &TaskId, conn: &Conn) -> Result( .iter() .filter(|token| !token.is_empty()) .map(|token| { - let token_bytes = STANDARD - .decode(token) - .context("couldn't base64-decode aggregator API auth token")?; - - Ok(SecretBytes::new(token_bytes)) + // Aggregator API auth tokens are always bearer tokens + AuthenticationToken::new_bearer_token_from_string(token) + .context("invalid aggregator API auth token") }) .collect::>>()?; diff --git a/aggregator_api/src/lib.rs b/aggregator_api/src/lib.rs index 291ce1fd9..04523444c 100644 --- a/aggregator_api/src/lib.rs +++ b/aggregator_api/src/lib.rs @@ -10,10 +10,8 @@ use janus_aggregator_core::{ SecretBytes, }; use janus_core::{ - hpke::generate_hpke_config_and_private_key, - http::extract_bearer_token, - task::{AuthenticationToken, DapAuthToken}, - time::Clock, + hpke::generate_hpke_config_and_private_key, http::extract_bearer_token, + task::AuthenticationToken, time::Clock, }; use janus_messages::{ query_type::Code as SupportedQueryType, Duration, HpkeAeadId, HpkeKdfId, HpkeKemId, Role, @@ -42,7 +40,7 @@ use url::Url; /// Represents the configuration for an instance of the Aggregator API. #[derive(Clone)] pub struct Config { - pub auth_tokens: Vec, + pub auth_tokens: Vec, pub public_dap_url: Url, } @@ -217,6 +215,7 @@ async fn post_task( // struct `aggregator_core::task::Task` expects to get two aggregator endpoint URLs, but only // the one for the peer aggregator is in the incoming request (or for that matter, is ever used // by Janus), so we insert a fake URL for "self". + // TODO(#1524): clean this up with `aggregator_core::task::Task` changes // unwrap safety: this fake URL is valid let fake_aggregator_url = Url::parse("http://never-used.example.com").unwrap(); let aggregator_endpoints = match req.role { @@ -252,33 +251,16 @@ async fn post_task( let vdaf_verify_keys = Vec::from([SecretBytes::new(vdaf_verify_key_bytes)]); - let aggregator_auth_tokens = match req.role { + let (aggregator_auth_tokens, collector_auth_tokens) = match req.role { Role::Leader => { - let encoded = req.aggregator_auth_token.as_ref().ok_or_else(|| { + let aggregator_auth_token = req.aggregator_auth_token.ok_or_else(|| { Error::new( "aggregator acting in leader role must be provided an aggregator auth token" .to_string(), Status::BadRequest, ) })?; - let token_bytes = URL_SAFE_NO_PAD.decode(encoded).map_err(|err| { - Error::new( - format!("Invalid base64 value for aggregator_auth_token: {err}"), - Status::BadRequest, - ) - })?; - Vec::from([ - // TODO(#472): Each token in the PostTaskReq should indicate whether it is a bearer - // token or a DAP-Auth-Token. For now, assume the latter. - DapAuthToken::try_from(token_bytes) - .map(AuthenticationToken::DapAuth) - .map_err(|_| { - Error::new( - "Invalid HTTP header value in aggregator_auth_token".to_string(), - Status::BadRequest, - ) - })?, - ]) + (Vec::from([aggregator_auth_token]), Vec::from([random()])) } Role::Helper => { @@ -289,19 +271,13 @@ async fn post_task( Status::BadRequest, )); } - // TODO(#472): switch to generating bearer tokens by default - Vec::from([AuthenticationToken::DapAuth(random())]) + + (Vec::from([random()]), Vec::new()) } _ => unreachable!(), }; - let collector_auth_tokens = match req.role { - // TODO(#472): switch to generating bearer tokens by default - Role::Leader => Vec::from([AuthenticationToken::DapAuth(random())]), - Role::Helper => Vec::new(), - _ => unreachable!(), - }; let hpke_keys = Vec::from([generate_hpke_config_and_private_key( random(), HpkeKemId::X25519HkdfSha256, @@ -426,7 +402,7 @@ async fn get_task_metrics( mod models { use base64::{engine::general_purpose::URL_SAFE_NO_PAD, Engine}; use janus_aggregator_core::task::{QueryType, Task}; - use janus_core::task::VdafInstance; + use janus_core::task::{AuthenticationToken, VdafInstance}; use janus_messages::{ query_type::Code as SupportedQueryType, Duration, HpkeConfig, Role, TaskId, Time, }; @@ -493,10 +469,9 @@ mod models { pub(crate) time_precision: Duration, /// HPKE configuration for the collector. pub(crate) collector_hpke_config: HpkeConfig, - /// If this aggregator is the leader, this is the bearer token to use to authenticate - /// requests to the helper, as Base64 encoded bytes. If this aggregator is the helper, the - /// value is `None`. - pub(crate) aggregator_auth_token: Option, + /// If this aggregator is the leader, this is the token to use to authenticate requests to + /// the helper. If this aggregator is the helper, the value is `None`. + pub(crate) aggregator_auth_token: Option, } #[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)] @@ -529,20 +504,18 @@ mod models { /// How much clock skew to allow between client and aggregator. Reports from /// farther than this duration into the future will be rejected. pub(crate) tolerable_clock_skew: Duration, - /// The authentication token for inter-aggregator communication in this task, as Base64 - /// encoded bytes. + /// The authentication token for inter-aggregator communication in this task. /// If `role` is Leader, this token is used by the aggregator to authenticate requests to /// the Helper. If `role` is Helper, this token is used by the aggregator to authenticate /// requests from the Leader. // TODO(#1509): This field will have to change as Janus helpers will only store a salted // hash of aggregator auth tokens. - pub(crate) aggregator_auth_token: String, - /// The authentication token used by the task's Collector to authenticate to the Leader, as - /// Base64 encoded bytes. + pub(crate) aggregator_auth_token: AuthenticationToken, + /// The authentication token used by the task's Collector to authenticate to the Leader. /// `Some` if `role` is Leader, `None` otherwise. // TODO(#1509) This field will have to change as Janus leaders will only store a salted hash // of collector auth tokens. - pub(crate) collector_auth_token: Option, + pub(crate) collector_auth_token: Option, /// HPKE configuration used by the collector to decrypt aggregate shares. pub(crate) collector_hpke_config: HpkeConfig, /// HPKE configuration(s) used by this aggregator to decrypt report shares. @@ -578,9 +551,8 @@ mod models { Role::Leader => { if task.collector_auth_tokens().len() != 1 { return Err("illegal number of collector auth tokens in task"); - } else { - Some(URL_SAFE_NO_PAD.encode(task.collector_auth_tokens()[0].as_ref())) } + Some(task.primary_collector_auth_token().clone()) } Role::Helper => None, _ => return Err("illegal aggregator role in task"), @@ -606,8 +578,7 @@ mod models { min_batch_size: task.min_batch_size(), time_precision: *task.time_precision(), tolerable_clock_skew: *task.tolerable_clock_skew(), - aggregator_auth_token: URL_SAFE_NO_PAD - .encode(task.aggregator_auth_tokens()[0].as_ref()), + aggregator_auth_token: task.primary_aggregator_auth_token().clone(), collector_auth_token, collector_hpke_config: task.collector_hpke_config().clone(), aggregator_hpke_configs, @@ -646,10 +617,7 @@ mod tests { models::{GetTaskIdsResp, GetTaskMetricsResp, PostTaskReq, TaskResp}, Config, CONTENT_TYPE, }; - use base64::{ - engine::general_purpose::{STANDARD, URL_SAFE_NO_PAD}, - Engine, - }; + use base64::{engine::general_purpose::URL_SAFE_NO_PAD, Engine}; use futures::future::try_join_all; use janus_aggregator_core::{ datastore::{ @@ -665,7 +633,7 @@ mod tests { }; use janus_core::{ hpke::{generate_hpke_config_and_private_key, HpkeKeypair, HpkePrivateKey}, - task::{AuthenticationToken, DapAuthToken, VdafInstance}, + task::{AuthenticationToken, VdafInstance}, test_util::{ dummy_vdaf::{self, AggregationParam}, install_test_trace_subscriber, @@ -685,7 +653,7 @@ mod tests { prelude::{delete, get, post}, }; - const AUTH_TOKEN: &str = "auth_token"; + const AUTH_TOKEN: &str = "Y29sbGVjdG9yLWFiY2RlZjAw"; async fn setup_api_test() -> (impl Handler, EphemeralDatastore, Arc>) { install_test_trace_subscriber(); @@ -694,7 +662,10 @@ mod tests { let handler = aggregator_api_handler( Arc::clone(&datastore), Config { - auth_tokens: Vec::from([SecretBytes::new(AUTH_TOKEN.as_bytes().to_vec())]), + auth_tokens: Vec::from([AuthenticationToken::new_bearer_token_from_string( + AUTH_TOKEN, + ) + .unwrap()]), public_dap_url: "https://dap.url".parse().unwrap(), }, ); @@ -707,10 +678,7 @@ mod tests { let (handler, ..) = setup_api_test().await; assert_response!( get("/") - .with_request_header( - "Authorization", - format!("Bearer {}", STANDARD.encode(AUTH_TOKEN)) - ) + .with_request_header("Authorization", format!("Bearer {}", AUTH_TOKEN)) .with_request_header("Accept", CONTENT_TYPE) .run_async(&handler) .await, @@ -758,10 +726,7 @@ mod tests { // Verify: we can get the task IDs we wrote back from the API. assert_response!( get("/task_ids") - .with_request_header( - "Authorization", - format!("Bearer {}", STANDARD.encode(AUTH_TOKEN)) - ) + .with_request_header("Authorization", format!("Bearer {AUTH_TOKEN}")) .with_request_header("Accept", CONTENT_TYPE) .run_async(&handler) .await, @@ -775,10 +740,7 @@ mod tests { "/task_ids?pagination_token={}", task_ids.first().unwrap() )) - .with_request_header( - "Authorization", - format!("Bearer {}", STANDARD.encode(AUTH_TOKEN)) - ) + .with_request_header("Authorization", format!("Bearer {AUTH_TOKEN}")) .with_request_header("Accept", CONTENT_TYPE) .run_async(&handler) .await, @@ -793,10 +755,7 @@ mod tests { "/task_ids?pagination_token={}", task_ids.last().unwrap() )) - .with_request_header( - "Authorization", - format!("Bearer {}", STANDARD.encode(AUTH_TOKEN)) - ) + .with_request_header("Authorization", format!("Bearer {AUTH_TOKEN}")) .with_request_header("Accept", CONTENT_TYPE) .run_async(&handler) .await, @@ -817,10 +776,7 @@ mod tests { // Verify: requests without the Accept header are denied. assert_response!( get("/task_ids") - .with_request_header( - "Authorization", - format!("Bearer {}", STANDARD.encode(AUTH_TOKEN)) - ) + .with_request_header("Authorization", format!("Bearer {AUTH_TOKEN}")) .run_async(&handler) .await, Status::NotAcceptable, @@ -855,15 +811,12 @@ mod tests { ) .config() .clone(), - aggregator_auth_token: Some(URL_SAFE_NO_PAD.encode(&aggregator_auth_token)), + aggregator_auth_token: Some(aggregator_auth_token), }; assert_response!( post("/tasks") .with_request_body(serde_json::to_vec(&req).unwrap()) - .with_request_header( - "Authorization", - format!("Bearer {}", STANDARD.encode(AUTH_TOKEN)) - ) + .with_request_header("Authorization", format!("Bearer {AUTH_TOKEN}")) .with_request_header("Accept", CONTENT_TYPE) .with_request_header("Content-Type", CONTENT_TYPE) .run_async(&handler) @@ -899,7 +852,7 @@ mod tests { ) .config() .clone(), - aggregator_auth_token: Some(URL_SAFE_NO_PAD.encode(&aggregator_auth_token)), + aggregator_auth_token: Some(aggregator_auth_token), }; assert_response!( post("/tasks") @@ -945,10 +898,7 @@ mod tests { }; let mut conn = post("/tasks") .with_request_body(serde_json::to_vec(&req).unwrap()) - .with_request_header( - "Authorization", - format!("Bearer {}", STANDARD.encode(AUTH_TOKEN)), - ) + .with_request_header("Authorization", format!("Bearer {AUTH_TOKEN}")) .with_request_header("Accept", CONTENT_TYPE) .with_request_header("Content-Type", CONTENT_TYPE) .run_async(&handler) @@ -1021,15 +971,12 @@ mod tests { ) .config() .clone(), - aggregator_auth_token: Some(URL_SAFE_NO_PAD.encode(&aggregator_auth_token)), + aggregator_auth_token: Some(aggregator_auth_token), }; assert_response!( post("/tasks") .with_request_body(serde_json::to_vec(&req).unwrap()) - .with_request_header( - "Authorization", - format!("Bearer {}", STANDARD.encode(AUTH_TOKEN)), - ) + .with_request_header("Authorization", format!("Bearer {AUTH_TOKEN}")) .with_request_header("Accept", CONTENT_TYPE) .with_request_header("Content-Type", CONTENT_TYPE) .run_async(&handler) @@ -1067,14 +1014,11 @@ mod tests { ) .config() .clone(), - aggregator_auth_token: Some(URL_SAFE_NO_PAD.encode(&aggregator_auth_token)), + aggregator_auth_token: Some(aggregator_auth_token.clone()), }; let mut conn = post("/tasks") .with_request_body(serde_json::to_vec(&req).unwrap()) - .with_request_header( - "Authorization", - format!("Bearer {}", STANDARD.encode(AUTH_TOKEN)), - ) + .with_request_header("Authorization", format!("Bearer {AUTH_TOKEN}")) .with_request_header("Accept", CONTENT_TYPE) .with_request_header("Content-Type", CONTENT_TYPE) .run_async(&handler) @@ -1163,10 +1107,7 @@ mod tests { assert_response!( post("/tasks") .with_request_body(serde_json::to_vec(&req).unwrap()) - .with_request_header( - "Authorization", - format!("Bearer {}", STANDARD.encode(AUTH_TOKEN)), - ) + .with_request_header("Authorization", format!("Bearer {AUTH_TOKEN}")) .with_request_header("Accept", CONTENT_TYPE) .with_request_header("Content-Type", CONTENT_TYPE) .run_async(&handler) @@ -1198,10 +1139,7 @@ mod tests { // Verify: getting the task returns the expected result. let want_task_resp = TaskResp::try_from(&task).unwrap(); let mut conn = get(&format!("/tasks/{}", task.id())) - .with_request_header( - "Authorization", - format!("Bearer {}", STANDARD.encode(AUTH_TOKEN)), - ) + .with_request_header("Authorization", format!("Bearer {AUTH_TOKEN}")) .with_request_header("Accept", CONTENT_TYPE) .run_async(&handler) .await; @@ -1220,10 +1158,7 @@ mod tests { // Verify: getting a nonexistent task returns NotFound. assert_response!( get(&format!("/tasks/{}", random::())) - .with_request_header( - "Authorization", - format!("Bearer {}", STANDARD.encode(AUTH_TOKEN)) - ) + .with_request_header("Authorization", format!("Bearer {AUTH_TOKEN}")) .with_request_header("Accept", CONTENT_TYPE) .run_async(&handler) .await, @@ -1265,10 +1200,7 @@ mod tests { // Verify: deleting a task succeeds (and actually deletes the task). assert_response!( delete(&format!("/tasks/{}", &task_id)) - .with_request_header( - "Authorization", - format!("Bearer {}", STANDARD.encode(AUTH_TOKEN)) - ) + .with_request_header("Authorization", format!("Bearer {AUTH_TOKEN}")) .with_request_header("Accept", CONTENT_TYPE) .run_async(&handler) .await, @@ -1288,10 +1220,7 @@ mod tests { // Verify: deleting a task twice returns NotFound. assert_response!( delete(&format!("/tasks/{}", &task_id)) - .with_request_header( - "Authorization", - format!("Bearer {}", STANDARD.encode(AUTH_TOKEN)) - ) + .with_request_header("Authorization", format!("Bearer {AUTH_TOKEN}")) .with_request_header("Accept", CONTENT_TYPE) .run_async(&handler) .await, @@ -1302,10 +1231,7 @@ mod tests { // Verify: deleting an arbitrary nonexistent task ID returns NotFound. assert_response!( delete(&format!("/tasks/{}", &random::())) - .with_request_header( - "Authorization", - format!("Bearer {}", STANDARD.encode(AUTH_TOKEN)) - ) + .with_request_header("Authorization", format!("Bearer {AUTH_TOKEN}")) .with_request_header("Accept", CONTENT_TYPE) .run_async(&handler) .await, @@ -1399,10 +1325,7 @@ mod tests { // Verify: requesting metrics on a task returns the correct result. assert_response!( get(&format!("/tasks/{}/metrics", &task_id)) - .with_request_header( - "Authorization", - format!("Bearer {}", STANDARD.encode(AUTH_TOKEN)) - ) + .with_request_header("Authorization", format!("Bearer {AUTH_TOKEN}")) .with_request_header("Accept", CONTENT_TYPE) .run_async(&handler) .await, @@ -1417,10 +1340,7 @@ mod tests { // Verify: requesting metrics on a nonexistent task returns NotFound. assert_response!( delete(&format!("/tasks/{}", &random::())) - .with_request_header( - "Authorization", - format!("Bearer {}", STANDARD.encode(AUTH_TOKEN)) - ) + .with_request_header("Authorization", format!("Bearer {AUTH_TOKEN}")) .with_request_header("Accept", CONTENT_TYPE) .run_async(&handler) .await, @@ -1601,7 +1521,9 @@ mod tests { HpkeAeadId::Aes128Gcm, HpkePublicKey::from([0u8; 32].to_vec()), ), - aggregator_auth_token: Some("encoded".to_owned()), + aggregator_auth_token: Some( + AuthenticationToken::new_dap_auth_token_from_string("ZW5jb2RlZA").unwrap(), + ), }, &[ Token::Struct { @@ -1676,7 +1598,15 @@ mod tests { Token::StructEnd, Token::Str("aggregator_auth_token"), Token::Some, - Token::Str("encoded"), + Token::Struct { + name: "AuthenticationToken", + len: 2, + }, + Token::Str("type"), + Token::Str("DapAuth"), + Token::Str("token"), + Token::Str("ZW5jb2RlZA"), + Token::StructEnd, Token::StructEnd, ], ); @@ -1709,12 +1639,14 @@ mod tests { HpkeAeadId::Aes128Gcm, HpkePublicKey::from([0u8; 32].to_vec()), ), - Vec::from([AuthenticationToken::DapAuth( - DapAuthToken::try_from(b"aggregator-12345678".to_vec()).unwrap(), - )]), - Vec::from([AuthenticationToken::DapAuth( - DapAuthToken::try_from(b"collector-abcdef00".to_vec()).unwrap(), - )]), + Vec::from([AuthenticationToken::new_dap_auth_token_from_string( + "Y29sbGVjdG9yLWFiY2RlZjAw", + ) + .unwrap()]), + Vec::from([AuthenticationToken::new_dap_auth_token_from_string( + "Y29sbGVjdG9yLWFiY2RlZjAw", + ) + .unwrap()]), [(HpkeKeypair::new( HpkeConfig::new( HpkeConfigId::from(13), @@ -1778,10 +1710,26 @@ mod tests { Token::NewtypeStruct { name: "Duration" }, Token::U64(60), Token::Str("aggregator_auth_token"), - Token::Str("YWdncmVnYXRvci0xMjM0NTY3OA"), + Token::Struct { + name: "AuthenticationToken", + len: 2, + }, + Token::Str("type"), + Token::Str("DapAuth"), + Token::Str("token"), + Token::Str("Y29sbGVjdG9yLWFiY2RlZjAw"), + Token::StructEnd, Token::Str("collector_auth_token"), Token::Some, + Token::Struct { + name: "AuthenticationToken", + len: 2, + }, + Token::Str("type"), + Token::Str("DapAuth"), + Token::Str("token"), Token::Str("Y29sbGVjdG9yLWFiY2RlZjAw"), + Token::StructEnd, Token::Str("collector_hpke_config"), Token::Struct { name: "HpkeConfig", diff --git a/aggregator_core/Cargo.toml b/aggregator_core/Cargo.toml index b4009a090..b0e5455b5 100644 --- a/aggregator_core/Cargo.toml +++ b/aggregator_core/Cargo.toml @@ -53,7 +53,7 @@ serde_yaml = "0.9.25" sqlx = { version = "0.6.3", optional = true, features = ["runtime-tokio-rustls", "migrate", "postgres"] } testcontainers = { version = "0.14.0", optional = true } thiserror = "1.0" -tokio = { version = "1.29", features = ["full", "tracing"] } +tokio.workspace = true tokio-postgres = { version = "0.7.8", features = ["with-chrono-0_4", "with-serde_json-1", "with-uuid-1", "array-impls"] } tracing = "0.1.37" tracing-log = "0.1.3" diff --git a/aggregator_core/src/datastore/models.rs b/aggregator_core/src/datastore/models.rs index 833145160..13a79c607 100644 --- a/aggregator_core/src/datastore/models.rs +++ b/aggregator_core/src/datastore/models.rs @@ -7,7 +7,7 @@ use derivative::Derivative; use janus_core::{ hpke::HpkeKeypair, report_id::ReportIdChecksumExt, - task::{AuthenticationToken, DapAuthToken, VdafInstance}, + task::{AuthenticationToken, VdafInstance}, time::{DurationExt, IntervalExt, TimeExt}, }; use janus_messages::{ @@ -47,13 +47,14 @@ pub enum AuthenticationTokenType { } impl AuthenticationTokenType { - pub fn as_authentication(&self, token: &[u8]) -> Result { + pub fn as_authentication(&self, token_bytes: &[u8]) -> Result { match self { - Self::DapAuthToken => DapAuthToken::try_from(token.to_vec()) - .map(AuthenticationToken::DapAuth) - .map_err(|e| Error::DbState(format!("invalid DAP auth token in database: {e:?}"))), - Self::AuthorizationBearerToken => Ok(AuthenticationToken::Bearer(token.into())), + Self::DapAuthToken => AuthenticationToken::new_dap_auth_token_from_bytes(token_bytes), + Self::AuthorizationBearerToken => { + AuthenticationToken::new_bearer_token_from_bytes(token_bytes) + } } + .map_err(|e| Error::DbState(format!("invalid DAP auth token in database: {e:?}"))) } } diff --git a/aggregator_core/src/task.rs b/aggregator_core/src/task.rs index bfbdb255b..eed696ce0 100644 --- a/aggregator_core/src/task.rs +++ b/aggregator_core/src/task.rs @@ -784,7 +784,7 @@ mod tests { }; use janus_core::{ hpke::{test_util::generate_test_hpke_config_and_private_key, HpkeKeypair, HpkePrivateKey}, - task::{AuthenticationToken, DapAuthToken, PRIO3_VERIFY_KEY_LENGTH}, + task::{AuthenticationToken, PRIO3_VERIFY_KEY_LENGTH}, test_util::roundtrip_encoding, time::DurationExt, }; @@ -1017,10 +1017,14 @@ mod tests { HpkeAeadId::Aes128Gcm, HpkePublicKey::from(b"collector hpke public key".to_vec()), ), - Vec::from([AuthenticationToken::DapAuth( - DapAuthToken::try_from(b"aggregator token".to_vec()).unwrap(), - )]), - Vec::from([AuthenticationToken::Bearer(b"collector token".to_vec())]), + Vec::from([AuthenticationToken::new_dap_auth_token_from_string( + "YWdncmVnYXRvciB0b2tlbg", + ) + .unwrap()]), + Vec::from([AuthenticationToken::new_bearer_token_from_string( + "Y29sbGVjdG9yIHRva2Vu", + ) + .unwrap()]), [HpkeKeypair::new( HpkeConfig::new( HpkeConfigId::from(255), @@ -1197,7 +1201,10 @@ mod tests { HpkeAeadId::Aes128Gcm, HpkePublicKey::from(b"collector hpke public key".to_vec()), ), - Vec::from([AuthenticationToken::Bearer(b"aggregator token".to_vec())]), + Vec::from([AuthenticationToken::new_bearer_token_from_string( + "YWdncmVnYXRvciB0b2tlbg", + ) + .unwrap()]), Vec::new(), [HpkeKeypair::new( HpkeConfig::new( @@ -1304,7 +1311,7 @@ mod tests { Token::Str("type"), Token::Str("Bearer"), Token::Str("token"), - Token::Str("YWdncmVnYXRvciB0b2tlbg=="), + Token::Str("YWdncmVnYXRvciB0b2tlbg"), Token::StructEnd, Token::SeqEnd, Token::Str("collector_auth_tokens"), diff --git a/client/Cargo.toml b/client/Cargo.toml index 83c318aa7..0aa6b4305 100644 --- a/client/Cargo.toml +++ b/client/Cargo.toml @@ -19,8 +19,8 @@ janus_messages.workspace = true prio.workspace = true rand = "0.8" reqwest = { version = "0.11.18", default-features = false, features = ["rustls-tls", "json"] } -thiserror = "1.0" -tokio = { version = "1.29", features = ["full"] } +thiserror.workspace = true +tokio.workspace = true tracing = "0.1.37" url = "2.4.0" diff --git a/collector/Cargo.toml b/collector/Cargo.toml index 0b1ee4fbb..a13bd5384 100644 --- a/collector/Cargo.toml +++ b/collector/Cargo.toml @@ -30,8 +30,8 @@ prio.workspace = true rand = { version = "0.8", features = ["min_const_gen"] } reqwest = { version = "0.11.18", default-features = false, features = ["rustls-tls", "json"] } retry-after = "0.3.1" -thiserror = "1.0" -tokio = { version = "1.29", features = ["full"] } +thiserror.workspace = true +tokio.workspace = true tracing = "0.1.37" url = "2.4.0" diff --git a/collector/src/lib.rs b/collector/src/lib.rs index 0aa17dcd1..b80282e00 100644 --- a/collector/src/lib.rs +++ b/collector/src/lib.rs @@ -29,7 +29,7 @@ //! let parameters = CollectorParameters::new( //! task_id, //! "https://example.com/dap/".parse().unwrap(), -//! AuthenticationToken::Bearer(b"my-authentication-token".to_vec()), +//! AuthenticationToken::new_bearer_token_from_string("Y29sbGVjdG9yIHRva2Vu").unwrap(), //! hpke_keypair.config().clone(), //! hpke_keypair.private_key().clone(), //! ); @@ -704,7 +704,7 @@ mod tests { let parameters = CollectorParameters::new( random(), server_url, - AuthenticationToken::Bearer(b"token".to_vec()), + AuthenticationToken::new_bearer_token_from_string("Y29sbGVjdG9yIHRva2Vu").unwrap(), hpke_keypair.config().clone(), hpke_keypair.private_key().clone(), ) @@ -805,7 +805,7 @@ mod tests { let collector_parameters = CollectorParameters::new( random(), "http://example.com/dap".parse().unwrap(), - AuthenticationToken::Bearer(b"token".to_vec()), + AuthenticationToken::new_bearer_token_from_string("Y29sbGVjdG9yIHRva2Vu").unwrap(), hpke_keypair.config().clone(), hpke_keypair.private_key().clone(), ); @@ -818,7 +818,7 @@ mod tests { let collector_parameters = CollectorParameters::new( random(), "http://example.com".parse().unwrap(), - AuthenticationToken::Bearer(b"token".to_vec()), + AuthenticationToken::new_bearer_token_from_string("Y29sbGVjdG9yIHRva2Vu").unwrap(), hpke_keypair.config().clone(), hpke_keypair.private_key().clone(), ); @@ -838,7 +838,6 @@ mod tests { let collector = setup_collector(&mut server, vdaf); let (auth_header, auth_value) = collector.parameters.authentication.request_authentication(); - let auth_value = String::from_utf8(auth_value).unwrap(); let batch_interval = Interval::new( Time::from_seconds_since_epoch(1_000_000), @@ -1224,7 +1223,7 @@ mod tests { let parameters = CollectorParameters::new( random(), server_url, - AuthenticationToken::Bearer(Vec::from([0x41u8; 16])), + AuthenticationToken::new_bearer_token_from_bytes(Vec::from([0x41u8; 16])).unwrap(), hpke_keypair.config().clone(), hpke_keypair.private_key().clone(), ) @@ -1247,7 +1246,7 @@ mod tests { CONTENT_TYPE.as_str(), CollectionReq::::MEDIA_TYPE, ) - .match_header(AUTHORIZATION.as_str(), "Bearer QUFBQUFBQUFBQUFBQUFBQQ==") + .match_header(AUTHORIZATION.as_str(), "Bearer AAAAAAAAAAAAAAAA") .with_status(201) .expect(1) .create_async() @@ -1255,15 +1254,15 @@ mod tests { let job = collector .start_collection(Query::new_time_interval(batch_interval), &()) - .await - .unwrap(); - assert_eq!(job.query.batch_interval(), &batch_interval); + .await; mocked_collect_start_success.assert_async().await; + let job = job.unwrap(); + assert_eq!(job.query.batch_interval(), &batch_interval); let mocked_collect_complete = server .mock("POST", job.collection_job_url.path()) - .match_header(AUTHORIZATION.as_str(), "Bearer QUFBQUFBQUFBQUFBQUFBQQ==") + .match_header(AUTHORIZATION.as_str(), "Bearer AAAAAAAAAAAAAAAA") .with_status(200) .with_header( CONTENT_TYPE.as_str(), diff --git a/core/Cargo.toml b/core/Cargo.toml index d1fa78288..f76565011 100644 --- a/core/Cargo.toml +++ b/core/Cargo.toml @@ -57,8 +57,8 @@ serde_yaml = "0.9.25" stopper = { version = "0.2.0", optional = true } tempfile = { version = "3", optional = true } testcontainers = { version = "0.14", optional = true } -thiserror = "1.0" -tokio = { version = "1.29", features = ["macros", "net", "rt"] } +thiserror.workspace = true +tokio.workspace = true tokio-stream = { version = "0.1.14", features = ["net"], optional = true } tracing = "0.1.37" tracing-log = { version = "0.1.3", optional = true } @@ -71,5 +71,6 @@ hex = { version = "0.4", features = ["serde"] } # ensure this remains compatibl janus_core = { path = ".", features = ["test-util"] } kube.workspace = true mockito = "1.1.0" +rstest.workspace = true serde_test = "1.0.175" url = "2.4.0" diff --git a/core/src/http.rs b/core/src/http.rs index 1ca6c6ad6..bb7eca117 100644 --- a/core/src/http.rs +++ b/core/src/http.rs @@ -1,5 +1,5 @@ +use crate::task::AuthenticationToken; use anyhow::{anyhow, Context}; -use base64::{engine::general_purpose::STANDARD, Engine}; use http_api_problem::{HttpApiProblem, PROBLEM_JSON_MEDIA_TYPE}; use reqwest::{header::CONTENT_TYPE, Response}; use tracing::warn; @@ -27,13 +27,13 @@ pub async fn response_to_problem_details(response: Response) -> HttpApiProblem { /// If the request in `conn` has an `authorization` header, returns the bearer token in the header /// value. Returns `None` if there is no `authorization` header, and an error if there is an /// `authorization` header whose value is not a bearer token. -pub fn extract_bearer_token(conn: &Conn) -> Result>, anyhow::Error> { +pub fn extract_bearer_token(conn: &Conn) -> Result, anyhow::Error> { if let Some(authorization_value) = conn.headers().get("authorization") { - if let Some(received_token) = authorization_value.as_ref().strip_prefix(b"Bearer ") { - let decoded = STANDARD - .decode(received_token) - .context("bearer token cannot be decoded from Base64")?; - return Ok(Some(decoded)); + if let Some(received_token) = authorization_value.to_string().strip_prefix("Bearer ") { + return Ok(Some( + AuthenticationToken::new_bearer_token_from_string(received_token) + .context("invalid bearer token")?, + )); } else { return Err(anyhow!("authorization header value is not a bearer token")); } diff --git a/core/src/task.rs b/core/src/task.rs index 3bb7fa626..c159c7179 100644 --- a/core/src/task.rs +++ b/core/src/task.rs @@ -1,13 +1,10 @@ -use base64::{ - engine::general_purpose::{STANDARD, URL_SAFE_NO_PAD}, - Engine, -}; +use base64::{engine::general_purpose::URL_SAFE_NO_PAD, Engine}; use derivative::Derivative; -use http::header::{HeaderValue, AUTHORIZATION}; +use http::header::AUTHORIZATION; use rand::{distributions::Standard, prelude::Distribution}; use reqwest::Url; use ring::constant_time; -use serde::{de::Error, Deserialize, Deserializer, Serialize, Serializer}; +use serde::{de::Error, Deserialize, Deserializer, Serialize}; use std::{fmt, str}; /// HTTP header where auth tokens are provided in messages between participants. @@ -547,21 +544,18 @@ macro_rules! vdaf_dispatch { /// Different modes of authentication supported by Janus for either sending requests (e.g., leader /// to helper) or receiving them (e.g., collector to leader). -#[derive(Clone, Derivative, Serialize, Deserialize)] +#[derive(Clone, Derivative, Serialize, Deserialize, PartialEq, Eq)] #[derivative(Debug)] #[serde(tag = "type", content = "token")] #[non_exhaustive] pub enum AuthenticationToken { - /// A bearer token. The value is an opaque byte string. Its Base64 encoding is inserted into - /// HTTP requests as specified in [RFC 6750 section 2.1][1]. The token is not necessarily an - /// OAuth token. + /// A bearer token, presented as the value of the "Authorization" HTTP header as specified in + /// [RFC 6750 section 2.1][1]. + /// + /// The token is not necessarily an OAuth token. /// /// [1]: https://datatracker.ietf.org/doc/html/rfc6750#section-2.1 - Bearer( - #[derivative(Debug = "ignore")] - #[serde(serialize_with = "as_base64", deserialize_with = "from_base64")] - Vec, - ), + Bearer(TokenInner), /// Token presented as the value of the "DAP-Auth-Token" HTTP header. Conforms to /// [draft-dcook-ppm-dap-interop-test-design-03][1], sections [4.3.3][2] and [4.4.2][3], and @@ -571,45 +565,51 @@ pub enum AuthenticationToken { /// [2]: https://datatracker.ietf.org/doc/html/draft-dcook-ppm-dap-interop-test-design-03#section-4.3.3 /// [3]: https://datatracker.ietf.org/doc/html/draft-dcook-ppm-dap-interop-test-design-03#section-4.4.2 /// [4]: https://datatracker.ietf.org/doc/html/draft-ietf-ppm-dap-01#name-https-sender-authentication - DapAuth(DapAuthToken), + DapAuth(TokenInner), } impl AuthenticationToken { + /// Attempts to create a new bearer token from the provided bytes. + pub fn new_bearer_token_from_bytes>(bytes: T) -> Result { + TokenInner::try_from(bytes.as_ref().to_vec()).map(AuthenticationToken::Bearer) + } + + /// Attempts to create a new bearer token from the provided string + pub fn new_bearer_token_from_string>(string: T) -> Result { + TokenInner::try_from_str(string.into()).map(AuthenticationToken::Bearer) + } + + /// Attempts to create a new DAP auth token from the provided bytes. + pub fn new_dap_auth_token_from_bytes>(bytes: T) -> Result { + TokenInner::try_from(bytes.as_ref().to_vec()).map(AuthenticationToken::DapAuth) + } + + /// Attempts to create a new DAP auth token from the provided string. + pub fn new_dap_auth_token_from_string>( + string: T, + ) -> Result { + TokenInner::try_from_str(string.into()).map(AuthenticationToken::DapAuth) + } + /// Returns an HTTP header and value that should be used to authenticate an HTTP request with /// this credential. - pub fn request_authentication(&self) -> (&'static str, Vec) { + pub fn request_authentication(&self) -> (&'static str, String) { match self { - Self::Bearer(token) => ( - AUTHORIZATION.as_str(), - // When encoding into a request, we use Base64 standard encoding - format!("Bearer {}", STANDARD.encode(token.as_slice())).into_bytes(), - ), - // A DAP-Auth-Token is already HTTP header-safe, so no encoding is needed. Cloning is - // unfortunate but necessary since other arms must allocate. - Self::DapAuth(token) => (DAP_AUTH_HEADER, token.as_ref().to_vec()), + Self::Bearer(token) => (AUTHORIZATION.as_str(), format!("Bearer {}", token.as_str())), + // Cloning is unfortunate but necessary since other arms must allocate. + Self::DapAuth(token) => (DAP_AUTH_HEADER, token.as_str().to_string()), } } -} -impl PartialEq for AuthenticationToken { - fn eq(&self, other: &Self) -> bool { - let (own, other) = match (self, other) { - (Self::Bearer(own), Self::Bearer(other)) => (own.as_slice(), other.as_slice()), - (Self::DapAuth(own), Self::DapAuth(other)) => (own.as_ref(), other.as_ref()), - _ => { - return false; - } - }; - // We attempt constant-time comparisons of the token data to mitigate timing attacks. Note - // that this function still eaks whether the lengths of the tokens are equal -- this is - // acceptable because we expec the content of the tokens to provide enough randomness that - // needs to be guessed even if the length is known. - constant_time::verify_slices_are_equal(own, other).is_ok() + /// Returns the token as a string. + pub fn as_str(&self) -> &str { + match self { + Self::DapAuth(token) => token.as_str(), + Self::Bearer(token) => token.as_str(), + } } } -impl Eq for AuthenticationToken {} - impl AsRef<[u8]> for AuthenticationToken { fn as_ref(&self) -> &[u8] { match self { @@ -619,80 +619,77 @@ impl AsRef<[u8]> for AuthenticationToken { } } -/// Serialize bytes into format suitable for bearer tokens in Janus configuration files. -fn as_base64>(key: &T, serializer: S) -> Result { - let bytes: &[u8] = key.as_ref(); - serializer.serialize_str(&STANDARD.encode(bytes)) -} - -/// Deserialize bytes from Janus configuration files into a bearer token. -fn from_base64<'de, D: Deserializer<'de>>(deserializer: D) -> Result, D::Error> { - String::deserialize(deserializer).and_then(|s| { - STANDARD - .decode(s) - .map_err(|e| D::Error::custom(format!("cannot decode value from Base64: {e:?}"))) - }) -} - impl Distribution for Standard { fn sample(&self, rng: &mut R) -> AuthenticationToken { - AuthenticationToken::Bearer(Vec::from(rng.gen::<[u8; 16]>())) + AuthenticationToken::Bearer(Standard::sample(self, rng)) } } -/// Token presented as the value of the "DAP-Auth-Token" HTTP header. The token is used directly in -/// the HTTP request without further encoding and so must be a legal HTTP header value. Conforms to -/// [draft-ietf-dap-ppm-01 section 3.2][1]. +/// A token value used to authenticate HTTP requests. +/// +/// The token is used directly in HTTP request headers without further encoding and so much be a +/// legal HTTP header value. More specifically, the token is restricted to the unpadded, URL-safe +/// Base64 alphabet, as specified in [RFC 4648 section 5][1]. The unpadded, URL-safe Base64 string +/// is the canonical form of the token and is used in configuration files, Janus aggregator API +/// requests and HTTP authentication headers. /// -/// This opaque type ensures it's impossible to construct an [`AuthenticationToken::DapAuth`] whose -/// contents are invalid. +/// This opaque type ensures it's impossible to construct an [`AuthenticationToken`] whose contents +/// are invalid. /// -/// [1]: https://datatracker.ietf.org/doc/html/draft-ietf-ppm-dap-01#name-https-sender-authentication -#[derive(Clone, Derivative)] +/// [1]: https://datatracker.ietf.org/doc/html/rfc4648#section-5 +#[derive(Clone, Derivative, Serialize)] #[derivative(Debug)] -pub struct DapAuthToken(#[derivative(Debug = "ignore")] Vec); +#[serde(transparent)] +pub struct TokenInner(#[derivative(Debug = "ignore")] String); -impl DapAuthToken {} - -impl AsRef<[u8]> for DapAuthToken { - fn as_ref(&self) -> &[u8] { +impl TokenInner { + /// Returns the token as a string. + pub fn as_str(&self) -> &str { &self.0 } -} -impl TryFrom> for DapAuthToken { - type Error = anyhow::Error; + fn try_from_str(value: String) -> Result { + // Verify that the string is legal unpadded, URL-safe Base64 + URL_SAFE_NO_PAD.decode(&value)?; + Ok(Self(value)) + } - fn try_from(token: Vec) -> Result { - HeaderValue::try_from(token.as_slice())?; - Ok(Self(token)) + fn try_from(value: Vec) -> Result { + Self::try_from_str(String::from_utf8(value)?) } } -impl Serialize for DapAuthToken { - fn serialize(&self, serializer: S) -> Result { - serializer.serialize_str(&URL_SAFE_NO_PAD.encode(self.as_ref())) +impl AsRef<[u8]> for TokenInner { + fn as_ref(&self) -> &[u8] { + self.0.as_bytes() } } -impl<'de> Deserialize<'de> for DapAuthToken { - fn deserialize>(deserializer: D) -> Result { - // Verify that the string is a safe HTTP header value +impl<'de> Deserialize<'de> for TokenInner { + fn deserialize(deserializer: D) -> Result + where + D: Deserializer<'de>, + { String::deserialize(deserializer) - .and_then(|string| { - URL_SAFE_NO_PAD.decode(string).map_err(|e| { - D::Error::custom(format!( - "cannot decode value from unpadded Base64URL: {e:?}" - )) - }) - }) - .and_then(|bytes| Self::try_from(bytes).map_err(D::Error::custom)) + .and_then(|string| Self::try_from_str(string).map_err(D::Error::custom)) } } -impl Distribution for Standard { - fn sample(&self, rng: &mut R) -> DapAuthToken { - DapAuthToken(Vec::from(hex::encode(rng.gen::<[u8; 16]>()))) +impl PartialEq for TokenInner { + fn eq(&self, other: &Self) -> bool { + // We attempt constant-time comparisons of the token data to mitigate timing attacks. Note + // that this function still eaks whether the lengths of the tokens are equal -- this is + // acceptable because we expec the content of the tokens to provide enough randomness that + // needs to be guessed even if the length is known. + constant_time::verify_slices_are_equal(self.0.as_bytes(), other.0.as_bytes()).is_ok() + } +} + +impl Eq for TokenInner {} + +impl Distribution for Standard { + fn sample(&self, rng: &mut R) -> TokenInner { + TokenInner(URL_SAFE_NO_PAD.encode(rng.gen::<[u8; 16]>())) } } @@ -805,12 +802,14 @@ mod tests { ); } + #[rstest::rstest] + #[case::dap_auth("DapAuth")] + #[case::bearer("Bearer")] #[test] - fn reject_invalid_dap_auth_token() { - let err = serde_yaml::from_str::( - "{type: \"DapAuth\", token: \"AAAAAAAAAAAAAA\"}", - ) + fn reject_invalid_auth_token(#[case] token_type: &str) { + serde_yaml::from_str::(&format!( + "{{type: \"{token_type}\", token: \"é\"}}" + )) .unwrap_err(); - assert!(err.to_string().contains("failed to parse header value")); } } diff --git a/docs/samples/tasks.yaml b/docs/samples/tasks.yaml index 1d381cd03..c83958604 100644 --- a/docs/samples/tasks.yaml +++ b/docs/samples/tasks.yaml @@ -75,9 +75,9 @@ # to "aggregator-235242f99406c4fd28b820c32eab0f68". - type: "DapAuth" token: "YWdncmVnYXRvci0yMzUyNDJmOTk0MDZjNGZkMjhiODIwYzMyZWFiMGY2OA" - # Bearer token values are encoded in base64 with padding. + # Bearer token values are encoded in unpadded base64url. - type: "Bearer" - token: "YWdncmVnYXRvci04NDc1NjkwZjJmYzQzMDBmYjE0NmJiMjk1NDIzNDk1NA==" + token: "YWdncmVnYXRvci04NDc1NjkwZjJmYzQzMDBmYjE0NmJiMjk1NDIzNDk1NA" # Authentication tokens shared between the leader and the collector, and used # to authenticate collector-to-leader requests. For leader tasks, this has the @@ -127,7 +127,7 @@ public_key: KHRLcWgfWxli8cdOLPsgsZPttHXh0ho3vLVLrW-63lE aggregator_auth_tokens: - type: "Bearer" - token: "YWdncmVnYXRvci1jZmE4NDMyZjdkMzllMjZiYjU3OGUzMzY5Mzk1MWQzNQ==" + token: "YWdncmVnYXRvci1jZmE4NDMyZjdkMzllMjZiYjU3OGUzMzY5Mzk1MWQzNQ" # Note that this task does not have any collector authentication tokens, since # it is a helper role task. collector_auth_tokens: [] diff --git a/integration_tests/Cargo.toml b/integration_tests/Cargo.toml index 206142965..43d21afdc 100644 --- a/integration_tests/Cargo.toml +++ b/integration_tests/Cargo.toml @@ -33,7 +33,7 @@ reqwest = { version = "0.11", default-features = false, features = ["rustls-tls" serde.workspace = true serde_json = "1.0.103" testcontainers = "0.14.0" -tokio = { version = "1", features = ["full", "tracing"] } +tokio.workspace = true url = { version = "2.4.0", features = ["serde"] } [dev-dependencies] diff --git a/integration_tests/src/divviup_api_client.rs b/integration_tests/src/divviup_api_client.rs index 06de50233..1e9136c8e 100644 --- a/integration_tests/src/divviup_api_client.rs +++ b/integration_tests/src/divviup_api_client.rs @@ -81,6 +81,15 @@ pub struct DivviUpAggregator { pub dap_url: Url, } +/// Representation of a collector auth token in divviup-api. +#[derive(Deserialize)] +pub struct CollectorAuthToken { + /// Type of the authentication token. Always "Bearer" in divviup-api. + pub r#type: String, + /// Encoded value of the token. The encoding is opaque to divviup-api. + pub token: String, +} + const DIVVIUP_CONTENT_TYPE: &str = "application/vnd.divviup+json;version=0.1"; pub struct DivviupApiClient { @@ -176,7 +185,10 @@ impl DivviupApiClient { .await } - pub async fn list_collector_auth_tokens(&self, task: &DivviUpApiTask) -> Vec { + pub async fn list_collector_auth_tokens( + &self, + task: &DivviUpApiTask, + ) -> Vec { // Hack: we must choose some specialization for the B type despite the request having no // Body self.make_request::( diff --git a/integration_tests/tests/in_cluster.rs b/integration_tests/tests/in_cluster.rs index 869284cbe..f45418d0b 100644 --- a/integration_tests/tests/in_cluster.rs +++ b/integration_tests/tests/in_cluster.rs @@ -1,14 +1,11 @@ #![cfg(feature = "in-cluster")] -use base64::engine::{ - general_purpose::{STANDARD, URL_SAFE_NO_PAD}, - Engine, -}; +use base64::engine::{general_purpose::STANDARD, Engine}; use common::{submit_measurements_and_verify_aggregate, test_task_builders}; use janus_aggregator_core::task::QueryType; -use janus_collector::AuthenticationToken; use janus_core::{ - task::{DapAuthToken, VdafInstance}, + task::AuthenticationToken, + task::VdafInstance, test_util::{ install_test_trace_subscriber, kubernetes::{Cluster, PortForward}, @@ -163,17 +160,14 @@ impl InClusterJanusPair { let collector_auth_tokens = divviup_api .list_collector_auth_tokens(&provisioned_task) .await; + assert_eq!(collector_auth_tokens[0].r#type, "Bearer"); // Update the task parameters with the ID and collector auth token from divviup-api. task_parameters.task_id = TaskId::from_str(provisioned_task.id.as_ref()).unwrap(); - task_parameters.collector_auth_token = AuthenticationToken::DapAuth( - DapAuthToken::try_from( - URL_SAFE_NO_PAD - .decode(collector_auth_tokens[0].clone()) - .unwrap(), - ) - .unwrap(), - ); + task_parameters.collector_auth_token = AuthenticationToken::new_bearer_token_from_string( + collector_auth_tokens[0].token.clone(), + ) + .unwrap(); Self { task_parameters, diff --git a/interop_binaries/Cargo.toml b/interop_binaries/Cargo.toml index bbe600bc1..8de5cfab5 100644 --- a/interop_binaries/Cargo.toml +++ b/interop_binaries/Cargo.toml @@ -48,7 +48,7 @@ serde.workspace = true serde_json = "1.0.103" sqlx = { version = "0.6.3", features = ["runtime-tokio-rustls", "migrate", "postgres"] } testcontainers = { version = "0.14" } -tokio = { version = "1.29", features = ["full", "tracing"] } +tokio.workspace = true tracing = "0.1.37" tracing-log = "0.1.3" tracing-subscriber = { version = "0.3", features = ["std", "env-filter", "fmt"] } diff --git a/interop_binaries/src/bin/janus_interop_aggregator.rs b/interop_binaries/src/bin/janus_interop_aggregator.rs index ac4c29125..a3e85643c 100644 --- a/interop_binaries/src/bin/janus_interop_aggregator.rs +++ b/interop_binaries/src/bin/janus_interop_aggregator.rs @@ -11,10 +11,7 @@ use janus_aggregator_core::{ task::{self, Task}, SecretBytes, }; -use janus_core::{ - task::{AuthenticationToken, DapAuthToken}, - time::RealClock, -}; +use janus_core::{task::AuthenticationToken, time::RealClock}; use janus_interop_binaries::{ status::{ERROR, SUCCESS}, AddTaskResponse, AggregatorAddTaskRequest, AggregatorRole, HpkeConfigRegistry, Keyring, @@ -42,10 +39,9 @@ async fn handle_add_task( request: AggregatorAddTaskRequest, ) -> anyhow::Result<()> { let vdaf = request.vdaf.into(); - let leader_authentication_token = AuthenticationToken::DapAuth( - DapAuthToken::try_from(request.leader_authentication_token.into_bytes()) - .context("invalid header value in \"leader_authentication_token\"")?, - ); + let leader_authentication_token = + AuthenticationToken::new_dap_auth_token_from_string(request.leader_authentication_token) + .context("invalid header value in \"leader_authentication_token\"")?; let vdaf_verify_key = SecretBytes::new( URL_SAFE_NO_PAD .decode(request.vdaf_verify_key) @@ -64,10 +60,10 @@ async fn handle_add_task( return Err(anyhow::anyhow!("collector authentication token is missing")) } (AggregatorRole::Leader, Some(collector_authentication_token)) => { - Vec::from([AuthenticationToken::DapAuth( - DapAuthToken::try_from(collector_authentication_token.into_bytes()) - .context("invalid header value in \"collector_authentication_token\"")?, - )]) + Vec::from([AuthenticationToken::new_dap_auth_token_from_string( + collector_authentication_token, + ) + .context("invalid header value in \"collector_authentication_token\"")?]) } (AggregatorRole::Helper, _) => Vec::new(), }; diff --git a/interop_binaries/src/bin/janus_interop_collector.rs b/interop_binaries/src/bin/janus_interop_collector.rs index 94f64c7d1..d25f674d5 100644 --- a/interop_binaries/src/bin/janus_interop_collector.rs +++ b/interop_binaries/src/bin/janus_interop_collector.rs @@ -7,7 +7,6 @@ use fixed::types::extra::{U15, U31, U63}; #[cfg(feature = "fpvec_bounded_l2")] use fixed::{FixedI16, FixedI32, FixedI64}; use janus_collector::{Collector, CollectorParameters}; -use janus_core::task::DapAuthToken; use janus_core::{ hpke::HpkeKeypair, task::{AuthenticationToken, VdafInstance}, @@ -167,10 +166,9 @@ async fn handle_add_task( let keypair = keyring.lock().await.get_random_keypair(); let hpke_config = keypair.config().clone(); - let auth_token = AuthenticationToken::DapAuth( - DapAuthToken::try_from(request.collector_authentication_token.into_bytes()) - .context("invalid header value in \"collector_authentication_token\"")?, - ); + let auth_token = + AuthenticationToken::new_dap_auth_token_from_string(request.collector_authentication_token) + .context("invalid header value in \"collector_authentication_token\"")?; entry.or_insert(TaskState { keypair, diff --git a/tools/Cargo.toml b/tools/Cargo.toml index ead742f3e..47725a871 100644 --- a/tools/Cargo.toml +++ b/tools/Cargo.toml @@ -23,7 +23,7 @@ janus_messages.workspace = true prio.workspace = true reqwest = { version = "0.11.18", default-features = false, features = ["rustls-tls", "json"] } serde_yaml = "0.9.25" -tokio = { version = "1.29", features = ["full"] } +tokio.workspace = true tracing = "0.1.37" tracing-log = "0.1.3" tracing-subscriber = { version = "0.3", features = ["std", "env-filter", "fmt"] } diff --git a/tools/src/bin/collect.rs b/tools/src/bin/collect.rs index 80991f553..0864b339d 100644 --- a/tools/src/bin/collect.rs +++ b/tools/src/bin/collect.rs @@ -1,7 +1,4 @@ -use base64::{ - engine::general_purpose::{STANDARD, URL_SAFE_NO_PAD}, - Engine, -}; +use base64::{engine::general_purpose::URL_SAFE_NO_PAD, Engine}; use clap::{ builder::{NonEmptyStringValueParser, StringValueParser, TypedValueParser}, error::ErrorKind, @@ -13,7 +10,7 @@ use fixed::types::extra::{U15, U31, U63}; #[cfg(feature = "fpvec_bounded_l2")] use fixed::{FixedI16, FixedI32, FixedI64}; use janus_collector::{default_http_client, AuthenticationToken, Collector, CollectorParameters}; -use janus_core::{hpke::HpkePrivateKey, task::DapAuthToken}; +use janus_core::hpke::HpkePrivateKey; use janus_messages::{ query_type::{FixedSize, QueryType, TimeInterval}, BatchId, Duration, FixedSizeQuery, HpkeConfig, Interval, PartialBatchSelector, Query, TaskId, @@ -145,14 +142,6 @@ impl TypedValueParser for BatchIdValueParser { } } -fn parse_authentication_token(value: String) -> Result { - DapAuthToken::try_from(value.into_bytes()).map(AuthenticationToken::DapAuth) -} - -fn parse_authentication_token_base64(value: String) -> Result { - Ok(AuthenticationToken::Bearer(STANDARD.decode(value)?)) -} - #[derive(Clone)] struct HpkeConfigValueParser { inner: NonEmptyStringValueParser, @@ -257,7 +246,7 @@ struct AuthenticationOptions { #[clap( long, required = false, - value_parser = StringValueParser::new().try_map(parse_authentication_token), + value_parser = StringValueParser::new().try_map(AuthenticationToken::new_dap_auth_token_from_string), env, help_heading = "Authorization", display_order = 0, @@ -266,11 +255,11 @@ struct AuthenticationOptions { #[derivative(Debug = "ignore")] dap_auth_token: Option, - /// Authentication token for the "Authorization: Bearer ..." HTTP header, in base64 + /// Authentication token for the "Authorization: Bearer ..." HTTP header #[clap( long, required = false, - value_parser = StringValueParser::new().try_map(parse_authentication_token_base64), + value_parser = StringValueParser::new().try_map(AuthenticationToken::new_bearer_token_from_string), env, help_heading = "Authorization", display_order = 1, @@ -605,13 +594,14 @@ impl QueryTypeExt for FixedSize { #[cfg(test)] mod tests { use crate::{ - run, AuthenticationOptions, AuthenticationToken, DapAuthToken, Error, Options, - QueryOptions, VdafType, + run, AuthenticationOptions, AuthenticationToken, Error, Options, QueryOptions, VdafType, }; use assert_matches::assert_matches; use base64::{engine::general_purpose::URL_SAFE_NO_PAD, Engine}; use clap::{error::ErrorKind, CommandFactory, Parser}; - use janus_core::hpke::test_util::generate_test_hpke_config_and_private_key; + use janus_core::{ + hpke::test_util::generate_test_hpke_config_and_private_key, task::TokenInner, + }; use janus_messages::{BatchId, TaskId}; use prio::codec::Encode; use rand::random; @@ -630,14 +620,13 @@ mod tests { let task_id = random(); let leader = Url::parse("https://example.com/dap/").unwrap(); + let auth_token = AuthenticationToken::DapAuth(random()); let expected = Options { task_id, leader: leader.clone(), authentication: AuthenticationOptions { - dap_auth_token: Some(AuthenticationToken::DapAuth( - DapAuthToken::try_from(b"collector-authentication-token".to_vec()).unwrap(), - )), + dap_auth_token: Some(auth_token.clone()), authorization_bearer_token: None, }, hpke_config: hpke_keypair.config().clone(), @@ -660,7 +649,7 @@ mod tests { "--leader", leader.as_str(), "--dap-auth-token", - "collector-authentication-token", + auth_token.as_str(), &format!("--hpke-config={encoded_hpke_config}"), &format!("--hpke-private-key={encoded_private_key}"), "--vdaf", @@ -723,7 +712,7 @@ mod tests { "--leader".to_string(), leader.to_string(), "--dap-auth-token".to_string(), - "collector-authentication-token".to_string(), + auth_token.as_str().to_string(), format!("--hpke-config={encoded_hpke_config}"), format!("--hpke-private-key={encoded_private_key}"), "--batch-interval-start".to_string(), @@ -887,16 +876,14 @@ mod tests { let hpke_keypair = generate_test_hpke_config_and_private_key(); let encoded_hpke_config = URL_SAFE_NO_PAD.encode(hpke_keypair.config().get_encoded()); let encoded_private_key = URL_SAFE_NO_PAD.encode(hpke_keypair.private_key().as_ref()); - let auth_token = Some(AuthenticationToken::DapAuth( - DapAuthToken::try_from(b"collector-authentication-token".to_vec()).unwrap(), - )); + let auth_token = AuthenticationToken::DapAuth(random()); // Check parsing arguments for a current batch query. let expected = Options { task_id, leader: leader.clone(), authentication: AuthenticationOptions { - dap_auth_token: auth_token.clone(), + dap_auth_token: Some(auth_token.clone()), authorization_bearer_token: None, }, hpke_config: hpke_keypair.config().clone(), @@ -918,7 +905,7 @@ mod tests { "--leader", leader.as_str(), "--dap-auth-token", - "collector-authentication-token", + auth_token.as_str(), &format!("--hpke-config={encoded_hpke_config}"), &format!("--hpke-private-key={encoded_private_key}"), "--vdaf", @@ -937,7 +924,7 @@ mod tests { task_id, leader: leader.clone(), authentication: AuthenticationOptions { - dap_auth_token: auth_token, + dap_auth_token: Some(auth_token.clone()), authorization_bearer_token: None, }, hpke_config: hpke_keypair.config().clone(), @@ -959,7 +946,7 @@ mod tests { "--leader", leader.as_str(), "--dap-auth-token", - "collector-authentication-token", + auth_token.as_str(), &format!("--hpke-config={encoded_hpke_config}"), &format!("--hpke-private-key={encoded_private_key}"), "--vdaf", @@ -977,7 +964,7 @@ mod tests { "--leader".to_string(), "https://example.com/dap/".to_string(), "--dap-auth-token".to_string(), - "collector-authentication-token".to_string(), + auth_token.as_str().to_string(), format!("--hpke-config={encoded_hpke_config}"), format!("--hpke-private-key={encoded_private_key}"), "--vdaf=count".to_string(), @@ -1096,13 +1083,17 @@ mod tests { "1000".to_string(), "--vdaf=count".to_string(), ]); + + let dap_auth_token: TokenInner = random(); + let bearer_token: TokenInner = random(); + let dap_auth_token_arguments = Vec::from([ "--dap-auth-token".to_string(), - "collector-authentication-token".to_string(), + dap_auth_token.as_str().to_string(), ]); let authorization_bearer_token_arguments = Vec::from([ "--authorization-bearer-token".to_string(), - "/////////////////////w==".to_string(), + bearer_token.as_str().to_string(), ]); let mut case_1_arguments = base_arguments.clone(); @@ -1112,9 +1103,7 @@ mod tests { .unwrap() .authentication, AuthenticationOptions { - dap_auth_token: Some(AuthenticationToken::DapAuth( - DapAuthToken::try_from(b"collector-authentication-token".to_vec()).unwrap() - )), + dap_auth_token: Some(AuthenticationToken::DapAuth(dap_auth_token)), authorization_bearer_token: None, } ); @@ -1127,9 +1116,7 @@ mod tests { .authentication, AuthenticationOptions { dap_auth_token: None, - authorization_bearer_token: Some(AuthenticationToken::Bearer(Vec::from( - [0xff; 16] - )),) + authorization_bearer_token: Some(AuthenticationToken::Bearer(bearer_token)), } ); diff --git a/tools/tests/cmd/collect.trycmd b/tools/tests/cmd/collect.trycmd index 70e4bae7d..18e911e4f 100644 --- a/tools/tests/cmd/collect.trycmd +++ b/tools/tests/cmd/collect.trycmd @@ -33,7 +33,7 @@ Authorization: [env: DAP_AUTH_TOKEN=] --authorization-bearer-token - Authentication token for the "Authorization: Bearer ..." HTTP header, in base64 + Authentication token for the "Authorization: Bearer ..." HTTP header [env: AUTHORIZATION_BEARER_TOKEN=]