Skip to content

Commit

Permalink
produce pystac item in example but can't serialize datetime
Browse files Browse the repository at this point in the history
  • Loading branch information
rbavery committed Feb 28, 2024
1 parent 2b62d7b commit 4590141
Show file tree
Hide file tree
Showing 4 changed files with 24 additions and 25 deletions.
18 changes: 15 additions & 3 deletions stac_model/examples.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,11 @@
from datetime import datetime

from stac_model.schema import (
Asset,
ClassObject,
InputArray,
MLModelExtension,
MLModelHelper,
MLModelProperties,
ModelInput,
ModelOutput,
ResultArray,
Expand Down Expand Up @@ -112,7 +115,7 @@ def eurosat_resnet():
output_shape=[-1, 10],
result_array=[result_array],
)
ml_model_meta = MLModelExtension(
ml_model_meta = MLModelProperties(
mlm_name="Resnet-18 Sentinel-2 ALL MOCO",
mlm_task="classification",
mlm_framework="pytorch",
Expand All @@ -129,4 +132,13 @@ def eurosat_resnet():
mlm_runtime=[mlm_runtime],
mlm_output=[mlm_output],
)
return ml_model_meta

mlmodel_helper = MLModelHelper(attrs = ml_model_meta.model_dump())
geometry=None
bbox = [-90, -180, 90, 180]
start_time = datetime.strptime("1900-01-01", '%Y-%m-%d')
end_time = None
item = mlmodel_helper.stac_item(geometry, bbox, start_datetime=start_time,
end_datetime=end_time)

return item
16 changes: 2 additions & 14 deletions stac_model/runtime.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from enum import Enum
from typing import List, Optional

from pydantic import BaseModel, ConfigDict, FilePath, field_validator
from pydantic import AnyUrl, BaseModel, ConfigDict, FilePath

from .paths import S3Path

Expand All @@ -11,26 +11,14 @@ class Asset(BaseModel):
Follows the STAC Asset Object spec.
"""

href: S3Path | FilePath | str
href: S3Path | FilePath | AnyUrl| str
title: Optional[str] = None
description: Optional[str] = None
type: Optional[str] = None
roles: Optional[List[str]] = None

model_config = ConfigDict(arbitrary_types_allowed = True)

@field_validator("href")
@classmethod
def check_path_type(cls, v):
if isinstance(v, str):
v = S3Path(url=v) if v.startswith("s3://") else FilePath(f=v)
else:
raise ValueError(
f"Expected str, S3Path, or FilePath input, received {type(v).__name__}"
)
return v


class Container(BaseModel):
container_file: str
image_name: str
Expand Down
5 changes: 2 additions & 3 deletions stac_model/schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,6 @@ class MLModelProperties(BaseModel):
class MLModelHelper:
def __init__(self, attrs: MutableMapping[str, Any]):
self.attrs = attrs
self.mlmodel_attrs = attrs["attributes"]

@property
def uid(self) -> str:
Expand All @@ -73,12 +72,12 @@ def uid(self) -> str:
"mlm_task",
]
name = "_".join("_".join(
self.mlmodel_attrs[k].split(" ")) for k in keys).lower()
self.attrs[k].split(" ")) for k in keys).lower()
return name

@property
def properties(self) -> MLModelProperties:
props = MLModelProperties(**self.mlmodel_attrs)
props = MLModelProperties(**self.attrs)
return props

def stac_item(self, geometry: AnyGeometry, bbox: List[float],
Expand Down
10 changes: 5 additions & 5 deletions tests/test_schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,10 @@
@pytest.fixture
def metadata_json():
from stac_model.examples import eurosat_resnet
model_metadata = eurosat_resnet()
return model_metadata.model_dump_json(indent=2)
model_metadata_stac_item = eurosat_resnet()
return model_metadata_stac_item

def test_model_metadata_json_operations(metadata_json):
from stac_model.schema import MLModel
model_metadata = MLModel.model_validate_json(metadata_json)
def test_model_metadata_json_operations(model_metadata_stac_item):
from stac_model.schema import MLModelExtension
model_metadata = MLModelExtension.apply(model_metadata_stac_item)
assert model_metadata.mlm_name == "Resnet-18 Sentinel-2 ALL MOCO"

0 comments on commit 4590141

Please sign in to comment.