Skip to content

Commit

Permalink
remove helper and use pystac.Item in example
Browse files Browse the repository at this point in the history
  • Loading branch information
rbavery committed Feb 28, 2024
1 parent 4590141 commit 4fc2e8e
Show file tree
Hide file tree
Showing 6 changed files with 64 additions and 70 deletions.
3 changes: 3 additions & 0 deletions stac_model/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,12 +14,14 @@
)
console = Console()


def version_callback(print_version: bool) -> None:
"""Print the version of the package."""
if print_version:
console.print(f"[yellow]stac-model[/] version: [bold blue]{__version__}[/]")
raise typer.Exit()


@app.command(name="")
def main(
print_version: bool = typer.Option(
Expand All @@ -40,5 +42,6 @@ def main(
print("Example model metadata written to ./example.json.")
return ml_model_meta


if __name__ == "__main__":
app()
37 changes: 24 additions & 13 deletions stac_model/examples.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
from datetime import datetime

import pystac

from stac_model.schema import (
Asset,
ClassObject,
InputArray,
MLModelHelper,
MLModelExtension,
MLModelProperties,
ModelInput,
ModelOutput,
Expand All @@ -15,7 +17,6 @@


def eurosat_resnet():

input_array = InputArray(
shape=[-1, 13, 64, 64], dim_order="bchw", data_type="float32"
)
Expand Down Expand Up @@ -74,8 +75,7 @@ def eurosat_resnet():
norm_type="z_score",
resize_type="none",
statistics=stats,
pre_processing_function = "https://github.com/microsoft/torchgeo/blob/545abe8326efc2848feae69d0212a15faba3eb00/torchgeo/datamodules/eurosat.py" # noqa: E501
,
pre_processing_function="https://github.com/microsoft/torchgeo/blob/545abe8326efc2848feae69d0212a15faba3eb00/torchgeo/datamodules/eurosat.py", # noqa: E501
)
mlm_runtime = Runtime(
framework="torch",
Expand Down Expand Up @@ -132,13 +132,24 @@ def eurosat_resnet():
mlm_runtime=[mlm_runtime],
mlm_output=[mlm_output],
)

mlmodel_helper = MLModelHelper(attrs = ml_model_meta.model_dump())
geometry=None
start_datetime = datetime.strptime("1900-01-01", "%Y-%m-%d")
end_datetime = None
geometry = None
bbox = [-90, -180, 90, 180]
start_time = datetime.strptime("1900-01-01", '%Y-%m-%d')
end_time = None
item = mlmodel_helper.stac_item(geometry, bbox, start_datetime=start_time,
end_datetime=end_time)

return item
name = (
"_".join(ml_model_meta.mlm_name.split(" ")).lower()
+ f"_{ml_model_meta.mlm_task}".lower()
)
item = pystac.Item(
id=name,
geometry=geometry,
bbox=bbox,
datetime=None,
properties={"start_datetime": start_datetime, "end_datetime": end_datetime},
)
item.add_derived_from(
"https://earth-search.aws.element84.com/v1/collections/sentinel-2-l2a"
)
item_mlmodel = MLModelExtension.ext(item, add_if_missing=True)
item_mlmodel.apply(ml_model_meta.model_dump())
return item_mlmodel
5 changes: 3 additions & 2 deletions stac_model/input.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,8 +44,9 @@ class ModelInput(BaseModel):
"none",
] = None
resize_type: Literal["crop", "pad", "interpolate", "none"] = None
parameters: Optional[Dict[str, Union[int, str, bool,
List[Union[int, str, bool]]]]] = None
parameters: Optional[
Dict[str, Union[int, str, bool, List[Union[int, str, bool]]]]
] = None
statistics: Optional[Union[Statistics, List[Statistics]]] = None
norm_with_clip_values: Optional[List[Union[float, int]]] = None
pre_processing_function: Optional[str | AnyUrl] = None
5 changes: 3 additions & 2 deletions stac_model/runtime.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,13 +11,14 @@ class Asset(BaseModel):
Follows the STAC Asset Object spec.
"""

href: S3Path | FilePath | AnyUrl| str
href: S3Path | FilePath | AnyUrl | str
title: Optional[str] = None
description: Optional[str] = None
type: Optional[str] = None
roles: Optional[List[str]] = None

model_config = ConfigDict(arbitrary_types_allowed = True)
model_config = ConfigDict(arbitrary_types_allowed=True)


class Container(BaseModel):
container_file: str
Expand Down
71 changes: 22 additions & 49 deletions stac_model/schema.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,11 @@
import json
from datetime import datetime
from typing import (
Any,
Dict,
Generic,
Iterable,
List,
Literal,
MutableMapping,
Optional,
TypeVar,
Union,
Expand All @@ -26,22 +24,24 @@
SummariesExtension,
)

from .geometry_models import AnyGeometry
from .input import Band, InputArray, ModelInput, Statistics
from .output import ClassObject, ModelOutput, ResultArray, TaskEnum
from .runtime import Asset, Container, Runtime

T = TypeVar("T", pystac.Collection, pystac.Item, pystac.Asset,
item_assets.AssetDefinition)
T = TypeVar(
"T", pystac.Collection, pystac.Item, pystac.Asset, item_assets.AssetDefinition
)

SchemaName = Literal["mlm"]
# TODO update
SCHEMA_URI: str = "https://raw.githubusercontent.com/crim-ca/dlm-extension/main/json-schema/schema.json" #noqa: E501
SCHEMA_URI: str = "https://raw.githubusercontent.com/crim-ca/dlm-extension/main/json-schema/schema.json" # noqa: E501
PREFIX = f"{get_args(SchemaName)[0]}:"


def mlm_prefix_replacer(field_name: str) -> str:
return field_name.replace("mlm_", "mlm:")


class MLModelProperties(BaseModel):
mlm_name: str
mlm_task: TaskEnum
Expand All @@ -55,46 +55,13 @@ class MLModelProperties(BaseModel):
mlm_total_parameters: int
mlm_pretrained_source: str
mlm_summary: str
mlm_parameters: Optional[Dict[str, Union[int, str, bool, List[Union[int, str, bool]]]]] = None # noqa: E501

model_config = ConfigDict(alias_generator=mlm_prefix_replacer,
populate_by_name=True, extra="ignore")

class MLModelHelper:
def __init__(self, attrs: MutableMapping[str, Any]):
self.attrs = attrs
mlm_parameters: Optional[
Dict[str, Union[int, str, bool, List[Union[int, str, bool]]]]
] = None # noqa: E501

@property
def uid(self) -> str:
"""Return a unique ID for MLModel data item."""
keys = [
"mlm_name",
"mlm_task",
]
name = "_".join("_".join(
self.attrs[k].split(" ")) for k in keys).lower()
return name

@property
def properties(self) -> MLModelProperties:
props = MLModelProperties(**self.attrs)
return props

def stac_item(self, geometry: AnyGeometry, bbox: List[float],
start_datetime: datetime, end_datetime: datetime) -> pystac.Item:
item = pystac.Item(
id=self.uid,
geometry=geometry,
bbox=bbox,
properties={
"start_datetime": start_datetime,
"end_datetime": end_datetime,
},
datetime=None,
)
item_mlmodel = MLModelExtension.ext(item, add_if_missing=True)
item_mlmodel.apply(self.properties)
return item
model_config = ConfigDict(
alias_generator=mlm_prefix_replacer, populate_by_name=True, extra="ignore"
)


class MLModelExtension(
Expand Down Expand Up @@ -131,7 +98,8 @@ def has_extension(cls, obj: S):
# is not fulfilled (ie: using 'main' branch, no tag available...)
ext_uri = cls.get_schema_uri()
return obj.stac_extensions is not None and any(
uri == ext_uri for uri in obj.stac_extensions)
uri == ext_uri for uri in obj.stac_extensions
)

@classmethod
def ext(cls, obj: T, add_if_missing: bool = False) -> "MLModelExtension[T]":
Expand Down Expand Up @@ -168,17 +136,18 @@ def summaries(
cls.ensure_has_extension(obj, add_if_missing)
return SummariesMLModelExtension(obj)


class SummariesMLModelExtension(SummariesExtension):
"""A concrete implementation of :class:`~SummariesExtension` that extends
the ``summaries`` field of a :class:`~pystac.Collection` to include properties
defined in the :stac-ext:`Machine Learning Model <mlm>`.
"""

def _check_mlm_property(self, prop: str) -> FieldInfo:
try:
return MLModelProperties.model_fields[prop]
except KeyError as err:
raise AttributeError(
f"Name '{prop}' is not a valid MLM property.") from err
raise AttributeError(f"Name '{prop}' is not a valid MLM property.") from err

def _validate_mlm_property(self, prop: str, summaries: list[Any]) -> None:
model = MLModelProperties.model_construct()
Expand All @@ -201,6 +170,7 @@ def __getattr__(self, prop):
def __setattr__(self, prop, value):
self.set_mlm_property(prop, value)


class ItemMLModelExtension(MLModelExtension[pystac.Item]):
"""A concrete implementation of :class:`MLModelExtension` on an
:class:`~pystac.Item` that extends the properties of the Item to
Expand All @@ -218,6 +188,7 @@ def __init__(self, item: pystac.Item):
def __repr__(self) -> str:
return f"<ItemMLModelExtension Item id={self.item.id}>"


class ItemAssetsMLModelExtension(MLModelExtension[item_assets.AssetDefinition]):
properties: dict[str, Any]
asset_defn: item_assets.AssetDefinition
Expand All @@ -226,6 +197,7 @@ def __init__(self, item_asset: item_assets.AssetDefinition):
self.asset_defn = item_asset
self.properties = item_asset.properties


class AssetMLModelExtension(MLModelExtension[pystac.Asset]):
"""A concrete implementation of :class:`MLModelExtension` on an
:class:`~pystac.Asset` that extends the Asset fields to include
Expand Down Expand Up @@ -255,11 +227,12 @@ def __init__(self, asset: pystac.Asset):
def __repr__(self) -> str:
return f"<AssetMLModelExtension Asset href={self.asset_href}>"

class CollectionMLModelExtension(MLModelExtension[pystac.Collection]):

class CollectionMLModelExtension(MLModelExtension[pystac.Collection]):
def __init__(self, collection: pystac.Collection):
self.collection = collection


__all__ = [
"MLModelExtension",
"ModelInput",
Expand Down
13 changes: 9 additions & 4 deletions tests/test_schema.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,19 @@

import pytest


@pytest.fixture
def metadata_json():
from stac_model.examples import eurosat_resnet

model_metadata_stac_item = eurosat_resnet()
return model_metadata_stac_item

def test_model_metadata_json_operations(model_metadata_stac_item):

def test_model_metadata_to_dict(metadata_json):
assert metadata_json.to_dict()


def test_model_metadata_json_operations(metadata_json):
from stac_model.schema import MLModelExtension
model_metadata = MLModelExtension.apply(model_metadata_stac_item)
assert model_metadata.mlm_name == "Resnet-18 Sentinel-2 ALL MOCO"

assert MLModelExtension(metadata_json.to_dict())

0 comments on commit 4fc2e8e

Please sign in to comment.