Skip to content

Commit cfab5e9

Browse files
committed
amend me: cog compressor
1 parent 8b3a7e6 commit cfab5e9

File tree

1 file changed

+157
-3
lines changed

1 file changed

+157
-3
lines changed

odc/geo/_cog.py

Lines changed: 157 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -9,9 +9,22 @@
99
import warnings
1010
from contextlib import contextmanager
1111
from dataclasses import dataclass, field
12+
from functools import partial
1213
from io import BytesIO
1314
from pathlib import Path
14-
from typing import Any, Dict, Generator, Iterable, List, Literal, Optional, Tuple, Union
15+
from typing import (
16+
Any,
17+
Callable,
18+
Dict,
19+
Generator,
20+
Iterable,
21+
Iterator,
22+
List,
23+
Literal,
24+
Optional,
25+
Tuple,
26+
Union,
27+
)
1528
from uuid import uuid4
1629

1730
import numpy as np
@@ -62,6 +75,38 @@ def chunks(self) -> Tuple[int, ...]:
6275
def pix_shape(self) -> Tuple[int, ...]:
6376
return self._pix_shape(self.shape)
6477

78+
@property
79+
def num_planes(self):
80+
if self.axis == "SYX":
81+
return self.nsamples
82+
return 1
83+
84+
def flatten(self) -> Tuple["CogMeta", ...]:
85+
return (self, *self.overviews)
86+
87+
@property
88+
def chunked(self) -> Shape2d:
89+
"""
90+
Shape in chunks.
91+
"""
92+
ny, nx = ((N + n - 1) // n for N, n in zip(self.shape.yx, self.tile.yx))
93+
return shape_((ny, nx))
94+
95+
@property
96+
def num_tiles(self):
97+
ny, nx = self.chunked.yx
98+
return self.num_planes * ny * nx
99+
100+
def tidx(self) -> Iterator[Tuple[int, int, int]]:
101+
"""``[(plane_idx, iy, ix), ...]``"""
102+
yield from np.ndindex((self.num_planes, *self.chunked.yx))
103+
104+
def cog_tidx(self) -> Iterator[Tuple[int, int, int, int]]:
105+
"""``[(ifd_idx, plane_idx, iy, ix), ...]``"""
106+
idx_layers = list(enumerate(self.flatten()))[::-1]
107+
for idx, mm in idx_layers:
108+
yield from ((idx, pi, yi, xi) for pi, yi, xi in mm.tidx())
109+
65110

66111
def _without(cfg: Dict[str, Any], *skip: str) -> Dict[str, Any]:
67112
skip = set(skip)
@@ -685,12 +730,11 @@ def _sh(shape: Shape2d) -> Tuple[int, ...]:
685730
for tsz, idx in zip(_blocks, range(nlevels + 1)):
686731
tile = _norm_blocksize(tsz)
687732
meta = CogMeta(
688-
ax, im_shape, shape_(tile), nsamples, dtype, _compression, predictor
733+
ax, im_shape, shape_(tile), nsamples, dtype, _compression, predictor, gbox
689734
)
690735

691736
if idx == 0:
692737
kw = {**opts_common, "extratags": extratags}
693-
meta.gbox = gbox
694738
else:
695739
kw = {**opts_common, "subfiletype": FILETYPE.REDUCEDIMAGE}
696740

@@ -703,10 +747,120 @@ def _sh(shape: Shape2d) -> Tuple[int, ...]:
703747

704748
metas.append(meta)
705749
im_shape = im_shape.shrink2()
750+
if gbox is not None:
751+
gbox = gbox.zoom_to(im_shape)
706752

707753
meta = metas[0]
708754
meta.overviews = tuple(metas[1:])
709755

710756
tw.close()
711757

712758
return meta, buf.getbuffer()
759+
760+
761+
def _cog_block_compressor(
762+
block: np.ndarray,
763+
*,
764+
tile_shape: Tuple[int, ...] = (),
765+
encoder: Any = None,
766+
predictor: Any = None,
767+
axis: int = 1,
768+
**kw,
769+
) -> bytes:
770+
assert block.ndim == len(tile_shape)
771+
if tile_shape != block.shape:
772+
pad = tuple((0, want - have) for want, have in zip(tile_shape, block.shape))
773+
block = np.pad(block, pad, "constant", constant_values=(0,))
774+
775+
if predictor is not None:
776+
block = predictor(block, axis=axis)
777+
if encoder:
778+
return encoder(block, **kw)
779+
return block.data
780+
781+
782+
def mk_tile_compressor(
783+
meta: CogMeta, **encoder_params
784+
) -> Callable[[np.ndarray], bytes]:
785+
"""
786+
Make tile compressor.
787+
788+
"""
789+
# pylint: disable=import-outside-toplevel
790+
have.check_or_error("tifffile")
791+
from tifffile import TIFF
792+
793+
tile_shape = meta.chunks
794+
encoder = TIFF.COMPRESSORS[meta.compression]
795+
796+
# TODO: handle SYX in planar mode
797+
axis = 1
798+
predictor = None
799+
if meta.predictor != 1:
800+
predictor = TIFF.PREDICTORS[meta.predictor]
801+
802+
return partial(
803+
_cog_block_compressor,
804+
tile_shape=tile_shape,
805+
encoder=encoder,
806+
predictor=predictor,
807+
axis=axis,
808+
**encoder_params,
809+
)
810+
811+
812+
def _compress_cog_tile(encoder, block, idx):
813+
return [{"data": encoder(block), "idx": idx}]
814+
815+
816+
def compress_tiles(
817+
xx: xr.DataArray,
818+
meta: CogMeta,
819+
plane_idx: Tuple[int, int] = (0, 0),
820+
**encoder_params,
821+
):
822+
"""
823+
Compress chunks according to cog spec.
824+
825+
:returns: Dask bag of dicts with ``{"data": bytes, "idx": (int, int, int)}``
826+
"""
827+
# pylint: disable=import-outside-toplevel
828+
have.check_or_error("dask")
829+
from dask.bag import Bag
830+
from dask.base import quote, tokenize
831+
from dask.highlevelgraph import HighLevelGraph
832+
833+
from ._interop import is_dask_collection
834+
835+
# TODO:
836+
assert meta.num_planes == 1
837+
assert meta.axis in ("YX", "YXS")
838+
839+
encoder = mk_tile_compressor(meta, **encoder_params)
840+
data = xx.data
841+
842+
assert is_dask_collection(data)
843+
844+
tk = tokenize(
845+
data,
846+
plane_idx,
847+
meta.axis,
848+
meta.chunks,
849+
meta.predictor,
850+
meta.compression,
851+
encoder_params,
852+
)
853+
scale_idx, sample_idx = plane_idx
854+
plane_id = "" if scale_idx == 0 else f"_{scale_idx}"
855+
plane_id += "" if sample_idx == 0 else f"@{sample_idx}"
856+
857+
name = f"compress{plane_id}-{tk}"
858+
859+
dsk = {}
860+
block_names = ((data.name, *idx) for idx in np.ndindex(data.blocks.shape))
861+
for i, b in enumerate(block_names):
862+
dsk[name, i] = (_compress_cog_tile, encoder, b, quote(plane_idx + (i,)))
863+
864+
nparts = len(dsk)
865+
dsk = HighLevelGraph.from_collections(name, dsk, dependencies=[data])
866+
return Bag(dsk, name, nparts)

0 commit comments

Comments
 (0)