Skip to content

Commit

Permalink
refactor: split up dask graph construction
Browse files Browse the repository at this point in the history
  • Loading branch information
Kirill888 committed May 29, 2024
1 parent 50518e5 commit 447b882
Showing 1 changed file with 48 additions and 27 deletions.
75 changes: 48 additions & 27 deletions odc/loader/_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,6 +152,7 @@ def __init__(
self.env = env
self.rdr = rdr
self._tk = tokenize(srcs, cfg, gbt, tyx_bins, env, chunks)
self._chunks = chunks
self.chunk_tyx = (chunks.get("time", 1), *self.gbt.chunk_shape((0, 0)).yx)
self._load_state = rdr.new_load(
gbox, chunks=dict(zip(["time", "y", "x"], self.chunk_tyx))
Expand All @@ -171,6 +172,29 @@ def build(
template=self.template,
)

def _prep_sources(
self, name: str, dsk: dict[Key, Any], deps: list[Any]
) -> tuple[str, Any]:
load_state = self._load_state
if is_dask_collection(load_state):
deps.append(load_state)
load_state = load_state.key

tk = self._tk
src_key = f"open-{name}-{tk}"

for src_idx, mbsrc in enumerate(self.srcs):
rsrc = mbsrc.get(name, None)
if rsrc is not None:
dsk[src_key, src_idx] = (
_dask_open_reader,
rsrc,
self.rdr,
self.env,
load_state,
)
return src_key, load_state

def __call__(
self,
shape: Tuple[int, ...],
Expand All @@ -181,9 +205,24 @@ def __call__(
) -> Any:
# pylint: disable=too-many-locals
assert isinstance(name, str)

cfg = self.cfg[name]
assert dtype == cfg.dtype
assert ydim == cfg.ydim + 1 # +1 for time dimension

tk = self._tk
deps: list[Any] = []
cfg_dask_key = f"cfg-{tokenize(cfg)}"
gbt_dask_key = f"grid-{tokenize(self.gbt)}"

dsk: Dict[Key, Any] = {
cfg_dask_key: cfg,
gbt_dask_key: self.gbt,
}
src_key, load_state = self._prep_sources(name, dsk, deps)

band_key = f"{name}-{tk}"

postfix_dims = shape[ydim + 2 :]
prefix_dims = shape[1:ydim]

Expand All @@ -199,35 +238,8 @@ def __call__(
range(last - n, last) for last, n in zip(np.cumsum(chunks[0]), chunks[0])
]

deps: list[Any] = []
load_state = self._load_state
if is_dask_collection(load_state):
deps.append(load_state)
load_state = load_state.key

cfg_dask_key = f"cfg-{tokenize(cfg)}"
gbt_dask_key = f"grid-{tokenize(self.gbt)}"

dsk: Dict[Key, Any] = {
cfg_dask_key: cfg,
gbt_dask_key: self.gbt,
}
tk = self._tk
band_key = f"{name}-{tk}"
src_key = f"open-{name}-{tk}"
shape_in_blocks = tuple(len(ch) for ch in chunks)

for src_idx, mbsrc in enumerate(self.srcs):
rsrc = mbsrc.get(name, None)
if rsrc is not None:
dsk[src_key, src_idx] = (
_dask_open_reader,
rsrc,
self.rdr,
self.env,
load_state,
)

for block_idx in np.ndindex(shape_in_blocks):
ti, yi, xi = block_idx[0], block_idx[ydim], block_idx[ydim + 1]
srcs_keys: list[list[tuple[str, int]]] = []
Expand Down Expand Up @@ -257,6 +269,15 @@ def __call__(

return da.Array(dsk, band_key, chunks, dtype=dtype, shape=shape)

def load_tasks(self, name: str) -> Iterator[LoadChunkTask]:
return load_tasks(
self.cfg,
self.tyx_bins,
self.gbt,
extra_dims=self.template.extra_dims_full(name),
bands=[name],
)


def _dask_open_reader(
src: RasterSource,
Expand Down

0 comments on commit 447b882

Please sign in to comment.