Skip to content

Commit

Permalink
Zarr: Optimize region="auto" detection (#8997)
Browse files Browse the repository at this point in the history
* Zarr: Optimize region detection

* Fix for unindexed dimensions.

* Better example

* small cleanup
  • Loading branch information
dcherian authored May 7, 2024
1 parent 2ad98b1 commit dcf2ac4
Show file tree
Hide file tree
Showing 3 changed files with 101 additions and 114 deletions.
4 changes: 2 additions & 2 deletions doc/user-guide/io.rst
Original file line number Diff line number Diff line change
Expand Up @@ -874,7 +874,7 @@ and then calling ``to_zarr`` with ``compute=False`` to write only metadata
# The values of this dask array are entirely irrelevant; only the dtype,
# shape and chunks are used
dummies = dask.array.zeros(30, chunks=10)
ds = xr.Dataset({"foo": ("x", dummies)})
ds = xr.Dataset({"foo": ("x", dummies)}, coords={"x": np.arange(30)})
path = "path/to/directory.zarr"
# Now we write the metadata without computing any array values
ds.to_zarr(path, compute=False)
Expand All @@ -890,7 +890,7 @@ where the data should be written (in index space, not label space), e.g.,
# For convenience, we'll slice a single dataset, but in the real use-case
# we would create them separately possibly even from separate processes.
ds = xr.Dataset({"foo": ("x", np.arange(30))})
ds = xr.Dataset({"foo": ("x", np.arange(30))}, coords={"x": np.arange(30)})
# Any of the following region specifications are valid
ds.isel(x=slice(0, 10)).to_zarr(path, region="auto")
ds.isel(x=slice(10, 20)).to_zarr(path, region={"x": "auto"})
Expand Down
115 changes: 10 additions & 105 deletions xarray/backends/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,6 @@
_normalize_path,
)
from xarray.backends.locks import _get_scheduler
from xarray.backends.zarr import open_zarr
from xarray.core import indexing
from xarray.core.combine import (
_infer_concat_order_from_positions,
Expand Down Expand Up @@ -1522,92 +1521,6 @@ def save_mfdataset(
)


def _auto_detect_region(ds_new, ds_orig, dim):
# Create a mapping array of coordinates to indices on the original array
coord = ds_orig[dim]
da_map = DataArray(np.arange(coord.size), coords={dim: coord})

try:
da_idxs = da_map.sel({dim: ds_new[dim]})
except KeyError as e:
if "not all values found" in str(e):
raise KeyError(
f"Not all values of coordinate '{dim}' in the new array were"
" found in the original store. Writing to a zarr region slice"
" requires that no dimensions or metadata are changed by the write."
)
else:
raise e

if (da_idxs.diff(dim) != 1).any():
raise ValueError(
f"The auto-detected region of coordinate '{dim}' for writing new data"
" to the original store had non-contiguous indices. Writing to a zarr"
" region slice requires that the new data constitute a contiguous subset"
" of the original store."
)

dim_slice = slice(da_idxs.values[0], da_idxs.values[-1] + 1)

return dim_slice


def _auto_detect_regions(ds, region, open_kwargs):
ds_original = open_zarr(**open_kwargs)
for key, val in region.items():
if val == "auto":
region[key] = _auto_detect_region(ds, ds_original, key)
return region


def _validate_and_autodetect_region(ds, region, mode, open_kwargs) -> dict[str, slice]:
if region == "auto":
region = {dim: "auto" for dim in ds.dims}

if not isinstance(region, dict):
raise TypeError(f"``region`` must be a dict, got {type(region)}")

if any(v == "auto" for v in region.values()):
if mode != "r+":
raise ValueError(
f"``mode`` must be 'r+' when using ``region='auto'``, got {mode}"
)
region = _auto_detect_regions(ds, region, open_kwargs)

for k, v in region.items():
if k not in ds.dims:
raise ValueError(
f"all keys in ``region`` are not in Dataset dimensions, got "
f"{list(region)} and {list(ds.dims)}"
)
if not isinstance(v, slice):
raise TypeError(
"all values in ``region`` must be slice objects, got "
f"region={region}"
)
if v.step not in {1, None}:
raise ValueError(
"step on all slices in ``region`` must be 1 or None, got "
f"region={region}"
)

non_matching_vars = [
k for k, v in ds.variables.items() if not set(region).intersection(v.dims)
]
if non_matching_vars:
raise ValueError(
f"when setting `region` explicitly in to_zarr(), all "
f"variables in the dataset to write must have at least "
f"one dimension in common with the region's dimensions "
f"{list(region.keys())}, but that is not "
f"the case for some variables here. To drop these variables "
f"from this dataset before exporting to zarr, write: "
f".drop_vars({non_matching_vars!r})"
)

return region


def _validate_datatypes_for_zarr_append(zstore, dataset):
"""If variable exists in the store, confirm dtype of the data to append is compatible with
existing dtype.
Expand Down Expand Up @@ -1768,24 +1681,6 @@ def to_zarr(
# validate Dataset keys, DataArray names
_validate_dataset_names(dataset)

if region is not None:
open_kwargs = dict(
store=store,
synchronizer=synchronizer,
group=group,
consolidated=consolidated,
storage_options=storage_options,
zarr_version=zarr_version,
)
region = _validate_and_autodetect_region(dataset, region, mode, open_kwargs)
# can't modify indexed with region writes
dataset = dataset.drop_vars(dataset.indexes)
if append_dim is not None and append_dim in region:
raise ValueError(
f"cannot list the same dimension in both ``append_dim`` and "
f"``region`` with to_zarr(), got {append_dim} in both"
)

if zarr_version is None:
# default to 2 if store doesn't specify it's version (e.g. a path)
zarr_version = int(getattr(store, "_store_version", 2))
Expand Down Expand Up @@ -1815,6 +1710,16 @@ def to_zarr(
write_empty=write_empty_chunks,
)

if region is not None:
zstore._validate_and_autodetect_region(dataset)
# can't modify indexed with region writes
dataset = dataset.drop_vars(dataset.indexes)
if append_dim is not None and append_dim in region:
raise ValueError(
f"cannot list the same dimension in both ``append_dim`` and "
f"``region`` with to_zarr(), got {append_dim} in both"
)

if mode in ["a", "a-", "r+"]:
_validate_datatypes_for_zarr_append(zstore, dataset)
if append_dim is not None:
Expand Down
96 changes: 89 additions & 7 deletions xarray/backends/zarr.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from typing import TYPE_CHECKING, Any

import numpy as np
import pandas as pd

from xarray import coding, conventions
from xarray.backends.common import (
Expand Down Expand Up @@ -509,7 +510,9 @@ def ds(self):
# TODO: consider deprecating this in favor of zarr_group
return self.zarr_group

def open_store_variable(self, name, zarr_array):
def open_store_variable(self, name, zarr_array=None):
if zarr_array is None:
zarr_array = self.zarr_group[name]
data = indexing.LazilyIndexedArray(ZarrArrayWrapper(zarr_array))
try_nczarr = self._mode == "r"
dimensions, attributes = _get_zarr_dims_and_attrs(
Expand Down Expand Up @@ -623,11 +626,7 @@ def store(
# avoid needing to load index variables into memory.
# TODO: consider making loading indexes lazy again?
existing_vars, _, _ = conventions.decode_cf_variables(
{
k: v
for k, v in self.get_variables().items()
if k in existing_variable_names
},
{k: self.open_store_variable(name=k) for k in existing_variable_names},
self.get_attrs(),
)
# Modified variables must use the same encoding as the store.
Expand Down Expand Up @@ -796,10 +795,93 @@ def set_variables(self, variables, check_encoding_set, writer, unlimited_dims=No
region = tuple(write_region[dim] for dim in dims)
writer.add(v.data, zarr_array, region)

def close(self):
def close(self) -> None:
if self._close_store_on_close:
self.zarr_group.store.close()

def _auto_detect_regions(self, ds, region):
for dim, val in region.items():
if val != "auto":
continue

if dim not in ds._variables:
# unindexed dimension
region[dim] = slice(0, ds.sizes[dim])
continue

variable = conventions.decode_cf_variable(
dim, self.open_store_variable(dim).compute()
)
assert variable.dims == (dim,)
index = pd.Index(variable.data)
idxs = index.get_indexer(ds[dim].data)
if any(idxs == -1):
raise KeyError(
f"Not all values of coordinate '{dim}' in the new array were"
" found in the original store. Writing to a zarr region slice"
" requires that no dimensions or metadata are changed by the write."
)

if (np.diff(idxs) != 1).any():
raise ValueError(
f"The auto-detected region of coordinate '{dim}' for writing new data"
" to the original store had non-contiguous indices. Writing to a zarr"
" region slice requires that the new data constitute a contiguous subset"
" of the original store."
)
region[dim] = slice(idxs[0], idxs[-1] + 1)
return region

def _validate_and_autodetect_region(self, ds) -> None:
region = self._write_region

if region == "auto":
region = {dim: "auto" for dim in ds.dims}

if not isinstance(region, dict):
raise TypeError(f"``region`` must be a dict, got {type(region)}")
if any(v == "auto" for v in region.values()):
if self._mode != "r+":
raise ValueError(
f"``mode`` must be 'r+' when using ``region='auto'``, got {self._mode!r}"
)
region = self._auto_detect_regions(ds, region)

# validate before attempting to auto-detect since the auto-detection
# should always return a valid slice.
for k, v in region.items():
if k not in ds.dims:
raise ValueError(
f"all keys in ``region`` are not in Dataset dimensions, got "
f"{list(region)} and {list(ds.dims)}"
)
if not isinstance(v, slice):
raise TypeError(
"all values in ``region`` must be slice objects, got "
f"region={region}"
)
if v.step not in {1, None}:
raise ValueError(
"step on all slices in ``region`` must be 1 or None, got "
f"region={region}"
)

non_matching_vars = [
k for k, v in ds.variables.items() if not set(region).intersection(v.dims)
]
if non_matching_vars:
raise ValueError(
f"when setting `region` explicitly in to_zarr(), all "
f"variables in the dataset to write must have at least "
f"one dimension in common with the region's dimensions "
f"{list(region.keys())}, but that is not "
f"the case for some variables here. To drop these variables "
f"from this dataset before exporting to zarr, write: "
f".drop_vars({non_matching_vars!r})"
)

self._write_region = region


def open_zarr(
store,
Expand Down

0 comments on commit dcf2ac4

Please sign in to comment.