From 7f25b68793303c57ddaee9fb57ecd57a338f8770 Mon Sep 17 00:00:00 2001 From: Darius Couchard Date: Tue, 2 Jul 2024 11:19:48 +0200 Subject: [PATCH] Added configuration data classes for cropland/croptype mapping workflow --- scripts/inference/cropland_mapping.py | 16 +-- src/worldcereal/job.py | 191 +++++++++++++++++--------- 2 files changed, 134 insertions(+), 73 deletions(-) diff --git a/scripts/inference/cropland_mapping.py b/scripts/inference/cropland_mapping.py index 90f2d197..a7287726 100644 --- a/scripts/inference/cropland_mapping.py +++ b/scripts/inference/cropland_mapping.py @@ -20,25 +20,25 @@ parser.add_argument("maxx", type=float, help="Maximum X coordinate (east)") parser.add_argument("maxy", type=float, help="Maximum Y coordinate (north)") parser.add_argument( - "--epsg", - type=int, - default=4326, - help="EPSG code of the input `minx`, `miny`, `maxx`, `maxy` parameters.", + "start_date", type=str, help="Starting date for data extraction." ) + parser.add_argument("end_date", type=str, help="Ending date for data extraction.") parser.add_argument( "product", type=str, help="Product to generate. One of ['cropland', 'croptype']", ) - parser.add_argument( - "start_date", type=str, help="Starting date for data extraction." - ) - parser.add_argument("end_date", type=str, help="Ending date for data extraction.") parser.add_argument( "output_path", type=Path, help="Path to folder where to save the resulting GeoTiff.", ) + parser.add_argument( + "--epsg", + type=int, + default=4326, + help="EPSG code of the input `minx`, `miny`, `maxx`, `maxy` parameters.", + ) args = parser.parse_args() diff --git a/src/worldcereal/job.py b/src/worldcereal/job.py index 8f6ac9e4..50d260c5 100644 --- a/src/worldcereal/job.py +++ b/src/worldcereal/job.py @@ -11,6 +11,7 @@ from openeo_gfmap.features.feature_extractor import apply_feature_extractor from openeo_gfmap.inference.model_inference import apply_model_inference from openeo_gfmap.preprocessing.scaling import compress_uint8, compress_uint16 +from pydantic import BaseModel from worldcereal.openeo.feature_extractor import PrestoFeatureExtractor from worldcereal.openeo.inference import CroplandClassifier, CroptypeClassifier @@ -24,44 +25,108 @@ class WorldCerealProduct(Enum): CROPTYPE = "croptype" -ONNX_DEPS_URL = "https://artifactory.vgt.vito.be/artifactory/auxdata-public/openeo/onnx_dependencies_1.16.3.zip" +class FeaturesParameters(BaseModel): + """Parameters for the feature extraction UDFs. Types are enforced by + Pydantic. + + Attributes + ---------- + rescale_s1 : bool (default=False) + Whether to rescale Sentinel-1 bands before feature extraction. Should be + left to False, as this is done in the Presto UDF itself. + presto_model_url : str + Public URL to the Presto model used for feature extraction. The file + should be a PyTorch serialized model. + """ + + rescale_s1: bool + presto_model_url: str + + +class ClassifierParameters(BaseModel): + """Parameters for the classifier. Types are enforced by Pydantic. + + Attributes + ---------- + classifier_url : str + Public URL to the classifier model. Te file should be an ONNX accepting + a `features` field for input data and returning two output probability + arrays `true` and `false`. + """ + + classifier_url: str + + +class CropLandParameters(BaseModel): + """Parameters for the cropland product inference pipeline. Types are + enforced by Pydantic. + + Attributes + ---------- + feature_extractor : PrestoFeatureExtractor + Feature extractor to use for the inference. This class must be a + subclass of GFMAP's `PatchFeatureExtractor` and returns float32 + features. + features_parameters : FeaturesParameters + Parameters for the feature extraction UDF. Will be serialized into a + dictionnary and passed in the process graph. + classifier : CroplandClassifier + Classifier to use for the inference. This class must be a subclass of + GFMAP's `ModelInference` and returns predictions/probabilities for + cropland. + classifier_parameters : ClassifierParameters + Parameters for the classifier UDF. Will be serialized into a dictionnary + and passed in the process graph. + """ + + feature_extractor: PrestoFeatureExtractor = PrestoFeatureExtractor + features_parameters: FeaturesParameters = FeaturesParameters( + rescale_s1=False, + presto_model_url="https://artifactory.vgt.vito.be/artifactory/auxdata-public/worldcereal-minimal-inference/presto.pt", # NOQA + ) + classifier: CroplandClassifier = CroplandClassifier + classifier_parameters: ClassifierParameters = ClassifierParameters( + classifier_url="https://artifactory.vgt.vito.be/artifactory/auxdata-public/worldcereal-minimal-inference/wc_catboost.onnx" # NOQA + ) -PRODUCT_SETTINGS = { - WorldCerealProduct.CROPLAND: { - "features": { - "extractor": PrestoFeatureExtractor, - "parameters": { - "rescale_s1": False, # Will be done in the Presto UDF itself! - "presto_model_url": "https://artifactory.vgt.vito.be/artifactory/auxdata-public/worldcereal-minimal-inference/presto.pt", # NOQA - }, - }, - "classification": { - "classifier": CroplandClassifier, - "parameters": { - "classifier_url": "https://artifactory.vgt.vito.be/artifactory/auxdata-public/worldcereal-minimal-inference/wc_catboost.onnx" # NOQA - }, - }, - }, - WorldCerealProduct.CROPTYPE: { - "features": { - "extractor": PrestoFeatureExtractor, - "parameters": { - "rescale_s1": False, # Will be done in the Presto UDF itself! - "presto_model_url": "https://artifactory.vgt.vito.be/artifactory/auxdata-public/worldcereal/models/PhaseII/presto-ss-wc-ft-ct-30D_test.pt", # NOQA - }, - }, - "classification": { - "classifier": CroptypeClassifier, - "parameters": { - "classifier_url": "https://artifactory.vgt.vito.be/artifactory/auxdata-public/worldcereal/models/PhaseII/presto-ss-wc-ft-ct-30D_test_CROPTYPE9.onnx" # NOQA - }, - }, - }, -} +class CropTypeParameters(BaseModel): + """Parameters for the croptype product inference pipeline. Types are + enforced by Pydantic. -@dataclass -class InferenceResults: + Attributes + ---------- + feature_extractor : PrestoFeatureExtractor + Feature extractor to use for the inference. This class must be a + subclass of GFMAP's `PatchFeatureExtractor` and returns float32 + features. + features_parameters : FeaturesParameters + Parameters for the feature extraction UDF. Will be serialized into a + dictionnary and passed in the process graph. + classifier : CroptypeClassifier + Classifier to use for the inference. This class must be a subclass of + GFMAP's `ModelInference` and returns predictions/probabilities for + cropland classes. + classifier_parameters : ClassifierParameters + Parameters for the classifier UDF. Will be serialized into a dictionnary + and passed in the process graph. + """ + + feature_extractor: PrestoFeatureExtractor = PrestoFeatureExtractor + feature_parameters: FeaturesParameters = FeaturesParameters( + rescale_s1=False, + presto_model_url="https://artifactory.vgt.vito.be/artifactory/auxdata-public/worldcereal/models/PhaseII/presto-ss-wc-ft-ct-30D_test.pt", # NOQA + ) + classifier: CroptypeClassifier = CroptypeClassifier + classifier_parameters: ClassifierParameters = ClassifierParameters( + classifier_url="https://artifactory.vgt.vito.be/artifactory/auxdata-public/worldcereal/models/PhaseII/presto-ss-wc-ft-ct-30D_test_CROPTYPE9.onnx" # NOQA + ) + + +ONNX_DEPS_URL = "https://artifactory.vgt.vito.be/artifactory/auxdata-public/openeo/onnx_dependencies_1.16.3.zip" + + +class InferenceResults(BaseModel): """Dataclass to store the results of the WorldCereal job. Attributes @@ -88,6 +153,8 @@ def generate_map( backend_context: BackendContext, output_path: Optional[Union[Path, str]], product_type: WorldCerealProduct = WorldCerealProduct.CROPLAND, + cropland_parameters: CropLandParameters = CropLandParameters(), + croptype_parameters: CropTypeParameters = CropTypeParameters(), out_format: str = "GTiff", ) -> InferenceResults: """Main function to generate a WorldCereal product. @@ -120,7 +187,7 @@ def generate_map( if the out_format is not supported """ - if product_type not in PRODUCT_SETTINGS.keys(): + if product_type not in WorldCerealProduct: raise ValueError(f"Product {product_type.value} not supported.") if out_format not in ["GTiff", "NetCDF"]: @@ -145,18 +212,22 @@ def generate_map( # Construct the feature extraction and model inference pipeline if product_type == WorldCerealProduct.CROPLAND: - classes = _cropland_map(inputs) + classes = _cropland_map(inputs, cropland_parameters=cropland_parameters) elif product_type == WorldCerealProduct.CROPTYPE: # First compute cropland map cropland_mask = ( - _cropland_map(inputs) + _cropland_map(inputs, cropland_parameters=cropland_parameters) .filter_bands("classification") .reduce_dimension( dimension="t", reducer="mean" ) # Temporary fix to make this work as mask ) - classes = _croptype_map(inputs, cropland_mask=cropland_mask) + classes = _croptype_map( + inputs, + croptype_parameters=croptype_parameters, + cropland_mask=cropland_mask, + ) # Submit the job job = classes.execute_batch( @@ -221,7 +292,9 @@ def collect_inputs( ) -def _cropland_map(inputs: DataCube) -> DataCube: +def _cropland_map( + inputs: DataCube, cropland_parameters: CropLandParameters +) -> DataCube: """Method to produce cropland map from preprocessed inputs, using a Presto feature extractor and a CatBoost classifier. @@ -238,13 +311,9 @@ def _cropland_map(inputs: DataCube) -> DataCube: # Run feature computer features = apply_feature_extractor( - feature_extractor_class=PRODUCT_SETTINGS[WorldCerealProduct.CROPLAND][ - "features" - ]["extractor"], + feature_extractor_class=cropland_parameters.feature_extractor, cube=inputs, - parameters=PRODUCT_SETTINGS[WorldCerealProduct.CROPLAND]["features"][ - "parameters" - ], + parameters=cropland_parameters.features_parameters.dict(), size=[ {"dimension": "x", "unit": "px", "value": 100}, {"dimension": "y", "unit": "px", "value": 100}, @@ -257,13 +326,9 @@ def _cropland_map(inputs: DataCube) -> DataCube: # Run model inference on features classes = apply_model_inference( - model_inference_class=PRODUCT_SETTINGS[WorldCerealProduct.CROPLAND][ - "classification" - ]["classifier"], + model_inference_class=cropland_parameters.classifier, cube=features, - parameters=PRODUCT_SETTINGS[WorldCerealProduct.CROPLAND]["classification"][ - "parameters" - ], + parameters=cropland_parameters.classifier_parameters.dict(), size=[ {"dimension": "x", "unit": "px", "value": 100}, {"dimension": "y", "unit": "px", "value": 100}, @@ -281,7 +346,11 @@ def _cropland_map(inputs: DataCube) -> DataCube: return classes -def _croptype_map(inputs: DataCube, cropland_mask: DataCube = None) -> DataCube: +def _croptype_map( + inputs: DataCube, + croptype_parameters: CropTypeParameters, + cropland_mask: DataCube = None, +) -> DataCube: """Method to produce croptype map from preprocessed inputs, using a Presto feature extractor and a CatBoost classifier. @@ -300,13 +369,9 @@ def _croptype_map(inputs: DataCube, cropland_mask: DataCube = None) -> DataCube: # Run feature computer features = apply_feature_extractor( - feature_extractor_class=PRODUCT_SETTINGS[WorldCerealProduct.CROPTYPE][ - "features" - ]["extractor"], + feature_extractor_class=croptype_parameters.feature_extractor, cube=inputs, - parameters=PRODUCT_SETTINGS[WorldCerealProduct.CROPTYPE]["features"][ - "parameters" - ], + parameters=croptype_parameters.feature_parameters.dict(), size=[ {"dimension": "x", "unit": "px", "value": 100}, {"dimension": "y", "unit": "px", "value": 100}, @@ -319,13 +384,9 @@ def _croptype_map(inputs: DataCube, cropland_mask: DataCube = None) -> DataCube: # Run model inference on features classes = apply_model_inference( - model_inference_class=PRODUCT_SETTINGS[WorldCerealProduct.CROPTYPE][ - "classification" - ]["classifier"], + model_inference_class=croptype_parameters.classifier, cube=features, - parameters=PRODUCT_SETTINGS[WorldCerealProduct.CROPTYPE]["classification"][ - "parameters" - ], + parameters=croptype_parameters.classifier_parameters.dict(), size=[ {"dimension": "x", "unit": "px", "value": 100}, {"dimension": "y", "unit": "px", "value": 100},