diff --git a/Cargo.toml b/Cargo.toml index 5538b29..3f29084 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -10,45 +10,45 @@ build = "build.rs" atomic = "0.6.0" base64 = "0.22.1" bcrypt = "0.15.1" -bigdecimal = "0.4.3" -bitflags = { version = "2.5.0", features = ["serde"] } -chrono = { version = "0.4.38", features = ["serde"] } +bigdecimal = "0.4.7" +bitflags = { version = "2.7.0", features = ["serde"] } +chrono = { version = "0.4.39", features = ["serde"] } dotenv = "0.15.0" -futures = "0.3.30" +futures = "0.3.31" hostname = "0.4.0" jsonwebtoken = "9.3.0" -lazy_static = "1.4.0" -log = "0.4.21" +lazy_static = "1.5.0" +log = "0.4.25" log4rs = { version = "1.3.0", features = [ "rolling_file_appender", "compound_policy", "size_trigger", "gzip", ] } -num-bigint = "0.4.5" +num-bigint = "0.4.6" num-traits = "0.2.19" -openssl = "0.10.64" -poem = "3.0.1" -utoipa = { version = "5.0.0-alpha.0", features = [] } +openssl = "0.10.68" +poem = "3.1.6" +utoipa = { version = "5.3.1", features = [] } rand = "0.8.5" -regex = "1.10.4" -reqwest = { version = "0.12.5", default-features = false, features = [ +regex = "1.11.1" +reqwest = { version = "0.12.12", default-features = false, features = [ "http2", "macos-system-configuration", "charset", "rustls-tls-webpki-roots", ] } -serde = { version = "1.0.203", features = ["derive"] } -serde_json = { version = "1.0.117", features = ["raw_value"] } -sqlx = { version = "0.8.2", features = [ +serde = { version = "1.0.217", features = ["derive"] } +serde_json = { version = "1.0.135", features = ["raw_value"] } +sqlx = { version = "0.8.3", features = [ "json", "chrono", "ipnetwork", "runtime-tokio-rustls", "any", ] } -thiserror = "1.0.61" -tokio = { version = "1.38.0", features = ["full"] } +thiserror = "1.0.69" +tokio = { version = "1.43.0", features = ["full"] } sentry = { version = "0.34.0", default-features = false, features = [ "backtrace", "contexts", @@ -57,7 +57,7 @@ sentry = { version = "0.34.0", default-features = false, features = [ "reqwest", "rustls", ] } -clap = { version = "4.5.4", features = ["derive"] } +clap = { version = "4.5.26", features = ["derive"] } chorus = { features = [ "backend", ], default-features = false, version = "0.18.0" } diff --git a/src/gateway/mod.rs b/src/gateway/mod.rs index 4e00e91..a77396d 100644 --- a/src/gateway/mod.rs +++ b/src/gateway/mod.rs @@ -186,3 +186,89 @@ async fn purge_expired_disconnects(connected_users: ConnectedUsers) { } } } + +/// Tells every user-/client specific tokio task spawned by the symfonia binary to yield so that the +/// server may shut down in an orderly fashion. +/// +/// ## #\[allow(clippy::await_holding_lock)] +/// +/// We hold the lock of `inner.write()` across an await point `user_mutex.lock().await`. The lock +/// is held across the await point right before shutting down the application. +/// It should be okay to do this. The inevitable shutdown of the application should guarantee no contention. +/// +/// TODO: This is currently unused. We cannot use this, as the future created by this function is not +/// `Send`. This is because of the whole "holding mutex across await" thing. We need to find a better +/// solution for this. +#[allow(clippy::await_holding_lock)] +pub async fn tokio_task_killer(connected_users: ConnectedUsers) { + exit_signal_detected().await; + log::debug!("Exit signal detected!"); + let inner = connected_users.inner(); + let mut users = &mut inner.write().users; + for (_, mut user_mutex) in users.iter() { + let mut user = user_mutex.lock().await; + user.kill().await; + } +} + +/// Detects when an exit signal is sent by the operating system. The future will complete when an +/// exit signal is detected. +async fn exit_signal_detected() { + #[cfg(all(unix, windows))] + { + panic!("Unsupported platform; How did you get here?"); + } + + #[cfg(unix)] + { + // All these signals should shut down an application on UNIX-like systems + use tokio::signal::unix::{signal, SignalKind}; + let mut sig_alarm = signal(SignalKind::alarm()).unwrap(); + let mut sig_hangup = signal(SignalKind::hangup()).unwrap(); + let mut sig_interrupt = signal(SignalKind::interrupt()).unwrap(); + let mut sig_pipe = signal(SignalKind::interrupt()).unwrap(); + let mut sig_quit = signal(SignalKind::quit()).unwrap(); + let mut sig_terminate = signal(SignalKind::terminate()).unwrap(); + let mut sig_user_defined1 = signal(SignalKind::user_defined1()).unwrap(); + let mut sig_user_defined2 = signal(SignalKind::user_defined2()).unwrap(); + let ctrl_c = tokio::signal::ctrl_c(); + + tokio::select! { + // If we receive any of these signals, yield + _ = sig_alarm.recv() => (), + _ = sig_hangup.recv() => (), + _ = sig_interrupt.recv() => (), + _ = sig_pipe.recv() => (), + _ = sig_quit.recv() => (), + _ = sig_terminate.recv() => (), + _ = sig_user_defined1.recv() => (), + _ = sig_user_defined2.recv() => (), + event = ctrl_c => event.expect("Failed to listen to CTRL-c event"), + } + } + + #[cfg(windows)] + { + // All these signals should shut down an application on Windows + use tokio::signal::windows::{ctrl_break, ctrl_close, ctrl_logoff, ctrl_shutdown}; + let mut sig_break = ctrl_break().unwrap(); + let ctrl_c = tokio::signal::ctrl_c(); + let mut sig_close = ctrl_close().unwrap(); + let mut sig_logoff = ctrl_logoff().unwrap(); + let mut sig_shutdown = ctrl_shutdown().unwrap(); + + tokio::select! { + // If we receive any of these signals, yield + _ = sig_break.recv() => (), + event = ctrl_c => event.expect("Failed to listen to CTRL-c event"), + _ = sig_close.recv() => (), + _ = sig_logoff.recv() => (), + _ = sig_shutdown.recv() => (), + } + } + + #[cfg(not(any(unix, windows)))] + { + panic!("Unsupported platform"); + } +} diff --git a/src/gateway/types/mod.rs b/src/gateway/types/mod.rs index d2dbfb6..0b14df8 100644 --- a/src/gateway/types/mod.rs +++ b/src/gateway/types/mod.rs @@ -155,6 +155,16 @@ pub struct GatewayUser { connected_users: ConnectedUsers, } +impl GatewayUser { + /// Kills a user by ending all of their clients' sessions. + pub async fn kill(&mut self) { + for (_, client_mutex) in self.clients.iter() { + let mut client = client_mutex.lock().await; + client.die(self.connected_users.clone()).await + } + } +} + /// A concrete session, that a [GatewayUser] is connected to the Gateway with. pub struct GatewayClient { connection: WebSocketConnection, @@ -338,7 +348,7 @@ impl Eq for GatewayUser {} impl GatewayClient { /// Disconnects a [GatewayClient] properly, including un-registering it from the memory store /// and creating a resumeable session. - pub async fn die(mut self, connected_users: ConnectedUsers) { + pub async fn die(&mut self, connected_users: ConnectedUsers) { self.connection.kill_send.send(()).unwrap(); let disconnect_info = DisconnectInfo { session_token: self.session_token.clone(), diff --git a/src/main.rs b/src/main.rs index 17e440c..f572430 100644 --- a/src/main.rs +++ b/src/main.rs @@ -13,6 +13,7 @@ use std::{ use chorus::types::Snowflake; use clap::Parser; +use sqlx::PgPool; use crate::configuration::SymfoniaConfiguration; use gateway::{ConnectedUsers, Event}; @@ -35,7 +36,7 @@ use log4rs::{ use logo::print_logo; use parking_lot::RwLock; use pubserve::Publisher; -use tokio::sync::Mutex; +use tokio::sync::{Mutex, OnceCell}; mod api; mod cdn; @@ -66,6 +67,8 @@ pub fn eq_shared_event_publisher(a: &SharedEventPublisher, b: &SharedEventPublis /// The maximum number of rows that can be returned in most queries static QUERY_UPPER_LIMIT: i32 = 10000; +static DATABASE: OnceCell = OnceCell::const_new(); + #[derive(Debug)] struct LogFilter; @@ -200,11 +203,15 @@ async fn main() { }; log::info!(target: "symfonia::db", "Establishing database connection"); - let db = database::establish_connection() - .await - .expect("Failed to establish database connection"); + let db = DATABASE + .get_or_init(|| async { + database::establish_connection() + .await + .expect("Could not establish a connection to the database") + }) + .await; - if database::check_migrating_from_spacebar(&db) + if database::check_migrating_from_spacebar(db) .await .expect("Failed to check migrating from spacebar") { @@ -213,40 +220,40 @@ async fn main() { std::process::exit(0); } else { log::warn!(target: "symfonia::db", "Migrating from spacebar to symfonia"); - database::delete_spacebar_migrations(&db) + database::delete_spacebar_migrations(db) .await .expect("Failed to delete spacebar migrations table"); log::info!(target: "symfonia::db", "Running migrations"); sqlx::migrate!("./spacebar-migrations") - .run(&db) + .run(db) .await .expect("Failed to run migrations"); } } else { sqlx::migrate!() - .run(&db) + .run(db) .await .expect("Failed to run migrations"); } - if database::check_fresh_db(&db) + if database::check_fresh_db(db) .await .expect("Failed to check fresh db") { log::info!(target: "symfonia::db", "Fresh database detected. Seeding database with config data"); - database::seed_config(&db) + database::seed_config(db) .await .expect("Failed to seed config"); } - let symfonia_config = crate::database::entities::Config::init(&db) + let symfonia_config = crate::database::entities::Config::init(db) .await .unwrap_or_default(); let connected_users = ConnectedUsers::default(); log::debug!(target: "symfonia", "Initializing Role->User map..."); connected_users - .init_role_user_map(&db) + .init_role_user_map(db) .await .expect("Failed to init role user map"); log::trace!(target: "symfonia", "Role->User map initialized with {} entries", connected_users.role_user_map.lock().await.len());