diff --git a/Cargo.lock b/Cargo.lock index f231125..e0ba60d 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -785,9 +785,9 @@ dependencies = [ [[package]] name = "futures-channel" -version = "0.3.15" +version = "0.3.17" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e682a68b29a882df0545c143dc3646daefe80ba479bcdede94d5a703de2871e2" +checksum = "5da6ba8c3bb3c165d3c7319fc1cc8304facf1fb8db99c5de877183c08a273888" dependencies = [ "futures-core", "futures-sink", @@ -795,9 +795,9 @@ dependencies = [ [[package]] name = "futures-core" -version = "0.3.15" +version = "0.3.17" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0402f765d8a89a26043b889b26ce3c4679d268fa6bb22cd7c6aad98340e179d1" +checksum = "88d1c26957f23603395cd326b0ffe64124b818f4449552f960d815cfba83a53d" [[package]] name = "futures-executor" @@ -812,15 +812,15 @@ dependencies = [ [[package]] name = "futures-io" -version = "0.3.15" +version = "0.3.17" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "acc499defb3b348f8d8f3f66415835a9131856ff7714bf10dadfc4ec4bdb29a1" +checksum = "522de2a0fe3e380f1bc577ba0474108faf3f6b18321dbf60b3b9c39a75073377" [[package]] name = "futures-macro" -version = "0.3.15" +version = "0.3.17" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a4c40298486cdf52cc00cd6d6987892ba502c7656a16a4192a9992b1ccedd121" +checksum = "18e4a4b95cea4b4ccbcf1c5675ca7c4ee4e9e75eb79944d07defde18068f79bb" dependencies = [ "autocfg", "proc-macro-hack", @@ -831,21 +831,21 @@ dependencies = [ [[package]] name = "futures-sink" -version = "0.3.15" +version = "0.3.17" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a57bead0ceff0d6dde8f465ecd96c9338121bb7717d3e7b108059531870c4282" +checksum = "36ea153c13024fe480590b3e3d4cad89a0cfacecc24577b68f86c6ced9c2bc11" [[package]] name = "futures-task" -version = "0.3.15" +version = "0.3.17" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8a16bef9fc1a4dddb5bee51c989e3fbba26569cbb0e31f5b303c184e3dd33dae" +checksum = "1d3d00f4eddb73e498a54394f228cd55853bdf059259e8e7bc6e69d408892e99" [[package]] name = "futures-util" -version = "0.3.15" +version = "0.3.17" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "feb5c238d27e2bf94ffdfd27b2c29e3df4a68c4193bb6427384259e2bf191967" +checksum = "36568465210a3a6ee45e1f165136d68671471a501e632e9a98d96872222b5481" dependencies = [ "autocfg", "futures-channel", @@ -2721,6 +2721,8 @@ dependencies = [ "docker_credential", "env-file-reader", "futures", + "futures-core", + "futures-util", "hyper", "indexmap", "oci-distribution", diff --git a/Cargo.toml b/Cargo.toml index 0163913..95e5520 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -6,7 +6,7 @@ [dependencies] anyhow = "1.0" - async-stream = "0.3" + async-stream = "0.3.2" async-trait = "0.1" bindle = { version = "0.3", default-features = false, features = ["client", "server", "caching"] } cap-std = "^0.22" @@ -36,3 +36,5 @@ wasmtime-cache = "0.33" wat = "1.0.37" chrono = "0.4.19" +futures-util = "0.3.17" +futures-core = "0.3.17" diff --git a/src/dispatcher.rs b/src/dispatcher.rs index ed81615..f69eca1 100644 --- a/src/dispatcher.rs +++ b/src/dispatcher.rs @@ -58,7 +58,7 @@ impl RoutingTable { let request_context = RequestContext { client_addr, }; - let response = rte.handle_request(&parts, data, &request_context, &self.global_context); + let response = rte.handle_request(&parts, data, &request_context, &self.global_context).await; Ok(response) }, Err(_) => Ok(not_found()), @@ -149,7 +149,7 @@ impl RoutingTableEntry { // TODO: I don't think this rightly belongs here. But // reasonable place to at least understand the decomposition and // dependencies. - pub fn handle_request( + pub async fn handle_request( &self, req: &Parts, body: Vec, @@ -159,7 +159,7 @@ impl RoutingTableEntry { match &self.handler_info { RouteHandler::HealthCheck => Response::new(Body::from("OK")), RouteHandler::Wasm(w) => { - let response = w.handle_request(&self.route_pattern, req, body, request_context, global_context, self.unique_key()); + let response = w.handle_request(&self.route_pattern, req, body, request_context, global_context, self.unique_key()).await; match response { Ok(res) => res, Err(e) => { @@ -352,7 +352,7 @@ fn append_one_dynamic_route(routing_table_entry: &RoutingTableEntry, wasm_route_ } } -fn build_wasi_context_for_dynamic_route_query(redirects: crate::wasm_module::IOStreamRedirects) -> wasi_common::WasiCtx { +fn build_wasi_context_for_dynamic_route_query(redirects: crate::wasm_module::IOStreamRedirects>) -> wasi_common::WasiCtx { let builder = wasi_cap_std_sync::WasiCtxBuilder::new() .stderr(Box::new(redirects.stderr)) .stdout(Box::new(redirects.stdout)); diff --git a/src/handlers.rs b/src/handlers.rs index 9c9e174..1f4eb02 100644 --- a/src/handlers.rs +++ b/src/handlers.rs @@ -1,5 +1,4 @@ use std::{collections::HashMap}; -use std::sync::{Arc, RwLock}; use wasi_cap_std_sync::Dir; use hyper::{ @@ -16,8 +15,9 @@ use crate::dispatcher::RoutePattern; use crate::http_util::{internal_error, parse_cgi_headers}; use crate::request::{RequestContext, RequestGlobalContext}; +use crate::stream_writer::StreamWriter; use crate::wasm_module::WasmModuleSource; -use crate::wasm_runner::{prepare_stdio_streams, prepare_wasm_instance, run_prepared_wasm_instance, WasmLinkOptions}; +use crate::wasm_runner::{prepare_stdio_streams_for_http, prepare_wasm_instance, run_prepared_wasm_instance, WasmLinkOptions}; #[derive(Clone, Debug)] pub enum RouteHandler { @@ -36,41 +36,79 @@ pub struct WasmRouteHandler { } impl WasmRouteHandler { - pub fn handle_request( + pub async fn handle_request( &self, matched_route: &RoutePattern, req: &Parts, - body: Vec, + request_body: Vec, request_context: &RequestContext, global_context: &RequestGlobalContext, logging_key: String, ) -> Result, anyhow::Error> { + + // These broken-out functions are slightly artificial but help solve some lifetime + // issues (where otherwise you get errors about things not being Send across an + // await). + let (stream_writer, instance, store) = + self.set_up_runtime_environment(matched_route, req, request_body, request_context, global_context, logging_key)?; + self.spawn_wasm_instance(instance, store, stream_writer.clone()); + + let response = match compose_response(stream_writer).await { + Ok(r) => r, + Err(e) => { + tracing::error!("Error parsing guest output into HTTP response: {}", e); + internal_error("internal error calling application") + } + }; + + tokio::task::yield_now().await; + + Ok(response) + } + + fn set_up_runtime_environment(&self, matched_route: &RoutePattern, req: &Parts, request_body: Vec, request_context: &RequestContext, global_context: &RequestGlobalContext, logging_key: String) -> anyhow::Result<(crate::stream_writer::StreamWriter, Instance, Store)> { let startup_span = tracing::info_span!("module instantiation").entered(); + let headers = crate::http_util::build_headers( matched_route, req, - body.len(), + request_body.len(), request_context.client_addr, global_context.default_host.as_str(), global_context.use_tls, &global_context.global_env_vars, ); - let redirects = prepare_stdio_streams(body, global_context, logging_key)?; - + let stream_writer = crate::stream_writer::StreamWriter::new(); + let redirects = prepare_stdio_streams_for_http(request_body, stream_writer.clone(), global_context, logging_key)?; let ctx = self.build_wasi_context_for_request(req, headers, redirects.streams)?; - let (store, instance) = self.prepare_wasm_instance(global_context, ctx)?; - - // Drop manually to get instantiation time + drop(startup_span); + + Ok((stream_writer, instance, store)) + } - run_prepared_wasm_instance(instance, store, &self.entrypoint, &self.wasm_module_name)?; - - compose_response(redirects.stdout_mutex) + fn spawn_wasm_instance(&self, instance: Instance, store: Store, mut stream_writer: StreamWriter) { + let entrypoint = self.entrypoint.clone(); + let wasm_module_name = self.wasm_module_name.clone(); + + // This is fire and forget, so there's a limited amount of error handling we + // can do. + tokio::spawn(async move { + match run_prepared_wasm_instance(instance, store, &entrypoint, &wasm_module_name) { + Ok(()) => (), + Err(e) => tracing::error!("Error running Wasm module: {}", e), + }; + // TODO: should we attempt to write an error response to the StreamWriter here? + match stream_writer.done() { + Ok(()) => (), + Err(e) => tracing::error!("Error marking Wasm output as done: {}", e), + } + }); } - fn build_wasi_context_for_request(&self, req: &Parts, headers: HashMap, redirects: crate::wasm_module::IOStreamRedirects) -> Result { + fn build_wasi_context_for_request(&self, req: &Parts, headers: HashMap, redirects: crate::wasm_module::IOStreamRedirects) -> Result { let uri_path = req.uri.path(); let mut args = vec![uri_path.to_string()]; req.uri @@ -110,34 +148,12 @@ impl WasmRouteHandler { } } -pub fn compose_response(stdout_mutex: Arc>>) -> Result, Error> { - // Okay, once we get here, all the information we need to send back in the response - // should be written to the STDOUT buffer. We fetch that, format it, and send - // it back. In the process, we might need to alter the status code of the result. - // - // This is a little janky, but basically we are looping through the output once, - // looking for the double-newline that distinguishes the headers from the body. - // The headers can then be parsed separately, while the body can be sent back - // to the client. - - let out = stdout_mutex.read().unwrap(); - let mut last = 0; - let mut scan_headers = true; - let mut buffer: Vec = Vec::new(); - let mut out_headers: Vec = Vec::new(); - out.iter().for_each(|i| { - if scan_headers && *i == 10 && last == 10 { - out_headers.append(&mut buffer); - buffer = Vec::new(); - scan_headers = false; - return; // Consume the linefeed - } - last = *i; - buffer.push(*i) - }); - let mut res = Response::new(Body::from(buffer)); +pub async fn compose_response(mut stream_writer: StreamWriter) -> anyhow::Result> { + let header_block = stream_writer.header_block().await?; + let mut res = Response::new(Body::wrap_stream(stream_writer.as_stream())); + let mut sufficient_response = false; - parse_cgi_headers(String::from_utf8(out_headers)?) + parse_cgi_headers(String::from_utf8(header_block)?) .iter() .for_each(|h| { use hyper::header::{CONTENT_TYPE, LOCATION}; diff --git a/src/lib.rs b/src/lib.rs index 88353de..404ce8c 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -6,6 +6,7 @@ pub(crate) mod handlers; mod http_util; pub (crate) mod module_loader; mod request; +pub (crate) mod stream_writer; mod tls; pub mod version; pub mod wagi_app; diff --git a/src/stream_writer.rs b/src/stream_writer.rs new file mode 100644 index 0000000..d8fed8d --- /dev/null +++ b/src/stream_writer.rs @@ -0,0 +1,239 @@ +use std::{io::Write, sync::{Arc, RwLock}}; + +use async_stream::stream; + +#[derive(Clone)] +pub struct StreamWriter { + pending: Arc>>, + done: Arc>, + // A way for the write side to signal new data to the stream side + write_index: Arc>, + write_index_sender: Arc>, + write_index_receiver: tokio::sync::watch::Receiver, +} + +impl StreamWriter { + pub fn new() -> Self { + let write_index = 0; + let (tx, rx) = tokio::sync::watch::channel(write_index); + Self { + pending: Arc::new(RwLock::new(vec![])), + done: Arc::new(RwLock::new(false)), + write_index: Arc::new(RwLock::new(write_index)), + write_index_sender: Arc::new(tx), + write_index_receiver: rx, + } + } + + fn append(&mut self, buf: &[u8]) -> anyhow::Result<()> { + let result = match self.pending.write().as_mut() { + Ok(pending) => { + pending.extend_from_slice(buf); + Ok(()) + }, + Err(e) => + Err(anyhow::anyhow!("Internal error: StreamWriter::append can't take lock: {}", e)) + }; + { + let mut write_index = self.write_index.write().unwrap(); + *write_index = *write_index + 1; + self.write_index_sender.send(*write_index).unwrap(); + drop(write_index); + } + result + } + + pub fn done(&mut self) -> anyhow::Result<()> { + match self.done.write().as_deref_mut() { + Ok(d) => { + *d = true; + Ok(()) + }, + Err(e) => + Err(anyhow::anyhow!("Internal error: StreamWriter::done can't take lock: {}", e)) + + } + } + + pub async fn header_block(&mut self) -> anyhow::Result> { + loop { + match self.pending.write().as_deref_mut() { + Ok(pending) => match split_at_two_newlines(&pending) { + None => (), + Some((header_block, rest)) => { + *pending = rest; + return Ok(header_block); + } + }, + Err(e) => { + return Err(anyhow::anyhow!("Internal error: StreamWriter::header_block can't take lock: {}", e)); + }, + } + // See comments on the as_stream loop, though using the change signal + // blocked this *completely* until end of writing! (And everything else + // waits on this.) + tokio::time::sleep(tokio::time::Duration::from_micros(1)).await; + } + } + + pub fn as_stream(mut self) -> impl futures_core::stream::Stream>> { + stream! { + loop { + let data = self.pop(); + match data { + Ok(v) => { + if v.is_empty() { + if self.is_done() { + return; + } else { + // This tiny wait seems to help the write-stream pipeline to flow more smmoothly. + // If we go straight to the 'changed().await' then the pipeline seems to stall after + // a few dozen writes, and everything else gets held up until the entire output + // has been written. There may be better ways of doing this; I haven't found them + // yet. + // + // (By the way, having the timer but not the change notification also worked. But if + // writes came slowly, that would result in very aggressive polling. So hopefully this + // gives us the best of both worlds.) + tokio::time::sleep(tokio::time::Duration::from_nanos(10)).await; + + match self.write_index_receiver.changed().await { + Ok(_) => continue, + Err(e) => { + // If this ever happens (which it, cough, shouldn't), it means all senders have + // closed, which _should_ mean we are done. Log the error + // but don't return it to the stream: the response as streamed so far + // _should_ be okay! + tracing::error!("StreamWriter::as_stream: error receiving write updates: {}", e); + return; + } + } + } + } else { + // This tiny wait seems to help the write-stream pipeline to flow more smmoothly. + // See the comment on the 'empty buffer' case. + tokio::time::sleep(tokio::time::Duration::from_nanos(10)).await; + yield Ok(v); + } + }, + Err(e) => { + if self.is_done() { + return; + } else { + yield Err(e); + return; + } + }, + } + } + } + } + + fn is_done(&self) -> bool { + match self.done.read() { + Ok(d) => *d, + Err(_) => false, + } + } + + fn pop(&mut self) -> anyhow::Result> { + let data = match self.pending.write().as_mut() { + Ok(pending) => { + let res = pending.clone(); + pending.clear(); + Ok(res) + }, + Err(e) => { + Err(anyhow::anyhow!("Internal error: StreamWriter::pop can't take lock: {}", e)) + } + }; + data + } +} + +impl Write for StreamWriter { + fn write(&mut self, buf: &[u8]) -> std::io::Result { + self.append(buf).map_err( + |e| std::io::Error::new(std::io::ErrorKind::Other, e) + )?; + + Ok(buf.len()) + } + + fn flush(&mut self) -> std::io::Result<()> { + Ok(()) + } +} + +fn split_at_two_newlines(source: &[u8]) -> Option<(Vec, Vec)> { + let mut buffer = vec![]; + let mut last: u8 = 0; + for value in source { + if *value == 10 && last == 10 { + let rest_slice = &source[(buffer.len() + 1)..]; + let rest = Vec::from(rest_slice); + return Some((buffer, rest)); + } else { + buffer.push(*value); + last = *value; + } + } + None +} + +#[cfg(test)] +mod test { + use futures::StreamExt; + + use super::*; + + #[test] + fn splits_at_two_newlines_if_pair_only() { + let source: Vec = vec![0x41, 0x42, 0x0a, 0x0a, 0x43, 0x44]; + let result = split_at_two_newlines(&source).expect("did not split at all"); + assert_eq!(vec![0x41, 0x42, 0x0a], result.0); + assert_eq!(vec![0x43, 0x44], result.1); + } + + #[test] + fn doesnt_splits_at_two_newlines_if_no_pair() { + let source: Vec = vec![0x41, 0x42, 0x0a, 0x43, 0x44, 0x0a, 0x45, 0x46]; + let result = split_at_two_newlines(&source); + assert_eq!(None, result); + } + + #[test] + fn splits_at_two_newlines_empty_rest_if_at_end() { + let source: Vec = vec![0x41, 0x42, 0x0a, 0x43, 0x44, 0x0a, 0x0a]; + let result = split_at_two_newlines(&source).expect("did not split at all"); + assert_eq!(vec![0x41, 0x42, 0x0a, 0x43, 0x44, 0x0a], result.0); + assert!(result.1.is_empty()); + } + + #[tokio::test] + async fn streaming_splits_out_headers() { + let mut sw = StreamWriter::new(); + let mut sw2 = sw.clone(); + tokio::spawn(async move { + write!(sw2, "Header 1\n").unwrap(); + write!(sw2, "Header 2\n").unwrap(); + write!(sw2, "\n").unwrap(); + write!(sw2, "Body 1\n").unwrap(); + write!(sw2, "Body 2\n").unwrap(); + sw2.done().unwrap(); + }); + let header = sw.header_block().await.unwrap(); + let header_text = String::from_utf8(header).unwrap(); + assert!(header_text.contains("Header 1\n")); + assert!(header_text.contains("Header 2\n")); + + let mut stm = Box::pin(sw.as_stream()); + let mut body = vec![]; + while let Some(Ok(v)) = stm.next().await { + body.extend_from_slice(&v); + } + let body_text = String::from_utf8(body).unwrap(); + assert!(body_text.contains("Body 1\n")); + assert!(body_text.contains("Body 2\n")); + } +} diff --git a/src/wasm_module.rs b/src/wasm_module.rs index 7bad77c..ed76759 100644 --- a/src/wasm_module.rs +++ b/src/wasm_module.rs @@ -1,4 +1,4 @@ -use std::{fmt::Debug, sync::{Arc, RwLock}}; +use std::{fmt::Debug, io::Write, sync::{Arc, RwLock}}; use wasi_common::pipe::{ReadPipe, WritePipe}; use wasmtime::*; @@ -31,13 +31,13 @@ impl Debug for WasmModuleSource { // constraints from the stdout_mutex. Not sure how to do this better. // (I don't want to .clone() the fields even though that would work, // because that is misleading about the semantics.) -pub struct IOStreamRedirects { +pub struct IOStreamRedirects { pub stdin: ReadPipe>>, - pub stdout: WritePipe>, + pub stdout: WritePipe, pub stderr: wasi_cap_std_sync::file::File, } -pub struct IORedirectionInfo { - pub streams: IOStreamRedirects, - pub stdout_mutex: Arc>>, +pub struct IORedirectionInfo { + pub streams: IOStreamRedirects, + pub stdout_mutex: Arc>, } diff --git a/src/wasm_runner.rs b/src/wasm_runner.rs index 3b0152a..b8bc939 100644 --- a/src/wasm_runner.rs +++ b/src/wasm_runner.rs @@ -47,7 +47,7 @@ pub fn prepare_stdio_streams( body: Vec, global_context: &RequestGlobalContext, handler_id: String, -) -> Result { +) -> Result>, Error> { let stdin = ReadPipe::from(body); let stdout_buf: Vec = vec![]; let stdout_mutex = Arc::new(RwLock::new(stdout_buf)); @@ -78,6 +78,40 @@ pub fn prepare_stdio_streams( }) } +pub fn prepare_stdio_streams_for_http( + body: Vec, + stream_writer: crate::stream_writer::StreamWriter, + global_context: &RequestGlobalContext, + handler_id: String, +) -> Result, Error> { + let stdin = ReadPipe::from(body); + let stdout_mutex = Arc::new(RwLock::new(stream_writer)); + let stdout = WritePipe::from_shared(stdout_mutex.clone()); + let log_dir = global_context.base_log_dir.join(handler_id); + + // The spec does not say what to do with STDERR. + // See specifically sections 4.2 and 6.1 of RFC 3875. + // Currently, we will attach to wherever logs go. + tracing::info!(log_dir = %log_dir.display(), "Using log dir"); + std::fs::create_dir_all(&log_dir)?; + let stderr = cap_std::fs::File::from_std( + std::fs::OpenOptions::new() + .append(true) + .create(true) + .open(log_dir.join(STDERR_FILE))?, + ); + let stderr = wasi_cap_std_sync::file::File::from_cap_std(stderr); + + Ok(crate::wasm_module::IORedirectionInfo { + streams: crate::wasm_module::IOStreamRedirects { + stdin, + stdout, + stderr, + }, + stdout_mutex, + }) +} + pub fn new_store_and_engine( cache_config_path: &Path, ctx: WasiCtx,