Skip to content

Commit

Permalink
Perform bounded loop with decorator
Browse files Browse the repository at this point in the history
  • Loading branch information
zombie-einstein committed Oct 18, 2024
1 parent 40d2afd commit 1f12f5f
Show file tree
Hide file tree
Showing 2 changed files with 22 additions and 36 deletions.
55 changes: 21 additions & 34 deletions jax_tqdm/pbar.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,46 +127,37 @@ def wrapper_progress_bar(i, val):


def bounded_while_tqdm(
cond_fun: typing.Callable,
body_fun: typing.Callable,
n: int,
print_rate: typing.Optional[int] = None,
tqdm_type: str = "auto",
**kwargs,
) -> typing.Tuple[typing.Callable, typing.Callable]:
) -> typing.Callable:

update_progress_bar, close_tqdm = build_tqdm(n, print_rate, tqdm_type, **kwargs)

def cond_fun_wrapper(val) -> bool:

if isinstance(val, tuple):
iter_num, *_ = val
else:
iter_num = val

cond = cond_fun(val)
cond = jax.lax.cond(
cond,
lambda _cond, *_: _cond,
close_tqdm,
cond,
iter_num,
iter_num - 1,
)
return cond
def _bounded_while_tqdm(cond_fun) -> typing.Callable:
def cond_fun_wrapper(val) -> bool:

def body_fun_wrapper(val):
if isinstance(val, tuple):
iter_num, val = val
else:
iter_num = val

val = update_progress_bar(val, iter_num)
val = body_fun(val)
if isinstance(val, tuple):
iter_num, *_ = val
else:
iter_num = val

val = update_progress_bar(val, iter_num)
cond = cond_fun(val)
cond = jax.lax.cond(
cond,
lambda _cond, *_: _cond,
close_tqdm,
cond,
iter_num,
iter_num - 1,
)
return cond

return val
return cond_fun_wrapper

return cond_fun_wrapper, body_fun_wrapper
return _bounded_while_tqdm


def build_tqdm(
Expand Down Expand Up @@ -221,9 +212,6 @@ def build_tqdm(
f"number of steps {n}, got {print_rate}"
)

remainder = n % print_rate
remainder = remainder if remainder > 0 else print_rate

def _define_tqdm(bar_id: int):
bar_id = int(bar_id)
tqdm_bars[bar_id] = pbar(
Expand All @@ -238,7 +226,6 @@ def _update_tqdm(bar_id: int):

def _close_tqdm(bar_id: int, final_value: int):
_pbar = tqdm_bars.pop(int(bar_id))
print(final_value, _pbar.n)
_pbar.update(int(final_value) - _pbar.n)
_pbar.clear()
_pbar.close()
Expand Down
3 changes: 1 addition & 2 deletions tests/test_examples.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,14 +65,13 @@ def test_bounded_while_loop():
n_total = 10_000
n_stop = 5_000

@bounded_while_tqdm(n_total)
def cond_fun(x):
return x < n_stop

def body_fun(x):
return x + 1

cond_fun, body_fun = bounded_while_tqdm(cond_fun, body_fun, n_total)

result = jax.lax.while_loop(cond_fun, body_fun, 0)

assert result == 5_000
Expand Down

0 comments on commit 1f12f5f

Please sign in to comment.