diff --git a/aggregator/src/aggregator.rs b/aggregator/src/aggregator.rs index 6380b3b35..b06311bbc 100644 --- a/aggregator/src/aggregator.rs +++ b/aggregator/src/aggregator.rs @@ -38,7 +38,7 @@ use janus_core::test_util::dummy_vdaf; use janus_core::{ hpke::{self, HpkeApplicationInfo, Label}, http::response_to_problem_details, - task::{AuthenticationToken, VdafInstance, DAP_AUTH_HEADER, PRIO3_VERIFY_KEY_LENGTH}, + task::{AuthenticationToken, VdafInstance, PRIO3_VERIFY_KEY_LENGTH}, time::{Clock, DurationExt, IntervalExt, TimeExt}, }; use janus_messages::{ @@ -2438,16 +2438,13 @@ async fn send_request_to_helper( ) -> Result { let domain = url.domain().unwrap_or_default().to_string(); let request_body = request.get_encoded(); + let (auth_header, auth_value) = auth_token.request_authentication(); let start = Instant::now(); let response_result = http_client .request(method, url) .header(CONTENT_TYPE, content_type) - // TODO(#472): We want to be able to communicate with new Janus (prefers bearer token but - // supports `DAP-Auth-Token`) as well as older Janus and Daphne (which require - // `DAP-Auth-Token`) so for the moment, we send `DAP-Auth-Token`. But eventually we should - // determine the appropriate token header to send for a given task. - .header(DAP_AUTH_HEADER, auth_token.as_ref()) + .header(auth_header, auth_value) .body(request_body) .send() .await; diff --git a/aggregator/src/aggregator/aggregate_init_tests.rs b/aggregator/src/aggregator/aggregate_init_tests.rs index 344a9e04d..8090fc16f 100644 --- a/aggregator/src/aggregator/aggregate_init_tests.rs +++ b/aggregator/src/aggregator/aggregate_init_tests.rs @@ -9,7 +9,7 @@ use janus_aggregator_core::{ task::{test_util::TaskBuilder, QueryType, Task}, }; use janus_core::{ - task::{VdafInstance, DAP_AUTH_HEADER}, + task::{AuthenticationToken, VdafInstance, DAP_AUTH_HEADER}, test_util::{dummy_vdaf, install_test_trace_subscriber, run_vdaf, VdafTranscript}, time::{Clock, MockClock, TimeExt as _}, }; @@ -176,22 +176,23 @@ pub(crate) async fn put_aggregation_job( } #[tokio::test] -async fn aggregation_job_init_authorization_bearer_header() { +async fn aggregation_job_init_authorization_dap_auth_token() { let test_case = setup_aggregate_init_test_without_sending_request().await; + // Find a DapAuthToken among the task's aggregator auth tokens + let (auth_header, auth_value) = test_case + .task + .aggregator_auth_tokens() + .iter() + .find(|auth| matches!(auth, AuthenticationToken::DapAuth(_))) + .unwrap() + .request_authentication(); let response = put(test_case .task .aggregation_job_uri(&test_case.aggregation_job_id) .unwrap() .path()) - // Authenticate using an "Authorization: Bearer " header instead of "DAP-Auth-Token" - .with_request_header( - KnownHeaderName::Authorization, - test_case - .task - .primary_aggregator_auth_token() - .bearer_token(), - ) + .with_request_header(auth_header, auth_value) .with_request_header( KnownHeaderName::ContentType, AggregationJobInitializeReq::::MEDIA_TYPE, diff --git a/aggregator/src/aggregator/collection_job_tests.rs b/aggregator/src/aggregator/collection_job_tests.rs index 7a0122f77..477352c98 100644 --- a/aggregator/src/aggregator/collection_job_tests.rs +++ b/aggregator/src/aggregator/collection_job_tests.rs @@ -56,8 +56,9 @@ impl CollectionJobTestCase { .collection_job_uri(collection_job_id) .unwrap() .path()); - if let Some(token) = auth_token { - test_conn = test_conn.with_request_header("DAP-Auth-Token", token.as_ref().to_owned()) + if let Some(auth) = auth_token { + let (header, value) = auth.request_authentication(); + test_conn = test_conn.with_request_header(header, value); } test_conn @@ -94,8 +95,9 @@ impl CollectionJobTestCase { .unwrap() .path(), ); - if let Some(token) = auth_token { - test_conn = test_conn.with_request_header("DAP-Auth-Token", token.as_ref().to_owned()) + if let Some(auth) = auth_token { + let (header, value) = auth.request_authentication(); + test_conn = test_conn.with_request_header(header, value); } test_conn.run_async(&self.handler).await } diff --git a/aggregator/src/aggregator/http_handlers.rs b/aggregator/src/aggregator/http_handlers.rs index 8fd1fc8c6..223adffad 100644 --- a/aggregator/src/aggregator/http_handlers.rs +++ b/aggregator/src/aggregator/http_handlers.rs @@ -5,7 +5,7 @@ use janus_aggregator_api::instrumented; use janus_aggregator_core::datastore::Datastore; use janus_core::{ http::extract_bearer_token, - task::{AuthenticationToken, DAP_AUTH_HEADER}, + task::{AuthenticationToken, DapAuthToken, DAP_AUTH_HEADER}, time::Clock, }; use janus_messages::{ @@ -526,32 +526,21 @@ fn parse_collection_job_id(captures: &Captures) -> Result Result, Error> { // Prefer a bearer token, then fall back to DAP-Auth-Token - let bearer_token = - extract_bearer_token(conn).map_err(|_| Error::UnauthorizedRequest(*task_id))?; - if bearer_token.is_some() { - return bearer_token - .map(AuthenticationToken::try_from) - .transpose() - .map_err(|_| { - Error::BadRequest( - "Authorization: Bearer value decodes to an authentication token containing \ - unsafe bytes" - .to_string(), - ) - }); + if let Some(bearer_token) = + extract_bearer_token(conn).map_err(|_| Error::UnauthorizedRequest(*task_id))? + { + return Ok(Some(AuthenticationToken::Bearer(bearer_token))); } conn.request_headers() .get(DAP_AUTH_HEADER) .map(|value| { - value.as_ref().to_owned().try_into().map_err(|_| { - Error::BadRequest( - "DAP-Auth-Header value is not a valid HTTP header value".to_string(), - ) - }) + DapAuthToken::try_from(value.as_ref().to_vec()) + .map(AuthenticationToken::DapAuth) + .map_err(|e| Error::BadRequest(format!("bad DAP-Auth-Token header: {e}"))) }) .transpose() } @@ -1211,6 +1200,8 @@ mod tests { datastore.put_task(&task).await.unwrap(); + let (wrong_auth_header, wrong_auth_value) = + random::().request_authentication(); let request = AggregationJobInitializeReq::new( Vec::new(), PartialBatchSelector::new_time_interval(), @@ -1225,10 +1216,7 @@ mod tests { .aggregation_job_uri(&aggregation_job_id) .unwrap() .path()) - .with_request_header( - "DAP-Auth-Token", - random::().as_ref().to_owned(), - ) + .with_request_header(wrong_auth_header, wrong_auth_value) .with_request_header( KnownHeaderName::ContentType, AggregationJobInitializeReq::::MEDIA_TYPE, diff --git a/aggregator/src/aggregator/problem_details.rs b/aggregator/src/aggregator/problem_details.rs index 74daeff19..f9225b5c7 100644 --- a/aggregator/src/aggregator/problem_details.rs +++ b/aggregator/src/aggregator/problem_details.rs @@ -66,10 +66,7 @@ mod tests { use assert_matches::assert_matches; use futures::future::join_all; use http::Method; - use janus_core::{ - task::AuthenticationToken, - time::{Clock, RealClock}, - }; + use janus_core::time::{Clock, RealClock}; use janus_messages::{ problem_type::{DapProblemType, DapProblemTypeParseError}, Duration, HpkeConfigId, Interval, ReportIdChecksum, @@ -242,7 +239,7 @@ mod tests { "test", "text/plain", (), - &AuthenticationToken::try_from("auth".as_bytes().to_vec()).unwrap(), + &random(), &request_histogram, ) .await diff --git a/aggregator_api/src/lib.rs b/aggregator_api/src/lib.rs index 69b57bb01..50aac5868 100644 --- a/aggregator_api/src/lib.rs +++ b/aggregator_api/src/lib.rs @@ -9,8 +9,10 @@ use janus_aggregator_core::{ SecretBytes, }; use janus_core::{ - hpke::generate_hpke_config_and_private_key, http::extract_bearer_token, - task::AuthenticationToken, time::Clock, + hpke::generate_hpke_config_and_private_key, + http::extract_bearer_token, + task::{AuthenticationToken, DapAuthToken}, + time::Clock, }; use janus_messages::{Duration, HpkeAeadId, HpkeKdfId, HpkeKemId, Role, TaskId, Time}; use models::{GetTaskMetricsResp, TaskResp}; @@ -170,17 +172,25 @@ async fn post_task( Status::BadRequest, ) })?; - Vec::from([AuthenticationToken::try_from(token_bytes).map_err(|_| { - Error::new( - "Invalid HTTP header value in aggregator_auth_token".to_string(), - Status::BadRequest, - ) - })?]) + Vec::from([ + // TODO(#472): Each token in the PostTaskReq should indicate whether it is a bearer + // token or a DAP-Auth-Token. For now, assume the latter. + DapAuthToken::try_from(token_bytes) + .map(AuthenticationToken::DapAuth) + .map_err(|_| { + Error::new( + "Invalid HTTP header value in aggregator_auth_token".to_string(), + Status::BadRequest, + ) + })?, + ]) } else { - Vec::from([random()]) + // TODO(#472): switch to generating bearer tokens by default + Vec::from([AuthenticationToken::DapAuth(random())]) }; let collector_auth_tokens = match (req.role, req.collector_auth_token) { - (Role::Leader, None) => Vec::from([random()]), + // TODO(#472): switch to generating bearer tokens by default + (Role::Leader, None) => Vec::from([AuthenticationToken::DapAuth(random())]), (Role::Leader, Some(encoded)) => { let token_bytes = URL_SAFE_NO_PAD.decode(encoded).map_err(|err| { Error::new( @@ -188,12 +198,16 @@ async fn post_task( Status::BadRequest, ) })?; - Vec::from([AuthenticationToken::try_from(token_bytes).map_err(|_| { - Error::new( - "Invalid HTTP header value in collector_auth_token".to_string(), - Status::BadRequest, - ) - })?]) + // TODO(#472): Each token in the PostTaskReq should indicate whether it is a bearer + // token or a DAP-Auth-Token. For now, assume the latter. + Vec::from([DapAuthToken::try_from(token_bytes) + .map(AuthenticationToken::DapAuth) + .map_err(|_| { + Error::new( + "Invalid HTTP header value in collector_auth_token".to_string(), + Status::BadRequest, + ) + })?]) } (Role::Helper, None) => Vec::new(), (Role::Helper, Some(_)) => { @@ -486,7 +500,7 @@ mod tests { }; use janus_core::{ hpke::{generate_hpke_config_and_private_key, HpkeKeypair, HpkePrivateKey}, - task::{AuthenticationToken, VdafInstance}, + task::{AuthenticationToken, DapAuthToken, VdafInstance}, test_util::{ dummy_vdaf::{self, AggregationParam}, install_test_trace_subscriber, @@ -712,8 +726,8 @@ mod tests { let vdaf_verify_key = SecretBytes::new(thread_rng().sample_iter(Standard).take(16).collect()); - let aggregator_auth_token = random::(); - let collector_auth_token = random::(); + let aggregator_auth_token = AuthenticationToken::DapAuth(random()); + let collector_auth_token = AuthenticationToken::DapAuth(random()); // Verify: posting a task creates a new task which matches the request. let req = PostTaskReq { @@ -1354,12 +1368,12 @@ mod tests { HpkeAeadId::Aes128Gcm, HpkePublicKey::from([0u8; 32].to_vec()), ), - Vec::from([ - AuthenticationToken::try_from("aggregator-12345678".as_bytes().to_vec()).unwrap(), - ]), - Vec::from([ - AuthenticationToken::try_from("collector-abcdef00".as_bytes().to_vec()).unwrap(), - ]), + Vec::from([AuthenticationToken::DapAuth( + DapAuthToken::try_from(b"aggregator-12345678".to_vec()).unwrap(), + )]), + Vec::from([AuthenticationToken::DapAuth( + DapAuthToken::try_from(b"collector-abcdef00".to_vec()).unwrap(), + )]), [(HpkeKeypair::new( HpkeConfig::new( HpkeConfigId::from(13), diff --git a/aggregator_core/src/datastore.rs b/aggregator_core/src/datastore.rs index 64effdeed..41d7a789c 100644 --- a/aggregator_core/src/datastore.rs +++ b/aggregator_core/src/datastore.rs @@ -2,9 +2,10 @@ use self::models::{ AcquiredAggregationJob, AcquiredCollectionJob, AggregateShareJob, AggregationJob, - AggregatorRole, Batch, BatchAggregation, CollectionJob, CollectionJobState, - CollectionJobStateCode, LeaderStoredReport, Lease, LeaseToken, OutstandingBatch, - ReportAggregation, ReportAggregationState, ReportAggregationStateCode, SqlInterval, + AggregatorRole, AuthenticationTokenType, Batch, BatchAggregation, CollectionJob, + CollectionJobState, CollectionJobStateCode, LeaderStoredReport, Lease, LeaseToken, + OutstandingBatch, ReportAggregation, ReportAggregationState, ReportAggregationStateCode, + SqlInterval, }; #[cfg(feature = "test-util")] use crate::VdafHasAggregationParameter; @@ -18,7 +19,7 @@ use chrono::NaiveDateTime; use futures::future::try_join_all; use janus_core::{ hpke::{HpkeKeypair, HpkePrivateKey}, - task::{AuthenticationToken, VdafInstance}, + task::VdafInstance, time::{Clock, TimeExt}, }; use janus_messages::{ @@ -86,7 +87,7 @@ macro_rules! supported_schema_versions { // List of schema versions that this version of Janus can safely run on. If any other schema // version is seen, [`Datastore::new`] fails. -supported_schema_versions!(8); +supported_schema_versions!(9); /// Datastore represents a datastore for Janus, with support for transactional reads and writes. /// In practice, Datastore instances are currently backed by a PostgreSQL database. @@ -450,6 +451,7 @@ impl Transaction<'_, C> { // Aggregator auth tokens. let mut aggregator_auth_token_ords = Vec::new(); + let mut aggregator_auth_token_types = Vec::new(); let mut aggregator_auth_tokens = Vec::new(); for (ord, token) in task.aggregator_auth_tokens().iter().enumerate() { let ord = i64::try_from(ord)?; @@ -466,25 +468,28 @@ impl Transaction<'_, C> { )?; aggregator_auth_token_ords.push(ord); + aggregator_auth_token_types.push(AuthenticationTokenType::from(token)); aggregator_auth_tokens.push(encrypted_aggregator_auth_token); } let stmt = self .prepare_cached( - "INSERT INTO task_aggregator_auth_tokens (task_id, ord, token) + "INSERT INTO task_aggregator_auth_tokens (task_id, ord, type, token) SELECT (SELECT id FROM tasks WHERE task_id = $1), - * FROM UNNEST($2::BIGINT[], $3::BYTEA[])", + * FROM UNNEST($2::BIGINT[], $3::AUTH_TOKEN_TYPE[], $4::BYTEA[])", ) .await?; let aggregator_auth_tokens_params: &[&(dyn ToSql + Sync)] = &[ /* task_id */ &task.id().as_ref(), /* ords */ &aggregator_auth_token_ords, + /* token_types */ &aggregator_auth_token_types, /* tokens */ &aggregator_auth_tokens, ]; let aggregator_auth_tokens_future = self.execute(&stmt, aggregator_auth_tokens_params); // Collector auth tokens. let mut collector_auth_token_ords = Vec::new(); + let mut collector_auth_token_types = Vec::new(); let mut collector_auth_tokens = Vec::new(); for (ord, token) in task.collector_auth_tokens().iter().enumerate() { let ord = i64::try_from(ord)?; @@ -501,19 +506,21 @@ impl Transaction<'_, C> { )?; collector_auth_token_ords.push(ord); + collector_auth_token_types.push(AuthenticationTokenType::from(token)); collector_auth_tokens.push(encrypted_collector_auth_token); } let stmt = self .prepare_cached( - "INSERT INTO task_collector_auth_tokens (task_id, ord, token) + "INSERT INTO task_collector_auth_tokens (task_id, ord, type, token) SELECT (SELECT id FROM tasks WHERE task_id = $1), - * FROM UNNEST($2::BIGINT[], $3::BYTEA[])", + * FROM UNNEST($2::BIGINT[], $3::AUTH_TOKEN_TYPE[], $4::BYTEA[])", ) .await?; let collector_auth_tokens_params: &[&(dyn ToSql + Sync)] = &[ /* task_id */ &task.id().as_ref(), /* ords */ &collector_auth_token_ords, + /* token_types */ &collector_auth_token_types, /* tokens */ &collector_auth_tokens, ]; let collector_auth_tokens_future = self.execute(&stmt, collector_auth_tokens_params); @@ -619,7 +626,7 @@ impl Transaction<'_, C> { let stmt = self .prepare_cached( - "SELECT ord, token FROM task_aggregator_auth_tokens + "SELECT ord, type, token FROM task_aggregator_auth_tokens WHERE task_id = (SELECT id FROM tasks WHERE task_id = $1) ORDER BY ord ASC", ) .await?; @@ -627,7 +634,7 @@ impl Transaction<'_, C> { let stmt = self .prepare_cached( - "SELECT ord, token FROM task_collector_auth_tokens + "SELECT ord, type, token FROM task_collector_auth_tokens WHERE task_id = (SELECT id FROM tasks WHERE task_id = $1) ORDER BY ord ASC", ) .await?; @@ -693,7 +700,7 @@ impl Transaction<'_, C> { .prepare_cached( "SELECT (SELECT tasks.task_id FROM tasks WHERE tasks.id = task_aggregator_auth_tokens.task_id), - ord, token FROM task_aggregator_auth_tokens ORDER BY ord ASC", + ord, type, token FROM task_aggregator_auth_tokens ORDER BY ord ASC", ) .await?; let aggregator_auth_token_rows = self.query(&stmt, &[]); @@ -702,7 +709,7 @@ impl Transaction<'_, C> { .prepare_cached( "SELECT (SELECT tasks.task_id FROM tasks WHERE tasks.id = task_collector_auth_tokens.task_id), - ord, token FROM task_collector_auth_tokens ORDER BY ord ASC", + ord, type, token FROM task_collector_auth_tokens ORDER BY ord ASC", ) .await?; let collector_auth_token_rows = self.query(&stmt, &[]); @@ -843,42 +850,42 @@ impl Transaction<'_, C> { let mut aggregator_auth_tokens = Vec::new(); for row in aggregator_auth_token_rows { let ord: i64 = row.get("ord"); + let auth_token_type: AuthenticationTokenType = row.get("type"); let encrypted_aggregator_auth_token: Vec = row.get("token"); let mut row_id = [0u8; TaskId::LEN + size_of::()]; row_id[..TaskId::LEN].copy_from_slice(task_id.as_ref()); row_id[TaskId::LEN..].copy_from_slice(&ord.to_be_bytes()); - aggregator_auth_tokens.push( - AuthenticationToken::try_from(self.crypter.decrypt( + aggregator_auth_tokens.push(auth_token_type.as_authentication( + &self.crypter.decrypt( "task_aggregator_auth_tokens", &row_id, "token", &encrypted_aggregator_auth_token, - )?) - .map_err(|e| Error::DbState(e.to_string()))?, - ); + )?, + )?); } // Collector authentication tokens. let mut collector_auth_tokens = Vec::new(); for row in collector_auth_token_rows { let ord: i64 = row.get("ord"); + let auth_token_type: AuthenticationTokenType = row.get("type"); let encrypted_collector_auth_token: Vec = row.get("token"); let mut row_id = [0u8; TaskId::LEN + size_of::()]; row_id[..TaskId::LEN].copy_from_slice(task_id.as_ref()); row_id[TaskId::LEN..].copy_from_slice(&ord.to_be_bytes()); - collector_auth_tokens.push( - AuthenticationToken::try_from(self.crypter.decrypt( + collector_auth_tokens.push(auth_token_type.as_authentication( + &self.crypter.decrypt( "task_collector_auth_tokens", &row_id, "token", &encrypted_collector_auth_token, - )?) - .map_err(|e| Error::DbState(e.to_string()))?, - ); + )?, + )?); } // HPKE keys. @@ -4020,7 +4027,7 @@ pub mod models { use derivative::Derivative; use janus_core::{ report_id::ReportIdChecksumExt, - task::VdafInstance, + task::{AuthenticationToken, DapAuthToken, VdafInstance}, time::{DurationExt, IntervalExt, TimeExt}, }; use janus_messages::{ @@ -4048,6 +4055,40 @@ pub mod models { // implementations don't play nice with generic fields, even if those fields are constrained to // themselves implement [Partial]Eq. + /// AuthenticationTokenType represents the type of an authentication token. It corresponds to enum + /// `AUTH_TOKEN_TYPE` in the schema. + #[derive(Copy, Clone, Debug, PartialEq, Eq, ToSql, FromSql)] + #[postgres(name = "auth_token_type")] + pub enum AuthenticationTokenType { + #[postgres(name = "DAP_AUTH")] + DapAuthToken, + #[postgres(name = "BEARER")] + AuthorizationBearerToken, + } + + impl AuthenticationTokenType { + pub fn as_authentication(&self, token: &[u8]) -> Result { + match self { + Self::DapAuthToken => DapAuthToken::try_from(token.to_vec()) + .map(AuthenticationToken::DapAuth) + .map_err(|e| { + Error::DbState(format!("invalid DAP auth token in database: {e:?}")) + }), + Self::AuthorizationBearerToken => Ok(AuthenticationToken::Bearer(token.into())), + } + } + } + + impl From<&AuthenticationToken> for AuthenticationTokenType { + fn from(value: &AuthenticationToken) -> Self { + match value { + AuthenticationToken::DapAuth(_) => Self::DapAuthToken, + AuthenticationToken::Bearer(_) => Self::AuthorizationBearerToken, + _ => unreachable!(), + } + } + } + /// Represents a report as it is stored in the leader's database, corresponding to a row in /// `client_reports`, where `leader_input_share` and `helper_encrypted_input_share` are required /// to be populated. diff --git a/aggregator_core/src/task.rs b/aggregator_core/src/task.rs index b697c1186..9a7e86942 100644 --- a/aggregator_core/src/task.rs +++ b/aggregator_core/src/task.rs @@ -107,10 +107,8 @@ pub struct Task { /// HPKE configuration for the collector. collector_hpke_config: HpkeConfig, /// Tokens used to authenticate messages sent to or received from the other aggregator. - #[derivative(Debug = "ignore")] aggregator_auth_tokens: Vec, /// Tokens used to authenticate messages sent to or received from the collector. - #[derivative(Debug = "ignore")] collector_auth_tokens: Vec, /// HPKE configurations & private keys used by this aggregator to decrypt client reports. hpke_keys: HashMap, @@ -326,6 +324,7 @@ impl Task { /// Returns the [`AuthenticationToken`] currently used by the collector to authenticate itself /// to the aggregators. pub fn primary_collector_auth_token(&self) -> &AuthenticationToken { + // Unwrap safety: self.collector_auth_tokens is never empty self.collector_auth_tokens.iter().rev().next().unwrap() } @@ -417,9 +416,9 @@ pub struct SerializedTask { time_precision: Duration, tolerable_clock_skew: Duration, collector_hpke_config: HpkeConfig, - aggregator_auth_tokens: Vec, // in unpadded base64url - collector_auth_tokens: Vec, // in unpadded base64url - hpke_keys: Vec, // uses unpadded base64url + aggregator_auth_tokens: Vec, + collector_auth_tokens: Vec, + hpke_keys: Vec, // uses unpadded base64url } impl SerializedTask { @@ -455,13 +454,11 @@ impl SerializedTask { } if self.aggregator_auth_tokens.is_empty() { - self.aggregator_auth_tokens = - Vec::from([URL_SAFE_NO_PAD.encode(random::())]); + self.aggregator_auth_tokens = Vec::from([random()]); } if self.collector_auth_tokens.is_empty() && self.role == Role::Leader { - self.collector_auth_tokens = - Vec::from([URL_SAFE_NO_PAD.encode(random::())]); + self.collector_auth_tokens = Vec::from([random()]); } if self.hpke_keys.is_empty() { @@ -484,16 +481,6 @@ impl Serialize for Task { .iter() .map(|key| URL_SAFE_NO_PAD.encode(key.as_ref())) .collect(); - let aggregator_auth_tokens = self - .aggregator_auth_tokens - .iter() - .map(|token| URL_SAFE_NO_PAD.encode(token)) - .collect(); - let collector_auth_tokens = self - .collector_auth_tokens - .iter() - .map(|token| URL_SAFE_NO_PAD.encode(token)) - .collect(); let hpke_keys = self.hpke_keys.values().cloned().collect(); SerializedTask { @@ -510,8 +497,8 @@ impl Serialize for Task { time_precision: self.time_precision, tolerable_clock_skew: self.tolerable_clock_skew, collector_hpke_config: self.collector_hpke_config.clone(), - aggregator_auth_tokens, - collector_auth_tokens, + aggregator_auth_tokens: self.aggregator_auth_tokens.clone(), + collector_auth_tokens: self.collector_auth_tokens.clone(), hpke_keys, } .serialize(serializer) @@ -534,34 +521,6 @@ impl TryFrom for Task { .map(|key| Ok(SecretBytes::new(URL_SAFE_NO_PAD.decode(key)?))) .collect::>()?; - // aggregator_auth_tokens - let aggregator_auth_tokens = serialized_task - .aggregator_auth_tokens - .into_iter() - .map(|token| { - AuthenticationToken::try_from(URL_SAFE_NO_PAD.decode(token)?).map_err(|_| { - Error::InvalidParameter(concat!( - "value in aggregator_auth_tokens does not base64url-decode to a valid ", - "HTTP header value" - )) - }) - }) - .collect::, Self::Error>>()?; - - // collector_auth_tokens - let collector_auth_tokens = serialized_task - .collector_auth_tokens - .into_iter() - .map(|token| { - AuthenticationToken::try_from(URL_SAFE_NO_PAD.decode(token)?).map_err(|_| { - Error::InvalidParameter(concat!( - "value in collector_auth_tokens does not base64url-decode to a valid ", - "HTTP header value" - )) - }) - }) - .collect::, Self::Error>>()?; - Task::new( task_id, serialized_task.aggregator_endpoints, @@ -576,8 +535,8 @@ impl TryFrom for Task { serialized_task.time_precision, serialized_task.tolerable_clock_skew, serialized_task.collector_hpke_config, - aggregator_auth_tokens, - collector_auth_tokens, + serialized_task.aggregator_auth_tokens, + serialized_task.collector_auth_tokens, serialized_task.hpke_keys, ) } @@ -653,7 +612,7 @@ pub mod test_util { ); let collector_auth_tokens = if role == Role::Leader { - Vec::from([random(), random()]) + Vec::from([random(), AuthenticationToken::DapAuth(random())]) } else { Vec::new() }; @@ -676,7 +635,7 @@ pub mod test_util { Duration::from_hours(8).unwrap(), Duration::from_minutes(10).unwrap(), generate_test_hpke_config_and_private_key().config().clone(), - Vec::from([random(), random()]), + Vec::from([random(), AuthenticationToken::DapAuth(random())]), collector_auth_tokens, Vec::from([aggregator_keypair_0, aggregator_keypair_1]), ) @@ -814,14 +773,12 @@ pub mod test_util { #[cfg(test)] mod tests { use crate::{ - task::{test_util::TaskBuilder, Error, QueryType, SerializedTask, Task, VdafInstance}, + task::{test_util::TaskBuilder, QueryType, Task, VdafInstance}, SecretBytes, }; - use assert_matches::assert_matches; - use base64::{engine::general_purpose::URL_SAFE_NO_PAD, Engine}; use janus_core::{ hpke::{test_util::generate_test_hpke_config_and_private_key, HpkeKeypair, HpkePrivateKey}, - task::{AuthenticationToken, PRIO3_VERIFY_KEY_LENGTH}, + task::{AuthenticationToken, DapAuthToken, PRIO3_VERIFY_KEY_LENGTH}, test_util::roundtrip_encoding, time::DurationExt, }; @@ -894,8 +851,8 @@ mod tests { Duration::from_hours(8).unwrap(), Duration::from_minutes(10).unwrap(), generate_test_hpke_config_and_private_key().config().clone(), - Vec::from([random::()]), - Vec::from([random::()]), + Vec::from([random()]), + Vec::from([random()]), Vec::from([generate_test_hpke_config_and_private_key()]), ) .unwrap(); @@ -918,7 +875,7 @@ mod tests { Duration::from_hours(8).unwrap(), Duration::from_minutes(10).unwrap(), generate_test_hpke_config_and_private_key().config().clone(), - Vec::from([random::()]), + Vec::from([random()]), Vec::new(), Vec::from([generate_test_hpke_config_and_private_key()]), ) @@ -942,8 +899,8 @@ mod tests { Duration::from_hours(8).unwrap(), Duration::from_minutes(10).unwrap(), generate_test_hpke_config_and_private_key().config().clone(), - Vec::from([random::()]), - Vec::from([random::()]), + Vec::from([random()]), + Vec::from([random()]), Vec::from([generate_test_hpke_config_and_private_key()]), ) .unwrap_err(); @@ -968,8 +925,8 @@ mod tests { Duration::from_hours(8).unwrap(), Duration::from_minutes(10).unwrap(), generate_test_hpke_config_and_private_key().config().clone(), - Vec::from([random::()]), - Vec::from([random::()]), + Vec::from([random()]), + Vec::from([random()]), Vec::from([generate_test_hpke_config_and_private_key()]), ) .unwrap(); @@ -1054,8 +1011,10 @@ mod tests { HpkeAeadId::Aes128Gcm, HpkePublicKey::from(b"collector hpke public key".to_vec()), ), - Vec::from([AuthenticationToken::try_from(b"aggregator token".to_vec()).unwrap()]), - Vec::from([AuthenticationToken::try_from(b"collector token".to_vec()).unwrap()]), + Vec::from([AuthenticationToken::DapAuth( + DapAuthToken::try_from(b"aggregator token".to_vec()).unwrap(), + )]), + Vec::from([AuthenticationToken::Bearer(b"collector token".to_vec())]), [HpkeKeypair::new( HpkeConfig::new( HpkeConfigId::from(255), @@ -1144,11 +1103,27 @@ mod tests { Token::StructEnd, Token::Str("aggregator_auth_tokens"), Token::Seq { len: Some(1) }, + Token::Struct { + name: "AuthenticationToken", + len: 2, + }, + Token::Str("type"), + Token::Str("DapAuth"), + Token::Str("token"), Token::Str("YWdncmVnYXRvciB0b2tlbg"), + Token::StructEnd, Token::SeqEnd, Token::Str("collector_auth_tokens"), Token::Seq { len: Some(1) }, + Token::Struct { + name: "AuthenticationToken", + len: 2, + }, + Token::Str("type"), + Token::Str("Bearer"), + Token::Str("token"), Token::Str("Y29sbGVjdG9yIHRva2Vu"), + Token::StructEnd, Token::SeqEnd, Token::Str("hpke_keys"), Token::Seq { len: Some(1) }, @@ -1216,7 +1191,7 @@ mod tests { HpkeAeadId::Aes128Gcm, HpkePublicKey::from(b"collector hpke public key".to_vec()), ), - Vec::from([AuthenticationToken::try_from(b"aggregator token".to_vec()).unwrap()]), + Vec::from([AuthenticationToken::Bearer(b"aggregator token".to_vec())]), Vec::new(), [HpkeKeypair::new( HpkeConfig::new( @@ -1316,7 +1291,15 @@ mod tests { Token::StructEnd, Token::Str("aggregator_auth_tokens"), Token::Seq { len: Some(1) }, - Token::Str("YWdncmVnYXRvciB0b2tlbg"), + Token::Struct { + name: "AuthenticationToken", + len: 2, + }, + Token::Str("type"), + Token::Str("Bearer"), + Token::Str("token"), + Token::Str("YWdncmVnYXRvciB0b2tlbg=="), + Token::StructEnd, Token::SeqEnd, Token::Str("collector_auth_tokens"), Token::Seq { len: Some(0) }, @@ -1363,64 +1346,4 @@ mod tests { ], ); } - - #[test] - fn reject_invalid_auth_tokens() { - let aggregator_keypair = generate_test_hpke_config_and_private_key(); - let collector_keypair = generate_test_hpke_config_and_private_key(); - - let bad_agg_auth_token = SerializedTask { - task_id: Some(random()), - aggregator_endpoints: Vec::from([ - "https://www.example.com/".parse().unwrap(), - "https://www.example.net/".parse().unwrap(), - ]), - query_type: QueryType::TimeInterval, - vdaf: VdafInstance::Prio3Count, - role: Role::Helper, - vdaf_verify_keys: Vec::from([]), - max_batch_query_count: 1, - task_expiration: None, - report_expiry_age: None, - min_batch_size: 100, - time_precision: Duration::from_seconds(3600), - tolerable_clock_skew: Duration::from_seconds(15), - collector_hpke_config: collector_keypair.config().clone(), - aggregator_auth_tokens: Vec::from(["AAAAAAAAAAAAAA".to_string()]), - collector_auth_tokens: Vec::new(), - hpke_keys: Vec::from([aggregator_keypair.clone()]), - }; - let err = Task::try_from(bad_agg_auth_token).unwrap_err(); - assert_matches!(err, Error::InvalidParameter(message) => { - assert!(message.contains("aggregator") && message.contains("HTTP header value"), "{}", message); - }); - - let bad_collector_auth_token = SerializedTask { - task_id: Some(random()), - aggregator_endpoints: Vec::from([ - "https://www.example.com/".parse().unwrap(), - "https://www.example.net/".parse().unwrap(), - ]), - query_type: QueryType::TimeInterval, - vdaf: VdafInstance::Prio3Count, - role: Role::Leader, - vdaf_verify_keys: Vec::from([]), - max_batch_query_count: 1, - task_expiration: None, - report_expiry_age: None, - min_batch_size: 100, - time_precision: Duration::from_seconds(3600), - tolerable_clock_skew: Duration::from_seconds(15), - collector_hpke_config: collector_keypair.config().clone(), - aggregator_auth_tokens: Vec::from([ - URL_SAFE_NO_PAD.encode(random::()) - ]), - collector_auth_tokens: Vec::from(["AAAAAAAAAAAAAA".to_string()]), - hpke_keys: Vec::from([aggregator_keypair]), - }; - let err = Task::try_from(bad_collector_auth_token).unwrap_err(); - assert_matches!(err, Error::InvalidParameter(message) => { - assert!(message.contains("collector") && message.contains("HTTP header value"), "{}", message); - }); - } } diff --git a/collector/src/lib.rs b/collector/src/lib.rs index bfd72811f..04c721fb6 100644 --- a/collector/src/lib.rs +++ b/collector/src/lib.rs @@ -7,8 +7,8 @@ //! # Examples //! //! ```no_run -//! use janus_collector::{Authentication, Collector, CollectorParameters, default_http_client}; -//! use janus_core::{hpke::generate_hpke_config_and_private_key, task::AuthenticationToken}; +//! use janus_collector::{AuthenticationToken, Collector, CollectorParameters, default_http_client}; +//! use janus_core::{hpke::generate_hpke_config_and_private_key}; //! use janus_messages::{ //! Duration, HpkeAeadId, HpkeConfig, HpkeConfigId, HpkeKdfId, HpkeKemId, Interval, TaskId, //! Time, Query, @@ -26,12 +26,10 @@ //! HpkeKdfId::HkdfSha256, //! HpkeAeadId::Aes128Gcm, //! ); -//! let authentication_token = -//! AuthenticationToken::try_from(b"my-authentication-token".to_vec()).unwrap(); //! let parameters = CollectorParameters::new_with_authentication( //! task_id, //! "https://example.com/dap/".parse().unwrap(), -//! Authentication::DapAuthToken(authentication_token), +//! AuthenticationToken::Bearer(b"my-authentication-token".to_vec()), //! hpke_keypair.config().clone(), //! hpke_keypair.private_key().clone(), //! ); @@ -59,11 +57,12 @@ use backoff::{backoff::Backoff, ExponentialBackoff}; use chrono::{DateTime, Duration, Utc}; use derivative::Derivative; use http_api_problem::HttpApiProblem; +pub use janus_core::task::AuthenticationToken; use janus_core::{ hpke::{self, HpkeApplicationInfo, HpkePrivateKey}, http::response_to_problem_details, retries::{http_request_exponential_backoff, retry_http_request}, - task::{url_ensure_trailing_slash, AuthenticationToken, DAP_AUTH_HEADER}, + task::url_ensure_trailing_slash, time::{DurationExt, TimeExt}, }; use janus_messages::{ @@ -78,7 +77,7 @@ use prio::{ }; use rand::random; use reqwest::{ - header::{HeaderValue, ToStrError, AUTHORIZATION, CONTENT_TYPE, RETRY_AFTER}, + header::{HeaderValue, ToStrError, CONTENT_TYPE, RETRY_AFTER}, Response, StatusCode, }; use retry_after::FromHeaderValueError; @@ -152,18 +151,6 @@ static COLLECTOR_USER_AGENT: &str = concat!( "collector" ); -/// Authentication configuration for communication with the leader aggregator. -#[derive(Derivative)] -#[derivative(Debug)] -#[non_exhaustive] -pub enum Authentication { - /// Bearer token authentication, via the `DAP-Auth-Token` header. - DapAuthToken(#[derivative(Debug = "ignore")] AuthenticationToken), - /// Bearer token authentication, via a header of the form `Authorization: Bearer `. - AuthorizationBearerToken(#[derivative(Debug = "ignore")] AuthenticationToken), -} - /// The DAP collector's view of task parameters. #[derive(Derivative)] #[derivative(Debug)] @@ -174,7 +161,7 @@ pub struct CollectorParameters { #[derivative(Debug(format_with = "std::fmt::Display::fmt"))] leader_endpoint: Url, /// The authentication information needed to communicate with the leader aggregator. - authentication: Authentication, + authentication: AuthenticationToken, /// HPKE configuration and public key used for encryption of aggregate shares. #[derivative(Debug = "ignore")] hpke_config: HpkeConfig, @@ -199,14 +186,14 @@ impl CollectorParameters { pub fn new( task_id: TaskId, leader_endpoint: Url, - authentication_token: AuthenticationToken, + authentication: AuthenticationToken, hpke_config: HpkeConfig, hpke_private_key: HpkePrivateKey, ) -> CollectorParameters { Self::new_with_authentication( task_id, leader_endpoint, - Authentication::DapAuthToken(authentication_token), + authentication, hpke_config, hpke_private_key, ) @@ -216,7 +203,7 @@ impl CollectorParameters { pub fn new_with_authentication( task_id: TaskId, mut leader_endpoint: Url, - authentication: Authentication, + authentication: AuthenticationToken, hpke_config: HpkeConfig, hpke_private_key: HpkePrivateKey, ) -> CollectorParameters { @@ -435,20 +422,15 @@ impl Collector { let response_res = retry_http_request( self.parameters.http_request_retry_parameters.clone(), || async { - let mut request = self - .http_client + let (auth_header, auth_value) = + self.parameters.authentication.request_authentication(); + self.http_client .put(collection_job_url.clone()) .header(CONTENT_TYPE, CollectionReq::::MEDIA_TYPE) - .body(collect_request.get_encoded()); - match &self.parameters.authentication { - Authentication::DapAuthToken(token) => { - request = request.header(DAP_AUTH_HEADER, token.as_ref()) - } - Authentication::AuthorizationBearerToken(token) => { - request = request.header(AUTHORIZATION, token.bearer_token()) - } - } - request.send().await + .body(collect_request.get_encoded()) + .header(auth_header, auth_value) + .send() + .await }, ) .await; @@ -490,16 +472,13 @@ impl Collector { let response_res = retry_http_request( self.parameters.http_request_retry_parameters.clone(), || async { - let mut request = self.http_client.post(job.collection_job_url.clone()); - match &self.parameters.authentication { - Authentication::DapAuthToken(token) => { - request = request.header(DAP_AUTH_HEADER, token.as_ref()) - } - Authentication::AuthorizationBearerToken(token) => { - request = request.header(AUTHORIZATION, token.bearer_token()) - } - } - request.send().await + let (auth_header, auth_value) = + self.parameters.authentication.request_authentication(); + self.http_client + .post(job.collection_job_url.clone()) + .header(auth_header, auth_value) + .send() + .await }, ) .await; @@ -703,8 +682,8 @@ pub mod test_util { #[cfg(test)] mod tests { use crate::{ - default_http_client, Authentication, Collection, CollectionJob, Collector, - CollectorParameters, Error, PollResult, + default_http_client, Collection, CollectionJob, Collector, CollectorParameters, Error, + PollResult, }; use assert_matches::assert_matches; use chrono::{DateTime, NaiveDateTime, TimeZone, Utc}; @@ -747,7 +726,7 @@ mod tests { let parameters = CollectorParameters::new_with_authentication( random(), server_url, - Authentication::DapAuthToken(AuthenticationToken::try_from(b"token".to_vec()).unwrap()), + AuthenticationToken::Bearer(b"token".to_vec()), hpke_keypair.config().clone(), hpke_keypair.private_key().clone(), ) @@ -848,7 +827,7 @@ mod tests { let collector_parameters = CollectorParameters::new_with_authentication( random(), "http://example.com/dap".parse().unwrap(), - Authentication::DapAuthToken(AuthenticationToken::try_from(b"token".to_vec()).unwrap()), + AuthenticationToken::Bearer(b"token".to_vec()), hpke_keypair.config().clone(), hpke_keypair.private_key().clone(), ); @@ -861,7 +840,7 @@ mod tests { let collector_parameters = CollectorParameters::new_with_authentication( random(), "http://example.com".parse().unwrap(), - Authentication::DapAuthToken(AuthenticationToken::try_from(b"token".to_vec()).unwrap()), + AuthenticationToken::Bearer(b"token".to_vec()), hpke_keypair.config().clone(), hpke_keypair.private_key().clone(), ); @@ -879,6 +858,9 @@ mod tests { let vdaf = Prio3::new_count(2).unwrap(); let transcript = run_vdaf(&vdaf, &random(), &(), &random(), &1); let collector = setup_collector(&mut server, vdaf); + let (auth_header, auth_value) = + collector.parameters.authentication.request_authentication(); + let auth_value = String::from_utf8(auth_value).unwrap(); let batch_interval = Interval::new( Time::from_seconds_since_epoch(1_000_000), @@ -905,7 +887,7 @@ mod tests { CONTENT_TYPE.as_str(), CollectionReq::::MEDIA_TYPE, ) - .match_header("DAP-Auth-Token", "token") + .match_header(auth_header, auth_value.as_str()) .with_status(201) .expect(1) .create_async() @@ -913,13 +895,14 @@ mod tests { let job = collector .start_collection(Query::new_time_interval(batch_interval), &()) - .await - .unwrap(); - assert_eq!(job.query.batch_interval(), &batch_interval); + .await; mocked_collect_start_error.assert_async().await; mocked_collect_start_success.assert_async().await; + let job = job.unwrap(); + assert_eq!(job.query.batch_interval(), &batch_interval); + let mocked_collect_error = server .mock("POST", job.collection_job_url.path()) .with_status(500) @@ -934,7 +917,7 @@ mod tests { .await; let mocked_collect_complete = server .mock("POST", job.collection_job_url.path()) - .match_header("DAP-Auth-Token", "token") + .match_header(auth_header, auth_value.as_str()) .with_status(200) .with_header( CONTENT_TYPE.as_str(), @@ -1263,9 +1246,7 @@ mod tests { let parameters = CollectorParameters::new_with_authentication( random(), server_url, - Authentication::AuthorizationBearerToken( - AuthenticationToken::try_from([0x41u8; 16].to_vec()).unwrap(), - ), + AuthenticationToken::Bearer(Vec::from([0x41u8; 16])), hpke_keypair.config().clone(), hpke_keypair.private_key().clone(), ) diff --git a/core/src/task.rs b/core/src/task.rs index 5405daa14..af7406845 100644 --- a/core/src/task.rs +++ b/core/src/task.rs @@ -1,9 +1,13 @@ -use base64::{engine::general_purpose::STANDARD, Engine}; -use http::header::HeaderValue; +use base64::{ + engine::general_purpose::{STANDARD, URL_SAFE_NO_PAD}, + Engine, +}; +use derivative::Derivative; +use http::header::{HeaderValue, AUTHORIZATION}; use rand::{distributions::Standard, prelude::Distribution}; use reqwest::Url; use ring::constant_time; -use serde::{Deserialize, Serialize}; +use serde::{de::Error, Deserialize, Deserializer, Serialize, Serializer}; use std::str; /// HTTP header where auth tokens are provided in messages between participants. @@ -533,48 +537,150 @@ macro_rules! vdaf_dispatch { }; } -/// An authentication (bearer) token used by aggregators for aggregator-to-aggregator and -/// collector-to-aggregator authentication. -#[derive(Clone)] -pub struct AuthenticationToken(Vec); +/// Different modes of authentication supported by Janus for either sending requests (e.g., leader +/// to helper) or receiving them (e.g., collector to leader). +#[derive(Clone, Derivative, Serialize, Deserialize)] +#[derivative(Debug)] +#[serde(tag = "type", content = "token")] +#[non_exhaustive] +pub enum AuthenticationToken { + /// A bearer token. The value is an opaque byte string. Its Base64 encoding is inserted into + /// HTTP requests as specified in [RFC 6750 section 2.1][1]. The token is not necessarily an + /// OAuth token. + /// + /// [1]: https://datatracker.ietf.org/doc/html/rfc6750#section-2.1 + Bearer( + #[derivative(Debug = "ignore")] + #[serde(serialize_with = "as_base64", deserialize_with = "from_base64")] + Vec, + ), + + /// Token presented as the value of the "DAP-Auth-Token" HTTP header. Conforms to + /// [draft-ietf-dap-ppm-01 section 3.2][1]. + /// + /// [1]: https://datatracker.ietf.org/doc/html/draft-ietf-ppm-dap-01#name-https-sender-authentication + DapAuth(DapAuthToken), +} impl AuthenticationToken { - /// Constructs a bearer token string suitable for use as the value in an HTTP `Authorization` - /// header. - pub fn bearer_token(&self) -> String { - format!("Bearer {}", STANDARD.encode(self.as_ref())) + /// Returns an HTTP header and value that should be used to authenticate an HTTP request with + /// this credential. + pub fn request_authentication(&self) -> (&'static str, Vec) { + match self { + Self::Bearer(token) => ( + AUTHORIZATION.as_str(), + // When encoding into a request, we use Base64 standard encoding + format!("Bearer {}", STANDARD.encode(token.as_slice())).into_bytes(), + ), + // A DAP-Auth-Token is already HTTP header-safe, so no encoding is needed. Cloning is + // unfortunate but necessary since other arms must allocate. + Self::DapAuth(token) => (DAP_AUTH_HEADER, token.as_ref().to_vec()), + } } } + +impl PartialEq for AuthenticationToken { + fn eq(&self, other: &Self) -> bool { + let (own, other) = match (self, other) { + (Self::Bearer(own), Self::Bearer(other)) => (own.as_slice(), other.as_slice()), + (Self::DapAuth(own), Self::DapAuth(other)) => (own.as_ref(), other.as_ref()), + _ => { + return false; + } + }; + // We attempt constant-time comparisons of the token data to mitigate timing attacks. Note + // that this function still eaks whether the lengths of the tokens are equal -- this is + // acceptable because we expec the content of the tokens to provide enough randomness that + // needs to be guessed even if the length is known. + constant_time::verify_slices_are_equal(own, other).is_ok() + } +} + +impl Eq for AuthenticationToken {} + impl AsRef<[u8]> for AuthenticationToken { + fn as_ref(&self) -> &[u8] { + match self { + Self::DapAuth(token) => token.as_ref(), + Self::Bearer(token) => token.as_ref(), + } + } +} + +/// Serialize bytes into format suitable for bearer tokens in Janus configuration files. +fn as_base64>(key: &T, serializer: S) -> Result { + let bytes: &[u8] = key.as_ref(); + serializer.serialize_str(&STANDARD.encode(bytes)) +} + +/// Deserialize bytes from Janus configuration files into a bearer token. +fn from_base64<'de, D: Deserializer<'de>>(deserializer: D) -> Result, D::Error> { + String::deserialize(deserializer).and_then(|s| { + STANDARD + .decode(s) + .map_err(|e| D::Error::custom(format!("cannot decode value from Base64: {e:?}"))) + }) +} + +impl Distribution for Standard { + fn sample(&self, rng: &mut R) -> AuthenticationToken { + AuthenticationToken::Bearer(Vec::from(rng.gen::<[u8; 16]>())) + } +} + +/// Token presented as the value of the "DAP-Auth-Token" HTTP header. The token is used directly in +/// the HTTP request without further encoding and so must be a legal HTTP header value. Conforms to +/// [draft-ietf-dap-ppm-01 section 3.2][1]. +/// +/// This opaque type ensures it's impossible to construct an [`AuthenticationToken::DapAuth`] whose +/// contents are invalid. +/// +/// [1]: https://datatracker.ietf.org/doc/html/draft-ietf-ppm-dap-01#name-https-sender-authentication +#[derive(Clone, Derivative)] +#[derivative(Debug)] +pub struct DapAuthToken(#[derivative(Debug = "ignore")] Vec); + +impl DapAuthToken {} + +impl AsRef<[u8]> for DapAuthToken { fn as_ref(&self) -> &[u8] { &self.0 } } -impl TryFrom> for AuthenticationToken { +impl TryFrom> for DapAuthToken { type Error = anyhow::Error; - fn try_from(token: Vec) -> Result { + fn try_from(token: Vec) -> Result { HeaderValue::try_from(token.as_slice())?; Ok(Self(token)) } } -impl PartialEq for AuthenticationToken { - fn eq(&self, other: &Self) -> bool { - // We attempt constant-time comparisons of the token data. Note that this function still - // leaks whether the lengths of the tokens are equal -- this is acceptable because we expect - // the content of the tokens to provide enough randomness that needs to be guessed even if - // the length is known. - constant_time::verify_slices_are_equal(&self.0, &other.0).is_ok() +impl Serialize for DapAuthToken { + fn serialize(&self, serializer: S) -> Result { + serializer.serialize_str(&URL_SAFE_NO_PAD.encode(self.as_ref())) } } -impl Eq for AuthenticationToken {} +impl<'de> Deserialize<'de> for DapAuthToken { + fn deserialize>(deserializer: D) -> Result { + // Verify that the string is a safe HTTP header value + String::deserialize(deserializer) + .and_then(|string| { + URL_SAFE_NO_PAD.decode(string).map_err(|e| { + D::Error::custom(format!( + "cannot decode value from unpadded Base64URL: {e:?}" + )) + }) + }) + .and_then(|bytes| Self::try_from(bytes).map_err(D::Error::custom)) + } +} -impl Distribution for Standard { - fn sample(&self, rng: &mut R) -> AuthenticationToken { - AuthenticationToken(Vec::from(hex::encode(rng.gen::<[u8; 16]>()))) +impl Distribution for Standard { + fn sample(&self, rng: &mut R) -> DapAuthToken { + DapAuthToken(Vec::from(hex::encode(rng.gen::<[u8; 16]>()))) } } @@ -591,7 +697,7 @@ pub fn url_ensure_trailing_slash(url: &mut Url) { #[cfg(test)] mod tests { - use super::VdafInstance; + use super::{AuthenticationToken, VdafInstance}; use serde_test::{assert_tokens, Token}; #[test] @@ -686,4 +792,13 @@ mod tests { }], ); } + + #[test] + fn reject_invalid_dap_auth_token() { + let err = serde_yaml::from_str::( + "{type: \"DapAuth\", token: \"AAAAAAAAAAAAAA\"}", + ) + .unwrap_err(); + assert!(err.to_string().contains("failed to parse header value")); + } } diff --git a/db/00000000000009_auth_token_kind.down.sql b/db/00000000000009_auth_token_kind.down.sql new file mode 100644 index 000000000..b84dda094 --- /dev/null +++ b/db/00000000000009_auth_token_kind.down.sql @@ -0,0 +1,4 @@ +ALTER TABLE task_collector_auth_tokens DROP COLUMN type AUTH_TOKEN_TYPE; +ALTER TABLE task_aggregator_auth_tokens DROP COLUMN type AUTH_TOKEN_TYPE; + +DROP TYPE AUTH_TOKEN_TYPE; diff --git a/db/00000000000009_auth_token_kind.up.sql b/db/00000000000009_auth_token_kind.up.sql new file mode 100644 index 000000000..2bbb22b66 --- /dev/null +++ b/db/00000000000009_auth_token_kind.up.sql @@ -0,0 +1,7 @@ +CREATE TYPE AUTH_TOKEN_TYPE AS ENUM( + 'DAP_AUTH', -- DAP-01 style DAP-Auth-Token header + 'BEARER' -- RFC 6750 bearer token +); + +ALTER TABLE task_aggregator_auth_tokens ADD COLUMN type AUTH_TOKEN_TYPE NOT NULL DEFAULT 'DAP_AUTH'; +ALTER TABLE task_collector_auth_tokens ADD COLUMN type AUTH_TOKEN_TYPE NOT NULL DEFAULT 'DAP_AUTH'; diff --git a/docs/samples/tasks.yaml b/docs/samples/tasks.yaml index e9f0c801a..1d381cd03 100644 --- a/docs/samples/tasks.yaml +++ b/docs/samples/tasks.yaml @@ -65,20 +65,28 @@ # authenticate leader-to-helper requests. In the case of a leader-role task, # the leader will include the first token in a header when making requests to # the helper. In the case of a helper-role task, the helper will accept - # requests with any of the listed authentication tokens. Each token is encoded - # in base64url, and the decoded value is sent as an HTTP header value. For - # example, this value decodes to - # "aggregator-235242f99406c4fd28b820c32eab0f68". + # requests with any of the listed authentication tokens. + # + # Each token's `type` governs how it is inserted into HTTP requests if used by + # the leader to authenticate a request to the helper. aggregator_auth_tokens: - - "YWdncmVnYXRvci0yMzUyNDJmOTk0MDZjNGZkMjhiODIwYzMyZWFiMGY2OA" + # DAP-Auth-Token values are encoded in unpadded base64url, and the decoded + # value is sent in an HTTP header. For example, this token's value decodes + # to "aggregator-235242f99406c4fd28b820c32eab0f68". + - type: "DapAuth" + token: "YWdncmVnYXRvci0yMzUyNDJmOTk0MDZjNGZkMjhiODIwYzMyZWFiMGY2OA" + # Bearer token values are encoded in base64 with padding. + - type: "Bearer" + token: "YWdncmVnYXRvci04NDc1NjkwZjJmYzQzMDBmYjE0NmJiMjk1NDIzNDk1NA==" # Authentication tokens shared between the leader and the collector, and used # to authenticate collector-to-leader requests. For leader tasks, this has the # same format as `aggregator_auth_tokens` above. For helper tasks, this will - # be an empty list instead. This example decodes to - # "collector-abf5408e2b1601831625af3959106458". + # be an empty list instead. + # This example decodes to "collector-abf5408e2b1601831625af3959106458". collector_auth_tokens: - - "Y29sbGVjdG9yLWFiZjU0MDhlMmIxNjAxODMxNjI1YWYzOTU5MTA2NDU4" + - type: "Bearer" + token: "Y29sbGVjdG9yLWFiZjU0MDhlMmIxNjAxODMxNjI1YWYzOTU5MTA2NDU4" # This aggregator's HPKE keypairs. The first keypair's HPKE configuration will # be served via the `hpke_config` DAP endpoint. All keypairs will be tried @@ -118,7 +126,8 @@ aead_id: Aes128Gcm public_key: KHRLcWgfWxli8cdOLPsgsZPttHXh0ho3vLVLrW-63lE aggregator_auth_tokens: - - "YWdncmVnYXRvci1jZmE4NDMyZjdkMzllMjZiYjU3OGUzMzY5Mzk1MWQzNQ" + - type: "Bearer" + token: "YWdncmVnYXRvci1jZmE4NDMyZjdkMzllMjZiYjU3OGUzMzY5Mzk1MWQzNQ==" # Note that this task does not have any collector authentication tokens, since # it is a helper role task. collector_auth_tokens: [] diff --git a/integration_tests/tests/common/mod.rs b/integration_tests/tests/common/mod.rs index b8cf83c2d..ef2df7417 100644 --- a/integration_tests/tests/common/mod.rs +++ b/integration_tests/tests/common/mod.rs @@ -2,8 +2,7 @@ use backoff::{future::retry, ExponentialBackoffBuilder}; use itertools::Itertools; use janus_aggregator_core::task::{test_util::TaskBuilder, QueryType, Task}; use janus_collector::{ - test_util::collect_with_rewritten_url, Authentication, Collection, Collector, - CollectorParameters, + test_util::collect_with_rewritten_url, Collection, Collector, CollectorParameters, }; use janus_core::{ hpke::{test_util::generate_test_hpke_config_and_private_key, HpkePrivateKey}, @@ -126,7 +125,7 @@ pub async fn submit_measurements_and_verify_aggregate_generic<'a, V>( let collector_params = CollectorParameters::new_with_authentication( *leader_task.id(), aggregator_endpoints[Role::Leader.index().unwrap()].clone(), - Authentication::DapAuthToken(leader_task.primary_collector_auth_token().clone()), + leader_task.primary_collector_auth_token().clone(), leader_task.collector_hpke_config().clone(), collector_private_key.clone(), ) diff --git a/interop_binaries/src/bin/janus_interop_aggregator.rs b/interop_binaries/src/bin/janus_interop_aggregator.rs index 432823b4c..1f440a156 100644 --- a/interop_binaries/src/bin/janus_interop_aggregator.rs +++ b/interop_binaries/src/bin/janus_interop_aggregator.rs @@ -11,7 +11,10 @@ use janus_aggregator_core::{ task::{self, Task}, SecretBytes, }; -use janus_core::{task::AuthenticationToken, time::RealClock}; +use janus_core::{ + task::{AuthenticationToken, DapAuthToken}, + time::RealClock, +}; use janus_interop_binaries::{ status::{ERROR, SUCCESS}, AddTaskResponse, AggregatorAddTaskRequest, AggregatorRole, HpkeConfigRegistry, Keyring, @@ -38,9 +41,10 @@ async fn handle_add_task( request: AggregatorAddTaskRequest, ) -> anyhow::Result<()> { let vdaf = request.vdaf.into(); - let leader_authentication_token = - AuthenticationToken::try_from(request.leader_authentication_token.into_bytes()) - .context("invalid header value in \"leader_authentication_token\"")?; + let leader_authentication_token = AuthenticationToken::DapAuth( + DapAuthToken::try_from(request.leader_authentication_token.into_bytes()) + .context("invalid header value in \"leader_authentication_token\"")?, + ); let vdaf_verify_key = SecretBytes::new( URL_SAFE_NO_PAD .decode(request.vdaf_verify_key) @@ -59,10 +63,10 @@ async fn handle_add_task( return Err(anyhow::anyhow!("collector authentication token is missing")) } (AggregatorRole::Leader, Some(collector_authentication_token)) => { - Vec::from([AuthenticationToken::try_from( - collector_authentication_token.into_bytes(), - ) - .context("invalid header value in \"collector_authentication_token\"")?]) + Vec::from([AuthenticationToken::DapAuth( + DapAuthToken::try_from(collector_authentication_token.into_bytes()) + .context("invalid header value in \"collector_authentication_token\"")?, + )]) } (AggregatorRole::Helper, _) => Vec::new(), }; diff --git a/interop_binaries/src/bin/janus_interop_collector.rs b/interop_binaries/src/bin/janus_interop_collector.rs index 127894717..e52aa335e 100644 --- a/interop_binaries/src/bin/janus_interop_collector.rs +++ b/interop_binaries/src/bin/janus_interop_collector.rs @@ -6,7 +6,8 @@ use clap::{value_parser, Arg, Command}; use fixed::types::extra::{U15, U31, U63}; #[cfg(feature = "fpvec_bounded_l2")] use fixed::{FixedI16, FixedI32, FixedI64}; -use janus_collector::{Authentication, Collector, CollectorParameters}; +use janus_collector::{Collector, CollectorParameters}; +use janus_core::task::DapAuthToken; use janus_core::{ hpke::HpkeKeypair, task::{AuthenticationToken, VdafInstance}, @@ -166,9 +167,10 @@ async fn handle_add_task( let keypair = keyring.lock().await.get_random_keypair(); let hpke_config = keypair.config().clone(); - let auth_token = - AuthenticationToken::try_from(request.collector_authentication_token.into_bytes()) - .context("invalid header value in \"collector_authentication_token\"")?; + let auth_token = AuthenticationToken::DapAuth( + DapAuthToken::try_from(request.collector_authentication_token.into_bytes()) + .context("invalid header value in \"collector_authentication_token\"")?, + ); entry.or_insert(TaskState { keypair, @@ -237,7 +239,7 @@ async fn handle_collection_start( let collector_params = CollectorParameters::new_with_authentication( task_id, task_state.leader_url.clone(), - Authentication::DapAuthToken(task_state.auth_token.clone()), + task_state.auth_token.clone(), task_state.keypair.config().clone(), task_state.keypair.private_key().clone(), ) diff --git a/tools/src/bin/collect.rs b/tools/src/bin/collect.rs index 7d4e7545b..68de4dfc0 100644 --- a/tools/src/bin/collect.rs +++ b/tools/src/bin/collect.rs @@ -12,8 +12,8 @@ use derivative::Derivative; use fixed::types::extra::{U15, U31, U63}; #[cfg(feature = "fpvec_bounded_l2")] use fixed::{FixedI16, FixedI32, FixedI64}; -use janus_collector::{default_http_client, Authentication, Collector, CollectorParameters}; -use janus_core::{hpke::HpkePrivateKey, task::AuthenticationToken}; +use janus_collector::{default_http_client, AuthenticationToken, Collector, CollectorParameters}; +use janus_core::{hpke::HpkePrivateKey, task::DapAuthToken}; use janus_messages::{ query_type::{FixedSize, QueryType, TimeInterval}, BatchId, Duration, FixedSizeQuery, HpkeConfig, Interval, PartialBatchSelector, Query, TaskId, @@ -146,11 +146,11 @@ impl TypedValueParser for BatchIdValueParser { } fn parse_authentication_token(value: String) -> Result { - value.into_bytes().try_into() + DapAuthToken::try_from(value.into_bytes()).map(AuthenticationToken::DapAuth) } fn parse_authentication_token_base64(value: String) -> Result { - STANDARD.decode(value)?.try_into() + Ok(AuthenticationToken::Bearer(STANDARD.decode(value)?)) } #[derive(Clone)] @@ -447,8 +447,8 @@ where options.authentication.dap_auth_token, options.authentication.authorization_bearer_token, ) { - (None, Some(token)) => Authentication::AuthorizationBearerToken(token), - (Some(token), None) => Authentication::DapAuthToken(token), + (None, Some(token)) => token, + (Some(token), None) => token, (None, None) | (Some(_), Some(_)) => unreachable!(), }; let parameters = CollectorParameters::new_with_authentication( @@ -604,13 +604,14 @@ impl QueryTypeExt for FixedSize { #[cfg(test)] mod tests { - use crate::{run, AuthenticationOptions, Error, Options, QueryOptions, VdafType}; + use crate::{ + run, AuthenticationOptions, AuthenticationToken, DapAuthToken, Error, Options, + QueryOptions, VdafType, + }; use assert_matches::assert_matches; use base64::{engine::general_purpose::URL_SAFE_NO_PAD, Engine}; use clap::{error::ErrorKind, CommandFactory, Parser}; - use janus_core::{ - hpke::test_util::generate_test_hpke_config_and_private_key, task::AuthenticationToken, - }; + use janus_core::hpke::test_util::generate_test_hpke_config_and_private_key; use janus_messages::{BatchId, TaskId}; use prio::codec::Encode; use rand::random; @@ -629,14 +630,14 @@ mod tests { let task_id = random(); let leader = Url::parse("https://example.com/dap/").unwrap(); - let auth_token = - AuthenticationToken::try_from(b"collector-authentication-token".to_vec()).unwrap(); let expected = Options { task_id, leader: leader.clone(), authentication: AuthenticationOptions { - dap_auth_token: Some(auth_token.clone()), + dap_auth_token: Some(AuthenticationToken::DapAuth( + DapAuthToken::try_from(b"collector-authentication-token".to_vec()).unwrap(), + )), authorization_bearer_token: None, }, hpke_config: hpke_keypair.config().clone(), @@ -886,16 +887,16 @@ mod tests { let hpke_keypair = generate_test_hpke_config_and_private_key(); let encoded_hpke_config = URL_SAFE_NO_PAD.encode(hpke_keypair.config().get_encoded()); let encoded_private_key = URL_SAFE_NO_PAD.encode(hpke_keypair.private_key().as_ref()); - - let auth_token = - AuthenticationToken::try_from(b"collector-authentication-token".to_vec()).unwrap(); + let auth_token = Some(AuthenticationToken::DapAuth( + DapAuthToken::try_from(b"collector-authentication-token".to_vec()).unwrap(), + )); // Check parsing arguments for a current batch query. let expected = Options { task_id, leader: leader.clone(), authentication: AuthenticationOptions { - dap_auth_token: Some(auth_token.clone()), + dap_auth_token: auth_token.clone(), authorization_bearer_token: None, }, hpke_config: hpke_keypair.config().clone(), @@ -936,7 +937,7 @@ mod tests { task_id, leader: leader.clone(), authentication: AuthenticationOptions { - dap_auth_token: Some(auth_token), + dap_auth_token: auth_token, authorization_bearer_token: None, }, hpke_config: hpke_keypair.config().clone(), @@ -1111,10 +1112,9 @@ mod tests { .unwrap() .authentication, AuthenticationOptions { - dap_auth_token: Some( - AuthenticationToken::try_from(b"collector-authentication-token".to_vec()) - .unwrap() - ), + dap_auth_token: Some(AuthenticationToken::DapAuth( + DapAuthToken::try_from(b"collector-authentication-token".to_vec()).unwrap() + )), authorization_bearer_token: None, } ); @@ -1127,9 +1127,9 @@ mod tests { .authentication, AuthenticationOptions { dap_auth_token: None, - authorization_bearer_token: Some( - AuthenticationToken::try_from(Vec::from([0xff; 16])).unwrap(), - ) + authorization_bearer_token: Some(AuthenticationToken::Bearer(Vec::from( + [0xff; 16] + )),) } );