Skip to content

Commit

Permalink
Make the WebSocket timeouts configurable
Browse files Browse the repository at this point in the history
  • Loading branch information
shepmaster committed Nov 13, 2024
1 parent c79a2c4 commit d6b4d75
Show file tree
Hide file tree
Showing 3 changed files with 41 additions and 9 deletions.
30 changes: 30 additions & 0 deletions ui/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,12 +8,17 @@ use std::{
net::SocketAddr,
path::{Path, PathBuf},
sync::Arc,
time::Duration,
};
use tracing::{error, info, warn};
use tracing_subscriber::EnvFilter;

const DEFAULT_ADDRESS: &str = "127.0.0.1";
const DEFAULT_PORT: u16 = 5000;

const DEFAULT_WEBSOCKET_IDLE_TIMEOUT: Duration = Duration::from_secs(60);
const DEFAULT_WEBSOCKET_SESSION_TIMEOUT: Duration = Duration::from_secs(45 * 60);

const DEFAULT_COORDINATORS_LIMIT: usize = 25;
const DEFAULT_PROCESSES_LIMIT: usize = 10;

Expand Down Expand Up @@ -50,6 +55,7 @@ struct Config {
metrics_token: Option<String>,
feature_flags: FeatureFlags,
request_db_path: Option<PathBuf>,
websocket_config: WebSocketConfig,
limits: Arc<dyn ResourceLimits>,
port: u16,
root: PathBuf,
Expand Down Expand Up @@ -108,6 +114,23 @@ impl Config {

let request_db_path = env::var_os("PLAYGROUND_REQUEST_DATABASE").map(Into::into);

let websocket_config = {
let idle_timeout = env::var("PLAYGROUND_WEBSOCKET_IDLE_TIMEOUT_S")
.ok()
.and_then(|l| l.parse().map(Duration::from_secs).ok())
.unwrap_or(DEFAULT_WEBSOCKET_IDLE_TIMEOUT);

let session_timeout = env::var("PLAYGROUND_WEBSOCKET_SESSION_TIMEOUT_S")
.ok()
.and_then(|l| l.parse().map(Duration::from_secs).ok())
.unwrap_or(DEFAULT_WEBSOCKET_SESSION_TIMEOUT);

WebSocketConfig {
idle_timeout,
session_timeout,
}
};

let coordinators_limit = env::var("PLAYGROUND_COORDINATORS_LIMIT")
.ok()
.and_then(|l| l.parse().ok())
Expand All @@ -131,6 +154,7 @@ impl Config {
metrics_token,
feature_flags,
request_db_path,
websocket_config,
limits,
port,
root,
Expand Down Expand Up @@ -232,3 +256,9 @@ impl limits::Lifecycle for LifecycleMetrics {
metrics::PROCESS_ACTIVE.dec();
}
}

#[derive(Debug, Copy, Clone)]
struct WebSocketConfig {
idle_timeout: Duration,
session_timeout: Duration,
}
8 changes: 5 additions & 3 deletions ui/src/server_axum.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ use crate::{
UNAVAILABLE_WS,
},
request_database::Handle,
Config, GhToken, MetricsToken,
Config, GhToken, MetricsToken, WebSocketConfig,
};
use async_trait::async_trait;
use axum::{
Expand Down Expand Up @@ -111,7 +111,8 @@ pub(crate) async fn serve(config: Config) {
.layer(Extension(db_handle))
.layer(Extension(Arc::new(SandboxCache::default())))
.layer(Extension(config.github_token()))
.layer(Extension(config.feature_flags));
.layer(Extension(config.feature_flags))
.layer(Extension(config.websocket_config));

if let Some(token) = config.metrics_token() {
app = app.layer(Extension(token))
Expand Down Expand Up @@ -652,11 +653,12 @@ async fn metrics(_: MetricsAuthorization) -> Result<Vec<u8>, StatusCode> {

async fn websocket(
ws: WebSocketUpgrade,
Extension(config): Extension<WebSocketConfig>,
Extension(factory): Extension<Factory>,
Extension(feature_flags): Extension<crate::FeatureFlags>,
Extension(db): Extension<Handle>,
) -> impl IntoResponse {
ws.on_upgrade(move |s| websocket::handle(s, factory.0, feature_flags.into(), db))
ws.on_upgrade(move |s| websocket::handle(s, config, factory.0, feature_flags.into(), db))
}

#[derive(Debug, serde::Deserialize)]
Expand Down
12 changes: 6 additions & 6 deletions ui/src/server_axum/websocket.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ use crate::{
metrics::{self, record_metric, Endpoint, HasLabelsCore, Outcome},
request_database::Handle,
server_axum::api_orchestrator_integration_impls::*,
WebSocketConfig,
};

use axum::extract::ws::{Message, WebSocket};
Expand Down Expand Up @@ -199,6 +200,7 @@ struct ExecuteResponse {
#[instrument(skip_all, fields(ws_id))]
pub(crate) async fn handle(
socket: WebSocket,
config: WebSocketConfig,
factory: Arc<CoordinatorFactory>,
feature_flags: FeatureFlags,
db: Handle,
Expand All @@ -212,7 +214,7 @@ pub(crate) async fn handle(
tracing::Span::current().record("ws_id", &id);
info!("WebSocket started");

handle_core(socket, factory, feature_flags, db).await;
handle_core(socket, config, factory, feature_flags, db).await;

info!("WebSocket ending");
metrics::LIVE_WS.dec();
Expand Down Expand Up @@ -242,9 +244,6 @@ struct CoordinatorManager {
}

impl CoordinatorManager {
const IDLE_TIMEOUT: Duration = Duration::from_secs(60);
const SESSION_TIMEOUT: Duration = Duration::from_secs(45 * 60);

const N_PARALLEL: usize = 2;

const N_KINDS: usize = 1;
Expand Down Expand Up @@ -343,6 +342,7 @@ type CoordinatorManagerResult<T, E = CoordinatorManagerError> = std::result::Res

async fn handle_core(
mut socket: WebSocket,
config: WebSocketConfig,
factory: Arc<CoordinatorFactory>,
feature_flags: FeatureFlags,
db: Handle,
Expand All @@ -363,7 +363,7 @@ async fn handle_core(
}

let mut manager = CoordinatorManager::new(&factory);
let mut session_timeout = pin!(time::sleep(CoordinatorManager::SESSION_TIMEOUT));
let mut session_timeout = pin!(time::sleep(config.session_timeout));
let mut idle_timeout = pin!(Fuse::terminated());

let mut active_executions = BTreeMap::new();
Expand Down Expand Up @@ -409,7 +409,7 @@ async fn handle_core(
// The last task has completed which means we are a
// candidate for idling in a little while.
if manager.is_empty() {
idle_timeout.set(time::sleep(CoordinatorManager::IDLE_TIMEOUT).fuse());
idle_timeout.set(time::sleep(config.idle_timeout).fuse());
}

let (error, meta) = match task {
Expand Down

0 comments on commit d6b4d75

Please sign in to comment.