Skip to content

Commit

Permalink
Add checkpoint/_src/arrays/BUILD and update dependencies on checkpoin…
Browse files Browse the repository at this point in the history
…t/_src/arrays to these targets.

PiperOrigin-RevId: 715238326
  • Loading branch information
liangyaning33 authored and Orbax Authors committed Jan 14, 2025
1 parent 2a7e309 commit 9eb185b
Show file tree
Hide file tree
Showing 9 changed files with 593 additions and 29 deletions.
26 changes: 0 additions & 26 deletions checkpoint/orbax/checkpoint/_src/BUILD

This file was deleted.

83 changes: 83 additions & 0 deletions checkpoint/orbax/checkpoint/_src/arrays/BUILD
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
package(
default_applicable_licenses = ["//:package_license"],
default_visibility = ["//visibility:public"],
)

py_library(
name = "types",
srcs = ["types.py"],
srcs_version = "PY3",
)

py_library(
name = "numpy_utils",
srcs = ["numpy_utils.py"],
srcs_version = "PY3",
deps = [":types"],
)

py_test(
name = "numpy_utils_test",
srcs = ["numpy_utils_test.py"],
python_version = "PY3",
srcs_version = "PY3",
deps = [":numpy_utils"],
)

py_library(
name = "fragments",
srcs = ["fragments.py"],
deps = [
":numpy_utils",
":types",
],
)

py_test(
name = "fragments_test",
srcs = ["fragments_test.py"],
python_version = "PY3",
srcs_version = "PY3",
deps = [":fragments"],
)

py_library(
name = "subchunking",
srcs = ["subchunking.py"],
srcs_version = "PY3",
deps = [
":fragments",
":types",
],
)

py_library(
name = "abstract_arrays",
srcs = ["abstract_arrays.py"],
deps = [
":types",
"//checkpoint/orbax/checkpoint/_src/metadata:sharding",
],
)

py_test(
name = "abstract_arrays_test",
srcs = ["abstract_arrays_test.py"],
deps = [
":abstract_arrays",
"//checkpoint/orbax/checkpoint/_src/serialization:type_handlers",
],
)

py_test(
name = "subchunking_test",
srcs = ["subchunking_test.py"],
args = ["--vmodule=subchunking=1"],
python_version = "PY3",
srcs_version = "PY3",
deps = [
":fragments",
":subchunking",
":types",
],
)
112 changes: 112 additions & 0 deletions checkpoint/orbax/checkpoint/_src/checkpointers/BUILD
Original file line number Diff line number Diff line change
@@ -0,0 +1,112 @@
load("//devtools/python/blaze:pytype.bzl", "pytype_strict_library")
load("//devtools/python/blaze:strict.bzl", "py_strict_test")

package(
default_applicable_licenses = ["//:package_license"],
default_visibility = ["//visibility:public"],
)

py_library(
name = "abstract_checkpointer",
srcs = ["abstract_checkpointer.py"],
deps = ["//orbax/checkpoint:version"],
)

py_library(
name = "checkpointer",
srcs = ["checkpointer.py"],
deps = [
":abstract_checkpointer",
"//checkpoint/orbax/checkpoint/_src/metadata:checkpoint",
"//checkpoint/orbax/checkpoint/_src/multihost",
"//third_party/py/etils/epy",
"//orbax/checkpoint:checkpoint_args",
"//orbax/checkpoint/_src/handlers:checkpoint_handler",
"//orbax/checkpoint/_src/handlers:composite_checkpoint_handler",
"//orbax/checkpoint/_src/path:atomicity",
"//orbax/checkpoint/_src/path:atomicity_defaults",
"//orbax/checkpoint/_src/path:atomicity_types",
],
)

py_library(
name = "pytree_checkpointer",
srcs = ["pytree_checkpointer.py"],
deps = [":checkpointer"],
)

py_library(
name = "standard_checkpointer",
srcs = ["standard_checkpointer.py"],
deps = [
":async_checkpointer",
"//checkpoint/orbax/checkpoint/_src/metadata:checkpoint",
"//orbax/checkpoint/_src/path:atomicity_types",
],
)

py_library(
name = "async_checkpointer",
srcs = ["async_checkpointer.py"],
deps = [
":checkpointer",
"//checkpoint/orbax/checkpoint/_src/metadata:checkpoint",
"//checkpoint/orbax/checkpoint/_src/multihost",
"//checkpoint/orbax/checkpoint/_src/path:async_utils",
"//orbax/checkpoint:checkpoint_args",
"//orbax/checkpoint/_src/handlers:async_checkpoint_handler",
"//orbax/checkpoint/_src/path:atomicity",
"//orbax/checkpoint/_src/path:atomicity_types",
],
)

py_library(
name = "checkpointer_test_utils",
srcs = ["checkpointer_test_utils.py"],
deps = [
":async_checkpointer",
":checkpointer",
"//checkpoint/orbax/checkpoint/_src/multihost",
"//third_party/py/flax:core",
"//third_party/py/flax/training:train_state",
"//orbax/checkpoint/_src/handlers:async_checkpoint_handler",
"//orbax/checkpoint/_src/handlers:composite_checkpoint_handler",
"//orbax/checkpoint/_src/metadata:tree",
"//orbax/checkpoint/_src/path:atomicity",
"//orbax/checkpoint/_src/path:step",
"//orbax/checkpoint/_src/serialization",
],
)

py_test(
name = "async_checkpointer_test",
srcs = ["async_checkpointer_test.py"],
args = [
"--tpu_chips_per_process=2",
"--num_processes=2",
],
python_version = "PY3",
deps = [
":async_checkpointer",
":checkpointer_test_utils",
"//checkpoint/orbax/checkpoint/_src/multihost",
"//orbax/checkpoint:checkpoint_args",
"//orbax/checkpoint/_src/handlers:async_checkpoint_handler",
"//orbax/checkpoint/_src/testing:multiprocess_test",
],
)

py_test(
name = "checkpointer_test",
srcs = ["checkpointer_test.py"],
args = [
"--tpu_chips_per_process=2",
"--num_processes=2",
],
python_version = "PY3",
deps = [
":checkpointer",
":checkpointer_test_utils",
"//orbax/checkpoint/_src/testing:multiprocess_test",
],
)
Loading

0 comments on commit 9eb185b

Please sign in to comment.