Skip to content

Commit

Permalink
Add headers to ReconnectSettings
Browse files Browse the repository at this point in the history
  • Loading branch information
DanielChabrowski authored and 1c3t3a committed Nov 24, 2024
1 parent cf9e93e commit ba9abe3
Show file tree
Hide file tree
Showing 3 changed files with 55 additions and 6 deletions.
8 changes: 5 additions & 3 deletions ci/socket-io-restart.js
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@
const { Socket } = require("socket.io");

let createServer = require("http").createServer;
let server = createServer();
const io = require("socket.io")(server);
Expand All @@ -8,8 +6,12 @@ const timeout = 2000;

console.log("Started");
var callback = (client) => {
const headers = client.request.headers;
console.log("headers", headers);
const message = headers.message_back || "test";

console.log("Connected!");
client.emit("message", "test");
client.emit("message", message);
client.on("restart_server", () => {
console.log("will restart in ", timeout, "ms");
io.close();
Expand Down
3 changes: 2 additions & 1 deletion socketio/src/asynchronous/client/builder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ pub struct ClientBuilder {
pub(crate) on_reconnect: Option<Callback<DynAsyncReconnectSettingsCallback>>,
pub(crate) namespace: String,
tls_config: Option<TlsConnector>,
opening_headers: Option<HeaderMap>,
pub(crate) opening_headers: Option<HeaderMap>,
transport_type: TransportType,
pub(crate) auth: Option<serde_json::Value>,
pub(crate) reconnect: bool,
Expand Down Expand Up @@ -214,6 +214,7 @@ impl ClientBuilder {
/// let mut settings = ReconnectSettings::new();
/// settings.address("http://server?test=123");
/// settings.auth(json!({ "token": "abc" }));
/// settings.opening_header("TRAIL", "abc-123");
/// settings
/// }.boxed()
/// })
Expand Down
50 changes: 48 additions & 2 deletions socketio/src/asynchronous/client/client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ use backoff::{backoff::Backoff, ExponentialBackoffBuilder};
use futures_util::{future::BoxFuture, stream, Stream, StreamExt};
use log::{error, trace};
use rand::{thread_rng, Rng};
use rust_engineio::header::{HeaderMap, HeaderValue};
use serde_json::Value;
use tokio::{
sync::RwLock,
Expand Down Expand Up @@ -38,6 +39,7 @@ enum DisconnectReason {
pub struct ReconnectSettings {
address: Option<String>,
auth: Option<serde_json::Value>,
headers: Option<HeaderMap>,
}

impl ReconnectSettings {
Expand All @@ -58,6 +60,19 @@ impl ReconnectSettings {
pub fn auth(&mut self, auth: serde_json::Value) {
self.auth = Some(auth);
}

/// Adds an http header to a container that is going to completely replace opening headers on reconnect.
/// If there are no headers set in `ReconnectSettings`, client will use headers initially set via the builder.
pub fn opening_header<T: Into<HeaderValue>, K: Into<String>>(
&mut self,
key: K,
val: T,
) -> &mut Self {
self.headers
.get_or_insert_with(|| HeaderMap::default())
.insert(key.into(), val.into());
self
}
}

/// A socket which handles communication with the server. It's initialized with
Expand Down Expand Up @@ -112,13 +127,18 @@ impl Client {

if let Some(config) = builder.on_reconnect.as_mut() {
let reconnect_settings = config().await;

if let Some(address) = reconnect_settings.address {
builder.address = address;
}

if let Some(auth) = reconnect_settings.auth {
self.auth = Some(auth);
}

if reconnect_settings.headers.is_some() {
builder.opening_headers = reconnect_settings.headers;
}
}

let socket = builder.inner_create().await?;
Expand Down Expand Up @@ -594,7 +614,7 @@ mod test {
use serde_json::json;
use serial_test::serial;
use tokio::{
sync::mpsc,
sync::{mpsc, Mutex},
time::{sleep, timeout},
};

Expand Down Expand Up @@ -755,6 +775,8 @@ mod test {
static CONNECT_NUM: AtomicUsize = AtomicUsize::new(0);
static MESSAGE_NUM: AtomicUsize = AtomicUsize::new(0);
static ON_RECONNECT_CALLED: AtomicUsize = AtomicUsize::new(0);
let latest_message = Arc::new(Mutex::new(String::new()));
let handler_latest_message = latest_message.clone();

let url = crate::test::socket_io_restart_server();

Expand All @@ -772,6 +794,7 @@ mod test {
// Try setting the address to what we already have, just
// to test. This is not strictly necessary in real usage.
settings.address(url.to_string());
settings.opening_header("MESSAGE_BACK", "updated");
settings
}
.boxed()
Expand All @@ -789,11 +812,24 @@ mod test {
}
.boxed()
})
.on("message", |_, _socket| {
.on("message", move |payload, _socket| {
let latest_message = handler_latest_message.clone();
async move {
// test the iterator implementation and make sure there is a constant
// stream of packets, even when reconnecting
MESSAGE_NUM.fetch_add(1, Ordering::Release);

let msg = match payload {
Payload::Text(msg) => msg
.into_iter()
.next()
.expect("there should be one text payload"),
_ => panic!(),
};

let msg = serde_json::from_value(msg).expect("payload should be json string");

*latest_message.lock().await = msg;
}
.boxed()
})
Expand All @@ -808,6 +844,11 @@ mod test {

assert_eq!(load(&CONNECT_NUM), 1, "should connect once");
assert_eq!(load(&MESSAGE_NUM), 1, "should receive one");
assert_eq!(
*latest_message.lock().await,
"test",
"should receive test message"
);

let r = socket.emit("restart_server", json!("")).await;
assert!(r.is_ok(), "should emit restart success");
Expand All @@ -826,6 +867,11 @@ mod test {
load(&ON_RECONNECT_CALLED) > 1,
"should call on_reconnect at least once"
);
assert_eq!(
*latest_message.lock().await,
"updated",
"should receive updated message"
);

socket.disconnect().await?;
Ok(())
Expand Down

0 comments on commit ba9abe3

Please sign in to comment.