Skip to content

Commit

Permalink
Add Brave-P3A-Revision header support, ignore requests with older c…
Browse files Browse the repository at this point in the history
…lient revision
  • Loading branch information
DJAndries committed Dec 1, 2023
1 parent 889d4d9 commit 2b62b0a
Show file tree
Hide file tree
Showing 3 changed files with 45 additions and 3 deletions.
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,7 @@ The format for an individual data channel setting is `<data channel name>=<value
| EPOCH_LIFETIMES | `typical=3` | No | The amount of current & recent previous epochs considered to be 'active'. Epochs older than this lifetime will be consider 'expired', and all partial measurements will be reported at the end of aggregation, if any. |
| EPOCH_DATE_FIELD_NAMES | `typical=wos` | No | The name of the date fields to inject into the aggregated measurements. The injected field will include the survey date, inferred via the measurement epoch. |
| RANDOMNESS_INSTANCE_NAMES | `typical=typical` | No | Randomness server instance names, for retrieving relevant server info. |
| MIN_CHANNEL_REVISIONS | | No | The minimum `Brave-P3A-Version` header value for measurements submitted to the server. |

The main channel name can be selected by using the `--main-channel-name` switch. Using this switch will have the following effects:

Expand Down
2 changes: 1 addition & 1 deletion src/channel.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ pub fn get_data_channel_map_from_env(env_key: &str, default: &str) -> HashMap<St
let env_encoded = env::var(env_key).unwrap_or_else(|_| default.to_string());

let mut map = HashMap::new();
for encoded_channel in env_encoded.split(",") {
for encoded_channel in env_encoded.split(",").filter(|v| !v.is_empty()) {
let mut encoded_channel_split = encoded_channel.split("=");
let channel_name = encoded_channel_split
.next()
Expand Down
45 changes: 43 additions & 2 deletions src/server.rs
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
use crate::channel::get_data_channel_map_from_env;
use crate::prometheus::{
create_metric_server, InflightMetricLabels, TotalMetricLabels, WebMetrics,
};
use crate::record_stream::{
get_data_channel_topic_map_from_env, KafkaRecordStream, KafkaRecordStreamConfig, RecordStream,
};
use crate::star::{parse_message, AppSTARError};
use actix_web::HttpRequest;
use actix_web::{
dev::Service,
error::ResponseError,
Expand All @@ -18,11 +20,15 @@ use base64::{engine::general_purpose as base64_engine, Engine as _};
use derive_more::{Display, Error, From};
use futures::{future::try_join, FutureExt};
use prometheus_client::registry::Registry;
use reqwest::header::HeaderName;
use std::collections::HashMap;
use std::str::{from_utf8, Utf8Error};
use std::sync::Arc;
use std::time::Instant;

const MIN_CHANNEL_REVISIONS_ENV_KEY: &str = "MIN_CHANNEL_REVISIONS";
const REVISION_HEADER: &str = "brave-p3a-version";

#[derive(From, Error, Display, Debug)]
pub enum WebError {
#[display(fmt = "failed to decode base64")]
Expand All @@ -39,6 +45,7 @@ pub struct ServerState {
pub channel_rec_streams: HashMap<String, KafkaRecordStream>,
pub web_metrics: Arc<WebMetrics>,
pub main_channel: String,
pub min_revision_map: HashMap<String, usize>,
}

impl ResponseError for WebError {
Expand Down Expand Up @@ -67,6 +74,7 @@ async fn ident_handler() -> Result<impl Responder, WebError> {

async fn handle_measurement_submit(
body: web::Bytes,
request: HttpRequest,
state: &ServerState,
channel_name: &String,
) -> Result<impl Responder, WebError> {
Expand All @@ -76,6 +84,24 @@ async fn handle_measurement_submit(
let body_str = from_utf8(&body)?.trim();
let bincode_msg = base64_engine::STANDARD.decode(body_str)?;
parse_message(&bincode_msg)?;

if let Some(min_revision) = state.min_revision_map.get(channel_name) {
let req_revision = request
.headers()
.get(HeaderName::from_static(REVISION_HEADER))
.map(|v| {
v.to_str()
.unwrap_or_default()
.parse::<usize>()
.unwrap_or_default()
})
.unwrap_or_default();
if req_revision < *min_revision {
// Just ignore older requests gracefully
return Ok(HttpResponse::NoContent().finish());
}
}

match rec_stream.produce(&bincode_msg).await {
Err(e) => {
error!("Failed to push message: {}", e);
Expand All @@ -90,18 +116,20 @@ async fn handle_measurement_submit(
#[post("/{channel}")]
async fn channel_handler(
body: web::Bytes,
request: HttpRequest,
state: Data<ServerState>,
channel: web::Path<String>,
) -> Result<impl Responder, WebError> {
handle_measurement_submit(body, state.as_ref(), channel.as_ref()).await
handle_measurement_submit(body, request, state.as_ref(), channel.as_ref()).await
}

#[post("/")]
async fn main_handler(
body: web::Bytes,
request: HttpRequest,
state: Data<ServerState>,
) -> Result<impl Responder, WebError> {
handle_measurement_submit(body, state.as_ref(), &state.main_channel).await
handle_measurement_submit(body, request, state.as_ref(), &state.main_channel).await
}

pub async fn start_server(worker_count: usize, main_channel: String) -> std::io::Result<()> {
Expand All @@ -120,10 +148,23 @@ pub async fn start_server(worker_count: usize, main_channel: String) -> std::io:
})
.collect();

let min_revision_map = get_data_channel_map_from_env(MIN_CHANNEL_REVISIONS_ENV_KEY, "")
.into_iter()
.map(|(channel, value)| {
(
channel,
value
.parse::<usize>()
.expect("minimum channel revision should be non-negative integer"),
)
})
.collect();

let state = Data::new(ServerState {
channel_rec_streams,
web_metrics: Arc::new(WebMetrics::new()),
main_channel,
min_revision_map,
});

let mut registry = <Registry>::default();
Expand Down

0 comments on commit 2b62b0a

Please sign in to comment.