diff --git a/stac_model/examples.py b/stac_model/examples.py index 79aaf41..fbc2ab8 100644 --- a/stac_model/examples.py +++ b/stac_model/examples.py @@ -1,8 +1,11 @@ +from datetime import datetime + from stac_model.schema import ( Asset, ClassObject, InputArray, - MLModelExtension, + MLModelHelper, + MLModelProperties, ModelInput, ModelOutput, ResultArray, @@ -112,7 +115,7 @@ def eurosat_resnet(): output_shape=[-1, 10], result_array=[result_array], ) - ml_model_meta = MLModelExtension( + ml_model_meta = MLModelProperties( mlm_name="Resnet-18 Sentinel-2 ALL MOCO", mlm_task="classification", mlm_framework="pytorch", @@ -129,4 +132,13 @@ def eurosat_resnet(): mlm_runtime=[mlm_runtime], mlm_output=[mlm_output], ) - return ml_model_meta + + mlmodel_helper = MLModelHelper(attrs = ml_model_meta.model_dump()) + geometry=None + bbox = [-90, -180, 90, 180] + start_time = datetime.strptime("1900-01-01", '%Y-%m-%d') + end_time = None + item = mlmodel_helper.stac_item(geometry, bbox, start_datetime=start_time, + end_datetime=end_time) + + return item diff --git a/stac_model/runtime.py b/stac_model/runtime.py index 2872abf..ef78812 100644 --- a/stac_model/runtime.py +++ b/stac_model/runtime.py @@ -1,7 +1,7 @@ from enum import Enum from typing import List, Optional -from pydantic import BaseModel, ConfigDict, FilePath, field_validator +from pydantic import AnyUrl, BaseModel, ConfigDict, FilePath from .paths import S3Path @@ -11,7 +11,7 @@ class Asset(BaseModel): Follows the STAC Asset Object spec. """ - href: S3Path | FilePath | str + href: S3Path | FilePath | AnyUrl| str title: Optional[str] = None description: Optional[str] = None type: Optional[str] = None @@ -19,18 +19,6 @@ class Asset(BaseModel): model_config = ConfigDict(arbitrary_types_allowed = True) - @field_validator("href") - @classmethod - def check_path_type(cls, v): - if isinstance(v, str): - v = S3Path(url=v) if v.startswith("s3://") else FilePath(f=v) - else: - raise ValueError( - f"Expected str, S3Path, or FilePath input, received {type(v).__name__}" - ) - return v - - class Container(BaseModel): container_file: str image_name: str diff --git a/stac_model/schema.py b/stac_model/schema.py index 1a60c30..115b66a 100644 --- a/stac_model/schema.py +++ b/stac_model/schema.py @@ -63,7 +63,6 @@ class MLModelProperties(BaseModel): class MLModelHelper: def __init__(self, attrs: MutableMapping[str, Any]): self.attrs = attrs - self.mlmodel_attrs = attrs["attributes"] @property def uid(self) -> str: @@ -73,12 +72,12 @@ def uid(self) -> str: "mlm_task", ] name = "_".join("_".join( - self.mlmodel_attrs[k].split(" ")) for k in keys).lower() + self.attrs[k].split(" ")) for k in keys).lower() return name @property def properties(self) -> MLModelProperties: - props = MLModelProperties(**self.mlmodel_attrs) + props = MLModelProperties(**self.attrs) return props def stac_item(self, geometry: AnyGeometry, bbox: List[float], diff --git a/tests/test_schema.py b/tests/test_schema.py index 71359d2..cd3eb8a 100644 --- a/tests/test_schema.py +++ b/tests/test_schema.py @@ -5,10 +5,10 @@ @pytest.fixture def metadata_json(): from stac_model.examples import eurosat_resnet - model_metadata = eurosat_resnet() - return model_metadata.model_dump_json(indent=2) + model_metadata_stac_item = eurosat_resnet() + return model_metadata_stac_item -def test_model_metadata_json_operations(metadata_json): - from stac_model.schema import MLModel - model_metadata = MLModel.model_validate_json(metadata_json) +def test_model_metadata_json_operations(model_metadata_stac_item): + from stac_model.schema import MLModelExtension + model_metadata = MLModelExtension.apply(model_metadata_stac_item) assert model_metadata.mlm_name == "Resnet-18 Sentinel-2 ALL MOCO"