-
Notifications
You must be signed in to change notification settings - Fork 255
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Devdutt Shenoi
committed
Mar 19, 2024
1 parent
041c6c8
commit 5dbe5a1
Showing
6 changed files
with
289 additions
and
242 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,132 +1,53 @@ | ||
use bytes::BytesMut; | ||
use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt}; | ||
use futures_util::SinkExt; | ||
use tokio::time::timeout; | ||
use tokio_stream::StreamExt; | ||
use tokio_util::codec::Framed; | ||
|
||
use super::mqttbytes; | ||
use super::mqttbytes::v5::{Connect, Login, Packet}; | ||
use super::{Incoming, MqttOptions, MqttState, StateError}; | ||
use std::io; | ||
use crate::framed::AsyncReadWrite; | ||
|
||
use super::mqttbytes::v5::Packet; | ||
use super::{mqttbytes, Codec}; | ||
use super::{Incoming, StateError}; | ||
use std::time::Duration; | ||
|
||
/// Network transforms packets <-> frames efficiently. It takes | ||
/// advantage of pre-allocation, buffering and vectorization when | ||
/// appropriate to achieve performance | ||
pub struct Network { | ||
/// Socket for IO | ||
socket: Box<dyn N>, | ||
/// Buffered reads | ||
read: BytesMut, | ||
/// Maximum packet size | ||
max_incoming_size: Option<usize>, | ||
/// Maximum readv count | ||
max_readb_count: usize, | ||
/// Frame MQTT packets from network connection | ||
framed: Framed<Box<dyn AsyncReadWrite>, Codec>, | ||
/// Time within which network operations should complete | ||
timeout: Duration, | ||
} | ||
|
||
impl Network { | ||
pub fn new(socket: impl N + 'static, max_incoming_size: Option<usize>) -> Network { | ||
let socket = Box::new(socket) as Box<dyn N>; | ||
Network { | ||
socket, | ||
read: BytesMut::with_capacity(10 * 1024), | ||
pub fn new( | ||
socket: impl AsyncReadWrite + 'static, | ||
max_incoming_size: Option<usize>, | ||
max_outgoing_size: Option<usize>, | ||
timeout: Duration, | ||
) -> Network { | ||
let socket = Box::new(socket) as Box<dyn AsyncReadWrite>; | ||
let codec = Codec { | ||
max_incoming_size, | ||
max_readb_count: 10, | ||
} | ||
} | ||
|
||
/// Reads more than 'required' bytes to frame a packet into self.read buffer | ||
async fn read_bytes(&mut self, required: usize) -> io::Result<usize> { | ||
let mut total_read = 0; | ||
loop { | ||
let read = self.socket.read_buf(&mut self.read).await?; | ||
if 0 == read { | ||
return if self.read.is_empty() { | ||
Err(io::Error::new( | ||
io::ErrorKind::ConnectionAborted, | ||
"connection closed by peer", | ||
)) | ||
} else { | ||
Err(io::Error::new( | ||
io::ErrorKind::ConnectionReset, | ||
"connection reset by peer", | ||
)) | ||
}; | ||
} | ||
max_outgoing_size, | ||
}; | ||
let framed = Framed::new(socket, codec); | ||
|
||
total_read += read; | ||
if total_read >= required { | ||
return Ok(total_read); | ||
} | ||
} | ||
Network { framed, timeout } | ||
} | ||
|
||
pub async fn read(&mut self) -> io::Result<Incoming> { | ||
loop { | ||
let required = match Packet::read(&mut self.read, self.max_incoming_size) { | ||
Ok(packet) => return Ok(packet), | ||
Err(mqttbytes::Error::InsufficientBytes(required)) => required, | ||
Err(e) => return Err(io::Error::new(io::ErrorKind::InvalidData, e.to_string())), | ||
}; | ||
|
||
// read more packets until a frame can be created. This function | ||
// blocks until a frame can be created. Use this in a select! branch | ||
self.read_bytes(required).await?; | ||
pub async fn read(&mut self) -> Result<Incoming, StateError> { | ||
match self.framed.next().await { | ||
Some(Ok(packet)) => Ok(packet), | ||
Some(Err(mqttbytes::Error::InsufficientBytes(_))) | None => unreachable!(), | ||
Some(Err(e)) => Err(StateError::Deserialization(e)), | ||
} | ||
} | ||
|
||
/// Read packets in bulk. This allow replies to be in bulk. This method is used | ||
/// after the connection is established to read a bunch of incoming packets | ||
pub async fn readb(&mut self, state: &mut MqttState) -> Result<(), StateError> { | ||
let mut count = 0; | ||
loop { | ||
match Packet::read(&mut self.read, self.max_incoming_size) { | ||
Ok(packet) => { | ||
state.handle_incoming_packet(packet)?; | ||
|
||
count += 1; | ||
if count >= self.max_readb_count { | ||
return Ok(()); | ||
} | ||
} | ||
// If some packets are already framed, return those | ||
Err(mqttbytes::Error::InsufficientBytes(_)) if count > 0 => return Ok(()), | ||
// Wait for more bytes until a frame can be created | ||
Err(mqttbytes::Error::InsufficientBytes(required)) => { | ||
self.read_bytes(required).await?; | ||
} | ||
Err(mqttbytes::Error::PayloadSizeLimitExceeded { pkt_size, max }) => { | ||
state.handle_protocol_error()?; | ||
return Err(StateError::IncomingPacketTooLarge { pkt_size, max }); | ||
} | ||
Err(e) => return Err(StateError::Deserialization(e)), | ||
}; | ||
pub async fn send(&mut self, packet: Packet) -> Result<(), StateError> { | ||
match timeout(self.timeout, self.framed.send(packet)).await { | ||
Ok(inner) => inner.map_err(StateError::Deserialization), | ||
Err(e) => Err(StateError::Timeout(e)), | ||
} | ||
} | ||
|
||
pub async fn connect(&mut self, connect: Connect, options: &MqttOptions) -> io::Result<usize> { | ||
let mut write = BytesMut::new(); | ||
let last_will = options.last_will(); | ||
let login = options.credentials().map(|l| Login { | ||
username: l.0, | ||
password: l.1, | ||
}); | ||
|
||
let len = match Packet::Connect(connect, last_will, login).write(&mut write) { | ||
Ok(size) => size, | ||
Err(e) => return Err(io::Error::new(io::ErrorKind::InvalidData, e.to_string())), | ||
}; | ||
|
||
self.socket.write_all(&write[..]).await?; | ||
Ok(len) | ||
} | ||
|
||
pub async fn flush(&mut self, write: &mut BytesMut) -> io::Result<()> { | ||
if write.is_empty() { | ||
return Ok(()); | ||
} | ||
|
||
self.socket.write_all(&write[..]).await?; | ||
write.clear(); | ||
Ok(()) | ||
} | ||
} | ||
|
||
pub trait N: AsyncRead + AsyncWrite + Send + Unpin {} | ||
impl<T> N for T where T: AsyncRead + AsyncWrite + Send + Unpin {} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,73 @@ | ||
use bytes::{Buf, BytesMut}; | ||
use tokio_util::codec::{Decoder, Encoder}; | ||
|
||
use super::{Error, Packet}; | ||
|
||
/// MQTT v4 codec | ||
#[derive(Debug, Clone)] | ||
pub struct Codec { | ||
/// Maximum packet size allowed by client | ||
pub max_incoming_size: Option<usize>, | ||
/// Maximum packet size allowed by broker | ||
pub max_outgoing_size: Option<usize>, | ||
} | ||
|
||
impl Decoder for Codec { | ||
type Item = Packet; | ||
type Error = Error; | ||
|
||
fn decode(&mut self, src: &mut BytesMut) -> Result<Option<Self::Item>, Self::Error> { | ||
if src.remaining() == 0 { | ||
return Ok(None); | ||
} | ||
|
||
let packet = Packet::read(src, self.max_incoming_size)?; | ||
Ok(Some(packet)) | ||
} | ||
} | ||
|
||
impl Encoder<Packet> for Codec { | ||
type Error = Error; | ||
|
||
fn encode(&mut self, item: Packet, dst: &mut BytesMut) -> Result<(), Self::Error> { | ||
item.write(dst, self.max_outgoing_size)?; | ||
|
||
Ok(()) | ||
} | ||
} | ||
|
||
#[cfg(test)] | ||
mod tests { | ||
use bytes::BytesMut; | ||
use tokio_util::codec::Encoder; | ||
|
||
use super::Codec; | ||
use crate::v5::{ | ||
mqttbytes::{Error, QoS}, | ||
Packet, Publish, | ||
}; | ||
|
||
#[test] | ||
fn outgoing_max_packet_size_check() { | ||
let mut buf = BytesMut::new(); | ||
let mut codec = Codec { | ||
max_incoming_size: Some(100), | ||
max_outgoing_size: Some(200), | ||
}; | ||
|
||
let mut small_publish = Publish::new("hello/world", QoS::AtLeastOnce, vec![1; 100], None); | ||
small_publish.pkid = 1; | ||
codec | ||
.encode(Packet::Publish(small_publish), &mut buf) | ||
.unwrap(); | ||
|
||
let large_publish = Publish::new("hello/world", QoS::AtLeastOnce, vec![1; 265], None); | ||
match codec.encode(Packet::Publish(large_publish), &mut buf) { | ||
Err(Error::OutgoingPacketTooLarge { | ||
pkt_size: 282, | ||
max: 200, | ||
}) => {} | ||
_ => unreachable!(), | ||
} | ||
} | ||
} |
Oops, something went wrong.