-
Notifications
You must be signed in to change notification settings - Fork 137
Open
Labels
Description
pytensor/pytensor/link/jax/dispatch/elemwise.py
Lines 72 to 89 in d3bd1f1
@jax_funcify.register(DimShuffle) | |
def jax_funcify_DimShuffle(op, **kwargs): | |
def dimshuffle(x): | |
res = jnp.transpose(x, op.transposition) | |
shape = list(res.shape[: len(op.shuffle)]) | |
for augm in op.augment: | |
shape.insert(augm, 1) | |
res = jnp.reshape(res, shape) | |
if not op.inplace: | |
res = jnp.copy(res) | |
return res | |
return dimshuffle |
The JAX docs of lax.reshape (which np.reshape uses) suggest this may be better for further optimizations: https://jax.readthedocs.io/en/latest/_autosummary/jax.lax.reshape.html#jax.lax.reshape
Relevant part:
For inserting/removing dimensions of size 1, prefer using lax.squeeze / lax.expand_dims. These preserve information about axis identity that may be useful for advanced transformation rules.