Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Downsample scale option #166

Draft
wants to merge 3 commits into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 14 additions & 0 deletions src/omero_zarr/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -244,6 +244,13 @@ def _configure(self, parser: Parser) -> None:
"overlapping labels"
),
)
masks.add_argument(
"--ds_scale",
type=str,
default=None,
help="Downsample scale factors, e.g. 1,1,2,2,2, omitting Image"
" dimensions of size 1 which will be squeezed out of the exported labels",
)

export = parser.add(sub, self.export, EXPORT_HELP)
export.add_argument(
Expand Down Expand Up @@ -283,6 +290,13 @@ def _configure(self, parser: Parser) -> None:
type=ProxyStringType("Image"),
help="The Image to export.",
)
export.add_argument(
"--ds_scale",
type=str,
default=None,
help="Downsample scale factors, e.g. 1,1,2,2,2, omitting dimensions where "
"the image is size 1, since they will be squeezed from the exported Image",
)

for subcommand in (polygons, masks, export):
subcommand.add_argument(
Expand Down
53 changes: 36 additions & 17 deletions src/omero_zarr/masks.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,21 +22,22 @@
import time
from collections import defaultdict
from fileinput import input as finput
from typing import Dict, List, Optional, Set, Tuple
from typing import Any, Dict, List, Optional, Set, Tuple, Union

import numpy as np
import omero.clients # noqa
from ome_zarr.conversions import int_to_rgba_255
from ome_zarr.io import parse_url
from ome_zarr.reader import Multiscales, Node
from ome_zarr.scale import Scaler
from ome_zarr.types import JSONDict
from ome_zarr.writer import write_multiscale_labels
from ome_zarr.writer import write_multiscales_metadata
from omero.model import MaskI, PolygonI
from omero.rtypes import unwrap
from skimage.draw import polygon as sk_polygon
from zarr.convenience import save_array
from zarr.hierarchy import open_group

from .raw_pixels import downsample_pyramid_on_disk
from .util import marshal_axes, marshal_transformations, open_store, print_status

LOGGER = logging.getLogger("omero_zarr.masks")
Expand Down Expand Up @@ -184,6 +185,7 @@ def image_shapes_to_zarr(
args.style,
args.source_image,
args.overlaps,
args.ds_scale,
)

if args.style == "split":
Expand Down Expand Up @@ -218,10 +220,13 @@ def __init__(
style: str = "labeled",
source: str = "..",
overlaps: str = "error",
ds_scale: Union[str, None] = None,
) -> None:
self.dtype = dtype
self.path = path
self.style = style
if ds_scale is not None:
self.ds_scale = [int(scale) for scale in ds_scale.split(",")]
self.source_image = source
self.plate = plate
self.plate_path = Optional[str]
Expand Down Expand Up @@ -310,15 +315,14 @@ def save(self, masks: List[omero.model.Shape], name: str) -> None:
assert src, "Source image does not exist"
input_pyramid = Node(src, [])
assert input_pyramid.load(Multiscales), "No multiscales metadata found"
input_pyramid_levels = len(input_pyramid.data)

store = open_store(filename)
root = open_group(store)

if self.plate:
label_group = root.require_group(self.plate_path)
labels_group = root.require_group(self.plate_path)
else:
label_group = root
labels_group = root.require_group("labels")

_mask_shape: List[int] = list(self.image_shape)
for d in ignored_dimensions:
Expand Down Expand Up @@ -347,10 +351,6 @@ def save(self, masks: List[omero.model.Shape], name: str) -> None:
dims_to_squeeze.append(dim)
labels = np.squeeze(labels, axis=tuple(dims_to_squeeze))

scaler = Scaler(max_layer=input_pyramid_levels)
label_pyramid = scaler.nearest(labels)
transformations = marshal_transformations(self.image, levels=len(label_pyramid))

# Specify and store metadata
image_label_colors: List[JSONDict] = []
label_properties: List[JSONDict] = []
Expand All @@ -369,13 +369,32 @@ def save(self, masks: List[omero.model.Shape], name: str) -> None:
{"label-value": label_value, "rgba": int_to_rgba_255(rgba_int)}
)

write_multiscale_labels(
label_pyramid,
label_group,
name,
axes=axes,
coordinate_transformations=transformations,
label_metadata=image_label,
# Target size for smallest multiresolution
TARGET_SIZE = 96
level_count = 1
longest = max(self.image_shape[-1], self.image_shape[-2])
while longest > TARGET_SIZE:
longest = longest // 2
level_count += 1
paths = [str(level) for level in range(level_count)]

axes = marshal_axes(self.image)
transformations = marshal_transformations(self.image, len(paths), self.ds_scale)

datasets: List[Dict[Any, Any]] = [{"path": path} for path in paths]
for dataset, transform in zip(datasets, transformations):
dataset["coordinateTransformations"] = transform

label_group = labels_group.require_group("0")
labels_group.attrs["labels"] = ["0"]
save_array(store, labels, path="labels/0/0")

label_group.attrs["image-label"] = image_label

downsample_pyramid_on_disk(label_group, paths, ds_scale=self.ds_scale)

write_multiscales_metadata(
label_group, datasets, axes=axes, ds_scale=self.ds_scale
)

def shape_to_binim_yx(
Expand Down
36 changes: 26 additions & 10 deletions src/omero_zarr/raw_pixels.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,12 +55,17 @@ def image_to_zarr(image: omero.gateway.ImageWrapper, args: argparse.Namespace) -
target_dir = args.output
tile_width = args.tile_width
tile_height = args.tile_height
ds_scale = None
if args.ds_scale:
ds_scale = [int(x) for x in args.ds_scale.split(",")]

name = os.path.join(target_dir, "%s.zarr" % image.id)
print(f"Exporting to {name} ({VERSION})")
store = open_store(name)
root = open_group(store)
add_image(image, root, tile_width=tile_width, tile_height=tile_height)
add_image(
image, root, tile_width=tile_width, tile_height=tile_height, ds_scale=ds_scale
)
add_omero_metadata(root, image)
add_toplevel_metadata(root)
print("Finished.")
Expand All @@ -71,6 +76,7 @@ def add_image(
parent: Group,
tile_width: Optional[int] = None,
tile_height: Optional[int] = None,
ds_scale: Optional[List[int]] = None,
) -> Tuple[int, List[Dict[str, Any]]]:
"""Adds an OMERO image pixel data as array to the given parent zarr group.
Returns the number of resolution levels generated for the image.
Expand All @@ -87,16 +93,16 @@ def add_image(
longest = longest // 2
level_count += 1

paths = add_raw_image(image, parent, level_count, tile_width, tile_height)
paths = add_raw_image(image, parent, level_count, tile_width, tile_height, ds_scale)

axes = marshal_axes(image)
transformations = marshal_transformations(image, len(paths))
transformations = marshal_transformations(image, len(paths), ds_scale)

datasets: List[Dict[Any, Any]] = [{"path": path} for path in paths]
for dataset, transform in zip(datasets, transformations):
dataset["coordinateTransformations"] = transform

write_multiscales_metadata(parent, datasets, axes=axes)
write_multiscales_metadata(parent, datasets, axes=axes, ds_scale=ds_scale)

return (level_count, axes)

Expand All @@ -107,6 +113,7 @@ def add_raw_image(
level_count: int,
tile_width: Optional[int] = None,
tile_height: Optional[int] = None,
ds_scale: Optional[List[int]] = None,
) -> List[str]:
pixels = image.getPrimaryPixels()
omero_dtype = image.getPixelsType()
Expand Down Expand Up @@ -198,14 +205,18 @@ def add_raw_image(

paths = [str(level) for level in range(level_count)]

downsample_pyramid_on_disk(parent, paths)
downsample_pyramid_on_disk(parent, paths, ds_scale)
return paths


def downsample_pyramid_on_disk(parent: Group, paths: List[str]) -> List[str]:
def downsample_pyramid_on_disk(
parent: Group, paths: List[str], ds_scale: Optional[List[int]] = None
) -> List[str]:
"""
Takes a high-resolution Zarr array at paths[0] in the zarr group
and down-samples it by a factor of 2 for each of the other paths
and down-samples it by a factor of 2 for each of the other paths by default.
If ds_scale is provided, it will down-sample by the specified factor for each
dimension. e.g. [1, 1, 2, 2, 2]
"""
group_path = parent.store.path
image_path = os.path.join(group_path, parent.path)
Expand All @@ -219,10 +230,15 @@ def downsample_pyramid_on_disk(parent: Group, paths: List[str]) -> List[str]:
path_to_array = os.path.join(image_path, paths[count])
dask_image = da.from_zarr(path_to_array)

# resize in X and Y
dims = list(dask_image.shape)
dims[-1] = dims[-1] // 2
dims[-2] = dims[-2] // 2
# downsample as specified...
if ds_scale is not None:
for dim, dim_scale in enumerate(ds_scale):
dims[dim] = dims[dim] // dim_scale
else:
# resize in X and Y by default
dims[-1] = dims[-1] // 2
dims[-2] = dims[-2] // 2
output = da_resize(
dask_image, tuple(dims), preserve_range=True, anti_aliasing=False
)
Expand Down
18 changes: 12 additions & 6 deletions src/omero_zarr/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
# along with this program. If not, see <http://www.gnu.org/licenses/>.

import time
from typing import Dict, List
from typing import Dict, List, Optional

from omero.gateway import ImageWrapper
from zarr.storage import FSStore
Expand Down Expand Up @@ -105,7 +105,7 @@ def marshal_axes(image: ImageWrapper) -> List[Dict]:


def marshal_transformations(
image: ImageWrapper, levels: int = 1, multiscales_zoom: float = 2.0
image: ImageWrapper, levels: int = 1, ds_scale: Optional[List[int]] = None
) -> List[List[Dict]]:
axes = marshal_axes(image)
pixel_sizes = marshal_pixel_sizes(image)
Expand All @@ -114,7 +114,6 @@ def marshal_transformations(
transformations = []
zooms = {"x": 1.0, "y": 1.0, "z": 1.0, "c": 1.0, "t": 1.0}
for level in range(levels):
# {"type": "scale", "scale": [1, 1, 0.3, 0.5, 0.5]
scales = []
for index, axis in enumerate(axes):
pixel_size = 1
Expand All @@ -123,8 +122,15 @@ def marshal_transformations(
scales.append(zooms[axis["name"]] * pixel_size)
# ...with a single 'scale' transformation each
transformations.append([{"type": "scale", "scale": scales}])
# NB we rescale X and Y for each level, but not Z, C, T
zooms["x"] = zooms["x"] * multiscales_zoom
zooms["y"] = zooms["y"] * multiscales_zoom

if ds_scale is None:
# NB we rescale X and Y for each level, but not Z, C, T
multiscales_zoom = 2.0
zooms["x"] = zooms["x"] * multiscales_zoom
zooms["y"] = zooms["y"] * multiscales_zoom
else:
assert len(ds_scale) == len(axes)
for axis, scale in zip(axes, ds_scale):
zooms[axis["name"]] = zooms[axis["name"]] * scale

return transformations
Loading