From 04485301f92b09aba98a358f5e59e386048143a5 Mon Sep 17 00:00:00 2001 From: Matthew Mould Date: Tue, 14 May 2024 11:43:32 -0700 Subject: [PATCH 1/2] make jax callback compatible with tqdm --- jax_tqdm/pbar.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/jax_tqdm/pbar.py b/jax_tqdm/pbar.py index df35e63..22b0cd1 100644 --- a/jax_tqdm/pbar.py +++ b/jax_tqdm/pbar.py @@ -131,7 +131,7 @@ def _define_tqdm(arg, transform): tqdm_bars[0].set_description(message, refresh=False) def _update_tqdm(arg, transform): - tqdm_bars[0].update(arg) + tqdm_bars[0].update(int(arg)) def _update_progress_bar(iter_num): "Updates tqdm from a JAX scan or loop" From 4a445629095a8e584ee4ad183f9f7c3c9aac30a5 Mon Sep 17 00:00:00 2001 From: Matthew Mould Date: Tue, 14 May 2024 13:32:01 -0700 Subject: [PATCH 2/2] ordered callbacks to flush progress bar creation --- jax_tqdm/pbar.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/jax_tqdm/pbar.py b/jax_tqdm/pbar.py index 22b0cd1..6c85f71 100644 --- a/jax_tqdm/pbar.py +++ b/jax_tqdm/pbar.py @@ -137,7 +137,7 @@ def _update_progress_bar(iter_num): "Updates tqdm from a JAX scan or loop" _ = jax.lax.cond( iter_num == 0, - lambda _: callback(_define_tqdm, None, None), + lambda _: callback(_define_tqdm, None, None, ordered=True), lambda _: None, operand=None, ) @@ -145,7 +145,7 @@ def _update_progress_bar(iter_num): _ = jax.lax.cond( # update tqdm every multiple of `print_rate` except at the end (iter_num % print_rate == 0) & (iter_num != n - remainder), - lambda _: callback(_update_tqdm, print_rate, None), + lambda _: callback(_update_tqdm, print_rate, None, ordered=True), lambda _: None, operand=None, ) @@ -153,7 +153,7 @@ def _update_progress_bar(iter_num): _ = jax.lax.cond( # update tqdm by `remainder` iter_num == n - remainder, - lambda _: callback(_update_tqdm, remainder, None), + lambda _: callback(_update_tqdm, remainder, None, ordered=True), lambda _: None, operand=None, ) @@ -164,7 +164,7 @@ def _close_tqdm(arg, transform): def close_tqdm(result, iter_num): _ = jax.lax.cond( iter_num == n - 1, - lambda _: callback(_close_tqdm, None, None), + lambda _: callback(_close_tqdm, None, None, ordered=True), lambda _: None, operand=None, )