diff --git a/checkpoint/CHANGELOG.md b/checkpoint/CHANGELOG.md index a9907dd0..1dcc634a 100644 --- a/checkpoint/CHANGELOG.md +++ b/checkpoint/CHANGELOG.md @@ -7,6 +7,10 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ## [Unreleased] +### Fixed +- Ignore not-exists and not-dir errors while building step metadata in + _StandardNameFormat. + ## [0.10.2] - 2024-12-04 ### Added diff --git a/checkpoint/orbax/checkpoint/_src/path/step.py b/checkpoint/orbax/checkpoint/_src/path/step.py index 5e115142..d57b9b19 100644 --- a/checkpoint/orbax/checkpoint/_src/path/step.py +++ b/checkpoint/orbax/checkpoint/_src/path/step.py @@ -303,10 +303,7 @@ def _build_metadata( self, step_path: epath.Path, step: Optional[int] = None ) -> Optional[Metadata]: """Returns metadata for given `step_path` if it is valid or None.""" - if not step_path.is_dir(): - return None - - if not is_checkpoint_finalized(step_path): + if not _is_checkpoint_finalized_ignore_error(step_path): return None if step is not None: @@ -569,6 +566,13 @@ def is_checkpoint_finalized(path: epath.PathLike) -> bool: return True +def _is_checkpoint_finalized_ignore_error(path: epath.PathLike) -> bool: + try: + return is_checkpoint_finalized(path) + except ValueError: + return False + + def tmp_checkpoints(checkpoint_dir: epath.PathLike) -> List[str]: """Returns a list of tmp checkpoint dir names in `checkpoint_dir`.""" checkpoint_dir = epath.Path(checkpoint_dir) @@ -627,7 +631,7 @@ def record_saved_duration(checkpoint_start_time: float): _LAST_CHECKPOINT_WRITE_TIME = checkpoint_start_time -def _is_step_checkpoint(path: epath.Path) -> bool: +def _is_legacy_step_checkpoint(path: epath.Path) -> bool: """Determines if the path resembles an Orbax step directory. Note that this is not foolproof, and users should not add extra files to the @@ -644,6 +648,12 @@ def _is_step_checkpoint(path: epath.Path) -> bool: return path.is_dir() and (name.isdigit() or name.split('_')[-1].isdigit()) +def _is_legacy_finalized_step_checkpoint(path: epath.Path) -> bool: + return _is_legacy_step_checkpoint( + path + ) and _is_checkpoint_finalized_ignore_error(path) + + def step_from_checkpoint_name(name: str) -> int: """Returns the step from a checkpoint name. Also works for tmp checkpoints.""" if name.isdigit(): @@ -665,21 +675,11 @@ def checkpoint_steps_paths( if not checkpoint_dir.exists(): raise ValueError(f'Path {checkpoint_dir} does not exist.') - def check_step_dir(step_dir: epath.Path) -> bool: - # This block allows catching errors in which the checkpoint was deleted - # between checking _is_step_checkpoint and is_checkpoint_finalized. - try: - result = _is_step_checkpoint(step_dir) and is_checkpoint_finalized( - step_dir - ) - except ValueError: - return False - - return result - with concurrent.futures.ThreadPoolExecutor() as executor: futures = { - step_dir: executor.submit(check_step_dir, step_dir) + step_dir: executor.submit( + _is_legacy_finalized_step_checkpoint, step_dir + ) for step_dir in checkpoint_dir.iterdir() } return [step_dir for step_dir, future in futures.items() if future.result()] @@ -723,6 +723,6 @@ def any_checkpoint_step(checkpoint_dir: epath.PathLike) -> Optional[int]: """ checkpoint_dir = epath.Path(checkpoint_dir) for s in checkpoint_dir.iterdir(): - if _is_step_checkpoint(s) and is_checkpoint_finalized(s): + if _is_legacy_finalized_step_checkpoint(s): return step_from_checkpoint_name(s.name) return None