Skip to content

Commit

Permalink
address comments
Browse files Browse the repository at this point in the history
  • Loading branch information
rbavery committed Mar 7, 2024
1 parent 0bed29b commit bf3b07f
Show file tree
Hide file tree
Showing 3 changed files with 24 additions and 179 deletions.
24 changes: 13 additions & 11 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -111,16 +111,18 @@ Note: It is common in the machine learning, computer vision, and remote sensing

### Runtime Object

| Field Name | Type | Description |
|-------------------------|---------------------------------------|-----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|
| model_asset | [Asset Object](stac-asset) | **REQUIRED.** Asset object containing URI to the model file. |
| source_code | [Asset Object](stac-asset) | **REQUIRED.** Source code description. Can describe a github repo, zip archive, etc. This description should reference the inference function, for example my_package.my_module.predict |
| accelerator | [Accelerator Enum](#accelerator-enum) | **REQUIRED.** The intended computational hardware that runs inference. |
| accelerator_constrained | boolean | **REQUIRED.** True if the intended `accelerator` is the only `accelerator` that can run inference. False if other accelerators, such as amd64 (CPU), can run inference. |
| hardware_summary | string | **REQUIRED.** A high level description of the number of accelerators, specific generation of the `accelerator`, or other relevant inference details. |
| container | [Container](#container) | **RECOMMENDED.** Information to run the model in a container instance. |
| model_commit_hash | string | Hash value pointing to a specific version of the code. |
| batch_size_suggestion | number | A suggested batch size for the accelerator and summarized hardware. |
| Field Name | Type | Description |
| ----------------------- | ------------------------------------- | ----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- |
| model_asset | [Asset Object](stac-asset) | **REQUIRED.** Asset object containing URI to the model file. Recommended asset `roles` include `weights` for model weights that need to be loaded by a model definition and `compiled` for models that can be loaded directly without an intermediate model definition. |
| source_code | [Asset Object](stac-asset) | **REQUIRED.** Source code description. Can describe a github repo, zip archive, etc. This description should reference the inference function, for example my_package.my_module.predict |
| accelerator | [Accelerator Enum](#accelerator-enum) | **REQUIRED.** The intended computational hardware that runs inference. |
| accelerator_constrained | boolean | **REQUIRED.** True if the intended `accelerator` is the only `accelerator` that can run inference. False if other accelerators, such as amd64 (CPU), can run inference. |
| hardware_summary | string | **REQUIRED.** A high level description of the number of accelerators, specific generation of the `accelerator`, or other relevant inference details. |
| container | [Container](#container) | **RECOMMENDED.** Information to run the model in a container instance. |
| model_commit_hash | string | Hash value pointing to a specific version of the code. |
| batch_size_suggestion | number | A suggested batch size for the accelerator and summarized hardware. |

For the `model_a`

#### Accelerator Enum

Expand Down Expand Up @@ -175,7 +177,7 @@ You can also use other base images. Pytorch and Tensorflow offer docker images f
| Field Name | Type | Description |
|--------------------------|-----------------------------------------------|--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|
| task | [Task Enum](#task-enum) | **REQUIRED.** Specifies the Machine Learning task for which the output can be used for. |
| result | [[Result Array Object](#result-array-object)] | The list of output array/tensor from the model. For example ($N \times H \times W$). Use -1 to indicate variable dimensions, like the batch dimension. |
| result_array | [[Result Array Object](#result-array-object)] | The list of output arrays/tensors from the model. |
| classification:classes | [[Class Object](#class-object)] | A list of class objects adhering to the [Classification extension](https://github.com/stac-extensions/classification). |
| post_processing_function | string | A url to the postprocessing function where normalization, rescaling, and other operations take place.. Or, instead, the function code path, for example: `my_package.my_module.my_processing_function` |

Expand Down
165 changes: 1 addition & 164 deletions README_STAC_MODEL.md
Original file line number Diff line number Diff line change
Expand Up @@ -42,170 +42,7 @@ stac-model --help
stac-model
```

This will make an example example.json metadata file for an example model.

Currently this looks like

```json
"mlm_name": "Resnet-18 Sentinel-2 ALL MOCO",
"mlm_task": "classification",
"mlm_framework": "pytorch",
"mlm_framework_version": "2.1.2+cu121",
"mlm_file_size": 1,
"mlm_memory_size": 1,
"mlm_input": [
{
"name": "13 Band Sentinel-2 Batch",
"bands": [
"B01",
"B02",
"B03",
"B04",
"B05",
"B06",
"B07",
"B08",
"B8A",
"B09",
"B10",
"B11",
"B12"
],
"input_array": {
"shape": [
-1,
13,
64,
64
],
"dim_order": "bchw",
"data_type": "float32"
},
"norm_by_channel": true,
"norm_type": "z_score",
"statistics": {
"mean": [
1354.40546513,
1118.24399958,
1042.92983953,
947.62620298,
1199.47283961,
1999.79090914,
2369.22292565,
2296.82608323,
732.08340178,
12.11327804,
1819.01027855,
1118.92391149,
2594.14080798
],
"stddev": [
245.71762908,
333.00778264,
395.09249139,
593.75055589,
566.4170017,
861.18399006,
1086.63139075,
1117.98170791,
404.91978886,
4.77584468,
1002.58768311,
761.30323499,
1231.58581042
]
},
"pre_processing_function": "https://github.com/microsoft/torchgeo/blob/545abe8326efc2848feae69d0212a15faba3eb00/torchgeo/datamodules/eurosat.py"
}
],
"mlm_output": [
{
"task": "classification",
"result_array": [
{
"shape": [
-1,
10
],
"dim_names": [
"batch",
"class"
],
"data_type": "float32"
}
],
"classification_classes": [
{
"value": 0,
"name": "Annual Crop",
"nodata": false
},
{
"value": 1,
"name": "Forest",
"nodata": false
},
{
"value": 2,
"name": "Herbaceous Vegetation",
"nodata": false
},
{
"value": 3,
"name": "Highway",
"nodata": false
},
{
"value": 4,
"name": "Industrial Buildings",
"nodata": false
},
{
"value": 5,
"name": "Pasture",
"nodata": false
},
{
"value": 6,
"name": "Permanent Crop",
"nodata": false
},
{
"value": 7,
"name": "Residential Buildings",
"nodata": false
},
{
"value": 8,
"name": "River",
"nodata": false
},
{
"value": 9,
"name": "SeaLake",
"nodata": false
}
]
}
],
"mlm_runtime": [
{
"asset": {
"href": "."
},
"source_code": {
"href": "."
},
"accelerator": "cuda",
"accelerator_constrained": false,
"hardware_summary": "Unknown"
}
],
"mlm_total_parameters": 11700000,
"mlm_pretrained_source": "EuroSat Sentinel-2",
"mlm_summary": "Sourced from torchgeo python library, identifier is ResNet18_Weights.SENTINEL2_ALL_MOCO"
}
```
This will make [this example item](./examples/example.json) for an example model.

## :chart_with_upwards_trend: Releases

Expand Down
14 changes: 10 additions & 4 deletions stac_model/examples.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import pystac

import json
import shapely
from stac_model.schema import (
Asset,
ClassObject,
Expand Down Expand Up @@ -73,7 +74,7 @@ def eurosat_resnet():
norm_type="z_score",
resize_type="none",
statistics=stats,
pre_processing_function="https://github.com/microsoft/torchgeo/blob/545abe8326efc2848feae69d0212a15faba3eb00/torchgeo/datamodules/eurosat.py", # noqa: E501
pre_processing_function="torchgeo.datamodules.eurosat.EuroSATDataModule.collate_fn", # noqa: E501
)
runtime = Runtime(
framework="torch",
Expand Down Expand Up @@ -135,8 +136,13 @@ def eurosat_resnet():
# Is this a problem that we don't do date validation if we supply as str?
start_datetime = "1900-01-01"
end_datetime = None
geometry = None
bbox = [-90, -180, 90, 180]
bbox = [
-7.882190080512502,
37.13739173208318,
27.911651652899923,
58.21798141355221
]
geometry = json.dumps(shapely.geometry.Polygon.from_bounds(*bbox).__geo_interface__, indent=2)
name = (
"_".join(ml_model_meta.name.split(" ")).lower()
+ f"_{ml_model_meta.task}".lower()
Expand Down

0 comments on commit bf3b07f

Please sign in to comment.