Skip to content

Commit

Permalink
amend me: cog compressor
Browse files Browse the repository at this point in the history
  • Loading branch information
Kirill888 committed Sep 17, 2023
1 parent 8b3a7e6 commit cfab5e9
Showing 1 changed file with 157 additions and 3 deletions.
160 changes: 157 additions & 3 deletions odc/geo/_cog.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,22 @@
import warnings
from contextlib import contextmanager
from dataclasses import dataclass, field
from functools import partial
from io import BytesIO
from pathlib import Path
from typing import Any, Dict, Generator, Iterable, List, Literal, Optional, Tuple, Union
from typing import (
Any,
Callable,
Dict,
Generator,
Iterable,
Iterator,
List,
Literal,
Optional,
Tuple,
Union,
)
from uuid import uuid4

import numpy as np
Expand Down Expand Up @@ -62,6 +75,38 @@ def chunks(self) -> Tuple[int, ...]:
def pix_shape(self) -> Tuple[int, ...]:
return self._pix_shape(self.shape)

@property
def num_planes(self):
if self.axis == "SYX":
return self.nsamples
return 1

def flatten(self) -> Tuple["CogMeta", ...]:
return (self, *self.overviews)

@property
def chunked(self) -> Shape2d:
"""
Shape in chunks.
"""
ny, nx = ((N + n - 1) // n for N, n in zip(self.shape.yx, self.tile.yx))
return shape_((ny, nx))

@property
def num_tiles(self):
ny, nx = self.chunked.yx
return self.num_planes * ny * nx

def tidx(self) -> Iterator[Tuple[int, int, int]]:
"""``[(plane_idx, iy, ix), ...]``"""
yield from np.ndindex((self.num_planes, *self.chunked.yx))

def cog_tidx(self) -> Iterator[Tuple[int, int, int, int]]:
"""``[(ifd_idx, plane_idx, iy, ix), ...]``"""
idx_layers = list(enumerate(self.flatten()))[::-1]
for idx, mm in idx_layers:
yield from ((idx, pi, yi, xi) for pi, yi, xi in mm.tidx())


def _without(cfg: Dict[str, Any], *skip: str) -> Dict[str, Any]:
skip = set(skip)
Expand Down Expand Up @@ -685,12 +730,11 @@ def _sh(shape: Shape2d) -> Tuple[int, ...]:
for tsz, idx in zip(_blocks, range(nlevels + 1)):
tile = _norm_blocksize(tsz)
meta = CogMeta(
ax, im_shape, shape_(tile), nsamples, dtype, _compression, predictor
ax, im_shape, shape_(tile), nsamples, dtype, _compression, predictor, gbox
)

if idx == 0:
kw = {**opts_common, "extratags": extratags}
meta.gbox = gbox
else:
kw = {**opts_common, "subfiletype": FILETYPE.REDUCEDIMAGE}

Expand All @@ -703,10 +747,120 @@ def _sh(shape: Shape2d) -> Tuple[int, ...]:

metas.append(meta)
im_shape = im_shape.shrink2()
if gbox is not None:
gbox = gbox.zoom_to(im_shape)

meta = metas[0]
meta.overviews = tuple(metas[1:])

tw.close()

return meta, buf.getbuffer()


def _cog_block_compressor(
block: np.ndarray,
*,
tile_shape: Tuple[int, ...] = (),
encoder: Any = None,
predictor: Any = None,
axis: int = 1,
**kw,
) -> bytes:
assert block.ndim == len(tile_shape)
if tile_shape != block.shape:
pad = tuple((0, want - have) for want, have in zip(tile_shape, block.shape))
block = np.pad(block, pad, "constant", constant_values=(0,))

if predictor is not None:
block = predictor(block, axis=axis)
if encoder:
return encoder(block, **kw)
return block.data


def mk_tile_compressor(
meta: CogMeta, **encoder_params
) -> Callable[[np.ndarray], bytes]:
"""
Make tile compressor.
"""
# pylint: disable=import-outside-toplevel
have.check_or_error("tifffile")
from tifffile import TIFF

tile_shape = meta.chunks
encoder = TIFF.COMPRESSORS[meta.compression]

# TODO: handle SYX in planar mode
axis = 1
predictor = None
if meta.predictor != 1:
predictor = TIFF.PREDICTORS[meta.predictor]

return partial(
_cog_block_compressor,
tile_shape=tile_shape,
encoder=encoder,
predictor=predictor,
axis=axis,
**encoder_params,
)


def _compress_cog_tile(encoder, block, idx):
return [{"data": encoder(block), "idx": idx}]


def compress_tiles(
xx: xr.DataArray,
meta: CogMeta,
plane_idx: Tuple[int, int] = (0, 0),
**encoder_params,
):
"""
Compress chunks according to cog spec.
:returns: Dask bag of dicts with ``{"data": bytes, "idx": (int, int, int)}``
"""
# pylint: disable=import-outside-toplevel
have.check_or_error("dask")
from dask.bag import Bag
from dask.base import quote, tokenize
from dask.highlevelgraph import HighLevelGraph

from ._interop import is_dask_collection

# TODO:
assert meta.num_planes == 1
assert meta.axis in ("YX", "YXS")

encoder = mk_tile_compressor(meta, **encoder_params)
data = xx.data

assert is_dask_collection(data)

tk = tokenize(
data,
plane_idx,
meta.axis,
meta.chunks,
meta.predictor,
meta.compression,
encoder_params,
)
scale_idx, sample_idx = plane_idx
plane_id = "" if scale_idx == 0 else f"_{scale_idx}"
plane_id += "" if sample_idx == 0 else f"@{sample_idx}"

name = f"compress{plane_id}-{tk}"

dsk = {}
block_names = ((data.name, *idx) for idx in np.ndindex(data.blocks.shape))
for i, b in enumerate(block_names):
dsk[name, i] = (_compress_cog_tile, encoder, b, quote(plane_idx + (i,)))

nparts = len(dsk)
dsk = HighLevelGraph.from_collections(name, dsk, dependencies=[data])
return Bag(dsk, name, nparts)

0 comments on commit cfab5e9

Please sign in to comment.