diff --git a/tf2jax/__init__.py b/tf2jax/__init__.py index 7a90115..8cd6072 100644 --- a/tf2jax/__init__.py +++ b/tf2jax/__init__.py @@ -14,15 +14,15 @@ # ============================================================================== """API of tf2jax.""" +from tf2jax._src.config import get_config +from tf2jax._src.config import override_config +from tf2jax._src.config import update_config + from tf2jax._src.tf2jax import convert from tf2jax._src.tf2jax import convert_from_restored from tf2jax._src.tf2jax import convert_functional from tf2jax._src.tf2jax import convert_functional_from_restored -from tf2jax._src.tf2jax import get_config -from tf2jax._src.tf2jax import override_config -from tf2jax._src.tf2jax import update_config - __version__ = "0.2.0" # _________________________________________ diff --git a/tf2jax/_src/config.py b/tf2jax/_src/config.py new file mode 100644 index 0000000..497bc33 --- /dev/null +++ b/tf2jax/_src/config.py @@ -0,0 +1,50 @@ +# Copyright 2021 DeepMind Technologies Limited. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""TF2JAX configurations.""" + +import contextlib + +from typing import Any + +_config = dict( + strict_shape_check=True, + strict_dtype_check=False, + force_const_float32_to_bfloat16=False, + force_const_float64_to_bfloat16=False, + convert_custom_gradient=True, + infer_relu_from_jax2tf=True, +) + + +def get_config(name: str) -> bool: + return _config[name] + + +def update_config(name: str, value: Any): + if name in _config: + _config[name] = value + else: + raise ValueError( + f"Parameter named {name} not found in config={_config}") + + +@contextlib.contextmanager +def override_config(name: str, value: Any): + old_value = get_config(name) + update_config(name, value) + try: + yield + finally: + update_config(name, old_value) diff --git a/tf2jax/_src/ops.py b/tf2jax/_src/ops.py new file mode 100644 index 0000000..4600bae --- /dev/null +++ b/tf2jax/_src/ops.py @@ -0,0 +1,1999 @@ +# Copyright 2021 DeepMind Technologies Limited. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Experimental functions for converting TF graphs to Jax functions.""" + +import dataclasses +import functools +from typing import Any, Callable, Optional, Mapping, Sequence, Set, Tuple + +from absl import logging + +import jax +import jax.numpy as jnp +import numpy as np +import tensorflow as tf +from tf2jax._src import config +from tf2jax._src import numpy_compat as anp +from tf2jax._src import xla_utils + + +# NoOp inserted to trigger side effects in function with no return values. +_EMPTY_RETURN_OP_NAME = "__NO_RETURN__" +_EMPTY_RETURN_VALUE = object() + + +def _check_attrs(proto, expected: Set[str]): + unexpected = [] + for k, v in proto.attr.items(): + # Ignore attributes with "_" prefix, as they appear to be undocumented. + if k not in expected and not k.startswith("_"): + unexpected.append(" `" + f"{k}={v}".strip() + "`") + if unexpected: + raise ValueError("\n".join( + [f"Unexpected attr(s) when parsing {proto.op}: {proto.name}"] + + unexpected)) + + +def _get_jax_op( + jax_op: Callable[..., Any], + expected_attrs: Set[str], +) -> Callable[..., Any]: + """For wrapping simple ops with no optional parameters.""" + + def wrapped(proto): + _check_attrs(proto, expected_attrs) + return jax_op + + return wrapped + + +def _fix_jax_poly_shape(shape: Tuple[Any, ...]) -> Tuple[Any, ...]: + good_shape = [] + for dim in shape: + try: + # This catches _DimPolynomial from jax2tf. + tf.compat.v1.Dimension(dim) + good_shape.append(dim) + except TypeError: + good_shape.append(None) + return tuple(good_shape) + + +_jax_ops = { + "Abs": _get_jax_op(jnp.abs, {"T"}), + "Add": _get_jax_op(anp.add, {"T"}), + "AddN": _get_jax_op( + lambda *args: anp.sum_(anp.stack(args, axis=0), axis=0, keepdims=False), + {"T", "N"}), + "AddV2": _get_jax_op(anp.add, {"T"}), + "ArgMax": _get_jax_op(jnp.argmax, {"T", "Tidx", "output_type"}), + "ArgMin": _get_jax_op(jnp.argmin, {"T", "Tidx", "output_type"}), + "Acosh": _get_jax_op(jnp.arccosh, {"T"}), + "Asinh": _get_jax_op(jnp.arcsinh, {"T"}), + "Atanh": _get_jax_op(jnp.arctanh, {"T"}), + "Atan2": _get_jax_op(jnp.arctan2, {"T"}), + "BitwiseAnd": _get_jax_op(jnp.bitwise_and, {"T"}), + "BitwiseOr": _get_jax_op(jnp.bitwise_or, {"T"}), + "BitwiseXor": _get_jax_op(jnp.bitwise_xor, {"T"}), + "BroadcastTo": _get_jax_op(anp.broadcast_to, {"T", "Tidx"}), + "Ceil": _get_jax_op(jnp.ceil, {"T"}), + "Complex": _get_jax_op(jax.lax.complex, {"T", "Tout"}), + "ComplexAbs": _get_jax_op(jax.lax.abs, {"T", "Tout"}), + "Conj": _get_jax_op(jax.lax.conj, {"T", "Tout"}), + "Cos": _get_jax_op(jnp.cos, {"T"}), + "Cosh": _get_jax_op(jnp.cosh, {"T"}), + "Digamma": _get_jax_op(jax.lax.digamma, {"T"}), + "Div": _get_jax_op(anp.divide, {"T"}), + "Elu": _get_jax_op(jax.nn.elu, {"T"}), + "Equal": _get_jax_op(anp.equal, {"T", "incompatible_shape_error"}), + "Erf": _get_jax_op(jax.lax.erf, {"T"}), + "Erfc": _get_jax_op(jax.lax.erfc, {"T"}), + "Erfinv": _get_jax_op(jax.lax.erf_inv, {"T"}), + "Exp": _get_jax_op(jnp.exp, {"T"}), + "Expm1": _get_jax_op(jnp.expm1, {"T"}), + "ExpandDims": _get_jax_op(anp.expand_dims, {"T", "Tdim"}), + "Floor": _get_jax_op(jnp.floor, {"T"}), + "FloorMod": _get_jax_op(anp.mod, {"T"}), + "FloorDiv": _get_jax_op(anp.floor_divide, {"T"}), + "Greater": _get_jax_op(anp.greater, {"T"}), + "GreaterEqual": _get_jax_op(anp.greater_equal, {"T"}), + "Identity": _get_jax_op(lambda x: x, {"T"}), + "Igamma": _get_jax_op(jax.lax.igamma, {"T"}), + "Igammac": _get_jax_op(jax.lax.igammac, {"T"}), + "Imag": _get_jax_op(jax.lax.imag, {"T", "Tout"}), + "IsFinite": _get_jax_op(jnp.isfinite, {"T"}), + "Invert": _get_jax_op(jnp.bitwise_not, {"T"}), + "L2Loss": _get_jax_op(lambda x: 0.5 * jnp.sum(jnp.square(x)), {"T"}), + "LeftShift": _get_jax_op(jnp.left_shift, {"T"}), + "Less": _get_jax_op(anp.less, {"T", "incompatible_shape_error"}), + "LessEqual": _get_jax_op(anp.less_equal, {"T", "incompatible_shape_error"}), + "Lgamma": _get_jax_op(jax.lax.lgamma, {"T"}), + "Log": _get_jax_op(jnp.log, {"T"}), + "Log1p": _get_jax_op(jnp.log1p, {"T"}), + "LogicalAnd": _get_jax_op(jnp.logical_and, {"T"}), + "LogicalNot": _get_jax_op(jnp.logical_not, {"T"}), + "LogicalOr": _get_jax_op(jnp.logical_or, {"T"}), + "Minimum": _get_jax_op(anp.minimum, {"T"}), + "Maximum": _get_jax_op(anp.maximum, {"T"}), + "Mul": _get_jax_op(anp.multiply, {"T"}), + "Neg": _get_jax_op(anp.negative, {"T"}), + "NoOp": _get_jax_op(lambda: _EMPTY_RETURN_VALUE, set({})), + "NotEqual": _get_jax_op(anp.not_equal, {"T", "incompatible_shape_error"}), + "OnesLike": _get_jax_op(jnp.ones_like, {"T"}), + "Pow": _get_jax_op(anp.power, {"T"}), + "Real": _get_jax_op(jax.lax.real, {"T", "Tout"}), + "ReadVariableOp": _get_jax_op(lambda x: x, {"dtype"}), + "RealDiv": _get_jax_op(anp.true_divide, {"T"}), + "Reciprocal": _get_jax_op(anp.reciprocal, {"T"}), + "Relu": _get_jax_op(jax.nn.relu, {"T"}), + "Relu6": _get_jax_op(jax.nn.relu6, {"T"}), + "ReverseV2": _get_jax_op(anp.flip, {"T", "Tidx"}), + "RightShift": _get_jax_op(jnp.right_shift, {"T"}), + "Round": _get_jax_op(jnp.round, {"T"}), + "Rsqrt": _get_jax_op(jax.lax.rsqrt, {"T"}), + "Shape": _get_jax_op(lambda x: np.array(jnp.shape(x)), {"T", "out_type"}), + "Sigmoid": _get_jax_op(jax.nn.sigmoid, {"T"}), + "Sign": _get_jax_op(jnp.sign, {"T"}), + "Sin": _get_jax_op(jnp.sin, {"T"}), + "Sinh": _get_jax_op(jnp.sinh, {"T"}), + "Size": _get_jax_op(lambda x: np.prod(jnp.shape(x), dtype=np.int32), + {"T", "out_type"}), + "Softplus": _get_jax_op(jax.nn.softplus, {"T"}), + "Sqrt": _get_jax_op(jnp.sqrt, {"T"}), + "Square": _get_jax_op(jnp.square, {"T"}), + "StopGradient": _get_jax_op(jax.lax.stop_gradient, {"T"}), + "Sub": _get_jax_op(anp.subtract, {"T"}), + "Tan": _get_jax_op(jnp.tan, {"T"}), + "Tanh": _get_jax_op(jnp.tanh, {"T"}), + "Tile": _get_jax_op(anp.tile, {"T", "Tmultiples"}), + "Where": _get_jax_op(jnp.argwhere, {"T"}), + "ZerosLike": _get_jax_op(jnp.zeros_like, {"T"}), + # The assignment logic is handled in _OpNode and convert(). + "AssignAddVariableOp": _get_jax_op(jnp.add, {"dtype"}), + "AssignSubVariableOp": _get_jax_op(jnp.subtract, {"dtype"}), + "AssignVariableOp": _get_jax_op( + lambda var, x: x, {"dtype", "validate_shape"}), +} + + +def get_parser(op_name: str) -> Callable[..., Callable[..., Any]]: + return _jax_ops[op_name] + + +def get_unsupported_operations(op_names: Sequence[str]) -> Set[str]: + return {name for name in op_names if name not in _jax_ops} + + +def _register(func: Callable[..., Any], op_name: str): + curr_func = _jax_ops.get(op_name, None) + if curr_func is None: + _jax_ops[op_name] = func + else: + if curr_func != func: + raise ValueError( + f"{op_name} is already registered as {curr_func}, received {func}.") + + return func + + +def register_operation(op_name): + return functools.partial(_register, op_name=op_name) + + +@dataclasses.dataclass +class _HigherOrderFunction: + """Base class for higher order ops.""" + inner_fn_names: Mapping[str, str] + + def get_inner_functions( + self, library_functions: Mapping[str, Callable[..., Any]] + ) -> Mapping[str, Callable[..., Any]]: + return {k: library_functions[v] for k, v in self.inner_fn_names.items()} + + +@register_operation("All") +def _all(proto): + _check_attrs(proto, {"Tidx", "keep_dims"}) + keep_dims = proto.attr["keep_dims"].b + return lambda x, axis: anp.all_(x, axis=axis.tolist(), keepdims=keep_dims) + + +@register_operation("Any") +def _any(proto): + _check_attrs(proto, {"Tidx", "keep_dims"}) + keep_dims = proto.attr["keep_dims"].b + return lambda x, axis: anp.any_(x, axis=axis.tolist(), keepdims=keep_dims) + + +@register_operation("Assert") +def _assert(proto): + _check_attrs(proto, {"T", "summarize"}) + + logging.warning("Assert has no effect and will just return the data.") + + return lambda cond, data: data + + +@register_operation("AvgPool") +def _avg_pool(proto): + """Parse a AvgPool Op.""" + _check_attrs( + proto, + {"T", "padding", "explicit_paddings", "ksize", "strides", "data_format"}) + + explicit_paddings = tuple(proto.attr["explicit_paddings"].list.i) + if explicit_paddings: + raise ValueError("explicit_padding in AvgPool not yet supported.") + + padding = str(proto.attr["padding"].s, "utf-8") + ksize = tuple(proto.attr["ksize"].list.i) + strides = tuple(proto.attr["strides"].list.i) + data_format = str(proto.attr["data_format"].s, "utf-8") + if data_format not in ("NHWC", "NCHW"): + raise ValueError(f"Found unsupported data format {data_format}.") + + reduce_window_args = dict( + init_value=0., + computation=jax.lax.add, + window_dimensions=ksize, + window_strides=strides, + padding=padding) + + def _func(x: jnp.ndarray) -> jnp.ndarray: + pooled = jax.lax.reduce_window(x, **reduce_window_args) + if padding == "VALID": + window_counts = np.prod(ksize) + else: + window_counts = jax.lax.reduce_window( + jnp.ones_like(x), **reduce_window_args) + return pooled / window_counts + + return _func + + +@register_operation("BiasAdd") +def _bias_add(proto): + """Parse a BiasAdd Op.""" + _check_attrs(proto, {"T", "data_format"}) + + data_format = str(proto.attr["data_format"].s, "utf-8") + if data_format == "NHWC": + expand_axis_fn = lambda x: [d for d in range(x.ndim) if d != x.ndim - 1] + elif data_format == "NCHW": + # TODO(shaobohou) this seems wrong but matches TF behaviour. + expand_axis_fn = lambda x: [d for d in range(x.ndim) if d != 1] + else: + raise ValueError(f"Found unsupported data format {data_format}.") + + def _func(value: jnp.ndarray, bias: jnp.ndarray) -> jnp.ndarray: + if bias.ndim != 1: + raise ValueError( + f"Expected `bias` as a 1D array, found array with {bias.ndim} dims.") + bias = anp.expand_dims(bias, axis=expand_axis_fn(value)) + return anp.add(value, bias) + + return _func + + +@register_operation("Bitcast") +def _bit_cast(proto): + _check_attrs(proto, {"T", "type"}) + dst_type = tf.as_dtype(proto.attr["type"].type) + return lambda x: jax.lax.bitcast_convert_type(x, anp.get_jax_dtype(dst_type)) + + +@register_operation("BroadcastArgs") +def _broadcast_args(proto): + _check_attrs(proto, {"T"}) + return lambda s0, s1: np.array(np.broadcast(np.zeros(s0), np.zeros(s1)).shape) + + +class _CaseOp(_HigherOrderFunction): + """Represents a Case Op.""" + + def __call__(self, branch_index, *operand, **branch_fns): + def create_branch(fn): + return lambda args: fn(*args) + branches = [create_branch(fn) for _, fn in sorted(branch_fns.items())] + return jax.lax.switch(branch_index, branches=branches, operand=operand) + + +@register_operation("StatelessCase") +@register_operation("Case") +def _case(proto): + """Parse a Case op.""" + _check_attrs(proto, {"Tin", "Tout", "output_shapes", "branches"}) + + branches = [f.name for f in proto.attr["branches"].list.func] + output_shapes = [ + [d.size for d in xs.dim] for xs in proto.attr["output_shapes"].list.shape + ] + del output_shapes + + return _CaseOp({f"fn_{k:06}": v for k, v in enumerate(branches)}) + + +@register_operation("Cast") +def _cast(proto): + """Parse a Cast Op.""" + _check_attrs(proto, {"SrcT", "DstT", "Truncate"}) + + src_type = tf.as_dtype(proto.attr["SrcT"].type) + dst_type = tf.as_dtype(proto.attr["DstT"].type) + truncate = proto.attr["Truncate"].b + del src_type + + if truncate: + raise ValueError(f"Cast does not support truncate={truncate}.") + + def _func(x: jnp.ndarray) -> jnp.ndarray: + return anp.asarray(x, dst_type) + + return _func + + +@register_operation("ConjugateTranspose") +def _conjugate_transpose(proto): + _check_attrs(proto, {"T", "Tperm"}) + return lambda x, axes: jax.lax.conj(jnp.transpose(x, axes=axes)) + + +@register_operation("ConcatV2") +def _concatenate(proto): + """Parse a ConcatV2 Op.""" + _check_attrs(proto, {"T", "Tidx", "N"}) + + num_arrays = proto.attr["N"].i + + def _func(*args) -> jnp.ndarray: + if len(args) != num_arrays + 1: + raise ValueError( + f"Concatenate expects {num_arrays} args, received {len(args)}.") + + *inputs, axis = args + return anp.concatenate(inputs, axis=axis) + + return _func + + +@register_operation("Const") +def _const(proto): + """Parse a Const Op.""" + _check_attrs(proto, {"dtype", "value"}) + value = tf.make_ndarray(proto.attr["value"].tensor) + dtype = value.dtype + + force_float32 = config.get_config("force_const_float32_to_bfloat16") + force_float64 = config.get_config("force_const_float64_to_bfloat16") + if ((force_float32 and dtype == np.float32) or + (force_float64 and dtype == np.float64)): + # NOTE: `jnp.asarray` (rather than the `np` version) cannot be used here; + # using it in a jitted context can produce runtime `UnexpectedTracerError`s. + bf16_value = np.asarray(value, dtype=jnp.bfloat16) + logging.warning("Converting float consts to bfloat16, from %s, to %s.", + value, bf16_value) + value = bf16_value + + return lambda: value + + +@register_operation("Conv2D") +def _conv2d(proto): + """Parse a Conv2D Op.""" + _check_attrs( + proto, { + "T", "padding", "explicit_paddings", "dilations", "strides", + "data_format", "use_cudnn_on_gpu" + }) + + explicit_paddings = tuple(proto.attr["explicit_paddings"].list.i) + if explicit_paddings: + raise ValueError("explicit_padding in Conv2D not yet supported.") + + padding = str(proto.attr["padding"].s, "utf-8") + dilations = tuple(proto.attr["dilations"].list.i) + strides = tuple(proto.attr["strides"].list.i) + data_format = str(proto.attr["data_format"].s, "utf-8") + if data_format == "NHWC": + dimension_numbers = ("NHWC", "HWIO", "NHWC") + strides = xla_utils.get_conv_sequence(strides, ndim=2, channel_index=-1) + dilations = xla_utils.get_conv_sequence(dilations, ndim=2, channel_index=-1) + feature_group_count_fn = lambda lhs, rhs: lhs.shape[3] // rhs.shape[2] + elif data_format == "NCHW": + dimension_numbers = ("NCHW", "HWIO", "NCHW") + strides = xla_utils.get_conv_sequence(strides, ndim=2, channel_index=1) + dilations = xla_utils.get_conv_sequence(dilations, ndim=2, channel_index=1) + feature_group_count_fn = lambda lhs, rhs: lhs.shape[1] // rhs.shape[2] + else: + raise ValueError(f"Found unsupported data format {data_format}.") + + _ = proto.attr["use_cudnn_on_gpu"].b + + def _func(lhs: jnp.ndarray, rhs: jnp.ndarray) -> jnp.ndarray: + feature_group_count = feature_group_count_fn(lhs, rhs) + return jax.lax.conv_general_dilated( + lhs, + rhs, + window_strides=strides, + padding=padding, + dimension_numbers=dimension_numbers, + rhs_dilation=dilations, + feature_group_count=feature_group_count) + + return _func + + +@register_operation("Conv2DBackpropInput") +def _conv2d_backprop_input(proto): + """Parse a Conv2DBackpropInput Op.""" + _check_attrs( + proto, { + "T", "padding", "explicit_paddings", "dilations", "strides", + "data_format", "use_cudnn_on_gpu" + }) + + explicit_paddings = tuple(proto.attr["explicit_paddings"].list.i) + if explicit_paddings: + raise ValueError( + "explicit_padding in Conv2DBackpropInput not yet supported.") + + padding = str(proto.attr["padding"].s, "utf-8") + dilations = tuple(proto.attr["dilations"].list.i) + strides = tuple(proto.attr["strides"].list.i) + data_format = str(proto.attr["data_format"].s, "utf-8") + if data_format == "NHWC": + dimension_numbers = ("NHWC", "HWIO", "NHWC") + strides = xla_utils.get_conv_sequence(strides, ndim=2, channel_index=-1) + dilations = xla_utils.get_conv_sequence(dilations, ndim=2, channel_index=-1) + elif data_format == "NCHW": + dimension_numbers = ("NCHW", "HWIO", "NCHW") + strides = xla_utils.get_conv_sequence(strides, ndim=2, channel_index=1) + dilations = xla_utils.get_conv_sequence(dilations, ndim=2, channel_index=1) + else: + raise ValueError(f"Found unsupported data format {data_format}.") + + _ = proto.attr["use_cudnn_on_gpu"].b + + def _func( + input_sizes: jnp.ndarray, + filters: jnp.ndarray, + out_backprop: jnp.ndarray, + ) -> jnp.ndarray: + del input_sizes + return jax.lax.conv_transpose( + out_backprop, + filters, + strides=strides, + padding=padding, + rhs_dilation=dilations, + transpose_kernel=True, + dimension_numbers=dimension_numbers) + + return _func + + +@register_operation("Cumsum") +def _cumsum(proto): + """Parse a Cumsum Op.""" + _check_attrs(proto, {"T", "Tidx", "exclusive", "reverse"}) + + exclusive = proto.attr["exclusive"].b + reverse = proto.attr["reverse"].b + + def _func(x: jnp.ndarray, axis: jnp.ndarray) -> jnp.ndarray: + axis: int = axis.tolist() + if reverse: + x = anp.flip(x, axis=axis) + if exclusive: + pad_shape = list(x.shape) + pad_shape[axis] = 1 + x = anp.concatenate([np.zeros(pad_shape, dtype=x.dtype), x], axis=axis) + x = x[(slice(None),) * axis + (slice(0, -1), Ellipsis)] + res = anp.cumsum(x, axis=axis) + if reverse: + res = anp.flip(res, axis=axis) + return res + + return _func + + +@register_operation("DepthwiseConv2dNative") +def _depthwise_conv2d(proto): + """Parse a DepthwiseConv2d Op.""" + _check_attrs(proto, { + "T", "strides", "dilations", "padding", "data_format", "explicit_paddings" + }) + + explicit_paddings = tuple(proto.attr["explicit_paddings"].list.i) + if explicit_paddings: + explicit_paddings = [ + tuple(x) for x in np.array(explicit_paddings).reshape(4, 2).tolist() + ] + + padding = explicit_paddings or str(proto.attr["padding"].s, "utf-8") + dilations = tuple(proto.attr["dilations"].list.i) + strides = tuple(proto.attr["strides"].list.i) + data_format = str(proto.attr["data_format"].s, "utf-8") + if data_format == "NHWC": + if explicit_paddings: + padding = padding[1:3] + dimension_numbers = ("NHWC", "HWIO", "NHWC") + strides = xla_utils.get_conv_sequence(strides, ndim=2, channel_index=-1) + dilations = xla_utils.get_conv_sequence(dilations, ndim=2, channel_index=-1) + channel_index = -1 + elif data_format == "NCHW": + if explicit_paddings: + padding = padding[2:] + dimension_numbers = ("NCHW", "HWIO", "NCHW") + strides = xla_utils.get_conv_sequence(strides, ndim=2, channel_index=1) + dilations = xla_utils.get_conv_sequence(dilations, ndim=2, channel_index=1) + channel_index = 1 + else: + raise ValueError(f"Found unsupported data format {data_format}.") + + def _func(lhs: jnp.ndarray, rhs: jnp.ndarray) -> jnp.ndarray: + output_dim = rhs.shape[2] * rhs.shape[3] + return jax.lax.conv_general_dilated( + lhs, + jnp.reshape(rhs, rhs.shape[:2] + (1, output_dim)), + window_strides=strides, + padding=padding, + dimension_numbers=dimension_numbers, + rhs_dilation=dilations, + feature_group_count=lhs.shape[channel_index]) + + return _func + + +@register_operation("Einsum") +def _einsum(proto): + """Parse an Einsum Op.""" + _check_attrs(proto, {"T", "N", "equation"}) + + num_inputs = proto.attr["N"].i + equation = str(proto.attr["equation"].s, "utf-8") + + def _func(*operands): + if len(operands) != num_inputs: + raise ValueError( + f"Expected {num_inputs} input arrays, found {len(operands)}") + return jnp.einsum(equation, *operands) + + return _func + + +@register_operation("Empty") +def _empty(proto): + """Parse an Empty op.""" + _check_attrs(proto, {"dtype", "init"}) + + dtype = tf.as_dtype(proto.attr["dtype"].type) + init = proto.attr["init"].b + + def _func(shape: jnp.ndarray) -> jnp.ndarray: + return anp.empty(shape=shape, dtype=dtype, init=init) + + return _func + + +@register_operation("Fill") +def _fill(proto): + """Parse an Fill op.""" + _check_attrs(proto, {"T", "index_type"}) + + dtype = tf.as_dtype(proto.attr["T"].type) + + def _func(shape: jnp.ndarray, fill_value: jnp.ndarray) -> jnp.ndarray: + return anp.full(shape=shape, fill_value=fill_value, dtype=dtype) + + return _func + + +@register_operation("FusedBatchNormV3") +@register_operation("FusedBatchNormV2") +def _fused_batch_norm(proto): + """Parse a FusedBatchNorm Op.""" + _check_attrs(proto, { + "T", "U", "data_format", "epsilon", "exponential_avg_factor", + "is_training" + }) + + data_format = str(proto.attr["data_format"].s, "utf-8") + if data_format == "NHWC": + reduce_axis = (0, 1, 2) + channel_dim = 3 + elif data_format == "NCHW": + reduce_axis = (0, 2, 3) + channel_dim = 1 + else: + raise ValueError(f"Found unsupported data format {data_format}.") + + epsilon = proto.attr["epsilon"].f + exponential_avg_factor = proto.attr["exponential_avg_factor"].f + one_minus_factor = 1. - exponential_avg_factor + is_training = proto.attr["is_training"].b + + def _func( + x: jnp.ndarray, + scale: jnp.ndarray, + offset: jnp.ndarray, + running_mean: jnp.ndarray, + running_var: jnp.ndarray, + ) -> Tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray]: + batch_mean = jnp.mean(x, axis=reduce_axis) + batch_var = jnp.var(x, axis=reduce_axis) + est_mean = batch_mean if is_training else running_mean + est_var = batch_var if is_training else running_var + + # Prep for broadcasting. + scale = jnp.expand_dims(scale, axis=reduce_axis) + offset = jnp.expand_dims(offset, axis=reduce_axis) + est_mean = jnp.expand_dims(est_mean, axis=reduce_axis) + est_var = jnp.expand_dims(est_var, axis=reduce_axis) + + inv = scale * jax.lax.rsqrt(est_var + epsilon) + norm_x = jnp.asarray((x - est_mean) * inv + offset, x.dtype) + + if is_training: + # Apply Bessel's correction and additional smoothing. + ndata = x.size / x.shape[channel_dim] + correction = ndata / jnp.maximum(ndata - 1.0, 1.0) + running_var = running_var if running_var.size else 0 + running_mean = running_mean if running_mean.size else 0 + new_var = ( + one_minus_factor * running_var + + exponential_avg_factor * batch_var * correction) + new_mean = ( + one_minus_factor * running_mean + + exponential_avg_factor * batch_mean) + return norm_x, new_mean, new_var + else: + return norm_x, running_mean, running_var + + return _func + + +@register_operation("GatherNd") +def _gather_nd(proto): + """Parse a GatherNd Op.""" + _check_attrs(proto, {"Tindices", "Tparams"}) + + def _func(params: jnp.ndarray, indices: jnp.ndarray) -> jnp.ndarray: + return params[tuple(anp.moveaxis(indices, -1, 0))] + + return _func + + +@register_operation("GatherV2") +def _gather(proto): + """Parse a GatherV2 Op.""" + _check_attrs(proto, {"Taxis", "Tindices", "Tparams", "batch_dims"}) + + batch_dims = proto.attr["batch_dims"].i + if batch_dims < 0: + raise ValueError(f"batch_dims={batch_dims} must be non-negative.") + + def _func( + params: jnp.ndarray, + indices: jnp.ndarray, + axis: jnp.ndarray, + ) -> jnp.ndarray: + return anp.gather( + params, indices, axis=axis.tolist(), batch_dims=batch_dims) + + return _func + + +@dataclasses.dataclass +class _IdentityN(_HigherOrderFunction): + """Represents a IdentityN Op.""" + + gradient_op_type: str # For debug, custom_gradient is handled by _Subgraph. + + def __call__(self, *args): + return args + + +@register_operation("IdentityN") +def _identity_n(proto): + """Parse a IdentityN Op.""" + _check_attrs(proto, {"T"}) + + gradient_op_type = str(proto.attr["_gradient_op_type"].s, "utf-8") + if gradient_op_type: + logging.info("Found custom gradient %s", gradient_op_type) + + return _IdentityN({}, gradient_op_type=gradient_op_type) + + +class _IfOp(_HigherOrderFunction): + """Represents a If Op.""" + + def __call__(self, pred, *operand, then_fun, else_fun): + true_fun = lambda args: then_fun(*args) + false_fun = lambda args: else_fun(*args) + return jax.lax.cond( + pred, true_fun=true_fun, false_fun=false_fun, operand=operand) + + +@register_operation("StatelessIf") +@register_operation("If") +def _if(proto): + """Parse a If op.""" + _check_attrs(proto, { + "Tcond", "Tin", "Tout", "output_shapes", "then_branch", "else_branch" + }) + + then_name = proto.attr["then_branch"].func.name + else_name = proto.attr["else_branch"].func.name + output_shapes = [ + [d.size for d in xs.dim] for xs in proto.attr["output_shapes"].list.shape + ] + del output_shapes + + return _IfOp(dict(then_fun=then_name, else_fun=else_name)) + + +@register_operation("InplaceAdd") +def _inplace_add(proto): + """Parse a InplaceAdd op.""" + _check_attrs(proto, {"T"}) + + def _func( + inputs: jnp.ndarray, + indices: jnp.ndarray, + updates: jnp.ndarray, + ) -> jnp.ndarray: + return jnp.asarray(inputs).at[indices].add(updates) + + return _func + + +@register_operation("InplaceUpdate") +def _inplace_update(proto): + """Parse a InplaceUpdate op.""" + _check_attrs(proto, {"T"}) + + def _func( + inputs: jnp.ndarray, + indices: jnp.ndarray, + updates: jnp.ndarray, + ) -> jnp.ndarray: + return jnp.asarray(inputs).at[indices].set(updates) + + return _func + + +@register_operation("LogSoftmax") +def _log_softmax(proto): + _check_attrs(proto, {"T"}) + return lambda x: jax.nn.log_softmax(x, axis=-1) + + +@register_operation("MatMul") +def _matmul(proto): + """Parse a MatMul Op.""" + _check_attrs(proto, {"T", "transpose_a", "transpose_b"}) + + transpose_a = proto.attr["transpose_a"].b + transpose_b = proto.attr["transpose_b"].b + + def _func(a: jnp.ndarray, b: jnp.ndarray) -> jnp.ndarray: + if transpose_a: + a = jnp.transpose(a) + if transpose_b: + b = jnp.transpose(b) + + return jnp.matmul(a, b) + + return _func + + +@register_operation("BatchMatMulV2") +def _batch_matmul(proto): + """Parse a BatchMatMul Op.""" + _check_attrs(proto, {"T", "adj_x", "adj_y"}) + + adj_x = proto.attr["adj_x"].b + adj_y = proto.attr["adj_y"].b + + # TODO(shaobohou) Add test for arrays with complex values. + def _func(x: jnp.ndarray, y: jnp.ndarray) -> jnp.ndarray: + if adj_x: + x = jnp.conjugate(jnp.swapaxes(x, -1, -2)) + if adj_y: + y = jnp.conjugate(jnp.swapaxes(y, -1, -2)) + + return jnp.matmul(x, y) + + return _func + + +@register_operation("MatrixDiagV3") +def _matrix_diag(proto): + """Parse a MatrixDiagV3 op.""" + _check_attrs(proto, {"T", "align"}) + + align = str(proto.attr["align"].s, "utf-8") + if align != "RIGHT_LEFT": + raise ValueError(f"MatrixDiagV3 does not support `align={align}` yet.") + + def _func( + diagonals: jnp.ndarray, + k: jnp.ndarray, + num_rows: jnp.ndarray, + num_cols: jnp.ndarray, + padding_value: jnp.ndarray, + ) -> jnp.ndarray: + if num_rows != -1 or num_cols != -1: + raise ValueError(f"MatrixDiagV3 does not yet support num_rows={num_rows} " + f"or num_cols={num_cols}.") + + diag_fn = lambda inputs: jnp.diagflat(inputs, k=k) + for _ in range(len(diagonals.shape) - 1): + diag_fn = jax.vmap(diag_fn) + + outputs = diag_fn(diagonals) + mask = diag_fn(jnp.ones_like(diagonals, dtype=jnp.bool_)) + return jnp.where(mask, outputs, padding_value) + + return _func + + +@register_operation("MatrixBandPart") +def _matrix_band_part(proto): + """Parse a MatrixBandPart op.""" + _check_attrs(proto, {"T", "Tindex"}) + + def _func( + x: jnp.ndarray, + lower: jnp.ndarray, + upper: jnp.ndarray, + ) -> jnp.ndarray: + if len(x.shape) < 2: + raise ValueError( + f"Expected input of at least rank 2, found {len(x.shape)}") + mask_shape = x.shape[-2:] + lower = lower.tolist() + 1 if lower.tolist() >= 0 else max(mask_shape) + mask_lower = jnp.tril(jnp.ones(mask_shape, jnp.int32), -lower) + upper = upper.tolist() + 1 if upper.tolist() >= 0 else max(mask_shape) + mask_upper = jnp.triu(jnp.ones(mask_shape, jnp.int32), upper) + return jnp.where((mask_lower + mask_upper) == 0, x, 0) + + return _func + + +@register_operation("Max") +def _max(proto): + _check_attrs(proto, {"T", "Tidx", "keep_dims"}) + + keep_dims = proto.attr["keep_dims"].b + + def _func(x: jnp.ndarray, axis: jnp.ndarray) -> jnp.ndarray: + return anp.max_(x, axis=axis.tolist(), keepdims=keep_dims) + + return _func + + +@register_operation("MaxPool") +def _max_pool(proto): + """Parse a MaxPool Op.""" + _check_attrs( + proto, + {"T", "padding", "explicit_paddings", "ksize", "strides", "data_format"}) + + explicit_paddings = tuple(proto.attr["explicit_paddings"].list.i) + if explicit_paddings: + raise ValueError("explicit_padding in MaxPool not yet supported.") + + padding = str(proto.attr["padding"].s, "utf-8") + ksize = tuple(proto.attr["ksize"].list.i) + strides = tuple(proto.attr["strides"].list.i) + data_format = str(proto.attr["data_format"].s, "utf-8") + if data_format not in ("NHWC", "NCHW"): + raise ValueError(f"Found unsupported data format {data_format}.") + + def _func(x: jnp.ndarray) -> jnp.ndarray: + return jax.lax.reduce_window( + x, + init_value=-jnp.inf, + computation=jax.lax.max, + window_dimensions=ksize, + window_strides=strides, + padding=padding) + + return _func + + +@register_operation("Mean") +def _mean(proto): + _check_attrs(proto, {"T", "Tidx", "keep_dims"}) + + keep_dims = proto.attr["keep_dims"].b + + def _func(x: jnp.ndarray, axis: jnp.ndarray) -> jnp.ndarray: + return jnp.mean(x, axis=axis.tolist(), keepdims=keep_dims) + + return _func + + +@register_operation("Min") +def _min(proto): + _check_attrs(proto, {"T", "Tidx", "keep_dims"}) + + keep_dims = proto.attr["keep_dims"].b + + def _func(x: jnp.ndarray, axis: jnp.ndarray) -> jnp.ndarray: + return anp.min_(x, axis=axis.tolist(), keepdims=keep_dims) + + return _func + + +@register_operation("OneHot") +def _one_hot(proto): + """Parse a OneHot Op.""" + _check_attrs(proto, {"T", "TI", "axis"}) + + axis = proto.attr["axis"].i + + def _func( + indices: jnp.ndarray, + depth: jnp.ndarray, + on_value: jnp.ndarray, + off_value: jnp.ndarray, + ) -> jnp.ndarray: + if axis != -1 and axis != len(indices.shape): + raise ValueError(f"OneHot does not support axis={axis} yet, " + f"indices.shape={indices.shape}.") + + mask = jax.nn.one_hot(indices, num_classes=depth, dtype=jnp.int32) + return mask * on_value + (1 - mask) * off_value + + return _func + + +@register_operation("Pack") +def _pack(proto): + """Parse a Pack op.""" + _check_attrs(proto, {"T", "axis", "N"}) + + num_arrays = proto.attr["N"].i + axis = proto.attr["axis"].i + + def _func(*args) -> jnp.ndarray: + if len(args) != num_arrays: + raise ValueError( + f"Pack expects {num_arrays} args, received {len(args)}.") + return anp.stack(args, axis=axis) + + return _func + + +@register_operation("Pad") +def _pad(proto): + _check_attrs(proto, {"T", "Tpaddings"}) + return lambda x, paddings: jnp.pad(x, pad_width=paddings) + + +@register_operation("PadV2") +def _pad_v2(proto): + """Parse a PadV2 op.""" + _check_attrs(proto, {"T", "Tpaddings"}) + + def _func( + inputs: jnp.ndarray, + padding: jnp.ndarray, + constant_values: jnp.ndarray, + ) -> jnp.ndarray: + return jnp.pad(inputs, pad_width=padding, constant_values=constant_values) + + return _func + + +class _PartitionedCall(_HigherOrderFunction): + """Represents a PartitionedCall Op.""" + + def __call__(self, *args, inner_fn, rng=None): + return inner_fn(*args, rng=rng) + + +# TODO(shaobohou) Add test for StatefulPartitionedCall. +@register_operation("StatefulPartitionedCall") +@register_operation("PartitionedCall") +def _partitioned_call(proto): + """Parse a PartitionedCall op.""" + _check_attrs(proto, + {"f", "Tin", "Tout", "config", "config_proto", "executor_type"}) + + inner_fn = proto.attr["f"].func.name + op_config = str(proto.attr["config"].s, "utf-8") + op_config_proto = proto.attr["config_proto"].s # TODO(shaobohou) decode this? + executor_type = str(proto.attr["executor_type"].s, "utf-8") + del op_config, op_config_proto, executor_type + + return _PartitionedCall(dict(inner_fn=inner_fn)) + + +@register_operation("Placeholder") +def _placeholder(proto): + _check_attrs(proto, {"dtype", "shape"}) + + name = proto.name + + def _func(): + raise ValueError(f"Placeholder `{name}` cannot be evaluated.") + + return _func + + +@register_operation("PreventGradient") +def _prevent_gradient(proto): + """Parse a PreventGradient op.""" + _check_attrs(proto, {"T", "message"}) + + message = str(proto.attr["message"].s, "utf-8") + jax_message = ( + f"Gradient explicitly prevented on node {proto.name}. Reason: {message}") + + @jax.custom_gradient + def _func(operand: jnp.ndarray) -> jnp.ndarray: + def grad_fn(_): + raise LookupError(jax_message) + return operand, grad_fn + + return _func + + +@register_operation("Prod") +def _prod(proto): + """Parse a Prod op.""" + _check_attrs(proto, {"T", "Tidx", "keep_dims"}) + + keep_dims = proto.attr["keep_dims"].b + + def _func(x: jnp.ndarray, axis: jnp.ndarray) -> jnp.ndarray: + return anp.prod(x, axis=axis.tolist(), keepdims=keep_dims) + + return _func + + +@register_operation("RandomStandardNormal") +def _random_standard_normal(proto): + """Parse a RandomStandardNormal op.""" + _check_attrs(proto, {"T", "dtype", "seed", "seed2"}) + + seed = proto.attr["seed"].i + seed2 = proto.attr["seed2"].i + dtype = tf.as_dtype(proto.attr["dtype"].type) + jax_dtype = anp.get_jax_dtype(dtype) + + if seed != 0 or seed2 != 0: + logging.warning( + "RandomStandardNormal does not yet support non-zero seeds, found " + "seed=%s and seed2=%s.", seed, seed2) + + return lambda shape, *, rng: jax.random.normal(rng, shape, dtype=jax_dtype) + + +@register_operation("RandomUniform") +def _random_uniform(proto): + """Parse a RandomUniform op.""" + _check_attrs(proto, {"T", "dtype", "seed", "seed2"}) + + seed = proto.attr["seed"].i + seed2 = proto.attr["seed2"].i + dtype = tf.as_dtype(proto.attr["dtype"].type) + jax_dtype = anp.get_jax_dtype(dtype) + + if seed != 0 or seed2 != 0: + logging.warning( + "RandomUniform does not yet support non-zero seeds, found " + "seed=%s and seed2=%s.", seed, seed2) + + return lambda shape, *, rng: jax.random.uniform(rng, shape, dtype=jax_dtype) + + +@register_operation("RandomUniformInt") +def _random_uniform_int(proto): + """Parse a RandomUniformInt op.""" + _check_attrs(proto, {"T", "Tout", "seed", "seed2"}) + + seed = proto.attr["seed"].i + seed2 = proto.attr["seed2"].i + dtype = tf.as_dtype(proto.attr["Tout"].type) + jax_dtype = anp.get_jax_dtype(dtype) + + if seed != 0 or seed2 != 0: + logging.warning( + "RandomUniformInt does not yet support non-zero seeds, found " + "seed=%s and seed2=%s.", seed, seed2) + + def _func(shape, minval, maxval, *, rng): + return jax.random.randint( + rng, shape, minval=minval, maxval=maxval, dtype=jax_dtype) + + return _func + + +@register_operation("Range") +def _range(proto): + """Parse a Range op.""" + _check_attrs(proto, {"Tidx"}) + + dtype = tf.as_dtype(proto.attr["Tidx"].type) + + def _func( + start: jnp.ndarray, + limit: jnp.ndarray, + delta: jnp.ndarray, + ) -> jnp.ndarray: + return anp.arange(start, stop=limit, step=delta, dtype=dtype) + + return _func + + +@register_operation("Reshape") +def _reshape(proto): + _check_attrs(proto, {"T", "Tshape"}) + return lambda x, shape: jnp.reshape(x, newshape=shape) + + +@register_operation("ResizeBilinear") +def _resize_linear(proto): + """Parse a ResizeBilinear op.""" + _check_attrs(proto, {"T", "align_corners", "half_pixel_centers"}) + + align_corners = proto.attr["align_corners"].b + half_pixel_centers = proto.attr["half_pixel_centers"].b + if align_corners and half_pixel_centers: + # Not supported by tf.raw_ops.ResizeBilinear. + raise ValueError( + "align_corners=True and half_pixel_centers=True are not supported. ") + + def _func(images: jnp.ndarray, size: jnp.ndarray) -> jnp.ndarray: + if len(images.shape) != 4: + raise ValueError( + "Expected A 4D tensor with shape [batch, height, width, channels], " + f"found {images.shape}") + + inp_batch, inp_height, inp_width, inp_channels = images.shape + out_height, out_width = size.tolist() + + height_scale = out_height / inp_height + width_scale = out_width / inp_width + if align_corners: + if out_height > 1: + height_scale = (out_height - 1) / (inp_height - 1) + if out_width > 1: + width_scale = (out_width - 1) / (inp_width - 1) + scale = np.array((height_scale, width_scale)) + + translation = np.array(([0.0] * 2)) + if not half_pixel_centers: + translation = translation - scale * 0.5 + 0.5 + + return jax.image.scale_and_translate( + images, + shape=(inp_batch, out_height, out_width, inp_channels), + spatial_dims=(1, 2), + scale=scale, + translation=translation, + method="linear", + antialias=False, + precision=None, + ) + + return _func + + +@register_operation("ScatterNd") +def _scatter_nd(proto): + """Parse a ScatterNd op.""" + _check_attrs(proto, {"T", "Tindices"}) + + def _func( + indices: jnp.ndarray, + updates: jnp.ndarray, + shape: jnp.ndarray, + ) -> jnp.ndarray: + zeros = jnp.zeros(shape, updates.dtype) + key = tuple(jnp.moveaxis(indices, -1, 0)) + return zeros.at[key].set(updates) + + return _func + + +@register_operation("SelectV2") +@register_operation("Select") +def _select(proto): + """Parse a Select op.""" + _check_attrs(proto, {"T"}) + + def _func( + conds: jnp.ndarray, + x: jnp.ndarray, + y: jnp.ndarray, + ) -> jnp.ndarray: + conds = anp.expand_dims(conds, axis=tuple(range(conds.ndim, x.ndim))) + return anp.where(conds, x, y) + + return _func + + +@register_operation("Slice") +def _slice(proto): + """Parse a Slice Op.""" + _check_attrs(proto, {"T", "Index"}) + + def _func( + x: jnp.ndarray, + begins: jnp.ndarray, + sizes: jnp.ndarray, + ) -> jnp.ndarray: + """`begins` and `sizes` must be concrete arrays.""" + slices = [slice(b, b + s) for b, s in zip(begins, sizes)] + return x[tuple(slices)] + + return _func + + +@register_operation("Softmax") +def _softmax(proto): + _check_attrs(proto, {"T"}) + return lambda x: jax.nn.softmax(x, axis=-1) + + +@register_operation("SparseSoftmaxCrossEntropyWithLogits") +def _sparse_softmax_cross_entropy_with_logits(proto): + """Parse a SparseSoftmaxCrossEntropyWithLogits Op.""" + _check_attrs(proto, {"T", "Tlabels"}) + + def cross_entropy(xs, ys): + return -1.0 * jnp.take_along_axis( + jax.nn.log_softmax(xs, axis=-1), ys[:, jnp.newaxis], axis=-1)[:, 0] + + def _func(features: jnp.ndarray, + labels: jnp.ndarray) -> Tuple[jnp.ndarray, jnp.ndarray]: + + loss = cross_entropy(features, labels) + vjp_outputs, vjp_fn = jax.vjp(cross_entropy, features, labels) + grads = vjp_fn(jnp.ones(vjp_outputs.shape))[0] + return loss, grads + + return _func + + +@register_operation("Split") +def _split(proto): + """Parse a Split op.""" + _check_attrs(proto, {"T", "num_split"}) + num_split = proto.attr["num_split"].i + return lambda axis, inputs: anp.split(inputs, num_split, axis=axis) + + +@register_operation("SplitV") +def _splitv(proto): + """Parse a SplitV op.""" + _check_attrs(proto, {"T", "Tlen", "num_split"}) + num_split = proto.attr["num_split"].i + + def _func( + value: jnp.ndarray, + size_splits: jnp.ndarray, + axis: jnp.ndarray, + ) -> jnp.ndarray: + assert size_splits.shape[0] == num_split, (size_splits.shape[0], num_split) + splits = size_splits.tolist() + axis = axis.tolist() + defined_size = sum([x for x in splits if x >= 0]) + splits = [x if x >= 0 else value.shape[axis] - defined_size for x in splits] + indices = np.cumsum(np.array(splits), axis=0) + return anp.split(value, indices, axis=axis) + + return _func + + +@register_operation("SquaredDifference") +def _squared_difference(proto): + _check_attrs(proto, {"T"}) + return lambda x1, x2: jnp.square(x1 - x2) + + +@register_operation("Squeeze") +def _squeeze(proto): + _check_attrs(proto, {"T", "squeeze_dims"}) + + axis = tuple(proto.attr["squeeze_dims"].list.i) + + return lambda x: anp.squeeze(x, axis=axis) + + +@register_operation("StatelessRandomGetKeyCounter") +def _stateless_random_get_key_counter(proto): + _check_attrs(proto, {"T", "Tseed"}) + + def _func(seed: jnp.ndarray) -> Tuple[jnp.ndarray, jnp.ndarray]: + assert seed.shape == (2,), seed.shape + seed = jnp.sum(seed) # An arbitrary choice. + return jax.random.PRNGKey(seed), jnp.array(0, dtype=jnp.int32) + + return _func + + +@register_operation("StatelessMultinomial") +def _stateless_multinomial(proto): + """Parse a StatelessMultinomial op.""" + _check_attrs(proto, {"T", "Tseed", "output_dtype"}) + + dtype = tf.as_dtype(proto.attr["output_dtype"].type) + jax_dtype = anp.get_jax_dtype(dtype) + + def _func( + logits: jnp.ndarray, + num_samples: jnp.ndarray, + seed: jnp.ndarray, + ) -> jnp.ndarray: + assert seed.shape == (2,), seed.shape + seed = seed.astype(jnp.uint32) + shape = (num_samples, logits.shape[0]) + samples = jax.random.categorical(seed, logits, shape=shape) + return samples.astype(jax_dtype).transpose([1, 0]) + + return _func + + +@register_operation("StatelessRandomNormalV2") +def _stateless_random_normal_v2(proto): + """Parse a StatelessRandomNormalV2 op.""" + _check_attrs(proto, {"T", "Tshape", "dtype"}) + + dtype = tf.as_dtype(proto.attr["dtype"].type) + jax_dtype = anp.get_jax_dtype(dtype) + + def _func( + shape: jnp.ndarray, + key: jnp.ndarray, + counter: jnp.ndarray, + alg: jnp.ndarray, + ) -> jnp.ndarray: + del counter, alg # TODO(shaobohou) combine key and counter? + return jax.random.normal(key=key, shape=shape, dtype=jax_dtype) + + return _func + + +@register_operation("StatelessRandomUniformV2") +def _stateless_random_uniform_v2(proto): + """Parse a StatelessRandomNormalV2 op.""" + _check_attrs(proto, {"T", "Tshape", "dtype"}) + + dtype = tf.as_dtype(proto.attr["dtype"].type) + jax_dtype = anp.get_jax_dtype(dtype) + + def _func( + shape: jnp.ndarray, + key: jnp.ndarray, + counter: jnp.ndarray, + alg: jnp.ndarray, + ) -> jnp.ndarray: + del counter, alg # TODO(shaobohou) combine key and counter? + return jax.random.uniform(key=key, shape=shape, dtype=jax_dtype) + + return _func + + +@register_operation("StatelessRandomUniformFullIntV2") +@register_operation("StatelessRandomUniformIntV2") +def _stateless_random_uniform_int_v2(proto): + """Parse a StatelessRandomUniformIntV2 op.""" + _check_attrs(proto, {"T", "Tshape", "dtype"}) + + dtype = tf.as_dtype(proto.attr["dtype"].type) + jax_dtype = anp.get_jax_dtype(dtype) + + def _func( + shape: jnp.ndarray, + key: jnp.ndarray, + counter: jnp.ndarray, + alg: jnp.ndarray, + minval: jnp.ndarray = jnp.iinfo(jax_dtype).min, + maxval: jnp.ndarray = jnp.iinfo(jax_dtype).max, + ) -> jnp.ndarray: + del counter, alg # TODO(shaobohou) combine key and counter? + return jax.random.randint( + key=key, shape=shape, minval=minval, maxval=maxval, dtype=jax_dtype,) + + return _func + + +class _StatelessWhile(_HigherOrderFunction): + """Represents a StatelessWhile Op.""" + + def __call__(self, *args, cond_fun, body_fun, rng=None): + def real_cond_fun(args): + *cond_args, rng = args + _, cond_key, _ = [None] * 3 if rng is None else jax.random.split(rng, 3) + outputs = cond_fun(*cond_args, rng=cond_key) + if len(outputs) != 1: + raise ValueError( + f"Expected cond_fun to return a single value, found {outputs}") + return outputs[0] + + def real_body_fun(args): + *body_args, rng = args + key, _, body_key = [None] * 3 if rng is None else jax.random.split(rng, 3) + outputs = tuple(body_fun(*body_args, rng=body_key)) + return outputs + (key,) + + outputs = jax.lax.while_loop(real_cond_fun, real_body_fun, args + (rng,)) + return outputs + + +@register_operation("StatelessWhile") +@register_operation("While") +def _stateless_while(proto): + """Parse a StatelessWhile op.""" + _check_attrs(proto, + {"T", "body", "cond", "parallel_iterations", "output_shapes"}) + # TODO(shaobohou) Check proto.arg_attr? + + body_name = proto.attr["body"].func.name + cond_name = proto.attr["cond"].func.name + parallel_iterations = proto.attr["parallel_iterations"].i + output_shapes = [ + [d.size for d in xs.dim] for xs in proto.attr["output_shapes"].list.shape + ] + del parallel_iterations, output_shapes + + return _StatelessWhile(dict(cond_fun=cond_name, body_fun=body_name)) + + +@register_operation("StridedSlice") +def _strided_slice(proto): + """Parse a StridedSlice Op.""" + _check_attrs( + proto, { + "T", "Index", "begin_mask", "ellipsis_mask", "end_mask", + "new_axis_mask", "shrink_axis_mask" + }) + + def unpack(x): + return [(1 if(x & (1 << v)) else 0) for v in range(32)] + + begin_mask = unpack(proto.attr["begin_mask"].i) + ellipsis_mask = unpack(proto.attr["ellipsis_mask"].i) + end_mask = unpack(proto.attr["end_mask"].i) + new_axis_mask = unpack(proto.attr["new_axis_mask"].i) + shrink_axis_mask = unpack(proto.attr["shrink_axis_mask"].i) + + def _func( + x: jnp.ndarray, + begin: jnp.ndarray, + end: jnp.ndarray, + strides: jnp.ndarray, + ) -> jnp.ndarray: + """`begin`, `end` and `strides` must be concrete arrays.""" + + num_specs = len(begin) + inserted = ( + len(x.shape) + sum(new_axis_mask) - (num_specs - sum(ellipsis_mask))) + + # Rebuild slices. + dim = 0 + slices = [] + for idx in range(num_specs): + if new_axis_mask[idx] == 1: + slices.append(jnp.newaxis) + elif ellipsis_mask[idx] == 1: + slices.append(Ellipsis) + dim += inserted + else: + if shrink_axis_mask[idx] == 1: + slices.append(begin[idx]) + else: + beg_dim = begin[idx] if begin_mask[idx] == 0 else None + end_dim = end[idx] if end_mask[idx] == 0 else None + stride = strides[idx] + x_dim = x.shape[dim] + if (stride == 1 and + jax.core.symbolic_equal_dim(beg_dim or 0, 0) and + jax.core.symbolic_equal_dim(end_dim or x_dim, x_dim)): + slices.append(slice(None, None, None)) + else: + slices.append(slice(beg_dim, end_dim, stride)) + dim += 1 + + # TODO(shaobohou) Handle stride=1 slicing along polymoprhic dimensions. + return x[tuple(slices)] + + return _func + + +@register_operation("Sum") +def _sum(proto): + _check_attrs(proto, {"T", "Tidx", "keep_dims"}) + + keep_dims = proto.attr["keep_dims"].b + + def _func(x: jnp.ndarray, axis: jnp.ndarray) -> jnp.ndarray: + return anp.sum_(x, axis=axis.tolist(), keepdims=keep_dims) + + return _func + + +@register_operation("TopKV2") +def _top_k(proto): + _check_attrs(proto, {"T", "sorted"}) + sorted_arg = proto.attr["sorted"].b + if not sorted_arg: + raise ValueError("sorted=False in TopKV2 is not yet supported.") + + return jax.lax.top_k + + +@register_operation("Transpose") +def _transpose(proto): + _check_attrs(proto, {"T", "Tperm"}) + return lambda x, axes: jnp.transpose(x, axes=axes) + + +@register_operation("Unpack") +def _unpack(proto): + """Parse a Unpack op.""" + _check_attrs(proto, {"T", "axis", "num"}) + + axis = proto.attr["axis"].i + num = proto.attr["num"].i + + def _func(x: jnp.ndarray) -> jnp.ndarray: + if x.shape[axis] != num: + raise ValueError("Unpack expects dimension of {num} for axis={axis}, " + "found {x.shape[axis]}, shape={x.shape}") + return [anp.squeeze(v, axis=axis) for v in anp.split(x, num, axis=axis)] + + return _func + + +@register_operation("XlaConvV2") +@register_operation("XlaConv") +def _xla_conv(proto): + """Parse a XlaConv op.""" + _check_attrs( + proto, { + "T", "LhsT", "RhsT", "Tindices", "dimension_numbers", + "precision_config", "preferred_element_type", "batch_group_count" + }) + + dimension_numbers = xla_utils.convolution_dimension_numbers_from_proto( + proto.attr["dimension_numbers"].s) + precision_config = xla_utils.precision_config_from_proto( + proto.attr["precision_config"].s) + batch_group_count = proto.attr["batch_group_count"].i + if "preferred_element_type" in proto.attr: + dst_dtype = tf.as_dtype(proto.attr["preferred_element_type"].type) + dst_dtype = anp.get_jax_dtype(dst_dtype) + else: + dst_dtype = None + + def _func( + lhs: jnp.ndarray, + rhs: jnp.ndarray, + strides: jnp.ndarray, + padding: jnp.ndarray, + lhs_dilation: jnp.ndarray, + rhs_dilation: jnp.ndarray, + feature_group_count: jnp.ndarray, + ) -> jnp.ndarray: + return jax.lax.conv_general_dilated( + lhs, + rhs, + window_strides=strides.tolist(), + padding=[tuple(v) for v in padding], + lhs_dilation=lhs_dilation.tolist(), + rhs_dilation=rhs_dilation.tolist(), + dimension_numbers=dimension_numbers, + feature_group_count=feature_group_count.tolist(), # Should be int. + batch_group_count=batch_group_count or 1, + precision=precision_config, + preferred_element_type=dst_dtype) + + return _func + + +@register_operation("XlaDotV2") +@register_operation("XlaDot") +def _xla_dot(proto): + """Parse a XlaDot op.""" + _check_attrs( + proto, { + "T", "LhsT", "RhsT", "dimension_numbers", "precision_config", + "preferred_element_type" + }) + + dimension_numbers = xla_utils.dot_dimension_numbers_from_proto( + proto.attr["dimension_numbers"].s) + precision_config = xla_utils.precision_config_from_proto( + proto.attr["precision_config"].s) + if "preferred_element_type" in proto.attr: + dst_dtype = tf.as_dtype(proto.attr["preferred_element_type"].type) + dst_dtype = anp.get_jax_dtype(dst_dtype) + else: + dst_dtype = None + + def _func(lhs: jnp.ndarray, rhs: jnp.ndarray) -> jnp.ndarray: + return jax.lax.dot_general( + lhs, + rhs, + dimension_numbers, + precision_config, + preferred_element_type=dst_dtype) + + return _func + + +@register_operation("XlaDynamicSlice") +def _xla_dynamic_slice(proto): + """Parse a XlaDynamicSlice op.""" + _check_attrs(proto, {"T", "Tindices"}) + return jax.lax.dynamic_slice + + +@register_operation("XlaDynamicUpdateSlice") +def _xla_dynamic_update_slice(proto): + """Parse a XlaDynamicUpdateSlice op.""" + _check_attrs(proto, {"T", "Tindices"}) + return jax.lax.dynamic_update_slice + + +@register_operation("XlaGather") +def _xla_gather(proto): + """Parse a XlaGather op.""" + _check_attrs(proto, + {"T", "Tindices", "dimension_numbers", "indices_are_sorted"}) + + dimension_numbers = xla_utils.gather_dimension_numbers_from_proto( + proto.attr["dimension_numbers"].s) + # This should exist on the XLA op, even though it's not exposed by JAX. + indices_are_sorted = proto.attr["indices_are_sorted"].b + del indices_are_sorted + + def _func( + operand: jnp.ndarray, + start_indices: jnp.ndarray, + slice_indices: jnp.ndarray, + ) -> jnp.ndarray: + return jax.lax.gather(operand, start_indices, dimension_numbers, + slice_indices) + + return _func + + +@register_operation("XlaPad") +def _xla_pad(proto): + """Parse a XlaPad op.""" + _check_attrs(proto, {"T", "Tindices"}) + + def _func( + operand: jnp.ndarray, + padding_value: jnp.ndarray, + padding_low: jnp.ndarray, + padding_high: jnp.ndarray, + padding_interior: jnp.ndarray,) -> jnp.ndarray: + padding_config = np.stack([padding_low, padding_high, padding_interior], + axis=0) + padding_config = [tuple(x) for x in padding_config.transpose().tolist()] + return jax.lax.pad(operand, padding_value, padding_config) + + return _func + + +def _maybe_get_jaxpreqn( + jaxpr: jax.core.ClosedJaxpr) -> Optional[jax.core.JaxprEqn]: + def is_all_vars(vs): + return all([isinstance(v, jax.core.Var) for v in vs]) + + if (len(jaxpr.eqns) == 1 and + is_all_vars(jaxpr.jaxpr.invars) and is_all_vars(jaxpr.jaxpr.outvars) and + is_all_vars(jaxpr.eqns[0].invars) and is_all_vars(jaxpr.eqns[0].outvars)): + return jaxpr.eqns[0] + return None + + +@dataclasses.dataclass +class _XlaVariadicReduce(_HigherOrderFunction): + """Represents a XlaVariadicReduce Op.""" + + dimensions: Sequence[int] + + def __call__(self, *args: jnp.ndarray, reducer: Callable[..., Any]): + num_args = len(args) + operands = args[:(num_args//2)] + init_values = args[(num_args//2):] + assert len(operands) == len(init_values) + reducer_fn = lambda xs, ys: reducer(*xs, *ys) + return jax.lax.reduce(operands, init_values, reducer_fn, self.dimensions) + + +@register_operation("XlaVariadicReduceV2") +def _xla_variadic_reduce(proto): + """Parse a XlaVariadicReduceV2 op.""" + _check_attrs(proto, {"T", "reducer", "dimensions_to_reduce"}) + + reducer = proto.attr["reducer"].func.name + dimensions = tuple(proto.attr["dimensions_to_reduce"].list.i) + + return _XlaVariadicReduce(dict(reducer=reducer), dimensions=dimensions) + + +@dataclasses.dataclass +class _XlaVariadicSort(_HigherOrderFunction): + """Represents a XlaVariadicSort Op.""" + + is_stable: bool + + def _compute_num_keys( + self, + dtypes: Sequence[jnp.dtype], + comparator: Callable[..., Any], + ) -> int: + """Infer num_keys from the comparator and operands.""" + def get_operands(): + return sum([[jnp.array(0, dtype)] * 2 for dtype in dtypes], []) + + for idx in range(len(dtypes)): + operands = get_operands() + is_eq, = comparator(*operands) + + operands[idx * 2 + 1] = jnp.array(1, dtypes[idx]) + is_lt, = comparator(*operands) + + if idx == 0 and (not is_lt or is_eq): + raise ValueError( + "Only less-than comparator is supported for XlaVariadicSort.") + + if is_lt: + num_keys = idx + 1 + else: + break + + return num_keys + + def __call__(self, *args: jnp.ndarray, comparator: Callable[..., Any]): + operands = args[:-1] + dimension = args[-1].tolist() + + with jax.ensure_compile_time_eval(): + dtypes = [x.dtype for x in operands] + num_keys = self._compute_num_keys(dtypes, comparator) + + return jax.lax.sort( + operands, + dimension=dimension, + is_stable=self.is_stable, + num_keys=num_keys) + + +@register_operation("XlaVariadicSort") +def _xla_variadic_sort(proto): + """Parse a XlaVariadicSort op.""" + _check_attrs(proto, {"T", "comparator", "is_stable"}) + + comparator = proto.attr["comparator"].func.name + is_stable = proto.attr["is_stable"].b + + logging.warning( + "Support for XlaVariadicSort is limited, current implementation assumes " + "the op is generated by `jax2tf.convert(jax.lax.sort)` and does not " + "support arbitrary comparators") + + return _XlaVariadicSort(dict(comparator=comparator), is_stable=is_stable) + + +class _XlaReduceWindow(_HigherOrderFunction): + """Represents a XlaReduceWindow Op.""" + + def __call__( + self, + operand: jnp.ndarray, + init_value: jnp.ndarray, + window_dimensions: jnp.ndarray, + window_strides: jnp.ndarray, + base_dilation: jnp.ndarray, + window_dilation: jnp.ndarray, + padding: jnp.ndarray, + *, + computation: Callable[..., Any], + ): + # Pattern matching computations that can be specialized. + primitives = { + jax.lax.max_p: jax.lax.max, + jax.lax.min_p: jax.lax.min, + jax.lax.add_p: jax.lax.add, + } + computation_jaxpr = jax.make_jaxpr(computation)(init_value, init_value) + computation_eqn = _maybe_get_jaxpreqn(computation_jaxpr) + if computation_eqn is not None and computation_eqn.primitive in primitives: + computation_fn = primitives[computation_eqn.primitive] + else: + computation_fn = lambda *args: computation(*args)[0] + logging.info("Calling reduce_window with the following computation:\n%s", + computation_jaxpr) + + return jax.lax.reduce_window( + operand, + init_value, + computation=computation_fn, + window_dimensions=window_dimensions.tolist(), + window_strides=window_strides.tolist(), + padding=[tuple(v) for v in padding.tolist()], + base_dilation=base_dilation.tolist(), + window_dilation=window_dilation.tolist()) + + +@register_operation("XlaReduceWindow") +def _xla_reduce_window(proto): + """Parse a XlaReduceWindow op.""" + _check_attrs(proto, {"T", "Tindices", "computation"}) + + computation = proto.attr["computation"].func.name + + return _XlaReduceWindow(dict(computation=computation)) + + +@register_operation("XlaRngBitGenerator") +def _xla_rng_bit_generator(proto): + """Parse a XlaRngBitGenerator op.""" + _check_attrs(proto, {"Tshape", "dtype"}) + + dtype = tf.as_dtype(proto.attr["dtype"].type) + jax_dtype = anp.get_jax_dtype(dtype) + if jax_dtype != jnp.uint32: + raise ValueError( + f"XlaRngBitGenerator currently only supports uint32, found{jax_dtype}.") + + def _func( + algorithm: jnp.ndarray, + key: jnp.ndarray, + shape: jnp.ndarray, + ) -> Tuple[jnp.ndarray, jnp.ndarray]: + return jax.lax.rng_bit_generator( + key=key.reshape(-1), + shape=shape, + dtype=jax_dtype, + # See tensorflow/compiler/tf2xla/ops/xla_ops.cc#L812 + algorithm=xla_utils.get_random_algorithm_from_tf(algorithm.tolist()), + ) + + return _func + + +@dataclasses.dataclass +class _XlaScatter(_HigherOrderFunction): + """Represents a XlaScatter Op.""" + + dimension_numbers: jax.lax.ScatterDimensionNumbers + indices_are_sorted: bool + + def __call__( + self, + operand: jnp.ndarray, + indices: jnp.ndarray, + updates: jnp.ndarray, + *, + update_computation: Callable[..., Any], + ) -> jnp.ndarray: + dummy_zero = jnp.array(0).astype(operand.dtype) + jaxpr = jax.make_jaxpr(update_computation)(dummy_zero, dummy_zero) + if not jaxpr.eqns: + scatter_fn = jax.lax.scatter + elif len(jaxpr.eqns) == 1 and jaxpr.eqns[0].primitive == jax.lax.add_p: + scatter_fn = jax.lax.scatter_add + elif len(jaxpr.eqns) == 1 and jaxpr.eqns[0].primitive == jax.lax.mul_p: + scatter_fn = jax.lax.scatter_mul + elif len(jaxpr.eqns) == 1 and jaxpr.eqns[0].primitive == jax.lax.min_p: + scatter_fn = jax.lax.scatter_min + elif len(jaxpr.eqns) == 1 and jaxpr.eqns[0].primitive == jax.lax.max_p: + scatter_fn = jax.lax.scatter_max + else: + raise ValueError( + "Reducer not supported as `update_computation`, found {jaxpr}") + + return scatter_fn( + operand, + indices, + updates, + dimension_numbers=self.dimension_numbers, + indices_are_sorted=self.indices_are_sorted) + + +@register_operation("XlaScatter") +def _xla_scatter(proto): + """Parse a XlaScatter op.""" + _check_attrs( + proto, { + "T", "Tindices", "dimension_numbers", "indices_are_sorted", + "update_computation" + }) + + dimension_numbers = xla_utils.scatter_dimension_numbers_from_proto( + proto.attr["dimension_numbers"].s) + update_computation = proto.attr["update_computation"].func.name + indices_are_sorted = proto.attr["indices_are_sorted"].b + + return _XlaScatter( + dict(update_computation=update_computation), dimension_numbers, + indices_are_sorted) + + +class _XlaSelectAndScatter(_HigherOrderFunction): + """Represents a XlaSelectAndScatter Op.""" + + def __call__( + self, + operand: jnp.ndarray, + window_dimensions: jnp.ndarray, + window_strides: jnp.ndarray, + padding: jnp.ndarray, + source: jnp.ndarray, + inner_init_value: jnp.ndarray, + *, + scatter: Callable[..., Any], + select: Callable[..., Any], + ) -> jnp.ndarray: + # Because jax.lax._select_and_scatter is not part of the JAX public api, we + # are using a crude pattern matching to determine the reducer used in the + # original reduce_window call. + + scatter_jaxpr = jax.make_jaxpr(scatter)(inner_init_value, inner_init_value) + scatter_eqn = _maybe_get_jaxpreqn(scatter_jaxpr) + if scatter_eqn is not None and scatter_eqn.primitive is not jax.lax.add_p: + raise ValueError( + f"Only Add is supported as scatter function, found {scatter_jaxpr}.") + + # TODO(shaobohou) Support jax.lax.add for AvgPool. + select_primitives = { + jax.lax.ge_p: (-jnp.inf, jax.lax.max), + jax.lax.le_p: (jnp.inf, jax.lax.min), + } + select_jaxpr = jax.make_jaxpr(select)(inner_init_value, inner_init_value) + select_eqn = _maybe_get_jaxpreqn(select_jaxpr) + if select_eqn is not None and select_eqn.primitive in select_primitives: + init_value, computation = select_primitives[select_eqn.primitive] + else: + raise ValueError("Only greater_equal (Max) and less_equal (Min) are " + f"supported as select function, found {select_jaxpr}") + + def reduce_window(x): + return jax.lax.reduce_window( + x, + init_value, + computation=computation, + window_dimensions=tuple(window_dimensions.tolist()), + window_strides=tuple(window_strides.tolist()), + padding=[tuple(v) for v in padding.tolist()]) + + _, f_vjp = jax.vjp(reduce_window, operand) + return f_vjp(source) + + +@register_operation("XlaSelectAndScatter") +def _xla_select_and_scatter(proto): + """Parse a XlaSelectAndScatter op.""" + _check_attrs(proto, {"T", "Tindices", "scatter", "select"}) + + scatter = proto.attr["scatter"].func.name + select = proto.attr["select"].func.name + + return _XlaSelectAndScatter(dict(scatter=scatter, select=select)) diff --git a/tf2jax/_src/ops_test.py b/tf2jax/_src/ops_test.py index 8f20806..e05d22b 100644 --- a/tf2jax/_src/ops_test.py +++ b/tf2jax/_src/ops_test.py @@ -23,6 +23,7 @@ import numpy as np import tensorflow as tf +from tf2jax._src import ops from tf2jax._src import tf2jax import tree @@ -38,6 +39,11 @@ def _nullcontext(enter_result=None): class OpsTest(tf.test.TestCase, parameterized.TestCase): + def test_get_unsupported(self): + unsupported = ops.get_unsupported_operations( + ["Add", "Relu", "NotAnOp", "Blah", "Relu"]) + self.assertEqual(unsupported, {"NotAnOp", "Blah"}) + def _assert_if_jitted(self, err): jitted = self.variant.type == chex.ChexVariantType.WITH_JIT return self.assertRaises(err) if jitted else _nullcontext() diff --git a/tf2jax/_src/roundtrip_test.py b/tf2jax/_src/roundtrip_test.py index ae6a128..5a6b02b 100644 --- a/tf2jax/_src/roundtrip_test.py +++ b/tf2jax/_src/roundtrip_test.py @@ -25,6 +25,7 @@ import numpy as np import tensorflow as tf +from tf2jax._src import config from tf2jax._src import tf2jax import tree @@ -68,7 +69,7 @@ def assert_grad_all_close(*args): jax.tree_map(self.assertAllClose, jax_outputs, tf_outputs) # Jax -> TF -> Jax - with tf2jax.override_config("convert_custom_gradient", with_custom_grad): + with config.override_config("convert_custom_gradient", with_custom_grad): rejax_func = tf2jax.convert_functional( tf.function(tf_func), *tree.map_structure(np.zeros_like, inputs)) rejax_func = self.variant(rejax_func) @@ -90,7 +91,7 @@ def assert_grad_all_close(*args): jax.tree_map(self.assertAllClose, jax_outputs, restored_tf_outputs) # Jax -> TF -> SavedModel -> TF -> Jax - with tf2jax.override_config("convert_custom_gradient", with_custom_grad): + with config.override_config("convert_custom_gradient", with_custom_grad): rejax_too_func = tf2jax.convert_functional( restored.f, *tree.map_structure(np.zeros_like, inputs)) rejax_too_func = self.variant(rejax_too_func) @@ -616,7 +617,7 @@ def grad(dy): tf_forward = tf.function(tf_forward) # JAX -> TF -> JAX - with tf2jax.override_config("convert_custom_gradient", True): + with config.override_config("convert_custom_gradient", True): jax_forward = tf2jax.convert_functional(tf_forward, tf.zeros_like(inputs)) jax_forward = self.variant(jax_forward) @@ -634,7 +635,7 @@ def grad(dy): restored = tf.saved_model.load(tmp_dir.full_path) # Jax -> TF -> SavedModel -> TF -> Jax - with tf2jax.override_config("convert_custom_gradient", True): + with config.override_config("convert_custom_gradient", True): re_jax_forward = tf2jax.convert_functional(restored.f, tf.zeros_like(inputs)) re_jax_forward = self.variant(re_jax_forward) @@ -678,7 +679,7 @@ def grad(dy): tf_fn_too = tf.function(tf_fn_too) # JAX -> TF -> CALL_TF -> TF -> JAX - with tf2jax.override_config("convert_custom_gradient", True): + with config.override_config("convert_custom_gradient", True): jax_fn_too = tf2jax.convert_functional(tf_fn_too, np.zeros_like(inputs)) jax_outputs = jax_fn_too(inputs) @@ -729,7 +730,7 @@ def forward(x): tf_forward = tf.function(tf_forward) # JAX -> TF -> JAX - with tf2jax.override_config("convert_custom_gradient", True): + with config.override_config("convert_custom_gradient", True): jax_forward = tf2jax.convert_functional(tf_forward, tf.zeros_like(inputs)) jax_forward = self.variant(jax_forward) @@ -748,7 +749,7 @@ def forward(x): restored = tf.saved_model.load(tmp_dir.full_path) # Jax -> TF -> SavedModel -> TF -> Jax - with tf2jax.override_config("convert_custom_gradient", True): + with config.override_config("convert_custom_gradient", True): re_jax_forward = tf2jax.convert_functional(restored.f, tf.zeros_like(inputs)) re_jax_forward = self.variant(re_jax_forward) diff --git a/tf2jax/_src/tf2jax.py b/tf2jax/_src/tf2jax.py index 4a74489..062ba4e 100644 --- a/tf2jax/_src/tf2jax.py +++ b/tf2jax/_src/tf2jax.py @@ -15,12 +15,9 @@ """Experimental functions for converting TF graphs to Jax functions.""" import collections -import contextlib -import dataclasses -import functools import inspect import itertools -from typing import Any, Callable, Iterator, Optional, Mapping, NamedTuple, Sequence, Set, Tuple, Union +from typing import Any, Callable, Iterator, Optional, Mapping, NamedTuple, Sequence, Tuple, Union from absl import logging @@ -28,9 +25,9 @@ import jax.numpy as jnp import numpy as np import tensorflow as tf -from tf2jax._src import numpy_compat as anp +from tf2jax._src import config +from tf2jax._src import ops from tf2jax._src import utils -from tf2jax._src import xla_utils import tree # Import usage logging here. @@ -39,67 +36,11 @@ from tensorflow.python.framework import ops as tf_ops # pylint: disable=no-name-in-module -# NoOp inserted to trigger side effects in function with no return values. -_EMPTY_RETURN_OP_NAME = "__NO_RETURN__" -_EMPTY_RETURN_VALUE = object() - -_UNUSED_INPUT = object() - -_config = dict( - strict_shape_check=True, - strict_dtype_check=False, - force_const_float32_to_bfloat16=False, - force_const_float64_to_bfloat16=False, - convert_custom_gradient=True, - infer_relu_from_jax2tf=True, -) - - -def get_config(name: str): - return _config[name] - - -def update_config(name: str, value: Any): - if name in _config: - _config[name] = value - else: - raise ValueError( - f"Parameter named {name} not found in config={_config}") +_EMPTY_RETURN_OP_NAME = ops._EMPTY_RETURN_OP_NAME # pylint: disable=protected-access +_EMPTY_RETURN_VALUE = ops._EMPTY_RETURN_VALUE # pylint: disable=protected-access -@contextlib.contextmanager -def override_config(name: str, value: Any): - old_value = get_config(name) - update_config(name, value) - try: - yield - finally: - update_config(name, old_value) - - -def _check_attrs(proto, expected: Set[str]): - unexpected = [] - for k, v in proto.attr.items(): - # Ignore attributes with "_" prefix, as they appear to be undocumented. - if k not in expected and not k.startswith("_"): - unexpected.append(" `" + f"{k}={v}".strip() + "`") - if unexpected: - raise ValueError("\n".join( - [f"Unexpected attr(s) when parsing {proto.op}: {proto.name}"] + - unexpected)) - - -def _get_jax_op( - jax_op: Callable[..., Any], - expected_attrs: Set[str], -) -> Callable[..., Any]: - """For wrapping simple ops with no optional parameters.""" - - def wrapped(proto): - _check_attrs(proto, expected_attrs) - return jax_op - - return wrapped +_UNUSED_INPUT = object() def _fix_jax_poly_shape(shape: Tuple[Any, ...]) -> Tuple[Any, ...]: @@ -114,1926 +55,6 @@ def _fix_jax_poly_shape(shape: Tuple[Any, ...]) -> Tuple[Any, ...]: return tuple(good_shape) -_jax_ops = { - "Abs": _get_jax_op(jnp.abs, {"T"}), - "Add": _get_jax_op(anp.add, {"T"}), - "AddN": _get_jax_op( - lambda *args: anp.sum_(anp.stack(args, axis=0), axis=0, keepdims=False), - {"T", "N"}), - "AddV2": _get_jax_op(anp.add, {"T"}), - "ArgMax": _get_jax_op(jnp.argmax, {"T", "Tidx", "output_type"}), - "ArgMin": _get_jax_op(jnp.argmin, {"T", "Tidx", "output_type"}), - "Acosh": _get_jax_op(jnp.arccosh, {"T"}), - "Asinh": _get_jax_op(jnp.arcsinh, {"T"}), - "Atanh": _get_jax_op(jnp.arctanh, {"T"}), - "Atan2": _get_jax_op(jnp.arctan2, {"T"}), - "BitwiseAnd": _get_jax_op(jnp.bitwise_and, {"T"}), - "BitwiseOr": _get_jax_op(jnp.bitwise_or, {"T"}), - "BitwiseXor": _get_jax_op(jnp.bitwise_xor, {"T"}), - "BroadcastTo": _get_jax_op(anp.broadcast_to, {"T", "Tidx"}), - "Ceil": _get_jax_op(jnp.ceil, {"T"}), - "Complex": _get_jax_op(jax.lax.complex, {"T", "Tout"}), - "ComplexAbs": _get_jax_op(jax.lax.abs, {"T", "Tout"}), - "Conj": _get_jax_op(jax.lax.conj, {"T", "Tout"}), - "Cos": _get_jax_op(jnp.cos, {"T"}), - "Cosh": _get_jax_op(jnp.cosh, {"T"}), - "Digamma": _get_jax_op(jax.lax.digamma, {"T"}), - "Div": _get_jax_op(anp.divide, {"T"}), - "Elu": _get_jax_op(jax.nn.elu, {"T"}), - "Equal": _get_jax_op(anp.equal, {"T", "incompatible_shape_error"}), - "Erf": _get_jax_op(jax.lax.erf, {"T"}), - "Erfc": _get_jax_op(jax.lax.erfc, {"T"}), - "Erfinv": _get_jax_op(jax.lax.erf_inv, {"T"}), - "Exp": _get_jax_op(jnp.exp, {"T"}), - "Expm1": _get_jax_op(jnp.expm1, {"T"}), - "ExpandDims": _get_jax_op(anp.expand_dims, {"T", "Tdim"}), - "Floor": _get_jax_op(jnp.floor, {"T"}), - "FloorMod": _get_jax_op(anp.mod, {"T"}), - "FloorDiv": _get_jax_op(anp.floor_divide, {"T"}), - "Greater": _get_jax_op(anp.greater, {"T"}), - "GreaterEqual": _get_jax_op(anp.greater_equal, {"T"}), - "Identity": _get_jax_op(lambda x: x, {"T"}), - "Igamma": _get_jax_op(jax.lax.igamma, {"T"}), - "Igammac": _get_jax_op(jax.lax.igammac, {"T"}), - "Imag": _get_jax_op(jax.lax.imag, {"T", "Tout"}), - "IsFinite": _get_jax_op(jnp.isfinite, {"T"}), - "Invert": _get_jax_op(jnp.bitwise_not, {"T"}), - "L2Loss": _get_jax_op(lambda x: 0.5 * jnp.sum(jnp.square(x)), {"T"}), - "LeftShift": _get_jax_op(jnp.left_shift, {"T"}), - "Less": _get_jax_op(anp.less, {"T", "incompatible_shape_error"}), - "LessEqual": _get_jax_op(anp.less_equal, {"T", "incompatible_shape_error"}), - "Lgamma": _get_jax_op(jax.lax.lgamma, {"T"}), - "Log": _get_jax_op(jnp.log, {"T"}), - "Log1p": _get_jax_op(jnp.log1p, {"T"}), - "LogicalAnd": _get_jax_op(jnp.logical_and, {"T"}), - "LogicalNot": _get_jax_op(jnp.logical_not, {"T"}), - "LogicalOr": _get_jax_op(jnp.logical_or, {"T"}), - "Minimum": _get_jax_op(anp.minimum, {"T"}), - "Maximum": _get_jax_op(anp.maximum, {"T"}), - "Mul": _get_jax_op(anp.multiply, {"T"}), - "Neg": _get_jax_op(anp.negative, {"T"}), - "NoOp": _get_jax_op(lambda: _EMPTY_RETURN_VALUE, set({})), - "NotEqual": _get_jax_op(anp.not_equal, {"T", "incompatible_shape_error"}), - "OnesLike": _get_jax_op(jnp.ones_like, {"T"}), - "Pow": _get_jax_op(anp.power, {"T"}), - "Real": _get_jax_op(jax.lax.real, {"T", "Tout"}), - "ReadVariableOp": _get_jax_op(lambda x: x, {"dtype"}), - "RealDiv": _get_jax_op(anp.true_divide, {"T"}), - "Reciprocal": _get_jax_op(anp.reciprocal, {"T"}), - "Relu": _get_jax_op(jax.nn.relu, {"T"}), - "Relu6": _get_jax_op(jax.nn.relu6, {"T"}), - "ReverseV2": _get_jax_op(anp.flip, {"T", "Tidx"}), - "RightShift": _get_jax_op(jnp.right_shift, {"T"}), - "Round": _get_jax_op(jnp.round, {"T"}), - "Rsqrt": _get_jax_op(jax.lax.rsqrt, {"T"}), - "Shape": _get_jax_op(lambda x: np.array(jnp.shape(x)), {"T", "out_type"}), - "Sigmoid": _get_jax_op(jax.nn.sigmoid, {"T"}), - "Sign": _get_jax_op(jnp.sign, {"T"}), - "Sin": _get_jax_op(jnp.sin, {"T"}), - "Sinh": _get_jax_op(jnp.sinh, {"T"}), - "Size": _get_jax_op(lambda x: np.prod(jnp.shape(x), dtype=np.int32), - {"T", "out_type"}), - "Softplus": _get_jax_op(jax.nn.softplus, {"T"}), - "Sqrt": _get_jax_op(jnp.sqrt, {"T"}), - "Square": _get_jax_op(jnp.square, {"T"}), - "StopGradient": _get_jax_op(jax.lax.stop_gradient, {"T"}), - "Sub": _get_jax_op(anp.subtract, {"T"}), - "Tan": _get_jax_op(jnp.tan, {"T"}), - "Tanh": _get_jax_op(jnp.tanh, {"T"}), - "Tile": _get_jax_op(anp.tile, {"T", "Tmultiples"}), - "Where": _get_jax_op(jnp.argwhere, {"T"}), - "ZerosLike": _get_jax_op(jnp.zeros_like, {"T"}), - # The assignment logic is handled in _OpNode and convert(). - "AssignAddVariableOp": _get_jax_op(jnp.add, {"dtype"}), - "AssignSubVariableOp": _get_jax_op(jnp.subtract, {"dtype"}), - "AssignVariableOp": _get_jax_op( - lambda var, x: x, {"dtype", "validate_shape"}), -} - - -def _register(func: Callable[..., Any], op_name: str): - curr_func = _jax_ops.get(op_name, None) - if curr_func is None: - _jax_ops[op_name] = func - else: - if curr_func != func: - raise ValueError( - f"{op_name} is already registered as {curr_func}, received {func}.") - - return func - - -def register_operation(op_name): - return functools.partial(_register, op_name=op_name) - - -@dataclasses.dataclass -class _HigherOrderFunction: - """Base class for higher order ops.""" - inner_fn_names: Mapping[str, str] - - def get_inner_functions( - self, library_functions: Mapping[str, Callable[..., Any]] - ) -> Mapping[str, Callable[..., Any]]: - return {k: library_functions[v] for k, v in self.inner_fn_names.items()} - - -@register_operation("All") -def _all(proto): - _check_attrs(proto, {"Tidx", "keep_dims"}) - keep_dims = proto.attr["keep_dims"].b - return lambda x, axis: anp.all_(x, axis=axis.tolist(), keepdims=keep_dims) - - -@register_operation("Any") -def _any(proto): - _check_attrs(proto, {"Tidx", "keep_dims"}) - keep_dims = proto.attr["keep_dims"].b - return lambda x, axis: anp.any_(x, axis=axis.tolist(), keepdims=keep_dims) - - -@register_operation("Assert") -def _assert(proto): - _check_attrs(proto, {"T", "summarize"}) - - logging.warning("Assert has no effect and will just return the data.") - - return lambda cond, data: data - - -@register_operation("AvgPool") -def _avg_pool(proto): - """Parse a AvgPool Op.""" - _check_attrs( - proto, - {"T", "padding", "explicit_paddings", "ksize", "strides", "data_format"}) - - explicit_paddings = tuple(proto.attr["explicit_paddings"].list.i) - if explicit_paddings: - raise ValueError("explicit_padding in AvgPool not yet supported.") - - padding = str(proto.attr["padding"].s, "utf-8") - ksize = tuple(proto.attr["ksize"].list.i) - strides = tuple(proto.attr["strides"].list.i) - data_format = str(proto.attr["data_format"].s, "utf-8") - if data_format not in ("NHWC", "NCHW"): - raise ValueError(f"Found unsupported data format {data_format}.") - - reduce_window_args = dict( - init_value=0., - computation=jax.lax.add, - window_dimensions=ksize, - window_strides=strides, - padding=padding) - - def _func(x: jnp.ndarray) -> jnp.ndarray: - pooled = jax.lax.reduce_window(x, **reduce_window_args) - if padding == "VALID": - window_counts = np.prod(ksize) - else: - window_counts = jax.lax.reduce_window( - jnp.ones_like(x), **reduce_window_args) - return pooled / window_counts - - return _func - - -@register_operation("BiasAdd") -def _bias_add(proto): - """Parse a BiasAdd Op.""" - _check_attrs(proto, {"T", "data_format"}) - - data_format = str(proto.attr["data_format"].s, "utf-8") - if data_format == "NHWC": - expand_axis_fn = lambda x: [d for d in range(x.ndim) if d != x.ndim - 1] - elif data_format == "NCHW": - # TODO(shaobohou) this seems wrong but matches TF behaviour. - expand_axis_fn = lambda x: [d for d in range(x.ndim) if d != 1] - else: - raise ValueError(f"Found unsupported data format {data_format}.") - - def _func(value: jnp.ndarray, bias: jnp.ndarray) -> jnp.ndarray: - if bias.ndim != 1: - raise ValueError( - f"Expected `bias` as a 1D array, found array with {bias.ndim} dims.") - bias = anp.expand_dims(bias, axis=expand_axis_fn(value)) - return anp.add(value, bias) - - return _func - - -@register_operation("Bitcast") -def _bit_cast(proto): - _check_attrs(proto, {"T", "type"}) - dst_type = tf.as_dtype(proto.attr["type"].type) - return lambda x: jax.lax.bitcast_convert_type(x, anp.get_jax_dtype(dst_type)) - - -@register_operation("BroadcastArgs") -def _broadcast_args(proto): - _check_attrs(proto, {"T"}) - return lambda s0, s1: np.array(np.broadcast(np.zeros(s0), np.zeros(s1)).shape) - - -class _CaseOp(_HigherOrderFunction): - """Represents a Case Op.""" - - def __call__(self, branch_index, *operand, **branch_fns): - def create_branch(fn): - return lambda args: fn(*args) - branches = [create_branch(fn) for _, fn in sorted(branch_fns.items())] - return jax.lax.switch(branch_index, branches=branches, operand=operand) - - -@register_operation("StatelessCase") -@register_operation("Case") -def _case(proto): - """Parse a Case op.""" - _check_attrs(proto, {"Tin", "Tout", "output_shapes", "branches"}) - - branches = [f.name for f in proto.attr["branches"].list.func] - output_shapes = [ - [d.size for d in xs.dim] for xs in proto.attr["output_shapes"].list.shape - ] - del output_shapes - - return _CaseOp({f"fn_{k:06}": v for k, v in enumerate(branches)}) - - -@register_operation("Cast") -def _cast(proto): - """Parse a Cast Op.""" - _check_attrs(proto, {"SrcT", "DstT", "Truncate"}) - - src_type = tf.as_dtype(proto.attr["SrcT"].type) - dst_type = tf.as_dtype(proto.attr["DstT"].type) - truncate = proto.attr["Truncate"].b - del src_type - - if truncate: - raise ValueError(f"Cast does not support truncate={truncate}.") - - def _func(x: jnp.ndarray) -> jnp.ndarray: - return anp.asarray(x, dst_type) - - return _func - - -@register_operation("ConjugateTranspose") -def _conjugate_transpose(proto): - _check_attrs(proto, {"T", "Tperm"}) - return lambda x, axes: jax.lax.conj(jnp.transpose(x, axes=axes)) - - -@register_operation("ConcatV2") -def _concatenate(proto): - """Parse a ConcatV2 Op.""" - _check_attrs(proto, {"T", "Tidx", "N"}) - - num_arrays = proto.attr["N"].i - - def _func(*args) -> jnp.ndarray: - if len(args) != num_arrays + 1: - raise ValueError( - f"Concatenate expects {num_arrays} args, received {len(args)}.") - - *inputs, axis = args - return anp.concatenate(inputs, axis=axis) - - return _func - - -@register_operation("Const") -def _const(proto): - """Parse a Const Op.""" - _check_attrs(proto, {"dtype", "value"}) - value = tf.make_ndarray(proto.attr["value"].tensor) - dtype = value.dtype - - force_float32 = get_config("force_const_float32_to_bfloat16") - force_float64 = get_config("force_const_float64_to_bfloat16") - if ((force_float32 and dtype == np.float32) or - (force_float64 and dtype == np.float64)): - # NOTE: `jnp.asarray` (rather than the `np` version) cannot be used here; - # using it in a jitted context can produce runtime `UnexpectedTracerError`s. - bf16_value = np.asarray(value, dtype=jnp.bfloat16) - logging.warning("Converting float consts to bfloat16, from %s, to %s.", - value, bf16_value) - value = bf16_value - - return lambda: value - - -@register_operation("Conv2D") -def _conv2d(proto): - """Parse a Conv2D Op.""" - _check_attrs( - proto, { - "T", "padding", "explicit_paddings", "dilations", "strides", - "data_format", "use_cudnn_on_gpu" - }) - - explicit_paddings = tuple(proto.attr["explicit_paddings"].list.i) - if explicit_paddings: - raise ValueError("explicit_padding in Conv2D not yet supported.") - - padding = str(proto.attr["padding"].s, "utf-8") - dilations = tuple(proto.attr["dilations"].list.i) - strides = tuple(proto.attr["strides"].list.i) - data_format = str(proto.attr["data_format"].s, "utf-8") - if data_format == "NHWC": - dimension_numbers = ("NHWC", "HWIO", "NHWC") - strides = xla_utils.get_conv_sequence(strides, ndim=2, channel_index=-1) - dilations = xla_utils.get_conv_sequence(dilations, ndim=2, channel_index=-1) - feature_group_count_fn = lambda lhs, rhs: lhs.shape[3] // rhs.shape[2] - elif data_format == "NCHW": - dimension_numbers = ("NCHW", "HWIO", "NCHW") - strides = xla_utils.get_conv_sequence(strides, ndim=2, channel_index=1) - dilations = xla_utils.get_conv_sequence(dilations, ndim=2, channel_index=1) - feature_group_count_fn = lambda lhs, rhs: lhs.shape[1] // rhs.shape[2] - else: - raise ValueError(f"Found unsupported data format {data_format}.") - - _ = proto.attr["use_cudnn_on_gpu"].b - - def _func(lhs: jnp.ndarray, rhs: jnp.ndarray) -> jnp.ndarray: - feature_group_count = feature_group_count_fn(lhs, rhs) - return jax.lax.conv_general_dilated( - lhs, - rhs, - window_strides=strides, - padding=padding, - dimension_numbers=dimension_numbers, - rhs_dilation=dilations, - feature_group_count=feature_group_count) - - return _func - - -@register_operation("Conv2DBackpropInput") -def _conv2d_backprop_input(proto): - """Parse a Conv2DBackpropInput Op.""" - _check_attrs( - proto, { - "T", "padding", "explicit_paddings", "dilations", "strides", - "data_format", "use_cudnn_on_gpu" - }) - - explicit_paddings = tuple(proto.attr["explicit_paddings"].list.i) - if explicit_paddings: - raise ValueError( - "explicit_padding in Conv2DBackpropInput not yet supported.") - - padding = str(proto.attr["padding"].s, "utf-8") - dilations = tuple(proto.attr["dilations"].list.i) - strides = tuple(proto.attr["strides"].list.i) - data_format = str(proto.attr["data_format"].s, "utf-8") - if data_format == "NHWC": - dimension_numbers = ("NHWC", "HWIO", "NHWC") - strides = xla_utils.get_conv_sequence(strides, ndim=2, channel_index=-1) - dilations = xla_utils.get_conv_sequence(dilations, ndim=2, channel_index=-1) - elif data_format == "NCHW": - dimension_numbers = ("NCHW", "HWIO", "NCHW") - strides = xla_utils.get_conv_sequence(strides, ndim=2, channel_index=1) - dilations = xla_utils.get_conv_sequence(dilations, ndim=2, channel_index=1) - else: - raise ValueError(f"Found unsupported data format {data_format}.") - - _ = proto.attr["use_cudnn_on_gpu"].b - - def _func( - input_sizes: jnp.ndarray, - filters: jnp.ndarray, - out_backprop: jnp.ndarray, - ) -> jnp.ndarray: - del input_sizes - return jax.lax.conv_transpose( - out_backprop, - filters, - strides=strides, - padding=padding, - rhs_dilation=dilations, - transpose_kernel=True, - dimension_numbers=dimension_numbers) - - return _func - - -@register_operation("Cumsum") -def _cumsum(proto): - """Parse a Cumsum Op.""" - _check_attrs(proto, {"T", "Tidx", "exclusive", "reverse"}) - - exclusive = proto.attr["exclusive"].b - reverse = proto.attr["reverse"].b - - def _func(x: jnp.ndarray, axis: jnp.ndarray) -> jnp.ndarray: - axis: int = axis.tolist() - if reverse: - x = anp.flip(x, axis=axis) - if exclusive: - pad_shape = list(x.shape) - pad_shape[axis] = 1 - x = anp.concatenate([np.zeros(pad_shape, dtype=x.dtype), x], axis=axis) - x = x[(slice(None),) * axis + (slice(0, -1), Ellipsis)] - res = anp.cumsum(x, axis=axis) - if reverse: - res = anp.flip(res, axis=axis) - return res - - return _func - - -@register_operation("DepthwiseConv2dNative") -def _depthwise_conv2d(proto): - """Parse a DepthwiseConv2d Op.""" - _check_attrs(proto, { - "T", "strides", "dilations", "padding", "data_format", "explicit_paddings" - }) - - explicit_paddings = tuple(proto.attr["explicit_paddings"].list.i) - if explicit_paddings: - explicit_paddings = [ - tuple(x) for x in np.array(explicit_paddings).reshape(4, 2).tolist() - ] - - padding = explicit_paddings or str(proto.attr["padding"].s, "utf-8") - dilations = tuple(proto.attr["dilations"].list.i) - strides = tuple(proto.attr["strides"].list.i) - data_format = str(proto.attr["data_format"].s, "utf-8") - if data_format == "NHWC": - if explicit_paddings: - padding = padding[1:3] - dimension_numbers = ("NHWC", "HWIO", "NHWC") - strides = xla_utils.get_conv_sequence(strides, ndim=2, channel_index=-1) - dilations = xla_utils.get_conv_sequence(dilations, ndim=2, channel_index=-1) - channel_index = -1 - elif data_format == "NCHW": - if explicit_paddings: - padding = padding[2:] - dimension_numbers = ("NCHW", "HWIO", "NCHW") - strides = xla_utils.get_conv_sequence(strides, ndim=2, channel_index=1) - dilations = xla_utils.get_conv_sequence(dilations, ndim=2, channel_index=1) - channel_index = 1 - else: - raise ValueError(f"Found unsupported data format {data_format}.") - - def _func(lhs: jnp.ndarray, rhs: jnp.ndarray) -> jnp.ndarray: - output_dim = rhs.shape[2] * rhs.shape[3] - return jax.lax.conv_general_dilated( - lhs, - jnp.reshape(rhs, rhs.shape[:2] + (1, output_dim)), - window_strides=strides, - padding=padding, - dimension_numbers=dimension_numbers, - rhs_dilation=dilations, - feature_group_count=lhs.shape[channel_index]) - - return _func - - -@register_operation("Einsum") -def _einsum(proto): - """Parse an Einsum Op.""" - _check_attrs(proto, {"T", "N", "equation"}) - - num_inputs = proto.attr["N"].i - equation = str(proto.attr["equation"].s, "utf-8") - - def _func(*operands): - if len(operands) != num_inputs: - raise ValueError( - f"Expected {num_inputs} input arrays, found {len(operands)}") - return jnp.einsum(equation, *operands) - - return _func - - -@register_operation("Empty") -def _empty(proto): - """Parse an Empty op.""" - _check_attrs(proto, {"dtype", "init"}) - - dtype = tf.as_dtype(proto.attr["dtype"].type) - init = proto.attr["init"].b - - def _func(shape: jnp.ndarray) -> jnp.ndarray: - return anp.empty(shape=shape, dtype=dtype, init=init) - - return _func - - -@register_operation("Fill") -def _fill(proto): - """Parse an Fill op.""" - _check_attrs(proto, {"T", "index_type"}) - - dtype = tf.as_dtype(proto.attr["T"].type) - - def _func(shape: jnp.ndarray, fill_value: jnp.ndarray) -> jnp.ndarray: - return anp.full(shape=shape, fill_value=fill_value, dtype=dtype) - - return _func - - -@register_operation("FusedBatchNormV3") -@register_operation("FusedBatchNormV2") -def _fused_batch_norm(proto): - """Parse a FusedBatchNorm Op.""" - _check_attrs(proto, { - "T", "U", "data_format", "epsilon", "exponential_avg_factor", - "is_training" - }) - - data_format = str(proto.attr["data_format"].s, "utf-8") - if data_format == "NHWC": - reduce_axis = (0, 1, 2) - channel_dim = 3 - elif data_format == "NCHW": - reduce_axis = (0, 2, 3) - channel_dim = 1 - else: - raise ValueError(f"Found unsupported data format {data_format}.") - - epsilon = proto.attr["epsilon"].f - exponential_avg_factor = proto.attr["exponential_avg_factor"].f - one_minus_factor = 1. - exponential_avg_factor - is_training = proto.attr["is_training"].b - - def _func( - x: jnp.ndarray, - scale: jnp.ndarray, - offset: jnp.ndarray, - running_mean: jnp.ndarray, - running_var: jnp.ndarray, - ) -> Tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray]: - batch_mean = jnp.mean(x, axis=reduce_axis) - batch_var = jnp.var(x, axis=reduce_axis) - est_mean = batch_mean if is_training else running_mean - est_var = batch_var if is_training else running_var - - # Prep for broadcasting. - scale = jnp.expand_dims(scale, axis=reduce_axis) - offset = jnp.expand_dims(offset, axis=reduce_axis) - est_mean = jnp.expand_dims(est_mean, axis=reduce_axis) - est_var = jnp.expand_dims(est_var, axis=reduce_axis) - - inv = scale * jax.lax.rsqrt(est_var + epsilon) - norm_x = jnp.asarray((x - est_mean) * inv + offset, x.dtype) - - if is_training: - # Apply Bessel's correction and additional smoothing. - ndata = x.size / x.shape[channel_dim] - correction = ndata / jnp.maximum(ndata - 1.0, 1.0) - running_var = running_var if running_var.size else 0 - running_mean = running_mean if running_mean.size else 0 - new_var = ( - one_minus_factor * running_var + - exponential_avg_factor * batch_var * correction) - new_mean = ( - one_minus_factor * running_mean + - exponential_avg_factor * batch_mean) - return norm_x, new_mean, new_var - else: - return norm_x, running_mean, running_var - - return _func - - -@register_operation("GatherNd") -def _gather_nd(proto): - """Parse a GatherNd Op.""" - _check_attrs(proto, {"Tindices", "Tparams"}) - - def _func(params: jnp.ndarray, indices: jnp.ndarray) -> jnp.ndarray: - return params[tuple(anp.moveaxis(indices, -1, 0))] - - return _func - - -@register_operation("GatherV2") -def _gather(proto): - """Parse a GatherV2 Op.""" - _check_attrs(proto, {"Taxis", "Tindices", "Tparams", "batch_dims"}) - - batch_dims = proto.attr["batch_dims"].i - if batch_dims < 0: - raise ValueError(f"batch_dims={batch_dims} must be non-negative.") - - def _func( - params: jnp.ndarray, - indices: jnp.ndarray, - axis: jnp.ndarray, - ) -> jnp.ndarray: - return anp.gather( - params, indices, axis=axis.tolist(), batch_dims=batch_dims) - - return _func - - -@dataclasses.dataclass -class _IdentityN(_HigherOrderFunction): - """Represents a IdentityN Op.""" - - gradient_op_type: str # For debug, custom_gradient is handled by _Subgraph. - - def __call__(self, *args): - return args - - -@register_operation("IdentityN") -def _identity_n(proto): - """Parse a IdentityN Op.""" - _check_attrs(proto, {"T"}) - - gradient_op_type = str(proto.attr["_gradient_op_type"].s, "utf-8") - if gradient_op_type: - logging.info("Found custom gradient %s", gradient_op_type) - - return _IdentityN({}, gradient_op_type=gradient_op_type) - - -class _IfOp(_HigherOrderFunction): - """Represents a If Op.""" - - def __call__(self, pred, *operand, then_fun, else_fun): - true_fun = lambda args: then_fun(*args) - false_fun = lambda args: else_fun(*args) - return jax.lax.cond( - pred, true_fun=true_fun, false_fun=false_fun, operand=operand) - - -@register_operation("StatelessIf") -@register_operation("If") -def _if(proto): - """Parse a If op.""" - _check_attrs(proto, { - "Tcond", "Tin", "Tout", "output_shapes", "then_branch", "else_branch" - }) - - then_name = proto.attr["then_branch"].func.name - else_name = proto.attr["else_branch"].func.name - output_shapes = [ - [d.size for d in xs.dim] for xs in proto.attr["output_shapes"].list.shape - ] - del output_shapes - - return _IfOp(dict(then_fun=then_name, else_fun=else_name)) - - -@register_operation("InplaceAdd") -def _inplace_add(proto): - """Parse a InplaceAdd op.""" - _check_attrs(proto, {"T"}) - - def _func( - inputs: jnp.ndarray, - indices: jnp.ndarray, - updates: jnp.ndarray, - ) -> jnp.ndarray: - return jnp.asarray(inputs).at[indices].add(updates) - - return _func - - -@register_operation("InplaceUpdate") -def _inplace_update(proto): - """Parse a InplaceUpdate op.""" - _check_attrs(proto, {"T"}) - - def _func( - inputs: jnp.ndarray, - indices: jnp.ndarray, - updates: jnp.ndarray, - ) -> jnp.ndarray: - return jnp.asarray(inputs).at[indices].set(updates) - - return _func - - -@register_operation("LogSoftmax") -def _log_softmax(proto): - _check_attrs(proto, {"T"}) - return lambda x: jax.nn.log_softmax(x, axis=-1) - - -@register_operation("MatMul") -def _matmul(proto): - """Parse a MatMul Op.""" - _check_attrs(proto, {"T", "transpose_a", "transpose_b"}) - - transpose_a = proto.attr["transpose_a"].b - transpose_b = proto.attr["transpose_b"].b - - def _func(a: jnp.ndarray, b: jnp.ndarray) -> jnp.ndarray: - if transpose_a: - a = jnp.transpose(a) - if transpose_b: - b = jnp.transpose(b) - - return jnp.matmul(a, b) - - return _func - - -@register_operation("BatchMatMulV2") -def _batch_matmul(proto): - """Parse a BatchMatMul Op.""" - _check_attrs(proto, {"T", "adj_x", "adj_y"}) - - adj_x = proto.attr["adj_x"].b - adj_y = proto.attr["adj_y"].b - - # TODO(shaobohou) Add test for arrays with complex values. - def _func(x: jnp.ndarray, y: jnp.ndarray) -> jnp.ndarray: - if adj_x: - x = jnp.conjugate(jnp.swapaxes(x, -1, -2)) - if adj_y: - y = jnp.conjugate(jnp.swapaxes(y, -1, -2)) - - return jnp.matmul(x, y) - - return _func - - -@register_operation("MatrixDiagV3") -def _matrix_diag(proto): - """Parse a MatrixDiagV3 op.""" - _check_attrs(proto, {"T", "align"}) - - align = str(proto.attr["align"].s, "utf-8") - if align != "RIGHT_LEFT": - raise ValueError(f"MatrixDiagV3 does not support `align={align}` yet.") - - def _func( - diagonals: jnp.ndarray, - k: jnp.ndarray, - num_rows: jnp.ndarray, - num_cols: jnp.ndarray, - padding_value: jnp.ndarray, - ) -> jnp.ndarray: - if num_rows != -1 or num_cols != -1: - raise ValueError(f"MatrixDiagV3 does not yet support num_rows={num_rows} " - f"or num_cols={num_cols}.") - - diag_fn = lambda inputs: jnp.diagflat(inputs, k=k) - for _ in range(len(diagonals.shape) - 1): - diag_fn = jax.vmap(diag_fn) - - outputs = diag_fn(diagonals) - mask = diag_fn(jnp.ones_like(diagonals, dtype=jnp.bool_)) - return jnp.where(mask, outputs, padding_value) - - return _func - - -@register_operation("MatrixBandPart") -def _matrix_band_part(proto): - """Parse a MatrixBandPart op.""" - _check_attrs(proto, {"T", "Tindex"}) - - def _func( - x: jnp.ndarray, - lower: jnp.ndarray, - upper: jnp.ndarray, - ) -> jnp.ndarray: - if len(x.shape) < 2: - raise ValueError( - f"Expected input of at least rank 2, found {len(x.shape)}") - mask_shape = x.shape[-2:] - lower = lower.tolist() + 1 if lower.tolist() >= 0 else max(mask_shape) - mask_lower = jnp.tril(jnp.ones(mask_shape, jnp.int32), -lower) - upper = upper.tolist() + 1 if upper.tolist() >= 0 else max(mask_shape) - mask_upper = jnp.triu(jnp.ones(mask_shape, jnp.int32), upper) - return jnp.where((mask_lower + mask_upper) == 0, x, 0) - - return _func - - -@register_operation("Max") -def _max(proto): - _check_attrs(proto, {"T", "Tidx", "keep_dims"}) - - keep_dims = proto.attr["keep_dims"].b - - def _func(x: jnp.ndarray, axis: jnp.ndarray) -> jnp.ndarray: - return anp.max_(x, axis=axis.tolist(), keepdims=keep_dims) - - return _func - - -@register_operation("MaxPool") -def _max_pool(proto): - """Parse a MaxPool Op.""" - _check_attrs( - proto, - {"T", "padding", "explicit_paddings", "ksize", "strides", "data_format"}) - - explicit_paddings = tuple(proto.attr["explicit_paddings"].list.i) - if explicit_paddings: - raise ValueError("explicit_padding in MaxPool not yet supported.") - - padding = str(proto.attr["padding"].s, "utf-8") - ksize = tuple(proto.attr["ksize"].list.i) - strides = tuple(proto.attr["strides"].list.i) - data_format = str(proto.attr["data_format"].s, "utf-8") - if data_format not in ("NHWC", "NCHW"): - raise ValueError(f"Found unsupported data format {data_format}.") - - def _func(x: jnp.ndarray) -> jnp.ndarray: - return jax.lax.reduce_window( - x, - init_value=-jnp.inf, - computation=jax.lax.max, - window_dimensions=ksize, - window_strides=strides, - padding=padding) - - return _func - - -@register_operation("Mean") -def _mean(proto): - _check_attrs(proto, {"T", "Tidx", "keep_dims"}) - - keep_dims = proto.attr["keep_dims"].b - - def _func(x: jnp.ndarray, axis: jnp.ndarray) -> jnp.ndarray: - return jnp.mean(x, axis=axis.tolist(), keepdims=keep_dims) - - return _func - - -@register_operation("Min") -def _min(proto): - _check_attrs(proto, {"T", "Tidx", "keep_dims"}) - - keep_dims = proto.attr["keep_dims"].b - - def _func(x: jnp.ndarray, axis: jnp.ndarray) -> jnp.ndarray: - return anp.min_(x, axis=axis.tolist(), keepdims=keep_dims) - - return _func - - -@register_operation("OneHot") -def _one_hot(proto): - """Parse a OneHot Op.""" - _check_attrs(proto, {"T", "TI", "axis"}) - - axis = proto.attr["axis"].i - - def _func( - indices: jnp.ndarray, - depth: jnp.ndarray, - on_value: jnp.ndarray, - off_value: jnp.ndarray, - ) -> jnp.ndarray: - if axis != -1 and axis != len(indices.shape): - raise ValueError(f"OneHot does not support axis={axis} yet, " - f"indices.shape={indices.shape}.") - - mask = jax.nn.one_hot(indices, num_classes=depth, dtype=jnp.int32) - return mask * on_value + (1 - mask) * off_value - - return _func - - -@register_operation("Pack") -def _pack(proto): - """Parse a Pack op.""" - _check_attrs(proto, {"T", "axis", "N"}) - - num_arrays = proto.attr["N"].i - axis = proto.attr["axis"].i - - def _func(*args) -> jnp.ndarray: - if len(args) != num_arrays: - raise ValueError( - f"Pack expects {num_arrays} args, received {len(args)}.") - return anp.stack(args, axis=axis) - - return _func - - -@register_operation("Pad") -def _pad(proto): - _check_attrs(proto, {"T", "Tpaddings"}) - return lambda x, paddings: jnp.pad(x, pad_width=paddings) - - -@register_operation("PadV2") -def _pad_v2(proto): - """Parse a PadV2 op.""" - _check_attrs(proto, {"T", "Tpaddings"}) - - def _func( - inputs: jnp.ndarray, - padding: jnp.ndarray, - constant_values: jnp.ndarray, - ) -> jnp.ndarray: - return jnp.pad(inputs, pad_width=padding, constant_values=constant_values) - - return _func - - -class _PartitionedCall(_HigherOrderFunction): - """Represents a PartitionedCall Op.""" - - def __call__(self, *args, inner_fn, rng=None): - return inner_fn(*args, rng=rng) - - -# TODO(shaobohou) Add test for StatefulPartitionedCall. -@register_operation("StatefulPartitionedCall") -@register_operation("PartitionedCall") -def _partitioned_call(proto): - """Parse a PartitionedCall op.""" - _check_attrs(proto, - {"f", "Tin", "Tout", "config", "config_proto", "executor_type"}) - - inner_fn = proto.attr["f"].func.name - config = str(proto.attr["config"].s, "utf-8") - config_proto = proto.attr["config_proto"].s # TODO(shaobohou) decode this? - executor_type = str(proto.attr["executor_type"].s, "utf-8") - del config, config_proto, executor_type - - return _PartitionedCall(dict(inner_fn=inner_fn)) - - -@register_operation("Placeholder") -def _placeholder(proto): - _check_attrs(proto, {"dtype", "shape"}) - - name = proto.name - - def _func(): - raise ValueError(f"Placeholder `{name}` cannot be evaluated.") - - return _func - - -@register_operation("PreventGradient") -def _prevent_gradient(proto): - """Parse a PreventGradient op.""" - _check_attrs(proto, {"T", "message"}) - - message = str(proto.attr["message"].s, "utf-8") - jax_message = ( - f"Gradient explicitly prevented on node {proto.name}. Reason: {message}") - - @jax.custom_gradient - def _func(operand: jnp.ndarray) -> jnp.ndarray: - def grad_fn(_): - raise LookupError(jax_message) - return operand, grad_fn - - return _func - - -@register_operation("Prod") -def _prod(proto): - """Parse a Prod op.""" - _check_attrs(proto, {"T", "Tidx", "keep_dims"}) - - keep_dims = proto.attr["keep_dims"].b - - def _func(x: jnp.ndarray, axis: jnp.ndarray) -> jnp.ndarray: - return anp.prod(x, axis=axis.tolist(), keepdims=keep_dims) - - return _func - - -@register_operation("RandomStandardNormal") -def _random_standard_normal(proto): - """Parse a RandomStandardNormal op.""" - _check_attrs(proto, {"T", "dtype", "seed", "seed2"}) - - seed = proto.attr["seed"].i - seed2 = proto.attr["seed2"].i - dtype = tf.as_dtype(proto.attr["dtype"].type) - jax_dtype = anp.get_jax_dtype(dtype) - - if seed != 0 or seed2 != 0: - logging.warning( - "RandomStandardNormal does not yet support non-zero seeds, found " - "seed=%s and seed2=%s.", seed, seed2) - - return lambda shape, *, rng: jax.random.normal(rng, shape, dtype=jax_dtype) - - -@register_operation("RandomUniform") -def _random_uniform(proto): - """Parse a RandomUniform op.""" - _check_attrs(proto, {"T", "dtype", "seed", "seed2"}) - - seed = proto.attr["seed"].i - seed2 = proto.attr["seed2"].i - dtype = tf.as_dtype(proto.attr["dtype"].type) - jax_dtype = anp.get_jax_dtype(dtype) - - if seed != 0 or seed2 != 0: - logging.warning( - "RandomUniform does not yet support non-zero seeds, found " - "seed=%s and seed2=%s.", seed, seed2) - - return lambda shape, *, rng: jax.random.uniform(rng, shape, dtype=jax_dtype) - - -@register_operation("RandomUniformInt") -def _random_uniform_int(proto): - """Parse a RandomUniformInt op.""" - _check_attrs(proto, {"T", "Tout", "seed", "seed2"}) - - seed = proto.attr["seed"].i - seed2 = proto.attr["seed2"].i - dtype = tf.as_dtype(proto.attr["Tout"].type) - jax_dtype = anp.get_jax_dtype(dtype) - - if seed != 0 or seed2 != 0: - logging.warning( - "RandomUniformInt does not yet support non-zero seeds, found " - "seed=%s and seed2=%s.", seed, seed2) - - def _func(shape, minval, maxval, *, rng): - return jax.random.randint( - rng, shape, minval=minval, maxval=maxval, dtype=jax_dtype) - - return _func - - -@register_operation("Range") -def _range(proto): - """Parse a Range op.""" - _check_attrs(proto, {"Tidx"}) - - dtype = tf.as_dtype(proto.attr["Tidx"].type) - - def _func( - start: jnp.ndarray, - limit: jnp.ndarray, - delta: jnp.ndarray, - ) -> jnp.ndarray: - return anp.arange(start, stop=limit, step=delta, dtype=dtype) - - return _func - - -@register_operation("Reshape") -def _reshape(proto): - _check_attrs(proto, {"T", "Tshape"}) - return lambda x, shape: jnp.reshape(x, newshape=shape) - - -@register_operation("ResizeBilinear") -def _resize_linear(proto): - """Parse a ResizeBilinear op.""" - _check_attrs(proto, {"T", "align_corners", "half_pixel_centers"}) - - align_corners = proto.attr["align_corners"].b - half_pixel_centers = proto.attr["half_pixel_centers"].b - if align_corners and half_pixel_centers: - # Not supported by tf.raw_ops.ResizeBilinear. - raise ValueError( - "align_corners=True and half_pixel_centers=True are not supported. ") - - def _func(images: jnp.ndarray, size: jnp.ndarray) -> jnp.ndarray: - if len(images.shape) != 4: - raise ValueError( - "Expected A 4D tensor with shape [batch, height, width, channels], " - f"found {images.shape}") - - inp_batch, inp_height, inp_width, inp_channels = images.shape - out_height, out_width = size.tolist() - - height_scale = out_height / inp_height - width_scale = out_width / inp_width - if align_corners: - if out_height > 1: - height_scale = (out_height - 1) / (inp_height - 1) - if out_width > 1: - width_scale = (out_width - 1) / (inp_width - 1) - scale = np.array((height_scale, width_scale)) - - translation = np.array(([0.0] * 2)) - if not half_pixel_centers: - translation = translation - scale * 0.5 + 0.5 - - return jax.image.scale_and_translate( - images, - shape=(inp_batch, out_height, out_width, inp_channels), - spatial_dims=(1, 2), - scale=scale, - translation=translation, - method="linear", - antialias=False, - precision=None, - ) - - return _func - - -@register_operation("ScatterNd") -def _scatter_nd(proto): - """Parse a ScatterNd op.""" - _check_attrs(proto, {"T", "Tindices"}) - - def _func( - indices: jnp.ndarray, - updates: jnp.ndarray, - shape: jnp.ndarray, - ) -> jnp.ndarray: - zeros = jnp.zeros(shape, updates.dtype) - key = tuple(jnp.moveaxis(indices, -1, 0)) - return zeros.at[key].set(updates) - - return _func - - -@register_operation("SelectV2") -@register_operation("Select") -def _select(proto): - """Parse a Select op.""" - _check_attrs(proto, {"T"}) - - def _func( - conds: jnp.ndarray, - x: jnp.ndarray, - y: jnp.ndarray, - ) -> jnp.ndarray: - conds = anp.expand_dims(conds, axis=tuple(range(conds.ndim, x.ndim))) - return anp.where(conds, x, y) - - return _func - - -@register_operation("Slice") -def _slice(proto): - """Parse a Slice Op.""" - _check_attrs(proto, {"T", "Index"}) - - def _func( - x: jnp.ndarray, - begins: jnp.ndarray, - sizes: jnp.ndarray, - ) -> jnp.ndarray: - """`begins` and `sizes` must be concrete arrays.""" - slices = [slice(b, b + s) for b, s in zip(begins, sizes)] - return x[tuple(slices)] - - return _func - - -@register_operation("Softmax") -def _softmax(proto): - _check_attrs(proto, {"T"}) - return lambda x: jax.nn.softmax(x, axis=-1) - - -@register_operation("SparseSoftmaxCrossEntropyWithLogits") -def _sparse_softmax_cross_entropy_with_logits(proto): - """Parse a SparseSoftmaxCrossEntropyWithLogits Op.""" - _check_attrs(proto, {"T", "Tlabels"}) - - def cross_entropy(xs, ys): - return -1.0 * jnp.take_along_axis( - jax.nn.log_softmax(xs, axis=-1), ys[:, jnp.newaxis], axis=-1)[:, 0] - - def _func(features: jnp.ndarray, - labels: jnp.ndarray) -> Tuple[jnp.ndarray, jnp.ndarray]: - - loss = cross_entropy(features, labels) - vjp_outputs, vjp_fn = jax.vjp(cross_entropy, features, labels) - grads = vjp_fn(jnp.ones(vjp_outputs.shape))[0] - return loss, grads - - return _func - - -@register_operation("Split") -def _split(proto): - """Parse a Split op.""" - _check_attrs(proto, {"T", "num_split"}) - num_split = proto.attr["num_split"].i - return lambda axis, inputs: anp.split(inputs, num_split, axis=axis) - - -@register_operation("SplitV") -def _splitv(proto): - """Parse a SplitV op.""" - _check_attrs(proto, {"T", "Tlen", "num_split"}) - num_split = proto.attr["num_split"].i - - def _func( - value: jnp.ndarray, - size_splits: jnp.ndarray, - axis: jnp.ndarray, - ) -> jnp.ndarray: - assert size_splits.shape[0] == num_split, (size_splits.shape[0], num_split) - splits = size_splits.tolist() - axis = axis.tolist() - defined_size = sum([x for x in splits if x >= 0]) - splits = [x if x >= 0 else value.shape[axis] - defined_size for x in splits] - indices = np.cumsum(np.array(splits), axis=0) - return anp.split(value, indices, axis=axis) - - return _func - - -@register_operation("SquaredDifference") -def _squared_difference(proto): - _check_attrs(proto, {"T"}) - return lambda x1, x2: jnp.square(x1 - x2) - - -@register_operation("Squeeze") -def _squeeze(proto): - _check_attrs(proto, {"T", "squeeze_dims"}) - - axis = tuple(proto.attr["squeeze_dims"].list.i) - - return lambda x: anp.squeeze(x, axis=axis) - - -@register_operation("StatelessRandomGetKeyCounter") -def _stateless_random_get_key_counter(proto): - _check_attrs(proto, {"T", "Tseed"}) - - def _func(seed: jnp.ndarray) -> Tuple[jnp.ndarray, jnp.ndarray]: - assert seed.shape == (2,), seed.shape - seed = jnp.sum(seed) # An arbitrary choice. - return jax.random.PRNGKey(seed), jnp.array(0, dtype=jnp.int32) - - return _func - - -@register_operation("StatelessMultinomial") -def _stateless_multinomial(proto): - """Parse a StatelessMultinomial op.""" - _check_attrs(proto, {"T", "Tseed", "output_dtype"}) - - dtype = tf.as_dtype(proto.attr["output_dtype"].type) - jax_dtype = anp.get_jax_dtype(dtype) - - def _func( - logits: jnp.ndarray, - num_samples: jnp.ndarray, - seed: jnp.ndarray, - ) -> jnp.ndarray: - assert seed.shape == (2,), seed.shape - seed = seed.astype(jnp.uint32) - shape = (num_samples, logits.shape[0]) - samples = jax.random.categorical(seed, logits, shape=shape) - return samples.astype(jax_dtype).transpose([1, 0]) - - return _func - - -@register_operation("StatelessRandomNormalV2") -def _stateless_random_normal_v2(proto): - """Parse a StatelessRandomNormalV2 op.""" - _check_attrs(proto, {"T", "Tshape", "dtype"}) - - dtype = tf.as_dtype(proto.attr["dtype"].type) - jax_dtype = anp.get_jax_dtype(dtype) - - def _func( - shape: jnp.ndarray, - key: jnp.ndarray, - counter: jnp.ndarray, - alg: jnp.ndarray, - ) -> jnp.ndarray: - del counter, alg # TODO(shaobohou) combine key and counter? - return jax.random.normal(key=key, shape=shape, dtype=jax_dtype) - - return _func - - -@register_operation("StatelessRandomUniformV2") -def _stateless_random_uniform_v2(proto): - """Parse a StatelessRandomNormalV2 op.""" - _check_attrs(proto, {"T", "Tshape", "dtype"}) - - dtype = tf.as_dtype(proto.attr["dtype"].type) - jax_dtype = anp.get_jax_dtype(dtype) - - def _func( - shape: jnp.ndarray, - key: jnp.ndarray, - counter: jnp.ndarray, - alg: jnp.ndarray, - ) -> jnp.ndarray: - del counter, alg # TODO(shaobohou) combine key and counter? - return jax.random.uniform(key=key, shape=shape, dtype=jax_dtype) - - return _func - - -@register_operation("StatelessRandomUniformFullIntV2") -@register_operation("StatelessRandomUniformIntV2") -def _stateless_random_uniform_int_v2(proto): - """Parse a StatelessRandomUniformIntV2 op.""" - _check_attrs(proto, {"T", "Tshape", "dtype"}) - - dtype = tf.as_dtype(proto.attr["dtype"].type) - jax_dtype = anp.get_jax_dtype(dtype) - - def _func( - shape: jnp.ndarray, - key: jnp.ndarray, - counter: jnp.ndarray, - alg: jnp.ndarray, - minval: jnp.ndarray = jnp.iinfo(jax_dtype).min, - maxval: jnp.ndarray = jnp.iinfo(jax_dtype).max, - ) -> jnp.ndarray: - del counter, alg # TODO(shaobohou) combine key and counter? - return jax.random.randint( - key=key, shape=shape, minval=minval, maxval=maxval, dtype=jax_dtype,) - - return _func - - -class _StatelessWhile(_HigherOrderFunction): - """Represents a StatelessWhile Op.""" - - def __call__(self, *args, cond_fun, body_fun, rng=None): - def real_cond_fun(args): - *cond_args, rng = args - _, cond_key, _ = [None] * 3 if rng is None else jax.random.split(rng, 3) - outputs = cond_fun(*cond_args, rng=cond_key) - if len(outputs) != 1: - raise ValueError( - f"Expected cond_fun to return a single value, found {outputs}") - return outputs[0] - - def real_body_fun(args): - *body_args, rng = args - key, _, body_key = [None] * 3 if rng is None else jax.random.split(rng, 3) - outputs = tuple(body_fun(*body_args, rng=body_key)) - return outputs + (key,) - - outputs = jax.lax.while_loop(real_cond_fun, real_body_fun, args + (rng,)) - return outputs - - -@register_operation("StatelessWhile") -@register_operation("While") -def _stateless_while(proto): - """Parse a StatelessWhile op.""" - _check_attrs(proto, - {"T", "body", "cond", "parallel_iterations", "output_shapes"}) - # TODO(shaobohou) Check proto.arg_attr? - - body_name = proto.attr["body"].func.name - cond_name = proto.attr["cond"].func.name - parallel_iterations = proto.attr["parallel_iterations"].i - output_shapes = [ - [d.size for d in xs.dim] for xs in proto.attr["output_shapes"].list.shape - ] - del parallel_iterations, output_shapes - - return _StatelessWhile(dict(cond_fun=cond_name, body_fun=body_name)) - - -@register_operation("StridedSlice") -def _strided_slice(proto): - """Parse a StridedSlice Op.""" - _check_attrs( - proto, { - "T", "Index", "begin_mask", "ellipsis_mask", "end_mask", - "new_axis_mask", "shrink_axis_mask" - }) - - def unpack(x): - return [(1 if(x & (1 << v)) else 0) for v in range(32)] - - begin_mask = unpack(proto.attr["begin_mask"].i) - ellipsis_mask = unpack(proto.attr["ellipsis_mask"].i) - end_mask = unpack(proto.attr["end_mask"].i) - new_axis_mask = unpack(proto.attr["new_axis_mask"].i) - shrink_axis_mask = unpack(proto.attr["shrink_axis_mask"].i) - - def _func( - x: jnp.ndarray, - begin: jnp.ndarray, - end: jnp.ndarray, - strides: jnp.ndarray, - ) -> jnp.ndarray: - """`begin`, `end` and `strides` must be concrete arrays.""" - - num_specs = len(begin) - inserted = ( - len(x.shape) + sum(new_axis_mask) - (num_specs - sum(ellipsis_mask))) - - # Rebuild slices. - dim = 0 - slices = [] - for idx in range(num_specs): - if new_axis_mask[idx] == 1: - slices.append(jnp.newaxis) - elif ellipsis_mask[idx] == 1: - slices.append(Ellipsis) - dim += inserted - else: - if shrink_axis_mask[idx] == 1: - slices.append(begin[idx]) - else: - beg_dim = begin[idx] if begin_mask[idx] == 0 else None - end_dim = end[idx] if end_mask[idx] == 0 else None - stride = strides[idx] - x_dim = x.shape[dim] - if (stride == 1 and - jax.core.symbolic_equal_dim(beg_dim or 0, 0) and - jax.core.symbolic_equal_dim(end_dim or x_dim, x_dim)): - slices.append(slice(None, None, None)) - else: - slices.append(slice(beg_dim, end_dim, stride)) - dim += 1 - - # TODO(shaobohou) Handle stride=1 slicing along polymoprhic dimensions. - return x[tuple(slices)] - - return _func - - -@register_operation("Sum") -def _sum(proto): - _check_attrs(proto, {"T", "Tidx", "keep_dims"}) - - keep_dims = proto.attr["keep_dims"].b - - def _func(x: jnp.ndarray, axis: jnp.ndarray) -> jnp.ndarray: - return anp.sum_(x, axis=axis.tolist(), keepdims=keep_dims) - - return _func - - -@register_operation("TopKV2") -def _top_k(proto): - _check_attrs(proto, {"T", "sorted"}) - sorted_arg = proto.attr["sorted"].b - if not sorted_arg: - raise ValueError("sorted=False in TopKV2 is not yet supported.") - - return jax.lax.top_k - - -@register_operation("Transpose") -def _transpose(proto): - _check_attrs(proto, {"T", "Tperm"}) - return lambda x, axes: jnp.transpose(x, axes=axes) - - -@register_operation("Unpack") -def _unpack(proto): - """Parse a Unpack op.""" - _check_attrs(proto, {"T", "axis", "num"}) - - axis = proto.attr["axis"].i - num = proto.attr["num"].i - - def _func(x: jnp.ndarray) -> jnp.ndarray: - if x.shape[axis] != num: - raise ValueError("Unpack expects dimension of {num} for axis={axis}, " - "found {x.shape[axis]}, shape={x.shape}") - return [anp.squeeze(v, axis=axis) for v in anp.split(x, num, axis=axis)] - - return _func - - -@register_operation("XlaConvV2") -@register_operation("XlaConv") -def _xla_conv(proto): - """Parse a XlaConv op.""" - _check_attrs( - proto, { - "T", "LhsT", "RhsT", "Tindices", "dimension_numbers", - "precision_config", "preferred_element_type", "batch_group_count" - }) - - dimension_numbers = xla_utils.convolution_dimension_numbers_from_proto( - proto.attr["dimension_numbers"].s) - precision_config = xla_utils.precision_config_from_proto( - proto.attr["precision_config"].s) - batch_group_count = proto.attr["batch_group_count"].i - if "preferred_element_type" in proto.attr: - dst_dtype = tf.as_dtype(proto.attr["preferred_element_type"].type) - dst_dtype = anp.get_jax_dtype(dst_dtype) - else: - dst_dtype = None - - def _func( - lhs: jnp.ndarray, - rhs: jnp.ndarray, - strides: jnp.ndarray, - padding: jnp.ndarray, - lhs_dilation: jnp.ndarray, - rhs_dilation: jnp.ndarray, - feature_group_count: jnp.ndarray, - ) -> jnp.ndarray: - return jax.lax.conv_general_dilated( - lhs, - rhs, - window_strides=strides.tolist(), - padding=[tuple(v) for v in padding], - lhs_dilation=lhs_dilation.tolist(), - rhs_dilation=rhs_dilation.tolist(), - dimension_numbers=dimension_numbers, - feature_group_count=feature_group_count.tolist(), # Should be int. - batch_group_count=batch_group_count or 1, - precision=precision_config, - preferred_element_type=dst_dtype) - - return _func - - -@register_operation("XlaDotV2") -@register_operation("XlaDot") -def _xla_dot(proto): - """Parse a XlaDot op.""" - _check_attrs( - proto, { - "T", "LhsT", "RhsT", "dimension_numbers", "precision_config", - "preferred_element_type" - }) - - dimension_numbers = xla_utils.dot_dimension_numbers_from_proto( - proto.attr["dimension_numbers"].s) - precision_config = xla_utils.precision_config_from_proto( - proto.attr["precision_config"].s) - if "preferred_element_type" in proto.attr: - dst_dtype = tf.as_dtype(proto.attr["preferred_element_type"].type) - dst_dtype = anp.get_jax_dtype(dst_dtype) - else: - dst_dtype = None - - def _func(lhs: jnp.ndarray, rhs: jnp.ndarray) -> jnp.ndarray: - return jax.lax.dot_general( - lhs, - rhs, - dimension_numbers, - precision_config, - preferred_element_type=dst_dtype) - - return _func - - -@register_operation("XlaDynamicSlice") -def _xla_dynamic_slice(proto): - """Parse a XlaDynamicSlice op.""" - _check_attrs(proto, {"T", "Tindices"}) - return jax.lax.dynamic_slice - - -@register_operation("XlaDynamicUpdateSlice") -def _xla_dynamic_update_slice(proto): - """Parse a XlaDynamicUpdateSlice op.""" - _check_attrs(proto, {"T", "Tindices"}) - return jax.lax.dynamic_update_slice - - -@register_operation("XlaGather") -def _xla_gather(proto): - """Parse a XlaGather op.""" - _check_attrs(proto, - {"T", "Tindices", "dimension_numbers", "indices_are_sorted"}) - - dimension_numbers = xla_utils.gather_dimension_numbers_from_proto( - proto.attr["dimension_numbers"].s) - # This should exist on the XLA op, even though it's not exposed by JAX. - indices_are_sorted = proto.attr["indices_are_sorted"].b - del indices_are_sorted - - def _func( - operand: jnp.ndarray, - start_indices: jnp.ndarray, - slice_indices: jnp.ndarray, - ) -> jnp.ndarray: - return jax.lax.gather(operand, start_indices, dimension_numbers, - slice_indices) - - return _func - - -@register_operation("XlaPad") -def _xla_pad(proto): - """Parse a XlaPad op.""" - _check_attrs(proto, {"T", "Tindices"}) - - def _func( - operand: jnp.ndarray, - padding_value: jnp.ndarray, - padding_low: jnp.ndarray, - padding_high: jnp.ndarray, - padding_interior: jnp.ndarray,) -> jnp.ndarray: - padding_config = np.stack([padding_low, padding_high, padding_interior], - axis=0) - padding_config = [tuple(x) for x in padding_config.transpose().tolist()] - return jax.lax.pad(operand, padding_value, padding_config) - - return _func - - -def _maybe_get_jaxpreqn( - jaxpr: jax.core.ClosedJaxpr) -> Optional[jax.core.JaxprEqn]: - def is_all_vars(vs): - return all([isinstance(v, jax.core.Var) for v in vs]) - - if (len(jaxpr.eqns) == 1 and - is_all_vars(jaxpr.jaxpr.invars) and is_all_vars(jaxpr.jaxpr.outvars) and - is_all_vars(jaxpr.eqns[0].invars) and is_all_vars(jaxpr.eqns[0].outvars)): - return jaxpr.eqns[0] - return None - - -@dataclasses.dataclass -class _XlaVariadicReduce(_HigherOrderFunction): - """Represents a XlaVariadicReduce Op.""" - - dimensions: Sequence[int] - - def __call__(self, *args: jnp.ndarray, reducer: Callable[..., Any]): - num_args = len(args) - operands = args[:(num_args//2)] - init_values = args[(num_args//2):] - assert len(operands) == len(init_values) - reducer_fn = lambda xs, ys: reducer(*xs, *ys) - return jax.lax.reduce(operands, init_values, reducer_fn, self.dimensions) - - -@register_operation("XlaVariadicReduceV2") -def _xla_variadic_reduce(proto): - """Parse a XlaVariadicReduceV2 op.""" - _check_attrs(proto, {"T", "reducer", "dimensions_to_reduce"}) - - reducer = proto.attr["reducer"].func.name - dimensions = tuple(proto.attr["dimensions_to_reduce"].list.i) - - return _XlaVariadicReduce(dict(reducer=reducer), dimensions=dimensions) - - -@dataclasses.dataclass -class _XlaVariadicSort(_HigherOrderFunction): - """Represents a XlaVariadicSort Op.""" - - is_stable: bool - - def _compute_num_keys( - self, - dtypes: Sequence[jnp.dtype], - comparator: Callable[..., Any], - ) -> int: - """Infer num_keys from the comparator and operands.""" - def get_operands(): - return sum([[jnp.array(0, dtype)] * 2 for dtype in dtypes], []) - - for idx in range(len(dtypes)): - operands = get_operands() - is_eq, = comparator(*operands) - - operands[idx * 2 + 1] = jnp.array(1, dtypes[idx]) - is_lt, = comparator(*operands) - - if idx == 0 and (not is_lt or is_eq): - raise ValueError( - "Only less-than comparator is supported for XlaVariadicSort.") - - if is_lt: - num_keys = idx + 1 - else: - break - - return num_keys - - def __call__(self, *args: jnp.ndarray, comparator: Callable[..., Any]): - operands = args[:-1] - dimension = args[-1].tolist() - - with jax.ensure_compile_time_eval(): - dtypes = [x.dtype for x in operands] - num_keys = self._compute_num_keys(dtypes, comparator) - - return jax.lax.sort( - operands, - dimension=dimension, - is_stable=self.is_stable, - num_keys=num_keys) - - -@register_operation("XlaVariadicSort") -def _xla_variadic_sort(proto): - """Parse a XlaVariadicSort op.""" - _check_attrs(proto, {"T", "comparator", "is_stable"}) - - comparator = proto.attr["comparator"].func.name - is_stable = proto.attr["is_stable"].b - - logging.warning( - "Support for XlaVariadicSort is limited, current implementation assumes " - "the op is generated by `jax2tf.convert(jax.lax.sort)` and does not " - "support arbitrary comparators") - - return _XlaVariadicSort(dict(comparator=comparator), is_stable=is_stable) - - -class _XlaReduceWindow(_HigherOrderFunction): - """Represents a XlaReduceWindow Op.""" - - def __call__( - self, - operand: jnp.ndarray, - init_value: jnp.ndarray, - window_dimensions: jnp.ndarray, - window_strides: jnp.ndarray, - base_dilation: jnp.ndarray, - window_dilation: jnp.ndarray, - padding: jnp.ndarray, - *, - computation: Callable[..., Any], - ): - # Pattern matching computations that can be specialized. - primitives = { - jax.lax.max_p: jax.lax.max, - jax.lax.min_p: jax.lax.min, - jax.lax.add_p: jax.lax.add, - } - computation_jaxpr = jax.make_jaxpr(computation)(init_value, init_value) - computation_eqn = _maybe_get_jaxpreqn(computation_jaxpr) - if computation_eqn is not None and computation_eqn.primitive in primitives: - computation_fn = primitives[computation_eqn.primitive] - else: - computation_fn = lambda *args: computation(*args)[0] - logging.info("Calling reduce_window with the following computation:\n%s", - computation_jaxpr) - - return jax.lax.reduce_window( - operand, - init_value, - computation=computation_fn, - window_dimensions=window_dimensions.tolist(), - window_strides=window_strides.tolist(), - padding=[tuple(v) for v in padding.tolist()], - base_dilation=base_dilation.tolist(), - window_dilation=window_dilation.tolist()) - - -@register_operation("XlaReduceWindow") -def _xla_reduce_window(proto): - """Parse a XlaReduceWindow op.""" - _check_attrs(proto, {"T", "Tindices", "computation"}) - - computation = proto.attr["computation"].func.name - - return _XlaReduceWindow(dict(computation=computation)) - - -@register_operation("XlaRngBitGenerator") -def _xla_rng_bit_generator(proto): - """Parse a XlaRngBitGenerator op.""" - _check_attrs(proto, {"Tshape", "dtype"}) - - dtype = tf.as_dtype(proto.attr["dtype"].type) - jax_dtype = anp.get_jax_dtype(dtype) - if jax_dtype != jnp.uint32: - raise ValueError( - f"XlaRngBitGenerator currently only supports uint32, found{jax_dtype}.") - - def _func( - algorithm: jnp.ndarray, - key: jnp.ndarray, - shape: jnp.ndarray, - ) -> Tuple[jnp.ndarray, jnp.ndarray]: - return jax.lax.rng_bit_generator( - key=key.reshape(-1), - shape=shape, - dtype=jax_dtype, - # See tensorflow/compiler/tf2xla/ops/xla_ops.cc#L812 - algorithm=xla_utils.get_random_algorithm_from_tf(algorithm.tolist()), - ) - - return _func - - -@dataclasses.dataclass -class _XlaScatter(_HigherOrderFunction): - """Represents a XlaScatter Op.""" - - dimension_numbers: jax.lax.ScatterDimensionNumbers - indices_are_sorted: bool - - def __call__( - self, - operand: jnp.ndarray, - indices: jnp.ndarray, - updates: jnp.ndarray, - *, - update_computation: Callable[..., Any], - ) -> jnp.ndarray: - dummy_zero = jnp.array(0).astype(operand.dtype) - jaxpr = jax.make_jaxpr(update_computation)(dummy_zero, dummy_zero) - if not jaxpr.eqns: - scatter_fn = jax.lax.scatter - elif len(jaxpr.eqns) == 1 and jaxpr.eqns[0].primitive == jax.lax.add_p: - scatter_fn = jax.lax.scatter_add - elif len(jaxpr.eqns) == 1 and jaxpr.eqns[0].primitive == jax.lax.mul_p: - scatter_fn = jax.lax.scatter_mul - elif len(jaxpr.eqns) == 1 and jaxpr.eqns[0].primitive == jax.lax.min_p: - scatter_fn = jax.lax.scatter_min - elif len(jaxpr.eqns) == 1 and jaxpr.eqns[0].primitive == jax.lax.max_p: - scatter_fn = jax.lax.scatter_max - else: - raise ValueError( - "Reducer not supported as `update_computation`, found {jaxpr}") - - return scatter_fn( - operand, - indices, - updates, - dimension_numbers=self.dimension_numbers, - indices_are_sorted=self.indices_are_sorted) - - -@register_operation("XlaScatter") -def _xla_scatter(proto): - """Parse a XlaScatter op.""" - _check_attrs( - proto, { - "T", "Tindices", "dimension_numbers", "indices_are_sorted", - "update_computation" - }) - - dimension_numbers = xla_utils.scatter_dimension_numbers_from_proto( - proto.attr["dimension_numbers"].s) - update_computation = proto.attr["update_computation"].func.name - indices_are_sorted = proto.attr["indices_are_sorted"].b - - return _XlaScatter( - dict(update_computation=update_computation), dimension_numbers, - indices_are_sorted) - - -class _XlaSelectAndScatter(_HigherOrderFunction): - """Represents a XlaSelectAndScatter Op.""" - - def __call__( - self, - operand: jnp.ndarray, - window_dimensions: jnp.ndarray, - window_strides: jnp.ndarray, - padding: jnp.ndarray, - source: jnp.ndarray, - inner_init_value: jnp.ndarray, - *, - scatter: Callable[..., Any], - select: Callable[..., Any], - ) -> jnp.ndarray: - # Because jax.lax._select_and_scatter is not part of the JAX public api, we - # are using a crude pattern matching to determine the reducer used in the - # original reduce_window call. - - scatter_jaxpr = jax.make_jaxpr(scatter)(inner_init_value, inner_init_value) - scatter_eqn = _maybe_get_jaxpreqn(scatter_jaxpr) - if scatter_eqn is not None and scatter_eqn.primitive is not jax.lax.add_p: - raise ValueError( - f"Only Add is supported as scatter function, found {scatter_jaxpr}.") - - # TODO(shaobohou) Support jax.lax.add for AvgPool. - select_primitives = { - jax.lax.ge_p: (-jnp.inf, jax.lax.max), - jax.lax.le_p: (jnp.inf, jax.lax.min), - } - select_jaxpr = jax.make_jaxpr(select)(inner_init_value, inner_init_value) - select_eqn = _maybe_get_jaxpreqn(select_jaxpr) - if select_eqn is not None and select_eqn.primitive in select_primitives: - init_value, computation = select_primitives[select_eqn.primitive] - else: - raise ValueError("Only greater_equal (Max) and less_equal (Min) are " - f"supported as select function, found {select_jaxpr}") - - def reduce_window(x): - return jax.lax.reduce_window( - x, - init_value, - computation=computation, - window_dimensions=tuple(window_dimensions.tolist()), - window_strides=tuple(window_strides.tolist()), - padding=[tuple(v) for v in padding.tolist()]) - - _, f_vjp = jax.vjp(reduce_window, operand) - return f_vjp(source) - - -@register_operation("XlaSelectAndScatter") -def _xla_select_and_scatter(proto): - """Parse a XlaSelectAndScatter op.""" - _check_attrs(proto, {"T", "Tindices", "scatter", "select"}) - - scatter = proto.attr["scatter"].func.name - select = proto.attr["select"].func.name - - return _XlaSelectAndScatter(dict(scatter=scatter, select=select)) - - class _TensorEdge(NamedTuple): """Represents an input/output Tensor.""" @@ -2138,12 +159,12 @@ class _OpNode: def __init__(self, proto, library: Mapping[str, _LibraryFunction], node_map: Mapping[str, Any]): - self.jax_func = _jax_ops[proto.op](proto) + self.jax_func = ops.get_parser(proto.op)(proto) self.op = proto.op self.name = proto.name self.inner_fns = dict() - if isinstance(self.jax_func, _HigherOrderFunction): + if isinstance(self.jax_func, ops._HigherOrderFunction): self.inner_fns = self.jax_func.get_inner_functions(library) inputs = [_TensorEdge.from_string(inp, node_map) for inp in proto.input] @@ -2413,7 +434,7 @@ def maybe_tensor_to_spec(v): signature = inspect.Signature(parameters=parameters) # Extract custom_gradient functions from the registry. - if get_config("convert_custom_gradient"): + if config.get_config("convert_custom_gradient"): library = _convert_all_gradient_functions(graph, {}) else: library = {} @@ -2682,7 +703,8 @@ def _infer_relu_from_jax2tf(nodes): # The Cast and Const ops may now be redundant but are kept anyway. node.op = "Relu" node.inputs = node.inputs[:1] - node.jax_func = _jax_ops["Relu"](_NodeDef("Relu", node.name, (), {})) + node.jax_func = ( + ops.get_parser("Relu")(_NodeDef("Relu", node.name, (), {}))) if not found_jax2tf: logging.warning("Replaced max(x, 0) with jax.nn.relu but did not " "find jax2tf_out.") @@ -2797,7 +819,8 @@ def _convert( assert len(input_names) == len(set(input_names)) output_names = tuple([v.name for v in tree.flatten(structured_outputs)]) - unsupported = {node.op for node in graphdef.node if node.op not in _jax_ops} + unsupported = ops.get_unsupported_operations( + [node.op for node in graphdef.node]) if unsupported: raise ValueError(f"Unsupported operations in graph: {list(unsupported)}\n" "Support for additional TensorFlow ops are added on an " @@ -2838,10 +861,10 @@ def _convert( output_args = [_TensorEdge.from_string(v, node_map) for v in output_names] num_rng_required = sum([node.require_rng for node in nodes]) - if get_config("infer_relu_from_jax2tf"): + if config.get_config("infer_relu_from_jax2tf"): _infer_relu_from_jax2tf(nodes) - if get_config("convert_custom_gradient"): + if config.get_config("convert_custom_gradient"): subgraphs = _extract_subgraphs(graphdef, nodes, library) for _, subgraph in subgraphs.items(): nodes = subgraph.rewrite(nodes) @@ -2883,7 +906,7 @@ def jax_func( "a Tensor or Array.") else: continue - if (get_config("strict_shape_check") and + if (config.get_config("strict_shape_check") and not spec.shape.is_compatible_with(_fix_jax_poly_shape(inp.shape))): raise ValueError( f"Found incompatible input shape: {inp.shape}, expected " @@ -2894,7 +917,7 @@ def jax_func( "either tf2jax.update_config('strict_shape_check', False) or the " "context manager " "tf2jax.override_config('strict_shape_check', False)") - if (get_config("strict_dtype_check") and + if (config.get_config("strict_dtype_check") and not spec.dtype.is_compatible_with(inp.dtype)): raise ValueError( f"Found incompatible input dtype: {inp.dtype}, expected " diff --git a/tf2jax/_src/tf2jax_test.py b/tf2jax/_src/tf2jax_test.py index 2d61065..ed669fe 100644 --- a/tf2jax/_src/tf2jax_test.py +++ b/tf2jax/_src/tf2jax_test.py @@ -24,6 +24,7 @@ import sonnet as snt import tensorflow as tf +from tf2jax._src import config from tf2jax._src import tf2jax import tree @@ -133,7 +134,7 @@ def test_saved_model(self): test_inputs = np.ones([20, 5], dtype=np.float32) expected_outputs = tf_func(test_inputs) - with tf2jax.override_config("strict_shape_check", False): + with config.override_config("strict_shape_check", False): actual_outputs, _ = jax_func({}, test_inputs) self.assertAllClose(expected_outputs, actual_outputs) @@ -150,7 +151,7 @@ def test_saved_model_ambiguous(self): test_inputs = np.ones([20, 7, 2], dtype=np.float32) expected_outputs = tf_func(test_inputs) - with tf2jax.override_config("strict_shape_check", False): + with config.override_config("strict_shape_check", False): actual_outputs, _ = jax_func({}, test_inputs) self.assertAllClose(expected_outputs, actual_outputs) @@ -164,7 +165,7 @@ def test_saved_model_functional(self): test_inputs = np.ones([20, 5], dtype=np.float32) expected_outputs = tf_func(test_inputs) - with tf2jax.override_config("strict_shape_check", False): + with config.override_config("strict_shape_check", False): actual_outputs = jax_func(test_inputs) self.assertAllClose(expected_outputs, actual_outputs) @@ -200,30 +201,30 @@ def tf_func(x): with self.subTest("valid_input_shape"): expected_outputs = tf_func(orig_inputs) - with tf2jax.override_config("strict_shape_check", True): + with config.override_config("strict_shape_check", True): actual_outputs = jax_func(orig_inputs) self.assertAllClose(expected_outputs, actual_outputs) - with tf2jax.override_config("strict_dtype_check", True): + with config.override_config("strict_dtype_check", True): actual_outputs = jax_func(orig_inputs) self.assertAllClose(expected_outputs, actual_outputs) test_inputs = np.ones([10, 5, 2], dtype=np.float32) with self.subTest("invalid_input_shape"): expected_outputs = tf_func(test_inputs) - with tf2jax.override_config("strict_shape_check", True): + with config.override_config("strict_shape_check", True): with self.assertRaisesRegex(ValueError, "incompatible input shape"): actual_outputs = jax_func(test_inputs) - with tf2jax.override_config("strict_shape_check", False): + with config.override_config("strict_shape_check", False): actual_outputs = jax_func(test_inputs) self.assertNotAllClose(expected_outputs, actual_outputs) test_inputs = np.ones([10, 5], dtype=np.int32) with self.subTest("invalid_input_dtype"): expected_outputs = tf_func(test_inputs) - with tf2jax.override_config("strict_dtype_check", True): + with config.override_config("strict_dtype_check", True): with self.assertRaisesRegex(ValueError, "incompatible input dtype"): actual_outputs = jax_func(test_inputs) - with tf2jax.override_config("strict_dtype_check", False): + with config.override_config("strict_dtype_check", False): actual_outputs = jax_func(test_inputs) self.assertAllClose(expected_outputs, actual_outputs) @@ -246,7 +247,7 @@ def tf_func(x): "force_const_float64_to_bfloat16") # This is the default. - with tf2jax.override_config(dtype_config_name, False): + with config.override_config(dtype_config_name, False): jax_func = tf2jax.convert_functional(tf_func, np_inputs) jax_func = self.variant(jax_func) orig_jax_outputs = jax_func(np_inputs) @@ -255,7 +256,7 @@ def tf_func(x): jnp.array(orig_jax_outputs).dtype) self.assertAllClose(tf_outputs, orig_jax_outputs) - with tf2jax.override_config(dtype_config_name, True): + with config.override_config(dtype_config_name, True): jax_func = tf2jax.convert_functional(tf_func, np_inputs) jax_func = self.variant(jax_func) forced_jax_outputs = jax_func(jnp.asarray(np_inputs, jnp.bfloat16)) @@ -278,7 +279,7 @@ class CachedFn(hk.Module): def __init__(self): super().__init__(name=None) if not self.cache: - with tf2jax.override_config("force_const_float32_to_bfloat16", True): + with config.override_config("force_const_float32_to_bfloat16", True): self.cache.append(jax.jit( tf2jax.convert_functional(tf_func, np_inputs))) @@ -344,7 +345,7 @@ def grad(dy): tf_outputs = tf_func(tf_inputs) tf_grads = tape.gradient(tf_outputs, tf_inputs) - with tf2jax.override_config("convert_custom_gradient", use_custom_gradient): + with config.override_config("convert_custom_gradient", use_custom_gradient): jax_func = tf2jax.convert_functional(tf_func, np.zeros_like(np_inputs)) jax_func = self.variant(jax_func) jax_outputs = jax_func(np_inputs)