diff --git a/tf2jax/_src/ops.py b/tf2jax/_src/ops.py index 7d088c5..fa8b0f6 100644 --- a/tf2jax/_src/ops.py +++ b/tf2jax/_src/ops.py @@ -24,6 +24,7 @@ from jax.experimental import checkify from jax.lib import xla_client import jax.numpy as jnp +import jax.scipy as jsp import numpy as np import tensorflow as tf from tf2jax._src import config @@ -201,6 +202,7 @@ def wrapped(proto): functools.partial(jax.ops.segment_sum, indices_are_sorted=False), {"T", "Tindices", "Tnumsegments"}), "Where": _get_jax_op(jnp.argwhere, {"T"}), + "Xlogy": _get_jax_op(jsp.special.xlogy, {"T"}), "ZerosLike": _get_jax_op(jnp.zeros_like, {"T"}), # The assignment logic is handled in _OpNode and convert(). "AssignAddVariableOp": _get_jax_op(jnp.add, {"dtype"}), diff --git a/tf2jax/_src/ops_test.py b/tf2jax/_src/ops_test.py index ee7fa32..0f5d77f 100644 --- a/tf2jax/_src/ops_test.py +++ b/tf2jax/_src/ops_test.py @@ -2146,6 +2146,15 @@ def where(cond): return tf.raw_ops.Where(condition=cond) self._test_convert(where, inputs) + @chex.variants(with_jit=True, without_jit=True) + def test_xlogy(self): + x = np.array([0.0, 0.0, 5.0, 5.0]) + y = np.array([0.0, 5.0, 0.0, 5.0]) + + def xlogy(x, y): + return tf.raw_ops.Xlogy(x=x, y=y) + self._test_convert(xlogy, [x, y]) + @chex.variants(with_jit=True, without_jit=True) def test_while_loop(self): inputs = np.array(np.reshape(range(24), (4, 3, 2)), dtype=np.float32)