From 6fb0de59405a96a78cac3d6c44267dd8cb8ff37d Mon Sep 17 00:00:00 2001 From: Niket Kumar Bhumihar Date: Wed, 8 Jan 2025 18:33:37 -0800 Subject: [PATCH] Internal change. PiperOrigin-RevId: 713488575 --- .../orbax/checkpoint/_src/arrays/types.py | 3 + .../orbax/checkpoint/_src/metadata/BUILD | 23 ++ .../_src/metadata/array_metadata.py | 45 ++++ .../_src/metadata/array_metadata_store.py | 220 ++++++++++++++++++ .../metadata/array_metadata_store_test.py | 135 +++++++++++ .../orbax/checkpoint/_src/metadata/value.py | 2 +- .../orbax/checkpoint/_src/serialization/BUILD | 2 + .../_src/serialization/tensorstore_utils.py | 29 +-- .../_src/serialization/type_handlers.py | 58 +++-- 9 files changed, 477 insertions(+), 40 deletions(-) create mode 100644 checkpoint/orbax/checkpoint/_src/metadata/array_metadata.py create mode 100644 checkpoint/orbax/checkpoint/_src/metadata/array_metadata_store.py create mode 100644 checkpoint/orbax/checkpoint/_src/metadata/array_metadata_store_test.py diff --git a/checkpoint/orbax/checkpoint/_src/arrays/types.py b/checkpoint/orbax/checkpoint/_src/arrays/types.py index 556512590..4af778db3 100644 --- a/checkpoint/orbax/checkpoint/_src/arrays/types.py +++ b/checkpoint/orbax/checkpoint/_src/arrays/types.py @@ -16,10 +16,12 @@ import dataclasses +from jax import numpy as jnp import numpy as np Shape = tuple[int, ...] +DType = jnp.dtype | np.dtype # Indexing an np.ndarray with an empty tuple gives an array of the same shape, # *unless* the array is zero-dimensional in which case the result is a scalar. @@ -34,6 +36,7 @@ @dataclasses.dataclass(frozen=True) class NumpyShapeDtypeStruct: """Abstract representation of a Numpy array.""" + shape: Shape dtype: np.dtype diff --git a/checkpoint/orbax/checkpoint/_src/metadata/BUILD b/checkpoint/orbax/checkpoint/_src/metadata/BUILD index 6c5fe1598..8f0bd3238 100644 --- a/checkpoint/orbax/checkpoint/_src/metadata/BUILD +++ b/checkpoint/orbax/checkpoint/_src/metadata/BUILD @@ -131,3 +131,26 @@ py_test( ":pytree_metadata_options", ], ) + +py_library( + name = "array_metadata", + srcs = ["array_metadata.py"], +) + +py_library( + name = "array_metadata_store", + srcs = ["array_metadata_store.py"], + deps = [ + ":array_metadata", + "//checkpoint/orbax/checkpoint/_src/multihost", + ], +) + +py_test( + name = "array_metadata_store_test", + srcs = ["array_metadata_store_test.py"], + deps = [ + ":array_metadata", + ":array_metadata_store", + ], +) diff --git a/checkpoint/orbax/checkpoint/_src/metadata/array_metadata.py b/checkpoint/orbax/checkpoint/_src/metadata/array_metadata.py new file mode 100644 index 000000000..8850bfdad --- /dev/null +++ b/checkpoint/orbax/checkpoint/_src/metadata/array_metadata.py @@ -0,0 +1,45 @@ +# Copyright 2024 The Orbax Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Metadata describing Arrays. Meant to be used internally.""" + +import dataclasses +from orbax.checkpoint._src.arrays import types + + +@dataclasses.dataclass(frozen=True) +class ArrayMetadata: + """TensorStore metadata for a single array in a checkpoint.""" + + param_name: str # Unique full name of the parameter. + shape: types.Shape + dtype: types.DType + write_shape: types.Shape + chunk_shape: types.Shape + use_ocdbt: bool + use_zarr3: bool + + +@dataclasses.dataclass(frozen=True, kw_only=True) +class SerializedArrayMetadata: + """Serialized version of `ArrayMetadata`. + + Not all fields of `ArrayMetadata` are serialized. + + Used in subchunking based checkpointing context. + """ + + param_name: str # Unique full name of the parameter. + write_shape: types.Shape + chunk_shape: types.Shape diff --git a/checkpoint/orbax/checkpoint/_src/metadata/array_metadata_store.py b/checkpoint/orbax/checkpoint/_src/metadata/array_metadata_store.py new file mode 100644 index 000000000..84aaa1d89 --- /dev/null +++ b/checkpoint/orbax/checkpoint/_src/metadata/array_metadata_store.py @@ -0,0 +1,220 @@ +# Copyright 2024 The Orbax Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Storage for `array_metadata.ArrayMetadata` (not value.ArrayMetadata).""" + +import json +import threading +from typing import Any, Iterator, List, Sequence +from absl import logging +from etils import epath +from orbax.checkpoint._src.metadata import array_metadata as array_metadata_lib +from orbax.checkpoint._src.multihost import multihost + + +class PathResolver: + """Resolves paths for the ArrayMetadata store read and write.""" + + _metadata_subdir = 'array_metadatas' + + def _file_name(self, process_index: int | str) -> str: + return f'process_{process_index}' + + def get_process_index(self, file_path: epath.Path) -> int: + """Returns the process index from the file path.""" + process_index = file_path.name.removeprefix('process_') + if process_index.isdigit(): + return int(process_index) + raise ValueError( + f'Invalid ArrayMetadata file path: {file_path}; expected file name' + ' to start with "process_"' + ) + + def get_write_file_path( + self, checkpoint_dir: epath.Path, process_index: int + ) -> epath.Path: + """Returns the file path to write.""" + return ( + checkpoint_dir / self._metadata_subdir / self._file_name(process_index) + ) + + def get_read_file_paths( + self, checkpoint_dir: epath.Path, process_index: int | None = None + ) -> Iterator[epath.Path] | epath.Path | None: + """Returns the file paths to read. + + Args: + checkpoint_dir: The base path containing metadata for each process. + process_index: The process index to read. If None, then read all processes + under `checkpoint_dir`. + + Returns: + Iterator of file paths to read if `process_index` is None. A file path to + read if `process_index` is not None. None if `process_index` is not None + but metadata file does not exist. + """ + if process_index is None: + file_name_pattern = self._file_name('*') + return checkpoint_dir.glob(f'{self._metadata_subdir}/{file_name_pattern}') + file_path = ( + checkpoint_dir / self._metadata_subdir / self._file_name(process_index) + ) + if file_path.exists(): + return file_path + return None + + +class SerDeserializer: + """Serializes and deserializes `tensorstore_utils.ArrayMetadata`.""" + + def _to_dict( + self, array_metadata: array_metadata_lib.ArrayMetadata + ) -> dict[str, Any]: + """Converts `array_metadata` to a dictionary.""" + return { + 'array_metadata': { + 'param_name': array_metadata.param_name, + 'write_shape': array_metadata.write_shape, + 'chunk_shape': array_metadata.chunk_shape, + } + } + + def _from_dict(self, obj: dict[str, Any]) -> Any: + """Converts a json object to `SerializedArrayMetadata` or `obj`.""" + if 'array_metadata' in obj: + array_metadata = obj['array_metadata'] + return array_metadata_lib.SerializedArrayMetadata( + param_name=array_metadata['param_name'], + write_shape=tuple(array_metadata['write_shape']), + chunk_shape=tuple(array_metadata['chunk_shape']), + ) + return obj + + def serialize( + self, array_metadatas: Sequence[array_metadata_lib.ArrayMetadata] + ) -> str: + """Serializes `array_metadatas` to string.""" + obj = { + 'array_metadatas': [ + self._to_dict(array_metadata) for array_metadata in array_metadatas + ] + } + return json.dumps(obj) + + def deserialize( + self, serialized: str + ) -> List[array_metadata_lib.SerializedArrayMetadata]: + """Deserializes `serialized` to `tensorstore_utils.ArrayMetadata`.""" + obj = json.loads(serialized, object_hook=self._from_dict) + return obj['array_metadatas'] + + +class Store: + """Storage for `tensorstore_utils.ArrayMetadata` (not value.ArrayMetadata).""" + + def __init__( + self, + path_resolver: PathResolver = PathResolver(), + ser_deser: SerDeserializer = SerDeserializer(), + ): + self._path_resolver = path_resolver + self._ser_deser = ser_deser + + async def write( + self, + checkpoint_dir: epath.Path, + array_metadatas: Sequence[array_metadata_lib.ArrayMetadata], + process_index: int, + ) -> None: + """Writes `array_metadatas` to a file under `checkpoint_dir`. + + See `PathResolver.get_write_file_path()` for the file path resolution. + + Args: + checkpoint_dir: The base path containing metadata for each process. + array_metadatas: The sequence of metadata to write. + process_index: The Jax process index used to resolve the file path. + """ + file_path = self._path_resolver.get_write_file_path( + checkpoint_dir, process_index + ) + file_path.parent.mkdir(parents=True, exist_ok=True) + file_path.write_text(self._ser_deser.serialize(array_metadatas)) + logging.info( + '[process=%s][thread=%s] Wrote %d tensorstore_utils.ArrayMetadata' + ' to %s', + multihost.process_index(), + threading.current_thread().name, + len(array_metadatas), + file_path, + ) + + def read( + self, + checkpoint_dir: epath.Path, + process_index: int | None = None, + ) -> ( + dict[int, List[array_metadata_lib.SerializedArrayMetadata]] + | List[array_metadata_lib.SerializedArrayMetadata] + | None + ): + """Reads `SerializedArrayMetadata` from storage under `checkpoint_dir`. + + Args: + checkpoint_dir: The base path containing metadata for each process. + process_index: The process index to read. If None, then read all processes + under `checkpoint_dir`. + + Returns: + A dictionary of process index to list of metadata if `process_index` + is None. A list of metadata if `process_index` is not None. None if + metadata does not exist. + """ + if not checkpoint_dir.exists(): + raise ValueError( + f'Checkpoint directory does not exist: {checkpoint_dir}.' + ) + file_paths = self._path_resolver.get_read_file_paths( + checkpoint_dir, process_index + ) + if file_paths is None: + logging.warning( + '[process=%s][thread=%s] No metadata found for process_index=%s,' + ' checkpoint_dir=%s. Please ignore if input checkpoint does not' + ' contain any jax.Array.', + multihost.process_index(), + threading.current_thread().name, + process_index, + checkpoint_dir, + ) + return None + if isinstance(file_paths, epath.Path): + return self._ser_deser.deserialize(file_paths.read_text()) + result = { + self._path_resolver.get_process_index( + file_path + ): self._ser_deser.deserialize(file_path.read_text()) + for file_path in file_paths + } + if not result: + logging.warning( + '[process=%s][thread=%s] No metadata found for any process_index,' + ' checkpoint_dir=%s. Please ignore if input checkpoint does not' + ' contain any jax.Array.', + multihost.process_index(), + threading.current_thread().name, + checkpoint_dir, + ) + return None + return result diff --git a/checkpoint/orbax/checkpoint/_src/metadata/array_metadata_store_test.py b/checkpoint/orbax/checkpoint/_src/metadata/array_metadata_store_test.py new file mode 100644 index 000000000..225a79551 --- /dev/null +++ b/checkpoint/orbax/checkpoint/_src/metadata/array_metadata_store_test.py @@ -0,0 +1,135 @@ +# Copyright 2024 The Orbax Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for `array_metadata_store` module.""" + +import unittest +from absl.testing import absltest +from absl.testing import parameterized +from etils import epath +import numpy as np +from orbax.checkpoint._src.metadata import array_metadata as array_metadata_lib +from orbax.checkpoint._src.metadata import array_metadata_store + + +class StoreTest(parameterized.TestCase, unittest.IsolatedAsyncioTestCase): + + def setUp(self): + super().setUp() + self.checkpoint_dir = epath.Path(self.create_tempdir().full_path) + self.store = array_metadata_store.Store() + + def test_non_existing_checkpoint_dir(self): + with self.assertRaisesRegex( + ValueError, 'Checkpoint directory does not exist' + ): + _ = self.store.read(self.checkpoint_dir / 'unknown_dir') + + def test_non_existing_metadata_files(self): + self.assertIsNone(self.store.read(self.checkpoint_dir)) + + (self.checkpoint_dir / 'array_metadatas').mkdir( + parents=True, exist_ok=False + ) + self.assertIsNone(self.store.read(self.checkpoint_dir)) + + async def test_write_and_read_single_process(self): + process_index = 0 + array_metadatas = [ + array_metadata_lib.ArrayMetadata( + param_name='a', + shape=(10, 20, 30), + dtype=np.dtype(int), + write_shape=(10, 20, 30), + chunk_shape=(1, 2, 3), + use_ocdbt=False, + use_zarr3=False, + ), + array_metadata_lib.ArrayMetadata( + param_name='b', + shape=(1, 1, 1), + dtype=np.dtype(int), + write_shape=(1, 1, 1), + chunk_shape=(1, 1, 1), + use_ocdbt=False, + use_zarr3=False, + ), + ] + await self.store.write( + self.checkpoint_dir, array_metadatas, process_index=process_index + ) + + self.assertEqual( + self.store.read(self.checkpoint_dir, process_index=process_index), + [ + array_metadata_lib.SerializedArrayMetadata( + param_name='a', + write_shape=(10, 20, 30), + chunk_shape=(1, 2, 3), + ), + array_metadata_lib.SerializedArrayMetadata( + param_name='b', + write_shape=(1, 1, 1), + chunk_shape=(1, 1, 1), + ), + ], + ) + + async def test_write_and_read_multiple_process(self): + for process_index in [0, 1, 2]: + array_metadatas = [ + array_metadata_lib.ArrayMetadata( + param_name=f'a_{process_index}', + shape=(10, 20, 30), + dtype=np.dtype(int), + write_shape=(10, 20, 30), + chunk_shape=(1, 2, 3), + use_ocdbt=False, + use_zarr3=False, + ), + ] + await self.store.write( + self.checkpoint_dir, array_metadatas, process_index=process_index + ) + + self.assertEqual( + self.store.read(self.checkpoint_dir, process_index=None), + { + 0: [ + array_metadata_lib.SerializedArrayMetadata( + param_name='a_0', + write_shape=(10, 20, 30), + chunk_shape=(1, 2, 3), + ) + ], + 1: [ + array_metadata_lib.SerializedArrayMetadata( + param_name='a_1', + write_shape=(10, 20, 30), + chunk_shape=(1, 2, 3), + ) + ], + 2: [ + array_metadata_lib.SerializedArrayMetadata( + param_name='a_2', + write_shape=(10, 20, 30), + chunk_shape=(1, 2, 3), + ) + ], + }, + ) + + +if __name__ == '__main__': + absltest.main() diff --git a/checkpoint/orbax/checkpoint/_src/metadata/value.py b/checkpoint/orbax/checkpoint/_src/metadata/value.py index 1e71acd18..65f30ff65 100644 --- a/checkpoint/orbax/checkpoint/_src/metadata/value.py +++ b/checkpoint/orbax/checkpoint/_src/metadata/value.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -"""Metadata describing PyTree values.""" +"""User facing metadata describing PyTree values.""" from __future__ import annotations diff --git a/checkpoint/orbax/checkpoint/_src/serialization/BUILD b/checkpoint/orbax/checkpoint/_src/serialization/BUILD index b868cabff..63bf3808c 100644 --- a/checkpoint/orbax/checkpoint/_src/serialization/BUILD +++ b/checkpoint/orbax/checkpoint/_src/serialization/BUILD @@ -7,6 +7,7 @@ py_library( name = "tensorstore_utils", srcs = ["tensorstore_utils.py"], srcs_version = "PY3", + deps = ["//orbax/checkpoint/_src/metadata:array_metadata"], ) py_library( @@ -35,6 +36,7 @@ py_library( "//checkpoint/orbax/checkpoint/_src/multihost:multislice", "//checkpoint/orbax/checkpoint/_src/path:async_utils", "//checkpoint/orbax/checkpoint/_src/path:format_utils", + "//orbax/checkpoint/_src/metadata:array_metadata_store", ], ) diff --git a/checkpoint/orbax/checkpoint/_src/serialization/tensorstore_utils.py b/checkpoint/orbax/checkpoint/_src/serialization/tensorstore_utils.py index d4a72dfcd..ddc62cf0d 100644 --- a/checkpoint/orbax/checkpoint/_src/serialization/tensorstore_utils.py +++ b/checkpoint/orbax/checkpoint/_src/serialization/tensorstore_utils.py @@ -14,19 +14,22 @@ """TensorStore serialization helper functions.""" -import dataclasses import math import os import re -from typing import Any, TypeAlias +from typing import Any, TypeAlias from absl import logging from jax import numpy as jnp -import numpy as np from orbax.checkpoint._src.arrays import subchunking from orbax.checkpoint._src.arrays import types +from orbax.checkpoint._src.metadata import array_metadata import tensorstore as ts +JsonSpec: TypeAlias = dict[str, Any] +Shape: TypeAlias = types.Shape +DType: TypeAlias = types.DType +ArrayMetadata: TypeAlias = array_metadata.ArrayMetadata DEFAULT_DRIVER = 'file' @@ -48,10 +51,6 @@ STORE_ARRAY_DATA_EQUAL_TO_FILL_VALUE = True -JsonSpec: TypeAlias = dict[str, Any] -Shape: TypeAlias = types.Shape -DType: TypeAlias = jnp.dtype | np.dtype - _BASE_TS_CONTEXT = { 'file_io_concurrency': {'limit': 128}, } @@ -112,8 +111,8 @@ def build_kvstore_tspec( name: Name (filename) of the parameter. use_ocdbt: Whether to use OCDBT driver. process_id: [only used with OCDBT driver] If provided, - `{directory}/ocdbt.process_{process_id}` path is used as the base path. - If a string, must conform to [A-Za-z0-9]+ pattern. + `{directory}/ocdbt.process_{process_id}` path is used as the base path. If + a string, must conform to [A-Za-z0-9]+ pattern. Returns: A Tensorstore KvStore spec in dictionary form. @@ -285,17 +284,6 @@ def calculate_chunk_byte_size( ### Building TensorStore array specs. -@dataclasses.dataclass(frozen=True) -class ArrayMetadata: - """TensorStore metadata for a single array in a checkpoint.""" - shape: Shape - dtype: DType - write_shape: Shape - chunk_shape: Shape - use_ocdbt: bool - use_zarr3: bool - - def _maybe_add_cast_to_write_spec( array_tspec: JsonSpec, *, @@ -398,6 +386,7 @@ def __init__( # Keep the metadata in a separate field. self._metadata = ArrayMetadata( + param_name=relative_array_filename, shape=global_shape, dtype=target_storage_dtype, write_shape=write_shape, diff --git a/checkpoint/orbax/checkpoint/_src/serialization/type_handlers.py b/checkpoint/orbax/checkpoint/_src/serialization/type_handlers.py index 14d9d70c3..8bf36cd59 100644 --- a/checkpoint/orbax/checkpoint/_src/serialization/type_handlers.py +++ b/checkpoint/orbax/checkpoint/_src/serialization/type_handlers.py @@ -39,6 +39,7 @@ from orbax.checkpoint._src import asyncio_utils from orbax.checkpoint._src.arrays import subchunking from orbax.checkpoint._src.arrays import types as arrays_types +from orbax.checkpoint._src.metadata import array_metadata_store as array_metadata_store_lib from orbax.checkpoint._src.metadata import empty_values from orbax.checkpoint._src.metadata import sharding as sharding_metadata from orbax.checkpoint._src.metadata import value as value_metadata @@ -202,7 +203,7 @@ def get_json_tspec_write( return tspec -def _build_array_tspec_write( +def _build_array_write_spec( info: types.ParamInfo, arg: Optional[types.SaveArgs] = None, *, @@ -212,8 +213,8 @@ def _build_array_tspec_write( use_ocdbt: bool, process_index: Optional[Union[int, str]] = None, metadata_key: Optional[str] = None, -) -> ts_utils.JsonSpec: - """Gets Tensorstore spec for writing.""" +) -> ts_utils.ArrayWriteSpec: + """Gets ArrayWriteSpec for writing.""" if info.path is None: raise ValueError('Must construct serialization path.') parent_dir = info.parent_dir @@ -234,7 +235,7 @@ def _build_array_tspec_write( process_id=process_index, ocdbt_target_data_file_size=info.ocdbt_target_data_file_size, metadata_key=metadata_key, - ).json + ) class _CommitFuture(future.Future): @@ -345,7 +346,8 @@ async def _validate_params( ) unique = with_zarray | without_zarray - logging.info( + logging.vlog( + 1, '[process=%s][thread=%s] Validating params in TensorStore KvStore.', process_index, current_thread_name, @@ -542,16 +544,16 @@ def __init__(self, metadata_key: Optional[str] = None): # managed by this controller. self._override_ocdbt_process_id: Optional[str] = None - def _get_json_tspec_write( + def _get_array_write_spec( self, info: types.ParamInfo, value: np.ndarray, use_ocdbt: bool, process_index: Optional[Union[int, str]] = None, arg: Optional[types.SaveArgs] = None, - ) -> Dict[str, Any]: - """Gets Tensorstore spec for writing.""" - return _build_array_tspec_write( + ) -> ts_utils.ArrayWriteSpec: + """Gets ArrayWriteSpec for writing.""" + return _build_array_write_spec( info=info, arg=arg, global_shape=value.shape, @@ -614,7 +616,7 @@ async def _background_serialize( """Serializes numpy arrays in a background thread.""" write_coros = [] for value, info, arg in zip(values, infos, args): - tspec = self._get_json_tspec_write( + array_write_spec = self._get_array_write_spec( info, value, use_ocdbt=info.is_ocdbt_checkpoint, @@ -624,6 +626,7 @@ async def _background_serialize( ), arg=arg, ) + tspec = array_write_spec.json if logging.vlog_is_on(1): logging.vlog(1, 'tspec = %s', tspec) logging.vlog(1, 'infos = %s', info) @@ -846,6 +849,7 @@ def __init__( replica_id: Optional[int] = 0, use_replica_parallel: bool = True, enable_write_sharding_file: bool = True, + array_metadata_store: array_metadata_store_lib.Store | None = None, ): """Constructor. @@ -860,19 +864,25 @@ def __init__( use_replica_parallel: Whether to parallelize saving across replicas. enable_write_sharding_file: whether to write sharding file, defaults to True. + array_metadata_store: Store to manage per host ArrayMetadata. To disable + ArrayMetadata persistence, set it to None. """ self._metadata_key = metadata_key self._primary_host = primary_host self._replica_id = replica_id self._enable_write_sharding_file = enable_write_sharding_file self._use_replica_parallel = use_replica_parallel + self._array_metadata_store = array_metadata_store - logging.info( - 'Created `ArrayHandler` with primary_host=%s, replica_id=%s,' - ' use_replica_parallel=%s', + logging.vlog( + 1, + 'Created `%s` with primary_host=%s, replica_id=%s,' + ' use_replica_parallel=%s, array_metadata_store=%s', + self.__class__.__qualname__, self._primary_host, self._replica_id, self._use_replica_parallel, + self._array_metadata_store, ) if self._primary_host is None and jax.__version_info__ <= (0, 4, 25): # pylint:disable=unreachable @@ -880,7 +890,7 @@ def __init__( 'Setting `primary_host` to None requires JAX version > 0.4.25.' ) - def _get_json_tspec_write( + def _get_array_write_spec( self, info: types.ParamInfo, value: replica_slices.ReplicaSlices, @@ -888,9 +898,9 @@ def _get_json_tspec_write( use_ocdbt: bool, process_index: Optional[Union[int, str]] = None, arg: Optional[types.SaveArgs] = None, - ) -> Dict[str, Any]: - """Gets Tensorstore spec for writing.""" - return _build_array_tspec_write( + ) -> ts_utils.ArrayWriteSpec: + """Gets ArrayWriteSpec for writing.""" + return _build_array_write_spec( info=info, arg=arg, global_shape=value.global_shape, @@ -1013,6 +1023,7 @@ async def _background_serialize( write_coros = [] sharding_metadata_txn = ts.Transaction() ocdbt_transaction: Optional[ts.Transaction] = None + array_metadatas = [] for value, info, arg in zip(values, infos, args): # The byte_limiter can't be used with a transaction, because awaiting the # `write` only waits until the in-memory transaction state reflects the @@ -1021,13 +1032,14 @@ async def _background_serialize( if info.is_ocdbt_checkpoint and info.byte_limiter is None: if ocdbt_transaction is None: ocdbt_transaction = ts.Transaction(atomic=True) - tspec = self._get_json_tspec_write( + array_write_spec = self._get_array_write_spec( info, value, use_ocdbt=info.is_ocdbt_checkpoint, process_index=get_process_index_for_subdir(info.is_ocdbt_checkpoint), arg=arg, ) + tspec = array_write_spec.json ts_context = info.ts_context write_coros.append( serialization.async_serialize_from_host( @@ -1045,7 +1057,15 @@ async def _background_serialize( value.sharding, info, sharding_metadata_txn ) ) - + array_metadatas.append(array_write_spec.metadata) + if self._array_metadata_store is not None: + write_coros.append( + self._array_metadata_store.write( + checkpoint_dir=infos[0].parent_dir, + array_metadatas=array_metadatas, + process_index=multihost.process_index(), + ) + ) await asyncio.gather(*write_coros) await sharding_metadata_txn.commit_async() if ocdbt_transaction is not None: