From fc8a1e2bb3ef1334f6d298c5cf23f233dca70c7d Mon Sep 17 00:00:00 2001 From: Kristof Van Tricht Date: Tue, 25 Jun 2024 14:31:47 +0200 Subject: [PATCH 1/9] Reduce default batch size --- src/worldcereal/openeo/feature_extractor.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/worldcereal/openeo/feature_extractor.py b/src/worldcereal/openeo/feature_extractor.py index e5875bdb..4526ba37 100644 --- a/src/worldcereal/openeo/feature_extractor.py +++ b/src/worldcereal/openeo/feature_extractor.py @@ -110,7 +110,7 @@ def execute(self, inarr: xr.DataArray) -> xr.DataArray: get_presto_features, ) - batch_size = self._parameters.get("batch_size", 4096) + batch_size = self._parameters.get("batch_size", 256) self.logger.info("Extracting presto features") features = get_presto_features( From 3721fe3a61f70194c23c38212ede548275be5468 Mon Sep 17 00:00:00 2001 From: Kristof Van Tricht Date: Tue, 25 Jun 2024 16:08:54 +0200 Subject: [PATCH 2/9] First attempt cropland masking --- src/worldcereal/job.py | 189 +++++++++++++++++++++------- src/worldcereal/openeo/inference.py | 9 ++ 2 files changed, 154 insertions(+), 44 deletions(-) diff --git a/src/worldcereal/job.py b/src/worldcereal/job.py index 1e29a4a4..0bfdc5fd 100644 --- a/src/worldcereal/job.py +++ b/src/worldcereal/job.py @@ -1,10 +1,12 @@ """Executing inference jobs on the OpenEO backend.""" + from dataclasses import dataclass from enum import Enum from pathlib import Path from typing import Optional, Union import openeo +from openeo import DataCube from openeo_gfmap import BackendContext, BoundingBoxExtent, TemporalContext from openeo_gfmap.features.feature_extractor import apply_feature_extractor from openeo_gfmap.inference.model_inference import apply_model_inference @@ -87,6 +89,7 @@ def generate_map( output_path: Optional[Union[Path, str]], product_type: WorldCerealProduct = WorldCerealProduct.CROPLAND, out_format: str = "GTiff", + apply_cropland_mask: bool = False, ): """Main function to generate a WorldCereal product. @@ -97,21 +100,32 @@ def generate_map( output_path (Union[Path, str]): output path to download the product to product (str, optional): product describer. Defaults to "cropland". format (str, optional): Output format. Defaults to "GTiff". + apply_cropland_mask (bool, optional). If True the output will be masked + with the cropland map. Defaults to False. Raises: ValueError: if the product is not supported + ValueError: if the out_format is not supported + ValueError: if a cropland mask is applied on a cropland workflow + """ if product_type not in PRODUCT_SETTINGS.keys(): raise ValueError(f"Product {product_type.value} not supported.") + if out_format not in ["GTiff", "NetCDF"]: + raise ValueError(f"Format {format} not supported.") + + if product_type == WorldCerealProduct.CROPLAND and apply_cropland_mask: + raise ValueError("Cannot apply a cropland mask on a cropland workflow.") + # Connect to openeo connection = openeo.connect( "https://openeo.creo.vito.be/openeo/" ).authenticate_oidc() - # Preparing the input cube for the inference + # Preparing the input cube for inference inputs = worldcereal_preprocessed_inputs_gfmap( connection=connection, backend_context=backend_context, @@ -119,53 +133,24 @@ def generate_map( temporal_extent=temporal_extent, ) - # Run feature computer - features = apply_feature_extractor( - feature_extractor_class=PRODUCT_SETTINGS[product_type]["features"]["extractor"], - cube=inputs, - parameters=PRODUCT_SETTINGS[product_type]["features"]["parameters"], - size=[ - {"dimension": "x", "unit": "px", "value": 100}, - {"dimension": "y", "unit": "px", "value": 100}, - ], - overlap=[ - {"dimension": "x", "unit": "px", "value": 0}, - {"dimension": "y", "unit": "px", "value": 0}, - ], - ) - - if out_format not in ["GTiff", "NetCDF"]: - raise ValueError(f"Format {format} not supported.") - - classes = apply_model_inference( - model_inference_class=PRODUCT_SETTINGS[product_type]["classification"][ - "classifier" - ], - cube=features, - parameters=PRODUCT_SETTINGS[product_type]["classification"]["parameters"], - size=[ - {"dimension": "x", "unit": "px", "value": 100}, - {"dimension": "y", "unit": "px", "value": 100}, - {"dimension": "t", "value": "P1D"}, - ], - overlap=[ - {"dimension": "x", "unit": "px", "value": 0}, - {"dimension": "y", "unit": "px", "value": 0}, - ], - ) - - # Cast to uint8 + # Construct the feature extraction and model inference pipeline if product_type == WorldCerealProduct.CROPLAND: - classes = compress_uint8(classes) - else: - classes = compress_uint16(classes) - + classes = _cropland_map(inputs) + elif product_type == WorldCerealProduct.CROPTYPE: + if apply_cropland_mask: + # First compute cropland map + cropland_mask = _cropland_map(inputs) + else: + cropland_mask = None + classes = _croptype_map(inputs, cropland_mask=cropland_mask) + + # Submit the job job = classes.execute_batch( outputfile=output_path, out_format=out_format, job_options={ "driver-memory": "4g", - "executor-memoryOverhead": "6g", + "executor-memoryOverhead": "4g", "udf-dependency-archives": [f"{ONNX_DEPS_URL}#onnx_deps"], }, ) @@ -173,7 +158,7 @@ def generate_map( asset = job.get_results().get_assets()[0] return InferenceResults( - job_id=classes.job_id, + job_id=job.job_id, product_url=asset.href, output_path=output_path, product=product_type, @@ -185,7 +170,7 @@ def collect_inputs( temporal_extent: TemporalContext, backend_context: BackendContext, output_path: Union[Path, str], -): +) -> DataCube: """Function to retrieve preprocessed inputs that are being used in the generation of WorldCereal products. @@ -218,3 +203,119 @@ def collect_inputs( out_format="NetCDF", job_options={"driver-memory": "4g", "executor-memoryOverhead": "4g"}, ) + + +def _cropland_map(inputs: DataCube) -> DataCube: + """Method to produce cropland map from preprocessed inputs, using + a Presto feature extractor and a CatBoost classifier. + + Args: + inputs (DataCube): preprocessed input cube + + Returns: + DataCube: binary labels and probability + """ + + # Run feature computer + features = apply_feature_extractor( + feature_extractor_class=PRODUCT_SETTINGS[WorldCerealProduct.CROPLAND][ + "features" + ]["extractor"], + cube=inputs, + parameters=PRODUCT_SETTINGS[WorldCerealProduct.CROPLAND]["features"][ + "parameters" + ], + size=[ + {"dimension": "x", "unit": "px", "value": 100}, + {"dimension": "y", "unit": "px", "value": 100}, + ], + overlap=[ + {"dimension": "x", "unit": "px", "value": 0}, + {"dimension": "y", "unit": "px", "value": 0}, + ], + ) + + # Run model inference on features + classes = apply_model_inference( + model_inference_class=PRODUCT_SETTINGS[WorldCerealProduct.CROPLAND][ + "classification" + ]["classifier"], + cube=features, + parameters=PRODUCT_SETTINGS[WorldCerealProduct.CROPLAND]["classification"][ + "parameters" + ], + size=[ + {"dimension": "x", "unit": "px", "value": 100}, + {"dimension": "y", "unit": "px", "value": 100}, + {"dimension": "t", "value": "P1D"}, + ], + overlap=[ + {"dimension": "x", "unit": "px", "value": 0}, + {"dimension": "y", "unit": "px", "value": 0}, + ], + ) + + # Cast to uint8 + classes = compress_uint8(classes) + + return classes + + +def _croptype_map(inputs: DataCube, cropland_mask: DataCube = None) -> DataCube: + """Method to produce croptype map from preprocessed inputs, using + a Presto feature extractor and a CatBoost classifier. + + Args: + inputs (DataCube): preprocessed input cube + cropland_mask (DataCube): optional cropland mask + + Returns: + DataCube: croptype labels and probability + """ + + # Run feature computer + features = apply_feature_extractor( + feature_extractor_class=PRODUCT_SETTINGS[WorldCerealProduct.CROPTYPE][ + "features" + ]["extractor"], + cube=inputs, + parameters=PRODUCT_SETTINGS[WorldCerealProduct.CROPTYPE]["features"][ + "parameters" + ], + size=[ + {"dimension": "x", "unit": "px", "value": 100}, + {"dimension": "y", "unit": "px", "value": 100}, + ], + overlap=[ + {"dimension": "x", "unit": "px", "value": 0}, + {"dimension": "y", "unit": "px", "value": 0}, + ], + ) + + # Run model inference on features + classes = apply_model_inference( + model_inference_class=PRODUCT_SETTINGS[WorldCerealProduct.CROPTYPE][ + "classification" + ]["classifier"], + cube=features, + parameters=dict( + **PRODUCT_SETTINGS[WorldCerealProduct.CROPTYPE]["classification"][ + "parameters" + ], + **{"cropland_mask": cropland_mask}, + ), + size=[ + {"dimension": "x", "unit": "px", "value": 100}, + {"dimension": "y", "unit": "px", "value": 100}, + {"dimension": "t", "value": "P1D"}, + ], + overlap=[ + {"dimension": "x", "unit": "px", "value": 0}, + {"dimension": "y", "unit": "px", "value": 0}, + ], + ) + + # Cast to uint16 + classes = compress_uint16(classes) + + return classes diff --git a/src/worldcereal/openeo/inference.py b/src/worldcereal/openeo/inference.py index 2feb409f..b1fc1f3a 100644 --- a/src/worldcereal/openeo/inference.py +++ b/src/worldcereal/openeo/inference.py @@ -165,4 +165,13 @@ def execute(self, inarr: xr.DataArray) -> xr.DataArray: }, ) + # Apply optional cropland mask + cropland_mask = self._parameters.get("cropland_mask", None) + if cropland_mask is not None: + # Non-cropland pixels are set to 0 + self.logger.info("Applying cropland mask ...") + classification = classification.where( + cropland_mask.sel(bands="classification") == 1, 0 + ) + return classification From 43e8d2ae3ebda4d051ce5eb2195364619dd0c971 Mon Sep 17 00:00:00 2001 From: Kristof Van Tricht Date: Thu, 27 Jun 2024 09:40:11 +0200 Subject: [PATCH 3/9] torch==2.3.1 --- environment.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/environment.yml b/environment.yml index d2522c17..13844683 100644 --- a/environment.yml +++ b/environment.yml @@ -16,7 +16,7 @@ dependencies: - openeo=0.29.0 - pyarrow=16.1.0 - python=3.10.0 - - pytorch=2.3.0 + - pytorch=2.3.1 - rasterio=1.3.10 - rioxarray=0.15.5 - scikit-image=0.22.0 From fe241d3dbd7a08d0c76caddae7a7660cd3ea5b25 Mon Sep 17 00:00:00 2001 From: Kristof Van Tricht Date: Thu, 27 Jun 2024 09:41:48 +0200 Subject: [PATCH 4/9] Working version of cropland masking --- src/worldcereal/job.py | 37 ++++++++++++++++------------- src/worldcereal/openeo/inference.py | 9 ------- 2 files changed, 20 insertions(+), 26 deletions(-) diff --git a/src/worldcereal/job.py b/src/worldcereal/job.py index 0bfdc5fd..309553ec 100644 --- a/src/worldcereal/job.py +++ b/src/worldcereal/job.py @@ -89,7 +89,6 @@ def generate_map( output_path: Optional[Union[Path, str]], product_type: WorldCerealProduct = WorldCerealProduct.CROPLAND, out_format: str = "GTiff", - apply_cropland_mask: bool = False, ): """Main function to generate a WorldCereal product. @@ -100,8 +99,6 @@ def generate_map( output_path (Union[Path, str]): output path to download the product to product (str, optional): product describer. Defaults to "cropland". format (str, optional): Output format. Defaults to "GTiff". - apply_cropland_mask (bool, optional). If True the output will be masked - with the cropland map. Defaults to False. Raises: ValueError: if the product is not supported @@ -117,9 +114,6 @@ def generate_map( if out_format not in ["GTiff", "NetCDF"]: raise ValueError(f"Format {format} not supported.") - if product_type == WorldCerealProduct.CROPLAND and apply_cropland_mask: - raise ValueError("Cannot apply a cropland mask on a cropland workflow.") - # Connect to openeo connection = openeo.connect( "https://openeo.creo.vito.be/openeo/" @@ -133,15 +127,23 @@ def generate_map( temporal_extent=temporal_extent, ) + # Explicit filtering again for bbox because of METEO low + # resolution causing issues + inputs = inputs.filter_bbox(dict(spatial_extent)) + # Construct the feature extraction and model inference pipeline if product_type == WorldCerealProduct.CROPLAND: classes = _cropland_map(inputs) elif product_type == WorldCerealProduct.CROPTYPE: - if apply_cropland_mask: - # First compute cropland map - cropland_mask = _cropland_map(inputs) - else: - cropland_mask = None + # First compute cropland map + cropland_mask = ( + _cropland_map(inputs) + .filter_bands("classification") + .reduce_dimension( + dimension="t", reducer="mean" + ) # Temporary fix to make this work as mask + ) + classes = _croptype_map(inputs, cropland_mask=cropland_mask) # Submit the job @@ -298,12 +300,9 @@ def _croptype_map(inputs: DataCube, cropland_mask: DataCube = None) -> DataCube: "classification" ]["classifier"], cube=features, - parameters=dict( - **PRODUCT_SETTINGS[WorldCerealProduct.CROPTYPE]["classification"][ - "parameters" - ], - **{"cropland_mask": cropland_mask}, - ), + parameters=PRODUCT_SETTINGS[WorldCerealProduct.CROPTYPE]["classification"][ + "parameters" + ], size=[ {"dimension": "x", "unit": "px", "value": 100}, {"dimension": "y", "unit": "px", "value": 100}, @@ -315,6 +314,10 @@ def _croptype_map(inputs: DataCube, cropland_mask: DataCube = None) -> DataCube: ], ) + # Mask cropland + if cropland_mask is not None: + classes = classes.mask(cropland_mask == 0, replacement=0) + # Cast to uint16 classes = compress_uint16(classes) diff --git a/src/worldcereal/openeo/inference.py b/src/worldcereal/openeo/inference.py index b1fc1f3a..2feb409f 100644 --- a/src/worldcereal/openeo/inference.py +++ b/src/worldcereal/openeo/inference.py @@ -165,13 +165,4 @@ def execute(self, inarr: xr.DataArray) -> xr.DataArray: }, ) - # Apply optional cropland mask - cropland_mask = self._parameters.get("cropland_mask", None) - if cropland_mask is not None: - # Non-cropland pixels are set to 0 - self.logger.info("Applying cropland mask ...") - classification = classification.where( - cropland_mask.sel(bands="classification") == 1, 0 - ) - return classification From ef3936f36dc14265b40cfcf4815c4b1d26df36d8 Mon Sep 17 00:00:00 2001 From: Kristof Van Tricht Date: Thu, 27 Jun 2024 09:42:12 +0200 Subject: [PATCH 5/9] Support both cropland and croptype products --- scripts/inference/cropland_mapping.py | 19 ++++++++++++++++--- 1 file changed, 16 insertions(+), 3 deletions(-) diff --git a/scripts/inference/cropland_mapping.py b/scripts/inference/cropland_mapping.py index 5068a184..f95a4260 100644 --- a/scripts/inference/cropland_mapping.py +++ b/scripts/inference/cropland_mapping.py @@ -11,8 +11,8 @@ if __name__ == "__main__": parser = argparse.ArgumentParser( - prog="WC - Cropland Inference", - description="Cropland inference using GFMAP, Presto and WorldCereal classifiers", + prog="WC - Crop Mapping Inference", + description="Crop Mapping inference using GFMAP, Presto and WorldCereal classifiers", ) parser.add_argument("minx", type=float, help="Minimum X coordinate (west)") @@ -25,6 +25,11 @@ default=4326, help="EPSG code of the input `minx`, `miny`, `maxx`, `maxy` parameters.", ) + parser.add_argument( + "product", + type=str, + help="Product to generate. One of ['cropland', 'croptype']", + ) parser.add_argument( "start_date", type=str, help="Starting date for data extraction." ) @@ -46,6 +51,14 @@ start_date = args.start_date end_date = args.end_date + product = args.product + + # minx, miny, maxx, maxy = (664000, 5611134, 665000, 5612134) + # epsg = 32631 + # start_date = "2020-11-01" + # end_date = "2021-10-31" + # product = "croptype" + spatial_extent = BoundingBoxExtent(minx, miny, maxx, maxy, epsg) temporal_extent = TemporalContext(start_date, end_date) @@ -56,7 +69,7 @@ temporal_extent, backend_context, args.output_path, - product_type=WorldCerealProduct.CROPLAND, + product_type=WorldCerealProduct(product), out_format="GTiff", ) logger.success("Job finished:\n\t%s", job_results) From ba7a54758739b5c01e781b3ca435fd263b00a602 Mon Sep 17 00:00:00 2001 From: Kristof Van Tricht Date: Thu, 27 Jun 2024 09:42:30 +0200 Subject: [PATCH 6/9] Fix typos --- src/worldcereal/openeo/feature_extractor.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/worldcereal/openeo/feature_extractor.py b/src/worldcereal/openeo/feature_extractor.py index 4526ba37..e837994a 100644 --- a/src/worldcereal/openeo/feature_extractor.py +++ b/src/worldcereal/openeo/feature_extractor.py @@ -21,7 +21,7 @@ class PrestoFeatureExtractor(PatchFeatureExtractor): import functools PRESTO_MODEL_URL = "https://artifactory.vgt.vito.be/artifactory/auxdata-public/worldcereal-minimal-inference/presto.pt" # NOQA - PRESO_WHL_URL = "https://artifactory.vgt.vito.be/artifactory/auxdata-public/worldcereal/dependencies/presto_worldcereal-0.1.1-py3-none-any.whl" + PRESTO_WHL_URL = "https://artifactory.vgt.vito.be/artifactory/auxdata-public/worldcereal/dependencies/presto_worldcereal-0.1.1-py3-none-any.whl" BASE_URL = "https://s3.waw3-1.cloudferro.com/swift/v1/project_dependencies" # NOQA DEPENDENCY_NAME = "worldcereal_deps.zip" @@ -73,7 +73,7 @@ def execute(self, inarr: xr.DataArray) -> xr.DataArray: presto_model_url = self._parameters.get( "presto_model_url", self.PRESTO_MODEL_URL ) - presto_wheel_url = self._parameters.get("presot_wheel_url", self.PRESO_WHL_URL) + presto_wheel_url = self._parameters.get("presto_wheel_url", self.PRESTO_WHL_URL) ignore_dependencies = self._parameters.get("ignore_dependencies", False) if ignore_dependencies: From b6119811889469937fecd2ce112e4b8123f17345 Mon Sep 17 00:00:00 2001 From: Kristof Van Tricht Date: Thu, 27 Jun 2024 09:45:01 +0200 Subject: [PATCH 7/9] Add test bounds --- scripts/inference/cropland_mapping.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/scripts/inference/cropland_mapping.py b/scripts/inference/cropland_mapping.py index f95a4260..c94eae3a 100644 --- a/scripts/inference/cropland_mapping.py +++ b/scripts/inference/cropland_mapping.py @@ -53,7 +53,8 @@ product = args.product - # minx, miny, maxx, maxy = (664000, 5611134, 665000, 5612134) + # minx, miny, maxx, maxy = (664000, 5611134, 665000, 5612134) # Small test + # minx, miny, maxx, maxy = (664000, 5611134, 684000, 5631134) # Large test # epsg = 32631 # start_date = "2020-11-01" # end_date = "2021-10-31" From c3c317ac31af876ae1b616c114dda227b81e5c08 Mon Sep 17 00:00:00 2001 From: Kristof Van Tricht Date: Thu, 27 Jun 2024 09:54:38 +0200 Subject: [PATCH 8/9] Updated local croptype mapping script --- scripts/inference/croptype_mapping_local.py | 42 ++++++++++++++++----- 1 file changed, 33 insertions(+), 9 deletions(-) diff --git a/scripts/inference/croptype_mapping_local.py b/scripts/inference/croptype_mapping_local.py index 65962eac..d663d655 100644 --- a/scripts/inference/croptype_mapping_local.py +++ b/scripts/inference/croptype_mapping_local.py @@ -15,7 +15,7 @@ from openeo_gfmap.inference.model_inference import apply_model_inference_local from worldcereal.openeo.feature_extractor import PrestoFeatureExtractor -from worldcereal.openeo.inference import CroptypeClassifier +from worldcereal.openeo.inference import CroplandClassifier, CroptypeClassifier TEST_FILE_URL = "https://artifactory.vgt.vito.be/artifactory/auxdata-public/worldcereal/presto/localtestdata/local_presto_inputs.nc" TEST_FILE_PATH = Path.cwd() / "presto_test_inputs.nc" @@ -40,24 +40,40 @@ .astype("uint16") ) - print("Running presto UDF locally") - features = apply_feature_extractor_local( + print("Get Presto cropland features") + cropland_features = apply_feature_extractor_local( PrestoFeatureExtractor, arr, + parameters={EPSG_HARMONIZED_NAME: 32631, "ignore_dependencies": True}, + ) + + print("Running cropland classification inference UDF locally") + + cropland_classification = apply_model_inference_local( + CroplandClassifier, + cropland_features, parameters={ EPSG_HARMONIZED_NAME: 32631, "ignore_dependencies": True, - "presto_model_url": PRESTO_URL, }, ) - features.to_netcdf(Path.cwd() / "presto_test_features_croptype.nc") + print("Get Presto croptype features") + croptype_features = apply_feature_extractor_local( + PrestoFeatureExtractor, + arr, + parameters={ + EPSG_HARMONIZED_NAME: 32631, + "ignore_dependencies": True, + "presto_model_url": PRESTO_URL, + }, + ) - print("Running classification inference UDF locally") + print("Running croptype classification inference UDF locally") - classification = apply_model_inference_local( + croptype_classification = apply_model_inference_local( CroptypeClassifier, - features, + croptype_features, parameters={ EPSG_HARMONIZED_NAME: 32631, "ignore_dependencies": True, @@ -65,4 +81,12 @@ }, ) - classification.to_netcdf(Path.cwd() / "test_classification_croptype.nc") + # Apply cropland mask -> on the backend this is done with mask process + croptype_classification = croptype_classification.where( + cropland_classification.sel(bands="classification") == 1, 0 + ) + + croptype_classification.to_netcdf( + Path("/vitodata/worldcereal/validation/internal_validation/") + / "test_classification_croptype_local.nc" + ) From 3437b23bdd2647dfc81e1e5492ac7c017d227a19 Mon Sep 17 00:00:00 2001 From: Kristof Van Tricht Date: Thu, 27 Jun 2024 13:52:04 +0200 Subject: [PATCH 9/9] Docstyle updated to Numpy --- src/worldcereal/job.py | 91 ++++++++++++++++++++++++++---------------- 1 file changed, 57 insertions(+), 34 deletions(-) diff --git a/src/worldcereal/job.py b/src/worldcereal/job.py index 309553ec..8f6ac9e4 100644 --- a/src/worldcereal/job.py +++ b/src/worldcereal/job.py @@ -89,23 +89,35 @@ def generate_map( output_path: Optional[Union[Path, str]], product_type: WorldCerealProduct = WorldCerealProduct.CROPLAND, out_format: str = "GTiff", -): +) -> InferenceResults: """Main function to generate a WorldCereal product. - Args: - spatial_extent (BoundingBoxExtent): spatial extent of the map - temporal_extent (TemporalContext): temporal range to consider - backend_context (BackendContext): backend to run the job on - output_path (Union[Path, str]): output path to download the product to - product (str, optional): product describer. Defaults to "cropland". - format (str, optional): Output format. Defaults to "GTiff". - - Raises: - ValueError: if the product is not supported - ValueError: if the out_format is not supported - ValueError: if a cropland mask is applied on a cropland workflow - - + Parameters + ---------- + spatial_extent : BoundingBoxExtent + spatial extent of the map + temporal_extent : TemporalContext + temporal range to consider + backend_context : BackendContext + backend to run the job on + output_path : Optional[Union[Path, str]] + output path to download the product to + product_type : WorldCerealProduct, optional + product describer, by default WorldCerealProduct.CROPLAND + out_format : str, optional + Output format, by default "GTiff" + + Returns + ------- + InferenceResults + Results of the finished WorldCereal job. + + Raises + ------ + ValueError + if the product is not supported + ValueError + if the out_format is not supported """ if product_type not in PRODUCT_SETTINGS.keys(): @@ -172,18 +184,20 @@ def collect_inputs( temporal_extent: TemporalContext, backend_context: BackendContext, output_path: Union[Path, str], -) -> DataCube: +): """Function to retrieve preprocessed inputs that are being used in the generation of WorldCereal products. - Args: - spatial_extent (BoundingBoxExtent): spatial extent of the map - temporal_extent (TemporalContext): temporal range to consider - backend_context (BackendContext): backend to run the job on - output_path (Union[Path, str]): output path to download the product to - - Raises: - ValueError: if the product is not supported + Parameters + ---------- + spatial_extent : BoundingBoxExtent + spatial extent of the map + temporal_extent : TemporalContext + temporal range to consider + backend_context : BackendContext + backend to run the job on + output_path : Union[Path, str] + output path to download the product to """ @@ -211,11 +225,15 @@ def _cropland_map(inputs: DataCube) -> DataCube: """Method to produce cropland map from preprocessed inputs, using a Presto feature extractor and a CatBoost classifier. - Args: - inputs (DataCube): preprocessed input cube + Parameters + ---------- + inputs : DataCube + preprocessed input cube - Returns: - DataCube: binary labels and probability + Returns + ------- + DataCube + binary labels and probability """ # Run feature computer @@ -267,12 +285,17 @@ def _croptype_map(inputs: DataCube, cropland_mask: DataCube = None) -> DataCube: """Method to produce croptype map from preprocessed inputs, using a Presto feature extractor and a CatBoost classifier. - Args: - inputs (DataCube): preprocessed input cube - cropland_mask (DataCube): optional cropland mask - - Returns: - DataCube: croptype labels and probability + Parameters + ---------- + inputs : DataCube + preprocessed input cube + cropland_mask : DataCube, optional + optional cropland mask, by default None + + Returns + ------- + DataCube + croptype labels and probability """ # Run feature computer