diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index 47d88a33..b77ca7de 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -72,8 +72,7 @@ jobs: - name: Initialize mypy run: | - mypy . > /dev/null || true - mypy --install-types --non-interactive + mypy --install-types --non-interactive . || true - name: Run tests run: | diff --git a/docs/api.rst b/docs/api.rst index 1845dd43..58e70650 100644 --- a/docs/api.rst +++ b/docs/api.rst @@ -15,30 +15,27 @@ Get a driver instance .. autofunction:: terracotta.get_driver -SQLite driver -------------- +TerracottaDriver +---------------- -.. autoclass:: terracotta.drivers.sqlite.SQLiteDriver +.. autoclass:: terracotta.drivers.TerracottaDriver :members: - :undoc-members: - :special-members: __init__ - :inherited-members: -Remote SQLite driver --------------------- -.. autoclass:: terracotta.drivers.sqlite_remote.RemoteSQLiteDriver - :members: - :undoc-members: - :special-members: __init__ - :inherited-members: - :exclude-members: delete, insert, create +Supported metadata stores +------------------------- -MySQL driver ------------- +SQLite metadata store ++++++++++++++++++++++ -.. autoclass:: terracotta.drivers.mysql.MySQLDriver - :members: - :undoc-members: - :special-members: __init__ - :inherited-members: +.. autoclass:: terracotta.drivers.sqlite_meta_store.SQLiteMetaStore + +Remote SQLite metadata store +++++++++++++++++++++++++++++ + +.. autoclass:: terracotta.drivers.sqlite_remote_meta_store.RemoteSQLiteMetaStore + +MySQL metadata store +++++++++++++++++++++ + +.. autoclass:: terracotta.drivers.mysql_meta_store.MySQLMetaStore diff --git a/setup.py b/setup.py index 9959aa84..bf34bf66 100644 --- a/setup.py +++ b/setup.py @@ -35,6 +35,7 @@ 'Programming Language :: Python :: 3.7', 'Programming Language :: Python :: 3.8', 'Programming Language :: Python :: 3.9', + 'Programming Language :: Python :: 3.10', 'Framework :: Flask', 'Operating System :: Microsoft :: Windows :: Windows 10', 'Operating System :: MacOS :: MacOS X', @@ -72,6 +73,7 @@ 'shapely', 'rasterio>=1.0', 'shapely', + 'sqlalchemy', 'toml', 'tqdm' ], diff --git a/terracotta/cog.py b/terracotta/cog.py index 14470604..08d612e4 100644 --- a/terracotta/cog.py +++ b/terracotta/cog.py @@ -25,7 +25,7 @@ def validate(src_path: str, strict: bool = True) -> bool: def check_raster_file(src_path: str) -> ValidationInfo: # pragma: no cover """ Implementation from - https://github.com/cogeotiff/rio-cogeo/blob/0f00a6ee1eff602014fbc88178a069bd9f4a10da/rio_cogeo/cogeo.py + https://github.com/cogeotiff/rio-cogeo/blob/a07d914e2d898878417638bbc089179f01eb5b28/rio_cogeo/cogeo.py#L385 This function is the rasterio equivalent of https://svn.osgeo.org/gdal/trunk/gdal/swig/python/samples/validate_cloud_optimized_geotiff.py @@ -44,15 +44,13 @@ def check_raster_file(src_path: str) -> ValidationInfo: # pragma: no cover errors.append('The file is not a GeoTIFF') return errors, warnings, details - filelist = [os.path.basename(f) for f in src.files] - src_bname = os.path.basename(src_path) - if len(filelist) > 1 and src_bname + '.ovr' in filelist: + if any(os.path.splitext(x)[-1] == '.ovr' for x in src.files): errors.append( 'Overviews found in external .ovr file. They should be internal' ) overviews = src.overviews(1) - if src.width >= 512 or src.height >= 512: + if src.width > 512 and src.height > 512: if not src.is_tiled: errors.append( 'The file is greater than 512xH or 512xW, but is not tiled' @@ -65,16 +63,28 @@ def check_raster_file(src_path: str) -> ValidationInfo: # pragma: no cover ) ifd_offset = int(src.get_tag_item('IFD_OFFSET', 'TIFF', bidx=1)) - ifd_offsets = [ifd_offset] + # Starting from GDAL 3.1, GeoTIFF and COG have ghost headers + # e.g: + # """ + # GDAL_STRUCTURAL_METADATA_SIZE=000140 bytes + # LAYOUT=IFDS_BEFORE_DATA + # BLOCK_ORDER=ROW_MAJOR + # BLOCK_LEADER=SIZE_AS_UINT4 + # BLOCK_TRAILER=LAST_4_BYTES_REPEATED + # KNOWN_INCOMPATIBLE_EDITION=NO + # """ + # + # This header should be < 200bytes if ifd_offset > 300: errors.append( f'The offset of the main IFD should be < 300. It is {ifd_offset} instead' ) + ifd_offsets = [ifd_offset] details['ifd_offsets'] = {} details['ifd_offsets']['main'] = ifd_offset - if not overviews == sorted(overviews): + if overviews and overviews != sorted(overviews): errors.append('Overviews should be sorted') for ix, dec in enumerate(overviews): @@ -111,9 +121,7 @@ def check_raster_file(src_path: str) -> ValidationInfo: # pragma: no cover ) ) - block_offset = int(src.get_tag_item('BLOCK_OFFSET_0_0', 'TIFF', bidx=1)) - if not block_offset: - errors.append('Missing BLOCK_OFFSET_0_0') + block_offset = src.get_tag_item('BLOCK_OFFSET_0_0', 'TIFF', bidx=1) data_offset = int(block_offset) if block_offset else 0 data_offsets = [data_offset] @@ -121,13 +129,14 @@ def check_raster_file(src_path: str) -> ValidationInfo: # pragma: no cover details['data_offsets']['main'] = data_offset for ix, dec in enumerate(overviews): - data_offset = int( - src.get_tag_item('BLOCK_OFFSET_0_0', 'TIFF', bidx=1, ovr=ix) + block_offset = src.get_tag_item( + 'BLOCK_OFFSET_0_0', 'TIFF', bidx=1, ovr=ix ) + data_offset = int(block_offset) if block_offset else 0 data_offsets.append(data_offset) details['data_offsets']['overview_{}'.format(ix)] = data_offset - if data_offsets[-1] < ifd_offsets[-1]: + if data_offsets[-1] != 0 and data_offsets[-1] < ifd_offsets[-1]: if len(overviews) > 0: errors.append( 'The offset of the first block of the smallest overview ' @@ -156,7 +165,7 @@ def check_raster_file(src_path: str) -> ValidationInfo: # pragma: no cover for ix, dec in enumerate(overviews): with rasterio.open(src_path, OVERVIEW_LEVEL=ix) as ovr_dst: - if ovr_dst.width >= 512 or ovr_dst.height >= 512: + if ovr_dst.width > 512 and ovr_dst.height > 512: if not ovr_dst.is_tiled: errors.append('Overview of index {} is not tiled'.format(ix)) diff --git a/terracotta/drivers/__init__.py b/terracotta/drivers/__init__.py index c8fd7415..26f38988 100644 --- a/terracotta/drivers/__init__.py +++ b/terracotta/drivers/__init__.py @@ -3,33 +3,36 @@ Define an interface to retrieve Terracotta drivers. """ +import os from typing import Union, Tuple, Dict, Type import urllib.parse as urlparse from pathlib import Path -from terracotta.drivers.base import Driver +from terracotta.drivers.base_classes import MetaStore +from terracotta.drivers.terracotta_driver import TerracottaDriver +from terracotta.drivers.geotiff_raster_store import GeoTiffRasterStore URLOrPathType = Union[str, Path] -def load_driver(provider: str) -> Type[Driver]: +def load_driver(provider: str) -> Type[MetaStore]: if provider == 'sqlite-remote': - from terracotta.drivers.sqlite_remote import RemoteSQLiteDriver - return RemoteSQLiteDriver + from terracotta.drivers.sqlite_remote_meta_store import RemoteSQLiteMetaStore + return RemoteSQLiteMetaStore if provider == 'mysql': - from terracotta.drivers.mysql import MySQLDriver - return MySQLDriver + from terracotta.drivers.mysql_meta_store import MySQLMetaStore + return MySQLMetaStore if provider == 'sqlite': - from terracotta.drivers.sqlite import SQLiteDriver - return SQLiteDriver + from terracotta.drivers.sqlite_meta_store import SQLiteMetaStore + return SQLiteMetaStore raise ValueError(f'Unknown database provider {provider}') -def auto_detect_provider(url_or_path: Union[str, Path]) -> str: - parsed_path = urlparse.urlparse(str(url_or_path)) +def auto_detect_provider(url_or_path: str) -> str: + parsed_path = urlparse.urlparse(url_or_path) scheme = parsed_path.scheme if scheme == 's3': @@ -41,10 +44,10 @@ def auto_detect_provider(url_or_path: Union[str, Path]) -> str: return 'sqlite' -_DRIVER_CACHE: Dict[Tuple[URLOrPathType, str], Driver] = {} +_DRIVER_CACHE: Dict[Tuple[URLOrPathType, str, int], TerracottaDriver] = {} -def get_driver(url_or_path: URLOrPathType, provider: str = None) -> Driver: +def get_driver(url_or_path: URLOrPathType, provider: str = None) -> TerracottaDriver: """Retrieve Terracotta driver instance for the given path. This function always returns the same instance for identical inputs. @@ -65,25 +68,37 @@ def get_driver(url_or_path: URLOrPathType, provider: str = None) -> Driver: >>> import terracotta as tc >>> tc.get_driver('tc.sqlite') - SQLiteDriver('/home/terracotta/tc.sqlite') + TerracottaDriver( + meta_store=SQLiteDriver('/home/terracotta/tc.sqlite'), + raster_store=GeoTiffRasterStore() + ) >>> tc.get_driver('mysql://root@localhost/tc') - MySQLDriver('mysql://root@localhost:3306/tc') + TerracottaDriver( + meta_store=MySQLDriver('mysql+pymysql://localhost:3306/tc'), + raster_store=GeoTiffRasterStore() + ) >>> # pass provider if path is given in a non-standard way >>> tc.get_driver('root@localhost/tc', provider='mysql') - MySQLDriver('mysql://root@localhost:3306/tc') + TerracottaDriver( + meta_store=MySQLDriver('mysql+pymysql://localhost:3306/tc'), + raster_store=GeoTiffRasterStore() + ) """ + url_or_path = str(url_or_path) + if provider is None: # try and auto-detect provider = auto_detect_provider(url_or_path) - if isinstance(url_or_path, Path) or provider == 'sqlite': - url_or_path = str(Path(url_or_path).resolve()) - DriverClass = load_driver(provider) normalized_path = DriverClass._normalize_path(url_or_path) - cache_key = (normalized_path, provider) + cache_key = (normalized_path, provider, os.getpid()) if cache_key not in _DRIVER_CACHE: - _DRIVER_CACHE[cache_key] = DriverClass(url_or_path) + driver = TerracottaDriver( + meta_store=DriverClass(url_or_path), + raster_store=GeoTiffRasterStore() + ) + _DRIVER_CACHE[cache_key] = driver return _DRIVER_CACHE[cache_key] diff --git a/terracotta/drivers/base.py b/terracotta/drivers/base.py deleted file mode 100644 index 37cd738b..00000000 --- a/terracotta/drivers/base.py +++ /dev/null @@ -1,183 +0,0 @@ -"""drivers/base.py - -Base class for drivers. -""" - -from typing import Callable, List, Mapping, Any, Tuple, Sequence, Dict, Union, TypeVar -from abc import ABC, abstractmethod -from collections import OrderedDict -import functools -import contextlib - -Number = TypeVar('Number', int, float) -T = TypeVar('T') - - -def requires_connection(fun: Callable[..., T]) -> Callable[..., T]: - @functools.wraps(fun) - def inner(self: Driver, *args: Any, **kwargs: Any) -> T: - with self.connect(): - return fun(self, *args, **kwargs) - return inner - - -class Driver(ABC): - """Abstract base class for all Terracotta data backends. - - Defines a common interface for all drivers. - """ - _RESERVED_KEYS = ('limit', 'page') - - db_version: str #: Terracotta version used to create the database - key_names: Tuple[str] #: Names of all keys defined by the database - - @abstractmethod - def __init__(self, url_or_path: str) -> None: - self.path = url_or_path - - @classmethod - def _normalize_path(cls, path: str) -> str: - """Convert given path to normalized version (that can be used for caching)""" - return path - - @abstractmethod - def create(self, keys: Sequence[str], *, - key_descriptions: Mapping[str, str] = None) -> None: - # Create a new, empty database (driver dependent) - pass - - @abstractmethod - def connect(self) -> contextlib.AbstractContextManager: - """Context manager to connect to a given database and clean up on exit. - - This allows you to pool interactions with the database to prevent possibly - expensive reconnects, or to roll back several interactions if one of them fails. - - Note: - - Make sure to call :meth:`create` on a fresh database before using this method. - - Example: - - >>> import terracotta as tc - >>> driver = tc.get_driver('tc.sqlite') - >>> with driver.connect(): - ... for keys, dataset in datasets.items(): - ... # connection will be kept open between insert operations - ... driver.insert(keys, dataset) - - """ - pass - - @abstractmethod - def get_keys(self) -> OrderedDict: - """Get all known keys and their fulltext descriptions. - - Returns: - - An :class:`~collections.OrderedDict` in the form - ``{key_name: key_description}`` - - """ - pass - - @abstractmethod - def get_datasets(self, where: Mapping[str, Union[str, List[str]]] = None, - page: int = 0, limit: int = None) -> Dict[Tuple[str, ...], Any]: - # Get all known dataset key combinations matching the given constraints, - # and a handle to retrieve the data (driver dependent) - pass - - @abstractmethod - def get_metadata(self, keys: Union[Sequence[str], Mapping[str, str]]) -> Dict[str, Any]: - """Return all stored metadata for given keys. - - Arguments: - - keys: Keys of the requested dataset. Can either be given as a sequence of key values, - or as a mapping ``{key_name: key_value}``. - - Returns: - - A :class:`dict` with the values - - - ``range``: global minimum and maximum value in dataset - - ``bounds``: physical bounds covered by dataset in latitude-longitude projection - - ``convex_hull``: GeoJSON shape specifying total data coverage in latitude-longitude - projection - - ``percentiles``: array of pre-computed percentiles from 1% through 99% - - ``mean``: global mean - - ``stdev``: global standard deviation - - ``metadata``: any additional client-relevant metadata - - """ - pass - - @abstractmethod - # TODO: add accurate signature if mypy ever supports conditional return types - def get_raster_tile(self, keys: Union[Sequence[str], Mapping[str, str]], *, - tile_bounds: Sequence[float] = None, - tile_size: Sequence[int] = (256, 256), - preserve_values: bool = False, - asynchronous: bool = False) -> Any: - """Load a raster tile with given keys and bounds. - - Arguments: - - keys: Keys of the requested dataset. Can either be given as a sequence of key values, - or as a mapping ``{key_name: key_value}``. - tile_bounds: Physical bounds of the tile to read, in Web Mercator projection (EPSG3857). - Reads the whole dataset if not given. - tile_size: Shape of the output array to return. Must be two-dimensional. - Defaults to :attr:`~terracotta.config.TerracottaSettings.DEFAULT_TILE_SIZE`. - preserve_values: Whether to preserve exact numerical values (e.g. when reading - categorical data). Sets all interpolation to nearest neighbor. - asynchronous: If given, the tile will be read asynchronously in a separate thread. - This function will return immediately with a :class:`~concurrent.futures.Future` - that can be used to retrieve the result. - - Returns: - - Requested tile as :class:`~numpy.ma.MaskedArray` of shape ``tile_size`` if - ``asynchronous=False``, otherwise a :class:`~concurrent.futures.Future` containing - the result. - - """ - pass - - @staticmethod - @abstractmethod - def compute_metadata(data: Any, *, - extra_metadata: Any = None, - **kwargs: Any) -> Dict[str, Any]: - # Compute metadata for a given input file (driver dependent) - pass - - @abstractmethod - def insert(self, keys: Union[Sequence[str], Mapping[str, str]], - handle: Any, **kwargs: Any) -> None: - """Register a new dataset. Used to populate metadata database. - - Arguments: - - keys: Keys of the dataset. Can either be given as a sequence of key values, or - as a mapping ``{key_name: key_value}``. - handle: Handle to access dataset (driver dependent). - - """ - pass - - @abstractmethod - def delete(self, keys: Union[Sequence[str], Mapping[str, str]]) -> None: - """Remove a dataset from the metadata database. - - Arguments: - - keys: Keys of the dataset. Can either be given as a sequence of key values, or - as a mapping ``{key_name: key_value}``. - - """ - pass - - def __repr__(self) -> str: - return f'{self.__class__.__name__}(\'{self.path}\')' diff --git a/terracotta/drivers/base_classes.py b/terracotta/drivers/base_classes.py new file mode 100644 index 00000000..bae3f206 --- /dev/null +++ b/terracotta/drivers/base_classes.py @@ -0,0 +1,135 @@ +"""drivers/base_classes.py + +Base class for drivers. +""" + +import contextlib +import functools +from abc import ABC, abstractmethod +from collections import OrderedDict +from typing import (Any, Callable, Dict, List, Mapping, Optional, Sequence, + Tuple, TypeVar, Union) + +KeysType = Mapping[str, str] +MultiValueKeysType = Mapping[str, Union[str, List[str]]] +Number = TypeVar('Number', int, float) +T = TypeVar('T') + + +def requires_connection( + fun: Callable[..., T] = None, *, + verify: bool = True +) -> Union[Callable[..., T], functools.partial]: + if fun is None: + return functools.partial(requires_connection, verify=verify) + + @functools.wraps(fun) + def inner(self: MetaStore, *args: Any, **kwargs: Any) -> T: + assert fun is not None + with self.connect(verify=verify): + return fun(self, *args, **kwargs) + + return inner + + +class MetaStore(ABC): + """Abstract base class for all Terracotta metadata backends. + + Defines a common interface for all metadata backends. + """ + _RESERVED_KEYS = ('limit', 'page') + + @property + @abstractmethod + def db_version(self) -> str: + """Terracotta version used to create the database.""" + pass + + @property + @abstractmethod + def key_names(self) -> Tuple[str, ...]: + """Names of all keys defined by the database.""" + pass + + @abstractmethod + def __init__(self, url_or_path: str) -> None: + self.path = url_or_path + + @classmethod + def _normalize_path(cls, path: str) -> str: + """Convert given path to normalized version (that can be used for caching)""" + return path + + @abstractmethod + def create(self, keys: Sequence[str], *, + key_descriptions: Mapping[str, str] = None) -> None: + """Create a new, empty database""" + pass + + @abstractmethod + def connect(self, verify: bool = True) -> contextlib.AbstractContextManager: + """Context manager to connect to a given database and clean up on exit. + + This allows you to pool interactions with the database to prevent possibly + expensive reconnects, or to roll back several interactions if one of them fails. + """ + pass + + @abstractmethod + def get_keys(self) -> OrderedDict: + """Get all known keys and their fulltext descriptions.""" + pass + + @abstractmethod + def get_datasets(self, where: MultiValueKeysType = None, + page: int = 0, limit: int = None) -> Dict[Tuple[str, ...], Any]: + """Get all known dataset key combinations matching the given constraints, + and a path to retrieve the data + """ + pass + + @abstractmethod + def get_metadata(self, keys: KeysType) -> Optional[Dict[str, Any]]: + """Return all stored metadata for given keys.""" + pass + + @abstractmethod + def insert(self, keys: KeysType, path: str, *, metadata: Mapping[str, Any] = None) -> None: + """Register a new dataset. This also populates the metadata database, + if metadata is specified and not `None`.""" + pass + + @abstractmethod + def delete(self, keys: KeysType) -> None: + """Remove a dataset, including information from the metadata database.""" + pass + + def __repr__(self) -> str: + return f'{self.__class__.__name__}(\'{self.path}\')' + + +class RasterStore(ABC): + """Abstract base class for all Terracotta raster backends. + + Defines a common interface for all raster backends.""" + + @abstractmethod + # TODO: add accurate signature if mypy ever supports conditional return types + def get_raster_tile(self, path: str, *, + tile_bounds: Sequence[float] = None, + tile_size: Sequence[int] = (256, 256), + preserve_values: bool = False, + asynchronous: bool = False) -> Any: + """Load a raster tile with given path and bounds.""" + pass + + @abstractmethod + def compute_metadata(self, path: str, *, + extra_metadata: Any = None, + use_chunks: bool = None, + max_shape: Sequence[int] = None) -> Dict[str, Any]: + """Compute metadata for a given input file""" + pass + + def __repr__(self) -> str: + return f'{self.__class__.__name__}()' diff --git a/terracotta/drivers/geotiff_raster_store.py b/terracotta/drivers/geotiff_raster_store.py new file mode 100644 index 00000000..d187383b --- /dev/null +++ b/terracotta/drivers/geotiff_raster_store.py @@ -0,0 +1,170 @@ +"""drivers/geotiff_raster_store.py + +Base class for drivers operating on physical raster files. +""" + +from typing import Any, Callable, Sequence, Dict, TypeVar +from concurrent.futures import Future, Executor, ProcessPoolExecutor, ThreadPoolExecutor +from concurrent.futures.process import BrokenProcessPool + +import functools +import logging +import warnings +import threading + +import numpy as np + +from terracotta import get_settings +from terracotta import raster +from terracotta.cache import CompressedLFUCache +from terracotta.drivers.base_classes import RasterStore + +Number = TypeVar('Number', int, float) + +logger = logging.getLogger(__name__) + +context = threading.local() +context.executor = None + + +def create_executor() -> Executor: + settings = get_settings() + + if not settings.USE_MULTIPROCESSING: + return ThreadPoolExecutor(max_workers=1) + + executor: Executor + + try: + # this fails on architectures without /dev/shm + executor = ProcessPoolExecutor(max_workers=3) + except OSError: + # fall back to serial evaluation + warnings.warn( + 'Multiprocessing is not available on this system. ' + 'Falling back to serial execution.' + ) + executor = ThreadPoolExecutor(max_workers=1) + + return executor + + +def submit_to_executor(task: Callable[..., Any]) -> Future: + if context.executor is None: + context.executor = create_executor() + + try: + future = context.executor.submit(task) + except BrokenProcessPool: + # re-create executor and try again + logger.warn('Re-creating broken process pool') + context.executor = create_executor() + future = context.executor.submit(task) + + return future + + +def ensure_hashable(val: Any) -> Any: + if isinstance(val, list): + return tuple(val) + + if isinstance(val, dict): + return tuple((k, ensure_hashable(v)) for k, v in val.items()) + + return val + + +class GeoTiffRasterStore(RasterStore): + """Raster store that operates on GeoTiff raster files from disk. + + Path arguments are expected to be file paths. + """ + _TARGET_CRS: str = 'epsg:3857' + _LARGE_RASTER_THRESHOLD: int = 10980 * 10980 + _RIO_ENV_OPTIONS = dict( + GDAL_TIFF_INTERNAL_MASK=True, + GDAL_DISABLE_READDIR_ON_OPEN='EMPTY_DIR' + ) + + def __init__(self) -> None: + settings = get_settings() + self._raster_cache = CompressedLFUCache( + settings.RASTER_CACHE_SIZE, + compression_level=settings.RASTER_CACHE_COMPRESS_LEVEL + ) + self._cache_lock = threading.RLock() + + def compute_metadata(self, path: str, *, + extra_metadata: Any = None, + use_chunks: bool = None, + max_shape: Sequence[int] = None) -> Dict[str, Any]: + return raster.compute_metadata(path, extra_metadata=extra_metadata, + use_chunks=use_chunks, max_shape=max_shape, + large_raster_threshold=self._LARGE_RASTER_THRESHOLD, + rio_env_options=self._RIO_ENV_OPTIONS) + + # return type has to be Any until mypy supports conditional return types + def get_raster_tile(self, + path: str, *, + tile_bounds: Sequence[float] = None, + tile_size: Sequence[int] = None, + preserve_values: bool = False, + asynchronous: bool = False) -> Any: + future: Future[np.ma.MaskedArray] + result: np.ma.MaskedArray + + settings = get_settings() + + if tile_size is None: + tile_size = settings.DEFAULT_TILE_SIZE + + kwargs = dict( + path=path, + tile_bounds=tile_bounds, + tile_size=tuple(tile_size), + preserve_values=preserve_values, + reprojection_method=settings.REPROJECTION_METHOD, + resampling_method=settings.RESAMPLING_METHOD, + target_crs=self._TARGET_CRS, + rio_env_options=self._RIO_ENV_OPTIONS, + ) + + cache_key = hash(ensure_hashable(kwargs)) + + try: + with self._cache_lock: + result = self._raster_cache[cache_key] + except KeyError: + pass + else: + if asynchronous: + # wrap result in a future + future = Future() + future.set_result(result) + return future + else: + return result + + retrieve_tile = functools.partial(raster.get_raster_tile, **kwargs) + + future = submit_to_executor(retrieve_tile) + + def cache_callback(future: Future) -> None: + # insert result into global cache if execution was successful + if future.exception() is None: + self._add_to_cache(cache_key, future.result()) + + if asynchronous: + future.add_done_callback(cache_callback) + return future + else: + result = future.result() + cache_callback(future) + return result + + def _add_to_cache(self, key: Any, value: Any) -> None: + try: + with self._cache_lock: + self._raster_cache[key] = value + except ValueError: # value too large + pass diff --git a/terracotta/drivers/mysql.py b/terracotta/drivers/mysql.py deleted file mode 100644 index 1edbc402..00000000 --- a/terracotta/drivers/mysql.py +++ /dev/null @@ -1,499 +0,0 @@ -"""drivers/sqlite.py - -MySQL-backed raster driver. Metadata is stored in a MySQL database, raster data is assumed -to be present on disk. -""" - -from typing import (List, Tuple, Dict, Iterator, Sequence, Union, - Mapping, Any, Optional, cast, TypeVar) -from collections import OrderedDict -import contextlib -from contextlib import AbstractContextManager -import re -import json -import urllib.parse as urlparse -from urllib.parse import ParseResult - -import numpy as np -import pymysql -from pymysql.connections import Connection -from pymysql.cursors import DictCursor - -from terracotta import get_settings, __version__ -from terracotta.drivers.raster_base import RasterDriver -from terracotta.drivers.base import requires_connection -from terracotta import exceptions -from terracotta.profile import trace - - -T = TypeVar('T') - -_ERROR_ON_CONNECT = ( - 'Could not connect to database. Make sure that the given path points ' - 'to a valid Terracotta database, and that you ran driver.create().' -) - -DEFAULT_PORT = 3306 - - -@contextlib.contextmanager -def convert_exceptions(msg: str) -> Iterator: - """Convert internal mysql exceptions to our InvalidDatabaseError""" - from pymysql import OperationalError, InternalError, ProgrammingError - try: - yield - except (OperationalError, InternalError, ProgrammingError) as exc: - raise exceptions.InvalidDatabaseError(msg) from exc - - -class MySQLCredentials: - __slots__ = ('host', 'port', 'db', '_user', '_password') - - def __init__(self, - host: str, - port: int, - db: str, - user: Optional[str] = None, - password: Optional[str] = None): - self.host = host - self.port = port - self.db = db - self._user = user - self._password = password - - @property - def user(self) -> Optional[str]: - return self._user or get_settings().MYSQL_USER - - @property - def password(self) -> str: - pw = self._password or get_settings().MYSQL_PASSWORD - - if pw is None: - pw = '' - - return pw - - -class MySQLDriver(RasterDriver): - """A MySQL-backed raster driver. - - Assumes raster data to be present in separate GDAL-readable files on disk or remotely. - Stores metadata and paths to raster files in MySQL. - - Requires a running MySQL server. - - The MySQL database consists of 4 different tables: - - - ``terracotta``: Metadata about the database itself. - - ``key_names``: Contains two columns holding all available keys and their description. - - ``datasets``: Maps key values to physical raster path. - - ``metadata``: Contains actual metadata as separate columns. Indexed via key values. - - This driver caches raster data and key names, but not metadata. - """ - _MAX_PRIMARY_KEY_LENGTH = 767 // 4 # Max key length for MySQL is at least 767B - _METADATA_COLUMNS: Tuple[Tuple[str, ...], ...] = ( - ('bounds_north', 'REAL'), - ('bounds_east', 'REAL'), - ('bounds_south', 'REAL'), - ('bounds_west', 'REAL'), - ('convex_hull', 'LONGTEXT'), - ('valid_percentage', 'REAL'), - ('min', 'REAL'), - ('max', 'REAL'), - ('mean', 'REAL'), - ('stdev', 'REAL'), - ('percentiles', 'BLOB'), - ('metadata', 'LONGTEXT') - ) - _CHARSET: str = 'utf8mb4' - - def __init__(self, mysql_path: str) -> None: - """Initialize the MySQLDriver. - - This should not be called directly, use :func:`~terracotta.get_driver` instead. - - Arguments: - - mysql_path: URL to running MySQL server, in the form - ``mysql://username:password@hostname/database`` - - """ - settings = get_settings() - - self.DB_CONNECTION_TIMEOUT: int = settings.DB_CONNECTION_TIMEOUT - - con_params = urlparse.urlparse(mysql_path) - - if not con_params.hostname: - con_params = urlparse.urlparse(f'mysql://{mysql_path}') - - assert con_params.hostname is not None - - if con_params.scheme != 'mysql': - raise ValueError(f'unsupported URL scheme "{con_params.scheme}"') - - self._db_args = MySQLCredentials( - host=con_params.hostname, - user=con_params.username, - password=con_params.password, - port=con_params.port or DEFAULT_PORT, - db=self._parse_db_name(con_params) - ) - - self._connection: Connection - self._cursor: DictCursor - self._connected = False - - self._version_checked: bool = False - self._db_keys: Optional[OrderedDict] = None - - # use normalized path to make sure username and password don't leak into __repr__ - qualified_path = self._normalize_path(mysql_path) - super().__init__(qualified_path) - - @classmethod - def _normalize_path(cls, path: str) -> str: - parts = urlparse.urlparse(path) - - if not parts.hostname: - parts = urlparse.urlparse(f'mysql://{path}') - - path = f'{parts.scheme}://{parts.hostname}:{parts.port or DEFAULT_PORT}{parts.path}' - path = path.rstrip('/') - return path - - @staticmethod - def _parse_db_name(con_params: ParseResult) -> str: - if not con_params.path: - raise ValueError('database must be specified in MySQL path') - - path = con_params.path.strip('/') - if '/' in path: - raise ValueError('invalid database path') - - return path - - @requires_connection - @convert_exceptions(_ERROR_ON_CONNECT) - def _get_db_version(self) -> str: - """Terracotta version used to create the database""" - cursor = self._cursor - cursor.execute('SELECT version from terracotta') - db_row = cast(Dict[str, str], cursor.fetchone()) - return db_row['version'] - - db_version = cast(str, property(_get_db_version)) - - def _connection_callback(self) -> None: - if not self._version_checked: - # check for version compatibility - def versiontuple(version_string: str) -> Sequence[str]: - return version_string.split('.') - - db_version = self.db_version - current_version = __version__ - - if versiontuple(db_version)[:2] != versiontuple(current_version)[:2]: - raise exceptions.InvalidDatabaseError( - f'Version conflict: database was created in v{db_version}, ' - f'but this is v{current_version}' - ) - self._version_checked = True - - def _get_key_names(self) -> Tuple[str, ...]: - """Names of all keys defined by the database""" - return tuple(self.get_keys().keys()) - - key_names = cast(Tuple[str], property(_get_key_names)) - - def connect(self) -> AbstractContextManager: - return self._connect(check=True) - - @contextlib.contextmanager - def _connect(self, check: bool = True) -> Iterator: - close = False - try: - if not self._connected: - with convert_exceptions(_ERROR_ON_CONNECT): - self._connection = pymysql.connect( - host=self._db_args.host, user=self._db_args.user, db=self._db_args.db, - password=self._db_args.password, port=self._db_args.port, - read_timeout=self.DB_CONNECTION_TIMEOUT, - write_timeout=self.DB_CONNECTION_TIMEOUT, - binary_prefix=True, charset='utf8mb4' - ) - self._cursor = self._connection.cursor(DictCursor) - self._connected = close = True - - if check: - self._connection_callback() - - try: - yield - except Exception: - self._connection.rollback() - raise - - finally: - if close: - self._connected = False - self._cursor.close() - self._connection.commit() - self._connection.close() - - @convert_exceptions('Could not create database') - def create(self, keys: Sequence[str], key_descriptions: Mapping[str, str] = None) -> None: - """Create and initialize database with empty tables. - - This must be called before opening the first connection. The MySQL database must not - exist already. - - Arguments: - - keys: Key names to use throughout the Terracotta database. - key_descriptions: Optional (but recommended) full-text description for some keys, - in the form of ``{key_name: description}``. - - """ - if key_descriptions is None: - key_descriptions = {} - else: - key_descriptions = dict(key_descriptions) - - if not all(k in keys for k in key_descriptions.keys()): - raise exceptions.InvalidKeyError('key description dict contains unknown keys') - - if not all(re.match(r'^\w+$', key) for key in keys): - raise exceptions.InvalidKeyError('key names must be alphanumeric') - - if any(key in self._RESERVED_KEYS for key in keys): - raise exceptions.InvalidKeyError(f'key names cannot be one of {self._RESERVED_KEYS!s}') - - for key in keys: - if key not in key_descriptions: - key_descriptions[key] = '' - - # total primary key length has an upper limit in MySQL - key_size = self._MAX_PRIMARY_KEY_LENGTH // len(keys) - key_type = f'VARCHAR({key_size})' - - connection = pymysql.connect( - host=self._db_args.host, user=self._db_args.user, - password=self._db_args.password, port=self._db_args.port, - read_timeout=self.DB_CONNECTION_TIMEOUT, - write_timeout=self.DB_CONNECTION_TIMEOUT, - binary_prefix=True, charset='utf8mb4' - ) - - with connection, connection.cursor() as cursor: - cursor.execute(f'CREATE DATABASE {self._db_args.db}') - - with self._connect(check=False): - cursor = self._cursor - cursor.execute(f'CREATE TABLE terracotta (version VARCHAR(255)) ' - f'CHARACTER SET {self._CHARSET}') - cursor.execute('INSERT INTO terracotta VALUES (%s)', [str(__version__)]) - - cursor.execute(f'CREATE TABLE key_names (key_name {key_type}, ' - f'description VARCHAR(8000)) CHARACTER SET {self._CHARSET}') - key_rows = [(key, key_descriptions[key]) for key in keys] - cursor.executemany('INSERT INTO key_names VALUES (%s, %s)', key_rows) - - key_string = ', '.join([f'{key} {key_type}' for key in keys]) - cursor.execute(f'CREATE TABLE datasets ({key_string}, filepath VARCHAR(8000), ' - f'PRIMARY KEY({", ".join(keys)})) CHARACTER SET {self._CHARSET}') - - column_string = ', '.join(f'{col} {col_type}' for col, col_type - in self._METADATA_COLUMNS) - cursor.execute(f'CREATE TABLE metadata ({key_string}, {column_string}, ' - f'PRIMARY KEY ({", ".join(keys)})) CHARACTER SET {self._CHARSET}') - - # invalidate key cache - self._db_keys = None - - def get_keys(self) -> OrderedDict: - if self._db_keys is None: - self._db_keys = self._get_keys() - return self._db_keys - - @requires_connection - @convert_exceptions('Could not retrieve keys from database') - def _get_keys(self) -> OrderedDict: - out: OrderedDict = OrderedDict() - - cursor = self._cursor - cursor.execute('SELECT * FROM key_names') - key_rows = cursor.fetchall() or () - - for row in key_rows: - out[row['key_name']] = row['description'] - - return out - - @trace('get_datasets') - @requires_connection - @convert_exceptions('Could not retrieve datasets') - def get_datasets(self, where: Mapping[str, Union[str, List[str]]] = None, - page: int = 0, limit: int = None) -> Dict[Tuple[str, ...], str]: - cursor = self._cursor - - if limit is not None: - # explicitly cast to int to prevent SQL injection - page_fragment = f'LIMIT {int(limit)} OFFSET {int(page) * int(limit)}' - else: - page_fragment = '' - - # sort by keys to ensure deterministic results - order_fragment = f'ORDER BY {", ".join(self.key_names)}' - - if where is None: - cursor.execute(f'SELECT * FROM datasets {order_fragment} {page_fragment}') - else: - if not all(key in self.key_names for key in where.keys()): - raise exceptions.InvalidKeyError('Encountered unrecognized keys in ' - 'where clause') - conditions = [] - values = [] - for key, value in where.items(): - if isinstance(value, str): - value = [value] - values.extend(value) - conditions.append(' OR '.join([f'{key}=%s'] * len(value))) - where_fragment = ' AND '.join([f'({condition})' for condition in conditions]) - cursor.execute( - f'SELECT * FROM datasets WHERE {where_fragment} {order_fragment} {page_fragment}', - values - ) - - def keytuple(row: Dict[str, Any]) -> Tuple[str, ...]: - return tuple(row[key] for key in self.key_names) - - datasets = {} - for row in cursor: - datasets[keytuple(row)] = row['filepath'] - - return datasets - - @staticmethod - def _encode_data(decoded: Mapping[str, Any]) -> Dict[str, Any]: - """Transform from internal format to database representation""" - encoded = { - 'bounds_north': decoded['bounds'][0], - 'bounds_east': decoded['bounds'][1], - 'bounds_south': decoded['bounds'][2], - 'bounds_west': decoded['bounds'][3], - 'convex_hull': json.dumps(decoded['convex_hull']), - 'valid_percentage': decoded['valid_percentage'], - 'min': decoded['range'][0], - 'max': decoded['range'][1], - 'mean': decoded['mean'], - 'stdev': decoded['stdev'], - 'percentiles': np.array(decoded['percentiles'], dtype='float32').tobytes(), - 'metadata': json.dumps(decoded['metadata']) - } - return encoded - - @staticmethod - def _decode_data(encoded: Mapping[str, Any]) -> Dict[str, Any]: - """Transform from database format to internal representation""" - decoded = { - 'bounds': tuple([encoded[f'bounds_{d}'] for d in ('north', 'east', 'south', 'west')]), - 'convex_hull': json.loads(encoded['convex_hull']), - 'valid_percentage': encoded['valid_percentage'], - 'range': (encoded['min'], encoded['max']), - 'mean': encoded['mean'], - 'stdev': encoded['stdev'], - 'percentiles': np.frombuffer(encoded['percentiles'], dtype='float32').tolist(), - 'metadata': json.loads(encoded['metadata']) - } - return decoded - - @trace('get_metadata') - @requires_connection - @convert_exceptions('Could not retrieve metadata') - def get_metadata(self, keys: Union[Sequence[str], Mapping[str, str]]) -> Dict[str, Any]: - keys = tuple(self._key_dict_to_sequence(keys)) - - if len(keys) != len(self.key_names): - raise exceptions.InvalidKeyError('Got wrong number of keys') - - cursor = self._cursor - - where_string = ' AND '.join([f'{key}=%s' for key in self.key_names]) - cursor.execute(f'SELECT * FROM metadata WHERE {where_string}', keys) - row = cursor.fetchone() - - if not row: # support lazy loading - filepath = self.get_datasets(dict(zip(self.key_names, keys))) - if not filepath: - raise exceptions.DatasetNotFoundError(f'No dataset found for given keys {keys}') - assert len(filepath) == 1 - - # compute metadata and try again - self.insert(keys, filepath[keys], skip_metadata=False) - cursor.execute(f'SELECT * FROM metadata WHERE {where_string}', keys) - row = cursor.fetchone() - - assert row - - data_columns, _ = zip(*self._METADATA_COLUMNS) - encoded_data = {col: row[col] for col in self.key_names + data_columns} - return self._decode_data(encoded_data) - - @trace('insert') - @requires_connection - @convert_exceptions('Could not write to database') - def insert(self, - keys: Union[Sequence[str], Mapping[str, str]], - filepath: str, *, - metadata: Mapping[str, Any] = None, - skip_metadata: bool = False, - override_path: str = None) -> None: - cursor = self._cursor - - if len(keys) != len(self.key_names): - raise exceptions.InvalidKeyError( - f'Got wrong number of keys (available keys: {self.key_names})' - ) - - if override_path is None: - override_path = filepath - - keys = self._key_dict_to_sequence(keys) - template_string = ', '.join(['%s'] * (len(keys) + 1)) - cursor.execute(f'REPLACE INTO datasets VALUES ({template_string})', - [*keys, override_path]) - - if metadata is None and not skip_metadata: - metadata = self.compute_metadata(filepath) - - if metadata is not None: - encoded_data = self._encode_data(metadata) - row_keys, row_values = zip(*encoded_data.items()) - template_string = ', '.join(['%s'] * (len(keys) + len(row_values))) - cursor.execute(f'REPLACE INTO metadata ({", ".join(self.key_names)}, ' - f'{", ".join(row_keys)}) VALUES ({template_string})', - [*keys, *row_values]) - - @trace('delete') - @requires_connection - @convert_exceptions('Could not write to database') - def delete(self, keys: Union[Sequence[str], Mapping[str, str]]) -> None: - cursor = self._cursor - - if len(keys) != len(self.key_names): - raise exceptions.InvalidKeyError( - f'Got wrong number of keys (available keys: {self.key_names})' - ) - - keys = self._key_dict_to_sequence(keys) - key_dict = dict(zip(self.key_names, keys)) - - if not self.get_datasets(key_dict): - raise exceptions.DatasetNotFoundError(f'No dataset found with keys {keys}') - - where_string = ' AND '.join([f'{key}=%s' for key in self.key_names]) - cursor.execute(f'DELETE FROM datasets WHERE {where_string}', keys) - cursor.execute(f'DELETE FROM metadata WHERE {where_string}', keys) diff --git a/terracotta/drivers/mysql_meta_store.py b/terracotta/drivers/mysql_meta_store.py new file mode 100644 index 00000000..7d6fa5e9 --- /dev/null +++ b/terracotta/drivers/mysql_meta_store.py @@ -0,0 +1,86 @@ +"""drivers/mysql_meta_store.py + +MySQL-backed metadata driver. Metadata is stored in a MySQL database. +""" + +import functools +from typing import Mapping, Sequence + +import sqlalchemy as sqla +from sqlalchemy.dialects.mysql import TEXT, VARCHAR +from terracotta.drivers.relational_meta_store import RelationalMetaStore + + +class MySQLMetaStore(RelationalMetaStore): + """A MySQL-backed metadata driver. + + Stores metadata and paths to raster files in MySQL. + + Requires a running MySQL server. + + The MySQL database consists of 4 different tables: + + - ``terracotta``: Metadata about the database itself. + - ``key_names``: Contains two columns holding all available keys and their description. + - ``datasets``: Maps key values to physical raster path. + - ``metadata``: Contains actual metadata as separate columns. Indexed via key values. + + This driver caches key names. + """ + SQL_DIALECT = 'mysql' + SQL_DRIVER = 'pymysql' + SQL_TIMEOUT_KEY = 'connect_timeout' + + _CHARSET = 'utf8mb4' + SQLA_STRING = functools.partial(VARCHAR, charset=_CHARSET) + + MAX_PRIMARY_KEY_SIZE = 767 // 4 # Max key length for MySQL is at least 767B + DEFAULT_PORT = 3306 + + def __init__(self, mysql_path: str) -> None: + """Initialize the MySQLDriver. + + This should not be called directly, use :func:`~terracotta.get_driver` instead. + + Arguments: + + mysql_path: URL to running MySQL server, in the form + ``mysql://username:password@hostname/database`` + + """ + super().__init__(f'{mysql_path}?charset={self._CHARSET}') + + self.SQLA_METADATA_TYPE_LOOKUP['text'] = functools.partial(TEXT, charset=self._CHARSET) + + # raise an exception if database name is invalid + if not self.url.database: + raise ValueError('database must be specified in MySQL path') + if '/' in self.url.database.strip('/'): + raise ValueError('invalid database path') + + @classmethod + def _normalize_path(cls, path: str) -> str: + url = cls._parse_path(path) + + path = f'{url.drivername}://{url.host}:{url.port or cls.DEFAULT_PORT}/{url.database}' + path = path.rstrip('/') + return path + + def _create_database(self) -> None: + engine = sqla.create_engine( + self.url.set(database=''), # `.set()` returns a copy with changed parameters + echo=False, + future=True + ) + with engine.connect() as connection: + connection.execute(sqla.text(f'CREATE DATABASE {self.url.database}')) + connection.commit() + + def _initialize_database( + self, + keys: Sequence[str], + key_descriptions: Mapping[str, str] = None + ) -> None: + # total primary key length has an upper limit in MySQL + self.SQL_KEY_SIZE = self.MAX_PRIMARY_KEY_SIZE // len(keys) + super()._initialize_database(keys, key_descriptions) diff --git a/terracotta/drivers/raster_base.py b/terracotta/drivers/raster_base.py deleted file mode 100644 index ccbcd669..00000000 --- a/terracotta/drivers/raster_base.py +++ /dev/null @@ -1,614 +0,0 @@ -"""drivers/raster_base.py - -Base class for drivers operating on physical raster files. -""" - -from typing import (Any, Callable, Union, Mapping, Sequence, Dict, List, Tuple, - TypeVar, Optional, cast, TYPE_CHECKING) -from abc import abstractmethod -from concurrent.futures import Future, Executor, ProcessPoolExecutor, ThreadPoolExecutor -from concurrent.futures.process import BrokenProcessPool - -import contextlib -import functools -import logging -import warnings -import threading - -import numpy as np - -if TYPE_CHECKING: # pragma: no cover - from rasterio.io import DatasetReader # noqa: F401 - -try: - from crick import TDigest, SummaryStats - has_crick = True -except ImportError: # pragma: no cover - has_crick = False - -from terracotta import get_settings, exceptions -from terracotta.cache import CompressedLFUCache -from terracotta.drivers.base import requires_connection, Driver -from terracotta.profile import trace - -Number = TypeVar('Number', int, float) - -logger = logging.getLogger(__name__) - -context = threading.local() -context.executor = None - - -def create_executor() -> Executor: - settings = get_settings() - - if not settings.USE_MULTIPROCESSING: - return ThreadPoolExecutor(max_workers=1) - - executor: Executor - - try: - # this fails on architectures without /dev/shm - executor = ProcessPoolExecutor(max_workers=3) - except OSError: - # fall back to serial evaluation - warnings.warn( - 'Multiprocessing is not available on this system. ' - 'Falling back to serial execution.' - ) - executor = ThreadPoolExecutor(max_workers=1) - - return executor - - -def submit_to_executor(task: Callable[..., Any]) -> Future: - if context.executor is None: - context.executor = create_executor() - - try: - future = context.executor.submit(task) - except BrokenProcessPool: - # re-create executor and try again - logger.warn('Re-creating broken process pool') - context.executor = create_executor() - future = context.executor.submit(task) - - return future - - -class RasterDriver(Driver): - """Mixin that implements methods to load raster data from disk. - - get_datasets has to return path to raster file as sole dict value. - """ - _TARGET_CRS: str = 'epsg:3857' - _LARGE_RASTER_THRESHOLD: int = 10980 * 10980 - _RIO_ENV_KEYS = dict( - GDAL_TIFF_INTERNAL_MASK=True, - GDAL_DISABLE_READDIR_ON_OPEN='EMPTY_DIR' - ) - - @abstractmethod - def __init__(self, *args: Any, **kwargs: Any) -> None: - settings = get_settings() - self._raster_cache = CompressedLFUCache( - settings.RASTER_CACHE_SIZE, - compression_level=settings.RASTER_CACHE_COMPRESS_LEVEL - ) - self._cache_lock = threading.RLock() - super().__init__(*args, **kwargs) - - # specify signature and docstring for insert - @abstractmethod - def insert(self, # type: ignore - keys: Union[Sequence[str], Mapping[str, str]], - filepath: str, *, - metadata: Mapping[str, Any] = None, - skip_metadata: bool = False, - override_path: str = None) -> None: - """Insert a raster file into the database. - - Arguments: - - keys: Keys identifying the new dataset. Can either be given as a sequence of key - values, or as a mapping ``{key_name: key_value}``. - filepath: Path to the GDAL-readable raster file. - metadata: If not given (default), call :meth:`compute_metadata` with default arguments - to compute raster metadata. Otherwise, use the given values. This can be used to - decouple metadata computation from insertion, or to use the optional arguments - of :meth:`compute_metadata`. - skip_metadata: Do not compute any raster metadata (will be computed during the first - request instead). Use sparingly; this option has a detrimental result on the end - user experience and might lead to surprising results. Has no effect if ``metadata`` - is given. - override_path: Override the path to the raster file in the database. Use this option if - you intend to copy the data somewhere else after insertion (e.g. when moving files - to a cloud storage later on). - - """ - pass - - # specify signature and docstring for get_datasets - @abstractmethod - def get_datasets(self, where: Mapping[str, Union[str, List[str]]] = None, - page: int = 0, limit: int = None) -> Dict[Tuple[str, ...], str]: - """Retrieve keys and file paths of datasets. - - Arguments: - - where: Constraints on returned datasets in the form ``{key_name: allowed_key_value}``. - Returns all datasets if not given (default). - page: Current page of results. Has no effect if ``limit`` is not given. - limit: If given, return at most this many datasets. Unlimited by default. - - - Returns: - - :class:`dict` containing - ``{(key_value1, key_value2, ...): raster_file_path}`` - - Example: - - >>> import terracotta as tc - >>> driver = tc.get_driver('tc.sqlite') - >>> driver.get_datasets() - { - ('reflectance', '20180101', 'B04'): 'reflectance_20180101_B04.tif', - ('reflectance', '20180102', 'B04'): 'reflectance_20180102_B04.tif', - } - >>> driver.get_datasets({'date': '20180101'}) - {('reflectance', '20180101', 'B04'): 'reflectance_20180101_B04.tif'} - - """ - pass - - def _key_dict_to_sequence(self, keys: Union[Mapping[str, Any], Sequence[Any]]) -> List[Any]: - """Convert {key_name: key_value} to [key_value] with the correct key order.""" - try: - keys_as_mapping = cast(Mapping[str, Any], keys) - return [keys_as_mapping[key] for key in self.key_names] - except TypeError: # not a mapping - return list(keys) - except KeyError as exc: - raise exceptions.InvalidKeyError('Encountered unknown key') from exc - - @staticmethod - def _hull_candidate_mask(mask: np.ndarray) -> np.ndarray: - """Returns a reduced boolean mask to speed up convex hull computations. - - Exploits the fact that only the first and last elements of each row and column - can contribute to the convex hull of a dataset. - """ - assert mask.ndim == 2 - assert mask.dtype == np.bool_ - - nx, ny = mask.shape - out = np.zeros_like(mask) - - # these operations do not short-circuit, but seems to be the best we can do - # NOTE: argmax returns 0 if a slice is all True or all False - first_row = np.argmax(mask, axis=0) - last_row = nx - 1 - np.argmax(mask[::-1, :], axis=0) - first_col = np.argmax(mask, axis=1) - last_col = ny - 1 - np.argmax(mask[:, ::-1], axis=1) - - all_rows = np.arange(nx) - all_cols = np.arange(ny) - - out[first_row, all_cols] = out[last_row, all_cols] = True - out[all_rows, first_col] = out[all_rows, last_col] = True - - # filter all-False slices - out &= mask - - return out - - @staticmethod - def _compute_image_stats_chunked(dataset: 'DatasetReader') -> Optional[Dict[str, Any]]: - """Compute statistics for the given rasterio dataset by looping over chunks.""" - from rasterio import features, warp, windows - from shapely import geometry - - total_count = valid_data_count = 0 - tdigest = TDigest() - sstats = SummaryStats() - convex_hull = geometry.Polygon() - - block_windows = [w for _, w in dataset.block_windows(1)] - - for w in block_windows: - with warnings.catch_warnings(): - warnings.filterwarnings('ignore', message='invalid value encountered.*') - block_data = dataset.read(1, window=w, masked=True) - - # handle NaNs for float rasters - block_data = np.ma.masked_invalid(block_data, copy=False) - - total_count += int(block_data.size) - valid_data = block_data.compressed() - - if valid_data.size == 0: - continue - - valid_data_count += int(valid_data.size) - - if np.any(block_data.mask): - hull_candidates = RasterDriver._hull_candidate_mask(~block_data.mask) - hull_shapes = [geometry.shape(s) for s, _ in features.shapes( - np.ones(hull_candidates.shape, 'uint8'), - mask=hull_candidates, - transform=windows.transform(w, dataset.transform) - )] - else: - w, s, e, n = windows.bounds(w, dataset.transform) - hull_shapes = [geometry.Polygon([(w, s), (e, s), (e, n), (w, n)])] - convex_hull = geometry.MultiPolygon([convex_hull, *hull_shapes]).convex_hull - - tdigest.update(valid_data) - sstats.update(valid_data) - - if sstats.count() == 0: - return None - - convex_hull_wgs = warp.transform_geom( - dataset.crs, 'epsg:4326', geometry.mapping(convex_hull) - ) - - return { - 'valid_percentage': valid_data_count / total_count * 100, - 'range': (sstats.min(), sstats.max()), - 'mean': sstats.mean(), - 'stdev': sstats.std(), - 'percentiles': tdigest.quantile(np.arange(0.01, 1, 0.01)), - 'convex_hull': convex_hull_wgs - } - - @staticmethod - def _compute_image_stats(dataset: 'DatasetReader', - max_shape: Sequence[int] = None) -> Optional[Dict[str, Any]]: - """Compute statistics for the given rasterio dataset by reading it into memory.""" - from rasterio import features, warp, transform - from shapely import geometry - - out_shape = (dataset.height, dataset.width) - - if max_shape is not None: - out_shape = ( - min(max_shape[0], out_shape[0]), - min(max_shape[1], out_shape[1]) - ) - - data_transform = transform.from_bounds( - *dataset.bounds, height=out_shape[0], width=out_shape[1] - ) - raster_data = dataset.read(1, out_shape=out_shape, masked=True) - - if dataset.nodata is not None: - # nodata values might slip into output array if out_shape < dataset.shape - raster_data = np.ma.masked_equal(raster_data, dataset.nodata, copy=False) - - # handle NaNs for float rasters - raster_data = np.ma.masked_invalid(raster_data, copy=False) - - valid_data = raster_data.compressed() - - if valid_data.size == 0: - return None - - if np.any(raster_data.mask): - hull_candidates = RasterDriver._hull_candidate_mask(~raster_data.mask) - hull_shapes = (geometry.shape(s) for s, _ in features.shapes( - np.ones(hull_candidates.shape, 'uint8'), - mask=hull_candidates, - transform=data_transform - )) - convex_hull = geometry.MultiPolygon(hull_shapes).convex_hull - else: - # no masked entries -> convex hull == dataset bounds - w, s, e, n = dataset.bounds - convex_hull = geometry.Polygon([(w, s), (e, s), (e, n), (w, n)]) - - convex_hull_wgs = warp.transform_geom( - dataset.crs, 'epsg:4326', geometry.mapping(convex_hull) - ) - - return { - 'valid_percentage': valid_data.size / raster_data.size * 100, - 'range': (float(valid_data.min()), float(valid_data.max())), - 'mean': float(valid_data.mean()), - 'stdev': float(valid_data.std()), - 'percentiles': np.percentile(valid_data, np.arange(1, 100)), - 'convex_hull': convex_hull_wgs - } - - @classmethod - @trace('compute_metadata') - def compute_metadata(cls, raster_path: str, *, # type: ignore[override] # noqa: F821 - extra_metadata: Any = None, - use_chunks: bool = None, - max_shape: Sequence[int] = None) -> Dict[str, Any]: - """Read given raster file and compute metadata from it. - - This handles most of the heavy lifting during raster ingestion. The returned metadata can - be passed directly to :meth:`insert`. - - Arguments: - - raster_path: Path to GDAL-readable raster file - extra_metadata: Any additional metadata to attach to the dataset. Will be - JSON-serialized and returned verbatim by :meth:`get_metadata`. - use_chunks: Whether to process the image in chunks (slower, but uses less memory). - If not given, use chunks for large images only. - max_shape: Gives the maximum number of pixels used in each dimension to compute - metadata. Setting this to a relatively small size such as ``(1024, 1024)`` will - result in much faster metadata computation for large images, at the expense of - inaccurate results. - - """ - import rasterio - from rasterio import warp - from terracotta.cog import validate - - row_data: Dict[str, Any] = {} - extra_metadata = extra_metadata or {} - - if max_shape is not None and len(max_shape) != 2: - raise ValueError('max_shape argument must contain 2 values') - - if use_chunks and max_shape is not None: - raise ValueError('Cannot use both use_chunks and max_shape arguments') - - with rasterio.Env(**cls._RIO_ENV_KEYS): - if not validate(raster_path): - warnings.warn( - f'Raster file {raster_path} is not a valid cloud-optimized GeoTIFF. ' - 'Any interaction with it will be significantly slower. Consider optimizing ' - 'it through `terracotta optimize-rasters` before ingestion.', - exceptions.PerformanceWarning, stacklevel=3 - ) - - with rasterio.open(raster_path) as src: - if src.nodata is None and not cls._has_alpha_band(src): - warnings.warn( - f'Raster file {raster_path} does not have a valid nodata value, ' - 'and does not contain an alpha band. No data will be masked.' - ) - - bounds = warp.transform_bounds( - src.crs, 'epsg:4326', *src.bounds, densify_pts=21 - ) - - if use_chunks is None and max_shape is None: - use_chunks = src.width * src.height > RasterDriver._LARGE_RASTER_THRESHOLD - - if use_chunks: - logger.debug( - f'Computing metadata for file {raster_path} using more than ' - f'{RasterDriver._LARGE_RASTER_THRESHOLD // 10**6}M pixels, iterating ' - 'over chunks' - ) - - if use_chunks and not has_crick: - warnings.warn( - 'Processing a large raster file, but crick failed to import. ' - 'Reading whole file into memory instead.', exceptions.PerformanceWarning - ) - use_chunks = False - - if use_chunks: - raster_stats = RasterDriver._compute_image_stats_chunked(src) - else: - raster_stats = RasterDriver._compute_image_stats(src, max_shape) - - if raster_stats is None: - raise ValueError(f'Raster file {raster_path} does not contain any valid data') - - row_data.update(raster_stats) - - row_data['bounds'] = bounds - row_data['metadata'] = extra_metadata - - return row_data - - @staticmethod - def _get_resampling_enum(method: str) -> Any: - from rasterio.enums import Resampling - - if method == 'nearest': - return Resampling.nearest - - if method == 'linear': - return Resampling.bilinear - - if method == 'cubic': - return Resampling.cubic - - if method == 'average': - return Resampling.average - - raise ValueError(f'unknown resampling method {method}') - - @staticmethod - def _has_alpha_band(src: 'DatasetReader') -> bool: - from rasterio.enums import MaskFlags, ColorInterp - return ( - any([MaskFlags.alpha in flags for flags in src.mask_flag_enums]) - or ColorInterp.alpha in src.colorinterp - ) - - @classmethod - @trace('get_raster_tile') - def _get_raster_tile(cls, path: str, *, - reprojection_method: str, - resampling_method: str, - tile_bounds: Tuple[float, float, float, float] = None, - tile_size: Tuple[int, int] = (256, 256), - preserve_values: bool = False) -> np.ma.MaskedArray: - """Load a raster dataset from a file through rasterio. - - Heavily inspired by mapbox/rio-tiler - """ - import rasterio - from rasterio import transform, windows, warp - from rasterio.vrt import WarpedVRT - from affine import Affine - - dst_bounds: Tuple[float, float, float, float] - - if preserve_values: - reproject_enum = resampling_enum = cls._get_resampling_enum('nearest') - else: - reproject_enum = cls._get_resampling_enum(reprojection_method) - resampling_enum = cls._get_resampling_enum(resampling_method) - - with contextlib.ExitStack() as es: - es.enter_context(rasterio.Env(**cls._RIO_ENV_KEYS)) - try: - with trace('open_dataset'): - src = es.enter_context(rasterio.open(path)) - except OSError: - raise IOError('error while reading file {}'.format(path)) - - # compute buonds in target CRS - dst_bounds = warp.transform_bounds(src.crs, cls._TARGET_CRS, *src.bounds) - - if tile_bounds is None: - tile_bounds = dst_bounds - - # prevent loads of very sparse data - cover_ratio = ( - (dst_bounds[2] - dst_bounds[0]) / (tile_bounds[2] - tile_bounds[0]) - * (dst_bounds[3] - dst_bounds[1]) / (tile_bounds[3] - tile_bounds[1]) - ) - - if cover_ratio < 0.01: - raise exceptions.TileOutOfBoundsError('dataset covers less than 1% of tile') - - # compute suggested resolution in target CRS - dst_transform, _, _ = warp.calculate_default_transform( - src.crs, cls._TARGET_CRS, src.width, src.height, *src.bounds - ) - dst_res = (abs(dst_transform.a), abs(dst_transform.e)) - - # make sure VRT resolves the entire tile - tile_transform = transform.from_bounds(*tile_bounds, *tile_size) - tile_res = (abs(tile_transform.a), abs(tile_transform.e)) - - if tile_res[0] < dst_res[0] or tile_res[1] < dst_res[1]: - dst_res = tile_res - resampling_enum = cls._get_resampling_enum('nearest') - - # pad tile bounds to prevent interpolation artefacts - num_pad_pixels = 2 - - # compute tile VRT shape and transform - dst_width = max(1, round((tile_bounds[2] - tile_bounds[0]) / dst_res[0])) - dst_height = max(1, round((tile_bounds[3] - tile_bounds[1]) / dst_res[1])) - vrt_transform = ( - transform.from_bounds(*tile_bounds, width=dst_width, height=dst_height) - * Affine.translation(-num_pad_pixels, -num_pad_pixels) - ) - vrt_height, vrt_width = dst_height + 2 * num_pad_pixels, dst_width + 2 * num_pad_pixels - - # remove padding in output - out_window = windows.Window( - col_off=num_pad_pixels, row_off=num_pad_pixels, width=dst_width, height=dst_height - ) - - # construct VRT - vrt = es.enter_context( - WarpedVRT( - src, crs=cls._TARGET_CRS, resampling=reproject_enum, - transform=vrt_transform, width=vrt_width, height=vrt_height, - add_alpha=not cls._has_alpha_band(src) - ) - ) - - # read data - with warnings.catch_warnings(), trace('read_from_vrt'): - warnings.filterwarnings('ignore', message='invalid value encountered.*') - tile_data = vrt.read( - 1, resampling=resampling_enum, window=out_window, out_shape=tile_size - ) - - # assemble alpha mask - mask_idx = vrt.count - mask = vrt.read(mask_idx, window=out_window, out_shape=tile_size) == 0 - - if src.nodata is not None: - mask |= tile_data == src.nodata - - return np.ma.masked_array(tile_data, mask=mask) - - # return type has to be Any until mypy supports conditional return types - @requires_connection - def get_raster_tile(self, - keys: Union[Sequence[str], Mapping[str, str]], *, - tile_bounds: Sequence[float] = None, - tile_size: Sequence[int] = None, - preserve_values: bool = False, - asynchronous: bool = False) -> Any: - # This wrapper handles cache interaction and asynchronous tile retrieval. - # The real work is done in _get_raster_tile. - - future: Future[np.ma.MaskedArray] - result: np.ma.MaskedArray - - settings = get_settings() - key_tuple = tuple(self._key_dict_to_sequence(keys)) - datasets = self.get_datasets(dict(zip(self.key_names, key_tuple))) - assert len(datasets) == 1 - path = datasets[key_tuple] - - if tile_size is None: - tile_size = settings.DEFAULT_TILE_SIZE - - # make sure all arguments are hashable - kwargs = dict( - path=path, - tile_bounds=tuple(tile_bounds) if tile_bounds else None, - tile_size=tuple(tile_size), - preserve_values=preserve_values, - reprojection_method=settings.REPROJECTION_METHOD, - resampling_method=settings.RESAMPLING_METHOD - ) - - cache_key = hash(tuple(kwargs.items())) - - try: - with self._cache_lock: - result = self._raster_cache[cache_key] - except KeyError: - pass - else: - if asynchronous: - # wrap result in a future - future = Future() - future.set_result(result) - return future - else: - return result - - retrieve_tile = functools.partial(self._get_raster_tile, **kwargs) - - future = submit_to_executor(retrieve_tile) - - def cache_callback(future: Future) -> None: - # insert result into global cache if execution was successful - if future.exception() is None: - self._add_to_cache(cache_key, future.result()) - - if asynchronous: - future.add_done_callback(cache_callback) - return future - else: - result = future.result() - cache_callback(future) - return result - - def _add_to_cache(self, key: Any, value: Any) -> None: - try: - with self._cache_lock: - self._raster_cache[key] = value - except ValueError: # value too large - pass diff --git a/terracotta/drivers/relational_meta_store.py b/terracotta/drivers/relational_meta_store.py new file mode 100644 index 00000000..2ccdf45f --- /dev/null +++ b/terracotta/drivers/relational_meta_store.py @@ -0,0 +1,438 @@ +"""drivers/relational_meta_store.py + +Base class for relational database drivers, using SQLAlchemy. +""" + +import contextlib +import functools +import json +import re +import urllib.parse as urlparse +from abc import ABC, abstractmethod +from collections import OrderedDict +from typing import (Any, Dict, Iterator, Mapping, Optional, Sequence, Tuple, + Type, Union) + +import numpy as np +import sqlalchemy as sqla +import terracotta +from sqlalchemy.engine.base import Connection +from sqlalchemy.engine.url import URL +from terracotta import exceptions +from terracotta.drivers.base_classes import (KeysType, MetaStore, + MultiValueKeysType, + requires_connection) +from terracotta.profile import trace + +_ERROR_ON_CONNECT = ( + 'Could not connect to database. Make sure that the given path points ' + 'to a valid Terracotta database, and that you ran driver.create().' +) + +DATABASE_DRIVER_EXCEPTIONS_TO_CONVERT: Tuple[Type[Exception], ...] = ( + sqla.exc.OperationalError, + sqla.exc.InternalError, + sqla.exc.ProgrammingError, + sqla.exc.InvalidRequestError, +) + +ExceptionType = Union[Type[Exception], Tuple[Type[Exception], ...]] + + +@contextlib.contextmanager +def convert_exceptions( + error_message: str, + exceptions_to_convert: ExceptionType = DATABASE_DRIVER_EXCEPTIONS_TO_CONVERT, +) -> Iterator: + try: + yield + except exceptions_to_convert as exception: + raise exceptions.InvalidDatabaseError(error_message) from exception + + +class RelationalMetaStore(MetaStore, ABC): + SQL_DIALECT: str # The database flavour, eg mysql, sqlite, etc + SQL_DRIVER: str # The actual database driver, eg pymysql, sqlite3, etc + SQL_KEY_SIZE: int + SQL_TIMEOUT_KEY: str + + SQLA_STRING = sqla.types.String + SQLA_METADATA_TYPE_LOOKUP: Dict[str, sqla.types.TypeEngine] = { + 'real': functools.partial(sqla.types.Float, precision=8), + 'text': sqla.types.Text, + 'blob': sqla.types.LargeBinary + } + + _METADATA_COLUMNS: Tuple[Tuple[str, str], ...] = ( + ('bounds_north', 'real'), + ('bounds_east', 'real'), + ('bounds_south', 'real'), + ('bounds_west', 'real'), + ('convex_hull', 'text'), + ('valid_percentage', 'real'), + ('min', 'real'), + ('max', 'real'), + ('mean', 'real'), + ('stdev', 'real'), + ('percentiles', 'blob'), + ('metadata', 'text') + ) + + def __init__(self, path: str) -> None: + settings = terracotta.get_settings() + db_connection_timeout: int = settings.DB_CONNECTION_TIMEOUT + + self.url = self._parse_path(path) + self.sqla_engine = sqla.create_engine( + self.url, + echo=False, + future=True, + connect_args={self.SQL_TIMEOUT_KEY: db_connection_timeout} + ) + self.sqla_metadata = sqla.MetaData() + + self._db_keys: Optional[OrderedDict] = None + + self.connection: Connection + self.connected: bool = False + self.db_version_verified: bool = False + + # use normalized path to make sure username and password don't leak into __repr__ + super().__init__(self._normalize_path(path)) + + @classmethod + def _parse_path(cls, connection_string: str) -> URL: + if "//" not in connection_string: + connection_string = f"//{connection_string}" + + con_params = urlparse.urlparse(connection_string) + + if not con_params.scheme: + con_params = urlparse.urlparse(f'{cls.SQL_DIALECT}:{connection_string}') + + if con_params.scheme != cls.SQL_DIALECT: + raise ValueError(f'unsupported URL scheme "{con_params.scheme}"') + + url = URL.create( + drivername=f'{cls.SQL_DIALECT}+{cls.SQL_DRIVER}', + username=con_params.username, + password=con_params.password, + host=con_params.hostname, + port=con_params.port, + database=con_params.path[1:], # remove leading '/' from urlparse + query=dict(urlparse.parse_qsl(con_params.query)) + ) + + return url + + @contextlib.contextmanager + def connect(self, verify: bool = True) -> Iterator: + @convert_exceptions(_ERROR_ON_CONNECT, sqla.exc.OperationalError) + def get_connection() -> Connection: + return self.sqla_engine.connect().execution_options(isolation_level='READ UNCOMMITTED') + + if not self.connected: + try: + with get_connection() as connection: + self.connection = connection + self.connected = True + if verify: + self._connection_callback() + + yield + self.connection.commit() + finally: + self.connected = False + self.connection = None + else: + try: + yield + except Exception as exception: + self.connection.rollback() + raise exception + + def _connection_callback(self) -> None: + if not self.db_version_verified: + # check for version compatibility + def version_tuple(version_string: str) -> Sequence[str]: + return version_string.split('.') + + db_version = self.db_version + current_version = terracotta.__version__ + + if version_tuple(db_version)[:2] != version_tuple(current_version)[:2]: + raise exceptions.InvalidDatabaseError( + f'Version conflict: database was created in v{db_version}, ' + f'but this is v{current_version}' + ) + self.db_version_verified = True + + @property # type: ignore + @requires_connection + @convert_exceptions(_ERROR_ON_CONNECT) + def db_version(self) -> str: + """Terracotta version used to create the database""" + terracotta_table = sqla.Table( + 'terracotta', + self.sqla_metadata, + autoload_with=self.sqla_engine + ) + stmt = sqla.select(terracotta_table.c.version) + version = self.connection.execute(stmt).scalar() + return version + + @convert_exceptions('Could not create database') + def create(self, keys: Sequence[str], key_descriptions: Mapping[str, str] = None) -> None: + """Create and initialize database with empty tables. + + This must be called before opening the first connection. The MySQL database must not + exist already. + + Arguments: + + keys: Key names to use throughout the Terracotta database. + key_descriptions: Optional (but recommended) full-text description for some keys, + in the form of ``{key_name: description}``. + + """ + self._create_database() + self._initialize_database(keys, key_descriptions) + + @abstractmethod + def _create_database(self) -> None: + # Note that some subclasses may not actually create any database here, as + # it may be created automatically on connection for some database vendors + pass + + @requires_connection(verify=False) + def _initialize_database( + self, + keys: Sequence[str], + key_descriptions: Mapping[str, str] = None + ) -> None: + if key_descriptions is None: + key_descriptions = {} + else: + key_descriptions = dict(key_descriptions) + + if not all(k in keys for k in key_descriptions.keys()): + raise exceptions.InvalidKeyError('key description dict contains unknown keys') + + if not all(re.match(r'^\w+$', key) for key in keys): + raise exceptions.InvalidKeyError('key names must be alphanumeric') + + if any(key in self._RESERVED_KEYS for key in keys): + raise exceptions.InvalidKeyError(f'key names cannot be one of {self._RESERVED_KEYS!s}') + + terracotta_table = sqla.Table( + 'terracotta', self.sqla_metadata, + sqla.Column('version', self.SQLA_STRING(255), primary_key=True) + ) + key_names_table = sqla.Table( + 'key_names', self.sqla_metadata, + sqla.Column('key_name', self.SQLA_STRING(self.SQL_KEY_SIZE), primary_key=True), + sqla.Column('description', self.SQLA_STRING(8000)), + sqla.Column('index', sqla.types.Integer, unique=True) + ) + _ = sqla.Table( + 'datasets', self.sqla_metadata, + *[ + sqla.Column(key, self.SQLA_STRING(self.SQL_KEY_SIZE), primary_key=True) + for key in keys + ], + sqla.Column('path', self.SQLA_STRING(8000)) + ) + _ = sqla.Table( + 'metadata', self.sqla_metadata, + *[ + sqla.Column(key, self.SQLA_STRING(self.SQL_KEY_SIZE), primary_key=True) + for key in keys], + *[ + sqla.Column(name, self.SQLA_METADATA_TYPE_LOOKUP[column_type]()) + for name, column_type in self._METADATA_COLUMNS + ] + ) + self.sqla_metadata.create_all(self.sqla_engine) + + self.connection.execute( + terracotta_table.insert().values(version=terracotta.__version__) + ) + self.connection.execute( + key_names_table.insert(), + [ + dict(key_name=key, description=key_descriptions.get(key, ''), index=i) + for i, key in enumerate(keys) + ] + ) + + @requires_connection + @convert_exceptions('Could not retrieve keys from database') + def get_keys(self) -> OrderedDict: + keys_table = sqla.Table('key_names', self.sqla_metadata, autoload_with=self.sqla_engine) + result = self.connection.execute( + sqla.select( + keys_table.c.get('key_name'), + keys_table.c.get('description') + ) + .order_by(keys_table.c.get('index'))) + return OrderedDict(result.all()) + + @property + def key_names(self) -> Tuple[str, ...]: + """Names of all keys defined by the database""" + if self._db_keys is None: + self._db_keys = self.get_keys() + return tuple(self._db_keys.keys()) + + @trace('get_datasets') + @requires_connection + @convert_exceptions('Could not retrieve datasets') + def get_datasets( + self, + where: MultiValueKeysType = None, + page: int = 0, + limit: int = None + ) -> Dict[Tuple[str, ...], str]: + if where is None: + where = {} + + where = { + key: value if isinstance(value, list) else [value] + for key, value in where.items() + } + + datasets_table = sqla.Table('datasets', self.sqla_metadata, autoload_with=self.sqla_engine) + stmt = ( + datasets_table + .select() + .where( + *[ + sqla.or_(*[datasets_table.c.get(column) == value for value in values]) + for column, values in where.items() + ] + ) + .order_by(*datasets_table.c.values()) + .limit(limit) + .offset(page * limit if limit is not None else None) + ) + + result = self.connection.execute(stmt) + + def keytuple(row: Dict[str, Any]) -> Tuple[str, ...]: + return tuple(row[key] for key in self.key_names) + + datasets = {keytuple(row): row['path'] for row in result} + return datasets + + @trace('get_metadata') + @requires_connection + @convert_exceptions('Could not retrieve metadata') + def get_metadata(self, keys: KeysType) -> Optional[Dict[str, Any]]: + metadata_table = sqla.Table('metadata', self.sqla_metadata, autoload_with=self.sqla_engine) + stmt = ( + metadata_table + .select() + .where( + *[ + metadata_table.c.get(key) == value + for key, value in keys.items() + ] + ) + ) + + row = self.connection.execute(stmt).first() + if not row: + return None + + data_columns, _ = zip(*self._METADATA_COLUMNS) + encoded_data = {col: row[col] for col in self.key_names + data_columns} + return self._decode_data(encoded_data) + + @trace('insert') + @requires_connection + @convert_exceptions('Could not write to database') + def insert( + self, + keys: KeysType, + path: str, *, + metadata: Mapping[str, Any] = None + ) -> None: + datasets_table = sqla.Table('datasets', self.sqla_metadata, autoload_with=self.sqla_engine) + metadata_table = sqla.Table('metadata', self.sqla_metadata, autoload_with=self.sqla_engine) + + self.connection.execute( + datasets_table + .delete() + .where(*[datasets_table.c.get(column) == value for column, value in keys.items()]) + ) + self.connection.execute( + datasets_table.insert().values(**keys, path=path) + ) + + if metadata is not None: + encoded_data = self._encode_data(metadata) + self.connection.execute( + metadata_table + .delete() + .where( + *[metadata_table.c.get(column) == value for column, value in keys.items()] + ) + ) + self.connection.execute( + metadata_table.insert().values(**keys, **encoded_data) + ) + + @trace('delete') + @requires_connection + @convert_exceptions('Could not write to database') + def delete(self, keys: KeysType) -> None: + if not self.get_datasets(keys): + raise exceptions.DatasetNotFoundError(f'No dataset found with keys {keys}') + + datasets_table = sqla.Table('datasets', self.sqla_metadata, autoload_with=self.sqla_engine) + metadata_table = sqla.Table('metadata', self.sqla_metadata, autoload_with=self.sqla_engine) + + self.connection.execute( + datasets_table + .delete() + .where(*[datasets_table.c.get(column) == value for column, value in keys.items()]) + ) + self.connection.execute( + metadata_table + .delete() + .where(*[metadata_table.c.get(column) == value for column, value in keys.items()]) + ) + + @staticmethod + def _encode_data(decoded: Mapping[str, Any]) -> Dict[str, Any]: + """Transform from internal format to database representation""" + encoded = { + 'bounds_north': decoded['bounds'][0], + 'bounds_east': decoded['bounds'][1], + 'bounds_south': decoded['bounds'][2], + 'bounds_west': decoded['bounds'][3], + 'convex_hull': json.dumps(decoded['convex_hull']), + 'valid_percentage': decoded['valid_percentage'], + 'min': decoded['range'][0], + 'max': decoded['range'][1], + 'mean': decoded['mean'], + 'stdev': decoded['stdev'], + 'percentiles': np.array(decoded['percentiles'], dtype='float32').tobytes(), + 'metadata': json.dumps(decoded['metadata']) + } + return encoded + + @staticmethod + def _decode_data(encoded: Mapping[str, Any]) -> Dict[str, Any]: + """Transform from database format to internal representation""" + decoded = { + 'bounds': tuple([encoded[f'bounds_{d}'] for d in ('north', 'east', 'south', 'west')]), + 'convex_hull': json.loads(encoded['convex_hull']), + 'valid_percentage': encoded['valid_percentage'], + 'range': (encoded['min'], encoded['max']), + 'mean': encoded['mean'], + 'stdev': encoded['stdev'], + 'percentiles': np.frombuffer(encoded['percentiles'], dtype='float32').tolist(), + 'metadata': json.loads(encoded['metadata']) + } + return decoded diff --git a/terracotta/drivers/sqlite.py b/terracotta/drivers/sqlite.py deleted file mode 100644 index 0cfd68cc..00000000 --- a/terracotta/drivers/sqlite.py +++ /dev/null @@ -1,392 +0,0 @@ -"""drivers/sqlite.py - -SQLite-backed raster driver. Metadata is stored in an SQLite database, raster data is assumed -to be present on disk. -""" - -from typing import Any, List, Sequence, Mapping, Tuple, Union, Iterator, Dict, cast -import os -import contextlib -from contextlib import AbstractContextManager -import json -import re -import sqlite3 -from sqlite3 import Connection -from pathlib import Path -from collections import OrderedDict - -import numpy as np - -from terracotta import get_settings, exceptions, __version__ -from terracotta.profile import trace -from terracotta.drivers.base import requires_connection -from terracotta.drivers.raster_base import RasterDriver - -_ERROR_ON_CONNECT = ( - 'Could not connect to database. Make sure that the given path points ' - 'to a valid Terracotta database, and that you ran driver.create().' -) - - -@contextlib.contextmanager -def convert_exceptions(msg: str) -> Iterator: - """Convert internal sqlite exceptions to our InvalidDatabaseError""" - try: - yield - except sqlite3.OperationalError as exc: - raise exceptions.InvalidDatabaseError(msg) from exc - - -class SQLiteDriver(RasterDriver): - """An SQLite-backed raster driver. - - Assumes raster data to be present in separate GDAL-readable files on disk or remotely. - Stores metadata and paths to raster files in SQLite. - - This is the simplest Terracotta driver, as it requires no additional infrastructure. - The SQLite database is simply a file that can be stored together with the actual - raster files. - - Note: - - This driver requires the SQLite database to be physically present on the server. - For remote SQLite databases hosted on S3, use - :class:`~terracotta.drivers.sqlite_remote.RemoteSQLiteDriver`. - - The SQLite database consists of 4 different tables: - - - ``terracotta``: Metadata about the database itself. - - ``keys``: Contains two columns holding all available keys and their description. - - ``datasets``: Maps key values to physical raster path. - - ``metadata``: Contains actual metadata as separate columns. Indexed via key values. - - This driver caches raster data, but not metadata. - - Warning: - - This driver is not thread-safe. It is not possible to connect to the database - outside the main thread. - - """ - _KEY_TYPE: str = 'VARCHAR[256]' - _METADATA_COLUMNS: Tuple[Tuple[str, ...], ...] = ( - ('bounds_north', 'REAL'), - ('bounds_east', 'REAL'), - ('bounds_south', 'REAL'), - ('bounds_west', 'REAL'), - ('convex_hull', 'VARCHAR[max]'), - ('valid_percentage', 'REAL'), - ('min', 'REAL'), - ('max', 'REAL'), - ('mean', 'REAL'), - ('stdev', 'REAL'), - ('percentiles', 'BLOB'), - ('metadata', 'VARCHAR[max]') - ) - - def __init__(self, path: Union[str, Path]) -> None: - """Initialize the SQLiteDriver. - - This should not be called directly, use :func:`~terracotta.get_driver` instead. - - Arguments: - - path: File path to target SQLite database (may or may not exist yet) - - """ - path = str(path) - - settings = get_settings() - self.DB_CONNECTION_TIMEOUT: int = settings.DB_CONNECTION_TIMEOUT - self.LAZY_LOADING_MAX_SHAPE: Tuple[int, int] = settings.LAZY_LOADING_MAX_SHAPE - - self._connection: Connection - self._connected = False - - super().__init__(os.path.realpath(path)) - - @classmethod - def _normalize_path(cls, path: str) -> str: - return os.path.normpath(os.path.realpath(path)) - - def connect(self) -> AbstractContextManager: - return self._connect(check=True) - - @contextlib.contextmanager - def _connect(self, check: bool = True) -> Iterator: - try: - close = False - if not self._connected: - with convert_exceptions(_ERROR_ON_CONNECT): - self._connection = sqlite3.connect( - self.path, timeout=self.DB_CONNECTION_TIMEOUT - ) - self._connection.row_factory = sqlite3.Row - self._connected = close = True - - if check: - self._connection_callback() - - try: - yield - except Exception: - self._connection.rollback() - raise - - finally: - if close: - self._connected = False - self._connection.commit() - self._connection.close() - - @requires_connection - @convert_exceptions(_ERROR_ON_CONNECT) - def _get_db_version(self) -> str: - """Terracotta version used to create the database""" - conn = self._connection - db_row = conn.execute('SELECT version from terracotta').fetchone() - return db_row['version'] - - db_version = cast(str, property(_get_db_version)) - - def _connection_callback(self) -> None: - """Called after opening a new connection""" - # check for version compatibility - def versiontuple(version_string: str) -> Sequence[str]: - return version_string.split('.') - - db_version = self.db_version - current_version = __version__ - - if versiontuple(db_version)[:2] != versiontuple(current_version)[:2]: - raise exceptions.InvalidDatabaseError( - f'Version conflict: database was created in v{db_version}, ' - f'but this is v{current_version}' - ) - - def _get_key_names(self) -> Tuple[str, ...]: - """Names of all keys defined by the database""" - return tuple(self.get_keys().keys()) - - key_names = cast(Tuple[str], property(_get_key_names)) - - @convert_exceptions('Could not create database') - def create(self, keys: Sequence[str], key_descriptions: Mapping[str, str] = None) -> None: - """Create and initialize database with empty tables. - - This must be called before opening the first connection. Tables must not exist already. - - Arguments: - - keys: Key names to use throughout the Terracotta database. - key_descriptions: Optional (but recommended) full-text description for some keys, - in the form of ``{key_name: description}``. - - """ - if key_descriptions is None: - key_descriptions = {} - else: - key_descriptions = dict(key_descriptions) - - if not all(k in keys for k in key_descriptions.keys()): - raise exceptions.InvalidKeyError('key description dict contains unknown keys') - - if not all(re.match(r'^\w+$', key) for key in keys): - raise exceptions.InvalidKeyError('key names must be alphanumeric') - - if any(key in self._RESERVED_KEYS for key in keys): - raise exceptions.InvalidKeyError(f'key names cannot be one of {self._RESERVED_KEYS!s}') - - for key in keys: - if key not in key_descriptions: - key_descriptions[key] = '' - - with self._connect(check=False): - conn = self._connection - conn.execute('CREATE TABLE terracotta (version VARCHAR[255])') - conn.execute('INSERT INTO terracotta VALUES (?)', [str(__version__)]) - - conn.execute(f'CREATE TABLE keys (key {self._KEY_TYPE}, description VARCHAR[max])') - key_rows = [(key, key_descriptions[key]) for key in keys] - conn.executemany('INSERT INTO keys VALUES (?, ?)', key_rows) - - key_string = ', '.join([f'{key} {self._KEY_TYPE}' for key in keys]) - conn.execute(f'CREATE TABLE datasets ({key_string}, filepath VARCHAR[8000], ' - f'PRIMARY KEY({", ".join(keys)}))') - - column_string = ', '.join(f'{col} {col_type}' for col, col_type - in self._METADATA_COLUMNS) - conn.execute(f'CREATE TABLE metadata ({key_string}, {column_string}, ' - f'PRIMARY KEY ({", ".join(keys)}))') - - @requires_connection - @convert_exceptions('Could not retrieve keys from database') - def get_keys(self) -> OrderedDict: - conn = self._connection - key_rows = conn.execute('SELECT * FROM keys') - - out: OrderedDict = OrderedDict() - for row in key_rows: - out[row['key']] = row['description'] - return out - - @trace('get_datasets') - @requires_connection - @convert_exceptions('Could not retrieve datasets') - def get_datasets(self, where: Mapping[str, Union[str, List[str]]] = None, - page: int = 0, limit: int = None) -> Dict[Tuple[str, ...], str]: - conn = self._connection - - if limit is not None: - # explicitly cast to int to prevent SQL injection - page_fragment = f'LIMIT {int(limit)} OFFSET {int(page) * int(limit)}' - else: - page_fragment = '' - - # sort by keys to ensure deterministic results - order_fragment = f'ORDER BY {", ".join(self.key_names)}' - - if where is None: - rows = conn.execute(f'SELECT * FROM datasets {order_fragment} {page_fragment}') - else: - if not all(key in self.key_names for key in where.keys()): - raise exceptions.InvalidKeyError('Encountered unrecognized keys in ' - 'where clause') - conditions = [] - values = [] - for key, value in where.items(): - if isinstance(value, str): - value = [value] - values.extend(value) - conditions.append(' OR '.join([f'{key}=?'] * len(value))) - where_fragment = ' AND '.join([f'({condition})' for condition in conditions]) - rows = conn.execute( - f'SELECT * FROM datasets WHERE {where_fragment} {order_fragment} {page_fragment}', - values - ) - - def keytuple(row: Dict[str, Any]) -> Tuple[str, ...]: - return tuple(row[key] for key in self.key_names) - - return {keytuple(row): row['filepath'] for row in rows} - - @staticmethod - def _encode_data(decoded: Mapping[str, Any]) -> Dict[str, Any]: - """Transform from internal format to database representation""" - encoded = { - 'bounds_north': decoded['bounds'][0], - 'bounds_east': decoded['bounds'][1], - 'bounds_south': decoded['bounds'][2], - 'bounds_west': decoded['bounds'][3], - 'convex_hull': json.dumps(decoded['convex_hull']), - 'valid_percentage': decoded['valid_percentage'], - 'min': decoded['range'][0], - 'max': decoded['range'][1], - 'mean': decoded['mean'], - 'stdev': decoded['stdev'], - 'percentiles': np.array(decoded['percentiles'], dtype='float32').tobytes(), - 'metadata': json.dumps(decoded['metadata']) - } - return encoded - - @staticmethod - def _decode_data(encoded: Mapping[str, Any]) -> Dict[str, Any]: - """Transform from database format to internal representation""" - decoded = { - 'bounds': tuple([encoded[f'bounds_{d}'] for d in ('north', 'east', 'south', 'west')]), - 'convex_hull': json.loads(encoded['convex_hull']), - 'valid_percentage': encoded['valid_percentage'], - 'range': (encoded['min'], encoded['max']), - 'mean': encoded['mean'], - 'stdev': encoded['stdev'], - 'percentiles': np.frombuffer(encoded['percentiles'], dtype='float32').tolist(), - 'metadata': json.loads(encoded['metadata']) - } - return decoded - - @trace('get_metadata') - @requires_connection - @convert_exceptions('Could not retrieve metadata') - def get_metadata(self, keys: Union[Sequence[str], Mapping[str, str]]) -> Dict[str, Any]: - keys = tuple(self._key_dict_to_sequence(keys)) - - if len(keys) != len(self.key_names): - raise exceptions.InvalidKeyError( - f'Got wrong number of keys (available keys: {self.key_names})' - ) - - conn = self._connection - - where_string = ' AND '.join([f'{key}=?' for key in self.key_names]) - row = conn.execute(f'SELECT * FROM metadata WHERE {where_string}', keys).fetchone() - - if not row: # support lazy loading - filepath = self.get_datasets(dict(zip(self.key_names, keys)), page=0, limit=1) - if not filepath: - raise exceptions.DatasetNotFoundError(f'No dataset found for given keys {keys}') - - # compute metadata and try again - metadata = self.compute_metadata(filepath[keys], max_shape=self.LAZY_LOADING_MAX_SHAPE) - self.insert(keys, filepath[keys], metadata=metadata) - row = conn.execute(f'SELECT * FROM metadata WHERE {where_string}', keys).fetchone() - - assert row - - data_columns, _ = zip(*self._METADATA_COLUMNS) - encoded_data = {col: row[col] for col in self.key_names + data_columns} - return self._decode_data(encoded_data) - - @trace('insert') - @requires_connection - @convert_exceptions('Could not write to database') - def insert(self, - keys: Union[Sequence[str], Mapping[str, str]], - filepath: str, *, - metadata: Mapping[str, Any] = None, - skip_metadata: bool = False, - override_path: str = None) -> None: - conn = self._connection - - if len(keys) != len(self.key_names): - raise exceptions.InvalidKeyError( - f'Got wrong number of keys (available keys: {self.key_names})' - ) - - if override_path is None: - override_path = filepath - - keys = self._key_dict_to_sequence(keys) - template_string = ', '.join(['?'] * (len(keys) + 1)) - conn.execute(f'INSERT OR REPLACE INTO datasets VALUES ({template_string})', - [*keys, override_path]) - - if metadata is None and not skip_metadata: - metadata = self.compute_metadata(filepath) - - if metadata is not None: - encoded_data = self._encode_data(metadata) - row_keys, row_values = zip(*encoded_data.items()) - template_string = ', '.join(['?'] * (len(keys) + len(row_values))) - conn.execute(f'INSERT OR REPLACE INTO metadata ({", ".join(self.key_names)}, ' - f'{", ".join(row_keys)}) VALUES ({template_string})', [*keys, *row_values]) - - @trace('delete') - @requires_connection - @convert_exceptions('Could not write to database') - def delete(self, keys: Union[Sequence[str], Mapping[str, str]]) -> None: - conn = self._connection - - if len(keys) != len(self.key_names): - raise exceptions.InvalidKeyError( - f'Got wrong number of keys (available keys: {self.key_names})' - ) - - keys = self._key_dict_to_sequence(keys) - key_dict = dict(zip(self.key_names, keys)) - - if not self.get_datasets(key_dict): - raise exceptions.DatasetNotFoundError(f'No dataset found with keys {keys}') - - where_string = ' AND '.join([f'{key}=?' for key in self.key_names]) - conn.execute(f'DELETE FROM datasets WHERE {where_string}', keys) - conn.execute(f'DELETE FROM metadata WHERE {where_string}', keys) diff --git a/terracotta/drivers/sqlite_meta_store.py b/terracotta/drivers/sqlite_meta_store.py new file mode 100644 index 00000000..1d504d2e --- /dev/null +++ b/terracotta/drivers/sqlite_meta_store.py @@ -0,0 +1,71 @@ +"""drivers/sqlite_meta_store.py + +SQLite-backed metadata driver. Metadata is stored in an SQLite database. +""" + +import os +from pathlib import Path +from typing import Union + +from terracotta.drivers.relational_meta_store import RelationalMetaStore + + +class SQLiteMetaStore(RelationalMetaStore): + """An SQLite-backed metadata driver. + + Stores metadata and paths to raster files in SQLite. + + This is the simplest Terracotta driver, as it requires no additional infrastructure. + The SQLite database is simply a file that can e.g. be stored together with the actual + raster files. + + Note: + + This driver requires the SQLite database to be physically present on the server. + For remote SQLite databases hosted on S3, use + :class:`~terracotta.drivers.sqlite_remote.RemoteSQLiteDriver`. + + The SQLite database consists of 4 different tables: + + - ``terracotta``: Metadata about the database itself. + - ``keys``: Contains two columns holding all available keys and their description. + - ``datasets``: Maps key values to physical raster path. + - ``metadata``: Contains actual metadata as separate columns. Indexed via key values. + + This driver caches key names, but not metadata. + + Warning: + + This driver is not thread-safe. It is not possible to connect to the database + outside the main thread. + + """ + SQL_DIALECT = 'sqlite' + SQL_DRIVER = 'pysqlite' + SQL_KEY_SIZE = 256 + SQL_TIMEOUT_KEY = 'timeout' + + def __init__(self, path: Union[str, Path]) -> None: + """Initialize the SQLiteDriver. + + This should not be called directly, use :func:`~terracotta.get_driver` instead. + + Arguments: + + path: File path to target SQLite database (may or may not exist yet) + + """ + super().__init__(f'{self.SQL_DIALECT}:///{path}') + + @classmethod + def _normalize_path(cls, path: str) -> str: + if path.startswith(f'{cls.SQL_DIALECT}:///'): + path = path.replace(f'{cls.SQL_DIALECT}:///', '') + + return os.path.normpath(os.path.realpath(path)) + + def _create_database(self) -> None: + """The database is automatically created by the sqlite driver on connection, + so no need to do anything here + """ + pass diff --git a/terracotta/drivers/sqlite_remote.py b/terracotta/drivers/sqlite_remote_meta_store.py similarity index 81% rename from terracotta/drivers/sqlite_remote.py rename to terracotta/drivers/sqlite_remote_meta_store.py index b4f5de37..e40968ad 100644 --- a/terracotta/drivers/sqlite_remote.py +++ b/terracotta/drivers/sqlite_remote_meta_store.py @@ -1,20 +1,20 @@ -"""drivers/sqlite.py +"""drivers/sqlite_remote_meta_store.py -SQLite-backed raster driver. Metadata is stored in an SQLite database, raster data is assumed -to be present on disk. +SQLite-backed metadata driver. Metadata is stored in an SQLite database. """ -from typing import Any, Iterator +import contextlib +import logging import os -import time -import tempfile import shutil -import logging -import contextlib +import tempfile +import time import urllib.parse as urlparse +from pathlib import Path +from typing import Any, Iterator, Union -from terracotta import get_settings, exceptions -from terracotta.drivers.sqlite import SQLiteDriver +from terracotta import exceptions, get_settings +from terracotta.drivers.sqlite_meta_store import SQLiteMetaStore from terracotta.profile import trace logger = logging.getLogger(__name__) @@ -22,12 +22,11 @@ @contextlib.contextmanager def convert_exceptions(msg: str) -> Iterator: - """Convert internal sqlite and boto exceptions to our InvalidDatabaseError""" - import sqlite3 + """Convert internal boto exceptions to our InvalidDatabaseError""" import botocore.exceptions try: yield - except (sqlite3.OperationalError, botocore.exceptions.ClientError) as exc: + except botocore.exceptions.ClientError as exc: raise exceptions.InvalidDatabaseError(msg) from exc @@ -50,10 +49,9 @@ def _update_from_s3(remote_path: str, local_path: str) -> None: shutil.copyfileobj(obj_bytes, f) -class RemoteSQLiteDriver(SQLiteDriver): - """An SQLite-backed raster driver, where the database file is stored remotely on S3. +class RemoteSQLiteMetaStore(SQLiteMetaStore): + """An SQLite-backed metadata driver, where the database file is stored remotely on S3. - Assumes raster data to be present in separate GDAL-readable files on disk or remotely. Stores metadata and paths to raster files in SQLite. See also: @@ -61,7 +59,7 @@ class RemoteSQLiteDriver(SQLiteDriver): :class:`~terracotta.drivers.sqlite.SQLiteDriver` for the local version of this driver. - The SQLite database is simply a file that can be stored together with the actual + The SQLite database is simply a file that can be stored e.g. together with the actual raster files on S3. Before handling the first request, this driver will download a temporary copy of the remote database file. It is thus not feasible for large databases. @@ -75,7 +73,7 @@ class RemoteSQLiteDriver(SQLiteDriver): """ - def __init__(self, remote_path: str) -> None: + def __init__(self, remote_path: Union[str, Path]) -> None: """Initialize the RemoteSQLiteDriver. This should not be called directly, use :func:`~terracotta.get_driver` instead. @@ -99,6 +97,7 @@ def __init__(self, remote_path: str) -> None: ) local_db_file.close() + self._local_path = local_db_file.name self._remote_path = str(remote_path) self._last_updated = -float('inf') @@ -130,7 +129,7 @@ def _update_db(self, remote_path: str, local_path: str) -> None: self._last_updated = time.time() def _connection_callback(self) -> None: - self._update_db(self._remote_path, self.path) + self._update_db(self._remote_path, self._local_path) super()._connection_callback() def create(self, *args: Any, **kwargs: Any) -> None: @@ -144,4 +143,4 @@ def delete(self, *args: Any, **kwargs: Any) -> None: def __del__(self) -> None: """Clean up temporary database upon exit""" - self.__rm(self.path) + self.__rm(self._local_path) diff --git a/terracotta/drivers/terracotta_driver.py b/terracotta/drivers/terracotta_driver.py new file mode 100644 index 00000000..e0ac77d1 --- /dev/null +++ b/terracotta/drivers/terracotta_driver.py @@ -0,0 +1,349 @@ +"""drivers/terracotta_driver.py + +The driver to interact with. +""" + +import contextlib +from collections import OrderedDict +from typing import (Any, Collection, Dict, List, Mapping, Optional, Sequence, Tuple, TypeVar, + Union) + +import terracotta +from terracotta import exceptions +from terracotta.drivers.base_classes import (KeysType, MetaStore, + MultiValueKeysType, RasterStore, + requires_connection) + +ExtendedKeysType = Union[Sequence[str], Mapping[str, str]] +ExtendedMultiValueKeysType = Union[Sequence[str], Mapping[str, Union[str, List[str]]]] +T = TypeVar('T') + + +def squeeze(iterable: Collection[T]) -> T: + assert len(iterable) == 1 + return next(iter(iterable)) + + +class TerracottaDriver: + """Terracotta driver object used to retrieve raster tiles and metadata. + + Do not instantiate directly, use :func:`terracotta.get_driver` instead. + """ + def __init__(self, meta_store: MetaStore, raster_store: RasterStore) -> None: + self.meta_store = meta_store + self.raster_store = raster_store + + settings = terracotta.get_settings() + self.LAZY_LOADING_MAX_SHAPE: Tuple[int, int] = settings.LAZY_LOADING_MAX_SHAPE + + @property + def db_version(self) -> str: + """Terracotta version used to create the meta store. + + Returns: + + A str specifying the version of Terracotta that was used to create the meta store. + + """ + return self.meta_store.db_version + + @property + def key_names(self) -> Tuple[str, ...]: + """Get names of all keys defined by the meta store. + + Returns: + + A tuple defining the key names and order. + + """ + return self.meta_store.key_names + + def create(self, keys: Sequence[str], *, + key_descriptions: Mapping[str, str] = None) -> None: + """Create a new, empty metadata store. + + Arguments: + + keys: A sequence defining the key names and order. + key_descriptions: A mapping from key name to a human-readable + description of what the key encodes. + + """ + self.meta_store.create(keys=keys, key_descriptions=key_descriptions) + + def connect(self, verify: bool = True) -> contextlib.AbstractContextManager: + """Context manager to connect to the metastore and clean up on exit. + + This allows you to pool interactions with the metastore to prevent possibly + expensive reconnects, or to roll back several interactions if one of them fails. + + Arguments: + + verify: Whether to verify the metastore (primarily its version) when connecting. + Should be `true` unless absolutely necessary, such as when instantiating the + metastore during creation of it. + + Note: + + Make sure to call :meth:`create` on a fresh metastore before using this method. + + Example: + + >>> import terracotta as tc + >>> driver = tc.get_driver('tc.sqlite') + >>> with driver.connect(): + ... for keys, dataset in datasets.items(): + ... # connection will be kept open between insert operations + ... driver.insert(keys, dataset) + + """ + return self.meta_store.connect(verify=verify) + + @requires_connection + def get_keys(self) -> OrderedDict: + """Get all known keys and their fulltext descriptions. + + Returns: + + An :class:`~collections.OrderedDict` in the form + ``{key_name: key_description}`` + + """ + return self.meta_store.get_keys() + + @requires_connection + def get_datasets(self, where: MultiValueKeysType = None, + page: int = 0, limit: int = None) -> Dict[Tuple[str, ...], Any]: + """Get all known dataset key combinations matching the given constraints, + and a path to retrieve the data (dependent on the raster store). + + Arguments: + + where: A mapping from key name to key value constraint(s) + page: A pagination parameter, skips first page * limit results + limit: A pagination parameter, max number of results to return + + Returns: + + A :class:`dict` mapping from key sequence tuple to dataset path. + + """ + return self.meta_store.get_datasets( + where=self._standardize_multi_value_keys(where, requires_all_keys=False), + page=page, + limit=limit + ) + + @requires_connection + def get_metadata(self, keys: ExtendedKeysType) -> Dict[str, Any]: + """Return all stored metadata for given keys. + + Arguments: + + keys: Keys of the requested dataset. Can either be given as a sequence of key values, + or as a mapping ``{key_name: key_value}``. + + Returns: + + A :class:`dict` with the values + + - ``range``: global minimum and maximum value in dataset + - ``bounds``: physical bounds covered by dataset in latitude-longitude projection + - ``convex_hull``: GeoJSON shape specifying total data coverage in latitude-longitude + projection + - ``percentiles``: array of pre-computed percentiles from 1% through 99% + - ``mean``: global mean + - ``stdev``: global standard deviation + - ``metadata``: any additional client-relevant metadata + + """ + keys = self._standardize_keys(keys) + + metadata = self.meta_store.get_metadata(keys) + + if metadata is None: + # metadata is not computed yet, trigger lazy loading + dataset = self.get_datasets(keys) + if not dataset: + raise exceptions.DatasetNotFoundError('No dataset found') + + path = squeeze(dataset.values()) + metadata = self.compute_metadata(path, max_shape=self.LAZY_LOADING_MAX_SHAPE) + self.insert(keys, path, metadata=metadata) + + # ensure standardized/consistent output (types and floating point precision) + metadata = self.meta_store.get_metadata(keys) + assert metadata is not None + + return metadata + + @requires_connection + def insert( + self, keys: ExtendedKeysType, + path: str, *, + override_path: str = None, + metadata: Mapping[str, Any] = None, + skip_metadata: bool = False + ) -> None: + """Register a new dataset. Used to populate meta store. + + Arguments: + + keys: Keys of the dataset. Can either be given as a sequence of key values, or + as a mapping ``{key_name: key_value}``. + path: Path to access dataset (driver dependent). + override_path: If given, this path will be inserted into the meta store + instead of the one used to load the dataset. + metadata: Metadata dict for the dataset. If not given, metadata will be computed + via :meth:`compute_metadata`. + skip_metadata: If True, will skip metadata computation (will be computed + during first request instead). Has no effect if ``metadata`` argument is given. + + """ + keys = self._standardize_keys(keys) + + if metadata is None and not skip_metadata: + metadata = self.compute_metadata(path) + + self.meta_store.insert( + keys=keys, + path=override_path or path, + metadata=metadata + ) + + @requires_connection + def delete(self, keys: ExtendedKeysType) -> None: + """Remove a dataset from the meta store. + + Arguments: + + keys: Keys of the dataset. Can either be given as a sequence of key values, or + as a mapping ``{key_name: key_value}``. + + """ + keys = self._standardize_keys(keys) + + self.meta_store.delete(keys) + + def get_raster_tile(self, keys: ExtendedKeysType, *, + tile_bounds: Sequence[float] = None, + tile_size: Sequence[int] = (256, 256), + preserve_values: bool = False, + asynchronous: bool = False) -> Any: + """Load a raster tile with given keys and bounds. + + Arguments: + + keys: Key sequence identifying the dataset to load tile from. + tile_bounds: Physical bounds of the tile to read, in Web Mercator projection (EPSG3857). + Reads the whole dataset if not given. + tile_size: Shape of the output array to return. Must be two-dimensional. + Defaults to :attr:`~terracotta.config.TerracottaSettings.DEFAULT_TILE_SIZE`. + preserve_values: Whether to preserve exact numerical values (e.g. when reading + categorical data). Sets all interpolation to nearest neighbor. + asynchronous: If given, the tile will be read asynchronously in a separate thread. + This function will return immediately with a :class:`~concurrent.futures.Future` + that can be used to retrieve the result. + + Returns: + + Requested tile as :class:`~numpy.ma.MaskedArray` of shape ``tile_size`` if + ``asynchronous=False``, otherwise a :class:`~concurrent.futures.Future` containing + the result. + + """ + path = squeeze(self.get_datasets(keys).values()) + + return self.raster_store.get_raster_tile( + path=path, + tile_bounds=tile_bounds, + tile_size=tile_size, + preserve_values=preserve_values, + asynchronous=asynchronous, + ) + + def compute_metadata(self, path: str, *, + extra_metadata: Any = None, + use_chunks: bool = None, + max_shape: Sequence[int] = None) -> Dict[str, Any]: + """Compute metadata for a dataset. + + Arguments: + + path: Path identifing the dataset. + extra_metadata: Any additional metadata that will be returned as is + in the result, under the `metadata` key. + use_chunks: Whether to load the dataset in chunks, when computing. + Useful if the dataset is too large to fit in memory. + Mutually exclusive with `max_shape`. + max_shape: If dataset is larger than this shape, it will be downsampled + while loading. Useful if the dataset is too large to fit in memory. + Mutually exclusive with `use_chunks`. + + Returns: + + A :class:`dict` with the values + + - ``range``: global minimum and maximum value in dataset + - ``bounds``: physical bounds covered by dataset in latitude-longitude projection + - ``convex_hull``: GeoJSON shape specifying total data coverage in latitude-longitude + projection + - ``percentiles``: array of pre-computed percentiles from 1% through 99% + - ``mean``: global mean + - ``stdev``: global standard deviation + - ``metadata``: any additional client-relevant metadata + + """ + return self.raster_store.compute_metadata( + path=path, + extra_metadata=extra_metadata, + use_chunks=use_chunks, + max_shape=max_shape, + ) + + def _standardize_keys( + self, keys: ExtendedKeysType, requires_all_keys: bool = True + ) -> KeysType: + return self._ensure_keys_as_dict(keys, requires_all_keys) + + def _standardize_multi_value_keys( + self, keys: Optional[ExtendedMultiValueKeysType], requires_all_keys: bool = True + ) -> MultiValueKeysType: + return self._ensure_keys_as_dict(keys, requires_all_keys) + + def _ensure_keys_as_dict( + self, + keys: Union[ExtendedKeysType, Optional[MultiValueKeysType]], + requires_all_keys: bool = True + ) -> Dict[str, Any]: + if requires_all_keys and (keys is None or len(keys) != len(self.key_names)): + raise exceptions.InvalidKeyError( + f'Got wrong number of keys (available keys: {self.key_names})' + ) + + if isinstance(keys, Mapping): + keys = dict(keys.items()) + elif isinstance(keys, Sequence): + keys = dict(zip(self.key_names, keys)) + elif keys is None: + keys = {} + else: + raise exceptions.InvalidKeyError( + 'Encountered unknown key type, expected Mapping or Sequence' + ) + + unknown_keys = set(keys) - set(self.key_names) + if unknown_keys: + raise exceptions.InvalidKeyError( + f'Encountered unrecognized keys {unknown_keys} (available keys: {self.key_names})' + ) + + return keys + + def __repr__(self) -> str: + return ( + f'{self.__class__.__name__}(\n' + f' meta_store={self.meta_store!r},\n' + f' raster_store={self.raster_store!r}\n' + ')' + ) diff --git a/terracotta/raster.py b/terracotta/raster.py new file mode 100644 index 00000000..428a65e8 --- /dev/null +++ b/terracotta/raster.py @@ -0,0 +1,385 @@ +"""raster.py + +Extract information from raster files through rasterio. +""" + +from typing import Optional, Any, Dict, Tuple, Sequence, TYPE_CHECKING +import contextlib +import warnings +import logging + +import numpy as np + +if TYPE_CHECKING: # pragma: no cover + from rasterio.io import DatasetReader # noqa: F401 + +try: + from crick import TDigest, SummaryStats + has_crick = True +except ImportError: # pragma: no cover + has_crick = False + +from terracotta import exceptions +from terracotta.profile import trace + +logger = logging.getLogger(__name__) + + +def convex_hull_candidate_mask(mask: np.ndarray) -> np.ndarray: + """Returns a reduced boolean mask to speed up convex hull computations. + + Exploits the fact that only the first and last elements of each row and column + can contribute to the convex hull of a dataset. + """ + assert mask.ndim == 2 + assert mask.dtype == np.bool_ + + nx, ny = mask.shape + out = np.zeros_like(mask) + + # these operations do not short-circuit, but seems to be the best we can do + # NOTE: argmax returns 0 if a slice is all True or all False + first_row = np.argmax(mask, axis=0) + last_row = nx - 1 - np.argmax(mask[::-1, :], axis=0) + first_col = np.argmax(mask, axis=1) + last_col = ny - 1 - np.argmax(mask[:, ::-1], axis=1) + + all_rows = np.arange(nx) + all_cols = np.arange(ny) + + out[first_row, all_cols] = out[last_row, all_cols] = True + out[all_rows, first_col] = out[all_rows, last_col] = True + + # filter all-False slices + out &= mask + + return out + + +def compute_image_stats_chunked(dataset: 'DatasetReader') -> Optional[Dict[str, Any]]: + """Compute statistics for the given rasterio dataset by looping over chunks.""" + from rasterio import features, warp, windows + from shapely import geometry + + total_count = valid_data_count = 0 + tdigest = TDigest() + sstats = SummaryStats() + convex_hull = geometry.Polygon() + + block_windows = [w for _, w in dataset.block_windows(1)] + + for w in block_windows: + with warnings.catch_warnings(): + warnings.filterwarnings('ignore', message='invalid value encountered.*') + block_data = dataset.read(1, window=w, masked=True) + + # handle NaNs for float rasters + block_data = np.ma.masked_invalid(block_data, copy=False) + + total_count += int(block_data.size) + valid_data = block_data.compressed() + + if valid_data.size == 0: + continue + + valid_data_count += int(valid_data.size) + + if np.any(block_data.mask): + hull_candidates = convex_hull_candidate_mask(~block_data.mask) + hull_shapes = [geometry.shape(s) for s, _ in features.shapes( + np.ones(hull_candidates.shape, 'uint8'), + mask=hull_candidates, + transform=windows.transform(w, dataset.transform) + )] + else: + w, s, e, n = windows.bounds(w, dataset.transform) + hull_shapes = [geometry.Polygon([(w, s), (e, s), (e, n), (w, n)])] + convex_hull = geometry.MultiPolygon([convex_hull, *hull_shapes]).convex_hull + + tdigest.update(valid_data) + sstats.update(valid_data) + + if sstats.count() == 0: + return None + + convex_hull_wgs = warp.transform_geom( + dataset.crs, 'epsg:4326', geometry.mapping(convex_hull) + ) + + return { + 'valid_percentage': valid_data_count / total_count * 100, + 'range': (sstats.min(), sstats.max()), + 'mean': sstats.mean(), + 'stdev': sstats.std(), + 'percentiles': tdigest.quantile(np.arange(0.01, 1, 0.01)), + 'convex_hull': convex_hull_wgs + } + + +def compute_image_stats(dataset: 'DatasetReader', + max_shape: Sequence[int] = None) -> Optional[Dict[str, Any]]: + """Compute statistics for the given rasterio dataset by reading it into memory.""" + from rasterio import features, warp, transform + from shapely import geometry + + out_shape = (dataset.height, dataset.width) + + if max_shape is not None: + out_shape = ( + min(max_shape[0], out_shape[0]), + min(max_shape[1], out_shape[1]) + ) + + data_transform = transform.from_bounds( + *dataset.bounds, height=out_shape[0], width=out_shape[1] + ) + raster_data = dataset.read(1, out_shape=out_shape, masked=True) + + if dataset.nodata is not None: + # nodata values might slip into output array if out_shape < dataset.shape + raster_data = np.ma.masked_equal(raster_data, dataset.nodata, copy=False) + + # handle NaNs for float rasters + raster_data = np.ma.masked_invalid(raster_data, copy=False) + + valid_data = raster_data.compressed() + + if valid_data.size == 0: + return None + + if np.any(raster_data.mask): + hull_candidates = convex_hull_candidate_mask(~raster_data.mask) + hull_shapes = (geometry.shape(s) for s, _ in features.shapes( + np.ones(hull_candidates.shape, 'uint8'), + mask=hull_candidates, + transform=data_transform + )) + convex_hull = geometry.MultiPolygon(hull_shapes).convex_hull + else: + # no masked entries -> convex hull == dataset bounds + w, s, e, n = dataset.bounds + convex_hull = geometry.Polygon([(w, s), (e, s), (e, n), (w, n)]) + + convex_hull_wgs = warp.transform_geom( + dataset.crs, 'epsg:4326', geometry.mapping(convex_hull) + ) + + return { + 'valid_percentage': valid_data.size / raster_data.size * 100, + 'range': (float(valid_data.min()), float(valid_data.max())), + 'mean': float(valid_data.mean()), + 'stdev': float(valid_data.std()), + 'percentiles': np.percentile(valid_data, np.arange(1, 100)), + 'convex_hull': convex_hull_wgs + } + + +@trace('compute_metadata') +def compute_metadata(path: str, *, + extra_metadata: Any = None, + use_chunks: bool = None, + max_shape: Sequence[int] = None, + large_raster_threshold: int = None, + rio_env_options: Dict[str, Any] = None) -> Dict[str, Any]: + import rasterio + from rasterio import warp + from terracotta.cog import validate + + row_data: Dict[str, Any] = {} + extra_metadata = extra_metadata or {} + + if max_shape is not None and len(max_shape) != 2: + raise ValueError('max_shape argument must contain 2 values') + + if use_chunks and max_shape is not None: + raise ValueError('Cannot use both use_chunks and max_shape arguments') + + if rio_env_options is None: + rio_env_options = {} + + with rasterio.Env(**rio_env_options): + if not validate(path): + warnings.warn( + f'Raster file {path} is not a valid cloud-optimized GeoTIFF. ' + 'Any interaction with it will be significantly slower. Consider optimizing ' + 'it through `terracotta optimize-rasters` before ingestion.', + exceptions.PerformanceWarning, stacklevel=3 + ) + + with rasterio.open(path) as src: + if src.nodata is None and not has_alpha_band(src): + warnings.warn( + f'Raster file {path} does not have a valid nodata value, ' + 'and does not contain an alpha band. No data will be masked.' + ) + + bounds = warp.transform_bounds( + src.crs, 'epsg:4326', *src.bounds, densify_pts=21 + ) + + if use_chunks is None and max_shape is None and large_raster_threshold is not None: + use_chunks = src.width * src.height > large_raster_threshold + + if use_chunks: + logger.debug( + f'Computing metadata for file {path} using more than ' + f'{large_raster_threshold // 10**6}M pixels, iterating ' + 'over chunks' + ) + + if use_chunks and not has_crick: + warnings.warn( + 'Processing a large raster file, but crick failed to import. ' + 'Reading whole file into memory instead.', exceptions.PerformanceWarning + ) + use_chunks = False + + if use_chunks: + raster_stats = compute_image_stats_chunked(src) + else: + raster_stats = compute_image_stats(src, max_shape) + + if raster_stats is None: + raise ValueError(f'Raster file {path} does not contain any valid data') + + row_data.update(raster_stats) + + row_data['bounds'] = bounds + row_data['metadata'] = extra_metadata + return row_data + + +def get_resampling_enum(method: str) -> Any: + from rasterio.enums import Resampling + + if method == 'nearest': + return Resampling.nearest + + if method == 'linear': + return Resampling.bilinear + + if method == 'cubic': + return Resampling.cubic + + if method == 'average': + return Resampling.average + + raise ValueError(f'unknown resampling method {method}') + + +def has_alpha_band(src: 'DatasetReader') -> bool: + from rasterio.enums import MaskFlags, ColorInterp + return ( + any([MaskFlags.alpha in flags for flags in src.mask_flag_enums]) + or ColorInterp.alpha in src.colorinterp + ) + + +@trace("get_raster_tile") +def get_raster_tile(path: str, *, + reprojection_method: str = "nearest", + resampling_method: str = "nearest", + tile_bounds: Tuple[float, float, float, float] = None, + tile_size: Tuple[int, int] = (256, 256), + preserve_values: bool = False, + target_crs: str = 'epsg:3857', + rio_env_options: Dict[str, Any] = None) -> np.ma.MaskedArray: + """Load a raster dataset from a file through rasterio. + + Heavily inspired by mapbox/rio-tiler + """ + import rasterio + from rasterio import transform, windows, warp + from rasterio.vrt import WarpedVRT + from affine import Affine + + dst_bounds: Tuple[float, float, float, float] + + if rio_env_options is None: + rio_env_options = {} + + if preserve_values: + reproject_enum = resampling_enum = get_resampling_enum('nearest') + else: + reproject_enum = get_resampling_enum(reprojection_method) + resampling_enum = get_resampling_enum(resampling_method) + + with contextlib.ExitStack() as es: + es.enter_context(rasterio.Env(**rio_env_options)) + try: + with trace('open_dataset'): + src = es.enter_context(rasterio.open(path)) + except OSError: + raise IOError('error while reading file {}'.format(path)) + + # compute buonds in target CRS + dst_bounds = warp.transform_bounds(src.crs, target_crs, *src.bounds) + + if tile_bounds is None: + tile_bounds = dst_bounds + + # prevent loads of very sparse data + cover_ratio = ( + (dst_bounds[2] - dst_bounds[0]) / (tile_bounds[2] - tile_bounds[0]) + * (dst_bounds[3] - dst_bounds[1]) / (tile_bounds[3] - tile_bounds[1]) + ) + + if cover_ratio < 0.01: + raise exceptions.TileOutOfBoundsError('dataset covers less than 1% of tile') + + # compute suggested resolution in target CRS + dst_transform, _, _ = warp.calculate_default_transform( + src.crs, target_crs, src.width, src.height, *src.bounds + ) + dst_res = (abs(dst_transform.a), abs(dst_transform.e)) + + # in some cases (e.g. at extreme latitudes), the default transform + # suggests very coarse resolutions - in this case, fall back to native tile res + tile_transform = transform.from_bounds(*tile_bounds, *tile_size) + tile_res = (abs(tile_transform.a), abs(tile_transform.e)) + + if tile_res[0] < dst_res[0] or tile_res[1] < dst_res[1]: + dst_res = tile_res + resampling_enum = get_resampling_enum('nearest') + + # pad tile bounds to prevent interpolation artefacts + num_pad_pixels = 2 + + # compute tile VRT shape and transform + dst_width = max(1, round((tile_bounds[2] - tile_bounds[0]) / dst_res[0])) + dst_height = max(1, round((tile_bounds[3] - tile_bounds[1]) / dst_res[1])) + vrt_transform = ( + transform.from_bounds(*tile_bounds, width=dst_width, height=dst_height) + * Affine.translation(-num_pad_pixels, -num_pad_pixels) + ) + vrt_height, vrt_width = dst_height + 2 * num_pad_pixels, dst_width + 2 * num_pad_pixels + + # remove padding in output + out_window = windows.Window( + col_off=num_pad_pixels, row_off=num_pad_pixels, width=dst_width, height=dst_height + ) + + # construct VRT + vrt = es.enter_context( + WarpedVRT( + src, crs=target_crs, resampling=reproject_enum, + transform=vrt_transform, width=vrt_width, height=vrt_height, + add_alpha=not has_alpha_band(src) + ) + ) + + # read data + with warnings.catch_warnings(), trace('read_from_vrt'): + warnings.filterwarnings('ignore', message='invalid value encountered.*') + tile_data = vrt.read( + 1, resampling=resampling_enum, window=out_window, out_shape=tile_size + ) + + # assemble alpha mask + mask_idx = vrt.count + mask = vrt.read(mask_idx, window=out_window, out_shape=tile_size) == 0 + + if src.nodata is not None: + mask |= tile_data == src.nodata + + return np.ma.masked_array(tile_data, mask=mask) diff --git a/terracotta/scripts/optimize_rasters.py b/terracotta/scripts/optimize_rasters.py index cc52f48f..85361af2 100644 --- a/terracotta/scripts/optimize_rasters.py +++ b/terracotta/scripts/optimize_rasters.py @@ -85,8 +85,8 @@ def _prefered_compression_method() -> str: def _get_vrt(src: DatasetReader, rs_method: int) -> WarpedVRT: - from terracotta.drivers.raster_base import RasterDriver - target_crs = RasterDriver._TARGET_CRS + from terracotta.drivers.geotiff_raster_store import GeoTiffRasterStore + target_crs = GeoTiffRasterStore._TARGET_CRS vrt_transform, vrt_width, vrt_height = calculate_default_transform( src.crs, target_crs, src.width, src.height, *src.bounds ) diff --git a/terracotta/xyz.py b/terracotta/xyz.py index b73253e1..dcbfe031 100644 --- a/terracotta/xyz.py +++ b/terracotta/xyz.py @@ -8,11 +8,11 @@ import mercantile from terracotta import exceptions -from terracotta.drivers.base import Driver +from terracotta.drivers.terracotta_driver import TerracottaDriver # TODO: add accurate signature if mypy ever supports conditional return types -def get_tile_data(driver: Driver, +def get_tile_data(driver: TerracottaDriver, keys: Union[Sequence[str], Mapping[str, str]], tile_xyz: Tuple[int, int, int] = None, *, tile_size: Tuple[int, int] = (256, 256), diff --git a/tests/benchmarks.py b/tests/benchmarks.py index c2f64699..adc9ec34 100644 --- a/tests/benchmarks.py +++ b/tests/benchmarks.py @@ -114,7 +114,7 @@ def test_bench_singleband(benchmark, zoom, resampling, big_raster_file_nodata, b rv = benchmark(client.get, '/singleband/nodata/1/preview.png') assert rv.status_code == 200 - assert not len(get_driver(str(benchmark_database))._raster_cache) + assert not len(get_driver(str(benchmark_database)).raster_store._raster_cache) def test_bench_singleband_out_of_bounds(benchmark, benchmark_database): @@ -136,12 +136,14 @@ def test_bench_singleband_out_of_bounds(benchmark, benchmark_database): @pytest.mark.parametrize('raster_type', ['nodata', 'masked']) def test_bench_compute_metadata(benchmark, big_raster_file_nodata, big_raster_file_mask, chunks, raster_type): - from terracotta.drivers.raster_base import RasterDriver + from terracotta import raster + if raster_type == 'nodata': raster_file = big_raster_file_nodata elif raster_type == 'masked': raster_file = big_raster_file_mask - benchmark(RasterDriver.compute_metadata, str(raster_file), use_chunks=chunks) + + benchmark(raster.compute_metadata, str(raster_file), use_chunks=chunks) @pytest.mark.parametrize('in_memory', [False, True]) diff --git a/tests/drivers/test_drivers.py b/tests/drivers/test_drivers.py index f4352bc6..d3881839 100644 --- a/tests/drivers/test_drivers.py +++ b/tests/drivers/test_drivers.py @@ -2,9 +2,9 @@ TESTABLE_DRIVERS = ['sqlite', 'mysql'] DRIVER_CLASSES = { - 'sqlite': 'SQLiteDriver', - 'sqlite-remote': 'SQLiteRemoteDriver', - 'mysql': 'MySQLDriver' + 'sqlite': 'SQLiteMetaStore', + 'sqlite-remote': 'SQLiteRemoteMetaStore', + 'mysql': 'MySQLMetaStore' } @@ -12,14 +12,14 @@ def test_auto_detect(driver_path, provider): from terracotta import drivers db = drivers.get_driver(driver_path) - assert db.__class__.__name__ == DRIVER_CLASSES[provider] + assert db.meta_store.__class__.__name__ == DRIVER_CLASSES[provider] assert drivers.get_driver(driver_path, provider=provider) is db def test_normalize_base(tmpdir): - from terracotta.drivers import Driver + from terracotta.drivers import MetaStore # base class normalize is noop - assert Driver._normalize_path(str(tmpdir)) == str(tmpdir) + assert MetaStore._normalize_path(str(tmpdir)) == str(tmpdir) @pytest.mark.parametrize('provider', ['sqlite']) @@ -65,6 +65,22 @@ def test_normalize_url(provider): assert driver._normalize_path(p) == first_path +@pytest.mark.parametrize('provider', ['mysql']) +def test_parse_connection_string_with_invalid_schemes(provider): + from terracotta import drivers + + invalid_schemes = ( + 'fakescheme://test.example.com/foo', + 'fakescheme://test.example.com:80/foo', + ) + + for invalid_scheme in invalid_schemes: + with pytest.raises(ValueError) as exc: + driver = drivers.get_driver(invalid_scheme, provider) + print(type(driver)) + assert 'unsupported URL scheme' in str(exc.value) + + def test_get_driver_invalid(): from terracotta import drivers with pytest.raises(ValueError) as exc: @@ -161,10 +177,10 @@ def __getattribute__(self, key): with pytest.raises(RuntimeError): with db.connect(): - db._connection = Evanescence() + db.meta_store.connection = Evanescence() db.get_keys() - assert not db._connected + assert not db.meta_store.connected with db.connect(): db.get_keys() @@ -174,7 +190,7 @@ def __getattribute__(self, key): def test_repr(driver_path, provider): from terracotta import drivers db = drivers.get_driver(driver_path, provider=provider) - assert repr(db).startswith(DRIVER_CLASSES[provider]) + assert f'meta_store={DRIVER_CLASSES[provider]}' in repr(db) @pytest.mark.parametrize('provider', TESTABLE_DRIVERS) @@ -205,11 +221,42 @@ def test_version_conflict(driver_path, provider, raster_file, monkeypatch): with monkeypatch.context() as m: fake_version = '0.0.0' - m.setattr(f'{db.__module__}.__version__', fake_version) - db._version_checked = False + m.setattr('terracotta.__version__', fake_version) + db.meta_store.db_version_verified = False with pytest.raises(exceptions.InvalidDatabaseError) as exc: with db.connect(): pass assert fake_version in str(exc.value) + + +@pytest.mark.parametrize('provider', TESTABLE_DRIVERS) +def test_invalid_key_types(driver_path, provider): + from terracotta import drivers, exceptions + + db = drivers.get_driver(driver_path, provider) + keys = ('some', 'keys') + + db.create(keys) + + db.get_datasets() + db.get_datasets(['a', 'b']) + db.get_datasets({'some': 'a', 'keys': 'b'}) + db.get_datasets(None) + + with pytest.raises(exceptions.InvalidKeyError) as exc: + db.get_datasets(45) + assert 'unknown key type' in str(exc) + + with pytest.raises(exceptions.InvalidKeyError) as exc: + db.delete(['a']) + assert 'wrong number of keys' in str(exc) + + with pytest.raises(exceptions.InvalidKeyError) as exc: + db.delete(None) + assert 'wrong number of keys' in str(exc) + + with pytest.raises(exceptions.InvalidKeyError) as exc: + db.get_datasets({'not-a-key': 'val'}) + assert 'unrecognized keys' in str(exc) diff --git a/tests/drivers/test_mysql.py b/tests/drivers/test_mysql.py index 8ddc546f..e51aec44 100644 --- a/tests/drivers/test_mysql.py +++ b/tests/drivers/test_mysql.py @@ -2,19 +2,19 @@ TEST_CASES = { 'mysql://root@localhost:5000/test': dict( - user='root', password='', host='localhost', port=5000, db='test' + username='root', password=None, host='localhost', port=5000, database='test' ), 'root@localhost:5000/test': dict( - user='root', password='', host='localhost', port=5000, db='test' + username='root', password=None, host='localhost', port=5000, database='test' ), 'mysql://root:foo@localhost/test': dict( - user='root', password='foo', host='localhost', port=3306, db='test' + username='root', password='foo', host='localhost', port=None, database='test' ), 'mysql://localhost/test': dict( - password='', host='localhost', port=3306, db='test' + password=None, host='localhost', port=None, database='test' ), 'localhost/test': dict( - password='', host='localhost', port=3306, db='test' + password=None, host='localhost', port=None, database='test' ) } @@ -32,9 +32,8 @@ def test_path_parsing(case): drivers._DRIVER_CACHE = {} db = drivers.get_driver(case, provider='mysql') - db_args = db._db_args - for attr in ('user', 'password', 'host', 'port', 'db'): - assert getattr(db_args, attr) == TEST_CASES[case].get(attr, None) + for attr in ('username', 'password', 'host', 'port', 'database'): + assert getattr(db.meta_store.url, attr) == TEST_CASES[case].get(attr, None) @pytest.mark.parametrize('case', INVALID_TEST_CASES) diff --git a/tests/drivers/test_raster_drivers.py b/tests/drivers/test_raster_drivers.py index f0b7e1d5..7954cf07 100644 --- a/tests/drivers/test_raster_drivers.py +++ b/tests/drivers/test_raster_drivers.py @@ -3,9 +3,6 @@ import platform import time -import rasterio -import rasterio.features -from shapely.geometry import shape, MultiPolygon import numpy as np DRIVERS = ['sqlite', 'mysql'] @@ -143,7 +140,7 @@ def test_lazy_loading(driver_path, provider, raster_file): data1 = db.get_metadata(['some', 'value']) data2 = db.get_metadata({'some': 'some', 'keynames': 'other_value'}) - assert list(data1.keys()) == list(data2.keys()) + assert set(data1.keys()) == set(data2.keys()) assert all(np.all(data1[k] == data2[k]) for k in data1.keys()) @@ -208,7 +205,7 @@ def test_wrong_key_number(driver_path, provider, raster_file): assert 'wrong number of keys' in str(exc.value) with pytest.raises(exceptions.InvalidKeyError) as exc: - db.insert(['a', 'b'], '') + db.insert(['a', 'b'], '', skip_metadata=True) assert 'wrong number of keys' in str(exc.value) with pytest.raises(exceptions.InvalidKeyError) as exc: @@ -302,7 +299,11 @@ def test_insertion_invalid_raster(driver_path, provider, invalid_raster_file): @pytest.mark.parametrize('provider', DRIVERS) -def test_raster_retrieval(driver_path, provider, raster_file): +@pytest.mark.parametrize('resampling_method', ['nearest', 'linear', 'cubic', 'average']) +def test_raster_retrieval(driver_path, provider, raster_file, resampling_method): + import terracotta + terracotta.update_settings(RESAMPLING_METHOD=resampling_method) + from terracotta import drivers db = drivers.get_driver(driver_path, provider=provider) keys = ('some', 'keynames') @@ -331,7 +332,7 @@ def test_raster_cache(driver_path, provider, raster_file, asynchronous): db.insert(['some', 'value'], str(raster_file)) db.insert(['some', 'other_value'], str(raster_file)) - assert len(db._raster_cache) == 0 + assert len(db.raster_store._raster_cache) == 0 data1 = db.get_raster_tile(['some', 'value'], tile_size=(256, 256), asynchronous=asynchronous) @@ -339,7 +340,7 @@ def test_raster_cache(driver_path, provider, raster_file, asynchronous): data1 = data1.result() time.sleep(1) # allow callback to finish - assert len(db._raster_cache) == 1 + assert len(db.raster_store._raster_cache) == 1 data2 = db.get_raster_tile(['some', 'value'], tile_size=(256, 256), asynchronous=asynchronous) @@ -347,7 +348,7 @@ def test_raster_cache(driver_path, provider, raster_file, asynchronous): data2 = data2.result() np.testing.assert_array_equal(data1, data2) - assert len(db._raster_cache) == 1 + assert len(db.raster_store._raster_cache) == 1 @pytest.mark.parametrize('provider', DRIVERS) @@ -363,7 +364,7 @@ def test_raster_cache_fail(driver_path, provider, raster_file, asynchronous): db.create(keys) db.insert(['some', 'value'], str(raster_file)) - assert len(db._raster_cache) == 0 + assert len(db.raster_store._raster_cache) == 0 data1 = db.get_raster_tile(['some', 'value'], tile_size=(256, 256), asynchronous=asynchronous) @@ -371,7 +372,7 @@ def test_raster_cache_fail(driver_path, provider, raster_file, asynchronous): data1 = data1.result() time.sleep(1) # allow callback to finish - assert len(db._raster_cache) == 0 + assert len(db.raster_store._raster_cache) == 0 @pytest.mark.parametrize('provider', DRIVERS) @@ -379,7 +380,7 @@ def test_multiprocessing_fallback(driver_path, provider, raster_file, monkeypatc import concurrent.futures from importlib import reload from terracotta import drivers - import terracotta.drivers.raster_base + import terracotta.drivers.geotiff_raster_store def dummy(*args, **kwargs): raise OSError('monkeypatched') @@ -388,7 +389,7 @@ def dummy(*args, **kwargs): with monkeypatch.context() as m, pytest.warns(UserWarning): m.setattr(concurrent.futures, 'ProcessPoolExecutor', dummy) - reload(terracotta.drivers.raster_base) + reload(terracotta.drivers.geotiff_raster_store) db = drivers.get_driver(driver_path, provider=provider) keys = ('some', 'keynames') @@ -404,7 +405,7 @@ def dummy(*args, **kwargs): np.testing.assert_array_equal(data1, data2) finally: - reload(terracotta.drivers.raster_base) + reload(terracotta.drivers.geotiff_raster_store) @pytest.mark.parametrize('provider', DRIVERS) @@ -475,206 +476,11 @@ def test_nodata_consistency(driver_path, provider, big_raster_file_mask, big_ras np.testing.assert_array_equal(data_mask.mask, data_nodata.mask) -def geometry_mismatch(shape1, shape2): - """Compute relative mismatch of two shapes""" - return shape1.symmetric_difference(shape2).area / shape1.union(shape2).area - - -def convex_hull_exact(src): - kwargs = dict(bidx=1, band=False, as_mask=True, geographic=True) - - data = src.read() - if np.any(np.isnan(data)) and src.nodata is not None: - # hack: replace NaNs with nodata to make sure they are excluded - with rasterio.MemoryFile() as memfile, memfile.open(**src.profile) as tmpsrc: - data[np.isnan(data)] = src.nodata - tmpsrc.write(data) - dataset_shape = list(rasterio.features.dataset_features(tmpsrc, **kwargs)) - else: - dataset_shape = list(rasterio.features.dataset_features(src, **kwargs)) - - convex_hull = MultiPolygon([shape(s['geometry']) for s in dataset_shape]).convex_hull - return convex_hull - - -@pytest.mark.parametrize('use_chunks', [True, False]) -@pytest.mark.parametrize('nodata_type', ['nodata', 'masked', 'none', 'nan']) -def test_compute_metadata(big_raster_file_nodata, big_raster_file_nomask, - big_raster_file_mask, raster_file_float, nodata_type, use_chunks): - from terracotta.drivers.raster_base import RasterDriver - - if nodata_type == 'nodata': - raster_file = big_raster_file_nodata - elif nodata_type == 'masked': - raster_file = big_raster_file_mask - elif nodata_type == 'none': - raster_file = big_raster_file_nomask - elif nodata_type == 'nan': - raster_file = raster_file_float - - if use_chunks: - pytest.importorskip('crick') - - with rasterio.open(str(raster_file)) as src: - data = src.read(1, masked=True) - valid_data = np.ma.masked_invalid(data).compressed() - convex_hull = convex_hull_exact(src) - - # compare - if nodata_type == 'none': - with pytest.warns(UserWarning) as record: - mtd = RasterDriver.compute_metadata(str(raster_file), use_chunks=use_chunks) - assert 'does not have a valid nodata value' in str(record[0].message) - else: - mtd = RasterDriver.compute_metadata(str(raster_file), use_chunks=use_chunks) - - np.testing.assert_allclose(mtd['valid_percentage'], 100 * valid_data.size / data.size) - np.testing.assert_allclose(mtd['range'], (valid_data.min(), valid_data.max())) - np.testing.assert_allclose(mtd['mean'], valid_data.mean()) - np.testing.assert_allclose(mtd['stdev'], valid_data.std()) - - # allow some error margin since we only compute approximate quantiles - np.testing.assert_allclose( - mtd['percentiles'], - np.percentile(valid_data, np.arange(1, 100)), - rtol=2e-2, atol=valid_data.max() / 100 - ) - - assert geometry_mismatch(shape(mtd['convex_hull']), convex_hull) < 1e-6 - - -@pytest.mark.parametrize('nodata_type', ['nodata', 'masked', 'none', 'nan']) -def test_compute_metadata_approximate(nodata_type, big_raster_file_nodata, big_raster_file_mask, - big_raster_file_nomask, raster_file_float): - from terracotta.drivers.raster_base import RasterDriver - - if nodata_type == 'nodata': - raster_file = big_raster_file_nodata - elif nodata_type == 'masked': - raster_file = big_raster_file_mask - elif nodata_type == 'none': - raster_file = big_raster_file_nomask - elif nodata_type == 'nan': - raster_file = raster_file_float - - with rasterio.open(str(raster_file)) as src: - data = src.read(1, masked=True) - valid_data = np.ma.masked_invalid(data).compressed() - convex_hull = convex_hull_exact(src) - - # compare - if nodata_type == 'none': - with pytest.warns(UserWarning) as record: - mtd = RasterDriver.compute_metadata(str(raster_file), max_shape=(512, 512)) - assert 'does not have a valid nodata value' in str(record[0].message) - else: - mtd = RasterDriver.compute_metadata(str(raster_file), max_shape=(512, 512)) - - np.testing.assert_allclose(mtd['valid_percentage'], 100 * valid_data.size / data.size, atol=1) - np.testing.assert_allclose( - mtd['range'], (valid_data.min(), valid_data.max()), atol=valid_data.max() / 100 - ) - np.testing.assert_allclose(mtd['mean'], valid_data.mean(), rtol=0.02) - np.testing.assert_allclose(mtd['stdev'], valid_data.std(), rtol=0.02) - - np.testing.assert_allclose( - mtd['percentiles'], - np.percentile(valid_data, np.arange(1, 100)), - atol=valid_data.max() / 100, rtol=0.02 - ) - - assert geometry_mismatch(shape(mtd['convex_hull']), convex_hull) < 0.05 - - -def test_compute_metadata_invalid_options(big_raster_file_nodata): - from terracotta.drivers.raster_base import RasterDriver - - with pytest.raises(ValueError): - RasterDriver.compute_metadata( - str(big_raster_file_nodata), max_shape=(256, 256), use_chunks=True - ) - - with pytest.raises(ValueError): - RasterDriver.compute_metadata(str(big_raster_file_nodata), max_shape=(256, 256, 1)) - - -@pytest.mark.parametrize('use_chunks', [True, False]) -def test_compute_metadata_invalid_raster(invalid_raster_file, use_chunks): - from terracotta.drivers.raster_base import RasterDriver - - if use_chunks: - pytest.importorskip('crick') - - with pytest.raises(ValueError): - RasterDriver.compute_metadata(str(invalid_raster_file), use_chunks=use_chunks) - - -def test_compute_metadata_nocrick(big_raster_file_nodata, monkeypatch): - with rasterio.open(str(big_raster_file_nodata)) as src: - data = src.read(1, masked=True) - valid_data = np.ma.masked_invalid(data).compressed() - convex_hull = convex_hull_exact(src) - - from terracotta import exceptions - import terracotta.drivers.raster_base - - with monkeypatch.context() as m: - m.setattr(terracotta.drivers.raster_base, 'has_crick', False) - - with pytest.warns(exceptions.PerformanceWarning): - mtd = terracotta.drivers.raster_base.RasterDriver.compute_metadata( - str(big_raster_file_nodata), use_chunks=True - ) - - # compare - np.testing.assert_allclose(mtd['valid_percentage'], 100 * valid_data.size / data.size) - np.testing.assert_allclose(mtd['range'], (valid_data.min(), valid_data.max())) - np.testing.assert_allclose(mtd['mean'], valid_data.mean()) - np.testing.assert_allclose(mtd['stdev'], valid_data.std()) - - # allow error of 1%, since we only compute approximate quantiles - np.testing.assert_allclose( - mtd['percentiles'], - np.percentile(valid_data, np.arange(1, 100)), - rtol=2e-2 - ) - - assert geometry_mismatch(shape(mtd['convex_hull']), convex_hull) < 1e-6 - - -def test_compute_metadata_unoptimized(unoptimized_raster_file): - from terracotta import exceptions - from terracotta.drivers.raster_base import RasterDriver - - with rasterio.open(str(unoptimized_raster_file)) as src: - data = src.read(1, masked=True) - valid_data = np.ma.masked_invalid(data).compressed() - convex_hull = convex_hull_exact(src) - - # compare - with pytest.warns(exceptions.PerformanceWarning): - mtd = RasterDriver.compute_metadata(str(unoptimized_raster_file), use_chunks=False) - - np.testing.assert_allclose(mtd['valid_percentage'], 100 * valid_data.size / data.size) - np.testing.assert_allclose(mtd['range'], (valid_data.min(), valid_data.max())) - np.testing.assert_allclose(mtd['mean'], valid_data.mean()) - np.testing.assert_allclose(mtd['stdev'], valid_data.std()) - - # allow some error margin since we only compute approximate quantiles - np.testing.assert_allclose( - mtd['percentiles'], - np.percentile(valid_data, np.arange(1, 100)), - rtol=2e-2 - ) - - assert geometry_mismatch(shape(mtd['convex_hull']), convex_hull) < 1e-6 - - @pytest.mark.parametrize('provider', DRIVERS) def test_broken_process_pool(driver_path, provider, raster_file): import concurrent.futures from terracotta import drivers - from terracotta.drivers.raster_base import context + from terracotta.drivers.geotiff_raster_store import context class BrokenPool: def submit(self, *args, **kwargs): @@ -701,7 +507,7 @@ def submit(self, *args, **kwargs): def test_no_multiprocessing(): import concurrent.futures from terracotta import update_settings - from terracotta.drivers.raster_base import create_executor + from terracotta.drivers.geotiff_raster_store import create_executor update_settings(USE_MULTIPROCESSING=False) diff --git a/tests/drivers/test_sqlite_remote.py b/tests/drivers/test_sqlite_remote.py index 9750bca8..d4852617 100644 --- a/tests/drivers/test_sqlite_remote.py +++ b/tests/drivers/test_sqlite_remote.py @@ -4,9 +4,9 @@ """ import os +import tempfile import time import uuid -import tempfile from pathlib import Path import pytest @@ -90,7 +90,7 @@ def test_invalid_url(): @moto.mock_s3 def test_nonexisting_url(): - from terracotta import get_driver, exceptions + from terracotta import exceptions, get_driver driver = get_driver('s3://foo/db.sqlite') with pytest.raises(exceptions.InvalidDatabaseError): with driver.connect(): @@ -105,32 +105,32 @@ def test_remote_database_cache(s3_db_factory, raster_file, monkeypatch): from terracotta import get_driver driver = get_driver(dbpath) - driver._last_updated = -float('inf') + driver.meta_store._last_updated = -float('inf') with driver.connect(): assert driver.key_names == keys assert driver.get_datasets() == {} - modification_date = os.path.getmtime(driver.path) + modification_date = os.path.getmtime(driver.meta_store._local_path) s3_db_factory(keys, datasets={('some', 'value'): str(raster_file)}) # no change yet assert driver.get_datasets() == {} - assert os.path.getmtime(driver.path) == modification_date + assert os.path.getmtime(driver.meta_store._local_path) == modification_date # check if remote db is cached correctly - driver._last_updated = time.time() + driver.meta_store._last_updated = time.time() with driver.connect(): # db connection is cached; so still no change assert driver.get_datasets() == {} - assert os.path.getmtime(driver.path) == modification_date + assert os.path.getmtime(driver.meta_store._local_path) == modification_date # invalidate cache - driver._last_updated = -float('inf') + driver.meta_store._last_updated = -float('inf') with driver.connect(): # now db is updated on reconnect assert list(driver.get_datasets().keys()) == [('some', 'value')] - assert os.path.getmtime(driver.path) != modification_date + assert os.path.getmtime(driver.meta_store._local_path) != modification_date @moto.mock_s3 @@ -160,14 +160,14 @@ def test_destructor(s3_db_factory, raster_file, capsys): from terracotta import get_driver driver = get_driver(dbpath) - assert os.path.isfile(driver.path) + assert os.path.isfile(driver.meta_store._local_path) - driver.__del__() - assert not os.path.isfile(driver.path) + driver.meta_store.__del__() + assert not os.path.isfile(driver.meta_store._local_path) captured = capsys.readouterr() assert 'Exception ignored' not in captured.err # re-create file to prevent actual destructor from failing - with open(driver.path, 'w'): + with open(driver.meta_store._local_path, 'w'): pass diff --git a/tests/test_cog.py b/tests/test_cog.py index 448b301e..75f6e6aa 100644 --- a/tests/test_cog.py +++ b/tests/test_cog.py @@ -69,12 +69,12 @@ def test_validate_unoptimized(tmpdir): from terracotta import cog outfile = str(tmpdir / 'raster.tif') - raster_data = 1000 * np.random.rand(512, 512).astype(np.uint16) + raster_data = 1000 * np.random.rand(1024, 1024).astype(np.uint16) profile = BASE_PROFILE.copy() profile.update( height=raster_data.shape[0], - width=raster_data.shape[1] + width=raster_data.shape[1], ) with rasterio.open(outfile, 'w', **profile) as dst: @@ -87,7 +87,7 @@ def test_validate_no_overviews(tmpdir): from terracotta import cog outfile = str(tmpdir / 'raster.tif') - raster_data = 1000 * np.random.rand(512, 512).astype(np.uint16) + raster_data = 1000 * np.random.rand(1024, 1024).astype(np.uint16) profile = BASE_PROFILE.copy() profile.update( diff --git a/tests/test_raster.py b/tests/test_raster.py new file mode 100644 index 00000000..2e24f4f8 --- /dev/null +++ b/tests/test_raster.py @@ -0,0 +1,256 @@ +import pytest + +import numpy as np +import rasterio +import rasterio.features +from shapely.geometry import shape, MultiPolygon + + +def geometry_mismatch(shape1, shape2): + """Compute relative mismatch of two shapes""" + return shape1.symmetric_difference(shape2).area / shape1.union(shape2).area + + +def convex_hull_exact(src): + kwargs = dict(bidx=1, band=False, as_mask=True, geographic=True) + + data = src.read() + if np.any(np.isnan(data)) and src.nodata is not None: + # hack: replace NaNs with nodata to make sure they are excluded + with rasterio.MemoryFile() as memfile, memfile.open(**src.profile) as tmpsrc: + data[np.isnan(data)] = src.nodata + tmpsrc.write(data) + dataset_shape = list(rasterio.features.dataset_features(tmpsrc, **kwargs)) + else: + dataset_shape = list(rasterio.features.dataset_features(src, **kwargs)) + + convex_hull = MultiPolygon([shape(s['geometry']) for s in dataset_shape]).convex_hull + return convex_hull + + +@pytest.mark.parametrize('large_raster_threshold', [None, 0]) +@pytest.mark.parametrize('use_chunks', [True, False, None]) +@pytest.mark.parametrize('nodata_type', ['nodata', 'masked', 'none', 'nan']) +def test_compute_metadata(big_raster_file_nodata, big_raster_file_nomask, big_raster_file_mask, + raster_file_float, nodata_type, use_chunks, large_raster_threshold): + from terracotta import raster + + if nodata_type == 'nodata': + raster_file = big_raster_file_nodata + elif nodata_type == 'masked': + raster_file = big_raster_file_mask + elif nodata_type == 'none': + raster_file = big_raster_file_nomask + elif nodata_type == 'nan': + raster_file = raster_file_float + + if use_chunks: + pytest.importorskip('crick') + + with rasterio.open(str(raster_file)) as src: + data = src.read(1, masked=True) + valid_data = np.ma.masked_invalid(data).compressed() + convex_hull = convex_hull_exact(src) + + # compare + if nodata_type == 'none': + with pytest.warns(UserWarning) as record: + mtd = raster.compute_metadata( + str(raster_file), + use_chunks=use_chunks, + large_raster_threshold=large_raster_threshold + ) + assert 'does not have a valid nodata value' in str(record[0].message) + else: + mtd = raster.compute_metadata( + str(raster_file), + use_chunks=use_chunks, + large_raster_threshold=large_raster_threshold + ) + + np.testing.assert_allclose(mtd['valid_percentage'], 100 * valid_data.size / data.size) + np.testing.assert_allclose(mtd['range'], (valid_data.min(), valid_data.max())) + np.testing.assert_allclose(mtd['mean'], valid_data.mean()) + np.testing.assert_allclose(mtd['stdev'], valid_data.std()) + + # allow some error margin since we only compute approximate quantiles + np.testing.assert_allclose( + mtd['percentiles'], + np.percentile(valid_data, np.arange(1, 100)), + rtol=2e-2, atol=valid_data.max() / 100 + ) + + assert geometry_mismatch(shape(mtd['convex_hull']), convex_hull) < 1e-6 + + +@pytest.mark.parametrize('nodata_type', ['nodata', 'masked', 'none', 'nan']) +def test_compute_metadata_approximate(nodata_type, big_raster_file_nodata, big_raster_file_mask, + big_raster_file_nomask, raster_file_float): + from terracotta import raster + + if nodata_type == 'nodata': + raster_file = big_raster_file_nodata + elif nodata_type == 'masked': + raster_file = big_raster_file_mask + elif nodata_type == 'none': + raster_file = big_raster_file_nomask + elif nodata_type == 'nan': + raster_file = raster_file_float + + with rasterio.open(str(raster_file)) as src: + data = src.read(1, masked=True) + valid_data = np.ma.masked_invalid(data).compressed() + convex_hull = convex_hull_exact(src) + + # compare + if nodata_type == 'none': + with pytest.warns(UserWarning) as record: + mtd = raster.compute_metadata(str(raster_file), max_shape=(512, 512)) + assert 'does not have a valid nodata value' in str(record[0].message) + else: + mtd = raster.compute_metadata(str(raster_file), max_shape=(512, 512)) + + np.testing.assert_allclose(mtd['valid_percentage'], 100 * valid_data.size / data.size, atol=1) + np.testing.assert_allclose( + mtd['range'], (valid_data.min(), valid_data.max()), atol=valid_data.max() / 100 + ) + np.testing.assert_allclose(mtd['mean'], valid_data.mean(), rtol=0.02) + np.testing.assert_allclose(mtd['stdev'], valid_data.std(), rtol=0.02) + + np.testing.assert_allclose( + mtd['percentiles'], + np.percentile(valid_data, np.arange(1, 100)), + atol=valid_data.max() / 100, rtol=0.02 + ) + + assert geometry_mismatch(shape(mtd['convex_hull']), convex_hull) < 0.05 + + +def test_compute_metadata_invalid_options(big_raster_file_nodata): + from terracotta import raster + + with pytest.raises(ValueError): + raster.compute_metadata( + str(big_raster_file_nodata), max_shape=(256, 256), use_chunks=True + ) + + with pytest.raises(ValueError): + raster.compute_metadata(str(big_raster_file_nodata), max_shape=(256, 256, 1)) + + +@pytest.mark.parametrize('use_chunks', [True, False]) +def test_compute_metadata_invalid_raster(invalid_raster_file, use_chunks): + from terracotta import raster + + if use_chunks: + pytest.importorskip('crick') + + with pytest.raises(ValueError): + raster.compute_metadata(str(invalid_raster_file), use_chunks=use_chunks) + + +def test_compute_metadata_nocrick(big_raster_file_nodata, monkeypatch): + with rasterio.open(str(big_raster_file_nodata)) as src: + data = src.read(1, masked=True) + valid_data = np.ma.masked_invalid(data).compressed() + convex_hull = convex_hull_exact(src) + + from terracotta import exceptions + import terracotta.drivers.geotiff_raster_store + + with monkeypatch.context() as m: + m.setattr(terracotta.raster, 'has_crick', False) + + with pytest.warns(exceptions.PerformanceWarning): + mtd = terracotta.drivers.geotiff_raster_store.raster.compute_metadata( + str(big_raster_file_nodata), use_chunks=True + ) + + # compare + np.testing.assert_allclose(mtd['valid_percentage'], 100 * valid_data.size / data.size) + np.testing.assert_allclose(mtd['range'], (valid_data.min(), valid_data.max())) + np.testing.assert_allclose(mtd['mean'], valid_data.mean()) + np.testing.assert_allclose(mtd['stdev'], valid_data.std()) + + # allow error of 1%, since we only compute approximate quantiles + np.testing.assert_allclose( + mtd['percentiles'], + np.percentile(valid_data, np.arange(1, 100)), + rtol=2e-2 + ) + + assert geometry_mismatch(shape(mtd['convex_hull']), convex_hull) < 1e-6 + + +def test_compute_metadata_unoptimized(unoptimized_raster_file): + from terracotta import exceptions + from terracotta import raster + + with rasterio.open(str(unoptimized_raster_file)) as src: + data = src.read(1, masked=True) + valid_data = np.ma.masked_invalid(data).compressed() + convex_hull = convex_hull_exact(src) + + # compare + with pytest.warns(exceptions.PerformanceWarning): + mtd = raster.compute_metadata(str(unoptimized_raster_file), use_chunks=False) + + np.testing.assert_allclose(mtd['valid_percentage'], 100 * valid_data.size / data.size) + np.testing.assert_allclose(mtd['range'], (valid_data.min(), valid_data.max())) + np.testing.assert_allclose(mtd['mean'], valid_data.mean()) + np.testing.assert_allclose(mtd['stdev'], valid_data.std()) + + # allow some error margin since we only compute approximate quantiles + np.testing.assert_allclose( + mtd['percentiles'], + np.percentile(valid_data, np.arange(1, 100)), + rtol=2e-2 + ) + + assert geometry_mismatch(shape(mtd['convex_hull']), convex_hull) < 1e-6 + + +@pytest.mark.parametrize('preserve_values', [True, False]) +@pytest.mark.parametrize('resampling_method', ['nearest', 'linear', 'cubic', 'average']) +def test_get_raster_tile(raster_file, preserve_values, resampling_method): + from terracotta import raster + + data = raster.get_raster_tile( + str(raster_file), + reprojection_method=resampling_method, + resampling_method=resampling_method, + preserve_values=preserve_values, + tile_size=(256, 256) + ) + assert data.shape == (256, 256) + + +def test_get_raster_tile_out_of_bounds(raster_file): + from terracotta import exceptions + from terracotta import raster + + bounds = ( + -1e30, + -1e30, + 1e30, + 1e30, + ) + + with pytest.raises(exceptions.TileOutOfBoundsError): + raster.get_raster_tile(str(raster_file), tile_bounds=bounds) + + +def test_get_raster_no_nodata(big_raster_file_nomask): + from terracotta import raster + + tile_size = (256, 256) + out = raster.get_raster_tile(str(big_raster_file_nomask), tile_size=tile_size) + assert out.shape == tile_size + + +def test_invalid_resampling_method(): + from terracotta import raster + + with pytest.raises(ValueError) as exc: + raster.get_resampling_enum('not-a-resampling-method') + assert 'unknown resampling method' in str(exc)