From 9529d64856c3203df02a36ecc41676e91d6eb8ba Mon Sep 17 00:00:00 2001 From: shenjack <3695888@qq.com> Date: Sat, 30 Mar 2024 19:40:27 +0800 Subject: [PATCH] first step to support custom serializer --- Cargo.lock | 4 +- engineio/Cargo.toml | 4 +- .../asynchronous/async_transports/polling.rs | 17 +--- engineio/src/client/client.rs | 14 ++- engineio/src/header.rs | 1 - engineio/src/packet.rs | 86 ++++++++++++++++++- engineio/src/packet/message_pack.rs | 0 engineio/src/packet/normal.rs | 0 engineio/src/socket.rs | 9 +- engineio/src/transports/polling.rs | 16 +--- socketio/src/asynchronous/client/builder.rs | 29 ++++++- socketio/src/client/builder.rs | 61 ++++++++++++- socketio/src/client/mod.rs | 3 +- socketio/src/lib.rs | 2 +- 14 files changed, 199 insertions(+), 47 deletions(-) create mode 100644 engineio/src/packet/message_pack.rs create mode 100644 engineio/src/packet/normal.rs diff --git a/Cargo.lock b/Cargo.lock index 648c4749..85410a7f 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1709,7 +1709,7 @@ version = "0.12.3" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "3e6cc1e89e689536eb5aeede61520e874df5a4707df811cd5da4aa5fbb2aae19" dependencies = [ - "base64 0.22.0", + "base64 0.21.5", "bytes", "encoding_rs", "futures-channel", @@ -1755,7 +1755,7 @@ dependencies = [ "adler32", "async-stream", "async-trait", - "base64 0.21.5", + "base64 0.22.0", "bytes", "criterion", "futures-util", diff --git a/engineio/Cargo.toml b/engineio/Cargo.toml index 9bc5433f..d9237184 100644 --- a/engineio/Cargo.toml +++ b/engineio/Cargo.toml @@ -14,7 +14,7 @@ license = "MIT" all-features = true [dependencies] -base64 = "0.21.5" +base64 = "0.22.0" bytes = "1" reqwest = { version = "0.12.3", features = ["blocking", "native-tls", "stream"] } adler32 = "1.2.0" @@ -29,7 +29,7 @@ async-trait = "0.1.79" async-stream = "0.3.5" thiserror = "1.0" native-tls = "0.2.11" -url = "2.4.1" +url = "2.5.0" [dev-dependencies] criterion = { version = "0.5.1", features = ["async_tokio"] } diff --git a/engineio/src/asynchronous/async_transports/polling.rs b/engineio/src/asynchronous/async_transports/polling.rs index d88f1e89..505f2ea8 100644 --- a/engineio/src/asynchronous/async_transports/polling.rs +++ b/engineio/src/asynchronous/async_transports/polling.rs @@ -101,24 +101,11 @@ impl Stream for PollingTransport { #[async_trait] impl AsyncTransport for PollingTransport { - async fn emit(&self, data: Bytes, is_binary_att: bool) -> Result<()> { - let data_to_send = if is_binary_att { - // the binary attachment gets `base64` encoded - let mut packet_bytes = BytesMut::with_capacity(data.len() + 1); - packet_bytes.put_u8(b'b'); - - let encoded_data = general_purpose::STANDARD.encode(data); - packet_bytes.put(encoded_data.as_bytes()); - - packet_bytes.freeze() - } else { - data - }; - + async fn emit(&self, data: Bytes) -> Result<()> { let status = self .client .post(self.address().await?) - .body(data_to_send) + .body(data) .send() .await? .status() diff --git a/engineio/src/client/client.rs b/engineio/src/client/client.rs index dc22ff77..1006f4a3 100644 --- a/engineio/src/client/client.rs +++ b/engineio/src/client/client.rs @@ -5,7 +5,7 @@ use crate::transport::Transport; use crate::error::{Error, Result}; use crate::header::HeaderMap; -use crate::packet::{HandshakePacket, Packet, PacketId}; +use crate::packet::{HandshakePacket, Packet, PacketId, PacketSerializer}; use crate::transports::{PollingTransport, WebsocketSecureTransport, WebsocketTransport}; use crate::ENGINE_IO_VERSION; use bytes::Bytes; @@ -32,6 +32,7 @@ pub struct ClientBuilder { url: Url, tls_config: Option, headers: Option, + serializer: PacketSerializer, handshake: Option, on_error: OptionalCallback, on_open: OptionalCallback<()>, @@ -55,6 +56,7 @@ impl ClientBuilder { headers: None, tls_config: None, handshake: None, + serializer: PacketSerializer::default(), on_close: OptionalCallback::default(), on_data: OptionalCallback::default(), on_error: OptionalCallback::default(), @@ -63,6 +65,13 @@ impl ClientBuilder { } } + /// Specify Packet Serializer + pub fn packet_serializer(mut self, packet_serializer: PacketSerializer) -> Self { + self.serializer = packet_serializer; + + self + } + /// Specify transport's tls config pub fn tls_config(mut self, tls_config: TlsConnector) -> Self { self.tls_config = Some(tls_config); @@ -183,6 +192,7 @@ impl ClientBuilder { Ok(Client { socket: InnerSocket::new( transport.into(), + self.serializer, self.handshake.unwrap(), self.on_close, self.on_data, @@ -228,6 +238,7 @@ impl ClientBuilder { Ok(Client { socket: InnerSocket::new( transport.into(), + self.serializer, self.handshake.unwrap(), self.on_close, self.on_data, @@ -250,6 +261,7 @@ impl ClientBuilder { Ok(Client { socket: InnerSocket::new( transport.into(), + self.serializer, self.handshake.unwrap(), self.on_close, self.on_data, diff --git a/engineio/src/header.rs b/engineio/src/header.rs index 3756b408..237b0b08 100644 --- a/engineio/src/header.rs +++ b/engineio/src/header.rs @@ -5,7 +5,6 @@ use http::{ HeaderValue as HttpHeaderValue, }; use std::collections::HashMap; -use std::convert::TryFrom; use std::fmt::{Display, Formatter, Result as FmtResult}; use std::str::FromStr; diff --git a/engineio/src/packet.rs b/engineio/src/packet.rs index 9238428f..c63abd58 100644 --- a/engineio/src/packet.rs +++ b/engineio/src/packet.rs @@ -2,12 +2,94 @@ use base64::{engine::general_purpose, Engine as _}; use bytes::{BufMut, Bytes, BytesMut}; use serde::{Deserialize, Serialize}; use std::char; -use std::convert::TryFrom; -use std::convert::TryInto; use std::fmt::{Display, Formatter, Result as FmtResult, Write}; use std::ops::Index; use crate::error::{Error, Result}; + +pub struct PacketSerializer { + decode: Box Result + Send + Sync>, + encode: Box Bytes + Send + Sync>, +} + +fn default_decode(bytes: Bytes) -> Result { + if bytes.is_empty() { + return Err(Error::IncompletePacket()); + } + + let is_base64 = *bytes.first().ok_or(Error::IncompletePacket())? == b'b'; + + // only 'messages' packets could be encoded + let packet_id = if is_base64 { + PacketId::MessageBinary + } else { + (*bytes.first().ok_or(Error::IncompletePacket())?).try_into()? + }; + + if bytes.len() == 1 && packet_id == PacketId::Message { + return Err(Error::IncompletePacket()); + } + + let data: Bytes = bytes.slice(1..); + + Ok(Packet { + packet_id, + data: if is_base64 { + Bytes::from(general_purpose::STANDARD.decode(data.as_ref())?) + } else { + data + }, + }) +} + +fn default_encode(packet: Packet) -> Bytes { + let mut result = BytesMut::with_capacity(packet.data.len() + 1); + result.put_u8(packet.packet_id.to_string_byte()); + if packet.packet_id == PacketId::MessageBinary { + result.extend(general_purpose::STANDARD.encode(packet.data).into_bytes()); + } else { + result.put(packet.data); + } + result.freeze() +} + + +impl PacketSerializer { + const SEPARATOR: char = '\x1e'; + + pub fn new( + decode: Box Result + Send + Sync>, + encode: Box Bytes + Send + Sync>, + ) -> Self { + Self { + decode, + encode, + } + } + + pub fn default() -> Self { + let decode = Box::new(default_decode); + let encode = Box::new(default_encode); + Self::new(decode, encode) + } + + pub fn decode(&self, datas: Bytes) -> Result { + (self.decode)(datas) + } + + pub fn decode_payload(&self, datas: Bytes) -> Result { + datas + .split(|&c| c as char == PacketSerializer::SEPARATOR) + .map(|slice| self.decode(datas.slice_ref(slice))) + .collect::>>() + .map(Payload) + } + + pub fn encode(&self, packet: Packet) -> Bytes { + (self.encode)(packet) + } +} + /// Enumeration of the `engine.io` `Packet` types. #[derive(Copy, Clone, Eq, PartialEq, Debug)] pub enum PacketId { diff --git a/engineio/src/packet/message_pack.rs b/engineio/src/packet/message_pack.rs new file mode 100644 index 00000000..e69de29b diff --git a/engineio/src/packet/normal.rs b/engineio/src/packet/normal.rs new file mode 100644 index 00000000..e69de29b diff --git a/engineio/src/socket.rs b/engineio/src/socket.rs index b5231ab6..d6d24846 100644 --- a/engineio/src/socket.rs +++ b/engineio/src/socket.rs @@ -2,9 +2,8 @@ use crate::callback::OptionalCallback; use crate::transport::TransportType; use crate::error::{Error, Result}; -use crate::packet::{HandshakePacket, Packet, PacketId, Payload}; +use crate::packet::{HandshakePacket, Packet, PacketId, PacketSerializer, Payload}; use bytes::Bytes; -use std::convert::TryFrom; use std::sync::RwLock; use std::time::Duration; use std::{fmt::Debug, sync::atomic::Ordering}; @@ -23,6 +22,7 @@ pub const DEFAULT_MAX_POLL_TIMEOUT: Duration = Duration::from_secs(45); #[derive(Clone)] pub struct Socket { transport: Arc, + serializer: PacketSerializer, on_close: OptionalCallback<()>, on_data: OptionalCallback, on_error: OptionalCallback, @@ -40,6 +40,7 @@ pub struct Socket { impl Socket { pub(crate) fn new( transport: TransportType, + serializer: PacketSerializer, handshake: HandshakePacket, on_close: OptionalCallback<()>, on_data: OptionalCallback, @@ -56,6 +57,7 @@ impl Socket { on_open, on_packet, transport: Arc::new(transport), + serializer, connected: Arc::new(AtomicBool::default()), last_ping: Arc::new(Mutex::new(Instant::now())), last_pong: Arc::new(Mutex::new(Instant::now())), @@ -148,7 +150,8 @@ impl Socket { continue; } - let payload = Payload::try_from(data)?; + // let payload = Payload::try_from(data)?; + let payload = self.serializer.decode_payload(data)?; let mut iter = payload.into_iter(); if let Some(packet) = iter.next() { diff --git a/engineio/src/transports/polling.rs b/engineio/src/transports/polling.rs index 26aee87b..c4d0d324 100644 --- a/engineio/src/transports/polling.rs +++ b/engineio/src/transports/polling.rs @@ -49,23 +49,11 @@ impl PollingTransport { } impl Transport for PollingTransport { - fn emit(&self, data: Bytes, is_binary_att: bool) -> Result<()> { - let data_to_send = if is_binary_att { - // the binary attachment gets `base64` encoded - let mut packet_bytes = BytesMut::with_capacity(data.len() + 1); - packet_bytes.put_u8(b'b'); - - let encoded_data = general_purpose::STANDARD.encode(data); - packet_bytes.put(encoded_data.as_bytes()); - - packet_bytes.freeze() - } else { - data - }; + fn emit(&self, data: Bytes, _is_binary_att: bool) -> Result<()> { let status = self .client .post(self.address()?) - .body(data_to_send) + .body(data) .send()? .status() .as_u16(); diff --git a/socketio/src/asynchronous/client/builder.rs b/socketio/src/asynchronous/client/builder.rs index 44710e19..5f066585 100644 --- a/socketio/src/asynchronous/client/builder.rs +++ b/socketio/src/asynchronous/client/builder.rs @@ -8,7 +8,7 @@ use rust_engineio::{ use std::collections::HashMap; use url::Url; -use crate::{error::Result, Event, Payload, TransportType}; +use crate::{error::Result, Event, PacketSerializer, Payload, TransportType}; use super::{ callback::{ @@ -31,6 +31,7 @@ pub struct ClientBuilder { tls_config: Option, opening_headers: Option, transport_type: TransportType, + packet_serializer: PacketSerializer, pub(crate) auth: Option, pub(crate) reconnect: bool, pub(crate) reconnect_on_disconnect: bool, @@ -89,7 +90,8 @@ impl ClientBuilder { namespace: "/".to_owned(), tls_config: None, opening_headers: None, - transport_type: TransportType::Any, + transport_type: TransportType::default(), + packet_serializer: PacketSerializer::default(), auth: None, reconnect: true, reconnect_on_disconnect: false, @@ -395,6 +397,29 @@ impl ClientBuilder { self } + /// Specifies the [`PacketSerializer`] to use for encoding and decoding packets. + /// + /// # Example + /// ```rust + /// use rust_socketio::{asynchronous::ClientBuilder, PacketSerializer}; + /// + /// #[tokio::main] + /// async fn main() { + /// let socket = ClientBuilder::new("http://localhost:4200/") + /// .namespace("/admin") + /// .on("error", |err, _| async move { eprintln!("Error: {:#?}", err) }.boxed()) + /// .packet_serializer(PacketSerializer::Normal) + /// .connect() + /// .await + /// .expect("connection failed"); + /// } + /// ``` + pub fn packet_serializer(mut self, packet_serializer: PacketSerializer) -> Self { + self.packet_serializer = packet_serializer; + + self + } + /// Connects the socket to a certain endpoint. This returns a connected /// [`Client`] instance. This method returns an [`std::result::Result::Err`] /// value if something goes wrong during connection. Also starts a separate diff --git a/socketio/src/client/builder.rs b/socketio/src/client/builder.rs index 724971f0..2ce44dbb 100644 --- a/socketio/src/client/builder.rs +++ b/socketio/src/client/builder.rs @@ -27,6 +27,36 @@ pub enum TransportType { Polling, } +impl Default for TransportType { + fn default() -> Self { + TransportType::Any + } +} + +/// Serializer of Engine.IO packet +#[derive(Clone, Debug, Eq, PartialEq)] +pub enum PacketSerializer { + /// Normal serializer + Normal, + /// MessagePack serializer + MessagePack, +} + +impl Into for PacketSerializer { + fn into(self) -> rust_engineio::packet::PacketSerializer { + match self { + PacketSerializer::Normal => rust_engineio::packet::PacketSerializer::Normal, + PacketSerializer::MessagePack => rust_engineio::packet::PacketSerializer::MessagePack, + } + } +} + +impl Default for PacketSerializer { + fn default() -> Self { + PacketSerializer::Normal + } +} + /// A builder class for a `socket.io` socket. This handles setting up the client and /// configuring the callback, the namespace and metadata of the socket. If no /// namespace is specified, the default namespace `/` is taken. The `connect` method @@ -40,6 +70,7 @@ pub struct ClientBuilder { tls_config: Option, opening_headers: Option, transport_type: TransportType, + packet_serializer: PacketSerializer, auth: Option, pub(crate) reconnect: bool, pub(crate) reconnect_on_disconnect: bool, @@ -90,7 +121,8 @@ impl ClientBuilder { namespace: "/".to_owned(), tls_config: None, opening_headers: None, - transport_type: TransportType::Any, + transport_type: TransportType::default(), + packet_serializer: PacketSerializer::default(), auth: None, reconnect: true, reconnect_on_disconnect: false, @@ -306,6 +338,30 @@ impl ClientBuilder { self } + /// Specifies the [`PacketSerializer`] to use for encoding and decoding packets. + /// + /// # Example + /// ```rust + /// use rust_socketio::{asynchronous::ClientBuilder, PacketSerializer}; + /// use futures_util::FutureExt; + /// + /// #[tokio::main] + /// async fn main() { + /// let socket = ClientBuilder::new("http://localhost:4200/") + /// .namespace("/admin") + /// .on("error", |err, _| async move { eprintln!("Error: {:#?}", err) }.boxed()) + /// .packet_serializer(PacketSerializer::Normal) + /// .connect() + /// .await + /// .expect("connection failed"); + /// } + /// ``` + pub fn packet_serializer(mut self, packet_serializer: PacketSerializer) -> Self { + self.packet_serializer = packet_serializer; + + self + } + /// Connects the socket to a certain endpoint. This returns a connected /// [`Client`] instance. This method returns an [`std::result::Result::Err`] /// value if something goes wrong during connection. Also starts a separate @@ -341,7 +397,8 @@ impl ClientBuilder { url.set_path("/socket.io/"); } - let mut builder = EngineIoClientBuilder::new(url); + let mut builder = + EngineIoClientBuilder::new(url).packet_serializer(self.packet_serializer.into()); if let Some(tls_config) = self.tls_config { builder = builder.tls_config(tls_config); diff --git a/socketio/src/client/mod.rs b/socketio/src/client/mod.rs index e3884b64..1a1d34a4 100644 --- a/socketio/src/client/mod.rs +++ b/socketio/src/client/mod.rs @@ -1,8 +1,7 @@ mod builder; mod raw_client; -pub use builder::ClientBuilder; -pub use builder::TransportType; +pub use builder::{ClientBuilder, PacketSerializer, TransportType}; pub use client::Client; pub use raw_client::RawClient; diff --git a/socketio/src/lib.rs b/socketio/src/lib.rs index b913eb4d..e2ea94d8 100644 --- a/socketio/src/lib.rs +++ b/socketio/src/lib.rs @@ -193,7 +193,7 @@ pub use error::Error; pub use {event::Event, payload::Payload}; -pub use client::{ClientBuilder, RawClient, TransportType}; +pub use client::{ClientBuilder, PacketSerializer, RawClient, TransportType}; // TODO: 0.4.0 remove #[deprecated(since = "0.3.0-alpha-2", note = "Socket renamed to Client")]