Skip to content

Commit

Permalink
message: switch to bytes::Bytes
Browse files Browse the repository at this point in the history
  • Loading branch information
daniel-abramov committed Dec 11, 2024
1 parent b31e88f commit a17a4c2
Show file tree
Hide file tree
Showing 4 changed files with 86 additions and 42 deletions.
2 changes: 1 addition & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ __rustls-tls = ["rustls", "rustls-pki-types"]
[dependencies]
data-encoding = { version = "2", optional = true }
byteorder = "1.3.2"
bytes = "1.0"
bytes = "1.9.0"
http = { version = "1.0", optional = true }
httparse = { version = "1.3.4", optional = true }
log = "0.4.8"
Expand Down
96 changes: 69 additions & 27 deletions src/protocol/frame/frame.rs
Original file line number Diff line number Diff line change
@@ -1,15 +1,18 @@
use byteorder::{NetworkEndian, ReadBytesExt};
use log::*;
use std::{
borrow::Cow,
default::Default,
fmt,
io::{Cursor, ErrorKind, Read, Write},
mem,
result::Result as StdResult,
str::Utf8Error,
string::{FromUtf8Error, String},
};

use byteorder::{NetworkEndian, ReadBytesExt};
use bytes::Bytes;
use log::*;

use super::{
coding::{CloseCode, Control, Data, OpCode},
mask::{apply_mask, generate_mask},
Expand Down Expand Up @@ -203,19 +206,55 @@ impl FrameHeader {
}
}

#[derive(Debug, Clone, Eq, PartialEq)]
enum Payload {
Owned(Vec<u8>),
Shared(Bytes),
}

impl Payload {
pub fn as_slice(&self) -> &[u8] {
match self {
Payload::Owned(v) => v,
Payload::Shared(v) => v,
}
}

pub fn as_mut_slice(&mut self) -> &mut [u8] {
match self {
Payload::Owned(v) => &mut *v,
Payload::Shared(v) => {
// Using `Bytes::to_vec()` or `Vec::from(bytes.as_ref())` would mean making a copy.
// `Bytes::into()` would not make a copy if our `Bytes` instance is the only one.
let data = mem::take(v).into();
*self = Payload::Owned(data);
let Payload::Owned(v) = self else { unreachable!() };
v
}
}
}

pub fn into_data(self) -> Vec<u8> {
match self {
Payload::Owned(v) => v,
Payload::Shared(v) => v.into(),
}
}
}

/// A struct representing a WebSocket frame.
#[derive(Debug, Clone, Eq, PartialEq)]
pub struct Frame {
header: FrameHeader,
payload: Vec<u8>,
payload: Payload,
}

impl Frame {
/// Get the length of the frame.
/// This is the length of the header + the length of the payload.
#[inline]
pub fn len(&self) -> usize {
let length = self.payload.len();
let length = self.payload.as_slice().len();
self.header.len(length as u64) + length
}

Expand All @@ -239,14 +278,14 @@ impl Frame {

/// Get a reference to the frame's payload.
#[inline]
pub fn payload(&self) -> &Vec<u8> {
&self.payload
pub fn payload(&self) -> &[u8] {
self.payload.as_slice()
}

/// Get a mutable reference to the frame's payload.
#[inline]
pub fn payload_mut(&mut self) -> &mut Vec<u8> {
&mut self.payload
pub fn payload_mut(&mut self) -> &mut [u8] {
self.payload.as_mut_slice()
}

/// Test whether the frame is masked.
Expand All @@ -269,36 +308,36 @@ impl Frame {
#[inline]
pub(crate) fn apply_mask(&mut self) {
if let Some(mask) = self.header.mask.take() {
apply_mask(&mut self.payload, mask);
apply_mask(self.payload.as_mut_slice(), mask);
}
}

/// Consume the frame into its payload as binary.
#[inline]
pub fn into_data(self) -> Vec<u8> {
self.payload
self.payload.into_data()
}

/// Consume the frame into its payload as string.
#[inline]
pub fn into_string(self) -> StdResult<String, FromUtf8Error> {
String::from_utf8(self.payload)
String::from_utf8(self.payload.into_data())
}

/// Get frame payload as `&str`.
#[inline]
pub fn to_text(&self) -> Result<&str, Utf8Error> {
std::str::from_utf8(&self.payload)
std::str::from_utf8(self.payload.as_slice())
}

/// Consume the frame into a closing frame.
#[inline]
pub(crate) fn into_close(self) -> Result<Option<CloseFrame<'static>>> {
match self.payload.len() {
match self.payload.as_slice().len() {
0 => Ok(None),
1 => Err(Error::Protocol(ProtocolError::InvalidCloseSequence)),
_ => {
let mut data = self.payload;
let mut data = self.payload.into_data();
let code = u16::from_be_bytes([data[0], data[1]]).into();
data.drain(0..2);
let text = String::from_utf8(data)?;
Expand All @@ -309,33 +348,36 @@ impl Frame {

/// Create a new data frame.
#[inline]
pub fn message(data: Vec<u8>, opcode: OpCode, is_final: bool) -> Frame {
pub fn message(data: impl Into<Bytes>, opcode: OpCode, is_final: bool) -> Frame {
debug_assert!(matches!(opcode, OpCode::Data(_)), "Invalid opcode for data frame.");

Frame { header: FrameHeader { is_final, opcode, ..FrameHeader::default() }, payload: data }
Frame {
header: FrameHeader { is_final, opcode, ..FrameHeader::default() },
payload: Payload::Shared(data.into()),
}
}

/// Create a new Pong control frame.
#[inline]
pub fn pong(data: Vec<u8>) -> Frame {
pub fn pong(data: impl Into<Bytes>) -> Frame {
Frame {
header: FrameHeader {
opcode: OpCode::Control(Control::Pong),
..FrameHeader::default()
},
payload: data,
payload: Payload::Shared(data.into()),
}
}

/// Create a new Ping control frame.
#[inline]
pub fn ping(data: Vec<u8>) -> Frame {
pub fn ping(data: impl Into<Bytes>) -> Frame {
Frame {
header: FrameHeader {
opcode: OpCode::Control(Control::Ping),
..FrameHeader::default()
},
payload: data,
payload: Payload::Shared(data.into()),
}
}

Expand All @@ -351,17 +393,17 @@ impl Frame {
Vec::new()
};

Frame { header: FrameHeader::default(), payload }
Frame { header: FrameHeader::default(), payload: Payload::Owned(payload) }
}

/// Create a frame from given header and data.
pub fn from_payload(header: FrameHeader, payload: Vec<u8>) -> Self {
Frame { header, payload }
pub fn from_payload(header: FrameHeader, payload: impl Into<Bytes>) -> Self {
Frame { header, payload: Payload::Shared(payload.into()) }
}

/// Write a frame out to a buffer
pub fn format(mut self, output: &mut impl Write) -> Result<()> {
self.header.format(self.payload.len() as u64, output)?;
self.header.format(self.payload.as_slice().len() as u64, output)?;
self.apply_mask();
output.write_all(self.payload())?;
Ok(())
Expand Down Expand Up @@ -390,8 +432,8 @@ payload: 0x{}
self.header.opcode,
// self.mask.map(|mask| format!("{:?}", mask)).unwrap_or("NONE".into()),
self.len(),
self.payload.len(),
self.payload.iter().fold(String::new(), |mut output, byte| {
self.payload.as_slice().len(),
self.payload.as_slice().iter().fold(String::new(), |mut output, byte| {
_ = write!(output, "{byte:02x}");
output
})
Expand Down Expand Up @@ -479,7 +521,7 @@ mod tests {

#[test]
fn display() {
let f = Frame::message("hi there".into(), OpCode::Data(Data::Text), true);
let f = Frame::message("hi there", OpCode::Data(Data::Text), true);
let view = format!("{f}");
assert!(view.contains("payload:"));
}
Expand Down
18 changes: 10 additions & 8 deletions src/protocol/message.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
use std::{fmt, result::Result as StdResult, str};

use bytes::Bytes;

use super::frame::{CloseFrame, Frame};
use crate::error::{CapacityError, Error, Result};

Expand Down Expand Up @@ -135,7 +137,7 @@ impl IncompleteMessage {
/// Convert an incomplete message into a complete one.
pub fn complete(self) -> Result<Message> {
match self.collector {
IncompleteMessageCollector::Binary(v) => Ok(Message::Binary(v)),
IncompleteMessageCollector::Binary(v) => Ok(Message::Binary(Bytes::from(v))),
IncompleteMessageCollector::Text(t) => {
let text = t.into_string()?;
Ok(Message::Text(text))
Expand All @@ -156,15 +158,15 @@ pub enum Message {
/// A text WebSocket message
Text(String),
/// A binary WebSocket message
Binary(Vec<u8>),
Binary(Bytes),
/// A ping message with the specified payload
///
/// The payload here must have a length less than 125 bytes
Ping(Vec<u8>),
/// A pong message with the specified payload
Ping(Bytes),
/// A pong message with the specified payloadVec<u8>
///
/// The payload here must have a length less than 125 bytes
Pong(Vec<u8>),
Pong(Bytes),
/// A close message with the optional close frame.
Close(Option<CloseFrame<'static>>),
/// Raw frame. Note, that you're not going to get this value while reading the message.
Expand All @@ -185,7 +187,7 @@ impl Message {
where
B: Into<Vec<u8>>,
{
Message::Binary(bin.into())
Message::Binary(Bytes::from(bin.into()))
}

/// Indicates whether a message is a text message.
Expand Down Expand Up @@ -235,7 +237,7 @@ impl Message {
pub fn into_data(self) -> Vec<u8> {
match self {
Message::Text(string) => string.into_bytes(),
Message::Binary(data) | Message::Ping(data) | Message::Pong(data) => data,
Message::Binary(data) | Message::Ping(data) | Message::Pong(data) => data.into(),
Message::Close(None) => Vec::new(),
Message::Close(Some(frame)) => frame.reason.into_owned().into_bytes(),
Message::Frame(frame) => frame.into_data(),
Expand All @@ -247,7 +249,7 @@ impl Message {
match self {
Message::Text(string) => Ok(string),
Message::Binary(data) | Message::Ping(data) | Message::Pong(data) => {
Ok(String::from_utf8(data)?)
Ok(String::from_utf8(data.into())?)
}
Message::Close(None) => Ok(String::new()),
Message::Close(Some(frame)) => Ok(frame.reason.into_owned()),
Expand Down
12 changes: 6 additions & 6 deletions src/protocol/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -439,7 +439,7 @@ impl WebSocketContext {
}

let frame = match message {
Message::Text(data) => Frame::message(data.into(), OpCode::Data(OpData::Text), true),
Message::Text(data) => Frame::message(data, OpCode::Data(OpData::Text), true),
Message::Binary(data) => Frame::message(data, OpCode::Data(OpData::Binary), true),
Message::Ping(data) => Frame::ping(data),
Message::Pong(data) => {
Expand Down Expand Up @@ -608,9 +608,9 @@ impl WebSocketContext {
if self.state.is_active() {
self.set_additional(Frame::pong(data.clone()));
}
Ok(Some(Message::Ping(data)))
Ok(Some(Message::Ping(data.into())))
}
OpCtl::Pong => Ok(Some(Message::Pong(frame.into_data()))),
OpCtl::Pong => Ok(Some(Message::Pong(frame.into_data().into()))),
}
}

Expand Down Expand Up @@ -826,10 +826,10 @@ mod tests {
0x03,
]);
let mut socket = WebSocket::from_raw_socket(WriteMoc(incoming), Role::Client, None);
assert_eq!(socket.read().unwrap(), Message::Ping(vec![1, 2]));
assert_eq!(socket.read().unwrap(), Message::Pong(vec![3]));
assert_eq!(socket.read().unwrap(), Message::Ping(vec![1, 2].into()));
assert_eq!(socket.read().unwrap(), Message::Pong(vec![3].into()));
assert_eq!(socket.read().unwrap(), Message::Text("Hello, World!".into()));
assert_eq!(socket.read().unwrap(), Message::Binary(vec![0x01, 0x02, 0x03]));
assert_eq!(socket.read().unwrap(), Message::Binary(vec![0x01, 0x02, 0x03].into()));
}

#[test]
Expand Down

0 comments on commit a17a4c2

Please sign in to comment.