From d6b4d7520e53308e7f5bfb3a2bb0a66ae408f991 Mon Sep 17 00:00:00 2001 From: Jake Goulding Date: Wed, 13 Nov 2024 11:19:33 -0500 Subject: [PATCH] Make the WebSocket timeouts configurable --- ui/src/main.rs | 30 ++++++++++++++++++++++++++++++ ui/src/server_axum.rs | 8 +++++--- ui/src/server_axum/websocket.rs | 12 ++++++------ 3 files changed, 41 insertions(+), 9 deletions(-) diff --git a/ui/src/main.rs b/ui/src/main.rs index 8decaaad..42eb0b05 100644 --- a/ui/src/main.rs +++ b/ui/src/main.rs @@ -8,12 +8,17 @@ use std::{ net::SocketAddr, path::{Path, PathBuf}, sync::Arc, + time::Duration, }; use tracing::{error, info, warn}; use tracing_subscriber::EnvFilter; const DEFAULT_ADDRESS: &str = "127.0.0.1"; const DEFAULT_PORT: u16 = 5000; + +const DEFAULT_WEBSOCKET_IDLE_TIMEOUT: Duration = Duration::from_secs(60); +const DEFAULT_WEBSOCKET_SESSION_TIMEOUT: Duration = Duration::from_secs(45 * 60); + const DEFAULT_COORDINATORS_LIMIT: usize = 25; const DEFAULT_PROCESSES_LIMIT: usize = 10; @@ -50,6 +55,7 @@ struct Config { metrics_token: Option, feature_flags: FeatureFlags, request_db_path: Option, + websocket_config: WebSocketConfig, limits: Arc, port: u16, root: PathBuf, @@ -108,6 +114,23 @@ impl Config { let request_db_path = env::var_os("PLAYGROUND_REQUEST_DATABASE").map(Into::into); + let websocket_config = { + let idle_timeout = env::var("PLAYGROUND_WEBSOCKET_IDLE_TIMEOUT_S") + .ok() + .and_then(|l| l.parse().map(Duration::from_secs).ok()) + .unwrap_or(DEFAULT_WEBSOCKET_IDLE_TIMEOUT); + + let session_timeout = env::var("PLAYGROUND_WEBSOCKET_SESSION_TIMEOUT_S") + .ok() + .and_then(|l| l.parse().map(Duration::from_secs).ok()) + .unwrap_or(DEFAULT_WEBSOCKET_SESSION_TIMEOUT); + + WebSocketConfig { + idle_timeout, + session_timeout, + } + }; + let coordinators_limit = env::var("PLAYGROUND_COORDINATORS_LIMIT") .ok() .and_then(|l| l.parse().ok()) @@ -131,6 +154,7 @@ impl Config { metrics_token, feature_flags, request_db_path, + websocket_config, limits, port, root, @@ -232,3 +256,9 @@ impl limits::Lifecycle for LifecycleMetrics { metrics::PROCESS_ACTIVE.dec(); } } + +#[derive(Debug, Copy, Clone)] +struct WebSocketConfig { + idle_timeout: Duration, + session_timeout: Duration, +} diff --git a/ui/src/server_axum.rs b/ui/src/server_axum.rs index bbe2cad8..6d01422b 100644 --- a/ui/src/server_axum.rs +++ b/ui/src/server_axum.rs @@ -5,7 +5,7 @@ use crate::{ UNAVAILABLE_WS, }, request_database::Handle, - Config, GhToken, MetricsToken, + Config, GhToken, MetricsToken, WebSocketConfig, }; use async_trait::async_trait; use axum::{ @@ -111,7 +111,8 @@ pub(crate) async fn serve(config: Config) { .layer(Extension(db_handle)) .layer(Extension(Arc::new(SandboxCache::default()))) .layer(Extension(config.github_token())) - .layer(Extension(config.feature_flags)); + .layer(Extension(config.feature_flags)) + .layer(Extension(config.websocket_config)); if let Some(token) = config.metrics_token() { app = app.layer(Extension(token)) @@ -652,11 +653,12 @@ async fn metrics(_: MetricsAuthorization) -> Result, StatusCode> { async fn websocket( ws: WebSocketUpgrade, + Extension(config): Extension, Extension(factory): Extension, Extension(feature_flags): Extension, Extension(db): Extension, ) -> impl IntoResponse { - ws.on_upgrade(move |s| websocket::handle(s, factory.0, feature_flags.into(), db)) + ws.on_upgrade(move |s| websocket::handle(s, config, factory.0, feature_flags.into(), db)) } #[derive(Debug, serde::Deserialize)] diff --git a/ui/src/server_axum/websocket.rs b/ui/src/server_axum/websocket.rs index 31c40581..1895c8ac 100644 --- a/ui/src/server_axum/websocket.rs +++ b/ui/src/server_axum/websocket.rs @@ -2,6 +2,7 @@ use crate::{ metrics::{self, record_metric, Endpoint, HasLabelsCore, Outcome}, request_database::Handle, server_axum::api_orchestrator_integration_impls::*, + WebSocketConfig, }; use axum::extract::ws::{Message, WebSocket}; @@ -199,6 +200,7 @@ struct ExecuteResponse { #[instrument(skip_all, fields(ws_id))] pub(crate) async fn handle( socket: WebSocket, + config: WebSocketConfig, factory: Arc, feature_flags: FeatureFlags, db: Handle, @@ -212,7 +214,7 @@ pub(crate) async fn handle( tracing::Span::current().record("ws_id", &id); info!("WebSocket started"); - handle_core(socket, factory, feature_flags, db).await; + handle_core(socket, config, factory, feature_flags, db).await; info!("WebSocket ending"); metrics::LIVE_WS.dec(); @@ -242,9 +244,6 @@ struct CoordinatorManager { } impl CoordinatorManager { - const IDLE_TIMEOUT: Duration = Duration::from_secs(60); - const SESSION_TIMEOUT: Duration = Duration::from_secs(45 * 60); - const N_PARALLEL: usize = 2; const N_KINDS: usize = 1; @@ -343,6 +342,7 @@ type CoordinatorManagerResult = std::result::Res async fn handle_core( mut socket: WebSocket, + config: WebSocketConfig, factory: Arc, feature_flags: FeatureFlags, db: Handle, @@ -363,7 +363,7 @@ async fn handle_core( } let mut manager = CoordinatorManager::new(&factory); - let mut session_timeout = pin!(time::sleep(CoordinatorManager::SESSION_TIMEOUT)); + let mut session_timeout = pin!(time::sleep(config.session_timeout)); let mut idle_timeout = pin!(Fuse::terminated()); let mut active_executions = BTreeMap::new(); @@ -409,7 +409,7 @@ async fn handle_core( // The last task has completed which means we are a // candidate for idling in a little while. if manager.is_empty() { - idle_timeout.set(time::sleep(CoordinatorManager::IDLE_TIMEOUT).fuse()); + idle_timeout.set(time::sleep(config.idle_timeout).fuse()); } let (error, meta) = match task {