Skip to content

Commit

Permalink
refactor cancellation system to a separate cancellation token struct
Browse files Browse the repository at this point in the history
  • Loading branch information
ClementTsang committed Aug 11, 2024
1 parent bdbb3d0 commit 473ccca
Show file tree
Hide file tree
Showing 2 changed files with 84 additions and 53 deletions.
86 changes: 33 additions & 53 deletions src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@

pub mod app;
pub mod utils {
pub mod cancellation_token;
pub mod data_prefixes;
pub mod data_units;
pub mod general;
Expand All @@ -29,7 +30,7 @@ use std::{
panic::{self, PanicInfo},
sync::{
mpsc::{self, Receiver, Sender},
Arc, Condvar, Mutex,
Arc,
},
thread::{self, JoinHandle},
time::{Duration, Instant},
Expand All @@ -49,6 +50,7 @@ use data_conversion::*;
use event::{handle_key_event_or_break, handle_mouse_event, BottomEvent, CollectionThreadEvent};
use options::{args, get_or_create_config, init_app};
use tui::{backend::CrosstermBackend, Terminal};
use utils::cancellation_token::CancellationToken;
#[allow(unused_imports)]
use utils::logging::*;

Expand Down Expand Up @@ -131,23 +133,27 @@ fn panic_hook(panic_info: &PanicInfo<'_>) {
)),
);
}

// TODO: Might be cleaner in the future to use a cancellation token; for now if it panics then shut down the
// main program entirely.
std::process::exit(1);
}

/// Create a thread to poll for user inputs and forward them to the main thread.
fn create_input_thread(
sender: Sender<BottomEvent>, termination_ctrl_lock: Arc<Mutex<bool>>,
sender: Sender<BottomEvent>, cancellation_token: Arc<CancellationToken>,
) -> JoinHandle<()> {
thread::spawn(move || {
let mut mouse_timer = Instant::now();

loop {
if let Ok(is_terminated) = termination_ctrl_lock.try_lock() {
// We don't block.
if *is_terminated {
drop(is_terminated);
// We don't block.
if let Some(is_terminated) = cancellation_token.try_check() {
if is_terminated {
break;
}
}

if let Ok(poll) = poll(Duration::from_millis(20)) {
if poll {
if let Ok(event) = read() {
Expand Down Expand Up @@ -206,8 +212,8 @@ fn create_input_thread(
/// Create a thread to handle data collection.
fn create_collection_thread(
sender: Sender<BottomEvent>, control_receiver: Receiver<CollectionThreadEvent>,
termination_lock: Arc<Mutex<bool>>, termination_cvar: Arc<Condvar>,
app_config_fields: &AppConfigFields, filters: DataFilters, used_widget_set: UsedWidgets,
cancellation_token: Arc<CancellationToken>, app_config_fields: &AppConfigFields,
filters: DataFilters, used_widget_set: UsedWidgets,
) -> JoinHandle<()> {
let temp_type = app_config_fields.temperature_type;
let use_current_cpu_total = app_config_fields.use_current_cpu_total;
Expand All @@ -228,9 +234,8 @@ fn create_collection_thread(

loop {
// Check once at the very top... don't block though.
if let Ok(is_terminated) = termination_lock.try_lock() {
if *is_terminated {
drop(is_terminated);
if let Some(is_terminated) = cancellation_token.try_check() {
if is_terminated {
break;
}
}
Expand All @@ -247,9 +252,8 @@ fn create_collection_thread(
data_state.update_data();

// Yet another check to bail if needed... do not block!
if let Ok(is_terminated) = termination_lock.try_lock() {
if *is_terminated {
drop(is_terminated);
if let Some(is_terminated) = cancellation_token.try_check() {
if is_terminated {
break;
}
}
Expand All @@ -260,15 +264,9 @@ fn create_collection_thread(
break;
}

// This is actually used as a "sleep" that can be interrupted by another thread.
if let Ok((is_terminated, _)) = termination_cvar.wait_timeout(
termination_lock.lock().unwrap(),
Duration::from_millis(update_time),
) {
if *is_terminated {
drop(is_terminated);
break;
}
// Sleep while allowing for interruptions...
if cancellation_token.sleep_with_cancellation(Duration::from_millis(update_time)) {
break;
}
}
})
Expand Down Expand Up @@ -308,8 +306,6 @@ fn generate_schema() -> anyhow::Result<()> {
}

fn main() -> anyhow::Result<()> {
// TODO: If there is any panic in any thread, send a cancellation token (or similar) to shut down everything.

// let _profiler = dhat::Profiler::new_heap();

let args = args::get_args();
Expand Down Expand Up @@ -341,12 +337,7 @@ fn main() -> anyhow::Result<()> {
// Check if the current environment is in a terminal.
check_if_terminal();

// Create termination mutex and cvar. We use this setup because we need to sleep
// at some points in the update thread, but we want to be able to interrupt
// the "sleep" if a termination occurs.
let termination_lock = Arc::new(Mutex::new(false));
let termination_cvar = Arc::new(Condvar::new());

let cancellation_token = Arc::new(CancellationToken::default());
let (sender, receiver) = mpsc::channel();

// Set up the event loop thread; we set this up early to speed up
Expand All @@ -355,37 +346,27 @@ fn main() -> anyhow::Result<()> {
let _collection_thread = create_collection_thread(
sender.clone(),
collection_thread_ctrl_receiver,
termination_lock.clone(),
termination_cvar.clone(),
cancellation_token.clone(),
&app.app_config_fields,
app.filters.clone(),
app.used_widgets,
);

// Set up the input handling loop thread.
let _input_thread = create_input_thread(sender.clone(), termination_lock.clone());
let _input_thread = create_input_thread(sender.clone(), cancellation_token.clone());

// Set up the cleaning loop thread.
let _cleaning_thread = {
let lock = termination_lock.clone();
let cvar = termination_cvar.clone();
let cancellation_token = cancellation_token.clone();
let cleaning_sender = sender.clone();
let offset_wait_time = app.app_config_fields.retention_ms + 60000;
thread::spawn(move || {
loop {
let result = cvar.wait_timeout(
lock.lock().unwrap(),
Duration::from_millis(offset_wait_time),
);
if let Ok(result) = result {
if *(result.0) {
break;
}
}
if cleaning_sender.send(BottomEvent::Clean).is_err() {
// debug!("Failed to send cleaning sender...");
break;
}
thread::spawn(move || loop {
if cancellation_token.sleep_with_cancellation(Duration::from_millis(offset_wait_time)) {
break;
}

if cleaning_sender.send(BottomEvent::Clean).is_err() {
break;
}
})
};
Expand Down Expand Up @@ -585,8 +566,7 @@ fn main() -> anyhow::Result<()> {
}

// I think doing it in this order is safe...
*termination_lock.lock().unwrap() = true;
termination_cvar.notify_all();
cancellation_token.cancel();
cleanup_terminal(&mut terminal)?;

Ok(())
Expand Down
51 changes: 51 additions & 0 deletions src/utils/cancellation_token.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
use std::{
sync::{Condvar, Mutex},
time::Duration,
};

/// A cancellation token.
pub(crate) struct CancellationToken {
// The "check" for the cancellation token. Setting this to true will mark the cancellation token as "cancelled".
mutex: Mutex<bool>,
cvar: Condvar,
}

impl Default for CancellationToken {
fn default() -> Self {
Self {
mutex: Mutex::new(false),
cvar: Condvar::new(),
}
}
}

impl CancellationToken {
/// Mark the [`CancellationToken`] as cancelled.
pub fn cancel(&self) {
*self.mutex.lock().unwrap() = true;
self.cvar.notify_all();
}

/// Try and check the [`CancellationToken`]'s status. Note that
/// this will not block.
pub fn try_check(&self) -> Option<bool> {
self.mutex.try_lock().ok().map(|guard| *guard)
}

/// Allows a thread to sleep while still being interruptible with by the token.
///
/// Returns the condition state after either sleeping or being woken up.
pub fn sleep_with_cancellation(&self, duration: Duration) -> bool {
let guard = self
.mutex
.lock()
.expect("cancellation token lock should not be poisoned");

let (result, _) = self
.cvar
.wait_timeout(guard, duration)
.expect("cancellation token lock should not be poisoned");

*result
}
}

0 comments on commit 473ccca

Please sign in to comment.