Skip to content
This repository was archived by the owner on Oct 2, 2024. It is now read-only.

crim-ca/dlm-extension

Repository files navigation

Machine Learning Model Extension Specification

hackmd-github-sync-badge

The STAC Machine Learning Model (MLM) Extension provides a standard set of fields to describe machine learning models trained on overhead imagery and enable running model inference.

The main objectives of the extension are:

  1. to enable building model collections that can be searched alongside associated STAC datasets
  2. record all necessary bands, parameters, modeling artifact locations, and high-level processing steps to deploy an inference service.

Specifically, this extension records the following information to make ML models searchable and reusable:

  1. Sensor band specifications
  2. Model input transforms including rescale and normalization
  3. Model output shape, data type, and its semantic interpretation
  4. An optional, flexible description of the runtime environment to be able to run the model
  5. Scientific references

The MLM specification is biased towards supervised ML models the produce classifications. However, fields that relate to supervised ML are optional and users can use the fields they need for different tasks.

Check the original technical report for an earlier version of the Model Extension here for more details.

Image Description

Item Properties and Collection Fields

Field Name Type Description
mlm:name string REQUIRED. A unique name for the model. Should be distinct from the name of the architecture it is based on, or the name(s) of the input(s).
mlm:input [Model Input Object] REQUIRED. Describes the transformation between the EO data and the model input.
mlm:architecture Architecture Object REQUIRED. Describes the model architecture.
mlm:runtime Runtime Object REQUIRED. Describes the runtime environments to run the model (inference).
mlm:output Model Output Object REQUIRED. Describes each model output and how to interpret it.
mlm:parameters Parameters Object Mapping with names for the parameters and their values. Some models may take additional scalars, tuples, and other non-tensor inputs like text.

In addition, fields from the following extensions must be imported in the item:

Model Input Object

Field Name Type Description
name string REQUIRED. Informative name of the input variable. Example "RGB Time Series"
bands [string] REQUIRED. The names of the raster bands used to train or fine-tune the model, which may be all or a subset of bands available in a STAC Item's Band Object.
input_array Array Object REQUIRED. The N-dimensional array object that describes the shape, dimension ordering, and data type.
parameters Parameters Object Mapping with names for the parameters and their values. Some models may take additional scalars, tuples, and other non-tensor inputs like text.
norm_by_channel boolean Whether to normalize each channel by channel-wise statistics or to normalize by dataset statistics. If True, use an array of Statistics Objects that is ordered like the bands field in this object.
norm_type string Normalization method. Select one option from "min_max", "z_score", "max_norm", "mean_norm", "unit_variance", "none"
rescale_type string High-level descriptor of the rescaling method to change image shape. Select one option from "crop", "pad", "interpolation", "none". If your rescaling method combines more than one of these operations, provide the name of the operation instead
statistics Statistics Object | [Statistics Object] Dataset statistics for the training dataset used to normalize the inputs.
pre_processing_function string A url to the preprocessing function where normalization and rescaling takes place, and any other significant operations. Or, instead, the function code path, for example: my_python_module_name:my_processing_function

Parameters Object

Field Name Type Description
parameter names depend on the model number | string | boolean | array The number of fields and their names depend on the model. Values should be not be n-dimensional array inputs. If the model input can be represented as an n-dimensional array, it should instead be supplied as another model input object.

The parameters field can either be specified in the model input object if they are associated with a specific input or as an Item or Collection field if the parameters are supplied without relation to a specific model input.

Bands and Statistics

We use the STAC 1.1 Bands Object for representing bands information, including the nodata value, data type, and common band names. Only bands used to train or fine tune the model should be included in this bands field.

A deviation from the STAC 1.1 Bands Object is that we do not include the Statistics object at the band object level, but at the Model Input level. This is because in machine learning, it is common to only need overall statistics for the dataset used to train the model to normalize all bands.

Array Object

Field Name Type Description
shape [integer] REQUIRED. Shape of the input n-dimensional array ( N × C × H × W ), including the batch size dimension. The batch size dimension must either be greater than 0 or -1 to indicate an unspecified batch dimension size.
dim_order string REQUIRED. How the above dimensions are ordered within the shape. "bhw", "bchw", "bthw", "btchw" are valid orderings where b=batch, c=channel, t=time, h=height, w=width.
dtype string REQUIRED. The data type of values in the n-dimensional array. Suggested to use Numpy numerical types, omitting the numpy module, e.g. "float32"

Note: It is common in the machine learning, computer vision, and remote sensing communities to refer to rasters that are inputs to a model as arrays or tensors. Array Objects are distinct from the JSON array type used to represent lists of values.

Architecture Object

Field Name Type Description
name string REQUIRED. The name of the model architecture. For example, "ResNet-18" or "Random Forest"
file_size integer REQUIRED. The size on disk of the model artifact (bytes).
memory_size integer REQUIRED. The in-memory size of the model on the accelerator during inference (bytes).
summary string Summary of the layers, can be the output of print(model).
pretrained_source string Indicates the source of the pretraining (ex: ImageNet).
total_parameters integer Total number of parameters.

Runtime Object

Field Name Type Description
framework string REQUIRED. Framework used to train the model (ex: PyTorch, TensorFlow).
version string REQUIRED. framework version (some models require a specific version of the framework to run).
model_asset Asset Object REQUIRED. Asset object containing URI to the model file.
source_code Asset Object REQUIRED. Source code description. Can describe a github repo, zip archive, etc. This description should reference the inference function, for example my_package.my_module.predict
accelerator Accelerator Enum REQUIRED. The intended computational hardware that runs inference.
accelerator_constrained boolean REQUIRED. True if the intended accelerator is the only accelerator that can run inference. False if other accelerators, such as amd64 (CPU), can run inference.
hardware_summary string REQUIRED. A high level description of the number of accelerators, specific generation of the accelerator, or other relevant inference details.
container Container RECOMMENDED. Information to run the model in a container instance.
model_commit_hash string Hash value pointing to a specific version of the code.
batch_size_suggestion number A suggested batch size for the accelerator and summarized hardware.

Accelerator Enum

It is recommended to define accelerator with one of the following values:

  • amd64 models compatible with AMD or Intel CPUs (no hardware specific optimizations)
  • cuda models compatible with NVIDIA GPUs
  • xla models compiled with XLA. models trained on TPUs are typically compiled with XLA.
  • amd-rocm models trained on AMD GPUs
  • intel-ipex-cpu for models optimized with IPEX for Intel CPUs
  • intel-ipex-gpu for models optimized with IPEX for Intel GPUs
  • macos-arm for models trained on Apple Silicon

Container Object

Field Name Type Description
container_file string Url of the container file (Dockerfile).
image_name string Name of the container image.
tag string Tag of the image.
working_dir string Working directory in the instance that can be mapped.
run string Running command.

If you're unsure how to containerize your model, we suggest starting from the latest official container image for your framework that works with your model and pinning the container version.

Examples: Pytorch Dockerhub Pytorch Docker Run Example

Tensorflow Dockerhub Tensorflow Docker Run Example

Using a base image for a framework looks like

# In your Dockerfile, pull the latest base image with all framework dependencies including accelerator drivers
FROM pytorch/pytorch:2.1.2-cuda11.8-cudnn8-runtime

### Your specific environment setup to run your model
RUN pip install my_package

You can also use other base images. Pytorch and Tensorflow offer docker images for serving models for inference.

Model Output Object

Field Name Type Description
task Task Enum REQUIRED. Specifies the Machine Learning task for which the output can be used for.
number_of_classes integer Number of classes.
result [Result Array Object] The list of output array/tensor from the model. For example ( N × H × W ). Use -1 to indicate variable dimensions, like the batch dimension.
class_name_mapping Class Map Object Mapping of the class name to an index representing the label in the model output.
post_processing_function string A url to the postprocessing function where normalization, rescaling, and other operations take place.. Or, instead, the function code path, for example: my_package.my_module.my_processing_function

While only task is a required field, all fields are recommended for supervised tasks that produce a fixed shape tensor and have output classes. image-captioning, multi-modal, and generative tasks may not return fixed shape tensors or classes.

Task Enum

It is recommended to define task with one of the following values:

  • regression
  • classification
  • object detection
  • semantic segmentation
  • instance segmentation
  • panoptic segmentation
  • multi-modal
  • similarity search
  • image captioning
  • generative

If the task falls within supervised machine learning and uses labels during training, this should align with the label:tasks values defined in STAC Label Extension for relevant STAC Collections and Items employed with the model described by this extension.

Result Array Object

Field Name Type Description
shape [integer] REQUIRED. Shape of the n-dimensional result array ( N × H × W ), possibly including a batch size dimension. The batch size dimension must either be greater than 0 or -1 to indicate an unspecified batch dimension size.
dim_names [string] REQUIRED. The names of the above dimensions of the result array, ordered the same as this object's shape field.
dtype string REQUIRED. The data type of values in the array. Suggested to use Numpy numerical types, omitting the numpy module, e.g. "float32"

Class Map Object

Field Name Type Description
class names depend on the model integer There are N corresponding integer values corresponding to N class fields.

The user can supply any number of fields for the classes of their model if the model produces a supervised classification result. |

Relation types

The following types should be used as applicable rel types in the Link Object of STAC Items describing Band Assets used with a model.

Type Description
derived_from This link points to _item.json or _collection.json. Replace with the unique mlm:name field's value.

Contributing

All contributions are subject to the STAC Specification Code of Conduct. For contributions, please follow the STAC specification contributing guide Instructions for running tests are copied here for convenience.

Running tests

The same checks that run as checks on PRs are part of the repository and can be run locally to verify that changes are valid. To run tests locally, you'll need npm, which is a standard part of any node.js installation.

First, install everything with npm once. Navigate to the root of this repository and on your command line run:

npm install

Then to check Markdown formatting and test the examples against the JSON schema, you can run:

npm test

This will spit out the same texts that you see online, and you can then go and fix your markdown or examples.

If the tests reveal formatting problems with the examples, you can fix them with:

npm run format-examples

Packages

No packages published

Languages

  • Python 90.1%
  • Makefile 7.9%
  • Dockerfile 2.0%