Skip to content

Commit

Permalink
Merge pull request #94 from WorldCereal/93-pydantic-2-issue
Browse files Browse the repository at this point in the history
Fix croptype/cropland parameters for pydantic>2
  • Loading branch information
kvantricht authored Jul 2, 2024
2 parents fefd73d + 04d36ca commit bef1a49
Show file tree
Hide file tree
Showing 2 changed files with 39 additions and 7 deletions.
4 changes: 3 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,9 @@ dependencies = [
"cftime",
"pytest-depends",
"pyarrow",
"geopandas"]
"geopandas",
"pydantic>=2.6"
]

[project.urls]
"Homepage" = "https://github.com/WorldCereal/worldcereal-classification"
Expand Down
42 changes: 36 additions & 6 deletions src/worldcereal/job.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,13 @@

from enum import Enum
from pathlib import Path
from typing import Optional, Union
from typing import Optional, Type, Union

import openeo
from openeo_gfmap import BackendContext, BoundingBoxExtent, TemporalContext
from pydantic import BaseModel
from openeo_gfmap.features.feature_extractor import PatchFeatureExtractor
from openeo_gfmap.inference.model_inference import ModelInference
from pydantic import BaseModel, Field, ValidationError, model_validator

from worldcereal.openeo.feature_extractor import PrestoFeatureExtractor
from worldcereal.openeo.inference import CroplandClassifier, CroptypeClassifier
Expand Down Expand Up @@ -76,16 +78,30 @@ class CropLandParameters(BaseModel):
and passed in the process graph.
"""

feature_extractor: PrestoFeatureExtractor = PrestoFeatureExtractor
feature_extractor: Type[PatchFeatureExtractor] = Field(
default=PrestoFeatureExtractor
)
features_parameters: FeaturesParameters = FeaturesParameters(
rescale_s1=False,
presto_model_url="https://artifactory.vgt.vito.be/artifactory/auxdata-public/worldcereal-minimal-inference/presto.pt", # NOQA
)
classifier: CroplandClassifier = CroplandClassifier
classifier: Type[ModelInference] = Field(default=CroplandClassifier)
classifier_parameters: ClassifierParameters = ClassifierParameters(
classifier_url="https://artifactory.vgt.vito.be/artifactory/auxdata-public/worldcereal-minimal-inference/wc_catboost.onnx" # NOQA
)

@model_validator(mode="after")
def check_udf_types(self):
"""Validates the FeatureExtractor and Classifier classes."""
if not issubclass(self.feature_extractor, PatchFeatureExtractor):
raise ValidationError(
f"Feature extractor must be a subclass of PrestoFeatureExtractor, got {self.feature_extractor}"
)
if not issubclass(self.classifier, ModelInference):
raise ValidationError(
f"Classifier must be a subclass of ModelInference, got {self.classifier}"
)


class CropTypeParameters(BaseModel):
"""Parameters for the croptype product inference pipeline. Types are
Expand All @@ -109,16 +125,30 @@ class CropTypeParameters(BaseModel):
and passed in the process graph.
"""

feature_extractor: PrestoFeatureExtractor = PrestoFeatureExtractor
feature_extractor: Type[PatchFeatureExtractor] = Field(
default=PrestoFeatureExtractor
)
feature_parameters: FeaturesParameters = FeaturesParameters(
rescale_s1=False,
presto_model_url="https://artifactory.vgt.vito.be/artifactory/auxdata-public/worldcereal/models/PhaseII/presto-ss-wc-ft-ct-30D_test.pt", # NOQA
)
classifier: CroptypeClassifier = CroptypeClassifier
classifier: Type[ModelInference] = Field(default=CroptypeClassifier)
classifier_parameters: ClassifierParameters = ClassifierParameters(
classifier_url="https://artifactory.vgt.vito.be/artifactory/auxdata-public/worldcereal/models/PhaseII/presto-ss-wc-ft-ct-30D_test_CROPTYPE9.onnx" # NOQA
)

@model_validator(mode="after")
def check_udf_types(self):
"""Validates the FeatureExtractor and Classifier classes."""
if not issubclass(self.feature_extractor, PatchFeatureExtractor):
raise ValidationError(
f"Feature extractor must be a subclass of PrestoFeatureExtractor, got {self.feature_extractor}"
)
if not issubclass(self.classifier, ModelInference):
raise ValidationError(
f"Classifier must be a subclass of ModelInference, got {self.classifier}"
)


ONNX_DEPS_URL = "https://artifactory.vgt.vito.be/artifactory/auxdata-public/openeo/onnx_dependencies_1.16.3.zip"

Expand Down

0 comments on commit bef1a49

Please sign in to comment.