Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add _update_metadata() method to CheckpointHandlers. #1463

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions checkpoint/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ when a custom `snapshot_dir` is specified.
`CompositeCheckpointHandler.metadata()` to retrieve item metadata by
default-constructing `CheckpointHandler`s when they're listed in the saved
`StepMetadata` but aren't found in the checkpoint.
- `FileOptions.format` to specify the underlying checkpointing file format.

### Fixed
- Ignore not-exists and not-dir errors while building step metadata in
Expand All @@ -44,6 +45,7 @@ default-constructing `CheckpointHandler`s when they're listed in the saved

### Changed
- Return `StepMetadata` from `CompositeCheckpointHandler.metadata()`.
- `Checkpointer.save()` also saves `StepMetadata`.

## [0.10.2] - 2024-12-04

Expand Down
38 changes: 36 additions & 2 deletions checkpoint/orbax/checkpoint/_src/checkpointers/checkpointer.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
from orbax.checkpoint._src.handlers import checkpoint_handler
from orbax.checkpoint._src.handlers import composite_checkpoint_handler
from orbax.checkpoint._src.metadata import checkpoint
from orbax.checkpoint._src.metadata import step_metadata_serialization
from orbax.checkpoint._src.multihost import multihost
from orbax.checkpoint._src.path import atomicity
from orbax.checkpoint._src.path import atomicity_defaults
Expand All @@ -42,6 +43,7 @@
get_legacy_handler_wrapper = (
composite_checkpoint_handler.get_legacy_handler_wrapper
)
StepMetadata = checkpoint.StepMetadata


def construct_checkpoint_args(
Expand Down Expand Up @@ -161,7 +163,12 @@ async def create_temporary_path(
return tmpdir

def save(
self, directory: epath.PathLike, *args, force: bool = False, **kwargs
self,
directory: epath.PathLike,
*args,
force: bool = False,
custom_metadata: dict[str, Any] | None = None,
**kwargs,
):
"""Saves the given item to the provided directory.

Expand All @@ -176,6 +183,8 @@ def save(
*args: additional args to provide to the CheckpointHandler's save method.
force: if True, allows overwriting an existing directory. May add overhead
due to the need to delete any existing files.
custom_metadata: a dictionary of custom metadata to be written to the
checkpoint directory via StepMetadata.
**kwargs: additional keyword args to provide to the CheckpointHandler's
save method.

Expand Down Expand Up @@ -210,6 +219,17 @@ def save(
processes=self._active_processes,
)

if utils.is_primary_host(self._primary_host):
self._save_step_metadata(tmpdir.get(), custom=custom_metadata)
multihost.sync_global_processes(
multihost.unique_barrier_key(
'Checkpointer:step_metadata_save',
prefix=self._barrier_sync_key_prefix,
# suffix=tmpdir.get().name,
),
processes=self._active_processes,
)

# Ensure save operation atomicity and record time saved by checkpoint.
if utils.is_primary_host(self._primary_host):
self._handler.finalize(tmpdir.get())
Expand Down Expand Up @@ -251,11 +271,25 @@ def _restore(
) -> Any:
return self._handler.restore(directory, args=args)

def metadata(self, directory: epath.PathLike) -> Optional[Any]:
def metadata(self, directory: epath.PathLike) -> StepMetadata | Any | None:
"""See superclass documentation."""
directory = epath.Path(directory)
return self._handler.metadata(directory)

def _save_step_metadata(
self, directory: epath.Path, custom: dict[str, Any] | None
):
"""Saves StepMetadata to the checkpoint directory."""
step_metadata = StepMetadata(
format=self._file_options.format,
custom=custom,
)
self._handler._update_metadata(directory, step_metadata) # pylint: disable=protected-access
self._metadata_store.update(
file_path=checkpoint.step_metadata_file_path(directory),
**step_metadata_serialization.serialize(step_metadata),
)

def close(self):
"""Closes the underlying CheckpointHandler."""
self._handler.close()
Expand Down
11 changes: 10 additions & 1 deletion checkpoint/orbax/checkpoint/_src/handlers/checkpoint_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@
from typing import Any, Optional
from etils import epath

from orbax.checkpoint._src.metadata import checkpoint


class CheckpointHandler(abc.ABC):
"""An interface providing save/restore methods used on a savable item.
Expand Down Expand Up @@ -67,6 +69,13 @@ def metadata(self, directory: epath.Path) -> Optional[Any]:
"""
pass

def _update_metadata(self, directory: epath.Path, step_metadata: Any) -> None:
"""Updates `item_handlers` in `step_metadata`."""
del directory
step_metadata.item_handlers: checkpoint.CheckpointHandlerTypeStr = (
self.typestr()
)

def finalize(self, directory: epath.Path) -> None:
"""Optional, custom checkpoint finalization callback.

Expand All @@ -84,4 +93,4 @@ def close(self):
@classmethod
def typestr(cls) -> str:
"""A unique identifier for the CheckpointHandler type."""
return f"{cls.__module__}.{cls.__qualname__}"
return f'{cls.__module__}.{cls.__qualname__}'
Original file line number Diff line number Diff line change
Expand Up @@ -902,6 +902,19 @@ def metadata(self, directory: epath.Path) -> StepMetadata:
item_metadata=CompositeItemMetadata(**item_metadata),
)

def _update_metadata(self, directory: epath.Path, step_metadata: Any) -> None:
"""Updates `item_handlers` in `step_metadata`."""
try:
partial_metadata: StepMetadata = self.metadata(directory)
except (FileNotFoundError, NotImplementedError, ValueError):
logging.warning(
'Failed to get per-item metadata from directory %s. Handler types '
'will not be saved.',
directory,
)
else:
step_metadata.item_handlers = partial_metadata.item_handlers

def finalize(self, directory: epath.Path):
if not self._current_temporary_paths:
raise ValueError('finalize() called before any items were saved.')
Expand Down
7 changes: 7 additions & 0 deletions checkpoint/orbax/checkpoint/options.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,9 @@



_ORBAX_STANDARD_FORMAT = 'orbax-standard'


@dataclasses.dataclass
class AsyncOptions:
"""Options used to configure async behavior.
Expand Down Expand Up @@ -65,9 +68,13 @@ class FileOptions:
metadata files. e.g. 0o750. Please check
https://github.com/google/etils/blob/main/etils/epath/backend.py if your
path is supported. default=None.
format: The checkpoint file format. This is useful when differentiating
between Orbax and Roc checkpoints, as well as checkpoints saved by
different apis. Defaults to 'orbax-standard'.
"""

path_permission_mode: Optional[int] = None
format: str = _ORBAX_STANDARD_FORMAT



Loading