From 269bd734721ae2306fcbe177369f3cbf0813c241 Mon Sep 17 00:00:00 2001 From: Francis Charette-Migneault Date: Thu, 4 Apr 2024 18:19:06 -0400 Subject: [PATCH] update pydantic models with new json-schema fields --- json-schema/schema.json | 6 +- stac_model/base.py | 66 ++++++++++++++++++++ stac_model/examples.py | 110 +++++++++++++++++++--------------- stac_model/input.py | 65 ++++++++++++-------- stac_model/output.py | 129 ++++++++++++++++++++++++++++------------ stac_model/runtime.py | 18 +++--- stac_model/schema.py | 57 ++++++++---------- 7 files changed, 299 insertions(+), 152 deletions(-) create mode 100644 stac_model/base.py diff --git a/json-schema/schema.json b/json-schema/schema.json index ed79f02..7578e74 100644 --- a/json-schema/schema.json +++ b/json-schema/schema.json @@ -539,7 +539,11 @@ ] }, "NormalizeClip": { - + "type": "array", + "minItems": 1, + "items": { + "type": "number" + } }, "ResizeType": { "oneOf": [ diff --git a/stac_model/base.py b/stac_model/base.py new file mode 100644 index 0000000..4f4235d --- /dev/null +++ b/stac_model/base.py @@ -0,0 +1,66 @@ +from enum import Enum +from typing import Any, Literal, Union, TypeAlias + +from pydantic import BaseModel + + +DataType: TypeAlias = Literal[ + "uint8", + "uint16", + "uint32", + "uint64", + "int8", + "int16", + "int32", + "int64", + "float16", + "float32", + "float64", + "cint16", + "cint32", + "cfloat32", + "cfloat64", + "other" +] + + +class TaskEnum(str, Enum): + REGRESSION = "regression" + CLASSIFICATION = "classification" + SCENE_CLASSIFICATION = "scene-classification" + DETECTION = "detection" + OBJECT_DETECTION = "object-detection" + SEGMENTATION = "segmentation" + SEMANTIC_SEGMENTATION = "semantic-segmentation" + INSTANCE_SEGMENTATION = "instance-segmentation" + PANOPTIC_SEGMENTATION = "panoptic-segmentation" + SIMILARITY_SEARCH = "similarity-search" + GENERATIVE = "generative" + IMAGE_CAPTIONING = "image-captioning" + SUPER_RESOLUTION = "super-resolution" + + +ModelTaskNames: TypeAlias = Literal[ + "regression", + "classification", + "scene-classification", + "detection", + "object-detection", + "segmentation", + "semantic-segmentation", + "instance-segmentation", + "panoptic-segmentation", + "similarity-search", + "generative", + "image-captioning", + "super-resolution" +] + + +ModelTask = Union[ModelTaskNames, TaskEnum] + + +class ProcessingExpression(BaseModel): + # FIXME: should use 'pystac' reference, but 'processing' extension is not implemented yet! + format: str + expression: Any diff --git a/stac_model/examples.py b/stac_model/examples.py index 8e0ede4..9747086 100644 --- a/stac_model/examples.py +++ b/stac_model/examples.py @@ -1,15 +1,15 @@ import pystac import json import shapely +from stac_model.base import ProcessingExpression +from stac_model.input import ModelInput +from stac_model.output import ModelOutput, ModelResult from stac_model.schema import ( Asset, - ClassObject, InputArray, + MLMClassification, MLModelExtension, MLModelProperties, - ModelInput, - ModelOutput, - ResultArray, Runtime, Statistics, ) @@ -17,7 +17,14 @@ def eurosat_resnet() -> MLModelExtension[pystac.Item]: input_array = InputArray( - shape=[-1, 13, 64, 64], dim_order="bchw", data_type="float32" + shape=[-1, 13, 64, 64], + dim_order=[ + "batch", + "channel", + "height", + "width" + ], + data_type="float32", ) band_names = [ "B01", @@ -69,29 +76,34 @@ def eurosat_resnet() -> MLModelExtension[pystac.Item]: input = ModelInput( name="13 Band Sentinel-2 Batch", bands=band_names, - input_array=input_array, + input=input_array, norm_by_channel=True, - norm_type="z_score", - resize_type="none", + norm_type="z-score", + resize_type=None, statistics=stats, - pre_processing_function="torchgeo.datamodules.eurosat.EuroSATDataModule.collate_fn", # noqa: E501 - ) - runtime = Runtime( - framework="torch", - version="2.1.2+cu121", - asset=Asset(title = "Pytorch weights checkpoint", description="A Resnet-18 classification model trained on normalized Sentinel-2 imagery with Eurosat landcover labels with torchgeo", # noqa: E501 - type=".pth", roles=["weights"], href="https://huggingface.co/torchgeo/resnet18_sentinel2_all_moco/resolve/main/resnet18_sentinel2_all_moco-59bfdff9.pth" # noqa: E501 - ), - source_code=Asset( - href="https://github.com/microsoft/torchgeo/blob/61efd2e2c4df7ebe3bd03002ebbaeaa3cfe9885a/torchgeo/models/resnet.py#L207" # noqa: E501 - ), - accelerator="cuda", - accelerator_constrained=False, - hardware_summary="Unknown", - commit_hash="61efd2e2c4df7ebe3bd03002ebbaeaa3cfe9885a", + pre_processing_function=ProcessingExpression( + format="python", + expression="torchgeo.datamodules.eurosat.EuroSATDataModule.collate_fn" + ), # noqa: E501 ) - result_array = ResultArray( - shape=[-1, 10], dim_order=["batch", "class"], data_type="float32" + # runtime = Runtime( + # framework="torch", + # version="2.1.2+cu121", + # asset=Asset(title = "Pytorch weights checkpoint", description="A Resnet-18 classification model trained on normalized Sentinel-2 imagery with Eurosat landcover labels with torchgeo", # noqa: E501 + # type=".pth", roles=["weights"], href="https://huggingface.co/torchgeo/resnet18_sentinel2_all_moco/resolve/main/resnet18_sentinel2_all_moco-59bfdff9.pth" # noqa: E501 + # ), + # source_code=Asset( + # href="https://github.com/microsoft/torchgeo/blob/61efd2e2c4df7ebe3bd03002ebbaeaa3cfe9885a/torchgeo/models/resnet.py#L207" # noqa: E501 + # ), + # accelerator="cuda", + # accelerator_constrained=False, + # hardware_summary="Unknown", + # commit_hash="61efd2e2c4df7ebe3bd03002ebbaeaa3cfe9885a", + # ) + result_array = ModelResult( + shape=[-1, 10], + dim_order=["batch", "class"], + data_type="float32" ) class_map = { "Annual Crop": 0, @@ -106,30 +118,26 @@ def eurosat_resnet() -> MLModelExtension[pystac.Item]: "SeaLake": 9, } class_objects = [ - ClassObject(value=class_map[class_name], name=class_name) - for class_name in class_map + MLMClassification(value=class_value, name=class_name) + for class_name, class_value in class_map.items() ] output = ModelOutput( - task="classification", - classification_classes=class_objects, - output_shape=[-1, 10], - result_array=[result_array], + name="classification", + tasks={"classification"}, + classes=class_objects, + result=result_array, + post_processing_function=None, ) ml_model_meta = MLModelProperties( name="Resnet-18 Sentinel-2 ALL MOCO", - task="classification", + tasks={"classification"}, framework="pytorch", framework_version="2.1.2+cu121", file_size=43000000, memory_size=1, - summary=( - "Sourced from torchgeo python library," - "identifier is ResNet18_Weights.SENTINEL2_ALL_MOCO" - ), pretrained_source="EuroSat Sentinel-2", total_parameters=11_700_000, input=[input], - runtime=[runtime], output=[output], ) # TODO, this can't be serialized but pystac.item calls for a datetime @@ -138,26 +146,30 @@ def eurosat_resnet() -> MLModelExtension[pystac.Item]: start_datetime = "1900-01-01" end_datetime = None bbox = [ - -7.882190080512502, - 37.13739173208318, - 27.911651652899923, - 58.21798141355221 - ] - geometry = json.dumps(shapely.geometry.Polygon.from_bounds(*bbox).__geo_interface__, indent=2) - name = ( - "_".join(ml_model_meta.name.split(" ")).lower() - + f"_{ml_model_meta.task}".lower() - ) + -7.882190080512502, + 37.13739173208318, + 27.911651652899923, + 58.21798141355221 + ] + geometry = shapely.geometry.Polygon.from_bounds(*bbox).__geo_interface__ + name = "_".join(ml_model_meta.name.split(" ")).lower() item = pystac.Item( id=name, geometry=geometry, bbox=bbox, datetime=None, - properties={"start_datetime": start_datetime, "end_datetime": end_datetime}, + properties={ + "start_datetime": start_datetime, + "end_datetime": end_datetime, + "description": ( + "Sourced from torchgeo python library," + "identifier is ResNet18_Weights.SENTINEL2_ALL_MOCO" + ), + }, ) item.add_derived_from( "https://earth-search.aws.element84.com/v1/collections/sentinel-2-l2a" ) item_mlm = MLModelExtension.ext(item, add_if_missing=True) - item_mlm.apply(ml_model_meta.model_dump()) + item_mlm.apply(ml_model_meta.model_dump(by_alias=True)) return item_mlm diff --git a/stac_model/input.py b/stac_model/input.py index c453dbb..107fc5c 100644 --- a/stac_model/input.py +++ b/stac_model/input.py @@ -1,15 +1,14 @@ -from typing import Dict, List, Literal, Optional, Union +from typing import Any, List, Literal, Optional, Set, TypeAlias, Union -from pydantic import AnyUrl, BaseModel, Field +from pydantic import BaseModel, Field + +from stac_model.base import DataType, ProcessingExpression class InputArray(BaseModel): - shape: List[Union[int, float]] - dim_order: List[str] - data_type: str = Field( - ..., - pattern="^(uint8|uint16|uint32|uint64|int8|int16|int32|int64|float16|float32|float64|cint16|cint32|cfloat32|cfloat64|other)$", - ) + shape: List[Union[int, float]] = Field(..., min_items=1) + dim_order: List[str] = Field(..., min_items=1) + data_type: DataType class Statistics(BaseModel): @@ -24,29 +23,45 @@ class Statistics(BaseModel): class Band(BaseModel): name: str description: Optional[str] = None - nodata: float | int | str + nodata: Union[float, int, str] data_type: str unit: Optional[str] = None +NormalizeType: TypeAlias = Optional[Literal[ + "min-max", + "z-score", + "l1", + "l2", + "l2sqr", + "hamming", + "hamming2", + "type-mask", + "relative", + "inf" +]] + +ResizeType: TypeAlias = Optional[Literal[ + "crop", + "pad", + "interpolation-nearest", + "interpolation-linear", + "interpolation-cubic", + "interpolation-area", + "interpolation-lanczos4", + "interpolation-max", + "wrap-fill-outliers", + "wrap-inverse-map" +]] + + class ModelInput(BaseModel): name: str bands: List[str] - input_array: InputArray + input: InputArray norm_by_channel: bool = None - norm_type: Literal[ - "min_max", - "z_score", - "max_norm", - "mean_norm", - "unit_variance", - "norm_with_clip", - "none", - ] = None - resize_type: Literal["crop", "pad", "interpolate", "none"] = None - parameters: Optional[ - Dict[str, Union[int, str, bool, List[Union[int, str, bool]]]] - ] = None + norm_type: NormalizeType = None + norm_clip: Optional[List[Union[float, int]]] = None + resize_type: ResizeType = None statistics: Optional[Union[Statistics, List[Statistics]]] = None - norm_with_clip_values: Optional[List[Union[float, int]]] = None - pre_processing_function: Optional[str | AnyUrl] = None + pre_processing_function: Optional[ProcessingExpression] = None diff --git a/stac_model/output.py b/stac_model/output.py index 0b2e919..f08cbb2 100644 --- a/stac_model/output.py +++ b/stac_model/output.py @@ -1,43 +1,96 @@ -from enum import Enum -from typing import List, Optional, Union - -from pydantic import BaseModel, Field - - -class TaskEnum(str, Enum): - regression = "regression" - classification = "classification" - object_detection = "object detection" - semantic_segmentation = "semantic segmentation" - instance_segmentation = "instance segmentation" - panoptic_segmentation = "panoptic segmentation" - multi_modal = "multi-modal" - similarity_search = "similarity search" - image_captioning = "image captioning" - generative = "generative" - super_resolution = "super resolution" - - -class ResultArray(BaseModel): - shape: List[Union[int, float]] - dim_order: List[str] - data_type: str = Field( - ..., - pattern="^(uint8|uint16|uint32|uint64|int8|int16|int32|int64|float16|float32|float64)$", - ) +from typing import Annotated, Any, Dict, List, Optional, Set, TypeAlias, Union +from typing_extensions import NotRequired, TypedDict +from pystac.extensions.classification import Classification +from pydantic import AliasChoices, BaseModel, ConfigDict, Field, PlainSerializer, model_serializer -class ClassObject(BaseModel): - value: int - name: str - description: Optional[str] = None - title: Optional[str] = None - color_hint: Optional[str] = None - nodata: Optional[bool] = False +from stac_model.base import DataType, ModelTask, ProcessingExpression + + +class ModelResult(BaseModel): + shape: List[Union[int, float]] = Field(..., min_items=1) + dim_order: List[str] = Field(..., min_items=1) + data_type: DataType + + +# MLMClassification: TypeAlias = Annotated[ +# Classification, +# PlainSerializer( +# lambda x: x.to_dict(), +# when_used="json", +# return_type=TypedDict( +# "Classification", +# { +# "value": int, +# "name": str, +# "description": NotRequired[str], +# "color_hint": NotRequired[str], +# } +# ) +# ) +# ] + + +class MLMClassification(BaseModel, Classification): + @model_serializer() + def model_dump(self, *_, **__) -> Dict[str, Any]: + return self.to_dict() + + def __init__( + self, + value: int, + description: Optional[str] = None, + name: Optional[str] = None, + color_hint: Optional[str] = None + ) -> None: + Classification.__init__(self, {}) + if not name and not description: + raise ValueError("Class name or description is required!") + self.apply( + value=value, + name=name or description, + description=description or name, + color_hint=color_hint, + ) + + def __hash__(self) -> int: + return sum(map(hash, self.to_dict().items())) + + def __setattr__(self, key: str, value: Any) -> None: + if key == "properties": + Classification.__setattr__(self, key, value) + else: + BaseModel.__setattr__(self, key, value) + + model_config = ConfigDict(arbitrary_types_allowed=True) + +# class ClassObject(BaseModel): +# value: int +# name: str +# description: Optional[str] = None +# title: Optional[str] = None +# color_hint: Optional[str] = None +# nodata: Optional[bool] = False class ModelOutput(BaseModel): - task: TaskEnum - result_array: Optional[List[ResultArray]] = None - classification_classes: Optional[List[ClassObject]] = None - post_processing_function: Optional[str] = None + name: str + tasks: Set[ModelTask] + result: ModelResult + + # NOTE: + # Although it is preferable to have 'Set' to avoid duplicate, + # it is more important to keep the order in this case, + # which we would lose with 'Set'. + # We also get some unhashable errors with 'Set', although 'MLMClassification' implements '__hash__'. + classes: List[MLMClassification] = Field( + alias="classification:classes", + validation_alias=AliasChoices("classification:classes", "classification_classes"), + exclude_unset=True, + exclude_defaults=True + ) + post_processing_function: Optional[ProcessingExpression] = None + + model_config = ConfigDict( + populate_by_name=True + ) diff --git a/stac_model/runtime.py b/stac_model/runtime.py index b1a564a..1c0491f 100644 --- a/stac_model/runtime.py +++ b/stac_model/runtime.py @@ -1,7 +1,7 @@ from enum import Enum from typing import List, Optional -from pydantic import AnyUrl, BaseModel, ConfigDict, FilePath +from pydantic import AnyUrl, BaseModel, ConfigDict, FilePath, Field class Asset(BaseModel): @@ -41,11 +41,13 @@ def __str__(self): class Runtime(BaseModel): - asset: Asset - source_code: Asset - accelerator: AcceleratorEnum - accelerator_constrained: bool - hardware_summary: str - container: Optional[Container] = None - commit_hash: Optional[str] = None + framework: str + framework_version: str + file_size: int = Field(alias="file:size") + memory_size: int batch_size_suggestion: Optional[int] = None + + accelerator: Optional[AcceleratorEnum] = Field(exclude_unset=True, default=None) + accelerator_constrained: bool = Field(exclude_unset=True, default=False) + accelerator_summary: str = Field(exclude_unset=True, exclude_defaults=True, default="") + accelerator_count: int = Field(minimum=1, exclude_unset=True, exclude_defaults=True, default=-1) diff --git a/stac_model/schema.py b/stac_model/schema.py index 4f41603..13f13b9 100644 --- a/stac_model/schema.py +++ b/stac_model/schema.py @@ -7,6 +7,7 @@ List, Literal, Optional, + Set, TypeVar, Union, cast, @@ -14,7 +15,7 @@ ) import pystac -from pydantic import BaseModel, ConfigDict +from pydantic import BaseModel, ConfigDict, Field from pydantic.fields import FieldInfo from pystac.extensions import item_assets from pystac.extensions.base import ( @@ -24,9 +25,10 @@ SummariesExtension, ) -from .input import Band, InputArray, ModelInput, Statistics -from .output import ClassObject, ModelOutput, ResultArray, TaskEnum -from .runtime import Asset, Container, Runtime +from stac_model.base import DataType, ModelTask +from stac_model.input import Band, InputArray, ModelInput, Statistics +from stac_model.output import MLMClassification, ModelOutput +from stac_model.runtime import Asset, Container, Runtime T = TypeVar( "T", pystac.Collection, pystac.Item, pystac.Asset, item_assets.AssetDefinition @@ -41,25 +43,20 @@ def mlm_prefix_adder(field_name: str) -> str: return "mlm:" + field_name -class MLModelProperties(BaseModel): +class MLModelProperties(Runtime): name: str - task: TaskEnum - framework: str - framework_version: str - file_size: int - memory_size: int + tasks: Set[ModelTask] input: List[ModelInput] output: List[ModelOutput] - runtime: List[Runtime] + total_parameters: int - pretrained_source: str - summary: str - parameters: Optional[ - Dict[str, Union[int, str, bool, List[Union[int, str, bool]]]] - ] = None # noqa: E501 + pretrained: bool = Field(exclude_unset=True, default=True) + pretrained_source: Optional[str] = Field(exclude_unset=True) model_config = ConfigDict( - alias_generator=mlm_prefix_adder, populate_by_name=True, extra="ignore" + alias_generator=mlm_prefix_adder, + populate_by_name=True, + extra="ignore" ) @@ -221,17 +218,15 @@ def __init__(self, collection: pystac.Collection): self.collection = collection -__all__ = [ - "MLModelExtension", - "ModelInput", - "InputArray", - "Band", - "Statistics", - "ModelOutput", - "ClassObject", - "Asset", - "ResultArray", - "Runtime", - "Container", - "Asset", -] +# __all__ = [ +# "MLModelExtension", +# "ModelInput", +# "InputArray", +# "Band", +# "Statistics", +# "ModelOutput", +# "Asset", +# "Runtime", +# "Container", +# "Asset", +# ]