Skip to content

Commit

Permalink
WIP: Moving vectors to fallible vectors
Browse files Browse the repository at this point in the history
  • Loading branch information
deven96 committed Sep 11, 2024
1 parent 7bfd744 commit 1ed9df6
Show file tree
Hide file tree
Showing 19 changed files with 165 additions and 90 deletions.
1 change: 1 addition & 0 deletions ahnlich/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -41,3 +41,4 @@ deadpool = { version = "0.10", features = ["rt_tokio_1"]}
opentelemetry = { version = "0.23.0", features = ["trace"] }
tracing-opentelemetry = "0.24.0"
log = "0.4"
fallible_collections = "0.4.9"
2 changes: 2 additions & 0 deletions ahnlich/ai/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,8 @@ serde_json.workspace = true
termcolor = "1.4.1"
strum = { version = "0.26", features = ["derive"] }
log.workspace = true
fallible_collections.workspace = true
rayon.workspace = true

[dev-dependencies]
db = { path = "../db", version = "*" }
Expand Down
3 changes: 2 additions & 1 deletion ahnlich/ai/src/engine/store.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ use ahnlich_types::keyval::StoreKey;
use ahnlich_types::keyval::StoreName;
use ahnlich_types::keyval::StoreValue;
use ahnlich_types::metadata::MetadataValue;
use fallible_collections::FallibleVec;
use flurry::HashMap as ConcurrentHashMap;
use serde::Deserialize;
use serde::Serialize;
Expand Down Expand Up @@ -162,7 +163,7 @@ impl AIStoreHandler {
) -> Result<StoreValidateResponse, AIProxyError> {
let metadata_key = &*AHNLICH_AI_RESERVED_META_KEY;
let store = self.get(store_name)?;
let mut output = Vec::with_capacity(inputs.len());
let mut output: Vec<_> = FallibleVec::try_with_capacity(inputs.len())?;
let mut delete_hashset = StdHashSet::new();

for (store_input, mut store_value) in inputs {
Expand Down
9 changes: 9 additions & 0 deletions ahnlich/ai/src/error.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ use ahnlich_types::{
ai::{AIStoreInputType, PreprocessAction},
keyval::StoreName,
};
use fallible_collections::TryReserveError;
use thiserror::Error;
use tokio::sync::oneshot::error::RecvError;

Expand Down Expand Up @@ -67,4 +68,12 @@ pub enum AIProxyError {
index_model_dim: usize,
query_model_dim: usize,
},
#[error("allocation error {0:?}")]
Allocation(TryReserveError),
}

impl From<TryReserveError> for AIProxyError {
fn from(input: TryReserveError) -> Self {
Self::Allocation(input)
}
}
3 changes: 2 additions & 1 deletion ahnlich/ai/src/manager/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ use crate::engine::ai::models::Model;
use crate::error::AIProxyError;
use ahnlich_types::ai::{AIModel, AIStoreInputType, ImageAction, PreprocessAction, StringAction};
use ahnlich_types::keyval::{StoreInput, StoreKey};
use fallible_collections::FallibleVec;
use std::collections::HashMap;
use task_manager::Task;
use task_manager::TaskManager;
Expand Down Expand Up @@ -46,7 +47,7 @@ impl ModelThread {
process_action: PreprocessAction,
) -> ModelThreadResponse {
let model: Model = (&self.model).into();
let mut response = Vec::with_capacity(inputs.len());
let mut response: Vec<_> = FallibleVec::try_with_capacity(inputs.len())?;
// move this from for loop into vec of inputs
for input in inputs {
let processed_input = self.preprocess_store_input(process_action, input)?;
Expand Down
34 changes: 17 additions & 17 deletions ahnlich/ai/src/server/task.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ use ahnlich_types::keyval::{StoreInput, StoreValue};
use ahnlich_types::metadata::MetadataValue;
use ahnlich_types::predicate::{Predicate, PredicateCondition};
use ahnlich_types::version::VERSION;
use fallible_collections::vec::TryFromIterator;
use std::collections::HashSet;
use std::net::SocketAddr;
use std::sync::Arc;
Expand Down Expand Up @@ -333,23 +334,22 @@ impl AhnlichProtocol for AIProxyTask {
{
Ok(res) => {
if let ServerResponse::GetSimN(response) = res {
// conversion to store input here
let mut output = Vec::new();

// TODO: Can Parallelize
for (store_key, store_value, sim) in response.into_iter() {
let temp =
self.store_handler.store_key_val_to_store_input_val(
vec![(store_key, store_value)],
);

if let Some(valid_result) = temp.first().take() {
let valid_result = valid_result.to_owned();
output.push((valid_result.0, valid_result.1, sim))
}
}

Ok(AIServerResponse::GetSimN(output))
TryFromIterator::try_from_iterator(
response.into_iter().flat_map(
|(store_key, store_value, sim)| {
// TODO: Can parallelize
self.store_handler
.store_key_val_to_store_input_val(vec![(
store_key,
store_value,
)])
.into_iter()
.map(move |v| (v.0, v.1, sim))
},
),
)
.map_err(|e| AIProxyError::from(e).to_string())
.map(AIServerResponse::GetSimN)
} else {
Err(AIProxyError::UnexpectedDBResponse(format!("{:?}", res))
.to_string())
Expand Down
1 change: 1 addition & 0 deletions ahnlich/client/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ bincode.workspace = true
async-trait.workspace = true
tokio.workspace = true
deadpool.workspace = true
fallible_collections.workspace = true
[dev-dependencies]
db = { path = "../db", version = "*" }
ai = { path = "../ai", version = "*" }
Expand Down
2 changes: 1 addition & 1 deletion ahnlich/client/src/conn/db.rs
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ impl Connection for DBConn {
}

async fn is_conn_valid(&mut self) -> Result<(), AhnlichError> {
let mut queries = Self::ServerQuery::with_capacity(1);
let mut queries = Self::ServerQuery::with_capacity(1)?;
queries.push(DBQuery::Ping);
let response = self.send_query(queries).await?;
let mut expected_response = ServerResult::with_capacity(1);
Expand Down
4 changes: 2 additions & 2 deletions ahnlich/client/src/db.rs
Original file line number Diff line number Diff line change
Expand Up @@ -215,7 +215,7 @@ impl DbClient {
tracing_id: Option<String>,
) -> Result<DbPipeline, AhnlichError> {
Ok(DbPipeline {
queries: ServerDBQuery::with_capacity_and_tracing_id(capacity, tracing_id),
queries: ServerDBQuery::with_capacity_and_tracing_id(capacity, tracing_id)?,
conn: self.pool.get().await?,
})
}
Expand Down Expand Up @@ -420,7 +420,7 @@ impl DbClient {
tracing_id: Option<String>,
) -> Result<ServerResponse, AhnlichError> {
let mut conn = self.pool.get().await?;
let mut queries = ServerDBQuery::with_capacity_and_tracing_id(1, tracing_id);
let mut queries = ServerDBQuery::with_capacity_and_tracing_id(1, tracing_id)?;
queries.push(query);
let res = conn
.send_query(queries)
Expand Down
16 changes: 14 additions & 2 deletions ahnlich/client/src/error.rs
Original file line number Diff line number Diff line change
@@ -1,11 +1,17 @@
use ahnlich_types::bincode::BincodeSerError;
use fallible_collections::TryReserveError;
use thiserror::Error;

#[derive(Error, Debug)]
pub enum AhnlichError {
#[error("std io error {0}")]
Standard(#[from] std::io::Error),
#[error("bincode serialize error {0}")]
BinCode(#[from] bincode::Error),
#[error("{0}")]
BinCodeSerAndDeser(#[from] BincodeSerError),
#[error("allocation error {0:?}")]
Allocation(TryReserveError),
#[error("bincode deserialize error {0}")]
Bincode(#[from] bincode::Error),
#[error("db error {0}")]
DbError(String),
#[error("empty response")]
Expand All @@ -27,3 +33,9 @@ impl From<deadpool::managed::BuildError> for AhnlichError {
Self::PoolError(format!("{input}"))
}
}

impl From<TryReserveError> for AhnlichError {
fn from(input: TryReserveError) -> Self {
Self::Allocation(input)
}
}
1 change: 1 addition & 0 deletions ahnlich/db/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ serde_json.workspace = true
async-trait.workspace = true
rayon.workspace = true
log.workspace = true
fallible_collections.workspace = true


[dev-dependencies]
Expand Down
3 changes: 2 additions & 1 deletion ahnlich/db/src/engine/store.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ use ahnlich_types::predicate::PredicateCondition;
use ahnlich_types::similarity::Algorithm;
use ahnlich_types::similarity::NonLinearAlgorithm;
use ahnlich_types::similarity::Similarity;
use fallible_collections::FallibleVec;
use flurry::HashMap as ConcurrentHashMap;
use serde::Deserialize;
use serde::Serialize;
Expand Down Expand Up @@ -588,7 +589,7 @@ impl Store {
.collect();
let pinned = self.id_to_value.pin();
let (mut inserted, mut updated) = (0, 0);
let mut inserted_keys = Vec::new();
let mut inserted_keys: Vec<_> = FallibleVec::try_with_capacity(res.len())?;
for (key, val) in res {
if pinned.insert(key, val.clone()).is_some() {
updated += 1;
Expand Down
11 changes: 10 additions & 1 deletion ahnlich/db/src/errors.rs
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
use ahnlich_types::keyval::StoreName;
use ahnlich_types::metadata::MetadataKey;
use ahnlich_types::similarity::NonLinearAlgorithm;
use fallible_collections::TryReserveError;
use thiserror::Error;

/// TODO: Move to shared rust types so library can deserialize it from the TCP response
#[derive(Error, Debug, Eq, PartialEq, PartialOrd, Ord)]
#[derive(Error, Debug, Eq, PartialEq)]
pub enum ServerError {
#[error("Predicate {0} not found in store, attempt CREATEPREDINDEX with predicate")]
PredicateNotFound(MetadataKey),
Expand All @@ -21,4 +22,12 @@ pub enum ServerError {
},
#[error("Could not deserialize query, error is {0}")]
QueryDeserializeError(String),
#[error("allocation error {0:?}")]
Allocation(TryReserveError),
}

impl From<TryReserveError> for ServerError {
fn from(input: TryReserveError) -> Self {
Self::Allocation(input)
}
}
3 changes: 2 additions & 1 deletion ahnlich/typegen/src/tracers/query/db.rs
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,8 @@ pub fn trace_db_query_enum() -> Registry {
let server_query =
ServerDBQuery::from_queries(&[deletepred_variant.clone(), set_query.clone()]);
let trace_id = "00-djf9039023r3-1er".to_string();
let server_query_with_trace_id = ServerDBQuery::with_capacity_and_tracing_id(2, Some(trace_id));
let server_query_with_trace_id = ServerDBQuery::with_capacity_and_tracing_id(2, Some(trace_id))
.expect("Could not create server query");

let _ = tracer
.trace_value(&mut samples, &create_store)
Expand Down
2 changes: 2 additions & 0 deletions ahnlich/types/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -10,3 +10,5 @@ ndarray.workspace = true
serde.workspace = true
bincode.workspace = true
once_cell.workspace = true
fallible_collections.workspace = true
thiserror.workspace = true
16 changes: 13 additions & 3 deletions ahnlich/types/src/bincode.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
use crate::version::VERSION;
use bincode::config::DefaultOptions;
use bincode::config::Options;
use fallible_collections::vec::FallibleVec;
use serde::de::DeserializeOwned;
use serde::Serialize;

Expand All @@ -25,17 +26,18 @@ pub trait BinCodeSerAndDeser
where
Self: Serialize + DeserializeOwned + Send,
{
fn serialize(&self) -> Result<Vec<u8>, bincode::Error> {
fn serialize(&self) -> Result<Vec<u8>, BincodeSerError> {
let config = DefaultOptions::new()
.with_fixint_encoding()
.with_little_endian();
let serialized_version_data = config.serialize(&*VERSION)?;
let serialized_data = config.serialize(self)?;
let data_length = serialized_data.len() as u64;
// serialization appends the length buffer to be read first
let mut buffer = Vec::with_capacity(
let mut buffer: Vec<_> = FallibleVec::try_with_capacity(
MAGIC_BYTES.len() + VERSION_LENGTH + LENGTH_HEADER_SIZE + serialized_data.len(),
);
)
.map_err(BincodeSerError::Allocation)?;
buffer.extend(MAGIC_BYTES);
buffer.extend(serialized_version_data);
buffer.extend(&data_length.to_le_bytes());
Expand Down Expand Up @@ -63,3 +65,11 @@ where
pub trait BinCodeSerAndDeserResponse: BinCodeSerAndDeser {
fn from_error(err: String) -> Self;
}

#[derive(thiserror::Error, Debug)]
pub enum BincodeSerError {
#[error("bincode serialize error {0}")]
BinCode(#[from] bincode::Error),
#[error("allocation error {0:?}")]
Allocation(fallible_collections::TryReserveError),
}
21 changes: 13 additions & 8 deletions ahnlich/types/src/db/query.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
use fallible_collections::FallibleVec;
use fallible_collections::TryReserveError;
use std::collections::HashSet;
use std::num::NonZeroUsize;

Expand Down Expand Up @@ -87,17 +89,20 @@ pub struct ServerQuery {
}

impl ServerQuery {
pub fn with_capacity(len: usize) -> Self {
Self {
queries: Vec::with_capacity(len),
pub fn with_capacity(len: usize) -> Result<Self, TryReserveError> {
Ok(Self {
queries: FallibleVec::try_with_capacity(len)?,
trace_id: None,
}
})
}
pub fn with_capacity_and_tracing_id(len: usize, trace_id: Option<String>) -> Self {
Self {
queries: Vec::with_capacity(len),
pub fn with_capacity_and_tracing_id(
len: usize,
trace_id: Option<String>,
) -> Result<Self, TryReserveError> {
Ok(Self {
queries: FallibleVec::try_with_capacity(len)?,
trace_id,
}
})
}

pub fn push(&mut self, entry: Query) {
Expand Down
1 change: 1 addition & 0 deletions ahnlich/utils/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -20,3 +20,4 @@ serde_json.workspace = true
log.workspace = true
cap = "0.1.2"
tokio-util.workspace = true
fallible_collections.workspace = true
Loading

0 comments on commit 1ed9df6

Please sign in to comment.