Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix NamedTuples Formatting #803

Open
wants to merge 8 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion docs/api/combining_optimizers.rst
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ Combining Optimizers
.. autosummary::
chain
multi_transform
MultiTransformState

Chain
~~~~~
Expand All @@ -15,4 +16,3 @@ Multi-transform
~~~~~~~~~~~~~~~
.. autofunction:: multi_transform
.. autoclass:: MultiTransformState
amosyou marked this conversation as resolved.
Show resolved Hide resolved
:members:
16 changes: 0 additions & 16 deletions docs/api/contrib.rst
Original file line number Diff line number Diff line change
Expand Up @@ -27,50 +27,34 @@ Complex-valued Optimization
~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. autofunction:: split_real_and_imaginary
.. autoclass:: SplitRealAndImaginaryState
:members:

Continuous coin betting
~~~~~~~~~~~~~~~~~~~~~~~
.. autofunction:: cocob
.. autoclass:: COCOBState
:members:

D-adaptation
~~~~~~~~~~~~
.. autofunction:: dadapt_adamw
.. autoclass:: DAdaptAdamWState
:members:

Privacy-Sensitive Optax Methods
-------------------------------

.. autosummary::
DifferentiallyPrivateAggregateState
differentially_private_aggregate


Differentially Private Aggregate
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. autofunction:: differentially_private_aggregate
.. autoclass:: DifferentiallyPrivateAggregateState
:members:
.. autofunction:: dpsgd


Mechanize
~~~~~~~~~
.. autofunction:: mechanize
.. autoclass:: MechanicState
:members:

Prodigy
~~~~~~~
.. autofunction:: prodigy
.. autoclass:: ProdigyState
:members:

Sharpness aware minimization
~~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. autofunction:: sam
.. autoclass:: SAMState
:members:
6 changes: 3 additions & 3 deletions docs/api/control_variates.rst
Original file line number Diff line number Diff line change
Expand Up @@ -9,13 +9,13 @@ Control Variates
moving_avg_baseline

Control delta method
~~~~~~~~~~~~~~~~~~~~~~
~~~~~~~~~~~~~~~~~~~~
.. autofunction:: control_delta_method

Control variates Jacobians
~~~~~~~~~~~~~~~~~~~~~~~~~~~~
~~~~~~~~~~~~~~~~~~~~~~~~~~
.. autofunction:: control_variates_jacobians

Moving average baseline
~~~~~~~~~~~~~~~~~~~~~~
~~~~~~~~~~~~~~~~~~~~~~~
.. autofunction:: moving_avg_baseline
1 change: 0 additions & 1 deletion docs/api/optimizer_schedules.rst
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,6 @@ Inject hyperparameters
~~~~~~~~~~~~~~~~~~~~~~
.. autofunction:: inject_hyperparams
.. autoclass:: InjectHyperparamsState
:members:

Linear schedules
~~~~~~~~~~~~~~~~
Expand Down
8 changes: 0 additions & 8 deletions docs/api/optimizer_wrappers.rst
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,6 @@ Apply if finite
~~~~~~~~~~~~~~~~~
.. autofunction:: apply_if_finite
.. autoclass:: ApplyIfFiniteState
:members:

Flatten
~~~~~~~~
Expand All @@ -35,29 +34,22 @@ Lookahead
~~~~~~~~~~~~~~~~~
.. autofunction:: lookahead
.. autoclass:: LookaheadParams
:members:
.. autoclass:: LookaheadState
:members:

Masked update
~~~~~~~~~~~~~~
.. autofunction:: masked
.. autoclass:: MaskedState
:members:

Maybe update
~~~~~~~~~~~~~~
.. autofunction:: maybe_update
.. autoclass:: MaybeUpdateState
:members:

Multi-step update
~~~~~~~~~~~~~~~~~~~~
.. autoclass:: MultiSteps
:members:
.. autoclass:: MultiStepsState
:members:
.. autoclass:: ShouldSkipUpdateFunction
:members:
.. autofunction:: skip_large_updates
.. autofunction:: skip_not_finite
26 changes: 0 additions & 26 deletions docs/api/transformations.rst
Original file line number Diff line number Diff line change
Expand Up @@ -118,19 +118,15 @@ Transformations and states

.. autofunction:: adaptive_grad_clip
.. autoclass:: AdaptiveGradClipState
:members:

.. autofunction:: add_decayed_weights
.. autoclass:: AddDecayedWeightsState
:members:

.. autofunction:: add_noise
.. autoclass:: AddNoiseState
:members:

.. autofunction:: apply_every
.. autoclass:: ApplyEvery
:members:

.. autofunction:: bias_correction

Expand All @@ -139,67 +135,54 @@ Transformations and states
.. autofunction:: clip
.. autofunction:: clip_by_block_rms
.. autoclass:: ClipState
:members:

.. autofunction:: clip_by_global_norm
.. autoclass:: ClipByGlobalNormState
:members:

.. autofunction:: ema
.. autoclass:: EmaState
:members:

.. autoclass:: EmptyState
:members:

.. autofunction:: global_norm

.. autofunction:: identity

.. autofunction:: keep_params_nonnegative
.. autoclass:: NonNegativeParamsState
:members:

.. autofunction:: per_example_global_norm_clip
.. autofunction:: per_example_layer_norm_clip

.. autofunction:: scale
.. autoclass:: ScaleState
:members:

.. autofunction:: scale_by_adadelta
.. autoclass:: ScaleByAdaDeltaState
:members:

.. autofunction:: scale_by_adam
.. autofunction:: scale_by_adamax
.. autoclass:: ScaleByAdamState
:members:

.. autofunction:: scale_by_amsgrad
.. autoclass:: ScaleByAmsgradState
:members:

.. autofunction:: scale_by_backtracking_linesearch
.. autoclass:: ScaleByBacktrackingLinesearchState

.. autofunction:: scale_by_belief
.. autoclass:: ScaleByBeliefState
:members:

.. autofunction:: scale_by_factored_rms
.. autoclass:: FactoredState
:members:

.. autofunction:: scale_by_learning_rate

.. autofunction:: scale_by_lion
.. autoclass:: ScaleByLionState
:members:

.. autofunction:: scale_by_novograd
.. autoclass:: ScaleByNovogradState
:members:

.. autofunction:: scale_by_optimistic_gradient

Expand All @@ -213,31 +196,24 @@ Transformations and states

.. autofunction:: scale_by_rms
.. autoclass:: ScaleByRmsState
:members:

.. autofunction:: scale_by_rprop
.. autoclass:: ScaleByRpropState
:members:

.. autofunction:: scale_by_rss
.. autoclass:: ScaleByRssState
:members:

.. autofunction:: scale_by_schedule
.. autoclass:: ScaleByScheduleState
:members:

.. autofunction:: scale_by_sm3
.. autoclass:: ScaleBySM3State
:members:

.. autofunction:: scale_by_stddev
.. autoclass:: ScaleByRStdDevState
:members:

.. autofunction:: scale_by_trust_ratio
.. autoclass:: ScaleByTrustRatioState
:members:

.. autofunction:: scale_by_yogi

Expand All @@ -248,7 +224,6 @@ Transformations and states

.. autofunction:: trace
.. autoclass:: TraceState
:members:

.. autofunction:: update_infinity_moment
.. autofunction:: update_moment
Expand All @@ -258,4 +233,3 @@ Transformations and states

.. autofunction:: zero_nans
.. autoclass:: ZeroNansState
:members:
1 change: 1 addition & 0 deletions optax/_src/combine.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,7 @@ def update_fn(updates, state, params=None, **extra_args):


class MultiTransformState(NamedTuple):
"""State of the `GradientTransformation` returned by `multi_transform`."""
inner_states: Mapping[Hashable, base.OptState]


Expand Down
15 changes: 9 additions & 6 deletions optax/_src/constrain.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,9 @@
# ==============================================================================
"""Gradient transformations used to enforce specific constraints."""

from typing import Any, NamedTuple
from typing import NamedTuple

import chex
import jax
import jax.numpy as jnp

Expand Down Expand Up @@ -57,13 +58,15 @@ def update_fn(updates, state, params):


class ZeroNansState(NamedTuple):
"""Contains a tree.
"""State of the `GradientTransformation` returned by `zero_nans`.

The entry `found_nan` has the same tree structure as that of the parameters.
Each leaf is a single boolean which contains True iff a NaN was detected in
the corresponding parameter array at the last call to `update`.
Attributes:
found_nan (``jax.Array``): tree that has the same structure as that of the
parameters. Each leaf is a single boolean which contains True iff a NaN
was detected in the corresponding parameter array at the last call to
`update`.
"""
found_nan: Any
found_nan: chex.ArrayTree


def zero_nans() -> base.GradientTransformation:
Expand Down
12 changes: 6 additions & 6 deletions optax/_src/lookahead.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,12 +29,12 @@ class LookaheadState(NamedTuple):
"""State of the `GradientTransformation` returned by `lookahead`.

Attributes:
fast_state: Optimizer state of the fast optimizer.
steps_since_sync: Number of fast optimizer steps taken since slow and fast
parameters were synchronized.
fast_state (:class:`optax.OptState`): Optimizer state of the fast optimizer.
steps_since_sync (``Union[jax.Array, int]``): Number of fast optimizer steps
taken since slow and fast parameters were synchronized.
"""
fast_state: base.OptState
steps_since_sync: jnp.ndarray
steps_since_sync: Union[jax.Array, int]


class LookaheadParams(NamedTuple):
Expand All @@ -48,8 +48,8 @@ class LookaheadParams(NamedTuple):
[Zhang et al, 2019](https://arxiv.org/pdf/1907.08610v1.pdf)

Attributes:
fast: Fast parameters.
slow: Slow parameters.
fast (:class:`optax.Params`): Fast parameters.
slow (:class:`optax.Params`): Slow parameters.
"""
fast: base.Params
slow: base.Params
Expand Down
Loading