diff --git a/tf2jax/_src/ops.py b/tf2jax/_src/ops.py index 30b3373..484a33c 100644 --- a/tf2jax/_src/ops.py +++ b/tf2jax/_src/ops.py @@ -149,6 +149,18 @@ def wrapped(proto): "Tan": _get_jax_op(jnp.tan, {"T"}), "Tanh": _get_jax_op(jnp.tanh, {"T"}), "Tile": _get_jax_op(anp.tile, {"T", "Tmultiples"}), + "UnsortedSegmentMax": _get_jax_op( + functools.partial(jax.ops.segment_max, indices_are_sorted=False), + {"T", "Tindices", "Tnumsegments"}), + "UnsortedSegmentMin": _get_jax_op( + functools.partial(jax.ops.segment_min, indices_are_sorted=False), + {"T", "Tindices", "Tnumsegments"}), + "UnsortedSegmentProd": _get_jax_op( + functools.partial(jax.ops.segment_prod, indices_are_sorted=False), + {"T", "Tindices", "Tnumsegments"}), + "UnsortedSegmentSum": _get_jax_op( + functools.partial(jax.ops.segment_sum, indices_are_sorted=False), + {"T", "Tindices", "Tnumsegments"}), "Where": _get_jax_op(jnp.argwhere, {"T"}), "ZerosLike": _get_jax_op(jnp.zeros_like, {"T"}), # The assignment logic is handled in _OpNode and convert(). @@ -1692,6 +1704,25 @@ def _func(x: jnp.ndarray) -> Tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray]: return _func +@register_operation("TensorScatterUpdate") +def _tensor_scatter_update(proto): + """Parse an TensorScatterUpdate Op.""" + _check_attrs(proto, {"T", "Tindices"}) + + def _func( + operand: jnp.ndarray, + indices: jnp.ndarray, + updates: jnp.ndarray, + ) -> jnp.ndarray: + dimension_numbers = jax.lax.ScatterDimensionNumbers( + range(1, updates.ndim), range(indices.shape[-1]), + range(indices.shape[-1])) + return jax.lax.scatter( + operand, indices, updates, dimension_numbers=dimension_numbers) + + return _func + + @register_operation("TopKV2") def _top_k(proto): _check_attrs(proto, {"T", "sorted"}) diff --git a/tf2jax/_src/ops_test.py b/tf2jax/_src/ops_test.py index 9fbd3ab..d2b1917 100644 --- a/tf2jax/_src/ops_test.py +++ b/tf2jax/_src/ops_test.py @@ -1545,6 +1545,64 @@ def switch_fn(x): ) self._test_convert(switch_fn, [np.array(1, dtype=np.int32)]) + @chex.variants(with_jit=True, without_jit=True) + @parameterized.named_parameters( + chex.params_product( + ( + ( + "1D", + np.ones([8], dtype=np.float32), + np.array([[4], [3], [1], [7]], dtype=np.int32), + np.array([9, 10, 11, 12], dtype=np.float32), + ), + ( + "2D_scalar", + np.ones([8, 2], dtype=np.float32), + np.array([[4, 0], [3, 1], [1, 0], [7, 1]], dtype=np.int32), + np.array([9, 10, 11, 12], dtype=np.float32), + ), + ( + "2D_slice", + np.ones([8, 2], dtype=np.float32), + np.array([[4], [3], [1], [7]], dtype=np.int32), + np.array([[9, 90], [10, 100], [11, 110], [12, 120]], + dtype=np.float32), + ), + ( + "3D_scalar", + np.ones([8, 3, 2], dtype=np.float32), + np.array([[4, 0, 0], [3, 1, 1], [1, 2, 0], [7, 0, 1]], + dtype=np.int32), + np.array([9, 10, 11, 12], dtype=np.float32), + ), + ( + "3D_slice", + np.ones([8, 3, 2], dtype=np.float32), + np.array([[4, 0], [3, 1], [1, 2], [7, 0]], dtype=np.int32), + np.array([[9, 90], [10, 100], [11, 110], [12, 120]], + dtype=np.float32), + ), + ( + "3D_block", + np.ones([8, 3, 2], dtype=np.float32), + np.array([[4], [3], [1], [7]], dtype=np.int32), + np.array([ + [[9, 90], [91, 92], [93, 94]], + [[10, 100], [101, 102], [103, 104]], + [[11, 110], [111, 112], [113, 114]], + [[12, 120], [121, 122], [123, 124]], + ], + dtype=np.float32), + ), + ), + named=True, + )) + def test_tensor_scatter_update(self, tensor, indices, updates): + def scatter(x, inds, ups): + return tf.raw_ops.TensorScatterUpdate(tensor=x, indices=inds, updates=ups) + + self._test_convert(scatter, [tensor, indices, updates]) + @chex.variants(with_jit=True, without_jit=True) def test_unpack(self): inputs = np.array([[1, 2], [3, 4], [5, 6]]) @@ -1558,6 +1616,22 @@ def unpack_static(): return [tf.zeros(s) for s in unpack(inputs)] self._test_convert(unpack_static, []) + @chex.variants(with_jit=True, without_jit=True) + @parameterized.parameters("UnsortedSegmentSum", "UnsortedSegmentMax", + "UnsortedSegmentMin", "UnsortedSegmentProd") + def test_unsorted_segment(self, op_name): + def segment_reduce(x, ids): + return getattr(tf.raw_ops, op_name)( + data=x, segment_ids=ids, num_segments=2) + + data = np.array([5, 1, 7, 2, 3, 4], np.float32) + segment_ids = np.array([0, 0, 1, 1, 0, 1], np.int32) + self._test_convert(segment_reduce, [data, segment_ids]) + + data = np.array([[1, 2, 3, 4], [5, 6, 7, 8], [4, 3, 2, 1]], np.float32) + segment_ids = np.array([0, 1, 0], np.int32) + self._test_convert(segment_reduce, [data, segment_ids]) + @chex.variants(without_jit=True) def test_where(self): inputs = [np.array([True, False])]