Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

combine_dtls_fragments helper #48

Open
wants to merge 3 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Binary file added assets/dtls_cert_frag01.bin
Binary file not shown.
Binary file added assets/dtls_cert_frag02.bin
Binary file not shown.
Binary file added assets/dtls_cert_frag03.bin
Binary file not shown.
Binary file added assets/dtls_cert_frag04.bin
Binary file not shown.
8 changes: 7 additions & 1 deletion src/dtls.rs
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,12 @@ pub struct DTLSMessageHandshake<'a> {
pub body: DTLSMessageHandshakeBody<'a>,
}

impl<'a> DTLSMessageHandshake<'a> {
pub fn is_fragment(&self) -> bool {
matches!(self.body, DTLSMessageHandshakeBody::Fragment(_))
}
}

/// DTLS Generic handshake message
#[derive(Debug, PartialEq)]
pub enum DTLSMessageHandshakeBody<'a> {
Expand Down Expand Up @@ -133,7 +139,7 @@ impl<'a> DTLSMessage<'a> {
/// fragments to be a complete message.
pub fn is_fragment(&self) -> bool {
match self {
DTLSMessage::Handshake(h) => matches!(h.body, DTLSMessageHandshakeBody::Fragment(_)),
DTLSMessage::Handshake(h) => h.is_fragment(),
_ => false,
}
}
Expand Down
219 changes: 219 additions & 0 deletions src/dtls_combine.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,219 @@
use core::array;

use nom::error::{make_error, ErrorKind};
use nom::{Err, IResult};

use crate::{
parse_dtls_message_handshake, DTLSMessage, DTLSMessageHandshake, DTLSMessageHandshakeBody,
};

const MAX_FRAGMENTS: usize = 50;

/// Combine the given fragments into one. Returns true if the fragments made a complete output.
/// The fragments are combined in such a way that the output constitutes a complete DTLSMessage.
///
/// Returns `None` if the fragments are not complete.
///
/// Errors if:
///
/// 1. The output is not big enough to hold the reconstituted messages
/// 2. Fragments are not of the same type (for example ClientHello mixed with Certificate)
/// 3. (Total) length field differs between the fragments
/// 4. Fragment offset/length are not consistent with total length
/// 5. The DTLSMessageHandshakeBody in the message is not a Fragment
/// 6. The message_seq differs between the fragments.
///
/// Panics if there are more than 50 fragments.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(minor) This comment is not correct now, it will return an error (and not panic)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yep will fix. Following the logic (if we keep that), this should be error case 7 or something.

pub fn combine_dtls_fragments<'a>(
fragments: &[DTLSMessageHandshake],
out: &'a mut [u8],
) -> IResult<&'a [u8], Option<DTLSMessage<'a>>> {
if fragments.is_empty() {
return Ok((&[], None));
}

if fragments.len() > MAX_FRAGMENTS {
return Err(Err::Error(make_error(&*out, ErrorKind::TooLarge)));
}

const MESSAGE_HEADER_OFFSET: usize = 12;

// The header all of the fragments share the same DTLSMessage start, apart from the
// fragment information. This goes into the first 12 bytes.
if out.len() < MESSAGE_HEADER_OFFSET {
// Error case 1
return Err(Err::Error(make_error(&*out, ErrorKind::Fail)));
}

// Helper to iterate the fragments in order.
let ordered = Ordered::new(fragments);

// Investigate each fragment_offset + fragment_length to figure out
// the max contiguous range over the fragments.
let max = ordered.max_contiguous();

// Unwrap is OK, because we have at least one item (checked above).
let first_handshake = ordered.iter().next().unwrap();
Comment on lines +55 to +56
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(minor) The test is not obvious, since it relies on Ordered keeping the same number of items as fragments. A test here (even if redundant) would not be bad to avoid unwrap, and would not cause performance problems.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Alright. It's an invariant, in the sense that if I got this assumption wrong, there's something more serious wrong (bug). I generally don't like masking those as errors, as it would appear the user provided incorrect data when it's actually a bug in this code.

But happy to change to error.


if first_handshake.fragment_offset != 0 {
// The first fragment must start at 0, or we might have
// missing packets or arriving out of order.
return Ok((&[], None));
}

let msg_type = first_handshake.msg_type;
let message_seq = first_handshake.message_seq;
let length = first_handshake.length;

#[allow(clippy::comparison_chain)]
if max > length {
// Error case 4
return Err(Err::Error(make_error(&*out, ErrorKind::Fail)));
} else if max < length {
// We do not have all fragments yet
return Ok((&[], None));
}

// Write the header into output.
{
out[0] = msg_type.into(); // The type.
out[1..4].copy_from_slice(&length.to_be_bytes()[1..]); // 24 bit length
out[4..6].copy_from_slice(&message_seq.to_be_bytes()); // 16 bit message sequence
out[6..9].copy_from_slice(&[0, 0, 0]); // 24 bit fragment_offset, which is 0 for the entire message.
out[9..12].copy_from_slice(&length.to_be_bytes()[1..]); // 24 bit fragment_length, which is entire length.
}

let data = &mut out[MESSAGE_HEADER_OFFSET..];

if data.len() < length as usize {
// Error case 1
return Err(Err::Error(make_error(&*out, ErrorKind::Fail)));
}

// Loop the fragments, in order and output the data.
for handshake in ordered.iter() {
if msg_type != handshake.msg_type {
// Error case 2
return Err(Err::Error(make_error(&*out, ErrorKind::Fail)));
}

if handshake.length != length {
// Error case 3
return Err(Err::Error(make_error(&*out, ErrorKind::Fail)));
}

if handshake.message_seq != message_seq {
// Error case 6
return Err(Err::Error(make_error(&*out, ErrorKind::Fail)));
}
Comment on lines +95 to +108
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This method will return Fail for every possible error, which will make debugging complicated. Instead, this could justify returning a different error kind. Maybe the context feature could help?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

replying to the context suggestion: this may not easy to use here. Maybe a custom error type would be better

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What if a made a custom error type with a subenum for the errors? Something like:

ErrorKind::DtlsCombine(DtlsCombineError)

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The usual method in nom is to create a custom error type, an implement From<ErrorKind> and ParseError. See x509-parser errors for an example.

But, please wait before changing errors, because this will add code, and this can be done in a second step


let from = handshake.fragment_offset as usize;
let to = from + handshake.fragment_length as usize;

let body = match &handshake.body {
DTLSMessageHandshakeBody::Fragment(v) => v,
_ => {
// Error case 5
return Err(Err::Error(make_error(&*out, ErrorKind::Fail)));
}
};

// Copy into output.
data[from..to].copy_from_slice(&body[..]);
}

// This parse should succeed now and produce a complete message.
let (rest, message) = parse_dtls_message_handshake(out)?;

Ok((rest, Some(message)))
}

struct Ordered<'a, 'b>([usize; MAX_FRAGMENTS], &'a [DTLSMessageHandshake<'b>]);
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am missing the entire logic of this structure, this may require some more explanations.
As far as I understand, the field 0 is the sorted list of indices, the fragments being in the field 1.

Why not create and use directly a list of sorted fragments ? Or, to better deal with lifetimes, references to fragments ?

You would just have to iterate the list of fragments, and insert them in a sorted structure (list, tree, etc).

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The goal here is to avoid an extra allocation. The user provides &[DTLSMessageHandshake], which could have been received in any order (given this is UDP). The order we want is in the fragment_offset field of each message. Thus, the Ordered iterator iterates the messages without any extra heap allocation.

Personally I think this is better than allocating, but I should provide this motivation as code doc. However if you prefer new Vec, I can make that too.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you for clarifying. The methods seems globally fine, and there is no need to switch to an allocation.

That said, I think the code needs some comments to be more readable later. This is my main concern.
Also, the fields could be named (for ex indices/fragments) for clarity, and avoid .0 and .1.


impl<'a, 'b> Ordered<'a, 'b> {
fn new(fragments: &'a [DTLSMessageHandshake<'b>]) -> Self {
// Indexes that will point into handshakes
let mut order: [usize; MAX_FRAGMENTS] = array::from_fn(|i| i);

// Sort the index for fragment_offset starts.
order.sort_by_key(|idx| {
fragments
.get(*idx)
.map(|h| h.fragment_offset)
// Somewhere outside the fragments length.
.unwrap_or(*idx as u32 + 1000)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What is this value? It seems dangerous to return a random offset if the fragment with this ID does not exist.
Should this return an error? Or is this a magic value?

Also, note that Option has a method map_or_else that could be used.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah. Poor logic here. It's another invariant. idx comes from order which means it's impossible for .get(*idx) to return None. Probably better to change this to a direct array reference fragments[idx].fragment_offset

I fix.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a consequence of not using a sorted list.
Well, even if there is an invariant, it is not described (and a minimum would be to use assert/debug_assert to state it).

});

Self(order, fragments)
}

fn iter(&self) -> impl Iterator<Item = &'a DTLSMessageHandshake<'b>> + '_ {
let len = self.1.len();
self.0
.iter()
.take_while(move |idx| **idx < len)
.map(move |idx| &self.1[*idx])
}

// Find the max contiguous fragment_offset/fragment_length.
fn max_contiguous(&self) -> u32 {
let mut max = 0;

for h in self.iter() {
// DTLS fragments can overlap, which means the offset might not start at the previous end.
if h.fragment_offset <= max {
let start = h.fragment_offset;
max = start + h.fragment_length;
} else {
// Not contiguous.
return 0;
}
}

max
}
}

#[cfg(test)]
mod test {
use crate::parse_dtls_plaintext_record;

use super::*;

#[test]
fn read_dtls_certifiate_fragments() {
// These are complete packets dumped with wireshark.
const DTLS_CERT01: &[u8] = include_bytes!("../assets/dtls_cert_frag01.bin");
const DTLS_CERT02: &[u8] = include_bytes!("../assets/dtls_cert_frag02.bin");
const DTLS_CERT03: &[u8] = include_bytes!("../assets/dtls_cert_frag03.bin");
const DTLS_CERT04: &[u8] = include_bytes!("../assets/dtls_cert_frag04.bin");

let mut fragments = vec![];

for c in &[DTLS_CERT01, DTLS_CERT02, DTLS_CERT03, DTLS_CERT04] {
let (_, record) = parse_dtls_plaintext_record(c).expect("parsing failed");

for message in record.messages {
// All of these should be fragments.
assert!(message.is_fragment());

let handshake = match message {
DTLSMessage::Handshake(v) => v,
_ => panic!("Expected Handshake"),
};

assert!(handshake.is_fragment());
fragments.push(handshake);
}
}

// Temporary output to combine the fragments into.
let mut out = vec![0_u8; 4192];
let (_, message) = combine_dtls_fragments(&fragments, &mut out).expect("combine fragments");

// This optional should hold Some(DTLSMessage) indicating a complete parse.
let message = message.expect("Combined fragments");

println!("{:02x?}", message);
}
}
2 changes: 2 additions & 0 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -142,6 +142,7 @@ extern crate alloc;

mod certificate_transparency;
mod dtls;
mod dtls_combine;
mod tls;
mod tls_alert;
mod tls_ciphers;
Expand All @@ -154,6 +155,7 @@ mod tls_states;

pub use certificate_transparency::*;
pub use dtls::*;
pub use dtls_combine::*;
pub use tls::*;
pub use tls_alert::*;
pub use tls_ciphers::*;
Expand Down