Skip to content

Commit

Permalink
Add _update_metadata() method to CheckpointHandlers.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 712293541
  • Loading branch information
BlaziusMaximus authored and Orbax Authors committed Jan 5, 2025
1 parent 58d7b26 commit 072408e
Show file tree
Hide file tree
Showing 5 changed files with 79 additions and 3 deletions.
2 changes: 2 additions & 0 deletions checkpoint/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ properties not included in any tree mapping operations.
`CompositeCheckpointHandler.metadata()` to retrieve item metadata by
default-constructing `CheckpointHandler`s when they're listed in the saved
`StepMetadata` but aren't found in the checkpoint.
- `FileOptions.format` to specify the underlying checkpointing file format.

### Fixed
- Ignore not-exists and not-dir errors while building step metadata in
Expand All @@ -38,6 +39,7 @@ default-constructing `CheckpointHandler`s when they're listed in the saved

### Changed
- Return `StepMetadata` from `CompositeCheckpointHandler.metadata()`.
- `Checkpointer.save()` also saves `StepMetadata`.

## [0.10.2] - 2024-12-04

Expand Down
38 changes: 36 additions & 2 deletions checkpoint/orbax/checkpoint/_src/checkpointers/checkpointer.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
from orbax.checkpoint._src.handlers import checkpoint_handler
from orbax.checkpoint._src.handlers import composite_checkpoint_handler
from orbax.checkpoint._src.metadata import checkpoint
from orbax.checkpoint._src.metadata import step_metadata_serialization
from orbax.checkpoint._src.multihost import multihost
from orbax.checkpoint._src.path import atomicity
from orbax.checkpoint._src.path import atomicity_defaults
Expand All @@ -42,6 +43,7 @@
get_legacy_handler_wrapper = (
composite_checkpoint_handler.get_legacy_handler_wrapper
)
StepMetadata = checkpoint.StepMetadata


def construct_checkpoint_args(
Expand Down Expand Up @@ -161,7 +163,12 @@ async def create_temporary_path(
return tmpdir

def save(
self, directory: epath.PathLike, *args, force: bool = False, **kwargs
self,
directory: epath.PathLike,
*args,
force: bool = False,
custom_metadata: dict[str, Any] | None = None,
**kwargs,
):
"""Saves the given item to the provided directory.
Expand All @@ -176,6 +183,8 @@ def save(
*args: additional args to provide to the CheckpointHandler's save method.
force: if True, allows overwriting an existing directory. May add overhead
due to the need to delete any existing files.
custom_metadata: a dictionary of custom metadata to be written to the
checkpoint directory via StepMetadata.
**kwargs: additional keyword args to provide to the CheckpointHandler's
save method.
Expand Down Expand Up @@ -210,6 +219,17 @@ def save(
processes=self._active_processes,
)

if utils.is_primary_host(self._primary_host):
self._save_step_metadata(tmpdir.get(), custom=custom_metadata)
multihost.sync_global_processes(
multihost.unique_barrier_key(
'Checkpointer:step_metadata_save',
prefix=self._barrier_sync_key_prefix,
# suffix=tmpdir.get().name,
),
processes=self._active_processes,
)

# Ensure save operation atomicity and record time saved by checkpoint.
if utils.is_primary_host(self._primary_host):
self._handler.finalize(tmpdir.get())
Expand Down Expand Up @@ -251,11 +271,25 @@ def _restore(
) -> Any:
return self._handler.restore(directory, args=args)

def metadata(self, directory: epath.PathLike) -> Optional[Any]:
def metadata(self, directory: epath.PathLike) -> StepMetadata | Any | None:
"""See superclass documentation."""
directory = epath.Path(directory)
return self._handler.metadata(directory)

def _save_step_metadata(
self, directory: epath.Path, custom: dict[str, Any] | None
):
"""Saves StepMetadata to the checkpoint directory."""
step_metadata = StepMetadata(
format=self._file_options.format,
custom=custom,
)
self._handler._update_metadata(directory, step_metadata) # pylint: disable=protected-access
self._metadata_store.update(
file_path=checkpoint.step_metadata_file_path(directory),
**step_metadata_serialization.serialize(step_metadata),
)

def close(self):
"""Closes the underlying CheckpointHandler."""
self._handler.close()
Expand Down
21 changes: 20 additions & 1 deletion checkpoint/orbax/checkpoint/_src/handlers/checkpoint_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,12 @@

import abc
from typing import Any, Optional

from absl import logging
from etils import epath

from orbax.checkpoint._src.metadata import checkpoint


class CheckpointHandler(abc.ABC):
"""An interface providing save/restore methods used on a savable item.
Expand Down Expand Up @@ -67,6 +71,21 @@ def metadata(self, directory: epath.Path) -> Optional[Any]:
"""
pass

def _update_metadata(self, directory: epath.Path, step_metadata: Any) -> None:
"""Updates `item_metadata` and `item_handlers` in `step_metadata`."""
try:
step_metadata.item_metadata: checkpoint.SingleItemMetadata = (
self.metadata(directory)
)
except (FileNotFoundError, NotImplementedError, ValueError):
logging.warning(
'Failed to get handler metadata from directory %s.',
directory,
)
step_metadata.item_handlers: checkpoint.CheckpointHandlerTypeStr = (
self.typestr()
)

def finalize(self, directory: epath.Path) -> None:
"""Optional, custom checkpoint finalization callback.
Expand All @@ -84,4 +103,4 @@ def close(self):
@classmethod
def typestr(cls) -> str:
"""A unique identifier for the CheckpointHandler type."""
return f"{cls.__module__}.{cls.__qualname__}"
return f'{cls.__module__}.{cls.__qualname__}'
Original file line number Diff line number Diff line change
Expand Up @@ -902,6 +902,20 @@ def metadata(self, directory: epath.Path) -> StepMetadata:
item_metadata=CompositeItemMetadata(**item_metadata),
)

def _update_metadata(self, directory: epath.Path, step_metadata: Any) -> None:
"""Updates `item_metadata` and `item_handlers` in `step_metadata`."""
try:
partial_metadata: StepMetadata = self.metadata(directory)
except (FileNotFoundError, NotImplementedError, ValueError):
logging.warning(
'Failed to get per-item metadata from directory %s. Handler types '
'will not be saved.',
directory,
)
else:
step_metadata.item_metadata = partial_metadata.item_metadata
step_metadata.item_handlers = partial_metadata.item_handlers

def finalize(self, directory: epath.Path):
if not self._current_temporary_paths:
raise ValueError('finalize() called before any items were saved.')
Expand Down
7 changes: 7 additions & 0 deletions checkpoint/orbax/checkpoint/options.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,9 @@



_ORBAX_STANDARD_FORMAT = 'orbax-standard'


@dataclasses.dataclass
class AsyncOptions:
"""Options used to configure async behavior.
Expand Down Expand Up @@ -65,9 +68,13 @@ class FileOptions:
metadata files. e.g. 0o750. Please check
https://github.com/google/etils/blob/main/etils/epath/backend.py if your
path is supported. default=None.
format: The checkpoint file format. This is useful when differentiating
between Orbax and Roc checkpoints, as well as checkpoints saved by
different apis. Defaults to 'orbax-standard'.
"""

path_permission_mode: Optional[int] = None
format: str = _ORBAX_STANDARD_FORMAT



0 comments on commit 072408e

Please sign in to comment.