Skip to content

Commit

Permalink
add test_validate_parameterized_size
Browse files Browse the repository at this point in the history
  • Loading branch information
FynnBe committed Nov 12, 2024
1 parent c1a339a commit 3290dd9
Showing 1 changed file with 33 additions and 20 deletions.
53 changes: 33 additions & 20 deletions tests/test_model/test_v0_5.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
from copy import deepcopy
from datetime import datetime
from typing import Any, Dict, Union
from types import MappingProxyType
from typing import Any, Dict, Mapping, Union

import pytest
from pydantic import RootModel, ValidationError
Expand Down Expand Up @@ -222,10 +224,12 @@ def test_input_axis(kwargs: Union[Dict[str, Any], SpaceInputAxis]):
check_type(InputAxis, kwargs)


@pytest.fixture
def model_data():
@pytest.fixture(scope="module")
def model():
"""reuse model object to avoid expensive model validation,
use only when not manipulating the model!"""
with ValidationContext(perform_io_checks=False):
model = ModelDescr(
return ModelDescr(
documentation=UNET2D_ROOT / "README.md",
license=LicenseId("MIT"),
git_repo=HttpUrl("https://github.com/bioimage-io/core-bioimage-io-python"),
Expand Down Expand Up @@ -288,12 +292,21 @@ def model_data():
),
type="model",
)
data = model.model_dump(mode="json")
assert data["documentation"] == str(UNET2D_ROOT / "README.md"), (
data["documentation"],
str(UNET2D_ROOT / "README.md"),
)
return data


@pytest.fixture(scope="module")
def const_model_data(model: ModelDescr):
data = model.model_dump(mode="json")
assert data["documentation"] == str(UNET2D_ROOT / "README.md"), (
data["documentation"],
str(UNET2D_ROOT / "README.md"),
)
return MappingProxyType(data)


@pytest.fixture
def model_data(const_model_data: Mapping[str, Any]):
return deepcopy(dict(const_model_data))


@pytest.mark.parametrize(
Expand Down Expand Up @@ -389,18 +402,12 @@ def test_output_fixed_shape_too_small(model_data: Dict[str, Any]):
assert summary.status == "failed", summary.format()


def test_get_axis_sizes_with_surplus_n(model_data: Dict[str, Any]):
with ValidationContext(perform_io_checks=False):
model = ModelDescr(**model_data)

def test_get_axis_sizes_with_surplus_n(model: ModelDescr):
key = (model.inputs[0].id, AxisId("y"))
_ = model.get_axis_sizes(ns={key: 1}, batch_size=1)


def test_get_axis_sizes_with_partial_max_size(model_data: Dict[str, Any]):
with ValidationContext(perform_io_checks=False):
model = ModelDescr(**model_data)

def test_get_axis_sizes_with_partial_max_size(model: ModelDescr):
key = (model.inputs[0].id, AxisId("y"))
ns = {key: 100}
wo_max_shape = model.get_axis_sizes(ns=ns)
Expand All @@ -426,7 +433,7 @@ def test_output_ref_shape_mismatch(model_data: Dict[str, Any]):
model_data["outputs"][0]["axes"][2] = {
"type": "space",
"id": "x",
"size": {"tensor_id": "input_1", "axis_id": "x"},
"size": {"tensor_id": "input_1", "axis_id": "y"},
"halo": 2,
}
summary = validate_format(
Expand All @@ -450,7 +457,7 @@ def test_output_ref_shape_too_small(model_data: Dict[str, Any]):
model_data["outputs"][0]["axes"][2] = {
"type": "space",
"id": "x",
"size": {"tensor_id": "input_1", "axis_id": "x"},
"size": {"tensor_id": "input_1", "axis_id": "y"},
"halo": 2,
}
summary = validate_format(
Expand Down Expand Up @@ -553,3 +560,9 @@ def test_empty_axis_data():
)
def test_identifier_identity(a: Any, b: Any):
assert a == b


def test_validate_parameterized_size(model: ModelDescr):
param_size = model.inputs[0].axes[2].size
assert isinstance(param_size, ParameterizedSize)
assert param_size.validate_size(512) == 512

0 comments on commit 3290dd9

Please sign in to comment.