diff --git a/ome_zarr/reader.py b/ome_zarr/reader.py index ccc4ea01..d314cc74 100644 --- a/ome_zarr/reader.py +++ b/ome_zarr/reader.py @@ -2,6 +2,7 @@ import logging import math +import os from abc import ABC from typing import Any, Dict, Iterator, List, Optional, Type, Union, cast, overload @@ -26,7 +27,7 @@ def __init__( zarr: ZarrLocation, root: Union["Node", "Reader", List[ZarrLocation]], visibility: bool = True, - plate_labels: bool = False, + # plate_labels: bool = False, ): self.zarr = zarr self.root = root @@ -53,11 +54,11 @@ def __init__( self.specs.append(Multiscales(self)) if OMERO.matches(zarr): self.specs.append(OMERO(self)) - if plate_labels: + # if plate_labels: + if PlateLabels.matches(zarr): self.specs.append(PlateLabels(self)) elif Plate.matches(zarr): self.specs.append(Plate(self)) - # self.add(zarr, plate_labels=True) if Well.matches(zarr): self.specs.append(Well(self)) @@ -136,7 +137,7 @@ def add( visibility = self.visible self.seen.append(zarr) - node = Node(zarr, self, visibility=visibility, plate_labels=plate_labels) + node = Node(zarr, self, visibility=visibility) if prepend: self.pre_nodes.append(node) else: @@ -474,19 +475,18 @@ def matches(zarr: ZarrLocation) -> bool: def __init__(self, node: Node) -> None: super().__init__(node) + LOGGER.debug("Plate created with ZarrLocation fmt: %s", self.zarr.fmt) - self.get_pyramid_lazy(node) - def get_pyramid_lazy(self, node: Node) -> None: - """ - Return a pyramid of dask data, where the highest resolution is the - stitched full-resolution images. - """ - self.plate_data = self.lookup("plate", {}) + self.first_field = "0" + # For Plate, plate_zarr is same as self.zarr, but for PlateLabels + # (node at /plate.zarr/labels) this is the parent at /plate.zarr node. + self.plate_zarr = self.get_plate_zarr() + self.plate_data = self.plate_zarr.root_attrs.get("plate", {}) + LOGGER.info("plate_data: %s", self.plate_data) self.rows = self.plate_data.get("rows") self.columns = self.plate_data.get("columns") - self.first_field = "0" self.row_names = [row["name"] for row in self.rows] self.col_names = [col["name"] for col in self.columns] @@ -496,40 +496,59 @@ def get_pyramid_lazy(self, node: Node) -> None: self.row_count = len(self.rows) self.column_count = len(self.columns) - # Get the first well... - well_zarr = self.zarr.create(self.well_paths[0]) - well_node = Node(well_zarr, node) - well_spec: Optional[Well] = well_node.first(Well) - if well_spec is None: - raise Exception("Could not find first well") - self.numpy_type = well_spec.numpy_type + img_path = self.get_image_path(self.well_paths[0]) + if not img_path: + # E.g. PlateLabels subclass has no Labels + return + image_zarr = self.plate_zarr.create(img_path) + # Create a Node for image, with no 'root' + self.first_well_image = Node(image_zarr, []) + + self.get_pyramid_lazy(node) - LOGGER.debug("img_pyramid_shapes: %s", well_spec.img_pyramid_shapes) + # Load possible node data IF this is a Plate + if Plate.matches(self.zarr): + child_zarr = self.zarr.create("labels") + # This is a 'virtual' path to plate.zarr/labels + node.add(child_zarr) - self.axes = well_spec.img_metadata["axes"] + def get_plate_zarr(self) -> ZarrLocation: + return self.zarr + + def get_pyramid_lazy(self, node: Node) -> None: + """ + Return a pyramid of dask data, where the highest resolution is the + stitched full-resolution images. + """ + + # Use the first well for dtype and shapes + img_data = self.first_well_image.data + img_pyramid_shapes = [d.shape for d in img_data] + level = 0 + + LOGGER.debug("img_pyramid_shapes: %s", img_pyramid_shapes) # Create a dask pyramid for the plate pyramid = [] - for level, tile_shape in enumerate(well_spec.img_pyramid_shapes): + for level, tile_shape in enumerate(img_pyramid_shapes): + self.numpy_type = img_data[level].dtype lazy_plate = self.get_stitched_grid(level, tile_shape) pyramid.append(lazy_plate) # Set the node.data to be pyramid view of the plate node.data = pyramid # Use the first image's metadata for viewing the whole Plate - node.metadata = well_spec.img_metadata + node.metadata = self.first_well_image.metadata # "metadata" dict gets added to each 'plate' layer in napari node.metadata.update({"metadata": {"plate": self.plate_data}}) - def get_numpy_type(self, image_node: Node) -> np.dtype: - return image_node.data[0].dtype + def get_image_path(self, well_path: str) -> Optional[str]: + return f"{well_path}/{self.first_field}/" def get_tile_path(self, level: int, row: int, col: int) -> str: - return ( - f"{self.row_names[row]}/" - f"{self.col_names[col]}/{self.first_field}/{level}" - ) + well_path = f"{self.row_names[row]}/{self.col_names[col]}" + return f"{self.get_image_path(well_path)}{level}/" def get_stitched_grid(self, level: int, tile_shape: tuple) -> da.core.Array: LOGGER.debug("get_stitched_grid() level: %s, tile_shape: %s", level, tile_shape) @@ -541,9 +560,9 @@ def get_tile(tile_name: str) -> np.ndarray: LOGGER.debug("LOADING tile... %s with shape: %s", path, tile_shape) try: - data = self.zarr.load(path) + data = self.plate_zarr.load(path) except ValueError: - LOGGER.exception("Failed to load %s", path) + LOGGER.error("Failed to load %s", path) data = np.zeros(tile_shape, dtype=self.numpy_type) return data @@ -559,53 +578,69 @@ def get_tile(tile_name: str) -> np.ndarray: lazy_reader(tile_name), shape=tile_shape, dtype=self.numpy_type ) lazy_row.append(lazy_tile) - lazy_rows.append(da.concatenate(lazy_row, axis=len(self.axes) - 1)) - return da.concatenate(lazy_rows, axis=len(self.axes) - 2) + lazy_rows.append(da.concatenate(lazy_row, axis=len(tile_shape) - 1)) + return da.concatenate(lazy_rows, axis=len(tile_shape) - 2) class PlateLabels(Plate): - def get_tile_path(self, level: int, row: int, col: int) -> str: # pragma: no cover - """251.zarr/A/1/0/labels/0/3/""" - path = ( - f"{self.row_names[row]}/{self.col_names[col]}/" - f"{self.first_field}/labels/0/{level}" - ) - return path - - def get_pyramid_lazy(self, node: Node) -> None: # pragma: no cover - super().get_pyramid_lazy(node) - # pyramid data may be multi-channel, but we only have 1 labels channel - # TODO: when PlateLabels are re-enabled, update the logic to handle - # 0.4 axes (list of dictionaries) - if "c" in self.axes: - c_index = self.axes.index("c") - idx = [slice(None)] * len(self.axes) - idx[c_index] = slice(0, 1) - node.data[0] = node.data[0][tuple(idx)] + @staticmethod + def matches(zarr: ZarrLocation) -> bool: + # If the path ends in plate/labels... + if not zarr.path.endswith("labels"): + return False + + # and the parent is a plate + parent_path = os.path.dirname(zarr.path) + parent = zarr.create(parent_path) + return "plate" in parent.root_attrs + + def __init__(self, node: Node) -> None: + # cache well/image/labels/.zattrs for first field of each well. Key is e.g. A/1 + self.well_labels_zattrs: Dict[str, Dict] = {} + super().__init__(node) + # remove image metadata - node.metadata = {} + # node.metadata = {} # combine 'properties' from each image # from https://github.com/ome/ome-zarr-py/pull/61/ properties: Dict[int, Dict[str, Any]] = {} - for row in self.row_names: - for col in self.col_names: - path = f"{row}/{col}/{self.first_field}/labels/0/.zattrs" - labels_json = self.zarr.get_json(path).get("image-label", {}) - # NB: assume that 'label_val' is unique across all images - props_list = labels_json.get("properties", []) - if props_list: - for props in props_list: - label_val = props["label-value"] - properties[label_val] = dict(props) - del properties[label_val]["label-value"] + for well_path in self.well_paths: + path = self.get_image_path(well_path) + if not path: + continue + labels_json = self.zarr.get_json(path + ".zattrs").get("image-label", {}) + # NB: assume that 'label_val' is unique across all images + props_list = labels_json.get("properties", []) + if props_list: + for props in props_list: + label_val = props["label-value"] + properties[label_val] = dict(props) + del properties[label_val]["label-value"] node.metadata["properties"] = properties - def get_numpy_type(self, image_node: Node) -> np.dtype: # pragma: no cover - # FIXME - don't assume Well A1 is valid - path = self.get_tile_path(0, 0, 0) - label_zarr = self.zarr.load(path) - return label_zarr.dtype + def get_plate_zarr(self) -> ZarrLocation: + # lookup parent plate, remove the /labels + parent_path = os.path.dirname(self.zarr.path) + return self.zarr.create(parent_path) + + def get_image_path(self, well_path: str) -> Optional[str]: + """Returns path to .zattr for Well labels, e.g. /A/1/0/labels/my_cells/""" + labels_attrs = self.well_labels_zattrs.get(well_path) + if labels_attrs is None: + # if not cached, load... + path = f"{well_path}/{self.first_field}/labels/" + LOGGER.info("loading labels/.zattrs: %s.zattrs", path) + plate_zarr = self.get_plate_zarr() + first_field_labels = plate_zarr.create(path) + # loads labels/.zattrs when new ZarrLocation is created + labels_attrs = first_field_labels.root_attrs + self.well_labels_zattrs[well_path] = labels_attrs + label_paths = labels_attrs.get("labels", []) + LOGGER.debug("label_paths: %s", label_paths) + if len(label_paths) > 0: + return f"{well_path}/{self.first_field}/labels/{label_paths[0]}/" + return None class Reader: diff --git a/tests/test_reader.py b/tests/test_reader.py index 556a2087..d3611c75 100644 --- a/tests/test_reader.py +++ b/tests/test_reader.py @@ -4,8 +4,13 @@ from ome_zarr.data import create_zarr from ome_zarr.io import parse_url -from ome_zarr.reader import Node, Plate, Reader -from ome_zarr.writer import write_image, write_plate_metadata, write_well_metadata +from ome_zarr.reader import Node, Plate, PlateLabels, Reader +from ome_zarr.writer import ( + write_image, + write_labels, + write_plate_metadata, + write_well_metadata, +) class TestReader: @@ -68,8 +73,7 @@ def test_minimal_plate(self): reader = Reader(parse_url(str(self.path))) nodes = list(reader()) - # currently reading plate labels disabled. Only 1 node - assert len(nodes) == 1 + assert len(nodes) == 2 assert len(nodes[0].specs) == 1 assert isinstance(nodes[0].specs[0], Plate) # assert len(nodes[1].specs) == 1 @@ -87,13 +91,19 @@ def test_multiwells_plate(self): write_well_metadata(well, ["0", "1", "2"]) for field in range(3): image = well.require_group(str(field)) - write_image(zeros((1, 1, 1, 256, 256)), image) + write_image(zeros((256, 256)), image) + + write_labels(zeros((256, 256)), image, name="test_labels") reader = Reader(parse_url(str(self.path))) nodes = list(reader()) - # currently reading plate labels disabled. Only 1 node - assert len(nodes) == 1 + assert len(nodes) == 2 assert len(nodes[0].specs) == 1 assert isinstance(nodes[0].specs[0], Plate) - # assert len(nodes[1].specs) == 1 - # assert isinstance(nodes[1].specs[0], PlateLabels) + assert len(nodes[1].specs) == 1 + assert isinstance(nodes[1].specs[0], PlateLabels) + # plate shape is the single image * grid dimensions + plate_shape = (256 * len(row_names), 256 * len(col_names)) + # check largest data for image and labels + assert nodes[0].data[0].shape == plate_shape + assert nodes[1].data[0].shape == plate_shape