From 072408e8385d4d93a8329c9d322161de5b1f3103 Mon Sep 17 00:00:00 2001 From: Adam Cogdell Date: Sun, 5 Jan 2025 12:06:40 -0800 Subject: [PATCH] Add `_update_metadata()` method to `CheckpointHandler`s. PiperOrigin-RevId: 712293541 --- checkpoint/CHANGELOG.md | 2 + .../_src/checkpointers/checkpointer.py | 38 ++++++++++++++++++- .../_src/handlers/checkpoint_handler.py | 21 +++++++++- .../handlers/composite_checkpoint_handler.py | 14 +++++++ checkpoint/orbax/checkpoint/options.py | 7 ++++ 5 files changed, 79 insertions(+), 3 deletions(-) diff --git a/checkpoint/CHANGELOG.md b/checkpoint/CHANGELOG.md index 5e031e62e..e1a75541b 100644 --- a/checkpoint/CHANGELOG.md +++ b/checkpoint/CHANGELOG.md @@ -30,6 +30,7 @@ properties not included in any tree mapping operations. `CompositeCheckpointHandler.metadata()` to retrieve item metadata by default-constructing `CheckpointHandler`s when they're listed in the saved `StepMetadata` but aren't found in the checkpoint. +- `FileOptions.format` to specify the underlying checkpointing file format. ### Fixed - Ignore not-exists and not-dir errors while building step metadata in @@ -38,6 +39,7 @@ default-constructing `CheckpointHandler`s when they're listed in the saved ### Changed - Return `StepMetadata` from `CompositeCheckpointHandler.metadata()`. +- `Checkpointer.save()` also saves `StepMetadata`. ## [0.10.2] - 2024-12-04 diff --git a/checkpoint/orbax/checkpoint/_src/checkpointers/checkpointer.py b/checkpoint/orbax/checkpoint/_src/checkpointers/checkpointer.py index 8ce7c71ca..8285a114e 100644 --- a/checkpoint/orbax/checkpoint/_src/checkpointers/checkpointer.py +++ b/checkpoint/orbax/checkpoint/_src/checkpointers/checkpointer.py @@ -29,6 +29,7 @@ from orbax.checkpoint._src.handlers import checkpoint_handler from orbax.checkpoint._src.handlers import composite_checkpoint_handler from orbax.checkpoint._src.metadata import checkpoint +from orbax.checkpoint._src.metadata import step_metadata_serialization from orbax.checkpoint._src.multihost import multihost from orbax.checkpoint._src.path import atomicity from orbax.checkpoint._src.path import atomicity_defaults @@ -42,6 +43,7 @@ get_legacy_handler_wrapper = ( composite_checkpoint_handler.get_legacy_handler_wrapper ) +StepMetadata = checkpoint.StepMetadata def construct_checkpoint_args( @@ -161,7 +163,12 @@ async def create_temporary_path( return tmpdir def save( - self, directory: epath.PathLike, *args, force: bool = False, **kwargs + self, + directory: epath.PathLike, + *args, + force: bool = False, + custom_metadata: dict[str, Any] | None = None, + **kwargs, ): """Saves the given item to the provided directory. @@ -176,6 +183,8 @@ def save( *args: additional args to provide to the CheckpointHandler's save method. force: if True, allows overwriting an existing directory. May add overhead due to the need to delete any existing files. + custom_metadata: a dictionary of custom metadata to be written to the + checkpoint directory via StepMetadata. **kwargs: additional keyword args to provide to the CheckpointHandler's save method. @@ -210,6 +219,17 @@ def save( processes=self._active_processes, ) + if utils.is_primary_host(self._primary_host): + self._save_step_metadata(tmpdir.get(), custom=custom_metadata) + multihost.sync_global_processes( + multihost.unique_barrier_key( + 'Checkpointer:step_metadata_save', + prefix=self._barrier_sync_key_prefix, + # suffix=tmpdir.get().name, + ), + processes=self._active_processes, + ) + # Ensure save operation atomicity and record time saved by checkpoint. if utils.is_primary_host(self._primary_host): self._handler.finalize(tmpdir.get()) @@ -251,11 +271,25 @@ def _restore( ) -> Any: return self._handler.restore(directory, args=args) - def metadata(self, directory: epath.PathLike) -> Optional[Any]: + def metadata(self, directory: epath.PathLike) -> StepMetadata | Any | None: """See superclass documentation.""" directory = epath.Path(directory) return self._handler.metadata(directory) + def _save_step_metadata( + self, directory: epath.Path, custom: dict[str, Any] | None + ): + """Saves StepMetadata to the checkpoint directory.""" + step_metadata = StepMetadata( + format=self._file_options.format, + custom=custom, + ) + self._handler._update_metadata(directory, step_metadata) # pylint: disable=protected-access + self._metadata_store.update( + file_path=checkpoint.step_metadata_file_path(directory), + **step_metadata_serialization.serialize(step_metadata), + ) + def close(self): """Closes the underlying CheckpointHandler.""" self._handler.close() diff --git a/checkpoint/orbax/checkpoint/_src/handlers/checkpoint_handler.py b/checkpoint/orbax/checkpoint/_src/handlers/checkpoint_handler.py index 5fa8a119e..b5c6db2f8 100644 --- a/checkpoint/orbax/checkpoint/_src/handlers/checkpoint_handler.py +++ b/checkpoint/orbax/checkpoint/_src/handlers/checkpoint_handler.py @@ -16,8 +16,12 @@ import abc from typing import Any, Optional + +from absl import logging from etils import epath +from orbax.checkpoint._src.metadata import checkpoint + class CheckpointHandler(abc.ABC): """An interface providing save/restore methods used on a savable item. @@ -67,6 +71,21 @@ def metadata(self, directory: epath.Path) -> Optional[Any]: """ pass + def _update_metadata(self, directory: epath.Path, step_metadata: Any) -> None: + """Updates `item_metadata` and `item_handlers` in `step_metadata`.""" + try: + step_metadata.item_metadata: checkpoint.SingleItemMetadata = ( + self.metadata(directory) + ) + except (FileNotFoundError, NotImplementedError, ValueError): + logging.warning( + 'Failed to get handler metadata from directory %s.', + directory, + ) + step_metadata.item_handlers: checkpoint.CheckpointHandlerTypeStr = ( + self.typestr() + ) + def finalize(self, directory: epath.Path) -> None: """Optional, custom checkpoint finalization callback. @@ -84,4 +103,4 @@ def close(self): @classmethod def typestr(cls) -> str: """A unique identifier for the CheckpointHandler type.""" - return f"{cls.__module__}.{cls.__qualname__}" + return f'{cls.__module__}.{cls.__qualname__}' diff --git a/checkpoint/orbax/checkpoint/_src/handlers/composite_checkpoint_handler.py b/checkpoint/orbax/checkpoint/_src/handlers/composite_checkpoint_handler.py index 30c9208f4..3121d77f2 100644 --- a/checkpoint/orbax/checkpoint/_src/handlers/composite_checkpoint_handler.py +++ b/checkpoint/orbax/checkpoint/_src/handlers/composite_checkpoint_handler.py @@ -902,6 +902,20 @@ def metadata(self, directory: epath.Path) -> StepMetadata: item_metadata=CompositeItemMetadata(**item_metadata), ) + def _update_metadata(self, directory: epath.Path, step_metadata: Any) -> None: + """Updates `item_metadata` and `item_handlers` in `step_metadata`.""" + try: + partial_metadata: StepMetadata = self.metadata(directory) + except (FileNotFoundError, NotImplementedError, ValueError): + logging.warning( + 'Failed to get per-item metadata from directory %s. Handler types ' + 'will not be saved.', + directory, + ) + else: + step_metadata.item_metadata = partial_metadata.item_metadata + step_metadata.item_handlers = partial_metadata.item_handlers + def finalize(self, directory: epath.Path): if not self._current_temporary_paths: raise ValueError('finalize() called before any items were saved.') diff --git a/checkpoint/orbax/checkpoint/options.py b/checkpoint/orbax/checkpoint/options.py index affc36b69..65f19e456 100644 --- a/checkpoint/orbax/checkpoint/options.py +++ b/checkpoint/orbax/checkpoint/options.py @@ -21,6 +21,9 @@ +_ORBAX_STANDARD_FORMAT = 'orbax-standard' + + @dataclasses.dataclass class AsyncOptions: """Options used to configure async behavior. @@ -65,9 +68,13 @@ class FileOptions: metadata files. e.g. 0o750. Please check https://github.com/google/etils/blob/main/etils/epath/backend.py if your path is supported. default=None. + format: The checkpoint file format. This is useful when differentiating + between Orbax and Roc checkpoints, as well as checkpoints saved by + different apis. Defaults to 'orbax-standard'. """ path_permission_mode: Optional[int] = None + format: str = _ORBAX_STANDARD_FORMAT