-
Notifications
You must be signed in to change notification settings - Fork 658
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Why doesn't nnx.fori_loop work here? #4436
Comments
Hey @cgarciae and @onnoeberhard Im going to do a bit more testing but looking into edit: import jax
from flax import nnx
rng = nnx.Rngs(jax.random.PRNGKey(0))
model = nnx.Linear(2, 2, rngs=rng)
model2 = nnx.Linear(2, 2, rngs=rng)
def f(i, x):
return x
nnx.fori_loop(0, 10, f, (model, model2)) edit2: import jax
from flax import nnx
# Create two models with different seeds
model1 = nnx.Linear(2, 2, rngs=nnx.Rngs(jax.random.PRNGKey(0)))
model2 = nnx.Linear(2, 2, rngs=nnx.Rngs(jax.random.PRNGKey(1)))
# First let's look at what extract.to_tree does with each model individually
state1 = nnx.extract.to_tree(model1)
state2 = nnx.extract.to_tree(model2)
print("Model 1 state structure:")
nnx.display(state1)
print("\nModel 2 state structure:")
nnx.display(state2)
# Now try extract.to_tree on the tuple of both models
try:
combined_state = nnx.extract.to_tree((model1, model2))
nnx.display(combined_state)
except ValueError as e:
print("\nError when combining states:")
print(e) when I do this I notice model1 has kernel values:
model2 has kernel values:
then the merged version of the two has kernel values (same as model1):
meaning we arent actually combining theses states. we are just overwriting, in this case, This explains the error because there are two inputs trying to be mapped to what is now only one available function. |
Hey @BeeGass , thanks for the context! The error above arises when comparing the pytree structures of the input and output. It would be good to just print both structures and test for equality. (btw: I'm a bit curious about the overriding behaviour you mention) |
Discussed in #4433
Originally posted by onnoeberhard December 13, 2024
I want to train two models at the same time. To do this, I use a
fori_loop
:The above code throws the following error:
ValueError: nnx.fori_loop requires body function's input and output to have the same reference and pytree structure, but they differ. If the mismatch comes from index_mapping field, you might have modified reference structure within the body function, which is not allowed.
If I loop with only one model, for example
nnx.fori_loop(0, 10, f, (model, model))
, there is no error. What is the problem here?The text was updated successfully, but these errors were encountered: