Skip to content

Commit

Permalink
Fix: deadlocking when calling close_ns from inside a `disconnect_ha…
Browse files Browse the repository at this point in the history
…ndler` (#316)

* fix(socketio/ns): close ns before removing it and avoid holding guard between handler calls.

* fix(socketio/socket): drop the `disconnect_handler` lock before calling the handler

* feat(socketio/ns): refactor the `close` fn for closing namespaces.

* doc(socketio): adapt documentation

* test(socketio): fix tests
  • Loading branch information
Totodore authored May 6, 2024
1 parent 867f2b5 commit 2a2e3b9
Show file tree
Hide file tree
Showing 6 changed files with 207 additions and 19 deletions.
27 changes: 22 additions & 5 deletions socketioxide/src/client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,13 +6,14 @@ use bytes::Bytes;
use engineioxide::handler::EngineIoHandler;
use engineioxide::socket::{DisconnectReason as EIoDisconnectReason, Socket as EIoSocket};
use engineioxide::Str;
use futures_util::TryFutureExt;
use futures_util::{FutureExt, TryFutureExt};

use engineioxide::sid::Sid;
use tokio::sync::oneshot;

use crate::adapter::Adapter;
use crate::handler::ConnectHandler;
use crate::socket::DisconnectReason;
use crate::ProtocolVersion;
use crate::{
errors::Error,
Expand Down Expand Up @@ -121,11 +122,19 @@ impl<A: Adapter> Client<A> {
self.ns.write().unwrap().insert(path, ns);
}

/// Deletes a namespace handler
/// Deletes a namespace handler and closes all the connections to it
pub fn delete_ns(&self, path: &str) {
#[cfg(feature = "v4")]
if path == "/" {
panic!("the root namespace \"/\" cannot be deleted for the socket.io v4 protocol. See https://socket.io/docs/v3/namespaces/#main-namespace for more info");
}

#[cfg(feature = "tracing")]
tracing::debug!("deleting namespace {}", path);
self.ns.write().unwrap().remove(path);
if let Some(ns) = self.ns.write().unwrap().remove(path) {
ns.close(DisconnectReason::ServerNSDisconnect)
.now_or_never();
}
}

pub fn get_ns(&self, path: &str) -> Option<Arc<Namespace<A>>> {
Expand All @@ -138,7 +147,11 @@ impl<A: Adapter> Client<A> {
#[cfg(feature = "tracing")]
tracing::debug!("closing all namespaces");
let ns = self.ns.read().unwrap().clone();
futures_util::future::join_all(ns.values().map(|ns| ns.close())).await;
futures_util::future::join_all(
ns.values()
.map(|ns| ns.close(DisconnectReason::ClosingServer)),
)
.await;
#[cfg(feature = "tracing")]
tracing::debug!("all namespaces closed");
}
Expand Down Expand Up @@ -230,12 +243,16 @@ impl<A: Adapter> EngineIoHandler for Client<A> {
fn on_disconnect(&self, socket: Arc<EIoSocket<SocketData>>, reason: EIoDisconnectReason) {
#[cfg(feature = "tracing")]
tracing::debug!("eio socket disconnected");
let _res: Result<Vec<_>, _> = self
let socks: Vec<_> = self
.ns
.read()
.unwrap()
.values()
.filter_map(|ns| ns.get_socket(socket.id).ok())
.collect();

let _res: Result<Vec<_>, _> = socks
.into_iter()
.map(|s| s.close(reason.clone().into()))
.collect();

Expand Down
9 changes: 8 additions & 1 deletion socketioxide/src/io.rs
Original file line number Diff line number Diff line change
Expand Up @@ -342,7 +342,14 @@ impl<A: Adapter> SocketIo<A> {
self.0.add_ns(path.into(), callback);
}

/// Deletes the namespace with the given path
/// Deletes the namespace with the given path.
///
/// This will disconnect all sockets connected to this
/// namespace in a deferred way.
///
/// # Panics
/// If the v4 protocol (legacy) is enabled and the namespace to delete is the default namespace "/".
/// For v4, the default namespace cannot be deleted. See [official doc](https://socket.io/docs/v3/namespaces/#main-namespace) for more informations.
#[inline]
pub fn delete_ns<'a>(&self, path: impl Into<&'a str>) {
self.0.delete_ns(path.into());
Expand Down
53 changes: 43 additions & 10 deletions socketioxide/src/ns.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ use crate::{
errors::{ConnectFail, Error},
handler::{BoxedConnectHandler, ConnectHandler, MakeErasedHandler},
packet::{Packet, PacketData},
socket::Socket,
socket::{DisconnectReason, Socket},
SocketIoConfig,
};
use crate::{client::SocketData, errors::AdapterError};
Expand Down Expand Up @@ -85,6 +85,9 @@ impl<A: Adapter> Namespace<A> {

/// Removes a socket from a namespace and propagate the event to the adapter
pub fn remove_socket(&self, sid: Sid) -> Result<(), AdapterError> {
#[cfg(feature = "tracing")]
tracing::trace!(?sid, "removing socket from namespace");

self.sockets.write().unwrap().remove(&sid);
self.adapter
.del_all(sid)
Expand Down Expand Up @@ -118,18 +121,40 @@ impl<A: Adapter> Namespace<A> {

/// Closes the entire namespace :
/// * Closes the adapter
/// * Closes all the sockets and their underlying connections
/// * Closes all the sockets and
/// their underlying connections in case of [`DisconnectReason::ClosingServer`]
/// * Removes all the sockets from the namespace
pub async fn close(&self) {
self.adapter.close().ok();
#[cfg(feature = "tracing")]
tracing::debug!("closing all sockets in namespace {}", self.path);
///
/// This function is using .await points only when called with [`DisconnectReason::ClosingServer`]
pub async fn close(&self, reason: DisconnectReason) {
use futures_util::future;
let sockets = self.sockets.read().unwrap().clone();
futures_util::future::join_all(sockets.values().map(|s| s.close_underlying_transport()))
.await;
self.sockets.write().unwrap().shrink_to_fit();

#[cfg(feature = "tracing")]
tracing::debug!(?self.path, "closing {} sockets in namespace", sockets.len());

if reason == DisconnectReason::ClosingServer {
// When closing the underlying transport, this will indirectly close the socket
// Therefore there is no need to manually call `s.close()`.
future::join_all(sockets.values().map(|s| s.close_underlying_transport())).await;
} else {
for s in sockets.into_values() {
let _sid = s.id;
let _err = s.close(reason);
#[cfg(feature = "tracing")]
if let Err(err) = _err {
tracing::debug!(?_sid, ?err, "error closing socket");
}
}
}
#[cfg(feature = "tracing")]
tracing::debug!(?self.path, "all sockets in namespace closed");

let _err = self.adapter.close();
#[cfg(feature = "tracing")]
tracing::debug!("all sockets in namespace {} closed", self.path);
if let Err(err) = _err {
tracing::debug!(?err, "could not close adapter");
}
}
}

Expand Down Expand Up @@ -160,3 +185,11 @@ impl<A: Adapter + std::fmt::Debug> std::fmt::Debug for Namespace<A> {
.finish()
}
}

#[cfg(feature = "tracing")]
impl<A: Adapter> Drop for Namespace<A> {
fn drop(&mut self) {
#[cfg(feature = "tracing")]
tracing::debug!("dropping namespace {}", self.path);
}
}
8 changes: 6 additions & 2 deletions socketioxide/src/socket.rs
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ pub enum DisconnectReason {
/// The client has manually disconnected the socket using [`socket.disconnect()`](https://socket.io/fr/docs/v4/client-api/#socketdisconnect)
ClientNSDisconnect,

/// The socket was forcefully disconnected from the namespace with [`Socket::disconnect`]
/// The socket was forcefully disconnected from the namespace with [`Socket::disconnect`] or with [`SocketIo::delete_ns`](crate::io::SocketIo::delete_ns)
ServerNSDisconnect,

/// The server is being closed
Expand Down Expand Up @@ -694,7 +694,11 @@ impl<A: Adapter> Socket<A> {
pub(crate) fn close(self: Arc<Self>, reason: DisconnectReason) -> Result<(), AdapterError> {
self.set_connected(false);

if let Some(handler) = self.disconnect_handler.lock().unwrap().take() {
let handler = { self.disconnect_handler.lock().unwrap().take() };
if let Some(handler) = handler {
#[cfg(feature = "tracing")]
tracing::trace!(?reason, ?self.id, "spawning disconnect handler");

handler.call(self.clone(), reason);
}

Expand Down
108 changes: 107 additions & 1 deletion socketioxide/tests/connect.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,22 @@ mod utils;

use bytes::Bytes;
use engineioxide::Packet::*;
use socketioxide::{extract::SocketRef, handler::ConnectHandler, SendError, SocketError, SocketIo};
use socketioxide::{
extract::SocketRef, handler::ConnectHandler, packet::Packet, SendError, SocketError, SocketIo,
};
use tokio::sync::mpsc;

fn create_msg(ns: &str, event: &str, data: impl Into<serde_json::Value>) -> engineioxide::Packet {
let packet: String = Packet::event(ns, event, data.into()).into();
Message(packet.into())
}
async fn timeout_rcv<T: std::fmt::Debug>(srx: &mut tokio::sync::mpsc::Receiver<T>) -> T {
tokio::time::timeout(std::time::Duration::from_millis(500), srx.recv())
.await
.unwrap()
.unwrap()
}

#[tokio::test]
pub async fn connect_middleware() {
let (_svc, io) = SocketIo::new_svc();
Expand Down Expand Up @@ -97,3 +110,96 @@ pub async fn connect_middleware_error() {
rx.recv().await.unwrap();
assert_err!(rx.try_recv());
}

#[tokio::test]
async fn remove_ns_from_connect_handler() {
let (tx, mut rx) = tokio::sync::mpsc::channel::<()>(2);
let (_svc, io) = SocketIo::new_svc();

let io_clone = io.clone();
io.ns("/test1", move || {
tx.try_send(()).unwrap();
io_clone.delete_ns("/test1");
});

let (stx, mut srx) = io.new_dummy_sock("/test1", ()).await;
timeout_rcv(&mut srx).await;
assert_ok!(stx.try_send(create_msg("/test1", "delete_ns", ())));
timeout_rcv(&mut rx).await;
assert_ok!(stx.try_send(create_msg("/test1", "delete_ns", ())));
// No response since ns is already deleted
let elapsed = tokio::time::timeout(std::time::Duration::from_millis(200), rx.recv()).await;
assert!(elapsed.is_err() || elapsed.unwrap().is_none());
}

#[tokio::test]
async fn remove_ns_from_middleware() {
let (tx, mut rx) = tokio::sync::mpsc::channel::<()>(2);
let (_svc, io) = SocketIo::new_svc();

let io_clone = io.clone();
let middleware = move || {
tx.try_send(()).unwrap();
io_clone.delete_ns("/test1");
Ok::<(), std::convert::Infallible>(())
};
fn handler() {}
io.ns("/test1", handler.with(middleware));

let (stx, mut srx) = io.new_dummy_sock("/test1", ()).await;
timeout_rcv(&mut srx).await;
assert_ok!(stx.try_send(create_msg("/test1", "delete_ns", ())));
timeout_rcv(&mut rx).await;
assert_ok!(stx.try_send(create_msg("/test1", "delete_ns", ())));
// No response since ns is already deleted
let elapsed = tokio::time::timeout(std::time::Duration::from_millis(200), rx.recv()).await;
assert!(elapsed.is_err() || elapsed.unwrap().is_none());
}

#[tokio::test]
async fn remove_ns_from_event_handler() {
let (tx, mut rx) = tokio::sync::mpsc::channel::<()>(2);
let (_svc, io) = SocketIo::new_svc();

let io_clone = io.clone();
io.ns("/test1", move |s: SocketRef| {
s.on("delete_ns", move || {
io_clone.delete_ns("/test1");
tx.try_send(()).unwrap();
});
});

let (stx, mut srx) = io.new_dummy_sock("/test1", ()).await;
timeout_rcv(&mut srx).await;
assert_ok!(stx.try_send(create_msg("/test1", "delete_ns", ())));
timeout_rcv(&mut rx).await;
assert_ok!(stx.try_send(create_msg("/test1", "delete_ns", ())));
// No response since ns is already deleted
let elapsed = tokio::time::timeout(std::time::Duration::from_millis(200), rx.recv()).await;
assert!(elapsed.is_err() || elapsed.unwrap().is_none());
}

#[tokio::test]
async fn remove_ns_from_disconnect_handler() {
let (tx, mut rx) = tokio::sync::mpsc::channel::<&'static str>(2);
let (_svc, io) = SocketIo::new_svc();

let io_clone = io.clone();
io.ns("/test2", move |s: SocketRef| {
tx.try_send("connect").unwrap();
s.on_disconnect(move || {
io_clone.delete_ns("/test2");
tx.try_send("disconnect").unwrap();
})
});

let (stx, mut srx) = io.new_dummy_sock("/test2", ()).await;
assert_eq!(timeout_rcv(&mut rx).await, "connect");
timeout_rcv(&mut srx).await;
assert_ok!(stx.try_send(Close));
assert_eq!(timeout_rcv(&mut rx).await, "disconnect");

let (_stx, mut _srx) = io.new_dummy_sock("/test2", ()).await;
let elapsed = tokio::time::timeout(std::time::Duration::from_millis(200), rx.recv()).await;
assert!(elapsed.is_err() || elapsed.unwrap().is_none());
}
21 changes: 21 additions & 0 deletions socketioxide/tests/disconnect_reason.rs
Original file line number Diff line number Diff line change
Expand Up @@ -228,6 +228,27 @@ pub async fn server_ns_disconnect() {
assert_eq!(data, DisconnectReason::ServerNSDisconnect);
}

#[tokio::test]
pub async fn server_ns_close() {
let (tx, mut rx) = mpsc::channel::<DisconnectReason>(1);
let io = create_server(12353).await;
let io2 = io.clone();
io.ns("/test", move |socket: SocketRef| {
socket.on_disconnect(move |reason: DisconnectReason| tx.try_send(reason).unwrap());
io2.delete_ns("/test");
});

let mut ws = create_ws_connection(12353).await;
ws.send(Message::Text("40/test,{}".to_string()))
.await
.unwrap();
let data = tokio::time::timeout(Duration::from_millis(20), rx.recv())
.await
.expect("timeout waiting for DisconnectReason::ServerNSDisconnect")
.unwrap();
assert_eq!(data, DisconnectReason::ServerNSDisconnect);
}

#[tokio::test]
pub async fn server_ws_closing() {
let io = create_server(12350).await;
Expand Down

0 comments on commit 2a2e3b9

Please sign in to comment.