From 872e19c67782ca0e594c70d7b6faac0e882d6436 Mon Sep 17 00:00:00 2001 From: Michael Krasnitski <42564254+mkrasnitski@users.noreply.github.com> Date: Mon, 27 Jan 2025 11:29:49 -0600 Subject: [PATCH] Improve gateway connection/resume logic (#3099) This commit refactors how the gateway connection being closed gets handled, and also reworks how resuming is performed. If a resume fails, or if the session id is invalid/doesn't exist, the shard will fall back to restart + reidentify after a 1 second delay. This behavior was only present in some circumstances before. Also, cleaned up the loop in `ShardRunner::run` by adding a `ShardAction::Dispatch` variant, since event dispatch was already mutually exclusive to hearbeating, identifying, and restarting. The overall effect is less interleaving of control flow. Plus, removed the `Shard::{reconnect, reset}` functions as they were unused. A notable change is that 4006 is no longer considered a valid close code as it is undocumented, and neither is 1000, which tungstenite assigns as `Normal` or "clean". We should stick to the [table of close codes](https://discord.com/developers/docs/topics/opcodes-and-status-codes#gateway-gateway-close-event-codes) provided by Discord. --- src/constants.rs | 88 ++++---- src/gateway/error.rs | 3 + src/gateway/sharding/mod.rs | 305 ++++++++++----------------- src/gateway/sharding/shard_runner.rs | 226 +++++++++----------- src/lib.rs | 2 +- 5 files changed, 261 insertions(+), 363 deletions(-) diff --git a/src/constants.rs b/src/constants.rs index 86a9d925cf6..a88dbb8baed 100644 --- a/src/constants.rs +++ b/src/constants.rs @@ -36,7 +36,7 @@ pub const USER_AGENT: &str = concat!( ); enum_number! { - /// An enum representing the [gateway opcodes]. + /// An enum representing the gateway opcodes. /// /// [Discord docs](https://discord.com/developers/docs/topics/opcodes-and-status-codes#gateway-gateway-opcodes). #[derive(Clone, Copy, Debug, Eq, Hash, Ord, PartialEq, PartialOrd, Deserialize, Serialize)] @@ -70,53 +70,43 @@ enum_number! { } } -pub mod close_codes { - /// Unknown error; try reconnecting? - /// - /// Can reconnect. - pub const UNKNOWN_ERROR: u16 = 4000; - /// Invalid Gateway OP Code. - /// - /// Can resume. - pub const UNKNOWN_OPCODE: u16 = 4001; - /// An invalid payload was sent. - /// - /// Can resume. - pub const DECODE_ERROR: u16 = 4002; - /// A payload was sent prior to identifying. - /// - /// Cannot reconnect. - pub const NOT_AUTHENTICATED: u16 = 4003; - /// The account token sent with the identify payload was incorrect. - /// - /// Cannot reconnect. - pub const AUTHENTICATION_FAILED: u16 = 4004; - /// More than one identify payload was sent. - /// - /// Can reconnect. - pub const ALREADY_AUTHENTICATED: u16 = 4005; - /// The sequence sent when resuming the session was invalid. - /// - /// Can reconnect. - pub const INVALID_SEQUENCE: u16 = 4007; - /// Payloads were being sent too quickly. - /// - /// Can resume. - pub const RATE_LIMITED: u16 = 4008; - /// A session timed out. - /// - /// Can reconnect. - pub const SESSION_TIMEOUT: u16 = 4009; - /// An invalid shard when identifying was sent. - /// - /// Cannot reconnect. - pub const INVALID_SHARD: u16 = 4010; - /// The session would have handled too many guilds. +enum_number! { + /// An enum representing the gateway close codes. /// - /// Cannot reconnect. - pub const SHARDING_REQUIRED: u16 = 4011; - /// Undocumented gateway intents have been provided. - pub const INVALID_GATEWAY_INTENTS: u16 = 4013; - /// Disallowed gateway intents have been provided. - pub const DISALLOWED_GATEWAY_INTENTS: u16 = 4014; + /// [Discord docs](https://discord.com/developers/docs/topics/opcodes-and-status-codes#gateway-gateway-close-event-codes) + #[derive(Clone, Copy, Debug, Eq, Hash, Ord, PartialEq, PartialOrd, Deserialize, Serialize)] + #[non_exhaustive] + pub enum CloseCode { + /// Unknown error; try reconnecting. + UnknownError = 4000, + /// Invalid gateway opcode. + UnknownOpcode = 4001, + /// An invalid payload was sent. + DecodeError = 4002, + /// A payload was sent prior to identifying, or the session was invalidated. + NotAuthenticated = 4003, + /// The account token sent with the identify payload was incorrect. + AuthenticationFailed = 4004, + /// More than one identify payload was sent. + AlreadyAuthenticated = 4005, + /// The sequence sent when resuming the session was invalid. + InvalidSequence = 4007, + /// Payloads were being sent too quickly. + RateLimited = 4008, + /// The gateway session timed out, and a new one must be started. + SessionTimeout = 4009, + /// An invalid shard was sent when identifying. + InvalidShard = 4010, + /// The session would have handled too many guilds; you must use sharding to connect. + ShardingRequired = 4011, + /// An invalid gateway API version was sent. + InvalidApiVersion = 4012, + /// An invalid gateway intent was sent. + InvalidGatewayIntents = 4013, + /// A disallowed gateway intent was sent; you may have it disabled or may not be approved + /// to use it. + DisallowedGatewayIntents = 4014, + + _ => Unknown(u16), + } } diff --git a/src/gateway/error.rs b/src/gateway/error.rs index a3e0d02e5da..7bc0a3c28d0 100644 --- a/src/gateway/error.rs +++ b/src/gateway/error.rs @@ -20,6 +20,8 @@ pub enum Error { HeartbeatFailed, /// When invalid authentication (a bad token) was sent in the IDENTIFY. InvalidAuthentication, + /// When an invalid API version was sent to the gateway. + InvalidApiVersion, /// Expected a Ready or an InvalidateSession InvalidHandshake, /// When invalid sharding data was sent in the IDENTIFY. @@ -71,6 +73,7 @@ impl fmt::Display for Error { Self::ExpectedHello => f.write_str("Expected a Hello"), Self::HeartbeatFailed => f.write_str("Failed sending a heartbeat"), Self::InvalidAuthentication => f.write_str("Sent invalid authentication"), + Self::InvalidApiVersion => f.write_str("Sent invalid API version"), Self::InvalidHandshake => f.write_str("Expected a valid Handshake"), Self::InvalidShardData => f.write_str("Sent invalid shard data"), Self::NoAuthentication => f.write_str("Sent no authentication"), diff --git a/src/gateway/sharding/mod.rs b/src/gateway/sharding/mod.rs index 80451a274bf..a892a467c89 100644 --- a/src/gateway/sharding/mod.rs +++ b/src/gateway/sharding/mod.rs @@ -62,7 +62,7 @@ pub use self::shard_messenger::ShardMessenger; pub use self::shard_queuer::{ShardQueue, ShardQueuer, ShardQueuerMessage}; pub use self::shard_runner::{ShardRunner, ShardRunnerMessage, ShardRunnerOptions}; use super::{ActivityData, ChunkGuildFilter, GatewayError, PresenceData, WsClient}; -use crate::constants::{self, close_codes}; +use crate::constants::{self, CloseCode}; use crate::internal::prelude::*; use crate::model::event::{Event, GatewayEvent}; use crate::model::gateway::{GatewayIntents, ShardInfo}; @@ -106,7 +106,6 @@ pub struct Shard { // This must be set to `true` in `Shard::handle_event`'s `Ok(GatewayEvent::HeartbeatAck)` arm. last_heartbeat_acknowledged: bool, seq: u64, - session_id: Option, shard_info: ShardInfo, stage: ConnectionStage, /// Instant of when the shard was started. @@ -115,7 +114,7 @@ pub struct Shard { pub started: Instant, token: Token, ws_url: Arc, - resume_ws_url: Option, + resume_metadata: Option, compression: TransportCompression, pub intents: GatewayIntents, } @@ -188,7 +187,6 @@ impl Shard { let last_heartbeat_acknowledged = true; let seq = 0; let stage = ConnectionStage::Handshake; - let session_id = None; Ok(Shard { client, @@ -202,10 +200,9 @@ impl Shard { stage, started: Instant::now(), token, - session_id, shard_info, ws_url, - resume_ws_url: None, + resume_metadata: None, compression, intents, }) @@ -283,7 +280,7 @@ impl Shard { } pub fn session_id(&self) -> Option<&str> { - self.session_id.as_deref() + self.resume_metadata.as_ref().map(|m| &*m.session_id) } #[cfg_attr(feature = "tracing_instrument", instrument(skip(self)))] @@ -325,7 +322,7 @@ impl Shard { seq: u64, event: JsonMap, original_str: &str, - ) -> Result<(Option, Option)> { + ) -> Result { if seq > self.seq + 1 { warn!("[{:?}] Sequence off; them: {}, us: {}", self.shard_info, seq, self.seq); } @@ -337,8 +334,10 @@ impl Shard { Event::Ready(ready) => { debug!("[{:?}] Received Ready", self.shard_info); - self.resume_ws_url = Some(ready.ready.resume_gateway_url.clone()); - self.session_id = Some(ready.ready.session_id.clone()); + self.resume_metadata = Some(ResumeMetadata { + session_id: ready.ready.session_id.clone(), + resume_ws_url: ready.ready.resume_gateway_url.clone(), + }); self.stage = ConnectionStage::Connected; if let Some(callback) = self.application_id_callback.take() { @@ -356,7 +355,7 @@ impl Shard { _ => {}, } - Ok((None, Some(event))) + Ok(event) } #[cfg_attr(feature = "tracing_instrument", instrument(skip(self)))] @@ -377,95 +376,78 @@ impl Shard { } warn!("[{:?}] Heartbeat during non-Handshake; auto-reconnecting", self.shard_info); - return ShardAction::Reconnect(self.reconnection_type()); + return ShardAction::Reconnect; } ShardAction::Heartbeat } #[cfg_attr(feature = "tracing_instrument", instrument(skip(self)))] - fn handle_gateway_closed( - &mut self, - data: Option<&CloseFrame<'static>>, - ) -> Result> { - let num = data.map(|d| d.code.into()); - let clean = num == Some(1000); - - match num { - Some(close_codes::UNKNOWN_OPCODE) => { - warn!("[{:?}] Sent invalid opcode.", self.shard_info); - }, - Some(close_codes::DECODE_ERROR) => { - warn!("[{:?}] Sent invalid message.", self.shard_info); - }, - Some(close_codes::NOT_AUTHENTICATED) => { - warn!("[{:?}] Sent no authentication.", self.shard_info); - - return Err(Error::Gateway(GatewayError::NoAuthentication)); - }, - Some(close_codes::AUTHENTICATION_FAILED) => { - error!( - "[{:?}] Sent invalid authentication, please check the token.", - self.shard_info - ); - - return Err(Error::Gateway(GatewayError::InvalidAuthentication)); - }, - Some(close_codes::ALREADY_AUTHENTICATED) => { - warn!("[{:?}] Already authenticated.", self.shard_info); - }, - Some(close_codes::INVALID_SEQUENCE) => { - warn!("[{:?}] Sent invalid seq: {}.", self.shard_info, self.seq); - - self.seq = 0; - }, - Some(close_codes::RATE_LIMITED) => { - warn!("[{:?}] Gateway ratelimited.", self.shard_info); - }, - Some(close_codes::INVALID_SHARD) => { - warn!("[{:?}] Sent invalid shard data.", self.shard_info); - - return Err(Error::Gateway(GatewayError::InvalidShardData)); - }, - Some(close_codes::SHARDING_REQUIRED) => { - error!("[{:?}] Shard has too many guilds.", self.shard_info); - - return Err(Error::Gateway(GatewayError::OverloadedShard)); - }, - Some(4006 | close_codes::SESSION_TIMEOUT) => { - info!("[{:?}] Invalid session.", self.shard_info); - - self.session_id = None; - }, - Some(close_codes::INVALID_GATEWAY_INTENTS) => { - error!("[{:?}] Invalid gateway intents have been provided.", self.shard_info); - - return Err(Error::Gateway(GatewayError::InvalidGatewayIntents)); - }, - Some(close_codes::DISALLOWED_GATEWAY_INTENTS) => { - error!("[{:?}] Disallowed gateway intents have been provided.", self.shard_info); - - return Err(Error::Gateway(GatewayError::DisallowedGatewayIntents)); - }, - Some(other) if !clean => { - warn!( - "[{:?}] Unknown unclean close {}: {:?}", + fn handle_gateway_closed(&mut self, data: Option<&CloseFrame<'static>>) -> Result<()> { + if let Some(code) = data.map(|d| d.code) { + match CloseCode(code.into()) { + CloseCode::UnknownError => warn!("[{:?}] Unknown gateway error.", self.shard_info), + CloseCode::UnknownOpcode => warn!("[{:?}] Sent invalid opcode.", self.shard_info), + CloseCode::DecodeError => warn!("[{:?}] Sent invalid message.", self.shard_info), + CloseCode::NotAuthenticated => { + warn!( + "[{:?}] Sent no authentication, or session invalidated.", + self.shard_info + ); + return Err(Error::Gateway(GatewayError::NoAuthentication)); + }, + CloseCode::AuthenticationFailed => { + error!( + "[{:?}] Sent invalid authentication, please check the token.", + self.shard_info + ); + + return Err(Error::Gateway(GatewayError::InvalidAuthentication)); + }, + CloseCode::AlreadyAuthenticated => { + warn!("[{:?}] Already authenticated.", self.shard_info); + }, + CloseCode::InvalidSequence => { + warn!("[{:?}] Sent invalid seq: {}.", self.shard_info, self.seq); + self.seq = 0; + }, + CloseCode::RateLimited => warn!("[{:?}] Gateway ratelimited.", self.shard_info), + CloseCode::SessionTimeout => { + info!("[{:?}] Invalid session.", self.shard_info); + self.resume_metadata = None; + }, + CloseCode::InvalidShard => { + warn!("[{:?}] Sent invalid shard data.", self.shard_info); + return Err(Error::Gateway(GatewayError::InvalidShardData)); + }, + CloseCode::ShardingRequired => { + error!("[{:?}] Shard has too many guilds.", self.shard_info); + return Err(Error::Gateway(GatewayError::OverloadedShard)); + }, + CloseCode::InvalidApiVersion => { + error!("[{:?}] Invalid gateway API version provided.", self.shard_info); + return Err(Error::Gateway(GatewayError::InvalidApiVersion)); + }, + CloseCode::InvalidGatewayIntents => { + error!("[{:?}] Invalid gateway intents have been provided.", self.shard_info); + return Err(Error::Gateway(GatewayError::InvalidGatewayIntents)); + }, + CloseCode::DisallowedGatewayIntents => { + error!( + "[{:?}] Disallowed gateway intents have been provided.", + self.shard_info + ); + return Err(Error::Gateway(GatewayError::DisallowedGatewayIntents)); + }, + _ => warn!( + "[{:?}] Unknown close code {}: {:?}", self.shard_info, - other, - data.map(|d| &d.reason), - ); - }, - _ => {}, + code, + data.map(|d| &d.reason) + ), + } } - - let resume = num - .is_none_or(|x| x != close_codes::AUTHENTICATION_FAILED && self.session_id.is_some()); - - Ok(Some(if resume { - ShardAction::Reconnect(ReconnectType::Resume) - } else { - ShardAction::Reconnect(ReconnectType::Reidentify) - })) + Ok(()) } /// Handles an event from the gateway over the receiver, requiring the receiver to be passed if @@ -489,18 +471,15 @@ impl Shard { /// Returns a [`GatewayError::OverloadedShard`] if the shard would have too many guilds /// assigned to it. #[cfg_attr(feature = "tracing_instrument", instrument(skip(self)))] - pub fn handle_event( - &mut self, - event: Result, - ) -> Result<(Option, Option)> { - let action = match event { + pub fn handle_event(&mut self, event: Result) -> Result> { + match event { Ok(GatewayEvent::Dispatch { seq, data, original_str, - }) => { - return self.handle_gateway_dispatch(seq, data, &original_str); - }, + }) => self + .handle_gateway_dispatch(seq, data, &original_str) + .map(|e| Some(ShardAction::Dispatch(e))), Ok(GatewayEvent::Heartbeat(s)) => Ok(Some(self.handle_heartbeat_event(s))), Ok(GatewayEvent::HeartbeatAck) => { self.last_heartbeat_ack = Some(Instant::now()); @@ -514,45 +493,43 @@ impl Shard { debug!("[{:?}] Received a Hello; interval: {}", self.shard_info, interval); if self.stage == ConnectionStage::Resuming { - return Ok((None, None)); - } - - self.heartbeat_interval = Some(std::time::Duration::from_millis(interval)); - - Ok(Some(if self.stage == ConnectionStage::Handshake { - ShardAction::Identify + Ok(None) } else { - debug!("[{:?}] Received late Hello; autoreconnecting", self.shard_info); - - ShardAction::Reconnect(self.reconnection_type()) - })) + self.heartbeat_interval = Some(std::time::Duration::from_millis(interval)); + let action = if self.stage == ConnectionStage::Handshake { + ShardAction::Identify + } else { + debug!("[{:?}] Received late Hello; autoreconnecting", self.shard_info); + ShardAction::Reconnect + }; + + Ok(Some(action)) + } }, Ok(GatewayEvent::InvalidateSession(resumable)) => { info!("[{:?}] Received session invalidation", self.shard_info); + if !resumable { + self.resume_metadata = None; + } - Ok(Some(if resumable { - ShardAction::Reconnect(ReconnectType::Resume) - } else { - ShardAction::Reconnect(ReconnectType::Reidentify) - })) + Ok(Some(ShardAction::Reconnect)) }, - Ok(GatewayEvent::Reconnect) => Ok(Some(ShardAction::Reconnect(ReconnectType::Resume))), + Ok(GatewayEvent::Reconnect) => Ok(Some(ShardAction::Reconnect)), Err(Error::Gateway(GatewayError::Closed(data))) => { - self.handle_gateway_closed(data.as_ref()) + self.handle_gateway_closed(data.as_ref())?; + Ok(Some(ShardAction::Reconnect)) }, Err(Error::Tungstenite(why)) => { info!("[{:?}] Websocket error: {:?}", self.shard_info, why); info!("[{:?}] Will attempt to auto-reconnect", self.shard_info); - Ok(Some(ShardAction::Reconnect(self.reconnection_type()))) + Ok(Some(ShardAction::Reconnect)) }, Err(why) => { warn!("[{:?}] Unhandled error: {:?}", self.shard_info, why); Ok(None) }, - }; - - action.map(|a| (a, None)) + } } /// Does a heartbeat if needed. Returns false if something went wrong and the shard should be @@ -614,30 +591,6 @@ impl Shard { None } - /// Performs a deterministic reconnect. - /// - /// The type of reconnect is deterministic on whether a [`Self::session_id`]. - /// - /// If the `session_id` still exists, then a RESUME is sent. If not, then an IDENTIFY is sent. - /// - /// Note that, if the shard is already in a stage of [`ConnectionStage::Connecting`], then no - /// action will be performed. - pub fn should_reconnect(&mut self) -> Option { - if self.stage == ConnectionStage::Connecting { - return None; - } - - Some(self.reconnection_type()) - } - - pub fn reconnection_type(&self) -> ReconnectType { - if self.session_id().is_some() { - ReconnectType::Resume - } else { - ReconnectType::Reidentify - } - } - /// Requests that one or multiple [`Guild`]s be chunked. /// /// This will ask the gateway to start sending member chunks for large guilds (250 members+). @@ -751,7 +704,10 @@ impl Shard { debug!("[{:?}] Initializing.", self.shard_info); // Reconnect to the resume URL if possible, otherwise use the generic URL. - let ws_url = self.resume_ws_url.as_deref().unwrap_or(&self.ws_url); + let ws_url = self + .resume_metadata + .as_ref() + .map_or(self.ws_url.as_ref(), |m| m.resume_ws_url.as_ref()); // We need to do two, sort of three things here: // - set the stage of the shard as opening the websocket connection @@ -768,17 +724,6 @@ impl Shard { Ok(client) } - #[cfg_attr(feature = "tracing_instrument", instrument(skip(self)))] - pub fn reset(&mut self) { - self.last_heartbeat_sent = Some(Instant::now()); - self.last_heartbeat_ack = None; - self.heartbeat_interval = None; - self.last_heartbeat_acknowledged = true; - self.session_id = None; - self.stage = ConnectionStage::Disconnected; - self.seq = 0; - } - /// # Errors /// /// Errors if unable to re-establish a websocket connection. @@ -789,29 +734,15 @@ impl Shard { self.client = self.reinitialize().await?; self.stage = ConnectionStage::Resuming; - match &self.session_id { - Some(session_id) => { - self.client - .send_resume(&self.shard_info, session_id, self.seq, self.token.expose_secret()) - .await - }, - None => Err(Error::Gateway(GatewayError::NoSessionId)), + if let Some(m) = &self.resume_metadata { + self.client + .send_resume(&self.shard_info, &m.session_id, self.seq, self.token.expose_secret()) + .await + } else { + Err(Error::Gateway(GatewayError::NoSessionId)) } } - /// # Errors - /// - /// Errors if unable to re-establish a websocket connection. - #[cfg_attr(feature = "tracing_instrument", instrument(skip(self)))] - pub async fn reconnect(&mut self) -> Result<()> { - info!("[{:?}] Attempting to reconnect", self.shard_info()); - - self.reset(); - self.client = self.reinitialize().await?; - - Ok(()) - } - /// # Errors /// /// Errors if there is a problem with the WS connection. @@ -852,12 +783,18 @@ fn deserialize_and_log_event(map: JsonMap, original_str: &str) -> Result }) } +struct ResumeMetadata { + session_id: FixedString, + resume_ws_url: FixedString, +} + #[derive(Debug)] #[non_exhaustive] pub enum ShardAction { Heartbeat, Identify, - Reconnect(ReconnectType), + Reconnect, + Dispatch(Event), } /// Information about a [`ShardRunner`]. @@ -959,16 +896,6 @@ impl fmt::Display for ConnectionStage { } } -/// The type of reconnection that should be performed. -#[derive(Debug)] -#[non_exhaustive] -pub enum ReconnectType { - /// Indicator that a new connection should be made by sending an IDENTIFY. - Reidentify, - /// Indicator that a new connection should be made by sending a RESUME. - Resume, -} - /// Newtype around a callback that will be called on every incoming request. As long as this /// collector should still receive events, it should return `true`. Once it returns `false`, it is /// removed. diff --git a/src/gateway/sharding/shard_runner.rs b/src/gateway/sharding/shard_runner.rs index 07c2ce0110e..aaaacca847b 100644 --- a/src/gateway/sharding/shard_runner.rs +++ b/src/gateway/sharding/shard_runner.rs @@ -10,7 +10,7 @@ use tracing::{debug, error, info, trace, warn}; #[cfg(feature = "collector")] use super::CollectorCallback; -use super::{ReconnectType, Shard, ShardAction, ShardId, ShardManager, ShardStageUpdateEvent}; +use super::{Shard, ShardAction, ShardId, ShardManager, ShardStageUpdateEvent}; #[cfg(feature = "cache")] use crate::cache::Cache; #[cfg(feature = "framework")] @@ -23,7 +23,9 @@ use crate::gateway::{ActivityData, ChunkGuildFilter, GatewayError}; use crate::http::Http; use crate::internal::prelude::*; use crate::internal::tokio::spawn_named; -use crate::model::event::{Event, GatewayEvent}; +#[cfg(feature = "voice")] +use crate::model::event::Event; +use crate::model::event::GatewayEvent; use crate::model::id::GuildId; use crate::model::user::OnlineStatus; @@ -116,7 +118,7 @@ impl ShardRunner { } let pre = self.shard.stage(); - let (event, action, successful) = self.recv_event().await?; + let action = self.recv_event().await?; let post = self.shard.stage(); if post != pre { @@ -137,88 +139,83 @@ impl ShardRunner { } } - match action { - Some(ShardAction::Reconnect(ReconnectType::Reidentify)) => { - self.request_restart().await; - return Ok(()); - }, - Some(other) => { - if let Err(e) = self.action(&other).await { - debug!( - "[ShardRunner {:?}] Reconnecting due to error performing {:?}: {:?}", + if let Some(action) = action { + match action { + ShardAction::Reconnect => { + self.reconnect().await; + return Ok(()); + }, + ShardAction::Heartbeat => { + if let Err(e) = self.shard.heartbeat().await { + debug!( + "[ShardRunner {:?}] Reconnecting due to error while heartbeating: {:?}", self.shard.shard_info(), - other, e ); - match self.shard.reconnection_type() { - ReconnectType::Reidentify => { - self.request_restart().await; - return Ok(()); - }, - ReconnectType::Resume => { - if let Err(why) = self.shard.resume().await { - warn!( - "[ShardRunner {:?}] Resume failed, reidentifying: {:?}", - self.shard.shard_info(), - why - ); - - self.request_restart().await; - return Ok(()); - } - }, + self.reconnect().await; + return Ok(()); + } + }, + ShardAction::Identify => { + if let Err(e) = self.shard.identify().await { + debug!( + "[ShardRunner {:?}] Reconnecting due to error while identifying: {:?}", + self.shard.shard_info(), + e + ); + self.reconnect().await; + return Ok(()); + } + }, + ShardAction::Dispatch(event) => { + #[cfg(feature = "voice")] + { + self.handle_voice_event(&event).await; } - } - }, - None => {}, - } - if let Some(event) = event { - let context = self.make_context(); - let can_dispatch = self - .event_handler - .as_ref() - .is_none_or(|handler| handler.filter_event(&context, &event)) - && self - .raw_event_handler - .as_ref() - .is_none_or(|handler| handler.filter_event(&context, &event)); - - if can_dispatch { - #[cfg(feature = "collector")] - { - let read_lock = self.collectors.read(); - // search all collectors to be removed and clone the Arcs - let to_remove: Vec<_> = read_lock - .iter() - .filter(|callback| !callback.0(&event)) - .cloned() - .collect(); - drop(read_lock); - // remove all found arcs from the collection - // this compares the inner pointer of the Arc - if !to_remove.is_empty() { - self.collectors.write().retain(|f| !to_remove.contains(f)); + let context = self.make_context(); + let can_dispatch = self + .event_handler + .as_ref() + .is_none_or(|handler| handler.filter_event(&context, &event)) + && self + .raw_event_handler + .as_ref() + .is_none_or(|handler| handler.filter_event(&context, &event)); + + if can_dispatch { + #[cfg(feature = "collector")] + { + let read_lock = self.collectors.read(); + // search all collectors to be removed and clone the Arcs + let to_remove: Vec<_> = read_lock + .iter() + .filter(|callback| !callback.0(&event)) + .cloned() + .collect(); + drop(read_lock); + // remove all found arcs from the collection + // this compares the inner pointer of the Arc + if !to_remove.is_empty() { + self.collectors.write().retain(|f| !to_remove.contains(f)); + } + } + spawn_named( + "shard_runner::dispatch", + dispatch_model( + event, + context, + #[cfg(feature = "framework")] + self.framework.clone(), + self.event_handler.clone(), + self.raw_event_handler.clone(), + ), + ); } - } - spawn_named( - "shard_runner::dispatch", - dispatch_model( - event, - context, - #[cfg(feature = "framework")] - self.framework.clone(), - self.event_handler.clone(), - self.raw_event_handler.clone(), - ), - ); + }, } } - if !successful && !self.shard.stage().is_connecting() { - self.request_restart().await; - return Ok(()); - } trace!("[ShardRunner {:?}] loop iteration reached the end.", self.shard.shard_info()); } } @@ -228,27 +225,6 @@ impl ShardRunner { self.runner_tx.clone() } - /// Takes an action that a [`Shard`] has determined should happen and then does it. - /// - /// For example, if the shard says that an Identify message needs to be sent, this will do - /// that. - /// - /// # Errors - /// - /// Returns - #[cfg_attr(feature = "tracing_instrument", instrument(skip(self, action)))] - async fn action(&mut self, action: &ShardAction) -> Result<()> { - match *action { - ShardAction::Reconnect(ReconnectType::Reidentify) => { - self.request_restart().await; - Ok(()) - }, - ShardAction::Reconnect(ReconnectType::Resume) => self.shard.resume().await, - ShardAction::Heartbeat => self.shard.heartbeat().await, - ShardAction::Identify => self.shard.identify().await, - } - } - // Checks if the ID received to shutdown is equivalent to the ID of the shard this runner is // responsible. If so, it shuts down the WebSocket client. // @@ -413,39 +389,27 @@ impl ShardRunner { /// Returns a received event, as well as whether reading the potentially present event was /// successful. #[cfg_attr(feature = "tracing_instrument", instrument(skip(self)))] - async fn recv_event(&mut self) -> Result<(Option, Option, bool)> { + async fn recv_event(&mut self) -> Result> { let gateway_event = match self.shard.client.recv_json().await { Ok(Some(inner)) => Ok(inner), Ok(None) => { - return Ok((None, None, true)); + return Ok(None); }, Err(Error::Tungstenite(tung_err)) if matches!(*tung_err, TungsteniteError::Io(_)) => { debug!("Attempting to auto-reconnect"); + self.reconnect().await; - match self.shard.reconnection_type() { - ReconnectType::Reidentify => return Ok((None, None, false)), - ReconnectType::Resume => { - if let Err(why) = self.shard.resume().await { - warn!("Failed to resume: {:?}", why); - - // Don't spam reattempts on internet connection loss - tokio::time::sleep(std::time::Duration::from_secs(1)).await; - - return Ok((None, None, false)); - } - }, - } - - return Ok((None, None, true)); + return Ok(None); }, Err(why) => Err(why), }; let is_ack = matches!(gateway_event, Ok(GatewayEvent::HeartbeatAck)); - let (action, event) = match self.shard.handle_event(gateway_event) { - Ok((action, event)) => (action, event), + let action = match self.shard.handle_event(gateway_event) { + Ok(action) => action, Err(Error::Gateway( why @ (GatewayError::InvalidAuthentication + | GatewayError::InvalidApiVersion | GatewayError::InvalidGatewayIntents | GatewayError::DisallowedGatewayIntents), )) => { @@ -453,6 +417,7 @@ impl ShardRunner { let why_clone = match why { GatewayError::InvalidAuthentication => GatewayError::InvalidAuthentication, + GatewayError::InvalidApiVersion => GatewayError::InvalidApiVersion, GatewayError::InvalidGatewayIntents => GatewayError::InvalidGatewayIntents, GatewayError::DisallowedGatewayIntents => { GatewayError::DisallowedGatewayIntents @@ -463,10 +428,10 @@ impl ShardRunner { self.manager.return_with_value(Err(why_clone)).await; return Err(Error::Gateway(why)); }, - Err(Error::Json(_)) => return Ok((None, None, true)), + Err(Error::Json(_)) => return Ok(None), Err(why) => { error!("Shard handler recieved err: {why:?}"); - return Ok((None, None, true)); + return Ok(None); }, }; @@ -474,14 +439,27 @@ impl ShardRunner { self.update_manager().await; } - #[cfg(feature = "voice")] - { - if let Some(event) = &event { - self.handle_voice_event(event).await; + Ok(action) + } + + #[cfg_attr(feature = "tracing_instrument", instrument(skip(self)))] + async fn reconnect(&mut self) { + if self.shard.session_id().is_some() { + if let Err(why) = self.shard.resume().await { + warn!( + "[ShardRunner {:?}] Resume failed, reidentifying: {:?}", + self.shard.shard_info(), + why, + ); + + // Don't spam reattempts on internet connection loss + tokio::time::sleep(std::time::Duration::from_secs(1)).await; + + self.request_restart().await; } + } else { + self.request_restart().await; } - - Ok((event, action, true)) } #[cfg_attr(feature = "tracing_instrument", instrument(skip(self)))] diff --git a/src/lib.rs b/src/lib.rs index 40897739a28..5ed329bc1dc 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -127,7 +127,7 @@ pub mod all { #[doc(no_inline)] pub use crate::collector::*; #[doc(no_inline)] - pub use crate::constants::{close_codes::*, *}; + pub use crate::constants::*; #[cfg(feature = "framework")] #[doc(no_inline)] pub use crate::framework::*;