From 50caab30537ddfb18387c670b53642e0bb40945b Mon Sep 17 00:00:00 2001 From: Kristof Van Tricht Date: Wed, 12 Jun 2024 16:41:36 +0200 Subject: [PATCH] Make batch_size for Presto flexible --- src/worldcereal/openeo/feature_extractor.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/src/worldcereal/openeo/feature_extractor.py b/src/worldcereal/openeo/feature_extractor.py index 19037055..40dde71a 100644 --- a/src/worldcereal/openeo/feature_extractor.py +++ b/src/worldcereal/openeo/feature_extractor.py @@ -19,7 +19,7 @@ class PrestoFeatureExtractor(PatchFeatureExtractor): """ PRESTO_MODEL_URL = "https://artifactory.vgt.vito.be/artifactory/auxdata-public/worldcereal-minimal-inference/presto.pt" # NOQA - PRESO_WHL_URL = "https://artifactory.vgt.vito.be/artifactory/auxdata-public/worldcereal/dependencies/presto_worldcereal-0.1.0-py3-none-any.whl" + PRESO_WHL_URL = "https://artifactory.vgt.vito.be/artifactory/auxdata-public/worldcereal/dependencies/presto_worldcereal-0.1.1-py3-none-any.whl" BASE_URL = "https://s3.waw3-1.cloudferro.com/swift/v1/project_dependencies" # NOQA DEPENDENCY_NAME = "worldcereal_deps.zip" @@ -103,7 +103,9 @@ def execute(self, inarr: xr.DataArray) -> xr.DataArray: ) self.logger.info("Extracting presto features") - features = get_presto_features(inarr, presto_model_url, self.epsg) + features = get_presto_features( + inarr, presto_model_url, self.epsg, batch_size=4096 + ) return features def _execute(self, cube: XarrayDataCube, parameters: dict) -> XarrayDataCube: