From 3fe0f775100589a8e05df49b753d7adb57ac46bf Mon Sep 17 00:00:00 2001 From: David Pollak Date: Wed, 28 Feb 2024 11:28:42 -0500 Subject: [PATCH] Ensured the print buffer is flushed before program exit. Added an option to print the plain URL in short format --- omnibor/Cargo.toml | 1 + omnibor/src/bin/omnibor.rs | 92 +++++++++++++++++++++++++++++--------- 2 files changed, 71 insertions(+), 22 deletions(-) diff --git a/omnibor/Cargo.toml b/omnibor/Cargo.toml index 53939b5..7f996fe 100644 --- a/omnibor/Cargo.toml +++ b/omnibor/Cargo.toml @@ -45,6 +45,7 @@ build-binary = [ "tokio/macros", "tokio/rt", "tokio/sync", + "tokio/time", "tokio/rt-multi-thread" ] diff --git a/omnibor/src/bin/omnibor.rs b/omnibor/src/bin/omnibor.rs index 8faca98..714775b 100644 --- a/omnibor/src/bin/omnibor.rs +++ b/omnibor/src/bin/omnibor.rs @@ -13,6 +13,7 @@ use omnibor::Sha256; use serde_json::json; use serde_json::Value as JsonValue; use smart_default::SmartDefault; +use tokio::time::sleep; use std::default::Default; use std::fmt::Display; use std::fmt::Formatter; @@ -21,6 +22,9 @@ use std::path::Path; use std::path::PathBuf; use std::process::ExitCode; use std::str::FromStr; +use std::sync::atomic::AtomicBool; +use std::sync::Arc; +use std::time::Duration; use tokio::fs::File as AsyncFile; use tokio::io::AsyncWrite; use tokio::io::AsyncWriteExt as _; @@ -32,16 +36,26 @@ use url::Url; async fn main() -> ExitCode { let args = Cli::parse(); + let printing_done = Arc::new(AtomicBool::new(false)); + let printing_done_2 = printing_done.clone(); + // TODO(alilleybrinker): Make this channel Msg limit configurable. - let (tx, mut rx) = mpsc::channel::(args.buffer.unwrap_or(100)); + let (tx, mut rx) = mpsc::channel::(args.buffer.unwrap_or(100)); // Do all printing in a separate task we spawn to _just_ do printing. // This stops printing from blocking the worker tasks. tokio::spawn(async move { while let Some(msg) = rx.recv().await { - // TODO(alilleybrinker): Handle this error. - let _ = msg.print().await; + match msg { + MsgOrEnd::End => break, + MsgOrEnd::Message(msg) => { + // TODO(alilleybrinker): Handle this error. + let _ = msg.print().await; + } + } } + rx.close(); + printing_done_2.store(true, std::sync::atomic::Ordering::Relaxed); }); let result = match args.command { @@ -49,14 +63,25 @@ async fn main() -> ExitCode { Command::Find(ref args) => run_find(&tx, args).await, }; + let mut return_code = ExitCode::SUCCESS; + if let Err(e) = result { // TODO(alilleybrinker): Handle this erroring out, probably by // sync-printing as a last resort. - let _ = tx.send(Msg::error(e, args.format())).await; - return ExitCode::FAILURE; + let _ = tx.send(MsgOrEnd::error(e, args.format())).await; + return_code = ExitCode::FAILURE; + } + + + // send a message to end the printing + tx.send(MsgOrEnd::End).await.unwrap(); + + // wait until the printing is done + while !printing_done.load(std::sync::atomic::Ordering::Relaxed) { + sleep(Duration::from_millis(10)).await; } - ExitCode::SUCCESS + return_code } /*=========================================================================== @@ -104,6 +129,10 @@ struct IdArgs { /// Hash algorithm (can be "sha256") #[arg(short = 'H', long = "hash", default_value_t)] hash: SelectedHash, + + /// Should the messages be short (just contain the gitoid)? + #[arg(short = 's', long = "short")] + short: bool, } #[derive(Debug, Args)] @@ -172,20 +201,37 @@ impl FromStr for SelectedHash { } } -#[derive(Debug)] +#[derive(Debug, Clone)] +enum MsgOrEnd { + End, + Message(Msg), +} + +impl MsgOrEnd { + fn id(path: &Path, url: &Url, format: Format, short: bool) -> Self { + MsgOrEnd::Message(Msg::id(path, url, format, short)) + } + + fn error>(error: E, format: Format) -> MsgOrEnd { + MsgOrEnd::Message(Msg::error(error, format)) + } +} + +#[derive(Debug, Clone)] struct Msg { content: Content, status: Status, } impl Msg { - fn id(path: &Path, url: &Url, format: Format) -> Self { + fn id(path: &Path, url: &Url, format: Format, short: bool) -> Self { let status = Status::Success; let path = path.display().to_string(); let url = url.to_string(); match format { - Format::Plain => Msg::plain(status, &format!("{} => {}", path, url)), + Format::Plain if !short => Msg::plain(status, &format!("{} => {}", path, url)), + Format::Plain => Msg::plain(status, &format!("{}", url)), Format::Json => Msg::json(status, json!({ "path": path, "id": url })), } } @@ -235,7 +281,7 @@ impl Msg { } } -#[derive(Debug)] +#[derive(Debug, Clone)] enum Content { Json(JsonValue), Plain(String), @@ -250,7 +296,7 @@ impl Display for Content { } } -#[derive(Debug)] +#[derive(Debug, Clone, Copy)] enum Status { Success, Error, @@ -261,18 +307,18 @@ enum Status { *-------------------------------------------------------------------------*/ /// Run the `id` subcommand. -async fn run_id(tx: &Sender, args: &IdArgs) -> Result<()> { +async fn run_id(tx: &Sender, args: &IdArgs) -> Result<()> { let mut file = open_async_file(&args.path).await?; if file_is_dir(&file).await? { - id_directory(tx, &args.path, args.format, args.hash).await + id_directory(tx, &args.path, args.format, args.hash, args.short).await } else { - id_file(tx, &mut file, &args.path, args.format, args.hash).await + id_file(tx, &mut file, &args.path, args.format, args.hash, args.short).await } } /// Run the `find` subcommand. -async fn run_find(tx: &Sender, args: &FindArgs) -> Result<()> { +async fn run_find(tx: &Sender, args: &FindArgs) -> Result<()> { let FindArgs { url, path, format } = args; let id = ArtifactId::::id_url(url.clone())?; @@ -283,7 +329,7 @@ async fn run_find(tx: &Sender, args: &FindArgs) -> Result<()> { loop { match entries.next().await { None => break, - Some(Err(e)) => tx.send(Msg::error(e, *format)).await?, + Some(Err(e)) => tx.send(MsgOrEnd::Message(Msg::error(e, *format))).await?, Some(Ok(entry)) => { let path = &entry.path(); @@ -295,7 +341,7 @@ async fn run_find(tx: &Sender, args: &FindArgs) -> Result<()> { let file_url = hash_file(SelectedHash::Sha256, &mut file, &path).await?; if url == file_url { - tx.send(Msg::id(&path, &url, *format)).await?; + tx.send(MsgOrEnd::id(&path, &url, *format, false)).await?; return Ok(()); } } @@ -311,17 +357,18 @@ async fn run_find(tx: &Sender, args: &FindArgs) -> Result<()> { // Identify, recursively, all the files under a directory. async fn id_directory( - tx: &Sender, + tx: &Sender, path: &Path, format: Format, hash: SelectedHash, + short: bool ) -> Result<()> { let mut entries = WalkDir::new(path); loop { match entries.next().await { None => break, - Some(Err(e)) => tx.send(Msg::error(e, format)).await?, + Some(Err(e)) => tx.send(MsgOrEnd::error(e, format)).await?, Some(Ok(entry)) => { let path = &entry.path(); @@ -330,7 +377,7 @@ async fn id_directory( } let mut file = open_async_file(&path).await?; - id_file(tx, &mut file, &path, format, hash).await?; + id_file(tx, &mut file, &path, format, hash, short).await?; } } } @@ -340,14 +387,15 @@ async fn id_directory( /// Identify a single file. async fn id_file( - tx: &Sender, + tx: &Sender, file: &mut AsyncFile, path: &Path, format: Format, hash: SelectedHash, + short: bool ) -> Result<()> { let url = hash_file(hash, file, &path).await?; - tx.send(Msg::id(path, &url, format)).await?; + tx.send(MsgOrEnd::id(path, &url, format, short)).await?; Ok(()) }