Skip to content

Commit

Permalink
CRF adjustments (#106)
Browse files Browse the repository at this point in the history
* Remove optional CRF dep+update warning

* Update crf.py

* Update pyproject.toml

---------

Co-authored-by: Mackenzie Mathis <[email protected]>
  • Loading branch information
C-Achard and MMathisLab authored Dec 23, 2024
1 parent 8c6c306 commit a72e931
Show file tree
Hide file tree
Showing 3 changed files with 24 additions and 49 deletions.
10 changes: 5 additions & 5 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -52,13 +52,13 @@ The strength of our approach is we can match supervised model performance with p

![FIG1 (1)](https://github.com/user-attachments/assets/0d970b45-79ff-4c58-861f-e1e7dc9abc65)

**Figure 1. Performance of 3D Semantic and Instance Segmentation Models.**
**a:** Raw mesoSPIM whole-brain sample, volumes and corresponding ground truth labels from somatosensory (S1) and visual (V1) cortical regions.
**Figure 1. Performance of 3D Semantic and Instance Segmentation Models.**
**a:** Raw mesoSPIM whole-brain sample, volumes and corresponding ground truth labels from somatosensory (S1) and visual (V1) cortical regions.
**b:** Evaluation of instance segmentation performance for baseline
thresholding-only, supervised models: Cellpose, StartDist, SwinUNetR, SegResNet, and our self-supervised model WNet3D over three data subsets.
F1-score is computed from the Intersection over Union (IoU) with ground truth labels, then averaged. Error bars represent 50% Confidence Intervals
(CIs).
**c:** View of 3D instance labels from supervised models, as noted, for visual cortex volume in b evaluation.
(CIs).
**c:** View of 3D instance labels from supervised models, as noted, for visual cortex volume in b evaluation.
**d:** Illustration of our WNet3D architecture showcasing the dual 3D U-Net structure with our modifications.


Expand Down Expand Up @@ -141,7 +141,7 @@ Before testing, install all requirements using ``pip install napari-cellseg3d[te

To run tests locally:

- Locally : run ``pytest`` in the plugin folder
- Locally : run ``pytest napari_cellseg3d\_tests`` in the plugin folder.
- Locally with coverage : In the plugin folder, run ``coverage run --source=napari_cellseg3d -m pytest`` then ``coverage xml`` to generate a .xml coverage file.
- With tox : run ``tox`` in the plugin folder (will simulate tests with several python and OS configs, requires substantial storage space)

Expand Down
58 changes: 18 additions & 40 deletions napari_cellseg3d/code_models/crf.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
"""Implements the CRF post-processing step for the WNet3D.
"""Implements the CRF post-processing step for WNet3D.
The CRF requires the following parameters:
Expand All @@ -16,8 +16,10 @@
Philipp Krähenbühl and Vladlen Koltun
NIPS 2011
Implemented using the pydense library available at https://github.com/lucasb-eyer/pydensecrf.
Implemented using the pydensecrf library available at https://github.com/lucasb-eyer/pydensecrf.
However, this is not maintained, thus we maintain this pacakge at https://github.com/AdaptiveMotorControlLab/pydensecrf.
"""

import importlib

import numpy as np
Expand All @@ -28,47 +30,19 @@

spec = importlib.util.find_spec("pydensecrf")
CRF_INSTALLED = spec is not None
if not CRF_INSTALLED:
logger.info(
"pydensecrf not installed, CRF post-processing will not be available. "
"Please install by running : pip install pydensecrf "
"This is not a hard requirement, you do not need it to install it unless you want to use the CRF post-processing step. "
)
else:
if CRF_INSTALLED:
import pydensecrf.densecrf as dcrf
from pydensecrf.utils import (
create_pairwise_bilateral,
create_pairwise_gaussian,
unary_from_softmax,
)

__author__ = "Yves Paychère, Colin Hofmann, Cyril Achard"
__credits__ = [
"Yves Paychère",
"Colin Hofmann",
"Cyril Achard",
"Philipp Krähenbühl",
"Vladlen Koltun",
"Liang-Chieh Chen",
"George Papandreou",
"Iasonas Kokkinos",
"Kevin Murphy",
"Alan L. Yuille",
"Xide Xia",
"Brian Kulis",
"Lucas Beyer",
]


def correct_shape_for_crf(image, desired_dims=4):
"""Corrects the shape of the image to be compatible with the CRF post-processing step."""
logger.debug(f"Correcting shape for CRF, desired_dims={desired_dims}")
logger.debug(f"Image shape: {image.shape}")
if len(image.shape) > desired_dims:
# if image.shape[0] > 1:
# raise ValueError(
# f"Image shape {image.shape} might have several channels"
# )
image = np.squeeze(image, axis=0)
elif len(image.shape) < desired_dims:
image = np.expand_dims(image, axis=0)
Expand All @@ -77,7 +51,7 @@ def correct_shape_for_crf(image, desired_dims=4):


def crf_batch(images, probs, sa, sb, sg, w1, w2, n_iter=5):
"""CRF post-processing step for the W-Net, applied to a batch of images.
"""CRF post-processing step for the WNet3D, applied to a batch of images.
Args:
images (np.ndarray): Array of shape (N, C, H, W, D) containing the input images.
Expand Down Expand Up @@ -105,7 +79,7 @@ def crf_batch(images, probs, sa, sb, sg, w1, w2, n_iter=5):


def crf(image, prob, sa, sb, sg, w1, w2, n_iter=5):
"""Implements the CRF post-processing step for the W-Net.
"""Implements the CRF post-processing step for the WNet3D.
Inspired by https://arxiv.org/abs/1210.5644, https://arxiv.org/abs/1606.00915 and https://arxiv.org/abs/1711.08506.
Implemented using the pydensecrf library.
Expand All @@ -124,14 +98,15 @@ def crf(image, prob, sa, sb, sg, w1, w2, n_iter=5):
np.ndarray: Array of shape (K, H, W, D) containing the refined class probabilities for each pixel.
"""
if not CRF_INSTALLED:
logger.info(
"pydensecrf not installed, therefore CRF post-processing will not be available! Please install the package. "
"Please install by running: pip install pydensecrf2 "
)
return None

d = dcrf.DenseCRF(
image.shape[1] * image.shape[2] * image.shape[3], prob.shape[0]
)
# print(f"Image shape : {image.shape}")
# print(f"Prob shape : {prob.shape}")
# d = dcrf.DenseCRF(262144, 3) # npoints, nlabels

# Get unary potentials from softmax probabilities
U = unary_from_softmax(prob)
Expand Down Expand Up @@ -165,7 +140,7 @@ def crf(image, prob, sa, sb, sg, w1, w2, n_iter=5):


def crf_with_config(image, prob, config: CRFConfig = None, log=logger.info):
"""Implements the CRF post-processing step for the W-Net.
"""Implements the CRF post-processing step for the WNet3D.
Args:
image (np.ndarray): Array of shape (C, H, W, D) containing the input image.
Expand Down Expand Up @@ -202,7 +177,7 @@ def crf_with_config(image, prob, config: CRFConfig = None, log=logger.info):


class CRFWorker(GeneratorWorker):
"""Worker for the CRF post-processing step for the W-Net."""
"""Worker for the CRF post-processing step for the WNet3D."""

def __init__(
self,
Expand Down Expand Up @@ -230,9 +205,12 @@ def __init__(
self.log = log

def _run_crf_job(self):
"""Runs the CRF post-processing step for the W-Net."""
"""Runs the CRF post-processing step for the WNet3D."""
if not CRF_INSTALLED:
raise ImportError("pydensecrf is not installed.")
logger.info(
"pydensecrf not installed, therefore CRF post-processing will not be available! Please install the package. "
"Please install by running: pip install pydensecrf2 "
)

if len(self.images) != len(self.labels):
raise ValueError("Number of images and labels must be the same.")
Expand Down
5 changes: 1 addition & 4 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ dependencies = [
"pyclesperanto-prototype",
"tqdm",
"matplotlib",
"pydensecrf2",
]
dynamic = ["version", "entry-points"]

Expand Down Expand Up @@ -123,9 +124,6 @@ profile = "black"
line_length = 79

[project.optional-dependencies]
crf = [
# "pydensecrf@git+https://github.com/lucasb-eyer/pydensecrf.git#egg=master",
]
pyqt5 = [
"pyqt5",
]
Expand Down Expand Up @@ -164,7 +162,6 @@ test = [
"coverage",
"tox",
"twine",
# "pydensecrf@git+https://github.com/lucasb-eyer/pydensecrf.git#egg=master",
"onnx",
"onnxruntime",
]

0 comments on commit a72e931

Please sign in to comment.