diff --git a/scripts/inference/cropland_mapping.py b/scripts/inference/cropland_mapping.py index 48f7c1ba..3d4fbecb 100644 --- a/scripts/inference/cropland_mapping.py +++ b/scripts/inference/cropland_mapping.py @@ -3,16 +3,10 @@ import argparse from pathlib import Path -import openeo from openeo_gfmap import BoundingBoxExtent, TemporalContext from openeo_gfmap.backend import Backend, BackendContext -from openeo_gfmap.features.feature_extractor import apply_feature_extractor -from openeo_gfmap.inference.model_inference import apply_model_inference -from openeo_gfmap.preprocessing.scaling import compress_uint8 -from worldcereal.openeo.feature_extractor import PrestoFeatureExtractor -from worldcereal.openeo.inference import CroplandClassifier -from worldcereal.openeo.preprocessing import worldcereal_preprocessed_inputs_gfmap +from worldcereal.job import generate_map ONNX_DEPS_URL = "https://artifactory.vgt.vito.be/artifactory/auxdata-public/openeo/onnx_dependencies_1.16.3.zip" @@ -58,63 +52,11 @@ backend_context = BackendContext(Backend.FED) - connection = openeo.connect( - "https://openeo.creo.vito.be/openeo/" - ).authenticate_oidc() - - # Preparing the input cube for the inference - inputs = worldcereal_preprocessed_inputs_gfmap( - connection=connection, - backend_context=backend_context, - spatial_extent=spatial_extent, - temporal_extent=temporal_extent, - ) - - # Test feature computer - presto_parameters = { - "rescale_s1": False, # Will be done in the Presto UDF itself! - } - - features = apply_feature_extractor( - feature_extractor_class=PrestoFeatureExtractor, - cube=inputs, - parameters=presto_parameters, - size=[ - {"dimension": "x", "unit": "px", "value": 100}, - {"dimension": "y", "unit": "px", "value": 100}, - ], - overlap=[ - {"dimension": "x", "unit": "px", "value": 0}, - {"dimension": "y", "unit": "px", "value": 0}, - ], - ) - - catboost_parameters = {} - - classes = apply_model_inference( - model_inference_class=CroplandClassifier, - cube=features, - parameters=catboost_parameters, - size=[ - {"dimension": "x", "unit": "px", "value": 100}, - {"dimension": "y", "unit": "px", "value": 100}, - {"dimension": "t", "value": "P1D"}, - ], - overlap=[ - {"dimension": "x", "unit": "px", "value": 0}, - {"dimension": "y", "unit": "px", "value": 0}, - ], - ) - - # Cast to uint8 - classes = compress_uint8(classes) - - classes.execute_batch( - outputfile=args.output_path, - out_format="GTiff", - job_options={ - "driver-memory": "4g", - "executor-memoryOverhead": "12g", - "udf-dependency-archives": [f"{ONNX_DEPS_URL}#onnx_deps"], - }, + generate_map( + spatial_extent, + temporal_extent, + backend_context, + args.output_path, + product="cropland", + format="GTiff", ) diff --git a/src/worldcereal/job.py b/src/worldcereal/job.py new file mode 100644 index 00000000..36367491 --- /dev/null +++ b/src/worldcereal/job.py @@ -0,0 +1,108 @@ +from pathlib import Path +from typing import Union + +import openeo +from openeo_gfmap import BackendContext, BoundingBoxExtent, TemporalContext +from openeo_gfmap.features.feature_extractor import apply_feature_extractor +from openeo_gfmap.inference.model_inference import apply_model_inference +from openeo_gfmap.preprocessing.scaling import compress_uint8 + +from worldcereal.openeo.feature_extractor import PrestoFeatureExtractor +from worldcereal.openeo.inference import CroplandClassifier +from worldcereal.openeo.preprocessing import worldcereal_preprocessed_inputs_gfmap + +ONNX_DEPS_URL = "https://artifactory.vgt.vito.be/artifactory/auxdata-public/openeo/onnx_dependencies_1.16.3.zip" + + +def generate_map( + spatial_extent: BoundingBoxExtent, + temporal_extent: TemporalContext, + backend_context: BackendContext, + output_path: Union[Path, str], + product: str = "cropland", + format: str = "GTiff", +): + """Main function to generate a WorldCereal product. + + Args: + spatial_extent (BoundingBoxExtent): spatial extent of the map + temporal_extent (TemporalContext): temporal range to consider + backend_context (BackendContext): backend to run the job on + output_path (Union[Path, str]): output path to download the product to + product (str, optional): product describer. Defaults to "cropland". + format (str, optional): Output format. Defaults to "GTiff". + + Raises: + ValueError: if the product is not supported + + """ + + # Connect to openeo + connection = openeo.connect( + "https://openeo.creo.vito.be/openeo/" + ).authenticate_oidc() + + # Preparing the input cube for the inference + inputs = worldcereal_preprocessed_inputs_gfmap( + connection=connection, + backend_context=backend_context, + spatial_extent=spatial_extent, + temporal_extent=temporal_extent, + ) + + # Run feature computer + presto_parameters = { + "rescale_s1": False, # Will be done in the Presto UDF itself! + } + + features = apply_feature_extractor( + feature_extractor_class=PrestoFeatureExtractor, + cube=inputs, + parameters=presto_parameters, + size=[ + {"dimension": "x", "unit": "px", "value": 100}, + {"dimension": "y", "unit": "px", "value": 100}, + ], + overlap=[ + {"dimension": "x", "unit": "px", "value": 0}, + {"dimension": "y", "unit": "px", "value": 0}, + ], + ) + + if product == "cropland": + # initiate default cropland model + model_inference_class = CroplandClassifier + model_inference_parameters = {} + else: + raise ValueError(f"Product {product} not supported.") + + if format not in ["GTiff", "NetCDF"]: + raise ValueError(f"Format {format} not supported.") + + classes = apply_model_inference( + model_inference_class=model_inference_class, + cube=features, + parameters=model_inference_parameters, + size=[ + {"dimension": "x", "unit": "px", "value": 100}, + {"dimension": "y", "unit": "px", "value": 100}, + {"dimension": "t", "value": "P1D"}, + ], + overlap=[ + {"dimension": "x", "unit": "px", "value": 0}, + {"dimension": "y", "unit": "px", "value": 0}, + ], + ) + + # Cast to uint8 + classes = compress_uint8(classes) + + classes.execute_batch( + outputfile=output_path, + out_format=format, + job_options={ + "driver-memory": "4g", + "executor-memoryOverhead": "12g", + "udf-dependency-archives": [f"{ONNX_DEPS_URL}#onnx_deps"], + }, + ) diff --git a/src/worldcereal/seasons.py b/src/worldcereal/seasons.py index f889295d..c46198e7 100644 --- a/src/worldcereal/seasons.py +++ b/src/worldcereal/seasons.py @@ -17,6 +17,10 @@ class NoSeasonError(Exception): pass +class SeasonMaxDiffError(Exception): + pass + + def doy_to_angle(day_of_year, total_days=365): return 2 * math.pi * (day_of_year / total_days) @@ -25,6 +29,32 @@ def angle_to_doy(angle, total_days=365): return (angle / (2 * math.pi)) * total_days +def max_doy_difference(doy_array): + """Method to check the max difference in days between all DOY values + in an array taking into account wrap-around effects due to the circular nature + """ + + doy_array = np.expand_dims(doy_array, axis=1) + x, y = np.meshgrid(doy_array, doy_array.T) + + days_in_year = 365 # True for crop calendars + + # Step 2: Calculate the direct difference + direct_difference = np.abs(x - y) + + # Step 3: Calculate the wrap-around difference + wrap_around_difference = days_in_year - direct_difference + + # Step 4: Determine the minimum difference + effective_difference = np.min( + np.stack([direct_difference, wrap_around_difference]), axis=0 + ) + + # Step 5: Determine the maximum difference for all combinations + + return effective_difference.max() + + def circular_median_day_of_year(doy_array, total_days=365): """This function computes the median doy from a given array taking into account its circular nature. Still has to be used with caution! @@ -211,7 +241,10 @@ def season_doys_to_dates( def get_processing_dates_for_extent( - extent: BoundingBoxExtent, year: int, season: str = "tc-annual" + extent: BoundingBoxExtent, + year: int, + season: str = "tc-annual", + max_seasonality_difference: int = 60, ): """Function to retrieve required temporal range of input products for a given extent, season and year. Based on the requested season's end date @@ -221,9 +254,12 @@ def get_processing_dates_for_extent( extent (BoundingBoxExtent): extent for which to infer dates year (int): year in which the end of season needs to be season (str): season identifier for which to infer dates. Defaults to tc-annual + max_seasonality_difference (int): maximum difference in seasonality for all pixels + in extent before raising an exception. Defaults to 60. Raises: ValueError: invalid season specified + SeasonMaxDiffError: raised when seasonality difference is too large Returns: (start_date, end_date): tuple of date strings specifying @@ -243,6 +279,16 @@ def get_processing_dates_for_extent( if not np.isfinite(eos_doy).any(): raise NoSeasonError(f"No valid EOS DOY found for season `{season}`") + # Only consider valid seasonality pixels + eos_doy = eos_doy[np.isfinite(eos_doy)] + + # Check max seasonality difference + seasonality_difference = max_doy_difference(eos_doy) + if seasonality_difference > max_seasonality_difference: + raise SeasonMaxDiffError( + f"Seasonality difference too large: {seasonality_difference} days" + ) + # Compute median DOY eos_doy_median = circular_median_day_of_year(eos_doy) diff --git a/tests/worldcerealtests/test_seasons.py b/tests/worldcerealtests/test_seasons.py index 71be2a0e..516a528c 100644 --- a/tests/worldcerealtests/test_seasons.py +++ b/tests/worldcerealtests/test_seasons.py @@ -45,7 +45,7 @@ def test_doy_to_date_after(): def test_get_processing_dates_for_extent(): # Test to check if we can infer processing dates for default season # tc-annual - bounds = (167286, 553423, 943774, 997257) + bounds = (574680, 5621800, 575320, 5622440) epsg = 32631 year = 2021 extent = BoundingBoxExtent(*bounds, epsg)