Skip to content

Commit

Permalink
Split up tf2jax.py to speed up type checking.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 446944266
  • Loading branch information
shaobohou authored and TF2JAXDev committed May 6, 2022
1 parent 6557a9e commit 0d7997b
Show file tree
Hide file tree
Showing 7 changed files with 2,098 additions and 2,018 deletions.
8 changes: 4 additions & 4 deletions tf2jax/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,15 +14,15 @@
# ==============================================================================
"""API of tf2jax."""

from tf2jax._src.config import get_config
from tf2jax._src.config import override_config
from tf2jax._src.config import update_config

from tf2jax._src.tf2jax import convert
from tf2jax._src.tf2jax import convert_from_restored
from tf2jax._src.tf2jax import convert_functional
from tf2jax._src.tf2jax import convert_functional_from_restored

from tf2jax._src.tf2jax import get_config
from tf2jax._src.tf2jax import override_config
from tf2jax._src.tf2jax import update_config

__version__ = "0.2.0"

# _________________________________________
Expand Down
50 changes: 50 additions & 0 deletions tf2jax/_src/config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
# Copyright 2021 DeepMind Technologies Limited. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""TF2JAX configurations."""

import contextlib

from typing import Any

_config = dict(
strict_shape_check=True,
strict_dtype_check=False,
force_const_float32_to_bfloat16=False,
force_const_float64_to_bfloat16=False,
convert_custom_gradient=True,
infer_relu_from_jax2tf=True,
)


def get_config(name: str) -> bool:
return _config[name]


def update_config(name: str, value: Any):
if name in _config:
_config[name] = value
else:
raise ValueError(
f"Parameter named {name} not found in config={_config}")


@contextlib.contextmanager
def override_config(name: str, value: Any):
old_value = get_config(name)
update_config(name, value)
try:
yield
finally:
update_config(name, old_value)
Loading

0 comments on commit 0d7997b

Please sign in to comment.