Skip to content

Commit

Permalink
Merge branch 'main' into david/granular-tracing
Browse files Browse the repository at this point in the history
  • Loading branch information
Iamdavidonuh committed Jun 4, 2024
2 parents e67b2ac + d635d56 commit 7fc0f62
Show file tree
Hide file tree
Showing 22 changed files with 381 additions and 118 deletions.
4 changes: 3 additions & 1 deletion ahnlich/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,11 @@ members = [
resolver = "2"

[workspace.dependencies]
serde = { version = "1.0.*", default-features = false }
serde = { version = "1.0.*", features = ["derive"] }
bincode = "1.3.3"
ndarray = { version = "0.15.6", features = ["serde"] }
itertools = "0.10.0"
clap = { version = "4.5.4", features = ["derive"] }
tracing = "0.1"
once_cell = "1.19.0"
futures = "0.3.30"
6 changes: 4 additions & 2 deletions ahnlich/server/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -16,17 +16,19 @@ path = "src/lib.rs"
flurry = "0.5.1"
sha2 = "0.10.3"
ndarray.workspace = true
bincode.workspace = true
itertools.workspace = true
clap.workspace = true
once_cell.workspace = true
tracing.workspace = true
types = { path = "../types", version = "*" }
tracer = { path = "../tracer", version = "*" }

tokio = { version = "1.37.0", features = ["net", "macros", "io-util", "rt-multi-thread", "sync"] }
tracer = { path = "../tracer", version = "*" }
thiserror = "1.0"

[dev-dependencies]
loom = "0.7.2"
reqwest = "0.12.4"
serde_json = "1.0.116"
tokio = { version = "1.37.0", features = ["io-util", "sync"] }
futures.workspace = true
13 changes: 6 additions & 7 deletions ahnlich/server/src/algorithm/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ mod similarity;

use std::num::NonZeroUsize;

use types::keyval::{SearchInput, StoreKey};
use types::keyval::StoreKey;
use types::similarity::Algorithm;

use self::{heap::AlgorithmHeapType, similarity::SimilarityFunc};
Expand Down Expand Up @@ -48,7 +48,7 @@ impl Ord for SimilarityVector<'_> {
pub(crate) trait FindSimilarN {
fn find_similar_n<'a>(
&'a self,
search_vector: &SearchInput,
search_vector: &StoreKey,
search_list: impl Iterator<Item = &'a StoreKey>,
n: NonZeroUsize,
) -> Vec<(&'a StoreKey, f64)>;
Expand All @@ -57,17 +57,16 @@ pub(crate) trait FindSimilarN {
impl FindSimilarN for Algorithm {
fn find_similar_n<'a>(
&'a self,
search_vector: &SearchInput,
search_vector: &StoreKey,
search_list: impl Iterator<Item = &'a StoreKey>,
n: NonZeroUsize,
) -> Vec<(&'a StoreKey, f64)> {
let mut heap: AlgorithmHeapType = (self, n).into();

let similarity_function: SimilarityFunc = self.into();
let search_vector = StoreKey(search_vector.clone());

for second_vector in search_list {
let similarity = similarity_function(&search_vector, second_vector);
let similarity = similarity_function(search_vector, second_vector);

let heap_value: SimilarityVector = (second_vector, similarity).into();
heap.push(heap_value)
Expand All @@ -79,7 +78,7 @@ impl FindSimilarN for Algorithm {
#[cfg(test)]
mod tests {
use super::*;
use crate::tests::*;
use crate::fixtures::*;

#[test]
fn test_teststore_find_top_3_similar_words_using_find_nearest_n() {
Expand All @@ -100,7 +99,7 @@ mod tests {
let cosine_algorithm = Algorithm::CosineSimilarity;

let similar_n_search = cosine_algorithm.find_similar_n(
&first_vector.0,
&first_vector,
search_list.iter(),
NonZeroUsize::new(no_similar_values).unwrap(),
);
Expand Down
2 changes: 1 addition & 1 deletion ahnlich/server/src/algorithm/similarity.rs
Original file line number Diff line number Diff line change
Expand Up @@ -128,7 +128,7 @@ fn euclidean_distance(first: &StoreKey, second: &StoreKey) -> f64 {
#[cfg(test)]
mod tests {
use super::*;
use crate::tests::*;
use crate::fixtures::*;

#[test]
fn test_find_top_3_similar_words_using_cosine_similarity() {
Expand Down
15 changes: 7 additions & 8 deletions ahnlich/server/src/cli/server.rs
Original file line number Diff line number Diff line change
Expand Up @@ -46,11 +46,15 @@ pub struct ServerConfig {
pub(crate) log_level: String,
}

impl ServerConfig {
fn new() -> Self {
impl Default for ServerConfig {
fn default() -> Self {
Self {
host: String::from("127.0.0.1"),
port: 1396,
#[cfg(not(test))]
port: 1369,
// allow OS to pick a port
#[cfg(test)]
port: 0,
enable_persistence: false,
persist_location: None,
persistence_intervals: 1000 * 60 * 5,
Expand All @@ -61,8 +65,3 @@ impl ServerConfig {
}
}
}
impl Default for ServerConfig {
fn default() -> Self {
Self::new()
}
}
2 changes: 1 addition & 1 deletion ahnlich/server/src/engine/mod.rs
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
mod predicate;
mod store;
pub(crate) mod store;
12 changes: 6 additions & 6 deletions ahnlich/server/src/engine/store.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@ use std::fmt::Write;
use std::mem::size_of_val;
use std::num::NonZeroUsize;
use std::sync::Arc;
use types::keyval::SearchInput;
use types::keyval::StoreKey;
use types::keyval::StoreName;
use types::keyval::StoreValue;
Expand Down Expand Up @@ -133,7 +132,7 @@ impl StoreHandler {
pub(crate) fn get_sim_in_store(
&self,
store_name: &StoreName,
search_input: SearchInput,
search_input: StoreKey,
closest_n: NonZeroUsize,
algorithm: Algorithm,
condition: Option<PredicateCondition>,
Expand Down Expand Up @@ -383,6 +382,7 @@ impl Store {
/// TODO: Fix nested calculation of sizes using size_of_val
fn size(&self) -> usize {
size_of_val(&self)
+ size_of_val(&self.dimension)
+ size_of_val(&self.id_to_value)
+ self
.id_to_value
Expand All @@ -404,7 +404,7 @@ impl Store {

#[cfg(test)]
mod tests {
use crate::tests::*;
use crate::fixtures::*;
use std::num::NonZeroUsize;

use super::*;
Expand Down Expand Up @@ -863,12 +863,12 @@ mod tests {
StoreInfo {
name: odd_store,
len: 2,
size_in_bytes: 2096,
size_in_bytes: 2104,
},
StoreInfo {
name: even_store,
len: 0,
size_in_bytes: 1728,
size_in_bytes: 1736,
},
])
)
Expand Down Expand Up @@ -929,7 +929,7 @@ mod tests {
value: MetadataValue::new("Chunin".into()),
op: PredicateOp::Equals,
});
let search_input = vectors.get(SEACH_TEXT).unwrap().0.clone();
let search_input = StoreKey(vectors.get(SEACH_TEXT).unwrap().0.clone());
let algorithm = Algorithm::CosineSimilarity;

let closest_n = NonZeroUsize::new(3).unwrap();
Expand Down
9 changes: 8 additions & 1 deletion ahnlich/server/src/errors.rs
Original file line number Diff line number Diff line change
@@ -1,14 +1,21 @@
use thiserror::Error;
use types::keyval::StoreName;
use types::metadata::MetadataKey;

/// TODO: Move to shared rust types so library can deserialize it from the TCP response
#[derive(Debug, Eq, PartialEq, PartialOrd, Ord)]
#[derive(Error, Debug, Eq, PartialEq, PartialOrd, Ord)]
pub enum ServerError {
#[error("Predicate {0} not found in store, attempt reindexing with predicate")]
PredicateNotFound(MetadataKey),
#[error("Store {0} not found")]
StoreNotFound(StoreName),
#[error("Store {0} already exists")]
StoreAlreadyExists(StoreName),
#[error("Store dimension is [{store_dimension}], input dimension of [{input_dimension}] was specified")]
StoreDimensionMismatch {
store_dimension: usize,
input_dimension: usize,
},
#[error("Could not deserialize query, error is {0}")]
QueryDeserializeError(String),
}
143 changes: 107 additions & 36 deletions ahnlich/server/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,51 +6,122 @@ mod engine;
mod errors;
mod network;
mod storage;
pub use crate::cli::ServerConfig;
use tokio::io::{AsyncBufReadExt, AsyncWriteExt, BufReader};
use crate::cli::ServerConfig;
use crate::engine::store::StoreHandler;
use std::io::Result as IoResult;
use std::net::SocketAddr;
use std::sync::Arc;
use tokio::io::{AsyncReadExt, AsyncWriteExt, BufReader};
use tokio::net::TcpListener;
use tokio::net::TcpStream;
use tracer::init_tracing;
use types::bincode::BinCodeSerAndDeser;
use types::bincode::LENGTH_HEADER_SIZE;
use types::query::Query;
use types::query::ServerQuery;
use types::server::ServerResponse;
use types::server::ServerResult;

pub async fn run_server(config: ServerConfig) -> std::io::Result<()> {
// setup tracing
if config.enable_tracing {
let otel_url = &config
.otel_endpoint
.unwrap_or("http://127.0.0.1:4317".to_string());
let log_level = &config.log_level;
init_tracing("ahnlich-db", Some(log_level), otel_url);
#[derive(Debug)]
pub struct Server {
listener: TcpListener,
store_handler: Arc<StoreHandler>,
}

impl Server {
/// initializes a server using server configuration
pub async fn new(config: &ServerConfig) -> IoResult<Self> {
let listener =
tokio::net::TcpListener::bind(format!("{}:{}", &config.host, &config.port)).await?;
// Setup tracing
if config.enable_tracing {
let otel_url = (config.otel_endpoint)
.to_owned()
.unwrap_or("http://127.0.0.1:4317".to_string());
let log_level = &config.log_level;
init_tracing("ahnlich-db", Some(log_level), &otel_url);
}
// TODO: replace with rules to retrieve store handler from persistence if persistence exist
let store_handler = Arc::new(StoreHandler::new());
Ok(Self {
listener,
store_handler,
})
}

let listener =
tokio::net::TcpListener::bind(format!("{}:{}", &config.host, &config.port)).await?;

loop {
let (stream, connect_addr) = listener.accept().await?;
tracing::info!("Connecting to {}", connect_addr);
tokio::spawn(async move {
if let Err(e) = process_stream(stream).await {
tracing::error!("Error handling connection: {}", e)
};
});
/// starts accepting connections using the listener and processing the incoming streams
pub async fn start(&self) -> IoResult<()> {
loop {
let (stream, connect_addr) = self.listener.accept().await?;
tracing::info!("Connecting to {}", connect_addr);
// TODO
// - Spawn a tokio task to handle the command while holding on to a reference to self
// - Convert the incoming bincode in a chunked manner to a Vec<Query>
// - Use store_handler to process the queries
// - Block new incoming connections on shutdown by no longer accepting and then
// cancelling existing ServerTask or forcing them to run to completion

// "inexpensive" to clone store handler as it is an Arc
let task = ServerTask::new(stream, self.store_handler.clone());
tokio::spawn(async move {
if let Err(e) = task.process().await {
tracing::error!("Error handling connection: {}", e)
};
});
}
}
}

#[tracing::instrument]
async fn process_stream(stream: tokio::net::TcpStream) -> Result<(), tokio::io::Error> {
stream.readable().await?;
let mut reader = BufReader::new(stream);
loop {
let mut message = String::new();
let _ = reader.read_line(&mut message).await?;
tracing::info_span!("Sending Messages");
reader.get_mut().write_all(message.as_bytes()).await?;
message.clear();
pub fn local_addr(&self) -> IoResult<SocketAddr> {
self.listener.local_addr()
}
}

#[cfg(test)]
mod tests {
// Import the fixtures for use in tests
pub use super::fixtures::*;
#[derive(Debug)]
struct ServerTask {
stream: TcpStream,
store_handler: Arc<StoreHandler>,
}

impl ServerTask {
fn new(stream: TcpStream, store_handler: Arc<StoreHandler>) -> Self {
Self {
stream,
store_handler,
}
}

/// processes messages from a stream
async fn process(self) -> IoResult<()> {
self.stream.readable().await?;
let mut reader = BufReader::new(self.stream);
let mut length_buf = [0u8; LENGTH_HEADER_SIZE];
loop {
reader.read_exact(&mut length_buf).await?;
let data_length = u64::from_be_bytes(length_buf);
let mut data = vec![0u8; data_length as usize];
reader.read_exact(&mut data).await?;
// TODO: Add trace here to catch whenever queries could not be deserialized at all
if let Ok(queries) = ServerQuery::deserialize(false, &data) {
// TODO: Pass in store_handler and use to respond to queries
let results = Self::handle(queries.into_inner());
if let Ok(binary_results) = results.serialize() {
reader.get_mut().write_all(&binary_results).await?;
}
}
}
}

fn handle(queries: Vec<Query>) -> ServerResult {
let mut result = ServerResult::with_capacity(queries.len());
for query in queries {
result.push(match query {
Query::Ping => Ok(ServerResponse::Pong),
Query::InfoServer => Ok(ServerResponse::Unit),
_ => Err("Response not implemented".to_string()),
})
}
result
}
}

#[cfg(test)]
Expand Down
3 changes: 2 additions & 1 deletion ahnlich/server/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,8 @@ async fn main() -> Result<(), Box<dyn Error>> {
let cli = Cli::parse();
match &cli.command {
Commands::Run(config) => {
server::run_server(config.to_owned()).await?;
let server = server::Server::new(config).await?;
server.start().await?;
}
}
Ok(())
Expand Down
Loading

0 comments on commit 7fc0f62

Please sign in to comment.