diff --git a/docs/api/optimizer_wrappers.rst b/docs/api/optimizer_wrappers.rst index b029f0f0c..21b91724a 100644 --- a/docs/api/optimizer_wrappers.rst +++ b/docs/api/optimizer_wrappers.rst @@ -25,7 +25,6 @@ Apply if finite ~~~~~~~~~~~~~~~~~ .. autofunction:: apply_if_finite .. autoclass:: ApplyIfFiniteState - :members: Flatten ~~~~~~~~ @@ -37,26 +36,22 @@ Lookahead .. autoclass:: LookaheadParams :members: .. autoclass:: LookaheadState - :members: Masked update ~~~~~~~~~~~~~~ .. autofunction:: masked .. autoclass:: MaskedState - :members: Maybe update ~~~~~~~~~~~~~~ .. autofunction:: maybe_update .. autoclass:: MaybeUpdateState - :members: Multi-step update ~~~~~~~~~~~~~~~~~~~~ .. autoclass:: MultiSteps :members: .. autoclass:: MultiStepsState - :members: .. autoclass:: ShouldSkipUpdateFunction :members: .. autofunction:: skip_large_updates diff --git a/optax/_src/wrappers.py b/optax/_src/wrappers.py index 9d0d50bc3..b901a5228 100644 --- a/optax/_src/wrappers.py +++ b/optax/_src/wrappers.py @@ -89,23 +89,23 @@ def update_fn(updates, state, params=None, **extra_args): class ApplyIfFiniteState(NamedTuple): """State of the `GradientTransformation` returned by `apply_if_finite`. - Fields: - notfinite_count: Number of consecutive gradient updates containing an Inf or + Attributes: + notfinite_count (``Union[jax.Array, int]``): Number of consecutive gradient updates containing an Inf or a NaN. This number is reset to 0 whenever a gradient update without an Inf or a NaN is done. - last_finite: Whether or not the last gradient update contained an Inf or a + last_finite (``Union[jax.Array, int]``): Whether or not the last gradient update contained an Inf or a NaN. - total_notfinite: Total number of gradient updates containing an Inf or + total_notfinite (``Union[jax.Array, int]``): Total number of gradient updates containing an Inf or a NaN since this optimizer was initialised. This number is never reset. - inner_state: The state of the inner `GradientTransformation`. + inner_state (:class:`optax.OptState`): The state of the inner `GradientTransformation`. """ # TODO(optax-dev): notfinite_count, last_finite and inner_state used to be # annotated as `jnp.array` but that is not a valid annotation (it's a function # and secretely resolved to `Any`. We should add back typing. - notfinite_count: Any - last_finite: Any - total_notfinite: Any - inner_state: Any + notfinite_count: Union[jax.Array, int] + last_finite: Union[jax.Array, int] + total_notfinite: Union[jax.Array, int] + inner_state: base.OptState def apply_if_finite( @@ -175,23 +175,23 @@ def _zeros_tree_like(inp_tree: chex.ArrayTree) -> chex.ArrayTree: class MultiStepsState(NamedTuple): """State of the `GradientTransformation` returned by `MultiSteps`. - Fields: - mini_step: current mini-step counter. At an update, this either increases by + Attributes: + mini_step (``Union[jax.Array, int]``): current mini-step counter. At an update, this either increases by 1 or is reset to 0. - gradient_step: gradient step counter. This only increases after enough + gradient_step (``Union[jax.Array, int]``): gradient step counter. This only increases after enough mini-steps have been accumulated. - inner_opt_state: the state of the wrapped otpimiser. - acc_grads: accumulated gradients over multiple mini-steps. - skip_state: an arbitrarily nested tree of arrays. This is only + inner_opt_state (:class:`optax.OptState`): the state of the wrapped optimiser. + acc_grads (``jax.Array``): accumulated gradients over multiple mini-steps. + skip_state (``chex.ArrayTree``): an arbitrarily nested tree of arrays. This is only relevant when passing a `should_skip_update_fn` to `MultiSteps`. This structure will then contain values for debugging and or monitoring. The actual structure will vary depending on the choice of `ShouldSkipUpdateFunction`. """ - mini_step: Array - gradient_step: Array - inner_opt_state: Any - acc_grads: Any + mini_step: Union[jax.Array, int] + gradient_step: Union[jax.Array, int] + inner_opt_state: base.OptState + acc_grads: jax.Array # TODO: double check this one skip_state: chex.ArrayTree = () @@ -448,8 +448,12 @@ def gradient_transformation(self) -> base.GradientTransformation: class MaskedState(NamedTuple): - """Maintains inner transform state for masked transformations.""" - inner_state: Any + """Maintains inner transform state for masked transformations. + + Attributes: + inner_state (:class:`optax.OptState`): The state of the inner `GradientTransformation`. + """ + inner_state: base.OptState class MaskedNode(NamedTuple): @@ -563,9 +567,14 @@ def update_fn(updates, state, params=None, **extra_args): class MaybeUpdateState(NamedTuple): - """Maintains inner transform state and adds a step counter.""" - inner_state: Any - step: Array + """Maintains inner transform state and adds a step counter. + + Attributes: + inner_state (:class:`optax.OptState`): The state of the inner `GradientTransformation`. + step (``Union[jax.Array, int]``): The current step counter. + """ + inner_state: base.OptState + step: Union[jax.Array, int] def maybe_update(