diff --git a/rmls/src/extensibility/mod.rs b/rmls/src/extensibility/mod.rs index bd79720..3979aa2 100644 --- a/rmls/src/extensibility/mod.rs +++ b/rmls/src/extensibility/mod.rs @@ -6,7 +6,12 @@ //! to the protocol. use bytes::{Buf, BufMut, Bytes, BytesMut}; +use std::ops::Deref; +use crate::crypto::credential::{Credential, CredentialType}; +use crate::crypto::{HPKEPublicKey, SignaturePublicKey}; +use crate::group::proposal::ProposalType; +use crate::ratchet_tree::RatchetTree; use crate::utilities::error::*; use crate::utilities::serde::*; @@ -53,30 +58,57 @@ impl From for u16 { /// [RFC9420 Sec.7.2](https://www.rfc-editor.org/rfc/rfc9420.html#section-7.2) Extension #[derive(Default, Debug, Clone, Eq, PartialEq)] pub struct Extension { - pub extension_type: ExtensionType, - pub extension_data: Bytes, + extension_type: ExtensionType, + extension_data: Bytes, +} + +impl Extension { + pub fn new(extension_type: ExtensionType, extension_data: Bytes) -> Self { + Self { + extension_type, + extension_data, + } + } + + pub fn extension_type(&self) -> ExtensionType { + self.extension_type + } + + pub fn extension_data(&self) -> &Bytes { + &self.extension_data + } } /// [RFC9420 Sec.7.2](https://www.rfc-editor.org/rfc/rfc9420.html#section-7.2) Extensions #[derive(Default, Debug, Clone, Eq, PartialEq)] -pub struct Extensions(pub Vec); +pub struct Extensions(Vec); + +impl Extensions { + pub fn new(extensions: Vec) -> Self { + Self(extensions) + } + + pub fn extensions(&self) -> &[Extension] { + self.0.as_ref() + } +} impl Deserializer for Extensions { fn deserialize(buf: &mut B) -> Result { - let mut exts = vec![]; + let mut extensions = vec![]; deserialize_vector(buf, |b: &mut Bytes| -> Result<()> { if b.remaining() < 2 { return Err(Error::BufferTooSmall); } let extension_type: ExtensionType = b.get_u16().into(); let extension_data = deserialize_opaque_vec(b)?; - exts.push(Extension { + extensions.push(Extension { extension_type, extension_data, }); Ok(()) })?; - Ok(Extensions(exts)) + Ok(Extensions(extensions)) } } @@ -103,3 +135,335 @@ impl Extensions { None } } + +/// Application Id Extension +#[derive(Default, Debug, Clone, Eq, PartialEq)] +pub struct ApplicationIdExtension(Bytes); + +impl ApplicationIdExtension { + /// Creates a new ApplicationIdExtension + pub fn new(id: Bytes) -> Self { + Self(id) + } +} + +impl Deref for ApplicationIdExtension { + type Target = Bytes; + + fn deref(&self) -> &Self::Target { + &self.0 + } +} + +impl Deserializer for ApplicationIdExtension { + fn deserialize(buf: &mut B) -> Result + where + Self: Sized, + B: Buf, + { + Ok(ApplicationIdExtension(deserialize_opaque_vec(buf)?)) + } +} + +impl Serializer for ApplicationIdExtension { + fn serialize(&self, buf: &mut B) -> Result<()> + where + Self: Sized, + B: BufMut, + { + serialize_opaque_vec(&self.0, buf) + } +} + +/// Ratchet Tree Extension +#[derive(Default, Debug, Clone, Eq, PartialEq)] +pub struct RatchetTreeExtension { + ratchet_tree: RatchetTree, +} + +impl RatchetTreeExtension { + /// Creates a new RatchetTreeExtension + pub fn new(ratchet_tree: RatchetTree) -> Self { + Self { ratchet_tree } + } + + /// Returns the RatchetTree from this extension + pub fn ratchet_tree(&self) -> &RatchetTree { + &self.ratchet_tree + } +} + +//FIXME(yngrtc): RatchetTreeExtension deserialize failed framing_test and serde_test +impl Deserializer for RatchetTreeExtension { + fn deserialize(buf: &mut B) -> Result + where + Self: Sized, + B: Buf, + { + Ok(RatchetTreeExtension { + ratchet_tree: RatchetTree::deserialize(buf)?, + }) + } +} + +impl Serializer for RatchetTreeExtension { + fn serialize(&self, buf: &mut B) -> Result<()> + where + Self: Sized, + B: BufMut, + { + self.ratchet_tree.serialize(buf) + } +} + +/// Required Capabilities Extension. +#[derive(Default, Debug, Clone, Eq, PartialEq)] +pub struct RequiredCapabilitiesExtension { + extension_types: Vec, + proposal_types: Vec, + credential_types: Vec, +} + +impl RequiredCapabilitiesExtension { + /// Creates a new RequiredCapabilitiesExtension + pub fn new( + extension_types: Vec, + proposal_types: Vec, + credential_types: Vec, + ) -> Self { + Self { + extension_types, + proposal_types, + credential_types, + } + } + + /// Returns the extension_types from this extension + pub fn extension_types(&self) -> &[ExtensionType] { + &self.extension_types + } + + /// Returns the proposal_types from this extension + pub fn proposal_types(&self) -> &[ProposalType] { + &self.proposal_types + } + + /// Returns the credential_types from this extension + pub fn credential_types(&self) -> &[CredentialType] { + &self.credential_types + } +} + +impl Deserializer for RequiredCapabilitiesExtension { + fn deserialize(buf: &mut B) -> Result + where + Self: Sized, + B: Buf, + { + let mut extension_types = vec![]; + deserialize_vector(buf, |b: &mut Bytes| -> Result<()> { + if b.remaining() < 2 { + return Err(Error::BufferTooSmall); + } + let extension_type: ExtensionType = b.get_u16().into(); + extension_types.push(extension_type); + Ok(()) + })?; + + let mut proposal_types = vec![]; + deserialize_vector(buf, |b: &mut Bytes| -> Result<()> { + if b.remaining() < 2 { + return Err(Error::BufferTooSmall); + } + let proposal_type: ProposalType = b.get_u16().into(); + proposal_types.push(proposal_type); + Ok(()) + })?; + + let mut credential_types = vec![]; + deserialize_vector(buf, |b: &mut Bytes| -> Result<()> { + if b.remaining() < 2 { + return Err(Error::BufferTooSmall); + } + let credential_type: CredentialType = b.get_u16().into(); + credential_types.push(credential_type); + Ok(()) + })?; + + Ok(RequiredCapabilitiesExtension { + extension_types, + proposal_types, + credential_types, + }) + } +} + +impl Serializer for RequiredCapabilitiesExtension { + fn serialize(&self, buf: &mut B) -> Result<()> + where + Self: Sized, + B: BufMut, + { + serialize_vector( + self.extension_types.len(), + buf, + |i: usize, b: &mut BytesMut| -> Result<()> { + b.put_u16(self.extension_types[i].into()); + Ok(()) + }, + )?; + + serialize_vector( + self.proposal_types.len(), + buf, + |i: usize, b: &mut BytesMut| -> Result<()> { + b.put_u16(self.proposal_types[i].into()); + Ok(()) + }, + )?; + + serialize_vector( + self.credential_types.len(), + buf, + |i: usize, b: &mut BytesMut| -> Result<()> { + b.put_u16(self.credential_types[i].into()); + Ok(()) + }, + )?; + + Ok(()) + } +} + +/// ExternalPub Extension +#[derive(Default, Debug, Clone, Eq, PartialEq)] +pub struct ExternalPubExtension { + external_pub: HPKEPublicKey, +} + +impl ExternalPubExtension { + /// Creates a new ExternalPubExtension + pub fn new(external_pub: HPKEPublicKey) -> Self { + Self { external_pub } + } + + /// Returns the HPKEPublicKey from this extension + pub fn external_pub(&self) -> &HPKEPublicKey { + &self.external_pub + } +} + +impl Deserializer for ExternalPubExtension { + fn deserialize(buf: &mut B) -> Result + where + Self: Sized, + B: Buf, + { + Ok(ExternalPubExtension { + external_pub: HPKEPublicKey::deserialize(buf)?, + }) + } +} + +impl Serializer for ExternalPubExtension { + fn serialize(&self, buf: &mut B) -> Result<()> + where + Self: Sized, + B: BufMut, + { + self.external_pub.serialize(buf) + } +} + +/// ExternalSenders Extension +#[derive(Default, Debug, Clone, Eq, PartialEq)] +pub struct ExternalSendersExtension { + signature_key: SignaturePublicKey, + credential: Credential, +} + +impl ExternalSendersExtension { + /// Creates a new ExternalSendersExtension + pub fn new(signature_key: SignaturePublicKey, credential: Credential) -> Self { + Self { + signature_key, + credential, + } + } + + /// Returns the SignaturePublicKey from this extension + pub fn signature_key(&self) -> &SignaturePublicKey { + &self.signature_key + } + + /// Returns the Credential from this extension + pub fn credential(&self) -> &Credential { + &self.credential + } +} + +impl Deserializer for ExternalSendersExtension { + fn deserialize(buf: &mut B) -> Result + where + Self: Sized, + B: Buf, + { + let signature_key = SignaturePublicKey::deserialize(buf)?; + let credential = Credential::deserialize(buf)?; + + Ok(ExternalSendersExtension { + signature_key, + credential, + }) + } +} + +impl Serializer for ExternalSendersExtension { + fn serialize(&self, buf: &mut B) -> Result<()> + where + Self: Sized, + B: BufMut, + { + self.signature_key.serialize(buf)?; + self.credential.serialize(buf) + } +} + +/// ExternalSenders Extension +#[derive(Default, Debug, Clone, Eq, PartialEq)] +pub struct UnknownExtension(Bytes); + +impl UnknownExtension { + /// Creates a new UnknownExtension + pub fn new(unknown: Bytes) -> Self { + Self(unknown) + } +} + +impl Deref for UnknownExtension { + type Target = Bytes; + + fn deref(&self) -> &Self::Target { + &self.0 + } +} + +impl Deserializer for UnknownExtension { + fn deserialize(buf: &mut B) -> Result + where + Self: Sized, + B: Buf, + { + Ok(UnknownExtension(deserialize_opaque_vec(buf)?)) + } +} + +impl Serializer for UnknownExtension { + fn serialize(&self, buf: &mut B) -> Result<()> + where + Self: Sized, + B: BufMut, + { + serialize_opaque_vec(&self.0, buf) + } +} diff --git a/rmls/src/ratchet_tree/mod.rs b/rmls/src/ratchet_tree/mod.rs index 0ff7b8e..686ccf2 100644 --- a/rmls/src/ratchet_tree/mod.rs +++ b/rmls/src/ratchet_tree/mod.rs @@ -526,11 +526,11 @@ impl LeafNode { for et in &self.capabilities.extensions { supported_exts.insert(*et); } - for ext in &self.extensions.0 { - if !supported_exts.contains(&ext.extension_type) { + for ext in self.extensions.extensions() { + if !supported_exts.contains(&ext.extension_type()) { return Err( Error::ExtensionTypeUsedByLeafNodeNotSupportedByThatLeafNode( - ext.extension_type.into(), + ext.extension_type().into(), ), ); }