Skip to content

Commit

Permalink
Zarr: Optimize appending (#8998)
Browse files Browse the repository at this point in the history
* Zarr: optimize appending

* Update xarray/backends/zarr.py

* Don't run `encoding` check if it wasn't provided.

* Add regression test

* fix types

* fix test

* Use mock instead
  • Loading branch information
dcherian authored May 10, 2024
1 parent cd25bfa commit 6fe1234
Show file tree
Hide file tree
Showing 4 changed files with 237 additions and 68 deletions.
5 changes: 3 additions & 2 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,8 @@ nosetests.xml
dask-worker-space/

# asv environments
.asv
asv_bench/.asv
asv_bench/pkgs

# Translations
*.mo
Expand All @@ -68,7 +69,7 @@ dask-worker-space/

# xarray specific
doc/_build
generated/
doc/generated/
xarray/tests/data/*.grib.*.idx

# Sync tools
Expand Down
58 changes: 3 additions & 55 deletions xarray/backends/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -1521,42 +1521,6 @@ def save_mfdataset(
)


def _validate_datatypes_for_zarr_append(zstore, dataset):
"""If variable exists in the store, confirm dtype of the data to append is compatible with
existing dtype.
"""

existing_vars = zstore.get_variables()

def check_dtype(vname, var):
if (
vname not in existing_vars
or np.issubdtype(var.dtype, np.number)
or np.issubdtype(var.dtype, np.datetime64)
or np.issubdtype(var.dtype, np.bool_)
or var.dtype == object
):
# We can skip dtype equality checks under two conditions: (1) if the var to append is
# new to the dataset, because in this case there is no existing var to compare it to;
# or (2) if var to append's dtype is known to be easy-to-append, because in this case
# we can be confident appending won't cause problems. Examples of dtypes which are not
# easy-to-append include length-specified strings of type `|S*` or `<U*` (where * is a
# positive integer character length). For these dtypes, appending dissimilar lengths
# can result in truncation of appended data. Therefore, variables which already exist
# in the dataset, and with dtypes which are not known to be easy-to-append, necessitate
# exact dtype equality, as checked below.
pass
elif not var.dtype == existing_vars[vname].dtype:
raise ValueError(
f"Mismatched dtypes for variable {vname} between Zarr store on disk "
f"and dataset to append. Store has dtype {existing_vars[vname].dtype} but "
f"dataset to append has dtype {var.dtype}."
)

for vname, var in dataset.data_vars.items():
check_dtype(vname, var)


# compute=True returns ZarrStore
@overload
def to_zarr(
Expand Down Expand Up @@ -1712,37 +1676,21 @@ def to_zarr(

if region is not None:
zstore._validate_and_autodetect_region(dataset)
# can't modify indexed with region writes
# can't modify indexes with region writes
dataset = dataset.drop_vars(dataset.indexes)
if append_dim is not None and append_dim in region:
raise ValueError(
f"cannot list the same dimension in both ``append_dim`` and "
f"``region`` with to_zarr(), got {append_dim} in both"
)

if mode in ["a", "a-", "r+"]:
_validate_datatypes_for_zarr_append(zstore, dataset)
if append_dim is not None:
existing_dims = zstore.get_dimensions()
if append_dim not in existing_dims:
raise ValueError(
f"append_dim={append_dim!r} does not match any existing "
f"dataset dimensions {existing_dims}"
)
if encoding and mode in ["a", "a-", "r+"]:
existing_var_names = set(zstore.zarr_group.array_keys())
for var_name in existing_var_names:
if var_name in encoding.keys():
if var_name in encoding:
raise ValueError(
f"variable {var_name!r} already exists, but encoding was provided"
)
if mode == "r+":
new_names = [k for k in dataset.variables if k not in existing_var_names]
if new_names:
raise ValueError(
f"dataset contains non-pre-existing variables {new_names}, "
"which is not allowed in ``xarray.Dataset.to_zarr()`` with "
"mode='r+'. To allow writing new variables, set mode='a'."
)

writer = ArrayWriter()
# TODO: figure out how to properly handle unlimited_dims
Expand Down
81 changes: 70 additions & 11 deletions xarray/backends/zarr.py
Original file line number Diff line number Diff line change
Expand Up @@ -324,6 +324,34 @@ def encode_zarr_variable(var, needs_copy=True, name=None):
return var


def _validate_datatypes_for_zarr_append(vname, existing_var, new_var):
"""If variable exists in the store, confirm dtype of the data to append is compatible with
existing dtype.
"""
if (
np.issubdtype(new_var.dtype, np.number)
or np.issubdtype(new_var.dtype, np.datetime64)
or np.issubdtype(new_var.dtype, np.bool_)
or new_var.dtype == object
):
# We can skip dtype equality checks under two conditions: (1) if the var to append is
# new to the dataset, because in this case there is no existing var to compare it to;
# or (2) if var to append's dtype is known to be easy-to-append, because in this case
# we can be confident appending won't cause problems. Examples of dtypes which are not
# easy-to-append include length-specified strings of type `|S*` or `<U*` (where * is a
# positive integer character length). For these dtypes, appending dissimilar lengths
# can result in truncation of appended data. Therefore, variables which already exist
# in the dataset, and with dtypes which are not known to be easy-to-append, necessitate
# exact dtype equality, as checked below.
pass
elif not new_var.dtype == existing_var.dtype:
raise ValueError(
f"Mismatched dtypes for variable {vname} between Zarr store on disk "
f"and dataset to append. Store has dtype {existing_var.dtype} but "
f"dataset to append has dtype {new_var.dtype}."
)


def _validate_and_transpose_existing_dims(
var_name, new_var, existing_var, region, append_dim
):
Expand Down Expand Up @@ -612,26 +640,58 @@ def store(
import zarr

existing_keys = tuple(self.zarr_group.array_keys())

if self._mode == "r+":
new_names = [k for k in variables if k not in existing_keys]
if new_names:
raise ValueError(
f"dataset contains non-pre-existing variables {new_names}, "
"which is not allowed in ``xarray.Dataset.to_zarr()`` with "
"``mode='r+'``. To allow writing new variables, set ``mode='a'``."
)

if self._append_dim is not None and self._append_dim not in existing_keys:
# For dimensions without coordinate values, we must parse
# the _ARRAY_DIMENSIONS attribute on *all* arrays to check if it
# is a valid existing dimension name.
# TODO: This `get_dimensions` method also does shape checking
# which isn't strictly necessary for our check.
existing_dims = self.get_dimensions()
if self._append_dim not in existing_dims:
raise ValueError(
f"append_dim={self._append_dim!r} does not match any existing "
f"dataset dimensions {existing_dims}"
)

existing_variable_names = {
vn for vn in variables if _encode_variable_name(vn) in existing_keys
}
new_variables = set(variables) - existing_variable_names
variables_without_encoding = {vn: variables[vn] for vn in new_variables}
new_variable_names = set(variables) - existing_variable_names
variables_encoded, attributes = self.encode(
variables_without_encoding, attributes
{vn: variables[vn] for vn in new_variable_names}, attributes
)

if existing_variable_names:
# Decode variables directly, without going via xarray.Dataset to
# avoid needing to load index variables into memory.
# TODO: consider making loading indexes lazy again?
# We make sure that values to be appended are encoded *exactly*
# as the current values in the store.
# To do so, we decode variables directly to access the proper encoding,
# without going via xarray.Dataset to avoid needing to load
# index variables into memory.
existing_vars, _, _ = conventions.decode_cf_variables(
{k: self.open_store_variable(name=k) for k in existing_variable_names},
self.get_attrs(),
variables={
k: self.open_store_variable(name=k) for k in existing_variable_names
},
# attributes = {} since we don't care about parsing the global
# "coordinates" attribute
attributes={},
)
# Modified variables must use the same encoding as the store.
vars_with_encoding = {}
for vn in existing_variable_names:
if self._mode in ["a", "a-", "r+"]:
_validate_datatypes_for_zarr_append(
vn, existing_vars[vn], variables[vn]
)
vars_with_encoding[vn] = variables[vn].copy(deep=False)
vars_with_encoding[vn].encoding = existing_vars[vn].encoding
vars_with_encoding, _ = self.encode(vars_with_encoding, {})
Expand Down Expand Up @@ -696,7 +756,6 @@ def set_variables(self, variables, check_encoding_set, writer, unlimited_dims=No

for vn, v in variables.items():
name = _encode_variable_name(vn)
check = vn in check_encoding_set
attrs = v.attrs.copy()
dims = v.dims
dtype = v.dtype
Expand All @@ -712,7 +771,7 @@ def set_variables(self, variables, check_encoding_set, writer, unlimited_dims=No
# https://github.com/pydata/xarray/issues/8371 for details.
encoding = extract_zarr_variable_encoding(
v,
raise_on_invalid=check,
raise_on_invalid=vn in check_encoding_set,
name=vn,
safe_chunks=self._safe_chunks,
)
Expand Down Expand Up @@ -815,7 +874,7 @@ def _auto_detect_regions(self, ds, region):
assert variable.dims == (dim,)
index = pd.Index(variable.data)
idxs = index.get_indexer(ds[dim].data)
if any(idxs == -1):
if (idxs == -1).any():
raise KeyError(
f"Not all values of coordinate '{dim}' in the new array were"
" found in the original store. Writing to a zarr region slice"
Expand Down
161 changes: 161 additions & 0 deletions xarray/tests/test_backends.py
Original file line number Diff line number Diff line change
Expand Up @@ -2990,6 +2990,167 @@ def test_chunked_cftime_datetime(self) -> None:
assert original.chunks == actual.chunks


@requires_zarr
@pytest.mark.skipif(not have_zarr_v3, reason="requires zarr version 3")
class TestInstrumentedZarrStore:
methods = [
"__iter__",
"__contains__",
"__setitem__",
"__getitem__",
"listdir",
"list_prefix",
]

@contextlib.contextmanager
def create_zarr_target(self):
import zarr

if Version(zarr.__version__) < Version("2.18.0"):
pytest.skip("Instrumented tests only work on latest Zarr.")

store = KVStoreV3({})
yield store

def make_patches(self, store):
from unittest.mock import MagicMock

return {
method: MagicMock(
f"KVStoreV3.{method}",
side_effect=getattr(store, method),
autospec=True,
)
for method in self.methods
}

def summarize(self, patches):
summary = {}
for name, patch_ in patches.items():
count = 0
for call in patch_.mock_calls:
if "zarr.json" not in call.args:
count += 1
summary[name.strip("__")] = count
return summary

def check_requests(self, expected, patches):
summary = self.summarize(patches)
for k in summary:
assert summary[k] <= expected[k], (k, summary)

def test_append(self) -> None:
original = Dataset({"foo": ("x", [1])}, coords={"x": [0]})
modified = Dataset({"foo": ("x", [2])}, coords={"x": [1]})
with self.create_zarr_target() as store:
expected = {
"iter": 2,
"contains": 9,
"setitem": 9,
"getitem": 6,
"listdir": 2,
"list_prefix": 2,
}
patches = self.make_patches(store)
with patch.multiple(KVStoreV3, **patches):
original.to_zarr(store)
self.check_requests(expected, patches)

patches = self.make_patches(store)
# v2024.03.0: {'iter': 6, 'contains': 2, 'setitem': 5, 'getitem': 10, 'listdir': 6, 'list_prefix': 0}
# 6057128b: {'iter': 5, 'contains': 2, 'setitem': 5, 'getitem': 10, "listdir": 5, "list_prefix": 0}
expected = {
"iter": 2,
"contains": 2,
"setitem": 5,
"getitem": 6,
"listdir": 2,
"list_prefix": 0,
}
with patch.multiple(KVStoreV3, **patches):
modified.to_zarr(store, mode="a", append_dim="x")
self.check_requests(expected, patches)

patches = self.make_patches(store)
expected = {
"iter": 2,
"contains": 2,
"setitem": 5,
"getitem": 6,
"listdir": 2,
"list_prefix": 0,
}
with patch.multiple(KVStoreV3, **patches):
modified.to_zarr(store, mode="a-", append_dim="x")
self.check_requests(expected, patches)

with open_dataset(store, engine="zarr") as actual:
assert_identical(
actual, xr.concat([original, modified, modified], dim="x")
)

@requires_dask
def test_region_write(self) -> None:
ds = Dataset({"foo": ("x", [1, 2, 3])}, coords={"x": [0, 1, 2]}).chunk()
with self.create_zarr_target() as store:
expected = {
"iter": 2,
"contains": 7,
"setitem": 8,
"getitem": 6,
"listdir": 2,
"list_prefix": 4,
}
patches = self.make_patches(store)
with patch.multiple(KVStoreV3, **patches):
ds.to_zarr(store, mode="w", compute=False)
self.check_requests(expected, patches)

# v2024.03.0: {'iter': 5, 'contains': 2, 'setitem': 1, 'getitem': 6, 'listdir': 5, 'list_prefix': 0}
# 6057128b: {'iter': 4, 'contains': 2, 'setitem': 1, 'getitem': 5, 'listdir': 4, 'list_prefix': 0}
expected = {
"iter": 2,
"contains": 2,
"setitem": 1,
"getitem": 3,
"listdir": 2,
"list_prefix": 0,
}
patches = self.make_patches(store)
with patch.multiple(KVStoreV3, **patches):
ds.to_zarr(store, region={"x": slice(None)})
self.check_requests(expected, patches)

# v2024.03.0: {'iter': 6, 'contains': 4, 'setitem': 1, 'getitem': 11, 'listdir': 6, 'list_prefix': 0}
# 6057128b: {'iter': 4, 'contains': 2, 'setitem': 1, 'getitem': 7, 'listdir': 4, 'list_prefix': 0}
expected = {
"iter": 2,
"contains": 2,
"setitem": 1,
"getitem": 5,
"listdir": 2,
"list_prefix": 0,
}
patches = self.make_patches(store)
with patch.multiple(KVStoreV3, **patches):
ds.to_zarr(store, region="auto")
self.check_requests(expected, patches)

expected = {
"iter": 1,
"contains": 2,
"setitem": 0,
"getitem": 5,
"listdir": 1,
"list_prefix": 0,
}
patches = self.make_patches(store)
with patch.multiple(KVStoreV3, **patches):
with open_dataset(store, engine="zarr") as actual:
assert_identical(actual, ds)
self.check_requests(expected, patches)


@requires_zarr
class TestZarrDictStore(ZarrBase):
@contextlib.contextmanager
Expand Down

0 comments on commit 6fe1234

Please sign in to comment.