Skip to content

Commit

Permalink
Implement new dask graph mode creation
Browse files Browse the repository at this point in the history
There are now 2 ways to create dask graph:
- "mem" and "concurrency"
- "auto" mode currently always default to "mem"
  mode, unless explicitly set to "concurrency"
  via environment variable
- "mem" is the original way, optimized for
  reducing Dask graph size
- "concurrency" is the new way, optimized for
  reducing Dask graph execution time

Other changes:

- Manually handle layer creation for high level
   graph: `cfg->open->band`
  • Loading branch information
Kirill888 committed Jun 10, 2024
1 parent 841a3a0 commit e40f33a
Show file tree
Hide file tree
Showing 5 changed files with 189 additions and 42 deletions.
196 changes: 160 additions & 36 deletions odc/loader/_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from __future__ import annotations

import dataclasses
import os
from concurrent.futures import ThreadPoolExecutor
from datetime import datetime
from typing import (
Expand Down Expand Up @@ -33,9 +34,15 @@
from odc.geo.geobox import GeoBox, GeoBoxBase, GeoboxTiles
from odc.geo.xr import xr_coords

from ._reader import nodata_mask, resolve_dst_fill_value, resolve_src_nodata
from ._reader import (
ReaderDaskAdaptor,
nodata_mask,
resolve_dst_fill_value,
resolve_src_nodata,
)
from ._utils import SizedIterable, pmap
from .types import (
DaskRasterReader,
MultiBandRasterSource,
RasterGroupMetadata,
RasterLoadParams,
Expand All @@ -45,6 +52,8 @@
T,
)

DaskBuilderMode = Literal["mem", "concurrency"]


class MkArray(Protocol):
"""Internal interface."""
Expand Down Expand Up @@ -137,6 +146,13 @@ def resolve_sources_dask(
]


def _default_dask_mode() -> DaskBuilderMode:
mode = os.environ.get("ODC_STAC_DASK_MODE", "mem")
if mode == "concurrency":
return "concurrency"
return "mem"


class DaskGraphBuilder:
"""
Build xarray from parsed metadata.
Expand All @@ -154,13 +170,17 @@ def __init__(
env: Dict[str, Any],
rdr: ReaderDriver,
chunks: Mapping[str, int],
mode: DaskBuilderMode | Literal["auto"] = "auto",
) -> None:
gbox = gbt.base
assert isinstance(gbox, GeoBox)
# make sure chunks for tyx match our structure
chunk_tyx = (chunks.get("time", 1), *gbt.chunk_shape((0, 0)).yx)
chunks = {**chunks}
chunks.update(dict(zip(["time", "y", "x"], chunk_tyx)))
if mode == "auto":
# "mem" unless overwriten by env var
mode = _default_dask_mode()

self.cfg = cfg
self.template = template
Expand All @@ -169,8 +189,9 @@ def __init__(
self.gbt = gbt
self.env = env
self.rdr = rdr
self._tk = tokenize(srcs, cfg, gbt, tyx_bins, env, chunks)
self._tk = tokenize(srcs, cfg, gbt, tyx_bins, env, chunks, mode)
self._chunks = chunks
self._mode = mode
self._load_state = rdr.new_load(gbox, chunks=chunks)

def _band_chunks(
Expand All @@ -191,7 +212,7 @@ def build(
gbox: GeoBox,
time: Sequence[datetime],
bands: Mapping[str, RasterLoadParams],
):
) -> xr.Dataset:
return mk_dataset(
gbox,
time,
Expand All @@ -200,16 +221,21 @@ def build(
template=self.template,
)

def _prep_sources(
self, name: str, dsk: dict[Key, Any], deps: list[Any]
) -> tuple[str, Any]:
def _norm_load_state(self, cfg_layer: dict[Key, Any]) -> tuple[Any, Any]:
load_state = self._load_state
if is_dask_collection(load_state):
deps.append(load_state)
load_state = load_state.key
cfg_layer.update(load_state.dask)
return load_state, load_state.key

tk = self._tk
src_key = f"open-{name}-{tk}"
return load_state, load_state

def _prep_sources(
self,
name: str,
dsk: dict[Key, Any],
load_state_dsk: Any,
) -> dict[Key, Any]:
src_key = f"open-{name}-{self._tk}"

for src_idx, mbsrc in enumerate(self.srcs):
rsrc = mbsrc.get(name, None)
Expand All @@ -219,9 +245,41 @@ def _prep_sources(
rsrc,
self.rdr,
self.env,
load_state,
load_state_dsk,
)
return src_key, load_state
return dsk

def _dask_rdr(self) -> DaskRasterReader:
if (dask_reader := self.rdr.dask_reader) is not None:
return dask_reader
return ReaderDaskAdaptor(self.rdr, self.env)

def _task_futures(
self,
task: LoadChunkTask,
dask_reader: DaskRasterReader,
layer_name: str,
dsk: dict[Key, Any],
) -> list[list[Key]]:
# pylint: disable=too-many-locals
srcs = task.resolve_sources(self.srcs)
out: list[list[Key]] = []
ctx = self._load_state
cfg = task.cfg
dst_gbox = task.dst_gbox

for i_time, layer in enumerate(srcs, start=task.idx[0]):
keys_out: list[Key] = []
for i_src, src in enumerate(layer):
idx = (i_time, *task.idx[1:], i_src)
rdr = dask_reader.open(src, ctx, layer_name=layer_name)
fut = rdr.read(cfg, dst_gbox, selection=task.selection, idx=idx)
keys_out.append(fut.key)
dsk.update(fut.dask)

out.append(keys_out)

return out

def __call__(
self,
Expand All @@ -237,39 +295,81 @@ def __call__(
cfg = self.cfg[name]
assert dtype == cfg.dtype
assert ydim == cfg.ydim + 1 # +1 for time dimension
chunks = self._band_chunks(name, shape, ydim)

tk = self._tk
deps: list[Any] = []
cfg_dask_key = f"cfg-{tokenize(cfg)}"
gbt_dask_key = f"grid-{tokenize(self.gbt)}"
cfg_dsk = f"cfg-{tokenize(cfg)}"
gbt_dsk = f"grid-{tokenize(self.gbt)}"

cfg_layer, open_layer, band_layer = (
f"cfg-{name}-{tk}",
f"open-{name}-{tk}",
f"{name}-{tk}",
)

dsk: Dict[Key, Any] = {
cfg_dask_key: cfg,
gbt_dask_key: self.gbt,
layers: Dict[str, Dict[Key, Any]] = {
cfg_layer: {
cfg_dsk: cfg,
gbt_dsk: self.gbt,
},
open_layer: {},
band_layer: {},
}
layer_deps: Dict[str, Any] = {
cfg_layer: set(),
open_layer: set([cfg_layer]),
band_layer: set([cfg_layer, open_layer]),
}
src_key, load_state = self._prep_sources(name, dsk, deps)

band_key = f"{name}-{tk}"
chunks = self._band_chunks(name, shape, ydim)
dsk = layers[f"{name}-{tk}"]

dask_reader: DaskRasterReader | None = None
load_state, load_state_dsk = self._norm_load_state(layers[cfg_layer])
assert load_state is load_state_dsk or is_dask_collection(load_state)

if self._mode == "mem":
self._prep_sources(name, layers[open_layer], load_state_dsk)
else:
dask_reader = self._dask_rdr()

fill_value = resolve_dst_fill_value(
np.dtype(dtype),
cfg,
resolve_src_nodata(cfg.fill_value, cfg),
)

for task in self.load_tasks(name, shape[0]):
dsk[(band_key, *task.idx)] = (
_dask_loader_tyx,
task.resolve_sources_dask(src_key, dsk),
gbt_dask_key,
quote(task.idx_tyx[1:]),
quote(task.prefix_dims),
quote(task.postfix_dims),
cfg_dask_key,
self.rdr,
self.env,
load_state,
task.selection,
)
task_key: Key = (band_layer, *task.idx)
if dask_reader is None:
dsk[task_key] = (
_dask_loader_tyx,
task.resolve_sources_dask(open_layer, layers[open_layer]),
gbt_dsk,
quote(task.idx_tyx[1:]),
quote(task.prefix_dims),
quote(task.postfix_dims),
cfg_dsk,
self.rdr,
self.env,
load_state_dsk,
task.selection,
)
else:
srcs_futures = self._task_futures(
task, dask_reader, open_layer, layers[open_layer]
)

dsk = HighLevelGraph.from_collections(band_key, dsk, dependencies=deps)
dsk[task_key] = (
_dask_fuser,
srcs_futures,
task.shape,
dtype,
fill_value,
ydim - 1,
)

return da.Array(dsk, band_key, chunks, dtype=dtype, shape=shape)
dsk = HighLevelGraph(layers, layer_deps)
return da.Array(dsk, band_layer, chunks, dtype=dtype, shape=shape)

def load_tasks(self, name: str, nt: int) -> Iterator[LoadChunkTask]:
return load_tasks(
Expand Down Expand Up @@ -320,6 +420,30 @@ def _dask_loader_tyx(
return chunk


def _dask_fuser(
srcs: list[list[Any]],
shape: tuple[int, ...],
dtype: DTypeLike,
fill_value: float | int,
src_ydim: int = 0,
):
assert shape[0] == len(srcs)
assert len(shape) >= 3 # time, ..., y, x, ...

dst = np.full(shape, fill_value, dtype=dtype)

for ti, layer in enumerate(srcs):
fuse_nd_slices(
layer,
fill_value,
dst[ti],
ydim=src_ydim,
prefilled=True,
)

return dst


def fuse_nd_slices(
srcs: Iterable[tuple[tuple[slice, slice], np.ndarray]],
fill_value: float | int,
Expand Down
16 changes: 13 additions & 3 deletions odc/loader/_reader.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@ def __init__(
env: dict[str, Any] | None = None,
ctx: Any | None = None,
src: RasterSource | None = None,
layer_name: str = "",
) -> None:
if env is None:
env = driver.capture_env()
Expand All @@ -61,17 +62,19 @@ def __init__(
self._env = env
self._ctx = ctx
self._src = src
self._layer_name = layer_name

def read(
self,
cfg: RasterLoadParams,
dst_geobox: GeoBox,
*,
selection: Optional[ReaderSubsetSelection] = None,
idx: tuple[int, ...],
) -> Any:
assert self._src is not None
assert self._ctx is not None
read_op = delayed(_dask_read_adaptor)
read_op = delayed(_dask_read_adaptor, name=self._layer_name)

# TODO: supply `dask_key_name=` that makes sense
return read_op(
Expand All @@ -82,10 +85,17 @@ def read(
self._driver,
self._env,
selection=selection,
dask_key_name=(self._layer_name, *idx),
)

def open(self, src: RasterSource, ctx: Any) -> "ReaderDaskAdaptor":
return ReaderDaskAdaptor(self._driver, self._env, ctx, src)
def open(self, src: RasterSource, ctx: Any, layer_name: str) -> "ReaderDaskAdaptor":
return ReaderDaskAdaptor(
self._driver,
self._env,
ctx,
src,
layer_name=layer_name,
)


def resolve_load_cfg(
Expand Down
3 changes: 3 additions & 0 deletions odc/loader/test_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,12 +149,14 @@ def test_mk_dataset(

@pytest.mark.parametrize("bands,extra_coords,extra_dims,expect", rlp_fixtures)
@pytest.mark.parametrize("chunk_extra_dims", [False, True])
@pytest.mark.parametrize("mode", ["auto", "concurrency"])
def test_dask_builder(
bands: Dict[str, RasterLoadParams],
extra_coords: Sequence[FixedCoord] | None,
extra_dims: Mapping[str, int] | None,
expect: Mapping[str, _sn],
chunk_extra_dims: bool,
mode,
):
_bands = {
k: RasterBandMetadata(b.dtype, b.fill_value, dims=b.dims)
Expand Down Expand Up @@ -195,6 +197,7 @@ def test_dask_builder(
env=rdr_env,
rdr=rdr,
chunks=chunks,
mode=mode,
)

xx = builder.build(gbox, tss, bands)
Expand Down
7 changes: 5 additions & 2 deletions odc/loader/test_reader.py
Original file line number Diff line number Diff line change
Expand Up @@ -380,13 +380,16 @@ def test_dask_reader_adaptor(dtype: str):
ctx = base_driver.new_load(gbox, chunks={"x": 64, "y": 64})

src = RasterSource("mem://", meta=meta)
rdr = driver.open(src, ctx)
rdr = driver.open(src, ctx, layer_name="aa")

assert isinstance(rdr, ReaderDaskAdaptor)

cfg = RasterLoadParams.same_as(src)
xx = rdr.read(cfg, gbox)
xx = rdr.read(cfg, gbox, idx=(0,))
assert is_dask_collection(xx)
assert xx.key == ("aa", 0)
assert rdr.read(cfg, gbox, idx=(1,)).key == ("aa", 1)
assert rdr.read(cfg, gbox, idx=(1, 2, 3)).key == ("aa", 1, 2, 3)

yy = xx.compute(scheduler="synchronous")
assert isinstance(yy, tuple)
Expand Down
Loading

0 comments on commit e40f33a

Please sign in to comment.