Skip to content

Commit

Permalink
Fixed tests
Browse files Browse the repository at this point in the history
  • Loading branch information
GriffinBabe committed Apr 15, 2024
1 parent c690c17 commit 276d0ba
Show file tree
Hide file tree
Showing 2 changed files with 42 additions and 8 deletions.
20 changes: 12 additions & 8 deletions tests/test_openeo_gfmap/test_model_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from pathlib import Path

import numpy as np
import xarray as xr
import rasterio
from openeo.udf import XarrayDataCube

from openeo_gfmap import (
Expand All @@ -21,10 +21,16 @@
from openeo_gfmap.preprocessing.cloudmasking import mask_scl_dilation
from openeo_gfmap.preprocessing.compositing import median_compositing

from .utils import load_dataarray_url

spatial_context = BoundingBoxExtent(west=5.0, south=51.2, east=5.025, north=51.225, epsg=4326)

temporal_extent = TemporalContext(start_date="2018-05-01", end_date="2018-10-31")

resources_file = (
"https://artifactory.vgt.vito.be/artifactory/auxdata-public/gfmap/test_inference_feats.nc"
)

onnx_model_url = (
"https://artifactory.vgt.vito.be/artifactory/auxdata-public/gfmap/knn_model_rgbnir.onnx"
)
Expand All @@ -33,7 +39,7 @@

def test_onnx_inference_local():
"""Test the ONNX Model inference locally"""
inds = xr.open_dataarray(Path(__file__).parent / "resources/test_inference_feats.nc")
inds = load_dataarray_url(resources_file)

inference = ONNXModelInference()

Expand Down Expand Up @@ -100,7 +106,7 @@ def test_onnx_inference():
size=[
{"dimension": "x", "unit": "px", "value": 128},
{"dimension": "y", "unit": "px", "value": 128},
{"dimension": "t", "value": 1},
{"dimension": "t", "value": "P1D"},
],
)

Expand All @@ -117,13 +123,11 @@ def test_onnx_inference():
job.start_and_wait()

for asset in job.get_results().get_assets():
if asset.metadata["type"].startswith("application/x-netcdf"):
if asset.metadata["type"].startswith("image/tiff"):
asset.download(output_path)
break

assert output_path.exists()

inds = xr.open_dataset(output_path).to_array(dim="bands")

assert inds.shape == (1, 256, 256)
assert len(np.unique(inds.values)) == 3
inds = rasterio.open(output_path, "r")
assert len(np.unique(inds.read(1))) == 3
30 changes: 30 additions & 0 deletions tests/test_openeo_gfmap/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
"""Utilitiaries used in tests, such as download test resources."""

from tempfile import NamedTemporaryFile

import requests
import xarray as xr


def load_dataset_url(url: str) -> NamedTemporaryFile:
"""Download a NetCDF file from the internet and return a Xarray Dataset."""
with NamedTemporaryFile(suffix=".nc", delete=True) as tmpfile:
response = requests.get(url)
response.raise_for_status()
tmpfile.write(response.content)

inds = xr.open_dataset(tmpfile.name)

return inds


def load_dataarray_url(url: str) -> NamedTemporaryFile:
"""Download a NetCDF file from the internet and return a Xarray Dataset."""
with NamedTemporaryFile(suffix=".nc", delete=True) as tmpfile:
response = requests.get(url)
response.raise_for_status()
tmpfile.write(response.content)

inds = xr.open_dataarray(tmpfile.name)

return inds

0 comments on commit 276d0ba

Please sign in to comment.