From 447b88206b8f1512eba1778bce2ee9610ef6b2dd Mon Sep 17 00:00:00 2001 From: Kirill Kouzoubov Date: Wed, 29 May 2024 14:44:35 +1000 Subject: [PATCH] refactor: split up dask graph construction --- odc/loader/_builder.py | 75 +++++++++++++++++++++++++++--------------- 1 file changed, 48 insertions(+), 27 deletions(-) diff --git a/odc/loader/_builder.py b/odc/loader/_builder.py index 2768c45..d2c502f 100644 --- a/odc/loader/_builder.py +++ b/odc/loader/_builder.py @@ -152,6 +152,7 @@ def __init__( self.env = env self.rdr = rdr self._tk = tokenize(srcs, cfg, gbt, tyx_bins, env, chunks) + self._chunks = chunks self.chunk_tyx = (chunks.get("time", 1), *self.gbt.chunk_shape((0, 0)).yx) self._load_state = rdr.new_load( gbox, chunks=dict(zip(["time", "y", "x"], self.chunk_tyx)) @@ -171,6 +172,29 @@ def build( template=self.template, ) + def _prep_sources( + self, name: str, dsk: dict[Key, Any], deps: list[Any] + ) -> tuple[str, Any]: + load_state = self._load_state + if is_dask_collection(load_state): + deps.append(load_state) + load_state = load_state.key + + tk = self._tk + src_key = f"open-{name}-{tk}" + + for src_idx, mbsrc in enumerate(self.srcs): + rsrc = mbsrc.get(name, None) + if rsrc is not None: + dsk[src_key, src_idx] = ( + _dask_open_reader, + rsrc, + self.rdr, + self.env, + load_state, + ) + return src_key, load_state + def __call__( self, shape: Tuple[int, ...], @@ -181,9 +205,24 @@ def __call__( ) -> Any: # pylint: disable=too-many-locals assert isinstance(name, str) + cfg = self.cfg[name] assert dtype == cfg.dtype assert ydim == cfg.ydim + 1 # +1 for time dimension + + tk = self._tk + deps: list[Any] = [] + cfg_dask_key = f"cfg-{tokenize(cfg)}" + gbt_dask_key = f"grid-{tokenize(self.gbt)}" + + dsk: Dict[Key, Any] = { + cfg_dask_key: cfg, + gbt_dask_key: self.gbt, + } + src_key, load_state = self._prep_sources(name, dsk, deps) + + band_key = f"{name}-{tk}" + postfix_dims = shape[ydim + 2 :] prefix_dims = shape[1:ydim] @@ -199,35 +238,8 @@ def __call__( range(last - n, last) for last, n in zip(np.cumsum(chunks[0]), chunks[0]) ] - deps: list[Any] = [] - load_state = self._load_state - if is_dask_collection(load_state): - deps.append(load_state) - load_state = load_state.key - - cfg_dask_key = f"cfg-{tokenize(cfg)}" - gbt_dask_key = f"grid-{tokenize(self.gbt)}" - - dsk: Dict[Key, Any] = { - cfg_dask_key: cfg, - gbt_dask_key: self.gbt, - } - tk = self._tk - band_key = f"{name}-{tk}" - src_key = f"open-{name}-{tk}" shape_in_blocks = tuple(len(ch) for ch in chunks) - for src_idx, mbsrc in enumerate(self.srcs): - rsrc = mbsrc.get(name, None) - if rsrc is not None: - dsk[src_key, src_idx] = ( - _dask_open_reader, - rsrc, - self.rdr, - self.env, - load_state, - ) - for block_idx in np.ndindex(shape_in_blocks): ti, yi, xi = block_idx[0], block_idx[ydim], block_idx[ydim + 1] srcs_keys: list[list[tuple[str, int]]] = [] @@ -257,6 +269,15 @@ def __call__( return da.Array(dsk, band_key, chunks, dtype=dtype, shape=shape) + def load_tasks(self, name: str) -> Iterator[LoadChunkTask]: + return load_tasks( + self.cfg, + self.tyx_bins, + self.gbt, + extra_dims=self.template.extra_dims_full(name), + bands=[name], + ) + def _dask_open_reader( src: RasterSource,