Skip to content

Commit

Permalink
Add 1. checkpoint/_src/handlers/BUILD, 2. handler tests and 3. update…
Browse files Browse the repository at this point in the history
…s dependencies on checkpoint/_src/handlers to these targets.

PiperOrigin-RevId: 715152882
  • Loading branch information
liangyaning33 authored and Orbax Authors committed Jan 14, 2025
1 parent 2a7e309 commit 838f262
Show file tree
Hide file tree
Showing 10 changed files with 2,550 additions and 3 deletions.
253 changes: 253 additions & 0 deletions checkpoint/orbax/checkpoint/_src/handlers/BUILD
Original file line number Diff line number Diff line change
@@ -0,0 +1,253 @@
package(
default_applicable_licenses = ["//:package_license"],
default_visibility = ["//visibility:public"],
)

py_library(
name = "checkpoint_handler",
srcs = ["checkpoint_handler.py"],
)

py_library(
name = "composite_checkpoint_handler",
srcs = ["composite_checkpoint_handler.py"],
deps = [
":async_checkpoint_handler",
":checkpoint_handler",
":handler_registration",
":proto_checkpoint_handler",
"//checkpoint/orbax/checkpoint/_src:asyncio_utils",
"//checkpoint/orbax/checkpoint/_src:composite",
"//checkpoint/orbax/checkpoint/_src/metadata:checkpoint",
"//checkpoint/orbax/checkpoint/_src/metadata:step_metadata_serialization",
"//checkpoint/orbax/checkpoint/_src/path:atomicity",
"//checkpoint/orbax/checkpoint/_src/path:atomicity_defaults",
"//checkpoint/orbax/checkpoint/_src/path:atomicity_types",
],
)

py_test(
name = "composite_checkpoint_handler_test",
srcs = ["composite_checkpoint_handler_test.py"],
deps = [
":checkpoint_handler",
":composite_checkpoint_handler",
":handler_registration",
":json_checkpoint_handler",
":proto_checkpoint_handler",
":standard_checkpoint_handler",
"//checkpoint/orbax/checkpoint/_src/metadata:checkpoint",
"//checkpoint/orbax/checkpoint/_src/metadata:step_metadata_serialization",
"//checkpoint/orbax/checkpoint/_src/metadata:value",
"//checkpoint/orbax/checkpoint/_src/multihost",
"//checkpoint/orbax/checkpoint/_src/path:step",
],
)

py_library(
name = "pytree_checkpoint_handler",
srcs = ["pytree_checkpoint_handler.py"],
srcs_version = "PY3",
deps = [
":async_checkpoint_handler",
":base_pytree_checkpoint_handler",
"//checkpoint/orbax/checkpoint/_src:asyncio_utils",
"//checkpoint/orbax/checkpoint/_src/metadata:empty_values",
"//checkpoint/orbax/checkpoint/_src/metadata:tree",
"//checkpoint/orbax/checkpoint/_src/serialization",
"//checkpoint/orbax/checkpoint/_src/serialization:tensorstore_utils",
"//checkpoint/orbax/checkpoint/_src/serialization:type_handlers",
],
)

py_library(
name = "base_pytree_checkpoint_handler",
srcs = ["base_pytree_checkpoint_handler.py"],
srcs_version = "PY3",
deps = [
":async_checkpoint_handler",
"//checkpoint/orbax/checkpoint/_src:asyncio_utils",
"//checkpoint/orbax/checkpoint/_src/metadata:empty_values",
"//checkpoint/orbax/checkpoint/_src/metadata:tree",
"//checkpoint/orbax/checkpoint/_src/multihost",
"//checkpoint/orbax/checkpoint/_src/path:format_utils",
"//checkpoint/orbax/checkpoint/_src/serialization",
"//checkpoint/orbax/checkpoint/_src/serialization:tensorstore_utils",
"//checkpoint/orbax/checkpoint/_src/serialization:type_handlers",
"//checkpoint/orbax/checkpoint/_src/serialization:types",
],
)

py_library(
name = "json_checkpoint_handler",
srcs = ["json_checkpoint_handler.py"],
deps = [
":async_checkpoint_handler",
"//checkpoint/orbax/checkpoint/_src:asyncio_utils",
],
)

py_library(
name = "async_checkpoint_handler",
srcs = ["async_checkpoint_handler.py"],
deps = [":checkpoint_handler"],
)

py_library(
name = "array_checkpoint_handler",
srcs = ["array_checkpoint_handler.py"],
deps = [
":async_checkpoint_handler",
"//checkpoint/orbax/checkpoint/_src:asyncio_utils",
"//checkpoint/orbax/checkpoint/_src/serialization:tensorstore_utils",
"//checkpoint/orbax/checkpoint/_src/serialization:type_handlers",
],
)

py_test(
name = "pytree_checkpoint_handler_test",
srcs = ["pytree_checkpoint_handler_test.py"],
python_version = "PY3",
deps = [":pytree_checkpoint_handler_test_utils"],
)

py_test(
name = "json_checkpoint_handler_test",
srcs = ["json_checkpoint_handler_test.py"],
python_version = "PY3",
deps = [
":json_checkpoint_handler",
"//checkpoint/orbax/checkpoint/_src:asyncio_utils",
],
)

py_test(
name = "array_checkpoint_handler_test",
srcs = ["array_checkpoint_handler_test.py"],
python_version = "PY3",
deps = [
":array_checkpoint_handler",
"//checkpoint/orbax/checkpoint/_src/multihost",
"//checkpoint/orbax/checkpoint/_src/serialization:type_handlers",
],
)

py_library(
name = "proto_checkpoint_handler",
srcs = ["proto_checkpoint_handler.py"],
deps = [
":async_checkpoint_handler",
"//checkpoint/orbax/checkpoint/_src:asyncio_utils",
],
)

py_library(
name = "pytree_checkpoint_handler_test_utils",
srcs = ["pytree_checkpoint_handler_test_utils.py"],
deps = [
":base_pytree_checkpoint_handler",
":proto_checkpoint_handler",
":pytree_checkpoint_handler",
"//checkpoint/orbax/checkpoint/_src/metadata:empty_values",
"//checkpoint/orbax/checkpoint/_src/metadata:sharding",
"//checkpoint/orbax/checkpoint/_src/metadata:tree",
"//checkpoint/orbax/checkpoint/_src/metadata:value",
"//checkpoint/orbax/checkpoint/_src/multihost",
"//checkpoint/orbax/checkpoint/_src/serialization",
"//checkpoint/orbax/checkpoint/_src/serialization:replica_slices",
"//checkpoint/orbax/checkpoint/_src/serialization:tensorstore_utils",
"//checkpoint/orbax/checkpoint/_src/serialization:type_handlers",
],
)

py_library(
name = "standard_checkpoint_handler",
srcs = ["standard_checkpoint_handler.py"],
deps = [
":async_checkpoint_handler",
":pytree_checkpoint_handler",
"//checkpoint/orbax/checkpoint/_src:asyncio_utils",
"//checkpoint/orbax/checkpoint/_src/metadata:pytree_metadata_options",
"//checkpoint/orbax/checkpoint/_src/metadata:tree",
],
)

py_test(
name = "proto_checkpoint_handler_test",
srcs = ["proto_checkpoint_handler_test.py"],
python_version = "PY3",
deps = [":proto_checkpoint_handler"],
)

py_library(
name = "standard_checkpoint_handler_test_utils",
srcs = ["standard_checkpoint_handler_test_utils.py"],
deps = [
":standard_checkpoint_handler",
"//checkpoint/orbax/checkpoint/_src/multihost",
"//checkpoint/orbax/checkpoint/_src/serialization:type_handlers",
],
)

py_test(
name = "standard_checkpoint_handler_test",
srcs = ["standard_checkpoint_handler_test.py"],
python_version = "PY3",
deps = [":standard_checkpoint_handler_test_utils"],
)

py_library(
name = "random_key_checkpoint_handler",
srcs = ["random_key_checkpoint_handler.py"],
deps = [
":array_checkpoint_handler",
":async_checkpoint_handler",
":composite_checkpoint_handler",
":json_checkpoint_handler",
":pytree_checkpoint_handler",
"//checkpoint/orbax/checkpoint/_src:asyncio_utils",
"//checkpoint/orbax/checkpoint/_src/serialization:type_handlers",
],
)

py_test(
name = "random_key_checkpoint_handler_test",
srcs = ["random_key_checkpoint_handler_test.py"],
python_version = "PY3",
deps = [
":composite_checkpoint_handler",
":random_key_checkpoint_handler",
],
)

py_library(
name = "handler_registration",
srcs = ["handler_registration.py"],
deps = [":checkpoint_handler"],
)

py_test(
name = "handler_registration_test",
srcs = ["handler_registration_test.py"],
deps = [
":checkpoint_handler",
":handler_registration",
":standard_checkpoint_handler",
],
)

py_library(
name = "handler_type_registry",
srcs = ["handler_type_registry.py"],
deps = [":checkpoint_handler"],
)

py_test(
name = "handler_type_registry_test",
srcs = ["handler_type_registry_test.py"],
deps = [
":checkpoint_handler",
":handler_type_registry",
":standard_checkpoint_handler",
],
)
Loading

0 comments on commit 838f262

Please sign in to comment.