Skip to content

Commit 27ef747

Browse files
committed
WIP fix tests and cleanup
1 parent f9c3bf0 commit 27ef747

File tree

7 files changed

+104
-296
lines changed

7 files changed

+104
-296
lines changed

bioimageio/core/_prediction_pipeline.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -164,20 +164,20 @@ def predict_sample_without_blocking(
164164
if out is not None
165165
},
166166
stat=sample.stat,
167-
id=self.get_output_sample_id(sample.id),
167+
id=sample.id,
168168
)
169169
if not skip_postprocessing:
170170
self.apply_postprocessing(output)
171171

172172
return output
173173

174174
def get_output_sample_id(self, input_sample_id: SampleId):
175-
if input_sample_id is None:
176-
return None
177-
else:
178-
return f"{input_sample_id}_" + (
179-
self.model_description.id or self.model_description.name
180-
)
175+
warnings.warn(
176+
"`PredictionPipeline.get_output_sample_id()` is deprecated and will be"
177+
+ " removed soon. Output sample id is equal to input sample id, hence this"
178+
+ " function is not needed."
179+
)
180+
return input_sample_id
181181

182182
def predict_sample_with_fixed_blocking(
183183
self,

bioimageio/core/cli.py

Lines changed: 14 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -116,7 +116,7 @@ class ValidateFormatCmd(CmdBase, WithSource):
116116
"""validate the meta data format of a bioimageio resource."""
117117

118118
def run(self):
119-
validate_format(self.descr)
119+
sys.exit(validate_format(self.descr))
120120

121121

122122
class TestCmd(CmdBase, WithSource):
@@ -134,11 +134,13 @@ class TestCmd(CmdBase, WithSource):
134134
"""Precision for numerical comparisons"""
135135

136136
def run(self):
137-
test(
138-
self.descr,
139-
weight_format=self.weight_format,
140-
devices=self.devices,
141-
decimal=self.decimal,
137+
sys.exit(
138+
test(
139+
self.descr,
140+
weight_format=self.weight_format,
141+
devices=self.devices,
142+
decimal=self.decimal,
143+
)
142144
)
143145

144146

@@ -158,10 +160,12 @@ def run(self):
158160
self.descr.validation_summary.display()
159161
raise ValueError("resource description is invalid")
160162

161-
package(
162-
self.descr,
163-
self.path,
164-
weight_format=self.weight_format,
163+
sys.exit(
164+
package(
165+
self.descr,
166+
self.path,
167+
weight_format=self.weight_format,
168+
)
165169
)
166170

167171

bioimageio/core/commands.py

Lines changed: 7 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,6 @@
1-
"""deprecated,
2-
use the CLI object `bioimageio.core.cli.Bioimageio` programmatically instead.
3-
"""
1+
"""These functions implement the logic of the bioimageio command line interface
2+
defined in `bioimageio.core.cli`."""
43

5-
import sys
64
from pathlib import Path
75
from typing import Optional, Sequence, Union
86

@@ -28,7 +26,7 @@ def test(
2826
weight_format: WeightFormatArgAll = "all",
2927
devices: Optional[Union[str, Sequence[str]]] = None,
3028
decimal: int = 4,
31-
):
29+
) -> int:
3230
"""test a bioimageio resource
3331
3432
Args:
@@ -40,7 +38,7 @@ def test(
4038
"""
4139
if isinstance(descr, InvalidDescr):
4240
descr.validation_summary.display()
43-
sys.exit(1)
41+
return 1
4442

4543
summary = test_description(
4644
descr,
@@ -49,7 +47,7 @@ def test(
4947
decimal=decimal,
5048
)
5149
summary.display()
52-
sys.exit(0 if summary.status == "passed" else 1)
50+
return 0 if summary.status == "passed" else 1
5351

5452

5553
def validate_format(
@@ -61,7 +59,7 @@ def validate_format(
6159
descr: a bioimageio resource description
6260
"""
6361
descr.validation_summary.display()
64-
sys.exit(0 if descr.validation_summary.status == "passed" else 1)
62+
return 0 if descr.validation_summary.status == "passed" else 1
6563

6664

6765
def package(
@@ -98,3 +96,4 @@ def package(
9896
output_path=path,
9997
weights_priority_order=weights_priority_order,
10098
)
99+
return 0

bioimageio/core/digest_spec.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -159,7 +159,7 @@ def get_test_inputs(model: AnyModelDescr) -> Sample:
159159
for m, arr, ax in zip(member_ids, arrays, axes)
160160
},
161161
stat={},
162-
id="test-input",
162+
id="test-sample",
163163
)
164164

165165

@@ -180,7 +180,7 @@ def get_test_outputs(model: AnyModelDescr) -> Sample:
180180
for m, arr, ax in zip(member_ids, arrays, axes)
181181
},
182182
stat={},
183-
id="test-output",
183+
id="test-sample",
184184
)
185185

186186

tests/conftest.py

Lines changed: 19 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -165,39 +165,39 @@ def conda_cmd():
165165
#
166166

167167

168-
@fixture(params=TORCH_MODELS)
168+
@fixture(scope="session", params=TORCH_MODELS)
169169
def any_torch_model(request: FixtureRequest):
170170
return MODEL_SOURCES[request.param]
171171

172172

173-
@fixture(params=TORCHSCRIPT_MODELS)
173+
@fixture(scope="session", params=TORCHSCRIPT_MODELS)
174174
def any_torchscript_model(request: FixtureRequest):
175175
return MODEL_SOURCES[request.param]
176176

177177

178-
@fixture(params=ONNX_MODELS)
178+
@fixture(scope="session", params=ONNX_MODELS)
179179
def any_onnx_model(request: FixtureRequest):
180180
return MODEL_SOURCES[request.param]
181181

182182

183-
@fixture(params=TENSORFLOW_MODELS)
183+
@fixture(scope="session", params=TENSORFLOW_MODELS)
184184
def any_tensorflow_model(request: FixtureRequest):
185185
return MODEL_SOURCES[request.param]
186186

187187

188-
@fixture(params=KERAS_MODELS)
188+
@fixture(scope="session", params=KERAS_MODELS)
189189
def any_keras_model(request: FixtureRequest):
190190
return MODEL_SOURCES[request.param]
191191

192192

193-
@fixture(params=TENSORFLOW_JS_MODELS)
193+
@fixture(scope="session", params=TENSORFLOW_JS_MODELS)
194194
def any_tensorflow_js_model(request: FixtureRequest):
195195
return MODEL_SOURCES[request.param]
196196

197197

198198
# fixture to test with all models that should run in the current environment
199199
# we exclude any 'wrong' model here
200-
@fixture(params=sorted({m for m in ALL_MODELS if "wrong" not in m}))
200+
@fixture(scope="session", params=sorted({m for m in ALL_MODELS if "wrong" not in m}))
201201
def any_model(request: FixtureRequest):
202202
return MODEL_SOURCES[request.param]
203203

@@ -239,48 +239,52 @@ def unet2d_keras(request: FixtureRequest):
239239

240240

241241
# written as model group to automatically skip on missing torch
242-
@fixture(params=[] if skip_torch else ["unet2d_nuclei_broad_model"])
242+
@fixture(scope="session", params=[] if skip_torch else ["unet2d_nuclei_broad_model"])
243243
def unet2d_nuclei_broad_model(request: FixtureRequest):
244244
return MODEL_SOURCES[request.param]
245245

246246

247247
# written as model group to automatically skip on missing torch
248-
@fixture(params=[] if skip_torch else ["unet2d_diff_output_shape"])
248+
@fixture(scope="session", params=[] if skip_torch else ["unet2d_diff_output_shape"])
249249
def unet2d_diff_output_shape(request: FixtureRequest):
250250
return MODEL_SOURCES[request.param]
251251

252252

253253
# written as model group to automatically skip on missing torch
254-
@fixture(params=[] if skip_torch else ["unet2d_expand_output_shape"])
254+
@fixture(scope="session", params=[] if skip_torch else ["unet2d_expand_output_shape"])
255255
def unet2d_expand_output_shape(request: FixtureRequest):
256256
return MODEL_SOURCES[request.param]
257257

258258

259259
# written as model group to automatically skip on missing torch
260-
@fixture(params=[] if skip_torch else ["unet2d_fixed_shape"])
260+
@fixture(scope="session", params=[] if skip_torch else ["unet2d_fixed_shape"])
261261
def unet2d_fixed_shape(request: FixtureRequest):
262262
return MODEL_SOURCES[request.param]
263263

264264

265265
# written as model group to automatically skip on missing torch
266-
@fixture(params=[] if skip_torch else ["shape_change"])
266+
@fixture(scope="session", params=[] if skip_torch else ["shape_change"])
267267
def shape_change_model(request: FixtureRequest):
268268
return MODEL_SOURCES[request.param]
269269

270270

271271
# written as model group to automatically skip on missing tensorflow 1
272-
@fixture(params=["stardist_wrong_shape"] if tf_major_version == 1 else [])
272+
@fixture(
273+
scope="session", params=["stardist_wrong_shape"] if tf_major_version == 1 else []
274+
)
273275
def stardist_wrong_shape(request: FixtureRequest):
274276
return MODEL_SOURCES[request.param]
275277

276278

277279
# written as model group to automatically skip on missing tensorflow 1
278-
@fixture(params=["stardist_wrong_shape2"] if tf_major_version == 1 else [])
280+
@fixture(
281+
scope="session", params=["stardist_wrong_shape2"] if tf_major_version == 1 else []
282+
)
279283
def stardist_wrong_shape2(request: FixtureRequest):
280284
return MODEL_SOURCES[request.param]
281285

282286

283287
# written as model group to automatically skip on missing tensorflow 1
284-
@fixture(params=["stardist"] if tf_major_version == 1 else [])
288+
@fixture(scope="session", params=["stardist"] if tf_major_version == 1 else [])
285289
def stardist(request: FixtureRequest):
286290
return MODEL_SOURCES[request.param]

tests/test_commands.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -6,10 +6,10 @@
66
from bioimageio.core import load_model
77
from bioimageio.core.commands import package, validate_format
88
from bioimageio.core.commands import test as command_tst
9-
from bioimageio.spec.model import ModelDescr
9+
from bioimageio.spec import AnyModelDescr
1010

1111

12-
@pytest.mark.fixture(scope="module")
12+
@pytest.fixture(scope="module")
1313
def model(unet2d_nuclei_broad_model: str):
1414
return load_model(unet2d_nuclei_broad_model, perform_io_checks=False)
1515

@@ -23,14 +23,14 @@ def model(unet2d_nuclei_broad_model: str):
2323
)
2424
def test_package(
2525
weight_format: Literal["all", "pytorch_state_dict"],
26-
model: ModelDescr,
26+
model: AnyModelDescr,
2727
tmp_path: Path,
2828
):
29-
_ = package(model, weight_format=weight_format, path=tmp_path / "out.zip")
29+
assert package(model, weight_format=weight_format, path=tmp_path / "out.zip") == 0
3030

3131

32-
def test_validate_format(model: ModelDescr):
33-
_ = validate_format(model)
32+
def test_validate_format(model: AnyModelDescr):
33+
assert validate_format(model) == 0
3434

3535

3636
@pytest.mark.parametrize(
@@ -39,6 +39,6 @@ def test_validate_format(model: ModelDescr):
3939
def test_test(
4040
weight_format: Literal["all", "pytorch_state_dict"],
4141
devices: Optional[str],
42-
model: ModelDescr,
42+
model: AnyModelDescr,
4343
):
44-
_ = command_tst(model, weight_format=weight_format, devices=devices)
44+
assert command_tst(model, weight_format=weight_format, devices=devices) == 0

0 commit comments

Comments
 (0)