diff --git a/jax_tqdm/pbar.py b/jax_tqdm/pbar.py index df35e63..6c85f71 100644 --- a/jax_tqdm/pbar.py +++ b/jax_tqdm/pbar.py @@ -131,13 +131,13 @@ def _define_tqdm(arg, transform): tqdm_bars[0].set_description(message, refresh=False) def _update_tqdm(arg, transform): - tqdm_bars[0].update(arg) + tqdm_bars[0].update(int(arg)) def _update_progress_bar(iter_num): "Updates tqdm from a JAX scan or loop" _ = jax.lax.cond( iter_num == 0, - lambda _: callback(_define_tqdm, None, None), + lambda _: callback(_define_tqdm, None, None, ordered=True), lambda _: None, operand=None, ) @@ -145,7 +145,7 @@ def _update_progress_bar(iter_num): _ = jax.lax.cond( # update tqdm every multiple of `print_rate` except at the end (iter_num % print_rate == 0) & (iter_num != n - remainder), - lambda _: callback(_update_tqdm, print_rate, None), + lambda _: callback(_update_tqdm, print_rate, None, ordered=True), lambda _: None, operand=None, ) @@ -153,7 +153,7 @@ def _update_progress_bar(iter_num): _ = jax.lax.cond( # update tqdm by `remainder` iter_num == n - remainder, - lambda _: callback(_update_tqdm, remainder, None), + lambda _: callback(_update_tqdm, remainder, None, ordered=True), lambda _: None, operand=None, ) @@ -164,7 +164,7 @@ def _close_tqdm(arg, transform): def close_tqdm(result, iter_num): _ = jax.lax.cond( iter_num == n - 1, - lambda _: callback(_close_tqdm, None, None), + lambda _: callback(_close_tqdm, None, None, ordered=True), lambda _: None, operand=None, )