From 2a2e3b93f35ac3d71b7747b55cff8790fe43b976 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Th=C3=A9odore=20Pr=C3=A9vot?= Date: Mon, 6 May 2024 20:38:11 +0200 Subject: [PATCH] Fix: deadlocking when calling `close_ns` from inside a `disconnect_handler` (#316) * fix(socketio/ns): close ns before removing it and avoid holding guard between handler calls. * fix(socketio/socket): drop the `disconnect_handler` lock before calling the handler * feat(socketio/ns): refactor the `close` fn for closing namespaces. * doc(socketio): adapt documentation * test(socketio): fix tests --- socketioxide/src/client.rs | 27 ++++-- socketioxide/src/io.rs | 9 +- socketioxide/src/ns.rs | 53 +++++++++--- socketioxide/src/socket.rs | 8 +- socketioxide/tests/connect.rs | 108 +++++++++++++++++++++++- socketioxide/tests/disconnect_reason.rs | 21 +++++ 6 files changed, 207 insertions(+), 19 deletions(-) diff --git a/socketioxide/src/client.rs b/socketioxide/src/client.rs index ac127360..ef60cf7e 100644 --- a/socketioxide/src/client.rs +++ b/socketioxide/src/client.rs @@ -6,13 +6,14 @@ use bytes::Bytes; use engineioxide::handler::EngineIoHandler; use engineioxide::socket::{DisconnectReason as EIoDisconnectReason, Socket as EIoSocket}; use engineioxide::Str; -use futures_util::TryFutureExt; +use futures_util::{FutureExt, TryFutureExt}; use engineioxide::sid::Sid; use tokio::sync::oneshot; use crate::adapter::Adapter; use crate::handler::ConnectHandler; +use crate::socket::DisconnectReason; use crate::ProtocolVersion; use crate::{ errors::Error, @@ -121,11 +122,19 @@ impl Client { self.ns.write().unwrap().insert(path, ns); } - /// Deletes a namespace handler + /// Deletes a namespace handler and closes all the connections to it pub fn delete_ns(&self, path: &str) { + #[cfg(feature = "v4")] + if path == "/" { + panic!("the root namespace \"/\" cannot be deleted for the socket.io v4 protocol. See https://socket.io/docs/v3/namespaces/#main-namespace for more info"); + } + #[cfg(feature = "tracing")] tracing::debug!("deleting namespace {}", path); - self.ns.write().unwrap().remove(path); + if let Some(ns) = self.ns.write().unwrap().remove(path) { + ns.close(DisconnectReason::ServerNSDisconnect) + .now_or_never(); + } } pub fn get_ns(&self, path: &str) -> Option>> { @@ -138,7 +147,11 @@ impl Client { #[cfg(feature = "tracing")] tracing::debug!("closing all namespaces"); let ns = self.ns.read().unwrap().clone(); - futures_util::future::join_all(ns.values().map(|ns| ns.close())).await; + futures_util::future::join_all( + ns.values() + .map(|ns| ns.close(DisconnectReason::ClosingServer)), + ) + .await; #[cfg(feature = "tracing")] tracing::debug!("all namespaces closed"); } @@ -230,12 +243,16 @@ impl EngineIoHandler for Client { fn on_disconnect(&self, socket: Arc>, reason: EIoDisconnectReason) { #[cfg(feature = "tracing")] tracing::debug!("eio socket disconnected"); - let _res: Result, _> = self + let socks: Vec<_> = self .ns .read() .unwrap() .values() .filter_map(|ns| ns.get_socket(socket.id).ok()) + .collect(); + + let _res: Result, _> = socks + .into_iter() .map(|s| s.close(reason.clone().into())) .collect(); diff --git a/socketioxide/src/io.rs b/socketioxide/src/io.rs index 1df99f7f..3f4ef4c4 100644 --- a/socketioxide/src/io.rs +++ b/socketioxide/src/io.rs @@ -342,7 +342,14 @@ impl SocketIo { self.0.add_ns(path.into(), callback); } - /// Deletes the namespace with the given path + /// Deletes the namespace with the given path. + /// + /// This will disconnect all sockets connected to this + /// namespace in a deferred way. + /// + /// # Panics + /// If the v4 protocol (legacy) is enabled and the namespace to delete is the default namespace "/". + /// For v4, the default namespace cannot be deleted. See [official doc](https://socket.io/docs/v3/namespaces/#main-namespace) for more informations. #[inline] pub fn delete_ns<'a>(&self, path: impl Into<&'a str>) { self.0.delete_ns(path.into()); diff --git a/socketioxide/src/ns.rs b/socketioxide/src/ns.rs index 6a968be8..7f810b8a 100644 --- a/socketioxide/src/ns.rs +++ b/socketioxide/src/ns.rs @@ -9,7 +9,7 @@ use crate::{ errors::{ConnectFail, Error}, handler::{BoxedConnectHandler, ConnectHandler, MakeErasedHandler}, packet::{Packet, PacketData}, - socket::Socket, + socket::{DisconnectReason, Socket}, SocketIoConfig, }; use crate::{client::SocketData, errors::AdapterError}; @@ -85,6 +85,9 @@ impl Namespace { /// Removes a socket from a namespace and propagate the event to the adapter pub fn remove_socket(&self, sid: Sid) -> Result<(), AdapterError> { + #[cfg(feature = "tracing")] + tracing::trace!(?sid, "removing socket from namespace"); + self.sockets.write().unwrap().remove(&sid); self.adapter .del_all(sid) @@ -118,18 +121,40 @@ impl Namespace { /// Closes the entire namespace : /// * Closes the adapter - /// * Closes all the sockets and their underlying connections + /// * Closes all the sockets and + /// their underlying connections in case of [`DisconnectReason::ClosingServer`] /// * Removes all the sockets from the namespace - pub async fn close(&self) { - self.adapter.close().ok(); - #[cfg(feature = "tracing")] - tracing::debug!("closing all sockets in namespace {}", self.path); + /// + /// This function is using .await points only when called with [`DisconnectReason::ClosingServer`] + pub async fn close(&self, reason: DisconnectReason) { + use futures_util::future; let sockets = self.sockets.read().unwrap().clone(); - futures_util::future::join_all(sockets.values().map(|s| s.close_underlying_transport())) - .await; - self.sockets.write().unwrap().shrink_to_fit(); + + #[cfg(feature = "tracing")] + tracing::debug!(?self.path, "closing {} sockets in namespace", sockets.len()); + + if reason == DisconnectReason::ClosingServer { + // When closing the underlying transport, this will indirectly close the socket + // Therefore there is no need to manually call `s.close()`. + future::join_all(sockets.values().map(|s| s.close_underlying_transport())).await; + } else { + for s in sockets.into_values() { + let _sid = s.id; + let _err = s.close(reason); + #[cfg(feature = "tracing")] + if let Err(err) = _err { + tracing::debug!(?_sid, ?err, "error closing socket"); + } + } + } + #[cfg(feature = "tracing")] + tracing::debug!(?self.path, "all sockets in namespace closed"); + + let _err = self.adapter.close(); #[cfg(feature = "tracing")] - tracing::debug!("all sockets in namespace {} closed", self.path); + if let Err(err) = _err { + tracing::debug!(?err, "could not close adapter"); + } } } @@ -160,3 +185,11 @@ impl std::fmt::Debug for Namespace { .finish() } } + +#[cfg(feature = "tracing")] +impl Drop for Namespace { + fn drop(&mut self) { + #[cfg(feature = "tracing")] + tracing::debug!("dropping namespace {}", self.path); + } +} diff --git a/socketioxide/src/socket.rs b/socketioxide/src/socket.rs index 24bf9b8f..b082d8c0 100644 --- a/socketioxide/src/socket.rs +++ b/socketioxide/src/socket.rs @@ -64,7 +64,7 @@ pub enum DisconnectReason { /// The client has manually disconnected the socket using [`socket.disconnect()`](https://socket.io/fr/docs/v4/client-api/#socketdisconnect) ClientNSDisconnect, - /// The socket was forcefully disconnected from the namespace with [`Socket::disconnect`] + /// The socket was forcefully disconnected from the namespace with [`Socket::disconnect`] or with [`SocketIo::delete_ns`](crate::io::SocketIo::delete_ns) ServerNSDisconnect, /// The server is being closed @@ -694,7 +694,11 @@ impl Socket { pub(crate) fn close(self: Arc, reason: DisconnectReason) -> Result<(), AdapterError> { self.set_connected(false); - if let Some(handler) = self.disconnect_handler.lock().unwrap().take() { + let handler = { self.disconnect_handler.lock().unwrap().take() }; + if let Some(handler) = handler { + #[cfg(feature = "tracing")] + tracing::trace!(?reason, ?self.id, "spawning disconnect handler"); + handler.call(self.clone(), reason); } diff --git a/socketioxide/tests/connect.rs b/socketioxide/tests/connect.rs index dd3f03d7..379ce65e 100644 --- a/socketioxide/tests/connect.rs +++ b/socketioxide/tests/connect.rs @@ -2,9 +2,22 @@ mod utils; use bytes::Bytes; use engineioxide::Packet::*; -use socketioxide::{extract::SocketRef, handler::ConnectHandler, SendError, SocketError, SocketIo}; +use socketioxide::{ + extract::SocketRef, handler::ConnectHandler, packet::Packet, SendError, SocketError, SocketIo, +}; use tokio::sync::mpsc; +fn create_msg(ns: &str, event: &str, data: impl Into) -> engineioxide::Packet { + let packet: String = Packet::event(ns, event, data.into()).into(); + Message(packet.into()) +} +async fn timeout_rcv(srx: &mut tokio::sync::mpsc::Receiver) -> T { + tokio::time::timeout(std::time::Duration::from_millis(500), srx.recv()) + .await + .unwrap() + .unwrap() +} + #[tokio::test] pub async fn connect_middleware() { let (_svc, io) = SocketIo::new_svc(); @@ -97,3 +110,96 @@ pub async fn connect_middleware_error() { rx.recv().await.unwrap(); assert_err!(rx.try_recv()); } + +#[tokio::test] +async fn remove_ns_from_connect_handler() { + let (tx, mut rx) = tokio::sync::mpsc::channel::<()>(2); + let (_svc, io) = SocketIo::new_svc(); + + let io_clone = io.clone(); + io.ns("/test1", move || { + tx.try_send(()).unwrap(); + io_clone.delete_ns("/test1"); + }); + + let (stx, mut srx) = io.new_dummy_sock("/test1", ()).await; + timeout_rcv(&mut srx).await; + assert_ok!(stx.try_send(create_msg("/test1", "delete_ns", ()))); + timeout_rcv(&mut rx).await; + assert_ok!(stx.try_send(create_msg("/test1", "delete_ns", ()))); + // No response since ns is already deleted + let elapsed = tokio::time::timeout(std::time::Duration::from_millis(200), rx.recv()).await; + assert!(elapsed.is_err() || elapsed.unwrap().is_none()); +} + +#[tokio::test] +async fn remove_ns_from_middleware() { + let (tx, mut rx) = tokio::sync::mpsc::channel::<()>(2); + let (_svc, io) = SocketIo::new_svc(); + + let io_clone = io.clone(); + let middleware = move || { + tx.try_send(()).unwrap(); + io_clone.delete_ns("/test1"); + Ok::<(), std::convert::Infallible>(()) + }; + fn handler() {} + io.ns("/test1", handler.with(middleware)); + + let (stx, mut srx) = io.new_dummy_sock("/test1", ()).await; + timeout_rcv(&mut srx).await; + assert_ok!(stx.try_send(create_msg("/test1", "delete_ns", ()))); + timeout_rcv(&mut rx).await; + assert_ok!(stx.try_send(create_msg("/test1", "delete_ns", ()))); + // No response since ns is already deleted + let elapsed = tokio::time::timeout(std::time::Duration::from_millis(200), rx.recv()).await; + assert!(elapsed.is_err() || elapsed.unwrap().is_none()); +} + +#[tokio::test] +async fn remove_ns_from_event_handler() { + let (tx, mut rx) = tokio::sync::mpsc::channel::<()>(2); + let (_svc, io) = SocketIo::new_svc(); + + let io_clone = io.clone(); + io.ns("/test1", move |s: SocketRef| { + s.on("delete_ns", move || { + io_clone.delete_ns("/test1"); + tx.try_send(()).unwrap(); + }); + }); + + let (stx, mut srx) = io.new_dummy_sock("/test1", ()).await; + timeout_rcv(&mut srx).await; + assert_ok!(stx.try_send(create_msg("/test1", "delete_ns", ()))); + timeout_rcv(&mut rx).await; + assert_ok!(stx.try_send(create_msg("/test1", "delete_ns", ()))); + // No response since ns is already deleted + let elapsed = tokio::time::timeout(std::time::Duration::from_millis(200), rx.recv()).await; + assert!(elapsed.is_err() || elapsed.unwrap().is_none()); +} + +#[tokio::test] +async fn remove_ns_from_disconnect_handler() { + let (tx, mut rx) = tokio::sync::mpsc::channel::<&'static str>(2); + let (_svc, io) = SocketIo::new_svc(); + + let io_clone = io.clone(); + io.ns("/test2", move |s: SocketRef| { + tx.try_send("connect").unwrap(); + s.on_disconnect(move || { + io_clone.delete_ns("/test2"); + tx.try_send("disconnect").unwrap(); + }) + }); + + let (stx, mut srx) = io.new_dummy_sock("/test2", ()).await; + assert_eq!(timeout_rcv(&mut rx).await, "connect"); + timeout_rcv(&mut srx).await; + assert_ok!(stx.try_send(Close)); + assert_eq!(timeout_rcv(&mut rx).await, "disconnect"); + + let (_stx, mut _srx) = io.new_dummy_sock("/test2", ()).await; + let elapsed = tokio::time::timeout(std::time::Duration::from_millis(200), rx.recv()).await; + assert!(elapsed.is_err() || elapsed.unwrap().is_none()); +} diff --git a/socketioxide/tests/disconnect_reason.rs b/socketioxide/tests/disconnect_reason.rs index 1c38b402..3df5053d 100644 --- a/socketioxide/tests/disconnect_reason.rs +++ b/socketioxide/tests/disconnect_reason.rs @@ -228,6 +228,27 @@ pub async fn server_ns_disconnect() { assert_eq!(data, DisconnectReason::ServerNSDisconnect); } +#[tokio::test] +pub async fn server_ns_close() { + let (tx, mut rx) = mpsc::channel::(1); + let io = create_server(12353).await; + let io2 = io.clone(); + io.ns("/test", move |socket: SocketRef| { + socket.on_disconnect(move |reason: DisconnectReason| tx.try_send(reason).unwrap()); + io2.delete_ns("/test"); + }); + + let mut ws = create_ws_connection(12353).await; + ws.send(Message::Text("40/test,{}".to_string())) + .await + .unwrap(); + let data = tokio::time::timeout(Duration::from_millis(20), rx.recv()) + .await + .expect("timeout waiting for DisconnectReason::ServerNSDisconnect") + .unwrap(); + assert_eq!(data, DisconnectReason::ServerNSDisconnect); +} + #[tokio::test] pub async fn server_ws_closing() { let io = create_server(12350).await;