Skip to content

Commit

Permalink
Fix linting
Browse files Browse the repository at this point in the history
  • Loading branch information
GriffinBabe committed Apr 12, 2024
1 parent 2b4193a commit c690c17
Show file tree
Hide file tree
Showing 5 changed files with 49 additions and 39 deletions.
17 changes: 7 additions & 10 deletions src/openeo_gfmap/inference/model_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,11 +116,14 @@ def execute(self, inarr: xr.DataArray) -> xr.DataArray:

# Load the model and the input_name parameters
session = self.load_ort_session(self._parameters.get("model_url"))

input_name = self._parameters.get("input_name")
if input_name is None:
input_name = session.get_inputs()[0].name
udf_inspect(message=f"Input name not defined. Using name of parameters from the model session: {input_name}.", level="warning")
udf_inspect(
message=f"Input name not defined. Using name of parameters from the model session: {input_name}.",
level="warning",
)

# Run the model inference on the input data
input_data = inarr.values.astype(np.float32)
Expand All @@ -132,18 +135,12 @@ def execute(self, inarr: xr.DataArray) -> xr.DataArray:
# Make the prediction
output = self.apply_ml(input_data, session, input_name)

output = output.reshape(
len(self.output_labels()), height, width
)
output = output.reshape(len(self.output_labels()), height, width)

return xr.DataArray(
output,
dims=["bands", "y", "x"],
coords={
"bands": self.output_labels(),
"x": inarr.x,
"y": inarr.y
}
coords={"bands": self.output_labels(), "x": inarr.x, "y": inarr.y},
)


Expand Down
3 changes: 1 addition & 2 deletions src/openeo_gfmap/preprocessing/cloudmasking.py
Original file line number Diff line number Diff line change
Expand Up @@ -184,8 +184,7 @@ def max_score_selection(score):
elif isinstance(period, list):
udf_path = Path(__file__).parent / "udf_rank.py"
rank_mask = bap_score.add_dimension(
name="bands",
label=BAPSCORE_HARMONIZED_NAME
name="bands", label=BAPSCORE_HARMONIZED_NAME
).apply_neighborhood(
process=openeo.UDF.from_file(str(udf_path), context={"intervals": period}),
size=[
Expand Down
13 changes: 13 additions & 0 deletions src/openeo_gfmap/preprocessing/sar.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
"""Routines to pre-process sar signals."""
import openeo


def compress_backscatter_uint16():
pass


def multitemporal_speckle(cube: openeo.DataCube) -> openeo.DataCube:
_ = cube.filter_bands(
bands=filter(lambda band: band.startswith("S1"), cube.metadata.band_names)
)
pass
3 changes: 1 addition & 2 deletions tests/test_openeo_gfmap/test_cloud_masking.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,13 +122,12 @@ def test_bap_quintad(backend: Backend):
# Fetch the datacube
s2_extractor = build_sentinel2_l2a_extractor(
backend_context=backend_context,
bands=["S2-L2A-SCL"],
bands=["S2-L2A-B04", "S2-L2A-SCL"],
fetch_type=FetchType.TILE,
**fetching_parameters,
)

cube = s2_extractor.get_cube(connection, spatial_extent, temporal_extent)

compositing_intervals = quintad_intervals(temporal_extent)

expected_intervals = [
Expand Down
52 changes: 27 additions & 25 deletions tests/test_openeo_gfmap/test_model_inference.py
Original file line number Diff line number Diff line change
@@ -1,32 +1,36 @@
"""Test on model inference implementations, both local and remote."""
from pathlib import Path

import xarray as xr
import numpy as np

import xarray as xr
from openeo.udf import XarrayDataCube

from openeo_gfmap import BoundingBoxExtent, TemporalContext, FetchType, Backend, BackendContext
from openeo_gfmap.fetching.s2 import build_sentinel2_l2a_extractor
from openeo_gfmap import (
Backend,
BackendContext,
BoundingBoxExtent,
FetchType,
TemporalContext,
)
from openeo_gfmap.backend import cdse_connection
from openeo_gfmap.fetching.s2 import build_sentinel2_l2a_extractor
from openeo_gfmap.inference.model_inference import (
ONNXModelInference,
apply_model_inference,
)
from openeo_gfmap.preprocessing.cloudmasking import mask_scl_dilation
from openeo_gfmap.preprocessing.compositing import median_compositing
from openeo_gfmap.inference.model_inference import ONNXModelInference, apply_model_inference, apply_model_inference_local


spatial_context = BoundingBoxExtent(
west=5.0,
south=51.2,
east=5.025,
north=51.225,
epsg=4326
)
spatial_context = BoundingBoxExtent(west=5.0, south=51.2, east=5.025, north=51.225, epsg=4326)

temporal_extent = TemporalContext(start_date="2018-05-01", end_date="2018-10-31")

onnx_model_url = "https://artifactory.vgt.vito.be/artifactory/auxdata-public/gfmap/knn_model_rgbnir.onnx"
onnx_model_url = (
"https://artifactory.vgt.vito.be/artifactory/auxdata-public/gfmap/knn_model_rgbnir.onnx"
)
dependency_url = "https://artifactory.vgt.vito.be/artifactory/auxdata-public/openeo/onnx_dependencies_1.16.3.zip"


def test_onnx_inference_local():
"""Test the ONNX Model inference locally"""
inds = xr.open_dataarray(Path(__file__).parent / "resources/test_inference_feats.nc")
Expand All @@ -41,8 +45,8 @@ def test_onnx_inference_local():
"model_url": onnx_model_url,
"input_name": "X",
"output_labels": ["label"],
"GEO-EPSG": "whatever"
}
"GEO-EPSG": "whatever",
},
)

output = output.get_array()
Expand Down Expand Up @@ -80,9 +84,9 @@ def test_onnx_inference():
cube = median_compositing(cube, period="year")

# We remove the SCL mask
cube = cube.filter_bands(bands=['S2-L2A-B04', 'S2-L2A-B03', 'S2-L2A-B02', 'S2-L2A-B08'])
cube = cube.filter_bands(bands=["S2-L2A-B04", "S2-L2A-B03", "S2-L2A-B02", "S2-L2A-B08"])

cube = cube.ndvi(nir='S2-L2A-B08', red='S2-L2A-B04', target_band='S2-L2A-NDVI')
cube = cube.ndvi(nir="S2-L2A-B08", red="S2-L2A-B04", target_band="S2-L2A-NDVI")

# Perform model inference
cube = apply_model_inference(
Expand All @@ -96,7 +100,7 @@ def test_onnx_inference():
size=[
{"dimension": "x", "unit": "px", "value": 128},
{"dimension": "y", "unit": "px", "value": 128},
{"dimension": "t", "value": 1}
{"dimension": "t", "value": 1},
],
)

Expand All @@ -107,21 +111,19 @@ def test_onnx_inference():
title="test_onnx_inference",
out_format="GTiff",
job_options={
"udf-dependency-archives": [
f"{dependency_url}#onnx_deps"
],
}
"udf-dependency-archives": [f"{dependency_url}#onnx_deps"],
},
)
job.start_and_wait()

for asset in job.get_results().get_assets():
if asset.metadata["type"].startswith("application/x-netcdf"):
asset.download(output_path)
break

assert output_path.exists()

inds = xr.open_dataset(output_path).to_array(dim='bands')
inds = xr.open_dataset(output_path).to_array(dim="bands")

assert inds.shape == (1, 256, 256)
assert len(np.unique(inds.values)) == 3

0 comments on commit c690c17

Please sign in to comment.