Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Workflow RDF #310

Open
wants to merge 24 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 3 additions & 6 deletions bioimageio/core/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,16 +4,14 @@
import sys
import warnings
from glob import glob

from pathlib import Path
from pprint import pformat, pprint
from pprint import pformat
from typing import List, Optional

import typer

from bioimageio.core import __version__, prediction, commands, resource_tests, load_raw_resource_description
from bioimageio.core import __version__, commands, prediction, resource_tests
from bioimageio.core.common import TestSummary
from bioimageio.core.prediction_pipeline import get_weight_formats
from bioimageio.spec.__main__ import app, help_version as help_version_spec
from bioimageio.spec.model.raw_nodes import WeightsFormat

Expand Down Expand Up @@ -192,7 +190,6 @@ def predict_image(
weight_format: Optional[WeightFormatEnum] = typer.Option(None, help="The weight format to use."),
devices: Optional[List[str]] = typer.Option(None, help="Devices for running the model."),
):

if isinstance(padding, str):
padding = json.loads(padding.replace("'", '"'))
assert isinstance(padding, dict)
Expand Down Expand Up @@ -244,7 +241,7 @@ def predict_images(
tiling = json.loads(tiling.replace("'", '"'))
assert isinstance(tiling, dict)

# this is a weird typer bug: default devices are empty tuple although they should be None
# this is a weird typer bug: default devices are empty tuple, although they should be None
if len(devices) == 0:
devices = None
prediction.predict_images(
Expand Down
33 changes: 21 additions & 12 deletions bioimageio/core/image_helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,13 +13,13 @@
#


def transform_input_image(image: np.ndarray, tensor_axes: str, image_axes: Optional[str] = None):
def transform_input_image(image: np.ndarray, tensor_axes: Sequence[str], image_axes: Optional[Sequence[str]] = None):
"""Transform input image into output tensor with desired axes.

Args:
image: the input image
tensor_axes: the desired tensor axes
input_axes: the axes of the input image (optional)
image_axes: the axes of the input image (optional)
"""
# if the image axes are not given deduce them from the required axes and image shape
if image_axes is None:
Expand All @@ -35,7 +35,16 @@ def transform_input_image(image: np.ndarray, tensor_axes: str, image_axes: Optio
image_axes = "bczyx"
else:
raise ValueError(f"Invalid number of image dimensions: {ndim}")
tensor = DataArray(image, dims=tuple(image_axes))

# instead of 'b' we might want 'batch', etc...
axis_letter_map = {
letter: name
for letter, name in {"b": "batch", "c": "channel", "i": "index", "t": "time"}.items()
if name in tensor_axes # only do this mapping if the full name is in the desired tensor_axes
}
image_axes = tuple(axis_letter_map.get(a, a) for a in image_axes)

tensor = DataArray(image, dims=image_axes)
# expand the missing image axes
missing_axes = tuple(set(tensor_axes) - set(image_axes))
tensor = tensor.expand_dims(dim=missing_axes)
Expand Down Expand Up @@ -75,9 +84,10 @@ def transform_output_tensor(tensor: np.ndarray, tensor_axes: str, output_axes: s


def to_channel_last(image):
chan_id = image.dims.index("c")
c = "c" if "c" in image.dims else "channel"
chan_id = image.dims.index(c)
if chan_id != image.ndim - 1:
target_axes = tuple(ax for ax in image.dims if ax != "c") + ("c",)
target_axes = tuple(ax for ax in image.dims if ax != c) + (c,)
image = image.transpose(*target_axes)
return image

Expand All @@ -95,27 +105,27 @@ def load_image(in_path, axes: Sequence[str]) -> DataArray:
is_volume = "z" in axes
im = imageio.volread(in_path) if is_volume else imageio.imread(in_path)
im = transform_input_image(im, axes)
return DataArray(im, dims=axes)
return DataArray(im, dims=tuple(axes))


def load_tensors(sources, tensor_specs: List[Union[InputTensor, OutputTensor]]) -> List[DataArray]:
return [load_image(s, sspec.axes) for s, sspec in zip(sources, tensor_specs)]


def save_image(out_path, image):
ext = os.path.splitext(out_path)[1]
def save_image(out_path: os.PathLike, image):
ext = os.path.splitext(str(out_path))[1]
if ext == ".npy":
np.save(out_path, image)
np.save(str(out_path), image)
else:
is_volume = "z" in image.dims

# squeeze batch or channel axes if they are singletons
squeeze = {ax: 0 if (ax in "bc" and sh == 1) else slice(None) for ax, sh in zip(image.dims, image.shape)}
image = image[squeeze]

if "b" in image.dims:
if "b" in image.dims or "batch" in image.dims:
raise RuntimeError(f"Cannot save prediction with batchsize > 1 as {ext}-file")
if "c" in image.dims: # image formats need channel last
if "c" in image.dims or "channel" in image.dims: # image formats need channel last
image = to_channel_last(image)

save_function = imageio.volsave if is_volume else imageio.imsave
Expand Down Expand Up @@ -157,7 +167,6 @@ def pad(image, axes: Sequence[str], padding, pad_right=True) -> Tuple[np.ndarray
pad_width = []
crop = {}
for ax, dlen, pr in zip(axes, image.shape, pad_right):

if ax in "zyx":
pad_to = padding_[ax]

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ def _unload(self) -> None:
def get_nn_instance(model_node: nodes.Model, **kwargs):
weight_spec = model_node.weights.get("pytorch_state_dict")
assert weight_spec is not None
assert isinstance(weight_spec.architecture, nodes.ImportedSource)
assert isinstance(weight_spec.architecture, nodes.ImportedCallable)
model_kwargs = weight_spec.kwargs
joined_kwargs = {} if model_kwargs is missing else dict(model_kwargs)
joined_kwargs.update(kwargs)
Expand Down
100 changes: 76 additions & 24 deletions bioimageio/core/resource_io/nodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,12 @@
from marshmallow import missing
from marshmallow.utils import _Missing

from bioimageio.spec.model import raw_nodes as model_raw_nodes
from bioimageio.spec.rdf import raw_nodes as rdf_raw_nodes
from bioimageio.spec.collection import raw_nodes as collection_raw_nodes
from bioimageio.spec.dataset import raw_nodes as dataset_raw_nodes
from bioimageio.spec.model.v0_4 import raw_nodes as model_raw_nodes
from bioimageio.spec.rdf import raw_nodes as rdf_raw_nodes
from bioimageio.spec.shared import raw_nodes
from bioimageio.spec.workflow import raw_nodes as workflow_raw_nodes


@dataclass
Expand Down Expand Up @@ -48,12 +50,12 @@ class CiteEntry(Node, rdf_raw_nodes.CiteEntry):


@dataclass
class Author(Node, model_raw_nodes.Author):
class Author(Node, rdf_raw_nodes.Author):
pass


@dataclass
class Maintainer(Node, model_raw_nodes.Maintainer):
class Maintainer(Node, rdf_raw_nodes.Maintainer):
pass


Expand All @@ -62,10 +64,19 @@ class Badge(Node, rdf_raw_nodes.Badge):
pass


@dataclass
class Attachments(Node, rdf_raw_nodes.Attachments):
files: Union[_Missing, List[Path]] = missing
unknown: Union[_Missing, Dict[str, Any]] = missing


@dataclass
class RDF(rdf_raw_nodes.RDF, ResourceDescription):
authors: Union[_Missing, List[Author]] = missing
attachments: Union[_Missing, Attachments] = missing
badges: Union[_Missing, List[Badge]] = missing
covers: Union[_Missing, List[Path]] = missing
cite: Union[_Missing, List[CiteEntry]] = missing
maintainers: Union[_Missing, List[Maintainer]] = missing


@dataclass
Expand All @@ -74,17 +85,22 @@ class CollectionEntry(Node, collection_raw_nodes.CollectionEntry):


@dataclass
class LinkedDataset(Node, model_raw_nodes.LinkedDataset):
class Collection(collection_raw_nodes.Collection, RDF):
collection: List[CollectionEntry] = missing


@dataclass
class Dataset(Node, dataset_raw_nodes.Dataset):
pass


@dataclass
class ModelParent(Node, model_raw_nodes.ModelParent):
class LinkedDataset(Node, model_raw_nodes.LinkedDataset):
pass


@dataclass
class Collection(collection_raw_nodes.Collection, RDF):
class ModelParent(Node, model_raw_nodes.ModelParent):
pass


Expand All @@ -106,6 +122,7 @@ class Postprocessing(Node, model_raw_nodes.Postprocessing):
@dataclass
class InputTensor(Node, model_raw_nodes.InputTensor):
axes: Tuple[str, ...] = missing
preprocessing: Union[_Missing, List[Preprocessing]] = missing

def __post_init__(self):
super().__post_init__()
Expand All @@ -116,6 +133,7 @@ def __post_init__(self):
@dataclass
class OutputTensor(Node, model_raw_nodes.OutputTensor):
axes: Tuple[str, ...] = missing
postprocessing: Union[_Missing, List[Postprocessing]] = missing

def __post_init__(self):
super().__post_init__()
Expand All @@ -124,48 +142,47 @@ def __post_init__(self):


@dataclass
class ImportedSource(Node):
factory: Callable
class ImportedCallable(Node):
call: Callable

def __call__(self, *args, **kwargs):
return self.factory(*args, **kwargs)
return self.call(*args, **kwargs)


@dataclass
class KerasHdf5WeightsEntry(Node, model_raw_nodes.KerasHdf5WeightsEntry):
source: Path = missing
class WeightsEntryBase(model_raw_nodes._WeightsEntryBase):
dependencies: Union[_Missing, Dependencies] = missing


@dataclass
class OnnxWeightsEntry(Node, model_raw_nodes.OnnxWeightsEntry):
class KerasHdf5WeightsEntry(WeightsEntryBase, model_raw_nodes.KerasHdf5WeightsEntry):
source: Path = missing


@dataclass
class PytorchStateDictWeightsEntry(Node, model_raw_nodes.PytorchStateDictWeightsEntry):
class OnnxWeightsEntry(WeightsEntryBase, model_raw_nodes.OnnxWeightsEntry):
source: Path = missing
architecture: Union[_Missing, ImportedSource] = missing


@dataclass
class TorchscriptWeightsEntry(Node, model_raw_nodes.TorchscriptWeightsEntry):
class PytorchStateDictWeightsEntry(WeightsEntryBase, model_raw_nodes.PytorchStateDictWeightsEntry):
source: Path = missing
architecture: Union[_Missing, ImportedCallable] = missing


@dataclass
class TensorflowJsWeightsEntry(Node, model_raw_nodes.TensorflowJsWeightsEntry):
class TorchscriptWeightsEntry(WeightsEntryBase, model_raw_nodes.TorchscriptWeightsEntry):
source: Path = missing


@dataclass
class TensorflowSavedModelBundleWeightsEntry(Node, model_raw_nodes.TensorflowSavedModelBundleWeightsEntry):
class TensorflowJsWeightsEntry(WeightsEntryBase, model_raw_nodes.TensorflowJsWeightsEntry):
source: Path = missing


@dataclass
class Attachments(Node, rdf_raw_nodes.Attachments):
files: Union[_Missing, List[Path]] = missing
unknown: Union[_Missing, Dict[str, Any]] = missing
class TensorflowSavedModelBundleWeightsEntry(WeightsEntryBase, model_raw_nodes.TensorflowSavedModelBundleWeightsEntry):
source: Path = missing


WeightsEntry = Union[
Expand All @@ -180,8 +197,43 @@ class Attachments(Node, rdf_raw_nodes.Attachments):

@dataclass
class Model(model_raw_nodes.Model, RDF):
authors: List[Author] = missing
maintainers: Union[_Missing, List[Maintainer]] = missing
inputs: List[InputTensor] = missing
outputs: List[OutputTensor] = missing
parent: Union[_Missing, ModelParent] = missing
run_mode: Union[_Missing, RunMode] = missing
test_inputs: List[Path] = missing
test_outputs: List[Path] = missing
training_data: Union[_Missing, Dataset, LinkedDataset] = missing
weights: Dict[model_raw_nodes.WeightsFormat, WeightsEntry] = missing


@dataclass
class Axis(Node, workflow_raw_nodes.Axis):
pass


@dataclass
class BatchAxis(Node, workflow_raw_nodes.Axis):
pass


@dataclass
class Input(Node, workflow_raw_nodes.Input):
pass


@dataclass
class Option(Node, workflow_raw_nodes.Option):
pass


@dataclass
class Output(Node, workflow_raw_nodes.Output):
pass


@dataclass
class Workflow(workflow_raw_nodes.Workflow, RDF):
inputs: List[Input] = missing
options: List[Option] = missing
outputs: List[Output] = missing
Loading