Skip to content

Commit

Permalink
Add checkpoint/_src/handlers/BUILD.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 715152882
  • Loading branch information
liangyaning33 authored and Orbax Authors committed Jan 14, 2025
1 parent 2a7e309 commit b134d8a
Show file tree
Hide file tree
Showing 5 changed files with 268 additions and 3 deletions.
253 changes: 253 additions & 0 deletions checkpoint/orbax/checkpoint/_src/handlers/BUILD
Original file line number Diff line number Diff line change
@@ -0,0 +1,253 @@
package(
default_applicable_licenses = ["//:package_license"],
default_visibility = ["//visibility:public"],
)

py_library(
name = "checkpoint_handler",
srcs = ["checkpoint_handler.py"],
)

py_library(
name = "composite_checkpoint_handler",
srcs = ["composite_checkpoint_handler.py"],
deps = [
":async_checkpoint_handler",
":checkpoint_handler",
":handler_registration",
":proto_checkpoint_handler",
"//checkpoint/orbax/checkpoint/_src:asyncio_utils",
"//checkpoint/orbax/checkpoint/_src:composite",
"//checkpoint/orbax/checkpoint/_src/metadata:checkpoint",
"//checkpoint/orbax/checkpoint/_src/metadata:step_metadata_serialization",
"//checkpoint/orbax/checkpoint/_src/path:atomicity",
"//checkpoint/orbax/checkpoint/_src/path:atomicity_defaults",
"//checkpoint/orbax/checkpoint/_src/path:atomicity_types",
],
)

py_test(
name = "composite_checkpoint_handler_test",
srcs = ["composite_checkpoint_handler_test.py"],
deps = [
":checkpoint_handler",
":composite_checkpoint_handler",
":handler_registration",
":json_checkpoint_handler",
":proto_checkpoint_handler",
":standard_checkpoint_handler",
"//checkpoint/orbax/checkpoint/_src/metadata:checkpoint",
"//checkpoint/orbax/checkpoint/_src/metadata:step_metadata_serialization",
"//checkpoint/orbax/checkpoint/_src/metadata:value",
"//checkpoint/orbax/checkpoint/_src/multihost",
"//checkpoint/orbax/checkpoint/_src/path:step",
],
)

py_library(
name = "pytree_checkpoint_handler",
srcs = ["pytree_checkpoint_handler.py"],
srcs_version = "PY3",
deps = [
":async_checkpoint_handler",
":base_pytree_checkpoint_handler",
"//checkpoint/orbax/checkpoint/_src:asyncio_utils",
"//checkpoint/orbax/checkpoint/_src/metadata:empty_values",
"//checkpoint/orbax/checkpoint/_src/metadata:tree",
"//checkpoint/orbax/checkpoint/_src/serialization",
"//checkpoint/orbax/checkpoint/_src/serialization:tensorstore_utils",
"//checkpoint/orbax/checkpoint/_src/serialization:type_handlers",
],
)

py_library(
name = "base_pytree_checkpoint_handler",
srcs = ["base_pytree_checkpoint_handler.py"],
srcs_version = "PY3",
deps = [
":async_checkpoint_handler",
"//checkpoint/orbax/checkpoint/_src:asyncio_utils",
"//checkpoint/orbax/checkpoint/_src/metadata:empty_values",
"//checkpoint/orbax/checkpoint/_src/metadata:tree",
"//checkpoint/orbax/checkpoint/_src/multihost",
"//checkpoint/orbax/checkpoint/_src/path:format_utils",
"//checkpoint/orbax/checkpoint/_src/serialization",
"//checkpoint/orbax/checkpoint/_src/serialization:tensorstore_utils",
"//checkpoint/orbax/checkpoint/_src/serialization:type_handlers",
"//checkpoint/orbax/checkpoint/_src/serialization:types",
],
)

py_library(
name = "json_checkpoint_handler",
srcs = ["json_checkpoint_handler.py"],
deps = [
":async_checkpoint_handler",
"//checkpoint/orbax/checkpoint/_src:asyncio_utils",
],
)

py_library(
name = "async_checkpoint_handler",
srcs = ["async_checkpoint_handler.py"],
deps = [":checkpoint_handler"],
)

py_library(
name = "array_checkpoint_handler",
srcs = ["array_checkpoint_handler.py"],
deps = [
":async_checkpoint_handler",
"//checkpoint/orbax/checkpoint/_src:asyncio_utils",
"//checkpoint/orbax/checkpoint/_src/serialization:tensorstore_utils",
"//checkpoint/orbax/checkpoint/_src/serialization:type_handlers",
],
)

py_test(
name = "pytree_checkpoint_handler_test",
srcs = ["pytree_checkpoint_handler_test.py"],
python_version = "PY3",
deps = [":pytree_checkpoint_handler_test_utils"],
)

py_test(
name = "json_checkpoint_handler_test",
srcs = ["json_checkpoint_handler_test.py"],
python_version = "PY3",
deps = [
":json_checkpoint_handler",
"//checkpoint/orbax/checkpoint/_src:asyncio_utils",
],
)

py_test(
name = "array_checkpoint_handler_test",
srcs = ["array_checkpoint_handler_test.py"],
python_version = "PY3",
deps = [
":array_checkpoint_handler",
"//checkpoint/orbax/checkpoint/_src/multihost",
"//checkpoint/orbax/checkpoint/_src/serialization:type_handlers",
],
)

py_library(
name = "proto_checkpoint_handler",
srcs = ["proto_checkpoint_handler.py"],
deps = [
":async_checkpoint_handler",
"//checkpoint/orbax/checkpoint/_src:asyncio_utils",
],
)

py_library(
name = "pytree_checkpoint_handler_test_utils",
srcs = ["pytree_checkpoint_handler_test_utils.py"],
deps = [
":base_pytree_checkpoint_handler",
":proto_checkpoint_handler",
":pytree_checkpoint_handler",
"//checkpoint/orbax/checkpoint/_src/metadata:empty_values",
"//checkpoint/orbax/checkpoint/_src/metadata:sharding",
"//checkpoint/orbax/checkpoint/_src/metadata:tree",
"//checkpoint/orbax/checkpoint/_src/metadata:value",
"//checkpoint/orbax/checkpoint/_src/multihost",
"//checkpoint/orbax/checkpoint/_src/serialization",
"//checkpoint/orbax/checkpoint/_src/serialization:replica_slices",
"//checkpoint/orbax/checkpoint/_src/serialization:tensorstore_utils",
"//checkpoint/orbax/checkpoint/_src/serialization:type_handlers",
],
)

py_library(
name = "standard_checkpoint_handler",
srcs = ["standard_checkpoint_handler.py"],
deps = [
":async_checkpoint_handler",
":pytree_checkpoint_handler",
"//checkpoint/orbax/checkpoint/_src:asyncio_utils",
"//checkpoint/orbax/checkpoint/_src/metadata:pytree_metadata_options",
"//checkpoint/orbax/checkpoint/_src/metadata:tree",
],
)

py_test(
name = "proto_checkpoint_handler_test",
srcs = ["proto_checkpoint_handler_test.py"],
python_version = "PY3",
deps = [":proto_checkpoint_handler"],
)

py_library(
name = "standard_checkpoint_handler_test_utils",
srcs = ["standard_checkpoint_handler_test_utils.py"],
deps = [
":standard_checkpoint_handler",
"//checkpoint/orbax/checkpoint/_src/multihost",
"//checkpoint/orbax/checkpoint/_src/serialization:type_handlers",
],
)

py_test(
name = "standard_checkpoint_handler_test",
srcs = ["standard_checkpoint_handler_test.py"],
python_version = "PY3",
deps = [":standard_checkpoint_handler_test_utils"],
)

py_library(
name = "random_key_checkpoint_handler",
srcs = ["random_key_checkpoint_handler.py"],
deps = [
":array_checkpoint_handler",
":async_checkpoint_handler",
":composite_checkpoint_handler",
":json_checkpoint_handler",
":pytree_checkpoint_handler",
"//checkpoint/orbax/checkpoint/_src:asyncio_utils",
"//checkpoint/orbax/checkpoint/_src/serialization:type_handlers",
],
)

py_test(
name = "random_key_checkpoint_handler_test",
srcs = ["random_key_checkpoint_handler_test.py"],
python_version = "PY3",
deps = [
":composite_checkpoint_handler",
":random_key_checkpoint_handler",
],
)

py_library(
name = "handler_registration",
srcs = ["handler_registration.py"],
deps = [":checkpoint_handler"],
)

py_test(
name = "handler_registration_test",
srcs = ["handler_registration_test.py"],
deps = [
":checkpoint_handler",
":handler_registration",
":standard_checkpoint_handler",
],
)

py_library(
name = "handler_type_registry",
srcs = ["handler_type_registry.py"],
deps = [":checkpoint_handler"],
)

py_test(
name = "handler_type_registry_test",
srcs = ["handler_type_registry_test.py"],
deps = [
":checkpoint_handler",
":handler_type_registry",
":standard_checkpoint_handler",
],
)
6 changes: 5 additions & 1 deletion checkpoint/orbax/checkpoint/_src/metadata/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ py_library(
":tree_rich_types",
":value",
":value_metadata_entry",
"//checkpoint/orbax/checkpoint/_src:asyncio_utils",
"//checkpoint/orbax/checkpoint/_src/serialization:tensorstore_utils",
"//checkpoint/orbax/checkpoint/_src/serialization:types",
],
Expand Down Expand Up @@ -62,7 +63,10 @@ py_test(
py_library(
name = "checkpoint",
srcs = ["checkpoint.py"],
deps = ["//checkpoint/orbax/checkpoint/_src/logging:step_statistics"],
deps = [
"//checkpoint/orbax/checkpoint/_src:composite",
"//checkpoint/orbax/checkpoint/_src/logging:step_statistics",
],
)

py_test(
Expand Down
7 changes: 6 additions & 1 deletion checkpoint/orbax/checkpoint/_src/path/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,10 @@ py_test(
py_library(
name = "async_utils",
srcs = ["async_utils.py"],
deps = [":step"],
deps = [
":step",
"//checkpoint/orbax/checkpoint/_src:asyncio_utils",
],
)

py_library(
Expand Down Expand Up @@ -109,6 +112,8 @@ py_test(
srcs = ["format_utils_test.py"],
deps = [
":format_utils",
"//checkpoint/orbax/checkpoint/_src/handlers:pytree_checkpoint_handler",
"//checkpoint/orbax/checkpoint/_src/handlers:standard_checkpoint_handler",
"//checkpoint/orbax/checkpoint/_src/metadata:checkpoint",
],
)
2 changes: 2 additions & 0 deletions checkpoint/orbax/checkpoint/_src/serialization/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ py_library(
":serialization",
":tensorstore_utils",
":types",
"//checkpoint/orbax/checkpoint/_src:asyncio_utils",
"//checkpoint/orbax/checkpoint/_src/metadata:empty_values",
"//checkpoint/orbax/checkpoint/_src/metadata:sharding",
"//checkpoint/orbax/checkpoint/_src/metadata:value",
Expand Down Expand Up @@ -69,6 +70,7 @@ py_test(
deps = [
":serialization",
":tensorstore_utils",
"//checkpoint/orbax/checkpoint/_src:asyncio_utils",
],
)

Expand Down
3 changes: 2 additions & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -23,4 +23,5 @@ optax
mock
nest_asyncio
tensorstore
humanize
humanize
flax

0 comments on commit b134d8a

Please sign in to comment.