Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix NamedTuples Formatting #803

Open
wants to merge 8 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion docs/api/combining_optimizers.rst
Original file line number Diff line number Diff line change
Expand Up @@ -15,4 +15,3 @@ Multi-transform
~~~~~~~~~~~~~~~
.. autofunction:: multi_transform
.. autoclass:: MultiTransformState
amosyou marked this conversation as resolved.
Show resolved Hide resolved
:members:
15 changes: 0 additions & 15 deletions docs/api/contrib.rst
Original file line number Diff line number Diff line change
Expand Up @@ -27,47 +27,32 @@ Complex-valued Optimization
~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. autofunction:: split_real_and_imaginary
.. autoclass:: SplitRealAndImaginaryState
:members:

Continuous coin betting
~~~~~~~~~~~~~~~~~~~~~~~
.. autofunction:: cocob
.. autoclass:: COCOBState
:members:

D-adaptation
~~~~~~~~~~~~
.. autofunction:: dadapt_adamw
.. autoclass:: DAdaptAdamWState
:members:

Privacy-Sensitive Optax Methods
-------------------------------

.. autosummary::
DifferentiallyPrivateAggregateState
differentially_private_aggregate


Differentially Private Aggregate
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. autofunction:: differentially_private_aggregate
.. autoclass:: DifferentiallyPrivateAggregateState
:members:
.. autofunction:: dpsgd


Mechanize
~~~~~~~~~
.. autofunction:: mechanize
.. autoclass:: MechanicState
:members:

Prodigy
~~~~~~~
.. autofunction:: prodigy
.. autoclass:: ProdigyState
:members:

Sharpness aware minimization
~~~~~~~~~~~~~~~~~~~~~~~~~~~~
Expand Down
6 changes: 3 additions & 3 deletions docs/api/control_variates.rst
Original file line number Diff line number Diff line change
Expand Up @@ -9,13 +9,13 @@ Control Variates
moving_avg_baseline

Control delta method
~~~~~~~~~~~~~~~~~~~~~~
~~~~~~~~~~~~~~~~~~~~
.. autofunction:: control_delta_method

Control variates Jacobians
~~~~~~~~~~~~~~~~~~~~~~~~~~~~
~~~~~~~~~~~~~~~~~~~~~~~~~~
.. autofunction:: control_variates_jacobians

Moving average baseline
~~~~~~~~~~~~~~~~~~~~~~
~~~~~~~~~~~~~~~~~~~~~~~
.. autofunction:: moving_avg_baseline
1 change: 0 additions & 1 deletion docs/api/optimizer_schedules.rst
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,6 @@ Inject hyperparameters
~~~~~~~~~~~~~~~~~~~~~~
.. autofunction:: inject_hyperparams
.. autoclass:: InjectHyperparamsState
:members:

Linear schedules
~~~~~~~~~~~~~~~~
Expand Down
5 changes: 0 additions & 5 deletions docs/api/optimizer_wrappers.rst
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since you made a pass on all States here to have a nice formatting you can keep the members attribute.
My idea was:

  • if you're brave and you change the docstring such that it's properly formatted, then let's keep :members: for that state.
  • otherwise, simply remove :members: and we may take care of the formatting of the states later.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sounds good! i brought back the members

Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,6 @@ Apply if finite
~~~~~~~~~~~~~~~~~
.. autofunction:: apply_if_finite
.. autoclass:: ApplyIfFiniteState
:members:

Flatten
~~~~~~~~
Expand All @@ -37,26 +36,22 @@ Lookahead
.. autoclass:: LookaheadParams
:members:
.. autoclass:: LookaheadState
:members:

Masked update
~~~~~~~~~~~~~~
.. autofunction:: masked
.. autoclass:: MaskedState
:members:

Maybe update
~~~~~~~~~~~~~~
.. autofunction:: maybe_update
.. autoclass:: MaybeUpdateState
:members:

Multi-step update
~~~~~~~~~~~~~~~~~~~~
.. autoclass:: MultiSteps
:members:
.. autoclass:: MultiStepsState
:members:
.. autoclass:: ShouldSkipUpdateFunction
:members:
.. autofunction:: skip_large_updates
Expand Down
26 changes: 0 additions & 26 deletions docs/api/transformations.rst
Original file line number Diff line number Diff line change
Expand Up @@ -115,19 +115,15 @@ Transformations and states

.. autofunction:: adaptive_grad_clip
.. autoclass:: AdaptiveGradClipState
:members:

.. autofunction:: add_decayed_weights
.. autoclass:: AddDecayedWeightsState
:members:

.. autofunction:: add_noise
.. autoclass:: AddNoiseState
:members:

.. autofunction:: apply_every
.. autoclass:: ApplyEvery
:members:

.. autofunction:: bias_correction

Expand All @@ -136,64 +132,51 @@ Transformations and states
.. autofunction:: clip
.. autofunction:: clip_by_block_rms
.. autoclass:: ClipState
:members:

.. autofunction:: clip_by_global_norm
.. autoclass:: ClipByGlobalNormState
:members:

.. autofunction:: ema
.. autoclass:: EmaState
:members:

.. autoclass:: EmptyState
:members:

.. autofunction:: global_norm

.. autofunction:: identity

.. autofunction:: keep_params_nonnegative
.. autoclass:: NonNegativeParamsState
:members:

.. autofunction:: per_example_global_norm_clip
.. autofunction:: per_example_layer_norm_clip

.. autofunction:: scale
.. autoclass:: ScaleState
:members:

.. autofunction:: scale_by_adadelta
.. autoclass:: ScaleByAdaDeltaState
:members:

.. autofunction:: scale_by_adam
.. autofunction:: scale_by_adamax
.. autoclass:: ScaleByAdamState
:members:

.. autofunction:: scale_by_amsgrad
.. autoclass:: ScaleByAmsgradState
:members:

.. autofunction:: scale_by_belief
.. autoclass:: ScaleByBeliefState
:members:

.. autofunction:: scale_by_factored_rms
.. autoclass:: FactoredState
:members:

.. autofunction:: scale_by_learning_rate

.. autofunction:: scale_by_lion
.. autoclass:: ScaleByLionState
:members:

.. autofunction:: scale_by_novograd
.. autoclass:: ScaleByNovogradState
:members:

.. autofunction:: scale_by_optimistic_gradient

Expand All @@ -205,31 +188,24 @@ Transformations and states

.. autofunction:: scale_by_rms
.. autoclass:: ScaleByRmsState
:members:

.. autofunction:: scale_by_rprop
.. autoclass:: ScaleByRpropState
:members:

.. autofunction:: scale_by_rss
.. autoclass:: ScaleByRssState
:members:

.. autofunction:: scale_by_schedule
.. autoclass:: ScaleByScheduleState
:members:

.. autofunction:: scale_by_sm3
.. autoclass:: ScaleBySM3State
:members:

.. autofunction:: scale_by_stddev
.. autoclass:: ScaleByRStdDevState
:members:

.. autofunction:: scale_by_trust_ratio
.. autoclass:: ScaleByTrustRatioState
:members:

.. autofunction:: scale_by_yogi

Expand All @@ -240,7 +216,6 @@ Transformations and states

.. autofunction:: trace
.. autoclass:: TraceState
:members:

.. autofunction:: update_infinity_moment
.. autofunction:: update_moment
Expand All @@ -250,4 +225,3 @@ Transformations and states

.. autofunction:: zero_nans
.. autoclass:: ZeroNansState
:members:
14 changes: 8 additions & 6 deletions optax/_src/constrain.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
# ==============================================================================
"""Gradient transformations used to enforce specific constraints."""

from typing import Any, NamedTuple
from typing import NamedTuple

import jax
import jax.numpy as jnp
Expand Down Expand Up @@ -57,13 +57,15 @@ def update_fn(updates, state, params):


class ZeroNansState(NamedTuple):
"""Contains a tree.
"""State of the `GradientTransformation` returned by `zero_nans`.
The entry `found_nan` has the same tree structure as that of the parameters.
Each leaf is a single boolean which contains True iff a NaN was detected in
the corresponding parameter array at the last call to `update`.
Attributes:
found_nan (``jax.Array``): tree that has the same structure as that of the
parameters. Each leaf is a single boolean which contains True iff a NaN
was detected in the corresponding parameter array at the last call to
`update`.
"""
found_nan: Any
found_nan: jax.Array
amosyou marked this conversation as resolved.
Show resolved Hide resolved


def zero_nans() -> base.GradientTransformation:
Expand Down
12 changes: 6 additions & 6 deletions optax/_src/lookahead.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,12 +29,12 @@ class LookaheadState(NamedTuple):
"""State of the `GradientTransformation` returned by `lookahead`.

Attributes:
fast_state: Optimizer state of the fast optimizer.
steps_since_sync: Number of fast optimizer steps taken since slow and fast
parameters were synchronized.
fast_state (:class:`optax.OptState`): Optimizer state of the fast optimizer.
steps_since_sync (``Union[jax.Array, int]``): Number of fast optimizer steps
taken since slow and fast parameters were synchronized.
"""
fast_state: base.OptState
steps_since_sync: jnp.ndarray
steps_since_sync: Union[jax.Array, int]


class LookaheadParams(NamedTuple):
Expand All @@ -48,8 +48,8 @@ class LookaheadParams(NamedTuple):
[Zhang et al, 2019](https://arxiv.org/pdf/1907.08610v1.pdf)

Attributes:
fast: Fast parameters.
slow: Slow parameters.
fast (:class:`optax.Params`): Fast parameters.
slow (:class:`optax.Params`): Slow parameters.
"""
fast: base.Params
slow: base.Params
Expand Down
Loading
Loading