Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

allow choice of tqdm submodule #22

Merged
merged 2 commits into from
Jul 17, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 17 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,23 @@ last_number, all_numbers = lax.scan(step, 0, jnp.arange(n))

will update every other step.

### Progress bar type

You can select the [tqdm](https://github.com/tqdm/tqdm) [submodule](https://github.com/tqdm/tqdm/tree/master?tab=readme-ov-file#submodules) manually with the `tqdm_type` option. The options are `'std'`, `'notebook'`, or `'auto'`.
```python
from jax_tqdm import scan_tqdm
from jax import lax
import jax.numpy as jnp

n = 10_000

@scan_tqdm(n, print_rate=1, tqdm_type='std') # tqdm_type='std' or 'notebook' or 'auto'
def step(carry, x):
return carry + 1, carry + 1

last_number, all_numbers = lax.scan(step, 0, jnp.arange(n))
```

### Progress bar options

Any additional keyword arguments are passed to the [tqdm](https://github.com/tqdm/tqdm)
Expand Down
20 changes: 16 additions & 4 deletions jax_tqdm/pbar.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,16 @@
import typing

import jax
import tqdm.auto
import tqdm.notebook
import tqdm.std
from jax.debug import callback
from tqdm.auto import tqdm


def scan_tqdm(
n: int,
print_rate: typing.Optional[int] = None,
tqdm_type: str = "auto",
zombie-einstein marked this conversation as resolved.
Show resolved Hide resolved
**kwargs,
) -> typing.Callable:
"""
Expand All @@ -29,7 +32,7 @@ def scan_tqdm(
Progress bar wrapping function.
"""

_update_progress_bar, close_tqdm = build_tqdm(n, print_rate, **kwargs)
_update_progress_bar, close_tqdm = build_tqdm(n, print_rate, tqdm_type, **kwargs)
zombie-einstein marked this conversation as resolved.
Show resolved Hide resolved

def _scan_tqdm(func):
"""Decorator that adds a tqdm progress bar to `body_fun` used in `jax.lax.scan`.
Expand All @@ -55,6 +58,7 @@ def wrapper_progress_bar(carry, x):
def loop_tqdm(
n: int,
print_rate: typing.Optional[int] = None,
tqdm_type: str = "auto",
**kwargs,
) -> typing.Callable:
"""
Expand All @@ -76,7 +80,7 @@ def loop_tqdm(
Progress bar wrapping function.
"""

_update_progress_bar, close_tqdm = build_tqdm(n, print_rate, **kwargs)
_update_progress_bar, close_tqdm = build_tqdm(n, print_rate, tqdm_type, **kwargs)

def _loop_tqdm(func):
"""
Expand All @@ -97,12 +101,20 @@ def wrapper_progress_bar(i, val):
def build_tqdm(
n: int,
print_rate: typing.Optional[int],
tqdm_type: str,
**kwargs,
) -> typing.Tuple[typing.Callable, typing.Callable]:
"""
Build the tqdm progress bar on the host
"""

if tqdm_type not in ("auto", "std", "notebook"):
raise ValueError(
'tqdm_type should be one of "auto", "std", or "notebook" '
f'but got "{tqdm_type}"'
)
pbar = getattr(tqdm, tqdm_type).tqdm

desc = kwargs.pop("desc", f"Running for {n:,} iterations")
message = kwargs.pop("message", desc)
for kwarg in ("total", "mininterval", "maxinterval", "miniters"):
Expand All @@ -127,7 +139,7 @@ def build_tqdm(
remainder = n % print_rate

def _define_tqdm(arg, transform):
tqdm_bars[0] = tqdm(range(n), **kwargs)
tqdm_bars[0] = pbar(range(n), **kwargs)
tqdm_bars[0].set_description(message, refresh=False)

def _update_tqdm(arg, transform):
Expand Down
Loading