Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP] improve save_image #331

Draft
wants to merge 1 commit into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
43 changes: 3 additions & 40 deletions bioimageio/core/build_spec/build_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -279,7 +279,6 @@ def _get_dependencies(dependencies, root):


def _get_deepimagej_macro(name, kwargs, export_folder):

# macros available in deepimagej
macro_names = ("binarize", "scale_linear", "scale_range", "zero_mean_unit_variance")
if name == "scale_linear":
Expand Down Expand Up @@ -382,7 +381,7 @@ def get_size(fname, axes):
assert len(shape) == 4
return " x ".join(map(str, shape))

# deepimagej always expexts a pixel size for the z axis
# deepimagej always expects a pixel size for the z axis
pixel_sizes_ = [pix_size if "z" in pix_size else dict(z=1.0, **pix_size) for pix_size in pixel_sizes]

test_info = {
Expand Down Expand Up @@ -410,55 +409,19 @@ def get_size(fname, axes):


def _write_sample_data(input_paths, output_paths, input_axes, output_axes, pixel_sizes, export_folder: Path):
def write_im(path, im, axes, pixel_size=None):
assert len(axes) == im.ndim, f"{len(axes), {im.ndim}}"
assert im.ndim in (4, 5), f"{im.ndim}"

# convert the image to expects (Z)CYX axis order
if im.ndim == 4:
assert set(axes) == {"b", "x", "y", "c"}, f"{axes}"
resolution_axes_ij = "cyxb"
else:
assert set(axes) == {"b", "x", "y", "z", "c"}, f"{axes}"
resolution_axes_ij = "bzcyx"

def addMissingAxes(im_axes):
needed_axes = ["b", "c", "x", "y", "z", "s"]
for ax in needed_axes:
if ax not in im_axes:
im_axes += ax
return im_axes

axes_ij = "bzcyxs"
# Expand the image to ImageJ dimensions
im = np.expand_dims(im, axis=tuple(range(len(axes), len(axes_ij))))

axis_permutation = tuple(addMissingAxes(axes).index(ax) for ax in axes_ij)
im = im.transpose(axis_permutation)

if pixel_size is None:
resolution = None
else:
spatial_axes = list(set(resolution_axes_ij) - set("bc"))
resolution = tuple(1.0 / pixel_size[ax] for ax in resolution_axes_ij if ax in spatial_axes)
# does not work for double
if np.dtype(im.dtype) == np.dtype("float64"):
im = im.astype("float32")
tifffile.imwrite(path, im, imagej=True, resolution=resolution)

sample_in_paths = []
for i, (in_path, axes) in enumerate(zip(input_paths, input_axes)):
inp = np.load(export_folder / in_path)
sample_in_path = export_folder / f"sample_input_{i}.tif"
pixel_size = None if pixel_sizes is None else pixel_sizes[i]
write_im(sample_in_path, inp, axes, pixel_size)
write_tiff_image(sample_in_path, inp, axes, pixel_size)
sample_in_paths.append(sample_in_path)

sample_out_paths = []
for i, (out_path, axes) in enumerate(zip(output_paths, output_axes)):
outp = np.load(export_folder / out_path)
sample_out_path = export_folder / f"sample_output_{i}.tif"
write_im(sample_out_path, outp, axes)
write_tiff_image(sample_out_path, outp, axes)
sample_out_paths.append(sample_out_path)

return [Path(p.name) for p in sample_in_paths], [Path(p.name) for p in sample_out_paths]
Expand Down
65 changes: 58 additions & 7 deletions bioimageio/core/image_helper.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,12 @@
import os
import warnings
from copy import deepcopy
from pathlib import Path
from typing import Dict, List, Optional, Sequence, Tuple, Union

import imageio
import numpy as np
from tifffile import tifffile
from xarray import DataArray
from bioimageio.core.resource_io.nodes import InputTensor, OutputTensor

Expand Down Expand Up @@ -87,7 +90,7 @@ def to_channel_last(image):
#


def load_image(in_path, axes: Sequence[str]) -> DataArray:
def load_image(in_path, axes: Optional[Sequence[str]] = None) -> DataArray:
ext = os.path.splitext(in_path)[1]
if ext == ".npy":
im = np.load(in_path)
Expand All @@ -102,19 +105,25 @@ def load_tensors(sources, tensor_specs: List[Union[InputTensor, OutputTensor]])
return [load_image(s, sspec.axes) for s, sspec in zip(sources, tensor_specs)]


def save_image(out_path, image):
ext = os.path.splitext(out_path)[1]
if ext == ".npy":
np.save(out_path, image)
def save_image(out_path: os.PathLike, image: DataArray, pixel_size=None):
out_path = Path(out_path)
if out_path.suffix == ".npy":
if pixel_size is not None:
warnings.warn("Ignoring 'pixel_size'")
np.save(str(out_path), image)
elif out_path.suffix in (".tif", ".tiff"):
save_imagej_tiff_image(out_path, image)
else:
if pixel_size is not None:
warnings.warn("Ignoring 'pixel_size'")
is_volume = "z" in image.dims

# squeeze batch or channel axes if they are singletons
squeeze = {ax: 0 if (ax in "bc" and sh == 1) else slice(None) for ax, sh in zip(image.dims, image.shape)}
image = image[squeeze]

if "b" in image.dims:
raise RuntimeError(f"Cannot save prediction with batchsize > 1 as {ext}-file")
raise RuntimeError(f"Cannot save prediction with batchsize > 1 as {out_path.suffix}-file")
if "c" in image.dims: # image formats need channel last
image = to_channel_last(image)

Expand All @@ -133,6 +142,49 @@ def save_image(out_path, image):
save_function(chan_out_path, image[..., c])


def save_imagej_tiff_image(path, image: DataArray, pixel_size: Optional[Dict[str, float]] = None):
pixel_size = pixel_size or image.attrs.get("pixel_size")
assert (
pixel_size is None
or isinstance(pixel_size, dict)
and all(isinstance(k, str) and isinstance(v, (int, float)) for k, v in pixel_size.items())
)
assert im.ndim in (4, 5), f"{im.ndim}"

# convert the image to expected (Z)CYX axis order
if im.ndim == 4:
assert set(axes) == {"b", "x", "y", "c"}, f"{axes}"
resolution_axes_ij = "cyxb"
else:
assert set(axes) == {"b", "x", "y", "z", "c"}, f"{axes}"
resolution_axes_ij = "bzcyx"

def add_missing_axes(im_axes):
needed_axes = ["b", "c", "x", "y", "z", "s"]
for ax in needed_axes:
if ax not in im_axes:
im_axes += ax
return im_axes

axes_ij = "bzcyxs"
# Expand the image to ImageJ dimensions
im = np.expand_dims(im, axis=tuple(range(len(axes), len(axes_ij))))

axis_permutation = tuple(add_missing_axes(axes).index(ax) for ax in axes_ij)
im = im.transpose(axis_permutation)

tiff_metadata = {}
if pixel_size is None:
resolution = None
else:
spatial_axes = list(set(resolution_axes_ij) - set("bc"))
resolution = tuple(1.0 / pixel_size[ax] for ax in resolution_axes_ij if ax in spatial_axes)
# does not work for double
if np.dtype(im.dtype) == np.dtype("float64"):
im = im.astype("float32")
tifffile.imwrite(path, im, imagej=True, resolution=resolution)


#
# helper function for padding
#
Expand All @@ -157,7 +209,6 @@ def pad(image, axes: Sequence[str], padding, pad_right=True) -> Tuple[np.ndarray
pad_width = []
crop = {}
for ax, dlen, pr in zip(axes, image.shape, pad_right):

if ax in "zyx":
pad_to = padding_[ax]

Expand Down
4 changes: 4 additions & 0 deletions tests/test_image_helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,3 +27,7 @@ def test_transform_output_tensor():
for out_axes in out_ax_list:
out = transform_output_tensor(tensor, tensor_axes, out_axes)
assert out.ndim == len(out_axes)


def test_save_image():
assert False