Skip to content

Commit f4e6f8b

Browse files
committed
fix: prevent reconnecting of client after invoking disconnect (1c3t3a#374)
1 parent 2ef32ec commit f4e6f8b

File tree

2 files changed

+104
-28
lines changed

2 files changed

+104
-28
lines changed

socketio/src/client/client.rs

Lines changed: 59 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@ use std::{
33
time::Duration,
44
};
55

6-
use super::{ClientBuilder, RawClient};
6+
use super::{raw_client::DisconnectReason, ClientBuilder, RawClient};
77
use crate::{
88
error::Result,
99
packet::{Packet, PacketId},
@@ -165,6 +165,11 @@ impl Client {
165165
client.disconnect()
166166
}
167167

168+
fn do_disconnect(&self) -> Result<()> {
169+
let client = self.client.read()?;
170+
client.do_disconnect()
171+
}
172+
168173
fn reconnect(&mut self) -> Result<()> {
169174
let mut reconnect_attempts = 0;
170175
let (reconnect, max_reconnect_attempts) = {
@@ -174,6 +179,17 @@ impl Client {
174179

175180
if reconnect {
176181
loop {
182+
// Check if disconnect_reason is Manual
183+
{
184+
let disconnect_reason = {
185+
let client = self.client.read()?;
186+
client.get_disconnect_reason()
187+
};
188+
if disconnect_reason == DisconnectReason::Manual {
189+
// Exit the loop, stop reconnecting
190+
break;
191+
}
192+
}
177193
if let Some(max_reconnect_attempts) = max_reconnect_attempts {
178194
reconnect_attempts += 1;
179195
if reconnect_attempts > max_reconnect_attempts {
@@ -186,6 +202,12 @@ impl Client {
186202
}
187203

188204
if self.do_reconnect().is_ok() {
205+
// Reset disconnect_reason to Unknown after successful reconnection
206+
{
207+
let client = self.client.read()?;
208+
let mut reason = client.disconnect_reason.write()?;
209+
*reason = DisconnectReason::Unknown;
210+
}
189211
break;
190212
}
191213
}
@@ -213,29 +235,43 @@ impl Client {
213235
let mut self_clone = self.clone();
214236
// Use thread to consume items in iterator in order to call callbacks
215237
std::thread::spawn(move || {
216-
// tries to restart a poll cycle whenever a 'normal' error occurs,
217-
// it just panics on network errors, in case the poll cycle returned
218-
// `Result::Ok`, the server receives a close frame so it's safe to
219-
// terminate
220-
for packet in self_clone.iter() {
221-
let should_reconnect = match packet {
222-
Err(Error::IncompleteResponseFromEngineIo(_)) => {
223-
//TODO: 0.3.X handle errors
224-
//TODO: logging error
225-
true
238+
loop {
239+
let next_item = self_clone.iter().next();
240+
match next_item {
241+
Some(Ok(_packet)) => {
242+
// Process packet normally
243+
continue;
244+
}
245+
Some(Err(_)) => {
246+
let should_reconnect = {
247+
let disconnect_reason = {
248+
let client = self_clone.client.read().unwrap();
249+
client.get_disconnect_reason()
250+
};
251+
match disconnect_reason {
252+
DisconnectReason::Unknown => {
253+
let builder = self_clone.builder.lock().unwrap();
254+
builder.reconnect
255+
}
256+
DisconnectReason::Manual => false,
257+
DisconnectReason::Server => {
258+
let builder = self_clone.builder.lock().unwrap();
259+
builder.reconnect_on_disconnect
260+
}
261+
}
262+
};
263+
if should_reconnect {
264+
let _ = self_clone.do_disconnect();
265+
let _ = self_clone.reconnect();
266+
} else {
267+
// No reconnection needed, exit the loop
268+
break;
269+
}
270+
}
271+
None => {
272+
// Iterator has ended, exit the loop
273+
break;
226274
}
227-
Ok(Packet {
228-
packet_type: PacketId::Disconnect,
229-
..
230-
}) => match self_clone.builder.lock() {
231-
Ok(builder) => builder.reconnect_on_disconnect,
232-
Err(_) => false,
233-
},
234-
_ => false,
235-
};
236-
if should_reconnect {
237-
let _ = self_clone.disconnect();
238-
let _ = self_clone.reconnect();
239275
}
240276
}
241277
});

socketio/src/client/raw_client.rs

Lines changed: 45 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -9,12 +9,23 @@ use crate::client::callback::{SocketAnyCallback, SocketCallback};
99
use crate::error::Result;
1010
use std::collections::HashMap;
1111
use std::ops::DerefMut;
12-
use std::sync::{Arc, Mutex};
12+
use std::sync::{Arc, Mutex, RwLock};
1313
use std::time::Duration;
1414
use std::time::Instant;
1515

1616
use crate::socket::Socket as InnerSocket;
1717

18+
#[derive(Default, Clone, Copy, PartialEq)]
19+
pub enum DisconnectReason {
20+
/// There is no known reason for the disconnect; likely a network error
21+
#[default]
22+
Unknown,
23+
/// The user disconnected manually
24+
Manual,
25+
/// The server disconnected
26+
Server,
27+
}
28+
1829
/// Represents an `Ack` as given back to the caller. Holds the internal `id` as
1930
/// well as the current ack'ed state. Holds data which will be accessible as
2031
/// soon as the ack'ed state is set to true. An `Ack` that didn't get ack'ed
@@ -41,6 +52,7 @@ pub struct RawClient {
4152
nsp: String,
4253
// Data send in the opening packet (commonly used as for auth)
4354
auth: Option<Value>,
55+
pub(crate) disconnect_reason: Arc<RwLock<DisconnectReason>>,
4456
}
4557

4658
impl RawClient {
@@ -62,6 +74,7 @@ impl RawClient {
6274
on_any,
6375
outstanding_acks: Arc::new(Mutex::new(Vec::new())),
6476
auth,
77+
disconnect_reason: Arc::new(RwLock::new(DisconnectReason::default())),
6578
})
6679
}
6780

@@ -142,7 +155,14 @@ impl RawClient {
142155
///
143156
/// ```
144157
pub fn disconnect(&self) -> Result<()> {
145-
let disconnect_packet =
158+
*(self.disconnect_reason.write()?) = DisconnectReason::Manual;
159+
self.do_disconnect()
160+
}
161+
162+
/// Disconnects this client the same way as `disconnect()` but
163+
/// without setting the `DisconnectReason` to `DisconnectReason::Manual`
164+
pub fn do_disconnect(&self) -> Result<()> {
165+
let disconnect_packet =
146166
Packet::new(PacketId::Disconnect, self.nsp.clone(), None, None, 0, None);
147167

148168
// TODO: logging
@@ -153,6 +173,10 @@ impl RawClient {
153173
Ok(())
154174
}
155175

176+
pub fn get_disconnect_reason(&self) -> DisconnectReason {
177+
*self.disconnect_reason.read().unwrap()
178+
}
179+
156180
/// Sends a message to the server but `alloc`s an `ack` to check whether the
157181
/// server responded in a given time span. This message takes an event, which
158182
/// could either be one of the common events like "message" or "error" or a
@@ -222,18 +246,32 @@ impl RawClient {
222246
}
223247

224248
pub(crate) fn poll(&self) -> Result<Option<Packet>> {
249+
{
250+
let disconnect_reason = *self.disconnect_reason.read()?;
251+
if disconnect_reason == DisconnectReason::Manual {
252+
// If disconnected manually, return Ok(None) to end iterator
253+
return Ok(None);
254+
}
255+
}
225256
loop {
226257
match self.socket.poll() {
227258
Err(err) => {
228-
self.callback(&Event::Error, err.to_string())?;
229-
return Err(err);
259+
// Check if the disconnection was manual
260+
let disconnect_reason = *self.disconnect_reason.read()?;
261+
if disconnect_reason == DisconnectReason::Manual {
262+
// Return Ok(None) to signal the end of the iterator
263+
return Ok(None);
264+
} else {
265+
self.callback(&Event::Error, err.to_string())?;
266+
return Err(err);
267+
}
230268
}
231269
Ok(Some(packet)) => {
232270
if packet.nsp == self.nsp {
233271
self.handle_socketio_packet(&packet)?;
234272
return Ok(Some(packet));
235273
} else {
236-
// Not our namespace continue polling
274+
// Not our namespace, continue polling
237275
}
238276
}
239277
Ok(None) => return Ok(None),
@@ -369,9 +407,11 @@ impl RawClient {
369407
}
370408
}
371409
PacketId::Connect => {
410+
*(self.disconnect_reason.write()?) = DisconnectReason::default();
372411
self.callback(&Event::Connect, "")?;
373412
}
374413
PacketId::Disconnect => {
414+
*(self.disconnect_reason.write()?) = DisconnectReason::Server;
375415
self.callback(&Event::Close, "")?;
376416
}
377417
PacketId::ConnectError => {

0 commit comments

Comments
 (0)