diff --git a/src/mqtt/packet/mod.rs b/src/mqtt/packet/mod.rs index daacd63..e09cb9b 100644 --- a/src/mqtt/packet/mod.rs +++ b/src/mqtt/packet/mod.rs @@ -38,7 +38,7 @@ pub use self::variable_byte_integer::{DecodeResult, VariableByteInteger}; mod packet_type; pub use self::packet_type::{FixedHeader, PacketType}; mod packet_id; -pub use self::packet_id::IsPacketId; +pub use self::packet_id::{IntoPacketId, IsPacketId}; pub mod v3_1_1; pub mod v5_0; pub use self::enum_packet::{GenericPacket, GenericPacketDisplay, GenericPacketTrait, Packet}; diff --git a/src/mqtt/packet/packet_id.rs b/src/mqtt/packet/packet_id.rs index 67b99bb..951893f 100644 --- a/src/mqtt/packet/packet_id.rs +++ b/src/mqtt/packet/packet_id.rs @@ -70,3 +70,70 @@ impl IsPacketId for u32 { } } } + +/// Trait for types that can be converted into an optional packet ID +/// +/// This trait enables the packet_id() builder method to accept both direct values +/// (e.g., `packet_id(42)`) and optional values (e.g., `packet_id(Some(42))` or `packet_id(None)`). +/// +/// # Examples +/// +/// ``` +/// use mqtt_protocol_core::mqtt::packet::IntoPacketId; +/// +/// // Direct value +/// let id1: Option = 42u16.into_packet_id(); +/// assert_eq!(id1, Some(42)); +/// +/// // Optional value +/// let id2: Option = Some(42u16).into_packet_id(); +/// assert_eq!(id2, Some(42)); +/// +/// // None value +/// let id3: Option = None::.into_packet_id(); +/// assert_eq!(id3, None); +/// ``` +pub trait IntoPacketId { + /// Convert self into an optional packet ID + fn into_packet_id(self) -> Option; +} + +// Implementations for u16 + +/// Implementation for direct u16 packet ID values +/// +/// Allows direct u16 values like `42u16` to be converted to `Some(42)` +impl IntoPacketId for u16 { + fn into_packet_id(self) -> Option { + Some(self) + } +} + +/// Implementation for optional u16 packet ID values +/// +/// Allows optional values like `Some(42u16)` or `None::` to be passed through +impl IntoPacketId for Option { + fn into_packet_id(self) -> Option { + self + } +} + +// Implementations for u32 + +/// Implementation for direct u32 packet ID values +/// +/// Allows direct u32 values like `42u32` to be converted to `Some(42)` +impl IntoPacketId for u32 { + fn into_packet_id(self) -> Option { + Some(self) + } +} + +/// Implementation for optional u32 packet ID values +/// +/// Allows optional values like `Some(42u32)` or `None::` to be passed through +impl IntoPacketId for Option { + fn into_packet_id(self) -> Option { + self + } +} diff --git a/src/mqtt/packet/v3_1_1/publish.rs b/src/mqtt/packet/v3_1_1/publish.rs index cab716b..3d28550 100644 --- a/src/mqtt/packet/v3_1_1/publish.rs +++ b/src/mqtt/packet/v3_1_1/publish.rs @@ -39,7 +39,7 @@ use crate::mqtt::packet::qos::Qos; use crate::mqtt::packet::variable_byte_integer::VariableByteInteger; use crate::mqtt::packet::GenericPacketDisplay; use crate::mqtt::packet::GenericPacketTrait; -use crate::mqtt::packet::IsPacketId; +use crate::mqtt::packet::{IntoPacketId, IsPacketId}; use crate::mqtt::result_code::MqttError; use crate::mqtt::{Arc, ArcPayload, IntoPayload}; @@ -843,9 +843,14 @@ where /// and must be a non-zero value. It is used to match the packet with its /// corresponding acknowledgment packets (PUBACK, PUBREC, etc.). /// + /// This method accepts both direct values and `Option`: + /// - `packet_id(42)` - Sets packet ID to 42 (for QoS 1/2, backward compatible) + /// - `packet_id(Some(42))` - Sets packet ID to 42 (for QoS 1/2) + /// - `packet_id(None)` - No packet ID (for QoS 0) + /// /// # Parameters /// - /// * `id` - The packet identifier (must be non-zero for QoS > 0) + /// * `id` - The packet identifier value or Option (must be non-zero for QoS > 0) /// /// # Returns /// @@ -857,14 +862,32 @@ where /// use mqtt_protocol_core::mqtt; /// use mqtt_protocol_core::mqtt::packet::qos::Qos; /// + /// // Direct value (backward compatible) /// let builder = mqtt::packet::v3_1_1::Publish::builder() /// .topic_name("test/topic") /// .unwrap() /// .qos(Qos::AtLeastOnce) /// .packet_id(123); + /// + /// // Explicit Some + /// let builder = mqtt::packet::v3_1_1::Publish::builder() + /// .topic_name("test/topic") + /// .unwrap() + /// .qos(Qos::AtLeastOnce) + /// .packet_id(Some(123)); + /// + /// // Explicit None for QoS 0 + /// let builder = mqtt::packet::v3_1_1::Publish::builder() + /// .topic_name("test/topic") + /// .unwrap() + /// .qos(Qos::AtMostOnce) + /// .packet_id(None); /// ``` - pub fn packet_id(mut self, id: PacketIdType) -> Self { - self.packet_id_buf = Some(Some(id.to_buffer())); + pub fn packet_id(mut self, id: T) -> Self + where + T: IntoPacketId, + { + self.packet_id_buf = Some(id.into_packet_id().map(|i| i.to_buffer())); self } diff --git a/src/mqtt/packet/v5_0/publish.rs b/src/mqtt/packet/v5_0/publish.rs index fcf7417..7223f68 100644 --- a/src/mqtt/packet/v5_0/publish.rs +++ b/src/mqtt/packet/v5_0/publish.rs @@ -42,9 +42,9 @@ use crate::mqtt::packet::topic_alias_send::TopicAliasType; use crate::mqtt::packet::variable_byte_integer::VariableByteInteger; use crate::mqtt::packet::GenericPacketDisplay; use crate::mqtt::packet::GenericPacketTrait; -use crate::mqtt::packet::IsPacketId; #[cfg(feature = "std")] use crate::mqtt::packet::PropertiesToBuffers; +use crate::mqtt::packet::{IntoPacketId, IsPacketId}; use crate::mqtt::packet::{Properties, PropertiesParse, PropertiesSize, Property}; use crate::mqtt::result_code::MqttError; use crate::mqtt::{Arc, ArcPayload, IntoPayload}; @@ -1141,9 +1141,14 @@ where /// to match acknowledgment packets (PUBACK, PUBREC, etc.) with the original /// PUBLISH packet. /// + /// This method accepts both direct values and `Option`: + /// - `packet_id(42)` - Sets packet ID to 42 (for QoS 1/2, backward compatible) + /// - `packet_id(Some(42))` - Sets packet ID to 42 (for QoS 1/2) + /// - `packet_id(None)` - No packet ID (for QoS 0) + /// /// # Parameters /// - /// - `id`: The packet identifier (must be non-zero for QoS > 0) + /// - `id`: The packet identifier value or Option (must be non-zero for QoS > 0) /// /// # Returns /// @@ -1154,11 +1159,23 @@ where /// ```ignore /// use mqtt_protocol_core::mqtt; /// + /// // Direct value (backward compatible) /// let builder = mqtt::packet::v5_0::Publish::builder() /// .packet_id(42); + /// + /// // Explicit Some + /// let builder = mqtt::packet::v5_0::Publish::builder() + /// .packet_id(Some(42)); + /// + /// // Explicit None for QoS 0 + /// let builder = mqtt::packet::v5_0::Publish::builder() + /// .packet_id(None); /// ``` - pub fn packet_id(mut self, id: PacketIdType) -> Self { - self.packet_id_buf = Some(Some(id.to_buffer())); + pub fn packet_id(mut self, id: T) -> Self + where + T: IntoPacketId, + { + self.packet_id_buf = Some(id.into_packet_id().map(|i| i.to_buffer())); self } diff --git a/tests/packet-v3_1_1-publish.rs b/tests/packet-v3_1_1-publish.rs index d09305b..2693f9c 100644 --- a/tests/packet-v3_1_1-publish.rs +++ b/tests/packet-v3_1_1-publish.rs @@ -668,3 +668,95 @@ fn test_packet_type() { let packet_type = mqtt::packet::v3_1_1::Publish::packet_type(); assert_eq!(packet_type, mqtt::packet::PacketType::Publish); } + +// Tests for packet_id() with Option interface + +#[test] +fn test_qos1_with_some_packet_id_success() { + common::init_tracing(); + let result = mqtt::packet::v3_1_1::Publish::builder() + .topic_name("test/topic") + .unwrap() + .qos(mqtt::packet::Qos::AtLeastOnce) + .packet_id(Some(42u16)) + .build() + .unwrap(); + assert_eq!(result.packet_id(), Some(42u16)); +} + +#[test] +fn test_qos2_with_some_packet_id_success() { + common::init_tracing(); + let result = mqtt::packet::v3_1_1::Publish::builder() + .topic_name("test/topic") + .unwrap() + .qos(mqtt::packet::Qos::ExactlyOnce) + .packet_id(Some(123u16)) + .build() + .unwrap(); + assert_eq!(result.packet_id(), Some(123u16)); +} + +#[test] +fn test_qos1_with_none_packet_id_error() { + common::init_tracing(); + let err = mqtt::packet::v3_1_1::Publish::builder() + .topic_name("test/topic") + .unwrap() + .qos(mqtt::packet::Qos::AtLeastOnce) + .packet_id(None::) + .build() + .unwrap_err(); + assert_eq!(err, mqtt::result_code::MqttError::MalformedPacket); +} + +#[test] +fn test_qos2_with_none_packet_id_error() { + common::init_tracing(); + let err = mqtt::packet::v3_1_1::Publish::builder() + .topic_name("test/topic") + .unwrap() + .qos(mqtt::packet::Qos::ExactlyOnce) + .packet_id(None::) + .build() + .unwrap_err(); + assert_eq!(err, mqtt::result_code::MqttError::MalformedPacket); +} + +#[test] +fn test_qos0_with_none_packet_id_success() { + common::init_tracing(); + let result = mqtt::packet::v3_1_1::Publish::builder() + .topic_name("test/topic") + .unwrap() + .qos(mqtt::packet::Qos::AtMostOnce) + .packet_id(None::) + .build() + .unwrap(); + assert_eq!(result.packet_id(), None); +} + +#[test] +fn test_qos0_with_some_packet_id_error() { + common::init_tracing(); + let err = mqtt::packet::v3_1_1::Publish::builder() + .topic_name("test/topic") + .unwrap() + .qos(mqtt::packet::Qos::AtMostOnce) + .packet_id(Some(42u16)) + .build() + .unwrap_err(); + assert_eq!(err, mqtt::result_code::MqttError::MalformedPacket); +} + +#[test] +fn test_qos0_without_packet_id_success() { + common::init_tracing(); + let result = mqtt::packet::v3_1_1::Publish::builder() + .topic_name("test/topic") + .unwrap() + .qos(mqtt::packet::Qos::AtMostOnce) + .build() + .unwrap(); + assert_eq!(result.packet_id(), None); +} diff --git a/tests/packet-v5_0-publish.rs b/tests/packet-v5_0-publish.rs index a76575c..5f7fc66 100644 --- a/tests/packet-v5_0-publish.rs +++ b/tests/packet-v5_0-publish.rs @@ -1504,3 +1504,95 @@ fn test_packet_type() { let packet_type = mqtt::packet::v5_0::Publish::packet_type(); assert_eq!(packet_type, mqtt::packet::PacketType::Publish); } + +// Tests for packet_id() with Option interface + +#[test] +fn test_qos1_with_some_packet_id_success() { + common::init_tracing(); + let result = mqtt::packet::v5_0::Publish::builder() + .topic_name("test/topic") + .unwrap() + .qos(mqtt::packet::Qos::AtLeastOnce) + .packet_id(Some(42u16)) + .build() + .unwrap(); + assert_eq!(result.packet_id(), Some(42u16)); +} + +#[test] +fn test_qos2_with_some_packet_id_success() { + common::init_tracing(); + let result = mqtt::packet::v5_0::Publish::builder() + .topic_name("test/topic") + .unwrap() + .qos(mqtt::packet::Qos::ExactlyOnce) + .packet_id(Some(123u16)) + .build() + .unwrap(); + assert_eq!(result.packet_id(), Some(123u16)); +} + +#[test] +fn test_qos1_with_none_packet_id_error() { + common::init_tracing(); + let err = mqtt::packet::v5_0::Publish::builder() + .topic_name("test/topic") + .unwrap() + .qos(mqtt::packet::Qos::AtLeastOnce) + .packet_id(None::) + .build() + .unwrap_err(); + assert_eq!(err, mqtt::result_code::MqttError::MalformedPacket); +} + +#[test] +fn test_qos2_with_none_packet_id_error() { + common::init_tracing(); + let err = mqtt::packet::v5_0::Publish::builder() + .topic_name("test/topic") + .unwrap() + .qos(mqtt::packet::Qos::ExactlyOnce) + .packet_id(None::) + .build() + .unwrap_err(); + assert_eq!(err, mqtt::result_code::MqttError::MalformedPacket); +} + +#[test] +fn test_qos0_with_none_packet_id_success() { + common::init_tracing(); + let result = mqtt::packet::v5_0::Publish::builder() + .topic_name("test/topic") + .unwrap() + .qos(mqtt::packet::Qos::AtMostOnce) + .packet_id(None::) + .build() + .unwrap(); + assert_eq!(result.packet_id(), None); +} + +#[test] +fn test_qos0_with_some_packet_id_error() { + common::init_tracing(); + let err = mqtt::packet::v5_0::Publish::builder() + .topic_name("test/topic") + .unwrap() + .qos(mqtt::packet::Qos::AtMostOnce) + .packet_id(Some(42u16)) + .build() + .unwrap_err(); + assert_eq!(err, mqtt::result_code::MqttError::MalformedPacket); +} + +#[test] +fn test_qos0_without_packet_id_success() { + common::init_tracing(); + let result = mqtt::packet::v5_0::Publish::builder() + .topic_name("test/topic") + .unwrap() + .qos(mqtt::packet::Qos::AtMostOnce) + .build() + .unwrap(); + assert_eq!(result.packet_id(), None); +}