Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Plate labels fix #207

Open
wants to merge 14 commits into
base: master
Choose a base branch
from
Open
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
121 changes: 59 additions & 62 deletions ome_zarr/reader.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ def __init__(
self.specs.append(PlateLabels(self))
elif Plate.matches(zarr):
self.specs.append(Plate(self))
# self.add(zarr, plate_labels=True)
self.add(zarr, plate_labels=True)
if Well.matches(zarr):
self.specs.append(Well(self))

Expand Down Expand Up @@ -465,18 +465,17 @@ def matches(zarr: ZarrLocation) -> bool:
def __init__(self, node: Node) -> None:
super().__init__(node)
LOGGER.debug(f"Plate created with ZarrLocation fmt:{ self.zarr.fmt}")
self.get_pyramid_lazy(node)

def get_pyramid_lazy(self, node: Node) -> None:
"""
Return a pyramid of dask data, where the highest resolution is the
stitched full-resolution images.
"""
self.first_field = "0"
self.plate_data = self.lookup("plate", {})
first_well_path = self.plate_data["wells"][0]["path"]
image_zarr = self.zarr.create(self.get_image_path(first_well_path))
# Create a Node for image, with no 'root'
self.first_well_image = Node(image_zarr, [])
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@joshmoore @sbesson Here I'm creating a Node for an Image as I want a Multiscales spec to do the parsing of the datasets, to give me the sizes of each resolution of the pyramid. However, this is leading to recursion errors in the tests (although it seems to work fine for me viewing local data in napari).

I don't see a way to reuse that logic in Multiscales spec, without creating a Node? But I don't need any Node traversing logic. I guess I want to create a Multiscales spec with no node. Should I make the node optional in the Spec class, or is there another way I should be thinking about this?

NB: get_image_path() is overridden by the PlateLabels subclass to point to a labels image, so this works both for images in Wells and their child labels.

Copy link
Member

@joshmoore joshmoore Jun 23, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry for having missed this. Without stepping through your code, I don't know offhand why you're getting the recursion. I assume that the seen field is not getting updated and therefore it just keeps looping. Perhaps that's caused by not setting the root.

What functionality do you want from the Spec without a Node? Could we just refactor that logic somewhere re-usable? (Perhaps statically?)

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think I just need the data.shape of each resolution of the multiscales pyramid, and the dtype.
Before my last commit d731b87 I was passing in a root node (plate node) but that was also giving me recursion errors:

self.first_well_image = Node(image_zarr, node)

I'll try stepping though my code and see if I can work out the loop...


LOGGER.info("plate_data: %s", self.plate_data)
self.rows = self.plate_data.get("rows")
self.columns = self.plate_data.get("columns")
self.first_field = "0"
self.row_names = [row["name"] for row in self.rows]
self.col_names = [col["name"] for col in self.columns]

Expand All @@ -486,40 +485,42 @@ def get_pyramid_lazy(self, node: Node) -> None:
self.row_count = len(self.rows)
self.column_count = len(self.columns)

# Get the first well...
well_zarr = self.zarr.create(self.well_paths[0])
well_node = Node(well_zarr, node)
well_spec: Optional[Well] = well_node.first(Well)
if well_spec is None:
raise Exception("could not find first well")
self.numpy_type = well_spec.numpy_type
self.get_pyramid_lazy(node)

LOGGER.debug(f"img_pyramid_shapes: {well_spec.img_pyramid_shapes}")
def get_pyramid_lazy(self, node: Node) -> None:
"""
Return a pyramid of dask data, where the highest resolution is the
stitched full-resolution images.
"""

# Use the first well for dtype and shapes
img_data = self.first_well_image.data
img_pyramid_shapes = [d.shape for d in img_data]
level = 0
self.numpy_type = img_data[level].dtype

self.axes = well_spec.img_metadata["axes"]
LOGGER.debug(f"img_pyramid_shapes: {img_pyramid_shapes}")

# Create a dask pyramid for the plate
pyramid = []
for level, tile_shape in enumerate(well_spec.img_pyramid_shapes):
for level, tile_shape in enumerate(img_pyramid_shapes):
lazy_plate = self.get_stitched_grid(level, tile_shape)
pyramid.append(lazy_plate)

# Set the node.data to be pyramid view of the plate
node.data = pyramid
# Use the first image's metadata for viewing the whole Plate
node.metadata = well_spec.img_metadata
node.metadata = self.first_well_image.metadata

# "metadata" dict gets added to each 'plate' layer in napari
node.metadata.update({"metadata": {"plate": self.plate_data}})

def get_numpy_type(self, image_node: Node) -> np.dtype:
return image_node.data[0].dtype
def get_image_path(self, well_path: str) -> str:
return f"{well_path}/{self.first_field}/"

def get_tile_path(self, level: int, row: int, col: int) -> str:
return (
f"{self.row_names[row]}/"
f"{self.col_names[col]}/{self.first_field}/{level}"
)
well_path = f"{self.row_names[row]}/{self.col_names[col]}"
return f"{self.get_image_path(well_path)}{level}/"

def get_stitched_grid(self, level: int, tile_shape: tuple) -> da.core.Array:
LOGGER.debug(f"get_stitched_grid() level: {level}, tile_shape: {tile_shape}")
Expand Down Expand Up @@ -550,53 +551,49 @@ def get_tile(tile_name: str) -> np.ndarray:
lazy_reader(tile_name), shape=tile_shape, dtype=self.numpy_type
)
lazy_row.append(lazy_tile)
lazy_rows.append(da.concatenate(lazy_row, axis=len(self.axes) - 1))
return da.concatenate(lazy_rows, axis=len(self.axes) - 2)
lazy_rows.append(da.concatenate(lazy_row, axis=len(tile_shape) - 1))
return da.concatenate(lazy_rows, axis=len(tile_shape) - 2)


class PlateLabels(Plate):
def get_tile_path(self, level: int, row: int, col: int) -> str: # pragma: no cover
"""251.zarr/A/1/0/labels/0/3/"""
path = (
f"{self.row_names[row]}/{self.col_names[col]}/"
f"{self.first_field}/labels/0/{level}"
)
return path

def get_pyramid_lazy(self, node: Node) -> None: # pragma: no cover
super().get_pyramid_lazy(node)
# pyramid data may be multi-channel, but we only have 1 labels channel
# TODO: when PlateLabels are re-enabled, update the logic to handle
# 0.4 axes (list of dictionaries)
if "c" in self.axes:
c_index = self.axes.index("c")
idx = [slice(None)] * len(self.axes)
idx[c_index] = slice(0, 1)
node.data[0] = node.data[0][tuple(idx)]
def __init__(self, node: Node) -> None:
# cache well/image/labels/.zattrs for first field of each well. Key is e.g. A/1
self.well_labels_zattrs: Dict[str, Dict] = {}
super().__init__(node)

# remove image metadata
node.metadata = {}
# node.metadata = {}

# combine 'properties' from each image
# from https://github.com/ome/ome-zarr-py/pull/61/
properties: Dict[int, Dict[str, Any]] = {}
for row in self.row_names:
for col in self.col_names:
path = f"{row}/{col}/{self.first_field}/labels/0/.zattrs"
labels_json = self.zarr.get_json(path).get("image-label", {})
# NB: assume that 'label_val' is unique across all images
props_list = labels_json.get("properties", [])
if props_list:
for props in props_list:
label_val = props["label-value"]
properties[label_val] = dict(props)
del properties[label_val]["label-value"]
for well_path in self.well_paths:
path = self.get_image_path(well_path) + ".zattrs"
labels_json = self.zarr.get_json(path).get("image-label", {})
# NB: assume that 'label_val' is unique across all images
props_list = labels_json.get("properties", [])
if props_list:
for props in props_list:
label_val = props["label-value"]
properties[label_val] = dict(props)
del properties[label_val]["label-value"]
node.metadata["properties"] = properties

def get_numpy_type(self, image_node: Node) -> np.dtype: # pragma: no cover
# FIXME - don't assume Well A1 is valid
path = self.get_tile_path(0, 0, 0)
label_zarr = self.zarr.load(path)
return label_zarr.dtype
def get_image_path(self, well_path: str) -> str:
"""Returns path to .zattr for Well labels, e.g. /A/1/0/labels/my_cells/"""
labels_attrs = self.well_labels_zattrs.get(well_path)
if labels_attrs is None:
# if not cached, load...
path = f"{well_path}/{self.first_field}/labels/"
LOGGER.info("loading labels/.zattrs: %s.zattrs", path)
first_field_labels = self.zarr.create(path)
# loads labels/.zattrs when new ZarrLocation is created
labels_attrs = first_field_labels.root_attrs
self.well_labels_zattrs[well_path] = labels_attrs
label_paths = labels_attrs.get("labels", [])
if len(label_paths) > 0:
return f"{well_path}/{self.first_field}/labels/{label_paths[0]}/"
return ""


class Reader:
Expand Down