Skip to content

Commit

Permalink
Properly parse dtypes at validation (#71)
Browse files Browse the repository at this point in the history
* adjust schemas and templates to allow custom types

* adjust dtype types in pipeline schemas

* fix isort in tests
  • Loading branch information
zigaLuksic authored Jun 14, 2022
1 parent e83d9c7 commit 4a2b0b0
Show file tree
Hide file tree
Showing 12 changed files with 50 additions and 24 deletions.
2 changes: 2 additions & 0 deletions eogrow/core/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@ class Config:

extra = "forbid"
validate_all = True
allow_mutation = False
arbitrary_types_allowed = True

config: Schema

Expand Down
15 changes: 8 additions & 7 deletions eogrow/core/schemas.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,41 +137,42 @@ def _get_referred_model_name(openapi_schema: dict) -> Optional[str]:

def build_minimal_template(
schema: Type[BaseModel],
required_only: bool,
required_only: bool = False,
pipeline_import_path: Optional[str] = None,
add_descriptions: bool = False,
) -> dict:
rec_flags: dict = dict(required_only=required_only, add_descriptions=add_descriptions) # type is needed
json_schema = schema.schema() # needed for descriptions

template: dict = {}
for name, field in schema.__fields__.items():

if required_only and not field.required:
continue

description = field.field_info.description if add_descriptions else None

if name == "pipeline" and pipeline_import_path:
template[name] = pipeline_import_path
elif isclass(field.type_) and issubclass(field.type_, BaseModel):
# Contains a subschema in the nesting
if isclass(field.outer_type_) and issubclass(field.outer_type_, BaseModel):
template[name] = build_minimal_template(field.type_, **rec_flags)
if "description" in json_schema["properties"][name]:
template[name]["<< description >>"] = json_schema["properties"][name]["description"]
if description:
template[name]["<< description >>"] = description
else:
template[name] = {
"<< type >>": repr(field._type_display()),
"<< nested schema >>": str(field.type_),
"<< sub-template >>": build_minimal_template(field.type_, **rec_flags),
}
else:
template[name] = _field_description(field, json_schema["properties"][name], add_descriptions)
template[name] = _field_description(field, description)

return template


def _field_description(field: ModelField, field_schema: dict, add_descriptions: bool) -> str:
description = " // " + field_schema["description"] if "description" in field_schema and add_descriptions else ""
def _field_description(field: ModelField, description: Optional[str]) -> str:
description = f" // {description}" if description else ""
field_type = repr(field._type_display())
default = repr(field.default) + " : " if field.default else ""
return f"<< {default}{field_type}{description} >>"
6 changes: 4 additions & 2 deletions eogrow/pipelines/batch_to_eopatch.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
"""
from typing import Any, Dict, List, Optional

import numpy as np
from pydantic import Field, root_validator

from eolearn.core import (
Expand All @@ -23,7 +24,7 @@
from ..tasks.batch_to_eopatch import DeleteFilesTask, FixImportedTimeDependentFeatureTask, LoadUserDataTask
from ..utils.filter import get_patches_with_missing_features
from ..utils.types import Feature, FeatureSpec
from ..utils.validators import field_validator
from ..utils.validators import field_validator, optional_field_validator, parse_dtype


class FeatureMappingSchema(BaseSchema):
Expand All @@ -37,7 +38,8 @@ class FeatureMappingSchema(BaseSchema):
)
feature: Feature
multiply_factor: float = Field(1, description="Factor used to multiply feature values with.")
dtype: Optional[str] = Field(description="Dtype of the output feature.")
dtype: Optional[np.dtype] = Field(description="Dtype of the output feature.")
_parse_dtype = optional_field_validator("dtype", parse_dtype, pre=True)
save_early: bool = Field(
True,
description=(
Expand Down
11 changes: 9 additions & 2 deletions eogrow/pipelines/download.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,13 @@
from ..core.schemas import BaseSchema
from ..utils.filter import get_patches_with_missing_features
from ..utils.types import Feature, FeatureSpec, Path, ProcessingType, TimePeriod
from ..utils.validators import field_validator, parse_data_collection, parse_time_period
from ..utils.validators import (
field_validator,
optional_field_validator,
parse_data_collection,
parse_dtype,
parse_time_period,
)

LOGGER = logging.getLogger(__name__)

Expand All @@ -53,7 +59,8 @@ def get_valid_session(self) -> SentinelHubSession:

class RescaleSchema(BaseSchema):
rescale_factor: float = Field(1, description="Amount by which the selected features are multiplied")
dtype: Optional[str] = Field(description="The output dtype of data")
dtype: Optional[np.dtype] = Field(description="The output dtype of data")
_parse_dtype = optional_field_validator("dtype", parse_dtype, pre=True)
features_to_rescale: List[Feature]


Expand Down
6 changes: 4 additions & 2 deletions eogrow/pipelines/features.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import logging
from typing import Dict, List, Optional, Tuple

import numpy as np
from pydantic import Field

from eolearn.core import (
Expand All @@ -30,7 +31,7 @@
)
from ..utils.filter import get_patches_with_missing_features
from ..utils.types import Feature, FeatureSpec, TimePeriod
from ..utils.validators import field_validator, parse_time_period
from ..utils.validators import field_validator, optional_field_validator, parse_dtype, parse_time_period

LOGGER = logging.getLogger(__name__)

Expand Down Expand Up @@ -70,7 +71,8 @@ class Schema(Pipeline.Schema):
),
)

dtype: Optional[str] = Field(description="The dtype under which the concatenated features should be saved")
dtype: Optional[np.dtype] = Field(description="The dtype under which the concatenated features should be saved")
_parse_dtype = optional_field_validator("dtype", parse_dtype, pre=True)
output_feature_name: str = Field(description="Name of output data feature encompassing bands and NDIs")
compress_level: int = Field(1, description="Level of compression used in saving eopatches")

Expand Down
5 changes: 4 additions & 1 deletion eogrow/pipelines/prediction.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import abc
from typing import List, Optional, Tuple

import numpy as np
from pydantic import Field, validator

from eolearn.core import EONode, EOWorkflow, FeatureType, LoadTask, MergeEOPatchesTask, OverwritePermission, SaveTask
Expand All @@ -12,6 +13,7 @@
from ..tasks.prediction import ClassificationPredictionTask, RegressionPredictionTask
from ..utils.filter import get_patches_with_missing_features
from ..utils.types import Feature, FeatureSpec
from ..utils.validators import optional_field_validator, parse_dtype


class BasePredictionPipeline(Pipeline, metaclass=abc.ABCMeta):
Expand All @@ -31,9 +33,10 @@ class Schema(Pipeline.Schema):
output_folder_key: str = Field(
description="The storage manager key pointing to the output folder for the prediction pipeline."
)
dtype: Optional[str] = Field(
dtype: Optional[np.dtype] = Field(
description="Casts the result to desired type. Uses predictor output type by default."
)
_parse_dtype = optional_field_validator("dtype", parse_dtype, pre=True)

prediction_mask_folder_key: Optional[str]
prediction_mask_feature_name: Optional[str] = Field(
Expand Down
4 changes: 3 additions & 1 deletion eogrow/pipelines/rasterize.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
from ..utils.filter import get_patches_with_missing_features
from ..utils.fs import LocalFile
from ..utils.types import Feature, FeatureSpec
from ..utils.validators import field_validator, parse_dtype
from ..utils.vector import concat_gdf

LOGGER = logging.getLogger(__name__)
Expand All @@ -51,7 +52,8 @@ class VectorColumnSchema(BaseSchema):
polygon_buffer: float = Field(0, description="The size of polygon buffering to be applied before rasterization.")
resolution: float = Field(description="Rendering resolution in meters.")
overlap_value: Optional[int] = Field(description="Value to write over the areas where polygons overlap.")
dtype: str = Field("int32", description="Numpy dtype of the output feature.")
dtype: np.dtype = Field(np.dtype("int32"), description="Numpy dtype of the output feature.")
_parse_dtype = field_validator("dtype", parse_dtype, pre=True)
no_data_value: int = Field(0, description="The no_data_value argument to be passed to VectorToRasterTask")

@validator("values_column")
Expand Down
5 changes: 3 additions & 2 deletions eogrow/pipelines/testing.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from ..core.schemas import BaseSchema
from ..tasks.testing import DummyRasterFeatureTask, DummyTimestampFeatureTask
from ..utils.types import Feature, TimePeriod
from ..utils.validators import field_validator, parse_time_period
from ..utils.validators import field_validator, parse_dtype, parse_time_period

Self = TypeVar("Self", bound="TestPipeline")
LOGGER = logging.getLogger(__name__)
Expand Down Expand Up @@ -72,7 +72,8 @@ class FeatureSchema(BaseSchema):
class RasterFeatureSchema(FeatureSchema):
feature: Feature = Field(description="A feature to be processed.")
shape: Tuple[int, ...] = Field(description="A shape of a feature")
dtype: str = Field(description="The output dtype of the feature")
dtype: np.dtype = Field(description="The output dtype of the feature")
_parse_dtype = field_validator("dtype", parse_dtype, pre=True)
min_value: int = Field(0, description="All values in the feature will be greater or equal to this value.")
max_value: int = Field(1, description="All values in the feature will be smaller to this value.")

Expand Down
2 changes: 1 addition & 1 deletion eogrow/tasks/prediction.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ def __init__(
input_features: List[Feature],
mask_feature: Feature,
output_feature: Feature,
output_dtype: Optional[str],
output_dtype: Optional[np.dtype],
mp_lock: bool,
sh_config: SHConfig,
):
Expand Down
5 changes: 5 additions & 0 deletions eogrow/utils/validators.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import inspect
from typing import TYPE_CHECKING, Any, Callable, Tuple

import numpy as np
from pydantic import validator

from sentinelhub import DataCollection
Expand Down Expand Up @@ -106,6 +107,10 @@ def parse_data_collection(value: str) -> DataCollection:
)


def parse_dtype(value: str) -> np.dtype:
return np.dtype(value)


def validate_manager(value: dict) -> "ManagerSchema":
"""Parse and validate schema describing a manager."""
assert "manager" in value, "Manager definition has no `manager` field that specifies its class."
Expand Down
11 changes: 6 additions & 5 deletions tests/test_core/test_schemas.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,14 +4,15 @@
import pytest

from eogrow.core.area import UtmZoneAreaManager
from eogrow.core.schemas import build_schema_template
from eogrow.core.schemas import build_minimal_template, build_schema_template
from eogrow.core.storage import StorageManager
from eogrow.pipelines.download import DownloadPipeline
from eogrow.pipelines.export_maps import ExportMapsPipeline
from eogrow.pipelines.mapping import MappingPipeline


@pytest.mark.fast
@pytest.mark.parametrize("eogrow_object", [UtmZoneAreaManager, DownloadPipeline, ExportMapsPipeline, StorageManager])
def test_build_schema_template(eogrow_object):
template = build_schema_template(eogrow_object.Schema)
@pytest.mark.parametrize("eogrow_object", [UtmZoneAreaManager, MappingPipeline, ExportMapsPipeline, StorageManager])
@pytest.mark.parametrize("schema_builder", [build_schema_template, build_minimal_template])
def test_build_schema_template(eogrow_object, schema_builder):
template = schema_builder(eogrow_object.Schema)
assert isinstance(template, dict)
2 changes: 1 addition & 1 deletion tests/test_pipelines/test_batch_to_eopatch.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ def prepare_batch_files(
width: int,
height: int,
num_timestamps: int,
dtype: str,
dtype: np.dtype,
add_userdata: bool,
):
transform = rasterio.transform.from_bounds(*tiff_bbox, width=width, height=height)
Expand Down

0 comments on commit 4a2b0b0

Please sign in to comment.