From b1be111b7ae53edffedc9cae33c8da22dcd91c1a Mon Sep 17 00:00:00 2001 From: Maxime Vidal Date: Mon, 4 Jul 2022 14:25:52 +0200 Subject: [PATCH 01/19] adding swinunetr & removing extra padding for inference --- .gitignore | 1 - napari_cellseg3d/model_framework.py | 2 + napari_cellseg3d/model_workers.py | 54 ++++++++++++++++++---- napari_cellseg3d/models/model_SwinUNetR.py | 18 ++++++++ napari_cellseg3d/models/model_VNet.py | 7 ++- napari_cellseg3d/plugin_model_inference.py | 31 ++++++++++++- napari_cellseg3d/plugin_model_training.py | 2 +- napari_cellseg3d/utils.py | 21 --------- setup.cfg | 5 +- 9 files changed, 101 insertions(+), 40 deletions(-) create mode 100644 napari_cellseg3d/models/model_SwinUNetR.py diff --git a/.gitignore b/.gitignore index f38ebe4d..74feefbe 100644 --- a/.gitignore +++ b/.gitignore @@ -99,4 +99,3 @@ venv/ /napari_cellseg3d/models/saved_weights/ /docs/res/logo/old_logo/ /reqs/ - diff --git a/napari_cellseg3d/model_framework.py b/napari_cellseg3d/model_framework.py index 1baf6eed..648865bb 100644 --- a/napari_cellseg3d/model_framework.py +++ b/napari_cellseg3d/model_framework.py @@ -13,6 +13,7 @@ from napari_cellseg3d import utils from napari_cellseg3d.log_utility import Log from napari_cellseg3d.models import model_SegResNet as SegResNet +from napari_cellseg3d.models import model_SwinUNetR as SwinUNetR from napari_cellseg3d.models import model_TRAILMAP as TRAILMAP from napari_cellseg3d.models import model_VNet as VNet from napari_cellseg3d.models import model_TRAILMAP_MS as TRAILMAP_MS @@ -64,6 +65,7 @@ def __init__(self, viewer: "napari.viewer.Viewer"): "SegResNet": SegResNet, "TRAILMAP": TRAILMAP, "TRAILMAP_MS": TRAILMAP_MS, + "SwinUNetR": SwinUNetR, } """dict: dictionary of available models, with string for widget display as key diff --git a/napari_cellseg3d/model_workers.py b/napari_cellseg3d/model_workers.py index eb2f3cb9..b71ca90d 100644 --- a/napari_cellseg3d/model_workers.py +++ b/napari_cellseg3d/model_workers.py @@ -179,6 +179,7 @@ def __init__( instance, use_window, window_infer_size, + window_overlap_percentage, keep_on_cpu, stats_csv, ): @@ -205,6 +206,8 @@ def __init__( * window_infer_size: size of window if use_window is True + * window_overlap_percentage: overlap of sliding windows if use_window is True + * keep_on_cpu: keep images on CPU or no * stats_csv: compute stats on cells and save them to a csv file @@ -228,6 +231,7 @@ def __init__( self.instance_params = instance self.use_window = use_window self.window_infer_size = window_infer_size + self.window_overlap_percentage = window_overlap_percentage self.keep_on_cpu = keep_on_cpu self.stats_to_csv = stats_csv """These attributes are all arguments of :py:func:~inference, please see that for reference""" @@ -350,8 +354,6 @@ def inference(self): # pad = utils.get_padding_dim(check, anisotropy_factor=zoom) # else: self.log("\nChecking dimensions...") - pad = utils.get_padding_dim(check) - # print(pad) dims = self.model_dict["segres_size"] model = self.model_dict["class"].get_net() @@ -365,6 +367,14 @@ def inference(self): out_channels=1, # dropout_prob=0.3, ) + elif self.model_dict["name"] == "SwinUNetR": + model = self.model_dict["class"].get_net()( + img_size=[dims, dims, dims], + in_channels=1, + out_channels=1, + feature_size=48, + use_checkpoint=False, + ) self.log_parameters() @@ -380,7 +390,6 @@ def inference(self): EnsureChannelFirstd(keys=["image"]), # Orientationd(keys=["image"], axcodes="PLI"), # anisotropic_transform, - SpatialPadd(keys=["image"], spatial_size=pad), EnsureTyped(keys=["image"]), ] ) @@ -437,10 +446,18 @@ def inference(self): # print(inputs.shape) inputs = inputs.to("cpu") + print(inputs.shape) - model_output = lambda inputs: post_process_transforms( - self.model_dict["class"].get_output(model, inputs) - ) + if self.model_dict["name"] == "SwinUNetR": + model_output = lambda inputs: post_process_transforms( + torch.sigmoid( + self.model_dict["class"].get_output(model, inputs) + ) + ) + else: + model_output = lambda inputs: post_process_transforms( + self.model_dict["class"].get_output(model, inputs) + ) if self.keep_on_cpu: dataset_device = "cpu" @@ -449,9 +466,10 @@ def inference(self): if self.use_window: window_size = self.window_infer_size + window_overlap = self.window_overlap_percentage else: window_size = None - + window_overlap = 0.25 outputs = sliding_window_inference( inputs, roi_size=window_size, @@ -459,12 +477,13 @@ def inference(self): predictor=model_output, sw_device=self.device, device=dataset_device, + overlap=window_overlap, ) - + print("done window infernce") out = outputs.detach().cpu() # del outputs # TODO fix memory ? # outputs = None - + print(out.shape) if self.transforms["zoom"][0]: zoom = self.transforms["zoom"][1] anisotropic_transform = Zoom( @@ -474,9 +493,11 @@ def inference(self): ) out = anisotropic_transform(out[0]) - out = post_process_transforms(out) + # out = post_process_transforms(out) out = np.array(out).astype(np.float32) + print(out.shape) out = np.squeeze(out) + print(out.shape) to_instance = out # avoid post processing since thresholding is done there anyway # batch_len = out.shape[1] @@ -825,6 +846,19 @@ def train(self): out_channels=1, dropout_prob=0.3, ) + elif model_name == "SwinUNetR": + if self.sampling: + size = self.sample_size + else: + size = check + print(f"Size of image : {size}") + model = model_class.get_net()( + img_size=utils.get_padding_dim(size), + in_channels=1, + out_channels=1, + feature_size=48, + use_checkpoint=True, + ) else: model = model_class.get_net() # get an instance of the model model = model.to(self.device) diff --git a/napari_cellseg3d/models/model_SwinUNetR.py b/napari_cellseg3d/models/model_SwinUNetR.py new file mode 100644 index 00000000..93516340 --- /dev/null +++ b/napari_cellseg3d/models/model_SwinUNetR.py @@ -0,0 +1,18 @@ +from monai.networks.nets import SwinUNETR + + +def get_weights_file(): + return "" + + +def get_net(): + return SwinUNETR + + +def get_output(model, input): + out = model(input) + return out + + +def get_validation(model, val_inputs): + return model(val_inputs) diff --git a/napari_cellseg3d/models/model_VNet.py b/napari_cellseg3d/models/model_VNet.py index 0c5f0b75..0c854832 100644 --- a/napari_cellseg3d/models/model_VNet.py +++ b/napari_cellseg3d/models/model_VNet.py @@ -19,6 +19,11 @@ def get_validation(model, val_inputs): roi_size = (64, 64, 64) sw_batch_size = 1 val_outputs = sliding_window_inference( - val_inputs, roi_size, sw_batch_size, model, mode="gaussian" + val_inputs, + roi_size, + sw_batch_size, + model, + mode="gaussian", + overlap=0.7, ) return val_outputs diff --git a/napari_cellseg3d/plugin_model_inference.py b/napari_cellseg3d/plugin_model_inference.py index 711f4b49..21b1a39a 100644 --- a/napari_cellseg3d/plugin_model_inference.py +++ b/napari_cellseg3d/plugin_model_inference.py @@ -76,6 +76,7 @@ def __init__(self, viewer: "napari.viewer.Viewer"): self.keep_on_cpu = False self.use_window_inference = False self.window_inference_size = None + self.window_overlap_percentage = None ########################### # interface @@ -132,6 +133,17 @@ def __init__(self, viewer: "napari.viewer.Viewer"): self.window_infer_box.clicked.connect(self.toggle_display_window_size) sizes_window = ["8", "16", "32", "64", "128", "256", "512"] + # ( + # self.window_size_choice, + # self.lbl_window_size_choice, + # ) = ui.make_combobox(sizes_window, label="Window size and overlap") + # self.window_overlap = ui.make_n_spinboxes( + # max=1, + # default=0.7, + # step=0.05, + # double=True, + # ) + self.window_size_choice = ui.DropdownMenu( sizes_window, label="Window size" ) @@ -144,6 +156,11 @@ def __init__(self, viewer: "napari.viewer.Viewer"): self.lbl_window_size_choice, horizontal=False, ) + # self.window_infer_params = ui.combine_blocks( + # self.window_overlap, + # self.window_infer_params, + # horizontal=False, + # ) ################## ################## @@ -210,7 +227,7 @@ def __init__(self, viewer: "napari.viewer.Viewer"): "Displays the image used for inference in the viewer" ) self.segres_size.setToolTip( - "Image size on which the SegResNet has been trained (default : 128)" + "Image size on which the model has been trained (default : 128)" ) thresh_desc = "Thresholding : all values in the image below the chosen probability threshold will be set to 0, and all others to 1." @@ -224,6 +241,11 @@ def __init__(self, viewer: "napari.viewer.Viewer"): self.window_size_choice.setToolTip( "Size of the window to run inference with (in pixels)" ) + + # self.window_overlap.setToolTip( + # "Amount of overlap between sliding windows" + # ) + self.keep_data_on_cpu_box.setToolTip( "If enabled, data will be kept on the RAM rather than the VRAM.\nCan avoid out of memory issues with CUDA" ) @@ -263,7 +285,10 @@ def check_ready(self): return False def toggle_display_segres_size(self): - if self.model_choice.currentText() == "SegResNet": + if ( + self.model_choice.currentText() == "SegResNet" + or self.model_choice.currentText() == "SwinUNetR" + ): self.segres_size.setVisible(True) else: self.segres_size.setVisible(False) @@ -575,6 +600,7 @@ def start(self): self.window_inference_size = int( self.window_size_choice.currentText() ) + # self.window_overlap_percentage = self.window_overlap.value() self.worker = InferenceWorker( device=device, @@ -587,6 +613,7 @@ def start(self): instance=self.instance_params, use_window=self.use_window_inference, window_infer_size=self.window_inference_size, + # window_overlap_percentage=self.window_overlap_percentage, keep_on_cpu=self.keep_on_cpu, stats_csv=self.stats_to_csv, ) diff --git a/napari_cellseg3d/plugin_model_training.py b/napari_cellseg3d/plugin_model_training.py index 517cf8fc..1e9deed6 100644 --- a/napari_cellseg3d/plugin_model_training.py +++ b/napari_cellseg3d/plugin_model_training.py @@ -31,7 +31,7 @@ from napari_cellseg3d.model_workers import TrainingWorker NUMBER_TABS = 3 -DEFAULT_PATCH_SIZE = 60 +DEFAULT_PATCH_SIZE = 64 class Trainer(ModelFramework): diff --git a/napari_cellseg3d/utils.py b/napari_cellseg3d/utils.py index bc725fc1..c4d87727 100644 --- a/napari_cellseg3d/utils.py +++ b/napari_cellseg3d/utils.py @@ -113,7 +113,6 @@ def resize(image, zoom_factors): def align_array_sizes(array_shape, target_shape): - index_differences = [] for i in range(len(target_shape)): if target_shape[i] != array_shape[i]: @@ -331,7 +330,6 @@ def fill_list_in_between(lst, n, elem): Returns : Filled list """ - new_list = [] for i in range(len(lst)): temp_list = [lst[i]] @@ -605,25 +603,6 @@ def format_Warning(message, category, filename, lineno, line=""): ) -# def dice_coeff(y_true, y_pred): -# smooth = 1. -# y_true_f = y_true.flatten() -# y_pred_f = K.flatten(y_pred) -# intersection = K.sum(y_true_f * y_pred_f) -# score = (2. * intersection + smooth) / (K.sum(y_true_f) + K.sum(y_pred_f) + smooth) -# return score - - -# def dice_loss(y_true, y_pred): -# loss = 1 - dice_coeff(y_true, y_pred) -# return loss - - -# def bce_dice_loss(y_true, y_pred): -# loss = binary_crossentropy(y_true, y_pred) + dice_loss(y_true, y_pred) -# return loss - - def divide_imgs(images): H = -(-images.shape[1] // 412) W = -(-images.shape[2] // 412) diff --git a/setup.cfg b/setup.cfg index 129426e4..b6444ec7 100644 --- a/setup.cfg +++ b/setup.cfg @@ -50,11 +50,8 @@ install_requires = tifffile>=2022.2.9 imageio-ffmpeg>=0.4.5 torch>=1.11 - monai>=0.9.0 - nibabel - scikit-image + monai[nibabel,scikit-image,itk,einops]>=0.9.0 pillow - itk>=5.2.0 matplotlib vispy>=0.9.6 From 21fee9be3066f3e17cf772ec6069de9ae6999c87 Mon Sep 17 00:00:00 2001 From: Maxime Vidal Date: Fri, 22 Jul 2022 14:32:17 +0200 Subject: [PATCH 02/19] removed unnecessary code --- napari_cellseg3d/utils.py | 356 -------------------------------------- 1 file changed, 356 deletions(-) diff --git a/napari_cellseg3d/utils.py b/napari_cellseg3d/utils.py index c4d87727..ddf19bff 100644 --- a/napari_cellseg3d/utils.py +++ b/napari_cellseg3d/utils.py @@ -601,359 +601,3 @@ def format_Warning(message, category, filename, lineno, line=""): + str(message) + "\n" ) - - -def divide_imgs(images): - H = -(-images.shape[1] // 412) - W = -(-images.shape[2] // 412) - - diveded_imgs = np.zeros((images.shape[0] * H * W, 512, 512, 1), np.float32) - # print(H, W) - - for z in range(images.shape[0]): - image = images[z] - for h in range(H): - for w in range(W): - cropped_img = np.zeros((512, 512, 1), np.float32) - cropped_img -= 1 - - if images.shape[1] < 412: - h = -1 - if images.shape[2] < 412: - w = -1 - - if h == -1: - if w == -1: - cropped_img[ - 50 : images.shape[1] + 50, - 50 : images.shape[2] + 50, - 0, - ] = image[0 : images.shape[1], 0 : images.shape[2], 0] - elif w == 0: - cropped_img[ - 50 : images.shape[1] + 50, 50:512, 0 - ] = image[0 : images.shape[1], 0:462, 0] - elif w == W - 1: - cropped_img[ - 50 : images.shape[1] + 50, - 0 : images.shape[2] - 412 * W - 50, - 0, - ] = image[ - 0 : images.shape[1], - w * 412 - 50 : images.shape[2], - 0, - ] - else: - cropped_img[50 : images.shape[1] + 50, :, 0] = image[ - 0 : images.shape[1], - w * 412 - 50 : (w + 1) * 412 + 50, - 0, - ] - elif h == 0: - if w == -1: - cropped_img[ - 50:512, 50 : images.shape[2] + 50, 0 - ] = image[0:462, 0 : images.shape[2], 0] - elif w == 0: - cropped_img[50:512, 50:512, 0] = image[0:462, 0:462, 0] - elif w == W - 1: - cropped_img[ - 50:512, 0 : images.shape[2] - 412 * W - 50, 0 - ] = image[0:462, w * 412 - 50 : images.shape[2], 0] - else: - # cropped_img[50:512, :, 0] = image[0:462, w*412-50:(w+1)*412+50, 0] - try: - cropped_img[50:512, :, 0] = image[ - 0:462, w * 412 - 50 : (w + 1) * 412 + 50, 0 - ] - except: - cropped_img[ - 50:512, - 0 : images.shape[2] - 412 * (W - 1) - 50, - 0, - ] = image[ - 0:462, w * 412 - 50 : (w + 1) * 412 + 50, 0 - ] - elif h == H - 1: - if w == -1: - cropped_img[ - 0 : images.shape[1] - 412 * H - 50, - 50 : images.shape[2] + 50, - 0, - ] = image[ - h * 412 - 50 : images.shape[1], - 0 : images.shape[2], - 0, - ] - elif w == 0: - cropped_img[ - 0 : images.shape[1] - 412 * H - 50, 50:512, 0 - ] = image[h * 412 - 50 : images.shape[1], 0:462, 0] - elif w == W - 1: - cropped_img[ - 0 : images.shape[1] - 412 * H - 50, - 0 : images.shape[2] - 412 * W - 50, - 0, - ] = image[ - h * 412 - 50 : images.shape[1], - w * 412 - 50 : images.shape[2], - 0, - ] - else: - try: - cropped_img[ - 0 : images.shape[1] - 412 * H - 50, :, 0 - ] = image[ - h * 412 - 50 : images.shape[1], - w * 412 - 50 : (w + 1) * 412 + 50, - 0, - ] - except: - cropped_img[ - 0 : images.shape[1] - 412 * H - 50, - 0 : images.shape[2] - 412 * (W - 1) - 50, - 0, - ] = image[ - h * 412 - 50 : images.shape[1], - w * 412 - 50 : (w + 1) * 412 + 50, - 0, - ] - else: - if w == -1: - cropped_img[:, 50 : images.shape[2] + 50, 0] = image[ - h * 412 - 50 : (h + 1) * 412 + 50, - 0 : images.shape[2], - 0, - ] - elif w == 0: - # cropped_img[:, 50:512, 0] = image[h*412-50:(h+1)*412+50, 0:462, 0] - try: - cropped_img[:, 50:512, 0] = image[ - h * 412 - 50 : (h + 1) * 412 + 50, 0:462, 0 - ] - except: - cropped_img[ - 0 : images.shape[1] - 412 * H - 50 + 412, - 50:512, - 0, - ] = image[ - h * 412 - 50 : (h + 1) * 412 + 50, 0:462, 0 - ] - elif w == W - 1: - # cropped_img[:, 0:images.shape[2]-412*W-50, 0] = image[h*412-50:(h+1)*412+50, w*412-50:images.shape[2], 0] - try: - cropped_img[ - :, 0 : images.shape[2] - 412 * W - 50, 0 - ] = image[ - h * 412 - 50 : (h + 1) * 412 + 50, - w * 412 - 50 : images.shape[2], - 0, - ] - except: - cropped_img[ - 0 : images.shape[1] - 412 * H - 50 + 412, - 0 : images.shape[2] - 412 * W - 50, - 0, - ] = image[ - h * 412 - 50 : (h + 1) * 412 + 50, - w * 412 - 50 : images.shape[2], - 0, - ] - else: - # cropped_img[:, :, 0] = image[h*412-50:(h+1)*412+50, w*412-50:(w+1)*412+50, 0] - try: - cropped_img[:, :, 0] = image[ - h * 412 - 50 : (h + 1) * 412 + 50, - w * 412 - 50 : (w + 1) * 412 + 50, - 0, - ] - except: - try: - cropped_img[ - :, - 0 : images.shape[2] - 412 * (W - 1) - 50, - 0, - ] = image[ - h * 412 - 50 : (h + 1) * 412 + 50, - w * 412 - 50 : (w + 1) * 412 + 50, - 0, - ] - except: - cropped_img[ - 0 : images.shape[1] - 412 * (H - 1) - 50, - :, - 0, - ] = image[ - h * 412 - 50 : (h + 1) * 412 + 50, - w * 412 - 50 : (w + 1) * 412 + 50, - 0, - ] - h = max(0, h) - w = max(0, w) - diveded_imgs[z * H * W + w * H + h] = cropped_img - # print(z*H*W+ w*H+h) - - return diveded_imgs - - -def merge_imgs(imgs, original_image_shape): - merged_imgs = np.zeros( - ( - original_image_shape[0], - original_image_shape[1], - original_image_shape[2], - 1, - ), - np.float32, - ) - H = -(-original_image_shape[1] // 412) - W = -(-original_image_shape[2] // 412) - - for z in range(original_image_shape[0]): - for h in range(H): - for w in range(W): - - if original_image_shape[1] < 412: - h = -1 - if original_image_shape[2] < 412: - w = -1 - - # print(z*H*W+ max(w, 0)*H+max(h, 0)) - if h == -1: - if w == -1: - merged_imgs[ - z, - 0 : original_image_shape[1], - 0 : original_image_shape[2], - 0, - ] = imgs[z * H * W + 0 * H + 0][ - 50 : original_image_shape[1] + 50, - 50 : original_image_shape[2] + 50, - 0, - ] - elif w == 0: - merged_imgs[ - z, 0 : original_image_shape[1], 0:412, 0 - ] = imgs[z * H * W + w * H + 0][ - 50 : original_image_shape[1] + 50, 50:462, 0 - ] - elif w == W - 1: - merged_imgs[ - z, - 0 : original_image_shape[1], - w * 412 : original_image_shape[2], - 0, - ] = imgs[z * H * W + w * H + 0][ - 50 : original_image_shape[1] + 50, - 50 : original_image_shape[2] - 412 * W - 50, - 0, - ] - else: - merged_imgs[ - z, - 0 : original_image_shape[1], - w * 412 : (w + 1) * 412, - 0, - ] = imgs[z * H * W + w * H + 0][ - 50 : original_image_shape[1] + 50, 50:462, 0 - ] - elif h == 0: - if w == -1: - merged_imgs[ - z, 0:412, 0 : original_image_shape[2], 0 - ] = imgs[z * H * W + 0 * H + h][ - 50:462, 50 : original_image_shape[2] + 50, 0 - ] - elif w == 0: - merged_imgs[z, 0:412, 0:412, 0] = imgs[ - z * H * W + w * H + h - ][50:462, 50:462, 0] - elif w == W - 1: - merged_imgs[ - z, 0:412, w * 412 : original_image_shape[2], 0 - ] = imgs[z * H * W + w * H + h][ - 50:462, - 50 : original_image_shape[2] - 412 * W - 50, - 0, - ] - else: - merged_imgs[ - z, 0:412, w * 412 : (w + 1) * 412, 0 - ] = imgs[z * H * W + w * H + h][50:462, 50:462, 0] - elif h == H - 1: - if w == -1: - merged_imgs[ - z, - h * 412 : original_image_shape[1], - 0 : original_image_shape[2], - 0, - ] = imgs[z * H * W + 0 * H + h][ - 50 : original_image_shape[1] - 412 * H - 50, - 50 : original_image_shape[2] + 50, - 0, - ] - elif w == 0: - merged_imgs[ - z, h * 412 : original_image_shape[1], 0:412, 0 - ] = imgs[z * H * W + w * H + h][ - 50 : original_image_shape[1] - 412 * H - 50, - 50:462, - 0, - ] - elif w == W - 1: - merged_imgs[ - z, - h * 412 : original_image_shape[1], - w * 412 : original_image_shape[2], - 0, - ] = imgs[z * H * W + w * H + h][ - 50 : original_image_shape[1] - 412 * H - 50, - 50 : original_image_shape[2] - 412 * W - 50, - 0, - ] - else: - merged_imgs[ - z, - h * 412 : original_image_shape[1], - w * 412 : (w + 1) * 412, - 0, - ] = imgs[z * H * W + w * H + h][ - 50 : original_image_shape[1] - 412 * H - 50, - 50:462, - 0, - ] - else: - if w == -1: - merged_imgs[ - z, - h * 412 : (h + 1) * 412, - 0 : original_image_shape[2], - 0, - ] = imgs[z * H * W + 0 * H + h][ - 50:462, 50 : original_image_shape[2] + 50, 0 - ] - elif w == 0: - merged_imgs[ - z, h * 412 : (h + 1) * 412, 0:412, 0 - ] = imgs[z * H * W + w * H + h][50:462, 50:462, 0] - elif w == W - 1: - merged_imgs[ - z, - h * 412 : (h + 1) * 412, - w * 412 : original_image_shape[2], - 0, - ] = imgs[z * H * W + w * H + h][ - 50:462, - 50 : original_image_shape[2] - 412 * W - 50, - 0, - ] - else: - merged_imgs[ - z, - h * 412 : (h + 1) * 412, - w * 412 : (w + 1) * 412, - 0, - ] = imgs[z * H * W + w * H + h][50:462, 50:462, 0] - - print(merged_imgs.shape) - return merged_imgs From cb5116d4883a81aed14966321aeda34dd80c1cef Mon Sep 17 00:00:00 2001 From: Maxime Vidal Date: Fri, 22 Jul 2022 17:52:58 +0200 Subject: [PATCH 03/19] :wrench move inherent model params where they are instantiated and fixed padding only for sliding window inference --- napari_cellseg3d/model_workers.py | 190 ++++++++++----------- napari_cellseg3d/models/model_SegResNet.py | 4 +- napari_cellseg3d/models/model_SwinUNetR.py | 7 +- napari_cellseg3d/plugin_model_inference.py | 37 +--- 4 files changed, 101 insertions(+), 137 deletions(-) diff --git a/napari_cellseg3d/model_workers.py b/napari_cellseg3d/model_workers.py index b71ca90d..900cce75 100644 --- a/napari_cellseg3d/model_workers.py +++ b/napari_cellseg3d/model_workers.py @@ -43,7 +43,6 @@ # Qt from qtpy.QtCore import Signal - from napari_cellseg3d import utils from napari_cellseg3d import log_utility @@ -168,20 +167,19 @@ class InferenceWorker(GeneratorWorker): Inherits from :py:class:`napari.qt.threading.GeneratorWorker`""" def __init__( - self, - device, - model_dict, - weights_dict, - images_filepaths, - results_path, - filetype, - transforms, - instance, - use_window, - window_infer_size, - window_overlap_percentage, - keep_on_cpu, - stats_csv, + self, + device, + model_dict, + weights_dict, + images_filepaths, + results_path, + filetype, + transforms, + instance, + use_window, + window_infer_size, + keep_on_cpu, + stats_csv, ): """Initializes a worker for inference with the arguments needed by the :py:func:`~inference` function. @@ -206,8 +204,6 @@ def __init__( * window_infer_size: size of window if use_window is True - * window_overlap_percentage: overlap of sliding windows if use_window is True - * keep_on_cpu: keep images on CPU or no * stats_csv: compute stats on cells and save them to a csv file @@ -231,7 +227,7 @@ def __init__( self.instance_params = instance self.use_window = use_window self.window_infer_size = window_infer_size - self.window_overlap_percentage = window_overlap_percentage + self.window_overlap_percentage = 0.8, self.keep_on_cpu = keep_on_cpu self.stats_to_csv = stats_csv """These attributes are all arguments of :py:func:~inference, please see that for reference""" @@ -343,36 +339,25 @@ def inference(self): # if self.device =="cuda": # TODO : fix mem alloc, this does not work it seems # torch.backends.cudnn.benchmark = False - # TODO : better solution than loading first image always ? + self.log("\nChecking dimensions...") data_check = LoadImaged(keys=["image"])(images_dict[0]) - # print(data) check = data_check["image"].shape - # print(check) - # TODO remove - # z_aniso = 5 / 1.5 - # if zoom is not None : - # pad = utils.get_padding_dim(check, anisotropy_factor=zoom) - # else: - self.log("\nChecking dimensions...") - dims = self.model_dict["segres_size"] + pad = utils.get_padding_dim(check) + + dims = self.model_dict["model_input_size"] model = self.model_dict["class"].get_net() if self.model_dict["name"] == "SegResNet": - model = self.model_dict["class"].get_net()( + model = self.model_dict["class"].get_net( input_image_size=[ dims, dims, dims, - ], # TODO FIX ! find a better way & remove model-specific code - out_channels=1, - # dropout_prob=0.3, + ] ) elif self.model_dict["name"] == "SwinUNetR": - model = self.model_dict["class"].get_net()( + model = self.model_dict["class"].get_net( img_size=[dims, dims, dims], - in_channels=1, - out_channels=1, - feature_size=48, use_checkpoint=False, ) @@ -382,17 +367,29 @@ def inference(self): # print("FILEPATHS PRINT") # print(self.images_filepaths) - - load_transforms = Compose( - [ - LoadImaged(keys=["image"]), - # AddChanneld(keys=["image"]), #already done - EnsureChannelFirstd(keys=["image"]), - # Orientationd(keys=["image"], axcodes="PLI"), - # anisotropic_transform, - EnsureTyped(keys=["image"]), - ] - ) + if self.use_window: + load_transforms = Compose( + [ + LoadImaged(keys=["image"]), + # AddChanneld(keys=["image"]), #already done + EnsureChannelFirstd(keys=["image"]), + # Orientationd(keys=["image"], axcodes="PLI"), + # anisotropic_transform, + EnsureTyped(keys=["image"]), + ] + ) + else: + load_transforms = Compose( + [ + LoadImaged(keys=["image"]), + # AddChanneld(keys=["image"]), #already done + EnsureChannelFirstd(keys=["image"]), + # Orientationd(keys=["image"], axcodes="PLI"), + # anisotropic_transform, + SpatialPadd(keys=["image"], spatial_size=pad), + EnsureTyped(keys=["image"]), + ] + ) if not self.transforms["thresh"][0]: post_process_transforms = EnsureType() @@ -448,16 +445,9 @@ def inference(self): inputs = inputs.to("cpu") print(inputs.shape) - if self.model_dict["name"] == "SwinUNetR": - model_output = lambda inputs: post_process_transforms( - torch.sigmoid( - self.model_dict["class"].get_output(model, inputs) - ) - ) - else: - model_output = lambda inputs: post_process_transforms( - self.model_dict["class"].get_output(model, inputs) - ) + model_output = lambda inputs: post_process_transforms( + self.model_dict["class"].get_output(model, inputs) + ) if self.keep_on_cpu: dataset_device = "cpu" @@ -479,7 +469,6 @@ def inference(self): device=dataset_device, overlap=window_overlap, ) - print("done window infernce") out = outputs.detach().cpu() # del outputs # TODO fix memory ? # outputs = None @@ -519,14 +508,14 @@ def inference(self): # File output save name : original-name_model_date+time_number.filetype file_path = ( - self.results_path - + "/" - + f"Prediction_{image_id}_" - + original_filename - + "_" - + self.model_dict["name"] - + f"_{time}_" - + self.filetype + self.results_path + + "/" + + f"Prediction_{image_id}_" + + original_filename + + "_" + + self.model_dict["name"] + + f"_{time}_" + + self.filetype ) # print(filename) @@ -567,14 +556,14 @@ def method(image): instance_labels = method(to_instance) instance_filepath = ( - self.results_path - + "/" - + f"Instance_seg_labels_{image_id}_" - + original_filename - + "_" - + self.model_dict["name"] - + f"_{time}_" - + self.filetype + self.results_path + + "/" + + f"Instance_seg_labels_{image_id}_" + + original_filename + + "_" + + self.model_dict["name"] + + f"_{time}_" + + self.filetype ) imwrite(instance_filepath, instance_labels) @@ -617,23 +606,23 @@ class TrainingWorker(GeneratorWorker): Inherits from :py:class:`napari.qt.threading.GeneratorWorker`""" def __init__( - self, - device, - model_dict, - weights_path, - data_dicts, - validation_percent, - max_epochs, - loss_function, - learning_rate, - val_interval, - batch_size, - results_path, - sampling, - num_samples, - sample_size, - do_augmentation, - deterministic, + self, + device, + model_dict, + weights_path, + data_dicts, + validation_percent, + max_epochs, + loss_function, + learning_rate, + val_interval, + batch_size, + results_path, + sampling, + num_samples, + sample_size, + do_augmentation, + deterministic, ): """Initializes a worker for inference with the arguments needed by the :py:func:`~train` function. Note: See :py:func:`~train` @@ -841,9 +830,8 @@ def train(self): else: size = check print(f"Size of image : {size}") - model = model_class.get_net()( + model = model_class.get_net( input_image_size=utils.get_padding_dim(size), - out_channels=1, dropout_prob=0.3, ) elif model_name == "SwinUNetR": @@ -852,11 +840,8 @@ def train(self): else: size = check print(f"Size of image : {size}") - model = model_class.get_net()( + model = model_class.get_net( img_size=utils.get_padding_dim(size), - in_channels=1, - out_channels=1, - feature_size=48, use_checkpoint=True, ) else: @@ -868,10 +853,10 @@ def train(self): self.train_files, self.val_files = ( self.data_dicts[ - 0 : int(len(self.data_dicts) * self.validation_percent) + 0: int(len(self.data_dicts) * self.validation_percent) ], self.data_dicts[ - int(len(self.data_dicts) * self.validation_percent) : + int(len(self.data_dicts) * self.validation_percent): ], ) @@ -1032,10 +1017,10 @@ def train(self): if self.device.type == "cuda": self.log("Memory Usage:") alloc_mem = round( - torch.cuda.memory_allocated(0) / 1024**3, 1 + torch.cuda.memory_allocated(0) / 1024 ** 3, 1 ) reserved_mem = round( - torch.cuda.memory_reserved(0) / 1024**3, 1 + torch.cuda.memory_reserved(0) / 1024 ** 3, 1 ) self.log(f"Allocated: {alloc_mem}GB") self.log(f"Cached: {reserved_mem}GB") @@ -1117,7 +1102,7 @@ def train(self): yield train_report weights_filename = ( - f"{model_name}_best_metric" + f"_epoch_{epoch + 1}.pth" + f"{model_name}_best_metric" + f"_epoch_{epoch + 1}.pth" ) if metric > best_metric: @@ -1158,7 +1143,6 @@ def train(self): # self.close() - # def this_is_fine(self): # import numpy as np # diff --git a/napari_cellseg3d/models/model_SegResNet.py b/napari_cellseg3d/models/model_SegResNet.py index 41dc3bde..eced51c9 100644 --- a/napari_cellseg3d/models/model_SegResNet.py +++ b/napari_cellseg3d/models/model_SegResNet.py @@ -1,8 +1,8 @@ from monai.networks.nets import SegResNetVAE -def get_net(): - return SegResNetVAE +def get_net(input_image_size, dropout_prob=None): + return SegResNetVAE(input_image_size, out_channels=1, dropout_prob=dropout_prob) def get_weights_file(): diff --git a/napari_cellseg3d/models/model_SwinUNetR.py b/napari_cellseg3d/models/model_SwinUNetR.py index 93516340..5e3546ee 100644 --- a/napari_cellseg3d/models/model_SwinUNetR.py +++ b/napari_cellseg3d/models/model_SwinUNetR.py @@ -1,3 +1,4 @@ +import torch from monai.networks.nets import SwinUNETR @@ -5,13 +6,13 @@ def get_weights_file(): return "" -def get_net(): - return SwinUNETR +def get_net(img_size, use_checkpoint=True): + return SwinUNETR(img_size, in_channels=1, out_channels=1, feature_size=48, use_checkpoint=use_checkpoint) def get_output(model, input): out = model(input) - return out + return torch.sigmoid(out) def get_validation(model, val_inputs): diff --git a/napari_cellseg3d/plugin_model_inference.py b/napari_cellseg3d/plugin_model_inference.py index 21b1a39a..b298bdef 100644 --- a/napari_cellseg3d/plugin_model_inference.py +++ b/napari_cellseg3d/plugin_model_inference.py @@ -76,7 +76,6 @@ def __init__(self, viewer: "napari.viewer.Viewer"): self.keep_on_cpu = False self.use_window_inference = False self.window_inference_size = None - self.window_overlap_percentage = None ########################### # interface @@ -98,9 +97,9 @@ def __init__(self, viewer: "napari.viewer.Viewer"): ###################### ###################### # TODO : better way to handle SegResNet size reqs ? - self.segres_size = ui.IntIncrementCounter(min=1, max=1024, default=128) + self.model_input_size = ui.IntIncrementCounter(min=1, max=1024, default=128) self.model_choice.currentIndexChanged.connect( - self.toggle_display_segres_size + self.toggle_display_model_input_size ) self.model_choice.setCurrentIndex(0) @@ -133,16 +132,6 @@ def __init__(self, viewer: "napari.viewer.Viewer"): self.window_infer_box.clicked.connect(self.toggle_display_window_size) sizes_window = ["8", "16", "32", "64", "128", "256", "512"] - # ( - # self.window_size_choice, - # self.lbl_window_size_choice, - # ) = ui.make_combobox(sizes_window, label="Window size and overlap") - # self.window_overlap = ui.make_n_spinboxes( - # max=1, - # default=0.7, - # step=0.05, - # double=True, - # ) self.window_size_choice = ui.DropdownMenu( sizes_window, label="Window size" @@ -156,11 +145,6 @@ def __init__(self, viewer: "napari.viewer.Viewer"): self.lbl_window_size_choice, horizontal=False, ) - # self.window_infer_params = ui.combine_blocks( - # self.window_overlap, - # self.window_infer_params, - # horizontal=False, - # ) ################## ################## @@ -226,7 +210,7 @@ def __init__(self, viewer: "napari.viewer.Viewer"): self.show_original_checkbox.setToolTip( "Displays the image used for inference in the viewer" ) - self.segres_size.setToolTip( + self.model_input_size.setToolTip( "Image size on which the model has been trained (default : 128)" ) @@ -242,9 +226,6 @@ def __init__(self, viewer: "napari.viewer.Viewer"): "Size of the window to run inference with (in pixels)" ) - # self.window_overlap.setToolTip( - # "Amount of overlap between sliding windows" - # ) self.keep_data_on_cpu_box.setToolTip( "If enabled, data will be kept on the RAM rather than the VRAM.\nCan avoid out of memory issues with CUDA" @@ -284,14 +265,14 @@ def check_ready(self): warnings.warn("Image and label paths are not correctly set") return False - def toggle_display_segres_size(self): + def toggle_display_model_input_size(self): if ( self.model_choice.currentText() == "SegResNet" or self.model_choice.currentText() == "SwinUNetR" ): - self.segres_size.setVisible(True) + self.model_input_size.setVisible(True) else: - self.segres_size.setVisible(False) + self.model_input_size.setVisible(False) def toggle_display_number(self): """Shows the choices for viewing results depending on whether :py:attr:`self.view_checkbox` is checked""" @@ -398,7 +379,7 @@ def build(self): self.model_choice, self.custom_weights_choice, self.weights_path_container, - self.segres_size, + self.model_input_size, ], ) self.weights_path_container.setVisible(False) @@ -551,7 +532,7 @@ def start(self): model_dict = { # gather model info "name": model_key, "class": self.get_model(model_key), - "segres_size": self.segres_size.value(), + "model_input_size": self.model_input_size.value(), } if self.custom_weights_choice.isChecked(): @@ -600,7 +581,6 @@ def start(self): self.window_inference_size = int( self.window_size_choice.currentText() ) - # self.window_overlap_percentage = self.window_overlap.value() self.worker = InferenceWorker( device=device, @@ -613,7 +593,6 @@ def start(self): instance=self.instance_params, use_window=self.use_window_inference, window_infer_size=self.window_inference_size, - # window_overlap_percentage=self.window_overlap_percentage, keep_on_cpu=self.keep_on_cpu, stats_csv=self.stats_to_csv, ) From ede3f3974b8bcf950e0d54057c3497acf36544f3 Mon Sep 17 00:00:00 2001 From: Maxime Vidal Date: Fri, 22 Jul 2022 17:53:23 +0200 Subject: [PATCH 04/19] :art: formatting --- napari_cellseg3d/model_workers.py | 105 +++++++++++---------- napari_cellseg3d/models/model_SegResNet.py | 4 +- napari_cellseg3d/models/model_SwinUNetR.py | 8 +- napari_cellseg3d/plugin_model_inference.py | 5 +- 4 files changed, 66 insertions(+), 56 deletions(-) diff --git a/napari_cellseg3d/model_workers.py b/napari_cellseg3d/model_workers.py index 900cce75..8127ec08 100644 --- a/napari_cellseg3d/model_workers.py +++ b/napari_cellseg3d/model_workers.py @@ -167,19 +167,19 @@ class InferenceWorker(GeneratorWorker): Inherits from :py:class:`napari.qt.threading.GeneratorWorker`""" def __init__( - self, - device, - model_dict, - weights_dict, - images_filepaths, - results_path, - filetype, - transforms, - instance, - use_window, - window_infer_size, - keep_on_cpu, - stats_csv, + self, + device, + model_dict, + weights_dict, + images_filepaths, + results_path, + filetype, + transforms, + instance, + use_window, + window_infer_size, + keep_on_cpu, + stats_csv, ): """Initializes a worker for inference with the arguments needed by the :py:func:`~inference` function. @@ -227,7 +227,7 @@ def __init__( self.instance_params = instance self.use_window = use_window self.window_infer_size = window_infer_size - self.window_overlap_percentage = 0.8, + self.window_overlap_percentage = (0.8,) self.keep_on_cpu = keep_on_cpu self.stats_to_csv = stats_csv """These attributes are all arguments of :py:func:~inference, please see that for reference""" @@ -508,14 +508,14 @@ def inference(self): # File output save name : original-name_model_date+time_number.filetype file_path = ( - self.results_path - + "/" - + f"Prediction_{image_id}_" - + original_filename - + "_" - + self.model_dict["name"] - + f"_{time}_" - + self.filetype + self.results_path + + "/" + + f"Prediction_{image_id}_" + + original_filename + + "_" + + self.model_dict["name"] + + f"_{time}_" + + self.filetype ) # print(filename) @@ -556,14 +556,14 @@ def method(image): instance_labels = method(to_instance) instance_filepath = ( - self.results_path - + "/" - + f"Instance_seg_labels_{image_id}_" - + original_filename - + "_" - + self.model_dict["name"] - + f"_{time}_" - + self.filetype + self.results_path + + "/" + + f"Instance_seg_labels_{image_id}_" + + original_filename + + "_" + + self.model_dict["name"] + + f"_{time}_" + + self.filetype ) imwrite(instance_filepath, instance_labels) @@ -606,23 +606,23 @@ class TrainingWorker(GeneratorWorker): Inherits from :py:class:`napari.qt.threading.GeneratorWorker`""" def __init__( - self, - device, - model_dict, - weights_path, - data_dicts, - validation_percent, - max_epochs, - loss_function, - learning_rate, - val_interval, - batch_size, - results_path, - sampling, - num_samples, - sample_size, - do_augmentation, - deterministic, + self, + device, + model_dict, + weights_path, + data_dicts, + validation_percent, + max_epochs, + loss_function, + learning_rate, + val_interval, + batch_size, + results_path, + sampling, + num_samples, + sample_size, + do_augmentation, + deterministic, ): """Initializes a worker for inference with the arguments needed by the :py:func:`~train` function. Note: See :py:func:`~train` @@ -853,10 +853,10 @@ def train(self): self.train_files, self.val_files = ( self.data_dicts[ - 0: int(len(self.data_dicts) * self.validation_percent) + 0 : int(len(self.data_dicts) * self.validation_percent) ], self.data_dicts[ - int(len(self.data_dicts) * self.validation_percent): + int(len(self.data_dicts) * self.validation_percent) : ], ) @@ -1017,10 +1017,10 @@ def train(self): if self.device.type == "cuda": self.log("Memory Usage:") alloc_mem = round( - torch.cuda.memory_allocated(0) / 1024 ** 3, 1 + torch.cuda.memory_allocated(0) / 1024**3, 1 ) reserved_mem = round( - torch.cuda.memory_reserved(0) / 1024 ** 3, 1 + torch.cuda.memory_reserved(0) / 1024**3, 1 ) self.log(f"Allocated: {alloc_mem}GB") self.log(f"Cached: {reserved_mem}GB") @@ -1102,7 +1102,7 @@ def train(self): yield train_report weights_filename = ( - f"{model_name}_best_metric" + f"_epoch_{epoch + 1}.pth" + f"{model_name}_best_metric" + f"_epoch_{epoch + 1}.pth" ) if metric > best_metric: @@ -1143,6 +1143,7 @@ def train(self): # self.close() + # def this_is_fine(self): # import numpy as np # diff --git a/napari_cellseg3d/models/model_SegResNet.py b/napari_cellseg3d/models/model_SegResNet.py index eced51c9..ee1dc9a8 100644 --- a/napari_cellseg3d/models/model_SegResNet.py +++ b/napari_cellseg3d/models/model_SegResNet.py @@ -2,7 +2,9 @@ def get_net(input_image_size, dropout_prob=None): - return SegResNetVAE(input_image_size, out_channels=1, dropout_prob=dropout_prob) + return SegResNetVAE( + input_image_size, out_channels=1, dropout_prob=dropout_prob + ) def get_weights_file(): diff --git a/napari_cellseg3d/models/model_SwinUNetR.py b/napari_cellseg3d/models/model_SwinUNetR.py index 5e3546ee..67d911b3 100644 --- a/napari_cellseg3d/models/model_SwinUNetR.py +++ b/napari_cellseg3d/models/model_SwinUNetR.py @@ -7,7 +7,13 @@ def get_weights_file(): def get_net(img_size, use_checkpoint=True): - return SwinUNETR(img_size, in_channels=1, out_channels=1, feature_size=48, use_checkpoint=use_checkpoint) + return SwinUNETR( + img_size, + in_channels=1, + out_channels=1, + feature_size=48, + use_checkpoint=use_checkpoint, + ) def get_output(model, input): diff --git a/napari_cellseg3d/plugin_model_inference.py b/napari_cellseg3d/plugin_model_inference.py index b298bdef..eed839bf 100644 --- a/napari_cellseg3d/plugin_model_inference.py +++ b/napari_cellseg3d/plugin_model_inference.py @@ -97,7 +97,9 @@ def __init__(self, viewer: "napari.viewer.Viewer"): ###################### ###################### # TODO : better way to handle SegResNet size reqs ? - self.model_input_size = ui.IntIncrementCounter(min=1, max=1024, default=128) + self.model_input_size = ui.IntIncrementCounter( + min=1, max=1024, default=128 + ) self.model_choice.currentIndexChanged.connect( self.toggle_display_model_input_size ) @@ -226,7 +228,6 @@ def __init__(self, viewer: "napari.viewer.Viewer"): "Size of the window to run inference with (in pixels)" ) - self.keep_data_on_cpu_box.setToolTip( "If enabled, data will be kept on the RAM rather than the VRAM.\nCan avoid out of memory issues with CUDA" ) From 9d0190c63cdce5f40d50e1c221922047cd30a10b Mon Sep 17 00:00:00 2001 From: C-Achard Date: Fri, 12 Aug 2022 11:14:54 +0200 Subject: [PATCH 05/19] Disable TRAILMAP Removed draft of TRAILMAP port to Pytorch --- napari_cellseg3d/model_framework.py | 4 ++-- napari_cellseg3d/models/model_TRAILMAP.py | 2 -- napari_cellseg3d/models/pretrained/pretrained_model_urls.json | 1 - 3 files changed, 2 insertions(+), 5 deletions(-) diff --git a/napari_cellseg3d/model_framework.py b/napari_cellseg3d/model_framework.py index 648865bb..0f0304b5 100644 --- a/napari_cellseg3d/model_framework.py +++ b/napari_cellseg3d/model_framework.py @@ -14,7 +14,7 @@ from napari_cellseg3d.log_utility import Log from napari_cellseg3d.models import model_SegResNet as SegResNet from napari_cellseg3d.models import model_SwinUNetR as SwinUNetR -from napari_cellseg3d.models import model_TRAILMAP as TRAILMAP +# from napari_cellseg3d.models import model_TRAILMAP as TRAILMAP from napari_cellseg3d.models import model_VNet as VNet from napari_cellseg3d.models import model_TRAILMAP_MS as TRAILMAP_MS from napari_cellseg3d.plugin_base import BasePluginFolder @@ -63,7 +63,7 @@ def __init__(self, viewer: "napari.viewer.Viewer"): self.models_dict = { "VNet": VNet, "SegResNet": SegResNet, - "TRAILMAP": TRAILMAP, + # "TRAILMAP": TRAILMAP, "TRAILMAP_MS": TRAILMAP_MS, "SwinUNetR": SwinUNetR, } diff --git a/napari_cellseg3d/models/model_TRAILMAP.py b/napari_cellseg3d/models/model_TRAILMAP.py index ec4cfdbb..98bc20af 100644 --- a/napari_cellseg3d/models/model_TRAILMAP.py +++ b/napari_cellseg3d/models/model_TRAILMAP.py @@ -92,7 +92,6 @@ def bridgeBlock(self, in_ch, out_ch, kernel_size, padding="same"): ), nn.BatchNorm3d(out_ch), nn.ReLU(), - # nn.ConvTranspose3d(out_ch, out_ch, kernel_size=2, stride=2), ) return encode @@ -117,6 +116,5 @@ def outBlock(self, in_ch, out_ch, kernel_size, padding="same"): out = nn.Sequential( nn.Conv3d(in_ch, out_ch, kernel_size=kernel_size, padding=padding), - # nn.BatchNorm3d(out_ch), ) return out diff --git a/napari_cellseg3d/models/pretrained/pretrained_model_urls.json b/napari_cellseg3d/models/pretrained/pretrained_model_urls.json index 86bc0f57..db70430a 100644 --- a/napari_cellseg3d/models/pretrained/pretrained_model_urls.json +++ b/napari_cellseg3d/models/pretrained/pretrained_model_urls.json @@ -1,6 +1,5 @@ { "TRAILMAP_MS": "http://deeplabcut.rowland.harvard.edu/cellseg3dmodels/TRAILMAP_MS.tar.gz", - "TRAILMAP": "http://deeplabcut.rowland.harvard.edu/cellseg3dmodels/TRAILMAP.tar.gz", "SegResNet": "http://deeplabcut.rowland.harvard.edu/cellseg3dmodels/SegResNet.tar.gz", "VNet": "http://deeplabcut.rowland.harvard.edu/cellseg3dmodels/VNet.tar.gz" } \ No newline at end of file From a6d3cce35ceb20b0aa796b5f11d7fcd18fd48a48 Mon Sep 17 00:00:00 2001 From: C-Achard Date: Fri, 12 Aug 2022 12:43:53 +0200 Subject: [PATCH 06/19] Fixes + reqs update - Fixed model instantiation - Fix window overlap argument type error - Updated reqs.txt to have einops (MONAI optional dep.) --- napari_cellseg3d/model_workers.py | 9 +++++++-- napari_cellseg3d/models/model_SwinUNetR.py | 2 +- requirements.txt | 5 +---- 3 files changed, 9 insertions(+), 7 deletions(-) diff --git a/napari_cellseg3d/model_workers.py b/napari_cellseg3d/model_workers.py index 8127ec08..03ba2f95 100644 --- a/napari_cellseg3d/model_workers.py +++ b/napari_cellseg3d/model_workers.py @@ -227,7 +227,7 @@ def __init__( self.instance_params = instance self.use_window = use_window self.window_infer_size = window_infer_size - self.window_overlap_percentage = (0.8,) + self.window_overlap_percentage = 0.8 self.keep_on_cpu = keep_on_cpu self.stats_to_csv = stats_csv """These attributes are all arguments of :py:func:~inference, please see that for reference""" @@ -346,7 +346,7 @@ def inference(self): dims = self.model_dict["model_input_size"] - model = self.model_dict["class"].get_net() + if self.model_dict["name"] == "SegResNet": model = self.model_dict["class"].get_net( input_image_size=[ @@ -360,6 +360,8 @@ def inference(self): img_size=[dims, dims, dims], use_checkpoint=False, ) + else: + model = self.model_dict["class"].get_net() self.log_parameters() @@ -445,6 +447,7 @@ def inference(self): inputs = inputs.to("cpu") print(inputs.shape) + # self.log("output") model_output = lambda inputs: post_process_transforms( self.model_dict["class"].get_output(model, inputs) ) @@ -460,6 +463,8 @@ def inference(self): else: window_size = None window_overlap = 0.25 + + # self.log("window") outputs = sliding_window_inference( inputs, roi_size=window_size, diff --git a/napari_cellseg3d/models/model_SwinUNetR.py b/napari_cellseg3d/models/model_SwinUNetR.py index 67d911b3..532aeb89 100644 --- a/napari_cellseg3d/models/model_SwinUNetR.py +++ b/napari_cellseg3d/models/model_SwinUNetR.py @@ -3,7 +3,7 @@ def get_weights_file(): - return "" + return "Swin64_best_metric.pth" def get_net(img_size, use_checkpoint=True): diff --git a/requirements.txt b/requirements.txt index 9885d9f0..3ba73405 100644 --- a/requirements.txt +++ b/requirements.txt @@ -13,13 +13,10 @@ napari[all]>=0.4.14 QtPy opencv-python>=4.5.5 dask-image>=0.6.0 -scikit-image>=0.19.2 matplotlib>=3.4.1 tifffile>=2022.2.9 imageio-ffmpeg>=0.4.5 torch>=1.11 -monai>=0.9.0 -nibabel +monai[nibabel,scikit-image,itk,einops]>=0.9.0 pillow -itk>=5.2.0 vispy>=0.9.6 From da713e54a653eebaf8b0571bf390f06ef1d7e81e Mon Sep 17 00:00:00 2001 From: C-Achard Date: Fri, 12 Aug 2022 13:51:18 +0200 Subject: [PATCH 07/19] UI overhaul + overlap parameter - Added overlap parameter for window - Improved UI code slightly --- napari_cellseg3d/interface.py | 29 +++++++++++++++++++--- napari_cellseg3d/model_framework.py | 1 + napari_cellseg3d/model_workers.py | 15 ++++++----- napari_cellseg3d/plugin_model_inference.py | 28 ++++++++++++++++++--- 4 files changed, 61 insertions(+), 12 deletions(-) diff --git a/napari_cellseg3d/interface.py b/napari_cellseg3d/interface.py index 9ec9f558..1b1cad54 100644 --- a/napari_cellseg3d/interface.py +++ b/napari_cellseg3d/interface.py @@ -1,5 +1,4 @@ from typing import Optional -from typing import Union from typing import List @@ -61,6 +60,13 @@ def toggle_visibility(checkbox, widget): widget.setVisible(checkbox.isChecked()) +def add_label(widget, label, label_before=True, horizontal=True): + if label_before: + return combine_blocks(widget, label, horizontal=horizontal) + else: + return combine_blocks(label, widget, horizontal=horizontal) + + class Button(QPushButton): """Class for a button with a title and connected to a function when clicked. Inherits from QPushButton. @@ -492,6 +498,7 @@ def __init__( step=1, parent: Optional[QWidget] = None, fixed: Optional[bool] = True, + label: Optional[str] = None, ): """Args: min (Optional[int]): minimum value, defaults to 0 @@ -499,13 +506,25 @@ def __init__( default (Optional[int]): default value, defaults to 0 step (Optional[int]): step value, defaults to 1 parent: parent widget, defaults to None - fixed (bool): if True, sets the QSizePolicy of the spinbox to Fixed""" + fixed (bool): if True, sets the QSizePolicy of the spinbox to Fixed + label (Optional[str]): if provided, creates a label with the chosen title to use with the counter""" super().__init__(parent) set_spinbox(self, min, max, default, step, fixed) + if label is not None: + self.label = make_label(name=label) + + # def setToolTip(self, a0: str) -> None: + # self.setToolTip(a0) + # if self.label is not None: + # self.label.setToolTip(a0) + + def get_with_label(self, horizontal=True): + return add_label(self, self.label, horizontal=horizontal) + def set_precision(self, decimals): - """Sets the precision of the box to the speicifed number of decimals""" + """Sets the precision of the box to the specified number of decimals""" self.setDecimals(decimals) @classmethod @@ -533,6 +552,7 @@ def __init__( step=1, parent: Optional[QWidget] = None, fixed: Optional[bool] = True, + label: Optional[str] = None, ): """Args: min (Optional[int]): minimum value, defaults to 0 @@ -544,6 +564,9 @@ def __init__( super().__init__(parent) set_spinbox(self, min, max, default, step, fixed) + self.label = None + if label is not None: + self.label = make_label(label, self) @classmethod def make_n( diff --git a/napari_cellseg3d/model_framework.py b/napari_cellseg3d/model_framework.py index 0f0304b5..47208616 100644 --- a/napari_cellseg3d/model_framework.py +++ b/napari_cellseg3d/model_framework.py @@ -14,6 +14,7 @@ from napari_cellseg3d.log_utility import Log from napari_cellseg3d.models import model_SegResNet as SegResNet from napari_cellseg3d.models import model_SwinUNetR as SwinUNetR + # from napari_cellseg3d.models import model_TRAILMAP as TRAILMAP from napari_cellseg3d.models import model_VNet as VNet from napari_cellseg3d.models import model_TRAILMAP_MS as TRAILMAP_MS diff --git a/napari_cellseg3d/model_workers.py b/napari_cellseg3d/model_workers.py index 03ba2f95..c3ea13d0 100644 --- a/napari_cellseg3d/model_workers.py +++ b/napari_cellseg3d/model_workers.py @@ -162,6 +162,9 @@ def __init__(self): super().__init__() +# TODO : use dataclass for config instead ? + + class InferenceWorker(GeneratorWorker): """A custom worker to run inference jobs in. Inherits from :py:class:`napari.qt.threading.GeneratorWorker`""" @@ -178,6 +181,7 @@ def __init__( instance, use_window, window_infer_size, + window_overlap, keep_on_cpu, stats_csv, ): @@ -227,7 +231,7 @@ def __init__( self.instance_params = instance self.use_window = use_window self.window_infer_size = window_infer_size - self.window_overlap_percentage = 0.8 + self.window_overlap_percentage = window_overlap self.keep_on_cpu = keep_on_cpu self.stats_to_csv = stats_csv """These attributes are all arguments of :py:func:~inference, please see that for reference""" @@ -346,7 +350,6 @@ def inference(self): dims = self.model_dict["model_input_size"] - if self.model_dict["name"] == "SegResNet": model = self.model_dict["class"].get_net( input_image_size=[ @@ -445,7 +448,7 @@ def inference(self): # print(inputs.shape) inputs = inputs.to("cpu") - print(inputs.shape) + # print(inputs.shape) # self.log("output") model_output = lambda inputs: post_process_transforms( @@ -477,7 +480,7 @@ def inference(self): out = outputs.detach().cpu() # del outputs # TODO fix memory ? # outputs = None - print(out.shape) + # print(out.shape) if self.transforms["zoom"][0]: zoom = self.transforms["zoom"][1] anisotropic_transform = Zoom( @@ -489,9 +492,9 @@ def inference(self): # out = post_process_transforms(out) out = np.array(out).astype(np.float32) - print(out.shape) + # print(out.shape) out = np.squeeze(out) - print(out.shape) + # print(out.shape) to_instance = out # avoid post processing since thresholding is done there anyway # batch_len = out.shape[1] diff --git a/napari_cellseg3d/plugin_model_inference.py b/napari_cellseg3d/plugin_model_inference.py index eed839bf..85b2ca7d 100644 --- a/napari_cellseg3d/plugin_model_inference.py +++ b/napari_cellseg3d/plugin_model_inference.py @@ -76,6 +76,7 @@ def __init__(self, viewer: "napari.viewer.Viewer"): self.keep_on_cpu = False self.use_window_inference = False self.window_inference_size = None + self.window_overlap = 0.25 ########################### # interface @@ -130,7 +131,7 @@ def __init__(self, viewer: "napari.viewer.Viewer"): T=7, parent=self ) - self.window_infer_box = ui.make_checkbox("Use window inference") + self.window_infer_box = ui.CheckBox(title="Use window inference") self.window_infer_box.clicked.connect(self.toggle_display_window_size) sizes_window = ["8", "16", "32", "64", "128", "256", "512"] @@ -140,14 +141,29 @@ def __init__(self, viewer: "napari.viewer.Viewer"): ) self.lbl_window_size_choice = self.window_size_choice.label - self.keep_data_on_cpu_box = ui.make_checkbox("Keep data on CPU") + self.window_overlap_counter = ui.DoubleIncrementCounter( + min=0, + max=1, + default=0.25, + step=0.05, + parent=self, + label="Overlap %", + ) - self.window_infer_params = ui.combine_blocks( + self.keep_data_on_cpu_box = ui.CheckBox(title="Keep data on CPU") + + window_size_widgets = ui.combine_blocks( self.window_size_choice, self.lbl_window_size_choice, horizontal=False, ) + self.window_infer_params = ui.combine_blocks( + window_size_widgets, + self.window_overlap_counter.get_with_label(horizontal=False), + horizontal=False, + ) + ################## ################## # instance segmentation widgets @@ -228,6 +244,10 @@ def __init__(self, viewer: "napari.viewer.Viewer"): "Size of the window to run inference with (in pixels)" ) + self.window_overlap_counter.setToolTip( + "Percentage of overlap between windows to use when using sliding window" + ) + self.keep_data_on_cpu_box.setToolTip( "If enabled, data will be kept on the RAM rather than the VRAM.\nCan avoid out of memory issues with CUDA" ) @@ -582,6 +602,7 @@ def start(self): self.window_inference_size = int( self.window_size_choice.currentText() ) + self.window_overlap = self.window_overlap_counter.value() self.worker = InferenceWorker( device=device, @@ -594,6 +615,7 @@ def start(self): instance=self.instance_params, use_window=self.use_window_inference, window_infer_size=self.window_inference_size, + window_overlap=self.window_overlap, keep_on_cpu=self.keep_on_cpu, stats_csv=self.stats_to_csv, ) From c3e70eea16fd6344ea814047ed7344d853979c3a Mon Sep 17 00:00:00 2001 From: Maxime Vidal Date: Mon, 4 Jul 2022 14:25:52 +0200 Subject: [PATCH 08/19] adding swinunetr & removing extra padding for inference --- .gitignore | 1 - napari_cellseg3d/model_framework.py | 2 ++ napari_cellseg3d/model_workers.py | 24 +++++++++++----- napari_cellseg3d/models/model_SwinUNetR.py | 18 ++++++++++++ napari_cellseg3d/models/model_VNet.py | 7 ++++- napari_cellseg3d/plugin_model_inference.py | 32 +++++++++++++++++++--- napari_cellseg3d/plugin_model_training.py | 2 +- napari_cellseg3d/utils.py | 21 -------------- setup.cfg | 3 +- 9 files changed, 74 insertions(+), 36 deletions(-) create mode 100644 napari_cellseg3d/models/model_SwinUNetR.py diff --git a/.gitignore b/.gitignore index f38ebe4d..74feefbe 100644 --- a/.gitignore +++ b/.gitignore @@ -99,4 +99,3 @@ venv/ /napari_cellseg3d/models/saved_weights/ /docs/res/logo/old_logo/ /reqs/ - diff --git a/napari_cellseg3d/model_framework.py b/napari_cellseg3d/model_framework.py index 1baf6eed..648865bb 100644 --- a/napari_cellseg3d/model_framework.py +++ b/napari_cellseg3d/model_framework.py @@ -13,6 +13,7 @@ from napari_cellseg3d import utils from napari_cellseg3d.log_utility import Log from napari_cellseg3d.models import model_SegResNet as SegResNet +from napari_cellseg3d.models import model_SwinUNetR as SwinUNetR from napari_cellseg3d.models import model_TRAILMAP as TRAILMAP from napari_cellseg3d.models import model_VNet as VNet from napari_cellseg3d.models import model_TRAILMAP_MS as TRAILMAP_MS @@ -64,6 +65,7 @@ def __init__(self, viewer: "napari.viewer.Viewer"): "SegResNet": SegResNet, "TRAILMAP": TRAILMAP, "TRAILMAP_MS": TRAILMAP_MS, + "SwinUNetR": SwinUNetR, } """dict: dictionary of available models, with string for widget display as key diff --git a/napari_cellseg3d/model_workers.py b/napari_cellseg3d/model_workers.py index 0878d926..93a93875 100644 --- a/napari_cellseg3d/model_workers.py +++ b/napari_cellseg3d/model_workers.py @@ -301,8 +301,6 @@ def log_parameters(self): f"Probability threshold is {self.instance_params['threshold']:.2f}\n" f"Objects smaller than {self.instance_params['size_small']} pixels will be removed\n" ) - # self.log(f"") - # self.log("\n") self.log("-" * 20) def load_folder(self): @@ -313,11 +311,7 @@ def load_folder(self): data_check = LoadImaged(keys=["image"])(images_dict[0]) check = data_check["image"].shape - # TODO remove - # z_aniso = 5 / 1.5 - # if zoom is not None : - # pad = utils.get_padding_dim(check, anisotropy_factor=zoom) - # else: + self.log("\nChecking dimensions...") pad = utils.get_padding_dim(check) @@ -1027,10 +1021,26 @@ def train(self): out_channels=1, dropout_prob=0.3, ) + elif model_name == "SwinUNetR": + if self.sampling: + size = self.sample_size + else: + size = check + print(f"Size of image : {size}") + model = model_class.get_net()( + img_size=utils.get_padding_dim(size), + in_channels=1, + out_channels=1, + feature_size=48, + use_checkpoint=True, + ) else: model = model_class.get_net() # get an instance of the model model = model.to(self.device) + + + epoch_loss_values = [] val_metric_values = [] diff --git a/napari_cellseg3d/models/model_SwinUNetR.py b/napari_cellseg3d/models/model_SwinUNetR.py new file mode 100644 index 00000000..93516340 --- /dev/null +++ b/napari_cellseg3d/models/model_SwinUNetR.py @@ -0,0 +1,18 @@ +from monai.networks.nets import SwinUNETR + + +def get_weights_file(): + return "" + + +def get_net(): + return SwinUNETR + + +def get_output(model, input): + out = model(input) + return out + + +def get_validation(model, val_inputs): + return model(val_inputs) diff --git a/napari_cellseg3d/models/model_VNet.py b/napari_cellseg3d/models/model_VNet.py index 0c5f0b75..0c854832 100644 --- a/napari_cellseg3d/models/model_VNet.py +++ b/napari_cellseg3d/models/model_VNet.py @@ -19,6 +19,11 @@ def get_validation(model, val_inputs): roi_size = (64, 64, 64) sw_batch_size = 1 val_outputs = sliding_window_inference( - val_inputs, roi_size, sw_batch_size, model, mode="gaussian" + val_inputs, + roi_size, + sw_batch_size, + model, + mode="gaussian", + overlap=0.7, ) return val_outputs diff --git a/napari_cellseg3d/plugin_model_inference.py b/napari_cellseg3d/plugin_model_inference.py index e6265767..f11b666c 100644 --- a/napari_cellseg3d/plugin_model_inference.py +++ b/napari_cellseg3d/plugin_model_inference.py @@ -78,6 +78,7 @@ def __init__(self, viewer: "napari.viewer.Viewer"): self.keep_on_cpu = False self.use_window_inference = False self.window_inference_size = None + self.window_overlap_percentage = None ########################### # interface @@ -134,6 +135,17 @@ def __init__(self, viewer: "napari.viewer.Viewer"): self.window_infer_box.clicked.connect(self.toggle_display_window_size) sizes_window = ["8", "16", "32", "64", "128", "256", "512"] + # ( + # self.window_size_choice, + # self.lbl_window_size_choice, + # ) = ui.make_combobox(sizes_window, label="Window size and overlap") + # self.window_overlap = ui.make_n_spinboxes( + # max=1, + # default=0.7, + # step=0.05, + # double=True, + # ) + self.window_size_choice = ui.DropdownMenu( sizes_window, label="Window size" ) @@ -146,6 +158,11 @@ def __init__(self, viewer: "napari.viewer.Viewer"): self.lbl_window_size_choice, horizontal=False, ) + # self.window_infer_params = ui.combine_blocks( + # self.window_overlap, + # self.window_infer_params, + # horizontal=False, + # ) ################## ################## @@ -216,7 +233,7 @@ def __init__(self, viewer: "napari.viewer.Viewer"): "Displays the image used for inference in the viewer" ) self.segres_size.setToolTip( - "Image size on which the SegResNet has been trained (default : 128)" + "Image size on which the model has been trained (default : 128)" ) thresh_desc = ( @@ -234,6 +251,11 @@ def __init__(self, viewer: "napari.viewer.Viewer"): self.window_size_choice.setToolTip( "Size of the window to run inference with (in pixels)" ) + + # self.window_overlap.setToolTip( + # "Amount of overlap between sliding windows" + # ) + self.keep_data_on_cpu_box.setToolTip( "If enabled, data will be kept on the RAM rather than the VRAM.\nCan avoid out of memory issues with CUDA" ) @@ -281,7 +303,10 @@ def check_ready(self): return False def toggle_display_segres_size(self): - if self.model_choice.currentText() == "SegResNet": + if ( + self.model_choice.currentText() == "SegResNet" + or self.model_choice.currentText() == "SwinUNetR" + ): self.segres_size.setVisible(True) else: self.segres_size.setVisible(False) @@ -600,6 +625,7 @@ def start(self, on_layer=False): self.window_inference_size = int( self.window_size_choice.currentText() ) + # self.window_overlap_percentage = self.window_overlap.value() if not on_layer: self.worker = InferenceWorker( @@ -724,8 +750,6 @@ def on_yield(data, widget): zoom = widget.zoom - # print(data["original"].shape) - # print(data["result"].shape) viewer.dims.ndisplay = 3 viewer.scale_bar.visible = True diff --git a/napari_cellseg3d/plugin_model_training.py b/napari_cellseg3d/plugin_model_training.py index 517cf8fc..1e9deed6 100644 --- a/napari_cellseg3d/plugin_model_training.py +++ b/napari_cellseg3d/plugin_model_training.py @@ -31,7 +31,7 @@ from napari_cellseg3d.model_workers import TrainingWorker NUMBER_TABS = 3 -DEFAULT_PATCH_SIZE = 60 +DEFAULT_PATCH_SIZE = 64 class Trainer(ModelFramework): diff --git a/napari_cellseg3d/utils.py b/napari_cellseg3d/utils.py index 83b23e9e..c1666632 100644 --- a/napari_cellseg3d/utils.py +++ b/napari_cellseg3d/utils.py @@ -116,7 +116,6 @@ def resize(image, zoom_factors): def align_array_sizes(array_shape, target_shape): - index_differences = [] for i in range(len(target_shape)): if target_shape[i] != array_shape[i]: @@ -334,7 +333,6 @@ def fill_list_in_between(lst, n, elem): Returns : Filled list """ - new_list = [] for i in range(len(lst)): temp_list = [lst[i]] @@ -608,25 +606,6 @@ def format_Warning(message, category, filename, lineno, line=""): ) -# def dice_coeff(y_true, y_pred): -# smooth = 1. -# y_true_f = y_true.flatten() -# y_pred_f = K.flatten(y_pred) -# intersection = K.sum(y_true_f * y_pred_f) -# score = (2. * intersection + smooth) / (K.sum(y_true_f) + K.sum(y_pred_f) + smooth) -# return score - - -# def dice_loss(y_true, y_pred): -# loss = 1 - dice_coeff(y_true, y_pred) -# return loss - - -# def bce_dice_loss(y_true, y_pred): -# loss = binary_crossentropy(y_true, y_pred) + dice_loss(y_true, y_pred) -# return loss - - def divide_imgs(images): H = -(-images.shape[1] // 412) W = -(-images.shape[2] // 412) diff --git a/setup.cfg b/setup.cfg index b423c9ea..a0aa6eee 100644 --- a/setup.cfg +++ b/setup.cfg @@ -50,12 +50,13 @@ install_requires = tifffile>=2022.2.9 imageio-ffmpeg>=0.4.5 torch>=1.11 + monai[nibabel,scikit-image,itk,einops]>=0.9.0 tqdm monai>=0.9.0 nibabel scikit-image pillow - itk>=5.2.0 + tqdm matplotlib vispy>=0.9.6 From 43807ec78c1b54fffe53f81b306a150ddc8701a1 Mon Sep 17 00:00:00 2001 From: Maxime Vidal Date: Fri, 22 Jul 2022 14:32:17 +0200 Subject: [PATCH 09/19] removed unnecessary code --- napari_cellseg3d/utils.py | 356 -------------------------------------- 1 file changed, 356 deletions(-) diff --git a/napari_cellseg3d/utils.py b/napari_cellseg3d/utils.py index c1666632..498c4830 100644 --- a/napari_cellseg3d/utils.py +++ b/napari_cellseg3d/utils.py @@ -604,359 +604,3 @@ def format_Warning(message, category, filename, lineno, line=""): + str(message) + "\n" ) - - -def divide_imgs(images): - H = -(-images.shape[1] // 412) - W = -(-images.shape[2] // 412) - - diveded_imgs = np.zeros((images.shape[0] * H * W, 512, 512, 1), np.float32) - # print(H, W) - - for z in range(images.shape[0]): - image = images[z] - for h in range(H): - for w in range(W): - cropped_img = np.zeros((512, 512, 1), np.float32) - cropped_img -= 1 - - if images.shape[1] < 412: - h = -1 - if images.shape[2] < 412: - w = -1 - - if h == -1: - if w == -1: - cropped_img[ - 50 : images.shape[1] + 50, - 50 : images.shape[2] + 50, - 0, - ] = image[0 : images.shape[1], 0 : images.shape[2], 0] - elif w == 0: - cropped_img[ - 50 : images.shape[1] + 50, 50:512, 0 - ] = image[0 : images.shape[1], 0:462, 0] - elif w == W - 1: - cropped_img[ - 50 : images.shape[1] + 50, - 0 : images.shape[2] - 412 * W - 50, - 0, - ] = image[ - 0 : images.shape[1], - w * 412 - 50 : images.shape[2], - 0, - ] - else: - cropped_img[50 : images.shape[1] + 50, :, 0] = image[ - 0 : images.shape[1], - w * 412 - 50 : (w + 1) * 412 + 50, - 0, - ] - elif h == 0: - if w == -1: - cropped_img[ - 50:512, 50 : images.shape[2] + 50, 0 - ] = image[0:462, 0 : images.shape[2], 0] - elif w == 0: - cropped_img[50:512, 50:512, 0] = image[0:462, 0:462, 0] - elif w == W - 1: - cropped_img[ - 50:512, 0 : images.shape[2] - 412 * W - 50, 0 - ] = image[0:462, w * 412 - 50 : images.shape[2], 0] - else: - # cropped_img[50:512, :, 0] = image[0:462, w*412-50:(w+1)*412+50, 0] - try: - cropped_img[50:512, :, 0] = image[ - 0:462, w * 412 - 50 : (w + 1) * 412 + 50, 0 - ] - except: - cropped_img[ - 50:512, - 0 : images.shape[2] - 412 * (W - 1) - 50, - 0, - ] = image[ - 0:462, w * 412 - 50 : (w + 1) * 412 + 50, 0 - ] - elif h == H - 1: - if w == -1: - cropped_img[ - 0 : images.shape[1] - 412 * H - 50, - 50 : images.shape[2] + 50, - 0, - ] = image[ - h * 412 - 50 : images.shape[1], - 0 : images.shape[2], - 0, - ] - elif w == 0: - cropped_img[ - 0 : images.shape[1] - 412 * H - 50, 50:512, 0 - ] = image[h * 412 - 50 : images.shape[1], 0:462, 0] - elif w == W - 1: - cropped_img[ - 0 : images.shape[1] - 412 * H - 50, - 0 : images.shape[2] - 412 * W - 50, - 0, - ] = image[ - h * 412 - 50 : images.shape[1], - w * 412 - 50 : images.shape[2], - 0, - ] - else: - try: - cropped_img[ - 0 : images.shape[1] - 412 * H - 50, :, 0 - ] = image[ - h * 412 - 50 : images.shape[1], - w * 412 - 50 : (w + 1) * 412 + 50, - 0, - ] - except: - cropped_img[ - 0 : images.shape[1] - 412 * H - 50, - 0 : images.shape[2] - 412 * (W - 1) - 50, - 0, - ] = image[ - h * 412 - 50 : images.shape[1], - w * 412 - 50 : (w + 1) * 412 + 50, - 0, - ] - else: - if w == -1: - cropped_img[:, 50 : images.shape[2] + 50, 0] = image[ - h * 412 - 50 : (h + 1) * 412 + 50, - 0 : images.shape[2], - 0, - ] - elif w == 0: - # cropped_img[:, 50:512, 0] = image[h*412-50:(h+1)*412+50, 0:462, 0] - try: - cropped_img[:, 50:512, 0] = image[ - h * 412 - 50 : (h + 1) * 412 + 50, 0:462, 0 - ] - except: - cropped_img[ - 0 : images.shape[1] - 412 * H - 50 + 412, - 50:512, - 0, - ] = image[ - h * 412 - 50 : (h + 1) * 412 + 50, 0:462, 0 - ] - elif w == W - 1: - # cropped_img[:, 0:images.shape[2]-412*W-50, 0] = image[h*412-50:(h+1)*412+50, w*412-50:images.shape[2], 0] - try: - cropped_img[ - :, 0 : images.shape[2] - 412 * W - 50, 0 - ] = image[ - h * 412 - 50 : (h + 1) * 412 + 50, - w * 412 - 50 : images.shape[2], - 0, - ] - except: - cropped_img[ - 0 : images.shape[1] - 412 * H - 50 + 412, - 0 : images.shape[2] - 412 * W - 50, - 0, - ] = image[ - h * 412 - 50 : (h + 1) * 412 + 50, - w * 412 - 50 : images.shape[2], - 0, - ] - else: - # cropped_img[:, :, 0] = image[h*412-50:(h+1)*412+50, w*412-50:(w+1)*412+50, 0] - try: - cropped_img[:, :, 0] = image[ - h * 412 - 50 : (h + 1) * 412 + 50, - w * 412 - 50 : (w + 1) * 412 + 50, - 0, - ] - except: - try: - cropped_img[ - :, - 0 : images.shape[2] - 412 * (W - 1) - 50, - 0, - ] = image[ - h * 412 - 50 : (h + 1) * 412 + 50, - w * 412 - 50 : (w + 1) * 412 + 50, - 0, - ] - except: - cropped_img[ - 0 : images.shape[1] - 412 * (H - 1) - 50, - :, - 0, - ] = image[ - h * 412 - 50 : (h + 1) * 412 + 50, - w * 412 - 50 : (w + 1) * 412 + 50, - 0, - ] - h = max(0, h) - w = max(0, w) - diveded_imgs[z * H * W + w * H + h] = cropped_img - # print(z*H*W+ w*H+h) - - return diveded_imgs - - -def merge_imgs(imgs, original_image_shape): - merged_imgs = np.zeros( - ( - original_image_shape[0], - original_image_shape[1], - original_image_shape[2], - 1, - ), - np.float32, - ) - H = -(-original_image_shape[1] // 412) - W = -(-original_image_shape[2] // 412) - - for z in range(original_image_shape[0]): - for h in range(H): - for w in range(W): - - if original_image_shape[1] < 412: - h = -1 - if original_image_shape[2] < 412: - w = -1 - - # print(z*H*W+ max(w, 0)*H+max(h, 0)) - if h == -1: - if w == -1: - merged_imgs[ - z, - 0 : original_image_shape[1], - 0 : original_image_shape[2], - 0, - ] = imgs[z * H * W + 0 * H + 0][ - 50 : original_image_shape[1] + 50, - 50 : original_image_shape[2] + 50, - 0, - ] - elif w == 0: - merged_imgs[ - z, 0 : original_image_shape[1], 0:412, 0 - ] = imgs[z * H * W + w * H + 0][ - 50 : original_image_shape[1] + 50, 50:462, 0 - ] - elif w == W - 1: - merged_imgs[ - z, - 0 : original_image_shape[1], - w * 412 : original_image_shape[2], - 0, - ] = imgs[z * H * W + w * H + 0][ - 50 : original_image_shape[1] + 50, - 50 : original_image_shape[2] - 412 * W - 50, - 0, - ] - else: - merged_imgs[ - z, - 0 : original_image_shape[1], - w * 412 : (w + 1) * 412, - 0, - ] = imgs[z * H * W + w * H + 0][ - 50 : original_image_shape[1] + 50, 50:462, 0 - ] - elif h == 0: - if w == -1: - merged_imgs[ - z, 0:412, 0 : original_image_shape[2], 0 - ] = imgs[z * H * W + 0 * H + h][ - 50:462, 50 : original_image_shape[2] + 50, 0 - ] - elif w == 0: - merged_imgs[z, 0:412, 0:412, 0] = imgs[ - z * H * W + w * H + h - ][50:462, 50:462, 0] - elif w == W - 1: - merged_imgs[ - z, 0:412, w * 412 : original_image_shape[2], 0 - ] = imgs[z * H * W + w * H + h][ - 50:462, - 50 : original_image_shape[2] - 412 * W - 50, - 0, - ] - else: - merged_imgs[ - z, 0:412, w * 412 : (w + 1) * 412, 0 - ] = imgs[z * H * W + w * H + h][50:462, 50:462, 0] - elif h == H - 1: - if w == -1: - merged_imgs[ - z, - h * 412 : original_image_shape[1], - 0 : original_image_shape[2], - 0, - ] = imgs[z * H * W + 0 * H + h][ - 50 : original_image_shape[1] - 412 * H - 50, - 50 : original_image_shape[2] + 50, - 0, - ] - elif w == 0: - merged_imgs[ - z, h * 412 : original_image_shape[1], 0:412, 0 - ] = imgs[z * H * W + w * H + h][ - 50 : original_image_shape[1] - 412 * H - 50, - 50:462, - 0, - ] - elif w == W - 1: - merged_imgs[ - z, - h * 412 : original_image_shape[1], - w * 412 : original_image_shape[2], - 0, - ] = imgs[z * H * W + w * H + h][ - 50 : original_image_shape[1] - 412 * H - 50, - 50 : original_image_shape[2] - 412 * W - 50, - 0, - ] - else: - merged_imgs[ - z, - h * 412 : original_image_shape[1], - w * 412 : (w + 1) * 412, - 0, - ] = imgs[z * H * W + w * H + h][ - 50 : original_image_shape[1] - 412 * H - 50, - 50:462, - 0, - ] - else: - if w == -1: - merged_imgs[ - z, - h * 412 : (h + 1) * 412, - 0 : original_image_shape[2], - 0, - ] = imgs[z * H * W + 0 * H + h][ - 50:462, 50 : original_image_shape[2] + 50, 0 - ] - elif w == 0: - merged_imgs[ - z, h * 412 : (h + 1) * 412, 0:412, 0 - ] = imgs[z * H * W + w * H + h][50:462, 50:462, 0] - elif w == W - 1: - merged_imgs[ - z, - h * 412 : (h + 1) * 412, - w * 412 : original_image_shape[2], - 0, - ] = imgs[z * H * W + w * H + h][ - 50:462, - 50 : original_image_shape[2] - 412 * W - 50, - 0, - ] - else: - merged_imgs[ - z, - h * 412 : (h + 1) * 412, - w * 412 : (w + 1) * 412, - 0, - ] = imgs[z * H * W + w * H + h][50:462, 50:462, 0] - - print(merged_imgs.shape) - return merged_imgs From 58b831e6261842853012687bec0ac5a46abbb082 Mon Sep 17 00:00:00 2001 From: Maxime Vidal Date: Fri, 22 Jul 2022 17:52:58 +0200 Subject: [PATCH 10/19] :wrench move inherent model params where they are instantiated and fixed padding only for sliding window inference --- napari_cellseg3d/model_workers.py | 8 ++++---- napari_cellseg3d/models/model_SegResNet.py | 4 ++-- napari_cellseg3d/models/model_SwinUNetR.py | 7 ++++--- napari_cellseg3d/plugin_model_inference.py | 16 ++++++++-------- 4 files changed, 18 insertions(+), 17 deletions(-) diff --git a/napari_cellseg3d/model_workers.py b/napari_cellseg3d/model_workers.py index 93a93875..bdc54d7c 100644 --- a/napari_cellseg3d/model_workers.py +++ b/napari_cellseg3d/model_workers.py @@ -45,7 +45,6 @@ # Qt from qtpy.QtCore import Signal - from napari_cellseg3d import utils from napari_cellseg3d import log_utility @@ -231,6 +230,7 @@ def __init__( self.instance_params = instance self.use_window = use_window self.window_infer_size = window_infer_size + self.window_overlap_percentage = 0.8, self.keep_on_cpu = keep_on_cpu self.stats_to_csv = stats_csv ############################################ @@ -399,8 +399,10 @@ def model_output( if self.use_window: window_size = self.window_infer_size + window_overlap = self.window_overlap_percentage else: window_size = None + window_overlap = 0.25 outputs = sliding_window_inference( inputs, @@ -409,6 +411,7 @@ def model_output( predictor=model_output, sw_device=self.device, device=dataset_device, + overlap=window_overlap, ) out = outputs.detach().cpu() @@ -1029,9 +1032,6 @@ def train(self): print(f"Size of image : {size}") model = model_class.get_net()( img_size=utils.get_padding_dim(size), - in_channels=1, - out_channels=1, - feature_size=48, use_checkpoint=True, ) else: diff --git a/napari_cellseg3d/models/model_SegResNet.py b/napari_cellseg3d/models/model_SegResNet.py index 41dc3bde..eced51c9 100644 --- a/napari_cellseg3d/models/model_SegResNet.py +++ b/napari_cellseg3d/models/model_SegResNet.py @@ -1,8 +1,8 @@ from monai.networks.nets import SegResNetVAE -def get_net(): - return SegResNetVAE +def get_net(input_image_size, dropout_prob=None): + return SegResNetVAE(input_image_size, out_channels=1, dropout_prob=dropout_prob) def get_weights_file(): diff --git a/napari_cellseg3d/models/model_SwinUNetR.py b/napari_cellseg3d/models/model_SwinUNetR.py index 93516340..5e3546ee 100644 --- a/napari_cellseg3d/models/model_SwinUNetR.py +++ b/napari_cellseg3d/models/model_SwinUNetR.py @@ -1,3 +1,4 @@ +import torch from monai.networks.nets import SwinUNETR @@ -5,13 +6,13 @@ def get_weights_file(): return "" -def get_net(): - return SwinUNETR +def get_net(img_size, use_checkpoint=True): + return SwinUNETR(img_size, in_channels=1, out_channels=1, feature_size=48, use_checkpoint=use_checkpoint) def get_output(model, input): out = model(input) - return out + return torch.sigmoid(out) def get_validation(model, val_inputs): diff --git a/napari_cellseg3d/plugin_model_inference.py b/napari_cellseg3d/plugin_model_inference.py index f11b666c..f79169a7 100644 --- a/napari_cellseg3d/plugin_model_inference.py +++ b/napari_cellseg3d/plugin_model_inference.py @@ -100,9 +100,9 @@ def __init__(self, viewer: "napari.viewer.Viewer"): ###################### ###################### # TODO : better way to handle SegResNet size reqs ? - self.segres_size = ui.IntIncrementCounter(min=1, max=1024, default=128) + self.model_input_size = ui.IntIncrementCounter(min=1, max=1024, default=128) self.model_choice.currentIndexChanged.connect( - self.toggle_display_segres_size + self.toggle_display_model_input_size ) self.model_choice.setCurrentIndex(0) @@ -232,7 +232,7 @@ def __init__(self, viewer: "napari.viewer.Viewer"): self.show_original_checkbox.setToolTip( "Displays the image used for inference in the viewer" ) - self.segres_size.setToolTip( + self.model_input_size.setToolTip( "Image size on which the model has been trained (default : 128)" ) @@ -302,14 +302,14 @@ def check_ready(self): warnings.warn("Image and label paths are not correctly set") return False - def toggle_display_segres_size(self): + def toggle_display_model_input_size(self): if ( self.model_choice.currentText() == "SegResNet" or self.model_choice.currentText() == "SwinUNetR" ): - self.segres_size.setVisible(True) + self.model_input_size.setVisible(True) else: - self.segres_size.setVisible(False) + self.model_input_size.setVisible(False) def toggle_display_number(self): """Shows the choices for viewing results depending on whether :py:attr:`self.view_checkbox` is checked""" @@ -418,7 +418,7 @@ def build(self): self.model_choice, self.custom_weights_choice, self.weights_path_container, - self.segres_size, + self.model_input_size, ], ) self.weights_path_container.setVisible(False) @@ -576,7 +576,7 @@ def start(self, on_layer=False): model_dict = { # gather model info "name": model_key, "class": self.get_model(model_key), - "segres_size": self.segres_size.value(), + "model_input_size": self.model_input_size.value(), } if self.custom_weights_choice.isChecked(): From 3c1e023f26b793ec58f7ec1b817577684729f5b7 Mon Sep 17 00:00:00 2001 From: Maxime Vidal Date: Fri, 22 Jul 2022 17:53:23 +0200 Subject: [PATCH 11/19] :art: formatting --- napari_cellseg3d/model_workers.py | 10 +++++++++- napari_cellseg3d/models/model_SegResNet.py | 4 +++- napari_cellseg3d/models/model_SwinUNetR.py | 8 +++++++- 3 files changed, 19 insertions(+), 3 deletions(-) diff --git a/napari_cellseg3d/model_workers.py b/napari_cellseg3d/model_workers.py index bdc54d7c..a795d338 100644 --- a/napari_cellseg3d/model_workers.py +++ b/napari_cellseg3d/model_workers.py @@ -683,7 +683,7 @@ def inference(self): try: dims = self.model_dict["segres_size"] - model = self.model_dict["class"].get_net() + if self.model_dict["name"] == "SegResNet": model = self.model_dict["class"].get_net()( input_image_size=[ @@ -694,6 +694,14 @@ def inference(self): out_channels=1, # dropout_prob=0.3, ) + elif self.model_dict["name"] == "SwinUNetR": + model = self.model_dict["class"].get_net( + img_size=[dims, dims, dims], + use_checkpoint=False, + ) + else: + model = self.model_dict["class"].get_net() + model = model.to(self.device) self.log_parameters() diff --git a/napari_cellseg3d/models/model_SegResNet.py b/napari_cellseg3d/models/model_SegResNet.py index eced51c9..ee1dc9a8 100644 --- a/napari_cellseg3d/models/model_SegResNet.py +++ b/napari_cellseg3d/models/model_SegResNet.py @@ -2,7 +2,9 @@ def get_net(input_image_size, dropout_prob=None): - return SegResNetVAE(input_image_size, out_channels=1, dropout_prob=dropout_prob) + return SegResNetVAE( + input_image_size, out_channels=1, dropout_prob=dropout_prob + ) def get_weights_file(): diff --git a/napari_cellseg3d/models/model_SwinUNetR.py b/napari_cellseg3d/models/model_SwinUNetR.py index 5e3546ee..67d911b3 100644 --- a/napari_cellseg3d/models/model_SwinUNetR.py +++ b/napari_cellseg3d/models/model_SwinUNetR.py @@ -7,7 +7,13 @@ def get_weights_file(): def get_net(img_size, use_checkpoint=True): - return SwinUNETR(img_size, in_channels=1, out_channels=1, feature_size=48, use_checkpoint=use_checkpoint) + return SwinUNETR( + img_size, + in_channels=1, + out_channels=1, + feature_size=48, + use_checkpoint=use_checkpoint, + ) def get_output(model, input): From 70a936a99a2cbb6bd743a4160211b071a3139446 Mon Sep 17 00:00:00 2001 From: C-Achard Date: Fri, 12 Aug 2022 11:14:54 +0200 Subject: [PATCH 12/19] Disable TRAILMAP Removed draft of TRAILMAP port to Pytorch --- napari_cellseg3d/model_framework.py | 4 ++-- napari_cellseg3d/models/model_TRAILMAP.py | 2 -- napari_cellseg3d/models/pretrained/pretrained_model_urls.json | 1 - 3 files changed, 2 insertions(+), 5 deletions(-) diff --git a/napari_cellseg3d/model_framework.py b/napari_cellseg3d/model_framework.py index 648865bb..0f0304b5 100644 --- a/napari_cellseg3d/model_framework.py +++ b/napari_cellseg3d/model_framework.py @@ -14,7 +14,7 @@ from napari_cellseg3d.log_utility import Log from napari_cellseg3d.models import model_SegResNet as SegResNet from napari_cellseg3d.models import model_SwinUNetR as SwinUNetR -from napari_cellseg3d.models import model_TRAILMAP as TRAILMAP +# from napari_cellseg3d.models import model_TRAILMAP as TRAILMAP from napari_cellseg3d.models import model_VNet as VNet from napari_cellseg3d.models import model_TRAILMAP_MS as TRAILMAP_MS from napari_cellseg3d.plugin_base import BasePluginFolder @@ -63,7 +63,7 @@ def __init__(self, viewer: "napari.viewer.Viewer"): self.models_dict = { "VNet": VNet, "SegResNet": SegResNet, - "TRAILMAP": TRAILMAP, + # "TRAILMAP": TRAILMAP, "TRAILMAP_MS": TRAILMAP_MS, "SwinUNetR": SwinUNetR, } diff --git a/napari_cellseg3d/models/model_TRAILMAP.py b/napari_cellseg3d/models/model_TRAILMAP.py index ff28ecdd..7cdf9b80 100644 --- a/napari_cellseg3d/models/model_TRAILMAP.py +++ b/napari_cellseg3d/models/model_TRAILMAP.py @@ -92,7 +92,6 @@ def bridgeBlock(self, in_ch, out_ch, kernel_size, padding="same"): ), nn.BatchNorm3d(out_ch), nn.ReLU(), - # nn.ConvTranspose3d(out_ch, out_ch, kernel_size=2, stride=2), ) return encode @@ -117,6 +116,5 @@ def outBlock(self, in_ch, out_ch, kernel_size, padding="same"): out = nn.Sequential( nn.Conv3d(in_ch, out_ch, kernel_size=kernel_size, padding=padding), - # nn.BatchNorm3d(out_ch), ) return out diff --git a/napari_cellseg3d/models/pretrained/pretrained_model_urls.json b/napari_cellseg3d/models/pretrained/pretrained_model_urls.json index 86bc0f57..db70430a 100644 --- a/napari_cellseg3d/models/pretrained/pretrained_model_urls.json +++ b/napari_cellseg3d/models/pretrained/pretrained_model_urls.json @@ -1,6 +1,5 @@ { "TRAILMAP_MS": "http://deeplabcut.rowland.harvard.edu/cellseg3dmodels/TRAILMAP_MS.tar.gz", - "TRAILMAP": "http://deeplabcut.rowland.harvard.edu/cellseg3dmodels/TRAILMAP.tar.gz", "SegResNet": "http://deeplabcut.rowland.harvard.edu/cellseg3dmodels/SegResNet.tar.gz", "VNet": "http://deeplabcut.rowland.harvard.edu/cellseg3dmodels/VNet.tar.gz" } \ No newline at end of file From 09bdef04cb6f098687a6f1e93293821ef0a0e291 Mon Sep 17 00:00:00 2001 From: C-Achard Date: Fri, 12 Aug 2022 12:43:53 +0200 Subject: [PATCH 13/19] Fixes + reqs update - Fixed model instantiation - Fix window overlap argument type error - Updated reqs.txt to have einops (MONAI optional dep.) --- napari_cellseg3d/models/model_SwinUNetR.py | 2 +- requirements.txt | 5 +---- 2 files changed, 2 insertions(+), 5 deletions(-) diff --git a/napari_cellseg3d/models/model_SwinUNetR.py b/napari_cellseg3d/models/model_SwinUNetR.py index 67d911b3..532aeb89 100644 --- a/napari_cellseg3d/models/model_SwinUNetR.py +++ b/napari_cellseg3d/models/model_SwinUNetR.py @@ -3,7 +3,7 @@ def get_weights_file(): - return "" + return "Swin64_best_metric.pth" def get_net(img_size, use_checkpoint=True): diff --git a/requirements.txt b/requirements.txt index 9885d9f0..3ba73405 100644 --- a/requirements.txt +++ b/requirements.txt @@ -13,13 +13,10 @@ napari[all]>=0.4.14 QtPy opencv-python>=4.5.5 dask-image>=0.6.0 -scikit-image>=0.19.2 matplotlib>=3.4.1 tifffile>=2022.2.9 imageio-ffmpeg>=0.4.5 torch>=1.11 -monai>=0.9.0 -nibabel +monai[nibabel,scikit-image,itk,einops]>=0.9.0 pillow -itk>=5.2.0 vispy>=0.9.6 From 55d22786f24f667ff61c1f665e2a229692fb4688 Mon Sep 17 00:00:00 2001 From: C-Achard Date: Fri, 12 Aug 2022 13:51:18 +0200 Subject: [PATCH 14/19] UI overhaul + overlap parameter - Added overlap parameter for window - Improved UI code slightly --- napari_cellseg3d/interface.py | 29 +++++++++- napari_cellseg3d/model_framework.py | 1 + napari_cellseg3d/model_workers.py | 64 ++++++++++++++++++---- napari_cellseg3d/plugin_model_inference.py | 31 +++++++++-- 4 files changed, 105 insertions(+), 20 deletions(-) diff --git a/napari_cellseg3d/interface.py b/napari_cellseg3d/interface.py index 4931026e..dfb3be31 100644 --- a/napari_cellseg3d/interface.py +++ b/napari_cellseg3d/interface.py @@ -1,5 +1,4 @@ from typing import Optional -from typing import Union from typing import List @@ -61,6 +60,13 @@ def toggle_visibility(checkbox, widget): widget.setVisible(checkbox.isChecked()) +def add_label(widget, label, label_before=True, horizontal=True): + if label_before: + return combine_blocks(widget, label, horizontal=horizontal) + else: + return combine_blocks(label, widget, horizontal=horizontal) + + class Button(QPushButton): """Class for a button with a title and connected to a function when clicked. Inherits from QPushButton. @@ -494,6 +500,7 @@ def __init__( step=1, parent: Optional[QWidget] = None, fixed: Optional[bool] = True, + label: Optional[str] = None, ): """Args: min (Optional[float]): minimum value, defaults to 0 @@ -501,13 +508,25 @@ def __init__( default (Optional[float]): default value, defaults to 0 step (Optional[float]): step value, defaults to 1 parent: parent widget, defaults to None - fixed (bool): if True, sets the QSizePolicy of the spinbox to Fixed""" + fixed (bool): if True, sets the QSizePolicy of the spinbox to Fixed + label (Optional[str]): if provided, creates a label with the chosen title to use with the counter""" super().__init__(parent) set_spinbox(self, min, max, default, step, fixed) + if label is not None: + self.label = make_label(name=label) + + # def setToolTip(self, a0: str) -> None: + # self.setToolTip(a0) + # if self.label is not None: + # self.label.setToolTip(a0) + + def get_with_label(self, horizontal=True): + return add_label(self, self.label, horizontal=horizontal) + def set_precision(self, decimals): - """Sets the precision of the box to the speicifed number of decimals""" + """Sets the precision of the box to the specified number of decimals""" self.setDecimals(decimals) @classmethod @@ -535,6 +554,7 @@ def __init__( step=1, parent: Optional[QWidget] = None, fixed: Optional[bool] = True, + label: Optional[str] = None, ): """Args: min (Optional[int]): minimum value, defaults to 0 @@ -546,6 +566,9 @@ def __init__( super().__init__(parent) set_spinbox(self, min, max, default, step, fixed) + self.label = None + if label is not None: + self.label = make_label(label, self) @classmethod def make_n( diff --git a/napari_cellseg3d/model_framework.py b/napari_cellseg3d/model_framework.py index 0f0304b5..47208616 100644 --- a/napari_cellseg3d/model_framework.py +++ b/napari_cellseg3d/model_framework.py @@ -14,6 +14,7 @@ from napari_cellseg3d.log_utility import Log from napari_cellseg3d.models import model_SegResNet as SegResNet from napari_cellseg3d.models import model_SwinUNetR as SwinUNetR + # from napari_cellseg3d.models import model_TRAILMAP as TRAILMAP from napari_cellseg3d.models import model_VNet as VNet from napari_cellseg3d.models import model_TRAILMAP_MS as TRAILMAP_MS diff --git a/napari_cellseg3d/model_workers.py b/napari_cellseg3d/model_workers.py index a795d338..4fb47d10 100644 --- a/napari_cellseg3d/model_workers.py +++ b/napari_cellseg3d/model_workers.py @@ -164,6 +164,9 @@ def __init__(self): super().__init__() +# TODO : use dataclass for config instead ? + + class InferenceWorker(GeneratorWorker): """A custom worker to run inference jobs in. Inherits from :py:class:`napari.qt.threading.GeneratorWorker`""" @@ -179,6 +182,7 @@ def __init__( instance, use_window, window_infer_size, + window_overlap, keep_on_cpu, stats_csv, images_filepaths=None, @@ -230,7 +234,7 @@ def __init__( self.instance_params = instance self.use_window = use_window self.window_infer_size = window_infer_size - self.window_overlap_percentage = 0.8, + self.window_overlap_percentage = window_overlap self.keep_on_cpu = keep_on_cpu self.stats_to_csv = stats_csv ############################################ @@ -315,17 +319,53 @@ def load_folder(self): self.log("\nChecking dimensions...") pad = utils.get_padding_dim(check) - load_transforms = Compose( - [ - LoadImaged(keys=["image"]), - # AddChanneld(keys=["image"]), #already done - EnsureChannelFirstd(keys=["image"]), - # Orientationd(keys=["image"], axcodes="PLI"), - # anisotropic_transform, - SpatialPadd(keys=["image"], spatial_size=pad), - EnsureTyped(keys=["image"]), - ] - ) + dims = self.model_dict["model_input_size"] + + if self.model_dict["name"] == "SegResNet": + model = self.model_dict["class"].get_net( + input_image_size=[ + dims, + dims, + dims, + ] + ) + elif self.model_dict["name"] == "SwinUNetR": + model = self.model_dict["class"].get_net( + img_size=[dims, dims, dims], + use_checkpoint=False, + ) + else: + model = self.model_dict["class"].get_net() + + self.log_parameters() + + model.to(self.device) + + # print("FILEPATHS PRINT") + # print(self.images_filepaths) + if self.use_window: + load_transforms = Compose( + [ + LoadImaged(keys=["image"]), + # AddChanneld(keys=["image"]), #already done + EnsureChannelFirstd(keys=["image"]), + # Orientationd(keys=["image"], axcodes="PLI"), + # anisotropic_transform, + EnsureTyped(keys=["image"]), + ] + ) + else: + load_transforms = Compose( + [ + LoadImaged(keys=["image"]), + # AddChanneld(keys=["image"]), #already done + EnsureChannelFirstd(keys=["image"]), + # Orientationd(keys=["image"], axcodes="PLI"), + # anisotropic_transform, + SpatialPadd(keys=["image"], spatial_size=pad), + EnsureTyped(keys=["image"]), + ] + ) self.log("\nLoading dataset...") inference_ds = Dataset(data=images_dict, transform=load_transforms) diff --git a/napari_cellseg3d/plugin_model_inference.py b/napari_cellseg3d/plugin_model_inference.py index f79169a7..9251964b 100644 --- a/napari_cellseg3d/plugin_model_inference.py +++ b/napari_cellseg3d/plugin_model_inference.py @@ -78,7 +78,7 @@ def __init__(self, viewer: "napari.viewer.Viewer"): self.keep_on_cpu = False self.use_window_inference = False self.window_inference_size = None - self.window_overlap_percentage = None + self.window_overlap = 0.25 ########################### # interface @@ -131,7 +131,7 @@ def __init__(self, viewer: "napari.viewer.Viewer"): T=7, parent=self ) - self.window_infer_box = ui.make_checkbox("Use window inference") + self.window_infer_box = ui.CheckBox(title="Use window inference") self.window_infer_box.clicked.connect(self.toggle_display_window_size) sizes_window = ["8", "16", "32", "64", "128", "256", "512"] @@ -151,9 +151,18 @@ def __init__(self, viewer: "napari.viewer.Viewer"): ) self.lbl_window_size_choice = self.window_size_choice.label - self.keep_data_on_cpu_box = ui.make_checkbox("Keep data on CPU") + self.window_overlap_counter = ui.DoubleIncrementCounter( + min=0, + max=1, + default=0.25, + step=0.05, + parent=self, + label="Overlap %", + ) - self.window_infer_params = ui.combine_blocks( + self.keep_data_on_cpu_box = ui.CheckBox(title="Keep data on CPU") + + window_size_widgets = ui.combine_blocks( self.window_size_choice, self.lbl_window_size_choice, horizontal=False, @@ -164,6 +173,12 @@ def __init__(self, viewer: "napari.viewer.Viewer"): # horizontal=False, # ) + self.window_infer_params = ui.combine_blocks( + window_size_widgets, + self.window_overlap_counter.get_with_label(horizontal=False), + horizontal=False, + ) + ################## ################## # instance segmentation widgets @@ -252,6 +267,10 @@ def __init__(self, viewer: "napari.viewer.Viewer"): "Size of the window to run inference with (in pixels)" ) + self.window_overlap_counter.setToolTip( + "Percentage of overlap between windows to use when using sliding window" + ) + # self.window_overlap.setToolTip( # "Amount of overlap between sliding windows" # ) @@ -625,7 +644,7 @@ def start(self, on_layer=False): self.window_inference_size = int( self.window_size_choice.currentText() ) - # self.window_overlap_percentage = self.window_overlap.value() + self.window_overlap = self.window_overlap_counter.value() if not on_layer: self.worker = InferenceWorker( @@ -639,6 +658,7 @@ def start(self, on_layer=False): instance=self.instance_params, use_window=self.use_window_inference, window_infer_size=self.window_inference_size, + window_overlap=self.window_overlap, keep_on_cpu=self.keep_on_cpu, stats_csv=self.stats_to_csv, ) @@ -655,6 +675,7 @@ def start(self, on_layer=False): use_window=self.use_window_inference, window_infer_size=self.window_inference_size, keep_on_cpu=self.keep_on_cpu, + window_overlap=self.window_overlap, stats_csv=self.stats_to_csv, layer=layer, ) From 103058147f1bbff3c9373badbd693ba8c0529cd3 Mon Sep 17 00:00:00 2001 From: Maxime Vidal Date: Fri, 12 Aug 2022 17:14:26 +0200 Subject: [PATCH 15/19] name fix --- napari_cellseg3d/model_workers.py | 10 +--------- 1 file changed, 1 insertion(+), 9 deletions(-) diff --git a/napari_cellseg3d/model_workers.py b/napari_cellseg3d/model_workers.py index 4fb47d10..31785095 100644 --- a/napari_cellseg3d/model_workers.py +++ b/napari_cellseg3d/model_workers.py @@ -711,17 +711,9 @@ def inference(self): torch.set_num_threads(1) # required for threading on macOS ? self.log("Number of threads has been set to 1 for macOS") - # if self.device =="cuda": # TODO : fix mem alloc, this does not work it seems - # torch.backends.cudnn.benchmark = False - - # TODO : better solution than loading first image always ? - # data_check = LoadImaged(keys=["image"])(images_dict[0]) - # print(data) - # check = data_check["image"].shape - # print(check) try: - dims = self.model_dict["segres_size"] + dims = self.model_dict["model_input_size"] if self.model_dict["name"] == "SegResNet": From 4babcffd51f7ceefc50115b0785f60a96a1ab7db Mon Sep 17 00:00:00 2001 From: C-Achard Date: Fri, 12 Aug 2022 17:19:51 +0200 Subject: [PATCH 16/19] Removed padding with window on layer Removed padding transform when using single layer inference and the sliding window inference option --- napari_cellseg3d/model_workers.py | 19 ++++++++++++++++--- 1 file changed, 16 insertions(+), 3 deletions(-) diff --git a/napari_cellseg3d/model_workers.py b/napari_cellseg3d/model_workers.py index 31785095..c9cf2e51 100644 --- a/napari_cellseg3d/model_workers.py +++ b/napari_cellseg3d/model_workers.py @@ -398,19 +398,32 @@ def load_layer(self): # print(volume.shape) # print(volume.dtype) - - load_transforms = Compose( + if self.use_window: + load_transforms = Compose( [ ToTensor(), # anisotropic_transform, AddChannel(), - SpatialPad(spatial_size=pad), + # SpatialPad(spatial_size=pad), AddChannel(), EnsureType(), ], map_items=False, log_stats=True, ) + else: + load_transforms = Compose( + [ + ToTensor(), + # anisotropic_transform, + AddChannel(), + SpatialPad(spatial_size=pad), + AddChannel(), + EnsureType(), + ], + map_items=False, + log_stats=True, + ) self.log("\nLoading dataset...") input_image = load_transforms(volume) From 5c1f9c2c9ab28b7f68d410279f935d6774efd838 Mon Sep 17 00:00:00 2001 From: C-Achard Date: Mon, 15 Aug 2022 09:24:54 +0200 Subject: [PATCH 17/19] Fixes - Fixed missing filename for saving layer inference - Fixed duplicate model init - Fixed error in model init - lint --- napari_cellseg3d/model_workers.py | 97 ++++++++++------------ napari_cellseg3d/plugin_model_inference.py | 5 +- 2 files changed, 49 insertions(+), 53 deletions(-) diff --git a/napari_cellseg3d/model_workers.py b/napari_cellseg3d/model_workers.py index c9cf2e51..ca08c403 100644 --- a/napari_cellseg3d/model_workers.py +++ b/napari_cellseg3d/model_workers.py @@ -319,27 +319,27 @@ def load_folder(self): self.log("\nChecking dimensions...") pad = utils.get_padding_dim(check) - dims = self.model_dict["model_input_size"] - - if self.model_dict["name"] == "SegResNet": - model = self.model_dict["class"].get_net( - input_image_size=[ - dims, - dims, - dims, - ] - ) - elif self.model_dict["name"] == "SwinUNetR": - model = self.model_dict["class"].get_net( - img_size=[dims, dims, dims], - use_checkpoint=False, - ) - else: - model = self.model_dict["class"].get_net() - - self.log_parameters() - - model.to(self.device) + # dims = self.model_dict["model_input_size"] + # + # if self.model_dict["name"] == "SegResNet": + # model = self.model_dict["class"].get_net( + # input_image_size=[ + # dims, + # dims, + # dims, + # ] + # ) + # elif self.model_dict["name"] == "SwinUNetR": + # model = self.model_dict["class"].get_net( + # img_size=[dims, dims, dims], + # use_checkpoint=False, + # ) + # else: + # model = self.model_dict["class"].get_net() + # + # self.log_parameters() + # + # model.to(self.device) # print("FILEPATHS PRINT") # print(self.images_filepaths) @@ -399,18 +399,18 @@ def load_layer(self): # print(volume.shape) # print(volume.dtype) if self.use_window: - load_transforms = Compose( - [ - ToTensor(), - # anisotropic_transform, - AddChannel(), - # SpatialPad(spatial_size=pad), - AddChannel(), - EnsureType(), - ], - map_items=False, - log_stats=True, - ) + load_transforms = Compose( + [ + ToTensor(), + # anisotropic_transform, + AddChannel(), + # SpatialPad(spatial_size=pad), + AddChannel(), + EnsureType(), + ], + map_items=False, + log_stats=True, + ) else: load_transforms = Compose( [ @@ -558,13 +558,12 @@ def save_image( ) imwrite(file_path, image) + filename = os.path.split(file_path)[1] if from_layer: - self.log(f"\nLayer prediction saved as :") + self.log(f"\nLayer prediction saved as : {filename}") else: - self.log(f"\nFile n°{i+1} saved as :") - filename = os.path.split(file_path)[1] - self.log(filename) + self.log(f"\nFile n°{i+1} saved as : {filename}") def aniso_transform(self, image): zoom = self.transforms["zoom"][1] @@ -680,9 +679,13 @@ def inference_on_layer(self, image, model, post_process_transforms): self.save_image(out, from_layer=True) - instance_labels, data_dict = self.get_instance_result(out,from_layer=True) + instance_labels, data_dict = self.get_instance_result( + out, from_layer=True + ) - return self.create_result_dict(out, instance_labels, from_layer=True, data_dict=data_dict) + return self.create_result_dict( + out, instance_labels, from_layer=True, data_dict=data_dict + ) def inference(self): """ @@ -724,20 +727,18 @@ def inference(self): torch.set_num_threads(1) # required for threading on macOS ? self.log("Number of threads has been set to 1 for macOS") - try: dims = self.model_dict["model_input_size"] - + self.log(f"MODEL DIMS : {dims}") + self.log(self.model_dict["name"]) if self.model_dict["name"] == "SegResNet": - model = self.model_dict["class"].get_net()( + model = self.model_dict["class"].get_net( input_image_size=[ dims, dims, dims, ], # TODO FIX ! find a better way & remove model-specific code - out_channels=1, - # dropout_prob=0.3, ) elif self.model_dict["name"] == "SwinUNetR": model = self.model_dict["class"].get_net( @@ -772,10 +773,7 @@ def inference(self): AsDiscrete(threshold=t), EnsureType() ) - - self.log( - "\nLoading weights..." - ) + self.log("\nLoading weights...") if self.weights_dict["custom"]: weights = self.weights_dict["path"] @@ -1091,9 +1089,6 @@ def train(self): model = model_class.get_net() # get an instance of the model model = model.to(self.device) - - - epoch_loss_values = [] val_metric_values = [] diff --git a/napari_cellseg3d/plugin_model_inference.py b/napari_cellseg3d/plugin_model_inference.py index 9251964b..f9cd5615 100644 --- a/napari_cellseg3d/plugin_model_inference.py +++ b/napari_cellseg3d/plugin_model_inference.py @@ -100,7 +100,9 @@ def __init__(self, viewer: "napari.viewer.Viewer"): ###################### ###################### # TODO : better way to handle SegResNet size reqs ? - self.model_input_size = ui.IntIncrementCounter(min=1, max=1024, default=128) + self.model_input_size = ui.IntIncrementCounter( + min=1, max=1024, default=128 + ) self.model_choice.currentIndexChanged.connect( self.toggle_display_model_input_size ) @@ -771,7 +773,6 @@ def on_yield(data, widget): zoom = widget.zoom - viewer.dims.ndisplay = 3 viewer.scale_bar.visible = True From bb90f9f19058356238d77e9dda6e447e7a1b10f4 Mon Sep 17 00:00:00 2001 From: C-Achard Date: Mon, 15 Aug 2022 09:26:53 +0200 Subject: [PATCH 18/19] Added SWIN URL Prepared for SWIN download --- napari_cellseg3d/models/pretrained/pretrained_model_urls.json | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/napari_cellseg3d/models/pretrained/pretrained_model_urls.json b/napari_cellseg3d/models/pretrained/pretrained_model_urls.json index db70430a..ffb756ef 100644 --- a/napari_cellseg3d/models/pretrained/pretrained_model_urls.json +++ b/napari_cellseg3d/models/pretrained/pretrained_model_urls.json @@ -1,5 +1,6 @@ { "TRAILMAP_MS": "http://deeplabcut.rowland.harvard.edu/cellseg3dmodels/TRAILMAP_MS.tar.gz", "SegResNet": "http://deeplabcut.rowland.harvard.edu/cellseg3dmodels/SegResNet.tar.gz", - "VNet": "http://deeplabcut.rowland.harvard.edu/cellseg3dmodels/VNet.tar.gz" + "VNet": "http://deeplabcut.rowland.harvard.edu/cellseg3dmodels/VNet.tar.gz", + "SwinUNetR": "http://deeplabcut.rowland.harvard.edu/cellseg3dmodels/SWIN.tar.gz" } \ No newline at end of file From 158e0c263140a2c606749ed1328d626145dc91fd Mon Sep 17 00:00:00 2001 From: C-Achard Date: Mon, 15 Aug 2022 13:33:11 +0200 Subject: [PATCH 19/19] Fixed model_workers.py - Fixed error in training : get_net erroneous call --- napari_cellseg3d/model_workers.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/napari_cellseg3d/model_workers.py b/napari_cellseg3d/model_workers.py index ca08c403..4ac4d379 100644 --- a/napari_cellseg3d/model_workers.py +++ b/napari_cellseg3d/model_workers.py @@ -1070,7 +1070,7 @@ def train(self): else: size = check print(f"Size of image : {size}") - model = model_class.get_net()( + model = model_class.get_net( input_image_size=utils.get_padding_dim(size), out_channels=1, dropout_prob=0.3, @@ -1081,7 +1081,7 @@ def train(self): else: size = check print(f"Size of image : {size}") - model = model_class.get_net()( + model = model_class.get_net( img_size=utils.get_padding_dim(size), use_checkpoint=True, )