From 4a2b0b010e9e7ebe89b92c4a8afe55719faae344 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=C5=BDiga=20Luk=C5=A1i=C4=8D?= <31988337+zigaLuksic@users.noreply.github.com> Date: Tue, 14 Jun 2022 14:40:31 +0200 Subject: [PATCH] Properly parse dtypes at validation (#71) * adjust schemas and templates to allow custom types * adjust dtype types in pipeline schemas * fix isort in tests --- eogrow/core/base.py | 2 ++ eogrow/core/schemas.py | 15 ++++++++------- eogrow/pipelines/batch_to_eopatch.py | 6 ++++-- eogrow/pipelines/download.py | 11 +++++++++-- eogrow/pipelines/features.py | 6 ++++-- eogrow/pipelines/prediction.py | 5 ++++- eogrow/pipelines/rasterize.py | 4 +++- eogrow/pipelines/testing.py | 5 +++-- eogrow/tasks/prediction.py | 2 +- eogrow/utils/validators.py | 5 +++++ tests/test_core/test_schemas.py | 11 ++++++----- tests/test_pipelines/test_batch_to_eopatch.py | 2 +- 12 files changed, 50 insertions(+), 24 deletions(-) diff --git a/eogrow/core/base.py b/eogrow/core/base.py index fb0b4281..5797bc16 100644 --- a/eogrow/core/base.py +++ b/eogrow/core/base.py @@ -19,6 +19,8 @@ class Config: extra = "forbid" validate_all = True + allow_mutation = False + arbitrary_types_allowed = True config: Schema diff --git a/eogrow/core/schemas.py b/eogrow/core/schemas.py index 300351eb..47712259 100644 --- a/eogrow/core/schemas.py +++ b/eogrow/core/schemas.py @@ -137,12 +137,11 @@ def _get_referred_model_name(openapi_schema: dict) -> Optional[str]: def build_minimal_template( schema: Type[BaseModel], - required_only: bool, + required_only: bool = False, pipeline_import_path: Optional[str] = None, add_descriptions: bool = False, ) -> dict: rec_flags: dict = dict(required_only=required_only, add_descriptions=add_descriptions) # type is needed - json_schema = schema.schema() # needed for descriptions template: dict = {} for name, field in schema.__fields__.items(): @@ -150,14 +149,16 @@ def build_minimal_template( if required_only and not field.required: continue + description = field.field_info.description if add_descriptions else None + if name == "pipeline" and pipeline_import_path: template[name] = pipeline_import_path elif isclass(field.type_) and issubclass(field.type_, BaseModel): # Contains a subschema in the nesting if isclass(field.outer_type_) and issubclass(field.outer_type_, BaseModel): template[name] = build_minimal_template(field.type_, **rec_flags) - if "description" in json_schema["properties"][name]: - template[name]["<< description >>"] = json_schema["properties"][name]["description"] + if description: + template[name]["<< description >>"] = description else: template[name] = { "<< type >>": repr(field._type_display()), @@ -165,13 +166,13 @@ def build_minimal_template( "<< sub-template >>": build_minimal_template(field.type_, **rec_flags), } else: - template[name] = _field_description(field, json_schema["properties"][name], add_descriptions) + template[name] = _field_description(field, description) return template -def _field_description(field: ModelField, field_schema: dict, add_descriptions: bool) -> str: - description = " // " + field_schema["description"] if "description" in field_schema and add_descriptions else "" +def _field_description(field: ModelField, description: Optional[str]) -> str: + description = f" // {description}" if description else "" field_type = repr(field._type_display()) default = repr(field.default) + " : " if field.default else "" return f"<< {default}{field_type}{description} >>" diff --git a/eogrow/pipelines/batch_to_eopatch.py b/eogrow/pipelines/batch_to_eopatch.py index 6e2a0d7b..60d39bdd 100644 --- a/eogrow/pipelines/batch_to_eopatch.py +++ b/eogrow/pipelines/batch_to_eopatch.py @@ -3,6 +3,7 @@ """ from typing import Any, Dict, List, Optional +import numpy as np from pydantic import Field, root_validator from eolearn.core import ( @@ -23,7 +24,7 @@ from ..tasks.batch_to_eopatch import DeleteFilesTask, FixImportedTimeDependentFeatureTask, LoadUserDataTask from ..utils.filter import get_patches_with_missing_features from ..utils.types import Feature, FeatureSpec -from ..utils.validators import field_validator +from ..utils.validators import field_validator, optional_field_validator, parse_dtype class FeatureMappingSchema(BaseSchema): @@ -37,7 +38,8 @@ class FeatureMappingSchema(BaseSchema): ) feature: Feature multiply_factor: float = Field(1, description="Factor used to multiply feature values with.") - dtype: Optional[str] = Field(description="Dtype of the output feature.") + dtype: Optional[np.dtype] = Field(description="Dtype of the output feature.") + _parse_dtype = optional_field_validator("dtype", parse_dtype, pre=True) save_early: bool = Field( True, description=( diff --git a/eogrow/pipelines/download.py b/eogrow/pipelines/download.py index fcbfc3fd..fb08fd00 100644 --- a/eogrow/pipelines/download.py +++ b/eogrow/pipelines/download.py @@ -30,7 +30,13 @@ from ..core.schemas import BaseSchema from ..utils.filter import get_patches_with_missing_features from ..utils.types import Feature, FeatureSpec, Path, ProcessingType, TimePeriod -from ..utils.validators import field_validator, parse_data_collection, parse_time_period +from ..utils.validators import ( + field_validator, + optional_field_validator, + parse_data_collection, + parse_dtype, + parse_time_period, +) LOGGER = logging.getLogger(__name__) @@ -53,7 +59,8 @@ def get_valid_session(self) -> SentinelHubSession: class RescaleSchema(BaseSchema): rescale_factor: float = Field(1, description="Amount by which the selected features are multiplied") - dtype: Optional[str] = Field(description="The output dtype of data") + dtype: Optional[np.dtype] = Field(description="The output dtype of data") + _parse_dtype = optional_field_validator("dtype", parse_dtype, pre=True) features_to_rescale: List[Feature] diff --git a/eogrow/pipelines/features.py b/eogrow/pipelines/features.py index 076f3426..4b916a38 100644 --- a/eogrow/pipelines/features.py +++ b/eogrow/pipelines/features.py @@ -4,6 +4,7 @@ import logging from typing import Dict, List, Optional, Tuple +import numpy as np from pydantic import Field from eolearn.core import ( @@ -30,7 +31,7 @@ ) from ..utils.filter import get_patches_with_missing_features from ..utils.types import Feature, FeatureSpec, TimePeriod -from ..utils.validators import field_validator, parse_time_period +from ..utils.validators import field_validator, optional_field_validator, parse_dtype, parse_time_period LOGGER = logging.getLogger(__name__) @@ -70,7 +71,8 @@ class Schema(Pipeline.Schema): ), ) - dtype: Optional[str] = Field(description="The dtype under which the concatenated features should be saved") + dtype: Optional[np.dtype] = Field(description="The dtype under which the concatenated features should be saved") + _parse_dtype = optional_field_validator("dtype", parse_dtype, pre=True) output_feature_name: str = Field(description="Name of output data feature encompassing bands and NDIs") compress_level: int = Field(1, description="Level of compression used in saving eopatches") diff --git a/eogrow/pipelines/prediction.py b/eogrow/pipelines/prediction.py index 7f1e9b23..afd0343c 100644 --- a/eogrow/pipelines/prediction.py +++ b/eogrow/pipelines/prediction.py @@ -4,6 +4,7 @@ import abc from typing import List, Optional, Tuple +import numpy as np from pydantic import Field, validator from eolearn.core import EONode, EOWorkflow, FeatureType, LoadTask, MergeEOPatchesTask, OverwritePermission, SaveTask @@ -12,6 +13,7 @@ from ..tasks.prediction import ClassificationPredictionTask, RegressionPredictionTask from ..utils.filter import get_patches_with_missing_features from ..utils.types import Feature, FeatureSpec +from ..utils.validators import optional_field_validator, parse_dtype class BasePredictionPipeline(Pipeline, metaclass=abc.ABCMeta): @@ -31,9 +33,10 @@ class Schema(Pipeline.Schema): output_folder_key: str = Field( description="The storage manager key pointing to the output folder for the prediction pipeline." ) - dtype: Optional[str] = Field( + dtype: Optional[np.dtype] = Field( description="Casts the result to desired type. Uses predictor output type by default." ) + _parse_dtype = optional_field_validator("dtype", parse_dtype, pre=True) prediction_mask_folder_key: Optional[str] prediction_mask_feature_name: Optional[str] = Field( diff --git a/eogrow/pipelines/rasterize.py b/eogrow/pipelines/rasterize.py index e6d6babf..56e5300d 100644 --- a/eogrow/pipelines/rasterize.py +++ b/eogrow/pipelines/rasterize.py @@ -31,6 +31,7 @@ from ..utils.filter import get_patches_with_missing_features from ..utils.fs import LocalFile from ..utils.types import Feature, FeatureSpec +from ..utils.validators import field_validator, parse_dtype from ..utils.vector import concat_gdf LOGGER = logging.getLogger(__name__) @@ -51,7 +52,8 @@ class VectorColumnSchema(BaseSchema): polygon_buffer: float = Field(0, description="The size of polygon buffering to be applied before rasterization.") resolution: float = Field(description="Rendering resolution in meters.") overlap_value: Optional[int] = Field(description="Value to write over the areas where polygons overlap.") - dtype: str = Field("int32", description="Numpy dtype of the output feature.") + dtype: np.dtype = Field(np.dtype("int32"), description="Numpy dtype of the output feature.") + _parse_dtype = field_validator("dtype", parse_dtype, pre=True) no_data_value: int = Field(0, description="The no_data_value argument to be passed to VectorToRasterTask") @validator("values_column") diff --git a/eogrow/pipelines/testing.py b/eogrow/pipelines/testing.py index 3468f6ec..afa4678b 100644 --- a/eogrow/pipelines/testing.py +++ b/eogrow/pipelines/testing.py @@ -14,7 +14,7 @@ from ..core.schemas import BaseSchema from ..tasks.testing import DummyRasterFeatureTask, DummyTimestampFeatureTask from ..utils.types import Feature, TimePeriod -from ..utils.validators import field_validator, parse_time_period +from ..utils.validators import field_validator, parse_dtype, parse_time_period Self = TypeVar("Self", bound="TestPipeline") LOGGER = logging.getLogger(__name__) @@ -72,7 +72,8 @@ class FeatureSchema(BaseSchema): class RasterFeatureSchema(FeatureSchema): feature: Feature = Field(description="A feature to be processed.") shape: Tuple[int, ...] = Field(description="A shape of a feature") - dtype: str = Field(description="The output dtype of the feature") + dtype: np.dtype = Field(description="The output dtype of the feature") + _parse_dtype = field_validator("dtype", parse_dtype, pre=True) min_value: int = Field(0, description="All values in the feature will be greater or equal to this value.") max_value: int = Field(1, description="All values in the feature will be smaller to this value.") diff --git a/eogrow/tasks/prediction.py b/eogrow/tasks/prediction.py index d469221c..eec5c611 100644 --- a/eogrow/tasks/prediction.py +++ b/eogrow/tasks/prediction.py @@ -24,7 +24,7 @@ def __init__( input_features: List[Feature], mask_feature: Feature, output_feature: Feature, - output_dtype: Optional[str], + output_dtype: Optional[np.dtype], mp_lock: bool, sh_config: SHConfig, ): diff --git a/eogrow/utils/validators.py b/eogrow/utils/validators.py index a1f46973..17faa61d 100644 --- a/eogrow/utils/validators.py +++ b/eogrow/utils/validators.py @@ -5,6 +5,7 @@ import inspect from typing import TYPE_CHECKING, Any, Callable, Tuple +import numpy as np from pydantic import validator from sentinelhub import DataCollection @@ -106,6 +107,10 @@ def parse_data_collection(value: str) -> DataCollection: ) +def parse_dtype(value: str) -> np.dtype: + return np.dtype(value) + + def validate_manager(value: dict) -> "ManagerSchema": """Parse and validate schema describing a manager.""" assert "manager" in value, "Manager definition has no `manager` field that specifies its class." diff --git a/tests/test_core/test_schemas.py b/tests/test_core/test_schemas.py index df71c686..27782f28 100644 --- a/tests/test_core/test_schemas.py +++ b/tests/test_core/test_schemas.py @@ -4,14 +4,15 @@ import pytest from eogrow.core.area import UtmZoneAreaManager -from eogrow.core.schemas import build_schema_template +from eogrow.core.schemas import build_minimal_template, build_schema_template from eogrow.core.storage import StorageManager -from eogrow.pipelines.download import DownloadPipeline from eogrow.pipelines.export_maps import ExportMapsPipeline +from eogrow.pipelines.mapping import MappingPipeline @pytest.mark.fast -@pytest.mark.parametrize("eogrow_object", [UtmZoneAreaManager, DownloadPipeline, ExportMapsPipeline, StorageManager]) -def test_build_schema_template(eogrow_object): - template = build_schema_template(eogrow_object.Schema) +@pytest.mark.parametrize("eogrow_object", [UtmZoneAreaManager, MappingPipeline, ExportMapsPipeline, StorageManager]) +@pytest.mark.parametrize("schema_builder", [build_schema_template, build_minimal_template]) +def test_build_schema_template(eogrow_object, schema_builder): + template = schema_builder(eogrow_object.Schema) assert isinstance(template, dict) diff --git a/tests/test_pipelines/test_batch_to_eopatch.py b/tests/test_pipelines/test_batch_to_eopatch.py index c69576a2..c3cd277b 100644 --- a/tests/test_pipelines/test_batch_to_eopatch.py +++ b/tests/test_pipelines/test_batch_to_eopatch.py @@ -29,7 +29,7 @@ def prepare_batch_files( width: int, height: int, num_timestamps: int, - dtype: str, + dtype: np.dtype, add_userdata: bool, ): transform = rasterio.transform.from_bounds(*tiff_bbox, width=width, height=height)