diff --git a/src/main.rs b/src/main.rs index d89665a..5086dc2 100644 --- a/src/main.rs +++ b/src/main.rs @@ -26,7 +26,13 @@ use tokio::{ io::{AsyncRead, AsyncWrite, AsyncWriteExt}, net::{TcpListener, TcpStream}, }; -use tokio_rustls::rustls::{pki_types::PrivateKeyDer, ServerConfig}; +use tokio_rustls::rustls::{ + crypto::ring::sign::any_ecdsa_type, + pki_types::PrivateKeyDer, + server::{ClientHello, ResolvesServerCert, ServerSessionMemoryCache}, + sign::CertifiedKey, + ServerConfig, +}; use tracing_subscriber::{ filter::Targets, layer::SubscriberExt, util::SubscriberInitExt, Layer, Registry, @@ -151,10 +157,21 @@ async fn main() -> eyre::Result<()> { num_entries += 1; } } + + fn format_size(size: u64) -> String { + if size < 1024 { + format!("{}B", size) + } else if size < 1024 * 1024 { + format!("{:.1}KiB", size as f64 / 1024.0) + } else { + format!("{:.1}MiB", size as f64 / 1024.0 / 1024.0) + } + } + tracing::info!( - "📊 Cache stats: {} entries, {} bytes total", + "📊 Cache stats: {} entries, {} total", num_entries, - cache_size + format_size(cache_size) ); } else { tracing::warn!( @@ -172,12 +189,27 @@ async fn main() -> eyre::Result<()> { } let cache_dir = cache_dir.canonicalize()?; + let cert_cache = Arc::new(CertCache { + certs_by_host: Mutex::new(HashMap::new()), + ca, + }); + + let mut server_conf = ServerConfig::builder() + .with_no_client_auth() + .with_cert_resolver(cert_cache); + server_conf.alpn_protocols.push(b"h2".to_vec()); + server_conf.max_early_data_size = 4 * 1024; + server_conf.send_half_rtt_data = true; + server_conf.session_storage = ServerSessionMemoryCache::new(16 * 1024); + let server_conf = Arc::new(server_conf); + let settings = ProxySettings { client, - ca, imch, cache_dir, + server_conf, }; + let service = UpgradeService { settings }; while let Ok((stream, remote_addr)) = ln.accept().await { @@ -208,6 +240,7 @@ type InMemoryCacheHandle = Arc; #[derive(Default)] struct InMemoryCache { + // TODO: use https://lib.rs/crates/papaya? entries: Mutex>, } @@ -293,30 +326,16 @@ async fn handle_upgraded_conn( let c = on_upgrade.await.unwrap(); let c = TokioIo::new(c); - let mut srv_params = rcgen::CertificateParams::new(vec![host.clone()]).unwrap(); - srv_params.is_ca = rcgen::IsCa::NoCa; - - let srv_keypair = rcgen::KeyPair::generate()?; - let srv_cert = srv_params - .signed_by(&srv_keypair, &settings.ca.cert, &settings.ca.keypair) - .unwrap(); - - let mut server_conf = ServerConfig::builder() - .with_no_client_auth() - .with_single_cert( - vec![srv_cert.into()], - PrivateKeyDer::Pkcs8(srv_keypair.serialize_der().into()), - )?; - server_conf.alpn_protocols.push(b"h2".to_vec()); - - let acceptor = tokio_rustls::TlsAcceptor::from(Arc::new(server_conf)); + let before_accept = Instant::now(); + let acceptor = tokio_rustls::TlsAcceptor::from(settings.server_conf.clone()); let tls_stream = acceptor.accept(c).await?; { let (_stream, server_conn) = tls_stream.get_ref(); tracing::trace!( - "Negotiated TLS session, ALPN proto:\n{}", + "Negotiated TLS session in {:?}, ALPN proto:\n{}", + before_accept.elapsed(), pretty_hex::pretty_hex(&server_conn.alpn_protocol().unwrap_or_default()) ); } @@ -351,19 +370,74 @@ async fn handle_upgraded_conn( type OurBody = BoxBody; +struct CertCache { + /// generated certificates + // TODO: use https://lib.rs/crates/papaya? + certs_by_host: Mutex>>, + + /// the shared certificate authority + ca: Arc, +} + +impl std::fmt::Debug for CertCache { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("CertCache").finish_non_exhaustive() + } +} + +impl ResolvesServerCert for CertCache { + fn resolve(&self, client_hello: ClientHello<'_>) -> Option> { + let server_name = match client_hello.server_name() { + Some(server_name) => server_name, + None => { + tracing::debug!("No server name in client hello, aborting handshake"); + return None; + } + }; + + let mut certs_by_host = self.certs_by_host.lock().unwrap(); + if let Some(cert) = certs_by_host.get(server_name) { + return Some(Arc::clone(cert)); + } + + let before_gen = Instant::now(); + let mut srv_params = rcgen::CertificateParams::new(vec![server_name.to_string()]).unwrap(); + srv_params.is_ca = rcgen::IsCa::NoCa; + + let srv_keypair = rcgen::KeyPair::generate().unwrap(); + let srv_cert = srv_params + .signed_by(&srv_keypair, &self.ca.cert, &self.ca.keypair) + .unwrap(); + + let cert = Arc::new(CertifiedKey::new( + vec![srv_cert.into()], + any_ecdsa_type(&PrivateKeyDer::Pkcs8(srv_keypair.serialize_der().into())) + .expect("Failed to create ECDSA signing key"), + )); + + certs_by_host.insert(server_name.to_string(), Arc::clone(&cert)); + tracing::info!( + "Generated cert for {server_name} in {:?}", + before_gen.elapsed() + ); + + Some(cert) + } +} + #[derive(Clone)] struct ProxySettings { /// the shared reqwest client for upstream client: reqwest::Client, - /// the shared certificate authority - ca: Arc, - /// the shared in-memory cache imch: InMemoryCacheHandle, /// the cache directory cache_dir: PathBuf, + + /// TLS server config + server_conf: Arc, } impl std::fmt::Debug for ProxySettings {