Skip to content

Commit

Permalink
Make batch_size for Presto flexible
Browse files Browse the repository at this point in the history
  • Loading branch information
kvantricht committed Jun 12, 2024
1 parent 77d5efc commit 50caab3
Showing 1 changed file with 4 additions and 2 deletions.
6 changes: 4 additions & 2 deletions src/worldcereal/openeo/feature_extractor.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ class PrestoFeatureExtractor(PatchFeatureExtractor):
"""

PRESTO_MODEL_URL = "https://artifactory.vgt.vito.be/artifactory/auxdata-public/worldcereal-minimal-inference/presto.pt" # NOQA
PRESO_WHL_URL = "https://artifactory.vgt.vito.be/artifactory/auxdata-public/worldcereal/dependencies/presto_worldcereal-0.1.0-py3-none-any.whl"
PRESO_WHL_URL = "https://artifactory.vgt.vito.be/artifactory/auxdata-public/worldcereal/dependencies/presto_worldcereal-0.1.1-py3-none-any.whl"
BASE_URL = "https://s3.waw3-1.cloudferro.com/swift/v1/project_dependencies" # NOQA
DEPENDENCY_NAME = "worldcereal_deps.zip"

Expand Down Expand Up @@ -103,7 +103,9 @@ def execute(self, inarr: xr.DataArray) -> xr.DataArray:
)

self.logger.info("Extracting presto features")
features = get_presto_features(inarr, presto_model_url, self.epsg)
features = get_presto_features(
inarr, presto_model_url, self.epsg, batch_size=4096
)
return features

def _execute(self, cube: XarrayDataCube, parameters: dict) -> XarrayDataCube:
Expand Down

0 comments on commit 50caab3

Please sign in to comment.