From 5dbe5a18225831c8e30da895c97eb373039df348 Mon Sep 17 00:00:00 2001 From: Devdutt Shenoi Date: Tue, 19 Mar 2024 16:27:06 +0000 Subject: [PATCH] refactor: v5 implementation --- rumqttc/src/v5/eventloop.rs | 66 ++++++-- rumqttc/src/v5/framed.rs | 149 +++++------------- rumqttc/src/v5/mqttbytes/mod.rs | 6 +- rumqttc/src/v5/mqttbytes/v5/codec.rs | 73 +++++++++ rumqttc/src/v5/mqttbytes/v5/mod.rs | 14 +- rumqttc/src/v5/state.rs | 223 ++++++++++++++------------- 6 files changed, 289 insertions(+), 242 deletions(-) create mode 100644 rumqttc/src/v5/mqttbytes/v5/codec.rs diff --git a/rumqttc/src/v5/eventloop.rs b/rumqttc/src/v5/eventloop.rs index 27c26f29d..7c8b5e51e 100644 --- a/rumqttc/src/v5/eventloop.rs +++ b/rumqttc/src/v5/eventloop.rs @@ -210,17 +210,21 @@ impl EventLoop { self.options.pending_throttle ), if !self.pending.is_empty() || (!inflight_full && !collision) => match o { Ok(request) => { - self.state.handle_outgoing_packet(request)?; - network.flush(&mut self.state.write).await?; + if let Some(outgoing) = self.state.handle_outgoing_packet(request)? { + network.send(outgoing).await?; + } + Ok(self.state.events.pop_front().unwrap()) } Err(_) => Err(ConnectionError::RequestsDone), }, // Pull a bunch of packets from network, reply in bunch and yield the first item - o = network.readb(&mut self.state) => { - o?; - // flush all the acks and return first incoming packet - network.flush(&mut self.state.write).await?; + o = network.read() => { + let incoming = o?; + if let Some(packet) = self.state.handle_incoming_packet(incoming)? { + network.send(packet).await?; + } + Ok(self.state.events.pop_front().unwrap()) }, // We generate pings irrespective of network activity. This keeps the ping logic @@ -229,8 +233,10 @@ impl EventLoop { let timeout = self.keepalive_timeout.as_mut().unwrap(); timeout.as_mut().reset(Instant::now() + self.options.keep_alive); - self.state.handle_outgoing_packet(Request::PingReq)?; - network.flush(&mut self.state.write).await?; + if let Some(outgoing) = self.state.handle_outgoing_packet(Request::PingReq)? { + network.send(outgoing).await?; + } + Ok(self.state.events.pop_front().unwrap()) } } @@ -276,7 +282,9 @@ async fn connect(options: &mut MqttOptions) -> Result<(Network, Incoming), Conne } async fn network_connect(options: &MqttOptions) -> Result { - let mut max_incoming_pkt_size = Some(options.default_max_incoming_size); + let mut max_incoming_pkt_size = Some(options.default_max_incoming_size); // incoming == outgoing + let max_outgoing_pkt_size = Some(options.default_max_incoming_size); + let network_timeout = Duration::from_secs(options.network_options.connection_timeout()); // Override default value if max_packet_size is set on `connect_properties` if let Some(connect_props) = &options.connect_properties { @@ -291,7 +299,12 @@ async fn network_connect(options: &MqttOptions) -> Result Result Network::new(tcp_stream, max_incoming_pkt_size), + Transport::Tcp => Network::new( + tcp_stream, + max_incoming_pkt_size, + max_outgoing_pkt_size, + network_timeout, + ), #[cfg(any(feature = "use-native-tls", feature = "use-rustls"))] Transport::Tls(tls_config) => { let socket = tls::tls_connect(&options.broker_addr, options.port, &tls_config, tcp_stream) .await?; - Network::new(socket, max_incoming_pkt_size) + Network::new( + socket, + max_incoming_pkt_size, + max_outgoing_pkt_size, + network_timeout, + ) } #[cfg(unix)] Transport::Unix => unreachable!(), @@ -352,7 +375,12 @@ async fn network_connect(options: &MqttOptions) -> Result { @@ -375,7 +403,12 @@ async fn network_connect(options: &MqttOptions) -> Result frames efficiently. It takes /// advantage of pre-allocation, buffering and vectorization when /// appropriate to achieve performance pub struct Network { - /// Socket for IO - socket: Box, - /// Buffered reads - read: BytesMut, - /// Maximum packet size - max_incoming_size: Option, - /// Maximum readv count - max_readb_count: usize, + /// Frame MQTT packets from network connection + framed: Framed, Codec>, + /// Time within which network operations should complete + timeout: Duration, } - impl Network { - pub fn new(socket: impl N + 'static, max_incoming_size: Option) -> Network { - let socket = Box::new(socket) as Box; - Network { - socket, - read: BytesMut::with_capacity(10 * 1024), + pub fn new( + socket: impl AsyncReadWrite + 'static, + max_incoming_size: Option, + max_outgoing_size: Option, + timeout: Duration, + ) -> Network { + let socket = Box::new(socket) as Box; + let codec = Codec { max_incoming_size, - max_readb_count: 10, - } - } - - /// Reads more than 'required' bytes to frame a packet into self.read buffer - async fn read_bytes(&mut self, required: usize) -> io::Result { - let mut total_read = 0; - loop { - let read = self.socket.read_buf(&mut self.read).await?; - if 0 == read { - return if self.read.is_empty() { - Err(io::Error::new( - io::ErrorKind::ConnectionAborted, - "connection closed by peer", - )) - } else { - Err(io::Error::new( - io::ErrorKind::ConnectionReset, - "connection reset by peer", - )) - }; - } + max_outgoing_size, + }; + let framed = Framed::new(socket, codec); - total_read += read; - if total_read >= required { - return Ok(total_read); - } - } + Network { framed, timeout } } - pub async fn read(&mut self) -> io::Result { - loop { - let required = match Packet::read(&mut self.read, self.max_incoming_size) { - Ok(packet) => return Ok(packet), - Err(mqttbytes::Error::InsufficientBytes(required)) => required, - Err(e) => return Err(io::Error::new(io::ErrorKind::InvalidData, e.to_string())), - }; - - // read more packets until a frame can be created. This function - // blocks until a frame can be created. Use this in a select! branch - self.read_bytes(required).await?; + pub async fn read(&mut self) -> Result { + match self.framed.next().await { + Some(Ok(packet)) => Ok(packet), + Some(Err(mqttbytes::Error::InsufficientBytes(_))) | None => unreachable!(), + Some(Err(e)) => Err(StateError::Deserialization(e)), } } - /// Read packets in bulk. This allow replies to be in bulk. This method is used - /// after the connection is established to read a bunch of incoming packets - pub async fn readb(&mut self, state: &mut MqttState) -> Result<(), StateError> { - let mut count = 0; - loop { - match Packet::read(&mut self.read, self.max_incoming_size) { - Ok(packet) => { - state.handle_incoming_packet(packet)?; - - count += 1; - if count >= self.max_readb_count { - return Ok(()); - } - } - // If some packets are already framed, return those - Err(mqttbytes::Error::InsufficientBytes(_)) if count > 0 => return Ok(()), - // Wait for more bytes until a frame can be created - Err(mqttbytes::Error::InsufficientBytes(required)) => { - self.read_bytes(required).await?; - } - Err(mqttbytes::Error::PayloadSizeLimitExceeded { pkt_size, max }) => { - state.handle_protocol_error()?; - return Err(StateError::IncomingPacketTooLarge { pkt_size, max }); - } - Err(e) => return Err(StateError::Deserialization(e)), - }; + pub async fn send(&mut self, packet: Packet) -> Result<(), StateError> { + match timeout(self.timeout, self.framed.send(packet)).await { + Ok(inner) => inner.map_err(StateError::Deserialization), + Err(e) => Err(StateError::Timeout(e)), } } - - pub async fn connect(&mut self, connect: Connect, options: &MqttOptions) -> io::Result { - let mut write = BytesMut::new(); - let last_will = options.last_will(); - let login = options.credentials().map(|l| Login { - username: l.0, - password: l.1, - }); - - let len = match Packet::Connect(connect, last_will, login).write(&mut write) { - Ok(size) => size, - Err(e) => return Err(io::Error::new(io::ErrorKind::InvalidData, e.to_string())), - }; - - self.socket.write_all(&write[..]).await?; - Ok(len) - } - - pub async fn flush(&mut self, write: &mut BytesMut) -> io::Result<()> { - if write.is_empty() { - return Ok(()); - } - - self.socket.write_all(&write[..]).await?; - write.clear(); - Ok(()) - } } - -pub trait N: AsyncRead + AsyncWrite + Send + Unpin {} -impl N for T where T: AsyncRead + AsyncWrite + Send + Unpin {} diff --git a/rumqttc/src/v5/mqttbytes/mod.rs b/rumqttc/src/v5/mqttbytes/mod.rs index 231c68067..42c1fcdd7 100644 --- a/rumqttc/src/v5/mqttbytes/mod.rs +++ b/rumqttc/src/v5/mqttbytes/mod.rs @@ -130,7 +130,7 @@ pub fn matches(topic: &str, filter: &str) -> bool { } /// Error during serialization and deserialization -#[derive(Debug, Clone, Copy, PartialEq, Eq, thiserror::Error)] +#[derive(Debug, thiserror::Error)] pub enum Error { #[error("Invalid return code received as response for connect = {0}")] InvalidConnectReturnCode(u8), @@ -183,4 +183,8 @@ pub enum Error { /// proceed further #[error("Insufficient number of bytes to frame packet, {0} more bytes required")] InsufficientBytes(usize), + #[error("IO: {0}")] + Io(#[from] std::io::Error), + #[error("Cannot send packet of size '{pkt_size:?}'. It's greater than the broker's maximum packet size of: '{max:?}'")] + OutgoingPacketTooLarge { pkt_size: u32, max: u32 }, } diff --git a/rumqttc/src/v5/mqttbytes/v5/codec.rs b/rumqttc/src/v5/mqttbytes/v5/codec.rs new file mode 100644 index 000000000..fc24105cd --- /dev/null +++ b/rumqttc/src/v5/mqttbytes/v5/codec.rs @@ -0,0 +1,73 @@ +use bytes::{Buf, BytesMut}; +use tokio_util::codec::{Decoder, Encoder}; + +use super::{Error, Packet}; + +/// MQTT v4 codec +#[derive(Debug, Clone)] +pub struct Codec { + /// Maximum packet size allowed by client + pub max_incoming_size: Option, + /// Maximum packet size allowed by broker + pub max_outgoing_size: Option, +} + +impl Decoder for Codec { + type Item = Packet; + type Error = Error; + + fn decode(&mut self, src: &mut BytesMut) -> Result, Self::Error> { + if src.remaining() == 0 { + return Ok(None); + } + + let packet = Packet::read(src, self.max_incoming_size)?; + Ok(Some(packet)) + } +} + +impl Encoder for Codec { + type Error = Error; + + fn encode(&mut self, item: Packet, dst: &mut BytesMut) -> Result<(), Self::Error> { + item.write(dst, self.max_outgoing_size)?; + + Ok(()) + } +} + +#[cfg(test)] +mod tests { + use bytes::BytesMut; + use tokio_util::codec::Encoder; + + use super::Codec; + use crate::v5::{ + mqttbytes::{Error, QoS}, + Packet, Publish, + }; + + #[test] + fn outgoing_max_packet_size_check() { + let mut buf = BytesMut::new(); + let mut codec = Codec { + max_incoming_size: Some(100), + max_outgoing_size: Some(200), + }; + + let mut small_publish = Publish::new("hello/world", QoS::AtLeastOnce, vec![1; 100], None); + small_publish.pkid = 1; + codec + .encode(Packet::Publish(small_publish), &mut buf) + .unwrap(); + + let large_publish = Publish::new("hello/world", QoS::AtLeastOnce, vec![1; 265], None); + match codec.encode(Packet::Publish(large_publish), &mut buf) { + Err(Error::OutgoingPacketTooLarge { + pkt_size: 282, + max: 200, + }) => {} + _ => unreachable!(), + } + } +} diff --git a/rumqttc/src/v5/mqttbytes/v5/mod.rs b/rumqttc/src/v5/mqttbytes/v5/mod.rs index 01ddef992..d98ab475c 100644 --- a/rumqttc/src/v5/mqttbytes/v5/mod.rs +++ b/rumqttc/src/v5/mqttbytes/v5/mod.rs @@ -1,6 +1,7 @@ use std::slice::Iter; pub use self::{ + codec::Codec, connack::{ConnAck, ConnAckProperties, ConnectReturnCode}, connect::{Connect, ConnectProperties, LastWill, LastWillProperties, Login}, disconnect::{Disconnect, DisconnectReasonCode}, @@ -19,6 +20,7 @@ pub use self::{ use super::*; use bytes::{Buf, BufMut, Bytes, BytesMut}; +mod codec; mod connack; mod connect; mod disconnect; @@ -126,7 +128,17 @@ impl Packet { Ok(packet) } - pub fn write(&self, write: &mut BytesMut) -> Result { + pub fn write(&self, write: &mut BytesMut, max_size: Option) -> Result { + if let Some(max_size) = max_size { + if self.size() > max_size { + dbg!(); + return Err(Error::OutgoingPacketTooLarge { + pkt_size: self.size() as u32, + max: max_size as u32, + }); + } + } + match self { Self::Publish(publish) => publish.write(write), Self::Subscribe(subscription) => subscription.write(write), diff --git a/rumqttc/src/v5/state.rs b/rumqttc/src/v5/state.rs index 8473f1f4c..c817191e7 100644 --- a/rumqttc/src/v5/state.rs +++ b/rumqttc/src/v5/state.rs @@ -9,8 +9,8 @@ use super::{Event, Incoming, Outgoing, Request}; use bytes::{Bytes, BytesMut}; use std::collections::{HashMap, VecDeque}; -use std::convert::TryInto; use std::{io, time::Instant}; +use tokio::time::error::Elapsed; /// Errors during state handling #[derive(Debug, thiserror::Error)] @@ -42,10 +42,6 @@ pub enum StateError { "Cannot use topic alias '{alias:?}'. It's greater than the broker's maximum of '{max:?}'." )] InvalidAlias { alias: u16, max: u16 }, - #[error("Cannot send packet of size '{pkt_size:?}'. It's greater than the broker's maximum packet size of: '{max:?}'")] - OutgoingPacketTooLarge { pkt_size: u32, max: u32 }, - #[error("Cannot receive packet of size '{pkt_size:?}'. It's greater than the client's maximum packet size of: '{max:?}'")] - IncomingPacketTooLarge { pkt_size: usize, max: usize }, #[error("Server sent disconnect with reason `{reason_string:?}` and code '{reason_code:?}' ")] ServerDisconnect { reason_code: DisconnectReasonCode, @@ -65,6 +61,8 @@ pub enum StateError { PubCompFail { reason: PubCompReason }, #[error("Connection failed with reason '{reason:?}' ")] ConnFail { reason: ConnectReturnCode }, + #[error("Timeout")] + Timeout(#[from] Elapsed), } /// State of the mqtt connection. @@ -108,7 +106,7 @@ pub struct MqttState { /// `topic_alias_maximum` RECEIVED via connack packet pub broker_topic_alias_max: u16, /// The broker's `max_packet_size` received via connack - pub max_outgoing_packet_size: Option, + pub max_outgoing_packet_size: Option, /// Maximum number of allowed inflight QoS1 & QoS2 requests pub(crate) max_outgoing_inflight: u16, /// Upper limit on the maximum number of allowed inflight QoS1 & QoS2 requests @@ -181,77 +179,67 @@ impl MqttState { /// Consolidates handling of all outgoing mqtt packet logic. Returns a packet which should /// be put on to the network by the eventloop - pub fn handle_outgoing_packet(&mut self, request: Request) -> Result<(), StateError> { - match request { - Request::Publish(publish) => { - self.check_size(publish.size())?; - self.outgoing_publish(publish)? - } - Request::PubRel(pubrel) => { - self.check_size(pubrel.size())?; - self.outgoing_pubrel(pubrel)? - } - Request::Subscribe(subscribe) => { - self.check_size(subscribe.size())?; - self.outgoing_subscribe(subscribe)? - } - Request::Unsubscribe(unsubscribe) => { - self.check_size(unsubscribe.size())?; - self.outgoing_unsubscribe(unsubscribe)? - } + pub fn handle_outgoing_packet( + &mut self, + request: Request, + ) -> Result, StateError> { + let packet = match request { + Request::Publish(publish) => self.outgoing_publish(publish)?, + Request::PubRel(pubrel) => self.outgoing_pubrel(pubrel)?, + Request::Subscribe(subscribe) => self.outgoing_subscribe(subscribe)?, + Request::Unsubscribe(unsubscribe) => self.outgoing_unsubscribe(unsubscribe)?, Request::PingReq => self.outgoing_ping()?, Request::Disconnect => { self.outgoing_disconnect(DisconnectReasonCode::NormalDisconnection)? } - Request::PubAck(puback) => { - self.check_size(puback.size())?; - self.outgoing_puback(puback)? - } - Request::PubRec(pubrec) => { - self.check_size(pubrec.size())?; - self.outgoing_pubrec(pubrec)? - } + Request::PubAck(puback) => self.outgoing_puback(puback)?, + Request::PubRec(pubrec) => self.outgoing_pubrec(pubrec)?, _ => unimplemented!(), }; self.last_outgoing = Instant::now(); - Ok(()) + Ok(packet) } /// Consolidates handling of all incoming mqtt packets. Returns a `Notification` which for the /// user to consume and `Packet` which for the eventloop to put on the network /// E.g For incoming QoS1 publish packet, this method returns (Publish, Puback). Publish packet will /// be forwarded to user and Pubck packet will be written to network - pub fn handle_incoming_packet(&mut self, mut packet: Incoming) -> Result<(), StateError> { - let out = match &mut packet { - Incoming::PingResp(_) => self.handle_incoming_pingresp(), - Incoming::Publish(publish) => self.handle_incoming_publish(publish), - Incoming::SubAck(suback) => self.handle_incoming_suback(suback), - Incoming::UnsubAck(unsuback) => self.handle_incoming_unsuback(unsuback), - Incoming::PubAck(puback) => self.handle_incoming_puback(puback), - Incoming::PubRec(pubrec) => self.handle_incoming_pubrec(pubrec), - Incoming::PubRel(pubrel) => self.handle_incoming_pubrel(pubrel), - Incoming::PubComp(pubcomp) => self.handle_incoming_pubcomp(pubcomp), - Incoming::ConnAck(connack) => self.handle_incoming_connack(connack), - Incoming::Disconnect(disconn) => self.handle_incoming_disconn(disconn), + pub fn handle_incoming_packet( + &mut self, + mut packet: Incoming, + ) -> Result, StateError> { + let outgoing = match &mut packet { + Incoming::PingResp(_) => self.handle_incoming_pingresp()?, + Incoming::Publish(publish) => self.handle_incoming_publish(publish)?, + Incoming::SubAck(suback) => self.handle_incoming_suback(suback)?, + Incoming::UnsubAck(unsuback) => self.handle_incoming_unsuback(unsuback)?, + Incoming::PubAck(puback) => self.handle_incoming_puback(puback)?, + Incoming::PubRec(pubrec) => self.handle_incoming_pubrec(pubrec)?, + Incoming::PubRel(pubrel) => self.handle_incoming_pubrel(pubrel)?, + Incoming::PubComp(pubcomp) => self.handle_incoming_pubcomp(pubcomp)?, + Incoming::ConnAck(connack) => self.handle_incoming_connack(connack)?, + Incoming::Disconnect(disconn) => self.handle_incoming_disconn(disconn)?, _ => { error!("Invalid incoming packet = {:?}", packet); return Err(StateError::WrongPacket); } }; - out?; self.events.push_back(Event::Incoming(packet)); self.last_incoming = Instant::now(); - Ok(()) + Ok(outgoing) } - pub fn handle_protocol_error(&mut self) -> Result<(), StateError> { + pub fn handle_protocol_error(&mut self) -> Result, StateError> { // send DISCONNECT packet with REASON_CODE 0x82 self.outgoing_disconnect(DisconnectReasonCode::ProtocolError) } - fn handle_incoming_suback(&mut self, suback: &mut SubAck) -> Result<(), StateError> { + fn handle_incoming_suback( + &mut self, + suback: &mut SubAck, + ) -> Result, StateError> { for reason in suback.return_codes.iter() { match reason { SubscribeReasonCode::Success(qos) => { @@ -260,19 +248,25 @@ impl MqttState { _ => return Err(StateError::SubFail { reason: *reason }), } } - Ok(()) + Ok(None) } - fn handle_incoming_unsuback(&mut self, unsuback: &mut UnsubAck) -> Result<(), StateError> { + fn handle_incoming_unsuback( + &mut self, + unsuback: &mut UnsubAck, + ) -> Result, StateError> { for reason in unsuback.reasons.iter() { if reason != &UnsubAckReason::Success { return Err(StateError::UnsubFail { reason: *reason }); } } - Ok(()) + Ok(None) } - fn handle_incoming_connack(&mut self, connack: &mut ConnAck) -> Result<(), StateError> { + fn handle_incoming_connack( + &mut self, + connack: &mut ConnAck, + ) -> Result, StateError> { if connack.code != ConnectReturnCode::Success { return Err(StateError::ConnFail { reason: connack.code, @@ -291,12 +285,15 @@ impl MqttState { // to save some space. } - self.max_outgoing_packet_size = props.max_packet_size; + self.max_outgoing_packet_size = props.max_packet_size.map(|i| i as usize); } - Ok(()) + Ok(None) } - fn handle_incoming_disconn(&mut self, disconn: &mut Disconnect) -> Result<(), StateError> { + fn handle_incoming_disconn( + &mut self, + disconn: &mut Disconnect, + ) -> Result, StateError> { let reason_code = disconn.reason_code; let reason_string = if let Some(props) = &disconn.properties { props.reason_string.clone() @@ -311,7 +308,10 @@ impl MqttState { /// Results in a publish notification in all the QoS cases. Replys with an ack /// in case of QoS1 and Replys rec in case of QoS while also storing the message - fn handle_incoming_publish(&mut self, publish: &mut Publish) -> Result<(), StateError> { + fn handle_incoming_publish( + &mut self, + publish: &mut Publish, + ) -> Result, StateError> { let qos = publish.qos; let topic_alias = match &publish.properties { @@ -332,13 +332,13 @@ impl MqttState { } match qos { - QoS::AtMostOnce => Ok(()), + QoS::AtMostOnce => Ok(None), QoS::AtLeastOnce => { if !self.manual_acks { let puback = PubAck::new(publish.pkid, None); self.outgoing_puback(puback)?; } - Ok(()) + Ok(None) } QoS::ExactlyOnce => { let pkid = publish.pkid; @@ -348,12 +348,12 @@ impl MqttState { let pubrec = PubRec::new(pkid, None); self.outgoing_pubrec(pubrec)?; } - Ok(()) + Ok(None) } } } - fn handle_incoming_puback(&mut self, puback: &PubAck) -> Result<(), StateError> { + fn handle_incoming_puback(&mut self, puback: &PubAck) -> Result, StateError> { let publish = self .outgoing_pub .get_mut(puback.pkid as usize) @@ -361,7 +361,7 @@ impl MqttState { let v = match publish.take() { Some(_) => { self.inflight -= 1; - Ok(()) + Ok(None) } None => { error!("Unsolicited puback packet: {:?}", puback.pkid); @@ -382,7 +382,7 @@ impl MqttState { self.inflight += 1; let pkid = publish.pkid; - Packet::Publish(publish).write(&mut self.write)?; + Packet::Publish(publish).write(&mut self.write, self.max_outgoing_packet_size)?; let event = Event::Outgoing(Outgoing::Publish(pkid)); self.events.push_back(event); self.collision_ping_count = 0; @@ -391,7 +391,7 @@ impl MqttState { v } - fn handle_incoming_pubrec(&mut self, pubrec: &PubRec) -> Result<(), StateError> { + fn handle_incoming_pubrec(&mut self, pubrec: &PubRec) -> Result, StateError> { let publish = self .outgoing_pub .get_mut(pubrec.pkid as usize) @@ -408,11 +408,12 @@ impl MqttState { // NOTE: Inflight - 1 for qos2 in comp self.outgoing_rel[pubrec.pkid as usize] = Some(pubrec.pkid); - Packet::PubRel(PubRel::new(pubrec.pkid, None)).write(&mut self.write)?; + Packet::PubRel(PubRel::new(pubrec.pkid, None)) + .write(&mut self.write, self.max_outgoing_packet_size)?; let event = Event::Outgoing(Outgoing::PubRel(pubrec.pkid)); self.events.push_back(event); - Ok(()) + Ok(None) } None => { error!("Unsolicited pubrec packet: {:?}", pubrec.pkid); @@ -421,7 +422,7 @@ impl MqttState { } } - fn handle_incoming_pubrel(&mut self, pubrel: &PubRel) -> Result<(), StateError> { + fn handle_incoming_pubrel(&mut self, pubrel: &PubRel) -> Result, StateError> { let publish = self .incoming_pub .get_mut(pubrel.pkid as usize) @@ -434,10 +435,11 @@ impl MqttState { }); } - Packet::PubComp(PubComp::new(pubrel.pkid, None)).write(&mut self.write)?; + Packet::PubComp(PubComp::new(pubrel.pkid, None)) + .write(&mut self.write, self.max_outgoing_packet_size)?; let event = Event::Outgoing(Outgoing::PubComp(pubrel.pkid)); self.events.push_back(event); - Ok(()) + Ok(None) } None => { error!("Unsolicited pubrel packet: {:?}", pubrel.pkid); @@ -446,10 +448,10 @@ impl MqttState { } } - fn handle_incoming_pubcomp(&mut self, pubcomp: &PubComp) -> Result<(), StateError> { + fn handle_incoming_pubcomp(&mut self, pubcomp: &PubComp) -> Result, StateError> { if let Some(publish) = self.check_collision(pubcomp.pkid) { let pkid = publish.pkid; - Packet::Publish(publish).write(&mut self.write)?; + Packet::Publish(publish).write(&mut self.write, self.max_outgoing_packet_size)?; let event = Event::Outgoing(Outgoing::Publish(pkid)); self.events.push_back(event); self.collision_ping_count = 0; @@ -468,7 +470,7 @@ impl MqttState { } self.inflight -= 1; - Ok(()) + Ok(None) } None => { error!("Unsolicited pubcomp packet: {:?}", pubcomp.pkid); @@ -477,14 +479,14 @@ impl MqttState { } } - fn handle_incoming_pingresp(&mut self) -> Result<(), StateError> { + fn handle_incoming_pingresp(&mut self) -> Result, StateError> { self.await_pingresp = false; - Ok(()) + Ok(None) } /// Adds next packet identifier to QoS 1 and 2 publish packets and returns /// it buy wrapping publish in packet - fn outgoing_publish(&mut self, mut publish: Publish) -> Result<(), StateError> { + fn outgoing_publish(&mut self, mut publish: Publish) -> Result, StateError> { if publish.qos != QoS::AtMostOnce { if publish.pkid == 0 { publish.pkid = self.next_pkid(); @@ -501,7 +503,7 @@ impl MqttState { self.collision = Some(publish); let event = Event::Outgoing(Outgoing::AwaitAck(pkid)); self.events.push_back(event); - return Ok(()); + return Ok(None); } // if there is an existing publish at this pkid, this implies that broker hasn't acked this @@ -532,43 +534,44 @@ impl MqttState { } }; - Packet::Publish(publish).write(&mut self.write)?; + Packet::Publish(publish).write(&mut self.write, self.max_outgoing_packet_size)?; let event = Event::Outgoing(Outgoing::Publish(pkid)); self.events.push_back(event); - Ok(()) + Ok(None) } - fn outgoing_pubrel(&mut self, pubrel: PubRel) -> Result<(), StateError> { + fn outgoing_pubrel(&mut self, pubrel: PubRel) -> Result, StateError> { let pubrel = self.save_pubrel(pubrel)?; debug!("Pubrel. Pkid = {}", pubrel.pkid); - Packet::PubRel(PubRel::new(pubrel.pkid, None)).write(&mut self.write)?; + Packet::PubRel(PubRel::new(pubrel.pkid, None)) + .write(&mut self.write, self.max_outgoing_packet_size)?; let event = Event::Outgoing(Outgoing::PubRel(pubrel.pkid)); self.events.push_back(event); - Ok(()) + Ok(None) } - fn outgoing_puback(&mut self, puback: PubAck) -> Result<(), StateError> { + fn outgoing_puback(&mut self, puback: PubAck) -> Result, StateError> { let pkid = puback.pkid; - Packet::PubAck(puback).write(&mut self.write)?; + Packet::PubAck(puback).write(&mut self.write, self.max_outgoing_packet_size)?; let event = Event::Outgoing(Outgoing::PubAck(pkid)); self.events.push_back(event); - Ok(()) + Ok(None) } - fn outgoing_pubrec(&mut self, pubrec: PubRec) -> Result<(), StateError> { + fn outgoing_pubrec(&mut self, pubrec: PubRec) -> Result, StateError> { let pkid = pubrec.pkid; - Packet::PubRec(pubrec).write(&mut self.write)?; + Packet::PubRec(pubrec).write(&mut self.write, self.max_outgoing_packet_size)?; let event = Event::Outgoing(Outgoing::PubRec(pkid)); self.events.push_back(event); - Ok(()) + Ok(None) } /// check when the last control packet/pingreq packet is received and return /// the status which tells if keep alive time has exceeded /// NOTE: status will be checked for zero keepalive times also - fn outgoing_ping(&mut self) -> Result<(), StateError> { + fn outgoing_ping(&mut self) -> Result, StateError> { let elapsed_in = self.last_incoming.elapsed(); let elapsed_out = self.last_outgoing.elapsed(); @@ -591,13 +594,16 @@ impl MqttState { elapsed_in, elapsed_out, ); - Packet::PingReq(PingReq).write(&mut self.write)?; + Packet::PingReq(PingReq).write(&mut self.write, self.max_outgoing_packet_size)?; let event = Event::Outgoing(Outgoing::PingReq); self.events.push_back(event); - Ok(()) + Ok(None) } - fn outgoing_subscribe(&mut self, mut subscription: Subscribe) -> Result<(), StateError> { + fn outgoing_subscribe( + &mut self, + mut subscription: Subscribe, + ) -> Result, StateError> { if subscription.filters.is_empty() { return Err(StateError::EmptySubscription); } @@ -611,13 +617,16 @@ impl MqttState { ); let pkid = subscription.pkid; - Packet::Subscribe(subscription).write(&mut self.write)?; + Packet::Subscribe(subscription).write(&mut self.write, self.max_outgoing_packet_size)?; let event = Event::Outgoing(Outgoing::Subscribe(pkid)); self.events.push_back(event); - Ok(()) + Ok(None) } - fn outgoing_unsubscribe(&mut self, mut unsub: Unsubscribe) -> Result<(), StateError> { + fn outgoing_unsubscribe( + &mut self, + mut unsub: Unsubscribe, + ) -> Result, StateError> { let pkid = self.next_pkid(); unsub.pkid = pkid; @@ -627,19 +636,23 @@ impl MqttState { ); let pkid = unsub.pkid; - Packet::Unsubscribe(unsub).write(&mut self.write)?; + Packet::Unsubscribe(unsub).write(&mut self.write, self.max_outgoing_packet_size)?; let event = Event::Outgoing(Outgoing::Unsubscribe(pkid)); self.events.push_back(event); - Ok(()) + Ok(None) } - fn outgoing_disconnect(&mut self, reason: DisconnectReasonCode) -> Result<(), StateError> { + fn outgoing_disconnect( + &mut self, + reason: DisconnectReasonCode, + ) -> Result, StateError> { debug!("Disconnect with {:?}", reason); - Packet::Disconnect(Disconnect::new(reason)).write(&mut self.write)?; + Packet::Disconnect(Disconnect::new(reason)) + .write(&mut self.write, self.max_outgoing_packet_size)?; let event = Event::Outgoing(Outgoing::Disconnect); self.events.push_back(event); - Ok(()) + Ok(None) } fn check_collision(&mut self, pkid: u16) -> Option { @@ -652,18 +665,6 @@ impl MqttState { None } - fn check_size(&self, pkt_size: usize) -> Result<(), StateError> { - let pkt_size = pkt_size.try_into()?; - - match self.max_outgoing_packet_size { - Some(max_size) if pkt_size > max_size => Err(StateError::OutgoingPacketTooLarge { - pkt_size, - max: max_size, - }), - _ => Ok(()), - } - } - fn save_pubrel(&mut self, mut pubrel: PubRel) -> Result { let pubrel = match pubrel.pkid { // consider PacketIdentifier(0) as uninitialized packets