Skip to content

Commit

Permalink
Added local presto test script
Browse files Browse the repository at this point in the history
  • Loading branch information
GriffinBabe committed Jun 13, 2024
1 parent 679f33c commit ef3cf59
Show file tree
Hide file tree
Showing 3 changed files with 68 additions and 11 deletions.
3 changes: 2 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -182,4 +182,5 @@ catboost_info/catboost_training.json
*.tar
*.zip

.notebook-tests/
.notebook-tests/
.local-presto-test/
47 changes: 47 additions & 0 deletions scripts/inference/cropland_mapping_local.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
"""Perform cropland mapping inference using a local execution of presto.
Make sure you test this script on the Python version 3.9+, and have worldcereal

This comment has been minimized.

Copy link
@kvantricht

kvantricht Jun 13, 2024

Contributor

backend is at python=3.8 and I think we made everything compatible to 3.8

dependencies installed with the presto wheel file installed with it's dependencies.
"""

from pathlib import Path

import requests
import xarray as xr
from openeo_gfmap.features.feature_extractor import (
EPSG_HARMONIZED_NAME,
apply_feature_extractor_local,
)

from worldcereal.openeo.feature_extractor import PrestoFeatureExtractor

TEST_FILE_URL = "https://artifactory.vgt.vito.be/artifactory/auxdata-public/worldcereal/presto/localtestdata/local_presto_inputs.nc"
TEST_FILE_PATH = Path.cwd() / "presto_test_inputs.nc"


if __name__ == "__main__":
if not TEST_FILE_PATH.exists():
print("Downloading test input data...")
# Download the test input data
with requests.get(TEST_FILE_URL, stream=True, timeout=180) as response:
response.raise_for_status()
with open(TEST_FILE_PATH, "wb") as f:
for chunk in response.iter_content(chunk_size=8192):
f.write(chunk)

print("Loading array in-memory...")
arr = (
xr.open_dataset(TEST_FILE_PATH)
.to_array(dim="bands")
.drop_sel(bands="crs")
.astype("uint16")
)

print("Running presto UDF locally")
features = apply_feature_extractor_local(
PrestoFeatureExtractor,
arr,
parameters={EPSG_HARMONIZED_NAME: 32631, "ignore_dependencies": True},
)

features.to_netcdf(Path.cwd() / "presto_test_features.nc")
29 changes: 19 additions & 10 deletions src/worldcereal/openeo/feature_extractor.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@ class PrestoFeatureExtractor(PatchFeatureExtractor):
specified, should be set as `False`.
"""

import functools

PRESTO_MODEL_URL = "https://artifactory.vgt.vito.be/artifactory/auxdata-public/worldcereal-minimal-inference/presto.pt" # NOQA
PRESO_WHL_URL = "https://artifactory.vgt.vito.be/artifactory/auxdata-public/worldcereal/dependencies/presto_worldcereal-0.1.1-py3-none-any.whl"
BASE_URL = "https://s3.waw3-1.cloudferro.com/swift/v1/project_dependencies" # NOQA
Expand All @@ -41,6 +43,7 @@ class PrestoFeatureExtractor(PatchFeatureExtractor):
"AGERA5-PRECIP": "precipitation-flux",
}

@functools.lru_cache(maxsize=6)
def unpack_presto_wheel(self, wheel_url: str, destination_dir: str) -> list:
import urllib.request
import zipfile
Expand Down Expand Up @@ -73,6 +76,14 @@ def execute(self, inarr: xr.DataArray) -> xr.DataArray:
)
presto_wheel_url = self._parameters.get("presot_wheel_url", self.PRESO_WHL_URL)

ignore_dependencies = self._parameters.get("ignore_dependencies", False)
if ignore_dependencies:
self.logger.info(
"`ignore_dependencies` flag is set to True. Make sure that "
"Presto and its dependencies are available on the runtime "
"environment"
)

# The below is required to avoid flipping of the result
# when running on OpenEO backend!
inarr = inarr.transpose("bands", "t", "x", "y")
Expand All @@ -87,16 +98,14 @@ def execute(self, inarr: xr.DataArray) -> xr.DataArray:
inarr = inarr.fillna(65535)

# Unzip de dependencies on the backend
self.logger.info("Unzipping dependencies")
deps_dir = self.extract_dependencies(self.BASE_URL, self.DEPENDENCY_NAME)
self.logger.info("Unpacking presto wheel")
deps_dir = self.unpack_presto_wheel(presto_wheel_url, deps_dir)

self.logger.info("Appending dependencies")
sys.path.append(str(deps_dir))

# Debug, print the dependency directory
self.logger.info("Dependency directory: %s", list(Path(deps_dir).iterdir()))
if not ignore_dependencies:
self.logger.info("Unzipping dependencies")
deps_dir = self.extract_dependencies(self.BASE_URL, self.DEPENDENCY_NAME)
self.logger.info("Unpacking presto wheel")
deps_dir = self.unpack_presto_wheel(presto_wheel_url, deps_dir)

self.logger.info("Appending dependencies")
sys.path.append(str(deps_dir))

from presto.inference import ( # pylint: disable=import-outside-toplevel
get_presto_features,
Expand Down

1 comment on commit ef3cf59

@kvantricht
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For now the local test only does presto feature extraction, no catboost inference yet, right?

Please sign in to comment.