Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: add xlogy op #219

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions tf2jax/_src/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
from jax.experimental import checkify
from jax.lib import xla_client
import jax.numpy as jnp
import jax.scipy as jsp
import numpy as np
import tensorflow as tf
from tf2jax._src import config
Expand Down Expand Up @@ -201,6 +202,7 @@ def wrapped(proto):
functools.partial(jax.ops.segment_sum, indices_are_sorted=False),
{"T", "Tindices", "Tnumsegments"}),
"Where": _get_jax_op(jnp.argwhere, {"T"}),
"Xlogy": _get_jax_op(jsp.special.xlogy, {"T"}),
"ZerosLike": _get_jax_op(jnp.zeros_like, {"T"}),
# The assignment logic is handled in _OpNode and convert().
"AssignAddVariableOp": _get_jax_op(jnp.add, {"dtype"}),
Expand Down
9 changes: 9 additions & 0 deletions tf2jax/_src/ops_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -2146,6 +2146,15 @@ def where(cond):
return tf.raw_ops.Where(condition=cond)
self._test_convert(where, inputs)

@chex.variants(with_jit=True, without_jit=True)
def test_xlogy(self):
x = np.array([0.0, 0.0, 5.0, 5.0])
y = np.array([0.0, 5.0, 0.0, 5.0])

def xlogy(x, y):
return tf.raw_ops.Xlogy(x=x, y=y)
self._test_convert(xlogy, [x, y])

@chex.variants(with_jit=True, without_jit=True)
def test_while_loop(self):
inputs = np.array(np.reshape(range(24), (4, 3, 2)), dtype=np.float32)
Expand Down
Loading