From f4f83a09b10eeb3c86a9c6cadd9ed4056849d820 Mon Sep 17 00:00:00 2001 From: Kristof Van Tricht Date: Fri, 21 Jun 2024 17:09:36 +0200 Subject: [PATCH 1/3] Add croptype inference --- scripts/inference/croptype_mapping_local.py | 68 +++++++++++++++ src/worldcereal/job.py | 92 +++++++++++++++++---- src/worldcereal/openeo/inference.py | 86 +++++++++++++++++++ 3 files changed, 231 insertions(+), 15 deletions(-) create mode 100644 scripts/inference/croptype_mapping_local.py diff --git a/scripts/inference/croptype_mapping_local.py b/scripts/inference/croptype_mapping_local.py new file mode 100644 index 00000000..65962eac --- /dev/null +++ b/scripts/inference/croptype_mapping_local.py @@ -0,0 +1,68 @@ +"""Perform cropland mapping inference using a local execution of presto. + +Make sure you test this script on the Python version 3.9+, and have worldcereal +dependencies installed with the presto wheel file installed with it's dependencies. +""" + +from pathlib import Path + +import requests +import xarray as xr +from openeo_gfmap.features.feature_extractor import ( + EPSG_HARMONIZED_NAME, + apply_feature_extractor_local, +) +from openeo_gfmap.inference.model_inference import apply_model_inference_local + +from worldcereal.openeo.feature_extractor import PrestoFeatureExtractor +from worldcereal.openeo.inference import CroptypeClassifier + +TEST_FILE_URL = "https://artifactory.vgt.vito.be/artifactory/auxdata-public/worldcereal/presto/localtestdata/local_presto_inputs.nc" +TEST_FILE_PATH = Path.cwd() / "presto_test_inputs.nc" +PRESTO_URL = "https://artifactory.vgt.vito.be/artifactory/auxdata-public/worldcereal/models/PhaseII/presto-ss-wc-ft-ct-30D_test.pt" +CATBOOST_URL = "https://artifactory.vgt.vito.be/artifactory/auxdata-public/worldcereal/models/PhaseII/presto-ss-wc-ft-ct-30D_test_CROPTYPE9.onnx" + +if __name__ == "__main__": + if not TEST_FILE_PATH.exists(): + print("Downloading test input data...") + # Download the test input data + with requests.get(TEST_FILE_URL, stream=True, timeout=180) as response: + response.raise_for_status() + with open(TEST_FILE_PATH, "wb") as f: + for chunk in response.iter_content(chunk_size=8192): + f.write(chunk) + + print("Loading array in-memory...") + arr = ( + xr.open_dataset(TEST_FILE_PATH) + .to_array(dim="bands") + .drop_sel(bands="crs") + .astype("uint16") + ) + + print("Running presto UDF locally") + features = apply_feature_extractor_local( + PrestoFeatureExtractor, + arr, + parameters={ + EPSG_HARMONIZED_NAME: 32631, + "ignore_dependencies": True, + "presto_model_url": PRESTO_URL, + }, + ) + + features.to_netcdf(Path.cwd() / "presto_test_features_croptype.nc") + + print("Running classification inference UDF locally") + + classification = apply_model_inference_local( + CroptypeClassifier, + features, + parameters={ + EPSG_HARMONIZED_NAME: 32631, + "ignore_dependencies": True, + "classifier_url": CATBOOST_URL, + }, + ) + + classification.to_netcdf(Path.cwd() / "test_classification_croptype.nc") diff --git a/src/worldcereal/job.py b/src/worldcereal/job.py index 934f77d9..3922a17b 100644 --- a/src/worldcereal/job.py +++ b/src/worldcereal/job.py @@ -14,6 +14,36 @@ ONNX_DEPS_URL = "https://artifactory.vgt.vito.be/artifactory/auxdata-public/openeo/onnx_dependencies_1.16.3.zip" +PRODUCT_SETTINGS = { + "cropland": { + "features": { + "extractor": PrestoFeatureExtractor, + "parameters": { + "rescale_s1": False, # Will be done in the Presto UDF itself! + "presto_model_url": "https://artifactory.vgt.vito.be/artifactory/auxdata-public/worldcereal-minimal-inference/presto.pt", # NOQA + }, + }, + "classification": { + "classifier": CroplandClassifier, + "parameters": {"classifier_url": ""}, + }, + }, + "croptype": { + "features": { + "extractor": PrestoFeatureExtractor, + "parameters": { + "rescale_s1": False, # Will be done in the Presto UDF itself! + "presto_model_url": "https://artifactory.vgt.vito.be/artifactory/auxdata-public/worldcereal/models/PhaseII/presto-ss-wc-ft-ct-30D_test.pt", # NOQA + }, + }, + "classification": { + "classifier": CroplandClassifier, # TODO: update to croptype classifier + "parameters": {"classifier_url": ""}, + }, + }, +} + + def generate_map( spatial_extent: BoundingBoxExtent, temporal_extent: TemporalContext, @@ -37,6 +67,9 @@ def generate_map( """ + if product not in PRODUCT_SETTINGS.keys(): + raise ValueError(f"Product {product} not supported.") + # Connect to openeo connection = openeo.connect( "https://openeo.creo.vito.be/openeo/" @@ -51,14 +84,10 @@ def generate_map( ) # Run feature computer - presto_parameters = { - "rescale_s1": False, # Will be done in the Presto UDF itself! - } - features = apply_feature_extractor( - feature_extractor_class=PrestoFeatureExtractor, + feature_extractor_class=PRODUCT_SETTINGS[product]["features"]["extractor"], cube=inputs, - parameters=presto_parameters, + parameters=PRODUCT_SETTINGS[product]["features"]["parameters"], size=[ {"dimension": "x", "unit": "px", "value": 100}, {"dimension": "y", "unit": "px", "value": 100}, @@ -69,20 +98,13 @@ def generate_map( ], ) - if product == "cropland": - # initiate default cropland model - model_inference_class = CroplandClassifier - model_inference_parameters = {} - else: - raise ValueError(f"Product {product} not supported.") - if format not in ["GTiff", "NetCDF"]: raise ValueError(f"Format {format} not supported.") classes = apply_model_inference( - model_inference_class=model_inference_class, + model_inference_class=PRODUCT_SETTINGS[product]["classification"]["classifier"], cube=features, - parameters=model_inference_parameters, + parameters=PRODUCT_SETTINGS[product]["classification"]["parameters"], size=[ {"dimension": "x", "unit": "px", "value": 100}, {"dimension": "y", "unit": "px", "value": 100}, @@ -109,3 +131,43 @@ def generate_map( "udf-dependency-archives": [f"{ONNX_DEPS_URL}#onnx_deps"], }, ) + + +def collect_inputs( + spatial_extent: BoundingBoxExtent, + temporal_extent: TemporalContext, + backend_context: BackendContext, + output_path: Union[Path, str], +): + """Function to retrieve preprocessed inputs that are being + used in the generation of WorldCereal products. + + Args: + spatial_extent (BoundingBoxExtent): spatial extent of the map + temporal_extent (TemporalContext): temporal range to consider + backend_context (BackendContext): backend to run the job on + output_path (Union[Path, str]): output path to download the product to + + Raises: + ValueError: if the product is not supported + + """ + + # Connect to openeo + connection = openeo.connect( + "https://openeo.creo.vito.be/openeo/" + ).authenticate_oidc() + + # Preparing the input cube for the inference + inputs = worldcereal_preprocessed_inputs_gfmap( + connection=connection, + backend_context=backend_context, + spatial_extent=spatial_extent, + temporal_extent=temporal_extent, + ) + + inputs.execute_batch( + outputfile=output_path, + out_format="NetCDF", + job_options={"driver-memory": "4g", "executor-memoryOverhead": "4g"}, + ) diff --git a/src/worldcereal/openeo/inference.py b/src/worldcereal/openeo/inference.py index cdfd2bf5..2feb409f 100644 --- a/src/worldcereal/openeo/inference.py +++ b/src/worldcereal/openeo/inference.py @@ -80,3 +80,89 @@ def execute(self, inarr: xr.DataArray) -> xr.DataArray: ) return classification + + +class CroptypeClassifier(ModelInference): + """Multi-class crop classifier using ONNX to load a catboost model. + + The classifier use the embeddings computed from the Presto Feature + Extractor. + + Interesting UDF parameters: + - classifier_url: A public URL to the ONNX classification model. Default is + the public Presto model. + """ + + import numpy as np + + CATBOOST_PATH = "https://artifactory.vgt.vito.be/artifactory/auxdata-public/worldcereal/models/PhaseII/presto-ss-wc-ft-ct-30D_test_CROPTYPE9.onnx" # NOQA + + def __init__(self): + super().__init__() + + self.onnx_session = None + + def dependencies(self) -> list: + return [] # Disable the dependencies from PIP install + + def output_labels(self) -> list: + return ["classification", "probability"] + + def predict(self, features: np.ndarray) -> np.ndarray: + """ + Predicts labels using the provided features array. + """ + import numpy as np + + if self.onnx_session is None: + raise ValueError("Model has not been loaded. Please load a model first.") + + # Prepare input data for ONNX model + outputs = self.onnx_session.run(None, {"features": features}) + + # Apply LUT: TODO:this needs an update! + LUT = { + "barley": 1, + "maize": 2, + "millet_sorghum": 3, + "other_crop": 4, + "rapeseed_rape": 5, + "soy_soybeans": 6, + "sunflower": 7, + "wheat": 8, + } + + # Extract classes as INTs and probability of winning class values + labels = np.zeros((len(outputs[0]),), dtype=np.uint16) + probabilities = np.zeros((len(outputs[0]),), dtype=np.uint8) + for i, (label, prob) in enumerate(zip(outputs[0], outputs[1])): + labels[i] = LUT[label] + probabilities[i] = int(prob[label] * 100) + + return np.stack([labels, probabilities], axis=0) + + def execute(self, inarr: xr.DataArray) -> xr.DataArray: + classifier_url = self._parameters.get("classifier_url", self.CATBOOST_PATH) + + # shape and indices for output ("xy", "bands") + x_coords, y_coords = inarr.x.values, inarr.y.values + inarr = inarr.transpose("bands", "x", "y").stack(xy=["x", "y"]).transpose() + + self.onnx_session = self.load_ort_session(classifier_url) + + # Run catboost classification + self.logger.info("Catboost classification with input shape: %s", inarr.shape) + classification = self.predict(inarr.values) + self.logger.info("Classification done with shape: %s", inarr.shape) + + classification = xr.DataArray( + classification.reshape((2, len(x_coords), len(y_coords))), + dims=["bands", "x", "y"], + coords={ + "bands": ["classification", "probability"], + "x": x_coords, + "y": y_coords, + }, + ) + + return classification From bdfd4dcfde62f4ec226ffab1f01c6f43bae90679 Mon Sep 17 00:00:00 2001 From: Kristof Van Tricht Date: Mon, 24 Jun 2024 15:39:14 +0200 Subject: [PATCH 2/3] Avoid going bankrupt --- src/worldcereal/job.py | 2 +- src/worldcereal/openeo/preprocessing.py | 1 + 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/src/worldcereal/job.py b/src/worldcereal/job.py index 3922a17b..6e7fc6cf 100644 --- a/src/worldcereal/job.py +++ b/src/worldcereal/job.py @@ -127,7 +127,7 @@ def generate_map( out_format=format, job_options={ "driver-memory": "4g", - "executor-memoryOverhead": "12g", + "executor-memoryOverhead": "6g", "udf-dependency-archives": [f"{ONNX_DEPS_URL}#onnx_deps"], }, ) diff --git a/src/worldcereal/openeo/preprocessing.py b/src/worldcereal/openeo/preprocessing.py index a9025383..fad5102a 100644 --- a/src/worldcereal/openeo/preprocessing.py +++ b/src/worldcereal/openeo/preprocessing.py @@ -264,6 +264,7 @@ def precomposited_datacube_METEO( temporal_extent=temporal_extent, bands=["precipitation-flux", "temperature-mean"], ) + cube.result_node().update_arguments(featureflags={"tilesize": 1}) cube = cube.rename_labels( dimension="bands", target=["AGERA5-PRECIP", "AGERA5-TMEAN"] ) From 15ef9ffa5bdda08e8ab02db091245579fa640138 Mon Sep 17 00:00:00 2001 From: Kristof Van Tricht Date: Mon, 24 Jun 2024 16:12:55 +0200 Subject: [PATCH 3/3] Added model URLs --- src/worldcereal/job.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/src/worldcereal/job.py b/src/worldcereal/job.py index 6e7fc6cf..e6973b0b 100644 --- a/src/worldcereal/job.py +++ b/src/worldcereal/job.py @@ -25,7 +25,9 @@ }, "classification": { "classifier": CroplandClassifier, - "parameters": {"classifier_url": ""}, + "parameters": { + "classifier_url": "https://artifactory.vgt.vito.be/artifactory/auxdata-public/worldcereal-minimal-inference/wc_catboost.onnx" # NOQA + }, }, }, "croptype": { @@ -38,7 +40,9 @@ }, "classification": { "classifier": CroplandClassifier, # TODO: update to croptype classifier - "parameters": {"classifier_url": ""}, + "parameters": { + "classifier_url": "https://artifactory.vgt.vito.be/artifactory/auxdata-public/worldcereal/models/PhaseII/presto-ss-wc-ft-ct-30D_test_CROPTYPE9.onnx" # NOQA + }, }, }, }