From 4fc2e8ed721779f752d5913bc1c8397c1246ff3e Mon Sep 17 00:00:00 2001 From: Ryan Avery Date: Wed, 28 Feb 2024 13:12:05 -0800 Subject: [PATCH] remove helper and use pystac.Item in example --- stac_model/__main__.py | 3 ++ stac_model/examples.py | 37 ++++++++++++++-------- stac_model/input.py | 5 +-- stac_model/runtime.py | 5 +-- stac_model/schema.py | 71 +++++++++++++----------------------------- tests/test_schema.py | 13 +++++--- 6 files changed, 64 insertions(+), 70 deletions(-) diff --git a/stac_model/__main__.py b/stac_model/__main__.py index 6eeb4bd..4fbfc32 100644 --- a/stac_model/__main__.py +++ b/stac_model/__main__.py @@ -14,12 +14,14 @@ ) console = Console() + def version_callback(print_version: bool) -> None: """Print the version of the package.""" if print_version: console.print(f"[yellow]stac-model[/] version: [bold blue]{__version__}[/]") raise typer.Exit() + @app.command(name="") def main( print_version: bool = typer.Option( @@ -40,5 +42,6 @@ def main( print("Example model metadata written to ./example.json.") return ml_model_meta + if __name__ == "__main__": app() diff --git a/stac_model/examples.py b/stac_model/examples.py index fbc2ab8..40f3365 100644 --- a/stac_model/examples.py +++ b/stac_model/examples.py @@ -1,10 +1,12 @@ from datetime import datetime +import pystac + from stac_model.schema import ( Asset, ClassObject, InputArray, - MLModelHelper, + MLModelExtension, MLModelProperties, ModelInput, ModelOutput, @@ -15,7 +17,6 @@ def eurosat_resnet(): - input_array = InputArray( shape=[-1, 13, 64, 64], dim_order="bchw", data_type="float32" ) @@ -74,8 +75,7 @@ def eurosat_resnet(): norm_type="z_score", resize_type="none", statistics=stats, - pre_processing_function = "https://github.com/microsoft/torchgeo/blob/545abe8326efc2848feae69d0212a15faba3eb00/torchgeo/datamodules/eurosat.py" # noqa: E501 -, + pre_processing_function="https://github.com/microsoft/torchgeo/blob/545abe8326efc2848feae69d0212a15faba3eb00/torchgeo/datamodules/eurosat.py", # noqa: E501 ) mlm_runtime = Runtime( framework="torch", @@ -132,13 +132,24 @@ def eurosat_resnet(): mlm_runtime=[mlm_runtime], mlm_output=[mlm_output], ) - - mlmodel_helper = MLModelHelper(attrs = ml_model_meta.model_dump()) - geometry=None + start_datetime = datetime.strptime("1900-01-01", "%Y-%m-%d") + end_datetime = None + geometry = None bbox = [-90, -180, 90, 180] - start_time = datetime.strptime("1900-01-01", '%Y-%m-%d') - end_time = None - item = mlmodel_helper.stac_item(geometry, bbox, start_datetime=start_time, - end_datetime=end_time) - - return item + name = ( + "_".join(ml_model_meta.mlm_name.split(" ")).lower() + + f"_{ml_model_meta.mlm_task}".lower() + ) + item = pystac.Item( + id=name, + geometry=geometry, + bbox=bbox, + datetime=None, + properties={"start_datetime": start_datetime, "end_datetime": end_datetime}, + ) + item.add_derived_from( + "https://earth-search.aws.element84.com/v1/collections/sentinel-2-l2a" + ) + item_mlmodel = MLModelExtension.ext(item, add_if_missing=True) + item_mlmodel.apply(ml_model_meta.model_dump()) + return item_mlmodel diff --git a/stac_model/input.py b/stac_model/input.py index 4bc0db1..318d766 100644 --- a/stac_model/input.py +++ b/stac_model/input.py @@ -44,8 +44,9 @@ class ModelInput(BaseModel): "none", ] = None resize_type: Literal["crop", "pad", "interpolate", "none"] = None - parameters: Optional[Dict[str, Union[int, str, bool, - List[Union[int, str, bool]]]]] = None + parameters: Optional[ + Dict[str, Union[int, str, bool, List[Union[int, str, bool]]]] + ] = None statistics: Optional[Union[Statistics, List[Statistics]]] = None norm_with_clip_values: Optional[List[Union[float, int]]] = None pre_processing_function: Optional[str | AnyUrl] = None diff --git a/stac_model/runtime.py b/stac_model/runtime.py index ef78812..dc11081 100644 --- a/stac_model/runtime.py +++ b/stac_model/runtime.py @@ -11,13 +11,14 @@ class Asset(BaseModel): Follows the STAC Asset Object spec. """ - href: S3Path | FilePath | AnyUrl| str + href: S3Path | FilePath | AnyUrl | str title: Optional[str] = None description: Optional[str] = None type: Optional[str] = None roles: Optional[List[str]] = None - model_config = ConfigDict(arbitrary_types_allowed = True) + model_config = ConfigDict(arbitrary_types_allowed=True) + class Container(BaseModel): container_file: str diff --git a/stac_model/schema.py b/stac_model/schema.py index 115b66a..cbf4ac5 100644 --- a/stac_model/schema.py +++ b/stac_model/schema.py @@ -1,5 +1,4 @@ import json -from datetime import datetime from typing import ( Any, Dict, @@ -7,7 +6,6 @@ Iterable, List, Literal, - MutableMapping, Optional, TypeVar, Union, @@ -26,22 +24,24 @@ SummariesExtension, ) -from .geometry_models import AnyGeometry from .input import Band, InputArray, ModelInput, Statistics from .output import ClassObject, ModelOutput, ResultArray, TaskEnum from .runtime import Asset, Container, Runtime -T = TypeVar("T", pystac.Collection, pystac.Item, pystac.Asset, - item_assets.AssetDefinition) +T = TypeVar( + "T", pystac.Collection, pystac.Item, pystac.Asset, item_assets.AssetDefinition +) SchemaName = Literal["mlm"] # TODO update -SCHEMA_URI: str = "https://raw.githubusercontent.com/crim-ca/dlm-extension/main/json-schema/schema.json" #noqa: E501 +SCHEMA_URI: str = "https://raw.githubusercontent.com/crim-ca/dlm-extension/main/json-schema/schema.json" # noqa: E501 PREFIX = f"{get_args(SchemaName)[0]}:" + def mlm_prefix_replacer(field_name: str) -> str: return field_name.replace("mlm_", "mlm:") + class MLModelProperties(BaseModel): mlm_name: str mlm_task: TaskEnum @@ -55,46 +55,13 @@ class MLModelProperties(BaseModel): mlm_total_parameters: int mlm_pretrained_source: str mlm_summary: str - mlm_parameters: Optional[Dict[str, Union[int, str, bool, List[Union[int, str, bool]]]]] = None # noqa: E501 - - model_config = ConfigDict(alias_generator=mlm_prefix_replacer, - populate_by_name=True, extra="ignore") - -class MLModelHelper: - def __init__(self, attrs: MutableMapping[str, Any]): - self.attrs = attrs + mlm_parameters: Optional[ + Dict[str, Union[int, str, bool, List[Union[int, str, bool]]]] + ] = None # noqa: E501 - @property - def uid(self) -> str: - """Return a unique ID for MLModel data item.""" - keys = [ - "mlm_name", - "mlm_task", - ] - name = "_".join("_".join( - self.attrs[k].split(" ")) for k in keys).lower() - return name - - @property - def properties(self) -> MLModelProperties: - props = MLModelProperties(**self.attrs) - return props - - def stac_item(self, geometry: AnyGeometry, bbox: List[float], - start_datetime: datetime, end_datetime: datetime) -> pystac.Item: - item = pystac.Item( - id=self.uid, - geometry=geometry, - bbox=bbox, - properties={ - "start_datetime": start_datetime, - "end_datetime": end_datetime, - }, - datetime=None, - ) - item_mlmodel = MLModelExtension.ext(item, add_if_missing=True) - item_mlmodel.apply(self.properties) - return item + model_config = ConfigDict( + alias_generator=mlm_prefix_replacer, populate_by_name=True, extra="ignore" + ) class MLModelExtension( @@ -131,7 +98,8 @@ def has_extension(cls, obj: S): # is not fulfilled (ie: using 'main' branch, no tag available...) ext_uri = cls.get_schema_uri() return obj.stac_extensions is not None and any( - uri == ext_uri for uri in obj.stac_extensions) + uri == ext_uri for uri in obj.stac_extensions + ) @classmethod def ext(cls, obj: T, add_if_missing: bool = False) -> "MLModelExtension[T]": @@ -168,17 +136,18 @@ def summaries( cls.ensure_has_extension(obj, add_if_missing) return SummariesMLModelExtension(obj) + class SummariesMLModelExtension(SummariesExtension): """A concrete implementation of :class:`~SummariesExtension` that extends the ``summaries`` field of a :class:`~pystac.Collection` to include properties defined in the :stac-ext:`Machine Learning Model `. """ + def _check_mlm_property(self, prop: str) -> FieldInfo: try: return MLModelProperties.model_fields[prop] except KeyError as err: - raise AttributeError( - f"Name '{prop}' is not a valid MLM property.") from err + raise AttributeError(f"Name '{prop}' is not a valid MLM property.") from err def _validate_mlm_property(self, prop: str, summaries: list[Any]) -> None: model = MLModelProperties.model_construct() @@ -201,6 +170,7 @@ def __getattr__(self, prop): def __setattr__(self, prop, value): self.set_mlm_property(prop, value) + class ItemMLModelExtension(MLModelExtension[pystac.Item]): """A concrete implementation of :class:`MLModelExtension` on an :class:`~pystac.Item` that extends the properties of the Item to @@ -218,6 +188,7 @@ def __init__(self, item: pystac.Item): def __repr__(self) -> str: return f"" + class ItemAssetsMLModelExtension(MLModelExtension[item_assets.AssetDefinition]): properties: dict[str, Any] asset_defn: item_assets.AssetDefinition @@ -226,6 +197,7 @@ def __init__(self, item_asset: item_assets.AssetDefinition): self.asset_defn = item_asset self.properties = item_asset.properties + class AssetMLModelExtension(MLModelExtension[pystac.Asset]): """A concrete implementation of :class:`MLModelExtension` on an :class:`~pystac.Asset` that extends the Asset fields to include @@ -255,11 +227,12 @@ def __init__(self, asset: pystac.Asset): def __repr__(self) -> str: return f"" -class CollectionMLModelExtension(MLModelExtension[pystac.Collection]): +class CollectionMLModelExtension(MLModelExtension[pystac.Collection]): def __init__(self, collection: pystac.Collection): self.collection = collection + __all__ = [ "MLModelExtension", "ModelInput", diff --git a/tests/test_schema.py b/tests/test_schema.py index cd3eb8a..2c12ec9 100644 --- a/tests/test_schema.py +++ b/tests/test_schema.py @@ -1,14 +1,19 @@ - import pytest @pytest.fixture def metadata_json(): from stac_model.examples import eurosat_resnet + model_metadata_stac_item = eurosat_resnet() return model_metadata_stac_item -def test_model_metadata_json_operations(model_metadata_stac_item): + +def test_model_metadata_to_dict(metadata_json): + assert metadata_json.to_dict() + + +def test_model_metadata_json_operations(metadata_json): from stac_model.schema import MLModelExtension - model_metadata = MLModelExtension.apply(model_metadata_stac_item) - assert model_metadata.mlm_name == "Resnet-18 Sentinel-2 ALL MOCO" + + assert MLModelExtension(metadata_json.to_dict())