diff --git a/ahnlich/server/Cargo.toml b/ahnlich/server/Cargo.toml index 8e3dc306..8c4381c0 100644 --- a/ahnlich/server/Cargo.toml +++ b/ahnlich/server/Cargo.toml @@ -23,7 +23,15 @@ env_logger.workspace = true log.workspace = true thiserror = "1.0" types = { path = "../types", version = "*" } -tokio = { version = "1.37.0", features = ["net", "macros", "io-util", "rt-multi-thread", "sync"] } +tokio = { version = "1.37.0", features = [ + "net", + "macros", + "io-util", + "rt-multi-thread", + "sync", + "signal" +]} +tokio-util = "0.7.11" [dev-dependencies] loom = "0.7.2" diff --git a/ahnlich/server/src/client/mod.rs b/ahnlich/server/src/client/mod.rs new file mode 100644 index 00000000..5dc203ac --- /dev/null +++ b/ahnlich/server/src/client/mod.rs @@ -0,0 +1,38 @@ +use flurry::HashSet as ConcurrentHashSet; +use std::collections::HashSet as StdHashSet; +use std::net::SocketAddr; +use std::time::SystemTime; +use types::server::ConnectedClient; + +#[derive(Debug)] +pub(crate) struct ClientHandler { + clients: ConcurrentHashSet, +} + +impl ClientHandler { + pub fn new() -> Self { + Self { + clients: ConcurrentHashSet::new(), + } + } + + pub fn connect(&self, addr: SocketAddr) -> ConnectedClient { + let client = ConnectedClient { + address: format!("{addr}"), + time_connected: SystemTime::now(), + }; + let pinned = self.clients.pin(); + pinned.insert(client.clone()); + client + } + + pub fn disconnect(&self, client: &ConnectedClient) { + let pinned = self.clients.pin(); + pinned.remove(client); + } + + pub fn list(&self) -> StdHashSet { + let pinned = self.clients.pin(); + pinned.into_iter().cloned().collect() + } +} diff --git a/ahnlich/server/src/lib.rs b/ahnlich/server/src/lib.rs index 26cddc97..e702b3c1 100644 --- a/ahnlich/server/src/lib.rs +++ b/ahnlich/server/src/lib.rs @@ -2,11 +2,13 @@ #![allow(clippy::size_of_ref)] mod algorithm; pub mod cli; +mod client; mod engine; mod errors; mod network; mod storage; use crate::cli::ServerConfig; +use crate::client::ClientHandler; use crate::engine::store::StoreHandler; use std::io::Result as IoResult; use std::net::SocketAddr; @@ -14,10 +16,15 @@ use std::sync::Arc; use tokio::io::{AsyncReadExt, AsyncWriteExt, BufReader}; use tokio::net::TcpListener; use tokio::net::TcpStream; +use tokio::select; +use tokio::signal; +use tokio_util::sync::CancellationToken; use types::bincode::BinCodeSerAndDeser; use types::bincode::LENGTH_HEADER_SIZE; use types::query::Query; use types::query::ServerQuery; +use types::server::ConnectedClient; +use types::server::ServerInfo; use types::server::ServerResponse; use types::server::ServerResult; @@ -25,6 +32,8 @@ use types::server::ServerResult; pub struct Server { listener: TcpListener, store_handler: Arc, + client_handler: Arc, + shutdown_token: CancellationToken, } impl Server { @@ -34,85 +43,148 @@ impl Server { tokio::net::TcpListener::bind(format!("{}:{}", &config.host, &config.port)).await?; // TODO: replace with rules to retrieve store handler from persistence if persistence exist let store_handler = Arc::new(StoreHandler::new()); + let client_handler = Arc::new(ClientHandler::new()); + let shutdown_token = CancellationToken::new(); Ok(Self { listener, store_handler, + client_handler, + shutdown_token, }) } /// starts accepting connections using the listener and processing the incoming streams + /// + /// listens for a ctrl_c signals to cancel spawned tasks pub async fn start(&self) -> IoResult<()> { + let server_addr = self.local_addr()?; loop { - let (stream, connect_addr) = self.listener.accept().await?; - log::info!("Connecting to {}", connect_addr); - // TODO - // - Spawn a tokio task to handle the command while holding on to a reference to self - // - Convert the incoming bincode in a chunked manner to a Vec - // - Use store_handler to process the queries - // - Block new incoming connections on shutdown by no longer accepting and then - // cancelling existing ServerTask or forcing them to run to completion + select! { + _ = signal::ctrl_c() => { + self.shutdown(); + break Ok(()); + } + Ok((stream, connect_addr)) = self.listener.accept() => { + log::info!("Connecting to {}", connect_addr); + // - Spawn a tokio task to handle the command while holding on to a reference to self + // - Convert the incoming bincode in a chunked manner to a Vec + // - Use store_handler to process the queries + // - Block new incoming connections on shutdown by no longer accepting and then + // cancelling existing ServerTask or forcing them to run to completion - // "inexpensive" to clone store handler as it is an Arc - let task = ServerTask::new(stream, self.store_handler.clone()); - tokio::spawn(async move { - if let Err(e) = task.process().await { - log::error!("Error handling connection: {}", e) - }; - }); + let mut task = self.create_task(stream, server_addr)?; + tokio::spawn(async move { + if let Err(e) = task.process().await { + log::error!("Error handling connection: {}", e) + }; + }); + } + } } } + /// stops all tasks and performs cleanup + pub fn shutdown(&self) { + // TODO: Add cleanup for instance persistence + self.shutdown_token.cancel() + } + pub fn local_addr(&self) -> IoResult { self.listener.local_addr() } + + fn create_task(&self, stream: TcpStream, server_addr: SocketAddr) -> IoResult { + let connected_client = self.client_handler.connect(stream.peer_addr()?); + let reader = BufReader::new(stream); + // add client to client_handler + Ok(ServerTask { + reader, + server_addr, + connected_client, + shutdown_token: self.shutdown_token.clone(), + // "inexpensive" to clone handlers they can be passed around in an Arc + client_handler: self.client_handler.clone(), + store_handler: self.store_handler.clone(), + }) + } } #[derive(Debug)] struct ServerTask { - stream: TcpStream, + shutdown_token: CancellationToken, + server_addr: SocketAddr, + reader: BufReader, store_handler: Arc, + client_handler: Arc, + connected_client: ConnectedClient, } impl ServerTask { - fn new(stream: TcpStream, store_handler: Arc) -> Self { - Self { - stream, - store_handler, - } - } - /// processes messages from a stream - async fn process(self) -> IoResult<()> { - self.stream.readable().await?; - let mut reader = BufReader::new(self.stream); + async fn process(&mut self) -> IoResult<()> { let mut length_buf = [0u8; LENGTH_HEADER_SIZE]; + loop { - reader.read_exact(&mut length_buf).await?; - let data_length = u64::from_be_bytes(length_buf); - let mut data = vec![0u8; data_length as usize]; - reader.read_exact(&mut data).await?; - // TODO: Add trace here to catch whenever queries could not be deserialized at all - if let Ok(queries) = ServerQuery::deserialize(false, &data) { - // TODO: Pass in store_handler and use to respond to queries - let results = Self::handle(queries.into_inner()); - if let Ok(binary_results) = results.serialize() { - reader.get_mut().write_all(&binary_results).await?; + select! { + _ = self.shutdown_token.cancelled() => { + break; + } + res = self.reader.read_exact(&mut length_buf) => { + match res { + // reader was closed + Err(ref e) if e.kind() == std::io::ErrorKind::UnexpectedEof => { + log::debug!("Client {} hung up buffered stream", self.connected_client.address); + break; + } + Err(e) => { + log::error!("Error reading from task buffered stream: {}", e); + } + Ok(_) => { + let data_length = u64::from_be_bytes(length_buf); + let mut data = vec![0u8; data_length as usize]; + self.reader.read_exact(&mut data).await?; + // TODO: Add trace here to catch whenever queries could not be deserialized at all + if let Ok(queries) = ServerQuery::deserialize(false, &data) { + // TODO: Pass in store_handler and use to respond to queries + let results = self.handle(queries.into_inner()); + if let Ok(binary_results) = results.serialize() { + self.reader.get_mut().write_all(&binary_results).await?; + } + } + } + } } } } + Ok(()) } - fn handle(queries: Vec) -> ServerResult { + fn handle(&self, queries: Vec) -> ServerResult { let mut result = ServerResult::with_capacity(queries.len()); for query in queries { result.push(match query { Query::Ping => Ok(ServerResponse::Pong), - Query::InfoServer => Ok(ServerResponse::Unit), + Query::InfoServer => Ok(ServerResponse::InfoServer(self.server_info())), + Query::ListClients => Ok(ServerResponse::ClientList(self.client_handler.list())), _ => Err("Response not implemented".to_string()), }) } result } + + fn server_info(&self) -> ServerInfo { + ServerInfo { + address: format!("{}", self.server_addr), + version: types::VERSION.to_string(), + r#type: types::server::ServerType::Database, + } + } +} + +impl Drop for ServerTask { + fn drop(&mut self) { + self.client_handler.disconnect(&self.connected_client); + } } #[cfg(test)] diff --git a/ahnlich/server/tests/server_test.rs b/ahnlich/server/tests/server_test.rs index 6d50ef53..7c0c6158 100644 --- a/ahnlich/server/tests/server_test.rs +++ b/ahnlich/server/tests/server_test.rs @@ -1,53 +1,104 @@ use futures::future::join_all; use server::cli::ServerConfig; -use std::net::SocketAddr; +use std::collections::HashSet; +use std::time::SystemTime; use tokio::io::{AsyncReadExt, AsyncWriteExt, BufReader}; use tokio::net::TcpStream; use tokio::time::{timeout, Duration}; use types::bincode::BinCodeSerAndDeser; use types::query::Query; use types::query::ServerQuery; +use types::server::ConnectedClient; +use types::server::ServerInfo; use types::server::ServerResponse; use types::server::ServerResult; #[tokio::test] -async fn test_run_server_echos() { +async fn test_server_client_info() { let server = server::Server::new(&ServerConfig::default()) .await .expect("Could not initialize server"); let address = server.local_addr().expect("Could not get local addr"); let _ = tokio::spawn(async move { server.start().await }); + // Allow some time for the server to start + tokio::time::sleep(Duration::from_millis(100)).await; + // + // Connect to the server multiple times + let first_stream = TcpStream::connect(address).await.unwrap(); + let other_stream = TcpStream::connect(address).await.unwrap(); + let first_stream_addr = first_stream.local_addr().unwrap(); + let expected_response = HashSet::from_iter([ + ConnectedClient { + address: format!("{first_stream_addr}"), + time_connected: SystemTime::now(), + }, + ConnectedClient { + address: format!("{}", other_stream.local_addr().unwrap()), + time_connected: SystemTime::now(), + }, + ]); + let message = ServerQuery::from_queries(&[Query::ListClients]); + let mut expected = ServerResult::with_capacity(1); + expected.push(Ok(ServerResponse::ClientList(expected_response.clone()))); + let mut reader = BufReader::new(first_stream); + query_server_assert_result(&mut reader, message, expected.clone()).await; + // drop other stream and see if it reflects + drop(other_stream); + let expected_response = HashSet::from_iter([ConnectedClient { + address: format!("{first_stream_addr}"), + time_connected: SystemTime::now(), + }]); + let message = ServerQuery::from_queries(&[Query::ListClients]); + let mut expected = ServerResult::with_capacity(1); + expected.push(Ok(ServerResponse::ClientList(expected_response.clone()))); + query_server_assert_result(&mut reader, message, expected.clone()).await; +} +#[tokio::test] +async fn test_run_server_echos() { + let server = server::Server::new(&ServerConfig::default()) + .await + .expect("Could not initialize server"); + let address = server.local_addr().expect("Could not get local addr"); + let _ = tokio::spawn(async move { server.start().await }); // Allow some time for the server to start tokio::time::sleep(Duration::from_millis(100)).await; let tasks = vec![ tokio::spawn(async move { let message = ServerQuery::from_queries(&[Query::InfoServer, Query::Ping]); let mut expected = ServerResult::with_capacity(2); - expected.push(Ok(ServerResponse::Unit)); + expected.push(Ok(ServerResponse::InfoServer(ServerInfo { + address: "127.0.0.1:1369".to_string(), + version: types::VERSION.to_string(), + r#type: types::server::ServerType::Database, + }))); expected.push(Ok(ServerResponse::Pong)); - query_server_assert_result(address, message, expected).await + let stream = TcpStream::connect(address).await.unwrap(); + let mut reader = BufReader::new(stream); + query_server_assert_result(&mut reader, message, expected).await }), tokio::spawn(async move { let message = ServerQuery::from_queries(&[Query::Ping, Query::InfoServer]); let mut expected = ServerResult::with_capacity(2); expected.push(Ok(ServerResponse::Pong)); - expected.push(Ok(ServerResponse::Unit)); - query_server_assert_result(address, message, expected).await + expected.push(Ok(ServerResponse::InfoServer(ServerInfo { + address: "127.0.0.1:1369".to_string(), + version: types::VERSION.to_string(), + r#type: types::server::ServerType::Database, + }))); + let stream = TcpStream::connect(address).await.unwrap(); + let mut reader = BufReader::new(stream); + query_server_assert_result(&mut reader, message, expected).await }), ]; join_all(tasks).await; } async fn query_server_assert_result( - server_addr: SocketAddr, + reader: &mut BufReader, query: ServerQuery, expected_result: ServerResult, ) { - // Connect to the server - let stream = TcpStream::connect(server_addr).await.unwrap(); - let mut reader = BufReader::new(stream); - // Message to send let serialized_message = query.serialize().unwrap(); diff --git a/ahnlich/types/src/lib.rs b/ahnlich/types/src/lib.rs index 45669774..60776df1 100644 --- a/ahnlich/types/src/lib.rs +++ b/ahnlich/types/src/lib.rs @@ -5,3 +5,5 @@ pub mod predicate; pub mod query; pub mod server; pub mod similarity; + +pub const VERSION: &str = env!("CARGO_PKG_VERSION"); diff --git a/ahnlich/types/src/server.rs b/ahnlich/types/src/server.rs index 4ca9cc5d..c58edd03 100644 --- a/ahnlich/types/src/server.rs +++ b/ahnlich/types/src/server.rs @@ -2,6 +2,9 @@ use crate::bincode::BinCodeSerAndDeser; use serde::Deserialize; use serde::Serialize; use std::collections::HashSet; +use std::hash::Hash; +use std::hash::Hasher; +use std::time::SystemTime; #[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)] pub enum ServerResponse { @@ -10,12 +13,46 @@ pub enum ServerResponse { Pong, // List of connected clients. Potentially outdated at the point of read ClientList(HashSet), + InfoServer(ServerInfo), // TODO: Define return types for queries, e.t.c } -#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq, Hash)] +#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq, PartialOrd, Ord)] +pub struct ServerInfo { + pub address: String, + pub version: String, + pub r#type: ServerType, +} + +#[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq, Eq, PartialOrd, Ord)] +pub enum ServerType { + Database, + AI, +} + +#[derive(Debug, Clone, Serialize, Deserialize, Eq, PartialOrd, Ord)] pub struct ConnectedClient { pub address: String, + // NOTE: We are using System specific time so the time marked by clients cannot be relied on to + // be monotonic and the size depends on operating system + pub time_connected: SystemTime, +} + +// NOTE: ConnectedClient should be unique purely by address assuming we are not doing any TCP magic +// to allow port reuse +impl Hash for ConnectedClient { + fn hash(&self, state: &mut H) + where + H: Hasher, + { + self.address.hash(state) + } +} + +impl PartialEq for ConnectedClient { + fn eq(&self, other: &Self) -> bool { + self.address.eq(&other.address) + } } // ServerResult: Given that an array of queries are sent in, we expect that an array of responses