Skip to content

Commit

Permalink
refactor: use Framed to handle reads/writes from network
Browse files Browse the repository at this point in the history
  • Loading branch information
Devdutt Shenoi committed Mar 25, 2024
1 parent 1fd8825 commit 200f865
Show file tree
Hide file tree
Showing 5 changed files with 68 additions and 104 deletions.
2 changes: 1 addition & 1 deletion rumqttc/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ websocket = ["dep:async-tungstenite", "dep:ws_stream_tungstenite", "dep:http"]
proxy = ["dep:async-http-proxy"]

[dependencies]
futures-util = { version = "0.3", default_features = false, features = ["std"] }
futures-util = { version = "0.3", default_features = false, features = ["std", "sink"] }
tokio = { version = "1.36", features = ["rt", "macros", "io-util", "net", "time"] }
tokio-util = { version = "0.7", features = ["codec"] }
bytes = "1.5"
Expand Down
16 changes: 12 additions & 4 deletions rumqttc/src/eventloop.rs
Original file line number Diff line number Diff line change
Expand Up @@ -229,7 +229,7 @@ impl EventLoop {
), if !self.pending.is_empty() || (!inflight_full && !collision) => match o {
Ok(request) => {
if let Some(outgoing) = self.state.handle_outgoing_packet(request)? {
network.write(outgoing)?;
network.write(outgoing).await?;
}
match time::timeout(network_timeout, network.flush()).await {
Ok(inner) => inner?,
Expand All @@ -247,7 +247,7 @@ impl EventLoop {
timeout.as_mut().reset(Instant::now() + self.mqtt_options.keep_alive);

if let Some(outgoing) = self.state.handle_outgoing_packet(Request::PingReq(PingReq))? {
network.write(outgoing)?;
network.write(outgoing).await?;
}
match time::timeout(network_timeout, network.flush()).await {
Ok(inner) => inner?,
Expand Down Expand Up @@ -428,7 +428,11 @@ async fn network_connect(
async_tungstenite::tokio::client_async(request, tcp_stream).await?;
validate_response_headers(response)?;

Network::new(WsStream::new(socket), options.max_incoming_packet_size)
Network::new(
WsStream::new(socket),
options.max_incoming_packet_size,
options.max_outgoing_packet_size,
)
}
#[cfg(all(feature = "use-rustls", feature = "websocket"))]
Transport::Wss(tls_config) => {
Expand All @@ -451,7 +455,11 @@ async fn network_connect(
.await?;
validate_response_headers(response)?;

Network::new(WsStream::new(socket), options.max_incoming_packet_size)
Network::new(
WsStream::new(socket),
options.max_incoming_packet_size,
options.max_outgoing_packet_size,
)
}
};

Expand Down
137 changes: 45 additions & 92 deletions rumqttc/src/framed.rs
Original file line number Diff line number Diff line change
@@ -1,23 +1,16 @@
use bytes::BytesMut;
use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt};
use tokio_util::codec::{Decoder, Encoder};
use futures_util::{FutureExt, SinkExt, StreamExt};
use tokio::io::{AsyncRead, AsyncWrite};
use tokio_util::codec::Framed;

use crate::mqttbytes::{self, v4::*};
use crate::{Incoming, MqttState, StateError};
use std::io;

/// Network transforms packets <-> frames efficiently. It takes
/// advantage of pre-allocation, buffering and vectorization when
/// appropriate to achieve performance
pub struct Network {
/// Socket for IO
socket: Box<dyn AsyncReadWrite>,
/// Buffered reads
read: BytesMut,
/// Buffered writes
pub write: BytesMut,
/// Use to decode MQTT packets
codec: Codec,
/// Frame MQTT packets from network connection
framed: Framed<Box<dyn AsyncReadWrite>, Codec>,
/// Maximum readv count
max_readb_count: usize,
}
Expand All @@ -29,117 +22,77 @@ impl Network {
max_outgoing_size: usize,
) -> Network {
let socket = Box::new(socket) as Box<dyn AsyncReadWrite>;
let codec = Codec {
max_incoming_size,
max_outgoing_size,
};
let framed = Framed::new(socket, codec);

Network {
socket,
read: BytesMut::with_capacity(10 * 1024),
write: BytesMut::with_capacity(10 * 1024),
codec: Codec {
max_incoming_size,
max_outgoing_size,
},
framed,
max_readb_count: 10,
}
}

/// Reads more than 'required' bytes to frame a packet into self.read buffer
async fn read_bytes(&mut self, required: usize) -> io::Result<usize> {
let mut total_read = 0;
loop {
let read = self.socket.read_buf(&mut self.read).await?;
if 0 == read {
return if self.read.is_empty() {
Err(io::Error::new(
io::ErrorKind::ConnectionAborted,
"connection closed by peer",
))
} else {
Err(io::Error::new(
io::ErrorKind::ConnectionReset,
"connection reset by peer",
))
};
}

total_read += read;
if total_read >= required {
return Ok(total_read);
}
}
}

pub async fn read(&mut self) -> io::Result<Incoming> {
loop {
let required = match self.codec.decode(&mut self.read) {
Ok(Some(packet)) => return Ok(packet),
// TODO: figure out how not to block
Ok(_) => 2,
Err(mqttbytes::Error::InsufficientBytes(required)) => required,
Err(e) => return Err(io::Error::new(io::ErrorKind::InvalidData, e.to_string())),
};

// read more packets until a frame can be created. This function
// blocks until a frame can be created. Use this in a select! branch
self.read_bytes(required).await?;
/// Reads and returns a single packet from network
pub async fn read(&mut self) -> Result<Incoming, StateError> {
match self.framed.next().await {
Some(Ok(packet)) => Ok(packet),
Some(Err(mqttbytes::Error::InsufficientBytes(_))) | None => unreachable!(),
Some(Err(e)) => Err(StateError::Deserialization(e)),
}
}

/// Read packets in bulk. This allow replies to be in bulk. This method is used
/// after the connection is established to read a bunch of incoming packets
pub async fn readb(&mut self, state: &mut MqttState) -> Result<(), StateError> {
let mut count = 0;
// wait for the first read
let mut res = self.framed.next().await;
let mut count = 1;
loop {
match self.codec.decode(&mut self.read) {
Ok(Some(packet)) => {
match res {
Some(Ok(packet)) => {
if let Some(packet) = state.handle_incoming_packet(packet)? {
self.write(packet)?;
self.write(packet).await?;
}

count += 1;
if count >= self.max_readb_count {
break;
}
}
// If some packets are already framed, return those
Err(mqttbytes::Error::InsufficientBytes(_)) | Ok(_) if count > 0 => break,
// NOTE: read atleast 1 packet
Ok(_) => {
self.read_bytes(2).await?;
}
// Wait for more bytes until a frame can be created
Err(mqttbytes::Error::InsufficientBytes(required)) => {
self.read_bytes(required).await?;
}
Err(e) => return Err(StateError::Deserialization(e)),
Some(Err(mqttbytes::Error::InsufficientBytes(_))) | None => unreachable!(),
Some(Err(e)) => return Err(StateError::Deserialization(e)),
}
// do not wait for subsequent reads
match self.framed.next().now_or_never() {
Some(r) => res = r,
_ => break,
};
}

Ok(())
}

pub fn write(&mut self, packet: Packet) -> Result<(), crate::state::StateError> {
self.codec
.encode(packet, &mut self.write)
.map_err(Into::into)
/// Serializes packet into write buffer
pub async fn write(&mut self, packet: Packet) -> Result<(), StateError> {
self.framed
.feed(packet)
.await
.map_err(StateError::Deserialization)
}

pub async fn connect(&mut self, connect: Connect) -> io::Result<()> {
let mut write = BytesMut::new();
self.codec
.encode(Packet::Connect(connect), &mut write)
.map_err(|e| io::Error::new(io::ErrorKind::InvalidData, e.to_string()))?;
pub async fn connect(&mut self, connect: Connect) -> Result<(), StateError> {
self.write(Packet::Connect(connect)).await?;

self.socket.write_all(&write[..]).await?;
Ok(())
self.flush().await
}

pub async fn flush(&mut self) -> io::Result<()> {
if self.write.is_empty() {
return Ok(());
}

self.socket.write_all(&self.write[..]).await?;
self.write.clear();
Ok(())
pub async fn flush(&mut self) -> Result<(), crate::state::StateError> {
self.framed
.flush()
.await
.map_err(StateError::Deserialization)
}
}

Expand Down
15 changes: 9 additions & 6 deletions rumqttc/src/mqttbytes/v4/codec.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use bytes::{Buf, BytesMut};
use bytes::BytesMut;
use tokio_util::codec::{Decoder, Encoder};

use super::{Error, Packet};
Expand All @@ -17,12 +17,15 @@ impl Decoder for Codec {
type Error = Error;

fn decode(&mut self, src: &mut BytesMut) -> Result<Option<Self::Item>, Self::Error> {
if src.remaining() == 0 {
return Ok(None);
match Packet::read(src, self.max_incoming_size) {
Ok(packet) => Ok(Some(packet)),
Err(Error::InsufficientBytes(b)) => {
// Get more packets to construct the incomplete packet
src.reserve(b);
Ok(None)
}
Err(e) => Err(e),
}

let packet = Packet::read(src, self.max_incoming_size)?;
Ok(Some(packet))
}
}

Expand Down
2 changes: 1 addition & 1 deletion rumqttc/tests/reliability.rs
Original file line number Diff line number Diff line change
Expand Up @@ -570,7 +570,7 @@ async fn state_is_being_cleaned_properly_and_pending_request_calculated_properly
if let Err(e) = res {
match e {
ConnectionError::FlushTimeout => {
assert!(eventloop.network.as_ref().unwrap().write.is_empty());
assert!(eventloop.network.is_none());
println!("State is being clean properly");
}
_ => {
Expand Down

0 comments on commit 200f865

Please sign in to comment.