Skip to content

Commit 31db751

Browse files
authored
Colab demo : Inference (#80)
* Add basic demo notebook * Functional demo Colab notebook * Change branch for notebook * Clear output * Edit Colab notebook header * Add "Launch in Colab" * Add mention of demo in RDME * Change branch to main * Update colab_inference_demo.ipynb
1 parent 64821a0 commit 31db751

File tree

4 files changed

+510
-0
lines changed

4 files changed

+510
-0
lines changed

README.md

+2
Original file line numberDiff line numberDiff line change
@@ -99,6 +99,8 @@ and launching the plugin from the Plugins menu.
9999
You may use the test volume in the `examples` folder to test the inference and review tools.
100100
This should run in far less than five minutes on a modern computer.
101101

102+
You may also find a demo Colab notebook in the `notebooks` folder.
103+
102104
## Issues
103105

104106
**Help us make the code better by reporting issues and adding your feature requests!**

examples/c5image.tif

3.85 MB
Binary file not shown.
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,188 @@
1+
"""Script to perform inference on a single image and run post-processing on the results, withot napari."""
2+
import logging
3+
from dataclasses import dataclass
4+
from pathlib import Path
5+
from typing import List
6+
7+
import numpy as np
8+
import torch
9+
10+
from napari_cellseg3d.code_models.instance_segmentation import (
11+
clear_large_objects,
12+
clear_small_objects,
13+
threshold,
14+
volume_stats,
15+
voronoi_otsu,
16+
)
17+
from napari_cellseg3d.code_models.worker_inference import InferenceWorker
18+
from napari_cellseg3d.config import (
19+
InferenceWorkerConfig,
20+
InstanceSegConfig,
21+
ModelInfo,
22+
SlidingWindowConfig,
23+
)
24+
from napari_cellseg3d.utils import resize
25+
26+
logger = logging.getLogger(__name__)
27+
logging.basicConfig(level=logging.INFO)
28+
29+
30+
class LogFixture:
31+
"""Fixture for napari-less logging, replaces napari_cellseg3d.interface.Log in model_workers.
32+
33+
This allows to redirect the output of the workers to stdout instead of a specialized widget.
34+
"""
35+
36+
def __init__(self):
37+
"""Creates a LogFixture object."""
38+
super(LogFixture, self).__init__()
39+
40+
def print_and_log(self, text, printing=None):
41+
"""Prints and logs text."""
42+
print(text)
43+
44+
def warn(self, warning):
45+
"""Logs warning."""
46+
logger.warning(warning)
47+
48+
def error(self, e):
49+
"""Logs error."""
50+
raise (e)
51+
52+
53+
WINDOW_SIZE = 64
54+
55+
MODEL_INFO = ModelInfo(
56+
name="SwinUNetR",
57+
model_input_size=64,
58+
)
59+
60+
CONFIG = InferenceWorkerConfig(
61+
device="cuda" if torch.cuda.is_available() else "cpu",
62+
model_info=MODEL_INFO,
63+
results_path=str(Path("./results").absolute()),
64+
compute_stats=False,
65+
sliding_window_config=SlidingWindowConfig(WINDOW_SIZE, 0.25),
66+
)
67+
68+
69+
@dataclass
70+
class PostProcessConfig:
71+
"""Config for post-processing."""
72+
73+
threshold: float = 0.4
74+
spot_sigma: float = 0.55
75+
outline_sigma: float = 0.55
76+
isotropic_spot_sigma: float = 0.2
77+
isotropic_outline_sigma: float = 0.2
78+
anisotropy_correction: List[
79+
float
80+
] = None # TODO change to actual values, should be a ratio like [1,1/5,1]
81+
clear_small_size: int = 5
82+
clear_large_objects: int = 500
83+
84+
85+
def inference_on_images(
86+
image: np.array, config: InferenceWorkerConfig = CONFIG
87+
):
88+
"""This function provides inference on an image with minimal config.
89+
90+
Args:
91+
image (np.array): Image to perform inference on.
92+
config (InferenceWorkerConfig, optional): Config for InferenceWorker. Defaults to CONFIG, see above.
93+
"""
94+
# instance_method = InstanceSegmentationWrapper(voronoi_otsu, {"spot_sigma": 0.7, "outline_sigma": 0.7})
95+
96+
config.post_process_config.zoom.enabled = False
97+
config.post_process_config.thresholding.enabled = (
98+
False # will need to be enabled and set to 0.5 for the test images
99+
)
100+
config.post_process_config.instance = InstanceSegConfig(
101+
enabled=False,
102+
)
103+
104+
config.layer = image
105+
106+
log = LogFixture()
107+
worker = InferenceWorker(config)
108+
logger.debug(f"Worker config: {worker.config}")
109+
110+
worker.log_signal.connect(log.print_and_log)
111+
worker.warn_signal.connect(log.warn)
112+
worker.error_signal.connect(log.error)
113+
114+
worker.log_parameters()
115+
116+
results = []
117+
# append the InferenceResult when yielded by worker to results
118+
for result in worker.inference():
119+
results.append(result)
120+
121+
return results
122+
123+
124+
def post_processing(semantic_segmentation, config: PostProcessConfig = None):
125+
"""Run post-processing on inference results."""
126+
config = PostProcessConfig() if config is None else config
127+
# if config.anisotropy_correction is None:
128+
# config.anisotropy_correction = [1, 1, 1 / 5]
129+
if config.anisotropy_correction is None:
130+
config.anisotropy_correction = [1, 1, 1]
131+
132+
image = semantic_segmentation
133+
# apply threshold to semantic segmentation
134+
logger.info(f"Thresholding with {config.threshold}")
135+
image = threshold(image, config.threshold)
136+
logger.debug(f"Thresholded image shape: {image.shape}")
137+
# remove artifacts by clearing large objects
138+
logger.info(f"Clearing large objects with {config.clear_large_objects}")
139+
image = clear_large_objects(image, config.clear_large_objects)
140+
# run instance segmentation
141+
logger.info(
142+
f"Running instance segmentation with {config.spot_sigma} and {config.outline_sigma}"
143+
)
144+
labels = voronoi_otsu(
145+
image,
146+
spot_sigma=config.spot_sigma,
147+
outline_sigma=config.outline_sigma,
148+
)
149+
# clear small objects
150+
logger.info(f"Clearing small objects with {config.clear_small_size}")
151+
labels = clear_small_objects(labels, config.clear_small_size).astype(
152+
np.uint16
153+
)
154+
logger.debug(f"Labels shape: {labels.shape}")
155+
# get volume stats WITH ANISOTROPY
156+
logger.debug(f"NUMBER OF OBJECTS: {np.max(np.unique(labels))-1}")
157+
stats_not_resized = volume_stats(labels)
158+
######## RUN WITH ANISOTROPY ########
159+
result_dict = {}
160+
result_dict["Not resized"] = {
161+
"labels": labels,
162+
"stats": stats_not_resized,
163+
}
164+
165+
if config.anisotropy_correction != [1, 1, 1]:
166+
logger.info("Resizing image to correct anisotropy")
167+
image = resize(image, config.anisotropy_correction)
168+
logger.debug(f"Resized image shape: {image.shape}")
169+
logger.info("Running labels without anisotropy")
170+
labels_resized = voronoi_otsu(
171+
image,
172+
spot_sigma=config.isotropic_spot_sigma,
173+
outline_sigma=config.isotropic_outline_sigma,
174+
)
175+
logger.info(
176+
f"Clearing small objects with {config.clear_large_objects}"
177+
)
178+
labels_resized = clear_small_objects(
179+
labels_resized, config.clear_small_size
180+
).astype(np.uint16)
181+
logger.debug(
182+
f"NUMBER OF OBJECTS: {np.max(np.unique(labels_resized))-1}"
183+
)
184+
logger.info("Getting volume stats without anisotropy")
185+
stats_resized = volume_stats(labels_resized)
186+
return labels_resized, stats_resized
187+
188+
return labels, stats_not_resized

0 commit comments

Comments
 (0)