|
| 1 | +"""Script to perform inference on a single image and run post-processing on the results, withot napari.""" |
| 2 | +import logging |
| 3 | +from dataclasses import dataclass |
| 4 | +from pathlib import Path |
| 5 | +from typing import List |
| 6 | + |
| 7 | +import numpy as np |
| 8 | +import torch |
| 9 | + |
| 10 | +from napari_cellseg3d.code_models.instance_segmentation import ( |
| 11 | + clear_large_objects, |
| 12 | + clear_small_objects, |
| 13 | + threshold, |
| 14 | + volume_stats, |
| 15 | + voronoi_otsu, |
| 16 | +) |
| 17 | +from napari_cellseg3d.code_models.worker_inference import InferenceWorker |
| 18 | +from napari_cellseg3d.config import ( |
| 19 | + InferenceWorkerConfig, |
| 20 | + InstanceSegConfig, |
| 21 | + ModelInfo, |
| 22 | + SlidingWindowConfig, |
| 23 | +) |
| 24 | +from napari_cellseg3d.utils import resize |
| 25 | + |
| 26 | +logger = logging.getLogger(__name__) |
| 27 | +logging.basicConfig(level=logging.INFO) |
| 28 | + |
| 29 | + |
| 30 | +class LogFixture: |
| 31 | + """Fixture for napari-less logging, replaces napari_cellseg3d.interface.Log in model_workers. |
| 32 | +
|
| 33 | + This allows to redirect the output of the workers to stdout instead of a specialized widget. |
| 34 | + """ |
| 35 | + |
| 36 | + def __init__(self): |
| 37 | + """Creates a LogFixture object.""" |
| 38 | + super(LogFixture, self).__init__() |
| 39 | + |
| 40 | + def print_and_log(self, text, printing=None): |
| 41 | + """Prints and logs text.""" |
| 42 | + print(text) |
| 43 | + |
| 44 | + def warn(self, warning): |
| 45 | + """Logs warning.""" |
| 46 | + logger.warning(warning) |
| 47 | + |
| 48 | + def error(self, e): |
| 49 | + """Logs error.""" |
| 50 | + raise (e) |
| 51 | + |
| 52 | + |
| 53 | +WINDOW_SIZE = 64 |
| 54 | + |
| 55 | +MODEL_INFO = ModelInfo( |
| 56 | + name="SwinUNetR", |
| 57 | + model_input_size=64, |
| 58 | +) |
| 59 | + |
| 60 | +CONFIG = InferenceWorkerConfig( |
| 61 | + device="cuda" if torch.cuda.is_available() else "cpu", |
| 62 | + model_info=MODEL_INFO, |
| 63 | + results_path=str(Path("./results").absolute()), |
| 64 | + compute_stats=False, |
| 65 | + sliding_window_config=SlidingWindowConfig(WINDOW_SIZE, 0.25), |
| 66 | +) |
| 67 | + |
| 68 | + |
| 69 | +@dataclass |
| 70 | +class PostProcessConfig: |
| 71 | + """Config for post-processing.""" |
| 72 | + |
| 73 | + threshold: float = 0.4 |
| 74 | + spot_sigma: float = 0.55 |
| 75 | + outline_sigma: float = 0.55 |
| 76 | + isotropic_spot_sigma: float = 0.2 |
| 77 | + isotropic_outline_sigma: float = 0.2 |
| 78 | + anisotropy_correction: List[ |
| 79 | + float |
| 80 | + ] = None # TODO change to actual values, should be a ratio like [1,1/5,1] |
| 81 | + clear_small_size: int = 5 |
| 82 | + clear_large_objects: int = 500 |
| 83 | + |
| 84 | + |
| 85 | +def inference_on_images( |
| 86 | + image: np.array, config: InferenceWorkerConfig = CONFIG |
| 87 | +): |
| 88 | + """This function provides inference on an image with minimal config. |
| 89 | +
|
| 90 | + Args: |
| 91 | + image (np.array): Image to perform inference on. |
| 92 | + config (InferenceWorkerConfig, optional): Config for InferenceWorker. Defaults to CONFIG, see above. |
| 93 | + """ |
| 94 | + # instance_method = InstanceSegmentationWrapper(voronoi_otsu, {"spot_sigma": 0.7, "outline_sigma": 0.7}) |
| 95 | + |
| 96 | + config.post_process_config.zoom.enabled = False |
| 97 | + config.post_process_config.thresholding.enabled = ( |
| 98 | + False # will need to be enabled and set to 0.5 for the test images |
| 99 | + ) |
| 100 | + config.post_process_config.instance = InstanceSegConfig( |
| 101 | + enabled=False, |
| 102 | + ) |
| 103 | + |
| 104 | + config.layer = image |
| 105 | + |
| 106 | + log = LogFixture() |
| 107 | + worker = InferenceWorker(config) |
| 108 | + logger.debug(f"Worker config: {worker.config}") |
| 109 | + |
| 110 | + worker.log_signal.connect(log.print_and_log) |
| 111 | + worker.warn_signal.connect(log.warn) |
| 112 | + worker.error_signal.connect(log.error) |
| 113 | + |
| 114 | + worker.log_parameters() |
| 115 | + |
| 116 | + results = [] |
| 117 | + # append the InferenceResult when yielded by worker to results |
| 118 | + for result in worker.inference(): |
| 119 | + results.append(result) |
| 120 | + |
| 121 | + return results |
| 122 | + |
| 123 | + |
| 124 | +def post_processing(semantic_segmentation, config: PostProcessConfig = None): |
| 125 | + """Run post-processing on inference results.""" |
| 126 | + config = PostProcessConfig() if config is None else config |
| 127 | + # if config.anisotropy_correction is None: |
| 128 | + # config.anisotropy_correction = [1, 1, 1 / 5] |
| 129 | + if config.anisotropy_correction is None: |
| 130 | + config.anisotropy_correction = [1, 1, 1] |
| 131 | + |
| 132 | + image = semantic_segmentation |
| 133 | + # apply threshold to semantic segmentation |
| 134 | + logger.info(f"Thresholding with {config.threshold}") |
| 135 | + image = threshold(image, config.threshold) |
| 136 | + logger.debug(f"Thresholded image shape: {image.shape}") |
| 137 | + # remove artifacts by clearing large objects |
| 138 | + logger.info(f"Clearing large objects with {config.clear_large_objects}") |
| 139 | + image = clear_large_objects(image, config.clear_large_objects) |
| 140 | + # run instance segmentation |
| 141 | + logger.info( |
| 142 | + f"Running instance segmentation with {config.spot_sigma} and {config.outline_sigma}" |
| 143 | + ) |
| 144 | + labels = voronoi_otsu( |
| 145 | + image, |
| 146 | + spot_sigma=config.spot_sigma, |
| 147 | + outline_sigma=config.outline_sigma, |
| 148 | + ) |
| 149 | + # clear small objects |
| 150 | + logger.info(f"Clearing small objects with {config.clear_small_size}") |
| 151 | + labels = clear_small_objects(labels, config.clear_small_size).astype( |
| 152 | + np.uint16 |
| 153 | + ) |
| 154 | + logger.debug(f"Labels shape: {labels.shape}") |
| 155 | + # get volume stats WITH ANISOTROPY |
| 156 | + logger.debug(f"NUMBER OF OBJECTS: {np.max(np.unique(labels))-1}") |
| 157 | + stats_not_resized = volume_stats(labels) |
| 158 | + ######## RUN WITH ANISOTROPY ######## |
| 159 | + result_dict = {} |
| 160 | + result_dict["Not resized"] = { |
| 161 | + "labels": labels, |
| 162 | + "stats": stats_not_resized, |
| 163 | + } |
| 164 | + |
| 165 | + if config.anisotropy_correction != [1, 1, 1]: |
| 166 | + logger.info("Resizing image to correct anisotropy") |
| 167 | + image = resize(image, config.anisotropy_correction) |
| 168 | + logger.debug(f"Resized image shape: {image.shape}") |
| 169 | + logger.info("Running labels without anisotropy") |
| 170 | + labels_resized = voronoi_otsu( |
| 171 | + image, |
| 172 | + spot_sigma=config.isotropic_spot_sigma, |
| 173 | + outline_sigma=config.isotropic_outline_sigma, |
| 174 | + ) |
| 175 | + logger.info( |
| 176 | + f"Clearing small objects with {config.clear_large_objects}" |
| 177 | + ) |
| 178 | + labels_resized = clear_small_objects( |
| 179 | + labels_resized, config.clear_small_size |
| 180 | + ).astype(np.uint16) |
| 181 | + logger.debug( |
| 182 | + f"NUMBER OF OBJECTS: {np.max(np.unique(labels_resized))-1}" |
| 183 | + ) |
| 184 | + logger.info("Getting volume stats without anisotropy") |
| 185 | + stats_resized = volume_stats(labels_resized) |
| 186 | + return labels_resized, stats_resized |
| 187 | + |
| 188 | + return labels, stats_not_resized |
0 commit comments