Skip to content

Commit e844425

Browse files
#105 AI command inconsistencies (#143)
1 parent 5b34a26 commit e844425

File tree

19 files changed

+363
-61
lines changed

19 files changed

+363
-61
lines changed

ahnlich/ai/src/server/task.rs

Lines changed: 36 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,8 @@
11
use crate::engine::ai::models::Model;
22
use ahnlich_client_rs::{builders::db as db_params, db::DbClient};
3-
use ahnlich_types::ai::{
4-
AIQuery, AIServerQuery, AIServerResponse, AIServerResult, PreprocessAction,
5-
};
3+
use ahnlich_types::ai::{AIQuery, AIServerQuery, AIServerResponse, AIServerResult};
64
use ahnlich_types::client::ConnectedClient;
75
use ahnlich_types::db::{ServerInfo, ServerResponse};
8-
use ahnlich_types::keyval::StoreInput;
96
use ahnlich_types::metadata::MetadataValue;
107
use ahnlich_types::predicate::{Predicate, PredicateCondition};
118
use ahnlich_types::version::VERSION;
@@ -345,20 +342,15 @@ impl AhnlichProtocol for AIProxyTask {
345342
condition,
346343
closest_n,
347344
algorithm,
345+
preprocess_action,
348346
} => {
349-
// TODO: Replace this with calls to self.model_manager.handle_request
350-
// TODO (HAKSOAT): Shouldn't preprocess action also be in the params?
351-
let preprocess = match search_input {
352-
StoreInput::RawString(_) => PreprocessAction::ModelPreprocessing,
353-
StoreInput::Image(_) => PreprocessAction::ModelPreprocessing,
354-
};
355347
let repr = self
356348
.store_handler
357349
.get_ndarray_repr_for_store(
358350
&store,
359351
search_input,
360352
&self.model_manager,
361-
preprocess,
353+
preprocess_action,
362354
)
363355
.await;
364356
match repr {
@@ -405,6 +397,39 @@ impl AhnlichProtocol for AIProxyTask {
405397
let destoryed = self.store_handler.purge_stores();
406398
Ok(AIServerResponse::Del(destoryed))
407399
}
400+
AIQuery::ListClients => {
401+
Ok(AIServerResponse::ClientList(self.client_handler.list()))
402+
}
403+
AIQuery::GetKey { store, keys } => {
404+
let metadata_values: HashSet<MetadataValue> =
405+
keys.into_iter().map(|value| value.into()).collect();
406+
let get_key_condition = PredicateCondition::Value(Predicate::In {
407+
key: AHNLICH_AI_RESERVED_META_KEY.clone(),
408+
value: metadata_values,
409+
});
410+
411+
let get_pred_params = db_params::GetPredParams::builder()
412+
.store(store.to_string())
413+
.condition(get_key_condition)
414+
.tracing_id(parent_id.clone())
415+
.build();
416+
417+
match self.db_client.get_pred(get_pred_params).await {
418+
Ok(res) => {
419+
if let ServerResponse::Get(response) = res {
420+
// conversion to store input here
421+
let output = self
422+
.store_handler
423+
.store_key_val_to_store_input_val(response);
424+
Ok(AIServerResponse::Get(output))
425+
} else {
426+
Err(AIProxyError::UnexpectedDBResponse(format!("{:?}", res))
427+
.to_string())
428+
}
429+
}
430+
Err(err) => Err(format!("{err}")),
431+
}
432+
}
408433
})
409434
}
410435
result

ahnlich/ai/src/tests/aiproxy_test.rs

Lines changed: 74 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,11 +11,16 @@ use ahnlich_types::{
1111
predicate::{Predicate, PredicateCondition},
1212
similarity::Algorithm,
1313
};
14+
// use flurry::HashMap;
1415
use utils::server::AhnlichServerUtils;
1516

1617
use once_cell::sync::Lazy;
1718
use pretty_assertions::assert_eq;
18-
use std::{collections::HashSet, num::NonZeroUsize, sync::atomic::Ordering};
19+
use std::{
20+
collections::{HashMap, HashSet},
21+
num::NonZeroUsize,
22+
sync::atomic::Ordering,
23+
};
1924

2025
use crate::{
2126
cli::{server::SupportedModels, AIProxyConfig},
@@ -162,6 +167,73 @@ async fn test_ai_proxy_create_store_success() {
162167
query_server_assert_result(&mut reader, message, expected.clone()).await;
163168
}
164169

170+
#[tokio::test]
171+
async fn test_ai_store_get_key_works() {
172+
let address = provision_test_servers().await;
173+
let first_stream = TcpStream::connect(address).await.unwrap();
174+
let second_stream = TcpStream::connect(address).await.unwrap();
175+
let store_name = StoreName(String::from("Deven Kicks"));
176+
let store_input = StoreInput::RawString(String::from("Jordan 3"));
177+
let store_data: (StoreInput, HashMap<MetadataKey, MetadataValue>) =
178+
(store_input.clone(), HashMap::new());
179+
180+
let message = AIServerQuery::from_queries(&[
181+
AIQuery::CreateStore {
182+
store: store_name.clone(),
183+
query_model: AIModel::AllMiniLML6V2,
184+
index_model: AIModel::AllMiniLML6V2,
185+
predicates: HashSet::new(),
186+
non_linear_indices: HashSet::new(),
187+
error_if_exists: true,
188+
store_original: false,
189+
},
190+
AIQuery::Set {
191+
store: store_name.clone(),
192+
inputs: vec![store_data.clone()],
193+
preprocess_action: PreprocessAction::NoPreprocessing,
194+
},
195+
]);
196+
let mut reader = BufReader::new(first_stream);
197+
198+
let _ = get_server_response(&mut reader, message).await;
199+
let message = AIServerQuery::from_queries(&[AIQuery::GetKey {
200+
store: store_name,
201+
keys: vec![store_input.clone()],
202+
}]);
203+
204+
let mut expected = AIServerResult::with_capacity(1);
205+
206+
expected.push(Ok(AIServerResponse::Get(vec![(
207+
Some(store_input),
208+
HashMap::new(),
209+
)])));
210+
211+
let mut reader = BufReader::new(second_stream);
212+
let response = get_server_response(&mut reader, message).await;
213+
assert!(response.len() == expected.len())
214+
}
215+
216+
#[tokio::test]
217+
async fn test_list_clients_works() {
218+
let address = provision_test_servers().await;
219+
let _first_stream = TcpStream::connect(address).await.unwrap();
220+
let second_stream = TcpStream::connect(address).await.unwrap();
221+
let message = AIServerQuery::from_queries(&[AIQuery::ListClients]);
222+
let mut reader = BufReader::new(second_stream);
223+
let response = get_server_response(&mut reader, message).await;
224+
let inner = response.into_inner();
225+
226+
// only two clients are connected
227+
match inner.as_slice() {
228+
[Ok(AIServerResponse::ClientList(connected_clients))] => {
229+
assert!(connected_clients.len() == 2)
230+
}
231+
a => {
232+
assert!(false, "Unexpected result for client list {:?}", a);
233+
}
234+
};
235+
}
236+
165237
// TODO: Same issues with random storekeys, changing the order of expected response
166238
#[tokio::test]
167239
async fn test_ai_store_no_original() {
@@ -375,6 +447,7 @@ async fn test_ai_proxy_get_sim_n_succeeds() {
375447
condition: None,
376448
closest_n: NonZeroUsize::new(1).unwrap(),
377449
algorithm: Algorithm::DotProductSimilarity,
450+
preprocess_action: PreprocessAction::ModelPreprocessing,
378451
}]);
379452

380453
let mut expected = AIServerResult::with_capacity(1);

ahnlich/client/src/ai.rs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -76,6 +76,7 @@ impl AIPipeline {
7676
condition: params.condition,
7777
closest_n: params.closest_n,
7878
algorithm: params.algorithm,
79+
preprocess_action: params.preprocess_action,
7980
})
8081
}
8182

@@ -234,6 +235,7 @@ impl AIClient {
234235
condition: params.condition,
235236
closest_n: params.closest_n,
236237
algorithm: params.algorithm,
238+
preprocess_action: params.preprocess_action,
237239
},
238240
params.tracing_id,
239241
)

ahnlich/client/src/builders/ai.rs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,8 @@ pub struct GetSimNParams {
6363

6464
#[builder(default = None)]
6565
pub tracing_id: Option<String>,
66+
#[builder(default = PreprocessAction::NoPreprocessing)]
67+
pub preprocess_action: PreprocessAction,
6668
}
6769

6870
#[derive(TypedBuilder)]

ahnlich/db/src/engine/predicate.rs

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,29 @@ use std::collections::HashSet as StdHashSet;
1919
use std::mem::size_of_val;
2020
use utils::parallel;
2121

22+
/// Predicates are essentially nested hashmaps that let us retrieve original keys that match a
23+
/// precise value. Take the following example
24+
///
25+
/// {
26+
/// "Country": {
27+
/// "Nigeria": [StoreKeyId(1), StoreKeyId(2)],
28+
/// "Australia": ..,
29+
/// },
30+
/// "Author": {
31+
/// ...
32+
/// }
33+
/// }
34+
///
35+
/// where `allowed_predicates` = ["Country", "Author"]
36+
///
37+
/// It takes less time to retrieve "where country = 'Nigeria'" by traversing the nested hashmap to
38+
/// obtain StoreKeyId(1) and StoreKeyId(2) than it would be to make a linear pass over an entire
39+
/// Store of size N comparing their metadata "country" along the way. Given that StoreKeyId is
40+
/// computed via blake hash, it is typically fast to compute and also of a fixed size which means
41+
/// predicate indices don't balloon with large metadata
42+
///
43+
/// Whichever key is not expressly included in `allowed_predicates` goes through the linear
44+
/// pass in order to obtain keys that satisfy the condition
2245
type InnerPredicateIndexVal = ConcurrentHashSet<StoreKeyId>;
2346
type InnerPredicateIndex = ConcurrentHashMap<MetadataValue, InnerPredicateIndexVal>;
2447
type InnerPredicateIndices = ConcurrentHashMap<MetadataKey, PredicateIndex>;

ahnlich/dsl/src/ai.rs

Lines changed: 21 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -18,11 +18,11 @@ use pest::Parser;
1818

1919
use crate::{error::DslError, predicate::parse_predicate_expression};
2020

21-
fn parse_to_preprocess_action(input: &str) -> PreprocessAction {
21+
fn parse_to_preprocess_action(input: &str) -> Result<PreprocessAction, DslError> {
2222
match input.to_lowercase().trim() {
23-
"nopreprocessing" => PreprocessAction::NoPreprocessing,
24-
"modelpreprocessing" => PreprocessAction::ModelPreprocessing,
25-
_ => panic!("Unexpected preprocess action"),
23+
"nopreprocessing" => Ok(PreprocessAction::NoPreprocessing),
24+
"modelpreprocessing" => Ok(PreprocessAction::ModelPreprocessing),
25+
a => Err(DslError::UnsupportedPreprocessingMode(a.to_string())),
2626
}
2727
}
2828

@@ -53,7 +53,7 @@ pub const COMMANDS: &[&str] = &[
5353
"dropnonlinearalgorithmindex", // if exists (kdtree) in store_name
5454
"delkey", // ([input 1 text], [input 2 text]) in my_store
5555
"getpred", // ((author = dickens) or (country != Nigeria)) in my_store
56-
"getsimn", // 4 with [random text inserted here] using cosinesimilarity in my_store where (author = dickens)
56+
"getsimn", // 4 with [random text inserted here] using cosinesimilarity preprocessaction nopreprocessing in my_store where (author = dickens)
5757
"createstore", // if not exists my_store querymodel resnet-50 indexmodel resnet-50 predicates (author, country) nonlinearalgorithmindex (kdtree)
5858
"set", // (([This is the life of Haks paragraphed], {name: Haks, category: dev}), ([This is the life of Deven paragraphed], {name: Deven, category: dev})) in store
5959
];
@@ -83,9 +83,9 @@ pub fn parse_ai_query(input: &str) -> Result<Vec<AIQuery>, DslError> {
8383
let preprocess_action = parse_to_preprocess_action(
8484
inner_pairs
8585
.next()
86-
.ok_or(DslError::UnexpectedSpan((start_pos, end_pos)))?
87-
.as_str(),
88-
);
86+
.map(|a| a.as_str())
87+
.unwrap_or("nopreprocessing"),
88+
)?;
8989

9090
AIQuery::Set {
9191
store: StoreName(store.to_string()),
@@ -175,6 +175,18 @@ pub fn parse_ai_query(input: &str) -> Result<Vec<AIQuery>, DslError> {
175175
.ok_or(DslError::UnexpectedSpan((start_pos, end_pos)))?
176176
.as_str(),
177177
)?;
178+
let mut preprocess_action = PreprocessAction::NoPreprocessing;
179+
if let Some(next_pair) = inner_pairs.peek() {
180+
if next_pair.as_rule() == Rule::preprocess_optional {
181+
let mut pair = inner_pairs
182+
.next()
183+
.ok_or(DslError::UnexpectedSpan((start_pos, end_pos)))?
184+
.into_inner();
185+
preprocess_action = parse_to_preprocess_action(
186+
pair.next().map(|a| a.as_str()).unwrap_or("nopreprocessing"),
187+
)?;
188+
}
189+
};
178190
let store = inner_pairs
179191
.next()
180192
.ok_or(DslError::UnexpectedSpan((start_pos, end_pos)))?
@@ -190,6 +202,7 @@ pub fn parse_ai_query(input: &str) -> Result<Vec<AIQuery>, DslError> {
190202
closest_n,
191203
algorithm,
192204
condition,
205+
preprocess_action,
193206
}
194207
}
195208
Rule::get_pred => {

ahnlich/dsl/src/error.rs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,4 +19,6 @@ pub enum DslError {
1919
UnsupportedAIModel(String),
2020
#[error("Unsupported rule used in parse fn {0:?}")]
2121
UnsupportedRule(Rule),
22+
#[error("Unexpected preprocessing {0:?}")]
23+
UnsupportedPreprocessingMode(String),
2224
}

ahnlich/dsl/src/syntax/syntax.pest

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,7 @@ ai_del_key = { whitespace* ~ ^"delkey" ~ whitespace* ~ "(" ~ store_inputs ~ ")"
5656
get_pred = { whitespace* ~ ^"getpred" ~ whitespace* ~ predicate_condition ~ in_ignored ~ store_name }
5757
// GETSIMN 2 WITH store-key USING algorithm IN store (WHERE predicate_condition)
5858
get_sim_n = { whitespace* ~ ^"getsimn" ~ whitespace* ~ non_zero ~ whitespace* ~ ^"with" ~ whitespace* ~ f32_array ~ whitespace* ~ ^"using" ~ whitespace* ~ algorithm ~ whitespace* ~ in_ignored ~ whitespace* ~ store_name ~ whitespace* ~ (^"where" ~ whitespace* ~ predicate_condition)? }
59-
ai_get_sim_n = { whitespace* ~ ^"getsimn" ~ whitespace* ~ non_zero ~ whitespace* ~ ^"with" ~ whitespace* ~ "[" ~ whitespace* ~ metadata_value ~ whitespace* ~ "]" ~ whitespace* ~ ^"using" ~ whitespace* ~ algorithm ~ whitespace* ~ in_ignored ~ whitespace* ~ store_name ~ whitespace* ~ (^"where" ~ whitespace* ~ predicate_condition)? }
59+
ai_get_sim_n = { whitespace* ~ ^"getsimn" ~ whitespace* ~ non_zero ~ whitespace* ~ ^"with" ~ whitespace* ~ "[" ~ whitespace* ~ metadata_value ~ whitespace* ~ "]" ~ whitespace* ~ ^"using" ~ whitespace* ~ algorithm ~ whitespace* ~ (preprocess_optional)? ~ whitespace* ~ in_ignored ~ whitespace* ~ store_name ~ whitespace* ~ (^"where" ~ whitespace* ~ predicate_condition)? }
6060
// CREATESTORE IF NOT EXISTS store-name DIMENSION non-zero-size PREDICATES (key1, key2) NONLINEARALGORITHMINDEX (kdtree)
6161
create_store = { whitespace* ~ ^"createstore" ~ whitespace* ~ (if_not_exists)? ~ whitespace* ~ store_name ~ whitespace* ~ ^"dimension" ~ whitespace* ~ non_zero ~ whitespace* ~ (^"predicates" ~ whitespace* ~ "(" ~ whitespace* ~ metadata_keys ~ whitespace* ~ ")" )? ~ (whitespace* ~ ^"nonlinearalgorithmindex" ~ whitespace* ~ "(" ~ whitespace* ~ non_linear_algorithms ~ whitespace* ~ ")")? }
6262
// CREATESTORE IF NOT EXISTS store-name QUERYMODEL model INDEXMODEL model PREDICATES (key1, key2) NONLINEARALGORITHMINDEX (kdtree)
@@ -66,6 +66,7 @@ ai_set_in_store = { whitespace* ~ ^"set" ~ whitespace* ~ store_inputs_to_store_v
6666

6767
if_exists = { whitespace* ~ ^"if" ~ whitespace* ~ ^"exists" ~ whitespace* }
6868
if_not_exists = { whitespace* ~ ^"if" ~ whitespace* ~ ^"not" ~ whitespace* ~ ^"exists" ~ whitespace* }
69+
preprocess_optional = { whitespace* ~ ^"preprocessaction" ~ whitespace* ~ preprocess_action}
6970
store_original = { whitespace* ~ ^"storeoriginal" ~ whitespace* }
7071

7172
// stores and predicates can be alphanumeric

ahnlich/dsl/src/tests/ai.rs

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -199,15 +199,16 @@ fn test_get_sim_n_parse() {
199199
panic!("Unexpected error pattern found")
200200
};
201201
assert_eq!((start, end), (0, 68));
202-
let input = r#"GETSIMN 5 with [hi my name is carter] using cosinesimilarity in random"#;
202+
let input = r#"GETSIMN 5 with [hi my name is carter] using cosinesimilarity preprocessaction MODELPREPROCESSING in random"#;
203203
assert_eq!(
204204
parse_ai_query(input).expect("Could not parse query input"),
205205
vec![AIQuery::GetSimN {
206206
store: StoreName("random".to_string()),
207207
search_input: StoreInput::RawString("hi my name is carter".to_string()),
208208
closest_n: NonZeroUsize::new(5).unwrap(),
209209
algorithm: Algorithm::CosineSimilarity,
210-
condition: None
210+
condition: None,
211+
preprocess_action: PreprocessAction::ModelPreprocessing,
211212
}]
212213
);
213214
let input = r#"GETSIMN 8 with [testing the limits of life] using euclideandistance in other where ((year != 2012) AND (month not in (december, october)))"#;
@@ -231,6 +232,7 @@ fn test_get_sim_n_parse() {
231232
]),
232233
}))
233234
),
235+
preprocess_action: PreprocessAction::NoPreprocessing,
234236
}]
235237
);
236238
}

0 commit comments

Comments
 (0)