15
15
"""Flax Optimizer api."""
16
16
17
17
import dataclasses
18
- from typing import Any , List , Tuple
18
+ from typing import Any , List , Tuple , Optional
19
19
import warnings
20
20
21
21
from .. import jax_utils
@@ -40,7 +40,7 @@ class OptimizerState:
40
40
41
41
class OptimizerDef :
42
42
"""Base class for an optimizer defintion, which specifies the initialization and gradient application logic.
43
-
43
+
44
44
See docstring of :class:`Optimizer` for more details.
45
45
"""
46
46
@@ -122,7 +122,7 @@ def update_hyper_params(self, **hyper_param_overrides):
122
122
hp = hp .replace (** hyper_param_overrides )
123
123
return hp
124
124
125
- def create (self , target , focus : 'ModelParamTraversal' = None ):
125
+ def create (self , target , focus : Optional [ 'ModelParamTraversal' ] = None ):
126
126
"""Creates a new optimizer for the given target.
127
127
128
128
See docstring of :class:`Optimizer` for more details.
@@ -133,7 +133,7 @@ def create(self, target, focus: 'ModelParamTraversal' = None):
133
133
of variables dicts, e.g. `(v1, v2)` and `('var1': v1, 'var2': v2)`
134
134
are valid inputs as well.
135
135
focus: a `flax.traverse_util.Traversal` that selects which subset of
136
- the target is optimized. See docstring of :class:`MultiOptimizer`
136
+ the target is optimized. See docstring of :class:`MultiOptimizer`
137
137
for an example of how to define a `Traversal` object.
138
138
Returns:
139
139
An instance of `Optimizer`.
@@ -183,10 +183,10 @@ class _NoAux:
183
183
class Optimizer (struct .PyTreeNode ):
184
184
"""
185
185
Flax optimizers are created using the :class:`OptimizerDef` class. That class
186
- specifies the initialization and gradient application logic. Creating an
187
- optimizer using the :meth:`OptimizerDef.create` method will result in an
186
+ specifies the initialization and gradient application logic. Creating an
187
+ optimizer using the :meth:`OptimizerDef.create` method will result in an
188
188
instance of the :class:`Optimizer` class, which encapsulates the optimization
189
- target and state. The optimizer is updated using the method
189
+ target and state. The optimizer is updated using the method
190
190
:meth:`apply_gradient`.
191
191
192
192
Example of constructing an optimizer for a model::
0 commit comments