diff --git a/Cargo.lock b/Cargo.lock index 4ca6ed82..08937df4 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -505,9 +505,9 @@ checksum = "d71b6127be86fdcfddb610f7182ac57211d4b18a3e9c82eb2d17662f2227ad6a" [[package]] name = "bytesize" -version = "2.3.1" +version = "2.1.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6bd91ee7b2422bcb158d90ef4d14f75ef67f340943fc4149891dcce8f8b972a3" +checksum = "f5c434ae3cf0089ca203e9019ebe529c47ff45cefe8af7c85ecb734ef541822f" dependencies = [ "serde_core", ] @@ -2241,6 +2241,7 @@ dependencies = [ "anyhow", "argh", "async-stream", + "async-trait", "axum", "axum-extra", "bytes", @@ -2262,6 +2263,7 @@ dependencies = [ "objectstore-service", "objectstore-test", "objectstore-types", + "pin-project", "rand 0.9.2", "reqwest", "rustls", diff --git a/objectstore-server/Cargo.toml b/objectstore-server/Cargo.toml index 65243f24..bfe22232 100644 --- a/objectstore-server/Cargo.toml +++ b/objectstore-server/Cargo.toml @@ -14,6 +14,7 @@ publish = false anyhow = { workspace = true } argh = "0.1.13" async-stream = "0.3.6" +async-trait = { workspace = true } axum = { version = "0.8.4", features = ["multipart"] } axum-extra = { version = "0.12.2", features = ["multipart"] } bytes = { workspace = true } @@ -33,6 +34,7 @@ multer = "3.1.0" num_cpus = "1.17.0" objectstore-service = { workspace = true } objectstore-types = { workspace = true } +pin-project = "1.1.10" rand = { workspace = true } reqwest = { workspace = true } rustls = { version = "0.23.31", default-features = false } diff --git a/objectstore-server/src/endpoints/batch.rs b/objectstore-server/src/endpoints/batch.rs index 47448cd0..bcf55511 100644 --- a/objectstore-server/src/endpoints/batch.rs +++ b/objectstore-server/src/endpoints/batch.rs @@ -1,12 +1,23 @@ +use std::pin::Pin; + +use async_trait::async_trait; use axum::Router; -use axum::http::StatusCode; +use axum::body::Body; use axum::response::{IntoResponse, Response}; use axum::routing; -use objectstore_service::id::ObjectContext; +use bytes::BytesMut; +use futures::stream::BoxStream; +use futures::{Stream, StreamExt, TryStreamExt}; +use http::header::CONTENT_TYPE; +use http::{HeaderMap, HeaderValue}; +use objectstore_service::id::{ObjectContext, ObjectId}; +use objectstore_service::{DeleteResult, GetResult, InsertResult}; use crate::auth::AuthAwareService; use crate::endpoints::common::ApiResult; +use crate::extractors::Operation; use crate::extractors::{BatchRequest, Xt}; +use crate::multipart::{IntoBytesStream, Part}; use crate::state::ServiceState; pub fn router() -> Router { @@ -14,9 +25,333 @@ pub fn router() -> Router { } async fn batch( - _service: AuthAwareService, - Xt(_context): Xt, - _request: BatchRequest, + service: AuthAwareService, + Xt(context): Xt, + mut request: BatchRequest, ) -> ApiResult { - Ok(StatusCode::NOT_IMPLEMENTED.into_response()) + let r = rand::random::(); + let boundary = format!("os-boundary-{r:032x}"); + let mut headers = HeaderMap::new(); + headers.insert( + CONTENT_TYPE, + HeaderValue::from_str(&format!("multipart/mixed; boundary={boundary}")).unwrap(), + ); + + let parts: BoxStream> = async_stream::try_stream! { + while let Some(operation) = request.operations.next().await { + let res = match operation { + Ok(operation) => match operation { + Operation::Get(get) => { + let res = service + .get_object(&ObjectId::new(context.clone(), get.key)) + .await; + res.into_part().await + } + Operation::Insert(insert) => { + let stream = futures_util::stream::once(async { Ok(insert.payload) }).boxed(); + let res = service + .insert_object(context.clone(), insert.key, &insert.metadata, stream) + .await; + res.into_part().await + } + Operation::Delete(delete) => { + let res = service + .delete_object(&ObjectId::new(context.clone(), delete.key)) + .await; + res.into_part().await + } + }, + Err(_) => todo!() + }; + yield res; + } + }.boxed(); + + Ok(( + headers, + Body::from_stream(parts.into_bytes_stream(boundary)), + ) + .into_response()) +} + +const HEADER_BATCH_OPERATION_STATUS: &str = "x-sn-batch-operation-status"; +const HEADER_BATCH_OPERATION_KEY: &str = "x-sn-batch-operation-key"; + +#[async_trait] +pub trait IntoPart { + async fn into_part(mut self) -> Part; +} + +#[async_trait] +impl IntoPart for GetResult { + async fn into_part(mut self) -> Part { + match self { + Ok(Some((metadata, payload))) => { + let payload = payload + .try_fold(BytesMut::new(), |mut acc, chunk| async move { + acc.extend_from_slice(&chunk); + Ok(acc) + }) + .await + .unwrap() + .freeze(); + + let mut headers = metadata.to_headers("", false).unwrap(); + headers.insert( + HEADER_BATCH_OPERATION_STATUS, + HeaderValue::from_static("200"), + ); + + Part::new(headers, payload) + } + Ok(None) => { + let mut headers = HeaderMap::new(); + headers.insert( + HEADER_BATCH_OPERATION_STATUS, + HeaderValue::from_static("404"), + ); + Part::headers_only(headers) + } + Err(_) => { + let mut headers = HeaderMap::new(); + headers.insert( + HEADER_BATCH_OPERATION_STATUS, + HeaderValue::from_static("500"), + ); + Part::headers_only(headers) + } + } + } +} + +#[async_trait] +impl IntoPart for InsertResult { + async fn into_part(mut self) -> Part { + match self { + Ok(id) => { + let mut headers = HeaderMap::new(); + headers.insert( + HEADER_BATCH_OPERATION_KEY, + HeaderValue::from_str(id.key()).unwrap(), + ); + headers.insert( + HEADER_BATCH_OPERATION_STATUS, + HeaderValue::from_static("200"), + ); + Part::headers_only(headers) + } + Err(_) => { + let mut headers = HeaderMap::new(); + headers.insert( + HEADER_BATCH_OPERATION_STATUS, + HeaderValue::from_static("500"), + ); + Part::headers_only(headers) + } + } + } +} + +#[async_trait] +impl IntoPart for DeleteResult { + async fn into_part(mut self) -> Part { + match self { + Ok(()) => { + let mut headers = HeaderMap::new(); + headers.insert( + HEADER_BATCH_OPERATION_STATUS, + HeaderValue::from_static("200"), + ); + Part::headers_only(headers) + } + Err(_) => { + let mut headers = HeaderMap::new(); + headers.insert( + HEADER_BATCH_OPERATION_STATUS, + HeaderValue::from_static("500"), + ); + Part::headers_only(headers) + } + } + } +} + +#[cfg(test)] +mod tests { + use crate::auth::PublicKeyDirectory; + + use super::*; + use axum::body::Body; + use axum::http::{Request, StatusCode}; + use bytes::Bytes; + use objectstore_service::StorageConfig; + use std::collections::BTreeMap; + use std::sync::Arc; + use tower::ServiceExt; + + /// Tests the batch endpoint end-to-end with insert, get, and delete operations + #[tokio::test] + async fn test_batch_endpoint_basic() { + // Set up temporary filesystem storage + let tempdir = tempfile::tempdir().unwrap(); + let config = StorageConfig::FileSystem { + path: tempdir.path(), + }; + let storage_service = objectstore_service::StorageService::new(config.clone(), config) + .await + .unwrap(); + + // Create application state + let state = Arc::new(crate::state::Services { + config: crate::config::Config::default(), + service: storage_service, + key_directory: PublicKeyDirectory { + keys: BTreeMap::new(), + }, + }); + + // Build the router with state + let app = router().with_state(state); + + // Create a batch request with insert, get, delete, and get non-existing key + let insert_data = b"test data"; + let request_body = format!( + "--boundary\r\n\ + {HEADER_BATCH_OPERATION_KEY}: testkey\r\n\ + x-sn-batch-operation-kind: insert\r\n\ + Content-Type: application/octet-stream\r\n\ + \r\n\ + {data}\r\n\ + --boundary\r\n\ + {HEADER_BATCH_OPERATION_KEY}: testkey\r\n\ + x-sn-batch-operation-kind: get\r\n\ + \r\n\ + \r\n\ + --boundary\r\n\ + {HEADER_BATCH_OPERATION_KEY}: testkey\r\n\ + x-sn-batch-operation-kind: delete\r\n\ + \r\n\ + \r\n\ + --boundary\r\n\ + {HEADER_BATCH_OPERATION_KEY}: nonexistent\r\n\ + x-sn-batch-operation-kind: get\r\n\ + \r\n\ + \r\n\ + --boundary--\r\n", + data = String::from_utf8_lossy(insert_data), + ); + + let request = Request::builder() + .uri("/objects:batch/testing/scope=value/") + .method("POST") + .header("Content-Type", "multipart/mixed; boundary=boundary") + .body(Body::from(request_body)) + .unwrap(); + + // Call the endpoint + let response = app.oneshot(request).await.unwrap(); + + // Verify response status + let status = response.status(); + if status != StatusCode::OK { + let body = response.into_body(); + let body_bytes = axum::body::to_bytes(body, usize::MAX).await.unwrap(); + let error_msg = String::from_utf8_lossy(&body_bytes); + panic!("Expected 200 OK, got {}: {}", status, error_msg); + } + + // Get the content type and extract boundary + let content_type = response + .headers() + .get("content-type") + .unwrap() + .to_str() + .unwrap(); + assert!(content_type.starts_with("multipart/mixed")); + + // Swap content type from multipart/mixed to multipart/form-data for multer + let content_type = content_type.replace("multipart/mixed", "multipart/form-data"); + let boundary = multer::parse_boundary(&content_type).unwrap(); + + // Parse the multipart response using multer + let body = response.into_body(); + let body_bytes = axum::body::to_bytes(body, usize::MAX).await.unwrap(); + + // Create a stream for multer + use futures::stream; + let chunks: Vec<_> = body_bytes + .chunks(64) + .map(|chunk| Ok::<_, multer::Error>(Bytes::copy_from_slice(chunk))) + .collect(); + let body_stream = stream::iter(chunks); + let mut multipart = multer::Multipart::new(body_stream, boundary); + + // Collect all parts + let mut parts = vec![]; + loop { + match multipart.next_field().await { + Ok(Some(field)) => { + let headers = field.headers().clone(); + match field.bytes().await { + Ok(data) => parts.push((headers, data)), + Err(e) => panic!("Failed to read field bytes: {:?}", e), + } + } + Ok(None) => break, + Err(e) => panic!("Failed to get next field: {:?}", e), + } + } + + // Should have exactly 4 parts + assert_eq!(parts.len(), 4); + + // First part: insert response + let (insert_headers, insert_body) = &parts[0]; + let insert_status = insert_headers + .get(HEADER_BATCH_OPERATION_STATUS) + .unwrap() + .to_str() + .unwrap(); + assert_eq!(insert_status, "200"); + + let insert_key = insert_headers + .get(HEADER_BATCH_OPERATION_KEY) + .unwrap() + .to_str() + .unwrap(); + assert_eq!(insert_key, "testkey"); + assert!(insert_body.is_empty()); // Insert response has no body + + // Second part: get response + let (get_headers, get_body) = &parts[1]; + let get_status = get_headers + .get(HEADER_BATCH_OPERATION_STATUS) + .unwrap() + .to_str() + .unwrap(); + assert_eq!(get_status, "200"); + + // Verify the retrieved data matches what we inserted + assert_eq!(get_body.as_ref(), insert_data); + + // Third part: delete response + let (delete_headers, delete_body) = &parts[2]; + let delete_status = delete_headers + .get(HEADER_BATCH_OPERATION_STATUS) + .unwrap() + .to_str() + .unwrap(); + assert_eq!(delete_status, "200"); + assert!(delete_body.is_empty()); // Delete response has no body + + // Fourth part: get non-existing key (should be 404) + let (not_found_headers, not_found_body) = &parts[3]; + let not_found_status = not_found_headers + .get(HEADER_BATCH_OPERATION_STATUS) + .unwrap() + .to_str() + .unwrap(); + assert_eq!(not_found_status, "404"); + assert!(not_found_body.is_empty()); // Not found response has no body + } } diff --git a/objectstore-server/src/lib.rs b/objectstore-server/src/lib.rs index 541d9a3d..a9fc5f27 100644 --- a/objectstore-server/src/lib.rs +++ b/objectstore-server/src/lib.rs @@ -11,6 +11,7 @@ pub mod endpoints; pub mod extractors; pub mod healthcheck; pub mod killswitches; +pub mod multipart; pub mod observability; pub mod state; pub mod web; diff --git a/objectstore-server/src/multipart.rs b/objectstore-server/src/multipart.rs new file mode 100644 index 00000000..5944d92f --- /dev/null +++ b/objectstore-server/src/multipart.rs @@ -0,0 +1,134 @@ +//! Utilities to represent and serialize multipart parts. + +use std::pin::Pin; +use std::task::{Context, Poll}; + +use bytes::{BufMut, Bytes, BytesMut}; +use futures::Stream; +use http::HeaderMap; +use pin_project::pin_project; + +/// A Multipart part. +#[derive(Debug)] +pub struct Part { + headers: HeaderMap, + body: Bytes, +} + +impl Part { + /// Creates a new Multipart part with headers and body. + pub fn new(headers: HeaderMap, body: Bytes) -> Self { + Part { headers, body } + } + + /// Creates a new Multipart part with headers only. + pub fn headers_only(headers: HeaderMap) -> Self { + Part { + headers, + body: Bytes::new(), + } + } +} + +pub trait IntoBytesStream { + fn into_bytes_stream(self, boundary: String) -> impl Stream>; +} + +impl IntoBytesStream for S +where + S: Stream> + Send, +{ + fn into_bytes_stream(self, boundary: String) -> impl Stream> { + let mut b = BytesMut::with_capacity(boundary.len() + 4); + b.put(&b"--"[..]); + b.put(boundary.as_bytes()); + b.put(&b"\r\n"[..]); + PartsSerializer { + parts: self, + boundary: b.freeze(), + state: State::Waiting, + } + } +} + +#[pin_project] +struct PartsSerializer +where + S: Stream>, +{ + #[pin] + parts: S, + boundary: Bytes, + state: State, +} + +enum State { + Waiting, + SendHeaders(Part), + SendBody(Bytes), + SendClosingBoundary, + Done, +} + +impl Stream for PartsSerializer +where + S: Stream>, +{ + type Item = Result; + + fn poll_next(self: Pin<&mut Self>, ctx: &mut Context<'_>) -> Poll> { + let mut this = self.project(); + match std::mem::replace(this.state, State::Waiting) { + State::Waiting => match this.parts.as_mut().poll_next(ctx) { + Poll::Pending => Poll::Pending, + Poll::Ready(None) => { + *this.state = State::SendClosingBoundary; + ctx.waker().wake_by_ref(); + Poll::Pending + } + Poll::Ready(Some(Err(e))) => Poll::Ready(Some(Err(e))), + Poll::Ready(Some(Ok(p))) => { + *this.state = State::SendHeaders(p); + Poll::Ready(Some(Ok(this.boundary.clone()))) + } + }, + State::SendHeaders(part) => { + *this.state = State::SendBody(part.body); + let headers = serialize_headers(part.headers); + Poll::Ready(Some(Ok(headers))) + } + State::SendBody(body) => { + // Add \r\n after the body + let mut body_with_newline = BytesMut::with_capacity(body.len() + 2); + body_with_newline.put(body); + body_with_newline.put(&b"\r\n"[..]); + Poll::Ready(Some(Ok(body_with_newline.freeze()))) + } + State::SendClosingBoundary => { + *this.state = State::Done; + // Create closing boundary: --boundary-- (without \r\n in between) + // The boundary already has --boundary\r\n, so we need to strip the \r\n + // and add --\r\n instead + let boundary_str = std::str::from_utf8(this.boundary).unwrap(); + let boundary_without_crlf = boundary_str.trim_end_matches("\r\n"); + let mut closing = BytesMut::with_capacity(boundary_without_crlf.len() + 4); + closing.put(boundary_without_crlf.as_bytes()); + closing.put(&b"--\r\n"[..]); + Poll::Ready(Some(Ok(closing.freeze()))) + } + State::Done => Poll::Ready(None), + } + } +} + +fn serialize_headers(headers: HeaderMap) -> Bytes { + let mut b = BytesMut::with_capacity(30 + 30 * headers.len()); + for (name, value) in &headers { + b.put(name.as_str().as_bytes()); + b.put(&b": "[..]); + b.put(value.as_bytes()); + b.put(&b"\r\n"[..]); + } + b.put(&b"\r\n"[..]); + b.freeze() +}