diff --git a/bioimageio/core/build_spec/build_model.py b/bioimageio/core/build_spec/build_model.py index dfe50d14..50218756 100644 --- a/bioimageio/core/build_spec/build_model.py +++ b/bioimageio/core/build_spec/build_model.py @@ -279,7 +279,6 @@ def _get_dependencies(dependencies, root): def _get_deepimagej_macro(name, kwargs, export_folder): - # macros available in deepimagej macro_names = ("binarize", "scale_linear", "scale_range", "zero_mean_unit_variance") if name == "scale_linear": @@ -382,7 +381,7 @@ def get_size(fname, axes): assert len(shape) == 4 return " x ".join(map(str, shape)) - # deepimagej always expexts a pixel size for the z axis + # deepimagej always expects a pixel size for the z axis pixel_sizes_ = [pix_size if "z" in pix_size else dict(z=1.0, **pix_size) for pix_size in pixel_sizes] test_info = { @@ -410,55 +409,19 @@ def get_size(fname, axes): def _write_sample_data(input_paths, output_paths, input_axes, output_axes, pixel_sizes, export_folder: Path): - def write_im(path, im, axes, pixel_size=None): - assert len(axes) == im.ndim, f"{len(axes), {im.ndim}}" - assert im.ndim in (4, 5), f"{im.ndim}" - - # convert the image to expects (Z)CYX axis order - if im.ndim == 4: - assert set(axes) == {"b", "x", "y", "c"}, f"{axes}" - resolution_axes_ij = "cyxb" - else: - assert set(axes) == {"b", "x", "y", "z", "c"}, f"{axes}" - resolution_axes_ij = "bzcyx" - - def addMissingAxes(im_axes): - needed_axes = ["b", "c", "x", "y", "z", "s"] - for ax in needed_axes: - if ax not in im_axes: - im_axes += ax - return im_axes - - axes_ij = "bzcyxs" - # Expand the image to ImageJ dimensions - im = np.expand_dims(im, axis=tuple(range(len(axes), len(axes_ij)))) - - axis_permutation = tuple(addMissingAxes(axes).index(ax) for ax in axes_ij) - im = im.transpose(axis_permutation) - - if pixel_size is None: - resolution = None - else: - spatial_axes = list(set(resolution_axes_ij) - set("bc")) - resolution = tuple(1.0 / pixel_size[ax] for ax in resolution_axes_ij if ax in spatial_axes) - # does not work for double - if np.dtype(im.dtype) == np.dtype("float64"): - im = im.astype("float32") - tifffile.imwrite(path, im, imagej=True, resolution=resolution) - sample_in_paths = [] for i, (in_path, axes) in enumerate(zip(input_paths, input_axes)): inp = np.load(export_folder / in_path) sample_in_path = export_folder / f"sample_input_{i}.tif" pixel_size = None if pixel_sizes is None else pixel_sizes[i] - write_im(sample_in_path, inp, axes, pixel_size) + write_tiff_image(sample_in_path, inp, axes, pixel_size) sample_in_paths.append(sample_in_path) sample_out_paths = [] for i, (out_path, axes) in enumerate(zip(output_paths, output_axes)): outp = np.load(export_folder / out_path) sample_out_path = export_folder / f"sample_output_{i}.tif" - write_im(sample_out_path, outp, axes) + write_tiff_image(sample_out_path, outp, axes) sample_out_paths.append(sample_out_path) return [Path(p.name) for p in sample_in_paths], [Path(p.name) for p in sample_out_paths] diff --git a/bioimageio/core/image_helper.py b/bioimageio/core/image_helper.py index 0468b61f..8433a4f0 100644 --- a/bioimageio/core/image_helper.py +++ b/bioimageio/core/image_helper.py @@ -1,9 +1,12 @@ import os +import warnings from copy import deepcopy +from pathlib import Path from typing import Dict, List, Optional, Sequence, Tuple, Union import imageio import numpy as np +from tifffile import tifffile from xarray import DataArray from bioimageio.core.resource_io.nodes import InputTensor, OutputTensor @@ -87,7 +90,7 @@ def to_channel_last(image): # -def load_image(in_path, axes: Sequence[str]) -> DataArray: +def load_image(in_path, axes: Optional[Sequence[str]] = None) -> DataArray: ext = os.path.splitext(in_path)[1] if ext == ".npy": im = np.load(in_path) @@ -102,11 +105,17 @@ def load_tensors(sources, tensor_specs: List[Union[InputTensor, OutputTensor]]) return [load_image(s, sspec.axes) for s, sspec in zip(sources, tensor_specs)] -def save_image(out_path, image): - ext = os.path.splitext(out_path)[1] - if ext == ".npy": - np.save(out_path, image) +def save_image(out_path: os.PathLike, image: DataArray, pixel_size=None): + out_path = Path(out_path) + if out_path.suffix == ".npy": + if pixel_size is not None: + warnings.warn("Ignoring 'pixel_size'") + np.save(str(out_path), image) + elif out_path.suffix in (".tif", ".tiff"): + save_imagej_tiff_image(out_path, image) else: + if pixel_size is not None: + warnings.warn("Ignoring 'pixel_size'") is_volume = "z" in image.dims # squeeze batch or channel axes if they are singletons @@ -114,7 +123,7 @@ def save_image(out_path, image): image = image[squeeze] if "b" in image.dims: - raise RuntimeError(f"Cannot save prediction with batchsize > 1 as {ext}-file") + raise RuntimeError(f"Cannot save prediction with batchsize > 1 as {out_path.suffix}-file") if "c" in image.dims: # image formats need channel last image = to_channel_last(image) @@ -133,6 +142,49 @@ def save_image(out_path, image): save_function(chan_out_path, image[..., c]) +def save_imagej_tiff_image(path, image: DataArray, pixel_size: Optional[Dict[str, float]] = None): + pixel_size = pixel_size or image.attrs.get("pixel_size") + assert ( + pixel_size is None + or isinstance(pixel_size, dict) + and all(isinstance(k, str) and isinstance(v, (int, float)) for k, v in pixel_size.items()) + ) + assert im.ndim in (4, 5), f"{im.ndim}" + + # convert the image to expected (Z)CYX axis order + if im.ndim == 4: + assert set(axes) == {"b", "x", "y", "c"}, f"{axes}" + resolution_axes_ij = "cyxb" + else: + assert set(axes) == {"b", "x", "y", "z", "c"}, f"{axes}" + resolution_axes_ij = "bzcyx" + + def add_missing_axes(im_axes): + needed_axes = ["b", "c", "x", "y", "z", "s"] + for ax in needed_axes: + if ax not in im_axes: + im_axes += ax + return im_axes + + axes_ij = "bzcyxs" + # Expand the image to ImageJ dimensions + im = np.expand_dims(im, axis=tuple(range(len(axes), len(axes_ij)))) + + axis_permutation = tuple(add_missing_axes(axes).index(ax) for ax in axes_ij) + im = im.transpose(axis_permutation) + + tiff_metadata = {} + if pixel_size is None: + resolution = None + else: + spatial_axes = list(set(resolution_axes_ij) - set("bc")) + resolution = tuple(1.0 / pixel_size[ax] for ax in resolution_axes_ij if ax in spatial_axes) + # does not work for double + if np.dtype(im.dtype) == np.dtype("float64"): + im = im.astype("float32") + tifffile.imwrite(path, im, imagej=True, resolution=resolution) + + # # helper function for padding # @@ -157,7 +209,6 @@ def pad(image, axes: Sequence[str], padding, pad_right=True) -> Tuple[np.ndarray pad_width = [] crop = {} for ax, dlen, pr in zip(axes, image.shape, pad_right): - if ax in "zyx": pad_to = padding_[ax] diff --git a/tests/test_image_helper.py b/tests/test_image_helper.py index 9c495de1..6638e9e5 100644 --- a/tests/test_image_helper.py +++ b/tests/test_image_helper.py @@ -27,3 +27,7 @@ def test_transform_output_tensor(): for out_axes in out_ax_list: out = transform_output_tensor(tensor, tensor_axes, out_axes) assert out.ndim == len(out_axes) + + +def test_save_image(): + assert False