Skip to content

Commit

Permalink
Move eps outside sqrt
Browse files Browse the repository at this point in the history
  • Loading branch information
hr0nix committed Sep 18, 2022
1 parent 41ecc29 commit 9e9658a
Showing 1 changed file with 1 addition and 1 deletion.
2 changes: 1 addition & 1 deletion optax_adan/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ def update_fn(updates, state, params=None):
n_hat = bias_correction(n, decay_n, count_inc)

new_updates = jax.tree_util.tree_map(
lambda mm, vv, nn: (mm + decay_v * vv) / jnp.sqrt(nn + eps), m_hat, v_hat, n_hat)
lambda mm, vv, nn: (mm + decay_v * vv) / (jnp.sqrt(nn) + eps), m_hat, v_hat, n_hat)

return new_updates, ScaleByAdanState(count=count_inc, m=m, v=v, n=n, prev_grad=updates)

Expand Down

0 comments on commit 9e9658a

Please sign in to comment.