Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Why doesn't nnx.fori_loop work here? #4436

Open
cgarciae opened this issue Dec 13, 2024 · 2 comments
Open

Why doesn't nnx.fori_loop work here? #4436

cgarciae opened this issue Dec 13, 2024 · 2 comments

Comments

@cgarciae
Copy link
Collaborator

Discussed in #4433

Originally posted by onnoeberhard December 13, 2024
I want to train two models at the same time. To do this, I use a fori_loop:

import jax
from flax import nnx

model = nnx.Linear(2, 2, rngs=nnx.Rngs(jax.random.PRNGKey(0)))
model2 = nnx.Linear(2, 2, rngs=nnx.Rngs(jax.random.PRNGKey(1)))

def f(i, x):
    return x

nnx.fori_loop(0, 10, f, (model, model2))

The above code throws the following error: ValueError: nnx.fori_loop requires body function's input and output to have the same reference and pytree structure, but they differ. If the mismatch comes from index_mapping field, you might have modified reference structure within the body function, which is not allowed.

If I loop with only one model, for example nnx.fori_loop(0, 10, f, (model, model)), there is no error. What is the problem here?

@BeeGass
Copy link

BeeGass commented Dec 15, 2024

Hey @cgarciae and @onnoeberhard

Im going to do a bit more testing but looking into nnx.fori_loop it calls ForiLoopBodyFn which calls extract.from_tree. Given that you are using two different models to loop over I believe the error is arising because the two states of the two models are being merged when performing pure_out = extract.to_tree(out, ctxtag='fori_loop_body'). This means that the pure_val is being mapped over a combined/merged state of model and model2. However this is the current theory. Going to see if this is actually true.

edit:
just to add, I believe the reason why nnx.fori_loop(0, 10, f, (model, model)) works is because the merged version of model and model within (model, model) is because the states are the same, thus the merged version of the same states is the same as if you were to do just nnx.fori_loop(0, 10, f, model). I believe you could also get the same result if you did:

import jax
from flax import nnx

rng = nnx.Rngs(jax.random.PRNGKey(0))
model = nnx.Linear(2, 2, rngs=rng)
model2 = nnx.Linear(2, 2, rngs=rng)

def f(i, x):
    return x

nnx.fori_loop(0, 10, f, (model, model2))

edit2:
Yes this seems to be what is happening. You can take a look with this example:

import jax
from flax import nnx

# Create two models with different seeds
model1 = nnx.Linear(2, 2, rngs=nnx.Rngs(jax.random.PRNGKey(0)))
model2 = nnx.Linear(2, 2, rngs=nnx.Rngs(jax.random.PRNGKey(1)))

# First let's look at what extract.to_tree does with each model individually
state1 = nnx.extract.to_tree(model1)
state2 = nnx.extract.to_tree(model2)

print("Model 1 state structure:")
nnx.display(state1)

print("\nModel 2 state structure:")
nnx.display(state2)

# Now try extract.to_tree on the tuple of both models
try:
    combined_state = nnx.extract.to_tree((model1, model2))
    nnx.display(combined_state)
except ValueError as e:
    print("\nError when combining states:")
    print(e)

when I do this I notice model1 has kernel values:

[[0.3179389536380768, 1.2253450155258179],
[0.14415453374385834, -1.0611951351165771]]

model2 has kernel values:

[[-0.6824085712432861, -1.2025362253189087],
[1.5878416299819946, 0.880131185054779]]

then the merged version of the two has kernel values (same as model1):

[[-0.6824085712432861, -1.2025362253189087],
[1.5878416299819946, 0.880131185054779]]

meaning we arent actually combining theses states. we are just overwriting, in this case, model2's state

This explains the error because there are two inputs trying to be mapped to what is now only one available function.

@cgarciae
Copy link
Collaborator Author

Hey @BeeGass , thanks for the context!

The error above arises when comparing the pytree structures of the input and output. It would be good to just print both structures and test for equality.

(btw: I'm a bit curious about the overriding behaviour you mention)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants