diff --git a/.github/workflows/build-production.yml b/.github/workflows/build-production.yml new file mode 100644 index 0000000..608c92b --- /dev/null +++ b/.github/workflows/build-production.yml @@ -0,0 +1,79 @@ +name: Build Docker images for geolake components and push to the repository + +on: + push: + tags: + - 'v*' +jobs: + build: + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v3 + - name: Set up Python + uses: actions/setup-python@v4 + with: + python-version: "3.x" + - name: Install build + run: >- + python3 -m + pip install + build + --user + - name: Build a binary wheel and a source for drivers + run: python3 -m build ./drivers + - name: Set Docker image tag name + run: echo "TAG=$(date +'%Y.%m.%d.%H.%M')" >> $GITHUB_ENV + - name: Login to Scaleway Container Registry + uses: docker/login-action@v2 + with: + username: nologin + password: ${{ secrets.DOCKER_PASSWORD }} + registry: ${{ vars.DOCKER_REGISTRY }} + - name: Get release tag + run: echo "RELEASE_TAG=${GITHUB_REF#refs/*/}" >> $GITHUB_ENV + - name: Set up Docker Buildx + uses: docker/setup-buildx-action@v2 + - name: Build and push drivers + uses: docker/build-push-action@v4 + with: + context: ./drivers + file: ./drivers/Dockerfile + push: true + build-args: | + REGISTRY=${{ vars.DOCKER_REGISTRY }} + tags: | + ${{ vars.DOCKER_REGISTRY }}/geolake-drivers:${{ env.RELEASE_TAG }} + ${{ vars.DOCKER_REGISTRY }}/geolake-drivers:latest + - name: Build and push datastore component + uses: docker/build-push-action@v4 + with: + context: ./datastore + file: ./datastore/Dockerfile + push: true + build-args: | + REGISTRY=${{ vars.DOCKER_REGISTRY }} + tags: | + ${{ vars.DOCKER_REGISTRY }}/geolake-datastore:${{ env.RELEASE_TAG }} + ${{ vars.DOCKER_REGISTRY }}/geolake-datastore:latest + - name: Build and push api component + uses: docker/build-push-action@v4 + with: + context: ./api + file: ./api/Dockerfile + push: true + build-args: | + REGISTRY=${{ vars.DOCKER_REGISTRY }} + tags: | + ${{ vars.DOCKER_REGISTRY }}/geolake-api:${{ env.RELEASE_TAG }} + ${{ vars.DOCKER_REGISTRY }}/geolake-api:latest + - name: Build and push executor component + uses: docker/build-push-action@v4 + with: + context: ./executor + file: ./executor/Dockerfile + push: true + build-args: | + REGISTRY=${{ vars.DOCKER_REGISTRY }} + tags: | + ${{ vars.DOCKER_REGISTRY }}/geolake-executor:${{ env.RELEASE_TAG }} + ${{ vars.DOCKER_REGISTRY }}/geolake-executor:latest \ No newline at end of file diff --git a/.github/workflows/build-staging.yml b/.github/workflows/build-staging.yml new file mode 100644 index 0000000..7c16ff2 --- /dev/null +++ b/.github/workflows/build-staging.yml @@ -0,0 +1,77 @@ +name: Build Docker images for geolake components and push to the repository + +on: + pull_request: + types: [opened, synchronize] + workflow_dispatch: +jobs: + build: + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v3 + - name: Set up Python + uses: actions/setup-python@v4 + with: + python-version: "3.x" + - name: Install build + run: >- + python3 -m + pip install + build + --user + - name: Build a binary wheel and a source for drivers + run: python3 -m build ./drivers + - name: Set Docker image tag name + run: echo "TAG=$(date +'%Y.%m.%d.%H.%M')" >> $GITHUB_ENV + - name: Login to Scaleway Container Registry + uses: docker/login-action@v2 + with: + username: nologin + password: ${{ secrets.DOCKER_PASSWORD }} + registry: ${{ vars.DOCKER_REGISTRY }} + - name: Set up Docker Buildx + uses: docker/setup-buildx-action@v2 + - name: Build and push drivers + uses: docker/build-push-action@v4 + with: + context: ./drivers + file: ./drivers/Dockerfile + push: true + build-args: | + REGISTRY=${{ vars.DOCKER_REGISTRY }} + tags: | + ${{ vars.DOCKER_REGISTRY }}/geolake-drivers:${{ env.TAG }} + ${{ vars.DOCKER_REGISTRY }}/geolake-drivers:latest + - name: Build and push datastore component + uses: docker/build-push-action@v4 + with: + context: ./datastore + file: ./datastore/Dockerfile + push: true + build-args: | + REGISTRY=${{ vars.DOCKER_REGISTRY }} + tags: | + ${{ vars.DOCKER_REGISTRY }}/geolake-datastore:${{ env.TAG }} + ${{ vars.DOCKER_REGISTRY }}/geolake-datastore:latest + - name: Build and push api component + uses: docker/build-push-action@v4 + with: + context: ./api + file: ./api/Dockerfile + push: true + build-args: | + REGISTRY=${{ vars.DOCKER_REGISTRY }} + tags: | + ${{ vars.DOCKER_REGISTRY }}/geolake-api:${{ env.TAG }} + ${{ vars.DOCKER_REGISTRY }}/geolake-api:latest + - name: Build and push executor component + uses: docker/build-push-action@v4 + with: + context: ./executor + file: ./executor/Dockerfile + push: true + build-args: | + REGISTRY=${{ vars.DOCKER_REGISTRY }} + tags: | + ${{ vars.DOCKER_REGISTRY }}/geolake-executor:${{ env.TAG }} + ${{ vars.DOCKER_REGISTRY }}/geolake-executor:latest \ No newline at end of file diff --git a/api/Dockerfile b/api/Dockerfile index 6182cb1..a2cfea0 100644 --- a/api/Dockerfile +++ b/api/Dockerfile @@ -1,15 +1,9 @@ -FROM continuumio/miniconda3 -WORKDIR /code -COPY ./api/requirements.txt /code/requirements.txt +ARG REGISTRY=rg.nl-ams.scw.cloud/geodds-production +ARG TAG=latest +FROM $REGISTRY/geolake-datastore:$TAG +WORKDIR /app +COPY requirements.txt /code/requirements.txt RUN pip install --no-cache-dir --upgrade -r /code/requirements.txt -RUN conda install -c anaconda psycopg2 -COPY ./utils/wait-for-it.sh /code/wait-for-it.sh -COPY ./db/dbmanager /code/db/dbmanager -COPY ./geoquery/ /code/geoquery -COPY ./resources /code/app/resources -COPY ./api/app /code/app +COPY app /app EXPOSE 80 -# VOLUME /code -CMD ["uvicorn", "app.main:app", "--host", "0.0.0.0", "--port", "80"] -# if behind a proxy use --proxy-headers -# CMD ["uvicorn", "app.main:app", "--proxy-headers", "--host", "0.0.0.0", "--port", "80"] \ No newline at end of file +CMD ["uvicorn", "app.main:app", "--proxy-headers", "--host", "0.0.0.0", "--port", "80"] diff --git a/api/app/api_utils.py b/api/app/api_utils.py new file mode 100644 index 0000000..82ea9f6 --- /dev/null +++ b/api/app/api_utils.py @@ -0,0 +1,73 @@ +"""Utils module""" + + +def convert_bytes(size_bytes: int, to: str) -> float: + """Converts size in bytes to the other unit - one out of: + ["kb", "mb", "gb"] + + Parameters + ---------- + size_bytes : int + Size in bytes + to : str + Unit to convert `size_bytes` to + + size : float + `size_bytes` converted to the given unit + """ + assert to is not None, "Expected unit cannot be `None`" + to = to.lower() + match to: + case "bytes": + return size_bytes + case "kb": + return size_bytes / 1024 + case "mb": + return size_bytes / 1024**2 + case "gb": + return size_bytes / 1024**3 + case _: + raise ValueError(f"unsupported units: {to}") + + +def make_bytes_readable_dict( + size_bytes: int, units: str | None = None +) -> dict: + """Prepare dictionary representing size (in bytes) in more readable unit + to keep value in the range [0,1] - if `units` is `None`. + If `units` is not None, converts `size_bytes` to the size expressed by + that argument. + + Parameters + ---------- + size_bytes : int + Size expressed in bytes + units : optional str + + Returns + ------- + result : dict + A dictionary with size and units in the form: + { + "value": ..., + "units": ... + } + """ + if units is None: + units = "bytes" + if units != "bytes": + converted_size = convert_bytes(size_bytes=size_bytes, to=units) + return {"value": converted_size, "units": units} + val = size_bytes + if val > 1024: + units = "kB" + val /= 1024 + if val > 1024: + units = "MB" + val /= 1024 + if val > 1024: + units = "GB" + val /= 1024 + if val > 0.0 and (round(val, 2) == 0.00): + val = 0.01 + return {"value": round(val, 2), "units": units} diff --git a/db/dbmanager/__init__.py b/api/app/auth/__init__.py similarity index 100% rename from db/dbmanager/__init__.py rename to api/app/auth/__init__.py diff --git a/api/app/auth/backend.py b/api/app/auth/backend.py new file mode 100644 index 0000000..c172b58 --- /dev/null +++ b/api/app/auth/backend.py @@ -0,0 +1,66 @@ +"""The module contains authentication backend""" +from uuid import UUID + +from starlette.authentication import ( + AuthCredentials, + AuthenticationBackend, + UnauthenticatedUser, +) +from dbmanager.dbmanager import DBManager + +import exceptions as exc +from auth.models import DDSUser +from auth import scopes + + +class DDSAuthenticationBackend(AuthenticationBackend): + """Class managing authentication and authorization""" + + async def authenticate(self, conn): + """Authenticate user based on `User-Token` header""" + if "User-Token" in conn.headers: + return self._manage_user_token_auth(conn.headers["User-Token"]) + return AuthCredentials([scopes.ANONYMOUS]), UnauthenticatedUser() + + def _manage_user_token_auth(self, user_token: str): + try: + user_id, api_key = self.get_authorization_scheme_param(user_token) + except exc.BaseDDSException as err: + raise err.wrap_around_http_exception() + user_dto = DBManager().get_user_details(user_id) + eligible_scopes = [scopes.AUTHENTICATED] + self._get_scopes_for_user( + user_dto=user_dto + ) + if user_dto.api_key != api_key: + raise exc.AuthenticationFailed( + user_dto + ).wrap_around_http_exception() + return AuthCredentials(eligible_scopes), DDSUser(username=user_id) + + def _get_scopes_for_user(self, user_dto) -> list[str]: + if user_dto is None: + return [] + eligible_scopes = [] + for role in user_dto.roles: + if "admin" == role.role_name: + eligible_scopes.append(scopes.ADMIN) + continue + # NOTE: Role-specific scopes + # Maybe need some more logic + eligible_scopes.append(role.role_name) + return eligible_scopes + + def get_authorization_scheme_param(self, user_token: str): + """Get `user_id` and `api_key` if authorization scheme is correct.""" + if user_token is None or user_token.strip() == "": + raise exc.EmptyUserTokenError + if ":" not in user_token: + raise exc.ImproperUserTokenError + user_id, api_key, *rest = user_token.split(":") + if len(rest) > 0: + raise exc.ImproperUserTokenError + try: + _ = UUID(user_id, version=4) + except ValueError as err: + raise exc.ImproperUserTokenError from err + return (user_id, api_key) diff --git a/api/app/auth/manager.py b/api/app/auth/manager.py new file mode 100644 index 0000000..02bf686 --- /dev/null +++ b/api/app/auth/manager.py @@ -0,0 +1,72 @@ +"""Module with access/authentication functions""" +from typing import Optional + +from utils.api_logging import get_dds_logger +import exceptions as exc + +log = get_dds_logger(__name__) + + +def is_role_eligible_for_product( + product_role_name: Optional[str] = None, + user_roles_names: Optional[list[str]] = None, +): + """Check if given role is eligible for the product with the provided + `product_role_name`. + + Parameters + ---------- + product_role_name : str, optional, default=None + The role which is eligible for the given product. + If `None`, product_role_name is claimed to be public + user_roles_names: list of str, optional, default=None + A list of user roles names. If `None`, user_roles_names is claimed + to be public + + Returns + ------- + is_eligible : bool + Flag which indicate if any role within the given `user_roles_names` + is eligible for the product with `product_role_name` + """ + log.debug( + "verifying eligibility of the product role '%s' against roles '%s'", + product_role_name, + user_roles_names, + ) + if product_role_name == "public" or product_role_name is None: + return True + if user_roles_names is None: + # NOTE: it means, we consider the public profile + return False + if "admin" in user_roles_names: + return True + if product_role_name in user_roles_names: + return True + return False + + +def assert_is_role_eligible( + product_role_name: Optional[str] = None, + user_roles_names: Optional[list[str]] = None, +): + """Assert that user role is eligible for the product + + Parameters + ---------- + product_role_name : str, optional, default=None + The role which is eligible for the given product. + If `None`, product_role_name is claimed to be public + user_roles_names: list of str, optional, default=None + A list of user roles names. If `None`, user_roles_names is claimed + to be public + + Raises + ------- + AuthorizationFailed + """ + if not is_role_eligible_for_product( + product_role_name=product_role_name, + user_roles_names=user_roles_names, + ): + raise exc.AuthorizationFailed diff --git a/api/app/auth/models.py b/api/app/auth/models.py new file mode 100644 index 0000000..bff896f --- /dev/null +++ b/api/app/auth/models.py @@ -0,0 +1,38 @@ +"""The module contains models related to the authentication and authorization""" +from starlette.authentication import SimpleUser + + +class DDSUser(SimpleUser): + """Immutable class containing information about the authenticated user""" + + def __init__(self, username: str) -> None: + super().__init__(username=username) + + @property + def id(self): + return self.username + + def __eq__(self, other) -> bool: + if not isinstance(other, DDSUser): + return False + if self.username == other.username: + return True + return False + + def __ne__(self, other): + return self != other + + def __repr__(self): + return f"" + + def __delattr__(self, name): + if getattr(self, name, None) is not None: + raise AttributeError("The attribute '{name}' cannot be deleted!") + super().__delattr__(name) + + def __setattr__(self, name, value): + if getattr(self, name, None) is not None: + raise AttributeError( + "The attribute '{name}' cannot modified when not None!" + ) + super().__setattr__(name, value) diff --git a/api/app/auth/scopes.py b/api/app/auth/scopes.py new file mode 100644 index 0000000..75113e4 --- /dev/null +++ b/api/app/auth/scopes.py @@ -0,0 +1,5 @@ +"""This module contains predefined authorization scopes""" + +ADMIN = "admin" +AUTHENTICATED = "authenticated" +ANONYMOUS = "anonymous" diff --git a/api/app/callbacks/__init__.py b/api/app/callbacks/__init__.py new file mode 100644 index 0000000..e003acf --- /dev/null +++ b/api/app/callbacks/__init__.py @@ -0,0 +1 @@ +from .on_startup import all_onstartup_callbacks diff --git a/api/app/callbacks/on_startup.py b/api/app/callbacks/on_startup.py new file mode 100644 index 0000000..ec883d3 --- /dev/null +++ b/api/app/callbacks/on_startup.py @@ -0,0 +1,15 @@ +"""Module with functions call during API server startup""" +from utils.api_logging import get_dds_logger + +from datastore.datastore import Datastore + +log = get_dds_logger(__name__) + + +def _load_cache() -> None: + log.info("loading cache started...") + Datastore()._load_cache() + log.info("cache loaded succesfully!") + + +all_onstartup_callbacks = [_load_cache] diff --git a/geoquery/__init__.py b/api/app/const/__init__.py similarity index 100% rename from geoquery/__init__.py rename to api/app/const/__init__.py diff --git a/api/app/const/tags.py b/api/app/const/tags.py new file mode 100644 index 0000000..58a2213 --- /dev/null +++ b/api/app/const/tags.py @@ -0,0 +1,5 @@ +"""The module with endpoint tags definitions""" + +BASIC = "basic" +DATASET = "dataset" +REQUEST = "request" diff --git a/api/app/const/venv.py b/api/app/const/venv.py new file mode 100644 index 0000000..85c3658 --- /dev/null +++ b/api/app/const/venv.py @@ -0,0 +1,7 @@ +"""This modul contains all supported environment variables names""" + +ENDPOINT_PREFIX = "ENDPOINT_PREFIX" +ALLOWED_CORS_ORIGINS_REGEX = "ALLOWED_CORS_ORIGINS_REGEX" +LOGGING_FORMAT = "LOGGING_FORMAT" +LOGGING_LEVEL = "LOGGING_LEVEL" +WEB_COMPONENT_HOST = "WEB_COMPONENT_HOST" diff --git a/api/app/decorators_factory.py b/api/app/decorators_factory.py new file mode 100644 index 0000000..d2e4b39 --- /dev/null +++ b/api/app/decorators_factory.py @@ -0,0 +1,37 @@ +"""Modules with utils for creating decorators""" +from inspect import Signature + + +def assert_parameters_are_defined( + sig: Signature, required_parameters: list[tuple] +): + """Assert the given callable signature has parameters with + names and types indicated by `required_parameters` argument. + + Parameters + ---------- + sig : Signature + A signature object of a callable + required_parameters : list of tuples + List of two-element tuples containing a name and a type + of the parameter, e.g. [("dataset_id", str)] + + Raises + ------ + TypeError + If a required parameter is not defined or is of wrong type + """ + for param_name, param_type in required_parameters: + if param_name not in sig.parameters: + raise TypeError( + f"The parameter '{param_name}' annotated with the type" + f" '{param_type}' must be defined for the callable decorated" + " with 'authenticate_user' decorator" + ) + + +def bind_arguments(sig: Signature, *args, **kwargs): + """Bind arguments to the signature""" + args_bind = sig.bind_partial(*args, **kwargs) + args_bind.apply_defaults() + return args_bind.arguments diff --git a/api/app/encoders.py b/api/app/encoders.py new file mode 100644 index 0000000..9566f57 --- /dev/null +++ b/api/app/encoders.py @@ -0,0 +1,41 @@ +import numpy as np +from fastapi.encoders import encoders_by_class_tuples + + +def make_ndarray_dtypes_valid(o: np.ndarray) -> np.ndarray: + """Convert `numpy.array` dtype to the one which is serializable + to JSON. + + int32 -> int64 + float32 -> float 64 + + Parameters + ---------- + o : np.ndarray + A NumPy array object + + Returns + ------- + res : np.ndarray + A NumPy array object with dtype set properly + + Raises + ------ + AssertionError + If passed object is not of `numpy.ndarray` + """ + assert isinstance(o, np.ndarray) + if np.issubdtype(o.dtype, np.int32): + return o.astype(np.int64) + if np.issubdtype(o.dtype, np.float32): + return o.astype(np.float64) + return o + + +def extend_json_encoders(): + """Extend `encoders_by_class_tuples` module variable from `fastapi.encoders` + with auxiliary encoders necessary for proper application working.""" + encoders_by_class_tuples[lambda o: list(make_ndarray_dtypes_valid(o))] = ( + np.ndarray, + ) + encoders_by_class_tuples[str] += (np.int32, np.float32) diff --git a/api/app/endpoint_handlers/__init__.py b/api/app/endpoint_handlers/__init__.py new file mode 100644 index 0000000..c5a44be --- /dev/null +++ b/api/app/endpoint_handlers/__init__.py @@ -0,0 +1,3 @@ +from . import file as file_handler +from . import dataset as dataset_handler +from . import request as request_handler diff --git a/api/app/endpoint_handlers/dataset.py b/api/app/endpoint_handlers/dataset.py new file mode 100644 index 0000000..c03a54b --- /dev/null +++ b/api/app/endpoint_handlers/dataset.py @@ -0,0 +1,430 @@ +"""Modules realizing logic for dataset-related endpoints""" +import os +import pika +import json +from typing import Optional + +from fastapi.responses import FileResponse + +from dbmanager.dbmanager import DBManager, RequestStatus +from intake_geokube.queries.geoquery import GeoQuery +from intake_geokube.queries.workflow import Workflow +from datastore.datastore import Datastore, DEFAULT_MAX_REQUEST_SIZE_GB +from datastore import exception as datastore_exception + +from utils.metrics import log_execution_time +from utils.api_logging import get_dds_logger +from auth.manager import ( + is_role_eligible_for_product, +) +import exceptions as exc +from api_utils import make_bytes_readable_dict +from validation import assert_product_exists + +from . import request + +log = get_dds_logger(__name__) +data_store = Datastore() + +MESSAGE_SEPARATOR = os.environ["MESSAGE_SEPARATOR"] + +def _is_etimate_enabled(dataset_id, product_id): + if dataset_id in ("sentinel-2",): + return False + return True + + +@log_execution_time(log) +def get_datasets(user_roles_names: list[str]) -> list[dict]: + """Realize the logic for the endpoint: + + `GET /datasets` + + Get datasets names, their metadata and products names (if eligible for a user). + If no eligible products are found for a dataset, it is not included. + + Parameters + ---------- + user_roles_names : list of str + List of user's roles + + Returns + ------- + datasets : list of dict + A list of dictionaries with datasets information (including metadata and + eligible products lists) + + Raises + ------- + MissingKeyInCatalogEntryError + If the dataset catalog entry does not contain the required key + """ + log.debug( + "getting all eligible products for datasets...", + ) + datasets = [] + for dataset_id in data_store.dataset_list(): + log.debug( + "getting info and eligible products for `%s`", + dataset_id, + ) + dataset_info = data_store.dataset_info(dataset_id=dataset_id) + try: + eligible_prods = { + prod_name: prod_info + for prod_name, prod_info in dataset_info["products"].items() + if is_role_eligible_for_product( + product_role_name=prod_info.get("role"), + user_roles_names=user_roles_names, + ) + } + except KeyError as err: + log.error( + "dataset `%s` does not have products defined", + dataset_id, + exc_info=True, + ) + raise exc.MissingKeyInCatalogEntryError( + key="products", dataset=dataset_id + ) from err + else: + if len(eligible_prods) == 0: + log.debug( + "no eligible products for dataset `%s` for the role `%s`." + " dataset skipped", + dataset_id, + user_roles_names, + ) + else: + dataset_info["products"] = eligible_prods + datasets.append(dataset_info) + return datasets + + +@log_execution_time(log) +@assert_product_exists +def get_product_details( + user_roles_names: list[str], + dataset_id: str, + product_id: Optional[str] = None, +) -> dict: + """Realize the logic for the endpoint: + + `GET /datasets/{dataset_id}/{product_id}` + + Get details for the given product indicated by `dataset_id` + and `product_id` arguments. + + Parameters + ---------- + user_roles_names : list of str + List of user's roles + dataset_id : str + ID of the dataset + product_id : optional, str + ID of the product. If `None` the 1st product will be considered + + Returns + ------- + details : dict + Details for the given product + + Raises + ------- + AuthorizationFailed + If user is not authorized for the resources + """ + log.debug( + "getting details for eligible products of `%s`", + dataset_id, + ) + try: + if product_id: + return data_store.product_details( + dataset_id=dataset_id, + product_id=product_id, + role=user_roles_names, + use_cache=True, + ) + else: + return data_store.first_eligible_product_details( + dataset_id=dataset_id, role=user_roles_names, use_cache=True + ) + except datastore_exception.UnauthorizedError as err: + raise exc.AuthorizationFailed from err + + +@log_execution_time(log) +@assert_product_exists +def get_metadata(dataset_id: str, product_id: str): + """Realize the logic for the endpoint: + + `GET /datasets/{dataset_id}/{product_id}/metadata` + + Get metadata for the product. + + Parameters + ---------- + dataset_id : str + ID of the dataset + product_id : str + ID of the product + """ + log.debug( + "getting metadata for '{dataset_id}.{product_id}'", + ) + return data_store.product_metadata(dataset_id, product_id) + + +@log_execution_time(log) +@assert_product_exists +def estimate( + dataset_id: str, + product_id: str, + query: GeoQuery, + unit: Optional[str] = None, +): + """Realize the logic for the nedpoint: + + `POST /datasets/{dataset_id}/{product_id}/estimate` + + Estimate the size of the resulting data. + No authentication is needed for estimation query. + + Parameters + ---------- + dataset_id : str + ID of the dataset + product_id : str + ID of the product + query : GeoQuery + Query to perform + unit : str + One of unit [bytes, kB, MB, GB] to present the result. If `None`, + unit will be inferred. + + Returns + ------- + size_details : dict + Estimated size of the query in the form: + ```python + { + "value": val, + "units": units + } + ``` + """ + query_bytes_estimation = data_store.estimate(dataset_id, product_id, query) + return make_bytes_readable_dict( + size_bytes=query_bytes_estimation, units=unit + ) + + +@log_execution_time(log) +@assert_product_exists +def async_query( + user_id: str, + dataset_id: str, + product_id: str, + query: GeoQuery, +): + """Realize the logic for the endpoint: + + `POST /datasets/{dataset_id}/{product_id}/execute` + + Query the data and return the ID of the request. + + Parameters + ---------- + user_id : str + ID of the user executing the query + dataset_id : str + ID of the dataset + product_id : str + ID of the product + query : GeoQuery + Query to perform + + Returns + ------- + request_id : int + ID of the request + + Raises + ------- + MaximumAllowedSizeExceededError + if the allowed size is below the estimated one + EmptyDatasetError + if estimated size is zero + + """ + log.debug("geoquery: %s", query) + if _is_etimate_enabled(dataset_id, product_id): + estimated_size = estimate(dataset_id, product_id, query, "GB").get("value") + allowed_size = data_store.product_metadata(dataset_id, product_id).get( + "maximum_query_size_gb", DEFAULT_MAX_REQUEST_SIZE_GB + ) + if estimated_size > allowed_size: + raise exc.MaximumAllowedSizeExceededError( + dataset_id=dataset_id, + product_id=product_id, + estimated_size_gb=estimated_size, + allowed_size_gb=allowed_size, + ) + if estimated_size == 0.0: + raise exc.EmptyDatasetError( + dataset_id=dataset_id, product_id=product_id + ) + broker_conn = pika.BlockingConnection( + pika.ConnectionParameters( + host=os.getenv("BROKER_SERVICE_HOST", "broker") + ) + ) + broker_channel = broker_conn.channel() + + request_id = DBManager().create_request( + user_id=user_id, + dataset=dataset_id, + product=product_id, + query=json.dumps(query.model_dump_original()), + ) + + # TODO: find a separator; for the moment use "\" + message = MESSAGE_SEPARATOR.join( + [str(request_id), "query", dataset_id, product_id, query.json()] + ) + + broker_channel.basic_publish( + exchange="", + routing_key="query_queue", + body=message, + properties=pika.BasicProperties( + delivery_mode=2, # make message persistent + ), + ) + broker_conn.close() + return request_id + +@log_execution_time(log) +@assert_product_exists +def sync_query( + user_id: str, + dataset_id: str, + product_id: str, + query: GeoQuery, +): + """Realize the logic for the endpoint: + + `POST /datasets/{dataset_id}/{product_id}/execute` + + Query the data and return the result of the request. + + Parameters + ---------- + user_id : str + ID of the user executing the query + dataset_id : str + ID of the dataset + product_id : str + ID of the product + query : GeoQuery + Query to perform + + Returns + ------- + request_id : int + ID of the request + + Raises + ------- + MaximumAllowedSizeExceededError + if the allowed size is below the estimated one + EmptyDatasetError + if estimated size is zero + + """ + + import time + request_id = async_query(user_id, dataset_id, product_id, query) + status, _ = DBManager().get_request_status_and_reason(request_id) + log.debug("sync query: status: %s", status) + while status in (RequestStatus.RUNNING, RequestStatus.QUEUED, + RequestStatus.PENDING): + time.sleep(1) + status, _ = DBManager().get_request_status_and_reason(request_id) + log.debug("sync query: status: %s", status) + + if status is RequestStatus.DONE: + download_details = DBManager().get_download_details_for_request_id( + request_id + ) + return FileResponse( + path=download_details.location_path, + filename=download_details.location_path.split(os.sep)[-1], + ) + raise exc.ProductRetrievingError( + dataset_id=dataset_id, + product_id=product_id, + status=status.name) + + +@log_execution_time(log) +def run_workflow( + user_id: str, + workflow: Workflow, +): + """Realize the logic for the endpoint: + + `POST /datasets/workflow` + + Schedule the workflow and return the ID of the request. + + Parameters + ---------- + user_id : str + ID of the user executing the query + workflow : Workflow + Workflow to perform + + Returns + ------- + request_id : int + ID of the request + + Raises + ------- + MaximumAllowedSizeExceededError + if the allowed size is below the estimated one + EmptyDatasetError + if estimated size is zero + + """ + log.debug("geoquery: %s", workflow) + broker_conn = pika.BlockingConnection( + pika.ConnectionParameters( + host=os.getenv("BROKER_SERVICE_HOST", "broker") + ) + ) + broker_channel = broker_conn.channel() + request_id = DBManager().create_request( + user_id=user_id, + dataset=workflow.dataset_id, + product=workflow.product_id, + query=workflow.json(), + ) + + # TODO: find a separator; for the moment use "\" + message = MESSAGE_SEPARATOR.join( + [str(request_id), "workflow", workflow.json()] + ) + + broker_channel.basic_publish( + exchange="", + routing_key="query_queue", + body=message, + properties=pika.BasicProperties( + delivery_mode=2, # make message persistent + ), + ) + broker_conn.close() + return request_id diff --git a/api/app/endpoint_handlers/file.py b/api/app/endpoint_handlers/file.py new file mode 100644 index 0000000..04cf562 --- /dev/null +++ b/api/app/endpoint_handlers/file.py @@ -0,0 +1,66 @@ +"""Module with functions to handle file related endpoints""" +import os + +from fastapi.responses import FileResponse +from dbmanager.dbmanager import DBManager, RequestStatus + +from utils.api_logging import get_dds_logger +from utils.metrics import log_execution_time +import exceptions as exc + +log = get_dds_logger(__name__) + + +@log_execution_time(log) +def download_request_result(request_id: int): + """Realize the logic for the endpoint: + + `GET /download/{request_id}` + + Get location path of the file being the result of + the request with `request_id`. + + Parameters + ---------- + request_id : int + ID of the request + + Returns + ------- + path : str + The location of the resulting file + + Raises + ------- + RequestNotYetAccomplished + If dds request was not yet finished + FileNotFoundError + If file was not found + """ + log.debug( + "preparing downloads for request id: %s", + request_id, + ) + ( + request_status, + _, + ) = DBManager().get_request_status_and_reason(request_id=request_id) + if request_status is not RequestStatus.DONE: + log.debug( + "request with id: '%s' does not exist or it is not finished yet!", + request_id, + ) + raise exc.RequestNotYetAccomplished(request_id=request_id) + download_details = DBManager().get_download_details_for_request( + request_id=request_id + ) + if not os.path.exists(download_details.location_path): + log.error( + "file '%s' does not exists!", + download_details.location_path, + ) + raise FileNotFoundError + return FileResponse( + path=download_details.location_path, + filename=download_details.location_path.split(os.sep)[-1], + ) diff --git a/api/app/endpoint_handlers/request.py b/api/app/endpoint_handlers/request.py new file mode 100644 index 0000000..93a0636 --- /dev/null +++ b/api/app/endpoint_handlers/request.py @@ -0,0 +1,144 @@ +"""Modules with functions realizing logic for requests-related endpoints""" +from dbmanager.dbmanager import DBManager + +from utils.api_logging import get_dds_logger +from utils.metrics import log_execution_time +import exceptions as exc + +log = get_dds_logger(__name__) + + +@log_execution_time(log) +def get_requests(user_id: str): + """Realize the logic for the endpoint: + + `GET /requests` + + Get details of all requests for the user. + + Parameters + ---------- + user_id : str + ID of the user for whom requests are taken + + Returns + ------- + requests : list + List of all requests done by the user + """ + return DBManager().get_requests_for_user_id(user_id=user_id) + + +@log_execution_time(log) +def get_request_status(user_id: str, request_id: int): + """Realize the logic for the endpoint: + + `GET /requests/{request_id}/status` + + Get request status and the reason of the eventual fail. + The second item is `None`, it status is other than failed. + + Parameters + ---------- + user_id : str + ID of the user whose request's status is about to be checed + request_id : int + ID of the request + + Returns + ------- + status : tuple + Tuple of status and fail reason. + """ + # NOTE: maybe verification should be added if user checks only him\her requests + try: + status, reason = DBManager().get_request_status_and_reason(request_id) + except IndexError as err: + log.error( + "request with id: '%s' was not found!", + request_id, + ) + raise exc.RequestNotFound(request_id=request_id) from err + return {"status": status.name, "fail_reason": reason} + + +@log_execution_time(log) +def get_request_resulting_size(request_id: int): + """Realize the logic for the endpoint: + + `GET /requests/{request_id}/size` + + Get size of the file being the result of the request with `request_id` + + Parameters + ---------- + request_id : int + ID of the request + + Returns + ------- + size : int + Size in bytes + + Raises + ------- + RequestNotFound + If the request was not found + """ + if request := DBManager().get_request_details(request_id): + size = request.download.size_bytes + if not size or size == 0: + raise exc.EmptyDatasetError(dataset_id=request.dataset, + product_id=request.product) + return size + log.info( + "request with id '%s' could not be found", + request_id, + ) + raise exc.RequestNotFound(request_id=request_id) + + +@log_execution_time(log) +def get_request_uri(request_id: int): + """ + Realize the logic for the endpoint: + + `GET /requests/{request_id}/uri` + + Get URI for the request. + + Parameters + ---------- + request_id : int + ID of the request + + Returns + ------- + uri : str + URI for the download associated with the given request + """ + try: + download_details = DBManager().get_download_details_for_request_id( + request_id + ) + except IndexError as err: + log.error( + "request with id: '%s' was not found!", + request_id, + ) + raise exc.RequestNotFound(request_id=request_id) from err + if download_details is None: + ( + request_status, + _, + ) = DBManager().get_request_status_and_reason(request_id) + log.info( + "download URI not found for request id: '%s'." + " Request status is '%s'", + request_id, + request_status, + ) + raise exc.RequestStatusNotDone( + request_id=request_id, request_status=request_status + ) + return download_details.download_uri diff --git a/api/app/exceptions.py b/api/app/exceptions.py new file mode 100644 index 0000000..af4d072 --- /dev/null +++ b/api/app/exceptions.py @@ -0,0 +1,195 @@ +"""Module with DDS exceptions definitions""" +from typing import Optional + +from fastapi import HTTPException + + +class BaseDDSException(BaseException): + """Base class for DDS.api exceptions""" + + msg: str = "Bad request" + code: int = 400 + + def wrap_around_http_exception(self) -> HTTPException: + """Wrap an exception around `fastapi.HTTPExcetion`""" + return HTTPException( + status_code=self.code, + detail=self.msg, + ) + + +class EmptyUserTokenError(BaseDDSException): + """Raised if `User-Token` is empty""" + + msg: str = "User-Token cannot be empty!" + + +class ImproperUserTokenError(BaseDDSException): + """Raised if `User-Token` format is wrong""" + + msg: str = ( + "The format of the User-Token is wrong. It should be be in the format" + " :!" + ) + + +class NoEligibleProductInDatasetError(BaseDDSException): + """No eligible products in the dataset Error""" + + msg: str = ( + "No eligible products for the dataset '{dataset_id}' for the user" + " with roles '{user_roles_names}'" + ) + + def __init__(self, dataset_id: str, user_roles_names: list[str]) -> None: + self.msg = self.msg.format( + dataset_id=dataset_id, user_roles_names=user_roles_names + ) + super().__init__(self.msg) + + +class MissingKeyInCatalogEntryError(BaseDDSException): + """Missing key in the catalog entry""" + + msg: str = ( + "There is missing '{key}' in the catalog for '{dataset}' dataset." + ) + + def __init__(self, key, dataset): + self.msg = self.msg.format(key=key, dataset=dataset) + super().__init__(self.msg) + + +class MaximumAllowedSizeExceededError(BaseDDSException): + """Estimated size is too big""" + + msg: str = ( + "Maximum allowed size for '{dataset_id}.{product_id}' is" + " {allowed_size_gb:.2f} GB but the estimated size is" + " {estimated_size_gb:.2f} GB" + ) + + def __init__( + self, dataset_id, product_id, estimated_size_gb, allowed_size_gb + ): + self.msg = self.msg.format( + dataset_id=dataset_id, + product_id=product_id, + allowed_size_gb=allowed_size_gb, + estimated_size_gb=estimated_size_gb, + ) + super().__init__(self.msg) + + +class RequestNotYetAccomplished(BaseDDSException): + """Raised if dds request was not finished yet""" + + msg: str = ( + "Request with id: {request_id} does not exist or it is not" + " finished yet!" + ) + + def __init__(self, request_id): + self.msg = self.msg.format(request_id=request_id) + super().__init__(self.msg) + + +class RequestNotFound(BaseDDSException): + """If the given request could not be found""" + + msg: str = "Request with ID '{request_id}' was not found" + + def __init__(self, request_id: int) -> None: + self.msg = self.msg.format(request_id=request_id) + super().__init__(self.msg) + + +class RequestStatusNotDone(BaseDDSException): + """Raised when the submitted request failed""" + + msg: str = ( + "Request with id: `{request_id}` does not have download. URI. Its" + " status is: `{request_status}`!" + ) + + def __init__(self, request_id, request_status) -> None: + self.msg = self.msg.format( + request_id=request_id, request_status=request_status + ) + super().__init__(self.msg) + + +class AuthorizationFailed(BaseDDSException): + """Raised when the user is not authorized for the given resource""" + + msg: str = "{user} is not authorized for the resource!" + code: int = 403 + + def __init__(self, user_id: Optional[str] = None): + if user_id is None: + self.msg = self.msg.format(user="User") + else: + self.msg = self.msg.format(user=f"User '{user_id}'") + super().__init__(self.msg) + + +class AuthenticationFailed(BaseDDSException): + """Raised when the key of the provided user differs from the one s + tored in the DB""" + + msg: str = "Authentication of the user '{user_id}' failed!" + code: int = 401 + + def __init__(self, user_id: str): + self.msg = self.msg.format(user_id=user_id) + super().__init__(self.msg) + + +class MissingDatasetError(BaseDDSException): + """Raied if the queried dataset is not present in the catalog""" + + msg: str = "Dataset '{dataset_id}' does not exist in the catalog!" + + def __init__(self, dataset_id: str): + self.msg = self.msg.format(dataset_id=dataset_id) + super().__init__(self.msg) + + +class MissingProductError(BaseDDSException): + """Raised if the requested product is not defined for the dataset""" + + msg: str = ( + "Product '{dataset_id}.{product_id}' does not exist in the catalog!" + ) + + def __init__(self, dataset_id: str, product_id: str): + self.msg = self.msg.format( + dataset_id=dataset_id, product_id=product_id + ) + super().__init__(self.msg) + + +class EmptyDatasetError(BaseDDSException): + """The size of the requested dataset is zero""" + + msg: str = "The resulting dataset '{dataset_id}.{product_id}' is empty" + + def __init__(self, dataset_id, product_id): + self.msg = self.msg.format( + dataset_id=dataset_id, + product_id=product_id, + ) + super().__init__(self.msg) + +class ProductRetrievingError(BaseDDSException): + """Retrieving of the product failed.""" + + msg: str = "Retrieving of the product '{dataset_id}.{product_id}' failed with the status {status}" + + def __init__(self, dataset_id, product_id, status): + self.msg = self.msg.format( + dataset_id=dataset_id, + product_id=product_id, + status=status + ) + super().__init__(self.msg) \ No newline at end of file diff --git a/api/app/main.py b/api/app/main.py index 2712586..2084394 100644 --- a/api/app/main.py +++ b/api/app/main.py @@ -1,72 +1,468 @@ -from fastapi import FastAPI -import pika -from enum import Enum -from pydantic import BaseModel -from db.dbmanager.dbmanager import DBManager -from geoquery.geoquery import GeoQuery - -app = FastAPI() -db_conn = None -## -# RabbitMQ Broker Connection -broker_conn = pika.BlockingConnection(pika.ConnectionParameters(host='broker')) -broker_chann = broker_conn.channel() - -@app.get("/") +"""Main module with dekube-dds API endpoints defined""" +__version__ = "2.0" +import os +from typing import Optional + +from datetime import datetime + +from fastapi import FastAPI, HTTPException, Request, status, Query +from fastapi.middleware.cors import CORSMiddleware +from starlette.middleware.authentication import AuthenticationMiddleware +from starlette.authentication import requires + +from aioprometheus import ( + Counter, + Summary, + timer, + MetricsMiddleware, +) +from aioprometheus.asgi.starlette import metrics + +from intake_geokube.queries.workflow import Workflow +from intake_geokube.queries.geoquery import GeoQuery + +from utils.api_logging import get_dds_logger +import exceptions as exc +from endpoint_handlers import ( + dataset_handler, + file_handler, + request_handler, +) +from auth.backend import DDSAuthenticationBackend +from callbacks import all_onstartup_callbacks +from encoders import extend_json_encoders +from const import venv, tags +from auth import scopes + +def map_to_geoquery( + variables: list[str], + format: str, + bbox: str | None = None, # minx, miny, maxx, maxy (minlon, minlat, maxlon, maxlat) + time: datetime | None = None, + **format_kwargs +) -> GeoQuery: + + bbox_ = [float(x) for x in bbox.split(',')] + area = { 'west': bbox_[0], 'south': bbox_[1], 'east': bbox_[2], 'north': bbox_[3], } + time_ = { 'year': time.year, 'month': time.month, 'day': time.day, 'hour': time.hour} + query = GeoQuery(variable=variables, time=time_, area=area, + format_args=format_kwargs, format=format) + return query + +logger = get_dds_logger(__name__) + +# ======== JSON encoders extension ========= # +extend_json_encoders() + +app = FastAPI( + title="geokube-dds API", + description="REST API for geokube-dds", + version=__version__, + contact={ + "name": "geokube Contributors", + "email": "geokube@googlegroups.com", + }, + license_info={ + "name": "Apache 2.0", + "url": "https://www.apache.org/licenses/LICENSE-2.0.html", + }, + root_path=os.environ.get(venv.ENDPOINT_PREFIX, "/api"), + on_startup=all_onstartup_callbacks, +) + +# ======== Authentication backend ========= # +app.add_middleware( + AuthenticationMiddleware, backend=DDSAuthenticationBackend() +) + +# ======== CORS ========= # +cors_kwargs: dict[str, str | list[str]] +if venv.ALLOWED_CORS_ORIGINS_REGEX in os.environ: + cors_kwargs = { + "allow_origin_regex": os.environ[venv.ALLOWED_CORS_ORIGINS_REGEX] + } +else: + cors_kwargs = {"allow_origins": ["*"]} + +app.add_middleware( + CORSMiddleware, + allow_credentials=True, + allow_methods=["*"], + allow_headers=["*"], + **cors_kwargs, +) + + +# ======== Prometheus metrics ========= # +app.add_middleware(MetricsMiddleware) +app.add_route("/metrics", metrics) + +app.state.api_request_duration_seconds = Summary( + "api_request_duration_seconds", "Requests duration" +) +app.state.api_http_requests_total = Counter( + "api_http_requests_total", "Total number of requests" +) + + +# ======== Endpoints definitions ========= # +@app.get("/", tags=[tags.BASIC]) async def dds_info(): - return {"DDS API 2.0"} - -@app.get("/datasets") -async def datasets(): - return {"List of Datasets"} - -@app.get("/datasets/{dataset_id}") -async def dataset(dataset_id: str): - return {f"Dataset Info {dataset_id}"} - -@app.get("/datasets/{dataset_id}/{product_id}") -async def dataset(dataset_id: str, product_id: str): - return {f"Product Info {product_id} from dataset {dataset_id}"} - -@app.post("/datasets/{dataset_id}/{product_id}/estimate") -async def estimate(dataset_id: str, product_id: str, query: GeoQuery): - return {f'estimate size for {dataset_id} {product_id} is 10GB'} - -@app.post("/datasets/{dataset_id}/{product_id}/execute") -async def query(dataset_id: str, product_id: str, format: str, query: GeoQuery): - global db_conn - if not db_conn: - db_conn = DBManager() -# -# -# TODO: Validation Query Schema -# TODO: estimate the size and will not execute if it is above the limit -# -# - request_id = db_conn.create_request(dataset=dataset_id, product=product_id, query=query.json()) - print(f"request id: {request_id}") - -# we should find a separator; for the moment use "\" - message = f'{request_id}\\{dataset_id}\\{product_id}\\{query.json()}\\{format}' - -# submit request to broker queue - broker_chann.basic_publish( - exchange='', - routing_key='query_queue', - body=message, - properties=pika.BasicProperties( - delivery_mode=2, # make message persistent - )) - return request_id - -@app.get("/requests") -async def get_requests(): - return - -@app.get("/requests/{request_id}/status") -async def get_request_status(request_id: int): - return db_conn.get_request_status(request_id) - -@app.get("/requests/{request_id}/uri") -async def get_request_uri(request_id: int): - return \ No newline at end of file + """Return current version of the DDS API""" + return f"DDS API {__version__}" + + +@app.get("/datasets", tags=[tags.DATASET]) +@timer( + app.state.api_request_duration_seconds, labels={"route": "GET /datasets"} +) +async def get_datasets(request: Request): + """List all products eligible for a user defined by user_token""" + app.state.api_http_requests_total.inc({"route": "GET /datasets"}) + try: + return dataset_handler.get_datasets( + user_roles_names=request.auth.scopes + ) + except exc.BaseDDSException as err: + raise err.wrap_around_http_exception() from err + + +@app.get("/datasets/{dataset_id}", tags=[tags.DATASET]) +@timer( + app.state.api_request_duration_seconds, + labels={"route": "GET /datasets/{dataset_id}"}, +) +async def get_first_product_details( + request: Request, + dataset_id: str, +): + """Get details for the 1st product of the dataset""" + app.state.api_http_requests_total.inc( + {"route": "GET /datasets/{dataset_id}"} + ) + try: + return dataset_handler.get_product_details( + user_roles_names=request.auth.scopes, + dataset_id=dataset_id, + ) + except exc.BaseDDSException as err: + raise err.wrap_around_http_exception() from err + + +@app.get("/datasets/{dataset_id}/{product_id}", tags=[tags.DATASET]) +@timer( + app.state.api_request_duration_seconds, + labels={"route": "GET /datasets/{dataset_id}/{product_id}"}, +) +async def get_product_details( + request: Request, + dataset_id: str, + product_id: str, +): + """Get details for the requested product if user is authorized""" + app.state.api_http_requests_total.inc( + {"route": "GET /datasets/{dataset_id}/{product_id}"} + ) + try: + return dataset_handler.get_product_details( + user_roles_names=request.auth.scopes, + dataset_id=dataset_id, + product_id=product_id, + ) + except exc.BaseDDSException as err: + raise err.wrap_around_http_exception() from err + +@app.get("/datasets/{dataset_id}/{product_id}/map", tags=[tags.DATASET]) +@timer( + app.state.api_request_duration_seconds, + labels={"route": "GET /datasets/{dataset_id}/{product_id}"}, +) +async def get_map( + request: Request, + dataset_id: str, + product_id: str, +# OGC WMS parameters + width: int, + height: int, + layers: str | None = None, + format: str | None = 'png', + time: datetime | None = None, + transparent: bool | None = 'true', + bgcolor: str | None = 'FFFFFF', + bbox: str | None = None, # minx, miny, maxx, maxy (minlon, minlat, maxlon, maxlat) + crs: str | None = None, +# OGC map parameters + # subset: str | None = None, + # subset_crs: str | None = Query(..., alias="subset-crs"), + # bbox_crs: str | None = Query(..., alias="bbox-crs"), +): + + app.state.api_http_requests_total.inc( + {"route": "GET /datasets/{dataset_id}/{product_id}/map"} + ) + # query should be the OGC query + # map OGC parameters to GeoQuery + # variable: Optional[Union[str, List[str]]] + # time: Optional[Union[Dict[str, str], Dict[str, List[str]]]] + # area: Optional[Dict[str, float]] + # location: Optional[Dict[str, Union[float, List[float]]]] + # vertical: Optional[Union[float, List[float], Dict[str, float]]] + # filters: Optional[Dict] + # format: Optional[str] + query = map_to_geoquery(variables=layers, bbox=bbox, time=time, + format="png", width=width, height=height, + transparent=transparent, bgcolor=bgcolor) + try: + return dataset_handler.sync_query( + user_id=request.user.id, + dataset_id=dataset_id, + product_id=product_id, + query=query + ) + except exc.BaseDDSException as err: + raise err.wrap_around_http_exception() from err + +@app.get("/datasets/{dataset_id}/{product_id}/items/{feature_id}", tags=[tags.DATASET]) +@timer( + app.state.api_request_duration_seconds, + labels={"route": "GET /datasets/{dataset_id}/{product_id}/items/{feature_id}"}, +) +async def get_feature( + request: Request, + dataset_id: str, + product_id: str, + feature_id: str, +# OGC feature parameters + time: datetime | None = None, + bbox: str | None = None, # minx, miny, maxx, maxy (minlon, minlat, maxlon, maxlat) + crs: str | None = None, +# OGC map parameters + # subset: str | None = None, + # subset_crs: str | None = Query(..., alias="subset-crs"), + # bbox_crs: str | None = Query(..., alias="bbox-crs"), +): + + app.state.api_http_requests_total.inc( + {"route": "GET /datasets/{dataset_id}/{product_id}/items/{feature_id}"} + ) + # query should be the OGC query + # feature OGC parameters to GeoQuery + # variable: Optional[Union[str, List[str]]] + # time: Optional[Union[Dict[str, str], Dict[str, List[str]]]] + # area: Optional[Dict[str, float]] + # location: Optional[Dict[str, Union[float, List[float]]]] + # vertical: Optional[Union[float, List[float], Dict[str, float]]] + # filters: Optional[Dict] + # format: Optional[str] + + query = map_to_geoquery(variables=[feature_id], bbox=bbox, time=time, + format="geojson") + try: + return dataset_handler.sync_query( + user_id=request.user.id, + dataset_id=dataset_id, + product_id=product_id, + query=query + ) + except exc.BaseDDSException as err: + raise err.wrap_around_http_exception() from err + +@app.get("/datasets/{dataset_id}/{product_id}/metadata", tags=[tags.DATASET]) +@timer( + app.state.api_request_duration_seconds, + labels={"route": "GET /datasets/{dataset_id}/{product_id}/metadata"}, +) +async def get_metadata( + request: Request, + dataset_id: str, + product_id: str, +): + """Get metadata of the given product""" + app.state.api_http_requests_total.inc( + {"route": "GET /datasets/{dataset_id}/{product_id}/metadata"} + ) + try: + return dataset_handler.get_metadata( + dataset_id=dataset_id, product_id=product_id + ) + except exc.BaseDDSException as err: + raise err.wrap_around_http_exception() from err + + +@app.post("/datasets/{dataset_id}/{product_id}/estimate", tags=[tags.DATASET]) +@timer( + app.state.api_request_duration_seconds, + labels={"route": "POST /datasets/{dataset_id}/{product_id}/estimate"}, +) +async def estimate( + request: Request, + dataset_id: str, + product_id: str, + query: GeoQuery, + unit: str = None, +): + """Estimate the resulting size of the query""" + app.state.api_http_requests_total.inc( + {"route": "POST /datasets/{dataset_id}/{product_id}/estimate"} + ) + try: + return dataset_handler.estimate( + dataset_id=dataset_id, + product_id=product_id, + query=query, + unit=unit, + ) + except exc.BaseDDSException as err: + raise err.wrap_around_http_exception() from err + + +@app.post("/datasets/{dataset_id}/{product_id}/execute", tags=[tags.DATASET]) +@timer( + app.state.api_request_duration_seconds, + labels={"route": "POST /datasets/{dataset_id}/{product_id}/execute"}, +) +@requires([scopes.AUTHENTICATED]) +async def query( + request: Request, + dataset_id: str, + product_id: str, + query: GeoQuery, +): + """Schedule the job of data retrieve""" + app.state.api_http_requests_total.inc( + {"route": "POST /datasets/{dataset_id}/{product_id}/execute"} + ) + try: + return dataset_handler.async_query( + user_id=request.user.id, + dataset_id=dataset_id, + product_id=product_id, + query=query, + ) + except exc.BaseDDSException as err: + raise err.wrap_around_http_exception() from err + + +@app.post("/datasets/workflow", tags=[tags.DATASET]) +@timer( + app.state.api_request_duration_seconds, + labels={"route": "POST /datasets/workflow"}, +) +@requires([scopes.AUTHENTICATED]) +async def workflow( + request: Request, + tasks: Workflow, +): + """Schedule the job of workflow processing""" + app.state.api_http_requests_total.inc({"route": "POST /datasets/workflow"}) + try: + return dataset_handler.run_workflow( + user_id=request.user.id, + workflow=tasks, + ) + except exc.BaseDDSException as err: + raise err.wrap_around_http_exception() from err + + +@app.get("/requests", tags=[tags.REQUEST]) +@timer( + app.state.api_request_duration_seconds, labels={"route": "GET /requests"} +) +@requires([scopes.AUTHENTICATED]) +async def get_requests( + request: Request, +): + """Get all requests for the user""" + app.state.api_http_requests_total.inc({"route": "GET /requests"}) + try: + return request_handler.get_requests(request.user.id) + except exc.BaseDDSException as err: + raise err.wrap_around_http_exception() from err + + +@app.get("/requests/{request_id}/status", tags=[tags.REQUEST]) +@timer( + app.state.api_request_duration_seconds, + labels={"route": "GET /requests/{request_id}/status"}, +) +@requires([scopes.AUTHENTICATED]) +async def get_request_status( + request: Request, + request_id: int, +): + """Get status of the request without authentication""" + app.state.api_http_requests_total.inc( + {"route": "GET /requests/{request_id}/status"} + ) + try: + return request_handler.get_request_status( + user_id=request.user.id, request_id=request_id + ) + except exc.BaseDDSException as err: + raise err.wrap_around_http_exception() from err + + +@app.get("/requests/{request_id}/size", tags=[tags.REQUEST]) +@timer( + app.state.api_request_duration_seconds, + labels={"route": "GET /requests/{request_id}/size"}, +) +@requires([scopes.AUTHENTICATED]) +async def get_request_resulting_size( + request: Request, + request_id: int, +): + """Get size of the file being the result of the request""" + app.state.api_http_requests_total.inc( + {"route": "GET /requests/{request_id}/size"} + ) + try: + return request_handler.get_request_resulting_size( + request_id=request_id + ) + except exc.BaseDDSException as err: + raise err.wrap_around_http_exception() from err + + +@app.get("/requests/{request_id}/uri", tags=[tags.REQUEST]) +@timer( + app.state.api_request_duration_seconds, + labels={"route": "GET /requests/{request_id}/uri"}, +) +@requires([scopes.AUTHENTICATED]) +async def get_request_uri( + request: Request, + request_id: int, +): + """Get download URI for the request""" + app.state.api_http_requests_total.inc( + {"route": "GET /requests/{request_id}/uri"} + ) + try: + return request_handler.get_request_uri(request_id=request_id) + except exc.BaseDDSException as err: + raise err.wrap_around_http_exception() from err + + +@app.get("/download/{request_id}", tags=[tags.REQUEST]) +@timer( + app.state.api_request_duration_seconds, + labels={"route": "GET /download/{request_id}"}, +) +# @requires([scopes.AUTHENTICATED]) # TODO: mange download auth in the web component +async def download_request_result( + request: Request, + request_id: int, +): + """Download result of the request""" + app.state.api_http_requests_total.inc( + {"route": "GET /download/{request_id}"} + ) + try: + return file_handler.download_request_result(request_id=request_id) + except exc.BaseDDSException as err: + raise err.wrap_around_http_exception() from err + except FileNotFoundError as err: + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, detail="File was not found!" + ) from err diff --git a/api/app/validation.py b/api/app/validation.py new file mode 100644 index 0000000..51bdbc1 --- /dev/null +++ b/api/app/validation.py @@ -0,0 +1,36 @@ +from datastore.datastore import Datastore +from utils.api_logging import get_dds_logger +from decorators_factory import assert_parameters_are_defined, bind_arguments +from functools import wraps +from inspect import signature +import exceptions as exc + + +log = get_dds_logger(__name__) + + +def assert_product_exists(func): + """Decorator for convenient checking if product is defined in the catalog + """ + sig = signature(func) + assert_parameters_are_defined( + sig, required_parameters=[("dataset_id", str), ("product_id", str)] + ) + + @wraps(func) + def assert_inner(*args, **kwargs): + args_dict = bind_arguments(sig, *args, **kwargs) + dataset_id = args_dict["dataset_id"] + product_id = args_dict["product_id"] + if dataset_id not in Datastore().dataset_list(): + raise exc.MissingDatasetError(dataset_id=dataset_id) + elif ( + product_id is not None + and product_id not in Datastore().product_list(dataset_id) + ): + raise exc.MissingProductError( + dataset_id=dataset_id, product_id=product_id + ) + return func(*args, **kwargs) + + return assert_inner diff --git a/api/requirements.txt b/api/requirements.txt index e23ebfb..97fcaf3 100644 --- a/api/requirements.txt +++ b/api/requirements.txt @@ -1,6 +1,5 @@ fastapi -pydantic uvicorn pika -intake -sqlalchemy \ No newline at end of file +sqlalchemy +aioprometheus diff --git a/datastore/Dockerfile b/datastore/Dockerfile new file mode 100644 index 0000000..018ad5e --- /dev/null +++ b/datastore/Dockerfile @@ -0,0 +1,14 @@ +ARG REGISTRY=rg.nl-ams.scw.cloud/geokube-production +ARG TAG=latest +FROM $REGISTRY/geolake-drivers:$TAG +RUN conda install -c conda-forge --yes --freeze-installed psycopg2 \ + && conda clean -afy +COPY requirements.txt /app/requirements.txt +RUN pip install --no-cache-dir -r /app/requirements.txt +COPY ./datastore /app/datastore +COPY ./workflow /app/workflow +COPY ./dbmanager /app/dbmanager +COPY ./utils /app/utils +COPY ./tests /app/tests +COPY ./wait-for-it.sh / + diff --git a/datastore/datastore.py b/datastore/datastore.py deleted file mode 100644 index 107d821..0000000 --- a/datastore/datastore.py +++ /dev/null @@ -1,63 +0,0 @@ -import intake -from geokube.core.datacube import DataCube -from geokube.core.dataset import Dataset -from typing import Union -from geoquery.geoquery import GeoQuery -import json - -class Datastore(): - - def __init__(self, cat_path: str) -> None: - self.catalog = intake.open_catalog(cat_path) - - def dataset_list(self): - return list(self.catalog) - - def product_list(self, dataset_id: str): - return list(self.catalog[dataset_id]) - - def dataset_info(self, dataset_id: str): - info = {} - entry = self.catalog[dataset_id] - if entry.metadata: - info['metadata'] = entry.metadata - info['products'] = {} - for p in self.products(): - info['products'][p] = self.product_info() - - def product_info(self, dataset_id: str, product_id: str): - info = {} - entry = self.catalog[dataset_id][product_id] - if entry.metadata: - info['metadata'] = entry.metadata - info.update(entry.read_chunked().to_dict()) - return info - - def query(self, dataset: str, product: str, query: Union[GeoQuery, dict, str], compute: bool=False): - """ - :param dataset: dasaset name - :param product: product name - :param query: subset query - :param path: path to store - :return: subsetted geokube of selected dataset product - """ - if isinstance(query, str): - query = json.loads(query) - if isinstance(query, dict): - query = GeoQuery(**query) - kube = self.catalog[dataset][product].read_chunked() - if isinstance(kube, Dataset): - kube = kube.filter(query.filters) - if query.variable: - kube = kube[query.variable] - if query.area: - kube = kube.geobbox(query.area) - if query.locations: - kube = kube.locations(**query.locations) - if query.time: - kube = kube.sel(query.time) - if query.vertical: - kube = kube.sel(query.vertical) - if compute: - kube.compute() - return kube \ No newline at end of file diff --git a/datastore/datastore/__init__.py b/datastore/datastore/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/datastore/datastore/const.py b/datastore/datastore/const.py new file mode 100644 index 0000000..22435bc --- /dev/null +++ b/datastore/datastore/const.py @@ -0,0 +1,6 @@ +"""This module contains useful constants definitions grouped into classes""" + + +class BaseRole: + PUBLIC = "public" + ADMIN = "admin" diff --git a/datastore/datastore/datastore.py b/datastore/datastore/datastore.py new file mode 100644 index 0000000..ca402fe --- /dev/null +++ b/datastore/datastore/datastore.py @@ -0,0 +1,447 @@ +"""Module for catalog management classes and functions""" +from __future__ import annotations + +import os +import logging +import json + +import intake +from dask.delayed import Delayed + +from intake_geokube.queries.geoquery import GeoQuery + +from geokube.core.datacube import DataCube +from geokube.core.dataset import Dataset + +from .singleton import Singleton +from .util import log_execution_time +from .const import BaseRole +from .exception import UnauthorizedError + +DEFAULT_MAX_REQUEST_SIZE_GB = 10 + + +class Datastore(metaclass=Singleton): + """Singleton component for managing catalog data""" + + _LOG = logging.getLogger("geokube.Datastore") + + def __init__(self) -> None: + if "CATALOG_PATH" not in os.environ: + self._LOG.error( + "missing required environment variable: 'CATALOG_PATH'" + ) + raise KeyError( + "Missing required environment variable: 'CATALOG_PATH'" + ) + if "CACHE_PATH" not in os.environ: + self._LOG.error( + "'CACHE_PATH' environment variable was not set. catalog will" + " not be opened!" + ) + raise RuntimeError( + "'CACHE_PATH' environment variable was not set. catalog will" + " not be opened!" + ) + self.catalog = intake.open_catalog(os.environ["CATALOG_PATH"]) + self.cache_dir = os.environ["CACHE_PATH"] + self._LOG.info("cache dir set to %s", self.cache_dir) + self.cache = None + + @log_execution_time(_LOG) + def get_cached_product_or_read( + self, dataset_id: str, product_id: str, query: GeoQuery | None = None + ) -> DataCube | Dataset: + """Get product from the cache instead of loading files indicated in + the catalog if `metadata_caching` set to `True`. + If might return `geokube.DataCube` or `geokube.Dataset`. + + Parameters + ------- + dataset_id : str + ID of the dataset + product_id : str + ID of the dataset + + Returns + ------- + kube : DataCube or Dataset + """ + if self.cache is None: + self._load_cache() + if ( + dataset_id not in self.cache + or product_id not in self.cache[dataset_id] + ): + self._LOG.info( + "dataset `%s` or product `%s` not found in cache! Reading" + " product!", + dataset_id, + product_id, + ) + return self.catalog(CACHE_DIR=self.cache_dir)[dataset_id][ + product_id + ].process(query=query) + return self.cache[dataset_id][product_id] + + @log_execution_time(_LOG) + def _load_cache(self, datasets: list[str] | None = None): + if self.cache is None or datasets is None: + self.cache = {} + datasets = self.dataset_list() + + for i, dataset_id in enumerate(datasets): + self._LOG.info( + "loading cache for `%s` (%d/%d)", + dataset_id, + i + 1, + len(datasets), + ) + self.cache[dataset_id] = {} + for product_id in self.product_list(dataset_id): + catalog_entry = self.catalog(CACHE_DIR=self.cache_dir)[ + dataset_id + ][product_id] + if hasattr(catalog_entry, "metadata_caching") and not catalog_entry.metadata_caching: + self._LOG.info( + "`metadata_caching` for product %s.%s set to `False`", + dataset_id, + product_id, + ) + continue + try: + self.cache[dataset_id][ + product_id + ] = catalog_entry.read() + except ValueError: + self._LOG.error( + "failed to load cache for `%s.%s`", + dataset_id, + product_id, + exc_info=True, + ) + except NotImplementedError: + pass + + @log_execution_time(_LOG) + def dataset_list(self) -> list: + """Get list of datasets available in the catalog stored in `catalog` + attribute + + Returns + ------- + datasets : list + List of datasets present in the catalog + """ + datasets = set(self.catalog(CACHE_DIR=self.cache_dir)) + datasets -= { + "medsea-rea-e3r1", + } + # NOTE: medsae cmip uses cftime.DatetimeNoLeap as time + # need to think how to handle it + return sorted(list(datasets)) + + @log_execution_time(_LOG) + def product_list(self, dataset_id: str): + """Get list of products available in the catalog for dataset + indicated by `dataset_id` + + Parameters + ---------- + dataset_id : str + ID of the dataset + + Returns + ------- + products : list + List of products for the dataset + """ + return list(self.catalog(CACHE_DIR=self.cache_dir)[dataset_id]) + + @log_execution_time(_LOG) + def dataset_info(self, dataset_id: str): + """Get information about the dataset and names of all available + products (with their metadata) + + Parameters + ---------- + dataset_id : str + ID of the dataset + + Returns + ------- + info : dict + Dict of short information about the dataset + """ + info = {} + entry = self.catalog(CACHE_DIR=self.cache_dir)[dataset_id] + if entry.metadata: + info["metadata"] = entry.metadata + info["metadata"]["id"] = dataset_id + info["products"] = {} + for product_id in entry: + prod_entry = entry[product_id] + info["products"][product_id] = prod_entry.metadata + info["products"][product_id][ + "description" + ] = prod_entry.description + return info + + @log_execution_time(_LOG) + def product_metadata(self, dataset_id: str, product_id: str): + """Get product metadata directly from the catalog. + + Parameters + ---------- + dataset_id : str + ID of the dataset + product_id : str + ID of the product + + Returns + ------- + metadata : dict + DatasetMetadata of the product + """ + return self.catalog(CACHE_DIR=self.cache_dir)[dataset_id][ + product_id + ].metadata + + @log_execution_time(_LOG) + def first_eligible_product_details( + self, + dataset_id: str, + role: str | list[str] | None = None, + use_cache: bool = False, + ): + """Get details for the 1st product of the dataset eligible for the `role`. + If `role` is `None`, the `public` role is considered. + + Parameters + ---------- + dataset_id : str + ID of the dataset + role : optional str or list of str, default=`None` + Role code for which the 1st eligible product of a dataset + should be selected + use_cache : bool, optional, default=False + Data will be loaded from cache if set to `True` or directly + from the catalog otherwise + + Returns + ------- + details : dict + Details of the product + + Raises + ------ + UnauthorizedError + if none of product of the requested dataset is eligible for a role + """ + info = {} + product_ids = self.product_list(dataset_id) + for prod_id in product_ids: + if not self.is_product_valid_for_role( + dataset_id, prod_id, role=role + ): + continue + entry = self.catalog(CACHE_DIR=self.cache_dir)[dataset_id][prod_id] + if entry.metadata: + info["metadata"] = entry.metadata + info["description"] = entry.description + info["id"] = prod_id + info["dataset"] = self.dataset_info(dataset_id=dataset_id) + if use_cache: + info["data"] = self.get_cached_product_or_read( + dataset_id, prod_id + ).to_dict() + else: + info["data"] = entry.read_chunked().to_dict() + return info + raise UnauthorizedError() + + @log_execution_time(_LOG) + def product_details( + self, + dataset_id: str, + product_id: str, + role: str | list[str] | None = None, + use_cache: bool = False, + ): + """Get details for the single product + + Parameters + ---------- + dataset_id : str + ID of the dataset + product_id : str + ID of the product + role : optional str or list of str, default=`None` + Role code for which the the product is requested. + use_cache : bool, optional, default=False + Data will be loaded from cache if set to `True` or directly + from the catalog otherwise + + Returns + ------- + details : dict + Details of the product + + Raises + ------ + UnauthorizedError + if the requested product is not eligible for a role + """ + info = {} + if not self.is_product_valid_for_role( + dataset_id, product_id, role=role + ): + raise UnauthorizedError() + entry = self.catalog(CACHE_DIR=self.cache_dir)[dataset_id][product_id] + if entry.metadata: + info["metadata"] = entry.metadata + info["description"] = entry.description + info["id"] = product_id + info["dataset"] = self.dataset_info(dataset_id=dataset_id) + if use_cache: + info["data"] = self.get_cached_product_or_read( + dataset_id, product_id + ).to_dict() + else: + info["data"] = entry.read_chunked().to_dict() + return info + + def product_info( + self, dataset_id: str, product_id: str, use_cache: bool = False + ): + info = {} + entry = self.catalog(CACHE_DIR=self.cache_dir)[dataset_id][product_id] + if entry.metadata: + info["metadata"] = entry.metadata + if use_cache: + info["data"] = self.get_cached_product_or_read( + dataset_id, product_id + ).to_dict() + else: + info["data"] = entry.read_chunked().to_dict() + return info + + @log_execution_time(_LOG) + def query( + self, + dataset_id: str, + product_id: str, + query: GeoQuery | dict | str, + compute: None | bool = False, + ) -> DataCube: + """Query dataset + + Parameters + ---------- + dataset_id : str + ID of the dataset + product_id : str + ID of the product + query : GeoQuery or dict or str or bytes or bytearray + Query to be executed for the given product + compute : bool, optional, default=False + If True, resulting data of DataCube will be computed, otherwise + DataCube with `dask.Delayed` object will be returned + + Returns + ------- + kube : DataCube + DataCube processed according to `query` + """ + self._LOG.debug("query: %s", query) + geoquery: GeoQuery = GeoQuery.parse(query) + self._LOG.debug("processing GeoQuery: %s", geoquery) + # NOTE: we always use catalog directly and single product cache + self._LOG.debug("loading product...") + kube = self.catalog(CACHE_DIR=self.cache_dir)[dataset_id][ + product_id + ].process(query=geoquery) + return kube + + @log_execution_time(_LOG) + def estimate( + self, + dataset_id: str, + product_id: str, + query: GeoQuery | dict | str, + ) -> int: + """Estimate dataset size + + Parameters + ---------- + dataset_id : str + ID of the dataset + product_id : str + ID of the product + query : GeoQuery or dict or str + Query to be executed for the given product + + Returns + ------- + size : int + Number of bytes of the estimated kube + """ + self._LOG.debug("query: %s", query) + geoquery: GeoQuery = GeoQuery.parse(query) + self._LOG.debug("processing GeoQuery: %s", geoquery) + # NOTE: we always use catalog directly and single product cache + self._LOG.debug("loading product...") + # NOTE: for estimation we use cached products + kube = self.get_cached_product_or_read(dataset_id, product_id, + query=query) + return Datastore._process_query(kube, geoquery, False).nbytes + + @log_execution_time(_LOG) + def is_product_valid_for_role( + self, + dataset_id: str, + product_id: str, + role: str | list[str] | None = None, + ): + entry = self.catalog(CACHE_DIR=self.cache_dir)[dataset_id][product_id] + product_role = BaseRole.PUBLIC + if entry.metadata: + product_role = entry.metadata.get("role", BaseRole.PUBLIC) + if product_role == BaseRole.PUBLIC: + return True + if not role: + # NOTE: it means, we consider the public profile + return False + if BaseRole.ADMIN in role: + return True + if product_role in role: + return True + return False + + @staticmethod + def _process_query(kube, query: GeoQuery, compute: None | bool = False): + if isinstance(kube, Dataset): + Datastore._LOG.debug("filtering with: %s", query.filters) + try: + kube = kube.filter(**query.filters) + except ValueError as err: + Datastore._LOG.warning("could not filter by one of the key: %s", err) + if isinstance(kube, Delayed) and compute: + kube = kube.compute() + if query.variable: + Datastore._LOG.debug("selecting fields...") + kube = kube[query.variable] + if query.area: + Datastore._LOG.debug("subsetting by geobbox...") + kube = kube.geobbox(**query.area) + if query.location: + Datastore._LOG.debug("subsetting by locations...") + kube = kube.locations(**query.location) + if query.time: + Datastore._LOG.debug("subsetting by time...") + kube = kube.sel(time=query.time) + if query.vertical: + Datastore._LOG.debug("subsetting by vertical...") + method = None if isinstance(query.vertical, slice) else "nearest" + kube = kube.sel(vertical=query.vertical, method=method) + return kube.compute() if compute else kube diff --git a/datastore/datastore/exception.py b/datastore/datastore/exception.py new file mode 100644 index 0000000..d048e83 --- /dev/null +++ b/datastore/datastore/exception.py @@ -0,0 +1,5 @@ +"""Module with exceptions definitions""" + + +class UnauthorizedError(ValueError): + """Role is not authorized""" diff --git a/datastore/datastore/singleton.py b/datastore/datastore/singleton.py new file mode 100644 index 0000000..ff6ef01 --- /dev/null +++ b/datastore/datastore/singleton.py @@ -0,0 +1,29 @@ +# -*- coding: utf-8 -*- +"""Singleton module. + +The module contains metaclass called Singleton +for thread-safe singleton-pattern implementation. +""" +import os +import logging +from threading import Lock +from typing import Any, Type + + +class Singleton(type): + """Thread-safe implementation of the singleton design pattern metaclass""" + + _instances: dict[Type, Any] = {} + _lock: Lock = Lock() + + def __call__(cls, *args, **kwargs): + with cls._lock: + if cls not in cls._instances: + instance = super().__call__(*args, **kwargs) + if hasattr(instance, "_LOG"): + instance._LOG.setLevel( + os.environ.get("LOGGING_LEVEL", "INFO") + ) + instance._LOG.addHandler(logging.StreamHandler()) + cls._instances[cls] = instance + return cls._instances[cls] diff --git a/datastore/datastore/util.py b/datastore/datastore/util.py new file mode 100644 index 0000000..4122d57 --- /dev/null +++ b/datastore/datastore/util.py @@ -0,0 +1,27 @@ +"""Utils module""" +from functools import wraps +import datetime +import logging + + +def log_execution_time(logger: logging.Logger): + """Decorator logging execution time of the method or function""" + + def inner(func): + @wraps(func) + def wrapper(*args, **kwds): + exec_start_time = datetime.datetime.now() + try: + return func(*args, **kwds) + finally: + exec_time = datetime.datetime.now() - exec_start_time + logger.info( + "execution of '%s' function from '%s' package took %s", + func.__name__, + func.__module__, + exec_time, + ) + + return wrapper + + return inner diff --git a/datastore/dbmanager/__init__.py b/datastore/dbmanager/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/datastore/dbmanager/dbmanager.py b/datastore/dbmanager/dbmanager.py new file mode 100644 index 0000000..d4ff293 --- /dev/null +++ b/datastore/dbmanager/dbmanager.py @@ -0,0 +1,349 @@ +from __future__ import annotations + +import os +import yaml +import logging +import uuid +import secrets +from datetime import datetime +from enum import auto, Enum as Enum_, unique + +from sqlalchemy import ( + Column, + create_engine, + DateTime, + Enum, + ForeignKey, + Integer, + JSON, + Sequence, + String, + Table, +) +from sqlalchemy.dialects.postgresql import UUID +from sqlalchemy.orm import declarative_base, sessionmaker, relationship + +from .singleton import Singleton + + +def is_true(item) -> bool: + """If `item` represents `True` value""" + if isinstance(item, str): + return item.lower() in ["y", "yes", "true", "t"] + return bool(item) + + +def generate_key() -> str: + """Generate as new api key for a user""" + return secrets.token_urlsafe(nbytes=32) + + +@unique +class RequestStatus(Enum_): + """Status of the Request""" + + PENDING = auto() + QUEUED = auto() + RUNNING = auto() + DONE = auto() + FAILED = auto() + TIMEOUT = auto() + + @classmethod + def _missing_(cls, value): + return cls.PENDING + + +class _Repr: + def __repr__(self): + cols = self.__table__.columns.keys() # pylint: disable=no-member + kwa = ", ".join(f"{col}={getattr(self, col)}" for col in cols) + return f"{type(self).__name__}({kwa})" + + +Base = declarative_base(cls=_Repr, name="Base") + + +association_table = Table( + "users_roles", + Base.metadata, + Column("user_id", ForeignKey("users.user_id")), + Column("role_id", ForeignKey("roles.role_id")), +) + + +class Role(Base): + __tablename__ = "roles" + role_id = Column(Integer, Sequence("role_id_seq"), primary_key=True) + role_name = Column(String(255), nullable=False, unique=True) + + +class User(Base): + __tablename__ = "users" + user_id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4) + # keycloak_id = Column(UUID(as_uuid=True), nullable=False, unique=True, default=uuid.uuid4) + api_key = Column( + String(255), nullable=False, unique=True, default=generate_key + ) + contact_name = Column(String(255)) + requests = relationship("Request", lazy="dynamic") + roles = relationship("Role", secondary=association_table, lazy="selectin") + + +class Worker(Base): + __tablename__ = "workers" + worker_id = Column(Integer, primary_key=True) + status = Column(String(255), nullable=False) + host = Column(String(255)) + dask_scheduler_port = Column(Integer) + dask_dashboard_address = Column(String(10)) + created_on = Column(DateTime, default=datetime.now) + + +class Request(Base): + __tablename__ = "requests" + request_id = Column(Integer, primary_key=True) + status = Column(Enum(RequestStatus), nullable=False) + priority = Column(Integer) + user_id = Column( + UUID(as_uuid=True), ForeignKey("users.user_id"), nullable=False + ) + worker_id = Column(Integer, ForeignKey("workers.worker_id")) + dataset = Column(String(255)) + product = Column(String(255)) + query = Column(JSON()) + estimate_size_bytes = Column(Integer) + created_on = Column(DateTime, default=datetime.now) + last_update = Column(DateTime, default=datetime.now, onupdate=datetime.now) + fail_reason = Column(String(1000)) + download = relationship("Download", uselist=False, lazy="selectin") + + +class Download(Base): + __tablename__ = "downloads" + download_id = Column(Integer, primary_key=True) + download_uri = Column(String(255)) + request_id = Column( + Integer, ForeignKey("requests.request_id"), nullable=False + ) + storage_id = Column(Integer, ForeignKey("storages.storage_id")) + location_path = Column(String(255)) + size_bytes = Column(Integer) + created_on = Column(DateTime, default=datetime.now) + + +class Storage(Base): + __tablename__ = "storages" + storage_id = Column(Integer, primary_key=True) + name = Column(String(255)) + host = Column(String(20)) + protocol = Column(String(10)) + port = Column(Integer) + + +class DBManager(metaclass=Singleton): + _LOG = logging.getLogger("geokube.DBManager") + + def __init__(self) -> None: + for venv_key in [ + "POSTGRES_DB", + "POSTGRES_USER", + "POSTGRES_PASSWORD", + "DB_SERVICE_PORT", + ]: + self._LOG.info( + "attempt to load data from environment variable: `%s`", + venv_key, + ) + if venv_key not in os.environ: + self._LOG.error( + "missing required environment variable: `%s`", venv_key + ) + raise KeyError( + f"missing required environment variable: {venv_key}" + ) + + user = os.environ["POSTGRES_USER"] + password = os.environ["POSTGRES_PASSWORD"] + host = os.environ["DB_SERVICE_HOST"] + port = os.environ["DB_SERVICE_PORT"] + database = os.environ["POSTGRES_DB"] + + url = f"postgresql://{user}:{password}@{host}:{port}/{database}" + self._LOG.info("db connection: `%s`", url) + self.__engine = create_engine( + url, echo=is_true(os.environ.get("DB_LOGGING", False)) + ) + self.__session_maker = sessionmaker(bind=self.__engine) + + def _create_database(self): + try: + Base.metadata.create_all(self.__engine) + except Exception as exception: + self._LOG.error( + "could not create a database due to an error", exc_info=True + ) + raise exception + + def add_user( + self, + contact_name: str, + user_id: UUID | None = None, + api_key: str | None = None, + roles_names: list[str] | None = None, + ): + with self.__session_maker() as session: + user = User( + user_id=user_id, api_key=api_key, contact_name=contact_name + ) + if roles_names: + user.roles.extend( + [ + session.query(Role) + .where(Role.role_name == role_name) + .all()[0] # NOTE: role_name is unique in the database + for role_name in roles_names + ] + ) + session.add(user) + session.commit() + return user + + def get_user_details(self, user_id: int): + with self.__session_maker() as session: + return session.query(User).get(user_id) + + def get_user_roles_names(self, user_id: int | None = None) -> list[str]: + if user_id is None: + return ["public"] + with self.__session_maker() as session: + return list( + map( + lambda role: role.role_name, + session.query(User).get(user_id).roles, + ) + ) + + def get_request_details(self, request_id: int): + with self.__session_maker() as session: + return session.query(Request).get(request_id) + + def get_download_details_for_request(self, request_id: int): + with self.__session_maker() as session: + request_details = session.query(Request).get(request_id) + if request_details is None: + raise ValueError( + f"Request with id: {request_id} doesn't exist" + ) + return request_details.download + + def create_request( + self, + user_id: int = 1, + dataset: str | None = None, + product: str | None = None, + query: str | None = None, + worker_id: int | None = None, + priority: str | None = None, + estimate_size_bytes: int | None = None, + status: RequestStatus = RequestStatus.PENDING, + ) -> int: + # TODO: Add more request-related parameters to this method. + with self.__session_maker() as session: + request = Request( + status=status, + priority=priority, + user_id=user_id, + worker_id=worker_id, + dataset=dataset, + product=product, + query=query, + estimate_size_bytes=estimate_size_bytes, + created_on=datetime.utcnow(), + ) + session.add(request) + session.commit() + return request.request_id + + def update_request( + self, + request_id: int, + worker_id: int | None = None, + status: RequestStatus | None = None, + location_path: str = None, + size_bytes: int = None, + fail_reason: str = None, + ) -> int: + with self.__session_maker() as session: + request = session.query(Request).get(request_id) + if status: + request.status = status + if worker_id: + request.worker_id = worker_id + request.last_update = datetime.utcnow() + request.fail_reason = fail_reason + session.commit() + if status is RequestStatus.DONE: + download = Download( + location_path=location_path, + storage_id=0, + request_id=request.request_id, + created_on=datetime.utcnow(), + download_uri=f"/download/{request_id}", + size_bytes=size_bytes, + ) + session.add(download) + session.commit() + return request.request_id + + def get_request_status_and_reason( + self, request_id + ) -> None | RequestStatus: + with self.__session_maker() as session: + if request := session.query(Request).get(request_id): + return RequestStatus(request.status), request.fail_reason + raise IndexError( + f"Request with id: `{request_id}` does not exist!" + ) + + def get_requests_for_user_id(self, user_id) -> list[Request]: + with self.__session_maker() as session: + return session.query(User).get(user_id).requests.all() + + def get_requests_for_user_id_and_status( + self, user_id, status: RequestStatus | tuple[RequestStatus] + ) -> list[Request]: + if isinstance(status, RequestStatus): + status = (status,) + with self.__session_maker() as session: + return session.get(User, user_id).requests.filter( + Request.status.in_(status) + ) + + def get_download_details_for_request_id(self, request_id) -> Download: + with self.__session_maker() as session: + request_details = session.query(Request).get(request_id) + if request_details is None: + raise IndexError( + f"Request with id: `{request_id}` does not exist!" + ) + return request_details.download + + def create_worker( + self, + status: str, + dask_scheduler_port: int, + dask_dashboard_address: int, + host: str = "localhost", + ) -> int: + with self.__session_maker() as session: + worker = Worker( + status=status, + host=host, + dask_scheduler_port=dask_scheduler_port, + dask_dashboard_address=dask_dashboard_address, + created_on=datetime.utcnow(), + ) + session.add(worker) + session.commit() + return worker.worker_id diff --git a/datastore/dbmanager/singleton.py b/datastore/dbmanager/singleton.py new file mode 100644 index 0000000..bf7b29b --- /dev/null +++ b/datastore/dbmanager/singleton.py @@ -0,0 +1,21 @@ +# -*- coding: utf-8 -*- +"""Singleton module. + +The module contains metaclass called Singleton +for thread-safe singleton-pattern implementation. +""" +from threading import Lock + + +class Singleton(type): + """Thread-safe implementation of the singleton design pattern metaclass""" + + _instances = {} + _lock: Lock = Lock() + + def __call__(cls, *args, **kwargs): + with cls._lock: + if cls not in cls._instances: + instance = super().__call__(*args, **kwargs) + cls._instances[cls] = instance + return cls._instances[cls] diff --git a/datastore/requirements.txt b/datastore/requirements.txt new file mode 100644 index 0000000..d4a7d44 --- /dev/null +++ b/datastore/requirements.txt @@ -0,0 +1,2 @@ +networkx +pydantic \ No newline at end of file diff --git a/datastore/tests/__init__.py b/datastore/tests/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/datastore/tests/workflow/__init__.py b/datastore/tests/workflow/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/datastore/tests/workflow/fixtures.py b/datastore/tests/workflow/fixtures.py new file mode 100644 index 0000000..8ce94ad --- /dev/null +++ b/datastore/tests/workflow/fixtures.py @@ -0,0 +1,122 @@ +import pytest + + +@pytest.fixture +def subset_query() -> str: + yield """ + { + "dataset_id": "era5-single-levels", + "product_id": "reanalysis", + "query": { + "area": { + "north": -85, + "south": -90, + "east": 260, + "west": 240 + }, + "time": { + "hour": [ + "15" + ], + "year": [ + "1981", + "1985", + "2022" + ], + "month": [ + "3", + "6" + ], + "day": [ + "23", + "27" + ] + }, + "variable": [ + "2_metre_dewpoint_temperature", + "surface_net_downward_shortwave_flux" + ] + } + } + """ + + +@pytest.fixture +def resample_query(): + yield """ + { + "freq": "1D", + "operator": "nanmax", + "resample_args": { + "closed": "right" + } + } + """ + + +@pytest.fixture +def workflow_str(): + yield """ + [ + { + "id": "subset1", + "op": "subset", + "args": { + "dataset_id": "era5-single-levels", + "product_id": "reanalysis", + "query": { + "area": { + "north": -85, + "south": -90, + "east": 260, + "west": 240 + } + } + } + }, + { + "id": "resample1", + "use": ["subset1"], + "op": "resample", + "args": + { + "freq": "1D", + "operator": "nanmax" + } + } + ] + """ + + +@pytest.fixture +def bad_workflow_str(): + yield """ + [ + { + "id": "subset1", + "op": "subset", + "args": { + "dataset_id": "era5-single-levels", + "product_id": "reanalysis", + "query": { + "area": { + "north": -85, + "south": -90, + "east": 260, + "west": 240 + } + } + } + }, + { + "id": "resample1", + "use": ["subset1", "subset2"], + "op": "resample", + "args": + { + "freq": "1D", + "operator": "nanmax" + } + } + ] + """ diff --git a/datastore/tests/workflow/test_operators.py b/datastore/tests/workflow/test_operators.py new file mode 100644 index 0000000..46cf109 --- /dev/null +++ b/datastore/tests/workflow/test_operators.py @@ -0,0 +1,20 @@ +from workflow import operators as op + +from .fixtures import subset_query, resample_query + + +def test_create_subset_operator_with_str_args(subset_query): + sub_op = op.Operator("subset", subset_query) + assert isinstance(sub_op, op.Subset) + assert isinstance(sub_op.args, op.SubsetArgs) + assert sub_op.args.dataset_id == "era5-single-levels" + assert sub_op.args.product_id == "reanalysis" + + +def test_create_resample_operator_with_str_args(resample_query): + res_op = op.Operator("resample", resample_query) + assert isinstance(res_op, op.Resample) + assert isinstance(res_op.args, op.ResampleArgs) + assert res_op.args.freq == "1D" + assert res_op.args.operator == "nanmax" + assert res_op.args.resample_args == {"closed": "right"} diff --git a/datastore/tests/workflow/test_workflow.py b/datastore/tests/workflow/test_workflow.py new file mode 100644 index 0000000..7036b73 --- /dev/null +++ b/datastore/tests/workflow/test_workflow.py @@ -0,0 +1,23 @@ +import pytest +from workflow.workflow import Workflow + +from .fixtures import workflow_str, bad_workflow_str + + +def test_create_workflow(workflow_str): + comp_graph = Workflow(workflow_str) + assert len(comp_graph) == 2 + task_iter = comp_graph.traverse() + node1, precedint1 = next(task_iter) + assert precedint1 == tuple() + assert node1.operator.name == "subset" + + node2, precedint2 = next(task_iter) + assert len(precedint2) == 1 + assert node2.operator.name == "resample" + assert precedint2[0].operator.name == "subset" + + +def test_fail_when_task_not_defined(bad_workflow_str): + with pytest.raises(ValueError, match=r"task with id*"): + _ = Workflow(bad_workflow_str) diff --git a/datastore/utils/__init__.py b/datastore/utils/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/datastore/utils/api_logging.py b/datastore/utils/api_logging.py new file mode 100644 index 0000000..58d148d --- /dev/null +++ b/datastore/utils/api_logging.py @@ -0,0 +1,40 @@ +import os +from typing import Literal +import logging as default_logging + + +def get_dds_logger( + name: str, + level: Literal["debug", "info", "warning", "error", "critical"] = "info", +): + """Get DDS logger with the expected format, handlers and formatter. + + Parameters + ---------- + name : str + Name of the logger + level : str, default="info" + Value of the logging level. One out of ["debug", "info", "warn", + "error", "critical"]. + Logging level is taken from the + enviornmental variable `LOGGING_FORMAT`. If this variable is not defined, + the value of the `level` argument is used. + + Returns + ------- + log : logging.Logger + Logger with the handlers set + """ + log = default_logging.getLogger(name) + format_ = os.environ.get( + "LOGGING_FORMAT", + "%(asctime)s %(name)s %(levelname)s %(message)s", + ) + formatter = default_logging.Formatter(format_) + logging_level = os.environ.get("LOGGING_LEVEL", level.upper()) + log.setLevel(logging_level) + stream_handler = default_logging.StreamHandler() + stream_handler.setFormatter(formatter) + stream_handler.setLevel(logging_level) + log.addHandler(stream_handler) + return log diff --git a/datastore/utils/metrics.py b/datastore/utils/metrics.py new file mode 100644 index 0000000..82aeb55 --- /dev/null +++ b/datastore/utils/metrics.py @@ -0,0 +1,33 @@ +import time +import logging as default_logging +from functools import wraps +from typing import Literal + + +def log_execution_time( + logger: default_logging.Logger, + level: Literal["debug", "info", "warning", "error", "critical"] = "info", +): + """Decorator logging execution time of the method or function""" + level = default_logging.getLevelName(level.upper()) + + def inner(func): + @wraps(func) + def wrapper(*args, **kwds): + exec_start_time = time.monotonic() + try: + return func(*args, **kwds) + finally: + # NOTE: maybe logging should be on DEBUG level + logger.log( + level, + "execution of '%s' function from '%s' package took" + " %.4f sec", + func.__name__, + func.__module__, + time.monotonic() - exec_start_time, + ) + + return wrapper + + return inner diff --git a/datastore/wait-for-it.sh b/datastore/wait-for-it.sh new file mode 100755 index 0000000..d990e0d --- /dev/null +++ b/datastore/wait-for-it.sh @@ -0,0 +1,182 @@ +#!/usr/bin/env bash +# Use this script to test if a given TCP host/port are available + +WAITFORIT_cmdname=${0##*/} + +echoerr() { if [[ $WAITFORIT_QUIET -ne 1 ]]; then echo "$@" 1>&2; fi } + +usage() +{ + cat << USAGE >&2 +Usage: + $WAITFORIT_cmdname host:port [-s] [-t timeout] [-- command args] + -h HOST | --host=HOST Host or IP under test + -p PORT | --port=PORT TCP port under test + Alternatively, you specify the host and port as host:port + -s | --strict Only execute subcommand if the test succeeds + -q | --quiet Don't output any status messages + -t TIMEOUT | --timeout=TIMEOUT + Timeout in seconds, zero for no timeout + -- COMMAND ARGS Execute command with args after the test finishes +USAGE + exit 1 +} + +wait_for() +{ + if [[ $WAITFORIT_TIMEOUT -gt 0 ]]; then + echoerr "$WAITFORIT_cmdname: waiting $WAITFORIT_TIMEOUT seconds for $WAITFORIT_HOST:$WAITFORIT_PORT" + else + echoerr "$WAITFORIT_cmdname: waiting for $WAITFORIT_HOST:$WAITFORIT_PORT without a timeout" + fi + WAITFORIT_start_ts=$(date +%s) + while : + do + if [[ $WAITFORIT_ISBUSY -eq 1 ]]; then + nc -z $WAITFORIT_HOST $WAITFORIT_PORT + WAITFORIT_result=$? + else + (echo -n > /dev/tcp/$WAITFORIT_HOST/$WAITFORIT_PORT) >/dev/null 2>&1 + WAITFORIT_result=$? + fi + if [[ $WAITFORIT_result -eq 0 ]]; then + WAITFORIT_end_ts=$(date +%s) + echoerr "$WAITFORIT_cmdname: $WAITFORIT_HOST:$WAITFORIT_PORT is available after $((WAITFORIT_end_ts - WAITFORIT_start_ts)) seconds" + break + fi + sleep 1 + done + return $WAITFORIT_result +} + +wait_for_wrapper() +{ + # In order to support SIGINT during timeout: http://unix.stackexchange.com/a/57692 + if [[ $WAITFORIT_QUIET -eq 1 ]]; then + timeout $WAITFORIT_BUSYTIMEFLAG $WAITFORIT_TIMEOUT $0 --quiet --child --host=$WAITFORIT_HOST --port=$WAITFORIT_PORT --timeout=$WAITFORIT_TIMEOUT & + else + timeout $WAITFORIT_BUSYTIMEFLAG $WAITFORIT_TIMEOUT $0 --child --host=$WAITFORIT_HOST --port=$WAITFORIT_PORT --timeout=$WAITFORIT_TIMEOUT & + fi + WAITFORIT_PID=$! + trap "kill -INT -$WAITFORIT_PID" INT + wait $WAITFORIT_PID + WAITFORIT_RESULT=$? + if [[ $WAITFORIT_RESULT -ne 0 ]]; then + echoerr "$WAITFORIT_cmdname: timeout occurred after waiting $WAITFORIT_TIMEOUT seconds for $WAITFORIT_HOST:$WAITFORIT_PORT" + fi + return $WAITFORIT_RESULT +} + +# process arguments +while [[ $# -gt 0 ]] +do + case "$1" in + *:* ) + WAITFORIT_hostport=(${1//:/ }) + WAITFORIT_HOST=${WAITFORIT_hostport[0]} + WAITFORIT_PORT=${WAITFORIT_hostport[1]} + shift 1 + ;; + --child) + WAITFORIT_CHILD=1 + shift 1 + ;; + -q | --quiet) + WAITFORIT_QUIET=1 + shift 1 + ;; + -s | --strict) + WAITFORIT_STRICT=1 + shift 1 + ;; + -h) + WAITFORIT_HOST="$2" + if [[ $WAITFORIT_HOST == "" ]]; then break; fi + shift 2 + ;; + --host=*) + WAITFORIT_HOST="${1#*=}" + shift 1 + ;; + -p) + WAITFORIT_PORT="$2" + if [[ $WAITFORIT_PORT == "" ]]; then break; fi + shift 2 + ;; + --port=*) + WAITFORIT_PORT="${1#*=}" + shift 1 + ;; + -t) + WAITFORIT_TIMEOUT="$2" + if [[ $WAITFORIT_TIMEOUT == "" ]]; then break; fi + shift 2 + ;; + --timeout=*) + WAITFORIT_TIMEOUT="${1#*=}" + shift 1 + ;; + --) + shift + WAITFORIT_CLI=("$@") + break + ;; + --help) + usage + ;; + *) + echoerr "Unknown argument: $1" + usage + ;; + esac +done + +if [[ "$WAITFORIT_HOST" == "" || "$WAITFORIT_PORT" == "" ]]; then + echoerr "Error: you need to provide a host and port to test." + usage +fi + +WAITFORIT_TIMEOUT=${WAITFORIT_TIMEOUT:-15} +WAITFORIT_STRICT=${WAITFORIT_STRICT:-0} +WAITFORIT_CHILD=${WAITFORIT_CHILD:-0} +WAITFORIT_QUIET=${WAITFORIT_QUIET:-0} + +# Check to see if timeout is from busybox? +WAITFORIT_TIMEOUT_PATH=$(type -p timeout) +WAITFORIT_TIMEOUT_PATH=$(realpath $WAITFORIT_TIMEOUT_PATH 2>/dev/null || readlink -f $WAITFORIT_TIMEOUT_PATH) + +WAITFORIT_BUSYTIMEFLAG="" +if [[ $WAITFORIT_TIMEOUT_PATH =~ "busybox" ]]; then + WAITFORIT_ISBUSY=1 + # Check if busybox timeout uses -t flag + # (recent Alpine versions don't support -t anymore) + if timeout &>/dev/stdout | grep -q -e '-t '; then + WAITFORIT_BUSYTIMEFLAG="-t" + fi +else + WAITFORIT_ISBUSY=0 +fi + +if [[ $WAITFORIT_CHILD -gt 0 ]]; then + wait_for + WAITFORIT_RESULT=$? + exit $WAITFORIT_RESULT +else + if [[ $WAITFORIT_TIMEOUT -gt 0 ]]; then + wait_for_wrapper + WAITFORIT_RESULT=$? + else + wait_for + WAITFORIT_RESULT=$? + fi +fi + +if [[ $WAITFORIT_CLI != "" ]]; then + if [[ $WAITFORIT_RESULT -ne 0 && $WAITFORIT_STRICT -eq 1 ]]; then + echoerr "$WAITFORIT_cmdname: strict mode, refusing to execute subprocess" + exit $WAITFORIT_RESULT + fi + exec "${WAITFORIT_CLI[@]}" +else + exit $WAITFORIT_RESULT +fi diff --git a/datastore/workflow/__init__.py b/datastore/workflow/__init__.py new file mode 100644 index 0000000..9c75326 --- /dev/null +++ b/datastore/workflow/__init__.py @@ -0,0 +1 @@ +from workflow.workflow import Workflow diff --git a/datastore/workflow/workflow.py b/datastore/workflow/workflow.py new file mode 100644 index 0000000..63e6f78 --- /dev/null +++ b/datastore/workflow/workflow.py @@ -0,0 +1,226 @@ +import json +from typing import Generator, Hashable, Callable, Literal, Any +from functools import partial +import logging + +import networkx as nx +from geokube.core.datacube import DataCube +from intake_geokube.queries.geoquery import GeoQuery +from intake_geokube.queries.workflow import Workflow as WorkflowModel +from datastore.datastore import Datastore + +AggregationFunctionName = ( + Literal["max"] + | Literal["nanmax"] + | Literal["min"] + | Literal["nanmin"] + | Literal["mean"] + | Literal["nanmean"] + | Literal["sum"] + | Literal["nansum"] +) + + +_LOG = logging.getLogger("geokube.workflow") + +TASK_ATTRIBUTE = "task" + + +class _WorkflowTask: + __slots__ = ("id", "dependencies", "operator") + + id: Hashable + dependencies: list[Hashable] | None + operator: Callable[..., DataCube] + + def __init__( + self, + id: Hashable, + operator: Callable[..., DataCube], + dependencies: list[Hashable] | None = None, + ) -> None: + self.operator = operator + self.id = id + if dependencies is None: + dependencies = [] + self.dependencies = dependencies + + def compute(self, kube: DataCube | None) -> DataCube: + return self.operator(kube) + + +class Workflow: + __slots__ = ("graph", "present_nodes_ids", "is_verified") + + graph: nx.DiGraph + present_nodes_ids: set[Hashable] + is_verified: bool + + def __init__(self) -> None: + self.graph = nx.DiGraph() + self.present_nodes_ids = set() + self.is_verified = False + + @classmethod + def from_tasklist(cls, task_list: WorkflowModel) -> "Workflow": + workflow = cls() + for task in task_list.tasks: + match task.op: + case "subset": + workflow.subset(task.id, **task.args) + case "resample": + workflow.resample( + task.id, dependencies=task.use, **task.args + ) + case "average": + workflow.average( + task.id, dependencies=task.use, **task.args + ) + case "to_regular": + workflow.to_regular( + task.id, dependencies=task.use, **task.args + ) + case _: + raise ValueError( + f"task operator: {task.op} is not defined" + ) + return workflow + + def _add_computational_node(self, task: _WorkflowTask): + node_id = task.id + assert ( + node_id not in self.present_nodes_ids + ), "worflow task IDs need to be unique!" + self.present_nodes_ids.add(node_id) + self.graph.add_node(node_id, **{TASK_ATTRIBUTE: task}) + for dependend_node in task.dependencies: + self.graph.add_edge(dependend_node, node_id) + self.is_verified = False + + def subset( + self, + id: Hashable, + dataset_id: str, + product_id: str, + query: GeoQuery | dict, + ) -> "Workflow": + def _subset(kube: DataCube | None = None) -> DataCube: + return Datastore().query( + dataset_id=dataset_id, + product_id=product_id, + query=( + query if isinstance(query, GeoQuery) else GeoQuery(**query) + ), + compute=False, + ) + + task = _WorkflowTask(id=id, operator=_subset) + self._add_computational_node(task) + return self + + def resample( + self, + id: Hashable, + freq: str, + agg: Callable[..., DataCube] | AggregationFunctionName, + resample_kwargs: dict[str, Any] | None, + *, + dependencies: list[Hashable], + ) -> "Workflow": + def _resample(kube: DataCube | None = None) -> DataCube: + assert kube is not None, "`kube` cannot be `None` for resampling" + return kube.resample( + operator=agg, + frequency=freq, + **resample_kwargs, + ) + + task = _WorkflowTask( + id=id, operator=_resample, dependencies=dependencies + ) + self._add_computational_node(task) + return self + + def average( + self, id: Hashable, dim: str, *, dependencies: list[Hashable] + ) -> "Workflow": + def _average(kube: DataCube | None = None) -> DataCube: + assert kube is not None, "`kube` cannot be `None` for averaging" + return kube.average(dim=dim) + + task = _WorkflowTask( + id=id, operator=_average, dependencies=dependencies + ) + self._add_computational_node(task) + return self + + def to_regular( + self, id: Hashable, *, dependencies: list[Hashable] + ) -> "Workflow": + def _to_regular(kube: DataCube | None = None) -> DataCube: + assert ( + kube is not None + ), "`kube` cannot be `None` for `to_regular``" + return kube.to_regular() + + task = _WorkflowTask( + id=id, operator=_to_regular, dependencies=dependencies + ) + self._add_computational_node(task) + return self + + def add_task( + self, + id: Hashable, + func: Callable[..., DataCube], + dependencies: list[str] | None = None, + **func_kwargs, + ) -> "Workflow": + task = _WorkflowTask( + id=id, + operator=partial(func, **func_kwargs), + dependencies=dependencies, + ) + self._add_computational_node(task) + return self + + def verify(self) -> "Workflow": + if self.is_verified: + return + assert nx.is_directed_acyclic_graph( + self.graph + ), "the workflow contains cycles!" + for u, v in self.graph.edges: + if TASK_ATTRIBUTE not in self.graph.nodes[u].keys(): + _LOG.error( + "task with id `%s` is not defined for the workflow", u + ) + raise ValueError( + f"task with id `{u}` is not defined for the workflow" + ) + if TASK_ATTRIBUTE not in self.graph.nodes[v].keys(): + _LOG.error( + "task with id `%s` is not defined for the workflow", v + ) + raise ValueError( + f"task with id `{v}` is not defined for the workflow" + ) + self.is_verified = True + + def traverse(self) -> Generator[_WorkflowTask, None, None]: + for node_id in nx.topological_sort(self.graph): + _LOG.debug("computing task for the node: %s", node_id) + yield self.graph.nodes[node_id][TASK_ATTRIBUTE] + + def compute(self) -> DataCube: + self.verify() + result = None + for task in self.traverse(): + result = task.compute(result) + return result + + def __len__(self): + return len(self.graph.nodes) + + def __getitem__(self, idx: Hashable): + return self.graph.nodes[idx] diff --git a/db/Dockerfile b/db/Dockerfile deleted file mode 100644 index 8bcf754..0000000 --- a/db/Dockerfile +++ /dev/null @@ -1,2 +0,0 @@ -FROM postgres:14.1 -ADD ./scripts/init.sql /docker-entrypoint-initdb.d/ \ No newline at end of file diff --git a/db/dbmanager/dbmanager.py b/db/dbmanager/dbmanager.py deleted file mode 100644 index 16b956b..0000000 --- a/db/dbmanager/dbmanager.py +++ /dev/null @@ -1,183 +0,0 @@ -from __future__ import annotations - -from datetime import datetime -from enum import auto, Enum as Enum_, unique - -from sqlalchemy import ( - Column, - create_engine, - DateTime, - Enum, - ForeignKey, - Integer, - JSON, - Sequence, - String -) -from sqlalchemy.orm import declarative_base, relationship, sessionmaker - - -@unique -class RequestStatus(Enum_): - PENDING = auto() - RUNNING = auto() - DONE = auto() - FAILED = auto() - - -class _Repr: - def __repr__(self): - cols = self.__table__.columns.keys() # pylint: disable=no-member - kwa = ', '.join(f'{col}={getattr(self, col)}' for col in cols) - return f'{type(self).__name__}({kwa})' - - -Base = declarative_base(cls=_Repr, name='Base') - - -class Role(Base): - __tablename__ = 'roles' - role_id = Column(Integer, Sequence('role_id_seq'), primary_key=True) - role_name = Column(String(255), nullable=False, unique=True) - - -class User(Base): - __tablename__ = 'users' - user_id = Column(Integer, primary_key=True) - keycloak_id = Column(Integer, nullable=False, unique=True) - api_key = Column(String(255), nullable=False, unique=True) - contact_name = Column(String(255)) - role_id = Column(Integer, ForeignKey('roles.role_id')) - - -class Worker(Base): - __tablename__ = 'workers' - worker_id = Column(Integer, primary_key=True) - status = Column(String(255), nullable=False) - host = Column(String(255)) - dask_scheduler_port = Column(Integer) - dask_dashboard_address = Column(String(10)) - created_on = Column(DateTime, nullable=False) - - -class Request(Base): - __tablename__ = 'requests' - request_id = Column(Integer, primary_key=True) - status = Column(Enum(RequestStatus), nullable=False) - priority = Column(Integer) - user_id = Column(Integer, ForeignKey('users.user_id'), nullable=False) - worker_id = Column(Integer, ForeignKey('workers.worker_id')) - dataset = Column(String(255)) - product = Column(String(255)) - query = Column(JSON()) - estimate_bytes_size = Column(Integer) - download_id = Column(Integer, unique=True) - created_on = Column(DateTime, nullable=False) - last_update = Column(DateTime) - - -class Download(Base): - __tablename__ = 'downloads' - download_id = Column( - Integer, primary_key=True - ) - download_uri = Column(String(255)) - storage_id = Column(Integer) - location_path = Column(String(255)) - bytes_size = Column(Integer) - created_on = Column(DateTime, nullable=False) - - -class Storage(Base): - __tablename__ = 'storages' - storage_id = Column(Integer, primary_key=True) - name = Column(String(255)) - host = Column(String(20)) - protocol = Column(String(10)) - port = Column(Integer) - - -class DBManager: - def __init__( - self, - database: str = 'dds', - host: str = 'db', - port: int = 5432, - user: str = 'dds', - password: str = 'dds' - ) -> None: - url = f'postgresql://{user}:{password}@{host}:{port}/{database}' - self.__engine = engine = create_engine(url, echo=True) - self.__session_maker = sessionmaker(bind=engine) - Base.metadata.create_all(engine) - - def create_request( - self, - user_id: int = 1, - dataset: str | None = None, - product: str | None = None, - query: str | None = None, - worker_id: int | None = None, - priority: str | None = None, - estimate_bytes_size: int | None = None, - download_id: int | None = None, - status: RequestStatus = RequestStatus.PENDING, - ) -> int: - # TODO: Add more request-related parameters to this method. - with self.__session_maker() as session: - request = Request( - status=status, - priority=priority, - user_id=user_id, - worker_id=worker_id, - dataset=dataset, - product=product, - query=query, - estimate_bytes_size=estimate_bytes_size, - download_id=download_id, - created_on=datetime.utcnow() - ) - session.add(request) - session.commit() - return request.request_id - - def update_request( - self, - request_id: int, - worker_id: int, - status: RequestStatus - ) -> int: - with self.__session_maker() as session: - request = session.query(Request).get(request_id) - request.status = status - request.worker_id = worker_id - request.last_update = datetime.utcnow() - session.commit() - return request.request_id - - def get_request_status( - self, - request_id - ) -> RequestStatus: - with self.__session_maker() as session: - request = session.query(Request).get(request_id) - return request.status - - def create_worker( - self, - status: str, - dask_scheduler_port: int, - dask_dashboard_address: int, - host: str = 'localhost' - ) -> int: - with self.__session_maker() as session: - worker = Worker( - status=status, - host=host, - dask_scheduler_port=dask_scheduler_port, - dask_dashboard_address=dask_dashboard_address, - created_on=datetime.utcnow() - ) - session.add(worker) - session.commit() - return worker.worker_id diff --git a/db/scripts/1-init.sql b/db/scripts/1-init.sql deleted file mode 100644 index fafd908..0000000 --- a/db/scripts/1-init.sql +++ /dev/null @@ -1,66 +0,0 @@ --- CREATE USER dds WITH PASSWORD 'dds'; --- CREATE DATABASE dds; --- GRANT ALL PRIVILEGES ON DATABASE dds TO dds; - -CREATE TABLE IF NOT EXISTS roles ( - role_id SERIAL PRIMARY KEY, - role_name VARCHAR (255) UNIQUE NOT NULL -); - -CREATE TABLE IF NOT EXISTS users ( - user_id SERIAL PRIMARY KEY, - keycloak_id INT UNIQUE NOT NULL, - api_key VARCHAR(255) UNIQUE NOT NULL, - contact_name VARCHAR(255), - role_id INT, - CONSTRAINT fk_role - FOREIGN KEY(role_id) - REFERENCES roles(role_id) -); - -CREATE TABLE IF NOT EXISTS workers ( - worker_id SERIAL PRIMARY KEY, - status VARCHAR(255) NOT NULL, - host VARCHAR(255), - dask_scheduler_port INT, - dask_dashboard_address CHAR(10), - created_on TIMESTAMP NOT NULL -); - -CREATE TABLE IF NOT EXISTS requests ( - request_id SERIAL PRIMARY KEY, - status VARCHAR(255) NOT NULL, - priority INT, - user_id INT NOT NULL, - worker_id INT, - dataset VARCHAR(255), - product VARCHAR(255), - query json, - estimate_bytes_size INT, - download_id INT UNIQUE, - created_on TIMESTAMP NOT NULL, - last_update TIMESTAMP, - CONSTRAINT fk_user - FOREIGN KEY(user_id) - REFERENCES users(user_id), - CONSTRAINT fk_worker - FOREIGN KEY(worker_id) - REFERENCES workers(worker_id) -); - -CREATE TABLE IF NOT EXISTS downloads ( - download_id SERIAL PRIMARY KEY, - download_uri VARCHAR(255), - storage_id INT, - location_path VARCHAR(255), - bytes_size INT, - created_on TIMESTAMP NOT NULL -); - -CREATE TABLE IF NOT EXISTS storages ( - storage_id SERIAL PRIMARY KEY, - name VARCHAR(255), - host VARCHAR(20), - protocol VARCHAR(10), - port INT -); \ No newline at end of file diff --git a/db/scripts/2-populate.sql b/db/scripts/2-populate.sql deleted file mode 100644 index 1406ff9..0000000 --- a/db/scripts/2-populate.sql +++ /dev/null @@ -1,2 +0,0 @@ -INSERT INTO roles VALUES (1, 'internal'); -INSERT INTO users VALUES (1, '1234', '1234:1234', 'Mario Rossi', 1); \ No newline at end of file diff --git a/drivers/Dockerfile b/drivers/Dockerfile new file mode 100644 index 0000000..4980d28 --- /dev/null +++ b/drivers/Dockerfile @@ -0,0 +1,8 @@ +ARG REGISTRY=rg.nl-ams.scw.cloud/geokube-production +ARG TAG=latest +FROM $REGISTRY/geokube:$TAG +RUN conda install -c conda-forge --yes --freeze-installed intake=0.6.6 +RUN conda clean -afy +COPY dist/geolake_drivers-1.0b0-py3-none-any.whl / +RUN pip install /geolake_drivers-1.0b0-py3-none-any.whl +RUN rm /geolake_drivers-1.0b0-py3-none-any.whl diff --git a/drivers/LICENSE b/drivers/LICENSE new file mode 100644 index 0000000..2b65938 --- /dev/null +++ b/drivers/LICENSE @@ -0,0 +1,201 @@ + Apache License + Version 2.0, January 2004 + http://www.apache.org/licenses/ + + TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION + + 1. Definitions. + + "License" shall mean the terms and conditions for use, reproduction, + and distribution as defined by Sections 1 through 9 of this document. + + "Licensor" shall mean the copyright owner or entity authorized by + the copyright owner that is granting the License. + + "Legal Entity" shall mean the union of the acting entity and all + other entities that control, are controlled by, or are under common + control with that entity. For the purposes of this definition, + "control" means (i) the power, direct or indirect, to cause the + direction or management of such entity, whether by contract or + otherwise, or (ii) ownership of fifty percent (50%) or more of the + outstanding shares, or (iii) beneficial ownership of such entity. + + "You" (or "Your") shall mean an individual or Legal Entity + exercising permissions granted by this License. + + "Source" form shall mean the preferred form for making modifications, + including but not limited to software source code, documentation + source, and configuration files. + + "Object" form shall mean any form resulting from mechanical + transformation or translation of a Source form, including but + not limited to compiled object code, generated documentation, + and conversions to other media types. + + "Work" shall mean the work of authorship, whether in Source or + Object form, made available under the License, as indicated by a + copyright notice that is included in or attached to the work + (an example is provided in the Appendix below). + + "Derivative Works" shall mean any work, whether in Source or Object + form, that is based on (or derived from) the Work and for which the + editorial revisions, annotations, elaborations, or other modifications + represent, as a whole, an original work of authorship. For the purposes + of this License, Derivative Works shall not include works that remain + separable from, or merely link (or bind by name) to the interfaces of, + the Work and Derivative Works thereof. + + "Contribution" shall mean any work of authorship, including + the original version of the Work and any modifications or additions + to that Work or Derivative Works thereof, that is intentionally + submitted to Licensor for inclusion in the Work by the copyright owner + or by an individual or Legal Entity authorized to submit on behalf of + the copyright owner. For the purposes of this definition, "submitted" + means any form of electronic, verbal, or written communication sent + to the Licensor or its representatives, including but not limited to + communication on electronic mailing lists, source code control systems, + and issue tracking systems that are managed by, or on behalf of, the + Licensor for the purpose of discussing and improving the Work, but + excluding communication that is conspicuously marked or otherwise + designated in writing by the copyright owner as "Not a Contribution." + + "Contributor" shall mean Licensor and any individual or Legal Entity + on behalf of whom a Contribution has been received by Licensor and + subsequently incorporated within the Work. + + 2. Grant of Copyright License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + copyright license to reproduce, prepare Derivative Works of, + publicly display, publicly perform, sublicense, and distribute the + Work and such Derivative Works in Source or Object form. + + 3. Grant of Patent License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + (except as stated in this section) patent license to make, have made, + use, offer to sell, sell, import, and otherwise transfer the Work, + where such license applies only to those patent claims licensable + by such Contributor that are necessarily infringed by their + Contribution(s) alone or by combination of their Contribution(s) + with the Work to which such Contribution(s) was submitted. If You + institute patent litigation against any entity (including a + cross-claim or counterclaim in a lawsuit) alleging that the Work + or a Contribution incorporated within the Work constitutes direct + or contributory patent infringement, then any patent licenses + granted to You under this License for that Work shall terminate + as of the date such litigation is filed. + + 4. Redistribution. You may reproduce and distribute copies of the + Work or Derivative Works thereof in any medium, with or without + modifications, and in Source or Object form, provided that You + meet the following conditions: + + (a) You must give any other recipients of the Work or + Derivative Works a copy of this License; and + + (b) You must cause any modified files to carry prominent notices + stating that You changed the files; and + + (c) You must retain, in the Source form of any Derivative Works + that You distribute, all copyright, patent, trademark, and + attribution notices from the Source form of the Work, + excluding those notices that do not pertain to any part of + the Derivative Works; and + + (d) If the Work includes a "NOTICE" text file as part of its + distribution, then any Derivative Works that You distribute must + include a readable copy of the attribution notices contained + within such NOTICE file, excluding those notices that do not + pertain to any part of the Derivative Works, in at least one + of the following places: within a NOTICE text file distributed + as part of the Derivative Works; within the Source form or + documentation, if provided along with the Derivative Works; or, + within a display generated by the Derivative Works, if and + wherever such third-party notices normally appear. The contents + of the NOTICE file are for informational purposes only and + do not modify the License. You may add Your own attribution + notices within Derivative Works that You distribute, alongside + or as an addendum to the NOTICE text from the Work, provided + that such additional attribution notices cannot be construed + as modifying the License. + + You may add Your own copyright statement to Your modifications and + may provide additional or different license terms and conditions + for use, reproduction, or distribution of Your modifications, or + for any such Derivative Works as a whole, provided Your use, + reproduction, and distribution of the Work otherwise complies with + the conditions stated in this License. + + 5. Submission of Contributions. Unless You explicitly state otherwise, + any Contribution intentionally submitted for inclusion in the Work + by You to the Licensor shall be under the terms and conditions of + this License, without any additional terms or conditions. + Notwithstanding the above, nothing herein shall supersede or modify + the terms of any separate license agreement you may have executed + with Licensor regarding such Contributions. + + 6. Trademarks. This License does not grant permission to use the trade + names, trademarks, service marks, or product names of the Licensor, + except as required for reasonable and customary use in describing the + origin of the Work and reproducing the content of the NOTICE file. + + 7. Disclaimer of Warranty. Unless required by applicable law or + agreed to in writing, Licensor provides the Work (and each + Contributor provides its Contributions) on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or + implied, including, without limitation, any warranties or conditions + of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A + PARTICULAR PURPOSE. You are solely responsible for determining the + appropriateness of using or redistributing the Work and assume any + risks associated with Your exercise of permissions under this License. + + 8. Limitation of Liability. In no event and under no legal theory, + whether in tort (including negligence), contract, or otherwise, + unless required by applicable law (such as deliberate and grossly + negligent acts) or agreed to in writing, shall any Contributor be + liable to You for damages, including any direct, indirect, special, + incidental, or consequential damages of any character arising as a + result of this License or out of the use or inability to use the + Work (including but not limited to damages for loss of goodwill, + work stoppage, computer failure or malfunction, or any and all + other commercial damages or losses), even if such Contributor + has been advised of the possibility of such damages. + + 9. Accepting Warranty or Additional Liability. While redistributing + the Work or Derivative Works thereof, You may choose to offer, + and charge a fee for, acceptance of support, warranty, indemnity, + or other liability obligations and/or rights consistent with this + License. However, in accepting such obligations, You may act only + on Your own behalf and on Your sole responsibility, not on behalf + of any other Contributor, and only if You agree to indemnify, + defend, and hold each Contributor harmless for any liability + incurred by, or claims asserted against, such Contributor by reason + of your accepting any such warranty or additional liability. + + END OF TERMS AND CONDITIONS + + APPENDIX: How to apply the Apache License to your work. + + To apply the Apache License to your work, attach the following + boilerplate notice, with the fields enclosed by brackets "[]" + replaced with your own identifying information. (Don't include + the brackets!) The text should be enclosed in the appropriate + comment syntax for the file format. We also recommend that a + file or class name and description of purpose be included on the + same "printed page" as the copyright notice for easier + identification within third-party archives. + + Copyright [yyyy] [name of copyright owner] + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. diff --git a/drivers/Makefile b/drivers/Makefile new file mode 100644 index 0000000..12a2661 --- /dev/null +++ b/drivers/Makefile @@ -0,0 +1,21 @@ +.PHONY: typehint +typehint: + mypy --ignore-missing-imports --check-untyped-defs intake_geokube + pylint intake_geokube + +.PHONY: test +test: + pytest tests/ + +.PHONY: format +format: + isort intake_geokube + black intake_geokube + black tests/ + isort tests/ + +.PHONY: docs +docs: + pydocstyle -e --convention=numpy intake_geokube + +prepublish: format typehint docs test diff --git a/drivers/README.md b/drivers/README.md new file mode 100644 index 0000000..ed98e22 --- /dev/null +++ b/drivers/README.md @@ -0,0 +1,2 @@ +# geolake-drivers +GeoKube plugin for Intake \ No newline at end of file diff --git a/drivers/intake_geokube/__init__.py b/drivers/intake_geokube/__init__.py new file mode 100644 index 0000000..95b5503 --- /dev/null +++ b/drivers/intake_geokube/__init__.py @@ -0,0 +1,6 @@ +"""Geokube Plugin for Intake.""" + +# This avoids a circular dependency pitfall by ensuring that the +# driver-discovery code runs first, see: +# https://intake.readthedocs.io/en/latest/making-plugins.html#entrypoints +from .queries.geoquery import GeoQuery diff --git a/drivers/intake_geokube/base.py b/drivers/intake_geokube/base.py new file mode 100644 index 0000000..e070427 --- /dev/null +++ b/drivers/intake_geokube/base.py @@ -0,0 +1,132 @@ +"""Module with AbstractBaseDriver definition.""" + +import logging +import os +from abc import ABC, abstractmethod +from typing import Any + +from dask.delayed import Delayed +from geokube.core.datacube import DataCube +from geokube.core.dataset import Dataset +from intake.source.base import DataSourceBase + +from .queries.geoquery import GeoQuery + +_NOT_SET: str = "" + + +class AbstractBaseDriver(ABC, DataSourceBase): + """Abstract base class for all DDS-related drivers.""" + + name: str = _NOT_SET + version: str = _NOT_SET + container: str = "python" + log: logging.Logger + + def __new__(cls, *arr, **kw): # pylint: disable=unused-argument + """Create a new instance of driver and configure logger.""" + obj = super().__new__(cls) + assert ( + obj.name != _NOT_SET + ), f"'name' class attribute was not set for the driver '{cls}'" + assert ( + obj.version != _NOT_SET + ), f"'name' class attribute was not set for the driver '{cls}'" + obj.log = cls.__configure_logger() + return obj + + def __init__(self, *, metadata: dict) -> None: + super().__init__(metadata=metadata) + + @classmethod + def __configure_logger(cls) -> logging.Logger: + log = logging.getLogger(f"dds.intake.{cls.__name__}") + level = os.environ.get("DDS_LOG_LEVEL", "INFO") + logformat = os.environ.get( + "DDS_LOG_FORMAT", + "%(asctime)s %(name)s %(funcName)s %(levelname)s %(message)s", + ) + log.setLevel(level) # type: ignore[arg-type] + for handler in log.handlers: + if isinstance(handler, logging.StreamHandler): + break + else: + log.addHandler(logging.StreamHandler()) + if logformat: + formatter = logging.Formatter(logformat) + for handler in log.handlers: + handler.setFormatter(formatter) + for handler in log.handlers: + handler.setLevel(level) # type: ignore[arg-type] + return log + + @abstractmethod + def read(self) -> Any: + """Read metadata.""" + raise NotImplementedError + + @abstractmethod + def load(self) -> Any: + """Read metadata and load data into the memory.""" + raise NotImplementedError + + def process(self, query: GeoQuery) -> Any: + """ + Process data with the query. + + Parameters + ---------- + query: GeoQuery + A query to use for data processing + + Results + ------- + res: Any + Result of `query` processing + """ + data_ = self.read() + return self._process_geokube_dataset(data_, query=query, compute=True) + + def _process_geokube_dataset( + self, + dataset: Dataset | DataCube, + query: GeoQuery, + compute: bool = False, + ) -> Dataset | DataCube: + self.log.info( + "processing geokube structure with Geoquery: %s '", query + ) + if not query: + self.log.info("query is empty!") + return dataset.compute() if compute else dataset + if isinstance(dataset, Dataset): + self.log.info("filtering with: %s", query.filters) + dataset = dataset.filter(**query.filters) + if isinstance(dataset, Delayed) and compute: + dataset = dataset.compute() + if query.variable: + self.log.info("selecting variable: %s", query.variable) + dataset = dataset[query.variable] + if query.area: + self.log.info("subsetting by bounding box: %s", query.area) + dataset = dataset.geobbox(**query.area) + if query.location: + self.log.info("subsetting by location: %s", query.location) + dataset = dataset.locations(**query.location) + if query.time: + self.log.info("subsetting by time: %s", query.time) + dataset = dataset.sel(time=query.time) + if query.vertical: + self.log.info("subsetting by vertical: %s", query.vertical) + method = None if isinstance(query.vertical, slice) else "nearest" + dataset = dataset.sel(vertical=query.vertical, method=method) + if isinstance(dataset, Dataset) and compute: + self.log.info( + "computing delayed datacubes in the dataset with %d" + " records...", + len(dataset), + ) + dataset = dataset.apply( + lambda dc: dc.compute() if isinstance(dc, Delayed) else dc + ) + return dataset diff --git a/drivers/intake_geokube/builders/__init__.py b/drivers/intake_geokube/builders/__init__.py new file mode 100644 index 0000000..0b7eded --- /dev/null +++ b/drivers/intake_geokube/builders/__init__.py @@ -0,0 +1 @@ +"""Subpackage with builders.""" diff --git a/drivers/intake_geokube/iot/__init__.py b/drivers/intake_geokube/iot/__init__.py new file mode 100644 index 0000000..5500b37 --- /dev/null +++ b/drivers/intake_geokube/iot/__init__.py @@ -0,0 +1 @@ +"""Domain-specific subpackage for IoT data.""" diff --git a/drivers/intake_geokube/iot/driver.py b/drivers/intake_geokube/iot/driver.py new file mode 100644 index 0000000..93c52cd --- /dev/null +++ b/drivers/intake_geokube/iot/driver.py @@ -0,0 +1,164 @@ +"""Driver for IoT data.""" + +import json +from collections import deque +from datetime import datetime +from typing import NoReturn + +import dateparser +import numpy as np +import pandas as pd +import streamz + +from ..base import AbstractBaseDriver +from ..queries.geoquery import GeoQuery + +d: deque = deque(maxlen=1) + + +def _build(data_model: dict) -> pd.DataFrame: + model_dict = { + data_model.get("time", ""): pd.to_datetime( + "01-01-1970 00:00:00", format="%d-%m-%Y %H:%M:%S" + ), + data_model.get("latitude", ""): [0.0], + data_model.get("longitude", ""): [0.0], + } + for f in data_model.get("filters", []): + model_dict[f] = [0] + for v in data_model.get("variables", []): + model_dict[v] = [0] + df_model = pd.DataFrame(model_dict) + df_model = df_model.set_index(data_model.get("time", "")) + return df_model + + +def _mqtt_preprocess(df, msg) -> pd.DataFrame: + payload = json.loads(msg.payload.decode("utf-8")) + if ("uplink_message" not in payload) or ( + "frm_payload" not in payload["uplink_message"] + ): + return df + data = payload["uplink_message"]["decoded_payload"]["data_packet"][ + "measures" + ] + date_time = pd.to_datetime( + datetime.now().strftime("%d-%m-%Y %H:%M:%S"), + format="%d-%m-%Y %H:%M:%S", + ) + data["device_id"] = payload["end_device_ids"]["device_id"] + data["string_type"] = 9 + data["cycle_duration"] = payload["uplink_message"]["decoded_payload"][ + "data_packet" + ]["timestamp"] + data["sensor_time"] = pd.to_datetime( + payload["received_at"], format="%Y-%m-%dT%H:%M:%S.%fZ" + ) + data["latitude"] = data["latitude"] / 10**7 + data["longitude"] = data["longitude"] / 10**7 + data["AirT"] = data["AirT"] / 100 + data["AirH"] = data["AirH"] / 100 + data["surfaceTemp"] = 2840 / 100 + row = pd.Series(data, name=date_time) + df = df._append(row) # pylint: disable=protected-access + return df + + +class IotDriver(AbstractBaseDriver): + """Driver class for IoT data.""" + + name: str = "iot_driver" + version: str = "0.1b0" + + def __init__( + self, + mqtt_kwargs, + time_window, + data_model, + start=False, + metadata=None, + **kwargs, + ): + super().__init__(metadata=metadata) + self.mqtt_kwargs = mqtt_kwargs + self.kwargs = kwargs + self.stream = None + self.time_window = time_window + self.start = start + self.df_model = _build(data_model) + + def _get_schema(self): + if not self.stream: + self.log.debug("creating stream...") + stream = streamz.Stream.from_mqtt(**self.mqtt_kwargs) + self.stream = stream.accumulate( + _mqtt_preprocess, returns_state=False, start=pd.DataFrame() + ).to_dataframe(example=self.df_model) + self.stream.stream.sink(d.append) + if self.start: + self.log.info("streaming started...") + self.stream.start() + return {"stream": str(self.stream)} + + def read(self) -> streamz.dataframe.core.DataFrame: + """Read IoT data.""" + self.log.info("reading stream...") + self._get_schema() + return self.stream + + def load(self) -> NoReturn: + """Load IoT data.""" + self.log.error("loading entire product is not supported for IoT data") + raise NotImplementedError( + "loading entire product is not supported for IoT data" + ) + + def process(self, query: GeoQuery) -> streamz.dataframe.core.DataFrame: + """Process IoT data with the passed query. + + Parameters + ---------- + query : intake_geokube.GeoQuery + A query to use + + Returns + ------- + stream : streamz.dataframe.core.DataFrame + A DataFrame object with streamed content + """ + df = d[0] + if not query: + self.log.info( + "method 'process' called without query. processing skipped." + ) + return df + if query.time: + if not isinstance(query.time, slice): + self.log.error( + "expected 'query.time' type is slice but found %s", + type(query.time), + ) + raise TypeError( + "expected 'query.time' type is slice but found" + f" {type(query.time)}" + ) + self.log.info("querying by time: %s", query.time) + df = df[query.time.start : query.time.stop] + else: + self.log.info( + "getting latest data for the predefined tie window: %s", + self.time_window, + ) + start = dateparser.parse(f"NOW - {self.time_window}") + stop = dateparser.parse("NOW") + df = df[start:stop] # type: ignore[misc] + if query.filters: + self.log.info("filtering with: %s", query.filters) + mask = np.logical_and.reduce( + [df[k] == v for k, v in query.filters.items()] + ) + df = df[mask] + if query.variable: + self.log.info("selecting variables: %s", query.variable) + df = df[query.variable] + return df diff --git a/drivers/intake_geokube/netcdf/__init__.py b/drivers/intake_geokube/netcdf/__init__.py new file mode 100644 index 0000000..315792c --- /dev/null +++ b/drivers/intake_geokube/netcdf/__init__.py @@ -0,0 +1 @@ +"""Domain-specific subpackage for netcdf data.""" diff --git a/drivers/intake_geokube/netcdf/driver.py b/drivers/intake_geokube/netcdf/driver.py new file mode 100644 index 0000000..e29cbfa --- /dev/null +++ b/drivers/intake_geokube/netcdf/driver.py @@ -0,0 +1,64 @@ +"""NetCDF driver for DDS.""" + +from geokube import open_datacube, open_dataset +from geokube.core.datacube import DataCube +from geokube.core.dataset import Dataset + +from ..base import AbstractBaseDriver + + +class NetCdfDriver(AbstractBaseDriver): + """Driver class for netCDF files.""" + + name = "netcdf_driver" + version = "0.1a0" + + def __init__( + self, + path: str, + metadata: dict, + pattern: str | None = None, + field_id: str | None = None, + metadata_caching: bool = False, + metadata_cache_path: str | None = None, + storage_options: dict | None = None, + xarray_kwargs: dict | None = None, + mapping: dict[str, dict[str, str]] | None = None, + load_files_on_persistance: bool = True, + ) -> None: + super().__init__(metadata=metadata) + self.path = path + self.pattern = pattern + self.field_id = field_id + self.metadata_caching = metadata_caching + self.metadata_cache_path = metadata_cache_path + self.storage_options = storage_options + self.mapping = mapping + self.xarray_kwargs = xarray_kwargs or {} + self.load_files_on_persistance = load_files_on_persistance + + @property + def _arguments(self) -> dict: + return { + "path": self.path, + "id_pattern": self.field_id, + "metadata_caching": self.metadata_caching, + "metadata_cache_path": self.metadata_cache_path, + "mapping": self.mapping, + } | self.xarray_kwargs + + def read(self) -> Dataset | DataCube: + """Read netCDF.""" + if self.pattern: + return open_dataset( + pattern=self.pattern, delay_read_cubes=True, **self._arguments + ) + return open_datacube(**self._arguments) + + def load(self) -> Dataset | DataCube: + """Load netCDF.""" + if self.pattern: + return open_dataset( + pattern=self.pattern, delay_read_cubes=False, **self._arguments + ) + return open_datacube(**self._arguments) diff --git a/drivers/intake_geokube/queries/__init__.py b/drivers/intake_geokube/queries/__init__.py new file mode 100644 index 0000000..e6847fb --- /dev/null +++ b/drivers/intake_geokube/queries/__init__.py @@ -0,0 +1 @@ +"""Subpackage with queries.""" diff --git a/drivers/intake_geokube/queries/geoquery.py b/drivers/intake_geokube/queries/geoquery.py new file mode 100644 index 0000000..9ab408a --- /dev/null +++ b/drivers/intake_geokube/queries/geoquery.py @@ -0,0 +1,94 @@ +"""Module with GeoQuery definition.""" + +from __future__ import annotations + +import json +from typing import Any + +from pydantic import ( + BaseModel, + ConfigDict, + Field, + field_serializer, + model_validator, +) + +from .types import BoundingBoxDict, SliceQuery, TimeComboDict +from .utils import maybe_dict_to_slice, slice_to_dict + + +class GeoQuery(BaseModel, extra="allow"): + """GeoQuery definition class.""" + + model_config = ConfigDict(arbitrary_types_allowed=True) + + variable: list[str] | None = None + time: SliceQuery | TimeComboDict | None = None + area: BoundingBoxDict | None = None + location: dict[str, float | list[float]] | None = None + vertical: SliceQuery | float | list[float] | None = None + filters: dict[str, Any] = Field(default_factory=dict) + format: str | None = None + format_args: dict[str, Any] | None = None + + @field_serializer("time") + def serialize_time(self, time: SliceQuery | TimeComboDict | None, _info): + """Serialize time.""" + if isinstance(time, slice): + return slice_to_dict(time) + return time + + @model_validator(mode="after") + @classmethod + def area_locations_mutually_exclusive_validator(cls, query): + """Assert 'locations' and 'area' are not passed at once.""" + if query.area is not None and query.location is not None: + raise KeyError( + "area and location couldn't be processed together, please use" + " one of them" + ) + return query + + @model_validator(mode="before") + @classmethod + def build_filters(cls, values: dict[str, Any]) -> dict[str, Any]: + """Build filters based on extra arguments.""" + if "filters" in values: + return values + filters = {} + fields = {} + for k in values.keys(): + if k in cls.model_fields: + fields[k] = values[k] + continue + if isinstance(values[k], dict): + values[k] = maybe_dict_to_slice(values[k]) + filters[k] = values[k] + fields["filters"] = filters + return fields + + def model_dump_original(self, skip_empty: bool = True) -> dict: + """Return the JSON representation of the original query.""" + res = super().model_dump() + res = {**res.pop("filters", {}), **res} + if skip_empty: + res = dict(filter(lambda item: item[1] is not None, res.items())) + return res + + @classmethod + def parse( + cls, load: "GeoQuery" | dict | str | bytes | bytearray + ) -> "GeoQuery": + """Parse load to GeoQuery instance.""" + if isinstance(load, cls): + return load + if isinstance(load, (str, bytes, bytearray)): + load = json.loads(load) + if isinstance(load, dict): + load = GeoQuery(**load) + else: + raise TypeError( + f"type of the `load` argument ({type(load).__name__}) is not" + " supported!" + ) + return load diff --git a/drivers/intake_geokube/queries/types.py b/drivers/intake_geokube/queries/types.py new file mode 100644 index 0000000..cfb7327 --- /dev/null +++ b/drivers/intake_geokube/queries/types.py @@ -0,0 +1,10 @@ +"""Module with types definitions.""" + +from pydantic import BeforeValidator +from typing_extensions import Annotated + +from . import utils as ut + +SliceQuery = Annotated[slice, BeforeValidator(ut.dict_to_slice)] +TimeComboDict = Annotated[dict, BeforeValidator(ut.assert_time_combo_dict)] +BoundingBoxDict = Annotated[dict, BeforeValidator(ut.assert_bounding_box_dict)] diff --git a/drivers/intake_geokube/queries/utils.py b/drivers/intake_geokube/queries/utils.py new file mode 100644 index 0000000..c2fb2dd --- /dev/null +++ b/drivers/intake_geokube/queries/utils.py @@ -0,0 +1,106 @@ +"""Module with util functions.""" + +from typing import Any, Collection, Hashable, Iterable + +import dateparser +from pydantic.fields import FieldInfo + +_TIME_COMBO_SUPPORTED_KEYS: tuple[str, ...] = ( + "year", + "month", + "day", + "hour", +) + +_BBOX_SUPPORTED_KEYS: tuple[str, ...] = ( + "north", + "south", + "west", + "east", +) + + +def _validate_dict_keys( + provided_keys: Iterable, supported_keys: Collection +) -> None: + for provided_k in provided_keys: + assert ( + provided_k in supported_keys + ), f"key '{provided_k}' is not among supported ones: {supported_keys}" + + +def dict_to_slice(mapping: dict) -> slice: + """Convert dictionary to slice.""" + mapping = mapping or {} + assert "start" in mapping or "stop" in mapping, ( + "missing at least of of the keys ['start', 'stop'] required to" + " construct slice object based on the dictionary" + ) + if "start" in mapping and "NOW" in mapping["start"]: + mapping["start"] = dateparser.parse(mapping["start"]) + if "stop" in mapping and "NOW" in mapping["stop"]: + mapping["stop"] = dateparser.parse(mapping["stop"]) + return slice( + mapping.get("start"), + mapping.get("stop"), + mapping.get("step"), + ) + + +def maybe_dict_to_slice(mapping: Any) -> slice: + """Convert valid dictionary to slice or return the original one.""" + if "start" in mapping or "stop" in mapping: + return dict_to_slice(mapping) + return mapping + + +def slice_to_dict(slice_: slice) -> dict: + """Convert slice to dictionary.""" + return {"start": slice_.start, "stop": slice_.stop, "step": slice_.step} + + +def assert_time_combo_dict(mapping: dict) -> dict: + """Check if dictionary contains time-combo related keys.""" + _validate_dict_keys(mapping.keys(), _TIME_COMBO_SUPPORTED_KEYS) + return mapping + + +def assert_bounding_box_dict(mapping: dict) -> dict: + """Check if dictionary contains bounding-box related keys.""" + _validate_dict_keys(mapping.keys(), _BBOX_SUPPORTED_KEYS) + return mapping + + +def split_extra_arguments( + values: dict, fields: dict[str, FieldInfo] +) -> tuple[dict, dict]: + """Split arguments to field-related and auxiliary.""" + extra_args: dict = {} + field_args: dict = {} + extra_args = {k: v for k, v in values.items() if k not in fields} + field_args = {k: v for k, v in values.items() if k in fields} + return (field_args, extra_args) + + +def find_value( + content: dict | list, key: Hashable, *, recursive: bool = False +) -> Any: + """Return value for a 'key' (recursive search).""" + result = None + if isinstance(content, dict): + if key in content: + return content[key] + if not recursive: + return result + for value in content.values(): + if isinstance(value, (dict, list)): + result = result or find_value(value, key, recursive=True) + elif isinstance(content, list): + for el in content: + result = result or find_value(el, key, recursive=True) + else: + raise TypeError( + "'content' argument need to be a dictionary or a list but found," + f" '{type(content)}" + ) + return result diff --git a/drivers/intake_geokube/queries/workflow.py b/drivers/intake_geokube/queries/workflow.py new file mode 100644 index 0000000..a93cd91 --- /dev/null +++ b/drivers/intake_geokube/queries/workflow.py @@ -0,0 +1,72 @@ +"""Module with workflow definition.""" + +from __future__ import annotations + +import json +from collections import Counter +from typing import Any + +from pydantic import BaseModel, Field, field_validator, model_validator + +from .utils import find_value + + +class Task(BaseModel): + """Single task model definition.""" + + id: str | int + op: str + use: list[str | int] = Field(default_factory=list) + args: dict[str, Any] = Field(default_factory=dict) + + +class Workflow(BaseModel): + """Workflow model definition.""" + + tasks: list[Task] + dataset_id: str = "" + product_id: str = "" + + @model_validator(mode="before") + @classmethod + def obtain_dataset_id(cls, values): + """Get dataset_id and product_id from included tasks.""" + dataset_id = find_value(values, key="dataset_id", recursive=True) + if not dataset_id: + raise KeyError( + "'dataset_id' key was missing. did you defined it for 'args'?" + ) + product_id = find_value(values, key="product_id", recursive=True) + if not product_id: + raise KeyError( + "'product_id' key was missing. did you defined it for 'args'?" + ) + return values | {"dataset_id": dataset_id, "product_id": product_id} + + @field_validator("tasks", mode="after") + @classmethod + def match_unique_ids(cls, items): + """Verify the IDs are uniqe.""" + for id_value, id_count in Counter([item.id for item in items]).items(): + if id_count != 1: + raise ValueError(f"duplicated key found: `{id_value}`") + return items + + @classmethod + def parse( + cls, + workflow: Workflow | dict | list[dict] | str | bytes | bytearray, + ) -> Workflow: + """Parse to Workflow model.""" + if isinstance(workflow, cls): + return workflow + if isinstance(workflow, (str | bytes | bytearray)): + workflow = json.loads(workflow) + if isinstance(workflow, list): + return cls(tasks=workflow) # type: ignore[arg-type] + if isinstance(workflow, dict): + return cls(**workflow) + raise TypeError( + f"`workflow` argument of type `{type(workflow).__name__}`" + " cannot be safetly parsed to the `Workflow`" + ) diff --git a/drivers/intake_geokube/sentinel/__init__.py b/drivers/intake_geokube/sentinel/__init__.py new file mode 100644 index 0000000..4957128 --- /dev/null +++ b/drivers/intake_geokube/sentinel/__init__.py @@ -0,0 +1 @@ +"""Domain-specific subpackage for sentinel data.""" diff --git a/drivers/intake_geokube/sentinel/auth.py b/drivers/intake_geokube/sentinel/auth.py new file mode 100644 index 0000000..680bfb2 --- /dev/null +++ b/drivers/intake_geokube/sentinel/auth.py @@ -0,0 +1,45 @@ +"""Module with auth utils for accessing sentinel data.""" + +import os + +import requests +from requests.auth import AuthBase + + +class SentinelAuth(AuthBase): # pylint: disable=too-few-public-methods + """Class ewith authentication for accessing sentinel data.""" + + _SENTINEL_AUTH_URL: str = os.environ.get( + "SENTINEL_AUTH_URL", + "https://identity.dataspace.copernicus.eu/auth/realms/CDSE/protocol/openid-connect/token", + ) + + def __init__(self, username: str, password: str) -> None: + self.username = username + self.password = password + + @classmethod + def _get_access_token(cls, username: str, password: str) -> str: + data = { + "client_id": "cdse-public", + "username": username, + "password": password, + "grant_type": "password", + } + try: + response = requests.post( + cls._SENTINEL_AUTH_URL, data=data, timeout=10 + ) + response.raise_for_status() + except Exception as e: + raise RuntimeError( + "Access token creation failed. Reponse from the server was:" + f" {response.json()}" + ) from e + return response.json()["access_token"] + + def __call__(self, request): + """Add authorization header.""" + token: str = self._get_access_token(self.username, self.password) + request.headers["Authorization"] = f"Bearer {token}" + return request diff --git a/drivers/intake_geokube/sentinel/driver.py b/drivers/intake_geokube/sentinel/driver.py new file mode 100644 index 0000000..4895103 --- /dev/null +++ b/drivers/intake_geokube/sentinel/driver.py @@ -0,0 +1,342 @@ +"""Geokube driver for sentinel data.""" + +import glob +import os +import string +import zipfile +from multiprocessing.util import get_temp_dir +from typing import Collection, NoReturn + +import dask +import numpy as np +import pandas as pd +import xarray as xr +from geokube.backend.netcdf import open_datacube +from geokube.core.dataset import Dataset +from intake.source.utils import reverse_format +from pyproj import Transformer +from pyproj.crs import CRS, GeographicCRS + +from ..base import AbstractBaseDriver +from ..queries.geoquery import GeoQuery +from ..queries.types import BoundingBoxDict, TimeComboDict +from .auth import SentinelAuth +from .odata_builder import ODataRequest, ODataRequestBuilder + + +def _get_items_nbr(mapping, key) -> int: + if isinstance(mapping[key], str): + return 1 + return len(mapping[key]) if isinstance(mapping[key], Collection) else 1 + + +def _validate_geoquery_for_sentinel(query: GeoQuery) -> None: + if query.time: + if isinstance(query.time, dict) and any([ + _get_items_nbr(query.time, "year") != 1, + _get_items_nbr(query.time, "month") != 1, + _get_items_nbr(query.time, "day") != 1, + ]): + raise ValueError( + "valid time combo for sentinel data should contain exactly one" + " value for 'year', one for 'month', and one for 'day'" + ) + if query.location and ( + "latitude" not in query.location or "longitude" not in query.location + ): + raise ValueError( + "both 'latitude' and 'longitude' must be defined for location" + ) + + +def _bounding_box_to_polygon( + bbox: BoundingBoxDict, +) -> list[tuple[float, float]]: + return [ + (bbox["north"], bbox["west"]), + (bbox["north"], bbox["east"]), + (bbox["south"], bbox["east"]), + (bbox["south"], bbox["west"]), + (bbox["north"], bbox["west"]), + ] + + +def _timecombo_to_day_range(combo: TimeComboDict) -> tuple[str, str]: + return (f"{combo['year']}-{combo['month']}-{combo['day']}T00:00:00", + f"{combo['year']}-{combo['month']}-{combo['day']}T23:59:59") + + +def _location_to_valid_point( + location: dict[str, float | list[float]] +) -> tuple[float, float]: + if isinstance(location["latitude"], list): + if len(location["latitude"]) > 1: + raise ValueError( + "location can have just a single point (single value for" + " 'latitude' and 'longitude')" + ) + lat = location["latitude"][0] + else: + lat = location["latitude"] + if isinstance(location["longitude"], list): + if len(location["longitude"]) > 1: + raise ValueError( + "location can have just a single point (single value for" + " 'latitude' and 'longitude')" + ) + lon = location["longitude"][0] + else: + lon = location["longitude"] + return (lat, lon) + + +def _validate_path_and_pattern(path: str, pattern: str): + if path.startswith(os.sep) or pattern.startswith(os.sep): + raise ValueError(f"path and pattern cannot start with {os.sep}") + + +def _get_attrs_keys_from_pattern(pattern: str) -> list[str]: + return list( + map( + lambda x: str(x[1]), + filter(lambda x: x[1], string.Formatter().parse(pattern)), + ) + ) + + +def unzip_and_clear(target: str) -> None: + """Unzip ZIP archives in 'target' dir and remove archive.""" + assert os.path.exists(target), f"directory '{target}' does not exist" + for file in os.listdir(target): + if not file.endswith(".zip"): + continue + prod_id = os.path.splitext(os.path.basename(file))[0] + target_prod = os.path.join(target, prod_id) + os.makedirs(target_prod, exist_ok=True) + try: + with zipfile.ZipFile(os.path.join(target, file)) as archive: + archive.extractall(path=target_prod) + except zipfile.BadZipFile as err: + raise RuntimeError("downloaded ZIP archive is invalid") from err + os.remove(os.path.join(target, file)) + + +def _get_field_name_from_path(path: str): + res, file = path.split(os.sep)[-2:] + band = file.split("_")[-2] + return f"{res}_{band}" + + +def preprocess_sentinel(dset: xr.Dataset) -> xr.Dataset: + """Preprocessing function for sentinel data. + + Parameters + ---------- + dset : xarray.Dataset + xarray.Dataset to be preprocessed + + Returns + ------- + ds : xarray.Dataset + Preprocessed xarray.Dataset + """ + crs = CRS.from_cf(dset["spatial_ref"].attrs) + transformer = Transformer.from_crs( + crs_from=crs, crs_to=GeographicCRS(), always_xy=True + ) + x_vals, y_vals = dset["x"].to_numpy(), dset["y"].to_numpy() + lon_vals, lat_vals = transformer.transform(*np.meshgrid(x_vals, y_vals)) # type: ignore[call-overload] # pylint: disable=unpacking-non-sequence + source_path = dset.encoding["source"] + sensing_time = os.path.splitext(source_path.split(os.sep)[-6])[0].split( + "_" + )[-1] + time = pd.to_datetime([sensing_time]).to_numpy() + dset = dset.assign_coords({ + "time": time, + "latitude": (("x", "y"), lat_vals), + "longitude": (("x", "y"), lon_vals), + }).rename({"band_data": _get_field_name_from_path(source_path)}) + expanded_timedim_dataarrays = {var_name: dset[var_name].expand_dims('time') for var_name in dset.data_vars} + dset = dset.update(expanded_timedim_dataarrays) + return dset + + +class _SentinelKeys: # pylint: disable=too-few-public-methods + UUID: str = "Id" + SENSING_TIME: str = "ContentDate/Start" + TYPE: str = "Name" + + +class SentinelDriver(AbstractBaseDriver): + """Driver class for sentinel data.""" + + name: str = "sentinel_driver" + version: str = "0.1b0" + + def __init__( + self, + metadata: dict, + url: str, + zippattern: str, + zippath: str, + type: str, + username: str | None = None, + password: str | None = None, + sentinel_timeout: int | None = None, + mapping: dict | None = None, + xarray_kwargs: dict | None = None, + ) -> None: + super().__init__(metadata=metadata) + self.url: str = url + self.zippattern: str = zippattern + self.zippath: str = zippath + self.type_ = type + _validate_path_and_pattern(path=self.zippath, pattern=self.zippattern) + self.auth: SentinelAuth = self._get_credentials(username, password) + self.target_dir: str = get_temp_dir() + self.sentinel_timeout: int | None = sentinel_timeout + self.mapping: dict = mapping or {} + self.xarray_kwargs: dict = xarray_kwargs or {} + + def _get_credentials( + self, username: str | None, password: str | None + ) -> SentinelAuth: + if username and password: + return SentinelAuth( + username=username, + password=password, + ) + self.log.debug("getting credentials from environmental variables...") + if ( + "SENTINEL_USERNAME" not in os.environ + or "SENTINEL_PASSWORD" not in os.environ + ): + self.log.error( + "missing at least of of the mandatory environmental variables:" + " ['SENTINEL_USERNAME', 'SENTINEL_PASSWORD']" + ) + raise KeyError( + "missing at least of of the mandatory environmental variables:" + " ['SENTINEL_USERNAME', 'SENTINEL_PASSWORD']" + ) + return SentinelAuth( + username=os.environ["SENTINEL_USERNAME"], + password=os.environ["SENTINEL_PASSWORD"], + ) + + def _force_sentinel_type(self, builder): + self.log.info("forcing sentinel type: %s...", self.type_) + return builder.filter(_SentinelKeys.TYPE, containing=self.type_) + + def _filter_by_sentinel_attrs(self, builder, query: GeoQuery): + self.log.info("filtering by sentinel attributes...") + path_filter_names: set[str] = { + parsed[1] + for parsed in string.Formatter().parse(self.zippattern) + if parsed[1] + } + if not query.filters: + return builder + sentinel_filter_names: set[str] = ( + query.filters.keys() - path_filter_names + ) + for sf in sentinel_filter_names: + builder = builder.filter_attr(sf, query.filters[sf]) + return builder + + def _build_odata_from_geoquery(self, query: GeoQuery) -> ODataRequest: + self.log.debug("validating geoquery...") + _validate_geoquery_for_sentinel(query) + self.log.debug("constructing odata request...") + builder = ODataRequestBuilder.new(url=self.url) + if "product_id" in query.filters: + builder = builder.filter( + name=_SentinelKeys.UUID, eq=query.filters.get("product_id") + ) + builder = self._filter_by_sentinel_attrs(builder, query=query) + builder = self._force_sentinel_type(builder) + if query.time: + if isinstance(query.time, dict): + timecombo_start, timecombo_end = _timecombo_to_day_range(query.time) + self.log.debug("filtering by timecombo: [%s, %s] ", timecombo_start, timecombo_end) + builder = builder.filter_date( + _SentinelKeys.SENSING_TIME, ge=timecombo_start, le=timecombo_end + ) + elif isinstance(query.time, slice): + self.log.debug("filtering by slice: %s", query.time) + builder = builder.filter_date( + _SentinelKeys.SENSING_TIME, + ge=query.time.start, + le=query.time.stop, + ) + if query.area: + self.log.debug("filering by polygon") + polygon = _bounding_box_to_polygon(query.area) + builder = builder.intersect_polygon(polygon=polygon) + if query.location: + self.log.debug("filering by location") + point = _location_to_valid_point(query.location) + builder = builder.intersect_point(point=point) + return builder.build() + + def _prepare_dataset(self) -> Dataset: + data: list = [] + attrs_keys: list[str] = _get_attrs_keys_from_pattern(self.zippattern) + for f in glob.glob(os.path.join(self.target_dir, self.zippath)): + self.log.debug("processsing file %s", f) + file_no_tmp_dir = f.removeprefix(self.target_dir).strip(os.sep) + attr = reverse_format(self.zippattern, file_no_tmp_dir) + attr[Dataset.FILES_COL] = [f] + data.append(attr) + # NOTE: eventually, join files if there are several for the same attrs + # combintation + df = ( + pd.DataFrame(data) + .groupby(attrs_keys) + .agg({Dataset.FILES_COL: sum}) + ) + datacubes = [] + for ind, files in df.iterrows(): + load = dict(zip(df.index.names, ind)) + load[Dataset.FILES_COL] = files + load[Dataset.DATACUBE_COL] = dask.delayed(open_datacube)( + path=files.item(), + id_pattern=None, + mapping=self.mapping, + metadata_caching=False, + **self.xarray_kwargs, + preprocess=preprocess_sentinel, + ) + datacubes.append(load) + return Dataset(pd.DataFrame(datacubes)) + + def read(self) -> NoReturn: + """Read sentinel data.""" + raise NotImplementedError( + "reading metadata is not supported for sentinel data" + ) + + def load(self) -> NoReturn: + """Load sentinel data.""" + raise NotImplementedError( + "loading entire product is not supported for sentinel data" + ) + + def process(self, query: GeoQuery) -> Dataset: + """Process query for sentinel data.""" + self.log.info("builder odata request based on passed geoquery...") + req = self._build_odata_from_geoquery(query) + self.log.info("downloading data...") + req.download( + target_dir=self.target_dir, + auth=self.auth, + timeout=self.sentinel_timeout, + ) + self.log.info("unzipping and removing archives...") + unzip_and_clear(self.target_dir) + self.log.info("preparing geokube.Dataset...") + dataset = self._prepare_dataset() + dataset = super()._process_geokube_dataset( + dataset, query=query, compute=True + ) + return dataset diff --git a/drivers/intake_geokube/sentinel/odata_builder.py b/drivers/intake_geokube/sentinel/odata_builder.py new file mode 100644 index 0000000..4036810 --- /dev/null +++ b/drivers/intake_geokube/sentinel/odata_builder.py @@ -0,0 +1,564 @@ +"""Module with OData API classes definitions.""" + +from __future__ import annotations + +__all__ = ( + "datetime_to_isoformat", + "HttpMethod", + "ODataRequestBuilder", + "ODataRequest", +) + +import math +import os +import warnings +from collections import defaultdict +from datetime import datetime +from enum import Enum, auto +from typing import Any, Callable + +import pandas as pd +import requests +from tqdm import tqdm + +from ..utils import create_zip_from_response +from .auth import SentinelAuth + + +def datetime_to_isoformat(date: str | datetime) -> str: + """Convert string of datetime object to ISO datetime string.""" + if isinstance(date, str): + try: + value = pd.to_datetime([date]).item().isoformat() + except ValueError as exc: + raise ValueError(f"cannot parse '{date}' to datetime") from exc + elif isinstance(date, datetime): + value = value.isoformat() + else: + raise TypeError(f"type '{type(date)}' is not supported") + return f"{value}Z" + + +class HttpMethod(Enum): + """Enum with HTTP methods.""" + + GET = auto() + POST = auto() + + @property + def method_name(self) -> str: + """Get name of the HTTP method.""" + return self.name.lower() + + +class _ODataEntity: # pylint: disable=too-few-public-methods + def __init__( + self, + url: str, + params: dict | None = None, + method: HttpMethod = HttpMethod.GET, + body: dict | None = None, + ) -> None: + if not params: + self.params: dict[str, list] = defaultdict(list) + self.conj: list = [] + if not body: + self.body: dict = {} + self.url = url + self.method = method + self.callbacks: dict = {} + + +class _ODataBuildableMixin: # pylint: disable=too-few-public-methods + odata: _ODataEntity + + def build(self) -> ODataRequest: + """Build ODataRequest object.""" + return ODataRequest(self.odata) + + +class _ODataOrderMixing: # pylint: disable=too-few-public-methods + odata: _ODataEntity + + def order(self, by: str, desc: bool = False) -> _ODataOperation: + """Add ordering option. + + Parameters + ---------- + by : str + A key by which ordering should be done + desc : bool + If descending order should be used + """ + order = "desc" if desc else "asc" + if "orderby" in self.odata.params: + raise ValueError( + f"ordering was already defined: {self.odata.params['orderby']}" + ) + self.odata.params["orderby"] = [f"{by} {order}"] + match self: + case _ODataOperation(): + return _ODataOperation(self.odata) + case _: + raise TypeError(f"unexpected type: {type(self)}") + + +class ODataRequest: + """OData request object.""" + + _ALL_HTTP_CODES: int = -1 + _DOWNLOAD_PATTERN: str = ( + "https://zipper.dataspace.copernicus.eu" + "/odata/v1/Products({pid})/$value" + ) + + def __init__(self, odata: _ODataEntity) -> None: + self.request_params: dict = {} + self.odata = odata + self._convert_filter_param() + self._convert_order_param() + + def _convert_order_param(self) -> None: + if self.odata.params["orderby"]: + self.request_params["orderby"] = self.odata.params["orderby"] + + def _convert_filter_param(self) -> None: + param: str = "" + for i in range(len(self.odata.params["filter"])): + if not param: + param = self.odata.params["filter"][i] + else: + param = f"{param} {self.odata.params['filter'][i]}" + if i < len(self.odata.params["filter"]) - 1: + param = f"{param} {self.odata.conj[i]}" + self.request_params["filter"] = param + + def _query( + self, + headers: dict | None = None, + auth: Any | None = None, + timeout: int | None = None, + ) -> requests.Response: + if self.odata.params and not self.odata.url.endswith("?"): + self.odata.url = f"{self.odata.url}?" + params = {} + if self.request_params: + params = { + f"${key}": value for key, value in self.request_params.items() + } + match self.odata.method: + case HttpMethod.GET: + return requests.get( + self.odata.url, + params=params, + headers=headers, + timeout=timeout, + ) + case HttpMethod.POST: + return requests.post( + self.odata.url, + data=self.odata.body, + auth=auth, + timeout=timeout, + ) + case _: + raise NotImplementedError( + f"method {self.odata.method} is not supported" + ) + + def with_callback( + self, + callback: Callable[[requests.Response], Any], + http_code: int | None = None, + ) -> "ODataRequest": + """ + Add callbacks for request response. + + Parameters + ---------- + callback : callable + A callback function taking just a single argument, + i.e `requests.Response` object + http_code : int + HTTP code for which callback should be used. + If not passed, callback will be executed for all codes. + """ + if http_code: + if http_code in self.odata.callbacks: + warnings.warn( + f"callback for HTTP code {http_code} will be overwritten" + ) + self.odata.callbacks[http_code] = callback + else: + self.odata.callbacks[self._ALL_HTTP_CODES] = callback + return self + + def query( + self, + headers: dict | None = None, + auth: Any | None = None, + timeout: int | None = None, + ) -> Any: + """Query data based on the built request. + + Parameters + ---------- + headers : dict, optional + Headers passed to HTTP request + auth : Any, optional + Authorization object or tuple (,) for basic authentication + + Returns + ------- + res : Any + Value returned from the appropriate callback or `requests.Response` object otherwise + """ + response = self._query(headers=headers, auth=auth, timeout=timeout) + if response.status_code in self.odata.callbacks: + return self.odata.callbacks[response.status_code](response) + if self._ALL_HTTP_CODES in self.odata.callbacks: + return self.odata.callbacks[self._ALL_HTTP_CODES](response) + return response + + def download( + self, + target_dir: str, + headers: dict | None = None, + auth: Any | None = None, + timeout: int | None = None, + ) -> Any: + """Download requested data to `target_dir`. + + Parameters + ---------- + target_dir : str + Path to the directory where files should be downloaded + headers : dict, optional + Headers passed to HTTP request + auth : Any, optional + Authorization object or tuple (,) for basic + authentication + """ + os.makedirs(target_dir, exist_ok=True) + response = self._query(headers=headers, auth=auth, timeout=timeout) + response.raise_for_status() + if response.status_code in self.odata.callbacks: + self.odata.callbacks[response.status_code](response) + if self._ALL_HTTP_CODES in self.odata.callbacks: + self.odata.callbacks[self._ALL_HTTP_CODES](response) + df = pd.DataFrame(response.json()["value"]) + if len(df) == 0: + raise ValueError("no product found for the request") + if not isinstance(auth, SentinelAuth): + raise TypeError( + f"expected authentication of the type '{SentinelAuth}' but" + f" passed '{type(auth)}'" + ) + for pid in tqdm(df["Id"]): + response = requests.get( + self._DOWNLOAD_PATTERN.format(pid=pid), + stream=True, + auth=auth, + timeout=timeout, + ) + response.raise_for_status() + create_zip_from_response( + response, os.path.join(target_dir, f"{pid}.zip") + ) + + +class _ODataOperation(_ODataBuildableMixin, _ODataOrderMixing): + def __init__(self, odata: _ODataEntity) -> None: + self.odata = odata + + def _append_query_param(self, param: str | None) -> None: + if not param: + return + self.odata.params["filter"].append(param) + self.odata.conj.append("and") + + def _validate_args(self, lt, le, eq, ge, gt) -> None: + if eq: + if any(map(lambda x: x is not None, [lt, le, ge, gt])): + raise ValueError( + "cannot define extra operations for a single option if" + " `eq` is defined" + ) + if lt and le: + raise ValueError( + "cannot define both operations `lt` and `le` for a single" + " option" + ) + if gt and ge: + raise ValueError( + "cannot define both operations `gt` and `ge` for a single" + " option" + ) + + def and_(self) -> _ODataOperation: + """Put conjunctive conditions.""" + self.odata.conj[-1] = "and" + return self + + def or_(self) -> _ODataOperation: + """Put alternative conditions.""" + self.odata.conj[-1] = "or" + return self + + def filter_attr(self, name: str, value: str) -> _ODataOperation: + """Filter by attribute value. + + Parameters + ---------- + name : str + Name of an attribute + value : str + Value of the attribute + """ + param: str = ( + "Attributes/OData.CSC.ValueTypeAttribute/any(att:att/Name eq" + f" ‘[{name}]’" + + f"and att/OData.CSC.ValueTypeAttribute/Value eq ‘{value}]’)" + ) + self._append_query_param(param) + return self + + def filter( + self, + name: str, + *, + lt: str | None = None, + le: str | None = None, + eq: str | None = None, + ge: str | None = None, + gt: str | None = None, + containing: str | None = None, + not_containing: str | None = None, + ) -> _ODataOperation: + """Filter option by values. + + Add filter option to the request. Value of an option indicated by + the `name` argument will be checked agains given values or arguments. + You cannot specify both `lt` and `le` or `ge` and `gt. + + Parameters + ---------- + lt : str, optional + value for `less than` comparison + le : str, optional + value for `less ord equal` comparison + eq : str, optional + value for `equal` comparison + ge : str, optional + value for `greater or equal` comparison + gt : str, optional + value for `greater than` comparison + containing : str, optional + value to be contained + not_containing : str, optional + value not to be containing + """ + if not any([le, lt, eq, ge, gt, containing, not_containing]): + return self + self._validate_args(le=le, lt=lt, eq=eq, ge=ge, gt=gt) + build_: _ODataOperation = self + assert isinstance(build_, _ODataOperation), "unexpected type" + if lt: + build_ = build_.with_option_lt(name, lt).and_() + if le: + build_ = build_.with_option_le(name, le).and_() + if eq: + build_ = build_.with_option_equal(name, eq).and_() + if ge: + build_ = build_.with_option_ge(name, ge).and_() + if gt: + build_ = build_.with_option_gt(name, gt).and_() + if containing: + build_ = build_.with_option_containing(name, containing).and_() + if not_containing: + build_ = build_.with_option_not_containing( + name, not_containing + ).and_() + + return build_ + + def filter_date( + self, + name: str, + *, + lt: str | None = None, + le: str | None = None, + eq: str | None = None, + ge: str | None = None, + gt: str | None = None, + ) -> _ODataOperation: + """ + Filter datetetime option by values. + + Add filter option to the request. Datetime values of an option + indicated by the `name` argument will be checked agains given + values or arguments. + Values of arguments will be automatically converted to ISO datetime + string format. + You cannot specify both `lt` and `le` or `ge` and `gt. + + Parameters + ---------- + lt : str, optional + value for `less than` comparison + le : str, optional + value for `less ord equal` comparison + eq : str, optional + value for `equal` comparison + ge : str, optional + value for `greater or equal` comparison + gt : str, optional + value for `greater than` comparison + """ + if lt: + lt = datetime_to_isoformat(lt) + if le: + le = datetime_to_isoformat(le) + if eq: + eq = datetime_to_isoformat(eq) + if ge: + ge = datetime_to_isoformat(ge) + if gt: + gt = datetime_to_isoformat(gt) + return self.filter(name, lt=lt, le=le, eq=eq, ge=ge, gt=gt) + + def with_option_equal(self, name: str, value: str) -> "_ODataOperation": + """Add filtering by option `is equal`.""" + param: str = f"{name} eq '{value}'" + self._append_query_param(param) + return self + + def with_option_containing( + self, name: str, value: str + ) -> "_ODataOperation": + """Add filtering by option `containing`.""" + param: str = f"contains({name},'{value}')" + self._append_query_param(param) + return self + + def with_option_not_containing( + self, name: str, value: str + ) -> "_ODataOperation": + """Add filtering by option `not containing`.""" + param: str = f"not contains({name},'{value}')" + self._append_query_param(param) + return self + + def with_option_equal_list( + self, name: str, value: list[str] + ) -> "_ODataOperation": + """Add filtering by equality.""" + self.odata.body.update({"FilterProducts": [{name: v} for v in value]}) + self.odata.method = HttpMethod.POST + return self + + def with_option_lt(self, name: str, value: str) -> "_ODataOperation": + """Add filtering with `less than` option.""" + param: str = f"{name} lt {value}" + self._append_query_param(param) + return self + + def with_option_le(self, name: str, value: str) -> "_ODataOperation": + """Add filtering with `less or equal` option.""" + param: str = f"{name} le {value}" + self._append_query_param(param) + return self + + def with_option_gt(self, name: str, value: str) -> "_ODataOperation": + """Add filtering with `greater or equal` option.""" + param: str = f"{name} gt {value}" + self._append_query_param(param) + return self + + def with_option_ge(self, name: str, value: str) -> "_ODataOperation": + """Add filtering with `greater than` option.""" + param: str = f"{name} ge {value}" + self._append_query_param(param) + return self + + def intersect_polygon( + self, + polygon: list[tuple[float, float]] | list[list[float]], + srid: str | None = "4326", + ) -> "_ODataOperation": + """ + Add filtering by polygon intersection. + + Parameters + ---------- + polygon: list of 2-element tuple or 2-element lists of floats + Points belonging to the polygon [longitude, latitude]. + The 1st at the last point needs to be the same (polygon needs + to be closed) + srid : str, optional + SRID name, currently supported is only `4326` + """ + if srid != "4326": + raise NotImplementedError( + "currently supported SRID is only ['4326' (EPSG 4326)]" + ) + if not polygon: + return self + if any(map(lambda x: len(x) != 2, polygon)): + raise ValueError( + "polygon should be defined as a 2-element list or tuple" + " (containing latitude and longitude values)" + ) + if not math.isclose(polygon[0][0], polygon[-1][0]) or not math.isclose( + polygon[0][1], polygon[-1][1] + ): + raise ValueError( + "polygon needs to end at the same point it starts!" + ) + polygon_repr = ",".join([f"{p[1]} {p[0]}" for p in polygon]) + param = f"OData.CSC.Intersects(area=geography'SRID={srid};POLYGON(({polygon_repr}))')" + self._append_query_param(param) + return self + + def intersect_point( + self, + point: list[float] | tuple[float, float], + srid: str | None = "4326", + ) -> "_ODataOperation": + """Add filtering by intersection with a point. + + Parameters + ---------- + point: 2-element tuple or list of floats + Point definition [latitude, longitude] + srid : str, optional + SRID name, currently supported is only `4326` + """ + if srid != "4326": + raise NotImplementedError( + "currently supported SRID is only ['4326' (EPSG 4326)]" + ) + if len(point) > 2: + # NOTE: to assure the order is [latitude, longitude] and not vice versa! + raise ValueError( + "point need to have just two elemens [latitude, longitude]" + ) + param = ( + f"OData.CSC.Intersects(area=geography'SRID={srid};POINT({point[0]} {point[1]})')" + ) + self._append_query_param(param) + return self + + +class ODataRequestBuilder( + _ODataOperation +): # pylint: disable=too-few-public-methods + """OData API request builder.""" + + _BASE_PATTERN: str = "{url}/Products" + + @classmethod + def new(cls, url: str) -> _ODataOperation: + """Start building OData request.""" + url = cls._BASE_PATTERN.format(url=url.strip("/")) + return _ODataOperation(_ODataEntity(url=url)) diff --git a/drivers/intake_geokube/utils.py b/drivers/intake_geokube/utils.py new file mode 100644 index 0000000..a3a97e2 --- /dev/null +++ b/drivers/intake_geokube/utils.py @@ -0,0 +1,51 @@ +"""Utils module.""" + +import os + +import requests + + +def create_zip_from_response(response: requests.Response, target: str) -> None: + """Create ZIP archive based on the content in streamable response. + + Parameters + ---------- + response : requests.Response + Response whose contant is streamable (`stream=True`) + target : str + Target path containing name and .zip extension + + Raises + ------ + ValueError + if `Content-Type` header is missing + TypeError + if type supplied by `Content-Type` is other than `zip` + RuntimError + if size provided by `Content-Length` header differs from the size + of the downloaded file + """ + content_type = response.headers.get("Content-Type") + if not content_type: + raise ValueError("`Content-Type` mandatory header is missing") + format_ = content_type.split("/")[-1] + _, ext = os.path.splitext(target) + if format_ != "zip": + raise TypeError( + f"provided content type {format_} is not allowed. expected 'zip'" + " format" + ) + assert ext[1:] == "zip", "expected target with '.zip' extension" + + expected_length = int(response.headers["Content-Length"]) + total_bytes = 0 + with open(target, "wb") as f: + for chunk in response.iter_content(chunk_size=1024): + if chunk: + f.write(chunk) + total_bytes += len(chunk) + if expected_length != total_bytes: + raise RuntimeError( + "downloaded file is not complete in spite of download finished" + " successfully" + ) diff --git a/drivers/intake_geokube/version.py b/drivers/intake_geokube/version.py new file mode 100644 index 0000000..656021a --- /dev/null +++ b/drivers/intake_geokube/version.py @@ -0,0 +1,3 @@ +"""Module with the current version number definition.""" + +__version__ = "1.0b0" diff --git a/drivers/intake_geokube/wrf/__init__.py b/drivers/intake_geokube/wrf/__init__.py new file mode 100644 index 0000000..c528597 --- /dev/null +++ b/drivers/intake_geokube/wrf/__init__.py @@ -0,0 +1 @@ +"""Domain subpackage for WRF datasets.""" diff --git a/drivers/intake_geokube/wrf/driver.py b/drivers/intake_geokube/wrf/driver.py new file mode 100644 index 0000000..d819760 --- /dev/null +++ b/drivers/intake_geokube/wrf/driver.py @@ -0,0 +1,178 @@ +"""WRF driver for DDS.""" + +from functools import partial +from typing import Any + +import numpy as np +import xarray as xr +from geokube import open_datacube, open_dataset +from geokube.core.datacube import DataCube +from geokube.core.dataset import Dataset + +from ..base import AbstractBaseDriver + +_DIM_RENAME_MAP: dict = { + "Time": "time", + "south_north": "latitude", + "west_east": "longitude", +} +_COORD_RENAME_MAP: dict = { + "XTIME": "time", + "XLAT": "latitude", + "XLONG": "longitude", +} +_COORD_SQUEEZE_NAMES: tuple = ("latitude", "longitude") +_PROJECTION: dict = {"grid_mapping_name": "latitude_longitude"} + + +def _cast_to_set(item: Any) -> set: + if item is None: + return set() + if isinstance(item, set): + return item + if isinstance(item, str): + return {item} + if isinstance(item, list): + return set(item) + raise TypeError(f"type '{type(item)}' is not supported!") + + +def rename_coords(dset: xr.Dataset) -> xr.Dataset: + """Rename coordinates.""" + dset_ = dset.rename_vars(_COORD_RENAME_MAP) + # Removing `Time` dimension from latitude and longitude. + coords = dset_.coords + for name in _COORD_SQUEEZE_NAMES: + coord = dset_[name] + if "Time" in coord.dims: + coords[name] = coord.squeeze(dim="Time", drop=True) + return dset_ + + +def change_dims(dset: xr.Dataset) -> xr.Dataset: + """Change dimensions to time, latitude, and longitude.""" + # Preparing new horizontal coordinates. + lat = (["south_north"], dset["latitude"].to_numpy().mean(axis=1)) + lon = (["west_east"], dset["longitude"].to_numpy().mean(axis=0)) + # Removing old horizontal coordinates. + dset_ = dset.drop_vars(["latitude", "longitude"]) + # Adding new horizontal coordinates and setting their units. + coords = dset_.coords + coords["latitude"] = lat + coords["longitude"] = lon + dset_["latitude"].attrs["units"] = "degree_north" + dset_["longitude"].attrs["units"] = "degree_east" + # Making `time`, `latitude`, and `longitude` new dimensions, instead of + # `Time`, `south_north`, and `west_east`. + dset_ = dset_.swap_dims(_DIM_RENAME_MAP) + return dset_ + + +def add_projection(dset: xr.Dataset) -> xr.Dataset: + """Add projection information to the dataset.""" + coords = dset.coords + coords["crs"] = xr.DataArray(data=np.array(1), attrs=_PROJECTION) + for var in dset.data_vars.values(): + enc = var.encoding + enc["grid_mapping"] = "crs" + if coord_names := enc.get("coordinates"): + for old_name, new_name in _COORD_RENAME_MAP.items(): + coord_names = coord_names.replace(old_name, new_name) + enc["coordinates"] = coord_names + return dset + + +def choose_variables( + dset: xr.Dataset, + variables_to_keep: str | list[str] | None = None, + variables_to_skip: str | list[str] | None = None, +) -> xr.Dataset: + """Choose only some variables by keeping or skipping some of them.""" + variables_to_keep_ = _cast_to_set(variables_to_keep) + variables_to_skip_ = _cast_to_set(variables_to_skip) + selected_variables = set(dset.data_vars.keys()) + if len(variables_to_keep_) > 0: + selected_variables = set(dset.data_vars.keys()) & variables_to_keep_ + selected_variables = selected_variables - variables_to_skip_ + if len(set(dset.data_vars.keys())) != len(selected_variables): + return dset[selected_variables] + return dset + + +def preprocess_wrf( + dset: xr.Dataset, variables_to_keep, variables_to_skip +) -> xr.Dataset: + """Preprocess WRF dataset.""" + dset = rename_coords(dset) + dset = change_dims(dset) + dset = add_projection(dset) + dset = choose_variables(dset, variables_to_keep, variables_to_skip) + return dset + + +class WrfDriver(AbstractBaseDriver): + """Driver class for netCDF files.""" + + name = "wrf_driver" + version = "0.1a0" + + def __init__( + self, + path: str, + metadata: dict, + pattern: str | None = None, + field_id: str | None = None, + metadata_caching: bool = False, + metadata_cache_path: str | None = None, + storage_options: dict | None = None, + xarray_kwargs: dict | None = None, + mapping: dict[str, dict[str, str]] | None = None, + load_files_on_persistance: bool = True, + variables_to_keep: str | list[str] | None = None, + variables_to_skip: str | list[str] | None = None, + ) -> None: + super().__init__(metadata=metadata) + self.path = path + self.pattern = pattern + self.field_id = field_id + self.metadata_caching = metadata_caching + self.metadata_cache_path = metadata_cache_path + self.storage_options = storage_options + self.mapping = mapping + self.xarray_kwargs = xarray_kwargs or {} + self.load_files_on_persistance = load_files_on_persistance + self.preprocess = partial( + preprocess_wrf, + variables_to_keep=variables_to_keep, + variables_to_skip=variables_to_skip, + ) + + @property + def _arguments(self) -> dict: + return { + "path": self.path, + "id_pattern": self.field_id, + "metadata_caching": self.metadata_caching, + "metadata_cache_path": self.metadata_cache_path, + "mapping": self.mapping, + } | self.xarray_kwargs + + def read(self) -> Dataset | DataCube: + """Read netCDF.""" + if self.pattern: + return open_dataset( + pattern=self.pattern, + preprocess=self.preprocess, + **self._arguments, + ) + return open_datacube( + delay_read_cubes=True, + preprocess=self.preprocess, + **self._arguments, + ) + + def load(self) -> Dataset | DataCube: + """Load netCDF.""" + if self.pattern: + return open_dataset(pattern=self.pattern, **self._arguments) + return open_datacube(delay_read_cubes=False, **self._arguments) diff --git a/drivers/pyproject.toml b/drivers/pyproject.toml new file mode 100644 index 0000000..2f0a6d5 --- /dev/null +++ b/drivers/pyproject.toml @@ -0,0 +1,85 @@ +[build-system] +requires = ["setuptools", "wheel"] +build-backend = "setuptools.build_meta" + +[project] +name = "geolake-drivers" +description = "opengeokube DDS driver." +requires-python = ">=3.10" +readme = "README.md" +license = {file = "LICENSE"} +dynamic = ["version"] +authors = [ + {name = "Jakub Walczak"}, + {name = "Marco Mancini"}, + {name = "Mirko Stojiljkovic"}, + {name = "Valentina Scardigno"}, +] +classifiers = [ + "Development Status :: 3 - Alpha", + "Environment :: Web Environment", + "Intended Audience :: Science/Research", + "License :: OSI Approved :: Apache Software License", + "Natural Language :: English", + "Operating System :: OS Independent", + "Programming Language :: Python :: 3", + "Programming Language :: Python :: 3", + "Programming Language :: Python :: 3 :: Only", + "Programming Language :: Python :: 3.10", + "Programming Language :: Python :: 3.11", + "Topic :: Scientific/Engineering :: Atmospheric Science", + "Topic :: Software Development :: Libraries :: Python Modules", + "Topic :: Software Development :: Libraries :: Application Frameworks", + "Topic :: Software Development :: Libraries", +] +dependencies = [ + "dateparser", + "intake", + "pydantic", + "tqdm", + "streamz@git+https://github.com/python-streamz/streamz.git", + "paho-mqtt" +] +[project.entry-points."intake.drivers"] +netcdf_driver = "intake_geokube.netcdf.driver:NetCdfDriver" +sentinel_driver = "intake_geokube.sentinel.driver:SentinelDriver" +iot_driver = "intake_geokube.iot.driver:IotDriver" +wrf_driver = "intake_geokube.wrf.driver:WrfDriver" + +[tool.setuptools.dynamic] +version = {attr = "intake_geokube.version.__version__"} + +[tool.setuptools] +include-package-data = true + +[tool.setuptools.packages.find] +where = ["."] +exclude = ["examples*"] + +[tool.pydocstyle] + +[tool.pylint.'MESSAGES CONTROL'] +disable = "too-many-arguments,too-many-instance-attributes,too-few-public-methods,duplicate-code" + + +[tool.isort] +profile = "black" +include_trailing_comma = true +line_length = 79 +overwrite_in_place = true +use_parentheses = true + +[tool.black] +line_length = 79 +preview = true + +[tool.mypy] +files = [ + "intake_geokube", "." +] +exclude = ["tests/"] + +[tool.pytest.ini_options] +filterwarnings = [ + "ignore::DeprecationWarning" +] diff --git a/drivers/setup.py b/drivers/setup.py new file mode 100644 index 0000000..b908cbe --- /dev/null +++ b/drivers/setup.py @@ -0,0 +1,3 @@ +import setuptools + +setuptools.setup() diff --git a/drivers/tests/__init__.py b/drivers/tests/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/drivers/tests/queries/__init__.py b/drivers/tests/queries/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/drivers/tests/queries/test_utils.py b/drivers/tests/queries/test_utils.py new file mode 100644 index 0000000..0fbefbc --- /dev/null +++ b/drivers/tests/queries/test_utils.py @@ -0,0 +1,50 @@ +from intake_geokube.queries import utils as ut + + +class TestUtils: + def test_find_key_root_level_recusrive_switched_off(self): + assert ut.find_value({"a": 0, "b": 10}, "b", recursive=False) == 10 + + def test_find_key_root_level_recusrive_switched_on(self): + assert ut.find_value({"a": 0, "b": 10}, "b", recursive=True) == 10 + + def test_return_none_on_missing_key_root_level(self): + assert ut.find_value({"a": 0, "b": 10}, "c", recursive=True) is None + + def test_return_none_on_missing_key_another_level(self): + assert ( + ut.find_value({"a": 0, "b": {"c": 10}}, "d", recursive=True) + is None + ) + + def test_find_key_another_level_recursive_switched_off(self): + assert ( + ut.find_value({"a": 0, "b": {"c": "ccc"}}, "c", recursive=False) + is None + ) + + def test_find_key_another_level_recursive_switched_on(self): + assert ( + ut.find_value({"a": 0, "b": {"c": "ccc"}}, "c", recursive=True) + == "ccc" + ) + + def test_find_list_first(self): + assert ( + ut.find_value( + {"a": 0, "b": [{"c": "ccc"}, {"d": "ddd"}]}, + "c", + recursive=True, + ) + == "ccc" + ) + + def test_find_list_not_first(self): + assert ( + ut.find_value( + {"a": 0, "b": [{"d": "ddd"}, {"c": "ccc"}]}, + "c", + recursive=True, + ) + == "ccc" + ) diff --git a/drivers/tests/queries/test_workflow.py b/drivers/tests/queries/test_workflow.py new file mode 100644 index 0000000..1b8f8c3 --- /dev/null +++ b/drivers/tests/queries/test_workflow.py @@ -0,0 +1,61 @@ +import pytest + +from intake_geokube.queries.workflow import Workflow + + +class TestWorkflow: + def test_fail_on_missing_dataset_id(self): + with pytest.raises( + KeyError, + match=r"'dataset_id' key was missing. did you defined it for*", + ): + Workflow.parse({ + "tasks": [{ + "id": 0, + "op": "subset", + "args": { + "product_id": "reanalysis", + }, + }] + }) + + def test_fail_on_missing_product_id(self): + with pytest.raises( + KeyError, + match=r"'product_id' key was missing. did you defined it for*", + ): + Workflow.parse({ + "tasks": [{ + "id": 0, + "op": "subset", + "args": { + "dataset_id": "era5", + }, + }] + }) + + def test_fail_on_nonunique_id(self): + with pytest.raises( + ValueError, + match=r"duplicated key found*", + ): + Workflow.parse({ + "tasks": [ + { + "id": 0, + "op": "subset", + "args": { + "dataset_id": "era5", + "product_id": "reanalysis", + }, + }, + { + "id": 0, + "op": "subset", + "args": { + "dataset_id": "era5", + "product_id": "reanalysis", + }, + }, + ] + }) diff --git a/drivers/tests/sentinel/__init__.py b/drivers/tests/sentinel/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/drivers/tests/sentinel/fixture.py b/drivers/tests/sentinel/fixture.py new file mode 100644 index 0000000..cfbb8bd --- /dev/null +++ b/drivers/tests/sentinel/fixture.py @@ -0,0 +1,11 @@ +import pytest + + +@pytest.fixture +def sentinel_files(): + return [ + "/tmp/pymp-2b5gr07m/162f8f7e-c954-4f69-bb53-ed820aa6432a/S2A_MSIL2A_20231007T100031_N0509_R122_T32TQM_20231007T142901.SAFE/GRANULE/L2A_T32TQM_A043305_20231007T100026/IMG_DATA/R20m/T32TQM_20231007T100031_B01_20m.jp2", + "/tmp/pymp-2b5gr07m/162f8f7e-c954-4f69-bb53-ed820aa6432a/S2A_MSIL2A_20231007T100031_N0509_R122_T32TQM_20231007T142901.SAFE/GRANULE/L2A_T32TQM_A043305_20231007T100026/IMG_DATA/R20m/T32TQM_20231007T100031_B10_20m.jp2", + "/tmp/pymp-2b5gr07m/162f8f7e-c954-4f69-bb53-ed820aa6432a/S2A_MSIL2A_20231007T100031_N0509_R122_T32TQM_20231007T142901.SAFE/GRANULE/L2A_T32TQM_A043305_20231007T100026/IMG_DATA/R30m/T32TQM_20231007T100031_B04_30m.jp2", + "/tmp/pymp-2b5gr07m/162f8f7e-c954-4f69-bb53-ed820aa6432a/S2A_MSIL2A_20231007T100031_N0509_R122_T32TQM_20231007T142901.SAFE/GRANULE/L2A_T32TQM_A043305_20231007T100026/IMG_DATA/R10m/T32TQM_20231007T100031_B12_40m.jp2", + ] diff --git a/drivers/tests/sentinel/test_builder.py b/drivers/tests/sentinel/test_builder.py new file mode 100644 index 0000000..f2e5cc1 --- /dev/null +++ b/drivers/tests/sentinel/test_builder.py @@ -0,0 +1,376 @@ +from multiprocessing import Value +from unittest import mock + +import pytest +from requests import Response, Session + +from intake_geokube.sentinel.odata_builder import ( + HttpMethod, + ODataRequest, + ODataRequestBuilder, + _ODataEntity, + _ODataOperation, + _ODataOrderMixing, + datetime_to_isoformat, +) + + +@pytest.fixture +def odata() -> _ODataEntity: + return _ODataEntity(url="http://url.com/v1") + + +@pytest.fixture +def odata_op(odata) -> _ODataOperation: + return _ODataOperation(odata=odata) + + +class TestHttpMethod: + @pytest.mark.parametrize( + "method,res", [(HttpMethod.GET, "get"), (HttpMethod.POST, "post")] + ) + def test_get_proper_name(self, method, res): + assert method.method_name == res + + +class TestODataRequestBuildable: + def test_build_from_operation(self, odata): + res = _ODataOperation(odata).build() + assert isinstance(res, ODataRequest) + assert res.odata == odata + + +class TestOrderMixin: + @pytest.mark.parametrize("type_", [_ODataOperation]) + def test_proper_class_when_order(self, type_, odata): + res = type_(odata).order(by="ProductionDate") + assert isinstance(res, type_) + + def test_fail_order_on_wrong_superclass(self, odata): + class A(_ODataOrderMixing): + def __init__(self, odata): + self.odata = odata + + with pytest.raises(TypeError, match=r"unexpected type:*"): + A(odata).order(by="a") + + +class TestODataRequest: + def test_convert_filter_param(self, odata_op): + odata_op.filter("a", eq=10).or_().filter("b", lt=100, ge=10).order( + by="a", desc=True + ) + req = ODataRequest(odata_op.odata) + assert req.odata.params["filter"] == [ + "a eq '10'", + "b lt 100", + "b ge 10", + ] + assert ( + req.request_params["filter"] == "a eq '10' or b lt 100 and b ge 10" + ) + assert req.odata.params["orderby"] == ["a desc"] + + +class TestODataRequestBuilder: + def test_create_odata_operation_from_builder(self): + res = ODataRequestBuilder.new(url="http:/url.com") + assert isinstance(res, _ODataOperation) + assert res.odata.url == "http:/url.com/Products" + + +class TestODataOperation: + @pytest.fixture + def odata_request(self) -> ODataRequest: + return ODataRequestBuilder.new("http://aaaa.com").build() + + @pytest.mark.parametrize( + "datestring,result", + [ + ("2002-02-01", "2002-02-01T00:00:00Z"), + ("2001-02-02 12:45", "2001-02-02T12:45:00Z"), + ("1977-12-23 11:00:05", "1977-12-23T11:00:05Z"), + ("1977-12-23T11:00:05", "1977-12-23T11:00:05Z"), + ], + ) + def test_convert_to_isoformat(self, datestring, result): + assert datetime_to_isoformat(datestring) == result + + def testwith_option_equal(self, odata_op): + odata_op.with_option_equal("Name", "some_name") + assert len(odata_op.odata.params) == 1 + assert odata_op.odata.method is HttpMethod.GET + assert odata_op.odata.params["filter"] == ["Name eq 'some_name'"] + + def test_option_containing(self, odata_op): + odata_op.with_option_containing("some_option", "aaa") + assert len(odata_op.odata.params) == 1 + assert odata_op.odata.method is HttpMethod.GET + assert odata_op.odata.params["filter"] == [ + "contains(some_option,'aaa')" + ] + + def test_option_not_containing(self, odata_op): + odata_op.with_option_not_containing("some_option", "aaa") + assert len(odata_op.odata.params) == 1 + assert odata_op.odata.method is HttpMethod.GET + assert odata_op.odata.params["filter"] == [ + "not contains(some_option,'aaa')" + ] + + def testwith_option_equal_list(self, odata_op): + odata_op.with_option_equal_list("Name", ["some_name", "aaa"]) + assert len(odata_op.odata.params) == 0 + assert odata_op.odata.method is HttpMethod.POST + assert odata_op.odata.body == { + "FilterProducts": [{"Name": "some_name"}, {"Name": "aaa"}] + } + + def test_several_options(self, odata_op): + odata_op.with_option_equal("aa", "bb").and_().with_option_lt( + "aaa", "1000" + ) + assert odata_op.odata.method is HttpMethod.GET + assert len(odata_op.odata.params) == 1 + assert odata_op.odata.params["filter"] == ["aa eq 'bb'", "aaa lt 1000"] + + @pytest.mark.parametrize( + "comb", + [ + {"lt": 1, "eq": 10}, + {"le": 1, "eq": 10}, + {"lt": 1, "le": 10}, + {"gt": 1, "ge": 10}, + {"ge": 1, "eq": 10}, + {"gt": 1, "eq": 10}, + {"lt": 1, "eq": 1, "ge": 1}, + ], + ) + def test_filter_fail_on_wrong_arguments_passed(self, comb, odata_op): + with pytest.raises(ValueError, match=r"cannot define *"): + odata_op.filter(name="a", **comb) + + def test_filter_single(self, odata_op): + res = odata_op.filter(name="a", lt=100) + assert res.odata.params["filter"] == ["a lt 100"] + + def test_filter_multiple(self, odata_op): + res = odata_op.filter(name="a", lt=100, gt=10) + assert res.odata.params["filter"] == ["a lt 100", "a gt 10"] + assert res.odata.conj[-1] == "and" + + def test_filter_multiple2(self, odata_op): + res = odata_op.filter(name="a", ge=10, le=100) + assert res.odata.params["filter"] == ["a le 100", "a ge 10"] + assert res.odata.conj[-1] == "and" + + def test_filter_multiple3(self, odata_op): + res = odata_op.filter(name="a", eq=10) + assert res.odata.params["filter"] == ["a eq '10'"] + assert res.odata.conj[-1] == "and" + + @pytest.mark.parametrize("arr", ["111", "111", "02-20", "56:45", "aaa"]) + def test_filter_date_fail_arg_nondateparsable(self, arr, odata_op): + with pytest.raises(ValueError, match=r"cannot parse*"): + odata_op.filter_date("ProductionDate", lt=arr) + + @pytest.mark.parametrize("arr", [(1,), 1, 1.2, [1, 2], {1, 2}]) + def test_filter_date_fail_arg_wrong_type(self, arr, odata_op): + with pytest.raises(TypeError, match=r"type .* is not supported"): + odata_op.filter_date("ProductionDate", lt=arr) + + def test_filter_and_order_ascending(self, odata_op): + odata_op.with_option_gt("aaa", "-50").order( + by="ProductionDate", desc=False + ) + assert odata_op.odata.method is HttpMethod.GET + assert len(odata_op.odata.params) == 2 + assert odata_op.odata.body == {} + assert odata_op.odata.params["filter"] == ["aaa gt -50"] + assert odata_op.odata.params["orderby"] == ["ProductionDate asc"] + + def test_filter_and_order_descending(self, odata_op): + odata_op.with_option_gt("aaa", "-50").order( + by="ProductionDate", desc=True + ) + assert odata_op.odata.method is HttpMethod.GET + assert len(odata_op.odata.params) == 2 + assert odata_op.odata.body == {} + assert odata_op.odata.params["filter"] == ["aaa gt -50"] + assert odata_op.odata.params["orderby"] == ["ProductionDate desc"] + + @mock.patch.object(Session, "send") + def test_request_data(self, send_mock, odata_op): + send_mock.json.return_value = "{'response': 'some response'}" + _ = ( + odata_op.with_option_gt("aaa", "-50") + .order(by="ProductionDate", desc=True) + .build() + .query() + ) + send_mock.assert_called_once() + assert ( + send_mock.call_args_list[0].args[0].url + == "http://url.com/v1?%24filter=aaa+gt+-50&%24orderby=ProductionDate+desc" + ) + + @mock.patch.object(Session, "send") + def test_url_passed_with_extra_slashes(self, send_mock): + builder = ODataRequestBuilder.new( + "https://some_url.com/odata/v1" + ).build() + assert builder.odata.url == "https://some_url.com/odata/v1/Products" + + def test_polygon_fail_on_other_srid_passed(self, odata_op): + with pytest.raises( + NotImplementedError, match=r"currently supported SRID is only*" + ): + odata_op.intersect_polygon( + polygon=[[0, 1], [1, 2], [0, 1]], srid="123" + ) + + def test_polygon_fail_on_polygon_with_more_than_two_coords(self, odata_op): + with pytest.raises( + ValueError, + match=r"polygon should be defined as a 2-element list or tuple*", + ): + odata_op.intersect_polygon(polygon=[[0, 1], [1, 2, 3], [0, 1]]) + + def test_polygon_fail_on_polygon_ending_not_on_start_point(self, odata_op): + with pytest.raises( + ValueError, + match=r"polygon needs to end at the same point it starts!", + ): + odata_op.intersect_polygon(polygon=[[0, 1], [1, 3], [1, 1]]) + + def test_location_fail_on_other_srid_passed(self, odata_op): + with pytest.raises( + NotImplementedError, match=r"currently supported SRID is only*" + ): + odata_op.intersect_point(point=(0.1, 2.0), srid="123") + + def test_location_fail_on_more_than_two_coords(self, odata_op): + with pytest.raises( + ValueError, match=r"point need to have just two elemens*" + ): + odata_op.intersect_point(point=[0, 1, 4]) + + @mock.patch.object(Session, "send") + @pytest.mark.parametrize( + "code,callback", [(200, lambda r: "ok"), (400, lambda r: "bad")] + ) + def test_callback_call_on_defined( + self, send_mock, code, callback, odata_request + ): + response = Response() + response.status_code = code + send_mock.return_value = response + res = odata_request.with_callback(callback, code).query() + assert res == callback(None) + + @mock.patch.object(Session, "send") + def test_return_response_on_missing_callback( + self, send_mock, odata_request + ): + response = Response() + response.status_code = 200 + send_mock.return_value = response + res = odata_request.query() + assert isinstance(res, Response) + + @mock.patch.object(Session, "send") + @pytest.mark.parametrize("code", [200, 300, 305, 400, 500]) + def test_callback_without_http_code(self, send_mock, code, odata_request): + response = Response() + response.status_code = code + send_mock.return_value = response + callback = mock.MagicMock() + _ = odata_request.with_callback(callback).query() + callback.assert_called_with(response) + + def test_operations_with_auto_conjunction(self, odata_op): + res = odata_op.filter("a", lt=10).filter("b", ge="aaa") + assert res.odata.params["filter"] == ["a lt 10", "b ge aaa"] + assert len(res.odata.conj) == 2 + assert res.odata.conj == ["and", "and"] + + def test_operations_with_auto_conjunction_with_several_operations( + self, odata_op + ): + res = ( + odata_op.filter("a", lt=10) + .filter("b", ge="aaa") + .filter_date("ProductioNDate", lt="2000-01-01") + ) + assert res.odata.params["filter"] == [ + "a lt 10", + "b ge aaa", + "ProductioNDate lt 2000-01-01T00:00:00Z", + ] + assert len(res.odata.conj) == 3 + assert res.odata.conj == ["and", "and", "and"] + + def test_operations_with_auto_and_explicit_conjunction_with_several_operations( + self, odata_op + ): + res = ( + odata_op.filter("a", lt=10) + .filter("b", ge="aaa") + .or_() + .filter_date("ProductioNDate", lt="2000-01-01") + ) + assert res.odata.params["filter"] == [ + "a lt 10", + "b ge aaa", + "ProductioNDate lt 2000-01-01T00:00:00Z", + ] + assert len(res.odata.conj) == 3 + assert res.odata.conj == ["and", "or", "and"] + + def test_con_conj_on_single_operation(self, odata_op): + res = odata_op.filter("a", lt=10) + assert res.odata.params["filter"] == ["a lt 10"] + assert len(res.odata.conj) == 1 + + def test_operations_with_explicit_conjunction_and(self, odata_op): + res = odata_op.filter("a", lt=10).and_().filter("b", ge="aaa") + assert res.odata.params["filter"] == ["a lt 10", "b ge aaa"] + assert len(res.odata.conj) == 2 + assert res.odata.conj == ["and", "and"] + + def test_operations_with_explicit_conjunction_or(self, odata_op): + res = odata_op.filter("a", lt=10).or_().filter("b", ge="aaa") + assert res.odata.params["filter"] == ["a lt 10", "b ge aaa"] + assert len(res.odata.conj) == 2 + assert res.odata.conj == ["or", "and"] + + def test_operation_with_idempotent_same_conjunction(self, odata_op): + res = odata_op.filter("a", lt=10).or_().or_().filter("b", ge="aaa") + assert res.odata.params["filter"] == ["a lt 10", "b ge aaa"] + assert len(res.odata.conj) == 2 + assert res.odata.conj == ["or", "and"] + + def test_operation_with_idempotent_other_conjunction(self, odata_op): + res = ( + odata_op.filter("a", lt=10) + .or_() + .or_() + .and_() + .filter("b", ge="aaa") + ) + assert res.odata.params["filter"] == ["a lt 10", "b ge aaa"] + assert len(res.odata.conj) == 2 + assert res.odata.conj == ["and", "and"] + + def test_filter_skip_if_all_arg_nones(self, odata_op): + odata_op = odata_op.filter("a").filter("b") + assert len(odata_op.odata.params) == 0 + assert len(odata_op.odata.conj) == 0 + + def test_filter_containing(self, odata_op): + odata_op = odata_op.filter("a", containing="ggg", not_containing="bbb") + assert odata_op.odata.params["filter"] == [ + "contains(a,'ggg')", + "not contains(a,'bbb')", + ] + assert odata_op.odata.conj == ["and", "and"] diff --git a/drivers/tests/sentinel/test_driver.py b/drivers/tests/sentinel/test_driver.py new file mode 100644 index 0000000..326bab4 --- /dev/null +++ b/drivers/tests/sentinel/test_driver.py @@ -0,0 +1,177 @@ +import os +from unittest import mock + +import pytest +from intake.source.utils import reverse_format + +import intake_geokube.sentinel.driver as drv +from intake_geokube.queries.geoquery import GeoQuery + +from . import fixture as fxt + + +class TestSentinelDriver: + @pytest.mark.parametrize( + "item,res", + [ + ("aaa", 1), + (["aa", "bb"], 2), + (10, 1), + ([10, 100], 2), + (("a", "b"), 2), + ((-1, -5), 2), + ], + ) + def test_get_items_nbr(self, item, res): + assert drv._get_items_nbr({"key": item}, "key") == res + + @pytest.mark.skip(reason="product_id is not mandatory anymore") + def test_validate_query_fail_on_missing_product_id(self): + query = GeoQuery() + with pytest.raises( + ValueError, match=r"\'product_id\' is mandatory filter" + ): + drv._validate_geoquery_for_sentinel(query) + + @pytest.mark.parametrize( + "time", + [ + {"year": [2000, 2014], "month": 10, "day": 14}, + {"year": 2014, "month": [10, 11], "day": 14}, + {"year": 2000, "month": 10, "day": [14, 15, 16]}, + ], + ) + def test_validate_query_fail_on_multiple_year_month_day(self, time): + query = GeoQuery(product_id="aaa", time=time) + with pytest.raises( + ValueError, + match=( + r"valid time combo for sentinel data should contain exactly" + r" one*" + ), + ): + drv._validate_geoquery_for_sentinel(query) + + @pytest.mark.parametrize( + "time", + [ + {"year": 1999, "month": 10, "day": 14}, + {"year": 2014, "month": 10, "day": 14}, + {"year": 2000, "month": 10, "day": 14}, + ], + ) + def test_validate_query_if_time_passed_as_int(self, time): + query = GeoQuery(product_id="aaa", time=time) + drv._validate_geoquery_for_sentinel(query) + + @pytest.mark.parametrize( + "time", + [ + {"year": "1999", "month": "10", "day": "14"}, + {"year": 2014, "month": "10", "day": 14}, + {"year": "2000", "month": 10, "day": 14}, + ], + ) + def test_validate_query_if_time_passed_as_str(self, time): + query = GeoQuery(product_id="aaa", time=time) + drv._validate_geoquery_for_sentinel(query) + + @pytest.mark.parametrize( + "locs", + [{"latitude": 10}, {"longitude": -10}, {"latitude": 5, "aaa": 10}], + ) + def test_validate_query_Fail_on_missing_key(self, locs): + query = GeoQuery(product_id="aa", location=locs) + with pytest.raises( + ValueError, + match=( + r"both \'latitude\' and \'longitude\' must be defined for" + r" locatio" + ), + ): + drv._validate_geoquery_for_sentinel(query) + + @pytest.mark.parametrize( + "locs", + [ + {"latitude": [10, -5], "longitude": [-1, -2]}, + {"latitude": 10, "longitude": [-1, -2]}, + {"latitude": [10, -5], "longitude": -1}, + ], + ) + def test_location_to_valid_point_fail_on_multielement_list_passed( + self, locs + ): + query = GeoQuery(product_id="aa", location=locs) + with pytest.raises( + ValueError, + match=r"location can have just a single point \(single value for*", + ): + drv._location_to_valid_point(query.location) + + @pytest.mark.parametrize( + "path,res", + [ + ( + "/tmp/pymp-2b5gr07m/162f8f7e-c954-4f69-bb53-ed820aa6432a/S2A_MSIL2A_20231007T100031_N0509_R122_T32TQM_20231007T142901.SAFE/GRANULE/L2A_T32TQM_A043305_20231007T100026/IMG_DATA/R20m/T32TQM_20231007T100031_B01_20m.jp2", + { + "product_id": "162f8f7e-c954-4f69-bb53-ed820aa6432a", + "resolution": "R20m", + "band": "B01", + }, + ), + ( + "/tmp/pymp-2b5gr07m/162f8f7e-c954-4f69-bb53-ed820aa6432a/S2A_MSIL2A_20231007T100031_N0509_R122_T32TQM_20231007T142901.SAFE/GRANULE/L2A_T32TQM_A043305_20231007T100026/IMG_DATA/R30m/T32TQM_20231007T100031_B04_30m.jp2", + { + "product_id": "162f8f7e-c954-4f69-bb53-ed820aa6432a", + "resolution": "R30m", + "band": "B04", + }, + ), + ], + ) + def test_zippatern(self, path, res): + zippattern = "/{product_id}/{}.SAFE/GRANULE/{}/IMG_DATA/{resolution}/{}_{}_{band}_{}.jp2" + target_dir = "/tmp/pymp-2b5gr07m" + assert reverse_format(zippattern, path.removeprefix(target_dir)) == res + + @pytest.mark.parametrize( + "path,exp", + [ + ( + "/tmp/pymp-2b5gr07m/162f8f7e-c954-4f69-bb53-ed820aa6432a/S2A_MSIL2A_20231007T100031_N0509_R122_T32TQM_20231007T142901.SAFE/GRANULE/L2A_T32TQM_A043305_20231007T100026/IMG_DATA/R20m/T32TQM_20231007T100031_B01_20m.jp2", + "R20m_B01", + ), + ( + "/tmp/pymp-2b5gr07m/162f8f7e-c954-4f69-bb53-ed820aa6432a/S2A_MSIL2A_20231007T100031_N0509_R122_T32TQM_20231007T142901.SAFE/GRANULE/L2A_T32TQM_A043305_20231007T100026/IMG_DATA/R30m/T32TQM_20231007T100031_B04_30m.jp2", + "R30m_B04", + ), + ], + ) + def test_get_field_name_from_path(self, path, exp): + assert drv._get_field_name_from_path(path) == exp + + @mock.patch.dict(os.environ, {}, clear=True) + def test_fail_if_no_username_passed(self): + with pytest.raises( + KeyError, + match=( + r"missing at least of of the mandatory environmental" + r" variables:" + ), + ): + drv.SentinelDriver({}, "", "", "") + + def test_raise_notimplemented_for_read(self): + with pytest.raises( + NotImplementedError, + match=r"reading metadata is not supported for sentinel data*", + ): + drv.SentinelDriver({}, "", "", "").read() + + def test_raise_notimplemented_for_load(self): + with pytest.raises( + NotImplementedError, + match=r"loading entire product is not supported for sentinel data", + ): + drv.SentinelDriver({}, "", "", "").load() diff --git a/drivers/tests/test_geoquery.py b/drivers/tests/test_geoquery.py new file mode 100644 index 0000000..4cb9daa --- /dev/null +++ b/drivers/tests/test_geoquery.py @@ -0,0 +1,41 @@ +from unittest import mock + +import pytest + +from intake_geokube.queries.geoquery import GeoQuery + + +class TestGeoQuery: + def test_pass_time_as_combo(self): + query = GeoQuery( + time={ + "year": ["2002"], + "month": ["6"], + "day": ["21"], + "hour": ["8", "10"], + } + ) + assert isinstance(query.time, dict) + + def test_pass_time_as_slice(self): + query = GeoQuery(time={"start": "2000-01-01", "stop": "2001-12-21"}) + assert isinstance(query.time, slice) + assert query.time.start == "2000-01-01" + assert query.time.stop == "2001-12-21" + + def test_dump_original_from_time_as_combo(self): + query = GeoQuery( + time={ + "year": ["2002"], + "month": ["6"], + "day": ["21"], + "hour": ["8", "10"], + } + ) + res = query.model_dump_original() + assert isinstance(res["time"], dict) + + def test_dump_original_from_time_as_slice(self): + query = GeoQuery(time={"start": "2000-01-01", "stop": "2001-12-21"}) + res = query.model_dump_original() + assert isinstance(res["time"], dict) diff --git a/executor/Dockerfile b/executor/Dockerfile index e3cc317..db3cebb 100644 --- a/executor/Dockerfile +++ b/executor/Dockerfile @@ -1,16 +1,12 @@ -FROM continuumio/miniconda3 -WORKDIR /code -RUN conda install -c conda-forge xesmf cartopy psycopg2 -y -COPY ./executor/requirements.txt /code/requirements.txt -RUN pip install --no-cache-dir -r requirements.txt -COPY geokube_packages/geokube-0.1a0-py3-none-any.whl /code -COPY geokube_packages/intake_geokube-0.1a0-py3-none-any.whl /code -RUN pip install /code/geokube-0.1a0-py3-none-any.whl -RUN pip install /code/intake_geokube-0.1a0-py3-none-any.whl -COPY ./db/dbmanager /code/app/db/dbmanager -COPY ./utils/wait-for-it.sh /code/wait-for-it.sh -COPY ./datastore /code/app/datastore -COPY ./geoquery /code/app/geoquery -COPY ./resources /code/app/resources -COPY ./executor/app /code/app -CMD [ "python", "./app/main.py" ] \ No newline at end of file +ARG REGISTRY=rg.nl-ams.scw.cloud/geodds-production +ARG TAG=latest +ARG SENTINEL_USERNAME=... +ARG SENTINEL_PASSWORD=... +FROM $REGISTRY/geolake-datastore:$TAG +WORKDIR /app +ENV SENTINEL_USERNAME=$SENTINEL_USERNAME +ENV SENTINEL_PASSWORD=$SENTINEL_PASSWORD +COPY requirements.txt /code/requirements.txt +RUN pip install --no-cache-dir -r /code/requirements.txt +COPY app /app +CMD [ "python", "main.py" ] diff --git a/executor/app/main.py b/executor/app/main.py index c59ef92..35b90fe 100644 --- a/executor/app/main.py +++ b/executor/app/main.py @@ -1,146 +1,477 @@ -# We have three type of executor: -# - query executor (query) -# - estimate query executor (estimate) -# - catalog info executor (info) -# -# Configuration parameters for the executor: -# type: query, estimate, catalog -# dask cluster base ports (if they are not provided the cluster is not created: (e.g. for estimate and catalog info)) -# channel: channel_queue, channel_type, channel_durable -# catalog path -# store_path (where to store the query results) -# -# An executor will register to the DB and get a worker id -# if dask cluster base ports are provided, a dask cluster is created -# an executor mush have a unique port for the dask scheduler/dashboard - import os -import json +import time +import datetime import pika -from dask.distributed import Client, LocalCluster +import logging +import asyncio +import threading, functools +from zipfile import ZipFile + +import numpy as np +from dask.distributed import Client, LocalCluster, Nanny, Status +from dask.delayed import Delayed +from geokube.core.datacube import DataCube +from geokube.core.dataset import Dataset +from geokube.core.field import Field from datastore.datastore import Datastore -from db.dbmanager.dbmanager import DBManager, RequestStatus +from workflow import Workflow +from intake_geokube.queries.geoquery import GeoQuery +from dbmanager.dbmanager import DBManager, RequestStatus + +from meta import LoggableMeta +from messaging import Message, MessageType + +_BASE_DOWNLOAD_PATH = "/downloads" + + +def get_file_name_for_climate_downscaled(kube: DataCube, message: Message): + query: GeoQuery = GeoQuery.parse(message.content) + is_time_range = False + if query.time: + is_time_range = "start" in query.time or "stop" in query.time + var_names = list(kube.fields.keys()) + if len(kube) == 1: + if is_time_range: + FILENAME_TEMPLATE = "{ncvar_name}_VHR-PRO_IT2km_CMCC-CM_{product_id}_CCLM5-0-9_1hr_{start_date}_{end_date}_{request_id}" + ncvar_name = kube.fields[var_names[0]].ncvar + return FILENAME_TEMPLATE.format( + product_id=message.product_id, + request_id=message.request_id, + ncvar_name=ncvar_name, + start_date=np.datetime_as_string( + kube.time.values[0], unit="D" + ), + end_date=np.datetime_as_string(kube.time.values[-1], unit="D"), + ) + else: + FILENAME_TEMPLATE = "{ncvar_name}_VHR-PRO_IT2km_CMCC-CM_{product_id}_CCLM5-0-9_1hr_{request_id}" + ncvar_name = kube.fields[var_names[0]].ncvar + return FILENAME_TEMPLATE.format( + product_id=message.product_id, + request_id=message.request_id, + ncvar_name=ncvar_name, + ) + else: + if is_time_range: + FILENAME_TEMPLATE = "VHR-PRO_IT2km_CMCC-CM_{product_id}_CCLM5-0-9_1hr_{start_date}_{end_date}_{request_id}" + return FILENAME_TEMPLATE.format( + product_id=message.product_id, + request_id=message.request_id, + start_date=np.datetime_as_string( + kube.time.values[0], unit="D" + ), + end_date=np.datetime_as_string(kube.time.values[-1], unit="D"), + ) + else: + FILENAME_TEMPLATE = ( + "VHR-PRO_IT2km_CMCC-CM_{product_id}_CCLM5-0-9_1hr_{request_id}" + ) + return FILENAME_TEMPLATE.format( + product_id=message.product_id, + request_id=message.request_id, + ) + + +def rcp85_filename_condition(kube: DataCube, message: Message) -> bool: + return ( + message.dataset_id == "climate-projections-rcp85-downscaled-over-italy" + ) + + +def get_history_message(): + return ( + f"Generated by CMCC DDS version 0.9.0 {str(datetime.datetime.now())}" + ) + + +def persist_datacube( + kube: DataCube, + message: Message, + base_path: str | os.PathLike, +) -> str | os.PathLike: + if rcp85_filename_condition(kube, message): + path = get_file_name_for_climate_downscaled(kube, message) + else: + var_names = list(kube.fields.keys()) + if len(kube) == 1: + path = "_".join( + [ + var_names[0], + message.dataset_id, + message.product_id, + message.request_id, + ] + ) + else: + path = "_".join( + [message.dataset_id, message.product_id, message.request_id] + ) + kube._properties["history"] = get_history_message() + if isinstance(message.content, GeoQuery): + format = message.content.format + format_args = message.content.format_args + else: + format = "netcdf" + match format: + case "netcdf": + full_path = os.path.join(base_path, f"{path}.nc") + kube.to_netcdf(full_path) + case "geojson": + full_path = os.path.join(base_path, f"{path}.json") + kube.to_geojson(full_path) + case "png": + full_path = os.path.join(base_path, f"{path}.png") + kube.to_image(full_path, **format_args) + case "jpeg": + full_path = os.path.join(base_path, f"{path}.jpg") + kube.to_image(full_path, **format_args) + case _: + raise ValueError(f"format `{format}` is not supported") + return full_path + + +def persist_dataset( + dset: Dataset, + message: Message, + base_path: str | os.PathLike, +): + def _get_attr_comb(dataframe_item, attrs): + return "_".join([dataframe_item[attr_name] for attr_name in attrs]) + + def _persist_single_datacube(dataframe_item, base_path, format, format_args=None): + if not format_args: + format_args = {} + dcube = dataframe_item[dset.DATACUBE_COL] + if isinstance(dcube, Delayed): + dcube = dcube.compute() + if len(dcube) == 0: + return None + for field in dcube.fields.values(): + if 0 in field.shape: + return None + attr_str = _get_attr_comb(dataframe_item, dset._Dataset__attrs) + var_names = list(dcube.fields.keys()) + if len(dcube) == 1: + path = "_".join( + [ + var_names[0], + message.dataset_id, + message.product_id, + attr_str, + message.request_id, + ] + ) + else: + path = "_".join( + [ + message.dataset_id, + message.product_id, + attr_str, + message.request_id, + ] + ) + match format: + case "netcdf": + full_path = os.path.join(base_path, f"{path}.nc") + dcube.to_netcdf(full_path) + case "geojson": + full_path = os.path.join(base_path, f"{path}.json") + dcube.to_geojson(full_path) + case "png": + full_path = os.path.join(base_path, f"{path}.png") + dcube.to_image(full_path, **format_args) + case "jpeg": + full_path = os.path.join(base_path, f"{path}.jpg") + dcube.to_image(full_path, **format_args) + case _: + raise ValueError(f"format: {format} is not supported!") + return full_path + + if isinstance(message.content, GeoQuery): + format = message.content.format + format_args = message.content.format_args + else: + format = "netcdf" + datacubes_paths = dset.data.apply( + _persist_single_datacube, base_path=base_path, format=format, format_args=format_args, axis=1 + ) + paths = datacubes_paths[~datacubes_paths.isna()] + if len(paths) == 0: + return None + elif len(paths) == 1: + return paths.iloc[0] + zip_name = "_".join( + [message.dataset_id, message.product_id, message.request_id] + ) + path = os.path.join(base_path, f"{zip_name}.zip") + with ZipFile(path, "w") as archive: + for file in paths: + archive.write(file, arcname=os.path.basename(file)) + for file in paths: + os.remove(file) + return path + -def ds_query(ds_id, prod_id, query, compute, catalog_path): - ds = Datastore(catalog_path) - kube = ds.query(ds_id, prod_id, query, compute) - kube.persist('.') - return kube +def process(message: Message, compute: bool): + res_path = os.path.join(_BASE_DOWNLOAD_PATH, message.request_id) + os.makedirs(res_path, exist_ok=True) + match message.type: + case MessageType.QUERY: + kube = Datastore().query( + message.dataset_id, + message.product_id, + message.content, + compute, + ) + case MessageType.WORKFLOW: + kube = Workflow.from_tasklist(message.content).compute() + case _: + raise ValueError("unsupported message type") + if isinstance(kube, Field): + kube = DataCube( + fields=[kube], + properties=kube.properties, + encoding=kube.encoding, + ) + match kube: + case DataCube(): + return persist_datacube(kube, message, base_path=res_path) + case Dataset(): + return persist_dataset(kube, message, base_path=res_path) + case _: + raise TypeError( + "expected geokube.DataCube or geokube.Dataset, but passed" + f" {type(kube).__name__}" + ) -class Executor(): - def __init__(self, broker, catalog_path, store_path): - self._datastore = Datastore(catalog_path) - self._catalog_path = catalog_path +class Executor(metaclass=LoggableMeta): + _LOG = logging.getLogger("geokube.Executor") + + def __init__(self, broker, store_path): self._store = store_path - broker_conn = pika.BlockingConnection(pika.ConnectionParameters(host=broker)) + broker_conn = pika.BlockingConnection( + pika.ConnectionParameters(host=broker, heartbeat=10), + ) + self._conn = broker_conn self._channel = broker_conn.channel() self._db = DBManager() - - def create_dask_cluster(self, dask_cluster_opts): - self._worker_id = self._db.create_worker(status='enabled', - dask_scheduler_port=dask_cluster_opts['scheduler_port'], - dask_dashboard_address=dask_cluster_opts['dashboard_address']) - dask_cluster = LocalCluster(n_workers=dask_cluster_opts['n_workers'], - scheduler_port=dask_cluster_opts['scheduler_port'], - dashboard_address=dask_cluster_opts['dashboard_address'] - ) + + def create_dask_cluster(self, dask_cluster_opts: dict = None): + if dask_cluster_opts is None: + dask_cluster_opts = {} + dask_cluster_opts["scheduler_port"] = int( + os.getenv("DASK_SCHEDULER_PORT", 8188) + ) + dask_cluster_opts["processes"] = True + port = int(os.getenv("DASK_DASHBOARD_PORT", 8787)) + dask_cluster_opts["dashboard_address"] = f":{port}" + dask_cluster_opts["n_workers"] = None + dask_cluster_opts["memory_limit"] = "auto" + self._worker_id = self._db.create_worker( + status="enabled", + dask_scheduler_port=dask_cluster_opts["scheduler_port"], + dask_dashboard_address=dask_cluster_opts["dashboard_address"], + ) + self._LOG.info( + "creating Dask Cluster with options: `%s`", + dask_cluster_opts, + extra={"track_id": self._worker_id}, + ) + dask_cluster = LocalCluster( + n_workers=dask_cluster_opts["n_workers"], + scheduler_port=dask_cluster_opts["scheduler_port"], + dashboard_address=dask_cluster_opts["dashboard_address"], + memory_limit=dask_cluster_opts["memory_limit"], + ) + self._LOG.info( + "creating Dask Client...", extra={"track_id": self._worker_id} + ) self._dask_client = Client(dask_cluster) + self._nanny = Nanny(self._dask_client.cluster.scheduler.address) + + def maybe_restart_cluster(self, status: RequestStatus): + if status is RequestStatus.TIMEOUT: + self._LOG.info("recreating the cluster due to timeout") + self._dask_client.cluster.close() + self.create_dask_cluster() + if self._dask_client.cluster.status is Status.failed: + self._LOG.info("attempt to restart the cluster...") + try: + asyncio.run(self._nanny.restart()) + except Exception as err: + self._LOG.error( + "couldn't restart the cluster due to an error: %s", err + ) + self._LOG.info("closing the cluster") + self._dask_client.cluster.close() + if self._dask_client.cluster.status is Status.closed: + self._LOG.info("recreating the cluster") + self.create_dask_cluster() - def query_and_persist(self, ds_id, prod_id, query, compute, format): - kube = self._datastore.query(ds_id, prod_id, query, compute) - kube.persist(self._store, format=format) - - def estimate(self, channel, method, properties, body): - m = body.decode().split('\\') - dataset_id = m[0] - product_id = m[1] - query = m[2] - kube = self._datastore.query(dataset_id, product_id, query) - channel.basic_publish(exchange='', - routing_key=properties.reply_to, - properties=pika.BasicProperties(correlation_id = properties.correlation_id), - body=str(kube.get_nbytes())) - channel.basic_ack(delivery_tag=method.delivery_tag) - - def info(self, channel, method, properties, body): - m = body.decode().split('\\') - oper = m[0] # could be list or info - if (oper == 'list'): - if len(m) == 1: # list datasets - response = json.loads(self._datastore.dataset_list()) - if len(m) == 2: # list dataset products - dataset_id = m[1] - response = json.loads(self._datastore.product_list(dataset_id)) - - if (oper == 'info'): - if (len(m) == 2): # dataset info - dataset_id = m[1] - response = json.loads(self._datastore.dataset_info(dataset_id)) - if (len(m) == 3): # product info - dataset_id = m[1] - product_id = m[2] - response = json.loads(self._datastore.product_info(dataset_id, product_id)) - - channel.basic_publish(exchange='', - routing_key=properties.reply_to, - properties=pika.BasicProperties(correlation_id = \ - properties.correlation_id), - body=response) - channel.basic_ack(delivery_tag=method.delivery_tag) - - def query(self, channel, method, properties, body): - m = body.decode().split('\\') - request_id = m[0] - dataset_id = m[1] - product_id = m[2] - query = m[3] - format = m[4] - - self._db.update_request(request_id=request_id, worker_id=self._worker_id, status=RequestStatus.RUNNING) - # future = self._dask_client.submit(self.query_and_persist, dataset_id, product_id, query, False, format) - future = self._dask_client.submit(ds_query, dataset_id, product_id, query, False, self._catalog_path) + def ack_message(self, channel, delivery_tag): + """Note that `channel` must be the same pika channel instance via which + the message being ACKed was retrieved (AMQP protocol constraint). + """ + if channel.is_open: + channel.basic_ack(delivery_tag) + else: + self._LOG.info( + "cannot acknowledge the message. channel is closed!" + ) + pass + + def retry_until_timeout( + self, + future, + message: Message, + retries: int = 30, + sleep_time: int = 10, + ): + assert retries is not None, "`retries` cannot be `None`" + assert sleep_time is not None, "`sleep_time` cannot be `None`" + status = fail_reason = location_path = None try: - future.result() - self._db.update_request(request_id=request_id, worker_id=self._worker_id, status=RequestStatus.DONE) + self._LOG.debug( + "attempt to get result for the request", + extra={"track_id": message.request_id}, + ) + for _ in range(retries): + if future.done(): + self._LOG.debug( + "result is done", + extra={"track_id": message.request_id}, + ) + location_path = future.result() + status = RequestStatus.DONE + self._LOG.debug( + "result save under: %s", + location_path, + extra={"track_id": message.request_id}, + ) + break + self._LOG.debug( + f"result is not ready yet. sleeping {sleep_time} sec", + extra={"track_id": message.request_id}, + ) + time.sleep(sleep_time) + else: + self._LOG.info( + "processing timout", + extra={"track_id": message.request_id}, + ) + future.cancel() + status = RequestStatus.TIMEOUT + fail_reason = "Processing timeout" except Exception as e: - print(e) - self._db.update_request(request_id=request_id, worker_id=self._worker_id, status=RequestStatus.FAILED) + self._LOG.error( + "failed to get result due to an error: %s", + e, + exc_info=True, + stack_info=True, + extra={"track_id": message.request_id}, + ) + status = RequestStatus.FAILED + fail_reason = f"{type(e).__name__}: {str(e)}" + return (location_path, status, fail_reason) + + def handle_message(self, connection, channel, delivery_tag, body): + message: Message = Message(body) + self._LOG.debug( + "executing query: `%s`", + message.content, + extra={"track_id": message.request_id}, + ) + + # TODO: estimation size should be updated, too + self._db.update_request( + request_id=message.request_id, + worker_id=self._worker_id, + status=RequestStatus.RUNNING, + ) + + self._LOG.debug( + "submitting job for workflow request", + extra={"track_id": message.request_id}, + ) + future = self._dask_client.submit( + process, + message=message, + compute=False, + ) + location_path, status, fail_reason = self.retry_until_timeout( + future, + message=message, + retries=int(os.environ.get("RESULT_CHECK_RETRIES")), + ) + self._db.update_request( + request_id=message.request_id, + worker_id=self._worker_id, + status=status, + location_path=location_path, + size_bytes=self.get_size(location_path), + fail_reason=fail_reason, + ) + self._LOG.debug( + "acknowledging request", extra={"track_id": message.request_id} + ) + cb = functools.partial(self.ack_message, channel, delivery_tag) + connection.add_callback_threadsafe(cb) - channel.basic_ack(delivery_tag=method.delivery_tag) + self.maybe_restart_cluster(status) + self._LOG.debug( + "request acknowledged", extra={"track_id": message.request_id} + ) + + def on_message(self, channel, method_frame, header_frame, body, args): + (connection, threads) = args + delivery_tag = method_frame.delivery_tag + t = threading.Thread( + target=self.handle_message, + args=(connection, channel, delivery_tag, body), + ) + t.start() + threads.append(t) def subscribe(self, etype): - print(f'subscribe channel: {etype}_queue') - self._channel.queue_declare(queue=f'{etype}_queue', durable=True) + self._LOG.debug( + "subscribe channel: %s_queue", etype, extra={"track_id": "N/A"} + ) + self._channel.queue_declare(queue=f"{etype}_queue", durable=True) self._channel.basic_qos(prefetch_count=1) - self._channel.basic_consume(queue=f'{etype}_queue', on_message_callback=getattr(self, etype)) + + threads = [] + on_message_callback = functools.partial( + self.on_message, args=(self._conn, threads) + ) + + self._channel.basic_consume( + queue=f"{etype}_queue", on_message_callback=on_message_callback + ) def listen(self): while True: self._channel.start_consuming() -if __name__ == "__main__": + def get_size(self, location_path): + if location_path and os.path.exists(location_path): + return os.path.getsize(location_path) + return None - broker = os.getenv('BROKER', 'broker') - executor_types = os.getenv('EXECUTOR_TYPES', 'query').split(',') - catalog_path = os.getenv('CATALOG_PATH', 'catalog.yaml') - store_path = os.getenv('STORE_PATH', '.') - executor = Executor(broker=broker, - catalog_path=catalog_path, - store_path=store_path) - print('channel subscribe') +if __name__ == "__main__": + broker = os.getenv("BROKER_SERVICE_HOST", "broker") + executor_types = os.getenv("EXECUTOR_TYPES", "query").split(",") + store_path = os.getenv("STORE_PATH", ".") + + executor = Executor(broker=broker, store_path=store_path) + print("channel subscribe") for etype in executor_types: - if etype == 'query': - dask_cluster_opts = {} - dask_cluster_opts['scheduler_port'] = int(os.getenv('DASK_SCHEDULER_PORT', 8188)) - port = int(os.getenv('DASK_DASHBOARD_PORT', 8787)) - dask_cluster_opts['dashboard_address'] = f':{port}' - dask_cluster_opts['n_workers'] = int(os.getenv('DASK_N_WORKERS', 1)) - executor.create_dask_cluster(dask_cluster_opts) + if etype == "query": + executor.create_dask_cluster() executor.subscribe(etype) - - print('waiting for requests ...') - executor.listen() \ No newline at end of file + + print("waiting for requests ...") + executor.listen() diff --git a/executor/app/messaging.py b/executor/app/messaging.py new file mode 100644 index 0000000..37ce25a --- /dev/null +++ b/executor/app/messaging.py @@ -0,0 +1,45 @@ +import os +import logging +from enum import Enum + +from intake_geokube.queries.geoquery import GeoQuery +from intake_geokube.queries.workflow import Workflow + +MESSAGE_SEPARATOR = os.environ["MESSAGE_SEPARATOR"] + + +class MessageType(Enum): + QUERY = "query" + WORKFLOW = "workflow" + + +class Message: + _LOG = logging.getLogger("geokube.Message") + + request_id: int + dataset_id: str = "" + product_id: str = "" + type: MessageType + content: GeoQuery | Workflow + + def __init__(self, load: bytes) -> None: + self.request_id, msg_type, *query = load.decode().split( + MESSAGE_SEPARATOR + ) + match MessageType(msg_type): + case MessageType.QUERY: + self._LOG.debug("processing content of `query` type") + assert len(query) == 3, "improper content for query message" + self.dataset_id, self.product_id, self.content = query + self.content: GeoQuery = GeoQuery.parse(self.content) + self.type = MessageType.QUERY + case MessageType.WORKFLOW: + self._LOG.debug("processing content of `workflow` type") + assert len(query) == 1, "improper content for workflow message" + self.content: Workflow = Workflow.parse(query[0]) + self.dataset_id = self.content.dataset_id + self.product_id = self.content.product_id + self.type = MessageType.WORKFLOW + case _: + self._LOG.error("type `%s` is not supported", msg_type) + raise ValueError(f"type `{msg_type}` is not supported!") diff --git a/executor/app/meta.py b/executor/app/meta.py new file mode 100644 index 0000000..739ef62 --- /dev/null +++ b/executor/app/meta.py @@ -0,0 +1,27 @@ +"""Module with `LoggableMeta` metaclass""" +import os +import logging + + +class LoggableMeta(type): + """Metaclass for dealing with logger levels and handlers""" + + def __new__(cls, child_cls, bases, namespace): + # NOTE: method is called while creating a class, not an instance! + res = super().__new__(cls, child_cls, bases, namespace) + if hasattr(res, "_LOG"): + format_ = os.environ.get( + "LOGGING_FORMAT", + "%(asctime)s %(name)s %(levelname)s %(lineno)d" + " %(track_id)s %(message)s", + ) + formatter = logging.Formatter(format_) + logging_level = os.environ.get("LOGGING_LEVEL", "INFO") + res._LOG.setLevel(logging_level) + stream_handler = logging.StreamHandler() + stream_handler.setFormatter(formatter) + stream_handler.setLevel(logging_level) + res._LOG.addHandler(stream_handler) + for handler in logging.getLogger("geokube").handlers: + handler.setFormatter(formatter) + return res diff --git a/executor/requirements.txt b/executor/requirements.txt index c4a403b..f188e90 100644 --- a/executor/requirements.txt +++ b/executor/requirements.txt @@ -1,7 +1,4 @@ -pika -bokeh -dask -distributed -intake -pydantic -sqlalchemy \ No newline at end of file +pika==1.2.1 +prometheus_client +sqlalchemy +pydantic \ No newline at end of file diff --git a/geokube_packages/geokube-0.1a0-py3-none-any.whl b/geokube_packages/geokube-0.1a0-py3-none-any.whl deleted file mode 100644 index 99341a8..0000000 Binary files a/geokube_packages/geokube-0.1a0-py3-none-any.whl and /dev/null differ diff --git a/geokube_packages/intake_geokube-0.1a0-py3-none-any.whl b/geokube_packages/intake_geokube-0.1a0-py3-none-any.whl deleted file mode 100644 index e24fdb3..0000000 Binary files a/geokube_packages/intake_geokube-0.1a0-py3-none-any.whl and /dev/null differ diff --git a/geoquery/geoquery.py b/geoquery/geoquery.py deleted file mode 100644 index dc42414..0000000 --- a/geoquery/geoquery.py +++ /dev/null @@ -1,17 +0,0 @@ -from typing import Optional, List, Dict, Union - -from pydantic import BaseModel, root_validator - -class GeoQuery(BaseModel): - variable: List[str] - time: Optional[Union[Dict[str, str], Dict[str, List[str]]]] - area: Optional[Dict[str, float]] - locations: Optional[Dict[str, List[float]]] - vertical: Optional[Union[float, List[float]]] - filters: Optional[Dict] - - @root_validator - def area_locations_mutually_exclusive_validator(cls, query): - if query["area"] is not None and query["locations"] is not None: - raise KeyError("area and locations couldn't be processed together, please use one of them") - return query \ No newline at end of file