diff --git a/proxy/src/auth/credentials.rs b/proxy/src/auth/credentials.rs index 6787d82b716c..7dc304d7ac7f 100644 --- a/proxy/src/auth/credentials.rs +++ b/proxy/src/auth/credentials.rs @@ -48,6 +48,14 @@ impl ClientCredentials<'_> { } impl<'a> ClientCredentials<'a> { + #[cfg(test)] + pub fn new_noop() -> Self { + ClientCredentials { + user: "", + project: None, + } + } + pub fn parse( params: &'a StartupMessageParams, sni: Option<&str>, diff --git a/proxy/src/proxy/tests.rs b/proxy/src/proxy/tests.rs index b9215cd90e0e..565f86eecc9e 100644 --- a/proxy/src/proxy/tests.rs +++ b/proxy/src/proxy/tests.rs @@ -1,5 +1,9 @@ //! A group of high-level tests for connection establishing logic and auth. +use std::borrow::Cow; + use super::*; +use crate::auth::ClientCredentials; +use crate::console::{CachedNodeInfo, NodeInfo}; use crate::{auth, sasl, scram}; use async_trait::async_trait; use rstest::rstest; @@ -304,3 +308,148 @@ fn connect_compute_total_wait() { assert!(total_wait < tokio::time::Duration::from_secs(12)); assert!(total_wait > tokio::time::Duration::from_secs(10)); } + +#[derive(Clone, Copy)] +enum ConnectAction { + Connect, + Retry, + Fail, +} + +struct TestConnectMechanism { + counter: Arc>, + sequence: Vec, +} + +impl TestConnectMechanism { + fn new(sequence: Vec) -> Self { + Self { + counter: Arc::new(std::sync::Mutex::new(0)), + sequence, + } + } +} + +#[derive(Debug)] +struct TestConnection; + +#[derive(Debug)] +struct TestConnectError { + retryable: bool, +} + +impl std::fmt::Display for TestConnectError { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "{:?}", self) + } +} + +impl std::error::Error for TestConnectError {} + +impl ShouldRetry for TestConnectError { + fn could_retry(&self) -> bool { + self.retryable + } +} + +#[async_trait] +impl ConnectMechanism for TestConnectMechanism { + type Connection = TestConnection; + type ConnectError = TestConnectError; + type Error = anyhow::Error; + + async fn connect_once( + &self, + _node_info: &console::CachedNodeInfo, + _timeout: time::Duration, + ) -> Result { + let mut counter = self.counter.lock().unwrap(); + let action = self.sequence[*counter]; + *counter += 1; + match action { + ConnectAction::Connect => Ok(TestConnection), + ConnectAction::Retry => Err(TestConnectError { retryable: true }), + ConnectAction::Fail => Err(TestConnectError { retryable: false }), + } + } + + fn update_connect_config(&self, _conf: &mut compute::ConnCfg) {} +} + +fn helper_create_connect_info() -> ( + CachedNodeInfo, + console::ConsoleReqExtra<'static>, + auth::BackendType<'static, ClientCredentials<'static>>, +) { + let node = NodeInfo { + config: compute::ConnCfg::new(), + aux: Default::default(), + allow_self_signed_compute: false, + }; + let cache = CachedNodeInfo::new_uncached(node); + let extra = console::ConsoleReqExtra { + session_id: uuid::Uuid::new_v4(), + application_name: Some("TEST"), + }; + let url = "https://TEST_URL".parse().unwrap(); + let api = console::provider::mock::Api::new(url); + let creds = auth::BackendType::Postgres(Cow::Owned(api), ClientCredentials::new_noop()); + (cache, extra, creds) +} + +#[tokio::test] +async fn connect_to_compute_success() { + use ConnectAction::*; + let mechanism = TestConnectMechanism::new(vec![Connect]); + let (cache, extra, creds) = helper_create_connect_info(); + connect_to_compute(&mechanism, cache, &extra, &creds) + .await + .unwrap(); +} + +#[tokio::test] +async fn connect_to_compute_retry() { + use ConnectAction::*; + let mechanism = TestConnectMechanism::new(vec![Retry, Retry, Connect]); + let (cache, extra, creds) = helper_create_connect_info(); + connect_to_compute(&mechanism, cache, &extra, &creds) + .await + .unwrap(); +} + +/// Test that we don't retry if the error is not retryable. +#[tokio::test] +async fn connect_to_compute_non_retry_1() { + use ConnectAction::*; + let mechanism = TestConnectMechanism::new(vec![Retry, Retry, Fail]); + let (cache, extra, creds) = helper_create_connect_info(); + connect_to_compute(&mechanism, cache, &extra, &creds) + .await + .unwrap_err(); +} + +/// Even for non-retryable errors, we should retry at least once. +#[tokio::test] +async fn connect_to_compute_non_retry_2() { + use ConnectAction::*; + let mechanism = TestConnectMechanism::new(vec![Fail, Retry, Connect]); + let (cache, extra, creds) = helper_create_connect_info(); + connect_to_compute(&mechanism, cache, &extra, &creds) + .await + .unwrap(); +} + +/// Retry for at most `NUM_RETRIES_CONNECT` times. +#[tokio::test] +async fn connect_to_compute_non_retry_3() { + assert_eq!(NUM_RETRIES_CONNECT, 10); + use ConnectAction::*; + let mechanism = TestConnectMechanism::new(vec![ + Retry, Retry, Retry, Retry, Retry, Retry, Retry, Retry, Retry, Retry, + /* the 11th time */ Retry, + ]); + let (cache, extra, creds) = helper_create_connect_info(); + connect_to_compute(&mechanism, cache, &extra, &creds) + .await + .unwrap_err(); +}