Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[RFC] orbax.checkpoint.v1 design #1624

Open
orbax-dev opened this issue Feb 21, 2025 · 0 comments
Open

[RFC] orbax.checkpoint.v1 design #1624

orbax-dev opened this issue Feb 21, 2025 · 0 comments
Labels
checkpoint type:feature New feature or request

Comments

@orbax-dev
Copy link
Collaborator

Feedback appreciated!

Motivation

A variety of Orbax Checkpoint users have expressed concerns that the API is too complex. As one user stated:

I would say orbax makes the hard things possible, but makes the easy things hard. If we can progress to the point where it makes hard things possible, and easy things easy -- that would be brilliant!

Another user concurred:

There is a steep learning curve using orbax, compared to lets say safetensors.

Orbax core APIs need simplification in order provide a better user experience.

Overview

Orbax Checkpoint (OCP) will introduce a V1 API (the current API denoted as “V0”), located at orbax.checkpoint.v1. This will serve as the new entry point for all users. However, migration will not be required. The public documentation page will reference only the V1 API.

The V0 API will continue to exist as the underlying implementation of the V1 API in the short term, while being gradually deprecated in the long term.

The new API will be designed to address a number of complaints about the current API, including its complexity, verbosity, and steep learning curve. The current API has failed to incorporate the principle of progressive disclosure of complexity, instead opting for maximum flexibility while failing to simplify common use cases. While maximum flexibility will still be possible, the most common use cases must be easy to understand and expressable in concise code.

API Surface

Basic users:

  • save_pytree / save_pytree_async - Allows saving a PyTree synchronously or asynchronously.
  • load_pytree / load_pytree_async - Allows loading a PyTree synchronously or asynchronously, plus retrieving metadata.
  • Checkpointer - The primary entry point for managing a sequence of checkpoints. It offers automatic garbage collection, configurable save interval policies, etc. with a similar interface to the free functions.

Intermediate users:

  • checkpointable” - The concept of a logically distinct unit of a checkpoint that has minimal relation to the other units and is often separated for loading (e.g. params/opt_state, dataset, other metadata).
  • save_checkpointables / load_checkpointables - Save and load arbitrary checkpointables (e.g. dataset, embeddings) (PyTree may be one of the checkpointables).
  • Core customization behaviors like partial load, partial update, and model surgery.
  • Checkpointable - Support custom checkpointables by implementing an interface.

Advanced users:

  • CheckpointableHandler - Allows fine-grained save/restore behavior customization (particularly formats) for a given Checkpointable.
  • LeafHandler - Allows customizing PyTree save/restore behavior for custom leaf objects.
  • Context - Allows specific configuration and working with operations (e.g. save/load). Allows above free functions to run in a specific given Context (configuration).
  • configure - Allows overriding the global Context and setting global options.

Use Cases

import orbax.checkpoint as ocp

Single-Checkpoint Use Cases

The following scenarios enumerate functionalities that users need when saving and restoring a single checkpoint, independently of the sequence of checkpoints that is typically required during training. These scenarios can be common when debugging checkpoints locally, or when running evaluations.

Many of the usage patterns listed here also apply when managing a sequence of checkpoints.

Save synchronously and asynchronously

ocp.save_pytree(path, pytree_state)
response = ocp.save_pytree_async(path, pytree_state)
response.result()  # Wait for completion

Restore synchronously and asynchronously

restored = ocp.load_pytree(path)
response = ocp.async_load_pytree(path)
restored = response.result()  # Wait for completion

Save and restore with optional arguments to customize behavior

ocp.save_pytree(path, pytree, partial_update=True)
restored = ocp.load_pytree(path, abstract_pytree, partial_load=True)

Save multiple logically distinct checkpointables

ocp.save_checkpointables(path, dict(pytree=pytree, dataset=ds_iter, foo=foo))
ocp.load_checkpointables(path, dict(pytree=abstract_pytree))  # Restore only pytree
ocp.load_pytree(path, abstract_pytree)  # Same as above

result = ocp.load_checkpointables(
    path, dict(pytree=abstract_pytree, dataset=...)
)
pytree, dataset = result['pytree'], result['dataset']

Obtain metadata

ocp.metadata(path)  # -> CheckpointMetadata

Support custom checkpointables

class MyCustomClass:  # Implements Checkpointable (see below)# Some properties.

  async def save_async(self, path: Path) -> AsyncResponse[None]:
    serialized = self.properties_as_json()
    return await ocp.save_async(path, serialized)  # Saves as basic JSON file.

  @classmethod
  async def load_async(
      cls, path: Path, abstract_checkpointable: NoneType = None, 
  ) -> AsyncResponse[MyCustomClass]:
    # Loading this object does not require any abstract information, so we set
    # abstract_checkpointable to None. The user could also define an abstract class
    # corresponding to MyCustomClass (the concrete class) and accept that instead.
    serialized = await ocp.load_async(path)  # Read JSON.
    return cls(**serialized)

  async def metadata(self, path: Path) -> None:
    return None

custom_obj = MyCustomClass(...)
ocp.save_checkpointables(path, dict(pytree=pytree, custom_obj=custom_obj))

Support custom formats

This differs from the above in that instead of some specific user-defined object that is always saved and loaded in the same way, we instead have a core object (like a PyTree) that needs to be handled in an alternative way. For example, this allows us to configure checkpointing behavior in a non-Orbax format (e.g. Roc or PyTorch).

ocp.save_pytree(
    path,
    pytree,
    handler=RocHandler(format=einshape_numpy_proto)
)

Context based customization

See below for more details. The following example allows saving to a process-local filesystem.

# Update global context statically.
ocp.configure(
    multiprocessing_options=MultiprocessingOptions(primary_host=None),
)
...
ocp.save_pytree(local_fs_path, pytree)

# Save pytree in a specific context.
multiprocessing_options=MultiprocessingOptions(primary_host=None)
with ocp.Context(multiprocessing_options=multiprocessing_options):
  ocp.save_pytree(local_fs_path, pytree)

Sequence-of-Checkpoints Use Cases

These use cases roughly correspond to those served by the existing CheckpointManager object. Note however that now the class is using the name Checkpointer, and that it tries to reuse constructs introduced to serve the single-checkpoint use case.

Saving and restoring

with ocp.Checkpointer(directory) as ckptr:
  ckptr.save_pytree(0, train_state)  # Save
  f = ckptr.save_pytree_async(1, train_state)  # Async save
  f.result()

  ckptr.load_pytree()  # Restores the latest step
  ckptr.load_pytree(1)  # Restore a specific step
  ckptr.load_pytree(1, abstract_train_state)  # Reshard / cast
  f = ckptr.load_pytree_async(1)  # Async restore
  f.result()

Saving and restoring with multiple checkpointables

with ocp.Checkpointer(directory) as ckptr:
  ckptr.save_checkpointables(step, dict(pytree=train_state, dataset=train_iter))
  ckptr.load_pytree(step, abstract_train_state)

Obtain metadata

with ocp.Checkpointer(directory) as ckptr:
  ckptr.metadata()  # -> RootMetadata: Root-directory-level metadata
  ckptr.metadata(step)  # -> CheckpointMetadata: Checkpoint-level metadata

Determine when to save

# In the future, will default to ContinuousCheckpointingPolicy (at least for internal users)
save_decision_policy=EveryNStepsPolicy(steps=1000)
with ocp.Checkpointer(directory, save_decision_policy) as ckptr:
  ckptr.should_save(step)  # -> bool

Identify existing checkpoints

with ocp.Checkpointer(directory) as ckptr:
  ckptr.latest_step()  # -> int
  ckptr.steps()  # -> set[int]

Handle garbage collection

with ocp.Checkpointer(directory, preservation_policy=LatestN(10)) as ckptr:
  …

Rank checkpoints by metrics

# Mimics builtin `sorted` function.
preservation_policy=BestN(10, lambda m: m['accuracy'], reverse=True/False)
with ocp.Checkpointer(directory, preservation_policy) as ckptr:
  ckptr.save_pytree(step, …, metrics={'accuracy': 0.9, 'loss': 0.65})

Context based customization

# Update global context statically.
ocp.configure(
    multiprocessing_options=MultiprocessingOptions(primary_host=None),
)
...
with ocp.Checkpointer(directory) as ckptr:
  ckptr.save_pytree(0, train_state)  # Save

# Checkpointer with a specific context.
# Overrides global context.
with ocp.Checkpointer(directory, context=ocp.Context(...)) as ckptr:
  ckptr.save_pytree(0, train_state)  # Save

# Checkpointer operation with a specific context.
# Overrides both global and Checkpointer level context.
with ocp.Checkpointer(directory) as ckptr:
  with ocp.Context(...):
    ckptr.save_pytree(0, train_state)  # Save

Training loop example

Note: model surgery complexity omitted.

def init_or_restore(ckptr: Checkpointer) -> PyTree:
  if exp_cfg.init_checkpoint_path:  # Restore initial checkpoint (e.g. finetuning)
    return ckptr.load_pytree(exp_cfg.init_checkpoint_path, transform_fn=surgery_fn)
  else:  # Init from scratch
    return init()

with ocp.Checkpointer(directory, **other_options) as ckptr:
  if ckptr.latest_step() is None:  # Recovering after restart
    train_state = ckptr.load_pytree()
  else:
    train_state = init_or_restore(ckptr)
    ckptr.save_pytree(0, train_state)  # Save initial model

  for step in range(start_step, end_step):
    train_state = train_step(train_state)
    if ckptr.should_save(step):
      ckptr.save_checkpointables(step, dict(state=train_state, dataset=train_iter))

Tree-Specific Use Cases

Restore with resharding/reshaping/casting

abstract_tree = {
  'a': jax.ShapeDtypeStruct(shape, dtype, sharding),  # restore as jax.Array
  'b': np.empty(shape, dtype),  # restore as np.ndarray
  'c': '',  # restore as string
}
ocp.load_pytree(path, abstract_tree)

Partial restoration

Partial restore is a way to solve the most common use case of loading a different tree than is present in the checkpoint - where leaves or subtrees can be omitted. The canonical example is to skip loading the optimizer state when you're doing evaluation.

In contrast, model surgery is the more complete version of this, where the user can manipulate trees/leaves in arbitrary ways, as well as load multiple trees and merge them.

abstract_tree = {
  'params': { … }
  # Note: omit 'opt_state' to avoid loading it
  'step': None  # Skip loading 'step'
}
# Unsafe variant, we need to set partial_load True by default
ocp.load_pytree(path, abstract_tree)
# Safe variant, partial_load must be opted-into
ocp.load_pytree(path, abstract_tree, partial_load=True)

Restore with model surgery

def transform_fn(source: PyTree) -> PyTree:
  ... 

ocp.load_and_transform(path, transform_fn, abstract_tree)

Multi-model restore with model surgery

def transform_fn(source_a: PyTree, source_b: PyTree) -> PyTree:
  ...

ocp.load_and_transform(
    abstract_tree, transform_fn, path_a, path_b
)

Partial write (update)

ocp.save_pytree(path, partial_pytree_one, partial_update=True)
ocp.save_pytree(path, partial_pytree_two, partial_update=True)

Support custom tree leaves with Context

# With global context.
ocp.configure(leaf_handlers={MyCustomLeaf: CustomLeafHandler})
…
ocp.save_pytree(path, pytree)

# With local context.
with ocp.Context(leaf_handlers={MyCustomLeaf: CustomLeafHandler}):
  ocp.save_pytree(path, pytree_with_custom_leaves)

API Definitions

Overview

In orbax.checkpoint.v1, free functions) will be the primary entry point for users into the library. These include save_pytree / load_pytree, which deal with single PyTrees, and save_checkpointables / load_checkpointables, which deal with multiple arbitrary checkpointables. (These functions also include async variants, and metadata access.)

While these functions operate at the level of an individual checkpoint path, the other main entry point is Checkpointer, which operates at the level of a root directory, under which a sequence of checkpoints corresponding to steps in a training loop are stored. This class makes restrictive assumptions about the set of tasks that a user will try to do and the patterns it is used under. In other words, it is obviously less flexible than APIs oriented around singular paths, but provides more features, like automatic garbage collection, metrics management, and save intervals. It will be suitable for many basic training loops, but not for more advanced users with greater customization needs.

The user facing api, especially the free functions, discussed in this doc are based on some global configurations. e.g. multiprocessing options, timeouts, LeafHandlerRegistry etc. These global configurations are called Context and are implemented as a context manager. The Orbax operations discussed above can be invoked within a context manager to customize their behavior with given configuration.

“Checkpointable” remains a key concept in the V1 API. A “checkpointable” refers to a logical piece of the checkpoint that is distinct in some way from other pieces. Checkpointables are separable; they may or may not be loaded concurrently and some may be omitted from the checkpoint entirely. Checkpointables are often represented by different types, and have different representations on disk. The quintessential example is model params vs. dataset.

Different checkpointables are handled by CheckpointableHandler implementations. These provide the logic for saving and loading particular objects, and also identify which objects they are capable of processing. For user convenience, a Checkpointable interface is also provided, which allows tightly coupling checkpointing logic to a specific object. While less flexible than the CheckpointableHandler, it is also more intuitive and we intend to expose it as the main interface for checkpointable customization.

The lowest-level user-accessible abstraction is the LeafHandler, which deals specifically with processing individual leaves of a PyTree. Implementing and registering a subclass of this object allows storing custom leaves in a PyTree.

Free Functions

Free functions serve as the primary entry point for users.

Checkpointable = Any

### SAVING ###

def save_pytree(
    directory: PathLike,
    pytree: PyTree,
    *,
    partial_update: bool = False,
    force: bool = False,
    handler: Type[CheckpointableHandler] = PyTreeHandler,
):
  …

def save_pytree_async(...) -> AsyncResponse[None]:
  …

def save_checkpointables(
    directory: PathLike,
    checkpointables: dict[str, Checkpointable],
    *,
    force: bool = False,
):
  …

def save_checkpointables_async(...) -> AsyncResponse[None]:
  …

### LOADING ###

def load_pytree(
    directory: PathLike,
    abstract_pytree: PyTree | None = None,
    *,
    partial_load: bool = False,
) -> PyTree:
  …

def load_pytree_async(...) -> AsyncResponse[PyTree]:
  …

def load_checkpointables(
    directory: PathLike,
    abstract_checkpointables: dict[str, Checkpointable] | None = None,
) -> dict[str, Checkpointable]:
  …

def load_checkpointables_async(...) -> AsyncResponse[dict[str, Checkpointable]]:
  …

def metadata(directory: PathLike) -> CheckpointMetadata:
  …

### MODEL SURGERY ###

def load_and_transform(
    directories: Sequence[PathLike],
    abstract_pytree: PyTree,
    transform_fn: TransformFn,
) -> PyTree:
  …

Futures vs. wait_until_finished

Currently async saving in Orbax does not return futures to the user, but instead relies on a wait_until_finished method for the user to block on the result of the save. However, it makes sense to abandon this model for a few reasons.

First, in the typical training checkpointing use case, users rarely block on the result of save, and only do so before exiting the program, typically. They rely on the library itself to block if a save is already ongoing when they try to save again.

Second, the use of context managers makes both futures and wait_until_finished unnecessary in many cases, as exiting the context automatically waits. This use case may be more common for small-scale experimentation, or one-off PyTree writes.

Third, and most importantly, the introduction of async_load does not really have a viable alternative for providing its result to the user other than via a future. The benefits of aligning the APIs of save_async and async_load outweigh any other potential arguments, in my view.

Further note that we should aim to move away from the “Future” terminology. Despite its superficial familiarity to many users, this fact can create confusion with the other Future implementations.

Instead we will opt for a construct like AsyncResponse. This is a simple container class that is returned by asynchronous APIs. It contains a method like result that allows blocking on the save/load operation and retrieving the operation result. In this respect, it is similar to ocp.Future in the current codebase.

Alternatives to save_pytree/save_checkpointables, load_pytree/load_checkpointables

We can combine both functionalities into one method, e.g.

def load(
    self,
    path: PathLike,
    abstract_pytree: PyTree | None = None,
    *,
    extra_checkpointables: dict[str, Any] | None = None,
  ):

The real difficulty is how to represent the return type for load. In order to mirror the function inputs, we must return a tuple of (pytree, extra_checkpointables). (We could also just return a single dictionary representing all checkpointables, but this is undesirable for users only providing pytree, since they will not necessarily be aware of how the checkpoint is represented, or that checkpointables is a concept they need to know. Returning a tuple of (pytree, extra_checkpointables) also requires a user to know that extra_checkpointables is an important argument, but is less bad than requiring the user to know the arbitrary name “pytree” in order to access their loaded tree.)

If we accept that a tuple is the ideal return type, in order to mirror the inputs, the return type must be:

tuple[PyTree, dict[str, Any] | None]

This interface is fairly inflexible and not that user friendly. The issues are:

  • pytree must be present in every checkpoint.
  • There is no way to restore extra_checkpointables and not pytree, since abstract_pytree=None is used to indicate “restore the pytree however you can”.
  • Different treatment of None for abstract_pytree and extra_checkpointables is confusing.
  • The return type is complex and not well-suited to users only interested in the pytree (this violates progressive disclosure of complexity).
  • Return types depend on both inputs and what is in the checkpoint.

The combined interface is more trouble than it’s worth. Splitting into load_pytree / load_checkpointables meshes well with progressive disclosure of complexity, simplifies input and output signatures, and makes return types more predictable, at the cost of using two functions instead of one.

Context and Configuration

@dataclasses.dataclass(frozen=True)
class Context:
  
  options: Options
  
  def __enter__(self) -> Context:
    ...
    yield self
    ...
  
  def __exit__(self, ...):
    ...

Dealing with global configurations (Context)

A “global configuration” is a setting that applies at multiple levels of the Orbax stack. These settings must be applied in the same way to multiple different layers, or the inconsistency can result in unexpected errors.

This includes groups of options like:

@dataclasses.dataclass(frozen=True, kw_only=True)
class Options:
  # save_timeout/load_timeout, 
  async_options: AsyncOptions

  # Settings for e.g. primary_host, active_processes
  multiprocessing_options: MultiprocessingOptions

  # Options controlling path permissions, data governance annotations (internal),
  # CNS2 options (internal), etc.
  file_options: FileOptions

  # Options for enabling hashing/signing behavior in save/load.
  signing_options: SigningOptions

  # Other options
  ...

For example, a setting like save_timeout/load_timeout applies globally to an entire operation, rather than being set separately for save/load, CheckpointableHandler, and LeafHandler. Another example is primary_host, which must have the same setting in every layer, or risk difficult-to-debug breakages.

Practically speaking, global options (corresponding roughly to existing options.py) are not commonly used. They are typically set once as global configurations. In rare cases, individual operations may need to modify the settings with greater flexibility.

All configurations

Orbax provides a lot of options for configuring specific behaviors at various levels.

For the V1 API, we can subdivide options into a number of categories. Of these, the most interesting are PyTree-related options and Array-related options, which comprise the bulk of all options.

  • AsyncOptions
    • timeout_secs
    • barrier_sync_fn
    • post_finalization_callback
    • create_directories_asynchronously
  • MultiprocessingOptions
    • primary_host
    • active_processes
    • barrier_sync_key_prefix
  • FileOptions
    • path_permission_mode
    • data_governance_annotations
    • cns2_storage_options
    • temporary_path_class # Atomicity
  • SecurityOptions
    • tree_verity_options
  • PyTrees
    • array_storage_options_creator # Creates an ArrayStorageOptions struct on a per-leaf basis that customizes save behavior for individual array leaves. If ArrayStorageOptions are set globally, this option will override them.
    • leaf_handler_registry # LeafHandlers used for PyTree leaves
    • enable_descriptor
    • pytree_metadata_options
    • array_metadata_validator # Not user-facing, mostly for internal testing
    • partial_update # Enable partial tree update
    • partial_load # Enable partial tree loading
  • Arrays
    • Saving
      • concurrent_bytes
      • OCDBT options
        • use_ocdbt
        • ocdbt_target_data_file_size
        • enable_post_merge_validation
      • Storage options # Can be customized per-array if we have multiple arrays.
        • dtype # cast type for storage
        • chunk_byte_size # loose cap on the size of Tensorstore chunks
        • shard_axes # Chunks subdivided along this axis
      • metadata_key # .zarray file name
      • use_zarr3
      • enable_pinned_host_transfer
      • enable_write_sharding_file
      • use_replica_parallel
    • Loading
      • concurrent_bytes
      • enable_padding_and_truncation
      • Single-replica restore+broadcast
        • replica_axis_index
        • primary_replica_id
        • broadcast_memory_limit_bytes
        • broadcast_memory_scaling_factor

It is clear that most of these options do not need to be modified often. When they do, a global setting is possible, and only in rare cases do users need per-operation customization. If all such settings are placed within a global Options, there is considerably less doubt for users about where to find a particular setting. In an inherently complicated landscape with many different settings, it will never be “easy” to find a particular option, but it will be less difficult if all are placed under a common structure.

ocp.configure(
  Options(
    async_options=ocp.Options.AsyncOptions(timeout=60),
    array_options=ocp.Options.ArrayOptions(
      concurrent_bytes=1e9,
      ocdbt_options=ocp.Options.ArrayOptions.OcdbtOptions(use_ocdbt=False),
      # Cast everything to bfloat16
      storage_options=ocp.Options.ArrayOptions.StorageOptions(dtype=bfloat16),
    )
  )
)
# Alternatively:
ocp.configure(
    Options({
        'async_options': {'timeout': 60},
        'array_options': {
            'concurrent_bytes': 1e9,
            'ocdbt_options': {'use_ocdbt': False},
            'storage_options': {'dtype': bfloat16},
        }
    })
)

It is important to distinguish settings from options that unlock commonly-used operations, like model surgery, partial loading, partial updating, and forced overwrite. These operations are core functionalities that users often wish to enable or disable. As such, they should be located in the signature of the function they are used in (e.g. partial_load: bool in load, and force: bool in save). These options can still be settings in Options, to enable global defaults, but can also be exposed directly in save or load as local overrides.

Checkpointing in a Training Loop

In the existing library, the division of labor between Checkpointer and CheckpointManager has not always been well understood. This is because users often conceptualize these terms interchangeably. Now, however, the user-facing API is oriented around free functions (save/load). We can now have a single Checkpointer class that behaves much as the current CheckpointManager does.

In the long run, we will aim to achieve a level of composability that allows users to effectively write their own implementation of Checkpointer with minimal additional code. The Checkpointer itself should be a Protocol with a rigid interface - we will be resistant to adding new features without substantial discussion and agreement that the proposed feature is a core element of checkpointing in a training loop.

Checkpointer will live under orbax.checkpoint.training to make explicit its intended use for training loops. Users with greater customization requirements will be encouraged to use lower-level APIs.

The save/load interface should mirror the free functions almost identically.

class Checkpointer:

  def __init__(
      root_directory: epath.PathLike | RootDirectoryFormat,
      *,
      # Default to continuous checkpointing.
      save_decision_policy: SaveDecisionPolicy | None = None
      # Default to the latest few. See design.
      preservation_policy: PreservationPolicy | None = None
      step_name_format: NameFormat[step_lib.Metadata] | None = None
      metric_comparator: MetricComparator | None = None
      # Default to async deletion.
      deletion_options: DeletionOptions | None = None
      # Context
      context: Context | None = None
      …
  ):
    …

  @property
  def directory(self) -> Path:
    …
  def latest_step(self) -> StepInfo:
    …
  def steps(self) -> Sequence[StepInfo]:
    …

  def save_pytree(...)
  def save_checkpointables(...)
  def save_pytree_async(...)
  def save_checkpointables_async(...)

  def load_pytree(...)
  def load_checkpointables(...)
  def load_pytree_async(...)
  def load_checkpointables_async(...)

  def metadata(self, step: int | None = None) -> RootMetadata | CheckpointMetadata:
    """Retrieves root-directory-level metadata."""def reload(self):
    """Reload internal properties from the root directory."""

Checkpointable

“Checkpointable” is a core concept, at least for any Orbax user beyond a beginner level. A “checkpointable” refers to a logical piece of the checkpoint that is distinct in some way from other pieces. Checkpointables are separable; they may or may not be loaded concurrently and some may be omitted from the checkpoint entirely. Checkpointables are often represented by different types, and have different representations on disk. The quintessential example is model params vs. dataset.

While the V0 API drew a distinction between “items” and “PyTree leaves”, this distinction was unnecessary. A jax.Array, str, or scalar are “checkpointable” objects in the same way that a PyTree composed of these objects is.

We can introduce a Protocol to represent this concept called Checkpointable, which defines methods needed to save and load the object.

T = TypeVar('T')
AbstractT = TypeVar('AbstractT')

class Checkpointable(Protocol[T, AbstractT]):
  async def save(self, directory: Path) -> AsyncResponse[None]:
    …
  @classmethod
  async def load(
      cls, directory: Path, abstract_checkpointable: AbstractT | None = None
  ) -> AsyncResponse[T]:
    …
  async def metadata(self, directory: Path) -> AsyncResponse[AbstractT]:
    …

When I have a certain object that requires customized logic for serialization, I can easily define the saving and loading logic associated with that object.

class MyObject:
  
  …  # Some properties

  async def save(self, directory: Path) -> AsyncResponse[None]:
    …

  @classmethod
  async def load(
      cls,
      directory: Path,
      abstract_checkpointable: AbstractMyObject | None = None
  ) -> AsyncResponse[MyObject]:
    …

Implementing the Checkpointable interface should not be necessary in most cases, as most users simply want to save an object like an array or a PyTree. Even in the case of custom, user-defined PyTree objects, Checkpointable should be rarely needed.

Furthermore, it is important to handle the case where a single PyTree may be saved in multiple different ways. This is common when writing format converters (e.g. Roc -> Orbax, PyTorch -> Orbax, etc.). In these cases, the checkpointable object is the same, but the checkpointing logic is different.

For these cases, CheckpointableHandler makes more sense, as this provides checkpointing logic for a recognizable type and can be swapped in and out as needed.

T = TypeVar('T')
AbstractT = TypeVar('AbstractT')

class CheckpointableHandler(Protocol[T, AbstractT]):

  async def save(
      …
  ) -> AsyncResponse[None]:
    …
  async def load(
      …
  ) -> AsyncResponse[T]:
    …

  async def metadata(self, directory: epath.Path) -> AsyncResponse[AbstractT]:
    …

  def is_handleable(self, checkpointable: T | AbstractT) -> bool:
    """Given any object, determine whether it can be stored with this handler."""
    …

Here are some concrete examples:

class PyTreeHandler(CheckpointableHandler[PyTree, PyTree]):

  async def save(self, path: Path, checkpointable: PyTree) -> AsyncResponse[None]:
    leaf_handlers_types = collect_per_leaf_checkpointable_handlers(checkpointable)
    save_responses = []
    for ht in leaf_handlers_types:
      # Construct the per-leaf handler.
      save_responses.append(await ht().save(path, leaf))
    tree_metadata_response = await save_tree_metadata(path, checkpointable)
    # Include finalize behavior in this response.
    return UnifiedReponse(
        *save_responses,
        tree_metadata_response,
    )
    

# Sometimes, checkpointables are always restorable without an
# abstract checkpointable, in which case it may be None.
class DatasetHandler(CheckpointableHandler[tf.data.Iterator, None]):
  …

# For a singular np.ndarray handler, we can define the following abstract type:
class AbstractNumpyArray:
  @property
  def shape(self) -> tuple[int, …]:
    …
  @property
  def dtype(self) -> np.dtype:
    …

class NumpyHandler(CheckpointableHandler[np.ndarray, AbstractNumpyArray]):
  …

Determining which CheckpointableHandler can save/restore an checkpointable

When the user provides an object to save or restore, how can we determine which handler is appropriate to deal with this object? Ultimately, we do not have that many core CheckpointableHandlers. These include PyTree, JSON, Proto, and Array. Except for a JSON object, which is by definition a PyTree, all objects are easily distinguishable. Furthermore, a user generally doesn’t care how their object is stored, as long as it can be stored successfully. Users seeking maximum performance will go hunting for the ideal handler of their own accord. Setting reasonable defaults thus satisfies both beginners and advanced users.

Each handler can define an is_handleable method that determines whether it is capable of storing the given object. When the user does not explicitly specify a handler, we check all globally-registered CheckpointableHandlers and select the first one capable of saving or restoring the object. Registration order matters, so we can ensure JsonHandler is always preferred for sufficiently simple objects (rather than PyTreeHandler).

PyTree Leaf Handlers

Leaf = TypeVar('Leaf')
AbstractLeaf = TypeVar(AbstractLeaf)

@dataclasses.dataclass
class SerializationParam[Generic[Leaf]]:
  name: str
  keypath: jax.tree.KeyPath
  value: T

@dataclasses.dataclass
class SerializationContext:
  path: epath.Path

@dataclasses.dataclass
class DeserializationParam[Generic[AbstractLeaf]]:
  name: str
  value: AbstractT | None = None

@dataclasses.dataclass
class DeserializationContext:
  path: epath.Path
  
class LeafHandler(Protocol[Leaf, AbstractLeaf]):
 
  async def serialize(
      self,
      params: list[SerializationParam[Leaf]],
      context: SerializationContext,
  ) -> AsyncResponse[None]:
    ...

  async def deserialize(
      self,
      params: list[DeserializationParam[AbstractLeaf]],
      context: DeserializationContext,
  ) -> AsyncResponse[Leaf]:
    …

  async def metadata(
      self,
      params: list[DeserializationParam[AbstractLeaf]],
      context: DeserializationContext,
  ) -> AsyncResponse[AbstractLeaf]:
    …

  def finalize(self):
    …

class ArrayHandler(LeafHandler[jax.Array, jax.ShapeDtypeStruct]):
  …

class AbstractNumpyArray:
  @property
  def shape(self) -> tuple[int, …]:
    …
  @property
  def dtype(self) -> np.dtype:
    …
class NumpyHandler(LeafHandler[np.ndarray, AbstractNumpyArray]):
  …

Customizing per-leaf save behavior

An argument to PyTreeHandler that allows easily setting per-leaf behaviors is the array_storage_options_creator. This is just a function that can be applied to the input PyTree via jax.tree.map_with_path and returns a ArrayStorageOptions struct, which contains a number of per-leaf settings. These options are only relevant to arrays, and are only applied to appropriate leaves.

@dataclasses.dataclass
class ArrayStorageOptions:
  # Cast a leaf when saving.
  dtype: jnp.dtype | None = None
  # Specify a target size for storage chunks
  chunk_byte_size: int | None = None
  # Specify axes to prioritize for subchunking
  shard_axes: tuple[int, …] = tuple()


class ArrayStorageOptionsCreator(Protocol):
  """Creates arguments to customize per-leaf saving behavior.

  The function is called by `PyTreeHandler` using::
    jax.tree.map_with_path(storage_options_creator, checkpointable)

  The user may provide a function that returns `StorageOptions`, which will then be
  applied to each leaf while saving.
  """

  def __call__(self, key: jax.tree.KeyPath, value: Any) -> ArrayStorageOptions:
    …

Eliminating RestoreArgs, ArrayRestoreArgs, etc.

When Orbax was first created, there was no notion that every leaf type had a corresponding abstract type that could be used to restore it. (jax.Array was not yet solidified as a concept, and jax.ShapeDtypeStruct existed but sharding did not really exist yet.) As such, RestoreArgs was introduced to capture restoration arguments relevant to a particular leaf.

Now, rather than needing to know and understand an entirely new set of classes, the user only needs to understand the rather intuitive idea that every concrete leaf type has a corresponding abstract type, that conveys properties without storing real data.

For standard types, these include:

jax.Array -> jax.ShapeDtypeStruct
np.ndarray -> AbstractNumpyArray  # Uses duck typing
int -> int
float -> float
bytes -> bytes
str -> str

The abstract type itself conveys the desired restoration type, if the user wishes to convert from jax.Array to np.ndarray, for example. The desired restoration shape, dtype, and sharding, are also conveyed. The only other per-leaf restoration parameter used in Orbax today is strict, which controls whether padding and truncating is allowed. In practice, however, this is always a setting used at a global level, not a per-leaf level.

What if we need to add additional per-leaf restoration options in the future?

Any such option is likely to be highly specialized, as no such need has been revealed after multiple years of Orbax development. It is always possible in this case to introduce a new LeafHandler with a different abstract type that carries additional properties.

@orbax-dev orbax-dev added checkpoint type:feature New feature or request labels Feb 21, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
checkpoint type:feature New feature or request
Projects
None yet
Development

No branches or pull requests

1 participant