Skip to content

Commit

Permalink
refactor: simplify Networkwith Framed<.., Codec> (#825)
Browse files Browse the repository at this point in the history
* feat: MQTT `Codec` decoder

* feat: MQTT `Codec` encoder

* refactor: move write buffer into `Network`

* fix: `readb` should block for 1 packet (#824)

* refactor: use `Framed` to handle reads/writes from network

* doc: make it clear

* refactor: v5 implementation

* doc: changelog entry

* fix: state

* ack incoming publishes if required
* remove unused write buffer
* update tests
  • Loading branch information
Devdutt Shenoi authored Mar 25, 2024
1 parent df348cf commit f869eae
Show file tree
Hide file tree
Showing 18 changed files with 615 additions and 507 deletions.
13 changes: 13 additions & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions rumqttc/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
### Changed

* rename `N` as `AsyncReadWrite` to describe usage.
* use `Framed` to encode/decode MQTT packets.

### Deprecated

Expand Down
4 changes: 3 additions & 1 deletion rumqttc/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,9 @@ websocket = ["dep:async-tungstenite", "dep:ws_stream_tungstenite", "dep:http"]
proxy = ["dep:async-http-proxy"]

[dependencies]
futures-util = { version = "0.3", default_features = false, features = ["std"] }
futures-util = { version = "0.3", default_features = false, features = ["std", "sink"] }
tokio = { version = "1.36", features = ["rt", "macros", "io-util", "net", "time"] }
tokio-util = { version = "0.7", features = ["codec"] }
bytes = "1.5"
log = "0.4"
flume = { version = "0.11", default-features = false, features = ["async"] }
Expand All @@ -47,6 +48,7 @@ native-tls = { version = "0.2.11", optional = true }
url = { version = "2", default-features = false, optional = true }
# proxy
async-http-proxy = { version = "1.2.5", features = ["runtime-tokio", "basic-auth"], optional = true }
tokio-stream = "0.1.15"

[dev-dependencies]
bincode = "1.3.3"
Expand Down
49 changes: 36 additions & 13 deletions rumqttc/src/eventloop.rs
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ pub struct EventLoop {
/// Pending packets from last session
pub pending: VecDeque<Request>,
/// Network connection to the broker
network: Option<Network>,
pub network: Option<Network>,
/// Keep alive time
keepalive_timeout: Option<Pin<Box<Sleep>>>,
pub network_options: NetworkOptions,
Expand All @@ -104,11 +104,10 @@ impl EventLoop {
let pending = VecDeque::new();
let max_inflight = mqtt_options.inflight;
let manual_acks = mqtt_options.manual_acks;
let max_outgoing_packet_size = mqtt_options.max_outgoing_packet_size;

EventLoop {
mqtt_options,
state: MqttState::new(max_inflight, manual_acks, max_outgoing_packet_size),
state: MqttState::new(max_inflight, manual_acks),
requests_tx,
requests_rx,
pending,
Expand Down Expand Up @@ -189,7 +188,7 @@ impl EventLoop {
o = network.readb(&mut self.state) => {
o?;
// flush all the acks and return first incoming packet
match time::timeout(network_timeout, network.flush(&mut self.state.write)).await {
match time::timeout(network_timeout, network.flush()).await {
Ok(inner) => inner?,
Err(_)=> return Err(ConnectionError::FlushTimeout),
};
Expand Down Expand Up @@ -229,8 +228,10 @@ impl EventLoop {
self.mqtt_options.pending_throttle
), if !self.pending.is_empty() || (!inflight_full && !collision) => match o {
Ok(request) => {
self.state.handle_outgoing_packet(request)?;
match time::timeout(network_timeout, network.flush(&mut self.state.write)).await {
if let Some(outgoing) = self.state.handle_outgoing_packet(request)? {
network.write(outgoing).await?;
}
match time::timeout(network_timeout, network.flush()).await {
Ok(inner) => inner?,
Err(_)=> return Err(ConnectionError::FlushTimeout),
};
Expand All @@ -245,8 +246,10 @@ impl EventLoop {
let timeout = self.keepalive_timeout.as_mut().unwrap();
timeout.as_mut().reset(Instant::now() + self.mqtt_options.keep_alive);

self.state.handle_outgoing_packet(Request::PingReq(PingReq))?;
match time::timeout(network_timeout, network.flush(&mut self.state.write)).await {
if let Some(outgoing) = self.state.handle_outgoing_packet(Request::PingReq(PingReq))? {
network.write(outgoing).await?;
}
match time::timeout(network_timeout, network.flush()).await {
Ok(inner) => inner?,
Err(_)=> return Err(ConnectionError::FlushTimeout),
};
Expand Down Expand Up @@ -356,7 +359,11 @@ async fn network_connect(
if matches!(options.transport(), Transport::Unix) {
let file = options.broker_addr.as_str();
let socket = UnixStream::connect(Path::new(file)).await?;
let network = Network::new(socket, options.max_incoming_packet_size);
let network = Network::new(
socket,
options.max_incoming_packet_size,
options.max_outgoing_packet_size,
);
return Ok(network);
}

Expand Down Expand Up @@ -388,13 +395,21 @@ async fn network_connect(
};

let network = match options.transport() {
Transport::Tcp => Network::new(tcp_stream, options.max_incoming_packet_size),
Transport::Tcp => Network::new(
tcp_stream,
options.max_incoming_packet_size,
options.max_outgoing_packet_size,
),
#[cfg(any(feature = "use-rustls", feature = "use-native-tls"))]
Transport::Tls(tls_config) => {
let socket =
tls::tls_connect(&options.broker_addr, options.port, &tls_config, tcp_stream)
.await?;
Network::new(socket, options.max_incoming_packet_size)
Network::new(
socket,
options.max_incoming_packet_size,
options.max_outgoing_packet_size,
)
}
#[cfg(unix)]
Transport::Unix => unreachable!(),
Expand All @@ -413,7 +428,11 @@ async fn network_connect(
async_tungstenite::tokio::client_async(request, tcp_stream).await?;
validate_response_headers(response)?;

Network::new(WsStream::new(socket), options.max_incoming_packet_size)
Network::new(
WsStream::new(socket),
options.max_incoming_packet_size,
options.max_outgoing_packet_size,
)
}
#[cfg(all(feature = "use-rustls", feature = "websocket"))]
Transport::Wss(tls_config) => {
Expand All @@ -436,7 +455,11 @@ async fn network_connect(
.await?;
validate_response_headers(response)?;

Network::new(WsStream::new(socket), options.max_incoming_packet_size)
Network::new(
WsStream::new(socket),
options.max_incoming_packet_size,
options.max_outgoing_packet_size,
)
}
};

Expand Down
135 changes: 57 additions & 78 deletions rumqttc/src/framed.rs
Original file line number Diff line number Diff line change
@@ -1,119 +1,98 @@
use bytes::BytesMut;
use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt};
use futures_util::{FutureExt, SinkExt, StreamExt};
use tokio::io::{AsyncRead, AsyncWrite};
use tokio_util::codec::Framed;

use crate::mqttbytes::{self, v4::*};
use crate::{Incoming, MqttState, StateError};
use std::io;

/// Network transforms packets <-> frames efficiently. It takes
/// advantage of pre-allocation, buffering and vectorization when
/// appropriate to achieve performance
pub struct Network {
/// Socket for IO
socket: Box<dyn AsyncReadWrite>,
/// Buffered reads
read: BytesMut,
/// Maximum packet size
max_incoming_size: usize,
/// Frame MQTT packets from network connection
framed: Framed<Box<dyn AsyncReadWrite>, Codec>,
/// Maximum readv count
max_readb_count: usize,
}

impl Network {
pub fn new(socket: impl AsyncReadWrite + 'static, max_incoming_size: usize) -> Network {
pub fn new(
socket: impl AsyncReadWrite + 'static,
max_incoming_size: usize,
max_outgoing_size: usize,
) -> Network {
let socket = Box::new(socket) as Box<dyn AsyncReadWrite>;
Network {
socket,
read: BytesMut::with_capacity(10 * 1024),
let codec = Codec {
max_incoming_size,
max_outgoing_size,
};
let framed = Framed::new(socket, codec);

Network {
framed,
max_readb_count: 10,
}
}

/// Reads more than 'required' bytes to frame a packet into self.read buffer
async fn read_bytes(&mut self, required: usize) -> io::Result<usize> {
let mut total_read = 0;
loop {
let read = self.socket.read_buf(&mut self.read).await?;
if 0 == read {
return if self.read.is_empty() {
Err(io::Error::new(
io::ErrorKind::ConnectionAborted,
"connection closed by peer",
))
} else {
Err(io::Error::new(
io::ErrorKind::ConnectionReset,
"connection reset by peer",
))
};
}

total_read += read;
if total_read >= required {
return Ok(total_read);
}
}
}

pub async fn read(&mut self) -> io::Result<Incoming> {
loop {
let required = match Packet::read(&mut self.read, self.max_incoming_size) {
Ok(packet) => return Ok(packet),
Err(mqttbytes::Error::InsufficientBytes(required)) => required,
Err(e) => return Err(io::Error::new(io::ErrorKind::InvalidData, e.to_string())),
};

// read more packets until a frame can be created. This function
// blocks until a frame can be created. Use this in a select! branch
self.read_bytes(required).await?;
/// Reads and returns a single packet from network
pub async fn read(&mut self) -> Result<Incoming, StateError> {
match self.framed.next().await {
Some(Ok(packet)) => Ok(packet),
Some(Err(mqttbytes::Error::InsufficientBytes(_))) | None => unreachable!(),
Some(Err(e)) => Err(StateError::Deserialization(e)),
}
}

/// Read packets in bulk. This allow replies to be in bulk. This method is used
/// after the connection is established to read a bunch of incoming packets
pub async fn readb(&mut self, state: &mut MqttState) -> Result<(), StateError> {
let mut count = 0;
// wait for the first read
let mut res = self.framed.next().await;
let mut count = 1;
loop {
match Packet::read(&mut self.read, self.max_incoming_size) {
Ok(packet) => {
state.handle_incoming_packet(packet)?;
match res {
Some(Ok(packet)) => {
if let Some(outgoing) = state.handle_incoming_packet(packet)? {
self.write(outgoing).await?;
}

count += 1;
if count >= self.max_readb_count {
return Ok(());
break;
}
}
// If some packets are already framed, return those
Err(mqttbytes::Error::InsufficientBytes(_)) if count > 0 => return Ok(()),
// Wait for more bytes until a frame can be created
Err(mqttbytes::Error::InsufficientBytes(required)) => {
self.read_bytes(required).await?;
}
Err(e) => return Err(StateError::Deserialization(e)),
Some(Err(mqttbytes::Error::InsufficientBytes(_))) | None => unreachable!(),
Some(Err(e)) => return Err(StateError::Deserialization(e)),
}
// do not wait for subsequent reads
match self.framed.next().now_or_never() {
Some(r) => res = r,
_ => break,
};
}
}

pub async fn connect(&mut self, connect: Connect) -> io::Result<usize> {
let mut write = BytesMut::new();
let len = match connect.write(&mut write) {
Ok(size) => size,
Err(e) => return Err(io::Error::new(io::ErrorKind::InvalidData, e.to_string())),
};
Ok(())
}

self.socket.write_all(&write[..]).await?;
Ok(len)
/// Serializes packet into write buffer
pub async fn write(&mut self, packet: Packet) -> Result<(), StateError> {
self.framed
.feed(packet)
.await
.map_err(StateError::Deserialization)
}

pub async fn flush(&mut self, write: &mut BytesMut) -> io::Result<()> {
if write.is_empty() {
return Ok(());
}
pub async fn connect(&mut self, connect: Connect) -> Result<(), StateError> {
self.write(Packet::Connect(connect)).await?;

self.socket.write_all(&write[..]).await?;
write.clear();
Ok(())
self.flush().await
}

pub async fn flush(&mut self) -> Result<(), crate::state::StateError> {
self.framed
.flush()
.await
.map_err(StateError::Deserialization)
}
}

Expand Down
19 changes: 0 additions & 19 deletions rumqttc/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -200,25 +200,6 @@ pub enum Request {
Disconnect(Disconnect),
}

impl Request {
fn size(&self) -> usize {
match &self {
Request::Publish(publish) => publish.size(),
Request::PubAck(puback) => puback.size(),
Request::PubRec(pubrec) => pubrec.size(),
Request::PubComp(pubcomp) => pubcomp.size(),
Request::PubRel(pubrel) => pubrel.size(),
Request::PingReq(pingreq) => pingreq.size(),
Request::PingResp(pingresp) => pingresp.size(),
Request::Subscribe(subscribe) => subscribe.size(),
Request::SubAck(suback) => suback.size(),
Request::Unsubscribe(unsubscribe) => unsubscribe.size(),
Request::UnsubAck(unsuback) => unsuback.size(),
Request::Disconnect(disconn) => disconn.size(),
}
}
}

impl From<Publish> for Request {
fn from(publish: Publish) -> Request {
Request::Publish(publish)
Expand Down
6 changes: 5 additions & 1 deletion rumqttc/src/mqttbytes/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ pub mod v4;
pub use topic::*;

/// Error during serialization and deserialization
#[derive(Debug, Clone, Copy, PartialEq, Eq, thiserror::Error)]
#[derive(Debug, thiserror::Error)]
pub enum Error {
#[error("Expected Connect, received: {0:?}")]
NotConnect(PacketType),
Expand Down Expand Up @@ -60,6 +60,10 @@ pub enum Error {
/// proceed further
#[error("At least {0} more bytes required to frame packet")]
InsufficientBytes(usize),
#[error("IO: {0}")]
Io(#[from] std::io::Error),
#[error("Cannot send packet of size '{pkt_size:?}'. It's greater than the broker's maximum packet size of: '{max:?}'")]
OutgoingPacketTooLarge { pkt_size: usize, max: usize },
}

/// MQTT packet type
Expand Down
Loading

0 comments on commit f869eae

Please sign in to comment.