Skip to content

Commit

Permalink
Ignore not-exists and not-dir errors while building step metadata in …
Browse files Browse the repository at this point in the history
…_StandardNameFormat.

PiperOrigin-RevId: 703105717
  • Loading branch information
niketkumar authored and Orbax Authors committed Dec 5, 2024
1 parent 9b38ce3 commit ac2d276
Show file tree
Hide file tree
Showing 2 changed files with 23 additions and 19 deletions.
4 changes: 4 additions & 0 deletions checkpoint/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,10 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

## [Unreleased]

### Fixed
- Ignore not-exists and not-dir errors while building step metadata in
_StandardNameFormat.

## [0.10.2] - 2024-12-04

### Added
Expand Down
38 changes: 19 additions & 19 deletions checkpoint/orbax/checkpoint/_src/path/step.py
Original file line number Diff line number Diff line change
Expand Up @@ -303,10 +303,7 @@ def _build_metadata(
self, step_path: epath.Path, step: Optional[int] = None
) -> Optional[Metadata]:
"""Returns metadata for given `step_path` if it is valid or None."""
if not step_path.is_dir():
return None

if not is_checkpoint_finalized(step_path):
if not _is_checkpoint_finalized_ignore_error(step_path):
return None

if step is not None:
Expand Down Expand Up @@ -569,6 +566,13 @@ def is_checkpoint_finalized(path: epath.PathLike) -> bool:
return True


def _is_checkpoint_finalized_ignore_error(path: epath.PathLike) -> bool:
try:
return is_checkpoint_finalized(path)
except ValueError:
return False


def tmp_checkpoints(checkpoint_dir: epath.PathLike) -> List[str]:
"""Returns a list of tmp checkpoint dir names in `checkpoint_dir`."""
checkpoint_dir = epath.Path(checkpoint_dir)
Expand Down Expand Up @@ -627,7 +631,7 @@ def record_saved_duration(checkpoint_start_time: float):
_LAST_CHECKPOINT_WRITE_TIME = checkpoint_start_time


def _is_step_checkpoint(path: epath.Path) -> bool:
def _is_legacy_step_checkpoint(path: epath.Path) -> bool:
"""Determines if the path resembles an Orbax step directory.
Note that this is not foolproof, and users should not add extra files to the
Expand All @@ -644,6 +648,12 @@ def _is_step_checkpoint(path: epath.Path) -> bool:
return path.is_dir() and (name.isdigit() or name.split('_')[-1].isdigit())


def _is_legacy_finalized_step_checkpoint(path: epath.Path) -> bool:
return _is_legacy_step_checkpoint(
path
) and _is_checkpoint_finalized_ignore_error(path)


def step_from_checkpoint_name(name: str) -> int:
"""Returns the step from a checkpoint name. Also works for tmp checkpoints."""
if name.isdigit():
Expand All @@ -665,21 +675,11 @@ def checkpoint_steps_paths(
if not checkpoint_dir.exists():
raise ValueError(f'Path {checkpoint_dir} does not exist.')

def check_step_dir(step_dir: epath.Path) -> bool:
# This block allows catching errors in which the checkpoint was deleted
# between checking _is_step_checkpoint and is_checkpoint_finalized.
try:
result = _is_step_checkpoint(step_dir) and is_checkpoint_finalized(
step_dir
)
except ValueError:
return False

return result

with concurrent.futures.ThreadPoolExecutor() as executor:
futures = {
step_dir: executor.submit(check_step_dir, step_dir)
step_dir: executor.submit(
_is_legacy_finalized_step_checkpoint, step_dir
)
for step_dir in checkpoint_dir.iterdir()
}
return [step_dir for step_dir, future in futures.items() if future.result()]
Expand Down Expand Up @@ -723,6 +723,6 @@ def any_checkpoint_step(checkpoint_dir: epath.PathLike) -> Optional[int]:
"""
checkpoint_dir = epath.Path(checkpoint_dir)
for s in checkpoint_dir.iterdir():
if _is_step_checkpoint(s) and is_checkpoint_finalized(s):
if _is_legacy_finalized_step_checkpoint(s):
return step_from_checkpoint_name(s.name)
return None

0 comments on commit ac2d276

Please sign in to comment.