diff --git a/Cargo.lock b/Cargo.lock index 2efed53fe..e6f8648cf 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1951,6 +1951,8 @@ dependencies = [ "tokio", "tokio-native-tls", "tokio-rustls", + "tokio-stream", + "tokio-util", "url", "ws_stream_tungstenite", ] @@ -2543,6 +2545,17 @@ dependencies = [ "tokio", ] +[[package]] +name = "tokio-stream" +version = "0.1.15" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "267ac89e0bec6e691e5813911606935d77c476ff49024f98abcea3e7b15e37af" +dependencies = [ + "futures-core", + "pin-project-lite", + "tokio", +] + [[package]] name = "tokio-util" version = "0.7.10" diff --git a/rumqttc/CHANGELOG.md b/rumqttc/CHANGELOG.md index a7c36b6f1..88bc2d319 100644 --- a/rumqttc/CHANGELOG.md +++ b/rumqttc/CHANGELOG.md @@ -15,6 +15,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ### Changed * rename `N` as `AsyncReadWrite` to describe usage. +* use `Framed` to encode/decode MQTT packets. ### Deprecated diff --git a/rumqttc/Cargo.toml b/rumqttc/Cargo.toml index 551dcab53..bba64822c 100644 --- a/rumqttc/Cargo.toml +++ b/rumqttc/Cargo.toml @@ -23,8 +23,9 @@ websocket = ["dep:async-tungstenite", "dep:ws_stream_tungstenite", "dep:http"] proxy = ["dep:async-http-proxy"] [dependencies] -futures-util = { version = "0.3", default_features = false, features = ["std"] } +futures-util = { version = "0.3", default_features = false, features = ["std", "sink"] } tokio = { version = "1.36", features = ["rt", "macros", "io-util", "net", "time"] } +tokio-util = { version = "0.7", features = ["codec"] } bytes = "1.5" log = "0.4" flume = { version = "0.11", default-features = false, features = ["async"] } @@ -47,6 +48,7 @@ native-tls = { version = "0.2.11", optional = true } url = { version = "2", default-features = false, optional = true } # proxy async-http-proxy = { version = "1.2.5", features = ["runtime-tokio", "basic-auth"], optional = true } +tokio-stream = "0.1.15" [dev-dependencies] bincode = "1.3.3" diff --git a/rumqttc/src/eventloop.rs b/rumqttc/src/eventloop.rs index fe971a6fa..1493b12eb 100644 --- a/rumqttc/src/eventloop.rs +++ b/rumqttc/src/eventloop.rs @@ -81,7 +81,7 @@ pub struct EventLoop { /// Pending packets from last session pub pending: VecDeque, /// Network connection to the broker - network: Option, + pub network: Option, /// Keep alive time keepalive_timeout: Option>>, pub network_options: NetworkOptions, @@ -104,11 +104,10 @@ impl EventLoop { let pending = VecDeque::new(); let max_inflight = mqtt_options.inflight; let manual_acks = mqtt_options.manual_acks; - let max_outgoing_packet_size = mqtt_options.max_outgoing_packet_size; EventLoop { mqtt_options, - state: MqttState::new(max_inflight, manual_acks, max_outgoing_packet_size), + state: MqttState::new(max_inflight, manual_acks), requests_tx, requests_rx, pending, @@ -189,7 +188,7 @@ impl EventLoop { o = network.readb(&mut self.state) => { o?; // flush all the acks and return first incoming packet - match time::timeout(network_timeout, network.flush(&mut self.state.write)).await { + match time::timeout(network_timeout, network.flush()).await { Ok(inner) => inner?, Err(_)=> return Err(ConnectionError::FlushTimeout), }; @@ -229,8 +228,10 @@ impl EventLoop { self.mqtt_options.pending_throttle ), if !self.pending.is_empty() || (!inflight_full && !collision) => match o { Ok(request) => { - self.state.handle_outgoing_packet(request)?; - match time::timeout(network_timeout, network.flush(&mut self.state.write)).await { + if let Some(outgoing) = self.state.handle_outgoing_packet(request)? { + network.write(outgoing).await?; + } + match time::timeout(network_timeout, network.flush()).await { Ok(inner) => inner?, Err(_)=> return Err(ConnectionError::FlushTimeout), }; @@ -245,8 +246,10 @@ impl EventLoop { let timeout = self.keepalive_timeout.as_mut().unwrap(); timeout.as_mut().reset(Instant::now() + self.mqtt_options.keep_alive); - self.state.handle_outgoing_packet(Request::PingReq(PingReq))?; - match time::timeout(network_timeout, network.flush(&mut self.state.write)).await { + if let Some(outgoing) = self.state.handle_outgoing_packet(Request::PingReq(PingReq))? { + network.write(outgoing).await?; + } + match time::timeout(network_timeout, network.flush()).await { Ok(inner) => inner?, Err(_)=> return Err(ConnectionError::FlushTimeout), }; @@ -356,7 +359,11 @@ async fn network_connect( if matches!(options.transport(), Transport::Unix) { let file = options.broker_addr.as_str(); let socket = UnixStream::connect(Path::new(file)).await?; - let network = Network::new(socket, options.max_incoming_packet_size); + let network = Network::new( + socket, + options.max_incoming_packet_size, + options.max_outgoing_packet_size, + ); return Ok(network); } @@ -388,13 +395,21 @@ async fn network_connect( }; let network = match options.transport() { - Transport::Tcp => Network::new(tcp_stream, options.max_incoming_packet_size), + Transport::Tcp => Network::new( + tcp_stream, + options.max_incoming_packet_size, + options.max_outgoing_packet_size, + ), #[cfg(any(feature = "use-rustls", feature = "use-native-tls"))] Transport::Tls(tls_config) => { let socket = tls::tls_connect(&options.broker_addr, options.port, &tls_config, tcp_stream) .await?; - Network::new(socket, options.max_incoming_packet_size) + Network::new( + socket, + options.max_incoming_packet_size, + options.max_outgoing_packet_size, + ) } #[cfg(unix)] Transport::Unix => unreachable!(), @@ -413,7 +428,11 @@ async fn network_connect( async_tungstenite::tokio::client_async(request, tcp_stream).await?; validate_response_headers(response)?; - Network::new(WsStream::new(socket), options.max_incoming_packet_size) + Network::new( + WsStream::new(socket), + options.max_incoming_packet_size, + options.max_outgoing_packet_size, + ) } #[cfg(all(feature = "use-rustls", feature = "websocket"))] Transport::Wss(tls_config) => { @@ -436,7 +455,11 @@ async fn network_connect( .await?; validate_response_headers(response)?; - Network::new(WsStream::new(socket), options.max_incoming_packet_size) + Network::new( + WsStream::new(socket), + options.max_incoming_packet_size, + options.max_outgoing_packet_size, + ) } }; diff --git a/rumqttc/src/framed.rs b/rumqttc/src/framed.rs index d2ec73674..4ccfcdad2 100644 --- a/rumqttc/src/framed.rs +++ b/rumqttc/src/framed.rs @@ -1,119 +1,98 @@ -use bytes::BytesMut; -use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt}; +use futures_util::{FutureExt, SinkExt, StreamExt}; +use tokio::io::{AsyncRead, AsyncWrite}; +use tokio_util::codec::Framed; use crate::mqttbytes::{self, v4::*}; use crate::{Incoming, MqttState, StateError}; -use std::io; /// Network transforms packets <-> frames efficiently. It takes /// advantage of pre-allocation, buffering and vectorization when /// appropriate to achieve performance pub struct Network { - /// Socket for IO - socket: Box, - /// Buffered reads - read: BytesMut, - /// Maximum packet size - max_incoming_size: usize, + /// Frame MQTT packets from network connection + framed: Framed, Codec>, /// Maximum readv count max_readb_count: usize, } impl Network { - pub fn new(socket: impl AsyncReadWrite + 'static, max_incoming_size: usize) -> Network { + pub fn new( + socket: impl AsyncReadWrite + 'static, + max_incoming_size: usize, + max_outgoing_size: usize, + ) -> Network { let socket = Box::new(socket) as Box; - Network { - socket, - read: BytesMut::with_capacity(10 * 1024), + let codec = Codec { max_incoming_size, + max_outgoing_size, + }; + let framed = Framed::new(socket, codec); + + Network { + framed, max_readb_count: 10, } } - /// Reads more than 'required' bytes to frame a packet into self.read buffer - async fn read_bytes(&mut self, required: usize) -> io::Result { - let mut total_read = 0; - loop { - let read = self.socket.read_buf(&mut self.read).await?; - if 0 == read { - return if self.read.is_empty() { - Err(io::Error::new( - io::ErrorKind::ConnectionAborted, - "connection closed by peer", - )) - } else { - Err(io::Error::new( - io::ErrorKind::ConnectionReset, - "connection reset by peer", - )) - }; - } - - total_read += read; - if total_read >= required { - return Ok(total_read); - } - } - } - - pub async fn read(&mut self) -> io::Result { - loop { - let required = match Packet::read(&mut self.read, self.max_incoming_size) { - Ok(packet) => return Ok(packet), - Err(mqttbytes::Error::InsufficientBytes(required)) => required, - Err(e) => return Err(io::Error::new(io::ErrorKind::InvalidData, e.to_string())), - }; - - // read more packets until a frame can be created. This function - // blocks until a frame can be created. Use this in a select! branch - self.read_bytes(required).await?; + /// Reads and returns a single packet from network + pub async fn read(&mut self) -> Result { + match self.framed.next().await { + Some(Ok(packet)) => Ok(packet), + Some(Err(mqttbytes::Error::InsufficientBytes(_))) | None => unreachable!(), + Some(Err(e)) => Err(StateError::Deserialization(e)), } } /// Read packets in bulk. This allow replies to be in bulk. This method is used /// after the connection is established to read a bunch of incoming packets pub async fn readb(&mut self, state: &mut MqttState) -> Result<(), StateError> { - let mut count = 0; + // wait for the first read + let mut res = self.framed.next().await; + let mut count = 1; loop { - match Packet::read(&mut self.read, self.max_incoming_size) { - Ok(packet) => { - state.handle_incoming_packet(packet)?; + match res { + Some(Ok(packet)) => { + if let Some(outgoing) = state.handle_incoming_packet(packet)? { + self.write(outgoing).await?; + } count += 1; if count >= self.max_readb_count { - return Ok(()); + break; } } - // If some packets are already framed, return those - Err(mqttbytes::Error::InsufficientBytes(_)) if count > 0 => return Ok(()), - // Wait for more bytes until a frame can be created - Err(mqttbytes::Error::InsufficientBytes(required)) => { - self.read_bytes(required).await?; - } - Err(e) => return Err(StateError::Deserialization(e)), + Some(Err(mqttbytes::Error::InsufficientBytes(_))) | None => unreachable!(), + Some(Err(e)) => return Err(StateError::Deserialization(e)), + } + // do not wait for subsequent reads + match self.framed.next().now_or_never() { + Some(r) => res = r, + _ => break, }; } - } - pub async fn connect(&mut self, connect: Connect) -> io::Result { - let mut write = BytesMut::new(); - let len = match connect.write(&mut write) { - Ok(size) => size, - Err(e) => return Err(io::Error::new(io::ErrorKind::InvalidData, e.to_string())), - }; + Ok(()) + } - self.socket.write_all(&write[..]).await?; - Ok(len) + /// Serializes packet into write buffer + pub async fn write(&mut self, packet: Packet) -> Result<(), StateError> { + self.framed + .feed(packet) + .await + .map_err(StateError::Deserialization) } - pub async fn flush(&mut self, write: &mut BytesMut) -> io::Result<()> { - if write.is_empty() { - return Ok(()); - } + pub async fn connect(&mut self, connect: Connect) -> Result<(), StateError> { + self.write(Packet::Connect(connect)).await?; - self.socket.write_all(&write[..]).await?; - write.clear(); - Ok(()) + self.flush().await + } + + pub async fn flush(&mut self) -> Result<(), crate::state::StateError> { + self.framed + .flush() + .await + .map_err(StateError::Deserialization) } } diff --git a/rumqttc/src/lib.rs b/rumqttc/src/lib.rs index 43dbb3bed..9c30d46a3 100644 --- a/rumqttc/src/lib.rs +++ b/rumqttc/src/lib.rs @@ -200,25 +200,6 @@ pub enum Request { Disconnect(Disconnect), } -impl Request { - fn size(&self) -> usize { - match &self { - Request::Publish(publish) => publish.size(), - Request::PubAck(puback) => puback.size(), - Request::PubRec(pubrec) => pubrec.size(), - Request::PubComp(pubcomp) => pubcomp.size(), - Request::PubRel(pubrel) => pubrel.size(), - Request::PingReq(pingreq) => pingreq.size(), - Request::PingResp(pingresp) => pingresp.size(), - Request::Subscribe(subscribe) => subscribe.size(), - Request::SubAck(suback) => suback.size(), - Request::Unsubscribe(unsubscribe) => unsubscribe.size(), - Request::UnsubAck(unsuback) => unsuback.size(), - Request::Disconnect(disconn) => disconn.size(), - } - } -} - impl From for Request { fn from(publish: Publish) -> Request { Request::Publish(publish) diff --git a/rumqttc/src/mqttbytes/mod.rs b/rumqttc/src/mqttbytes/mod.rs index 69858d80f..3345b897f 100644 --- a/rumqttc/src/mqttbytes/mod.rs +++ b/rumqttc/src/mqttbytes/mod.rs @@ -13,7 +13,7 @@ pub mod v4; pub use topic::*; /// Error during serialization and deserialization -#[derive(Debug, Clone, Copy, PartialEq, Eq, thiserror::Error)] +#[derive(Debug, thiserror::Error)] pub enum Error { #[error("Expected Connect, received: {0:?}")] NotConnect(PacketType), @@ -60,6 +60,10 @@ pub enum Error { /// proceed further #[error("At least {0} more bytes required to frame packet")] InsufficientBytes(usize), + #[error("IO: {0}")] + Io(#[from] std::io::Error), + #[error("Cannot send packet of size '{pkt_size:?}'. It's greater than the broker's maximum packet size of: '{max:?}'")] + OutgoingPacketTooLarge { pkt_size: usize, max: usize }, } /// MQTT packet type diff --git a/rumqttc/src/mqttbytes/v4/codec.rs b/rumqttc/src/mqttbytes/v4/codec.rs new file mode 100644 index 000000000..f605a1f5a --- /dev/null +++ b/rumqttc/src/mqttbytes/v4/codec.rs @@ -0,0 +1,73 @@ +use bytes::BytesMut; +use tokio_util::codec::{Decoder, Encoder}; + +use super::{Error, Packet}; + +/// MQTT v4 codec +#[derive(Debug, Clone)] +pub struct Codec { + /// Maximum packet size allowed by client + pub max_incoming_size: usize, + /// Maximum packet size allowed by broker + pub max_outgoing_size: usize, +} + +impl Decoder for Codec { + type Item = Packet; + type Error = Error; + + fn decode(&mut self, src: &mut BytesMut) -> Result, Self::Error> { + match Packet::read(src, self.max_incoming_size) { + Ok(packet) => Ok(Some(packet)), + Err(Error::InsufficientBytes(b)) => { + // Get more packets to construct the incomplete packet + src.reserve(b); + Ok(None) + } + Err(e) => Err(e), + } + } +} + +impl Encoder for Codec { + type Error = Error; + + fn encode(&mut self, item: Packet, dst: &mut BytesMut) -> Result<(), Self::Error> { + item.write(dst, self.max_outgoing_size)?; + + Ok(()) + } +} + +#[cfg(test)] +mod tests { + use bytes::BytesMut; + use tokio_util::codec::Encoder; + + use super::Codec; + use crate::{mqttbytes::Error, Packet, Publish, QoS}; + + #[test] + fn outgoing_max_packet_size_check() { + let mut buf = BytesMut::new(); + let mut codec = Codec { + max_incoming_size: 100, + max_outgoing_size: 200, + }; + + let mut small_publish = Publish::new("hello/world", QoS::AtLeastOnce, vec![1; 100]); + small_publish.pkid = 1; + codec + .encode(Packet::Publish(small_publish), &mut buf) + .unwrap(); + + let large_publish = Publish::new("hello/world", QoS::AtLeastOnce, vec![1; 265]); + match codec.encode(Packet::Publish(large_publish), &mut buf) { + Err(Error::OutgoingPacketTooLarge { + pkt_size: 281, + max: 200, + }) => {} + _ => unreachable!(), + } + } +} diff --git a/rumqttc/src/mqttbytes/v4/mod.rs b/rumqttc/src/mqttbytes/v4/mod.rs index 3c9225e82..ed438dd0f 100644 --- a/rumqttc/src/mqttbytes/v4/mod.rs +++ b/rumqttc/src/mqttbytes/v4/mod.rs @@ -1,5 +1,6 @@ use super::*; +mod codec; mod connack; mod connect; mod disconnect; @@ -27,6 +28,7 @@ pub use suback::*; pub use subscribe::*; pub use unsuback::*; pub use unsubscribe::*; +pub use codec::*; /// Encapsulates all MQTT packet types #[derive(Debug, Clone, PartialEq, Eq)] @@ -109,7 +111,14 @@ impl Packet { } /// Serializes the MQTT packet into a stream of bytes - pub fn write(&self, stream: &mut BytesMut) -> Result { + pub fn write(&self, stream: &mut BytesMut, max_size: usize) -> Result { + if self.size() > max_size { + return Err(Error::OutgoingPacketTooLarge { + pkt_size: self.size(), + max: max_size, + }) + } + match self { Packet::Connect(c) => c.write(stream), Packet::ConnAck(c) => c.write(stream), diff --git a/rumqttc/src/state.rs b/rumqttc/src/state.rs index acee6f1da..f6ffc5de0 100644 --- a/rumqttc/src/state.rs +++ b/rumqttc/src/state.rs @@ -2,7 +2,6 @@ use crate::{Event, Incoming, Outgoing, Request}; use crate::mqttbytes::v4::*; use crate::mqttbytes::{self, *}; -use bytes::BytesMut; use std::collections::VecDeque; use std::{io, time::Instant}; @@ -30,8 +29,6 @@ pub enum StateError { EmptySubscription, #[error("Mqtt serialization/deserialization error: {0}")] Deserialization(#[from] mqttbytes::Error), - #[error("Cannot send packet of size '{pkt_size:?}'. It's greater than the broker's maximum packet size of: '{max:?}'")] - OutgoingPacketTooLarge { pkt_size: usize, max: usize }, } /// State of the mqtt connection. @@ -70,19 +67,15 @@ pub struct MqttState { pub collision: Option, /// Buffered incoming packets pub events: VecDeque, - /// Write buffer - pub write: BytesMut, /// Indicates if acknowledgements should be send immediately pub manual_acks: bool, - /// Maximum outgoing packet size, set via MqttOptions - pub max_outgoing_packet_size: usize, } impl MqttState { /// Creates new mqtt state. Same state should be used during a /// connection for persistent sessions while new state should /// instantiated for clean sessions - pub fn new(max_inflight: u16, manual_acks: bool, max_outgoing_packet_size: usize) -> Self { + pub fn new(max_inflight: u16, manual_acks: bool) -> Self { MqttState { await_pingresp: false, collision_ping_count: 0, @@ -99,9 +92,7 @@ impl MqttState { collision: None, // TODO: Optimize these sizes later events: VecDeque::with_capacity(100), - write: BytesMut::with_capacity(10 * 1024), manual_acks, - max_outgoing_packet_size, } } @@ -135,7 +126,6 @@ impl MqttState { self.await_pingresp = false; self.collision_ping_count = 0; self.inflight = 0; - self.write.clear(); pending } @@ -145,10 +135,11 @@ impl MqttState { /// Consolidates handling of all outgoing mqtt packet logic. Returns a packet which should /// be put on to the network by the eventloop - pub fn handle_outgoing_packet(&mut self, request: Request) -> Result<(), StateError> { - // Enforce max outgoing packet size - self.check_size(request.size())?; - match request { + pub fn handle_outgoing_packet( + &mut self, + request: Request, + ) -> Result, StateError> { + let packet = match request { Request::Publish(publish) => self.outgoing_publish(publish)?, Request::PubRel(pubrel) => self.outgoing_pubrel(pubrel)?, Request::Subscribe(subscribe) => self.outgoing_subscribe(subscribe)?, @@ -161,56 +152,58 @@ impl MqttState { }; self.last_outgoing = Instant::now(); - Ok(()) + Ok(packet) } /// Consolidates handling of all incoming mqtt packets. Returns a `Notification` which for the /// user to consume and `Packet` which for the eventloop to put on the network /// E.g For incoming QoS1 publish packet, this method returns (Publish, Puback). Publish packet will /// be forwarded to user and Pubck packet will be written to network - pub fn handle_incoming_packet(&mut self, packet: Incoming) -> Result<(), StateError> { - let out = match &packet { - Incoming::PingResp => self.handle_incoming_pingresp(), - Incoming::Publish(publish) => self.handle_incoming_publish(publish), - Incoming::SubAck(_suback) => self.handle_incoming_suback(), - Incoming::UnsubAck(_unsuback) => self.handle_incoming_unsuback(), - Incoming::PubAck(puback) => self.handle_incoming_puback(puback), - Incoming::PubRec(pubrec) => self.handle_incoming_pubrec(pubrec), - Incoming::PubRel(pubrel) => self.handle_incoming_pubrel(pubrel), - Incoming::PubComp(pubcomp) => self.handle_incoming_pubcomp(pubcomp), + pub fn handle_incoming_packet( + &mut self, + packet: Incoming, + ) -> Result, StateError> { + let outgoing = match &packet { + Incoming::PingResp => self.handle_incoming_pingresp()?, + Incoming::Publish(publish) => self.handle_incoming_publish(publish)?, + Incoming::SubAck(_suback) => self.handle_incoming_suback()?, + Incoming::UnsubAck(_unsuback) => self.handle_incoming_unsuback()?, + Incoming::PubAck(puback) => self.handle_incoming_puback(puback)?, + Incoming::PubRec(pubrec) => self.handle_incoming_pubrec(pubrec)?, + Incoming::PubRel(pubrel) => self.handle_incoming_pubrel(pubrel)?, + Incoming::PubComp(pubcomp) => self.handle_incoming_pubcomp(pubcomp)?, _ => { error!("Invalid incoming packet = {:?}", packet); return Err(StateError::WrongPacket); } }; - - out?; self.events.push_back(Event::Incoming(packet)); self.last_incoming = Instant::now(); - Ok(()) + + Ok(outgoing) } - fn handle_incoming_suback(&mut self) -> Result<(), StateError> { - Ok(()) + fn handle_incoming_suback(&mut self) -> Result, StateError> { + Ok(None) } - fn handle_incoming_unsuback(&mut self) -> Result<(), StateError> { - Ok(()) + fn handle_incoming_unsuback(&mut self) -> Result, StateError> { + Ok(None) } /// Results in a publish notification in all the QoS cases. Replys with an ack /// in case of QoS1 and Replys rec in case of QoS while also storing the message - fn handle_incoming_publish(&mut self, publish: &Publish) -> Result<(), StateError> { + fn handle_incoming_publish(&mut self, publish: &Publish) -> Result, StateError> { let qos = publish.qos; match qos { - QoS::AtMostOnce => Ok(()), + QoS::AtMostOnce => Ok(None), QoS::AtLeastOnce => { if !self.manual_acks { let puback = PubAck::new(publish.pkid); - self.outgoing_puback(puback)?; + return self.outgoing_puback(puback); } - Ok(()) + Ok(None) } QoS::ExactlyOnce => { let pkid = publish.pkid; @@ -218,45 +211,41 @@ impl MqttState { if !self.manual_acks { let pubrec = PubRec::new(pkid); - self.outgoing_pubrec(pubrec)?; + return self.outgoing_pubrec(pubrec); } - Ok(()) + Ok(None) } } } - fn handle_incoming_puback(&mut self, puback: &PubAck) -> Result<(), StateError> { + fn handle_incoming_puback(&mut self, puback: &PubAck) -> Result, StateError> { let publish = self .outgoing_pub .get_mut(puback.pkid as usize) .ok_or(StateError::Unsolicited(puback.pkid))?; self.last_puback = puback.pkid; - let v = match publish.take() { - Some(_) => { - self.inflight -= 1; - Ok(()) - } - None => { - error!("Unsolicited puback packet: {:?}", puback.pkid); - Err(StateError::Unsolicited(puback.pkid)) - } - }; + publish.take().ok_or({ + error!("Unsolicited puback packet: {:?}", puback.pkid); + StateError::Unsolicited(puback.pkid) + })?; - if let Some(publish) = self.check_collision(puback.pkid) { + self.inflight -= 1; + let packet = self.check_collision(puback.pkid).map(|publish| { self.outgoing_pub[publish.pkid as usize] = Some(publish.clone()); self.inflight += 1; - publish.write(&mut self.write)?; let event = Event::Outgoing(Outgoing::Publish(publish.pkid)); self.events.push_back(event); self.collision_ping_count = 0; - } - v + Packet::Publish(publish) + }); + + Ok(packet) } - fn handle_incoming_pubrec(&mut self, pubrec: &PubRec) -> Result<(), StateError> { + fn handle_incoming_pubrec(&mut self, pubrec: &PubRec) -> Result, StateError> { let publish = self .outgoing_pub .get_mut(pubrec.pkid as usize) @@ -265,11 +254,11 @@ impl MqttState { Some(_) => { // NOTE: Inflight - 1 for qos2 in comp self.outgoing_rel[pubrec.pkid as usize] = Some(pubrec.pkid); - PubRel::new(pubrec.pkid).write(&mut self.write)?; - + let pubrel = PubRel { pkid: pubrec.pkid }; let event = Event::Outgoing(Outgoing::PubRel(pubrec.pkid)); self.events.push_back(event); - Ok(()) + + Ok(Some(Packet::PubRel(pubrel))) } None => { error!("Unsolicited pubrec packet: {:?}", pubrec.pkid); @@ -278,17 +267,18 @@ impl MqttState { } } - fn handle_incoming_pubrel(&mut self, pubrel: &PubRel) -> Result<(), StateError> { + fn handle_incoming_pubrel(&mut self, pubrel: &PubRel) -> Result, StateError> { let publish = self .incoming_pub .get_mut(pubrel.pkid as usize) .ok_or(StateError::Unsolicited(pubrel.pkid))?; match publish.take() { Some(_) => { - PubComp::new(pubrel.pkid).write(&mut self.write)?; let event = Event::Outgoing(Outgoing::PubComp(pubrel.pkid)); + let pubcomp = PubComp { pkid: pubrel.pkid }; self.events.push_back(event); - Ok(()) + + Ok(Some(Packet::PubComp(pubcomp))) } None => { error!("Unsolicited pubrel packet: {:?}", pubrel.pkid); @@ -297,38 +287,37 @@ impl MqttState { } } - fn handle_incoming_pubcomp(&mut self, pubcomp: &PubComp) -> Result<(), StateError> { - if let Some(publish) = self.check_collision(pubcomp.pkid) { - publish.write(&mut self.write)?; + fn handle_incoming_pubcomp(&mut self, pubcomp: &PubComp) -> Result, StateError> { + self.outgoing_rel + .get_mut(pubcomp.pkid as usize) + .ok_or(StateError::Unsolicited(pubcomp.pkid))? + .take() + .ok_or({ + error!("Unsolicited pubcomp packet: {:?}", pubcomp.pkid); + StateError::Unsolicited(pubcomp.pkid) + })?; + + self.inflight -= 1; + let packet = self.check_collision(pubcomp.pkid).map(|publish| { let event = Event::Outgoing(Outgoing::Publish(publish.pkid)); self.events.push_back(event); self.collision_ping_count = 0; - } - let pubrel = self - .outgoing_rel - .get_mut(pubcomp.pkid as usize) - .ok_or(StateError::Unsolicited(pubcomp.pkid))?; - match pubrel.take() { - Some(_) => { - self.inflight -= 1; - Ok(()) - } - None => { - error!("Unsolicited pubcomp packet: {:?}", pubcomp.pkid); - Err(StateError::Unsolicited(pubcomp.pkid)) - } - } + Packet::Publish(publish) + }); + + Ok(packet) } - fn handle_incoming_pingresp(&mut self) -> Result<(), StateError> { + fn handle_incoming_pingresp(&mut self) -> Result, StateError> { self.await_pingresp = false; - Ok(()) + + Ok(None) } /// Adds next packet identifier to QoS 1 and 2 publish packets and returns /// it buy wrapping publish in packet - fn outgoing_publish(&mut self, mut publish: Publish) -> Result<(), StateError> { + fn outgoing_publish(&mut self, mut publish: Publish) -> Result, StateError> { if publish.qos != QoS::AtMostOnce { if publish.pkid == 0 { publish.pkid = self.next_pkid(); @@ -345,7 +334,7 @@ impl MqttState { self.collision = Some(publish); let event = Event::Outgoing(Outgoing::AwaitAck(pkid)); self.events.push_back(event); - return Ok(()); + return Ok(None); } // if there is an existing publish at this pkid, this implies that broker hasn't acked this @@ -361,41 +350,40 @@ impl MqttState { publish.payload.len() ); - publish.write(&mut self.write)?; let event = Event::Outgoing(Outgoing::Publish(publish.pkid)); self.events.push_back(event); - Ok(()) + + Ok(Some(Packet::Publish(publish))) } - fn outgoing_pubrel(&mut self, pubrel: PubRel) -> Result<(), StateError> { + fn outgoing_pubrel(&mut self, pubrel: PubRel) -> Result, StateError> { let pubrel = self.save_pubrel(pubrel)?; debug!("Pubrel. Pkid = {}", pubrel.pkid); - PubRel::new(pubrel.pkid).write(&mut self.write)?; - let event = Event::Outgoing(Outgoing::PubRel(pubrel.pkid)); self.events.push_back(event); - Ok(()) + + Ok(Some(Packet::PubRel(pubrel))) } - fn outgoing_puback(&mut self, puback: PubAck) -> Result<(), StateError> { - puback.write(&mut self.write)?; + fn outgoing_puback(&mut self, puback: PubAck) -> Result, StateError> { let event = Event::Outgoing(Outgoing::PubAck(puback.pkid)); self.events.push_back(event); - Ok(()) + + Ok(Some(Packet::PubAck(puback))) } - fn outgoing_pubrec(&mut self, pubrec: PubRec) -> Result<(), StateError> { - pubrec.write(&mut self.write)?; + fn outgoing_pubrec(&mut self, pubrec: PubRec) -> Result, StateError> { let event = Event::Outgoing(Outgoing::PubRec(pubrec.pkid)); self.events.push_back(event); - Ok(()) + + Ok(Some(Packet::PubRec(pubrec))) } /// check when the last control packet/pingreq packet is received and return /// the status which tells if keep alive time has exceeded /// NOTE: status will be checked for zero keepalive times also - fn outgoing_ping(&mut self) -> Result<(), StateError> { + fn outgoing_ping(&mut self) -> Result, StateError> { let elapsed_in = self.last_incoming.elapsed(); let elapsed_out = self.last_outgoing.elapsed(); @@ -421,13 +409,16 @@ impl MqttState { elapsed_out.as_millis() ); - PingReq.write(&mut self.write)?; let event = Event::Outgoing(Outgoing::PingReq); self.events.push_back(event); - Ok(()) + + Ok(Some(Packet::PingReq)) } - fn outgoing_subscribe(&mut self, mut subscription: Subscribe) -> Result<(), StateError> { + fn outgoing_subscribe( + &mut self, + mut subscription: Subscribe, + ) -> Result, StateError> { if subscription.filters.is_empty() { return Err(StateError::EmptySubscription); } @@ -440,13 +431,16 @@ impl MqttState { subscription.filters, subscription.pkid ); - subscription.write(&mut self.write)?; let event = Event::Outgoing(Outgoing::Subscribe(subscription.pkid)); self.events.push_back(event); - Ok(()) + + Ok(Some(Packet::Subscribe(subscription))) } - fn outgoing_unsubscribe(&mut self, mut unsub: Unsubscribe) -> Result<(), StateError> { + fn outgoing_unsubscribe( + &mut self, + mut unsub: Unsubscribe, + ) -> Result, StateError> { let pkid = self.next_pkid(); unsub.pkid = pkid; @@ -455,19 +449,19 @@ impl MqttState { unsub.topics, unsub.pkid ); - unsub.write(&mut self.write)?; let event = Event::Outgoing(Outgoing::Unsubscribe(unsub.pkid)); self.events.push_back(event); - Ok(()) + + Ok(Some(Packet::Unsubscribe(unsub))) } - fn outgoing_disconnect(&mut self) -> Result<(), StateError> { + fn outgoing_disconnect(&mut self) -> Result, StateError> { debug!("Disconnect"); - Disconnect.write(&mut self.write)?; let event = Event::Outgoing(Outgoing::Disconnect); self.events.push_back(event); - Ok(()) + + Ok(Some(Packet::Disconnect)) } fn check_collision(&mut self, pkid: u16) -> Option { @@ -480,17 +474,6 @@ impl MqttState { None } - fn check_size(&self, pkt_size: usize) -> Result<(), StateError> { - if pkt_size > self.max_outgoing_packet_size { - Err(StateError::OutgoingPacketTooLarge { - pkt_size, - max: self.max_outgoing_packet_size, - }) - } else { - Ok(()) - } - } - fn save_pubrel(&mut self, mut pubrel: PubRel) -> Result { let pubrel = match pubrel.pkid { // consider PacketIdentifier(0) as uninitialized packets @@ -532,7 +515,6 @@ mod test { use crate::mqttbytes::v4::*; use crate::mqttbytes::*; use crate::{Event, Incoming, Outgoing, Request}; - use bytes::BufMut; fn build_outgoing_publish(qos: QoS) -> Publish { let topic = "hello/world".to_owned(); @@ -554,7 +536,7 @@ mod test { } fn build_mqttstate() -> MqttState { - MqttState::new(100, false, usize::MAX) + MqttState::new(100, false) } #[test] @@ -574,25 +556,6 @@ mod test { } } - #[test] - fn outgoing_max_packet_size_check() { - let mut mqtt = MqttState::new(100, false, 200); - - let small_publish = Publish::new("hello/world", QoS::AtLeastOnce, vec![1; 100]); - assert_eq!( - mqtt.handle_outgoing_packet(Request::Publish(small_publish)) - .is_ok(), - true - ); - - let large_publish = Publish::new("hello/world", QoS::AtLeastOnce, vec![1; 265]); - assert_eq!( - mqtt.handle_outgoing_packet(Request::Publish(large_publish)) - .is_ok(), - false - ); - } - #[test] fn outgoing_publish_should_set_pkid_and_add_publish_to_queue() { let mut mqtt = build_mqttstate(); @@ -702,8 +665,7 @@ mod test { let mut mqtt = build_mqttstate(); let publish = build_incoming_publish(QoS::ExactlyOnce, 1); - mqtt.handle_incoming_publish(&publish).unwrap(); - let packet = Packet::read(&mut mqtt.write, 10 * 1024).unwrap(); + let packet = mqtt.handle_incoming_publish(&publish).unwrap().unwrap(); match packet { Packet::PubRec(pubrec) => assert_eq!(pubrec.pkid, 1), _ => panic!("Invalid network request: {:?}", packet), @@ -769,15 +731,16 @@ mod test { let mut mqtt = build_mqttstate(); let publish = build_outgoing_publish(QoS::ExactlyOnce); - mqtt.outgoing_publish(publish).unwrap(); - let packet = Packet::read(&mut mqtt.write, 10 * 1024).unwrap(); + let packet = mqtt.outgoing_publish(publish).unwrap().unwrap(); match packet { Packet::Publish(publish) => assert_eq!(publish.pkid, 1), packet => panic!("Invalid network request: {:?}", packet), } - mqtt.handle_incoming_pubrec(&PubRec::new(1)).unwrap(); - let packet = Packet::read(&mut mqtt.write, 10 * 1024).unwrap(); + let packet = mqtt + .handle_incoming_pubrec(&PubRec::new(1)) + .unwrap() + .unwrap(); match packet { Packet::PubRel(pubrel) => assert_eq!(pubrel.pkid, 1), packet => panic!("Invalid network request: {:?}", packet), @@ -789,15 +752,16 @@ mod test { let mut mqtt = build_mqttstate(); let publish = build_incoming_publish(QoS::ExactlyOnce, 1); - mqtt.handle_incoming_publish(&publish).unwrap(); - let packet = Packet::read(&mut mqtt.write, 10 * 1024).unwrap(); + let packet = mqtt.handle_incoming_publish(&publish).unwrap().unwrap(); match packet { Packet::PubRec(pubrec) => assert_eq!(pubrec.pkid, 1), packet => panic!("Invalid network request: {:?}", packet), } - mqtt.handle_incoming_pubrel(&PubRel::new(1)).unwrap(); - let packet = Packet::read(&mut mqtt.write, 10 * 1024).unwrap(); + let packet = mqtt + .handle_incoming_pubrel(&PubRel::new(1)) + .unwrap() + .unwrap(); match packet { Packet::PubComp(pubcomp) => assert_eq!(pubcomp.pkid, 1), packet => panic!("Invalid network request: {:?}", packet), @@ -848,15 +812,6 @@ mod test { mqtt.outgoing_ping().unwrap(); } - #[test] - fn state_should_be_clean_properly() { - let mut mqtt = build_mqttstate(); - mqtt.write.put(&b"test"[..]); - // After this clean state.write should be empty - mqtt.clean(); - assert!(mqtt.write.is_empty()); - } - #[test] fn clean_is_calculating_pending_correctly() { let mut mqtt = build_mqttstate(); diff --git a/rumqttc/src/v5/eventloop.rs b/rumqttc/src/v5/eventloop.rs index 27c26f29d..ab1edb17c 100644 --- a/rumqttc/src/v5/eventloop.rs +++ b/rumqttc/src/v5/eventloop.rs @@ -9,7 +9,6 @@ use tokio::select; use tokio::time::{self, error::Elapsed, Instant, Sleep}; use std::collections::VecDeque; -use std::convert::TryInto; use std::io; use std::pin::Pin; use std::time::Duration; @@ -211,7 +210,7 @@ impl EventLoop { ), if !self.pending.is_empty() || (!inflight_full && !collision) => match o { Ok(request) => { self.state.handle_outgoing_packet(request)?; - network.flush(&mut self.state.write).await?; + network.flush().await?; Ok(self.state.events.pop_front().unwrap()) } Err(_) => Err(ConnectionError::RequestsDone), @@ -220,7 +219,7 @@ impl EventLoop { o = network.readb(&mut self.state) => { o?; // flush all the acks and return first incoming packet - network.flush(&mut self.state.write).await?; + network.flush().await?; Ok(self.state.events.pop_front().unwrap()) }, // We generate pings irrespective of network activity. This keeps the ping logic @@ -230,7 +229,7 @@ impl EventLoop { timeout.as_mut().reset(Instant::now() + self.options.keep_alive); self.state.handle_outgoing_packet(Request::PingReq)?; - network.flush(&mut self.state.write).await?; + network.flush().await?; Ok(self.state.events.pop_front().unwrap()) } } @@ -281,7 +280,6 @@ async fn network_connect(options: &MqttOptions) -> Result frames efficiently. It takes /// advantage of pre-allocation, buffering and vectorization when /// appropriate to achieve performance pub struct Network { - /// Socket for IO - socket: Box, - /// Buffered reads - read: BytesMut, - /// Maximum packet size - max_incoming_size: Option, + /// Frame MQTT packets from network connection + framed: Framed, Codec>, /// Maximum readv count max_readb_count: usize, } - impl Network { - pub fn new(socket: impl AsyncReadWrite + 'static, max_incoming_size: Option) -> Network { + pub fn new(socket: impl AsyncReadWrite + 'static, max_incoming_size: Option) -> Network { let socket = Box::new(socket) as Box; - Network { - socket, - read: BytesMut::with_capacity(10 * 1024), + let codec = Codec { max_incoming_size, + max_outgoing_size: None, + }; + let framed = Framed::new(socket, codec); + + Network { + framed, max_readb_count: 10, } } - /// Reads more than 'required' bytes to frame a packet into self.read buffer - async fn read_bytes(&mut self, required: usize) -> io::Result { - let mut total_read = 0; - loop { - let read = self.socket.read_buf(&mut self.read).await?; - if 0 == read { - return if self.read.is_empty() { - Err(io::Error::new( - io::ErrorKind::ConnectionAborted, - "connection closed by peer", - )) - } else { - Err(io::Error::new( - io::ErrorKind::ConnectionReset, - "connection reset by peer", - )) - }; - } - - total_read += read; - if total_read >= required { - return Ok(total_read); - } - } + pub fn set_max_outgoing_size(&mut self, max_outgoing_size: Option) { + self.framed.codec_mut().max_outgoing_size = max_outgoing_size; } - pub async fn read(&mut self) -> io::Result { - loop { - let required = match Packet::read(&mut self.read, self.max_incoming_size) { - Ok(packet) => return Ok(packet), - Err(mqttbytes::Error::InsufficientBytes(required)) => required, - Err(e) => return Err(io::Error::new(io::ErrorKind::InvalidData, e.to_string())), - }; - - // read more packets until a frame can be created. This function - // blocks until a frame can be created. Use this in a select! branch - self.read_bytes(required).await?; + /// Reads and returns a single packet from network + pub async fn read(&mut self) -> Result { + match self.framed.next().await { + Some(Ok(packet)) => Ok(packet), + Some(Err(mqttbytes::Error::InsufficientBytes(_))) | None => unreachable!(), + Some(Err(e)) => Err(StateError::Deserialization(e)), } } /// Read packets in bulk. This allow replies to be in bulk. This method is used /// after the connection is established to read a bunch of incoming packets pub async fn readb(&mut self, state: &mut MqttState) -> Result<(), StateError> { - let mut count = 0; + // wait for the first read + let mut res = self.framed.next().await; + let mut count = 1; loop { - match Packet::read(&mut self.read, self.max_incoming_size) { - Ok(packet) => { - state.handle_incoming_packet(packet)?; + match res { + Some(Ok(packet)) => { + if let Some(outgoing) = state.handle_incoming_packet(packet)? { + self.write(outgoing).await?; + } count += 1; if count >= self.max_readb_count { - return Ok(()); + break; } } - // If some packets are already framed, return those - Err(mqttbytes::Error::InsufficientBytes(_)) if count > 0 => return Ok(()), - // Wait for more bytes until a frame can be created - Err(mqttbytes::Error::InsufficientBytes(required)) => { - self.read_bytes(required).await?; - } - Err(mqttbytes::Error::PayloadSizeLimitExceeded { pkt_size, max }) => { - state.handle_protocol_error()?; - return Err(StateError::IncomingPacketTooLarge { pkt_size, max }); - } - Err(e) => return Err(StateError::Deserialization(e)), + Some(Err(mqttbytes::Error::InsufficientBytes(_))) | None => unreachable!(), + Some(Err(e)) => return Err(StateError::Deserialization(e)), + } + // do not wait for subsequent reads + match self.framed.next().now_or_never() { + Some(r) => res = r, + _ => break, }; } + + Ok(()) } - pub async fn connect(&mut self, connect: Connect, options: &MqttOptions) -> io::Result { - let mut write = BytesMut::new(); + /// Serializes packet into write buffer + pub async fn write(&mut self, packet: Packet) -> Result<(), StateError> { + self.framed + .feed(packet) + .await + .map_err(StateError::Deserialization) + } + + pub async fn connect( + &mut self, + connect: Connect, + options: &MqttOptions, + ) -> Result<(), StateError> { let last_will = options.last_will(); let login = options.credentials().map(|l| Login { username: l.0, password: l.1, }); + self.write(Packet::Connect(connect, last_will, login)) + .await?; - let len = match Packet::Connect(connect, last_will, login).write(&mut write) { - Ok(size) => size, - Err(e) => return Err(io::Error::new(io::ErrorKind::InvalidData, e.to_string())), - }; - - self.socket.write_all(&write[..]).await?; - Ok(len) + self.flush().await } - pub async fn flush(&mut self, write: &mut BytesMut) -> io::Result<()> { - if write.is_empty() { - return Ok(()); - } - - self.socket.write_all(&write[..]).await?; - write.clear(); - Ok(()) + pub async fn flush(&mut self) -> Result<(), StateError> { + self.framed + .flush() + .await + .map_err(StateError::Deserialization) } } diff --git a/rumqttc/src/v5/mod.rs b/rumqttc/src/v5/mod.rs index 663cfd278..2ed75a5d2 100644 --- a/rumqttc/src/v5/mod.rs +++ b/rumqttc/src/v5/mod.rs @@ -89,7 +89,7 @@ pub struct MqttOptions { conn_timeout: u64, /// Default value of for maximum incoming packet size. /// Used when `max_incomming_size` in `connect_properties` is NOT available. - default_max_incoming_size: usize, + default_max_incoming_size: u32, /// Connect Properties connect_properties: Option, /// If set to `true` MQTT acknowledgements are not sent automatically. diff --git a/rumqttc/src/v5/mqttbytes/mod.rs b/rumqttc/src/v5/mqttbytes/mod.rs index 231c68067..c205aaa91 100644 --- a/rumqttc/src/v5/mqttbytes/mod.rs +++ b/rumqttc/src/v5/mqttbytes/mod.rs @@ -130,7 +130,7 @@ pub fn matches(topic: &str, filter: &str) -> bool { } /// Error during serialization and deserialization -#[derive(Debug, Clone, Copy, PartialEq, Eq, thiserror::Error)] +#[derive(Debug, thiserror::Error)] pub enum Error { #[error("Invalid return code received as response for connect = {0}")] InvalidConnectReturnCode(u8), @@ -163,7 +163,7 @@ pub enum Error { #[error("Payload is too long")] PayloadTooLong, #[error("Max Payload size of {max:?} has been exceeded by packet of {pkt_size:?} bytes")] - PayloadSizeLimitExceeded { pkt_size: usize, max: usize }, + PayloadSizeLimitExceeded { pkt_size: usize, max: u32 }, #[error("Payload is required")] PayloadRequired, #[error("Payload is required = {0}")] @@ -183,4 +183,8 @@ pub enum Error { /// proceed further #[error("Insufficient number of bytes to frame packet, {0} more bytes required")] InsufficientBytes(usize), + #[error("IO: {0}")] + Io(#[from] std::io::Error), + #[error("Cannot send packet of size '{pkt_size:?}'. It's greater than the broker's maximum packet size of: '{max:?}'")] + OutgoingPacketTooLarge { pkt_size: u32, max: u32 }, } diff --git a/rumqttc/src/v5/mqttbytes/v5/codec.rs b/rumqttc/src/v5/mqttbytes/v5/codec.rs new file mode 100644 index 000000000..76909d62d --- /dev/null +++ b/rumqttc/src/v5/mqttbytes/v5/codec.rs @@ -0,0 +1,76 @@ +use bytes::BytesMut; +use tokio_util::codec::{Decoder, Encoder}; + +use super::{Error, Packet}; + +/// MQTT v4 codec +#[derive(Debug, Clone)] +pub struct Codec { + /// Maximum packet size allowed by client + pub max_incoming_size: Option, + /// Maximum packet size allowed by broker + pub max_outgoing_size: Option, +} + +impl Decoder for Codec { + type Item = Packet; + type Error = Error; + + fn decode(&mut self, src: &mut BytesMut) -> Result, Self::Error> { + match Packet::read(src, self.max_incoming_size) { + Ok(packet) => Ok(Some(packet)), + Err(Error::InsufficientBytes(b)) => { + // Get more packets to construct the incomplete packet + src.reserve(b); + Ok(None) + } + Err(e) => Err(e), + } + } +} + +impl Encoder for Codec { + type Error = Error; + + fn encode(&mut self, item: Packet, dst: &mut BytesMut) -> Result<(), Self::Error> { + item.write(dst, self.max_outgoing_size)?; + + Ok(()) + } +} + +#[cfg(test)] +mod tests { + use bytes::BytesMut; + use tokio_util::codec::Encoder; + + use super::Codec; + use crate::v5::{ + mqttbytes::{Error, QoS}, + Packet, Publish, + }; + + #[test] + fn outgoing_max_packet_size_check() { + let mut buf = BytesMut::new(); + let mut codec = Codec { + max_incoming_size: Some(100), + max_outgoing_size: Some(200), + }; + + let mut small_publish = Publish::new("hello/world", QoS::AtLeastOnce, vec![1; 100], None); + small_publish.pkid = 1; + codec + .encode(Packet::Publish(small_publish), &mut buf) + .unwrap(); + + let large_publish = Publish::new("hello/world", QoS::AtLeastOnce, vec![1; 265], None); + match codec.encode(Packet::Publish(large_publish), &mut buf) { + Err(Error::OutgoingPacketTooLarge { + pkt_size: 282, + max: 200, + }) => {} + _ => unreachable!(), + } + } +} diff --git a/rumqttc/src/v5/mqttbytes/v5/mod.rs b/rumqttc/src/v5/mqttbytes/v5/mod.rs index 01ddef992..342278596 100644 --- a/rumqttc/src/v5/mqttbytes/v5/mod.rs +++ b/rumqttc/src/v5/mqttbytes/v5/mod.rs @@ -1,6 +1,7 @@ use std::slice::Iter; pub use self::{ + codec::Codec, connack::{ConnAck, ConnAckProperties, ConnectReturnCode}, connect::{Connect, ConnectProperties, LastWill, LastWillProperties, Login}, disconnect::{Disconnect, DisconnectReasonCode}, @@ -19,6 +20,7 @@ pub use self::{ use super::*; use bytes::{Buf, BufMut, Bytes, BytesMut}; +mod codec; mod connack; mod connect; mod disconnect; @@ -53,7 +55,7 @@ pub enum Packet { impl Packet { /// Reads a stream of bytes and extracts next MQTT packet out of it - pub fn read(stream: &mut BytesMut, max_size: Option) -> Result { + pub fn read(stream: &mut BytesMut, max_size: Option) -> Result { let fixed_header = check(stream.iter(), max_size)?; // Test with a stream with exactly the size to check border panics @@ -126,7 +128,16 @@ impl Packet { Ok(packet) } - pub fn write(&self, write: &mut BytesMut) -> Result { + pub fn write(&self, write: &mut BytesMut, max_size: Option) -> Result { + if let Some(max_size) = max_size { + if self.size() > max_size as usize { + return Err(Error::OutgoingPacketTooLarge { + pkt_size: self.size() as u32, + max: max_size, + }); + } + } + match self { Self::Publish(publish) => publish.write(write), Self::Subscribe(subscription) => subscription.write(write), @@ -320,7 +331,7 @@ fn property(num: u8) -> Result { /// The passed stream doesn't modify parent stream's cursor. If this function /// returned an error, next `check` on the same parent stream is forced start /// with cursor at 0 again (Iter is owned. Only Iter's cursor is changed internally) -pub fn check(stream: Iter, max_packet_size: Option) -> Result { +pub fn check(stream: Iter, max_packet_size: Option) -> Result { // Create fixed header if there are enough bytes in the stream // to frame full packet let stream_len = stream.len(); @@ -329,7 +340,7 @@ pub fn check(stream: Iter, max_packet_size: Option) -> Result max_size { + if fixed_header.remaining_len > max_size as usize { return Err(Error::PayloadSizeLimitExceeded { pkt_size: fixed_header.remaining_len, max: max_size, diff --git a/rumqttc/src/v5/state.rs b/rumqttc/src/v5/state.rs index 8473f1f4c..456272b4d 100644 --- a/rumqttc/src/v5/state.rs +++ b/rumqttc/src/v5/state.rs @@ -3,13 +3,12 @@ use super::mqttbytes::v5::{ PubAckReason, PubComp, PubCompReason, PubRec, PubRecReason, PubRel, PubRelReason, Publish, SubAck, Subscribe, SubscribeReasonCode, UnsubAck, UnsubAckReason, Unsubscribe, }; -use super::mqttbytes::{self, QoS}; +use super::mqttbytes::{self, Error as MqttError, QoS}; use super::{Event, Incoming, Outgoing, Request}; -use bytes::{Bytes, BytesMut}; +use bytes::Bytes; use std::collections::{HashMap, VecDeque}; -use std::convert::TryInto; use std::{io, time::Instant}; /// Errors during state handling @@ -37,7 +36,7 @@ pub enum StateError { #[error("A Subscribe packet must contain atleast one filter")] EmptySubscription, #[error("Mqtt serialization/deserialization error: {0}")] - Deserialization(#[from] mqttbytes::Error), + Deserialization(MqttError), #[error( "Cannot use topic alias '{alias:?}'. It's greater than the broker's maximum of '{max:?}'." )] @@ -67,6 +66,17 @@ pub enum StateError { ConnFail { reason: ConnectReturnCode }, } +impl From for StateError { + fn from(value: MqttError) -> Self { + match value { + MqttError::OutgoingPacketTooLarge { pkt_size, max } => { + StateError::OutgoingPacketTooLarge { pkt_size, max } + } + e => StateError::Deserialization(e), + } + } +} + /// State of the mqtt connection. // Design: Methods will just modify the state of the object without doing any network operations // Design: All inflight queues are maintained in a pre initialized vec with index as packet id. @@ -99,16 +109,12 @@ pub struct MqttState { pub collision: Option, /// Buffered incoming packets pub events: VecDeque, - /// Write buffer - pub write: BytesMut, /// Indicates if acknowledgements should be send immediately pub manual_acks: bool, /// Map of alias_id->topic topic_alises: HashMap, /// `topic_alias_maximum` RECEIVED via connack packet pub broker_topic_alias_max: u16, - /// The broker's `max_packet_size` received via connack - pub max_outgoing_packet_size: Option, /// Maximum number of allowed inflight QoS1 & QoS2 requests pub(crate) max_outgoing_inflight: u16, /// Upper limit on the maximum number of allowed inflight QoS1 & QoS2 requests @@ -134,12 +140,10 @@ impl MqttState { collision: None, // TODO: Optimize these sizes later events: VecDeque::with_capacity(100), - write: BytesMut::with_capacity(10 * 1024), manual_acks, topic_alises: HashMap::new(), // Set via CONNACK broker_topic_alias_max: 0, - max_outgoing_packet_size: None, max_outgoing_inflight: max_inflight, max_outgoing_inflight_upper_limit: max_inflight, } @@ -181,77 +185,67 @@ impl MqttState { /// Consolidates handling of all outgoing mqtt packet logic. Returns a packet which should /// be put on to the network by the eventloop - pub fn handle_outgoing_packet(&mut self, request: Request) -> Result<(), StateError> { - match request { - Request::Publish(publish) => { - self.check_size(publish.size())?; - self.outgoing_publish(publish)? - } - Request::PubRel(pubrel) => { - self.check_size(pubrel.size())?; - self.outgoing_pubrel(pubrel)? - } - Request::Subscribe(subscribe) => { - self.check_size(subscribe.size())?; - self.outgoing_subscribe(subscribe)? - } - Request::Unsubscribe(unsubscribe) => { - self.check_size(unsubscribe.size())?; - self.outgoing_unsubscribe(unsubscribe)? - } + pub fn handle_outgoing_packet( + &mut self, + request: Request, + ) -> Result, StateError> { + let packet = match request { + Request::Publish(publish) => self.outgoing_publish(publish)?, + Request::PubRel(pubrel) => self.outgoing_pubrel(pubrel)?, + Request::Subscribe(subscribe) => self.outgoing_subscribe(subscribe)?, + Request::Unsubscribe(unsubscribe) => self.outgoing_unsubscribe(unsubscribe)?, Request::PingReq => self.outgoing_ping()?, Request::Disconnect => { self.outgoing_disconnect(DisconnectReasonCode::NormalDisconnection)? } - Request::PubAck(puback) => { - self.check_size(puback.size())?; - self.outgoing_puback(puback)? - } - Request::PubRec(pubrec) => { - self.check_size(pubrec.size())?; - self.outgoing_pubrec(pubrec)? - } + Request::PubAck(puback) => self.outgoing_puback(puback)?, + Request::PubRec(pubrec) => self.outgoing_pubrec(pubrec)?, _ => unimplemented!(), }; self.last_outgoing = Instant::now(); - Ok(()) + Ok(packet) } /// Consolidates handling of all incoming mqtt packets. Returns a `Notification` which for the /// user to consume and `Packet` which for the eventloop to put on the network /// E.g For incoming QoS1 publish packet, this method returns (Publish, Puback). Publish packet will /// be forwarded to user and Pubck packet will be written to network - pub fn handle_incoming_packet(&mut self, mut packet: Incoming) -> Result<(), StateError> { - let out = match &mut packet { - Incoming::PingResp(_) => self.handle_incoming_pingresp(), - Incoming::Publish(publish) => self.handle_incoming_publish(publish), - Incoming::SubAck(suback) => self.handle_incoming_suback(suback), - Incoming::UnsubAck(unsuback) => self.handle_incoming_unsuback(unsuback), - Incoming::PubAck(puback) => self.handle_incoming_puback(puback), - Incoming::PubRec(pubrec) => self.handle_incoming_pubrec(pubrec), - Incoming::PubRel(pubrel) => self.handle_incoming_pubrel(pubrel), - Incoming::PubComp(pubcomp) => self.handle_incoming_pubcomp(pubcomp), - Incoming::ConnAck(connack) => self.handle_incoming_connack(connack), - Incoming::Disconnect(disconn) => self.handle_incoming_disconn(disconn), + pub fn handle_incoming_packet( + &mut self, + mut packet: Incoming, + ) -> Result, StateError> { + let outgoing = match &mut packet { + Incoming::PingResp(_) => self.handle_incoming_pingresp()?, + Incoming::Publish(publish) => self.handle_incoming_publish(publish)?, + Incoming::SubAck(suback) => self.handle_incoming_suback(suback)?, + Incoming::UnsubAck(unsuback) => self.handle_incoming_unsuback(unsuback)?, + Incoming::PubAck(puback) => self.handle_incoming_puback(puback)?, + Incoming::PubRec(pubrec) => self.handle_incoming_pubrec(pubrec)?, + Incoming::PubRel(pubrel) => self.handle_incoming_pubrel(pubrel)?, + Incoming::PubComp(pubcomp) => self.handle_incoming_pubcomp(pubcomp)?, + Incoming::ConnAck(connack) => self.handle_incoming_connack(connack)?, + Incoming::Disconnect(disconn) => self.handle_incoming_disconn(disconn)?, _ => { error!("Invalid incoming packet = {:?}", packet); return Err(StateError::WrongPacket); } }; - out?; self.events.push_back(Event::Incoming(packet)); self.last_incoming = Instant::now(); - Ok(()) + Ok(outgoing) } - pub fn handle_protocol_error(&mut self) -> Result<(), StateError> { + pub fn handle_protocol_error(&mut self) -> Result, StateError> { // send DISCONNECT packet with REASON_CODE 0x82 self.outgoing_disconnect(DisconnectReasonCode::ProtocolError) } - fn handle_incoming_suback(&mut self, suback: &mut SubAck) -> Result<(), StateError> { + fn handle_incoming_suback( + &mut self, + suback: &mut SubAck, + ) -> Result, StateError> { for reason in suback.return_codes.iter() { match reason { SubscribeReasonCode::Success(qos) => { @@ -260,19 +254,25 @@ impl MqttState { _ => return Err(StateError::SubFail { reason: *reason }), } } - Ok(()) + Ok(None) } - fn handle_incoming_unsuback(&mut self, unsuback: &mut UnsubAck) -> Result<(), StateError> { + fn handle_incoming_unsuback( + &mut self, + unsuback: &mut UnsubAck, + ) -> Result, StateError> { for reason in unsuback.reasons.iter() { if reason != &UnsubAckReason::Success { return Err(StateError::UnsubFail { reason: *reason }); } } - Ok(()) + Ok(None) } - fn handle_incoming_connack(&mut self, connack: &mut ConnAck) -> Result<(), StateError> { + fn handle_incoming_connack( + &mut self, + connack: &mut ConnAck, + ) -> Result, StateError> { if connack.code != ConnectReturnCode::Success { return Err(StateError::ConnFail { reason: connack.code, @@ -290,13 +290,14 @@ impl MqttState { // FIXME: Maybe resize the pubrec and pubrel queues here // to save some space. } - - self.max_outgoing_packet_size = props.max_packet_size; } - Ok(()) + Ok(None) } - fn handle_incoming_disconn(&mut self, disconn: &mut Disconnect) -> Result<(), StateError> { + fn handle_incoming_disconn( + &mut self, + disconn: &mut Disconnect, + ) -> Result, StateError> { let reason_code = disconn.reason_code; let reason_string = if let Some(props) = &disconn.properties { props.reason_string.clone() @@ -311,7 +312,10 @@ impl MqttState { /// Results in a publish notification in all the QoS cases. Replys with an ack /// in case of QoS1 and Replys rec in case of QoS while also storing the message - fn handle_incoming_publish(&mut self, publish: &mut Publish) -> Result<(), StateError> { + fn handle_incoming_publish( + &mut self, + publish: &mut Publish, + ) -> Result, StateError> { let qos = publish.qos; let topic_alias = match &publish.properties { @@ -332,13 +336,13 @@ impl MqttState { } match qos { - QoS::AtMostOnce => Ok(()), + QoS::AtMostOnce => Ok(None), QoS::AtLeastOnce => { if !self.manual_acks { let puback = PubAck::new(publish.pkid, None); - self.outgoing_puback(puback)?; + return self.outgoing_puback(puback); } - Ok(()) + Ok(None) } QoS::ExactlyOnce => { let pkid = publish.pkid; @@ -346,14 +350,14 @@ impl MqttState { if !self.manual_acks { let pubrec = PubRec::new(pkid, None); - self.outgoing_pubrec(pubrec)?; + return self.outgoing_pubrec(pubrec); } - Ok(()) + Ok(None) } } } - fn handle_incoming_puback(&mut self, puback: &PubAck) -> Result<(), StateError> { + fn handle_incoming_puback(&mut self, puback: &PubAck) -> Result, StateError> { let publish = self .outgoing_pub .get_mut(puback.pkid as usize) @@ -361,7 +365,8 @@ impl MqttState { let v = match publish.take() { Some(_) => { self.inflight -= 1; - Ok(()) + + Ok(None) } None => { error!("Unsolicited puback packet: {:?}", puback.pkid); @@ -382,16 +387,17 @@ impl MqttState { self.inflight += 1; let pkid = publish.pkid; - Packet::Publish(publish).write(&mut self.write)?; let event = Event::Outgoing(Outgoing::Publish(pkid)); self.events.push_back(event); self.collision_ping_count = 0; + + return Ok(Some(Packet::Publish(publish))); } v } - fn handle_incoming_pubrec(&mut self, pubrec: &PubRec) -> Result<(), StateError> { + fn handle_incoming_pubrec(&mut self, pubrec: &PubRec) -> Result, StateError> { let publish = self .outgoing_pub .get_mut(pubrec.pkid as usize) @@ -408,11 +414,10 @@ impl MqttState { // NOTE: Inflight - 1 for qos2 in comp self.outgoing_rel[pubrec.pkid as usize] = Some(pubrec.pkid); - Packet::PubRel(PubRel::new(pubrec.pkid, None)).write(&mut self.write)?; - let event = Event::Outgoing(Outgoing::PubRel(pubrec.pkid)); self.events.push_back(event); - Ok(()) + + Ok(Some(Packet::PubRel(PubRel::new(pubrec.pkid, None)))) } None => { error!("Unsolicited pubrec packet: {:?}", pubrec.pkid); @@ -421,7 +426,7 @@ impl MqttState { } } - fn handle_incoming_pubrel(&mut self, pubrel: &PubRel) -> Result<(), StateError> { + fn handle_incoming_pubrel(&mut self, pubrel: &PubRel) -> Result, StateError> { let publish = self .incoming_pub .get_mut(pubrel.pkid as usize) @@ -434,10 +439,10 @@ impl MqttState { }); } - Packet::PubComp(PubComp::new(pubrel.pkid, None)).write(&mut self.write)?; let event = Event::Outgoing(Outgoing::PubComp(pubrel.pkid)); self.events.push_back(event); - Ok(()) + + Ok(Some(Packet::PubComp(PubComp::new(pubrel.pkid, None)))) } None => { error!("Unsolicited pubrel packet: {:?}", pubrel.pkid); @@ -446,14 +451,15 @@ impl MqttState { } } - fn handle_incoming_pubcomp(&mut self, pubcomp: &PubComp) -> Result<(), StateError> { - if let Some(publish) = self.check_collision(pubcomp.pkid) { + fn handle_incoming_pubcomp(&mut self, pubcomp: &PubComp) -> Result, StateError> { + let outgoing = self.check_collision(pubcomp.pkid).map(|publish| { let pkid = publish.pkid; - Packet::Publish(publish).write(&mut self.write)?; let event = Event::Outgoing(Outgoing::Publish(pkid)); self.events.push_back(event); self.collision_ping_count = 0; - } + + Packet::Publish(publish) + }); let pubrel = self .outgoing_rel @@ -468,7 +474,7 @@ impl MqttState { } self.inflight -= 1; - Ok(()) + Ok(outgoing) } None => { error!("Unsolicited pubcomp packet: {:?}", pubcomp.pkid); @@ -477,14 +483,14 @@ impl MqttState { } } - fn handle_incoming_pingresp(&mut self) -> Result<(), StateError> { + fn handle_incoming_pingresp(&mut self) -> Result, StateError> { self.await_pingresp = false; - Ok(()) + Ok(None) } /// Adds next packet identifier to QoS 1 and 2 publish packets and returns /// it buy wrapping publish in packet - fn outgoing_publish(&mut self, mut publish: Publish) -> Result<(), StateError> { + fn outgoing_publish(&mut self, mut publish: Publish) -> Result, StateError> { if publish.qos != QoS::AtMostOnce { if publish.pkid == 0 { publish.pkid = self.next_pkid(); @@ -501,7 +507,7 @@ impl MqttState { self.collision = Some(publish); let event = Event::Outgoing(Outgoing::AwaitAck(pkid)); self.events.push_back(event); - return Ok(()); + return Ok(None); } // if there is an existing publish at this pkid, this implies that broker hasn't acked this @@ -532,43 +538,43 @@ impl MqttState { } }; - Packet::Publish(publish).write(&mut self.write)?; let event = Event::Outgoing(Outgoing::Publish(pkid)); self.events.push_back(event); - Ok(()) + + Ok(Some(Packet::Publish(publish))) } - fn outgoing_pubrel(&mut self, pubrel: PubRel) -> Result<(), StateError> { + fn outgoing_pubrel(&mut self, pubrel: PubRel) -> Result, StateError> { let pubrel = self.save_pubrel(pubrel)?; debug!("Pubrel. Pkid = {}", pubrel.pkid); - Packet::PubRel(PubRel::new(pubrel.pkid, None)).write(&mut self.write)?; let event = Event::Outgoing(Outgoing::PubRel(pubrel.pkid)); self.events.push_back(event); - Ok(()) + + Ok(Some(Packet::PubRel(PubRel::new(pubrel.pkid, None)))) } - fn outgoing_puback(&mut self, puback: PubAck) -> Result<(), StateError> { + fn outgoing_puback(&mut self, puback: PubAck) -> Result, StateError> { let pkid = puback.pkid; - Packet::PubAck(puback).write(&mut self.write)?; let event = Event::Outgoing(Outgoing::PubAck(pkid)); self.events.push_back(event); - Ok(()) + + Ok(Some(Packet::PubAck(puback))) } - fn outgoing_pubrec(&mut self, pubrec: PubRec) -> Result<(), StateError> { + fn outgoing_pubrec(&mut self, pubrec: PubRec) -> Result, StateError> { let pkid = pubrec.pkid; - Packet::PubRec(pubrec).write(&mut self.write)?; let event = Event::Outgoing(Outgoing::PubRec(pkid)); self.events.push_back(event); - Ok(()) + + Ok(Some(Packet::PubRec(pubrec))) } /// check when the last control packet/pingreq packet is received and return /// the status which tells if keep alive time has exceeded /// NOTE: status will be checked for zero keepalive times also - fn outgoing_ping(&mut self) -> Result<(), StateError> { + fn outgoing_ping(&mut self) -> Result, StateError> { let elapsed_in = self.last_incoming.elapsed(); let elapsed_out = self.last_outgoing.elapsed(); @@ -591,13 +597,16 @@ impl MqttState { elapsed_in, elapsed_out, ); - Packet::PingReq(PingReq).write(&mut self.write)?; let event = Event::Outgoing(Outgoing::PingReq); self.events.push_back(event); - Ok(()) + + Ok(Some(Packet::PingReq(PingReq))) } - fn outgoing_subscribe(&mut self, mut subscription: Subscribe) -> Result<(), StateError> { + fn outgoing_subscribe( + &mut self, + mut subscription: Subscribe, + ) -> Result, StateError> { if subscription.filters.is_empty() { return Err(StateError::EmptySubscription); } @@ -611,13 +620,16 @@ impl MqttState { ); let pkid = subscription.pkid; - Packet::Subscribe(subscription).write(&mut self.write)?; let event = Event::Outgoing(Outgoing::Subscribe(pkid)); self.events.push_back(event); - Ok(()) + + Ok(Some(Packet::Subscribe(subscription))) } - fn outgoing_unsubscribe(&mut self, mut unsub: Unsubscribe) -> Result<(), StateError> { + fn outgoing_unsubscribe( + &mut self, + mut unsub: Unsubscribe, + ) -> Result, StateError> { let pkid = self.next_pkid(); unsub.pkid = pkid; @@ -627,19 +639,21 @@ impl MqttState { ); let pkid = unsub.pkid; - Packet::Unsubscribe(unsub).write(&mut self.write)?; let event = Event::Outgoing(Outgoing::Unsubscribe(pkid)); self.events.push_back(event); - Ok(()) + + Ok(Some(Packet::Unsubscribe(unsub))) } - fn outgoing_disconnect(&mut self, reason: DisconnectReasonCode) -> Result<(), StateError> { + fn outgoing_disconnect( + &mut self, + reason: DisconnectReasonCode, + ) -> Result, StateError> { debug!("Disconnect with {:?}", reason); - - Packet::Disconnect(Disconnect::new(reason)).write(&mut self.write)?; let event = Event::Outgoing(Outgoing::Disconnect); self.events.push_back(event); - Ok(()) + + Ok(Some(Packet::Disconnect(Disconnect::new(reason)))) } fn check_collision(&mut self, pkid: u16) -> Option { @@ -652,18 +666,6 @@ impl MqttState { None } - fn check_size(&self, pkt_size: usize) -> Result<(), StateError> { - let pkt_size = pkt_size.try_into()?; - - match self.max_outgoing_packet_size { - Some(max_size) if pkt_size > max_size => Err(StateError::OutgoingPacketTooLarge { - pkt_size, - max: max_size, - }), - _ => Ok(()), - } - } - fn save_pubrel(&mut self, mut pubrel: PubRel) -> Result { let pubrel = match pubrel.pkid { // consider PacketIdentifier(0) as uninitialized packets @@ -887,11 +889,9 @@ mod test { let mut mqtt = build_mqttstate(); let mut publish = build_incoming_publish(QoS::ExactlyOnce, 1); - mqtt.handle_incoming_publish(&mut publish).unwrap(); - let packet = Packet::read(&mut mqtt.write, Some(10 * 1024)).unwrap(); - match packet { + match mqtt.handle_incoming_publish(&mut publish).unwrap().unwrap() { Packet::PubRec(pubrec) => assert_eq!(pubrec.pkid, 1), - _ => panic!("Invalid network request: {:?}", packet), + packet => panic!("Invalid network request: {:?}", packet), } } @@ -956,16 +956,16 @@ mod test { let mut mqtt = build_mqttstate(); let publish = build_outgoing_publish(QoS::ExactlyOnce); - mqtt.outgoing_publish(publish).unwrap(); - let packet = Packet::read(&mut mqtt.write, Some(10 * 1024)).unwrap(); - match packet { + match mqtt.outgoing_publish(publish).unwrap().unwrap() { Packet::Publish(publish) => assert_eq!(publish.pkid, 1), packet => panic!("Invalid network request: {:?}", packet), } - mqtt.handle_incoming_pubrec(&PubRec::new(1, None)).unwrap(); - let packet = Packet::read(&mut mqtt.write, Some(10 * 1024)).unwrap(); - match packet { + match mqtt + .handle_incoming_pubrec(&PubRec::new(1, None)) + .unwrap() + .unwrap() + { Packet::PubRel(pubrel) => assert_eq!(pubrel.pkid, 1), packet => panic!("Invalid network request: {:?}", packet), } @@ -976,16 +976,16 @@ mod test { let mut mqtt = build_mqttstate(); let mut publish = build_incoming_publish(QoS::ExactlyOnce, 1); - mqtt.handle_incoming_publish(&mut publish).unwrap(); - let packet = Packet::read(&mut mqtt.write, Some(10 * 1024)).unwrap(); - match packet { + match mqtt.handle_incoming_publish(&mut publish).unwrap().unwrap() { Packet::PubRec(pubrec) => assert_eq!(pubrec.pkid, 1), packet => panic!("Invalid network request: {:?}", packet), } - mqtt.handle_incoming_pubrel(&PubRel::new(1, None)).unwrap(); - let packet = Packet::read(&mut mqtt.write, Some(10 * 1024)).unwrap(); - match packet { + match mqtt + .handle_incoming_pubrel(&PubRel::new(1, None)) + .unwrap() + .unwrap() + { Packet::PubComp(pubcomp) => assert_eq!(pubcomp.pkid, 1), packet => panic!("Invalid network request: {:?}", packet), } diff --git a/rumqttc/tests/reliability.rs b/rumqttc/tests/reliability.rs index 3e7acd1e3..0a83d57ce 100644 --- a/rumqttc/tests/reliability.rs +++ b/rumqttc/tests/reliability.rs @@ -570,7 +570,7 @@ async fn state_is_being_cleaned_properly_and_pending_request_calculated_properly if let Err(e) = res { match e { ConnectionError::FlushTimeout => { - assert!(eventloop.state.write.is_empty()); + assert!(eventloop.network.is_none()); println!("State is being clean properly"); } _ => {