diff --git a/environment.yml b/environment.yml index d2522c17..13844683 100644 --- a/environment.yml +++ b/environment.yml @@ -16,7 +16,7 @@ dependencies: - openeo=0.29.0 - pyarrow=16.1.0 - python=3.10.0 - - pytorch=2.3.0 + - pytorch=2.3.1 - rasterio=1.3.10 - rioxarray=0.15.5 - scikit-image=0.22.0 diff --git a/scripts/inference/cropland_mapping.py b/scripts/inference/cropland_mapping.py index 5068a184..c94eae3a 100644 --- a/scripts/inference/cropland_mapping.py +++ b/scripts/inference/cropland_mapping.py @@ -11,8 +11,8 @@ if __name__ == "__main__": parser = argparse.ArgumentParser( - prog="WC - Cropland Inference", - description="Cropland inference using GFMAP, Presto and WorldCereal classifiers", + prog="WC - Crop Mapping Inference", + description="Crop Mapping inference using GFMAP, Presto and WorldCereal classifiers", ) parser.add_argument("minx", type=float, help="Minimum X coordinate (west)") @@ -25,6 +25,11 @@ default=4326, help="EPSG code of the input `minx`, `miny`, `maxx`, `maxy` parameters.", ) + parser.add_argument( + "product", + type=str, + help="Product to generate. One of ['cropland', 'croptype']", + ) parser.add_argument( "start_date", type=str, help="Starting date for data extraction." ) @@ -46,6 +51,15 @@ start_date = args.start_date end_date = args.end_date + product = args.product + + # minx, miny, maxx, maxy = (664000, 5611134, 665000, 5612134) # Small test + # minx, miny, maxx, maxy = (664000, 5611134, 684000, 5631134) # Large test + # epsg = 32631 + # start_date = "2020-11-01" + # end_date = "2021-10-31" + # product = "croptype" + spatial_extent = BoundingBoxExtent(minx, miny, maxx, maxy, epsg) temporal_extent = TemporalContext(start_date, end_date) @@ -56,7 +70,7 @@ temporal_extent, backend_context, args.output_path, - product_type=WorldCerealProduct.CROPLAND, + product_type=WorldCerealProduct(product), out_format="GTiff", ) logger.success("Job finished:\n\t%s", job_results) diff --git a/scripts/inference/croptype_mapping_local.py b/scripts/inference/croptype_mapping_local.py index 65962eac..d663d655 100644 --- a/scripts/inference/croptype_mapping_local.py +++ b/scripts/inference/croptype_mapping_local.py @@ -15,7 +15,7 @@ from openeo_gfmap.inference.model_inference import apply_model_inference_local from worldcereal.openeo.feature_extractor import PrestoFeatureExtractor -from worldcereal.openeo.inference import CroptypeClassifier +from worldcereal.openeo.inference import CroplandClassifier, CroptypeClassifier TEST_FILE_URL = "https://artifactory.vgt.vito.be/artifactory/auxdata-public/worldcereal/presto/localtestdata/local_presto_inputs.nc" TEST_FILE_PATH = Path.cwd() / "presto_test_inputs.nc" @@ -40,24 +40,40 @@ .astype("uint16") ) - print("Running presto UDF locally") - features = apply_feature_extractor_local( + print("Get Presto cropland features") + cropland_features = apply_feature_extractor_local( PrestoFeatureExtractor, arr, + parameters={EPSG_HARMONIZED_NAME: 32631, "ignore_dependencies": True}, + ) + + print("Running cropland classification inference UDF locally") + + cropland_classification = apply_model_inference_local( + CroplandClassifier, + cropland_features, parameters={ EPSG_HARMONIZED_NAME: 32631, "ignore_dependencies": True, - "presto_model_url": PRESTO_URL, }, ) - features.to_netcdf(Path.cwd() / "presto_test_features_croptype.nc") + print("Get Presto croptype features") + croptype_features = apply_feature_extractor_local( + PrestoFeatureExtractor, + arr, + parameters={ + EPSG_HARMONIZED_NAME: 32631, + "ignore_dependencies": True, + "presto_model_url": PRESTO_URL, + }, + ) - print("Running classification inference UDF locally") + print("Running croptype classification inference UDF locally") - classification = apply_model_inference_local( + croptype_classification = apply_model_inference_local( CroptypeClassifier, - features, + croptype_features, parameters={ EPSG_HARMONIZED_NAME: 32631, "ignore_dependencies": True, @@ -65,4 +81,12 @@ }, ) - classification.to_netcdf(Path.cwd() / "test_classification_croptype.nc") + # Apply cropland mask -> on the backend this is done with mask process + croptype_classification = croptype_classification.where( + cropland_classification.sel(bands="classification") == 1, 0 + ) + + croptype_classification.to_netcdf( + Path("/vitodata/worldcereal/validation/internal_validation/") + / "test_classification_croptype_local.nc" + ) diff --git a/src/worldcereal/job.py b/src/worldcereal/job.py index 1e29a4a4..8f6ac9e4 100644 --- a/src/worldcereal/job.py +++ b/src/worldcereal/job.py @@ -1,10 +1,12 @@ """Executing inference jobs on the OpenEO backend.""" + from dataclasses import dataclass from enum import Enum from pathlib import Path from typing import Optional, Union import openeo +from openeo import DataCube from openeo_gfmap import BackendContext, BoundingBoxExtent, TemporalContext from openeo_gfmap.features.feature_extractor import apply_feature_extractor from openeo_gfmap.inference.model_inference import apply_model_inference @@ -87,31 +89,49 @@ def generate_map( output_path: Optional[Union[Path, str]], product_type: WorldCerealProduct = WorldCerealProduct.CROPLAND, out_format: str = "GTiff", -): +) -> InferenceResults: """Main function to generate a WorldCereal product. - Args: - spatial_extent (BoundingBoxExtent): spatial extent of the map - temporal_extent (TemporalContext): temporal range to consider - backend_context (BackendContext): backend to run the job on - output_path (Union[Path, str]): output path to download the product to - product (str, optional): product describer. Defaults to "cropland". - format (str, optional): Output format. Defaults to "GTiff". - - Raises: - ValueError: if the product is not supported - + Parameters + ---------- + spatial_extent : BoundingBoxExtent + spatial extent of the map + temporal_extent : TemporalContext + temporal range to consider + backend_context : BackendContext + backend to run the job on + output_path : Optional[Union[Path, str]] + output path to download the product to + product_type : WorldCerealProduct, optional + product describer, by default WorldCerealProduct.CROPLAND + out_format : str, optional + Output format, by default "GTiff" + + Returns + ------- + InferenceResults + Results of the finished WorldCereal job. + + Raises + ------ + ValueError + if the product is not supported + ValueError + if the out_format is not supported """ if product_type not in PRODUCT_SETTINGS.keys(): raise ValueError(f"Product {product_type.value} not supported.") + if out_format not in ["GTiff", "NetCDF"]: + raise ValueError(f"Format {format} not supported.") + # Connect to openeo connection = openeo.connect( "https://openeo.creo.vito.be/openeo/" ).authenticate_oidc() - # Preparing the input cube for the inference + # Preparing the input cube for inference inputs = worldcereal_preprocessed_inputs_gfmap( connection=connection, backend_context=backend_context, @@ -119,53 +139,32 @@ def generate_map( temporal_extent=temporal_extent, ) - # Run feature computer - features = apply_feature_extractor( - feature_extractor_class=PRODUCT_SETTINGS[product_type]["features"]["extractor"], - cube=inputs, - parameters=PRODUCT_SETTINGS[product_type]["features"]["parameters"], - size=[ - {"dimension": "x", "unit": "px", "value": 100}, - {"dimension": "y", "unit": "px", "value": 100}, - ], - overlap=[ - {"dimension": "x", "unit": "px", "value": 0}, - {"dimension": "y", "unit": "px", "value": 0}, - ], - ) - - if out_format not in ["GTiff", "NetCDF"]: - raise ValueError(f"Format {format} not supported.") - - classes = apply_model_inference( - model_inference_class=PRODUCT_SETTINGS[product_type]["classification"][ - "classifier" - ], - cube=features, - parameters=PRODUCT_SETTINGS[product_type]["classification"]["parameters"], - size=[ - {"dimension": "x", "unit": "px", "value": 100}, - {"dimension": "y", "unit": "px", "value": 100}, - {"dimension": "t", "value": "P1D"}, - ], - overlap=[ - {"dimension": "x", "unit": "px", "value": 0}, - {"dimension": "y", "unit": "px", "value": 0}, - ], - ) + # Explicit filtering again for bbox because of METEO low + # resolution causing issues + inputs = inputs.filter_bbox(dict(spatial_extent)) - # Cast to uint8 + # Construct the feature extraction and model inference pipeline if product_type == WorldCerealProduct.CROPLAND: - classes = compress_uint8(classes) - else: - classes = compress_uint16(classes) - + classes = _cropland_map(inputs) + elif product_type == WorldCerealProduct.CROPTYPE: + # First compute cropland map + cropland_mask = ( + _cropland_map(inputs) + .filter_bands("classification") + .reduce_dimension( + dimension="t", reducer="mean" + ) # Temporary fix to make this work as mask + ) + + classes = _croptype_map(inputs, cropland_mask=cropland_mask) + + # Submit the job job = classes.execute_batch( outputfile=output_path, out_format=out_format, job_options={ "driver-memory": "4g", - "executor-memoryOverhead": "6g", + "executor-memoryOverhead": "4g", "udf-dependency-archives": [f"{ONNX_DEPS_URL}#onnx_deps"], }, ) @@ -173,7 +172,7 @@ def generate_map( asset = job.get_results().get_assets()[0] return InferenceResults( - job_id=classes.job_id, + job_id=job.job_id, product_url=asset.href, output_path=output_path, product=product_type, @@ -189,14 +188,16 @@ def collect_inputs( """Function to retrieve preprocessed inputs that are being used in the generation of WorldCereal products. - Args: - spatial_extent (BoundingBoxExtent): spatial extent of the map - temporal_extent (TemporalContext): temporal range to consider - backend_context (BackendContext): backend to run the job on - output_path (Union[Path, str]): output path to download the product to - - Raises: - ValueError: if the product is not supported + Parameters + ---------- + spatial_extent : BoundingBoxExtent + spatial extent of the map + temporal_extent : TemporalContext + temporal range to consider + backend_context : BackendContext + backend to run the job on + output_path : Union[Path, str] + output path to download the product to """ @@ -218,3 +219,129 @@ def collect_inputs( out_format="NetCDF", job_options={"driver-memory": "4g", "executor-memoryOverhead": "4g"}, ) + + +def _cropland_map(inputs: DataCube) -> DataCube: + """Method to produce cropland map from preprocessed inputs, using + a Presto feature extractor and a CatBoost classifier. + + Parameters + ---------- + inputs : DataCube + preprocessed input cube + + Returns + ------- + DataCube + binary labels and probability + """ + + # Run feature computer + features = apply_feature_extractor( + feature_extractor_class=PRODUCT_SETTINGS[WorldCerealProduct.CROPLAND][ + "features" + ]["extractor"], + cube=inputs, + parameters=PRODUCT_SETTINGS[WorldCerealProduct.CROPLAND]["features"][ + "parameters" + ], + size=[ + {"dimension": "x", "unit": "px", "value": 100}, + {"dimension": "y", "unit": "px", "value": 100}, + ], + overlap=[ + {"dimension": "x", "unit": "px", "value": 0}, + {"dimension": "y", "unit": "px", "value": 0}, + ], + ) + + # Run model inference on features + classes = apply_model_inference( + model_inference_class=PRODUCT_SETTINGS[WorldCerealProduct.CROPLAND][ + "classification" + ]["classifier"], + cube=features, + parameters=PRODUCT_SETTINGS[WorldCerealProduct.CROPLAND]["classification"][ + "parameters" + ], + size=[ + {"dimension": "x", "unit": "px", "value": 100}, + {"dimension": "y", "unit": "px", "value": 100}, + {"dimension": "t", "value": "P1D"}, + ], + overlap=[ + {"dimension": "x", "unit": "px", "value": 0}, + {"dimension": "y", "unit": "px", "value": 0}, + ], + ) + + # Cast to uint8 + classes = compress_uint8(classes) + + return classes + + +def _croptype_map(inputs: DataCube, cropland_mask: DataCube = None) -> DataCube: + """Method to produce croptype map from preprocessed inputs, using + a Presto feature extractor and a CatBoost classifier. + + Parameters + ---------- + inputs : DataCube + preprocessed input cube + cropland_mask : DataCube, optional + optional cropland mask, by default None + + Returns + ------- + DataCube + croptype labels and probability + """ + + # Run feature computer + features = apply_feature_extractor( + feature_extractor_class=PRODUCT_SETTINGS[WorldCerealProduct.CROPTYPE][ + "features" + ]["extractor"], + cube=inputs, + parameters=PRODUCT_SETTINGS[WorldCerealProduct.CROPTYPE]["features"][ + "parameters" + ], + size=[ + {"dimension": "x", "unit": "px", "value": 100}, + {"dimension": "y", "unit": "px", "value": 100}, + ], + overlap=[ + {"dimension": "x", "unit": "px", "value": 0}, + {"dimension": "y", "unit": "px", "value": 0}, + ], + ) + + # Run model inference on features + classes = apply_model_inference( + model_inference_class=PRODUCT_SETTINGS[WorldCerealProduct.CROPTYPE][ + "classification" + ]["classifier"], + cube=features, + parameters=PRODUCT_SETTINGS[WorldCerealProduct.CROPTYPE]["classification"][ + "parameters" + ], + size=[ + {"dimension": "x", "unit": "px", "value": 100}, + {"dimension": "y", "unit": "px", "value": 100}, + {"dimension": "t", "value": "P1D"}, + ], + overlap=[ + {"dimension": "x", "unit": "px", "value": 0}, + {"dimension": "y", "unit": "px", "value": 0}, + ], + ) + + # Mask cropland + if cropland_mask is not None: + classes = classes.mask(cropland_mask == 0, replacement=0) + + # Cast to uint16 + classes = compress_uint16(classes) + + return classes diff --git a/src/worldcereal/openeo/feature_extractor.py b/src/worldcereal/openeo/feature_extractor.py index e5875bdb..e837994a 100644 --- a/src/worldcereal/openeo/feature_extractor.py +++ b/src/worldcereal/openeo/feature_extractor.py @@ -21,7 +21,7 @@ class PrestoFeatureExtractor(PatchFeatureExtractor): import functools PRESTO_MODEL_URL = "https://artifactory.vgt.vito.be/artifactory/auxdata-public/worldcereal-minimal-inference/presto.pt" # NOQA - PRESO_WHL_URL = "https://artifactory.vgt.vito.be/artifactory/auxdata-public/worldcereal/dependencies/presto_worldcereal-0.1.1-py3-none-any.whl" + PRESTO_WHL_URL = "https://artifactory.vgt.vito.be/artifactory/auxdata-public/worldcereal/dependencies/presto_worldcereal-0.1.1-py3-none-any.whl" BASE_URL = "https://s3.waw3-1.cloudferro.com/swift/v1/project_dependencies" # NOQA DEPENDENCY_NAME = "worldcereal_deps.zip" @@ -73,7 +73,7 @@ def execute(self, inarr: xr.DataArray) -> xr.DataArray: presto_model_url = self._parameters.get( "presto_model_url", self.PRESTO_MODEL_URL ) - presto_wheel_url = self._parameters.get("presot_wheel_url", self.PRESO_WHL_URL) + presto_wheel_url = self._parameters.get("presto_wheel_url", self.PRESTO_WHL_URL) ignore_dependencies = self._parameters.get("ignore_dependencies", False) if ignore_dependencies: @@ -110,7 +110,7 @@ def execute(self, inarr: xr.DataArray) -> xr.DataArray: get_presto_features, ) - batch_size = self._parameters.get("batch_size", 4096) + batch_size = self._parameters.get("batch_size", 256) self.logger.info("Extracting presto features") features = get_presto_features(