Skip to content

Commit fefd73d

Browse files
authored
Merge pull request #91 from WorldCereal/86-generete-map-params
Added configuration data classes for cropland/croptype mapping workflow
2 parents ab9293e + ce3aa8f commit fefd73d

File tree

3 files changed

+255
-178
lines changed

3 files changed

+255
-178
lines changed

scripts/inference/cropland_mapping.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -20,25 +20,25 @@
2020
parser.add_argument("maxx", type=float, help="Maximum X coordinate (east)")
2121
parser.add_argument("maxy", type=float, help="Maximum Y coordinate (north)")
2222
parser.add_argument(
23-
"--epsg",
24-
type=int,
25-
default=4326,
26-
help="EPSG code of the input `minx`, `miny`, `maxx`, `maxy` parameters.",
23+
"start_date", type=str, help="Starting date for data extraction."
2724
)
25+
parser.add_argument("end_date", type=str, help="Ending date for data extraction.")
2826
parser.add_argument(
2927
"product",
3028
type=str,
3129
help="Product to generate. One of ['cropland', 'croptype']",
3230
)
33-
parser.add_argument(
34-
"start_date", type=str, help="Starting date for data extraction."
35-
)
36-
parser.add_argument("end_date", type=str, help="Ending date for data extraction.")
3731
parser.add_argument(
3832
"output_path",
3933
type=Path,
4034
help="Path to folder where to save the resulting GeoTiff.",
4135
)
36+
parser.add_argument(
37+
"--epsg",
38+
type=int,
39+
default=4326,
40+
help="EPSG code of the input `minx`, `miny`, `maxx`, `maxy` parameters.",
41+
)
4242

4343
args = parser.parse_args()
4444

src/worldcereal/job.py

Lines changed: 118 additions & 170 deletions
Original file line numberDiff line numberDiff line change
@@ -1,19 +1,16 @@
11
"""Executing inference jobs on the OpenEO backend."""
22

3-
from dataclasses import dataclass
43
from enum import Enum
54
from pathlib import Path
65
from typing import Optional, Union
76

87
import openeo
9-
from openeo import DataCube
108
from openeo_gfmap import BackendContext, BoundingBoxExtent, TemporalContext
11-
from openeo_gfmap.features.feature_extractor import apply_feature_extractor
12-
from openeo_gfmap.inference.model_inference import apply_model_inference
13-
from openeo_gfmap.preprocessing.scaling import compress_uint8, compress_uint16
9+
from pydantic import BaseModel
1410

1511
from worldcereal.openeo.feature_extractor import PrestoFeatureExtractor
1612
from worldcereal.openeo.inference import CroplandClassifier, CroptypeClassifier
13+
from worldcereal.openeo.mapping import _cropland_map, _croptype_map
1714
from worldcereal.openeo.preprocessing import worldcereal_preprocessed_inputs_gfmap
1815

1916

@@ -24,44 +21,109 @@ class WorldCerealProduct(Enum):
2421
CROPTYPE = "croptype"
2522

2623

27-
ONNX_DEPS_URL = "https://artifactory.vgt.vito.be/artifactory/auxdata-public/openeo/onnx_dependencies_1.16.3.zip"
24+
class FeaturesParameters(BaseModel):
25+
"""Parameters for the feature extraction UDFs. Types are enforced by
26+
Pydantic.
2827
29-
PRODUCT_SETTINGS = {
30-
WorldCerealProduct.CROPLAND: {
31-
"features": {
32-
"extractor": PrestoFeatureExtractor,
33-
"parameters": {
34-
"rescale_s1": False, # Will be done in the Presto UDF itself!
35-
"presto_model_url": "https://artifactory.vgt.vito.be/artifactory/auxdata-public/worldcereal-minimal-inference/presto.pt", # NOQA
36-
},
37-
},
38-
"classification": {
39-
"classifier": CroplandClassifier,
40-
"parameters": {
41-
"classifier_url": "https://artifactory.vgt.vito.be/artifactory/auxdata-public/worldcereal-minimal-inference/wc_catboost.onnx" # NOQA
42-
},
43-
},
44-
},
45-
WorldCerealProduct.CROPTYPE: {
46-
"features": {
47-
"extractor": PrestoFeatureExtractor,
48-
"parameters": {
49-
"rescale_s1": False, # Will be done in the Presto UDF itself!
50-
"presto_model_url": "https://artifactory.vgt.vito.be/artifactory/auxdata-public/worldcereal/models/PhaseII/presto-ss-wc-ft-ct-30D_test.pt", # NOQA
51-
},
52-
},
53-
"classification": {
54-
"classifier": CroptypeClassifier,
55-
"parameters": {
56-
"classifier_url": "https://artifactory.vgt.vito.be/artifactory/auxdata-public/worldcereal/models/PhaseII/presto-ss-wc-ft-ct-30D_test_CROPTYPE9.onnx" # NOQA
57-
},
58-
},
59-
},
60-
}
28+
Attributes
29+
----------
30+
rescale_s1 : bool (default=False)
31+
Whether to rescale Sentinel-1 bands before feature extraction. Should be
32+
left to False, as this is done in the Presto UDF itself.
33+
presto_model_url : str
34+
Public URL to the Presto model used for feature extraction. The file
35+
should be a PyTorch serialized model.
36+
"""
37+
38+
rescale_s1: bool
39+
presto_model_url: str
6140

6241

63-
@dataclass
64-
class InferenceResults:
42+
class ClassifierParameters(BaseModel):
43+
"""Parameters for the classifier. Types are enforced by Pydantic.
44+
45+
Attributes
46+
----------
47+
classifier_url : str
48+
Public URL to the classifier model. Te file should be an ONNX accepting
49+
a `features` field for input data and returning either two output
50+
probability arrays `true` and `false` in case of cropland mapping, or
51+
a probability array per-class in case of croptype mapping.
52+
"""
53+
54+
classifier_url: str
55+
56+
57+
class CropLandParameters(BaseModel):
58+
"""Parameters for the cropland product inference pipeline. Types are
59+
enforced by Pydantic.
60+
61+
Attributes
62+
----------
63+
feature_extractor : PrestoFeatureExtractor
64+
Feature extractor to use for the inference. This class must be a
65+
subclass of GFMAP's `PatchFeatureExtractor` and returns float32
66+
features.
67+
features_parameters : FeaturesParameters
68+
Parameters for the feature extraction UDF. Will be serialized into a
69+
dictionnary and passed in the process graph.
70+
classifier : CroplandClassifier
71+
Classifier to use for the inference. This class must be a subclass of
72+
GFMAP's `ModelInference` and returns predictions/probabilities for
73+
cropland.
74+
classifier_parameters : ClassifierParameters
75+
Parameters for the classifier UDF. Will be serialized into a dictionnary
76+
and passed in the process graph.
77+
"""
78+
79+
feature_extractor: PrestoFeatureExtractor = PrestoFeatureExtractor
80+
features_parameters: FeaturesParameters = FeaturesParameters(
81+
rescale_s1=False,
82+
presto_model_url="https://artifactory.vgt.vito.be/artifactory/auxdata-public/worldcereal-minimal-inference/presto.pt", # NOQA
83+
)
84+
classifier: CroplandClassifier = CroplandClassifier
85+
classifier_parameters: ClassifierParameters = ClassifierParameters(
86+
classifier_url="https://artifactory.vgt.vito.be/artifactory/auxdata-public/worldcereal-minimal-inference/wc_catboost.onnx" # NOQA
87+
)
88+
89+
90+
class CropTypeParameters(BaseModel):
91+
"""Parameters for the croptype product inference pipeline. Types are
92+
enforced by Pydantic.
93+
94+
Attributes
95+
----------
96+
feature_extractor : PrestoFeatureExtractor
97+
Feature extractor to use for the inference. This class must be a
98+
subclass of GFMAP's `PatchFeatureExtractor` and returns float32
99+
features.
100+
features_parameters : FeaturesParameters
101+
Parameters for the feature extraction UDF. Will be serialized into a
102+
dictionnary and passed in the process graph.
103+
classifier : CroptypeClassifier
104+
Classifier to use for the inference. This class must be a subclass of
105+
GFMAP's `ModelInference` and returns predictions/probabilities for
106+
cropland classes.
107+
classifier_parameters : ClassifierParameters
108+
Parameters for the classifier UDF. Will be serialized into a dictionnary
109+
and passed in the process graph.
110+
"""
111+
112+
feature_extractor: PrestoFeatureExtractor = PrestoFeatureExtractor
113+
feature_parameters: FeaturesParameters = FeaturesParameters(
114+
rescale_s1=False,
115+
presto_model_url="https://artifactory.vgt.vito.be/artifactory/auxdata-public/worldcereal/models/PhaseII/presto-ss-wc-ft-ct-30D_test.pt", # NOQA
116+
)
117+
classifier: CroptypeClassifier = CroptypeClassifier
118+
classifier_parameters: ClassifierParameters = ClassifierParameters(
119+
classifier_url="https://artifactory.vgt.vito.be/artifactory/auxdata-public/worldcereal/models/PhaseII/presto-ss-wc-ft-ct-30D_test_CROPTYPE9.onnx" # NOQA
120+
)
121+
122+
123+
ONNX_DEPS_URL = "https://artifactory.vgt.vito.be/artifactory/auxdata-public/openeo/onnx_dependencies_1.16.3.zip"
124+
125+
126+
class InferenceResults(BaseModel):
65127
"""Dataclass to store the results of the WorldCereal job.
66128
67129
Attributes
@@ -88,6 +150,8 @@ def generate_map(
88150
backend_context: BackendContext,
89151
output_path: Optional[Union[Path, str]],
90152
product_type: WorldCerealProduct = WorldCerealProduct.CROPLAND,
153+
cropland_parameters: CropLandParameters = CropLandParameters(),
154+
croptype_parameters: Optional[CropTypeParameters] = CropTypeParameters(),
91155
out_format: str = "GTiff",
92156
) -> InferenceResults:
93157
"""Main function to generate a WorldCereal product.
@@ -104,6 +168,12 @@ def generate_map(
104168
output path to download the product to
105169
product_type : WorldCerealProduct, optional
106170
product describer, by default WorldCerealProduct.CROPLAND
171+
cropland_parameters: CropLandParameters
172+
Parameters for the cropland product inference pipeline.
173+
croptype_parameters: Optional[CropTypeParameters]
174+
Parameters for the croptype product inference pipeline. Only required
175+
whenever `product_type` is set to `WorldCerealProduct.CROPTYPE`, will be
176+
ignored otherwise.
107177
out_format : str, optional
108178
Output format, by default "GTiff"
109179
@@ -120,7 +190,7 @@ def generate_map(
120190
if the out_format is not supported
121191
"""
122192

123-
if product_type not in PRODUCT_SETTINGS.keys():
193+
if product_type not in WorldCerealProduct:
124194
raise ValueError(f"Product {product_type.value} not supported.")
125195

126196
if out_format not in ["GTiff", "NetCDF"]:
@@ -145,18 +215,22 @@ def generate_map(
145215

146216
# Construct the feature extraction and model inference pipeline
147217
if product_type == WorldCerealProduct.CROPLAND:
148-
classes = _cropland_map(inputs)
218+
classes = _cropland_map(inputs, cropland_parameters=cropland_parameters)
149219
elif product_type == WorldCerealProduct.CROPTYPE:
150220
# First compute cropland map
151221
cropland_mask = (
152-
_cropland_map(inputs)
222+
_cropland_map(inputs, cropland_parameters=cropland_parameters)
153223
.filter_bands("classification")
154224
.reduce_dimension(
155225
dimension="t", reducer="mean"
156226
) # Temporary fix to make this work as mask
157227
)
158228

159-
classes = _croptype_map(inputs, cropland_mask=cropland_mask)
229+
classes = _croptype_map(
230+
inputs,
231+
croptype_parameters=croptype_parameters,
232+
cropland_mask=cropland_mask,
233+
)
160234

161235
# Submit the job
162236
job = classes.execute_batch(
@@ -219,129 +293,3 @@ def collect_inputs(
219293
out_format="NetCDF",
220294
job_options={"driver-memory": "4g", "executor-memoryOverhead": "4g"},
221295
)
222-
223-
224-
def _cropland_map(inputs: DataCube) -> DataCube:
225-
"""Method to produce cropland map from preprocessed inputs, using
226-
a Presto feature extractor and a CatBoost classifier.
227-
228-
Parameters
229-
----------
230-
inputs : DataCube
231-
preprocessed input cube
232-
233-
Returns
234-
-------
235-
DataCube
236-
binary labels and probability
237-
"""
238-
239-
# Run feature computer
240-
features = apply_feature_extractor(
241-
feature_extractor_class=PRODUCT_SETTINGS[WorldCerealProduct.CROPLAND][
242-
"features"
243-
]["extractor"],
244-
cube=inputs,
245-
parameters=PRODUCT_SETTINGS[WorldCerealProduct.CROPLAND]["features"][
246-
"parameters"
247-
],
248-
size=[
249-
{"dimension": "x", "unit": "px", "value": 100},
250-
{"dimension": "y", "unit": "px", "value": 100},
251-
],
252-
overlap=[
253-
{"dimension": "x", "unit": "px", "value": 0},
254-
{"dimension": "y", "unit": "px", "value": 0},
255-
],
256-
)
257-
258-
# Run model inference on features
259-
classes = apply_model_inference(
260-
model_inference_class=PRODUCT_SETTINGS[WorldCerealProduct.CROPLAND][
261-
"classification"
262-
]["classifier"],
263-
cube=features,
264-
parameters=PRODUCT_SETTINGS[WorldCerealProduct.CROPLAND]["classification"][
265-
"parameters"
266-
],
267-
size=[
268-
{"dimension": "x", "unit": "px", "value": 100},
269-
{"dimension": "y", "unit": "px", "value": 100},
270-
{"dimension": "t", "value": "P1D"},
271-
],
272-
overlap=[
273-
{"dimension": "x", "unit": "px", "value": 0},
274-
{"dimension": "y", "unit": "px", "value": 0},
275-
],
276-
)
277-
278-
# Cast to uint8
279-
classes = compress_uint8(classes)
280-
281-
return classes
282-
283-
284-
def _croptype_map(inputs: DataCube, cropland_mask: DataCube = None) -> DataCube:
285-
"""Method to produce croptype map from preprocessed inputs, using
286-
a Presto feature extractor and a CatBoost classifier.
287-
288-
Parameters
289-
----------
290-
inputs : DataCube
291-
preprocessed input cube
292-
cropland_mask : DataCube, optional
293-
optional cropland mask, by default None
294-
295-
Returns
296-
-------
297-
DataCube
298-
croptype labels and probability
299-
"""
300-
301-
# Run feature computer
302-
features = apply_feature_extractor(
303-
feature_extractor_class=PRODUCT_SETTINGS[WorldCerealProduct.CROPTYPE][
304-
"features"
305-
]["extractor"],
306-
cube=inputs,
307-
parameters=PRODUCT_SETTINGS[WorldCerealProduct.CROPTYPE]["features"][
308-
"parameters"
309-
],
310-
size=[
311-
{"dimension": "x", "unit": "px", "value": 100},
312-
{"dimension": "y", "unit": "px", "value": 100},
313-
],
314-
overlap=[
315-
{"dimension": "x", "unit": "px", "value": 0},
316-
{"dimension": "y", "unit": "px", "value": 0},
317-
],
318-
)
319-
320-
# Run model inference on features
321-
classes = apply_model_inference(
322-
model_inference_class=PRODUCT_SETTINGS[WorldCerealProduct.CROPTYPE][
323-
"classification"
324-
]["classifier"],
325-
cube=features,
326-
parameters=PRODUCT_SETTINGS[WorldCerealProduct.CROPTYPE]["classification"][
327-
"parameters"
328-
],
329-
size=[
330-
{"dimension": "x", "unit": "px", "value": 100},
331-
{"dimension": "y", "unit": "px", "value": 100},
332-
{"dimension": "t", "value": "P1D"},
333-
],
334-
overlap=[
335-
{"dimension": "x", "unit": "px", "value": 0},
336-
{"dimension": "y", "unit": "px", "value": 0},
337-
],
338-
)
339-
340-
# Mask cropland
341-
if cropland_mask is not None:
342-
classes = classes.mask(cropland_mask == 0, replacement=0)
343-
344-
# Cast to uint16
345-
classes = compress_uint16(classes)
346-
347-
return classes

0 commit comments

Comments
 (0)