Skip to content

Commit

Permalink
fix zero nan state typing
Browse files Browse the repository at this point in the history
  • Loading branch information
amosyou committed Feb 21, 2024
1 parent 0bf3765 commit 5dbe9cb
Showing 1 changed file with 7 additions and 5 deletions.
12 changes: 7 additions & 5 deletions optax/_src/constrain.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,13 +57,15 @@ def update_fn(updates, state, params):


class ZeroNansState(NamedTuple):
"""Contains a tree.
"""State of the `GradientTransformation` returned by `zero_nans`.
The entry `found_nan` has the same tree structure as that of the parameters.
Each leaf is a single boolean which contains True iff a NaN was detected in
the corresponding parameter array at the last call to `update`.
Attributes:
found_nan (``jax.Array``): tree that has the same structure as that of the
parameters. Each leaf is a single boolean which contains True iff a NaN
was detected in the corresponding parameter array at the last call to
`update`.
"""
found_nan: Any
found_nan: jax.Array


def zero_nans() -> base.GradientTransformation:
Expand Down

0 comments on commit 5dbe9cb

Please sign in to comment.