Skip to content

Commit

Permalink
Refactor host_callback with debug.callback
Browse files Browse the repository at this point in the history
  • Loading branch information
BirkhoffG committed Apr 24, 2024
1 parent 8afbd6d commit 942e93f
Showing 1 changed file with 12 additions and 11 deletions.
23 changes: 12 additions & 11 deletions jax_tqdm/pbar.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import typing

import jax
from jax.experimental import host_callback
from jax.debug import callback
from tqdm.auto import tqdm


Expand Down Expand Up @@ -135,38 +135,39 @@ def _update_tqdm(arg, transform):

def _update_progress_bar(iter_num):
"Updates tqdm from a JAX scan or loop"
_ = jax.jax.lax.cond(
_ = jax.lax.cond(
iter_num == 0,
lambda _: host_callback.id_tap(_define_tqdm, None, result=iter_num),
lambda _: iter_num,
lambda _: callback(_define_tqdm, None, None),
lambda _: None,
operand=None,
)

_ = jax.lax.cond(
# update tqdm every multiple of `print_rate` except at the end
(iter_num % print_rate == 0) & (iter_num != n - remainder),
lambda _: host_callback.id_tap(_update_tqdm, print_rate, result=iter_num),
lambda _: iter_num,
lambda _: callback(_update_tqdm, print_rate, None),
lambda _: None,
operand=None,
)

_ = jax.lax.cond(
# update tqdm by `remainder`
iter_num == n - remainder,
lambda _: host_callback.id_tap(_update_tqdm, remainder, result=iter_num),
lambda _: iter_num,
lambda _: callback(_update_tqdm, remainder, None),
lambda _: None,
operand=None,
)

def _close_tqdm(arg, transform):
tqdm_bars[0].close()

def close_tqdm(result, iter_num):
return jax.lax.cond(
_ = jax.lax.cond(
iter_num == n - 1,
lambda _: host_callback.id_tap(_close_tqdm, None, result=result),
lambda _: result,
lambda _: callback(_close_tqdm, None, None),
lambda _: None,
operand=None,
)
return result

return _update_progress_bar, close_tqdm

0 comments on commit 942e93f

Please sign in to comment.