Skip to content

Commit

Permalink
Merge pull request #83 from polyphony-chat/static-ref-database
Browse files Browse the repository at this point in the history
static `DATABASE` and kill for GatewayUser
  • Loading branch information
bitfl0wer authored Jan 14, 2025
2 parents 737d77d + b168d2b commit 1bc3d8c
Show file tree
Hide file tree
Showing 4 changed files with 134 additions and 31 deletions.
36 changes: 18 additions & 18 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -10,45 +10,45 @@ build = "build.rs"
atomic = "0.6.0"
base64 = "0.22.1"
bcrypt = "0.15.1"
bigdecimal = "0.4.3"
bitflags = { version = "2.5.0", features = ["serde"] }
chrono = { version = "0.4.38", features = ["serde"] }
bigdecimal = "0.4.7"
bitflags = { version = "2.7.0", features = ["serde"] }
chrono = { version = "0.4.39", features = ["serde"] }
dotenv = "0.15.0"
futures = "0.3.30"
futures = "0.3.31"
hostname = "0.4.0"
jsonwebtoken = "9.3.0"
lazy_static = "1.4.0"
log = "0.4.21"
lazy_static = "1.5.0"
log = "0.4.25"
log4rs = { version = "1.3.0", features = [
"rolling_file_appender",
"compound_policy",
"size_trigger",
"gzip",
] }
num-bigint = "0.4.5"
num-bigint = "0.4.6"
num-traits = "0.2.19"
openssl = "0.10.64"
poem = "3.0.1"
utoipa = { version = "5.0.0-alpha.0", features = [] }
openssl = "0.10.68"
poem = "3.1.6"
utoipa = { version = "5.3.1", features = [] }
rand = "0.8.5"
regex = "1.10.4"
reqwest = { version = "0.12.5", default-features = false, features = [
regex = "1.11.1"
reqwest = { version = "0.12.12", default-features = false, features = [
"http2",
"macos-system-configuration",
"charset",
"rustls-tls-webpki-roots",
] }
serde = { version = "1.0.203", features = ["derive"] }
serde_json = { version = "1.0.117", features = ["raw_value"] }
sqlx = { version = "0.8.2", features = [
serde = { version = "1.0.217", features = ["derive"] }
serde_json = { version = "1.0.135", features = ["raw_value"] }
sqlx = { version = "0.8.3", features = [
"json",
"chrono",
"ipnetwork",
"runtime-tokio-rustls",
"any",
] }
thiserror = "1.0.61"
tokio = { version = "1.38.0", features = ["full"] }
thiserror = "1.0.69"
tokio = { version = "1.43.0", features = ["full"] }
sentry = { version = "0.34.0", default-features = false, features = [
"backtrace",
"contexts",
Expand All @@ -57,7 +57,7 @@ sentry = { version = "0.34.0", default-features = false, features = [
"reqwest",
"rustls",
] }
clap = { version = "4.5.4", features = ["derive"] }
clap = { version = "4.5.26", features = ["derive"] }
chorus = { features = [
"backend",
], default-features = false, version = "0.18.0" }
Expand Down
86 changes: 86 additions & 0 deletions src/gateway/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -186,3 +186,89 @@ async fn purge_expired_disconnects(connected_users: ConnectedUsers) {
}
}
}

/// Tells every user-/client specific tokio task spawned by the symfonia binary to yield so that the
/// server may shut down in an orderly fashion.
///
/// ## #\[allow(clippy::await_holding_lock)]
///
/// We hold the lock of `inner.write()` across an await point `user_mutex.lock().await`. The lock
/// is held across the await point right before shutting down the application.
/// It should be okay to do this. The inevitable shutdown of the application should guarantee no contention.
///
/// TODO: This is currently unused. We cannot use this, as the future created by this function is not
/// `Send`. This is because of the whole "holding mutex across await" thing. We need to find a better
/// solution for this.
#[allow(clippy::await_holding_lock)]
pub async fn tokio_task_killer(connected_users: ConnectedUsers) {
exit_signal_detected().await;
log::debug!("Exit signal detected!");
let inner = connected_users.inner();
let mut users = &mut inner.write().users;
for (_, mut user_mutex) in users.iter() {
let mut user = user_mutex.lock().await;
user.kill().await;
}
}

/// Detects when an exit signal is sent by the operating system. The future will complete when an
/// exit signal is detected.
async fn exit_signal_detected() {
#[cfg(all(unix, windows))]
{
panic!("Unsupported platform; How did you get here?");
}

#[cfg(unix)]
{
// All these signals should shut down an application on UNIX-like systems
use tokio::signal::unix::{signal, SignalKind};
let mut sig_alarm = signal(SignalKind::alarm()).unwrap();
let mut sig_hangup = signal(SignalKind::hangup()).unwrap();
let mut sig_interrupt = signal(SignalKind::interrupt()).unwrap();
let mut sig_pipe = signal(SignalKind::interrupt()).unwrap();
let mut sig_quit = signal(SignalKind::quit()).unwrap();
let mut sig_terminate = signal(SignalKind::terminate()).unwrap();
let mut sig_user_defined1 = signal(SignalKind::user_defined1()).unwrap();
let mut sig_user_defined2 = signal(SignalKind::user_defined2()).unwrap();
let ctrl_c = tokio::signal::ctrl_c();

tokio::select! {
// If we receive any of these signals, yield
_ = sig_alarm.recv() => (),
_ = sig_hangup.recv() => (),
_ = sig_interrupt.recv() => (),
_ = sig_pipe.recv() => (),
_ = sig_quit.recv() => (),
_ = sig_terminate.recv() => (),
_ = sig_user_defined1.recv() => (),
_ = sig_user_defined2.recv() => (),
event = ctrl_c => event.expect("Failed to listen to CTRL-c event"),
}
}

#[cfg(windows)]
{
// All these signals should shut down an application on Windows
use tokio::signal::windows::{ctrl_break, ctrl_close, ctrl_logoff, ctrl_shutdown};
let mut sig_break = ctrl_break().unwrap();
let ctrl_c = tokio::signal::ctrl_c();
let mut sig_close = ctrl_close().unwrap();
let mut sig_logoff = ctrl_logoff().unwrap();
let mut sig_shutdown = ctrl_shutdown().unwrap();

tokio::select! {
// If we receive any of these signals, yield
_ = sig_break.recv() => (),
event = ctrl_c => event.expect("Failed to listen to CTRL-c event"),
_ = sig_close.recv() => (),
_ = sig_logoff.recv() => (),
_ = sig_shutdown.recv() => (),
}
}

#[cfg(not(any(unix, windows)))]
{
panic!("Unsupported platform");
}
}
12 changes: 11 additions & 1 deletion src/gateway/types/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -155,6 +155,16 @@ pub struct GatewayUser {
connected_users: ConnectedUsers,
}

impl GatewayUser {
/// Kills a user by ending all of their clients' sessions.
pub async fn kill(&mut self) {
for (_, client_mutex) in self.clients.iter() {
let mut client = client_mutex.lock().await;
client.die(self.connected_users.clone()).await
}
}
}

/// A concrete session, that a [GatewayUser] is connected to the Gateway with.
pub struct GatewayClient {
connection: WebSocketConnection,
Expand Down Expand Up @@ -338,7 +348,7 @@ impl Eq for GatewayUser {}
impl GatewayClient {
/// Disconnects a [GatewayClient] properly, including un-registering it from the memory store
/// and creating a resumeable session.
pub async fn die(mut self, connected_users: ConnectedUsers) {
pub async fn die(&mut self, connected_users: ConnectedUsers) {
self.connection.kill_send.send(()).unwrap();
let disconnect_info = DisconnectInfo {
session_token: self.session_token.clone(),
Expand Down
31 changes: 19 additions & 12 deletions src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ use std::{

use chorus::types::Snowflake;
use clap::Parser;
use sqlx::PgPool;

use crate::configuration::SymfoniaConfiguration;
use gateway::{ConnectedUsers, Event};
Expand All @@ -35,7 +36,7 @@ use log4rs::{
use logo::print_logo;
use parking_lot::RwLock;
use pubserve::Publisher;
use tokio::sync::Mutex;
use tokio::sync::{Mutex, OnceCell};

mod api;
mod cdn;
Expand Down Expand Up @@ -66,6 +67,8 @@ pub fn eq_shared_event_publisher(a: &SharedEventPublisher, b: &SharedEventPublis
/// The maximum number of rows that can be returned in most queries
static QUERY_UPPER_LIMIT: i32 = 10000;

static DATABASE: OnceCell<PgPool> = OnceCell::const_new();

#[derive(Debug)]
struct LogFilter;

Expand Down Expand Up @@ -200,11 +203,15 @@ async fn main() {
};

log::info!(target: "symfonia::db", "Establishing database connection");
let db = database::establish_connection()
.await
.expect("Failed to establish database connection");
let db = DATABASE
.get_or_init(|| async {
database::establish_connection()
.await
.expect("Could not establish a connection to the database")
})
.await;

if database::check_migrating_from_spacebar(&db)
if database::check_migrating_from_spacebar(db)
.await
.expect("Failed to check migrating from spacebar")
{
Expand All @@ -213,40 +220,40 @@ async fn main() {
std::process::exit(0);
} else {
log::warn!(target: "symfonia::db", "Migrating from spacebar to symfonia");
database::delete_spacebar_migrations(&db)
database::delete_spacebar_migrations(db)
.await
.expect("Failed to delete spacebar migrations table");
log::info!(target: "symfonia::db", "Running migrations");
sqlx::migrate!("./spacebar-migrations")
.run(&db)
.run(db)
.await
.expect("Failed to run migrations");
}
} else {
sqlx::migrate!()
.run(&db)
.run(db)
.await
.expect("Failed to run migrations");
}

if database::check_fresh_db(&db)
if database::check_fresh_db(db)
.await
.expect("Failed to check fresh db")
{
log::info!(target: "symfonia::db", "Fresh database detected. Seeding database with config data");
database::seed_config(&db)
database::seed_config(db)
.await
.expect("Failed to seed config");
}

let symfonia_config = crate::database::entities::Config::init(&db)
let symfonia_config = crate::database::entities::Config::init(db)
.await
.unwrap_or_default();

let connected_users = ConnectedUsers::default();
log::debug!(target: "symfonia", "Initializing Role->User map...");
connected_users
.init_role_user_map(&db)
.init_role_user_map(db)
.await
.expect("Failed to init role user map");
log::trace!(target: "symfonia", "Role->User map initialized with {} entries", connected_users.role_user_map.lock().await.len());
Expand Down

0 comments on commit 1bc3d8c

Please sign in to comment.