Skip to content

Commit

Permalink
Merge pull request #88 from Open-EO/27-minimal-inference
Browse files Browse the repository at this point in the history
27 minimal inference
  • Loading branch information
GriffinBabe authored Apr 15, 2024
2 parents 11d9c39 + eaf3d8a commit faf53d2
Show file tree
Hide file tree
Showing 10 changed files with 451 additions and 43 deletions.
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ dependencies = [
"geojson>=3.0.0",
"h5netcdf>=1.2.0",
"openeo[localprocessing]",
"onnxruntime>=1.16.3",
"cftime",
"pyarrow",
"fastparquet",
Expand Down
8 changes: 4 additions & 4 deletions src/openeo_gfmap/features/feature_extractor.py
Original file line number Diff line number Diff line change
Expand Up @@ -229,7 +229,7 @@ def _get_apply_udf_data(feature_extractor: FeatureExtractor) -> str:
return source.replace('"<feature_extractor_class>"', feature_extractor.__name__)


def generate_udf_code(feature_extractor_class: FeatureExtractor) -> openeo.UDF:
def _generate_udf_code(feature_extractor_class: FeatureExtractor) -> openeo.UDF:
"""Generates the udf code by packing imports of this file, the necessary
superclass and subclasses as well as the user defined feature extractor
class and the apply_datacube function.
Expand Down Expand Up @@ -271,7 +271,7 @@ def apply_feature_extractor(
feature_extractor._parameters = parameters
output_labels = feature_extractor.output_labels()

udf_code = generate_udf_code(feature_extractor_class)
udf_code = _generate_udf_code(feature_extractor_class)

udf = openeo.UDF(code=udf_code, context=parameters)

Expand All @@ -282,7 +282,7 @@ def apply_feature_extractor(
def apply_feature_extractor_local(
feature_extractor_class: FeatureExtractor, cube: xr.DataArray, parameters: dict
) -> xr.DataArray:
"""Applies and user-define feature extractor, but locally. The
"""Applies and user-defined feature extractor, but locally. The
parameters are the same as in the `apply_feature_extractor` function,
excepts for the cube parameter which expects a `xarray.DataArray` instead of
a `openeo.rest.datacube.DataCube` object.
Expand All @@ -291,7 +291,7 @@ def apply_feature_extractor_local(
feature_extractor._parameters = parameters
output_labels = feature_extractor.output_labels()

udf_code = generate_udf_code(feature_extractor_class)
udf_code = _generate_udf_code(feature_extractor_class)

udf = openeo.UDF(code=udf_code, context=parameters)

Expand Down
4 changes: 4 additions & 0 deletions src/openeo_gfmap/fetching/commons.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,6 +127,10 @@ def load_collection(
properties=load_collection_parameters,
)

# Adding the process graph updates for experimental features
if params.get("update_arguments") is not None:
cube.result_node().update_arguments(**params["update_arguments"])

# Peforming pre-mask optimization
pre_mask = params.get("pre_mask", None)
if pre_mask is not None:
Expand Down
36 changes: 0 additions & 36 deletions src/openeo_gfmap/inference/inference_models.py

This file was deleted.

262 changes: 262 additions & 0 deletions src/openeo_gfmap/inference/model_inference.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,262 @@
"""Inference functionalities. Such as a base class to assist the implementation
of inference models on an UDF.
"""

import functools
import inspect
import re
import sys
from abc import ABC, abstractmethod

import numpy as np
import openeo
import requests
import xarray as xr
from openeo.udf import XarrayDataCube
from openeo.udf import inspect as udf_inspect
from openeo.udf.run_code import execute_local_udf
from openeo.udf.udf_data import UdfData

sys.path.insert(0, "onnx_deps")
import onnxruntime as ort # noqa: E402

EPSG_HARMONIZED_NAME = "GEO-EPSG"


class ModelInference(ABC):
"""Base class for all model inference UDFs. It provides some common
methods and attributes to be used by other model inference classes.
"""

@functools.lru_cache(maxsize=6)
def load_ort_session(self, model_url: str):
"""Loads an onnx session from a publicly available URL. The URL must be a direct
download link to the ONNX session file.
The `lru_cache` decorator avoids loading multiple time the model within the same worker.
"""
# Two minutes timeout to download the model
response = requests.get(model_url, timeout=120)
model = response.content

return ort.InferenceSession(model)

def apply_ml(
self, tensor: np.ndarray, session: ort.InferenceSession, input_name: str
) -> np.ndarray:
"""Applies the machine learning model to the input data as a tensor.
Parameters
----------
tensor: np.ndarray
The input data with shape (bands, instance). If the input data is a tile (bands, y, x),
then the y, x dimension must be flattened before being applied in this function.
session: ort.InferenceSession
The ONNX Session object, loaded from the `load_ort_session` class method.
input_name: str
The name of the input tensor in the ONNX session. Depends on how is the ONNX serialized
model generated. For example, CatBoost models have their input tensor named as
features: https://catboost.ai/en/docs/concepts/apply-onnx-ml
"""
return session.run(None, {input_name: tensor})[0]

def _common_preparations(self, inarr: xr.DataArray, parameters: dict) -> xr.DataArray:
"""Common preparations for all inference models. This method will be
executed at the very beginning of the process.
"""
self._epsg = parameters.pop(EPSG_HARMONIZED_NAME)
self._parameters = parameters
return inarr

def _execute(self, cube: XarrayDataCube, parameters: dict) -> XarrayDataCube:
arr = cube.get_array().transpose("bands", "y", "x")
arr = self._common_preparations(arr, parameters)
arr = self.execute(arr).transpose("bands", "y", "x")
return XarrayDataCube(arr)

@property
def epsg(self) -> int:
"""EPSG code of the input data."""
return self._epsg

@abstractmethod
def output_labels(self) -> list:
"""Returns the labels of the output data."""
raise NotImplementedError(
"ModelInference is a base abstract class, please implement the "
"output_labels property."
)

@abstractmethod
def execute(self, inarr: xr.DataArray) -> xr.DataArray:
"""Executes the model inference."""
raise NotImplementedError(
"ModelInference is a base abstract class, please implement the " "execute method."
)


class ONNXModelInference(ModelInference):
"""Basic implementation of model inference that loads an ONNX model and runs the data
through it. The input data, as model inference classes, is expected to have ('bands', 'y', 'x')
as dimension orders, where 'bands' are the features that were computed the same way as for the
training data.
The following parameters are necessary:
- `model_url`: URL to download the ONNX model.
- `input_name`: Name of the input tensor in the ONNX model.
- `output_labels`: Labels of the output data.
"""

def output_labels(self) -> list:
return self._parameters["output_labels"]

def execute(self, inarr: xr.DataArray) -> xr.DataArray:
if self._parameters.get("model_url") is None:
raise ValueError("The model_url must be defined in the parameters.")

# Load the model and the input_name parameters
session = self.load_ort_session(self._parameters.get("model_url"))

input_name = self._parameters.get("input_name")
if input_name is None:
input_name = session.get_inputs()[0].name
udf_inspect(
message=f"Input name not defined. Using name of parameters from the model session: {input_name}.",
level="warning",
)

# Run the model inference on the input data
input_data = inarr.values.astype(np.float32)
n_bands, height, width = input_data.shape

# Flatten the x and y coordiantes into one
input_data = input_data.reshape(n_bands, -1).T

# Make the prediction
output = self.apply_ml(input_data, session, input_name)

output = output.reshape(len(self.output_labels()), height, width)

return xr.DataArray(
output,
dims=["bands", "y", "x"],
coords={"bands": self.output_labels(), "x": inarr.x, "y": inarr.y},
)


def apply_udf_data(udf_data: UdfData) -> XarrayDataCube:
model_inference_class = "<model_inference_class>"

model_inference = model_inference_class()

# User-defined, model inference class initialized here
cube = udf_data.datacube_list[0]
parameters = udf_data.user_context

proj = udf_data.proj
if proj is not None:
proj = proj.get("EPSG")

parameters[EPSG_HARMONIZED_NAME] = proj

cube = model_inference._execute(cube, parameters=parameters)

udf_data.datacube_list = [cube]

return udf_data


def _get_imports() -> str:
with open(__file__, "r", encoding="UTF-8") as f:
script_source = f.read()

lines = script_source.split("\n")

imports = []
static_globals = []

for line in lines:
if line.strip().startswith(("import ", "from ", "sys.path.insert(", "sys.path.append(")):
imports.append(line)
elif re.match("^[A-Z_0-9]+\s*=.*$", line):
static_globals.append(line)

return "\n".join(imports) + "\n\n" + "\n".join(static_globals)


def _get_apply_udf_data(model_inference: ModelInference) -> str:
source_lines = inspect.getsource(apply_udf_data)
source = "".join(source_lines)
# replace in the source function the `model_inference_class`
return source.replace('"<model_inference_class>"', model_inference.__name__)


def _generate_udf_code(model_inference_class: ModelInference) -> openeo.UDF:
"""Generates the udf code by packing imports of this file, the necessary
superclass and subclasses as well as the user defined model inference
class and the apply_datacube function.
"""

# UDF code that will be built here
udf_code = ""

assert issubclass(
model_inference_class, ModelInference
), "The model inference class must be a subclass of ModelInference."

udf_code += _get_imports() + "\n\n"
udf_code += f"{inspect.getsource(ModelInference)}\n\n"
udf_code += f"{inspect.getsource(model_inference_class)}\n\n"
udf_code += _get_apply_udf_data(model_inference_class)
return udf_code


def apply_model_inference(
model_inference_class: ModelInference,
cube: openeo.rest.datacube.DataCube,
parameters: dict,
size: list,
overlap: list = [],
) -> openeo.rest.datacube.DataCube:
"""Applies an user-defined model inference on the cube by using the
`openeo.Cube.apply_neighborhood` method. The defined class as well as the
required subclasses will be packed into a generated UDF file that will be
executed.
"""
model_inference = model_inference_class()
model_inference._parameters = parameters
output_labels = model_inference.output_labels()

udf_code = _generate_udf_code(model_inference_class)

udf = openeo.UDF(code=udf_code, context=parameters)

cube = cube.apply_neighborhood(process=udf, size=size, overlap=overlap)
return cube.rename_labels(dimension="bands", target=output_labels)


def apply_model_inference_local(
model_inference_class: ModelInference, cube: xr.DataArray, parameters: dict
) -> xr.DataArray:
"""Applies and user-defined model inference, but locally. The
parameters are the same as in the `apply_model_inference` function,
excepts for the cube parameter which expects a `xarray.DataArray` instead of
a `openeo.rest.datacube.DataCube` object.
"""
model_inference = model_inference_class()
model_inference._parameters = parameters
output_labels = model_inference.output_labels()

udf_code = _generate_udf_code(model_inference_class)

udf = openeo.UDF(code=udf_code, context=parameters)

cube = XarrayDataCube(cube)

out_udf_data: UdfData = execute_local_udf(udf, cube, fmt="NetCDF")

output_cubes = out_udf_data.datacube_list

assert len(output_cubes) == 1, "UDF should have only a single output cube."

return output_cubes[0].get_array().assign_coords({"bands": output_labels})
4 changes: 3 additions & 1 deletion src/openeo_gfmap/preprocessing/cloudmasking.py
Original file line number Diff line number Diff line change
Expand Up @@ -183,7 +183,9 @@ def max_score_selection(score):
)
elif isinstance(period, list):
udf_path = Path(__file__).parent / "udf_rank.py"
rank_mask = bap_score.apply_neighborhood(
rank_mask = bap_score.add_dimension(
name="bands", label=BAPSCORE_HARMONIZED_NAME
).apply_neighborhood(
process=openeo.UDF.from_file(str(udf_path), context={"intervals": period}),
size=[
{"dimension": "x", "unit": "px", "value": 256},
Expand Down
13 changes: 13 additions & 0 deletions src/openeo_gfmap/preprocessing/sar.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
"""Routines to pre-process sar signals."""
import openeo


def compress_backscatter_uint16():
pass


def multitemporal_speckle(cube: openeo.DataCube) -> openeo.DataCube:
_ = cube.filter_bands(
bands=filter(lambda band: band.startswith("S1"), cube.metadata.band_names)
)
pass
3 changes: 1 addition & 2 deletions tests/test_openeo_gfmap/test_cloud_masking.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,13 +122,12 @@ def test_bap_quintad(backend: Backend):
# Fetch the datacube
s2_extractor = build_sentinel2_l2a_extractor(
backend_context=backend_context,
bands=["S2-L2A-SCL"],
bands=["S2-L2A-B04", "S2-L2A-SCL"],
fetch_type=FetchType.TILE,
**fetching_parameters,
)

cube = s2_extractor.get_cube(connection, spatial_extent, temporal_extent)

compositing_intervals = quintad_intervals(temporal_extent)

expected_intervals = [
Expand Down
Loading

0 comments on commit faf53d2

Please sign in to comment.