Skip to content

Commit

Permalink
add proxy unit tests for retry connections (neondatabase#4721)
Browse files Browse the repository at this point in the history
Given now we've refactored `connect_to_compute` as a generic, we can
test it with mock backends. In this PR, we mock the error API and
connect_once API to test the retry behavior of `connect_to_compute`. In
the next PR, I'll add mock for credentials so that we can also test
behavior with `wake_compute`.

ref neondatabase#4709

---------

Signed-off-by: Alex Chi Z <[email protected]>
  • Loading branch information
skyzh committed Jul 24, 2023
1 parent e5b7ddf commit 407a20c
Show file tree
Hide file tree
Showing 2 changed files with 157 additions and 0 deletions.
8 changes: 8 additions & 0 deletions proxy/src/auth/credentials.rs
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,14 @@ impl ClientCredentials<'_> {
}

impl<'a> ClientCredentials<'a> {
#[cfg(test)]
pub fn new_noop() -> Self {
ClientCredentials {
user: "",
project: None,
}
}

pub fn parse(
params: &'a StartupMessageParams,
sni: Option<&str>,
Expand Down
149 changes: 149 additions & 0 deletions proxy/src/proxy/tests.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,9 @@
//! A group of high-level tests for connection establishing logic and auth.
use std::borrow::Cow;

use super::*;
use crate::auth::ClientCredentials;
use crate::console::{CachedNodeInfo, NodeInfo};
use crate::{auth, sasl, scram};
use async_trait::async_trait;
use rstest::rstest;
Expand Down Expand Up @@ -304,3 +308,148 @@ fn connect_compute_total_wait() {
assert!(total_wait < tokio::time::Duration::from_secs(12));
assert!(total_wait > tokio::time::Duration::from_secs(10));
}

#[derive(Clone, Copy)]
enum ConnectAction {
Connect,
Retry,
Fail,
}

struct TestConnectMechanism {
counter: Arc<std::sync::Mutex<usize>>,
sequence: Vec<ConnectAction>,
}

impl TestConnectMechanism {
fn new(sequence: Vec<ConnectAction>) -> Self {
Self {
counter: Arc::new(std::sync::Mutex::new(0)),
sequence,
}
}
}

#[derive(Debug)]
struct TestConnection;

#[derive(Debug)]
struct TestConnectError {
retryable: bool,
}

impl std::fmt::Display for TestConnectError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "{:?}", self)
}
}

impl std::error::Error for TestConnectError {}

impl ShouldRetry for TestConnectError {
fn could_retry(&self) -> bool {
self.retryable
}
}

#[async_trait]
impl ConnectMechanism for TestConnectMechanism {
type Connection = TestConnection;
type ConnectError = TestConnectError;
type Error = anyhow::Error;

async fn connect_once(
&self,
_node_info: &console::CachedNodeInfo,
_timeout: time::Duration,
) -> Result<Self::Connection, Self::ConnectError> {
let mut counter = self.counter.lock().unwrap();
let action = self.sequence[*counter];
*counter += 1;
match action {
ConnectAction::Connect => Ok(TestConnection),
ConnectAction::Retry => Err(TestConnectError { retryable: true }),
ConnectAction::Fail => Err(TestConnectError { retryable: false }),
}
}

fn update_connect_config(&self, _conf: &mut compute::ConnCfg) {}
}

fn helper_create_connect_info() -> (
CachedNodeInfo,
console::ConsoleReqExtra<'static>,
auth::BackendType<'static, ClientCredentials<'static>>,
) {
let node = NodeInfo {
config: compute::ConnCfg::new(),
aux: Default::default(),
allow_self_signed_compute: false,
};
let cache = CachedNodeInfo::new_uncached(node);
let extra = console::ConsoleReqExtra {
session_id: uuid::Uuid::new_v4(),
application_name: Some("TEST"),
};
let url = "https://TEST_URL".parse().unwrap();
let api = console::provider::mock::Api::new(url);
let creds = auth::BackendType::Postgres(Cow::Owned(api), ClientCredentials::new_noop());
(cache, extra, creds)
}

#[tokio::test]
async fn connect_to_compute_success() {
use ConnectAction::*;
let mechanism = TestConnectMechanism::new(vec![Connect]);
let (cache, extra, creds) = helper_create_connect_info();
connect_to_compute(&mechanism, cache, &extra, &creds)
.await
.unwrap();
}

#[tokio::test]
async fn connect_to_compute_retry() {
use ConnectAction::*;
let mechanism = TestConnectMechanism::new(vec![Retry, Retry, Connect]);
let (cache, extra, creds) = helper_create_connect_info();
connect_to_compute(&mechanism, cache, &extra, &creds)
.await
.unwrap();
}

/// Test that we don't retry if the error is not retryable.
#[tokio::test]
async fn connect_to_compute_non_retry_1() {
use ConnectAction::*;
let mechanism = TestConnectMechanism::new(vec![Retry, Retry, Fail]);
let (cache, extra, creds) = helper_create_connect_info();
connect_to_compute(&mechanism, cache, &extra, &creds)
.await
.unwrap_err();
}

/// Even for non-retryable errors, we should retry at least once.
#[tokio::test]
async fn connect_to_compute_non_retry_2() {
use ConnectAction::*;
let mechanism = TestConnectMechanism::new(vec![Fail, Retry, Connect]);
let (cache, extra, creds) = helper_create_connect_info();
connect_to_compute(&mechanism, cache, &extra, &creds)
.await
.unwrap();
}

/// Retry for at most `NUM_RETRIES_CONNECT` times.
#[tokio::test]
async fn connect_to_compute_non_retry_3() {
assert_eq!(NUM_RETRIES_CONNECT, 10);
use ConnectAction::*;
let mechanism = TestConnectMechanism::new(vec![
Retry, Retry, Retry, Retry, Retry, Retry, Retry, Retry, Retry, Retry,
/* the 11th time */ Retry,
]);
let (cache, extra, creds) = helper_create_connect_info();
connect_to_compute(&mechanism, cache, &extra, &creds)
.await
.unwrap_err();
}

0 comments on commit 407a20c

Please sign in to comment.