diff --git a/src/openeo_gfmap/fetching/generic.py b/src/openeo_gfmap/fetching/generic.py index 21dc209..46edb74 100644 --- a/src/openeo_gfmap/fetching/generic.py +++ b/src/openeo_gfmap/fetching/generic.py @@ -1,10 +1,10 @@ """ Generic extraction of features, supporting VITO backend. """ -from functools import partial -from typing import Callable +from typing import Callable, Optional import openeo +from openeo.rest import OpenEoApiError from openeo_gfmap.backend import Backend, BackendContext from openeo_gfmap.fetching import CollectionFetcher, FetchType, _log @@ -28,15 +28,18 @@ "vapour-pressure": "AGERA5-VAPOUR", "wind-speed": "AGERA5-WIND", } +KNOWN_UNTEMPORAL_COLLECTIONS = ["COPERNICUS_30"] -def _get_generic_fetcher(collection_name: str, fetch_type: FetchType) -> Callable: +def _get_generic_fetcher( + collection_name: str, fetch_type: FetchType, backend: Backend +) -> Callable: + band_mapping: Optional[dict] = None + if collection_name == "COPERNICUS_30": - BASE_MAPPING = BASE_DEM_MAPPING + band_mapping = BASE_DEM_MAPPING elif collection_name == "AGERA5": - BASE_MAPPING = BASE_WEATHER_MAPPING - else: - raise Exception("Please choose a valid collection.") + band_mapping = BASE_WEATHER_MAPPING def generic_default_fetcher( connection: openeo.Connection, @@ -45,23 +48,34 @@ def generic_default_fetcher( bands: list, **params, ) -> openeo.DataCube: - bands = convert_band_names(bands, BASE_MAPPING) + if band_mapping is not None: + bands = convert_band_names(bands, band_mapping) - if (collection_name == "COPERNICUS_30") and (temporal_extent is not None): + if (collection_name in KNOWN_UNTEMPORAL_COLLECTIONS) and ( + temporal_extent is not None + ): _log.warning( - "User set-up non None temporal extent for DEM collection. Ignoring it." + "Ignoring the temporal extent provided by the user as the collection %s is known to be untemporal.", + collection_name, ) temporal_extent = None - cube = _load_collection( - connection, - bands, - collection_name, - spatial_extent, - temporal_extent, - fetch_type, - **params, - ) + try: + cube = _load_collection( + connection, + bands, + collection_name, + spatial_extent, + temporal_extent, + fetch_type, + **params, + ) + except OpenEoApiError as e: + if "CollectionNotFound" in str(e): + raise ValueError( + f"Collection {collection_name} not found in the selected backend {backend.value}." + ) from e + raise e # # Apply if the collection is a GeoJSON Feature collection # if isinstance(spatial_extent, GeoJSON): @@ -76,12 +90,11 @@ def _get_generic_processor(collection_name: str, fetch_type: FetchType) -> Calla """Builds the preprocessing function from the collection name as it stored in the target backend. """ + band_mapping: Optional[dict] = None if collection_name == "COPERNICUS_30": - BASE_MAPPING = BASE_DEM_MAPPING + band_mapping = BASE_DEM_MAPPING elif collection_name == "AGERA5": - BASE_MAPPING = BASE_WEATHER_MAPPING - else: - raise Exception("Please choose a valid collection.") + band_mapping = BASE_WEATHER_MAPPING def generic_default_processor(cube: openeo.DataCube, **params): """Default collection preprocessing method for generic datasets. @@ -99,51 +112,14 @@ def generic_default_processor(cube: openeo.DataCube, **params): if collection_name == "COPERNICUS_30": cube = cube.min_time() - cube = rename_bands(cube, BASE_MAPPING) + if band_mapping is not None: + cube = rename_bands(cube, band_mapping) return cube return generic_default_processor -OTHER_BACKEND_MAP = { - "AGERA5": { - Backend.TERRASCOPE: { - "fetch": partial(_get_generic_fetcher, collection_name="AGERA5"), - "preprocessor": partial(_get_generic_processor, collection_name="AGERA5"), - }, - Backend.CDSE: { - "fetch": partial(_get_generic_fetcher, collection_name="AGERA5"), - "preprocessor": partial(_get_generic_processor, collection_name="AGERA5"), - }, - Backend.FED: { - "fetch": partial(_get_generic_fetcher, collection_name="AGERA5"), - "preprocessor": partial(_get_generic_processor, collection_name="AGERA5"), - }, - }, - "COPERNICUS_30": { - Backend.TERRASCOPE: { - "fetch": partial(_get_generic_fetcher, collection_name="COPERNICUS_30"), - "preprocessor": partial( - _get_generic_processor, collection_name="COPERNICUS_30" - ), - }, - Backend.CDSE: { - "fetch": partial(_get_generic_fetcher, collection_name="COPERNICUS_30"), - "preprocessor": partial( - _get_generic_processor, collection_name="COPERNICUS_30" - ), - }, - Backend.FED: { - "fetch": partial(_get_generic_fetcher, collection_name="COPERNICUS_30"), - "preprocessor": partial( - _get_generic_processor, collection_name="COPERNICUS_30" - ), - }, - }, -} - - def build_generic_extractor( backend_context: BackendContext, bands: list, @@ -152,13 +128,7 @@ def build_generic_extractor( **params, ) -> CollectionFetcher: """Creates a generic extractor adapted to the given backend. Currently only tested with VITO backend""" - backend_functions = OTHER_BACKEND_MAP.get(collection_name).get( - backend_context.backend - ) - - fetcher, preprocessor = ( - backend_functions["fetch"](fetch_type=fetch_type), - backend_functions["preprocessor"](fetch_type=fetch_type), - ) + fetcher = _get_generic_fetcher(collection_name, fetch_type, backend_context.backend) + preprocessor = _get_generic_processor(collection_name, fetch_type) return CollectionFetcher(backend_context, bands, fetcher, preprocessor, **params) diff --git a/src/openeo_gfmap/fetching/s1.py b/src/openeo_gfmap/fetching/s1.py index 6081d40..97d6fc2 100644 --- a/src/openeo_gfmap/fetching/s1.py +++ b/src/openeo_gfmap/fetching/s1.py @@ -67,8 +67,6 @@ def s1_grd_fetch_default( """ bands = convert_band_names(bands, BASE_SENTINEL1_GRD_MAPPING) - load_collection_parameters = params.get("load_collection", {}) - cube = _load_collection( connection, bands, @@ -76,7 +74,7 @@ def s1_grd_fetch_default( spatial_extent, temporal_extent, fetch_type, - **load_collection_parameters, + **params, ) if fetch_type is not FetchType.POINT and isinstance(spatial_extent, GeoJSON): diff --git a/src/openeo_gfmap/manager/job_manager.py b/src/openeo_gfmap/manager/job_manager.py index 9a9f7b2..986bbf4 100644 --- a/src/openeo_gfmap/manager/job_manager.py +++ b/src/openeo_gfmap/manager/job_manager.py @@ -1,9 +1,11 @@ import json +import pickle import threading +import time from concurrent.futures import ThreadPoolExecutor -from enum import Enum from functools import partial from pathlib import Path +from threading import Lock from typing import Callable, Optional, Union import pandas as pd @@ -16,28 +18,62 @@ from openeo_gfmap.stac import constants # Lock to use when writing to the STAC collection -_stac_lock = threading.Lock() +_stac_lock = Lock() + + +def retry_on_exception(max_retries: int, delay_s: int = 180): + """Decorator to retry a function if an exception occurs. + Used for post-job actions that can crash due to internal backend issues. Restarting the action + usually helps to solve the issue. + + Parameters + ---------- + max_retries: int + The maximum number of retries to attempt before finally raising the exception. + delay: int (default=180 seconds) + The delay in seconds to wait before retrying the decorated function. + """ + + def decorator(func): + def wrapper(*args, **kwargs): + latest_exception = None + for _ in range(max_retries): + try: + return func(*args, **kwargs) + except Exception as e: + time.sleep( + delay_s + ) # Waits before retrying, while allowing other futures to run. + latest_exception = e + raise latest_exception + + return wrapper + + return decorator def done_callback(future, df, idx): - """Sets the status of the job to the given status when the future is done.""" + """Changes the status of the job when the post-job action future is done.""" current_status = df.loc[idx, "status"] - if not future.exception(): + exception = future.exception() + if exception is None: if current_status == "postprocessing": df.loc[idx, "status"] = "finished" elif current_status == "postprocessing-error": df.loc[idx, "status"] = "error" + elif current_status == "running": + df.loc[idx, "status"] = "running" else: raise ValueError( f"Invalid status {current_status} for job {df.loc[idx, 'id']} for done_callback!" ) - - -class PostJobStatus(Enum): - """Indicates the workers if the job finished as sucessful or with an error.""" - - FINISHED = "finished" - ERROR = "error" + else: + _log.exception( + "Exception occurred in post-job future for job %s:\n%s", + df.loc[idx, "id"], + exception, + ) + df.loc[idx, "status"] = "error" class GFMAPJobManager(MultiBackendJobManager): @@ -53,13 +89,50 @@ def __init__( post_job_action: Optional[Callable] = None, poll_sleep: int = 5, n_threads: int = 1, - post_job_params: dict = {}, resume_postproc: bool = True, # If we need to check for post-job actions that crashed restart_failed: bool = False, # If we need to restart failed jobs + stac_enabled: bool = True, ): + """ + Initializes the GFMAP job manager. + + Parameters + ---------- + output_dir: Path + The base output directory where the results/stac/logs of the jobs will be stored. + output_path_generator: Callable + User defined function that generates the output path for the job results. Expects as + inputs the output directory, the index of the job in the job dataframe + and the row of the job, and returns the final path where to save a job result asset. + collection_id: Optional[str] + The ID of the STAC collection that is being generated. Can be left empty if the STAC + catalogue is not being generated or if it is being resumed from an existing catalogue. + collection_description: Optional[str] + The description of the STAC collection that is being generated. + stac: Optional[Union[str, Path]] + The path to the STAC collection to be saved or resumed. + If None, the default path will be used. + post_job_action: Optional[Callable] + A user defined function that will be called after a job is finished. It will receive + the list of items generated by the job and the row of the job, and should return the + updated list of items. + poll_sleep: int + The time in seconds to wait between polling the backend for job status. + n_threads: int + The number of threads to execute `on_job_done` and `on_job_error` functions. + resume_postproc: bool + If set to true, all `on_job_done` and `on_job_error` functions that failed are resumed. + restart_failed: bool + If set to true, all jobs that failed within the OpenEO backend are restarted. + stac_enabled: bool (default=True) + If the STAC generation is enabled or not. Disabling it will prevent the creation, + update and loading of the STAC collection. + """ self._output_dir = output_dir + self._catalogue_cache = output_dir / "catalogue_cache.bin" self.stac = stac + self.stac_enabled = stac_enabled self.collection_id = collection_id self.collection_description = collection_description @@ -74,41 +147,73 @@ def __init__( self._output_path_gen = output_path_generator self._post_job_action = post_job_action - self._post_job_params = post_job_params # Monkey patching the _normalize_df method to ensure we have no modification on the # geometry column MultiBackendJobManager._normalize_df = self._normalize_df super().__init__(poll_sleep) - self._root_collection = self._normalize_stac() + if self.stac_enabled: + self._root_collection = self._initialize_stac() - def _normalize_stac(self): + def _load_stac(self) -> Optional[pystac.Collection]: + """ + Loads the STAC collection from the cache, the specified `stac` path or the default path. + If no STAC collection is found, returns None. + """ default_collection_path = self._output_dir / "stac/collection.json" - if self.stac is not None: + if self._catalogue_cache.exists(): _log.info( - f"Reloading the STAC collection from the provided path: {self.stac}." + "Loading the STAC collection from the persisted binary file: %s.", + self._catalogue_cache, ) - root_collection = pystac.read_file(str(self.stac)) + with open(self._catalogue_cache, "rb") as file: + return pickle.load(file) + elif self.stac is not None: + _log.info( + "Reloading the STAC collection from the provided path: %s.", self.stac + ) + return pystac.read_file(str(self.stac)) elif default_collection_path.exists(): _log.info( - f"Reload the STAC collection from the default path: {default_collection_path}." + "Reload the STAC collection from the default path: %s.", + default_collection_path, ) self.stac = default_collection_path - root_collection = pystac.read_file(str(self.stac)) - else: - _log.info("Starting a fresh STAC collection.") - assert ( - self.collection_id is not None - ), "A collection ID is required to generate a STAC collection." - root_collection = pystac.Collection( - id=self.collection_id, - description=self.collection_description, - extent=None, + return pystac.read_file(str(self.stac)) + + _log.info( + "No STAC collection found as cache, in the default path or in the provided path." + ) + return None + + def _create_stac(self) -> pystac.Collection: + """ + Creates and returns new STAC collection. The created stac collection will use the + `collection_id` and `collection_description` parameters set in the constructor. + """ + if self.collection_id is None: + raise ValueError( + "A collection ID is required to generate a STAC collection." ) - root_collection.license = constants.LICENSE - root_collection.add_link(constants.LICENSE_LINK) - root_collection.stac_extensions = constants.STAC_EXTENSIONS + collection = pystac.Collection( + id=self.collection_id, + description=self.collection_description, + extent=None, + ) + collection.license = constants.LICENSE + collection.add_link(constants.LICENSE_LINK) + collection.stac_extensions = constants.STAC_EXTENSIONS + return collection + + def _initialize_stac(self) -> pystac.Collection: + """ + Loads and returns if possible an existing stac collection, otherwise creates a new one. + """ + root_collection = self._load_stac() + if not root_collection: + _log.info("Starting a fresh STAC collection.") + root_collection = self._create_stac() return root_collection @@ -150,16 +255,30 @@ def _resume_postjob_actions(self, df: pd.DataFrame): job = connection.job(row.id) if row.status == "postprocessing": _log.info( - f"Resuming postprocessing of job {row.id}, queueing on_job_finished..." + "Resuming postprocessing of job %s, queueing on_job_finished...", + row.id, + ) + future = self._executor.submit(self.on_job_done, job, row, _stac_lock) + future.add_done_callback( + partial( + done_callback, + df=df, + idx=idx, + ) ) - future = self._executor.submit(self.on_job_done, job, row) - future.add_done_callback(partial(done_callback, df=df, idx=idx)) else: _log.info( - f"Resuming postprocessing of job {row.id}, queueing on_job_error..." + "Resuming postprocessing of job %s, queueing on_job_error...", + row.id, ) future = self._executor.submit(self.on_job_error, job, row) - future.add_done_callback(partial(done_callback, df=df, idx=idx)) + future.add_done_callback( + partial( + done_callback, + df=df, + idx=idx, + ) + ) self._futures.append(future) def _restart_failed_jobs(self, df: pd.DataFrame): @@ -167,7 +286,9 @@ def _restart_failed_jobs(self, df: pd.DataFrame): failed_tasks = df[df.status.isin(["error", "start_failed"])] not_started_tasks = df[df.status == "not_started"] _log.info( - f"Resetting {len(failed_tasks)} failed jobs to 'not_started'. {len(not_started_tasks)} jobs are already 'not_started'." + "Resetting %s failed jobs to 'not_started'. %s jobs are already 'not_started'.", + len(failed_tasks), + len(not_started_tasks), ) for idx, _ in failed_tasks.iterrows(): df.loc[idx, "status"] = "not_started" @@ -203,38 +324,53 @@ def _update_statuses(self, df: pd.DataFrame): job_metadata["status"] == "finished" ): _log.info( - f"Job {job.job_id} finished successfully, queueing on_job_done..." + "Job %s finished successfully, queueing on_job_done...", job.job_id ) job_status = "postprocessing" - future = self._executor.submit(self.on_job_done, job, row) + future = self._executor.submit(self.on_job_done, job, row, _stac_lock) # Future will setup the status to finished when the job is done - future.add_done_callback(partial(done_callback, df=df, idx=idx)) - self._futures.append(future) - df.loc[idx, "costs"] = job_metadata["costs"] - df.loc[idx, "memory"] = ( - job_metadata["usage"] - .get("max_executor_memory", {}) - .get("value", None) - ) - df.loc[idx, "cpu"] = ( - job_metadata["usage"].get("cpu", {}).get("value", None) - ) - df.loc[idx, "duration"] = ( - job_metadata["usage"].get("duration", {}).get("value", None) + future.add_done_callback( + partial( + done_callback, + df=df, + idx=idx, + ) ) + self._futures.append(future) + if "costs" in job_metadata: + df.loc[idx, "costs"] = job_metadata["costs"] + df.loc[idx, "memory"] = ( + job_metadata["usage"] + .get("max_executor_memory", {}) + .get("value", None) + ) + + else: + _log.warning( + "Costs not found in job %s metadata. Costs will be set to 'None'.", + job.job_id, + ) # Case in which it failed if (df.loc[idx, "status"] != "error") and ( job_metadata["status"] == "error" ): _log.info( - f"Job {job.job_id} finished with error, queueing on_job_error..." + "Job %s finished with error, queueing on_job_error...", + job.job_id, ) job_status = "postprocessing-error" future = self._executor.submit(self.on_job_error, job, row) # Future will setup the status to error when the job is done - future.add_done_callback(partial(done_callback, df=df, idx=idx)) + future.add_done_callback( + partial( + done_callback, + df=df, + idx=idx, + ) + ) self._futures.append(future) + if "costs" in job_metadata: df.loc[idx, "costs"] = job_metadata["costs"] df.loc[idx, "status"] = job_status @@ -242,6 +378,7 @@ def _update_statuses(self, df: pd.DataFrame): # Clear the futures that are done and raise their potential exceptions if they occurred. self._clear_queued_actions() + @retry_on_exception(max_retries=2, delay_s=180) def on_job_error(self, job: BatchJob, row: pd.Series): """Method called when a job finishes with an error. @@ -252,7 +389,14 @@ def on_job_error(self, job: BatchJob, row: pd.Series): row: pd.Series The row in the dataframe that contains the job relative information. """ - logs = job.logs() + try: + logs = job.logs() + except Exception as e: # pylint: disable=broad-exception-caught + _log.exception( + "Error getting logs in `on_job_error` for job %s:\n%s", job.job_id, e + ) + logs = [] + error_logs = [log for log in logs if log.level.lower() == "error"] job_metadata = job.describe_job() @@ -271,15 +415,21 @@ def on_job_error(self, job: BatchJob, row: pd.Series): f"Couldn't find any error logs. Please check the error manually on job ID: {job.job_id}." ) - def on_job_done(self, job: BatchJob, row: pd.Series): + @retry_on_exception(max_retries=2, delay_s=30) + def on_job_done( + self, job: BatchJob, row: pd.Series, lock: Lock + ): # pylint: disable=arguments-differ """Method called when a job finishes successfully. It will first download the results of the job and then call the `post_job_action` method. """ + job_products = {} for idx, asset in enumerate(job.get_results().get_assets()): try: _log.debug( - f"Generating output path for asset {asset.name} from job {job.job_id}..." + "Generating output path for asset %s from job %s...", + asset.name, + job.job_id, ) output_path = self._output_path_gen(self._output_dir, idx, row) # Make the output path @@ -288,11 +438,17 @@ def on_job_done(self, job: BatchJob, row: pd.Series): # Add to the list of downloaded products job_products[f"{job.job_id}_{asset.name}"] = [output_path] _log.debug( - f"Downloaded {asset.name} from job {job.job_id} -> {output_path}" + "Downloaded %s from job %s -> %s", + asset.name, + job.job_id, + output_path, ) except Exception as e: _log.exception( - f"Error downloading asset {asset.name} from job {job.job_id}", e + "Error downloading asset %s from job %s:\n%s", + asset.name, + job.job_id, + e, ) raise e @@ -313,53 +469,35 @@ def on_job_done(self, job: BatchJob, row: pd.Series): asset.href = str( asset_path ) # Update the asset href to the output location set by the output_path_generator - # item.id = f"{job.job_id}_{item.id}" + # Add the item to the the current job items. job_items.append(item) - _log.info(f"Parsed item {item.id} from job {job.job_id}") + _log.info("Parsed item %s from job %s", item.id, job.job_id) except Exception as e: _log.exception( - f"Error failed to add item {item.id} from job {job.job_id} to STAC collection", + "Error failed to add item %s from job %s to STAC collection:\n%s", + item.id, + job.job_id, e, ) - raise e # _post_job_action returns an updated list of stac items. Post job action can therefore # update the stac items and access their products through the HREF. It is also the # reponsible of adding the appropriate metadata/assets to the items. if self._post_job_action is not None: - _log.debug(f"Calling post job action for job {job.job_id}...") - job_items = self._post_job_action(job_items, row, self._post_job_params) + _log.debug("Calling post job action for job %s...", job.job_id) + job_items = self._post_job_action(job_items, row) - _log.info(f"Adding {len(job_items)} items to the STAC collection...") + _log.info("Adding %s items to the STAC collection...", len(job_items)) - with _stac_lock: # Take the STAC lock to avoid concurrence issues - # Filters the job items to only keep the ones that are not already in the collection - existing_ids = [item.id for item in self._root_collection.get_all_items()] - job_items = [item for item in job_items if item.id not in existing_ids] + if self.stac_enabled: + with lock: + self._update_stac(job.job_id, job_items) - self._root_collection.add_items(job_items) - _log.info(f"Added {len(job_items)} items to the STAC collection.") - - _log.info(f"Writing STAC collection for {job.job_id} to file...") - try: - self._write_stac() - except Exception as e: - _log.exception( - f"Error writing STAC collection for job {job.job_id} to file.", e - ) - raise e - _log.info(f"Wrote STAC collection for {job.job_id} to file.") - - _log.info(f"Job {job.job_id} and post job action finished successfully.") + _log.info("Job %s and post job action finished successfully.", job.job_id) def _normalize_df(self, df: pd.DataFrame) -> pd.DataFrame: - """Ensure we have the required columns and the expected type for the geometry column. - - :param df: The dataframe to normalize. - :return: a new dataframe that is normalized. - """ - + """Ensure we have the required columns and the expected type for the geometry column.""" # check for some required columns. required_with_default = [ ("status", "not_started"), @@ -377,7 +515,7 @@ def _normalize_df(self, df: pd.DataFrame) -> pd.DataFrame: } df = df.assign(**new_columns) - _log.debug(f"Normalizing dataframe. Columns: {df.columns}") + _log.debug("Normalizing dataframe. Columns: %s", df.columns) return df @@ -412,7 +550,7 @@ def run_jobs( The file to track the results of the jobs. """ # Starts the thread pool to work on the on_job_done and on_job_error methods - _log.info(f"Starting ThreadPoolExecutor with {self._n_threads} workers.") + _log.info("Starting ThreadPoolExecutor with %s workers.", self._n_threads) with ThreadPoolExecutor(max_workers=self._n_threads) as executor: _log.info("Creating and running jobs.") self._executor = executor @@ -423,6 +561,13 @@ def run_jobs( self._wait_queued_actions() _log.info("Exiting ThreadPoolExecutor.") self._executor = None + _log.info("All jobs finished running.") + if self.stac_enabled: + _log.info("Saving persisted STAC collection to final .json collection.") + self._write_stac() + _log.info("Saved STAC catalogue to JSON format, all tasks finished!") + else: + _log.info("STAC was disabled, skipping generation of the catalogue.") def _write_stac(self): """Writes the STAC collection to the output directory.""" @@ -439,6 +584,36 @@ def _write_stac(self): self._root_collection.normalize_hrefs(str(root_path)) self._root_collection.save(catalog_type=CatalogType.SELF_CONTAINED) + def _persist_stac(self): + """Persists the STAC collection by saving it into a binary file.""" + _log.debug("Validating the STAC collection before persisting.") + self._root_collection.validate_all() + _log.info("Persisting STAC collection to temp file %s.", self._catalogue_cache) + with open(self._catalogue_cache, "wb") as file: + pickle.dump(self._root_collection, file) + + def _update_stac(self, job_id: str, job_items: list[pystac.Item]): + """Updates the STAC collection by adding the items generated by the job. + Does not add duplicates or override with the same item ID. + """ + try: + _log.info("Thread %s entered the STAC lock.", threading.get_ident()) + # Filters the job items to only keep the ones that are not already in the collection + existing_ids = [item.id for item in self._root_collection.get_all_items()] + job_items = [item for item in job_items if item.id not in existing_ids] + + self._root_collection.add_items(job_items) + _log.info("Added %s items to the STAC collection.", len(job_items)) + + self._persist_stac() + except Exception as e: + _log.exception( + "Error adding items to the STAC collection for job %s:\n%s ", + job_id, + str(e), + ) + raise e + def setup_stac( self, constellation: Optional[str] = None, diff --git a/src/openeo_gfmap/manager/job_splitters.py b/src/openeo_gfmap/manager/job_splitters.py index 7ed9a5f..19ede9c 100644 --- a/src/openeo_gfmap/manager/job_splitters.py +++ b/src/openeo_gfmap/manager/job_splitters.py @@ -2,6 +2,7 @@ form of a GeoDataFrames. """ +from functools import lru_cache from pathlib import Path from typing import List @@ -30,6 +31,11 @@ def load_s2_grid(web_mercator: bool = False) -> gpd.GeoDataFrame: url, timeout=180, # 3mins ) + if response.status_code != 200: + raise ValueError( + "Failed to download the S2 grid from the artifactory. " + f"Status code: {response.status_code}" + ) with open(gdf_path, "wb") as f: f.write(response.content) return gpd.read_parquet(gdf_path) @@ -66,6 +72,10 @@ def split_job_s2grid( if polygons.crs is None: raise ValueError("The GeoDataFrame must contain a CRS") +<<<<<<< HEAD + polygons = polygons.to_crs(epsg=4326) + polygons["geometry"] = polygons.geometry.centroid +======= epsg = 3857 if web_mercator else 4326 original_crs = polygons.crs @@ -73,16 +83,26 @@ def split_job_s2grid( polygons = polygons.to_crs(epsg=epsg) polygons["centroid"] = polygons.geometry.centroid +>>>>>>> 1110e4aa35cfbe72a9dbd9b56e40048ea40ca2d8 # Dataset containing all the S2 tiles, find the nearest S2 tile for each point s2_grid = load_s2_grid(web_mercator) s2_grid["geometry"] = s2_grid.geometry.centroid +<<<<<<< HEAD + # Filter tiles on CDSE availability + s2_grid = s2_grid[s2_grid.cdse_valid] + + polygons = gpd.sjoin_nearest(polygons, s2_grid[["tile", "geometry"]]).drop( + columns=["index_right"] + ) +======= polygons = gpd.sjoin_nearest( polygons.set_geometry("centroid"), s2_grid[["tile", "geometry"]] ).drop(columns=["index_right", "centroid"]) polygons = polygons.set_geometry("geometry").to_crs(original_crs) +>>>>>>> 1110e4aa35cfbe72a9dbd9b56e40048ea40ca2d8 split_datasets = [] for _, sub_gdf in polygons.groupby("tile"): diff --git a/src/openeo_gfmap/utils/catalogue.py b/src/openeo_gfmap/utils/catalogue.py index 2553a5d..e31f3af 100644 --- a/src/openeo_gfmap/utils/catalogue.py +++ b/src/openeo_gfmap/utils/catalogue.py @@ -1,9 +1,12 @@ """Functionalities to interract with product catalogues.""" +from typing import Optional + import geojson import requests from pyproj.crs import CRS from rasterio.warp import transform_bounds +from requests import adapters from shapely.geometry import box, shape from shapely.ops import unary_union @@ -15,6 +18,20 @@ TemporalContext, ) +request_sessions: Optional[requests.Session] = None + + +def _request_session() -> requests.Session: + global request_sessions + + if request_sessions is None: + request_sessions = requests.Session() + retries = adapters.Retry( + total=5, backoff_factor=1, status_forcelist=[500, 502, 503, 504] + ) + request_sessions.mount("https://", adapters.HTTPAdapter(max_retries=retries)) + return request_sessions + class UncoveredS1Exception(Exception): """Exception raised when there is no product available to fully cover spatially a given @@ -39,6 +56,14 @@ def _query_cdse_catalogue( temporal_extent: TemporalContext, **additional_parameters: dict, ) -> dict: + """ + Queries the CDSE catalogue for a given collection, spatio-temporal context and additional + parameters. + + Params + ------ + + """ minx, miny, maxx, maxy = bounds # The date format should be YYYY-MM-DD @@ -48,13 +73,14 @@ def _query_cdse_catalogue( url = ( f"https://catalogue.dataspace.copernicus.eu/resto/api/collections/" f"{collection}/search.json?box={minx},{miny},{maxx},{maxy}" - f"&sortParam=startDate&maxRecords=100" - f"&dataset=ESA-DATASET&startDate={start_date}&completionDate={end_date}" + f"&sortParam=startDate&maxRecords=1000&dataset=ESA-DATASET" + f"&startDate={start_date}&completionDate={end_date}" ) for key, value in additional_parameters.items(): url += f"&{key}={value}" - response = requests.get(url) + session = _request_session() + response = session.get(url, timeout=60) if response.status_code != 200: raise Exception( @@ -107,19 +133,20 @@ def _check_cdse_catalogue( return len(grd_tiles) > 0 -def s1_area_per_orbitstate( +def s1_area_per_orbitstate_vvvh( backend: BackendContext, spatial_extent: SpatialContext, temporal_extent: TemporalContext, ) -> dict: - """Evaluates for both the ascending and descending state orbits the area of interesection - between the given spatio-temporal context and the products available in the backend's - catalogue. + """ + Evaluates for both the ascending and descending state orbits the area of interesection for the + available products with a VV&VH polarisation. Parameters ---------- backend : BackendContext - The backend to be within, as each backend might use different catalogues. + The backend to be within, as each backend might use different catalogues. Only the CDSE, + CDSE_STAGING and FED backends are supported. spatial_extent : SpatialContext The spatial extent to be checked, it will check within its bounding box. temporal_extent : TemporalContext @@ -159,7 +186,11 @@ def s1_area_per_orbitstate( if backend.backend in [Backend.CDSE, Backend.CDSE_STAGING, Backend.FED]: ascending_products = _parse_cdse_products( _query_cdse_catalogue( - "Sentinel1", bounds, temporal_extent, orbitDirection="ASCENDING" + "Sentinel1", + bounds, + temporal_extent, + orbitDirection="ASCENDING", + polarisation="VV%26VH", ) ) descending_products = _parse_cdse_products( @@ -168,6 +199,7 @@ def s1_area_per_orbitstate( bounds, temporal_extent, orbitDirection="DESCENDING", + polarisation="VV%26VH", ) ) else: @@ -204,18 +236,19 @@ def s1_area_per_orbitstate( } -def select_S1_orbitstate( +def select_s1_orbitstate_vvvh( backend: BackendContext, spatial_extent: SpatialContext, temporal_extent: TemporalContext, ) -> str: - """Selects the orbit state that covers the most area of the given spatio-temporal context - for the Sentinel-1 collection. + """Selects the orbit state that covers the most area of intersection for the + available products with a VV&VH polarisation. Parameters ---------- backend : BackendContext - The backend to be within, as each backend might use different catalogues. + The backend to be within, as each backend might use different catalogues. Only the CDSE, + CDSE_STAGING and FED backends are supported. spatial_extent : SpatialContext The spatial extent to be checked, it will check within its bounding box. temporal_extent : TemporalContext @@ -228,7 +261,7 @@ def select_S1_orbitstate( """ # Queries the products in the catalogues - areas = s1_area_per_orbitstate(backend, spatial_extent, temporal_extent) + areas = s1_area_per_orbitstate_vvvh(backend, spatial_extent, temporal_extent) ascending_overlap = areas["ASCENDING"]["full_overlap"] descending_overlap = areas["DESCENDING"]["full_overlap"] diff --git a/tests/test_openeo_gfmap/test_generic_fetchers.py b/tests/test_openeo_gfmap/test_generic_fetchers.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/test_openeo_gfmap/test_s1_fetchers.py b/tests/test_openeo_gfmap/test_s1_fetchers.py index d805fb8..01e41e2 100644 --- a/tests/test_openeo_gfmap/test_s1_fetchers.py +++ b/tests/test_openeo_gfmap/test_s1_fetchers.py @@ -50,7 +50,7 @@ def sentinel1_grd( "elevation_model": "COPERNICUS_30", "coefficient": "gamma0-ellipsoid", "load_collection": { - "polarization": lambda polar: (polar == "VV") or (polar == "VH"), + "polarization": lambda polar: polar == "VV&VH", }, } @@ -156,7 +156,7 @@ def sentinel1_grd_point_based( "elevation_model": "COPERNICUS_30", "coefficient": "gamma0-ellipsoid", "load_collection": { - "polarization": lambda polar: (polar == "VV") or (polar == "VH"), + "polarization": lambda polar: polar == "VV&VH", }, } extractor = build_sentinel1_grd_extractor( diff --git a/tests/test_openeo_gfmap/test_utils.py b/tests/test_openeo_gfmap/test_utils.py index e63dc1f..95c241a 100644 --- a/tests/test_openeo_gfmap/test_utils.py +++ b/tests/test_openeo_gfmap/test_utils.py @@ -7,7 +7,10 @@ from openeo_gfmap import Backend, BackendContext, BoundingBoxExtent, TemporalContext from openeo_gfmap.utils import split_collection_by_epsg, update_nc_attributes -from openeo_gfmap.utils.catalogue import s1_area_per_orbitstate, select_S1_orbitstate +from openeo_gfmap.utils.catalogue import ( + s1_area_per_orbitstate_vvvh, + select_s1_orbitstate_vvvh, +) # Region of Paris, France SPATIAL_CONTEXT = BoundingBoxExtent( @@ -21,7 +24,7 @@ def test_query_cdse_catalogue(): backend_context = BackendContext(Backend.CDSE) - response = s1_area_per_orbitstate( + response = s1_area_per_orbitstate_vvvh( backend=backend_context, spatial_extent=SPATIAL_CONTEXT, temporal_extent=TEMPORAL_CONTEXT, @@ -42,7 +45,7 @@ def test_query_cdse_catalogue(): assert response["DESCENDING"]["full_overlap"] is True # Testing the decision maker, it should return DESCENDING - decision = select_S1_orbitstate( + decision = select_s1_orbitstate_vvvh( backend=backend_context, spatial_extent=SPATIAL_CONTEXT, temporal_extent=TEMPORAL_CONTEXT,