Skip to content

Commit

Permalink
refactor ProtocolVersion to enum
Browse files Browse the repository at this point in the history
  • Loading branch information
yngrtc committed Aug 30, 2023
1 parent ef06da9 commit 5a518de
Show file tree
Hide file tree
Showing 8 changed files with 44 additions and 21 deletions.
33 changes: 28 additions & 5 deletions rmls/src/framing/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,32 @@ use crate::serde::*;
use crate::tree::math::LeafIndex;
use crate::tree::secret::RatchetSecret;

pub(crate) type ProtocolVersion = u16;
/// [RFC9420 Sec.6](https://www.rfc-editor.org/rfc/rfc9420.html#section-6) ProtocolVersion
#[derive(Default, Debug, Copy, Clone, Eq, PartialEq, Hash)]
#[repr(u16)]
pub enum ProtocolVersion {
#[default]
MLS10 = 1,
Unsupported(u16),
}

pub(crate) const PROTOCOL_VERSION_MLS10: ProtocolVersion = 1;
impl From<u16> for ProtocolVersion {
fn from(v: u16) -> Self {
match v {
1 => ProtocolVersion::MLS10,
_ => ProtocolVersion::Unsupported(v),
}
}
}

impl From<ProtocolVersion> for u16 {
fn from(val: ProtocolVersion) -> u16 {
match val {
ProtocolVersion::MLS10 => 1,
ProtocolVersion::Unsupported(v) => v,
}
}
}

#[derive(Default, Debug, Copy, Clone, Eq, PartialEq)]
#[repr(u8)]
Expand Down Expand Up @@ -367,7 +390,7 @@ impl AuthenticatedContent {

fn framed_content_tbs(&self, ctx: &GroupContext) -> FramedContentTBS {
FramedContentTBS {
version: PROTOCOL_VERSION_MLS10,
version: ProtocolVersion::MLS10,
wire_format: self.wire_format,
content: self.content.clone(),
context: Some(ctx.clone()),
Expand Down Expand Up @@ -491,7 +514,7 @@ impl Deserializer for FramedContentTBS {
if buf.remaining() < 2 {
return Err(Error::BufferTooSmall);
}
let version = buf.get_u16();
let version = buf.get_u16().into();
let wire_format = WireFormat::deserialize(buf)?;
let content = FramedContent::deserialize(buf)?;
let context = match &content.sender {
Expand All @@ -514,7 +537,7 @@ impl Serializer for FramedContentTBS {
Self: Sized,
B: BufMut,
{
buf.put_u16(self.version);
buf.put_u16(self.version.into());
self.wire_format.serialize(buf)?;
self.content.serialize(buf)?;
match &self.content.sender {
Expand Down
2 changes: 1 addition & 1 deletion rmls/src/group/group_test.rs
Original file line number Diff line number Diff line change
Expand Up @@ -308,7 +308,7 @@ fn message_protection_test(
tc: &MessageProtectionTest,
) -> Result<()> {
let ctx = GroupContext {
version: PROTOCOL_VERSION_MLS10,
version: ProtocolVersion::MLS10,
cipher_suite,
group_id: tc.group_id.clone().into(),
epoch: tc.epoch,
Expand Down
8 changes: 4 additions & 4 deletions rmls/src/group/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -359,10 +359,10 @@ impl Deserializer for Message {
if buf.remaining() < 2 {
return Err(Error::BufferTooSmall);
}
let version = buf.get_u16();
let version: ProtocolVersion = buf.get_u16().into();

if version != PROTOCOL_VERSION_MLS10 {
return Err(Error::InvalidProtocolVersion(version));
if version != ProtocolVersion::MLS10 {
return Err(Error::InvalidProtocolVersion(version.into()));
}

let wire_format = WireFormat::deserialize(buf)?;
Expand Down Expand Up @@ -392,7 +392,7 @@ impl Serializer for Message {
Self: Sized,
B: BufMut,
{
buf.put_u16(self.version);
buf.put_u16(self.version.into());
self.wire_format.serialize(buf)?;
match &self.message {
WireFormatMessage::PublicMessage(message) => {
Expand Down
4 changes: 2 additions & 2 deletions rmls/src/group/proposal.rs
Original file line number Diff line number Diff line change
Expand Up @@ -296,7 +296,7 @@ impl Deserializer for ReInitProposal {
if buf.remaining() < 4 {
return Err(Error::BufferTooSmall);
}
let version = buf.get_u16();
let version = buf.get_u16().into();
let cipher_suite = buf.get_u16().try_into()?;

let extensions = deserialize_extensions(buf)?;
Expand All @@ -316,7 +316,7 @@ impl Serializer for ReInitProposal {
B: BufMut,
{
serialize_opaque_vec(&self.group_id, buf)?;
buf.put_u16(self.version);
buf.put_u16(self.version.into());
buf.put_u16(self.cipher_suite as u16);
serialize_extensions(&self.extensions, buf)
}
Expand Down
2 changes: 1 addition & 1 deletion rmls/src/key/key_test.rs
Original file line number Diff line number Diff line change
Expand Up @@ -159,7 +159,7 @@ fn key_schedule_test(
println!("epoch {}", i);

let ctx = GroupContext {
version: PROTOCOL_VERSION_MLS10,
version: ProtocolVersion::MLS10,
cipher_suite,
group_id: tc.group_id.clone().into(),
epoch: i as u64,
Expand Down
4 changes: 2 additions & 2 deletions rmls/src/key/package.rs
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ impl Deserializer for KeyPackage {
return Err(Error::BufferTooSmall);
}

let version = buf.get_u16();
let version = buf.get_u16().into();
let cipher_suite = buf.get_u16().try_into()?;
let init_key = deserialize_opaque_vec(buf)?;
let leaf_node = LeafNode::deserialize(buf)?;
Expand Down Expand Up @@ -64,7 +64,7 @@ impl KeyPackage {
Self: Sized,
B: BufMut,
{
buf.put_u16(self.version);
buf.put_u16(self.version.into());
buf.put_u16(self.cipher_suite as u16);
serialize_opaque_vec(&self.init_key, buf)?;
self.leaf_node.serialize(buf)?;
Expand Down
8 changes: 4 additions & 4 deletions rmls/src/key/schedule.rs
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ impl Deserializer for GroupContext {
return Err(Error::BufferTooSmall);
}

let version = buf.get_u16();
let version: ProtocolVersion = buf.get_u16().into();
let cipher_suite = buf.get_u16().try_into()?;
let group_id = deserialize_opaque_vec(buf)?;
if buf.remaining() < 8 {
Expand All @@ -37,8 +37,8 @@ impl Deserializer for GroupContext {
let tree_hash = deserialize_opaque_vec(buf)?;
let confirmed_transcript_hash = deserialize_opaque_vec(buf)?;

if version != PROTOCOL_VERSION_MLS10 {
return Err(Error::InvalidProposalTypeValue(version));
if version != ProtocolVersion::MLS10 {
return Err(Error::InvalidProposalTypeValue(version.into()));
}

let extensions = deserialize_extensions(buf)?;
Expand All @@ -60,7 +60,7 @@ impl Serializer for GroupContext {
Self: Sized,
B: BufMut,
{
buf.put_u16(self.version);
buf.put_u16(self.version.into());
buf.put_u16(self.cipher_suite as u16);
serialize_opaque_vec(&self.group_id, buf)?;
buf.put_u64(self.epoch);
Expand Down
4 changes: 2 additions & 2 deletions rmls/src/tree/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -172,7 +172,7 @@ impl Deserializer for Capabilities {
if b.remaining() < 2 {
return Err(Error::BufferTooSmall);
}
let ver: ProtocolVersion = b.get_u16();
let ver: ProtocolVersion = b.get_u16().into();
versions.push(ver);
Ok(())
})?;
Expand Down Expand Up @@ -236,7 +236,7 @@ impl Serializer for Capabilities {
self.versions.len(),
buf,
|i: usize, b: &mut BytesMut| -> Result<()> {
b.put_u16(self.versions[i]);
b.put_u16(self.versions[i].into());
Ok(())
},
)?;
Expand Down

0 comments on commit 5a518de

Please sign in to comment.