From dcf2ac4addb5a92723c6b064fb6546ff02ebd1cd Mon Sep 17 00:00:00 2001 From: Deepak Cherian Date: Tue, 7 May 2024 09:29:26 -0600 Subject: [PATCH] Zarr: Optimize `region="auto"` detection (#8997) * Zarr: Optimize region detection * Fix for unindexed dimensions. * Better example * small cleanup --- doc/user-guide/io.rst | 4 +- xarray/backends/api.py | 115 ++++------------------------------------ xarray/backends/zarr.py | 96 ++++++++++++++++++++++++++++++--- 3 files changed, 101 insertions(+), 114 deletions(-) diff --git a/doc/user-guide/io.rst b/doc/user-guide/io.rst index 63bf8b80d81..b73d0fdcb51 100644 --- a/doc/user-guide/io.rst +++ b/doc/user-guide/io.rst @@ -874,7 +874,7 @@ and then calling ``to_zarr`` with ``compute=False`` to write only metadata # The values of this dask array are entirely irrelevant; only the dtype, # shape and chunks are used dummies = dask.array.zeros(30, chunks=10) - ds = xr.Dataset({"foo": ("x", dummies)}) + ds = xr.Dataset({"foo": ("x", dummies)}, coords={"x": np.arange(30)}) path = "path/to/directory.zarr" # Now we write the metadata without computing any array values ds.to_zarr(path, compute=False) @@ -890,7 +890,7 @@ where the data should be written (in index space, not label space), e.g., # For convenience, we'll slice a single dataset, but in the real use-case # we would create them separately possibly even from separate processes. - ds = xr.Dataset({"foo": ("x", np.arange(30))}) + ds = xr.Dataset({"foo": ("x", np.arange(30))}, coords={"x": np.arange(30)}) # Any of the following region specifications are valid ds.isel(x=slice(0, 10)).to_zarr(path, region="auto") ds.isel(x=slice(10, 20)).to_zarr(path, region={"x": "auto"}) diff --git a/xarray/backends/api.py b/xarray/backends/api.py index 62085fe5e2a..c9a8630a575 100644 --- a/xarray/backends/api.py +++ b/xarray/backends/api.py @@ -27,7 +27,6 @@ _normalize_path, ) from xarray.backends.locks import _get_scheduler -from xarray.backends.zarr import open_zarr from xarray.core import indexing from xarray.core.combine import ( _infer_concat_order_from_positions, @@ -1522,92 +1521,6 @@ def save_mfdataset( ) -def _auto_detect_region(ds_new, ds_orig, dim): - # Create a mapping array of coordinates to indices on the original array - coord = ds_orig[dim] - da_map = DataArray(np.arange(coord.size), coords={dim: coord}) - - try: - da_idxs = da_map.sel({dim: ds_new[dim]}) - except KeyError as e: - if "not all values found" in str(e): - raise KeyError( - f"Not all values of coordinate '{dim}' in the new array were" - " found in the original store. Writing to a zarr region slice" - " requires that no dimensions or metadata are changed by the write." - ) - else: - raise e - - if (da_idxs.diff(dim) != 1).any(): - raise ValueError( - f"The auto-detected region of coordinate '{dim}' for writing new data" - " to the original store had non-contiguous indices. Writing to a zarr" - " region slice requires that the new data constitute a contiguous subset" - " of the original store." - ) - - dim_slice = slice(da_idxs.values[0], da_idxs.values[-1] + 1) - - return dim_slice - - -def _auto_detect_regions(ds, region, open_kwargs): - ds_original = open_zarr(**open_kwargs) - for key, val in region.items(): - if val == "auto": - region[key] = _auto_detect_region(ds, ds_original, key) - return region - - -def _validate_and_autodetect_region(ds, region, mode, open_kwargs) -> dict[str, slice]: - if region == "auto": - region = {dim: "auto" for dim in ds.dims} - - if not isinstance(region, dict): - raise TypeError(f"``region`` must be a dict, got {type(region)}") - - if any(v == "auto" for v in region.values()): - if mode != "r+": - raise ValueError( - f"``mode`` must be 'r+' when using ``region='auto'``, got {mode}" - ) - region = _auto_detect_regions(ds, region, open_kwargs) - - for k, v in region.items(): - if k not in ds.dims: - raise ValueError( - f"all keys in ``region`` are not in Dataset dimensions, got " - f"{list(region)} and {list(ds.dims)}" - ) - if not isinstance(v, slice): - raise TypeError( - "all values in ``region`` must be slice objects, got " - f"region={region}" - ) - if v.step not in {1, None}: - raise ValueError( - "step on all slices in ``region`` must be 1 or None, got " - f"region={region}" - ) - - non_matching_vars = [ - k for k, v in ds.variables.items() if not set(region).intersection(v.dims) - ] - if non_matching_vars: - raise ValueError( - f"when setting `region` explicitly in to_zarr(), all " - f"variables in the dataset to write must have at least " - f"one dimension in common with the region's dimensions " - f"{list(region.keys())}, but that is not " - f"the case for some variables here. To drop these variables " - f"from this dataset before exporting to zarr, write: " - f".drop_vars({non_matching_vars!r})" - ) - - return region - - def _validate_datatypes_for_zarr_append(zstore, dataset): """If variable exists in the store, confirm dtype of the data to append is compatible with existing dtype. @@ -1768,24 +1681,6 @@ def to_zarr( # validate Dataset keys, DataArray names _validate_dataset_names(dataset) - if region is not None: - open_kwargs = dict( - store=store, - synchronizer=synchronizer, - group=group, - consolidated=consolidated, - storage_options=storage_options, - zarr_version=zarr_version, - ) - region = _validate_and_autodetect_region(dataset, region, mode, open_kwargs) - # can't modify indexed with region writes - dataset = dataset.drop_vars(dataset.indexes) - if append_dim is not None and append_dim in region: - raise ValueError( - f"cannot list the same dimension in both ``append_dim`` and " - f"``region`` with to_zarr(), got {append_dim} in both" - ) - if zarr_version is None: # default to 2 if store doesn't specify it's version (e.g. a path) zarr_version = int(getattr(store, "_store_version", 2)) @@ -1815,6 +1710,16 @@ def to_zarr( write_empty=write_empty_chunks, ) + if region is not None: + zstore._validate_and_autodetect_region(dataset) + # can't modify indexed with region writes + dataset = dataset.drop_vars(dataset.indexes) + if append_dim is not None and append_dim in region: + raise ValueError( + f"cannot list the same dimension in both ``append_dim`` and " + f"``region`` with to_zarr(), got {append_dim} in both" + ) + if mode in ["a", "a-", "r+"]: _validate_datatypes_for_zarr_append(zstore, dataset) if append_dim is not None: diff --git a/xarray/backends/zarr.py b/xarray/backends/zarr.py index 3d6baeefe01..e4a684e945d 100644 --- a/xarray/backends/zarr.py +++ b/xarray/backends/zarr.py @@ -7,6 +7,7 @@ from typing import TYPE_CHECKING, Any import numpy as np +import pandas as pd from xarray import coding, conventions from xarray.backends.common import ( @@ -509,7 +510,9 @@ def ds(self): # TODO: consider deprecating this in favor of zarr_group return self.zarr_group - def open_store_variable(self, name, zarr_array): + def open_store_variable(self, name, zarr_array=None): + if zarr_array is None: + zarr_array = self.zarr_group[name] data = indexing.LazilyIndexedArray(ZarrArrayWrapper(zarr_array)) try_nczarr = self._mode == "r" dimensions, attributes = _get_zarr_dims_and_attrs( @@ -623,11 +626,7 @@ def store( # avoid needing to load index variables into memory. # TODO: consider making loading indexes lazy again? existing_vars, _, _ = conventions.decode_cf_variables( - { - k: v - for k, v in self.get_variables().items() - if k in existing_variable_names - }, + {k: self.open_store_variable(name=k) for k in existing_variable_names}, self.get_attrs(), ) # Modified variables must use the same encoding as the store. @@ -796,10 +795,93 @@ def set_variables(self, variables, check_encoding_set, writer, unlimited_dims=No region = tuple(write_region[dim] for dim in dims) writer.add(v.data, zarr_array, region) - def close(self): + def close(self) -> None: if self._close_store_on_close: self.zarr_group.store.close() + def _auto_detect_regions(self, ds, region): + for dim, val in region.items(): + if val != "auto": + continue + + if dim not in ds._variables: + # unindexed dimension + region[dim] = slice(0, ds.sizes[dim]) + continue + + variable = conventions.decode_cf_variable( + dim, self.open_store_variable(dim).compute() + ) + assert variable.dims == (dim,) + index = pd.Index(variable.data) + idxs = index.get_indexer(ds[dim].data) + if any(idxs == -1): + raise KeyError( + f"Not all values of coordinate '{dim}' in the new array were" + " found in the original store. Writing to a zarr region slice" + " requires that no dimensions or metadata are changed by the write." + ) + + if (np.diff(idxs) != 1).any(): + raise ValueError( + f"The auto-detected region of coordinate '{dim}' for writing new data" + " to the original store had non-contiguous indices. Writing to a zarr" + " region slice requires that the new data constitute a contiguous subset" + " of the original store." + ) + region[dim] = slice(idxs[0], idxs[-1] + 1) + return region + + def _validate_and_autodetect_region(self, ds) -> None: + region = self._write_region + + if region == "auto": + region = {dim: "auto" for dim in ds.dims} + + if not isinstance(region, dict): + raise TypeError(f"``region`` must be a dict, got {type(region)}") + if any(v == "auto" for v in region.values()): + if self._mode != "r+": + raise ValueError( + f"``mode`` must be 'r+' when using ``region='auto'``, got {self._mode!r}" + ) + region = self._auto_detect_regions(ds, region) + + # validate before attempting to auto-detect since the auto-detection + # should always return a valid slice. + for k, v in region.items(): + if k not in ds.dims: + raise ValueError( + f"all keys in ``region`` are not in Dataset dimensions, got " + f"{list(region)} and {list(ds.dims)}" + ) + if not isinstance(v, slice): + raise TypeError( + "all values in ``region`` must be slice objects, got " + f"region={region}" + ) + if v.step not in {1, None}: + raise ValueError( + "step on all slices in ``region`` must be 1 or None, got " + f"region={region}" + ) + + non_matching_vars = [ + k for k, v in ds.variables.items() if not set(region).intersection(v.dims) + ] + if non_matching_vars: + raise ValueError( + f"when setting `region` explicitly in to_zarr(), all " + f"variables in the dataset to write must have at least " + f"one dimension in common with the region's dimensions " + f"{list(region.keys())}, but that is not " + f"the case for some variables here. To drop these variables " + f"from this dataset before exporting to zarr, write: " + f".drop_vars({non_matching_vars!r})" + ) + + self._write_region = region + def open_zarr( store,