Skip to content

Commit

Permalink
Merge pull request #561 from bioimage-io/data_dep_size
Browse files Browse the repository at this point in the history
disallow ParameterizedSize for output axes
  • Loading branch information
FynnBe authored Mar 15, 2024
2 parents 3b60948 + 7b73ee9 commit 382c978
Showing 1 changed file with 60 additions and 24 deletions.
84 changes: 60 additions & 24 deletions bioimageio/spec/model/v0_5.py
Original file line number Diff line number Diff line change
Expand Up @@ -288,15 +288,17 @@ def get_size(
self,
axis: Union[
ChannelAxis,
IndexAxis,
IndexInputAxis,
IndexOutputAxis,
TimeInputAxis,
SpaceInputAxis,
TimeOutputAxis,
SpaceOutputAxis,
],
ref_axis: Union[
ChannelAxis,
IndexAxis,
IndexInputAxis,
IndexOutputAxis,
TimeInputAxis,
SpaceInputAxis,
TimeOutputAxis,
Expand Down Expand Up @@ -338,7 +340,8 @@ def get_size(
def _get_unit(
axis: Union[
ChannelAxis,
IndexAxis,
IndexInputAxis,
IndexOutputAxis,
TimeInputAxis,
SpaceInputAxis,
TimeOutputAxis,
Expand Down Expand Up @@ -400,7 +403,20 @@ def unit(self):
return None


class IndexTimeSpaceAxisBase(AxisBase):
class IndexAxisBase(AxisBase):
type: Literal["index"] = "index"
id: NonBatchAxisId = AxisId("index")

@property
def scale(self) -> float:
return 1.0

@property
def unit(self):
return None


class _WithInputAxisSize(Node):
size: Annotated[
Union[Annotated[int, Gt(0)], ParameterizedSize, SizeReference],
Field(
Expand All @@ -413,64 +429,78 @@ class IndexTimeSpaceAxisBase(AxisBase):
]
),
]
"""The size/length of an axis can be specified as
"""The size/length of this axis can be specified as
- fixed integer
- parameterized series of valid sizes (`ParameterizedSize`)
- reference to another axis with an optional offset (`SizeReference`)
"""


class IndexAxis(IndexTimeSpaceAxisBase):
type: Literal["index"] = "index"
id: NonBatchAxisId = AxisId("index")
class _WithOutputAxisSize(Node):
size: Annotated[
Union[Annotated[int, Gt(0)], SizeReference],
Field(
examples=[
10,
SizeReference(
tensor_id=TensorId("t"), axis_id=AxisId("a"), offset=5
).model_dump(mode="json"),
]
),
]
"""The size/length of this axis can be specified as
- fixed integer
- reference to another axis with an optional offset (`SizeReference`)
# TODO: add `DataDependentSize(min, max, step)`
"""

@property
def scale(self) -> float:
return 1.0

@property
def unit(self):
return None
class IndexInputAxis(IndexAxisBase, _WithInputAxisSize):
pass


class IndexOutputAxis(IndexAxisBase, _WithOutputAxisSize):
pass


class TimeAxisBase(IndexTimeSpaceAxisBase):
class TimeAxisBase(AxisBase):
type: Literal["time"] = "time"
id: NonBatchAxisId = AxisId("time")
unit: Optional[TimeUnit] = None
scale: Annotated[float, Gt(0)] = 1.0


class TimeInputAxis(TimeAxisBase):
class TimeInputAxis(TimeAxisBase, _WithInputAxisSize):
pass


class SpaceAxisBase(IndexTimeSpaceAxisBase):
class SpaceAxisBase(AxisBase):
type: Literal["space"] = "space"
id: Annotated[NonBatchAxisId, Field(examples=["x", "y", "z"])] = AxisId("x")
unit: Optional[SpaceUnit] = None
scale: Annotated[float, Gt(0)] = 1.0


class SpaceInputAxis(SpaceAxisBase):
class SpaceInputAxis(SpaceAxisBase, _WithInputAxisSize):
pass


_InputAxisUnion = Union[
BatchAxis, ChannelAxis, IndexAxis, TimeInputAxis, SpaceInputAxis
BatchAxis, ChannelAxis, IndexInputAxis, TimeInputAxis, SpaceInputAxis
]
InputAxis = Annotated[_InputAxisUnion, Field(discriminator="type")]


class TimeOutputAxis(TimeAxisBase, WithHalo):
class TimeOutputAxis(TimeAxisBase, WithHalo, _WithOutputAxisSize):
pass


class SpaceOutputAxis(SpaceAxisBase, WithHalo):
class SpaceOutputAxis(SpaceAxisBase, WithHalo, _WithOutputAxisSize):
pass


_OutputAxisUnion = Union[
BatchAxis, ChannelAxis, IndexAxis, TimeOutputAxis, SpaceOutputAxis
BatchAxis, ChannelAxis, IndexOutputAxis, TimeOutputAxis, SpaceOutputAxis
]
OutputAxis = Annotated[_OutputAxisUnion, Field(discriminator="type")]

Expand Down Expand Up @@ -1116,7 +1146,10 @@ def convert_axes(
)
)
elif axis_type == "index":
ret.append(IndexAxis(size=size))
if tensor_type == "input":
ret.append(IndexInputAxis(size=size))
else:
ret.append(IndexOutputAxis(size=size))
elif axis_type == "channel":
assert not isinstance(size, ParameterizedSize)
if isinstance(size, SizeReference):
Expand Down Expand Up @@ -2322,7 +2355,10 @@ def to_2d_image(data: NDArray[Any], axes: Sequence[AnyAxis]):
for i, a in enumerate(axes):
s = data.shape[i]
assert s > 1
if isinstance(a, (BatchAxis, IndexAxis)) and ndim > ndim_need:
if (
isinstance(a, (BatchAxis, IndexInputAxis, IndexOutputAxis))
and ndim > ndim_need
):
data = data[slices + (slice(s // 2 - 1, s // 2),)]
ndim -= 1
elif isinstance(a, ChannelAxis):
Expand Down

0 comments on commit 382c978

Please sign in to comment.