Skip to content

Add asynchronous load method #10327

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 50 commits into
base: main
Choose a base branch
from
Draft

Conversation

TomNicholas
Copy link
Member

@TomNicholas TomNicholas commented May 16, 2025

Adds an .async_load() method to Variable, which works by plumbing async get_duck_array all the way down until it finally gets to the async methods zarr v3 exposes.

Needs a lot of refactoring before it could be merged, but it works.

API:

  • Variable.load_async
  • DataArray.load_async
  • Dataset.load_async
  • DataTree.load_async
  • load_dataset?
  • load_dataarray?

@TomNicholas
Copy link
Member Author

TomNicholas commented May 19, 2025

These failing tests from the CI do not fail when I run them locally, which is interesting.

FAILED xarray/tests/test_backends.py::TestH5NetCDFViaDaskData::test_outer_indexing_reversed - ValueError: dimensions ('t', 'y', 'x') must have the same length as the number of data dimensions, ndim=4
FAILED xarray/tests/test_backends.py::TestNetCDF4ViaDaskData::test_outer_indexing_reversed - ValueError: dimensions ('t', 'y', 'x') must have the same length as the number of data dimensions, ndim=4
FAILED xarray/tests/test_backends.py::TestDask::test_outer_indexing_reversed - ValueError: dimensions ('t', 'y', 'x') must have the same length as the number of data dimensions, ndim=4
= 3 failed, 18235 passed, 1269 skipped, 77 xfailed, 15 xpassed, 2555 warnings in 487.15s (0:08:07) =
Error: Process completed with exit code 1.

@@ -267,13 +268,23 @@ def robust_getitem(array, key, catch=Exception, max_retries=6, initial_delay=500
time.sleep(1e-3 * next_delay)


class BackendArray(NdimSizeLenMixin, indexing.ExplicitlyIndexed):
class BackendArray(ABC, NdimSizeLenMixin, indexing.ExplicitlyIndexed):
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As __getitem__ is required, I feel like BackendArray should always have been an ABC.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This class is public API and this is a backwards incompatible change.

Comment on lines +277 to +278
async def async_getitem(key: indexing.ExplicitIndexer) -> np.typing.ArrayLike:
raise NotImplementedError("Backend does not not support asynchronous loading")
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've implemented this for the ZarrArray class but in theory it could be supported by other backends too.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This might not be the desired behaviour though - this currently means if you opened a dataset from netCDF and called ds.load_async you would get a NotImplementedError. Would it be better to quietly just block instead?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes absolutely.

Comment on lines +574 to +578
# load everything else concurrently
coros = [
v.load_async() for k, v in self.variables.items() if k not in chunked_data
]
await asyncio.gather(*coros)
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We could actually do this same thing inside of the synchronous ds.load() too, but it would require:

  1. Xarray to decide how to call the async code, e.g. with a ThreadPool or similar (see Support concurrent loading of variables #8965)
  2. The backend to support async_getitem (it could fall back to synchronous loading if it's not supported)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should rate-limite all gather calls with a Semaphore using something like this:

async def async_gather(*coros, concurrency: Optional[int] = None, return_exceptions: bool = False) -> list[Any]:
    """Execute a gather while limiting the number of concurrent tasks.

    Args:
        coros: coroutines
            list of coroutines to execute
        concurrency: int
            concurrency limit
            if None, defaults to config_obj.get('async.concurrency', 4)
            if <= 0, no concurrency limit

    """
    if concurrency is None:
        concurrency = int(config_obj.get("async.concurrency", 4))

    if concurrency > 0:
        # if concurrency > 0, we use a semaphore to limit the number of concurrent coroutines
        semaphore = asyncio.Semaphore(concurrency)

        async def sem_coro(coro):
            async with semaphore:
                return await coro

        results = await asyncio.gather(*(sem_coro(c) for c in coros), return_exceptions=return_exceptions)
    else:
        results = await asyncio.gather(*coros, return_exceptions=return_exceptions)

    return results

@ianhi
Copy link
Contributor

ianhi commented May 22, 2025

There is something funky going on when using .sel

# /// script
# requires-python = ">=3.12"
# dependencies = [
#     "arraylake",
#     "yappi",
#     "zarr==3.0.8",
#     "xarray",
#     "icechunk"
# ]
#
# [tool.uv.sources]
# xarray = { git = "https://github.com/TomNicholas/xarray", rev = "async.load" }
# ///

import asyncio
from collections.abc import Iterable
from typing import TypeVar

import numpy as np

import xarray as xr

import zarr
from zarr.abc.store import ByteRequest, Store
from zarr.core.buffer import Buffer, BufferPrototype
from zarr.storage._wrapper import WrapperStore

T_Store = TypeVar("T_Store", bound=Store)


class LatencyStore(WrapperStore[T_Store]):
    """Works the same way as the zarr LoggingStore"""

    latency: float

    def __init__(
        self,
        store: T_Store,
        latency: float = 0.0,
    ) -> None:
        """
        Store wrapper that adds artificial latency to each get call.

        Parameters
        ----------
        store : Store
            Store to wrap
        latency : float
            Amount of artificial latency to add to each get call, in seconds.
        """
        super().__init__(store)
        self.latency = latency

    def __str__(self) -> str:
        return f"latency-{self._store}"

    def __repr__(self) -> str:
        return f"LatencyStore({self._store.__class__.__name__}, '{self._store}', latency={self.latency})"

    async def get(
        self,
        key: str,
        prototype: BufferPrototype,
        byte_range: ByteRequest | None = None,
    ) -> Buffer | None:
        await asyncio.sleep(self.latency)
        return await self._store.get(
            key=key, prototype=prototype, byte_range=byte_range
        )

    async def get_partial_values(
        self,
        prototype: BufferPrototype,
        key_ranges: Iterable[tuple[str, ByteRequest | None]],
    ) -> list[Buffer | None]:
        await asyncio.sleep(self.latency)
        return await self._store.get_partial_values(
            prototype=prototype, key_ranges=key_ranges
        )


memorystore = zarr.storage.MemoryStore({})

shape = 5
X = np.arange(5) * 10
ds = xr.Dataset(
    {
        "data": xr.DataArray(
            np.zeros(shape),
            coords={"x": X},
        )
    }
)

ds.to_zarr(memorystore)


latencystore = LatencyStore(memorystore, latency=0.1)
ds = xr.open_zarr(latencystore, zarr_format=3, consolidated=False, chunks=None)

# no problem for any of these
asyncio.run(ds["data"][0].load_async())
asyncio.run(ds["data"].sel(x=10).load_async())
asyncio.run(ds["data"].sel(x=11, method="nearest").load_async())

# also fine
ds["data"].sel(x=[30, 40]).load()

# broken!
asyncio.run(ds["data"].sel(x=[30, 40]).load_async())

uv run that script gives:

Traceback (most recent call last):
  File "/Users/ian/tmp/async_error.py", line 109, in <module>
    asyncio.run(ds["data"].sel(x=[30, 40]).load_async())
  File "/Users/ian/miniforge3/envs/test/lib/python3.12/asyncio/runners.py", line 195, in run
    return runner.run(main)
           ^^^^^^^^^^^^^^^^
  File "/Users/ian/miniforge3/envs/test/lib/python3.12/asyncio/runners.py", line 118, in run
    return self._loop.run_until_complete(task)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/ian/miniforge3/envs/test/lib/python3.12/asyncio/base_events.py", line 691, in run_until_complete
    return future.result()
           ^^^^^^^^^^^^^^^
  File "/Users/ian/.cache/uv/environments-v2/async-error-29817fa21dae3c0f/lib/python3.12/site-packages/xarray/core/dataarray.py", line 1165, in load_async
    ds = await temp_ds.load_async(**kwargs)
         ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/ian/.cache/uv/environments-v2/async-error-29817fa21dae3c0f/lib/python3.12/site-packages/xarray/core/dataset.py", line 578, in load_async
    await asyncio.gather(*coros)
  File "/Users/ian/.cache/uv/environments-v2/async-error-29817fa21dae3c0f/lib/python3.12/site-packages/xarray/core/variable.py", line 963, in load_async
    self._data = await async_to_duck_array(self._data, **kwargs)
                 ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/ian/.cache/uv/environments-v2/async-error-29817fa21dae3c0f/lib/python3.12/site-packages/xarray/namedarray/pycompat.py", line 168, in async_to_duck_array
    return await data.async_get_duck_array()  # type: ignore[no-untyped-call, no-any-return]
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/ian/.cache/uv/environments-v2/async-error-29817fa21dae3c0f/lib/python3.12/site-packages/xarray/core/indexing.py", line 875, in async_get_duck_array
    await self._async_ensure_cached()
  File "/Users/ian/.cache/uv/environments-v2/async-error-29817fa21dae3c0f/lib/python3.12/site-packages/xarray/core/indexing.py", line 867, in _async_ensure_cached
    duck_array = await self.array.async_get_duck_array()
                 ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/ian/.cache/uv/environments-v2/async-error-29817fa21dae3c0f/lib/python3.12/site-packages/xarray/core/indexing.py", line 821, in async_get_duck_array
    return await self.array.async_get_duck_array()
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/ian/.cache/uv/environments-v2/async-error-29817fa21dae3c0f/lib/python3.12/site-packages/xarray/core/indexing.py", line 674, in async_get_duck_array
    array = await self.array.async_getitem(self.key)
            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/ian/.cache/uv/environments-v2/async-error-29817fa21dae3c0f/lib/python3.12/site-packages/xarray/backends/zarr.py", line 248, in async_getitem
    return await indexing.async_explicit_indexing_adapter(
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/ian/.cache/uv/environments-v2/async-error-29817fa21dae3c0f/lib/python3.12/site-packages/xarray/core/indexing.py", line 1068, in async_explicit_indexing_adapter
    result = await raw_indexing_method(raw_key.tuple)
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
TypeError: object numpy.ndarray can't be used in 'await' expression

Comment on lines 240 to 245
elif isinstance(key, indexing.VectorizedIndexer):
# TODO
method = self._vindex
elif isinstance(key, indexing.OuterIndexer):
# TODO
method = self._oindex
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@ianhi almost certainly these need to become async to fix your bug

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Outer (also known as "Orthogonal") indexing support added in 5eacdb0, but requires changes to zarr-python: zarr-developers/zarr-python#3083

Comment on lines +192 to +196
# test vectorized indexing
# TODO this shouldn't pass! I haven't implemented async vectorized indexing yet...
indexer = xr.DataArray([2, 3], dims=["x"])
result = await ds.foo[indexer].load_async()
xrt.assert_identical(result, ds.foo[indexer].load())
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This currently passes, even though it shouldn't, because I haven't added support for async vectorized indexing yet!

I think this means that my test is wrong, and what I'm doing here is apparently not vectorized indexing. I'm unsure what my test would have to look like though 😕

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is an outer indexer. Try xr.DataArray([[2, 3]], dims=["y", "x"])

Comment on lines +277 to +278
async def async_getitem(key: indexing.ExplicitIndexer) -> np.typing.ArrayLike:
raise NotImplementedError("Backend does not not support asynchronous loading")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes absolutely.

@@ -267,13 +268,23 @@ def robust_getitem(array, key, catch=Exception, max_retries=6, initial_delay=500
time.sleep(1e-3 * next_delay)


class BackendArray(NdimSizeLenMixin, indexing.ExplicitlyIndexed):
class BackendArray(ABC, NdimSizeLenMixin, indexing.ExplicitlyIndexed):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This class is public API and this is a backwards incompatible change.

Comment on lines +574 to +578
# load everything else concurrently
coros = [
v.load_async() for k, v in self.variables.items() if k not in chunked_data
]
await asyncio.gather(*coros)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should rate-limite all gather calls with a Semaphore using something like this:

async def async_gather(*coros, concurrency: Optional[int] = None, return_exceptions: bool = False) -> list[Any]:
    """Execute a gather while limiting the number of concurrent tasks.

    Args:
        coros: coroutines
            list of coroutines to execute
        concurrency: int
            concurrency limit
            if None, defaults to config_obj.get('async.concurrency', 4)
            if <= 0, no concurrency limit

    """
    if concurrency is None:
        concurrency = int(config_obj.get("async.concurrency", 4))

    if concurrency > 0:
        # if concurrency > 0, we use a semaphore to limit the number of concurrent coroutines
        semaphore = asyncio.Semaphore(concurrency)

        async def sem_coro(coro):
            async with semaphore:
                return await coro

        results = await asyncio.gather(*(sem_coro(c) for c in coros), return_exceptions=return_exceptions)
    else:
        results = await asyncio.gather(*coros, return_exceptions=return_exceptions)

    return results

case "ds":
return ds

def assert_time_as_expected(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's instead use mocks to assert the async methods were called. Xarray's job is to do that only

Comment on lines +192 to +196
# test vectorized indexing
# TODO this shouldn't pass! I haven't implemented async vectorized indexing yet...
indexer = xr.DataArray([2, 3], dims=["x"])
result = await ds.foo[indexer].load_async()
xrt.assert_identical(result, ds.foo[indexer].load())
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is an outer indexer. Try xr.DataArray([[2, 3]], dims=["y", "x"])

async def _async_ensure_cached(self):
duck_array = await self.array.async_get_duck_array()
self.array = as_indexable(duck_array)

def get_duck_array(self):
self._ensure_cached()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

_ensure_cached seems like pointless indirection, it is only used once. let's consolidate.

return self

async def load_async(self, **kwargs) -> Self:
# TODO refactor this to pull out the common chunked_data codepath
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

let's instead just have the sync methods issue a blocking call to the async versions.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CI Continuous Integration tools dependencies Pull requests that update a dependency file enhancement io topic-backends topic-documentation topic-indexing topic-NamedArray Lightweight version of Variable topic-zarr Related to zarr storage library
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Add an asynchronous load method?
3 participants