Skip to content

Commit

Permalink
remove members in wrappers + fix wrapper state typing
Browse files Browse the repository at this point in the history
  • Loading branch information
amosyou committed Feb 15, 2024
1 parent 5d4186f commit a225bcb
Show file tree
Hide file tree
Showing 2 changed files with 33 additions and 29 deletions.
5 changes: 0 additions & 5 deletions docs/api/optimizer_wrappers.rst
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,6 @@ Apply if finite
~~~~~~~~~~~~~~~~~
.. autofunction:: apply_if_finite
.. autoclass:: ApplyIfFiniteState
:members:

Flatten
~~~~~~~~
Expand All @@ -37,26 +36,22 @@ Lookahead
.. autoclass:: LookaheadParams
:members:
.. autoclass:: LookaheadState
:members:

Masked update
~~~~~~~~~~~~~~
.. autofunction:: masked
.. autoclass:: MaskedState
:members:

Maybe update
~~~~~~~~~~~~~~
.. autofunction:: maybe_update
.. autoclass:: MaybeUpdateState
:members:

Multi-step update
~~~~~~~~~~~~~~~~~~~~
.. autoclass:: MultiSteps
:members:
.. autoclass:: MultiStepsState
:members:
.. autoclass:: ShouldSkipUpdateFunction
:members:
.. autofunction:: skip_large_updates
Expand Down
57 changes: 33 additions & 24 deletions optax/_src/wrappers.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,23 +89,23 @@ def update_fn(updates, state, params=None, **extra_args):
class ApplyIfFiniteState(NamedTuple):
"""State of the `GradientTransformation` returned by `apply_if_finite`.
Fields:
notfinite_count: Number of consecutive gradient updates containing an Inf or
Attributes:
notfinite_count (``Union[jax.Array, int]``): Number of consecutive gradient updates containing an Inf or
a NaN. This number is reset to 0 whenever a gradient update without an Inf
or a NaN is done.
last_finite: Whether or not the last gradient update contained an Inf or a
last_finite (``Union[jax.Array, int]``): Whether or not the last gradient update contained an Inf or a
NaN.
total_notfinite: Total number of gradient updates containing an Inf or
total_notfinite (``Union[jax.Array, int]``): Total number of gradient updates containing an Inf or
a NaN since this optimizer was initialised. This number is never reset.
inner_state: The state of the inner `GradientTransformation`.
inner_state (:class:`optax.OptState`): The state of the inner `GradientTransformation`.
"""
# TODO(optax-dev): notfinite_count, last_finite and inner_state used to be
# annotated as `jnp.array` but that is not a valid annotation (it's a function
# and secretely resolved to `Any`. We should add back typing.
notfinite_count: Any
last_finite: Any
total_notfinite: Any
inner_state: Any
notfinite_count: Union[jax.Array, int]
last_finite: Union[jax.Array, int]
total_notfinite: Union[jax.Array, int]
inner_state: base.OptState


def apply_if_finite(
Expand Down Expand Up @@ -175,23 +175,23 @@ def _zeros_tree_like(inp_tree: chex.ArrayTree) -> chex.ArrayTree:
class MultiStepsState(NamedTuple):
"""State of the `GradientTransformation` returned by `MultiSteps`.
Fields:
mini_step: current mini-step counter. At an update, this either increases by
Attributes:
mini_step (``Union[jax.Array, int]``): current mini-step counter. At an update, this either increases by
1 or is reset to 0.
gradient_step: gradient step counter. This only increases after enough
gradient_step (``Union[jax.Array, int]``): gradient step counter. This only increases after enough
mini-steps have been accumulated.
inner_opt_state: the state of the wrapped otpimiser.
acc_grads: accumulated gradients over multiple mini-steps.
skip_state: an arbitrarily nested tree of arrays. This is only
inner_opt_state (:class:`optax.OptState`): the state of the wrapped optimiser.
acc_grads (``jax.Array``): accumulated gradients over multiple mini-steps.
skip_state (``chex.ArrayTree``): an arbitrarily nested tree of arrays. This is only
relevant when passing a `should_skip_update_fn` to `MultiSteps`. This
structure will then contain values for debugging and or monitoring. The
actual structure will vary depending on the choice of
`ShouldSkipUpdateFunction`.
"""
mini_step: Array
gradient_step: Array
inner_opt_state: Any
acc_grads: Any
mini_step: Union[jax.Array, int]
gradient_step: Union[jax.Array, int]
inner_opt_state: base.OptState
acc_grads: jax.Array # TODO: double check this one
skip_state: chex.ArrayTree = ()


Expand Down Expand Up @@ -448,8 +448,12 @@ def gradient_transformation(self) -> base.GradientTransformation:


class MaskedState(NamedTuple):
"""Maintains inner transform state for masked transformations."""
inner_state: Any
"""Maintains inner transform state for masked transformations.
Attributes:
inner_state (:class:`optax.OptState`): The state of the inner `GradientTransformation`.
"""
inner_state: base.OptState


class MaskedNode(NamedTuple):
Expand Down Expand Up @@ -563,9 +567,14 @@ def update_fn(updates, state, params=None, **extra_args):


class MaybeUpdateState(NamedTuple):
"""Maintains inner transform state and adds a step counter."""
inner_state: Any
step: Array
"""Maintains inner transform state and adds a step counter.
Attributes:
inner_state (:class:`optax.OptState`): The state of the inner `GradientTransformation`.
step (``Union[jax.Array, int]``): The current step counter.
"""
inner_state: base.OptState
step: Union[jax.Array, int]


def maybe_update(
Expand Down

0 comments on commit a225bcb

Please sign in to comment.