Skip to content

Commit

Permalink
Add a _to_jax_shape function to allow partial TF shape to JAX symbo…
Browse files Browse the repository at this point in the history
…lic shape conversion.

PiperOrigin-RevId: 732931969
  • Loading branch information
Orbax Authors committed Mar 3, 2025
1 parent 6cf3572 commit 7540bbe
Show file tree
Hide file tree
Showing 19 changed files with 857 additions and 351 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/build.yml
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ jobs:
- name: Test with pytest
# TODO(yaning): Move these to an exclude target within pytest.ini.
run: |
python -m pytest --ignore=orbax/checkpoint/experimental/emergency/broadcast_multislice_test.py --ignore=orbax/checkpoint/experimental/emergency/checkpoint_manager_test.py --ignore=orbax/checkpoint/experimental/emergency/local_checkpoint_data_debugging_test.py --ignore=orbax/checkpoint/experimental/emergency/local_checkpoint_manager_test.py --ignore=orbax/checkpoint/experimental/emergency/multihost_test.py --ignore=orbax/checkpoint/experimental/emergency/replicator_checkpoint_manager_test.py --ignore=orbax/checkpoint/_src/testing/multiprocess_test.py
python -m pytest --ignore=orbax/checkpoint/experimental/emergency/broadcast_multislice_test.py --ignore=orbax/checkpoint/experimental/emergency/checkpoint_manager_test.py --ignore=orbax/checkpoint/experimental/emergency/single_slice_checkpoint_manager_test.py --ignore=orbax/checkpoint/experimental/emergency/local_checkpoint_data_debugging_test.py --ignore=orbax/checkpoint/experimental/emergency/local_checkpoint_manager_test.py --ignore=orbax/checkpoint/experimental/emergency/multihost_test.py --ignore=orbax/checkpoint/experimental/emergency/replicator_checkpoint_manager_test.py --ignore=orbax/checkpoint/_src/testing/multiprocess_test.py
# The below step just reports the success or failure of tests as a "commit status".
# This is needed for copybara integration.
- name: Report success or failure as github status
Expand Down
4 changes: 3 additions & 1 deletion checkpoint/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

### Added

- support saving and restoring jax.random.key() in PyTree
- support saving and restoring jax.random.key() in PyTree.
- `CheckpointableHandler` for V1.
- Support single-slice checkpointing in `emergency.CheckpointManager`.

## [0.11.6] - 2025-02-20

Expand Down
2 changes: 2 additions & 0 deletions checkpoint/orbax/checkpoint/_src/multihost/multislice.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,8 @@ def slice_count(
global_mesh: jax.sharding.Mesh, *, replica_axis_index: int = 0
) -> int:
"""Number of slices implied by the mesh's replica dimension."""
if len(global_mesh.shape_tuple) == 1:
return 1
return global_mesh.devices.shape[replica_axis_index]


Expand Down
181 changes: 12 additions & 169 deletions checkpoint/orbax/checkpoint/abstract_checkpoint_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,11 @@ def best_step(self) -> Optional[int]:

@abc.abstractmethod
def reload(self):
"""Performs disk reads to ensure internal properties are up to date."""
"""Reloads internal properties.
Resets internal cache of checkpoint steps, in case the directory managed
by this object has been updated externally.
"""

@abc.abstractmethod
def reached_preemption(self, step: int) -> bool:
Expand All @@ -112,186 +116,25 @@ def delete(self, step: int):
def save(
self,
step: int,
items: Optional[Union[Any, Mapping[str, Any]]] = None,
save_kwargs: Optional[Union[SaveParams, Mapping[str, SaveParams]]] = None,
metrics: Optional[PyTree] = None,
force: Optional[bool] = False,
args: Optional[args_lib.CheckpointArgs] = None,
custom_metadata: dict[str, Any] | None = None,
*args,
**kwargs,
) -> bool:
"""Saves the provided items.
This method should be called by all hosts - process synchronization and
actions that need to be performed on only one host are managed internally.
NOTE: The `items` and `save_kwargs` arguments are deprecated, use `args`
instead. Make sure to configure `CheckpointManager` with `item_names`.
`args` should be a subclass of
`orbax.checkpoint.args.CheckpointArgs`, the specific type of which is used
to indicate what logic is used to save the object. For a typical, PyTree of
arrays, use `StandardSave`/`StandardRestore`.
When constructing the `CheckpointManager`, if no `item_names` were provided,
it is assumed that we are managing a single object. If `item_names` were
provided, it is assumed that we are managing multiple objects, and `args`
must be `orbax.checkpoint.args.CompositeArgs`. See below for details.
Example::
# Single item
mngr = ocp.CheckpointManager(directory)
mngr.save(step, args=ocp.args.StandardSave(my_train_state))
# Multiple items
mngr = ocp.CheckpointManager(directory, item_names=('state', 'meta'))
mngr.save(step, args=ocp.args.Composite(
state=ocp.args.StandardSave(my_train_state),
meta=ocp.args.JsonSave(my_metadata)
))
Args:
step: current step, int
items: a savable object, or a dictionary of object name to savable object.
save_kwargs: save kwargs for a single Checkpointer, or a dictionary of
object name to kwargs needed by the Checkpointer implementation to save
the object.
metrics: a dictionary of metric name (string) to numeric value to be
tracked along with this checkpoint. Required if `options.best_fn` is
set. Allows users to specify a metric value to determine which
checkpoints are best and should be kept (in conjunction with
`options.max_to_keep`).
force: if `True`, this method will attempt to save a checkpoint regardless
of the result of `AbstractCheckpointManager.should_save(step)`. By
default, `save` will only write a checkpoint to disk when the options
permit, e.g. when `step` is in `options.save_interval_steps` or
`options.save_on_steps`. Setting `force=True` will not overwrite
existing checkpoints.
args: `CheckpointArgs` which is used to save checkpointable objects with
the appropriate logic.
custom_metadata: a dictionary of custom metadata to be written to the
checkpoint directory via StepMetadata.
Returns:
bool indicating whether a save operation was performed.
Raises:
ValueError: if `track_best` was indicated but `metrics` is not provided.
ValueError: directory creation failed.
ValueError: if an item is provided for which no `Checkpointer` is
found.
ValueError: if the checkpoint already exists.
"""
"""Saves the given step."""

@abc.abstractmethod
def restore(
self,
step: Optional[int],
items: Optional[Union[Any, Mapping[str, Any]]] = None,
restore_kwargs: Optional[
Union[RestoreParams, Mapping[str, RestoreParams]]
] = None,
directory: Optional[epath.PathLike] = None,
args: Optional[args_lib.CheckpointArgs] = None,
*args,
**kwargs,
) -> Union[Any, Mapping[str, Any], args_lib.Composite]:
"""Restores from the given step and provided items.
This method should be called by all hosts - process synchronization and
actions that need to be performed on only one host are managed internally.
NOTE: The `items` and `restore_kwargs` arguments are deprecated, use `args`
instead. Make sure to configure `CheckpointManager` with `item_names`.
See `save` docstring for additional details.
Example::
# Single item
mngr = ocp.CheckpointManager(directory)
mngr.restore(step, args=ocp.args.StandardRestore(abstract_train_state))
# Multiple items
mngr = ocp.CheckpointManager(directory, item_names=('state', 'meta'))
mngr.restore(step, args=ocp.args.Composite(
state=ocp.args.StandardRestore(abstract_train_state),
meta=ocp.args.JsonRestore(),
))
# If it is acceptable to restore without providing additional arguments,
# and if a save has already been performed, it is ok to do the following:
mngr.restore(step, args=ocp.args.Composite(state=None, meta=None))
# If a save has not already been performed, there is no way for Orbax to
# know how to restore the objects. If a save has already been performed,
# it remembers the logic used to save the objects.
Args:
step: current step, int
items: a restoreable object, or a dictionary of object name to restorable
object.
restore_kwargs: restore kwargs for a single Checkpointer, or a dictionary
of object name to kwargs needed by the Checkpointer implementation to
restore the object.
directory: if provided, uses the given directory rather than the
`directory` property of this class. Can be used to restore checkpoints
from an independent location.
args: `CheckpointArgs` which is used to restore checkpointable objects
with the appropriate logic.
Returns:
If managing a single item, returns a single checkpointable object.
If managing multiple items, returns ocp.args.Composite, where the keys
are item names, and values are checkpointable objects.
"""
"""Restores the given step."""

@abc.abstractmethod
def item_metadata(
self, step: int
) -> Union[Any, Mapping[str, Any], args_lib.Composite]:
"""For all Checkpointers, returns any metadata associated with the item.
Calls the `metadata` method for each Checkpointer and returns a
mapping of each item name to the restored metadata. If the manager only
manages a single item, a single metadata will be returned instead.
To avoid errors due to missing CheckpointHandlers, concrete
CheckpointManager constructor must allow mapping from item names to
respective CheckpointHandlers to be input other than via save() and
restore(). Please note that save() and restore() calls automatically
map CheckpointHandlers to respective item names and retain it during the
lifetime of the CheckpointManager instance.
Example::
# Single item
mngr = ocp.CheckpointManager(directory)
# No calls to save() or restore() before calling item_metadata().
mngr.item_metadata(step) # Raises error.
mngr = ocp.CheckpointManager(directory,
item_handlers=ocp.StandardCheckpointHandler)
# No calls to save() or restore() before calling item_metadata().
metadata = mngr.item_metadata(step) # Successful.
# Multiple items
mngr = ocp.CheckpointManager(directory, item_names=('state', 'extra'))
# No calls to save() or restore() before calling item_metadata().
mngr.item_metadata(step) # Raises error.
mngr = ocp.CheckpointManager(directory,
item_names=('state', 'extra'),
item_handlers={
'state': ocp.StandardCheckpointHandler,
'extra': ocp.PytreeCheckpointHandler,
}
)
# No calls to save() or restore() before calling item_metadata().
metadata = mngr.item_metadata(step) # Successful.
Metadata may be None for an individual item.
Args:
step: Step for which to retrieve metadata.
Returns:
A dictionary mapping name to item metadata, or a single item metadata.
"""
"""Returns metadata for all known items."""

@abc.abstractmethod
def metadata(
Expand Down
Loading

0 comments on commit 7540bbe

Please sign in to comment.