Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

WIP: Moving vectors to fallible vectors #103

Merged
merged 1 commit into from
Sep 11, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions ahnlich/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -41,3 +41,4 @@ deadpool = { version = "0.10", features = ["rt_tokio_1"]}
opentelemetry = { version = "0.23.0", features = ["trace"] }
tracing-opentelemetry = "0.24.0"
log = "0.4"
fallible_collections = "0.4.9"
2 changes: 2 additions & 0 deletions ahnlich/ai/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,8 @@ serde_json.workspace = true
termcolor = "1.4.1"
strum = { version = "0.26", features = ["derive"] }
log.workspace = true
fallible_collections.workspace = true
rayon.workspace = true

[dev-dependencies]
db = { path = "../db", version = "*" }
Expand Down
3 changes: 2 additions & 1 deletion ahnlich/ai/src/engine/store.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ use ahnlich_types::keyval::StoreKey;
use ahnlich_types::keyval::StoreName;
use ahnlich_types::keyval::StoreValue;
use ahnlich_types::metadata::MetadataValue;
use fallible_collections::FallibleVec;
use flurry::HashMap as ConcurrentHashMap;
use serde::Deserialize;
use serde::Serialize;
Expand Down Expand Up @@ -162,7 +163,7 @@ impl AIStoreHandler {
) -> Result<StoreValidateResponse, AIProxyError> {
let metadata_key = &*AHNLICH_AI_RESERVED_META_KEY;
let store = self.get(store_name)?;
let mut output = Vec::with_capacity(inputs.len());
let mut output: Vec<_> = FallibleVec::try_with_capacity(inputs.len())?;
let mut delete_hashset = StdHashSet::new();

for (store_input, mut store_value) in inputs {
Expand Down
9 changes: 9 additions & 0 deletions ahnlich/ai/src/error.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ use ahnlich_types::{
ai::{AIStoreInputType, PreprocessAction},
keyval::StoreName,
};
use fallible_collections::TryReserveError;
use thiserror::Error;
use tokio::sync::oneshot::error::RecvError;

Expand Down Expand Up @@ -67,4 +68,12 @@ pub enum AIProxyError {
index_model_dim: usize,
query_model_dim: usize,
},
#[error("allocation error {0:?}")]
Allocation(TryReserveError),
}

impl From<TryReserveError> for AIProxyError {
fn from(input: TryReserveError) -> Self {
Self::Allocation(input)
}
}
3 changes: 2 additions & 1 deletion ahnlich/ai/src/manager/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ use crate::engine::ai::models::Model;
use crate::error::AIProxyError;
use ahnlich_types::ai::{AIModel, AIStoreInputType, ImageAction, PreprocessAction, StringAction};
use ahnlich_types::keyval::{StoreInput, StoreKey};
use fallible_collections::FallibleVec;
use std::collections::HashMap;
use task_manager::Task;
use task_manager::TaskManager;
Expand Down Expand Up @@ -46,7 +47,7 @@ impl ModelThread {
process_action: PreprocessAction,
) -> ModelThreadResponse {
let model: Model = (&self.model).into();
let mut response = Vec::with_capacity(inputs.len());
let mut response: Vec<_> = FallibleVec::try_with_capacity(inputs.len())?;
// move this from for loop into vec of inputs
for input in inputs {
let processed_input = self.preprocess_store_input(process_action, input)?;
Expand Down
34 changes: 17 additions & 17 deletions ahnlich/ai/src/server/task.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ use ahnlich_types::keyval::{StoreInput, StoreValue};
use ahnlich_types::metadata::MetadataValue;
use ahnlich_types::predicate::{Predicate, PredicateCondition};
use ahnlich_types::version::VERSION;
use fallible_collections::vec::TryFromIterator;
use std::collections::HashSet;
use std::net::SocketAddr;
use std::sync::Arc;
Expand Down Expand Up @@ -333,23 +334,22 @@ impl AhnlichProtocol for AIProxyTask {
{
Ok(res) => {
if let ServerResponse::GetSimN(response) = res {
// conversion to store input here
let mut output = Vec::new();

// TODO: Can Parallelize
for (store_key, store_value, sim) in response.into_iter() {
let temp =
self.store_handler.store_key_val_to_store_input_val(
vec![(store_key, store_value)],
);

if let Some(valid_result) = temp.first().take() {
let valid_result = valid_result.to_owned();
output.push((valid_result.0, valid_result.1, sim))
}
}

Ok(AIServerResponse::GetSimN(output))
TryFromIterator::try_from_iterator(
response.into_iter().flat_map(
|(store_key, store_value, sim)| {
// TODO: Can parallelize
self.store_handler
.store_key_val_to_store_input_val(vec![(
store_key,
store_value,
)])
.into_iter()
.map(move |v| (v.0, v.1, sim))
},
),
)
.map_err(|e| AIProxyError::from(e).to_string())
.map(AIServerResponse::GetSimN)
} else {
Err(AIProxyError::UnexpectedDBResponse(format!("{:?}", res))
.to_string())
Expand Down
1 change: 1 addition & 0 deletions ahnlich/client/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ bincode.workspace = true
async-trait.workspace = true
tokio.workspace = true
deadpool.workspace = true
fallible_collections.workspace = true
[dev-dependencies]
db = { path = "../db", version = "*" }
ai = { path = "../ai", version = "*" }
Expand Down
2 changes: 1 addition & 1 deletion ahnlich/client/src/conn/db.rs
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ impl Connection for DBConn {
}

async fn is_conn_valid(&mut self) -> Result<(), AhnlichError> {
let mut queries = Self::ServerQuery::with_capacity(1);
let mut queries = Self::ServerQuery::with_capacity(1)?;
queries.push(DBQuery::Ping);
let response = self.send_query(queries).await?;
let mut expected_response = ServerResult::with_capacity(1);
Expand Down
4 changes: 2 additions & 2 deletions ahnlich/client/src/db.rs
Original file line number Diff line number Diff line change
Expand Up @@ -215,7 +215,7 @@ impl DbClient {
tracing_id: Option<String>,
) -> Result<DbPipeline, AhnlichError> {
Ok(DbPipeline {
queries: ServerDBQuery::with_capacity_and_tracing_id(capacity, tracing_id),
queries: ServerDBQuery::with_capacity_and_tracing_id(capacity, tracing_id)?,
conn: self.pool.get().await?,
})
}
Expand Down Expand Up @@ -420,7 +420,7 @@ impl DbClient {
tracing_id: Option<String>,
) -> Result<ServerResponse, AhnlichError> {
let mut conn = self.pool.get().await?;
let mut queries = ServerDBQuery::with_capacity_and_tracing_id(1, tracing_id);
let mut queries = ServerDBQuery::with_capacity_and_tracing_id(1, tracing_id)?;
queries.push(query);
let res = conn
.send_query(queries)
Expand Down
16 changes: 14 additions & 2 deletions ahnlich/client/src/error.rs
Original file line number Diff line number Diff line change
@@ -1,11 +1,17 @@
use ahnlich_types::bincode::BincodeSerError;
use fallible_collections::TryReserveError;
use thiserror::Error;

#[derive(Error, Debug)]
pub enum AhnlichError {
#[error("std io error {0}")]
Standard(#[from] std::io::Error),
#[error("bincode serialize error {0}")]
BinCode(#[from] bincode::Error),
#[error("{0}")]
BinCodeSerAndDeser(#[from] BincodeSerError),
#[error("allocation error {0:?}")]
Allocation(TryReserveError),
#[error("bincode deserialize error {0}")]
Bincode(#[from] bincode::Error),
#[error("db error {0}")]
DbError(String),
#[error("empty response")]
Expand All @@ -27,3 +33,9 @@ impl From<deadpool::managed::BuildError> for AhnlichError {
Self::PoolError(format!("{input}"))
}
}

impl From<TryReserveError> for AhnlichError {
fn from(input: TryReserveError) -> Self {
Self::Allocation(input)
}
}
1 change: 1 addition & 0 deletions ahnlich/db/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ serde_json.workspace = true
async-trait.workspace = true
rayon.workspace = true
log.workspace = true
fallible_collections.workspace = true


[dev-dependencies]
Expand Down
3 changes: 2 additions & 1 deletion ahnlich/db/src/engine/store.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ use ahnlich_types::predicate::PredicateCondition;
use ahnlich_types::similarity::Algorithm;
use ahnlich_types::similarity::NonLinearAlgorithm;
use ahnlich_types::similarity::Similarity;
use fallible_collections::FallibleVec;
use flurry::HashMap as ConcurrentHashMap;
use serde::Deserialize;
use serde::Serialize;
Expand Down Expand Up @@ -588,7 +589,7 @@ impl Store {
.collect();
let pinned = self.id_to_value.pin();
let (mut inserted, mut updated) = (0, 0);
let mut inserted_keys = Vec::new();
let mut inserted_keys: Vec<_> = FallibleVec::try_with_capacity(res.len())?;
for (key, val) in res {
if pinned.insert(key, val.clone()).is_some() {
updated += 1;
Expand Down
11 changes: 10 additions & 1 deletion ahnlich/db/src/errors.rs
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
use ahnlich_types::keyval::StoreName;
use ahnlich_types::metadata::MetadataKey;
use ahnlich_types::similarity::NonLinearAlgorithm;
use fallible_collections::TryReserveError;
use thiserror::Error;

/// TODO: Move to shared rust types so library can deserialize it from the TCP response
#[derive(Error, Debug, Eq, PartialEq, PartialOrd, Ord)]
#[derive(Error, Debug, Eq, PartialEq)]
pub enum ServerError {
#[error("Predicate {0} not found in store, attempt CREATEPREDINDEX with predicate")]
PredicateNotFound(MetadataKey),
Expand All @@ -21,4 +22,12 @@ pub enum ServerError {
},
#[error("Could not deserialize query, error is {0}")]
QueryDeserializeError(String),
#[error("allocation error {0:?}")]
Allocation(TryReserveError),
}

impl From<TryReserveError> for ServerError {
fn from(input: TryReserveError) -> Self {
Self::Allocation(input)
}
}
3 changes: 2 additions & 1 deletion ahnlich/typegen/src/tracers/query/db.rs
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,8 @@ pub fn trace_db_query_enum() -> Registry {
let server_query =
ServerDBQuery::from_queries(&[deletepred_variant.clone(), set_query.clone()]);
let trace_id = "00-djf9039023r3-1er".to_string();
let server_query_with_trace_id = ServerDBQuery::with_capacity_and_tracing_id(2, Some(trace_id));
let server_query_with_trace_id = ServerDBQuery::with_capacity_and_tracing_id(2, Some(trace_id))
.expect("Could not create server query");

let _ = tracer
.trace_value(&mut samples, &create_store)
Expand Down
2 changes: 2 additions & 0 deletions ahnlich/types/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -10,3 +10,5 @@ ndarray.workspace = true
serde.workspace = true
bincode.workspace = true
once_cell.workspace = true
fallible_collections.workspace = true
thiserror.workspace = true
16 changes: 13 additions & 3 deletions ahnlich/types/src/bincode.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
use crate::version::VERSION;
use bincode::config::DefaultOptions;
use bincode::config::Options;
use fallible_collections::vec::FallibleVec;
use serde::de::DeserializeOwned;
use serde::Serialize;

Expand All @@ -25,17 +26,18 @@ pub trait BinCodeSerAndDeser
where
Self: Serialize + DeserializeOwned + Send,
{
fn serialize(&self) -> Result<Vec<u8>, bincode::Error> {
fn serialize(&self) -> Result<Vec<u8>, BincodeSerError> {
let config = DefaultOptions::new()
.with_fixint_encoding()
.with_little_endian();
let serialized_version_data = config.serialize(&*VERSION)?;
let serialized_data = config.serialize(self)?;
let data_length = serialized_data.len() as u64;
// serialization appends the length buffer to be read first
let mut buffer = Vec::with_capacity(
let mut buffer: Vec<_> = FallibleVec::try_with_capacity(
MAGIC_BYTES.len() + VERSION_LENGTH + LENGTH_HEADER_SIZE + serialized_data.len(),
);
)
.map_err(BincodeSerError::Allocation)?;
buffer.extend(MAGIC_BYTES);
buffer.extend(serialized_version_data);
buffer.extend(&data_length.to_le_bytes());
Expand Down Expand Up @@ -63,3 +65,11 @@ where
pub trait BinCodeSerAndDeserResponse: BinCodeSerAndDeser {
fn from_error(err: String) -> Self;
}

#[derive(thiserror::Error, Debug)]
pub enum BincodeSerError {
#[error("bincode serialize error {0}")]
BinCode(#[from] bincode::Error),
#[error("allocation error {0:?}")]
Allocation(fallible_collections::TryReserveError),
}
21 changes: 13 additions & 8 deletions ahnlich/types/src/db/query.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
use fallible_collections::FallibleVec;
use fallible_collections::TryReserveError;
use std::collections::HashSet;
use std::num::NonZeroUsize;

Expand Down Expand Up @@ -87,17 +89,20 @@ pub struct ServerQuery {
}

impl ServerQuery {
pub fn with_capacity(len: usize) -> Self {
Self {
queries: Vec::with_capacity(len),
pub fn with_capacity(len: usize) -> Result<Self, TryReserveError> {
Ok(Self {
queries: FallibleVec::try_with_capacity(len)?,
trace_id: None,
}
})
}
pub fn with_capacity_and_tracing_id(len: usize, trace_id: Option<String>) -> Self {
Self {
queries: Vec::with_capacity(len),
pub fn with_capacity_and_tracing_id(
len: usize,
trace_id: Option<String>,
) -> Result<Self, TryReserveError> {
Ok(Self {
queries: FallibleVec::try_with_capacity(len)?,
trace_id,
}
})
}

pub fn push(&mut self, entry: Query) {
Expand Down
1 change: 1 addition & 0 deletions ahnlich/utils/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -20,3 +20,4 @@ serde_json.workspace = true
log.workspace = true
cap = "0.1.2"
tokio-util.workspace = true
fallible_collections.workspace = true
Loading
Loading