Jax equivalent of tf.scatter_nd #3658
-
Hi, Is there a function or short code snippet equivalent to Tensorflow's scatter_nd. I know of jax.lax.scatter, but it has an extra Thanks! |
Beta Was this translation helpful? Give feedback.
Replies: 1 comment 1 reply
-
The API for indexed updates is very similar but with a different axis order matching NumPy's advanced indexing. To reproduce import jax.numpy as jnp
def scatter_nd(indices, updates, shape):
zeros = jnp.zeros(shape, updates.dtype)
key = tuple(jnp.moveaxis(indices, -1, 0))
return zeros.at[key].add(updates) |
Beta Was this translation helpful? Give feedback.
lax.scatter
exists, but is indeed rather complex. For cases where you would usetf.scatter_nd
, we recommend using indexed update functions or the equivalent syntactic sugar using the.at
property.The API for indexed updates is very similar but with a different axis order matching NumPy's advanced indexing. To reproduce
scatter_nd
in JAX you could use: