Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Ensured the print buffer is flushed before program exit, short printing #133

Merged
merged 1 commit into from
Feb 29, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions omnibor/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ build-binary = [
"tokio/macros",
"tokio/rt",
"tokio/sync",
"tokio/time",
"tokio/rt-multi-thread"
]

Expand Down
86 changes: 66 additions & 20 deletions omnibor/src/bin/omnibor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,42 +21,66 @@ use std::path::Path;
use std::path::PathBuf;
use std::process::ExitCode;
use std::str::FromStr;
use std::sync::atomic::AtomicBool;
use std::sync::Arc;
use std::time::Duration;
use tokio::fs::File as AsyncFile;
use tokio::io::AsyncWrite;
use tokio::io::AsyncWriteExt as _;
use tokio::sync::mpsc;
use tokio::sync::mpsc::Sender;
use tokio::time::sleep;
use url::Url;

#[tokio::main]
async fn main() -> ExitCode {
let args = Cli::parse();

let printing_done = Arc::new(AtomicBool::new(false));
let printing_done_2 = printing_done.clone();

// TODO(alilleybrinker): Make this channel Msg limit configurable.
let (tx, mut rx) = mpsc::channel::<Msg>(args.buffer.unwrap_or(100));
let (tx, mut rx) = mpsc::channel::<MsgOrEnd>(args.buffer.unwrap_or(100));

// Do all printing in a separate task we spawn to _just_ do printing.
// This stops printing from blocking the worker tasks.
tokio::spawn(async move {
while let Some(msg) = rx.recv().await {
// TODO(alilleybrinker): Handle this error.
let _ = msg.print().await;
match msg {
MsgOrEnd::End => break,
MsgOrEnd::Message(msg) => {
// TODO(alilleybrinker): Handle this error.
let _ = msg.print().await;
}
}
}
rx.close();
printing_done_2.store(true, std::sync::atomic::Ordering::Relaxed);
});

let result = match args.command {
Command::Id(ref args) => run_id(&tx, args).await,
Command::Find(ref args) => run_find(&tx, args).await,
};

let mut return_code = ExitCode::SUCCESS;

if let Err(e) = result {
// TODO(alilleybrinker): Handle this erroring out, probably by
// sync-printing as a last resort.
let _ = tx.send(Msg::error(e, args.format())).await;
return ExitCode::FAILURE;
let _ = tx.send(MsgOrEnd::error(e, args.format())).await;
return_code = ExitCode::FAILURE;
}

ExitCode::SUCCESS
// send a message to end the printing
tx.send(MsgOrEnd::End).await.unwrap();

// wait until the printing is done
while !printing_done.load(std::sync::atomic::Ordering::Relaxed) {
sleep(Duration::from_millis(10)).await;
}

return_code
}

/*===========================================================================
Expand Down Expand Up @@ -97,7 +121,7 @@ struct IdArgs {
/// Path to identify
path: PathBuf,

/// Output format (can be "plain" or "json")
/// Output format (can be "plain", "short", or "json")
#[arg(short = 'f', long = "format", default_value_t)]
format: Format,

Expand All @@ -114,7 +138,7 @@ struct FindArgs {
/// The root path to search under
path: PathBuf,

/// Output format (can be "plain" or "json")
/// Output format (can be "plain", "short", or "json")
#[arg(short = 'f', long = "format", default_value_t)]
format: Format,
}
Expand All @@ -124,13 +148,15 @@ enum Format {
#[default]
Plain,
Json,
Short,
}

impl Display for Format {
fn fmt(&self, f: &mut Formatter<'_>) -> FmtResult {
match self {
Format::Plain => write!(f, "plain"),
Format::Json => write!(f, "json"),
Format::Short => write!(f, "short"),
}
}
}
Expand All @@ -142,6 +168,7 @@ impl FromStr for Format {
match s {
"plain" => Ok(Format::Plain),
"json" => Ok(Format::Json),
"short" => Ok(Format::Short),
_ => Err(anyhow!("unknown format '{}'", s)),
}
}
Expand Down Expand Up @@ -172,7 +199,23 @@ impl FromStr for SelectedHash {
}
}

#[derive(Debug)]
#[derive(Debug, Clone)]
enum MsgOrEnd {
End,
Message(Msg),
}

impl MsgOrEnd {
fn id(path: &Path, url: &Url, format: Format) -> Self {
MsgOrEnd::Message(Msg::id(path, url, format))
}

fn error<E: Into<Error>>(error: E, format: Format) -> MsgOrEnd {
MsgOrEnd::Message(Msg::error(error, format))
}
}

#[derive(Debug, Clone)]
struct Msg {
content: Content,
status: Status,
Expand All @@ -186,6 +229,7 @@ impl Msg {

match format {
Format::Plain => Msg::plain(status, &format!("{} => {}", path, url)),
Format::Short => Msg::plain(status, &format!("{}", url)),
Format::Json => Msg::json(status, json!({ "path": path, "id": url })),
}
}
Expand All @@ -195,7 +239,9 @@ impl Msg {
let status = Status::Error;

match format {
Format::Plain => Msg::plain(status, &format!("error: {}", error.to_string())),
Format::Plain | Format::Short => {
Msg::plain(status, &format!("error: {}", error.to_string()))
}
Format::Json => Msg::json(status, json!({"error": error.to_string()})),
}
}
Expand Down Expand Up @@ -235,7 +281,7 @@ impl Msg {
}
}

#[derive(Debug)]
#[derive(Debug, Clone)]
enum Content {
Json(JsonValue),
Plain(String),
Expand All @@ -250,7 +296,7 @@ impl Display for Content {
}
}

#[derive(Debug)]
#[derive(Debug, Clone, Copy)]
enum Status {
Success,
Error,
Expand All @@ -261,7 +307,7 @@ enum Status {
*-------------------------------------------------------------------------*/

/// Run the `id` subcommand.
async fn run_id(tx: &Sender<Msg>, args: &IdArgs) -> Result<()> {
async fn run_id(tx: &Sender<MsgOrEnd>, args: &IdArgs) -> Result<()> {
let mut file = open_async_file(&args.path).await?;

if file_is_dir(&file).await? {
Expand All @@ -272,7 +318,7 @@ async fn run_id(tx: &Sender<Msg>, args: &IdArgs) -> Result<()> {
}

/// Run the `find` subcommand.
async fn run_find(tx: &Sender<Msg>, args: &FindArgs) -> Result<()> {
async fn run_find(tx: &Sender<MsgOrEnd>, args: &FindArgs) -> Result<()> {
let FindArgs { url, path, format } = args;

let id = ArtifactId::<Sha256>::id_url(url.clone())?;
Expand All @@ -283,7 +329,7 @@ async fn run_find(tx: &Sender<Msg>, args: &FindArgs) -> Result<()> {
loop {
match entries.next().await {
None => break,
Some(Err(e)) => tx.send(Msg::error(e, *format)).await?,
Some(Err(e)) => tx.send(MsgOrEnd::Message(Msg::error(e, *format))).await?,
Some(Ok(entry)) => {
let path = &entry.path();

Expand All @@ -295,7 +341,7 @@ async fn run_find(tx: &Sender<Msg>, args: &FindArgs) -> Result<()> {
let file_url = hash_file(SelectedHash::Sha256, &mut file, &path).await?;

if url == file_url {
tx.send(Msg::id(&path, &url, *format)).await?;
tx.send(MsgOrEnd::id(&path, &url, *format)).await?;
return Ok(());
}
}
Expand All @@ -311,7 +357,7 @@ async fn run_find(tx: &Sender<Msg>, args: &FindArgs) -> Result<()> {

// Identify, recursively, all the files under a directory.
async fn id_directory(
tx: &Sender<Msg>,
tx: &Sender<MsgOrEnd>,
path: &Path,
format: Format,
hash: SelectedHash,
Expand All @@ -321,7 +367,7 @@ async fn id_directory(
loop {
match entries.next().await {
None => break,
Some(Err(e)) => tx.send(Msg::error(e, format)).await?,
Some(Err(e)) => tx.send(MsgOrEnd::error(e, format)).await?,
Some(Ok(entry)) => {
let path = &entry.path();

Expand All @@ -340,14 +386,14 @@ async fn id_directory(

/// Identify a single file.
async fn id_file(
tx: &Sender<Msg>,
tx: &Sender<MsgOrEnd>,
file: &mut AsyncFile,
path: &Path,
format: Format,
hash: SelectedHash,
) -> Result<()> {
let url = hash_file(hash, file, &path).await?;
tx.send(Msg::id(path, &url, format)).await?;
tx.send(MsgOrEnd::id(path, &url, format)).await?;
Ok(())
}

Expand Down