From 07f5fff99dfe420ae352e15398de23b01756c098 Mon Sep 17 00:00:00 2001 From: lordsarcastic Date: Thu, 14 Nov 2024 05:09:10 +0100 Subject: [PATCH 01/10] Fix: Inconsistencies between ai-proxy commands and db --- ahnlich/ai/src/server/task.rs | 17 ++++++++--------- ahnlich/ai/src/tests/aiproxy_test.rs | 1 + ahnlich/client/src/ai.rs | 4 ++++ ahnlich/dsl/src/ai.rs | 7 +++++++ ahnlich/dsl/src/tests/ai.rs | 4 +++- ahnlich/typegen/src/tracers/query/ai.rs | 1 + ahnlich/types/src/ai/query.rs | 6 ++++++ 7 files changed, 30 insertions(+), 10 deletions(-) diff --git a/ahnlich/ai/src/server/task.rs b/ahnlich/ai/src/server/task.rs index 65a99021..5033fc04 100644 --- a/ahnlich/ai/src/server/task.rs +++ b/ahnlich/ai/src/server/task.rs @@ -320,22 +320,15 @@ impl AhnlichProtocol for AIProxyTask { condition, closest_n, algorithm, + preprocess_action, } => { - // TODO: Replace this with calls to self.model_manager.handle_request - // TODO (HAKSOAT): Shouldn't preprocess action also be in the params? - let preprocess = match search_input { - StoreInput::RawString(_) => { - PreprocessAction::RawString(StringAction::TruncateIfTokensExceed) - } - StoreInput::Image(_) => PreprocessAction::Image(ImageAction::ResizeImage), - }; let repr = self .store_handler .get_ndarray_repr_for_store( &store, search_input, &self.model_manager, - preprocess, + preprocess_action, ) .await; if let Ok(store_key) = repr { @@ -384,6 +377,12 @@ impl AhnlichProtocol for AIProxyTask { let destoryed = self.store_handler.purge_stores(); Ok(AIServerResponse::Del(destoryed)) } + AIQuery::ListClients => Ok(AIServerResponse::ClientList(self.client_handler.list())), + AIQuery::GetKey { store, keys } => self + .store_handler + .get_key_in_store(&store, keys) + .map(ServerResponse::Get) + .map_err(|e| format!("{e}")), }) } result diff --git a/ahnlich/ai/src/tests/aiproxy_test.rs b/ahnlich/ai/src/tests/aiproxy_test.rs index ccedba81..c7226741 100644 --- a/ahnlich/ai/src/tests/aiproxy_test.rs +++ b/ahnlich/ai/src/tests/aiproxy_test.rs @@ -375,6 +375,7 @@ async fn test_ai_proxy_get_sim_n_succeeds() { condition: None, closest_n: NonZeroUsize::new(1).unwrap(), algorithm: Algorithm::DotProductSimilarity, + preprocess_action: PreprocessAction::RawString(StringAction::ErrorIfTokensExceed), }]); let mut expected = AIServerResult::with_capacity(1); diff --git a/ahnlich/client/src/ai.rs b/ahnlich/client/src/ai.rs index f7e2ef05..864b2ec5 100644 --- a/ahnlich/client/src/ai.rs +++ b/ahnlich/client/src/ai.rs @@ -84,6 +84,7 @@ impl AIPipeline { condition: Option, closest_n: NonZeroUsize, algorithm: Algorithm, + preprocess_action: PreprocessAction, ) { self.queries.push(AIQuery::GetSimN { store, @@ -91,6 +92,7 @@ impl AIPipeline { condition, closest_n, algorithm, + preprocess_action }) } @@ -256,6 +258,7 @@ impl AIClient { closest_n: NonZeroUsize, algorithm: Algorithm, tracing_id: Option, + preprocess_action: PreprocessAction, ) -> Result { self.exec( AIQuery::GetSimN { @@ -264,6 +267,7 @@ impl AIClient { condition, closest_n, algorithm, + preprocess_action }, tracing_id, ) diff --git a/ahnlich/dsl/src/ai.rs b/ahnlich/dsl/src/ai.rs index f2967880..fb2bce34 100644 --- a/ahnlich/dsl/src/ai.rs +++ b/ahnlich/dsl/src/ai.rs @@ -190,12 +190,19 @@ pub fn parse_ai_query(input: &str) -> Result, DslError> { } else { None }; + let preprocess_action = parse_to_preprocess_action( + inner_pairs + .next() + .ok_or(DslError::UnexpectedSpan((start_pos, end_pos)))? + .as_str(), + ); AIQuery::GetSimN { store: StoreName(store.to_string()), search_input, closest_n, algorithm, condition, + preprocess_action } } Rule::get_pred => { diff --git a/ahnlich/dsl/src/tests/ai.rs b/ahnlich/dsl/src/tests/ai.rs index 24026b32..338444a0 100644 --- a/ahnlich/dsl/src/tests/ai.rs +++ b/ahnlich/dsl/src/tests/ai.rs @@ -207,7 +207,8 @@ fn test_get_sim_n_parse() { search_input: StoreInput::RawString("hi my name is carter".to_string()), closest_n: NonZeroUsize::new(5).unwrap(), algorithm: Algorithm::CosineSimilarity, - condition: None + condition: None, + preprocess_action: PreprocessAction::RawString(StringAction::TruncateIfTokensExceed), }] ); let input = r#"GETSIMN 8 with [testing the limits of life] using euclideandistance in other where ((year != 2012) AND (month not in (december, october)))"#; @@ -231,6 +232,7 @@ fn test_get_sim_n_parse() { ]), })) ), + preprocess_action: PreprocessAction::RawString(StringAction::TruncateIfTokensExceed), }] ); } diff --git a/ahnlich/typegen/src/tracers/query/ai.rs b/ahnlich/typegen/src/tracers/query/ai.rs index 07a58925..88357427 100644 --- a/ahnlich/typegen/src/tracers/query/ai.rs +++ b/ahnlich/typegen/src/tracers/query/ai.rs @@ -66,6 +66,7 @@ pub fn trace_ai_query_enum() -> Registry { condition: Some(test_predicate_condition.clone()), closest_n: NonZeroUsize::new(4).unwrap(), algorithm: Algorithm::CosineSimilarity, + preprocess_action: PreprocessAction::RawString(StringAction::TruncateIfTokensExceed), }; let create_index = AIQuery::CreatePredIndex { diff --git a/ahnlich/types/src/ai/query.rs b/ahnlich/types/src/ai/query.rs index eeed6317..337c7ca9 100644 --- a/ahnlich/types/src/ai/query.rs +++ b/ahnlich/types/src/ai/query.rs @@ -30,6 +30,7 @@ pub enum AIQuery { condition: Option, closest_n: NonZeroUsize, algorithm: Algorithm, + preprocess_action: PreprocessAction, }, CreatePredIndex { store: StoreName, @@ -62,7 +63,12 @@ pub enum AIQuery { store: StoreName, error_if_not_exists: bool, }, + GetKey { + store: StoreName, + keys: Vec, + }, InfoServer, + ListClients, ListStores, PurgeStores, Ping, From 9bead87fe0b7006ed1f798ea4292d39857f03c4b Mon Sep 17 00:00:00 2001 From: lordsarcastic Date: Sat, 30 Nov 2024 01:17:04 +0100 Subject: [PATCH 02/10] Chore: Implemented tests for changes to AI queries --- ahnlich/ai/src/engine/store.rs | 1 + ahnlich/ai/src/server/task.rs | 40 ++++++++++++---- ahnlich/ai/src/tests/aiproxy_test.rs | 71 +++++++++++++++++++++++++--- ahnlich/dsl/src/syntax/syntax.pest | 2 +- 4 files changed, 98 insertions(+), 16 deletions(-) diff --git a/ahnlich/ai/src/engine/store.rs b/ahnlich/ai/src/engine/store.rs index d2f02215..f9569b78 100644 --- a/ahnlich/ai/src/engine/store.rs +++ b/ahnlich/ai/src/engine/store.rs @@ -328,6 +328,7 @@ impl AIStoreHandler { self.stores.clear(&guard); store_length } + } #[derive(Debug, Serialize, Deserialize)] diff --git a/ahnlich/ai/src/server/task.rs b/ahnlich/ai/src/server/task.rs index 5033fc04..448357b8 100644 --- a/ahnlich/ai/src/server/task.rs +++ b/ahnlich/ai/src/server/task.rs @@ -1,12 +1,10 @@ use crate::engine::ai::models::Model; use ahnlich_client_rs::db::DbClient; use ahnlich_types::ai::{ - AIQuery, AIServerQuery, AIServerResponse, AIServerResult, ImageAction, PreprocessAction, - StringAction, + AIQuery, AIServerQuery, AIServerResponse, AIServerResult }; use ahnlich_types::client::ConnectedClient; use ahnlich_types::db::{ServerInfo, ServerResponse}; -use ahnlich_types::keyval::StoreInput; use ahnlich_types::metadata::MetadataValue; use ahnlich_types::predicate::{Predicate, PredicateCondition}; use ahnlich_types::version::VERSION; @@ -378,17 +376,43 @@ impl AhnlichProtocol for AIProxyTask { Ok(AIServerResponse::Del(destoryed)) } AIQuery::ListClients => Ok(AIServerResponse::ClientList(self.client_handler.list())), - AIQuery::GetKey { store, keys } => self - .store_handler - .get_key_in_store(&store, keys) - .map(ServerResponse::Get) - .map_err(|e| format!("{e}")), + AIQuery::GetKey { store, keys } => { + let metadata_values: HashSet = keys.into_iter().map( + |value| value.into() + ).collect(); + let get_key_condition = + PredicateCondition::Value(Predicate::In { + key: AHNLICH_AI_RESERVED_META_KEY.clone(), + value: metadata_values, + }); + + match self + .db_client + .get_pred(store, get_key_condition, parent_id.clone()) + .await + { + Ok(res) => { + if let ServerResponse::Get(response) = res { + // conversion to store input here + let output = self + .store_handler + .store_key_val_to_store_input_val(response); + Ok(AIServerResponse::Get(output)) + } else { + Err(AIProxyError::UnexpectedDBResponse(format!("{:?}", res)) + .to_string()) + } + } + Err(err) => Err(format!("{err}")), + } + } }) } result } } + impl AIProxyTask { #[tracing::instrument(skip(self))] fn server_info(&self) -> ServerInfo { diff --git a/ahnlich/ai/src/tests/aiproxy_test.rs b/ahnlich/ai/src/tests/aiproxy_test.rs index c7226741..f556a76f 100644 --- a/ahnlich/ai/src/tests/aiproxy_test.rs +++ b/ahnlich/ai/src/tests/aiproxy_test.rs @@ -4,18 +4,14 @@ use ahnlich_types::{ ai::{ AIModel, AIQuery, AIServerQuery, AIServerResponse, AIServerResult, AIStoreInfo, ImageAction, PreprocessAction, StringAction, - }, - db::StoreUpsert, - keyval::{StoreInput, StoreName, StoreValue}, - metadata::{MetadataKey, MetadataValue}, - predicate::{Predicate, PredicateCondition}, - similarity::Algorithm, + }, client::ConnectedClient, db::StoreUpsert, keyval::{StoreInput, StoreName, StoreValue}, metadata::{MetadataKey, MetadataValue}, predicate::{Predicate, PredicateCondition}, similarity::Algorithm }; +// use flurry::HashMap; use utils::server::AhnlichServerUtils; use once_cell::sync::Lazy; use pretty_assertions::assert_eq; -use std::{collections::HashSet, num::NonZeroUsize, sync::atomic::Ordering}; +use std::{collections::{HashMap, HashSet}, num::NonZeroUsize, sync::atomic::Ordering}; use crate::{ cli::{server::SupportedModels, AIProxyConfig}, @@ -162,6 +158,67 @@ async fn test_ai_proxy_create_store_success() { query_server_assert_result(&mut reader, message, expected.clone()).await; } + +#[tokio::test] +async fn test_ai_store_get_key_works() { + let address = provision_test_servers().await; + let first_stream = TcpStream::connect(address).await.unwrap(); + let second_stream = TcpStream::connect(address).await.unwrap(); + let store_name = StoreName(String::from("Deven Kicks")); + let store_input = StoreInput::RawString(String::from("Jordan 3")); + let store_data: (StoreInput, HashMap) = ( + store_input.clone(), + HashMap::new() + ); + + let message = AIServerQuery::from_queries(&[ + AIQuery::CreateStore { + store: store_name.clone(), + query_model: AIModel::AllMiniLML6V2, + index_model: AIModel::AllMiniLML6V2, + predicates: HashSet::new(), + non_linear_indices: HashSet::new(), + error_if_exists: true, + store_original: false, + }, + AIQuery::Set { + store: store_name.clone(), + inputs: vec![store_data.clone()], + preprocess_action: PreprocessAction::RawString(StringAction::ErrorIfTokensExceed), + }, + ]); + let mut reader = BufReader::new(first_stream); + + let _ = get_server_response(&mut reader, message).await; + let message = AIServerQuery::from_queries(&[AIQuery::GetKey { store: store_name, keys: vec![store_input.clone()] }]); + + let mut expected = AIServerResult::with_capacity(1); + + expected.push(Ok(AIServerResponse::Get(vec![( + Some(store_input), HashMap::new() + )]))); + + let mut reader = BufReader::new(second_stream); + let response = get_server_response(&mut reader, message).await; + assert!(response.len() == expected.len()) + +} + + +#[tokio::test] +async fn test_list_clients_works() { + let address = provision_test_servers().await; + let _first_stream = TcpStream::connect(address).await.unwrap(); + let second_stream = TcpStream::connect(address).await.unwrap(); + let message = AIServerQuery::from_queries(&[AIQuery::ListClients]); + let mut reader = BufReader::new(second_stream); + let response = get_server_response(&mut reader, message).await; + let inner = response.into_inner(); + + // only two clients are connected + assert!(inner.len() == 2) +} + // TODO: Same issues with random storekeys, changing the order of expected response #[tokio::test] async fn test_ai_store_no_original() { diff --git a/ahnlich/dsl/src/syntax/syntax.pest b/ahnlich/dsl/src/syntax/syntax.pest index 863089a5..de76b1f3 100644 --- a/ahnlich/dsl/src/syntax/syntax.pest +++ b/ahnlich/dsl/src/syntax/syntax.pest @@ -55,7 +55,7 @@ del_key = { whitespace* ~ ^"delkey" ~ whitespace* ~ "(" ~ f32_arrays ~ ")" ~ in_ ai_del_key = { whitespace* ~ ^"delkey" ~ whitespace* ~ "(" ~ store_inputs ~ ")" ~ in_ignored ~ store_name } get_pred = { whitespace* ~ ^"getpred" ~ whitespace* ~ predicate_condition ~ in_ignored ~ store_name } // GETSIMN 2 WITH store-key USING algorithm IN store (WHERE predicate_condition) -get_sim_n = { whitespace* ~ ^"getsimn" ~ whitespace* ~ non_zero ~ whitespace* ~ ^"with" ~ whitespace* ~ f32_array ~ whitespace* ~ ^"using" ~ whitespace* ~ algorithm ~ whitespace* ~ in_ignored ~ whitespace* ~ store_name ~ whitespace* ~ (^"where" ~ whitespace* ~ predicate_condition)? } +get_sim_n = { whitespace* ~ ^"getsimn" ~ whitespace* ~ non_zero ~ whitespace* ~ ^"with" ~ whitespace* ~ f32_array ~ whitespace* ~ ^"using" ~ whitespace* ~ algorithm ~ whitespace* ~ in_ignored ~ whitespace* ~ store_name ~ whitespace* ~ (^"where" ~ whitespace* ~ predicate_condition)? ~ whitespace* ~ ^"preprocessaction" ~ whitespace* ~ preprocess_action } ai_get_sim_n = { whitespace* ~ ^"getsimn" ~ whitespace* ~ non_zero ~ whitespace* ~ ^"with" ~ whitespace* ~ "[" ~ whitespace* ~ metadata_value ~ whitespace* ~ "]" ~ whitespace* ~ ^"using" ~ whitespace* ~ algorithm ~ whitespace* ~ in_ignored ~ whitespace* ~ store_name ~ whitespace* ~ (^"where" ~ whitespace* ~ predicate_condition)? } // CREATESTORE IF NOT EXISTS store-name DIMENSION non-zero-size PREDICATES (key1, key2) NONLINEARALGORITHMINDEX (kdtree) create_store = { whitespace* ~ ^"createstore" ~ whitespace* ~ (if_not_exists)? ~ whitespace* ~ store_name ~ whitespace* ~ ^"dimension" ~ whitespace* ~ non_zero ~ whitespace* ~ (^"predicates" ~ whitespace* ~ "(" ~ whitespace* ~ metadata_keys ~ whitespace* ~ ")" )? ~ (whitespace* ~ ^"nonlinearalgorithmindex" ~ whitespace* ~ "(" ~ whitespace* ~ non_linear_algorithms ~ whitespace* ~ ")")? } From 94e4c4521460cc329dbae665b1fa785b434467ab Mon Sep 17 00:00:00 2001 From: lordsarcastic Date: Sat, 30 Nov 2024 01:25:02 +0100 Subject: [PATCH 03/10] Chore: Ran formatting --- ahnlich/ai/src/engine/store.rs | 1 - ahnlich/ai/src/server/task.rs | 23 ++++++++---------- ahnlich/ai/src/tests/aiproxy_test.rs | 32 +++++++++++++++---------- ahnlich/client/src/ai.rs | 5 ++-- ahnlich/dsl/src/ai.rs | 2 +- ahnlich/typegen/src/tracers/query/ai.rs | 2 +- 6 files changed, 35 insertions(+), 30 deletions(-) diff --git a/ahnlich/ai/src/engine/store.rs b/ahnlich/ai/src/engine/store.rs index f9569b78..d2f02215 100644 --- a/ahnlich/ai/src/engine/store.rs +++ b/ahnlich/ai/src/engine/store.rs @@ -328,7 +328,6 @@ impl AIStoreHandler { self.stores.clear(&guard); store_length } - } #[derive(Debug, Serialize, Deserialize)] diff --git a/ahnlich/ai/src/server/task.rs b/ahnlich/ai/src/server/task.rs index 448357b8..8db29fbc 100644 --- a/ahnlich/ai/src/server/task.rs +++ b/ahnlich/ai/src/server/task.rs @@ -1,8 +1,6 @@ use crate::engine::ai::models::Model; use ahnlich_client_rs::db::DbClient; -use ahnlich_types::ai::{ - AIQuery, AIServerQuery, AIServerResponse, AIServerResult -}; +use ahnlich_types::ai::{AIQuery, AIServerQuery, AIServerResponse, AIServerResult}; use ahnlich_types::client::ConnectedClient; use ahnlich_types::db::{ServerInfo, ServerResponse}; use ahnlich_types::metadata::MetadataValue; @@ -375,16 +373,16 @@ impl AhnlichProtocol for AIProxyTask { let destoryed = self.store_handler.purge_stores(); Ok(AIServerResponse::Del(destoryed)) } - AIQuery::ListClients => Ok(AIServerResponse::ClientList(self.client_handler.list())), + AIQuery::ListClients => { + Ok(AIServerResponse::ClientList(self.client_handler.list())) + } AIQuery::GetKey { store, keys } => { - let metadata_values: HashSet = keys.into_iter().map( - |value| value.into() - ).collect(); - let get_key_condition = - PredicateCondition::Value(Predicate::In { - key: AHNLICH_AI_RESERVED_META_KEY.clone(), - value: metadata_values, - }); + let metadata_values: HashSet = + keys.into_iter().map(|value| value.into()).collect(); + let get_key_condition = PredicateCondition::Value(Predicate::In { + key: AHNLICH_AI_RESERVED_META_KEY.clone(), + value: metadata_values, + }); match self .db_client @@ -412,7 +410,6 @@ impl AhnlichProtocol for AIProxyTask { } } - impl AIProxyTask { #[tracing::instrument(skip(self))] fn server_info(&self) -> ServerInfo { diff --git a/ahnlich/ai/src/tests/aiproxy_test.rs b/ahnlich/ai/src/tests/aiproxy_test.rs index f556a76f..41c69276 100644 --- a/ahnlich/ai/src/tests/aiproxy_test.rs +++ b/ahnlich/ai/src/tests/aiproxy_test.rs @@ -4,14 +4,23 @@ use ahnlich_types::{ ai::{ AIModel, AIQuery, AIServerQuery, AIServerResponse, AIServerResult, AIStoreInfo, ImageAction, PreprocessAction, StringAction, - }, client::ConnectedClient, db::StoreUpsert, keyval::{StoreInput, StoreName, StoreValue}, metadata::{MetadataKey, MetadataValue}, predicate::{Predicate, PredicateCondition}, similarity::Algorithm + }, + db::StoreUpsert, + keyval::{StoreInput, StoreName, StoreValue}, + metadata::{MetadataKey, MetadataValue}, + predicate::{Predicate, PredicateCondition}, + similarity::Algorithm, }; // use flurry::HashMap; use utils::server::AhnlichServerUtils; use once_cell::sync::Lazy; use pretty_assertions::assert_eq; -use std::{collections::{HashMap, HashSet}, num::NonZeroUsize, sync::atomic::Ordering}; +use std::{ + collections::{HashMap, HashSet}, + num::NonZeroUsize, + sync::atomic::Ordering, +}; use crate::{ cli::{server::SupportedModels, AIProxyConfig}, @@ -158,7 +167,6 @@ async fn test_ai_proxy_create_store_success() { query_server_assert_result(&mut reader, message, expected.clone()).await; } - #[tokio::test] async fn test_ai_store_get_key_works() { let address = provision_test_servers().await; @@ -166,10 +174,8 @@ async fn test_ai_store_get_key_works() { let second_stream = TcpStream::connect(address).await.unwrap(); let store_name = StoreName(String::from("Deven Kicks")); let store_input = StoreInput::RawString(String::from("Jordan 3")); - let store_data: (StoreInput, HashMap) = ( - store_input.clone(), - HashMap::new() - ); + let store_data: (StoreInput, HashMap) = + (store_input.clone(), HashMap::new()); let message = AIServerQuery::from_queries(&[ AIQuery::CreateStore { @@ -190,21 +196,23 @@ async fn test_ai_store_get_key_works() { let mut reader = BufReader::new(first_stream); let _ = get_server_response(&mut reader, message).await; - let message = AIServerQuery::from_queries(&[AIQuery::GetKey { store: store_name, keys: vec![store_input.clone()] }]); + let message = AIServerQuery::from_queries(&[AIQuery::GetKey { + store: store_name, + keys: vec![store_input.clone()], + }]); let mut expected = AIServerResult::with_capacity(1); expected.push(Ok(AIServerResponse::Get(vec![( - Some(store_input), HashMap::new() + Some(store_input), + HashMap::new(), )]))); let mut reader = BufReader::new(second_stream); let response = get_server_response(&mut reader, message).await; assert!(response.len() == expected.len()) - } - #[tokio::test] async fn test_list_clients_works() { let address = provision_test_servers().await; @@ -214,7 +222,7 @@ async fn test_list_clients_works() { let mut reader = BufReader::new(second_stream); let response = get_server_response(&mut reader, message).await; let inner = response.into_inner(); - + // only two clients are connected assert!(inner.len() == 2) } diff --git a/ahnlich/client/src/ai.rs b/ahnlich/client/src/ai.rs index 864b2ec5..cdcf68c8 100644 --- a/ahnlich/client/src/ai.rs +++ b/ahnlich/client/src/ai.rs @@ -92,7 +92,7 @@ impl AIPipeline { condition, closest_n, algorithm, - preprocess_action + preprocess_action, }) } @@ -250,6 +250,7 @@ impl AIClient { .await } + #[allow(clippy::too_many_arguments)] pub async fn get_sim_n( &self, store: StoreName, @@ -267,7 +268,7 @@ impl AIClient { condition, closest_n, algorithm, - preprocess_action + preprocess_action, }, tracing_id, ) diff --git a/ahnlich/dsl/src/ai.rs b/ahnlich/dsl/src/ai.rs index fb2bce34..bf12ddee 100644 --- a/ahnlich/dsl/src/ai.rs +++ b/ahnlich/dsl/src/ai.rs @@ -202,7 +202,7 @@ pub fn parse_ai_query(input: &str) -> Result, DslError> { closest_n, algorithm, condition, - preprocess_action + preprocess_action, } } Rule::get_pred => { diff --git a/ahnlich/typegen/src/tracers/query/ai.rs b/ahnlich/typegen/src/tracers/query/ai.rs index 88357427..bd7f5121 100644 --- a/ahnlich/typegen/src/tracers/query/ai.rs +++ b/ahnlich/typegen/src/tracers/query/ai.rs @@ -66,7 +66,7 @@ pub fn trace_ai_query_enum() -> Registry { condition: Some(test_predicate_condition.clone()), closest_n: NonZeroUsize::new(4).unwrap(), algorithm: Algorithm::CosineSimilarity, - preprocess_action: PreprocessAction::RawString(StringAction::TruncateIfTokensExceed), + preprocess_action: PreprocessAction::RawString(StringAction::TruncateIfTokensExceed), }; let create_index = AIQuery::CreatePredIndex { From c2b26a1c0aaae1a21826bb4a96d2dae0d738681f Mon Sep 17 00:00:00 2001 From: lordsarcastic Date: Sat, 30 Nov 2024 01:34:09 +0100 Subject: [PATCH 04/10] Fix: Usage of Preprocess action removed --- ahnlich/ai/src/server/task.rs | 14 ++++++++------ ahnlich/ai/src/tests/aiproxy_test.rs | 4 ++-- ahnlich/dsl/src/tests/ai.rs | 4 ++-- ahnlich/typegen/src/tracers/query/ai.rs | 2 +- 4 files changed, 13 insertions(+), 11 deletions(-) diff --git a/ahnlich/ai/src/server/task.rs b/ahnlich/ai/src/server/task.rs index 53704b6b..17db43c7 100644 --- a/ahnlich/ai/src/server/task.rs +++ b/ahnlich/ai/src/server/task.rs @@ -1,7 +1,7 @@ use crate::engine::ai::models::Model; use ahnlich_client_rs::{builders::db as db_params, db::DbClient}; use ahnlich_types::ai::{ - AIQuery, AIServerQuery, AIServerResponse, AIServerResult, PreprocessAction, + AIQuery, AIServerQuery, AIServerResponse, AIServerResult, }; use ahnlich_types::client::ConnectedClient; use ahnlich_types::db::{ServerInfo, ServerResponse}; @@ -410,11 +410,13 @@ impl AhnlichProtocol for AIProxyTask { value: metadata_values, }); - match self - .db_client - .get_pred(store, get_key_condition, parent_id.clone()) - .await - { + let get_pred_params = db_params::GetPredParams::builder() + .store(store.to_string()) + .condition(get_key_condition) + .tracing_id(parent_id.clone()) + .build(); + + match self.db_client.get_pred(get_pred_params).await { Ok(res) => { if let ServerResponse::Get(response) = res { // conversion to store input here diff --git a/ahnlich/ai/src/tests/aiproxy_test.rs b/ahnlich/ai/src/tests/aiproxy_test.rs index 8e3bc2a0..27fb5fde 100644 --- a/ahnlich/ai/src/tests/aiproxy_test.rs +++ b/ahnlich/ai/src/tests/aiproxy_test.rs @@ -190,7 +190,7 @@ async fn test_ai_store_get_key_works() { AIQuery::Set { store: store_name.clone(), inputs: vec![store_data.clone()], - preprocess_action: PreprocessAction::RawString(StringAction::ErrorIfTokensExceed), + preprocess_action: PreprocessAction::NoPreprocessing, }, ]); let mut reader = BufReader::new(first_stream); @@ -440,7 +440,7 @@ async fn test_ai_proxy_get_sim_n_succeeds() { condition: None, closest_n: NonZeroUsize::new(1).unwrap(), algorithm: Algorithm::DotProductSimilarity, - preprocess_action: PreprocessAction::RawString(StringAction::ErrorIfTokensExceed), + preprocess_action: PreprocessAction::ModelPreprocessing, }]); let mut expected = AIServerResult::with_capacity(1); diff --git a/ahnlich/dsl/src/tests/ai.rs b/ahnlich/dsl/src/tests/ai.rs index f15b6ec4..d0209f70 100644 --- a/ahnlich/dsl/src/tests/ai.rs +++ b/ahnlich/dsl/src/tests/ai.rs @@ -208,7 +208,7 @@ fn test_get_sim_n_parse() { closest_n: NonZeroUsize::new(5).unwrap(), algorithm: Algorithm::CosineSimilarity, condition: None, - preprocess_action: PreprocessAction::RawString(StringAction::TruncateIfTokensExceed), + preprocess_action: PreprocessAction::ModelPreprocessing, }] ); let input = r#"GETSIMN 8 with [testing the limits of life] using euclideandistance in other where ((year != 2012) AND (month not in (december, october)))"#; @@ -232,7 +232,7 @@ fn test_get_sim_n_parse() { ]), })) ), - preprocess_action: PreprocessAction::RawString(StringAction::TruncateIfTokensExceed), + preprocess_action: PreprocessAction::ModelPreprocessing, }] ); } diff --git a/ahnlich/typegen/src/tracers/query/ai.rs b/ahnlich/typegen/src/tracers/query/ai.rs index cab245dd..22cc2969 100644 --- a/ahnlich/typegen/src/tracers/query/ai.rs +++ b/ahnlich/typegen/src/tracers/query/ai.rs @@ -66,7 +66,7 @@ pub fn trace_ai_query_enum() -> Registry { condition: Some(test_predicate_condition.clone()), closest_n: NonZeroUsize::new(4).unwrap(), algorithm: Algorithm::CosineSimilarity, - preprocess_action: PreprocessAction::RawString(StringAction::TruncateIfTokensExceed), + preprocess_action: PreprocessAction::ModelPreprocessing, }; let create_index = AIQuery::CreatePredIndex { From 79cfcd46fd60c7bfc8144d0845c9dbd5828e1870 Mon Sep 17 00:00:00 2001 From: lordsarcastic Date: Sat, 30 Nov 2024 02:44:50 +0100 Subject: [PATCH 05/10] Chore: Ran typegen for clients --- .../ahnlich_client_py/internals/ai_query.py | 67 ++++++++++--------- .../internals/ai_response.py | 56 ++++++---------- .../internals/bincode/__init__.py | 4 +- .../ahnlich_client_py/internals/db_query.py | 29 +++----- .../internals/db_response.py | 41 +++++------- .../internals/serde_binary/__init__.py | 2 +- .../internals/serde_types/__init__.py | 5 +- type_specs/query/ai_query.json | 30 ++++++++- 8 files changed, 115 insertions(+), 119 deletions(-) diff --git a/sdk/ahnlich-client-py/ahnlich_client_py/internals/ai_query.py b/sdk/ahnlich-client-py/ahnlich_client_py/internals/ai_query.py index 368bddbb..88f753a2 100644 --- a/sdk/ahnlich-client-py/ahnlich_client_py/internals/ai_query.py +++ b/sdk/ahnlich-client-py/ahnlich_client_py/internals/ai_query.py @@ -1,10 +1,8 @@ # pyre-strict -import typing from dataclasses import dataclass - -from ahnlich_client_py.internals import bincode +import typing from ahnlich_client_py.internals import serde_types as st - +from ahnlich_client_py.internals import bincode class AIModel: VARIANTS = [] # type: typing.Sequence[typing.Type[AIModel]] @@ -13,7 +11,7 @@ def bincode_serialize(self) -> bytes: return bincode.serialize(self, AIModel) @staticmethod - def bincode_deserialize(input: bytes) -> "AIModel": + def bincode_deserialize(input: bytes) -> 'AIModel': v, buffer = bincode.deserialize(input, AIModel) if buffer: raise st.DeserializationError("Some input bytes were not read") @@ -61,7 +59,6 @@ class AIModel__ClipVitB32Text(AIModel): INDEX = 6 # type: int pass - AIModel.VARIANTS = [ AIModel__AllMiniLML6V2, AIModel__AllMiniLML12V2, @@ -80,7 +77,7 @@ def bincode_serialize(self) -> bytes: return bincode.serialize(self, AIQuery) @staticmethod - def bincode_deserialize(input: bytes) -> "AIQuery": + def bincode_deserialize(input: bytes) -> 'AIQuery': v, buffer = bincode.deserialize(input, AIQuery) if buffer: raise st.DeserializationError("Some input bytes were not read") @@ -114,6 +111,7 @@ class AIQuery__GetSimN(AIQuery): condition: typing.Optional["PredicateCondition"] closest_n: st.uint64 algorithm: "Algorithm" + preprocess_action: "PreprocessAction" @dataclass(frozen=True) @@ -150,9 +148,7 @@ class AIQuery__DropNonLinearAlgorithmIndex(AIQuery): class AIQuery__Set(AIQuery): INDEX = 7 # type: int store: str - inputs: typing.Sequence[ - typing.Tuple["StoreInput", typing.Dict[str, "MetadataValue"]] - ] + inputs: typing.Sequence[typing.Tuple["StoreInput", typing.Dict[str, "MetadataValue"]]] preprocess_action: "PreprocessAction" @@ -171,29 +167,41 @@ class AIQuery__DropStore(AIQuery): @dataclass(frozen=True) -class AIQuery__InfoServer(AIQuery): +class AIQuery__GetKey(AIQuery): INDEX = 10 # type: int - pass + store: str + keys: typing.Sequence["StoreInput"] @dataclass(frozen=True) -class AIQuery__ListStores(AIQuery): +class AIQuery__InfoServer(AIQuery): INDEX = 11 # type: int pass @dataclass(frozen=True) -class AIQuery__PurgeStores(AIQuery): +class AIQuery__ListClients(AIQuery): INDEX = 12 # type: int pass @dataclass(frozen=True) -class AIQuery__Ping(AIQuery): +class AIQuery__ListStores(AIQuery): INDEX = 13 # type: int pass +@dataclass(frozen=True) +class AIQuery__PurgeStores(AIQuery): + INDEX = 14 # type: int + pass + + +@dataclass(frozen=True) +class AIQuery__Ping(AIQuery): + INDEX = 15 # type: int + pass + AIQuery.VARIANTS = [ AIQuery__CreateStore, AIQuery__GetPred, @@ -205,7 +213,9 @@ class AIQuery__Ping(AIQuery): AIQuery__Set, AIQuery__DelKey, AIQuery__DropStore, + AIQuery__GetKey, AIQuery__InfoServer, + AIQuery__ListClients, AIQuery__ListStores, AIQuery__PurgeStores, AIQuery__Ping, @@ -221,7 +231,7 @@ def bincode_serialize(self) -> bytes: return bincode.serialize(self, AIServerQuery) @staticmethod - def bincode_deserialize(input: bytes) -> "AIServerQuery": + def bincode_deserialize(input: bytes) -> 'AIServerQuery': v, buffer = bincode.deserialize(input, AIServerQuery) if buffer: raise st.DeserializationError("Some input bytes were not read") @@ -235,7 +245,7 @@ def bincode_serialize(self) -> bytes: return bincode.serialize(self, AIStoreInputType) @staticmethod - def bincode_deserialize(input: bytes) -> "AIStoreInputType": + def bincode_deserialize(input: bytes) -> 'AIStoreInputType': v, buffer = bincode.deserialize(input, AIStoreInputType) if buffer: raise st.DeserializationError("Some input bytes were not read") @@ -253,7 +263,6 @@ class AIStoreInputType__Image(AIStoreInputType): INDEX = 1 # type: int pass - AIStoreInputType.VARIANTS = [ AIStoreInputType__RawString, AIStoreInputType__Image, @@ -267,7 +276,7 @@ def bincode_serialize(self) -> bytes: return bincode.serialize(self, Algorithm) @staticmethod - def bincode_deserialize(input: bytes) -> "Algorithm": + def bincode_deserialize(input: bytes) -> 'Algorithm': v, buffer = bincode.deserialize(input, Algorithm) if buffer: raise st.DeserializationError("Some input bytes were not read") @@ -297,7 +306,6 @@ class Algorithm__KDTree(Algorithm): INDEX = 3 # type: int pass - Algorithm.VARIANTS = [ Algorithm__EuclideanDistance, Algorithm__DotProductSimilarity, @@ -313,7 +321,7 @@ def bincode_serialize(self) -> bytes: return bincode.serialize(self, MetadataValue) @staticmethod - def bincode_deserialize(input: bytes) -> "MetadataValue": + def bincode_deserialize(input: bytes) -> 'MetadataValue': v, buffer = bincode.deserialize(input, MetadataValue) if buffer: raise st.DeserializationError("Some input bytes were not read") @@ -331,7 +339,6 @@ class MetadataValue__Image(MetadataValue): INDEX = 1 # type: int value: typing.Sequence[st.uint8] - MetadataValue.VARIANTS = [ MetadataValue__RawString, MetadataValue__Image, @@ -345,7 +352,7 @@ def bincode_serialize(self) -> bytes: return bincode.serialize(self, NonLinearAlgorithm) @staticmethod - def bincode_deserialize(input: bytes) -> "NonLinearAlgorithm": + def bincode_deserialize(input: bytes) -> 'NonLinearAlgorithm': v, buffer = bincode.deserialize(input, NonLinearAlgorithm) if buffer: raise st.DeserializationError("Some input bytes were not read") @@ -357,7 +364,6 @@ class NonLinearAlgorithm__KDTree(NonLinearAlgorithm): INDEX = 0 # type: int pass - NonLinearAlgorithm.VARIANTS = [ NonLinearAlgorithm__KDTree, ] @@ -370,7 +376,7 @@ def bincode_serialize(self) -> bytes: return bincode.serialize(self, Predicate) @staticmethod - def bincode_deserialize(input: bytes) -> "Predicate": + def bincode_deserialize(input: bytes) -> 'Predicate': v, buffer = bincode.deserialize(input, Predicate) if buffer: raise st.DeserializationError("Some input bytes were not read") @@ -404,7 +410,6 @@ class Predicate__NotIn(Predicate): key: str value: typing.Sequence["MetadataValue"] - Predicate.VARIANTS = [ Predicate__Equals, Predicate__NotEquals, @@ -420,7 +425,7 @@ def bincode_serialize(self) -> bytes: return bincode.serialize(self, PredicateCondition) @staticmethod - def bincode_deserialize(input: bytes) -> "PredicateCondition": + def bincode_deserialize(input: bytes) -> 'PredicateCondition': v, buffer = bincode.deserialize(input, PredicateCondition) if buffer: raise st.DeserializationError("Some input bytes were not read") @@ -444,7 +449,6 @@ class PredicateCondition__Or(PredicateCondition): INDEX = 2 # type: int value: typing.Tuple["PredicateCondition", "PredicateCondition"] - PredicateCondition.VARIANTS = [ PredicateCondition__Value, PredicateCondition__And, @@ -459,7 +463,7 @@ def bincode_serialize(self) -> bytes: return bincode.serialize(self, PreprocessAction) @staticmethod - def bincode_deserialize(input: bytes) -> "PreprocessAction": + def bincode_deserialize(input: bytes) -> 'PreprocessAction': v, buffer = bincode.deserialize(input, PreprocessAction) if buffer: raise st.DeserializationError("Some input bytes were not read") @@ -477,7 +481,6 @@ class PreprocessAction__ModelPreprocessing(PreprocessAction): INDEX = 1 # type: int pass - PreprocessAction.VARIANTS = [ PreprocessAction__NoPreprocessing, PreprocessAction__ModelPreprocessing, @@ -491,7 +494,7 @@ def bincode_serialize(self) -> bytes: return bincode.serialize(self, StoreInput) @staticmethod - def bincode_deserialize(input: bytes) -> "StoreInput": + def bincode_deserialize(input: bytes) -> 'StoreInput': v, buffer = bincode.deserialize(input, StoreInput) if buffer: raise st.DeserializationError("Some input bytes were not read") @@ -509,8 +512,8 @@ class StoreInput__Image(StoreInput): INDEX = 1 # type: int value: typing.Sequence[st.uint8] - StoreInput.VARIANTS = [ StoreInput__RawString, StoreInput__Image, ] + diff --git a/sdk/ahnlich-client-py/ahnlich_client_py/internals/ai_response.py b/sdk/ahnlich-client-py/ahnlich_client_py/internals/ai_response.py index e71b20ca..3312d280 100644 --- a/sdk/ahnlich-client-py/ahnlich_client_py/internals/ai_response.py +++ b/sdk/ahnlich-client-py/ahnlich_client_py/internals/ai_response.py @@ -1,10 +1,8 @@ # pyre-strict -import typing from dataclasses import dataclass - -from ahnlich_client_py.internals import bincode +import typing from ahnlich_client_py.internals import serde_types as st - +from ahnlich_client_py.internals import bincode class AIModel: VARIANTS = [] # type: typing.Sequence[typing.Type[AIModel]] @@ -13,7 +11,7 @@ def bincode_serialize(self) -> bytes: return bincode.serialize(self, AIModel) @staticmethod - def bincode_deserialize(input: bytes) -> "AIModel": + def bincode_deserialize(input: bytes) -> 'AIModel': v, buffer = bincode.deserialize(input, AIModel) if buffer: raise st.DeserializationError("Some input bytes were not read") @@ -61,7 +59,6 @@ class AIModel__ClipVitB32Text(AIModel): INDEX = 6 # type: int pass - AIModel.VARIANTS = [ AIModel__AllMiniLML6V2, AIModel__AllMiniLML12V2, @@ -80,7 +77,7 @@ def bincode_serialize(self) -> bytes: return bincode.serialize(self, AIServerResponse) @staticmethod - def bincode_deserialize(input: bytes) -> "AIServerResponse": + def bincode_deserialize(input: bytes) -> 'AIServerResponse': v, buffer = bincode.deserialize(input, AIServerResponse) if buffer: raise st.DeserializationError("Some input bytes were not read") @@ -126,21 +123,13 @@ class AIServerResponse__Set(AIServerResponse): @dataclass(frozen=True) class AIServerResponse__Get(AIServerResponse): INDEX = 6 # type: int - value: typing.Sequence[ - typing.Tuple[typing.Optional["StoreInput"], typing.Dict[str, "MetadataValue"]] - ] + value: typing.Sequence[typing.Tuple[typing.Optional["StoreInput"], typing.Dict[str, "MetadataValue"]]] @dataclass(frozen=True) class AIServerResponse__GetSimN(AIServerResponse): INDEX = 7 # type: int - value: typing.Sequence[ - typing.Tuple[ - typing.Optional["StoreInput"], - typing.Dict[str, "MetadataValue"], - "Similarity", - ] - ] + value: typing.Sequence[typing.Tuple[typing.Optional["StoreInput"], typing.Dict[str, "MetadataValue"], "Similarity"]] @dataclass(frozen=True) @@ -154,7 +143,6 @@ class AIServerResponse__CreateIndex(AIServerResponse): INDEX = 9 # type: int value: st.uint64 - AIServerResponse.VARIANTS = [ AIServerResponse__Unit, AIServerResponse__Pong, @@ -177,7 +165,7 @@ def bincode_serialize(self) -> bytes: return bincode.serialize(self, AIServerResult) @staticmethod - def bincode_deserialize(input: bytes) -> "AIServerResult": + def bincode_deserialize(input: bytes) -> 'AIServerResult': v, buffer = bincode.deserialize(input, AIServerResult) if buffer: raise st.DeserializationError("Some input bytes were not read") @@ -195,7 +183,7 @@ def bincode_serialize(self) -> bytes: return bincode.serialize(self, AIStoreInfo) @staticmethod - def bincode_deserialize(input: bytes) -> "AIStoreInfo": + def bincode_deserialize(input: bytes) -> 'AIStoreInfo': v, buffer = bincode.deserialize(input, AIStoreInfo) if buffer: raise st.DeserializationError("Some input bytes were not read") @@ -209,7 +197,7 @@ def bincode_serialize(self) -> bytes: return bincode.serialize(self, AIStoreInputType) @staticmethod - def bincode_deserialize(input: bytes) -> "AIStoreInputType": + def bincode_deserialize(input: bytes) -> 'AIStoreInputType': v, buffer = bincode.deserialize(input, AIStoreInputType) if buffer: raise st.DeserializationError("Some input bytes were not read") @@ -227,7 +215,6 @@ class AIStoreInputType__Image(AIStoreInputType): INDEX = 1 # type: int pass - AIStoreInputType.VARIANTS = [ AIStoreInputType__RawString, AIStoreInputType__Image, @@ -243,7 +230,7 @@ def bincode_serialize(self) -> bytes: return bincode.serialize(self, ConnectedClient) @staticmethod - def bincode_deserialize(input: bytes) -> "ConnectedClient": + def bincode_deserialize(input: bytes) -> 'ConnectedClient': v, buffer = bincode.deserialize(input, ConnectedClient) if buffer: raise st.DeserializationError("Some input bytes were not read") @@ -257,7 +244,7 @@ def bincode_serialize(self) -> bytes: return bincode.serialize(self, MetadataValue) @staticmethod - def bincode_deserialize(input: bytes) -> "MetadataValue": + def bincode_deserialize(input: bytes) -> 'MetadataValue': v, buffer = bincode.deserialize(input, MetadataValue) if buffer: raise st.DeserializationError("Some input bytes were not read") @@ -275,7 +262,6 @@ class MetadataValue__Image(MetadataValue): INDEX = 1 # type: int value: typing.Sequence[st.uint8] - MetadataValue.VARIANTS = [ MetadataValue__RawString, MetadataValue__Image, @@ -289,7 +275,7 @@ def bincode_serialize(self) -> bytes: return bincode.serialize(self, Result) @staticmethod - def bincode_deserialize(input: bytes) -> "Result": + def bincode_deserialize(input: bytes) -> 'Result': v, buffer = bincode.deserialize(input, Result) if buffer: raise st.DeserializationError("Some input bytes were not read") @@ -307,7 +293,6 @@ class Result__Err(Result): INDEX = 1 # type: int value: str - Result.VARIANTS = [ Result__Ok, Result__Err, @@ -326,7 +311,7 @@ def bincode_serialize(self) -> bytes: return bincode.serialize(self, ServerInfo) @staticmethod - def bincode_deserialize(input: bytes) -> "ServerInfo": + def bincode_deserialize(input: bytes) -> 'ServerInfo': v, buffer = bincode.deserialize(input, ServerInfo) if buffer: raise st.DeserializationError("Some input bytes were not read") @@ -340,7 +325,7 @@ def bincode_serialize(self) -> bytes: return bincode.serialize(self, ServerType) @staticmethod - def bincode_deserialize(input: bytes) -> "ServerType": + def bincode_deserialize(input: bytes) -> 'ServerType': v, buffer = bincode.deserialize(input, ServerType) if buffer: raise st.DeserializationError("Some input bytes were not read") @@ -358,7 +343,6 @@ class ServerType__AI(ServerType): INDEX = 1 # type: int pass - ServerType.VARIANTS = [ ServerType__Database, ServerType__AI, @@ -373,7 +357,7 @@ def bincode_serialize(self) -> bytes: return bincode.serialize(self, Similarity) @staticmethod - def bincode_deserialize(input: bytes) -> "Similarity": + def bincode_deserialize(input: bytes) -> 'Similarity': v, buffer = bincode.deserialize(input, Similarity) if buffer: raise st.DeserializationError("Some input bytes were not read") @@ -387,7 +371,7 @@ def bincode_serialize(self) -> bytes: return bincode.serialize(self, StoreInput) @staticmethod - def bincode_deserialize(input: bytes) -> "StoreInput": + def bincode_deserialize(input: bytes) -> 'StoreInput': v, buffer = bincode.deserialize(input, StoreInput) if buffer: raise st.DeserializationError("Some input bytes were not read") @@ -405,7 +389,6 @@ class StoreInput__Image(StoreInput): INDEX = 1 # type: int value: typing.Sequence[st.uint8] - StoreInput.VARIANTS = [ StoreInput__RawString, StoreInput__Image, @@ -421,7 +404,7 @@ def bincode_serialize(self) -> bytes: return bincode.serialize(self, StoreUpsert) @staticmethod - def bincode_deserialize(input: bytes) -> "StoreUpsert": + def bincode_deserialize(input: bytes) -> 'StoreUpsert': v, buffer = bincode.deserialize(input, StoreUpsert) if buffer: raise st.DeserializationError("Some input bytes were not read") @@ -437,7 +420,7 @@ def bincode_serialize(self) -> bytes: return bincode.serialize(self, SystemTime) @staticmethod - def bincode_deserialize(input: bytes) -> "SystemTime": + def bincode_deserialize(input: bytes) -> 'SystemTime': v, buffer = bincode.deserialize(input, SystemTime) if buffer: raise st.DeserializationError("Some input bytes were not read") @@ -454,8 +437,9 @@ def bincode_serialize(self) -> bytes: return bincode.serialize(self, Version) @staticmethod - def bincode_deserialize(input: bytes) -> "Version": + def bincode_deserialize(input: bytes) -> 'Version': v, buffer = bincode.deserialize(input, Version) if buffer: raise st.DeserializationError("Some input bytes were not read") return v + diff --git a/sdk/ahnlich-client-py/ahnlich_client_py/internals/bincode/__init__.py b/sdk/ahnlich-client-py/ahnlich_client_py/internals/bincode/__init__.py index 38cbd7ff..4e5e0837 100644 --- a/sdk/ahnlich-client-py/ahnlich_client_py/internals/bincode/__init__.py +++ b/sdk/ahnlich-client-py/ahnlich_client_py/internals/bincode/__init__.py @@ -1,16 +1,16 @@ # Copyright (c) Facebook, Inc. and its affiliates # SPDX-License-Identifier: MIT OR Apache-2.0 -import collections import dataclasses +import collections import io import struct import typing from copy import copy from typing import get_type_hints -from ahnlich_client_py.internals import serde_binary as sb from ahnlich_client_py.internals import serde_types as st +from ahnlich_client_py.internals import serde_binary as sb # Maximum length in practice for sequences (e.g. in Java). MAX_LENGTH = (1 << 31) - 1 diff --git a/sdk/ahnlich-client-py/ahnlich_client_py/internals/db_query.py b/sdk/ahnlich-client-py/ahnlich_client_py/internals/db_query.py index b281f346..30713685 100644 --- a/sdk/ahnlich-client-py/ahnlich_client_py/internals/db_query.py +++ b/sdk/ahnlich-client-py/ahnlich_client_py/internals/db_query.py @@ -1,10 +1,8 @@ # pyre-strict -import typing from dataclasses import dataclass - -from ahnlich_client_py.internals import bincode +import typing from ahnlich_client_py.internals import serde_types as st - +from ahnlich_client_py.internals import bincode class Algorithm: VARIANTS = [] # type: typing.Sequence[typing.Type[Algorithm]] @@ -13,7 +11,7 @@ def bincode_serialize(self) -> bytes: return bincode.serialize(self, Algorithm) @staticmethod - def bincode_deserialize(input: bytes) -> "Algorithm": + def bincode_deserialize(input: bytes) -> 'Algorithm': v, buffer = bincode.deserialize(input, Algorithm) if buffer: raise st.DeserializationError("Some input bytes were not read") @@ -43,7 +41,6 @@ class Algorithm__KDTree(Algorithm): INDEX = 3 # type: int pass - Algorithm.VARIANTS = [ Algorithm__EuclideanDistance, Algorithm__DotProductSimilarity, @@ -62,7 +59,7 @@ def bincode_serialize(self) -> bytes: return bincode.serialize(self, Array) @staticmethod - def bincode_deserialize(input: bytes) -> "Array": + def bincode_deserialize(input: bytes) -> 'Array': v, buffer = bincode.deserialize(input, Array) if buffer: raise st.DeserializationError("Some input bytes were not read") @@ -76,7 +73,7 @@ def bincode_serialize(self) -> bytes: return bincode.serialize(self, MetadataValue) @staticmethod - def bincode_deserialize(input: bytes) -> "MetadataValue": + def bincode_deserialize(input: bytes) -> 'MetadataValue': v, buffer = bincode.deserialize(input, MetadataValue) if buffer: raise st.DeserializationError("Some input bytes were not read") @@ -94,7 +91,6 @@ class MetadataValue__Image(MetadataValue): INDEX = 1 # type: int value: typing.Sequence[st.uint8] - MetadataValue.VARIANTS = [ MetadataValue__RawString, MetadataValue__Image, @@ -108,7 +104,7 @@ def bincode_serialize(self) -> bytes: return bincode.serialize(self, NonLinearAlgorithm) @staticmethod - def bincode_deserialize(input: bytes) -> "NonLinearAlgorithm": + def bincode_deserialize(input: bytes) -> 'NonLinearAlgorithm': v, buffer = bincode.deserialize(input, NonLinearAlgorithm) if buffer: raise st.DeserializationError("Some input bytes were not read") @@ -120,7 +116,6 @@ class NonLinearAlgorithm__KDTree(NonLinearAlgorithm): INDEX = 0 # type: int pass - NonLinearAlgorithm.VARIANTS = [ NonLinearAlgorithm__KDTree, ] @@ -133,7 +128,7 @@ def bincode_serialize(self) -> bytes: return bincode.serialize(self, Predicate) @staticmethod - def bincode_deserialize(input: bytes) -> "Predicate": + def bincode_deserialize(input: bytes) -> 'Predicate': v, buffer = bincode.deserialize(input, Predicate) if buffer: raise st.DeserializationError("Some input bytes were not read") @@ -167,7 +162,6 @@ class Predicate__NotIn(Predicate): key: str value: typing.Sequence["MetadataValue"] - Predicate.VARIANTS = [ Predicate__Equals, Predicate__NotEquals, @@ -183,7 +177,7 @@ def bincode_serialize(self) -> bytes: return bincode.serialize(self, PredicateCondition) @staticmethod - def bincode_deserialize(input: bytes) -> "PredicateCondition": + def bincode_deserialize(input: bytes) -> 'PredicateCondition': v, buffer = bincode.deserialize(input, PredicateCondition) if buffer: raise st.DeserializationError("Some input bytes were not read") @@ -207,7 +201,6 @@ class PredicateCondition__Or(PredicateCondition): INDEX = 2 # type: int value: typing.Tuple["PredicateCondition", "PredicateCondition"] - PredicateCondition.VARIANTS = [ PredicateCondition__Value, PredicateCondition__And, @@ -222,7 +215,7 @@ def bincode_serialize(self) -> bytes: return bincode.serialize(self, Query) @staticmethod - def bincode_deserialize(input: bytes) -> "Query": + def bincode_deserialize(input: bytes) -> 'Query': v, buffer = bincode.deserialize(input, Query) if buffer: raise st.DeserializationError("Some input bytes were not read") @@ -344,7 +337,6 @@ class Query__Ping(Query): INDEX = 15 # type: int pass - Query.VARIANTS = [ Query__CreateStore, Query__GetKey, @@ -374,8 +366,9 @@ def bincode_serialize(self) -> bytes: return bincode.serialize(self, ServerQuery) @staticmethod - def bincode_deserialize(input: bytes) -> "ServerQuery": + def bincode_deserialize(input: bytes) -> 'ServerQuery': v, buffer = bincode.deserialize(input, ServerQuery) if buffer: raise st.DeserializationError("Some input bytes were not read") return v + diff --git a/sdk/ahnlich-client-py/ahnlich_client_py/internals/db_response.py b/sdk/ahnlich-client-py/ahnlich_client_py/internals/db_response.py index d1d0a6c4..f9826aff 100644 --- a/sdk/ahnlich-client-py/ahnlich_client_py/internals/db_response.py +++ b/sdk/ahnlich-client-py/ahnlich_client_py/internals/db_response.py @@ -1,10 +1,8 @@ # pyre-strict -import typing from dataclasses import dataclass - -from ahnlich_client_py.internals import bincode +import typing from ahnlich_client_py.internals import serde_types as st - +from ahnlich_client_py.internals import bincode @dataclass(frozen=True) class Array: @@ -16,7 +14,7 @@ def bincode_serialize(self) -> bytes: return bincode.serialize(self, Array) @staticmethod - def bincode_deserialize(input: bytes) -> "Array": + def bincode_deserialize(input: bytes) -> 'Array': v, buffer = bincode.deserialize(input, Array) if buffer: raise st.DeserializationError("Some input bytes were not read") @@ -32,7 +30,7 @@ def bincode_serialize(self) -> bytes: return bincode.serialize(self, ConnectedClient) @staticmethod - def bincode_deserialize(input: bytes) -> "ConnectedClient": + def bincode_deserialize(input: bytes) -> 'ConnectedClient': v, buffer = bincode.deserialize(input, ConnectedClient) if buffer: raise st.DeserializationError("Some input bytes were not read") @@ -46,7 +44,7 @@ def bincode_serialize(self) -> bytes: return bincode.serialize(self, MetadataValue) @staticmethod - def bincode_deserialize(input: bytes) -> "MetadataValue": + def bincode_deserialize(input: bytes) -> 'MetadataValue': v, buffer = bincode.deserialize(input, MetadataValue) if buffer: raise st.DeserializationError("Some input bytes were not read") @@ -64,7 +62,6 @@ class MetadataValue__Image(MetadataValue): INDEX = 1 # type: int value: typing.Sequence[st.uint8] - MetadataValue.VARIANTS = [ MetadataValue__RawString, MetadataValue__Image, @@ -78,7 +75,7 @@ def bincode_serialize(self) -> bytes: return bincode.serialize(self, Result) @staticmethod - def bincode_deserialize(input: bytes) -> "Result": + def bincode_deserialize(input: bytes) -> 'Result': v, buffer = bincode.deserialize(input, Result) if buffer: raise st.DeserializationError("Some input bytes were not read") @@ -96,7 +93,6 @@ class Result__Err(Result): INDEX = 1 # type: int value: str - Result.VARIANTS = [ Result__Ok, Result__Err, @@ -115,7 +111,7 @@ def bincode_serialize(self) -> bytes: return bincode.serialize(self, ServerInfo) @staticmethod - def bincode_deserialize(input: bytes) -> "ServerInfo": + def bincode_deserialize(input: bytes) -> 'ServerInfo': v, buffer = bincode.deserialize(input, ServerInfo) if buffer: raise st.DeserializationError("Some input bytes were not read") @@ -129,7 +125,7 @@ def bincode_serialize(self) -> bytes: return bincode.serialize(self, ServerResponse) @staticmethod - def bincode_deserialize(input: bytes) -> "ServerResponse": + def bincode_deserialize(input: bytes) -> 'ServerResponse': v, buffer = bincode.deserialize(input, ServerResponse) if buffer: raise st.DeserializationError("Some input bytes were not read") @@ -181,9 +177,7 @@ class ServerResponse__Get(ServerResponse): @dataclass(frozen=True) class ServerResponse__GetSimN(ServerResponse): INDEX = 7 # type: int - value: typing.Sequence[ - typing.Tuple["Array", typing.Dict[str, "MetadataValue"], "Similarity"] - ] + value: typing.Sequence[typing.Tuple["Array", typing.Dict[str, "MetadataValue"], "Similarity"]] @dataclass(frozen=True) @@ -197,7 +191,6 @@ class ServerResponse__CreateIndex(ServerResponse): INDEX = 9 # type: int value: st.uint64 - ServerResponse.VARIANTS = [ ServerResponse__Unit, ServerResponse__Pong, @@ -220,7 +213,7 @@ def bincode_serialize(self) -> bytes: return bincode.serialize(self, ServerResult) @staticmethod - def bincode_deserialize(input: bytes) -> "ServerResult": + def bincode_deserialize(input: bytes) -> 'ServerResult': v, buffer = bincode.deserialize(input, ServerResult) if buffer: raise st.DeserializationError("Some input bytes were not read") @@ -234,7 +227,7 @@ def bincode_serialize(self) -> bytes: return bincode.serialize(self, ServerType) @staticmethod - def bincode_deserialize(input: bytes) -> "ServerType": + def bincode_deserialize(input: bytes) -> 'ServerType': v, buffer = bincode.deserialize(input, ServerType) if buffer: raise st.DeserializationError("Some input bytes were not read") @@ -252,7 +245,6 @@ class ServerType__AI(ServerType): INDEX = 1 # type: int pass - ServerType.VARIANTS = [ ServerType__Database, ServerType__AI, @@ -267,7 +259,7 @@ def bincode_serialize(self) -> bytes: return bincode.serialize(self, Similarity) @staticmethod - def bincode_deserialize(input: bytes) -> "Similarity": + def bincode_deserialize(input: bytes) -> 'Similarity': v, buffer = bincode.deserialize(input, Similarity) if buffer: raise st.DeserializationError("Some input bytes were not read") @@ -284,7 +276,7 @@ def bincode_serialize(self) -> bytes: return bincode.serialize(self, StoreInfo) @staticmethod - def bincode_deserialize(input: bytes) -> "StoreInfo": + def bincode_deserialize(input: bytes) -> 'StoreInfo': v, buffer = bincode.deserialize(input, StoreInfo) if buffer: raise st.DeserializationError("Some input bytes were not read") @@ -300,7 +292,7 @@ def bincode_serialize(self) -> bytes: return bincode.serialize(self, StoreUpsert) @staticmethod - def bincode_deserialize(input: bytes) -> "StoreUpsert": + def bincode_deserialize(input: bytes) -> 'StoreUpsert': v, buffer = bincode.deserialize(input, StoreUpsert) if buffer: raise st.DeserializationError("Some input bytes were not read") @@ -316,7 +308,7 @@ def bincode_serialize(self) -> bytes: return bincode.serialize(self, SystemTime) @staticmethod - def bincode_deserialize(input: bytes) -> "SystemTime": + def bincode_deserialize(input: bytes) -> 'SystemTime': v, buffer = bincode.deserialize(input, SystemTime) if buffer: raise st.DeserializationError("Some input bytes were not read") @@ -333,8 +325,9 @@ def bincode_serialize(self) -> bytes: return bincode.serialize(self, Version) @staticmethod - def bincode_deserialize(input: bytes) -> "Version": + def bincode_deserialize(input: bytes) -> 'Version': v, buffer = bincode.deserialize(input, Version) if buffer: raise st.DeserializationError("Some input bytes were not read") return v + diff --git a/sdk/ahnlich-client-py/ahnlich_client_py/internals/serde_binary/__init__.py b/sdk/ahnlich-client-py/ahnlich_client_py/internals/serde_binary/__init__.py index a71b03f5..0730bd23 100644 --- a/sdk/ahnlich-client-py/ahnlich_client_py/internals/serde_binary/__init__.py +++ b/sdk/ahnlich-client-py/ahnlich_client_py/internals/serde_binary/__init__.py @@ -7,8 +7,8 @@ Note: This internal module is currently only meant to share code between the BCS and bincode formats. Internal APIs could change in the future. """ -import collections import dataclasses +import collections import io import typing from typing import get_type_hints diff --git a/sdk/ahnlich-client-py/ahnlich_client_py/internals/serde_types/__init__.py b/sdk/ahnlich-client-py/ahnlich_client_py/internals/serde_types/__init__.py index 1c85909c..6d72f027 100644 --- a/sdk/ahnlich-client-py/ahnlich_client_py/internals/serde_types/__init__.py +++ b/sdk/ahnlich-client-py/ahnlich_client_py/internals/serde_types/__init__.py @@ -1,10 +1,9 @@ # Copyright (c) Facebook, Inc. and its affiliates # SPDX-License-Identifier: MIT OR Apache-2.0 -import typing -from dataclasses import dataclass - import numpy as np +from dataclasses import dataclass +import typing class SerializationError(ValueError): diff --git a/type_specs/query/ai_query.json b/type_specs/query/ai_query.json index 8325591b..99adda0e 100644 --- a/type_specs/query/ai_query.json +++ b/type_specs/query/ai_query.json @@ -102,6 +102,11 @@ "algorithm": { "TYPENAME": "Algorithm" } + }, + { + "preprocess_action": { + "TYPENAME": "PreprocessAction" + } } ] } @@ -232,15 +237,34 @@ } }, "10": { - "InfoServer": "UNIT" + "GetKey": { + "STRUCT": [ + { + "store": "STR" + }, + { + "keys": { + "SEQ": { + "TYPENAME": "StoreInput" + } + } + } + ] + } }, "11": { - "ListStores": "UNIT" + "InfoServer": "UNIT" }, "12": { - "PurgeStores": "UNIT" + "ListClients": "UNIT" }, "13": { + "ListStores": "UNIT" + }, + "14": { + "PurgeStores": "UNIT" + }, + "15": { "Ping": "UNIT" } } From f5f7a70992fb5597869ce940a8e3733937f60f70 Mon Sep 17 00:00:00 2001 From: lordsarcastic Date: Sat, 30 Nov 2024 22:05:25 +0100 Subject: [PATCH 06/10] Chore: Implemented tests for get_key and list_clients --- sdk/ahnlich-client-py/Makefile | 5 ++ .../ahnlich_client_py/builders/ai.py | 8 +++ .../ahnlich_client_py/clients/ai.py | 18 ++++++ .../clients/non_blocking/ai.py | 20 +++++++ .../ahnlich_client_py/internals/ai_query.py | 43 +++++++++----- .../internals/ai_response.py | 56 ++++++++++++------- .../internals/bincode/__init__.py | 4 +- .../ahnlich_client_py/internals/db_query.py | 29 ++++++---- .../internals/db_response.py | 41 ++++++++------ .../internals/serde_binary/__init__.py | 2 +- .../internals/serde_types/__init__.py | 5 +- .../test_ai_client_store_commands.py | 48 ++++++++++++++++ 12 files changed, 211 insertions(+), 68 deletions(-) create mode 100644 sdk/ahnlich-client-py/Makefile diff --git a/sdk/ahnlich-client-py/Makefile b/sdk/ahnlich-client-py/Makefile new file mode 100644 index 00000000..6510d334 --- /dev/null +++ b/sdk/ahnlich-client-py/Makefile @@ -0,0 +1,5 @@ +install: + @poetry install + +test: + @poetry run pytest . -s -vv diff --git a/sdk/ahnlich-client-py/ahnlich_client_py/builders/ai.py b/sdk/ahnlich-client-py/ahnlich_client_py/builders/ai.py index 4ed83756..b86afa24 100644 --- a/sdk/ahnlich-client-py/ahnlich_client_py/builders/ai.py +++ b/sdk/ahnlich-client-py/ahnlich_client_py/builders/ai.py @@ -52,6 +52,7 @@ def get_sim_n( closest_n: st.uint64 = 1, algorithm: ai_query.Algorithm = ai_query.Algorithm__CosineSimilarity, condition: typing.Optional[ai_query.PredicateCondition] = None, + preprocess_action: ai_query.PreprocessAction = ai_query.PreprocessAction__ModelPreprocessing, ): nonzero_n = NonZeroSizeInteger(closest_n) self.queries.append( @@ -61,6 +62,7 @@ def get_sim_n( closest_n=nonzero_n.value, algorithm=algorithm, condition=condition, + preprocess_action=preprocess_action, ) ) @@ -127,6 +129,9 @@ def set( def del_key(self, store_name: str, key: ai_query.StoreInput): self.queries.append(ai_query.AIQuery__DelKey(store=store_name, key=key)) + def get_key(self, store_name: str, keys: typing.Sequence[ai_query.StoreInput]): + self.queries.append(ai_query.AIQuery__GetKey(store=store_name, keys=keys)) + def drop_store(self, store_name: str, error_if_not_exists: bool = True): self.queries.append( ai_query.AIQuery__DropStore( @@ -143,6 +148,9 @@ def info_server(self): def list_stores(self): self.queries.append(ai_query.AIQuery__ListStores()) + def list_clients(self): + self.queries.append(ai_query.AIQuery__ListClients()) + def ping(self): self.queries.append(ai_query.AIQuery__Ping()) diff --git a/sdk/ahnlich-client-py/ahnlich_client_py/clients/ai.py b/sdk/ahnlich-client-py/ahnlich_client_py/clients/ai.py index 765627b8..20ad31cd 100644 --- a/sdk/ahnlich-client-py/ahnlich_client_py/clients/ai.py +++ b/sdk/ahnlich-client-py/ahnlich_client_py/clients/ai.py @@ -154,6 +154,16 @@ def del_key( builder.del_key(store_name=store_name, key=key) return self.process_request(builder.to_server_query()) + def get_key( + self, + store_name: str, + keys: typing.Sequence[ai_query.StoreInput], + tracing_id: typing.Optional[str] = None, + ): + builder = builders.AhnlichAIRequestBuilder(tracing_id) + builder.get_key(store_name=store_name, keys=keys) + return self.process_request(builder.to_server_query()) + def drop_store( self, store_name: str, @@ -189,6 +199,14 @@ def list_stores( builder = builders.AhnlichAIRequestBuilder(tracing_id) builder.list_stores() return self.process_request(builder.to_server_query()) + + def list_clients( + self, + tracing_id: typing.Optional[str] = None, + ): + builder = builders.AhnlichAIRequestBuilder(tracing_id) + builder.list_clients() + return self.process_request(builder.to_server_query()) def ping( self, diff --git a/sdk/ahnlich-client-py/ahnlich_client_py/clients/non_blocking/ai.py b/sdk/ahnlich-client-py/ahnlich_client_py/clients/non_blocking/ai.py index 9168791f..44d0eb7f 100644 --- a/sdk/ahnlich-client-py/ahnlich_client_py/clients/non_blocking/ai.py +++ b/sdk/ahnlich-client-py/ahnlich_client_py/clients/non_blocking/ai.py @@ -65,6 +65,7 @@ async def get_sim_n( closest_n: st.uint64 = 1, algorithm: ai_query.Algorithm = ai_query.Algorithm__CosineSimilarity, condition: typing.Optional[ai_query.PredicateCondition] = None, + preprocess_action: ai_query.PreprocessAction = ai_query.PreprocessAction__ModelPreprocessing, tracing_id: typing.Optional[str] = None, ): builder = AsyncAhnlichAIRequestBuilder(tracing_id) @@ -74,6 +75,7 @@ async def get_sim_n( closest_n=closest_n, algorithm=algorithm, condition=condition, + preprocess_action=preprocess_action, ) return await self.process_request(builder.to_server_query()) @@ -154,6 +156,16 @@ async def del_key( builder.del_key(store_name=store_name, key=key) return await self.process_request(builder.to_server_query()) + async def get_key( + self, + store_name: str, + keys: typing.Sequence[ai_query.StoreInput], + tracing_id: typing.Optional[str] = None, + ): + builder = AsyncAhnlichAIRequestBuilder(tracing_id) + builder.get_key(store_name=store_name, keys=keys) + return await self.process_request(builder.to_server_query()) + async def drop_store( self, store_name: str, @@ -190,6 +202,14 @@ async def list_stores( builder.list_stores() return await self.process_request(builder.to_server_query()) + async def list_clients( + self, + tracing_id: typing.Optional[str] = None, + ): + builder = AsyncAhnlichAIRequestBuilder(tracing_id) + builder.list_clients() + return await self.process_request(builder.to_server_query()) + async def ping( self, tracing_id: typing.Optional[str] = None, diff --git a/sdk/ahnlich-client-py/ahnlich_client_py/internals/ai_query.py b/sdk/ahnlich-client-py/ahnlich_client_py/internals/ai_query.py index 88f753a2..40326f01 100644 --- a/sdk/ahnlich-client-py/ahnlich_client_py/internals/ai_query.py +++ b/sdk/ahnlich-client-py/ahnlich_client_py/internals/ai_query.py @@ -1,8 +1,10 @@ # pyre-strict -from dataclasses import dataclass import typing -from ahnlich_client_py.internals import serde_types as st +from dataclasses import dataclass + from ahnlich_client_py.internals import bincode +from ahnlich_client_py.internals import serde_types as st + class AIModel: VARIANTS = [] # type: typing.Sequence[typing.Type[AIModel]] @@ -11,7 +13,7 @@ def bincode_serialize(self) -> bytes: return bincode.serialize(self, AIModel) @staticmethod - def bincode_deserialize(input: bytes) -> 'AIModel': + def bincode_deserialize(input: bytes) -> "AIModel": v, buffer = bincode.deserialize(input, AIModel) if buffer: raise st.DeserializationError("Some input bytes were not read") @@ -59,6 +61,7 @@ class AIModel__ClipVitB32Text(AIModel): INDEX = 6 # type: int pass + AIModel.VARIANTS = [ AIModel__AllMiniLML6V2, AIModel__AllMiniLML12V2, @@ -77,7 +80,7 @@ def bincode_serialize(self) -> bytes: return bincode.serialize(self, AIQuery) @staticmethod - def bincode_deserialize(input: bytes) -> 'AIQuery': + def bincode_deserialize(input: bytes) -> "AIQuery": v, buffer = bincode.deserialize(input, AIQuery) if buffer: raise st.DeserializationError("Some input bytes were not read") @@ -148,7 +151,9 @@ class AIQuery__DropNonLinearAlgorithmIndex(AIQuery): class AIQuery__Set(AIQuery): INDEX = 7 # type: int store: str - inputs: typing.Sequence[typing.Tuple["StoreInput", typing.Dict[str, "MetadataValue"]]] + inputs: typing.Sequence[ + typing.Tuple["StoreInput", typing.Dict[str, "MetadataValue"]] + ] preprocess_action: "PreprocessAction" @@ -202,6 +207,7 @@ class AIQuery__Ping(AIQuery): INDEX = 15 # type: int pass + AIQuery.VARIANTS = [ AIQuery__CreateStore, AIQuery__GetPred, @@ -231,7 +237,7 @@ def bincode_serialize(self) -> bytes: return bincode.serialize(self, AIServerQuery) @staticmethod - def bincode_deserialize(input: bytes) -> 'AIServerQuery': + def bincode_deserialize(input: bytes) -> "AIServerQuery": v, buffer = bincode.deserialize(input, AIServerQuery) if buffer: raise st.DeserializationError("Some input bytes were not read") @@ -245,7 +251,7 @@ def bincode_serialize(self) -> bytes: return bincode.serialize(self, AIStoreInputType) @staticmethod - def bincode_deserialize(input: bytes) -> 'AIStoreInputType': + def bincode_deserialize(input: bytes) -> "AIStoreInputType": v, buffer = bincode.deserialize(input, AIStoreInputType) if buffer: raise st.DeserializationError("Some input bytes were not read") @@ -263,6 +269,7 @@ class AIStoreInputType__Image(AIStoreInputType): INDEX = 1 # type: int pass + AIStoreInputType.VARIANTS = [ AIStoreInputType__RawString, AIStoreInputType__Image, @@ -276,7 +283,7 @@ def bincode_serialize(self) -> bytes: return bincode.serialize(self, Algorithm) @staticmethod - def bincode_deserialize(input: bytes) -> 'Algorithm': + def bincode_deserialize(input: bytes) -> "Algorithm": v, buffer = bincode.deserialize(input, Algorithm) if buffer: raise st.DeserializationError("Some input bytes were not read") @@ -306,6 +313,7 @@ class Algorithm__KDTree(Algorithm): INDEX = 3 # type: int pass + Algorithm.VARIANTS = [ Algorithm__EuclideanDistance, Algorithm__DotProductSimilarity, @@ -321,7 +329,7 @@ def bincode_serialize(self) -> bytes: return bincode.serialize(self, MetadataValue) @staticmethod - def bincode_deserialize(input: bytes) -> 'MetadataValue': + def bincode_deserialize(input: bytes) -> "MetadataValue": v, buffer = bincode.deserialize(input, MetadataValue) if buffer: raise st.DeserializationError("Some input bytes were not read") @@ -339,6 +347,7 @@ class MetadataValue__Image(MetadataValue): INDEX = 1 # type: int value: typing.Sequence[st.uint8] + MetadataValue.VARIANTS = [ MetadataValue__RawString, MetadataValue__Image, @@ -352,7 +361,7 @@ def bincode_serialize(self) -> bytes: return bincode.serialize(self, NonLinearAlgorithm) @staticmethod - def bincode_deserialize(input: bytes) -> 'NonLinearAlgorithm': + def bincode_deserialize(input: bytes) -> "NonLinearAlgorithm": v, buffer = bincode.deserialize(input, NonLinearAlgorithm) if buffer: raise st.DeserializationError("Some input bytes were not read") @@ -364,6 +373,7 @@ class NonLinearAlgorithm__KDTree(NonLinearAlgorithm): INDEX = 0 # type: int pass + NonLinearAlgorithm.VARIANTS = [ NonLinearAlgorithm__KDTree, ] @@ -376,7 +386,7 @@ def bincode_serialize(self) -> bytes: return bincode.serialize(self, Predicate) @staticmethod - def bincode_deserialize(input: bytes) -> 'Predicate': + def bincode_deserialize(input: bytes) -> "Predicate": v, buffer = bincode.deserialize(input, Predicate) if buffer: raise st.DeserializationError("Some input bytes were not read") @@ -410,6 +420,7 @@ class Predicate__NotIn(Predicate): key: str value: typing.Sequence["MetadataValue"] + Predicate.VARIANTS = [ Predicate__Equals, Predicate__NotEquals, @@ -425,7 +436,7 @@ def bincode_serialize(self) -> bytes: return bincode.serialize(self, PredicateCondition) @staticmethod - def bincode_deserialize(input: bytes) -> 'PredicateCondition': + def bincode_deserialize(input: bytes) -> "PredicateCondition": v, buffer = bincode.deserialize(input, PredicateCondition) if buffer: raise st.DeserializationError("Some input bytes were not read") @@ -449,6 +460,7 @@ class PredicateCondition__Or(PredicateCondition): INDEX = 2 # type: int value: typing.Tuple["PredicateCondition", "PredicateCondition"] + PredicateCondition.VARIANTS = [ PredicateCondition__Value, PredicateCondition__And, @@ -463,7 +475,7 @@ def bincode_serialize(self) -> bytes: return bincode.serialize(self, PreprocessAction) @staticmethod - def bincode_deserialize(input: bytes) -> 'PreprocessAction': + def bincode_deserialize(input: bytes) -> "PreprocessAction": v, buffer = bincode.deserialize(input, PreprocessAction) if buffer: raise st.DeserializationError("Some input bytes were not read") @@ -481,6 +493,7 @@ class PreprocessAction__ModelPreprocessing(PreprocessAction): INDEX = 1 # type: int pass + PreprocessAction.VARIANTS = [ PreprocessAction__NoPreprocessing, PreprocessAction__ModelPreprocessing, @@ -494,7 +507,7 @@ def bincode_serialize(self) -> bytes: return bincode.serialize(self, StoreInput) @staticmethod - def bincode_deserialize(input: bytes) -> 'StoreInput': + def bincode_deserialize(input: bytes) -> "StoreInput": v, buffer = bincode.deserialize(input, StoreInput) if buffer: raise st.DeserializationError("Some input bytes were not read") @@ -512,8 +525,8 @@ class StoreInput__Image(StoreInput): INDEX = 1 # type: int value: typing.Sequence[st.uint8] + StoreInput.VARIANTS = [ StoreInput__RawString, StoreInput__Image, ] - diff --git a/sdk/ahnlich-client-py/ahnlich_client_py/internals/ai_response.py b/sdk/ahnlich-client-py/ahnlich_client_py/internals/ai_response.py index 3312d280..e71b20ca 100644 --- a/sdk/ahnlich-client-py/ahnlich_client_py/internals/ai_response.py +++ b/sdk/ahnlich-client-py/ahnlich_client_py/internals/ai_response.py @@ -1,8 +1,10 @@ # pyre-strict -from dataclasses import dataclass import typing -from ahnlich_client_py.internals import serde_types as st +from dataclasses import dataclass + from ahnlich_client_py.internals import bincode +from ahnlich_client_py.internals import serde_types as st + class AIModel: VARIANTS = [] # type: typing.Sequence[typing.Type[AIModel]] @@ -11,7 +13,7 @@ def bincode_serialize(self) -> bytes: return bincode.serialize(self, AIModel) @staticmethod - def bincode_deserialize(input: bytes) -> 'AIModel': + def bincode_deserialize(input: bytes) -> "AIModel": v, buffer = bincode.deserialize(input, AIModel) if buffer: raise st.DeserializationError("Some input bytes were not read") @@ -59,6 +61,7 @@ class AIModel__ClipVitB32Text(AIModel): INDEX = 6 # type: int pass + AIModel.VARIANTS = [ AIModel__AllMiniLML6V2, AIModel__AllMiniLML12V2, @@ -77,7 +80,7 @@ def bincode_serialize(self) -> bytes: return bincode.serialize(self, AIServerResponse) @staticmethod - def bincode_deserialize(input: bytes) -> 'AIServerResponse': + def bincode_deserialize(input: bytes) -> "AIServerResponse": v, buffer = bincode.deserialize(input, AIServerResponse) if buffer: raise st.DeserializationError("Some input bytes were not read") @@ -123,13 +126,21 @@ class AIServerResponse__Set(AIServerResponse): @dataclass(frozen=True) class AIServerResponse__Get(AIServerResponse): INDEX = 6 # type: int - value: typing.Sequence[typing.Tuple[typing.Optional["StoreInput"], typing.Dict[str, "MetadataValue"]]] + value: typing.Sequence[ + typing.Tuple[typing.Optional["StoreInput"], typing.Dict[str, "MetadataValue"]] + ] @dataclass(frozen=True) class AIServerResponse__GetSimN(AIServerResponse): INDEX = 7 # type: int - value: typing.Sequence[typing.Tuple[typing.Optional["StoreInput"], typing.Dict[str, "MetadataValue"], "Similarity"]] + value: typing.Sequence[ + typing.Tuple[ + typing.Optional["StoreInput"], + typing.Dict[str, "MetadataValue"], + "Similarity", + ] + ] @dataclass(frozen=True) @@ -143,6 +154,7 @@ class AIServerResponse__CreateIndex(AIServerResponse): INDEX = 9 # type: int value: st.uint64 + AIServerResponse.VARIANTS = [ AIServerResponse__Unit, AIServerResponse__Pong, @@ -165,7 +177,7 @@ def bincode_serialize(self) -> bytes: return bincode.serialize(self, AIServerResult) @staticmethod - def bincode_deserialize(input: bytes) -> 'AIServerResult': + def bincode_deserialize(input: bytes) -> "AIServerResult": v, buffer = bincode.deserialize(input, AIServerResult) if buffer: raise st.DeserializationError("Some input bytes were not read") @@ -183,7 +195,7 @@ def bincode_serialize(self) -> bytes: return bincode.serialize(self, AIStoreInfo) @staticmethod - def bincode_deserialize(input: bytes) -> 'AIStoreInfo': + def bincode_deserialize(input: bytes) -> "AIStoreInfo": v, buffer = bincode.deserialize(input, AIStoreInfo) if buffer: raise st.DeserializationError("Some input bytes were not read") @@ -197,7 +209,7 @@ def bincode_serialize(self) -> bytes: return bincode.serialize(self, AIStoreInputType) @staticmethod - def bincode_deserialize(input: bytes) -> 'AIStoreInputType': + def bincode_deserialize(input: bytes) -> "AIStoreInputType": v, buffer = bincode.deserialize(input, AIStoreInputType) if buffer: raise st.DeserializationError("Some input bytes were not read") @@ -215,6 +227,7 @@ class AIStoreInputType__Image(AIStoreInputType): INDEX = 1 # type: int pass + AIStoreInputType.VARIANTS = [ AIStoreInputType__RawString, AIStoreInputType__Image, @@ -230,7 +243,7 @@ def bincode_serialize(self) -> bytes: return bincode.serialize(self, ConnectedClient) @staticmethod - def bincode_deserialize(input: bytes) -> 'ConnectedClient': + def bincode_deserialize(input: bytes) -> "ConnectedClient": v, buffer = bincode.deserialize(input, ConnectedClient) if buffer: raise st.DeserializationError("Some input bytes were not read") @@ -244,7 +257,7 @@ def bincode_serialize(self) -> bytes: return bincode.serialize(self, MetadataValue) @staticmethod - def bincode_deserialize(input: bytes) -> 'MetadataValue': + def bincode_deserialize(input: bytes) -> "MetadataValue": v, buffer = bincode.deserialize(input, MetadataValue) if buffer: raise st.DeserializationError("Some input bytes were not read") @@ -262,6 +275,7 @@ class MetadataValue__Image(MetadataValue): INDEX = 1 # type: int value: typing.Sequence[st.uint8] + MetadataValue.VARIANTS = [ MetadataValue__RawString, MetadataValue__Image, @@ -275,7 +289,7 @@ def bincode_serialize(self) -> bytes: return bincode.serialize(self, Result) @staticmethod - def bincode_deserialize(input: bytes) -> 'Result': + def bincode_deserialize(input: bytes) -> "Result": v, buffer = bincode.deserialize(input, Result) if buffer: raise st.DeserializationError("Some input bytes were not read") @@ -293,6 +307,7 @@ class Result__Err(Result): INDEX = 1 # type: int value: str + Result.VARIANTS = [ Result__Ok, Result__Err, @@ -311,7 +326,7 @@ def bincode_serialize(self) -> bytes: return bincode.serialize(self, ServerInfo) @staticmethod - def bincode_deserialize(input: bytes) -> 'ServerInfo': + def bincode_deserialize(input: bytes) -> "ServerInfo": v, buffer = bincode.deserialize(input, ServerInfo) if buffer: raise st.DeserializationError("Some input bytes were not read") @@ -325,7 +340,7 @@ def bincode_serialize(self) -> bytes: return bincode.serialize(self, ServerType) @staticmethod - def bincode_deserialize(input: bytes) -> 'ServerType': + def bincode_deserialize(input: bytes) -> "ServerType": v, buffer = bincode.deserialize(input, ServerType) if buffer: raise st.DeserializationError("Some input bytes were not read") @@ -343,6 +358,7 @@ class ServerType__AI(ServerType): INDEX = 1 # type: int pass + ServerType.VARIANTS = [ ServerType__Database, ServerType__AI, @@ -357,7 +373,7 @@ def bincode_serialize(self) -> bytes: return bincode.serialize(self, Similarity) @staticmethod - def bincode_deserialize(input: bytes) -> 'Similarity': + def bincode_deserialize(input: bytes) -> "Similarity": v, buffer = bincode.deserialize(input, Similarity) if buffer: raise st.DeserializationError("Some input bytes were not read") @@ -371,7 +387,7 @@ def bincode_serialize(self) -> bytes: return bincode.serialize(self, StoreInput) @staticmethod - def bincode_deserialize(input: bytes) -> 'StoreInput': + def bincode_deserialize(input: bytes) -> "StoreInput": v, buffer = bincode.deserialize(input, StoreInput) if buffer: raise st.DeserializationError("Some input bytes were not read") @@ -389,6 +405,7 @@ class StoreInput__Image(StoreInput): INDEX = 1 # type: int value: typing.Sequence[st.uint8] + StoreInput.VARIANTS = [ StoreInput__RawString, StoreInput__Image, @@ -404,7 +421,7 @@ def bincode_serialize(self) -> bytes: return bincode.serialize(self, StoreUpsert) @staticmethod - def bincode_deserialize(input: bytes) -> 'StoreUpsert': + def bincode_deserialize(input: bytes) -> "StoreUpsert": v, buffer = bincode.deserialize(input, StoreUpsert) if buffer: raise st.DeserializationError("Some input bytes were not read") @@ -420,7 +437,7 @@ def bincode_serialize(self) -> bytes: return bincode.serialize(self, SystemTime) @staticmethod - def bincode_deserialize(input: bytes) -> 'SystemTime': + def bincode_deserialize(input: bytes) -> "SystemTime": v, buffer = bincode.deserialize(input, SystemTime) if buffer: raise st.DeserializationError("Some input bytes were not read") @@ -437,9 +454,8 @@ def bincode_serialize(self) -> bytes: return bincode.serialize(self, Version) @staticmethod - def bincode_deserialize(input: bytes) -> 'Version': + def bincode_deserialize(input: bytes) -> "Version": v, buffer = bincode.deserialize(input, Version) if buffer: raise st.DeserializationError("Some input bytes were not read") return v - diff --git a/sdk/ahnlich-client-py/ahnlich_client_py/internals/bincode/__init__.py b/sdk/ahnlich-client-py/ahnlich_client_py/internals/bincode/__init__.py index 4e5e0837..38cbd7ff 100644 --- a/sdk/ahnlich-client-py/ahnlich_client_py/internals/bincode/__init__.py +++ b/sdk/ahnlich-client-py/ahnlich_client_py/internals/bincode/__init__.py @@ -1,16 +1,16 @@ # Copyright (c) Facebook, Inc. and its affiliates # SPDX-License-Identifier: MIT OR Apache-2.0 -import dataclasses import collections +import dataclasses import io import struct import typing from copy import copy from typing import get_type_hints -from ahnlich_client_py.internals import serde_types as st from ahnlich_client_py.internals import serde_binary as sb +from ahnlich_client_py.internals import serde_types as st # Maximum length in practice for sequences (e.g. in Java). MAX_LENGTH = (1 << 31) - 1 diff --git a/sdk/ahnlich-client-py/ahnlich_client_py/internals/db_query.py b/sdk/ahnlich-client-py/ahnlich_client_py/internals/db_query.py index 30713685..b281f346 100644 --- a/sdk/ahnlich-client-py/ahnlich_client_py/internals/db_query.py +++ b/sdk/ahnlich-client-py/ahnlich_client_py/internals/db_query.py @@ -1,8 +1,10 @@ # pyre-strict -from dataclasses import dataclass import typing -from ahnlich_client_py.internals import serde_types as st +from dataclasses import dataclass + from ahnlich_client_py.internals import bincode +from ahnlich_client_py.internals import serde_types as st + class Algorithm: VARIANTS = [] # type: typing.Sequence[typing.Type[Algorithm]] @@ -11,7 +13,7 @@ def bincode_serialize(self) -> bytes: return bincode.serialize(self, Algorithm) @staticmethod - def bincode_deserialize(input: bytes) -> 'Algorithm': + def bincode_deserialize(input: bytes) -> "Algorithm": v, buffer = bincode.deserialize(input, Algorithm) if buffer: raise st.DeserializationError("Some input bytes were not read") @@ -41,6 +43,7 @@ class Algorithm__KDTree(Algorithm): INDEX = 3 # type: int pass + Algorithm.VARIANTS = [ Algorithm__EuclideanDistance, Algorithm__DotProductSimilarity, @@ -59,7 +62,7 @@ def bincode_serialize(self) -> bytes: return bincode.serialize(self, Array) @staticmethod - def bincode_deserialize(input: bytes) -> 'Array': + def bincode_deserialize(input: bytes) -> "Array": v, buffer = bincode.deserialize(input, Array) if buffer: raise st.DeserializationError("Some input bytes were not read") @@ -73,7 +76,7 @@ def bincode_serialize(self) -> bytes: return bincode.serialize(self, MetadataValue) @staticmethod - def bincode_deserialize(input: bytes) -> 'MetadataValue': + def bincode_deserialize(input: bytes) -> "MetadataValue": v, buffer = bincode.deserialize(input, MetadataValue) if buffer: raise st.DeserializationError("Some input bytes were not read") @@ -91,6 +94,7 @@ class MetadataValue__Image(MetadataValue): INDEX = 1 # type: int value: typing.Sequence[st.uint8] + MetadataValue.VARIANTS = [ MetadataValue__RawString, MetadataValue__Image, @@ -104,7 +108,7 @@ def bincode_serialize(self) -> bytes: return bincode.serialize(self, NonLinearAlgorithm) @staticmethod - def bincode_deserialize(input: bytes) -> 'NonLinearAlgorithm': + def bincode_deserialize(input: bytes) -> "NonLinearAlgorithm": v, buffer = bincode.deserialize(input, NonLinearAlgorithm) if buffer: raise st.DeserializationError("Some input bytes were not read") @@ -116,6 +120,7 @@ class NonLinearAlgorithm__KDTree(NonLinearAlgorithm): INDEX = 0 # type: int pass + NonLinearAlgorithm.VARIANTS = [ NonLinearAlgorithm__KDTree, ] @@ -128,7 +133,7 @@ def bincode_serialize(self) -> bytes: return bincode.serialize(self, Predicate) @staticmethod - def bincode_deserialize(input: bytes) -> 'Predicate': + def bincode_deserialize(input: bytes) -> "Predicate": v, buffer = bincode.deserialize(input, Predicate) if buffer: raise st.DeserializationError("Some input bytes were not read") @@ -162,6 +167,7 @@ class Predicate__NotIn(Predicate): key: str value: typing.Sequence["MetadataValue"] + Predicate.VARIANTS = [ Predicate__Equals, Predicate__NotEquals, @@ -177,7 +183,7 @@ def bincode_serialize(self) -> bytes: return bincode.serialize(self, PredicateCondition) @staticmethod - def bincode_deserialize(input: bytes) -> 'PredicateCondition': + def bincode_deserialize(input: bytes) -> "PredicateCondition": v, buffer = bincode.deserialize(input, PredicateCondition) if buffer: raise st.DeserializationError("Some input bytes were not read") @@ -201,6 +207,7 @@ class PredicateCondition__Or(PredicateCondition): INDEX = 2 # type: int value: typing.Tuple["PredicateCondition", "PredicateCondition"] + PredicateCondition.VARIANTS = [ PredicateCondition__Value, PredicateCondition__And, @@ -215,7 +222,7 @@ def bincode_serialize(self) -> bytes: return bincode.serialize(self, Query) @staticmethod - def bincode_deserialize(input: bytes) -> 'Query': + def bincode_deserialize(input: bytes) -> "Query": v, buffer = bincode.deserialize(input, Query) if buffer: raise st.DeserializationError("Some input bytes were not read") @@ -337,6 +344,7 @@ class Query__Ping(Query): INDEX = 15 # type: int pass + Query.VARIANTS = [ Query__CreateStore, Query__GetKey, @@ -366,9 +374,8 @@ def bincode_serialize(self) -> bytes: return bincode.serialize(self, ServerQuery) @staticmethod - def bincode_deserialize(input: bytes) -> 'ServerQuery': + def bincode_deserialize(input: bytes) -> "ServerQuery": v, buffer = bincode.deserialize(input, ServerQuery) if buffer: raise st.DeserializationError("Some input bytes were not read") return v - diff --git a/sdk/ahnlich-client-py/ahnlich_client_py/internals/db_response.py b/sdk/ahnlich-client-py/ahnlich_client_py/internals/db_response.py index f9826aff..d1d0a6c4 100644 --- a/sdk/ahnlich-client-py/ahnlich_client_py/internals/db_response.py +++ b/sdk/ahnlich-client-py/ahnlich_client_py/internals/db_response.py @@ -1,8 +1,10 @@ # pyre-strict -from dataclasses import dataclass import typing -from ahnlich_client_py.internals import serde_types as st +from dataclasses import dataclass + from ahnlich_client_py.internals import bincode +from ahnlich_client_py.internals import serde_types as st + @dataclass(frozen=True) class Array: @@ -14,7 +16,7 @@ def bincode_serialize(self) -> bytes: return bincode.serialize(self, Array) @staticmethod - def bincode_deserialize(input: bytes) -> 'Array': + def bincode_deserialize(input: bytes) -> "Array": v, buffer = bincode.deserialize(input, Array) if buffer: raise st.DeserializationError("Some input bytes were not read") @@ -30,7 +32,7 @@ def bincode_serialize(self) -> bytes: return bincode.serialize(self, ConnectedClient) @staticmethod - def bincode_deserialize(input: bytes) -> 'ConnectedClient': + def bincode_deserialize(input: bytes) -> "ConnectedClient": v, buffer = bincode.deserialize(input, ConnectedClient) if buffer: raise st.DeserializationError("Some input bytes were not read") @@ -44,7 +46,7 @@ def bincode_serialize(self) -> bytes: return bincode.serialize(self, MetadataValue) @staticmethod - def bincode_deserialize(input: bytes) -> 'MetadataValue': + def bincode_deserialize(input: bytes) -> "MetadataValue": v, buffer = bincode.deserialize(input, MetadataValue) if buffer: raise st.DeserializationError("Some input bytes were not read") @@ -62,6 +64,7 @@ class MetadataValue__Image(MetadataValue): INDEX = 1 # type: int value: typing.Sequence[st.uint8] + MetadataValue.VARIANTS = [ MetadataValue__RawString, MetadataValue__Image, @@ -75,7 +78,7 @@ def bincode_serialize(self) -> bytes: return bincode.serialize(self, Result) @staticmethod - def bincode_deserialize(input: bytes) -> 'Result': + def bincode_deserialize(input: bytes) -> "Result": v, buffer = bincode.deserialize(input, Result) if buffer: raise st.DeserializationError("Some input bytes were not read") @@ -93,6 +96,7 @@ class Result__Err(Result): INDEX = 1 # type: int value: str + Result.VARIANTS = [ Result__Ok, Result__Err, @@ -111,7 +115,7 @@ def bincode_serialize(self) -> bytes: return bincode.serialize(self, ServerInfo) @staticmethod - def bincode_deserialize(input: bytes) -> 'ServerInfo': + def bincode_deserialize(input: bytes) -> "ServerInfo": v, buffer = bincode.deserialize(input, ServerInfo) if buffer: raise st.DeserializationError("Some input bytes were not read") @@ -125,7 +129,7 @@ def bincode_serialize(self) -> bytes: return bincode.serialize(self, ServerResponse) @staticmethod - def bincode_deserialize(input: bytes) -> 'ServerResponse': + def bincode_deserialize(input: bytes) -> "ServerResponse": v, buffer = bincode.deserialize(input, ServerResponse) if buffer: raise st.DeserializationError("Some input bytes were not read") @@ -177,7 +181,9 @@ class ServerResponse__Get(ServerResponse): @dataclass(frozen=True) class ServerResponse__GetSimN(ServerResponse): INDEX = 7 # type: int - value: typing.Sequence[typing.Tuple["Array", typing.Dict[str, "MetadataValue"], "Similarity"]] + value: typing.Sequence[ + typing.Tuple["Array", typing.Dict[str, "MetadataValue"], "Similarity"] + ] @dataclass(frozen=True) @@ -191,6 +197,7 @@ class ServerResponse__CreateIndex(ServerResponse): INDEX = 9 # type: int value: st.uint64 + ServerResponse.VARIANTS = [ ServerResponse__Unit, ServerResponse__Pong, @@ -213,7 +220,7 @@ def bincode_serialize(self) -> bytes: return bincode.serialize(self, ServerResult) @staticmethod - def bincode_deserialize(input: bytes) -> 'ServerResult': + def bincode_deserialize(input: bytes) -> "ServerResult": v, buffer = bincode.deserialize(input, ServerResult) if buffer: raise st.DeserializationError("Some input bytes were not read") @@ -227,7 +234,7 @@ def bincode_serialize(self) -> bytes: return bincode.serialize(self, ServerType) @staticmethod - def bincode_deserialize(input: bytes) -> 'ServerType': + def bincode_deserialize(input: bytes) -> "ServerType": v, buffer = bincode.deserialize(input, ServerType) if buffer: raise st.DeserializationError("Some input bytes were not read") @@ -245,6 +252,7 @@ class ServerType__AI(ServerType): INDEX = 1 # type: int pass + ServerType.VARIANTS = [ ServerType__Database, ServerType__AI, @@ -259,7 +267,7 @@ def bincode_serialize(self) -> bytes: return bincode.serialize(self, Similarity) @staticmethod - def bincode_deserialize(input: bytes) -> 'Similarity': + def bincode_deserialize(input: bytes) -> "Similarity": v, buffer = bincode.deserialize(input, Similarity) if buffer: raise st.DeserializationError("Some input bytes were not read") @@ -276,7 +284,7 @@ def bincode_serialize(self) -> bytes: return bincode.serialize(self, StoreInfo) @staticmethod - def bincode_deserialize(input: bytes) -> 'StoreInfo': + def bincode_deserialize(input: bytes) -> "StoreInfo": v, buffer = bincode.deserialize(input, StoreInfo) if buffer: raise st.DeserializationError("Some input bytes were not read") @@ -292,7 +300,7 @@ def bincode_serialize(self) -> bytes: return bincode.serialize(self, StoreUpsert) @staticmethod - def bincode_deserialize(input: bytes) -> 'StoreUpsert': + def bincode_deserialize(input: bytes) -> "StoreUpsert": v, buffer = bincode.deserialize(input, StoreUpsert) if buffer: raise st.DeserializationError("Some input bytes were not read") @@ -308,7 +316,7 @@ def bincode_serialize(self) -> bytes: return bincode.serialize(self, SystemTime) @staticmethod - def bincode_deserialize(input: bytes) -> 'SystemTime': + def bincode_deserialize(input: bytes) -> "SystemTime": v, buffer = bincode.deserialize(input, SystemTime) if buffer: raise st.DeserializationError("Some input bytes were not read") @@ -325,9 +333,8 @@ def bincode_serialize(self) -> bytes: return bincode.serialize(self, Version) @staticmethod - def bincode_deserialize(input: bytes) -> 'Version': + def bincode_deserialize(input: bytes) -> "Version": v, buffer = bincode.deserialize(input, Version) if buffer: raise st.DeserializationError("Some input bytes were not read") return v - diff --git a/sdk/ahnlich-client-py/ahnlich_client_py/internals/serde_binary/__init__.py b/sdk/ahnlich-client-py/ahnlich_client_py/internals/serde_binary/__init__.py index 0730bd23..a71b03f5 100644 --- a/sdk/ahnlich-client-py/ahnlich_client_py/internals/serde_binary/__init__.py +++ b/sdk/ahnlich-client-py/ahnlich_client_py/internals/serde_binary/__init__.py @@ -7,8 +7,8 @@ Note: This internal module is currently only meant to share code between the BCS and bincode formats. Internal APIs could change in the future. """ -import dataclasses import collections +import dataclasses import io import typing from typing import get_type_hints diff --git a/sdk/ahnlich-client-py/ahnlich_client_py/internals/serde_types/__init__.py b/sdk/ahnlich-client-py/ahnlich_client_py/internals/serde_types/__init__.py index 6d72f027..1c85909c 100644 --- a/sdk/ahnlich-client-py/ahnlich_client_py/internals/serde_types/__init__.py +++ b/sdk/ahnlich-client-py/ahnlich_client_py/internals/serde_types/__init__.py @@ -1,9 +1,10 @@ # Copyright (c) Facebook, Inc. and its affiliates # SPDX-License-Identifier: MIT OR Apache-2.0 -import numpy as np -from dataclasses import dataclass import typing +from dataclasses import dataclass + +import numpy as np class SerializationError(ValueError): diff --git a/sdk/ahnlich-client-py/ahnlich_client_py/tests/ai_client/test_ai_client_store_commands.py b/sdk/ahnlich-client-py/ahnlich_client_py/tests/ai_client/test_ai_client_store_commands.py index 4342f51b..9fe8c7d2 100644 --- a/sdk/ahnlich-client-py/ahnlich_client_py/tests/ai_client/test_ai_client_store_commands.py +++ b/sdk/ahnlich-client-py/ahnlich_client_py/tests/ai_client/test_ai_client_store_commands.py @@ -218,6 +218,38 @@ def test_ai_client_del_key(spin_up_ahnlich_ai): ai_client.cleanup() +def test_ai_client_get_key(spin_up_ahnlich_ai): + port = spin_up_ahnlich_ai + + ai_client = AhnlichAIClient(address="127.0.0.1", port=port) + store_inputs = [(ai_query.StoreInput__RawString("Jordan One"), {})] + + builder = ai_client.pipeline() + builder.create_store(**ai_store_payload_with_predicates) + builder.set( + store_name=ai_store_payload_with_predicates["store_name"], + inputs=store_inputs, + preprocess_action=ai_query.PreprocessAction__NoPreprocessing(), + ) + expected = ai_response.Result__Ok( + value=ai_response.AIServerResponse__Get([(ai_query.StoreInput__RawString(value="Jordan One"), {})]) + ) + + try: + builder.exec() + response = ai_client.get_key( + ai_store_payload_with_predicates["store_name"], + keys=[ai_query.StoreInput__RawString("Jordan One")], + ) + assert str(expected) == str(response.results[0]) + except Exception as e: + print(f"Exception: {e}") + ai_client.cleanup() + raise e + finally: + ai_client.cleanup() + + def test_ai_client_drop_store_succeeds(spin_up_ahnlich_ai): port = spin_up_ahnlich_ai @@ -272,3 +304,19 @@ def test_ai_client_purge_stores_succeeds(spin_up_ahnlich_ai): raise e finally: ai_client.cleanup() + + +def test_ai_client_list_clients_succeeds(spin_up_ahnlich_ai): + port = spin_up_ahnlich_ai + + ai_client = AhnlichAIClient(address="127.0.0.1", port=port) + + try: + response = ai_client.list_clients() + assert len(response.results) == 1 + except Exception as e: + print(f"Exception: {e}") + ai_client.cleanup() + raise e + finally: + ai_client.cleanup() From 59c9ca3af249b02d90f6904eb2ac3f62fb813405 Mon Sep 17 00:00:00 2001 From: Diretnan Domnan Date: Mon, 2 Dec 2024 10:19:55 +0100 Subject: [PATCH 07/10] Condensing arguments for clippy sake --- ahnlich/similarity/src/kdtree.rs | 73 ++++++++++++++++++-------------- 1 file changed, 42 insertions(+), 31 deletions(-) diff --git a/ahnlich/similarity/src/kdtree.rs b/ahnlich/similarity/src/kdtree.rs index 82ec0af1..42ddb7c5 100644 --- a/ahnlich/similarity/src/kdtree.rs +++ b/ahnlich/similarity/src/kdtree.rs @@ -218,6 +218,16 @@ impl<'de> Deserialize<'de> for KDTree { } } +struct NearestRecuriveArgs<'a> { + node: &'a Atomic, + reference_point: &'a Array1, + depth: usize, + n: NonZeroUsize, + guard: &'a Guard, + heap: &'a mut BinaryHeap>, + accept_list: &'a Option>, +} + impl KDTree { /// initialize KDTree with a specified nonzero dimension /// dimension: The dimension of the 1-D arrays to be inserted in the tree @@ -475,15 +485,15 @@ impl KDTree { if matches!(accept_list.as_ref(), Some(a) if a.is_empty()) { return Ok(vec![]); } - self.n_nearest_recursive( - &self.root, + self.n_nearest_recursive(NearestRecuriveArgs { + node: &self.root, reference_point, - 0, + depth: 0, n, - &guard, - &mut heap, - &accept_list, - ); + guard: &guard, + heap: &mut heap, + accept_list: &accept_list, + }); let mut results = Vec::with_capacity(n.get()); while let Some(Reverse(OrderedArray(val, distance))) = heap.pop() { results.push((val, distance)); @@ -506,17 +516,18 @@ impl KDTree { true } - #[allow(clippy::too_many_arguments)] #[tracing::instrument(skip_all)] fn n_nearest_recursive( &self, - node: &Atomic, - reference_point: &Array1, - depth: usize, - n: NonZeroUsize, - guard: &Guard, - heap: &mut BinaryHeap>, - accept_list: &Option>, + NearestRecuriveArgs { + node, + reference_point, + depth, + n, + guard, + heap, + accept_list, + }: NearestRecuriveArgs, ) { if let Some(shared) = unsafe { node.load(Ordering::Acquire, guard).as_ref() } { let distance = self.squared_distance(reference_point, &shared.point); @@ -534,52 +545,52 @@ impl KDTree { let dim = depth % self.depth.get(); let go_left_first = reference_point[dim] < shared.point[dim]; if go_left_first { - self.n_nearest_recursive( - &shared.left, + self.n_nearest_recursive(NearestRecuriveArgs { + node: &shared.left, reference_point, - depth + 1, + depth: depth + 1, n, guard, heap, accept_list, - ); + }); if heap.len() < n.get() || (reference_point[dim] - shared.point[dim]).abs() < heap.peek().map_or(f32::INFINITY, |x| x.0 .1) { - self.n_nearest_recursive( - &shared.right, + self.n_nearest_recursive(NearestRecuriveArgs { + node: &shared.right, reference_point, - depth + 1, + depth: depth + 1, n, guard, heap, accept_list, - ); + }); } } else { - self.n_nearest_recursive( - &shared.right, + self.n_nearest_recursive(NearestRecuriveArgs { + node: &shared.right, reference_point, - depth + 1, + depth: depth + 1, n, guard, heap, accept_list, - ); + }); if heap.len() < n.get() || (reference_point[dim] - shared.point[dim]).abs() < heap.peek().map_or(f32::INFINITY, |x| x.0 .1) { - self.n_nearest_recursive( - &shared.left, + self.n_nearest_recursive(NearestRecuriveArgs { + node: &shared.left, reference_point, - depth + 1, + depth: depth + 1, n, guard, heap, accept_list, - ); + }); } } } From 8aa1334e4634f798cf4988f397ff849a32ff2852 Mon Sep 17 00:00:00 2001 From: Diretnan Domnan Date: Mon, 2 Dec 2024 10:24:09 +0100 Subject: [PATCH 08/10] Adding predicate docs and fixing dsl for ai_get_sim_n --- ahnlich/ai/src/server/task.rs | 4 +--- ahnlich/client/src/ai.rs | 1 - ahnlich/db/src/engine/predicate.rs | 23 ++++++++++++++++++++ ahnlich/dsl/src/ai.rs | 34 ++++++++++++++++++------------ ahnlich/dsl/src/error.rs | 2 ++ ahnlich/dsl/src/syntax/syntax.pest | 5 +++-- ahnlich/dsl/src/tests/ai.rs | 4 ++-- 7 files changed, 51 insertions(+), 22 deletions(-) diff --git a/ahnlich/ai/src/server/task.rs b/ahnlich/ai/src/server/task.rs index 17db43c7..c661b539 100644 --- a/ahnlich/ai/src/server/task.rs +++ b/ahnlich/ai/src/server/task.rs @@ -1,8 +1,6 @@ use crate::engine::ai::models::Model; use ahnlich_client_rs::{builders::db as db_params, db::DbClient}; -use ahnlich_types::ai::{ - AIQuery, AIServerQuery, AIServerResponse, AIServerResult, -}; +use ahnlich_types::ai::{AIQuery, AIServerQuery, AIServerResponse, AIServerResult}; use ahnlich_types::client::ConnectedClient; use ahnlich_types::db::{ServerInfo, ServerResponse}; use ahnlich_types::metadata::MetadataValue; diff --git a/ahnlich/client/src/ai.rs b/ahnlich/client/src/ai.rs index 116b1d63..50e98a65 100644 --- a/ahnlich/client/src/ai.rs +++ b/ahnlich/client/src/ai.rs @@ -224,7 +224,6 @@ impl AIClient { .await } - #[allow(clippy::too_many_arguments)] pub async fn get_sim_n( &self, params: ai_params::GetSimNParams, diff --git a/ahnlich/db/src/engine/predicate.rs b/ahnlich/db/src/engine/predicate.rs index ff2f25a7..d29e3d76 100644 --- a/ahnlich/db/src/engine/predicate.rs +++ b/ahnlich/db/src/engine/predicate.rs @@ -19,6 +19,29 @@ use std::collections::HashSet as StdHashSet; use std::mem::size_of_val; use utils::parallel; +/// Predicates are essentially nested hashmaps that let us retrieve original keys that match a +/// precise value. Take the following example +/// +/// { +/// "Country": { +/// "Nigeria": [StoreKeyId(1), StoreKeyId(2)], +/// "Australia": .., +/// }, +/// "Author": { +/// ... +/// } +/// } +/// +/// where `allowed_predicates` = ["Country", "Author"] +/// +/// It takes less time to retrieve "where country = 'Nigeria'" by traversing the nested hashmap to +/// obtain StoreKeyId(1) and StoreKeyId(2) than it would be to make a linear pass over an entire +/// Store of size N comparing their metadata "country" along the way. Given that StoreKeyId is +/// computed via blake hash, it is typically fast to compute and also of a fixed size which means +/// predicate indices don't balloon with large metadata +/// +/// Whichever key is not expressly included in `allowed_predicates` goes through the linear +/// pass in order to obtain keys that satisfy the condition type InnerPredicateIndexVal = ConcurrentHashSet; type InnerPredicateIndex = ConcurrentHashMap; type InnerPredicateIndices = ConcurrentHashMap; diff --git a/ahnlich/dsl/src/ai.rs b/ahnlich/dsl/src/ai.rs index 988b37ad..f7b27804 100644 --- a/ahnlich/dsl/src/ai.rs +++ b/ahnlich/dsl/src/ai.rs @@ -18,11 +18,11 @@ use pest::Parser; use crate::{error::DslError, predicate::parse_predicate_expression}; -fn parse_to_preprocess_action(input: &str) -> PreprocessAction { +fn parse_to_preprocess_action(input: &str) -> Result { match input.to_lowercase().trim() { - "nopreprocessing" => PreprocessAction::NoPreprocessing, - "modelpreprocessing" => PreprocessAction::ModelPreprocessing, - _ => panic!("Unexpected preprocess action"), + "nopreprocessing" => Ok(PreprocessAction::NoPreprocessing), + "modelpreprocessing" => Ok(PreprocessAction::ModelPreprocessing), + a => Err(DslError::UnsupportedPreprocessingMode(a.to_string())), } } @@ -53,7 +53,7 @@ pub const COMMANDS: &[&str] = &[ "dropnonlinearalgorithmindex", // if exists (kdtree) in store_name "delkey", // ([input 1 text], [input 2 text]) in my_store "getpred", // ((author = dickens) or (country != Nigeria)) in my_store - "getsimn", // 4 with [random text inserted here] using cosinesimilarity in my_store where (author = dickens) + "getsimn", // 4 with [random text inserted here] using cosinesimilarity preprocessaction nopreprocessing in my_store where (author = dickens) "createstore", // if not exists my_store querymodel resnet-50 indexmodel resnet-50 predicates (author, country) nonlinearalgorithmindex (kdtree) "set", // (([This is the life of Haks paragraphed], {name: Haks, category: dev}), ([This is the life of Deven paragraphed], {name: Deven, category: dev})) in store ]; @@ -83,9 +83,9 @@ pub fn parse_ai_query(input: &str) -> Result, DslError> { let preprocess_action = parse_to_preprocess_action( inner_pairs .next() - .ok_or(DslError::UnexpectedSpan((start_pos, end_pos)))? - .as_str(), - ); + .map(|a| a.as_str()) + .unwrap_or("nopreprocessing"), + )?; AIQuery::Set { store: StoreName(store.to_string()), @@ -175,6 +175,18 @@ pub fn parse_ai_query(input: &str) -> Result, DslError> { .ok_or(DslError::UnexpectedSpan((start_pos, end_pos)))? .as_str(), )?; + let mut preprocess_action = PreprocessAction::NoPreprocessing; + if let Some(next_pair) = inner_pairs.peek() { + if next_pair.as_rule() == Rule::preprocess_optional { + let mut pair = inner_pairs + .next() + .ok_or(DslError::UnexpectedSpan((start_pos, end_pos)))? + .into_inner(); + preprocess_action = parse_to_preprocess_action( + pair.next().map(|a| a.as_str()).unwrap_or("nopreprocessing"), + )?; + } + }; let store = inner_pairs .next() .ok_or(DslError::UnexpectedSpan((start_pos, end_pos)))? @@ -184,12 +196,6 @@ pub fn parse_ai_query(input: &str) -> Result, DslError> { } else { None }; - let preprocess_action = parse_to_preprocess_action( - inner_pairs - .next() - .ok_or(DslError::UnexpectedSpan((start_pos, end_pos)))? - .as_str(), - ); AIQuery::GetSimN { store: StoreName(store.to_string()), search_input, diff --git a/ahnlich/dsl/src/error.rs b/ahnlich/dsl/src/error.rs index 94154506..6c654acd 100644 --- a/ahnlich/dsl/src/error.rs +++ b/ahnlich/dsl/src/error.rs @@ -19,4 +19,6 @@ pub enum DslError { UnsupportedAIModel(String), #[error("Unsupported rule used in parse fn {0:?}")] UnsupportedRule(Rule), + #[error("Unexpected preprocessing {0:?}")] + UnsupportedPreprocessingMode(String), } diff --git a/ahnlich/dsl/src/syntax/syntax.pest b/ahnlich/dsl/src/syntax/syntax.pest index ed8b081c..17f1f37a 100644 --- a/ahnlich/dsl/src/syntax/syntax.pest +++ b/ahnlich/dsl/src/syntax/syntax.pest @@ -55,8 +55,8 @@ del_key = { whitespace* ~ ^"delkey" ~ whitespace* ~ "(" ~ f32_arrays ~ ")" ~ in_ ai_del_key = { whitespace* ~ ^"delkey" ~ whitespace* ~ "(" ~ store_inputs ~ ")" ~ in_ignored ~ store_name } get_pred = { whitespace* ~ ^"getpred" ~ whitespace* ~ predicate_condition ~ in_ignored ~ store_name } // GETSIMN 2 WITH store-key USING algorithm IN store (WHERE predicate_condition) -get_sim_n = { whitespace* ~ ^"getsimn" ~ whitespace* ~ non_zero ~ whitespace* ~ ^"with" ~ whitespace* ~ f32_array ~ whitespace* ~ ^"using" ~ whitespace* ~ algorithm ~ whitespace* ~ in_ignored ~ whitespace* ~ store_name ~ whitespace* ~ (^"where" ~ whitespace* ~ predicate_condition)? ~ whitespace* ~ ^"preprocessaction" ~ whitespace* ~ preprocess_action } -ai_get_sim_n = { whitespace* ~ ^"getsimn" ~ whitespace* ~ non_zero ~ whitespace* ~ ^"with" ~ whitespace* ~ "[" ~ whitespace* ~ metadata_value ~ whitespace* ~ "]" ~ whitespace* ~ ^"using" ~ whitespace* ~ algorithm ~ whitespace* ~ in_ignored ~ whitespace* ~ store_name ~ whitespace* ~ (^"where" ~ whitespace* ~ predicate_condition)? } +get_sim_n = { whitespace* ~ ^"getsimn" ~ whitespace* ~ non_zero ~ whitespace* ~ ^"with" ~ whitespace* ~ f32_array ~ whitespace* ~ ^"using" ~ whitespace* ~ algorithm ~ whitespace* ~ in_ignored ~ whitespace* ~ store_name ~ whitespace* ~ (^"where" ~ whitespace* ~ predicate_condition)? } +ai_get_sim_n = { whitespace* ~ ^"getsimn" ~ whitespace* ~ non_zero ~ whitespace* ~ ^"with" ~ whitespace* ~ "[" ~ whitespace* ~ metadata_value ~ whitespace* ~ "]" ~ whitespace* ~ ^"using" ~ whitespace* ~ algorithm ~ whitespace* ~ (preprocess_optional)? ~ whitespace* ~ in_ignored ~ whitespace* ~ store_name ~ whitespace* ~ (^"where" ~ whitespace* ~ predicate_condition)? } // CREATESTORE IF NOT EXISTS store-name DIMENSION non-zero-size PREDICATES (key1, key2) NONLINEARALGORITHMINDEX (kdtree) create_store = { whitespace* ~ ^"createstore" ~ whitespace* ~ (if_not_exists)? ~ whitespace* ~ store_name ~ whitespace* ~ ^"dimension" ~ whitespace* ~ non_zero ~ whitespace* ~ (^"predicates" ~ whitespace* ~ "(" ~ whitespace* ~ metadata_keys ~ whitespace* ~ ")" )? ~ (whitespace* ~ ^"nonlinearalgorithmindex" ~ whitespace* ~ "(" ~ whitespace* ~ non_linear_algorithms ~ whitespace* ~ ")")? } // CREATESTORE IF NOT EXISTS store-name QUERYMODEL model INDEXMODEL model PREDICATES (key1, key2) NONLINEARALGORITHMINDEX (kdtree) @@ -66,6 +66,7 @@ ai_set_in_store = { whitespace* ~ ^"set" ~ whitespace* ~ store_inputs_to_store_v if_exists = { whitespace* ~ ^"if" ~ whitespace* ~ ^"exists" ~ whitespace* } if_not_exists = { whitespace* ~ ^"if" ~ whitespace* ~ ^"not" ~ whitespace* ~ ^"exists" ~ whitespace* } +preprocess_optional = { whitespace* ~ ^"preprocessaction" ~ whitespace* ~ preprocess_action} store_original = { whitespace* ~ ^"storeoriginal" ~ whitespace* } // stores and predicates can be alphanumeric diff --git a/ahnlich/dsl/src/tests/ai.rs b/ahnlich/dsl/src/tests/ai.rs index d0209f70..1a971c6f 100644 --- a/ahnlich/dsl/src/tests/ai.rs +++ b/ahnlich/dsl/src/tests/ai.rs @@ -199,7 +199,7 @@ fn test_get_sim_n_parse() { panic!("Unexpected error pattern found") }; assert_eq!((start, end), (0, 68)); - let input = r#"GETSIMN 5 with [hi my name is carter] using cosinesimilarity in random"#; + let input = r#"GETSIMN 5 with [hi my name is carter] using cosinesimilarity preprocessaction MODELPREPROCESSING in random"#; assert_eq!( parse_ai_query(input).expect("Could not parse query input"), vec![AIQuery::GetSimN { @@ -232,7 +232,7 @@ fn test_get_sim_n_parse() { ]), })) ), - preprocess_action: PreprocessAction::ModelPreprocessing, + preprocess_action: PreprocessAction::NoPreprocessing, }] ); } From d2da7a72a01d22b8ba721ca4d9e463a4cc6627f7 Mon Sep 17 00:00:00 2001 From: Diretnan Domnan Date: Mon, 2 Dec 2024 10:25:49 +0100 Subject: [PATCH 09/10] Fixing python lint --- sdk/ahnlich-client-py/ahnlich_client_py/clients/ai.py | 2 +- .../tests/ai_client/test_ai_client_store_commands.py | 4 +++- 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/sdk/ahnlich-client-py/ahnlich_client_py/clients/ai.py b/sdk/ahnlich-client-py/ahnlich_client_py/clients/ai.py index 20ad31cd..bc426c85 100644 --- a/sdk/ahnlich-client-py/ahnlich_client_py/clients/ai.py +++ b/sdk/ahnlich-client-py/ahnlich_client_py/clients/ai.py @@ -199,7 +199,7 @@ def list_stores( builder = builders.AhnlichAIRequestBuilder(tracing_id) builder.list_stores() return self.process_request(builder.to_server_query()) - + def list_clients( self, tracing_id: typing.Optional[str] = None, diff --git a/sdk/ahnlich-client-py/ahnlich_client_py/tests/ai_client/test_ai_client_store_commands.py b/sdk/ahnlich-client-py/ahnlich_client_py/tests/ai_client/test_ai_client_store_commands.py index 9fe8c7d2..6a633ac3 100644 --- a/sdk/ahnlich-client-py/ahnlich_client_py/tests/ai_client/test_ai_client_store_commands.py +++ b/sdk/ahnlich-client-py/ahnlich_client_py/tests/ai_client/test_ai_client_store_commands.py @@ -232,7 +232,9 @@ def test_ai_client_get_key(spin_up_ahnlich_ai): preprocess_action=ai_query.PreprocessAction__NoPreprocessing(), ) expected = ai_response.Result__Ok( - value=ai_response.AIServerResponse__Get([(ai_query.StoreInput__RawString(value="Jordan One"), {})]) + value=ai_response.AIServerResponse__Get( + [(ai_query.StoreInput__RawString(value="Jordan One"), {})] + ) ) try: From aecf0170248538e5f8d853fa2a32534baaec8af6 Mon Sep 17 00:00:00 2001 From: Diretnan Domnan Date: Mon, 2 Dec 2024 10:37:33 +0100 Subject: [PATCH 10/10] Fixing test for new client list --- ahnlich/ai/src/tests/aiproxy_test.rs | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/ahnlich/ai/src/tests/aiproxy_test.rs b/ahnlich/ai/src/tests/aiproxy_test.rs index 27fb5fde..47d93524 100644 --- a/ahnlich/ai/src/tests/aiproxy_test.rs +++ b/ahnlich/ai/src/tests/aiproxy_test.rs @@ -224,7 +224,14 @@ async fn test_list_clients_works() { let inner = response.into_inner(); // only two clients are connected - assert!(inner.len() == 2) + match inner.as_slice() { + [Ok(AIServerResponse::ClientList(connected_clients))] => { + assert!(connected_clients.len() == 2) + } + a => { + assert!(false, "Unexpected result for client list {:?}", a); + } + }; } // TODO: Same issues with random storekeys, changing the order of expected response