From fac6cb2012c2056af64daedb5e29b3b8fcaef4f2 Mon Sep 17 00:00:00 2001 From: Darius Couchard Date: Fri, 21 Jun 2024 09:52:19 +0200 Subject: [PATCH] PR 67 fixes --- scripts/inference/cropland_mapping.py | 2 -- src/worldcereal/job.py | 7 +++++-- 2 files changed, 5 insertions(+), 4 deletions(-) diff --git a/scripts/inference/cropland_mapping.py b/scripts/inference/cropland_mapping.py index 3d4fbecb..1a2bf5b4 100644 --- a/scripts/inference/cropland_mapping.py +++ b/scripts/inference/cropland_mapping.py @@ -8,8 +8,6 @@ from worldcereal.job import generate_map -ONNX_DEPS_URL = "https://artifactory.vgt.vito.be/artifactory/auxdata-public/openeo/onnx_dependencies_1.16.3.zip" - if __name__ == "__main__": parser = argparse.ArgumentParser( prog="WC - Cropland Inference", diff --git a/src/worldcereal/job.py b/src/worldcereal/job.py index 36367491..934f77d9 100644 --- a/src/worldcereal/job.py +++ b/src/worldcereal/job.py @@ -5,7 +5,7 @@ from openeo_gfmap import BackendContext, BoundingBoxExtent, TemporalContext from openeo_gfmap.features.feature_extractor import apply_feature_extractor from openeo_gfmap.inference.model_inference import apply_model_inference -from openeo_gfmap.preprocessing.scaling import compress_uint8 +from openeo_gfmap.preprocessing.scaling import compress_uint8, compress_uint16 from worldcereal.openeo.feature_extractor import PrestoFeatureExtractor from worldcereal.openeo.inference import CroplandClassifier @@ -95,7 +95,10 @@ def generate_map( ) # Cast to uint8 - classes = compress_uint8(classes) + if product == "cropland": + classes = compress_uint8(classes) + else: + classes = compress_uint16(classes) classes.execute_batch( outputfile=output_path,