You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
A variety of Orbax Checkpoint users have expressed concerns that the API is too complex. As one user stated:
I would say orbax makes the hard things possible, but makes the easy things hard. If we can progress to the point where it makes hard things possible, and easy things easy -- that would be brilliant!
Another user concurred:
There is a steep learning curve using orbax, compared to lets say safetensors.
Orbax core APIs need simplification in order provide a better user experience.
Overview
Orbax Checkpoint (OCP) will introduce a V1 API (the current API denoted as “V0”), located at orbax.checkpoint.v1. This will serve as the new entry point for all users. However, migration will not be required. The public documentation page will reference only the V1 API.
The V0 API will continue to exist as the underlying implementation of the V1 API in the short term, while being gradually deprecated in the long term.
The new API will be designed to address a number of complaints about the current API, including its complexity, verbosity, and steep learning curve. The current API has failed to incorporate the principle of progressive disclosure of complexity, instead opting for maximum flexibility while failing to simplify common use cases. While maximum flexibility will still be possible, the most common use cases must be easy to understand and expressable in concise code.
API Surface
Basic users:
save_pytree / save_pytree_async - Allows saving a PyTree synchronously or asynchronously.
load_pytree / load_pytree_async - Allows loading a PyTree synchronously or asynchronously, plus retrieving metadata.
Checkpointer - The primary entry point for managing a sequence of checkpoints. It offers automatic garbage collection, configurable save interval policies, etc. with a similar interface to the free functions.
Intermediate users:
“checkpointable” - The concept of a logically distinct unit of a checkpoint that has minimal relation to the other units and is often separated for loading (e.g. params/opt_state, dataset, other metadata).
save_checkpointables / load_checkpointables - Save and load arbitrary checkpointables (e.g. dataset, embeddings) (PyTree may be one of the checkpointables).
Core customization behaviors like partial load, partial update, and model surgery.
Checkpointable - Support custom checkpointables by implementing an interface.
Advanced users:
CheckpointableHandler - Allows fine-grained save/restore behavior customization (particularly formats) for a given Checkpointable.
Context - Allows specific configuration and working with operations (e.g. save/load). Allows above free functions to run in a specific given Context (configuration).
configure - Allows overriding the global Context and setting global options.
Use Cases
import orbax.checkpoint as ocp
Single-Checkpoint Use Cases
The following scenarios enumerate functionalities that users need when saving and restoring a single checkpoint, independently of the sequence of checkpoints that is typically required during training. These scenarios can be common when debugging checkpoints locally, or when running evaluations.
Many of the usage patterns listed here also apply when managing a sequence of checkpoints.
Save synchronously and asynchronously
ocp.save_pytree(path, pytree_state)
response=ocp.save_pytree_async(path, pytree_state)
response.result() # Wait for completion
Restore synchronously and asynchronously
restored=ocp.load_pytree(path)
response=ocp.async_load_pytree(path)
restored=response.result() # Wait for completion
Save and restore with optional arguments to customize behavior
ocp.save_checkpointables(path, dict(pytree=pytree, dataset=ds_iter, foo=foo))
ocp.load_checkpointables(path, dict(pytree=abstract_pytree)) # Restore only pytreeocp.load_pytree(path, abstract_pytree) # Same as aboveresult=ocp.load_checkpointables(
path, dict(pytree=abstract_pytree, dataset=...)
)
pytree, dataset=result['pytree'], result['dataset']
Obtain metadata
ocp.metadata(path) # -> CheckpointMetadata
Support custom checkpointables
classMyCustomClass: # Implements Checkpointable (see below)
… # Some properties.asyncdefsave_async(self, path: Path) ->AsyncResponse[None]:
serialized=self.properties_as_json()
returnawaitocp.save_async(path, serialized) # Saves as basic JSON file.@classmethodasyncdefload_async(
cls, path: Path, abstract_checkpointable: NoneType=None,
) ->AsyncResponse[MyCustomClass]:
# Loading this object does not require any abstract information, so we set# abstract_checkpointable to None. The user could also define an abstract class# corresponding to MyCustomClass (the concrete class) and accept that instead.serialized=awaitocp.load_async(path) # Read JSON.returncls(**serialized)
asyncdefmetadata(self, path: Path) ->None:
returnNonecustom_obj=MyCustomClass(...)
ocp.save_checkpointables(path, dict(pytree=pytree, custom_obj=custom_obj))
Support custom formats
This differs from the above in that instead of some specific user-defined object that is always saved and loaded in the same way, we instead have a core object (like a PyTree) that needs to be handled in an alternative way. For example, this allows us to configure checkpointing behavior in a non-Orbax format (e.g. Roc or PyTorch).
See below for more details. The following example allows saving to a process-local filesystem.
# Update global context statically.ocp.configure(
multiprocessing_options=MultiprocessingOptions(primary_host=None),
)
...
ocp.save_pytree(local_fs_path, pytree)
# Save pytree in a specific context.multiprocessing_options=MultiprocessingOptions(primary_host=None)
withocp.Context(multiprocessing_options=multiprocessing_options):
ocp.save_pytree(local_fs_path, pytree)
Sequence-of-Checkpoints Use Cases
These use cases roughly correspond to those served by the existing CheckpointManager object. Note however that now the class is using the name Checkpointer, and that it tries to reuse constructs introduced to serve the single-checkpoint use case.
Saving and restoring
withocp.Checkpointer(directory) asckptr:
ckptr.save_pytree(0, train_state) # Savef=ckptr.save_pytree_async(1, train_state) # Async savef.result()
ckptr.load_pytree() # Restores the latest stepckptr.load_pytree(1) # Restore a specific stepckptr.load_pytree(1, abstract_train_state) # Reshard / castf=ckptr.load_pytree_async(1) # Async restoref.result()
Saving and restoring with multiple checkpointables
# In the future, will default to ContinuousCheckpointingPolicy (at least for internal users)save_decision_policy=EveryNStepsPolicy(steps=1000)
withocp.Checkpointer(directory, save_decision_policy) asckptr:
ckptr.should_save(step) # -> bool
# Update global context statically.ocp.configure(
multiprocessing_options=MultiprocessingOptions(primary_host=None),
)
...
withocp.Checkpointer(directory) asckptr:
ckptr.save_pytree(0, train_state) # Save# Checkpointer with a specific context.# Overrides global context.withocp.Checkpointer(directory, context=ocp.Context(...)) asckptr:
ckptr.save_pytree(0, train_state) # Save# Checkpointer operation with a specific context.# Overrides both global and Checkpointer level context.withocp.Checkpointer(directory) asckptr:
withocp.Context(...):
ckptr.save_pytree(0, train_state) # Save
abstract_tree= {
'a': jax.ShapeDtypeStruct(shape, dtype, sharding), # restore as jax.Array'b': np.empty(shape, dtype), # restore as np.ndarray'c': '', # restore as string
}
ocp.load_pytree(path, abstract_tree)
Partial restoration
Partial restore is a way to solve the most common use case of loading a different tree than is present in the checkpoint - where leaves or subtrees can be omitted. The canonical example is to skip loading the optimizer state when you're doing evaluation.
In contrast, model surgery is the more complete version of this, where the user can manipulate trees/leaves in arbitrary ways, as well as load multiple trees and merge them.
abstract_tree= {
'params': { … }
# Note: omit 'opt_state' to avoid loading it'step': None# Skip loading 'step'
}
# Unsafe variant, we need to set partial_load True by defaultocp.load_pytree(path, abstract_tree)
# Safe variant, partial_load must be opted-intoocp.load_pytree(path, abstract_tree, partial_load=True)
# With global context.ocp.configure(leaf_handlers={MyCustomLeaf: CustomLeafHandler})
…
ocp.save_pytree(path, pytree)
# With local context.withocp.Context(leaf_handlers={MyCustomLeaf: CustomLeafHandler}):
ocp.save_pytree(path, pytree_with_custom_leaves)
API Definitions
Overview
In orbax.checkpoint.v1, free functions) will be the primary entry point for users into the library. These include save_pytree / load_pytree, which deal with single PyTrees, and save_checkpointables / load_checkpointables, which deal with multiple arbitrary checkpointables. (These functions also include async variants, and metadata access.)
While these functions operate at the level of an individual checkpoint path, the other main entry point is Checkpointer, which operates at the level of a root directory, under which a sequence of checkpoints corresponding to steps in a training loop are stored. This class makes restrictive assumptions about the set of tasks that a user will try to do and the patterns it is used under. In other words, it is obviously less flexible than APIs oriented around singular paths, but provides more features, like automatic garbage collection, metrics management, and save intervals. It will be suitable for many basic training loops, but not for more advanced users with greater customization needs.
The user facing api, especially the free functions, discussed in this doc are based on some global configurations. e.g. multiprocessing options, timeouts, LeafHandlerRegistry etc. These global configurations are called Context and are implemented as a context manager. The Orbax operations discussed above can be invoked within a context manager to customize their behavior with given configuration.
“Checkpointable” remains a key concept in the V1 API. A “checkpointable” refers to a logical piece of the checkpoint that is distinct in some way from other pieces. Checkpointables are separable; they may or may not be loaded concurrently and some may be omitted from the checkpoint entirely. Checkpointables are often represented by different types, and have different representations on disk. The quintessential example is model params vs. dataset.
Different checkpointables are handled by CheckpointableHandler implementations. These provide the logic for saving and loading particular objects, and also identify which objects they are capable of processing. For user convenience, a Checkpointable interface is also provided, which allows tightly coupling checkpointing logic to a specific object. While less flexible than the CheckpointableHandler, it is also more intuitive and we intend to expose it as the main interface for checkpointable customization.
The lowest-level user-accessible abstraction is the LeafHandler, which deals specifically with processing individual leaves of a PyTree. Implementing and registering a subclass of this object allows storing custom leaves in a PyTree.
Free Functions
Free functions serve as the primary entry point for users.
Currently async saving in Orbax does not return futures to the user, but instead relies on a wait_until_finished method for the user to block on the result of the save. However, it makes sense to abandon this model for a few reasons.
First, in the typical training checkpointing use case, users rarely block on the result of save, and only do so before exiting the program, typically. They rely on the library itself to block if a save is already ongoing when they try to save again.
Second, the use of context managers makes both futures and wait_until_finished unnecessary in many cases, as exiting the context automatically waits. This use case may be more common for small-scale experimentation, or one-off PyTree writes.
Third, and most importantly, the introduction of async_load does not really have a viable alternative for providing its result to the user other than via a future. The benefits of aligning the APIs of save_async and async_load outweigh any other potential arguments, in my view.
Further note that we should aim to move away from the “Future” terminology. Despite its superficial familiarity to many users, this fact can create confusion with the other Future implementations.
Instead we will opt for a construct like AsyncResponse. This is a simple container class that is returned by asynchronous APIs. It contains a method like result that allows blocking on the save/load operation and retrieving the operation result. In this respect, it is similar to ocp.Future in the current codebase.
Alternatives to save_pytree/save_checkpointables, load_pytree/load_checkpointables
We can combine both functionalities into one method, e.g.
The real difficulty is how to represent the return type for load. In order to mirror the function inputs, we must return a tuple of (pytree, extra_checkpointables). (We could also just return a single dictionary representing all checkpointables, but this is undesirable for users only providing pytree, since they will not necessarily be aware of how the checkpoint is represented, or that checkpointables is a concept they need to know. Returning a tuple of (pytree, extra_checkpointables) also requires a user to know that extra_checkpointables is an important argument, but is less bad than requiring the user to know the arbitrary name “pytree” in order to access their loaded tree.)
If we accept that a tuple is the ideal return type, in order to mirror the inputs, the return type must be:
tuple[PyTree, dict[str, Any] |None]
This interface is fairly inflexible and not that user friendly. The issues are:
pytreemust be present in every checkpoint.
There is no way to restore extra_checkpointables and not pytree, since abstract_pytree=None is used to indicate “restore the pytree however you can”.
Different treatment of None for abstract_pytree and extra_checkpointables is confusing.
The return type is complex and not well-suited to users only interested in the pytree (this violates progressive disclosure of complexity).
Return types depend on both inputs and what is in the checkpoint.
The combined interface is more trouble than it’s worth. Splitting into load_pytree / load_checkpointables meshes well with progressive disclosure of complexity, simplifies input and output signatures, and makes return types more predictable, at the cost of using two functions instead of one.
A “global configuration” is a setting that applies at multiple levels of the Orbax stack. These settings must be applied in the same way to multiple different layers, or the inconsistency can result in unexpected errors.
This includes groups of options like:
@dataclasses.dataclass(frozen=True, kw_only=True)classOptions:
# save_timeout/load_timeout, async_options: AsyncOptions# Settings for e.g. primary_host, active_processesmultiprocessing_options: MultiprocessingOptions# Options controlling path permissions, data governance annotations (internal),# CNS2 options (internal), etc.file_options: FileOptions# Options for enabling hashing/signing behavior in save/load.signing_options: SigningOptions# Other options
...
For example, a setting like save_timeout/load_timeout applies globally to an entire operation, rather than being set separately for save/load, CheckpointableHandler, and LeafHandler. Another example is primary_host, which must have the same setting in every layer, or risk difficult-to-debug breakages.
Practically speaking, global options (corresponding roughly to existing options.py) are not commonly used. They are typically set once as global configurations. In rare cases, individual operations may need to modify the settings with greater flexibility.
All configurations
Orbax provides a lot of options for configuring specific behaviors at various levels.
For the V1 API, we can subdivide options into a number of categories. Of these, the most interesting are PyTree-related options and Array-related options, which comprise the bulk of all options.
AsyncOptions
timeout_secs
barrier_sync_fn
post_finalization_callback
create_directories_asynchronously
MultiprocessingOptions
primary_host
active_processes
barrier_sync_key_prefix
FileOptions
path_permission_mode
data_governance_annotations
cns2_storage_options
temporary_path_class # Atomicity
SecurityOptions
tree_verity_options
PyTrees
array_storage_options_creator # Creates an ArrayStorageOptions struct on a per-leaf basis that customizes save behavior for individual array leaves. If ArrayStorageOptions are set globally, this option will override them.
leaf_handler_registry # LeafHandlers used for PyTree leaves
enable_descriptor
pytree_metadata_options
array_metadata_validator # Not user-facing, mostly for internal testing
partial_update # Enable partial tree update
partial_load # Enable partial tree loading
Arrays
Saving
concurrent_bytes
OCDBT options
use_ocdbt
ocdbt_target_data_file_size
enable_post_merge_validation
Storage options # Can be customized per-array if we have multiple arrays.
dtype # cast type for storage
chunk_byte_size # loose cap on the size of Tensorstore chunks
shard_axes # Chunks subdivided along this axis
metadata_key # .zarray file name
use_zarr3
enable_pinned_host_transfer
enable_write_sharding_file
use_replica_parallel
Loading
concurrent_bytes
enable_padding_and_truncation
Single-replica restore+broadcast
replica_axis_index
primary_replica_id
broadcast_memory_limit_bytes
broadcast_memory_scaling_factor
It is clear that most of these options do not need to be modified often. When they do, a global setting is possible, and only in rare cases do users need per-operation customization. If all such settings are placed within a global Options, there is considerably less doubt for users about where to find a particular setting. In an inherently complicated landscape with many different settings, it will never be “easy” to find a particular option, but it will be less difficult if all are placed under a common structure.
It is important to distinguish settings from options that unlock commonly-used operations, like model surgery, partial loading, partial updating, and forced overwrite. These operations are core functionalities that users often wish to enable or disable. As such, they should be located in the signature of the function they are used in (e.g. partial_load: bool in load, and force: bool in save). These options can still be settings in Options, to enable global defaults, but can also be exposed directly in save or load as local overrides.
Checkpointing in a Training Loop
In the existing library, the division of labor between Checkpointer and CheckpointManager has not always been well understood. This is because users often conceptualize these terms interchangeably. Now, however, the user-facing API is oriented around free functions (save/load). We can now have a single Checkpointer class that behaves much as the current CheckpointManager does.
In the long run, we will aim to achieve a level of composability that allows users to effectively write their own implementation of Checkpointer with minimal additional code. The Checkpointer itself should be a Protocol with a rigid interface - we will be resistant to adding new features without substantial discussion and agreement that the proposed feature is a core element of checkpointing in a training loop.
Checkpointer will live under orbax.checkpoint.training to make explicit its intended use for training loops. Users with greater customization requirements will be encouraged to use lower-level APIs.
The save/load interface should mirror the free functions almost identically.
classCheckpointer:
def__init__(
root_directory: epath.PathLike|RootDirectoryFormat,
*,
# Default to continuous checkpointing.save_decision_policy: SaveDecisionPolicy|None=None# Default to the latest few. See design.preservation_policy: PreservationPolicy|None=Nonestep_name_format: NameFormat[step_lib.Metadata] |None=Nonemetric_comparator: MetricComparator|None=None# Default to async deletion.deletion_options: DeletionOptions|None=None# Contextcontext: Context|None=None
…
):
…
@propertydefdirectory(self) ->Path:
…
deflatest_step(self) ->StepInfo:
…
defsteps(self) ->Sequence[StepInfo]:
…
defsave_pytree(...)
defsave_checkpointables(...)
defsave_pytree_async(...)
defsave_checkpointables_async(...)
defload_pytree(...)
defload_checkpointables(...)
defload_pytree_async(...)
defload_checkpointables_async(...)
defmetadata(self, step: int|None=None) ->RootMetadata|CheckpointMetadata:
"""Retrieves root-directory-level metadata."""
…
defreload(self):
"""Reload internal properties from the root directory."""
…
Checkpointable
“Checkpointable” is a core concept, at least for any Orbax user beyond a beginner level. A “checkpointable” refers to a logical piece of the checkpoint that is distinct in some way from other pieces. Checkpointables are separable; they may or may not be loaded concurrently and some may be omitted from the checkpoint entirely. Checkpointables are often represented by different types, and have different representations on disk. The quintessential example is model params vs. dataset.
While the V0 API drew a distinction between “items” and “PyTree leaves”, this distinction was unnecessary. A jax.Array, str, or scalar are “checkpointable” objects in the same way that a PyTree composed of these objects is.
We can introduce a Protocol to represent this concept called Checkpointable, which defines methods needed to save and load the object.
When I have a certain object that requires customized logic for serialization, I can easily define the saving and loading logic associated with that object.
Implementing the Checkpointable interface should not be necessary in most cases, as most users simply want to save an object like an array or a PyTree. Even in the case of custom, user-defined PyTree objects, Checkpointable should be rarely needed.
Furthermore, it is important to handle the case where a single PyTree may be saved in multiple different ways. This is common when writing format converters (e.g. Roc -> Orbax, PyTorch -> Orbax, etc.). In these cases, the checkpointable object is the same, but the checkpointing logic is different.
For these cases, CheckpointableHandler makes more sense, as this provides checkpointing logic for a recognizable type and can be swapped in and out as needed.
T=TypeVar('T')
AbstractT=TypeVar('AbstractT')
classCheckpointableHandler(Protocol[T, AbstractT]):
asyncdefsave(
…
) ->AsyncResponse[None]:
…
asyncdefload(
…
) ->AsyncResponse[T]:
…
asyncdefmetadata(self, directory: epath.Path) ->AsyncResponse[AbstractT]:
…
defis_handleable(self, checkpointable: T|AbstractT) ->bool:
"""Given any object, determine whether it can be stored with this handler."""
…
Here are some concrete examples:
classPyTreeHandler(CheckpointableHandler[PyTree, PyTree]):
asyncdefsave(self, path: Path, checkpointable: PyTree) ->AsyncResponse[None]:
leaf_handlers_types=collect_per_leaf_checkpointable_handlers(checkpointable)
save_responses= []
forhtinleaf_handlers_types:
# Construct the per-leaf handler.save_responses.append(awaitht().save(path, leaf))
tree_metadata_response=awaitsave_tree_metadata(path, checkpointable)
# Include finalize behavior in this response.returnUnifiedReponse(
*save_responses,
tree_metadata_response,
)
# Sometimes, checkpointables are always restorable without an# abstract checkpointable, in which case it may be None.classDatasetHandler(CheckpointableHandler[tf.data.Iterator, None]):
…
# For a singular np.ndarray handler, we can define the following abstract type:classAbstractNumpyArray:
@propertydefshape(self) ->tuple[int, …]:
…
@propertydefdtype(self) ->np.dtype:
…
classNumpyHandler(CheckpointableHandler[np.ndarray, AbstractNumpyArray]):
…
Determining which CheckpointableHandler can save/restore an checkpointable
When the user provides an object to save or restore, how can we determine which handler is appropriate to deal with this object? Ultimately, we do not have that many core CheckpointableHandlers. These include PyTree, JSON, Proto, and Array. Except for a JSON object, which is by definition a PyTree, all objects are easily distinguishable. Furthermore, a user generally doesn’t care how their object is stored, as long as it can be stored successfully. Users seeking maximum performance will go hunting for the ideal handler of their own accord. Setting reasonable defaults thus satisfies both beginners and advanced users.
Each handler can define an is_handleable method that determines whether it is capable of storing the given object. When the user does not explicitly specify a handler, we check all globally-registered CheckpointableHandlers and select the first one capable of saving or restoring the object. Registration order matters, so we can ensure JsonHandler is always preferred for sufficiently simple objects (rather than PyTreeHandler).
An argument to PyTreeHandler that allows easily setting per-leaf behaviors is the array_storage_options_creator. This is just a function that can be applied to the input PyTree via jax.tree.map_with_path and returns a ArrayStorageOptions struct, which contains a number of per-leaf settings. These options are only relevant to arrays, and are only applied to appropriate leaves.
@dataclasses.dataclassclassArrayStorageOptions:
# Cast a leaf when saving.dtype: jnp.dtype|None=None# Specify a target size for storage chunkschunk_byte_size: int|None=None# Specify axes to prioritize for subchunkingshard_axes: tuple[int, …] =tuple()
classArrayStorageOptionsCreator(Protocol):
"""Creates arguments to customize per-leaf saving behavior. The function is called by `PyTreeHandler` using:: jax.tree.map_with_path(storage_options_creator, checkpointable) The user may provide a function that returns `StorageOptions`, which will then be applied to each leaf while saving. """def__call__(self, key: jax.tree.KeyPath, value: Any) ->ArrayStorageOptions:
…
Eliminating RestoreArgs, ArrayRestoreArgs, etc.
When Orbax was first created, there was no notion that every leaf type had a corresponding abstract type that could be used to restore it. (jax.Array was not yet solidified as a concept, and jax.ShapeDtypeStruct existed but sharding did not really exist yet.) As such, RestoreArgs was introduced to capture restoration arguments relevant to a particular leaf.
Now, rather than needing to know and understand an entirely new set of classes, the user only needs to understand the rather intuitive idea that every concrete leaf type has a corresponding abstract type, that conveys properties without storing real data.
The abstract type itself conveys the desired restoration type, if the user wishes to convert from jax.Array to np.ndarray, for example. The desired restoration shape, dtype, and sharding, are also conveyed. The only other per-leaf restoration parameter used in Orbax today is strict, which controls whether padding and truncating is allowed. In practice, however, this is always a setting used at a global level, not a per-leaf level.
What if we need to add additional per-leaf restoration options in the future?
Any such option is likely to be highly specialized, as no such need has been revealed after multiple years of Orbax development. It is always possible in this case to introduce a new LeafHandler with a different abstract type that carries additional properties.
The text was updated successfully, but these errors were encountered:
Feedback appreciated!
Motivation
A variety of Orbax Checkpoint users have expressed concerns that the API is too complex. As one user stated:
Another user concurred:
Orbax core APIs need simplification in order provide a better user experience.
Overview
Orbax Checkpoint (OCP) will introduce a V1 API (the current API denoted as “V0”), located at
orbax.checkpoint.v1
. This will serve as the new entry point for all users. However, migration will not be required. The public documentation page will reference only the V1 API.The V0 API will continue to exist as the underlying implementation of the V1 API in the short term, while being gradually deprecated in the long term.
The new API will be designed to address a number of complaints about the current API, including its complexity, verbosity, and steep learning curve. The current API has failed to incorporate the principle of progressive disclosure of complexity, instead opting for maximum flexibility while failing to simplify common use cases. While maximum flexibility will still be possible, the most common use cases must be easy to understand and expressable in concise code.
API Surface
Basic users:
save_pytree / save_pytree_async
- Allows saving a PyTree synchronously or asynchronously.load_pytree / load_pytree_async
- Allows loading a PyTree synchronously or asynchronously, plus retrieving metadata.Checkpointer
- The primary entry point for managing a sequence of checkpoints. It offers automatic garbage collection, configurable save interval policies, etc. with a similar interface to the free functions.Intermediate users:
save_checkpointables
/load_checkpointables
- Save and load arbitrary checkpointables (e.g. dataset, embeddings) (PyTree may be one of the checkpointables).Checkpointable
- Support custom checkpointables by implementing an interface.Advanced users:
CheckpointableHandler
- Allows fine-grained save/restore behavior customization (particularly formats) for a givenCheckpointable
.LeafHandler
- Allows customizing PyTree save/restore behavior for custom leaf objects.Context
- Allows specific configuration and working with operations (e.g. save/load). Allows above free functions to run in a specific given Context (configuration).configure
- Allows overriding the globalContext
and setting global options.Use Cases
Single-Checkpoint Use Cases
The following scenarios enumerate functionalities that users need when saving and restoring a single checkpoint, independently of the sequence of checkpoints that is typically required during training. These scenarios can be common when debugging checkpoints locally, or when running evaluations.
Many of the usage patterns listed here also apply when managing a sequence of checkpoints.
Save synchronously and asynchronously
Restore synchronously and asynchronously
Save and restore with optional arguments to customize behavior
Save multiple logically distinct checkpointables
Obtain metadata
Support custom checkpointables
Support custom formats
This differs from the above in that instead of some specific user-defined object that is always saved and loaded in the same way, we instead have a core object (like a PyTree) that needs to be handled in an alternative way. For example, this allows us to configure checkpointing behavior in a non-Orbax format (e.g. Roc or PyTorch).
Context based customization
See below for more details. The following example allows saving to a process-local filesystem.
Sequence-of-Checkpoints Use Cases
These use cases roughly correspond to those served by the existing
CheckpointManager
object. Note however that now the class is using the nameCheckpointer
, and that it tries to reuse constructs introduced to serve the single-checkpoint use case.Saving and restoring
Saving and restoring with multiple checkpointables
Obtain metadata
Determine when to save
Identify existing checkpoints
Handle garbage collection
Rank checkpoints by metrics
Context based customization
Training loop example
Note: model surgery complexity omitted.
Tree-Specific Use Cases
Restore with resharding/reshaping/casting
Partial restoration
Partial restore is a way to solve the most common use case of loading a different tree than is present in the checkpoint - where leaves or subtrees can be omitted. The canonical example is to skip loading the optimizer state when you're doing evaluation.
In contrast, model surgery is the more complete version of this, where the user can manipulate trees/leaves in arbitrary ways, as well as load multiple trees and merge them.
Restore with model surgery
Multi-model restore with model surgery
Partial write (update)
Support custom tree leaves with Context
API Definitions
Overview
In
orbax.checkpoint.v1
, free functions) will be the primary entry point for users into the library. These includesave_pytree
/load_pytree
, which deal with single PyTrees, andsave_checkpointables
/load_checkpointables
, which deal with multiple arbitrary checkpointables. (These functions also include async variants, and metadata access.)While these functions operate at the level of an individual checkpoint path, the other main entry point is
Checkpointer
, which operates at the level of a root directory, under which a sequence of checkpoints corresponding to steps in a training loop are stored. This class makes restrictive assumptions about the set of tasks that a user will try to do and the patterns it is used under. In other words, it is obviously less flexible than APIs oriented around singular paths, but provides more features, like automatic garbage collection, metrics management, and save intervals. It will be suitable for many basic training loops, but not for more advanced users with greater customization needs.The user facing api, especially the free functions, discussed in this doc are based on some global configurations. e.g. multiprocessing options, timeouts, LeafHandlerRegistry etc. These global configurations are called
Context
and are implemented as a context manager. The Orbax operations discussed above can be invoked within a context manager to customize their behavior with given configuration.“Checkpointable” remains a key concept in the V1 API. A “checkpointable” refers to a logical piece of the checkpoint that is distinct in some way from other pieces. Checkpointables are separable; they may or may not be loaded concurrently and some may be omitted from the checkpoint entirely. Checkpointables are often represented by different types, and have different representations on disk. The quintessential example is model params vs. dataset.
Different checkpointables are handled by
CheckpointableHandler
implementations. These provide the logic for saving and loading particular objects, and also identify which objects they are capable of processing. For user convenience, aCheckpointable
interface is also provided, which allows tightly coupling checkpointing logic to a specific object. While less flexible than theCheckpointableHandler
, it is also more intuitive and we intend to expose it as the main interface for checkpointable customization.The lowest-level user-accessible abstraction is the
LeafHandler
, which deals specifically with processing individual leaves of a PyTree. Implementing and registering a subclass of this object allows storing custom leaves in a PyTree.Free Functions
Free functions serve as the primary entry point for users.
Futures vs. wait_until_finished
Currently async saving in Orbax does not return futures to the user, but instead relies on a
wait_until_finished
method for the user to block on the result of the save. However, it makes sense to abandon this model for a few reasons.First, in the typical training checkpointing use case, users rarely block on the result of save, and only do so before exiting the program, typically. They rely on the library itself to block if a save is already ongoing when they try to save again.
Second, the use of context managers makes both futures and
wait_until_finished
unnecessary in many cases, as exiting the context automatically waits. This use case may be more common for small-scale experimentation, or one-off PyTree writes.Third, and most importantly, the introduction of
async_load
does not really have a viable alternative for providing its result to the user other than via a future. The benefits of aligning the APIs ofsave_async
andasync_load
outweigh any other potential arguments, in my view.Further note that we should aim to move away from the “Future” terminology. Despite its superficial familiarity to many users, this fact can create confusion with the other Future implementations.
Instead we will opt for a construct like
AsyncResponse
. This is a simple container class that is returned by asynchronous APIs. It contains a method likeresult
that allows blocking on the save/load operation and retrieving the operation result. In this respect, it is similar toocp.Future
in the current codebase.Alternatives to save_pytree/save_checkpointables, load_pytree/load_checkpointables
We can combine both functionalities into one method, e.g.
The real difficulty is how to represent the return type for
load
. In order to mirror the function inputs, we must return a tuple of(pytree, extra_checkpointables)
. (We could also just return a single dictionary representing all checkpointables, but this is undesirable for users only providingpytree
, since they will not necessarily be aware of how the checkpoint is represented, or thatcheckpointables
is a concept they need to know. Returning a tuple of(pytree, extra_checkpointables)
also requires a user to know thatextra_checkpointables
is an important argument, but is less bad than requiring the user to know the arbitrary name “pytree” in order to access their loaded tree.)If we accept that a tuple is the ideal return type, in order to mirror the inputs, the return type must be:
This interface is fairly inflexible and not that user friendly. The issues are:
pytree
must be present in every checkpoint.extra_checkpointables
and notpytree
, sinceabstract_pytree=None
is used to indicate “restore the pytree however you can”.None
forabstract_pytree
andextra_checkpointables
is confusing.The combined interface is more trouble than it’s worth. Splitting into
load_pytree
/load_checkpointables
meshes well with progressive disclosure of complexity, simplifies input and output signatures, and makes return types more predictable, at the cost of using two functions instead of one.Context and Configuration
Dealing with global configurations (Context)
A “global configuration” is a setting that applies at multiple levels of the Orbax stack. These settings must be applied in the same way to multiple different layers, or the inconsistency can result in unexpected errors.
This includes groups of options like:
For example, a setting like
save_timeout
/load_timeout
applies globally to an entire operation, rather than being set separately forsave
/load
,CheckpointableHandler
, andLeafHandler
. Another example isprimary_host
, which must have the same setting in every layer, or risk difficult-to-debug breakages.Practically speaking, global options (corresponding roughly to existing
options.py
) are not commonly used. They are typically set once as global configurations. In rare cases, individual operations may need to modify the settings with greater flexibility.All configurations
Orbax provides a lot of options for configuring specific behaviors at various levels.
For the V1 API, we can subdivide options into a number of categories. Of these, the most interesting are PyTree-related options and Array-related options, which comprise the bulk of all options.
It is clear that most of these options do not need to be modified often. When they do, a global setting is possible, and only in rare cases do users need per-operation customization. If all such settings are placed within a global
Options
, there is considerably less doubt for users about where to find a particular setting. In an inherently complicated landscape with many different settings, it will never be “easy” to find a particular option, but it will be less difficult if all are placed under a common structure.It is important to distinguish settings from options that unlock commonly-used operations, like model surgery, partial loading, partial updating, and forced overwrite. These operations are core functionalities that users often wish to enable or disable. As such, they should be located in the signature of the function they are used in (e.g.
partial_load: bool
inload
, andforce: bool
insave
). These options can still be settings inOptions
, to enable global defaults, but can also be exposed directly insave
orload
as local overrides.Checkpointing in a Training Loop
In the existing library, the division of labor between
Checkpointer
andCheckpointManager
has not always been well understood. This is because users often conceptualize these terms interchangeably. Now, however, the user-facing API is oriented around free functions (save
/load
). We can now have a singleCheckpointer
class that behaves much as the currentCheckpointManager
does.In the long run, we will aim to achieve a level of composability that allows users to effectively write their own implementation of
Checkpointer
with minimal additional code. TheCheckpointer
itself should be aProtocol
with a rigid interface - we will be resistant to adding new features without substantial discussion and agreement that the proposed feature is a core element of checkpointing in a training loop.Checkpointer
will live underorbax.checkpoint.training
to make explicit its intended use for training loops. Users with greater customization requirements will be encouraged to use lower-level APIs.The save/load interface should mirror the free functions almost identically.
Checkpointable
“Checkpointable” is a core concept, at least for any Orbax user beyond a beginner level. A “checkpointable” refers to a logical piece of the checkpoint that is distinct in some way from other pieces. Checkpointables are separable; they may or may not be loaded concurrently and some may be omitted from the checkpoint entirely. Checkpointables are often represented by different types, and have different representations on disk. The quintessential example is model params vs. dataset.
While the V0 API drew a distinction between “items” and “PyTree leaves”, this distinction was unnecessary. A
jax.Array
,str
, or scalar are “checkpointable” objects in the same way that aPyTree
composed of these objects is.We can introduce a
Protocol
to represent this concept calledCheckpointable
, which defines methods needed to save and load the object.When I have a certain object that requires customized logic for serialization, I can easily define the saving and loading logic associated with that object.
Implementing the
Checkpointable
interface should not be necessary in most cases, as most users simply want to save an object like an array or a PyTree. Even in the case of custom, user-defined PyTree objects,Checkpointable
should be rarely needed.Furthermore, it is important to handle the case where a single PyTree may be saved in multiple different ways. This is common when writing format converters (e.g. Roc -> Orbax, PyTorch -> Orbax, etc.). In these cases, the checkpointable object is the same, but the checkpointing logic is different.
For these cases,
CheckpointableHandler
makes more sense, as this provides checkpointing logic for a recognizable type and can be swapped in and out as needed.Here are some concrete examples:
Determining which CheckpointableHandler can save/restore an checkpointable
When the user provides an object to save or restore, how can we determine which handler is appropriate to deal with this object? Ultimately, we do not have that many core
CheckpointableHandler
s. These includePyTree
,JSON
,Proto
, andArray
. Except for a JSON object, which is by definition a PyTree, all objects are easily distinguishable. Furthermore, a user generally doesn’t care how their object is stored, as long as it can be stored successfully. Users seeking maximum performance will go hunting for the ideal handler of their own accord. Setting reasonable defaults thus satisfies both beginners and advanced users.Each handler can define an
is_handleable
method that determines whether it is capable of storing the given object. When the user does not explicitly specify a handler, we check all globally-registeredCheckpointableHandler
s and select the first one capable of saving or restoring the object. Registration order matters, so we can ensureJsonHandler
is always preferred for sufficiently simple objects (rather thanPyTreeHandler
).PyTree Leaf Handlers
Customizing per-leaf save behavior
An argument to
PyTreeHandler
that allows easily setting per-leaf behaviors is thearray_storage_options_creator
. This is just a function that can be applied to the input PyTree viajax.tree.map_with_path
and returns aArrayStorageOptions
struct, which contains a number of per-leaf settings. These options are only relevant to arrays, and are only applied to appropriate leaves.Eliminating RestoreArgs, ArrayRestoreArgs, etc.
When Orbax was first created, there was no notion that every leaf type had a corresponding abstract type that could be used to restore it. (
jax.Array
was not yet solidified as a concept, andjax.ShapeDtypeStruct
existed butsharding
did not really exist yet.) As such,RestoreArgs
was introduced to capture restoration arguments relevant to a particular leaf.Now, rather than needing to know and understand an entirely new set of classes, the user only needs to understand the rather intuitive idea that every concrete leaf type has a corresponding abstract type, that conveys properties without storing real data.
For standard types, these include:
The abstract type itself conveys the desired restoration type, if the user wishes to convert from
jax.Array
tonp.ndarray
, for example. The desired restoration shape, dtype, and sharding, are also conveyed. The only other per-leaf restoration parameter used in Orbax today isstrict
, which controls whether padding and truncating is allowed. In practice, however, this is always a setting used at a global level, not a per-leaf level.What if we need to add additional per-leaf restoration options in the future?
Any such option is likely to be highly specialized, as no such need has been revealed after multiple years of Orbax development. It is always possible in this case to introduce a new LeafHandler with a different abstract type that carries additional properties.
The text was updated successfully, but these errors were encountered: