diff --git a/Cargo.toml b/Cargo.toml index 11ae514a9..73784860a 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -46,6 +46,8 @@ harness = false [workspace] members = [ + "http", "tools", + "benches/timers_container", ] diff --git a/Makefile b/Makefile index e29736c2d..b01a9bcf5 100644 --- a/Makefile +++ b/Makefile @@ -19,7 +19,7 @@ TARGETS ?= x86_64-apple-darwin x86_64-unknown-linux-gnu x86_64-unknown-freebsd RUN ?= test test: - cargo test --all-features + cargo test --all-features --workspace # NOTE: Keep `RUSTFLAGS` and `RUSTDOCFLAGS` in sync to ensure the doc tests # compile correctly. @@ -27,14 +27,14 @@ test_sanitiser: @if [ -z $${SAN+x} ]; then echo "Required '\$$SAN' variable is not set" 1>&2; exit 1; fi RUSTFLAGS="-Z sanitizer=$$SAN -Z sanitizer-memory-track-origins" \ RUSTDOCFLAGS="-Z sanitizer=$$SAN -Z sanitizer-memory-track-origins" \ - cargo test -Z build-std --all-features --target $(RUSTUP_TARGET) + cargo test -Z build-std --all-features --workspace --target $(RUSTUP_TARGET) check: - cargo check --all-features --all-targets + cargo check --all-features --workspace --all-targets check_all_targets: $(TARGETS) $(TARGETS): - cargo check --all-features --all-targets --target $@ + cargo check --all-features --workspace --all-targets --target $@ # NOTE: when using this command you might want to change the `test` target to # only run a subset of the tests you're actively working on. @@ -47,7 +47,7 @@ dev: # multiple-crate-versions: socket2 is included twice? But `cargo tree` disagrees. clippy: lint lint: - cargo clippy --all-features -- \ + cargo clippy --all-features --workspace -- \ --deny clippy::all \ --deny clippy::correctness \ --deny clippy::style \ @@ -110,10 +110,10 @@ install_llvm_tools: rustup component add llvm-tools-preview doc: - cargo doc --all-features + cargo doc --all-features --workspace doc_private: - cargo doc --all-features --document-private-items + cargo doc --all-features --workspace --document-private-items clean: cargo clean diff --git a/http/Cargo.toml b/http/Cargo.toml new file mode 100644 index 000000000..d72cf2707 --- /dev/null +++ b/http/Cargo.toml @@ -0,0 +1,19 @@ +[package] +name = "heph-http" +version = "0.1.0" +edition = "2018" + +[dependencies] +heph = { version = "0.3.0", path = "../", default-features = false } +httparse = { version = "1.4.0", default-features = false } +httpdate = { version = "1.0.0", default-features = false } +log = { version = "0.4.8", default-features = false } +itoa = { version = "0.4.7", default-features = false } + +[dev-dependencies] +# Enable logging panics via `std-logger`. +std-logger = { version = "0.4.0", default-features = false, features = ["log-panic", "nightly"] } + +[dev-dependencies.heph] +path = "../" +features = ["test"] diff --git a/http/LICENSE b/http/LICENSE new file mode 100644 index 000000000..1cc94c7c8 --- /dev/null +++ b/http/LICENSE @@ -0,0 +1,20 @@ +Copyright (C) 2021 Thomas de Zeeuw + + +Permission is hereby granted, free of charge, to any person obtaining a copy of +this software and associated documentation files (the "Software"), to deal in +the Software without restriction, including without limitation the rights to +use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies +of the Software, and to permit persons to whom the Software is furnished to do +so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. diff --git a/http/Makefile b/http/Makefile new file mode 120000 index 000000000..d0b0e8e00 --- /dev/null +++ b/http/Makefile @@ -0,0 +1 @@ +../Makefile \ No newline at end of file diff --git a/http/examples/my_ip.rs b/http/examples/my_ip.rs new file mode 100644 index 000000000..0a3a1509e --- /dev/null +++ b/http/examples/my_ip.rs @@ -0,0 +1,141 @@ +#![feature(never_type)] + +use std::borrow::Cow; +use std::io; +use std::net::SocketAddr; +use std::time::Duration; + +use heph::actor::{self, Actor, NewActor}; +use heph::net::TcpStream; +use heph::rt::{self, Runtime, ThreadLocal}; +use heph::spawn::options::{ActorOptions, Priority}; +use heph::supervisor::{Supervisor, SupervisorStrategy}; +use heph::timer::Deadline; +use heph_http::body::OneshotBody; +use heph_http::{self as http, Header, HeaderName, Headers, HttpServer, Method, StatusCode}; +use log::{debug, error, info, warn}; + +fn main() -> Result<(), rt::Error> { + std_logger::init(); + + let actor = http_actor as fn(_, _, _) -> _; + let address = "127.0.0.1:7890".parse().unwrap(); + let server = HttpServer::setup(address, conn_supervisor, actor, ActorOptions::default()) + .map_err(rt::Error::setup)?; + + let mut runtime = Runtime::setup().use_all_cores().build()?; + runtime.run_on_workers(move |mut runtime_ref| -> io::Result<()> { + let options = ActorOptions::default().with_priority(Priority::LOW); + let server_ref = runtime_ref.try_spawn_local(ServerSupervisor, server, (), options)?; + + runtime_ref.receive_signals(server_ref.try_map()); + Ok(()) + })?; + info!("listening on {}", address); + runtime.start() +} + +/// Our supervisor for the TCP server. +#[derive(Copy, Clone, Debug)] +struct ServerSupervisor; + +impl Supervisor for ServerSupervisor +where + NA: NewActor, + NA::Actor: Actor>, +{ + fn decide(&mut self, err: http::server::Error) -> SupervisorStrategy<()> { + use http::server::Error::*; + match err { + Accept(err) => { + error!("error accepting new connection: {}", err); + SupervisorStrategy::Restart(()) + } + NewActor(_) => unreachable!(), + } + } + + fn decide_on_restart_error(&mut self, err: io::Error) -> SupervisorStrategy<()> { + error!("error restarting the TCP server: {}", err); + SupervisorStrategy::Stop + } + + fn second_restart_error(&mut self, err: io::Error) { + error!("error restarting the actor a second time: {}", err); + } +} + +fn conn_supervisor(err: io::Error) -> SupervisorStrategy<(TcpStream, SocketAddr)> { + error!("error handling connection: {}", err); + SupervisorStrategy::Stop +} + +const READ_TIMEOUT: Duration = Duration::from_secs(10); +const ALIVE_TIMEOUT: Duration = Duration::from_secs(120); +const WRITE_TIMEOUT: Duration = Duration::from_secs(10); + +async fn http_actor( + mut ctx: actor::Context, + mut connection: http::Connection, + address: SocketAddr, +) -> io::Result<()> { + info!("accepted connection: source={}", address); + connection.set_nodelay(true)?; + + let mut read_timeout = READ_TIMEOUT; + let mut headers = Headers::EMPTY; + loop { + let fut = Deadline::after(&mut ctx, read_timeout, connection.next_request()); + let (code, body, should_close) = match fut.await? { + Ok(Some(request)) => { + info!("received request: {:?}: source={}", request, address); + if request.path() != "/" { + (StatusCode::NOT_FOUND, "Not found".into(), false) + } else if !matches!(request.method(), Method::Get | Method::Head) { + headers.add(Header::new(HeaderName::ALLOW, b"GET, HEAD")); + let body = "Method not allowed".into(); + (StatusCode::METHOD_NOT_ALLOWED, body, false) + } else if !request.body().is_empty() { + let body = Cow::from("Not expecting a body"); + (StatusCode::PAYLOAD_TOO_LARGE, body, true) + } else { + // This will allocate a new string which isn't the most + // efficient way to do this, but it's the easiest so we'll + // keep this for sake of example. + let body = Cow::from(address.ip().to_string()); + (StatusCode::OK, body, false) + } + } + // No more requests. + Ok(None) => return Ok(()), + Err(err) => { + warn!("error reading request: {}: source={}", err, address); + let code = err.proper_status_code(); + let body = Cow::from(format!("Bad request: {}", err)); + (code, body, err.should_close()) + } + }; + + if should_close { + headers.add(Header::new(HeaderName::CONNECTION, b"close")); + } + + debug!( + "sending response: code={}, body='{}', source={}", + code, body, address + ); + let body = OneshotBody::new(body.as_bytes()); + let write_response = connection.respond(code, &headers, body); + Deadline::after(&mut ctx, WRITE_TIMEOUT, write_response).await?; + + if should_close { + warn!("closing connection: source={}", address); + return Ok(()); + } + + // Now that we've read a single request we can wait a little for the + // next one so that we can reuse the resources for the next request. + read_timeout = ALIVE_TIMEOUT; + headers.clear(); + } +} diff --git a/http/src/body.rs b/http/src/body.rs new file mode 100644 index 000000000..bc707a44f --- /dev/null +++ b/http/src/body.rs @@ -0,0 +1,447 @@ +//! Module with HTTP body related types. +//! +//! See the [`Body`] trait. + +use std::io::{self, IoSlice}; +use std::marker::PhantomData; +use std::num::NonZeroUsize; +use std::stream::Stream; + +use heph::net::tcp::stream::{FileSend, SendAll, TcpStream}; + +/// Trait that defines a HTTP body. +/// +/// The trait can't be implemented outside of this create and is implemented by +/// the following types: +/// +/// * [`EmptyBody`]: no/empty body. +/// * [`OneshotBody`]: body consisting of a single slice of bytes (`&[u8]`). +/// * [`StreamingBody`]: body that is streaming, with a known length. +/// * [`ChunkedBody`]: body that is streaming, with a *un*known length. This +/// uses HTTP chunked encoding to transfer the body. +/// * [`FileBody`]: uses a file as body, sending it's content using the +/// `sendfile(2)` system call. +pub trait Body<'a>: PrivateBody<'a> { + /// Length of the body, or the body will be chunked. + fn length(&self) -> BodyLength; +} + +/// Length of a body. +#[derive(Copy, Clone, Debug, Eq, PartialEq)] +pub enum BodyLength { + /// Body length is known. + Known(usize), + /// Body length is unknown and the body will be transfered using chunked + /// encoding. + Chunked, +} + +mod private { + use std::future::Future; + use std::io::{self, IoSlice}; + use std::num::NonZeroUsize; + use std::pin::Pin; + use std::stream::Stream; + use std::task::{self, Poll}; + + use heph::net::tcp::stream::FileSend; + use heph::net::TcpStream; + + /// Private extention of [`Body`]. + /// + /// [`Body`]: super::Body + pub trait PrivateBody<'body> { + type WriteBody<'stream, 'head>: Future>; + + /// Write a HTTP message to `stream`. + /// + /// The `http_head` buffer contains the HTTP header (i.e. request/status + /// line and all headers), this must still be written to the `stream` + /// also. + fn write_message<'stream, 'head>( + self, + stream: &'stream mut TcpStream, + http_head: &'head [u8], + ) -> Self::WriteBody<'stream, 'head> + where + 'body: 'head; + } + + /// See [`OneshotBody`]. + #[derive(Debug)] + pub struct SendOneshotBody<'s, 'b> { + pub(super) stream: &'s mut TcpStream, + // HTTP head and body. + pub(super) bufs: [IoSlice<'b>; 2], + } + + impl<'s, 'b> Future for SendOneshotBody<'s, 'b> { + type Output = io::Result<()>; + + fn poll(self: Pin<&mut Self>, _: &mut task::Context<'_>) -> Poll { + let SendOneshotBody { stream, bufs } = Pin::into_inner(self); + loop { + match stream.try_send_vectored(bufs) { + Ok(0) => return Poll::Ready(Err(io::ErrorKind::WriteZero.into())), + Ok(n) => { + let head_len = bufs[0].len(); + let body_len = bufs[1].len(); + if n >= head_len + body_len { + // Written everything. + return Poll::Ready(Ok(())); + } else if n <= head_len { + // Only written part of the head, advance the head + // buffer. + let _ = IoSlice::advance(&mut bufs[..1], n); + } else { + // Written entire head. + bufs[0] = IoSlice::new(&[]); + let _ = IoSlice::advance(&mut bufs[1..], n - head_len); + } + } + Err(ref err) if err.kind() == io::ErrorKind::WouldBlock => { + return Poll::Pending + } + Err(ref err) if err.kind() == io::ErrorKind::Interrupted => continue, + Err(err) => return Poll::Ready(Err(err)), + } + } + } + } + + /// See [`StreamingBody`]. + #[derive(Debug)] + pub struct SendStreamingBody<'s, 'h, 'b, B> { + pub(super) stream: &'s mut TcpStream, + pub(super) head: &'h [u8], + /// Bytes left to write from `body`, not counting the HTTP head. + pub(super) left: usize, + pub(super) body: B, + /// Slice of bytes from `body`. + pub(super) body_bytes: Option<&'b [u8]>, + } + + impl<'s, 'h, 'b, B> Future for SendStreamingBody<'s, 'h, 'b, B> + where + B: Stream>, + { + type Output = io::Result<()>; + + fn poll(self: Pin<&mut Self>, ctx: &mut task::Context<'_>) -> Poll { + // SAFETY: not moving `body: B`, ensuring it's still pinned. + #[rustfmt::skip] + let SendStreamingBody { stream, head, left, body, body_bytes } = unsafe { Pin::into_inner_unchecked(self) }; + let mut body = unsafe { Pin::new_unchecked(body) }; + + // Send the HTTP head first. + // TODO: try to use vectored I/O on first call. + while !head.is_empty() { + match stream.try_send(head) { + Ok(0) => return Poll::Ready(Err(io::ErrorKind::WriteZero.into())), + Ok(n) => *head = &head[n..], + Err(ref err) if err.kind() == io::ErrorKind::WouldBlock => { + return Poll::Pending + } + Err(ref err) if err.kind() == io::ErrorKind::Interrupted => continue, + Err(err) => return Poll::Ready(Err(err)), + } + } + + while *left != 0 { + // We have bytes we need to send. + if let Some(bytes) = body_bytes.as_mut() { + // TODO: check `bytes.len()` <= `left`. + match stream.try_send(*bytes) { + Ok(0) => return Poll::Ready(Err(io::ErrorKind::WriteZero.into())), + Ok(n) => { + *left -= n; + if n >= bytes.len() { + *body_bytes = None; + } else { + *bytes = &bytes[n..]; + continue; + } + } + Err(ref err) if err.kind() == io::ErrorKind::WouldBlock => { + return Poll::Pending + } + Err(ref err) if err.kind() == io::ErrorKind::Interrupted => continue, + Err(err) => return Poll::Ready(Err(err)), + } + } + + // Read some bytes from the `body` stream. + match body.as_mut().poll_next(ctx) { + Poll::Ready(Some(Ok(bytes))) => *body_bytes = Some(bytes), + Poll::Ready(Some(Err(err))) => return Poll::Ready(Err(err)), + Poll::Ready(None) => { + // NOTE: this shouldn't happend. + debug_assert!(*left == 0, "short body provided to `StreamingBody`"); + return Poll::Ready(Ok(())); + } + Poll::Pending => return Poll::Pending, + } + } + + Poll::Ready(Ok(())) + } + } + + /// See [`FileBody`]. + #[derive(Debug)] + pub struct SendFileBody<'s, 'h, 'f, F> { + pub(super) stream: &'s mut TcpStream, + pub(super) head: &'h [u8], + pub(super) file: &'f F, + pub(super) offset: usize, + pub(super) end: NonZeroUsize, + } + + impl<'s, 'h, 'f, F> Future for SendFileBody<'s, 'h, 'f, F> + where + F: FileSend, + { + type Output = io::Result<()>; + + fn poll(self: Pin<&mut Self>, _: &mut task::Context<'_>) -> Poll { + #[rustfmt::skip] + let SendFileBody { stream, head, file, offset, end } = Pin::into_inner(self); + + // Send the HTTP head first. + while !head.is_empty() { + match stream.try_send(head) { + Ok(0) => return Poll::Ready(Err(io::ErrorKind::WriteZero.into())), + Ok(n) => *head = &head[n..], + Err(ref err) if err.kind() == io::ErrorKind::WouldBlock => { + return Poll::Pending + } + Err(ref err) if err.kind() == io::ErrorKind::Interrupted => continue, + Err(err) => return Poll::Ready(Err(err)), + } + } + + while end.get() > *offset { + let length = NonZeroUsize::new(end.get() - *offset); + match stream.try_send_file(*file, *offset, length) { + // All bytes were send. + Ok(0) => return Poll::Ready(Ok(())), + Ok(n) => *offset += n, + Err(ref err) if err.kind() == io::ErrorKind::WouldBlock => { + return Poll::Pending + } + Err(ref err) if err.kind() == io::ErrorKind::Interrupted => continue, + Err(err) => return Poll::Ready(Err(err)), + } + } + + Poll::Ready(Ok(())) + } + } +} + +pub(crate) use private::{PrivateBody, SendStreamingBody}; +use private::{SendFileBody, SendOneshotBody}; + +/// An empty body. +#[derive(Debug)] +pub struct EmptyBody; + +impl<'b> Body<'b> for EmptyBody { + fn length(&self) -> BodyLength { + BodyLength::Known(0) + } +} + +impl<'b> PrivateBody<'b> for EmptyBody { + type WriteBody<'s, 'h> = SendAll<'s, 'h>; + + fn write_message<'s, 'h>( + self, + stream: &'s mut TcpStream, + http_head: &'h [u8], + ) -> Self::WriteBody<'s, 'h> + where + 'b: 'h, + { + // Just need to write the HTTP head as we don't have a body. + stream.send_all(http_head) + } +} + +/// Body length and content is known in advance. Send in a single payload (i.e. +/// not chunked). +#[derive(Debug)] +pub struct OneshotBody<'b> { + bytes: &'b [u8], +} + +impl<'b> OneshotBody<'b> { + /// Create a new one-shot body. + pub const fn new(body: &'b [u8]) -> OneshotBody<'b> { + OneshotBody { bytes: body } + } +} + +impl<'b> Body<'b> for OneshotBody<'b> { + fn length(&self) -> BodyLength { + BodyLength::Known(self.bytes.len()) + } +} + +impl<'b> PrivateBody<'b> for OneshotBody<'b> { + type WriteBody<'s, 'h> = SendOneshotBody<'s, 'h>; + + fn write_message<'s, 'h>( + self, + stream: &'s mut TcpStream, + http_head: &'h [u8], + ) -> Self::WriteBody<'s, 'h> + where + 'b: 'h, + { + let head = IoSlice::new(http_head); + let body = IoSlice::new(self.bytes); + SendOneshotBody { + stream, + bufs: [head, body], + } + } +} + +impl<'b> From<&'b [u8]> for OneshotBody<'b> { + fn from(body: &'b [u8]) -> Self { + OneshotBody::new(body) + } +} + +impl<'b> From<&'b str> for OneshotBody<'b> { + fn from(body: &'b str) -> Self { + OneshotBody::new(body.as_bytes()) + } +} + +/// Streaming body with a known length. Send in a single payload (i.e. not +/// chunked). +#[derive(Debug)] +pub struct StreamingBody<'b, B> { + length: usize, + body: B, + _body_lifetime: PhantomData<&'b [u8]>, +} + +impl<'b, B> StreamingBody<'b, B> +where + B: Stream>, +{ + /// Use a [`Stream`] as HTTP body with a known length. + pub const fn new(length: usize, stream: B) -> StreamingBody<'b, B> { + StreamingBody { + length, + body: stream, + _body_lifetime: PhantomData, + } + } +} + +impl<'b, B> Body<'b> for StreamingBody<'b, B> +where + B: Stream>, +{ + fn length(&self) -> BodyLength { + BodyLength::Known(self.length) + } +} + +impl<'b, B> PrivateBody<'b> for StreamingBody<'b, B> +where + B: Stream>, +{ + type WriteBody<'s, 'h> = SendStreamingBody<'s, 'h, 'b, B>; + + fn write_message<'s, 'h>( + self, + stream: &'s mut TcpStream, + head: &'h [u8], + ) -> Self::WriteBody<'s, 'h> + where + 'b: 'h, + { + SendStreamingBody { + stream, + body: self.body, + head, + left: self.length, + body_bytes: None, + } + } +} + +/// Streaming body with an unknown length. Send in multiple chunks. +#[derive(Debug)] +pub struct ChunkedBody<'b, B> { + stream: B, + _body_lifetime: PhantomData<&'b [u8]>, +} + +// TODO: implement `Body` for `ChunkedBody`. + +/// Body that sends the entire file `F`. +#[derive(Debug)] +pub struct FileBody<'f, F> { + file: &'f F, + /// Start offset into the `file`. + offset: usize, + /// Length of the file, or the maximum number of bytes to send (minus + /// `offset`). + /// Always: `end >= offset`. + end: NonZeroUsize, +} + +impl<'f, F> FileBody<'f, F> +where + F: FileSend, +{ + /// Use a file as HTTP body. + /// + /// This uses the bytes `offset..end` from `file` as HTTP body and sends + /// them using `sendfile(2)` (using [`TcpStream::send_file`]). + pub const fn new(file: &'f F, offset: usize, end: NonZeroUsize) -> FileBody<'f, F> { + debug_assert!(end.get() >= offset); + FileBody { file, offset, end } + } +} + +impl<'f, F> Body<'f> for FileBody<'f, F> +where + F: FileSend, +{ + fn length(&self) -> BodyLength { + // NOTE: per the comment on `end`: `end >= offset`, so this can't + // underflow. + BodyLength::Known(self.end.get() - self.offset) + } +} + +impl<'f, F> PrivateBody<'f> for FileBody<'f, F> +where + F: FileSend, +{ + type WriteBody<'s, 'h> = SendFileBody<'s, 'h, 'f, F>; + + fn write_message<'s, 'h>( + self, + stream: &'s mut TcpStream, + head: &'h [u8], + ) -> Self::WriteBody<'s, 'h> + where + 'f: 'h, + { + SendFileBody { + stream, + head, + file: self.file, + offset: self.offset, + end: self.end, + } + } +} diff --git a/http/src/client.rs b/http/src/client.rs new file mode 100644 index 000000000..e197b059a --- /dev/null +++ b/http/src/client.rs @@ -0,0 +1,732 @@ +//! Module with the HTTP client implementation. + +// FIXME: remove. +#![allow(missing_docs)] + +use std::cmp::min; +use std::future::Future; +use std::net::SocketAddr; +use std::pin::Pin; +use std::task::{self, Poll}; +use std::{fmt, io}; + +use heph::net::tcp::stream::{self, TcpStream}; +use heph::{actor, rt}; + +use crate::body::{BodyLength, EmptyBody}; +use crate::header::{FromHeaderValue, HeaderName, Headers}; +use crate::{ + map_version_byte, trim_ws, Method, Response, StatusCode, BUF_SIZE, MAX_HEADERS, MAX_HEAD_SIZE, + MIN_READ_SIZE, +}; + +#[derive(Debug)] +pub struct Client { + stream: TcpStream, + buf: Vec, + /// Number of bytes of `buf` that are already parsed. + /// NOTE: this may be larger then `buf.len()`, in which case a `Body` was + /// dropped without reading it entirely. + parsed_bytes: usize, +} + +impl Client { + /// Create a new HTTP client, connected to `address`. + pub fn connect( + ctx: &mut actor::Context, + address: SocketAddr, + ) -> io::Result + where + RT: rt::Access, + { + TcpStream::connect(ctx, address).map(|connect| Connect { connect }) + } + + /// Send a GET request. + /// + /// # Notes + /// + /// Any [`ResponseError`] are turned into [`io::Error`]. If you want to + /// handle the `ResponseError`s separately use [`Client::request`]. + pub async fn get<'c, 'p>(&'c mut self, path: &'p str) -> io::Result>> { + let res = self + .request(Method::Get, path, &Headers::EMPTY, EmptyBody) + .await; + match res { + Ok(Ok(response)) => Ok(response), + Ok(Err(err)) => Err(err.into()), + Err(err) => Err(err), + } + } + + /// Make a [`Request`] and wait (non-blocking) for a [`Response`]. + /// + /// [`Request`]: crate::Request + /// + /// # Notes + /// + /// This always uses HTTP/1.1 to make the requests. + /// + /// If the server doesn't respond this return an [`io::Error`] with + /// [`io::ErrorKind::UnexpectedEof`]. + pub async fn request<'c, 'b, B>( + &'c mut self, + method: Method, + path: &str, + headers: &Headers, + body: B, + ) -> io::Result>, ResponseError>> + where + B: crate::Body<'b>, + { + self.send_request(method, path, headers, body).await?; + match self.read_response(method).await { + Ok(Ok(Some(request))) => Ok(Ok(request)), + Ok(Ok(None)) => Err(io::Error::new( + io::ErrorKind::UnexpectedEof, + "no HTTP response", + )), + Ok(Err(err)) => Ok(Err(err)), + Err(err) => Err(err), + } + } + + pub async fn send_request<'b, B>( + &mut self, + method: Method, + path: &str, + headers: &Headers, + body: B, + ) -> io::Result<()> + where + B: crate::Body<'b>, + { + // Clear bytes from the previous request, keeping the bytes of the + // response. + self.clear_buffer(); + let ignore_end = self.buf.len(); + + // Request line. + self.buf.extend_from_slice(method.as_str().as_bytes()); + self.buf.push(b' '); + self.buf.extend_from_slice(path.as_bytes()); + self.buf.extend_from_slice(b" HTTP/1.1\r\n"); + + // Headers. + let mut set_user_agent_header = false; + let mut set_content_length_header = false; + let mut set_transfer_encoding_header = false; + for header in headers.iter() { + let name = header.name(); + // Field-name: + self.buf.extend_from_slice(name.as_ref().as_bytes()); + // NOTE: spacing after the colon (`:`) is optional. + self.buf.extend_from_slice(b": "); + // Append the header's value. + // NOTE: `header.value` shouldn't contain CRLF (`\r\n`). + self.buf.extend_from_slice(header.value()); + self.buf.extend_from_slice(b"\r\n"); + + if name == &HeaderName::USER_AGENT { + set_user_agent_header = true; + } else if name == &HeaderName::CONTENT_LENGTH { + set_content_length_header = true; + } else if name == &HeaderName::TRANSFER_ENCODING { + set_transfer_encoding_header = true; + } + } + + /* TODO: set "Host" header. + // Provide the "Host" header if the user didn't. + if !set_host_header { + write!(&mut self.buf, "Host: {}\r\n", self.host).unwrap(); + } + */ + + // Provide the "User-Agent" header if the user didn't. + if !set_user_agent_header { + self.buf.extend_from_slice( + concat!("User-Agent: Heph-HTTP/", env!("CARGO_PKG_VERSION"), "\r\n").as_bytes(), + ); + } + + if !set_content_length_header && !set_transfer_encoding_header { + match body.length() { + BodyLength::Known(0) => {} // No need for a "Content-Length" header. + BodyLength::Known(length) => { + let mut itoa_buf = itoa::Buffer::new(); + self.buf.extend_from_slice(b"Content-Length: "); + self.buf + .extend_from_slice(itoa_buf.format(length).as_bytes()); + self.buf.extend_from_slice(b"\r\n"); + } + BodyLength::Chunked => { + self.buf + .extend_from_slice(b"Transfer-Encoding: chunked\r\n"); + } + } + } + + // End of the HTTP head. + self.buf.extend_from_slice(b"\r\n"); + + // Write the request to the stream. + let http_head = &self.buf[ignore_end..]; + body.write_message(&mut self.stream, http_head).await?; + + // Remove the request from the buffer. + self.buf.truncate(ignore_end); + Ok(()) + } + + pub async fn read_response<'a>( + &'a mut self, + request_method: Method, + ) -> io::Result>>, ResponseError>> { + let mut too_short = 0; + loop { + // In case of pipelined responses it could be that while reading a + // previous response's body it partially read the head of the next + // (this) response. To handle this we first attempt to parse the + // response if we have more than zero bytes (of the next response) + // in the first iteration of the loop. + while self.parsed_bytes >= self.buf.len() || self.buf.len() <= too_short { + // While we didn't read the entire previous response body, or + // while we have less than `too_short` bytes we try to receive + // some more bytes. + + self.clear_buffer(); + self.buf.reserve(MIN_READ_SIZE); + if self.stream.recv(&mut self.buf).await? == 0 { + return if self.buf.is_empty() { + // Read the entire stream, so we're done. + Ok(Ok(None)) + } else { + // Couldn't read any more bytes, but we still have bytes + // in the buffer. This means it contains a partial + // response. + Ok(Err(ResponseError::IncompleteResponse)) + }; + } + } + + let mut headers = [httparse::EMPTY_HEADER; MAX_HEADERS]; + let mut response = httparse::Response::new(&mut headers); + // SAFETY: because we received until at least `self.parsed_bytes >= + // self.buf.len()` above, we can safely slice the buffer.. + match response.parse(&self.buf[self.parsed_bytes..]) { + Ok(httparse::Status::Complete(head_length)) => { + self.parsed_bytes += head_length; + + // SAFETY: all these unwraps are safe because `parse` above + // ensures there all `Some`. + let version = map_version_byte(response.version.unwrap()); + let status = StatusCode(response.code.unwrap()); + // NOTE: don't care about the reason. + + // RFC 7230 section 3.3.3 Message Body Length. + let mut body_length: Option = None; + let res = Headers::from_httparse_headers(response.headers, |name, value| { + if *name == HeaderName::CONTENT_LENGTH { + // RFC 7230 section 3.3.3 point 4: + // > If a message is received without + // > Transfer-Encoding and with either multiple + // > Content-Length header fields having differing + // > field-values or a single Content-Length header + // > field having an invalid value, then the message + // > framing is invalid and the recipient MUST treat + // > it as an unrecoverable error. [..] If this is a + // > response message received by a user agent, the + // > user agent MUST close the connection to the + // > server and discard the received response. + if let Ok(length) = FromHeaderValue::from_bytes(value) { + match body_length.as_mut() { + Some(ResponseBodyLength::Known(body_length)) + if *body_length == length => {} + Some(ResponseBodyLength::Known(_)) => { + return Err(ResponseError::DifferentContentLengths) + } + Some( + ResponseBodyLength::Chunked | ResponseBodyLength::ReadToEnd, + ) => { + return Err(ResponseError::ContentLengthAndTransferEncoding) + } + // RFC 7230 section 3.3.3 point 5: + // > If a valid Content-Length header field + // > is present without Transfer-Encoding, + // > its decimal value defines the expected + // > message body length in octets. + None => body_length = Some(ResponseBodyLength::Known(length)), + } + } else { + return Err(ResponseError::InvalidContentLength); + } + } else if *name == HeaderName::TRANSFER_ENCODING { + let mut encodings = value.split(|b| *b == b',').peekable(); + while let Some(encoding) = encodings.next() { + match trim_ws(encoding) { + b"chunked" => { + // RFC 7230 section 3.3.3 point 3: + // > If a message is received with both + // > a Transfer-Encoding and a + // > Content-Length header field, the + // > Transfer-Encoding overrides the + // > Content-Length. Such a message + // > might indicate an attempt to + // > perform request smuggling (Section + // > 9.5) or response splitting (Section + // > 9.4) and ought to be handled as an + // > error. + if body_length.is_some() { + return Err( + ResponseError::ContentLengthAndTransferEncoding, + ); + } + + // RFC 7230 section 3.3.3 point 3: + // > If a Transfer-Encoding header field + // > is present in a response and the + // > chunked transfer coding is not the + // > final encoding, the message body + // > length is determined by reading the + // > connection until it is closed by + // > the server. + if encodings.peek().is_some() { + body_length = Some(ResponseBodyLength::ReadToEnd) + } else { + body_length = Some(ResponseBodyLength::Chunked); + } + } + b"identity" => {} // No changes. + // TODO: support "compress", "deflate" and + // "gzip". + _ => return Err(ResponseError::UnsupportedTransferEncoding), + } + } + } + Ok(()) + }); + let headers = match res { + Ok(headers) => headers, + Err(err) => return Ok(Err(err)), + }; + + let kind = match body_length { + // RFC 7230 section 3.3.3 point 2: + // > Any 2xx (Successful) response to a CONNECT request + // > implies that the connection will become a tunnel + // > immediately after the empty line that concludes the + // > header fields. A client MUST ignore any + // > Content-Length or Transfer-Encoding header fields + // > received in such a message. + _ if matches!(request_method, Method::Connect) + && status.is_successful() => + { + BodyKind::Known { left: 0 } + } + Some(ResponseBodyLength::Known(left)) => BodyKind::Known { left }, + Some(ResponseBodyLength::Chunked) => { + #[allow(clippy::cast_possible_truncation)] // For truncate below. + match httparse::parse_chunk_size(&self.buf[self.parsed_bytes..]) { + Ok(httparse::Status::Complete((idx, chunk_size))) => { + self.parsed_bytes += idx; + BodyKind::Chunked { + // FIXME: add check here. It's fine on + // 64 bit (only currently supported). + left_in_chunk: chunk_size as usize, + read_complete: chunk_size == 0, + } + } + Ok(httparse::Status::Partial) => BodyKind::Chunked { + left_in_chunk: 0, + read_complete: false, + }, + Err(_) => return Ok(Err(ResponseError::InvalidChunkSize)), + } + } + Some(ResponseBodyLength::ReadToEnd) => BodyKind::Unknown, + // RFC 7230 section 3.3.3 point 1: + // > Any response to a HEAD request and any response + // > with a 1xx (Informational), 204 (No Content), or + // > 304 (Not Modified) status code is always terminated + // > by the first empty line after the header fields, + // > regardless of the header fields present in the + // > message, and thus cannot contain a message body. + // NOTE: we don't follow this strictly as a server might + // not be implemented correctly, in which case we follow + // the "Content-Length"/"Transfer-Encoding" header + // instead (above). + None if !request_method.expects_body() || !status.includes_body() => { + BodyKind::Known { left: 0 } + } + // RFC 7230 section 3.3.3 point 7: + // > Otherwise, this is a response message without a + // > declared message body length, so the message body + // > length is determined by the number of octets + // > received prior to the server closing the + // > connection. + None => BodyKind::Unknown, + }; + let body = Body { client: self, kind }; + return Ok(Ok(Some(Response::new(version, status, headers, body)))); + } + Ok(httparse::Status::Partial) => { + // Buffer doesn't include the entire response head, try + // reading more bytes (in the next iteration). + too_short = self.buf.len(); + if too_short >= MAX_HEAD_SIZE { + return Ok(Err(ResponseError::HeadTooLarge)); + } + + continue; + } + Err(err) => return Ok(Err(ResponseError::from_httparse(err))), + } + } + } + + async fn read_chunk( + &mut self, + // Fields of `BodyKind::Chunked`: + left_in_chunk: &mut usize, + read_complete: &mut bool, + ) -> io::Result<()> { + loop { + match httparse::parse_chunk_size(&self.buf[self.parsed_bytes..]) { + #[allow(clippy::cast_possible_truncation)] // For truncate below. + Ok(httparse::Status::Complete((idx, chunk_size))) => { + self.parsed_bytes += idx; + if chunk_size == 0 { + *read_complete = true; + } + // FIXME: add check here. It's fine on 64 bit (only currently + // supported). + *left_in_chunk = chunk_size as usize; + return Ok(()); + } + Ok(httparse::Status::Partial) => {} // Read some more data below. + Err(_) => { + return Err(io::Error::new( + io::ErrorKind::InvalidData, + "invalid chunk size", + )) + } + } + + // Ensure we have space in the buffer to read into. + self.clear_buffer(); + self.buf.reserve(MIN_READ_SIZE); + + if self.stream.recv(&mut self.buf).await? == 0 { + return Err(io::ErrorKind::UnexpectedEof.into()); + } + } + } + + /// Clear parsed request(s) from the buffer. + fn clear_buffer(&mut self) { + let buf_len = self.buf.len(); + if self.parsed_bytes >= buf_len { + // Parsed all bytes in the buffer, so we can clear it. + self.buf.clear(); + self.parsed_bytes -= buf_len; + } + + // TODO: move bytes to the start. + } +} + +enum ResponseBodyLength { + /// Body length is known. + Known(usize), + /// Body length is unknown and the body will be transfered using chunked + /// encoding. + Chunked, + /// Body length is unknown, but the response is not chunked. Read until the + /// connection is closed. + ReadToEnd, +} + +/// [`Future`] behind [`Client::connect`]. +#[derive(Debug)] +pub struct Connect { + connect: stream::Connect, +} + +impl Future for Connect { + type Output = io::Result; + + #[track_caller] + fn poll(mut self: Pin<&mut Self>, ctx: &mut task::Context<'_>) -> Poll { + match Pin::new(&mut self.connect).poll(ctx) { + Poll::Ready(Ok(mut stream)) => { + stream.set_nodelay(true)?; + Poll::Ready(Ok(Client { + stream, + buf: Vec::with_capacity(BUF_SIZE), + parsed_bytes: 0, + })) + } + Poll::Ready(Err(err)) => return Poll::Ready(Err(err)), + Poll::Pending => Poll::Pending, + } + } +} + +#[derive(Debug)] +pub struct Body<'c> { + client: &'c mut Client, + kind: BodyKind, +} + +#[derive(Debug)] +enum BodyKind { + /// Known body length. + Known { + /// Number of unread (by the user) bytes. + left: usize, + }, + /// Chunked transfer encoding. + Chunked { + /// Number of unread (by the user) bytes in this chunk. + left_in_chunk: usize, + /// Read all chunks. + read_complete: bool, + }, + /// Body length is not known, read the body until the server closes the + /// connection. + Unknown, +} + +impl<'c> Body<'c> { + /* + /// Returns `true` if the body is completely read (or was empty to begin + /// with). + /// + /// # Notes + /// + /// This can return `false` for empty bodies using chunked encoding if not + /// enough bytes have been read yet. Using chunked encoding we don't know + /// the length upfront as it it's determined by reading the length of each + /// chunk. If the send request only contained the HTTP head (i.e. no body) + /// and uses chunked encoding this would return `false`, as body length is + /// unknown and thus not empty. However if the body would then send a single + /// empty chunk (signaling the end of the body), this would return `true` as + /// it turns out the body is indeed empty. + pub fn is_empty(&self) -> bool { + match self.kind { + BodyKind::Known { left } => left == 0, + BodyKind::Chunked { + left_in_chunk, + read_complete, + } => read_complete && left_in_chunk == 0, + } + } + + /// Returns `true` if the body is chunked. + pub fn is_chunked(&self) -> bool { + matches!(self.kind, BodyKind::Chunked { .. }) + } + */ + + /* + TODO: RFC 7230 section 3.3.3 point 5: + [..] If the sender closes the connection or the recipient times out + before the indicated number of octets are received, the recipient MUST + consider the message to be incomplete and close the connection. + */ + + pub async fn read_all(&mut self, buf: &mut Vec, limit: usize) -> io::Result<()> { + let mut total = 0; + loop { + // Copy bytes in our buffer. + let bytes = self.buf_bytes(); + let len = bytes.len(); + if limit < total + len { + return Err(io::Error::new(io::ErrorKind::Other, "body too large")); + } + + buf.extend_from_slice(bytes); + self.processed(len); + total += len; + + match &mut self.kind { + // Read all the bytes from the body. + BodyKind::Known { left: 0 } => return Ok(()), + // Read all the bytes in the chunk, so need to read another + // chunk. + BodyKind::Chunked { + left_in_chunk, + read_complete, + } if *left_in_chunk == 0 => { + if *read_complete { + return Ok(()); + } + + self.client.read_chunk(left_in_chunk, read_complete).await?; + // Copy read bytes again. + continue; + } + // Continue to reading below. + BodyKind::Known { .. } | BodyKind::Chunked { .. } | BodyKind::Unknown => break, + } + } + + loop { + // Limit the read until the end of the chunk/body. + let chunk_len = match self.kind { + BodyKind::Known { left } => Some(left), + BodyKind::Chunked { left_in_chunk, .. } => Some(left_in_chunk), + BodyKind::Unknown => None, + }; + + if let Some(chunk_len) = chunk_len { + if chunk_len == 0 { + return Ok(()); + } else if total + chunk_len > limit { + return Err(io::Error::new(io::ErrorKind::Other, "body too large")); + } + } + + let capacity = chunk_len + .unwrap_or_else(|| min(MIN_READ_SIZE, limit.saturating_sub(buf.capacity()))); + (&mut *buf).reserve(capacity); + if let Some(chunk_len) = chunk_len { + // FIXME: doesn't deal with chunked bodies. + return self.client.stream.recv_n(&mut *buf, chunk_len).await; + } else { + let n = self.client.stream.recv(&mut *buf).await?; + if n == 0 { + return Ok(()); + } + total += n; + if total > limit { + return Err(io::Error::new(io::ErrorKind::Other, "body too large")); + } + } + } + } + + /// Returns the bytes currently in the buffer. + /// + /// This is limited to the bytes of this request/chunk, i.e. it doesn't + /// contain the next request/chunk. + fn buf_bytes(&self) -> &[u8] { + let bytes = &self.client.buf[self.client.parsed_bytes..]; + match self.kind { + BodyKind::Known { left } + | BodyKind::Chunked { + left_in_chunk: left, + .. + } if bytes.len() > left => &bytes[..left], + _ => bytes, + } + } + + /// Mark `n` bytes are processed. + fn processed(&mut self, n: usize) { + // TODO: should this be `unsafe`? We don't do underflow checks... + match &mut self.kind { + BodyKind::Known { left } => *left -= n, + BodyKind::Chunked { left_in_chunk, .. } => *left_in_chunk -= n, + BodyKind::Unknown => {} + } + self.client.parsed_bytes += n; + } +} + +// FIXME: remove body from `Client` if it's dropped before it's fully read. + +/// Error parsing HTTP response. +#[derive(Copy, Clone, Debug, Eq, PartialEq)] +pub enum ResponseError { + /// Missing part of response. + IncompleteResponse, + /// HTTP Head (start line and headers) is too large. + /// + /// Limit is defined by [`MAX_HEAD_SIZE`]. + HeadTooLarge, + /// Value in the "Content-Length" header is invalid. + InvalidContentLength, + /// Multiple "Content-Length" headers were present with differing values. + DifferentContentLengths, + /// Invalid byte in header name. + InvalidHeaderName, + /// Invalid byte in header value. + InvalidHeaderValue, + /// Number of headers send in the request is larger than [`MAX_HEADERS`]. + TooManyHeaders, + /// Unsupported "Transfer-Encoding" header. + UnsupportedTransferEncoding, + /// Response contains both "Content-Length" and "Transfer-Encoding" headers. + /// + /// An attacker might attempt to "smuggle a request" ("HTTP Response + /// Smuggling", Linhart et al., June 2005) or "split a response" ("Divide + /// and Conquer - HTTP Response Splitting, Web Cache Poisoning Attacks, and + /// Related Topics", Klein, March 2004). RFC 7230 (see section 3.3.3 point + /// 3) says that this "ought to be handled as an error", and so we do. + ContentLengthAndTransferEncoding, + /// Invalid byte in new line. + InvalidNewLine, + /// Invalid byte in HTTP version. + InvalidVersion, + /// Invalid byte in status code. + InvalidStatus, + /// Chunk size is invalid. + InvalidChunkSize, +} + +impl ResponseError { + /// Returns `true` if the connection should be closed based on the error + /// (after sending a error response). + pub const fn should_close(self) -> bool { + // Currently all errors are fatal for the connection. + true + } + + fn from_httparse(err: httparse::Error) -> ResponseError { + use httparse::Error::*; + match err { + HeaderName => ResponseError::InvalidHeaderName, + HeaderValue => ResponseError::InvalidHeaderValue, + Token => unreachable!(), + NewLine => ResponseError::InvalidNewLine, + Version => ResponseError::InvalidVersion, + TooManyHeaders => ResponseError::TooManyHeaders, + Status => ResponseError::InvalidStatus, + } + } + + fn as_str(self) -> &'static str { + use ResponseError::*; + match self { + IncompleteResponse => "incomplete response", + HeadTooLarge => "response head too large", + InvalidContentLength => "invalid response Content-Length header", + DifferentContentLengths => "response has different Content-Length headers", + InvalidHeaderName => "invalid response header name", + InvalidHeaderValue => "invalid response header value", + TooManyHeaders => "too many response headers", + UnsupportedTransferEncoding => "response has unsupported Transfer-Encoding header", + ContentLengthAndTransferEncoding => { + "response contained both Content-Length and Transfer-Encoding headers" + } + InvalidNewLine => "invalid response syntax", + InvalidVersion => "invalid HTTP response version", + InvalidStatus => "invalid HTTP response status", + InvalidChunkSize => "invalid response chunk size", + } + } +} + +impl From for io::Error { + fn from(err: ResponseError) -> io::Error { + io::Error::new(io::ErrorKind::InvalidData, err.as_str()) + } +} + +impl fmt::Display for ResponseError { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.write_str(self.as_str()) + } +} diff --git a/http/src/header.rs b/http/src/header.rs new file mode 100644 index 000000000..b77a4e830 --- /dev/null +++ b/http/src/header.rs @@ -0,0 +1,1036 @@ +//! Module with HTTP header related types. + +use std::borrow::Cow; +use std::convert::AsRef; +use std::iter::{FromIterator, FusedIterator}; +use std::time::SystemTime; +use std::{fmt, str}; + +use httpdate::parse_http_date; + +use crate::{cmp_lower_case, is_lower_case}; + +/// List of headers. +/// +/// A complete list can be found at the "Message Headers" registry: +/// +pub struct Headers { + /// All values appended in a single allocation. + values: Vec, + /// All parts of the headers. + parts: Vec, +} + +struct HeaderPart { + name: HeaderName<'static>, + /// Indices into `Headers.values`. + start: usize, + end: usize, +} + +impl Headers { + /// Empty list of headers. + pub const EMPTY: Headers = Headers { + values: Vec::new(), + parts: Vec::new(), + }; + + /// Creates new `Headers` from `headers`. + /// + /// Calls `F` for each header. + pub(crate) fn from_httparse_headers( + raw_headers: &[httparse::Header<'_>], + mut f: F, + ) -> Result + where + F: FnMut(&HeaderName<'_>, &[u8]) -> Result<(), E>, + { + let values_len = raw_headers.iter().map(|h| h.value.len()).sum(); + let mut headers = Headers { + values: Vec::with_capacity(values_len), + parts: Vec::with_capacity(raw_headers.len()), + }; + for header in raw_headers { + let name = HeaderName::from_str(header.name); + let value = header.value; + f(&name, value)?; + headers._add(name, value); + } + Ok(headers) + } + + /// Returns the number of headers. + pub fn len(&self) -> usize { + self.parts.len() + } + + /// Returns `true` if this is empty. + pub fn is_empty(&self) -> bool { + self.parts.is_empty() + } + + /// Clear the headers. + /// + /// Removes all headers from the list. + pub fn clear(&mut self) { + self.parts.clear(); + self.values.clear(); + } + + /// Add a new `header`. + /// + /// # Notes + /// + /// This doesn't check for duplicate headers, it just adds it to the list of + /// headers. + pub fn add(&mut self, header: Header<'static, '_>) { + self._add(header.name, header.value) + } + + fn _add(&mut self, name: HeaderName<'static>, value: &[u8]) { + let start = self.values.len(); + self.values.extend_from_slice(value); + let end = self.values.len(); + self.parts.push(HeaderPart { name, start, end }); + } + + /// Get the header with `name`, if any. + /// + /// # Notes + /// + /// If all you need is the header value you can use [`Headers::get_value`]. + pub fn get(&self, name: &HeaderName<'_>) -> Option> { + for part in &self.parts { + if part.name == *name { + return Some(Header { + name: part.name.borrow(), + value: &self.values[part.start..part.end], + }); + } + } + None + } + + /// Get the header's value with `name`, if any. + pub fn get_value<'a>(&'a self, name: &HeaderName<'_>) -> Option<&'a [u8]> { + for part in &self.parts { + if part.name == *name { + return Some(&self.values[part.start..part.end]); + } + } + None + } + + // TODO: remove header? + + /// Returns an iterator over all headers. + /// + /// The order is unspecified. + pub const fn iter<'a>(&'a self) -> Iter<'a> { + Iter { + headers: self, + pos: 0, + } + } +} + +impl From> for Headers { + fn from(header: Header<'static, '_>) -> Headers { + Headers { + values: header.value.to_vec(), + parts: vec![HeaderPart { + name: header.name, + start: 0, + end: header.value.len(), + }], + } + } +} + +impl From<[Header<'static, '_>; N]> for Headers { + fn from(raw_headers: [Header<'static, '_>; N]) -> Headers { + let values_len = raw_headers.iter().map(|h| h.value.len()).sum(); + let mut headers = Headers { + values: Vec::with_capacity(values_len), + parts: Vec::with_capacity(raw_headers.len()), + }; + for header in raw_headers { + headers._add(header.name.clone(), header.value); + } + headers + } +} + +impl From<&[Header<'static, '_>]> for Headers { + fn from(raw_headers: &[Header<'static, '_>]) -> Headers { + let values_len = raw_headers.iter().map(|h| h.value.len()).sum(); + let mut headers = Headers { + values: Vec::with_capacity(values_len), + parts: Vec::with_capacity(raw_headers.len()), + }; + for header in raw_headers { + headers._add(header.name.clone(), header.value); + } + headers + } +} + +impl<'v> FromIterator> for Headers { + fn from_iter(iter: I) -> Headers + where + I: IntoIterator>, + { + let mut headers = Headers::EMPTY; + headers.extend(iter); + headers + } +} + +impl<'v> Extend> for Headers { + fn extend(&mut self, iter: I) + where + I: IntoIterator>, + { + let iter = iter.into_iter(); + let (iter_len, _) = iter.size_hint(); + // Make a guess of 10 bytes per header value on average. + self.values.reserve(iter_len * 10); + self.parts.reserve(iter_len); + for header in iter { + self._add(header.name, header.value); + } + } +} + +impl fmt::Debug for Headers { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + let mut f = f.debug_map(); + for part in &self.parts { + let value = &self.values[part.start..part.end]; + if let Ok(str) = std::str::from_utf8(value) { + let _ = f.entry(&part.name, &str); + } else { + let _ = f.entry(&part.name, &value); + } + } + f.finish() + } +} + +/// Iterator for [`Headers`], see [`Headers::iter`]. +#[derive(Debug)] +pub struct Iter<'a> { + headers: &'a Headers, + pos: usize, +} + +impl<'a> Iterator for Iter<'a> { + type Item = Header<'a, 'a>; + + fn next(&mut self) -> Option { + self.headers.parts.get(self.pos).map(|part| { + let header = Header { + name: part.name.borrow(), + value: &self.headers.values[part.start..part.end], + }; + self.pos += 1; + header + }) + } + + fn size_hint(&self) -> (usize, Option) { + let len = self.len(); + (len, Some(len)) + } + + fn count(self) -> usize { + self.len() + } +} + +impl<'a> ExactSizeIterator for Iter<'a> { + fn len(&self) -> usize { + self.headers.len() - self.pos + } +} + +impl<'a> FusedIterator for Iter<'a> {} + +/// HTTP header. +/// +/// RFC 7230 section 3.2. +#[derive(Clone)] +pub struct Header<'n, 'v> { + name: HeaderName<'n>, + value: &'v [u8], +} + +impl<'n, 'v> Header<'n, 'v> { + /// Create a new `Header`. + /// + /// # Notes + /// + /// `value` MUST NOT contain `\r\n`. + pub const fn new(name: HeaderName<'n>, value: &'v [u8]) -> Header<'n, 'v> { + debug_assert!(no_crlf(value), "header value contains CRLF ('\\r\\n')"); + Header { name, value } + } + + /// Returns the name of the header. + pub const fn name(&self) -> &HeaderName<'n> { + &self.name + } + + /// Returns the value of the header. + pub const fn value(&self) -> &'v [u8] { + self.value + } + + /// Parse the value of the header using `T`'s [`FromHeaderValue`] + /// implementation. + pub fn parse(&self) -> Result + where + T: FromHeaderValue<'v>, + { + FromHeaderValue::from_bytes(self.value) + } +} + +/// Returns `true` if `value` does not contain `\r\n`. +const fn no_crlf(value: &[u8]) -> bool { + let mut i = 1; + while i < value.len() { + if value[i - 1] == b'\r' && value[i] == b'\n' { + return false; + } + i += 1; + } + true +} + +impl<'n, 'v> fmt::Debug for Header<'n, 'v> { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + let mut f = f.debug_struct("Header"); + let _ = f.field("name", &self.name); + if let Ok(str) = std::str::from_utf8(self.value) { + let _ = f.field("value", &str); + } else { + let _ = f.field("value", &self.value); + } + f.finish() + } +} + +/// HTTP header name. +#[derive(Clone, PartialEq, Eq)] +pub struct HeaderName<'a> { + /// The value MUST be lower case. + inner: Cow<'a, str>, +} + +/// Macro to create [`Name`] constants. +macro_rules! known_headers { + ($( + $length: tt: [ + $( $(#[$meta: meta])* ( $const_name: ident, $http_name: expr ) $(,)* ),+ + ], + )+) => { + $($( + $( #[$meta] )* + pub const $const_name: HeaderName<'static> = HeaderName::from_lowercase($http_name); + )+)+ + + /// Create a new HTTP `HeaderName`. + /// + /// # Notes + /// + /// If `name` is static prefer to use [`HeaderName::from_lowercase`]. + #[allow(clippy::should_implement_trait)] + pub fn from_str(name: &str) -> HeaderName<'static> { + // This first matches on the length of the `name`, then does a + // case-insensitive compare of the name with all known headers with + // the same length, returning a static version if a match is found. + match name.len() { + $( + $length => { + $( + if cmp_lower_case($http_name, name) { + return HeaderName::$const_name; + } + )+ + } + )+ + _ => {} + } + // If it's not a known header return a custom (heap-allocated) + // header name. + HeaderName::from(name.to_string()) + } + } +} + +impl<'n> HeaderName<'n> { + // NOTE: these are automatically generated by the `parse_headers.bash` + // script. + // NOTE: we adding here also add to the + // `functional::header::from_str_known_headers` test. + known_headers!( + 2: [ + #[doc = "IM.\n\nRFC 4229."] + (IM, "im"), + #[doc = "If.\n\nRFC 4918."] + (IF, "if"), + #[doc = "TE.\n\nRFC 7230 section 4.3."] + (TE, "te"), + ], + 3: [ + #[doc = "Age.\n\nRFC 7234 section 5.1."] + (AGE, "age"), + #[doc = "DAV.\n\nRFC 4918."] + (DAV, "dav"), + #[doc = "Ext.\n\nRFC 4229."] + (EXT, "ext"), + #[doc = "Man.\n\nRFC 4229."] + (MAN, "man"), + #[doc = "Opt.\n\nRFC 4229."] + (OPT, "opt"), + #[doc = "P3P.\n\nRFC 4229."] + (P3P, "p3p"), + #[doc = "PEP.\n\nRFC 4229."] + (PEP, "pep"), + #[doc = "TCN.\n\nRFC 4229."] + (TCN, "tcn"), + #[doc = "TTL.\n\nRFC 8030 section 5.2."] + (TTL, "ttl"), + #[doc = "URI.\n\nRFC 4229."] + (URI, "uri"), + #[doc = "Via.\n\nRFC 7230 section 5.7.1."] + (VIA, "via"), + ], + 4: [ + #[doc = "A-IM.\n\nRFC 4229."] + (A_IM, "a-im"), + #[doc = "ALPN.\n\nRFC 7639 section 2."] + (ALPN, "alpn"), + #[doc = "DASL.\n\nRFC 5323."] + (DASL, "dasl"), + #[doc = "Date.\n\nRFC 7231 section 7.1.1.2."] + (DATE, "date"), + #[doc = "ETag.\n\nRFC 7232 section 2.3."] + (ETAG, "etag"), + #[doc = "From.\n\nRFC 7231 section 5.5.1."] + (FROM, "from"), + #[doc = "Host.\n\nRFC 7230 section 5.4."] + (HOST, "host"), + #[doc = "Link.\n\nRFC 8288."] + (LINK, "link"), + #[doc = "Safe.\n\nRFC 4229."] + (SAFE, "safe"), + #[doc = "SLUG.\n\nRFC 5023."] + (SLUG, "slug"), + #[doc = "Vary.\n\nRFC 7231 section 7.1.4."] + (VARY, "vary"), + #[doc = "Cost.\n\nRFC 4229."] + (COST, "cost"), + ], + 5: [ + #[doc = "Allow.\n\nRFC 7231 section 7.4.1."] + (ALLOW, "allow"), + #[doc = "C-Ext.\n\nRFC 4229."] + (C_EXT, "c-ext"), + #[doc = "C-Man.\n\nRFC 4229."] + (C_MAN, "c-man"), + #[doc = "C-Opt.\n\nRFC 4229."] + (C_OPT, "c-opt"), + #[doc = "C-PEP.\n\nRFC 4229."] + (C_PEP, "c-pep"), + #[doc = "Close.\n\nRFC 7230 section 8.1."] + (CLOSE, "close"), + #[doc = "Depth.\n\nRFC 4918."] + (DEPTH, "depth"), + #[doc = "Label.\n\nRFC 4229."] + (LABEL, "label"), + #[doc = "Meter.\n\nRFC 4229."] + (METER, "meter"), + #[doc = "Range.\n\nRFC 7233 section 3.1."] + (RANGE, "range"), + #[doc = "Topic.\n\nRFC 8030 section 5.4."] + (TOPIC, "topic"), + #[doc = "SubOK.\n\nRFC 4229."] + (SUBOK, "subok"), + #[doc = "Subst.\n\nRFC 4229."] + (SUBST, "subst"), + #[doc = "Title.\n\nRFC 4229."] + (TITLE, "title"), + ], + 6: [ + #[doc = "Accept.\n\nRFC 7231 section 5.3.2."] + (ACCEPT, "accept"), + #[doc = "Cookie.\n\nRFC 6265."] + (COOKIE, "cookie"), + #[doc = "Digest.\n\nRFC 4229."] + (DIGEST, "digest"), + #[doc = "Expect.\n\nRFC 7231 section 5.1.1."] + (EXPECT, "expect"), + #[doc = "Origin.\n\nRFC 6454."] + (ORIGIN, "origin"), + #[doc = "OSCORE.\n\nRFC 8613 section 11.1."] + (OSCORE, "oscore"), + #[doc = "Pragma.\n\nRFC 7234 section 5.4."] + (PRAGMA, "pragma"), + #[doc = "Prefer.\n\nRFC 7240."] + (PREFER, "prefer"), + #[doc = "Public.\n\nRFC 4229."] + (PUBLIC, "public"), + #[doc = "Server.\n\nRFC 7231 section 7.4.2."] + (SERVER, "server"), + #[doc = "Sunset.\n\nRFC 8594."] + (SUNSET, "sunset"), + ], + 7: [ + #[doc = "Alt-Svc.\n\nRFC 7838."] + (ALT_SVC, "alt-svc"), + #[doc = "Cookie2.\n\nRFC 2965, RFC 6265."] + (COOKIE2, "cookie2"), + #[doc = "Expires.\n\nRFC 7234 section 5.3."] + (EXPIRES, "expires"), + #[doc = "Hobareg.\n\nRFC 7486 section 6.1.1."] + (HOBAREG, "hobareg"), + #[doc = "Referer.\n\nRFC 7231 section 5.5.2."] + (REFERER, "referer"), + #[doc = "Timeout.\n\nRFC 4918."] + (TIMEOUT, "timeout"), + #[doc = "Trailer.\n\nRFC 7230 section 4.4."] + (TRAILER, "trailer"), + #[doc = "Urgency.\n\nRFC 8030 section 5.3."] + (URGENCY, "urgency"), + #[doc = "Upgrade.\n\nRFC 7230 section 6.7."] + (UPGRADE, "upgrade"), + #[doc = "Warning.\n\nRFC 7234 section 5.5."] + (WARNING, "warning"), + #[doc = "Version.\n\nRFC 4229."] + (VERSION, "version"), + ], + 8: [ + #[doc = "Alt-Used.\n\nRFC 7838."] + (ALT_USED, "alt-used"), + #[doc = "CDN-Loop.\n\nRFC 8586."] + (CDN_LOOP, "cdn-loop"), + #[doc = "If-Match.\n\nRFC 7232 section 3.1."] + (IF_MATCH, "if-match"), + #[doc = "If-Range.\n\nRFC 7233 section 3.2."] + (IF_RANGE, "if-range"), + #[doc = "Location.\n\nRFC 7231 section 7.1.2."] + (LOCATION, "location"), + #[doc = "Pep-Info.\n\nRFC 4229."] + (PEP_INFO, "pep-info"), + #[doc = "Position.\n\nRFC 4229."] + (POSITION, "position"), + #[doc = "Protocol.\n\nRFC 4229."] + (PROTOCOL, "protocol"), + #[doc = "Optional.\n\nRFC 4229."] + (OPTIONAL, "optional"), + #[doc = "UA-Color.\n\nRFC 4229."] + (UA_COLOR, "ua-color"), + #[doc = "UA-Media.\n\nRFC 4229."] + (UA_MEDIA, "ua-media"), + ], + 9: [ + #[doc = "Accept-CH.\n\nRFC 8942 section 3.1."] + (ACCEPT_CH, "accept-ch"), + #[doc = "Expect-CT.\n\nRFC -ietf-httpbis-expect-ct-08."] + (EXPECT_CT, "expect-ct"), + #[doc = "Forwarded.\n\nRFC 7239."] + (FORWARDED, "forwarded"), + #[doc = "Negotiate.\n\nRFC 4229."] + (NEGOTIATE, "negotiate"), + #[doc = "Overwrite.\n\nRFC 4918."] + (OVERWRITE, "overwrite"), + #[doc = "Isolation.\n\nOData Version 4.01 Part 1: Protocol, OASIS, Chet_Ensign."] + (ISOLATION, "isolation"), + #[doc = "UA-Pixels.\n\nRFC 4229."] + (UA_PIXELS, "ua-pixels"), + ], + 10: [ + #[doc = "Alternates.\n\nRFC 4229."] + (ALTERNATES, "alternates"), + #[doc = "C-PEP-Info.\n\nRFC 4229."] + (C_PEP_INFO, "c-pep-info"), + #[doc = "Connection.\n\nRFC 7230 section 6.1."] + (CONNECTION, "connection"), + #[doc = "Content-ID.\n\nRFC 4229."] + (CONTENT_ID, "content-id"), + #[doc = "Delta-Base.\n\nRFC 4229."] + (DELTA_BASE, "delta-base"), + #[doc = "Early-Data.\n\nRFC 8470."] + (EARLY_DATA, "early-data"), + #[doc = "GetProfile.\n\nRFC 4229."] + (GETPROFILE, "getprofile"), + #[doc = "Keep-Alive.\n\nRFC 4229."] + (KEEP_ALIVE, "keep-alive"), + #[doc = "Lock-Token.\n\nRFC 4918."] + (LOCK_TOKEN, "lock-token"), + #[doc = "PICS-Label.\n\nRFC 4229."] + (PICS_LABEL, "pics-label"), + #[doc = "Set-Cookie.\n\nRFC 6265."] + (SET_COOKIE, "set-cookie"), + #[doc = "SetProfile.\n\nRFC 4229."] + (SETPROFILE, "setprofile"), + #[doc = "SoapAction.\n\nRFC 4229."] + (SOAPACTION, "soapaction"), + #[doc = "Status-URI.\n\nRFC 4229."] + (STATUS_URI, "status-uri"), + #[doc = "User-Agent.\n\nRFC 7231 section 5.5.3."] + (USER_AGENT, "user-agent"), + #[doc = "Compliance.\n\nRFC 4229."] + (COMPLIANCE, "compliance"), + #[doc = "Message-ID.\n\nRFC 4229."] + (MESSAGE_ID, "message-id"), + #[doc = "Tracestate.\n\n."] + (TRACESTATE, "tracestate"), + ], + 11: [ + #[doc = "Accept-Post.\n\n."] + (ACCEPT_POST, "accept-post"), + #[doc = "Content-MD5.\n\nRFC 4229."] + (CONTENT_MD5, "content-md5"), + #[doc = "Destination.\n\nRFC 4918."] + (DESTINATION, "destination"), + #[doc = "Retry-After.\n\nRFC 7231 section 7.1.3."] + (RETRY_AFTER, "retry-after"), + #[doc = "Set-Cookie2.\n\nRFC 2965, RFC 6265."] + (SET_COOKIE2, "set-cookie2"), + #[doc = "Want-Digest.\n\nRFC 4229."] + (WANT_DIGEST, "want-digest"), + #[doc = "Traceparent.\n\n."] + (TRACEPARENT, "traceparent"), + ], + 12: [ + #[doc = "Accept-Patch.\n\nRFC 5789."] + (ACCEPT_PATCH, "accept-patch"), + #[doc = "Content-Base.\n\nRFC 2068, RFC 2616."] + (CONTENT_BASE, "content-base"), + #[doc = "Content-Type.\n\nRFC 7231 section 3.1.1.5."] + (CONTENT_TYPE, "content-type"), + #[doc = "Derived-From.\n\nRFC 4229."] + (DERIVED_FROM, "derived-from"), + #[doc = "Max-Forwards.\n\nRFC 7231 section 5.1.2."] + (MAX_FORWARDS, "max-forwards"), + #[doc = "MIME-Version.\n\nRFC 7231, Appendix A.1."] + (MIME_VERSION, "mime-version"), + #[doc = "Redirect-Ref.\n\nRFC 4437."] + (REDIRECT_REF, "redirect-ref"), + #[doc = "Replay-Nonce.\n\nRFC 8555 section 6.5.1."] + (REPLAY_NONCE, "replay-nonce"), + #[doc = "Schedule-Tag.\n\nRFC 6638."] + (SCHEDULE_TAG, "schedule-tag"), + #[doc = "Variant-Vary.\n\nRFC 4229."] + (VARIANT_VARY, "variant-vary"), + #[doc = "Method-Check.\n\nW3C Web Application Formats Working Group."] + (METHOD_CHECK, "method-check"), + #[doc = "Referer-Root.\n\nW3C Web Application Formats Working Group."] + (REFERER_ROOT, "referer-root"), + #[doc = "X-Request-ID."] + (X_REQUEST_ID, "x-request-id"), + ], + 13: [ + #[doc = "Accept-Ranges.\n\nRFC 7233 section 2.3."] + (ACCEPT_RANGES, "accept-ranges"), + #[doc = "Authorization.\n\nRFC 7235 section 4.2."] + (AUTHORIZATION, "authorization"), + #[doc = "Cache-Control.\n\nRFC 7234 section 5.2."] + (CACHE_CONTROL, "cache-control"), + #[doc = "Content-Range.\n\nRFC 7233 section 4.2."] + (CONTENT_RANGE, "content-range"), + #[doc = "Default-Style.\n\nRFC 4229."] + (DEFAULT_STYLE, "default-style"), + #[doc = "If-None-Match.\n\nRFC 7232 section 3.2."] + (IF_NONE_MATCH, "if-none-match"), + #[doc = "Last-Modified.\n\nRFC 7232 section 2.2."] + (LAST_MODIFIED, "last-modified"), + #[doc = "OData-Version.\n\nOData Version 4.01 Part 1: Protocol, OASIS, Chet_Ensign."] + (ODATA_VERSION, "odata-version"), + #[doc = "Ordering-Type.\n\nRFC 4229."] + (ORDERING_TYPE, "ordering-type"), + #[doc = "ProfileObject.\n\nRFC 4229."] + (PROFILEOBJECT, "profileobject"), + #[doc = "Protocol-Info.\n\nRFC 4229."] + (PROTOCOL_INFO, "protocol-info"), + #[doc = "UA-Resolution.\n\nRFC 4229."] + (UA_RESOLUTION, "ua-resolution"), + ], + 14: [ + #[doc = "Accept-Charset.\n\nRFC 7231 section 5.3.3."] + (ACCEPT_CHARSET, "accept-charset"), + #[doc = "Cal-Managed-ID.\n\nRFC 8607 section 5.1."] + (CAL_MANAGED_ID, "cal-managed-id"), + #[doc = "Cert-Not-After.\n\nRFC 8739 section 3.3."] + (CERT_NOT_AFTER, "cert-not-after"), + #[doc = "Content-Length.\n\nRFC 7230 section 3.3.2."] + (CONTENT_LENGTH, "content-length"), + #[doc = "HTTP2-Settings.\n\nRFC 7540 section 3.2.1."] + (HTTP2_SETTINGS, "http2-settings"), + #[doc = "OData-EntityId.\n\nOData Version 4.01 Part 1: Protocol, OASIS, Chet_Ensign."] + (ODATA_ENTITYID, "odata-entityid"), + #[doc = "Protocol-Query.\n\nRFC 4229."] + (PROTOCOL_QUERY, "protocol-query"), + #[doc = "Proxy-Features.\n\nRFC 4229."] + (PROXY_FEATURES, "proxy-features"), + #[doc = "Schedule-Reply.\n\nRFC 6638."] + (SCHEDULE_REPLY, "schedule-reply"), + #[doc = "Access-Control.\n\nW3C Web Application Formats Working Group."] + (ACCESS_CONTROL, "access-control"), + #[doc = "Non-Compliance.\n\nRFC 4229."] + (NON_COMPLIANCE, "non-compliance"), + ], + 15: [ + #[doc = "Accept-Datetime.\n\nRFC 7089."] + (ACCEPT_DATETIME, "accept-datetime"), + #[doc = "Accept-Encoding.\n\nRFC 7231 section 5.3.4, RFC 7694 section 3."] + (ACCEPT_ENCODING, "accept-encoding"), + #[doc = "Accept-Features.\n\nRFC 4229."] + (ACCEPT_FEATURES, "accept-features"), + #[doc = "Accept-Language.\n\nRFC 7231 section 5.3.5."] + (ACCEPT_LANGUAGE, "accept-language"), + #[doc = "Cert-Not-Before.\n\nRFC 8739 section 3.3."] + (CERT_NOT_BEFORE, "cert-not-before"), + #[doc = "Content-Version.\n\nRFC 4229."] + (CONTENT_VERSION, "content-version"), + #[doc = "Differential-ID.\n\nRFC 4229."] + (DIFFERENTIAL_ID, "differential-id"), + #[doc = "OData-Isolation.\n\nOData Version 4.01 Part 1: Protocol, OASIS, Chet_Ensign."] + (ODATA_ISOLATION, "odata-isolation"), + #[doc = "Public-Key-Pins.\n\nRFC 7469."] + (PUBLIC_KEY_PINS, "public-key-pins"), + #[doc = "Security-Scheme.\n\nRFC 4229."] + (SECURITY_SCHEME, "security-scheme"), + #[doc = "X-Frame-Options.\n\nRFC 7034."] + (X_FRAME_OPTIONS, "x-frame-options"), + #[doc = "EDIINT-Features.\n\nRFC 6017."] + (EDIINT_FEATURES, "ediint-features"), + #[doc = "Resolution-Hint.\n\nRFC 4229."] + (RESOLUTION_HINT, "resolution-hint"), + #[doc = "UA-Windowpixels.\n\nRFC 4229."] + (UA_WINDOWPIXELS, "ua-windowpixels"), + #[doc = "X-Device-Accept.\n\nW3C Mobile Web Best Practices Working Group."] + (X_DEVICE_ACCEPT, "x-device-accept"), + ], + 16: [ + #[doc = "Accept-Additions.\n\nRFC 4229."] + (ACCEPT_ADDITIONS, "accept-additions"), + #[doc = "CalDAV-Timezones.\n\nRFC 7809 section 7.1."] + (CALDAV_TIMEZONES, "caldav-timezones"), + #[doc = "Content-Encoding.\n\nRFC 7231 section 3.1.2.2."] + (CONTENT_ENCODING, "content-encoding"), + #[doc = "Content-Language.\n\nRFC 7231 section 3.1.3.2."] + (CONTENT_LANGUAGE, "content-language"), + #[doc = "Content-Location.\n\nRFC 7231 section 3.1.4.2."] + (CONTENT_LOCATION, "content-location"), + #[doc = "Memento-Datetime.\n\nRFC 7089."] + (MEMENTO_DATETIME, "memento-datetime"), + #[doc = "OData-MaxVersion.\n\nOData Version 4.01 Part 1: Protocol, OASIS, Chet_Ensign."] + (ODATA_MAXVERSION, "odata-maxversion"), + #[doc = "Protocol-Request.\n\nRFC 4229."] + (PROTOCOL_REQUEST, "protocol-request"), + #[doc = "WWW-Authenticate.\n\nRFC 7235 section 4.1."] + (WWW_AUTHENTICATE, "www-authenticate"), + ], + 17: [ + #[doc = "If-Modified-Since.\n\nRFC 7232 section 3.3."] + (IF_MODIFIED_SINCE, "if-modified-since"), + #[doc = "Proxy-Instruction.\n\nRFC 4229."] + (PROXY_INSTRUCTION, "proxy-instruction"), + #[doc = "Sec-Token-Binding.\n\nRFC 8473."] + (SEC_TOKEN_BINDING, "sec-token-binding"), + #[doc = "Sec-WebSocket-Key.\n\nRFC 6455."] + (SEC_WEBSOCKET_KEY, "sec-websocket-key"), + #[doc = "Surrogate-Control.\n\nRFC 4229."] + (SURROGATE_CONTROL, "surrogate-control"), + #[doc = "Transfer-Encoding.\n\nRFC 7230 section 3.3.1."] + (TRANSFER_ENCODING, "transfer-encoding"), + #[doc = "OSLC-Core-Version.\n\nOASIS Project Specification 01, OASIS, Chet_Ensign."] + (OSLC_CORE_VERSION, "oslc-core-version"), + #[doc = "Resolver-Location.\n\nRFC 4229."] + (RESOLVER_LOCATION, "resolver-location"), + ], + 18: [ + #[doc = "Content-Style-Type.\n\nRFC 4229."] + (CONTENT_STYLE_TYPE, "content-style-type"), + #[doc = "Preference-Applied.\n\nRFC 7240."] + (PREFERENCE_APPLIED, "preference-applied"), + #[doc = "Proxy-Authenticate.\n\nRFC 7235 section 4.3."] + (PROXY_AUTHENTICATE, "proxy-authenticate"), + ], + 19: [ + #[doc = "Authentication-Info.\n\nRFC 7615 section 3."] + (AUTHENTICATION_INFO, "authentication-info"), + #[doc = "Content-Disposition.\n\nRFC 6266."] + (CONTENT_DISPOSITION, "content-disposition"), + #[doc = "Content-Script-Type.\n\nRFC 4229."] + (CONTENT_SCRIPT_TYPE, "content-script-type"), + #[doc = "If-Unmodified-Since.\n\nRFC 7232 section 3.4."] + (IF_UNMODIFIED_SINCE, "if-unmodified-since"), + #[doc = "Proxy-Authorization.\n\nRFC 7235 section 4.4."] + (PROXY_AUTHORIZATION, "proxy-authorization"), + #[doc = "AMP-Cache-Transform.\n\n."] + (AMP_CACHE_TRANSFORM, "amp-cache-transform"), + #[doc = "Timing-Allow-Origin.\n\n."] + (TIMING_ALLOW_ORIGIN, "timing-allow-origin"), + #[doc = "X-Device-User-Agent.\n\nW3C Mobile Web Best Practices Working Group."] + (X_DEVICE_USER_AGENT, "x-device-user-agent"), + ], + 20: [ + #[doc = "Sec-WebSocket-Accept.\n\nRFC 6455."] + (SEC_WEBSOCKET_ACCEPT, "sec-websocket-accept"), + #[doc = "Surrogate-Capability.\n\nRFC 4229."] + (SURROGATE_CAPABILITY, "surrogate-capability"), + #[doc = "Method-Check-Expires.\n\nW3C Web Application Formats Working Group."] + (METHOD_CHECK_EXPIRES, "method-check-expires"), + #[doc = "Repeatability-Result.\n\nRepeatable Requests Version 1.0, OASIS, Chet_Ensign."] + (REPEATABILITY_RESULT, "repeatability-result"), + ], + 21: [ + #[doc = "Apply-To-Redirect-Ref.\n\nRFC 4437."] + (APPLY_TO_REDIRECT_REF, "apply-to-redirect-ref"), + #[doc = "If-Schedule-Tag-Match.\n\nRFC 6638."] + (IF_SCHEDULE_TAG_MATCH, "if-schedule-tag-match"), + #[doc = "Sec-WebSocket-Version.\n\nRFC 6455."] + (SEC_WEBSOCKET_VERSION, "sec-websocket-version"), + ], + 22: [ + #[doc = "Authentication-Control.\n\nRFC 8053 section 4."] + (AUTHENTICATION_CONTROL, "authentication-control"), + #[doc = "Sec-WebSocket-Protocol.\n\nRFC 6455."] + (SEC_WEBSOCKET_PROTOCOL, "sec-websocket-protocol"), + #[doc = "X-Content-Type-Options.\n\n."] + (X_CONTENT_TYPE_OPTIONS, "x-content-type-options"), + #[doc = "Access-Control-Max-Age.\n\nW3C Web Application Formats Working Group."] + (ACCESS_CONTROL_MAX_AGE, "access-control-max-age"), + ], + 23: [ + #[doc = "Repeatability-Client-ID.\n\nRepeatable Requests Version 1.0, OASIS, Chet_Ensign."] + (REPEATABILITY_CLIENT_ID, "repeatability-client-id"), + #[doc = "X-Device-Accept-Charset.\n\nW3C Mobile Web Best Practices Working Group."] + (X_DEVICE_ACCEPT_CHARSET, "x-device-accept-charset"), + ], + 24: [ + #[doc = "Sec-WebSocket-Extensions.\n\nRFC 6455."] + (SEC_WEBSOCKET_EXTENSIONS, "sec-websocket-extensions"), + #[doc = "Repeatability-First-Sent.\n\nRepeatable Requests Version 1.0, OASIS, Chet_Ensign."] + (REPEATABILITY_FIRST_SENT, "repeatability-first-sent"), + #[doc = "Repeatability-Request-ID.\n\nRepeatable Requests Version 1.0, OASIS, Chet_Ensign."] + (REPEATABILITY_REQUEST_ID, "repeatability-request-id"), + #[doc = "X-Device-Accept-Encoding.\n\nW3C Mobile Web Best Practices Working Group."] + (X_DEVICE_ACCEPT_ENCODING, "x-device-accept-encoding"), + #[doc = "X-Device-Accept-Language.\n\nW3C Mobile Web Best Practices Working Group."] + (X_DEVICE_ACCEPT_LANGUAGE, "x-device-accept-language"), + ], + 25: [ + #[doc = "Optional-WWW-Authenticate.\n\nRFC 8053 section 3."] + (OPTIONAL_WWW_AUTHENTICATE, "optional-www-authenticate"), + #[doc = "Proxy-Authentication-Info.\n\nRFC 7615 section 4."] + (PROXY_AUTHENTICATION_INFO, "proxy-authentication-info"), + #[doc = "Strict-Transport-Security.\n\nRFC 6797."] + (STRICT_TRANSPORT_SECURITY, "strict-transport-security"), + #[doc = "Content-Transfer-Encoding.\n\nRFC 4229."] + (CONTENT_TRANSFER_ENCODING, "content-transfer-encoding"), + ], + 27: [ + #[doc = "Public-Key-Pins-Report-Only.\n\nRFC 7469."] + (PUBLIC_KEY_PINS_REPORT_ONLY, "public-key-pins-report-only"), + #[doc = "Access-Control-Allow-Origin.\n\nW3C Web Application Formats Working Group."] + (ACCESS_CONTROL_ALLOW_ORIGIN, "access-control-allow-origin"), + ], + 28: [ + #[doc = "Access-Control-Allow-Headers.\n\nW3C Web Application Formats Working Group."] + (ACCESS_CONTROL_ALLOW_HEADERS, "access-control-allow-headers"), + #[doc = "Access-Control-Allow-Methods.\n\nW3C Web Application Formats Working Group."] + (ACCESS_CONTROL_ALLOW_METHODS, "access-control-allow-methods"), + ], + 29: [ + #[doc = "Access-Control-Request-Method.\n\nW3C Web Application Formats Working Group."] + (ACCESS_CONTROL_REQUEST_METHOD, "access-control-request-method"), + ], + 30: [ + #[doc = "Access-Control-Request-Headers.\n\nW3C Web Application Formats Working Group."] + (ACCESS_CONTROL_REQUEST_HEADERS, "access-control-request-headers"), + ], + 32: [ + #[doc = "Access-Control-Allow-Credentials.\n\nW3C Web Application Formats Working Group."] + (ACCESS_CONTROL_ALLOW_CREDENTIALS, "access-control-allow-credentials"), + ], + 33: [ + #[doc = "Include-Referred-Token-Binding-ID.\n\nRFC 8473."] + (INCLUDE_REFERRED_TOKEN_BINDING_ID, "include-referred-token-binding-id"), + ], + ); + + /// Create a new HTTP `HeaderName`. + /// + /// # Panics + /// + /// Panics if `name` is not all ASCII lowercase. + pub const fn from_lowercase(name: &'n str) -> HeaderName<'n> { + assert!(is_lower_case(name), "header name not lowercase"); + HeaderName { + inner: Cow::Borrowed(name), + } + } + + /// Borrow the header name for a shorter lifetime. + /// + /// This is used in things like [`Headers::get`] and [`Iter`] for `Headers` + /// to avoid clone heap-allocated `HeaderName`s. + fn borrow<'b>(&'b self) -> HeaderName<'b> { + HeaderName { + inner: Cow::Borrowed(self.as_ref()), + } + } + + /// Returns `true` if `self` is heap allocated. + /// + /// # Notes + /// + /// This is only header to test [`HeaderName::from_str`], not part of the + /// stable API. + #[doc(hidden)] + pub const fn is_heap_allocated(&self) -> bool { + matches!(self.inner, Cow::Owned(_)) + } +} + +impl From for HeaderName<'static> { + fn from(mut name: String) -> HeaderName<'static> { + name.make_ascii_lowercase(); + HeaderName { + inner: Cow::Owned(name), + } + } +} + +impl<'a> AsRef for HeaderName<'a> { + fn as_ref(&self) -> &str { + self.inner.as_ref() + } +} + +impl<'a> PartialEq for HeaderName<'a> { + fn eq(&self, other: &str) -> bool { + // NOTE: `self` is always lowercase, per the comment on the `inner` + // field. + cmp_lower_case(self.inner.as_ref(), other) + } +} + +impl<'a> PartialEq<&'_ str> for HeaderName<'a> { + fn eq(&self, other: &&str) -> bool { + self.eq(*other) + } +} + +impl<'a> fmt::Debug for HeaderName<'a> { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.write_str(self.as_ref()) + } +} + +impl<'a> fmt::Display for HeaderName<'a> { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.write_str(self.as_ref()) + } +} + +/// Analogous trait to [`FromStr`]. +/// +/// The main use case for this trait in [`Header::parse`]. Because of this the +/// implementations should expect the `value`s passed to be ASCII/UTF-8, but +/// this not true in all cases. +/// +/// [`FromStr`]: std::str::FromStr +pub trait FromHeaderValue<'a>: Sized { + /// Error returned by parsing the bytes. + type Err; + + /// Parse the `value`. + fn from_bytes(value: &'a [u8]) -> Result; +} + +/// Error returned by the [`FromHeaderValue`] implementation for numbers, e.g. +/// `usize`. +#[derive(Debug)] +pub struct ParseIntError; + +impl fmt::Display for ParseIntError { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.write_str("invalid integer") + } +} + +macro_rules! int_impl { + ($( $ty: ty ),+) => { + $( + impl FromHeaderValue<'_> for $ty { + type Err = ParseIntError; + + fn from_bytes(src: &[u8]) -> Result { + if src.is_empty() { + return Err(ParseIntError); + } + + let mut value: $ty = 0; + for b in src.iter().copied() { + if (b'0'..=b'9').contains(&b) { + match value.checked_mul(10) { + Some(v) => value = v, + None => return Err(ParseIntError), + } + #[allow(trivial_numeric_casts)] // For `u8 as u8`. + match value.checked_add((b - b'0') as $ty) { + Some(v) => value = v, + None => return Err(ParseIntError), + } + } else { + return Err(ParseIntError); + } + } + Ok(value) + } + } + )+ + }; +} + +int_impl!(u8, u16, u32, u64, usize); + +impl<'a> FromHeaderValue<'a> for &'a str { + type Err = str::Utf8Error; + + fn from_bytes(value: &'a [u8]) -> Result { + str::from_utf8(value) + } +} + +/// Error returned by the [`FromHeaderValue`] implementation for [`SystemTime`]. +#[derive(Debug)] +pub struct ParseTimeError; + +impl fmt::Display for ParseTimeError { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.write_str("invalid time") + } +} + +/// Parses the value following RFC7231 section 7.1.1.1. +impl FromHeaderValue<'_> for SystemTime { + type Err = ParseTimeError; + + fn from_bytes(value: &[u8]) -> Result { + match str::from_utf8(value) { + Ok(value) => match parse_http_date(value) { + Ok(time) => Ok(time), + Err(_) => Err(ParseTimeError), + }, + Err(_) => Err(ParseTimeError), + } + } +} diff --git a/http/src/lib.rs b/http/src/lib.rs new file mode 100644 index 000000000..8e8d8ec1c --- /dev/null +++ b/http/src/lib.rs @@ -0,0 +1,204 @@ +//! HTTP/1.1 implementation for Heph. + +#![feature( + async_stream, + const_fn_trait_bound, + const_mut_refs, + const_panic, + generic_associated_types, + io_slice_advance, + maybe_uninit_write_slice, + ready_macro, + stmt_expr_attributes +)] +#![allow(incomplete_features)] // NOTE: for `generic_associated_types`. +#![warn( + anonymous_parameters, + bare_trait_objects, + missing_debug_implementations, + missing_docs, + rust_2018_idioms, + trivial_numeric_casts, + unused_extern_crates, + unused_import_braces, + unused_qualifications, + unused_results, + variant_size_differences +)] + +pub mod body; +pub mod client; +pub mod header; +pub mod method; +mod request; +mod response; +pub mod server; +mod status_code; +pub mod version; + +#[doc(no_inline)] +pub use body::Body; +#[doc(no_inline)] +pub use client::Client; +#[doc(no_inline)] +pub use header::{Header, HeaderName, Headers}; +#[doc(no_inline)] +pub use method::Method; +pub use request::Request; +pub use response::Response; +#[doc(no_inline)] +pub use server::{Connection, HttpServer}; +pub use status_code::StatusCode; +#[doc(no_inline)] +pub use version::Version; + +/// Maximum size of the HTTP head (the start line and the headers). +/// +/// RFC 7230 section 3.1.1 recommends "all HTTP senders and recipients support, +/// at a minimum, request-line lengths of 8000 octets." +pub const MAX_HEAD_SIZE: usize = 16384; + +/// Maximum number of headers parsed from a single [`Request`]/[`Response`]. +pub const MAX_HEADERS: usize = 64; + +/// Minimum amount of bytes read from the connection or the buffer will be +/// grown. +const MIN_READ_SIZE: usize = 4096; + +/// Size of the buffer used in [`server::Connection`] and [`Client`]. +const BUF_SIZE: usize = 8192; + +/// Map a `version` byte to a [`Version`]. +const fn map_version_byte(version: u8) -> Version { + match version { + 0 => Version::Http10, + // RFC 7230 section 2.6: + // > A server SHOULD send a response version equal to + // > the highest version to which the server is + // > conformant that has a major version less than or + // > equal to the one received in the request. + // HTTP/1.1 is the highest we support. + _ => Version::Http11, + } +} + +/// Trim whitespace from `value`. +fn trim_ws(value: &[u8]) -> &[u8] { + let len = value.len(); + if len == 0 { + return value; + } + let mut start = 0; + while start < len { + if !value[start].is_ascii_whitespace() { + break; + } + start += 1; + } + let mut end = len - 1; + while end > start { + if !value[end].is_ascii_whitespace() { + break; + } + end -= 1; + } + // TODO: make this `const`. + &value[start..=end] +} + +/// Returns `true` if `lower_case` and `right` are a case-insensitive match. +/// +/// # Notes +/// +/// `lower_case` must be lower case! +const fn cmp_lower_case(lower_case: &str, right: &str) -> bool { + debug_assert!(is_lower_case(lower_case)); + + let left = lower_case.as_bytes(); + let right = right.as_bytes(); + let len = left.len(); + if len != right.len() { + return false; + } + + let mut i = 0; + while i < len { + if left[i] != right[i].to_ascii_lowercase() { + return false; + } + i += 1; + } + true +} + +/// Returns `true` if `value` is all ASCII lowercase. +const fn is_lower_case(value: &str) -> bool { + let value = value.as_bytes(); + let mut i = 0; + while i < value.len() { + // NOTE: allows `-` because it's used in header names. + if !matches!(value[i], b'0'..=b'9' | b'a'..=b'z' | b'-') { + return false; + } + i += 1; + } + true +} + +#[cfg(test)] +mod tests { + use super::{cmp_lower_case, is_lower_case, trim_ws}; + + #[test] + fn test_trim_ws() { + let tests = &[ + ("", ""), + ("abc", "abc"), + (" abc", "abc"), + (" abc ", "abc"), + (" gzip, chunked ", "gzip, chunked"), + ]; + for (input, expected) in tests { + let got = trim_ws(input.as_bytes()); + assert_eq!(got, expected.as_bytes(), "input: {}", input); + } + } + + #[test] + fn test_is_lower_case() { + let tests = &[ + ("", true), + ("abc", true), + ("Abc", false), + ("aBc", false), + ("AbC", false), + ("ABC", false), + ]; + for (input, expected) in tests { + let got = is_lower_case(input); + assert_eq!(got, *expected, "input: {}", input); + } + } + + #[test] + fn test_cmp_lower_case() { + let tests = &[ + ("", "", true), + ("abc", "abc", true), + ("abc", "Abc", true), + ("abc", "aBc", true), + ("abc", "abC", true), + ("abc", "ABC", true), + ("a", "", false), + ("", "a", false), + ("abc", "", false), + ("abc", "d", false), + ("abc", "de", false), + ("abc", "def", false), + ]; + for (lower_case, right, expected) in tests { + let got = cmp_lower_case(lower_case, right); + assert_eq!(got, *expected, "input: '{}', '{}'", lower_case, right); + } + } +} diff --git a/http/src/method.rs b/http/src/method.rs new file mode 100644 index 000000000..2ed81a179 --- /dev/null +++ b/http/src/method.rs @@ -0,0 +1,155 @@ +//! Module with HTTP method related types. + +use std::fmt; +use std::str::FromStr; + +use crate::cmp_lower_case; + +/// HTTP method. +/// +/// RFC 7231 section 4. +#[derive(Copy, Clone, Debug, PartialEq, Eq)] +pub enum Method { + /// GET method. + /// + /// RFC 7231 section 4.3.1. + Get, + /// HEAD method. + /// + /// RFC 7231 section 4.3.2. + Head, + /// POST method. + /// + /// RFC 7231 section 4.3.3. + Post, + /// PUT method. + /// + /// RFC 7231 section 4.3.4. + Put, + /// DELETE method. + /// + /// RFC 7231 section 4.3.5. + Delete, + /// CONNECT method. + /// + /// RFC 7231 section 4.3.6. + Connect, + /// OPTIONS method. + /// + /// RFC 7231 section 4.3.7. + Options, + /// TRACE method. + /// + /// RFC 7231 section 4.3.8. + Trace, + /// PATCH method. + /// + /// RFC 5789. + Patch, +} + +impl Method { + /// Returns `true` if the method is safe. + /// + /// RFC 7321 section 4.2.1. + pub const fn is_safe(self) -> bool { + use Method::*; + matches!(self, Get | Head | Options | Trace) + } + + /// Returns `true` if the method is idempotent. + /// + /// RFC 7321 section 4.2.2. + pub const fn is_idempotent(self) -> bool { + matches!(self, Method::Put | Method::Delete) || self.is_safe() + } + + /// Returns `false` if a response to this method MUST NOT include a body. + /// + /// This is only the case for the HEAD method. + /// + /// RFC 7230 section 3.3 and RFC 7321 section 4.3.2. + pub const fn expects_body(self) -> bool { + // RFC 7231 section 4.3.2: + // > The HEAD method is identical to GET except that the server MUST NOT + // > send a message body in the response (i.e., the response terminates + // > at the end of the header section). + !matches!(self, Method::Head) + } + + /// Returns the method as string. + pub const fn as_str(self) -> &'static str { + use Method::*; + match self { + Options => "OPTIONS", + Get => "GET", + Post => "POST", + Put => "PUT", + Delete => "DELETE", + Head => "HEAD", + Trace => "TRACE", + Connect => "CONNECT", + Patch => "PATCH", + } + } +} + +impl fmt::Display for Method { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.write_str(self.as_str()) + } +} + +/// Error returned by the [`FromStr`] implementation for [`Method`]. +#[derive(Copy, Clone, Debug)] +pub struct UnknownMethod; + +impl fmt::Display for UnknownMethod { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.write_str("unknown HTTP method") + } +} + +impl FromStr for Method { + type Err = UnknownMethod; + + fn from_str(method: &str) -> Result { + match method.len() { + 3 => { + if cmp_lower_case("get", method) { + return Ok(Method::Get); + } else if cmp_lower_case("put", method) { + return Ok(Method::Put); + } + } + 4 => { + if cmp_lower_case("head", method) { + return Ok(Method::Head); + } else if cmp_lower_case("post", method) { + return Ok(Method::Post); + } + } + 5 => { + if cmp_lower_case("trace", method) { + return Ok(Method::Trace); + } else if cmp_lower_case("patch", method) { + return Ok(Method::Patch); + } + } + 6 => { + if cmp_lower_case("delete", method) { + return Ok(Method::Delete); + } + } + 7 => { + if cmp_lower_case("connect", method) { + return Ok(Method::Connect); + } else if cmp_lower_case("options", method) { + return Ok(Method::Options); + } + } + _ => {} + } + Err(UnknownMethod) + } +} diff --git a/http/src/parse_headers.bash b/http/src/parse_headers.bash new file mode 100755 index 000000000..a4ce63b5c --- /dev/null +++ b/http/src/parse_headers.bash @@ -0,0 +1,90 @@ +#!/usr/bin/env bash + +# Get the two csv files (permanent and provisional) from: +# https://www.iana.org/assignments/message-headers/message-headers.xhtml +# Remove the header from both file and run: +# $ cat perm-headers.csv prov-headers.csv | ./parse.bash + +set -eu + +clean_reference_partial() { + local reference="$1" + + # Remove '[' .. ']'. + if [[ "${reference:0:1}" == "[" ]]; then + if [[ "${reference: -1}" == "]" ]]; then + reference="${reference:1:-1}" + else + reference="${reference:1}" + fi + fi + + # Wrap links in '<' .. '>'. + if [[ "$reference" == http* ]]; then + reference="<$reference>" + fi + + if [[ "${reference:0:3}" == "RFC" ]]; then + # Add a space after 'RFC'. + reference="RFC ${reference:3}" + # Remove comma and lower case section. + reference="${reference/, S/ s}" + fi + + echo -n "$reference" +} + +clean_reference() { + local reference="$1" + local partial="${2:-false}" + + # Remove '"' .. '"'. + if [[ "${reference:0:1}" == "\"" ]]; then + reference="${reference:1:-1}" + fi + + # Some references are actually multiple references inside '[' .. ']'. + # Clean them one by one. + IFS="]" read -ra refs <<< "$reference" + reference="" + for ref in "${refs[@]}"; do + reference+="$(clean_reference_partial "$ref")" + reference+=', ' + done + + echo "${reference:0:-2}" # Remove last ', '. +} + +# Collect all known header name by length in `header_names`. +declare -a header_names +while IFS=$',' read -r name template protocol status reference; do + # We're only interested in HTTP headers. + if [[ "http" != "$protocol" ]]; then + continue + fi + + reference="$(clean_reference "$reference")" + const_name="${name^^}" # To uppercase. + const_name="${const_name//\-/\_}" # '-' -> '_'. + const_name="${const_name// /\_}" # ' ' -> '_'. + const_value="${name,,}" # To lowercase. + value_length="${#const_value}" # Value length. + docs="#[doc = \"$name.\\\\n\\\\n$reference.\"]" + + header_names[$value_length]+="$docs|$const_name|$const_value +" +done + +# Add non-standard headers. +# X-Request-ID. +header_names[12]+="#[doc = \"X-Request-ID.\"]|X_REQUEST_ID|x-request-id +" + +for value_length in "${!header_names[@]}"; do + values="${header_names[$value_length]}" + echo " $value_length: [" + while IFS=$'|' read -r docs const_name const_value; do + printf " $docs\n ($const_name, \"$const_value\"),\n" + done <<< "${values:0:-1}" # Remove last new line. + echo " ],"; +done diff --git a/http/src/request.rs b/http/src/request.rs new file mode 100644 index 000000000..5f4030c64 --- /dev/null +++ b/http/src/request.rs @@ -0,0 +1,86 @@ +use std::fmt; + +use crate::{Headers, Method, Version}; + +/// HTTP request. +pub struct Request { + method: Method, + path: String, + version: Version, + headers: Headers, + body: B, +} + +impl Request { + /// Create a new request. + pub const fn new( + method: Method, + path: String, + version: Version, + headers: Headers, + body: B, + ) -> Request { + Request { + method, + path, + version, + headers, + body, + } + } + + /// Returns the HTTP version of this request. + /// + /// # Notes + /// + /// Requests from the [`HttpServer`] will return the highest version it + /// understand, e.g. if a client used HTTP/1.2 (which doesn't exists) the + /// version would be set to HTTP/1.1 (the highest version this crate + /// understands) per RFC 7230 section 2.6. + /// + /// [`HttpServer`]: crate::HttpServer + pub const fn version(&self) -> Version { + self.version + } + + /// Returns the HTTP method of this request. + pub const fn method(&self) -> Method { + self.method + } + + /// Returns the path of this request. + pub fn path(&self) -> &str { + &self.path + } + + /// Returns the headers. + pub const fn headers(&self) -> &Headers { + &self.headers + } + + /// Returns mutable access to the headers. + pub const fn headers_mut(&mut self) -> &mut Headers { + &mut self.headers + } + + /// The request body. + pub const fn body(&self) -> &B { + &self.body + } + + /// Mutable access to the request body. + pub const fn body_mut(&mut self) -> &mut B { + &mut self.body + } +} + +impl fmt::Debug for Request { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_struct("Request") + .field("method", &self.method) + .field("path", &self.path) + .field("version", &self.version) + .field("headers", &self.headers) + .finish() + } +} diff --git a/http/src/response.rs b/http/src/response.rs new file mode 100644 index 000000000..affafa607 --- /dev/null +++ b/http/src/response.rs @@ -0,0 +1,73 @@ +use std::fmt; + +use crate::{Headers, StatusCode, Version}; + +/// HTTP response. +pub struct Response { + version: Version, + status: StatusCode, + headers: Headers, + body: B, +} + +impl Response { + /// Create a new HTTP response. + pub const fn new( + version: Version, + status: StatusCode, + headers: Headers, + body: B, + ) -> Response { + Response { + version, + status, + headers, + body, + } + } + + /// Returns the HTTP version of this response. + pub const fn version(&self) -> Version { + self.version + } + + /// Returns the response code. + pub const fn status(&self) -> StatusCode { + self.status + } + + /// Returns the headers. + pub const fn headers(&self) -> &Headers { + &self.headers + } + + /// Returns mutable access to the headers. + pub const fn headers_mut(&mut self) -> &mut Headers { + &mut self.headers + } + + /// Returns a reference to the body. + pub const fn body(&self) -> &B { + &self.body + } + + /// Returns a mutable reference to the body. + pub const fn body_mut(&mut self) -> &mut B { + &mut self.body + } + + /// Returns the body of the response. + pub fn into_body(self) -> B { + self.body + } +} + +impl fmt::Debug for Response { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_struct("Request") + .field("version", &self.version) + .field("status", &self.status) + .field("headers", &self.headers) + .finish() + } +} diff --git a/http/src/server.rs b/http/src/server.rs new file mode 100644 index 000000000..70441d5e8 --- /dev/null +++ b/http/src/server.rs @@ -0,0 +1,1594 @@ +// TODO: `S: Supervisor` currently uses `TcpStream` as argument due to `ArgMap`. +// Maybe disconnect `S` from `NA`? +// +// TODO: Continue reading RFC 7230 section 4 Transfer Codings. +// +// TODO: RFC 7230 section 3.3.3 point 5: +// > If the sender closes the connection or the recipient +// > times out before the indicated number of octets are +// > received, the recipient MUST consider the message to be +// > incomplete and close the connection. + +//! Module with the HTTP server implementation. + +use std::cmp::min; +use std::fmt; +use std::future::Future; +use std::io::{self, Write}; +use std::mem::MaybeUninit; +use std::net::SocketAddr; +use std::pin::Pin; +use std::task::ready; +use std::task::{self, Poll}; +use std::time::SystemTime; + +use heph::net::{tcp, Bytes, BytesVectored, TcpServer, TcpStream}; +use heph::spawn::{ActorOptions, Spawn}; +use heph::{actor, rt, Actor, NewActor, Supervisor}; +use httpdate::HttpDate; + +use crate::body::BodyLength; +use crate::header::{FromHeaderValue, HeaderName, Headers}; +use crate::{ + map_version_byte, trim_ws, Method, Request, StatusCode, Version, BUF_SIZE, MAX_HEADERS, + MAX_HEAD_SIZE, MIN_READ_SIZE, +}; + +/// A intermediate structure that implements [`NewActor`], creating +/// [`HttpServer`]. +/// +/// See [`HttpServer::setup`] to create this and [`HttpServer`] for examples. +#[derive(Debug)] +pub struct Setup { + inner: tcp::server::Setup>, +} + +impl Setup { + /// Returns the address the server is bound to. + pub fn local_addr(&self) -> SocketAddr { + self.inner.local_addr() + } +} + +impl NewActor for Setup +where + S: Supervisor> + Clone + 'static, + NA: NewActor + Clone + 'static, + NA::RuntimeAccess: rt::Access + Spawn, NA::RuntimeAccess>, +{ + type Message = Message; + type Argument = (); + type Actor = HttpServer; + type Error = io::Error; + type RuntimeAccess = NA::RuntimeAccess; + + fn new( + &mut self, + ctx: actor::Context, + arg: Self::Argument, + ) -> Result { + self.inner.new(ctx, arg).map(|inner| HttpServer { inner }) + } +} + +impl Clone for Setup { + fn clone(&self) -> Setup { + Setup { + inner: self.inner.clone(), + } + } +} + +/// An actor that starts a new actor for each accepted HTTP [`Connection`]. +/// +/// `HttpServer` has the same design as [`TcpServer`]. It accept `TcpStream`s +/// and converts those into HTTP [`Connection`]s, from which HTTP [`Request`]s +/// can be read and HTTP [`Response`]s can be written. +/// +/// Similar to `TcpServer` this type works with thread-safe and thread-local +/// actors. +/// +/// [`Response`]: crate::Response +/// +/// # Graceful shutdown +/// +/// Graceful shutdown is done by sending it a [`Terminate`] message. The HTTP +/// server can also handle (shutdown) process signals, see below for an example. +/// +/// [`Terminate`]: heph::actor::messages::Terminate +/// +/// # Examples +/// +/// ```rust +/// # #![feature(never_type)] +/// use std::borrow::Cow; +/// use std::io; +/// use std::net::SocketAddr; +/// use std::time::Duration; +/// +/// use heph::actor::{self, Actor, NewActor}; +/// use heph::net::TcpStream; +/// use heph::rt::{self, Runtime, ThreadLocal}; +/// use heph::spawn::options::{ActorOptions, Priority}; +/// use heph::supervisor::{Supervisor, SupervisorStrategy}; +/// use heph::timer::Deadline; +/// use heph_http::body::OneshotBody; +/// use heph_http::{self as http, Header, HeaderName, Headers, HttpServer, Method, StatusCode}; +/// use log::error; +/// +/// fn main() -> Result<(), rt::Error> { +/// // Setup the HTTP server. +/// let actor = http_actor as fn(_, _, _) -> _; +/// let address = "127.0.0.1:7890".parse().unwrap(); +/// let server = HttpServer::setup(address, conn_supervisor, actor, ActorOptions::default()) +/// .map_err(rt::Error::setup)?; +/// +/// // Build the runtime. +/// let mut runtime = Runtime::setup().use_all_cores().build()?; +/// // On each worker thread start our HTTP server. +/// runtime.run_on_workers(move |mut runtime_ref| -> io::Result<()> { +/// let options = ActorOptions::default().with_priority(Priority::LOW); +/// let server_ref = runtime_ref.try_spawn_local(ServerSupervisor, server, (), options)?; +/// +/// # server_ref.try_send(heph::actor::messages::Terminate).unwrap(); +/// +/// // Allow graceful shutdown by responding to process signals. +/// runtime_ref.receive_signals(server_ref.try_map()); +/// Ok(()) +/// })?; +/// runtime.start() +/// } +/// +/// /// Our supervisor for the TCP server. +/// #[derive(Copy, Clone, Debug)] +/// struct ServerSupervisor; +/// +/// impl Supervisor for ServerSupervisor +/// where +/// NA: NewActor, +/// NA::Actor: Actor>, +/// { +/// fn decide(&mut self, err: http::server::Error) -> SupervisorStrategy<()> { +/// use http::server::Error::*; +/// match err { +/// Accept(err) => { +/// error!("error accepting new connection: {}", err); +/// SupervisorStrategy::Restart(()) +/// } +/// NewActor(_) => unreachable!(), +/// } +/// } +/// +/// fn decide_on_restart_error(&mut self, err: io::Error) -> SupervisorStrategy<()> { +/// error!("error restarting the TCP server: {}", err); +/// SupervisorStrategy::Stop +/// } +/// +/// fn second_restart_error(&mut self, err: io::Error) { +/// error!("error restarting the actor a second time: {}", err); +/// } +/// } +/// +/// fn conn_supervisor(err: io::Error) -> SupervisorStrategy<(TcpStream, SocketAddr)> { +/// error!("error handling connection: {}", err); +/// SupervisorStrategy::Stop +/// } +/// +/// /// Our actor that handles a single HTTP connection. +/// async fn http_actor( +/// mut ctx: actor::Context, +/// mut connection: http::Connection, +/// address: SocketAddr, +/// ) -> io::Result<()> { +/// // Set `TCP_NODELAY` on the `TcpStream`. +/// connection.set_nodelay(true)?; +/// +/// let mut headers = Headers::EMPTY; +/// loop { +/// // Read the next request. +/// let (code, body, should_close) = match connection.next_request().await? { +/// Ok(Some(request)) => { +/// // Only support GET/HEAD to "/", with an empty body. +/// if request.path() != "/" { +/// (StatusCode::NOT_FOUND, "Not found".into(), false) +/// } else if !matches!(request.method(), Method::Get | Method::Head) { +/// // Add the "Allow" header to show the HTTP methods we do +/// // support. +/// headers.add(Header::new(HeaderName::ALLOW, b"GET, HEAD")); +/// let body = "Method not allowed".into(); +/// (StatusCode::METHOD_NOT_ALLOWED, body, false) +/// } else if !request.body().is_empty() { +/// (StatusCode::PAYLOAD_TOO_LARGE, "Not expecting a body".into(), true) +/// } else { +/// // Use the IP address as body. +/// let body = Cow::from(address.ip().to_string()); +/// (StatusCode::OK, body, false) +/// } +/// } +/// // No more requests. +/// Ok(None) => return Ok(()), +/// // Error parsing request. +/// Err(err) => { +/// // Determine the correct status code to return. +/// let code = err.proper_status_code(); +/// // Create a useful error message as body. +/// let body = Cow::from(format!("Bad request: {}", err)); +/// (code, body, err.should_close()) +/// } +/// }; +/// +/// // If we want to close the connection add the "Connection: close" +/// // header. +/// if should_close { +/// headers.add(Header::new(HeaderName::CONNECTION, b"close")); +/// } +/// +/// // Send the body as a single payload. +/// let body = OneshotBody::new(body.as_bytes()); +/// // Respond to the request. +/// connection.respond(code, &headers, body).await?; +/// +/// if should_close { +/// return Ok(()); +/// } +/// headers.clear(); +/// } +/// } +/// ``` +pub struct HttpServer> { + inner: TcpServer>, +} + +impl HttpServer +where + S: Supervisor> + Clone + 'static, + NA: NewActor + Clone + 'static, +{ + /// Create a new [server setup]. + /// + /// Arguments: + /// * `address`: the address to listen on. + /// * `supervisor`: the [`Supervisor`] used to supervise each started actor, + /// * `new_actor`: the [`NewActor`] implementation to start each actor, + /// and + /// * `options`: the actor options used to spawn the new actors. + /// + /// [server setup]: Setup + pub fn setup( + address: SocketAddr, + supervisor: S, + new_actor: NA, + options: ActorOptions, + ) -> io::Result> { + let new_actor = ArgMap { new_actor }; + TcpServer::setup(address, supervisor, new_actor, options).map(|inner| Setup { inner }) + } +} + +impl Actor for HttpServer +where + S: Supervisor> + Clone + 'static, + NA: NewActor + Clone + 'static, + NA::RuntimeAccess: rt::Access + Spawn, NA::RuntimeAccess>, +{ + type Error = Error; + + fn try_poll( + self: Pin<&mut Self>, + ctx: &mut task::Context<'_>, + ) -> Poll> { + let this = unsafe { self.map_unchecked_mut(|s| &mut s.inner) }; + this.try_poll(ctx) + } +} + +impl fmt::Debug for HttpServer +where + S: fmt::Debug, + NA: NewActor + fmt::Debug, + NA::RuntimeAccess: fmt::Debug, +{ + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_struct("HttpServer") + .field("inner", &self.inner) + .finish() + } +} + +// TODO: better name. Like `TcpStreamToConnection`? +/// Maps `NA` to accept `(TcpStream, SocketAddr)` as argument, creating a +/// [`Connection`]. +#[derive(Debug, Clone)] +pub struct ArgMap { + new_actor: NA, +} + +impl NewActor for ArgMap +where + NA: NewActor, +{ + type Message = NA::Message; + type Argument = (TcpStream, SocketAddr); + type Actor = NA::Actor; + type Error = NA::Error; + type RuntimeAccess = NA::RuntimeAccess; + + fn new( + &mut self, + ctx: actor::Context, + (stream, address): Self::Argument, + ) -> Result { + let conn = Connection::new(stream); + self.new_actor.new(ctx, (conn, address)) + } + + fn name(&self) -> &'static str { + self.new_actor.name() + } +} + +/// HTTP connection. +/// +/// This wraps a TCP stream from which [HTTP requests] are read and [HTTP +/// responses] are send to. +/// +/// It's advisable to set `TCP_NODELAY` using [`Connection::set_nodelay`] as the +/// `Connection` uses internally buffering, meaning only bodies with small +/// chunks would benefit from `TCP_NODELAY`. +/// +/// [HTTP requests]: Request +/// [HTTP responses]: crate::Response +#[derive(Debug)] +pub struct Connection { + stream: TcpStream, + buf: Vec, + /// Number of bytes of `buf` that are already parsed. + /// NOTE: this may be larger then `buf.len()`, in which case a `Body` was + /// dropped without reading it entirely. + parsed_bytes: usize, + /// The HTTP version of the last request. + last_version: Option, + /// The HTTP method of the last request. + last_method: Option, +} + +impl Connection { + /// Create a new `Connection`. + fn new(stream: TcpStream) -> Connection { + Connection { + stream, + buf: Vec::with_capacity(BUF_SIZE), + parsed_bytes: 0, + last_version: None, + last_method: None, + } + } + + /// Parse the next request from the connection. + /// + /// The return is a bit complex so let's break it down. The outer type is an + /// [`io::Result`], which often needs to be handled seperately from errors + /// in the request, e.g. by using `?`. + /// + /// Next is a `Result, `[`RequestError`]`>`. + /// `Ok(None)` is returned if the connection contains no more requests, i.e. + /// when all bytes are read. If the connection contains a request it will + /// return `Ok(Some(`[`Request`]`)`. If the request is somehow invalid it + /// will return an `Err(`[`RequestError`]`)`. + /// + /// # Notes + /// + /// Most [`RequestError`]s can't be receover from and the connection should + /// be closed when hitting them, see [`RequestError::should_close`]. If the + /// connection is not closed and `next_request` is called again it will + /// likely return the same error (but this is not guaranteed). + /// + /// Also see the [`Connection::last_request_version`] and + /// [`Connection::last_request_method`] functions to properly respond to + /// request errors. + #[allow(clippy::too_many_lines)] // TODO. + pub async fn next_request<'a>( + &'a mut self, + ) -> io::Result>>, RequestError>> { + // NOTE: not resetting the version as that doesn't change between + // requests. + self.last_method = None; + + let mut too_short = 0; + loop { + // In case of pipelined requests it could be that while reading a + // previous request's body it partially read the head of the next + // (this) request. To handle this we first attempt to parse the + // request if we have more than zero bytes (of the next request) in + // the first iteration of the loop. + while self.parsed_bytes >= self.buf.len() || self.buf.len() <= too_short { + // While we didn't read the entire previous request body, or + // while we have less than `too_short` bytes we try to receive + // some more bytes. + + self.clear_buffer(); + self.buf.reserve(MIN_READ_SIZE); + if self.stream.recv(&mut self.buf).await? == 0 { + return if self.buf.is_empty() { + // Read the entire stream, so we're done. + Ok(Ok(None)) + } else { + // Couldn't read any more bytes, but we still have bytes + // in the buffer. This means it contains a partial + // request. + Ok(Err(RequestError::IncompleteRequest)) + }; + } + } + + let mut headers = [httparse::EMPTY_HEADER; MAX_HEADERS]; + let mut request = httparse::Request::new(&mut headers); + // SAFETY: because we received until at least `self.parsed_bytes >= + // self.buf.len()` above, we can safely slice the buffer.. + match request.parse(&self.buf[self.parsed_bytes..]) { + Ok(httparse::Status::Complete(head_length)) => { + self.parsed_bytes += head_length; + + // SAFETY: all these unwraps are safe because `parse` above + // ensures there all `Some`. + let method = match request.method.unwrap().parse() { + Ok(method) => method, + Err(_) => return Ok(Err(RequestError::UnknownMethod)), + }; + self.last_method = Some(method); + let path = request.path.unwrap().to_string(); + let version = map_version_byte(request.version.unwrap()); + self.last_version = Some(version); + + // RFC 7230 section 3.3.3 Message Body Length. + let mut body_length: Option = None; + let res = Headers::from_httparse_headers(request.headers, |name, value| { + if *name == HeaderName::CONTENT_LENGTH { + // RFC 7230 section 3.3.3 point 4: + // > If a message is received without + // > Transfer-Encoding and with either multiple + // > Content-Length header fields having differing + // > field-values or a single Content-Length header + // > field having an invalid value, then the message + // > framing is invalid and the recipient MUST treat + // > it as an unrecoverable error. If this is a + // > request message, the server MUST respond with a + // > 400 (Bad Request) status code and then close + // > the connection. + if let Ok(length) = FromHeaderValue::from_bytes(value) { + match body_length.as_mut() { + Some(BodyLength::Known(body_length)) + if *body_length == length => {} + Some(BodyLength::Known(_)) => { + return Err(RequestError::DifferentContentLengths) + } + Some(BodyLength::Chunked) => { + return Err(RequestError::ContentLengthAndTransferEncoding) + } + // RFC 7230 section 3.3.3 point 5: + // > If a valid Content-Length header field + // > is present without Transfer-Encoding, + // > its decimal value defines the expected + // > message body length in octets. + None => body_length = Some(BodyLength::Known(length)), + } + } else { + return Err(RequestError::InvalidContentLength); + } + } else if *name == HeaderName::TRANSFER_ENCODING { + let mut encodings = value.split(|b| *b == b',').peekable(); + while let Some(encoding) = encodings.next() { + match trim_ws(encoding) { + b"chunked" => { + // RFC 7230 section 3.3.3 point 3: + // > If a Transfer-Encoding header field + // > is present in a request and the + // > chunked transfer coding is not the + // > final encoding, the message body + // > length cannot be determined + // > reliably; the server MUST respond + // > with the 400 (Bad Request) status + // > code and then close the connection. + if encodings.peek().is_some() { + return Err( + RequestError::ChunkedNotLastTransferEncoding, + ); + } + + // RFC 7230 section 3.3.3 point 3: + // > If a message is received with both + // > a Transfer-Encoding and a + // > Content-Length header field, the + // > Transfer-Encoding overrides the + // > Content-Length. Such a message + // > might indicate an attempt to + // > perform request smuggling (Section + // > 9.5) or response splitting (Section + // > 9.4) and ought to be handled as an + // > error. + if body_length.is_some() { + return Err( + RequestError::ContentLengthAndTransferEncoding, + ); + } + + body_length = Some(BodyLength::Chunked); + } + b"identity" => {} // No changes. + // TODO: support "compress", "deflate" and + // "gzip". + _ => return Err(RequestError::UnsupportedTransferEncoding), + } + } + } + Ok(()) + }); + let headers = match res { + Ok(headers) => headers, + Err(err) => return Ok(Err(err)), + }; + + let kind = match body_length { + Some(BodyLength::Known(left)) => BodyKind::Oneshot { left }, + Some(BodyLength::Chunked) => { + #[allow(clippy::cast_possible_truncation)] // For truncate below. + match httparse::parse_chunk_size(&self.buf[self.parsed_bytes..]) { + Ok(httparse::Status::Complete((idx, chunk_size))) => { + self.parsed_bytes += idx; + BodyKind::Chunked { + // FIXME: add check here. It's fine on + // 64 bit (only currently supported). + left_in_chunk: chunk_size as usize, + read_complete: chunk_size == 0, + } + } + Ok(httparse::Status::Partial) => BodyKind::Chunked { + left_in_chunk: 0, + read_complete: false, + }, + Err(_) => return Ok(Err(RequestError::InvalidChunkSize)), + } + } + // RFC 7230 section 3.3.3 point 6: + // > If this is a request message and none of the above + // > are true, then the message body length is zero (no + // > message body is present). + None => BodyKind::Oneshot { left: 0 }, + }; + let body = Body { conn: self, kind }; + return Ok(Ok(Some(Request::new(method, path, version, headers, body)))); + } + Ok(httparse::Status::Partial) => { + // Buffer doesn't include the entire request head, try + // reading more bytes (in the next iteration). + too_short = self.buf.len(); + self.last_method = request.method.and_then(|m| m.parse().ok()); + if let Some(version) = request.version { + self.last_version = Some(map_version_byte(version)); + } + + if too_short >= MAX_HEAD_SIZE { + return Ok(Err(RequestError::HeadTooLarge)); + } + + continue; + } + Err(err) => return Ok(Err(RequestError::from_httparse(err))), + } + } + } + + /// Returns the HTTP version of the last (partial) request. + /// + /// This can be used in cases where [`Connection::next_request`] returns a + /// [`RequestError`]. + /// + /// # Examples + /// + /// Responding to a [`RequestError`]. + /// + /// ``` + /// use heph_http::{Response, Headers, StatusCode, Version, Method}; + /// use heph_http::server::{Connection, RequestError}; + /// use heph_http::body::OneshotBody; + /// + /// # return; + /// # #[allow(unreachable_code)] + /// # { + /// let mut conn: Connection = /* From HttpServer. */ + /// # panic!("can't actually run example"); + /// + /// // Reading a request returned this error. + /// let err = RequestError::IncompleteRequest; + /// + /// // We can use `last_request_method` to determine the method of the last + /// // request, which is used to determine if we need to send a body. + /// let request_method = conn.last_request_method().unwrap_or(Method::Get); + /// + /// // We can use `last_request_version` to determine the client preferred + /// // HTTP version, or default to the server's preferred version (HTTP/1.1 + /// // here). + /// let version = conn.last_request_version().unwrap_or(Version::Http11); + /// + /// let msg = format!("Bad request: {}", err); + /// let body = OneshotBody::new(msg.as_bytes()); + /// + /// // Respond with the response. + /// conn.send_response(request_method, version, StatusCode::BAD_REQUEST, &Headers::EMPTY, body); + /// + /// // Close the connection if the error is fatal. + /// if err.should_close() { + /// return; + /// } + /// # } + /// ``` + pub fn last_request_version(&self) -> Option { + self.last_version + } + + /// Returns the HTTP method of the last (partial) request. + /// + /// This can be used in cases where [`Connection::next_request`] returns a + /// [`RequestError`]. + /// + /// # Examples + /// + /// See [`Connection::last_request_version`] for an example that responds to + /// a [`RequestError`], which uses `last_request_method`. + pub fn last_request_method(&self) -> Option { + self.last_method + } + + /// Respond to the last parsed request. + /// + /// # Notes + /// + /// This uses information from the last call to [`Connection::next_request`] + /// to respond to the request correctly. For example it uses the HTTP + /// [`Method`] to determine whether or not to send the body (as HEAD request + /// don't expect a body). When reading multiple requests from the connection + /// before responding use [`Connection::send_response`] directly. + /// + /// See the notes for [`Connection::send_response`], they apply to this + /// function also. + #[allow(clippy::future_not_send)] + pub async fn respond<'b, B>( + &mut self, + status: StatusCode, + headers: &Headers, + body: B, + ) -> io::Result<()> + where + B: crate::Body<'b>, + { + let req_method = self.last_method.unwrap_or(Method::Get); + let version = self.last_version.unwrap_or(Version::Http11).highest_minor(); + self.send_response(req_method, version, status, headers, body) + .await + } + + /// Send a [`Response`]. + /// + /// Arguments: + /// * `request_method` is the method used by the [`Request`], used to + /// determine if a body needs to be send. + /// * `version`, `status`, `headers` and `body` make up the HTTP + /// [`Response`]. + /// + /// In most cases it's easier to use [`Connection::respond`], only when + /// reading two requests before responding is this function useful. + /// + /// [`Response`]: crate::Response + /// + /// # Notes + /// + /// This automatically sets the "Content-Length" or "Transfer-Encoding", + /// "Connection" and "Date" headers if not provided in `headers`. + /// + /// If `request_method.`[`expects_body()`] or `status.`[`includes_body()`] + /// returns `false` this will not write the body to the connection. + /// + /// [`expects_body()`]: Method::expects_body + /// [`includes_body()`]: StatusCode::includes_body + #[allow(clippy::future_not_send)] + pub async fn send_response<'b, B>( + &mut self, + request_method: Method, + // Response data: + version: Version, + status: StatusCode, + headers: &Headers, + body: B, + ) -> io::Result<()> + where + B: crate::Body<'b>, + { + let mut itoa_buf = itoa::Buffer::new(); + + // Clear bytes from the previous request, keeping the bytes of the + // request. + self.clear_buffer(); + let ignore_end = self.buf.len(); + + // Format the status-line (RFC 7230 section 3.1.2). + self.buf.extend_from_slice(version.as_str().as_bytes()); + self.buf.push(b' '); + self.buf + .extend_from_slice(itoa_buf.format(status.0).as_bytes()); + // NOTE: we're not sending a reason-phrase, but the space is required + // before \r\n. + self.buf.extend_from_slice(b" \r\n"); + + // Format the headers (RFC 7230 section 3.2). + let mut set_connection_header = false; + let mut set_content_length_header = false; + let mut set_transfer_encoding_header = false; + let mut set_date_header = false; + for header in headers.iter() { + let name = header.name(); + // Field-name: + self.buf.extend_from_slice(name.as_ref().as_bytes()); + // NOTE: spacing after the colon (`:`) is optional. + self.buf.extend_from_slice(b": "); + // Append the header's value. + // NOTE: `header.value` shouldn't contain CRLF (`\r\n`). + self.buf.extend_from_slice(header.value()); + self.buf.extend_from_slice(b"\r\n"); + + if name == &HeaderName::CONNECTION { + set_connection_header = true; + } else if name == &HeaderName::CONTENT_LENGTH { + set_content_length_header = true; + } else if name == &HeaderName::TRANSFER_ENCODING { + set_transfer_encoding_header = true; + } else if name == &HeaderName::DATE { + set_date_header = true; + } + } + + // Provide the "Connection" header if the user didn't. + if !set_connection_header && matches!(version, Version::Http10) { + // Per RFC 7230 section 6.3, HTTP/1.0 needs the "Connection: + // keep-alive" header to persistent the connection. Connections + // using HTTP/1.1 persistent by default. + self.buf.extend_from_slice(b"Connection: keep-alive\r\n"); + } + + // Provide the "Date" header if the user didn't. + if !set_date_header { + let now = HttpDate::from(SystemTime::now()); + write!(&mut self.buf, "Date: {}\r\n", now).unwrap(); + } + + // Provide the "Conent-Length" or "Transfer-Encoding" header if the user + // didn't. + let mut send_body = true; + if !set_content_length_header && !set_transfer_encoding_header { + match body.length() { + _ if !request_method.expects_body() || !status.includes_body() => { + send_body = false; + extend_content_length_header(&mut self.buf, &mut itoa_buf, 0) + } + BodyLength::Known(length) => { + extend_content_length_header(&mut self.buf, &mut itoa_buf, length) + } + BodyLength::Chunked => { + self.buf + .extend_from_slice(b"Transfer-Encoding: chunked\r\n"); + } + } + } + + // End of the HTTP head. + self.buf.extend_from_slice(b"\r\n"); + + // Write the response to the stream. + let http_head = &self.buf[ignore_end..]; + if send_body { + body.write_message(&mut self.stream, http_head).await?; + } else { + self.stream.send_all(http_head).await?; + } + + // Remove the response head from the buffer. + self.buf.truncate(ignore_end); + Ok(()) + } + + /// See [`TcpStream::peer_addr`]. + pub fn peer_addr(&mut self) -> io::Result { + self.stream.peer_addr() + } + + /// See [`TcpStream::local_addr`]. + pub fn local_addr(&mut self) -> io::Result { + self.stream.local_addr() + } + + /// See [`TcpStream::set_ttl`]. + pub fn set_ttl(&mut self, ttl: u32) -> io::Result<()> { + self.stream.set_ttl(ttl) + } + + /// See [`TcpStream::ttl`]. + pub fn ttl(&mut self) -> io::Result { + self.stream.ttl() + } + + /// See [`TcpStream::set_nodelay`]. + pub fn set_nodelay(&mut self, nodelay: bool) -> io::Result<()> { + self.stream.set_nodelay(nodelay) + } + + /// See [`TcpStream::nodelay`]. + pub fn nodelay(&mut self) -> io::Result { + self.stream.nodelay() + } + + /// See [`TcpStream::keepalive`]. + pub fn keepalive(&self) -> io::Result { + self.stream.keepalive() + } + + /// See [`TcpStream::set_keepalive`]. + pub fn set_keepalive(&self, enable: bool) -> io::Result<()> { + self.stream.set_keepalive(enable) + } + + /// Clear parsed request(s) from the buffer. + fn clear_buffer(&mut self) { + let buf_len = self.buf.len(); + if self.parsed_bytes >= buf_len { + // Parsed all bytes in the buffer, so we can clear it. + self.buf.clear(); + self.parsed_bytes -= buf_len; + } + + // TODO: move bytes to the start. + } + + /// Recv bytes from the underlying stream, reading into `self.buf`. + /// + /// Returns an `UnexpectedEof` error if zero bytes are received. + fn try_recv(&mut self) -> Poll> { + // Ensure we have space in the buffer to read into. + self.clear_buffer(); + self.buf.reserve(MIN_READ_SIZE); + + loop { + match self.stream.try_recv(&mut self.buf) { + Ok(0) => return Poll::Ready(Err(io::ErrorKind::UnexpectedEof.into())), + Ok(n) => return Poll::Ready(Ok(n)), + Err(ref err) if err.kind() == io::ErrorKind::WouldBlock => return Poll::Pending, + Err(ref err) if err.kind() == io::ErrorKind::Interrupted => continue, + Err(err) => return Poll::Ready(Err(err)), + } + } + } + + /// Read a HTTP body chunk. + /// + /// Returns an I/O error, or an `InvalidData` error if the chunk size is + /// invalid. + fn try_read_chunk( + &mut self, + // Fields of `BodyKind::Chunked`: + left_in_chunk: &mut usize, + read_complete: &mut bool, + ) -> Poll> { + loop { + match httparse::parse_chunk_size(&self.buf[self.parsed_bytes..]) { + #[allow(clippy::cast_possible_truncation)] // For truncate below. + Ok(httparse::Status::Complete((idx, chunk_size))) => { + self.parsed_bytes += idx; + if chunk_size == 0 { + *read_complete = true; + } + // FIXME: add check here. It's fine on 64 bit (only currently + // supported). + *left_in_chunk = chunk_size as usize; + return Poll::Ready(Ok(())); + } + Ok(httparse::Status::Partial) => {} // Read some more data below. + Err(_) => { + return Poll::Ready(Err(io::Error::new( + io::ErrorKind::InvalidData, + "invalid chunk size", + ))) + } + } + + let _ = ready!(self.try_recv())?; + } + } + + async fn read_chunk( + &mut self, + // Fields of `BodyKind::Chunked`: + left_in_chunk: &mut usize, + read_complete: &mut bool, + ) -> io::Result<()> { + loop { + match httparse::parse_chunk_size(&self.buf[self.parsed_bytes..]) { + #[allow(clippy::cast_possible_truncation)] // For truncate below. + Ok(httparse::Status::Complete((idx, chunk_size))) => { + self.parsed_bytes += idx; + if chunk_size == 0 { + *read_complete = true; + } + // FIXME: add check here. It's fine on 64 bit (only currently + // supported). + *left_in_chunk = chunk_size as usize; + return Ok(()); + } + Ok(httparse::Status::Partial) => {} // Read some more data below. + Err(_) => { + return Err(io::Error::new( + io::ErrorKind::InvalidData, + "invalid chunk size", + )) + } + } + + // Ensure we have space in the buffer to read into. + self.clear_buffer(); + self.buf.reserve(MIN_READ_SIZE); + + if self.stream.recv(&mut self.buf).await? == 0 { + return Err(io::ErrorKind::UnexpectedEof.into()); + } + } + } +} + +/// Add "Content-Length" header to `buf`. +fn extend_content_length_header( + buf: &mut Vec, + itoa_buf: &mut itoa::Buffer, + content_length: usize, +) { + buf.extend_from_slice(b"Content-Length: "); + buf.extend_from_slice(itoa_buf.format(content_length).as_bytes()); + buf.extend_from_slice(b"\r\n"); +} + +/// Body of HTTP [`Request`] read from a [`Connection`]. +/// +/// # Notes +/// +/// If the body is not (completely) read before this is dropped it will still +/// removed from the `Connection`. +#[derive(Debug)] +pub struct Body<'a> { + conn: &'a mut Connection, + kind: BodyKind, +} + +#[derive(Debug)] +enum BodyKind { + /// No encoding. + Oneshot { + /// Number of unread (by the user) bytes. + left: usize, + }, + /// Chunked transfer encoding. + Chunked { + /// Number of unread (by the user) bytes in this chunk. + left_in_chunk: usize, + /// Read all chunks. + read_complete: bool, + }, +} + +impl<'a> Body<'a> { + /// Returns the length of the body (in bytes) *left*, or a + /// + /// Calling this before [`recv`] or [`recv_vectored`] will return the + /// original body length, after removing bytes from the body this will + /// return the *remaining* length. + /// + /// The body length is determined by the "Content-Length" or + /// "Transfer-Encoding" header, or 0 if neither are present. + /// + /// [`recv`]: Body::recv + /// [`recv_vectored`]: Body::recv_vectored + pub fn len(&self) -> BodyLength { + match self.kind { + BodyKind::Oneshot { left } => BodyLength::Known(left), + BodyKind::Chunked { .. } => BodyLength::Chunked, + } + } + + /// Return the length of this chunk *left*, or the entire body in case of a + /// oneshot body. + fn chunk_len(&self) -> usize { + match self.kind { + BodyKind::Oneshot { left } => left, + BodyKind::Chunked { left_in_chunk, .. } => left_in_chunk, + } + } + + /// Returns `true` if the body is completely read (or was empty to begin + /// with). + /// + /// # Notes + /// + /// This can return `false` for empty bodies using chunked encoding if not + /// enough bytes have been read yet. Using chunked encoding we don't know + /// the length upfront as it it's determined by reading the length of each + /// chunk. If the send request only contained the HTTP head (i.e. no body) + /// and uses chunked encoding this would return `false`, as body length is + /// unknown and thus not empty. However if the body would then send a single + /// empty chunk (signaling the end of the body), this would return `true` as + /// it turns out the body is indeed empty. + pub fn is_empty(&self) -> bool { + match self.kind { + BodyKind::Oneshot { left } => left == 0, + BodyKind::Chunked { + left_in_chunk, + read_complete, + } => read_complete && left_in_chunk == 0, + } + } + + /// Returns `true` if the body is chunked. + pub fn is_chunked(&self) -> bool { + matches!(self.kind, BodyKind::Chunked { .. }) + } + + /// Receive bytes from the request body, writing them into `buf`. + pub const fn recv(&'a mut self, buf: B) -> Recv<'a, B> + where + B: Bytes, + { + Recv { body: self, buf } + } + + /// Receive bytes from the request body, writing them into `bufs`. + pub const fn recv_vectored(&'a mut self, bufs: B) -> RecvVectored<'a, B> + where + B: BytesVectored, + { + RecvVectored { body: self, bufs } + } + + /// Read the entire body into `buf`, up to `limit` bytes. + /// + /// If the body is larger then `limit` bytes it return an `io::Error`. + pub async fn read_all(&mut self, buf: &mut Vec, limit: usize) -> io::Result<()> { + let mut total = 0; + loop { + // Copy bytes in our buffer. + let bytes = self.buf_bytes(); + let len = bytes.len(); + if limit < total + len { + return Err(io::Error::new(io::ErrorKind::Other, "body too large")); + } + + buf.extend_from_slice(bytes); + self.processed(len); + total += len; + + let chunk_len = self.chunk_len(); + if chunk_len == 0 { + match &mut self.kind { + // Read all the bytes from the oneshot body. + BodyKind::Oneshot { .. } => return Ok(()), + // Read all the bytes in the chunk, so need to read another + // chunk. + BodyKind::Chunked { + left_in_chunk, + read_complete, + } => { + if *read_complete { + return Ok(()); + } + + self.conn.read_chunk(left_in_chunk, read_complete).await?; + // Copy read bytes again. + continue; + } + } + } + // Continue to reading below. + break; + } + + loop { + // Limit the read until the end of the chunk/body. + let chunk_len = self.chunk_len(); + if chunk_len == 0 { + return Ok(()); + } else if total + chunk_len > limit { + return Err(io::Error::new(io::ErrorKind::Other, "body too large")); + } + + (&mut *buf).reserve(chunk_len); + self.conn.stream.recv_n(&mut *buf, chunk_len).await?; + total += chunk_len; + + // FIXME: doesn't deal with chunked bodies. + } + } + + /// Returns the bytes currently in the buffer. + /// + /// This is limited to the bytes of this request/chunk, i.e. it doesn't + /// contain the next request/chunk. + fn buf_bytes(&self) -> &[u8] { + let bytes = &self.conn.buf[self.conn.parsed_bytes..]; + let left = match self.kind { + BodyKind::Oneshot { left } => left, + BodyKind::Chunked { left_in_chunk, .. } => left_in_chunk, + }; + if bytes.len() > left { + &bytes[..left] + } else { + bytes + } + } + + /// Copy already read bytes. + /// + /// Same as [`Body::buf_bytes`] this is limited to the bytes of this + /// request/chunk, i.e. it doesn't contain the next request/chunk. + fn copy_buf_bytes(&mut self, dst: &mut [MaybeUninit]) -> usize { + let bytes = self.buf_bytes(); + let len = min(bytes.len(), dst.len()); + if len != 0 { + let _ = MaybeUninit::write_slice(&mut dst[..len], &bytes[..len]); + self.processed(len); + } + len + } + + /// Mark `n` bytes are processed. + fn processed(&mut self, n: usize) { + // TODO: should this be `unsafe`? We don't do underflow checks... + match &mut self.kind { + BodyKind::Oneshot { left } => *left -= n, + BodyKind::Chunked { left_in_chunk, .. } => *left_in_chunk -= n, + } + self.conn.parsed_bytes += n; + } +} + +/// The [`Future`] behind [`Body::recv`]. +#[derive(Debug)] +#[must_use = "futures do nothing unless you `.await` or poll them"] +pub struct Recv<'b, B> { + body: &'b mut Body<'b>, + buf: B, +} + +impl<'b, B> Future for Recv<'b, B> +where + B: Bytes + Unpin, +{ + type Output = io::Result; + + fn poll(self: Pin<&mut Self>, _: &mut task::Context<'_>) -> Poll { + let Recv { body, buf } = Pin::into_inner(self); + + let mut len = 0; + loop { + // Copy bytes in our buffer. + len += body.copy_buf_bytes(buf.as_bytes()); + if len != 0 { + unsafe { buf.update_length(len) }; + } + + let limit = body.chunk_len(); + if limit == 0 { + match &mut body.kind { + // Read all the bytes from the oneshot body. + BodyKind::Oneshot { .. } => return Poll::Ready(Ok(len)), + // Read all the bytes in the chunk, so need to read another + // chunk. + BodyKind::Chunked { + left_in_chunk, + read_complete, + } => { + ready!(body.conn.try_read_chunk(left_in_chunk, read_complete))?; + // Copy read bytes again. + continue; + } + } + } + // Continue to reading below. + break; + } + + // Read from the stream if there is space left. + if buf.has_spare_capacity() { + // Limit the read until the end of the chunk/body. + let limit = body.chunk_len(); + loop { + match body.conn.stream.try_recv(buf.limit(limit)) { + Ok(n) => return Poll::Ready(Ok(len + n)), + Err(ref err) if err.kind() == io::ErrorKind::WouldBlock => { + return if len == 0 { + Poll::Pending + } else { + Poll::Ready(Ok(len)) + } + } + Err(ref err) if err.kind() == io::ErrorKind::Interrupted => continue, + Err(err) => return Poll::Ready(Err(err)), + } + } + } else { + Poll::Ready(Ok(len)) + } + } +} + +/// The [`Future`] behind [`Body::recv_vectored`]. +#[derive(Debug)] +#[must_use = "futures do nothing unless you `.await` or poll them"] +pub struct RecvVectored<'b, B> { + body: &'b mut Body<'b>, + bufs: B, +} + +impl<'b, B> Future for RecvVectored<'b, B> +where + B: BytesVectored + Unpin, +{ + type Output = io::Result; + + fn poll(self: Pin<&mut Self>, _: &mut task::Context<'_>) -> Poll { + let RecvVectored { body, bufs } = Pin::into_inner(self); + + let mut len = 0; + loop { + // Copy bytes in our buffer. + for buf in bufs.as_bufs().as_mut() { + match body.copy_buf_bytes(buf) { + 0 => break, + n => len += n, + } + } + if len != 0 { + unsafe { bufs.update_lengths(len) }; + } + + let limit = body.chunk_len(); + if limit == 0 { + match &mut body.kind { + // Read all the bytes from the oneshot body. + BodyKind::Oneshot { .. } => return Poll::Ready(Ok(len)), + // Read all the bytes in the chunk, so need to read another + // chunk. + BodyKind::Chunked { + left_in_chunk, + read_complete, + } => { + ready!(body.conn.try_read_chunk(left_in_chunk, read_complete))?; + // Copy read bytes again. + continue; + } + } + } + // Continue to reading below. + break; + } + + // Read from the stream if there is space left. + if bufs.has_spare_capacity() { + // Limit the read until the end of the chunk/body. + let limit = body.chunk_len(); + loop { + match body.conn.stream.try_recv_vectored(bufs.limit(limit)) { + Ok(n) => return Poll::Ready(Ok(len + n)), + Err(ref err) if err.kind() == io::ErrorKind::WouldBlock => { + return if len == 0 { + Poll::Pending + } else { + Poll::Ready(Ok(len)) + } + } + Err(ref err) if err.kind() == io::ErrorKind::Interrupted => continue, + Err(err) => return Poll::Ready(Err(err)), + } + } + } else { + Poll::Ready(Ok(len)) + } + } +} + +impl<'a> crate::Body<'a> for Body<'a> { + fn length(&self) -> BodyLength { + self.len() + } +} + +mod private { + use std::future::Future; + use std::io; + use std::pin::Pin; + use std::task::{self, ready, Poll}; + + use heph::net::TcpStream; + + use super::{Body, BodyKind}; + + #[derive(Debug)] + pub struct SendBody<'c, 's, 'h> { + pub(super) body: Body<'c>, + /// Stream we're writing the body to. + pub(super) stream: &'s mut TcpStream, + /// HTTP head for the response. + pub(super) head: &'h [u8], + } + + impl<'c, 's, 'h> Future for SendBody<'c, 's, 'h> { + type Output = io::Result<()>; + + fn poll(self: Pin<&mut Self>, _: &mut task::Context<'_>) -> Poll { + let SendBody { body, stream, head } = Pin::into_inner(self); + + // Send the HTTP head first. + // TODO: try to use vectored I/O on first call. + while !head.is_empty() { + match stream.try_send(head) { + Ok(0) => return Poll::Ready(Err(io::ErrorKind::WriteZero.into())), + Ok(n) => *head = &head[n..], + Err(ref err) if err.kind() == io::ErrorKind::WouldBlock => { + return Poll::Pending + } + Err(ref err) if err.kind() == io::ErrorKind::Interrupted => continue, + Err(err) => return Poll::Ready(Err(err)), + } + } + + while !body.is_empty() { + let limit = body.chunk_len(); + let bytes = body.buf_bytes(); + let bytes = if bytes.len() > limit { + &bytes[..limit] + } else { + bytes + }; + // TODO: maybe read first if we have less then N bytes? + if !bytes.is_empty() { + match stream.try_send(bytes) { + Ok(0) => return Poll::Ready(Err(io::ErrorKind::WriteZero.into())), + Ok(n) => { + body.processed(n); + continue; + } + Err(ref err) if err.kind() == io::ErrorKind::WouldBlock => { + return Poll::Pending + } + Err(ref err) if err.kind() == io::ErrorKind::Interrupted => continue, + Err(err) => return Poll::Ready(Err(err)), + } + // NOTE: we don't continue here, we always return on start + // the next iteration of the loop. + } + + // Read some more data, or the next chunk. + match &mut body.kind { + BodyKind::Oneshot { .. } => { + let _ = ready!(body.conn.try_recv())?; + } + BodyKind::Chunked { + left_in_chunk, + read_complete, + } => { + if *left_in_chunk == 0 { + ready!(body.conn.try_read_chunk(left_in_chunk, read_complete))?; + } else { + let _ = ready!(body.conn.try_recv())?; + } + } + } + } + + Poll::Ready(Ok(())) + } + } +} + +impl<'c> crate::body::PrivateBody<'c> for Body<'c> { + type WriteBody<'s, 'h> = private::SendBody<'c, 's, 'h>; + + fn write_message<'s, 'h>( + self, + stream: &'s mut TcpStream, + head: &'h [u8], + ) -> Self::WriteBody<'s, 'h> + where + 'c: 'h, + { + private::SendBody { + body: self, + stream, + head, + } + } +} + +impl<'a> Drop for Body<'a> { + fn drop(&mut self) { + if self.is_empty() { + // Empty body, then we're done quickly. + return; + } + + // Mark the entire body as parsed. + // NOTE: `Connection` handles the case where we didn't read the entire + // body yet. + match self.kind { + BodyKind::Oneshot { left } => self.conn.parsed_bytes += left, + BodyKind::Chunked { + left_in_chunk, + read_complete, + } => { + if read_complete { + // Read all chunks. + debug_assert_eq!(left_in_chunk, 0); + } else { + // FIXME: don't panic here. + todo!("remove chunked body from connection"); + } + } + } + } +} + +/// Error parsing HTTP request. +#[derive(Copy, Clone, Debug, Eq, PartialEq)] +pub enum RequestError { + /// Missing part of request. + IncompleteRequest, + /// HTTP Head (start line and headers) is too large. + /// + /// Limit is defined by [`MAX_HEAD_SIZE`]. + HeadTooLarge, + /// Value in the "Content-Length" header is invalid. + InvalidContentLength, + /// Multiple "Content-Length" headers were present with differing values. + DifferentContentLengths, + /// Invalid byte in header name. + InvalidHeaderName, + /// Invalid byte in header value. + InvalidHeaderValue, + /// Number of headers send in the request is larger than [`MAX_HEADERS`]. + TooManyHeaders, + /// Unsupported "Transfer-Encoding" header. + UnsupportedTransferEncoding, + /// Request has a "Transfer-Encoding" header with a chunked encoding, but + /// it's not the final encoding, then the message body length cannot be + /// determined reliably. + /// + /// See RFC 7230 section 3.3.3 point 3. + ChunkedNotLastTransferEncoding, + /// Request contains both "Content-Length" and "Transfer-Encoding" headers. + /// + /// An attacker might attempt to "smuggle a request" ("HTTP Request + /// Smuggling", Linhart et al., June 2005) or "split a response" ("Divide + /// and Conquer - HTTP Response Splitting, Web Cache Poisoning Attacks, and + /// Related Topics", Klein, March 2004). RFC 7230 (see section 3.3.3 point + /// 3) says that this "ought to be handled as an error", and so we do. + ContentLengthAndTransferEncoding, + /// Invalid byte where token is required. + InvalidToken, + /// Invalid byte in new line. + InvalidNewLine, + /// Invalid byte in HTTP version. + InvalidVersion, + /// Unknown HTTP method, not in [`Method`]. + UnknownMethod, + /// Chunk size is invalid. + InvalidChunkSize, +} + +impl RequestError { + /// Returns the proper status code for a given error. + pub const fn proper_status_code(self) -> StatusCode { + use RequestError::*; + // See the parsing code for various references to the RFC(s) that + // determine the values here. + match self { + IncompleteRequest + | HeadTooLarge + | InvalidContentLength + | DifferentContentLengths + | InvalidHeaderName + | InvalidHeaderValue + | TooManyHeaders + | ChunkedNotLastTransferEncoding + | ContentLengthAndTransferEncoding + | InvalidToken + | InvalidNewLine + | InvalidVersion + | InvalidChunkSize=> StatusCode::BAD_REQUEST, + // RFC 7230 section 3.3.1: + // > A server that receives a request message with a transfer coding + // > it does not understand SHOULD respond with 501 (Not + // > Implemented). + UnsupportedTransferEncoding + // RFC 7231 section 4.1: + // > When a request method is received that is unrecognized or not + // > implemented by an origin server, the origin server SHOULD + // > respond with the 501 (Not Implemented) status code. + | UnknownMethod => StatusCode::NOT_IMPLEMENTED, + } + } + + /// Returns `true` if the connection should be closed based on the error + /// (after sending a error response). + pub const fn should_close(self) -> bool { + use RequestError::*; + // See the parsing code for various references to the RFC(s) that + // determine the values here. + match self { + IncompleteRequest + | HeadTooLarge + | InvalidContentLength + | DifferentContentLengths + | InvalidHeaderName + | InvalidHeaderValue + | UnsupportedTransferEncoding + | TooManyHeaders + | ChunkedNotLastTransferEncoding + | ContentLengthAndTransferEncoding + | InvalidToken + | InvalidNewLine + | InvalidVersion + | InvalidChunkSize => true, + UnknownMethod => false, + } + } + + fn from_httparse(err: httparse::Error) -> RequestError { + use httparse::Error::*; + match err { + HeaderName => RequestError::InvalidHeaderName, + HeaderValue => RequestError::InvalidHeaderValue, + Token => RequestError::InvalidToken, + NewLine => RequestError::InvalidNewLine, + Version => RequestError::InvalidVersion, + TooManyHeaders => RequestError::TooManyHeaders, + // SAFETY: request never contain a status, only responses do. + Status => unreachable!(), + } + } +} + +impl fmt::Display for RequestError { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + use RequestError::*; + f.write_str(match self { + IncompleteRequest => "incomplete request", + HeadTooLarge => "head too large", + InvalidContentLength => "invalid Content-Length header", + DifferentContentLengths => "different Content-Length headers", + InvalidHeaderName => "invalid header name", + InvalidHeaderValue => "invalid header value", + TooManyHeaders => "too many header", + UnsupportedTransferEncoding => "unsupported Transfer-Encoding", + ChunkedNotLastTransferEncoding => "invalid Transfer-Encoding header", + ContentLengthAndTransferEncoding => { + "provided both Content-Length and Transfer-Encoding headers" + } + InvalidToken | InvalidNewLine => "invalid request syntax", + InvalidVersion => "invalid version", + UnknownMethod => "unknown method", + InvalidChunkSize => "invalid chunk size", + }) + } +} + +/// The message type used by [`HttpServer`] (and [`TcpServer`]). +/// +#[doc(inline)] +pub use heph::net::tcp::server::Message; + +/// Error returned by [`HttpServer`] (and [`TcpServer`]). +/// +#[doc(inline)] +pub use heph::net::tcp::server::Error; diff --git a/http/src/status_code.rs b/http/src/status_code.rs new file mode 100644 index 000000000..201018172 --- /dev/null +++ b/http/src/status_code.rs @@ -0,0 +1,391 @@ +use std::fmt; + +/// Response Status Code. +/// +/// A complete list can be found at the HTTP Status Code Registry: +/// . +/// +/// RFC 7231 section 6. +#[derive(Copy, Clone, Debug, Eq, PartialEq)] +pub struct StatusCode(pub u16); + +impl StatusCode { + // 1xx range. + /// Continue. + /// + /// RFC 7231 section 6.2.1. + pub const CONTINUE: StatusCode = StatusCode(100); + /// Switching Protocols. + /// + /// RFC 7231 section 6.2.2. + pub const SWITCHING_PROTOCOLS: StatusCode = StatusCode(101); + /// Processing. + /// + /// RFC 2518. + pub const PROCESSING: StatusCode = StatusCode(103); + /// Early Hints. + /// + /// RFC 8297. + pub const EARLY_HINTS: StatusCode = StatusCode(104); + + // 2xx range. + /// OK. + /// + /// RFC 7231 section 6.3.1. + pub const OK: StatusCode = StatusCode(200); + /// Created. + /// + /// RFC 7231 section 6.3.2. + pub const CREATED: StatusCode = StatusCode(201); + /// Accepted. + /// + /// RFC 7231 section 6.3.3. + pub const ACCEPTED: StatusCode = StatusCode(202); + /// Non-Authoritative Information. + /// + /// RFC 7231 section 6.3.4. + pub const NON_AUTHORITATIVE_INFORMATION: StatusCode = StatusCode(203); + /// No Content. + /// + /// RFC 7231 section 6.3.5. + pub const NO_CONTENT: StatusCode = StatusCode(204); + /// Reset Content. + /// + /// RFC 7231 section 6.3.6. + pub const RESET_CONTENT: StatusCode = StatusCode(205); + /// Partial Content. + /// + /// RFC 7233 section 4.1. + pub const PARTIAL_CONTENT: StatusCode = StatusCode(206); + /// Multi-Status. + /// + /// RFC 4918. + pub const MULTI_STATUS: StatusCode = StatusCode(207); + /// Already Reported. + /// + /// RFC 5842. + pub const ALREADY_REPORTED: StatusCode = StatusCode(208); + /// IM Used. + /// + /// RFC 3229. + pub const IM_USED: StatusCode = StatusCode(226); + + // 3xx range. + /// Multiple Choices. + /// + /// RFC 7231 section 6.4.1. + pub const MULTIPLE_CHOICES: StatusCode = StatusCode(300); + /// Moved Permanently. + /// + /// RFC 7231 section 6.4.2. + pub const MOVED_PERMANENTLY: StatusCode = StatusCode(301); + /// Found. + /// + /// RFC 7231 section 6.4.3. + pub const FOUND: StatusCode = StatusCode(302); + /// See Other. + /// + /// RFC 7231 section 6.4.4. + pub const SEE_OTHER: StatusCode = StatusCode(303); + /// Not Modified. + /// + /// RFC 7232 section 4.1. + pub const NOT_MODIFIED: StatusCode = StatusCode(304); + // NOTE: 306 is unused, per RFC 7231 section 6.4.6. + /// Use Proxy. + /// + /// RFC 7231 section 6.4.5. + pub const USE_PROXY: StatusCode = StatusCode(305); + /// Temporary Redirect. + /// + /// RFC 7231 section 6.4.7. + pub const TEMPORARY_REDIRECT: StatusCode = StatusCode(307); + /// Permanent Redirect. + /// + /// RFC 7538. + pub const PERMANENT_REDIRECT: StatusCode = StatusCode(308); + + // 4xx range. + /// Bad Request. + /// + /// RFC 7231 section 6.5.1. + pub const BAD_REQUEST: StatusCode = StatusCode(400); + /// Unauthorized. + /// + /// RFC 7235 section 3.1. + pub const UNAUTHORIZED: StatusCode = StatusCode(401); + /// Payment Required. + /// + /// RFC 7231 section 6.5.2. + pub const PAYMENT_REQUIRED: StatusCode = StatusCode(402); + /// Forbidden. + /// + /// RFC 7231 section 6.5.3. + pub const FORBIDDEN: StatusCode = StatusCode(403); + /// Not Found. + /// + /// RFC 7231 section 6.5.4. + pub const NOT_FOUND: StatusCode = StatusCode(404); + /// Method Not Allowed. + /// + /// RFC 7231 section 6.5.5. + pub const METHOD_NOT_ALLOWED: StatusCode = StatusCode(405); + /// Not Acceptable. + /// + /// RFC 7231 section 6.5.6. + pub const NOT_ACCEPTABLE: StatusCode = StatusCode(406); + /// Proxy Authentication Required. + /// + /// RFC 7235 section 3.2. + pub const PROXY_AUTHENTICATION_REQUIRED: StatusCode = StatusCode(407); + /// Request Timeout. + /// + /// RFC 7231 section 6.5.7. + pub const REQUEST_TIMEOUT: StatusCode = StatusCode(408); + /// Conflict. + /// + /// RFC 7231 section 6.5.8. + pub const CONFLICT: StatusCode = StatusCode(409); + /// Gone. + /// + /// RFC 7231 section 6.5.9. + pub const GONE: StatusCode = StatusCode(410); + /// Length Required. + /// + /// RFC 7231 section 6.5.10. + pub const LENGTH_REQUIRED: StatusCode = StatusCode(411); + /// Precondition Failed. + /// + /// RFC 7232 section 4.2 and RFC 8144 section 3.2. + pub const PRECONDITION_FAILED: StatusCode = StatusCode(412); + /// Payload Too Large. + /// + /// RFC 7231 section 6.5.11. + pub const PAYLOAD_TOO_LARGE: StatusCode = StatusCode(413); + /// URI Too Long. + /// + /// RFC 7231 section 6.5.12. + pub const URI_TOO_LONG: StatusCode = StatusCode(414); + /// Unsupported Media Type. + /// + /// RFC 7231 section 6.5.13 and RFC 7694 section 3. + pub const UNSUPPORTED_MEDIA_TYPE: StatusCode = StatusCode(415); + /// Range Not Satisfiable. + /// + /// RFC 7233 section 4.4. + pub const RANGE_NOT_SATISFIABLE: StatusCode = StatusCode(416); + /// Expectation Failed. + /// + /// RFC 7231 section 6.5.14. + pub const EXPECTATION_FAILED: StatusCode = StatusCode(417); + // NOTE: 418-420 are unassigned. + /// Misdirected Request. + /// + /// RFC 7540 section 9.1.2. + pub const MISDIRECTED_REQUEST: StatusCode = StatusCode(421); + /// Unprocessable Entity. + /// + /// RFC 4918. + pub const UNPROCESSABLE_ENTITY: StatusCode = StatusCode(422); + /// Locked. + /// + /// RFC 4918. + pub const LOCKED: StatusCode = StatusCode(423); + /// Failed Dependency. + /// + /// RFC 4918. + pub const FAILED_DEPENDENCY: StatusCode = StatusCode(424); + /// Too Early. + /// + /// RFC 8470. + pub const TOO_EARLY: StatusCode = StatusCode(425); + /// Upgrade Required. + /// + /// RFC 7231 section 6.5.15. + pub const UPGRADE_REQUIRED: StatusCode = StatusCode(426); + // NOTE: 427 is unassigned. + /// Precondition Required. + /// + /// RFC 6585. + pub const PRECONDITION_REQUIRED: StatusCode = StatusCode(428); + /// Too Many Requests. + /// + /// RFC 6585. + pub const TOO_MANY_REQUESTS: StatusCode = StatusCode(429); + // NOTE: 320 is unassigned. + /// Request Header Fields Too Large. + /// + /// RFC 6585. + pub const REQUEST_HEADER_FIELDS_TOO_LARGE: StatusCode = StatusCode(431); + // NOTE: 432-450 are unassigned. + /// Unavailable For Legal Reasons. + /// + /// RFC 7725. + pub const UNAVAILABLE_FOR_LEGAL_REASONS: StatusCode = StatusCode(451); + + // 5xx range. + /// Internal Server Error. + /// + /// RFC 7231 section 6.6.1. + pub const INTERNAL_SERVER_ERROR: StatusCode = StatusCode(500); + /// Not Implemented. + /// + /// RFC 7231 section 6.6.2. + pub const NOT_IMPLEMENTED: StatusCode = StatusCode(501); + /// Bad Gateway. + /// + /// RFC 7231 section 6.6.3. + pub const BAD_GATEWAY: StatusCode = StatusCode(502); + /// Service Unavailable. + /// + /// RFC 7231 section 6.6.4. + pub const SERVICE_UNAVAILABLE: StatusCode = StatusCode(503); + /// Gateway Timeout. + /// + /// RFC 7231 section 6.6.5. + pub const GATEWAY_TIMEOUT: StatusCode = StatusCode(504); + /// HTTP Version Not Supported. + /// + /// RFC 7231 section 6.6.6. + pub const HTTP_VERSION_NOT_SUPPORTED: StatusCode = StatusCode(505); + /// Variant Also Negotiates. + /// + /// RFC 2295. + pub const VARIANT_ALSO_NEGOTIATES: StatusCode = StatusCode(506); + /// Insufficient Storage. + /// + /// RFC 4918. + pub const INSUFFICIENT_STORAGE: StatusCode = StatusCode(507); + /// Loop Detected. + /// + /// RFC 5842. + pub const LOOP_DETECTED: StatusCode = StatusCode(508); + // NOTE: 509 is unassigned. + /// Not Extended. + /// + /// RFC 2774. + pub const NOT_EXTENDED: StatusCode = StatusCode(510); + /// Network Authentication Required. + /// + /// RFC 6585. + pub const NETWORK_AUTHENTICATION_REQUIRED: StatusCode = StatusCode(511); + + /// Returns `true` if the status code is in 1xx range. + pub const fn is_informational(self) -> bool { + self.0 >= 100 && self.0 <= 199 + } + + /// Returns `true` if the status code is in 2xx range. + pub const fn is_successful(self) -> bool { + self.0 >= 200 && self.0 <= 299 + } + + /// Returns `true` if the status code is in 3xx range. + pub const fn is_redirect(self) -> bool { + self.0 >= 300 && self.0 <= 399 + } + + /// Returns `true` if the status code is in 4xx range. + pub const fn is_client_error(self) -> bool { + self.0 >= 400 && self.0 <= 499 + } + + /// Returns `true` if the status code is in 5xx range. + pub const fn is_server_error(self) -> bool { + self.0 >= 500 && self.0 <= 599 + } + + /// Returns `false` if the status code MUST NOT include a body. + /// + /// This includes the entire 1xx (Informational) range, 204 (No Content), + /// and 304 (Not Modified). + /// + /// Also see RFC 7230 section 3.3 and RFC 7231 section 6 (the individual + /// status codes). + pub const fn includes_body(self) -> bool { + // RFC 7230 section 3.3: + // > All 1xx (Informational), 204 (No Content), and 304 (Not Modified) + // > responses do not include a message body. All other responses do + // > include a message body, although the body might be of zero length. + !matches!(self.0, 100..=199 | 204 | 304) + } + + /// Returns the reason phrase for well known status codes. + pub const fn phrase(self) -> Option<&'static str> { + match self.0 { + 100 => Some("Continue"), + 101 => Some("Switching Protocols"), + 103 => Some("Processing"), + 104 => Some("Early Hints"), + + 200 => Some("OK"), + 201 => Some("Created"), + 202 => Some("Accepted"), + 203 => Some("Non-Authoritative Information"), + 204 => Some("No Content"), + 205 => Some("Reset Content"), + 206 => Some("Partial Content"), + 207 => Some("Multi-Status"), + 208 => Some("Already Reported"), + 226 => Some("IM Used"), + + 300 => Some("Multiple Choices"), + 301 => Some("Moved Permanently"), + 302 => Some("Found"), + 303 => Some("See Other"), + 304 => Some("Not Modified"), + 305 => Some("Use Proxy"), + 307 => Some("Temporary Redirect"), + 308 => Some("Permanent Redirect"), + + 400 => Some("Bad Request"), + 401 => Some("Unauthorized"), + 402 => Some("Payment Required"), + 403 => Some("Forbidden"), + 404 => Some("Not Found"), + 405 => Some("Method Not Allowed"), + 406 => Some("Not Acceptable"), + 407 => Some("Proxy Authentication Required"), + 408 => Some("Request Timeout"), + 409 => Some("Conflict"), + 410 => Some("Gone"), + 411 => Some("Length Required"), + 412 => Some("Precondition Failed"), + 413 => Some("Payload Too Large"), + 414 => Some("URI Too Long"), + 415 => Some("Unsupported Media Type"), + 416 => Some("Range Not Satisfiable"), + 417 => Some("Expectation Failed"), + 421 => Some("Misdirected Request"), + 422 => Some("Unprocessable Entity"), + 423 => Some("Locked"), + 424 => Some("Failed Dependency"), + 425 => Some("Too Early"), + 426 => Some("Upgrade Required"), + 428 => Some("Precondition Required"), + 429 => Some("Too Many Requests"), + 431 => Some("Request Header Fields Too Large"), + 451 => Some("Unavailable For Legal Reasons"), + + 500 => Some("Internal Server Error"), + 501 => Some("Not Implemented"), + 502 => Some("Bad Gateway"), + 503 => Some("Service Unavailable"), + 504 => Some("Gateway Timeout"), + 505 => Some("HTTP Version Not Supported"), + 506 => Some("Variant Also Negotiates"), + 507 => Some("Insufficient Storage"), + 508 => Some("Loop Detected"), + 510 => Some("Not Extended"), + 511 => Some("Network Authentication Required"), + + _ => None, + } + } +} + +impl fmt::Display for StatusCode { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + self.0.fmt(f) + } +} diff --git a/http/src/version.rs b/http/src/version.rs new file mode 100644 index 000000000..bf8e69a1d --- /dev/null +++ b/http/src/version.rs @@ -0,0 +1,87 @@ +//! Module with HTTP version related types. + +use std::fmt; +use std::str::FromStr; + +/// HTTP version. +/// +/// RFC 7231 section 2.6. +#[derive(Copy, Clone, Debug, PartialEq, Eq)] +pub enum Version { + /// HTTP/1.0. + /// + /// RFC 1945. + Http10, + /// HTTP/1.1. + /// + /// RFC 7230. + Http11, +} + +impl Version { + /// Returns the major version. + pub const fn major(self) -> u8 { + match self { + Version::Http10 | Version::Http11 => 1, + } + } + + /// Returns the minor version. + pub const fn minor(self) -> u8 { + match self { + Version::Http10 => 0, + Version::Http11 => 1, + } + } + + /// Returns the highest minor version with the same major version as `self`. + /// + /// According to RFC 7230 section 2.6: + /// > A server SHOULD send a response version equal to the highest version + /// > to which the server is conformant that has a major version less than or + /// > equal to the one received in the request. + /// + /// This function can be used to return the highest version given a major + /// version. + pub const fn highest_minor(self) -> Version { + match self { + Version::Http10 | Version::Http11 => Version::Http11, + } + } + + /// Returns the version as string. + pub const fn as_str(self) -> &'static str { + match self { + Version::Http10 => "HTTP/1.0", + Version::Http11 => "HTTP/1.1", + } + } +} + +impl fmt::Display for Version { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.write_str(self.as_str()) + } +} + +/// Error returned by the [`FromStr`] implementation for [`Version`]. +#[derive(Copy, Clone, Debug)] +pub struct UnknownVersion; + +impl fmt::Display for UnknownVersion { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.write_str("unknown HTTP version") + } +} + +impl FromStr for Version { + type Err = UnknownVersion; + + fn from_str(method: &str) -> Result { + match method { + "HTTP/1.0" => Ok(Version::Http10), + "HTTP/1.1" => Ok(Version::Http11), + _ => Err(UnknownVersion), + } + } +} diff --git a/http/tests/functional.rs b/http/tests/functional.rs new file mode 100644 index 000000000..c1fe31517 --- /dev/null +++ b/http/tests/functional.rs @@ -0,0 +1,21 @@ +//! Functional tests. + +#![feature(async_stream, never_type, once_cell)] + +use std::mem::size_of; + +#[track_caller] +fn assert_size(expected: usize) { + assert_eq!(size_of::(), expected); +} + +#[path = "functional"] // rustfmt can't find the files. +mod functional { + mod client; + mod from_header_value; + mod header; + mod method; + mod server; + mod status_code; + mod version; +} diff --git a/http/tests/functional/client.rs b/http/tests/functional/client.rs new file mode 100644 index 000000000..a99c3d9d9 --- /dev/null +++ b/http/tests/functional/client.rs @@ -0,0 +1,1478 @@ +#![allow(unused_imports)] + +use std::borrow::Cow; +use std::io::{self, Read, Write}; +use std::lazy::SyncLazy; +use std::net::{Shutdown, SocketAddr, TcpListener, TcpStream}; +use std::sync::{Arc, Condvar, Mutex, Weak}; +use std::task::Poll; +use std::thread::{self, sleep}; +use std::time::{Duration, SystemTime}; +use std::{fmt, str}; + +use heph::actor::messages::Terminate; +use heph::rt::{self, Runtime, ThreadSafe}; +use heph::spawn::options::{ActorOptions, Priority}; +use heph::test::{init_actor, poll_actor}; +use heph::{actor, Actor, ActorRef, NewActor, Supervisor, SupervisorStrategy}; +use heph_http::body::{EmptyBody, OneshotBody}; +use heph_http::client::{Client, ResponseError}; +use heph_http::server::{HttpServer, RequestError}; +use heph_http::{self as http, Header, HeaderName, Headers, Method, Response, StatusCode, Version}; +use httpdate::fmt_http_date; + +const USER_AGENT: &[u8] = b"Heph-HTTP/0.1.0"; + +/// Macro to run with a test server. +macro_rules! with_test_server { + (|$test_server: ident| $test: block) => { + let test_server = TestServer::spawn(); + let $test_server = test_server; + $test + }; +} + +#[test] +fn get() { + with_test_server!(|test_server| { + async fn http_actor( + mut ctx: actor::Context, + address: SocketAddr, + ) -> io::Result<()> { + let mut client = Client::connect(&mut ctx, address)?.await?; + let response = client.get("/").await?; + let headers = Headers::from([Header::new(HeaderName::CONTENT_LENGTH, b"2")]); + expect_response(response, Version::Http11, StatusCode::OK, &headers, b"Ok").await; + Ok(()) + } + + let (mut stream, handle) = test_server.accept(|address| { + let http_actor = http_actor as fn(_, _) -> _; + let (actor, _) = init_actor(http_actor, address).unwrap(); + actor + }); + + expect_request( + &mut stream, + Method::Get, + "/", + Version::Http11, + &Headers::from([Header::new(HeaderName::USER_AGENT, USER_AGENT)]), + b"", + ); + + // Write response. + stream + .write_all(b"HTTP/1.1 200\r\nContent-Length: 2\r\n\r\nOk") + .unwrap(); + + handle.join().unwrap(); + }); +} + +#[test] +fn get_no_response() { + with_test_server!(|test_server| { + async fn http_actor( + mut ctx: actor::Context, + address: SocketAddr, + ) -> io::Result<()> { + let mut client = Client::connect(&mut ctx, address)?.await?; + let err = client.get("/").await.unwrap_err(); + assert_eq!(err.kind(), io::ErrorKind::UnexpectedEof); + assert_eq!(err.to_string(), "no HTTP response"); + Ok(()) + } + + let (mut stream, handle) = test_server.accept(|address| { + let http_actor = http_actor as fn(_, _) -> _; + let (actor, _) = init_actor(http_actor, address).unwrap(); + actor + }); + + expect_request( + &mut stream, + Method::Get, + "/", + Version::Http11, + &Headers::from([Header::new(HeaderName::USER_AGENT, USER_AGENT)]), + b"", + ); + + // No response. + drop(stream); + + handle.join().unwrap(); + }); +} + +#[test] +fn get_invalid_response() { + with_test_server!(|test_server| { + async fn http_actor( + mut ctx: actor::Context, + address: SocketAddr, + ) -> io::Result<()> { + let mut client = Client::connect(&mut ctx, address)?.await?; + let err = client.get("/").await.unwrap_err(); + assert_eq!(err.kind(), io::ErrorKind::InvalidData); + assert_eq!(err.to_string(), "invalid HTTP response status"); + Ok(()) + } + + let (mut stream, handle) = test_server.accept(|address| { + let http_actor = http_actor as fn(_, _) -> _; + let (actor, _) = init_actor(http_actor, address).unwrap(); + actor + }); + + expect_request( + &mut stream, + Method::Get, + "/", + Version::Http11, + &Headers::from([Header::new(HeaderName::USER_AGENT, USER_AGENT)]), + b"", + ); + + // Write response. + stream + .write_all(b"HTTP/1.1 a00\r\nContent-Length: 2\r\n\r\nOk") + .unwrap(); + + handle.join().unwrap(); + }); +} + +#[test] +fn request_with_headers() { + with_test_server!(|test_server| { + async fn http_actor( + mut ctx: actor::Context, + address: SocketAddr, + ) -> io::Result<()> { + let mut client = Client::connect(&mut ctx, address)?.await?; + let headers = Headers::from([Header::new(HeaderName::HOST, b"localhost")]); + let response = client + .request(Method::Get, "/", &headers, EmptyBody) + .await? + .unwrap(); + let headers = Headers::from([Header::new(HeaderName::CONTENT_LENGTH, b"2")]); + expect_response(response, Version::Http11, StatusCode::OK, &headers, b"Ok").await; + Ok(()) + } + + let (mut stream, handle) = test_server.accept(|address| { + let http_actor = http_actor as fn(_, _) -> _; + let (actor, _) = init_actor(http_actor, address).unwrap(); + actor + }); + + expect_request( + &mut stream, + Method::Get, + "/", + Version::Http11, + &Headers::from([ + Header::new(HeaderName::USER_AGENT, USER_AGENT), + Header::new(HeaderName::HOST, b"localhost"), + ]), + b"", + ); + + // Write response. + stream + .write_all(b"HTTP/1.1 200\r\nContent-Length: 2\r\n\r\nOk") + .unwrap(); + + handle.join().unwrap(); + }); +} + +#[test] +fn request_with_user_agent_header() { + with_test_server!(|test_server| { + async fn http_actor( + mut ctx: actor::Context, + address: SocketAddr, + ) -> io::Result<()> { + let mut client = Client::connect(&mut ctx, address)?.await?; + let headers = Headers::from([Header::new(HeaderName::USER_AGENT, b"my-user-agent")]); + let response = client + .request(Method::Get, "/", &headers, EmptyBody) + .await? + .unwrap(); + let headers = Headers::from([Header::new(HeaderName::CONTENT_LENGTH, b"2")]); + expect_response(response, Version::Http11, StatusCode::OK, &headers, b"Ok").await; + Ok(()) + } + + let (mut stream, handle) = test_server.accept(|address| { + let http_actor = http_actor as fn(_, _) -> _; + let (actor, _) = init_actor(http_actor, address).unwrap(); + actor + }); + + expect_request( + &mut stream, + Method::Get, + "/", + Version::Http11, + &Headers::from([Header::new(HeaderName::USER_AGENT, b"my-user-agent")]), + b"", + ); + + // Write response. + stream + .write_all(b"HTTP/1.1 200\r\nContent-Length: 2\r\n\r\nOk") + .unwrap(); + + handle.join().unwrap(); + }); +} + +/* FIXME: The following tests have the following problem: +error: implementation of `body::private::PrivateBody` is not general enough + --> http/tests/functional/client.rs:255:48 + | +255 | let (mut stream, handle) = test_server.accept(|address| { + | ^^^^^^ implementation of `body::private::PrivateBody` is not general enough + | + = note: `body::private::PrivateBody<'1>` would have to be implemented for the type `OneshotBody<'0>`, for any two lifetimes `'0` and `'1`... + = note: ...but `body::private::PrivateBody<'2>` is actually implemented for the type `OneshotBody<'2>`, for some specific lifetime `'2` + +#[test] +fn request_with_content_length_header() { + with_test_server!(|test_server| { + async fn http_actor( + mut ctx: actor::Context, + address: SocketAddr, + ) -> io::Result<()> { + let mut client = Client::connect(&mut ctx, address)?.await?; + let body = OneshotBody::new(b"Hi"); + // NOTE: Content-Length is incorrect for this test! + let headers = Headers::from([Header::new(HeaderName::CONTENT_LENGTH, b"3")]); + let response = client + .request(Method::Get, "/", &headers, body) + .await? + .unwrap(); + let headers = Headers::from([Header::new(HeaderName::CONTENT_LENGTH, b"2")]); + expect_response(response, Version::Http11, StatusCode::OK, &headers, b"Ok").await; + Ok(()) + } + + let (mut stream, handle) = test_server.accept(|address| { + let http_actor = http_actor as fn(_, _) -> _; + let (actor, _) = init_actor(http_actor, address).unwrap(); + actor + }); + + expect_request( + &mut stream, + Method::Get, + "/", + Version::Http11, + &Headers::from([ + Header::new(HeaderName::USER_AGENT, USER_AGENT), + Header::new(HeaderName::CONTENT_LENGTH, b"3"), + ]), + b"hi", + ); + + // Write response. + stream + .write_all(b"HTTP/1.1 200\r\nContent-Length: 2\r\n\r\nOk") + .unwrap(); + + handle.join().unwrap(); + }); +} + +#[test] +fn request_with_transfer_encoding_header() { + with_test_server!(|test_server| { + async fn http_actor( + mut ctx: actor::Context, + address: SocketAddr, + ) -> io::Result<()> { + let mut client = Client::connect(&mut ctx, address)?.await?; + let headers = Headers::from([Header::new(HeaderName::TRANSFER_ENCODING, b"identify")]); + let body = OneshotBody::new(b"Hi"); + let response = client + .request(Method::Get, "/", &headers, body) + .await? + .unwrap(); + let headers = Headers::from([Header::new(HeaderName::CONTENT_LENGTH, b"2")]); + expect_response(response, Version::Http11, StatusCode::OK, &headers, b"Ok").await; + Ok(()) + } + + let (mut stream, handle) = test_server.accept(|address| { + let http_actor = http_actor as fn(_, _) -> _; + let (actor, _) = init_actor(http_actor, address).unwrap(); + actor + }); + + expect_request( + &mut stream, + Method::Get, + "/", + Version::Http11, + &Headers::from([ + Header::new(HeaderName::USER_AGENT, USER_AGENT), + Header::new(HeaderName::TRANSFER_ENCODING, b"identify"), + ]), + b"hi", + ); + + // Write response. + stream + .write_all(b"HTTP/1.1 200\r\nContent-Length: 2\r\n\r\nOk") + .unwrap(); + + handle.join().unwrap(); + }); +} + +#[test] +fn request_sets_content_length_header() { + with_test_server!(|test_server| { + async fn http_actor( + mut ctx: actor::Context, + address: SocketAddr, + ) -> io::Result<()> { + let mut client = Client::connect(&mut ctx, address)?.await?; + let body = OneshotBody::new(b"Hello"); + let response = client + .request(Method::Get, "/", &Headers::EMPTY, body) + .await? + .unwrap(); + let headers = Headers::from([Header::new(HeaderName::CONTENT_LENGTH, b"2")]); + expect_response(response, Version::Http11, StatusCode::OK, &headers, b"Ok").await; + Ok(()) + } + + let (mut stream, handle) = test_server.accept(|address| { + let http_actor = http_actor as fn(_, _) -> _; + let (actor, _) = init_actor(http_actor, address).unwrap(); + actor + }); + + expect_request( + &mut stream, + Method::Get, + "/", + Version::Http11, + &Headers::from([ + Header::new(HeaderName::USER_AGENT, USER_AGENT), + Header::new(HeaderName::CONTENT_LENGTH, b"4"), + ]), + b"Hello", + ); + + // Write response. + stream + .write_all(b"HTTP/1.1 200\r\nContent-Length: 2\r\n\r\nOk") + .unwrap(); + + handle.join().unwrap(); + }); +} +*/ + +// TODO: add test with `ChunkedBody`. + +#[test] +fn partial_response() { + with_test_server!(|test_server| { + async fn http_actor( + mut ctx: actor::Context, + address: SocketAddr, + ) -> io::Result<()> { + let mut client = Client::connect(&mut ctx, address)?.await?; + let err = client + .request(Method::Get, "/", &Headers::EMPTY, EmptyBody) + .await? + .unwrap_err(); + assert_eq!(err, ResponseError::IncompleteResponse); + Ok(()) + } + + let (mut stream, handle) = test_server.accept(|address| { + let http_actor = http_actor as fn(_, _) -> _; + let (actor, _) = init_actor(http_actor, address).unwrap(); + actor + }); + + expect_request( + &mut stream, + Method::Get, + "/", + Version::Http11, + &Headers::from([Header::new(HeaderName::USER_AGENT, USER_AGENT)]), + b"", + ); + + // Partal response, missing last `\r\n`. + stream + .write_all(b"HTTP/1.1 200\r\nContent-Length: 2\r\n") + .unwrap(); + stream.shutdown(Shutdown::Write).unwrap(); + + handle.join().unwrap(); + }); +} + +#[test] +fn same_content_length() { + with_test_server!(|test_server| { + async fn http_actor( + mut ctx: actor::Context, + address: SocketAddr, + ) -> io::Result<()> { + let mut client = Client::connect(&mut ctx, address)?.await?; + let response = client.get("/").await?; + let headers = Headers::from([ + Header::new(HeaderName::CONTENT_LENGTH, b"2"), + Header::new(HeaderName::CONTENT_LENGTH, b"2"), + ]); + expect_response(response, Version::Http11, StatusCode::OK, &headers, b"Ok").await; + Ok(()) + } + + let (mut stream, handle) = test_server.accept(|address| { + let http_actor = http_actor as fn(_, _) -> _; + let (actor, _) = init_actor(http_actor, address).unwrap(); + actor + }); + + expect_request( + &mut stream, + Method::Get, + "/", + Version::Http11, + &Headers::from([Header::new(HeaderName::USER_AGENT, USER_AGENT)]), + b"", + ); + + stream + .write_all(b"HTTP/1.1 200\r\nContent-Length: 2\r\nContent-Length: 2\r\n\r\nOk") + .unwrap(); + + handle.join().unwrap(); + }); +} + +#[test] +fn different_content_length() { + with_test_server!(|test_server| { + async fn http_actor( + mut ctx: actor::Context, + address: SocketAddr, + ) -> io::Result<()> { + let mut client = Client::connect(&mut ctx, address)?.await?; + let err = client + .request(Method::Get, "/", &Headers::EMPTY, EmptyBody) + .await? + .unwrap_err(); + assert_eq!(err, ResponseError::DifferentContentLengths); + Ok(()) + } + + let (mut stream, handle) = test_server.accept(|address| { + let http_actor = http_actor as fn(_, _) -> _; + let (actor, _) = init_actor(http_actor, address).unwrap(); + actor + }); + + expect_request( + &mut stream, + Method::Get, + "/", + Version::Http11, + &Headers::from([Header::new(HeaderName::USER_AGENT, USER_AGENT)]), + b"", + ); + + stream + .write_all(b"HTTP/1.1 200\r\nContent-Length: 2\r\nContent-Length: 4\r\n\r\nOk") + .unwrap(); + + handle.join().unwrap(); + }); +} + +#[test] +fn transfer_encoding_and_content_length_and() { + with_test_server!(|test_server| { + async fn http_actor( + mut ctx: actor::Context, + address: SocketAddr, + ) -> io::Result<()> { + let mut client = Client::connect(&mut ctx, address)?.await?; + let err = client + .request(Method::Get, "/", &Headers::EMPTY, EmptyBody) + .await? + .unwrap_err(); + assert_eq!(err, ResponseError::ContentLengthAndTransferEncoding); + Ok(()) + } + + let (mut stream, handle) = test_server.accept(|address| { + let http_actor = http_actor as fn(_, _) -> _; + let (actor, _) = init_actor(http_actor, address).unwrap(); + actor + }); + + expect_request( + &mut stream, + Method::Get, + "/", + Version::Http11, + &Headers::from([Header::new(HeaderName::USER_AGENT, USER_AGENT)]), + b"", + ); + + stream + .write_all(b"HTTP/1.1 200\r\nTransfer-Encoding: chunked\r\nContent-Length: 2\r\n\r\n") + .unwrap(); + + handle.join().unwrap(); + }); +} + +#[test] +fn invalid_content_length() { + with_test_server!(|test_server| { + async fn http_actor( + mut ctx: actor::Context, + address: SocketAddr, + ) -> io::Result<()> { + let mut client = Client::connect(&mut ctx, address)?.await?; + let err = client + .request(Method::Get, "/", &Headers::EMPTY, EmptyBody) + .await? + .unwrap_err(); + assert_eq!(err, ResponseError::InvalidContentLength); + Ok(()) + } + + let (mut stream, handle) = test_server.accept(|address| { + let http_actor = http_actor as fn(_, _) -> _; + let (actor, _) = init_actor(http_actor, address).unwrap(); + actor + }); + + expect_request( + &mut stream, + Method::Get, + "/", + Version::Http11, + &Headers::from([Header::new(HeaderName::USER_AGENT, USER_AGENT)]), + b"", + ); + + stream + .write_all(b"HTTP/1.1 200\r\nContent-Length: abc\r\n\r\nOk") + .unwrap(); + + handle.join().unwrap(); + }); +} + +#[test] +fn chunked_transfer_encoding() { + with_test_server!(|test_server| { + async fn http_actor( + mut ctx: actor::Context, + address: SocketAddr, + ) -> io::Result<()> { + let mut client = Client::connect(&mut ctx, address)?.await?; + let response = client.get("/").await?; + let headers = Headers::from([Header::new(HeaderName::TRANSFER_ENCODING, b"chunked")]); + expect_response(response, Version::Http11, StatusCode::OK, &headers, b"Ok").await; + Ok(()) + } + + let (mut stream, handle) = test_server.accept(|address| { + let http_actor = http_actor as fn(_, _) -> _; + let (actor, _) = init_actor(http_actor, address).unwrap(); + actor + }); + + expect_request( + &mut stream, + Method::Get, + "/", + Version::Http11, + &Headers::from([Header::new(HeaderName::USER_AGENT, USER_AGENT)]), + b"", + ); + + stream + .write_all(b"HTTP/1.1 200\r\nTransfer-Encoding: chunked\r\n\r\n2\r\nOk0\r\n") + .unwrap(); + + handle.join().unwrap(); + }); +} + +#[test] +fn slow_chunked_transfer_encoding() { + with_test_server!(|test_server| { + async fn http_actor( + mut ctx: actor::Context, + address: SocketAddr, + ) -> io::Result<()> { + let mut client = Client::connect(&mut ctx, address)?.await?; + let response = client.get("/").await?; + let headers = Headers::from([Header::new(HeaderName::TRANSFER_ENCODING, b"chunked")]); + expect_response(response, Version::Http11, StatusCode::OK, &headers, b"Ok").await; + Ok(()) + } + + let (mut stream, handle) = test_server.accept(|address| { + let http_actor = http_actor as fn(_, _) -> _; + let (actor, _) = init_actor(http_actor, address).unwrap(); + actor + }); + + expect_request( + &mut stream, + Method::Get, + "/", + Version::Http11, + &Headers::from([Header::new(HeaderName::USER_AGENT, USER_AGENT)]), + b"", + ); + + stream + .write_all(b"HTTP/1.1 200\r\nTransfer-Encoding: chunked\r\n\r\n") + .unwrap(); + sleep(Duration::from_millis(100)); + stream.write_all(b"2\r\nOk0\r\n").unwrap(); + + handle.join().unwrap(); + }); +} + +#[test] +fn empty_chunked_transfer_encoding() { + with_test_server!(|test_server| { + async fn http_actor( + mut ctx: actor::Context, + address: SocketAddr, + ) -> io::Result<()> { + let mut client = Client::connect(&mut ctx, address)?.await?; + let response = client.get("/").await?; + let headers = Headers::from([Header::new(HeaderName::TRANSFER_ENCODING, b"chunked")]); + expect_response(response, Version::Http11, StatusCode::OK, &headers, b"").await; + Ok(()) + } + + let (mut stream, handle) = test_server.accept(|address| { + let http_actor = http_actor as fn(_, _) -> _; + let (actor, _) = init_actor(http_actor, address).unwrap(); + actor + }); + + expect_request( + &mut stream, + Method::Get, + "/", + Version::Http11, + &Headers::from([Header::new(HeaderName::USER_AGENT, USER_AGENT)]), + b"", + ); + + stream + .write_all(b"HTTP/1.1 200\r\nTransfer-Encoding: chunked\r\n\r\n0\r\n") + .unwrap(); + + handle.join().unwrap(); + }); +} + +#[test] +fn content_length_and_identity_transfer_encoding() { + with_test_server!(|test_server| { + async fn http_actor( + mut ctx: actor::Context, + address: SocketAddr, + ) -> io::Result<()> { + let mut client = Client::connect(&mut ctx, address)?.await?; + let response = client.get("/").await?; + let headers = Headers::from([ + Header::new(HeaderName::CONTENT_LENGTH, b"2"), + Header::new(HeaderName::TRANSFER_ENCODING, b"identity"), + ]); + expect_response(response, Version::Http11, StatusCode::OK, &headers, b"Ok").await; + Ok(()) + } + + let (mut stream, handle) = test_server.accept(|address| { + let http_actor = http_actor as fn(_, _) -> _; + let (actor, _) = init_actor(http_actor, address).unwrap(); + actor + }); + + expect_request( + &mut stream, + Method::Get, + "/", + Version::Http11, + &Headers::from([Header::new(HeaderName::USER_AGENT, USER_AGENT)]), + b"", + ); + + stream + .write_all( + b"HTTP/1.1 200\r\nContent-Length: 2\r\nTransfer-Encoding: identity\r\n\r\nOk", + ) + .unwrap(); + + handle.join().unwrap(); + }); +} + +#[test] +fn unsupported_transfer_encoding() { + with_test_server!(|test_server| { + async fn http_actor( + mut ctx: actor::Context, + address: SocketAddr, + ) -> io::Result<()> { + let mut client = Client::connect(&mut ctx, address)?.await?; + let err = client + .request(Method::Get, "/", &Headers::EMPTY, EmptyBody) + .await? + .unwrap_err(); + assert_eq!(err, ResponseError::UnsupportedTransferEncoding); + Ok(()) + } + + let (mut stream, handle) = test_server.accept(|address| { + let http_actor = http_actor as fn(_, _) -> _; + let (actor, _) = init_actor(http_actor, address).unwrap(); + actor + }); + + expect_request( + &mut stream, + Method::Get, + "/", + Version::Http11, + &Headers::from([Header::new(HeaderName::USER_AGENT, USER_AGENT)]), + b"", + ); + + stream + .write_all(b"HTTP/1.1 200\r\nTransfer-Encoding: gzip\r\n\r\n") + .unwrap(); + + handle.join().unwrap(); + }); +} + +#[test] +fn chunked_not_last_transfer_encoding() { + with_test_server!(|test_server| { + async fn http_actor( + mut ctx: actor::Context, + address: SocketAddr, + ) -> io::Result<()> { + let mut client = Client::connect(&mut ctx, address)?.await?; + let response = client.get("/").await?; + let headers = Headers::from([Header::new( + HeaderName::TRANSFER_ENCODING, + b"chunked, identity", + )]); + expect_response(response, Version::Http11, StatusCode::OK, &headers, b"Ok").await; + Ok(()) + } + + let (mut stream, handle) = test_server.accept(|address| { + let http_actor = http_actor as fn(_, _) -> _; + let (actor, _) = init_actor(http_actor, address).unwrap(); + actor + }); + + expect_request( + &mut stream, + Method::Get, + "/", + Version::Http11, + &Headers::from([Header::new(HeaderName::USER_AGENT, USER_AGENT)]), + b"", + ); + + stream + .write_all(b"HTTP/1.1 200\r\nTransfer-Encoding: chunked, identity\r\n\r\nOk") + .unwrap(); + stream.shutdown(Shutdown::Write).unwrap(); + + handle.join().unwrap(); + }); +} + +#[test] +fn content_length_and_transfer_encoding() { + with_test_server!(|test_server| { + async fn http_actor( + mut ctx: actor::Context, + address: SocketAddr, + ) -> io::Result<()> { + let mut client = Client::connect(&mut ctx, address)?.await?; + let err = client + .request(Method::Get, "/", &Headers::EMPTY, EmptyBody) + .await? + .unwrap_err(); + assert_eq!(err, ResponseError::ContentLengthAndTransferEncoding); + Ok(()) + } + + let (mut stream, handle) = test_server.accept(|address| { + let http_actor = http_actor as fn(_, _) -> _; + let (actor, _) = init_actor(http_actor, address).unwrap(); + actor + }); + + expect_request( + &mut stream, + Method::Get, + "/", + Version::Http11, + &Headers::from([Header::new(HeaderName::USER_AGENT, USER_AGENT)]), + b"", + ); + + stream + .write_all(b"HTTP/1.1 200\r\nContent-Length: 2\r\nTransfer-Encoding: chunked\r\n\r\n") + .unwrap(); + + handle.join().unwrap(); + }); +} + +#[test] +fn invalid_chunk_size() { + with_test_server!(|test_server| { + async fn http_actor( + mut ctx: actor::Context, + address: SocketAddr, + ) -> io::Result<()> { + let mut client = Client::connect(&mut ctx, address)?.await?; + let err = client + .request(Method::Get, "/", &Headers::EMPTY, EmptyBody) + .await? + .unwrap_err(); + assert_eq!(err, ResponseError::InvalidChunkSize); + Ok(()) + } + + let (mut stream, handle) = test_server.accept(|address| { + let http_actor = http_actor as fn(_, _) -> _; + let (actor, _) = init_actor(http_actor, address).unwrap(); + actor + }); + + expect_request( + &mut stream, + Method::Get, + "/", + Version::Http11, + &Headers::from([Header::new(HeaderName::USER_AGENT, USER_AGENT)]), + b"", + ); + + stream + .write_all(b"HTTP/1.1 200\r\nTransfer-Encoding: chunked\r\n\r\nQ\r\nOk0\r\n") + .unwrap(); + + handle.join().unwrap(); + }); +} + +#[test] +fn connect() { + with_test_server!(|test_server| { + async fn http_actor( + mut ctx: actor::Context, + address: SocketAddr, + ) -> io::Result<()> { + let mut client = Client::connect(&mut ctx, address)?.await?; + let response = client + .request(Method::Connect, "/", &Headers::EMPTY, EmptyBody) + .await? + .unwrap(); + let headers = Headers::EMPTY; + expect_response(response, Version::Http11, StatusCode::OK, &headers, b"").await; + Ok(()) + } + + let (mut stream, handle) = test_server.accept(|address| { + let http_actor = http_actor as fn(_, _) -> _; + let (actor, _) = init_actor(http_actor, address).unwrap(); + actor + }); + + expect_request( + &mut stream, + Method::Connect, + "/", + Version::Http11, + &Headers::from([Header::new(HeaderName::USER_AGENT, USER_AGENT)]), + b"", + ); + + // Write response. + stream.write_all(b"HTTP/1.1 200\r\n\r\n").unwrap(); + + handle.join().unwrap(); + }); +} + +#[test] +fn head() { + with_test_server!(|test_server| { + async fn http_actor( + mut ctx: actor::Context, + address: SocketAddr, + ) -> io::Result<()> { + let mut client = Client::connect(&mut ctx, address)?.await?; + let response = client + .request(Method::Head, "/", &Headers::EMPTY, EmptyBody) + .await? + .unwrap(); + let headers = Headers::EMPTY; + expect_response(response, Version::Http11, StatusCode::OK, &headers, b"").await; + Ok(()) + } + + let (mut stream, handle) = test_server.accept(|address| { + let http_actor = http_actor as fn(_, _) -> _; + let (actor, _) = init_actor(http_actor, address).unwrap(); + actor + }); + + expect_request( + &mut stream, + Method::Head, + "/", + Version::Http11, + &Headers::from([Header::new(HeaderName::USER_AGENT, USER_AGENT)]), + b"", + ); + + // Write response. + stream.write_all(b"HTTP/1.1 200\r\n\r\n").unwrap(); + + handle.join().unwrap(); + }); +} + +#[test] +fn response_status_204() { + with_test_server!(|test_server| { + async fn http_actor( + mut ctx: actor::Context, + address: SocketAddr, + ) -> io::Result<()> { + let mut client = Client::connect(&mut ctx, address)?.await?; + let response = client + .request(Method::Get, "/", &Headers::EMPTY, EmptyBody) + .await? + .unwrap(); + let headers = Headers::EMPTY; + let status = StatusCode::NO_CONTENT; + expect_response(response, Version::Http11, status, &headers, b"").await; + Ok(()) + } + + let (mut stream, handle) = test_server.accept(|address| { + let http_actor = http_actor as fn(_, _) -> _; + let (actor, _) = init_actor(http_actor, address).unwrap(); + actor + }); + + expect_request( + &mut stream, + Method::Get, + "/", + Version::Http11, + &Headers::from([Header::new(HeaderName::USER_AGENT, USER_AGENT)]), + b"", + ); + + // Write response. + stream.write_all(b"HTTP/1.1 204\r\n\r\n").unwrap(); + + handle.join().unwrap(); + }); +} + +#[test] +fn no_content_length_no_transfer_encoding_response() { + with_test_server!(|test_server| { + async fn http_actor( + mut ctx: actor::Context, + address: SocketAddr, + ) -> io::Result<()> { + let mut client = Client::connect(&mut ctx, address)?.await?; + let response = client + .request(Method::Get, "/", &Headers::EMPTY, EmptyBody) + .await? + .unwrap(); + let headers = Headers::EMPTY; + expect_response(response, Version::Http11, StatusCode::OK, &headers, b"Ok").await; + Ok(()) + } + + let (mut stream, handle) = test_server.accept(|address| { + let http_actor = http_actor as fn(_, _) -> _; + let (actor, _) = init_actor(http_actor, address).unwrap(); + actor + }); + + expect_request( + &mut stream, + Method::Get, + "/", + Version::Http11, + &Headers::from([Header::new(HeaderName::USER_AGENT, USER_AGENT)]), + b"", + ); + + // Write response. + stream.write_all(b"HTTP/1.1 200\r\n\r\nOk").unwrap(); + stream.shutdown(Shutdown::Write).unwrap(); + + handle.join().unwrap(); + }); +} + +#[test] +fn response_head_too_large() { + with_test_server!(|test_server| { + async fn http_actor( + mut ctx: actor::Context, + address: SocketAddr, + ) -> io::Result<()> { + let mut client = Client::connect(&mut ctx, address)?.await?; + let err = client + .request(Method::Get, "/", &Headers::EMPTY, EmptyBody) + .await? + .unwrap_err(); + assert_eq!(err, ResponseError::HeadTooLarge); + Ok(()) + } + + let (mut stream, handle) = test_server.accept(|address| { + let http_actor = http_actor as fn(_, _) -> _; + let (actor, _) = init_actor(http_actor, address).unwrap(); + actor + }); + + expect_request( + &mut stream, + Method::Get, + "/", + Version::Http11, + &Headers::from([Header::new(HeaderName::USER_AGENT, USER_AGENT)]), + b"", + ); + + // Write response. + stream.write_all(b"HTTP/1.1 200\r\n").unwrap(); + let buf = [b'a'; http::MAX_HEAD_SIZE]; + stream.write_all(&buf).unwrap(); + stream.write_all(b"\r\n").unwrap(); + + handle.join().unwrap(); + }); +} + +#[test] +fn invalid_header_name() { + with_test_server!(|test_server| { + async fn http_actor( + mut ctx: actor::Context, + address: SocketAddr, + ) -> io::Result<()> { + let mut client = Client::connect(&mut ctx, address)?.await?; + let err = client + .request(Method::Get, "/", &Headers::EMPTY, EmptyBody) + .await? + .unwrap_err(); + assert_eq!(err, ResponseError::InvalidHeaderName); + Ok(()) + } + + let (mut stream, handle) = test_server.accept(|address| { + let http_actor = http_actor as fn(_, _) -> _; + let (actor, _) = init_actor(http_actor, address).unwrap(); + actor + }); + + expect_request( + &mut stream, + Method::Get, + "/", + Version::Http11, + &Headers::from([Header::new(HeaderName::USER_AGENT, USER_AGENT)]), + b"", + ); + + // Write response. + stream.write_all(b"HTTP/1.1 200\r\n\0: \r\n\r\n").unwrap(); + + handle.join().unwrap(); + }); +} + +#[test] +fn invalid_header_value() { + with_test_server!(|test_server| { + async fn http_actor( + mut ctx: actor::Context, + address: SocketAddr, + ) -> io::Result<()> { + let mut client = Client::connect(&mut ctx, address)?.await?; + let err = client + .request(Method::Get, "/", &Headers::EMPTY, EmptyBody) + .await? + .unwrap_err(); + assert_eq!(err, ResponseError::InvalidHeaderValue); + Ok(()) + } + + let (mut stream, handle) = test_server.accept(|address| { + let http_actor = http_actor as fn(_, _) -> _; + let (actor, _) = init_actor(http_actor, address).unwrap(); + actor + }); + + expect_request( + &mut stream, + Method::Get, + "/", + Version::Http11, + &Headers::from([Header::new(HeaderName::USER_AGENT, USER_AGENT)]), + b"", + ); + + // Write response. + stream + .write_all(b"HTTP/1.1 200\r\nAbc: Header\rvalue\r\n\r\n") + .unwrap(); + + handle.join().unwrap(); + }); +} + +#[test] +fn invalid_new_line() { + with_test_server!(|test_server| { + async fn http_actor( + mut ctx: actor::Context, + address: SocketAddr, + ) -> io::Result<()> { + let mut client = Client::connect(&mut ctx, address)?.await?; + let err = client + .request(Method::Get, "/", &Headers::EMPTY, EmptyBody) + .await? + .unwrap_err(); + assert_eq!(err, ResponseError::InvalidNewLine); + Ok(()) + } + + let (mut stream, handle) = test_server.accept(|address| { + let http_actor = http_actor as fn(_, _) -> _; + let (actor, _) = init_actor(http_actor, address).unwrap(); + actor + }); + + expect_request( + &mut stream, + Method::Get, + "/", + Version::Http11, + &Headers::from([Header::new(HeaderName::USER_AGENT, USER_AGENT)]), + b"", + ); + + // Write response. + stream.write_all(b"\rHTTP/1.1 200\r\n\r\n").unwrap(); + + handle.join().unwrap(); + }); +} + +#[test] +fn invalid_version() { + with_test_server!(|test_server| { + async fn http_actor( + mut ctx: actor::Context, + address: SocketAddr, + ) -> io::Result<()> { + let mut client = Client::connect(&mut ctx, address)?.await?; + let err = client + .request(Method::Get, "/", &Headers::EMPTY, EmptyBody) + .await? + .unwrap_err(); + assert_eq!(err, ResponseError::InvalidVersion); + Ok(()) + } + + let (mut stream, handle) = test_server.accept(|address| { + let http_actor = http_actor as fn(_, _) -> _; + let (actor, _) = init_actor(http_actor, address).unwrap(); + actor + }); + + expect_request( + &mut stream, + Method::Get, + "/", + Version::Http11, + &Headers::from([Header::new(HeaderName::USER_AGENT, USER_AGENT)]), + b"", + ); + + // Write response. + stream.write_all(b"HTTPS/1.1 200\r\n\r\n").unwrap(); + + handle.join().unwrap(); + }); +} + +#[test] +fn invalid_status() { + with_test_server!(|test_server| { + async fn http_actor( + mut ctx: actor::Context, + address: SocketAddr, + ) -> io::Result<()> { + let mut client = Client::connect(&mut ctx, address)?.await?; + let err = client + .request(Method::Get, "/", &Headers::EMPTY, EmptyBody) + .await? + .unwrap_err(); + assert_eq!(err, ResponseError::InvalidStatus); + Ok(()) + } + + let (mut stream, handle) = test_server.accept(|address| { + let http_actor = http_actor as fn(_, _) -> _; + let (actor, _) = init_actor(http_actor, address).unwrap(); + actor + }); + + expect_request( + &mut stream, + Method::Get, + "/", + Version::Http11, + &Headers::from([Header::new(HeaderName::USER_AGENT, USER_AGENT)]), + b"", + ); + + // Write response. + stream.write_all(b"HTTP/1.1 2009\r\n\r\n").unwrap(); + + handle.join().unwrap(); + }); +} + +#[test] +fn too_many_headers() { + with_test_server!(|test_server| { + async fn http_actor( + mut ctx: actor::Context, + address: SocketAddr, + ) -> io::Result<()> { + let mut client = Client::connect(&mut ctx, address)?.await?; + let err = client + .request(Method::Get, "/", &Headers::EMPTY, EmptyBody) + .await? + .unwrap_err(); + assert_eq!(err, ResponseError::TooManyHeaders); + Ok(()) + } + + let (mut stream, handle) = test_server.accept(|address| { + let http_actor = http_actor as fn(_, _) -> _; + let (actor, _) = init_actor(http_actor, address).unwrap(); + actor + }); + + expect_request( + &mut stream, + Method::Get, + "/", + Version::Http11, + &Headers::from([Header::new(HeaderName::USER_AGENT, USER_AGENT)]), + b"", + ); + + // Write response. + stream.write_all(b"HTTP/1.1 200\r\n").unwrap(); + for _ in 0..=http::MAX_HEADERS { + stream.write_all(b"Some-Header: Abc\r\n").unwrap(); + } + stream.write_all(b"\r\n").unwrap(); + + handle.join().unwrap(); + }); +} + +fn expect_request( + stream: &mut TcpStream, + // Expected values: + method: Method, + path: &str, + version: Version, + headers: &Headers, + body: &[u8], +) { + let mut buf = [0; 1024]; + let n = stream.read(&mut buf).unwrap(); + let buf = &buf[..n]; + + eprintln!("read request: {:?}", str::from_utf8(&buf[..n])); + + let mut h = [httparse::EMPTY_HEADER; 64]; + let mut request = httparse::Request::new(&mut h); + let parsed_n = request.parse(&buf).unwrap().unwrap(); + + assert_eq!(request.method, Some(method.as_str())); + assert_eq!(request.path, Some(path)); + assert_eq!(request.version, Some(version.minor())); + assert_eq!( + request.headers.len(), + headers.len(), + "mismatch headers lengths, got: {:?}, expected: {:?}", + request.headers, + headers + ); + for got_header in request.headers { + let got_header_name = HeaderName::from_str(got_header.name); + let got = headers.get_value(&got_header_name).unwrap(); + assert_eq!( + got_header.value, + got, + "different header values for '{}' header, got: '{:?}', expected: '{:?}'", + got_header_name, + str::from_utf8(got_header.value), + str::from_utf8(got) + ); + } + assert_eq!(&buf[parsed_n..], body, "different bodies"); + assert_eq!(parsed_n, n - body.len(), "unexpected extra bytes"); +} + +async fn expect_response( + mut response: Response>, + // Expected values: + version: Version, + status: StatusCode, + headers: &Headers, + body: &[u8], +) { + eprintln!("read response: {:?}", response); + assert_eq!(response.version(), version); + assert_eq!(response.status(), status); + assert_eq!( + response.headers().len(), + headers.len(), + "mismatch headers lengths, got: {:?}, expected: {:?}", + response.headers(), + headers + ); + for got_header in response.headers().iter() { + let expected = headers.get_value(&got_header.name()).unwrap(); + assert_eq!( + got_header.value(), + expected, + "different header values for '{}' header, got: '{:?}', expected: '{:?}'", + got_header.name(), + str::from_utf8(got_header.value()), + str::from_utf8(expected) + ); + } + let mut got_body = Vec::new(); + response + .body_mut() + .read_all(&mut got_body, 1024) + .await + .unwrap(); + assert_eq!(got_body, body, "different bodies"); +} + +struct TestServer { + address: SocketAddr, + listener: Mutex, +} + +impl TestServer { + fn spawn() -> Arc { + static TEST_SERVER: SyncLazy>> = + SyncLazy::new(|| Mutex::new(Weak::new())); + + let mut test_server = TEST_SERVER.lock().unwrap(); + if let Some(test_server) = test_server.upgrade() { + // Use an existing running server. + test_server + } else { + // Start a new server. + let new_server = Arc::new(TestServer::new()); + *test_server = Arc::downgrade(&new_server); + new_server + } + } + + fn new() -> TestServer { + let address: SocketAddr = "127.0.0.1:0".parse().unwrap(); + let listener = TcpListener::bind(address).unwrap(); + let address = listener.local_addr().unwrap(); + + TestServer { + address, + listener: Mutex::new(listener), + } + } + + #[track_caller] + fn accept(&self, spawn: F) -> (TcpStream, thread::JoinHandle<()>) + where + F: FnOnce(SocketAddr) -> A, + A: Actor + Send + 'static, + A::Error: fmt::Display, + { + let listener = self.listener.lock().unwrap(); + let actor = spawn(self.address); + let mut actor = Box::pin(actor); + // TODO: don't run this on a different thread, use a test Heph runtime. + let handle = thread::spawn(move || { + for _ in 0..100 { + match poll_actor(actor.as_mut()) { + Poll::Pending => {} + Poll::Ready(Ok(())) => return, + Poll::Ready(Err(err)) => panic!("error in actor: {}", err), + } + sleep(Duration::from_millis(10)); + } + panic!("looped too many times"); + }); + let (stream, _) = listener.accept().unwrap(); + drop(listener); + stream.set_nodelay(true).unwrap(); + stream + .set_read_timeout(Some(Duration::from_secs(1))) + .unwrap(); + stream + .set_write_timeout(Some(Duration::from_secs(1))) + .unwrap(); + (stream, handle) + } +} diff --git a/http/tests/functional/from_header_value.rs b/http/tests/functional/from_header_value.rs new file mode 100644 index 000000000..0a3a87247 --- /dev/null +++ b/http/tests/functional/from_header_value.rs @@ -0,0 +1,112 @@ +use std::fmt; +use std::time::SystemTime; + +use heph_http::header::{FromHeaderValue, ParseIntError, ParseTimeError}; + +#[test] +fn str() { + test_parse(b"123", "123"); + test_parse(b"abc", "abc"); +} + +#[test] +fn str_not_utf8() { + test_parse_fail::<&str>(&[0, 255]); +} + +#[test] +fn integers() { + test_parse(b"123", 123u8); + test_parse(b"123", 123u16); + test_parse(b"123", 123u32); + test_parse(b"123", 123u64); + test_parse(b"123", 123usize); + + test_parse(b"255", u8::MAX); + test_parse(b"65535", u16::MAX); + test_parse(b"4294967295", u32::MAX); + test_parse(b"18446744073709551615", u64::MAX); + test_parse(b"18446744073709551615", usize::MAX); +} + +#[test] +fn integers_overflow() { + // In multiplication. + test_parse_fail::(b"300"); + test_parse_fail::(b"70000"); + test_parse_fail::(b"5000000000"); + test_parse_fail::(b"20000000000000000000"); + test_parse_fail::(b"20000000000000000000"); + + // In addition. + test_parse_fail::(b"257"); + test_parse_fail::(b"65537"); + test_parse_fail::(b"4294967297"); + test_parse_fail::(b"18446744073709551616"); + test_parse_fail::(b"18446744073709551616"); +} + +#[test] +fn empty_integers() { + test_parse_fail::(b""); + test_parse_fail::(b""); + test_parse_fail::(b""); + test_parse_fail::(b""); + test_parse_fail::(b""); +} + +#[test] +fn invalid_integers() { + test_parse_fail::(b"abc"); + test_parse_fail::(b"abc"); + test_parse_fail::(b"abc"); + test_parse_fail::(b"abc"); + test_parse_fail::(b"abc"); + + test_parse_fail::(b"2a"); + test_parse_fail::(b"2a"); + test_parse_fail::(b"2a"); + test_parse_fail::(b"2a"); + test_parse_fail::(b"2a"); +} + +#[test] +fn system_time() { + test_parse(b"Thu, 01 Jan 1970 00:00:00 GMT", SystemTime::UNIX_EPOCH); // IMF-fixdate. + test_parse(b"Thursday, 01-Jan-70 00:00:00 GMT", SystemTime::UNIX_EPOCH); // RFC 850. + test_parse(b"Thu Jan 1 00:00:00 1970", SystemTime::UNIX_EPOCH); // ANSI C’s `asctime`. +} + +#[test] +fn invalid_system_time() { + test_parse_fail::(b"\xa0\xa1"); // Invalid UTF-8. + test_parse_fail::(b"ABC, 01 Jan 1970 00:00:00 GMT"); // Invalid format. +} + +#[track_caller] +fn test_parse<'a, T>(value: &'a [u8], expected: T) +where + T: FromHeaderValue<'a> + fmt::Debug + PartialEq, + >::Err: fmt::Debug, +{ + assert_eq!(T::from_bytes(value).unwrap(), expected); +} + +#[track_caller] +fn test_parse_fail<'a, T>(value: &'a [u8]) +where + T: FromHeaderValue<'a> + fmt::Debug + PartialEq, + >::Err: fmt::Debug, +{ + assert!(T::from_bytes(value).is_err()); +} + +#[test] +fn parse_int_error_fmt_display() { + assert_eq!(ParseIntError.to_string(), "invalid integer"); +} + +#[test] +fn parse_time_error_fmt_display() { + assert_eq!(ParseTimeError.to_string(), "invalid time"); +} diff --git a/http/tests/functional/header.rs b/http/tests/functional/header.rs new file mode 100644 index 000000000..c6263b382 --- /dev/null +++ b/http/tests/functional/header.rs @@ -0,0 +1,548 @@ +use std::fmt; +use std::iter::FromIterator; + +use heph_http::header::{FromHeaderValue, Header, HeaderName, Headers}; + +use crate::assert_size; + +#[test] +fn sizes() { + assert_size::(48); + assert_size::
(48); + assert_size::>(32); +} + +#[test] +fn headers_add_one_header() { + const VALUE: &[u8] = b"GET"; + + let mut headers = Headers::EMPTY; + headers.add(Header::new(HeaderName::ALLOW, VALUE)); + assert_eq!(headers.len(), 1); + assert!(!headers.is_empty()); + + check_header(&headers, &HeaderName::ALLOW, VALUE, "GET"); + check_iter(&headers, &[(HeaderName::ALLOW, VALUE)]); +} + +#[test] +fn headers_add_multiple_headers() { + const ALLOW: &[u8] = b"GET"; + const CONTENT_LENGTH: &[u8] = b"123"; + const X_REQUEST_ID: &[u8] = b"abc-def"; + + let mut headers = Headers::EMPTY; + headers.add(Header::new(HeaderName::ALLOW, ALLOW)); + headers.add(Header::new(HeaderName::CONTENT_LENGTH, CONTENT_LENGTH)); + headers.add(Header::new(HeaderName::X_REQUEST_ID, X_REQUEST_ID)); + assert_eq!(headers.len(), 3); + assert!(!headers.is_empty()); + + check_header(&headers, &HeaderName::ALLOW, ALLOW, "GET"); + #[rustfmt::skip] + check_header(&headers, &HeaderName::CONTENT_LENGTH, CONTENT_LENGTH, 123usize); + check_header(&headers, &HeaderName::X_REQUEST_ID, X_REQUEST_ID, "abc-def"); + check_iter( + &headers, + &[ + (HeaderName::ALLOW, ALLOW), + (HeaderName::CONTENT_LENGTH, CONTENT_LENGTH), + (HeaderName::X_REQUEST_ID, X_REQUEST_ID), + ], + ); +} + +#[test] +fn headers_from_header() { + const VALUE: &[u8] = b"GET"; + let header = Header::new(HeaderName::ALLOW, VALUE); + let headers = Headers::from(header); + assert_eq!(headers.len(), 1); + assert!(!headers.is_empty()); + + check_header(&headers, &HeaderName::ALLOW, VALUE, "GET"); + check_iter(&headers, &[(HeaderName::ALLOW, VALUE)]); +} + +#[test] +fn headers_from_array() { + const ALLOW: &[u8] = b"GET"; + const CONTENT_LENGTH: &[u8] = b"123"; + const X_REQUEST_ID: &[u8] = b"abc-def"; + + let headers = Headers::from([ + Header::new(HeaderName::ALLOW, ALLOW), + Header::new(HeaderName::CONTENT_LENGTH, CONTENT_LENGTH), + Header::new(HeaderName::X_REQUEST_ID, X_REQUEST_ID), + ]); + assert_eq!(headers.len(), 3); + assert!(!headers.is_empty()); + + check_header(&headers, &HeaderName::ALLOW, ALLOW, "GET"); + #[rustfmt::skip] + check_header(&headers, &HeaderName::CONTENT_LENGTH, CONTENT_LENGTH, 123usize); + check_header(&headers, &HeaderName::X_REQUEST_ID, X_REQUEST_ID, "abc-def"); + check_iter( + &headers, + &[ + (HeaderName::ALLOW, ALLOW), + (HeaderName::CONTENT_LENGTH, CONTENT_LENGTH), + (HeaderName::X_REQUEST_ID, X_REQUEST_ID), + ], + ); +} + +#[test] +fn headers_from_slice() { + const ALLOW: &[u8] = b"GET"; + const CONTENT_LENGTH: &[u8] = b"123"; + const X_REQUEST_ID: &[u8] = b"abc-def"; + + let expected_headers: &[_] = &[ + Header::new(HeaderName::ALLOW, ALLOW), + Header::new(HeaderName::CONTENT_LENGTH, CONTENT_LENGTH), + Header::new(HeaderName::X_REQUEST_ID, X_REQUEST_ID), + ]; + + let headers = Headers::from(expected_headers); + assert_eq!(headers.len(), 3); + assert!(!headers.is_empty()); + + check_header(&headers, &HeaderName::ALLOW, ALLOW, "GET"); + #[rustfmt::skip] + check_header(&headers, &HeaderName::CONTENT_LENGTH, CONTENT_LENGTH, 123usize); + check_header(&headers, &HeaderName::X_REQUEST_ID, X_REQUEST_ID, "abc-def"); + check_iter( + &headers, + &[ + (HeaderName::ALLOW, ALLOW), + (HeaderName::CONTENT_LENGTH, CONTENT_LENGTH), + (HeaderName::X_REQUEST_ID, X_REQUEST_ID), + ], + ); +} + +#[test] +fn headers_from_iter_and_extend() { + const ALLOW: &[u8] = b"GET"; + const CONTENT_LENGTH: &[u8] = b"123"; + const X_REQUEST_ID: &[u8] = b"abc-def"; + + let mut headers = Headers::from_iter([ + Header::new(HeaderName::ALLOW, ALLOW), + Header::new(HeaderName::CONTENT_LENGTH, CONTENT_LENGTH), + ]); + assert_eq!(headers.len(), 2); + assert!(!headers.is_empty()); + + headers.extend([Header::new(HeaderName::X_REQUEST_ID, X_REQUEST_ID)]); + assert_eq!(headers.len(), 3); + assert!(!headers.is_empty()); + + check_header(&headers, &HeaderName::ALLOW, ALLOW, "GET"); + #[rustfmt::skip] + check_header(&headers, &HeaderName::CONTENT_LENGTH, CONTENT_LENGTH, 123usize); + check_header(&headers, &HeaderName::X_REQUEST_ID, X_REQUEST_ID, "abc-def"); + check_iter( + &headers, + &[ + (HeaderName::ALLOW, ALLOW), + (HeaderName::CONTENT_LENGTH, CONTENT_LENGTH), + (HeaderName::X_REQUEST_ID, X_REQUEST_ID), + ], + ); +} + +#[test] +fn headers_get_not_found() { + let mut headers = Headers::EMPTY; + assert!(headers.get(&HeaderName::DATE).is_none()); + assert!(headers.get_value(&HeaderName::DATE).is_none()); + + headers.add(Header::new(HeaderName::ALLOW, b"GET")); + assert!(headers.get(&HeaderName::DATE).is_none()); + assert!(headers.get_value(&HeaderName::DATE).is_none()); +} + +#[test] +fn clear_headers() { + const ALLOW: &[u8] = b"GET"; + const CONTENT_LENGTH: &[u8] = b"123"; + + let mut headers = Headers::EMPTY; + headers.add(Header::new(HeaderName::ALLOW, ALLOW)); + headers.add(Header::new(HeaderName::CONTENT_LENGTH, CONTENT_LENGTH)); + assert_eq!(headers.len(), 2); + assert!(!headers.is_empty()); + + headers.clear(); + assert_eq!(headers.len(), 0); + assert!(headers.is_empty()); + assert!(headers.get(&HeaderName::ALLOW).is_none()); + assert!(headers.get(&HeaderName::CONTENT_LENGTH).is_none()); +} + +fn check_header<'a, T>( + headers: &'a Headers, + name: &'_ HeaderName<'_>, + value: &'_ [u8], + parsed_value: T, +) where + T: FromHeaderValue<'a> + PartialEq + fmt::Debug, + >::Err: fmt::Debug, +{ + let got = headers.get(name).unwrap(); + assert_eq!(got.name(), name); + assert_eq!(got.value(), value); + assert_eq!(got.parse::().unwrap(), parsed_value); + + assert_eq!(headers.get_value(name).unwrap(), value); +} + +fn check_iter(headers: &'_ Headers, expected: &[(HeaderName<'_>, &'_ [u8])]) { + let mut len = expected.len(); + let mut iter = headers.iter(); + assert_eq!(iter.len(), len); + assert_eq!(iter.size_hint(), (len, Some(len))); + for (name, value) in expected { + let got = iter.next().unwrap(); + assert_eq!(got.name(), name); + assert_eq!(got.value(), *value); + len -= 1; + assert_eq!(iter.len(), len); + assert_eq!(iter.size_hint(), (len, Some(len))); + } + assert_eq!(iter.count(), 0); + + let iter = headers.iter(); + assert_eq!(iter.count(), expected.len()); +} + +#[test] +fn new_header() { + const _MY_HEADER: Header<'static, 'static> = + Header::new(HeaderName::USER_AGENT, b"Heph-HTTP/0.1"); + let _header = Header::new(HeaderName::USER_AGENT, b""); + // Should be fine. + let _header = Header::new(HeaderName::USER_AGENT, b"\rabc\n"); +} + +#[test] +#[should_panic = "header value contains CRLF ('\\r\\n')"] +fn new_header_with_crlf_should_panic() { + let _header = Header::new(HeaderName::USER_AGENT, b"\r\n"); +} + +#[test] +#[should_panic = "header value contains CRLF ('\\r\\n')"] +fn new_header_with_crlf_should_panic2() { + let _header = Header::new(HeaderName::USER_AGENT, b"some_text\r\n"); +} + +#[test] +fn parse_header() { + const LENGTH: Header<'static, 'static> = Header::new(HeaderName::CONTENT_LENGTH, b"100"); + assert_eq!(LENGTH.parse::().unwrap(), 100); +} + +#[test] +fn from_str_known_headers() { + let known_headers = &[ + "A-IM", + "ALPN", + "AMP-Cache-Transform", + "Accept", + "Accept-Additions", + "Accept-CH", + "Accept-Charset", + "Accept-Datetime", + "Accept-Encoding", + "Accept-Features", + "Accept-Language", + "Accept-Patch", + "Accept-Post", + "Accept-Ranges", + "Access-Control", + "Access-Control-Allow-Credentials", + "Access-Control-Allow-Headers", + "Access-Control-Allow-Methods", + "Access-Control-Allow-Origin", + "Access-Control-Max-Age", + "Access-Control-Request-Headers", + "Access-Control-Request-Method", + "Age", + "Allow", + "Alt-Svc", + "Alt-Used", + "Alternates", + "Apply-To-Redirect-Ref", + "Authentication-Control", + "Authentication-Info", + "Authorization", + "C-Ext", + "C-Man", + "C-Opt", + "C-PEP", + "C-PEP-Info", + "CDN-Loop", + "Cache-Control", + "Cal-Managed-ID", + "CalDAV-Timezones", + "Cert-Not-After", + "Cert-Not-Before", + "Close", + "Compliance", + "Connection", + "Content-Base", + "Content-Disposition", + "Content-Encoding", + "Content-ID", + "Content-Language", + "Content-Length", + "Content-Location", + "Content-MD5", + "Content-Range", + "Content-Script-Type", + "Content-Style-Type", + "Content-Transfer-Encoding", + "Content-Type", + "Content-Version", + "Cookie", + "Cookie2", + "Cost", + "DASL", + "DAV", + "Date", + "Default-Style", + "Delta-Base", + "Depth", + "Derived-From", + "Destination", + "Differential-ID", + "Digest", + "EDIINT-Features", + "ETag", + "Early-Data", + "Expect", + "Expect-CT", + "Expires", + "Ext", + "Forwarded", + "From", + "GetProfile", + "HTTP2-Settings", + "Hobareg", + "Host", + "IM", + "If", + "If-Match", + "If-Modified-Since", + "If-None-Match", + "If-Range", + "If-Schedule-Tag-Match", + "If-Unmodified-Since", + "Include-Referred-Token-Binding-ID", + "Isolation", + "Keep-Alive", + "Label", + "Last-Modified", + "Link", + "Location", + "Lock-Token", + "MIME-Version", + "Man", + "Max-Forwards", + "Memento-Datetime", + "Message-ID", + "Meter", + "Method-Check", + "Method-Check-Expires", + "Negotiate", + "Non-Compliance", + "OData-EntityId", + "OData-Isolation", + "OData-MaxVersion", + "OData-Version", + "OSCORE", + "OSLC-Core-Version", + "Opt", + "Optional", + "Optional-WWW-Authenticate", + "Ordering-Type", + "Origin", + "Overwrite", + "P3P", + "PEP", + "PICS-Label", + "Pep-Info", + "Position", + "Pragma", + "Prefer", + "Preference-Applied", + "ProfileObject", + "Protocol", + "Protocol-Info", + "Protocol-Query", + "Protocol-Request", + "Proxy-Authenticate", + "Proxy-Authentication-Info", + "Proxy-Authorization", + "Proxy-Features", + "Proxy-Instruction", + "Public", + "Public-Key-Pins", + "Public-Key-Pins-Report-Only", + "Range", + "Redirect-Ref", + "Referer", + "Referer-Root", + "Repeatability-Client-ID", + "Repeatability-First-Sent", + "Repeatability-Request-ID", + "Repeatability-Result", + "Replay-Nonce", + "Resolution-Hint", + "Resolver-Location", + "Retry-After", + "SLUG", + "Safe", + "Schedule-Reply", + "Schedule-Tag", + "Sec-Token-Binding", + "Sec-WebSocket-Accept", + "Sec-WebSocket-Extensions", + "Sec-WebSocket-Key", + "Sec-WebSocket-Protocol", + "Sec-WebSocket-Version", + "Security-Scheme", + "Server", + "Set-Cookie", + "Set-Cookie2", + "SetProfile", + "SoapAction", + "Status-URI", + "Strict-Transport-Security", + "SubOK", + "Subst", + "Sunset", + "Surrogate-Capability", + "Surrogate-Control", + "TCN", + "TE", + "TTL", + "Timeout", + "Timing-Allow-Origin", + "Title", + "Topic", + "Traceparent", + "Tracestate", + "Trailer", + "Transfer-Encoding", + "UA-Color", + "UA-Media", + "UA-Pixels", + "UA-Resolution", + "UA-Windowpixels", + "URI", + "Upgrade", + "Urgency", + "User-Agent", + "Variant-Vary", + "Vary", + "Version", + "Via", + "WWW-Authenticate", + "Want-Digest", + "Warning", + "X-Content-Type-Options", + "X-Device-Accept", + "X-Device-Accept-Charset", + "X-Device-Accept-Encoding", + "X-Device-Accept-Language", + "X-Device-User-Agent", + "X-Frame-Options", + "X-Request-ID", + ]; + for name in known_headers { + let header_name = HeaderName::from_str(name); + assert!(!header_name.is_heap_allocated(), "header: {}", name); + + // Matching should be case-insensitive. + let header_name = HeaderName::from_str(&name.to_uppercase()); + assert!(!header_name.is_heap_allocated(), "header: {}", name); + } +} + +#[test] +fn from_str_unknown_header() { + let header_name = HeaderName::from_str("EXTRA_LONG_UNKNOWN_HEADER_NAME_REALLY_LONG"); + assert!(header_name.is_heap_allocated()); +} + +#[test] +fn from_str_custom() { + let unknown_headers = &["my-header", "My-Header"]; + for name in unknown_headers { + let header_name = HeaderName::from_str(name); + assert!(header_name.is_heap_allocated(), "header: {}", name); + assert_eq!(header_name, "my-header"); + assert_eq!(header_name.as_ref(), "my-header"); + } + + let name = "bllow"; // Matches length of "Allow" header. + let header_name = HeaderName::from_str(name); + assert!(header_name.is_heap_allocated(), "header: {}", name); +} + +#[test] +fn from_lowercase() { + let header_name = HeaderName::from_lowercase("my-header"); + assert_eq!(header_name, "my-header"); + assert_eq!(header_name.as_ref(), "my-header"); +} + +#[test] +#[should_panic = "header name not lowercase"] +fn from_lowercase_not_lowercase_should_panic() { + let _name = HeaderName::from_lowercase("My-Header"); +} + +#[test] +fn from_string() { + let header_name = HeaderName::from("my-header".to_owned()); + assert_eq!(header_name, "my-header"); + assert_eq!(header_name.as_ref(), "my-header"); +} + +#[test] +fn from_string_makes_lowercase() { + let header_name = HeaderName::from("My-Header".to_owned()); + assert_eq!(header_name, "my-header"); + assert_eq!(header_name.as_ref(), "my-header"); +} + +#[test] +fn compare_is_case_insensitive() { + let tests = &[ + HeaderName::from_lowercase("my-header"), + HeaderName::from("My-Header".to_owned()), + ]; + for header_name in tests { + assert_eq!(header_name, "my-header"); + assert_eq!(header_name, "My-Header"); + assert_eq!(header_name, "mY-hEaDeR"); + assert_eq!(header_name.as_ref(), "my-header"); + } + assert_eq!(tests[0], tests[1]); +} + +#[test] +fn fmt_display() { + let tests = &[ + HeaderName::from_lowercase("my-header"), + HeaderName::from("My-Header".to_owned()), + ]; + for header_name in tests { + assert_eq!(header_name.to_string(), "my-header"); + } +} diff --git a/http/tests/functional/method.rs b/http/tests/functional/method.rs new file mode 100644 index 000000000..60961e688 --- /dev/null +++ b/http/tests/functional/method.rs @@ -0,0 +1,116 @@ +use heph_http::method::UnknownMethod; +use heph_http::Method::{self, *}; + +use crate::assert_size; + +#[test] +fn size() { + assert_size::(1); +} + +#[test] +fn is_safe() { + let safe = &[Get, Head, Options, Trace]; + for method in safe { + assert!(method.is_safe()); + } + let not_safe = &[Post, Put, Delete, Connect, Patch]; + for method in not_safe { + assert!(!method.is_safe()); + } +} + +#[test] +fn is_idempotent() { + let idempotent = &[Get, Head, Put, Delete, Options, Trace]; + for method in idempotent { + assert!(method.is_idempotent()); + } + let not_idempotent = &[Post, Connect, Patch]; + for method in not_idempotent { + assert!(!method.is_idempotent()); + } +} + +#[test] +fn expects_body() { + let no_body = &[Head]; + for method in no_body { + assert!(!method.expects_body()); + } + let has_body = &[Get, Post, Put, Delete, Connect, Options, Trace, Patch]; + for method in has_body { + assert!(method.expects_body()); + } +} + +#[test] +fn as_str() { + let tests = &[ + (Get, "GET"), + (Head, "HEAD"), + (Post, "POST"), + (Put, "PUT"), + (Delete, "DELETE"), + (Connect, "CONNECT"), + (Options, "OPTIONS"), + (Trace, "TRACE"), + (Patch, "PATCH"), + ]; + for (method, expected) in tests { + assert_eq!(method.as_str(), *expected); + } +} + +#[test] +fn from_str() { + let tests = &[ + (Get, "GET"), + (Head, "HEAD"), + (Post, "POST"), + (Put, "PUT"), + (Delete, "DELETE"), + (Connect, "CONNECT"), + (Options, "OPTIONS"), + (Trace, "TRACE"), + (Patch, "PATCH"), + ]; + for (expected, input) in tests { + let got: Method = input.parse().unwrap(); + assert_eq!(got, *expected); + // Must be case-insensitive. + let got: Method = input.to_lowercase().parse().unwrap(); + assert_eq!(got, *expected); + } +} + +#[test] +fn from_invalid_str() { + let tests = &["abc", "abcd", "abcde", "abcdef", "abcdefg", "abcdefgh"]; + for input in tests { + assert!(input.parse::().is_err()); + } +} + +#[test] +fn fmt_display() { + let tests = &[ + (Get, "GET"), + (Head, "HEAD"), + (Post, "POST"), + (Put, "PUT"), + (Delete, "DELETE"), + (Connect, "CONNECT"), + (Options, "OPTIONS"), + (Trace, "TRACE"), + (Patch, "PATCH"), + ]; + for (method, expected) in tests { + assert_eq!(*method.to_string(), **expected); + } +} + +#[test] +fn unknown_method_fmt_display() { + assert_eq!(UnknownMethod.to_string(), "unknown HTTP method"); +} diff --git a/http/tests/functional/server.rs b/http/tests/functional/server.rs new file mode 100644 index 000000000..029085ac1 --- /dev/null +++ b/http/tests/functional/server.rs @@ -0,0 +1,766 @@ +use std::borrow::Cow; +use std::io::{self, Read, Write}; +use std::lazy::SyncLazy; +use std::net::{Shutdown, SocketAddr, TcpStream}; +use std::str; +use std::sync::{Arc, Condvar, Mutex, Weak}; +use std::thread::{self, sleep}; +use std::time::{Duration, SystemTime}; + +use heph::actor::messages::Terminate; +use heph::rt::{self, Runtime, ThreadLocal}; +use heph::spawn::options::{ActorOptions, Priority}; +use heph::{actor, Actor, ActorRef, NewActor, Supervisor, SupervisorStrategy}; +use heph_http::body::OneshotBody; +use heph_http::server::{HttpServer, RequestError}; +use heph_http::{self as http, Header, HeaderName, Headers, Method, StatusCode, Version}; +use httpdate::fmt_http_date; + +/// Macro to run with a test server. +macro_rules! with_test_server { + (|$stream: ident| $test: block) => { + let test_server = TestServer::spawn(); + // NOTE: we put `test` in a block to ensure all connections to the + // server are dropped before we call `test_server.join()` below (which + // would block a shutdown. + { + let mut $stream = TcpStream::connect(test_server.address).unwrap(); + $stream.set_nodelay(true).unwrap(); + $stream + .set_read_timeout(Some(Duration::from_secs(1))) + .unwrap(); + $stream + .set_write_timeout(Some(Duration::from_secs(1))) + .unwrap(); + $test + } + test_server.join(); + }; +} + +#[test] +fn get() { + with_test_server!(|stream| { + stream.write_all(b"GET / HTTP/1.1\r\n\r\n").unwrap(); + let mut headers = Headers::EMPTY; + let now = fmt_http_date(SystemTime::now()); + headers.add(Header::new(HeaderName::DATE, now.as_bytes())); + headers.add(Header::new(HeaderName::CONTENT_LENGTH, b"2")); + let body = b"OK"; + expect_response(&mut stream, Version::Http11, StatusCode::OK, &headers, body); + }); +} + +#[test] +fn head() { + with_test_server!(|stream| { + stream.write_all(b"HEAD / HTTP/1.1\r\n\r\n").unwrap(); + let mut headers = Headers::EMPTY; + let now = fmt_http_date(SystemTime::now()); + headers.add(Header::new(HeaderName::DATE, now.as_bytes())); + headers.add(Header::new(HeaderName::CONTENT_LENGTH, b"0")); + let body = b""; + expect_response(&mut stream, Version::Http11, StatusCode::OK, &headers, body); + }); +} + +#[test] +fn post() { + with_test_server!(|stream| { + stream + .write_all(b"POST /echo-body HTTP/1.1\r\nContent-Length: 11\r\n\r\nHello world") + .unwrap(); + let mut headers = Headers::EMPTY; + let now = fmt_http_date(SystemTime::now()); + headers.add(Header::new(HeaderName::DATE, now.as_bytes())); + headers.add(Header::new(HeaderName::CONTENT_LENGTH, b"11")); + let body = b"Hello world"; + expect_response(&mut stream, Version::Http11, StatusCode::OK, &headers, body); + }); +} + +#[test] +fn with_request_header() { + with_test_server!(|stream| { + stream + .write_all(b"GET / HTTP/1.1\r\nUser-Agent:heph-http\r\n\r\n") + .unwrap(); + let mut headers = Headers::EMPTY; + let now = fmt_http_date(SystemTime::now()); + headers.add(Header::new(HeaderName::DATE, now.as_bytes())); + headers.add(Header::new(HeaderName::CONTENT_LENGTH, b"2")); + let body = b"OK"; + expect_response(&mut stream, Version::Http11, StatusCode::OK, &headers, body); + }); +} + +#[test] +fn with_multiple_request_headers() { + with_test_server!(|stream| { + stream + .write_all(b"GET / HTTP/1.1\r\nUser-Agent:heph-http\r\nAccept: */*\r\n\r\n") + .unwrap(); + let mut headers = Headers::EMPTY; + let now = fmt_http_date(SystemTime::now()); + headers.add(Header::new(HeaderName::DATE, now.as_bytes())); + headers.add(Header::new(HeaderName::CONTENT_LENGTH, b"2")); + let body = b"OK"; + expect_response(&mut stream, Version::Http11, StatusCode::OK, &headers, body); + }); +} + +#[test] +fn deny_incomplete_request() { + with_test_server!(|stream| { + // NOTE: missing `\r\n`. + stream.write_all(b"GET / HTTP/1.1\r\n").unwrap(); + stream.shutdown(Shutdown::Write).unwrap(); + let status = StatusCode::BAD_REQUEST; + let mut headers = Headers::EMPTY; + let now = fmt_http_date(SystemTime::now()); + headers.add(Header::new(HeaderName::DATE, now.as_bytes())); + headers.add(Header::new(HeaderName::CONTENT_LENGTH, b"31")); + headers.add(Header::new(HeaderName::CONNECTION, b"close")); + let body = b"Bad request: incomplete request"; + expect_response(&mut stream, Version::Http11, status, &headers, body); + }); +} + +#[test] +fn deny_unknown_method() { + with_test_server!(|stream| { + stream.write_all(b"MY_GET / HTTP/1.1\r\n\r\n").unwrap(); + stream.shutdown(Shutdown::Write).unwrap(); + let status = StatusCode::NOT_IMPLEMENTED; + let mut headers = Headers::EMPTY; + let now = fmt_http_date(SystemTime::now()); + headers.add(Header::new(HeaderName::DATE, now.as_bytes())); + headers.add(Header::new(HeaderName::CONTENT_LENGTH, b"27")); + let body = b"Bad request: unknown method"; + expect_response(&mut stream, Version::Http11, status, &headers, body); + }); +} + +#[test] +fn deny_invalid_method() { + with_test_server!(|stream| { + stream.write_all(b"G\nE\rT / HTTP/1.1\r\n\r\n").unwrap(); + let status = StatusCode::BAD_REQUEST; + let mut headers = Headers::EMPTY; + let now = fmt_http_date(SystemTime::now()); + headers.add(Header::new(HeaderName::DATE, now.as_bytes())); + headers.add(Header::new(HeaderName::CONTENT_LENGTH, b"35")); + headers.add(Header::new(HeaderName::CONNECTION, b"close")); + let body = b"Bad request: invalid request syntax"; + expect_response(&mut stream, Version::Http11, status, &headers, body); + }); +} + +#[test] +fn accept_same_content_length_headers() { + with_test_server!(|stream| { + stream + .write_all(b"GET / HTTP/1.1\r\nContent-Length: 0\r\nContent-Length: 0\r\n\r\n") + .unwrap(); + let mut headers = Headers::EMPTY; + let now = fmt_http_date(SystemTime::now()); + headers.add(Header::new(HeaderName::DATE, now.as_bytes())); + headers.add(Header::new(HeaderName::CONTENT_LENGTH, b"2")); + let body = b"OK"; + expect_response(&mut stream, Version::Http11, StatusCode::OK, &headers, body); + }); +} + +#[test] +fn deny_different_content_length_headers() { + with_test_server!(|stream| { + stream + .write_all(b"GET / HTTP/1.1\r\nContent-Length: 0\r\nContent-Length: 1\r\n\r\nA") + .unwrap(); + stream.shutdown(Shutdown::Write).unwrap(); + let status = StatusCode::BAD_REQUEST; + let mut headers = Headers::EMPTY; + let now = fmt_http_date(SystemTime::now()); + headers.add(Header::new(HeaderName::DATE, now.as_bytes())); + headers.add(Header::new(HeaderName::CONTENT_LENGTH, b"45")); + headers.add(Header::new(HeaderName::CONNECTION, b"close")); + let body = b"Bad request: different Content-Length headers"; + expect_response(&mut stream, Version::Http11, status, &headers, body); + }); +} + +#[test] +fn deny_content_length_and_chunked_transfer_encoding_request() { + // NOTE: similar to + // `deny_chunked_transfer_encoding_and_content_length_request`, but + // Transfer-Encoding goes first. + with_test_server!(|stream| { + stream + .write_all( + b"GET / HTTP/1.1\r\nTransfer-Encoding: chunked\r\nContent-Length: 1\r\n\r\nA", + ) + .unwrap(); + stream.shutdown(Shutdown::Write).unwrap(); + let status = StatusCode::BAD_REQUEST; + let mut headers = Headers::EMPTY; + let now = fmt_http_date(SystemTime::now()); + headers.add(Header::new(HeaderName::DATE, now.as_bytes())); + headers.add(Header::new(HeaderName::CONTENT_LENGTH, b"71")); + headers.add(Header::new(HeaderName::CONNECTION, b"close")); + let body = b"Bad request: provided both Content-Length and Transfer-Encoding headers"; + expect_response(&mut stream, Version::Http11, status, &headers, body); + }); +} + +#[test] +fn deny_invalid_content_length_headers() { + with_test_server!(|stream| { + stream + .write_all(b"GET / HTTP/1.1\r\nContent-Length: ABC\r\n\r\n") + .unwrap(); + stream.shutdown(Shutdown::Write).unwrap(); + let status = StatusCode::BAD_REQUEST; + let mut headers = Headers::EMPTY; + let now = fmt_http_date(SystemTime::now()); + headers.add(Header::new(HeaderName::DATE, now.as_bytes())); + headers.add(Header::new(HeaderName::CONTENT_LENGTH, b"42")); + headers.add(Header::new(HeaderName::CONNECTION, b"close")); + let body = b"Bad request: invalid Content-Length header"; + expect_response(&mut stream, Version::Http11, status, &headers, body); + }); +} + +#[test] +fn identity_transfer_encoding() { + with_test_server!(|stream| { + stream + .write_all(b"GET / HTTP/1.1\r\nTransfer-Encoding: identity\r\n\r\n") + .unwrap(); + stream.shutdown(Shutdown::Write).unwrap(); + let status = StatusCode::OK; + let mut headers = Headers::EMPTY; + let now = fmt_http_date(SystemTime::now()); + headers.add(Header::new(HeaderName::DATE, now.as_bytes())); + headers.add(Header::new(HeaderName::CONTENT_LENGTH, b"2")); + let body = b"OK"; + expect_response(&mut stream, Version::Http11, status, &headers, body); + }); +} + +#[test] +fn identity_transfer_encoding_with_content_length() { + with_test_server!(|stream| { + stream + .write_all( + b"POST /echo-body HTTP/1.1\r\nTransfer-Encoding: identity\r\nContent-Length: 1\r\n\r\nA", + ) + .unwrap(); + stream.shutdown(Shutdown::Write).unwrap(); + let status = StatusCode::OK; + let mut headers = Headers::EMPTY; + let now = fmt_http_date(SystemTime::now()); + headers.add(Header::new(HeaderName::DATE, now.as_bytes())); + headers.add(Header::new(HeaderName::CONTENT_LENGTH, b"1")); + let body = b"A"; + expect_response(&mut stream, Version::Http11, status, &headers, body); + }); +} + +#[test] +fn deny_unsupported_transfer_encoding() { + with_test_server!(|stream| { + stream + .write_all(b"GET / HTTP/1.1\r\nTransfer-Encoding: Nah\r\n\r\nA") + .unwrap(); + stream.shutdown(Shutdown::Write).unwrap(); + let status = StatusCode::NOT_IMPLEMENTED; + let mut headers = Headers::EMPTY; + let now = fmt_http_date(SystemTime::now()); + headers.add(Header::new(HeaderName::DATE, now.as_bytes())); + headers.add(Header::new(HeaderName::CONTENT_LENGTH, b"42")); + headers.add(Header::new(HeaderName::CONNECTION, b"close")); + let body = b"Bad request: unsupported Transfer-Encoding"; + expect_response(&mut stream, Version::Http11, status, &headers, body); + }); +} + +#[test] +fn deny_empty_transfer_encoding() { + with_test_server!(|stream| { + stream + .write_all(b"GET / HTTP/1.1\r\nTransfer-Encoding:\r\n\r\n") + .unwrap(); + stream.shutdown(Shutdown::Write).unwrap(); + let status = StatusCode::NOT_IMPLEMENTED; + let mut headers = Headers::EMPTY; + let now = fmt_http_date(SystemTime::now()); + headers.add(Header::new(HeaderName::DATE, now.as_bytes())); + headers.add(Header::new(HeaderName::CONTENT_LENGTH, b"42")); + headers.add(Header::new(HeaderName::CONNECTION, b"close")); + let body = b"Bad request: unsupported Transfer-Encoding"; + expect_response(&mut stream, Version::Http11, status, &headers, body); + }); +} + +#[test] +fn deny_chunked_transfer_encoding_not_last() { + with_test_server!(|stream| { + stream + .write_all(b"GET / HTTP/1.1\r\nTransfer-Encoding: chunked, gzip\r\n\r\nA") + .unwrap(); + stream.shutdown(Shutdown::Write).unwrap(); + let status = StatusCode::BAD_REQUEST; + let mut headers = Headers::EMPTY; + let now = fmt_http_date(SystemTime::now()); + headers.add(Header::new(HeaderName::DATE, now.as_bytes())); + headers.add(Header::new(HeaderName::CONTENT_LENGTH, b"45")); + headers.add(Header::new(HeaderName::CONNECTION, b"close")); + let body = b"Bad request: invalid Transfer-Encoding header"; + expect_response(&mut stream, Version::Http11, status, &headers, body); + }); +} + +#[test] +fn deny_chunked_transfer_encoding_and_content_length_request() { + // NOTE: similar to + // `deny_content_length_and_chunked_transfer_encoding_request`, but + // Content-Length goes first. + with_test_server!(|stream| { + stream + .write_all( + b"GET / HTTP/1.1\r\nContent-Length: 1\r\nTransfer-Encoding: chunked\r\n\r\nA", + ) + .unwrap(); + stream.shutdown(Shutdown::Write).unwrap(); + let status = StatusCode::BAD_REQUEST; + let mut headers = Headers::EMPTY; + let now = fmt_http_date(SystemTime::now()); + headers.add(Header::new(HeaderName::DATE, now.as_bytes())); + headers.add(Header::new(HeaderName::CONTENT_LENGTH, b"71")); + headers.add(Header::new(HeaderName::CONNECTION, b"close")); + let body = b"Bad request: provided both Content-Length and Transfer-Encoding headers"; + expect_response(&mut stream, Version::Http11, status, &headers, body); + }); +} + +#[test] +fn empty_body_chunked_transfer_encoding() { + with_test_server!(|stream| { + stream + .write_all(b"POST /echo-body HTTP/1.1\r\nTransfer-Encoding: chunked\r\n\r\n0\r\n") + .unwrap(); + let status = StatusCode::OK; + let mut headers = Headers::EMPTY; + let now = fmt_http_date(SystemTime::now()); + headers.add(Header::new(HeaderName::DATE, now.as_bytes())); + headers.add(Header::new(HeaderName::CONTENT_LENGTH, b"0")); + let body = b""; + expect_response(&mut stream, Version::Http11, status, &headers, body); + }); +} + +#[test] +fn deny_invalid_chunk_size() { + with_test_server!(|stream| { + stream + .write_all(b"GET / HTTP/1.1\r\nTransfer-Encoding: chunked\r\n\r\nZ\r\nAbc0\r\n") + .unwrap(); + let status = StatusCode::BAD_REQUEST; + let mut headers = Headers::EMPTY; + let now = fmt_http_date(SystemTime::now()); + headers.add(Header::new(HeaderName::DATE, now.as_bytes())); + headers.add(Header::new(HeaderName::CONTENT_LENGTH, b"31")); + headers.add(Header::new(HeaderName::CONNECTION, b"close")); + let body = b"Bad request: invalid chunk size"; + expect_response(&mut stream, Version::Http11, status, &headers, body); + }); +} + +#[test] +fn read_partial_chunk_size_chunked_transfer_encoding() { + // Test `Connection::next_request` handling reading the HTTP head, but not + // the chunk size yet. + with_test_server!(|stream| { + stream + .write_all(b"POST /echo-body HTTP/1.1\r\nTransfer-Encoding: chunked\r\n\r\n") + .unwrap(); + sleep(Duration::from_millis(200)); + stream.write_all(b"0\r\n").unwrap(); + let status = StatusCode::OK; + let mut headers = Headers::EMPTY; + let now = fmt_http_date(SystemTime::now()); + headers.add(Header::new(HeaderName::DATE, now.as_bytes())); + headers.add(Header::new(HeaderName::CONTENT_LENGTH, b"0")); + let body = b""; + expect_response(&mut stream, Version::Http11, status, &headers, body); + }); +} + +#[test] +fn too_large_http_head() { + // Tests `heph_http::MAX_HEAD_SIZE`. + with_test_server!(|stream| { + stream + .write_all(b"GET / HTTP/1.1\r\nSOME_HEADER: ") + .unwrap(); + let mut header_value = Vec::with_capacity(heph_http::MAX_HEAD_SIZE); + header_value.resize(heph_http::MAX_HEAD_SIZE, b'a'); + stream.write_all(&header_value).unwrap(); + stream.write_all(b"\r\n\r\n").unwrap(); + let status = StatusCode::BAD_REQUEST; + let mut headers = Headers::EMPTY; + let now = fmt_http_date(SystemTime::now()); + headers.add(Header::new(HeaderName::DATE, now.as_bytes())); + headers.add(Header::new(HeaderName::CONTENT_LENGTH, b"27")); + headers.add(Header::new(HeaderName::CONNECTION, b"close")); + let body = b"Bad request: head too large"; + expect_response(&mut stream, Version::Http11, status, &headers, body); + }); +} + +#[test] +fn invalid_header_name() { + with_test_server!(|stream| { + stream.write_all(b"GET / HTTP/1.1\r\n\0: \r\n\r\n").unwrap(); + let status = StatusCode::BAD_REQUEST; + let mut headers = Headers::EMPTY; + let now = fmt_http_date(SystemTime::now()); + headers.add(Header::new(HeaderName::DATE, now.as_bytes())); + headers.add(Header::new(HeaderName::CONTENT_LENGTH, b"32")); + headers.add(Header::new(HeaderName::CONNECTION, b"close")); + let body = b"Bad request: invalid header name"; + expect_response(&mut stream, Version::Http11, status, &headers, body); + }); +} + +#[test] +fn invalid_header_value() { + with_test_server!(|stream| { + stream + .write_all(b"GET / HTTP/1.1\r\nAbc: Header\rvalue\r\n\r\n") + .unwrap(); + let status = StatusCode::BAD_REQUEST; + let mut headers = Headers::EMPTY; + let now = fmt_http_date(SystemTime::now()); + headers.add(Header::new(HeaderName::DATE, now.as_bytes())); + headers.add(Header::new(HeaderName::CONTENT_LENGTH, b"33")); + headers.add(Header::new(HeaderName::CONNECTION, b"close")); + let body = b"Bad request: invalid header value"; + expect_response(&mut stream, Version::Http11, status, &headers, body); + }); +} + +#[test] +fn invalid_carriage_return() { + with_test_server!(|stream| { + stream.write_all(b"\rGET / HTTP/1.1\r\n\r\n").unwrap(); + let status = StatusCode::BAD_REQUEST; + let mut headers = Headers::EMPTY; + let now = fmt_http_date(SystemTime::now()); + headers.add(Header::new(HeaderName::DATE, now.as_bytes())); + headers.add(Header::new(HeaderName::CONTENT_LENGTH, b"35")); + headers.add(Header::new(HeaderName::CONNECTION, b"close")); + let body = b"Bad request: invalid request syntax"; + expect_response(&mut stream, Version::Http11, status, &headers, body); + }); +} + +#[test] +fn invalid_http_version() { + with_test_server!(|stream| { + stream.write_all(b"GET / HTTPS/1.1\r\n\r\n").unwrap(); + let status = StatusCode::BAD_REQUEST; + let mut headers = Headers::EMPTY; + let now = fmt_http_date(SystemTime::now()); + headers.add(Header::new(HeaderName::DATE, now.as_bytes())); + headers.add(Header::new(HeaderName::CONTENT_LENGTH, b"28")); + headers.add(Header::new(HeaderName::CONNECTION, b"close")); + let body = b"Bad request: invalid version"; + expect_response(&mut stream, Version::Http11, status, &headers, body); + }); +} + +#[test] +fn too_many_header() { + with_test_server!(|stream| { + stream.write_all(b"GET / HTTP/1.1\r\n").unwrap(); + for _ in 0..=http::MAX_HEADERS { + stream.write_all(b"Some-Header: Abc\r\n").unwrap(); + } + stream.write_all(b"\r\n").unwrap(); + let status = StatusCode::BAD_REQUEST; + let mut headers = Headers::EMPTY; + let now = fmt_http_date(SystemTime::now()); + headers.add(Header::new(HeaderName::DATE, now.as_bytes())); + headers.add(Header::new(HeaderName::CONTENT_LENGTH, b"28")); + headers.add(Header::new(HeaderName::CONNECTION, b"close")); + let body = b"Bad request: too many header"; + expect_response(&mut stream, Version::Http11, status, &headers, body); + }); +} + +fn expect_response( + stream: &mut TcpStream, + // Expected values: + version: Version, + status: StatusCode, + headers: &Headers, + body: &[u8], +) { + let mut buf = [0; 1024]; + let n = stream.read(&mut buf).unwrap(); + let buf = &buf[..n]; + + eprintln!("read response: {:?}", str::from_utf8(&buf[..n])); + + let mut h = [httparse::EMPTY_HEADER; 64]; + let mut response = httparse::Response::new(&mut h); + let parsed_n = response.parse(&buf).unwrap().unwrap(); + + assert_eq!(response.version, Some(version.minor())); + assert_eq!(response.code.unwrap(), status.0); + assert!(response.reason.unwrap().is_empty()); // We don't send a reason-phrase. + assert_eq!( + response.headers.len(), + headers.len(), + "mismatch headers lengths, got: {:?}, expected: {:?}", + response.headers, + headers + ); + for got_header in response.headers { + let got_header_name = HeaderName::from_str(got_header.name); + let got = headers.get_value(&got_header_name).unwrap(); + assert_eq!( + got_header.value, + got, + "different header values for '{}' header, got: '{:?}', expected: '{:?}'", + got_header_name, + str::from_utf8(got_header.value), + str::from_utf8(got) + ); + } + assert_eq!(&buf[parsed_n..], body, "different bodies"); + assert_eq!(parsed_n, n - body.len(), "unexpected extra bytes"); +} + +struct TestServer { + address: SocketAddr, + server_ref: ActorRef, + handle: Option>, +} + +impl TestServer { + fn spawn() -> Arc { + static TEST_SERVER: SyncLazy>> = + SyncLazy::new(|| Mutex::new(Weak::new())); + + let mut test_server = TEST_SERVER.lock().unwrap(); + if let Some(test_server) = test_server.upgrade() { + // Use an existing running server. + test_server + } else { + // Start a new server. + let new_server = Arc::new(TestServer::new()); + *test_server = Arc::downgrade(&new_server); + new_server + } + } + + fn new() -> TestServer { + const TIMEOUT: Duration = Duration::from_secs(1); + + let actor = http_actor as fn(_, _, _) -> _; + let address = "127.0.0.1:0".parse().unwrap(); + let server = HttpServer::setup(address, conn_supervisor, actor, ActorOptions::default()) + .map_err(rt::Error::setup) + .unwrap(); + let address = server.local_addr(); + + let mut runtime = Runtime::setup().num_threads(1).build().unwrap(); + let server_ref = Arc::new(Mutex::new(None)); + let set_ref = Arc::new(Condvar::new()); + let srv_ref = server_ref.clone(); + let set_ref2 = set_ref.clone(); + runtime + .run_on_workers(move |mut runtime_ref| -> Result<(), !> { + let mut server_ref = srv_ref.lock().unwrap(); + let options = ActorOptions::default().with_priority(Priority::LOW); + *server_ref = Some( + runtime_ref + .try_spawn_local(ServerSupervisor, server, (), options) + .unwrap() + .map(), + ); + set_ref2.notify_all(); + Ok(()) + }) + .unwrap(); + + let handle = thread::spawn(move || runtime.start().unwrap()); + let mut server_ref = set_ref + .wait_timeout_while(server_ref.lock().unwrap(), TIMEOUT, |r| r.is_none()) + .unwrap() + .0; + let server_ref = server_ref.take().unwrap(); + TestServer { + address, + server_ref, + handle: Some(handle), + } + } + + fn join(mut self: Arc) { + if let Some(this) = Arc::get_mut(&mut self) { + this.server_ref.try_send(Terminate).unwrap(); + this.handle.take().unwrap().join().unwrap() + } + } +} + +#[derive(Copy, Clone, Debug)] +struct ServerSupervisor; + +impl Supervisor for ServerSupervisor +where + NA: NewActor, + NA::Actor: Actor>, +{ + fn decide(&mut self, err: http::server::Error) -> SupervisorStrategy<()> { + use http::server::Error::*; + match err { + Accept(err) => panic!("error accepting new connection: {}", err), + NewActor(_) => unreachable!(), + } + } + + fn decide_on_restart_error(&mut self, err: io::Error) -> SupervisorStrategy<()> { + panic!("error restarting the TCP server: {}", err); + } + + fn second_restart_error(&mut self, err: io::Error) { + panic!("error restarting the actor a second time: {}", err); + } +} + +fn conn_supervisor(err: io::Error) -> SupervisorStrategy<(heph::net::TcpStream, SocketAddr)> { + panic!("error handling connection: {}", err) +} + +/// Routes: +/// GET / => 200, OK. +/// POST /echo-body => 200, $request_body. +/// * => 404, Not found. +async fn http_actor( + _: actor::Context, + mut connection: http::Connection, + _: SocketAddr, +) -> io::Result<()> { + connection.set_nodelay(true)?; + + let mut headers = Headers::EMPTY; + loop { + let mut got_version = None; + let mut got_method = None; + let (code, body, should_close) = match connection.next_request().await? { + Ok(Some(mut request)) => { + got_version = Some(request.version()); + got_method = Some(request.method()); + + match (request.method(), request.path()) { + (Method::Get | Method::Head, "/") => (StatusCode::OK, "OK".into(), false), + (Method::Post, "/echo-body") => { + let body_len = request.body().len(); + let mut buf = Vec::with_capacity(128); + request.body_mut().read_all(&mut buf, 1024).await?; + assert!(request.body().is_empty()); + if let http::body::BodyLength::Known(length) = body_len { + assert_eq!(length, buf.len()); + } else { + assert!(request.body().is_chunked()); + } + let body = String::from_utf8(buf).unwrap().into(); + (StatusCode::OK, body, false) + } + _ => (StatusCode::NOT_FOUND, "Not found".into(), false), + } + } + // No more requests. + Ok(None) => return Ok(()), + Err(err) => { + let code = err.proper_status_code(); + let body = Cow::from(format!("Bad request: {}", err)); + (code, body, err.should_close()) + } + }; + if let Some(got_version) = got_version { + assert_eq!(connection.last_request_version().unwrap(), got_version); + } + if let Some(got_method) = got_method { + assert_eq!(connection.last_request_method().unwrap(), got_method); + } + + if should_close { + headers.add(Header::new(HeaderName::CONNECTION, b"close")); + } + + let body = OneshotBody::new(body.as_bytes()); + connection.respond(code, &headers, body).await?; + if should_close { + return Ok(()); + } + + headers.clear(); + } +} + +#[test] +fn request_error_proper_status_code() { + use RequestError::*; + let tests = &[ + (IncompleteRequest, StatusCode::BAD_REQUEST), + (HeadTooLarge, StatusCode::BAD_REQUEST), + (InvalidContentLength, StatusCode::BAD_REQUEST), + (DifferentContentLengths, StatusCode::BAD_REQUEST), + (InvalidHeaderName, StatusCode::BAD_REQUEST), + (InvalidHeaderValue, StatusCode::BAD_REQUEST), + (TooManyHeaders, StatusCode::BAD_REQUEST), + (ChunkedNotLastTransferEncoding, StatusCode::BAD_REQUEST), + (ContentLengthAndTransferEncoding, StatusCode::BAD_REQUEST), + (InvalidToken, StatusCode::BAD_REQUEST), + (InvalidNewLine, StatusCode::BAD_REQUEST), + (InvalidVersion, StatusCode::BAD_REQUEST), + (InvalidChunkSize, StatusCode::BAD_REQUEST), + (UnsupportedTransferEncoding, StatusCode::NOT_IMPLEMENTED), + (UnknownMethod, StatusCode::NOT_IMPLEMENTED), + ]; + + for (error, expected) in tests { + assert_eq!(error.proper_status_code(), *expected); + } +} + +#[test] +fn request_should_close() { + use RequestError::*; + let tests = &[ + (IncompleteRequest, true), + (HeadTooLarge, true), + (InvalidContentLength, true), + (DifferentContentLengths, true), + (InvalidHeaderName, true), + (InvalidHeaderValue, true), + (UnsupportedTransferEncoding, true), + (TooManyHeaders, true), + (ChunkedNotLastTransferEncoding, true), + (ContentLengthAndTransferEncoding, true), + (InvalidToken, true), + (InvalidNewLine, true), + (InvalidVersion, true), + (InvalidChunkSize, true), + (UnknownMethod, false), + ]; + + for (error, expected) in tests { + assert_eq!(error.should_close(), *expected); + } +} diff --git a/http/tests/functional/status_code.rs b/http/tests/functional/status_code.rs new file mode 100644 index 000000000..3824a619d --- /dev/null +++ b/http/tests/functional/status_code.rs @@ -0,0 +1,168 @@ +use heph_http::StatusCode; + +#[test] +fn is_informational() { + let informational = &[100, 101, 199]; + for status in informational { + assert!(StatusCode(*status).is_informational()); + } + let not_informational = &[0, 10, 200, 201, 400, 999]; + for status in not_informational { + assert!(!StatusCode(*status).is_informational()); + } +} + +#[test] +fn is_successful() { + let successful = &[200, 201, 299]; + for status in successful { + assert!(StatusCode(*status).is_successful()); + } + let not_successful = &[0, 10, 100, 101, 400, 999]; + for status in not_successful { + assert!(!StatusCode(*status).is_successful()); + } +} + +#[test] +fn is_redirect() { + let redirect = &[300, 301, 399]; + for status in redirect { + assert!(StatusCode(*status).is_redirect()); + } + let not_redirect = &[0, 10, 100, 101, 400, 999]; + for status in not_redirect { + assert!(!StatusCode(*status).is_redirect()); + } +} + +#[test] +fn is_client_error() { + let client_error = &[400, 401, 499]; + for status in client_error { + assert!(StatusCode(*status).is_client_error()); + } + let not_client_error = &[0, 10, 100, 101, 300, 500, 999]; + for status in not_client_error { + assert!(!StatusCode(*status).is_client_error()); + } +} + +#[test] +fn is_server_error() { + let server_error = &[500, 501, 599]; + for status in server_error { + assert!(StatusCode(*status).is_server_error()); + } + let not_server_error = &[0, 10, 100, 101, 400, 600, 999]; + for status in not_server_error { + assert!(!StatusCode(*status).is_server_error()); + } +} + +#[test] +fn includes_body() { + let no_body = &[100, 101, 199, 204, 304]; + for status in no_body { + assert!(!StatusCode(*status).includes_body()); + } + let has_body = &[0, 10, 200, 201, 203, 205, 300, 301, 303, 305, 400, 500, 999]; + for status in has_body { + assert!(StatusCode(*status).includes_body()); + } +} + +#[test] +fn phrase() { + #[rustfmt::skip] + let tests = &[ + // 1xx range. + (StatusCode::CONTINUE, "Continue"), + (StatusCode::SWITCHING_PROTOCOLS, "Switching Protocols"), + (StatusCode::PROCESSING, "Processing"), + (StatusCode::EARLY_HINTS, "Early Hints"), + + // 2xx range. + (StatusCode::OK, "OK"), + (StatusCode::CREATED, "Created"), + (StatusCode::ACCEPTED, "Accepted"), + (StatusCode::NON_AUTHORITATIVE_INFORMATION, "Non-Authoritative Information"), + (StatusCode::NO_CONTENT, "No Content"), + (StatusCode::RESET_CONTENT, "Reset Content"), + (StatusCode::PARTIAL_CONTENT, "Partial Content"), + (StatusCode::MULTI_STATUS, "Multi-Status"), + (StatusCode::ALREADY_REPORTED, "Already Reported"), + (StatusCode::IM_USED, "IM Used"), + + // 3xx range. + (StatusCode::MULTIPLE_CHOICES, "Multiple Choices"), + (StatusCode::MOVED_PERMANENTLY, "Moved Permanently"), + (StatusCode::FOUND, "Found"), + (StatusCode::SEE_OTHER, "See Other"), + (StatusCode::NOT_MODIFIED, "Not Modified"), + (StatusCode::USE_PROXY, "Use Proxy"), + (StatusCode::TEMPORARY_REDIRECT, "Temporary Redirect"), + (StatusCode::PERMANENT_REDIRECT, "Permanent Redirect"), + + // 4xx range. + (StatusCode::BAD_REQUEST, "Bad Request"), + (StatusCode::UNAUTHORIZED, "Unauthorized"), + (StatusCode::PAYMENT_REQUIRED, "Payment Required"), + (StatusCode::FORBIDDEN, "Forbidden"), + (StatusCode::NOT_FOUND, "Not Found"), + (StatusCode::METHOD_NOT_ALLOWED, "Method Not Allowed"), + (StatusCode::NOT_ACCEPTABLE, "Not Acceptable"), + (StatusCode::PROXY_AUTHENTICATION_REQUIRED, "Proxy Authentication Required"), + (StatusCode::REQUEST_TIMEOUT, "Request Timeout"), + (StatusCode::CONFLICT, "Conflict"), + (StatusCode::GONE, "Gone"), + (StatusCode::LENGTH_REQUIRED, "Length Required"), + (StatusCode::PRECONDITION_FAILED, "Precondition Failed"), + (StatusCode::PAYLOAD_TOO_LARGE, "Payload Too Large"), + (StatusCode::URI_TOO_LONG, "URI Too Long"), + (StatusCode::UNSUPPORTED_MEDIA_TYPE, "Unsupported Media Type"), + (StatusCode::RANGE_NOT_SATISFIABLE, "Range Not Satisfiable"), + (StatusCode::EXPECTATION_FAILED, "Expectation Failed"), + (StatusCode::MISDIRECTED_REQUEST, "Misdirected Request"), + (StatusCode::UNPROCESSABLE_ENTITY, "Unprocessable Entity"), + (StatusCode::LOCKED, "Locked"), + (StatusCode::FAILED_DEPENDENCY, "Failed Dependency"), + (StatusCode::TOO_EARLY, "Too Early"), + (StatusCode::UPGRADE_REQUIRED, "Upgrade Required"), + (StatusCode::PRECONDITION_REQUIRED, "Precondition Required"), + (StatusCode::TOO_MANY_REQUESTS, "Too Many Requests"), + (StatusCode::REQUEST_HEADER_FIELDS_TOO_LARGE, "Request Header Fields Too Large"), + (StatusCode::UNAVAILABLE_FOR_LEGAL_REASONS, "Unavailable For Legal Reasons"), + + // 5xx range. + (StatusCode::INTERNAL_SERVER_ERROR, "Internal Server Error"), + (StatusCode::NOT_IMPLEMENTED, "Not Implemented"), + (StatusCode::BAD_GATEWAY, "Bad Gateway"), + (StatusCode::SERVICE_UNAVAILABLE, "Service Unavailable"), + (StatusCode::GATEWAY_TIMEOUT, "Gateway Timeout"), + ( StatusCode::HTTP_VERSION_NOT_SUPPORTED, "HTTP Version Not Supported"), + ( StatusCode::VARIANT_ALSO_NEGOTIATES, "Variant Also Negotiates"), + (StatusCode::INSUFFICIENT_STORAGE, "Insufficient Storage"), + (StatusCode::LOOP_DETECTED, "Loop Detected"), + (StatusCode::NOT_EXTENDED, "Not Extended"), + ( StatusCode::NETWORK_AUTHENTICATION_REQUIRED, "Network Authentication Required"), + ]; + for (input, expected) in tests { + assert_eq!(input.phrase().unwrap(), *expected); + } + + assert!(StatusCode(0).phrase().is_none()); + assert!(StatusCode(999).phrase().is_none()); +} + +#[test] +fn fmt_display() { + let tests = &[ + (StatusCode::OK, "200"), + (StatusCode::BAD_REQUEST, "400"), + (StatusCode(999), "999"), + ]; + for (method, expected) in tests { + assert_eq!(*method.to_string(), **expected); + } +} diff --git a/http/tests/functional/version.rs b/http/tests/functional/version.rs new file mode 100644 index 000000000..0569227a8 --- /dev/null +++ b/http/tests/functional/version.rs @@ -0,0 +1,65 @@ +use heph_http::version::UnknownVersion; +use heph_http::Version::{self, *}; + +use crate::assert_size; + +#[test] +fn size() { + assert_size::(1); +} + +#[test] +fn major() { + let tests = &[(Http10, 1), (Http11, 1)]; + for (version, expected) in tests { + assert_eq!(version.major(), *expected); + } +} + +#[test] +fn minor() { + let tests = &[(Http10, 0), (Http11, 1)]; + for (version, expected) in tests { + assert_eq!(version.minor(), *expected); + } +} + +#[test] +fn highest_minor() { + let tests = &[(Http10, Http11), (Http11, Http11)]; + for (version, expected) in tests { + assert_eq!(version.highest_minor(), *expected); + } +} + +#[test] +fn from_str() { + let tests = &[(Http10, "HTTP/1.0"), (Http11, "HTTP/1.1")]; + for (expected, input) in tests { + let got: Version = input.parse().unwrap(); + assert_eq!(got, *expected); + // NOTE: version (unlike most other types) is matched case-sensitive. + } + assert!("HTTP/1.2".parse::().is_err()); +} + +#[test] +fn as_str() { + let tests = &[(Http10, "HTTP/1.0"), (Http11, "HTTP/1.1")]; + for (method, expected) in tests { + assert_eq!(method.as_str(), *expected); + } +} + +#[test] +fn fmt_display() { + let tests = &[(Http10, "HTTP/1.0"), (Http11, "HTTP/1.1")]; + for (method, expected) in tests { + assert_eq!(*method.to_string(), **expected); + } +} + +#[test] +fn unknown_version_fmt_display() { + assert_eq!(UnknownVersion.to_string(), "unknown HTTP version"); +} diff --git a/src/net/mod.rs b/src/net/mod.rs index 76a7b8257..8ae8719a4 100644 --- a/src/net/mod.rs +++ b/src/net/mod.rs @@ -48,7 +48,7 @@ use std::cmp::min; use std::mem::MaybeUninit; use std::net::SocketAddr; use std::ops::{Deref, DerefMut}; -use std::{fmt, io}; +use std::{fmt, io, slice}; use socket2::SockAddr; @@ -93,6 +93,16 @@ pub trait Bytes { /// [`update_length`]: Bytes::update_length fn as_bytes(&mut self) -> &mut [MaybeUninit]; + /// Returns the length of the buffer as returned by [`as_bytes`]. + /// + /// [`as_bytes`]: Bytes::as_bytes + fn spare_capacity(&self) -> usize; + + /// Returns `true` if the buffer has spare capacity. + fn has_spare_capacity(&self) -> bool { + self.spare_capacity() == 0 + } + /// Update the length of the byte slice, marking `n` bytes as initialised. /// /// # Safety @@ -108,6 +118,19 @@ pub trait Bytes { /// [`TcpStream::recv_n`] will not work correctly (as the buffer will /// overwrite itself on successive reads). unsafe fn update_length(&mut self, n: usize); + + /// Wrap the buffer in `LimitedBytes`, which limits the amount of bytes used + /// to `limit`. + /// + /// [`LimitedBytes::into_inner`] can be used to retrieve the buffer again, + /// or a mutable reference to the buffer can be used and the limited buffer + /// be dropped after usage. + fn limit(self, limit: usize) -> LimitedBytes + where + Self: Sized, + { + LimitedBytes { buf: self, limit } + } } impl Bytes for &mut B @@ -118,6 +141,14 @@ where (&mut **self).as_bytes() } + fn spare_capacity(&self) -> usize { + (&**self).spare_capacity() + } + + fn has_spare_capacity(&self) -> bool { + (&**self).has_spare_capacity() + } + unsafe fn update_length(&mut self, n: usize) { (&mut **self).update_length(n) } @@ -155,11 +186,18 @@ where /// } /// ``` impl Bytes for Vec { - // NOTE: keep this function in sync with the impl below. fn as_bytes(&mut self) -> &mut [MaybeUninit] { self.spare_capacity_mut() } + fn spare_capacity(&self) -> usize { + self.capacity() - self.len() + } + + fn has_spare_capacity(&self) -> bool { + self.capacity() != self.len() + } + unsafe fn update_length(&mut self, n: usize) { let new = self.len() + n; debug_assert!(self.capacity() >= new); @@ -202,6 +240,18 @@ impl<'a> MaybeUninitSlice<'a> { MaybeUninitSlice(socket2::MaybeUninitSlice::new(buf.as_bytes())) } + fn limit(&mut self, limit: usize) { + let len = self.len(); + assert!(len >= limit); + self.0 = unsafe { + // SAFETY: this should be the line below, but I couldn't figure out + // the lifetime. Since we're only making the slices smaller (as + // checked by the assert above) this should be safe. + //self.0 = socket2::MaybeUninitSlice::new(&mut self[..limit]); + socket2::MaybeUninitSlice::new(slice::from_raw_parts_mut(self.0.as_mut_ptr(), limit)) + }; + } + /// Returns `bufs` as [`socket2::MaybeUninitSlice`]. #[allow(clippy::wrong_self_convention)] fn as_socket2<'b>( @@ -351,6 +401,14 @@ impl<'a> DerefMut for MaybeUninitSlice<'a> { /// &mut self.bytes[self.initialised..] /// } /// +/// fn spare_capacity(&self) -> usize { +/// self.bytes.len() - self.initialised +/// } +/// +/// fn has_spare_capacity(&self) -> bool { +/// self.bytes.len() != self.initialised +/// } +/// /// unsafe fn update_length(&mut self, n: usize) { /// self.initialised += n; /// } @@ -363,6 +421,16 @@ pub trait BytesVectored { /// Returns itself as a slice of [`MaybeUninitSlice`]. fn as_bufs<'b>(&'b mut self) -> Self::Bufs<'b>; + /// Returns the total length of the buffers as returned by [`as_bufs`]. + /// + /// [`as_bufs`]: BytesVectored::as_bufs + fn spare_capacity(&self) -> usize; + + /// Returns `true` if (one of) the buffers has spare capacity. + fn has_spare_capacity(&self) -> bool { + self.spare_capacity() == 0 + } + /// Update the length of the buffers in the slice. /// /// # Safety @@ -379,6 +447,19 @@ pub trait BytesVectored { /// [`TcpStream::recv_n_vectored`] will not work correctly (as the buffer /// will overwrite itself on successive reads). unsafe fn update_lengths(&mut self, n: usize); + + /// Wrap the buffer in `LimitedBytes`, which limits the amount of bytes used + /// to `limit`. + /// + /// [`LimitedBytes::into_inner`] can be used to retrieve the buffer again, + /// or a mutable reference to the buffer can be used and the limited buffer + /// be dropped after usage. + fn limit(self, limit: usize) -> LimitedBytes + where + Self: Sized, + { + LimitedBytes { buf: self, limit } + } } impl BytesVectored for &mut B @@ -391,6 +472,14 @@ where (&mut **self).as_bufs() } + fn spare_capacity(&self) -> usize { + (&**self).spare_capacity() + } + + fn has_spare_capacity(&self) -> bool { + (&**self).has_spare_capacity() + } + unsafe fn update_lengths(&mut self, n: usize) { (&mut **self).update_lengths(n) } @@ -411,10 +500,18 @@ where unsafe { MaybeUninit::array_assume_init(bufs) } } + fn spare_capacity(&self) -> usize { + self.iter().map(|b| b.spare_capacity()).sum() + } + + fn has_spare_capacity(&self) -> bool { + self.iter().any(|b| b.has_spare_capacity()) + } + unsafe fn update_lengths(&mut self, n: usize) { let mut left = n; for buf in self.iter_mut() { - let n = min(left, buf.as_bytes().len()); + let n = min(left, buf.spare_capacity()); buf.update_length(n); left -= n; if left == 0 { @@ -440,10 +537,18 @@ macro_rules! impl_vectored_bytes_tuple { unsafe { MaybeUninit::array_assume_init(bufs) } } + fn spare_capacity(&self) -> usize { + $( self.$idx.spare_capacity() + )+ 0 + } + + fn has_spare_capacity(&self) -> bool { + $( self.$idx.has_spare_capacity() || )+ false + } + unsafe fn update_lengths(&mut self, n: usize) { let mut left = n; $( - let n = min(left, self.$idx.as_bytes().len()); + let n = min(left, self.$idx.spare_capacity()); self.$idx.update_length(n); left -= n; if left == 0 { @@ -467,6 +572,92 @@ impl_vectored_bytes_tuple! { 4: B0 0, B1 1, B2 2, B3 3 } impl_vectored_bytes_tuple! { 3: B0 0, B1 1, B2 2 } impl_vectored_bytes_tuple! { 2: B0 0, B1 1 } +/// Wrapper to limit the number of bytes `B` can use. +/// +/// See [`Bytes::limit`] and [`BytesVectored::limit`]. +#[derive(Debug)] +pub struct LimitedBytes { + buf: B, + limit: usize, +} + +impl LimitedBytes { + /// Returns the underlying buffer. + pub fn into_inner(self) -> B { + self.buf + } +} + +impl Bytes for LimitedBytes +where + B: Bytes, +{ + fn as_bytes(&mut self) -> &mut [MaybeUninit] { + let bytes = self.buf.as_bytes(); + if bytes.len() > self.limit { + &mut bytes[..self.limit] + } else { + bytes + } + } + + fn spare_capacity(&self) -> usize { + min(self.buf.spare_capacity(), self.limit) + } + + fn has_spare_capacity(&self) -> bool { + self.spare_capacity() > 0 + } + + unsafe fn update_length(&mut self, n: usize) { + self.buf.update_length(n); + self.limit -= n; + } +} + +impl BytesVectored for LimitedBytes +where + B: BytesVectored, +{ + type Bufs<'b> = B::Bufs<'b>; + + fn as_bufs<'b>(&'b mut self) -> Self::Bufs<'b> { + let mut bufs = self.buf.as_bufs(); + let mut left = self.limit; + let mut iter = bufs.as_mut().iter_mut(); + while let Some(buf) = iter.next() { + let len = buf.len(); + if left > len { + left -= len; + } else { + buf.limit(left); + for buf in iter { + *buf = MaybeUninitSlice::new(&mut []); + } + break; + } + } + bufs + } + + fn spare_capacity(&self) -> usize { + if self.limit == 0 { + 0 + } else { + min(self.buf.spare_capacity(), self.limit) + } + } + + fn has_spare_capacity(&self) -> bool { + self.limit != 0 && self.buf.has_spare_capacity() + } + + unsafe fn update_lengths(&mut self, n: usize) { + self.buf.update_lengths(n); + self.limit -= n; + } +} + /// Convert a `socket2:::SockAddr` into a `std::net::SocketAddr`. #[allow(clippy::needless_pass_by_value)] fn convert_address(address: SockAddr) -> io::Result { diff --git a/src/net/tcp/stream.rs b/src/net/tcp/stream.rs index 9b9ddd801..fd0336f63 100644 --- a/src/net/tcp/stream.rs +++ b/src/net/tcp/stream.rs @@ -236,16 +236,17 @@ impl TcpStream { where B: Bytes, { - let dst = buf.as_bytes(); debug_assert!( - !dst.is_empty(), + buf.has_spare_capacity(), "called `TcpStream::try_recv with an empty buffer" ); - SockRef::from(&self.socket).recv(dst).map(|read| { - // Safety: just read the bytes. - unsafe { buf.update_length(read) } - read - }) + SockRef::from(&self.socket) + .recv(buf.as_bytes()) + .map(|read| { + // Safety: just read the bytes. + unsafe { buf.update_length(read) } + read + }) } /// Receive messages from the stream, writing them into `buf`. @@ -314,12 +315,12 @@ impl TcpStream { /// # /// # drop(actor); // Silent dead code warnings. /// ``` - pub fn recv_n<'a, B>(&'a mut self, mut buf: B, n: usize) -> RecvN<'a, B> + pub fn recv_n<'a, B>(&'a mut self, buf: B, n: usize) -> RecvN<'a, B> where B: Bytes, { debug_assert!( - buf.as_bytes().len() >= n, + buf.spare_capacity() >= n, "called `TcpStream::recv_n` with a buffer smaller then `n`" ); RecvN { @@ -342,16 +343,14 @@ impl TcpStream { where B: BytesVectored, { - let mut dst = bufs.as_bufs(); debug_assert!( - dst.as_mut().iter().any(|buf| !buf.is_empty()), - "called `UdpSocket::try_recv_vectored` with an empty buffers" + bufs.has_spare_capacity(), + "called `UdpSocket::try_recv_vectored` with empty buffers" ); - let res = - SockRef::from(&self.socket).recv_vectored(MaybeUninitSlice::as_socket2(dst.as_mut())); + let res = SockRef::from(&self.socket) + .recv_vectored(MaybeUninitSlice::as_socket2(bufs.as_bufs().as_mut())); match res { Ok((read, _)) => { - drop(dst); // Safety: just read the bytes. unsafe { bufs.update_lengths(read) } Ok(read) @@ -365,19 +364,20 @@ impl TcpStream { where B: BytesVectored, { + debug_assert!( + bufs.has_spare_capacity(), + "called `TcpStream::recv_vectored` with empty buffers" + ); RecvVectored { stream: self, bufs } } /// Receive at least `n` bytes from the stream, writing them into `bufs`. - pub fn recv_n_vectored(&mut self, mut bufs: B, n: usize) -> RecvNVectored<'_, B> + pub fn recv_n_vectored(&mut self, bufs: B, n: usize) -> RecvNVectored<'_, B> where B: BytesVectored, { debug_assert!( - { - let mut dst = bufs.as_bufs(); - !dst.as_mut().iter().map(|buf| buf.len()).sum::() >= n - }, + bufs.spare_capacity() >= n, "called `TcpStream::recv_n_vectored` with a buffer smaller then `n`" ); RecvNVectored { @@ -394,16 +394,17 @@ impl TcpStream { where B: Bytes, { - let dst = buf.as_bytes(); debug_assert!( - !dst.is_empty(), + buf.has_spare_capacity(), "called `TcpStream::try_peek with an empty buffer" ); - SockRef::from(&self.socket).peek(dst).map(|read| { - // Safety: just read the bytes. - unsafe { buf.update_length(read) } - read - }) + SockRef::from(&self.socket) + .peek(buf.as_bytes()) + .map(|read| { + // Safety: just read the bytes. + unsafe { buf.update_length(read) } + read + }) } /// Receive messages from the stream, writing them into `buf`, without @@ -423,16 +424,16 @@ impl TcpStream { where B: BytesVectored, { - let mut dst = bufs.as_bufs(); debug_assert!( - dst.as_mut().iter().any(|buf| !buf.is_empty()), - "called `UdpSocket::try_peek_vectored` with an empty buffer" + bufs.has_spare_capacity(), + "called `UdpSocket::try_peek_vectored` with empty buffers" + ); + let res = SockRef::from(&self.socket).recv_vectored_with_flags( + MaybeUninitSlice::as_socket2(bufs.as_bufs().as_mut()), + libc::MSG_PEEK, ); - let res = SockRef::from(&self.socket) - .recv_vectored_with_flags(MaybeUninitSlice::as_socket2(dst.as_mut()), libc::MSG_PEEK); match res { Ok((read, _)) => { - drop(dst); // Safety: just read the bytes. unsafe { bufs.update_lengths(read) } Ok(read) diff --git a/src/net/udp.rs b/src/net/udp.rs index b63a9f55b..c380c5f72 100644 --- a/src/net/udp.rs +++ b/src/net/udp.rs @@ -263,13 +263,12 @@ impl UdpSocket { where B: Bytes, { - let dst = buf.as_bytes(); debug_assert!( - !dst.is_empty(), + buf.has_spare_capacity(), "called `UdpSocket::try_recv_from` with an empty buffer" ); SockRef::from(&self.socket) - .recv_from(dst) + .recv_from(buf.as_bytes()) .and_then(|(read, address)| { // Safety: just read the bytes. unsafe { buf.update_length(read) } @@ -300,16 +299,14 @@ impl UdpSocket { where B: BytesVectored, { - let mut dst = bufs.as_bufs(); debug_assert!( - !dst.as_mut().first().map_or(true, |buf| buf.is_empty()), - "called `UdpSocket::try_recv_from` with an empty buffer" + bufs.has_spare_capacity(), + "called `UdpSocket::try_recv_from` with empty buffers" ); let res = SockRef::from(&self.socket) - .recv_from_vectored(MaybeUninitSlice::as_socket2(dst.as_mut())); + .recv_from_vectored(MaybeUninitSlice::as_socket2(bufs.as_bufs().as_mut())); match res { Ok((read, _, address)) => { - drop(dst); // Safety: just read the bytes. unsafe { bufs.update_lengths(read) } let address = convert_address(address)?; @@ -341,13 +338,12 @@ impl UdpSocket { where B: Bytes, { - let dst = buf.as_bytes(); debug_assert!( - !dst.is_empty(), + buf.has_spare_capacity(), "called `UdpSocket::try_peek_from` with an empty buffer" ); SockRef::from(&self.socket) - .peek_from(dst) + .peek_from(buf.as_bytes()) .and_then(|(read, address)| { // Safety: just read the bytes. unsafe { buf.update_length(read) } @@ -379,18 +375,16 @@ impl UdpSocket { where B: BytesVectored, { - let mut dst = bufs.as_bufs(); debug_assert!( - !dst.as_mut().first().map_or(true, |buf| buf.is_empty()), - "called `UdpSocket::try_peek_from_vectored` with an empty buffer" + bufs.has_spare_capacity(), + "called `UdpSocket::try_peek_from_vectored` with empty buffers" ); let res = SockRef::from(&self.socket).recv_from_vectored_with_flags( - MaybeUninitSlice::as_socket2(dst.as_mut()), + MaybeUninitSlice::as_socket2(bufs.as_bufs().as_mut()), libc::MSG_PEEK, ); match res { Ok((read, _, address)) => { - drop(dst); // Safety: just read the bytes. unsafe { bufs.update_lengths(read) } let address = convert_address(address)?; @@ -591,16 +585,17 @@ impl UdpSocket { where B: Bytes, { - let dst = buf.as_bytes(); debug_assert!( - !dst.is_empty(), + buf.has_spare_capacity(), "called `UdpSocket::try_recv` with an empty buffer" ); - SockRef::from(&self.socket).recv(dst).map(|read| { - // Safety: just read the bytes. - unsafe { buf.update_length(read) } - read - }) + SockRef::from(&self.socket) + .recv(buf.as_bytes()) + .map(|read| { + // Safety: just read the bytes. + unsafe { buf.update_length(read) } + read + }) } /// Receives data from the socket. Returns a [`Future`] that on success @@ -624,16 +619,14 @@ impl UdpSocket { where B: BytesVectored, { - let mut dst = bufs.as_bufs(); debug_assert!( - !dst.as_mut().first().map_or(true, |buf| buf.is_empty()), - "called `UdpSocket::try_recv_vectored` with an empty buffer" + bufs.has_spare_capacity(), + "called `UdpSocket::try_recv_vectored` with empty buffers" ); - let res = - SockRef::from(&self.socket).recv_vectored(MaybeUninitSlice::as_socket2(dst.as_mut())); + let res = SockRef::from(&self.socket) + .recv_vectored(MaybeUninitSlice::as_socket2(bufs.as_bufs().as_mut())); match res { Ok((read, _)) => { - drop(dst); // Safety: just read the bytes. unsafe { bufs.update_lengths(read) } Ok(read) @@ -663,16 +656,17 @@ impl UdpSocket { where B: Bytes, { - let dst = buf.as_bytes(); debug_assert!( - !dst.is_empty(), + buf.has_spare_capacity(), "called `UdpSocket::try_peek` with an empty buffer" ); - SockRef::from(&self.socket).peek(dst).map(|read| { - // Safety: just read the bytes. - unsafe { buf.update_length(read) } - read - }) + SockRef::from(&self.socket) + .peek(buf.as_bytes()) + .map(|read| { + // Safety: just read the bytes. + unsafe { buf.update_length(read) } + read + }) } /// Receives data from the socket, without removing it from the input queue. @@ -697,16 +691,16 @@ impl UdpSocket { where B: BytesVectored, { - let mut dst = bufs.as_bufs(); debug_assert!( - !dst.as_mut().first().map_or(true, |buf| buf.is_empty()), - "called `UdpSocket::try_peek_vectored` with an empty buffer" + bufs.has_spare_capacity(), + "called `UdpSocket::try_peek_vectored` with empty buffers" + ); + let res = SockRef::from(&self.socket).recv_vectored_with_flags( + MaybeUninitSlice::as_socket2(bufs.as_bufs().as_mut()), + libc::MSG_PEEK, ); - let res = SockRef::from(&self.socket) - .recv_vectored_with_flags(MaybeUninitSlice::as_socket2(dst.as_mut()), libc::MSG_PEEK); match res { Ok((read, _)) => { - drop(dst); // Safety: just read the bytes. unsafe { bufs.update_lengths(read) } Ok(read) diff --git a/tests/functional/bytes.rs b/tests/functional/bytes.rs index d9c656c90..6718679a7 100644 --- a/tests/functional/bytes.rs +++ b/tests/functional/bytes.rs @@ -12,7 +12,9 @@ fn write_bytes(src: &[u8], mut buf: B) -> usize where B: Bytes, { + let spare_capacity = buf.spare_capacity(); let dst = buf.as_bytes(); + assert_eq!(dst.len(), spare_capacity); let len = min(src.len(), dst.len()); // Safety: both the `src` and `dst` pointers are good. And we've ensured // that the length is correct, not overwriting data we don't own or reading @@ -52,33 +54,69 @@ where #[test] fn impl_for_vec() { let mut buf = Vec::::with_capacity(2 * DATA.len()); + assert_eq!(buf.spare_capacity(), 2 * DATA.len()); + assert!(buf.has_spare_capacity()); let n = write_bytes(DATA, &mut buf); assert_eq!(n, DATA.len()); assert_eq!(buf.len(), DATA.len()); assert_eq!(&*buf, DATA); + assert_eq!(buf.spare_capacity(), DATA.len()); + assert!(buf.has_spare_capacity()); } #[test] fn dont_overwrite_existing_bytes_in_vec() { let mut buf = Vec::::with_capacity(2 * DATA.len()); + assert_eq!(buf.spare_capacity(), 2 * DATA.len()); + assert!(buf.has_spare_capacity()); buf.extend(DATA2); + assert_eq!(buf.spare_capacity(), 2 * DATA.len() - DATA2.len()); + assert!(buf.has_spare_capacity()); let start = buf.len(); let n = write_bytes(DATA, &mut buf); assert_eq!(n, DATA.len()); assert_eq!(buf.len(), DATA2.len() + DATA.len()); assert_eq!(&buf[..start], DATA2); // Original bytes untouched. assert_eq!(&buf[start..start + n], DATA); + assert_eq!(buf.spare_capacity(), 1); + assert!(buf.has_spare_capacity()); + buf.push(b'a'); + assert_eq!(buf.spare_capacity(), 0); + assert!(!buf.has_spare_capacity()); +} + +#[test] +fn limited_bytes() { + const LIMIT: usize = 5; + let mut buf = Vec::::with_capacity(2 * DATA.len()).limit(LIMIT); + assert_eq!(buf.spare_capacity(), 5); + assert!(buf.has_spare_capacity()); + + let n = write_bytes(DATA, &mut buf); + assert_eq!(n, LIMIT); + assert_eq!(buf.spare_capacity(), 0); + assert!(!buf.has_spare_capacity()); + let buf = buf.into_inner(); + assert_eq!(&*buf, &DATA[..LIMIT]); + assert_eq!(buf.len(), LIMIT); } #[test] fn vectored_array() { let mut bufs = [Vec::with_capacity(1), Vec::with_capacity(DATA.len())]; + assert_eq!(bufs.spare_capacity(), 1 + DATA.len()); + assert!(bufs.has_spare_capacity()); let n = write_bytes_vectored(DATA, &mut bufs); assert_eq!(n, DATA.len()); assert_eq!(bufs[0].len(), 1); assert_eq!(bufs[1].len(), DATA.len() - 1); assert_eq!(bufs[0], &DATA[..1]); assert_eq!(bufs[1], &DATA[1..]); + assert_eq!(bufs.spare_capacity(), 1); + assert!(bufs.has_spare_capacity()); + bufs[1].push(b'a'); + assert_eq!(bufs.spare_capacity(), 0); + assert!(!bufs.has_spare_capacity()); } #[test] @@ -88,6 +126,8 @@ fn vectored_tuple() { Vec::with_capacity(3), Vec::with_capacity(DATA.len()), ); + assert_eq!(bufs.spare_capacity(), 1 + 3 + DATA.len()); + assert!(bufs.has_spare_capacity()); let n = write_bytes_vectored(DATA, &mut bufs); assert_eq!(n, DATA.len()); assert_eq!(bufs.0.len(), 1); @@ -96,4 +136,35 @@ fn vectored_tuple() { assert_eq!(bufs.0, &DATA[..1]); assert_eq!(bufs.1, &DATA[1..4]); assert_eq!(bufs.2, &DATA[4..]); + assert_eq!(bufs.spare_capacity(), 4); + assert!(bufs.has_spare_capacity()); + bufs.2.extend_from_slice(b"aaaa"); + assert_eq!(bufs.spare_capacity(), 0); + assert!(!bufs.has_spare_capacity()); +} + +#[test] +fn limited_bytes_vectored() { + const LIMIT: usize = 5; + + let mut bufs = [ + Vec::with_capacity(1), + Vec::with_capacity(DATA.len()), + Vec::with_capacity(10), + ] + .limit(LIMIT); + assert_eq!(bufs.spare_capacity(), LIMIT); + assert!(bufs.has_spare_capacity()); + + let n = write_bytes_vectored(DATA, &mut bufs); + assert_eq!(n, LIMIT); + assert_eq!(bufs.spare_capacity(), 0); + assert!(!bufs.has_spare_capacity()); + let bufs = bufs.into_inner(); + assert_eq!(bufs[0].len(), 1); + assert_eq!(bufs[1].len(), LIMIT - 1); + assert_eq!(bufs[2].len(), 0); + assert_eq!(bufs[0], &DATA[..1]); + assert_eq!(bufs[1], &DATA[1..LIMIT]); + assert_eq!(bufs[2], &[]); } diff --git a/tools/src/bin/convert_trace.rs b/tools/src/bin/convert_trace.rs index 3d41aa25e..e646caafd 100644 --- a/tools/src/bin/convert_trace.rs +++ b/tools/src/bin/convert_trace.rs @@ -19,7 +19,7 @@ fn main() { let output = match args.next() { Some(output) => PathBuf::from(output), None => { - let end_idx = input.rfind('.').unwrap_or(input.len()); + let end_idx = input.rfind('.').unwrap_or_else(|| input.len()); let mut output = PathBuf::from(&input[..end_idx]); // If the input has a single extension this will add `json` to it. // If however it has two extensions, e.g. `.bin.log` this will @@ -60,10 +60,11 @@ fn main() { .as_micros(); let mut duration = event.end.duration_since(event.start).unwrap().as_micros(); - let pid = event.stream_id; - let tid = event.substream_id; + let process_id = event.stream_id; + let thread_id = event.substream_id; loop { - match times.entry((pid, tid)).or_default().entry(timestamp) { + let key = (process_id, thread_id); + match times.entry(key).or_default().entry(timestamp) { Entry::Vacant(entry) => { entry.insert(duration); break; @@ -88,8 +89,8 @@ fn main() { output, "{}\t\t{{\"pid\": {}, \"tid\": {}, \"ts\": {}, \"dur\": {}, \"name\": \"{}\"", if first { "" } else { ",\n" }, - pid, - tid, + process_id, + thread_id, timestamp, duration, event.description, @@ -101,7 +102,7 @@ fn main() { output .write_all(b", \"args\": {") .expect("failed to write event to output"); - for (name, value) in event.attributes.iter() { + for (name, value) in &event.attributes { let fmt_args = match value { // NOTE: `format_args!` is useless. Value::Unsigned(value) => format!("\"{}\": {}", name, value), @@ -194,7 +195,9 @@ pub struct TraceEvents<'t, R> { trace: &'t mut Trace, } +#[allow(clippy::unreadable_literal)] const METADATA_MAGIC: u32 = 0x75D11D4D; +#[allow(clippy::unreadable_literal)] const EVENT_MAGIC: u32 = 0xC1FC1FB7; /// Minimum amount of bytes in the buffer before we read again. @@ -237,12 +240,12 @@ where match magic { METADATA_MAGIC => { if let Err(err) = self.apply_metadata_packet() { - Some(Err(err.into())) + Some(Err(err)) } else { self.next() } } - EVENT_MAGIC => Some(self.parse_event_packet().map_err(|e| e.into())), + EVENT_MAGIC => Some(self.parse_event_packet()), magic => Some(Err(ParseError::InvalidMagic(magic))), } }