-
Notifications
You must be signed in to change notification settings - Fork 40
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add 1. checkpoint/_src/handlers/BUILD, 2. handler tests and 3. update…
…s dependencies on checkpoint/_src/handlers to these targets. PiperOrigin-RevId: 715152882
- Loading branch information
1 parent
2a7e309
commit 838f262
Showing
10 changed files
with
2,550 additions
and
3 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,253 @@ | ||
package( | ||
default_applicable_licenses = ["//:package_license"], | ||
default_visibility = ["//visibility:public"], | ||
) | ||
|
||
py_library( | ||
name = "checkpoint_handler", | ||
srcs = ["checkpoint_handler.py"], | ||
) | ||
|
||
py_library( | ||
name = "composite_checkpoint_handler", | ||
srcs = ["composite_checkpoint_handler.py"], | ||
deps = [ | ||
":async_checkpoint_handler", | ||
":checkpoint_handler", | ||
":handler_registration", | ||
":proto_checkpoint_handler", | ||
"//checkpoint/orbax/checkpoint/_src:asyncio_utils", | ||
"//checkpoint/orbax/checkpoint/_src:composite", | ||
"//checkpoint/orbax/checkpoint/_src/metadata:checkpoint", | ||
"//checkpoint/orbax/checkpoint/_src/metadata:step_metadata_serialization", | ||
"//checkpoint/orbax/checkpoint/_src/path:atomicity", | ||
"//checkpoint/orbax/checkpoint/_src/path:atomicity_defaults", | ||
"//checkpoint/orbax/checkpoint/_src/path:atomicity_types", | ||
], | ||
) | ||
|
||
py_test( | ||
name = "composite_checkpoint_handler_test", | ||
srcs = ["composite_checkpoint_handler_test.py"], | ||
deps = [ | ||
":checkpoint_handler", | ||
":composite_checkpoint_handler", | ||
":handler_registration", | ||
":json_checkpoint_handler", | ||
":proto_checkpoint_handler", | ||
":standard_checkpoint_handler", | ||
"//checkpoint/orbax/checkpoint/_src/metadata:checkpoint", | ||
"//checkpoint/orbax/checkpoint/_src/metadata:step_metadata_serialization", | ||
"//checkpoint/orbax/checkpoint/_src/metadata:value", | ||
"//checkpoint/orbax/checkpoint/_src/multihost", | ||
"//checkpoint/orbax/checkpoint/_src/path:step", | ||
], | ||
) | ||
|
||
py_library( | ||
name = "pytree_checkpoint_handler", | ||
srcs = ["pytree_checkpoint_handler.py"], | ||
srcs_version = "PY3", | ||
deps = [ | ||
":async_checkpoint_handler", | ||
":base_pytree_checkpoint_handler", | ||
"//checkpoint/orbax/checkpoint/_src:asyncio_utils", | ||
"//checkpoint/orbax/checkpoint/_src/metadata:empty_values", | ||
"//checkpoint/orbax/checkpoint/_src/metadata:tree", | ||
"//checkpoint/orbax/checkpoint/_src/serialization", | ||
"//checkpoint/orbax/checkpoint/_src/serialization:tensorstore_utils", | ||
"//checkpoint/orbax/checkpoint/_src/serialization:type_handlers", | ||
], | ||
) | ||
|
||
py_library( | ||
name = "base_pytree_checkpoint_handler", | ||
srcs = ["base_pytree_checkpoint_handler.py"], | ||
srcs_version = "PY3", | ||
deps = [ | ||
":async_checkpoint_handler", | ||
"//checkpoint/orbax/checkpoint/_src:asyncio_utils", | ||
"//checkpoint/orbax/checkpoint/_src/metadata:empty_values", | ||
"//checkpoint/orbax/checkpoint/_src/metadata:tree", | ||
"//checkpoint/orbax/checkpoint/_src/multihost", | ||
"//checkpoint/orbax/checkpoint/_src/path:format_utils", | ||
"//checkpoint/orbax/checkpoint/_src/serialization", | ||
"//checkpoint/orbax/checkpoint/_src/serialization:tensorstore_utils", | ||
"//checkpoint/orbax/checkpoint/_src/serialization:type_handlers", | ||
"//checkpoint/orbax/checkpoint/_src/serialization:types", | ||
], | ||
) | ||
|
||
py_library( | ||
name = "json_checkpoint_handler", | ||
srcs = ["json_checkpoint_handler.py"], | ||
deps = [ | ||
":async_checkpoint_handler", | ||
"//checkpoint/orbax/checkpoint/_src:asyncio_utils", | ||
], | ||
) | ||
|
||
py_library( | ||
name = "async_checkpoint_handler", | ||
srcs = ["async_checkpoint_handler.py"], | ||
deps = [":checkpoint_handler"], | ||
) | ||
|
||
py_library( | ||
name = "array_checkpoint_handler", | ||
srcs = ["array_checkpoint_handler.py"], | ||
deps = [ | ||
":async_checkpoint_handler", | ||
"//checkpoint/orbax/checkpoint/_src:asyncio_utils", | ||
"//checkpoint/orbax/checkpoint/_src/serialization:tensorstore_utils", | ||
"//checkpoint/orbax/checkpoint/_src/serialization:type_handlers", | ||
], | ||
) | ||
|
||
py_test( | ||
name = "pytree_checkpoint_handler_test", | ||
srcs = ["pytree_checkpoint_handler_test.py"], | ||
python_version = "PY3", | ||
deps = [":pytree_checkpoint_handler_test_utils"], | ||
) | ||
|
||
py_test( | ||
name = "json_checkpoint_handler_test", | ||
srcs = ["json_checkpoint_handler_test.py"], | ||
python_version = "PY3", | ||
deps = [ | ||
":json_checkpoint_handler", | ||
"//checkpoint/orbax/checkpoint/_src:asyncio_utils", | ||
], | ||
) | ||
|
||
py_test( | ||
name = "array_checkpoint_handler_test", | ||
srcs = ["array_checkpoint_handler_test.py"], | ||
python_version = "PY3", | ||
deps = [ | ||
":array_checkpoint_handler", | ||
"//checkpoint/orbax/checkpoint/_src/multihost", | ||
"//checkpoint/orbax/checkpoint/_src/serialization:type_handlers", | ||
], | ||
) | ||
|
||
py_library( | ||
name = "proto_checkpoint_handler", | ||
srcs = ["proto_checkpoint_handler.py"], | ||
deps = [ | ||
":async_checkpoint_handler", | ||
"//checkpoint/orbax/checkpoint/_src:asyncio_utils", | ||
], | ||
) | ||
|
||
py_library( | ||
name = "pytree_checkpoint_handler_test_utils", | ||
srcs = ["pytree_checkpoint_handler_test_utils.py"], | ||
deps = [ | ||
":base_pytree_checkpoint_handler", | ||
":proto_checkpoint_handler", | ||
":pytree_checkpoint_handler", | ||
"//checkpoint/orbax/checkpoint/_src/metadata:empty_values", | ||
"//checkpoint/orbax/checkpoint/_src/metadata:sharding", | ||
"//checkpoint/orbax/checkpoint/_src/metadata:tree", | ||
"//checkpoint/orbax/checkpoint/_src/metadata:value", | ||
"//checkpoint/orbax/checkpoint/_src/multihost", | ||
"//checkpoint/orbax/checkpoint/_src/serialization", | ||
"//checkpoint/orbax/checkpoint/_src/serialization:replica_slices", | ||
"//checkpoint/orbax/checkpoint/_src/serialization:tensorstore_utils", | ||
"//checkpoint/orbax/checkpoint/_src/serialization:type_handlers", | ||
], | ||
) | ||
|
||
py_library( | ||
name = "standard_checkpoint_handler", | ||
srcs = ["standard_checkpoint_handler.py"], | ||
deps = [ | ||
":async_checkpoint_handler", | ||
":pytree_checkpoint_handler", | ||
"//checkpoint/orbax/checkpoint/_src:asyncio_utils", | ||
"//checkpoint/orbax/checkpoint/_src/metadata:pytree_metadata_options", | ||
"//checkpoint/orbax/checkpoint/_src/metadata:tree", | ||
], | ||
) | ||
|
||
py_test( | ||
name = "proto_checkpoint_handler_test", | ||
srcs = ["proto_checkpoint_handler_test.py"], | ||
python_version = "PY3", | ||
deps = [":proto_checkpoint_handler"], | ||
) | ||
|
||
py_library( | ||
name = "standard_checkpoint_handler_test_utils", | ||
srcs = ["standard_checkpoint_handler_test_utils.py"], | ||
deps = [ | ||
":standard_checkpoint_handler", | ||
"//checkpoint/orbax/checkpoint/_src/multihost", | ||
"//checkpoint/orbax/checkpoint/_src/serialization:type_handlers", | ||
], | ||
) | ||
|
||
py_test( | ||
name = "standard_checkpoint_handler_test", | ||
srcs = ["standard_checkpoint_handler_test.py"], | ||
python_version = "PY3", | ||
deps = [":standard_checkpoint_handler_test_utils"], | ||
) | ||
|
||
py_library( | ||
name = "random_key_checkpoint_handler", | ||
srcs = ["random_key_checkpoint_handler.py"], | ||
deps = [ | ||
":array_checkpoint_handler", | ||
":async_checkpoint_handler", | ||
":composite_checkpoint_handler", | ||
":json_checkpoint_handler", | ||
":pytree_checkpoint_handler", | ||
"//checkpoint/orbax/checkpoint/_src:asyncio_utils", | ||
"//checkpoint/orbax/checkpoint/_src/serialization:type_handlers", | ||
], | ||
) | ||
|
||
py_test( | ||
name = "random_key_checkpoint_handler_test", | ||
srcs = ["random_key_checkpoint_handler_test.py"], | ||
python_version = "PY3", | ||
deps = [ | ||
":composite_checkpoint_handler", | ||
":random_key_checkpoint_handler", | ||
], | ||
) | ||
|
||
py_library( | ||
name = "handler_registration", | ||
srcs = ["handler_registration.py"], | ||
deps = [":checkpoint_handler"], | ||
) | ||
|
||
py_test( | ||
name = "handler_registration_test", | ||
srcs = ["handler_registration_test.py"], | ||
deps = [ | ||
":checkpoint_handler", | ||
":handler_registration", | ||
":standard_checkpoint_handler", | ||
], | ||
) | ||
|
||
py_library( | ||
name = "handler_type_registry", | ||
srcs = ["handler_type_registry.py"], | ||
deps = [":checkpoint_handler"], | ||
) | ||
|
||
py_test( | ||
name = "handler_type_registry_test", | ||
srcs = ["handler_type_registry_test.py"], | ||
deps = [ | ||
":checkpoint_handler", | ||
":handler_type_registry", | ||
":standard_checkpoint_handler", | ||
], | ||
) |
Oops, something went wrong.