Skip to content
This repository has been archived by the owner on Jan 13, 2025. It is now read-only.

Commit

Permalink
Add orbax/checkpoint/_src/serialization/BUILD and update serilaizatio…
Browse files Browse the repository at this point in the history
…n dependencies to this BUILD.

PiperOrigin-RevId: 714849208
  • Loading branch information
liangyaning33 authored and Orbax Authors committed Jan 13, 2025
1 parent e97f19d commit 6fe0813
Show file tree
Hide file tree
Showing 4 changed files with 89 additions and 28 deletions.
26 changes: 0 additions & 26 deletions checkpoint/orbax/checkpoint/_src/BUILD

This file was deleted.

9 changes: 8 additions & 1 deletion checkpoint/orbax/checkpoint/_src/metadata/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -27,13 +27,19 @@ py_library(
":tree_rich_types",
":value",
":value_metadata_entry",
"//checkpoint/orbax/checkpoint/_src/serialization:tensorstore_utils",
"//checkpoint/orbax/checkpoint/_src/serialization:types",
],
)

py_test(
name = "tree_test",
srcs = ["tree_test.py"],
deps = [":tree"],
deps = [
":tree",
"//checkpoint/orbax/checkpoint/_src/serialization:type_handlers",
"//checkpoint/orbax/checkpoint/_src/serialization:types",
],
)

py_library(
Expand Down Expand Up @@ -102,6 +108,7 @@ py_library(
deps = [
":empty_values",
":pytree_metadata_options",
"//checkpoint/orbax/checkpoint/_src/serialization:types",
],
)

Expand Down
79 changes: 79 additions & 0 deletions checkpoint/orbax/checkpoint/_src/serialization/BUILD
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
package(
default_applicable_licenses = ["//:package_license"],
default_visibility = ["//visibility:public"],
)

py_library(
name = "tensorstore_utils",
srcs = ["tensorstore_utils.py"],
srcs_version = "PY3",
)

py_library(
name = "types",
srcs = ["types.py"],
deps = [
":serialization",
"//checkpoint/orbax/checkpoint/_src/metadata:empty_values",
"//checkpoint/orbax/checkpoint/_src/metadata:pytree_metadata_options",
"//checkpoint/orbax/checkpoint/_src/metadata:value",
],
)

py_library(
name = "type_handlers",
srcs = ["type_handlers.py"],
deps = [
":replica_slices",
":serialization",
":tensorstore_utils",
":types",
"//checkpoint/orbax/checkpoint/_src/metadata:empty_values",
"//checkpoint/orbax/checkpoint/_src/metadata:sharding",
"//checkpoint/orbax/checkpoint/_src/metadata:value",
"//checkpoint/orbax/checkpoint/_src/multihost",
"//checkpoint/orbax/checkpoint/_src/multihost:multislice",
"//checkpoint/orbax/checkpoint/_src/path:async_utils",
"//checkpoint/orbax/checkpoint/_src/path:format_utils",
],
)

py_test(
name = "tensorstore_utils_test",
srcs = ["tensorstore_utils_test.py"],
python_version = "PY3",
srcs_version = "PY3",
deps = [":tensorstore_utils"],
)

py_library(
name = "serialization",
srcs = ["serialization.py"],
deps = [
":replica_slices",
":tensorstore_utils",
"//checkpoint/orbax/checkpoint/_src/multihost",
],
)

py_library(
name = "replica_slices",
srcs = ["replica_slices.py"],
deps = ["//checkpoint/orbax/checkpoint/_src/multihost"],
)

py_test(
name = "serialization_test",
srcs = ["serialization_test.py"],
python_version = "PY3",
deps = [
":serialization",
":tensorstore_utils",
],
)

py_test(
name = "replica_slices_test",
srcs = ["replica_slices_test.py"],
deps = [":replica_slices"],
)
3 changes: 2 additions & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -21,4 +21,5 @@ simplejson
chex
optax
nest_asyncio

tensorstore
humanize

0 comments on commit 6fe0813

Please sign in to comment.