diff --git a/rumqttc/Cargo.toml b/rumqttc/Cargo.toml index 1ace73691..bba64822c 100644 --- a/rumqttc/Cargo.toml +++ b/rumqttc/Cargo.toml @@ -23,7 +23,7 @@ websocket = ["dep:async-tungstenite", "dep:ws_stream_tungstenite", "dep:http"] proxy = ["dep:async-http-proxy"] [dependencies] -futures-util = { version = "0.3", default_features = false, features = ["std"] } +futures-util = { version = "0.3", default_features = false, features = ["std", "sink"] } tokio = { version = "1.36", features = ["rt", "macros", "io-util", "net", "time"] } tokio-util = { version = "0.7", features = ["codec"] } bytes = "1.5" diff --git a/rumqttc/src/eventloop.rs b/rumqttc/src/eventloop.rs index 6b3af08a8..1493b12eb 100644 --- a/rumqttc/src/eventloop.rs +++ b/rumqttc/src/eventloop.rs @@ -229,7 +229,7 @@ impl EventLoop { ), if !self.pending.is_empty() || (!inflight_full && !collision) => match o { Ok(request) => { if let Some(outgoing) = self.state.handle_outgoing_packet(request)? { - network.write(outgoing)?; + network.write(outgoing).await?; } match time::timeout(network_timeout, network.flush()).await { Ok(inner) => inner?, @@ -247,7 +247,7 @@ impl EventLoop { timeout.as_mut().reset(Instant::now() + self.mqtt_options.keep_alive); if let Some(outgoing) = self.state.handle_outgoing_packet(Request::PingReq(PingReq))? { - network.write(outgoing)?; + network.write(outgoing).await?; } match time::timeout(network_timeout, network.flush()).await { Ok(inner) => inner?, @@ -428,7 +428,11 @@ async fn network_connect( async_tungstenite::tokio::client_async(request, tcp_stream).await?; validate_response_headers(response)?; - Network::new(WsStream::new(socket), options.max_incoming_packet_size) + Network::new( + WsStream::new(socket), + options.max_incoming_packet_size, + options.max_outgoing_packet_size, + ) } #[cfg(all(feature = "use-rustls", feature = "websocket"))] Transport::Wss(tls_config) => { @@ -451,7 +455,11 @@ async fn network_connect( .await?; validate_response_headers(response)?; - Network::new(WsStream::new(socket), options.max_incoming_packet_size) + Network::new( + WsStream::new(socket), + options.max_incoming_packet_size, + options.max_outgoing_packet_size, + ) } }; diff --git a/rumqttc/src/framed.rs b/rumqttc/src/framed.rs index 6efc2a41f..58caec90d 100644 --- a/rumqttc/src/framed.rs +++ b/rumqttc/src/framed.rs @@ -1,23 +1,16 @@ -use bytes::BytesMut; -use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt}; -use tokio_util::codec::{Decoder, Encoder}; +use futures_util::{FutureExt, SinkExt, StreamExt}; +use tokio::io::{AsyncRead, AsyncWrite}; +use tokio_util::codec::Framed; use crate::mqttbytes::{self, v4::*}; use crate::{Incoming, MqttState, StateError}; -use std::io; /// Network transforms packets <-> frames efficiently. It takes /// advantage of pre-allocation, buffering and vectorization when /// appropriate to achieve performance pub struct Network { - /// Socket for IO - socket: Box, - /// Buffered reads - read: BytesMut, - /// Buffered writes - pub write: BytesMut, - /// Use to decode MQTT packets - codec: Codec, + /// Frame MQTT packets from network connection + framed: Framed, Codec>, /// Maximum readv count max_readb_count: usize, } @@ -29,69 +22,38 @@ impl Network { max_outgoing_size: usize, ) -> Network { let socket = Box::new(socket) as Box; + let codec = Codec { + max_incoming_size, + max_outgoing_size, + }; + let framed = Framed::new(socket, codec); + Network { - socket, - read: BytesMut::with_capacity(10 * 1024), - write: BytesMut::with_capacity(10 * 1024), - codec: Codec { - max_incoming_size, - max_outgoing_size, - }, + framed, max_readb_count: 10, } } - /// Reads more than 'required' bytes to frame a packet into self.read buffer - async fn read_bytes(&mut self, required: usize) -> io::Result { - let mut total_read = 0; - loop { - let read = self.socket.read_buf(&mut self.read).await?; - if 0 == read { - return if self.read.is_empty() { - Err(io::Error::new( - io::ErrorKind::ConnectionAborted, - "connection closed by peer", - )) - } else { - Err(io::Error::new( - io::ErrorKind::ConnectionReset, - "connection reset by peer", - )) - }; - } - - total_read += read; - if total_read >= required { - return Ok(total_read); - } - } - } - - pub async fn read(&mut self) -> io::Result { - loop { - let required = match self.codec.decode(&mut self.read) { - Ok(Some(packet)) => return Ok(packet), - // TODO: figure out how not to block - Ok(_) => 2, - Err(mqttbytes::Error::InsufficientBytes(required)) => required, - Err(e) => return Err(io::Error::new(io::ErrorKind::InvalidData, e.to_string())), - }; - - // read more packets until a frame can be created. This function - // blocks until a frame can be created. Use this in a select! branch - self.read_bytes(required).await?; + /// Reads and returns a single packet from network + pub async fn read(&mut self) -> Result { + match self.framed.next().await { + Some(Ok(packet)) => Ok(packet), + Some(Err(mqttbytes::Error::InsufficientBytes(_))) | None => unreachable!(), + Some(Err(e)) => Err(StateError::Deserialization(e)), } } /// Read packets in bulk. This allow replies to be in bulk. This method is used /// after the connection is established to read a bunch of incoming packets pub async fn readb(&mut self, state: &mut MqttState) -> Result<(), StateError> { - let mut count = 0; + // wait for the first read + let mut res = self.framed.next().await; + let mut count = 1; loop { - match self.codec.decode(&mut self.read) { - Ok(Some(packet)) => { + match res { + Some(Ok(packet)) => { if let Some(packet) = state.handle_incoming_packet(packet)? { - self.write(packet)?; + self.write(packet).await?; } count += 1; @@ -99,47 +61,38 @@ impl Network { break; } } - // If some packets are already framed, return those - Err(mqttbytes::Error::InsufficientBytes(_)) | Ok(_) if count > 0 => break, - // NOTE: read atleast 1 packet - Ok(_) => { - self.read_bytes(2).await?; - } - // Wait for more bytes until a frame can be created - Err(mqttbytes::Error::InsufficientBytes(required)) => { - self.read_bytes(required).await?; - } - Err(e) => return Err(StateError::Deserialization(e)), + Some(Err(mqttbytes::Error::InsufficientBytes(_))) | None => unreachable!(), + Some(Err(e)) => return Err(StateError::Deserialization(e)), + } + // do not wait for subsequent reads + match self.framed.next().now_or_never() { + Some(r) => res = r, + _ => break, }; } Ok(()) } - pub fn write(&mut self, packet: Packet) -> Result<(), crate::state::StateError> { - self.codec - .encode(packet, &mut self.write) - .map_err(Into::into) + /// Serializes packet into write buffer + pub async fn write(&mut self, packet: Packet) -> Result<(), StateError> { + self.framed + .feed(packet) + .await + .map_err(StateError::Deserialization) } - pub async fn connect(&mut self, connect: Connect) -> io::Result<()> { - let mut write = BytesMut::new(); - self.codec - .encode(Packet::Connect(connect), &mut write) - .map_err(|e| io::Error::new(io::ErrorKind::InvalidData, e.to_string()))?; + pub async fn connect(&mut self, connect: Connect) -> Result<(), StateError> { + self.write(Packet::Connect(connect)).await?; - self.socket.write_all(&write[..]).await?; - Ok(()) + self.flush().await } - pub async fn flush(&mut self) -> io::Result<()> { - if self.write.is_empty() { - return Ok(()); - } - - self.socket.write_all(&self.write[..]).await?; - self.write.clear(); - Ok(()) + pub async fn flush(&mut self) -> Result<(), crate::state::StateError> { + self.framed + .flush() + .await + .map_err(StateError::Deserialization) } } diff --git a/rumqttc/src/mqttbytes/v4/codec.rs b/rumqttc/src/mqttbytes/v4/codec.rs index 3e7c73d56..f605a1f5a 100644 --- a/rumqttc/src/mqttbytes/v4/codec.rs +++ b/rumqttc/src/mqttbytes/v4/codec.rs @@ -1,4 +1,4 @@ -use bytes::{Buf, BytesMut}; +use bytes::BytesMut; use tokio_util::codec::{Decoder, Encoder}; use super::{Error, Packet}; @@ -17,12 +17,15 @@ impl Decoder for Codec { type Error = Error; fn decode(&mut self, src: &mut BytesMut) -> Result, Self::Error> { - if src.remaining() == 0 { - return Ok(None); + match Packet::read(src, self.max_incoming_size) { + Ok(packet) => Ok(Some(packet)), + Err(Error::InsufficientBytes(b)) => { + // Get more packets to construct the incomplete packet + src.reserve(b); + Ok(None) + } + Err(e) => Err(e), } - - let packet = Packet::read(src, self.max_incoming_size)?; - Ok(Some(packet)) } } diff --git a/rumqttc/tests/reliability.rs b/rumqttc/tests/reliability.rs index 7d96ae441..0a83d57ce 100644 --- a/rumqttc/tests/reliability.rs +++ b/rumqttc/tests/reliability.rs @@ -570,7 +570,7 @@ async fn state_is_being_cleaned_properly_and_pending_request_calculated_properly if let Err(e) = res { match e { ConnectionError::FlushTimeout => { - assert!(eventloop.network.as_ref().unwrap().write.is_empty()); + assert!(eventloop.network.is_none()); println!("State is being clean properly"); } _ => {