diff --git a/aggregator/src/aggregator.rs b/aggregator/src/aggregator.rs index c483e9e76..827dc9219 100644 --- a/aggregator/src/aggregator.rs +++ b/aggregator/src/aggregator.rs @@ -8,6 +8,7 @@ use crate::{ query_type::{CollectableQueryType, UploadableQueryType}, report_writer::{ReportWriteBatcher, WritableReport}, }, + cache::GlobalHpkeKeypairCache, config::TaskprovConfig, Operation, }; @@ -26,7 +27,7 @@ use janus_aggregator_core::{ self, models::{ AggregateShareJob, AggregationJob, AggregationJobState, Batch, BatchAggregation, - BatchAggregationState, BatchState, CollectionJob, CollectionJobState, HpkeKeyState, + BatchAggregationState, BatchState, CollectionJob, CollectionJobState, LeaderStoredReport, ReportAggregation, ReportAggregationState, }, Datastore, Transaction, @@ -37,7 +38,7 @@ use janus_aggregator_core::{ #[cfg(feature = "test-util")] use janus_core::test_util::dummy_vdaf; use janus_core::{ - hpke::{self, HpkeApplicationInfo, Label}, + hpke::{self, HpkeApplicationInfo, HpkeKeypair, Label}, http::response_to_problem_details, task::{AuthenticationToken, VdafInstance, PRIO3_VERIFY_KEY_LENGTH}, time::{Clock, DurationExt, IntervalExt, TimeExt}, @@ -48,7 +49,7 @@ use janus_messages::{ AggregateShare, AggregateShareAad, AggregateShareReq, AggregationJobContinueReq, AggregationJobId, AggregationJobInitializeReq, AggregationJobResp, AggregationJobRound, BatchSelector, Collection, CollectionJobId, CollectionReq, Duration, HpkeCiphertext, - HpkeConfig, HpkeConfigList, InputShareAad, Interval, PartialBatchSelector, PlaintextInputShare, + HpkeConfigList, InputShareAad, Interval, PartialBatchSelector, PlaintextInputShare, PrepareStep, PrepareStepResult, Report, ReportIdChecksum, ReportShare, ReportShareError, Role, TaskId, }; @@ -74,11 +75,11 @@ use std::{ collections::{hash_map::Entry, HashMap, HashSet}, fmt::Debug, panic, - sync::{Arc, Mutex as StdMutex}, + sync::Arc, time::{Duration as StdDuration, Instant}, }; -use tokio::{spawn, sync::Mutex, task::JoinHandle, time::sleep, try_join}; -use tracing::{debug, error, info, trace_span, warn}; +use tokio::{sync::Mutex, try_join}; +use tracing::{debug, info, trace_span, warn}; use url::Url; pub mod accumulator; @@ -161,8 +162,8 @@ pub struct Aggregator { /// process. aggregate_step_failure_counter: Counter, - /// Cache of global HPKE configs. - global_hpke_configs: GlobalHpkeConfigCache, + /// Cache of global HPKE keypairs and configs. + global_hpke_keypairs: GlobalHpkeKeypairCache, } /// Config represents a configuration for an Aggregator. @@ -195,89 +196,12 @@ impl Default for Config { max_upload_batch_size: 1, max_upload_batch_write_delay: StdDuration::ZERO, batch_aggregation_shard_count: 1, - global_hpke_configs_refresh_interval: GlobalHpkeConfigCache::DEFAULT_REFRESH_INTERVAL, + global_hpke_configs_refresh_interval: GlobalHpkeKeypairCache::DEFAULT_REFRESH_INTERVAL, taskprov_config: TaskprovConfig::default(), } } } -#[derive(Debug)] -pub struct GlobalHpkeConfigCache { - /// Cache of HPKE configs. We use a [`std::sync::Mutex`] since we won't hold - /// locks across `.await`s and it is lighter weight than [`tokio::sync::Mutex`]. - configs: Arc>>>, - - /// Handle for task responsible for periodically refreshing the cache. - refresh_handle: JoinHandle<()>, -} - -impl GlobalHpkeConfigCache { - pub const DEFAULT_REFRESH_INTERVAL: StdDuration = - StdDuration::from_secs(60 * 60 /* hourly */); - - async fn new( - datastore: Arc>, - refresh_interval: StdDuration, - ) -> Result { - // Initial cache load. - let configs = Arc::new(StdMutex::new(Arc::new( - Self::get_configs_from_datastore(&datastore).await?, - ))); - - let refresh_configs = configs.clone(); - let refresh_handle = spawn(async move { - loop { - sleep(refresh_interval).await; - - match Self::get_configs_from_datastore(&datastore).await { - Ok(new_configs) => { - let mut values = refresh_configs.lock().unwrap(); - *values = Arc::new(new_configs); - } - Err(err) => { - error!(?err, "failed to refresh HPKE config cache"); - } - } - } - }); - - Ok(Self { - configs, - refresh_handle, - }) - } - - async fn get_configs_from_datastore( - datastore: &Datastore, - ) -> Result, Error> { - Ok(datastore - .run_tx_with_name("refresh_global_hpke_configs_cache", |tx| { - Box::pin(async move { - Ok(tx - .get_global_hpke_keypairs() - .await? - .iter() - .filter(|keypair| matches!(keypair.state(), HpkeKeyState::Active)) - .map(|keypair| keypair.hpke_keypair().config().clone()) - .collect::>()) - }) - }) - .await?) - } - - // Retrieve currently cached configs. - fn configs(&self) -> Arc> { - let configs = self.configs.lock().unwrap(); - configs.clone() - } -} - -impl Drop for GlobalHpkeConfigCache { - fn drop(&mut self) { - self.refresh_handle.abort() - } -} - impl Aggregator { async fn new( datastore: Arc>, @@ -306,9 +230,11 @@ impl Aggregator { let aggregate_step_failure_counter = aggregate_step_failure_counter(meter); aggregate_step_failure_counter.add(&Context::current(), 0, &[]); - let global_hpke_configs = - GlobalHpkeConfigCache::new(datastore.clone(), cfg.global_hpke_configs_refresh_interval) - .await?; + let global_hpke_keypairs = GlobalHpkeKeypairCache::new( + datastore.clone(), + cfg.global_hpke_configs_refresh_interval, + ) + .await?; Ok(Self { datastore, @@ -319,7 +245,7 @@ impl Aggregator { upload_decrypt_failure_counter, upload_decode_failure_counter, aggregate_step_failure_counter, - global_hpke_configs, + global_hpke_keypairs, }) } @@ -338,7 +264,7 @@ impl Aggregator { Ok(task_aggregator.handle_hpke_config()) } None => { - let configs = self.global_hpke_configs.configs(); + let configs = self.global_hpke_keypairs.configs(); if configs.is_empty() { if self.cfg.taskprov_config.enabled { // A global HPKE configuration is only _required_ when taskprov @@ -366,6 +292,7 @@ impl Aggregator { task_aggregator .handle_upload( &self.clock, + &self.global_hpke_keypairs, &self.upload_decrypt_failure_counter, &self.upload_decode_failure_counter, report, @@ -394,6 +321,7 @@ impl Aggregator { task_aggregator .handle_aggregate_init( &self.datastore, + &self.global_hpke_keypairs, &self.aggregate_step_failure_counter, aggregation_job_id, req_bytes, @@ -692,6 +620,7 @@ impl TaskAggregator { async fn handle_upload( &self, clock: &C, + global_hpke_keypairs: &GlobalHpkeKeypairCache, upload_decrypt_failure_counter: &Counter, upload_decode_failure_counter: &Counter, report: Report, @@ -699,6 +628,7 @@ impl TaskAggregator { self.vdaf_ops .handle_upload( clock, + global_hpke_keypairs, upload_decrypt_failure_counter, upload_decode_failure_counter, &self.task, @@ -711,6 +641,7 @@ impl TaskAggregator { async fn handle_aggregate_init( &self, datastore: &Datastore, + global_hpke_keypairs: &GlobalHpkeKeypairCache, aggregate_step_failure_counter: &Counter, aggregation_job_id: &AggregationJobId, req_bytes: &[u8], @@ -718,6 +649,7 @@ impl TaskAggregator { self.vdaf_ops .handle_aggregate_init( datastore, + global_hpke_keypairs, aggregate_step_failure_counter, Arc::clone(&self.task), aggregation_job_id, @@ -944,6 +876,7 @@ impl VdafOps { async fn handle_upload( &self, clock: &C, + global_hpke_keypairs: &GlobalHpkeKeypairCache, upload_decrypt_failure_counter: &Counter, upload_decode_failure_counter: &Counter, task: &Task, @@ -956,6 +889,7 @@ impl VdafOps { Self::handle_upload_generic::( Arc::clone(vdaf), clock, + global_hpke_keypairs, upload_decrypt_failure_counter, upload_decode_failure_counter, task, @@ -970,6 +904,7 @@ impl VdafOps { Self::handle_upload_generic::( Arc::clone(vdaf), clock, + global_hpke_keypairs, upload_decrypt_failure_counter, upload_decode_failure_counter, task, @@ -992,6 +927,7 @@ impl VdafOps { async fn handle_aggregate_init( &self, datastore: &Datastore, + global_hpke_keypairs: &GlobalHpkeKeypairCache, aggregate_step_failure_counter: &Counter, task: Arc, aggregation_job_id: &AggregationJobId, @@ -1002,6 +938,7 @@ impl VdafOps { vdaf_ops_dispatch!(self, (vdaf, verify_key, VdafType, VERIFY_KEY_LENGTH) => { Self::handle_aggregate_init_generic::( datastore, + global_hpke_keypairs, vdaf, aggregate_step_failure_counter, task, @@ -1016,6 +953,7 @@ impl VdafOps { vdaf_ops_dispatch!(self, (vdaf, verify_key, VdafType, VERIFY_KEY_LENGTH) => { Self::handle_aggregate_init_generic::( datastore, + global_hpke_keypairs, vdaf, aggregate_step_failure_counter, task, @@ -1081,6 +1019,7 @@ impl VdafOps { async fn handle_upload_generic( vdaf: Arc, clock: &C, + global_hpke_keypairs: &GlobalHpkeKeypairCache, upload_decrypt_failure_counter: &Counter, upload_decode_failure_counter: &Counter, task: &Task, @@ -1107,15 +1046,6 @@ impl VdafOps { let leader_encrypted_input_share = &report.encrypted_input_shares()[Role::Leader.index().unwrap()]; - // Verify that the report's HPKE config ID is known. - // https://www.ietf.org/archive/id/draft-ietf-ppm-dap-02.html#section-4.3.2 - let hpke_keypair = task - .hpke_keys() - .get(leader_encrypted_input_share.config_id()) - .ok_or_else(|| { - Error::OutdatedHpkeConfig(*task.id(), *leader_encrypted_input_share.config_id()) - })?; - let report_deadline = clock .now() .add(task.tolerable_clock_skew()) @@ -1179,19 +1109,51 @@ impl VdafOps { } }; - let encoded_leader_plaintext_input_share = match hpke::open( - hpke_keypair.config(), - hpke_keypair.private_key(), - &HpkeApplicationInfo::new(&Label::InputShare, &Role::Client, task.role()), - leader_encrypted_input_share, - &InputShareAad::new( - *task.id(), - report.metadata().clone(), - report.public_share().to_vec(), + let try_hpke_open = |hpke_keypair: &HpkeKeypair| { + hpke::open( + hpke_keypair.config(), + hpke_keypair.private_key(), + &HpkeApplicationInfo::new(&Label::InputShare, &Role::Client, task.role()), + leader_encrypted_input_share, + &InputShareAad::new( + *task.id(), + report.metadata().clone(), + report.public_share().to_vec(), + ) + .get_encoded(), ) - .get_encoded(), - ) { - Ok(encoded_leader_plaintext_input_share) => encoded_leader_plaintext_input_share, + }; + + let global_hpke_keypair = + global_hpke_keypairs.keypair(leader_encrypted_input_share.config_id()); + + let task_hpke_keypair = task + .hpke_keys() + .get(leader_encrypted_input_share.config_id()); + + let decryption_result = match (task_hpke_keypair, global_hpke_keypair) { + // Verify that the report's HPKE config ID is known. + // https://www.ietf.org/archive/id/draft-ietf-ppm-dap-02.html#section-4.3.2 + (None, None) => { + return Err(Arc::new(Error::OutdatedHpkeConfig( + *task.id(), + *leader_encrypted_input_share.config_id(), + ))); + } + (None, Some(global_hpke_keypair)) => try_hpke_open(&global_hpke_keypair), + (Some(task_hpke_keypair), None) => try_hpke_open(task_hpke_keypair), + (Some(task_hpke_keypair), Some(global_hpke_keypair)) => { + try_hpke_open(task_hpke_keypair).or_else(|error| match error { + // Only attempt second trial if _decryption_ fails, and not some + // error in server-side HPKE configuration. + hpke::Error::Hpke(_) => try_hpke_open(&global_hpke_keypair), + error => Err(error), + }) + } + }; + + let encoded_leader_plaintext_input_share = match decryption_result { + Ok(plaintext) => plaintext, Err(error) => { info!( report.task_id = %task.id(), @@ -1345,6 +1307,7 @@ impl VdafOps { /// helper, described in §4.4.4.1 of draft-gpew-priv-ppm. async fn handle_aggregate_init_generic( datastore: &Datastore, + global_hpke_keypairs: &GlobalHpkeKeypairCache, vdaf: &A, aggregate_step_failure_counter: &Counter, task: Arc, @@ -1400,24 +1363,15 @@ impl VdafOps { } } - let hpke_keypair = task + let task_hpke_keypair = task .hpke_keys() - .get(report_share.encrypted_input_share().config_id()) - .ok_or_else(|| { - info!( - config_id = %report_share.encrypted_input_share().config_id(), - "Helper encrypted input share references unknown HPKE config ID" - ); - aggregate_step_failure_counter.add( - &Context::current(), - 1, - &[KeyValue::new("type", "unknown_hpke_config_id")], - ); - ReportShareError::HpkeUnknownConfigId - }); + .get(report_share.encrypted_input_share().config_id()); + + let global_hpke_keypair = + global_hpke_keypairs.keypair(report_share.encrypted_input_share().config_id()); // If decryption fails, then the aggregator MUST fail with error `hpke-decrypt-error`. (§4.4.2.2) - let plaintext = hpke_keypair.and_then(|hpke_keypair| { + let try_hpke_open = |hpke_keypair: &HpkeKeypair| { hpke::open( hpke_keypair.config(), hpke_keypair.private_key(), @@ -1430,6 +1384,37 @@ impl VdafOps { ) .get_encoded(), ) + }; + + let check_keypairs = if task_hpke_keypair.is_none() && global_hpke_keypair.is_none() { + info!( + config_id = %report_share.encrypted_input_share().config_id(), + "Helper encrypted input share references unknown HPKE config ID" + ); + aggregate_step_failure_counter.add( + &Context::current(), + 1, + &[KeyValue::new("type", "unknown_hpke_config_id")], + ); + Err(ReportShareError::HpkeUnknownConfigId) + } else { + Ok(()) + }; + + let plaintext = check_keypairs.and_then(|_| { + match (task_hpke_keypair, global_hpke_keypair) { + (None, None) => unreachable!("already checked this condition"), + (None, Some(global_hpke_keypair)) => try_hpke_open(&global_hpke_keypair), + (Some(task_hpke_keypair), None) => try_hpke_open(task_hpke_keypair), + (Some(task_hpke_keypair), Some(global_hpke_keypair)) => { + try_hpke_open(task_hpke_keypair).or_else(|error| match error { + // Only attempt second trial if _decryption_ fails, and not some + // error in server-side HPKE configuration. + hpke::Error::Hpke(_) => try_hpke_open(&global_hpke_keypair), + error => Err(error), + }) + } + } .map_err(|error| { info!( task_id = %task.id(), @@ -2819,7 +2804,10 @@ mod tests { test_util::noop_meter, }; use janus_core::{ - hpke::{self, HpkeApplicationInfo, Label}, + hpke::{ + self, test_util::generate_test_hpke_config_and_private_key_with_id, + HpkeApplicationInfo, HpkeKeypair, Label, + }, task::{VdafInstance, PRIO3_VERIFY_KEY_LENGTH}, test_util::install_test_trace_subscriber, time::{Clock, MockClock, TimeExt}, @@ -2835,6 +2823,7 @@ mod tests { }; use rand::random; use std::{collections::HashSet, iter, sync::Arc, time::Duration as StdDuration}; + use tokio::time::sleep; pub(crate) const BATCH_AGGREGATION_SHARD_COUNT: u64 = 32; @@ -2853,11 +2842,11 @@ mod tests { task: &Task, report_timestamp: Time, id: ReportId, + hpke_key: &HpkeKeypair, ) -> Report { assert_eq!(task.vdaf(), &VdafInstance::Prio3Count); let vdaf = Prio3Count::new_count(2).unwrap(); - let hpke_key = task.current_hpke_key(); let report_metadata = ReportMetadata::new(id, report_timestamp); let (public_share, measurements) = vdaf.shard(&1, id.as_ref()).unwrap(); @@ -2891,7 +2880,7 @@ mod tests { } pub(super) fn create_report(task: &Task, report_timestamp: Time) -> Report { - create_report_with_id(task, report_timestamp, random()) + create_report_with_id(task, report_timestamp, random(), task.current_hpke_key()) } async fn setup_upload_test( @@ -2969,7 +2958,12 @@ mod tests { .unwrap(); // Reports may not be mutated - let mutated_report = create_report_with_id(&task, clock.now(), *report.metadata().id()); + let mutated_report = create_report_with_id( + &task, + clock.now(), + *report.metadata().id(), + task.current_hpke_key(), + ); let error = aggregator .handle_upload(task.id(), &mutated_report.get_encoded()) .await @@ -3182,6 +3176,80 @@ mod tests { }); } + #[tokio::test] + async fn upload_report_encrypted_with_global_key() { + install_test_trace_subscriber(); + + let (vdaf, aggregator, clock, task, datastore, _ephemeral_datastore) = + setup_upload_test(Config { + max_upload_batch_size: 1000, + max_upload_batch_write_delay: StdDuration::from_millis(500), + global_hpke_configs_refresh_interval: StdDuration::from_millis(500), + ..Default::default() + }) + .await; + + // Same ID as the task to test having both keys to choose from. + let global_hpke_keypair_same_id = generate_test_hpke_config_and_private_key_with_id( + (*task.current_hpke_key().config().id()).into(), + ); + // Different ID to test misses on the task key. + let global_hpke_keypair_different_id = generate_test_hpke_config_and_private_key_with_id( + (0..) + .map(HpkeConfigId::from) + .find(|id| !task.hpke_keys().contains_key(id)) + .unwrap() + .into(), + ); + + datastore + .run_tx(|tx| { + let global_hpke_keypair_same_id = global_hpke_keypair_same_id.clone(); + let global_hpke_keypair_different_id = global_hpke_keypair_different_id.clone(); + Box::pin(async move { + // Leave these in the PENDING state--they should still be decryptable. + tx.put_global_hpke_keypair(&global_hpke_keypair_same_id) + .await?; + tx.put_global_hpke_keypair(&global_hpke_keypair_different_id) + .await?; + Ok(()) + }) + }) + .await + .unwrap(); + + // Let keypair cache refresh. + sleep(StdDuration::from_millis(750)).await; + + for report in [ + create_report(&task, clock.now()), + create_report_with_id(&task, clock.now(), random(), &global_hpke_keypair_same_id), + create_report_with_id( + &task, + clock.now(), + random(), + &global_hpke_keypair_different_id, + ), + ] { + aggregator + .handle_upload(task.id(), &report.get_encoded()) + .await + .unwrap(); + + let got_report = datastore + .run_tx(|tx| { + let (vdaf, task_id, report_id) = + (vdaf.clone(), *task.id(), *report.metadata().id()); + Box::pin(async move { tx.get_client_report(&vdaf, &task_id, &report_id).await }) + }) + .await + .unwrap() + .unwrap(); + assert_eq!(task.id(), got_report.task_id()); + assert_eq!(report.metadata(), got_report.metadata()); + } + } + pub(crate) fn generate_helper_report_share>( task_id: TaskId, report_metadata: ReportMetadata, diff --git a/aggregator/src/aggregator/http_handlers.rs b/aggregator/src/aggregator/http_handlers.rs index 20a1f997d..c8abf8e1a 100644 --- a/aggregator/src/aggregator/http_handlers.rs +++ b/aggregator/src/aggregator/http_handlers.rs @@ -976,7 +976,12 @@ mod tests { let accepted_report_id = report.metadata().id(); // Verify that new reports using an existing report ID are rejected with reportRejected - let duplicate_id_report = create_report_with_id(&task, clock.now(), *accepted_report_id); + let duplicate_id_report = create_report_with_id( + &task, + clock.now(), + *accepted_report_id, + task.current_hpke_key(), + ); let mut test_conn = put(task.report_upload_uri().unwrap().path()) .with_request_header(KnownHeaderName::ContentType, Report::MEDIA_TYPE) .with_request_body(duplicate_id_report.get_encoded()) @@ -1892,6 +1897,250 @@ mod tests { } } + #[tokio::test] + #[allow(clippy::unit_arg)] + async fn aggregate_init_with_reports_encrypted_by_global_key() { + install_test_trace_subscriber(); + + let task = + TaskBuilder::new(QueryType::TimeInterval, VdafInstance::Fake, Role::Helper).build(); + let clock = MockClock::default(); + let ephemeral_datastore = ephemeral_datastore().await; + let datastore = Arc::new(ephemeral_datastore.datastore(clock.clone()).await); + + // Insert some global HPKE keys. + // Same ID as the task to test having both keys to choose from. + let global_hpke_keypair_same_id = generate_test_hpke_config_and_private_key_with_id( + (*task.current_hpke_key().config().id()).into(), + ); + // Different ID to test misses on the task key. + let global_hpke_keypair_different_id = generate_test_hpke_config_and_private_key_with_id( + (0..) + .map(HpkeConfigId::from) + .find(|id| !task.hpke_keys().contains_key(id)) + .unwrap() + .into(), + ); + datastore + .run_tx(|tx| { + let global_hpke_keypair_same_id = global_hpke_keypair_same_id.clone(); + let global_hpke_keypair_different_id = global_hpke_keypair_different_id.clone(); + Box::pin(async move { + // Leave these in the PENDING state--they should still be decryptable. + tx.put_global_hpke_keypair(&global_hpke_keypair_same_id) + .await?; + tx.put_global_hpke_keypair(&global_hpke_keypair_different_id) + .await?; + Ok(()) + }) + }) + .await + .unwrap(); + + datastore.put_task(&task).await.unwrap(); + + let vdaf = dummy_vdaf::Vdaf::new(); + let verify_key: VerifyKey<0> = task.primary_vdaf_verify_key().unwrap(); + + // This report was encrypted with a global HPKE config that has the same config + // ID as the task's HPKE config. + let report_metadata_same_id = ReportMetadata::new( + random(), + clock + .now() + .to_batch_interval_start(task.time_precision()) + .unwrap(), + ); + let transcript = run_vdaf( + &vdaf, + verify_key.as_bytes(), + &dummy_vdaf::AggregationParam(0), + report_metadata_same_id.id(), + &(), + ); + let report_share_same_id = generate_helper_report_share::( + *task.id(), + report_metadata_same_id, + global_hpke_keypair_same_id.config(), + &transcript.public_share, + Vec::new(), + &transcript.input_shares[1], + ); + + // This report was encrypted with a global HPKE config that has the same config + // ID as the task's HPKE config, but will fail to decrypt. + let report_metadata_same_id_corrupted = ReportMetadata::new( + random(), + clock + .now() + .to_batch_interval_start(task.time_precision()) + .unwrap(), + ); + let transcript = run_vdaf( + &vdaf, + verify_key.as_bytes(), + &dummy_vdaf::AggregationParam(0), + report_metadata_same_id_corrupted.id(), + &(), + ); + let report_share_same_id_corrupted = generate_helper_report_share::( + *task.id(), + report_metadata_same_id_corrupted.clone(), + global_hpke_keypair_same_id.config(), + &transcript.public_share, + Vec::new(), + &transcript.input_shares[1], + ); + let encrypted_input_share = report_share_same_id_corrupted.encrypted_input_share(); + let mut corrupted_payload = encrypted_input_share.payload().to_vec(); + corrupted_payload[0] ^= 0xFF; + let corrupted_input_share = HpkeCiphertext::new( + *encrypted_input_share.config_id(), + encrypted_input_share.encapsulated_key().to_vec(), + corrupted_payload, + ); + let encoded_public_share = transcript.public_share.get_encoded(); + let report_share_same_id_corrupted = ReportShare::new( + report_metadata_same_id_corrupted, + encoded_public_share.clone(), + corrupted_input_share, + ); + + // This report was encrypted with a global HPKE config that doesn't collide + // with the task HPKE config's ID. + let report_metadata_different_id = ReportMetadata::new( + random(), + clock + .now() + .to_batch_interval_start(task.time_precision()) + .unwrap(), + ); + let transcript = run_vdaf( + &vdaf, + verify_key.as_bytes(), + &dummy_vdaf::AggregationParam(0), + report_metadata_different_id.id(), + &(), + ); + let report_share_different_id = generate_helper_report_share::( + *task.id(), + report_metadata_different_id, + global_hpke_keypair_different_id.config(), + &transcript.public_share, + Vec::new(), + &transcript.input_shares[1], + ); + + // This report was encrypted with a global HPKE config that doesn't collide + // with the task HPKE config's ID, but will fail decryption. + let report_metadata_different_id_corrupted = ReportMetadata::new( + random(), + clock + .now() + .to_batch_interval_start(task.time_precision()) + .unwrap(), + ); + let transcript = run_vdaf( + &vdaf, + verify_key.as_bytes(), + &dummy_vdaf::AggregationParam(0), + report_metadata_different_id_corrupted.id(), + &(), + ); + let report_share_different_id_corrupted = generate_helper_report_share::( + *task.id(), + report_metadata_different_id_corrupted.clone(), + global_hpke_keypair_different_id.config(), + &transcript.public_share, + Vec::new(), + &transcript.input_shares[1], + ); + let encrypted_input_share = report_share_different_id_corrupted.encrypted_input_share(); + let mut corrupted_payload = encrypted_input_share.payload().to_vec(); + corrupted_payload[0] ^= 0xFF; + let corrupted_input_share = HpkeCiphertext::new( + *encrypted_input_share.config_id(), + encrypted_input_share.encapsulated_key().to_vec(), + corrupted_payload, + ); + let encoded_public_share = transcript.public_share.get_encoded(); + let report_share_different_id_corrupted = ReportShare::new( + report_metadata_different_id_corrupted, + encoded_public_share.clone(), + corrupted_input_share, + ); + + let handler = aggregator_handler( + Arc::clone(&datastore), + clock, + &noop_meter(), + default_aggregator_config(), + ) + .await + .unwrap(); + let aggregation_job_id: AggregationJobId = random(); + + let request = AggregationJobInitializeReq::new( + dummy_vdaf::AggregationParam(0).get_encoded(), + PartialBatchSelector::new_time_interval(), + Vec::from([ + report_share_same_id.clone(), + report_share_different_id.clone(), + report_share_same_id_corrupted.clone(), + report_share_different_id_corrupted.clone(), + ]), + ); + + let mut test_conn = + put_aggregation_job(&task, &aggregation_job_id, &request, &handler).await; + assert_eq!(test_conn.status(), Some(Status::Ok)); + let body_bytes = take_response_body(&mut test_conn).await; + let aggregate_resp = AggregationJobResp::get_decoded(&body_bytes).unwrap(); + + // Validate response. + assert_eq!(aggregate_resp.prepare_steps().len(), 4); + + let prepare_step_same_id = aggregate_resp.prepare_steps().get(0).unwrap(); + assert_eq!( + prepare_step_same_id.report_id(), + report_share_same_id.metadata().id() + ); + assert_matches!( + prepare_step_same_id.result(), + &PrepareStepResult::Continued(..) + ); + + let prepare_step_different_id = aggregate_resp.prepare_steps().get(1).unwrap(); + assert_eq!( + prepare_step_different_id.report_id(), + report_share_different_id.metadata().id() + ); + assert_matches!( + prepare_step_different_id.result(), + &PrepareStepResult::Continued(..) + ); + + let prepare_step_same_id_corrupted = aggregate_resp.prepare_steps().get(2).unwrap(); + assert_eq!( + prepare_step_same_id_corrupted.report_id(), + report_share_same_id_corrupted.metadata().id() + ); + assert_matches!( + prepare_step_same_id_corrupted.result(), + &PrepareStepResult::Failed(ReportShareError::HpkeDecryptError) + ); + + let prepare_step_different_id_corrupted = aggregate_resp.prepare_steps().get(3).unwrap(); + assert_eq!( + prepare_step_different_id_corrupted.report_id(), + report_share_different_id_corrupted.metadata().id() + ); + assert_matches!( + prepare_step_different_id_corrupted.result(), + &PrepareStepResult::Failed(ReportShareError::HpkeDecryptError) + ); + } + #[allow(clippy::unit_arg)] #[tokio::test] async fn aggregate_init_change_report_timestamp() { diff --git a/aggregator/src/bin/aggregator.rs b/aggregator/src/bin/aggregator.rs index 25b3490fe..56f340579 100644 --- a/aggregator/src/bin/aggregator.rs +++ b/aggregator/src/bin/aggregator.rs @@ -2,13 +2,11 @@ use anyhow::{Context, Result}; use base64::{engine::general_purpose::STANDARD, Engine}; use clap::Parser; use janus_aggregator::{ - aggregator::{ - self, garbage_collector::GarbageCollector, http_handlers::aggregator_handler, - GlobalHpkeConfigCache, - }, + aggregator::{self, garbage_collector::GarbageCollector, http_handlers::aggregator_handler}, binary_utils::{ janus_main, setup_server, setup_signal_handler, BinaryOptions, CommonBinaryOptions, }, + cache::GlobalHpkeKeypairCache, config::{BinaryConfig, CommonConfig, TaskprovConfig}, }; use janus_aggregator_api::{self, aggregator_api_handler}; @@ -284,7 +282,7 @@ struct Config { /// Defines how often to refresh the global HPKE configs cache in milliseconds. This affects /// how often an aggregator becomes aware of key state changes. If unspecified, default is - /// defined by [`GlobalHpkeConfigCache::DEFAULT_REFRESH_INTERVAL`]. You shouldn't normally + /// defined by [`GlobalHpkeKeypairCache::DEFAULT_REFRESH_INTERVAL`]. You shouldn't normally /// have to specify this. #[serde(default)] global_hpke_configs_refresh_interval: Option, @@ -331,7 +329,7 @@ impl Config { taskprov_config: self.taskprov_config.clone(), global_hpke_configs_refresh_interval: match self.global_hpke_configs_refresh_interval { Some(duration) => Duration::from_millis(duration), - None => GlobalHpkeConfigCache::DEFAULT_REFRESH_INTERVAL, + None => GlobalHpkeKeypairCache::DEFAULT_REFRESH_INTERVAL, }, } } diff --git a/aggregator/src/cache.rs b/aggregator/src/cache.rs new file mode 100644 index 000000000..855e8d96b --- /dev/null +++ b/aggregator/src/cache.rs @@ -0,0 +1,134 @@ +//! Various in-memory caches that can be used by an aggregator. + +use crate::aggregator::Error; +use janus_aggregator_core::datastore::{ + models::{GlobalHpkeKeypair, HpkeKeyState}, + Datastore, +}; +use janus_core::{hpke::HpkeKeypair, time::Clock}; +use janus_messages::{HpkeConfig, HpkeConfigId}; +use std::{ + collections::HashMap, + fmt::Debug, + sync::{Arc, Mutex as StdMutex}, + time::Duration as StdDuration, +}; +use tokio::{spawn, task::JoinHandle, time::sleep}; +use tracing::error; + +type HpkeConfigs = Arc>; +type HpkeKeypairs = HashMap>; + +#[derive(Debug)] +pub struct GlobalHpkeKeypairCache { + // We use a std::sync::Mutex in this cache because we won't hold locks across + // `.await` boundaries. StdMutex is lighter weight than `tokio::sync::Mutex`. + /// Cache of HPKE configs for advertisement. + configs: Arc>, + + /// Cache of HPKE keypairs for report decryption. + keypairs: Arc>, + + /// Handle for task responsible for periodically refreshing the cache. + refresh_handle: JoinHandle<()>, +} + +impl GlobalHpkeKeypairCache { + pub const DEFAULT_REFRESH_INTERVAL: StdDuration = + StdDuration::from_secs(60 * 30 /* 30 minutes */); + + pub async fn new( + datastore: Arc>, + refresh_interval: StdDuration, + ) -> Result { + // Initial cache load. + let global_keypairs = Self::get_global_keypairs(&datastore).await?; + let configs = Arc::new(StdMutex::new(Self::filter_active_configs(&global_keypairs))); + let keypairs = Arc::new(StdMutex::new(Self::map_keypairs(&global_keypairs))); + + // Start refresh task. + let refresh_configs = configs.clone(); + let refresh_keypairs = keypairs.clone(); + let refresh_handle = spawn(async move { + loop { + sleep(refresh_interval).await; + + match Self::get_global_keypairs(&datastore).await { + Ok(global_keypairs) => { + let new_configs = Self::filter_active_configs(&global_keypairs); + let new_keypairs = Self::map_keypairs(&global_keypairs); + { + let mut configs = refresh_configs.lock().unwrap(); + *configs = new_configs; + } + { + let mut keypairs = refresh_keypairs.lock().unwrap(); + *keypairs = new_keypairs; + } + } + Err(err) => { + error!(?err, "failed to refresh HPKE config cache"); + } + } + } + }); + + Ok(Self { + configs, + keypairs, + refresh_handle, + }) + } + + fn filter_active_configs(global_keypairs: &[GlobalHpkeKeypair]) -> HpkeConfigs { + Arc::new( + global_keypairs + .iter() + .filter_map(|keypair| match keypair.state() { + HpkeKeyState::Active => Some(keypair.hpke_keypair().config().clone()), + _ => None, + }) + .collect(), + ) + } + + fn map_keypairs(global_keypairs: &[GlobalHpkeKeypair]) -> HpkeKeypairs { + global_keypairs + .iter() + .map(|keypair| { + let keypair = keypair.hpke_keypair().clone(); + (*keypair.config().id(), Arc::new(keypair)) + }) + .collect() + } + + async fn get_global_keypairs( + datastore: &Datastore, + ) -> Result, Error> { + Ok(datastore + .run_tx_with_name("refresh_global_hpke_configs_cache", |tx| { + Box::pin(async move { tx.get_global_hpke_keypairs().await }) + }) + .await?) + } + + /// Retrieve active configs for config advertisement. This only returns configs + /// for keypairs that are in the `[HpkeKeyState::Active]` state. + pub fn configs(&self) -> HpkeConfigs { + let configs = self.configs.lock().unwrap(); + configs.clone() + } + + /// Retrieve a keypair by ID for report decryption. This retrieves keypairs that + /// are in any state. + pub fn keypair(&self, id: &HpkeConfigId) -> Option> { + let keypairs = self.keypairs.lock().unwrap(); + keypairs.get(id).cloned() + } +} + +impl Drop for GlobalHpkeKeypairCache { + fn drop(&mut self) { + self.refresh_handle.abort() + } +} diff --git a/aggregator/src/lib.rs b/aggregator/src/lib.rs index cef20d96d..652ff2dc4 100644 --- a/aggregator/src/lib.rs +++ b/aggregator/src/lib.rs @@ -3,6 +3,7 @@ pub mod aggregator; pub mod binary_utils; +pub mod cache; pub mod config; pub mod metrics; pub mod trace;