Skip to content

Commit

Permalink
refactor: v5 implementation
Browse files Browse the repository at this point in the history
  • Loading branch information
Devdutt Shenoi committed Mar 19, 2024
1 parent 041c6c8 commit 5dbe5a1
Show file tree
Hide file tree
Showing 6 changed files with 289 additions and 242 deletions.
66 changes: 51 additions & 15 deletions rumqttc/src/v5/eventloop.rs
Original file line number Diff line number Diff line change
Expand Up @@ -210,17 +210,21 @@ impl EventLoop {
self.options.pending_throttle
), if !self.pending.is_empty() || (!inflight_full && !collision) => match o {
Ok(request) => {
self.state.handle_outgoing_packet(request)?;
network.flush(&mut self.state.write).await?;
if let Some(outgoing) = self.state.handle_outgoing_packet(request)? {
network.send(outgoing).await?;
}

Ok(self.state.events.pop_front().unwrap())
}
Err(_) => Err(ConnectionError::RequestsDone),
},
// Pull a bunch of packets from network, reply in bunch and yield the first item
o = network.readb(&mut self.state) => {
o?;
// flush all the acks and return first incoming packet
network.flush(&mut self.state.write).await?;
o = network.read() => {
let incoming = o?;
if let Some(packet) = self.state.handle_incoming_packet(incoming)? {
network.send(packet).await?;
}

Ok(self.state.events.pop_front().unwrap())
},
// We generate pings irrespective of network activity. This keeps the ping logic
Expand All @@ -229,8 +233,10 @@ impl EventLoop {
let timeout = self.keepalive_timeout.as_mut().unwrap();
timeout.as_mut().reset(Instant::now() + self.options.keep_alive);

self.state.handle_outgoing_packet(Request::PingReq)?;
network.flush(&mut self.state.write).await?;
if let Some(outgoing) = self.state.handle_outgoing_packet(Request::PingReq)? {
network.send(outgoing).await?;
}

Ok(self.state.events.pop_front().unwrap())
}
}
Expand Down Expand Up @@ -276,7 +282,9 @@ async fn connect(options: &mut MqttOptions) -> Result<(Network, Incoming), Conne
}

async fn network_connect(options: &MqttOptions) -> Result<Network, ConnectionError> {
let mut max_incoming_pkt_size = Some(options.default_max_incoming_size);
let mut max_incoming_pkt_size = Some(options.default_max_incoming_size); // incoming == outgoing
let max_outgoing_pkt_size = Some(options.default_max_incoming_size);
let network_timeout = Duration::from_secs(options.network_options.connection_timeout());

// Override default value if max_packet_size is set on `connect_properties`
if let Some(connect_props) = &options.connect_properties {
Expand All @@ -291,7 +299,12 @@ async fn network_connect(options: &MqttOptions) -> Result<Network, ConnectionErr
if matches!(options.transport(), Transport::Unix) {
let file = options.broker_addr.as_str();
let socket = UnixStream::connect(Path::new(file)).await?;
let network = Network::new(socket, max_incoming_pkt_size);
let network = Network::new(
socket,
max_incoming_pkt_size,
max_outgoing_pkt_size,
network_timeout,
);
return Ok(network);
}

Expand Down Expand Up @@ -327,13 +340,23 @@ async fn network_connect(options: &MqttOptions) -> Result<Network, ConnectionErr
};

let network = match options.transport() {
Transport::Tcp => Network::new(tcp_stream, max_incoming_pkt_size),
Transport::Tcp => Network::new(
tcp_stream,
max_incoming_pkt_size,
max_outgoing_pkt_size,
network_timeout,
),
#[cfg(any(feature = "use-native-tls", feature = "use-rustls"))]
Transport::Tls(tls_config) => {
let socket =
tls::tls_connect(&options.broker_addr, options.port, &tls_config, tcp_stream)
.await?;
Network::new(socket, max_incoming_pkt_size)
Network::new(
socket,
max_incoming_pkt_size,
max_outgoing_pkt_size,
network_timeout,
)
}
#[cfg(unix)]
Transport::Unix => unreachable!(),
Expand All @@ -352,7 +375,12 @@ async fn network_connect(options: &MqttOptions) -> Result<Network, ConnectionErr
async_tungstenite::tokio::client_async(request, tcp_stream).await?;
validate_response_headers(response)?;

Network::new(WsStream::new(socket), max_incoming_pkt_size)
Network::new(
WsStream::new(socket),
max_incoming_pkt_size,
max_outgoing_pkt_size,
network_timeout,
)
}
#[cfg(all(feature = "use-rustls", feature = "websocket"))]
Transport::Wss(tls_config) => {
Expand All @@ -375,7 +403,12 @@ async fn network_connect(options: &MqttOptions) -> Result<Network, ConnectionErr
.await?;
validate_response_headers(response)?;

Network::new(WsStream::new(socket), max_incoming_pkt_size)
Network::new(
WsStream::new(socket),
max_incoming_pkt_size,
max_outgoing_pkt_size,
network_timeout,
)
}
};

Expand All @@ -390,6 +423,7 @@ async fn mqtt_connect(
let clean_start = options.clean_start();
let client_id = options.client_id();
let properties = options.connect_properties();
let last_will = options.last_will();

let connect = Connect {
keep_alive,
Expand All @@ -399,7 +433,9 @@ async fn mqtt_connect(
};

// send mqtt connect packet
network.connect(connect, options).await?;
network
.send(Packet::Connect(connect, last_will, None))
.await?;

// validate connack
match network.read().await? {
Expand Down
149 changes: 35 additions & 114 deletions rumqttc/src/v5/framed.rs
Original file line number Diff line number Diff line change
@@ -1,132 +1,53 @@
use bytes::BytesMut;
use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt};
use futures_util::SinkExt;
use tokio::time::timeout;
use tokio_stream::StreamExt;
use tokio_util::codec::Framed;

use super::mqttbytes;
use super::mqttbytes::v5::{Connect, Login, Packet};
use super::{Incoming, MqttOptions, MqttState, StateError};
use std::io;
use crate::framed::AsyncReadWrite;

use super::mqttbytes::v5::Packet;
use super::{mqttbytes, Codec};
use super::{Incoming, StateError};
use std::time::Duration;

/// Network transforms packets <-> frames efficiently. It takes
/// advantage of pre-allocation, buffering and vectorization when
/// appropriate to achieve performance
pub struct Network {
/// Socket for IO
socket: Box<dyn N>,
/// Buffered reads
read: BytesMut,
/// Maximum packet size
max_incoming_size: Option<usize>,
/// Maximum readv count
max_readb_count: usize,
/// Frame MQTT packets from network connection
framed: Framed<Box<dyn AsyncReadWrite>, Codec>,
/// Time within which network operations should complete
timeout: Duration,
}

impl Network {
pub fn new(socket: impl N + 'static, max_incoming_size: Option<usize>) -> Network {
let socket = Box::new(socket) as Box<dyn N>;
Network {
socket,
read: BytesMut::with_capacity(10 * 1024),
pub fn new(
socket: impl AsyncReadWrite + 'static,
max_incoming_size: Option<usize>,
max_outgoing_size: Option<usize>,
timeout: Duration,
) -> Network {
let socket = Box::new(socket) as Box<dyn AsyncReadWrite>;
let codec = Codec {
max_incoming_size,
max_readb_count: 10,
}
}

/// Reads more than 'required' bytes to frame a packet into self.read buffer
async fn read_bytes(&mut self, required: usize) -> io::Result<usize> {
let mut total_read = 0;
loop {
let read = self.socket.read_buf(&mut self.read).await?;
if 0 == read {
return if self.read.is_empty() {
Err(io::Error::new(
io::ErrorKind::ConnectionAborted,
"connection closed by peer",
))
} else {
Err(io::Error::new(
io::ErrorKind::ConnectionReset,
"connection reset by peer",
))
};
}
max_outgoing_size,
};
let framed = Framed::new(socket, codec);

total_read += read;
if total_read >= required {
return Ok(total_read);
}
}
Network { framed, timeout }
}

pub async fn read(&mut self) -> io::Result<Incoming> {
loop {
let required = match Packet::read(&mut self.read, self.max_incoming_size) {
Ok(packet) => return Ok(packet),
Err(mqttbytes::Error::InsufficientBytes(required)) => required,
Err(e) => return Err(io::Error::new(io::ErrorKind::InvalidData, e.to_string())),
};

// read more packets until a frame can be created. This function
// blocks until a frame can be created. Use this in a select! branch
self.read_bytes(required).await?;
pub async fn read(&mut self) -> Result<Incoming, StateError> {
match self.framed.next().await {
Some(Ok(packet)) => Ok(packet),
Some(Err(mqttbytes::Error::InsufficientBytes(_))) | None => unreachable!(),
Some(Err(e)) => Err(StateError::Deserialization(e)),
}
}

/// Read packets in bulk. This allow replies to be in bulk. This method is used
/// after the connection is established to read a bunch of incoming packets
pub async fn readb(&mut self, state: &mut MqttState) -> Result<(), StateError> {
let mut count = 0;
loop {
match Packet::read(&mut self.read, self.max_incoming_size) {
Ok(packet) => {
state.handle_incoming_packet(packet)?;

count += 1;
if count >= self.max_readb_count {
return Ok(());
}
}
// If some packets are already framed, return those
Err(mqttbytes::Error::InsufficientBytes(_)) if count > 0 => return Ok(()),
// Wait for more bytes until a frame can be created
Err(mqttbytes::Error::InsufficientBytes(required)) => {
self.read_bytes(required).await?;
}
Err(mqttbytes::Error::PayloadSizeLimitExceeded { pkt_size, max }) => {
state.handle_protocol_error()?;
return Err(StateError::IncomingPacketTooLarge { pkt_size, max });
}
Err(e) => return Err(StateError::Deserialization(e)),
};
pub async fn send(&mut self, packet: Packet) -> Result<(), StateError> {
match timeout(self.timeout, self.framed.send(packet)).await {
Ok(inner) => inner.map_err(StateError::Deserialization),
Err(e) => Err(StateError::Timeout(e)),
}
}

pub async fn connect(&mut self, connect: Connect, options: &MqttOptions) -> io::Result<usize> {
let mut write = BytesMut::new();
let last_will = options.last_will();
let login = options.credentials().map(|l| Login {
username: l.0,
password: l.1,
});

let len = match Packet::Connect(connect, last_will, login).write(&mut write) {
Ok(size) => size,
Err(e) => return Err(io::Error::new(io::ErrorKind::InvalidData, e.to_string())),
};

self.socket.write_all(&write[..]).await?;
Ok(len)
}

pub async fn flush(&mut self, write: &mut BytesMut) -> io::Result<()> {
if write.is_empty() {
return Ok(());
}

self.socket.write_all(&write[..]).await?;
write.clear();
Ok(())
}
}

pub trait N: AsyncRead + AsyncWrite + Send + Unpin {}
impl<T> N for T where T: AsyncRead + AsyncWrite + Send + Unpin {}
6 changes: 5 additions & 1 deletion rumqttc/src/v5/mqttbytes/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -130,7 +130,7 @@ pub fn matches(topic: &str, filter: &str) -> bool {
}

/// Error during serialization and deserialization
#[derive(Debug, Clone, Copy, PartialEq, Eq, thiserror::Error)]
#[derive(Debug, thiserror::Error)]
pub enum Error {
#[error("Invalid return code received as response for connect = {0}")]
InvalidConnectReturnCode(u8),
Expand Down Expand Up @@ -183,4 +183,8 @@ pub enum Error {
/// proceed further
#[error("Insufficient number of bytes to frame packet, {0} more bytes required")]
InsufficientBytes(usize),
#[error("IO: {0}")]
Io(#[from] std::io::Error),
#[error("Cannot send packet of size '{pkt_size:?}'. It's greater than the broker's maximum packet size of: '{max:?}'")]
OutgoingPacketTooLarge { pkt_size: u32, max: u32 },
}
73 changes: 73 additions & 0 deletions rumqttc/src/v5/mqttbytes/v5/codec.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
use bytes::{Buf, BytesMut};
use tokio_util::codec::{Decoder, Encoder};

use super::{Error, Packet};

/// MQTT v4 codec
#[derive(Debug, Clone)]
pub struct Codec {
/// Maximum packet size allowed by client
pub max_incoming_size: Option<usize>,
/// Maximum packet size allowed by broker
pub max_outgoing_size: Option<usize>,
}

impl Decoder for Codec {
type Item = Packet;
type Error = Error;

fn decode(&mut self, src: &mut BytesMut) -> Result<Option<Self::Item>, Self::Error> {
if src.remaining() == 0 {
return Ok(None);
}

let packet = Packet::read(src, self.max_incoming_size)?;
Ok(Some(packet))
}
}

impl Encoder<Packet> for Codec {
type Error = Error;

fn encode(&mut self, item: Packet, dst: &mut BytesMut) -> Result<(), Self::Error> {
item.write(dst, self.max_outgoing_size)?;

Ok(())
}
}

#[cfg(test)]
mod tests {
use bytes::BytesMut;
use tokio_util::codec::Encoder;

use super::Codec;
use crate::v5::{
mqttbytes::{Error, QoS},
Packet, Publish,
};

#[test]
fn outgoing_max_packet_size_check() {
let mut buf = BytesMut::new();
let mut codec = Codec {
max_incoming_size: Some(100),
max_outgoing_size: Some(200),
};

let mut small_publish = Publish::new("hello/world", QoS::AtLeastOnce, vec![1; 100], None);
small_publish.pkid = 1;
codec
.encode(Packet::Publish(small_publish), &mut buf)
.unwrap();

let large_publish = Publish::new("hello/world", QoS::AtLeastOnce, vec![1; 265], None);
match codec.encode(Packet::Publish(large_publish), &mut buf) {
Err(Error::OutgoingPacketTooLarge {
pkt_size: 282,
max: 200,
}) => {}
_ => unreachable!(),
}
}
}
Loading

0 comments on commit 5dbe5a1

Please sign in to comment.