diff --git a/example.json b/example.json index ef32350..790c8dd 100644 --- a/example.json +++ b/example.json @@ -1,159 +1,6 @@ { - "mlm:name": "Resnet-18 Sentinel-2 ALL MOCO", - "mlm:task": "classification", - "mlm:framework": "pytorch", - "mlm:framework_version": "2.1.2+cu121", - "mlm:file_size": 1, - "mlm:memory_size": 1, - "mlm:input": [ - { - "name": "13 Band Sentinel-2 Batch", - "bands": [ - "B01", - "B02", - "B03", - "B04", - "B05", - "B06", - "B07", - "B08", - "B8A", - "B09", - "B10", - "B11", - "B12" - ], - "input_array": { - "shape": [ - -1, - 13, - 64, - 64 - ], - "dim_order": "bchw", - "data_type": "float32" - }, - "norm_by_channel": true, - "norm_type": "z_score", - "statistics": { - "mean": [ - 1354.40546513, - 1118.24399958, - 1042.92983953, - 947.62620298, - 1199.47283961, - 1999.79090914, - 2369.22292565, - 2296.82608323, - 732.08340178, - 12.11327804, - 1819.01027855, - 1118.92391149, - 2594.14080798 - ], - "stddev": [ - 245.71762908, - 333.00778264, - 395.09249139, - 593.75055589, - 566.4170017, - 861.18399006, - 1086.63139075, - 1117.98170791, - 404.91978886, - 4.77584468, - 1002.58768311, - 761.30323499, - 1231.58581042 - ] - }, - "pre_processing_function": "https://github.com/microsoft/torchgeo/blob/545abe8326efc2848feae69d0212a15faba3eb00/torchgeo/datamodules/eurosat.py" - } - ], - "mlm:output": [ - { - "task": "classification", - "result_array": [ - { - "shape": [ - -1, - 10 - ], - "dim_names": [ - "batch", - "class" - ], - "data_type": "float32" - } - ], - "classification_classes": [ - { - "value": 0, - "name": "Annual Crop", - "nodata": false - }, - { - "value": 1, - "name": "Forest", - "nodata": false - }, - { - "value": 2, - "name": "Herbaceous Vegetation", - "nodata": false - }, - { - "value": 3, - "name": "Highway", - "nodata": false - }, - { - "value": 4, - "name": "Industrial Buildings", - "nodata": false - }, - { - "value": 5, - "name": "Pasture", - "nodata": false - }, - { - "value": 6, - "name": "Permanent Crop", - "nodata": false - }, - { - "value": 7, - "name": "Residential Buildings", - "nodata": false - }, - { - "value": 8, - "name": "River", - "nodata": false - }, - { - "value": 9, - "name": "SeaLake", - "nodata": false - } - ] - } - ], - "mlm:runtime": [ - { - "asset": { - "href": "." - }, - "source_code": { - "href": "." - }, - "accelerator": "cuda", - "accelerator_constrained": false, - "hardware_summary": "Unknown" - } - ], - "mlm:total_parameters": 11700000, - "mlm:pretrained_source": "EuroSat Sentinel-2", - "mlm:summary": "Sourced from torchgeo python library,identifier is ResNet18_Weights.SENTINEL2_ALL_MOCO" -} + "type": "Feature", + "stac_version": "1.0.0", + "id": "resnet-18_sentinel-2_all_moco_classification", + "properties": { + "start_datetime": \ No newline at end of file diff --git a/stac_model/__main__.py b/stac_model/__main__.py index 4fbfc32..7c2a1c3 100644 --- a/stac_model/__main__.py +++ b/stac_model/__main__.py @@ -1,6 +1,6 @@ import typer from rich.console import Console - +import json from stac_model import __version__ from stac_model.examples import eurosat_resnet @@ -35,10 +35,8 @@ def main( ) -> None: """Generate example spec.""" ml_model_meta = eurosat_resnet() - json_str = ml_model_meta.model_dump_json(indent=2, exclude_none=True, by_alias=True) - with open("example.json", "w") as file: - file.write(json_str) - print(ml_model_meta.model_dump_json(indent=2, exclude_none=True, by_alias=True)) + with open("example.json", "w") as json_file: + json.dump(ml_model_meta.item.to_dict(), json_file, indent=4) print("Example model metadata written to ./example.json.") return ml_model_meta diff --git a/tests/test_schema.py b/tests/test_schema.py index 2c12ec9..20154f8 100644 --- a/tests/test_schema.py +++ b/tests/test_schema.py @@ -2,18 +2,15 @@ @pytest.fixture -def metadata_json(): +def mlmodel_metadata_item(): from stac_model.examples import eurosat_resnet model_metadata_stac_item = eurosat_resnet() return model_metadata_stac_item +def test_model_metadata_to_dict(mlmodel_metadata_item): + assert mlmodel_metadata_item.item.to_dict() -def test_model_metadata_to_dict(metadata_json): - assert metadata_json.to_dict() - - -def test_model_metadata_json_operations(metadata_json): - from stac_model.schema import MLModelExtension - - assert MLModelExtension(metadata_json.to_dict()) +def test_validate_model_metadata(mlmodel_metadata_item): + import pystac + assert pystac.read_dict(mlmodel_metadata_item.item.to_dict())