Skip to content

Commit

Permalink
Handling for INFOSERVER and LISTCLIENTS
Browse files Browse the repository at this point in the history
  • Loading branch information
deven96 committed Jun 3, 2024
1 parent d635d56 commit a447e01
Show file tree
Hide file tree
Showing 6 changed files with 259 additions and 51 deletions.
10 changes: 9 additions & 1 deletion ahnlich/server/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,15 @@ env_logger.workspace = true
log.workspace = true
thiserror = "1.0"
types = { path = "../types", version = "*" }
tokio = { version = "1.37.0", features = ["net", "macros", "io-util", "rt-multi-thread", "sync"] }
tokio = { version = "1.37.0", features = [
"net",
"macros",
"io-util",
"rt-multi-thread",
"sync",
"signal"
]}
tokio-util = "0.7.11"

[dev-dependencies]
loom = "0.7.2"
Expand Down
38 changes: 38 additions & 0 deletions ahnlich/server/src/client/mod.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
use flurry::HashSet as ConcurrentHashSet;
use std::collections::HashSet as StdHashSet;
use std::net::SocketAddr;
use std::time::SystemTime;
use types::server::ConnectedClient;

#[derive(Debug)]
pub(crate) struct ClientHandler {
clients: ConcurrentHashSet<ConnectedClient>,
}

impl ClientHandler {
pub fn new() -> Self {
Self {
clients: ConcurrentHashSet::new(),
}
}

pub fn connect(&self, addr: SocketAddr) -> ConnectedClient {
let client = ConnectedClient {
address: format!("{addr}"),
time_connected: SystemTime::now(),
};
let pinned = self.clients.pin();
pinned.insert(client.clone());
client
}

pub fn disconnect(&self, client: &ConnectedClient) {
let pinned = self.clients.pin();
pinned.remove(client);
}

pub fn list(&self) -> StdHashSet<ConnectedClient> {
let pinned = self.clients.pin();
pinned.into_iter().cloned().collect()
}
}
148 changes: 110 additions & 38 deletions ahnlich/server/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,29 +2,38 @@
#![allow(clippy::size_of_ref)]
mod algorithm;
pub mod cli;
mod client;
mod engine;
mod errors;
mod network;
mod storage;
use crate::cli::ServerConfig;
use crate::client::ClientHandler;
use crate::engine::store::StoreHandler;
use std::io::Result as IoResult;
use std::net::SocketAddr;
use std::sync::Arc;
use tokio::io::{AsyncReadExt, AsyncWriteExt, BufReader};
use tokio::net::TcpListener;
use tokio::net::TcpStream;
use tokio::select;
use tokio::signal;
use tokio_util::sync::CancellationToken;
use types::bincode::BinCodeSerAndDeser;
use types::bincode::LENGTH_HEADER_SIZE;
use types::query::Query;
use types::query::ServerQuery;
use types::server::ConnectedClient;
use types::server::ServerInfo;
use types::server::ServerResponse;
use types::server::ServerResult;

#[derive(Debug)]
pub struct Server {
listener: TcpListener,
store_handler: Arc<StoreHandler>,
client_handler: Arc<ClientHandler>,
shutdown_token: CancellationToken,
}

impl Server {
Expand All @@ -34,85 +43,148 @@ impl Server {
tokio::net::TcpListener::bind(format!("{}:{}", &config.host, &config.port)).await?;
// TODO: replace with rules to retrieve store handler from persistence if persistence exist
let store_handler = Arc::new(StoreHandler::new());
let client_handler = Arc::new(ClientHandler::new());
let shutdown_token = CancellationToken::new();
Ok(Self {
listener,
store_handler,
client_handler,
shutdown_token,
})
}

/// starts accepting connections using the listener and processing the incoming streams
///
/// listens for a ctrl_c signals to cancel spawned tasks
pub async fn start(&self) -> IoResult<()> {
let server_addr = self.local_addr()?;
loop {
let (stream, connect_addr) = self.listener.accept().await?;
log::info!("Connecting to {}", connect_addr);
// TODO
// - Spawn a tokio task to handle the command while holding on to a reference to self
// - Convert the incoming bincode in a chunked manner to a Vec<Query>
// - Use store_handler to process the queries
// - Block new incoming connections on shutdown by no longer accepting and then
// cancelling existing ServerTask or forcing them to run to completion
select! {
_ = signal::ctrl_c() => {
self.shutdown();
break Ok(());
}
Ok((stream, connect_addr)) = self.listener.accept() => {
log::info!("Connecting to {}", connect_addr);
// - Spawn a tokio task to handle the command while holding on to a reference to self
// - Convert the incoming bincode in a chunked manner to a Vec<Query>
// - Use store_handler to process the queries
// - Block new incoming connections on shutdown by no longer accepting and then
// cancelling existing ServerTask or forcing them to run to completion

// "inexpensive" to clone store handler as it is an Arc
let task = ServerTask::new(stream, self.store_handler.clone());
tokio::spawn(async move {
if let Err(e) = task.process().await {
log::error!("Error handling connection: {}", e)
};
});
let mut task = self.create_task(stream, server_addr)?;
tokio::spawn(async move {
if let Err(e) = task.process().await {
log::error!("Error handling connection: {}", e)
};
});
}
}
}
}

/// stops all tasks and performs cleanup
pub fn shutdown(&self) {
// TODO: Add cleanup for instance persistence
self.shutdown_token.cancel()
}

pub fn local_addr(&self) -> IoResult<SocketAddr> {
self.listener.local_addr()
}

fn create_task(&self, stream: TcpStream, server_addr: SocketAddr) -> IoResult<ServerTask> {
let connected_client = self.client_handler.connect(stream.peer_addr()?);
let reader = BufReader::new(stream);
// add client to client_handler
Ok(ServerTask {
reader,
server_addr,
connected_client,
shutdown_token: self.shutdown_token.clone(),
// "inexpensive" to clone handlers they can be passed around in an Arc
client_handler: self.client_handler.clone(),
store_handler: self.store_handler.clone(),
})
}
}

#[derive(Debug)]
struct ServerTask {
stream: TcpStream,
shutdown_token: CancellationToken,
server_addr: SocketAddr,
reader: BufReader<TcpStream>,
store_handler: Arc<StoreHandler>,
client_handler: Arc<ClientHandler>,
connected_client: ConnectedClient,
}

impl ServerTask {
fn new(stream: TcpStream, store_handler: Arc<StoreHandler>) -> Self {
Self {
stream,
store_handler,
}
}

/// processes messages from a stream
async fn process(self) -> IoResult<()> {
self.stream.readable().await?;
let mut reader = BufReader::new(self.stream);
async fn process(&mut self) -> IoResult<()> {
let mut length_buf = [0u8; LENGTH_HEADER_SIZE];

loop {
reader.read_exact(&mut length_buf).await?;
let data_length = u64::from_be_bytes(length_buf);
let mut data = vec![0u8; data_length as usize];
reader.read_exact(&mut data).await?;
// TODO: Add trace here to catch whenever queries could not be deserialized at all
if let Ok(queries) = ServerQuery::deserialize(false, &data) {
// TODO: Pass in store_handler and use to respond to queries
let results = Self::handle(queries.into_inner());
if let Ok(binary_results) = results.serialize() {
reader.get_mut().write_all(&binary_results).await?;
select! {
_ = self.shutdown_token.cancelled() => {
break;
}
res = self.reader.read_exact(&mut length_buf) => {
match res {
// reader was closed
Err(ref e) if e.kind() == std::io::ErrorKind::UnexpectedEof => {
log::debug!("Client {} hung up buffered stream", self.connected_client.address);
break;
}
Err(e) => {
log::error!("Error reading from task buffered stream: {}", e);
}
Ok(_) => {
let data_length = u64::from_be_bytes(length_buf);
let mut data = vec![0u8; data_length as usize];
self.reader.read_exact(&mut data).await?;
// TODO: Add trace here to catch whenever queries could not be deserialized at all
if let Ok(queries) = ServerQuery::deserialize(false, &data) {
// TODO: Pass in store_handler and use to respond to queries
let results = self.handle(queries.into_inner());
if let Ok(binary_results) = results.serialize() {
self.reader.get_mut().write_all(&binary_results).await?;
}
}
}
}
}
}
}
Ok(())
}

fn handle(queries: Vec<Query>) -> ServerResult {
fn handle(&self, queries: Vec<Query>) -> ServerResult {
let mut result = ServerResult::with_capacity(queries.len());
for query in queries {
result.push(match query {
Query::Ping => Ok(ServerResponse::Pong),
Query::InfoServer => Ok(ServerResponse::Unit),
Query::InfoServer => Ok(ServerResponse::InfoServer(self.server_info())),
Query::ListClients => Ok(ServerResponse::ClientList(self.client_handler.list())),
_ => Err("Response not implemented".to_string()),
})
}
result
}

fn server_info(&self) -> ServerInfo {
ServerInfo {
address: format!("{}", self.server_addr),
version: types::VERSION.to_string(),
r#type: types::server::ServerType::Database,
}
}
}

impl Drop for ServerTask {
fn drop(&mut self) {
self.client_handler.disconnect(&self.connected_client);
}
}

#[cfg(test)]
Expand Down
73 changes: 62 additions & 11 deletions ahnlich/server/tests/server_test.rs
Original file line number Diff line number Diff line change
@@ -1,53 +1,104 @@
use futures::future::join_all;
use server::cli::ServerConfig;
use std::net::SocketAddr;
use std::collections::HashSet;
use std::time::SystemTime;
use tokio::io::{AsyncReadExt, AsyncWriteExt, BufReader};
use tokio::net::TcpStream;
use tokio::time::{timeout, Duration};
use types::bincode::BinCodeSerAndDeser;
use types::query::Query;
use types::query::ServerQuery;
use types::server::ConnectedClient;
use types::server::ServerInfo;
use types::server::ServerResponse;
use types::server::ServerResult;

#[tokio::test]
async fn test_run_server_echos() {
async fn test_server_client_info() {
let server = server::Server::new(&ServerConfig::default())
.await
.expect("Could not initialize server");
let address = server.local_addr().expect("Could not get local addr");
let _ = tokio::spawn(async move { server.start().await });
// Allow some time for the server to start
tokio::time::sleep(Duration::from_millis(100)).await;
//
// Connect to the server multiple times
let first_stream = TcpStream::connect(address).await.unwrap();
let other_stream = TcpStream::connect(address).await.unwrap();
let first_stream_addr = first_stream.local_addr().unwrap();
let expected_response = HashSet::from_iter([
ConnectedClient {
address: format!("{first_stream_addr}"),
time_connected: SystemTime::now(),
},
ConnectedClient {
address: format!("{}", other_stream.local_addr().unwrap()),
time_connected: SystemTime::now(),
},
]);
let message = ServerQuery::from_queries(&[Query::ListClients]);
let mut expected = ServerResult::with_capacity(1);
expected.push(Ok(ServerResponse::ClientList(expected_response.clone())));
let mut reader = BufReader::new(first_stream);
query_server_assert_result(&mut reader, message, expected.clone()).await;
// drop other stream and see if it reflects
drop(other_stream);
let expected_response = HashSet::from_iter([ConnectedClient {
address: format!("{first_stream_addr}"),
time_connected: SystemTime::now(),
}]);
let message = ServerQuery::from_queries(&[Query::ListClients]);
let mut expected = ServerResult::with_capacity(1);
expected.push(Ok(ServerResponse::ClientList(expected_response.clone())));
query_server_assert_result(&mut reader, message, expected.clone()).await;
}

#[tokio::test]
async fn test_run_server_echos() {
let server = server::Server::new(&ServerConfig::default())
.await
.expect("Could not initialize server");
let address = server.local_addr().expect("Could not get local addr");
let _ = tokio::spawn(async move { server.start().await });
// Allow some time for the server to start
tokio::time::sleep(Duration::from_millis(100)).await;
let tasks = vec![
tokio::spawn(async move {
let message = ServerQuery::from_queries(&[Query::InfoServer, Query::Ping]);
let mut expected = ServerResult::with_capacity(2);
expected.push(Ok(ServerResponse::Unit));
expected.push(Ok(ServerResponse::InfoServer(ServerInfo {
address: "127.0.0.1:1369".to_string(),
version: types::VERSION.to_string(),
r#type: types::server::ServerType::Database,
})));
expected.push(Ok(ServerResponse::Pong));
query_server_assert_result(address, message, expected).await
let stream = TcpStream::connect(address).await.unwrap();
let mut reader = BufReader::new(stream);
query_server_assert_result(&mut reader, message, expected).await
}),
tokio::spawn(async move {
let message = ServerQuery::from_queries(&[Query::Ping, Query::InfoServer]);
let mut expected = ServerResult::with_capacity(2);
expected.push(Ok(ServerResponse::Pong));
expected.push(Ok(ServerResponse::Unit));
query_server_assert_result(address, message, expected).await
expected.push(Ok(ServerResponse::InfoServer(ServerInfo {
address: "127.0.0.1:1369".to_string(),
version: types::VERSION.to_string(),
r#type: types::server::ServerType::Database,
})));
let stream = TcpStream::connect(address).await.unwrap();
let mut reader = BufReader::new(stream);
query_server_assert_result(&mut reader, message, expected).await
}),
];
join_all(tasks).await;
}

async fn query_server_assert_result(
server_addr: SocketAddr,
reader: &mut BufReader<TcpStream>,
query: ServerQuery,
expected_result: ServerResult,
) {
// Connect to the server
let stream = TcpStream::connect(server_addr).await.unwrap();
let mut reader = BufReader::new(stream);

// Message to send
let serialized_message = query.serialize().unwrap();

Expand Down
Loading

0 comments on commit a447e01

Please sign in to comment.