From 2b112fbe445423e0350b70059bda3a8d9527d4f5 Mon Sep 17 00:00:00 2001
From: John Howard <howardjohn@google.com>
Date: Thu, 6 Jun 2024 20:06:36 -0700
Subject: [PATCH] identity: triple check proxies can only request appropriate
 certificates (#1114)

This is not fixing a bug, but rather adding a 3rd line of defense
against one of the worst *potentional* vulnerabilities in ztunnel: a
confused deputy causing incorrect certificates to be used.

We know have many checks:
* We check the IP of the request, and give it an identity matching the
  workload associate with that IP
* We check the socket the request landed on is running in the pod
  matching ^
* (new) We check any request identities by a proxy are of the pod we
  are running the proxy in (for inpod). This means if there were a
coding error accidentally requesting the wrong cert, it would be denied.
---
 src/identity.rs         |  3 +++
 src/identity/manager.rs |  7 ++++++
 src/proxy.rs            | 55 ++++++++++++++++++++++++++++++++++++++---
 src/proxy/inbound.rs    |  6 ++---
 src/proxy/outbound.rs   |  6 +++--
 src/proxy/pool.rs       | 14 ++++++-----
 6 files changed, 77 insertions(+), 14 deletions(-)

diff --git a/src/identity.rs b/src/identity.rs
index e3ee15675..f611e6c74 100644
--- a/src/identity.rs
+++ b/src/identity.rs
@@ -23,6 +23,7 @@ pub mod manager;
 pub use manager::*;
 
 mod auth;
+use crate::state::WorkloadInfo;
 pub use auth::*;
 
 #[cfg(any(test, feature = "testing"))]
@@ -49,6 +50,8 @@ pub enum Error {
     Spiffe(String),
     #[error("the identity is no longer needed")]
     Forgotten,
+    #[error("BUG: identity requested {0}, but only allowed {1:?}")]
+    BugInvalidIdentityRequest(Identity, Arc<WorkloadInfo>),
 }
 
 impl From<tls::Error> for Error {
diff --git a/src/identity/manager.rs b/src/identity/manager.rs
index 7701e4387..3a35d3636 100644
--- a/src/identity/manager.rs
+++ b/src/identity/manager.rs
@@ -102,6 +102,13 @@ impl fmt::Display for Identity {
 }
 
 impl Identity {
+    pub fn from_parts(td: Strng, ns: Strng, sa: Strng) -> Identity {
+        Identity::Spiffe {
+            trust_domain: td,
+            namespace: ns,
+            service_account: sa,
+        }
+    }
     pub fn to_strng(self: &Identity) -> Strng {
         match self {
             Identity::Spiffe {
diff --git a/src/proxy.rs b/src/proxy.rs
index ac83d3cf1..12a45e34e 100644
--- a/src/proxy.rs
+++ b/src/proxy.rs
@@ -104,10 +104,55 @@ pub struct Proxy {
     policy_watcher: PolicyWatcher,
 }
 
+/// ScopedSecretManager provides an extra check against certificate lookups to ensure only appropriate certificates
+/// are requested for a given workload.
+/// This acts as a second line of defense against *coding errors* that could cause incorrect identity assignment.
+#[derive(Clone)]
+pub struct ScopedSecretManager {
+    cert_manager: Arc<SecretManager>,
+    allowed: Option<Arc<WorkloadInfo>>,
+}
+
+impl ScopedSecretManager {
+    #[cfg(any(test, feature = "testing"))]
+    pub fn new(cert_manager: Arc<SecretManager>) -> Self {
+        Self {
+            cert_manager,
+            allowed: None,
+        }
+    }
+
+    pub async fn fetch_certificate(
+        &self,
+        id: &Identity,
+    ) -> Result<Arc<tls::WorkloadCertificate>, identity::Error> {
+        if let Some(allowed) = &self.allowed {
+            match &id {
+                Identity::Spiffe {
+                    namespace,
+                    service_account,
+                    ..
+                } => {
+                    // We cannot compare trust domain, since we don't get this from WorkloadInfo
+                    if namespace != &allowed.namespace
+                        || service_account != &allowed.service_account
+                    {
+                        let err =
+                            identity::Error::BugInvalidIdentityRequest(id.clone(), allowed.clone());
+                        debug_assert!(false, "{err}");
+                        return Err(err);
+                    }
+                }
+            }
+        }
+        self.cert_manager.fetch_certificate(id).await
+    }
+}
+
 #[derive(Clone)]
 pub(super) struct ProxyInputs {
     cfg: Arc<config::Config>,
-    cert_manager: Arc<SecretManager>,
+    cert_manager: ScopedSecretManager,
     connection_manager: ConnectionManager,
     pub state: DemandProxyState,
     metrics: Arc<Metrics>,
@@ -128,14 +173,18 @@ impl ProxyInputs {
         proxy_workload_info: Option<WorkloadInfo>,
         resolver: Option<Arc<dyn Resolver + Send + Sync>>,
     ) -> Arc<Self> {
+        let proxy_workload_info = proxy_workload_info.map(Arc::new);
         Arc::new(Self {
             cfg,
             state,
-            cert_manager,
+            cert_manager: ScopedSecretManager {
+                cert_manager,
+                allowed: proxy_workload_info.clone(),
+            },
             metrics,
             connection_manager,
             socket_factory,
-            proxy_workload_info: proxy_workload_info.map(Arc::new),
+            proxy_workload_info,
             resolver,
         })
     }
diff --git a/src/proxy/inbound.rs b/src/proxy/inbound.rs
index c68926094..2bcef5a2f 100644
--- a/src/proxy/inbound.rs
+++ b/src/proxy/inbound.rs
@@ -27,9 +27,9 @@ use tokio::net::TcpStream;
 
 use tracing::{debug, info, instrument, trace_span, Instrument};
 
-use super::Error;
+use super::{Error, ScopedSecretManager};
 use crate::baggage::parse_baggage_header;
-use crate::identity::{Identity, SecretManager};
+use crate::identity::Identity;
 
 use crate::proxy::h2::server::H2Request;
 use crate::proxy::metrics::{ConnectionOpen, Reporter};
@@ -447,7 +447,7 @@ impl<'a, T: Display> Display for OptionDisplay<'a, T> {
 
 #[derive(Clone)]
 struct InboundCertProvider {
-    cert_manager: Arc<SecretManager>,
+    cert_manager: ScopedSecretManager,
     state: DemandProxyState,
     network: Strng,
 }
diff --git a/src/proxy/outbound.rs b/src/proxy/outbound.rs
index 9de94efff..b36c639b0 100644
--- a/src/proxy/outbound.rs
+++ b/src/proxy/outbound.rs
@@ -572,11 +572,13 @@ mod tests {
         };
 
         let sock_fact = std::sync::Arc::new(crate::proxy::DefaultSocketFactory);
-        let cert_mgr = identity::mock::new_secret_manager(Duration::from_secs(10));
+        let cert_mgr = proxy::ScopedSecretManager::new(identity::mock::new_secret_manager(
+            Duration::from_secs(10),
+        ));
         let original_src = false; // for testing, not needed
         let outbound = OutboundConnection {
             pi: Arc::new(ProxyInputs {
-                cert_manager: identity::mock::new_secret_manager(Duration::from_secs(10)),
+                cert_manager: cert_mgr.clone(),
                 state,
                 cfg: cfg.clone(),
                 metrics: test_proxy_metrics(),
diff --git a/src/proxy/pool.rs b/src/proxy/pool.rs
index fdf0968cb..45a230230 100644
--- a/src/proxy/pool.rs
+++ b/src/proxy/pool.rs
@@ -13,7 +13,7 @@
 // limitations under the License.
 
 #![warn(clippy::cast_lossless)]
-use super::h2;
+use super::{h2, ScopedSecretManager};
 use super::{Error, SocketFactory};
 use std::time::Duration;
 
@@ -35,7 +35,7 @@ use tokio::sync::Mutex;
 use tracing::{debug, trace};
 
 use crate::config;
-use crate::identity::{Identity, SecretManager};
+use crate::identity::Identity;
 
 use flurry;
 
@@ -78,7 +78,7 @@ struct ConnSpawner {
     cfg: Arc<config::Config>,
     original_source: bool,
     socket_factory: Arc<dyn SocketFactory + Send + Sync>,
-    cert_manager: Arc<SecretManager>,
+    cert_manager: ScopedSecretManager,
     timeout_rx: watch::Receiver<bool>,
 }
 
@@ -337,7 +337,7 @@ impl WorkloadHBONEPool {
         cfg: Arc<crate::config::Config>,
         original_source: bool,
         socket_factory: Arc<dyn SocketFactory + Send + Sync>,
-        cert_manager: Arc<SecretManager>,
+        cert_manager: ScopedSecretManager,
     ) -> WorkloadHBONEPool {
         let (timeout_tx, timeout_rx) = watch::channel(false);
         let (timeout_send, timeout_recv) = watch::channel(false);
@@ -561,7 +561,7 @@ mod test {
     use std::net::SocketAddr;
     use std::time::Instant;
 
-    use crate::identity;
+    use crate::{identity, proxy};
 
     use drain::Watch;
     use futures_util::{future, StreamExt};
@@ -991,7 +991,9 @@ mod test {
             ..crate::config::parse_config().unwrap()
         };
         let sock_fact = Arc::new(crate::proxy::DefaultSocketFactory);
-        let cert_mgr = identity::mock::new_secret_manager(Duration::from_secs(10));
+        let cert_mgr = proxy::ScopedSecretManager::new(identity::mock::new_secret_manager(
+            Duration::from_secs(10),
+        ));
         let original_src = false; // for testing, not needed
         let pool = WorkloadHBONEPool::new(Arc::new(cfg), original_src, sock_fact, cert_mgr);
         let server = TestServer {