Skip to content

Commit

Permalink
Add checkpoint/_src/arrays/BUILD and update dependencies on checkpoin…
Browse files Browse the repository at this point in the history
…t/_src/arrays to these targets.

PiperOrigin-RevId: 715238326
  • Loading branch information
liangyaning33 authored and Orbax Authors committed Jan 14, 2025
1 parent 2a7e309 commit b5c7971
Show file tree
Hide file tree
Showing 3 changed files with 109 additions and 3 deletions.
83 changes: 83 additions & 0 deletions checkpoint/orbax/checkpoint/_src/arrays/BUILD
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
package(
default_applicable_licenses = ["//:package_license"],
default_visibility = ["//visibility:public"],
)

py_library(
name = "types",
srcs = ["types.py"],
srcs_version = "PY3",
)

py_library(
name = "numpy_utils",
srcs = ["numpy_utils.py"],
srcs_version = "PY3",
deps = [":types"],
)

py_test(
name = "numpy_utils_test",
srcs = ["numpy_utils_test.py"],
python_version = "PY3",
srcs_version = "PY3",
deps = [":numpy_utils"],
)

py_library(
name = "fragments",
srcs = ["fragments.py"],
deps = [
":numpy_utils",
":types",
],
)

py_test(
name = "fragments_test",
srcs = ["fragments_test.py"],
python_version = "PY3",
srcs_version = "PY3",
deps = [":fragments"],
)

py_library(
name = "subchunking",
srcs = ["subchunking.py"],
srcs_version = "PY3",
deps = [
":fragments",
":types",
],
)

py_library(
name = "abstract_arrays",
srcs = ["abstract_arrays.py"],
deps = [
":types",
"//checkpoint/orbax/checkpoint/_src/metadata:sharding",
],
)

py_test(
name = "abstract_arrays_test",
srcs = ["abstract_arrays_test.py"],
deps = [
":abstract_arrays",
"//checkpoint/orbax/checkpoint/_src/serialization:type_handlers",
],
)

py_test(
name = "subchunking_test",
srcs = ["subchunking_test.py"],
args = ["--vmodule=subchunking=1"],
python_version = "PY3",
srcs_version = "PY3",
deps = [
":fragments",
":subchunking",
":types",
],
)
6 changes: 5 additions & 1 deletion checkpoint/orbax/checkpoint/_src/metadata/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,10 @@ py_test(
py_library(
name = "value",
srcs = ["value.py"],
deps = [":sharding"],
deps = [
":sharding",
"//checkpoint/orbax/checkpoint/_src/arrays:types",
],
)

py_library(
Expand Down Expand Up @@ -111,6 +114,7 @@ py_library(
deps = [
":empty_values",
":pytree_metadata_options",
"//checkpoint/orbax/checkpoint/_src/arrays:types",
"//checkpoint/orbax/checkpoint/_src/serialization:types",
],
)
Expand Down
23 changes: 21 additions & 2 deletions checkpoint/orbax/checkpoint/_src/serialization/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -7,13 +7,18 @@ py_library(
name = "tensorstore_utils",
srcs = ["tensorstore_utils.py"],
srcs_version = "PY3",
deps = [
"//checkpoint/orbax/checkpoint/_src/arrays:subchunking",
"//checkpoint/orbax/checkpoint/_src/arrays:types",
],
)

py_library(
name = "types",
srcs = ["types.py"],
deps = [
":serialization",
"//checkpoint/orbax/checkpoint/_src/arrays:types",
"//checkpoint/orbax/checkpoint/_src/metadata:empty_values",
"//checkpoint/orbax/checkpoint/_src/metadata:pytree_metadata_options",
"//checkpoint/orbax/checkpoint/_src/metadata:value",
Expand All @@ -28,6 +33,8 @@ py_library(
":serialization",
":tensorstore_utils",
":types",
"//checkpoint/orbax/checkpoint/_src/arrays:subchunking",
"//checkpoint/orbax/checkpoint/_src/arrays:types",
"//checkpoint/orbax/checkpoint/_src/metadata:empty_values",
"//checkpoint/orbax/checkpoint/_src/metadata:sharding",
"//checkpoint/orbax/checkpoint/_src/metadata:value",
Expand All @@ -43,7 +50,11 @@ py_test(
srcs = ["tensorstore_utils_test.py"],
python_version = "PY3",
srcs_version = "PY3",
deps = [":tensorstore_utils"],
deps = [
":tensorstore_utils",
"//checkpoint/orbax/checkpoint/_src/arrays:subchunking",
"//checkpoint/orbax/checkpoint/_src/arrays:types",
],
)

py_library(
Expand All @@ -52,14 +63,22 @@ py_library(
deps = [
":replica_slices",
":tensorstore_utils",
"//checkpoint/orbax/checkpoint/_src/arrays:fragments",
"//checkpoint/orbax/checkpoint/_src/arrays:numpy_utils",
"//checkpoint/orbax/checkpoint/_src/arrays:types",
"//checkpoint/orbax/checkpoint/_src/multihost",
],
)

py_library(
name = "replica_slices",
srcs = ["replica_slices.py"],
deps = ["//checkpoint/orbax/checkpoint/_src/multihost"],
deps = [
"//checkpoint/orbax/checkpoint/_src/arrays:fragments",
"//checkpoint/orbax/checkpoint/_src/arrays:numpy_utils",
"//checkpoint/orbax/checkpoint/_src/arrays:types",
"//checkpoint/orbax/checkpoint/_src/multihost",
],
)

py_test(
Expand Down

0 comments on commit b5c7971

Please sign in to comment.