Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Workflow RDF #310

Open
wants to merge 24 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 8 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
99 changes: 93 additions & 6 deletions bioimageio/core/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,19 +3,23 @@
import os
import sys
import warnings
from argparse import ArgumentParser
from functools import partial
from glob import glob

from pathlib import Path
from pprint import pformat, pprint
from typing import List, Optional
from pprint import pformat
from typing import List, Optional, Union

import typer

from bioimageio.core import __version__, prediction, commands, resource_tests, load_raw_resource_description
from bioimageio.core import __version__, commands, load_raw_resource_description, prediction, resource_tests
from bioimageio.core.common import TestSummary
from bioimageio.core.prediction_pipeline import get_weight_formats
from bioimageio.core.image_helper import load_image, save_image
from bioimageio.core.resource_io import nodes
from bioimageio.core.workflow.operators import run_workflow
from bioimageio.spec.__main__ import app, help_version as help_version_spec
from bioimageio.spec.model.raw_nodes import WeightsFormat
from bioimageio.spec.workflow.raw_nodes import Workflow

try:
from typing import get_args
Expand Down Expand Up @@ -244,7 +248,7 @@ def predict_images(
tiling = json.loads(tiling.replace("'", '"'))
assert isinstance(tiling, dict)

# this is a weird typer bug: default devices are empty tuple although they should be None
# this is a weird typer bug: default devices are empty tuple, although they should be None
if len(devices) == 0:
devices = None
prediction.predict_images(
Expand Down Expand Up @@ -310,5 +314,88 @@ def convert_keras_weights_to_tensorflow(
)


@app.command(context_settings=dict(allow_extra_args=True, ignore_unknown_options=True), add_help_option=False)
def run(
rdf_source: str = typer.Argument(..., help="BioImage.IO RDF id/url/path."),
*,
output_folder: Path = Path("outputs"),
output_tensor_extension: str = ".npy",
ctx: typer.Context,
):
resource = load_raw_resource_description(rdf_source, update_to_format="latest")
if not isinstance(resource, Workflow):
raise NotImplementedError(f"Non-workflow RDFs not yet supported (got type {resource.type})")

map_type = dict(
any=str,
boolean=bool,
float=float,
int=int,
list=str,
string=str,
)
wf = resource
parser = ArgumentParser(description=f"CLI for {wf.name}")

# replicate typer args to show up in help
parser.add_argument(
metavar="rdf-source",
dest="rdf_source",
help="BioImage.IO RDF id/url/path. The optional arguments below are RDF specific.",
)
parser.add_argument(
metavar="output-folder", dest="output_folder", help="Folder to save outputs to.", default=Path("outputs")
)
parser.add_argument(
metavar="output-tensor-extension",
dest="output_tensor_extension",
help="Output tensor extension.",
default=".npy",
)

def add_param_args(params):
for param in params:
argument_kwargs = {}
if param.type == "tensor":
argument_kwargs["type"] = partial(load_image, axes=[a.name or a.type for a in param.axes])
else:
argument_kwargs["type"] = map_type[param.type]

if param.type == "list":
argument_kwargs["nargs"] = "*"

argument_kwargs["help"] = param.description or ""
if hasattr(param, "default"):
argument_kwargs["default"] = param.default
else:
argument_kwargs["required"] = True

argument_kwargs["metavar"] = param.name[0].capitalize()
parser.add_argument("--" + param.name.replace("_", "-"), **argument_kwargs)

def prepare_parameter(value, param: Union[nodes.InputSpec, nodes.OptionSpec]):
if param.type == "tensor":
return load_image(value, [a.name or a.type for a in param.axes])
else:
return value

add_param_args(wf.inputs_spec)
add_param_args(wf.options_spec)
args = parser.parse_args([rdf_source, str(output_folder), output_tensor_extension] + list(ctx.args))
outputs = run_workflow(
rdf_source,
inputs=[prepare_parameter(getattr(args, ipt.name), ipt) for ipt in wf.inputs_spec],
options={opt.name: prepare_parameter(getattr(args, opt.name), opt) for opt in wf.options_spec},
)
output_folder.mkdir(parents=True, exist_ok=True)
for out_spec, out in zip(wf.outputs_spec, outputs):
out_path = output_folder / out_spec.name
if out_spec.type == "tensor":
save_image(out_path.with_suffix(output_tensor_extension), out)
else:
with out_path.with_suffix(".json").open("w") as f:
json.dump(out, f)


if __name__ == "__main__":
app()
28 changes: 19 additions & 9 deletions bioimageio/core/image_helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
#


def transform_input_image(image: np.ndarray, tensor_axes: str, image_axes: Optional[str] = None):
def transform_input_image(image: np.ndarray, tensor_axes: Sequence[str], image_axes: Optional[str] = None):
"""Transform input image into output tensor with desired axes.

Args:
Expand All @@ -35,7 +35,16 @@ def transform_input_image(image: np.ndarray, tensor_axes: str, image_axes: Optio
image_axes = "bczyx"
else:
raise ValueError(f"Invalid number of image dimensions: {ndim}")
tensor = DataArray(image, dims=tuple(image_axes))

# instead of 'b' we might want 'batch', etc...
axis_letter_map = {
letter: name
for letter, name in {"b": "batch", "c": "channel", "i": "index", "t": "time"}
if name in tensor_axes # only do this mapping if the full name is in the desired tensor_axes
}
image_axes = tuple(axis_letter_map.get(a, a) for a in image_axes)

tensor = DataArray(image, dims=image_axes)
# expand the missing image axes
missing_axes = tuple(set(tensor_axes) - set(image_axes))
tensor = tensor.expand_dims(dim=missing_axes)
Expand Down Expand Up @@ -75,9 +84,10 @@ def transform_output_tensor(tensor: np.ndarray, tensor_axes: str, output_axes: s


def to_channel_last(image):
chan_id = image.dims.index("c")
c = "c" if "c" in image.dims else "channel"
chan_id = image.dims.index(c)
if chan_id != image.ndim - 1:
target_axes = tuple(ax for ax in image.dims if ax != "c") + ("c",)
target_axes = tuple(ax for ax in image.dims if ax != c) + (c,)
image = image.transpose(*target_axes)
return image

Expand All @@ -102,20 +112,20 @@ def load_tensors(sources, tensor_specs: List[Union[InputTensor, OutputTensor]])
return [load_image(s, sspec.axes) for s, sspec in zip(sources, tensor_specs)]


def save_image(out_path, image):
ext = os.path.splitext(out_path)[1]
def save_image(out_path: os.PathLike, image):
ext = os.path.splitext(str(out_path))[1]
if ext == ".npy":
np.save(out_path, image)
np.save(str(out_path), image)
else:
is_volume = "z" in image.dims

# squeeze batch or channel axes if they are singletons
squeeze = {ax: 0 if (ax in "bc" and sh == 1) else slice(None) for ax, sh in zip(image.dims, image.shape)}
image = image[squeeze]

if "b" in image.dims:
if "b" in image.dims or "batch" in image.dims:
raise RuntimeError(f"Cannot save prediction with batchsize > 1 as {ext}-file")
if "c" in image.dims: # image formats need channel last
if "c" in image.dims or "channel" in image.dims: # image formats need channel last
image = to_channel_last(image)

save_function = imageio.volsave if is_volume else imageio.imsave
Expand Down
Loading