Skip to content

Commit

Permalink
make jax callback compatible with tqdm (#20)
Browse files Browse the repository at this point in the history
* make jax callback compatible with tqdm

* ordered callbacks to flush progress bar creation

---------

Co-authored-by: Matthew Mould <[email protected]>
  • Loading branch information
mdmould and Matthew Mould authored May 15, 2024
1 parent 43d5fda commit cd2fe02
Showing 1 changed file with 5 additions and 5 deletions.
10 changes: 5 additions & 5 deletions jax_tqdm/pbar.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,29 +131,29 @@ def _define_tqdm(arg, transform):
tqdm_bars[0].set_description(message, refresh=False)

def _update_tqdm(arg, transform):
tqdm_bars[0].update(arg)
tqdm_bars[0].update(int(arg))

def _update_progress_bar(iter_num):
"Updates tqdm from a JAX scan or loop"
_ = jax.lax.cond(
iter_num == 0,
lambda _: callback(_define_tqdm, None, None),
lambda _: callback(_define_tqdm, None, None, ordered=True),
lambda _: None,
operand=None,
)

_ = jax.lax.cond(
# update tqdm every multiple of `print_rate` except at the end
(iter_num % print_rate == 0) & (iter_num != n - remainder),
lambda _: callback(_update_tqdm, print_rate, None),
lambda _: callback(_update_tqdm, print_rate, None, ordered=True),
lambda _: None,
operand=None,
)

_ = jax.lax.cond(
# update tqdm by `remainder`
iter_num == n - remainder,
lambda _: callback(_update_tqdm, remainder, None),
lambda _: callback(_update_tqdm, remainder, None, ordered=True),
lambda _: None,
operand=None,
)
Expand All @@ -164,7 +164,7 @@ def _close_tqdm(arg, transform):
def close_tqdm(result, iter_num):
_ = jax.lax.cond(
iter_num == n - 1,
lambda _: callback(_close_tqdm, None, None),
lambda _: callback(_close_tqdm, None, None, ordered=True),
lambda _: None,
operand=None,
)
Expand Down

0 comments on commit cd2fe02

Please sign in to comment.