1
1
"""Executing inference jobs on the OpenEO backend."""
2
2
3
- from dataclasses import dataclass
4
3
from enum import Enum
5
4
from pathlib import Path
6
5
from typing import Optional , Union
7
6
8
7
import openeo
9
- from openeo import DataCube
10
8
from openeo_gfmap import BackendContext , BoundingBoxExtent , TemporalContext
11
- from openeo_gfmap .features .feature_extractor import apply_feature_extractor
12
- from openeo_gfmap .inference .model_inference import apply_model_inference
13
- from openeo_gfmap .preprocessing .scaling import compress_uint8 , compress_uint16
9
+ from pydantic import BaseModel
14
10
15
11
from worldcereal .openeo .feature_extractor import PrestoFeatureExtractor
16
12
from worldcereal .openeo .inference import CroplandClassifier , CroptypeClassifier
13
+ from worldcereal .openeo .mapping import _cropland_map , _croptype_map
17
14
from worldcereal .openeo .preprocessing import worldcereal_preprocessed_inputs_gfmap
18
15
19
16
@@ -24,44 +21,109 @@ class WorldCerealProduct(Enum):
24
21
CROPTYPE = "croptype"
25
22
26
23
27
- ONNX_DEPS_URL = "https://artifactory.vgt.vito.be/artifactory/auxdata-public/openeo/onnx_dependencies_1.16.3.zip"
24
+ class FeaturesParameters (BaseModel ):
25
+ """Parameters for the feature extraction UDFs. Types are enforced by
26
+ Pydantic.
28
27
29
- PRODUCT_SETTINGS = {
30
- WorldCerealProduct .CROPLAND : {
31
- "features" : {
32
- "extractor" : PrestoFeatureExtractor ,
33
- "parameters" : {
34
- "rescale_s1" : False , # Will be done in the Presto UDF itself!
35
- "presto_model_url" : "https://artifactory.vgt.vito.be/artifactory/auxdata-public/worldcereal-minimal-inference/presto.pt" , # NOQA
36
- },
37
- },
38
- "classification" : {
39
- "classifier" : CroplandClassifier ,
40
- "parameters" : {
41
- "classifier_url" : "https://artifactory.vgt.vito.be/artifactory/auxdata-public/worldcereal-minimal-inference/wc_catboost.onnx" # NOQA
42
- },
43
- },
44
- },
45
- WorldCerealProduct .CROPTYPE : {
46
- "features" : {
47
- "extractor" : PrestoFeatureExtractor ,
48
- "parameters" : {
49
- "rescale_s1" : False , # Will be done in the Presto UDF itself!
50
- "presto_model_url" : "https://artifactory.vgt.vito.be/artifactory/auxdata-public/worldcereal/models/PhaseII/presto-ss-wc-ft-ct-30D_test.pt" , # NOQA
51
- },
52
- },
53
- "classification" : {
54
- "classifier" : CroptypeClassifier ,
55
- "parameters" : {
56
- "classifier_url" : "https://artifactory.vgt.vito.be/artifactory/auxdata-public/worldcereal/models/PhaseII/presto-ss-wc-ft-ct-30D_test_CROPTYPE9.onnx" # NOQA
57
- },
58
- },
59
- },
60
- }
28
+ Attributes
29
+ ----------
30
+ rescale_s1 : bool (default=False)
31
+ Whether to rescale Sentinel-1 bands before feature extraction. Should be
32
+ left to False, as this is done in the Presto UDF itself.
33
+ presto_model_url : str
34
+ Public URL to the Presto model used for feature extraction. The file
35
+ should be a PyTorch serialized model.
36
+ """
37
+
38
+ rescale_s1 : bool
39
+ presto_model_url : str
61
40
62
41
63
- @dataclass
64
- class InferenceResults :
42
+ class ClassifierParameters (BaseModel ):
43
+ """Parameters for the classifier. Types are enforced by Pydantic.
44
+
45
+ Attributes
46
+ ----------
47
+ classifier_url : str
48
+ Public URL to the classifier model. Te file should be an ONNX accepting
49
+ a `features` field for input data and returning either two output
50
+ probability arrays `true` and `false` in case of cropland mapping, or
51
+ a probability array per-class in case of croptype mapping.
52
+ """
53
+
54
+ classifier_url : str
55
+
56
+
57
+ class CropLandParameters (BaseModel ):
58
+ """Parameters for the cropland product inference pipeline. Types are
59
+ enforced by Pydantic.
60
+
61
+ Attributes
62
+ ----------
63
+ feature_extractor : PrestoFeatureExtractor
64
+ Feature extractor to use for the inference. This class must be a
65
+ subclass of GFMAP's `PatchFeatureExtractor` and returns float32
66
+ features.
67
+ features_parameters : FeaturesParameters
68
+ Parameters for the feature extraction UDF. Will be serialized into a
69
+ dictionnary and passed in the process graph.
70
+ classifier : CroplandClassifier
71
+ Classifier to use for the inference. This class must be a subclass of
72
+ GFMAP's `ModelInference` and returns predictions/probabilities for
73
+ cropland.
74
+ classifier_parameters : ClassifierParameters
75
+ Parameters for the classifier UDF. Will be serialized into a dictionnary
76
+ and passed in the process graph.
77
+ """
78
+
79
+ feature_extractor : PrestoFeatureExtractor = PrestoFeatureExtractor
80
+ features_parameters : FeaturesParameters = FeaturesParameters (
81
+ rescale_s1 = False ,
82
+ presto_model_url = "https://artifactory.vgt.vito.be/artifactory/auxdata-public/worldcereal-minimal-inference/presto.pt" , # NOQA
83
+ )
84
+ classifier : CroplandClassifier = CroplandClassifier
85
+ classifier_parameters : ClassifierParameters = ClassifierParameters (
86
+ classifier_url = "https://artifactory.vgt.vito.be/artifactory/auxdata-public/worldcereal-minimal-inference/wc_catboost.onnx" # NOQA
87
+ )
88
+
89
+
90
+ class CropTypeParameters (BaseModel ):
91
+ """Parameters for the croptype product inference pipeline. Types are
92
+ enforced by Pydantic.
93
+
94
+ Attributes
95
+ ----------
96
+ feature_extractor : PrestoFeatureExtractor
97
+ Feature extractor to use for the inference. This class must be a
98
+ subclass of GFMAP's `PatchFeatureExtractor` and returns float32
99
+ features.
100
+ features_parameters : FeaturesParameters
101
+ Parameters for the feature extraction UDF. Will be serialized into a
102
+ dictionnary and passed in the process graph.
103
+ classifier : CroptypeClassifier
104
+ Classifier to use for the inference. This class must be a subclass of
105
+ GFMAP's `ModelInference` and returns predictions/probabilities for
106
+ cropland classes.
107
+ classifier_parameters : ClassifierParameters
108
+ Parameters for the classifier UDF. Will be serialized into a dictionnary
109
+ and passed in the process graph.
110
+ """
111
+
112
+ feature_extractor : PrestoFeatureExtractor = PrestoFeatureExtractor
113
+ feature_parameters : FeaturesParameters = FeaturesParameters (
114
+ rescale_s1 = False ,
115
+ presto_model_url = "https://artifactory.vgt.vito.be/artifactory/auxdata-public/worldcereal/models/PhaseII/presto-ss-wc-ft-ct-30D_test.pt" , # NOQA
116
+ )
117
+ classifier : CroptypeClassifier = CroptypeClassifier
118
+ classifier_parameters : ClassifierParameters = ClassifierParameters (
119
+ classifier_url = "https://artifactory.vgt.vito.be/artifactory/auxdata-public/worldcereal/models/PhaseII/presto-ss-wc-ft-ct-30D_test_CROPTYPE9.onnx" # NOQA
120
+ )
121
+
122
+
123
+ ONNX_DEPS_URL = "https://artifactory.vgt.vito.be/artifactory/auxdata-public/openeo/onnx_dependencies_1.16.3.zip"
124
+
125
+
126
+ class InferenceResults (BaseModel ):
65
127
"""Dataclass to store the results of the WorldCereal job.
66
128
67
129
Attributes
@@ -88,6 +150,8 @@ def generate_map(
88
150
backend_context : BackendContext ,
89
151
output_path : Optional [Union [Path , str ]],
90
152
product_type : WorldCerealProduct = WorldCerealProduct .CROPLAND ,
153
+ cropland_parameters : CropLandParameters = CropLandParameters (),
154
+ croptype_parameters : Optional [CropTypeParameters ] = CropTypeParameters (),
91
155
out_format : str = "GTiff" ,
92
156
) -> InferenceResults :
93
157
"""Main function to generate a WorldCereal product.
@@ -104,6 +168,12 @@ def generate_map(
104
168
output path to download the product to
105
169
product_type : WorldCerealProduct, optional
106
170
product describer, by default WorldCerealProduct.CROPLAND
171
+ cropland_parameters: CropLandParameters
172
+ Parameters for the cropland product inference pipeline.
173
+ croptype_parameters: Optional[CropTypeParameters]
174
+ Parameters for the croptype product inference pipeline. Only required
175
+ whenever `product_type` is set to `WorldCerealProduct.CROPTYPE`, will be
176
+ ignored otherwise.
107
177
out_format : str, optional
108
178
Output format, by default "GTiff"
109
179
@@ -120,7 +190,7 @@ def generate_map(
120
190
if the out_format is not supported
121
191
"""
122
192
123
- if product_type not in PRODUCT_SETTINGS . keys () :
193
+ if product_type not in WorldCerealProduct :
124
194
raise ValueError (f"Product { product_type .value } not supported." )
125
195
126
196
if out_format not in ["GTiff" , "NetCDF" ]:
@@ -145,18 +215,22 @@ def generate_map(
145
215
146
216
# Construct the feature extraction and model inference pipeline
147
217
if product_type == WorldCerealProduct .CROPLAND :
148
- classes = _cropland_map (inputs )
218
+ classes = _cropland_map (inputs , cropland_parameters = cropland_parameters )
149
219
elif product_type == WorldCerealProduct .CROPTYPE :
150
220
# First compute cropland map
151
221
cropland_mask = (
152
- _cropland_map (inputs )
222
+ _cropland_map (inputs , cropland_parameters = cropland_parameters )
153
223
.filter_bands ("classification" )
154
224
.reduce_dimension (
155
225
dimension = "t" , reducer = "mean"
156
226
) # Temporary fix to make this work as mask
157
227
)
158
228
159
- classes = _croptype_map (inputs , cropland_mask = cropland_mask )
229
+ classes = _croptype_map (
230
+ inputs ,
231
+ croptype_parameters = croptype_parameters ,
232
+ cropland_mask = cropland_mask ,
233
+ )
160
234
161
235
# Submit the job
162
236
job = classes .execute_batch (
@@ -219,129 +293,3 @@ def collect_inputs(
219
293
out_format = "NetCDF" ,
220
294
job_options = {"driver-memory" : "4g" , "executor-memoryOverhead" : "4g" },
221
295
)
222
-
223
-
224
- def _cropland_map (inputs : DataCube ) -> DataCube :
225
- """Method to produce cropland map from preprocessed inputs, using
226
- a Presto feature extractor and a CatBoost classifier.
227
-
228
- Parameters
229
- ----------
230
- inputs : DataCube
231
- preprocessed input cube
232
-
233
- Returns
234
- -------
235
- DataCube
236
- binary labels and probability
237
- """
238
-
239
- # Run feature computer
240
- features = apply_feature_extractor (
241
- feature_extractor_class = PRODUCT_SETTINGS [WorldCerealProduct .CROPLAND ][
242
- "features"
243
- ]["extractor" ],
244
- cube = inputs ,
245
- parameters = PRODUCT_SETTINGS [WorldCerealProduct .CROPLAND ]["features" ][
246
- "parameters"
247
- ],
248
- size = [
249
- {"dimension" : "x" , "unit" : "px" , "value" : 100 },
250
- {"dimension" : "y" , "unit" : "px" , "value" : 100 },
251
- ],
252
- overlap = [
253
- {"dimension" : "x" , "unit" : "px" , "value" : 0 },
254
- {"dimension" : "y" , "unit" : "px" , "value" : 0 },
255
- ],
256
- )
257
-
258
- # Run model inference on features
259
- classes = apply_model_inference (
260
- model_inference_class = PRODUCT_SETTINGS [WorldCerealProduct .CROPLAND ][
261
- "classification"
262
- ]["classifier" ],
263
- cube = features ,
264
- parameters = PRODUCT_SETTINGS [WorldCerealProduct .CROPLAND ]["classification" ][
265
- "parameters"
266
- ],
267
- size = [
268
- {"dimension" : "x" , "unit" : "px" , "value" : 100 },
269
- {"dimension" : "y" , "unit" : "px" , "value" : 100 },
270
- {"dimension" : "t" , "value" : "P1D" },
271
- ],
272
- overlap = [
273
- {"dimension" : "x" , "unit" : "px" , "value" : 0 },
274
- {"dimension" : "y" , "unit" : "px" , "value" : 0 },
275
- ],
276
- )
277
-
278
- # Cast to uint8
279
- classes = compress_uint8 (classes )
280
-
281
- return classes
282
-
283
-
284
- def _croptype_map (inputs : DataCube , cropland_mask : DataCube = None ) -> DataCube :
285
- """Method to produce croptype map from preprocessed inputs, using
286
- a Presto feature extractor and a CatBoost classifier.
287
-
288
- Parameters
289
- ----------
290
- inputs : DataCube
291
- preprocessed input cube
292
- cropland_mask : DataCube, optional
293
- optional cropland mask, by default None
294
-
295
- Returns
296
- -------
297
- DataCube
298
- croptype labels and probability
299
- """
300
-
301
- # Run feature computer
302
- features = apply_feature_extractor (
303
- feature_extractor_class = PRODUCT_SETTINGS [WorldCerealProduct .CROPTYPE ][
304
- "features"
305
- ]["extractor" ],
306
- cube = inputs ,
307
- parameters = PRODUCT_SETTINGS [WorldCerealProduct .CROPTYPE ]["features" ][
308
- "parameters"
309
- ],
310
- size = [
311
- {"dimension" : "x" , "unit" : "px" , "value" : 100 },
312
- {"dimension" : "y" , "unit" : "px" , "value" : 100 },
313
- ],
314
- overlap = [
315
- {"dimension" : "x" , "unit" : "px" , "value" : 0 },
316
- {"dimension" : "y" , "unit" : "px" , "value" : 0 },
317
- ],
318
- )
319
-
320
- # Run model inference on features
321
- classes = apply_model_inference (
322
- model_inference_class = PRODUCT_SETTINGS [WorldCerealProduct .CROPTYPE ][
323
- "classification"
324
- ]["classifier" ],
325
- cube = features ,
326
- parameters = PRODUCT_SETTINGS [WorldCerealProduct .CROPTYPE ]["classification" ][
327
- "parameters"
328
- ],
329
- size = [
330
- {"dimension" : "x" , "unit" : "px" , "value" : 100 },
331
- {"dimension" : "y" , "unit" : "px" , "value" : 100 },
332
- {"dimension" : "t" , "value" : "P1D" },
333
- ],
334
- overlap = [
335
- {"dimension" : "x" , "unit" : "px" , "value" : 0 },
336
- {"dimension" : "y" , "unit" : "px" , "value" : 0 },
337
- ],
338
- )
339
-
340
- # Mask cropland
341
- if cropland_mask is not None :
342
- classes = classes .mask (cropland_mask == 0 , replacement = 0 )
343
-
344
- # Cast to uint16
345
- classes = compress_uint16 (classes )
346
-
347
- return classes
0 commit comments