From b32d1f9806b96e445d86561d079aa58b336b49be Mon Sep 17 00:00:00 2001 From: Orbax Authors Date: Tue, 14 Jan 2025 06:53:43 -0800 Subject: [PATCH] Separate sub-item temporary path class from default temporary path class. PiperOrigin-RevId: 715363908 --- .../_src/handlers/composite_checkpoint_handler.py | 2 +- .../orbax/checkpoint/_src/path/atomicity_defaults.py | 11 +++++++++++ 2 files changed, 12 insertions(+), 1 deletion(-) diff --git a/checkpoint/orbax/checkpoint/_src/handlers/composite_checkpoint_handler.py b/checkpoint/orbax/checkpoint/_src/handlers/composite_checkpoint_handler.py index 811a6b12..8b86b924 100644 --- a/checkpoint/orbax/checkpoint/_src/handlers/composite_checkpoint_handler.py +++ b/checkpoint/orbax/checkpoint/_src/handlers/composite_checkpoint_handler.py @@ -637,7 +637,7 @@ def _get_item_temporary_directory( ) -> atomicity_types.TemporaryPath: temporary_path_class = ( self._temporary_path_class - or atomicity_defaults.get_default_temporary_path_class(directory) + or atomicity_defaults.get_item_default_temporary_path_class(directory) ) tmp_item_dir = temporary_path_class.from_final( self._get_item_directory(directory, item_name), diff --git a/checkpoint/orbax/checkpoint/_src/path/atomicity_defaults.py b/checkpoint/orbax/checkpoint/_src/path/atomicity_defaults.py index 232bd3f3..dd592d18 100644 --- a/checkpoint/orbax/checkpoint/_src/path/atomicity_defaults.py +++ b/checkpoint/orbax/checkpoint/_src/path/atomicity_defaults.py @@ -26,9 +26,20 @@ from orbax.checkpoint._src.path import step as step_lib +def get_item_default_temporary_path_class( + final_path: epath.Path, +) -> Type[atomicity_types.TemporaryPath]: + """Returns the default temporary path class for a given sub-item path.""" + if step_lib.is_gcs_path(final_path): + return atomicity.CommitFileTemporaryPath + else: + return atomicity.AtomicRenameTemporaryPath + + def get_default_temporary_path_class( final_path: epath.Path, ) -> Type[atomicity_types.TemporaryPath]: + """Returns the default temporary path class for a given checkpoint path.""" if step_lib.is_gcs_path(final_path): return atomicity.CommitFileTemporaryPath else: