Skip to content

Commit

Permalink
Merge pull request #3 from hr0nix/eps_outside_sqrt
Browse files Browse the repository at this point in the history
Move eps outside sqrt
  • Loading branch information
hr0nix authored Sep 18, 2022
2 parents 41ecc29 + 9e9658a commit cf71bab
Showing 1 changed file with 1 addition and 1 deletion.
2 changes: 1 addition & 1 deletion optax_adan/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ def update_fn(updates, state, params=None):
n_hat = bias_correction(n, decay_n, count_inc)

new_updates = jax.tree_util.tree_map(
lambda mm, vv, nn: (mm + decay_v * vv) / jnp.sqrt(nn + eps), m_hat, v_hat, n_hat)
lambda mm, vv, nn: (mm + decay_v * vv) / (jnp.sqrt(nn) + eps), m_hat, v_hat, n_hat)

return new_updates, ScaleByAdanState(count=count_inc, m=m, v=v, n=n, prev_grad=updates)

Expand Down

0 comments on commit cf71bab

Please sign in to comment.