diff --git a/jax_tqdm/pbar.py b/jax_tqdm/pbar.py index 427fdcf..df35e63 100644 --- a/jax_tqdm/pbar.py +++ b/jax_tqdm/pbar.py @@ -1,7 +1,7 @@ import typing import jax -from jax.experimental import host_callback +from jax.debug import callback from tqdm.auto import tqdm @@ -135,26 +135,26 @@ def _update_tqdm(arg, transform): def _update_progress_bar(iter_num): "Updates tqdm from a JAX scan or loop" - _ = jax.jax.lax.cond( + _ = jax.lax.cond( iter_num == 0, - lambda _: host_callback.id_tap(_define_tqdm, None, result=iter_num), - lambda _: iter_num, + lambda _: callback(_define_tqdm, None, None), + lambda _: None, operand=None, ) _ = jax.lax.cond( # update tqdm every multiple of `print_rate` except at the end (iter_num % print_rate == 0) & (iter_num != n - remainder), - lambda _: host_callback.id_tap(_update_tqdm, print_rate, result=iter_num), - lambda _: iter_num, + lambda _: callback(_update_tqdm, print_rate, None), + lambda _: None, operand=None, ) _ = jax.lax.cond( # update tqdm by `remainder` iter_num == n - remainder, - lambda _: host_callback.id_tap(_update_tqdm, remainder, result=iter_num), - lambda _: iter_num, + lambda _: callback(_update_tqdm, remainder, None), + lambda _: None, operand=None, ) @@ -162,11 +162,12 @@ def _close_tqdm(arg, transform): tqdm_bars[0].close() def close_tqdm(result, iter_num): - return jax.lax.cond( + _ = jax.lax.cond( iter_num == n - 1, - lambda _: host_callback.id_tap(_close_tqdm, None, result=result), - lambda _: result, + lambda _: callback(_close_tqdm, None, None), + lambda _: None, operand=None, ) + return result return _update_progress_bar, close_tqdm