From e17ad034b99634915eb38f06f93a6984d58fe530 Mon Sep 17 00:00:00 2001 From: Yaning Liang Date: Mon, 13 Jan 2025 16:42:56 -0800 Subject: [PATCH] Add checkpoint/_src/handlers/BUILD and updates dependencies on checkpoint/_src/handlers to these targets. PiperOrigin-RevId: 715152882 --- .../orbax/checkpoint/_src/handlers/BUILD | 253 ++++++++++++++++++ .../orbax/checkpoint/_src/metadata/BUILD | 6 +- checkpoint/orbax/checkpoint/_src/path/BUILD | 7 +- .../orbax/checkpoint/_src/serialization/BUILD | 2 + requirements.txt | 3 +- 5 files changed, 268 insertions(+), 3 deletions(-) create mode 100644 checkpoint/orbax/checkpoint/_src/handlers/BUILD diff --git a/checkpoint/orbax/checkpoint/_src/handlers/BUILD b/checkpoint/orbax/checkpoint/_src/handlers/BUILD new file mode 100644 index 000000000..604fad04e --- /dev/null +++ b/checkpoint/orbax/checkpoint/_src/handlers/BUILD @@ -0,0 +1,253 @@ +package( + default_applicable_licenses = ["//:package_license"], + default_visibility = ["//visibility:public"], +) + +py_library( + name = "checkpoint_handler", + srcs = ["checkpoint_handler.py"], +) + +py_library( + name = "composite_checkpoint_handler", + srcs = ["composite_checkpoint_handler.py"], + deps = [ + ":async_checkpoint_handler", + ":checkpoint_handler", + ":handler_registration", + ":proto_checkpoint_handler", + "//checkpoint/orbax/checkpoint/_src:asyncio_utils", + "//checkpoint/orbax/checkpoint/_src:composite", + "//checkpoint/orbax/checkpoint/_src/metadata:checkpoint", + "//checkpoint/orbax/checkpoint/_src/metadata:step_metadata_serialization", + "//checkpoint/orbax/checkpoint/_src/path:atomicity", + "//checkpoint/orbax/checkpoint/_src/path:atomicity_defaults", + "//checkpoint/orbax/checkpoint/_src/path:atomicity_types", + ], +) + +py_test( + name = "composite_checkpoint_handler_test", + srcs = ["composite_checkpoint_handler_test.py"], + deps = [ + ":checkpoint_handler", + ":composite_checkpoint_handler", + ":handler_registration", + ":json_checkpoint_handler", + ":proto_checkpoint_handler", + ":standard_checkpoint_handler", + "//checkpoint/orbax/checkpoint/_src/metadata:checkpoint", + "//checkpoint/orbax/checkpoint/_src/metadata:step_metadata_serialization", + "//checkpoint/orbax/checkpoint/_src/metadata:value", + "//checkpoint/orbax/checkpoint/_src/multihost", + "//checkpoint/orbax/checkpoint/_src/path:step", + ], +) + +py_library( + name = "pytree_checkpoint_handler", + srcs = ["pytree_checkpoint_handler.py"], + srcs_version = "PY3", + deps = [ + ":async_checkpoint_handler", + ":base_pytree_checkpoint_handler", + "//checkpoint/orbax/checkpoint/_src:asyncio_utils", + "//checkpoint/orbax/checkpoint/_src/metadata:empty_values", + "//checkpoint/orbax/checkpoint/_src/metadata:tree", + "//checkpoint/orbax/checkpoint/_src/serialization", + "//checkpoint/orbax/checkpoint/_src/serialization:tensorstore_utils", + "//checkpoint/orbax/checkpoint/_src/serialization:type_handlers", + ], +) + +py_library( + name = "base_pytree_checkpoint_handler", + srcs = ["base_pytree_checkpoint_handler.py"], + srcs_version = "PY3", + deps = [ + ":async_checkpoint_handler", + "//checkpoint/orbax/checkpoint/_src:asyncio_utils", + "//checkpoint/orbax/checkpoint/_src/metadata:empty_values", + "//checkpoint/orbax/checkpoint/_src/metadata:tree", + "//checkpoint/orbax/checkpoint/_src/multihost", + "//checkpoint/orbax/checkpoint/_src/path:format_utils", + "//checkpoint/orbax/checkpoint/_src/serialization", + "//checkpoint/orbax/checkpoint/_src/serialization:tensorstore_utils", + "//checkpoint/orbax/checkpoint/_src/serialization:type_handlers", + "//checkpoint/orbax/checkpoint/_src/serialization:types", + ], +) + +py_library( + name = "json_checkpoint_handler", + srcs = ["json_checkpoint_handler.py"], + deps = [ + ":async_checkpoint_handler", + "//checkpoint/orbax/checkpoint/_src:asyncio_utils", + ], +) + +py_library( + name = "async_checkpoint_handler", + srcs = ["async_checkpoint_handler.py"], + deps = [":checkpoint_handler"], +) + +py_library( + name = "array_checkpoint_handler", + srcs = ["array_checkpoint_handler.py"], + deps = [ + ":async_checkpoint_handler", + "//checkpoint/orbax/checkpoint/_src:asyncio_utils", + "//checkpoint/orbax/checkpoint/_src/serialization:tensorstore_utils", + "//checkpoint/orbax/checkpoint/_src/serialization:type_handlers", + ], +) + +py_test( + name = "pytree_checkpoint_handler_test", + srcs = ["pytree_checkpoint_handler_test.py"], + python_version = "PY3", + deps = [":pytree_checkpoint_handler_test_utils"], +) + +py_test( + name = "json_checkpoint_handler_test", + srcs = ["json_checkpoint_handler_test.py"], + python_version = "PY3", + deps = [ + ":json_checkpoint_handler", + "//checkpoint/orbax/checkpoint/_src:asyncio_utils", + ], +) + +py_test( + name = "array_checkpoint_handler_test", + srcs = ["array_checkpoint_handler_test.py"], + python_version = "PY3", + deps = [ + ":array_checkpoint_handler", + "//checkpoint/orbax/checkpoint/_src/multihost", + "//checkpoint/orbax/checkpoint/_src/serialization:type_handlers", + ], +) + +py_library( + name = "proto_checkpoint_handler", + srcs = ["proto_checkpoint_handler.py"], + deps = [ + ":async_checkpoint_handler", + "//checkpoint/orbax/checkpoint/_src:asyncio_utils", + ], +) + +py_library( + name = "pytree_checkpoint_handler_test_utils", + srcs = ["pytree_checkpoint_handler_test_utils.py"], + deps = [ + ":base_pytree_checkpoint_handler", + ":proto_checkpoint_handler", + ":pytree_checkpoint_handler", + "//checkpoint/orbax/checkpoint/_src/metadata:empty_values", + "//checkpoint/orbax/checkpoint/_src/metadata:sharding", + "//checkpoint/orbax/checkpoint/_src/metadata:tree", + "//checkpoint/orbax/checkpoint/_src/metadata:value", + "//checkpoint/orbax/checkpoint/_src/multihost", + "//checkpoint/orbax/checkpoint/_src/serialization", + "//checkpoint/orbax/checkpoint/_src/serialization:replica_slices", + "//checkpoint/orbax/checkpoint/_src/serialization:tensorstore_utils", + "//checkpoint/orbax/checkpoint/_src/serialization:type_handlers", + ], +) + +py_library( + name = "standard_checkpoint_handler", + srcs = ["standard_checkpoint_handler.py"], + deps = [ + ":async_checkpoint_handler", + ":pytree_checkpoint_handler", + "//checkpoint/orbax/checkpoint/_src:asyncio_utils", + "//checkpoint/orbax/checkpoint/_src/metadata:pytree_metadata_options", + "//checkpoint/orbax/checkpoint/_src/metadata:tree", + ], +) + +py_test( + name = "proto_checkpoint_handler_test", + srcs = ["proto_checkpoint_handler_test.py"], + python_version = "PY3", + deps = [":proto_checkpoint_handler"], +) + +py_library( + name = "standard_checkpoint_handler_test_utils", + srcs = ["standard_checkpoint_handler_test_utils.py"], + deps = [ + ":standard_checkpoint_handler", + "//checkpoint/orbax/checkpoint/_src/multihost", + "//checkpoint/orbax/checkpoint/_src/serialization:type_handlers", + ], +) + +py_test( + name = "standard_checkpoint_handler_test", + srcs = ["standard_checkpoint_handler_test.py"], + python_version = "PY3", + deps = [":standard_checkpoint_handler_test_utils"], +) + +py_library( + name = "random_key_checkpoint_handler", + srcs = ["random_key_checkpoint_handler.py"], + deps = [ + ":array_checkpoint_handler", + ":async_checkpoint_handler", + ":composite_checkpoint_handler", + ":json_checkpoint_handler", + ":pytree_checkpoint_handler", + "//checkpoint/orbax/checkpoint/_src:asyncio_utils", + "//checkpoint/orbax/checkpoint/_src/serialization:type_handlers", + ], +) + +py_test( + name = "random_key_checkpoint_handler_test", + srcs = ["random_key_checkpoint_handler_test.py"], + python_version = "PY3", + deps = [ + ":composite_checkpoint_handler", + ":random_key_checkpoint_handler", + ], +) + +py_library( + name = "handler_registration", + srcs = ["handler_registration.py"], + deps = [":checkpoint_handler"], +) + +py_test( + name = "handler_registration_test", + srcs = ["handler_registration_test.py"], + deps = [ + ":checkpoint_handler", + ":handler_registration", + ":standard_checkpoint_handler", + ], +) + +py_library( + name = "handler_type_registry", + srcs = ["handler_type_registry.py"], + deps = [":checkpoint_handler"], +) + +py_test( + name = "handler_type_registry_test", + srcs = ["handler_type_registry_test.py"], + deps = [ + ":checkpoint_handler", + ":handler_type_registry", + ":standard_checkpoint_handler", + ], +) diff --git a/checkpoint/orbax/checkpoint/_src/metadata/BUILD b/checkpoint/orbax/checkpoint/_src/metadata/BUILD index 6c5fe1598..f2f11e208 100644 --- a/checkpoint/orbax/checkpoint/_src/metadata/BUILD +++ b/checkpoint/orbax/checkpoint/_src/metadata/BUILD @@ -27,6 +27,7 @@ py_library( ":tree_rich_types", ":value", ":value_metadata_entry", + "//checkpoint/orbax/checkpoint/_src:asyncio_utils", "//checkpoint/orbax/checkpoint/_src/serialization:tensorstore_utils", "//checkpoint/orbax/checkpoint/_src/serialization:types", ], @@ -62,7 +63,10 @@ py_test( py_library( name = "checkpoint", srcs = ["checkpoint.py"], - deps = ["//checkpoint/orbax/checkpoint/_src/logging:step_statistics"], + deps = [ + "//checkpoint/orbax/checkpoint/_src:composite", + "//checkpoint/orbax/checkpoint/_src/logging:step_statistics", + ], ) py_test( diff --git a/checkpoint/orbax/checkpoint/_src/path/BUILD b/checkpoint/orbax/checkpoint/_src/path/BUILD index 88ab19010..f180725da 100644 --- a/checkpoint/orbax/checkpoint/_src/path/BUILD +++ b/checkpoint/orbax/checkpoint/_src/path/BUILD @@ -53,7 +53,10 @@ py_test( py_library( name = "async_utils", srcs = ["async_utils.py"], - deps = [":step"], + deps = [ + ":step", + "//checkpoint/orbax/checkpoint/_src:asyncio_utils", + ], ) py_library( @@ -109,6 +112,8 @@ py_test( srcs = ["format_utils_test.py"], deps = [ ":format_utils", + "//checkpoint/orbax/checkpoint/_src/handlers:pytree_checkpoint_handler", + "//checkpoint/orbax/checkpoint/_src/handlers:standard_checkpoint_handler", "//checkpoint/orbax/checkpoint/_src/metadata:checkpoint", ], ) diff --git a/checkpoint/orbax/checkpoint/_src/serialization/BUILD b/checkpoint/orbax/checkpoint/_src/serialization/BUILD index b868cabff..ad02dc942 100644 --- a/checkpoint/orbax/checkpoint/_src/serialization/BUILD +++ b/checkpoint/orbax/checkpoint/_src/serialization/BUILD @@ -28,6 +28,7 @@ py_library( ":serialization", ":tensorstore_utils", ":types", + "//checkpoint/orbax/checkpoint/_src:asyncio_utils", "//checkpoint/orbax/checkpoint/_src/metadata:empty_values", "//checkpoint/orbax/checkpoint/_src/metadata:sharding", "//checkpoint/orbax/checkpoint/_src/metadata:value", @@ -69,6 +70,7 @@ py_test( deps = [ ":serialization", ":tensorstore_utils", + "//checkpoint/orbax/checkpoint/_src:asyncio_utils", ], ) diff --git a/requirements.txt b/requirements.txt index 6212640aa..3026541f2 100644 --- a/requirements.txt +++ b/requirements.txt @@ -23,4 +23,5 @@ optax mock nest_asyncio tensorstore -humanize \ No newline at end of file +humanize +flax \ No newline at end of file