diff --git a/ahnlich/Cargo.toml b/ahnlich/Cargo.toml index 9c47429a..ff17122c 100644 --- a/ahnlich/Cargo.toml +++ b/ahnlich/Cargo.toml @@ -41,3 +41,4 @@ deadpool = { version = "0.10", features = ["rt_tokio_1"]} opentelemetry = { version = "0.23.0", features = ["trace"] } tracing-opentelemetry = "0.24.0" log = "0.4" +fallible_collections = "0.4.9" diff --git a/ahnlich/ai/Cargo.toml b/ahnlich/ai/Cargo.toml index 0bb1a884..5809cfe8 100644 --- a/ahnlich/ai/Cargo.toml +++ b/ahnlich/ai/Cargo.toml @@ -35,6 +35,7 @@ serde_json.workspace = true termcolor = "1.4.1" strum = { version = "0.26", features = ["derive"] } log.workspace = true +fallible_collections.workspace = true [dev-dependencies] db = { path = "../db", version = "*" } diff --git a/ahnlich/ai/src/engine/store.rs b/ahnlich/ai/src/engine/store.rs index 9ce1e66b..0fd7ae33 100644 --- a/ahnlich/ai/src/engine/store.rs +++ b/ahnlich/ai/src/engine/store.rs @@ -9,6 +9,7 @@ use ahnlich_types::keyval::StoreKey; use ahnlich_types::keyval::StoreName; use ahnlich_types::keyval::StoreValue; use ahnlich_types::metadata::MetadataValue; +use fallible_collections::FallibleVec; use flurry::HashMap as ConcurrentHashMap; use serde::Deserialize; use serde::Serialize; @@ -162,7 +163,7 @@ impl AIStoreHandler { ) -> Result { let metadata_key = &*AHNLICH_AI_RESERVED_META_KEY; let store = self.get(store_name)?; - let mut output = Vec::with_capacity(inputs.len()); + let mut output: Vec<_> = FallibleVec::try_with_capacity(inputs.len())?; let mut delete_hashset = StdHashSet::new(); for (store_input, mut store_value) in inputs { diff --git a/ahnlich/ai/src/error.rs b/ahnlich/ai/src/error.rs index e81f29a6..59ec6f1a 100644 --- a/ahnlich/ai/src/error.rs +++ b/ahnlich/ai/src/error.rs @@ -2,6 +2,7 @@ use ahnlich_types::{ ai::{AIStoreInputType, PreprocessAction}, keyval::StoreName, }; +use fallible_collections::TryReserveError; use thiserror::Error; use tokio::sync::oneshot::error::RecvError; @@ -67,4 +68,12 @@ pub enum AIProxyError { index_model_dim: usize, query_model_dim: usize, }, + #[error("allocation error {0:?}")] + Allocation(TryReserveError), +} + +impl From for AIProxyError { + fn from(input: TryReserveError) -> Self { + Self::Allocation(input) + } } diff --git a/ahnlich/ai/src/manager/mod.rs b/ahnlich/ai/src/manager/mod.rs index ba403ae4..93f3c166 100644 --- a/ahnlich/ai/src/manager/mod.rs +++ b/ahnlich/ai/src/manager/mod.rs @@ -6,6 +6,7 @@ use crate::engine::ai::models::Model; use crate::error::AIProxyError; use ahnlich_types::ai::{AIModel, AIStoreInputType, ImageAction, PreprocessAction, StringAction}; use ahnlich_types::keyval::{StoreInput, StoreKey}; +use fallible_collections::FallibleVec; use std::collections::HashMap; use task_manager::Task; use task_manager::TaskManager; @@ -46,7 +47,7 @@ impl ModelThread { process_action: PreprocessAction, ) -> ModelThreadResponse { let model: Model = (&self.model).into(); - let mut response = Vec::with_capacity(inputs.len()); + let mut response: Vec<_> = FallibleVec::try_with_capacity(inputs.len())?; // move this from for loop into vec of inputs for input in inputs { let processed_input = self.preprocess_store_input(process_action, input)?; diff --git a/ahnlich/ai/src/server/task.rs b/ahnlich/ai/src/server/task.rs index 5c6a79a5..8efed3f8 100644 --- a/ahnlich/ai/src/server/task.rs +++ b/ahnlich/ai/src/server/task.rs @@ -7,6 +7,7 @@ use ahnlich_types::keyval::{StoreInput, StoreValue}; use ahnlich_types::metadata::MetadataValue; use ahnlich_types::predicate::{Predicate, PredicateCondition}; use ahnlich_types::version::VERSION; +use fallible_collections::vec::TryFromIterator; use std::collections::HashSet; use std::net::SocketAddr; use std::sync::Arc; @@ -334,22 +335,24 @@ impl AhnlichProtocol for AIProxyTask { Ok(res) => { if let ServerResponse::GetSimN(response) = res { // conversion to store input here - let mut output = Vec::new(); - - // TODO: Can Parallelize - for (store_key, store_value, sim) in response.into_iter() { - let temp = - self.store_handler.store_key_val_to_store_input_val( - vec![(store_key, store_value)], - ); - - if let Some(valid_result) = temp.first().take() { - let valid_result = valid_result.to_owned(); - output.push((valid_result.0, valid_result.1, sim)) - } + match TryFromIterator::try_from_iterator( + response.into_iter().flat_map( + |(store_key, store_value, sim)| { + self.store_handler + .store_key_val_to_store_input_val(vec![( + store_key, + store_value, + )]) + .into_iter() + .map(move |v| (v.0, v.1, sim)) + }, + ), + ) + .map_err(|e| AIProxyError::from(e).to_string()) + { + Ok(output) => Ok(AIServerResponse::GetSimN(output)), + Err(err) => Err(err), } - - Ok(AIServerResponse::GetSimN(output)) } else { Err(AIProxyError::UnexpectedDBResponse(format!("{:?}", res)) .to_string()) diff --git a/ahnlich/client/Cargo.toml b/ahnlich/client/Cargo.toml index befc722c..c2ea27a6 100644 --- a/ahnlich/client/Cargo.toml +++ b/ahnlich/client/Cargo.toml @@ -22,6 +22,7 @@ bincode.workspace = true async-trait.workspace = true tokio.workspace = true deadpool.workspace = true +fallible_collections.workspace = true [dev-dependencies] db = { path = "../db", version = "*" } ai = { path = "../ai", version = "*" } diff --git a/ahnlich/client/src/conn/db.rs b/ahnlich/client/src/conn/db.rs index 9f6b87d2..2b113334 100644 --- a/ahnlich/client/src/conn/db.rs +++ b/ahnlich/client/src/conn/db.rs @@ -27,7 +27,7 @@ impl Connection for DBConn { } async fn is_conn_valid(&mut self) -> Result<(), AhnlichError> { - let mut queries = Self::ServerQuery::with_capacity(1); + let mut queries = Self::ServerQuery::with_capacity(1)?; queries.push(DBQuery::Ping); let response = self.send_query(queries).await?; let mut expected_response = ServerResult::with_capacity(1); diff --git a/ahnlich/client/src/db.rs b/ahnlich/client/src/db.rs index 3f5bcf15..925976cb 100644 --- a/ahnlich/client/src/db.rs +++ b/ahnlich/client/src/db.rs @@ -215,7 +215,7 @@ impl DbClient { tracing_id: Option, ) -> Result { Ok(DbPipeline { - queries: ServerDBQuery::with_capacity_and_tracing_id(capacity, tracing_id), + queries: ServerDBQuery::with_capacity_and_tracing_id(capacity, tracing_id)?, conn: self.pool.get().await?, }) } @@ -420,7 +420,7 @@ impl DbClient { tracing_id: Option, ) -> Result { let mut conn = self.pool.get().await?; - let mut queries = ServerDBQuery::with_capacity_and_tracing_id(1, tracing_id); + let mut queries = ServerDBQuery::with_capacity_and_tracing_id(1, tracing_id)?; queries.push(query); let res = conn .send_query(queries) diff --git a/ahnlich/client/src/error.rs b/ahnlich/client/src/error.rs index 68852d5f..4e66177c 100644 --- a/ahnlich/client/src/error.rs +++ b/ahnlich/client/src/error.rs @@ -1,11 +1,17 @@ +use ahnlich_types::bincode::BincodeSerError; +use fallible_collections::TryReserveError; use thiserror::Error; #[derive(Error, Debug)] pub enum AhnlichError { #[error("std io error {0}")] Standard(#[from] std::io::Error), - #[error("bincode serialize error {0}")] - BinCode(#[from] bincode::Error), + #[error("{0}")] + BinCodeSerAndDeser(#[from] BincodeSerError), + #[error("allocation error {0:?}")] + Allocation(TryReserveError), + #[error("bincode deserialize error {0}")] + Bincode(#[from] bincode::Error), #[error("db error {0}")] DbError(String), #[error("empty response")] @@ -27,3 +33,9 @@ impl From for AhnlichError { Self::PoolError(format!("{input}")) } } + +impl From for AhnlichError { + fn from(input: TryReserveError) -> Self { + Self::Allocation(input) + } +} diff --git a/ahnlich/db/Cargo.toml b/ahnlich/db/Cargo.toml index 766af681..4a82b6d1 100644 --- a/ahnlich/db/Cargo.toml +++ b/ahnlich/db/Cargo.toml @@ -37,6 +37,7 @@ serde_json.workspace = true async-trait.workspace = true rayon.workspace = true log.workspace = true +fallible_collections.workspace = true [dev-dependencies] diff --git a/ahnlich/db/src/engine/store.rs b/ahnlich/db/src/engine/store.rs index f62966bf..2a0f407c 100644 --- a/ahnlich/db/src/engine/store.rs +++ b/ahnlich/db/src/engine/store.rs @@ -14,6 +14,7 @@ use ahnlich_types::predicate::PredicateCondition; use ahnlich_types::similarity::Algorithm; use ahnlich_types::similarity::NonLinearAlgorithm; use ahnlich_types::similarity::Similarity; +use fallible_collections::FallibleVec; use flurry::HashMap as ConcurrentHashMap; use serde::Deserialize; use serde::Serialize; @@ -588,7 +589,7 @@ impl Store { .collect(); let pinned = self.id_to_value.pin(); let (mut inserted, mut updated) = (0, 0); - let mut inserted_keys = Vec::new(); + let mut inserted_keys: Vec<_> = FallibleVec::try_with_capacity(res.len())?; for (key, val) in res { if pinned.insert(key, val.clone()).is_some() { updated += 1; diff --git a/ahnlich/db/src/errors.rs b/ahnlich/db/src/errors.rs index 66ee50d2..97dfaf34 100644 --- a/ahnlich/db/src/errors.rs +++ b/ahnlich/db/src/errors.rs @@ -1,10 +1,11 @@ use ahnlich_types::keyval::StoreName; use ahnlich_types::metadata::MetadataKey; use ahnlich_types::similarity::NonLinearAlgorithm; +use fallible_collections::TryReserveError; use thiserror::Error; /// TODO: Move to shared rust types so library can deserialize it from the TCP response -#[derive(Error, Debug, Eq, PartialEq, PartialOrd, Ord)] +#[derive(Error, Debug, Eq, PartialEq)] pub enum ServerError { #[error("Predicate {0} not found in store, attempt CREATEPREDINDEX with predicate")] PredicateNotFound(MetadataKey), @@ -21,4 +22,12 @@ pub enum ServerError { }, #[error("Could not deserialize query, error is {0}")] QueryDeserializeError(String), + #[error("allocation error {0:?}")] + Allocation(TryReserveError), +} + +impl From for ServerError { + fn from(input: TryReserveError) -> Self { + Self::Allocation(input) + } } diff --git a/ahnlich/typegen/src/tracers/query/db.rs b/ahnlich/typegen/src/tracers/query/db.rs index 617e149a..55d7a00b 100644 --- a/ahnlich/typegen/src/tracers/query/db.rs +++ b/ahnlich/typegen/src/tracers/query/db.rs @@ -88,7 +88,8 @@ pub fn trace_db_query_enum() -> Registry { let server_query = ServerDBQuery::from_queries(&[deletepred_variant.clone(), set_query.clone()]); let trace_id = "00-djf9039023r3-1er".to_string(); - let server_query_with_trace_id = ServerDBQuery::with_capacity_and_tracing_id(2, Some(trace_id)); + let server_query_with_trace_id = ServerDBQuery::with_capacity_and_tracing_id(2, Some(trace_id)) + .expect("Could not create server query"); let _ = tracer .trace_value(&mut samples, &create_store) diff --git a/ahnlich/types/Cargo.toml b/ahnlich/types/Cargo.toml index 5d829539..1499643d 100644 --- a/ahnlich/types/Cargo.toml +++ b/ahnlich/types/Cargo.toml @@ -10,3 +10,5 @@ ndarray.workspace = true serde.workspace = true bincode.workspace = true once_cell.workspace = true +fallible_collections.workspace = true +thiserror.workspace = true diff --git a/ahnlich/types/src/bincode.rs b/ahnlich/types/src/bincode.rs index 87e13fea..e3efed95 100644 --- a/ahnlich/types/src/bincode.rs +++ b/ahnlich/types/src/bincode.rs @@ -1,6 +1,7 @@ use crate::version::VERSION; use bincode::config::DefaultOptions; use bincode::config::Options; +use fallible_collections::vec::FallibleVec; use serde::de::DeserializeOwned; use serde::Serialize; @@ -25,7 +26,7 @@ pub trait BinCodeSerAndDeser where Self: Serialize + DeserializeOwned + Send, { - fn serialize(&self) -> Result, bincode::Error> { + fn serialize(&self) -> Result, BincodeSerError> { let config = DefaultOptions::new() .with_fixint_encoding() .with_little_endian(); @@ -33,9 +34,10 @@ where let serialized_data = config.serialize(self)?; let data_length = serialized_data.len() as u64; // serialization appends the length buffer to be read first - let mut buffer = Vec::with_capacity( + let mut buffer: Vec<_> = FallibleVec::try_with_capacity( MAGIC_BYTES.len() + VERSION_LENGTH + LENGTH_HEADER_SIZE + serialized_data.len(), - ); + ) + .map_err(BincodeSerError::Allocation)?; buffer.extend(MAGIC_BYTES); buffer.extend(serialized_version_data); buffer.extend(&data_length.to_le_bytes()); @@ -63,3 +65,11 @@ where pub trait BinCodeSerAndDeserResponse: BinCodeSerAndDeser { fn from_error(err: String) -> Self; } + +#[derive(thiserror::Error, Debug)] +pub enum BincodeSerError { + #[error("bincode serialize error {0}")] + BinCode(#[from] bincode::Error), + #[error("allocation error {0:?}")] + Allocation(fallible_collections::TryReserveError), +} diff --git a/ahnlich/types/src/db/query.rs b/ahnlich/types/src/db/query.rs index 5f4c62e3..5d3a159e 100644 --- a/ahnlich/types/src/db/query.rs +++ b/ahnlich/types/src/db/query.rs @@ -1,3 +1,5 @@ +use fallible_collections::FallibleVec; +use fallible_collections::TryReserveError; use std::collections::HashSet; use std::num::NonZeroUsize; @@ -87,17 +89,20 @@ pub struct ServerQuery { } impl ServerQuery { - pub fn with_capacity(len: usize) -> Self { - Self { - queries: Vec::with_capacity(len), + pub fn with_capacity(len: usize) -> Result { + Ok(Self { + queries: FallibleVec::try_with_capacity(len)?, trace_id: None, - } + }) } - pub fn with_capacity_and_tracing_id(len: usize, trace_id: Option) -> Self { - Self { - queries: Vec::with_capacity(len), + pub fn with_capacity_and_tracing_id( + len: usize, + trace_id: Option, + ) -> Result { + Ok(Self { + queries: FallibleVec::try_with_capacity(len)?, trace_id, - } + }) } pub fn push(&mut self, entry: Query) { diff --git a/ahnlich/utils/Cargo.toml b/ahnlich/utils/Cargo.toml index b416b71c..d1321338 100644 --- a/ahnlich/utils/Cargo.toml +++ b/ahnlich/utils/Cargo.toml @@ -20,3 +20,4 @@ serde_json.workspace = true log.workspace = true cap = "0.1.2" tokio-util.workspace = true +fallible_collections.workspace = true diff --git a/ahnlich/utils/src/protocol.rs b/ahnlich/utils/src/protocol.rs index e4a58028..3e7693c1 100644 --- a/ahnlich/utils/src/protocol.rs +++ b/ahnlich/utils/src/protocol.rs @@ -7,6 +7,7 @@ use ahnlich_types::bincode::VERSION_LENGTH; use ahnlich_types::client::ConnectedClient; use ahnlich_types::version::Version; use ahnlich_types::version::VERSION; +use fallible_collections::vec::FallibleVec; use std::fmt::Debug; use std::io::Error; use std::io::ErrorKind; @@ -44,31 +45,26 @@ where let mut reader = reader.lock().await; match reader.read_exact(&mut magic_bytes_buf).await { Err(ref e) if e.kind() == std::io::ErrorKind::UnexpectedEof => { - let error = self.prefix_log("Hung up on buffered stream"); - log::error!("{error}"); - return TaskState::Break; + let error = "Hung up on buffered stream"; + return self.handle_error(error, false).await; } Err(e) => { - let error = self.prefix_log(format!("Error reading from task buffered stream {e}")); - log::error!("{error}"); - return TaskState::Break; + let error = format!("Error reading from task buffered stream {e}"); + return self.handle_error(error, false).await; } Ok(_) => { if magic_bytes_buf != MAGIC_BYTES { let error = "Invalid request stream".to_string(); - log::error!("{error}"); - return TaskState::Break; + return self.handle_error(error, false).await; } - if let Err(e) = reader.read_exact(&mut version_buf).await { - log::error!("{}", e.to_string()); - return TaskState::Break; + if let Err(error) = reader.read_exact(&mut version_buf).await { + return self.handle_error(error, false).await; } let version = match Version::deserialize_magic_bytes(&version_buf) { Ok(version) => version, Err(error) => { let error = format!("Unable to parse version chunk {error}"); - log::error!("{error}"); - return TaskState::Break; + return self.handle_error(error, false).await; } }; if !VERSION.is_compatible(&version) { @@ -76,36 +72,43 @@ where "Incompatible versions, Server: {:?}, Client {version:?}", *VERSION ); - log::error!("{error}"); - return TaskState::Break; + return self.handle_error(error, false).await; } // cap the message size to be of length 1MiB - if let Err(e) = reader.read_exact(&mut length_buf).await { - let error = format!("Could not read length buffer {e}"); - log::error!("{error}"); - return TaskState::Break; + if let Err(error) = reader.read_exact(&mut length_buf).await { + return self.handle_error(error, false).await; }; let data_length = u64::from_le_bytes(length_buf); if data_length > self.maximum_message_size() { - let error = self.prefix_log(format!( + let error = format!( "Message cannot exceed {} bytes, configure `message_size` for higher", self.maximum_message_size() - )); - log::error!("{error}"); - return TaskState::Break; - } - let mut data = Vec::new(); - if data.try_reserve(data_length as usize).is_err() { - let error = self - .prefix_log(format!("failed to reserve buffer of length {data_length}")); - log::error!("{error}"); - return TaskState::Break; + ); + return self.handle_error(error, false).await; + }; + + let mut data: Vec<_> = match FallibleVec::try_with_capacity(data_length as usize) { + Err(error) => { + return self + .handle_error( + format!("Could not allocate buffer for message body {:?}", error), + true, + ) + .await; + } + Ok(data) => data, + }; + if let Err(error) = data.try_resize(data_length as usize, 0u8) { + return self + .handle_error( + format!("Could not resize buffer for message body {:?}", error), + true, + ) + .await; }; - data.resize(data_length as usize, 0u8); if let Err(e) = reader.read_exact(&mut data).await { let error = format!("Could not read data buffer {e}"); - log::error!("{error}"); - return TaskState::Break; + return self.handle_error(error.to_string(), false).await; }; match Self::ServerQuery::deserialize(&data) { Ok(queries) => { @@ -116,18 +119,14 @@ where .map_err(|err| Error::new(ErrorKind::Other, err)) { Ok(parent_context) => parent_context, - Err(error) => { - log::error!("{error}"); - return TaskState::Break; - } + Err(error) => return self.handle_error(error, false).await, }; span.set_parent(parent_context); } let results = self.handle(queries.into_inner()).instrument(span).await; if let Ok(binary_results) = results.serialize() { if let Err(error) = reader.get_mut().write_all(&binary_results).await { - log::error!("{error}"); - return TaskState::Break; + return self.handle_error(error, false).await; }; log::debug!( "Sent Response of length {}, {:?}", @@ -137,19 +136,7 @@ where } } Err(error) => { - let error = self.prefix_log(format!( - "Could not deserialize client message as server query {error}" - )); - log::error!("{error}"); - let deserialize_error = Self::ServerResponse::from_error( - "Could not deserialize query, error is {e}".to_string(), - ) - .serialize() - .expect("Could not serialize deserialize error"); - if let Err(error) = reader.get_mut().write_all(&deserialize_error).await { - log::error!("{error}"); - return TaskState::Break; - }; + return self.handle_error(error, true).await; } } } @@ -157,6 +144,37 @@ where TaskState::Continue } + async fn handle_error( + &self, + error: impl ToString + Send, + respond_with_error: bool, + ) -> TaskState { + let error = self.prefix_log(error.to_string()); + log::error!("{error}"); + if respond_with_error { + let reader = self.reader(); + let mut reader = reader.lock().await; + match Self::ServerResponse::from_error(format!( + "Could not deserialize query, error is {error}" + )) + .serialize() + { + Err(e) => log::error!( + "{}", + self.prefix_log(format!("Could not deserialize error response, {}", e)) + ), + Ok(deserialize_error) => { + if let Err(error) = reader.get_mut().write_all(&deserialize_error).await { + log::error!("{}", self.prefix_log(format!("{error}"))); + } else { + return TaskState::Continue; + } + } + }; + } + TaskState::Break + } + async fn handle( &self, queries: <::ServerQuery as BinCodeSerAndDeserQuery>::Inner,