From ef3cf59595a77cb92076ef399a5e0abeda1bb076 Mon Sep 17 00:00:00 2001 From: Darius Couchard Date: Thu, 13 Jun 2024 10:20:10 +0200 Subject: [PATCH 1/3] Added local presto test script --- .gitignore | 3 +- scripts/inference/cropland_mapping_local.py | 47 +++++++++++++++++++++ src/worldcereal/openeo/feature_extractor.py | 29 ++++++++----- 3 files changed, 68 insertions(+), 11 deletions(-) create mode 100644 scripts/inference/cropland_mapping_local.py diff --git a/.gitignore b/.gitignore index e33f7077..d42db549 100755 --- a/.gitignore +++ b/.gitignore @@ -182,4 +182,5 @@ catboost_info/catboost_training.json *.tar *.zip -.notebook-tests/ \ No newline at end of file +.notebook-tests/ +.local-presto-test/ \ No newline at end of file diff --git a/scripts/inference/cropland_mapping_local.py b/scripts/inference/cropland_mapping_local.py new file mode 100644 index 00000000..0c36039f --- /dev/null +++ b/scripts/inference/cropland_mapping_local.py @@ -0,0 +1,47 @@ +"""Perform cropland mapping inference using a local execution of presto. + +Make sure you test this script on the Python version 3.9+, and have worldcereal +dependencies installed with the presto wheel file installed with it's dependencies. +""" + +from pathlib import Path + +import requests +import xarray as xr +from openeo_gfmap.features.feature_extractor import ( + EPSG_HARMONIZED_NAME, + apply_feature_extractor_local, +) + +from worldcereal.openeo.feature_extractor import PrestoFeatureExtractor + +TEST_FILE_URL = "https://artifactory.vgt.vito.be/artifactory/auxdata-public/worldcereal/presto/localtestdata/local_presto_inputs.nc" +TEST_FILE_PATH = Path.cwd() / "presto_test_inputs.nc" + + +if __name__ == "__main__": + if not TEST_FILE_PATH.exists(): + print("Downloading test input data...") + # Download the test input data + with requests.get(TEST_FILE_URL, stream=True, timeout=180) as response: + response.raise_for_status() + with open(TEST_FILE_PATH, "wb") as f: + for chunk in response.iter_content(chunk_size=8192): + f.write(chunk) + + print("Loading array in-memory...") + arr = ( + xr.open_dataset(TEST_FILE_PATH) + .to_array(dim="bands") + .drop_sel(bands="crs") + .astype("uint16") + ) + + print("Running presto UDF locally") + features = apply_feature_extractor_local( + PrestoFeatureExtractor, + arr, + parameters={EPSG_HARMONIZED_NAME: 32631, "ignore_dependencies": True}, + ) + + features.to_netcdf(Path.cwd() / "presto_test_features.nc") diff --git a/src/worldcereal/openeo/feature_extractor.py b/src/worldcereal/openeo/feature_extractor.py index 40dde71a..08b8f2f0 100644 --- a/src/worldcereal/openeo/feature_extractor.py +++ b/src/worldcereal/openeo/feature_extractor.py @@ -18,6 +18,8 @@ class PrestoFeatureExtractor(PatchFeatureExtractor): specified, should be set as `False`. """ + import functools + PRESTO_MODEL_URL = "https://artifactory.vgt.vito.be/artifactory/auxdata-public/worldcereal-minimal-inference/presto.pt" # NOQA PRESO_WHL_URL = "https://artifactory.vgt.vito.be/artifactory/auxdata-public/worldcereal/dependencies/presto_worldcereal-0.1.1-py3-none-any.whl" BASE_URL = "https://s3.waw3-1.cloudferro.com/swift/v1/project_dependencies" # NOQA @@ -41,6 +43,7 @@ class PrestoFeatureExtractor(PatchFeatureExtractor): "AGERA5-PRECIP": "precipitation-flux", } + @functools.lru_cache(maxsize=6) def unpack_presto_wheel(self, wheel_url: str, destination_dir: str) -> list: import urllib.request import zipfile @@ -73,6 +76,14 @@ def execute(self, inarr: xr.DataArray) -> xr.DataArray: ) presto_wheel_url = self._parameters.get("presot_wheel_url", self.PRESO_WHL_URL) + ignore_dependencies = self._parameters.get("ignore_dependencies", False) + if ignore_dependencies: + self.logger.info( + "`ignore_dependencies` flag is set to True. Make sure that " + "Presto and its dependencies are available on the runtime " + "environment" + ) + # The below is required to avoid flipping of the result # when running on OpenEO backend! inarr = inarr.transpose("bands", "t", "x", "y") @@ -87,16 +98,14 @@ def execute(self, inarr: xr.DataArray) -> xr.DataArray: inarr = inarr.fillna(65535) # Unzip de dependencies on the backend - self.logger.info("Unzipping dependencies") - deps_dir = self.extract_dependencies(self.BASE_URL, self.DEPENDENCY_NAME) - self.logger.info("Unpacking presto wheel") - deps_dir = self.unpack_presto_wheel(presto_wheel_url, deps_dir) - - self.logger.info("Appending dependencies") - sys.path.append(str(deps_dir)) - - # Debug, print the dependency directory - self.logger.info("Dependency directory: %s", list(Path(deps_dir).iterdir())) + if not ignore_dependencies: + self.logger.info("Unzipping dependencies") + deps_dir = self.extract_dependencies(self.BASE_URL, self.DEPENDENCY_NAME) + self.logger.info("Unpacking presto wheel") + deps_dir = self.unpack_presto_wheel(presto_wheel_url, deps_dir) + + self.logger.info("Appending dependencies") + sys.path.append(str(deps_dir)) from presto.inference import ( # pylint: disable=import-outside-toplevel get_presto_features, From 84f12bcda2fc9ebb18adeeda1f8dcc6655cc40fa Mon Sep 17 00:00:00 2001 From: Kristof Van Tricht Date: Fri, 21 Jun 2024 17:08:22 +0200 Subject: [PATCH 2/3] Added cropland model inference --- scripts/inference/cropland_mapping_local.py | 27 ++++++++++++++++++--- 1 file changed, 24 insertions(+), 3 deletions(-) diff --git a/scripts/inference/cropland_mapping_local.py b/scripts/inference/cropland_mapping_local.py index 0c36039f..e20a0920 100644 --- a/scripts/inference/cropland_mapping_local.py +++ b/scripts/inference/cropland_mapping_local.py @@ -12,12 +12,15 @@ EPSG_HARMONIZED_NAME, apply_feature_extractor_local, ) +from openeo_gfmap.inference.model_inference import apply_model_inference_local from worldcereal.openeo.feature_extractor import PrestoFeatureExtractor +from worldcereal.openeo.inference import CroplandClassifier TEST_FILE_URL = "https://artifactory.vgt.vito.be/artifactory/auxdata-public/worldcereal/presto/localtestdata/local_presto_inputs.nc" TEST_FILE_PATH = Path.cwd() / "presto_test_inputs.nc" - +PRESTO_URL = "https://artifactory.vgt.vito.be/artifactory/auxdata-public/worldcereal-minimal-inference/presto.pt" +CATBOOST_URL = "https://artifactory.vgt.vito.be/artifactory/auxdata-public/worldcereal-minimal-inference/wc_catboost.onnx" if __name__ == "__main__": if not TEST_FILE_PATH.exists(): @@ -41,7 +44,25 @@ features = apply_feature_extractor_local( PrestoFeatureExtractor, arr, - parameters={EPSG_HARMONIZED_NAME: 32631, "ignore_dependencies": True}, + parameters={ + EPSG_HARMONIZED_NAME: 32631, + "ignore_dependencies": True, + "presto_model_url": PRESTO_URL, + }, + ) + + features.to_netcdf(Path.cwd() / "presto_test_features_cropland.nc") + + print("Running classification inference UDF locally") + + classification = apply_model_inference_local( + CroplandClassifier, + features, + parameters={ + EPSG_HARMONIZED_NAME: 32631, + "ignore_dependencies": True, + "classifier_url": CATBOOST_URL, + }, ) - features.to_netcdf(Path.cwd() / "presto_test_features.nc") + classification.to_netcdf(Path.cwd() / "test_classification_cropland.nc") From 41c103ffc13f53343e72c45a70d9e3b253d04195 Mon Sep 17 00:00:00 2001 From: Kristof Van Tricht Date: Mon, 24 Jun 2024 13:05:36 +0200 Subject: [PATCH 3/3] Ruff fix --- src/worldcereal/openeo/feature_extractor.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/worldcereal/openeo/feature_extractor.py b/src/worldcereal/openeo/feature_extractor.py index 08b8f2f0..19472ef1 100644 --- a/src/worldcereal/openeo/feature_extractor.py +++ b/src/worldcereal/openeo/feature_extractor.py @@ -64,7 +64,6 @@ def output_labels(self) -> list: def execute(self, inarr: xr.DataArray) -> xr.DataArray: import sys - from pathlib import Path if self.epsg is None: raise ValueError(