Skip to content

Commit

Permalink
Added cropland model inference
Browse files Browse the repository at this point in the history
  • Loading branch information
kvantricht committed Jun 21, 2024
1 parent aa19af4 commit 84f12bc
Showing 1 changed file with 24 additions and 3 deletions.
27 changes: 24 additions & 3 deletions scripts/inference/cropland_mapping_local.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,12 +12,15 @@
EPSG_HARMONIZED_NAME,
apply_feature_extractor_local,
)
from openeo_gfmap.inference.model_inference import apply_model_inference_local

from worldcereal.openeo.feature_extractor import PrestoFeatureExtractor
from worldcereal.openeo.inference import CroplandClassifier

TEST_FILE_URL = "https://artifactory.vgt.vito.be/artifactory/auxdata-public/worldcereal/presto/localtestdata/local_presto_inputs.nc"
TEST_FILE_PATH = Path.cwd() / "presto_test_inputs.nc"

PRESTO_URL = "https://artifactory.vgt.vito.be/artifactory/auxdata-public/worldcereal-minimal-inference/presto.pt"
CATBOOST_URL = "https://artifactory.vgt.vito.be/artifactory/auxdata-public/worldcereal-minimal-inference/wc_catboost.onnx"

if __name__ == "__main__":
if not TEST_FILE_PATH.exists():
Expand All @@ -41,7 +44,25 @@
features = apply_feature_extractor_local(
PrestoFeatureExtractor,
arr,
parameters={EPSG_HARMONIZED_NAME: 32631, "ignore_dependencies": True},
parameters={
EPSG_HARMONIZED_NAME: 32631,
"ignore_dependencies": True,
"presto_model_url": PRESTO_URL,
},
)

features.to_netcdf(Path.cwd() / "presto_test_features_cropland.nc")

print("Running classification inference UDF locally")

classification = apply_model_inference_local(
CroplandClassifier,
features,
parameters={
EPSG_HARMONIZED_NAME: 32631,
"ignore_dependencies": True,
"classifier_url": CATBOOST_URL,
},
)

features.to_netcdf(Path.cwd() / "presto_test_features.nc")
classification.to_netcdf(Path.cwd() / "test_classification_cropland.nc")

0 comments on commit 84f12bc

Please sign in to comment.