Why doesn't nnx.fori_loop work here? #4433
Unanswered
onnoeberhard
asked this question in
Q&A
Replies: 2 comments
-
It seems that a possible workaround is import jax
from flax import nnx
model = nnx.Linear(2, 2, rngs=nnx.Rngs(jax.random.PRNGKey(0)))
model2 = nnx.Linear(2, 2, rngs=nnx.Rngs(jax.random.PRNGKey(1)))
container = nnx.Module()
container.model = model
container.model2 = model2
def f(i, x):
return x
nnx.fori_loop(0, 10, f, container) Is there a reason why this works and the above does not? And will this "solution" yield the expected results? I am very confused about this error. |
Beta Was this translation helpful? Give feedback.
0 replies
-
Thanks for reporting this. Looks like a bug. Converting this into an issue. |
Beta Was this translation helpful? Give feedback.
0 replies
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
-
I want to train two models at the same time. To do this, I use a
fori_loop
:The above code throws the following error:
ValueError: nnx.fori_loop requires body function's input and output to have the same reference and pytree structure, but they differ. If the mismatch comes from index_mapping field, you might have modified reference structure within the body function, which is not allowed.
If I loop with only one model, for example
nnx.fori_loop(0, 10, f, (model, model))
, there is no error. What is the problem here?Beta Was this translation helpful? Give feedback.
All reactions