Skip to content

Commit a227388

Browse files
committed
add numpy_load helper
1 parent e3b253b commit a227388

File tree

4 files changed

+17
-16
lines changed

4 files changed

+17
-16
lines changed

bioimageio/core/image_helper.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -9,10 +9,11 @@
99
from numpy.typing import NDArray
1010
from xarray import DataArray
1111

12-
from bioimageio.spec.model.v0_4 import InputTensor as InputTensor04
13-
from bioimageio.spec.model.v0_4 import OutputTensor as OutputTensor04
14-
from bioimageio.spec.model.v0_5 import InputTensor as InputTensor05
15-
from bioimageio.spec.model.v0_5 import OutputTensor as OutputTensor05
12+
from bioimageio.spec._internal.io_utils import load_array
13+
from bioimageio.spec.model.v0_4 import InputTensorDescr as InputTensor04
14+
from bioimageio.spec.model.v0_4 import OutputTensorDescr as OutputTensor04
15+
from bioimageio.spec.model.v0_5 import InputTensorDescr as InputTensor05
16+
from bioimageio.spec.model.v0_5 import OutputTensorDescr as OutputTensor05
1617

1718
InputTensor = Union[InputTensor04, InputTensor05]
1819
OutputTensor = Union[OutputTensor04, OutputTensor05]
@@ -103,7 +104,7 @@ def to_channel_last(image):
103104
def load_image(in_path, axes: Sequence[str]) -> DataArray:
104105
ext = os.path.splitext(in_path)[1]
105106
if ext == ".npy":
106-
im = np.load(in_path)
107+
im = load_array(in_path)
107108
else:
108109
is_volume = "z" in axes
109110
im = imageio.volread(in_path) if is_volume else imageio.imread(in_path)

bioimageio/core/resource_tests.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,6 @@
99
import numpy
1010
import numpy as np
1111
import xarray as xr
12-
from marshmallow import ValidationError
1312

1413
from bioimageio.core import __version__ as bioimageio_core_version
1514
from bioimageio.core import load_raw_resource_description, load_resource_description
@@ -25,6 +24,7 @@
2524
ResourceDescription,
2625
)
2726
from bioimageio.spec import __version__ as bioimageio_spec_version
27+
from bioimageio.spec._internal.io_utils import load_array
2828
from bioimageio.spec.model.raw_nodes import WeightsFormat
2929
from bioimageio.spec.shared import resolve_source
3030
from bioimageio.spec.shared.common import ValidationWarning
@@ -161,8 +161,8 @@ def _test_model_inference(model: Model, weight_format: str, devices: Optional[Li
161161
tb: Optional = None
162162
with warnings.catch_warnings(record=True) as all_warnings:
163163
try:
164-
inputs = [np.load(str(in_path)) for in_path in model.test_inputs]
165-
expected = [np.load(str(out_path)) for out_path in model.test_outputs]
164+
inputs = [load_array(str(in_path)) for in_path in model.test_inputs]
165+
expected = [load_array(str(out_path)) for out_path in model.test_outputs]
166166

167167
assert len(inputs) == len(model.inputs) # should be checked by validation
168168
input_shapes = {}
@@ -362,7 +362,7 @@ def debug_model(
362362
bioimageio_model=model, devices=devices, weight_format=weight_format
363363
)
364364
inputs = [
365-
xr.DataArray(np.load(str(in_path)), dims=input_spec.axes)
365+
xr.DataArray(load_array(str(in_path)), dims=input_spec.axes)
366366
for in_path, input_spec in zip(model.test_inputs, model.inputs)
367367
]
368368
input_dict = {input_spec.name: input for input_spec, input in zip(model.inputs, inputs)}
@@ -383,7 +383,7 @@ def debug_model(
383383
outputs = [outputs]
384384

385385
expected = [
386-
xr.DataArray(np.load(str(out_path)), dims=output_spec.axes)
386+
xr.DataArray(load_array(str(out_path)), dims=output_spec.axes)
387387
for out_path, output_spec in zip(model.test_outputs, model.outputs)
388388
]
389389
if len(outputs) != len(expected):

bioimageio/core/utils.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77

88
from bioimageio.spec.model import v0_4, v0_5
99
from bioimageio.spec.model.v0_5 import TensorId
10-
from bioimageio.spec.utils import download
10+
from bioimageio.spec.utils import download, load_array
1111

1212
# @singledispatch
1313
# def is_valid_tensor(description: object, tensor: Union[NDArray[Any], xr.DataArray]) -> bool:
@@ -24,7 +24,7 @@ def get_test_input_tensors(model: object) -> List[xr.DataArray]:
2424

2525
@get_test_input_tensors.register
2626
def _(model: v0_4.Model):
27-
data = [np.load(download(ipt).path) for ipt in model.test_inputs]
27+
data = [load_array(download(ipt).path) for ipt in model.test_inputs]
2828
assert all(isinstance(d, np.ndarray) for d in data)
2929

3030

tests/prediction_pipeline/test_measures.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -6,16 +6,16 @@
66
import pytest
77
import xarray as xr
88

9-
from bioimageio.core import statistical_measures
9+
from bioimageio.core import stat_measures
1010
from bioimageio.core.prediction_pipeline._measure_groups import get_measure_groups
1111
from bioimageio.core.prediction_pipeline._utils import PER_DATASET, PER_SAMPLE
12-
from bioimageio.core.statistical_measures import Mean, Percentile, Std, Var
12+
from bioimageio.core.stat_measures import Mean, Percentile, Std, Var
1313

1414

1515
@pytest.mark.parametrize("name_axes", product(["mean", "var", "std"], [None, ("x", "y")]))
1616
def test_individual_normal_measure(name_axes):
1717
name, axes = name_axes
18-
measure = getattr(statistical_measures, name.title())(axes=axes)
18+
measure = getattr(stat_measures, name.title())(axes=axes)
1919
data = xr.DataArray(np.random.random((5, 6, 3)), dims=("x", "y", "c"))
2020

2121
expected = getattr(data, name)(dim=axes)
@@ -26,7 +26,7 @@ def test_individual_normal_measure(name_axes):
2626
@pytest.mark.parametrize("axes_n", product([None, ("x", "y")], [0, 10, 50, 100]))
2727
def test_individual_percentile_measure(axes_n):
2828
axes, n = axes_n
29-
measure = statistical_measures.Percentile(axes=axes, n=n)
29+
measure = stat_measures.Percentile(axes=axes, n=n)
3030
data = xr.DataArray(np.random.random((5, 6, 3)), dims=("x", "y", "c"))
3131

3232
expected = data.quantile(q=n / 100, dim=axes)

0 commit comments

Comments
 (0)