Skip to content

Commit

Permalink
fix zero nan state typing
Browse files Browse the repository at this point in the history
  • Loading branch information
amosyou committed Feb 21, 2024
1 parent 65d4be6 commit 6398a1e
Showing 1 changed file with 8 additions and 6 deletions.
14 changes: 8 additions & 6 deletions optax/_src/constrain.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
# ==============================================================================
"""Gradient transformations used to enforce specific constraints."""

from typing import Any, NamedTuple
from typing import NamedTuple

import jax
import jax.numpy as jnp
Expand Down Expand Up @@ -57,13 +57,15 @@ def update_fn(updates, state, params):


class ZeroNansState(NamedTuple):
"""Contains a tree.
"""State of the `GradientTransformation` returned by `zero_nans`.
The entry `found_nan` has the same tree structure as that of the parameters.
Each leaf is a single boolean which contains True iff a NaN was detected in
the corresponding parameter array at the last call to `update`.
Attributes:
found_nan (``jax.Array``): tree that has the same structure as that of the
parameters. Each leaf is a single boolean which contains True iff a NaN
was detected in the corresponding parameter array at the last call to
`update`.
"""
found_nan: Any
found_nan: jax.Array


def zero_nans() -> base.GradientTransformation:
Expand Down

0 comments on commit 6398a1e

Please sign in to comment.