diff --git a/best-practices.md b/best-practices.md index 1d24ff8..8d6023c 100644 --- a/best-practices.md +++ b/best-practices.md @@ -2,6 +2,10 @@ This document makes a number of recommendations for creating real world ML Model Extensions. None of them are required to meet the core specification, but following these practices will improve the documentation of your model and make life easier for client tooling and users. They come about from practical experience of implementors and introduce a bit more 'constraint' for those who are creating STAC objects representing their models or creating tools to work with STAC. +## Using STAC Common Metadata Fields for the ML Model Extension + +We recommend using the `start_datetime` and `end_datetime`, `geometry`, and `bbox` to represent the recommended context of data the model was trained with and for which the model should have appropriate domain knowledge for inference. For example, we can consider a model which is trained on imagery from all over the world and is robust enough to be applied to any time period. In this case, the common metadata to use with the model would include the bbox of "the world" `[-90, -180, 90, 180]` and the start_datetime and end_datetime range could be generic values like `["1900-01-01", null]`. + ## Recommended Extensions to Compose with the ML Model Extension ### Processing Extension diff --git a/poetry.lock b/poetry.lock index 3e3b246..55aef26 100644 --- a/poetry.lock +++ b/poetry.lock @@ -785,6 +785,29 @@ files = [ plugins = ["importlib-metadata"] windows-terminal = ["colorama (>=0.4.6)"] +[[package]] +name = "pystac" +version = "1.9.0" +description = "Python library for working with the SpatioTemporal Asset Catalog (STAC) specification" +optional = false +python-versions = ">=3.9" +files = [ + {file = "pystac-1.9.0-py3-none-any.whl", hash = "sha256:64d5654166290169ad6ad2bc0d5337a1664ede1165635f0b73b327065b801a2f"}, + {file = "pystac-1.9.0.tar.gz", hash = "sha256:c6b5a86e241fca5e9267a7902c26679f208749a107e9015fe6aaf73a9dd40948"}, +] + +[package.dependencies] +python-dateutil = ">=2.7.0" + +[package.extras] +bench = ["asv (>=0.6.0,<0.7.0)", "packaging (>=23.1,<24.0)", "virtualenv (>=20.22,<21.0)"] +docs = ["Sphinx (>=6.2,<7.0)", "boto3 (>=1.28,<2.0)", "ipython (>=8.12,<9.0)", "jinja2 (<4.0)", "jupyter (>=1.0,<2.0)", "nbsphinx (>=0.9.0,<0.10.0)", "pydata-sphinx-theme (>=0.13,<1.0)", "rasterio (>=1.3,<2.0)", "shapely (>=2.0,<3.0)", "sphinx-autobuild (==2021.3.14)", "sphinx-design (>=0.5.0,<0.6.0)", "sphinxcontrib-fulltoc (>=1.2,<2.0)"] +jinja2 = ["jinja2 (<4.0)"] +orjson = ["orjson (>=3.5)"] +test = ["black (>=23.3,<24.0)", "codespell (>=2.2,<3.0)", "coverage (>=7.2,<8.0)", "doc8 (>=1.1,<2.0)", "html5lib (>=1.1,<2.0)", "jinja2 (<4.0)", "jsonschema (>=4.18,<5.0)", "mypy (>=1.2,<2.0)", "orjson (>=3.8,<4.0)", "pre-commit (>=3.2,<4.0)", "pytest (>=7.3,<8.0)", "pytest-cov (>=4.0,<5.0)", "pytest-mock (>=3.10,<4.0)", "pytest-recording (>=0.13.0,<0.14.0)", "requests-mock (>=1.11,<2.0)", "ruff (==0.1.1)", "types-html5lib (>=1.1,<2.0)", "types-jsonschema (>=4.18,<5.0)", "types-orjson (>=3.6,<4.0)", "types-python-dateutil (>=2.8,<3.0)", "types-urllib3 (>=1.26,<2.0)"] +urllib3 = ["urllib3 (>=1.26)"] +validation = ["jsonschema (>=4.18,<5.0)"] + [[package]] name = "pytest" version = "7.4.4" @@ -957,6 +980,20 @@ files = [ [package.dependencies] pytest = ">=5.0.0" +[[package]] +name = "python-dateutil" +version = "2.8.2" +description = "Extensions to the standard Python datetime module" +optional = false +python-versions = "!=3.0.*,!=3.1.*,!=3.2.*,>=2.7" +files = [ + {file = "python-dateutil-2.8.2.tar.gz", hash = "sha256:0123cacc1627ae19ddf3c27a5de5bd67ee4586fbdd6440d9748f8abb483d3e86"}, + {file = "python_dateutil-2.8.2-py2.py3-none-any.whl", hash = "sha256:961d03dc3453ebbc59dbdea9e4e11c5651520a876d0f4db161e8674aae935da9"}, +] + +[package.dependencies] +six = ">=1.5" + [[package]] name = "pyyaml" version = "6.0.1" @@ -1208,6 +1245,17 @@ files = [ {file = "shellingham-1.5.4.tar.gz", hash = "sha256:8dbca0739d487e5bd35ab3ca4b36e11c4078f3a234bfce294b0a0291363404de"}, ] +[[package]] +name = "six" +version = "1.16.0" +description = "Python 2 and 3 compatibility utilities" +optional = false +python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*" +files = [ + {file = "six-1.16.0-py2.py3-none-any.whl", hash = "sha256:8abb2f1d86890a2dfb989f9a77cfcfd3e47c2a354b01111771326f8aa26e0254"}, + {file = "six-1.16.0.tar.gz", hash = "sha256:1e61c37477a1626458e36f7b1d82aa5c9b094fa4802892072e49de9c60c4c926"}, +] + [[package]] name = "snowballstemmer" version = "2.2.0" @@ -1333,4 +1381,4 @@ test = ["covdefaults (>=2.3)", "coverage (>=7.2.7)", "coverage-enable-subprocess [metadata] lock-version = "2.0" python-versions = "^3.10" -content-hash = "8adf56c14896ce1548bf241e7ce24532f82c9a8a580bbde30c444fa5d6d6415d" +content-hash = "22fb0b0e7386f5abc1f2f7aa52630ace3c35cbdba9d94e75f0d5a1935f3574e9" diff --git a/pyproject.toml b/pyproject.toml index b9bd1ce..bad04fe 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -57,8 +57,7 @@ typer = {extras = ["all"], version = "^0.9.0"} rich = "^13.7.0" pydantic = "2.3" # bug in post 2.3 https://github.com/pydantic/pydantic/issues/7720 pydantic-core = "^2" -numpy = "^1.26.2" -# fastapi="^0.108.0" +pystac = "^1.9.0" [tool.poetry.group.dev.dependencies] diff --git a/stac_model/examples.py b/stac_model/examples.py index 79f3a44..79aaf41 100644 --- a/stac_model/examples.py +++ b/stac_model/examples.py @@ -2,7 +2,7 @@ Asset, ClassObject, InputArray, - MLModel, + MLModelExtension, ModelInput, ModelOutput, ResultArray, @@ -112,7 +112,7 @@ def eurosat_resnet(): output_shape=[-1, 10], result_array=[result_array], ) - ml_model_meta = MLModel( + ml_model_meta = MLModelExtension( mlm_name="Resnet-18 Sentinel-2 ALL MOCO", mlm_task="classification", mlm_framework="pytorch", diff --git a/stac_model/geometry_models.py b/stac_model/geometry_models.py new file mode 100644 index 0000000..125e08b --- /dev/null +++ b/stac_model/geometry_models.py @@ -0,0 +1,39 @@ +from typing import List, Literal, Union + +from pydantic import ( + BaseModel, +) + + +class Geometry(BaseModel): + type: str + coordinates: List + + +class GeoJSONPoint(Geometry): + type: Literal["Point"] + coordinates: List[float] + + +class GeoJSONMultiPoint(Geometry): + type: Literal["MultiPoint"] + coordinates: List[List[float]] + + +class GeoJSONPolygon(Geometry): + type: Literal["Polygon"] + coordinates: List[List[List[float]]] + + +class GeoJSONMultiPolygon(Geometry): + type: Literal["MultiPolygon"] + coordinates: List[List[List[List[float]]]] + + +AnyGeometry = Union[ + Geometry, + GeoJSONPoint, + GeoJSONMultiPoint, + GeoJSONPolygon, + GeoJSONMultiPolygon, +] diff --git a/stac_model/schema.py b/stac_model/schema.py index 84a6f8c..1a60c30 100644 --- a/stac_model/schema.py +++ b/stac_model/schema.py @@ -1,16 +1,48 @@ -from typing import Dict, List, Optional, Union +import json +from datetime import datetime +from typing import ( + Any, + Dict, + Generic, + Iterable, + List, + Literal, + MutableMapping, + Optional, + TypeVar, + Union, + cast, + get_args, +) +import pystac from pydantic import BaseModel, ConfigDict +from pydantic.fields import FieldInfo +from pystac.extensions import item_assets +from pystac.extensions.base import ( + ExtensionManagementMixin, + PropertiesExtension, + S, # generic pystac.STACObject + SummariesExtension, +) +from .geometry_models import AnyGeometry from .input import Band, InputArray, ModelInput, Statistics from .output import ClassObject, ModelOutput, ResultArray, TaskEnum from .runtime import Asset, Container, Runtime +T = TypeVar("T", pystac.Collection, pystac.Item, pystac.Asset, + item_assets.AssetDefinition) + +SchemaName = Literal["mlm"] +# TODO update +SCHEMA_URI: str = "https://raw.githubusercontent.com/crim-ca/dlm-extension/main/json-schema/schema.json" #noqa: E501 +PREFIX = f"{get_args(SchemaName)[0]}:" def mlm_prefix_replacer(field_name: str) -> str: return field_name.replace("mlm_", "mlm:") -class MLModel(BaseModel): +class MLModelProperties(BaseModel): mlm_name: str mlm_task: TaskEnum mlm_framework: str @@ -28,9 +60,209 @@ class MLModel(BaseModel): model_config = ConfigDict(alias_generator=mlm_prefix_replacer, populate_by_name=True, extra="ignore") +class MLModelHelper: + def __init__(self, attrs: MutableMapping[str, Any]): + self.attrs = attrs + self.mlmodel_attrs = attrs["attributes"] + + @property + def uid(self) -> str: + """Return a unique ID for MLModel data item.""" + keys = [ + "mlm_name", + "mlm_task", + ] + name = "_".join("_".join( + self.mlmodel_attrs[k].split(" ")) for k in keys).lower() + return name + + @property + def properties(self) -> MLModelProperties: + props = MLModelProperties(**self.mlmodel_attrs) + return props + + def stac_item(self, geometry: AnyGeometry, bbox: List[float], + start_datetime: datetime, end_datetime: datetime) -> pystac.Item: + item = pystac.Item( + id=self.uid, + geometry=geometry, + bbox=bbox, + properties={ + "start_datetime": start_datetime, + "end_datetime": end_datetime, + }, + datetime=None, + ) + item_mlmodel = MLModelExtension.ext(item, add_if_missing=True) + item_mlmodel.apply(self.properties) + return item + + +class MLModelExtension( + Generic[T], + PropertiesExtension, + ExtensionManagementMixin[Union[pystac.Asset, pystac.Item, pystac.Collection]], +): + @property + def name(self) -> SchemaName: + return get_args(SchemaName)[0] + + def apply( + self, + properties: Union[MLModelProperties, dict[str, Any]], + ) -> None: + """Applies Machine Learning Model Extension properties to the extended + :class:`~pystac.Item` or :class:`~pystac.Asset`. + """ + if isinstance(properties, dict): + properties = MLModelProperties(**properties) + data_json = json.loads(properties.model_dump_json(by_alias=True)) + for prop, val in data_json.items(): + self._set_property(prop, val) + + @classmethod + def get_schema_uri(cls) -> str: + return SCHEMA_URI + + @classmethod + def has_extension(cls, obj: S): + # FIXME: this override should be removed once an official and + # versioned schema is released ignore the original implementation + # logic for a version regex since in our case, the VERSION_REGEX + # is not fulfilled (ie: using 'main' branch, no tag available...) + ext_uri = cls.get_schema_uri() + return obj.stac_extensions is not None and any( + uri == ext_uri for uri in obj.stac_extensions) + + @classmethod + def ext(cls, obj: T, add_if_missing: bool = False) -> "MLModelExtension[T]": + """Extends the given STAC Object with properties from the + :stac-ext:`Machine Learning Model Extension `. + + This extension can be applied to instances of :class:`~pystac.Item` or + :class:`~pystac.Asset`. + + Raises: + + pystac.ExtensionTypeError : If an invalid object type is passed. + """ + if isinstance(obj, pystac.Collection): + cls.ensure_has_extension(obj, add_if_missing) + return cast(MLModelExtension[T], CollectionMLModelExtension(obj)) + elif isinstance(obj, pystac.Item): + cls.ensure_has_extension(obj, add_if_missing) + return cast(MLModelExtension[T], ItemMLModelExtension(obj)) + elif isinstance(obj, pystac.Asset): + cls.ensure_owner_has_extension(obj, add_if_missing) + return cast(MLModelExtension[T], AssetMLModelExtension(obj)) + elif isinstance(obj, item_assets.AssetDefinition): + cls.ensure_owner_has_extension(obj, add_if_missing) + return cast(MLModelExtension[T], ItemAssetsMLModelExtension(obj)) + else: + raise pystac.ExtensionTypeError(cls._ext_error_message(obj)) + + @classmethod + def summaries( + cls, obj: pystac.Collection, add_if_missing: bool = False + ) -> "SummariesMLModelExtension": + """Returns the extended summaries object for the given collection.""" + cls.ensure_has_extension(obj, add_if_missing) + return SummariesMLModelExtension(obj) + +class SummariesMLModelExtension(SummariesExtension): + """A concrete implementation of :class:`~SummariesExtension` that extends + the ``summaries`` field of a :class:`~pystac.Collection` to include properties + defined in the :stac-ext:`Machine Learning Model `. + """ + def _check_mlm_property(self, prop: str) -> FieldInfo: + try: + return MLModelProperties.model_fields[prop] + except KeyError as err: + raise AttributeError( + f"Name '{prop}' is not a valid MLM property.") from err + + def _validate_mlm_property(self, prop: str, summaries: list[Any]) -> None: + model = MLModelProperties.model_construct() + validator = MLModelProperties.__pydantic_validator__ + for value in summaries: + validator.validate_assignment(model, prop, value) + + def get_mlm_property(self, prop: str) -> list[Any]: + self._check_mlm_property(prop) + return self.summaries.get_list(prop) + + def set_mlm_property(self, prop: str, summaries: list[Any]) -> None: + self._check_mlm_property(prop) + self._validate_mlm_property(prop, summaries) + self._set_summary(prop, summaries) + + def __getattr__(self, prop): + return self.get_mlm_property(prop) + + def __setattr__(self, prop, value): + self.set_mlm_property(prop, value) + +class ItemMLModelExtension(MLModelExtension[pystac.Item]): + """A concrete implementation of :class:`MLModelExtension` on an + :class:`~pystac.Item` that extends the properties of the Item to + include properties defined in the :stac-ext:`Machine Learning Model + Extension `. + + This class should generally not be instantiated directly. Instead, call + :meth:`MLModelExtension.ext` on an :class:`~pystac.Item` to extend it. + """ + + def __init__(self, item: pystac.Item): + self.item = item + self.properties = item.properties + + def __repr__(self) -> str: + return f"" + +class ItemAssetsMLModelExtension(MLModelExtension[item_assets.AssetDefinition]): + properties: dict[str, Any] + asset_defn: item_assets.AssetDefinition + + def __init__(self, item_asset: item_assets.AssetDefinition): + self.asset_defn = item_asset + self.properties = item_asset.properties + +class AssetMLModelExtension(MLModelExtension[pystac.Asset]): + """A concrete implementation of :class:`MLModelExtension` on an + :class:`~pystac.Asset` that extends the Asset fields to include + properties defined in the :stac-ext:`Machine Learning Model + Extension `. + + This class should generally not be instantiated directly. Instead, call + :meth:`MLModelExtension.ext` on an :class:`~pystac.Asset` to extend it. + """ + + asset_href: str + """The ``href`` value of the :class:`~pystac.Asset` being extended.""" + + properties: dict[str, Any] + """The :class:`~pystac.Asset` fields, including extension properties.""" + + additional_read_properties: Optional[Iterable[dict[str, Any]]] = None + """If present, this will be a list containing 1 dictionary representing the + properties of the owning :class:`~pystac.Item`.""" + + def __init__(self, asset: pystac.Asset): + self.asset_href = asset.href + self.properties = asset.extra_fields + if asset.owner and isinstance(asset.owner, pystac.Item): + self.additional_read_properties = [asset.owner.properties] + + def __repr__(self) -> str: + return f"" + +class CollectionMLModelExtension(MLModelExtension[pystac.Collection]): + + def __init__(self, collection: pystac.Collection): + self.collection = collection __all__ = [ - "MLModel", + "MLModelExtension", "ModelInput", "InputArray", "Band",