Skip to content

Commit 6057855

Browse files
levskayaFlax Authors
authored and
Flax Authors
committedMar 29, 2022
Unpin pytype dependency and fix various pytype/typing errors.
PiperOrigin-RevId: 438008523
1 parent b4e9b12 commit 6057855

14 files changed

+39
-46
lines changed
 

‎examples/nlp_seq/input_pipeline.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
import collections
1919
import enum
2020

21-
import tensorflow.compat.v2 as tf
21+
import tensorflow.compat.v2 as tf # pytype: disable=import-error
2222

2323

2424
# Values for padding, unknown words and a root.

‎flax/linen/linear.py

+1-2
Original file line numberDiff line numberDiff line change
@@ -205,7 +205,7 @@ def _conv_dimension_numbers(input_shape):
205205
return lax.ConvDimensionNumbers(lhs_spec, rhs_spec, out_spec)
206206

207207

208-
PaddingLike = Union[str, int, Sequence[Union[int, Tuple[int, int]]]]
208+
PaddingLike = Union[str, int, Sequence[Union[int, Tuple[int, int]]]]
209209
LaxPadding = Union[str, Sequence[Tuple[int, int]]]
210210

211211

@@ -281,7 +281,6 @@ class _Conv(Module):
281281
bias_init: Callable[[PRNGKey, Shape, Dtype], Array] = zeros
282282

283283
@property
284-
@abc.abstractmethod
285284
def shared_weights(self) -> bool:
286285
"""Defines whether weights are shared or not between different pixels.
287286

‎flax/linen/module.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@
2222
import typing
2323
from typing import (Any, Callable, Dict, Iterable, List, Optional,
2424
Set, Tuple, Type, TypeVar, Union, overload)
25-
from typing_extensions import dataclass_transform # pytype: disable=import-error
25+
from typing_extensions import dataclass_transform # pytype: disable=not-supported-yet
2626
import weakref
2727

2828
from flax import config

‎flax/linen/recurrent.py

-1
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,6 @@ class RNNCellBase(Module):
4646
"""RNN cell base class."""
4747

4848
@staticmethod
49-
@abc.abstractmethod
5049
def initialize_carry(rng, batch_dims, size, init_fn=zeros):
5150
"""Initialize the RNN cell carry.
5251

‎flax/metrics/tensorboard.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020
# pylint: disable=g-import-not-at-top
2121
import numpy as np
2222

23-
import tensorflow.compat.v2 as tf
23+
import tensorflow.compat.v2 as tf # pytype: disable=import-error
2424
from tensorboard.plugins.hparams import api as hparams_api
2525

2626

‎flax/optim/adadelta.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414

1515
"""Adadelta Optimizer."""
1616

17+
from typing import Optional
1718
from .. import struct
1819
from .base import OptimizerDef
1920
import jax.numpy as jnp
@@ -43,7 +44,7 @@ class Adadelta(OptimizerDef):
4344
"""
4445

4546
def __init__(self,
46-
learning_rate: float = None,
47+
learning_rate: Optional[float] = None,
4748
rho: float = 0.9,
4849
eps: float = 1e-6,
4950
weight_decay: float = 0.0):

‎flax/optim/adagrad.py

+3-2
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15+
from typing import Optional
1516
import jax.numpy as jnp
1617
import numpy as np
1718
from .. import struct
@@ -35,9 +36,9 @@ class _AdagradParamState:
3536

3637
class Adagrad(OptimizerDef):
3738
"""Adagrad optimizer"""
38-
def __init__(self, learning_rate: float = None, eps=1e-8):
39+
def __init__(self, learning_rate: Optional[float] = None, eps=1e-8):
3940
"""Constructor for the Adagrad optimizer.
40-
41+
4142
Args:
4243
learning_rate: the step size used to update the parameters.
4344
"""

‎flax/optim/base.py

+7-7
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
"""Flax Optimizer api."""
1616

1717
import dataclasses
18-
from typing import Any, List, Tuple
18+
from typing import Any, List, Tuple, Optional
1919
import warnings
2020

2121
from .. import jax_utils
@@ -40,7 +40,7 @@ class OptimizerState:
4040

4141
class OptimizerDef:
4242
"""Base class for an optimizer defintion, which specifies the initialization and gradient application logic.
43-
43+
4444
See docstring of :class:`Optimizer` for more details.
4545
"""
4646

@@ -122,7 +122,7 @@ def update_hyper_params(self, **hyper_param_overrides):
122122
hp = hp.replace(**hyper_param_overrides)
123123
return hp
124124

125-
def create(self, target, focus: 'ModelParamTraversal' = None):
125+
def create(self, target, focus: Optional['ModelParamTraversal'] = None):
126126
"""Creates a new optimizer for the given target.
127127
128128
See docstring of :class:`Optimizer` for more details.
@@ -133,7 +133,7 @@ def create(self, target, focus: 'ModelParamTraversal' = None):
133133
of variables dicts, e.g. `(v1, v2)` and `('var1': v1, 'var2': v2)`
134134
are valid inputs as well.
135135
focus: a `flax.traverse_util.Traversal` that selects which subset of
136-
the target is optimized. See docstring of :class:`MultiOptimizer`
136+
the target is optimized. See docstring of :class:`MultiOptimizer`
137137
for an example of how to define a `Traversal` object.
138138
Returns:
139139
An instance of `Optimizer`.
@@ -183,10 +183,10 @@ class _NoAux:
183183
class Optimizer(struct.PyTreeNode):
184184
"""
185185
Flax optimizers are created using the :class:`OptimizerDef` class. That class
186-
specifies the initialization and gradient application logic. Creating an
187-
optimizer using the :meth:`OptimizerDef.create` method will result in an
186+
specifies the initialization and gradient application logic. Creating an
187+
optimizer using the :meth:`OptimizerDef.create` method will result in an
188188
instance of the :class:`Optimizer` class, which encapsulates the optimization
189-
target and state. The optimizer is updated using the method
189+
target and state. The optimizer is updated using the method
190190
:meth:`apply_gradient`.
191191
192192
Example of constructing an optimizer for a model::

‎flax/optim/rmsprop.py

+3-2
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15+
from typing import Optional
1516
import jax.numpy as jnp
1617
import numpy as np
1718
from .. import struct
@@ -38,7 +39,7 @@ class _RMSPropParamState:
3839

3940
class RMSProp(OptimizerDef):
4041
"""RMSProp optimizer"""
41-
def __init__(self, learning_rate: float = None, beta2=0.9, eps=1e-8,
42+
def __init__(self, learning_rate: Optional[float] = None, beta2=0.9, eps=1e-8,
4243
centered=False):
4344
"""Constructor for the RMSProp optimizer
4445
@@ -73,7 +74,7 @@ def apply_param_gradient(self, step, hyper_params, param, state, grad):
7374
else:
7475
new_mg = state.mg
7576
maybe_centered_v = new_v
76-
new_param = param - hyper_params.learning_rate * grad / (
77+
new_param = param - hyper_params.learning_rate * grad / (
7778
jnp.sqrt(maybe_centered_v) + hyper_params.eps)
7879
new_state = _RMSPropParamState(new_v, new_mg)
7980

‎flax/struct.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@
2222
from . import serialization
2323

2424
import jax
25-
from typing_extensions import dataclass_transform # pytype: disable=import-error
25+
from typing_extensions import dataclass_transform # pytype: disable=not-supported-yet
2626

2727

2828
_T = TypeVar("_T")

‎flax/training/checkpoints.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@
2828
from flax import core
2929
from flax import errors
3030
from flax import serialization
31-
from tensorflow.io import gfile
31+
from tensorflow.io import gfile # pytype: disable=import-error
3232

3333

3434
# Single-group reg-exps for int or float numerical substrings.

‎pytype.cfg

+8
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
# NOTE: All relative paths are relative to github root directory.
2+
3+
[pytype]
4+
5+
# TODO(levskaya): figure out why we get pyi-error from flax's root __init__.py
6+
# could be a pytype bug.
7+
disable =
8+
pyi-error

‎setup.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,7 @@
4444
"pytest",
4545
"pytest-cov",
4646
"pytest-xdist==1.34.0", # upgrading to 2.0 broke tests, need to investigate
47-
"pytype==2021.5.25", # pytype 2021.6.17 complains on recurrent.py, need to investigate!
47+
"pytype",
4848
"sentencepiece", # WMT example.
4949
"svn",
5050
"tensorflow_text>=2.4.0", # WMT example.

‎tests/run_all_tests.sh

+9-25
Original file line numberDiff line numberDiff line change
@@ -51,35 +51,19 @@ pytest -n auto tests $PYTEST_OPTS
5151
# In pytest foo/bar/baz_test.py and baz/bleep/baz_test.py will collide and error out when
5252
# /foo/bar and /baz/bleep aren't set up as packages.
5353
for egd in $(find examples -maxdepth 1 -mindepth 1 -type d); do
54-
# Skip tests on deprecated example -- we get this error.
55-
# ValueError: Failed to construct dataset lm1b: BuilderConfig subwords32k not found. Available: []
56-
#
57-
# TODO: Remove after github.com/google/flax/issues/567 is resolved
58-
if [[ "$egd" == "examples/lm1b_deprecated" ]]; then
59-
continue
60-
fi
6154
pytest $egd
6255
done
6356

64-
# validate types
65-
if [[ "$OSTYPE" == "darwin"* ]]; then
66-
echo "Pytype is currently not working on MacOS, see https://github.com/google/pytype/issues/661"
67-
else
68-
pytype flax/
69-
70-
for egd in $(find examples -maxdepth 1 -mindepth 1 -type d); do
71-
# Skip pytype on deprecated example
72-
# TODO: Remove after github.com/google/flax/issues/567 is resolved
73-
if [[ "$egd" == "examples/lm1b_deprecated" ]]; then
74-
continue
75-
fi
57+
# Validate types in library code.
58+
pytype --config pytype.cfg flax/
7659

77-
# use cd to make sure pytpe cache lives in example dir and doesn't name clash
78-
# use *.py to avoid importing configs as a top-level import which leads tot import errors
79-
# because config files use relative imports (e.g. from config import ...).
80-
(cd $egd ; pytype "*.py")
81-
done
82-
fi
60+
# Validate types in examples.
61+
for egd in $(find examples -maxdepth 1 -mindepth 1 -type d); do
62+
# use cd to make sure pytpe cache lives in example dir and doesn't name clash
63+
# use *.py to avoid importing configs as a top-level import which leads to import errors
64+
# because config files use relative imports (e.g. from config import ...).
65+
(cd $egd ; pytype --config ../../pytype.cfg "*.py")
66+
done
8367

8468
# Return error code 0 if no real failures happened.
8569
echo "finished all tests."

0 commit comments

Comments
 (0)
Please sign in to comment.