diff --git a/odc/loader/_builder.py b/odc/loader/_builder.py index d2c502f..06136a1 100644 --- a/odc/loader/_builder.py +++ b/odc/loader/_builder.py @@ -122,6 +122,9 @@ def resolve_sources( out.append(_srcs) return out + def resolve_sources_dask(self, dask_key: str) -> list[list[tuple[str, int]]]: + return [[(dask_key, idx) for idx, _ in layer] for layer in self.srcs] + class DaskGraphBuilder: """ @@ -234,35 +237,20 @@ def __call__( ) assert len(chunk_shape) == len(shape) chunks: tuple[tuple[int, ...], ...] = normalize_chunks(chunk_shape, shape) - tchunk_range = [ - range(last - n, last) for last, n in zip(np.cumsum(chunks[0]), chunks[0]) - ] - - shape_in_blocks = tuple(len(ch) for ch in chunks) - - for block_idx in np.ndindex(shape_in_blocks): - ti, yi, xi = block_idx[0], block_idx[ydim], block_idx[ydim + 1] - srcs_keys: list[list[tuple[str, int]]] = [] - for _ti in tchunk_range[ti]: - srcs_keys.append( - [ - (src_key, src_idx) - for src_idx in self.tyx_bins.get((_ti, yi, xi), []) - if (src_key, src_idx) in dsk - ] - ) - dsk[(band_key, *block_idx)] = ( + for task in self.load_tasks(name): + dsk[(band_key, *task.idx)] = ( _dask_loader_tyx, - srcs_keys, + task.resolve_sources_dask(src_key), gbt_dask_key, - quote((yi, xi)), - quote(prefix_dims), - quote(postfix_dims), + quote(task.idx_tyx), + quote(task.prefix_dims), + quote(task.postfix_dims), cfg_dask_key, self.rdr, self.env, load_state, + task.selection, ) dsk = HighLevelGraph.from_collections(band_key, dsk, dependencies=deps) @@ -270,10 +258,17 @@ def __call__( return da.Array(dsk, band_key, chunks, dtype=dtype, shape=shape) def load_tasks(self, name: str) -> Iterator[LoadChunkTask]: + chunks = {**self._chunks} + + # make sure chunks for tyx match our structure + chunks.update(dict(zip(["time", "y", "x"], self.chunk_tyx))) + return load_tasks( self.cfg, self.tyx_bins, self.gbt, + nt=len(self.tyx_bins), + chunks=chunks, extra_dims=self.template.extra_dims_full(name), bands=[name], ) @@ -299,6 +294,7 @@ def _dask_loader_tyx( rdr: ReaderDriver, env: Dict[str, Any], load_state: Any, + selection: Any | None = None, ): assert cfg.dtype is not None gbox = cast(GeoBox, gbt[iyx]) @@ -309,7 +305,9 @@ def _dask_loader_tyx( ydim = len(prefix_dims) with rdr.restore_env(env, load_state): for ti, ti_srcs in enumerate(srcs): - _fill_nd_slice(ti_srcs, gbox, cfg, chunk[ti], ydim=ydim) + _fill_nd_slice( + ti_srcs, gbox, cfg, chunk[ti], ydim=ydim, selection=selection + ) return chunk @@ -319,12 +317,14 @@ def _fill_nd_slice( cfg: RasterLoadParams, dst: Any, ydim: int = 0, + selection: Any | None = None, ) -> Any: # TODO: support masks not just nodata based fusing # # ``nodata`` marks missing pixels, but it might be None (everything is valid) # ``fill_value`` is the initial value to use, it's equal to ``nodata`` when set, # otherwise defaults to .nan for floats and 0 for integers + # pylint: disable=too-many-locals assert dst.shape[ydim : ydim + 2] == dst_gbox.shape.yx postfix_roi = (slice(None),) * len(dst.shape[ydim + 2 :]) @@ -338,7 +338,7 @@ def _fill_nd_slice( return dst src, *rest = srcs - yx_roi, pix = src.read(cfg, dst_gbox, dst=dst) + yx_roi, pix = src.read(cfg, dst_gbox, dst=dst, selection=selection) assert len(yx_roi) == 2 assert pix.ndim == dst.ndim @@ -612,6 +612,8 @@ def load_tasks( ) if len(selection) == 1: selection = selection[0] + if shape_in_chunks[3] == 1: + selection = None yield LoadChunkTask( band_name, diff --git a/odc/stac/_stac_load.py b/odc/stac/_stac_load.py index 3659a0a..5d216f0 100644 --- a/odc/stac/_stac_load.py +++ b/odc/stac/_stac_load.py @@ -442,7 +442,7 @@ def _with_debug_info(ds: xr.Dataset, **kw) -> xr.Dataset: return _with_debug_info( chunked_load( load_cfg, - collection.meta, + collection.meta_for(bands), _parsed, tyx_bins, gbt, diff --git a/odc/stac/model.py b/odc/stac/model.py index eaecab8..8b68185 100644 --- a/odc/stac/model.py +++ b/odc/stac/model.py @@ -166,6 +166,17 @@ def __getitem__(self, band: BandIdentifier) -> RasterBandMetadata: def bands(self) -> Dict[BandKey, RasterBandMetadata]: return self.meta.bands + def meta_for(self, bands: BandQuery = None) -> RasterGroupMetadata: + """ + Extract raster group metadata for a subset of bands. + + Output uses supplied band names as keys, effectively replacing canonical + names with aliases supplied by the user. + """ + return self.meta.patch( + bands={norm_key(b): self[b] for b in self.normalize_band_query(bands)} + ) + @property def aliases(self) -> Dict[str, List[BandKey]]: return self.meta.aliases