-
Notifications
You must be signed in to change notification settings - Fork 44
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Flax NNX checkpointing is not that straightforward #1105
Comments
I am having a similar issue with Orbax-checkpoint in Hence, it is not about |
I encouter the same problem on a dropout layer in flax nnx |
I meet the same issue too. Based on your code, I come out a slightly different approach which include a ADAM optimizer.
|
I think orbax should support class Model(nnx.Module):
def __init__(self, din, dmid, dout, rngs: nnx.Rngs):
self.linear = nnx.Linear(din, dmid, rngs=rngs)
self.bn = nnx.BatchNorm(dmid, rngs=rngs)
self.dropout = nnx.Dropout(0.2, rngs=rngs)
self.linear_out = nnx.Linear(dmid, dout, rngs=rngs)
def __call__(self, x):
x = nnx.relu(self.dropout(self.bn(self.linear(x))))
return self.linear_out(x)
# saving
model = Model(2, 64, 3, rngs=nnx.Rngs(0)) # eager initialization
keys, state = nnx.state(model, nnx.RngKey, ...)
keys = jax.tree.map(jax.random.key_data, keys)
... # save keys and state checkpoints
# loading
... # load keys and state checkpoints
keys = jax.tree_map(jax.random.wrap_key_data, keys)
nnx.update(model, keys, state) |
Looks code became slightly simple. But need a new function to merge the rngs key back to state.
|
Orbax team is looking into it and planing to add support for RNG in early Jan. Thanks! |
Thanks, for working on this. Looking forward to being able to easily save models with dropouts. |
Waiting for better support as well :) Currently when trying to save/restore a simple module like: import orbax.checkpoint as ocp
checkpoint = ocp.PyTreeCheckpointer()
checkpoint.save("models/renderer-nnx", nnx.state(renderer)) I get: I think this is related to this issue? Happy to see an update soon, thanks! |
I am facing a related issue. Following the NNX scan over layers example, I have a simple MLP module containing other modules. Unfortunately, after attempting the @cgarciae workaround for checkpointing the model,
|
I have been trying to use Orbax for checkpointing Flax NNX models and getting checkpointing to work for models with Dropout layers which also hold JAX RNG keys is not very straight forward. After various attempts this was the only way I could get it to work.
The comments describe the issues faced.
It would be good to address ease of use for cases where the model has Dropout layers.
The text was updated successfully, but these errors were encountered: