From 3467944b22699571dcd1a772db277b41dabe7473 Mon Sep 17 00:00:00 2001 From: Jacob Rothstein Date: Fri, 14 Jul 2023 13:48:57 -0700 Subject: [PATCH] encrypt bearer tokens at rest --- Cargo.lock | 2 + Cargo.toml | 2 + README.md | 1 + client/tests/aggregators.rs | 3 +- migration/src/lib.rs | 2 + ...1_181722_rename_aggregator_bearer_token.rs | 64 ++++++++ src/bin.rs | 6 +- src/clients/aggregator_client.rs | 8 +- src/config.rs | 4 +- src/crypter.rs | 142 ++++++++++++++++++ src/entity/aggregator.rs | 27 +++- src/entity/aggregator/new_aggregator.rs | 12 +- src/entity/aggregator/update_aggregator.rs | 10 +- src/entity/task/provisionable_task.rs | 17 ++- src/handler.rs | 5 + src/handler/error.rs | 12 ++ src/lib.rs | 2 + src/routes.rs | 10 +- src/routes/aggregators.rs | 42 +++--- src/routes/tasks.rs | 35 +++-- test-support/src/fixtures.rs | 12 +- test-support/src/lib.rs | 3 +- tests/aggregator_client.rs | 26 ++-- tests/aggregators.rs | 10 +- tests/crypter.rs | 63 ++++++++ 25 files changed, 444 insertions(+), 76 deletions(-) create mode 100644 migration/src/m20230731_181722_rename_aggregator_bearer_token.rs create mode 100644 src/crypter.rs create mode 100644 tests/crypter.rs diff --git a/Cargo.lock b/Cargo.lock index 4e7721e1..17db1ab7 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1042,6 +1042,7 @@ dependencies = [ name = "divviup-api" version = "0.0.16" dependencies = [ + "aes-gcm", "async-lock", "async-session", "base64 0.21.2", @@ -1088,6 +1089,7 @@ dependencies = [ "trillium-static-compiled", "trillium-testing", "trillium-tokio", + "typenum", "url", "uuid", "validator", diff --git a/Cargo.toml b/Cargo.toml index 8c9d1e1b..2d3f97ef 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -17,6 +17,7 @@ integration-testing = [] members = [".", "migration", "client", "test-support", "cli"] [dependencies] +aes-gcm = "0.10.2" async-lock = "2.7.0" async-session = "3.0.0" base64 = "0.21.2" @@ -60,6 +61,7 @@ trillium-sessions = "0.4.2" trillium-static-compiled = "0.5.0" trillium-testing = { version = "0.5.0", optional = true } trillium-tokio = "0.3.1" +typenum = "1.16.0" url = "2.4.0" uuid = { version = "1.4.1", features = ["v4", "fast-rng", "serde"] } validator = { version = "0.16.1", features = ["derive"] } diff --git a/README.md b/README.md index f917a3e6..2308574f 100644 --- a/README.md +++ b/README.md @@ -38,6 +38,7 @@ An example `.envrc` is provided for optional but recommended use with [`direnv`] * `DATABASE_URL` -- A [libpq-compatible postgres uri](https://www.postgresql.org/docs/current/libpq-connect.html#id-1.7.3.8.3.6) * `POSTMARK_TOKEN` -- the token from the transactional stream from a [postmark](https://postmarkapp.com) account * `EMAIL_ADDRESS` -- the address this deployment should send from +* `DATABASE_ENCRYPTION_KEYS` -- Comma-joined url-safe-no-pad base64'ed 16 byte cryptographically-random keys. The first one will be used to encrypt aggregator API authentication tokens at rest in the database ### Optional binding environment variables diff --git a/client/tests/aggregators.rs b/client/tests/aggregators.rs index e0380639..99a117af 100644 --- a/client/tests/aggregators.rs +++ b/client/tests/aggregators.rs @@ -95,7 +95,8 @@ async fn rotate_bearer_token( .one(app.db()) .await? .unwrap() - .bearer_token, + .bearer_token(app.crypter()) + .unwrap(), new_bearer_token ); Ok(()) diff --git a/migration/src/lib.rs b/migration/src/lib.rs index 38d009bb..28cf64a8 100644 --- a/migration/src/lib.rs +++ b/migration/src/lib.rs @@ -15,6 +15,7 @@ mod m20230626_183248_add_is_first_party_to_aggregators; mod m20230630_175314_create_api_tokens; mod m20230703_201332_add_additional_fields_to_api_tokens; mod m20230725_220134_add_vdafs_and_query_types_to_aggregators; +mod m20230731_181722_rename_aggregator_bearer_token; pub struct Migrator; @@ -37,6 +38,7 @@ impl MigratorTrait for Migrator { Box::new(m20230630_175314_create_api_tokens::Migration), Box::new(m20230703_201332_add_additional_fields_to_api_tokens::Migration), Box::new(m20230725_220134_add_vdafs_and_query_types_to_aggregators::Migration), + Box::new(m20230731_181722_rename_aggregator_bearer_token::Migration), ] } } diff --git a/migration/src/m20230731_181722_rename_aggregator_bearer_token.rs b/migration/src/m20230731_181722_rename_aggregator_bearer_token.rs new file mode 100644 index 00000000..2f865d67 --- /dev/null +++ b/migration/src/m20230731_181722_rename_aggregator_bearer_token.rs @@ -0,0 +1,64 @@ +use sea_orm_migration::prelude::*; + +#[derive(DeriveMigrationName)] +pub struct Migration; + +#[async_trait::async_trait] +impl MigrationTrait for Migration { + async fn up(&self, manager: &SchemaManager) -> Result<(), DbErr> { + manager + .exec_stmt(Table::truncate().table(Task::Table).to_owned()) + .await?; + + manager + .exec_stmt(Table::truncate().table(Aggregator::Table).to_owned()) + .await?; + + manager + .alter_table( + Table::alter() + .table(Aggregator::Table) + .add_column( + ColumnDef::new(Aggregator::EncryptedBearerToken) + .binary() + .not_null(), + ) + .drop_column(Aggregator::BearerToken) + .to_owned(), + ) + .await?; + Ok(()) + } + + async fn down(&self, manager: &SchemaManager) -> Result<(), DbErr> { + manager + .exec_stmt(Table::truncate().table(Task::Table).to_owned()) + .await?; + manager + .exec_stmt(Table::truncate().table(Aggregator::Table).to_owned()) + .await?; + + manager + .alter_table( + Table::alter() + .table(Aggregator::Table) + .add_column(ColumnDef::new(Aggregator::BearerToken).string().not_null()) + .drop_column(Aggregator::EncryptedBearerToken) + .to_owned(), + ) + .await?; + Ok(()) + } +} + +#[derive(Iden)] +enum Aggregator { + Table, + BearerToken, + EncryptedBearerToken, +} + +#[derive(Iden)] +enum Task { + Table, +} diff --git a/src/bin.rs b/src/bin.rs index f0058b84..248bde55 100644 --- a/src/bin.rs +++ b/src/bin.rs @@ -9,8 +9,12 @@ async fn main() { let config = match Config::from_env() { Ok(config) => config, - Err(e) => panic!("{e}"), + Err(e) => { + eprintln!("{e}"); + std::process::exit(-1); + } }; + let stopper = Stopper::new(); let observer = CloneCounterObserver::default(); diff --git a/src/clients/aggregator_client.rs b/src/clients/aggregator_client.rs index 15a38de2..1821bb3f 100644 --- a/src/clients/aggregator_client.rs +++ b/src/clients/aggregator_client.rs @@ -15,9 +15,9 @@ const CONTENT_TYPE: &str = "application/vnd.janus.aggregator+json;version=0.1"; #[derive(Debug, Clone)] pub struct AggregatorClient { client: Client, - base_url: Url, auth_header: HeaderValue, aggregator: Aggregator, + base_url: Url, } impl AsRef for AggregatorClient { @@ -27,7 +27,7 @@ impl AsRef for AggregatorClient { } impl AggregatorClient { - pub fn new(client: Client, aggregator: Aggregator) -> Self { + pub fn new(client: Client, aggregator: Aggregator, bearer_token: &str) -> Self { let mut base_url: Url = aggregator.api_url.clone().into(); if !base_url.path().ends_with('/') { base_url.set_path(&format!("{}/", base_url.path())); @@ -35,9 +35,9 @@ impl AggregatorClient { Self { client, - base_url, - auth_header: HeaderValue::from(format!("Bearer {}", &aggregator.bearer_token)), + auth_header: format!("Bearer {bearer_token}").into(), aggregator, + base_url, } } diff --git a/src/config.rs b/src/config.rs index 4a38e52e..75272960 100644 --- a/src/config.rs +++ b/src/config.rs @@ -1,4 +1,4 @@ -use crate::handler::oauth2::Oauth2Config; +use crate::{handler::oauth2::Oauth2Config, Crypter}; use base64::{engine::general_purpose::URL_SAFE_NO_PAD, Engine}; use email_address::EmailAddress; use std::{collections::VecDeque, env::VarError, error::Error, str::FromStr}; @@ -17,6 +17,7 @@ pub struct Config { pub auth_client_secret: String, pub auth_url: Url, pub client: Client, + pub crypter: Crypter, pub database_url: Url, pub email_address: EmailAddress, pub postmark_token: String, @@ -95,6 +96,7 @@ impl Config { auth_client_secret: var("AUTH_CLIENT_SECRET")?, auth_url: var("AUTH_URL")?, client: build_client(), + crypter: var("DATABASE_ENCRYPTION_KEYS")?, database_url: var("DATABASE_URL")?, email_address: var("EMAIL_ADDRESS")?, postmark_token: var("POSTMARK_TOKEN")?, diff --git a/src/crypter.rs b/src/crypter.rs new file mode 100644 index 00000000..29f1daa0 --- /dev/null +++ b/src/crypter.rs @@ -0,0 +1,142 @@ +use aes_gcm::{ + aead::{AeadCore, AeadInPlace, KeyInit, OsRng}, + Aes128Gcm as AesGcm, Error, KeySizeUser, Nonce, +}; +use base64::{engine::general_purpose::URL_SAFE_NO_PAD, DecodeError, Engine}; +use std::{ + collections::VecDeque, + fmt::{self, Debug, Formatter}, + iter, + str::FromStr, + sync::Arc, +}; +use typenum::marker_traits::Unsigned; + +pub type Key = aes_gcm::Key; + +#[derive(Clone)] +pub struct Crypter(Arc); + +#[derive(thiserror::Error, Debug, Clone)] +pub enum CrypterParseError { + #[error(transparent)] + Base64(#[from] DecodeError), + + #[error("incorrect key length, must be {} bytes after base64 decode", ::KeySize::USIZE)] + IncorrectLength, + + #[error("at least one key needed")] + Missing, +} + +impl FromStr for Crypter { + type Err = CrypterParseError; + fn from_str(s: &str) -> Result { + let mut keys = s + .split(',') + .map(|s| { + URL_SAFE_NO_PAD + .decode(s) + .map_err(CrypterParseError::Base64) + .and_then(|v| Key::from_exact_iter(v).ok_or(CrypterParseError::IncorrectLength)) + }) + .collect::, _>>()?; + let current_key = keys.pop_front().ok_or(CrypterParseError::Missing)?; + Ok(Self::new(current_key, keys)) + } +} + +impl Debug for Crypter { + fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result { + f.debug_struct("Crypter") + .field("current_ciphers", &"..") + .field("past_ciphers", &self.0.past_ciphers.len()) + .finish() + } +} + +#[derive(Clone)] +struct CrypterInner { + current_cipher: AesGcm, + past_ciphers: Vec, +} + +impl From for Crypter { + fn from(key: Key) -> Self { + Self::new(key, []) + } +} + +impl Crypter { + pub fn new(current_key: Key, past_keys: impl IntoIterator) -> Self { + Self(Arc::new(CrypterInner { + current_cipher: AesGcm::new(¤t_key), + past_ciphers: past_keys.into_iter().map(|key| AesGcm::new(&key)).collect(), + })) + } + + pub fn generate_key() -> Key { + AesGcm::generate_key(OsRng) + } + + pub fn encrypt(&self, associated_data: &[u8], plaintext: &[u8]) -> Result, Error> { + self.0.encrypt(associated_data, plaintext) + } + + pub fn decrypt( + &self, + associated_data: &[u8], + nonce_and_ciphertext: &[u8], + ) -> Result, Error> { + self.0.decrypt(associated_data, nonce_and_ciphertext) + } +} + +impl CrypterInner { + fn encrypt(&self, associated_data: &[u8], plaintext: &[u8]) -> Result, Error> { + let nonce = AesGcm::generate_nonce(&mut OsRng); + let mut buffer = plaintext.to_vec(); + self.current_cipher + .encrypt_in_place(&nonce, associated_data, &mut buffer)?; + let mut nonce_and_ciphertext = nonce.to_vec(); + nonce_and_ciphertext.append(&mut buffer); + Ok(nonce_and_ciphertext) + } + + fn decrypt( + &self, + associated_data: &[u8], + nonce_and_ciphertext: &[u8], + ) -> Result, Error> { + let nonce_size = ::NonceSize::USIZE; + if nonce_and_ciphertext.len() < nonce_size { + return Err(Error); + } + + let (nonce, ciphertext) = nonce_and_ciphertext.split_at(nonce_size); + + self.cipher_iter() + .find_map(|cipher| { + self.decrypt_with_key(cipher, associated_data, nonce, ciphertext) + .ok() + }) + .ok_or(Error) + } + + fn cipher_iter(&self) -> impl Iterator { + iter::once(&self.current_cipher).chain(self.past_ciphers.iter()) + } + + fn decrypt_with_key( + &self, + cipher: &AesGcm, + associated_data: &[u8], + nonce: &[u8], + ciphertext: &[u8], + ) -> Result, Error> { + let nonce = Nonce::from_slice(nonce); + let mut bytes = ciphertext.to_vec(); + cipher.decrypt_in_place(nonce, associated_data, &mut bytes)?; + Ok(bytes) + } +} diff --git a/src/entity/aggregator.rs b/src/entity/aggregator.rs index 27f29ea8..cbbaf92f 100644 --- a/src/entity/aggregator.rs +++ b/src/entity/aggregator.rs @@ -5,12 +5,13 @@ mod update_aggregator; mod vdaf_name; use super::{url::Url, AccountColumn, AccountRelation, Accounts, Memberships}; -use crate::clients::AggregatorClient; +use crate::{clients::AggregatorClient, Crypter, Error}; use sea_orm::{ ActiveModelBehavior, ActiveValue, DeriveEntityModel, DerivePrimaryKey, DeriveRelation, EntityTrait, EnumIter, IntoActiveModel, PrimaryKeyTrait, Related, RelationDef, RelationTrait, }; use serde::{Deserialize, Serialize}; + use time::OffsetDateTime; use uuid::Uuid; @@ -39,10 +40,10 @@ pub struct Model { pub dap_url: Url, pub api_url: Url, pub is_first_party: bool, - #[serde(skip)] - pub bearer_token: String, pub query_types: QueryTypeNameSet, pub vdafs: VdafNameSet, + #[serde(skip)] + pub encrypted_bearer_token: Vec, } impl Model { @@ -57,8 +58,24 @@ impl Model { self.deleted_at.is_some() } - pub fn client(&self, http_client: trillium_client::Client) -> AggregatorClient { - AggregatorClient::new(http_client, self.clone()) + pub fn client( + &self, + http_client: trillium_client::Client, + crypter: &Crypter, + ) -> Result { + Ok(AggregatorClient::new( + http_client, + self.clone(), + &self.bearer_token(crypter)?, + )) + } + + pub fn bearer_token(&self, crypter: &Crypter) -> Result { + let bearer_token_bytes = crypter.decrypt( + self.api_url.as_ref().as_bytes(), + &self.encrypted_bearer_token, + )?; + String::from_utf8(bearer_token_bytes).map_err(Into::into) } } diff --git a/src/entity/aggregator/new_aggregator.rs b/src/entity/aggregator/new_aggregator.rs index 7dd6b8a6..5733ab2f 100644 --- a/src/entity/aggregator/new_aggregator.rs +++ b/src/entity/aggregator/new_aggregator.rs @@ -1,7 +1,7 @@ use super::ActiveModel; use crate::{ clients::{AggregatorClient, ClientError}, - entity::{Account, Aggregator}, + entity::{url::Url, Account, Aggregator}, handler::Error, }; use sea_orm::IntoActiveModel; @@ -37,6 +37,7 @@ impl NewAggregator { self, account: Option<&Account>, client: Client, + crypter: &crate::Crypter, ) -> Result { self.validate()?; let aggregator_config = AggregatorClient::get_config( @@ -80,12 +81,19 @@ impl NewAggregator { // of the scope of this repository, we work around this by // double-checking these Options -- once in validate, and // again in the conversion to non-optional fields. + + let api_url: Url = self.api_url.as_ref().unwrap().parse()?; + let encrypted_bearer_token = crypter.encrypt( + api_url.as_ref().as_bytes(), + self.bearer_token.as_ref().unwrap().as_bytes(), + )?; + Ok(Aggregator { role: aggregator_config.role, name: self.name.unwrap(), api_url: self.api_url.unwrap().parse()?, dap_url: aggregator_config.dap_url.into(), - bearer_token: self.bearer_token.unwrap(), + encrypted_bearer_token, id: Uuid::new_v4(), account_id: account.map(|account| account.id), created_at: OffsetDateTime::now_utc(), diff --git a/src/entity/aggregator/update_aggregator.rs b/src/entity/aggregator/update_aggregator.rs index bfaf2907..deac03ca 100644 --- a/src/entity/aggregator/update_aggregator.rs +++ b/src/entity/aggregator/update_aggregator.rs @@ -1,7 +1,7 @@ use crate::{ clients::{AggregatorClient, ClientError}, entity::Aggregator, - Error, + Crypter, Error, }; use sea_orm::{ActiveModelTrait, ActiveValue, IntoActiveModel}; use serde::Deserialize; @@ -22,7 +22,8 @@ impl UpdateAggregator { self, aggregator: Aggregator, client: Client, - ) -> Result { + crypter: &Crypter, + ) -> Result { self.validate()?; let api_url = aggregator.api_url.clone().into(); let mut aggregator = aggregator.into_active_model(); @@ -49,7 +50,10 @@ impl UpdateAggregator { aggregator.query_types = ActiveValue::Set(aggregator_config.query_types); aggregator.vdafs = ActiveValue::Set(aggregator_config.vdafs); - aggregator.bearer_token = ActiveValue::Set(bearer_token); + aggregator.encrypted_bearer_token = ActiveValue::Set(crypter.encrypt( + aggregator.api_url.as_ref().as_ref().as_bytes(), + bearer_token.as_bytes(), + )?); } if aggregator.is_changed() { diff --git a/src/entity/task/provisionable_task.rs b/src/entity/task/provisionable_task.rs index 81f1cfcc..e2ba11e7 100644 --- a/src/entity/task/provisionable_task.rs +++ b/src/entity/task/provisionable_task.rs @@ -3,6 +3,7 @@ use crate::{ clients::aggregator_client::api_types::AuthenticationToken, entity::{Account, Aggregator}, handler::Error, + Crypter, }; use janus_messages::HpkeConfig; use trillium_client::Client; @@ -49,8 +50,12 @@ impl ProvisionableTask { &self, http_client: Client, aggregator: Aggregator, + crypter: &Crypter, ) -> Result { - let response = aggregator.client(http_client).create_task(self).await?; + let response = aggregator + .client(http_client, crypter)? + .create_task(self) + .await?; assert_same(&self.vdaf, &response.vdaf.clone().into(), "vdaf")?; assert_same( @@ -80,15 +85,19 @@ impl ProvisionableTask { Ok(response) } - pub async fn provision(mut self, client: Client) -> Result { + pub async fn provision( + mut self, + client: Client, + crypter: &Crypter, + ) -> Result { let helper = self - .provision_aggregator(client.clone(), self.helper_aggregator.clone()) + .provision_aggregator(client.clone(), self.helper_aggregator.clone(), crypter) .await?; self.aggregator_auth_token = helper.aggregator_auth_token.map(AuthenticationToken::token); let _leader = self - .provision_aggregator(client, self.leader_aggregator.clone()) + .provision_aggregator(client, self.leader_aggregator.clone(), crypter) .await?; Ok(super::Model { diff --git a/src/handler.rs b/src/handler.rs index e599833d..213b1e01 100644 --- a/src/handler.rs +++ b/src/handler.rs @@ -80,6 +80,10 @@ impl DivviupApi { pub fn config(&self) -> &Config { &self.config } + + pub fn crypter(&self) -> &crate::Crypter { + &self.config.crypter + } } impl AsRef for DivviupApi { @@ -101,6 +105,7 @@ fn api(db: &Db, config: &Config) -> impl Handler { .with_cookie_name("divviup.sid") .with_older_secrets(&config.session_secrets.older), state(config.client.clone()), + state(config.crypter.clone()), cors_headers(config), cache_control([Private, MustRevalidate]), db.clone(), diff --git a/src/handler/error.rs b/src/handler/error.rs index e4902cb7..e3b59675 100644 --- a/src/handler/error.rs +++ b/src/handler/error.rs @@ -78,6 +78,18 @@ pub enum Error { TaskProvisioning(#[from] crate::entity::task::TaskProvisioningError), #[error(transparent)] Uuid(#[from] uuid::Error), + #[error("encryption error")] + Encryption, + #[error(transparent)] + Utf8Error(#[from] std::string::FromUtf8Error), + #[error("{0}")] + String(&'static str), +} + +impl From for Error { + fn from(_: aes_gcm::Error) -> Self { + Self::Encryption + } } impl From> for Error { diff --git a/src/lib.rs b/src/lib.rs index b28ff1e9..5c1d8c68 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -12,6 +12,7 @@ mod config; mod db; #[macro_use] pub mod entity; +mod crypter; pub mod handler; pub mod permissions; pub mod queue; @@ -20,6 +21,7 @@ pub mod telemetry; mod user; pub use config::{Config, ConfigError}; +pub use crypter::Crypter; pub use db::Db; pub use handler::{custom_mime_types::CONTENT_TYPE, DivviupApi, Error}; pub use permissions::{Permissions, PermissionsActor}; diff --git a/src/routes.rs b/src/routes.rs index 54377f82..2f646439 100644 --- a/src/routes.rs +++ b/src/routes.rs @@ -50,11 +50,11 @@ pub fn routes(config: &Config) -> impl Handler { .any( &[Get, Post, Delete, Patch], "/api/*", - (state(auth0_client), api_routes(config)), + (state(auth0_client), api_routes()), ) } -fn api_routes(config: &Config) -> impl Handler { +fn api_routes() -> impl Handler { ( ReplaceMimeTypes, api(actor_required), @@ -82,20 +82,20 @@ fn api_routes(config: &Config) -> impl Handler { .any( &[Patch, Get, Post], "/accounts/:account_id/*", - accounts_routes(config), + accounts_routes(), ) .all("/admin/*", admin::routes()), ) } -fn accounts_routes(config: &Config) -> impl Handler { +fn accounts_routes() -> impl Handler { router() .patch("/", api(accounts::update)) .get("/", api(accounts::show)) .get("/memberships", api(memberships::index)) .post("/memberships", api(memberships::create)) .get("/tasks", api(tasks::index)) - .post("/tasks", (state(config.clone()), api(tasks::create))) + .post("/tasks", api(tasks::create)) .post("/aggregators", api(aggregators::create)) .get("/aggregators", api(aggregators::index)) .post("/api_tokens", api(api_tokens::create)) diff --git a/src/routes/aggregators.rs b/src/routes/aggregators.rs index 090bb4cb..b2ddc82f 100644 --- a/src/routes/aggregators.rs +++ b/src/routes/aggregators.rs @@ -1,15 +1,14 @@ use crate::{ entity::{Account, Aggregator, AggregatorColumn, Aggregators, NewAggregator, UpdateAggregator}, - handler::Error, - Db, Permissions, PermissionsActor, + Db, Error, Permissions, PermissionsActor, }; use sea_orm::{ sea_query::{all, any}, ActiveModelTrait, ColumnTrait, EntityTrait, QueryFilter, }; use trillium::{Conn, Handler, Status}; -use trillium_api::{FromConn, Json, State}; -use trillium_client::Client; +use trillium_api::{FromConn, Json}; + use trillium_router::RouterConnExt; use uuid::Uuid; @@ -82,16 +81,14 @@ pub async fn index( } pub async fn create( - _: &mut Conn, - (db, account, Json(new_aggregator), State(client)): ( - Db, - Account, - Json, - State, - ), + conn: &mut Conn, + (db, account, Json(new_aggregator)): (Db, Account, Json), ) -> Result { + let client = conn.take_state().unwrap(); + let crypter = conn.take_state().unwrap(); + new_aggregator - .build(Some(&account), client) + .build(Some(&account), client, &crypter) .await? .insert(&db) .await @@ -105,16 +102,13 @@ pub async fn delete(_: &mut Conn, (db, aggregator): (Db, Aggregator)) -> Result< } pub async fn update( - _: &mut Conn, - (db, aggregator, Json(update_aggregator), State(client)): ( - Db, - Aggregator, - Json, - State, - ), + conn: &mut Conn, + (db, aggregator, Json(update_aggregator)): (Db, Aggregator, Json), ) -> Result, Error> { + let client = conn.take_state().unwrap(); + let crypter = conn.state().unwrap(); update_aggregator - .build(aggregator, client) + .build(aggregator, client, crypter) .await? .update(&db) .await @@ -123,11 +117,13 @@ pub async fn update( } pub async fn admin_create( - _: &mut Conn, - (db, Json(new_aggregator), State(client)): (Db, Json, State), + conn: &mut Conn, + (db, Json(new_aggregator)): (Db, Json), ) -> Result { + let client = conn.take_state().unwrap(); + let crypter = conn.state().unwrap(); new_aggregator - .build(None, client) + .build(None, client, crypter) .await? .insert(&db) .await diff --git a/src/routes/tasks.rs b/src/routes/tasks.rs index d9875738..810822ef 100644 --- a/src/routes/tasks.rs +++ b/src/routes/tasks.rs @@ -1,6 +1,6 @@ use crate::{ entity::{Account, NewTask, Task, Tasks, UpdateTask}, - Db, Error, Permissions, PermissionsActor, + Crypter, Db, Error, Permissions, PermissionsActor, }; use sea_orm::{ActiveModelTrait, EntityTrait, ModelTrait}; use std::time::Duration; @@ -46,12 +46,13 @@ impl FromConn for Task { type CreateArgs = (Account, Json, State, Db); pub async fn create( - _: &mut Conn, + conn: &mut Conn, (account, task, State(client), db): CreateArgs, ) -> Result { + let crypter = conn.state().unwrap(); task.validate(account, &db) .await? - .provision(client) + .provision(client, crypter) .await? .insert(&db) .await @@ -59,12 +60,20 @@ pub async fn create( .map(|task| (Status::Created, Json(task))) } -async fn refresh_metrics_if_needed(task: Task, db: Db, client: Client) -> Result { +async fn refresh_metrics_if_needed( + task: Task, + db: Db, + client: Client, + crypter: &Crypter, +) -> Result { if OffsetDateTime::now_utc() - task.updated_at <= Duration::from_secs(5 * 60) { return Ok(task); } if let Some(aggregator) = task.first_party_aggregator(&db).await? { - let metrics = aggregator.client(client).get_task_metrics(&task.id).await?; + let metrics = aggregator + .client(client, crypter)? + .get_task_metrics(&task.id) + .await?; task.update_metrics(metrics, db).await.map_err(Into::into) } else { Ok(task) @@ -75,7 +84,8 @@ pub async fn show( conn: &mut Conn, (task, db, State(client)): (Task, Db, State), ) -> Result, Error> { - let task = refresh_metrics_if_needed(task, db, client).await?; + let crypter = conn.state().unwrap(); + let task = refresh_metrics_if_needed(task, db, client, crypter).await?; conn.set_last_modified(task.updated_at.into()); Ok(Json(task)) } @@ -84,18 +94,23 @@ pub async fn update( _: &mut Conn, (task, Json(update), db): (Task, Json, Db), ) -> Result { - let task = update.build(task)?.update(&db).await?; - Ok(Json(task)) + update + .build(task)? + .update(&db) + .await + .map(Json) + .map_err(Error::from) } pub mod collector_auth_tokens { use super::*; pub async fn index( - _: &mut Conn, + conn: &mut Conn, (task, db, State(client)): (Task, Db, State), ) -> Result { + let crypter = conn.state().unwrap(); let leader = task.leader_aggregator(&db).await?; - let client = leader.client(client); + let client = leader.client(client, crypter)?; let task_response = client.get_task(&task.id).await?; Ok(Json([task_response.collector_auth_token])) } diff --git a/test-support/src/fixtures.rs b/test-support/src/fixtures.rs index 22e574a1..7bec2a88 100644 --- a/test-support/src/fixtures.rs +++ b/test-support/src/fixtures.rs @@ -120,15 +120,19 @@ pub async fn aggregator_pair(app: &DivviupApi, account: &Account) -> (Aggregator } pub async fn aggregator(app: &DivviupApi, account: Option<&Account>) -> Aggregator { + let api_url: Url = format!("https://api.{}.divviup.org/", random_name()) + .parse() + .unwrap(); Aggregator { account_id: account.map(|a| a.id), - api_url: format!("https://api.{}.divviup.org/", random_name()) - .parse() - .unwrap(), + api_url: api_url.clone().into(), dap_url: format!("https://dap.{}.divviup.org/", random_name()) .parse() .unwrap(), - bearer_token: random_name(), + encrypted_bearer_token: app + .crypter() + .encrypt(api_url.as_ref().as_bytes(), random_name().as_bytes()) + .unwrap(), created_at: OffsetDateTime::now_utc(), updated_at: OffsetDateTime::now_utc(), deleted_at: None, diff --git a/test-support/src/lib.rs b/test-support/src/lib.rs index 40271395..adf0d5d1 100644 --- a/test-support/src/lib.rs +++ b/test-support/src/lib.rs @@ -1,6 +1,6 @@ use divviup_api::{ clients::aggregator_client::api_types::{Encode, HpkeConfig}, - Config, Db, + Config, Crypter, Db, }; use serde::{de::DeserializeOwned, Serialize}; use std::{error::Error, future::Future, iter::repeat_with}; @@ -87,6 +87,7 @@ pub fn config(api_mocks: impl Handler) -> Config { email_address: "test@example.test".parse().unwrap(), postmark_url: POSTMARK_URL.parse().unwrap(), client: Client::new(trillium_testing::connector(api_mocks)), + crypter: Crypter::from(Crypter::generate_key()), } } diff --git a/tests/aggregator_client.rs b/tests/aggregator_client.rs index 4d01ae3e..a73b5cab 100644 --- a/tests/aggregator_client.rs +++ b/tests/aggregator_client.rs @@ -8,7 +8,7 @@ use trillium::Handler; #[test(harness = with_client_logs)] async fn get_task_ids(app: DivviupApi, client_logs: ClientLogs) -> TestResult { let aggregator = fixtures::aggregator(&app, None).await; - let client = aggregator.client(app.config().client.clone()); + let client = aggregator.client(app.config().client.clone(), app.crypter())?; let task_ids = client.get_task_ids().await?; assert_eq!(task_ids.len(), 25); // two pages of 10 plus a final page of 5 @@ -24,7 +24,7 @@ async fn get_task_ids(app: DivviupApi, client_logs: ClientLogs) -> TestResult { log.request_headers .get_str(KnownHeaderName::Authorization) .unwrap() - == &format!("Bearer {}", aggregator.bearer_token) + == &format!("Bearer {}", aggregator.bearer_token(app.crypter()).unwrap()) })); let queries = logs @@ -46,7 +46,7 @@ async fn get_task_ids(app: DivviupApi, client_logs: ClientLogs) -> TestResult { #[test(harness = with_client_logs)] async fn get_task_metrics(app: DivviupApi, client_logs: ClientLogs) -> TestResult { let aggregator = fixtures::aggregator(&app, None).await; - let client = aggregator.client(app.config().client.clone()); + let client = aggregator.client(app.config().client.clone(), app.crypter())?; assert!(client.get_task_metrics("fake-task-id").await.is_ok()); let log = client_logs.last(); @@ -60,7 +60,7 @@ async fn get_task_metrics(app: DivviupApi, client_logs: ClientLogs) -> TestResul log.request_headers .get_str(KnownHeaderName::Authorization) .unwrap(), - &format!("Bearer {}", aggregator.bearer_token) + &format!("Bearer {}", aggregator.bearer_token(app.crypter()).unwrap()) ); assert_eq!( @@ -74,7 +74,7 @@ async fn get_task_metrics(app: DivviupApi, client_logs: ClientLogs) -> TestResul #[test(harness = with_client_logs)] async fn delete_task(app: DivviupApi, client_logs: ClientLogs) -> TestResult { let aggregator = fixtures::aggregator(&app, None).await; - let client = aggregator.client(app.config().client.clone()); + let client = aggregator.client(app.config().client.clone(), app.crypter())?; assert!(client.delete_task("fake-task-id").await.is_ok()); let log = client_logs.last(); @@ -88,7 +88,7 @@ async fn delete_task(app: DivviupApi, client_logs: ClientLogs) -> TestResult { log.request_headers .get_str(KnownHeaderName::Authorization) .unwrap(), - &format!("Bearer {}", aggregator.bearer_token) + &format!("Bearer {}", aggregator.bearer_token(app.crypter()).unwrap()) ); assert_eq!( @@ -167,6 +167,14 @@ mod prefixes { let mut app = DivviupApi::new(config(api_mocks)).await; set_up_schema(app.db()).await; let mut aggregator = fixtures::aggregator(&app, None).await.into_active_model(); + aggregator.encrypted_bearer_token = ActiveValue::Set( + app.crypter() + .encrypt( + api_url.as_ref().as_bytes(), + fixtures::random_name().as_bytes(), + ) + .unwrap(), + ); aggregator.api_url = ActiveValue::Set(api_url.into()); let aggregator = aggregator.update(app.db()).await.unwrap(); let mut info = "testing".into(); @@ -182,7 +190,7 @@ mod prefixes { client_logs: ClientLogs, aggregator: Aggregator, ) -> TestResult { - let client = aggregator.client(app.config().client.clone()); + let client = aggregator.client(app.config().client.clone(), app.crypter())?; let task_ids = client.get_task_ids().await?; assert_eq!( @@ -208,7 +216,7 @@ mod prefixes { client_logs: ClientLogs, aggregator: Aggregator, ) -> TestResult { - let client = aggregator.client(app.config().client.clone()); + let client = aggregator.client(app.config().client.clone(), app.crypter())?; let metrics = client.get_task_metrics("fake-task-id").await?; assert_eq!(client_logs.last().response_json::(), metrics); assert_eq!(client_logs.last().method, Method::Get); @@ -226,7 +234,7 @@ mod prefixes { client_logs: ClientLogs, aggregator: Aggregator, ) -> TestResult { - let client = aggregator.client(app.config().client.clone()); + let client = aggregator.client(app.config().client.clone(), app.crypter())?; client.delete_task("fake-task-id").await?; assert_eq!(client_logs.last().url.path_segments().unwrap().count(), 3); assert_eq!( diff --git a/tests/aggregators.rs b/tests/aggregators.rs index dc5bb49b..aacc5e3d 100644 --- a/tests/aggregators.rs +++ b/tests/aggregators.rs @@ -737,7 +737,7 @@ mod update { assert_eq!(response_aggregator.name, new_name); let reloaded = aggregator.reload(app.db()).await?.unwrap(); assert_eq!(reloaded.name, new_name); - assert_eq!(reloaded.bearer_token, new_bearer_token); + assert_eq!(reloaded.bearer_token(app.crypter())?, new_bearer_token); Ok(()) } @@ -747,7 +747,7 @@ mod update { let (user, account, ..) = fixtures::member(&app).await; let aggregator = fixtures::aggregator(&app, Some(&account)).await; - let original_bearer_token = aggregator.bearer_token.clone(); + let original_bearer_token = aggregator.encrypted_bearer_token.clone(); let mut conn = patch(format!("/api/aggregators/{}", aggregator.id)) .with_api_headers() .with_request_json(json!({ "bearer_token": &BAD_BEARER_TOKEN })) @@ -766,7 +766,11 @@ mod update { ); assert_eq!( - aggregator.reload(app.db()).await?.unwrap().bearer_token, + aggregator + .reload(app.db()) + .await? + .unwrap() + .encrypted_bearer_token, original_bearer_token ); assert_eq!(client_logs.last().response_status, Status::Unauthorized); diff --git a/tests/crypter.rs b/tests/crypter.rs new file mode 100644 index 00000000..674f5bdb --- /dev/null +++ b/tests/crypter.rs @@ -0,0 +1,63 @@ +use base64::{engine::general_purpose::URL_SAFE_NO_PAD, Engine}; +use divviup_api::Crypter; +use test_support::{assert_eq, test}; + +const AAD: &[u8] = b"aad"; +const PLAINTEXT: &[u8] = b"plaintext"; + +#[test] +fn round_trip_with_current_key() { + let crypter = Crypter::from(Crypter::generate_key()); + let encrypted = crypter.encrypt(AAD, PLAINTEXT).unwrap(); + assert_eq!(crypter.decrypt(AAD, &encrypted).unwrap(), PLAINTEXT); +} + +#[test] +fn round_trip_with_old_key() { + let old_key = Crypter::generate_key(); + let crypter = Crypter::from(old_key.clone()); + let encrypted = crypter.encrypt(AAD, PLAINTEXT).unwrap(); + + let crypter = Crypter::new(Crypter::generate_key(), [old_key]); + assert_eq!(crypter.decrypt(AAD, &encrypted).unwrap(), PLAINTEXT); +} + +#[test] +fn wrong_key() { + let crypter = Crypter::from(Crypter::generate_key()); + let encrypted = crypter.encrypt(AAD, PLAINTEXT).unwrap(); + + let crypter = Crypter::from(Crypter::generate_key()); + assert!(crypter.decrypt(AAD, &encrypted).is_err()); +} + +#[test] +fn wrong_aad() { + let crypter = Crypter::from(Crypter::generate_key()); + let encrypted = crypter.encrypt(AAD, PLAINTEXT).unwrap(); + assert!(crypter.decrypt(b"different aad", &encrypted).is_err()); +} + +#[test] +fn short_input_does_not_panic() { + let crypter = Crypter::from(Crypter::generate_key()); + assert!(crypter.decrypt(AAD, b"x").is_err()); +} + +#[test] +fn parsing() { + let keys = std::iter::repeat_with(Crypter::generate_key) + .take(5) + .collect::>(); + let encrypted = Crypter::from(keys[0].clone()) + .encrypt(AAD, PLAINTEXT) + .unwrap(); + let crypter = keys + .iter() + .map(|k| URL_SAFE_NO_PAD.encode(&k)) + .collect::>() + .join(",") + .parse::() + .unwrap(); + assert_eq!(crypter.decrypt(AAD, &encrypted).unwrap(), PLAINTEXT); +}