diff --git a/ome_zarr.py b/ome_zarr.py index a18307b1..5af84c08 100644 --- a/ome_zarr.py +++ b/ome_zarr.py @@ -102,6 +102,20 @@ def is_zarr(self): def is_ome_zarr(self): return self.zgroup and "multiscales" in self.root_attrs + def has_ome_masks(self): + "Does the zarr Image also include /masks sub-dir" + return self.get_json('masks/.zgroup') + + def is_ome_mask(self): + return self.zarr_path.endswith('masks/') and self.get_json('.zgroup') + + def get_mask_names(self): + """ + Called if is_ome_mask is true + """ + # If this is a mask, the names are in root .zattrs + return self.root_attrs.get('masks', []) + def get_json(self, subpath): raise NotImplementedError("unknown") @@ -110,6 +124,10 @@ def get_reader_function(self): raise Exception(f"not a zarr: {self}") return self.reader_function + def to_rgba(self, v): + """Get rgba (0-1) e.g. (1, 0.5, 0, 1) from integer""" + return [x/255 for x in v.to_bytes(4, signed=True, byteorder='big')] + def reader_function(self, path: PathLike) -> List[LayerData]: """Take a path or list of paths and return a list of LayerData tuples.""" @@ -118,12 +136,22 @@ def reader_function(self, path: PathLike) -> List[LayerData]: # TODO: safe to ignore this path? if self.is_ome_zarr(): - return [self.load_ome_zarr()] + layers = [self.load_ome_zarr()] + # If the Image contains masks... + if self.has_ome_masks(): + mask_path = os.path.join(self.zarr_path, 'masks') + # Create a new OME Zarr Reader to load masks + masks = self.__class__(mask_path).reader_function(None) + layers.extend(masks) + return layers elif self.zarray: data = da.from_zarr(f"{self.zarr_path}") return [(data,)] + elif self.is_ome_mask(): + return self.load_ome_masks() + def load_omero_metadata(self, assert_channel_count=None): """Load OMERO metadata as json and convert for napari""" metadata = {} @@ -191,7 +219,6 @@ def load_omero_metadata(self, assert_channel_count=None): return metadata - def load_ome_zarr(self): resolutions = ["0"] # TODO: could be first alphanumeric dataset on err @@ -219,6 +246,25 @@ def load_ome_zarr(self): return (pyramid, {'channel_axis': 1, **metadata}) + def load_ome_masks(self): + # look for masks in this dir... + mask_names = self.get_mask_names() + masks = [] + for name in mask_names: + mask_path = os.path.join(self.zarr_path, name) + mask_attrs = self.get_json(f'{name}/.zattrs') + colors = {} + if 'color' in mask_attrs: + color_dict = mask_attrs.get('color') + colors = {int(k):self.to_rgba(v) for (k, v) in color_dict.items()} + data = da.from_zarr(mask_path) + # Split masks into separate channels, 1 per layer + for n in range(data.shape[1]): + masks.append((data[:,n,:,:,:], + {'name': name, 'color': colors}, + 'labels')) + return masks + class LocalZarr(BaseZarr): @@ -231,7 +277,6 @@ def get_json(self, subpath): with open(filename) as f: return json.loads(f.read()) - class RemoteZarr(BaseZarr): def get_json(self, subpath): @@ -249,7 +294,6 @@ def get_json(self, subpath): LOGGER.error(f"({rsp.status_code}): {rsp.text}") return {} - def info(path): """ print information about the ome-zarr fileset