-
Notifications
You must be signed in to change notification settings - Fork 658
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Issues checkpointing optimizer state using Optax, nnx.Optimizer, and Orbax #4423
Comments
It seems that also adding a type handler for |
@SandSnip3r |
Hey @SandSnip3r, try serializing the checkpoint = nnx.state(optimizer)
# optional but works better
checkpoint = checkpoint.to_pure_dict()
checkpointer.save(checkpointDir, checkpoint, force=True) Then to load it you checkpoint = checkpointer.load(...)
nnx.update(optimizer, checkpoint) |
Thanks @cgarciae! I'll give this a try when I get time. Why does What about the arguments to |
It removes the
I think if you have a pure dictionary you don't need the target structure. Else just recreate the checkpoint structure using abstract_optimizer = nnx.eval_shape(lambda: create_optimizer())
target = nnx.state(abstract_optimizer).to_pure_dict()
... |
@cgarciae Honestly great explanation. Based on this conversation I was able to make this. Its not perfect but this works pretty well. class CheckpointManager:
def __init__(self, checkpoint_dir: str, keep_n: int = 3):
self.checkpoint_dir = checkpoint_dir
self.checkpointer = orbax.checkpoint.PyTreeCheckpointer()
self.keep_n = keep_n
def _cleanup_old_checkpoints(self, step: int):
# Get all checkpoint steps except 'best'
checkpoints = []
for path in self.checkpoint_dir.glob("model-*"):
if "best" not in str(path):
try:
ckpt_step = int(path.name.split("-")[-1])
checkpoints.append(ckpt_step)
except ValueError:
continue
# Sort and remove old checkpoints while keeping newest n
checkpoints.sort(reverse=True)
for old_step in checkpoints[self.keep_n :]:
for prefix in ["model", "optimizer", "metrics"]:
path = self.checkpoint_dir / f"{prefix}-{old_step}"
if path.exists():
shutil.rmtree(path)
def save_model(self, model: Any, step: int, is_best: bool = False):
state = nnx.state(model).to_pure_dict()
self._save(obj=state, filename=f"model-{step}")
if is_best:
self._save(obj=state, filename="model-best")
self._cleanup_old_checkpoints(step)
def save_optimizer(self, optimizer: nnx.Optimizer, step: int):
state = nnx.state(optimizer).to_pure_dict()
self._save(obj=state, filename=f"optimizer-{step}")
def save_training_state(self, step: int, total_tokens: int, metrics: dict = None):
metrics = {"step": step, "total_tokens": total_tokens, "metrics": metrics}
self._save(obj=metrics, filename=f"metrics-{step}")
def restore_model(self, model_cls: Any, model_config: dict, mesh: Mesh, step: int) -> Any:
abs_model = nnx.eval_shape(lambda: model_cls(**model_config, rngs=nnx.Rngs(0)))
abs_state = nnx.state(abs_model).to_pure_dict()
target = jax.tree.map(
lambda a, s: jax.ShapeDtypeStruct(a.shape, a.dtype, sharding=s),
abs_state,
nnx.get_named_sharding(nnx.state(abs_model), mesh),
)
state = self._restore(filename=f"model-{step}", target=target)
model = nnx.merge(abs_model, state)
return model
def restore_optimizer(
self, model: Any, train_config: TrainingConfig, step: int
) -> nnx.Optimizer:
tx = create_tx(train_config) # function where I construct my optax optimizer
abs_optimizer = nnx.eval_shape(lambda: nnx.Optimizer(model, tx))
abs_state = nnx.state(abs_optimizer).to_pure_dict()
state = self._restore(filename=f"optimizer-{step}", target=abs_state)
optimizer = nnx.Optimizer(model, tx)
nnx.update(optimizer, state)
return optimizer
def restore_training_state(self, step: int) -> dict:
"""Restore training metrics and token count."""
return self._restore(filename=f"metrics-{step}")
def _save(self, obj: Any, filename: str):
path = self.checkpoint_dir.absolute() / filename
save_args = orbax_utils.save_args_from_target(obj)
self.checkpointer.save(str(Path(path)), obj, save_args=save_args)
def _restore(self, filename: str, target: Any = None) -> Any:
path = self.checkpoint_dir.absolute() / filename
return self.checkpointer.restore(str(Path(path)), target=target) |
As described, this also works without using
Then restoring can be done with
Thanks a lot, @cgarciae. What do you think about adding a small blurb about saving & restoring optimizer states in the flax documentation section about checkpointing? https://flax.readthedocs.io/en/latest/guides/checkpointing.html#save-checkpoints I think this would be nice, especially since flax.nnx is offering an API for optimizers. |
I am running into an error while trying to checkpoint an Optax optimizer state, wrapped as an
nnx.Optimizer
, using the Orbax checkpointing library.I am using packages:
Minimal repro:
I see yall have documentation about using Orbax for model checkpointing, but don't see any official info about optimizer state checkpointing.
I see an older github issue where the Optax folks recommended
cloudpickle
: google-deepmind/optax#180I tried adding a custom serialization/deserialization for
nnx.training.optimizer.OptArray
via https://orbax.readthedocs.io/en/latest/guides/checkpoint/custom_handlers.html#custom-serialization-deserialization, but then just ran into the next error:Maybe I'd need to also add a type handler for
OptVariable
as well as maybe evenVariable
? That seems quite annoying. Am I missing something?Thanks for your time!
The text was updated successfully, but these errors were encountered: