Skip to content

Commit

Permalink
Internal
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 733831700
  • Loading branch information
Orbax Authors committed Mar 5, 2025
1 parent 19b8e84 commit dd0bf84
Show file tree
Hide file tree
Showing 3 changed files with 24 additions and 0 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
from orbax.checkpoint._src.path import async_utils
from orbax.checkpoint._src.path import atomicity
from orbax.checkpoint._src.path import atomicity_types
from orbax.checkpoint._src.path import utils as path_utils



Expand Down Expand Up @@ -423,6 +424,10 @@ async def _save(
directory = tmpdir.get_final()
self.synchronize_next_awaitable_signal_operation_id()

jax.monitoring.record_event(
'/jax/orbax/write/async/storage_type',
storage_type=path_utils.get_storage_type(directory),
)
jax.monitoring.record_event('/jax/orbax/write/async/start')
logging.info(
'[process=%s] Started async saving checkpoint to %s.',
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
from orbax.checkpoint._src.path import atomicity
from orbax.checkpoint._src.path import atomicity_defaults
from orbax.checkpoint._src.path import atomicity_types
from orbax.checkpoint._src.path import utils as path_utils
from typing_extensions import Self # for Python version < 3.11


Expand Down Expand Up @@ -212,6 +213,10 @@ def save(
checkpoint_start_time = time.time()
directory = epath.Path(directory)

jax.monitoring.record_event(
'/jax/orbax/write/storage_type',
storage_type=path_utils.get_storage_type(directory),
)
jax.monitoring.record_event('/jax/orbax/write/start')
logging.info(
'[process=%s] Started saving checkpoint to %s.',
Expand Down
14 changes: 14 additions & 0 deletions checkpoint/orbax/checkpoint/_src/path/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,20 @@



_GCS_PATH_PREFIX = ('gs://',)


def is_gcs_path(path: epath.Path) -> bool:
return path.as_posix().startswith(_GCS_PATH_PREFIX)


def get_storage_type(path: epath.Path) -> str:
if is_gcs_path(path):
return 'gcs'
else:
return 'local'


class Timer(object):
"""A simple timer to measure the time it takes to run a function."""

Expand Down

0 comments on commit dd0bf84

Please sign in to comment.