diff --git a/ahnlich/dsl/Cargo.toml b/ahnlich/dsl/Cargo.toml index a4321f6b..d3cc0942 100644 --- a/ahnlich/dsl/Cargo.toml +++ b/ahnlich/dsl/Cargo.toml @@ -8,4 +8,5 @@ ahnlich_types = { path = "../types", version = "*" } pest = "2.7.13" pest_derive = "2.7.13" thiserror.workspace = true +ndarray.workspace = true diff --git a/ahnlich/dsl/src/db.rs b/ahnlich/dsl/src/db.rs index 9d0d2bd1..9b3161ea 100644 --- a/ahnlich/dsl/src/db.rs +++ b/ahnlich/dsl/src/db.rs @@ -1,6 +1,10 @@ use ahnlich_types::{ - db::DBQuery, keyval::StoreName, metadata::MetadataKey, similarity::NonLinearAlgorithm, + db::DBQuery, + keyval::{StoreKey, StoreName}, + metadata::MetadataKey, + similarity::NonLinearAlgorithm, }; +use ndarray::Array1; use pest::Parser; use pest_derive::Parser; @@ -17,6 +21,19 @@ fn to_non_linear(input: &str) -> Option { } } +fn parse_multi_f32_array(f32_arrays_pair: pest::iterators::Pair) -> Vec { + f32_arrays_pair.into_inner().map(parse_f32_array).collect() +} + +fn parse_f32_array(pair: pest::iterators::Pair) -> StoreKey { + StoreKey(Array1::from_iter(pair.into_inner().map(|f32_pair| { + f32_pair + .as_str() + .parse::() + .expect("Cannot parse single f32 num") + }))) +} + // Parse raw strings separated by ; into a Vec. Examples include but are not restricted // to // @@ -29,12 +46,12 @@ fn to_non_linear(input: &str) -> Option { // DROPPREDINDEX IF EXISTS (key1, key2) in store_name // CREATENONLINEARALGORITHMINDEX (kdtree) in store_name // DROPNONLINEARALGORITHMINDEX IF EXISTS (kdtree) in store_name +// GETKEY ((1.0, 2.0), (3.0, 4.0)) IN my_store +// DELKEY ((1.2, 3.0), (5.6, 7.8)) IN my_store // // #TODO // SET -// DELKEY // CREATESTORE -// GETKEY // GETPRED // GETSIMN pub fn parse_db_query(input: &str) -> Result, DslError> { @@ -49,6 +66,38 @@ pub fn parse_db_query(input: &str) -> Result, DslError> { Rule::list_clients => DBQuery::ListClients, Rule::list_stores => DBQuery::ListStores, Rule::info_server => DBQuery::InfoServer, + Rule::get_key => { + let mut inner_pairs = statement.into_inner(); + let f32_arrays_pair = inner_pairs + .next() + .ok_or(DslError::UnexpectedSpan((start_pos, end_pos)))?; + let keys = parse_multi_f32_array(f32_arrays_pair); + + let store = inner_pairs + .next() + .ok_or(DslError::UnexpectedSpan((start_pos, end_pos)))? + .as_str(); + DBQuery::GetKey { + store: StoreName(store.to_string()), + keys, + } + } + Rule::del_key => { + let mut inner_pairs = statement.into_inner(); + let f32_arrays_pair = inner_pairs + .next() + .ok_or(DslError::UnexpectedSpan((start_pos, end_pos)))?; + let keys = parse_multi_f32_array(f32_arrays_pair); + + let store = inner_pairs + .next() + .ok_or(DslError::UnexpectedSpan((start_pos, end_pos)))? + .as_str(); + DBQuery::DelKey { + store: StoreName(store.to_string()), + keys, + } + } Rule::create_non_linear_algorithm_index => { let mut inner_pairs = statement.into_inner(); let index_name_pairs = inner_pairs @@ -324,4 +373,44 @@ mod tests { }] ); } + + #[test] + fn test_get_key_parse() { + let input = r#"getkey ((a, b, c), (3.0, 4.0)) in 1234"#; + let DslError::UnexpectedSpan((start, end)) = parse_db_query(input).unwrap_err() else { + panic!("Unexpected error pattern found") + }; + assert_eq!((start, end), (0, 38)); + let input = r#"getkey ((1, 2, 3), (3.0, 4.0)) in 1234"#; + assert_eq!( + parse_db_query(input).expect("Could not parse query input"), + vec![DBQuery::GetKey { + store: StoreName("1234".to_string()), + keys: vec![ + StoreKey(Array1::from_iter([1.0, 2.0, 3.0])), + StoreKey(Array1::from_iter([3.0, 4.0])), + ], + }] + ); + } + + #[test] + fn test_del_key_parse() { + let input = r#"DELKEY ((a, b, c), (3.0, 4.0)) in 1234"#; + let DslError::UnexpectedSpan((start, end)) = parse_db_query(input).unwrap_err() else { + panic!("Unexpected error pattern found") + }; + assert_eq!((start, end), (0, 38)); + let input = r#"DELKEY ((1, 2, 3), (3.0, 4.0)) in 1234"#; + assert_eq!( + parse_db_query(input).expect("Could not parse query input"), + vec![DBQuery::DelKey { + store: StoreName("1234".to_string()), + keys: vec![ + StoreKey(Array1::from_iter([1.0, 2.0, 3.0])), + StoreKey(Array1::from_iter([3.0, 4.0])), + ], + }] + ); + } } diff --git a/ahnlich/dsl/src/syntax/db.pest b/ahnlich/dsl/src/syntax/db.pest index fd1f2c5f..22ed938b 100644 --- a/ahnlich/dsl/src/syntax/db.pest +++ b/ahnlich/dsl/src/syntax/db.pest @@ -3,15 +3,17 @@ whitespace = _{ " " | "\t" } query = _{ statement ~ (";" ~ statement) * } // Matches multiple statements separated by ; statement = _{ - ping | - info_server | - list_stores | - list_clients | - drop_store | - create_pred_index | - drop_pred_index | - create_non_linear_algorithm_index | - drop_non_linear_algorithm_index | + ping | + info_server | + list_stores | + list_clients | + drop_store | + create_pred_index | + drop_pred_index | + create_non_linear_algorithm_index | + drop_non_linear_algorithm_index | + get_key | + del_key | invalid_statement } @@ -24,6 +26,8 @@ create_pred_index = { whitespace* ~ ^"createpredindex" ~ whitespace* ~ "(" ~ ind create_non_linear_algorithm_index = { whitespace* ~ ^"createnonlinearalgorithmindex" ~ whitespace* ~ "(" ~ non_linear_algorithms ~ ")" ~ whitespace* ~ ^"in" ~ whitespace* ~ store_name} drop_pred_index = { whitespace* ~ ^"droppredindex" ~ whitespace* ~ (if_exists)? ~ "(" ~ index_names ~ ")" ~ whitespace* ~ ^"in" ~whitespace* ~ store_name } drop_non_linear_algorithm_index = { whitespace* ~ ^"dropnonlinearalgorithmindex" ~ whitespace* ~ (if_exists)? ~ "(" ~ non_linear_algorithms ~ ")" ~ whitespace* ~ ^"in" ~whitespace* ~ store_name } +get_key = { whitespace* ~ ^"getkey" ~ whitespace* ~ "(" ~ f32_arrays ~ ")" ~ whitespace* ~ ^"in" ~ whitespace* ~ store_name } +del_key = { whitespace* ~ ^"delkey" ~ whitespace* ~ "(" ~ f32_arrays ~ ")" ~ whitespace* ~ ^"in" ~ whitespace* ~ store_name } if_exists = { whitespace* ~ ^"if" ~ whitespace* ~ ^"exists" ~ whitespace* } @@ -34,6 +38,15 @@ non_linear_algorithm = { ^"kdtree" } non_linear_algorithms = { non_linear_algorithm ~ (whitespace* ~ "," ~ whitespace* ~ non_linear_algorithm)* } index_names = { index_name ~ (whitespace* ~ "," ~ whitespace* ~ index_name)* } +// Floating point number +f32 = { ASCII_DIGIT+ ~ ("." ~ ASCII_DIGIT+)? } + +// Array of floating-point numbers +f32_array = { "(" ~ f32 ~ (whitespace* ~ "," ~ whitespace* ~ f32)* ~ ")"} + +// List of f32 arrays (comma-separated) +f32_arrays = { f32_array ~ (whitespace* ~ "," ~ whitespace* ~ f32_array)* } + // Catch-all rule for invalid statements invalid_statement = { whitespace* ~ (!";" ~ ANY)+ } // Match anything that isn't a valid statement