From ba9abe3df104dcb984fb352bca8d949a5aef9c67 Mon Sep 17 00:00:00 2001 From: Daniel Chabrowski Date: Mon, 7 Oct 2024 21:39:33 +0200 Subject: [PATCH] Add headers to ReconnectSettings --- ci/socket-io-restart.js | 8 ++-- socketio/src/asynchronous/client/builder.rs | 3 +- socketio/src/asynchronous/client/client.rs | 50 ++++++++++++++++++++- 3 files changed, 55 insertions(+), 6 deletions(-) diff --git a/ci/socket-io-restart.js b/ci/socket-io-restart.js index b028442a..406b98b3 100644 --- a/ci/socket-io-restart.js +++ b/ci/socket-io-restart.js @@ -1,5 +1,3 @@ -const { Socket } = require("socket.io"); - let createServer = require("http").createServer; let server = createServer(); const io = require("socket.io")(server); @@ -8,8 +6,12 @@ const timeout = 2000; console.log("Started"); var callback = (client) => { + const headers = client.request.headers; + console.log("headers", headers); + const message = headers.message_back || "test"; + console.log("Connected!"); - client.emit("message", "test"); + client.emit("message", message); client.on("restart_server", () => { console.log("will restart in ", timeout, "ms"); io.close(); diff --git a/socketio/src/asynchronous/client/builder.rs b/socketio/src/asynchronous/client/builder.rs index 44710e19..1d5f8bbb 100644 --- a/socketio/src/asynchronous/client/builder.rs +++ b/socketio/src/asynchronous/client/builder.rs @@ -29,7 +29,7 @@ pub struct ClientBuilder { pub(crate) on_reconnect: Option>, pub(crate) namespace: String, tls_config: Option, - opening_headers: Option, + pub(crate) opening_headers: Option, transport_type: TransportType, pub(crate) auth: Option, pub(crate) reconnect: bool, @@ -214,6 +214,7 @@ impl ClientBuilder { /// let mut settings = ReconnectSettings::new(); /// settings.address("http://server?test=123"); /// settings.auth(json!({ "token": "abc" })); + /// settings.opening_header("TRAIL", "abc-123"); /// settings /// }.boxed() /// }) diff --git a/socketio/src/asynchronous/client/client.rs b/socketio/src/asynchronous/client/client.rs index da1d5f09..01991056 100644 --- a/socketio/src/asynchronous/client/client.rs +++ b/socketio/src/asynchronous/client/client.rs @@ -4,6 +4,7 @@ use backoff::{backoff::Backoff, ExponentialBackoffBuilder}; use futures_util::{future::BoxFuture, stream, Stream, StreamExt}; use log::{error, trace}; use rand::{thread_rng, Rng}; +use rust_engineio::header::{HeaderMap, HeaderValue}; use serde_json::Value; use tokio::{ sync::RwLock, @@ -38,6 +39,7 @@ enum DisconnectReason { pub struct ReconnectSettings { address: Option, auth: Option, + headers: Option, } impl ReconnectSettings { @@ -58,6 +60,19 @@ impl ReconnectSettings { pub fn auth(&mut self, auth: serde_json::Value) { self.auth = Some(auth); } + + /// Adds an http header to a container that is going to completely replace opening headers on reconnect. + /// If there are no headers set in `ReconnectSettings`, client will use headers initially set via the builder. + pub fn opening_header, K: Into>( + &mut self, + key: K, + val: T, + ) -> &mut Self { + self.headers + .get_or_insert_with(|| HeaderMap::default()) + .insert(key.into(), val.into()); + self + } } /// A socket which handles communication with the server. It's initialized with @@ -112,6 +127,7 @@ impl Client { if let Some(config) = builder.on_reconnect.as_mut() { let reconnect_settings = config().await; + if let Some(address) = reconnect_settings.address { builder.address = address; } @@ -119,6 +135,10 @@ impl Client { if let Some(auth) = reconnect_settings.auth { self.auth = Some(auth); } + + if reconnect_settings.headers.is_some() { + builder.opening_headers = reconnect_settings.headers; + } } let socket = builder.inner_create().await?; @@ -594,7 +614,7 @@ mod test { use serde_json::json; use serial_test::serial; use tokio::{ - sync::mpsc, + sync::{mpsc, Mutex}, time::{sleep, timeout}, }; @@ -755,6 +775,8 @@ mod test { static CONNECT_NUM: AtomicUsize = AtomicUsize::new(0); static MESSAGE_NUM: AtomicUsize = AtomicUsize::new(0); static ON_RECONNECT_CALLED: AtomicUsize = AtomicUsize::new(0); + let latest_message = Arc::new(Mutex::new(String::new())); + let handler_latest_message = latest_message.clone(); let url = crate::test::socket_io_restart_server(); @@ -772,6 +794,7 @@ mod test { // Try setting the address to what we already have, just // to test. This is not strictly necessary in real usage. settings.address(url.to_string()); + settings.opening_header("MESSAGE_BACK", "updated"); settings } .boxed() @@ -789,11 +812,24 @@ mod test { } .boxed() }) - .on("message", |_, _socket| { + .on("message", move |payload, _socket| { + let latest_message = handler_latest_message.clone(); async move { // test the iterator implementation and make sure there is a constant // stream of packets, even when reconnecting MESSAGE_NUM.fetch_add(1, Ordering::Release); + + let msg = match payload { + Payload::Text(msg) => msg + .into_iter() + .next() + .expect("there should be one text payload"), + _ => panic!(), + }; + + let msg = serde_json::from_value(msg).expect("payload should be json string"); + + *latest_message.lock().await = msg; } .boxed() }) @@ -808,6 +844,11 @@ mod test { assert_eq!(load(&CONNECT_NUM), 1, "should connect once"); assert_eq!(load(&MESSAGE_NUM), 1, "should receive one"); + assert_eq!( + *latest_message.lock().await, + "test", + "should receive test message" + ); let r = socket.emit("restart_server", json!("")).await; assert!(r.is_ok(), "should emit restart success"); @@ -826,6 +867,11 @@ mod test { load(&ON_RECONNECT_CALLED) > 1, "should call on_reconnect at least once" ); + assert_eq!( + *latest_message.lock().await, + "updated", + "should receive updated message" + ); socket.disconnect().await?; Ok(())