diff --git a/ahnlich/Cargo.toml b/ahnlich/Cargo.toml index 7fe9ce17..6cff0cc6 100644 --- a/ahnlich/Cargo.toml +++ b/ahnlich/Cargo.toml @@ -8,9 +8,11 @@ members = [ resolver = "2" [workspace.dependencies] -serde = { version = "1.0.*", default-features = false } +serde = { version = "1.0.*", features = ["derive"] } +bincode = "1.3.3" ndarray = { version = "0.15.6", features = ["serde"] } itertools = "0.10.0" clap = { version = "4.5.4", features = ["derive"] } tracing = "0.1" once_cell = "1.19.0" +futures = "0.3.30" diff --git a/ahnlich/server/Cargo.toml b/ahnlich/server/Cargo.toml index 92e96638..543163f3 100644 --- a/ahnlich/server/Cargo.toml +++ b/ahnlich/server/Cargo.toml @@ -16,17 +16,19 @@ path = "src/lib.rs" flurry = "0.5.1" sha2 = "0.10.3" ndarray.workspace = true +bincode.workspace = true itertools.workspace = true clap.workspace = true once_cell.workspace = true tracing.workspace = true types = { path = "../types", version = "*" } -tracer = { path = "../tracer", version = "*" } - tokio = { version = "1.37.0", features = ["net", "macros", "io-util", "rt-multi-thread", "sync"] } +tracer = { path = "../tracer", version = "*" } +thiserror = "1.0" [dev-dependencies] loom = "0.7.2" reqwest = "0.12.4" serde_json = "1.0.116" tokio = { version = "1.37.0", features = ["io-util", "sync"] } +futures.workspace = true diff --git a/ahnlich/server/src/algorithm/mod.rs b/ahnlich/server/src/algorithm/mod.rs index 5b75d377..6b8262fa 100644 --- a/ahnlich/server/src/algorithm/mod.rs +++ b/ahnlich/server/src/algorithm/mod.rs @@ -3,7 +3,7 @@ mod similarity; use std::num::NonZeroUsize; -use types::keyval::{SearchInput, StoreKey}; +use types::keyval::StoreKey; use types::similarity::Algorithm; use self::{heap::AlgorithmHeapType, similarity::SimilarityFunc}; @@ -48,7 +48,7 @@ impl Ord for SimilarityVector<'_> { pub(crate) trait FindSimilarN { fn find_similar_n<'a>( &'a self, - search_vector: &SearchInput, + search_vector: &StoreKey, search_list: impl Iterator, n: NonZeroUsize, ) -> Vec<(&'a StoreKey, f64)>; @@ -57,17 +57,16 @@ pub(crate) trait FindSimilarN { impl FindSimilarN for Algorithm { fn find_similar_n<'a>( &'a self, - search_vector: &SearchInput, + search_vector: &StoreKey, search_list: impl Iterator, n: NonZeroUsize, ) -> Vec<(&'a StoreKey, f64)> { let mut heap: AlgorithmHeapType = (self, n).into(); let similarity_function: SimilarityFunc = self.into(); - let search_vector = StoreKey(search_vector.clone()); for second_vector in search_list { - let similarity = similarity_function(&search_vector, second_vector); + let similarity = similarity_function(search_vector, second_vector); let heap_value: SimilarityVector = (second_vector, similarity).into(); heap.push(heap_value) @@ -79,7 +78,7 @@ impl FindSimilarN for Algorithm { #[cfg(test)] mod tests { use super::*; - use crate::tests::*; + use crate::fixtures::*; #[test] fn test_teststore_find_top_3_similar_words_using_find_nearest_n() { @@ -100,7 +99,7 @@ mod tests { let cosine_algorithm = Algorithm::CosineSimilarity; let similar_n_search = cosine_algorithm.find_similar_n( - &first_vector.0, + &first_vector, search_list.iter(), NonZeroUsize::new(no_similar_values).unwrap(), ); diff --git a/ahnlich/server/src/algorithm/similarity.rs b/ahnlich/server/src/algorithm/similarity.rs index 0461d4c9..e3769c3b 100644 --- a/ahnlich/server/src/algorithm/similarity.rs +++ b/ahnlich/server/src/algorithm/similarity.rs @@ -128,7 +128,7 @@ fn euclidean_distance(first: &StoreKey, second: &StoreKey) -> f64 { #[cfg(test)] mod tests { use super::*; - use crate::tests::*; + use crate::fixtures::*; #[test] fn test_find_top_3_similar_words_using_cosine_similarity() { diff --git a/ahnlich/server/src/cli/server.rs b/ahnlich/server/src/cli/server.rs index c248e3be..364a6a10 100644 --- a/ahnlich/server/src/cli/server.rs +++ b/ahnlich/server/src/cli/server.rs @@ -46,11 +46,15 @@ pub struct ServerConfig { pub(crate) log_level: String, } -impl ServerConfig { - fn new() -> Self { +impl Default for ServerConfig { + fn default() -> Self { Self { host: String::from("127.0.0.1"), - port: 1396, + #[cfg(not(test))] + port: 1369, + // allow OS to pick a port + #[cfg(test)] + port: 0, enable_persistence: false, persist_location: None, persistence_intervals: 1000 * 60 * 5, @@ -61,8 +65,3 @@ impl ServerConfig { } } } -impl Default for ServerConfig { - fn default() -> Self { - Self::new() - } -} diff --git a/ahnlich/server/src/engine/mod.rs b/ahnlich/server/src/engine/mod.rs index 4583b4e1..7eddd8c6 100644 --- a/ahnlich/server/src/engine/mod.rs +++ b/ahnlich/server/src/engine/mod.rs @@ -1,2 +1,2 @@ mod predicate; -mod store; +pub(crate) mod store; diff --git a/ahnlich/server/src/engine/store.rs b/ahnlich/server/src/engine/store.rs index 0abe6a7d..9897e4ef 100644 --- a/ahnlich/server/src/engine/store.rs +++ b/ahnlich/server/src/engine/store.rs @@ -11,7 +11,6 @@ use std::fmt::Write; use std::mem::size_of_val; use std::num::NonZeroUsize; use std::sync::Arc; -use types::keyval::SearchInput; use types::keyval::StoreKey; use types::keyval::StoreName; use types::keyval::StoreValue; @@ -133,7 +132,7 @@ impl StoreHandler { pub(crate) fn get_sim_in_store( &self, store_name: &StoreName, - search_input: SearchInput, + search_input: StoreKey, closest_n: NonZeroUsize, algorithm: Algorithm, condition: Option, @@ -383,6 +382,7 @@ impl Store { /// TODO: Fix nested calculation of sizes using size_of_val fn size(&self) -> usize { size_of_val(&self) + + size_of_val(&self.dimension) + size_of_val(&self.id_to_value) + self .id_to_value @@ -404,7 +404,7 @@ impl Store { #[cfg(test)] mod tests { - use crate::tests::*; + use crate::fixtures::*; use std::num::NonZeroUsize; use super::*; @@ -863,12 +863,12 @@ mod tests { StoreInfo { name: odd_store, len: 2, - size_in_bytes: 2096, + size_in_bytes: 2104, }, StoreInfo { name: even_store, len: 0, - size_in_bytes: 1728, + size_in_bytes: 1736, }, ]) ) @@ -929,7 +929,7 @@ mod tests { value: MetadataValue::new("Chunin".into()), op: PredicateOp::Equals, }); - let search_input = vectors.get(SEACH_TEXT).unwrap().0.clone(); + let search_input = StoreKey(vectors.get(SEACH_TEXT).unwrap().0.clone()); let algorithm = Algorithm::CosineSimilarity; let closest_n = NonZeroUsize::new(3).unwrap(); diff --git a/ahnlich/server/src/errors.rs b/ahnlich/server/src/errors.rs index 9ff0ab8e..36ea3803 100644 --- a/ahnlich/server/src/errors.rs +++ b/ahnlich/server/src/errors.rs @@ -1,14 +1,21 @@ +use thiserror::Error; use types::keyval::StoreName; use types::metadata::MetadataKey; /// TODO: Move to shared rust types so library can deserialize it from the TCP response -#[derive(Debug, Eq, PartialEq, PartialOrd, Ord)] +#[derive(Error, Debug, Eq, PartialEq, PartialOrd, Ord)] pub enum ServerError { + #[error("Predicate {0} not found in store, attempt reindexing with predicate")] PredicateNotFound(MetadataKey), + #[error("Store {0} not found")] StoreNotFound(StoreName), + #[error("Store {0} already exists")] StoreAlreadyExists(StoreName), + #[error("Store dimension is [{store_dimension}], input dimension of [{input_dimension}] was specified")] StoreDimensionMismatch { store_dimension: usize, input_dimension: usize, }, + #[error("Could not deserialize query, error is {0}")] + QueryDeserializeError(String), } diff --git a/ahnlich/server/src/lib.rs b/ahnlich/server/src/lib.rs index 1244001e..c3a347f1 100644 --- a/ahnlich/server/src/lib.rs +++ b/ahnlich/server/src/lib.rs @@ -6,51 +6,122 @@ mod engine; mod errors; mod network; mod storage; -pub use crate::cli::ServerConfig; -use tokio::io::{AsyncBufReadExt, AsyncWriteExt, BufReader}; +use crate::cli::ServerConfig; +use crate::engine::store::StoreHandler; +use std::io::Result as IoResult; +use std::net::SocketAddr; +use std::sync::Arc; +use tokio::io::{AsyncReadExt, AsyncWriteExt, BufReader}; +use tokio::net::TcpListener; +use tokio::net::TcpStream; use tracer::init_tracing; +use types::bincode::BinCodeSerAndDeser; +use types::bincode::LENGTH_HEADER_SIZE; +use types::query::Query; +use types::query::ServerQuery; +use types::server::ServerResponse; +use types::server::ServerResult; -pub async fn run_server(config: ServerConfig) -> std::io::Result<()> { - // setup tracing - if config.enable_tracing { - let otel_url = &config - .otel_endpoint - .unwrap_or("http://127.0.0.1:4317".to_string()); - let log_level = &config.log_level; - init_tracing("ahnlich-db", Some(log_level), otel_url); +#[derive(Debug)] +pub struct Server { + listener: TcpListener, + store_handler: Arc, +} + +impl Server { + /// initializes a server using server configuration + pub async fn new(config: &ServerConfig) -> IoResult { + let listener = + tokio::net::TcpListener::bind(format!("{}:{}", &config.host, &config.port)).await?; + // Setup tracing + if config.enable_tracing { + let otel_url = (config.otel_endpoint) + .to_owned() + .unwrap_or("http://127.0.0.1:4317".to_string()); + let log_level = &config.log_level; + init_tracing("ahnlich-db", Some(log_level), &otel_url); + } + // TODO: replace with rules to retrieve store handler from persistence if persistence exist + let store_handler = Arc::new(StoreHandler::new()); + Ok(Self { + listener, + store_handler, + }) } - let listener = - tokio::net::TcpListener::bind(format!("{}:{}", &config.host, &config.port)).await?; - - loop { - let (stream, connect_addr) = listener.accept().await?; - tracing::info!("Connecting to {}", connect_addr); - tokio::spawn(async move { - if let Err(e) = process_stream(stream).await { - tracing::error!("Error handling connection: {}", e) - }; - }); + /// starts accepting connections using the listener and processing the incoming streams + pub async fn start(&self) -> IoResult<()> { + loop { + let (stream, connect_addr) = self.listener.accept().await?; + tracing::info!("Connecting to {}", connect_addr); + // TODO + // - Spawn a tokio task to handle the command while holding on to a reference to self + // - Convert the incoming bincode in a chunked manner to a Vec + // - Use store_handler to process the queries + // - Block new incoming connections on shutdown by no longer accepting and then + // cancelling existing ServerTask or forcing them to run to completion + + // "inexpensive" to clone store handler as it is an Arc + let task = ServerTask::new(stream, self.store_handler.clone()); + tokio::spawn(async move { + if let Err(e) = task.process().await { + tracing::error!("Error handling connection: {}", e) + }; + }); + } } -} -#[tracing::instrument] -async fn process_stream(stream: tokio::net::TcpStream) -> Result<(), tokio::io::Error> { - stream.readable().await?; - let mut reader = BufReader::new(stream); - loop { - let mut message = String::new(); - let _ = reader.read_line(&mut message).await?; - tracing::info_span!("Sending Messages"); - reader.get_mut().write_all(message.as_bytes()).await?; - message.clear(); + pub fn local_addr(&self) -> IoResult { + self.listener.local_addr() } } -#[cfg(test)] -mod tests { - // Import the fixtures for use in tests - pub use super::fixtures::*; +#[derive(Debug)] +struct ServerTask { + stream: TcpStream, + store_handler: Arc, +} + +impl ServerTask { + fn new(stream: TcpStream, store_handler: Arc) -> Self { + Self { + stream, + store_handler, + } + } + + /// processes messages from a stream + async fn process(self) -> IoResult<()> { + self.stream.readable().await?; + let mut reader = BufReader::new(self.stream); + let mut length_buf = [0u8; LENGTH_HEADER_SIZE]; + loop { + reader.read_exact(&mut length_buf).await?; + let data_length = u64::from_be_bytes(length_buf); + let mut data = vec![0u8; data_length as usize]; + reader.read_exact(&mut data).await?; + // TODO: Add trace here to catch whenever queries could not be deserialized at all + if let Ok(queries) = ServerQuery::deserialize(false, &data) { + // TODO: Pass in store_handler and use to respond to queries + let results = Self::handle(queries.into_inner()); + if let Ok(binary_results) = results.serialize() { + reader.get_mut().write_all(&binary_results).await?; + } + } + } + } + + fn handle(queries: Vec) -> ServerResult { + let mut result = ServerResult::with_capacity(queries.len()); + for query in queries { + result.push(match query { + Query::Ping => Ok(ServerResponse::Pong), + Query::InfoServer => Ok(ServerResponse::Unit), + _ => Err("Response not implemented".to_string()), + }) + } + result + } } #[cfg(test)] diff --git a/ahnlich/server/src/main.rs b/ahnlich/server/src/main.rs index 5ce274d9..d79e812c 100644 --- a/ahnlich/server/src/main.rs +++ b/ahnlich/server/src/main.rs @@ -7,7 +7,8 @@ async fn main() -> Result<(), Box> { let cli = Cli::parse(); match &cli.command { Commands::Run(config) => { - server::run_server(config.to_owned()).await?; + let server = server::Server::new(config).await?; + server.start().await?; } } Ok(()) diff --git a/ahnlich/server/tests/server_test.rs b/ahnlich/server/tests/server_test.rs index 0f7b9611..6d50ef53 100644 --- a/ahnlich/server/tests/server_test.rs +++ b/ahnlich/server/tests/server_test.rs @@ -1,40 +1,77 @@ -use server::ServerConfig; -use tokio::io::{AsyncBufReadExt, AsyncWriteExt, BufReader}; +use futures::future::join_all; +use server::cli::ServerConfig; +use std::net::SocketAddr; +use tokio::io::{AsyncReadExt, AsyncWriteExt, BufReader}; use tokio::net::TcpStream; use tokio::time::{timeout, Duration}; +use types::bincode::BinCodeSerAndDeser; +use types::query::Query; +use types::query::ServerQuery; +use types::server::ServerResponse; +use types::server::ServerResult; #[tokio::test] async fn test_run_server_echos() { - let server_config = ServerConfig::default(); - spawn_app(&server_config).await; + let server = server::Server::new(&ServerConfig::default()) + .await + .expect("Could not initialize server"); + let address = server.local_addr().expect("Could not get local addr"); + let _ = tokio::spawn(async move { server.start().await }); // Allow some time for the server to start tokio::time::sleep(Duration::from_millis(100)).await; + let tasks = vec![ + tokio::spawn(async move { + let message = ServerQuery::from_queries(&[Query::InfoServer, Query::Ping]); + let mut expected = ServerResult::with_capacity(2); + expected.push(Ok(ServerResponse::Unit)); + expected.push(Ok(ServerResponse::Pong)); + query_server_assert_result(address, message, expected).await + }), + tokio::spawn(async move { + let message = ServerQuery::from_queries(&[Query::Ping, Query::InfoServer]); + let mut expected = ServerResult::with_capacity(2); + expected.push(Ok(ServerResponse::Pong)); + expected.push(Ok(ServerResponse::Unit)); + query_server_assert_result(address, message, expected).await + }), + ]; + join_all(tasks).await; +} +async fn query_server_assert_result( + server_addr: SocketAddr, + query: ServerQuery, + expected_result: ServerResult, +) { // Connect to the server - let stream = TcpStream::connect(format!("{}:{}", &server_config.host, &server_config.port)) - .await - .unwrap(); + let stream = TcpStream::connect(server_addr).await.unwrap(); let mut reader = BufReader::new(stream); // Message to send - let message = "Hello, world!\n"; + let serialized_message = query.serialize().unwrap(); // Send the message - reader.write_all(message.as_bytes()).await.unwrap(); + reader.write_all(&serialized_message).await.unwrap(); - let mut response = String::new(); + // get length of response + let mut length_header = [0u8; types::bincode::LENGTH_HEADER_SIZE]; + timeout( + Duration::from_secs(1), + reader.read_exact(&mut length_header), + ) + .await + .unwrap() + .unwrap(); + let data_length = u64::from_be_bytes(length_header); + let mut response = vec![0u8; data_length as usize]; - timeout(Duration::from_secs(1), reader.read_line(&mut response)) + timeout(Duration::from_secs(1), reader.read_exact(&mut response)) .await .unwrap() .unwrap(); - assert_eq!(message, response); -} - -async fn spawn_app(config: &ServerConfig) { - let test_server = server::run_server(config.to_owned()); + let response = ServerResult::deserialize(false, &response).unwrap(); - let _ = tokio::spawn(test_server); + assert_eq!(response, expected_result); } diff --git a/ahnlich/tracer/Cargo.toml b/ahnlich/tracer/Cargo.toml index 00a66e3e..8d2a5403 100644 --- a/ahnlich/tracer/Cargo.toml +++ b/ahnlich/tracer/Cargo.toml @@ -4,7 +4,7 @@ version = "0.1.0" edition = "2021" [dependencies] -tracing.workplace = true +tracing.workspace = true tracing-subscriber = { version = "0.3.18", features = ["env-filter", "json"] } tracing-opentelemetry = "0.24.0" opentelemetry = "0.23.0" diff --git a/ahnlich/types/Cargo.toml b/ahnlich/types/Cargo.toml index 8de6d6fd..de07697b 100644 --- a/ahnlich/types/Cargo.toml +++ b/ahnlich/types/Cargo.toml @@ -7,3 +7,5 @@ edition = "2021" [dependencies] ndarray.workspace = true +serde.workspace = true +bincode.workspace = true diff --git a/ahnlich/types/src/bincode.rs b/ahnlich/types/src/bincode.rs new file mode 100644 index 00000000..62cb07d6 --- /dev/null +++ b/ahnlich/types/src/bincode.rs @@ -0,0 +1,39 @@ +use bincode::config::DefaultOptions; +use bincode::config::Options; +use serde::Deserialize; +use serde::Serialize; + +pub const LENGTH_HEADER_SIZE: usize = 8; + +/// - Length encoding must use fixed int and not var int +/// - Endianess must be Big Endian. +/// - First 8 bytes must contain length of the entire vec of response or queries +/// +/// Used to serialize and deserialize queries and responses into bincode +pub trait BinCodeSerAndDeser<'a> +where + Self: Serialize + Deserialize<'a>, +{ + fn serialize(&self) -> Result, bincode::Error> { + let config = DefaultOptions::new() + .with_fixint_encoding() + .with_big_endian(); + let serialized_data = config.serialize(self)?; + let data_length = serialized_data.len() as u64; + // serialization appends the length buffer to be read first + let mut buffer = Vec::with_capacity(LENGTH_HEADER_SIZE + serialized_data.len()); + buffer.extend(&data_length.to_be_bytes()); + buffer.extend(&serialized_data); + Ok(buffer) + } + + fn deserialize(has_length_header: bool, bytes: &'a [u8]) -> Result { + let config = DefaultOptions::new() + .with_fixint_encoding() + .with_big_endian(); + if has_length_header { + return config.deserialize(&bytes[LENGTH_HEADER_SIZE..]); + } + config.deserialize(bytes) + } +} diff --git a/ahnlich/types/src/keyval.rs b/ahnlich/types/src/keyval.rs index 553380f0..d6ab918d 100644 --- a/ahnlich/types/src/keyval.rs +++ b/ahnlich/types/src/keyval.rs @@ -1,18 +1,45 @@ use crate::metadata::MetadataKey; use crate::metadata::MetadataValue; use ndarray::Array1; +use serde::Deserialize; +use serde::Serialize; use std::collections::HashMap as StdHashMap; +use std::fmt; /// Name of a Store -#[derive(Debug, Clone, PartialEq, Eq, Ord, PartialOrd, Hash)] +#[derive(Debug, Clone, PartialEq, Eq, Ord, PartialOrd, Hash, Serialize, Deserialize)] +#[serde(transparent)] pub struct StoreName(pub String); +impl fmt::Display for StoreName { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!(f, "{}", self.0) + } +} + /// A store value for now is a simple key value pair of strings pub type StoreValue = StdHashMap; /// A store key is always an f64 one dimensional array -#[derive(Debug, Clone, PartialEq)] +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(transparent)] pub struct StoreKey(pub Array1); -/// Search input is just also an f64 one dimensional array -pub type SearchInput = Array1; +impl Eq for StoreKey {} + +impl PartialEq for StoreKey { + fn eq(&self, other: &Self) -> bool { + if self.0.shape() != other.0.shape() { + return false; + } + // std::f64::EPSILON adheres to the IEEE 754 standard and we use it here to determine when + // two Array1 are extremely similar to the point where the differences are neglible. + // We can modify to allow for greater precision, however we currently only + // use it for PartialEq and not for it's distinctive properties. For that, within the + // server we defer to using StoreKeyId whenever we want to compare distinctive Array1 + self.0 + .iter() + .zip(other.0.iter()) + .all(|(x, y)| (x - y).abs() < std::f64::EPSILON) + } +} diff --git a/ahnlich/types/src/lib.rs b/ahnlich/types/src/lib.rs index e4960ef4..45669774 100644 --- a/ahnlich/types/src/lib.rs +++ b/ahnlich/types/src/lib.rs @@ -1,20 +1,7 @@ +pub mod bincode; pub mod keyval; pub mod metadata; pub mod predicate; pub mod query; +pub mod server; pub mod similarity; - -pub fn add(left: usize, right: usize) -> usize { - left + right -} - -#[cfg(test)] -mod tests { - use super::*; - - #[test] - fn it_works() { - let result = add(2, 2); - assert_eq!(result, 4); - } -} diff --git a/ahnlich/types/src/metadata.rs b/ahnlich/types/src/metadata.rs index 241f2aba..7ea0327b 100644 --- a/ahnlich/types/src/metadata.rs +++ b/ahnlich/types/src/metadata.rs @@ -1,12 +1,24 @@ +use serde::Deserialize; +use serde::Serialize; +use std::fmt; /// New types for store metadata key and values -#[derive(Debug, Clone, Hash, Eq, PartialEq, Ord, PartialOrd)] +#[derive(Debug, Clone, Hash, Eq, PartialEq, Ord, PartialOrd, Serialize, Deserialize)] +#[serde(transparent)] pub struct MetadataKey(String); impl MetadataKey { pub fn new(input: String) -> Self { Self(input) } } -#[derive(Debug, Clone, Hash, Eq, PartialEq, Ord, PartialOrd)] + +impl fmt::Display for MetadataKey { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!(f, "{}", self.0) + } +} + +#[derive(Debug, Clone, Hash, Eq, PartialEq, Ord, PartialOrd, Serialize, Deserialize)] +#[serde(transparent)] pub struct MetadataValue(String); impl MetadataValue { pub fn new(input: String) -> Self { diff --git a/ahnlich/types/src/predicate.rs b/ahnlich/types/src/predicate.rs index 3e539eae..9e9e367b 100644 --- a/ahnlich/types/src/predicate.rs +++ b/ahnlich/types/src/predicate.rs @@ -1,8 +1,10 @@ use crate::metadata::MetadataKey; use crate::metadata::MetadataValue; +use serde::Deserialize; +use serde::Serialize; /// PredicateOp are the various operations that can be conducted against a predicate value -#[derive(Debug, Copy, Clone)] +#[derive(Debug, Copy, Clone, Serialize, Deserialize, PartialEq, Eq)] pub enum PredicateOp { Equals, NotEquals, @@ -11,7 +13,7 @@ pub enum PredicateOp { /// Representation of how one predicate value and ops looks /// to specify a predicate of name != "David", one would use the format /// PredicateOp { key: "name", value: "David", op: NotEquals } -#[derive(Debug, Clone)] +#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)] pub struct Predicate { pub key: MetadataKey, pub value: MetadataValue, @@ -20,7 +22,7 @@ pub struct Predicate { /// All possible representations of a predicate condition /// We can only have a simple And or Or and we can combine those in any fashion -#[derive(Debug, Clone)] +#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)] pub enum PredicateCondition { And(Box, Box), Or(Box, Box), diff --git a/ahnlich/types/src/query.rs b/ahnlich/types/src/query.rs index 4b38df7c..6991098a 100644 --- a/ahnlich/types/src/query.rs +++ b/ahnlich/types/src/query.rs @@ -1,16 +1,23 @@ use std::collections::HashSet; use std::num::NonZeroUsize; -use crate::keyval::{SearchInput, StoreKey, StoreName, StoreValue}; +use crate::bincode::BinCodeSerAndDeser; +use crate::keyval::{StoreKey, StoreName, StoreValue}; use crate::metadata::MetadataKey; use crate::predicate::PredicateCondition; use crate::similarity::Algorithm; +use serde::Deserialize; +use serde::Serialize; /// All possible queries for the server to respond to -#[derive(Debug, Clone)] +/// +/// +/// Vec of queries are to be sent by clients in bincode +/// - Length encoding must use fixed int and not var int +/// - Endianess must be Big Endian. +/// - First 8 bytes must contain length of the entire vec of queries +#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)] pub enum Query { - Connect, - Disconnect, Create { store: StoreName, dimension: NonZeroUsize, @@ -26,7 +33,7 @@ pub enum Query { GetSimN { store: StoreName, closest_n: NonZeroUsize, - input: SearchInput, + input: StoreKey, algorithm: Algorithm, condition: Option, }, @@ -53,11 +60,37 @@ pub enum Query { DropStore { store: StoreName, }, - ShutdownServer { - reason: Option, - }, InfoServer, ListStores, ListClients, Ping, } + +#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)] +pub struct ServerQuery { + queries: Vec, +} + +impl ServerQuery { + pub fn with_capacity(len: usize) -> Self { + Self { + queries: Vec::with_capacity(len), + } + } + + pub fn push(&mut self, entry: Query) { + self.queries.push(entry) + } + + pub fn from_queries(queries: &[Query]) -> Self { + Self { + queries: queries.to_vec(), + } + } + + pub fn into_inner(self) -> Vec { + self.queries + } +} + +impl BinCodeSerAndDeser<'_> for ServerQuery {} diff --git a/ahnlich/types/src/server.rs b/ahnlich/types/src/server.rs new file mode 100644 index 00000000..4ca9cc5d --- /dev/null +++ b/ahnlich/types/src/server.rs @@ -0,0 +1,40 @@ +use crate::bincode::BinCodeSerAndDeser; +use serde::Deserialize; +use serde::Serialize; +use std::collections::HashSet; + +#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)] +pub enum ServerResponse { + // Unit variant for no action + Unit, + Pong, + // List of connected clients. Potentially outdated at the point of read + ClientList(HashSet), + // TODO: Define return types for queries, e.t.c +} + +#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq, Hash)] +pub struct ConnectedClient { + pub address: String, +} + +// ServerResult: Given that an array of queries are sent in, we expect that an array of responses +// be returned each being a potential error +#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)] +pub struct ServerResult { + results: Vec>, +} + +impl BinCodeSerAndDeser<'_> for ServerResult {} + +impl ServerResult { + pub fn with_capacity(len: usize) -> Self { + Self { + results: Vec::with_capacity(len), + } + } + + pub fn push(&mut self, entry: Result) { + self.results.push(entry) + } +} diff --git a/ahnlich/types/src/similarity.rs b/ahnlich/types/src/similarity.rs index fc99ccf3..4e9b0a53 100644 --- a/ahnlich/types/src/similarity.rs +++ b/ahnlich/types/src/similarity.rs @@ -1,4 +1,7 @@ -#[derive(Debug, Clone, Copy, Eq, PartialEq, Ord, PartialOrd)] +use serde::Deserialize; +use serde::Serialize; + +#[derive(Debug, Clone, Copy, Eq, PartialEq, Ord, PartialOrd, Serialize, Deserialize)] /// Supported ahnlich similarity algorithms pub enum Algorithm { /// Euclidean distance is defined as the L2-norm of the difference between two vectors or their diff --git a/docs/draft.md b/docs/draft.md index 339e249a..2b69de46 100644 --- a/docs/draft.md +++ b/docs/draft.md @@ -80,9 +80,9 @@ Here's a rough sketch of commands to be expanded on later: - - `CONNECT` - - `DISCONNECT` - - `SHUTDOWNSERVER`: shut down basically discounts from all connected clients, performs cleanup before killing the server + + + - `CREATE`: Create a store which must have a unique name with respect to the server. Create can take in name_of_store, dimensions_of_vectors(immutable) to be stored in that store, ability to create predicate indices - `GETKEY`: takes in store, key and direct return of key within store matching the input key