From 872e44f14dc36c53a477af9c6d71b4cc39d04691 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?John=20Arg=C3=A9rus?= Date: Tue, 5 Nov 2024 13:30:54 +0100 Subject: [PATCH] Add support for unix domain sockets --- databroker/Cargo.toml | 4 +- databroker/src/grpc/server.rs | 79 +++++++++++++++++++++++------------ databroker/src/main.rs | 63 +++++++++++++++++++++++++++- databroker/tests/world/mod.rs | 4 +- 4 files changed, 119 insertions(+), 31 deletions(-) diff --git a/databroker/Cargo.toml b/databroker/Cargo.toml index 492b0db7..9955764c 100644 --- a/databroker/Cargo.toml +++ b/databroker/Cargo.toml @@ -59,10 +59,10 @@ glob-match = "0.2.1" jemallocator = { version = "0.5.0", optional = true } lazy_static = "1.4.0" thiserror = "1.0.47" +futures = { version = "0.3.28" } # VISS axum = { version = "0.6.20", optional = true, features = ["ws"] } -futures = { version = "0.3.28", optional = true } chrono = { version = "0.4.31", optional = true, features = ["std"] } uuid = { version = "1.4.1", optional = true, features = ["v4"] } @@ -74,7 +74,7 @@ sd-notify = "0.4.1" default = ["tls"] tls = ["tonic/tls"] jemalloc = ["dep:jemallocator"] -viss = ["dep:axum", "dep:chrono", "dep:futures", "dep:uuid"] +viss = ["dep:axum", "dep:chrono", "dep:uuid"] libtest = [] [build-dependencies] diff --git a/databroker/src/grpc/server.rs b/databroker/src/grpc/server.rs index 8bc282ca..5fe24264 100644 --- a/databroker/src/grpc/server.rs +++ b/databroker/src/grpc/server.rs @@ -13,11 +13,15 @@ use std::{convert::TryFrom, future::Future, time::Duration}; -use tokio::net::TcpListener; -use tokio_stream::wrappers::TcpListenerStream; -use tonic::transport::Server; +use futures::Stream; +use tokio::{ + io::{AsyncRead, AsyncWrite}, + net::{TcpListener, UnixListener}, +}; +use tokio_stream::wrappers::{TcpListenerStream, UnixListenerStream}; #[cfg(feature = "tls")] use tonic::transport::ServerTlsConfig; +use tonic::transport::{server::Connected, Server}; use tracing::{debug, info}; use databroker_proto::{kuksa, sdv}; @@ -34,7 +38,7 @@ pub enum ServerTLS { Enabled { tls_config: ServerTlsConfig }, } -#[derive(PartialEq)] +#[derive(PartialEq, Clone)] pub enum Api { KuksaValV1, SdvDatabrokerV1, @@ -95,7 +99,7 @@ where databroker.shutdown().await; } -pub async fn serve( +pub async fn serve_tcp( addr: impl Into, broker: broker::DataBroker, #[cfg(feature = "tls")] server_tls: ServerTLS, @@ -109,25 +113,14 @@ where let socket_addr = addr.into(); let listener = TcpListener::bind(socket_addr).await?; - /* On Linux systems try to notify daemon readiness to systemd. - * This function determines whether the a system is using systemd - * or not, so it is safe to use on non-systemd systems as well. - */ - #[cfg(target_os = "linux")] - { - match sd_notify::booted() { - Ok(true) => { - info!("Notifying systemd that the service is ready"); - sd_notify::notify(false, &[sd_notify::NotifyState::Ready])?; - } - _ => { - debug!("System is not using systemd, will not try to notify"); - } - } + if let Ok(addr) = listener.local_addr() { + info!("Listening on {}", addr); } + let incoming = TcpListenerStream::new(listener); + serve_with_incoming_shutdown( - listener, + incoming, broker, #[cfg(feature = "tls")] server_tls, @@ -138,10 +131,9 @@ where .await } -pub async fn serve_with_incoming_shutdown( - listener: TcpListener, +pub async fn serve_uds( + path: impl AsRef, broker: broker::DataBroker, - #[cfg(feature = "tls")] server_tls: ServerTLS, apis: &[Api], authorization: Authorization, signal: F, @@ -149,12 +141,45 @@ pub async fn serve_with_incoming_shutdown( where F: Future, { - broker.start_housekeeping_task(); + let listener = UnixListener::bind(path)?; + if let Ok(addr) = listener.local_addr() { - info!("Listening on {}", addr); + match addr.as_pathname() { + Some(pathname) => info!("Listening on unix socket at {}", pathname.display()), + None => info!("Listening on unix socket (unknown path)"), + } } - let incoming = TcpListenerStream::new(listener); + let incoming = UnixListenerStream::new(listener); + + serve_with_incoming_shutdown( + incoming, + broker, + ServerTLS::Disabled, + apis, + authorization, + signal, + ) + .await +} + +pub async fn serve_with_incoming_shutdown( + incoming: I, + broker: broker::DataBroker, + #[cfg(feature = "tls")] server_tls: ServerTLS, + apis: &[Api], + authorization: Authorization, + signal: F, +) -> Result<(), Box> +where + F: Future, + I: Stream>, + IO: AsyncRead + AsyncWrite + Connected + Unpin + Send + 'static, + IO::ConnectInfo: Clone + Send + Sync + 'static, + IE: Into>, +{ + broker.start_housekeeping_task(); + let mut server = Server::builder() .http2_keepalive_interval(Some(Duration::from_secs(10))) .http2_keepalive_timeout(Some(Duration::from_secs(20))); diff --git a/databroker/src/main.rs b/databroker/src/main.rs index f141c305..0fed0694 100644 --- a/databroker/src/main.rs +++ b/databroker/src/main.rs @@ -15,6 +15,10 @@ #[global_allocator] static ALLOC: jemallocator::Jemalloc = jemallocator::Jemalloc; +use std::io; +use std::os::unix::fs::FileTypeExt; +use std::path::Path; + use databroker::authorization::Authorization; use databroker::broker::RegistrationError; @@ -170,6 +174,15 @@ async fn read_metadata_file<'a, 'b>( Ok(()) } +fn unlink_unix_domain_socket(path: impl AsRef) -> Result<(), io::Error> { + if let Ok(metadata) = std::fs::metadata(&path) { + if metadata.file_type().is_socket() { + std::fs::remove_file(&path)?; + } + }; + Ok(()) +} + #[tokio::main] async fn main() -> Result<(), Box> { let version = option_env!("CARGO_PKG_VERSION").unwrap_or_default(); @@ -218,6 +231,16 @@ async fn main() -> Result<(), Box> { .value_parser(clap::value_parser!(u16)) .default_value("55555"), ) + .arg( + Arg::new("unix-socket") + .display_order(3) + .long("unix-socket") + .help("Listen on unix socket, e.g. /tmp/kuksa/databroker.sock") + .action(ArgAction::Set) + .value_name("PATH") + .required(false) + .env("KUKSA_DATABROKER_UNIX_SOCKET"), + ) .arg( Arg::new("vss-file") .display_order(4) @@ -457,7 +480,45 @@ async fn main() -> Result<(), Box> { apis.push(grpc::server::Api::SdvDatabrokerV1); } - grpc::server::serve( + let unix_socket = args.get_one::("unix-socket").cloned(); + if let Some(path) = unix_socket { + unlink_unix_domain_socket(&path)?; + std::fs::create_dir_all(Path::new(&path).parent().unwrap())?; + let broker = broker.clone(); + let authorization = authorization.clone(); + let apis = apis.clone(); + tokio::spawn(async move { + if let Err(err) = + grpc::server::serve_uds(&path, broker, &apis, authorization, shutdown_handler()) + .await + { + error!("{err}"); + } + + info!("Unlinking unix domain socket at {}", path); + unlink_unix_domain_socket(path) + .unwrap_or_else(|_| error!("Failed to unlink unix domain socket")); + }); + } + + /* On Linux systems try to notify daemon readiness to systemd. + * This function determines whether the a system is using systemd + * or not, so it is safe to use on non-systemd systems as well. + */ + #[cfg(target_os = "linux")] + { + match sd_notify::booted() { + Ok(true) => { + info!("Notifying systemd that the service is ready"); + sd_notify::notify(false, &[sd_notify::NotifyState::Ready])?; + } + _ => { + debug!("System is not using systemd, will not try to notify"); + } + } + } + + grpc::server::serve_tcp( addr, broker, #[cfg(feature = "tls")] diff --git a/databroker/tests/world/mod.rs b/databroker/tests/world/mod.rs index e3e6a7c6..8aada34c 100644 --- a/databroker/tests/world/mod.rs +++ b/databroker/tests/world/mod.rs @@ -32,6 +32,7 @@ use databroker::{ }; use tokio::net::TcpListener; +use tokio_stream::wrappers::TcpListenerStream; use tracing::debug; use lazy_static::lazy_static; @@ -188,6 +189,7 @@ impl DataBrokerWorld { let addr = listener .local_addr() .expect("failed to determine listener's port"); + let incoming = TcpListenerStream::new(listener); tokio::spawn(async move { let version = option_env!("VERGEN_GIT_SEMVER_LIGHTWEIGHT") @@ -228,7 +230,7 @@ impl DataBrokerWorld { } grpc::server::serve_with_incoming_shutdown( - listener, + incoming, data_broker, #[cfg(feature = "tls")] CERTS.server_tls_config(),