diff --git a/ci/minimum_versions.py b/ci/minimum_versions.py
index 08808d002d9..21123bffcd6 100644
--- a/ci/minimum_versions.py
+++ b/ci/minimum_versions.py
@@ -30,6 +30,7 @@
"coveralls",
"pip",
"pytest",
+ "pytest-asyncio",
"pytest-cov",
"pytest-env",
"pytest-mypy-plugins",
diff --git a/ci/requirements/all-but-dask.yml b/ci/requirements/all-but-dask.yml
index ca4943bddb1..987adc7dfdd 100644
--- a/ci/requirements/all-but-dask.yml
+++ b/ci/requirements/all-but-dask.yml
@@ -28,6 +28,7 @@ dependencies:
- pip
- pydap
- pytest
+ - pytest-asyncio
- pytest-cov
- pytest-env
- pytest-mypy-plugins
diff --git a/ci/requirements/all-but-numba.yml b/ci/requirements/all-but-numba.yml
index fa7ad81f198..1d49f92133c 100644
--- a/ci/requirements/all-but-numba.yml
+++ b/ci/requirements/all-but-numba.yml
@@ -41,6 +41,7 @@ dependencies:
- pyarrow # pandas raises a deprecation warning without this, breaking doctests
- pydap
- pytest
+ - pytest-asyncio
- pytest-cov
- pytest-env
- pytest-mypy-plugins
diff --git a/ci/requirements/bare-minimum.yml b/ci/requirements/bare-minimum.yml
index 02e99d34af2..cc34a6e4824 100644
--- a/ci/requirements/bare-minimum.yml
+++ b/ci/requirements/bare-minimum.yml
@@ -7,6 +7,7 @@ dependencies:
- coveralls
- pip
- pytest
+ - pytest-asyncio
- pytest-cov
- pytest-env
- pytest-mypy-plugins
diff --git a/ci/requirements/environment-3.14.yml b/ci/requirements/environment-3.14.yml
index 1e6ee7ff5f9..bfbeababa56 100644
--- a/ci/requirements/environment-3.14.yml
+++ b/ci/requirements/environment-3.14.yml
@@ -37,6 +37,7 @@ dependencies:
- pyarrow # pandas raises a deprecation warning without this, breaking doctests
- pydap
- pytest
+ - pytest-asyncio
- pytest-cov
- pytest-env
- pytest-mypy-plugins
diff --git a/ci/requirements/environment-windows-3.14.yml b/ci/requirements/environment-windows-3.14.yml
index 4eb2049f2e6..d5143470614 100644
--- a/ci/requirements/environment-windows-3.14.yml
+++ b/ci/requirements/environment-windows-3.14.yml
@@ -32,6 +32,7 @@ dependencies:
- pyarrow # importing dask.dataframe raises an ImportError without this
- pydap
- pytest
+ - pytest-asyncio
- pytest-cov
- pytest-env
- pytest-mypy-plugins
diff --git a/ci/requirements/environment-windows.yml b/ci/requirements/environment-windows.yml
index 45cbebd38db..6aeca2cb0ab 100644
--- a/ci/requirements/environment-windows.yml
+++ b/ci/requirements/environment-windows.yml
@@ -32,6 +32,7 @@ dependencies:
- pyarrow # importing dask.dataframe raises an ImportError without this
- pydap
- pytest
+ - pytest-asyncio
- pytest-cov
- pytest-env
- pytest-mypy-plugins
diff --git a/ci/requirements/environment.yml b/ci/requirements/environment.yml
index b4354b14f40..9c253d5d489 100644
--- a/ci/requirements/environment.yml
+++ b/ci/requirements/environment.yml
@@ -38,6 +38,7 @@ dependencies:
- pydap
- pydap-server
- pytest
+ - pytest-asyncio
- pytest-cov
- pytest-env
- pytest-mypy-plugins
diff --git a/ci/requirements/min-all-deps.yml b/ci/requirements/min-all-deps.yml
index 03e14773d53..1293f4d78d6 100644
--- a/ci/requirements/min-all-deps.yml
+++ b/ci/requirements/min-all-deps.yml
@@ -44,6 +44,7 @@ dependencies:
- pip
- pydap=3.5
- pytest
+ - pytest-asyncio
- pytest-cov
- pytest-env
- pytest-mypy-plugins
diff --git a/doc/api-hidden.rst b/doc/api-hidden.rst
index 9a6037cf3c4..98d3704de9b 100644
--- a/doc/api-hidden.rst
+++ b/doc/api-hidden.rst
@@ -228,6 +228,7 @@
Variable.isnull
Variable.item
Variable.load
+ Variable.load_async
Variable.max
Variable.mean
Variable.median
diff --git a/doc/api.rst b/doc/api.rst
index b6023866eb8..80715555e56 100644
--- a/doc/api.rst
+++ b/doc/api.rst
@@ -1122,6 +1122,7 @@ Dataset methods
Dataset.filter_by_attrs
Dataset.info
Dataset.load
+ Dataset.load_async
Dataset.persist
Dataset.unify_chunks
@@ -1154,6 +1155,7 @@ DataArray methods
DataArray.compute
DataArray.persist
DataArray.load
+ DataArray.load_async
DataArray.unify_chunks
DataTree methods
diff --git a/doc/internals/how-to-add-new-backend.rst b/doc/internals/how-to-add-new-backend.rst
index e4f6d54f75c..883c817dccc 100644
--- a/doc/internals/how-to-add-new-backend.rst
+++ b/doc/internals/how-to-add-new-backend.rst
@@ -325,10 +325,12 @@ information on plugins.
How to support lazy loading
+++++++++++++++++++++++++++
-If you want to make your backend effective with big datasets, then you should
-support lazy loading.
-Basically, you shall replace the :py:class:`numpy.ndarray` inside the
-variables with a custom class that supports lazy loading indexing.
+If you want to make your backend effective with big datasets, then you should take advantage of xarray's
+support for lazy loading and indexing.
+
+Basically, when your backend constructs the ``Variable`` objects,
+you need to replace the :py:class:`numpy.ndarray` inside the
+variables with a custom :py:class:`~xarray.backends.BackendArray` subclass that supports lazy loading and indexing.
See the example below:
.. code-block:: python
@@ -339,25 +341,27 @@ See the example below:
Where:
-- :py:class:`~xarray.core.indexing.LazilyIndexedArray` is a class
- provided by Xarray that manages the lazy loading.
-- ``MyBackendArray`` shall be implemented by the backend and shall inherit
+- :py:class:`~xarray.core.indexing.LazilyIndexedArray` is a wrapper class
+ provided by Xarray that manages the lazy loading and indexing.
+- ``MyBackendArray`` should be implemented by the backend and must inherit
from :py:class:`~xarray.backends.BackendArray`.
BackendArray subclassing
^^^^^^^^^^^^^^^^^^^^^^^^
-The BackendArray subclass shall implement the following method and attributes:
+The BackendArray subclass must implement the following method and attributes:
-- the ``__getitem__`` method that takes in input an index and returns a
- `NumPy `__ array
-- the ``shape`` attribute
+- the ``__getitem__`` method that takes an index as an input and returns a
+ `NumPy `__ array,
+- the ``shape`` attribute,
- the ``dtype`` attribute.
-Xarray supports different type of :doc:`/user-guide/indexing`, that can be
-grouped in three types of indexes
+It may also optionally implement an additional ``async_getitem`` method.
+
+Xarray supports different types of :doc:`/user-guide/indexing`, that can be
+grouped in three types of indexes:
:py:class:`~xarray.core.indexing.BasicIndexer`,
-:py:class:`~xarray.core.indexing.OuterIndexer` and
+:py:class:`~xarray.core.indexing.OuterIndexer`, and
:py:class:`~xarray.core.indexing.VectorizedIndexer`.
This implies that the implementation of the method ``__getitem__`` can be tricky.
In order to simplify this task, Xarray provides a helper function,
@@ -413,8 +417,22 @@ input the ``key``, the array ``shape`` and the following parameters:
For more details see
:py:class:`~xarray.core.indexing.IndexingSupport` and :ref:`RST indexing`.
+Async support
+^^^^^^^^^^^^^
+
+Backends can also optionally support loading data asynchronously via xarray's asynchronous loading methods
+(e.g. ``~xarray.Dataset.load_async``).
+To support async loading the ``BackendArray`` subclass must additionally implement the ``BackendArray.async_getitem`` method.
+
+Note that implementing this method is only necessary if you want to be able to load data from different xarray objects concurrently.
+Even without this method your ``BackendArray`` implementation is still free to concurrently load chunks of data for a single ``Variable`` itself,
+so long as it does so behind the synchronous ``__getitem__`` interface.
+
+Dask support
+^^^^^^^^^^^^
+
In order to support `Dask Distributed `__ and
-:py:mod:`multiprocessing`, ``BackendArray`` subclass should be serializable
+:py:mod:`multiprocessing`, the ``BackendArray`` subclass should be serializable
either with :ref:`io.pickle` or
`cloudpickle `__.
That implies that all the reference to open files should be dropped. For
diff --git a/doc/whats-new.rst b/doc/whats-new.rst
index 941e52764ae..05dfe8c8abe 100644
--- a/doc/whats-new.rst
+++ b/doc/whats-new.rst
@@ -24,6 +24,9 @@ v2025.05.0 (unreleased)
New Features
~~~~~~~~~~~~
+
+- Added new asynchronous loading methods :py:meth:`~xarray.Dataset.load_async`, :py:meth:`~xarray.DataArray.load_async`, :py:meth:`~xarray.Variable.load_async`.
+ (:issue:`10326`, :pull:`10327`) By `Tom Nicholas `_.
- Allow an Xarray index that uses multiple dimensions checking equality with another
index for only a subset of those dimensions (i.e., ignoring the dimensions
that are excluded from alignment).
@@ -42,7 +45,6 @@ Bug fixes
~~~~~~~~~
- Fix :py:class:`~xarray.groupers.BinGrouper` when ``labels`` is not specified (:issue:`10284`).
By `Deepak Cherian `_.
-
- Allow accessing arbitrary attributes on Pandas ExtensionArrays.
By `Deepak Cherian `_.
diff --git a/pyproject.toml b/pyproject.toml
index fa087abbc13..7dc784f170f 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -74,6 +74,7 @@ dev = [
"pytest-mypy-plugins",
"pytest-timeout",
"pytest-xdist",
+ "pytest-asyncio",
"ruff>=0.8.0",
"sphinx",
"sphinx_autosummary_accessors",
diff --git a/xarray/backends/common.py b/xarray/backends/common.py
index 58a98598a5b..10a698ac329 100644
--- a/xarray/backends/common.py
+++ b/xarray/backends/common.py
@@ -270,10 +270,17 @@ def robust_getitem(array, key, catch=Exception, max_retries=6, initial_delay=500
class BackendArray(NdimSizeLenMixin, indexing.ExplicitlyIndexed):
__slots__ = ()
+ async def async_getitem(key: indexing.ExplicitIndexer) -> np.typing.ArrayLike:
+ raise NotImplementedError("Backend does not not support asynchronous loading")
+
def get_duck_array(self, dtype: np.typing.DTypeLike = None):
key = indexing.BasicIndexer((slice(None),) * self.ndim)
return self[key] # type: ignore[index]
+ async def async_get_duck_array(self, dtype: np.typing.DTypeLike = None):
+ key = indexing.BasicIndexer((slice(None),) * self.ndim)
+ return await self.async_getitem(key) # type: ignore[index]
+
class AbstractDataStore:
__slots__ = ()
diff --git a/xarray/backends/zarr.py b/xarray/backends/zarr.py
index 1a46346dda7..8f814c7f1f3 100644
--- a/xarray/backends/zarr.py
+++ b/xarray/backends/zarr.py
@@ -185,6 +185,8 @@ class ZarrArrayWrapper(BackendArray):
def __init__(self, zarr_array):
# some callers attempt to evaluate an array if an `array` property exists on the object.
# we prefix with _ to avoid this inference.
+
+ # TODO type hint this?
self._array = zarr_array
self.shape = self._array.shape
@@ -212,6 +214,18 @@ def _vindex(self, key):
def _getitem(self, key):
return self._array[key]
+ async def _async_getitem(self, key):
+ async_array = self._array._async_array
+ return await async_array.getitem(key)
+
+ async def _async_oindex(self, key):
+ async_array = self._array._async_array
+ return await async_array.oindex.getitem(key)
+
+ async def _async_vindex(self, key):
+ async_array = self._array._async_array
+ return await async_array.vindex.getitem(key)
+
def __getitem__(self, key):
array = self._array
if isinstance(key, indexing.BasicIndexer):
@@ -227,6 +241,19 @@ def __getitem__(self, key):
# if self.ndim == 0:
# could possibly have a work-around for 0d data here
+ async def async_getitem(self, key):
+ print("async getting")
+ array = self._array
+ if isinstance(key, indexing.BasicIndexer):
+ method = self._async_getitem
+ elif isinstance(key, indexing.VectorizedIndexer):
+ method = self._async_vindex
+ elif isinstance(key, indexing.OuterIndexer):
+ method = self._async_oindex
+ return await indexing.async_explicit_indexing_adapter(
+ key, array.shape, indexing.IndexingSupport.VECTORIZED, method
+ )
+
def _determine_zarr_chunks(
enc_chunks, var_chunks, ndim, name, safe_chunks, region, mode, shape
diff --git a/xarray/coding/common.py b/xarray/coding/common.py
index 1b455009668..8093827138b 100644
--- a/xarray/coding/common.py
+++ b/xarray/coding/common.py
@@ -75,6 +75,9 @@ def __getitem__(self, key):
def get_duck_array(self):
return self.func(self.array.get_duck_array())
+ async def async_get_duck_array(self):
+ return self.func(await self.array.async_get_duck_array())
+
def __repr__(self) -> str:
return f"{type(self).__name__}({self.array!r}, func={self.func!r}, dtype={self.dtype!r})"
diff --git a/xarray/coding/strings.py b/xarray/coding/strings.py
index 4ca6a3f0a46..a2295c218a6 100644
--- a/xarray/coding/strings.py
+++ b/xarray/coding/strings.py
@@ -250,14 +250,17 @@ def __repr__(self):
return f"{type(self).__name__}({self.array!r})"
def _vindex_get(self, key):
- return _numpy_char_to_bytes(self.array.vindex[key])
+ return type(self)(self.array.vindex[key])
def _oindex_get(self, key):
- return _numpy_char_to_bytes(self.array.oindex[key])
+ return type(self)(self.array.oindex[key])
def __getitem__(self, key):
# require slicing the last dimension completely
key = type(key)(indexing.expanded_indexer(key.tuple, self.array.ndim))
if key.tuple[-1] != slice(None):
raise IndexError("too many indices")
- return _numpy_char_to_bytes(self.array[key])
+ return type(self)(self.array[key])
+
+ def get_duck_array(self):
+ return _numpy_char_to_bytes(self.array.get_duck_array())
diff --git a/xarray/coding/variables.py b/xarray/coding/variables.py
index 1b7bc95e2b4..f82f0c65768 100644
--- a/xarray/coding/variables.py
+++ b/xarray/coding/variables.py
@@ -58,13 +58,16 @@ def dtype(self) -> np.dtype:
return np.dtype(self.array.dtype.kind + str(self.array.dtype.itemsize))
def _oindex_get(self, key):
- return np.asarray(self.array.oindex[key], dtype=self.dtype)
+ return type(self)(self.array.oindex[key])
def _vindex_get(self, key):
- return np.asarray(self.array.vindex[key], dtype=self.dtype)
+ return type(self)(self.array.vindex[key])
def __getitem__(self, key) -> np.ndarray:
- return np.asarray(self.array[key], dtype=self.dtype)
+ return type(self)(self.array[key])
+
+ def get_duck_array(self):
+ return duck_array_ops.astype(self.array.get_duck_array(), dtype=self.dtype)
class BoolTypeArray(indexing.ExplicitlyIndexedNDArrayMixin):
@@ -96,13 +99,16 @@ def dtype(self) -> np.dtype:
return np.dtype("bool")
def _oindex_get(self, key):
- return np.asarray(self.array.oindex[key], dtype=self.dtype)
+ return type(self)(self.array.oindex[key])
def _vindex_get(self, key):
- return np.asarray(self.array.vindex[key], dtype=self.dtype)
+ return type(self)(self.array.vindex[key])
def __getitem__(self, key) -> np.ndarray:
- return np.asarray(self.array[key], dtype=self.dtype)
+ return type(self)(self.array[key])
+
+ def get_duck_array(self):
+ return duck_array_ops.astype(self.array.get_duck_array(), dtype=self.dtype)
def _apply_mask(
diff --git a/xarray/core/dataarray.py b/xarray/core/dataarray.py
index 1e7e1069076..05f5d4c7fa8 100644
--- a/xarray/core/dataarray.py
+++ b/xarray/core/dataarray.py
@@ -1160,6 +1160,14 @@ def load(self, **kwargs) -> Self:
self._coords = new._coords
return self
+ async def load_async(self, **kwargs) -> Self:
+ temp_ds = self._to_temp_dataset()
+ ds = await temp_ds.load_async(**kwargs)
+ new = self._from_temp_dataset(ds)
+ self._variable = new._variable
+ self._coords = new._coords
+ return self
+
def compute(self, **kwargs) -> Self:
"""Manually trigger loading of this array's data from disk or a
remote source into memory and return a new array.
diff --git a/xarray/core/dataset.py b/xarray/core/dataset.py
index 5a7f757ba8a..8a4e7177caa 100644
--- a/xarray/core/dataset.py
+++ b/xarray/core/dataset.py
@@ -1,5 +1,6 @@
from __future__ import annotations
+import asyncio
import copy
import datetime
import math
@@ -531,24 +532,50 @@ def load(self, **kwargs) -> Self:
dask.compute
"""
# access .data to coerce everything to numpy or dask arrays
- lazy_data = {
+ chunked_data = {
k: v._data for k, v in self.variables.items() if is_chunked_array(v._data)
}
- if lazy_data:
- chunkmanager = get_chunked_array_type(*lazy_data.values())
+ if chunked_data:
+ chunkmanager = get_chunked_array_type(*chunked_data.values())
# evaluate all the chunked arrays simultaneously
evaluated_data: tuple[np.ndarray[Any, Any], ...] = chunkmanager.compute(
- *lazy_data.values(), **kwargs
+ *chunked_data.values(), **kwargs
)
- for k, data in zip(lazy_data, evaluated_data, strict=False):
+ for k, data in zip(chunked_data, evaluated_data, strict=False):
self.variables[k].data = data
# load everything else sequentially
- for k, v in self.variables.items():
- if k not in lazy_data:
- v.load()
+ [v.load() for k, v in self.variables.items() if k not in chunked_data]
+
+ return self
+
+ async def load_async(self, **kwargs) -> Self:
+ # TODO refactor this to pull out the common chunked_data codepath
+
+ # this blocks on chunked arrays but not on lazily indexed arrays
+
+ # access .data to coerce everything to numpy or dask arrays
+ chunked_data = {
+ k: v._data for k, v in self.variables.items() if is_chunked_array(v._data)
+ }
+ if chunked_data:
+ chunkmanager = get_chunked_array_type(*chunked_data.values())
+
+ # evaluate all the chunked arrays simultaneously
+ evaluated_data: tuple[np.ndarray[Any, Any], ...] = chunkmanager.compute(
+ *chunked_data.values(), **kwargs
+ )
+
+ for k, data in zip(chunked_data, evaluated_data, strict=False):
+ self.variables[k].data = data
+
+ # load everything else concurrently
+ coros = [
+ v.load_async() for k, v in self.variables.items() if k not in chunked_data
+ ]
+ await asyncio.gather(*coros)
return self
diff --git a/xarray/core/indexing.py b/xarray/core/indexing.py
index c1b847202c7..824558010e1 100644
--- a/xarray/core/indexing.py
+++ b/xarray/core/indexing.py
@@ -516,13 +516,31 @@ def get_duck_array(self):
return self.array
-class ExplicitlyIndexedNDArrayMixin(NDArrayMixin, ExplicitlyIndexed):
- __slots__ = ()
+class IndexingAdapter:
+ """Marker class for indexing adapters.
+
+ These classes translate between Xarray's indexing semantics and the underlying array's
+ indexing semantics.
+ """
def get_duck_array(self):
key = BasicIndexer((slice(None),) * self.ndim)
return self[key]
+ async def async_get_duck_array(self):
+ """These classes are applied to in-memory arrays, so specific async support isn't needed."""
+ return self.get_duck_array()
+
+
+class ExplicitlyIndexedNDArrayMixin(NDArrayMixin, ExplicitlyIndexed):
+ __slots__ = ()
+
+ def get_duck_array(self):
+ raise NotImplementedError
+
+ async def async_get_duck_array(self):
+ raise NotImplementedError
+
def _oindex_get(self, indexer: OuterIndexer):
raise NotImplementedError(
f"{self.__class__.__name__}._oindex_get method should be overridden"
@@ -646,19 +664,25 @@ def shape(self) -> _Shape:
return self._shape
def get_duck_array(self):
- if isinstance(self.array, ExplicitlyIndexedNDArrayMixin):
- array = apply_indexer(self.array, self.key)
- else:
- # If the array is not an ExplicitlyIndexedNDArrayMixin,
- # it may wrap a BackendArray so use its __getitem__
+ from xarray.backends.common import BackendArray
+
+ if isinstance(self.array, BackendArray):
array = self.array[self.key]
+ else:
+ array = apply_indexer(self.array, self.key)
+ if isinstance(array, ExplicitlyIndexed):
+ array = array.get_duck_array()
+ return _wrap_numpy_scalars(array)
- # self.array[self.key] is now a numpy array when
- # self.array is a BackendArray subclass
- # and self.key is BasicIndexer((slice(None, None, None),))
- # so we need the explicit check for ExplicitlyIndexed
- if isinstance(array, ExplicitlyIndexed):
- array = array.get_duck_array()
+ async def async_get_duck_array(self):
+ from xarray.backends.common import BackendArray
+
+ if isinstance(self.array, BackendArray):
+ array = await self.array.async_getitem(self.key)
+ else:
+ array = apply_indexer(self.array, self.key)
+ if isinstance(array, ExplicitlyIndexed):
+ array = await array.async_get_duck_array()
return _wrap_numpy_scalars(array)
def transpose(self, order):
@@ -722,18 +746,26 @@ def shape(self) -> _Shape:
return np.broadcast(*self.key.tuple).shape
def get_duck_array(self):
- if isinstance(self.array, ExplicitlyIndexedNDArrayMixin):
+ from xarray.backends.common import BackendArray
+
+ if isinstance(self.array, BackendArray):
+ array = self.array[self.key]
+ else:
array = apply_indexer(self.array, self.key)
+ if isinstance(array, ExplicitlyIndexed):
+ array = array.get_duck_array()
+ return _wrap_numpy_scalars(array)
+
+ async def async_get_duck_array(self):
+ print("inside LazilyVectorizedIndexedArray.async_get_duck_array")
+ from xarray.backends.common import BackendArray
+
+ if isinstance(self.array, BackendArray):
+ array = await self.array.async_getitem(self.key)
else:
- # If the array is not an ExplicitlyIndexedNDArrayMixin,
- # it may wrap a BackendArray so use its __getitem__
- array = self.array[self.key]
- # self.array[self.key] is now a numpy array when
- # self.array is a BackendArray subclass
- # and self.key is BasicIndexer((slice(None, None, None),))
- # so we need the explicit check for ExplicitlyIndexed
- if isinstance(array, ExplicitlyIndexed):
- array = array.get_duck_array()
+ array = apply_indexer(self.array, self.key)
+ if isinstance(array, ExplicitlyIndexed):
+ array = await array.async_get_duck_array()
return _wrap_numpy_scalars(array)
def _updated_key(self, new_key: ExplicitIndexer):
@@ -797,6 +829,9 @@ def _ensure_copied(self):
def get_duck_array(self):
return self.array.get_duck_array()
+ async def async_get_duck_array(self):
+ return await self.array.async_get_duck_array()
+
def _oindex_get(self, indexer: OuterIndexer):
return type(self)(_wrap_numpy_scalars(self.array.oindex[indexer]))
@@ -837,12 +872,17 @@ class MemoryCachedArray(ExplicitlyIndexedNDArrayMixin):
def __init__(self, array):
self.array = _wrap_numpy_scalars(as_indexable(array))
- def _ensure_cached(self):
- self.array = as_indexable(self.array.get_duck_array())
-
def get_duck_array(self):
- self._ensure_cached()
- return self.array.get_duck_array()
+ duck_array = self.array.get_duck_array()
+ # ensure the array object is cached in-memory
+ self.array = as_indexable(duck_array)
+ return duck_array
+
+ async def async_get_duck_array(self):
+ duck_array = await self.array.async_get_duck_array()
+ # ensure the array object is cached in-memory
+ self.array = as_indexable(duck_array)
+ return duck_array
def _oindex_get(self, indexer: OuterIndexer):
return type(self)(_wrap_numpy_scalars(self.array.oindex[indexer]))
@@ -1027,6 +1067,21 @@ def explicit_indexing_adapter(
return result
+async def async_explicit_indexing_adapter(
+ key: ExplicitIndexer,
+ shape: _Shape,
+ indexing_support: IndexingSupport,
+ raw_indexing_method: Callable[..., Any],
+) -> Any:
+ raw_key, numpy_indices = decompose_indexer(key, shape, indexing_support)
+ result = await raw_indexing_method(raw_key.tuple)
+ if numpy_indices.tuple:
+ # index the loaded duck array
+ indexable = as_indexable(result)
+ result = apply_indexer(indexable, numpy_indices)
+ return result
+
+
def apply_indexer(indexable, indexer: ExplicitIndexer):
"""Apply an indexer to an indexable object."""
if isinstance(indexer, VectorizedIndexer):
@@ -1526,7 +1581,7 @@ def is_fancy_indexer(indexer: Any) -> bool:
return True
-class NumpyIndexingAdapter(ExplicitlyIndexedNDArrayMixin):
+class NumpyIndexingAdapter(IndexingAdapter, ExplicitlyIndexedNDArrayMixin):
"""Wrap a NumPy array to use explicit indexing."""
__slots__ = ("array",)
@@ -1605,7 +1660,7 @@ def __init__(self, array):
self.array = array
-class ArrayApiIndexingAdapter(ExplicitlyIndexedNDArrayMixin):
+class ArrayApiIndexingAdapter(IndexingAdapter, ExplicitlyIndexedNDArrayMixin):
"""Wrap an array API array to use explicit indexing."""
__slots__ = ("array",)
@@ -1670,7 +1725,7 @@ def _assert_not_chunked_indexer(idxr: tuple[Any, ...]) -> None:
)
-class DaskIndexingAdapter(ExplicitlyIndexedNDArrayMixin):
+class DaskIndexingAdapter(IndexingAdapter, ExplicitlyIndexedNDArrayMixin):
"""Wrap a dask array to support explicit indexing."""
__slots__ = ("array",)
@@ -1746,7 +1801,7 @@ def transpose(self, order):
return self.array.transpose(order)
-class PandasIndexingAdapter(ExplicitlyIndexedNDArrayMixin):
+class PandasIndexingAdapter(IndexingAdapter, ExplicitlyIndexedNDArrayMixin):
"""Wrap a pandas.Index to preserve dtypes and handle explicit indexing."""
__slots__ = ("_dtype", "array")
@@ -2063,7 +2118,9 @@ def copy(self, deep: bool = True) -> Self:
return type(self)(array, self._dtype, self.level)
-class CoordinateTransformIndexingAdapter(ExplicitlyIndexedNDArrayMixin):
+class CoordinateTransformIndexingAdapter(
+ IndexingAdapter, ExplicitlyIndexedNDArrayMixin
+):
"""Wrap a CoordinateTransform as a lazy coordinate array.
Supports explicit indexing (both outer and vectorized).
diff --git a/xarray/core/variable.py b/xarray/core/variable.py
index 4e58b0d4b20..38f2676ec52 100644
--- a/xarray/core/variable.py
+++ b/xarray/core/variable.py
@@ -47,6 +47,7 @@
from xarray.namedarray.core import NamedArray, _raise_if_any_duplicate_dimensions
from xarray.namedarray.parallelcompat import get_chunked_array_type
from xarray.namedarray.pycompat import (
+ async_to_duck_array,
integer_types,
is_0d_dask_array,
is_chunked_array,
@@ -957,6 +958,10 @@ def load(self, **kwargs):
self._data = to_duck_array(self._data, **kwargs)
return self
+ async def load_async(self, **kwargs):
+ self._data = await async_to_duck_array(self._data, **kwargs)
+ return self
+
def compute(self, **kwargs):
"""Manually trigger loading of this variable's data from disk or a
remote source into memory and return a new variable. The original is
diff --git a/xarray/namedarray/pycompat.py b/xarray/namedarray/pycompat.py
index 68b6a7853bf..6e61d3445ab 100644
--- a/xarray/namedarray/pycompat.py
+++ b/xarray/namedarray/pycompat.py
@@ -145,3 +145,23 @@ def to_duck_array(data: Any, **kwargs: dict[str, Any]) -> duckarray[_ShapeType,
return data
else:
return np.asarray(data) # type: ignore[return-value]
+
+
+async def async_to_duck_array(
+ data: Any, **kwargs: dict[str, Any]
+) -> duckarray[_ShapeType, _DType]:
+ from xarray.core.indexing import (
+ ExplicitlyIndexed,
+ ImplicitToExplicitIndexingAdapter,
+ IndexingAdapter,
+ )
+
+ print(type(data))
+ if isinstance(data, IndexingAdapter):
+ # These wrap in-memory arrays, and async isn't needed
+ return data.get_duck_array()
+ elif isinstance(data, ExplicitlyIndexed | ImplicitToExplicitIndexingAdapter):
+ print("async inside to_duck_array")
+ return await data.async_get_duck_array() # type: ignore[no-untyped-call, no-any-return]
+ else:
+ return to_duck_array(data, **kwargs)
diff --git a/xarray/tests/test_async.py b/xarray/tests/test_async.py
new file mode 100644
index 00000000000..918a2508ea0
--- /dev/null
+++ b/xarray/tests/test_async.py
@@ -0,0 +1,221 @@
+import asyncio
+import time
+from collections.abc import Iterable
+from contextlib import asynccontextmanager
+from typing import TypeVar
+from unittest.mock import patch
+
+import numpy as np
+import pytest
+
+import xarray as xr
+import xarray.testing as xrt
+from xarray.tests import has_zarr_v3, requires_zarr_v3
+
+if has_zarr_v3:
+ import zarr
+ from zarr.abc.store import ByteRequest, Store
+ from zarr.core.buffer import Buffer, BufferPrototype
+ from zarr.storage import MemoryStore
+ from zarr.storage._wrapper import WrapperStore
+
+ T_Store = TypeVar("T_Store", bound=Store)
+
+ class LatencyStore(WrapperStore[T_Store]):
+ """Works the same way as the zarr LoggingStore"""
+
+ latency: float
+
+ # TODO only have to add this because of dumb behaviour in zarr where it raises with "ValueError: Store is not read-only but mode is 'r'"
+ read_only = True
+
+ def __init__(
+ self,
+ store: T_Store,
+ latency: float = 0.0,
+ ) -> None:
+ """
+ Store wrapper that adds artificial latency to each get call.
+
+ Parameters
+ ----------
+ store : Store
+ Store to wrap
+ latency : float
+ Amount of artificial latency to add to each get call, in seconds.
+ """
+ super().__init__(store)
+ self.latency = latency
+
+ def __str__(self) -> str:
+ return f"latency-{self._store}"
+
+ def __repr__(self) -> str:
+ return f"LatencyStore({self._store.__class__.__name__}, '{self._store}', latency={self.latency})"
+
+ async def get(
+ self,
+ key: str,
+ prototype: BufferPrototype,
+ byte_range: ByteRequest | None = None,
+ ) -> Buffer | None:
+ await asyncio.sleep(self.latency)
+ return await self._store.get(
+ key=key, prototype=prototype, byte_range=byte_range
+ )
+
+ async def get_partial_values(
+ self,
+ prototype: BufferPrototype,
+ key_ranges: Iterable[tuple[str, ByteRequest | None]],
+ ) -> list[Buffer | None]:
+ await asyncio.sleep(self.latency)
+ return await self._store.get_partial_values(
+ prototype=prototype, key_ranges=key_ranges
+ )
+else:
+ LatencyStore = {}
+
+
+@pytest.fixture
+def memorystore() -> "MemoryStore":
+ memorystore = zarr.storage.MemoryStore({})
+ z1 = zarr.create_array(
+ store=memorystore,
+ name="foo",
+ shape=(10, 10),
+ chunks=(5, 5),
+ dtype="f4",
+ dimension_names=["x", "y"],
+ attributes={"add_offset": 1, "scale_factor": 2},
+ )
+ z1[:, :] = np.random.random((10, 10))
+
+ z2 = zarr.create_array(
+ store=memorystore,
+ name="x",
+ shape=(10,),
+ chunks=(5),
+ dtype="f4",
+ dimension_names=["x"],
+ )
+ z2[:] = np.arange(10)
+
+ return memorystore
+
+
+class AsyncTimer:
+ """Context manager for timing async operations and making assertions about their execution time."""
+
+ start_time: float
+ end_time: float
+ total_time: float
+
+ @asynccontextmanager
+ async def measure(self):
+ """Measure the execution time of the async code within this context."""
+ self.start_time = time.time()
+ try:
+ yield self
+ finally:
+ self.end_time = time.time()
+ self.total_time = self.end_time - self.start_time
+
+
+@requires_zarr_v3
+@pytest.mark.asyncio
+class TestAsyncLoad:
+ LATENCY: float = 1.0
+
+ @pytest.fixture(params=["var", "ds", "da"])
+ def xr_obj(self, request, memorystore) -> xr.Dataset | xr.DataArray | xr.Variable:
+ latencystore = LatencyStore(memorystore, latency=self.LATENCY)
+ ds = xr.open_zarr(latencystore, zarr_format=3, consolidated=False, chunks=None)
+
+ match request.param:
+ case "var":
+ return ds["foo"].variable
+ case "da":
+ return ds["foo"]
+ case "ds":
+ return ds
+
+ def assert_time_as_expected(
+ self, total_time: float, latency: float, n_loads: int
+ ) -> None:
+ assert total_time > latency # Cannot possibly be quicker than this
+ assert (
+ total_time < latency * n_loads
+ ) # If this isn't true we're gaining nothing from async
+ assert (
+ abs(total_time - latency) < 2.0
+ ) # Should take approximately `latency` seconds, but allow some buffer
+
+ async def test_concurrent_load_multiple_variables(self, memorystore) -> None:
+ latencystore = LatencyStore(memorystore, latency=self.LATENCY)
+ ds = xr.open_zarr(latencystore, zarr_format=3, consolidated=False, chunks=None)
+
+ # TODO up the number of variables in the dataset?
+ async with AsyncTimer().measure() as timer:
+ result_ds = await ds.load_async()
+
+ xrt.assert_identical(result_ds, ds.load())
+
+ # 2 because there are 2 lazy variables in the dataset
+ self.assert_time_as_expected(
+ total_time=timer.total_time, latency=self.LATENCY, n_loads=2
+ )
+
+ async def test_concurrent_load_multiple_objects(self, xr_obj) -> None:
+ N_OBJECTS = 5
+
+ async with AsyncTimer().measure() as timer:
+ coros = [xr_obj.load_async() for _ in range(N_OBJECTS)]
+ results = await asyncio.gather(*coros)
+
+ for result in results:
+ xrt.assert_identical(result, xr_obj.load())
+
+ self.assert_time_as_expected(
+ total_time=timer.total_time, latency=self.LATENCY, n_loads=N_OBJECTS
+ )
+
+ @pytest.mark.parametrize("method", ["sel", "isel"])
+ @pytest.mark.parametrize(
+ "indexer, zarr_getitem_method",
+ [
+ ({"x": 2}, zarr.AsyncArray.getitem),
+ ({"x": slice(2, 4)}, zarr.AsyncArray.getitem),
+ ({"x": [2, 3]}, zarr.core.indexing.AsyncOIndex.getitem),
+ (
+ {
+ "x": xr.DataArray([2, 3], dims="points"),
+ "y": xr.DataArray([2, 3], dims="points"),
+ },
+ zarr.core.indexing.AsyncVIndex.getitem,
+ ),
+ ],
+ ids=["basic-int", "basic-slice", "outer", "vectorized"],
+ )
+ async def test_indexing(
+ self, memorystore, method, indexer, zarr_getitem_method
+ ) -> None:
+ # TODO we don't need a LatencyStore for this test
+ latencystore = LatencyStore(memorystore, latency=0.0)
+
+ with patch.object(
+ zarr.AsyncArray, "getitem", wraps=zarr_getitem_method, autospec=True
+ ) as mocked_meth:
+ ds = xr.open_zarr(
+ latencystore, zarr_format=3, consolidated=False, chunks=None
+ )
+
+ # TODO we're not actually testing that these indexing methods are not blocking...
+ result = await getattr(ds, method)(**indexer).load_async()
+
+ assert mocked_meth.call_count > 0
+ mocked_meth.assert_called()
+ mocked_meth.assert_awaited()
+
+ expected = getattr(ds, method)(**indexer).load()
+ xrt.assert_identical(result, expected)
diff --git a/xarray/tests/test_indexing.py b/xarray/tests/test_indexing.py
index 6dd75b58c6a..d308844c6fa 100644
--- a/xarray/tests/test_indexing.py
+++ b/xarray/tests/test_indexing.py
@@ -490,6 +490,23 @@ def test_sub_array(self) -> None:
assert isinstance(child.array, indexing.NumpyIndexingAdapter)
assert isinstance(wrapped.array, indexing.LazilyIndexedArray)
+ async def test_async_wrapper(self) -> None:
+ original = indexing.LazilyIndexedArray(np.arange(10))
+ wrapped = indexing.MemoryCachedArray(original)
+ await wrapped.async_get_duck_array()
+ assert_array_equal(wrapped, np.arange(10))
+ assert isinstance(wrapped.array, indexing.NumpyIndexingAdapter)
+
+ async def test_async_sub_array(self) -> None:
+ original = indexing.LazilyIndexedArray(np.arange(10))
+ wrapped = indexing.MemoryCachedArray(original)
+ child = wrapped[B[:5]]
+ assert isinstance(child, indexing.MemoryCachedArray)
+ await child.async_get_duck_array()
+ assert_array_equal(child, np.arange(5))
+ assert isinstance(child.array, indexing.NumpyIndexingAdapter)
+ assert isinstance(wrapped.array, indexing.LazilyIndexedArray)
+
def test_setitem(self) -> None:
original = np.arange(10)
wrapped = indexing.MemoryCachedArray(original)