diff --git a/odc/geo/_cog.py b/odc/geo/_cog.py index 43ad2eaf..3a5b438e 100644 --- a/odc/geo/_cog.py +++ b/odc/geo/_cog.py @@ -9,9 +9,22 @@ import warnings from contextlib import contextmanager from dataclasses import dataclass, field +from functools import partial from io import BytesIO from pathlib import Path -from typing import Any, Dict, Generator, Iterable, List, Literal, Optional, Tuple, Union +from typing import ( + Any, + Callable, + Dict, + Generator, + Iterable, + Iterator, + List, + Literal, + Optional, + Tuple, + Union, +) from uuid import uuid4 import numpy as np @@ -62,6 +75,38 @@ def chunks(self) -> Tuple[int, ...]: def pix_shape(self) -> Tuple[int, ...]: return self._pix_shape(self.shape) + @property + def num_planes(self): + if self.axis == "SYX": + return self.nsamples + return 1 + + def flatten(self) -> Tuple["CogMeta", ...]: + return (self, *self.overviews) + + @property + def chunked(self) -> Shape2d: + """ + Shape in chunks. + """ + ny, nx = ((N + n - 1) // n for N, n in zip(self.shape.yx, self.tile.yx)) + return shape_((ny, nx)) + + @property + def num_tiles(self): + ny, nx = self.chunked.yx + return self.num_planes * ny * nx + + def tidx(self) -> Iterator[Tuple[int, int, int]]: + """``[(plane_idx, iy, ix), ...]``""" + yield from np.ndindex((self.num_planes, *self.chunked.yx)) + + def cog_tidx(self) -> Iterator[Tuple[int, int, int, int]]: + """``[(ifd_idx, plane_idx, iy, ix), ...]``""" + idx_layers = list(enumerate(self.flatten()))[::-1] + for idx, mm in idx_layers: + yield from ((idx, pi, yi, xi) for pi, yi, xi in mm.tidx()) + def _without(cfg: Dict[str, Any], *skip: str) -> Dict[str, Any]: skip = set(skip) @@ -685,12 +730,11 @@ def _sh(shape: Shape2d) -> Tuple[int, ...]: for tsz, idx in zip(_blocks, range(nlevels + 1)): tile = _norm_blocksize(tsz) meta = CogMeta( - ax, im_shape, shape_(tile), nsamples, dtype, _compression, predictor + ax, im_shape, shape_(tile), nsamples, dtype, _compression, predictor, gbox ) if idx == 0: kw = {**opts_common, "extratags": extratags} - meta.gbox = gbox else: kw = {**opts_common, "subfiletype": FILETYPE.REDUCEDIMAGE} @@ -703,6 +747,8 @@ def _sh(shape: Shape2d) -> Tuple[int, ...]: metas.append(meta) im_shape = im_shape.shrink2() + if gbox is not None: + gbox = gbox.zoom_to(im_shape) meta = metas[0] meta.overviews = tuple(metas[1:]) @@ -710,3 +756,111 @@ def _sh(shape: Shape2d) -> Tuple[int, ...]: tw.close() return meta, buf.getbuffer() + + +def _cog_block_compressor( + block: np.ndarray, + *, + tile_shape: Tuple[int, ...] = (), + encoder: Any = None, + predictor: Any = None, + axis: int = 1, + **kw, +) -> bytes: + assert block.ndim == len(tile_shape) + if tile_shape != block.shape: + pad = tuple((0, want - have) for want, have in zip(tile_shape, block.shape)) + block = np.pad(block, pad, "constant", constant_values=(0,)) + + if predictor is not None: + block = predictor(block, axis=axis) + if encoder: + return encoder(block, **kw) + return block.data + + +def mk_tile_compressor( + meta: CogMeta, **encoder_params +) -> Callable[[np.ndarray], bytes]: + """ + Make tile compressor. + + """ + # pylint: disable=import-outside-toplevel + have.check_or_error("tifffile") + from tifffile import TIFF + + tile_shape = meta.chunks + encoder = TIFF.COMPRESSORS[meta.compression] + + # TODO: handle SYX in planar mode + axis = 1 + predictor = None + if meta.predictor != 1: + predictor = TIFF.PREDICTORS[meta.predictor] + + return partial( + _cog_block_compressor, + tile_shape=tile_shape, + encoder=encoder, + predictor=predictor, + axis=axis, + **encoder_params, + ) + + +def _compress_cog_tile(encoder, block, idx): + return [{"data": encoder(block), "idx": idx}] + + +def compress_tiles( + xx: xr.DataArray, + meta: CogMeta, + plane_idx: Tuple[int, int] = (0, 0), + **encoder_params, +): + """ + Compress chunks according to cog spec. + + :returns: Dask bag of dicts with ``{"data": bytes, "idx": (int, int, int)}`` + """ + # pylint: disable=import-outside-toplevel + have.check_or_error("dask") + from dask.bag import Bag + from dask.base import quote, tokenize + from dask.highlevelgraph import HighLevelGraph + + from ._interop import is_dask_collection + + # TODO: + assert meta.num_planes == 1 + assert meta.axis in ("YX", "YXS") + + encoder = mk_tile_compressor(meta, **encoder_params) + data = xx.data + + assert is_dask_collection(data) + + tk = tokenize( + data, + plane_idx, + meta.axis, + meta.chunks, + meta.predictor, + meta.compression, + encoder_params, + ) + scale_idx, sample_idx = plane_idx + plane_id = "" if scale_idx == 0 else f"_{scale_idx}" + plane_id += "" if sample_idx == 0 else f"@{sample_idx}" + + name = f"compress{plane_id}-{tk}" + + dsk = {} + block_names = ((data.name, *idx) for idx in np.ndindex(data.blocks.shape)) + for i, b in enumerate(block_names): + dsk[name, i] = (_compress_cog_tile, encoder, b, quote(plane_idx + (i,))) + + nparts = len(dsk) + dsk = HighLevelGraph.from_collections(name, dsk, dependencies=[data]) + return Bag(dsk, name, nparts)