Skip to content

Commit

Permalink
Support UnsortedSegmentMax, UnsortedSegmentMin, UnsortedSegmentProd, …
Browse files Browse the repository at this point in the history
…UnsortedSegmentSum and TensorScatterUpdate.

PiperOrigin-RevId: 464543303
  • Loading branch information
shaobohou authored and TF2JAXDev committed Aug 1, 2022
1 parent 50508ab commit c2efe70
Show file tree
Hide file tree
Showing 2 changed files with 105 additions and 0 deletions.
31 changes: 31 additions & 0 deletions tf2jax/_src/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,6 +149,18 @@ def wrapped(proto):
"Tan": _get_jax_op(jnp.tan, {"T"}),
"Tanh": _get_jax_op(jnp.tanh, {"T"}),
"Tile": _get_jax_op(anp.tile, {"T", "Tmultiples"}),
"UnsortedSegmentMax": _get_jax_op(
functools.partial(jax.ops.segment_max, indices_are_sorted=False),
{"T", "Tindices", "Tnumsegments"}),
"UnsortedSegmentMin": _get_jax_op(
functools.partial(jax.ops.segment_min, indices_are_sorted=False),
{"T", "Tindices", "Tnumsegments"}),
"UnsortedSegmentProd": _get_jax_op(
functools.partial(jax.ops.segment_prod, indices_are_sorted=False),
{"T", "Tindices", "Tnumsegments"}),
"UnsortedSegmentSum": _get_jax_op(
functools.partial(jax.ops.segment_sum, indices_are_sorted=False),
{"T", "Tindices", "Tnumsegments"}),
"Where": _get_jax_op(jnp.argwhere, {"T"}),
"ZerosLike": _get_jax_op(jnp.zeros_like, {"T"}),
# The assignment logic is handled in _OpNode and convert().
Expand Down Expand Up @@ -1692,6 +1704,25 @@ def _func(x: jnp.ndarray) -> Tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray]:
return _func


@register_operation("TensorScatterUpdate")
def _tensor_scatter_update(proto):
"""Parse an TensorScatterUpdate Op."""
_check_attrs(proto, {"T", "Tindices"})

def _func(
operand: jnp.ndarray,
indices: jnp.ndarray,
updates: jnp.ndarray,
) -> jnp.ndarray:
dimension_numbers = jax.lax.ScatterDimensionNumbers(
range(1, updates.ndim), range(indices.shape[-1]),
range(indices.shape[-1]))
return jax.lax.scatter(
operand, indices, updates, dimension_numbers=dimension_numbers)

return _func


@register_operation("TopKV2")
def _top_k(proto):
_check_attrs(proto, {"T", "sorted"})
Expand Down
74 changes: 74 additions & 0 deletions tf2jax/_src/ops_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -1545,6 +1545,64 @@ def switch_fn(x):
)
self._test_convert(switch_fn, [np.array(1, dtype=np.int32)])

@chex.variants(with_jit=True, without_jit=True)
@parameterized.named_parameters(
chex.params_product(
(
(
"1D",
np.ones([8], dtype=np.float32),
np.array([[4], [3], [1], [7]], dtype=np.int32),
np.array([9, 10, 11, 12], dtype=np.float32),
),
(
"2D_scalar",
np.ones([8, 2], dtype=np.float32),
np.array([[4, 0], [3, 1], [1, 0], [7, 1]], dtype=np.int32),
np.array([9, 10, 11, 12], dtype=np.float32),
),
(
"2D_slice",
np.ones([8, 2], dtype=np.float32),
np.array([[4], [3], [1], [7]], dtype=np.int32),
np.array([[9, 90], [10, 100], [11, 110], [12, 120]],
dtype=np.float32),
),
(
"3D_scalar",
np.ones([8, 3, 2], dtype=np.float32),
np.array([[4, 0, 0], [3, 1, 1], [1, 2, 0], [7, 0, 1]],
dtype=np.int32),
np.array([9, 10, 11, 12], dtype=np.float32),
),
(
"3D_slice",
np.ones([8, 3, 2], dtype=np.float32),
np.array([[4, 0], [3, 1], [1, 2], [7, 0]], dtype=np.int32),
np.array([[9, 90], [10, 100], [11, 110], [12, 120]],
dtype=np.float32),
),
(
"3D_block",
np.ones([8, 3, 2], dtype=np.float32),
np.array([[4], [3], [1], [7]], dtype=np.int32),
np.array([
[[9, 90], [91, 92], [93, 94]],
[[10, 100], [101, 102], [103, 104]],
[[11, 110], [111, 112], [113, 114]],
[[12, 120], [121, 122], [123, 124]],
],
dtype=np.float32),
),
),
named=True,
))
def test_tensor_scatter_update(self, tensor, indices, updates):
def scatter(x, inds, ups):
return tf.raw_ops.TensorScatterUpdate(tensor=x, indices=inds, updates=ups)

self._test_convert(scatter, [tensor, indices, updates])

@chex.variants(with_jit=True, without_jit=True)
def test_unpack(self):
inputs = np.array([[1, 2], [3, 4], [5, 6]])
Expand All @@ -1558,6 +1616,22 @@ def unpack_static():
return [tf.zeros(s) for s in unpack(inputs)]
self._test_convert(unpack_static, [])

@chex.variants(with_jit=True, without_jit=True)
@parameterized.parameters("UnsortedSegmentSum", "UnsortedSegmentMax",
"UnsortedSegmentMin", "UnsortedSegmentProd")
def test_unsorted_segment(self, op_name):
def segment_reduce(x, ids):
return getattr(tf.raw_ops, op_name)(
data=x, segment_ids=ids, num_segments=2)

data = np.array([5, 1, 7, 2, 3, 4], np.float32)
segment_ids = np.array([0, 0, 1, 1, 0, 1], np.int32)
self._test_convert(segment_reduce, [data, segment_ids])

data = np.array([[1, 2, 3, 4], [5, 6, 7, 8], [4, 3, 2, 1]], np.float32)
segment_ids = np.array([0, 1, 0], np.int32)
self._test_convert(segment_reduce, [data, segment_ids])

@chex.variants(without_jit=True)
def test_where(self):
inputs = [np.array([True, False])]
Expand Down

0 comments on commit c2efe70

Please sign in to comment.