diff --git a/Cargo.lock b/Cargo.lock index 3f348166d..6d33477e3 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -989,10 +989,12 @@ dependencies = [ "sel4", "sel4-async-block-io", "sel4-async-block-io-fat", + "sel4-async-io", "sel4-async-network", "sel4-async-single-threaded-executor", "sel4-async-time", "sel4-async-unsync", + "sel4-atomic-ptr", "sel4-bounce-buffer-allocator", "sel4-config", "sel4-externally-shared", @@ -1091,10 +1093,10 @@ dependencies = [ "rustls-pemfile", "sel4-async-block-io", "sel4-async-block-io-fat", + "sel4-async-io", "sel4-async-network", "sel4-async-network-rustls", "sel4-async-network-rustls-utils", - "sel4-async-network-traits", "sel4-async-single-threaded-executor", "sel4-async-time", "sel4-async-unsync", @@ -1844,13 +1846,16 @@ dependencies = [ "sel4-async-block-io", ] +[[package]] +name = "sel4-async-io" +version = "0.1.0" + [[package]] name = "sel4-async-network" version = "0.1.0" dependencies = [ - "futures", "log", - "sel4-async-network-traits", + "sel4-async-io", "smoltcp", ] @@ -1858,10 +1863,9 @@ dependencies = [ name = "sel4-async-network-rustls" version = "0.1.0" dependencies = [ - "futures", "log", "rustls", - "sel4-async-network-traits", + "sel4-async-io", ] [[package]] @@ -1875,13 +1879,6 @@ dependencies = [ "sel4-async-time", ] -[[package]] -name = "sel4-async-network-traits" -version = "0.1.0" -dependencies = [ - "futures", -] - [[package]] name = "sel4-async-single-threaded-executor" version = "0.1.0" diff --git a/Cargo.toml b/Cargo.toml index 6404330de..f4a9ba661 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -61,10 +61,10 @@ members = [ "crates/sel4", "crates/sel4-async/block-io", "crates/sel4-async/block-io/fat", + "crates/sel4-async/io", "crates/sel4-async/network", "crates/sel4-async/network/rustls", "crates/sel4-async/network/rustls/utils", - "crates/sel4-async/network/traits", "crates/sel4-async/single-threaded-executor", "crates/sel4-async/time", "crates/sel4-async/unsync", diff --git a/crates/examples/microkit/http-server/pds/server/core/Cargo.nix b/crates/examples/microkit/http-server/pds/server/core/Cargo.nix index a92d16881..ee4e2932b 100644 --- a/crates/examples/microkit/http-server/pds/server/core/Cargo.nix +++ b/crates/examples/microkit/http-server/pds/server/core/Cargo.nix @@ -33,7 +33,7 @@ mk { sel4-async-unsync sel4-async-time sel4-async-network - sel4-async-network-traits + sel4-async-io sel4-async-network-rustls sel4-async-network-rustls-utils sel4-panicking-env diff --git a/crates/examples/microkit/http-server/pds/server/core/Cargo.toml b/crates/examples/microkit/http-server/pds/server/core/Cargo.toml index d5851b2e5..4a32fe243 100644 --- a/crates/examples/microkit/http-server/pds/server/core/Cargo.toml +++ b/crates/examples/microkit/http-server/pds/server/core/Cargo.toml @@ -23,10 +23,10 @@ log = "0.4.17" rustls-pemfile = { version = "2.0.0", default-features = false } sel4-async-block-io = { path = "../../../../../../sel4-async/block-io" } sel4-async-block-io-fat = { path = "../../../../../../sel4-async/block-io/fat" } +sel4-async-io = { path = "../../../../../../sel4-async/io" } sel4-async-network = { path = "../../../../../../sel4-async/network" } sel4-async-network-rustls = { path = "../../../../../../sel4-async/network/rustls" } sel4-async-network-rustls-utils = { path = "../../../../../../sel4-async/network/rustls/utils" } -sel4-async-network-traits = { path = "../../../../../../sel4-async/network/traits" } sel4-async-time = { path = "../../../../../../sel4-async/time" } sel4-async-unsync = { path = "../../../../../../sel4-async/unsync" } sel4-panicking-env = { path = "../../../../../../sel4-panicking/env" } diff --git a/crates/examples/microkit/http-server/pds/server/core/src/lib.rs b/crates/examples/microkit/http-server/pds/server/core/src/lib.rs index 8da6b7e5e..737eee447 100644 --- a/crates/examples/microkit/http-server/pds/server/core/src/lib.rs +++ b/crates/examples/microkit/http-server/pds/server/core/src/lib.rs @@ -23,10 +23,10 @@ use rustls::ServerConfig; use sel4_async_block_io::{access::ReadOnly, constant_block_sizes, BlockIO}; use sel4_async_block_io_fat as fat; +use sel4_async_io::ReadExactError; use sel4_async_network::{ManagedInterface, TcpSocket, TcpSocketError}; use sel4_async_network_rustls::{Error as AsyncRustlsError, ServerConnector}; use sel4_async_network_rustls_utils::GetCurrentTimeImpl; -use sel4_async_network_traits::ClosedError; use sel4_async_single_threaded_executor::LocalSpawner; use sel4_async_time::{Instant, TimerManager}; @@ -117,10 +117,10 @@ type SocketUser = Box< async fn use_socket_for_http( server: Server, mut socket: TcpSocket, -) -> Result<(), ClosedError> { +) -> Result<(), ReadExactError> { socket.accept(HTTP_PORT).await?; server.handle_connection(&mut socket).await?; - socket.close().await?; + socket.close(); Ok(()) } @@ -128,7 +128,7 @@ async fn use_socket_for_https, tls_config: Arc, mut socket: TcpSocket, -) -> Result<(), ClosedError>> { +) -> Result<(), ReadExactError>> { socket .accept(HTTPS_PORT) .await @@ -138,11 +138,7 @@ async fn use_socket_for_https Server { } } - pub(crate) async fn handle_connection( + pub(crate) async fn handle_connection( &self, conn: &mut U, - ) -> Result<(), ClosedError> { + ) -> Result<(), ReadExactError> { loop { let mut buf = vec![0; 1024 * 16]; let mut i = 0; loop { let n = conn.read(&mut buf[i..]).await?; if n == 0 { - return Err(ClosedError::Closed); + return Err(ReadExactError::UnexpectedEof); } i += n; if is_request_complete(&buf[..i]).unwrap_or(false) { @@ -68,11 +68,11 @@ impl Server { Ok(()) } - async fn handle_request( + async fn handle_request( &self, conn: &mut U, request_path: &str, - ) -> Result<(), ClosedError> { + ) -> Result<(), ReadExactError> { match self.lookup_request_path(request_path).await { RequestPathStatus::Ok { file_name, file } => { let content_type = content_type_from_name(&file_name); @@ -88,12 +88,12 @@ impl Server { Ok(()) } - async fn serve_file( + async fn serve_file( &self, conn: &mut U, content_type: &str, file: fat::File, - ) -> Result<(), ClosedError> { + ) -> Result<(), ReadExactError> { let file_len: usize = self .volume_manager .lock() @@ -133,11 +133,11 @@ impl Server { Ok(()) } - async fn serve_moved_permanently( + async fn serve_moved_permanently( &self, conn: &mut U, location: &str, - ) -> Result<(), ClosedError> { + ) -> Result<(), ReadExactError> { let phrase = "Moved Permanently"; self.start_response_headers(conn, 301, phrase).await?; self.send_response_header(conn, "Content-Type", b"text/plain") @@ -151,10 +151,10 @@ impl Server { Ok(()) } - async fn serve_not_found( + async fn serve_not_found( &self, conn: &mut U, - ) -> Result<(), ClosedError> { + ) -> Result<(), ReadExactError> { let phrase = "Not Found"; self.start_response_headers(conn, 404, phrase).await?; self.send_response_header(conn, "Content-Type", b"text/plain") @@ -166,12 +166,12 @@ impl Server { Ok(()) } - async fn start_response_headers( + async fn start_response_headers( &self, conn: &mut U, status_code: usize, reason_phrase: &str, - ) -> Result<(), ClosedError> { + ) -> Result<(), ReadExactError> { conn.write_all(b"HTTP/1.1 ").await?; conn.write_all(status_code.to_string().as_bytes()).await?; conn.write_all(b" ").await?; @@ -180,12 +180,12 @@ impl Server { Ok(()) } - async fn send_response_header( + async fn send_response_header( &self, conn: &mut U, name: &str, value: &[u8], - ) -> Result<(), ClosedError> { + ) -> Result<(), ReadExactError> { conn.write_all(name.as_bytes()).await?; conn.write_all(b": ").await?; conn.write_all(value).await?; @@ -193,10 +193,10 @@ impl Server { Ok(()) } - async fn finish_response_headers( + async fn finish_response_headers( &self, conn: &mut U, - ) -> Result<(), ClosedError> { + ) -> Result<(), ReadExactError> { conn.write_all(b"\r\n").await?; Ok(()) } diff --git a/crates/private/meta/Cargo.nix b/crates/private/meta/Cargo.nix index 23504a629..d4ff855e9 100644 --- a/crates/private/meta/Cargo.nix +++ b/crates/private/meta/Cargo.nix @@ -18,10 +18,12 @@ mk { sel4-async-block-io sel4-async-block-io-fat + sel4-async-io sel4-async-network sel4-async-single-threaded-executor sel4-async-time sel4-async-unsync + sel4-atomic-ptr sel4-bounce-buffer-allocator sel4-immediate-sync-once-cell sel4-immutable-cell diff --git a/crates/private/meta/Cargo.toml b/crates/private/meta/Cargo.toml index aaf6b992d..e116b05c9 100644 --- a/crates/private/meta/Cargo.toml +++ b/crates/private/meta/Cargo.toml @@ -26,10 +26,12 @@ log = "0.4.17" sel4 = { path = "../../sel4" } sel4-async-block-io = { path = "../../sel4-async/block-io" } sel4-async-block-io-fat = { path = "../../sel4-async/block-io/fat" } +sel4-async-io = { path = "../../sel4-async/io" } sel4-async-network = { path = "../../sel4-async/network" } sel4-async-single-threaded-executor = { path = "../../sel4-async/single-threaded-executor" } sel4-async-time = { path = "../../sel4-async/time" } sel4-async-unsync = { path = "../../sel4-async/unsync" } +sel4-atomic-ptr = { path = "../../sel4-atomic-ptr" } sel4-bounce-buffer-allocator = { path = "../../sel4-bounce-buffer-allocator" } sel4-config = { path = "../../sel4/config" } sel4-externally-shared = { path = "../../sel4-externally-shared", features = ["unstable"] } diff --git a/crates/private/meta/src/lib.rs b/crates/private/meta/src/lib.rs index c527fe9ae..21140f83b 100644 --- a/crates/private/meta/src/lib.rs +++ b/crates/private/meta/src/lib.rs @@ -82,10 +82,12 @@ definitely! { sel4 sel4_async_block_io sel4_async_block_io_fat + sel4_async_io sel4_async_network sel4_async_single_threaded_executor sel4_async_time sel4_async_unsync + sel4_atomic_ptr sel4_bounce_buffer_allocator sel4_config sel4_externally_shared diff --git a/crates/sel4-async/io/Cargo.nix b/crates/sel4-async/io/Cargo.nix new file mode 100644 index 000000000..3af5bb001 --- /dev/null +++ b/crates/sel4-async/io/Cargo.nix @@ -0,0 +1,15 @@ +# +# Copyright 2023, Colias Group, LLC +# +# SPDX-License-Identifier: BSD-2-Clause +# + +{ mk, mkDefaultFrontmatterWithReuseArgs, defaultReuseFrontmatterArgs, versions }: + +mk rec { + nix.frontmatter = mkDefaultFrontmatterWithReuseArgs (defaultReuseFrontmatterArgs // { + licenseID = package.license; + }); + package.name = "sel4-async-io"; + package.license = "MIT OR Apache-2.0"; +} diff --git a/crates/sel4-async/network/traits/Cargo.toml b/crates/sel4-async/io/Cargo.toml similarity index 65% rename from crates/sel4-async/network/traits/Cargo.toml rename to crates/sel4-async/io/Cargo.toml index d82f33fd4..bf0818008 100644 --- a/crates/sel4-async/network/traits/Cargo.toml +++ b/crates/sel4-async/io/Cargo.toml @@ -1,7 +1,7 @@ # # Copyright 2023, Colias Group, LLC # -# SPDX-License-Identifier: BSD-2-Clause +# SPDX-License-Identifier: MIT OR Apache-2.0 # # # This file is generated from './Cargo.nix'. You can edit this file directly @@ -10,11 +10,8 @@ # [package] -name = "sel4-async-network-traits" +name = "sel4-async-io" version = "0.1.0" authors = ["Nick Spinale "] edition = "2021" -license = "BSD-2-Clause" - -[dependencies] -futures = { version = "0.3.28", default-features = false, features = ["alloc"] } +license = "MIT OR Apache-2.0" diff --git a/crates/sel4-async/io/src/lib.rs b/crates/sel4-async/io/src/lib.rs new file mode 100644 index 000000000..acf2a8b74 --- /dev/null +++ b/crates/sel4-async/io/src/lib.rs @@ -0,0 +1,118 @@ +// +// Copyright 2024, Colias Group, LLC +// Copyright 2024, Embedded devices Working Group +// +// SPDX-License-Identifier: MIT OR Apache-2.0 +// + +// TODO use Pin + +#![no_std] + +use core::fmt; +use core::future::poll_fn; +use core::pin::Pin; +use core::task::{Context, Poll}; + +pub trait ErrorType { + type Error: fmt::Debug; +} + +pub trait Read: ErrorType { + fn poll_read( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &mut [u8], + ) -> Poll>; + + // // // + + #[allow(async_fn_in_trait)] + async fn read(&mut self, buf: &mut [u8]) -> Result + where + Self: Unpin, + { + let mut pin = Pin::new(self); + poll_fn(move |cx| pin.as_mut().poll_read(cx, buf)).await + } + + #[allow(async_fn_in_trait)] + async fn read_exact(&mut self, mut buf: &mut [u8]) -> Result<(), ReadExactError> + where + Self: Unpin, + { + while !buf.is_empty() { + match self.read(buf).await { + Ok(0) => break, + Ok(n) => buf = &mut buf[n..], + Err(e) => return Err(ReadExactError::Other(e)), + } + } + if buf.is_empty() { + Ok(()) + } else { + Err(ReadExactError::UnexpectedEof) + } + } +} + +pub trait Write: ErrorType { + fn poll_write( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &[u8], + ) -> Poll>; + + fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll>; + + // // // + + #[allow(async_fn_in_trait)] + async fn write(&mut self, buf: &[u8]) -> Result + where + Self: Unpin, + { + let mut pin = Pin::new(self); + poll_fn(|cx| pin.as_mut().poll_write(cx, buf)).await + } + + #[allow(async_fn_in_trait)] + async fn write_all(&mut self, buf: &[u8]) -> Result<(), Self::Error> + where + Self: Unpin, + { + let mut buf = buf; + while !buf.is_empty() { + match self.write(buf).await { + Ok(0) => panic!("write() returned Ok(0)"), + Ok(n) => buf = &buf[n..], + Err(e) => return Err(e), + } + } + Ok(()) + } + + #[allow(async_fn_in_trait)] + async fn flush(&mut self) -> Result<(), Self::Error> + where + Self: Unpin, + { + let mut pin = Pin::new(self); + poll_fn(|cx| pin.as_mut().poll_flush(cx)).await + } +} + +/// Error returned by [`Read::read_exact`] +#[derive(Debug, Copy, Clone, Eq, PartialEq)] +pub enum ReadExactError { + /// An EOF error was encountered before reading the exact amount of requested bytes. + UnexpectedEof, + /// Error returned by the inner Read. + Other(E), +} + +impl From for ReadExactError { + fn from(err: E) -> Self { + Self::Other(err) + } +} diff --git a/crates/sel4-async/network/Cargo.nix b/crates/sel4-async/network/Cargo.nix index d1b3ae6d0..1a3c5ff15 100644 --- a/crates/sel4-async/network/Cargo.nix +++ b/crates/sel4-async/network/Cargo.nix @@ -9,15 +9,8 @@ mk { package.name = "sel4-async-network"; dependencies = { - inherit (localCrates) sel4-async-network-traits; + inherit (localCrates) sel4-async-io; inherit (versions) log; - futures = { - version = versions.futures; - default-features = false; - features = [ - "alloc" - ]; - }; smoltcp = smoltcpWith [ "async" "alloc" diff --git a/crates/sel4-async/network/Cargo.toml b/crates/sel4-async/network/Cargo.toml index ce9bfc974..753536c04 100644 --- a/crates/sel4-async/network/Cargo.toml +++ b/crates/sel4-async/network/Cargo.toml @@ -17,9 +17,8 @@ edition = "2021" license = "BSD-2-Clause" [dependencies] -futures = { version = "0.3.28", default-features = false, features = ["alloc"] } log = "0.4.17" -sel4-async-network-traits = { path = "traits" } +sel4-async-io = { path = "../io" } [dependencies.smoltcp] version = "0.10.0" diff --git a/crates/sel4-async/network/rustls/Cargo.nix b/crates/sel4-async/network/rustls/Cargo.nix index 36e74e9b2..3721a241b 100644 --- a/crates/sel4-async/network/rustls/Cargo.nix +++ b/crates/sel4-async/network/rustls/Cargo.nix @@ -14,16 +14,9 @@ mk rec { package.license = "Apache-2.0 OR ISC OR MIT"; dependencies = { inherit (localCrates) - sel4-async-network-traits + sel4-async-io ; inherit (versions) log; rustls = rustlsWith [] // (localCrates.rustls or {}); - futures = { - version = versions.futures; - default-features = false; - features = [ - "alloc" - ]; - }; }; } diff --git a/crates/sel4-async/network/rustls/Cargo.toml b/crates/sel4-async/network/rustls/Cargo.toml index b3d602294..27096d99b 100644 --- a/crates/sel4-async/network/rustls/Cargo.toml +++ b/crates/sel4-async/network/rustls/Cargo.toml @@ -17,9 +17,8 @@ edition = "2021" license = "Apache-2.0 OR ISC OR MIT" [dependencies] -futures = { version = "0.3.28", default-features = false, features = ["alloc"] } log = "0.4.17" -sel4-async-network-traits = { path = "../traits" } +sel4-async-io = { path = "../../io" } [dependencies.rustls] git = "https://github.com/coliasgroup/rustls.git" diff --git a/crates/sel4-async/network/rustls/src/conn.rs b/crates/sel4-async/network/rustls/src/conn.rs index 763f563aa..5499ae471 100644 --- a/crates/sel4-async/network/rustls/src/conn.rs +++ b/crates/sel4-async/network/rustls/src/conn.rs @@ -6,6 +6,7 @@ // Derived from https://github.com/rustls/rustls/pull/1648 by https://github.com/japaric +use core::future::Future; use core::marker::PhantomData; use core::mem; use core::ops::DerefMut; @@ -14,7 +15,6 @@ use core::task::{self, Poll}; use alloc::sync::Arc; -use futures::Future; use rustls::client::{ClientConnectionData, UnbufferedClientConnection}; use rustls::pki_types::ServerName; use rustls::server::{ServerConnectionData, UnbufferedServerConnection}; @@ -23,7 +23,7 @@ use rustls::unbuffered::{ }; use rustls::{ClientConfig, ServerConfig, SideData, UnbufferedConnectionCommon}; -use sel4_async_network_traits::AsyncIO; +use sel4_async_io::{Read, Write}; use crate::{ utils::{poll_read, poll_write, try_or_resize_and_retry, Buffer, WriteCursor}, @@ -42,7 +42,7 @@ impl ClientConnector { // FIXME should not return an error but instead hoist it into a `Connect` variant ) -> Result, Error> where - IO: AsyncIO, + IO: Read + Write, { let conn = UnbufferedClientConnection::new(self.config.clone(), domain)?; @@ -67,7 +67,7 @@ impl ServerConnector { // FIXME should not return an error but instead hoist it into a `Connect` variant ) -> Result, Error> where - IO: AsyncIO, + IO: Read + Write, { let conn = UnbufferedServerConnection::new(self.config.clone())?; @@ -117,7 +117,7 @@ impl Future for Connect where D: Unpin + SideDataAugmented, T: Unpin + DerefMut>, - IO: Unpin + AsyncIO, + IO: Unpin + Read + Write, { type Output = Result, Error>; @@ -215,7 +215,7 @@ impl SideDataAugmented for ServerConnectionData { impl ConnectInner where T: DerefMut>, - IO: AsyncIO, + IO: Read + Write, D: SideDataAugmented, { fn advance(&mut self, updates: &mut Updates) -> Result> { @@ -288,58 +288,19 @@ impl TlsStream { } } -impl AsyncIO for TlsStream +impl sel4_async_io::ErrorType for TlsStream where - T: DerefMut> + Unpin, - IO: AsyncIO + Unpin, - D: SideDataAugmented + Unpin, + IO: sel4_async_io::ErrorType, { type Error = Error; +} - fn poll_write( - mut self: Pin<&mut Self>, - cx: &mut task::Context<'_>, - buf: &[u8], - ) -> Poll> { - let mut outgoing = mem::take(&mut self.outgoing); - - // no IO here; just in-memory writes - match SideDataAugmented::process_tls_records_generic(&mut self.conn, &mut []).state? { - ConnectionState::WriteTraffic(mut state) => { - try_or_resize_and_retry( - |out_buffer| state.encrypt(buf, out_buffer), - |e| { - if let EncryptError::InsufficientSize(is) = &e { - Ok(*is) - } else { - Err(e.into()) - } - }, - &mut outgoing, - )?; - } - - ConnectionState::Closed => { - return Poll::Ready(Err(Error::ConnectionAborted)); - } - - state => unreachable!("{state:?}"), - } - - // opportunistically try to write data into the socket - // XXX should this be a loop? - while !outgoing.is_empty() { - let would_block = poll_write(&mut self.io, &mut outgoing, cx)?; - if would_block { - break; - } - } - - self.outgoing = outgoing; - - Poll::Ready(Ok(buf.len())) - } - +impl Read for TlsStream +where + T: DerefMut> + Unpin, + IO: Read + Unpin, + D: SideDataAugmented + Unpin, +{ fn poll_read( mut self: Pin<&mut Self>, cx: &mut task::Context<'_>, @@ -398,6 +359,57 @@ where Poll::Ready(Ok(cursor.into_used())) } +} + +impl Write for TlsStream +where + T: DerefMut> + Unpin, + IO: Write + Unpin, + D: SideDataAugmented + Unpin, +{ + fn poll_write( + mut self: Pin<&mut Self>, + cx: &mut task::Context<'_>, + buf: &[u8], + ) -> Poll> { + let mut outgoing = mem::take(&mut self.outgoing); + + // no IO here; just in-memory writes + match SideDataAugmented::process_tls_records_generic(&mut self.conn, &mut []).state? { + ConnectionState::WriteTraffic(mut state) => { + try_or_resize_and_retry( + |out_buffer| state.encrypt(buf, out_buffer), + |e| { + if let EncryptError::InsufficientSize(is) = &e { + Ok(*is) + } else { + Err(e.into()) + } + }, + &mut outgoing, + )?; + } + + ConnectionState::Closed => { + return Poll::Ready(Err(Error::ConnectionAborted)); + } + + state => unreachable!("{state:?}"), + } + + // opportunistically try to write data into the socket + // XXX should this be a loop? + while !outgoing.is_empty() { + let would_block = poll_write(&mut self.io, &mut outgoing, cx)?; + if would_block { + break; + } + } + + self.outgoing = outgoing; + + Poll::Ready(Ok(buf.len())) + } fn poll_flush( mut self: Pin<&mut Self>, @@ -421,15 +433,4 @@ where .poll_flush(cx) .map_err(Error::TransitError) } - - #[allow(unused_mut)] - fn poll_close( - mut self: Pin<&mut Self>, - cx: &mut task::Context<'_>, - ) -> Poll>> { - // XXX write out close_notify here? - Pin::new(&mut self.io) - .poll_close(cx) - .map_err(Error::TransitError) - } } diff --git a/crates/sel4-async/network/rustls/src/utils.rs b/crates/sel4-async/network/rustls/src/utils.rs index 8a6cc4637..44d1eed50 100644 --- a/crates/sel4-async/network/rustls/src/utils.rs +++ b/crates/sel4-async/network/rustls/src/utils.rs @@ -12,7 +12,7 @@ use core::task::{self, Poll}; use rustls::unbuffered::InsufficientSizeError; -use sel4_async_network_traits::AsyncIO; +use sel4_async_io::{Read, Write}; use crate::Error; @@ -143,7 +143,7 @@ pub(crate) fn poll_read( cx: &mut task::Context, ) -> Result> where - IO: AsyncIO + Unpin, + IO: Read + Unpin, { if incoming.unfilled().is_empty() { // XXX should this be user configurable? @@ -157,7 +157,6 @@ where incoming.advance(read); false } - Poll::Pending => true, }; @@ -171,9 +170,9 @@ pub(crate) fn poll_write( cx: &mut task::Context, ) -> Result> where - IO: AsyncIO + Unpin, + IO: Write + Unpin, { - let pending = match Pin::new(io).poll_write(cx, outgoing.filled()) { + let would_block = match Pin::new(io).poll_write(cx, outgoing.filled()) { Poll::Ready(res) => { let written = res.map_err(Error::TransitError)?; log::trace!("wrote {written}B into socket"); @@ -181,8 +180,8 @@ where log::trace!("{}B remain in the outgoing buffer", outgoing.len()); false } - Poll::Pending => true, }; - Ok(pending) + + Ok(would_block) } diff --git a/crates/sel4-async/network/src/lib.rs b/crates/sel4-async/network/src/lib.rs index d5a54f933..ded7b52db 100644 --- a/crates/sel4-async/network/src/lib.rs +++ b/crates/sel4-async/network/src/lib.rs @@ -4,6 +4,9 @@ // SPDX-License-Identifier: BSD-2-Clause // +// Ideas for implementing operations on TCP sockets taken from: +// https://github.com/embassy-rs/embassy/blob/main/embassy-net/src/tcp.rs + #![no_std] extern crate alloc; @@ -12,11 +15,11 @@ use alloc::rc::Rc; use alloc::vec; use alloc::vec::Vec; use core::cell::RefCell; +use core::future::poll_fn; use core::marker::PhantomData; use core::pin::Pin; use core::task::{self, Poll}; -use futures::prelude::*; use log::info; use smoltcp::{ iface::{Config, Context, Interface, SocketHandle, SocketSet}, @@ -26,7 +29,7 @@ use smoltcp::{ wire::{DnsQueryType, IpAddress, IpCidr, IpEndpoint, IpListenEndpoint, Ipv4Address, Ipv4Cidr}, }; -use sel4_async_network_traits::AsyncIO; +use sel4_async_io::{Read, Write}; pub(crate) const DEFAULT_KEEP_ALIVE_INTERVAL: u64 = 75000; pub(crate) const DEFAULT_TCP_SOCKET_BUFFER_SIZE: usize = 65535; @@ -59,12 +62,24 @@ pub struct Socket { _phantom: PhantomData, } +impl Drop for Socket { + fn drop(&mut self) { + self.shared + .inner + .borrow_mut() + .socket_set + .remove(self.handle); + } +} + #[derive(Copy, Clone, Debug, PartialEq, Eq)] pub enum TcpSocketError { InvalidState(tcp::State), // TODO just use InvalidState variants of below errors? RecvError(tcp::RecvError), SendError(tcp::SendError), + ListenError(tcp::ListenError), ConnectError(tcp::ConnectError), + ConnectionResetDuringConnect, } #[derive(Copy, Clone, Debug, PartialEq, Eq)] @@ -82,10 +97,8 @@ impl ManagedInterface { ) -> Self { let iface = Interface::new(config, device, instant); let mut socket_set = SocketSet::new(vec![]); - let dns_socket = dns::Socket::new(&[], vec![]); - let dns_socket_handle = socket_set.add(dns_socket); - let dhcp_socket = dhcpv4::Socket::new(); - let dhcp_socket_handle = socket_set.add(dhcp_socket); + let dns_socket_handle = socket_set.add(dns::Socket::new(&[], vec![])); + let dhcp_socket_handle = socket_set.add(dhcpv4::Socket::new()); let mut this = ManagedInterfaceShared { iface, @@ -157,7 +170,7 @@ impl ManagedInterface { .start_query(inner.iface.context(), name, query_type) .map_err(DnsError::StartQueryError)? }; - future::poll_fn(|cx| { + poll_fn(|cx| { let inner = &mut *self.inner().borrow_mut(); let socket = inner .socket_set @@ -190,7 +203,7 @@ impl> Socket { f(socket) } - pub fn with_and_context_mut(&mut self, f: impl FnOnce(&mut Context, &mut T) -> R) -> R { + pub fn with_context_mut(&mut self, f: impl FnOnce(&mut Context, &mut T) -> R) -> R { let network = &mut *self.shared.inner().borrow_mut(); let context = network.iface.context(); let socket = network.socket_set.get_mut(self.handle); @@ -208,20 +221,17 @@ impl Socket> { T: Into, U: Into, { - self.with_and_context_mut(|cx, socket| socket.connect(cx, remote_endpoint, local_endpoint)) + self.with_context_mut(|cx, socket| socket.connect(cx, remote_endpoint, local_endpoint)) .map_err(TcpSocketError::ConnectError)?; - future::poll_fn(|cx| { + poll_fn(|cx| { self.with_mut(|socket| { let state = socket.state(); match state { tcp::State::Closed | tcp::State::TimeWait => { - Poll::Ready(Err(TcpSocketError::InvalidState(state))) - } - tcp::State::Listen => { - // TODO handle differently - Poll::Ready(Err(TcpSocketError::InvalidState(state))) + Poll::Ready(Err(TcpSocketError::ConnectionResetDuringConnect)) } + tcp::State::Listen => unreachable!(), // because future holds &mut self tcp::State::SynSent | tcp::State::SynReceived => { socket.register_send_waker(cx.waker()); Poll::Pending @@ -235,43 +245,22 @@ impl Socket> { pub async fn accept_with_keep_alive( &mut self, - port: u16, + local_endpoint: impl Into, keep_alive_interval: Option, ) -> Result<(), TcpSocketError> { - future::poll_fn(|cx| { + self.with_mut(|socket| { + socket + .listen(local_endpoint) + .map_err(TcpSocketError::ListenError) + })?; + + poll_fn(|cx| { self.with_mut(|socket| match socket.state() { - tcp::State::Closed => { - socket.listen(port).unwrap(); - Poll::Ready(()) - } - tcp::State::Listen => Poll::Ready(()), - _ => { + tcp::State::Listen | tcp::State::SynSent | tcp::State::SynReceived => { socket.register_recv_waker(cx.waker()); Poll::Pending } - }) - }) - .await; - - future::poll_fn(|cx| { - self.with_mut(|socket| { - if socket.is_active() { - Poll::Ready(Ok(())) - } else { - let state = socket.state(); - match state { - tcp::State::Closed - | tcp::State::Closing - | tcp::State::FinWait1 - | tcp::State::FinWait2 => { - Poll::Ready(Err(TcpSocketError::InvalidState(state))) - } - _ => { - socket.register_recv_waker(cx.waker()); - Poll::Pending - } - } - } + _ => Poll::Ready(Ok(())), }) }) .await?; @@ -281,136 +270,19 @@ impl Socket> { Ok(()) } - pub async fn accept(&mut self, port: u16) -> Result<(), TcpSocketError> { + pub async fn accept( + &mut self, + local_endpoint: impl Into, + ) -> Result<(), TcpSocketError> { self.accept_with_keep_alive( - port, + local_endpoint, Some(Duration::from_millis(DEFAULT_KEEP_ALIVE_INTERVAL)), ) .await } - #[allow(clippy::needless_pass_by_ref_mut)] - pub async fn recv(&mut self, buffer: &mut [u8]) -> Result { - future::poll_fn(|cx| self.poll_recv(cx, buffer)).await - } - - #[allow(clippy::needless_pass_by_ref_mut)] - pub fn poll_recv( - &mut self, - cx: &mut task::Context<'_>, - buffer: &mut [u8], - ) -> Poll> { - self.with_mut(|socket| { - if socket.can_recv() { - Poll::Ready( - socket - .recv_slice(buffer) - .map_err(TcpSocketError::RecvError) - .map(|n| { - assert!(n > 0); // check assumption about smoltcp - n - }), - ) - } else { - let state = socket.state(); - match state { - tcp::State::FinWait1 - | tcp::State::FinWait2 - | tcp::State::Closed - | tcp::State::Closing - | tcp::State::CloseWait - | tcp::State::TimeWait => Poll::Ready(Err(TcpSocketError::InvalidState(state))), - _ => { - socket.register_recv_waker(cx.waker()); - Poll::Pending - } - } - } - }) - } - - pub async fn send_all(&mut self, buffer: &[u8]) -> Result<(), TcpSocketError> { - let mut pos = 0; - while pos < buffer.len() { - let n = self.send(&buffer[pos..]).await?; - assert!(n > 0); - pos += n; - } - assert_eq!(pos, buffer.len()); - Ok(()) - } - - pub async fn send(&mut self, buffer: &[u8]) -> Result { - future::poll_fn(|cx| self.poll_send(cx, buffer)).await - } - - #[allow(clippy::needless_pass_by_ref_mut)] - pub fn poll_send( - &mut self, - cx: &mut task::Context<'_>, - buffer: &[u8], - ) -> Poll> { - self.with_mut(|socket| { - if socket.can_send() { - Poll::Ready(socket.send_slice(buffer).map_err(TcpSocketError::SendError)) - } else { - let state = socket.state(); - match state { - tcp::State::FinWait1 - | tcp::State::FinWait2 - | tcp::State::Closed - | tcp::State::Closing - | tcp::State::CloseWait - | tcp::State::TimeWait => Poll::Ready(Err(TcpSocketError::InvalidState(state))), - _ => { - socket.register_send_waker(cx.waker()); - Poll::Pending - } - } - } - }) - } - - pub async fn close(&mut self) -> Result<(), TcpSocketError> { - future::poll_fn(|cx| { - self.with_mut(|socket| { - let state = socket.state(); - match state { - tcp::State::FinWait1 - | tcp::State::FinWait2 - | tcp::State::Closed - | tcp::State::Closing - | tcp::State::TimeWait => Poll::Ready(Err(TcpSocketError::InvalidState(state))), - _ => { - if socket.send_queue() > 0 { - socket.register_send_waker(cx.waker()); - Poll::Pending - } else { - socket.close(); - Poll::Ready(Ok(())) - } - } - } - }) - }) - .await?; - - future::poll_fn(|cx| { - self.with_mut(|socket| match socket.state() { - tcp::State::FinWait1 - | tcp::State::FinWait2 - | tcp::State::Closed - | tcp::State::Closing - | tcp::State::TimeWait => Poll::Ready(()), - _ => { - socket.register_send_waker(cx.waker()); - Poll::Pending - } - }) - }) - .await; - - Ok(()) + pub fn close(&mut self) { + self.with_mut(|socket| socket.close()) } pub fn abort(&mut self) { @@ -418,49 +290,60 @@ impl Socket> { } } -impl AsyncIO for Socket> { +impl sel4_async_io::ErrorType for Socket> { type Error = TcpSocketError; +} +impl Read for Socket> { fn poll_read( mut self: Pin<&mut Self>, cx: &mut task::Context<'_>, buf: &mut [u8], ) -> Poll> { - self.poll_recv(cx, buf) + self.with_mut(|socket| match socket.recv_slice(buf) { + Ok(0) if buf.is_empty() => Poll::Ready(Ok(0)), + Ok(0) => { + socket.register_recv_waker(cx.waker()); + Poll::Pending + } + Ok(n) => Poll::Ready(Ok(n)), + Err(tcp::RecvError::Finished) => Poll::Ready(Ok(0)), + Err(err) => Poll::Ready(Err(TcpSocketError::RecvError(err))), + }) } +} +impl Write for Socket> { fn poll_write( mut self: Pin<&mut Self>, cx: &mut task::Context<'_>, buf: &[u8], ) -> Poll> { - self.poll_send(cx, buf) + self.with_mut(|socket| match socket.send_slice(buf) { + Ok(0) if buf.is_empty() => Poll::Ready(Ok(0)), + Ok(0) => { + socket.register_send_waker(cx.waker()); + Poll::Pending + } + Ok(n) => Poll::Ready(Ok(n)), + Err(err) => Poll::Ready(Err(TcpSocketError::SendError(err))), + }) } fn poll_flush( - self: Pin<&mut Self>, - _cx: &mut task::Context<'_>, - ) -> Poll> { - // TODO - Poll::Ready(Ok(())) - } - - fn poll_close( - self: Pin<&mut Self>, - _cx: &mut task::Context<'_>, + mut self: Pin<&mut Self>, + cx: &mut task::Context<'_>, ) -> Poll> { - // TODO - Poll::Ready(Ok(())) - } -} - -impl Drop for Socket { - fn drop(&mut self) { - self.shared - .inner - .borrow_mut() - .socket_set - .remove(self.handle); + self.with_mut(|socket| { + let waiting_close = + socket.state() == tcp::State::Closed && socket.remote_endpoint().is_some(); + if socket.send_queue() > 0 || waiting_close { + socket.register_send_waker(cx.waker()); + Poll::Pending + } else { + Poll::Ready(Ok(())) + } + }) } } @@ -503,7 +386,7 @@ impl ManagedInterfaceShared { self.set_router(config.router); } if self.dhcp_overrides.dns_servers.is_none() { - self.set_dns_servers(&config.dns_servers); + self.set_dns_servers(&convert_dns_servers(&config.dns_servers)); } } dhcpv4::Event::Deconfigured => { @@ -560,16 +443,11 @@ impl ManagedInterfaceShared { self.iface.routes_mut().remove_default_ipv4_route(); } - fn set_dns_servers(&mut self, dns_servers: &[Ipv4Address]) { + fn set_dns_servers(&mut self, dns_servers: &[IpAddress]) { for (i, s) in dns_servers.iter().enumerate() { info!("DNS server {}: {}", i, s); } - let dns_servers = dns_servers - .iter() - .copied() - .map(From::from) - .collect::>(); - self.dns_socket_mut().update_servers(&dns_servers); + self.dns_socket_mut().update_servers(dns_servers); } fn clear_dns_servers(&mut self) { @@ -583,8 +461,12 @@ impl ManagedInterfaceShared { if let Some(router) = self.dhcp_overrides.router { self.set_router(router); } - if let Some(dns_servers) = self.dhcp_overrides.dns_servers.clone() { - // lazy, appease borrow checker + if let Some(dns_servers) = self + .dhcp_overrides + .dns_servers + .as_deref() + .map(convert_dns_servers) + { self.set_dns_servers(&dns_servers); } } @@ -606,3 +488,11 @@ fn free_dhcp_config(config: dhcpv4::Config) -> dhcpv4::Config<'static> { packet: None, } } + +fn convert_dns_servers(dns_servers: &[Ipv4Address]) -> Vec { + dns_servers + .iter() + .copied() + .map(From::from) + .collect::>() +} diff --git a/crates/sel4-async/network/traits/Cargo.nix b/crates/sel4-async/network/traits/Cargo.nix deleted file mode 100644 index 207aea943..000000000 --- a/crates/sel4-async/network/traits/Cargo.nix +++ /dev/null @@ -1,20 +0,0 @@ -# -# Copyright 2023, Colias Group, LLC -# -# SPDX-License-Identifier: BSD-2-Clause -# - -{ mk, versions }: - -mk { - package.name = "sel4-async-network-traits"; - dependencies = { - futures = { - version = versions.futures; - default-features = false; - features = [ - "alloc" - ]; - }; - }; -} diff --git a/crates/sel4-async/network/traits/src/lib.rs b/crates/sel4-async/network/traits/src/lib.rs deleted file mode 100644 index 6b0c89f57..000000000 --- a/crates/sel4-async/network/traits/src/lib.rs +++ /dev/null @@ -1,111 +0,0 @@ -// -// Copyright 2023, Colias Group, LLC -// -// SPDX-License-Identifier: BSD-2-Clause -// - -// TODO use Pin - -#![no_std] - -use core::pin::Pin; -use core::task::{Context, Poll}; - -use futures::future; - -pub trait AsyncIO { - type Error; - - fn poll_read( - self: Pin<&mut Self>, - cx: &mut Context<'_>, - buf: &mut [u8], - ) -> Poll>; - - fn poll_write( - self: Pin<&mut Self>, - cx: &mut Context<'_>, - buf: &[u8], - ) -> Poll>; - - fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll>; - - fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll>; -} - -#[derive(Copy, Clone, Debug)] -pub enum ClosedError { - Other(E), - Closed, -} - -impl From for ClosedError { - fn from(err: E) -> Self { - Self::Other(err) - } -} - -pub trait AsyncIOExt: AsyncIO { - #[allow(async_fn_in_trait)] - async fn read(&mut self, buf: &mut [u8]) -> Result - where - Self: Unpin, - { - let mut pin = Pin::new(self); - future::poll_fn(move |cx| pin.as_mut().poll_read(cx, buf)).await - } - - #[allow(async_fn_in_trait)] - async fn read_exact(&mut self, buf: &mut [u8]) -> Result<(), ClosedError> - where - Self: Unpin, - { - let mut pos = 0; - while pos < buf.len() { - let n = self.read(&mut buf[pos..]).await?; - if n == 0 { - return Err(ClosedError::Closed); - } - pos += n; - } - assert_eq!(pos, buf.len()); - Ok(()) - } - - #[allow(async_fn_in_trait)] - async fn write(&mut self, buf: &[u8]) -> Result - where - Self: Unpin, - { - let mut pin = Pin::new(self); - future::poll_fn(|cx| pin.as_mut().poll_write(cx, buf)).await - } - - #[allow(async_fn_in_trait)] - async fn write_all(&mut self, buf: &[u8]) -> Result<(), ClosedError> - where - Self: Unpin, - { - let mut pos = 0; - while pos < buf.len() { - let n = self.write(&buf[pos..]).await?; - if n == 0 { - return Err(ClosedError::Closed); - } - pos += n; - } - assert_eq!(pos, buf.len()); - Ok(()) - } - - #[allow(async_fn_in_trait)] - async fn flush(&mut self) -> Result<(), Self::Error> - where - Self: Unpin, - { - let mut pin = Pin::new(self); - future::poll_fn(|cx| pin.as_mut().poll_flush(cx)).await - } -} - -impl AsyncIOExt for T {}