diff --git a/tf2jax/_src/numpy_compat.py b/tf2jax/_src/numpy_compat.py index d52acf3..ecf0589 100644 --- a/tf2jax/_src/numpy_compat.py +++ b/tf2jax/_src/numpy_compat.py @@ -227,12 +227,22 @@ def broadcast_to(arr, shape): flip = lambda arr, axis: _get_np(arr).flip(arr, axis=axis) roll = lambda arr, shift, axis: _get_np(arr).roll(arr, shift=shift, axis=axis) split = lambda arr, sections, axis: _get_np(arr).split(arr, sections, axis=axis) -squeeze = lambda arr, axis: _get_np(arr).squeeze(arr, axis=axis) stack = lambda arrs, axis: _get_np(*arrs).stack(arrs, axis=axis) tile = lambda arr, reps: _get_np(arr, reps).tile(arr, reps=reps) where = lambda cond, x, y: _get_np(cond, x, y).where(cond, x, y) +def squeeze(arr, axis): + # tf.squeeze and np/jnp.squeeze have different behaviors when axis=(). + # - tf.squeeze will squeeze all dimensions. + # - np/jnp.squeeze will not squeeze any dimensions. + # Here we change () to None to ensure that squeeze has the same behavior + # when converted from tf to np/jnp. + if axis == tuple(): + axis = None + return _get_np(arr).squeeze(arr, axis=axis) + + def moveaxis( arr, source: Union[int, Sequence[int]], diff --git a/tf2jax/_src/ops_test.py b/tf2jax/_src/ops_test.py index 0d96158..7f40fd8 100644 --- a/tf2jax/_src/ops_test.py +++ b/tf2jax/_src/ops_test.py @@ -1458,8 +1458,9 @@ def roll_static(): self._test_convert(roll_static, []) @chex.variants(with_jit=True, without_jit=True) - def test_squeeze(self): - inputs, dims = np.array([[[42], [47]]]), (0, 2) + @parameterized.parameters(((0, 2),), (tuple(),), (None,)) + def test_squeeze(self, dims): + inputs = np.array([[[42], [47]]]) def squeeze(x): return tf.raw_ops.Squeeze(input=x, axis=dims)