diff --git a/.gitignore b/.gitignore index e33f7077..d42db549 100755 --- a/.gitignore +++ b/.gitignore @@ -182,4 +182,5 @@ catboost_info/catboost_training.json *.tar *.zip -.notebook-tests/ \ No newline at end of file +.notebook-tests/ +.local-presto-test/ \ No newline at end of file diff --git a/scripts/inference/cropland_mapping_local.py b/scripts/inference/cropland_mapping_local.py new file mode 100644 index 00000000..0c36039f --- /dev/null +++ b/scripts/inference/cropland_mapping_local.py @@ -0,0 +1,47 @@ +"""Perform cropland mapping inference using a local execution of presto. + +Make sure you test this script on the Python version 3.9+, and have worldcereal +dependencies installed with the presto wheel file installed with it's dependencies. +""" + +from pathlib import Path + +import requests +import xarray as xr +from openeo_gfmap.features.feature_extractor import ( + EPSG_HARMONIZED_NAME, + apply_feature_extractor_local, +) + +from worldcereal.openeo.feature_extractor import PrestoFeatureExtractor + +TEST_FILE_URL = "https://artifactory.vgt.vito.be/artifactory/auxdata-public/worldcereal/presto/localtestdata/local_presto_inputs.nc" +TEST_FILE_PATH = Path.cwd() / "presto_test_inputs.nc" + + +if __name__ == "__main__": + if not TEST_FILE_PATH.exists(): + print("Downloading test input data...") + # Download the test input data + with requests.get(TEST_FILE_URL, stream=True, timeout=180) as response: + response.raise_for_status() + with open(TEST_FILE_PATH, "wb") as f: + for chunk in response.iter_content(chunk_size=8192): + f.write(chunk) + + print("Loading array in-memory...") + arr = ( + xr.open_dataset(TEST_FILE_PATH) + .to_array(dim="bands") + .drop_sel(bands="crs") + .astype("uint16") + ) + + print("Running presto UDF locally") + features = apply_feature_extractor_local( + PrestoFeatureExtractor, + arr, + parameters={EPSG_HARMONIZED_NAME: 32631, "ignore_dependencies": True}, + ) + + features.to_netcdf(Path.cwd() / "presto_test_features.nc") diff --git a/src/worldcereal/openeo/feature_extractor.py b/src/worldcereal/openeo/feature_extractor.py index 40dde71a..08b8f2f0 100644 --- a/src/worldcereal/openeo/feature_extractor.py +++ b/src/worldcereal/openeo/feature_extractor.py @@ -18,6 +18,8 @@ class PrestoFeatureExtractor(PatchFeatureExtractor): specified, should be set as `False`. """ + import functools + PRESTO_MODEL_URL = "https://artifactory.vgt.vito.be/artifactory/auxdata-public/worldcereal-minimal-inference/presto.pt" # NOQA PRESO_WHL_URL = "https://artifactory.vgt.vito.be/artifactory/auxdata-public/worldcereal/dependencies/presto_worldcereal-0.1.1-py3-none-any.whl" BASE_URL = "https://s3.waw3-1.cloudferro.com/swift/v1/project_dependencies" # NOQA @@ -41,6 +43,7 @@ class PrestoFeatureExtractor(PatchFeatureExtractor): "AGERA5-PRECIP": "precipitation-flux", } + @functools.lru_cache(maxsize=6) def unpack_presto_wheel(self, wheel_url: str, destination_dir: str) -> list: import urllib.request import zipfile @@ -73,6 +76,14 @@ def execute(self, inarr: xr.DataArray) -> xr.DataArray: ) presto_wheel_url = self._parameters.get("presot_wheel_url", self.PRESO_WHL_URL) + ignore_dependencies = self._parameters.get("ignore_dependencies", False) + if ignore_dependencies: + self.logger.info( + "`ignore_dependencies` flag is set to True. Make sure that " + "Presto and its dependencies are available on the runtime " + "environment" + ) + # The below is required to avoid flipping of the result # when running on OpenEO backend! inarr = inarr.transpose("bands", "t", "x", "y") @@ -87,16 +98,14 @@ def execute(self, inarr: xr.DataArray) -> xr.DataArray: inarr = inarr.fillna(65535) # Unzip de dependencies on the backend - self.logger.info("Unzipping dependencies") - deps_dir = self.extract_dependencies(self.BASE_URL, self.DEPENDENCY_NAME) - self.logger.info("Unpacking presto wheel") - deps_dir = self.unpack_presto_wheel(presto_wheel_url, deps_dir) - - self.logger.info("Appending dependencies") - sys.path.append(str(deps_dir)) - - # Debug, print the dependency directory - self.logger.info("Dependency directory: %s", list(Path(deps_dir).iterdir())) + if not ignore_dependencies: + self.logger.info("Unzipping dependencies") + deps_dir = self.extract_dependencies(self.BASE_URL, self.DEPENDENCY_NAME) + self.logger.info("Unpacking presto wheel") + deps_dir = self.unpack_presto_wheel(presto_wheel_url, deps_dir) + + self.logger.info("Appending dependencies") + sys.path.append(str(deps_dir)) from presto.inference import ( # pylint: disable=import-outside-toplevel get_presto_features,