jnp.linalg.inv requires ndarray or scalar arguments #25694
Replies: 2 comments
-
In general Alternatively, if you want to invert every array in your tree, you could use |
Beta Was this translation helpful? Give feedback.
-
Note that Jake's answer covers the case in which array in your pytree is a matrix you'd like to invert. If you are in the other case in which your whole pytree is morally 'one big matrix' (just chopped into little pieces and represented in block form in a pytree), then you might like import jax
import jax.numpy as jnp
import lineax as lx
# Here is a 4x4 matrix:
a = jnp.arange(16.).reshape(4, 4)
# As this is a single array, we can invert it:
jnp.linalg.inv(a)
# What if the same data was represented in block form as a PyTree?
# Here it is a list-of-lists-of-arrays, where each array describes a 2x2 submatrix:
a2 = [jnp.split(ai, 2, axis=1) for ai in jnp.split(a, 2, axis=0)]
# In that case, `lineax.PyTreeLinearOperator` might be a convenient way to represent the
# data, that automatically handles the conversion back to a matrix for you:
struct = [jax.ShapeDtypeStruct((2,), jnp.float32), jax.ShapeDtypeStruct((2,), jnp.float32)]
op = lx.PyTreeLinearOperator(a2, output_structure=struct)
# And now the bit that is really useful:
a3 = op.as_matrix() # This is an array again.
assert (a == a3).all()
jnp.linalg.inv(a3) |
Beta Was this translation helpful? Give feedback.
-
Hello,
jnp.linalg.inv requires ndarray or scalar arguments
okay but: not pytrees?
TypeError: jnp.linalg.inv requires ndarray or scalar arguments, got <class 'main.Protein_Pytree'> at position 0.
Or is there a suspect issue with my pytree? Surely its takes a pytree? If not, I guess the fix is to convert to jnp.array in function?
Thanks, Tom
Beta Was this translation helpful? Give feedback.
All reactions