diff --git a/README.md b/README.md index c3563e3..895a9ff 100644 --- a/README.md +++ b/README.md @@ -66,6 +66,23 @@ last_number, all_numbers = lax.scan(step, 0, jnp.arange(n)) will update every other step. +### Progress bar type + +You can select the [tqdm](https://github.com/tqdm/tqdm) [submodule](https://github.com/tqdm/tqdm/tree/master?tab=readme-ov-file#submodules) manually with the `tqdm_type` option. The options are `'std'`, `'notebook'`, or `'auto'`. +```python +from jax_tqdm import scan_tqdm +from jax import lax +import jax.numpy as jnp + +n = 10_000 + +@scan_tqdm(n, print_rate=1, tqdm_type='std') # tqdm_type='std' or 'notebook' or 'auto' +def step(carry, x): + return carry + 1, carry + 1 + +last_number, all_numbers = lax.scan(step, 0, jnp.arange(n)) +``` + ### Progress bar options Any additional keyword arguments are passed to the [tqdm](https://github.com/tqdm/tqdm) diff --git a/jax_tqdm/pbar.py b/jax_tqdm/pbar.py index 6c85f71..4c68bf8 100644 --- a/jax_tqdm/pbar.py +++ b/jax_tqdm/pbar.py @@ -1,13 +1,16 @@ import typing import jax +import tqdm.auto +import tqdm.notebook +import tqdm.std from jax.debug import callback -from tqdm.auto import tqdm def scan_tqdm( n: int, print_rate: typing.Optional[int] = None, + tqdm_type: str = "auto", **kwargs, ) -> typing.Callable: """ @@ -29,7 +32,7 @@ def scan_tqdm( Progress bar wrapping function. """ - _update_progress_bar, close_tqdm = build_tqdm(n, print_rate, **kwargs) + _update_progress_bar, close_tqdm = build_tqdm(n, print_rate, tqdm_type, **kwargs) def _scan_tqdm(func): """Decorator that adds a tqdm progress bar to `body_fun` used in `jax.lax.scan`. @@ -55,6 +58,7 @@ def wrapper_progress_bar(carry, x): def loop_tqdm( n: int, print_rate: typing.Optional[int] = None, + tqdm_type: str = "auto", **kwargs, ) -> typing.Callable: """ @@ -76,7 +80,7 @@ def loop_tqdm( Progress bar wrapping function. """ - _update_progress_bar, close_tqdm = build_tqdm(n, print_rate, **kwargs) + _update_progress_bar, close_tqdm = build_tqdm(n, print_rate, tqdm_type, **kwargs) def _loop_tqdm(func): """ @@ -97,12 +101,20 @@ def wrapper_progress_bar(i, val): def build_tqdm( n: int, print_rate: typing.Optional[int], + tqdm_type: str, **kwargs, ) -> typing.Tuple[typing.Callable, typing.Callable]: """ Build the tqdm progress bar on the host """ + if tqdm_type not in ("auto", "std", "notebook"): + raise ValueError( + 'tqdm_type should be one of "auto", "std", or "notebook" ' + f'but got "{tqdm_type}"' + ) + pbar = getattr(tqdm, tqdm_type).tqdm + desc = kwargs.pop("desc", f"Running for {n:,} iterations") message = kwargs.pop("message", desc) for kwarg in ("total", "mininterval", "maxinterval", "miniters"): @@ -127,7 +139,7 @@ def build_tqdm( remainder = n % print_rate def _define_tqdm(arg, transform): - tqdm_bars[0] = tqdm(range(n), **kwargs) + tqdm_bars[0] = pbar(range(n), **kwargs) tqdm_bars[0].set_description(message, refresh=False) def _update_tqdm(arg, transform):