diff --git a/odc/geo/_cog.py b/odc/geo/_cog.py index 71cc7acd..3add9cb7 100644 --- a/odc/geo/_cog.py +++ b/odc/geo/_cog.py @@ -5,10 +5,12 @@ """ Write Cloud Optimized GeoTIFFs from xarrays. """ +import itertools import warnings from contextlib import contextmanager +from io import BytesIO from pathlib import Path -from typing import Any, Dict, Generator, Iterable, List, Optional, Tuple, Union +from typing import Any, Dict, Generator, Iterable, List, Literal, Optional, Tuple, Union from uuid import uuid4 import numpy as np @@ -16,6 +18,8 @@ import xarray as xr from rasterio.shutil import copy as rio_copy # pylint: disable=no-name-in-module +from ._interop import have +from .converters import geotiff_metadata from .geobox import GeoBox from .math import align_down_pow2, align_up from .types import MaybeNodata, Shape2d, SomeShape, Unset, shape_, wh_ @@ -23,6 +27,8 @@ # pylint: disable=too-many-locals,too-many-branches,too-many-arguments,too-many-statements +AxisOrder = Union[Literal["YX"], Literal["YXS"], Literal["SYX"]] + def _without(cfg: Dict[str, Any], *skip: str) -> Dict[str, Any]: skip = set(skip) @@ -60,6 +66,15 @@ def _adjust_blocksize(block: int, dim: int = 0) -> int: return align_up(block, 16) +def _norm_blocksize(block: Union[int, Tuple[int, int]]) -> Tuple[int, int]: + if isinstance(block, int): + block = _adjust_blocksize(block) + return (block, block) + + b1, b2 = map(_adjust_blocksize, block) + return (b1, b2) + + def _num_overviews(block: int, dim: int) -> int: c = 0 while block < dim: @@ -480,3 +495,127 @@ def write_cog_layers( else: rio_copy(temp_fname, dst, copy_src_overviews=True, **rio_opts) return Path(dst) + + +def _yaxis_from_shape( + shape: Tuple[int, ...], gbox: Optional[GeoBox] = None +) -> Tuple[AxisOrder, int]: + ndim = len(shape) + + if ndim == 2: + return "YX", 0 + + if ndim != 3: + raise ValueError("Can only work with 2-d or 3-d data") + if shape[-1] in (3, 4): # YXS in RGB(A) + return "YXS", 0 + + if gbox is None: + return "SYX", 1 + if gbox.shape == shape[:2]: # YXS + return "YXS", 0 + if gbox.shape == shape[1:]: # SYX + return "SYX", 1 + + raise ValueError("Geobox and image shape do not match") + + +def make_empty_cog( + shape: Tuple[int, ...], + dtype: Any, + gbox: Optional[GeoBox] = None, + *, + nodata: MaybeNodata = None, + gdal_metadata: Optional[str] = None, + compression: Union[str, Unset] = Unset(), + predictor: Union[int, str, bool, Unset] = Unset(), + blocksize: Union[int, List[Union[int, Tuple[int, int]]]] = 2048, + **kw, +) -> memoryview: + # pylint: disable=import-outside-toplevel + have.check_or_error("tifffile", "rasterio", "xarray") + from tifffile import ( + COMPRESSION, + FILETYPE, + PHOTOMETRIC, + PLANARCONFIG, + TIFF, + TiffWriter, + ) + + if isinstance(compression, Unset): + compression = str(kw.pop("compress", "ADOBE_DEFLATE")) + compression = compression.upper() + compression = {"DEFLATE": "ADOBE_DEFLATE"}.get(compression, compression) + compression = COMPRESSION[compression] + + if isinstance(predictor, Unset): + predictor = compression not in TIFF.IMAGE_COMPRESSIONS + + if isinstance(blocksize, int): + blocksize = [blocksize] + + ax, yaxis = _yaxis_from_shape(shape, gbox) + im_shape = shape_(shape[yaxis : yaxis + 2]) + photometric = PHOTOMETRIC.MINISBLACK + planarconfig = PLANARCONFIG.SEPARATE + if ax == "YX": + nsamples = 1 + elif ax == "YXS": + nsamples = shape[-1] + planarconfig = PLANARCONFIG.CONTIG + if nsamples in (3, 4): + photometric = PHOTOMETRIC.RGB + else: + nsamples = shape[0] + + extratags: List[Tuple[int, int, int, Any]] = [] + if gbox is not None: + extratags, _ = geotiff_metadata( + gbox, nodata=nodata, gdal_metadata=gdal_metadata + ) + + buf = BytesIO() + + opts_common = { + "dtype": dtype, + "photometric": photometric, + "planarconfig": planarconfig, + "predictor": predictor, + "compression": compression, + "software": False, + **kw, + } + + def _sh(shape: Shape2d) -> Tuple[int, ...]: + if ax == "YX": + return shape.shape + if ax == "YXS": + return (*shape.shape, nsamples) + return (nsamples, *shape.shape) + + tsz = _norm_blocksize(blocksize[-1]) + im_shape, _, nlevels = _compute_cog_spec(im_shape, tsz) + + _blocks = itertools.chain(iter(blocksize), itertools.repeat(blocksize[-1])) + + tw = TiffWriter(buf, bigtiff=True) + + for tsz, idx in zip(_blocks, range(nlevels + 1)): + if idx == 0: + kw = {**opts_common, "extratags": extratags} + else: + kw = {**opts_common, "subfiletype": FILETYPE.REDUCEDIMAGE} + + tw.write( + itertools.repeat(b""), + shape=_sh(im_shape), + tile=_norm_blocksize(tsz), + **kw, + ) + + im_shape = im_shape.shrink2() + + tw.close() + + return buf.getbuffer() diff --git a/tests/test_cog.py b/tests/test_cog.py index ce61d000..74aec53b 100644 --- a/tests/test_cog.py +++ b/tests/test_cog.py @@ -1,9 +1,12 @@ +import itertools +from io import BytesIO from typing import Optional, Tuple import pytest -from odc.geo._cog import _compute_cog_spec, _num_overviews +from odc.geo._cog import _compute_cog_spec, _num_overviews, make_empty_cog from odc.geo.gridspec import GridSpec +from odc.geo.types import Unset from odc.geo.xr import xr_zeros @@ -87,3 +90,82 @@ def test_cog_spec( if max_pad is not None: assert _shape[0] - shape[0] <= max_pad assert _shape[1] - shape[1] <= max_pad + + +@pytest.mark.parametrize( + "shape, blocksize, expect_ax", + [ + ((800, 600), [400, 200], "YX"), + ((800, 600, 3), [400, 200], "YXS"), + ((800, 600, 4), [400, 200], "YXS"), + ((2, 800, 600), [400, 200], "SYX"), + ((160, 30), 16, "YX"), + ((160, 30, 5), 16, "YXS"), + ], +) +@pytest.mark.parametrize( + "dtype, compression, expect_predictor", + [ + ("int16", "deflate", 2), + ("int16", "zstd", 2), + ("uint8", "webp", 1), + ("float32", Unset(), 3), + ], +) +def test_empty_cog(shape, blocksize, expect_ax, dtype, compression, expect_predictor): + tifffile = pytest.importorskip("tifffile") + gbox = GridSpec.web_tiles(0)[0, 0] + if expect_ax == "SYX": + gbox = gbox.zoom_to(shape[1:]) + assert gbox.shape == shape[1:] + else: + gbox = gbox.zoom_to(shape[:2]) + assert gbox.shape == shape[:2] + + mm = make_empty_cog( + shape, + dtype, + gbox=gbox, + blocksize=blocksize, + compression=compression, + ) + assert isinstance(mm, memoryview) + + f = tifffile.TiffFile(BytesIO(mm)) + assert f.tiff.is_bigtiff + + p = f.pages[0] + assert p.shape[0] >= shape[0] + assert p.shape[1] >= shape[1] + assert p.dtype == dtype + assert p.axes == expect_ax + assert p.predictor == expect_predictor + + if isinstance(compression, str): + compression = compression.upper() + compression = {"DEFLATE": "ADOBE_DEFLATE"}.get(compression, compression) + assert p.compression.name == compression + else: + # should default to deflate + assert p.compression == 8 + + if expect_ax == "YX": + assert f.pages[-1].chunked == (1, 1) + elif expect_ax == "YXS": + assert f.pages[-1].chunked[:2] == (1, 1) + elif expect_ax == "SYX": + assert f.pages[-1].chunked[1:] == (1, 1) + + if not isinstance(blocksize, list): + blocksize = [blocksize] + + _blocks = itertools.chain(iter(blocksize), itertools.repeat(blocksize[-1])) + for p, tsz in zip(f.pages, _blocks): + if isinstance(tsz, int): + tsz = (tsz, tsz) + + assert p.chunks[0] % 16 == 0 + assert p.chunks[1] % 16 == 0 + + assert tsz[0] <= p.chunks[0] < tsz[0] + 16 + assert tsz[1] <= p.chunks[1] < tsz[1] + 16