Skip to content

Commit

Permalink
refactor(socket): handle recoverable errors from read/write threads o…
Browse files Browse the repository at this point in the history
…n the socket
  • Loading branch information
dev-msp committed Apr 13, 2024
1 parent c76b71a commit 48ccf40
Showing 1 changed file with 100 additions and 14 deletions.
114 changes: 100 additions & 14 deletions src/socket.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
use std::{
io::{BufRead, BufReader, Read, Write},
io::{self, BufRead, BufReader, Read, Write},
os::unix::{
fs::FileTypeExt,
net::{UnixListener, UnixStream},
Expand All @@ -11,12 +11,54 @@ use std::{
use anyhow::anyhow;
use crossbeam::{
atomic::AtomicCell,
channel::{unbounded, Receiver, Sender},
channel::{unbounded, Receiver, SendError, Sender},
};
use serde_json::Value;

struct DebugBufReader<R: BufRead>(R);

#[derive(Debug, thiserror::Error)]
#[error("socket error: {0}")]
enum Error {
Stream(#[from] std::io::Error),
Read(#[from] ReadError),
Write(#[from] WriteError),
}

#[derive(Debug, thiserror::Error)]
#[error("read: {0}")]
enum ReadError {
#[error("parse: {0}")]
Parse(#[from] serde_json::Error),

#[error("channel: {0}")]
Channel(#[from] SendError<Value>),
}

#[derive(Debug, thiserror::Error)]
#[error("write error: {0}")]
enum WriteError {
#[error("io: {0}")]
Io(#[from] io::Error),
}

impl Error {
fn is_broken_pipe(&self) -> bool {
match self {
Error::Write(WriteError::Io(e)) => e.kind() == io::ErrorKind::BrokenPipe,
_ => false,
}
}

fn recoverable(&self) -> bool {
if self.is_broken_pipe() {
return true;
}

matches!(self, Error::Read(ReadError::Parse(_)))
}
}

impl<R: BufRead> Read for DebugBufReader<R> {
fn read(&mut self, buf: &mut [u8]) -> std::io::Result<usize> {
let n = self.0.read(buf)?;
Expand All @@ -40,7 +82,7 @@ fn write_thread(
mut wstream: UnixStream,
r: &Receiver<Value>,
is_done: Arc<AtomicCell<bool>>,
) -> thread::JoinHandle<Result<(), anyhow::Error>> {
) -> thread::JoinHandle<Result<(), WriteError>> {
let r = r.clone();
thread::spawn(move || {
while !is_done.load() {
Expand All @@ -61,7 +103,7 @@ fn read_thread(
stream: UnixStream,
s: Sender<Value>,
is_done: Arc<AtomicCell<bool>>,
) -> thread::JoinHandle<Result<(), anyhow::Error>> {
) -> thread::JoinHandle<Result<(), ReadError>> {
thread::spawn(move || {
let reader = DebugBufReader(BufReader::new(stream));
let it = reader.lines().flat_map(|l| {
Expand All @@ -86,12 +128,54 @@ fn read_thread(
})
}

struct ThreadName<S: ToString>(Option<S>);

impl<S: ToString> ThreadName<S> {
fn new(name: Option<S>) -> Self {
Self(name)
}

fn realize(self) -> String {
self.0
.map(|n| format!("thread named {}", n.to_string()))
.unwrap_or_else(|| "thread".to_string())
}
}

/// Opinionated function to handle thread join results
///
/// When the thread cannot be joined, this function will panic.
fn settle_thread<T, E>(
handle: thread::JoinHandle<Result<T, E>>,
name: Option<&'static str>,
) -> Result<T, Error>
where
E: Into<Error>,
{
let join_result = handle.join().unwrap_or_else(|_| {
panic!("Failed to join {} thread", ThreadName::new(name).realize());
});

match join_result.map_err(Into::into) {
Err(e) if e.recoverable() => {
log::warn!(
"Recoverable failure in {}: {:?}",
name.map(|n| format!("{} thread", n))
.unwrap_or_else(|| "thread".to_string()),
e
);
Err(e)
}
x => x,
}
}

#[tracing::instrument(skip_all)]
fn handle_stream(
stream: UnixStream,
cmd_send: Sender<Value>,
res_recv: &Receiver<Value>,
) -> Result<(), anyhow::Error> {
) -> Result<(), Error> {
let wstream = stream.try_clone()?;
log::trace!("Cloned stream");

Expand All @@ -100,21 +184,20 @@ fn handle_stream(
let reads = read_thread(stream, cmd_send, is_done.clone());
let writes = write_thread(wstream, res_recv, is_done);

let w_outcome = writes.join().expect("Failed to join write thread");
let r_outcome = reads.join().expect("Failed to join read thread");

w_outcome.unwrap();
r_outcome.unwrap();
settle_thread(writes, Some("write"))?;
settle_thread(reads, Some("read"))?;

log::debug!("Exiting handle_stream");
Ok(())
}

type Handle = std::thread::JoinHandle<Result<(), anyhow::Error>>;
type ChannelPair<T> = (Receiver<T>, Sender<T>);

// Triple of channel pair (commands), sender (responses), and handle for the socket thread
type InstructionHandle = (ChannelPair<Value>, Sender<Value>, Handle);

pub fn receive_instructions(
socket_path: &str,
) -> Result<((Receiver<Value>, Sender<Value>), Sender<Value>, Handle), anyhow::Error> {
pub fn receive_instructions(socket_path: &str) -> Result<InstructionHandle, anyhow::Error> {
match std::fs::metadata(socket_path) {
Ok(metadata) if metadata.file_type().is_socket() => {
std::fs::remove_file(socket_path)?;
Expand All @@ -134,7 +217,10 @@ pub fn receive_instructions(

let mut incoming = listener.incoming();
while let Some(rstream) = incoming.next().transpose()? {
handle_stream(rstream, csend.clone(), &rrecv)?;
match handle_stream(rstream, csend.clone(), &rrecv) {
Err(e) if e.recoverable() => Ok(()),
x => x,
}?
}
log::warn!("Listener done providing streams");
Ok(())
Expand Down

0 comments on commit 48ccf40

Please sign in to comment.