-
I have a custom datatype consisting of a list of arrays. I want to use this datatype in a custom class MyDType(object):
def __init__(self, x: List[jnp.ndarray]):
# x = preprocess(x)
self.cores: List[jnp.ndarray] = [nnx.Param(y) for y in x]
# some custom functions like __add__, __mul__, copy etc.
class MyModule(nnx.Module):
def __init__(self, a: MyDType):
self.order = len(a.cores)
self.a = a
def __call__(self, x):
# some functionality, e.g.:
assert x.shape[0] == self.order
result = 0
for i in range(self.order):
result += jnp.sum(self.a.cores[i] ** x[i])
return result
a = MyDType([jnp.ones(1), 2*jnp.ones(2), 3*jnp.ones(3)])
module = MyModule(a)
vars = nnx.state(module) Here, A workaround that I found is having How do I get the cores of |
Beta Was this translation helpful? Give feedback.
Replies: 1 comment 1 reply
-
Hi @NKlug , you can inherit from |
Beta Was this translation helpful? Give feedback.
Hi @NKlug , you can inherit from
nnx.Object
which is the base for Module, Optimizer and other NNX types.