Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Issues checkpointing optimizer state using Optax, nnx.Optimizer, and Orbax #4423

Open
SandSnip3r opened this issue Dec 9, 2024 · 7 comments

Comments

@SandSnip3r
Copy link

I am running into an error while trying to checkpoint an Optax optimizer state, wrapped as an nnx.Optimizer, using the Orbax checkpointing library.

ValueError: Unsupported type: <class 'flax.nnx.training.optimizer.OptArray'> for key: ('0', 'count'). Supported types are (<class 'int'>, <class 'float'>, <class 'numpy.ndarray'>, <class 'jax.Array'>).

I am using packages:

  • flax 0.10.2
  • jax 0.4.36
  • jax-cuda12-pjrt 0.4.36
  • jax-cuda12-plugin 0.4.36
  • jaxlib 0.4.36
  • optax 0.2.4
  • orbax-checkpoint 0.10.2

Minimal repro:

from flax import nnx
import numpy as np
import orbax.checkpoint as ocp
import optax
import os
import pathlib

class MyModel(nnx.Module):
  def __init__(self, rngs):
    self.linear = nnx.Linear(in_features=4, out_features=1, rngs=rngs)

  def __call__(self, x):
    return self.linear(x)

rngs = nnx.Rngs(0)
model = MyModel(rngs)

tx = optax.adam(1e-3)
optimizer = nnx.Optimizer(model, tx)

checkpointDir = pathlib.Path('/tmp/my-checkpoints/')
checkpointer = ocp.StandardCheckpointer()

checkpointer.save(checkpointDir, optimizer.opt_state, force=True)

I see yall have documentation about using Orbax for model checkpointing, but don't see any official info about optimizer state checkpointing.

I see an older github issue where the Optax folks recommended cloudpickle: google-deepmind/optax#180

I tried adding a custom serialization/deserialization for nnx.training.optimizer.OptArray via https://orbax.readthedocs.io/en/latest/guides/checkpoint/custom_handlers.html#custom-serialization-deserialization, but then just ran into the next error:

ValueError: TypeHandler lookup failed for: type=<class 'flax.nnx.training.optimizer.OptVariable'>

Maybe I'd need to also add a type handler for OptVariable as well as maybe even Variable? That seems quite annoying. Am I missing something?

Thanks for your time!

@SandSnip3r
Copy link
Author

It seems that also adding a type handler for nnx.training.optimizer.OptVariable is sufficient. I'd rather not need to look into the details of these classes to write serializers/deserializers though.

@BeeGass
Copy link

BeeGass commented Dec 10, 2024

It seems that also adding a type handler for nnx.training.optimizer.OptVariable is sufficient. I'd rather not need to look into the details of these classes to write serializers/deserializers though.

@SandSnip3r
Im hoping I could ask if you could share what sounds like the solution was. I also have been having issues with this.

@cgarciae
Copy link
Collaborator

cgarciae commented Dec 10, 2024

Hey @SandSnip3r, try serializing the State instead:

checkpoint = nnx.state(optimizer)
# optional but works better
checkpoint = checkpoint.to_pure_dict()

checkpointer.save(checkpointDir, checkpoint, force=True)

Then to load it you

checkpoint = checkpointer.load(...)
nnx.update(optimizer, checkpoint)

@SandSnip3r
Copy link
Author

Thanks @cgarciae! I'll give this a try when I get time.

Why does .to_pure_dict() work better?

What about the arguments to checkpoint.load? The checkpointer requires a structure to load the data into, right? Like an abstract tree state? What would I use for this?

@cgarciae
Copy link
Collaborator

Why does .to_pure_dict() work better?

It removes the VariableState objects that contain metadata and leaves only the Arrays. It also makes checkpoint a pure dict which is usually easier to serialize.

The checkpointer requires a structure to load the data into

I think if you have a pure dictionary you don't need the target structure. Else just recreate the checkpoint structure using nnx.eval_shape e.g.

abstract_optimizer = nnx.eval_shape(lambda: create_optimizer())
target = nnx.state(abstract_optimizer).to_pure_dict()
...

@BeeGass
Copy link

BeeGass commented Dec 11, 2024

@cgarciae Honestly great explanation. Based on this conversation I was able to make this. Its not perfect but this works pretty well.

class CheckpointManager:
    def __init__(self, checkpoint_dir: str, keep_n: int = 3):
        self.checkpoint_dir = checkpoint_dir
        self.checkpointer = orbax.checkpoint.PyTreeCheckpointer()
        self.keep_n = keep_n

    def _cleanup_old_checkpoints(self, step: int):
        # Get all checkpoint steps except 'best'
        checkpoints = []
        for path in self.checkpoint_dir.glob("model-*"):
            if "best" not in str(path):
                try:
                    ckpt_step = int(path.name.split("-")[-1])
                    checkpoints.append(ckpt_step)
                except ValueError:
                    continue

        # Sort and remove old checkpoints while keeping newest n
        checkpoints.sort(reverse=True)
        for old_step in checkpoints[self.keep_n :]:
            for prefix in ["model", "optimizer", "metrics"]:
                path = self.checkpoint_dir / f"{prefix}-{old_step}"
                if path.exists():
                    shutil.rmtree(path)

    def save_model(self, model: Any, step: int, is_best: bool = False):
        state = nnx.state(model).to_pure_dict()
        self._save(obj=state, filename=f"model-{step}")
        if is_best:
            self._save(obj=state, filename="model-best")
        self._cleanup_old_checkpoints(step)

    def save_optimizer(self, optimizer: nnx.Optimizer, step: int):
        state = nnx.state(optimizer).to_pure_dict()
        self._save(obj=state, filename=f"optimizer-{step}")

    def save_training_state(self, step: int, total_tokens: int, metrics: dict = None):
        metrics = {"step": step, "total_tokens": total_tokens, "metrics": metrics}
        self._save(obj=metrics, filename=f"metrics-{step}")

    def restore_model(self, model_cls: Any, model_config: dict, mesh: Mesh, step: int) -> Any:
        abs_model = nnx.eval_shape(lambda: model_cls(**model_config, rngs=nnx.Rngs(0)))
        abs_state = nnx.state(abs_model).to_pure_dict()

        target = jax.tree.map(
            lambda a, s: jax.ShapeDtypeStruct(a.shape, a.dtype, sharding=s),
            abs_state,
            nnx.get_named_sharding(nnx.state(abs_model), mesh),
        )

        state = self._restore(filename=f"model-{step}", target=target)
        model = nnx.merge(abs_model, state)
        return model

    def restore_optimizer(
        self, model: Any, train_config: TrainingConfig, step: int
    ) -> nnx.Optimizer:
        tx = create_tx(train_config) # function where I construct my optax optimizer
        abs_optimizer = nnx.eval_shape(lambda: nnx.Optimizer(model, tx))
        abs_state = nnx.state(abs_optimizer).to_pure_dict()

        state = self._restore(filename=f"optimizer-{step}", target=abs_state)
        optimizer = nnx.Optimizer(model, tx)
        nnx.update(optimizer, state)
        return optimizer

    def restore_training_state(self, step: int) -> dict:
        """Restore training metrics and token count."""
        return self._restore(filename=f"metrics-{step}")

    def _save(self, obj: Any, filename: str):
        path = self.checkpoint_dir.absolute() / filename
        save_args = orbax_utils.save_args_from_target(obj)
        self.checkpointer.save(str(Path(path)), obj, save_args=save_args)

    def _restore(self, filename: str, target: Any = None) -> Any:
        path = self.checkpoint_dir.absolute() / filename
        return self.checkpointer.restore(str(Path(path)), target=target)

@SandSnip3r
Copy link
Author

As described, this also works without using to_pure_dict() by simply changing the last line of my repro:

checkpointer.save(checkpointDir, nnx.state(optimizer), force=True)

Then restoring can be done with

tx = optax.adam(1e-3)
optimizer = nnx.Optimizer(model, tx)
abstractOptStateTree = jax.tree_util.tree_map(ocp.utils.to_shape_dtype_struct, nnx.state(optimizer))
optimizerState = checkpointer.restore(checkpointDir, abstractOptStateTree)
nnx.update(optimizer, optimizerState)

Thanks a lot, @cgarciae.

What do you think about adding a small blurb about saving & restoring optimizer states in the flax documentation section about checkpointing? https://flax.readthedocs.io/en/latest/guides/checkpointing.html#save-checkpoints I think this would be nice, especially since flax.nnx is offering an API for optimizers.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants