Skip to content

Commit

Permalink
Enable ArrayMetadata persistence globally.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 733769831
  • Loading branch information
niketkumar authored and Orbax Authors committed Mar 5, 2025
1 parent c1a21a6 commit f63dc85
Show file tree
Hide file tree
Showing 4 changed files with 19 additions and 14 deletions.
10 changes: 4 additions & 6 deletions checkpoint/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- V1: Introduce Context and its usage as a contextmanager.
- V1: Add `ArrayStorageOptions` to customize per-leaf saving behavior for
arrays (e.g. `dtype`).
- support saving and restoring jax.random.key() in PyTree.
- `CheckpointableHandler` for V1.
- Support single-slice checkpointing in `emergency.CheckpointManager`.

### Fixed

Expand All @@ -23,12 +26,7 @@ which are reflected in the CHANGELOG.
### Changed

- Improve `Cannot serialize host local jax.Array` error message.

### Added

- support saving and restoring jax.random.key() in PyTree.
- `CheckpointableHandler` for V1.
- Support single-slice checkpointing in `emergency.CheckpointManager`.
- Enable `ArrayMetadata` persistence globally.

## [0.11.6] - 2025-02-20

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -301,8 +301,7 @@ def __init__(
use_zarr3: If True, use Zarr ver3 otherwise Zarr ver2.
multiprocessing_options: See orbax.checkpoint.options.
type_handler_registry: a type_handlers.TypeHandlerRegistry. If not
specified, the global type handler registry will be used. # BEGIN
enable_descriptor: If True, logs a Descriptor proto that contains lineage
specified, the global type handler registry will be used.
enable_post_merge_validation: If True, enables validation of the
parameters after the finalize step.
pytree_metadata_options: `PyTreeMetadataOptions` to manage metadata.
Expand Down Expand Up @@ -332,8 +331,10 @@ def __init__(
'/jax/orbax/pytree_checkpoint_handler/init/ocdbt'
)
logging.info(
'Created BasePyTreeCheckpointHandler: pytree_metadata_options=%s,'
' array_metadata_store=%s',
'Created BasePyTreeCheckpointHandler: use_ocdbt=%s, use_zarr3=%s,'
' pytree_metadata_options=%s, array_metadata_store=%s',
self._use_ocdbt,
self._use_zarr3,
self._pytree_metadata_options,
self._array_metadata_store,
)
Expand Down
10 changes: 6 additions & 4 deletions checkpoint/orbax/checkpoint/_src/serialization/type_handlers.py
Original file line number Diff line number Diff line change
Expand Up @@ -761,12 +761,11 @@ def get_sharding_tensorstore_spec(
kvstore_tspec = ts_utils.build_kvstore_tspec(
directory, name=_SHARDING, use_ocdbt=False
)
param_name = base64.urlsafe_b64encode(param_name.encode()).decode('utf-8')
return {
'driver': 'json',
'kvstore': kvstore_tspec,
'json_pointer': '/' + base64.urlsafe_b64encode(
param_name.encode()
).decode('utf-8'),
'json_pointer': f'/{param_name}',
}


Expand Down Expand Up @@ -1790,7 +1789,10 @@ def create_type_handler_registry(
(bytes, ScalarHandler()),
(np.number, ScalarHandler()),
(np.ndarray, NumpyHandler()),
(jax.Array, ArrayHandler()),
(
jax.Array,
ArrayHandler(array_metadata_store=array_metadata_store_lib.Store()),
),
(str, StringHandler()),
])

Expand Down
4 changes: 4 additions & 0 deletions checkpoint/orbax/checkpoint/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,6 +146,10 @@ def add_nested_key(subtree, nested_key, key_name):
continue
if k.name == '_METADATA':
continue
# array_metadatas is not a checkpoint param. Only used when ocdbt is used.
# ocdbt is still disabled in some projects like paxml.
if k.name == 'array_metadatas':
continue
tree = add_nested_key(tree, k.name.split('.'), k.name)
return tree

Expand Down

0 comments on commit f63dc85

Please sign in to comment.