-
-
Notifications
You must be signed in to change notification settings - Fork 1.1k
Add asynchronous load method #10327
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Add asynchronous load method #10327
Conversation
for more information, see https://pre-commit.ci
These failing tests from the CI do not fail when I run them locally, which is interesting. FAILED xarray/tests/test_backends.py::TestH5NetCDFViaDaskData::test_outer_indexing_reversed - ValueError: dimensions ('t', 'y', 'x') must have the same length as the number of data dimensions, ndim=4
FAILED xarray/tests/test_backends.py::TestNetCDF4ViaDaskData::test_outer_indexing_reversed - ValueError: dimensions ('t', 'y', 'x') must have the same length as the number of data dimensions, ndim=4
FAILED xarray/tests/test_backends.py::TestDask::test_outer_indexing_reversed - ValueError: dimensions ('t', 'y', 'x') must have the same length as the number of data dimensions, ndim=4
= 3 failed, 18235 passed, 1269 skipped, 77 xfailed, 15 xpassed, 2555 warnings in 487.15s (0:08:07) =
Error: Process completed with exit code 1. |
@@ -267,13 +268,23 @@ def robust_getitem(array, key, catch=Exception, max_retries=6, initial_delay=500 | |||
time.sleep(1e-3 * next_delay) | |||
|
|||
|
|||
class BackendArray(NdimSizeLenMixin, indexing.ExplicitlyIndexed): | |||
class BackendArray(ABC, NdimSizeLenMixin, indexing.ExplicitlyIndexed): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
As __getitem__
is required, I feel like BackendArray
should always have been an ABC.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This class is public API and this is a backwards incompatible change.
async def async_getitem(key: indexing.ExplicitIndexer) -> np.typing.ArrayLike: | ||
raise NotImplementedError("Backend does not not support asynchronous loading") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I've implemented this for the ZarrArray
class but in theory it could be supported by other backends too.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This might not be the desired behaviour though - this currently means if you opened a dataset from netCDF and called ds.load_async
you would get a NotImplementedError
. Would it be better to quietly just block instead?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes absolutely.
for more information, see https://pre-commit.ci
# load everything else concurrently | ||
coros = [ | ||
v.load_async() for k, v in self.variables.items() if k not in chunked_data | ||
] | ||
await asyncio.gather(*coros) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We could actually do this same thing inside of the synchronous ds.load()
too, but it would require:
- Xarray to decide how to call the async code, e.g. with a
ThreadPool
or similar (see Support concurrent loading of variables #8965) - The backend to support
async_getitem
(it could fall back to synchronous loading if it's not supported)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We should rate-limite all gather
calls with a Semaphore using something like this:
async def async_gather(*coros, concurrency: Optional[int] = None, return_exceptions: bool = False) -> list[Any]:
"""Execute a gather while limiting the number of concurrent tasks.
Args:
coros: coroutines
list of coroutines to execute
concurrency: int
concurrency limit
if None, defaults to config_obj.get('async.concurrency', 4)
if <= 0, no concurrency limit
"""
if concurrency is None:
concurrency = int(config_obj.get("async.concurrency", 4))
if concurrency > 0:
# if concurrency > 0, we use a semaphore to limit the number of concurrent coroutines
semaphore = asyncio.Semaphore(concurrency)
async def sem_coro(coro):
async with semaphore:
return await coro
results = await asyncio.gather(*(sem_coro(c) for c in coros), return_exceptions=return_exceptions)
else:
results = await asyncio.gather(*coros, return_exceptions=return_exceptions)
return results
There is something funky going on when using # /// script
# requires-python = ">=3.12"
# dependencies = [
# "arraylake",
# "yappi",
# "zarr==3.0.8",
# "xarray",
# "icechunk"
# ]
#
# [tool.uv.sources]
# xarray = { git = "https://github.com/TomNicholas/xarray", rev = "async.load" }
# ///
import asyncio
from collections.abc import Iterable
from typing import TypeVar
import numpy as np
import xarray as xr
import zarr
from zarr.abc.store import ByteRequest, Store
from zarr.core.buffer import Buffer, BufferPrototype
from zarr.storage._wrapper import WrapperStore
T_Store = TypeVar("T_Store", bound=Store)
class LatencyStore(WrapperStore[T_Store]):
"""Works the same way as the zarr LoggingStore"""
latency: float
def __init__(
self,
store: T_Store,
latency: float = 0.0,
) -> None:
"""
Store wrapper that adds artificial latency to each get call.
Parameters
----------
store : Store
Store to wrap
latency : float
Amount of artificial latency to add to each get call, in seconds.
"""
super().__init__(store)
self.latency = latency
def __str__(self) -> str:
return f"latency-{self._store}"
def __repr__(self) -> str:
return f"LatencyStore({self._store.__class__.__name__}, '{self._store}', latency={self.latency})"
async def get(
self,
key: str,
prototype: BufferPrototype,
byte_range: ByteRequest | None = None,
) -> Buffer | None:
await asyncio.sleep(self.latency)
return await self._store.get(
key=key, prototype=prototype, byte_range=byte_range
)
async def get_partial_values(
self,
prototype: BufferPrototype,
key_ranges: Iterable[tuple[str, ByteRequest | None]],
) -> list[Buffer | None]:
await asyncio.sleep(self.latency)
return await self._store.get_partial_values(
prototype=prototype, key_ranges=key_ranges
)
memorystore = zarr.storage.MemoryStore({})
shape = 5
X = np.arange(5) * 10
ds = xr.Dataset(
{
"data": xr.DataArray(
np.zeros(shape),
coords={"x": X},
)
}
)
ds.to_zarr(memorystore)
latencystore = LatencyStore(memorystore, latency=0.1)
ds = xr.open_zarr(latencystore, zarr_format=3, consolidated=False, chunks=None)
# no problem for any of these
asyncio.run(ds["data"][0].load_async())
asyncio.run(ds["data"].sel(x=10).load_async())
asyncio.run(ds["data"].sel(x=11, method="nearest").load_async())
# also fine
ds["data"].sel(x=[30, 40]).load()
# broken!
asyncio.run(ds["data"].sel(x=[30, 40]).load_async())
|
xarray/backends/zarr.py
Outdated
elif isinstance(key, indexing.VectorizedIndexer): | ||
# TODO | ||
method = self._vindex | ||
elif isinstance(key, indexing.OuterIndexer): | ||
# TODO | ||
method = self._oindex |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@ianhi almost certainly these need to become async to fix your bug
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Outer (also known as "Orthogonal") indexing support added in 5eacdb0, but requires changes to zarr-python: zarr-developers/zarr-python#3083
for more information, see https://pre-commit.ci
# test vectorized indexing | ||
# TODO this shouldn't pass! I haven't implemented async vectorized indexing yet... | ||
indexer = xr.DataArray([2, 3], dims=["x"]) | ||
result = await ds.foo[indexer].load_async() | ||
xrt.assert_identical(result, ds.foo[indexer].load()) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This currently passes, even though it shouldn't, because I haven't added support for async vectorized indexing yet!
I think this means that my test is wrong, and what I'm doing here is apparently not vectorized indexing. I'm unsure what my test would have to look like though 😕
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this is an outer indexer. Try xr.DataArray([[2, 3]], dims=["y", "x"])
async def async_getitem(key: indexing.ExplicitIndexer) -> np.typing.ArrayLike: | ||
raise NotImplementedError("Backend does not not support asynchronous loading") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes absolutely.
@@ -267,13 +268,23 @@ def robust_getitem(array, key, catch=Exception, max_retries=6, initial_delay=500 | |||
time.sleep(1e-3 * next_delay) | |||
|
|||
|
|||
class BackendArray(NdimSizeLenMixin, indexing.ExplicitlyIndexed): | |||
class BackendArray(ABC, NdimSizeLenMixin, indexing.ExplicitlyIndexed): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This class is public API and this is a backwards incompatible change.
# load everything else concurrently | ||
coros = [ | ||
v.load_async() for k, v in self.variables.items() if k not in chunked_data | ||
] | ||
await asyncio.gather(*coros) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We should rate-limite all gather
calls with a Semaphore using something like this:
async def async_gather(*coros, concurrency: Optional[int] = None, return_exceptions: bool = False) -> list[Any]:
"""Execute a gather while limiting the number of concurrent tasks.
Args:
coros: coroutines
list of coroutines to execute
concurrency: int
concurrency limit
if None, defaults to config_obj.get('async.concurrency', 4)
if <= 0, no concurrency limit
"""
if concurrency is None:
concurrency = int(config_obj.get("async.concurrency", 4))
if concurrency > 0:
# if concurrency > 0, we use a semaphore to limit the number of concurrent coroutines
semaphore = asyncio.Semaphore(concurrency)
async def sem_coro(coro):
async with semaphore:
return await coro
results = await asyncio.gather(*(sem_coro(c) for c in coros), return_exceptions=return_exceptions)
else:
results = await asyncio.gather(*coros, return_exceptions=return_exceptions)
return results
case "ds": | ||
return ds | ||
|
||
def assert_time_as_expected( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Let's instead use mocks to assert the async methods were called. Xarray's job is to do that only
# test vectorized indexing | ||
# TODO this shouldn't pass! I haven't implemented async vectorized indexing yet... | ||
indexer = xr.DataArray([2, 3], dims=["x"]) | ||
result = await ds.foo[indexer].load_async() | ||
xrt.assert_identical(result, ds.foo[indexer].load()) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this is an outer indexer. Try xr.DataArray([[2, 3]], dims=["y", "x"])
async def _async_ensure_cached(self): | ||
duck_array = await self.array.async_get_duck_array() | ||
self.array = as_indexable(duck_array) | ||
|
||
def get_duck_array(self): | ||
self._ensure_cached() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
_ensure_cached
seems like pointless indirection, it is only used once. let's consolidate.
return self | ||
|
||
async def load_async(self, **kwargs) -> Self: | ||
# TODO refactor this to pull out the common chunked_data codepath |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
let's instead just have the sync methods issue a blocking call to the async versions.
Adds an
.async_load()
method toVariable
, which works by plumbing asyncget_duck_array
all the way down until it finally gets to the async methods zarr v3 exposes.Needs a lot of refactoring before it could be merged, but it works.
whats-new.rst
api.rst
API:
Variable.load_async
DataArray.load_async
Dataset.load_async
DataTree.load_async
load_dataset
?load_dataarray
?