diff --git a/aggregator/src/bin/aggregator.rs b/aggregator/src/bin/aggregator.rs index 515b36a62..c642f4335 100644 --- a/aggregator/src/bin/aggregator.rs +++ b/aggregator/src/bin/aggregator.rs @@ -17,6 +17,7 @@ use std::{iter::Iterator, net::SocketAddr, sync::Arc, time::Duration}; use tokio::join; use tracing::info; use trillium::Headers; +use trillium_router::router; use trillium_tokio::Stopper; #[tokio::main] @@ -31,62 +32,84 @@ async fn main() -> Result<()> { .response_headers() .context("failed to parse response headers")?; - let aggregator_handler = aggregator_handler( - Arc::clone(&datastore), - ctx.clock, - ctx.config.aggregator_config(), - )?; + let mut handlers = ( + aggregator_handler( + Arc::clone(&datastore), + ctx.clock, + ctx.config.aggregator_config(), + )?, + None, + ); - let (aggregator_bound_address, aggregator_server) = setup_server( - ctx.config.listen_address, - response_headers.clone(), - stopper.clone(), - aggregator_handler, - ) - .await - .context("failed to create aggregator server")?; + let aggregator_api_auth_tokens = ctx + .options + .aggregator_api_auth_tokens + .iter() + .filter(|token| !token.is_empty()) + .map(|token| { + let token_bytes = STANDARD + .decode(token) + .context("couldn't base64-decode aggregator API auth token")?; - info!(?aggregator_bound_address, "Running aggregator"); + Ok(SecretBytes::new(token_bytes)) + }) + .collect::>>()?; - let aggregator_api_server = - if let Some(aggregator_api_listen_address) = ctx.config.aggregator_api_listen_address { - let auth_tokens = ctx - .options - .aggregator_api_auth_tokens - .iter() - .filter(|token| !token.is_empty()) - .map(|token| { - let token_bytes = STANDARD - .decode(token) - .context("couldn't base64-decode aggregator API auth token")?; - - Ok(SecretBytes::new(token_bytes)) - }) - .collect::>>()?; - - let aggregator_api_handler = aggregator_api_handler( - Arc::clone(&datastore), - janus_aggregator_api::Config { auth_tokens }, - ); + let inner_aggregator_api_handler = aggregator_api_handler( + Arc::clone(&datastore), + janus_aggregator_api::Config { + auth_tokens: aggregator_api_auth_tokens, + }, + ); + + // No-op closure to unconditionally pass to tokio::join! + let mut aggregator_api_future = Box::pin(async {}) as Pin>>; + match ctx.config.aggregator_api { + Some(AggregatorApi::ListenAddress { listen_address }) => { + // Bind the requested address and spawn a future that serves the aggregator API on + // it, which we'll `tokio::join!` on below let (aggregator_api_bound_address, aggregator_api_server) = setup_server( - aggregator_api_listen_address, - response_headers, + listen_address, + response_headers.clone(), stopper.clone(), - aggregator_api_handler, + inner_aggregator_api_handler, ) .await .context("failed to create aggregator API server")?; info!(?aggregator_api_bound_address, "Running aggregator API"); - Box::pin(aggregator_api_server) as Pin>> - } else { - // No-op closure to unconditionally pass to tokio::join! - Box::pin(async {}) as Pin>> - }; + aggregator_api_future = + Box::pin(aggregator_api_server) as Pin>> + } + Some(AggregatorApi::PathPrefix { path_prefix }) => { + // Create a Trillium handler under the requested path prefix, which we'll add to the + // DAP API handler in the setup_server call below + info!( + aggregator_bound_address = ?ctx.config.listen_address, + ?path_prefix, + "Serving aggregator API relative to DAP API" + ); + // Append wildcard so that this handler will match anything under the prefix + let path_prefix = format!("{path_prefix}/*"); + handlers.1 = Some(router().all(path_prefix, inner_aggregator_api_handler)); + } + None => { /* Do nothing */ } + } + + let (aggregator_bound_address, aggregator_server) = setup_server( + ctx.config.listen_address, + response_headers, + stopper.clone(), + handlers, + ) + .await + .context("failed to create aggregator server")?; + + info!(?aggregator_bound_address, "Running aggregator"); - join!(aggregator_server, aggregator_api_server); + join!(aggregator_server, aggregator_api_future); Ok(()) }) .await @@ -128,15 +151,55 @@ pub struct HeaderEntry { value: String, } +/// Options for serving the aggregator API. +#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)] +#[serde(untagged)] +pub enum AggregatorApi { + /// Address on which this server should listen for connections to the Janus aggregator API and + /// serve its API endpoints, independently from the address on which the DAP API is served. + ListenAddress { listen_address: SocketAddr }, + /// The Janus aggregator API will be served on the same address as the DAP API, but relative to + /// the provided prefix. e.g., if `path_prefix` is `aggregator-api`, then the DAP API's uploads + /// endpoint would be `{listen-address}/tasks/{task-id}/reports`, while task IDs could be + /// obtained from the aggregator API at `{listen-address}/aggregator-api/task_ids`. + PathPrefix { path_prefix: String }, +} + /// Non-secret configuration options for a Janus aggregator, deserialized from YAML. /// /// # Examples /// +/// Configuration serving the aggregator API on its own port, distinct from the DAP API: +/// +/// ``` +/// let yaml_config = r#" +/// --- +/// listen_address: "0.0.0.0:8080" +/// aggregator_api: +/// listen_address: "0.0.0.0:8081" +/// response_headers: +/// - name: "Example" +/// value: "header value" +/// database: +/// url: "postgres://postgres:postgres@localhost:5432/postgres" +/// logging_config: # logging_config is optional +/// force_json_output: true +/// max_upload_batch_size: 100 +/// max_upload_batch_write_delay_ms: 250 +/// batch_aggregation_shard_count: 32 +/// "#; +/// +/// let _decoded: Config = serde_yaml::from_str(yaml_config).unwrap(); +/// ``` +/// +/// Configuration serving the aggregator API relative to the DAP API: +/// /// ``` /// let yaml_config = r#" /// --- /// listen_address: "0.0.0.0:8080" -/// aggregator_api_listen_address: "0.0.0.0:8081" +/// aggregator_api: +/// path_prefix: "aggregator-api" /// response_headers: /// - name: "Example" /// value: "header value" @@ -161,9 +224,8 @@ struct Config { // TODO(#232): options for terminating TLS, unless that gets handled in a load balancer? listen_address: SocketAddr, - /// Address on which this server should listen for connections to the Janus aggregator API and - /// serve its API endpoints. If not set, the aggregator API is not served. - aggregator_api_listen_address: Option, + /// How to serve the Janus aggregator API. If not set, the aggregator API is not served. + aggregator_api: Option, /// Additional headers that will be added to all responses. #[serde(default)] @@ -219,7 +281,7 @@ impl BinaryConfig for Config { #[cfg(test)] mod tests { - use super::{Config, HeaderEntry, Options}; + use super::{AggregatorApi, Config, HeaderEntry, Options}; use clap::CommandFactory; use janus_aggregator::{ aggregator, @@ -244,11 +306,16 @@ mod tests { Options::command().debug_assert() } + #[rstest::rstest] + #[case::listen_address(AggregatorApi::ListenAddress { + listen_address: SocketAddr::from((Ipv4Addr::UNSPECIFIED, 8081)), + })] + #[case::path_prefix(AggregatorApi::PathPrefix { path_prefix: "prefix".to_string() })] #[test] - fn roundtrip_config() { + fn roundtrip_config(#[case] aggregator_api: AggregatorApi) { roundtrip_encoding(Config { listen_address: SocketAddr::from((Ipv4Addr::UNSPECIFIED, 8080)), - aggregator_api_listen_address: Some(SocketAddr::from((Ipv4Addr::UNSPECIFIED, 8081))), + aggregator_api: Some(aggregator_api), common_config: CommonConfig { database: generate_db_config(), logging_config: generate_trace_config(), @@ -265,6 +332,74 @@ mod tests { }) } + #[test] + fn config_no_aggregator_api() { + assert_eq!( + serde_yaml::from_str::( + r#"--- + listen_address: "0.0.0.0:8080" + database: + url: "postgres://postgres:postgres@localhost:5432/postgres" + connection_pool_timeouts_secs: 60 + max_upload_batch_size: 100 + max_upload_batch_write_delay_ms: 250 + batch_aggregation_shard_count: 32 + "# + ) + .unwrap() + .aggregator_api, + None + ); + } + + #[test] + fn config_aggregator_api_listen_address() { + assert_eq!( + serde_yaml::from_str::( + r#"--- + listen_address: "0.0.0.0:8080" + database: + url: "postgres://postgres:postgres@localhost:5432/postgres" + connection_pool_timeouts_secs: 60 + max_upload_batch_size: 100 + max_upload_batch_write_delay_ms: 250 + batch_aggregation_shard_count: 32 + aggregator_api: + listen_address: "0.0.0.0:8081" + "# + ) + .unwrap() + .aggregator_api, + Some(AggregatorApi::ListenAddress { + listen_address: SocketAddr::from((Ipv4Addr::UNSPECIFIED, 8081)) + }) + ); + } + + #[test] + fn config_aggregator_api_path_prefix() { + assert_eq!( + serde_yaml::from_str::( + r#"--- + listen_address: "0.0.0.0:8080" + database: + url: "postgres://postgres:postgres@localhost:5432/postgres" + connection_pool_timeouts_secs: 60 + max_upload_batch_size: 100 + max_upload_batch_write_delay_ms: 250 + batch_aggregation_shard_count: 32 + aggregator_api: + path_prefix: "aggregator-api" + "# + ) + .unwrap() + .aggregator_api, + Some(AggregatorApi::PathPrefix { + path_prefix: "aggregator-api".to_string() + }) + ); + } + /// Check that configuration fragments in the README and other documentation can be parsed /// correctly. #[test] diff --git a/aggregator/tests/graceful_shutdown.rs b/aggregator/tests/graceful_shutdown.rs index b6f9158fd..a6a389cd1 100644 --- a/aggregator/tests/graceful_shutdown.rs +++ b/aggregator/tests/graceful_shutdown.rs @@ -11,7 +11,7 @@ use janus_aggregator_core::{ use janus_core::{task::VdafInstance, test_util::install_test_trace_subscriber, time::RealClock}; use janus_messages::Role; use reqwest::Url; -use serde_yaml::Mapping; +use serde_yaml::{Mapping, Value}; use std::{ future::Future, io::{ErrorKind, Write}, @@ -243,10 +243,12 @@ async fn aggregator_shutdown() { "listen_address".into(), format!("{aggregator_listen_address}").into(), ); - config.insert( - "aggregator_api_listen_address".into(), + let mut aggregator_api = Mapping::new(); + aggregator_api.insert( + "listen_address".into(), format!("{aggregator_api_listen_address}").into(), ); + config.insert("aggregator_api".into(), Value::Mapping(aggregator_api)); config.insert("max_upload_batch_size".into(), 100.into()); config.insert("max_upload_batch_write_delay_ms".into(), 250.into()); config.insert("batch_aggregation_shard_count".into(), 32u64.into()); diff --git a/docs/samples/advanced_config/aggregator.yaml b/docs/samples/advanced_config/aggregator.yaml index 637766ecd..b54435bce 100644 --- a/docs/samples/advanced_config/aggregator.yaml +++ b/docs/samples/advanced_config/aggregator.yaml @@ -67,9 +67,12 @@ metrics_config: # Socket address for DAP requests. (required) listen_address: "0.0.0.0:80" -# Socket address for Janus aggregator API requests. If not set, Janus aggregator API is not served. -# (optional) -aggregator_api_listen_address: "0.0.0.0:8081" +# How to serve the Janus aggregator API. If not set, Janus aggregator API is not served. (optional) +aggregator_api: + # Socket address on which to listen for requests. + listen_address: "0.0.0.0:8081" + # Alternately, the aggregator API may be served on `listen_address`, at an arbitrary path prefix. + # path_prefix: "aggregator-api" # Maximum number of uploaded reports per batching transaction. (required) max_upload_batch_size: 100