Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/mqtt/packet/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ pub use self::variable_byte_integer::{DecodeResult, VariableByteInteger};
mod packet_type;
pub use self::packet_type::{FixedHeader, PacketType};
mod packet_id;
pub use self::packet_id::IsPacketId;
pub use self::packet_id::{IntoPacketId, IsPacketId};
pub mod v3_1_1;
pub mod v5_0;
pub use self::enum_packet::{GenericPacket, GenericPacketDisplay, GenericPacketTrait, Packet};
Expand Down
67 changes: 67 additions & 0 deletions src/mqtt/packet/packet_id.rs
Original file line number Diff line number Diff line change
Expand Up @@ -70,3 +70,70 @@ impl IsPacketId for u32 {
}
}
}

/// Trait for types that can be converted into an optional packet ID
///
/// This trait enables the packet_id() builder method to accept both direct values
/// (e.g., `packet_id(42)`) and optional values (e.g., `packet_id(Some(42))` or `packet_id(None)`).
///
/// # Examples
///
/// ```
/// use mqtt_protocol_core::mqtt::packet::IntoPacketId;
///
/// // Direct value
/// let id1: Option<u16> = 42u16.into_packet_id();
/// assert_eq!(id1, Some(42));
///
/// // Optional value
/// let id2: Option<u16> = Some(42u16).into_packet_id();
/// assert_eq!(id2, Some(42));
///
/// // None value
/// let id3: Option<u16> = None::<u16>.into_packet_id();
/// assert_eq!(id3, None);
/// ```
pub trait IntoPacketId<T> {
/// Convert self into an optional packet ID
fn into_packet_id(self) -> Option<T>;
}

// Implementations for u16

/// Implementation for direct u16 packet ID values
///
/// Allows direct u16 values like `42u16` to be converted to `Some(42)`
impl IntoPacketId<u16> for u16 {
fn into_packet_id(self) -> Option<u16> {
Some(self)
}
}

/// Implementation for optional u16 packet ID values
///
/// Allows optional values like `Some(42u16)` or `None::<u16>` to be passed through
impl IntoPacketId<u16> for Option<u16> {
fn into_packet_id(self) -> Option<u16> {
self
}
}

// Implementations for u32

/// Implementation for direct u32 packet ID values
///
/// Allows direct u32 values like `42u32` to be converted to `Some(42)`
impl IntoPacketId<u32> for u32 {
fn into_packet_id(self) -> Option<u32> {
Some(self)
}
}

/// Implementation for optional u32 packet ID values
///
/// Allows optional values like `Some(42u32)` or `None::<u32>` to be passed through
impl IntoPacketId<u32> for Option<u32> {
fn into_packet_id(self) -> Option<u32> {
self
}
}
31 changes: 27 additions & 4 deletions src/mqtt/packet/v3_1_1/publish.rs
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ use crate::mqtt::packet::qos::Qos;
use crate::mqtt::packet::variable_byte_integer::VariableByteInteger;
use crate::mqtt::packet::GenericPacketDisplay;
use crate::mqtt::packet::GenericPacketTrait;
use crate::mqtt::packet::IsPacketId;
use crate::mqtt::packet::{IntoPacketId, IsPacketId};
use crate::mqtt::result_code::MqttError;
use crate::mqtt::{Arc, ArcPayload, IntoPayload};

Expand Down Expand Up @@ -843,9 +843,14 @@ where
/// and must be a non-zero value. It is used to match the packet with its
/// corresponding acknowledgment packets (PUBACK, PUBREC, etc.).
///
/// This method accepts both direct values and `Option<PacketIdType>`:
/// - `packet_id(42)` - Sets packet ID to 42 (for QoS 1/2, backward compatible)
/// - `packet_id(Some(42))` - Sets packet ID to 42 (for QoS 1/2)
/// - `packet_id(None)` - No packet ID (for QoS 0)
///
/// # Parameters
///
/// * `id` - The packet identifier (must be non-zero for QoS > 0)
/// * `id` - The packet identifier value or Option (must be non-zero for QoS > 0)
///
/// # Returns
///
Expand All @@ -857,14 +862,32 @@ where
/// use mqtt_protocol_core::mqtt;
/// use mqtt_protocol_core::mqtt::packet::qos::Qos;
///
/// // Direct value (backward compatible)
/// let builder = mqtt::packet::v3_1_1::Publish::builder()
/// .topic_name("test/topic")
/// .unwrap()
/// .qos(Qos::AtLeastOnce)
/// .packet_id(123);
///
/// // Explicit Some
/// let builder = mqtt::packet::v3_1_1::Publish::builder()
/// .topic_name("test/topic")
/// .unwrap()
/// .qos(Qos::AtLeastOnce)
/// .packet_id(Some(123));
///
/// // Explicit None for QoS 0
/// let builder = mqtt::packet::v3_1_1::Publish::builder()
/// .topic_name("test/topic")
/// .unwrap()
/// .qos(Qos::AtMostOnce)
/// .packet_id(None);
/// ```
pub fn packet_id(mut self, id: PacketIdType) -> Self {
self.packet_id_buf = Some(Some(id.to_buffer()));
pub fn packet_id<T>(mut self, id: T) -> Self
where
T: IntoPacketId<PacketIdType>,
{
self.packet_id_buf = Some(id.into_packet_id().map(|i| i.to_buffer()));
self
}

Expand Down
25 changes: 21 additions & 4 deletions src/mqtt/packet/v5_0/publish.rs
Original file line number Diff line number Diff line change
Expand Up @@ -42,9 +42,9 @@ use crate::mqtt::packet::topic_alias_send::TopicAliasType;
use crate::mqtt::packet::variable_byte_integer::VariableByteInteger;
use crate::mqtt::packet::GenericPacketDisplay;
use crate::mqtt::packet::GenericPacketTrait;
use crate::mqtt::packet::IsPacketId;
#[cfg(feature = "std")]
use crate::mqtt::packet::PropertiesToBuffers;
use crate::mqtt::packet::{IntoPacketId, IsPacketId};
use crate::mqtt::packet::{Properties, PropertiesParse, PropertiesSize, Property};
use crate::mqtt::result_code::MqttError;
use crate::mqtt::{Arc, ArcPayload, IntoPayload};
Expand Down Expand Up @@ -1141,9 +1141,14 @@ where
/// to match acknowledgment packets (PUBACK, PUBREC, etc.) with the original
/// PUBLISH packet.
///
/// This method accepts both direct values and `Option<PacketIdType>`:
/// - `packet_id(42)` - Sets packet ID to 42 (for QoS 1/2, backward compatible)
/// - `packet_id(Some(42))` - Sets packet ID to 42 (for QoS 1/2)
/// - `packet_id(None)` - No packet ID (for QoS 0)
///
/// # Parameters
///
/// - `id`: The packet identifier (must be non-zero for QoS > 0)
/// - `id`: The packet identifier value or Option (must be non-zero for QoS > 0)
///
/// # Returns
///
Expand All @@ -1154,11 +1159,23 @@ where
/// ```ignore
/// use mqtt_protocol_core::mqtt;
///
/// // Direct value (backward compatible)
/// let builder = mqtt::packet::v5_0::Publish::builder()
/// .packet_id(42);
///
/// // Explicit Some
/// let builder = mqtt::packet::v5_0::Publish::builder()
/// .packet_id(Some(42));
///
/// // Explicit None for QoS 0
/// let builder = mqtt::packet::v5_0::Publish::builder()
/// .packet_id(None);
/// ```
pub fn packet_id(mut self, id: PacketIdType) -> Self {
self.packet_id_buf = Some(Some(id.to_buffer()));
pub fn packet_id<T>(mut self, id: T) -> Self
where
T: IntoPacketId<PacketIdType>,
{
self.packet_id_buf = Some(id.into_packet_id().map(|i| i.to_buffer()));
self
}

Expand Down
92 changes: 92 additions & 0 deletions tests/packet-v3_1_1-publish.rs
Original file line number Diff line number Diff line change
Expand Up @@ -668,3 +668,95 @@ fn test_packet_type() {
let packet_type = mqtt::packet::v3_1_1::Publish::packet_type();
assert_eq!(packet_type, mqtt::packet::PacketType::Publish);
}

// Tests for packet_id() with Option interface

#[test]
fn test_qos1_with_some_packet_id_success() {
common::init_tracing();
let result = mqtt::packet::v3_1_1::Publish::builder()
.topic_name("test/topic")
.unwrap()
.qos(mqtt::packet::Qos::AtLeastOnce)
.packet_id(Some(42u16))
.build()
.unwrap();
assert_eq!(result.packet_id(), Some(42u16));
}

#[test]
fn test_qos2_with_some_packet_id_success() {
common::init_tracing();
let result = mqtt::packet::v3_1_1::Publish::builder()
.topic_name("test/topic")
.unwrap()
.qos(mqtt::packet::Qos::ExactlyOnce)
.packet_id(Some(123u16))
.build()
.unwrap();
assert_eq!(result.packet_id(), Some(123u16));
}

#[test]
fn test_qos1_with_none_packet_id_error() {
common::init_tracing();
let err = mqtt::packet::v3_1_1::Publish::builder()
.topic_name("test/topic")
.unwrap()
.qos(mqtt::packet::Qos::AtLeastOnce)
.packet_id(None::<u16>)
.build()
.unwrap_err();
assert_eq!(err, mqtt::result_code::MqttError::MalformedPacket);
}

#[test]
fn test_qos2_with_none_packet_id_error() {
common::init_tracing();
let err = mqtt::packet::v3_1_1::Publish::builder()
.topic_name("test/topic")
.unwrap()
.qos(mqtt::packet::Qos::ExactlyOnce)
.packet_id(None::<u16>)
.build()
.unwrap_err();
assert_eq!(err, mqtt::result_code::MqttError::MalformedPacket);
}

#[test]
fn test_qos0_with_none_packet_id_success() {
common::init_tracing();
let result = mqtt::packet::v3_1_1::Publish::builder()
.topic_name("test/topic")
.unwrap()
.qos(mqtt::packet::Qos::AtMostOnce)
.packet_id(None::<u16>)
.build()
.unwrap();
assert_eq!(result.packet_id(), None);
}

#[test]
fn test_qos0_with_some_packet_id_error() {
common::init_tracing();
let err = mqtt::packet::v3_1_1::Publish::builder()
.topic_name("test/topic")
.unwrap()
.qos(mqtt::packet::Qos::AtMostOnce)
.packet_id(Some(42u16))
.build()
.unwrap_err();
assert_eq!(err, mqtt::result_code::MqttError::MalformedPacket);
}

#[test]
fn test_qos0_without_packet_id_success() {
common::init_tracing();
let result = mqtt::packet::v3_1_1::Publish::builder()
.topic_name("test/topic")
.unwrap()
.qos(mqtt::packet::Qos::AtMostOnce)
.build()
.unwrap();
assert_eq!(result.packet_id(), None);
}
92 changes: 92 additions & 0 deletions tests/packet-v5_0-publish.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1504,3 +1504,95 @@ fn test_packet_type() {
let packet_type = mqtt::packet::v5_0::Publish::packet_type();
assert_eq!(packet_type, mqtt::packet::PacketType::Publish);
}

// Tests for packet_id() with Option interface

#[test]
fn test_qos1_with_some_packet_id_success() {
common::init_tracing();
let result = mqtt::packet::v5_0::Publish::builder()
.topic_name("test/topic")
.unwrap()
.qos(mqtt::packet::Qos::AtLeastOnce)
.packet_id(Some(42u16))
.build()
.unwrap();
assert_eq!(result.packet_id(), Some(42u16));
}

#[test]
fn test_qos2_with_some_packet_id_success() {
common::init_tracing();
let result = mqtt::packet::v5_0::Publish::builder()
.topic_name("test/topic")
.unwrap()
.qos(mqtt::packet::Qos::ExactlyOnce)
.packet_id(Some(123u16))
.build()
.unwrap();
assert_eq!(result.packet_id(), Some(123u16));
}

#[test]
fn test_qos1_with_none_packet_id_error() {
common::init_tracing();
let err = mqtt::packet::v5_0::Publish::builder()
.topic_name("test/topic")
.unwrap()
.qos(mqtt::packet::Qos::AtLeastOnce)
.packet_id(None::<u16>)
.build()
.unwrap_err();
assert_eq!(err, mqtt::result_code::MqttError::MalformedPacket);
}

#[test]
fn test_qos2_with_none_packet_id_error() {
common::init_tracing();
let err = mqtt::packet::v5_0::Publish::builder()
.topic_name("test/topic")
.unwrap()
.qos(mqtt::packet::Qos::ExactlyOnce)
.packet_id(None::<u16>)
.build()
.unwrap_err();
assert_eq!(err, mqtt::result_code::MqttError::MalformedPacket);
}

#[test]
fn test_qos0_with_none_packet_id_success() {
common::init_tracing();
let result = mqtt::packet::v5_0::Publish::builder()
.topic_name("test/topic")
.unwrap()
.qos(mqtt::packet::Qos::AtMostOnce)
.packet_id(None::<u16>)
.build()
.unwrap();
assert_eq!(result.packet_id(), None);
}

#[test]
fn test_qos0_with_some_packet_id_error() {
common::init_tracing();
let err = mqtt::packet::v5_0::Publish::builder()
.topic_name("test/topic")
.unwrap()
.qos(mqtt::packet::Qos::AtMostOnce)
.packet_id(Some(42u16))
.build()
.unwrap_err();
assert_eq!(err, mqtt::result_code::MqttError::MalformedPacket);
}

#[test]
fn test_qos0_without_packet_id_success() {
common::init_tracing();
let result = mqtt::packet::v5_0::Publish::builder()
.topic_name("test/topic")
.unwrap()
.qos(mqtt::packet::Qos::AtMostOnce)
.build()
.unwrap();
assert_eq!(result.packet_id(), None);
}