Skip to content

Commit

Permalink
Merge pull request #12 from AdaptiveMotorControlLab/feature/swinunetr
Browse files Browse the repository at this point in the history
[WIP] Adding SwinUNetR
  • Loading branch information
MMathisLab authored Aug 15, 2022
2 parents fde3c71 + 158e0c2 commit 7363823
Show file tree
Hide file tree
Showing 14 changed files with 249 additions and 469 deletions.
1 change: 0 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -99,4 +99,3 @@ venv/
/napari_cellseg3d/models/saved_weights/
/docs/res/logo/old_logo/
/reqs/

29 changes: 26 additions & 3 deletions napari_cellseg3d/interface.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
from typing import Optional
from typing import Union
from typing import List


Expand Down Expand Up @@ -61,6 +60,13 @@ def toggle_visibility(checkbox, widget):
widget.setVisible(checkbox.isChecked())


def add_label(widget, label, label_before=True, horizontal=True):
if label_before:
return combine_blocks(widget, label, horizontal=horizontal)
else:
return combine_blocks(label, widget, horizontal=horizontal)


class Button(QPushButton):
"""Class for a button with a title and connected to a function when clicked. Inherits from QPushButton.
Expand Down Expand Up @@ -494,20 +500,33 @@ def __init__(
step=1,
parent: Optional[QWidget] = None,
fixed: Optional[bool] = True,
label: Optional[str] = None,
):
"""Args:
min (Optional[float]): minimum value, defaults to 0
max (Optional[float]): maximum value, defaults to 10
default (Optional[float]): default value, defaults to 0
step (Optional[float]): step value, defaults to 1
parent: parent widget, defaults to None
fixed (bool): if True, sets the QSizePolicy of the spinbox to Fixed"""
fixed (bool): if True, sets the QSizePolicy of the spinbox to Fixed
label (Optional[str]): if provided, creates a label with the chosen title to use with the counter"""

super().__init__(parent)
set_spinbox(self, min, max, default, step, fixed)

if label is not None:
self.label = make_label(name=label)

# def setToolTip(self, a0: str) -> None:
# self.setToolTip(a0)
# if self.label is not None:
# self.label.setToolTip(a0)

def get_with_label(self, horizontal=True):
return add_label(self, self.label, horizontal=horizontal)

def set_precision(self, decimals):
"""Sets the precision of the box to the speicifed number of decimals"""
"""Sets the precision of the box to the specified number of decimals"""
self.setDecimals(decimals)

@classmethod
Expand Down Expand Up @@ -535,6 +554,7 @@ def __init__(
step=1,
parent: Optional[QWidget] = None,
fixed: Optional[bool] = True,
label: Optional[str] = None,
):
"""Args:
min (Optional[int]): minimum value, defaults to 0
Expand All @@ -546,6 +566,9 @@ def __init__(

super().__init__(parent)
set_spinbox(self, min, max, default, step, fixed)
self.label = None
if label is not None:
self.label = make_label(label, self)

@classmethod
def make_n(
Expand Down
7 changes: 5 additions & 2 deletions napari_cellseg3d/model_framework.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,9 @@
from napari_cellseg3d import utils
from napari_cellseg3d.log_utility import Log
from napari_cellseg3d.models import model_SegResNet as SegResNet
from napari_cellseg3d.models import model_TRAILMAP as TRAILMAP
from napari_cellseg3d.models import model_SwinUNetR as SwinUNetR

# from napari_cellseg3d.models import model_TRAILMAP as TRAILMAP
from napari_cellseg3d.models import model_VNet as VNet
from napari_cellseg3d.models import model_TRAILMAP_MS as TRAILMAP_MS
from napari_cellseg3d.plugin_base import BasePluginFolder
Expand Down Expand Up @@ -62,8 +64,9 @@ def __init__(self, viewer: "napari.viewer.Viewer"):
self.models_dict = {
"VNet": VNet,
"SegResNet": SegResNet,
"TRAILMAP": TRAILMAP,
# "TRAILMAP": TRAILMAP,
"TRAILMAP_MS": TRAILMAP_MS,
"SwinUNetR": SwinUNetR,
}
"""dict: dictionary of available models, with string for widget display as key
Expand Down
172 changes: 115 additions & 57 deletions napari_cellseg3d/model_workers.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,6 @@
# Qt
from qtpy.QtCore import Signal


from napari_cellseg3d import utils
from napari_cellseg3d import log_utility

Expand Down Expand Up @@ -165,6 +164,9 @@ def __init__(self):
super().__init__()


# TODO : use dataclass for config instead ?


class InferenceWorker(GeneratorWorker):
"""A custom worker to run inference jobs in.
Inherits from :py:class:`napari.qt.threading.GeneratorWorker`"""
Expand All @@ -180,6 +182,7 @@ def __init__(
instance,
use_window,
window_infer_size,
window_overlap,
keep_on_cpu,
stats_csv,
images_filepaths=None,
Expand Down Expand Up @@ -231,6 +234,7 @@ def __init__(
self.instance_params = instance
self.use_window = use_window
self.window_infer_size = window_infer_size
self.window_overlap_percentage = window_overlap
self.keep_on_cpu = keep_on_cpu
self.stats_to_csv = stats_csv
############################################
Expand Down Expand Up @@ -301,8 +305,6 @@ def log_parameters(self):
f"Probability threshold is {self.instance_params['threshold']:.2f}\n"
f"Objects smaller than {self.instance_params['size_small']} pixels will be removed\n"
)
# self.log(f"")
# self.log("\n")
self.log("-" * 20)

def load_folder(self):
Expand All @@ -313,25 +315,57 @@ def load_folder(self):
data_check = LoadImaged(keys=["image"])(images_dict[0])

check = data_check["image"].shape
# TODO remove
# z_aniso = 5 / 1.5
# if zoom is not None :
# pad = utils.get_padding_dim(check, anisotropy_factor=zoom)
# else:

self.log("\nChecking dimensions...")
pad = utils.get_padding_dim(check)

load_transforms = Compose(
[
LoadImaged(keys=["image"]),
# AddChanneld(keys=["image"]), #already done
EnsureChannelFirstd(keys=["image"]),
# Orientationd(keys=["image"], axcodes="PLI"),
# anisotropic_transform,
SpatialPadd(keys=["image"], spatial_size=pad),
EnsureTyped(keys=["image"]),
]
)
# dims = self.model_dict["model_input_size"]
#
# if self.model_dict["name"] == "SegResNet":
# model = self.model_dict["class"].get_net(
# input_image_size=[
# dims,
# dims,
# dims,
# ]
# )
# elif self.model_dict["name"] == "SwinUNetR":
# model = self.model_dict["class"].get_net(
# img_size=[dims, dims, dims],
# use_checkpoint=False,
# )
# else:
# model = self.model_dict["class"].get_net()
#
# self.log_parameters()
#
# model.to(self.device)

# print("FILEPATHS PRINT")
# print(self.images_filepaths)
if self.use_window:
load_transforms = Compose(
[
LoadImaged(keys=["image"]),
# AddChanneld(keys=["image"]), #already done
EnsureChannelFirstd(keys=["image"]),
# Orientationd(keys=["image"], axcodes="PLI"),
# anisotropic_transform,
EnsureTyped(keys=["image"]),
]
)
else:
load_transforms = Compose(
[
LoadImaged(keys=["image"]),
# AddChanneld(keys=["image"]), #already done
EnsureChannelFirstd(keys=["image"]),
# Orientationd(keys=["image"], axcodes="PLI"),
# anisotropic_transform,
SpatialPadd(keys=["image"], spatial_size=pad),
EnsureTyped(keys=["image"]),
]
)

self.log("\nLoading dataset...")
inference_ds = Dataset(data=images_dict, transform=load_transforms)
Expand Down Expand Up @@ -364,19 +398,32 @@ def load_layer(self):

# print(volume.shape)
# print(volume.dtype)

load_transforms = Compose(
[
ToTensor(),
# anisotropic_transform,
AddChannel(),
SpatialPad(spatial_size=pad),
AddChannel(),
EnsureType(),
],
map_items=False,
log_stats=True,
)
if self.use_window:
load_transforms = Compose(
[
ToTensor(),
# anisotropic_transform,
AddChannel(),
# SpatialPad(spatial_size=pad),
AddChannel(),
EnsureType(),
],
map_items=False,
log_stats=True,
)
else:
load_transforms = Compose(
[
ToTensor(),
# anisotropic_transform,
AddChannel(),
SpatialPad(spatial_size=pad),
AddChannel(),
EnsureType(),
],
map_items=False,
log_stats=True,
)

self.log("\nLoading dataset...")
input_image = load_transforms(volume)
Expand Down Expand Up @@ -405,8 +452,10 @@ def model_output(

if self.use_window:
window_size = self.window_infer_size
window_overlap = self.window_overlap_percentage
else:
window_size = None
window_overlap = 0.25

outputs = sliding_window_inference(
inputs,
Expand All @@ -415,6 +464,7 @@ def model_output(
predictor=model_output,
sw_device=self.device,
device=dataset_device,
overlap=window_overlap,
)

out = outputs.detach().cpu()
Expand Down Expand Up @@ -508,13 +558,12 @@ def save_image(
)

imwrite(file_path, image)
filename = os.path.split(file_path)[1]

if from_layer:
self.log(f"\nLayer prediction saved as :")
self.log(f"\nLayer prediction saved as : {filename}")
else:
self.log(f"\nFile n°{i+1} saved as :")
filename = os.path.split(file_path)[1]
self.log(filename)
self.log(f"\nFile n°{i+1} saved as : {filename}")

def aniso_transform(self, image):
zoom = self.transforms["zoom"][1]
Expand Down Expand Up @@ -630,9 +679,13 @@ def inference_on_layer(self, image, model, post_process_transforms):

self.save_image(out, from_layer=True)

instance_labels, data_dict = self.get_instance_result(out,from_layer=True)
instance_labels, data_dict = self.get_instance_result(
out, from_layer=True
)

return self.create_result_dict(out, instance_labels, from_layer=True, data_dict=data_dict)
return self.create_result_dict(
out, instance_labels, from_layer=True, data_dict=data_dict
)

def inference(self):
"""
Expand Down Expand Up @@ -674,29 +727,27 @@ def inference(self):
torch.set_num_threads(1) # required for threading on macOS ?
self.log("Number of threads has been set to 1 for macOS")

# if self.device =="cuda": # TODO : fix mem alloc, this does not work it seems
# torch.backends.cudnn.benchmark = False

# TODO : better solution than loading first image always ?
# data_check = LoadImaged(keys=["image"])(images_dict[0])
# print(data)
# check = data_check["image"].shape
# print(check)

try:
dims = self.model_dict["segres_size"]
dims = self.model_dict["model_input_size"]
self.log(f"MODEL DIMS : {dims}")
self.log(self.model_dict["name"])

model = self.model_dict["class"].get_net()
if self.model_dict["name"] == "SegResNet":
model = self.model_dict["class"].get_net()(
model = self.model_dict["class"].get_net(
input_image_size=[
dims,
dims,
dims,
], # TODO FIX ! find a better way & remove model-specific code
out_channels=1,
# dropout_prob=0.3,
)
elif self.model_dict["name"] == "SwinUNetR":
model = self.model_dict["class"].get_net(
img_size=[dims, dims, dims],
use_checkpoint=False,
)
else:
model = self.model_dict["class"].get_net()
model = model.to(self.device)

self.log_parameters()

Expand All @@ -722,10 +773,7 @@ def inference(self):
AsDiscrete(threshold=t), EnsureType()
)


self.log(
"\nLoading weights..."
)
self.log("\nLoading weights...")

if self.weights_dict["custom"]:
weights = self.weights_dict["path"]
Expand Down Expand Up @@ -1022,11 +1070,21 @@ def train(self):
else:
size = check
print(f"Size of image : {size}")
model = model_class.get_net()(
model = model_class.get_net(
input_image_size=utils.get_padding_dim(size),
out_channels=1,
dropout_prob=0.3,
)
elif model_name == "SwinUNetR":
if self.sampling:
size = self.sample_size
else:
size = check
print(f"Size of image : {size}")
model = model_class.get_net(
img_size=utils.get_padding_dim(size),
use_checkpoint=True,
)
else:
model = model_class.get_net() # get an instance of the model
model = model.to(self.device)
Expand Down
Loading

0 comments on commit 7363823

Please sign in to comment.