Skip to content

Commit

Permalink
feat: Ensured the print buffer is flushed before program exit.
Browse files Browse the repository at this point in the history
Added a new 'Short' format that only prints the gitoid

Signed-off-by: David Pollak <[email protected]>
  • Loading branch information
dpp committed Feb 28, 2024
1 parent c9253f3 commit db8a027
Show file tree
Hide file tree
Showing 2 changed files with 67 additions and 20 deletions.
1 change: 1 addition & 0 deletions omnibor/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ build-binary = [
"tokio/macros",
"tokio/rt",
"tokio/sync",
"tokio/time",
"tokio/rt-multi-thread"
]

Expand Down
86 changes: 66 additions & 20 deletions omnibor/src/bin/omnibor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,42 +21,66 @@ use std::path::Path;
use std::path::PathBuf;
use std::process::ExitCode;
use std::str::FromStr;
use std::sync::atomic::AtomicBool;
use std::sync::Arc;
use std::time::Duration;
use tokio::fs::File as AsyncFile;
use tokio::io::AsyncWrite;
use tokio::io::AsyncWriteExt as _;
use tokio::sync::mpsc;
use tokio::sync::mpsc::Sender;
use tokio::time::sleep;
use url::Url;

#[tokio::main]
async fn main() -> ExitCode {
let args = Cli::parse();

let printing_done = Arc::new(AtomicBool::new(false));
let printing_done_2 = printing_done.clone();

// TODO(alilleybrinker): Make this channel Msg limit configurable.
let (tx, mut rx) = mpsc::channel::<Msg>(args.buffer.unwrap_or(100));
let (tx, mut rx) = mpsc::channel::<MsgOrEnd>(args.buffer.unwrap_or(100));

// Do all printing in a separate task we spawn to _just_ do printing.
// This stops printing from blocking the worker tasks.
tokio::spawn(async move {
while let Some(msg) = rx.recv().await {
// TODO(alilleybrinker): Handle this error.
let _ = msg.print().await;
match msg {
MsgOrEnd::End => break,
MsgOrEnd::Message(msg) => {
// TODO(alilleybrinker): Handle this error.
let _ = msg.print().await;
}
}
}
rx.close();
printing_done_2.store(true, std::sync::atomic::Ordering::Relaxed);
});

let result = match args.command {
Command::Id(ref args) => run_id(&tx, args).await,
Command::Find(ref args) => run_find(&tx, args).await,
};

let mut return_code = ExitCode::SUCCESS;

if let Err(e) = result {
// TODO(alilleybrinker): Handle this erroring out, probably by
// sync-printing as a last resort.
let _ = tx.send(Msg::error(e, args.format())).await;
return ExitCode::FAILURE;
let _ = tx.send(MsgOrEnd::error(e, args.format())).await;
return_code = ExitCode::FAILURE;
}

ExitCode::SUCCESS
// send a message to end the printing
tx.send(MsgOrEnd::End).await.unwrap();

// wait until the printing is done
while !printing_done.load(std::sync::atomic::Ordering::Relaxed) {
sleep(Duration::from_millis(10)).await;
}

return_code
}

/*===========================================================================
Expand Down Expand Up @@ -97,7 +121,7 @@ struct IdArgs {
/// Path to identify
path: PathBuf,

/// Output format (can be "plain" or "json")
/// Output format (can be "plain", "short", or "json")
#[arg(short = 'f', long = "format", default_value_t)]
format: Format,

Expand All @@ -114,7 +138,7 @@ struct FindArgs {
/// The root path to search under
path: PathBuf,

/// Output format (can be "plain" or "json")
/// Output format (can be "plain", "short", or "json")
#[arg(short = 'f', long = "format", default_value_t)]
format: Format,
}
Expand All @@ -124,13 +148,15 @@ enum Format {
#[default]
Plain,
Json,
Short,
}

impl Display for Format {
fn fmt(&self, f: &mut Formatter<'_>) -> FmtResult {
match self {
Format::Plain => write!(f, "plain"),
Format::Json => write!(f, "json"),
Format::Short => write!(f, "short"),
}
}
}
Expand All @@ -142,6 +168,7 @@ impl FromStr for Format {
match s {
"plain" => Ok(Format::Plain),
"json" => Ok(Format::Json),
"short" => Ok(Format::Short),
_ => Err(anyhow!("unknown format '{}'", s)),
}
}
Expand Down Expand Up @@ -172,7 +199,23 @@ impl FromStr for SelectedHash {
}
}

#[derive(Debug)]
#[derive(Debug, Clone)]
enum MsgOrEnd {
End,
Message(Msg),
}

impl MsgOrEnd {
fn id(path: &Path, url: &Url, format: Format) -> Self {
MsgOrEnd::Message(Msg::id(path, url, format))
}

fn error<E: Into<Error>>(error: E, format: Format) -> MsgOrEnd {
MsgOrEnd::Message(Msg::error(error, format))
}
}

#[derive(Debug, Clone)]
struct Msg {
content: Content,
status: Status,
Expand All @@ -186,6 +229,7 @@ impl Msg {

match format {
Format::Plain => Msg::plain(status, &format!("{} => {}", path, url)),
Format::Short => Msg::plain(status, &format!("{}", url)),
Format::Json => Msg::json(status, json!({ "path": path, "id": url })),
}
}
Expand All @@ -195,7 +239,9 @@ impl Msg {
let status = Status::Error;

match format {
Format::Plain => Msg::plain(status, &format!("error: {}", error.to_string())),
Format::Plain | Format::Short => {
Msg::plain(status, &format!("error: {}", error.to_string()))
}
Format::Json => Msg::json(status, json!({"error": error.to_string()})),
}
}
Expand Down Expand Up @@ -235,7 +281,7 @@ impl Msg {
}
}

#[derive(Debug)]
#[derive(Debug, Clone)]
enum Content {
Json(JsonValue),
Plain(String),
Expand All @@ -250,7 +296,7 @@ impl Display for Content {
}
}

#[derive(Debug)]
#[derive(Debug, Clone, Copy)]
enum Status {
Success,
Error,
Expand All @@ -261,7 +307,7 @@ enum Status {
*-------------------------------------------------------------------------*/

/// Run the `id` subcommand.
async fn run_id(tx: &Sender<Msg>, args: &IdArgs) -> Result<()> {
async fn run_id(tx: &Sender<MsgOrEnd>, args: &IdArgs) -> Result<()> {
let mut file = open_async_file(&args.path).await?;

if file_is_dir(&file).await? {
Expand All @@ -272,7 +318,7 @@ async fn run_id(tx: &Sender<Msg>, args: &IdArgs) -> Result<()> {
}

/// Run the `find` subcommand.
async fn run_find(tx: &Sender<Msg>, args: &FindArgs) -> Result<()> {
async fn run_find(tx: &Sender<MsgOrEnd>, args: &FindArgs) -> Result<()> {
let FindArgs { url, path, format } = args;

let id = ArtifactId::<Sha256>::id_url(url.clone())?;
Expand All @@ -283,7 +329,7 @@ async fn run_find(tx: &Sender<Msg>, args: &FindArgs) -> Result<()> {
loop {
match entries.next().await {
None => break,
Some(Err(e)) => tx.send(Msg::error(e, *format)).await?,
Some(Err(e)) => tx.send(MsgOrEnd::Message(Msg::error(e, *format))).await?,
Some(Ok(entry)) => {
let path = &entry.path();

Expand All @@ -295,7 +341,7 @@ async fn run_find(tx: &Sender<Msg>, args: &FindArgs) -> Result<()> {
let file_url = hash_file(SelectedHash::Sha256, &mut file, &path).await?;

if url == file_url {
tx.send(Msg::id(&path, &url, *format)).await?;
tx.send(MsgOrEnd::id(&path, &url, *format)).await?;
return Ok(());
}
}
Expand All @@ -311,7 +357,7 @@ async fn run_find(tx: &Sender<Msg>, args: &FindArgs) -> Result<()> {

// Identify, recursively, all the files under a directory.
async fn id_directory(
tx: &Sender<Msg>,
tx: &Sender<MsgOrEnd>,
path: &Path,
format: Format,
hash: SelectedHash,
Expand All @@ -321,7 +367,7 @@ async fn id_directory(
loop {
match entries.next().await {
None => break,
Some(Err(e)) => tx.send(Msg::error(e, format)).await?,
Some(Err(e)) => tx.send(MsgOrEnd::error(e, format)).await?,
Some(Ok(entry)) => {
let path = &entry.path();

Expand All @@ -340,14 +386,14 @@ async fn id_directory(

/// Identify a single file.
async fn id_file(
tx: &Sender<Msg>,
tx: &Sender<MsgOrEnd>,
file: &mut AsyncFile,
path: &Path,
format: Format,
hash: SelectedHash,
) -> Result<()> {
let url = hash_file(hash, file, &path).await?;
tx.send(Msg::id(path, &url, format)).await?;
tx.send(MsgOrEnd::id(path, &url, format)).await?;
Ok(())
}

Expand Down

0 comments on commit db8a027

Please sign in to comment.