From 7e3fc910c0b50ce4b2dd10e1898009596ea4bd23 Mon Sep 17 00:00:00 2001 From: Kirill Kouzoubov Date: Sun, 17 Sep 2023 19:39:22 +1000 Subject: [PATCH] Dask cog compressor --- odc/geo/_cog.py | 365 +++++++++++++++++++++++++++++++++++++++++++++- tests/test_cog.py | 56 ++++++- 2 files changed, 416 insertions(+), 5 deletions(-) diff --git a/odc/geo/_cog.py b/odc/geo/_cog.py index 43ad2eaf..2359717f 100644 --- a/odc/geo/_cog.py +++ b/odc/geo/_cog.py @@ -9,9 +9,23 @@ import warnings from contextlib import contextmanager from dataclasses import dataclass, field +from functools import partial from io import BytesIO from pathlib import Path -from typing import Any, Dict, Generator, Iterable, List, Literal, Optional, Tuple, Union +from time import monotonic +from typing import ( + Any, + Callable, + Dict, + Generator, + Iterable, + Iterator, + List, + Literal, + Optional, + Tuple, + Union, +) from uuid import uuid4 import numpy as np @@ -27,6 +41,7 @@ from .warp import resampling_s2rio # pylint: disable=too-many-locals,too-many-branches,too-many-arguments,too-many-statements,too-many-instance-attributes +# pylint: disable=too-many-lines AxisOrder = Union[Literal["YX"], Literal["YXS"], Literal["SYX"]] @@ -62,6 +77,63 @@ def chunks(self) -> Tuple[int, ...]: def pix_shape(self) -> Tuple[int, ...]: return self._pix_shape(self.shape) + @property + def num_planes(self): + if self.axis == "SYX": + return self.nsamples + return 1 + + def flatten(self) -> Tuple["CogMeta", ...]: + return (self, *self.overviews) + + @property + def chunked(self) -> Shape2d: + """ + Shape in chunks. + """ + ny, nx = ((N + n - 1) // n for N, n in zip(self.shape.yx, self.tile.yx)) + return shape_((ny, nx)) + + @property + def num_tiles(self): + ny, nx = self.chunked.yx + return self.num_planes * ny * nx + + def tidx(self) -> Iterator[Tuple[int, int, int]]: + """``[(plane_idx, iy, ix), ...]``""" + yield from np.ndindex((self.num_planes, *self.chunked.yx)) + + def flat_tile_idx(self, idx: Tuple[int, int, int]) -> int: + """Convert from sample,iy,ix to flat tile index.""" + ns = self.num_planes + ny, nx = self.chunked.yx + for n, i in zip((ns, ny, nx), idx): + if i < 0 or i >= n: + raise IndexError() + + sample, y, x = idx + return sample * (ny * nx) + y * nx + x + + def cog_tidx(self) -> Iterator[Tuple[int, int, int, int]]: + """``[(ifd_idx, plane_idx, iy, ix), ...]``""" + idx_layers = list(enumerate(self.flatten()))[::-1] + for idx, mm in idx_layers: + yield from ((idx, pi, yi, xi) for pi, yi, xi in mm.tidx()) + + def __dask_tokenize__(self): + return ( + "odc.CogMeta", + self.axis, + *self.shape.yx, + *self.tile.yx, + self.nsamples, + self.dtype, + self.compression, + self.predictor, + self.gbox, + len(self.overviews), + ) + def _without(cfg: Dict[str, Any], *skip: str) -> Dict[str, Any]: skip = set(skip) @@ -685,12 +757,11 @@ def _sh(shape: Shape2d) -> Tuple[int, ...]: for tsz, idx in zip(_blocks, range(nlevels + 1)): tile = _norm_blocksize(tsz) meta = CogMeta( - ax, im_shape, shape_(tile), nsamples, dtype, _compression, predictor + ax, im_shape, shape_(tile), nsamples, dtype, _compression, predictor, gbox ) if idx == 0: kw = {**opts_common, "extratags": extratags} - meta.gbox = gbox else: kw = {**opts_common, "subfiletype": FILETYPE.REDUCEDIMAGE} @@ -703,6 +774,8 @@ def _sh(shape: Shape2d) -> Tuple[int, ...]: metas.append(meta) im_shape = im_shape.shrink2() + if gbox is not None: + gbox = gbox.zoom_to(im_shape) meta = metas[0] meta.overviews = tuple(metas[1:]) @@ -710,3 +783,289 @@ def _sh(shape: Shape2d) -> Tuple[int, ...]: tw.close() return meta, buf.getbuffer() + + +def _cog_block_compressor( + block: np.ndarray, + *, + tile_shape: Tuple[int, ...] = (), + encoder: Any = None, + predictor: Any = None, + axis: int = 1, + **kw, +) -> bytes: + assert block.ndim == len(tile_shape) + if tile_shape != block.shape: + pad = tuple((0, want - have) for want, have in zip(tile_shape, block.shape)) + block = np.pad(block, pad, "constant", constant_values=(0,)) + + if predictor is not None: + block = predictor(block, axis=axis) + if encoder: + try: + return encoder(block, **kw) + except Exception: # pylint: disable=broad-except + return b"" + + return block.data + + +def mk_tile_compressor( + meta: CogMeta, **encoder_params +) -> Callable[[np.ndarray], bytes]: + """ + Make tile compressor. + + """ + # pylint: disable=import-outside-toplevel + have.check_or_error("tifffile") + from tifffile import TIFF + + tile_shape = meta.chunks + encoder = TIFF.COMPRESSORS[meta.compression] + + # TODO: handle SYX in planar mode + axis = 1 + predictor = None + if meta.predictor != 1: + predictor = TIFF.PREDICTORS[meta.predictor] + + return partial( + _cog_block_compressor, + tile_shape=tile_shape, + encoder=encoder, + predictor=predictor, + axis=axis, + **encoder_params, + ) + + +def _compress_cog_tile(encoder, block, idx): + return [{"data": encoder(block), "idx": idx}] + + +def compress_tiles( + xx: xr.DataArray, + meta: CogMeta, + scale_idx: int = 0, + sample_idx: int = 0, + **encoder_params, +): + """ + Compress chunks according to cog spec. + + :returns: Dask bag of dicts with ``{"data": bytes, "idx": (int, int, int, int)}`` + """ + # pylint: disable=import-outside-toplevel + have.check_or_error("dask") + from dask.bag import Bag + from dask.base import quote, tokenize + from dask.highlevelgraph import HighLevelGraph + + from ._interop import is_dask_collection + + # TODO: deal with SYX planar data + assert meta.num_planes == 1 + assert meta.axis in ("YX", "YXS") + assert meta.num_planes == 1 + src_ydim = 0 # for now assume Y,X or Y,X,S + + encoder = mk_tile_compressor(meta, **encoder_params) + data = xx.data + + assert is_dask_collection(data) + + tk = tokenize( + data, + scale_idx, + meta.axis, + meta.chunks, + meta.predictor, + meta.compression, + encoder_params, + ) + plane_id = "" if scale_idx == 0 else f"_{scale_idx}" + plane_id += "" if sample_idx == 0 else f"@{sample_idx}" + + name = f"compress{plane_id}-{tk}" + + src_data_name = data.name + + def block_name(p, y, x): + if data.ndim == 2: + return (src_data_name, y, x) + if src_ydim == 0: + return (src_data_name, y, x, p) + return (src_data_name, p, y, x) + + dsk = {} + for i, (p, y, x) in enumerate(meta.tidx()): + block = block_name(p, y, x) + dsk[name, i] = (_compress_cog_tile, encoder, block, quote((scale_idx, p, y, x))) + + nparts = len(dsk) + dsk = HighLevelGraph.from_collections(name, dsk, dependencies=[data]) + return Bag(dsk, name, nparts) + + +def _pyramids_from_cog_metadata( + xx: xr.DataArray, + cog_meta: CogMeta, + resampling: Union[str, int] = "nearest", +) -> Tuple[xr.DataArray, ...]: + out = [xx] + + for mm in cog_meta.overviews: + gbox = mm.gbox + out.append( + out[-1].odc.reproject(gbox, chunks=mm.tile.yx, resampling=resampling) + ) + + return tuple(out) + + +def _extact_tile_info( + meta: CogMeta, + tiles: List[Tuple[int, int, int, int, int]], + start_offset: int = 0, +) -> List[Tuple[List[int], List[int]]]: + mm = meta.flatten() + tile_info = [([0] * m.num_tiles, [0] * m.num_tiles) for m in mm] + + byte_offset = start_offset + for scale_idx, p, y, x, sz in tiles: + m = mm[scale_idx] + b_offsets, b_lengths = tile_info[scale_idx] + + tidx = m.flat_tile_idx((p, y, x)) + if sz != 0: + b_lengths[tidx] = sz + b_offsets[tidx] = byte_offset + byte_offset += sz + + return tile_info + + +def _update_header( + meta: CogMeta, + hdr0: bytes, + tiles: List[Tuple[int, int, int, int, int]], +) -> bytes: + # pylint: disable=import-outside-toplevel + from tifffile import TiffFile + + tile_info = _extact_tile_info(meta, tiles, len(hdr0)) + + _bio = BytesIO(hdr0) + with TiffFile(_bio, mode="r+", name=":mem:") as tr: + assert len(tile_info) == len(tr.pages) + + # 324 -- offsets + # 325 -- byte counts + for info, page in zip(tile_info, tr.pages): + tags = page.tags + offsets, lengths = info + tags[324].overwrite(offsets) + tags[325].overwrite(lengths) + + return bytes(_bio.getbuffer()) + + +def save_cog_with_dask( + xx: xr.DataArray, + dst: str = "", + *, + client: Any = None, + compression: Union[str, Unset] = Unset(), + predictor: Union[int, bool, Unset] = Unset(), + blocksize: Union[int, List[Union[int, Tuple[int, int]]]] = 2048, + bigtiff: bool = True, + overview_resampling: Union[int, str] = "nearest", + verbose: bool = False, + encoder_params: Any = None, + **kw, +): + # pylint: disable=import-outside-toplevel + t0 = monotonic() + from dask import bag, delayed + from dask.utils import format_time + + if encoder_params is None: + encoder_params = {} + + # usefull when debugging + optimize_graph = kw.pop("optimize_graph", True) + + meta, hdr0 = make_empty_cog( + xx.shape, + xx.dtype, + xx.odc.geobox, + predictor=predictor, + compression=compression, + blocksize=blocksize, + bigtiff=bigtiff, + **kw, + ) + hdr0 = bytes(hdr0) + + layers = _pyramids_from_cog_metadata(xx, meta, resampling=overview_resampling) + + _tiles = [] + for scale_idx, (mm, img) in enumerate(zip(meta.flatten(), layers)): + _tiles.append(compress_tiles(img, mm, scale_idx=scale_idx, **encoder_params)) + + hdr_info = bag.concat( + [t.map(lambda d: (*d["idx"], len(d["data"]))) for t in _tiles[::-1]] + ) + tile_bytes = bag.concat([t.map(lambda d: d["data"]) for t in _tiles[::-1]]) + + new_hdr = delayed(_update_header, pure=True)(meta, hdr0, hdr_info) + + dbg = { + "hdr_info": hdr_info, + "meta": meta, + "hdr0": hdr0, + "t0": t0, + } + + if dst == "": + return (new_hdr, tile_bytes, dbg) + + if client is None: + return (new_hdr, tile_bytes, dbg) + + def time_past() -> str: + return format_time(monotonic() - t0) + + with open(dst, "wb") as fp: + if verbose: + print(f"[{time_past()}] Will write to: {dst}") + print(f"[{time_past()}] Starting computation on client: {client}]") + + new_hdr, tile_bytes = client.persist( + [new_hdr, tile_bytes], + optimize_graph=optimize_graph, + ) + + if verbose: + print(f"[{time_past()}] Waiting for Dask to compress to RAM...") + + new_hdr = client.compute(new_hdr, optimize_graph=optimize_graph) + new_hdr = new_hdr.result() # blocks + + if verbose: + print(f"[{time_past()}] DONE: compressed data is now in Dask") + print(f"... writing tiles to: {dst}") + + fp.write(new_hdr) + for fut in tile_bytes.to_delayed(): + for chunk in fut.compute(): + fp.write(chunk) + if verbose: + print(".", end="") + if verbose: + print(f"\n[{time_past()}] DONE") + + dbg["t1"] = monotonic() + dbg["total_seconds"] = dbg["t1"] - dbg["t0"] + return new_hdr, tile_bytes, dbg diff --git a/tests/test_cog.py b/tests/test_cog.py index 5471e48b..1f4f7f05 100644 --- a/tests/test_cog.py +++ b/tests/test_cog.py @@ -4,10 +4,16 @@ import pytest -from odc.geo._cog import _compute_cog_spec, _num_overviews, cog_gbox, make_empty_cog +from odc.geo._cog import ( + CogMeta, + _compute_cog_spec, + _num_overviews, + cog_gbox, + make_empty_cog, +) from odc.geo.geobox import GeoBox from odc.geo.gridspec import GridSpec -from odc.geo.types import Unset +from odc.geo.types import Unset, wh_ from odc.geo.xr import xr_zeros gbox_globe = GridSpec.web_tiles(0)[0, 0] @@ -203,3 +209,49 @@ def test_empty_cog(shape, blocksize, expect_ax, dtype, compression, expect_predi assert tsz[0] <= p.chunks[0] < tsz[0] + 16 assert tsz[1] <= p.chunks[1] < tsz[1] + 16 + + +@pytest.mark.parametrize( + "meta", + [ + CogMeta("YX", wh_(256, 256), wh_(128, 128), 1, "int16", 8, 1), + CogMeta("YXS", wh_(500, 256), wh_(128, 128), 3, "uint8", 8, 1), + CogMeta("SYX", wh_(500, 256), wh_(128, 128), 5, "float32", 8, 1), + ], +) +def test_cog_meta(meta: CogMeta): + for idx_flat, idx in enumerate(meta.tidx()): + assert meta.flat_tile_idx(idx) == idx_flat + + assert meta.num_tiles == (meta.chunked.x * meta.chunked.y * meta.num_planes) + assert meta.num_tiles == len(list(meta.tidx())) + + assert len(meta.pix_shape) == {"YX": 2, "SYX": 3, "YXS": 3}[meta.axis] + assert len(meta.pix_shape) == len(meta.chunks) + + if meta.axis in ("YX", "YXS"): + assert meta.num_planes == 1 + assert meta.nsamples >= meta.num_planes + + if meta.axis == "SYX": + assert meta.num_planes == meta.nsamples + + layers = meta.flatten() + for idx in meta.cog_tidx(): + assert isinstance(idx, tuple) + assert len(idx) == 4 + img_idx, s, y, x = idx + tidx = (s, y, x) + assert 0 <= img_idx < len(layers) + assert layers[img_idx].flat_tile_idx(tidx) <= layers[img_idx].num_tiles + + for bad_idx in [ + (0, -1, 0), + (-1, 0, 0), + (0, 0, -1), + (meta.num_planes, 0, 0), + (0, meta.chunked.y, 0), + (0, meta.chunked.y - 1, meta.chunked.x + 2), + ]: + with pytest.raises(IndexError): + _ = meta.flat_tile_idx(bad_idx)