diff --git a/README.md b/README.md index e7f6a245..a4181082 100644 --- a/README.md +++ b/README.md @@ -1,4 +1,4 @@ -# CellSeg3D: self-supervised (and supervised) 3D cell segmentation, primarily for mesoSPIM data! +# CellSeg3D: self-supervised (and supervised) 3D cell segmentation, primarily for mesoSPIM data! [![napari hub](https://img.shields.io/endpoint?url=https://api.napari-hub.org/shields/napari-cellseg3d)](https://www.napari-hub.org/plugins/napari-cellseg3d) [![PyPI](https://img.shields.io/pypi/v/napari-cellseg3d.svg?color=green)](https://pypi.org/project/napari-cellseg3d) [![Downloads](https://static.pepy.tech/badge/napari-cellseg3d)](https://pepy.tech/project/napari-cellseg3d) @@ -59,11 +59,17 @@ The strength of our approach is we can match supervised model performance with p ## News -**New version: v0.2.0** +**New version: v0.2.1** -- Changed project name to "napari_cellseg3d" to avoid setuptools deprecation -- Small API changes for training/inference from a script -- Some fixes to WandB integration ad csv saving after training +- v0.2.1: + - Updated plugin default behaviors across the board to be more readily applicable to demo data + - Threshold value in inference is now automatically set by default according to performance on demo data on a per-model basis + - Added a grid search utility to find best thresholds for supervised models + +- v0.2.0: + - Changed project name to "napari_cellseg3d" to avoid setuptools deprecation + - Small API changes for training/inference from a script + - Some fixes to WandB integration and csv saving after training Previous additions: diff --git a/docs/source/guides/inference_module_guide.rst b/docs/source/guides/inference_module_guide.rst index e209486f..7a986be5 100644 --- a/docs/source/guides/inference_module_guide.rst +++ b/docs/source/guides/inference_module_guide.rst @@ -50,13 +50,13 @@ Interface and functionalities Inference parameters -* **Loading data** : +* **Loading data**: | When launching the module, select either an **image layer** or an **image folder** containing the 3D volumes you wish to label. | When loading from folder : All images with the chosen extension ( currently **.tif**) will be labeled. | Specify an **output folder**, where the labelled results will be saved. -* **Model selection** : +* **Model selection**: | You can then choose from the listed **models** for inference. | You may also **load custom weights** rather than the pre-trained ones. Make sure these weights are **compatible** (e.g. produced from the training module for the same model). @@ -66,7 +66,7 @@ Interface and functionalities Currently the SegResNet and SwinUNetR models require you to provide the size of the images the model was trained with. Provided weights use a size of 64, please leave it on the default value if you're not using custom weights. -* **Inference parameters** : +* **Inference parameters**: * **Window inference**: You can choose to use inference on the entire image at once (disabled) or divide the image (enabled) on smaller chunks, based on your memory constraints. * **Window overlap**: Define the overlap between windows to reduce border effects; @@ -74,11 +74,11 @@ Interface and functionalities * **Keep on CPU**: You can choose to keep the dataset in RAM rather than VRAM to avoid running out of VRAM if you have several images. * **Device Selection**: You can choose to run inference on either CPU or GPU. A GPU is recommended for faster inference. -* **Anisotropy** : +* **Anisotropy**: For **anisotropic images** you may set the **resolution of your volume in micron**, to view and save the results without anisotropy. -* **Thresholding** : +* **Thresholding**: You can perform thresholding to **binarize your labels**. All values below the **confidence threshold** will be set to 0. @@ -87,7 +87,7 @@ Interface and functionalities It is recommended to first run without thresholding. You can then use the napari contrast limits to find a good threshold value, and run inference later with your chosen threshold. -* **Instance segmentation** : +* **Instance segmentation**: | You can convert the semantic segmentation into instance labels by using either the `Voronoi-Otsu`_, `Watershed`_ or `Connected Components`_ method, as detailed in :ref:`utils_module_guide`. | Instance labels will be saved (and shown if applicable) separately from other results. @@ -98,7 +98,7 @@ Interface and functionalities .. _Voronoi-Otsu: https://haesleinhuepf.github.io/BioImageAnalysisNotebooks/20_image_segmentation/11_voronoi_otsu_labeling.html -* **Computing objects statistics** : +* **Computing objects statistics**: You can choose to compute various stats from the labels and save them to a **`.csv`** file for later use. Statistics include individual object details and general metrics. @@ -109,7 +109,7 @@ Interface and functionalities * Sphericity - Global metrics : + Global metrics: * Image size * Total image volume (pixels) @@ -118,7 +118,7 @@ Interface and functionalities * The number of labeled objects -* **Display options** : +* **Display options**: When running inference on a folder, you can choose to display the results in napari. If selected, you may choose the display quantity, and whether to display the original image alongside the results. @@ -151,7 +151,16 @@ Unsupervised model - WNet3D | The `WNet3D model` is a fully self-supervised model used to segment images without any labels. | It functions similarly to the above models, with a few notable differences. -.. _WNet3D model: https://arxiv.org/abs/1711.08506 +WNet3D has been tested on: + +* **MesoSPIM** data (whole-brain samples of mice imaged by mesoSPIM microscopy) with nuclei staining. +* Other microscopy (i.e., confocal) data with: + * **Sufficient contrast** between objects and background. + * **Low to medium crowding** of objects. If all objects are adjacent to each other, instance segmentation methods provided here may not work well. + +Noise and object size are less critical, though objects still have to fit within the field of view of the model. + +.. _WNet3D model: https://elifesciences.org/reviewed-preprints/99848 .. note:: Our provided, pre-trained model uses an input size of 64x64x64. As such, window inference is always enabled diff --git a/docs/source/guides/utils_module_guide.rst b/docs/source/guides/utils_module_guide.rst index 0dc0fbf0..fe458d99 100644 --- a/docs/source/guides/utils_module_guide.rst +++ b/docs/source/guides/utils_module_guide.rst @@ -11,6 +11,21 @@ See `Usage section 0: for slider in self.sliders: - self.recorded_parameters[ - slider.label.text() - ] = slider.slider_value + self.recorded_parameters[slider.label.text()] = ( + slider.slider_value + ) if len(self.counters) > 0: for counter in self.counters: - self.recorded_parameters[ - counter.label.text() - ] = counter.value() + self.recorded_parameters[counter.label.text()] = ( + counter.value() + ) def run_method_from_params(self, image): """Runs the method on the image with the RECORDED parameters set in the widget. @@ -327,10 +328,8 @@ def binary_connected( ) semantic = np.squeeze(volume) foreground = np.where(semantic > thres, volume, 0) # int(255 * thres) - segm = label(foreground) - segm = remove_small_objects(segm, thres_small) - - return segm + seg = label(foreground) + return remove_small_objects(seg, thres_small) def binary_watershed( @@ -419,15 +418,10 @@ def clear_small_objects(image, threshold, is_file_path=False): if is_file_path: image = imread(image) - # print(threshold) - labeled = label(image) result = remove_small_objects(labeled, threshold) - # print(np.sum(labeled)) - # print(np.sum(result)) - if np.sum(labeled) == np.sum(result): print("Warning : no objects were removed") @@ -551,9 +545,9 @@ def __init__(self, widget_parent=None): ) self.sliders[0].label.setText("Foreground probability threshold") - self.sliders[ - 0 - ].tooltips = "Probability threshold for foreground object" + self.sliders[0].tooltips = ( + "Probability threshold for foreground object" + ) self.sliders[0].setValue(500) self.sliders[1].label.setText("Seed probability threshold") @@ -652,9 +646,9 @@ def __init__(self, widget_parent=None): ) self.sliders[0].label.setText("Foreground probability threshold") - self.sliders[ - 0 - ].tooltips = "Probability threshold for foreground object" + self.sliders[0].tooltips = ( + "Probability threshold for foreground object" + ) self.sliders[0].setValue(800) self.counters[0].label.setText("Small objects removal") @@ -715,18 +709,18 @@ def __init__(self, widget_parent=None): widget_parent=widget_parent, ) self.counters[0].label.setText("Spot sigma") # closeness - self.counters[ - 0 - ].tooltips = "Determines how close detected objects can be" + self.counters[0].tooltips = ( + "Determines how close detected objects can be" + ) self.counters[0].setMaximum(100) - self.counters[0].setValue(2) + self.counters[0].setValue(0.65) self.counters[1].label.setText("Outline sigma") # smoothness - self.counters[ - 1 - ].tooltips = "Determines the smoothness of the segmentation" + self.counters[1].tooltips = ( + "Determines the smoothness of the segmentation" + ) self.counters[1].setMaximum(100) - self.counters[1].setValue(2) + self.counters[1].setValue(0.65) self.counters[2].label.setText("Small object removal") self.counters[2].tooltips = ( diff --git a/napari_cellseg3d/code_models/models/TEMPLATE_model.py b/napari_cellseg3d/code_models/models/TEMPLATE_model.py index 0586c0b4..7c33adf4 100644 --- a/napari_cellseg3d/code_models/models/TEMPLATE_model.py +++ b/napari_cellseg3d/code_models/models/TEMPLATE_model.py @@ -2,6 +2,7 @@ Please note that custom model implementations are not fully supported out of the box yet, but might be in the future. """ + from abc import ABC, abstractmethod @@ -11,6 +12,7 @@ class ModelTemplate_(ABC): weights_file = ( "model_template.pth" # specify the file name of the weights file only ) + default_threshold = 0.5 # specify the default threshold for the model @abstractmethod def __init__( diff --git a/napari_cellseg3d/code_models/models/model_SegResNet.py b/napari_cellseg3d/code_models/models/model_SegResNet.py index 58b932e8..ef1e7492 100644 --- a/napari_cellseg3d/code_models/models/model_SegResNet.py +++ b/napari_cellseg3d/code_models/models/model_SegResNet.py @@ -1,4 +1,5 @@ """SegResNet wrapper for napari_cellseg3d.""" + from monai.networks.nets import SegResNetVAE @@ -6,6 +7,7 @@ class SegResNet_(SegResNetVAE): """SegResNet_ wrapper for napari_cellseg3d.""" weights_file = "SegResNet_latest.pth" + default_threshold = 0.3 def __init__( self, input_img_size, out_channels=1, dropout_prob=0.3, **kwargs diff --git a/napari_cellseg3d/code_models/models/model_SwinUNetR.py b/napari_cellseg3d/code_models/models/model_SwinUNetR.py index 286defb9..1335b388 100644 --- a/napari_cellseg3d/code_models/models/model_SwinUNetR.py +++ b/napari_cellseg3d/code_models/models/model_SwinUNetR.py @@ -1,4 +1,5 @@ """SwinUNetR wrapper for napari_cellseg3d.""" + from monai.networks.nets import SwinUNETR from napari_cellseg3d.utils import LOGGER @@ -10,6 +11,7 @@ class SwinUNETR_(SwinUNETR): """SwinUNETR wrapper for napari_cellseg3d.""" weights_file = "SwinUNetR_latest.pth" + default_threshold = 0.4 def __init__( self, diff --git a/napari_cellseg3d/code_models/models/model_TRAILMAP_MS.py b/napari_cellseg3d/code_models/models/model_TRAILMAP_MS.py index 2aacc333..2735c871 100644 --- a/napari_cellseg3d/code_models/models/model_TRAILMAP_MS.py +++ b/napari_cellseg3d/code_models/models/model_TRAILMAP_MS.py @@ -1,4 +1,5 @@ """TRAILMAP model, reimplemented in PyTorch.""" + from napari_cellseg3d.code_models.models.unet.model import UNet3D from napari_cellseg3d.utils import LOGGER as logger @@ -7,6 +8,7 @@ class TRAILMAP_MS_(UNet3D): """TRAILMAP_MS wrapper for napari_cellseg3d.""" weights_file = "TRAILMAP_MS_best_metric.pth" + default_threshold = 0.15 # original model from Liqun Luo lab, transferred to pytorch and trained on mesoSPIM-acquired data (mostly TPH2 as of July 2022) diff --git a/napari_cellseg3d/code_models/models/model_VNet.py b/napari_cellseg3d/code_models/models/model_VNet.py index 2e2e618f..8089ad3a 100644 --- a/napari_cellseg3d/code_models/models/model_VNet.py +++ b/napari_cellseg3d/code_models/models/model_VNet.py @@ -1,4 +1,5 @@ """VNet wrapper for napari_cellseg3d.""" + from monai.networks.nets import VNet @@ -6,6 +7,7 @@ class VNet_(VNet): """VNet wrapper for napari_cellseg3d.""" weights_file = "VNet_latest.pth" + default_threshold = 0.15 def __init__(self, in_channels=1, out_channels=1, **kwargs): """Create a VNet model. diff --git a/napari_cellseg3d/code_models/models/model_WNet.py b/napari_cellseg3d/code_models/models/model_WNet.py index e80884e8..2e4c29b1 100644 --- a/napari_cellseg3d/code_models/models/model_WNet.py +++ b/napari_cellseg3d/code_models/models/model_WNet.py @@ -15,6 +15,7 @@ class WNet_(WNet_encoder): """ weights_file = "wnet_latest.pth" + default_threshold = 0.6 def __init__( self, diff --git a/napari_cellseg3d/code_models/models/model_test.py b/napari_cellseg3d/code_models/models/model_test.py index 39f81392..dfb556fd 100644 --- a/napari_cellseg3d/code_models/models/model_test.py +++ b/napari_cellseg3d/code_models/models/model_test.py @@ -1,4 +1,5 @@ """Model for testing purposes.""" + import torch from torch import nn @@ -7,6 +8,7 @@ class TestModel(nn.Module): """For tests only.""" weights_file = "test.pth" + default_threshold = 0.5 def __init__(self, **kwargs): """Create a TestModel model.""" diff --git a/napari_cellseg3d/code_plugins/plugin_convert.py b/napari_cellseg3d/code_plugins/plugin_convert.py index dcc18c05..c57c61c0 100644 --- a/napari_cellseg3d/code_plugins/plugin_convert.py +++ b/napari_cellseg3d/code_plugins/plugin_convert.py @@ -1,4 +1,5 @@ """Several image processing utilities.""" + from pathlib import Path from warnings import warn @@ -837,3 +838,145 @@ def _start(self): ) else: logger.warning("Please specify a layer or a folder") + + +class ThresholdGridSearchUtils(BasePluginUtils): + """Widget to run a grid search for thresholding.""" + + save_path = Path.home() / "cellseg3d" / "threshold_grid_search" + + def __init__(self, viewer: "napari.viewer.Viewer", parent=None): + """Creates a ThresholdGridSearchUtils widget. + + Args: + viewer: viewer in which to process data + parent: parent widget + """ + super().__init__( + viewer, + parent=parent, + ) + self.do_binarize = False + self.do_remap = False + self.result_text = "" + self.values = {} + + self.data_panel = self._build_io_panel() + # disable folder choice + self.radio_buttons.setVisible(False) + self.radio_buttons.setEnabled(False) + + self.image_layer_loader.layer_list.label.setText("Prediction :") + self.label_layer_loader.layer_list.label.setText("Labels :") + + self.results_path = str(self.save_path) + self.results_filewidget.text_field.setText(self.results_path) + self.results_filewidget.check_ready() + + self.start_btn = ui.Button("Start", self._start) + self.result_display = ui.make_label(self.result_text, self) + self.image_layer_loader.layer_list.currentIndexChanged.connect( + self._reset + ) + self.label_layer_loader.layer_list.currentIndexChanged.connect( + self._reset + ) + + self.container = self._build() + + def _build(self): + container = ui.ContainerWidget() + + container.layout.addWidget(self.data_panel) + ui.add_widgets( + container.layout, + [ + self.start_btn, + self.result_display, + ], + ) + + ui.ScrollArea.make_scrollable( + container.layout, self, max_wh=[MAX_W, MAX_H] + ) + self._set_io_visibility() + return container + + def _reset(self): + self.values = {} + self.result_text = "" + self.result_display.setText("") + + def _check_ready(self): + image_data = self.image_layer_loader.layer_data() + label_data = self.label_layer_loader.layer_data() + if image_data is None: + self.result_display.setText("Please load a prediction layer") + return False + if label_data is None: + self.result_display.setText("Please load a labels layer") + return False + if label_data.shape != image_data.shape: + self.result_display.setText( + "Prediction and labels must have the same shape" + ) + return False + if ( + label_data.min() < 0 + or label_data.max() > 1 + or len(np.unique(label_data)) != 2 + ): + self.do_binarize = True + if image_data.min() < 0 or image_data.max() > 1: + self.do_remap = True + return True + + def _get_dice_graph(self): + max_dice = max(self.values.values()) + self.result_text += "Thre | Dice | Graph\n" + for tr, dice in self.values.items(): + bar = "°" * int((dice / max_dice) * 25) + self.result_text += f"{tr:.2f} | {dice:.3f} | {bar}\n" + + def _start(self): + utils.mkdir_from_str(self.results_path) + if not self._check_ready(): + return + + pred_data = self.image_layer_loader.layer_data().copy() + label_data = self.label_layer_loader.layer_data().copy() + if self.do_binarize: + logger.info("Labels values are not binary, binarizing") + label_data = to_semantic(label_data) + if self.do_remap: + logger.info( + "Prediction values are not from a model. Remapping between 0 and 1" + ) + pred_data = utils.remap_image(pred_data, new_max=1) + # find best threshold + search_space = np.arange(0, 1, 0.05) + for i in search_space: + i = i.round(2) + binarized = threshold(pred_data, i) + binarized = np.where(binarized > 0, 1, 0) + dice = utils.dice_coeff(binarized, label_data).round(3) + self.values[i] = dice + logger.info(f"Threshold : {i}, Dice : {dice}") + + best_threshold = max(self.values, key=self.values.get) + binarized = threshold(pred_data, best_threshold) + utils.save_layer( + self.results_path, + f"binarized_{utils.get_date_time()}.tif", + binarized, + ) + self.layer = utils.show_result( + self._viewer, + self.image_layer_loader.layer(), + binarized, + "binarized", + existing_layer=self.layer, + ) + self.result_test = f"Best threshold : {best_threshold}, Dice : {self.values[best_threshold]}\n" + self._get_dice_graph() + self.result_display.setText(self.result_text) diff --git a/napari_cellseg3d/code_plugins/plugin_helper.py b/napari_cellseg3d/code_plugins/plugin_helper.py index e143bfb3..43a91b33 100644 --- a/napari_cellseg3d/code_plugins/plugin_helper.py +++ b/napari_cellseg3d/code_plugins/plugin_helper.py @@ -1,4 +1,5 @@ """Tiny plugin showing link to documentation and about page.""" + import pathlib from typing import TYPE_CHECKING @@ -43,7 +44,7 @@ def __init__(self, viewer: "napari.viewer.Viewer"): self.logo_label.setToolTip("Open Github page") self.info_label = ui.make_label( - f"You are using napari-cellseg3d v.{'0.2.0'}\n\n" + f"You are using napari-cellseg3d v.{'0.2.1'}\n\n" f"Plugin for cell segmentation developed\n" f"by the Mathis Lab of Adaptive Motor Control\n\n" f"Code by :\nCyril Achard\nMaxime Vidal\nJessy Lauer\nMackenzie Mathis\n" diff --git a/napari_cellseg3d/code_plugins/plugin_model_inference.py b/napari_cellseg3d/code_plugins/plugin_model_inference.py index fd4c37e5..089a3db3 100644 --- a/napari_cellseg3d/code_plugins/plugin_model_inference.py +++ b/napari_cellseg3d/code_plugins/plugin_model_inference.py @@ -1,4 +1,5 @@ """Inference plugin for napari_cellseg3d.""" + from functools import partial from typing import TYPE_CHECKING @@ -144,7 +145,9 @@ def __init__(self, viewer: "napari.viewer.Viewer", parent=None): ) self.thresholding_slider = ui.Slider( - default=config.PostProcessConfig().thresholding.threshold_value + default=config.MODEL_LIST[ + self.model_choice.currentText() + ].default_threshold * 100, divide_factor=100.0, parent=self, @@ -408,6 +411,12 @@ def _load_weights_path(self): ) self._update_weights_path(file) + def _set_default_threshold(self): + # Whenever a model is selected, set the default threshold from the model file + model_name = self.model_choice.currentText() + threshold = config.MODEL_LIST[model_name].default_threshold + self.thresholding_slider.slider_value = threshold * 100 + def _build(self): """Puts all widgets in a layout and adds them to the napari Viewer.""" # ui.add_blank(self.view_results_container, view_results_layout) @@ -492,7 +501,8 @@ def _build(self): self.device_choice, ], ) - self.window_infer_params.setVisible(False) + self.use_window_choice.setChecked(True) + # self.window_infer_params.setVisible(False) inference_param_group_w.setLayout(inference_param_group_l) @@ -537,14 +547,18 @@ def _build(self): # self.instance_param_container, # instance segmentation ], ) + # self.thresholding_slider.container.setVisible(False) + self.thresholding_checkbox.setChecked(True) self._toggle_crf_choice() self.model_choice.currentIndexChanged.connect(self._toggle_crf_choice) + self.model_choice.currentIndexChanged.connect( + self._set_default_threshold + ) ModelFramework._show_io_element( self.save_stats_to_csv_box, self.use_instance_choice ) self.anisotropy_wdgt.container.setVisible(False) - self.thresholding_slider.container.setVisible(False) self.instance_widgets.setVisible(False) self.crf_widgets.setVisible(False) self.save_stats_to_csv_box.setVisible(False) diff --git a/napari_cellseg3d/code_plugins/plugin_utilities.py b/napari_cellseg3d/code_plugins/plugin_utilities.py index 538a5410..9ca16c6e 100644 --- a/napari_cellseg3d/code_plugins/plugin_utilities.py +++ b/napari_cellseg3d/code_plugins/plugin_utilities.py @@ -1,4 +1,5 @@ """Central plugin for all utilities.""" + from typing import TYPE_CHECKING if TYPE_CHECKING: @@ -18,6 +19,7 @@ FragmentUtils, RemoveSmallUtils, StatsUtils, + ThresholdGridSearchUtils, ThresholdUtils, ToInstanceUtils, ToSemanticUtils, @@ -39,6 +41,7 @@ "CRF": CRFWidget, "Label statistics": StatsUtils, "Clear large labels": ArtifactRemovalUtils, + "Find best threshold": ThresholdGridSearchUtils, } @@ -62,6 +65,7 @@ def __init__(self, viewer: "napari.viewer.Viewer"): "crf", "stats", "artifacts", + "find_thresh", ] self._create_utils_widgets(attr_names) self.utils_choice = ui.DropdownMenu( diff --git a/setup.cfg b/setup.cfg index 04e7a881..7dd98e41 100644 --- a/setup.cfg +++ b/setup.cfg @@ -1,6 +1,6 @@ [metadata] name = napari_cellseg3d -version = 0.2.0 +version = 0.2.1 [options] packages = find: