diff --git a/Cargo.lock b/Cargo.lock index 2efed53fe..e6f8648cf 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1951,6 +1951,8 @@ dependencies = [ "tokio", "tokio-native-tls", "tokio-rustls", + "tokio-stream", + "tokio-util", "url", "ws_stream_tungstenite", ] @@ -2543,6 +2545,17 @@ dependencies = [ "tokio", ] +[[package]] +name = "tokio-stream" +version = "0.1.15" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "267ac89e0bec6e691e5813911606935d77c476ff49024f98abcea3e7b15e37af" +dependencies = [ + "futures-core", + "pin-project-lite", + "tokio", +] + [[package]] name = "tokio-util" version = "0.7.10" diff --git a/benchmarks/parsers/v4.rs b/benchmarks/parsers/v4.rs index 8a97bf2f2..4fbde7bab 100644 --- a/benchmarks/parsers/v4.rs +++ b/benchmarks/parsers/v4.rs @@ -1,6 +1,7 @@ use bytes::{Buf, BytesMut}; use rumqttc::mqttbytes::v4; use rumqttc::mqttbytes::QoS; +use rumqttc::Packet; use std::time::Instant; mod common; @@ -31,7 +32,7 @@ fn main() { let start = Instant::now(); let mut packets = Vec::with_capacity(count); while output.has_remaining() { - let packet = v4::read(&mut output, 10 * 1024).unwrap(); + let packet = Packet::read(&mut output, 10 * 1024).unwrap(); packets.push(packet); } diff --git a/rumqttc/CHANGELOG.md b/rumqttc/CHANGELOG.md index cc5c0c1e3..70cce1df6 100644 --- a/rumqttc/CHANGELOG.md +++ b/rumqttc/CHANGELOG.md @@ -11,6 +11,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ### Changed +* Refactor `Network`, simplify with `Framed` + ### Deprecated ### Removed diff --git a/rumqttc/Cargo.toml b/rumqttc/Cargo.toml index 551dcab53..bba64822c 100644 --- a/rumqttc/Cargo.toml +++ b/rumqttc/Cargo.toml @@ -23,8 +23,9 @@ websocket = ["dep:async-tungstenite", "dep:ws_stream_tungstenite", "dep:http"] proxy = ["dep:async-http-proxy"] [dependencies] -futures-util = { version = "0.3", default_features = false, features = ["std"] } +futures-util = { version = "0.3", default_features = false, features = ["std", "sink"] } tokio = { version = "1.36", features = ["rt", "macros", "io-util", "net", "time"] } +tokio-util = { version = "0.7", features = ["codec"] } bytes = "1.5" log = "0.4" flume = { version = "0.11", default-features = false, features = ["async"] } @@ -47,6 +48,7 @@ native-tls = { version = "0.2.11", optional = true } url = { version = "2", default-features = false, optional = true } # proxy async-http-proxy = { version = "1.2.5", features = ["runtime-tokio", "basic-auth"], optional = true } +tokio-stream = "0.1.15" [dev-dependencies] bincode = "1.3.3" diff --git a/rumqttc/src/eventloop.rs b/rumqttc/src/eventloop.rs index a8aee76c7..5317e437a 100644 --- a/rumqttc/src/eventloop.rs +++ b/rumqttc/src/eventloop.rs @@ -2,7 +2,7 @@ use crate::{framed::Network, Transport}; use crate::{Incoming, MqttState, NetworkOptions, Packet, Request, StateError}; use crate::{MqttOptions, Outgoing}; -use crate::framed::N; +use crate::framed::AsyncReadWrite; use crate::mqttbytes::v4::*; use flume::{bounded, Receiver, Sender}; use tokio::net::{lookup_host, TcpSocket, TcpStream}; @@ -38,8 +38,6 @@ pub enum ConnectionError { MqttState(#[from] StateError), #[error("Network timeout")] NetworkTimeout, - #[error("Flush timeout")] - FlushTimeout, #[cfg(feature = "websocket")] #[error("Websocket: {0}")] Websocket(#[from] async_tungstenite::tungstenite::error::Error), @@ -81,7 +79,7 @@ pub struct EventLoop { /// Pending packets from last session pub pending: VecDeque, /// Network connection to the broker - network: Option, + pub network: Option, /// Keep alive time keepalive_timeout: Option>>, pub network_options: NetworkOptions, @@ -104,11 +102,10 @@ impl EventLoop { let pending = VecDeque::new(); let max_inflight = mqtt_options.inflight; let manual_acks = mqtt_options.manual_acks; - let max_outgoing_packet_size = mqtt_options.max_outgoing_packet_size; EventLoop { mqtt_options, - state: MqttState::new(max_inflight, manual_acks, max_outgoing_packet_size), + state: MqttState::new(max_inflight, manual_acks), requests_tx, requests_rx, pending, @@ -174,7 +171,6 @@ impl EventLoop { // let await_acks = self.state.await_acks; let inflight_full = self.state.inflight >= self.mqtt_options.inflight; let collision = self.state.collision.is_some(); - let network_timeout = Duration::from_secs(self.network_options.connection_timeout()); // Read buffered events from previous polls before calling a new poll if let Some(event) = self.state.events.pop_front() { @@ -186,13 +182,11 @@ impl EventLoop { // instead of returning a None event, we try again. select! { // Pull a bunch of packets from network, reply in bunch and yield the first item - o = network.readb(&mut self.state) => { - o?; - // flush all the acks and return first incoming packet - match time::timeout(network_timeout, network.flush(&mut self.state.write)).await { - Ok(inner) => inner?, - Err(_)=> return Err(ConnectionError::FlushTimeout), - }; + o = network.read() => { + let incoming = o?; + if let Some(packet) = self.state.handle_incoming_packet(incoming)? { + network.write(packet).await?; + } Ok(self.state.events.pop_front().unwrap()) }, // Handles pending and new requests. @@ -229,11 +223,10 @@ impl EventLoop { self.mqtt_options.pending_throttle ), if !self.pending.is_empty() || (!inflight_full && !collision) => match o { Ok(request) => { - self.state.handle_outgoing_packet(request)?; - match time::timeout(network_timeout, network.flush(&mut self.state.write)).await { - Ok(inner) => inner?, - Err(_)=> return Err(ConnectionError::FlushTimeout), - }; + if let Some(outgoing) = self.state.handle_outgoing_packet(request)? { + network.write(outgoing).await?; + } + Ok(self.state.events.pop_front().unwrap()) } Err(_) => Err(ConnectionError::RequestsDone), @@ -245,11 +238,10 @@ impl EventLoop { let timeout = self.keepalive_timeout.as_mut().unwrap(); timeout.as_mut().reset(Instant::now() + self.mqtt_options.keep_alive); - self.state.handle_outgoing_packet(Request::PingReq(PingReq))?; - match time::timeout(network_timeout, network.flush(&mut self.state.write)).await { - Ok(inner) => inner?, - Err(_)=> return Err(ConnectionError::FlushTimeout), - }; + if let Some(outgoing) = self.state.handle_outgoing_packet(Request::PingReq(PingReq))? { + network.write(outgoing).await?; + } + Ok(self.state.events.pop_front().unwrap()) } } @@ -351,12 +343,19 @@ async fn network_connect( options: &MqttOptions, network_options: NetworkOptions, ) -> Result { + let network_timeout = Duration::from_secs(network_options.connection_timeout()); // Process Unix files early, as proxy is not supported for them. #[cfg(unix)] if matches!(options.transport(), Transport::Unix) { let file = options.broker_addr.as_str(); let socket = UnixStream::connect(Path::new(file)).await?; - let network = Network::new(socket, options.max_incoming_packet_size); + let network = Network::new( + socket, + options.max_incoming_packet_size, + options.max_outgoing_packet_size, + network_timeout, + options.network_buffer_capacity, + ); return Ok(network); } @@ -369,7 +368,7 @@ async fn network_connect( _ => options.broker_address(), }; - let tcp_stream: Box = { + let tcp_stream: Box = { #[cfg(feature = "proxy")] match options.proxy() { Some(proxy) => proxy.connect(&domain, port, network_options).await?, @@ -388,13 +387,25 @@ async fn network_connect( }; let network = match options.transport() { - Transport::Tcp => Network::new(tcp_stream, options.max_incoming_packet_size), + Transport::Tcp => Network::new( + tcp_stream, + options.max_incoming_packet_size, + options.max_outgoing_packet_size, + network_timeout, + options.network_buffer_capacity, + ), #[cfg(any(feature = "use-rustls", feature = "use-native-tls"))] Transport::Tls(tls_config) => { let socket = tls::tls_connect(&options.broker_addr, options.port, &tls_config, tcp_stream) .await?; - Network::new(socket, options.max_incoming_packet_size) + Network::new( + socket, + options.max_incoming_packet_size, + options.max_outgoing_packet_size, + network_timeout, + options.network_buffer_capacity, + ) } #[cfg(unix)] Transport::Unix => unreachable!(), @@ -413,7 +424,13 @@ async fn network_connect( async_tungstenite::tokio::client_async(request, tcp_stream).await?; validate_response_headers(response)?; - Network::new(WsStream::new(socket), options.max_incoming_packet_size) + Network::new( + WsStream::new(socket), + options.max_incoming_packet_size, + options.max_outgoing_packet_size, + network_timeout, + options.network_buffer_capacity, + ) } #[cfg(all(feature = "use-rustls", feature = "websocket"))] Transport::Wss(tls_config) => { @@ -436,7 +453,13 @@ async fn network_connect( .await?; validate_response_headers(response)?; - Network::new(WsStream::new(socket), options.max_incoming_packet_size) + Network::new( + WsStream::new(socket), + options.max_incoming_packet_size, + options.max_outgoing_packet_size, + network_timeout, + options.network_buffer_capacity, + ) } }; @@ -462,7 +485,7 @@ async fn mqtt_connect( } // send mqtt connect packet - network.connect(connect).await?; + network.write(Packet::Connect(connect)).await?; // validate connack match network.read().await? { diff --git a/rumqttc/src/framed.rs b/rumqttc/src/framed.rs index b0a536e78..336de5e5d 100644 --- a/rumqttc/src/framed.rs +++ b/rumqttc/src/framed.rs @@ -1,121 +1,83 @@ -use bytes::BytesMut; -use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt}; +use futures_util::{SinkExt, StreamExt}; +use tokio::io::{AsyncRead, AsyncWrite}; +use tokio::time::timeout; +use tokio_util::codec::Framed; use crate::mqttbytes::{self, v4::*}; -use crate::{Incoming, MqttState, StateError}; -use std::io; +use crate::{Incoming, StateError}; +use std::time::Duration; /// Network transforms packets <-> frames efficiently. It takes /// advantage of pre-allocation, buffering and vectorization when /// appropriate to achieve performance pub struct Network { - /// Socket for IO - socket: Box, - /// Buffered reads - read: BytesMut, - /// Maximum packet size - max_incoming_size: usize, - /// Maximum readv count - max_readb_count: usize, + /// Frame MQTT packets from network connection + framed: Framed, Codec>, + /// Time within which network write operations should complete + timeout: Duration, + /// Capacity upto which buffering is good + buffer_capacity: usize, } impl Network { - pub fn new(socket: impl N + 'static, max_incoming_size: usize) -> Network { - let socket = Box::new(socket) as Box; - Network { - socket, - read: BytesMut::with_capacity(10 * 1024), + pub fn new( + socket: impl AsyncReadWrite + 'static, + max_incoming_size: usize, + max_outgoing_size: usize, + timeout: Duration, + buffer_capacity: usize, + ) -> Network { + let socket = Box::new(socket) as Box; + let codec = Codec { max_incoming_size, - max_readb_count: 10, - } - } - - /// Reads more than 'required' bytes to frame a packet into self.read buffer - async fn read_bytes(&mut self, required: usize) -> io::Result { - let mut total_read = 0; - loop { - let read = self.socket.read_buf(&mut self.read).await?; - if 0 == read { - return if self.read.is_empty() { - Err(io::Error::new( - io::ErrorKind::ConnectionAborted, - "connection closed by peer", - )) - } else { - Err(io::Error::new( - io::ErrorKind::ConnectionReset, - "connection reset by peer", - )) - }; - } + max_outgoing_size, + }; + let framed = Framed::with_capacity(socket, codec, buffer_capacity); - total_read += read; - if total_read >= required { - return Ok(total_read); - } + Network { + framed, + timeout, + buffer_capacity, } } - pub async fn read(&mut self) -> io::Result { - loop { - let required = match read(&mut self.read, self.max_incoming_size) { - Ok(packet) => return Ok(packet), - Err(mqttbytes::Error::InsufficientBytes(required)) => required, - Err(e) => return Err(io::Error::new(io::ErrorKind::InvalidData, e.to_string())), - }; - - // read more packets until a frame can be created. This function - // blocks until a frame can be created. Use this in a select! branch - self.read_bytes(required).await?; + pub async fn read(&mut self) -> Result { + match self.framed.next().await { + Some(Ok(packet)) => Ok(packet), + Some(Err(mqttbytes::Error::InsufficientBytes(_))) => unreachable!(), + Some(Err(e)) => Err(StateError::Deserialization(e)), + None => Err(StateError::ConnectionClosed), } } - /// Read packets in bulk. This allow replies to be in bulk. This method is used - /// after the connection is established to read a bunch of incoming packets - pub async fn readb(&mut self, state: &mut MqttState) -> Result<(), StateError> { - let mut count = 0; - loop { - match read(&mut self.read, self.max_incoming_size) { - Ok(packet) => { - state.handle_incoming_packet(packet)?; + /// Write packets into buffer, flushes `Connect`/`PingReq`/`PingResp` packets instantly, + /// or on breaching buffer capacity + pub async fn write(&mut self, packet: Packet) -> Result<(), StateError> { + let packet_size = packet.size(); + let should_flush = match packet { + Packet::Connect(_) | Packet::PingReq | Packet::PingResp => true, + _ => false, + }; + self.framed + .feed(packet) + .await + .map_err(StateError::Deserialization)?; - count += 1; - if count >= self.max_readb_count { - return Ok(()); - } - } - // If some packets are already framed, return those - Err(mqttbytes::Error::InsufficientBytes(_)) if count > 0 => return Ok(()), - // Wait for more bytes until a frame can be created - Err(mqttbytes::Error::InsufficientBytes(required)) => { - self.read_bytes(required).await?; - } - Err(e) => return Err(StateError::Deserialization(e)), - }; + if should_flush || self.framed.write_buffer().len() + packet_size >= self.buffer_capacity { + self.flush().await?; } - } - - pub async fn connect(&mut self, connect: Connect) -> io::Result { - let mut write = BytesMut::new(); - let len = match connect.write(&mut write) { - Ok(size) => size, - Err(e) => return Err(io::Error::new(io::ErrorKind::InvalidData, e.to_string())), - }; - self.socket.write_all(&write[..]).await?; - Ok(len) + Ok(()) } - pub async fn flush(&mut self, write: &mut BytesMut) -> io::Result<()> { - if write.is_empty() { - return Ok(()); + /// Force flush all packets in buffer, reset count + pub async fn flush(&mut self) -> Result<(), StateError> { + match timeout(self.timeout, self.framed.flush()).await { + Ok(inner) => inner.map_err(StateError::Deserialization), + Err(_) => Err(StateError::FlushTimeout), } - - self.socket.write_all(&write[..]).await?; - write.clear(); - Ok(()) } } -pub trait N: AsyncRead + AsyncWrite + Send + Unpin {} -impl N for T where T: AsyncRead + AsyncWrite + Send + Unpin {} +pub trait AsyncReadWrite: AsyncRead + AsyncWrite + Send + Unpin {} +impl AsyncReadWrite for T where T: AsyncRead + AsyncWrite + Send + Unpin {} diff --git a/rumqttc/src/lib.rs b/rumqttc/src/lib.rs index 43dbb3bed..4c9acf3e8 100644 --- a/rumqttc/src/lib.rs +++ b/rumqttc/src/lib.rs @@ -200,25 +200,6 @@ pub enum Request { Disconnect(Disconnect), } -impl Request { - fn size(&self) -> usize { - match &self { - Request::Publish(publish) => publish.size(), - Request::PubAck(puback) => puback.size(), - Request::PubRec(pubrec) => pubrec.size(), - Request::PubComp(pubcomp) => pubcomp.size(), - Request::PubRel(pubrel) => pubrel.size(), - Request::PingReq(pingreq) => pingreq.size(), - Request::PingResp(pingresp) => pingresp.size(), - Request::Subscribe(subscribe) => subscribe.size(), - Request::SubAck(suback) => suback.size(), - Request::Unsubscribe(unsubscribe) => unsubscribe.size(), - Request::UnsubAck(unsuback) => unsuback.size(), - Request::Disconnect(disconn) => disconn.size(), - } - } -} - impl From for Request { fn from(publish: Publish) -> Request { Request::Publish(publish) @@ -461,8 +442,8 @@ pub struct MqttOptions { max_outgoing_packet_size: usize, /// request (publish, subscribe) channel capacity request_channel_capacity: usize, - /// Max internal request batching - max_request_batch: usize, + /// Network buffer capacity in memory + network_buffer_capacity: usize, /// Minimum delay time between consecutive outgoing packets /// while retransmitting pending packets pending_throttle: Duration, @@ -502,7 +483,7 @@ impl MqttOptions { max_incoming_packet_size: 10 * 1024, max_outgoing_packet_size: 10 * 1024, request_channel_capacity: 10, - max_request_batch: 0, + network_buffer_capacity: 10 * 1024, pending_throttle: Duration::from_micros(0), inflight: 100, last_will: None, @@ -661,6 +642,12 @@ impl MqttOptions { self.request_channel_capacity } + /// Maximum buffer capacity before network flush + pub fn set_network_buffer_capacity(&mut self, network_buffer_capacity: usize) -> &mut Self { + self.network_buffer_capacity = network_buffer_capacity; + self + } + /// Enables throttling and sets outoing message rate to the specified 'rate' pub fn set_pending_throttle(&mut self, duration: Duration) -> &mut Self { self.pending_throttle = duration; @@ -861,12 +848,12 @@ impl std::convert::TryFrom for MqttOptions { options.request_channel_capacity = request_channel_capacity; } - if let Some(max_request_batch) = queries - .remove("max_request_batch_num") + if let Some(network_buffer_capacity) = queries + .remove("network_buffer_capacity_num") .map(|v| v.parse::().map_err(|_| OptionError::MaxRequestBatch)) .transpose()? { - options.max_request_batch = max_request_batch; + options.network_buffer_capacity = network_buffer_capacity; } if let Some(pending_throttle) = queries @@ -906,7 +893,7 @@ impl Debug for MqttOptions { .field("credentials", &self.credentials) .field("max_packet_size", &self.max_incoming_packet_size) .field("request_channel_capacity", &self.request_channel_capacity) - .field("max_request_batch", &self.max_request_batch) + .field("network_buffer_capacity", &self.network_buffer_capacity) .field("pending_throttle", &self.pending_throttle) .field("inflight", &self.inflight) .field("last_will", &self.last_will) @@ -989,7 +976,7 @@ mod test { OptionError::RequestChannelCapacity ); assert_eq!( - err("mqtt://host:42?client_id=foo&max_request_batch_num=foo"), + err("mqtt://host:42?client_id=foo&network_buffer_capacity_num=foo"), OptionError::MaxRequestBatch ); assert_eq!( diff --git a/rumqttc/src/mqttbytes/mod.rs b/rumqttc/src/mqttbytes/mod.rs index 69858d80f..3345b897f 100644 --- a/rumqttc/src/mqttbytes/mod.rs +++ b/rumqttc/src/mqttbytes/mod.rs @@ -13,7 +13,7 @@ pub mod v4; pub use topic::*; /// Error during serialization and deserialization -#[derive(Debug, Clone, Copy, PartialEq, Eq, thiserror::Error)] +#[derive(Debug, thiserror::Error)] pub enum Error { #[error("Expected Connect, received: {0:?}")] NotConnect(PacketType), @@ -60,6 +60,10 @@ pub enum Error { /// proceed further #[error("At least {0} more bytes required to frame packet")] InsufficientBytes(usize), + #[error("IO: {0}")] + Io(#[from] std::io::Error), + #[error("Cannot send packet of size '{pkt_size:?}'. It's greater than the broker's maximum packet size of: '{max:?}'")] + OutgoingPacketTooLarge { pkt_size: usize, max: usize }, } /// MQTT packet type diff --git a/rumqttc/src/mqttbytes/v4/codec.rs b/rumqttc/src/mqttbytes/v4/codec.rs new file mode 100644 index 000000000..c8fb91739 --- /dev/null +++ b/rumqttc/src/mqttbytes/v4/codec.rs @@ -0,0 +1,73 @@ +use bytes::BytesMut; +use tokio_util::codec::{Decoder, Encoder}; + +use super::{Error, Packet}; + +/// MQTT v4 codec +#[derive(Debug, Clone)] +pub struct Codec { + /// Maximum packet size allowed by client + pub max_incoming_size: usize, + /// Maximum packet size allowed by broker + pub max_outgoing_size: usize, +} + +impl Decoder for Codec { + type Item = Packet; + type Error = Error; + + fn decode(&mut self, src: &mut BytesMut) -> Result, Self::Error> { + match Packet::read(src, self.max_incoming_size) { + Ok(packet) => Ok(Some(packet)), + // NOTE: not enough bytes to construct packet, reserve enough in src buffer + Err(Error::InsufficientBytes(b)) => { + src.reserve(b); + Ok(None) + } + Err(e) => Err(e), + } + } +} + +impl Encoder for Codec { + type Error = Error; + + fn encode(&mut self, item: Packet, dst: &mut BytesMut) -> Result<(), Self::Error> { + item.write(dst, self.max_outgoing_size)?; + + Ok(()) + } +} + +#[cfg(test)] +mod tests { + use bytes::BytesMut; + use tokio_util::codec::Encoder; + + use super::Codec; + use crate::{mqttbytes::Error, Packet, Publish, QoS}; + + #[test] + fn outgoing_max_packet_size_check() { + let mut buf = BytesMut::new(); + let mut codec = Codec { + max_incoming_size: 100, + max_outgoing_size: 200, + }; + + let mut small_publish = Publish::new("hello/world", QoS::AtLeastOnce, vec![1; 100]); + small_publish.pkid = 1; + codec + .encode(Packet::Publish(small_publish), &mut buf) + .unwrap(); + + let large_publish = Publish::new("hello/world", QoS::AtLeastOnce, vec![1; 265]); + match codec.encode(Packet::Publish(large_publish), &mut buf) { + Err(Error::OutgoingPacketTooLarge { + pkt_size: 281, + max: 200, + }) => {} + _ => unreachable!(), + } + } +} diff --git a/rumqttc/src/mqttbytes/v4/connack.rs b/rumqttc/src/mqttbytes/v4/connack.rs index 453919515..65a0da48b 100644 --- a/rumqttc/src/mqttbytes/v4/connack.rs +++ b/rumqttc/src/mqttbytes/v4/connack.rs @@ -61,6 +61,13 @@ impl ConnAck { Ok(1 + count + len) } + + pub fn size(&self) -> usize { + let len = self.len(); + let remaining_len_size = len_len(len); + + 1 + remaining_len_size + len + } } /// Connection return code type diff --git a/rumqttc/src/mqttbytes/v4/connect.rs b/rumqttc/src/mqttbytes/v4/connect.rs index cdba10140..8732384fc 100644 --- a/rumqttc/src/mqttbytes/v4/connect.rs +++ b/rumqttc/src/mqttbytes/v4/connect.rs @@ -132,6 +132,13 @@ impl Connect { buffer[flags_index] = connect_flags; Ok(1 + count + len) } + + pub fn size(&self) -> usize { + let len = self.len(); + let remaining_len_size = len_len(len); + + 1 + remaining_len_size + len + } } /// LastWill that broker forwards on behalf of the client diff --git a/rumqttc/src/mqttbytes/v4/mod.rs b/rumqttc/src/mqttbytes/v4/mod.rs index abe456127..3621945de 100644 --- a/rumqttc/src/mqttbytes/v4/mod.rs +++ b/rumqttc/src/mqttbytes/v4/mod.rs @@ -1,5 +1,6 @@ use super::*; +mod codec; mod connack; mod connect; mod disconnect; @@ -14,6 +15,7 @@ mod subscribe; mod unsuback; mod unsubscribe; +pub use codec::*; pub use connack::*; pub use connect::*; pub use disconnect::*; @@ -47,43 +49,93 @@ pub enum Packet { Disconnect, } -/// Reads a stream of bytes and extracts next MQTT packet out of it -pub fn read(stream: &mut BytesMut, max_size: usize) -> Result { - let fixed_header = check(stream.iter(), max_size)?; +impl Packet { + pub fn size(&self) -> usize { + match self { + Self::Publish(publish) => publish.size(), + Self::Subscribe(subscription) => subscription.size(), + Self::Unsubscribe(unsubscribe) => unsubscribe.size(), + Self::ConnAck(ack) => ack.size(), + Self::PubAck(ack) => ack.size(), + Self::SubAck(ack) => ack.size(), + Self::UnsubAck(unsuback) => unsuback.size(), + Self::PubRec(pubrec) => pubrec.size(), + Self::PubRel(pubrel) => pubrel.size(), + Self::PubComp(pubcomp) => pubcomp.size(), + Self::Connect(connect) => connect.size(), + Self::PingReq => PingReq.size(), + Self::PingResp => PingResp.size(), + Self::Disconnect => Disconnect.size(), + } + } + + /// Reads a stream of bytes and extracts next MQTT packet out of it + pub fn read(stream: &mut BytesMut, max_size: usize) -> Result { + let fixed_header = check(stream.iter(), max_size)?; + + // Test with a stream with exactly the size to check border panics + let packet = stream.split_to(fixed_header.frame_length()); + let packet_type = fixed_header.packet_type()?; - // Test with a stream with exactly the size to check border panics - let packet = stream.split_to(fixed_header.frame_length()); - let packet_type = fixed_header.packet_type()?; + if fixed_header.remaining_len == 0 { + // no payload packets + return match packet_type { + PacketType::PingReq => Ok(Packet::PingReq), + PacketType::PingResp => Ok(Packet::PingResp), + PacketType::Disconnect => Ok(Packet::Disconnect), + _ => Err(Error::PayloadRequired), + }; + } - if fixed_header.remaining_len == 0 { - // no payload packets - return match packet_type { - PacketType::PingReq => Ok(Packet::PingReq), - PacketType::PingResp => Ok(Packet::PingResp), - PacketType::Disconnect => Ok(Packet::Disconnect), - _ => Err(Error::PayloadRequired), + let packet = packet.freeze(); + let packet = match packet_type { + PacketType::Connect => Packet::Connect(Connect::read(fixed_header, packet)?), + PacketType::ConnAck => Packet::ConnAck(ConnAck::read(fixed_header, packet)?), + PacketType::Publish => Packet::Publish(Publish::read(fixed_header, packet)?), + PacketType::PubAck => Packet::PubAck(PubAck::read(fixed_header, packet)?), + PacketType::PubRec => Packet::PubRec(PubRec::read(fixed_header, packet)?), + PacketType::PubRel => Packet::PubRel(PubRel::read(fixed_header, packet)?), + PacketType::PubComp => Packet::PubComp(PubComp::read(fixed_header, packet)?), + PacketType::Subscribe => Packet::Subscribe(Subscribe::read(fixed_header, packet)?), + PacketType::SubAck => Packet::SubAck(SubAck::read(fixed_header, packet)?), + PacketType::Unsubscribe => { + Packet::Unsubscribe(Unsubscribe::read(fixed_header, packet)?) + } + PacketType::UnsubAck => Packet::UnsubAck(UnsubAck::read(fixed_header, packet)?), + PacketType::PingReq => Packet::PingReq, + PacketType::PingResp => Packet::PingResp, + PacketType::Disconnect => Packet::Disconnect, }; + + Ok(packet) } - let packet = packet.freeze(); - let packet = match packet_type { - PacketType::Connect => Packet::Connect(Connect::read(fixed_header, packet)?), - PacketType::ConnAck => Packet::ConnAck(ConnAck::read(fixed_header, packet)?), - PacketType::Publish => Packet::Publish(Publish::read(fixed_header, packet)?), - PacketType::PubAck => Packet::PubAck(PubAck::read(fixed_header, packet)?), - PacketType::PubRec => Packet::PubRec(PubRec::read(fixed_header, packet)?), - PacketType::PubRel => Packet::PubRel(PubRel::read(fixed_header, packet)?), - PacketType::PubComp => Packet::PubComp(PubComp::read(fixed_header, packet)?), - PacketType::Subscribe => Packet::Subscribe(Subscribe::read(fixed_header, packet)?), - PacketType::SubAck => Packet::SubAck(SubAck::read(fixed_header, packet)?), - PacketType::Unsubscribe => Packet::Unsubscribe(Unsubscribe::read(fixed_header, packet)?), - PacketType::UnsubAck => Packet::UnsubAck(UnsubAck::read(fixed_header, packet)?), - PacketType::PingReq => Packet::PingReq, - PacketType::PingResp => Packet::PingResp, - PacketType::Disconnect => Packet::Disconnect, - }; + /// Serializes the MQTT packet into a stream of bytes + pub fn write(&self, stream: &mut BytesMut, max_size: usize) -> Result { + if self.size() > max_size { + return Err(Error::OutgoingPacketTooLarge { + pkt_size: self.size(), + max: max_size, + }); + } - Ok(packet) + match self { + Packet::Connect(c) => c.write(stream), + Packet::ConnAck(c) => c.write(stream), + Packet::Publish(p) => p.write(stream), + Packet::PubAck(p) => p.write(stream), + Packet::PubRec(p) => p.write(stream), + Packet::PubRel(p) => p.write(stream), + Packet::PubComp(p) => p.write(stream), + Packet::Subscribe(s) => s.write(stream), + Packet::SubAck(s) => s.write(stream), + Packet::Unsubscribe(u) => u.write(stream), + Packet::UnsubAck(u) => u.write(stream), + Packet::PingReq => PingReq.write(stream), + Packet::PingResp => PingResp.write(stream), + Packet::Disconnect => Disconnect.write(stream), + } + } } /// Return number of remaining length bytes required for encoding length diff --git a/rumqttc/src/proxy.rs b/rumqttc/src/proxy.rs index 3dbe741cf..e7f84cd37 100644 --- a/rumqttc/src/proxy.rs +++ b/rumqttc/src/proxy.rs @@ -1,5 +1,5 @@ use crate::eventloop::socket_connect; -use crate::framed::N; +use crate::framed::AsyncReadWrite; use crate::NetworkOptions; use std::io; @@ -46,10 +46,10 @@ impl Proxy { broker_addr: &str, broker_port: u16, network_options: NetworkOptions, - ) -> Result, ProxyError> { + ) -> Result, ProxyError> { let proxy_addr = format!("{}:{}", self.addr, self.port); - let tcp: Box = Box::new(socket_connect(proxy_addr, network_options).await?); + let tcp: Box = Box::new(socket_connect(proxy_addr, network_options).await?); let mut tcp = match self.ty { ProxyType::Http => tcp, #[cfg(any(feature = "use-rustls", feature = "use-native-tls"))] @@ -67,7 +67,7 @@ impl ProxyAuth { self, host: &str, port: u16, - tcp_stream: &mut Box, + tcp_stream: &mut Box, ) -> Result<(), ProxyError> { match self { Self::None => async_http_proxy::http_connect_tokio(tcp_stream, host, port).await?, diff --git a/rumqttc/src/state.rs b/rumqttc/src/state.rs index da33bd2f2..06a75c1ba 100644 --- a/rumqttc/src/state.rs +++ b/rumqttc/src/state.rs @@ -2,7 +2,6 @@ use crate::{Event, Incoming, Outgoing, Request}; use crate::mqttbytes::v4::*; use crate::mqttbytes::{self, *}; -use bytes::BytesMut; use std::collections::VecDeque; use std::{io, time::Instant}; @@ -30,8 +29,10 @@ pub enum StateError { EmptySubscription, #[error("Mqtt serialization/deserialization error: {0}")] Deserialization(#[from] mqttbytes::Error), - #[error("Cannot send packet of size '{pkt_size:?}'. It's greater than the broker's maximum packet size of: '{max:?}'")] - OutgoingPacketTooLarge { pkt_size: usize, max: usize }, + #[error("Flush timeout")] + FlushTimeout, + #[error("Connection Closed")] + ConnectionClosed, } /// State of the mqtt connection. @@ -70,19 +71,15 @@ pub struct MqttState { pub collision: Option, /// Buffered incoming packets pub events: VecDeque, - /// Write buffer - pub write: BytesMut, /// Indicates if acknowledgements should be send immediately pub manual_acks: bool, - /// Maximum outgoing packet size, set via MqttOptions - pub max_outgoing_packet_size: usize, } impl MqttState { /// Creates new mqtt state. Same state should be used during a /// connection for persistent sessions while new state should /// instantiated for clean sessions - pub fn new(max_inflight: u16, manual_acks: bool, max_outgoing_packet_size: usize) -> Self { + pub fn new(max_inflight: u16, manual_acks: bool) -> Self { MqttState { await_pingresp: false, collision_ping_count: 0, @@ -99,9 +96,7 @@ impl MqttState { collision: None, // TODO: Optimize these sizes later events: VecDeque::with_capacity(100), - write: BytesMut::with_capacity(10 * 1024), manual_acks, - max_outgoing_packet_size, } } @@ -135,7 +130,6 @@ impl MqttState { self.await_pingresp = false; self.collision_ping_count = 0; self.inflight = 0; - self.write.clear(); pending } @@ -145,10 +139,11 @@ impl MqttState { /// Consolidates handling of all outgoing mqtt packet logic. Returns a packet which should /// be put on to the network by the eventloop - pub fn handle_outgoing_packet(&mut self, request: Request) -> Result<(), StateError> { - // Enforce max outgoing packet size - self.check_size(request.size())?; - match request { + pub fn handle_outgoing_packet( + &mut self, + request: Request, + ) -> Result, StateError> { + let packet = match request { Request::Publish(publish) => self.outgoing_publish(publish)?, Request::PubRel(pubrel) => self.outgoing_pubrel(pubrel)?, Request::Subscribe(subscribe) => self.outgoing_subscribe(subscribe)?, @@ -161,56 +156,58 @@ impl MqttState { }; self.last_outgoing = Instant::now(); - Ok(()) + Ok(packet) } /// Consolidates handling of all incoming mqtt packets. Returns a `Notification` which for the /// user to consume and `Packet` which for the eventloop to put on the network /// E.g For incoming QoS1 publish packet, this method returns (Publish, Puback). Publish packet will /// be forwarded to user and Pubck packet will be written to network - pub fn handle_incoming_packet(&mut self, packet: Incoming) -> Result<(), StateError> { - let out = match &packet { - Incoming::PingResp => self.handle_incoming_pingresp(), - Incoming::Publish(publish) => self.handle_incoming_publish(publish), - Incoming::SubAck(_suback) => self.handle_incoming_suback(), - Incoming::UnsubAck(_unsuback) => self.handle_incoming_unsuback(), - Incoming::PubAck(puback) => self.handle_incoming_puback(puback), - Incoming::PubRec(pubrec) => self.handle_incoming_pubrec(pubrec), - Incoming::PubRel(pubrel) => self.handle_incoming_pubrel(pubrel), - Incoming::PubComp(pubcomp) => self.handle_incoming_pubcomp(pubcomp), + pub fn handle_incoming_packet( + &mut self, + packet: Incoming, + ) -> Result, StateError> { + let outgoing = match &packet { + Incoming::PingResp => self.handle_incoming_pingresp()?, + Incoming::Publish(publish) => self.handle_incoming_publish(publish)?, + Incoming::SubAck(_suback) => self.handle_incoming_suback()?, + Incoming::UnsubAck(_unsuback) => self.handle_incoming_unsuback()?, + Incoming::PubAck(puback) => self.handle_incoming_puback(puback)?, + Incoming::PubRec(pubrec) => self.handle_incoming_pubrec(pubrec)?, + Incoming::PubRel(pubrel) => self.handle_incoming_pubrel(pubrel)?, + Incoming::PubComp(pubcomp) => self.handle_incoming_pubcomp(pubcomp)?, _ => { error!("Invalid incoming packet = {:?}", packet); return Err(StateError::WrongPacket); } }; - - out?; self.events.push_back(Event::Incoming(packet)); self.last_incoming = Instant::now(); - Ok(()) + + Ok(outgoing) } - fn handle_incoming_suback(&mut self) -> Result<(), StateError> { - Ok(()) + fn handle_incoming_suback(&mut self) -> Result, StateError> { + Ok(None) } - fn handle_incoming_unsuback(&mut self) -> Result<(), StateError> { - Ok(()) + fn handle_incoming_unsuback(&mut self) -> Result, StateError> { + Ok(None) } /// Results in a publish notification in all the QoS cases. Replys with an ack /// in case of QoS1 and Replys rec in case of QoS while also storing the message - fn handle_incoming_publish(&mut self, publish: &Publish) -> Result<(), StateError> { + fn handle_incoming_publish(&mut self, publish: &Publish) -> Result, StateError> { let qos = publish.qos; match qos { - QoS::AtMostOnce => Ok(()), + QoS::AtMostOnce => Ok(None), QoS::AtLeastOnce => { if !self.manual_acks { let puback = PubAck::new(publish.pkid); - self.outgoing_puback(puback)?; + return self.outgoing_puback(puback); } - Ok(()) + Ok(None) } QoS::ExactlyOnce => { let pkid = publish.pkid; @@ -218,45 +215,41 @@ impl MqttState { if !self.manual_acks { let pubrec = PubRec::new(pkid); - self.outgoing_pubrec(pubrec)?; + return self.outgoing_pubrec(pubrec); } - Ok(()) + Ok(None) } } } - fn handle_incoming_puback(&mut self, puback: &PubAck) -> Result<(), StateError> { + fn handle_incoming_puback(&mut self, puback: &PubAck) -> Result, StateError> { let publish = self .outgoing_pub .get_mut(puback.pkid as usize) .ok_or(StateError::Unsolicited(puback.pkid))?; self.last_puback = puback.pkid; - let v = match publish.take() { - Some(_) => { - self.inflight -= 1; - Ok(()) - } - None => { - error!("Unsolicited puback packet: {:?}", puback.pkid); - Err(StateError::Unsolicited(puback.pkid)) - } - }; + publish.take().ok_or({ + error!("Unsolicited puback packet: {:?}", puback.pkid); + StateError::Unsolicited(puback.pkid) + })?; - if let Some(publish) = self.check_collision(puback.pkid) { + self.inflight -= 1; + let packet = self.check_collision(puback.pkid).map(|publish| { self.outgoing_pub[publish.pkid as usize] = Some(publish.clone()); self.inflight += 1; - publish.write(&mut self.write)?; let event = Event::Outgoing(Outgoing::Publish(publish.pkid)); self.events.push_back(event); self.collision_ping_count = 0; - } - v + Packet::Publish(publish) + }); + + Ok(packet) } - fn handle_incoming_pubrec(&mut self, pubrec: &PubRec) -> Result<(), StateError> { + fn handle_incoming_pubrec(&mut self, pubrec: &PubRec) -> Result, StateError> { let publish = self .outgoing_pub .get_mut(pubrec.pkid as usize) @@ -265,11 +258,11 @@ impl MqttState { Some(_) => { // NOTE: Inflight - 1 for qos2 in comp self.outgoing_rel[pubrec.pkid as usize] = Some(pubrec.pkid); - PubRel::new(pubrec.pkid).write(&mut self.write)?; - + let pubrel = PubRel { pkid: pubrec.pkid }; let event = Event::Outgoing(Outgoing::PubRel(pubrec.pkid)); self.events.push_back(event); - Ok(()) + + Ok(Some(Packet::PubRel(pubrel))) } None => { error!("Unsolicited pubrec packet: {:?}", pubrec.pkid); @@ -278,17 +271,18 @@ impl MqttState { } } - fn handle_incoming_pubrel(&mut self, pubrel: &PubRel) -> Result<(), StateError> { + fn handle_incoming_pubrel(&mut self, pubrel: &PubRel) -> Result, StateError> { let publish = self .incoming_pub .get_mut(pubrel.pkid as usize) .ok_or(StateError::Unsolicited(pubrel.pkid))?; match publish.take() { Some(_) => { - PubComp::new(pubrel.pkid).write(&mut self.write)?; let event = Event::Outgoing(Outgoing::PubComp(pubrel.pkid)); + let pubcomp = PubComp { pkid: pubrel.pkid }; self.events.push_back(event); - Ok(()) + + Ok(Some(Packet::PubComp(pubcomp))) } None => { error!("Unsolicited pubrel packet: {:?}", pubrel.pkid); @@ -297,38 +291,37 @@ impl MqttState { } } - fn handle_incoming_pubcomp(&mut self, pubcomp: &PubComp) -> Result<(), StateError> { - if let Some(publish) = self.check_collision(pubcomp.pkid) { - publish.write(&mut self.write)?; + fn handle_incoming_pubcomp(&mut self, pubcomp: &PubComp) -> Result, StateError> { + self.outgoing_rel + .get_mut(pubcomp.pkid as usize) + .ok_or(StateError::Unsolicited(pubcomp.pkid))? + .take() + .ok_or({ + error!("Unsolicited pubcomp packet: {:?}", pubcomp.pkid); + StateError::Unsolicited(pubcomp.pkid) + })?; + + self.inflight -= 1; + let packet = self.check_collision(pubcomp.pkid).map(|publish| { let event = Event::Outgoing(Outgoing::Publish(publish.pkid)); self.events.push_back(event); self.collision_ping_count = 0; - } - let pubrel = self - .outgoing_rel - .get_mut(pubcomp.pkid as usize) - .ok_or(StateError::Unsolicited(pubcomp.pkid))?; - match pubrel.take() { - Some(_) => { - self.inflight -= 1; - Ok(()) - } - None => { - error!("Unsolicited pubcomp packet: {:?}", pubcomp.pkid); - Err(StateError::Unsolicited(pubcomp.pkid)) - } - } + Packet::Publish(publish) + }); + + Ok(packet) } - fn handle_incoming_pingresp(&mut self) -> Result<(), StateError> { + fn handle_incoming_pingresp(&mut self) -> Result, StateError> { self.await_pingresp = false; - Ok(()) + + Ok(None) } /// Adds next packet identifier to QoS 1 and 2 publish packets and returns /// it buy wrapping publish in packet - fn outgoing_publish(&mut self, mut publish: Publish) -> Result<(), StateError> { + fn outgoing_publish(&mut self, mut publish: Publish) -> Result, StateError> { if publish.qos != QoS::AtMostOnce { if publish.pkid == 0 { publish.pkid = self.next_pkid(); @@ -345,7 +338,7 @@ impl MqttState { self.collision = Some(publish); let event = Event::Outgoing(Outgoing::AwaitAck(pkid)); self.events.push_back(event); - return Ok(()); + return Ok(None); } // if there is an existing publish at this pkid, this implies that broker hasn't acked this @@ -361,41 +354,40 @@ impl MqttState { publish.payload.len() ); - publish.write(&mut self.write)?; let event = Event::Outgoing(Outgoing::Publish(publish.pkid)); self.events.push_back(event); - Ok(()) + + Ok(Some(Packet::Publish(publish))) } - fn outgoing_pubrel(&mut self, pubrel: PubRel) -> Result<(), StateError> { + fn outgoing_pubrel(&mut self, pubrel: PubRel) -> Result, StateError> { let pubrel = self.save_pubrel(pubrel)?; debug!("Pubrel. Pkid = {}", pubrel.pkid); - PubRel::new(pubrel.pkid).write(&mut self.write)?; - let event = Event::Outgoing(Outgoing::PubRel(pubrel.pkid)); self.events.push_back(event); - Ok(()) + + Ok(Some(Packet::PubRel(pubrel))) } - fn outgoing_puback(&mut self, puback: PubAck) -> Result<(), StateError> { - puback.write(&mut self.write)?; + fn outgoing_puback(&mut self, puback: PubAck) -> Result, StateError> { let event = Event::Outgoing(Outgoing::PubAck(puback.pkid)); self.events.push_back(event); - Ok(()) + + Ok(Some(Packet::PubAck(puback))) } - fn outgoing_pubrec(&mut self, pubrec: PubRec) -> Result<(), StateError> { - pubrec.write(&mut self.write)?; + fn outgoing_pubrec(&mut self, pubrec: PubRec) -> Result, StateError> { let event = Event::Outgoing(Outgoing::PubRec(pubrec.pkid)); self.events.push_back(event); - Ok(()) + + Ok(Some(Packet::PubRec(pubrec))) } /// check when the last control packet/pingreq packet is received and return /// the status which tells if keep alive time has exceeded /// NOTE: status will be checked for zero keepalive times also - fn outgoing_ping(&mut self) -> Result<(), StateError> { + fn outgoing_ping(&mut self) -> Result, StateError> { let elapsed_in = self.last_incoming.elapsed(); let elapsed_out = self.last_outgoing.elapsed(); @@ -421,13 +413,16 @@ impl MqttState { elapsed_out.as_millis() ); - PingReq.write(&mut self.write)?; let event = Event::Outgoing(Outgoing::PingReq); self.events.push_back(event); - Ok(()) + + Ok(Some(Packet::PingReq)) } - fn outgoing_subscribe(&mut self, mut subscription: Subscribe) -> Result<(), StateError> { + fn outgoing_subscribe( + &mut self, + mut subscription: Subscribe, + ) -> Result, StateError> { if subscription.filters.is_empty() { return Err(StateError::EmptySubscription); } @@ -440,13 +435,16 @@ impl MqttState { subscription.filters, subscription.pkid ); - subscription.write(&mut self.write)?; let event = Event::Outgoing(Outgoing::Subscribe(subscription.pkid)); self.events.push_back(event); - Ok(()) + + Ok(Some(Packet::Subscribe(subscription))) } - fn outgoing_unsubscribe(&mut self, mut unsub: Unsubscribe) -> Result<(), StateError> { + fn outgoing_unsubscribe( + &mut self, + mut unsub: Unsubscribe, + ) -> Result, StateError> { let pkid = self.next_pkid(); unsub.pkid = pkid; @@ -455,19 +453,19 @@ impl MqttState { unsub.topics, unsub.pkid ); - unsub.write(&mut self.write)?; let event = Event::Outgoing(Outgoing::Unsubscribe(unsub.pkid)); self.events.push_back(event); - Ok(()) + + Ok(Some(Packet::Unsubscribe(unsub))) } - fn outgoing_disconnect(&mut self) -> Result<(), StateError> { + fn outgoing_disconnect(&mut self) -> Result, StateError> { debug!("Disconnect"); - Disconnect.write(&mut self.write)?; let event = Event::Outgoing(Outgoing::Disconnect); self.events.push_back(event); - Ok(()) + + Ok(Some(Packet::Disconnect)) } fn check_collision(&mut self, pkid: u16) -> Option { @@ -480,17 +478,6 @@ impl MqttState { None } - fn check_size(&self, pkt_size: usize) -> Result<(), StateError> { - if pkt_size > self.max_outgoing_packet_size { - Err(StateError::OutgoingPacketTooLarge { - pkt_size, - max: self.max_outgoing_packet_size, - }) - } else { - Ok(()) - } - } - fn save_pubrel(&mut self, mut pubrel: PubRel) -> Result { let pubrel = match pubrel.pkid { // consider PacketIdentifier(0) as uninitialized packets @@ -532,7 +519,6 @@ mod test { use crate::mqttbytes::v4::*; use crate::mqttbytes::*; use crate::{Event, Incoming, Outgoing, Request}; - use bytes::BufMut; fn build_outgoing_publish(qos: QoS) -> Publish { let topic = "hello/world".to_owned(); @@ -554,7 +540,7 @@ mod test { } fn build_mqttstate() -> MqttState { - MqttState::new(100, false, usize::MAX) + MqttState::new(100, false) } #[test] @@ -574,25 +560,6 @@ mod test { } } - #[test] - fn outgoing_max_packet_size_check() { - let mut mqtt = MqttState::new(100, false, 200); - - let small_publish = Publish::new("hello/world", QoS::AtLeastOnce, vec![1; 100]); - assert_eq!( - mqtt.handle_outgoing_packet(Request::Publish(small_publish)) - .is_ok(), - true - ); - - let large_publish = Publish::new("hello/world", QoS::AtLeastOnce, vec![1; 265]); - assert_eq!( - mqtt.handle_outgoing_packet(Request::Publish(large_publish)) - .is_ok(), - false - ); - } - #[test] fn outgoing_publish_should_set_pkid_and_add_publish_to_queue() { let mut mqtt = build_mqttstate(); @@ -702,8 +669,7 @@ mod test { let mut mqtt = build_mqttstate(); let publish = build_incoming_publish(QoS::ExactlyOnce, 1); - mqtt.handle_incoming_publish(&publish).unwrap(); - let packet = read(&mut mqtt.write, 10 * 1024).unwrap(); + let packet = mqtt.handle_incoming_publish(&publish).unwrap().unwrap(); match packet { Packet::PubRec(pubrec) => assert_eq!(pubrec.pkid, 1), _ => panic!("Invalid network request: {:?}", packet), @@ -769,15 +735,16 @@ mod test { let mut mqtt = build_mqttstate(); let publish = build_outgoing_publish(QoS::ExactlyOnce); - mqtt.outgoing_publish(publish).unwrap(); - let packet = read(&mut mqtt.write, 10 * 1024).unwrap(); + let packet = mqtt.outgoing_publish(publish).unwrap().unwrap(); match packet { Packet::Publish(publish) => assert_eq!(publish.pkid, 1), packet => panic!("Invalid network request: {:?}", packet), } - mqtt.handle_incoming_pubrec(&PubRec::new(1)).unwrap(); - let packet = read(&mut mqtt.write, 10 * 1024).unwrap(); + let packet = mqtt + .handle_incoming_pubrec(&PubRec::new(1)) + .unwrap() + .unwrap(); match packet { Packet::PubRel(pubrel) => assert_eq!(pubrel.pkid, 1), packet => panic!("Invalid network request: {:?}", packet), @@ -789,15 +756,16 @@ mod test { let mut mqtt = build_mqttstate(); let publish = build_incoming_publish(QoS::ExactlyOnce, 1); - mqtt.handle_incoming_publish(&publish).unwrap(); - let packet = read(&mut mqtt.write, 10 * 1024).unwrap(); + let packet = mqtt.handle_incoming_publish(&publish).unwrap().unwrap(); match packet { Packet::PubRec(pubrec) => assert_eq!(pubrec.pkid, 1), packet => panic!("Invalid network request: {:?}", packet), } - mqtt.handle_incoming_pubrel(&PubRel::new(1)).unwrap(); - let packet = read(&mut mqtt.write, 10 * 1024).unwrap(); + let packet = mqtt + .handle_incoming_pubrel(&PubRel::new(1)) + .unwrap() + .unwrap(); match packet { Packet::PubComp(pubcomp) => assert_eq!(pubcomp.pkid, 1), packet => panic!("Invalid network request: {:?}", packet), @@ -848,15 +816,6 @@ mod test { mqtt.outgoing_ping().unwrap(); } - #[test] - fn state_should_be_clean_properly() { - let mut mqtt = build_mqttstate(); - mqtt.write.put(&b"test"[..]); - // After this clean state.write should be empty - mqtt.clean(); - assert!(mqtt.write.is_empty()); - } - #[test] fn clean_is_calculating_pending_correctly() { let mut mqtt = build_mqttstate(); diff --git a/rumqttc/src/tls.rs b/rumqttc/src/tls.rs index c8e775712..f80dba641 100644 --- a/rumqttc/src/tls.rs +++ b/rumqttc/src/tls.rs @@ -16,7 +16,7 @@ use std::io::{BufReader, Cursor}; #[cfg(feature = "use-rustls")] use std::sync::Arc; -use crate::framed::N; +use crate::framed::AsyncReadWrite; use crate::TlsConfiguration; #[cfg(feature = "use-native-tls")] @@ -166,9 +166,9 @@ pub async fn tls_connect( addr: &str, _port: u16, tls_config: &TlsConfiguration, - tcp: Box, -) -> Result, Error> { - let tls: Box = match tls_config { + tcp: Box, +) -> Result, Error> { + let tls: Box = match tls_config { #[cfg(feature = "use-rustls")] TlsConfiguration::Simple { .. } | TlsConfiguration::Rustls(_) => { let connector = rustls_connector(tls_config).await?; diff --git a/rumqttc/src/v5/eventloop.rs b/rumqttc/src/v5/eventloop.rs index 36c10971d..c0ba2ee72 100644 --- a/rumqttc/src/v5/eventloop.rs +++ b/rumqttc/src/v5/eventloop.rs @@ -2,7 +2,7 @@ use super::framed::Network; use super::mqttbytes::v5::*; use super::{Incoming, MqttOptions, MqttState, Outgoing, Request, StateError, Transport}; use crate::eventloop::socket_connect; -use crate::framed::N; +use crate::framed::AsyncReadWrite; use flume::{bounded, Receiver, Sender}; use tokio::select; @@ -210,17 +210,21 @@ impl EventLoop { self.options.pending_throttle ), if !self.pending.is_empty() || (!inflight_full && !collision) => match o { Ok(request) => { - self.state.handle_outgoing_packet(request)?; - network.flush(&mut self.state.write).await?; + if let Some(outgoing) = self.state.handle_outgoing_packet(request)? { + network.write(outgoing).await?; + } + Ok(self.state.events.pop_front().unwrap()) } Err(_) => Err(ConnectionError::RequestsDone), }, // Pull a bunch of packets from network, reply in bunch and yield the first item - o = network.readb(&mut self.state) => { - o?; - // flush all the acks and return first incoming packet - network.flush(&mut self.state.write).await?; + o = network.read() => { + let incoming = o?; + if let Some(packet) = self.state.handle_incoming_packet(incoming)? { + network.write(packet).await?; + } + Ok(self.state.events.pop_front().unwrap()) }, // We generate pings irrespective of network activity. This keeps the ping logic @@ -229,8 +233,10 @@ impl EventLoop { let timeout = self.keepalive_timeout.as_mut().unwrap(); timeout.as_mut().reset(Instant::now() + self.options.keep_alive); - self.state.handle_outgoing_packet(Request::PingReq)?; - network.flush(&mut self.state.write).await?; + if let Some(outgoing) = self.state.handle_outgoing_packet(Request::PingReq)? { + network.write(outgoing).await?; + } + Ok(self.state.events.pop_front().unwrap()) } } @@ -276,7 +282,9 @@ async fn connect(options: &mut MqttOptions) -> Result<(Network, Incoming), Conne } async fn network_connect(options: &MqttOptions) -> Result { - let mut max_incoming_pkt_size = Some(options.default_max_incoming_size); + let mut max_incoming_pkt_size = Some(options.default_max_incoming_size); // incoming == outgoing + let max_outgoing_pkt_size = Some(options.default_max_incoming_size); + let network_timeout = Duration::from_secs(options.network_options.connection_timeout()); // Override default value if max_packet_size is set on `connect_properties` if let Some(connect_props) = &options.connect_properties { @@ -291,7 +299,13 @@ async fn network_connect(options: &MqttOptions) -> Result Result options.broker_address(), }; - let tcp_stream: Box = { + let tcp_stream: Box = { #[cfg(feature = "proxy")] match options.proxy() { Some(proxy) => { @@ -327,13 +341,25 @@ async fn network_connect(options: &MqttOptions) -> Result Network::new(tcp_stream, max_incoming_pkt_size), + Transport::Tcp => Network::new( + tcp_stream, + max_incoming_pkt_size, + max_outgoing_pkt_size, + network_timeout, + options.network_buffer_capacity, + ), #[cfg(any(feature = "use-native-tls", feature = "use-rustls"))] Transport::Tls(tls_config) => { let socket = tls::tls_connect(&options.broker_addr, options.port, &tls_config, tcp_stream) .await?; - Network::new(socket, max_incoming_pkt_size) + Network::new( + socket, + max_incoming_pkt_size, + max_outgoing_pkt_size, + network_timeout, + options.network_buffer_capacity, + ) } #[cfg(unix)] Transport::Unix => unreachable!(), @@ -352,7 +378,13 @@ async fn network_connect(options: &MqttOptions) -> Result { @@ -375,7 +407,13 @@ async fn network_connect(options: &MqttOptions) -> Result frames efficiently. It takes /// advantage of pre-allocation, buffering and vectorization when /// appropriate to achieve performance pub struct Network { - /// Socket for IO - socket: Box, - /// Buffered reads - read: BytesMut, - /// Maximum packet size - max_incoming_size: Option, - /// Maximum readv count - max_readb_count: usize, + /// Frame MQTT packets from network connection + framed: Framed, Codec>, + /// Time within which network write operations should complete + timeout: Duration, + /// Capacity upto which buffering is good + buffer_capacity: usize, } impl Network { - pub fn new(socket: impl N + 'static, max_incoming_size: Option) -> Network { - let socket = Box::new(socket) as Box; - Network { - socket, - read: BytesMut::with_capacity(10 * 1024), + pub fn new( + socket: impl AsyncReadWrite + 'static, + max_incoming_size: Option, + max_outgoing_size: Option, + timeout: Duration, + buffer_capacity: usize, + ) -> Network { + let socket = Box::new(socket) as Box; + let codec = Codec { max_incoming_size, - max_readb_count: 10, - } - } - - /// Reads more than 'required' bytes to frame a packet into self.read buffer - async fn read_bytes(&mut self, required: usize) -> io::Result { - let mut total_read = 0; - loop { - let read = self.socket.read_buf(&mut self.read).await?; - if 0 == read { - return if self.read.is_empty() { - Err(io::Error::new( - io::ErrorKind::ConnectionAborted, - "connection closed by peer", - )) - } else { - Err(io::Error::new( - io::ErrorKind::ConnectionReset, - "connection reset by peer", - )) - }; - } + max_outgoing_size, + }; + let framed = Framed::with_capacity(socket, codec, buffer_capacity); - total_read += read; - if total_read >= required { - return Ok(total_read); - } + Network { + framed, + timeout, + buffer_capacity, } } - pub async fn read(&mut self) -> io::Result { - loop { - let required = match Packet::read(&mut self.read, self.max_incoming_size) { - Ok(packet) => return Ok(packet), - Err(mqttbytes::Error::InsufficientBytes(required)) => required, - Err(e) => return Err(io::Error::new(io::ErrorKind::InvalidData, e.to_string())), - }; - - // read more packets until a frame can be created. This function - // blocks until a frame can be created. Use this in a select! branch - self.read_bytes(required).await?; + pub async fn read(&mut self) -> Result { + match self.framed.next().await { + Some(Ok(packet)) => Ok(packet), + Some(Err(mqttbytes::Error::InsufficientBytes(_))) => unreachable!(), + Some(Err(e)) => Err(StateError::Deserialization(e)), + None => Err(StateError::ConnectionClosed), } } - /// Read packets in bulk. This allow replies to be in bulk. This method is used - /// after the connection is established to read a bunch of incoming packets - pub async fn readb(&mut self, state: &mut MqttState) -> Result<(), StateError> { - let mut count = 0; - loop { - match Packet::read(&mut self.read, self.max_incoming_size) { - Ok(packet) => { - state.handle_incoming_packet(packet)?; + /// Write packets into buffer, flushes `Connect`/`PingReq`/`PingResp` packets instantly, + /// or on breaching buffer capacity + pub async fn write(&mut self, packet: Packet) -> Result<(), StateError> { + let packet_size = packet.size(); + let should_flush = match packet { + Packet::Connect(..) | Packet::PingReq(_) | Packet::PingResp(_) => true, + _ => false, + }; + self.framed + .feed(packet) + .await + .map_err(StateError::Deserialization)?; - count += 1; - if count >= self.max_readb_count { - return Ok(()); - } - } - // If some packets are already framed, return those - Err(mqttbytes::Error::InsufficientBytes(_)) if count > 0 => return Ok(()), - // Wait for more bytes until a frame can be created - Err(mqttbytes::Error::InsufficientBytes(required)) => { - self.read_bytes(required).await?; - } - Err(mqttbytes::Error::PayloadSizeLimitExceeded { pkt_size, max }) => { - state.handle_protocol_error()?; - return Err(StateError::IncomingPacketTooLarge { pkt_size, max }); - } - Err(e) => return Err(StateError::Deserialization(e)), - }; + if should_flush || self.framed.write_buffer().len() + packet_size >= self.buffer_capacity { + self.flush().await?; } - } - - pub async fn connect(&mut self, connect: Connect, options: &MqttOptions) -> io::Result { - let mut write = BytesMut::new(); - let last_will = options.last_will(); - let login = options.credentials().map(|l| Login { - username: l.0, - password: l.1, - }); - - let len = match Packet::Connect(connect, last_will, login).write(&mut write) { - Ok(size) => size, - Err(e) => return Err(io::Error::new(io::ErrorKind::InvalidData, e.to_string())), - }; - self.socket.write_all(&write[..]).await?; - Ok(len) + Ok(()) } - pub async fn flush(&mut self, write: &mut BytesMut) -> io::Result<()> { - if write.is_empty() { - return Ok(()); + /// Force flush all packets in buffer, reset count + pub async fn flush(&mut self) -> Result<(), StateError> { + match timeout(self.timeout, self.framed.flush()).await { + Ok(inner) => inner.map_err(StateError::Deserialization), + Err(e) => Err(StateError::Timeout(e)), } - - self.socket.write_all(&write[..]).await?; - write.clear(); - Ok(()) } } - -pub trait N: AsyncRead + AsyncWrite + Send + Unpin {} -impl N for T where T: AsyncRead + AsyncWrite + Send + Unpin {} diff --git a/rumqttc/src/v5/mod.rs b/rumqttc/src/v5/mod.rs index 663cfd278..8fce8b49b 100644 --- a/rumqttc/src/v5/mod.rs +++ b/rumqttc/src/v5/mod.rs @@ -78,8 +78,8 @@ pub struct MqttOptions { credentials: Option<(String, String)>, /// request (publish, subscribe) channel capacity request_channel_capacity: usize, - /// Max internal request batching - max_request_batch: usize, + /// Network buffer capacity in memory + network_buffer_capacity: usize, /// Minimum delay time between consecutive outgoing packets /// while retransmitting pending packets pending_throttle: Duration, @@ -126,7 +126,7 @@ impl MqttOptions { client_id: id.into(), credentials: None, request_channel_capacity: 10, - max_request_batch: 0, + network_buffer_capacity: 10 * 1024, pending_throttle: Duration::from_micros(0), last_will: None, conn_timeout: 5, @@ -274,6 +274,12 @@ impl MqttOptions { self } + /// Maximum buffer capacity before network flush + pub fn set_network_buffer_capacity(&mut self, network_buffer_capacity: usize) -> &mut Self { + self.network_buffer_capacity = network_buffer_capacity; + self + } + /// Request channel capacity pub fn request_channel_capacity(&self) -> usize { self.request_channel_capacity @@ -654,12 +660,12 @@ impl std::convert::TryFrom for MqttOptions { options.request_channel_capacity = request_channel_capacity; } - if let Some(max_request_batch) = queries - .remove("max_request_batch_num") + if let Some(network_buffer_capacity) = queries + .remove("network_buffer_capacity_num") .map(|v| v.parse::().map_err(|_| OptionError::MaxRequestBatch)) .transpose()? { - options.max_request_batch = max_request_batch; + options.network_buffer_capacity = network_buffer_capacity; } if let Some(pending_throttle) = queries @@ -704,7 +710,7 @@ impl Debug for MqttOptions { .field("client_id", &self.client_id) .field("credentials", &self.credentials) .field("request_channel_capacity", &self.request_channel_capacity) - .field("max_request_batch", &self.max_request_batch) + .field("network_buffer_capacity", &self.network_buffer_capacity) .field("pending_throttle", &self.pending_throttle) .field("last_will", &self.last_will) .field("conn_timeout", &self.conn_timeout) @@ -785,7 +791,7 @@ mod test { OptionError::RequestChannelCapacity ); assert_eq!( - err("mqtt://host:42?client_id=foo&max_request_batch_num=foo"), + err("mqtt://host:42?client_id=foo&network_buffer_capacity_num=foo"), OptionError::MaxRequestBatch ); assert_eq!( diff --git a/rumqttc/src/v5/mqttbytes/mod.rs b/rumqttc/src/v5/mqttbytes/mod.rs index 231c68067..42c1fcdd7 100644 --- a/rumqttc/src/v5/mqttbytes/mod.rs +++ b/rumqttc/src/v5/mqttbytes/mod.rs @@ -130,7 +130,7 @@ pub fn matches(topic: &str, filter: &str) -> bool { } /// Error during serialization and deserialization -#[derive(Debug, Clone, Copy, PartialEq, Eq, thiserror::Error)] +#[derive(Debug, thiserror::Error)] pub enum Error { #[error("Invalid return code received as response for connect = {0}")] InvalidConnectReturnCode(u8), @@ -183,4 +183,8 @@ pub enum Error { /// proceed further #[error("Insufficient number of bytes to frame packet, {0} more bytes required")] InsufficientBytes(usize), + #[error("IO: {0}")] + Io(#[from] std::io::Error), + #[error("Cannot send packet of size '{pkt_size:?}'. It's greater than the broker's maximum packet size of: '{max:?}'")] + OutgoingPacketTooLarge { pkt_size: u32, max: u32 }, } diff --git a/rumqttc/src/v5/mqttbytes/v5/codec.rs b/rumqttc/src/v5/mqttbytes/v5/codec.rs new file mode 100644 index 000000000..2bd3272be --- /dev/null +++ b/rumqttc/src/v5/mqttbytes/v5/codec.rs @@ -0,0 +1,76 @@ +use bytes::BytesMut; +use tokio_util::codec::{Decoder, Encoder}; + +use super::{Error, Packet}; + +/// MQTT v4 codec +#[derive(Debug, Clone)] +pub struct Codec { + /// Maximum packet size allowed by client + pub max_incoming_size: Option, + /// Maximum packet size allowed by broker + pub max_outgoing_size: Option, +} + +impl Decoder for Codec { + type Item = Packet; + type Error = Error; + + fn decode(&mut self, src: &mut BytesMut) -> Result, Self::Error> { + match Packet::read(src, self.max_incoming_size) { + Ok(packet) => Ok(Some(packet)), + // NOTE: not enough bytes to construct packet, reserve enough in src buffer + Err(Error::InsufficientBytes(b)) => { + src.reserve(b); + Ok(None) + } + Err(e) => Err(e), + } + } +} + +impl Encoder for Codec { + type Error = Error; + + fn encode(&mut self, item: Packet, dst: &mut BytesMut) -> Result<(), Self::Error> { + item.write(dst, self.max_outgoing_size)?; + + Ok(()) + } +} + +#[cfg(test)] +mod tests { + use bytes::BytesMut; + use tokio_util::codec::Encoder; + + use super::Codec; + use crate::v5::{ + mqttbytes::{Error, QoS}, + Packet, Publish, + }; + + #[test] + fn outgoing_max_packet_size_check() { + let mut buf = BytesMut::new(); + let mut codec = Codec { + max_incoming_size: Some(100), + max_outgoing_size: Some(200), + }; + + let mut small_publish = Publish::new("hello/world", QoS::AtLeastOnce, vec![1; 100], None); + small_publish.pkid = 1; + codec + .encode(Packet::Publish(small_publish), &mut buf) + .unwrap(); + + let large_publish = Publish::new("hello/world", QoS::AtLeastOnce, vec![1; 265], None); + match codec.encode(Packet::Publish(large_publish), &mut buf) { + Err(Error::OutgoingPacketTooLarge { + pkt_size: 282, + max: 200, + }) => {} + _ => unreachable!(), + } + } +} diff --git a/rumqttc/src/v5/mqttbytes/v5/connect.rs b/rumqttc/src/v5/mqttbytes/v5/connect.rs index 83918b871..a351c411d 100644 --- a/rumqttc/src/v5/mqttbytes/v5/connect.rs +++ b/rumqttc/src/v5/mqttbytes/v5/connect.rs @@ -127,6 +127,13 @@ impl Connect { buffer[flags_index] = connect_flags; Ok(1 + count + len) } + + pub fn size(&self, will: &Option, login: &Option) -> usize { + let len = self.len(will, login); + let remaining_len_size = len_len(len); + + 1 + remaining_len_size + len + } } #[derive(Debug, Clone, PartialEq, Eq)] diff --git a/rumqttc/src/v5/mqttbytes/v5/mod.rs b/rumqttc/src/v5/mqttbytes/v5/mod.rs index bf4dcb425..d98ab475c 100644 --- a/rumqttc/src/v5/mqttbytes/v5/mod.rs +++ b/rumqttc/src/v5/mqttbytes/v5/mod.rs @@ -1,6 +1,7 @@ use std::slice::Iter; pub use self::{ + codec::Codec, connack::{ConnAck, ConnAckProperties, ConnectReturnCode}, connect::{Connect, ConnectProperties, LastWill, LastWillProperties, Login}, disconnect::{Disconnect, DisconnectReasonCode}, @@ -19,6 +20,7 @@ pub use self::{ use super::*; use bytes::{Buf, BufMut, Bytes, BytesMut}; +mod codec; mod connack; mod connect; mod disconnect; @@ -126,7 +128,17 @@ impl Packet { Ok(packet) } - pub fn write(&self, write: &mut BytesMut) -> Result { + pub fn write(&self, write: &mut BytesMut, max_size: Option) -> Result { + if let Some(max_size) = max_size { + if self.size() > max_size { + dbg!(); + return Err(Error::OutgoingPacketTooLarge { + pkt_size: self.size() as u32, + max: max_size as u32, + }); + } + } + match self { Self::Publish(publish) => publish.write(write), Self::Subscribe(subscription) => subscription.write(write), @@ -144,6 +156,25 @@ impl Packet { Self::Disconnect(disconnect) => disconnect.write(write), } } + + pub fn size(&self) -> usize { + match self { + Self::Publish(publish) => publish.size(), + Self::Subscribe(subscription) => subscription.size(), + Self::Unsubscribe(unsubscribe) => unsubscribe.size(), + Self::ConnAck(ack) => ack.size(), + Self::PubAck(ack) => ack.size(), + Self::SubAck(ack) => ack.size(), + Self::UnsubAck(unsuback) => unsuback.size(), + Self::PubRec(pubrec) => pubrec.size(), + Self::PubRel(pubrel) => pubrel.size(), + Self::PubComp(pubcomp) => pubcomp.size(), + Self::Connect(connect, will, login) => connect.size(will, login), + Self::PingReq(req) => req.size(), + Self::PingResp(resp) => resp.size(), + Self::Disconnect(disconnect) => disconnect.size(), + } + } } /// MQTT packet type diff --git a/rumqttc/src/v5/mqttbytes/v5/ping.rs b/rumqttc/src/v5/mqttbytes/v5/ping.rs index 086311ed9..d69ead6f4 100644 --- a/rumqttc/src/v5/mqttbytes/v5/ping.rs +++ b/rumqttc/src/v5/mqttbytes/v5/ping.rs @@ -9,6 +9,10 @@ impl PingReq { payload.put_slice(&[0xC0, 0x00]); Ok(2) } + + pub fn size(&self) -> usize { + 2 + } } #[derive(Debug, Clone, PartialEq, Eq)] @@ -19,4 +23,8 @@ impl PingResp { payload.put_slice(&[0xD0, 0x00]); Ok(2) } + + pub fn size(&self) -> usize { + 2 + } } diff --git a/rumqttc/src/v5/state.rs b/rumqttc/src/v5/state.rs index 8473f1f4c..7d41da1c1 100644 --- a/rumqttc/src/v5/state.rs +++ b/rumqttc/src/v5/state.rs @@ -9,8 +9,8 @@ use super::{Event, Incoming, Outgoing, Request}; use bytes::{Bytes, BytesMut}; use std::collections::{HashMap, VecDeque}; -use std::convert::TryInto; use std::{io, time::Instant}; +use tokio::time::error::Elapsed; /// Errors during state handling #[derive(Debug, thiserror::Error)] @@ -42,10 +42,6 @@ pub enum StateError { "Cannot use topic alias '{alias:?}'. It's greater than the broker's maximum of '{max:?}'." )] InvalidAlias { alias: u16, max: u16 }, - #[error("Cannot send packet of size '{pkt_size:?}'. It's greater than the broker's maximum packet size of: '{max:?}'")] - OutgoingPacketTooLarge { pkt_size: u32, max: u32 }, - #[error("Cannot receive packet of size '{pkt_size:?}'. It's greater than the client's maximum packet size of: '{max:?}'")] - IncomingPacketTooLarge { pkt_size: usize, max: usize }, #[error("Server sent disconnect with reason `{reason_string:?}` and code '{reason_code:?}' ")] ServerDisconnect { reason_code: DisconnectReasonCode, @@ -65,6 +61,10 @@ pub enum StateError { PubCompFail { reason: PubCompReason }, #[error("Connection failed with reason '{reason:?}' ")] ConnFail { reason: ConnectReturnCode }, + #[error("Timeout")] + Timeout(#[from] Elapsed), + #[error("Connection Closed")] + ConnectionClosed, } /// State of the mqtt connection. @@ -108,7 +108,7 @@ pub struct MqttState { /// `topic_alias_maximum` RECEIVED via connack packet pub broker_topic_alias_max: u16, /// The broker's `max_packet_size` received via connack - pub max_outgoing_packet_size: Option, + pub max_outgoing_packet_size: Option, /// Maximum number of allowed inflight QoS1 & QoS2 requests pub(crate) max_outgoing_inflight: u16, /// Upper limit on the maximum number of allowed inflight QoS1 & QoS2 requests @@ -181,77 +181,67 @@ impl MqttState { /// Consolidates handling of all outgoing mqtt packet logic. Returns a packet which should /// be put on to the network by the eventloop - pub fn handle_outgoing_packet(&mut self, request: Request) -> Result<(), StateError> { - match request { - Request::Publish(publish) => { - self.check_size(publish.size())?; - self.outgoing_publish(publish)? - } - Request::PubRel(pubrel) => { - self.check_size(pubrel.size())?; - self.outgoing_pubrel(pubrel)? - } - Request::Subscribe(subscribe) => { - self.check_size(subscribe.size())?; - self.outgoing_subscribe(subscribe)? - } - Request::Unsubscribe(unsubscribe) => { - self.check_size(unsubscribe.size())?; - self.outgoing_unsubscribe(unsubscribe)? - } + pub fn handle_outgoing_packet( + &mut self, + request: Request, + ) -> Result, StateError> { + let packet = match request { + Request::Publish(publish) => self.outgoing_publish(publish)?, + Request::PubRel(pubrel) => self.outgoing_pubrel(pubrel)?, + Request::Subscribe(subscribe) => self.outgoing_subscribe(subscribe)?, + Request::Unsubscribe(unsubscribe) => self.outgoing_unsubscribe(unsubscribe)?, Request::PingReq => self.outgoing_ping()?, Request::Disconnect => { self.outgoing_disconnect(DisconnectReasonCode::NormalDisconnection)? } - Request::PubAck(puback) => { - self.check_size(puback.size())?; - self.outgoing_puback(puback)? - } - Request::PubRec(pubrec) => { - self.check_size(pubrec.size())?; - self.outgoing_pubrec(pubrec)? - } + Request::PubAck(puback) => self.outgoing_puback(puback)?, + Request::PubRec(pubrec) => self.outgoing_pubrec(pubrec)?, _ => unimplemented!(), }; self.last_outgoing = Instant::now(); - Ok(()) + Ok(packet) } /// Consolidates handling of all incoming mqtt packets. Returns a `Notification` which for the /// user to consume and `Packet` which for the eventloop to put on the network /// E.g For incoming QoS1 publish packet, this method returns (Publish, Puback). Publish packet will /// be forwarded to user and Pubck packet will be written to network - pub fn handle_incoming_packet(&mut self, mut packet: Incoming) -> Result<(), StateError> { - let out = match &mut packet { - Incoming::PingResp(_) => self.handle_incoming_pingresp(), - Incoming::Publish(publish) => self.handle_incoming_publish(publish), - Incoming::SubAck(suback) => self.handle_incoming_suback(suback), - Incoming::UnsubAck(unsuback) => self.handle_incoming_unsuback(unsuback), - Incoming::PubAck(puback) => self.handle_incoming_puback(puback), - Incoming::PubRec(pubrec) => self.handle_incoming_pubrec(pubrec), - Incoming::PubRel(pubrel) => self.handle_incoming_pubrel(pubrel), - Incoming::PubComp(pubcomp) => self.handle_incoming_pubcomp(pubcomp), - Incoming::ConnAck(connack) => self.handle_incoming_connack(connack), - Incoming::Disconnect(disconn) => self.handle_incoming_disconn(disconn), + pub fn handle_incoming_packet( + &mut self, + mut packet: Incoming, + ) -> Result, StateError> { + let outgoing = match &mut packet { + Incoming::PingResp(_) => self.handle_incoming_pingresp()?, + Incoming::Publish(publish) => self.handle_incoming_publish(publish)?, + Incoming::SubAck(suback) => self.handle_incoming_suback(suback)?, + Incoming::UnsubAck(unsuback) => self.handle_incoming_unsuback(unsuback)?, + Incoming::PubAck(puback) => self.handle_incoming_puback(puback)?, + Incoming::PubRec(pubrec) => self.handle_incoming_pubrec(pubrec)?, + Incoming::PubRel(pubrel) => self.handle_incoming_pubrel(pubrel)?, + Incoming::PubComp(pubcomp) => self.handle_incoming_pubcomp(pubcomp)?, + Incoming::ConnAck(connack) => self.handle_incoming_connack(connack)?, + Incoming::Disconnect(disconn) => self.handle_incoming_disconn(disconn)?, _ => { error!("Invalid incoming packet = {:?}", packet); return Err(StateError::WrongPacket); } }; - out?; self.events.push_back(Event::Incoming(packet)); self.last_incoming = Instant::now(); - Ok(()) + Ok(outgoing) } - pub fn handle_protocol_error(&mut self) -> Result<(), StateError> { + pub fn handle_protocol_error(&mut self) -> Result, StateError> { // send DISCONNECT packet with REASON_CODE 0x82 self.outgoing_disconnect(DisconnectReasonCode::ProtocolError) } - fn handle_incoming_suback(&mut self, suback: &mut SubAck) -> Result<(), StateError> { + fn handle_incoming_suback( + &mut self, + suback: &mut SubAck, + ) -> Result, StateError> { for reason in suback.return_codes.iter() { match reason { SubscribeReasonCode::Success(qos) => { @@ -260,19 +250,25 @@ impl MqttState { _ => return Err(StateError::SubFail { reason: *reason }), } } - Ok(()) + Ok(None) } - fn handle_incoming_unsuback(&mut self, unsuback: &mut UnsubAck) -> Result<(), StateError> { + fn handle_incoming_unsuback( + &mut self, + unsuback: &mut UnsubAck, + ) -> Result, StateError> { for reason in unsuback.reasons.iter() { if reason != &UnsubAckReason::Success { return Err(StateError::UnsubFail { reason: *reason }); } } - Ok(()) + Ok(None) } - fn handle_incoming_connack(&mut self, connack: &mut ConnAck) -> Result<(), StateError> { + fn handle_incoming_connack( + &mut self, + connack: &mut ConnAck, + ) -> Result, StateError> { if connack.code != ConnectReturnCode::Success { return Err(StateError::ConnFail { reason: connack.code, @@ -291,12 +287,15 @@ impl MqttState { // to save some space. } - self.max_outgoing_packet_size = props.max_packet_size; + self.max_outgoing_packet_size = props.max_packet_size.map(|i| i as usize); } - Ok(()) + Ok(None) } - fn handle_incoming_disconn(&mut self, disconn: &mut Disconnect) -> Result<(), StateError> { + fn handle_incoming_disconn( + &mut self, + disconn: &mut Disconnect, + ) -> Result, StateError> { let reason_code = disconn.reason_code; let reason_string = if let Some(props) = &disconn.properties { props.reason_string.clone() @@ -311,7 +310,10 @@ impl MqttState { /// Results in a publish notification in all the QoS cases. Replys with an ack /// in case of QoS1 and Replys rec in case of QoS while also storing the message - fn handle_incoming_publish(&mut self, publish: &mut Publish) -> Result<(), StateError> { + fn handle_incoming_publish( + &mut self, + publish: &mut Publish, + ) -> Result, StateError> { let qos = publish.qos; let topic_alias = match &publish.properties { @@ -332,13 +334,13 @@ impl MqttState { } match qos { - QoS::AtMostOnce => Ok(()), + QoS::AtMostOnce => Ok(None), QoS::AtLeastOnce => { if !self.manual_acks { let puback = PubAck::new(publish.pkid, None); self.outgoing_puback(puback)?; } - Ok(()) + Ok(None) } QoS::ExactlyOnce => { let pkid = publish.pkid; @@ -348,12 +350,12 @@ impl MqttState { let pubrec = PubRec::new(pkid, None); self.outgoing_pubrec(pubrec)?; } - Ok(()) + Ok(None) } } } - fn handle_incoming_puback(&mut self, puback: &PubAck) -> Result<(), StateError> { + fn handle_incoming_puback(&mut self, puback: &PubAck) -> Result, StateError> { let publish = self .outgoing_pub .get_mut(puback.pkid as usize) @@ -361,7 +363,7 @@ impl MqttState { let v = match publish.take() { Some(_) => { self.inflight -= 1; - Ok(()) + Ok(None) } None => { error!("Unsolicited puback packet: {:?}", puback.pkid); @@ -382,7 +384,7 @@ impl MqttState { self.inflight += 1; let pkid = publish.pkid; - Packet::Publish(publish).write(&mut self.write)?; + Packet::Publish(publish).write(&mut self.write, self.max_outgoing_packet_size)?; let event = Event::Outgoing(Outgoing::Publish(pkid)); self.events.push_back(event); self.collision_ping_count = 0; @@ -391,7 +393,7 @@ impl MqttState { v } - fn handle_incoming_pubrec(&mut self, pubrec: &PubRec) -> Result<(), StateError> { + fn handle_incoming_pubrec(&mut self, pubrec: &PubRec) -> Result, StateError> { let publish = self .outgoing_pub .get_mut(pubrec.pkid as usize) @@ -408,11 +410,12 @@ impl MqttState { // NOTE: Inflight - 1 for qos2 in comp self.outgoing_rel[pubrec.pkid as usize] = Some(pubrec.pkid); - Packet::PubRel(PubRel::new(pubrec.pkid, None)).write(&mut self.write)?; + Packet::PubRel(PubRel::new(pubrec.pkid, None)) + .write(&mut self.write, self.max_outgoing_packet_size)?; let event = Event::Outgoing(Outgoing::PubRel(pubrec.pkid)); self.events.push_back(event); - Ok(()) + Ok(None) } None => { error!("Unsolicited pubrec packet: {:?}", pubrec.pkid); @@ -421,7 +424,7 @@ impl MqttState { } } - fn handle_incoming_pubrel(&mut self, pubrel: &PubRel) -> Result<(), StateError> { + fn handle_incoming_pubrel(&mut self, pubrel: &PubRel) -> Result, StateError> { let publish = self .incoming_pub .get_mut(pubrel.pkid as usize) @@ -434,10 +437,11 @@ impl MqttState { }); } - Packet::PubComp(PubComp::new(pubrel.pkid, None)).write(&mut self.write)?; + Packet::PubComp(PubComp::new(pubrel.pkid, None)) + .write(&mut self.write, self.max_outgoing_packet_size)?; let event = Event::Outgoing(Outgoing::PubComp(pubrel.pkid)); self.events.push_back(event); - Ok(()) + Ok(None) } None => { error!("Unsolicited pubrel packet: {:?}", pubrel.pkid); @@ -446,10 +450,10 @@ impl MqttState { } } - fn handle_incoming_pubcomp(&mut self, pubcomp: &PubComp) -> Result<(), StateError> { + fn handle_incoming_pubcomp(&mut self, pubcomp: &PubComp) -> Result, StateError> { if let Some(publish) = self.check_collision(pubcomp.pkid) { let pkid = publish.pkid; - Packet::Publish(publish).write(&mut self.write)?; + Packet::Publish(publish).write(&mut self.write, self.max_outgoing_packet_size)?; let event = Event::Outgoing(Outgoing::Publish(pkid)); self.events.push_back(event); self.collision_ping_count = 0; @@ -468,7 +472,7 @@ impl MqttState { } self.inflight -= 1; - Ok(()) + Ok(None) } None => { error!("Unsolicited pubcomp packet: {:?}", pubcomp.pkid); @@ -477,14 +481,14 @@ impl MqttState { } } - fn handle_incoming_pingresp(&mut self) -> Result<(), StateError> { + fn handle_incoming_pingresp(&mut self) -> Result, StateError> { self.await_pingresp = false; - Ok(()) + Ok(None) } /// Adds next packet identifier to QoS 1 and 2 publish packets and returns /// it buy wrapping publish in packet - fn outgoing_publish(&mut self, mut publish: Publish) -> Result<(), StateError> { + fn outgoing_publish(&mut self, mut publish: Publish) -> Result, StateError> { if publish.qos != QoS::AtMostOnce { if publish.pkid == 0 { publish.pkid = self.next_pkid(); @@ -501,7 +505,7 @@ impl MqttState { self.collision = Some(publish); let event = Event::Outgoing(Outgoing::AwaitAck(pkid)); self.events.push_back(event); - return Ok(()); + return Ok(None); } // if there is an existing publish at this pkid, this implies that broker hasn't acked this @@ -532,43 +536,44 @@ impl MqttState { } }; - Packet::Publish(publish).write(&mut self.write)?; + Packet::Publish(publish).write(&mut self.write, self.max_outgoing_packet_size)?; let event = Event::Outgoing(Outgoing::Publish(pkid)); self.events.push_back(event); - Ok(()) + Ok(None) } - fn outgoing_pubrel(&mut self, pubrel: PubRel) -> Result<(), StateError> { + fn outgoing_pubrel(&mut self, pubrel: PubRel) -> Result, StateError> { let pubrel = self.save_pubrel(pubrel)?; debug!("Pubrel. Pkid = {}", pubrel.pkid); - Packet::PubRel(PubRel::new(pubrel.pkid, None)).write(&mut self.write)?; + Packet::PubRel(PubRel::new(pubrel.pkid, None)) + .write(&mut self.write, self.max_outgoing_packet_size)?; let event = Event::Outgoing(Outgoing::PubRel(pubrel.pkid)); self.events.push_back(event); - Ok(()) + Ok(None) } - fn outgoing_puback(&mut self, puback: PubAck) -> Result<(), StateError> { + fn outgoing_puback(&mut self, puback: PubAck) -> Result, StateError> { let pkid = puback.pkid; - Packet::PubAck(puback).write(&mut self.write)?; + Packet::PubAck(puback).write(&mut self.write, self.max_outgoing_packet_size)?; let event = Event::Outgoing(Outgoing::PubAck(pkid)); self.events.push_back(event); - Ok(()) + Ok(None) } - fn outgoing_pubrec(&mut self, pubrec: PubRec) -> Result<(), StateError> { + fn outgoing_pubrec(&mut self, pubrec: PubRec) -> Result, StateError> { let pkid = pubrec.pkid; - Packet::PubRec(pubrec).write(&mut self.write)?; + Packet::PubRec(pubrec).write(&mut self.write, self.max_outgoing_packet_size)?; let event = Event::Outgoing(Outgoing::PubRec(pkid)); self.events.push_back(event); - Ok(()) + Ok(None) } /// check when the last control packet/pingreq packet is received and return /// the status which tells if keep alive time has exceeded /// NOTE: status will be checked for zero keepalive times also - fn outgoing_ping(&mut self) -> Result<(), StateError> { + fn outgoing_ping(&mut self) -> Result, StateError> { let elapsed_in = self.last_incoming.elapsed(); let elapsed_out = self.last_outgoing.elapsed(); @@ -591,13 +596,16 @@ impl MqttState { elapsed_in, elapsed_out, ); - Packet::PingReq(PingReq).write(&mut self.write)?; + Packet::PingReq(PingReq).write(&mut self.write, self.max_outgoing_packet_size)?; let event = Event::Outgoing(Outgoing::PingReq); self.events.push_back(event); - Ok(()) + Ok(None) } - fn outgoing_subscribe(&mut self, mut subscription: Subscribe) -> Result<(), StateError> { + fn outgoing_subscribe( + &mut self, + mut subscription: Subscribe, + ) -> Result, StateError> { if subscription.filters.is_empty() { return Err(StateError::EmptySubscription); } @@ -611,13 +619,16 @@ impl MqttState { ); let pkid = subscription.pkid; - Packet::Subscribe(subscription).write(&mut self.write)?; + Packet::Subscribe(subscription).write(&mut self.write, self.max_outgoing_packet_size)?; let event = Event::Outgoing(Outgoing::Subscribe(pkid)); self.events.push_back(event); - Ok(()) + Ok(None) } - fn outgoing_unsubscribe(&mut self, mut unsub: Unsubscribe) -> Result<(), StateError> { + fn outgoing_unsubscribe( + &mut self, + mut unsub: Unsubscribe, + ) -> Result, StateError> { let pkid = self.next_pkid(); unsub.pkid = pkid; @@ -627,19 +638,23 @@ impl MqttState { ); let pkid = unsub.pkid; - Packet::Unsubscribe(unsub).write(&mut self.write)?; + Packet::Unsubscribe(unsub).write(&mut self.write, self.max_outgoing_packet_size)?; let event = Event::Outgoing(Outgoing::Unsubscribe(pkid)); self.events.push_back(event); - Ok(()) + Ok(None) } - fn outgoing_disconnect(&mut self, reason: DisconnectReasonCode) -> Result<(), StateError> { + fn outgoing_disconnect( + &mut self, + reason: DisconnectReasonCode, + ) -> Result, StateError> { debug!("Disconnect with {:?}", reason); - Packet::Disconnect(Disconnect::new(reason)).write(&mut self.write)?; + Packet::Disconnect(Disconnect::new(reason)) + .write(&mut self.write, self.max_outgoing_packet_size)?; let event = Event::Outgoing(Outgoing::Disconnect); self.events.push_back(event); - Ok(()) + Ok(None) } fn check_collision(&mut self, pkid: u16) -> Option { @@ -652,18 +667,6 @@ impl MqttState { None } - fn check_size(&self, pkt_size: usize) -> Result<(), StateError> { - let pkt_size = pkt_size.try_into()?; - - match self.max_outgoing_packet_size { - Some(max_size) if pkt_size > max_size => Err(StateError::OutgoingPacketTooLarge { - pkt_size, - max: max_size, - }), - _ => Ok(()), - } - } - fn save_pubrel(&mut self, mut pubrel: PubRel) -> Result { let pubrel = match pubrel.pkid { // consider PacketIdentifier(0) as uninitialized packets diff --git a/rumqttc/tests/broker.rs b/rumqttc/tests/broker.rs index a6ebacc82..cbf781ed6 100644 --- a/rumqttc/tests/broker.rs +++ b/rumqttc/tests/broker.rs @@ -147,6 +147,9 @@ impl Broker { /// Selects between outgoing and incoming packets pub async fn tick(&mut self) -> Event { + if let Some(incoming) = self.incoming.pop_front() { + return Event::Incoming(incoming); + } select! { request = self.outgoing_rx.recv_async() => { let request = request.unwrap(); @@ -232,7 +235,7 @@ impl Network { pub async fn readb(&mut self, incoming: &mut VecDeque) -> io::Result<()> { let mut count = 0; loop { - match read(&mut self.read, self.max_incoming_size) { + match Packet::read(&mut self.read, self.max_incoming_size) { Ok(packet) => { incoming.push_back(packet); count += 1; diff --git a/rumqttc/tests/reliability.rs b/rumqttc/tests/reliability.rs index 3e7acd1e3..79b121f6a 100644 --- a/rumqttc/tests/reliability.rs +++ b/rumqttc/tests/reliability.rs @@ -292,7 +292,7 @@ async fn requests_are_blocked_after_max_inflight_queue_size() { #[tokio::test] async fn requests_are_recovered_after_inflight_queue_size_falls_below_max() { let mut options = MqttOptions::new("dummy", "127.0.0.1", 1888); - options.set_inflight(3); + options.set_inflight(3).set_network_buffer_capacity(0); // NOTE: to instantly flush let (client, mut eventloop) = AsyncClient::new(options, 5); @@ -474,7 +474,9 @@ async fn next_poll_after_connect_failure_reconnects() { #[tokio::test] async fn reconnection_resumes_from_the_previous_state() { let mut options = MqttOptions::new("dummy", "127.0.0.1", 3001); - options.set_keep_alive(Duration::from_secs(5)); + options + .set_keep_alive(Duration::from_secs(5)) + .set_network_buffer_capacity(0); // NOTE: to instantly flush // start sending qos0 publishes. Makes sure that there is out activity but no in activity let (client, mut eventloop) = AsyncClient::new(options, 5); @@ -514,7 +516,9 @@ async fn reconnection_resumes_from_the_previous_state() { #[tokio::test] async fn reconnection_resends_unacked_packets_from_the_previous_connection_first() { let mut options = MqttOptions::new("dummy", "127.0.0.1", 3002); - options.set_keep_alive(Duration::from_secs(5)); + options + .set_keep_alive(Duration::from_secs(5)) + .set_network_buffer_capacity(0); // NOTE: to instantly flush // start sending qos0 publishes. this makes sure that there is // outgoing activity but no incoming activity @@ -569,8 +573,8 @@ async fn state_is_being_cleaned_properly_and_pending_request_calculated_properly let res = run(&mut eventloop, false).await; if let Err(e) = res { match e { - ConnectionError::FlushTimeout => { - assert!(eventloop.state.write.is_empty()); + ConnectionError::MqttState(StateError::FlushTimeout) => { + assert!(eventloop.network.is_none()); println!("State is being clean properly"); } _ => {