From 7fa7deb7a259256cd18a826ac7ed19115d13db5a Mon Sep 17 00:00:00 2001 From: Amos Wenger Date: Sun, 13 Oct 2024 23:42:24 +0200 Subject: [PATCH] take _some_ headers into account when computing the cache key --- Cargo.lock | 7 +++ Cargo.toml | 1 + README.md | 1 - src/main.rs | 122 +++++++++++++++++++++++++++++++++++++++++++--------- 4 files changed, 109 insertions(+), 22 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 54785a3..20b5723 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -262,6 +262,7 @@ dependencies = [ "http-serde", "hyper", "hyper-util", + "md5", "postcard", "pretty-hex", "rcgen", @@ -607,6 +608,12 @@ version = "0.4.22" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "a7a70ba024b9dc04c27ea2f0c0548feb474ec5c54bba33a7f72f873a39d07b24" +[[package]] +name = "md5" +version = "0.7.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "490cc448043f947bae3cbee9c203358d62dbee0db12107a74be5c30ccfd09771" + [[package]] name = "memchr" version = "2.7.4" diff --git a/Cargo.toml b/Cargo.toml index b79f81f..4673519 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -24,6 +24,7 @@ hyper = { version = "1.4.1", default-features = false, features = [ "server", ] } hyper-util = { version = "0.1.9", features = ["tokio"] } +md5 = "0.7.0" postcard = { version = "1.0.10", features = ["use-std"] } pretty-hex = "0.4.1" rcgen = { version = "0.13.1" } diff --git a/README.md b/README.md index 510d2c7..4fbaa27 100644 --- a/README.md +++ b/README.md @@ -8,7 +8,6 @@ An proof-of-concept(TM) caching HTTP forward proxy ## Limitations - * Will only accept to negotiate http/2 over TLS (via CONNECT) right now * Very naive rules to decide if something is cachable (see sources) specifically, **fopro DOES NOT RESPECT `cache-control`, `vary`, ETC**. * The cache is boundless (both in memory and on disk) diff --git a/src/main.rs b/src/main.rs index 5086dc2..1be552f 100644 --- a/src/main.rs +++ b/src/main.rs @@ -2,6 +2,7 @@ use argh::FromArgs; use color_eyre::eyre::{self, Context}; use futures_util::future::BoxFuture; use http_body_util::{combinators::BoxBody, BodyExt, Full}; +use http_serde::http::uri::Scheme; use hyper::{ body::{Body, Bytes}, server::conn, @@ -78,7 +79,7 @@ impl CertAuth { } // just the output of the 'date' on macOS Sequoia -static CACHE_VERSION: &str = "Sun Oct 13 22:09:06 CEST 2024"; +static CACHE_VERSION: &str = "Sun Oct 13 23:40:34 CEST 2024"; #[derive(FromArgs)] /// A caching HTTP forward proxy @@ -198,6 +199,7 @@ async fn main() -> eyre::Result<()> { .with_no_client_auth() .with_cert_resolver(cert_cache); server_conf.alpn_protocols.push(b"h2".to_vec()); + server_conf.alpn_protocols.push(b"http/1.1".to_vec()); server_conf.max_early_data_size = 4 * 1024; server_conf.send_half_rtt_data = true; server_conf.session_storage = ServerSessionMemoryCache::new(16 * 1024); @@ -289,8 +291,19 @@ where } }; + let scheme = if req.uri().port_u16().unwrap_or_default() == 443 { + Scheme::HTTPS + } else { + Scheme::HTTP + }; + if req.method() != Method::CONNECT { - let service = ProxyService { host, settings }; + let service = ProxyService { + host, + settings, + scheme, + }; + return match service.proxy_request(req).await { Ok(resp) => Ok(resp), Err(e) => { @@ -305,7 +318,7 @@ where let on_upgrade = hyper::upgrade::on(req); tokio::spawn(async move { - if let Err(e) = handle_upgraded_conn(on_upgrade, host, settings).await { + if let Err(e) = handle_upgraded_conn(on_upgrade, host, scheme, settings).await { tracing::error!("Error handling upgraded conn: {e:?}"); } }); @@ -321,6 +334,7 @@ where async fn handle_upgraded_conn( on_upgrade: OnUpgrade, host: String, + scheme: Scheme, settings: ProxySettings, ) -> eyre::Result<()> { let c = on_upgrade.await.unwrap(); @@ -330,7 +344,12 @@ async fn handle_upgraded_conn( let acceptor = tokio_rustls::TlsAcceptor::from(settings.server_conf.clone()); let tls_stream = acceptor.accept(c).await?; - { + enum Mode { + H1, + H2, + } + + let mode = { let (_stream, server_conn) = tls_stream.get_ref(); tracing::trace!( @@ -338,11 +357,34 @@ async fn handle_upgraded_conn( before_accept.elapsed(), pretty_hex::pretty_hex(&server_conn.alpn_protocol().unwrap_or_default()) ); - } - let service = ProxyService { host, settings }; - let conn = conn::http2::Builder::new(TokioExecutor::new()) - .serve_connection(TokioIo::new(tls_stream), service); + if server_conn.alpn_protocol().unwrap_or_default() == b"h2" { + Mode::H2 + } else { + Mode::H1 + } + }; + + let service = ProxyService { + host, + settings, + scheme, + }; + + let conn = tokio::spawn(async move { + match mode { + Mode::H1 => { + conn::http1::Builder::new() + .serve_connection(TokioIo::new(tls_stream), service) + .await + } + Mode::H2 => { + conn::http2::Builder::new(TokioExecutor::new()) + .serve_connection(TokioIo::new(tls_stream), service) + .await + } + } + }); match conn.await { Ok(_) => (), Err(e) => { @@ -449,8 +491,8 @@ impl std::fmt::Debug for ProxySettings { #[derive(Debug, Clone)] struct ProxyService { host: String, - settings: ProxySettings, + scheme: Scheme, } impl Service> for ProxyService @@ -494,9 +536,13 @@ impl ProxyService { let uri = req.uri().clone(); tracing::trace!(settings = ?self.settings, %uri, "Should proxy request"); + let method = req.method().clone(); + let (part, body) = req.into_parts(); + let uri_host = uri .host() - .ok_or_else(|| eyre::eyre!("expected host in CONNECT request"))?; + .or_else(|| part.headers.get("host").and_then(|h| h.to_str().ok())) + .ok_or_else(|| eyre::eyre!("expected host in URI or host header"))?; if uri_host != self.host { return Ok(Response::builder() @@ -507,28 +553,53 @@ impl ProxyService { let before_req = Instant::now(); - let method = req.method().clone(); - let (part, body) = req.into_parts(); - let mut cachable = true; - let cache_key = format!( + let mut cache_key = format!( "k/{}{}", uri.authority().map(|a| a.as_str()).unwrap_or_default(), uri.path_and_query() .map(|pq| pq.as_str()) .unwrap_or_default() ); - let cache_key = cache_key.replace(':', "_COLON_"); - let cache_key = cache_key.replace("//", "_SLASHSLASH_"); + cache_key = cache_key.replace(':', "_COLON_"); + cache_key = cache_key.replace("//", "_SLASHSLASH_"); if cache_key.contains("..") { cachable = false; } - let cache_key = if cache_key.ends_with('/') { - format!("{cache_key}_INDEX_") - } else { - cache_key.to_string() + + if cache_key.ends_with('/') { + cache_key = format!("{cache_key}_INDEX_"); }; + + if let Some(authorization) = part.headers.get(hyper::header::AUTHORIZATION) { + let authorization = authorization.to_str().unwrap(); + let hash = md5::compute(authorization); + let hash = format!("{:x}", hash); + cache_key = format!("{cache_key}_AUTH_{hash}"); + } + + if let Some(accept) = part.headers.get(hyper::header::ACCEPT) { + let accept = accept.to_str().unwrap(); + let hash = md5::compute(accept); + let hash = format!("{:x}", hash); + cache_key = format!("{cache_key}_ACCEPT_{hash}"); + } + + if let Some(accept_encoding) = part.headers.get(hyper::header::ACCEPT_ENCODING) { + let accept_encoding = accept_encoding.to_str().unwrap(); + let hash = md5::compute(accept_encoding); + let hash = format!("{:x}", hash); + cache_key = format!("{cache_key}_ACCEPT_ENCODING_{hash}"); + } + + if let Some(accept_language) = part.headers.get(hyper::header::ACCEPT_LANGUAGE) { + let accept_language = accept_language.to_str().unwrap(); + let hash = md5::compute(accept_language); + let hash = format!("{:x}", hash); + cache_key = format!("{cache_key}_ACCEPT_LANGUAGE_{hash}"); + } + tracing::debug!("Cache key: {}", cache_key); if let Some(host) = uri.host() { @@ -640,7 +711,16 @@ impl ProxyService { } } - tracing::debug!("Proxying {method} {uri}"); + tracing::debug!("Proxying {method} {uri}: {part:#?}"); + + let uri = if uri.host().is_none() { + let mut parts = uri.into_parts(); + parts.scheme = Some(self.scheme.clone()); + parts.authority = Some(format!("{}", self.host).parse().unwrap()); + hyper::Uri::from_parts(parts).unwrap() + } else { + uri + }; let upstream_res = match self .settings