diff --git a/scripts/inference/croptype_mapping_local.py b/scripts/inference/croptype_mapping_local.py new file mode 100644 index 00000000..65962eac --- /dev/null +++ b/scripts/inference/croptype_mapping_local.py @@ -0,0 +1,68 @@ +"""Perform cropland mapping inference using a local execution of presto. + +Make sure you test this script on the Python version 3.9+, and have worldcereal +dependencies installed with the presto wheel file installed with it's dependencies. +""" + +from pathlib import Path + +import requests +import xarray as xr +from openeo_gfmap.features.feature_extractor import ( + EPSG_HARMONIZED_NAME, + apply_feature_extractor_local, +) +from openeo_gfmap.inference.model_inference import apply_model_inference_local + +from worldcereal.openeo.feature_extractor import PrestoFeatureExtractor +from worldcereal.openeo.inference import CroptypeClassifier + +TEST_FILE_URL = "https://artifactory.vgt.vito.be/artifactory/auxdata-public/worldcereal/presto/localtestdata/local_presto_inputs.nc" +TEST_FILE_PATH = Path.cwd() / "presto_test_inputs.nc" +PRESTO_URL = "https://artifactory.vgt.vito.be/artifactory/auxdata-public/worldcereal/models/PhaseII/presto-ss-wc-ft-ct-30D_test.pt" +CATBOOST_URL = "https://artifactory.vgt.vito.be/artifactory/auxdata-public/worldcereal/models/PhaseII/presto-ss-wc-ft-ct-30D_test_CROPTYPE9.onnx" + +if __name__ == "__main__": + if not TEST_FILE_PATH.exists(): + print("Downloading test input data...") + # Download the test input data + with requests.get(TEST_FILE_URL, stream=True, timeout=180) as response: + response.raise_for_status() + with open(TEST_FILE_PATH, "wb") as f: + for chunk in response.iter_content(chunk_size=8192): + f.write(chunk) + + print("Loading array in-memory...") + arr = ( + xr.open_dataset(TEST_FILE_PATH) + .to_array(dim="bands") + .drop_sel(bands="crs") + .astype("uint16") + ) + + print("Running presto UDF locally") + features = apply_feature_extractor_local( + PrestoFeatureExtractor, + arr, + parameters={ + EPSG_HARMONIZED_NAME: 32631, + "ignore_dependencies": True, + "presto_model_url": PRESTO_URL, + }, + ) + + features.to_netcdf(Path.cwd() / "presto_test_features_croptype.nc") + + print("Running classification inference UDF locally") + + classification = apply_model_inference_local( + CroptypeClassifier, + features, + parameters={ + EPSG_HARMONIZED_NAME: 32631, + "ignore_dependencies": True, + "classifier_url": CATBOOST_URL, + }, + ) + + classification.to_netcdf(Path.cwd() / "test_classification_croptype.nc") diff --git a/src/worldcereal/job.py b/src/worldcereal/job.py index 6ffeb071..1e29a4a4 100644 --- a/src/worldcereal/job.py +++ b/src/worldcereal/job.py @@ -11,11 +11,9 @@ from openeo_gfmap.preprocessing.scaling import compress_uint8, compress_uint16 from worldcereal.openeo.feature_extractor import PrestoFeatureExtractor -from worldcereal.openeo.inference import CroplandClassifier +from worldcereal.openeo.inference import CroplandClassifier, CroptypeClassifier from worldcereal.openeo.preprocessing import worldcereal_preprocessed_inputs_gfmap -ONNX_DEPS_URL = "https://artifactory.vgt.vito.be/artifactory/auxdata-public/openeo/onnx_dependencies_1.16.3.zip" - class WorldCerealProduct(Enum): """Enum to define the different WorldCereal products.""" @@ -24,6 +22,42 @@ class WorldCerealProduct(Enum): CROPTYPE = "croptype" +ONNX_DEPS_URL = "https://artifactory.vgt.vito.be/artifactory/auxdata-public/openeo/onnx_dependencies_1.16.3.zip" + +PRODUCT_SETTINGS = { + WorldCerealProduct.CROPLAND: { + "features": { + "extractor": PrestoFeatureExtractor, + "parameters": { + "rescale_s1": False, # Will be done in the Presto UDF itself! + "presto_model_url": "https://artifactory.vgt.vito.be/artifactory/auxdata-public/worldcereal-minimal-inference/presto.pt", # NOQA + }, + }, + "classification": { + "classifier": CroplandClassifier, + "parameters": { + "classifier_url": "https://artifactory.vgt.vito.be/artifactory/auxdata-public/worldcereal-minimal-inference/wc_catboost.onnx" # NOQA + }, + }, + }, + WorldCerealProduct.CROPTYPE: { + "features": { + "extractor": PrestoFeatureExtractor, + "parameters": { + "rescale_s1": False, # Will be done in the Presto UDF itself! + "presto_model_url": "https://artifactory.vgt.vito.be/artifactory/auxdata-public/worldcereal/models/PhaseII/presto-ss-wc-ft-ct-30D_test.pt", # NOQA + }, + }, + "classification": { + "classifier": CroptypeClassifier, + "parameters": { + "classifier_url": "https://artifactory.vgt.vito.be/artifactory/auxdata-public/worldcereal/models/PhaseII/presto-ss-wc-ft-ct-30D_test_CROPTYPE9.onnx" # NOQA + }, + }, + }, +} + + @dataclass class InferenceResults: """Dataclass to store the results of the WorldCereal job. @@ -69,6 +103,9 @@ def generate_map( """ + if product_type not in PRODUCT_SETTINGS.keys(): + raise ValueError(f"Product {product_type.value} not supported.") + # Connect to openeo connection = openeo.connect( "https://openeo.creo.vito.be/openeo/" @@ -83,14 +120,10 @@ def generate_map( ) # Run feature computer - presto_parameters = { - "rescale_s1": False, # Will be done in the Presto UDF itself! - } - features = apply_feature_extractor( - feature_extractor_class=PrestoFeatureExtractor, + feature_extractor_class=PRODUCT_SETTINGS[product_type]["features"]["extractor"], cube=inputs, - parameters=presto_parameters, + parameters=PRODUCT_SETTINGS[product_type]["features"]["parameters"], size=[ {"dimension": "x", "unit": "px", "value": 100}, {"dimension": "y", "unit": "px", "value": 100}, @@ -101,20 +134,15 @@ def generate_map( ], ) - if product_type == WorldCerealProduct.CROPLAND: - # initiate default cropland model - model_inference_class = CroplandClassifier - model_inference_parameters = {} - else: - raise ValueError(f"Product {product_type} not supported.") - if out_format not in ["GTiff", "NetCDF"]: - raise ValueError(f"Format {out_format} not supported.") + raise ValueError(f"Format {format} not supported.") classes = apply_model_inference( - model_inference_class=model_inference_class, + model_inference_class=PRODUCT_SETTINGS[product_type]["classification"][ + "classifier" + ], cube=features, - parameters=model_inference_parameters, + parameters=PRODUCT_SETTINGS[product_type]["classification"]["parameters"], size=[ {"dimension": "x", "unit": "px", "value": 100}, {"dimension": "y", "unit": "px", "value": 100}, @@ -137,11 +165,11 @@ def generate_map( out_format=out_format, job_options={ "driver-memory": "4g", - "executor-memoryOverhead": "12g", + "executor-memoryOverhead": "6g", "udf-dependency-archives": [f"{ONNX_DEPS_URL}#onnx_deps"], }, ) - # Should contain a single job as this is a single-jon tile inference. + asset = job.get_results().get_assets()[0] return InferenceResults( @@ -150,3 +178,43 @@ def generate_map( output_path=output_path, product=product_type, ) + + +def collect_inputs( + spatial_extent: BoundingBoxExtent, + temporal_extent: TemporalContext, + backend_context: BackendContext, + output_path: Union[Path, str], +): + """Function to retrieve preprocessed inputs that are being + used in the generation of WorldCereal products. + + Args: + spatial_extent (BoundingBoxExtent): spatial extent of the map + temporal_extent (TemporalContext): temporal range to consider + backend_context (BackendContext): backend to run the job on + output_path (Union[Path, str]): output path to download the product to + + Raises: + ValueError: if the product is not supported + + """ + + # Connect to openeo + connection = openeo.connect( + "https://openeo.creo.vito.be/openeo/" + ).authenticate_oidc() + + # Preparing the input cube for the inference + inputs = worldcereal_preprocessed_inputs_gfmap( + connection=connection, + backend_context=backend_context, + spatial_extent=spatial_extent, + temporal_extent=temporal_extent, + ) + + inputs.execute_batch( + outputfile=output_path, + out_format="NetCDF", + job_options={"driver-memory": "4g", "executor-memoryOverhead": "4g"}, + ) diff --git a/src/worldcereal/openeo/inference.py b/src/worldcereal/openeo/inference.py index cdfd2bf5..2feb409f 100644 --- a/src/worldcereal/openeo/inference.py +++ b/src/worldcereal/openeo/inference.py @@ -80,3 +80,89 @@ def execute(self, inarr: xr.DataArray) -> xr.DataArray: ) return classification + + +class CroptypeClassifier(ModelInference): + """Multi-class crop classifier using ONNX to load a catboost model. + + The classifier use the embeddings computed from the Presto Feature + Extractor. + + Interesting UDF parameters: + - classifier_url: A public URL to the ONNX classification model. Default is + the public Presto model. + """ + + import numpy as np + + CATBOOST_PATH = "https://artifactory.vgt.vito.be/artifactory/auxdata-public/worldcereal/models/PhaseII/presto-ss-wc-ft-ct-30D_test_CROPTYPE9.onnx" # NOQA + + def __init__(self): + super().__init__() + + self.onnx_session = None + + def dependencies(self) -> list: + return [] # Disable the dependencies from PIP install + + def output_labels(self) -> list: + return ["classification", "probability"] + + def predict(self, features: np.ndarray) -> np.ndarray: + """ + Predicts labels using the provided features array. + """ + import numpy as np + + if self.onnx_session is None: + raise ValueError("Model has not been loaded. Please load a model first.") + + # Prepare input data for ONNX model + outputs = self.onnx_session.run(None, {"features": features}) + + # Apply LUT: TODO:this needs an update! + LUT = { + "barley": 1, + "maize": 2, + "millet_sorghum": 3, + "other_crop": 4, + "rapeseed_rape": 5, + "soy_soybeans": 6, + "sunflower": 7, + "wheat": 8, + } + + # Extract classes as INTs and probability of winning class values + labels = np.zeros((len(outputs[0]),), dtype=np.uint16) + probabilities = np.zeros((len(outputs[0]),), dtype=np.uint8) + for i, (label, prob) in enumerate(zip(outputs[0], outputs[1])): + labels[i] = LUT[label] + probabilities[i] = int(prob[label] * 100) + + return np.stack([labels, probabilities], axis=0) + + def execute(self, inarr: xr.DataArray) -> xr.DataArray: + classifier_url = self._parameters.get("classifier_url", self.CATBOOST_PATH) + + # shape and indices for output ("xy", "bands") + x_coords, y_coords = inarr.x.values, inarr.y.values + inarr = inarr.transpose("bands", "x", "y").stack(xy=["x", "y"]).transpose() + + self.onnx_session = self.load_ort_session(classifier_url) + + # Run catboost classification + self.logger.info("Catboost classification with input shape: %s", inarr.shape) + classification = self.predict(inarr.values) + self.logger.info("Classification done with shape: %s", inarr.shape) + + classification = xr.DataArray( + classification.reshape((2, len(x_coords), len(y_coords))), + dims=["bands", "x", "y"], + coords={ + "bands": ["classification", "probability"], + "x": x_coords, + "y": y_coords, + }, + ) + + return classification diff --git a/src/worldcereal/openeo/preprocessing.py b/src/worldcereal/openeo/preprocessing.py index a9025383..fad5102a 100644 --- a/src/worldcereal/openeo/preprocessing.py +++ b/src/worldcereal/openeo/preprocessing.py @@ -264,6 +264,7 @@ def precomposited_datacube_METEO( temporal_extent=temporal_extent, bands=["precipitation-flux", "temperature-mean"], ) + cube.result_node().update_arguments(featureflags={"tilesize": 1}) cube = cube.rename_labels( dimension="bands", target=["AGERA5-PRECIP", "AGERA5-TMEAN"] )