-
Notifications
You must be signed in to change notification settings - Fork 658
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Flax NNX and Orbax Checkpointing require hacking to work together #4383
Comments
Thank you! I'll contact the Orbax team to see if they can fix this on their end. |
Hey! Here's a quick and dirty workaround. Generally the idea is to use nnx.split with the NNX filter functionality to split the nnx.RngState types out of the state and then not save those.
and then just saving the This means that RNG state will not be restored, which might be sub-optimal for certain scenarios but should work for most stuff. Hope it helps! |
Another workaround for Dropout layers, and maybe custom layers too if they follow the same pattern, is to initialize them without the import jax.numpy as jnp
import orbax.checkpoint as ocp
from flax import nnx
# Init dropout without rng arg.
model = nnx.Dropout(0.5)
# Pass RNG at call time.
output = model(jnp.ones(()), rngs=nnx.Rngs(0))
# This now works.
ckpt_dir = ocp.test_utils.erase_and_create_empty("/tmp/my-checkpoints/")
checkpointer = ocp.StandardCheckpointer()
checkpointer.save(ckpt_dir / "state", nnx.split(model)[1]) Versus if the RNG is supplied at initialization, the last line throws the following:
But, this is only a workaround, as the RNG state will still not be serialized, and it makes for a more verbose call signature. |
I'm building a system using
flax.nnx
andorbax.checkpointing
. However, it is overly complicated on how to save and restore models due to the newjax.random.key()
being used inflax.nnx
rather thanjax.random.PRNGkey()
.I have had to create a workaround where all layers with
rng
andkey
in their path are changed fromdtype=key<fry>
to a format appropriate for saving. Then, upon restoration, they need to be shanged back.I am attaching a link to a notebook explaining what I've done but I would be keen to hear if there are simpler workarounds? Or, preferably, if there is a way to simple save and restore models?
https://colab.research.google.com/drive/1ozln9ejG7eRtxvbkqHYU3K6OyPvveH9w?usp=sharing
Note: I am also adding an issue to orbax to see if there is a fix their side (#1337).
The text was updated successfully, but these errors were encountered: