Skip to content

Commit

Permalink
Don't constantly regenerate certs
Browse files Browse the repository at this point in the history
  • Loading branch information
fasterthanlime committed Oct 13, 2024
1 parent 63ff085 commit 3310935
Showing 1 changed file with 99 additions and 25 deletions.
124 changes: 99 additions & 25 deletions src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,13 @@ use tokio::{
io::{AsyncRead, AsyncWrite, AsyncWriteExt},
net::{TcpListener, TcpStream},
};
use tokio_rustls::rustls::{pki_types::PrivateKeyDer, ServerConfig};
use tokio_rustls::rustls::{
crypto::ring::sign::any_ecdsa_type,
pki_types::PrivateKeyDer,
server::{ClientHello, ResolvesServerCert, ServerSessionMemoryCache},
sign::CertifiedKey,
ServerConfig,
};

use tracing_subscriber::{
filter::Targets, layer::SubscriberExt, util::SubscriberInitExt, Layer, Registry,
Expand Down Expand Up @@ -151,10 +157,21 @@ async fn main() -> eyre::Result<()> {
num_entries += 1;
}
}

fn format_size(size: u64) -> String {
if size < 1024 {
format!("{}B", size)
} else if size < 1024 * 1024 {
format!("{:.1}KiB", size as f64 / 1024.0)
} else {
format!("{:.1}MiB", size as f64 / 1024.0 / 1024.0)
}
}

tracing::info!(
"📊 Cache stats: {} entries, {} bytes total",
"📊 Cache stats: {} entries, {} total",
num_entries,
cache_size
format_size(cache_size)
);
} else {
tracing::warn!(
Expand All @@ -172,12 +189,27 @@ async fn main() -> eyre::Result<()> {
}
let cache_dir = cache_dir.canonicalize()?;

let cert_cache = Arc::new(CertCache {
certs_by_host: Mutex::new(HashMap::new()),
ca,
});

let mut server_conf = ServerConfig::builder()
.with_no_client_auth()
.with_cert_resolver(cert_cache);
server_conf.alpn_protocols.push(b"h2".to_vec());
server_conf.max_early_data_size = 4 * 1024;
server_conf.send_half_rtt_data = true;
server_conf.session_storage = ServerSessionMemoryCache::new(16 * 1024);
let server_conf = Arc::new(server_conf);

let settings = ProxySettings {
client,
ca,
imch,
cache_dir,
server_conf,
};

let service = UpgradeService { settings };

while let Ok((stream, remote_addr)) = ln.accept().await {
Expand Down Expand Up @@ -208,6 +240,7 @@ type InMemoryCacheHandle = Arc<InMemoryCache>;

#[derive(Default)]
struct InMemoryCache {
// TODO: use https://lib.rs/crates/papaya?
entries: Mutex<HashMap<String, (CacheEntry, Bytes)>>,
}

Expand Down Expand Up @@ -293,30 +326,16 @@ async fn handle_upgraded_conn(
let c = on_upgrade.await.unwrap();
let c = TokioIo::new(c);

let mut srv_params = rcgen::CertificateParams::new(vec![host.clone()]).unwrap();
srv_params.is_ca = rcgen::IsCa::NoCa;

let srv_keypair = rcgen::KeyPair::generate()?;
let srv_cert = srv_params
.signed_by(&srv_keypair, &settings.ca.cert, &settings.ca.keypair)
.unwrap();

let mut server_conf = ServerConfig::builder()
.with_no_client_auth()
.with_single_cert(
vec![srv_cert.into()],
PrivateKeyDer::Pkcs8(srv_keypair.serialize_der().into()),
)?;
server_conf.alpn_protocols.push(b"h2".to_vec());

let acceptor = tokio_rustls::TlsAcceptor::from(Arc::new(server_conf));
let before_accept = Instant::now();
let acceptor = tokio_rustls::TlsAcceptor::from(settings.server_conf.clone());
let tls_stream = acceptor.accept(c).await?;

{
let (_stream, server_conn) = tls_stream.get_ref();

tracing::trace!(
"Negotiated TLS session, ALPN proto:\n{}",
"Negotiated TLS session in {:?}, ALPN proto:\n{}",
before_accept.elapsed(),
pretty_hex::pretty_hex(&server_conn.alpn_protocol().unwrap_or_default())
);
}
Expand Down Expand Up @@ -351,19 +370,74 @@ async fn handle_upgraded_conn(

type OurBody = BoxBody<Bytes, Infallible>;

struct CertCache {
/// generated certificates
// TODO: use https://lib.rs/crates/papaya?
certs_by_host: Mutex<HashMap<String, Arc<CertifiedKey>>>,

/// the shared certificate authority
ca: Arc<CertAuth>,
}

impl std::fmt::Debug for CertCache {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("CertCache").finish_non_exhaustive()
}
}

impl ResolvesServerCert for CertCache {
fn resolve(&self, client_hello: ClientHello<'_>) -> Option<Arc<CertifiedKey>> {
let server_name = match client_hello.server_name() {
Some(server_name) => server_name,
None => {
tracing::debug!("No server name in client hello, aborting handshake");
return None;
}
};

let mut certs_by_host = self.certs_by_host.lock().unwrap();
if let Some(cert) = certs_by_host.get(server_name) {
return Some(Arc::clone(cert));
}

let before_gen = Instant::now();
let mut srv_params = rcgen::CertificateParams::new(vec![server_name.to_string()]).unwrap();
srv_params.is_ca = rcgen::IsCa::NoCa;

let srv_keypair = rcgen::KeyPair::generate().unwrap();
let srv_cert = srv_params
.signed_by(&srv_keypair, &self.ca.cert, &self.ca.keypair)
.unwrap();

let cert = Arc::new(CertifiedKey::new(
vec![srv_cert.into()],
any_ecdsa_type(&PrivateKeyDer::Pkcs8(srv_keypair.serialize_der().into()))
.expect("Failed to create ECDSA signing key"),
));

certs_by_host.insert(server_name.to_string(), Arc::clone(&cert));
tracing::info!(
"Generated cert for {server_name} in {:?}",
before_gen.elapsed()
);

Some(cert)
}
}

#[derive(Clone)]
struct ProxySettings {
/// the shared reqwest client for upstream
client: reqwest::Client,

/// the shared certificate authority
ca: Arc<CertAuth>,

/// the shared in-memory cache
imch: InMemoryCacheHandle,

/// the cache directory
cache_dir: PathBuf,

/// TLS server config
server_conf: Arc<ServerConfig>,
}

impl std::fmt::Debug for ProxySettings {
Expand Down

0 comments on commit 3310935

Please sign in to comment.