diff --git a/ci/minimum_versions.py b/ci/minimum_versions.py index 08808d002d9..21123bffcd6 100644 --- a/ci/minimum_versions.py +++ b/ci/minimum_versions.py @@ -30,6 +30,7 @@ "coveralls", "pip", "pytest", + "pytest-asyncio", "pytest-cov", "pytest-env", "pytest-mypy-plugins", diff --git a/ci/requirements/all-but-dask.yml b/ci/requirements/all-but-dask.yml index ca4943bddb1..987adc7dfdd 100644 --- a/ci/requirements/all-but-dask.yml +++ b/ci/requirements/all-but-dask.yml @@ -28,6 +28,7 @@ dependencies: - pip - pydap - pytest + - pytest-asyncio - pytest-cov - pytest-env - pytest-mypy-plugins diff --git a/ci/requirements/all-but-numba.yml b/ci/requirements/all-but-numba.yml index fa7ad81f198..1d49f92133c 100644 --- a/ci/requirements/all-but-numba.yml +++ b/ci/requirements/all-but-numba.yml @@ -41,6 +41,7 @@ dependencies: - pyarrow # pandas raises a deprecation warning without this, breaking doctests - pydap - pytest + - pytest-asyncio - pytest-cov - pytest-env - pytest-mypy-plugins diff --git a/ci/requirements/bare-minimum.yml b/ci/requirements/bare-minimum.yml index 02e99d34af2..cc34a6e4824 100644 --- a/ci/requirements/bare-minimum.yml +++ b/ci/requirements/bare-minimum.yml @@ -7,6 +7,7 @@ dependencies: - coveralls - pip - pytest + - pytest-asyncio - pytest-cov - pytest-env - pytest-mypy-plugins diff --git a/ci/requirements/environment-3.14.yml b/ci/requirements/environment-3.14.yml index 1e6ee7ff5f9..bfbeababa56 100644 --- a/ci/requirements/environment-3.14.yml +++ b/ci/requirements/environment-3.14.yml @@ -37,6 +37,7 @@ dependencies: - pyarrow # pandas raises a deprecation warning without this, breaking doctests - pydap - pytest + - pytest-asyncio - pytest-cov - pytest-env - pytest-mypy-plugins diff --git a/ci/requirements/environment-windows-3.14.yml b/ci/requirements/environment-windows-3.14.yml index 4eb2049f2e6..d5143470614 100644 --- a/ci/requirements/environment-windows-3.14.yml +++ b/ci/requirements/environment-windows-3.14.yml @@ -32,6 +32,7 @@ dependencies: - pyarrow # importing dask.dataframe raises an ImportError without this - pydap - pytest + - pytest-asyncio - pytest-cov - pytest-env - pytest-mypy-plugins diff --git a/ci/requirements/environment-windows.yml b/ci/requirements/environment-windows.yml index 45cbebd38db..6aeca2cb0ab 100644 --- a/ci/requirements/environment-windows.yml +++ b/ci/requirements/environment-windows.yml @@ -32,6 +32,7 @@ dependencies: - pyarrow # importing dask.dataframe raises an ImportError without this - pydap - pytest + - pytest-asyncio - pytest-cov - pytest-env - pytest-mypy-plugins diff --git a/ci/requirements/environment.yml b/ci/requirements/environment.yml index b4354b14f40..9c253d5d489 100644 --- a/ci/requirements/environment.yml +++ b/ci/requirements/environment.yml @@ -38,6 +38,7 @@ dependencies: - pydap - pydap-server - pytest + - pytest-asyncio - pytest-cov - pytest-env - pytest-mypy-plugins diff --git a/ci/requirements/min-all-deps.yml b/ci/requirements/min-all-deps.yml index 03e14773d53..1293f4d78d6 100644 --- a/ci/requirements/min-all-deps.yml +++ b/ci/requirements/min-all-deps.yml @@ -44,6 +44,7 @@ dependencies: - pip - pydap=3.5 - pytest + - pytest-asyncio - pytest-cov - pytest-env - pytest-mypy-plugins diff --git a/doc/api-hidden.rst b/doc/api-hidden.rst index 9a6037cf3c4..98d3704de9b 100644 --- a/doc/api-hidden.rst +++ b/doc/api-hidden.rst @@ -228,6 +228,7 @@ Variable.isnull Variable.item Variable.load + Variable.load_async Variable.max Variable.mean Variable.median diff --git a/doc/api.rst b/doc/api.rst index b6023866eb8..80715555e56 100644 --- a/doc/api.rst +++ b/doc/api.rst @@ -1122,6 +1122,7 @@ Dataset methods Dataset.filter_by_attrs Dataset.info Dataset.load + Dataset.load_async Dataset.persist Dataset.unify_chunks @@ -1154,6 +1155,7 @@ DataArray methods DataArray.compute DataArray.persist DataArray.load + DataArray.load_async DataArray.unify_chunks DataTree methods diff --git a/doc/internals/how-to-add-new-backend.rst b/doc/internals/how-to-add-new-backend.rst index e4f6d54f75c..883c817dccc 100644 --- a/doc/internals/how-to-add-new-backend.rst +++ b/doc/internals/how-to-add-new-backend.rst @@ -325,10 +325,12 @@ information on plugins. How to support lazy loading +++++++++++++++++++++++++++ -If you want to make your backend effective with big datasets, then you should -support lazy loading. -Basically, you shall replace the :py:class:`numpy.ndarray` inside the -variables with a custom class that supports lazy loading indexing. +If you want to make your backend effective with big datasets, then you should take advantage of xarray's +support for lazy loading and indexing. + +Basically, when your backend constructs the ``Variable`` objects, +you need to replace the :py:class:`numpy.ndarray` inside the +variables with a custom :py:class:`~xarray.backends.BackendArray` subclass that supports lazy loading and indexing. See the example below: .. code-block:: python @@ -339,25 +341,27 @@ See the example below: Where: -- :py:class:`~xarray.core.indexing.LazilyIndexedArray` is a class - provided by Xarray that manages the lazy loading. -- ``MyBackendArray`` shall be implemented by the backend and shall inherit +- :py:class:`~xarray.core.indexing.LazilyIndexedArray` is a wrapper class + provided by Xarray that manages the lazy loading and indexing. +- ``MyBackendArray`` should be implemented by the backend and must inherit from :py:class:`~xarray.backends.BackendArray`. BackendArray subclassing ^^^^^^^^^^^^^^^^^^^^^^^^ -The BackendArray subclass shall implement the following method and attributes: +The BackendArray subclass must implement the following method and attributes: -- the ``__getitem__`` method that takes in input an index and returns a - `NumPy `__ array -- the ``shape`` attribute +- the ``__getitem__`` method that takes an index as an input and returns a + `NumPy `__ array, +- the ``shape`` attribute, - the ``dtype`` attribute. -Xarray supports different type of :doc:`/user-guide/indexing`, that can be -grouped in three types of indexes +It may also optionally implement an additional ``async_getitem`` method. + +Xarray supports different types of :doc:`/user-guide/indexing`, that can be +grouped in three types of indexes: :py:class:`~xarray.core.indexing.BasicIndexer`, -:py:class:`~xarray.core.indexing.OuterIndexer` and +:py:class:`~xarray.core.indexing.OuterIndexer`, and :py:class:`~xarray.core.indexing.VectorizedIndexer`. This implies that the implementation of the method ``__getitem__`` can be tricky. In order to simplify this task, Xarray provides a helper function, @@ -413,8 +417,22 @@ input the ``key``, the array ``shape`` and the following parameters: For more details see :py:class:`~xarray.core.indexing.IndexingSupport` and :ref:`RST indexing`. +Async support +^^^^^^^^^^^^^ + +Backends can also optionally support loading data asynchronously via xarray's asynchronous loading methods +(e.g. ``~xarray.Dataset.load_async``). +To support async loading the ``BackendArray`` subclass must additionally implement the ``BackendArray.async_getitem`` method. + +Note that implementing this method is only necessary if you want to be able to load data from different xarray objects concurrently. +Even without this method your ``BackendArray`` implementation is still free to concurrently load chunks of data for a single ``Variable`` itself, +so long as it does so behind the synchronous ``__getitem__`` interface. + +Dask support +^^^^^^^^^^^^ + In order to support `Dask Distributed `__ and -:py:mod:`multiprocessing`, ``BackendArray`` subclass should be serializable +:py:mod:`multiprocessing`, the ``BackendArray`` subclass should be serializable either with :ref:`io.pickle` or `cloudpickle `__. That implies that all the reference to open files should be dropped. For diff --git a/doc/whats-new.rst b/doc/whats-new.rst index 941e52764ae..05dfe8c8abe 100644 --- a/doc/whats-new.rst +++ b/doc/whats-new.rst @@ -24,6 +24,9 @@ v2025.05.0 (unreleased) New Features ~~~~~~~~~~~~ + +- Added new asynchronous loading methods :py:meth:`~xarray.Dataset.load_async`, :py:meth:`~xarray.DataArray.load_async`, :py:meth:`~xarray.Variable.load_async`. + (:issue:`10326`, :pull:`10327`) By `Tom Nicholas `_. - Allow an Xarray index that uses multiple dimensions checking equality with another index for only a subset of those dimensions (i.e., ignoring the dimensions that are excluded from alignment). @@ -42,7 +45,6 @@ Bug fixes ~~~~~~~~~ - Fix :py:class:`~xarray.groupers.BinGrouper` when ``labels`` is not specified (:issue:`10284`). By `Deepak Cherian `_. - - Allow accessing arbitrary attributes on Pandas ExtensionArrays. By `Deepak Cherian `_. diff --git a/pyproject.toml b/pyproject.toml index fa087abbc13..7dc784f170f 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -74,6 +74,7 @@ dev = [ "pytest-mypy-plugins", "pytest-timeout", "pytest-xdist", + "pytest-asyncio", "ruff>=0.8.0", "sphinx", "sphinx_autosummary_accessors", diff --git a/xarray/backends/common.py b/xarray/backends/common.py index 58a98598a5b..10a698ac329 100644 --- a/xarray/backends/common.py +++ b/xarray/backends/common.py @@ -270,10 +270,17 @@ def robust_getitem(array, key, catch=Exception, max_retries=6, initial_delay=500 class BackendArray(NdimSizeLenMixin, indexing.ExplicitlyIndexed): __slots__ = () + async def async_getitem(key: indexing.ExplicitIndexer) -> np.typing.ArrayLike: + raise NotImplementedError("Backend does not not support asynchronous loading") + def get_duck_array(self, dtype: np.typing.DTypeLike = None): key = indexing.BasicIndexer((slice(None),) * self.ndim) return self[key] # type: ignore[index] + async def async_get_duck_array(self, dtype: np.typing.DTypeLike = None): + key = indexing.BasicIndexer((slice(None),) * self.ndim) + return await self.async_getitem(key) # type: ignore[index] + class AbstractDataStore: __slots__ = () diff --git a/xarray/backends/zarr.py b/xarray/backends/zarr.py index 1a46346dda7..8f814c7f1f3 100644 --- a/xarray/backends/zarr.py +++ b/xarray/backends/zarr.py @@ -185,6 +185,8 @@ class ZarrArrayWrapper(BackendArray): def __init__(self, zarr_array): # some callers attempt to evaluate an array if an `array` property exists on the object. # we prefix with _ to avoid this inference. + + # TODO type hint this? self._array = zarr_array self.shape = self._array.shape @@ -212,6 +214,18 @@ def _vindex(self, key): def _getitem(self, key): return self._array[key] + async def _async_getitem(self, key): + async_array = self._array._async_array + return await async_array.getitem(key) + + async def _async_oindex(self, key): + async_array = self._array._async_array + return await async_array.oindex.getitem(key) + + async def _async_vindex(self, key): + async_array = self._array._async_array + return await async_array.vindex.getitem(key) + def __getitem__(self, key): array = self._array if isinstance(key, indexing.BasicIndexer): @@ -227,6 +241,19 @@ def __getitem__(self, key): # if self.ndim == 0: # could possibly have a work-around for 0d data here + async def async_getitem(self, key): + print("async getting") + array = self._array + if isinstance(key, indexing.BasicIndexer): + method = self._async_getitem + elif isinstance(key, indexing.VectorizedIndexer): + method = self._async_vindex + elif isinstance(key, indexing.OuterIndexer): + method = self._async_oindex + return await indexing.async_explicit_indexing_adapter( + key, array.shape, indexing.IndexingSupport.VECTORIZED, method + ) + def _determine_zarr_chunks( enc_chunks, var_chunks, ndim, name, safe_chunks, region, mode, shape diff --git a/xarray/coding/common.py b/xarray/coding/common.py index 1b455009668..8093827138b 100644 --- a/xarray/coding/common.py +++ b/xarray/coding/common.py @@ -75,6 +75,9 @@ def __getitem__(self, key): def get_duck_array(self): return self.func(self.array.get_duck_array()) + async def async_get_duck_array(self): + return self.func(await self.array.async_get_duck_array()) + def __repr__(self) -> str: return f"{type(self).__name__}({self.array!r}, func={self.func!r}, dtype={self.dtype!r})" diff --git a/xarray/coding/strings.py b/xarray/coding/strings.py index 4ca6a3f0a46..a2295c218a6 100644 --- a/xarray/coding/strings.py +++ b/xarray/coding/strings.py @@ -250,14 +250,17 @@ def __repr__(self): return f"{type(self).__name__}({self.array!r})" def _vindex_get(self, key): - return _numpy_char_to_bytes(self.array.vindex[key]) + return type(self)(self.array.vindex[key]) def _oindex_get(self, key): - return _numpy_char_to_bytes(self.array.oindex[key]) + return type(self)(self.array.oindex[key]) def __getitem__(self, key): # require slicing the last dimension completely key = type(key)(indexing.expanded_indexer(key.tuple, self.array.ndim)) if key.tuple[-1] != slice(None): raise IndexError("too many indices") - return _numpy_char_to_bytes(self.array[key]) + return type(self)(self.array[key]) + + def get_duck_array(self): + return _numpy_char_to_bytes(self.array.get_duck_array()) diff --git a/xarray/coding/variables.py b/xarray/coding/variables.py index 1b7bc95e2b4..f82f0c65768 100644 --- a/xarray/coding/variables.py +++ b/xarray/coding/variables.py @@ -58,13 +58,16 @@ def dtype(self) -> np.dtype: return np.dtype(self.array.dtype.kind + str(self.array.dtype.itemsize)) def _oindex_get(self, key): - return np.asarray(self.array.oindex[key], dtype=self.dtype) + return type(self)(self.array.oindex[key]) def _vindex_get(self, key): - return np.asarray(self.array.vindex[key], dtype=self.dtype) + return type(self)(self.array.vindex[key]) def __getitem__(self, key) -> np.ndarray: - return np.asarray(self.array[key], dtype=self.dtype) + return type(self)(self.array[key]) + + def get_duck_array(self): + return duck_array_ops.astype(self.array.get_duck_array(), dtype=self.dtype) class BoolTypeArray(indexing.ExplicitlyIndexedNDArrayMixin): @@ -96,13 +99,16 @@ def dtype(self) -> np.dtype: return np.dtype("bool") def _oindex_get(self, key): - return np.asarray(self.array.oindex[key], dtype=self.dtype) + return type(self)(self.array.oindex[key]) def _vindex_get(self, key): - return np.asarray(self.array.vindex[key], dtype=self.dtype) + return type(self)(self.array.vindex[key]) def __getitem__(self, key) -> np.ndarray: - return np.asarray(self.array[key], dtype=self.dtype) + return type(self)(self.array[key]) + + def get_duck_array(self): + return duck_array_ops.astype(self.array.get_duck_array(), dtype=self.dtype) def _apply_mask( diff --git a/xarray/core/dataarray.py b/xarray/core/dataarray.py index 1e7e1069076..05f5d4c7fa8 100644 --- a/xarray/core/dataarray.py +++ b/xarray/core/dataarray.py @@ -1160,6 +1160,14 @@ def load(self, **kwargs) -> Self: self._coords = new._coords return self + async def load_async(self, **kwargs) -> Self: + temp_ds = self._to_temp_dataset() + ds = await temp_ds.load_async(**kwargs) + new = self._from_temp_dataset(ds) + self._variable = new._variable + self._coords = new._coords + return self + def compute(self, **kwargs) -> Self: """Manually trigger loading of this array's data from disk or a remote source into memory and return a new array. diff --git a/xarray/core/dataset.py b/xarray/core/dataset.py index 5a7f757ba8a..8a4e7177caa 100644 --- a/xarray/core/dataset.py +++ b/xarray/core/dataset.py @@ -1,5 +1,6 @@ from __future__ import annotations +import asyncio import copy import datetime import math @@ -531,24 +532,50 @@ def load(self, **kwargs) -> Self: dask.compute """ # access .data to coerce everything to numpy or dask arrays - lazy_data = { + chunked_data = { k: v._data for k, v in self.variables.items() if is_chunked_array(v._data) } - if lazy_data: - chunkmanager = get_chunked_array_type(*lazy_data.values()) + if chunked_data: + chunkmanager = get_chunked_array_type(*chunked_data.values()) # evaluate all the chunked arrays simultaneously evaluated_data: tuple[np.ndarray[Any, Any], ...] = chunkmanager.compute( - *lazy_data.values(), **kwargs + *chunked_data.values(), **kwargs ) - for k, data in zip(lazy_data, evaluated_data, strict=False): + for k, data in zip(chunked_data, evaluated_data, strict=False): self.variables[k].data = data # load everything else sequentially - for k, v in self.variables.items(): - if k not in lazy_data: - v.load() + [v.load() for k, v in self.variables.items() if k not in chunked_data] + + return self + + async def load_async(self, **kwargs) -> Self: + # TODO refactor this to pull out the common chunked_data codepath + + # this blocks on chunked arrays but not on lazily indexed arrays + + # access .data to coerce everything to numpy or dask arrays + chunked_data = { + k: v._data for k, v in self.variables.items() if is_chunked_array(v._data) + } + if chunked_data: + chunkmanager = get_chunked_array_type(*chunked_data.values()) + + # evaluate all the chunked arrays simultaneously + evaluated_data: tuple[np.ndarray[Any, Any], ...] = chunkmanager.compute( + *chunked_data.values(), **kwargs + ) + + for k, data in zip(chunked_data, evaluated_data, strict=False): + self.variables[k].data = data + + # load everything else concurrently + coros = [ + v.load_async() for k, v in self.variables.items() if k not in chunked_data + ] + await asyncio.gather(*coros) return self diff --git a/xarray/core/indexing.py b/xarray/core/indexing.py index c1b847202c7..824558010e1 100644 --- a/xarray/core/indexing.py +++ b/xarray/core/indexing.py @@ -516,13 +516,31 @@ def get_duck_array(self): return self.array -class ExplicitlyIndexedNDArrayMixin(NDArrayMixin, ExplicitlyIndexed): - __slots__ = () +class IndexingAdapter: + """Marker class for indexing adapters. + + These classes translate between Xarray's indexing semantics and the underlying array's + indexing semantics. + """ def get_duck_array(self): key = BasicIndexer((slice(None),) * self.ndim) return self[key] + async def async_get_duck_array(self): + """These classes are applied to in-memory arrays, so specific async support isn't needed.""" + return self.get_duck_array() + + +class ExplicitlyIndexedNDArrayMixin(NDArrayMixin, ExplicitlyIndexed): + __slots__ = () + + def get_duck_array(self): + raise NotImplementedError + + async def async_get_duck_array(self): + raise NotImplementedError + def _oindex_get(self, indexer: OuterIndexer): raise NotImplementedError( f"{self.__class__.__name__}._oindex_get method should be overridden" @@ -646,19 +664,25 @@ def shape(self) -> _Shape: return self._shape def get_duck_array(self): - if isinstance(self.array, ExplicitlyIndexedNDArrayMixin): - array = apply_indexer(self.array, self.key) - else: - # If the array is not an ExplicitlyIndexedNDArrayMixin, - # it may wrap a BackendArray so use its __getitem__ + from xarray.backends.common import BackendArray + + if isinstance(self.array, BackendArray): array = self.array[self.key] + else: + array = apply_indexer(self.array, self.key) + if isinstance(array, ExplicitlyIndexed): + array = array.get_duck_array() + return _wrap_numpy_scalars(array) - # self.array[self.key] is now a numpy array when - # self.array is a BackendArray subclass - # and self.key is BasicIndexer((slice(None, None, None),)) - # so we need the explicit check for ExplicitlyIndexed - if isinstance(array, ExplicitlyIndexed): - array = array.get_duck_array() + async def async_get_duck_array(self): + from xarray.backends.common import BackendArray + + if isinstance(self.array, BackendArray): + array = await self.array.async_getitem(self.key) + else: + array = apply_indexer(self.array, self.key) + if isinstance(array, ExplicitlyIndexed): + array = await array.async_get_duck_array() return _wrap_numpy_scalars(array) def transpose(self, order): @@ -722,18 +746,26 @@ def shape(self) -> _Shape: return np.broadcast(*self.key.tuple).shape def get_duck_array(self): - if isinstance(self.array, ExplicitlyIndexedNDArrayMixin): + from xarray.backends.common import BackendArray + + if isinstance(self.array, BackendArray): + array = self.array[self.key] + else: array = apply_indexer(self.array, self.key) + if isinstance(array, ExplicitlyIndexed): + array = array.get_duck_array() + return _wrap_numpy_scalars(array) + + async def async_get_duck_array(self): + print("inside LazilyVectorizedIndexedArray.async_get_duck_array") + from xarray.backends.common import BackendArray + + if isinstance(self.array, BackendArray): + array = await self.array.async_getitem(self.key) else: - # If the array is not an ExplicitlyIndexedNDArrayMixin, - # it may wrap a BackendArray so use its __getitem__ - array = self.array[self.key] - # self.array[self.key] is now a numpy array when - # self.array is a BackendArray subclass - # and self.key is BasicIndexer((slice(None, None, None),)) - # so we need the explicit check for ExplicitlyIndexed - if isinstance(array, ExplicitlyIndexed): - array = array.get_duck_array() + array = apply_indexer(self.array, self.key) + if isinstance(array, ExplicitlyIndexed): + array = await array.async_get_duck_array() return _wrap_numpy_scalars(array) def _updated_key(self, new_key: ExplicitIndexer): @@ -797,6 +829,9 @@ def _ensure_copied(self): def get_duck_array(self): return self.array.get_duck_array() + async def async_get_duck_array(self): + return await self.array.async_get_duck_array() + def _oindex_get(self, indexer: OuterIndexer): return type(self)(_wrap_numpy_scalars(self.array.oindex[indexer])) @@ -837,12 +872,17 @@ class MemoryCachedArray(ExplicitlyIndexedNDArrayMixin): def __init__(self, array): self.array = _wrap_numpy_scalars(as_indexable(array)) - def _ensure_cached(self): - self.array = as_indexable(self.array.get_duck_array()) - def get_duck_array(self): - self._ensure_cached() - return self.array.get_duck_array() + duck_array = self.array.get_duck_array() + # ensure the array object is cached in-memory + self.array = as_indexable(duck_array) + return duck_array + + async def async_get_duck_array(self): + duck_array = await self.array.async_get_duck_array() + # ensure the array object is cached in-memory + self.array = as_indexable(duck_array) + return duck_array def _oindex_get(self, indexer: OuterIndexer): return type(self)(_wrap_numpy_scalars(self.array.oindex[indexer])) @@ -1027,6 +1067,21 @@ def explicit_indexing_adapter( return result +async def async_explicit_indexing_adapter( + key: ExplicitIndexer, + shape: _Shape, + indexing_support: IndexingSupport, + raw_indexing_method: Callable[..., Any], +) -> Any: + raw_key, numpy_indices = decompose_indexer(key, shape, indexing_support) + result = await raw_indexing_method(raw_key.tuple) + if numpy_indices.tuple: + # index the loaded duck array + indexable = as_indexable(result) + result = apply_indexer(indexable, numpy_indices) + return result + + def apply_indexer(indexable, indexer: ExplicitIndexer): """Apply an indexer to an indexable object.""" if isinstance(indexer, VectorizedIndexer): @@ -1526,7 +1581,7 @@ def is_fancy_indexer(indexer: Any) -> bool: return True -class NumpyIndexingAdapter(ExplicitlyIndexedNDArrayMixin): +class NumpyIndexingAdapter(IndexingAdapter, ExplicitlyIndexedNDArrayMixin): """Wrap a NumPy array to use explicit indexing.""" __slots__ = ("array",) @@ -1605,7 +1660,7 @@ def __init__(self, array): self.array = array -class ArrayApiIndexingAdapter(ExplicitlyIndexedNDArrayMixin): +class ArrayApiIndexingAdapter(IndexingAdapter, ExplicitlyIndexedNDArrayMixin): """Wrap an array API array to use explicit indexing.""" __slots__ = ("array",) @@ -1670,7 +1725,7 @@ def _assert_not_chunked_indexer(idxr: tuple[Any, ...]) -> None: ) -class DaskIndexingAdapter(ExplicitlyIndexedNDArrayMixin): +class DaskIndexingAdapter(IndexingAdapter, ExplicitlyIndexedNDArrayMixin): """Wrap a dask array to support explicit indexing.""" __slots__ = ("array",) @@ -1746,7 +1801,7 @@ def transpose(self, order): return self.array.transpose(order) -class PandasIndexingAdapter(ExplicitlyIndexedNDArrayMixin): +class PandasIndexingAdapter(IndexingAdapter, ExplicitlyIndexedNDArrayMixin): """Wrap a pandas.Index to preserve dtypes and handle explicit indexing.""" __slots__ = ("_dtype", "array") @@ -2063,7 +2118,9 @@ def copy(self, deep: bool = True) -> Self: return type(self)(array, self._dtype, self.level) -class CoordinateTransformIndexingAdapter(ExplicitlyIndexedNDArrayMixin): +class CoordinateTransformIndexingAdapter( + IndexingAdapter, ExplicitlyIndexedNDArrayMixin +): """Wrap a CoordinateTransform as a lazy coordinate array. Supports explicit indexing (both outer and vectorized). diff --git a/xarray/core/variable.py b/xarray/core/variable.py index 4e58b0d4b20..38f2676ec52 100644 --- a/xarray/core/variable.py +++ b/xarray/core/variable.py @@ -47,6 +47,7 @@ from xarray.namedarray.core import NamedArray, _raise_if_any_duplicate_dimensions from xarray.namedarray.parallelcompat import get_chunked_array_type from xarray.namedarray.pycompat import ( + async_to_duck_array, integer_types, is_0d_dask_array, is_chunked_array, @@ -957,6 +958,10 @@ def load(self, **kwargs): self._data = to_duck_array(self._data, **kwargs) return self + async def load_async(self, **kwargs): + self._data = await async_to_duck_array(self._data, **kwargs) + return self + def compute(self, **kwargs): """Manually trigger loading of this variable's data from disk or a remote source into memory and return a new variable. The original is diff --git a/xarray/namedarray/pycompat.py b/xarray/namedarray/pycompat.py index 68b6a7853bf..6e61d3445ab 100644 --- a/xarray/namedarray/pycompat.py +++ b/xarray/namedarray/pycompat.py @@ -145,3 +145,23 @@ def to_duck_array(data: Any, **kwargs: dict[str, Any]) -> duckarray[_ShapeType, return data else: return np.asarray(data) # type: ignore[return-value] + + +async def async_to_duck_array( + data: Any, **kwargs: dict[str, Any] +) -> duckarray[_ShapeType, _DType]: + from xarray.core.indexing import ( + ExplicitlyIndexed, + ImplicitToExplicitIndexingAdapter, + IndexingAdapter, + ) + + print(type(data)) + if isinstance(data, IndexingAdapter): + # These wrap in-memory arrays, and async isn't needed + return data.get_duck_array() + elif isinstance(data, ExplicitlyIndexed | ImplicitToExplicitIndexingAdapter): + print("async inside to_duck_array") + return await data.async_get_duck_array() # type: ignore[no-untyped-call, no-any-return] + else: + return to_duck_array(data, **kwargs) diff --git a/xarray/tests/test_async.py b/xarray/tests/test_async.py new file mode 100644 index 00000000000..918a2508ea0 --- /dev/null +++ b/xarray/tests/test_async.py @@ -0,0 +1,221 @@ +import asyncio +import time +from collections.abc import Iterable +from contextlib import asynccontextmanager +from typing import TypeVar +from unittest.mock import patch + +import numpy as np +import pytest + +import xarray as xr +import xarray.testing as xrt +from xarray.tests import has_zarr_v3, requires_zarr_v3 + +if has_zarr_v3: + import zarr + from zarr.abc.store import ByteRequest, Store + from zarr.core.buffer import Buffer, BufferPrototype + from zarr.storage import MemoryStore + from zarr.storage._wrapper import WrapperStore + + T_Store = TypeVar("T_Store", bound=Store) + + class LatencyStore(WrapperStore[T_Store]): + """Works the same way as the zarr LoggingStore""" + + latency: float + + # TODO only have to add this because of dumb behaviour in zarr where it raises with "ValueError: Store is not read-only but mode is 'r'" + read_only = True + + def __init__( + self, + store: T_Store, + latency: float = 0.0, + ) -> None: + """ + Store wrapper that adds artificial latency to each get call. + + Parameters + ---------- + store : Store + Store to wrap + latency : float + Amount of artificial latency to add to each get call, in seconds. + """ + super().__init__(store) + self.latency = latency + + def __str__(self) -> str: + return f"latency-{self._store}" + + def __repr__(self) -> str: + return f"LatencyStore({self._store.__class__.__name__}, '{self._store}', latency={self.latency})" + + async def get( + self, + key: str, + prototype: BufferPrototype, + byte_range: ByteRequest | None = None, + ) -> Buffer | None: + await asyncio.sleep(self.latency) + return await self._store.get( + key=key, prototype=prototype, byte_range=byte_range + ) + + async def get_partial_values( + self, + prototype: BufferPrototype, + key_ranges: Iterable[tuple[str, ByteRequest | None]], + ) -> list[Buffer | None]: + await asyncio.sleep(self.latency) + return await self._store.get_partial_values( + prototype=prototype, key_ranges=key_ranges + ) +else: + LatencyStore = {} + + +@pytest.fixture +def memorystore() -> "MemoryStore": + memorystore = zarr.storage.MemoryStore({}) + z1 = zarr.create_array( + store=memorystore, + name="foo", + shape=(10, 10), + chunks=(5, 5), + dtype="f4", + dimension_names=["x", "y"], + attributes={"add_offset": 1, "scale_factor": 2}, + ) + z1[:, :] = np.random.random((10, 10)) + + z2 = zarr.create_array( + store=memorystore, + name="x", + shape=(10,), + chunks=(5), + dtype="f4", + dimension_names=["x"], + ) + z2[:] = np.arange(10) + + return memorystore + + +class AsyncTimer: + """Context manager for timing async operations and making assertions about their execution time.""" + + start_time: float + end_time: float + total_time: float + + @asynccontextmanager + async def measure(self): + """Measure the execution time of the async code within this context.""" + self.start_time = time.time() + try: + yield self + finally: + self.end_time = time.time() + self.total_time = self.end_time - self.start_time + + +@requires_zarr_v3 +@pytest.mark.asyncio +class TestAsyncLoad: + LATENCY: float = 1.0 + + @pytest.fixture(params=["var", "ds", "da"]) + def xr_obj(self, request, memorystore) -> xr.Dataset | xr.DataArray | xr.Variable: + latencystore = LatencyStore(memorystore, latency=self.LATENCY) + ds = xr.open_zarr(latencystore, zarr_format=3, consolidated=False, chunks=None) + + match request.param: + case "var": + return ds["foo"].variable + case "da": + return ds["foo"] + case "ds": + return ds + + def assert_time_as_expected( + self, total_time: float, latency: float, n_loads: int + ) -> None: + assert total_time > latency # Cannot possibly be quicker than this + assert ( + total_time < latency * n_loads + ) # If this isn't true we're gaining nothing from async + assert ( + abs(total_time - latency) < 2.0 + ) # Should take approximately `latency` seconds, but allow some buffer + + async def test_concurrent_load_multiple_variables(self, memorystore) -> None: + latencystore = LatencyStore(memorystore, latency=self.LATENCY) + ds = xr.open_zarr(latencystore, zarr_format=3, consolidated=False, chunks=None) + + # TODO up the number of variables in the dataset? + async with AsyncTimer().measure() as timer: + result_ds = await ds.load_async() + + xrt.assert_identical(result_ds, ds.load()) + + # 2 because there are 2 lazy variables in the dataset + self.assert_time_as_expected( + total_time=timer.total_time, latency=self.LATENCY, n_loads=2 + ) + + async def test_concurrent_load_multiple_objects(self, xr_obj) -> None: + N_OBJECTS = 5 + + async with AsyncTimer().measure() as timer: + coros = [xr_obj.load_async() for _ in range(N_OBJECTS)] + results = await asyncio.gather(*coros) + + for result in results: + xrt.assert_identical(result, xr_obj.load()) + + self.assert_time_as_expected( + total_time=timer.total_time, latency=self.LATENCY, n_loads=N_OBJECTS + ) + + @pytest.mark.parametrize("method", ["sel", "isel"]) + @pytest.mark.parametrize( + "indexer, zarr_getitem_method", + [ + ({"x": 2}, zarr.AsyncArray.getitem), + ({"x": slice(2, 4)}, zarr.AsyncArray.getitem), + ({"x": [2, 3]}, zarr.core.indexing.AsyncOIndex.getitem), + ( + { + "x": xr.DataArray([2, 3], dims="points"), + "y": xr.DataArray([2, 3], dims="points"), + }, + zarr.core.indexing.AsyncVIndex.getitem, + ), + ], + ids=["basic-int", "basic-slice", "outer", "vectorized"], + ) + async def test_indexing( + self, memorystore, method, indexer, zarr_getitem_method + ) -> None: + # TODO we don't need a LatencyStore for this test + latencystore = LatencyStore(memorystore, latency=0.0) + + with patch.object( + zarr.AsyncArray, "getitem", wraps=zarr_getitem_method, autospec=True + ) as mocked_meth: + ds = xr.open_zarr( + latencystore, zarr_format=3, consolidated=False, chunks=None + ) + + # TODO we're not actually testing that these indexing methods are not blocking... + result = await getattr(ds, method)(**indexer).load_async() + + assert mocked_meth.call_count > 0 + mocked_meth.assert_called() + mocked_meth.assert_awaited() + + expected = getattr(ds, method)(**indexer).load() + xrt.assert_identical(result, expected) diff --git a/xarray/tests/test_indexing.py b/xarray/tests/test_indexing.py index 6dd75b58c6a..d308844c6fa 100644 --- a/xarray/tests/test_indexing.py +++ b/xarray/tests/test_indexing.py @@ -490,6 +490,23 @@ def test_sub_array(self) -> None: assert isinstance(child.array, indexing.NumpyIndexingAdapter) assert isinstance(wrapped.array, indexing.LazilyIndexedArray) + async def test_async_wrapper(self) -> None: + original = indexing.LazilyIndexedArray(np.arange(10)) + wrapped = indexing.MemoryCachedArray(original) + await wrapped.async_get_duck_array() + assert_array_equal(wrapped, np.arange(10)) + assert isinstance(wrapped.array, indexing.NumpyIndexingAdapter) + + async def test_async_sub_array(self) -> None: + original = indexing.LazilyIndexedArray(np.arange(10)) + wrapped = indexing.MemoryCachedArray(original) + child = wrapped[B[:5]] + assert isinstance(child, indexing.MemoryCachedArray) + await child.async_get_duck_array() + assert_array_equal(child, np.arange(5)) + assert isinstance(child.array, indexing.NumpyIndexingAdapter) + assert isinstance(wrapped.array, indexing.LazilyIndexedArray) + def test_setitem(self) -> None: original = np.arange(10) wrapped = indexing.MemoryCachedArray(original)