Skip to content

Commit

Permalink
GETSIMN command inclusion
Browse files Browse the repository at this point in the history
  • Loading branch information
deven96 committed Sep 26, 2024
1 parent 6cb1abe commit b365c20
Show file tree
Hide file tree
Showing 7 changed files with 387 additions and 269 deletions.
20 changes: 20 additions & 0 deletions ahnlich/dsl/src/algorithm.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
use ahnlich_types::similarity::{Algorithm, NonLinearAlgorithm};

use crate::error::DslError;

pub(crate) fn to_non_linear(input: &str) -> Option<NonLinearAlgorithm> {
match input.to_lowercase().trim() {
"kdtree" => Some(NonLinearAlgorithm::KDTree),
_ => None,
}
}

pub(crate) fn to_algorithm(input: &str) -> Result<Algorithm, DslError> {
match input.to_lowercase().trim() {
"kdtree" => Ok(Algorithm::KDTree),
"cosinesimilarity" => Ok(Algorithm::CosineSimilarity),
"dotproductsimilarity" => Ok(Algorithm::DotProductSimilarity),
"euclideandistance" => Ok(Algorithm::EuclideanDistance),
e => Err(DslError::UnsupportedAlgorithm(e.to_string())),
}
}
308 changes: 41 additions & 267 deletions ahnlich/dsl/src/db.rs
Original file line number Diff line number Diff line change
@@ -1,23 +1,20 @@
use crate::parser::{QueryParser, Rule};
use std::num::NonZeroUsize;

use crate::{
algorithm::{to_algorithm, to_non_linear},
parser::{QueryParser, Rule},
};
use ahnlich_types::{
db::DBQuery,
keyval::{StoreKey, StoreName},
metadata::MetadataKey,
similarity::NonLinearAlgorithm,
};
use ndarray::Array1;
use pest::iterators::Pair;
use pest::Parser;

use crate::{error::DslError, predicate::parse_predicate_expression};

fn to_non_linear(input: &str) -> Option<NonLinearAlgorithm> {
match input.to_lowercase().trim() {
"kdtree" => Some(NonLinearAlgorithm::KDTree),
_ => None,
}
}

fn parse_multi_f32_array(f32_arrays_pair: Pair<Rule>) -> Vec<StoreKey> {
f32_arrays_pair.into_inner().map(parse_f32_array).collect()
}
Expand Down Expand Up @@ -46,11 +43,11 @@ fn parse_f32_array(pair: Pair<Rule>) -> StoreKey {
// GETKEY ((1.0, 2.0), (3.0, 4.0)) IN my_store
// DELKEY ((1.2, 3.0), (5.6, 7.8)) IN my_store
// GETPRED ((author = dickens) OR (country != Nigeria)) IN my_store
// GETSIMN 4 WITH (0.65, 2.78) USING cosinesimilarity IN my_store WHERE (author = dickens)
//
// #TODO
// SET
// CREATESTORE
// GETSIMN
pub fn parse_db_query(input: &str) -> Result<Vec<DBQuery>, DslError> {
let pairs = QueryParser::parse(Rule::db_query, input).map_err(Box::new)?;
let statements = pairs.into_iter().collect::<Vec<_>>();
Expand All @@ -63,6 +60,40 @@ pub fn parse_db_query(input: &str) -> Result<Vec<DBQuery>, DslError> {
Rule::list_clients => DBQuery::ListClients,
Rule::list_stores => DBQuery::ListStores,
Rule::info_server => DBQuery::InfoServer,
Rule::get_sim_n => {
let mut inner_pairs = statement.into_inner();
let closest_n = inner_pairs
.next()
.ok_or(DslError::UnexpectedSpan((start_pos, end_pos)))?
.as_str()
.parse::<NonZeroUsize>()?;
let f32_array = inner_pairs
.next()
.ok_or(DslError::UnexpectedSpan((start_pos, end_pos)))?;
let search_input = parse_f32_array(f32_array);
let algorithm = to_algorithm(
inner_pairs
.next()
.ok_or(DslError::UnexpectedSpan((start_pos, end_pos)))?
.as_str(),
)?;
let store = inner_pairs
.next()
.ok_or(DslError::UnexpectedSpan((start_pos, end_pos)))?
.as_str();
let condition = if let Some(predicate_conditions) = inner_pairs.next() {
Some(parse_predicate_expression(predicate_conditions)?)
} else {
None
};
DBQuery::GetSimN {
store: StoreName(store.to_string()),
search_input,
closest_n,
algorithm,
condition,
}
}
Rule::get_pred => {
let mut inner_pairs = statement.into_inner();
let predicate_conditions = inner_pairs
Expand Down Expand Up @@ -225,260 +256,3 @@ pub fn parse_db_query(input: &str) -> Result<Vec<DBQuery>, DslError> {
}
Ok(queries)
}

#[cfg(test)]
mod tests {
use std::collections::HashSet;

use ahnlich_types::{
metadata::MetadataValue,
predicate::{Predicate, PredicateCondition},
};

use super::*;

#[test]
fn test_single_query_parse() {
let input = r#"LISTCLIENTS"#;
assert_eq!(
parse_db_query(input).expect("Could not parse query input"),
vec![DBQuery::ListClients]
);
let input = r#"listclients"#;
assert_eq!(
parse_db_query(input).expect("Could not parse query input"),
vec![DBQuery::ListClients]
);
let input = r#" Ping "#;
assert_eq!(
parse_db_query(input).expect("Could not parse query input"),
vec![DBQuery::Ping]
);
}

#[test]
fn test_multi_query_parse() {
let input = r#" INFOSERVER ; listSTORES;"#;
assert_eq!(
parse_db_query(input).expect("Could not parse query input"),
vec![DBQuery::InfoServer, DBQuery::ListStores]
);
}

#[test]
fn test_no_valid_input_in_query() {
let input = r#" random ; listSTORES;"#;
let DslError::UnexpectedSpan((start, end)) = parse_db_query(input).unwrap_err() else {
panic!("Unexpected error pattern found")
};
assert_eq!((start, end), (0, 8));
let input = r#" INfoSERVER ; random; ping"#;
let DslError::UnexpectedSpan((start, end)) = parse_db_query(input).unwrap_err() else {
panic!("Unexpected error pattern found")
};
assert_eq!((start, end), (13, 20));
}

#[test]
fn test_drop_store_parse() {
let input = r#"DROPSTORE random"#;
assert_eq!(
parse_db_query(input).expect("Could not parse query input"),
vec![DBQuery::DropStore {
store: StoreName("random".to_string()),
error_if_not_exists: true
}]
);
let input = r#"dropstore yeezy_store IF exists"#;
assert_eq!(
parse_db_query(input).expect("Could not parse query input"),
vec![DBQuery::DropStore {
store: StoreName("yeezy_store".to_string()),
error_if_not_exists: false,
}]
);
let input = r#"dropstore yeezy IF NOT exists"#;
// IF NOT EXISTS is not valid syntax
let DslError::UnexpectedSpan((start, end)) = parse_db_query(input).unwrap_err() else {
panic!("Unexpected error pattern found")
};
assert_eq!((start, end), (15, 29));
}

#[test]
fn test_create_predicate_index_parse() {
let input = r#"CREATEPREDINDEX (one, two, 3) in tapHstore1"#;
assert_eq!(
parse_db_query(input).expect("Could not parse query input"),
vec![DBQuery::CreatePredIndex {
store: StoreName("tapHstore1".to_string()),
predicates: HashSet::from_iter([
MetadataKey::new("one".to_string()),
MetadataKey::new("two".to_string()),
MetadataKey::new("3".to_string()),
])
}]
);
}

#[test]
fn test_drop_pred_index_parse() {
let input = r#"DROPPREDINDEX (here, th2) in store2"#;
assert_eq!(
parse_db_query(input).expect("Could not parse query input"),
vec![DBQuery::DropPredIndex {
store: StoreName("store2".to_string()),
predicates: HashSet::from_iter([
MetadataKey::new("here".to_string()),
MetadataKey::new("th2".to_string()),
]),
error_if_not_exists: true,
}]
);
let input = r#"DROPPREDINDEX IF EXISTS (off) in storememe"#;
assert_eq!(
parse_db_query(input).expect("Could not parse query input"),
vec![DBQuery::DropPredIndex {
store: StoreName("storememe".to_string()),
predicates: HashSet::from_iter([MetadataKey::new("off".to_string()),]),
error_if_not_exists: false,
}]
);
}

#[test]
fn test_create_non_linear_algorithm_parse() {
let input = r#"createnonlinearalgorithmindex (fake) in store2"#;
let DslError::UnexpectedSpan((start, end)) = parse_db_query(input).unwrap_err() else {
panic!("Unexpected error pattern found")
};
assert_eq!((start, end), (0, 46));
let input = r#"createnonlinearalgorithmindex (kdtree) in store2"#;
assert_eq!(
parse_db_query(input).expect("Could not parse query input"),
vec![DBQuery::CreateNonLinearAlgorithmIndex {
store: StoreName("store2".to_string()),
non_linear_indices: HashSet::from_iter([NonLinearAlgorithm::KDTree]),
}]
);
}

#[test]
fn test_drop_non_linear_algorithm_parse() {
let input = r#"DROPNONLINEARALGORITHMINDEX (fake) in 1234"#;
let DslError::UnexpectedSpan((start, end)) = parse_db_query(input).unwrap_err() else {
panic!("Unexpected error pattern found")
};
assert_eq!((start, end), (0, 42));
let input = r#"DROPNONLINEARALGORITHMINDEX (kdtree) in 1234"#;
assert_eq!(
parse_db_query(input).expect("Could not parse query input"),
vec![DBQuery::DropNonLinearAlgorithmIndex {
store: StoreName("1234".to_string()),
non_linear_indices: HashSet::from_iter([NonLinearAlgorithm::KDTree]),
error_if_not_exists: true,
}]
);
let input = r#"DROPNONLINEARALGORITHMINDEX IF EXISTS (kdtree) in 1234"#;
assert_eq!(
parse_db_query(input).expect("Could not parse query input"),
vec![DBQuery::DropNonLinearAlgorithmIndex {
store: StoreName("1234".to_string()),
non_linear_indices: HashSet::from_iter([NonLinearAlgorithm::KDTree]),
error_if_not_exists: false,
}]
);
}

#[test]
fn test_get_key_parse() {
let input = r#"getkey ((a, b, c), (3.0, 4.0)) in 1234"#;
let DslError::UnexpectedSpan((start, end)) = parse_db_query(input).unwrap_err() else {
panic!("Unexpected error pattern found")
};
assert_eq!((start, end), (0, 38));
let input = r#"getkey ((1, 2, 3), (3.0, 4.0)) in 1234"#;
assert_eq!(
parse_db_query(input).expect("Could not parse query input"),
vec![DBQuery::GetKey {
store: StoreName("1234".to_string()),
keys: vec![
StoreKey(Array1::from_iter([1.0, 2.0, 3.0])),
StoreKey(Array1::from_iter([3.0, 4.0])),
],
}]
);
}

#[test]
fn test_del_key_parse() {
let input = r#"DELKEY ((a, b, c), (3.0, 4.0)) in 1234"#;
let DslError::UnexpectedSpan((start, end)) = parse_db_query(input).unwrap_err() else {
panic!("Unexpected error pattern found")
};
assert_eq!((start, end), (0, 38));
let input = r#"DELKEY ((1, 2, 3), (3.0, 4.0)) in 1234"#;
assert_eq!(
parse_db_query(input).expect("Could not parse query input"),
vec![DBQuery::DelKey {
store: StoreName("1234".to_string()),
keys: vec![
StoreKey(Array1::from_iter([1.0, 2.0, 3.0])),
StoreKey(Array1::from_iter([3.0, 4.0])),
],
}]
);
}

#[test]
fn test_get_pred_parse() {
let input = r#"GETPRED ((a, b, c), (3.0, 4.0)) in 1234"#;
let DslError::UnexpectedSpan((start, end)) = parse_db_query(input).unwrap_err() else {
panic!("Unexpected error pattern found")
};
assert_eq!((start, end), (0, 39));
let input = r#"GETPRED ((firstname = king) OR (surname != charles)) in store2"#;
assert_eq!(
parse_db_query(input).expect("Could not parse query input"),
vec![DBQuery::GetPred {
store: StoreName("store2".to_string()),
condition: PredicateCondition::Value(Predicate::Equals {
key: MetadataKey::new("firstname".into()),
value: MetadataValue::RawString("king".to_string())
})
.or(PredicateCondition::Value(Predicate::NotEquals {
key: MetadataKey::new("surname".into()),
value: MetadataValue::RawString("charles".to_string())
})),
}]
);
let input = r#"GETPRED ((pages in (0, 1, 2)) AND (author != dickens) OR (author NOT in (jk-rowlins, rick-riodan)) ) in bookshelf"#;
assert_eq!(
parse_db_query(input).expect("Could not parse query input"),
vec![DBQuery::GetPred {
store: StoreName("bookshelf".to_string()),
condition: PredicateCondition::Value(Predicate::In {
key: MetadataKey::new("pages".into()),
value: HashSet::from_iter([
MetadataValue::RawString("0".to_string()),
MetadataValue::RawString("1".to_string()),
MetadataValue::RawString("2".to_string()),
]),
})
.and(
PredicateCondition::Value(Predicate::NotEquals {
key: MetadataKey::new("author".into()),
value: MetadataValue::RawString("dickens".to_string())
})
.or(PredicateCondition::Value(Predicate::NotIn {
key: MetadataKey::new("author".into()),
value: HashSet::from_iter([
MetadataValue::RawString("jk-rowlins".to_string()),
MetadataValue::RawString("rick-riodan".to_string()),
]),
}))
)
}]
);
}
}
6 changes: 6 additions & 0 deletions ahnlich/dsl/src/error.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
use std::num::ParseIntError;

use crate::parser::Rule;
use thiserror::Error;

Expand All @@ -9,4 +11,8 @@ pub enum DslError {
UnexpectedSpan((usize, usize)),
#[error("Could not parse Hex string into image {0:?}")]
UnexpectedHex(String),
#[error("Could not parse string into nonzerousize {0:?}")]
NonZeroUsizeParse(#[from] ParseIntError),
#[error("Found unsupported algorithm {0}")]
UnsupportedAlgorithm(String),
}
3 changes: 3 additions & 0 deletions ahnlich/dsl/src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
mod algorithm;
pub mod db;
pub mod error;
mod metadata;
mod parser;
mod predicate;
#[cfg(test)]
mod tests;
Loading

0 comments on commit b365c20

Please sign in to comment.