Skip to content

Commit

Permalink
internal
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 733036803
  • Loading branch information
qdhack authored and Orbax Authors committed Mar 3, 2025
1 parent 4a24304 commit 594c6dd
Show file tree
Hide file tree
Showing 10 changed files with 217 additions and 19 deletions.
7 changes: 6 additions & 1 deletion checkpoint/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,13 +7,18 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

## [Unreleased]

### Fixed

- Fix RESOURCE_EXHAUSTED while writing array_metadatas.

### Changed

- Improve `Cannot serialize host local jax.Array` error message.

### Added

- support saving and restoring jax.random.key() in PyTree
- support saving and restoring jax.random.key() in PyTree.
- `CheckpointableHandler` for V1.

## [0.11.6] - 2025-02-20

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -323,6 +323,8 @@ def __init__(
type_handler_registry
)
)
if self._array_metadata_store:
self._array_metadata_store.set_primary_host(self._primary_host)
self._array_metadata_validator = array_metadata_validator


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -133,9 +133,45 @@ def __init__(
self,
path_resolver: PathResolver = PathResolver(),
ser_deser: SerDeserializer = SerDeserializer(),
primary_host: int | None = 0, # None means all hosts are primary hosts.
write_timeout_secs: int = 600, # 10 minutes.
):
self._path_resolver = path_resolver
self._ser_deser = ser_deser
self._primary_host = primary_host
self._write_timeout_secs = write_timeout_secs

def set_primary_host(self, primary_host: int | None) -> None:
"""Sets the primary host."""
self._primary_host = primary_host

async def _maybe_create_base_dir(self, base_dir: epath.Path) -> None:
"""Primary host creates the base directory, rest of the hosts wait."""
if multihost.is_primary_host(self._primary_host):
# primary host creates, rest of the hosts wait.
return await asyncio.to_thread(
base_dir.mkdir, parents=True, exist_ok=True
)

# non-primary host waits for primary host to create the base dir/folder.
async def wait_for_base_dir_creation():
while not await asyncio.to_thread(base_dir.exists):
await asyncio.sleep(0.25)

try:
await asyncio.wait_for(
wait_for_base_dir_creation(), timeout=self._write_timeout_secs
)
except asyncio.TimeoutError as e:
primary_process = (
'LOCAL' if self._primary_host is None else self._primary_host
)
raise ValueError(
f'[process_index={multihost.process_index()}] Timed out waiting for'
f' array_metadatas base directory creation: {base_dir}.'
f' timeout={self._write_timeout_secs} seconds.'
f' primary_process={primary_process}'
) from e

async def write(
self,
Expand All @@ -155,7 +191,7 @@ async def write(
file_path = self._path_resolver.get_write_file_path(
checkpoint_dir, process_index
)
await asyncio.to_thread(file_path.parent.mkdir, parents=True, exist_ok=True)
await self._maybe_create_base_dir(file_path.parent)
await asyncio.to_thread(
file_path.write_text, self._ser_deser.serialize(array_metadatas)
)
Expand Down
12 changes: 12 additions & 0 deletions checkpoint/orbax/checkpoint/experimental/v1/_src/handlers/BUILD
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
load("//devtools/python/blaze:pytype.bzl", "pytype_strict_library")

package(default_visibility = ["//visibility:public"])

py_library(
name = "types",
srcs = ["types.py"],
deps = [
"//orbax/checkpoint/experimental/v1/_src/path:types",
"//orbax/checkpoint/experimental/v1/_src/synchronization:types",
],
)
143 changes: 143 additions & 0 deletions checkpoint/orbax/checkpoint/experimental/v1/_src/handlers/types.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,143 @@
# Copyright 2024 The Orbax Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Defines types for `CheckpointableHandler`."""

from typing import Awaitable, Protocol, TypeVar
from orbax.checkpoint.experimental.v1._src.path import types as path_types
from orbax.checkpoint.experimental.v1._src.synchronization import types as async_types


T = TypeVar('T')
AbstractT = TypeVar('AbstractT')
MetadataT = TypeVar('MetadataT')

PathLike = path_types.PathLike
AsyncResponse = async_types.AsyncResponse


class CheckpointableHandler(Protocol[T, AbstractT, MetadataT]):
"""An interface that defines save/load logic for a `checkpointable` object.
NOTE: Prefer to use `Checkpointable` interface when possible.
A "checkpointable" is a fundamental concept in Orbax. A “checkpointable”
refers to a logical piece of the checkpoint that is distinct in some way from
other pieces. Checkpointables are separable; they may or may not be loaded
concurrently and some may be omitted from the checkpoint entirely.
Checkpointables are often represented by different types, and have different
representations on disk. The quintessential example is model params vs.
dataset.
A PyTree of arrays, representing model parameters, is the most basic
"checkpointable". A singular array is also a checkpointable.
In most contexts, when dealing with just a PyTree, the API of choice is::
ocp.save_pytree(directory, pytree)
The concept of "checkpointable" is not so obvious in this case. When dealing
with multiple objects, we can use::
ocp.save_checkpointables(
directory,
dict(
pytree=model_params,
dataset=dataset_iterator,
# other checkpointables, e.g. extra metadata, etc.
),
)
Now, it is easy to simply skip loading the dataset, as is commonly desired
when running evals or inference::
ocp.load_checkpointables(
directory,
dict(
pytree=abstract_model_params,
),
)
# Equivalently,
ocp.load_pytree(directory, abstract_model_params)
With the methods defined in this Protocol (`save`, `load`),
logic within the method itself is executed in the main thread,
in a blocking fashion. Additional logic can be executed in the background by
returning an `Awaitable` function (which itself may return a result).
TODO(b/398249409) Include more details on implementing this Protocol.
"""

async def save(
self, directory: path_types.PathLike, checkpointable: T
) -> Awaitable[None]:
"""Saves the given `checkpointable` to the given `directory`.
Save should perform any operations that need to block the main thread, such
as device-to-host copying of on-device arrays. It then creates a background
operation to continue writing the object to the storage location.
Args:
directory: The directory to save the checkpoint to.
checkpointable: The checkpointable object to save.
Returns:
An `Awaitable`. This object represents the result of the save
operation running in the background.
"""
...

async def load(
self,
directory: path_types.PathLike,
abstract_checkpointable: AbstractT | None = None,
) -> Awaitable[T]:
"""Loads the checkpointable from the given `directory`.
Args:
directory: The directory to load the checkpoint from.
abstract_checkpointable: An optional abstract representation of the
checkpointable to load. If provided, this is used to provide properties
to guide the restoration logic of the checkpoint. In the case of arrays,
for example, this conveys properties like shape and dtype, for casting
and reshaping. In some cases, no information is needed, and `AbstractT`
may always be None. In other cases, the abstract representation may be a
hard requirement for loading.
Returns:
An `Awaitable` that continues to load the checkpointable in the background
and returns the loaded checkpointable when complete.
"""
...

async def metadata(
self, directory: path_types.PathLike
) -> MetadataT:
"""Returns the metadata for the given `directory`.
The logic in this method must be executed fully in the main thread; metadata
access is expected to be cheap and fast.
Args:
directory: The directory where the checkpoint is located.
Returns:
MetadataT: The metadata is an `MetadataT`,
which is the abstract representation of
the checkpointable. `MetadataT` differs from `AbstractT` in that it may
contain additional properties that cannot be directly consumed to
customize loading behavior, but are nevertheless present and useful
to know about in some cases.
"""
...
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
from orbax.experimental.model.core.python.saved_model_proto import node_builder
from orbax.experimental.model.core.python.saved_model_proto import nodes

from .net.proto2.contrib.pyutil import compare
from tensorflow.python.util.protobuf import compare
from absl.testing import absltest

Node = nodes.Node
Expand All @@ -30,7 +30,7 @@
class NodesTest(absltest.TestCase):

def assertProtoEqual(self, a, b):
compare.assertProto2Equal(self, a, b)
compare.assertProtoEqual(self, a, b)

def test_node(self):
node = Node(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
from orbax.experimental.model import core as obm
from orbax.experimental.model.jax2obm import jax_specific_info
from orbax.experimental.model.jax2obm import jax_supplemental_pb2
from .net.proto2.contrib.pyutil import compare
from tensorflow.python.util.protobuf import compare
from google.protobuf import text_format
from absl.testing import absltest

Expand Down Expand Up @@ -102,7 +102,7 @@ def test_to_shape_dtype_refinements_proto(self):
}
}
"""
compare.assertProto2Equal(
compare.assertProtoEqual(
self,
jax_specific_info._to_shape_dtype_refinements_proto(input1),
text_format.Parse(
Expand Down Expand Up @@ -135,7 +135,7 @@ def test_to_shape_dtype_refinements_proto(self):
}
}
"""
compare.assertProto2Equal(
compare.assertProtoEqual(
self,
jax_specific_info._to_shape_dtype_refinements_proto(input2),
text_format.Parse(
Expand Down
10 changes: 5 additions & 5 deletions model/orbax/experimental/model/jax2obm/main_lib_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
from orbax.experimental.model.jax2obm import jax_supplemental_pb2
from orbax.experimental.model.jax2obm import main_lib

from .net.proto2.contrib.pyutil import compare
from tensorflow.python.util.protobuf import compare
from google.protobuf import text_format
from absl.testing import absltest

Expand Down Expand Up @@ -586,7 +586,7 @@ def _generated_sharded_params():
expected_manifest_proto = text_format.Parse(
expected_manifest_proto_text, obm.manifest_pb2.Manifest()
)
compare.assertProto2Equal(
compare.assertProtoEqual(
self,
manifest_proto,
expected_manifest_proto,
Expand Down Expand Up @@ -620,7 +620,7 @@ def _generated_sharded_params():
expected_orchestration_proto_text,
obm.simple_orchestration_pb2.SimpleOrchestration(),
)
compare.assertProto2Equal(
compare.assertProtoEqual(
self, expected_orchestration_proto, orchestration_proto
)

Expand Down Expand Up @@ -671,7 +671,7 @@ def _generated_sharded_params():
expected_jax_supplemental_proto_text,
jax_supplemental_pb2.Function(),
)
compare.assertProto2Equal(
compare.assertProtoEqual(
self, expected_jax_supplemental_proto, jax_supplemental_proto
)

Expand Down Expand Up @@ -1190,7 +1190,7 @@ def get_mesh():
expected_manifest_proto_text, obm.manifest_pb2.Manifest()
)

compare.assertProto2Equal(
compare.assertProtoEqual(
self,
manifest_proto,
expected_manifest_proto,
Expand Down
6 changes: 3 additions & 3 deletions model/orbax/experimental/model/jax2obm/sharding_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
import numpy as np
from orbax.experimental.model import core as obm
from orbax.experimental.model.jax2obm import sharding
from .net.proto2.contrib.pyutil import compare
from tensorflow.python.util.protobuf import compare
from google.protobuf import text_format
from absl.testing import absltest

Expand Down Expand Up @@ -47,7 +47,7 @@ def test_jax_mesh_to_obm_device_mesh(self):
expected_device_mesh = text_format.Parse(
expected_device_mesh_text, obm.manifest_pb2.DeviceMesh()
)
compare.assertProto2Equal(
compare.assertProtoEqual(
self,
device_mesh,
expected_device_mesh,
Expand Down Expand Up @@ -174,7 +174,7 @@ def test_jax_mesh_to_obm_device_assignment_by_coords(self):
expected_device_assignment_by_coords_text,
obm.manifest_pb2.DeviceAssignmentByCoords(),
)
compare.assertProto2Equal(
compare.assertProtoEqual(
self,
device_assignment_by_coords,
expected_device_assignment_by_coords,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
from orbax.experimental.model.tf2obm.tf_concrete_functions_to_obm import TF_SAVED_MODEL_SUPPLEMENTAL_NAME
import tensorflow as tf

from .net.proto2.contrib.pyutil import compare
from tensorflow.python.util.protobuf import compare
from google.protobuf import text_format
from absl.testing import absltest

Expand Down Expand Up @@ -332,7 +332,7 @@ def tf_fn(a):
expected_manifest_proto = text_format.Parse(
expected_manifest_proto_text, obm.manifest_pb2.Manifest()
)
compare.assertProto2Equal(
compare.assertProtoEqual(
self,
manifest_proto,
expected_manifest_proto,
Expand Down Expand Up @@ -360,7 +360,7 @@ def tf_fn(a):
expected_pre_processor_proto_text,
tf_concrete_function_handle_pb2.TfConcreteFunctionHandle(),
)
compare.assertProto2Equal(
compare.assertProtoEqual(
self,
pre_processor_proto,
expected_pre_processor_proto,
Expand Down Expand Up @@ -388,7 +388,7 @@ def tf_fn(a):
expected_post_processor_proto_text,
tf_concrete_function_handle_pb2.TfConcreteFunctionHandle(),
)
compare.assertProto2Equal(
compare.assertProtoEqual(
self,
post_processor_proto,
expected_post_processor_proto,
Expand Down

0 comments on commit 594c6dd

Please sign in to comment.