Skip to content

Commit

Permalink
Merge pull request #77 from WorldCereal/57-udf-memory-profiling
Browse files Browse the repository at this point in the history
57 udf memory profiling
  • Loading branch information
GriffinBabe authored Jun 24, 2024
2 parents 4680b7e + 41c103f commit f4ab084
Show file tree
Hide file tree
Showing 3 changed files with 89 additions and 12 deletions.
3 changes: 2 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -179,4 +179,5 @@ catboost_info/catboost_training.json
*.tar
*.zip

.notebook-tests/
.notebook-tests/
.local-presto-test/
68 changes: 68 additions & 0 deletions scripts/inference/cropland_mapping_local.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
"""Perform cropland mapping inference using a local execution of presto.
Make sure you test this script on the Python version 3.9+, and have worldcereal
dependencies installed with the presto wheel file installed with it's dependencies.
"""

from pathlib import Path

import requests
import xarray as xr
from openeo_gfmap.features.feature_extractor import (
EPSG_HARMONIZED_NAME,
apply_feature_extractor_local,
)
from openeo_gfmap.inference.model_inference import apply_model_inference_local

from worldcereal.openeo.feature_extractor import PrestoFeatureExtractor
from worldcereal.openeo.inference import CroplandClassifier

TEST_FILE_URL = "https://artifactory.vgt.vito.be/artifactory/auxdata-public/worldcereal/presto/localtestdata/local_presto_inputs.nc"
TEST_FILE_PATH = Path.cwd() / "presto_test_inputs.nc"
PRESTO_URL = "https://artifactory.vgt.vito.be/artifactory/auxdata-public/worldcereal-minimal-inference/presto.pt"
CATBOOST_URL = "https://artifactory.vgt.vito.be/artifactory/auxdata-public/worldcereal-minimal-inference/wc_catboost.onnx"

if __name__ == "__main__":
if not TEST_FILE_PATH.exists():
print("Downloading test input data...")
# Download the test input data
with requests.get(TEST_FILE_URL, stream=True, timeout=180) as response:
response.raise_for_status()
with open(TEST_FILE_PATH, "wb") as f:
for chunk in response.iter_content(chunk_size=8192):
f.write(chunk)

print("Loading array in-memory...")
arr = (
xr.open_dataset(TEST_FILE_PATH)
.to_array(dim="bands")
.drop_sel(bands="crs")
.astype("uint16")
)

print("Running presto UDF locally")
features = apply_feature_extractor_local(
PrestoFeatureExtractor,
arr,
parameters={
EPSG_HARMONIZED_NAME: 32631,
"ignore_dependencies": True,
"presto_model_url": PRESTO_URL,
},
)

features.to_netcdf(Path.cwd() / "presto_test_features_cropland.nc")

print("Running classification inference UDF locally")

classification = apply_model_inference_local(
CroplandClassifier,
features,
parameters={
EPSG_HARMONIZED_NAME: 32631,
"ignore_dependencies": True,
"classifier_url": CATBOOST_URL,
},
)

classification.to_netcdf(Path.cwd() / "test_classification_cropland.nc")
30 changes: 19 additions & 11 deletions src/worldcereal/openeo/feature_extractor.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@ class PrestoFeatureExtractor(PatchFeatureExtractor):
specified, should be set as `False`.
"""

import functools

PRESTO_MODEL_URL = "https://artifactory.vgt.vito.be/artifactory/auxdata-public/worldcereal-minimal-inference/presto.pt" # NOQA
PRESO_WHL_URL = "https://artifactory.vgt.vito.be/artifactory/auxdata-public/worldcereal/dependencies/presto_worldcereal-0.1.1-py3-none-any.whl"
BASE_URL = "https://s3.waw3-1.cloudferro.com/swift/v1/project_dependencies" # NOQA
Expand All @@ -41,6 +43,7 @@ class PrestoFeatureExtractor(PatchFeatureExtractor):
"AGERA5-PRECIP": "precipitation-flux",
}

@functools.lru_cache(maxsize=6)
def unpack_presto_wheel(self, wheel_url: str, destination_dir: str) -> list:
import urllib.request
import zipfile
Expand All @@ -61,7 +64,6 @@ def output_labels(self) -> list:

def execute(self, inarr: xr.DataArray) -> xr.DataArray:
import sys
from pathlib import Path

if self.epsg is None:
raise ValueError(
Expand All @@ -73,6 +75,14 @@ def execute(self, inarr: xr.DataArray) -> xr.DataArray:
)
presto_wheel_url = self._parameters.get("presot_wheel_url", self.PRESO_WHL_URL)

ignore_dependencies = self._parameters.get("ignore_dependencies", False)
if ignore_dependencies:
self.logger.info(
"`ignore_dependencies` flag is set to True. Make sure that "
"Presto and its dependencies are available on the runtime "
"environment"
)

# The below is required to avoid flipping of the result
# when running on OpenEO backend!
inarr = inarr.transpose("bands", "t", "x", "y")
Expand All @@ -87,16 +97,14 @@ def execute(self, inarr: xr.DataArray) -> xr.DataArray:
inarr = inarr.fillna(65535)

# Unzip de dependencies on the backend
self.logger.info("Unzipping dependencies")
deps_dir = self.extract_dependencies(self.BASE_URL, self.DEPENDENCY_NAME)
self.logger.info("Unpacking presto wheel")
deps_dir = self.unpack_presto_wheel(presto_wheel_url, deps_dir)

self.logger.info("Appending dependencies")
sys.path.append(str(deps_dir))

# Debug, print the dependency directory
self.logger.info("Dependency directory: %s", list(Path(deps_dir).iterdir()))
if not ignore_dependencies:
self.logger.info("Unzipping dependencies")
deps_dir = self.extract_dependencies(self.BASE_URL, self.DEPENDENCY_NAME)
self.logger.info("Unpacking presto wheel")
deps_dir = self.unpack_presto_wheel(presto_wheel_url, deps_dir)

self.logger.info("Appending dependencies")
sys.path.append(str(deps_dir))

from presto.inference import ( # pylint: disable=import-outside-toplevel
get_presto_features,
Expand Down

0 comments on commit f4ab084

Please sign in to comment.