Skip to content

Commit

Permalink
linting for lookahead and wrappers
Browse files Browse the repository at this point in the history
  • Loading branch information
amosyou committed Feb 21, 2024
1 parent a225bcb commit d44b905
Show file tree
Hide file tree
Showing 2 changed files with 26 additions and 24 deletions.
4 changes: 2 additions & 2 deletions optax/_src/lookahead.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,8 +30,8 @@ class LookaheadState(NamedTuple):
Attributes:
fast_state (:class:`optax.OptState`): Optimizer state of the fast optimizer.
steps_since_sync (``Union[jax.Array, int]``): Number of fast optimizer steps taken since slow and fast
parameters were synchronized.
steps_since_sync (``Union[jax.Array, int]``): Number of fast optimizer steps
taken since slow and fast parameters were synchronized.
"""
fast_state: base.OptState
steps_since_sync: Union[jax.Array, int]
Expand Down
46 changes: 24 additions & 22 deletions optax/_src/wrappers.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,18 +90,17 @@ class ApplyIfFiniteState(NamedTuple):
"""State of the `GradientTransformation` returned by `apply_if_finite`.
Attributes:
notfinite_count (``Union[jax.Array, int]``): Number of consecutive gradient updates containing an Inf or
a NaN. This number is reset to 0 whenever a gradient update without an Inf
or a NaN is done.
last_finite (``Union[jax.Array, int]``): Whether or not the last gradient update contained an Inf or a
NaN.
total_notfinite (``Union[jax.Array, int]``): Total number of gradient updates containing an Inf or
a NaN since this optimizer was initialised. This number is never reset.
inner_state (:class:`optax.OptState`): The state of the inner `GradientTransformation`.
notfinite_count (``Union[jax.Array, int]``): Number of consecutive gradient
updates containing an Inf or a NaN. This number is reset to 0 whenever a
gradient update without an Inf or a NaN is done.
last_finite (``Union[jax.Array, int]``): Whether or not the last gradient
update contained an Inf or a NaN.
total_notfinite (``Union[jax.Array, int]``): Total number of gradient
updates containing an Inf or a NaN since this optimizer was initialised.
This number is never reset.
inner_state (:class:`optax.OptState`): The state of the inner
`GradientTransformation`.
"""
# TODO(optax-dev): notfinite_count, last_finite and inner_state used to be
# annotated as `jnp.array` but that is not a valid annotation (it's a function
# and secretely resolved to `Any`. We should add back typing.
notfinite_count: Union[jax.Array, int]
last_finite: Union[jax.Array, int]
total_notfinite: Union[jax.Array, int]
Expand Down Expand Up @@ -176,16 +175,17 @@ class MultiStepsState(NamedTuple):
"""State of the `GradientTransformation` returned by `MultiSteps`.
Attributes:
mini_step (``Union[jax.Array, int]``): current mini-step counter. At an update, this either increases by
1 or is reset to 0.
gradient_step (``Union[jax.Array, int]``): gradient step counter. This only increases after enough
mini-steps have been accumulated.
inner_opt_state (:class:`optax.OptState`): the state of the wrapped optimiser.
mini_step (``Union[jax.Array, int]``): current mini-step counter. At an
update, this either increases by 1 or is reset to 0.
gradient_step (``Union[jax.Array, int]``): gradient step counter. This only
increases after enough mini-steps have been accumulated.
inner_opt_state (:class:`optax.OptState`): the state of the wrapped
optimiser.
acc_grads (``jax.Array``): accumulated gradients over multiple mini-steps.
skip_state (``chex.ArrayTree``): an arbitrarily nested tree of arrays. This is only
relevant when passing a `should_skip_update_fn` to `MultiSteps`. This
structure will then contain values for debugging and or monitoring. The
actual structure will vary depending on the choice of
skip_state (``chex.ArrayTree``): an arbitrarily nested tree of arrays. This
is only relevant when passing a `should_skip_update_fn` to `MultiSteps`.
This structure will then contain values for debugging and or monitoring.
The actual structure will vary depending on the choice of
`ShouldSkipUpdateFunction`.
"""
mini_step: Union[jax.Array, int]
Expand Down Expand Up @@ -451,7 +451,8 @@ class MaskedState(NamedTuple):
"""Maintains inner transform state for masked transformations.
Attributes:
inner_state (:class:`optax.OptState`): The state of the inner `GradientTransformation`.
inner_state (:class:`optax.OptState`): The state of the inner
`GradientTransformation`.
"""
inner_state: base.OptState

Expand Down Expand Up @@ -570,7 +571,8 @@ class MaybeUpdateState(NamedTuple):
"""Maintains inner transform state and adds a step counter.
Attributes:
inner_state (:class:`optax.OptState`): The state of the inner `GradientTransformation`.
inner_state (:class:`optax.OptState`): The state of the inner
`GradientTransformation`.
step (``Union[jax.Array, int]``): The current step counter.
"""
inner_state: base.OptState
Expand Down

0 comments on commit d44b905

Please sign in to comment.