Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

#105 AI command inconsistencies #143

Merged
merged 12 commits into from
Dec 2, 2024
Merged
47 changes: 36 additions & 11 deletions ahnlich/ai/src/server/task.rs
Original file line number Diff line number Diff line change
@@ -1,11 +1,8 @@
use crate::engine::ai::models::Model;
use ahnlich_client_rs::{builders::db as db_params, db::DbClient};
use ahnlich_types::ai::{
AIQuery, AIServerQuery, AIServerResponse, AIServerResult, PreprocessAction,
};
use ahnlich_types::ai::{AIQuery, AIServerQuery, AIServerResponse, AIServerResult};
use ahnlich_types::client::ConnectedClient;
use ahnlich_types::db::{ServerInfo, ServerResponse};
use ahnlich_types::keyval::StoreInput;
use ahnlich_types::metadata::MetadataValue;
use ahnlich_types::predicate::{Predicate, PredicateCondition};
use ahnlich_types::version::VERSION;
Expand Down Expand Up @@ -345,20 +342,15 @@ impl AhnlichProtocol for AIProxyTask {
condition,
closest_n,
algorithm,
preprocess_action,
} => {
// TODO: Replace this with calls to self.model_manager.handle_request
// TODO (HAKSOAT): Shouldn't preprocess action also be in the params?
let preprocess = match search_input {
StoreInput::RawString(_) => PreprocessAction::ModelPreprocessing,
StoreInput::Image(_) => PreprocessAction::ModelPreprocessing,
};
let repr = self
.store_handler
.get_ndarray_repr_for_store(
&store,
search_input,
&self.model_manager,
preprocess,
preprocess_action,
)
.await;
match repr {
Expand Down Expand Up @@ -405,6 +397,39 @@ impl AhnlichProtocol for AIProxyTask {
let destoryed = self.store_handler.purge_stores();
Ok(AIServerResponse::Del(destoryed))
}
AIQuery::ListClients => {
Ok(AIServerResponse::ClientList(self.client_handler.list()))
}
AIQuery::GetKey { store, keys } => {
let metadata_values: HashSet<MetadataValue> =
keys.into_iter().map(|value| value.into()).collect();
let get_key_condition = PredicateCondition::Value(Predicate::In {
key: AHNLICH_AI_RESERVED_META_KEY.clone(),
value: metadata_values,
});

let get_pred_params = db_params::GetPredParams::builder()
.store(store.to_string())
.condition(get_key_condition)
.tracing_id(parent_id.clone())
.build();

match self.db_client.get_pred(get_pred_params).await {
Ok(res) => {
if let ServerResponse::Get(response) = res {
// conversion to store input here
let output = self
.store_handler
.store_key_val_to_store_input_val(response);
Ok(AIServerResponse::Get(output))
} else {
Err(AIProxyError::UnexpectedDBResponse(format!("{:?}", res))
.to_string())
}
}
Err(err) => Err(format!("{err}")),
}
}
})
}
result
Expand Down
75 changes: 74 additions & 1 deletion ahnlich/ai/src/tests/aiproxy_test.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,11 +11,16 @@ use ahnlich_types::{
predicate::{Predicate, PredicateCondition},
similarity::Algorithm,
};
// use flurry::HashMap;
use utils::server::AhnlichServerUtils;

use once_cell::sync::Lazy;
use pretty_assertions::assert_eq;
use std::{collections::HashSet, num::NonZeroUsize, sync::atomic::Ordering};
use std::{
collections::{HashMap, HashSet},
num::NonZeroUsize,
sync::atomic::Ordering,
};

use crate::{
cli::{server::SupportedModels, AIProxyConfig},
Expand Down Expand Up @@ -162,6 +167,73 @@ async fn test_ai_proxy_create_store_success() {
query_server_assert_result(&mut reader, message, expected.clone()).await;
}

#[tokio::test]
async fn test_ai_store_get_key_works() {
let address = provision_test_servers().await;
let first_stream = TcpStream::connect(address).await.unwrap();
let second_stream = TcpStream::connect(address).await.unwrap();
let store_name = StoreName(String::from("Deven Kicks"));
let store_input = StoreInput::RawString(String::from("Jordan 3"));
let store_data: (StoreInput, HashMap<MetadataKey, MetadataValue>) =
(store_input.clone(), HashMap::new());

let message = AIServerQuery::from_queries(&[
AIQuery::CreateStore {
store: store_name.clone(),
query_model: AIModel::AllMiniLML6V2,
index_model: AIModel::AllMiniLML6V2,
predicates: HashSet::new(),
non_linear_indices: HashSet::new(),
error_if_exists: true,
store_original: false,
},
AIQuery::Set {
store: store_name.clone(),
inputs: vec![store_data.clone()],
preprocess_action: PreprocessAction::NoPreprocessing,
},
]);
let mut reader = BufReader::new(first_stream);

let _ = get_server_response(&mut reader, message).await;
let message = AIServerQuery::from_queries(&[AIQuery::GetKey {
store: store_name,
keys: vec![store_input.clone()],
}]);

let mut expected = AIServerResult::with_capacity(1);

expected.push(Ok(AIServerResponse::Get(vec![(
Some(store_input),
HashMap::new(),
)])));

let mut reader = BufReader::new(second_stream);
let response = get_server_response(&mut reader, message).await;
assert!(response.len() == expected.len())
}

#[tokio::test]
async fn test_list_clients_works() {
let address = provision_test_servers().await;
let _first_stream = TcpStream::connect(address).await.unwrap();
let second_stream = TcpStream::connect(address).await.unwrap();
let message = AIServerQuery::from_queries(&[AIQuery::ListClients]);
let mut reader = BufReader::new(second_stream);
let response = get_server_response(&mut reader, message).await;
let inner = response.into_inner();

// only two clients are connected
match inner.as_slice() {
[Ok(AIServerResponse::ClientList(connected_clients))] => {
assert!(connected_clients.len() == 2)
}
a => {
assert!(false, "Unexpected result for client list {:?}", a);
}
};
}

// TODO: Same issues with random storekeys, changing the order of expected response
#[tokio::test]
async fn test_ai_store_no_original() {
Expand Down Expand Up @@ -375,6 +447,7 @@ async fn test_ai_proxy_get_sim_n_succeeds() {
condition: None,
closest_n: NonZeroUsize::new(1).unwrap(),
algorithm: Algorithm::DotProductSimilarity,
preprocess_action: PreprocessAction::ModelPreprocessing,
}]);

let mut expected = AIServerResult::with_capacity(1);
Expand Down
2 changes: 2 additions & 0 deletions ahnlich/client/src/ai.rs
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,7 @@ impl AIPipeline {
condition: params.condition,
closest_n: params.closest_n,
algorithm: params.algorithm,
preprocess_action: params.preprocess_action,
})
}

Expand Down Expand Up @@ -234,6 +235,7 @@ impl AIClient {
condition: params.condition,
closest_n: params.closest_n,
algorithm: params.algorithm,
preprocess_action: params.preprocess_action,
},
params.tracing_id,
)
Expand Down
2 changes: 2 additions & 0 deletions ahnlich/client/src/builders/ai.rs
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,8 @@ pub struct GetSimNParams {

#[builder(default = None)]
pub tracing_id: Option<String>,
#[builder(default = PreprocessAction::NoPreprocessing)]
pub preprocess_action: PreprocessAction,
}

#[derive(TypedBuilder)]
Expand Down
23 changes: 23 additions & 0 deletions ahnlich/db/src/engine/predicate.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,29 @@ use std::collections::HashSet as StdHashSet;
use std::mem::size_of_val;
use utils::parallel;

/// Predicates are essentially nested hashmaps that let us retrieve original keys that match a
/// precise value. Take the following example
///
/// {
/// "Country": {
/// "Nigeria": [StoreKeyId(1), StoreKeyId(2)],
/// "Australia": ..,
/// },
/// "Author": {
/// ...
/// }
/// }
///
/// where `allowed_predicates` = ["Country", "Author"]
///
/// It takes less time to retrieve "where country = 'Nigeria'" by traversing the nested hashmap to
/// obtain StoreKeyId(1) and StoreKeyId(2) than it would be to make a linear pass over an entire
/// Store of size N comparing their metadata "country" along the way. Given that StoreKeyId is
/// computed via blake hash, it is typically fast to compute and also of a fixed size which means
/// predicate indices don't balloon with large metadata
///
/// Whichever key is not expressly included in `allowed_predicates` goes through the linear
/// pass in order to obtain keys that satisfy the condition
type InnerPredicateIndexVal = ConcurrentHashSet<StoreKeyId>;
type InnerPredicateIndex = ConcurrentHashMap<MetadataValue, InnerPredicateIndexVal>;
type InnerPredicateIndices = ConcurrentHashMap<MetadataKey, PredicateIndex>;
Expand Down
29 changes: 21 additions & 8 deletions ahnlich/dsl/src/ai.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,11 +18,11 @@ use pest::Parser;

use crate::{error::DslError, predicate::parse_predicate_expression};

fn parse_to_preprocess_action(input: &str) -> PreprocessAction {
fn parse_to_preprocess_action(input: &str) -> Result<PreprocessAction, DslError> {
match input.to_lowercase().trim() {
"nopreprocessing" => PreprocessAction::NoPreprocessing,
"modelpreprocessing" => PreprocessAction::ModelPreprocessing,
_ => panic!("Unexpected preprocess action"),
"nopreprocessing" => Ok(PreprocessAction::NoPreprocessing),
"modelpreprocessing" => Ok(PreprocessAction::ModelPreprocessing),
a => Err(DslError::UnsupportedPreprocessingMode(a.to_string())),
}
}

Expand Down Expand Up @@ -53,7 +53,7 @@ pub const COMMANDS: &[&str] = &[
"dropnonlinearalgorithmindex", // if exists (kdtree) in store_name
"delkey", // ([input 1 text], [input 2 text]) in my_store
"getpred", // ((author = dickens) or (country != Nigeria)) in my_store
"getsimn", // 4 with [random text inserted here] using cosinesimilarity in my_store where (author = dickens)
"getsimn", // 4 with [random text inserted here] using cosinesimilarity preprocessaction nopreprocessing in my_store where (author = dickens)
"createstore", // if not exists my_store querymodel resnet-50 indexmodel resnet-50 predicates (author, country) nonlinearalgorithmindex (kdtree)
"set", // (([This is the life of Haks paragraphed], {name: Haks, category: dev}), ([This is the life of Deven paragraphed], {name: Deven, category: dev})) in store
];
Expand Down Expand Up @@ -83,9 +83,9 @@ pub fn parse_ai_query(input: &str) -> Result<Vec<AIQuery>, DslError> {
let preprocess_action = parse_to_preprocess_action(
inner_pairs
.next()
.ok_or(DslError::UnexpectedSpan((start_pos, end_pos)))?
.as_str(),
);
.map(|a| a.as_str())
.unwrap_or("nopreprocessing"),
)?;

AIQuery::Set {
store: StoreName(store.to_string()),
Expand Down Expand Up @@ -175,6 +175,18 @@ pub fn parse_ai_query(input: &str) -> Result<Vec<AIQuery>, DslError> {
.ok_or(DslError::UnexpectedSpan((start_pos, end_pos)))?
.as_str(),
)?;
let mut preprocess_action = PreprocessAction::NoPreprocessing;
if let Some(next_pair) = inner_pairs.peek() {
if next_pair.as_rule() == Rule::preprocess_optional {
let mut pair = inner_pairs
.next()
.ok_or(DslError::UnexpectedSpan((start_pos, end_pos)))?
.into_inner();
preprocess_action = parse_to_preprocess_action(
pair.next().map(|a| a.as_str()).unwrap_or("nopreprocessing"),
)?;
}
};
let store = inner_pairs
.next()
.ok_or(DslError::UnexpectedSpan((start_pos, end_pos)))?
Expand All @@ -190,6 +202,7 @@ pub fn parse_ai_query(input: &str) -> Result<Vec<AIQuery>, DslError> {
closest_n,
algorithm,
condition,
preprocess_action,
}
}
Rule::get_pred => {
Expand Down
2 changes: 2 additions & 0 deletions ahnlich/dsl/src/error.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,4 +19,6 @@ pub enum DslError {
UnsupportedAIModel(String),
#[error("Unsupported rule used in parse fn {0:?}")]
UnsupportedRule(Rule),
#[error("Unexpected preprocessing {0:?}")]
UnsupportedPreprocessingMode(String),
}
3 changes: 2 additions & 1 deletion ahnlich/dsl/src/syntax/syntax.pest
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ ai_del_key = { whitespace* ~ ^"delkey" ~ whitespace* ~ "(" ~ store_inputs ~ ")"
get_pred = { whitespace* ~ ^"getpred" ~ whitespace* ~ predicate_condition ~ in_ignored ~ store_name }
// GETSIMN 2 WITH store-key USING algorithm IN store (WHERE predicate_condition)
get_sim_n = { whitespace* ~ ^"getsimn" ~ whitespace* ~ non_zero ~ whitespace* ~ ^"with" ~ whitespace* ~ f32_array ~ whitespace* ~ ^"using" ~ whitespace* ~ algorithm ~ whitespace* ~ in_ignored ~ whitespace* ~ store_name ~ whitespace* ~ (^"where" ~ whitespace* ~ predicate_condition)? }
ai_get_sim_n = { whitespace* ~ ^"getsimn" ~ whitespace* ~ non_zero ~ whitespace* ~ ^"with" ~ whitespace* ~ "[" ~ whitespace* ~ metadata_value ~ whitespace* ~ "]" ~ whitespace* ~ ^"using" ~ whitespace* ~ algorithm ~ whitespace* ~ in_ignored ~ whitespace* ~ store_name ~ whitespace* ~ (^"where" ~ whitespace* ~ predicate_condition)? }
ai_get_sim_n = { whitespace* ~ ^"getsimn" ~ whitespace* ~ non_zero ~ whitespace* ~ ^"with" ~ whitespace* ~ "[" ~ whitespace* ~ metadata_value ~ whitespace* ~ "]" ~ whitespace* ~ ^"using" ~ whitespace* ~ algorithm ~ whitespace* ~ (preprocess_optional)? ~ whitespace* ~ in_ignored ~ whitespace* ~ store_name ~ whitespace* ~ (^"where" ~ whitespace* ~ predicate_condition)? }
// CREATESTORE IF NOT EXISTS store-name DIMENSION non-zero-size PREDICATES (key1, key2) NONLINEARALGORITHMINDEX (kdtree)
create_store = { whitespace* ~ ^"createstore" ~ whitespace* ~ (if_not_exists)? ~ whitespace* ~ store_name ~ whitespace* ~ ^"dimension" ~ whitespace* ~ non_zero ~ whitespace* ~ (^"predicates" ~ whitespace* ~ "(" ~ whitespace* ~ metadata_keys ~ whitespace* ~ ")" )? ~ (whitespace* ~ ^"nonlinearalgorithmindex" ~ whitespace* ~ "(" ~ whitespace* ~ non_linear_algorithms ~ whitespace* ~ ")")? }
// CREATESTORE IF NOT EXISTS store-name QUERYMODEL model INDEXMODEL model PREDICATES (key1, key2) NONLINEARALGORITHMINDEX (kdtree)
Expand All @@ -66,6 +66,7 @@ ai_set_in_store = { whitespace* ~ ^"set" ~ whitespace* ~ store_inputs_to_store_v

if_exists = { whitespace* ~ ^"if" ~ whitespace* ~ ^"exists" ~ whitespace* }
if_not_exists = { whitespace* ~ ^"if" ~ whitespace* ~ ^"not" ~ whitespace* ~ ^"exists" ~ whitespace* }
preprocess_optional = { whitespace* ~ ^"preprocessaction" ~ whitespace* ~ preprocess_action}
store_original = { whitespace* ~ ^"storeoriginal" ~ whitespace* }

// stores and predicates can be alphanumeric
Expand Down
6 changes: 4 additions & 2 deletions ahnlich/dsl/src/tests/ai.rs
Original file line number Diff line number Diff line change
Expand Up @@ -199,15 +199,16 @@ fn test_get_sim_n_parse() {
panic!("Unexpected error pattern found")
};
assert_eq!((start, end), (0, 68));
let input = r#"GETSIMN 5 with [hi my name is carter] using cosinesimilarity in random"#;
let input = r#"GETSIMN 5 with [hi my name is carter] using cosinesimilarity preprocessaction MODELPREPROCESSING in random"#;
assert_eq!(
parse_ai_query(input).expect("Could not parse query input"),
vec![AIQuery::GetSimN {
store: StoreName("random".to_string()),
search_input: StoreInput::RawString("hi my name is carter".to_string()),
closest_n: NonZeroUsize::new(5).unwrap(),
algorithm: Algorithm::CosineSimilarity,
condition: None
condition: None,
preprocess_action: PreprocessAction::ModelPreprocessing,
}]
);
let input = r#"GETSIMN 8 with [testing the limits of life] using euclideandistance in other where ((year != 2012) AND (month not in (december, october)))"#;
Expand All @@ -231,6 +232,7 @@ fn test_get_sim_n_parse() {
]),
}))
),
preprocess_action: PreprocessAction::NoPreprocessing,
}]
);
}
Expand Down
Loading
Loading