Skip to content

Commit

Permalink
allow choice of tqdm submodule (#22)
Browse files Browse the repository at this point in the history
* allow choice of tqdm submodule

* revert build_tqdm to required args and add check for tqdm_type
  • Loading branch information
mdmould authored Jul 17, 2024
1 parent 9c43885 commit 9050d7d
Show file tree
Hide file tree
Showing 2 changed files with 33 additions and 4 deletions.
17 changes: 17 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,23 @@ last_number, all_numbers = lax.scan(step, 0, jnp.arange(n))

will update every other step.

### Progress bar type

You can select the [tqdm](https://github.com/tqdm/tqdm) [submodule](https://github.com/tqdm/tqdm/tree/master?tab=readme-ov-file#submodules) manually with the `tqdm_type` option. The options are `'std'`, `'notebook'`, or `'auto'`.
```python
from jax_tqdm import scan_tqdm
from jax import lax
import jax.numpy as jnp

n = 10_000

@scan_tqdm(n, print_rate=1, tqdm_type='std') # tqdm_type='std' or 'notebook' or 'auto'
def step(carry, x):
return carry + 1, carry + 1

last_number, all_numbers = lax.scan(step, 0, jnp.arange(n))
```

### Progress bar options

Any additional keyword arguments are passed to the [tqdm](https://github.com/tqdm/tqdm)
Expand Down
20 changes: 16 additions & 4 deletions jax_tqdm/pbar.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,16 @@
import typing

import jax
import tqdm.auto
import tqdm.notebook
import tqdm.std
from jax.debug import callback
from tqdm.auto import tqdm


def scan_tqdm(
n: int,
print_rate: typing.Optional[int] = None,
tqdm_type: str = "auto",
**kwargs,
) -> typing.Callable:
"""
Expand All @@ -29,7 +32,7 @@ def scan_tqdm(
Progress bar wrapping function.
"""

_update_progress_bar, close_tqdm = build_tqdm(n, print_rate, **kwargs)
_update_progress_bar, close_tqdm = build_tqdm(n, print_rate, tqdm_type, **kwargs)

def _scan_tqdm(func):
"""Decorator that adds a tqdm progress bar to `body_fun` used in `jax.lax.scan`.
Expand All @@ -55,6 +58,7 @@ def wrapper_progress_bar(carry, x):
def loop_tqdm(
n: int,
print_rate: typing.Optional[int] = None,
tqdm_type: str = "auto",
**kwargs,
) -> typing.Callable:
"""
Expand All @@ -76,7 +80,7 @@ def loop_tqdm(
Progress bar wrapping function.
"""

_update_progress_bar, close_tqdm = build_tqdm(n, print_rate, **kwargs)
_update_progress_bar, close_tqdm = build_tqdm(n, print_rate, tqdm_type, **kwargs)

def _loop_tqdm(func):
"""
Expand All @@ -97,12 +101,20 @@ def wrapper_progress_bar(i, val):
def build_tqdm(
n: int,
print_rate: typing.Optional[int],
tqdm_type: str,
**kwargs,
) -> typing.Tuple[typing.Callable, typing.Callable]:
"""
Build the tqdm progress bar on the host
"""

if tqdm_type not in ("auto", "std", "notebook"):
raise ValueError(
'tqdm_type should be one of "auto", "std", or "notebook" '
f'but got "{tqdm_type}"'
)
pbar = getattr(tqdm, tqdm_type).tqdm

desc = kwargs.pop("desc", f"Running for {n:,} iterations")
message = kwargs.pop("message", desc)
for kwarg in ("total", "mininterval", "maxinterval", "miniters"):
Expand All @@ -127,7 +139,7 @@ def build_tqdm(
remainder = n % print_rate

def _define_tqdm(arg, transform):
tqdm_bars[0] = tqdm(range(n), **kwargs)
tqdm_bars[0] = pbar(range(n), **kwargs)
tqdm_bars[0].set_description(message, refresh=False)

def _update_tqdm(arg, transform):
Expand Down

0 comments on commit 9050d7d

Please sign in to comment.