Skip to content

Commit

Permalink
Separate sub-item temporary path class from default temporary path cl…
Browse files Browse the repository at this point in the history
…ass.

PiperOrigin-RevId: 715363908
  • Loading branch information
Orbax Authors committed Jan 14, 2025
1 parent 2a7e309 commit b32d1f9
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 1 deletion.
Original file line number Diff line number Diff line change
Expand Up @@ -637,7 +637,7 @@ def _get_item_temporary_directory(
) -> atomicity_types.TemporaryPath:
temporary_path_class = (
self._temporary_path_class
or atomicity_defaults.get_default_temporary_path_class(directory)
or atomicity_defaults.get_item_default_temporary_path_class(directory)
)
tmp_item_dir = temporary_path_class.from_final(
self._get_item_directory(directory, item_name),
Expand Down
11 changes: 11 additions & 0 deletions checkpoint/orbax/checkpoint/_src/path/atomicity_defaults.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,9 +26,20 @@
from orbax.checkpoint._src.path import step as step_lib


def get_item_default_temporary_path_class(
final_path: epath.Path,
) -> Type[atomicity_types.TemporaryPath]:
"""Returns the default temporary path class for a given sub-item path."""
if step_lib.is_gcs_path(final_path):
return atomicity.CommitFileTemporaryPath
else:
return atomicity.AtomicRenameTemporaryPath


def get_default_temporary_path_class(
final_path: epath.Path,
) -> Type[atomicity_types.TemporaryPath]:
"""Returns the default temporary path class for a given checkpoint path."""
if step_lib.is_gcs_path(final_path):
return atomicity.CommitFileTemporaryPath
else:
Expand Down

0 comments on commit b32d1f9

Please sign in to comment.