Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

make jax callback compatible with tqdm #20

Merged
merged 2 commits into from
May 15, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 5 additions & 5 deletions jax_tqdm/pbar.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,29 +131,29 @@ def _define_tqdm(arg, transform):
tqdm_bars[0].set_description(message, refresh=False)

def _update_tqdm(arg, transform):
tqdm_bars[0].update(arg)
tqdm_bars[0].update(int(arg))
zombie-einstein marked this conversation as resolved.
Show resolved Hide resolved

def _update_progress_bar(iter_num):
"Updates tqdm from a JAX scan or loop"
_ = jax.lax.cond(
iter_num == 0,
lambda _: callback(_define_tqdm, None, None),
lambda _: callback(_define_tqdm, None, None, ordered=True),
zombie-einstein marked this conversation as resolved.
Show resolved Hide resolved
lambda _: None,
operand=None,
)

_ = jax.lax.cond(
# update tqdm every multiple of `print_rate` except at the end
(iter_num % print_rate == 0) & (iter_num != n - remainder),
lambda _: callback(_update_tqdm, print_rate, None),
lambda _: callback(_update_tqdm, print_rate, None, ordered=True),
lambda _: None,
operand=None,
)

_ = jax.lax.cond(
# update tqdm by `remainder`
iter_num == n - remainder,
lambda _: callback(_update_tqdm, remainder, None),
lambda _: callback(_update_tqdm, remainder, None, ordered=True),
lambda _: None,
operand=None,
)
Expand All @@ -164,7 +164,7 @@ def _close_tqdm(arg, transform):
def close_tqdm(result, iter_num):
_ = jax.lax.cond(
iter_num == n - 1,
lambda _: callback(_close_tqdm, None, None),
lambda _: callback(_close_tqdm, None, None, ordered=True),
lambda _: None,
operand=None,
)
Expand Down
Loading