From 9e9658adbe794740ce8a44988fa7cbdd9ca0037d Mon Sep 17 00:00:00 2001 From: hr0nix Date: Sun, 18 Sep 2022 18:36:30 +0300 Subject: [PATCH] Move eps outside sqrt --- optax_adan/transform.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/optax_adan/transform.py b/optax_adan/transform.py index 88a8407..309289c 100644 --- a/optax_adan/transform.py +++ b/optax_adan/transform.py @@ -73,7 +73,7 @@ def update_fn(updates, state, params=None): n_hat = bias_correction(n, decay_n, count_inc) new_updates = jax.tree_util.tree_map( - lambda mm, vv, nn: (mm + decay_v * vv) / jnp.sqrt(nn + eps), m_hat, v_hat, n_hat) + lambda mm, vv, nn: (mm + decay_v * vv) / (jnp.sqrt(nn) + eps), m_hat, v_hat, n_hat) return new_updates, ScaleByAdanState(count=count_inc, m=m, v=v, n=n, prev_grad=updates)