From cbf41c66c43f275fcf9f18836aec7d118ce0d0db Mon Sep 17 00:00:00 2001 From: s22chan Date: Thu, 3 Oct 2024 00:11:48 -0400 Subject: [PATCH 1/2] feat: add xlogy op --- tf2jax/_src/ops.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/tf2jax/_src/ops.py b/tf2jax/_src/ops.py index 7d088c5..fa8b0f6 100644 --- a/tf2jax/_src/ops.py +++ b/tf2jax/_src/ops.py @@ -24,6 +24,7 @@ from jax.experimental import checkify from jax.lib import xla_client import jax.numpy as jnp +import jax.scipy as jsp import numpy as np import tensorflow as tf from tf2jax._src import config @@ -201,6 +202,7 @@ def wrapped(proto): functools.partial(jax.ops.segment_sum, indices_are_sorted=False), {"T", "Tindices", "Tnumsegments"}), "Where": _get_jax_op(jnp.argwhere, {"T"}), + "Xlogy": _get_jax_op(jsp.special.xlogy, {"T"}), "ZerosLike": _get_jax_op(jnp.zeros_like, {"T"}), # The assignment logic is handled in _OpNode and convert(). "AssignAddVariableOp": _get_jax_op(jnp.add, {"dtype"}), From ca9ee016284f7348fad5d674bab40828af6b54a3 Mon Sep 17 00:00:00 2001 From: Steve Chan Date: Fri, 4 Oct 2024 19:40:25 -0400 Subject: [PATCH 2/2] add tests --- tf2jax/_src/ops_test.py | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/tf2jax/_src/ops_test.py b/tf2jax/_src/ops_test.py index ee7fa32..0f5d77f 100644 --- a/tf2jax/_src/ops_test.py +++ b/tf2jax/_src/ops_test.py @@ -2146,6 +2146,15 @@ def where(cond): return tf.raw_ops.Where(condition=cond) self._test_convert(where, inputs) + @chex.variants(with_jit=True, without_jit=True) + def test_xlogy(self): + x = np.array([0.0, 0.0, 5.0, 5.0]) + y = np.array([0.0, 5.0, 0.0, 5.0]) + + def xlogy(x, y): + return tf.raw_ops.Xlogy(x=x, y=y) + self._test_convert(xlogy, [x, y]) + @chex.variants(with_jit=True, without_jit=True) def test_while_loop(self): inputs = np.array(np.reshape(range(24), (4, 3, 2)), dtype=np.float32)