Skip to content

Commit

Permalink
Adding make_empty_cog function
Browse files Browse the repository at this point in the history
  • Loading branch information
Kirill888 committed Sep 14, 2023
1 parent 9f4cb64 commit 405d0ff
Show file tree
Hide file tree
Showing 2 changed files with 223 additions and 2 deletions.
141 changes: 140 additions & 1 deletion odc/geo/_cog.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,24 +5,30 @@
"""
Write Cloud Optimized GeoTIFFs from xarrays.
"""
import itertools
import warnings
from contextlib import contextmanager
from io import BytesIO
from pathlib import Path
from typing import Any, Dict, Generator, Iterable, List, Optional, Tuple, Union
from typing import Any, Dict, Generator, Iterable, List, Literal, Optional, Tuple, Union
from uuid import uuid4

import numpy as np
import rasterio
import xarray as xr
from rasterio.shutil import copy as rio_copy # pylint: disable=no-name-in-module

from ._interop import have
from .converters import geotiff_metadata
from .geobox import GeoBox
from .math import align_down_pow2, align_up
from .types import MaybeNodata, Shape2d, SomeShape, Unset, shape_, wh_
from .warp import resampling_s2rio

# pylint: disable=too-many-locals,too-many-branches,too-many-arguments,too-many-statements

AxisOrder = Union[Literal["YX"], Literal["YXS"], Literal["SYX"]]


def _without(cfg: Dict[str, Any], *skip: str) -> Dict[str, Any]:
skip = set(skip)
Expand Down Expand Up @@ -60,6 +66,15 @@ def _adjust_blocksize(block: int, dim: int = 0) -> int:
return align_up(block, 16)


def _norm_blocksize(block: Union[int, Tuple[int, int]]) -> Tuple[int, int]:
if isinstance(block, int):
block = _adjust_blocksize(block)
return (block, block)

b1, b2 = map(_adjust_blocksize, block)
return (b1, b2)


def _num_overviews(block: int, dim: int) -> int:
c = 0
while block < dim:
Expand Down Expand Up @@ -480,3 +495,127 @@ def write_cog_layers(
else:
rio_copy(temp_fname, dst, copy_src_overviews=True, **rio_opts)
return Path(dst)


def _yaxis_from_shape(
shape: Tuple[int, ...], gbox: Optional[GeoBox] = None
) -> Tuple[AxisOrder, int]:
ndim = len(shape)

if ndim == 2:
return "YX", 0

if ndim != 3:
raise ValueError("Can only work with 2-d or 3-d data")
if shape[-1] in (3, 4): # YXS in RGB(A)
return "YXS", 0

if gbox is None:
return "SYX", 1
if gbox.shape == shape[:2]: # YXS
return "YXS", 0
if gbox.shape == shape[1:]: # SYX
return "SYX", 1

raise ValueError("Geobox and image shape do not match")


def make_empty_cog(
shape: Tuple[int, ...],
dtype: Any,
gbox: Optional[GeoBox] = None,
*,
nodata: MaybeNodata = None,
gdal_metadata: Optional[str] = None,
compression: Union[str, Unset] = Unset(),
predictor: Union[int, str, bool, Unset] = Unset(),
blocksize: Union[int, List[Union[int, Tuple[int, int]]]] = 2048,
**kw,
) -> memoryview:
# pylint: disable=import-outside-toplevel
have.check_or_error("tifffile", "rasterio", "xarray")
from tifffile import (
COMPRESSION,
FILETYPE,
PHOTOMETRIC,
PLANARCONFIG,
TIFF,
TiffWriter,
)

if isinstance(compression, Unset):
compression = str(kw.pop("compress", "ADOBE_DEFLATE"))
compression = compression.upper()
compression = {"DEFLATE": "ADOBE_DEFLATE"}.get(compression, compression)
compression = COMPRESSION[compression]

if isinstance(predictor, Unset):
predictor = compression not in TIFF.IMAGE_COMPRESSIONS

if isinstance(blocksize, int):
blocksize = [blocksize]

ax, yaxis = _yaxis_from_shape(shape, gbox)
im_shape = shape_(shape[yaxis : yaxis + 2])
photometric = PHOTOMETRIC.MINISBLACK
planarconfig = PLANARCONFIG.SEPARATE
if ax == "YX":
nsamples = 1
elif ax == "YXS":
nsamples = shape[-1]
planarconfig = PLANARCONFIG.CONTIG
if nsamples in (3, 4):
photometric = PHOTOMETRIC.RGB
else:
nsamples = shape[0]

extratags: List[Tuple[int, int, int, Any]] = []
if gbox is not None:
extratags, _ = geotiff_metadata(
gbox, nodata=nodata, gdal_metadata=gdal_metadata
)

buf = BytesIO()

opts_common = {
"dtype": dtype,
"photometric": photometric,
"planarconfig": planarconfig,
"predictor": predictor,
"compression": compression,
"software": False,
**kw,
}

def _sh(shape: Shape2d) -> Tuple[int, ...]:
if ax == "YX":
return shape.shape
if ax == "YXS":
return (*shape.shape, nsamples)
return (nsamples, *shape.shape)

tsz = _norm_blocksize(blocksize[-1])
im_shape, _, nlevels = _compute_cog_spec(im_shape, tsz)

_blocks = itertools.chain(iter(blocksize), itertools.repeat(blocksize[-1]))

tw = TiffWriter(buf, bigtiff=True)

for tsz, idx in zip(_blocks, range(nlevels + 1)):
if idx == 0:
kw = {**opts_common, "extratags": extratags}
else:
kw = {**opts_common, "subfiletype": FILETYPE.REDUCEDIMAGE}

tw.write(
itertools.repeat(b""),
shape=_sh(im_shape),
tile=_norm_blocksize(tsz),
**kw,
)

im_shape = im_shape.shrink2()

tw.close()

return buf.getbuffer()
84 changes: 83 additions & 1 deletion tests/test_cog.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,12 @@
import itertools
from io import BytesIO
from typing import Optional, Tuple

import pytest

from odc.geo._cog import _compute_cog_spec, _num_overviews
from odc.geo._cog import _compute_cog_spec, _num_overviews, make_empty_cog
from odc.geo.gridspec import GridSpec
from odc.geo.types import Unset
from odc.geo.xr import xr_zeros


Expand Down Expand Up @@ -87,3 +90,82 @@ def test_cog_spec(
if max_pad is not None:
assert _shape[0] - shape[0] <= max_pad
assert _shape[1] - shape[1] <= max_pad


@pytest.mark.parametrize(
"shape, blocksize, expect_ax",
[
((800, 600), [400, 200], "YX"),
((800, 600, 3), [400, 200], "YXS"),
((800, 600, 4), [400, 200], "YXS"),
((2, 800, 600), [400, 200], "SYX"),
((160, 30), 16, "YX"),
((160, 30, 5), 16, "YXS"),
],
)
@pytest.mark.parametrize(
"dtype, compression, expect_predictor",
[
("int16", "deflate", 2),
("int16", "zstd", 2),
("uint8", "webp", 1),
("float32", Unset(), 3),
],
)
def test_empty_cog(shape, blocksize, expect_ax, dtype, compression, expect_predictor):
tifffile = pytest.importorskip("tifffile")
gbox = GridSpec.web_tiles(0)[0, 0]
if expect_ax == "SYX":
gbox = gbox.zoom_to(shape[1:])
assert gbox.shape == shape[1:]
else:
gbox = gbox.zoom_to(shape[:2])
assert gbox.shape == shape[:2]

mm = make_empty_cog(
shape,
dtype,
gbox=gbox,
blocksize=blocksize,
compression=compression,
)
assert isinstance(mm, memoryview)

f = tifffile.TiffFile(BytesIO(mm))
assert f.tiff.is_bigtiff

p = f.pages[0]
assert p.shape[0] >= shape[0]
assert p.shape[1] >= shape[1]
assert p.dtype == dtype
assert p.axes == expect_ax
assert p.predictor == expect_predictor

if isinstance(compression, str):
compression = compression.upper()
compression = {"DEFLATE": "ADOBE_DEFLATE"}.get(compression, compression)
assert p.compression.name == compression
else:
# should default to deflate
assert p.compression == 8

if expect_ax == "YX":
assert f.pages[-1].chunked == (1, 1)
elif expect_ax == "YXS":
assert f.pages[-1].chunked[:2] == (1, 1)
elif expect_ax == "SYX":
assert f.pages[-1].chunked[1:] == (1, 1)

if not isinstance(blocksize, list):
blocksize = [blocksize]

_blocks = itertools.chain(iter(blocksize), itertools.repeat(blocksize[-1]))
for p, tsz in zip(f.pages, _blocks):
if isinstance(tsz, int):
tsz = (tsz, tsz)

assert p.chunks[0] % 16 == 0
assert p.chunks[1] % 16 == 0

assert tsz[0] <= p.chunks[0] < tsz[0] + 16
assert tsz[1] <= p.chunks[1] < tsz[1] + 16

0 comments on commit 405d0ff

Please sign in to comment.