Skip to content

Commit

Permalink
#105 AI command inconsistencies (#143)
Browse files Browse the repository at this point in the history
  • Loading branch information
lordsarcastic authored Dec 2, 2024
1 parent 5b34a26 commit e844425
Show file tree
Hide file tree
Showing 19 changed files with 363 additions and 61 deletions.
47 changes: 36 additions & 11 deletions ahnlich/ai/src/server/task.rs
Original file line number Diff line number Diff line change
@@ -1,11 +1,8 @@
use crate::engine::ai::models::Model;
use ahnlich_client_rs::{builders::db as db_params, db::DbClient};
use ahnlich_types::ai::{
AIQuery, AIServerQuery, AIServerResponse, AIServerResult, PreprocessAction,
};
use ahnlich_types::ai::{AIQuery, AIServerQuery, AIServerResponse, AIServerResult};
use ahnlich_types::client::ConnectedClient;
use ahnlich_types::db::{ServerInfo, ServerResponse};
use ahnlich_types::keyval::StoreInput;
use ahnlich_types::metadata::MetadataValue;
use ahnlich_types::predicate::{Predicate, PredicateCondition};
use ahnlich_types::version::VERSION;
Expand Down Expand Up @@ -345,20 +342,15 @@ impl AhnlichProtocol for AIProxyTask {
condition,
closest_n,
algorithm,
preprocess_action,
} => {
// TODO: Replace this with calls to self.model_manager.handle_request
// TODO (HAKSOAT): Shouldn't preprocess action also be in the params?
let preprocess = match search_input {
StoreInput::RawString(_) => PreprocessAction::ModelPreprocessing,
StoreInput::Image(_) => PreprocessAction::ModelPreprocessing,
};
let repr = self
.store_handler
.get_ndarray_repr_for_store(
&store,
search_input,
&self.model_manager,
preprocess,
preprocess_action,
)
.await;
match repr {
Expand Down Expand Up @@ -405,6 +397,39 @@ impl AhnlichProtocol for AIProxyTask {
let destoryed = self.store_handler.purge_stores();
Ok(AIServerResponse::Del(destoryed))
}
AIQuery::ListClients => {
Ok(AIServerResponse::ClientList(self.client_handler.list()))
}
AIQuery::GetKey { store, keys } => {
let metadata_values: HashSet<MetadataValue> =
keys.into_iter().map(|value| value.into()).collect();
let get_key_condition = PredicateCondition::Value(Predicate::In {
key: AHNLICH_AI_RESERVED_META_KEY.clone(),
value: metadata_values,
});

let get_pred_params = db_params::GetPredParams::builder()
.store(store.to_string())
.condition(get_key_condition)
.tracing_id(parent_id.clone())
.build();

match self.db_client.get_pred(get_pred_params).await {
Ok(res) => {
if let ServerResponse::Get(response) = res {
// conversion to store input here
let output = self
.store_handler
.store_key_val_to_store_input_val(response);
Ok(AIServerResponse::Get(output))
} else {
Err(AIProxyError::UnexpectedDBResponse(format!("{:?}", res))
.to_string())
}
}
Err(err) => Err(format!("{err}")),
}
}
})
}
result
Expand Down
75 changes: 74 additions & 1 deletion ahnlich/ai/src/tests/aiproxy_test.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,11 +11,16 @@ use ahnlich_types::{
predicate::{Predicate, PredicateCondition},
similarity::Algorithm,
};
// use flurry::HashMap;
use utils::server::AhnlichServerUtils;

use once_cell::sync::Lazy;
use pretty_assertions::assert_eq;
use std::{collections::HashSet, num::NonZeroUsize, sync::atomic::Ordering};
use std::{
collections::{HashMap, HashSet},
num::NonZeroUsize,
sync::atomic::Ordering,
};

use crate::{
cli::{server::SupportedModels, AIProxyConfig},
Expand Down Expand Up @@ -162,6 +167,73 @@ async fn test_ai_proxy_create_store_success() {
query_server_assert_result(&mut reader, message, expected.clone()).await;
}

#[tokio::test]
async fn test_ai_store_get_key_works() {
let address = provision_test_servers().await;
let first_stream = TcpStream::connect(address).await.unwrap();
let second_stream = TcpStream::connect(address).await.unwrap();
let store_name = StoreName(String::from("Deven Kicks"));
let store_input = StoreInput::RawString(String::from("Jordan 3"));
let store_data: (StoreInput, HashMap<MetadataKey, MetadataValue>) =
(store_input.clone(), HashMap::new());

let message = AIServerQuery::from_queries(&[
AIQuery::CreateStore {
store: store_name.clone(),
query_model: AIModel::AllMiniLML6V2,
index_model: AIModel::AllMiniLML6V2,
predicates: HashSet::new(),
non_linear_indices: HashSet::new(),
error_if_exists: true,
store_original: false,
},
AIQuery::Set {
store: store_name.clone(),
inputs: vec![store_data.clone()],
preprocess_action: PreprocessAction::NoPreprocessing,
},
]);
let mut reader = BufReader::new(first_stream);

let _ = get_server_response(&mut reader, message).await;
let message = AIServerQuery::from_queries(&[AIQuery::GetKey {
store: store_name,
keys: vec![store_input.clone()],
}]);

let mut expected = AIServerResult::with_capacity(1);

expected.push(Ok(AIServerResponse::Get(vec![(
Some(store_input),
HashMap::new(),
)])));

let mut reader = BufReader::new(second_stream);
let response = get_server_response(&mut reader, message).await;
assert!(response.len() == expected.len())
}

#[tokio::test]
async fn test_list_clients_works() {
let address = provision_test_servers().await;
let _first_stream = TcpStream::connect(address).await.unwrap();
let second_stream = TcpStream::connect(address).await.unwrap();
let message = AIServerQuery::from_queries(&[AIQuery::ListClients]);
let mut reader = BufReader::new(second_stream);
let response = get_server_response(&mut reader, message).await;
let inner = response.into_inner();

// only two clients are connected
match inner.as_slice() {
[Ok(AIServerResponse::ClientList(connected_clients))] => {
assert!(connected_clients.len() == 2)
}
a => {
assert!(false, "Unexpected result for client list {:?}", a);
}
};
}

// TODO: Same issues with random storekeys, changing the order of expected response
#[tokio::test]
async fn test_ai_store_no_original() {
Expand Down Expand Up @@ -375,6 +447,7 @@ async fn test_ai_proxy_get_sim_n_succeeds() {
condition: None,
closest_n: NonZeroUsize::new(1).unwrap(),
algorithm: Algorithm::DotProductSimilarity,
preprocess_action: PreprocessAction::ModelPreprocessing,
}]);

let mut expected = AIServerResult::with_capacity(1);
Expand Down
2 changes: 2 additions & 0 deletions ahnlich/client/src/ai.rs
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,7 @@ impl AIPipeline {
condition: params.condition,
closest_n: params.closest_n,
algorithm: params.algorithm,
preprocess_action: params.preprocess_action,
})
}

Expand Down Expand Up @@ -234,6 +235,7 @@ impl AIClient {
condition: params.condition,
closest_n: params.closest_n,
algorithm: params.algorithm,
preprocess_action: params.preprocess_action,
},
params.tracing_id,
)
Expand Down
2 changes: 2 additions & 0 deletions ahnlich/client/src/builders/ai.rs
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,8 @@ pub struct GetSimNParams {

#[builder(default = None)]
pub tracing_id: Option<String>,
#[builder(default = PreprocessAction::NoPreprocessing)]
pub preprocess_action: PreprocessAction,
}

#[derive(TypedBuilder)]
Expand Down
23 changes: 23 additions & 0 deletions ahnlich/db/src/engine/predicate.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,29 @@ use std::collections::HashSet as StdHashSet;
use std::mem::size_of_val;
use utils::parallel;

/// Predicates are essentially nested hashmaps that let us retrieve original keys that match a
/// precise value. Take the following example
///
/// {
/// "Country": {
/// "Nigeria": [StoreKeyId(1), StoreKeyId(2)],
/// "Australia": ..,
/// },
/// "Author": {
/// ...
/// }
/// }
///
/// where `allowed_predicates` = ["Country", "Author"]
///
/// It takes less time to retrieve "where country = 'Nigeria'" by traversing the nested hashmap to
/// obtain StoreKeyId(1) and StoreKeyId(2) than it would be to make a linear pass over an entire
/// Store of size N comparing their metadata "country" along the way. Given that StoreKeyId is
/// computed via blake hash, it is typically fast to compute and also of a fixed size which means
/// predicate indices don't balloon with large metadata
///
/// Whichever key is not expressly included in `allowed_predicates` goes through the linear
/// pass in order to obtain keys that satisfy the condition
type InnerPredicateIndexVal = ConcurrentHashSet<StoreKeyId>;
type InnerPredicateIndex = ConcurrentHashMap<MetadataValue, InnerPredicateIndexVal>;
type InnerPredicateIndices = ConcurrentHashMap<MetadataKey, PredicateIndex>;
Expand Down
29 changes: 21 additions & 8 deletions ahnlich/dsl/src/ai.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,11 +18,11 @@ use pest::Parser;

use crate::{error::DslError, predicate::parse_predicate_expression};

fn parse_to_preprocess_action(input: &str) -> PreprocessAction {
fn parse_to_preprocess_action(input: &str) -> Result<PreprocessAction, DslError> {
match input.to_lowercase().trim() {
"nopreprocessing" => PreprocessAction::NoPreprocessing,
"modelpreprocessing" => PreprocessAction::ModelPreprocessing,
_ => panic!("Unexpected preprocess action"),
"nopreprocessing" => Ok(PreprocessAction::NoPreprocessing),
"modelpreprocessing" => Ok(PreprocessAction::ModelPreprocessing),
a => Err(DslError::UnsupportedPreprocessingMode(a.to_string())),
}
}

Expand Down Expand Up @@ -53,7 +53,7 @@ pub const COMMANDS: &[&str] = &[
"dropnonlinearalgorithmindex", // if exists (kdtree) in store_name
"delkey", // ([input 1 text], [input 2 text]) in my_store
"getpred", // ((author = dickens) or (country != Nigeria)) in my_store
"getsimn", // 4 with [random text inserted here] using cosinesimilarity in my_store where (author = dickens)
"getsimn", // 4 with [random text inserted here] using cosinesimilarity preprocessaction nopreprocessing in my_store where (author = dickens)
"createstore", // if not exists my_store querymodel resnet-50 indexmodel resnet-50 predicates (author, country) nonlinearalgorithmindex (kdtree)
"set", // (([This is the life of Haks paragraphed], {name: Haks, category: dev}), ([This is the life of Deven paragraphed], {name: Deven, category: dev})) in store
];
Expand Down Expand Up @@ -83,9 +83,9 @@ pub fn parse_ai_query(input: &str) -> Result<Vec<AIQuery>, DslError> {
let preprocess_action = parse_to_preprocess_action(
inner_pairs
.next()
.ok_or(DslError::UnexpectedSpan((start_pos, end_pos)))?
.as_str(),
);
.map(|a| a.as_str())
.unwrap_or("nopreprocessing"),
)?;

AIQuery::Set {
store: StoreName(store.to_string()),
Expand Down Expand Up @@ -175,6 +175,18 @@ pub fn parse_ai_query(input: &str) -> Result<Vec<AIQuery>, DslError> {
.ok_or(DslError::UnexpectedSpan((start_pos, end_pos)))?
.as_str(),
)?;
let mut preprocess_action = PreprocessAction::NoPreprocessing;
if let Some(next_pair) = inner_pairs.peek() {
if next_pair.as_rule() == Rule::preprocess_optional {
let mut pair = inner_pairs
.next()
.ok_or(DslError::UnexpectedSpan((start_pos, end_pos)))?
.into_inner();
preprocess_action = parse_to_preprocess_action(
pair.next().map(|a| a.as_str()).unwrap_or("nopreprocessing"),
)?;
}
};
let store = inner_pairs
.next()
.ok_or(DslError::UnexpectedSpan((start_pos, end_pos)))?
Expand All @@ -190,6 +202,7 @@ pub fn parse_ai_query(input: &str) -> Result<Vec<AIQuery>, DslError> {
closest_n,
algorithm,
condition,
preprocess_action,
}
}
Rule::get_pred => {
Expand Down
2 changes: 2 additions & 0 deletions ahnlich/dsl/src/error.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,4 +19,6 @@ pub enum DslError {
UnsupportedAIModel(String),
#[error("Unsupported rule used in parse fn {0:?}")]
UnsupportedRule(Rule),
#[error("Unexpected preprocessing {0:?}")]
UnsupportedPreprocessingMode(String),
}
3 changes: 2 additions & 1 deletion ahnlich/dsl/src/syntax/syntax.pest
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ ai_del_key = { whitespace* ~ ^"delkey" ~ whitespace* ~ "(" ~ store_inputs ~ ")"
get_pred = { whitespace* ~ ^"getpred" ~ whitespace* ~ predicate_condition ~ in_ignored ~ store_name }
// GETSIMN 2 WITH store-key USING algorithm IN store (WHERE predicate_condition)
get_sim_n = { whitespace* ~ ^"getsimn" ~ whitespace* ~ non_zero ~ whitespace* ~ ^"with" ~ whitespace* ~ f32_array ~ whitespace* ~ ^"using" ~ whitespace* ~ algorithm ~ whitespace* ~ in_ignored ~ whitespace* ~ store_name ~ whitespace* ~ (^"where" ~ whitespace* ~ predicate_condition)? }
ai_get_sim_n = { whitespace* ~ ^"getsimn" ~ whitespace* ~ non_zero ~ whitespace* ~ ^"with" ~ whitespace* ~ "[" ~ whitespace* ~ metadata_value ~ whitespace* ~ "]" ~ whitespace* ~ ^"using" ~ whitespace* ~ algorithm ~ whitespace* ~ in_ignored ~ whitespace* ~ store_name ~ whitespace* ~ (^"where" ~ whitespace* ~ predicate_condition)? }
ai_get_sim_n = { whitespace* ~ ^"getsimn" ~ whitespace* ~ non_zero ~ whitespace* ~ ^"with" ~ whitespace* ~ "[" ~ whitespace* ~ metadata_value ~ whitespace* ~ "]" ~ whitespace* ~ ^"using" ~ whitespace* ~ algorithm ~ whitespace* ~ (preprocess_optional)? ~ whitespace* ~ in_ignored ~ whitespace* ~ store_name ~ whitespace* ~ (^"where" ~ whitespace* ~ predicate_condition)? }
// CREATESTORE IF NOT EXISTS store-name DIMENSION non-zero-size PREDICATES (key1, key2) NONLINEARALGORITHMINDEX (kdtree)
create_store = { whitespace* ~ ^"createstore" ~ whitespace* ~ (if_not_exists)? ~ whitespace* ~ store_name ~ whitespace* ~ ^"dimension" ~ whitespace* ~ non_zero ~ whitespace* ~ (^"predicates" ~ whitespace* ~ "(" ~ whitespace* ~ metadata_keys ~ whitespace* ~ ")" )? ~ (whitespace* ~ ^"nonlinearalgorithmindex" ~ whitespace* ~ "(" ~ whitespace* ~ non_linear_algorithms ~ whitespace* ~ ")")? }
// CREATESTORE IF NOT EXISTS store-name QUERYMODEL model INDEXMODEL model PREDICATES (key1, key2) NONLINEARALGORITHMINDEX (kdtree)
Expand All @@ -66,6 +66,7 @@ ai_set_in_store = { whitespace* ~ ^"set" ~ whitespace* ~ store_inputs_to_store_v

if_exists = { whitespace* ~ ^"if" ~ whitespace* ~ ^"exists" ~ whitespace* }
if_not_exists = { whitespace* ~ ^"if" ~ whitespace* ~ ^"not" ~ whitespace* ~ ^"exists" ~ whitespace* }
preprocess_optional = { whitespace* ~ ^"preprocessaction" ~ whitespace* ~ preprocess_action}
store_original = { whitespace* ~ ^"storeoriginal" ~ whitespace* }

// stores and predicates can be alphanumeric
Expand Down
6 changes: 4 additions & 2 deletions ahnlich/dsl/src/tests/ai.rs
Original file line number Diff line number Diff line change
Expand Up @@ -199,15 +199,16 @@ fn test_get_sim_n_parse() {
panic!("Unexpected error pattern found")
};
assert_eq!((start, end), (0, 68));
let input = r#"GETSIMN 5 with [hi my name is carter] using cosinesimilarity in random"#;
let input = r#"GETSIMN 5 with [hi my name is carter] using cosinesimilarity preprocessaction MODELPREPROCESSING in random"#;
assert_eq!(
parse_ai_query(input).expect("Could not parse query input"),
vec![AIQuery::GetSimN {
store: StoreName("random".to_string()),
search_input: StoreInput::RawString("hi my name is carter".to_string()),
closest_n: NonZeroUsize::new(5).unwrap(),
algorithm: Algorithm::CosineSimilarity,
condition: None
condition: None,
preprocess_action: PreprocessAction::ModelPreprocessing,
}]
);
let input = r#"GETSIMN 8 with [testing the limits of life] using euclideandistance in other where ((year != 2012) AND (month not in (december, october)))"#;
Expand All @@ -231,6 +232,7 @@ fn test_get_sim_n_parse() {
]),
}))
),
preprocess_action: PreprocessAction::NoPreprocessing,
}]
);
}
Expand Down
Loading

0 comments on commit e844425

Please sign in to comment.