From 07b728cd27fbf7c6386b3acdb7e029fafa68f567 Mon Sep 17 00:00:00 2001 From: tinzhu Date: Wed, 17 Apr 2024 16:20:18 +0800 Subject: [PATCH 01/31] Add enhanced authentication example. --- Cargo.lock | 51 ++++++- rumqttc/Cargo.toml | 1 + rumqttc/examples/auth.rs | 44 ++++++ rumqttc/src/lib.rs | 2 + rumqttc/src/v5/client.rs | 53 ++++++- rumqttc/src/v5/eventloop.rs | 64 +++++++-- rumqttc/src/v5/mod.rs | 5 +- rumqttc/src/v5/mqttbytes/v5/auth.rs | 214 ++++++++++++++++++++++++++++ rumqttc/src/v5/mqttbytes/v5/mod.rs | 11 ++ rumqttc/src/v5/state.rs | 49 ++++++- 10 files changed, 464 insertions(+), 30 deletions(-) create mode 100644 rumqttc/examples/auth.rs create mode 100644 rumqttc/src/v5/mqttbytes/v5/auth.rs diff --git a/Cargo.lock b/Cargo.lock index 75a7ced4c..d7966b0a9 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -689,7 +689,7 @@ checksum = "55ac459de2512911e4b674ce33cf20befaba382d05b62b008afc1c8b57cbf181" dependencies = [ "futures-core", "futures-sink", - "spin", + "spin 0.9.8", ] [[package]] @@ -1899,6 +1899,21 @@ dependencies = [ "bytemuck", ] +[[package]] +name = "ring" +version = "0.16.20" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3053cf52e236a3ed746dfc745aa9cacf1b791d846bdaf412f60a8d7d6e17c8fc" +dependencies = [ + "cc", + "libc", + "once_cell", + "spin 0.5.2", + "untrusted 0.7.1", + "web-sys", + "winapi", +] + [[package]] name = "ring" version = "0.17.8" @@ -1909,8 +1924,8 @@ dependencies = [ "cfg-if", "getrandom", "libc", - "spin", - "untrusted", + "spin 0.9.8", + "untrusted 0.9.0", "windows-sys 0.52.0", ] @@ -1946,6 +1961,7 @@ dependencies = [ "rustls-native-certs", "rustls-pemfile", "rustls-webpki", + "scram", "serde", "thiserror", "tokio", @@ -2046,7 +2062,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "e87c9956bd9807afa1f77e0f7594af32566e830e088a5576d27c5b6f30f49d41" dependencies = [ "log", - "ring", + "ring 0.17.8", "rustls-pki-types", "rustls-webpki", "subtle", @@ -2088,9 +2104,9 @@ version = "0.102.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "faaa0a62740bedb9b2ef5afa303da42764c012f743917351dc9a237ea1663610" dependencies = [ - "ring", + "ring 0.17.8", "rustls-pki-types", - "untrusted", + "untrusted 0.9.0", ] [[package]] @@ -2120,6 +2136,17 @@ version = "1.2.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "94143f37725109f92c262ed2cf5e59bce7498c01bcc1502d7b9afe439a4e9f49" +[[package]] +name = "scram" +version = "0.6.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7679a5e6b97bac99b2c208894ba0d34b17d9657f0b728c1cd3bf1c5f7f6ebe88" +dependencies = [ + "base64 0.13.1", + "rand", + "ring 0.16.20", +] + [[package]] name = "security-framework" version = "2.9.2" @@ -2293,6 +2320,12 @@ dependencies = [ "windows-sys 0.52.0", ] +[[package]] +name = "spin" +version = "0.5.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6e63cff320ae2c57904679ba7cb63280a3dc4613885beafb148ee7bf9aa9042d" + [[package]] name = "spin" version = "0.9.8" @@ -2767,6 +2800,12 @@ version = "0.2.4" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "f962df74c8c05a667b5ee8bcf162993134c104e96440b663c8daa176dc772d8c" +[[package]] +name = "untrusted" +version = "0.7.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a156c684c91ea7d62626509bce3cb4e1d9ed5c4d978f7b4352658f96a4c26b4a" + [[package]] name = "untrusted" version = "0.9.0" diff --git a/rumqttc/Cargo.toml b/rumqttc/Cargo.toml index bba64822c..fa57640db 100644 --- a/rumqttc/Cargo.toml +++ b/rumqttc/Cargo.toml @@ -30,6 +30,7 @@ bytes = "1.5" log = "0.4" flume = { version = "0.11", default-features = false, features = ["async"] } thiserror = "1" +scram = "0.6.0" # Optional # rustls diff --git a/rumqttc/examples/auth.rs b/rumqttc/examples/auth.rs new file mode 100644 index 000000000..80bdf050b --- /dev/null +++ b/rumqttc/examples/auth.rs @@ -0,0 +1,44 @@ + +use rumqttc::v5::mqttbytes::QoS; +use rumqttc::v5::{AsyncClient, MqttOptions}; +use tokio::task; +use std::error::Error; +use std::thread; +use scram::ScramClient; + +#[tokio::main()] +async fn main() -> Result<(), Box> { + + let scram = ScramClient::new("user1", "123456", None); + let (scram, client_first) = scram.client_first(); + + let mut mqttoptions = MqttOptions::new("auth_test", "127.0.0.1", 1883); + mqttoptions.set_authentication_method(Some("SCRAM-SHA-256".to_string())); + mqttoptions.set_authentication_data(Some(client_first.clone().into())); + mqttoptions.set_connection_timeout(20); + let (client, mut eventloop) = AsyncClient::new(mqttoptions, 10); + + task::spawn(async move { + let server_first: String = client.recv_server_auth_data().await.unwrap(); + let scram = scram.handle_server_first(&server_first).unwrap(); + let (scram, client_final) = scram.client_final(); + client.send_client_auth_data(client_final).await.unwrap(); + + client.subscribe("rumqtt_auth/topic", QoS::AtLeastOnce).await.unwrap(); + client.publish("rumqtt_auth/topic", QoS::AtLeastOnce, false, "hello world").await.unwrap(); + }); + + loop { + let notification = eventloop.poll().await; + + match notification { + Ok(event) => println!("{:?}", event), + Err(e) => { + println!("Error = {:?}", e); + break; + } + } + } + + Ok(()) +} \ No newline at end of file diff --git a/rumqttc/src/lib.rs b/rumqttc/src/lib.rs index 29cad1a34..887ef22dd 100644 --- a/rumqttc/src/lib.rs +++ b/rumqttc/src/lib.rs @@ -180,6 +180,8 @@ pub enum Outgoing { Disconnect, /// Await for an ack for more outgoing progress AwaitAck(u16), + /// Auth packet + Auth, } /// Requests by the client to mqtt event loop. Request are diff --git a/rumqttc/src/v5/client.rs b/rumqttc/src/v5/client.rs index f8629b8c5..768d746a2 100644 --- a/rumqttc/src/v5/client.rs +++ b/rumqttc/src/v5/client.rs @@ -3,7 +3,7 @@ use std::time::Duration; use super::mqttbytes::v5::{ - Filter, PubAck, PubRec, Publish, PublishProperties, Subscribe, SubscribeProperties, + Filter, PubAck, PubRec, Publish, PublishProperties, Subscribe, SubscribeProperties, Auth, AuthProperties, AuthReasonCode, Unsubscribe, UnsubscribeProperties, }; use super::mqttbytes::{valid_filter, QoS}; @@ -11,7 +11,7 @@ use super::{ConnectionError, Event, EventLoop, MqttOptions, Request}; use crate::valid_topic; use bytes::Bytes; -use flume::{SendError, Sender, TrySendError}; +use flume::{SendError, Sender, Receiver, TrySendError}; use futures_util::FutureExt; use tokio::runtime::{self, Runtime}; use tokio::time::timeout; @@ -47,6 +47,8 @@ impl From> for ClientError { #[derive(Clone, Debug)] pub struct AsyncClient { request_tx: Sender, + auth_tx: Option>, + auth_rx: Option>, } impl AsyncClient { @@ -56,18 +58,45 @@ impl AsyncClient { pub fn new(options: MqttOptions, cap: usize) -> (AsyncClient, EventLoop) { let eventloop = EventLoop::new(options, cap); let request_tx = eventloop.requests_tx.clone(); + let auth_tx = eventloop.auth_cdata_tx.clone(); + let auth_rx = eventloop.auth_sdata_rx.clone(); - let client = AsyncClient { request_tx }; + let client = AsyncClient { request_tx, auth_tx, auth_rx}; (client, eventloop) } + pub async fn recv_server_auth_data(&self) -> Result { + if self.auth_rx.is_none() { + return Err(ClientError::Request(Request::Disconnect)); + } + let string = self.auth_rx.as_ref().unwrap().recv_async().await; + + if let Err(_) = string { + return Err(ClientError::Request(Request::Disconnect)); + } + + Ok(string.unwrap()) + } + + pub async fn send_client_auth_data(&self, data: String) -> Result<(), ClientError> { + if self.auth_tx.is_none() { + return Err(ClientError::Request(Request::Disconnect)); + } + + if let Err(_) = self.auth_tx.as_ref().unwrap().send_async(data).await { + return Err(ClientError::Request(Request::Disconnect)); + } + + Ok(()) + } + /// Create a new `AsyncClient` from a channel `Sender`. /// /// This is mostly useful for creating a test instance where you can /// listen on the corresponding receiver. pub fn from_senders(request_tx: Sender) -> AsyncClient { - AsyncClient { request_tx } + AsyncClient { request_tx, auth_tx: None, auth_rx: None } } /// Sends a MQTT Publish to the `EventLoop`. @@ -196,6 +225,22 @@ impl AsyncClient { Ok(()) } + /// Sends a MQTT AUTH to `EventLoop` for authentication. + pub async fn auth(&self, reason: AuthReasonCode, properties: Option) -> Result<(), ClientError>{ + let auth = Auth::new(reason, properties); + let auth = Request::Auth(auth); + self.request_tx.send_async(auth).await?; + Ok(()) + } + + /// Attempts to send a MQTT AUTH to `EventLoop` for authentication. + pub fn try_auth(&self, reason: AuthReasonCode, properties: Option) -> Result<(), ClientError>{ + let auth = Auth::new(reason, properties); + let auth = Request::Auth(auth); + self.request_tx.try_send(auth)?; + Ok(()) + } + /// Sends a MQTT Publish to the `EventLoop` async fn handle_publish_bytes( &self, diff --git a/rumqttc/src/v5/eventloop.rs b/rumqttc/src/v5/eventloop.rs index a59094807..9927b4762 100644 --- a/rumqttc/src/v5/eventloop.rs +++ b/rumqttc/src/v5/eventloop.rs @@ -55,6 +55,8 @@ pub enum ConnectionError { NotConnAck(Box), #[error("Requests done")] RequestsDone, + #[error("Auth processing error")] + AuthProcessingError, #[cfg(feature = "websocket")] #[error("Invalid Url: {0}")] InvalidUrl(#[from] UrlError), @@ -82,6 +84,9 @@ pub struct EventLoop { network: Option, /// Keep alive time keepalive_timeout: Option>>, + + pub auth_sdata_rx: Option>, + pub auth_cdata_tx: Option>, } /// Events which can be yielded by the event loop @@ -101,15 +106,33 @@ impl EventLoop { let pending = VecDeque::new(); let inflight_limit = options.outgoing_inflight_upper_limit.unwrap_or(u16::MAX); let manual_acks = options.manual_acks; + + // set state according to authentication method + let mut auth_sdata_tx = None; + let mut auth_sdata_rx = None; + let mut auth_cdata_rx = None; + let mut auth_cdata_tx = None; + + if options.authentication_method().is_some() { + let (auth_ctx, auth_crx) = bounded(1); + let (auth_stx, auth_srx) = bounded(1); + + auth_sdata_tx = Some(auth_stx); + auth_sdata_rx = Some(auth_srx); + auth_cdata_rx = Some(auth_crx); + auth_cdata_tx = Some(auth_ctx); + } EventLoop { options, - state: MqttState::new(inflight_limit, manual_acks), + state: MqttState::new(inflight_limit, manual_acks, auth_cdata_rx, auth_sdata_tx), requests_tx, requests_rx, pending, network: None, keepalive_timeout: None, + auth_sdata_rx, + auth_cdata_tx, } } @@ -138,7 +161,7 @@ impl EventLoop { if self.network.is_none() { let (network, connack) = time::timeout( Duration::from_secs(self.options.connection_timeout()), - connect(&mut self.options), + connect(&mut self.options, &mut self.state), ) .await??; self.network = Some(network); @@ -263,12 +286,12 @@ impl EventLoop { /// the stream. /// This function (for convenience) includes internal delays for users to perform internal sleeps /// between re-connections so that cancel semantics can be used during this sleep -async fn connect(options: &mut MqttOptions) -> Result<(Network, Incoming), ConnectionError> { +async fn connect(options: &mut MqttOptions, state: &mut MqttState) -> Result<(Network, Incoming), ConnectionError> { // connect to the broker let mut network = network_connect(options).await?; // make MQTT connection request (which internally awaits for ack) - let packet = mqtt_connect(options, &mut network).await?; + let packet = mqtt_connect(options, &mut network, state).await?; // Last session might contain packets which aren't acked. MQTT says these packets should be // republished in the next session @@ -387,12 +410,12 @@ async fn network_connect(options: &MqttOptions) -> Result Result { let keep_alive = options.keep_alive().as_secs() as u16; let clean_start = options.clean_start(); let client_id = options.client_id(); let properties = options.connect_properties(); - let connect = Connect { keep_alive, client_id, @@ -404,18 +427,29 @@ async fn mqtt_connect( network.connect(connect, options).await?; // validate connack - match network.read().await? { - Incoming::ConnAck(connack) if connack.code == ConnectReturnCode::Success => { - // Override local keep_alive value if set by server. - if let Some(props) = &connack.properties { - if let Some(keep_alive) = props.server_keep_alive { - options.keep_alive = Duration::from_secs(keep_alive as u64); + loop { + match network.read().await? { + Incoming::ConnAck(connack) if connack.code == ConnectReturnCode::Success => { + // Override local keep_alive value if set by server. + if let Some(props) = &connack.properties { + if let Some(keep_alive) = props.server_keep_alive { + options.keep_alive = Duration::from_secs(keep_alive as u64); + } + network.set_max_outgoing_size(props.max_packet_size); + } + return Ok(Packet::ConnAck(connack)); + } + Incoming::ConnAck(connack) => return Err(ConnectionError::ConnectionRefused(connack.code)), + Incoming::Auth(auth) => { + if let Some(outgoing) = state.handle_incoming_packet(Incoming::Auth(auth))? { + network.write(outgoing).await?; + network.flush().await?; + } + else { + return Err(ConnectionError::AuthProcessingError); } - network.set_max_outgoing_size(props.max_packet_size); } - Ok(Packet::ConnAck(connack)) + packet => return Err(ConnectionError::NotConnAck(Box::new(packet))), } - Incoming::ConnAck(connack) => Err(ConnectionError::ConnectionRefused(connack.code)), - packet => Err(ConnectionError::NotConnAck(Box::new(packet))), } } diff --git a/rumqttc/src/v5/mod.rs b/rumqttc/src/v5/mod.rs index 44499cde2..c5a06c501 100644 --- a/rumqttc/src/v5/mod.rs +++ b/rumqttc/src/v5/mod.rs @@ -1,11 +1,13 @@ use bytes::Bytes; use std::fmt::{self, Debug, Formatter}; use std::time::Duration; +use flume::{Sender, Receiver}; + #[cfg(feature = "websocket")] use std::{ future::{Future, IntoFuture}, pin::Pin, - sync::Arc, + sync::Arc }; mod client; @@ -47,6 +49,7 @@ pub enum Request { Unsubscribe(Unsubscribe), UnsubAck(UnsubAck), Disconnect, + Auth(Auth), } #[cfg(feature = "websocket")] diff --git a/rumqttc/src/v5/mqttbytes/v5/auth.rs b/rumqttc/src/v5/mqttbytes/v5/auth.rs new file mode 100644 index 000000000..eb47fb2fd --- /dev/null +++ b/rumqttc/src/v5/mqttbytes/v5/auth.rs @@ -0,0 +1,214 @@ +use super::*; +use bytes::{Buf, BufMut, Bytes, BytesMut}; + +/// Auth packet +#[derive(Clone, Debug, PartialEq, Eq)] +pub struct Auth { + pub reason: AuthReasonCode, + pub properties: Option, +} + +impl Auth { + pub fn new(reason: AuthReasonCode, properties: Option) -> Self { + + Self { + reason, + properties, + } + } + + pub fn size(&self) -> usize { + let len = self.len(); + let remaining_len_size = len_len(len); + + 1 + remaining_len_size + len + } + + fn len(&self) -> usize { + let mut len = 1; + + if let Some(p) = &self.properties { + let properties_len = p.len(); + let properties_len_len = len_len(properties_len); + len += properties_len_len + properties_len; + } else { + // just 1 byte representing 0 len + len += 1; + } + + len + } + + pub fn read(fixed_header: FixedHeader, mut bytes: Bytes) -> Result { + let variable_header_index = fixed_header.fixed_header_len; + bytes.advance(variable_header_index); + let code = read_u8(&mut bytes)?; + let reason = reason(code)?; + let properties = AuthProperties::read(&mut bytes)?; + let auth = Auth { + reason, + properties, + }; + + Ok(auth) + } + + pub fn write(&self, buffer: &mut BytesMut) -> Result { + let len = Self::len(self); + buffer.put_u8(0xF0); + + let count = write_remaining_length(buffer, len)?; + buffer.put_u8(code(self.reason)); + + if let Some(p) = &self.properties { + p.write(buffer)?; + } else { + write_remaining_length(buffer, 0)?; + } + + Ok(1 + count + len) + } +} + +/// Return code in auth +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub enum AuthReasonCode { + Success, + ContinueAuthentication, + Reauthenticate +} + +#[derive(Debug, Clone, PartialEq, Eq, Default)] +pub struct AuthProperties { + /// Method of authentication + pub authentication_method: Option, + /// Authentication data + pub authentication_data: Option, + /// Reason for disconnection + pub reason_string: Option, + /// List of user properties + pub user_properties: Vec<(String, String)>, +} + +impl AuthProperties { + fn len(&self) -> usize { + let mut len = 0; + + if let Some(authentication_method) = &self.authentication_method { + len += 1 + 2 + authentication_method.len(); + } + + if let Some(authentication_data) = &self.authentication_data { + len += 1 + 2 + authentication_data.len(); + } + + if let Some(reason) = &self.reason_string { + len += 1 + 2 + reason.len(); + } + + for (key, value) in self.user_properties.iter() { + len += 1 + 2 + key.len() + 2 + value.len(); + } + + len + } + + pub fn read(bytes: &mut Bytes) -> Result, Error> { + let mut authentication_method = None; + let mut authentication_data = None; + let mut reason_string = None; + let mut user_properties = Vec::new(); + + let (properties_len_len, properties_len) = length(bytes.iter())?; + bytes.advance(properties_len_len); + if properties_len == 0 { + return Ok(None); + } + + let mut cursor = 0; + // read until cursor reaches property length. properties_len = 0 will skip this loop + while cursor < properties_len { + let prop = read_u8(bytes)?; + cursor += 1; + + match property(prop)? { + PropertyType::AuthenticationMethod => { + let method = read_mqtt_string(bytes)?; + cursor += 2 + method.len(); + authentication_method = Some(method); + } + PropertyType::AuthenticationData => { + let data = read_mqtt_bytes(bytes)?; + cursor += 2 + data.len(); + authentication_data = Some(data); + } + PropertyType::ReasonString => { + let reason = read_mqtt_string(bytes)?; + cursor += 2 + reason.len(); + reason_string = Some(reason); + } + PropertyType::UserProperty => { + let key = read_mqtt_string(bytes)?; + let value = read_mqtt_string(bytes)?; + cursor += 2 + key.len() + 2 + value.len(); + user_properties.push((key, value)); + } + _ => return Err(Error::InvalidPropertyType(prop)), + } + } + + Ok(Some(AuthProperties { + authentication_method, + authentication_data, + reason_string, + user_properties, + })) + } + + pub fn write(&self, buffer: &mut BytesMut) -> Result<(), Error> { + let len = self.len(); + write_remaining_length(buffer, len)?; + + if let Some(authentication_method) = &self.authentication_method { + buffer.put_u8(PropertyType::AuthenticationMethod as u8); + write_mqtt_string(buffer, authentication_method); + } + + if let Some(authentication_data) = &self.authentication_data { + buffer.put_u8(PropertyType::AuthenticationData as u8); + write_mqtt_bytes(buffer, authentication_data); + } + + if let Some(reason) = &self.reason_string { + buffer.put_u8(PropertyType::ReasonString as u8); + write_mqtt_string(buffer, reason); + } + + for (key, value) in self.user_properties.iter() { + buffer.put_u8(PropertyType::UserProperty as u8); + write_mqtt_string(buffer, key); + write_mqtt_string(buffer, value); + } + + Ok(()) + } +} + +fn reason(num: u8) -> Result { + let code = match num { + 0x00 => AuthReasonCode::Success, + 0x18 => AuthReasonCode::ContinueAuthentication, + 0x19 => AuthReasonCode::Reauthenticate, + num => return Err(Error::InvalidReason(num)), + }; + + Ok(code) +} + +fn code(value: AuthReasonCode) -> u8 { + match value { + AuthReasonCode::Success => 0x00, + AuthReasonCode::ContinueAuthentication => 0x18, + AuthReasonCode::Reauthenticate => 0x19 + } +} diff --git a/rumqttc/src/v5/mqttbytes/v5/mod.rs b/rumqttc/src/v5/mqttbytes/v5/mod.rs index 342278596..ea433aa51 100644 --- a/rumqttc/src/v5/mqttbytes/v5/mod.rs +++ b/rumqttc/src/v5/mqttbytes/v5/mod.rs @@ -15,6 +15,7 @@ pub use self::{ subscribe::{Filter, RetainForwardRule, Subscribe, SubscribeProperties}, unsuback::{UnsubAck, UnsubAckProperties, UnsubAckReason}, unsubscribe::{Unsubscribe, UnsubscribeProperties}, + auth::{Auth, AuthProperties, AuthReasonCode}, }; use super::*; @@ -34,6 +35,7 @@ mod suback; mod subscribe; mod unsuback; mod unsubscribe; +mod auth; #[derive(Clone, Debug, PartialEq, Eq)] pub enum Packet { @@ -51,6 +53,7 @@ pub enum Packet { Unsubscribe(Unsubscribe), UnsubAck(UnsubAck), Disconnect(Disconnect), + Auth(Auth), } impl Packet { @@ -123,6 +126,10 @@ impl Packet { let disconnect = Disconnect::read(fixed_header, packet)?; Packet::Disconnect(disconnect) } + PacketType::Auth => { + let auth = Auth::read(fixed_header, packet)?; + Packet::Auth(auth) + } }; Ok(packet) @@ -153,6 +160,7 @@ impl Packet { Self::PingReq(_) => PingReq::write(write), Self::PingResp(_) => PingResp::write(write), Self::Disconnect(disconnect) => disconnect.write(write), + Self::Auth(auth) => auth.write(write), } } @@ -172,6 +180,7 @@ impl Packet { Self::PingReq(req) => req.size(), Self::PingResp(resp) => resp.size(), Self::Disconnect(disconnect) => disconnect.size(), + Self::Auth(auth) => auth.size(), } } } @@ -194,6 +203,7 @@ pub enum PacketType { PingReq, PingResp, Disconnect, + Auth, } #[repr(u8)] @@ -280,6 +290,7 @@ impl FixedHeader { 12 => Ok(PacketType::PingReq), 13 => Ok(PacketType::PingResp), 14 => Ok(PacketType::Disconnect), + 15 => Ok(PacketType::Auth), _ => Err(Error::InvalidPacketType(num)), } } diff --git a/rumqttc/src/v5/state.rs b/rumqttc/src/v5/state.rs index 854aa7b0f..8c12d16cc 100644 --- a/rumqttc/src/v5/state.rs +++ b/rumqttc/src/v5/state.rs @@ -1,7 +1,7 @@ use super::mqttbytes::v5::{ ConnAck, ConnectReturnCode, Disconnect, DisconnectReasonCode, Packet, PingReq, PubAck, PubAckReason, PubComp, PubCompReason, PubRec, PubRecReason, PubRel, PubRelReason, Publish, - SubAck, Subscribe, SubscribeReasonCode, UnsubAck, UnsubAckReason, Unsubscribe, + SubAck, Subscribe, SubscribeReasonCode, UnsubAck, UnsubAckReason, Unsubscribe, Auth, AuthProperties, AuthReasonCode }; use super::mqttbytes::{self, Error as MqttError, QoS}; @@ -10,6 +10,7 @@ use super::{Event, Incoming, Outgoing, Request}; use bytes::Bytes; use std::collections::{HashMap, VecDeque}; use std::{io, time::Instant}; +use flume::{Receiver, Sender}; /// Errors during state handling #[derive(Debug, thiserror::Error)] @@ -121,13 +122,15 @@ pub struct MqttState { pub(crate) max_outgoing_inflight: u16, /// Upper limit on the maximum number of allowed inflight QoS1 & QoS2 requests max_outgoing_inflight_upper_limit: u16, + auth_rx: Option>, + auth_sx: Option>, } impl MqttState { /// Creates new mqtt state. Same state should be used during a /// connection for persistent sessions while new state should /// instantiated for clean sessions - pub fn new(max_inflight: u16, manual_acks: bool) -> Self { + pub fn new(max_inflight: u16, manual_acks: bool, auth_rx: Option>, auth_sx: Option>) -> Self { MqttState { await_pingresp: false, collision_ping_count: 0, @@ -148,6 +151,8 @@ impl MqttState { broker_topic_alias_max: 0, max_outgoing_inflight: max_inflight, max_outgoing_inflight_upper_limit: max_inflight, + auth_rx, + auth_sx, } } @@ -202,6 +207,7 @@ impl MqttState { } Request::PubAck(puback) => self.outgoing_puback(puback)?, Request::PubRec(pubrec) => self.outgoing_pubrec(pubrec)?, + Request::Auth(auth) => self.outgoing_auth(auth)?, _ => unimplemented!(), }; @@ -228,6 +234,7 @@ impl MqttState { Incoming::PubComp(pubcomp) => self.handle_incoming_pubcomp(pubcomp)?, Incoming::ConnAck(connack) => self.handle_incoming_connack(connack)?, Incoming::Disconnect(disconn) => self.handle_incoming_disconn(disconn)?, + Incoming::Auth(auth) => self.handle_incoming_auth(auth)?, _ => { error!("Invalid incoming packet = {:?}", packet); return Err(StateError::WrongPacket); @@ -478,6 +485,32 @@ impl MqttState { self.await_pingresp = false; Ok(None) } + + fn handle_incoming_auth(&mut self, auth: &mut Auth) -> Result, StateError> { + let props = auth.properties.clone().unwrap(); + let auth_data = String::from_utf8(props.authentication_data.unwrap().to_vec()).unwrap(); + + if self.auth_rx.is_none() || self.auth_sx.is_none() { + return Err(StateError::InvalidState); + } + + // Send server authentication data to application. + self.auth_sx.as_ref().unwrap().send(auth_data).unwrap(); + + // Receive client authentication data from application. + let client_auth_data = self.auth_rx.as_ref().unwrap().recv().unwrap(); + + let properties = AuthProperties{ + authentication_method: Some(props.authentication_method.unwrap().to_string()), + authentication_data: Some(client_auth_data.clone().into()), + reason_string: None, + user_properties: Vec::new(), + }; + + let client_auth = Auth::new(AuthReasonCode::ContinueAuthentication, Some(properties)); + + self.outgoing_auth(client_auth) + } /// Adds next packet identifier to QoS 1 and 2 publish packets and returns /// it buy wrapping publish in packet @@ -646,6 +679,14 @@ impl MqttState { Ok(Some(Packet::Disconnect(Disconnect::new(reason)))) } + + fn outgoing_auth(&mut self, auth: Auth) -> Result, StateError> { + let props = auth.properties.as_ref().unwrap(); + debug!("Auth packet sent. Auth Method: {:?}. Auth Data: {:?}", props.authentication_method, props.authentication_data); + let event = Event::Outgoing(Outgoing::Auth); + self.events.push_back(event); + Ok(Some(Packet::Auth(auth))) + } fn check_collision(&mut self, pkid: u16) -> Option { if let Some(publish) = &self.collision { @@ -719,7 +760,7 @@ mod test { } fn build_mqttstate() -> MqttState { - MqttState::new(u16::MAX, false) + MqttState::new(u16::MAX, false, None, None) } #[test] @@ -780,7 +821,7 @@ mod test { #[test] fn outgoing_publish_with_max_inflight_is_ok() { - let mut mqtt = MqttState::new(2, false); + let mut mqtt = MqttState::new(2, false, None, None); // QoS2 publish let publish = build_outgoing_publish(QoS::ExactlyOnce); From 0e5d53a69b324abe47e7a927ffc22f046159ed0c Mon Sep 17 00:00:00 2001 From: tinzhu Date: Wed, 17 Apr 2024 17:16:51 +0800 Subject: [PATCH 02/31] Test re-authentication. --- rumqttc/examples/auth.rs | 20 +++++++++++++++++++- 1 file changed, 19 insertions(+), 1 deletion(-) diff --git a/rumqttc/examples/auth.rs b/rumqttc/examples/auth.rs index 80bdf050b..c21477174 100644 --- a/rumqttc/examples/auth.rs +++ b/rumqttc/examples/auth.rs @@ -1,5 +1,5 @@ -use rumqttc::v5::mqttbytes::QoS; +use rumqttc::v5::mqttbytes::{QoS, v5::AuthReasonCode, v5::AuthProperties, v5::Auth}; use rumqttc::v5::{AsyncClient, MqttOptions}; use tokio::task; use std::error::Error; @@ -18,6 +18,9 @@ async fn main() -> Result<(), Box> { mqttoptions.set_connection_timeout(20); let (client, mut eventloop) = AsyncClient::new(mqttoptions, 10); + let scram2 = ScramClient::new("user1", "123456", None); + let (scram2, client_first2) = scram2.client_first(); + task::spawn(async move { let server_first: String = client.recv_server_auth_data().await.unwrap(); let scram = scram.handle_server_first(&server_first).unwrap(); @@ -26,6 +29,21 @@ async fn main() -> Result<(), Box> { client.subscribe("rumqtt_auth/topic", QoS::AtLeastOnce).await.unwrap(); client.publish("rumqtt_auth/topic", QoS::AtLeastOnce, false, "hello world").await.unwrap(); + + // Reauthentication + let props = AuthProperties { + authentication_method: Some("SCRAM-SHA-256".to_string()), + authentication_data: Some(client_first2.clone().into()), + reason_string: None, + user_properties: vec![], + }; + + client.auth(AuthReasonCode::Reauthenticate, Some(props)).await.unwrap(); + + let server_first: String = client.recv_server_auth_data().await.unwrap(); + let scram2 = scram2.handle_server_first(&server_first).unwrap(); + let (scram2, client_final2) = scram2.client_final(); + client.send_client_auth_data(client_final2).await.unwrap(); }); loop { From c42c4e8e082af28f8efa5bc34c6a2817bda6d613 Mon Sep 17 00:00:00 2001 From: tinzhu Date: Fri, 19 Apr 2024 11:35:32 +0800 Subject: [PATCH 03/31] Update auth user interface. --- rumqttc/examples/auth.rs | 71 ++++++++++++++++++++++--------------- rumqttc/src/v5/client.rs | 34 ++---------------- rumqttc/src/v5/eventloop.rs | 25 ++----------- rumqttc/src/v5/mod.rs | 19 ++++++++++ rumqttc/src/v5/state.rs | 34 ++++++++---------- 5 files changed, 80 insertions(+), 103 deletions(-) diff --git a/rumqttc/examples/auth.rs b/rumqttc/examples/auth.rs index c21477174..fa9e2eee3 100644 --- a/rumqttc/examples/auth.rs +++ b/rumqttc/examples/auth.rs @@ -1,49 +1,62 @@ use rumqttc::v5::mqttbytes::{QoS, v5::AuthReasonCode, v5::AuthProperties, v5::Auth}; -use rumqttc::v5::{AsyncClient, MqttOptions}; +use rumqttc::v5::{AsyncClient, MqttOptions, AuthManagerTrait}; use tokio::task; use std::error::Error; -use std::thread; +use std::rc::Rc; +use std::cell::RefCell; use scram::ScramClient; +use scram::client::ServerFirst; -#[tokio::main()] +#[derive(Debug)] +struct AuthManager <'a>{ + scram_client: Option>, + scram_server: Option>, +} + +impl <'a> AuthManager <'a>{ + fn new(user: &'a str, password: &'a str) -> AuthManager <'a>{ + let scram = ScramClient::new(user, password, None); + + AuthManager{ + scram_client: Some(scram), + scram_server: None, + } + } + + fn auth_start(&mut self) -> Result{ + let scram = self.scram_client.take().unwrap(); + let (scram, client_first) = scram.client_first(); + self.scram_server = Some(scram); + + Ok(client_first) + } +} + +impl <'a> AuthManagerTrait for AuthManager<'a> { + fn auth_continue(&mut self, auth_data: String) -> Result { + let scram = self.scram_server.take().unwrap(); + let scram = scram.handle_server_first(&auth_data).unwrap(); + let (_, client_final) = scram.client_final(); + Ok(client_final) + } +} + +#[tokio::main(flavor = "current_thread")] async fn main() -> Result<(), Box> { - let scram = ScramClient::new("user1", "123456", None); - let (scram, client_first) = scram.client_first(); + let mut authmanager = AuthManager::new("user1", "123456"); + let client_first = authmanager.auth_start().unwrap(); let mut mqttoptions = MqttOptions::new("auth_test", "127.0.0.1", 1883); mqttoptions.set_authentication_method(Some("SCRAM-SHA-256".to_string())); mqttoptions.set_authentication_data(Some(client_first.clone().into())); - mqttoptions.set_connection_timeout(20); + mqttoptions.set_auth_manager(Rc::new(RefCell::new(authmanager))); let (client, mut eventloop) = AsyncClient::new(mqttoptions, 10); - let scram2 = ScramClient::new("user1", "123456", None); - let (scram2, client_first2) = scram2.client_first(); - task::spawn(async move { - let server_first: String = client.recv_server_auth_data().await.unwrap(); - let scram = scram.handle_server_first(&server_first).unwrap(); - let (scram, client_final) = scram.client_final(); - client.send_client_auth_data(client_final).await.unwrap(); - client.subscribe("rumqtt_auth/topic", QoS::AtLeastOnce).await.unwrap(); client.publish("rumqtt_auth/topic", QoS::AtLeastOnce, false, "hello world").await.unwrap(); - - // Reauthentication - let props = AuthProperties { - authentication_method: Some("SCRAM-SHA-256".to_string()), - authentication_data: Some(client_first2.clone().into()), - reason_string: None, - user_properties: vec![], - }; - - client.auth(AuthReasonCode::Reauthenticate, Some(props)).await.unwrap(); - - let server_first: String = client.recv_server_auth_data().await.unwrap(); - let scram2 = scram2.handle_server_first(&server_first).unwrap(); - let (scram2, client_final2) = scram2.client_final(); - client.send_client_auth_data(client_final2).await.unwrap(); }); loop { diff --git a/rumqttc/src/v5/client.rs b/rumqttc/src/v5/client.rs index 768d746a2..048082886 100644 --- a/rumqttc/src/v5/client.rs +++ b/rumqttc/src/v5/client.rs @@ -47,8 +47,6 @@ impl From> for ClientError { #[derive(Clone, Debug)] pub struct AsyncClient { request_tx: Sender, - auth_tx: Option>, - auth_rx: Option>, } impl AsyncClient { @@ -58,45 +56,17 @@ impl AsyncClient { pub fn new(options: MqttOptions, cap: usize) -> (AsyncClient, EventLoop) { let eventloop = EventLoop::new(options, cap); let request_tx = eventloop.requests_tx.clone(); - let auth_tx = eventloop.auth_cdata_tx.clone(); - let auth_rx = eventloop.auth_sdata_rx.clone(); - - let client = AsyncClient { request_tx, auth_tx, auth_rx}; + let client = AsyncClient { request_tx}; (client, eventloop) } - pub async fn recv_server_auth_data(&self) -> Result { - if self.auth_rx.is_none() { - return Err(ClientError::Request(Request::Disconnect)); - } - let string = self.auth_rx.as_ref().unwrap().recv_async().await; - - if let Err(_) = string { - return Err(ClientError::Request(Request::Disconnect)); - } - - Ok(string.unwrap()) - } - - pub async fn send_client_auth_data(&self, data: String) -> Result<(), ClientError> { - if self.auth_tx.is_none() { - return Err(ClientError::Request(Request::Disconnect)); - } - - if let Err(_) = self.auth_tx.as_ref().unwrap().send_async(data).await { - return Err(ClientError::Request(Request::Disconnect)); - } - - Ok(()) - } - /// Create a new `AsyncClient` from a channel `Sender`. /// /// This is mostly useful for creating a test instance where you can /// listen on the corresponding receiver. pub fn from_senders(request_tx: Sender) -> AsyncClient { - AsyncClient { request_tx, auth_tx: None, auth_rx: None } + AsyncClient { request_tx } } /// Sends a MQTT Publish to the `EventLoop`. diff --git a/rumqttc/src/v5/eventloop.rs b/rumqttc/src/v5/eventloop.rs index 9927b4762..2a109e0f5 100644 --- a/rumqttc/src/v5/eventloop.rs +++ b/rumqttc/src/v5/eventloop.rs @@ -84,9 +84,6 @@ pub struct EventLoop { network: Option, /// Keep alive time keepalive_timeout: Option>>, - - pub auth_sdata_rx: Option>, - pub auth_cdata_tx: Option>, } /// Events which can be yielded by the event loop @@ -106,33 +103,17 @@ impl EventLoop { let pending = VecDeque::new(); let inflight_limit = options.outgoing_inflight_upper_limit.unwrap_or(u16::MAX); let manual_acks = options.manual_acks; - - // set state according to authentication method - let mut auth_sdata_tx = None; - let mut auth_sdata_rx = None; - let mut auth_cdata_rx = None; - let mut auth_cdata_tx = None; - - if options.authentication_method().is_some() { - let (auth_ctx, auth_crx) = bounded(1); - let (auth_stx, auth_srx) = bounded(1); - - auth_sdata_tx = Some(auth_stx); - auth_sdata_rx = Some(auth_srx); - auth_cdata_rx = Some(auth_crx); - auth_cdata_tx = Some(auth_ctx); - } + + let auth_manager = options.auth_manager(); EventLoop { options, - state: MqttState::new(inflight_limit, manual_acks, auth_cdata_rx, auth_sdata_tx), + state: MqttState::new(inflight_limit, manual_acks, auth_manager), requests_tx, requests_rx, pending, network: None, keepalive_timeout: None, - auth_sdata_rx, - auth_cdata_tx, } } diff --git a/rumqttc/src/v5/mod.rs b/rumqttc/src/v5/mod.rs index c5a06c501..7aa798dfc 100644 --- a/rumqttc/src/v5/mod.rs +++ b/rumqttc/src/v5/mod.rs @@ -1,6 +1,8 @@ use bytes::Bytes; use std::fmt::{self, Debug, Formatter}; use std::time::Duration; +use std::rc::Rc; +use std::cell::RefCell; use flume::{Sender, Receiver}; #[cfg(feature = "websocket")] @@ -33,6 +35,11 @@ pub use crate::proxy::{Proxy, ProxyAuth, ProxyType}; pub type Incoming = Packet; +pub trait AuthManagerTrait: std::fmt::Debug { + fn auth_continue(&mut self, auth_data: String) -> Result; +} + + /// Requests by the client to mqtt event loop. Request are /// handled one by one. #[derive(Clone, Debug, PartialEq, Eq)] @@ -107,6 +114,8 @@ pub struct MqttOptions { outgoing_inflight_upper_limit: Option, #[cfg(feature = "websocket")] request_modifier: Option, + + auth_manager: Option>>, } impl MqttOptions { @@ -142,6 +151,7 @@ impl MqttOptions { outgoing_inflight_upper_limit: None, #[cfg(feature = "websocket")] request_modifier: None, + auth_manager: None, } } @@ -529,6 +539,15 @@ impl MqttOptions { pub fn get_outgoing_inflight_upper_limit(&self) -> Option { self.outgoing_inflight_upper_limit } + + pub fn set_auth_manager(&mut self, auth_manager: Rc>) -> &mut Self { + self.auth_manager = Some(auth_manager); + self + } + + pub fn auth_manager(&self) -> Option>> { + self.auth_manager.clone() + } } #[cfg(feature = "url")] diff --git a/rumqttc/src/v5/state.rs b/rumqttc/src/v5/state.rs index 8c12d16cc..58e88e0ac 100644 --- a/rumqttc/src/v5/state.rs +++ b/rumqttc/src/v5/state.rs @@ -1,15 +1,17 @@ use super::mqttbytes::v5::{ ConnAck, ConnectReturnCode, Disconnect, DisconnectReasonCode, Packet, PingReq, PubAck, PubAckReason, PubComp, PubCompReason, PubRec, PubRecReason, PubRel, PubRelReason, Publish, - SubAck, Subscribe, SubscribeReasonCode, UnsubAck, UnsubAckReason, Unsubscribe, Auth, AuthProperties, AuthReasonCode + SubAck, Subscribe, SubscribeReasonCode, UnsubAck, UnsubAckReason, Unsubscribe, Auth, AuthProperties, AuthReasonCode, }; use super::mqttbytes::{self, Error as MqttError, QoS}; -use super::{Event, Incoming, Outgoing, Request}; +use super::{Event, Incoming, Outgoing, Request, AuthManagerTrait}; use bytes::Bytes; use std::collections::{HashMap, VecDeque}; use std::{io, time::Instant}; +use std::rc::Rc; +use std::cell::RefCell; use flume::{Receiver, Sender}; /// Errors during state handling @@ -122,15 +124,15 @@ pub struct MqttState { pub(crate) max_outgoing_inflight: u16, /// Upper limit on the maximum number of allowed inflight QoS1 & QoS2 requests max_outgoing_inflight_upper_limit: u16, - auth_rx: Option>, - auth_sx: Option>, + /// Authentication manager + auth_manager: Option>>, } impl MqttState { /// Creates new mqtt state. Same state should be used during a /// connection for persistent sessions while new state should /// instantiated for clean sessions - pub fn new(max_inflight: u16, manual_acks: bool, auth_rx: Option>, auth_sx: Option>) -> Self { + pub fn new(max_inflight: u16, manual_acks: bool, auth_manager: Option>>) -> Self { MqttState { await_pingresp: false, collision_ping_count: 0, @@ -151,8 +153,7 @@ impl MqttState { broker_topic_alias_max: 0, max_outgoing_inflight: max_inflight, max_outgoing_inflight_upper_limit: max_inflight, - auth_rx, - auth_sx, + auth_manager, } } @@ -488,21 +489,14 @@ impl MqttState { fn handle_incoming_auth(&mut self, auth: &mut Auth) -> Result, StateError> { let props = auth.properties.clone().unwrap(); - let auth_data = String::from_utf8(props.authentication_data.unwrap().to_vec()).unwrap(); + let in_auth_data = String::from_utf8(props.authentication_data.unwrap().to_vec()).unwrap(); - if self.auth_rx.is_none() || self.auth_sx.is_none() { - return Err(StateError::InvalidState); - } - - // Send server authentication data to application. - self.auth_sx.as_ref().unwrap().send(auth_data).unwrap(); - - // Receive client authentication data from application. - let client_auth_data = self.auth_rx.as_ref().unwrap().recv().unwrap(); + let auth_manager = self.auth_manager.clone().unwrap(); + let out_auth_data = auth_manager.borrow_mut().auth_continue(in_auth_data).unwrap(); let properties = AuthProperties{ authentication_method: Some(props.authentication_method.unwrap().to_string()), - authentication_data: Some(client_auth_data.clone().into()), + authentication_data: Some(out_auth_data.clone().into()), reason_string: None, user_properties: Vec::new(), }; @@ -760,7 +754,7 @@ mod test { } fn build_mqttstate() -> MqttState { - MqttState::new(u16::MAX, false, None, None) + MqttState::new(u16::MAX, false, None) } #[test] @@ -821,7 +815,7 @@ mod test { #[test] fn outgoing_publish_with_max_inflight_is_ok() { - let mut mqtt = MqttState::new(2, false, None, None); + let mut mqtt = MqttState::new(2, false, None); // QoS2 publish let publish = build_outgoing_publish(QoS::ExactlyOnce); From 9b078bcc28946501ae00a0c11b2286a6acdce5c8 Mon Sep 17 00:00:00 2001 From: tinzhu Date: Mon, 22 Apr 2024 13:24:51 +0800 Subject: [PATCH 04/31] Improve logic. --- rumqttc/src/v5/mod.rs | 2 +- rumqttc/src/v5/state.rs | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/rumqttc/src/v5/mod.rs b/rumqttc/src/v5/mod.rs index 7aa798dfc..326ea6edb 100644 --- a/rumqttc/src/v5/mod.rs +++ b/rumqttc/src/v5/mod.rs @@ -546,7 +546,7 @@ impl MqttOptions { } pub fn auth_manager(&self) -> Option>> { - self.auth_manager.clone() + Some(Rc::clone(self.auth_manager.as_ref().unwrap())) } } diff --git a/rumqttc/src/v5/state.rs b/rumqttc/src/v5/state.rs index 58e88e0ac..63a5d1a91 100644 --- a/rumqttc/src/v5/state.rs +++ b/rumqttc/src/v5/state.rs @@ -491,7 +491,7 @@ impl MqttState { let props = auth.properties.clone().unwrap(); let in_auth_data = String::from_utf8(props.authentication_data.unwrap().to_vec()).unwrap(); - let auth_manager = self.auth_manager.clone().unwrap(); + let auth_manager = Rc::clone(self.auth_manager.as_ref().unwrap()); let out_auth_data = auth_manager.borrow_mut().auth_continue(in_auth_data).unwrap(); let properties = AuthProperties{ From 2d76573a103e4b13d4e63f7550607329c70cf32f Mon Sep 17 00:00:00 2001 From: tinzhu Date: Tue, 23 Apr 2024 16:54:37 +0800 Subject: [PATCH 05/31] Change auth_continue return error type to StateError. --- rumqttc/examples/auth.rs | 6 +++--- rumqttc/src/v5/mod.rs | 5 ++++- rumqttc/src/v5/state.rs | 6 +++++- 3 files changed, 12 insertions(+), 5 deletions(-) diff --git a/rumqttc/examples/auth.rs b/rumqttc/examples/auth.rs index fa9e2eee3..46c621169 100644 --- a/rumqttc/examples/auth.rs +++ b/rumqttc/examples/auth.rs @@ -1,6 +1,6 @@ use rumqttc::v5::mqttbytes::{QoS, v5::AuthReasonCode, v5::AuthProperties, v5::Auth}; -use rumqttc::v5::{AsyncClient, MqttOptions, AuthManagerTrait}; +use rumqttc::v5::{AsyncClient, MqttOptions, AuthManagerTrait, StateError}; use tokio::task; use std::error::Error; use std::rc::Rc; @@ -24,7 +24,7 @@ impl <'a> AuthManager <'a>{ } } - fn auth_start(&mut self) -> Result{ + fn auth_start(&mut self) -> Result{ let scram = self.scram_client.take().unwrap(); let (scram, client_first) = scram.client_first(); self.scram_server = Some(scram); @@ -34,7 +34,7 @@ impl <'a> AuthManager <'a>{ } impl <'a> AuthManagerTrait for AuthManager<'a> { - fn auth_continue(&mut self, auth_data: String) -> Result { + fn auth_continue(&mut self, auth_data: String) -> Result { let scram = self.scram_server.take().unwrap(); let scram = scram.handle_server_first(&auth_data).unwrap(); let (_, client_final) = scram.client_final(); diff --git a/rumqttc/src/v5/mod.rs b/rumqttc/src/v5/mod.rs index 326ea6edb..74ae0078f 100644 --- a/rumqttc/src/v5/mod.rs +++ b/rumqttc/src/v5/mod.rs @@ -36,7 +36,7 @@ pub use crate::proxy::{Proxy, ProxyAuth, ProxyType}; pub type Incoming = Packet; pub trait AuthManagerTrait: std::fmt::Debug { - fn auth_continue(&mut self, auth_data: String) -> Result; + fn auth_continue(&mut self, auth_data: String) -> Result; } @@ -546,6 +546,9 @@ impl MqttOptions { } pub fn auth_manager(&self) -> Option>> { + if self.auth_manager.is_none() { + return None; + } Some(Rc::clone(self.auth_manager.as_ref().unwrap())) } } diff --git a/rumqttc/src/v5/state.rs b/rumqttc/src/v5/state.rs index 63a5d1a91..102e1fc16 100644 --- a/rumqttc/src/v5/state.rs +++ b/rumqttc/src/v5/state.rs @@ -492,7 +492,11 @@ impl MqttState { let in_auth_data = String::from_utf8(props.authentication_data.unwrap().to_vec()).unwrap(); let auth_manager = Rc::clone(self.auth_manager.as_ref().unwrap()); - let out_auth_data = auth_manager.borrow_mut().auth_continue(in_auth_data).unwrap(); + + let out_auth_data = match auth_manager.borrow_mut().auth_continue(in_auth_data) { + Ok(data) => data, + Err(err) => return Err(err), + }; let properties = AuthProperties{ authentication_method: Some(props.authentication_method.unwrap().to_string()), From 95150f680aef3a136863c75e4a0a822c45092ce7 Mon Sep 17 00:00:00 2001 From: tinzhu Date: Tue, 23 Apr 2024 17:39:35 +0800 Subject: [PATCH 06/31] Add length_calculation test. --- rumqttc/src/v5/mqttbytes/v5/auth.rs | 30 +++++++++++++++++++++++++++++ 1 file changed, 30 insertions(+) diff --git a/rumqttc/src/v5/mqttbytes/v5/auth.rs b/rumqttc/src/v5/mqttbytes/v5/auth.rs index eb47fb2fd..8157905f3 100644 --- a/rumqttc/src/v5/mqttbytes/v5/auth.rs +++ b/rumqttc/src/v5/mqttbytes/v5/auth.rs @@ -212,3 +212,33 @@ fn code(value: AuthReasonCode) -> u8 { AuthReasonCode::Reauthenticate => 0x19 } } + +#[cfg(test)] +mod test { + use super::super::test::{USER_PROP_KEY, USER_PROP_VAL}; + use super::*; + use bytes::BytesMut; + use pretty_assertions::assert_eq; + + #[test] + fn length_calculation() { + let mut dummy_bytes = BytesMut::new(); + // Use user_properties to pad the size to exceed ~128 bytes to make the + // remaining_length field in the packet be 2 bytes long. + let auth_props = AuthProperties { + authentication_method: Some("Authentication Method".into()), + authentication_data: Some("Authentication Data".into()), + reason_string: None, + user_properties: vec![(USER_PROP_KEY.into(), USER_PROP_VAL.into())], + }; + + let auth_pkt = Auth::new(AuthReasonCode::ContinueAuthentication, Some(auth_props)); + + let size_from_size = auth_pkt.size(); + let size_from_write = auth_pkt.write(&mut dummy_bytes).unwrap(); + let size_from_bytes = dummy_bytes.len(); + + assert_eq!(size_from_write, size_from_bytes); + assert_eq!(size_from_size, size_from_bytes); + } +} From 133159d5b863326dc7ab7d45efef03c957ce7e8b Mon Sep 17 00:00:00 2001 From: tinzhu Date: Wed, 24 Apr 2024 10:03:39 +0800 Subject: [PATCH 07/31] Improve error handling logic. --- rumqttc/examples/{auth.rs => async_auth.rs} | 2 +- rumqttc/src/v5/mod.rs | 2 +- rumqttc/src/v5/state.rs | 13 +++++++++++-- 3 files changed, 13 insertions(+), 4 deletions(-) rename rumqttc/examples/{auth.rs => async_auth.rs} (99%) diff --git a/rumqttc/examples/auth.rs b/rumqttc/examples/async_auth.rs similarity index 99% rename from rumqttc/examples/auth.rs rename to rumqttc/examples/async_auth.rs index 46c621169..7899557b0 100644 --- a/rumqttc/examples/auth.rs +++ b/rumqttc/examples/async_auth.rs @@ -34,7 +34,7 @@ impl <'a> AuthManager <'a>{ } impl <'a> AuthManagerTrait for AuthManager<'a> { - fn auth_continue(&mut self, auth_data: String) -> Result { + fn auth_continue(&mut self, auth_data: String) -> Result { let scram = self.scram_server.take().unwrap(); let scram = scram.handle_server_first(&auth_data).unwrap(); let (_, client_final) = scram.client_final(); diff --git a/rumqttc/src/v5/mod.rs b/rumqttc/src/v5/mod.rs index 74ae0078f..6e9082ddf 100644 --- a/rumqttc/src/v5/mod.rs +++ b/rumqttc/src/v5/mod.rs @@ -36,7 +36,7 @@ pub use crate::proxy::{Proxy, ProxyAuth, ProxyType}; pub type Incoming = Packet; pub trait AuthManagerTrait: std::fmt::Debug { - fn auth_continue(&mut self, auth_data: String) -> Result; + fn auth_continue(&mut self, auth_data: String) -> Result; } diff --git a/rumqttc/src/v5/state.rs b/rumqttc/src/v5/state.rs index 102e1fc16..ffa102d6b 100644 --- a/rumqttc/src/v5/state.rs +++ b/rumqttc/src/v5/state.rs @@ -68,7 +68,11 @@ pub enum StateError { #[error("Connection failed with reason '{reason:?}' ")] ConnFail { reason: ConnectReturnCode }, #[error("Connection closed by peer abruptly")] - ConnectionAborted + ConnectionAborted, + #[error("Authentication error: {0}")] + AuthError(String), + #[error("Auth Manager not set")] + AuthManagerNotSet, } impl From for StateError { @@ -491,11 +495,16 @@ impl MqttState { let props = auth.properties.clone().unwrap(); let in_auth_data = String::from_utf8(props.authentication_data.unwrap().to_vec()).unwrap(); + // Check if auth manager is set + if self.auth_manager.is_none() { + return Err(StateError::AuthManagerNotSet); + } + let auth_manager = Rc::clone(self.auth_manager.as_ref().unwrap()); let out_auth_data = match auth_manager.borrow_mut().auth_continue(in_auth_data) { Ok(data) => data, - Err(err) => return Err(err), + Err(err) => return Err(StateError::AuthError(err)), }; let properties = AuthProperties{ From 7425387d9e574435412383e971050d1a9ba964b4 Mon Sep 17 00:00:00 2001 From: tinzhu Date: Wed, 24 Apr 2024 11:52:00 +0800 Subject: [PATCH 08/31] Improve logic. --- rumqttc/examples/async_auth.rs | 19 +++++++++++++------ rumqttc/src/v5/mod.rs | 2 +- rumqttc/src/v5/state.rs | 10 ++++++---- 3 files changed, 20 insertions(+), 11 deletions(-) diff --git a/rumqttc/examples/async_auth.rs b/rumqttc/examples/async_auth.rs index 7899557b0..d167def8a 100644 --- a/rumqttc/examples/async_auth.rs +++ b/rumqttc/examples/async_auth.rs @@ -5,6 +5,7 @@ use tokio::task; use std::error::Error; use std::rc::Rc; use std::cell::RefCell; +use bytes::Bytes; use scram::ScramClient; use scram::client::ServerFirst; @@ -24,21 +25,27 @@ impl <'a> AuthManager <'a>{ } } - fn auth_start(&mut self) -> Result{ + fn auth_start(&mut self) -> Result, String>{ let scram = self.scram_client.take().unwrap(); let (scram, client_first) = scram.client_first(); self.scram_server = Some(scram); - Ok(client_first) + Ok(Some(client_first.into())) } } impl <'a> AuthManagerTrait for AuthManager<'a> { - fn auth_continue(&mut self, auth_data: String) -> Result { + fn auth_continue(&mut self, auth_method: Option, auth_data: Option) -> Result, String> { + + // Check if the authentication method is SCRAM-SHA-256 + if auth_method.unwrap() != "SCRAM-SHA-256" { + return Err("Invalid authentication method".to_string()); + } + let scram = self.scram_server.take().unwrap(); - let scram = scram.handle_server_first(&auth_data).unwrap(); + let scram = scram.handle_server_first(&String::from_utf8(auth_data.unwrap().to_vec()).unwrap()).unwrap(); let (_, client_final) = scram.client_final(); - Ok(client_final) + Ok(Some(client_final.into())) } } @@ -50,7 +57,7 @@ async fn main() -> Result<(), Box> { let mut mqttoptions = MqttOptions::new("auth_test", "127.0.0.1", 1883); mqttoptions.set_authentication_method(Some("SCRAM-SHA-256".to_string())); - mqttoptions.set_authentication_data(Some(client_first.clone().into())); + mqttoptions.set_authentication_data(client_first); mqttoptions.set_auth_manager(Rc::new(RefCell::new(authmanager))); let (client, mut eventloop) = AsyncClient::new(mqttoptions, 10); diff --git a/rumqttc/src/v5/mod.rs b/rumqttc/src/v5/mod.rs index 6e9082ddf..40d7edad6 100644 --- a/rumqttc/src/v5/mod.rs +++ b/rumqttc/src/v5/mod.rs @@ -36,7 +36,7 @@ pub use crate::proxy::{Proxy, ProxyAuth, ProxyType}; pub type Incoming = Packet; pub trait AuthManagerTrait: std::fmt::Debug { - fn auth_continue(&mut self, auth_data: String) -> Result; + fn auth_continue(&mut self, auth_method: Option, auth_data: Option) -> Result, String>; } diff --git a/rumqttc/src/v5/state.rs b/rumqttc/src/v5/state.rs index ffa102d6b..584711ec3 100644 --- a/rumqttc/src/v5/state.rs +++ b/rumqttc/src/v5/state.rs @@ -493,7 +493,8 @@ impl MqttState { fn handle_incoming_auth(&mut self, auth: &mut Auth) -> Result, StateError> { let props = auth.properties.clone().unwrap(); - let in_auth_data = String::from_utf8(props.authentication_data.unwrap().to_vec()).unwrap(); + let in_auth_method = props.authentication_method; + let in_auth_data = props.authentication_data; // Check if auth manager is set if self.auth_manager.is_none() { @@ -502,14 +503,15 @@ impl MqttState { let auth_manager = Rc::clone(self.auth_manager.as_ref().unwrap()); - let out_auth_data = match auth_manager.borrow_mut().auth_continue(in_auth_data) { + // Call auth_continue method of auth manager + let out_auth_data = match auth_manager.borrow_mut().auth_continue(in_auth_method.clone(), in_auth_data) { Ok(data) => data, Err(err) => return Err(StateError::AuthError(err)), }; let properties = AuthProperties{ - authentication_method: Some(props.authentication_method.unwrap().to_string()), - authentication_data: Some(out_auth_data.clone().into()), + authentication_method: in_auth_method, + authentication_data: out_auth_data, reason_string: None, user_properties: Vec::new(), }; From fee30d6648d677ab04abd3443e56d3643d624c20 Mon Sep 17 00:00:00 2001 From: tinzhu Date: Wed, 24 Apr 2024 13:47:13 +0800 Subject: [PATCH 09/31] Change auth to reauth. --- rumqttc/examples/async_auth.rs | 42 +++++++++++++++++++++++----------- rumqttc/src/v5/client.rs | 8 +++---- 2 files changed, 33 insertions(+), 17 deletions(-) diff --git a/rumqttc/examples/async_auth.rs b/rumqttc/examples/async_auth.rs index d167def8a..ac8ac9f40 100644 --- a/rumqttc/examples/async_auth.rs +++ b/rumqttc/examples/async_auth.rs @@ -1,6 +1,6 @@ -use rumqttc::v5::mqttbytes::{QoS, v5::AuthReasonCode, v5::AuthProperties, v5::Auth}; -use rumqttc::v5::{AsyncClient, MqttOptions, AuthManagerTrait, StateError}; +use rumqttc::v5::mqttbytes::{QoS, v5::AuthProperties}; +use rumqttc::v5::{AsyncClient, MqttOptions, AuthManagerTrait}; use tokio::task; use std::error::Error; use std::rc::Rc; @@ -11,30 +11,32 @@ use scram::client::ServerFirst; #[derive(Debug)] struct AuthManager <'a>{ - scram_client: Option>, - scram_server: Option>, + user: &'a str, + password: &'a str, + scram: Option>, } impl <'a> AuthManager <'a>{ fn new(user: &'a str, password: &'a str) -> AuthManager <'a>{ - let scram = ScramClient::new(user, password, None); - AuthManager{ - scram_client: Some(scram), - scram_server: None, + user, + password, + scram: None, } } - + fn auth_start(&mut self) -> Result, String>{ - let scram = self.scram_client.take().unwrap(); + let scram = ScramClient::new(self.user, self.password, None); let (scram, client_first) = scram.client_first(); - self.scram_server = Some(scram); + self.scram = Some(scram); Ok(Some(client_first.into())) } } impl <'a> AuthManagerTrait for AuthManager<'a> { + + fn auth_continue(&mut self, auth_method: Option, auth_data: Option) -> Result, String> { // Check if the authentication method is SCRAM-SHA-256 @@ -42,9 +44,23 @@ impl <'a> AuthManagerTrait for AuthManager<'a> { return Err("Invalid authentication method".to_string()); } - let scram = self.scram_server.take().unwrap(); - let scram = scram.handle_server_first(&String::from_utf8(auth_data.unwrap().to_vec()).unwrap()).unwrap(); + if self.scram.is_none() { + return Err("Invalid state".to_string()); + } + + let scram = self.scram.take().unwrap(); + + let auth_data = String::from_utf8(auth_data.unwrap().to_vec()).unwrap(); + + // Process the server first message and reassign the SCRAM state. + let scram = match(scram.handle_server_first(&auth_data)){ + Ok(scram) => scram, + Err(e) => return Err(e.to_string()), + }; + + // Get the client final message and reassign the SCRAM state. let (_, client_final) = scram.client_final(); + Ok(Some(client_final.into())) } } diff --git a/rumqttc/src/v5/client.rs b/rumqttc/src/v5/client.rs index 048082886..0f255655b 100644 --- a/rumqttc/src/v5/client.rs +++ b/rumqttc/src/v5/client.rs @@ -196,16 +196,16 @@ impl AsyncClient { } /// Sends a MQTT AUTH to `EventLoop` for authentication. - pub async fn auth(&self, reason: AuthReasonCode, properties: Option) -> Result<(), ClientError>{ - let auth = Auth::new(reason, properties); + pub async fn reauth(&self, properties: Option) -> Result<(), ClientError>{ + let auth = Auth::new(AuthReasonCode::Reauthenticate, properties); let auth = Request::Auth(auth); self.request_tx.send_async(auth).await?; Ok(()) } /// Attempts to send a MQTT AUTH to `EventLoop` for authentication. - pub fn try_auth(&self, reason: AuthReasonCode, properties: Option) -> Result<(), ClientError>{ - let auth = Auth::new(reason, properties); + pub fn try_reauth(&self, properties: Option) -> Result<(), ClientError>{ + let auth = Auth::new(AuthReasonCode::Reauthenticate, properties); let auth = Request::Auth(auth); self.request_tx.try_send(auth)?; Ok(()) From 9fdf82c08a76048c9298cd996e8ac6b682acee05 Mon Sep 17 00:00:00 2001 From: tinzhu Date: Wed, 24 Apr 2024 16:42:46 +0800 Subject: [PATCH 10/31] Make auth_manager thread safe. --- rumqttc/examples/async_auth.rs | 4 ++-- rumqttc/src/v5/mod.rs | 14 +++++++------- rumqttc/src/v5/state.rs | 8 ++++---- 3 files changed, 13 insertions(+), 13 deletions(-) diff --git a/rumqttc/examples/async_auth.rs b/rumqttc/examples/async_auth.rs index ac8ac9f40..a332b7a8f 100644 --- a/rumqttc/examples/async_auth.rs +++ b/rumqttc/examples/async_auth.rs @@ -3,7 +3,7 @@ use rumqttc::v5::mqttbytes::{QoS, v5::AuthProperties}; use rumqttc::v5::{AsyncClient, MqttOptions, AuthManagerTrait}; use tokio::task; use std::error::Error; -use std::rc::Rc; +use std::sync::Arc; use std::cell::RefCell; use bytes::Bytes; use scram::ScramClient; @@ -74,7 +74,7 @@ async fn main() -> Result<(), Box> { let mut mqttoptions = MqttOptions::new("auth_test", "127.0.0.1", 1883); mqttoptions.set_authentication_method(Some("SCRAM-SHA-256".to_string())); mqttoptions.set_authentication_data(client_first); - mqttoptions.set_auth_manager(Rc::new(RefCell::new(authmanager))); + mqttoptions.set_auth_manager(Arc::new(RefCell::new(authmanager))); let (client, mut eventloop) = AsyncClient::new(mqttoptions, 10); task::spawn(async move { diff --git a/rumqttc/src/v5/mod.rs b/rumqttc/src/v5/mod.rs index 40d7edad6..5ca84aad6 100644 --- a/rumqttc/src/v5/mod.rs +++ b/rumqttc/src/v5/mod.rs @@ -1,15 +1,14 @@ use bytes::Bytes; use std::fmt::{self, Debug, Formatter}; use std::time::Duration; -use std::rc::Rc; +use std::sync::Arc; use std::cell::RefCell; use flume::{Sender, Receiver}; #[cfg(feature = "websocket")] use std::{ future::{Future, IntoFuture}, - pin::Pin, - sync::Arc + pin::Pin }; mod client; @@ -115,7 +114,7 @@ pub struct MqttOptions { #[cfg(feature = "websocket")] request_modifier: Option, - auth_manager: Option>>, + auth_manager: Option>>, } impl MqttOptions { @@ -540,16 +539,17 @@ impl MqttOptions { self.outgoing_inflight_upper_limit } - pub fn set_auth_manager(&mut self, auth_manager: Rc>) -> &mut Self { + pub fn set_auth_manager(&mut self, auth_manager: Arc>) -> &mut Self { self.auth_manager = Some(auth_manager); self } - pub fn auth_manager(&self) -> Option>> { + pub fn auth_manager(&self) -> Option>> { if self.auth_manager.is_none() { return None; } - Some(Rc::clone(self.auth_manager.as_ref().unwrap())) + + self.auth_manager.clone() } } diff --git a/rumqttc/src/v5/state.rs b/rumqttc/src/v5/state.rs index 584711ec3..51ae79fb8 100644 --- a/rumqttc/src/v5/state.rs +++ b/rumqttc/src/v5/state.rs @@ -10,7 +10,7 @@ use super::{Event, Incoming, Outgoing, Request, AuthManagerTrait}; use bytes::Bytes; use std::collections::{HashMap, VecDeque}; use std::{io, time::Instant}; -use std::rc::Rc; +use std::sync::Arc; use std::cell::RefCell; use flume::{Receiver, Sender}; @@ -129,14 +129,14 @@ pub struct MqttState { /// Upper limit on the maximum number of allowed inflight QoS1 & QoS2 requests max_outgoing_inflight_upper_limit: u16, /// Authentication manager - auth_manager: Option>>, + auth_manager: Option>>, } impl MqttState { /// Creates new mqtt state. Same state should be used during a /// connection for persistent sessions while new state should /// instantiated for clean sessions - pub fn new(max_inflight: u16, manual_acks: bool, auth_manager: Option>>) -> Self { + pub fn new(max_inflight: u16, manual_acks: bool, auth_manager: Option>>) -> Self { MqttState { await_pingresp: false, collision_ping_count: 0, @@ -501,7 +501,7 @@ impl MqttState { return Err(StateError::AuthManagerNotSet); } - let auth_manager = Rc::clone(self.auth_manager.as_ref().unwrap()); + let auth_manager = Arc::clone(self.auth_manager.as_ref().unwrap()); // Call auth_continue method of auth manager let out_auth_data = match auth_manager.borrow_mut().auth_continue(in_auth_method.clone(), in_auth_data) { From 83b4be2dbde9c61366b459f3b108cd736e11674c Mon Sep 17 00:00:00 2001 From: tinzhu Date: Wed, 24 Apr 2024 17:30:33 +0800 Subject: [PATCH 11/31] Improve logic. --- rumqttc/examples/async_auth.rs | 31 ++++++++++++++---- rumqttc/src/v5/mod.rs | 1 + rumqttc/src/v5/state.rs | 58 +++++++++++++++++++--------------- 3 files changed, 57 insertions(+), 33 deletions(-) diff --git a/rumqttc/examples/async_auth.rs b/rumqttc/examples/async_auth.rs index a332b7a8f..f11f8c81f 100644 --- a/rumqttc/examples/async_auth.rs +++ b/rumqttc/examples/async_auth.rs @@ -24,7 +24,10 @@ impl <'a> AuthManager <'a>{ scram: None, } } - +} + +impl <'a> AuthManagerTrait for AuthManager<'a> { + fn auth_start(&mut self) -> Result, String>{ let scram = ScramClient::new(self.user, self.password, None); let (scram, client_first) = scram.client_first(); @@ -32,10 +35,6 @@ impl <'a> AuthManager <'a>{ Ok(Some(client_first.into())) } -} - -impl <'a> AuthManagerTrait for AuthManager<'a> { - fn auth_continue(&mut self, auth_method: Option, auth_data: Option) -> Result, String> { @@ -70,12 +69,14 @@ async fn main() -> Result<(), Box> { let mut authmanager = AuthManager::new("user1", "123456"); let client_first = authmanager.auth_start().unwrap(); + let authmanager = Arc::new(RefCell::new(authmanager)); let mut mqttoptions = MqttOptions::new("auth_test", "127.0.0.1", 1883); mqttoptions.set_authentication_method(Some("SCRAM-SHA-256".to_string())); mqttoptions.set_authentication_data(client_first); - mqttoptions.set_auth_manager(Arc::new(RefCell::new(authmanager))); + mqttoptions.set_auth_manager(authmanager.clone()); let (client, mut eventloop) = AsyncClient::new(mqttoptions, 10); + let client2 = client.clone(); task::spawn(async move { client.subscribe("rumqtt_auth/topic", QoS::AtLeastOnce).await.unwrap(); @@ -86,7 +87,23 @@ async fn main() -> Result<(), Box> { let notification = eventloop.poll().await; match notification { - Ok(event) => println!("{:?}", event), + Ok(event) => { + println!("{:?}", event); + match(event){ + rumqttc::v5::Event::Incoming(rumqttc::v5::Incoming::ConnAck(_)) => { + // Test re-authentication. + let client_first = authmanager.clone().borrow_mut().auth_start().unwrap(); + let properties = AuthProperties{ + authentication_method: Some("SCRAM-SHA-256".to_string()), + authentication_data: client_first, + reason_string: None, + user_properties: Vec::new(), + }; + client2.reauth(Some(properties)).await.unwrap(); + } + _ => {}, + } + } Err(e) => { println!("Error = {:?}", e); break; diff --git a/rumqttc/src/v5/mod.rs b/rumqttc/src/v5/mod.rs index 5ca84aad6..f57d2ed2b 100644 --- a/rumqttc/src/v5/mod.rs +++ b/rumqttc/src/v5/mod.rs @@ -35,6 +35,7 @@ pub use crate::proxy::{Proxy, ProxyAuth, ProxyType}; pub type Incoming = Packet; pub trait AuthManagerTrait: std::fmt::Debug { + fn auth_start(&mut self) -> Result, String>; fn auth_continue(&mut self, auth_method: Option, auth_data: Option) -> Result, String>; } diff --git a/rumqttc/src/v5/state.rs b/rumqttc/src/v5/state.rs index 51ae79fb8..8776dd887 100644 --- a/rumqttc/src/v5/state.rs +++ b/rumqttc/src/v5/state.rs @@ -492,33 +492,39 @@ impl MqttState { } fn handle_incoming_auth(&mut self, auth: &mut Auth) -> Result, StateError> { - let props = auth.properties.clone().unwrap(); - let in_auth_method = props.authentication_method; - let in_auth_data = props.authentication_data; - - // Check if auth manager is set - if self.auth_manager.is_none() { - return Err(StateError::AuthManagerNotSet); + match auth.reason { + AuthReasonCode::Success => Ok(None), + AuthReasonCode::ContinueAuthentication => { + let props = auth.properties.clone().unwrap(); + let in_auth_method = props.authentication_method; + let in_auth_data = props.authentication_data; + + // Check if auth manager is set + if self.auth_manager.is_none() { + return Err(StateError::AuthManagerNotSet); + } + + let auth_manager = self.auth_manager.clone().unwrap(); + + // Call auth_continue method of auth manager + let out_auth_data = match auth_manager.borrow_mut().auth_continue(in_auth_method.clone(), in_auth_data) { + Ok(data) => data, + Err(err) => return Err(StateError::AuthError(err)), + }; + + let properties = AuthProperties{ + authentication_method: in_auth_method, + authentication_data: out_auth_data, + reason_string: None, + user_properties: Vec::new(), + }; + + let client_auth = Auth::new(AuthReasonCode::ContinueAuthentication, Some(properties)); + + self.outgoing_auth(client_auth) + } + _ => return Err(StateError::AuthError("Authentication Failed!".to_string())), } - - let auth_manager = Arc::clone(self.auth_manager.as_ref().unwrap()); - - // Call auth_continue method of auth manager - let out_auth_data = match auth_manager.borrow_mut().auth_continue(in_auth_method.clone(), in_auth_data) { - Ok(data) => data, - Err(err) => return Err(StateError::AuthError(err)), - }; - - let properties = AuthProperties{ - authentication_method: in_auth_method, - authentication_data: out_auth_data, - reason_string: None, - user_properties: Vec::new(), - }; - - let client_auth = Auth::new(AuthReasonCode::ContinueAuthentication, Some(properties)); - - self.outgoing_auth(client_auth) } /// Adds next packet identifier to QoS 1 and 2 publish packets and returns From 4e07ecb1c406d557538d11b156ef77367d4e11b7 Mon Sep 17 00:00:00 2001 From: tinzhu Date: Wed, 24 Apr 2024 17:32:53 +0800 Subject: [PATCH 12/31] Improve logic. --- rumqttc/examples/async_auth.rs | 5 ++--- rumqttc/src/v5/mod.rs | 1 - 2 files changed, 2 insertions(+), 4 deletions(-) diff --git a/rumqttc/examples/async_auth.rs b/rumqttc/examples/async_auth.rs index f11f8c81f..d43fc4f0f 100644 --- a/rumqttc/examples/async_auth.rs +++ b/rumqttc/examples/async_auth.rs @@ -24,9 +24,6 @@ impl <'a> AuthManager <'a>{ scram: None, } } -} - -impl <'a> AuthManagerTrait for AuthManager<'a> { fn auth_start(&mut self) -> Result, String>{ let scram = ScramClient::new(self.user, self.password, None); @@ -35,7 +32,9 @@ impl <'a> AuthManagerTrait for AuthManager<'a> { Ok(Some(client_first.into())) } +} +impl <'a> AuthManagerTrait for AuthManager<'a> { fn auth_continue(&mut self, auth_method: Option, auth_data: Option) -> Result, String> { // Check if the authentication method is SCRAM-SHA-256 diff --git a/rumqttc/src/v5/mod.rs b/rumqttc/src/v5/mod.rs index f57d2ed2b..5ca84aad6 100644 --- a/rumqttc/src/v5/mod.rs +++ b/rumqttc/src/v5/mod.rs @@ -35,7 +35,6 @@ pub use crate::proxy::{Proxy, ProxyAuth, ProxyType}; pub type Incoming = Packet; pub trait AuthManagerTrait: std::fmt::Debug { - fn auth_start(&mut self) -> Result, String>; fn auth_continue(&mut self, auth_method: Option, auth_data: Option) -> Result, String>; } From cb65cb900ca742a3d88e1b254c526351f3a8841e Mon Sep 17 00:00:00 2001 From: tinzhu Date: Thu, 25 Apr 2024 14:11:55 +0800 Subject: [PATCH 13/31] Rename trait name "AuthManagerTrait" to "AuthManager". --- rumqttc/examples/async_auth.rs | 14 +++++++------- rumqttc/src/v5/mod.rs | 21 ++++++++++++++++----- rumqttc/src/v5/state.rs | 6 +++--- 3 files changed, 26 insertions(+), 15 deletions(-) diff --git a/rumqttc/examples/async_auth.rs b/rumqttc/examples/async_auth.rs index d43fc4f0f..69933fbf1 100644 --- a/rumqttc/examples/async_auth.rs +++ b/rumqttc/examples/async_auth.rs @@ -1,6 +1,6 @@ use rumqttc::v5::mqttbytes::{QoS, v5::AuthProperties}; -use rumqttc::v5::{AsyncClient, MqttOptions, AuthManagerTrait}; +use rumqttc::v5::{AsyncClient, MqttOptions, AuthManager}; use tokio::task; use std::error::Error; use std::sync::Arc; @@ -10,15 +10,15 @@ use scram::ScramClient; use scram::client::ServerFirst; #[derive(Debug)] -struct AuthManager <'a>{ +struct ScramAuthManager <'a>{ user: &'a str, password: &'a str, scram: Option>, } -impl <'a> AuthManager <'a>{ - fn new(user: &'a str, password: &'a str) -> AuthManager <'a>{ - AuthManager{ +impl <'a> ScramAuthManager <'a>{ + fn new(user: &'a str, password: &'a str) -> ScramAuthManager <'a>{ + ScramAuthManager{ user, password, scram: None, @@ -34,7 +34,7 @@ impl <'a> AuthManager <'a>{ } } -impl <'a> AuthManagerTrait for AuthManager<'a> { +impl <'a> AuthManager for ScramAuthManager<'a> { fn auth_continue(&mut self, auth_method: Option, auth_data: Option) -> Result, String> { // Check if the authentication method is SCRAM-SHA-256 @@ -66,7 +66,7 @@ impl <'a> AuthManagerTrait for AuthManager<'a> { #[tokio::main(flavor = "current_thread")] async fn main() -> Result<(), Box> { - let mut authmanager = AuthManager::new("user1", "123456"); + let mut authmanager = ScramAuthManager::new("user1", "123456"); let client_first = authmanager.auth_start().unwrap(); let authmanager = Arc::new(RefCell::new(authmanager)); diff --git a/rumqttc/src/v5/mod.rs b/rumqttc/src/v5/mod.rs index 5ca84aad6..e6e48addb 100644 --- a/rumqttc/src/v5/mod.rs +++ b/rumqttc/src/v5/mod.rs @@ -3,7 +3,6 @@ use std::fmt::{self, Debug, Formatter}; use std::time::Duration; use std::sync::Arc; use std::cell::RefCell; -use flume::{Sender, Receiver}; #[cfg(feature = "websocket")] use std::{ @@ -34,7 +33,19 @@ pub use crate::proxy::{Proxy, ProxyAuth, ProxyType}; pub type Incoming = Packet; -pub trait AuthManagerTrait: std::fmt::Debug { +pub trait AuthManager: std::fmt::Debug { + /// Process authentication data received from the server and generate authentication data to be sent back. + /// + /// # Arguments + /// + /// * `auth_method` - The authentication method received from the server. + /// * `auth_data` - The authentication data received from the server. + /// + /// # Returns + /// + /// * `Ok(auth_data)` - The authentication data to be sent back to the server. + /// * `Err(error_message)` - An error indicating that the authentication process has failed or terminated. + fn auth_continue(&mut self, auth_method: Option, auth_data: Option) -> Result, String>; } @@ -114,7 +125,7 @@ pub struct MqttOptions { #[cfg(feature = "websocket")] request_modifier: Option, - auth_manager: Option>>, + auth_manager: Option>>, } impl MqttOptions { @@ -539,12 +550,12 @@ impl MqttOptions { self.outgoing_inflight_upper_limit } - pub fn set_auth_manager(&mut self, auth_manager: Arc>) -> &mut Self { + pub fn set_auth_manager(&mut self, auth_manager: Arc>) -> &mut Self { self.auth_manager = Some(auth_manager); self } - pub fn auth_manager(&self) -> Option>> { + pub fn auth_manager(&self) -> Option>> { if self.auth_manager.is_none() { return None; } diff --git a/rumqttc/src/v5/state.rs b/rumqttc/src/v5/state.rs index 8776dd887..5f15aec72 100644 --- a/rumqttc/src/v5/state.rs +++ b/rumqttc/src/v5/state.rs @@ -5,7 +5,7 @@ use super::mqttbytes::v5::{ }; use super::mqttbytes::{self, Error as MqttError, QoS}; -use super::{Event, Incoming, Outgoing, Request, AuthManagerTrait}; +use super::{Event, Incoming, Outgoing, Request, AuthManager}; use bytes::Bytes; use std::collections::{HashMap, VecDeque}; @@ -129,14 +129,14 @@ pub struct MqttState { /// Upper limit on the maximum number of allowed inflight QoS1 & QoS2 requests max_outgoing_inflight_upper_limit: u16, /// Authentication manager - auth_manager: Option>>, + auth_manager: Option>>, } impl MqttState { /// Creates new mqtt state. Same state should be used during a /// connection for persistent sessions while new state should /// instantiated for clean sessions - pub fn new(max_inflight: u16, manual_acks: bool, auth_manager: Option>>) -> Self { + pub fn new(max_inflight: u16, manual_acks: bool, auth_manager: Option>>) -> Self { MqttState { await_pingresp: false, collision_ping_count: 0, From f62427640f09a71389aebe9809ff488feaa069a4 Mon Sep 17 00:00:00 2001 From: tinzhu Date: Fri, 26 Apr 2024 14:50:52 +0800 Subject: [PATCH 14/31] Use Mutex instead RefCell. --- rumqttc/examples/async_auth.rs | 25 ++++++++++++++----------- rumqttc/src/v5/mod.rs | 9 ++++----- rumqttc/src/v5/state.rs | 9 ++++----- 3 files changed, 22 insertions(+), 21 deletions(-) diff --git a/rumqttc/examples/async_auth.rs b/rumqttc/examples/async_auth.rs index 69933fbf1..7a0aeb591 100644 --- a/rumqttc/examples/async_auth.rs +++ b/rumqttc/examples/async_auth.rs @@ -3,8 +3,7 @@ use rumqttc::v5::mqttbytes::{QoS, v5::AuthProperties}; use rumqttc::v5::{AsyncClient, MqttOptions, AuthManager}; use tokio::task; use std::error::Error; -use std::sync::Arc; -use std::cell::RefCell; +use std::sync::{Arc, Mutex}; use bytes::Bytes; use scram::ScramClient; use scram::client::ServerFirst; @@ -68,7 +67,7 @@ async fn main() -> Result<(), Box> { let mut authmanager = ScramAuthManager::new("user1", "123456"); let client_first = authmanager.auth_start().unwrap(); - let authmanager = Arc::new(RefCell::new(authmanager)); + let authmanager = Arc::new(Mutex::new(authmanager)); let mut mqttoptions = MqttOptions::new("auth_test", "127.0.0.1", 1883); mqttoptions.set_authentication_method(Some("SCRAM-SHA-256".to_string())); @@ -80,6 +79,18 @@ async fn main() -> Result<(), Box> { task::spawn(async move { client.subscribe("rumqtt_auth/topic", QoS::AtLeastOnce).await.unwrap(); client.publish("rumqtt_auth/topic", QoS::AtLeastOnce, false, "hello world").await.unwrap(); + + // Sleep for 5 seconds for the connection to be established. + tokio::time::sleep(tokio::time::Duration::from_secs(5)).await; + + let client_first = authmanager.clone().lock().unwrap().auth_start().unwrap(); + let properties = AuthProperties{ + authentication_method: Some("SCRAM-SHA-256".to_string()), + authentication_data: client_first, + reason_string: None, + user_properties: Vec::new(), + }; + client.reauth(Some(properties)).await.unwrap(); }); loop { @@ -91,14 +102,6 @@ async fn main() -> Result<(), Box> { match(event){ rumqttc::v5::Event::Incoming(rumqttc::v5::Incoming::ConnAck(_)) => { // Test re-authentication. - let client_first = authmanager.clone().borrow_mut().auth_start().unwrap(); - let properties = AuthProperties{ - authentication_method: Some("SCRAM-SHA-256".to_string()), - authentication_data: client_first, - reason_string: None, - user_properties: Vec::new(), - }; - client2.reauth(Some(properties)).await.unwrap(); } _ => {}, } diff --git a/rumqttc/src/v5/mod.rs b/rumqttc/src/v5/mod.rs index e6e48addb..40ee63802 100644 --- a/rumqttc/src/v5/mod.rs +++ b/rumqttc/src/v5/mod.rs @@ -1,8 +1,7 @@ use bytes::Bytes; use std::fmt::{self, Debug, Formatter}; use std::time::Duration; -use std::sync::Arc; -use std::cell::RefCell; +use std::sync::{Arc, Mutex}; #[cfg(feature = "websocket")] use std::{ @@ -125,7 +124,7 @@ pub struct MqttOptions { #[cfg(feature = "websocket")] request_modifier: Option, - auth_manager: Option>>, + auth_manager: Option>>, } impl MqttOptions { @@ -550,12 +549,12 @@ impl MqttOptions { self.outgoing_inflight_upper_limit } - pub fn set_auth_manager(&mut self, auth_manager: Arc>) -> &mut Self { + pub fn set_auth_manager(&mut self, auth_manager: Arc>) -> &mut Self { self.auth_manager = Some(auth_manager); self } - pub fn auth_manager(&self) -> Option>> { + pub fn auth_manager(&self) -> Option>> { if self.auth_manager.is_none() { return None; } diff --git a/rumqttc/src/v5/state.rs b/rumqttc/src/v5/state.rs index 5f15aec72..1d25181df 100644 --- a/rumqttc/src/v5/state.rs +++ b/rumqttc/src/v5/state.rs @@ -10,8 +10,7 @@ use super::{Event, Incoming, Outgoing, Request, AuthManager}; use bytes::Bytes; use std::collections::{HashMap, VecDeque}; use std::{io, time::Instant}; -use std::sync::Arc; -use std::cell::RefCell; +use std::sync::{Arc, Mutex}; use flume::{Receiver, Sender}; /// Errors during state handling @@ -129,14 +128,14 @@ pub struct MqttState { /// Upper limit on the maximum number of allowed inflight QoS1 & QoS2 requests max_outgoing_inflight_upper_limit: u16, /// Authentication manager - auth_manager: Option>>, + auth_manager: Option>>, } impl MqttState { /// Creates new mqtt state. Same state should be used during a /// connection for persistent sessions while new state should /// instantiated for clean sessions - pub fn new(max_inflight: u16, manual_acks: bool, auth_manager: Option>>) -> Self { + pub fn new(max_inflight: u16, manual_acks: bool, auth_manager: Option>>) -> Self { MqttState { await_pingresp: false, collision_ping_count: 0, @@ -507,7 +506,7 @@ impl MqttState { let auth_manager = self.auth_manager.clone().unwrap(); // Call auth_continue method of auth manager - let out_auth_data = match auth_manager.borrow_mut().auth_continue(in_auth_method.clone(), in_auth_data) { + let out_auth_data = match auth_manager.lock().unwrap().auth_continue(in_auth_method.clone(), in_auth_data) { Ok(data) => data, Err(err) => return Err(StateError::AuthError(err)), }; From 213702b0f235375a90a11857655365c2ff3abf64 Mon Sep 17 00:00:00 2001 From: tinzhu Date: Fri, 26 Apr 2024 15:24:26 +0800 Subject: [PATCH 15/31] Improve logic. --- rumqttc/examples/async_auth.rs | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/rumqttc/examples/async_auth.rs b/rumqttc/examples/async_auth.rs index 7a0aeb591..0922fde0d 100644 --- a/rumqttc/examples/async_auth.rs +++ b/rumqttc/examples/async_auth.rs @@ -7,6 +7,7 @@ use std::sync::{Arc, Mutex}; use bytes::Bytes; use scram::ScramClient; use scram::client::ServerFirst; +use flume::bounded; #[derive(Debug)] struct ScramAuthManager <'a>{ @@ -74,14 +75,15 @@ async fn main() -> Result<(), Box> { mqttoptions.set_authentication_data(client_first); mqttoptions.set_auth_manager(authmanager.clone()); let (client, mut eventloop) = AsyncClient::new(mqttoptions, 10); - let client2 = client.clone(); + + let (tx, rx) = bounded(1); task::spawn(async move { client.subscribe("rumqtt_auth/topic", QoS::AtLeastOnce).await.unwrap(); client.publish("rumqtt_auth/topic", QoS::AtLeastOnce, false, "hello world").await.unwrap(); - // Sleep for 5 seconds for the connection to be established. - tokio::time::sleep(tokio::time::Duration::from_secs(5)).await; + // Wait for the connection to be established. + rx.recv_async().await.unwrap(); let client_first = authmanager.clone().lock().unwrap().auth_start().unwrap(); let properties = AuthProperties{ @@ -101,7 +103,7 @@ async fn main() -> Result<(), Box> { println!("{:?}", event); match(event){ rumqttc::v5::Event::Incoming(rumqttc::v5::Incoming::ConnAck(_)) => { - // Test re-authentication. + tx.send_async("Connected").await.unwrap(); } _ => {}, } From 7d372910b8542adc237e28ae69ebe34d6be6afff Mon Sep 17 00:00:00 2001 From: tinzhu Date: Sun, 28 Apr 2024 14:06:30 +0800 Subject: [PATCH 16/31] Improve logic. --- rumqttc/examples/async_auth.rs | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/rumqttc/examples/async_auth.rs b/rumqttc/examples/async_auth.rs index 0922fde0d..f27772847 100644 --- a/rumqttc/examples/async_auth.rs +++ b/rumqttc/examples/async_auth.rs @@ -85,6 +85,7 @@ async fn main() -> Result<(), Box> { // Wait for the connection to be established. rx.recv_async().await.unwrap(); + // Reauthenticate using SCRAM-SHA-256 let client_first = authmanager.clone().lock().unwrap().auth_start().unwrap(); let properties = AuthProperties{ authentication_method: Some("SCRAM-SHA-256".to_string()), @@ -100,7 +101,7 @@ async fn main() -> Result<(), Box> { match notification { Ok(event) => { - println!("{:?}", event); + println!("Event = {:?}", event); match(event){ rumqttc::v5::Event::Incoming(rumqttc::v5::Incoming::ConnAck(_)) => { tx.send_async("Connected").await.unwrap(); From 7cfc99b24853ce651f4f4ed8fd291fe62ca918ce Mon Sep 17 00:00:00 2001 From: tinzhu Date: Sun, 28 Apr 2024 14:57:49 +0800 Subject: [PATCH 17/31] Clean code. --- rumqttc/Cargo.toml | 1 - rumqttc/examples/async_auth.rs | 48 ++++++++++++++++++---------------- rumqttc/src/v5/client.rs | 4 +-- 3 files changed, 28 insertions(+), 25 deletions(-) diff --git a/rumqttc/Cargo.toml b/rumqttc/Cargo.toml index fa57640db..bba64822c 100644 --- a/rumqttc/Cargo.toml +++ b/rumqttc/Cargo.toml @@ -30,7 +30,6 @@ bytes = "1.5" log = "0.4" flume = { version = "0.11", default-features = false, features = ["async"] } thiserror = "1" -scram = "0.6.0" # Optional # rustls diff --git a/rumqttc/examples/async_auth.rs b/rumqttc/examples/async_auth.rs index f27772847..43b24cb2d 100644 --- a/rumqttc/examples/async_auth.rs +++ b/rumqttc/examples/async_auth.rs @@ -5,15 +5,15 @@ use tokio::task; use std::error::Error; use std::sync::{Arc, Mutex}; use bytes::Bytes; -use scram::ScramClient; -use scram::client::ServerFirst; +//use scram::ScramClient; +//use scram::client::ServerFirst; use flume::bounded; #[derive(Debug)] struct ScramAuthManager <'a>{ user: &'a str, password: &'a str, - scram: Option>, + //scram: Option>, } impl <'a> ScramAuthManager <'a>{ @@ -21,16 +21,18 @@ impl <'a> ScramAuthManager <'a>{ ScramAuthManager{ user, password, - scram: None, + //scram: None, } } fn auth_start(&mut self) -> Result, String>{ - let scram = ScramClient::new(self.user, self.password, None); - let (scram, client_first) = scram.client_first(); - self.scram = Some(scram); + //let scram = ScramClient::new(self.user, self.password, None); + //let (scram, client_first) = scram.client_first(); + //self.scram = Some(scram); - Ok(Some(client_first.into())) + //Ok(Some(client_first.into())) + + Ok(Some("client first message".into())) } } @@ -38,28 +40,30 @@ impl <'a> AuthManager for ScramAuthManager<'a> { fn auth_continue(&mut self, auth_method: Option, auth_data: Option) -> Result, String> { // Check if the authentication method is SCRAM-SHA-256 - if auth_method.unwrap() != "SCRAM-SHA-256" { - return Err("Invalid authentication method".to_string()); - } + //if auth_method.unwrap() != "SCRAM-SHA-256" { + // return Err("Invalid authentication method".to_string()); + //} - if self.scram.is_none() { - return Err("Invalid state".to_string()); - } + //if self.scram.is_none() { + // return Err("Invalid state".to_string()); + //} - let scram = self.scram.take().unwrap(); + //let scram = self.scram.take().unwrap(); - let auth_data = String::from_utf8(auth_data.unwrap().to_vec()).unwrap(); + //let auth_data = String::from_utf8(auth_data.unwrap().to_vec()).unwrap(); // Process the server first message and reassign the SCRAM state. - let scram = match(scram.handle_server_first(&auth_data)){ - Ok(scram) => scram, - Err(e) => return Err(e.to_string()), - }; + //let scram = match(scram.handle_server_first(&auth_data)){ + // Ok(scram) => scram, + // Err(e) => return Err(e.to_string()), + //}; // Get the client final message and reassign the SCRAM state. - let (_, client_final) = scram.client_final(); + //let (_, client_final) = scram.client_final(); + + //Ok(Some(client_final.into())) - Ok(Some(client_final.into())) + Ok(Some("client final message".into())) } } diff --git a/rumqttc/src/v5/client.rs b/rumqttc/src/v5/client.rs index 0f255655b..49d61d36f 100644 --- a/rumqttc/src/v5/client.rs +++ b/rumqttc/src/v5/client.rs @@ -11,7 +11,7 @@ use super::{ConnectionError, Event, EventLoop, MqttOptions, Request}; use crate::valid_topic; use bytes::Bytes; -use flume::{SendError, Sender, Receiver, TrySendError}; +use flume::{SendError, Sender, TrySendError}; use futures_util::FutureExt; use tokio::runtime::{self, Runtime}; use tokio::time::timeout; @@ -56,7 +56,7 @@ impl AsyncClient { pub fn new(options: MqttOptions, cap: usize) -> (AsyncClient, EventLoop) { let eventloop = EventLoop::new(options, cap); let request_tx = eventloop.requests_tx.clone(); - let client = AsyncClient { request_tx}; + let client = AsyncClient { request_tx }; (client, eventloop) } From 3e3b78dc8caac3a355f85420c88238138d647619 Mon Sep 17 00:00:00 2001 From: tinzhu Date: Sun, 28 Apr 2024 15:16:11 +0800 Subject: [PATCH 18/31] Formatted with cargo fmt. --- rumqttc/examples/async_auth.rs | 51 ++++++++++++++++------------- rumqttc/src/v5/client.rs | 8 ++--- rumqttc/src/v5/eventloop.rs | 12 ++++--- rumqttc/src/v5/mod.rs | 11 ++++--- rumqttc/src/v5/mqttbytes/v5/auth.rs | 15 +++------ rumqttc/src/v5/mqttbytes/v5/mod.rs | 4 +-- rumqttc/src/v5/state.rs | 50 +++++++++++++++++----------- 7 files changed, 85 insertions(+), 66 deletions(-) diff --git a/rumqttc/examples/async_auth.rs b/rumqttc/examples/async_auth.rs index 43b24cb2d..da1a40530 100644 --- a/rumqttc/examples/async_auth.rs +++ b/rumqttc/examples/async_auth.rs @@ -1,31 +1,30 @@ - -use rumqttc::v5::mqttbytes::{QoS, v5::AuthProperties}; -use rumqttc::v5::{AsyncClient, MqttOptions, AuthManager}; -use tokio::task; +use bytes::Bytes; +use rumqttc::v5::mqttbytes::{v5::AuthProperties, QoS}; +use rumqttc::v5::{AsyncClient, AuthManager, MqttOptions}; use std::error::Error; use std::sync::{Arc, Mutex}; -use bytes::Bytes; +use tokio::task; //use scram::ScramClient; //use scram::client::ServerFirst; use flume::bounded; #[derive(Debug)] -struct ScramAuthManager <'a>{ +struct ScramAuthManager<'a> { user: &'a str, password: &'a str, //scram: Option>, } -impl <'a> ScramAuthManager <'a>{ - fn new(user: &'a str, password: &'a str) -> ScramAuthManager <'a>{ - ScramAuthManager{ +impl<'a> ScramAuthManager<'a> { + fn new(user: &'a str, password: &'a str) -> ScramAuthManager<'a> { + ScramAuthManager { user, password, //scram: None, } } - fn auth_start(&mut self) -> Result, String>{ + fn auth_start(&mut self) -> Result, String> { //let scram = ScramClient::new(self.user, self.password, None); //let (scram, client_first) = scram.client_first(); //self.scram = Some(scram); @@ -36,9 +35,12 @@ impl <'a> ScramAuthManager <'a>{ } } -impl <'a> AuthManager for ScramAuthManager<'a> { - fn auth_continue(&mut self, auth_method: Option, auth_data: Option) -> Result, String> { - +impl<'a> AuthManager for ScramAuthManager<'a> { + fn auth_continue( + &mut self, + auth_method: Option, + auth_data: Option, + ) -> Result, String> { // Check if the authentication method is SCRAM-SHA-256 //if auth_method.unwrap() != "SCRAM-SHA-256" { // return Err("Invalid authentication method".to_string()); @@ -69,7 +71,6 @@ impl <'a> AuthManager for ScramAuthManager<'a> { #[tokio::main(flavor = "current_thread")] async fn main() -> Result<(), Box> { - let mut authmanager = ScramAuthManager::new("user1", "123456"); let client_first = authmanager.auth_start().unwrap(); let authmanager = Arc::new(Mutex::new(authmanager)); @@ -79,19 +80,25 @@ async fn main() -> Result<(), Box> { mqttoptions.set_authentication_data(client_first); mqttoptions.set_auth_manager(authmanager.clone()); let (client, mut eventloop) = AsyncClient::new(mqttoptions, 10); - + let (tx, rx) = bounded(1); task::spawn(async move { - client.subscribe("rumqtt_auth/topic", QoS::AtLeastOnce).await.unwrap(); - client.publish("rumqtt_auth/topic", QoS::AtLeastOnce, false, "hello world").await.unwrap(); + client + .subscribe("rumqtt_auth/topic", QoS::AtLeastOnce) + .await + .unwrap(); + client + .publish("rumqtt_auth/topic", QoS::AtLeastOnce, false, "hello world") + .await + .unwrap(); // Wait for the connection to be established. rx.recv_async().await.unwrap(); // Reauthenticate using SCRAM-SHA-256 let client_first = authmanager.clone().lock().unwrap().auth_start().unwrap(); - let properties = AuthProperties{ + let properties = AuthProperties { authentication_method: Some("SCRAM-SHA-256".to_string()), authentication_data: client_first, reason_string: None, @@ -106,11 +113,11 @@ async fn main() -> Result<(), Box> { match notification { Ok(event) => { println!("Event = {:?}", event); - match(event){ + match (event) { rumqttc::v5::Event::Incoming(rumqttc::v5::Incoming::ConnAck(_)) => { tx.send_async("Connected").await.unwrap(); } - _ => {}, + _ => {} } } Err(e) => { @@ -119,6 +126,6 @@ async fn main() -> Result<(), Box> { } } } - + Ok(()) -} \ No newline at end of file +} diff --git a/rumqttc/src/v5/client.rs b/rumqttc/src/v5/client.rs index 49d61d36f..4da5225f4 100644 --- a/rumqttc/src/v5/client.rs +++ b/rumqttc/src/v5/client.rs @@ -3,8 +3,8 @@ use std::time::Duration; use super::mqttbytes::v5::{ - Filter, PubAck, PubRec, Publish, PublishProperties, Subscribe, SubscribeProperties, Auth, AuthProperties, AuthReasonCode, - Unsubscribe, UnsubscribeProperties, + Auth, AuthProperties, AuthReasonCode, Filter, PubAck, PubRec, Publish, PublishProperties, + Subscribe, SubscribeProperties, Unsubscribe, UnsubscribeProperties, }; use super::mqttbytes::{valid_filter, QoS}; use super::{ConnectionError, Event, EventLoop, MqttOptions, Request}; @@ -196,7 +196,7 @@ impl AsyncClient { } /// Sends a MQTT AUTH to `EventLoop` for authentication. - pub async fn reauth(&self, properties: Option) -> Result<(), ClientError>{ + pub async fn reauth(&self, properties: Option) -> Result<(), ClientError> { let auth = Auth::new(AuthReasonCode::Reauthenticate, properties); let auth = Request::Auth(auth); self.request_tx.send_async(auth).await?; @@ -204,7 +204,7 @@ impl AsyncClient { } /// Attempts to send a MQTT AUTH to `EventLoop` for authentication. - pub fn try_reauth(&self, properties: Option) -> Result<(), ClientError>{ + pub fn try_reauth(&self, properties: Option) -> Result<(), ClientError> { let auth = Auth::new(AuthReasonCode::Reauthenticate, properties); let auth = Request::Auth(auth); self.request_tx.try_send(auth)?; diff --git a/rumqttc/src/v5/eventloop.rs b/rumqttc/src/v5/eventloop.rs index 2a109e0f5..c3163bc45 100644 --- a/rumqttc/src/v5/eventloop.rs +++ b/rumqttc/src/v5/eventloop.rs @@ -267,7 +267,10 @@ impl EventLoop { /// the stream. /// This function (for convenience) includes internal delays for users to perform internal sleeps /// between re-connections so that cancel semantics can be used during this sleep -async fn connect(options: &mut MqttOptions, state: &mut MqttState) -> Result<(Network, Incoming), ConnectionError> { +async fn connect( + options: &mut MqttOptions, + state: &mut MqttState, +) -> Result<(Network, Incoming), ConnectionError> { // connect to the broker let mut network = network_connect(options).await?; @@ -420,13 +423,14 @@ async fn mqtt_connect( } return Ok(Packet::ConnAck(connack)); } - Incoming::ConnAck(connack) => return Err(ConnectionError::ConnectionRefused(connack.code)), + Incoming::ConnAck(connack) => { + return Err(ConnectionError::ConnectionRefused(connack.code)) + } Incoming::Auth(auth) => { if let Some(outgoing) = state.handle_incoming_packet(Incoming::Auth(auth))? { network.write(outgoing).await?; network.flush().await?; - } - else { + } else { return Err(ConnectionError::AuthProcessingError); } } diff --git a/rumqttc/src/v5/mod.rs b/rumqttc/src/v5/mod.rs index 40ee63802..e2a5c1cb3 100644 --- a/rumqttc/src/v5/mod.rs +++ b/rumqttc/src/v5/mod.rs @@ -1,12 +1,12 @@ use bytes::Bytes; use std::fmt::{self, Debug, Formatter}; -use std::time::Duration; use std::sync::{Arc, Mutex}; +use std::time::Duration; #[cfg(feature = "websocket")] use std::{ future::{Future, IntoFuture}, - pin::Pin + pin::Pin, }; mod client; @@ -45,10 +45,13 @@ pub trait AuthManager: std::fmt::Debug { /// * `Ok(auth_data)` - The authentication data to be sent back to the server. /// * `Err(error_message)` - An error indicating that the authentication process has failed or terminated. - fn auth_continue(&mut self, auth_method: Option, auth_data: Option) -> Result, String>; + fn auth_continue( + &mut self, + auth_method: Option, + auth_data: Option, + ) -> Result, String>; } - /// Requests by the client to mqtt event loop. Request are /// handled one by one. #[derive(Clone, Debug, PartialEq, Eq)] diff --git a/rumqttc/src/v5/mqttbytes/v5/auth.rs b/rumqttc/src/v5/mqttbytes/v5/auth.rs index 8157905f3..07b668528 100644 --- a/rumqttc/src/v5/mqttbytes/v5/auth.rs +++ b/rumqttc/src/v5/mqttbytes/v5/auth.rs @@ -10,11 +10,7 @@ pub struct Auth { impl Auth { pub fn new(reason: AuthReasonCode, properties: Option) -> Self { - - Self { - reason, - properties, - } + Self { reason, properties } } pub fn size(&self) -> usize { @@ -45,10 +41,7 @@ impl Auth { let code = read_u8(&mut bytes)?; let reason = reason(code)?; let properties = AuthProperties::read(&mut bytes)?; - let auth = Auth { - reason, - properties, - }; + let auth = Auth { reason, properties }; Ok(auth) } @@ -75,7 +68,7 @@ impl Auth { pub enum AuthReasonCode { Success, ContinueAuthentication, - Reauthenticate + Reauthenticate, } #[derive(Debug, Clone, PartialEq, Eq, Default)] @@ -209,7 +202,7 @@ fn code(value: AuthReasonCode) -> u8 { match value { AuthReasonCode::Success => 0x00, AuthReasonCode::ContinueAuthentication => 0x18, - AuthReasonCode::Reauthenticate => 0x19 + AuthReasonCode::Reauthenticate => 0x19, } } diff --git a/rumqttc/src/v5/mqttbytes/v5/mod.rs b/rumqttc/src/v5/mqttbytes/v5/mod.rs index ea433aa51..22f11547d 100644 --- a/rumqttc/src/v5/mqttbytes/v5/mod.rs +++ b/rumqttc/src/v5/mqttbytes/v5/mod.rs @@ -1,6 +1,7 @@ use std::slice::Iter; pub use self::{ + auth::{Auth, AuthProperties, AuthReasonCode}, codec::Codec, connack::{ConnAck, ConnAckProperties, ConnectReturnCode}, connect::{Connect, ConnectProperties, LastWill, LastWillProperties, Login}, @@ -15,12 +16,12 @@ pub use self::{ subscribe::{Filter, RetainForwardRule, Subscribe, SubscribeProperties}, unsuback::{UnsubAck, UnsubAckProperties, UnsubAckReason}, unsubscribe::{Unsubscribe, UnsubscribeProperties}, - auth::{Auth, AuthProperties, AuthReasonCode}, }; use super::*; use bytes::{Buf, BufMut, Bytes, BytesMut}; +mod auth; mod codec; mod connack; mod connect; @@ -35,7 +36,6 @@ mod suback; mod subscribe; mod unsuback; mod unsubscribe; -mod auth; #[derive(Clone, Debug, PartialEq, Eq)] pub enum Packet { diff --git a/rumqttc/src/v5/state.rs b/rumqttc/src/v5/state.rs index 1d25181df..528905189 100644 --- a/rumqttc/src/v5/state.rs +++ b/rumqttc/src/v5/state.rs @@ -1,17 +1,17 @@ use super::mqttbytes::v5::{ - ConnAck, ConnectReturnCode, Disconnect, DisconnectReasonCode, Packet, PingReq, PubAck, - PubAckReason, PubComp, PubCompReason, PubRec, PubRecReason, PubRel, PubRelReason, Publish, - SubAck, Subscribe, SubscribeReasonCode, UnsubAck, UnsubAckReason, Unsubscribe, Auth, AuthProperties, AuthReasonCode, + Auth, AuthProperties, AuthReasonCode, ConnAck, ConnectReturnCode, Disconnect, + DisconnectReasonCode, Packet, PingReq, PubAck, PubAckReason, PubComp, PubCompReason, PubRec, + PubRecReason, PubRel, PubRelReason, Publish, SubAck, Subscribe, SubscribeReasonCode, UnsubAck, + UnsubAckReason, Unsubscribe, }; use super::mqttbytes::{self, Error as MqttError, QoS}; -use super::{Event, Incoming, Outgoing, Request, AuthManager}; +use super::{AuthManager, Event, Incoming, Outgoing, Request}; use bytes::Bytes; use std::collections::{HashMap, VecDeque}; -use std::{io, time::Instant}; use std::sync::{Arc, Mutex}; -use flume::{Receiver, Sender}; +use std::{io, time::Instant}; /// Errors during state handling #[derive(Debug, thiserror::Error)] @@ -135,7 +135,11 @@ impl MqttState { /// Creates new mqtt state. Same state should be used during a /// connection for persistent sessions while new state should /// instantiated for clean sessions - pub fn new(max_inflight: u16, manual_acks: bool, auth_manager: Option>>) -> Self { + pub fn new( + max_inflight: u16, + manual_acks: bool, + auth_manager: Option>>, + ) -> Self { MqttState { await_pingresp: false, collision_ping_count: 0, @@ -489,7 +493,7 @@ impl MqttState { self.await_pingresp = false; Ok(None) } - + fn handle_incoming_auth(&mut self, auth: &mut Auth) -> Result, StateError> { match auth.reason { AuthReasonCode::Success => Ok(None), @@ -497,29 +501,34 @@ impl MqttState { let props = auth.properties.clone().unwrap(); let in_auth_method = props.authentication_method; let in_auth_data = props.authentication_data; - + // Check if auth manager is set if self.auth_manager.is_none() { return Err(StateError::AuthManagerNotSet); } - + let auth_manager = self.auth_manager.clone().unwrap(); - + // Call auth_continue method of auth manager - let out_auth_data = match auth_manager.lock().unwrap().auth_continue(in_auth_method.clone(), in_auth_data) { + let out_auth_data = match auth_manager + .lock() + .unwrap() + .auth_continue(in_auth_method.clone(), in_auth_data) + { Ok(data) => data, Err(err) => return Err(StateError::AuthError(err)), }; - - let properties = AuthProperties{ + + let properties = AuthProperties { authentication_method: in_auth_method, authentication_data: out_auth_data, reason_string: None, user_properties: Vec::new(), }; - - let client_auth = Auth::new(AuthReasonCode::ContinueAuthentication, Some(properties)); - + + let client_auth = + Auth::new(AuthReasonCode::ContinueAuthentication, Some(properties)); + self.outgoing_auth(client_auth) } _ => return Err(StateError::AuthError("Authentication Failed!".to_string())), @@ -693,10 +702,13 @@ impl MqttState { Ok(Some(Packet::Disconnect(Disconnect::new(reason)))) } - + fn outgoing_auth(&mut self, auth: Auth) -> Result, StateError> { let props = auth.properties.as_ref().unwrap(); - debug!("Auth packet sent. Auth Method: {:?}. Auth Data: {:?}", props.authentication_method, props.authentication_data); + debug!( + "Auth packet sent. Auth Method: {:?}. Auth Data: {:?}", + props.authentication_method, props.authentication_data + ); let event = Event::Outgoing(Outgoing::Auth); self.events.push_back(event); Ok(Some(Packet::Auth(auth))) From 13e08b1258b6fe06cb0e6b2012fcaf3741919237 Mon Sep 17 00:00:00 2001 From: tinzhu Date: Sun, 28 Apr 2024 15:32:21 +0800 Subject: [PATCH 19/31] Update CHANGELOG.md --- rumqttc/CHANGELOG.md | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/rumqttc/CHANGELOG.md b/rumqttc/CHANGELOG.md index 1045cfcf1..3a341249c 100644 --- a/rumqttc/CHANGELOG.md +++ b/rumqttc/CHANGELOG.md @@ -12,6 +12,10 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 * `size()` method on `Packet` calculates size once serialized. * `read()` and `write()` methods on `Packet`. * `ConnectionAborted` variant on `StateError` type to denote abrupt end to a connection +* `AUTH` packet support for enhanced authentication. +* `MqttOptions::set_auth_manager` that allows users to set their own authentication manager that implements the `AuthManager` trait. +* `Client::reauth` that enables users to send `AUTH` packet for re-authentication purposes. + ### Changed From c9241f25ca5633c2150f8bb3d75c17f8cb55983a Mon Sep 17 00:00:00 2001 From: tinzhu Date: Wed, 8 May 2024 09:45:38 +0800 Subject: [PATCH 20/31] Add reauth and try_reauth to sync client. --- rumqttc/src/v5/client.rs | 26 ++++++++++++++++++++++++++ 1 file changed, 26 insertions(+) diff --git a/rumqttc/src/v5/client.rs b/rumqttc/src/v5/client.rs index 4da5225f4..cc87f3d79 100644 --- a/rumqttc/src/v5/client.rs +++ b/rumqttc/src/v5/client.rs @@ -615,6 +615,22 @@ impl Client { Ok(()) } + /// Sends a MQTT AUTH to `EventLoop` for authentication. + pub fn reauth(&self, properties: Option) -> Result<(), ClientError> { + let auth = Auth::new(AuthReasonCode::Reauthenticate, properties); + let auth = Request::Auth(auth); + self.client.request_tx.send(auth)?; + Ok(()) + } + + /// Attempts to send a MQTT AUTH to `EventLoop` for authentication. + pub fn try_reauth(&self, properties: Option) -> Result<(), ClientError> { + let auth = Auth::new(AuthReasonCode::Reauthenticate, properties); + let auth = Request::Auth(auth); + self.client.request_tx.try_send(auth)?; + Ok(()) + } + /// Sends a MQTT Subscribe to the `EventLoop` fn handle_subscribe>( &self, @@ -911,4 +927,14 @@ mod test { .expect("Should be able to publish"); let _ = rx.try_recv().expect("Should have message"); } + + #[test] + fn test_reauth() { + let (client, mut connection) = Client::new(MqttOptions::new("test-1", "localhost", 1883), 10); + let _ = client.reauth(None).expect("Should be able to reauth"); + let _ = connection.iter().next().expect("Should have event"); + + let _ = client.try_reauth(None).expect("Should be able to reauth"); + let _ = connection.iter().next().expect("Should have event"); + } } From 1ab0aac2e2d4eb04bdd5430bef5a9e41f7574ff2 Mon Sep 17 00:00:00 2001 From: tinzhu Date: Wed, 8 May 2024 09:59:15 +0800 Subject: [PATCH 21/31] Add sync_auth example. --- rumqttc/examples/sync_auth.rs | 126 ++++++++++++++++++++++++++++++++++ 1 file changed, 126 insertions(+) create mode 100644 rumqttc/examples/sync_auth.rs diff --git a/rumqttc/examples/sync_auth.rs b/rumqttc/examples/sync_auth.rs new file mode 100644 index 000000000..48969be08 --- /dev/null +++ b/rumqttc/examples/sync_auth.rs @@ -0,0 +1,126 @@ +use bytes::Bytes; +use rumqttc::v5::mqttbytes::{v5::AuthProperties, QoS}; +use rumqttc::v5::{Client, AuthManager, MqttOptions}; +use std::error::Error; +use std::sync::{Arc, Mutex}; +use std::thread; +//use scram::ScramClient; +//use scram::client::ServerFirst; +use flume::bounded; + +#[derive(Debug)] +struct ScramAuthManager<'a> { + user: &'a str, + password: &'a str, + //scram: Option>, +} + +impl<'a> ScramAuthManager<'a> { + fn new(user: &'a str, password: &'a str) -> ScramAuthManager<'a> { + ScramAuthManager { + user, + password, + //scram: None, + } + } + + fn auth_start(&mut self) -> Result, String> { + // let scram = ScramClient::new(self.user, self.password, None); + // let (scram, client_first) = scram.client_first(); + // self.scram = Some(scram); + + // Ok(Some(client_first.into())) + + Ok(Some("client first message".into())) + } +} + +impl<'a> AuthManager for ScramAuthManager<'a> { + fn auth_continue( + &mut self, + auth_method: Option, + auth_data: Option, + ) -> Result, String> { + //Check if the authentication method is SCRAM-SHA-256 + // if auth_method.unwrap() != "SCRAM-SHA-256" { + // return Err("Invalid authentication method".to_string()); + // } + + // if self.scram.is_none() { + // return Err("Invalid state".to_string()); + // } + + // let scram = self.scram.take().unwrap(); + + // let auth_data = String::from_utf8(auth_data.unwrap().to_vec()).unwrap(); + + //Process the server first message and reassign the SCRAM state. + // let scram = match(scram.handle_server_first(&auth_data)){ + // Ok(scram) => scram, + // Err(e) => return Err(e.to_string()), + // }; + + // //Get the client final message and reassign the SCRAM state. + // let (_, client_final) = scram.client_final(); + + // Ok(Some(client_final.into())) + + Ok(Some("client final message".into())) + } +} + +fn main() -> Result<(), Box> { + let mut authmanager = ScramAuthManager::new("user1", "123456"); + let client_first = authmanager.auth_start().unwrap(); + let authmanager = Arc::new(Mutex::new(authmanager)); + + let mut mqttoptions = MqttOptions::new("auth_test", "127.0.0.1", 1883); + mqttoptions.set_authentication_method(Some("SCRAM-SHA-256".to_string())); + mqttoptions.set_authentication_data(client_first); + mqttoptions.set_auth_manager(authmanager.clone()); + let (client, mut connection) = Client::new(mqttoptions, 10); + + let (tx, rx) = bounded(1); + + thread::spawn(move || { + client + .subscribe("rumqtt_auth/topic", QoS::AtLeastOnce) + .unwrap(); + client + .publish("rumqtt_auth/topic", QoS::AtLeastOnce, false, "hello world") + .unwrap(); + + // Wait for the connection to be established. + rx.recv().unwrap(); + + // Reauthenticate using SCRAM-SHA-256 + let client_first = authmanager.clone().lock().unwrap().auth_start().unwrap(); + let properties = AuthProperties { + authentication_method: Some("SCRAM-SHA-256".to_string()), + authentication_data: client_first, + reason_string: None, + user_properties: Vec::new(), + }; + client.reauth(Some(properties)).unwrap(); + }); + + for (i, notification) in connection.iter().enumerate() { + match notification { + Ok(event) => { + println!("Event = {:?}", event); + match (event) { + rumqttc::v5::Event::Incoming(rumqttc::v5::Incoming::ConnAck(_)) => { + tx.send("Connected").unwrap(); + } + _ => {} + } + } + Err(e) => { + println!("Error = {:?}", e); + break; + } + } + } + + Ok(()) +} From db2ef375cd3a1401b8ec0aff979bbf935b44a3f8 Mon Sep 17 00:00:00 2001 From: tinzhu Date: Wed, 8 May 2024 11:24:40 +0800 Subject: [PATCH 22/31] Format code. --- rumqttc/examples/sync_auth.rs | 2 +- rumqttc/src/v5/client.rs | 3 ++- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/rumqttc/examples/sync_auth.rs b/rumqttc/examples/sync_auth.rs index 48969be08..2fa47d64c 100644 --- a/rumqttc/examples/sync_auth.rs +++ b/rumqttc/examples/sync_auth.rs @@ -1,6 +1,6 @@ use bytes::Bytes; use rumqttc::v5::mqttbytes::{v5::AuthProperties, QoS}; -use rumqttc::v5::{Client, AuthManager, MqttOptions}; +use rumqttc::v5::{AuthManager, Client, MqttOptions}; use std::error::Error; use std::sync::{Arc, Mutex}; use std::thread; diff --git a/rumqttc/src/v5/client.rs b/rumqttc/src/v5/client.rs index cc87f3d79..57e2fe0f5 100644 --- a/rumqttc/src/v5/client.rs +++ b/rumqttc/src/v5/client.rs @@ -930,7 +930,8 @@ mod test { #[test] fn test_reauth() { - let (client, mut connection) = Client::new(MqttOptions::new("test-1", "localhost", 1883), 10); + let (client, mut connection) = + Client::new(MqttOptions::new("test-1", "localhost", 1883), 10); let _ = client.reauth(None).expect("Should be able to reauth"); let _ = connection.iter().next().expect("Should have event"); From 14eab60354d0ed52988a87f81a8b6397e66653c3 Mon Sep 17 00:00:00 2001 From: tinzhu Date: Tue, 28 May 2024 17:03:16 +0800 Subject: [PATCH 23/31] Update auth tests. --- rumqttc/Cargo.toml | 3 ++ rumqttc/examples/async_auth.rs | 64 ++++++++++++++++++++-------------- rumqttc/examples/sync_auth.rs | 62 +++++++++++++++++++------------- 3 files changed, 78 insertions(+), 51 deletions(-) diff --git a/rumqttc/Cargo.toml b/rumqttc/Cargo.toml index 1e5a4ba93..bfd4fe183 100644 --- a/rumqttc/Cargo.toml +++ b/rumqttc/Cargo.toml @@ -21,6 +21,7 @@ use-rustls = ["dep:tokio-rustls", "dep:rustls-webpki", "dep:rustls-pemfile", "de use-native-tls = ["dep:tokio-native-tls", "dep:native-tls"] websocket = ["dep:async-tungstenite", "dep:ws_stream_tungstenite", "dep:http"] proxy = ["dep:async-http-proxy"] +auth-scram = ["dep:scram"] [dependencies] futures-util = { version = "0.3", default-features = false, features = ["std", "sink"] } @@ -49,6 +50,8 @@ url = { version = "2", default-features = false, optional = true } # proxy async-http-proxy = { version = "1.2.5", features = ["runtime-tokio", "basic-auth"], optional = true } tokio-stream = "0.1.15" +#auth +scram = { version = "0.6.0", optional = true } [dev-dependencies] bincode = "1.3.3" diff --git a/rumqttc/examples/async_auth.rs b/rumqttc/examples/async_auth.rs index da1a40530..d30610b8a 100644 --- a/rumqttc/examples/async_auth.rs +++ b/rumqttc/examples/async_auth.rs @@ -4,15 +4,18 @@ use rumqttc::v5::{AsyncClient, AuthManager, MqttOptions}; use std::error::Error; use std::sync::{Arc, Mutex}; use tokio::task; -//use scram::ScramClient; -//use scram::client::ServerFirst; +#[cfg(feature = "auth-scram")] +use scram::ScramClient; +#[cfg(feature = "auth-scram")] +use scram::client::ServerFirst; use flume::bounded; #[derive(Debug)] struct ScramAuthManager<'a> { user: &'a str, password: &'a str, - //scram: Option>, + #[cfg(feature = "auth-scram")] + scram: Option>, } impl<'a> ScramAuthManager<'a> { @@ -20,17 +23,22 @@ impl<'a> ScramAuthManager<'a> { ScramAuthManager { user, password, - //scram: None, + #[cfg(feature = "auth-scram")] + scram: None, } } fn auth_start(&mut self) -> Result, String> { - //let scram = ScramClient::new(self.user, self.password, None); - //let (scram, client_first) = scram.client_first(); - //self.scram = Some(scram); + #[cfg(feature = "auth-scram")] + { + let scram = ScramClient::new(self.user, self.password, None); + let (scram, client_first) = scram.client_first(); + self.scram = Some(scram); - //Ok(Some(client_first.into())) + Ok(Some(client_first.into())) + } + #[cfg(not(feature = "auth-scram"))] Ok(Some("client first message".into())) } } @@ -41,30 +49,34 @@ impl<'a> AuthManager for ScramAuthManager<'a> { auth_method: Option, auth_data: Option, ) -> Result, String> { - // Check if the authentication method is SCRAM-SHA-256 - //if auth_method.unwrap() != "SCRAM-SHA-256" { - // return Err("Invalid authentication method".to_string()); - //} + #[cfg(feature = "auth-scram")] + { + // Check if the authentication method is SCRAM-SHA-256 + if auth_method.unwrap() != "SCRAM-SHA-256" { + return Err("Invalid authentication method".to_string()); + } - //if self.scram.is_none() { - // return Err("Invalid state".to_string()); - //} + if self.scram.is_none() { + return Err("Invalid state".to_string()); + } - //let scram = self.scram.take().unwrap(); + let scram = self.scram.take().unwrap(); - //let auth_data = String::from_utf8(auth_data.unwrap().to_vec()).unwrap(); + let auth_data = String::from_utf8(auth_data.unwrap().to_vec()).unwrap(); - // Process the server first message and reassign the SCRAM state. - //let scram = match(scram.handle_server_first(&auth_data)){ - // Ok(scram) => scram, - // Err(e) => return Err(e.to_string()), - //}; + // Process the server first message and reassign the SCRAM state. + let scram = match scram.handle_server_first(&auth_data) { + Ok(scram) => scram, + Err(e) => return Err(e.to_string()), + }; - // Get the client final message and reassign the SCRAM state. - //let (_, client_final) = scram.client_final(); + // Get the client final message and reassign the SCRAM state. + let (_, client_final) = scram.client_final(); - //Ok(Some(client_final.into())) + Ok(Some(client_final.into())) + } + #[cfg(not(feature = "auth-scram"))] Ok(Some("client final message".into())) } } @@ -113,7 +125,7 @@ async fn main() -> Result<(), Box> { match notification { Ok(event) => { println!("Event = {:?}", event); - match (event) { + match event { rumqttc::v5::Event::Incoming(rumqttc::v5::Incoming::ConnAck(_)) => { tx.send_async("Connected").await.unwrap(); } diff --git a/rumqttc/examples/sync_auth.rs b/rumqttc/examples/sync_auth.rs index 2fa47d64c..8421a2f2c 100644 --- a/rumqttc/examples/sync_auth.rs +++ b/rumqttc/examples/sync_auth.rs @@ -4,15 +4,18 @@ use rumqttc::v5::{AuthManager, Client, MqttOptions}; use std::error::Error; use std::sync::{Arc, Mutex}; use std::thread; -//use scram::ScramClient; -//use scram::client::ServerFirst; +#[cfg(feature = "auth-scram")] +use scram::ScramClient; +#[cfg(feature = "auth-scram")] +use scram::client::ServerFirst; use flume::bounded; #[derive(Debug)] struct ScramAuthManager<'a> { user: &'a str, password: &'a str, - //scram: Option>, + #[cfg(feature = "auth-scram")] + scram: Option>, } impl<'a> ScramAuthManager<'a> { @@ -20,17 +23,22 @@ impl<'a> ScramAuthManager<'a> { ScramAuthManager { user, password, - //scram: None, + #[cfg(feature = "auth-scram")] + scram: None, } } fn auth_start(&mut self) -> Result, String> { - // let scram = ScramClient::new(self.user, self.password, None); - // let (scram, client_first) = scram.client_first(); - // self.scram = Some(scram); + #[cfg(feature = "auth-scram")] + { + let scram = ScramClient::new(self.user, self.password, None); + let (scram, client_first) = scram.client_first(); + self.scram = Some(scram); - // Ok(Some(client_first.into())) + Ok(Some(client_first.into())) + } + #[cfg(not(feature = "auth-scram"))] Ok(Some("client first message".into())) } } @@ -41,30 +49,34 @@ impl<'a> AuthManager for ScramAuthManager<'a> { auth_method: Option, auth_data: Option, ) -> Result, String> { - //Check if the authentication method is SCRAM-SHA-256 - // if auth_method.unwrap() != "SCRAM-SHA-256" { - // return Err("Invalid authentication method".to_string()); - // } + #[cfg(feature = "auth-scram")] + { + //Check if the authentication method is SCRAM-SHA-256 + if auth_method.unwrap() != "SCRAM-SHA-256" { + return Err("Invalid authentication method".to_string()); + } - // if self.scram.is_none() { - // return Err("Invalid state".to_string()); - // } + if self.scram.is_none() { + return Err("Invalid state".to_string()); + } - // let scram = self.scram.take().unwrap(); + let scram = self.scram.take().unwrap(); - // let auth_data = String::from_utf8(auth_data.unwrap().to_vec()).unwrap(); + let auth_data = String::from_utf8(auth_data.unwrap().to_vec()).unwrap(); - //Process the server first message and reassign the SCRAM state. - // let scram = match(scram.handle_server_first(&auth_data)){ - // Ok(scram) => scram, - // Err(e) => return Err(e.to_string()), - // }; + //Process the server first message and reassign the SCRAM state. + let scram = match scram.handle_server_first(&auth_data) { + Ok(scram) => scram, + Err(e) => return Err(e.to_string()), + }; - // //Get the client final message and reassign the SCRAM state. - // let (_, client_final) = scram.client_final(); + //Get the client final message and reassign the SCRAM state. + let (_, client_final) = scram.client_final(); - // Ok(Some(client_final.into())) + Ok(Some(client_final.into())) + } + #[cfg(not(feature = "auth-scram"))] Ok(Some("client final message".into())) } } From d5e5a066b2a5498eb2b058bb8578338cf764b493 Mon Sep 17 00:00:00 2001 From: tinzhu Date: Tue, 28 May 2024 17:27:20 +0800 Subject: [PATCH 24/31] Add a new auth test. --- rumqttc/examples/async_auth_oauth.rs | 64 +++++++++++++++++++ .../{async_auth.rs => async_auth_scram.rs} | 10 +-- .../{sync_auth.rs => sync_auth_scram.rs} | 12 ++-- 3 files changed, 75 insertions(+), 11 deletions(-) create mode 100644 rumqttc/examples/async_auth_oauth.rs rename rumqttc/examples/{async_auth.rs => async_auth_scram.rs} (100%) rename rumqttc/examples/{sync_auth.rs => sync_auth_scram.rs} (98%) diff --git a/rumqttc/examples/async_auth_oauth.rs b/rumqttc/examples/async_auth_oauth.rs new file mode 100644 index 000000000..cb115de1a --- /dev/null +++ b/rumqttc/examples/async_auth_oauth.rs @@ -0,0 +1,64 @@ +use rumqttc::v5::mqttbytes::v5::AuthProperties; +use rumqttc::v5::{mqttbytes::QoS, AsyncClient, MqttOptions}; +use rumqttc::{TlsConfiguration, Transport}; +use std::error::Error; +use std::sync::Arc; +use tokio::task; +use tokio_rustls::rustls::ClientConfig; + +#[tokio::main(flavor = "current_thread")] +async fn main() -> Result<(), Box> { + let pubsub_access_token = ""; + + let mut mqttoptions = MqttOptions::new("client1-session1", "MQTT hostname", 8883); + mqttoptions.set_authentication_method(Some("OAUTH2-JWT".to_string())); + mqttoptions.set_authentication_data(Some(pubsub_access_token.into())); + + // Use rustls-native-certs to load root certificates from the operating system. + let mut root_cert_store = tokio_rustls::rustls::RootCertStore::empty(); + root_cert_store.add_parsable_certificates( + rustls_native_certs::load_native_certs().expect("could not load platform certs"), + ); + + let client_config = ClientConfig::builder() + .with_root_certificates(root_cert_store) + .with_no_client_auth(); + + let transport = Transport::Tls(TlsConfiguration::Rustls(Arc::new(client_config.into()))); + + mqttoptions.set_transport(transport); + + let (client, mut eventloop) = AsyncClient::new(mqttoptions, 10); + + task::spawn(async move { + client.subscribe("topic1", QoS::AtLeastOnce).await.unwrap(); + client + .publish("topic1", QoS::AtLeastOnce, false, "hello world") + .await + .unwrap(); + + // Re-authentication test. + let props = AuthProperties { + authentication_method: Some("OAUTH2-JWT".to_string()), + authentication_data: Some(pubsub_access_token.into()), + reason_string: None, + user_properties: Vec::new(), + }; + + client.reauth(Some(props)).await.unwrap(); + }); + + loop { + let notification = eventloop.poll().await; + + match notification { + Ok(event) => println!("{:?}", event), + Err(e) => { + println!("Error = {:?}", e); + break; + } + } + } + + Ok(()) +} diff --git a/rumqttc/examples/async_auth.rs b/rumqttc/examples/async_auth_scram.rs similarity index 100% rename from rumqttc/examples/async_auth.rs rename to rumqttc/examples/async_auth_scram.rs index d30610b8a..296e86708 100644 --- a/rumqttc/examples/async_auth.rs +++ b/rumqttc/examples/async_auth_scram.rs @@ -1,14 +1,14 @@ use bytes::Bytes; +use flume::bounded; use rumqttc::v5::mqttbytes::{v5::AuthProperties, QoS}; use rumqttc::v5::{AsyncClient, AuthManager, MqttOptions}; +#[cfg(feature = "auth-scram")] +use scram::client::ServerFirst; +#[cfg(feature = "auth-scram")] +use scram::ScramClient; use std::error::Error; use std::sync::{Arc, Mutex}; use tokio::task; -#[cfg(feature = "auth-scram")] -use scram::ScramClient; -#[cfg(feature = "auth-scram")] -use scram::client::ServerFirst; -use flume::bounded; #[derive(Debug)] struct ScramAuthManager<'a> { diff --git a/rumqttc/examples/sync_auth.rs b/rumqttc/examples/sync_auth_scram.rs similarity index 98% rename from rumqttc/examples/sync_auth.rs rename to rumqttc/examples/sync_auth_scram.rs index 8421a2f2c..2ab3cad7b 100644 --- a/rumqttc/examples/sync_auth.rs +++ b/rumqttc/examples/sync_auth_scram.rs @@ -1,14 +1,14 @@ use bytes::Bytes; +use flume::bounded; use rumqttc::v5::mqttbytes::{v5::AuthProperties, QoS}; use rumqttc::v5::{AuthManager, Client, MqttOptions}; +#[cfg(feature = "auth-scram")] +use scram::client::ServerFirst; +#[cfg(feature = "auth-scram")] +use scram::ScramClient; use std::error::Error; use std::sync::{Arc, Mutex}; use std::thread; -#[cfg(feature = "auth-scram")] -use scram::ScramClient; -#[cfg(feature = "auth-scram")] -use scram::client::ServerFirst; -use flume::bounded; #[derive(Debug)] struct ScramAuthManager<'a> { @@ -116,7 +116,7 @@ fn main() -> Result<(), Box> { client.reauth(Some(properties)).unwrap(); }); - for (i, notification) in connection.iter().enumerate() { + for (_, notification) in connection.iter().enumerate() { match notification { Ok(event) => { println!("Event = {:?}", event); From 4fd4e215c2787235a6c0e58732dfa2d245d40de0 Mon Sep 17 00:00:00 2001 From: CQ Xiao Date: Thu, 30 May 2024 12:21:46 +0800 Subject: [PATCH 25/31] Update auth_continue API. --- rumqttc/examples/async_auth_scram.rs | 33 ++++++++++++++++++++++------ rumqttc/src/v5/mod.rs | 10 ++++----- rumqttc/src/v5/state.rs | 17 ++++---------- 3 files changed, 34 insertions(+), 26 deletions(-) diff --git a/rumqttc/examples/async_auth_scram.rs b/rumqttc/examples/async_auth_scram.rs index 296e86708..197b82043 100644 --- a/rumqttc/examples/async_auth_scram.rs +++ b/rumqttc/examples/async_auth_scram.rs @@ -12,7 +12,9 @@ use tokio::task; #[derive(Debug)] struct ScramAuthManager<'a> { + #[allow(dead_code)] user: &'a str, + #[allow(dead_code)] password: &'a str, #[cfg(feature = "auth-scram")] scram: Option>, @@ -46,13 +48,20 @@ impl<'a> ScramAuthManager<'a> { impl<'a> AuthManager for ScramAuthManager<'a> { fn auth_continue( &mut self, - auth_method: Option, - auth_data: Option, - ) -> Result, String> { + #[allow(unused_variables)] + auth_prop: Option, + ) -> Result, String> { #[cfg(feature = "auth-scram")] { + // Unwrap the properties. + let prop = auth_prop.unwrap(); + // Check if the authentication method is SCRAM-SHA-256 - if auth_method.unwrap() != "SCRAM-SHA-256" { + if let Some(auth_method) = &prop.authentication_method { + if auth_method != "SCRAM-SHA-256" { + return Err("Invalid authentication method".to_string()); + } + } else { return Err("Invalid authentication method".to_string()); } @@ -62,7 +71,7 @@ impl<'a> AuthManager for ScramAuthManager<'a> { let scram = self.scram.take().unwrap(); - let auth_data = String::from_utf8(auth_data.unwrap().to_vec()).unwrap(); + let auth_data = String::from_utf8(prop.authentication_data.unwrap().to_vec()).unwrap(); // Process the server first message and reassign the SCRAM state. let scram = match scram.handle_server_first(&auth_data) { @@ -73,11 +82,21 @@ impl<'a> AuthManager for ScramAuthManager<'a> { // Get the client final message and reassign the SCRAM state. let (_, client_final) = scram.client_final(); - Ok(Some(client_final.into())) + Ok(Some(AuthProperties{ + authentication_method: Some("SCRAM-SHA-256".to_string()), + authentication_data: Some(client_final.into()), + reason_string: None, + user_properties: Vec::new(), + })) } #[cfg(not(feature = "auth-scram"))] - Ok(Some("client final message".into())) + Ok(Some(AuthProperties { + authentication_method: Some("SCRAM-SHA-256".to_string()), + authentication_data: Some("client final message".into()), + reason_string: None, + user_properties: Vec::new(), + })) } } diff --git a/rumqttc/src/v5/mod.rs b/rumqttc/src/v5/mod.rs index e2a5c1cb3..f071d60cf 100644 --- a/rumqttc/src/v5/mod.rs +++ b/rumqttc/src/v5/mod.rs @@ -37,19 +37,17 @@ pub trait AuthManager: std::fmt::Debug { /// /// # Arguments /// - /// * `auth_method` - The authentication method received from the server. - /// * `auth_data` - The authentication data received from the server. + /// * `auth_prop` - The authentication Properties received from the server. /// /// # Returns /// - /// * `Ok(auth_data)` - The authentication data to be sent back to the server. + /// * `Ok(auth_prop)` - The authentication Properties to be sent back to the server. /// * `Err(error_message)` - An error indicating that the authentication process has failed or terminated. fn auth_continue( &mut self, - auth_method: Option, - auth_data: Option, - ) -> Result, String>; + auth_prop: Option, + ) -> Result, String>; } /// Requests by the client to mqtt event loop. Request are diff --git a/rumqttc/src/v5/state.rs b/rumqttc/src/v5/state.rs index 77df3f24e..2b1be692d 100644 --- a/rumqttc/src/v5/state.rs +++ b/rumqttc/src/v5/state.rs @@ -1,5 +1,5 @@ use super::mqttbytes::v5::{ - Auth, AuthProperties, AuthReasonCode, ConnAck, ConnectReturnCode, Disconnect, + Auth, AuthReasonCode, ConnAck, ConnectReturnCode, Disconnect, DisconnectReasonCode, Packet, PingReq, PubAck, PubAckReason, PubComp, PubCompReason, PubRec, PubRecReason, PubRel, PubRelReason, Publish, SubAck, Subscribe, SubscribeReasonCode, UnsubAck, UnsubAckReason, Unsubscribe, @@ -500,8 +500,6 @@ impl MqttState { AuthReasonCode::Success => Ok(None), AuthReasonCode::ContinueAuthentication => { let props = auth.properties.clone().unwrap(); - let in_auth_method = props.authentication_method; - let in_auth_data = props.authentication_data; // Check if auth manager is set if self.auth_manager.is_none() { @@ -511,24 +509,17 @@ impl MqttState { let auth_manager = self.auth_manager.clone().unwrap(); // Call auth_continue method of auth manager - let out_auth_data = match auth_manager + let out_auth_prop = match auth_manager .lock() .unwrap() - .auth_continue(in_auth_method.clone(), in_auth_data) + .auth_continue(Some(props)) { Ok(data) => data, Err(err) => return Err(StateError::AuthError(err)), }; - let properties = AuthProperties { - authentication_method: in_auth_method, - authentication_data: out_auth_data, - reason_string: None, - user_properties: Vec::new(), - }; - let client_auth = - Auth::new(AuthReasonCode::ContinueAuthentication, Some(properties)); + Auth::new(AuthReasonCode::ContinueAuthentication, out_auth_prop); self.outgoing_auth(client_auth) } From c801b3c2fa3be1447c1fef82526f44c336525259 Mon Sep 17 00:00:00 2001 From: CQ Xiao Date: Thu, 30 May 2024 12:41:04 +0800 Subject: [PATCH 26/31] Optimize. --- rumqttc/src/v5/state.rs | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/rumqttc/src/v5/state.rs b/rumqttc/src/v5/state.rs index 796e0ec6c..cba317bc6 100644 --- a/rumqttc/src/v5/state.rs +++ b/rumqttc/src/v5/state.rs @@ -490,7 +490,7 @@ impl MqttState { match auth.code { AuthReasonCode::Success => Ok(None), AuthReasonCode::Continue => { - let props = auth.properties.clone().unwrap(); + let props = auth.properties.clone(); // Check if auth manager is set if self.auth_manager.is_none() { @@ -500,16 +500,16 @@ impl MqttState { let auth_manager = self.auth_manager.clone().unwrap(); // Call auth_continue method of auth manager - let out_auth_prop = match auth_manager + let out_auth_props = match auth_manager .lock() .unwrap() - .auth_continue(Some(props)) + .auth_continue(props) { Ok(data) => data, Err(err) => return Err(StateError::AuthError(err)), }; - let client_auth = Auth::new(AuthReasonCode::Continue, out_auth_prop); + let client_auth = Auth::new(AuthReasonCode::Continue, out_auth_props); self.outgoing_auth(client_auth) } From 000e712dc5c292d4322f9ad9eed848b214be26c7 Mon Sep 17 00:00:00 2001 From: CQ Xiao Date: Thu, 30 May 2024 14:03:48 +0800 Subject: [PATCH 27/31] Update sync example. --- rumqttc/examples/sync_auth_scram.rs | 41 +++++++++++++++++++++-------- 1 file changed, 30 insertions(+), 11 deletions(-) diff --git a/rumqttc/examples/sync_auth_scram.rs b/rumqttc/examples/sync_auth_scram.rs index a42bf24f4..c96b42f1e 100644 --- a/rumqttc/examples/sync_auth_scram.rs +++ b/rumqttc/examples/sync_auth_scram.rs @@ -12,7 +12,9 @@ use std::thread; #[derive(Debug)] struct ScramAuthManager<'a> { + #[allow(dead_code)] user: &'a str, + #[allow(dead_code)] password: &'a str, #[cfg(feature = "auth-scram")] scram: Option>, @@ -46,13 +48,20 @@ impl<'a> ScramAuthManager<'a> { impl<'a> AuthManager for ScramAuthManager<'a> { fn auth_continue( &mut self, - auth_method: Option, - auth_data: Option, - ) -> Result, String> { + #[allow(unused_variables)] + auth_prop: Option, + ) -> Result, String> { #[cfg(feature = "auth-scram")] { - //Check if the authentication method is SCRAM-SHA-256 - if auth_method.unwrap() != "SCRAM-SHA-256" { + // Unwrap the properties. + let prop = auth_prop.unwrap(); + + // Check if the authentication method is SCRAM-SHA-256 + if let Some(auth_method) = &prop.method { + if auth_method != "SCRAM-SHA-256" { + return Err("Invalid authentication method".to_string()); + } + } else { return Err("Invalid authentication method".to_string()); } @@ -62,22 +71,32 @@ impl<'a> AuthManager for ScramAuthManager<'a> { let scram = self.scram.take().unwrap(); - let auth_data = String::from_utf8(auth_data.unwrap().to_vec()).unwrap(); + let auth_data = String::from_utf8(prop.data.unwrap().to_vec()).unwrap(); - //Process the server first message and reassign the SCRAM state. + // Process the server first message and reassign the SCRAM state. let scram = match scram.handle_server_first(&auth_data) { Ok(scram) => scram, Err(e) => return Err(e.to_string()), }; - //Get the client final message and reassign the SCRAM state. + // Get the client final message and reassign the SCRAM state. let (_, client_final) = scram.client_final(); - Ok(Some(client_final.into())) + Ok(Some(AuthProperties{ + method: Some("SCRAM-SHA-256".to_string()), + data: Some(client_final.into()), + reason: None, + user_properties: Vec::new(), + })) } #[cfg(not(feature = "auth-scram"))] - Ok(Some("client final message".into())) + Ok(Some(AuthProperties { + method: Some("SCRAM-SHA-256".to_string()), + data: Some("client final message".into()), + reason: None, + user_properties: Vec::new(), + })) } } @@ -120,7 +139,7 @@ fn main() -> Result<(), Box> { match notification { Ok(event) => { println!("Event = {:?}", event); - match (event) { + match event { rumqttc::v5::Event::Incoming(rumqttc::v5::Incoming::ConnAck(_)) => { tx.send("Connected").unwrap(); } From 16cc05e6e9c708b734c58f5ca824caaad95323c9 Mon Sep 17 00:00:00 2001 From: CQ Xiao Date: Tue, 4 Jun 2024 15:14:51 +0800 Subject: [PATCH 28/31] Fix Send issue between threads. --- rumqttc/src/v5/mod.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/rumqttc/src/v5/mod.rs b/rumqttc/src/v5/mod.rs index 29071cad2..d0195948c 100644 --- a/rumqttc/src/v5/mod.rs +++ b/rumqttc/src/v5/mod.rs @@ -32,7 +32,7 @@ pub use crate::proxy::{Proxy, ProxyAuth, ProxyType}; pub type Incoming = Packet; -pub trait AuthManager: std::fmt::Debug { +pub trait AuthManager: std::fmt::Debug + Send { /// Process authentication data received from the server and generate authentication data to be sent back. /// /// # Arguments From c7f0dc6e4a515839a68b21d2889a2141183d16ce Mon Sep 17 00:00:00 2001 From: CQ Xiao Date: Wed, 5 Jun 2024 09:21:58 +0800 Subject: [PATCH 29/31] Fix testing. --- rumqttc/src/v5/client.rs | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/rumqttc/src/v5/client.rs b/rumqttc/src/v5/client.rs index 2492505c5..b1786e140 100644 --- a/rumqttc/src/v5/client.rs +++ b/rumqttc/src/v5/client.rs @@ -932,10 +932,16 @@ mod test { fn test_reauth() { let (client, mut connection) = Client::new(MqttOptions::new("test-1", "localhost", 1883), 10); - let _ = client.reauth(None).expect("Should be able to reauth"); + let props = AuthProperties { + method: Some("test".to_string()), + data: Some(Bytes::from("test")), + reason: None, + user_properties: vec![], + }; + let _ = client.reauth(Some(props.clone())).expect("Should be able to reauth"); let _ = connection.iter().next().expect("Should have event"); - let _ = client.try_reauth(None).expect("Should be able to reauth"); + let _ = client.try_reauth(Some(props.clone())).expect("Should be able to reauth"); let _ = connection.iter().next().expect("Should have event"); } } From 2c0a896544cec5641e5a4e0083356e54ebe38ed6 Mon Sep 17 00:00:00 2001 From: CQ Xiao Date: Fri, 14 Jun 2024 11:06:25 +0800 Subject: [PATCH 30/31] Fix compile warnings. --- rumqttc/src/v5/mod.rs | 4 +--- rumqttc/src/v5/state.rs | 2 +- 2 files changed, 2 insertions(+), 4 deletions(-) diff --git a/rumqttc/src/v5/mod.rs b/rumqttc/src/v5/mod.rs index d0195948c..d6280def5 100644 --- a/rumqttc/src/v5/mod.rs +++ b/rumqttc/src/v5/mod.rs @@ -577,9 +577,7 @@ impl MqttOptions { } pub fn auth_manager(&self) -> Option>> { - if self.auth_manager.is_none() { - return None; - } + self.auth_manager.as_ref()?; self.auth_manager.clone() } diff --git a/rumqttc/src/v5/state.rs b/rumqttc/src/v5/state.rs index cba317bc6..22b7d8e5d 100644 --- a/rumqttc/src/v5/state.rs +++ b/rumqttc/src/v5/state.rs @@ -513,7 +513,7 @@ impl MqttState { self.outgoing_auth(client_auth) } - _ => return Err(StateError::AuthError("Authentication Failed!".to_string())), + _ => Err(StateError::AuthError("Authentication Failed!".to_string())), } } From 4906aeec889914bf978069fec769883d4b04894b Mon Sep 17 00:00:00 2001 From: CQ Xiao Date: Mon, 17 Jun 2024 09:36:30 +0800 Subject: [PATCH 31/31] Fix testing issue. --- rumqttc/Cargo.toml | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/rumqttc/Cargo.toml b/rumqttc/Cargo.toml index 110847842..63e2032ca 100644 --- a/rumqttc/Cargo.toml +++ b/rumqttc/Cargo.toml @@ -62,6 +62,11 @@ pretty_assertions = "1" pretty_env_logger = "0.5" serde = { version = "1", features = ["derive"] } +[[example]] +name = "async_auth_oauth" +path = "examples/async_auth_oauth.rs" +required-features = ["use-rustls"] + [[example]] name = "tls" path = "examples/tls.rs"