diff --git a/checkpoint/orbax/checkpoint/_src/BUILD b/checkpoint/orbax/checkpoint/_src/BUILD deleted file mode 100644 index ff512f325..000000000 --- a/checkpoint/orbax/checkpoint/_src/BUILD +++ /dev/null @@ -1,26 +0,0 @@ -package( - default_applicable_licenses = ["//:package_license"], - default_visibility = ["//visibility:public"], -) - -py_library( - name = "asyncio_utils", - srcs = ["asyncio_utils.py"], -) - -py_test( - name = "asyncio_utils_test", - srcs = ["asyncio_utils_test.py"], - deps = [":asyncio_utils"], -) - -py_library( - name = "composite", - srcs = ["composite.py"], -) - -py_test( - name = "composite_test", - srcs = ["composite_test.py"], - deps = [":composite"], -) diff --git a/checkpoint/orbax/checkpoint/_src/arrays/BUILD b/checkpoint/orbax/checkpoint/_src/arrays/BUILD new file mode 100644 index 000000000..bfebfa135 --- /dev/null +++ b/checkpoint/orbax/checkpoint/_src/arrays/BUILD @@ -0,0 +1,83 @@ +package( + default_applicable_licenses = ["//:package_license"], + default_visibility = ["//visibility:public"], +) + +py_library( + name = "types", + srcs = ["types.py"], + srcs_version = "PY3", +) + +py_library( + name = "numpy_utils", + srcs = ["numpy_utils.py"], + srcs_version = "PY3", + deps = [":types"], +) + +py_test( + name = "numpy_utils_test", + srcs = ["numpy_utils_test.py"], + python_version = "PY3", + srcs_version = "PY3", + deps = [":numpy_utils"], +) + +py_library( + name = "fragments", + srcs = ["fragments.py"], + deps = [ + ":numpy_utils", + ":types", + ], +) + +py_test( + name = "fragments_test", + srcs = ["fragments_test.py"], + python_version = "PY3", + srcs_version = "PY3", + deps = [":fragments"], +) + +py_library( + name = "subchunking", + srcs = ["subchunking.py"], + srcs_version = "PY3", + deps = [ + ":fragments", + ":types", + ], +) + +py_library( + name = "abstract_arrays", + srcs = ["abstract_arrays.py"], + deps = [ + ":types", + "//checkpoint/orbax/checkpoint/_src/metadata:sharding", + ], +) + +py_test( + name = "abstract_arrays_test", + srcs = ["abstract_arrays_test.py"], + deps = [ + ":abstract_arrays", + "//checkpoint/orbax/checkpoint/_src/serialization:type_handlers", + ], +) + +py_test( + name = "subchunking_test", + srcs = ["subchunking_test.py"], + args = ["--vmodule=subchunking=1"], + python_version = "PY3", + srcs_version = "PY3", + deps = [ + ":fragments", + ":subchunking", + ":types", + ], +) diff --git a/checkpoint/orbax/checkpoint/_src/checkpointers/BUILD b/checkpoint/orbax/checkpoint/_src/checkpointers/BUILD new file mode 100644 index 000000000..199d96028 --- /dev/null +++ b/checkpoint/orbax/checkpoint/_src/checkpointers/BUILD @@ -0,0 +1,112 @@ +load("//devtools/python/blaze:pytype.bzl", "pytype_strict_library") +load("//devtools/python/blaze:strict.bzl", "py_strict_test") + +package( + default_applicable_licenses = ["//:package_license"], + default_visibility = ["//visibility:public"], +) + +py_library( + name = "abstract_checkpointer", + srcs = ["abstract_checkpointer.py"], + deps = ["//orbax/checkpoint:version"], +) + +py_library( + name = "checkpointer", + srcs = ["checkpointer.py"], + deps = [ + ":abstract_checkpointer", + "//checkpoint/orbax/checkpoint/_src/metadata:checkpoint", + "//checkpoint/orbax/checkpoint/_src/multihost", + "//third_party/py/etils/epy", + "//orbax/checkpoint:checkpoint_args", + "//orbax/checkpoint/_src/handlers:checkpoint_handler", + "//orbax/checkpoint/_src/handlers:composite_checkpoint_handler", + "//orbax/checkpoint/_src/path:atomicity", + "//orbax/checkpoint/_src/path:atomicity_defaults", + "//orbax/checkpoint/_src/path:atomicity_types", + ], +) + +py_library( + name = "pytree_checkpointer", + srcs = ["pytree_checkpointer.py"], + deps = [":checkpointer"], +) + +py_library( + name = "standard_checkpointer", + srcs = ["standard_checkpointer.py"], + deps = [ + ":async_checkpointer", + "//checkpoint/orbax/checkpoint/_src/metadata:checkpoint", + "//orbax/checkpoint/_src/path:atomicity_types", + ], +) + +py_library( + name = "async_checkpointer", + srcs = ["async_checkpointer.py"], + deps = [ + ":checkpointer", + "//checkpoint/orbax/checkpoint/_src/metadata:checkpoint", + "//checkpoint/orbax/checkpoint/_src/multihost", + "//checkpoint/orbax/checkpoint/_src/path:async_utils", + "//orbax/checkpoint:checkpoint_args", + "//orbax/checkpoint/_src/handlers:async_checkpoint_handler", + "//orbax/checkpoint/_src/path:atomicity", + "//orbax/checkpoint/_src/path:atomicity_types", + ], +) + +py_library( + name = "checkpointer_test_utils", + srcs = ["checkpointer_test_utils.py"], + deps = [ + ":async_checkpointer", + ":checkpointer", + "//checkpoint/orbax/checkpoint/_src/multihost", + "//third_party/py/flax:core", + "//third_party/py/flax/training:train_state", + "//orbax/checkpoint/_src/handlers:async_checkpoint_handler", + "//orbax/checkpoint/_src/handlers:composite_checkpoint_handler", + "//orbax/checkpoint/_src/metadata:tree", + "//orbax/checkpoint/_src/path:atomicity", + "//orbax/checkpoint/_src/path:step", + "//orbax/checkpoint/_src/serialization", + ], +) + +py_test( + name = "async_checkpointer_test", + srcs = ["async_checkpointer_test.py"], + args = [ + "--tpu_chips_per_process=2", + "--num_processes=2", + ], + python_version = "PY3", + deps = [ + ":async_checkpointer", + ":checkpointer_test_utils", + "//checkpoint/orbax/checkpoint/_src/multihost", + "//orbax/checkpoint:checkpoint_args", + "//orbax/checkpoint/_src/handlers:async_checkpoint_handler", + "//orbax/checkpoint/_src/testing:multiprocess_test", + ], +) + +py_test( + name = "checkpointer_test", + srcs = ["checkpointer_test.py"], + args = [ + "--tpu_chips_per_process=2", + "--num_processes=2", + ], + python_version = "PY3", + deps = [ + ":checkpointer", + ":checkpointer_test_utils", + "//orbax/checkpoint/_src/testing:multiprocess_test", + ], +) diff --git a/checkpoint/orbax/checkpoint/_src/handlers/BUILD b/checkpoint/orbax/checkpoint/_src/handlers/BUILD new file mode 100644 index 000000000..71c69b765 --- /dev/null +++ b/checkpoint/orbax/checkpoint/_src/handlers/BUILD @@ -0,0 +1,295 @@ +load("//devtools/python/blaze:pytype.bzl", "pytype_strict_library", pytype_strict_test = "pytype_strict_contrib_test") +load("//devtools/python/blaze:strict.bzl", "py_strict_test") + +package( + default_applicable_licenses = ["//:package_license"], + default_visibility = ["//visibility:public"], +) + +py_library( + name = "checkpoint_handler", + srcs = ["checkpoint_handler.py"], +) + +py_library( + name = "composite_checkpoint_handler", + srcs = ["composite_checkpoint_handler.py"], + deps = [ + ":async_checkpoint_handler", + ":checkpoint_handler", + ":handler_registration", + ":proto_checkpoint_handler", + "//checkpoint/orbax/checkpoint/_src/metadata:checkpoint", + "//checkpoint/orbax/checkpoint/_src/metadata:step_metadata_serialization", + "//orbax/checkpoint:checkpoint_args", + "//orbax/checkpoint/_src/path:atomicity", + "//orbax/checkpoint/_src/path:atomicity_defaults", + "//orbax/checkpoint/_src/path:atomicity_types", + "//orbax/checkpoint/google:build_data_utils", + ], +) + +py_test( + name = "composite_checkpoint_handler_test", + srcs = ["composite_checkpoint_handler_test.py"], + deps = [ + ":checkpoint_handler", + ":composite_checkpoint_handler", + ":handler_registration", + ":json_checkpoint_handler", + ":proto_checkpoint_handler", + ":standard_checkpoint_handler", + "//checkpoint/orbax/checkpoint/_src/metadata:checkpoint", + "//checkpoint/orbax/checkpoint/_src/metadata:step_metadata_serialization", + "//checkpoint/orbax/checkpoint/_src/metadata:value", + "//checkpoint/orbax/checkpoint/_src/multihost", + "//orbax/checkpoint:logging", + "//orbax/checkpoint/_src/path:step", + ], +) + +py_library( + name = "pytree_checkpoint_handler", + srcs = ["pytree_checkpoint_handler.py"], + srcs_version = "PY3", + tags = ["ignore_for_dep=third_party.py.orbax.checkpoint.google.pathways_type_handlers"], + deps = [ + ":async_checkpoint_handler", + ":base_pytree_checkpoint_handler", + "//checkpoint/orbax/checkpoint/_src/metadata:empty_values", + "//checkpoint/orbax/checkpoint/_src/serialization:tensorstore_utils", + "//checkpoint/orbax/checkpoint/_src/serialization:type_handlers", + "//orbax/checkpoint:aggregate_handlers", + "//orbax/checkpoint:checkpoint_args", + "//orbax/checkpoint:transform_utils", + "//orbax/checkpoint/_src/metadata:tree", + "//orbax/checkpoint/_src/serialization", + ], +) + +py_library( + name = "base_pytree_checkpoint_handler", + srcs = ["base_pytree_checkpoint_handler.py"], + srcs_version = "PY3", + tags = ["ignore_for_dep=third_party.py.orbax.checkpoint.google.pathways_type_handlers"], + deps = [ + ":async_checkpoint_handler", + "//checkpoint/orbax/checkpoint/_src/metadata:empty_values", + "//checkpoint/orbax/checkpoint/_src/multihost", + "//checkpoint/orbax/checkpoint/_src/path:format_utils", + "//checkpoint/orbax/checkpoint/_src/serialization:tensorstore_utils", + "//checkpoint/orbax/checkpoint/_src/serialization:type_handlers", + "//checkpoint/orbax/checkpoint/_src/serialization:types", + "//third_party/py/google/protobuf", + "//orbax/checkpoint:checkpoint_args", + "//orbax/checkpoint/_src/metadata:tree", + "//orbax/checkpoint/_src/serialization", + "//orbax/checkpoint/google:build_data_utils", + ], +) + +py_library( + name = "json_checkpoint_handler", + srcs = ["json_checkpoint_handler.py"], + deps = [ + ":async_checkpoint_handler", + "//orbax/checkpoint:checkpoint_args", + ], +) + +py_library( + name = "async_checkpoint_handler", + srcs = ["async_checkpoint_handler.py"], + deps = [":checkpoint_handler"], +) + +py_library( + name = "array_checkpoint_handler", + srcs = ["array_checkpoint_handler.py"], + deps = [ + ":async_checkpoint_handler", + "//checkpoint/orbax/checkpoint/_src/serialization:tensorstore_utils", + "//checkpoint/orbax/checkpoint/_src/serialization:type_handlers", + "//orbax/checkpoint:aggregate_handlers", + "//orbax/checkpoint:checkpoint_args", + ], +) + +py_test( + name = "pytree_checkpoint_handler_test", + srcs = ["pytree_checkpoint_handler_test.py"], + args = [ + "--tpu_chips_per_process=2", + "--num_processes=2", + ], + python_version = "PY3", + deps = [ + ":pytree_checkpoint_handler_test_utils", + "//orbax/checkpoint/_src/testing:multiprocess_test", + ], +) + +py_test( + name = "json_checkpoint_handler_test", + srcs = ["json_checkpoint_handler_test.py"], + python_version = "PY3", + deps = [":json_checkpoint_handler"], +) + +py_test( + name = "array_checkpoint_handler_test", + srcs = ["array_checkpoint_handler_test.py"], + args = [ + "--tpu_chips_per_process=2", + "--num_processes=2", + ], + python_version = "PY3", + deps = [ + ":array_checkpoint_handler", + "//checkpoint/orbax/checkpoint/_src/multihost", + "//checkpoint/orbax/checkpoint/_src/serialization:type_handlers", + "//orbax/checkpoint/_src/testing:multiprocess_test", + ], +) + +py_library( + name = "proto_checkpoint_handler", + srcs = ["proto_checkpoint_handler.py"], + deps = [ + ":async_checkpoint_handler", + "//third_party/py/google/protobuf", + "//orbax/checkpoint:checkpoint_args", + ], +) + +py_library( + name = "pytree_checkpoint_handler_test_utils", + srcs = ["pytree_checkpoint_handler_test_utils.py"], + deps = [ + ":base_pytree_checkpoint_handler", + ":proto_checkpoint_handler", + ":pytree_checkpoint_handler", + "//checkpoint/orbax/checkpoint/_src/metadata:empty_values", + "//checkpoint/orbax/checkpoint/_src/metadata:sharding", + "//checkpoint/orbax/checkpoint/_src/metadata:value", + "//checkpoint/orbax/checkpoint/_src/multihost", + "//checkpoint/orbax/checkpoint/_src/serialization:tensorstore_utils", + "//checkpoint/orbax/checkpoint/_src/serialization:type_handlers", + "//third_party/py/flax:core", + "//third_party/py/flax/training:train_state", + "//orbax/checkpoint:msgpack_utils", + "//orbax/checkpoint:transform_utils", + "//orbax/checkpoint/_src/metadata:tree", + "//orbax/checkpoint/_src/serialization", + "//orbax/checkpoint/_src/serialization:replica_slices", + ], +) + +py_library( + name = "standard_checkpoint_handler", + srcs = ["standard_checkpoint_handler.py"], + deps = [ + ":async_checkpoint_handler", + ":pytree_checkpoint_handler", + "//checkpoint/orbax/checkpoint/_src/metadata:pytree_metadata_options", + "//orbax/checkpoint:checkpoint_args", + "//orbax/checkpoint:checkpoint_utils", + "//orbax/checkpoint/_src/metadata:tree", + ], +) + +py_test( + name = "proto_checkpoint_handler_test", + srcs = ["proto_checkpoint_handler_test.py"], + python_version = "PY3", + deps = [ + ":proto_checkpoint_handler", + "//orbax/checkpoint/proto/testing:foo_py_pb2", + ], +) + +py_library( + name = "standard_checkpoint_handler_test_utils", + srcs = ["standard_checkpoint_handler_test_utils.py"], + deps = [ + ":standard_checkpoint_handler", + "//checkpoint/orbax/checkpoint/_src/multihost", + "//checkpoint/orbax/checkpoint/_src/serialization:type_handlers", + "//third_party/py/flax:core", + "//third_party/py/flax/training:train_state", + ], +) + +py_test( + name = "standard_checkpoint_handler_test", + srcs = ["standard_checkpoint_handler_test.py"], + args = [ + "--tpu_chips_per_process=2", + "--num_processes=2", + ], + python_version = "PY3", + deps = [ + ":standard_checkpoint_handler_test_utils", + "//orbax/checkpoint/_src/testing:multiprocess_test", + ], +) + +py_library( + name = "random_key_checkpoint_handler", + srcs = ["random_key_checkpoint_handler.py"], + deps = [ + ":array_checkpoint_handler", + ":async_checkpoint_handler", + ":composite_checkpoint_handler", + ":json_checkpoint_handler", + ":pytree_checkpoint_handler", + "//checkpoint/orbax/checkpoint/_src/serialization:type_handlers", + "//orbax/checkpoint:checkpoint_args", + ], +) + +py_test( + name = "random_key_checkpoint_handler_test", + srcs = ["random_key_checkpoint_handler_test.py"], + python_version = "PY3", + deps = [ + ":composite_checkpoint_handler", + ":random_key_checkpoint_handler", + ], +) + +py_library( + name = "handler_registration", + srcs = ["handler_registration.py"], + deps = [ + ":checkpoint_handler", + "//orbax/checkpoint:checkpoint_args", + ], +) + +py_test( + name = "handler_registration_test", + srcs = ["handler_registration_test.py"], + deps = [ + ":checkpoint_handler", + ":handler_registration", + ":standard_checkpoint_handler", + "//orbax/checkpoint:checkpoint_args", + ], +) + +py_library( + name = "handler_type_registry", + srcs = ["handler_type_registry.py"], + deps = [":checkpoint_handler"], +) + +py_test( + name = "handler_type_registry_test", + srcs = ["handler_type_registry_test.py"], + deps = [ + ":checkpoint_handler", + ":handler_type_registry", + ":standard_checkpoint_handler", + ], +) diff --git a/checkpoint/orbax/checkpoint/_src/metadata/BUILD b/checkpoint/orbax/checkpoint/_src/metadata/BUILD index 6c5fe1598..45f16f830 100644 --- a/checkpoint/orbax/checkpoint/_src/metadata/BUILD +++ b/checkpoint/orbax/checkpoint/_src/metadata/BUILD @@ -45,7 +45,10 @@ py_test( py_library( name = "value", srcs = ["value.py"], - deps = [":sharding"], + deps = [ + ":sharding", + "//checkpoint/orbax/checkpoint/_src/arrays:types", + ], ) py_library( @@ -111,6 +114,7 @@ py_library( deps = [ ":empty_values", ":pytree_metadata_options", + "//checkpoint/orbax/checkpoint/_src/arrays:types", "//checkpoint/orbax/checkpoint/_src/serialization:types", ], ) diff --git a/checkpoint/orbax/checkpoint/_src/serialization/BUILD b/checkpoint/orbax/checkpoint/_src/serialization/BUILD index b868cabff..965ad1d05 100644 --- a/checkpoint/orbax/checkpoint/_src/serialization/BUILD +++ b/checkpoint/orbax/checkpoint/_src/serialization/BUILD @@ -7,6 +7,10 @@ py_library( name = "tensorstore_utils", srcs = ["tensorstore_utils.py"], srcs_version = "PY3", + deps = [ + "//checkpoint/orbax/checkpoint/_src/arrays:subchunking", + "//checkpoint/orbax/checkpoint/_src/arrays:types", + ], ) py_library( @@ -14,6 +18,7 @@ py_library( srcs = ["types.py"], deps = [ ":serialization", + "//checkpoint/orbax/checkpoint/_src/arrays:types", "//checkpoint/orbax/checkpoint/_src/metadata:empty_values", "//checkpoint/orbax/checkpoint/_src/metadata:pytree_metadata_options", "//checkpoint/orbax/checkpoint/_src/metadata:value", @@ -28,6 +33,8 @@ py_library( ":serialization", ":tensorstore_utils", ":types", + "//checkpoint/orbax/checkpoint/_src/arrays:subchunking", + "//checkpoint/orbax/checkpoint/_src/arrays:types", "//checkpoint/orbax/checkpoint/_src/metadata:empty_values", "//checkpoint/orbax/checkpoint/_src/metadata:sharding", "//checkpoint/orbax/checkpoint/_src/metadata:value", @@ -43,7 +50,11 @@ py_test( srcs = ["tensorstore_utils_test.py"], python_version = "PY3", srcs_version = "PY3", - deps = [":tensorstore_utils"], + deps = [ + ":tensorstore_utils", + "//checkpoint/orbax/checkpoint/_src/arrays:subchunking", + "//checkpoint/orbax/checkpoint/_src/arrays:types", + ], ) py_library( @@ -52,6 +63,9 @@ py_library( deps = [ ":replica_slices", ":tensorstore_utils", + "//checkpoint/orbax/checkpoint/_src/arrays:fragments", + "//checkpoint/orbax/checkpoint/_src/arrays:numpy_utils", + "//checkpoint/orbax/checkpoint/_src/arrays:types", "//checkpoint/orbax/checkpoint/_src/multihost", ], ) @@ -59,7 +73,12 @@ py_library( py_library( name = "replica_slices", srcs = ["replica_slices.py"], - deps = ["//checkpoint/orbax/checkpoint/_src/multihost"], + deps = [ + "//checkpoint/orbax/checkpoint/_src/arrays:fragments", + "//checkpoint/orbax/checkpoint/_src/arrays:numpy_utils", + "//checkpoint/orbax/checkpoint/_src/arrays:types", + "//checkpoint/orbax/checkpoint/_src/multihost", + ], ) py_test( diff --git a/checkpoint/orbax/checkpoint/_src/testing/BUILD b/checkpoint/orbax/checkpoint/_src/testing/BUILD new file mode 100644 index 000000000..828502843 --- /dev/null +++ b/checkpoint/orbax/checkpoint/_src/testing/BUILD @@ -0,0 +1,28 @@ +load("//devtools/python/blaze:pytype.bzl", "pytype_strict_library") + +package( + default_applicable_licenses = ["//:package_license"], + default_visibility = ["//visibility:public"], +) + +py_library( + name = "multiprocess_test", + testonly = 1, + srcs = ["multiprocess_test.py"], + deps = [ + "//checkpoint/orbax/checkpoint/_src/multihost", + "//third_party/py/absl:app", + "//third_party/py/portpicker", + ], +) + +py_library( + name = "test_tree_utils", + srcs = ["test_tree_utils.py"], + srcs_version = "PY3", + deps = [ + "//third_party/py/flax:core", + "//orbax/checkpoint/_src/metadata:tree", + "//orbax/checkpoint/_src/metadata:tree_rich_types", + ], +) diff --git a/checkpoint/orbax/checkpoint/_src/testing/tree_verity/BUILD b/checkpoint/orbax/checkpoint/_src/testing/tree_verity/BUILD new file mode 100644 index 000000000..e82af5eb5 --- /dev/null +++ b/checkpoint/orbax/checkpoint/_src/testing/tree_verity/BUILD @@ -0,0 +1,19 @@ +load("//devtools/python/blaze:pytype.bzl", pytype_strict_test = "pytype_strict_contrib_test") + +package( + default_applicable_licenses = ["//:package_license"], + default_visibility = ["//visibility:public"], +) + +py_test( + name = "checkpoint_manager_test", + timeout = "long", + srcs = ["checkpoint_manager_test.py"], + tags = ["optonly"], + deps = [ + "//checkpoint/orbax/checkpoint/_src/multihost", + "//security/bcid/tree_verity:tree_verity_py_pb2", + "//orbax/checkpoint/_src/path:step", + "//orbax/checkpoint/_src/testing:multiprocess_test", + ], +) diff --git a/checkpoint/orbax/checkpoint/_src/tree/BUILD b/checkpoint/orbax/checkpoint/_src/tree/BUILD new file mode 100644 index 000000000..6e2db17ea --- /dev/null +++ b/checkpoint/orbax/checkpoint/_src/tree/BUILD @@ -0,0 +1,30 @@ +load("//devtools/python/blaze:pytype.bzl", "pytype_strict_library", pytype_strict_test = "pytype_strict_contrib_test") + +package( + default_applicable_licenses = ["//:package_license"], + default_visibility = ["//visibility:public"], +) + +py_library( + name = "types", + srcs = ["types.py"], + srcs_version = "PY3", +) + +py_library( + name = "utils", + srcs = ["utils.py"], + deps = [ + ":types", + "//orbax/checkpoint/_src/arrays:abstract_arrays", + ], +) + +py_test( + name = "utils_test", + srcs = ["utils_test.py"], + deps = [ + ":utils", + "//third_party/py/flax:core", + ], +)