diff --git a/.gitignore b/.gitignore index f38ebe4d..74feefbe 100644 --- a/.gitignore +++ b/.gitignore @@ -99,4 +99,3 @@ venv/ /napari_cellseg3d/models/saved_weights/ /docs/res/logo/old_logo/ /reqs/ - diff --git a/napari_cellseg3d/interface.py b/napari_cellseg3d/interface.py index 4931026e..dfb3be31 100644 --- a/napari_cellseg3d/interface.py +++ b/napari_cellseg3d/interface.py @@ -1,5 +1,4 @@ from typing import Optional -from typing import Union from typing import List @@ -61,6 +60,13 @@ def toggle_visibility(checkbox, widget): widget.setVisible(checkbox.isChecked()) +def add_label(widget, label, label_before=True, horizontal=True): + if label_before: + return combine_blocks(widget, label, horizontal=horizontal) + else: + return combine_blocks(label, widget, horizontal=horizontal) + + class Button(QPushButton): """Class for a button with a title and connected to a function when clicked. Inherits from QPushButton. @@ -494,6 +500,7 @@ def __init__( step=1, parent: Optional[QWidget] = None, fixed: Optional[bool] = True, + label: Optional[str] = None, ): """Args: min (Optional[float]): minimum value, defaults to 0 @@ -501,13 +508,25 @@ def __init__( default (Optional[float]): default value, defaults to 0 step (Optional[float]): step value, defaults to 1 parent: parent widget, defaults to None - fixed (bool): if True, sets the QSizePolicy of the spinbox to Fixed""" + fixed (bool): if True, sets the QSizePolicy of the spinbox to Fixed + label (Optional[str]): if provided, creates a label with the chosen title to use with the counter""" super().__init__(parent) set_spinbox(self, min, max, default, step, fixed) + if label is not None: + self.label = make_label(name=label) + + # def setToolTip(self, a0: str) -> None: + # self.setToolTip(a0) + # if self.label is not None: + # self.label.setToolTip(a0) + + def get_with_label(self, horizontal=True): + return add_label(self, self.label, horizontal=horizontal) + def set_precision(self, decimals): - """Sets the precision of the box to the speicifed number of decimals""" + """Sets the precision of the box to the specified number of decimals""" self.setDecimals(decimals) @classmethod @@ -535,6 +554,7 @@ def __init__( step=1, parent: Optional[QWidget] = None, fixed: Optional[bool] = True, + label: Optional[str] = None, ): """Args: min (Optional[int]): minimum value, defaults to 0 @@ -546,6 +566,9 @@ def __init__( super().__init__(parent) set_spinbox(self, min, max, default, step, fixed) + self.label = None + if label is not None: + self.label = make_label(label, self) @classmethod def make_n( diff --git a/napari_cellseg3d/model_framework.py b/napari_cellseg3d/model_framework.py index 1baf6eed..47208616 100644 --- a/napari_cellseg3d/model_framework.py +++ b/napari_cellseg3d/model_framework.py @@ -13,7 +13,9 @@ from napari_cellseg3d import utils from napari_cellseg3d.log_utility import Log from napari_cellseg3d.models import model_SegResNet as SegResNet -from napari_cellseg3d.models import model_TRAILMAP as TRAILMAP +from napari_cellseg3d.models import model_SwinUNetR as SwinUNetR + +# from napari_cellseg3d.models import model_TRAILMAP as TRAILMAP from napari_cellseg3d.models import model_VNet as VNet from napari_cellseg3d.models import model_TRAILMAP_MS as TRAILMAP_MS from napari_cellseg3d.plugin_base import BasePluginFolder @@ -62,8 +64,9 @@ def __init__(self, viewer: "napari.viewer.Viewer"): self.models_dict = { "VNet": VNet, "SegResNet": SegResNet, - "TRAILMAP": TRAILMAP, + # "TRAILMAP": TRAILMAP, "TRAILMAP_MS": TRAILMAP_MS, + "SwinUNetR": SwinUNetR, } """dict: dictionary of available models, with string for widget display as key diff --git a/napari_cellseg3d/model_workers.py b/napari_cellseg3d/model_workers.py index 0878d926..4ac4d379 100644 --- a/napari_cellseg3d/model_workers.py +++ b/napari_cellseg3d/model_workers.py @@ -45,7 +45,6 @@ # Qt from qtpy.QtCore import Signal - from napari_cellseg3d import utils from napari_cellseg3d import log_utility @@ -165,6 +164,9 @@ def __init__(self): super().__init__() +# TODO : use dataclass for config instead ? + + class InferenceWorker(GeneratorWorker): """A custom worker to run inference jobs in. Inherits from :py:class:`napari.qt.threading.GeneratorWorker`""" @@ -180,6 +182,7 @@ def __init__( instance, use_window, window_infer_size, + window_overlap, keep_on_cpu, stats_csv, images_filepaths=None, @@ -231,6 +234,7 @@ def __init__( self.instance_params = instance self.use_window = use_window self.window_infer_size = window_infer_size + self.window_overlap_percentage = window_overlap self.keep_on_cpu = keep_on_cpu self.stats_to_csv = stats_csv ############################################ @@ -301,8 +305,6 @@ def log_parameters(self): f"Probability threshold is {self.instance_params['threshold']:.2f}\n" f"Objects smaller than {self.instance_params['size_small']} pixels will be removed\n" ) - # self.log(f"") - # self.log("\n") self.log("-" * 20) def load_folder(self): @@ -313,25 +315,57 @@ def load_folder(self): data_check = LoadImaged(keys=["image"])(images_dict[0]) check = data_check["image"].shape - # TODO remove - # z_aniso = 5 / 1.5 - # if zoom is not None : - # pad = utils.get_padding_dim(check, anisotropy_factor=zoom) - # else: + self.log("\nChecking dimensions...") pad = utils.get_padding_dim(check) - load_transforms = Compose( - [ - LoadImaged(keys=["image"]), - # AddChanneld(keys=["image"]), #already done - EnsureChannelFirstd(keys=["image"]), - # Orientationd(keys=["image"], axcodes="PLI"), - # anisotropic_transform, - SpatialPadd(keys=["image"], spatial_size=pad), - EnsureTyped(keys=["image"]), - ] - ) + # dims = self.model_dict["model_input_size"] + # + # if self.model_dict["name"] == "SegResNet": + # model = self.model_dict["class"].get_net( + # input_image_size=[ + # dims, + # dims, + # dims, + # ] + # ) + # elif self.model_dict["name"] == "SwinUNetR": + # model = self.model_dict["class"].get_net( + # img_size=[dims, dims, dims], + # use_checkpoint=False, + # ) + # else: + # model = self.model_dict["class"].get_net() + # + # self.log_parameters() + # + # model.to(self.device) + + # print("FILEPATHS PRINT") + # print(self.images_filepaths) + if self.use_window: + load_transforms = Compose( + [ + LoadImaged(keys=["image"]), + # AddChanneld(keys=["image"]), #already done + EnsureChannelFirstd(keys=["image"]), + # Orientationd(keys=["image"], axcodes="PLI"), + # anisotropic_transform, + EnsureTyped(keys=["image"]), + ] + ) + else: + load_transforms = Compose( + [ + LoadImaged(keys=["image"]), + # AddChanneld(keys=["image"]), #already done + EnsureChannelFirstd(keys=["image"]), + # Orientationd(keys=["image"], axcodes="PLI"), + # anisotropic_transform, + SpatialPadd(keys=["image"], spatial_size=pad), + EnsureTyped(keys=["image"]), + ] + ) self.log("\nLoading dataset...") inference_ds = Dataset(data=images_dict, transform=load_transforms) @@ -364,19 +398,32 @@ def load_layer(self): # print(volume.shape) # print(volume.dtype) - - load_transforms = Compose( - [ - ToTensor(), - # anisotropic_transform, - AddChannel(), - SpatialPad(spatial_size=pad), - AddChannel(), - EnsureType(), - ], - map_items=False, - log_stats=True, - ) + if self.use_window: + load_transforms = Compose( + [ + ToTensor(), + # anisotropic_transform, + AddChannel(), + # SpatialPad(spatial_size=pad), + AddChannel(), + EnsureType(), + ], + map_items=False, + log_stats=True, + ) + else: + load_transforms = Compose( + [ + ToTensor(), + # anisotropic_transform, + AddChannel(), + SpatialPad(spatial_size=pad), + AddChannel(), + EnsureType(), + ], + map_items=False, + log_stats=True, + ) self.log("\nLoading dataset...") input_image = load_transforms(volume) @@ -405,8 +452,10 @@ def model_output( if self.use_window: window_size = self.window_infer_size + window_overlap = self.window_overlap_percentage else: window_size = None + window_overlap = 0.25 outputs = sliding_window_inference( inputs, @@ -415,6 +464,7 @@ def model_output( predictor=model_output, sw_device=self.device, device=dataset_device, + overlap=window_overlap, ) out = outputs.detach().cpu() @@ -508,13 +558,12 @@ def save_image( ) imwrite(file_path, image) + filename = os.path.split(file_path)[1] if from_layer: - self.log(f"\nLayer prediction saved as :") + self.log(f"\nLayer prediction saved as : {filename}") else: - self.log(f"\nFile n°{i+1} saved as :") - filename = os.path.split(file_path)[1] - self.log(filename) + self.log(f"\nFile n°{i+1} saved as : {filename}") def aniso_transform(self, image): zoom = self.transforms["zoom"][1] @@ -630,9 +679,13 @@ def inference_on_layer(self, image, model, post_process_transforms): self.save_image(out, from_layer=True) - instance_labels, data_dict = self.get_instance_result(out,from_layer=True) + instance_labels, data_dict = self.get_instance_result( + out, from_layer=True + ) - return self.create_result_dict(out, instance_labels, from_layer=True, data_dict=data_dict) + return self.create_result_dict( + out, instance_labels, from_layer=True, data_dict=data_dict + ) def inference(self): """ @@ -674,29 +727,27 @@ def inference(self): torch.set_num_threads(1) # required for threading on macOS ? self.log("Number of threads has been set to 1 for macOS") - # if self.device =="cuda": # TODO : fix mem alloc, this does not work it seems - # torch.backends.cudnn.benchmark = False - - # TODO : better solution than loading first image always ? - # data_check = LoadImaged(keys=["image"])(images_dict[0]) - # print(data) - # check = data_check["image"].shape - # print(check) - try: - dims = self.model_dict["segres_size"] + dims = self.model_dict["model_input_size"] + self.log(f"MODEL DIMS : {dims}") + self.log(self.model_dict["name"]) - model = self.model_dict["class"].get_net() if self.model_dict["name"] == "SegResNet": - model = self.model_dict["class"].get_net()( + model = self.model_dict["class"].get_net( input_image_size=[ dims, dims, dims, ], # TODO FIX ! find a better way & remove model-specific code - out_channels=1, - # dropout_prob=0.3, ) + elif self.model_dict["name"] == "SwinUNetR": + model = self.model_dict["class"].get_net( + img_size=[dims, dims, dims], + use_checkpoint=False, + ) + else: + model = self.model_dict["class"].get_net() + model = model.to(self.device) self.log_parameters() @@ -722,10 +773,7 @@ def inference(self): AsDiscrete(threshold=t), EnsureType() ) - - self.log( - "\nLoading weights..." - ) + self.log("\nLoading weights...") if self.weights_dict["custom"]: weights = self.weights_dict["path"] @@ -1022,11 +1070,21 @@ def train(self): else: size = check print(f"Size of image : {size}") - model = model_class.get_net()( + model = model_class.get_net( input_image_size=utils.get_padding_dim(size), out_channels=1, dropout_prob=0.3, ) + elif model_name == "SwinUNetR": + if self.sampling: + size = self.sample_size + else: + size = check + print(f"Size of image : {size}") + model = model_class.get_net( + img_size=utils.get_padding_dim(size), + use_checkpoint=True, + ) else: model = model_class.get_net() # get an instance of the model model = model.to(self.device) diff --git a/napari_cellseg3d/models/model_SegResNet.py b/napari_cellseg3d/models/model_SegResNet.py index 41dc3bde..ee1dc9a8 100644 --- a/napari_cellseg3d/models/model_SegResNet.py +++ b/napari_cellseg3d/models/model_SegResNet.py @@ -1,8 +1,10 @@ from monai.networks.nets import SegResNetVAE -def get_net(): - return SegResNetVAE +def get_net(input_image_size, dropout_prob=None): + return SegResNetVAE( + input_image_size, out_channels=1, dropout_prob=dropout_prob + ) def get_weights_file(): diff --git a/napari_cellseg3d/models/model_SwinUNetR.py b/napari_cellseg3d/models/model_SwinUNetR.py new file mode 100644 index 00000000..532aeb89 --- /dev/null +++ b/napari_cellseg3d/models/model_SwinUNetR.py @@ -0,0 +1,25 @@ +import torch +from monai.networks.nets import SwinUNETR + + +def get_weights_file(): + return "Swin64_best_metric.pth" + + +def get_net(img_size, use_checkpoint=True): + return SwinUNETR( + img_size, + in_channels=1, + out_channels=1, + feature_size=48, + use_checkpoint=use_checkpoint, + ) + + +def get_output(model, input): + out = model(input) + return torch.sigmoid(out) + + +def get_validation(model, val_inputs): + return model(val_inputs) diff --git a/napari_cellseg3d/models/model_TRAILMAP.py b/napari_cellseg3d/models/model_TRAILMAP.py index ff28ecdd..7cdf9b80 100644 --- a/napari_cellseg3d/models/model_TRAILMAP.py +++ b/napari_cellseg3d/models/model_TRAILMAP.py @@ -92,7 +92,6 @@ def bridgeBlock(self, in_ch, out_ch, kernel_size, padding="same"): ), nn.BatchNorm3d(out_ch), nn.ReLU(), - # nn.ConvTranspose3d(out_ch, out_ch, kernel_size=2, stride=2), ) return encode @@ -117,6 +116,5 @@ def outBlock(self, in_ch, out_ch, kernel_size, padding="same"): out = nn.Sequential( nn.Conv3d(in_ch, out_ch, kernel_size=kernel_size, padding=padding), - # nn.BatchNorm3d(out_ch), ) return out diff --git a/napari_cellseg3d/models/model_VNet.py b/napari_cellseg3d/models/model_VNet.py index 0c5f0b75..0c854832 100644 --- a/napari_cellseg3d/models/model_VNet.py +++ b/napari_cellseg3d/models/model_VNet.py @@ -19,6 +19,11 @@ def get_validation(model, val_inputs): roi_size = (64, 64, 64) sw_batch_size = 1 val_outputs = sliding_window_inference( - val_inputs, roi_size, sw_batch_size, model, mode="gaussian" + val_inputs, + roi_size, + sw_batch_size, + model, + mode="gaussian", + overlap=0.7, ) return val_outputs diff --git a/napari_cellseg3d/models/pretrained/pretrained_model_urls.json b/napari_cellseg3d/models/pretrained/pretrained_model_urls.json index 86bc0f57..ffb756ef 100644 --- a/napari_cellseg3d/models/pretrained/pretrained_model_urls.json +++ b/napari_cellseg3d/models/pretrained/pretrained_model_urls.json @@ -1,6 +1,6 @@ { "TRAILMAP_MS": "http://deeplabcut.rowland.harvard.edu/cellseg3dmodels/TRAILMAP_MS.tar.gz", - "TRAILMAP": "http://deeplabcut.rowland.harvard.edu/cellseg3dmodels/TRAILMAP.tar.gz", "SegResNet": "http://deeplabcut.rowland.harvard.edu/cellseg3dmodels/SegResNet.tar.gz", - "VNet": "http://deeplabcut.rowland.harvard.edu/cellseg3dmodels/VNet.tar.gz" + "VNet": "http://deeplabcut.rowland.harvard.edu/cellseg3dmodels/VNet.tar.gz", + "SwinUNetR": "http://deeplabcut.rowland.harvard.edu/cellseg3dmodels/SWIN.tar.gz" } \ No newline at end of file diff --git a/napari_cellseg3d/plugin_model_inference.py b/napari_cellseg3d/plugin_model_inference.py index e6265767..f9cd5615 100644 --- a/napari_cellseg3d/plugin_model_inference.py +++ b/napari_cellseg3d/plugin_model_inference.py @@ -78,6 +78,7 @@ def __init__(self, viewer: "napari.viewer.Viewer"): self.keep_on_cpu = False self.use_window_inference = False self.window_inference_size = None + self.window_overlap = 0.25 ########################### # interface @@ -99,9 +100,11 @@ def __init__(self, viewer: "napari.viewer.Viewer"): ###################### ###################### # TODO : better way to handle SegResNet size reqs ? - self.segres_size = ui.IntIncrementCounter(min=1, max=1024, default=128) + self.model_input_size = ui.IntIncrementCounter( + min=1, max=1024, default=128 + ) self.model_choice.currentIndexChanged.connect( - self.toggle_display_segres_size + self.toggle_display_model_input_size ) self.model_choice.setCurrentIndex(0) @@ -130,22 +133,53 @@ def __init__(self, viewer: "napari.viewer.Viewer"): T=7, parent=self ) - self.window_infer_box = ui.make_checkbox("Use window inference") + self.window_infer_box = ui.CheckBox(title="Use window inference") self.window_infer_box.clicked.connect(self.toggle_display_window_size) sizes_window = ["8", "16", "32", "64", "128", "256", "512"] + # ( + # self.window_size_choice, + # self.lbl_window_size_choice, + # ) = ui.make_combobox(sizes_window, label="Window size and overlap") + # self.window_overlap = ui.make_n_spinboxes( + # max=1, + # default=0.7, + # step=0.05, + # double=True, + # ) + self.window_size_choice = ui.DropdownMenu( sizes_window, label="Window size" ) self.lbl_window_size_choice = self.window_size_choice.label - self.keep_data_on_cpu_box = ui.make_checkbox("Keep data on CPU") + self.window_overlap_counter = ui.DoubleIncrementCounter( + min=0, + max=1, + default=0.25, + step=0.05, + parent=self, + label="Overlap %", + ) - self.window_infer_params = ui.combine_blocks( + self.keep_data_on_cpu_box = ui.CheckBox(title="Keep data on CPU") + + window_size_widgets = ui.combine_blocks( self.window_size_choice, self.lbl_window_size_choice, horizontal=False, ) + # self.window_infer_params = ui.combine_blocks( + # self.window_overlap, + # self.window_infer_params, + # horizontal=False, + # ) + + self.window_infer_params = ui.combine_blocks( + window_size_widgets, + self.window_overlap_counter.get_with_label(horizontal=False), + horizontal=False, + ) ################## ################## @@ -215,8 +249,8 @@ def __init__(self, viewer: "napari.viewer.Viewer"): self.show_original_checkbox.setToolTip( "Displays the image used for inference in the viewer" ) - self.segres_size.setToolTip( - "Image size on which the SegResNet has been trained (default : 128)" + self.model_input_size.setToolTip( + "Image size on which the model has been trained (default : 128)" ) thresh_desc = ( @@ -234,6 +268,15 @@ def __init__(self, viewer: "napari.viewer.Viewer"): self.window_size_choice.setToolTip( "Size of the window to run inference with (in pixels)" ) + + self.window_overlap_counter.setToolTip( + "Percentage of overlap between windows to use when using sliding window" + ) + + # self.window_overlap.setToolTip( + # "Amount of overlap between sliding windows" + # ) + self.keep_data_on_cpu_box.setToolTip( "If enabled, data will be kept on the RAM rather than the VRAM.\nCan avoid out of memory issues with CUDA" ) @@ -280,11 +323,14 @@ def check_ready(self): warnings.warn("Image and label paths are not correctly set") return False - def toggle_display_segres_size(self): - if self.model_choice.currentText() == "SegResNet": - self.segres_size.setVisible(True) + def toggle_display_model_input_size(self): + if ( + self.model_choice.currentText() == "SegResNet" + or self.model_choice.currentText() == "SwinUNetR" + ): + self.model_input_size.setVisible(True) else: - self.segres_size.setVisible(False) + self.model_input_size.setVisible(False) def toggle_display_number(self): """Shows the choices for viewing results depending on whether :py:attr:`self.view_checkbox` is checked""" @@ -393,7 +439,7 @@ def build(self): self.model_choice, self.custom_weights_choice, self.weights_path_container, - self.segres_size, + self.model_input_size, ], ) self.weights_path_container.setVisible(False) @@ -551,7 +597,7 @@ def start(self, on_layer=False): model_dict = { # gather model info "name": model_key, "class": self.get_model(model_key), - "segres_size": self.segres_size.value(), + "model_input_size": self.model_input_size.value(), } if self.custom_weights_choice.isChecked(): @@ -600,6 +646,7 @@ def start(self, on_layer=False): self.window_inference_size = int( self.window_size_choice.currentText() ) + self.window_overlap = self.window_overlap_counter.value() if not on_layer: self.worker = InferenceWorker( @@ -613,6 +660,7 @@ def start(self, on_layer=False): instance=self.instance_params, use_window=self.use_window_inference, window_infer_size=self.window_inference_size, + window_overlap=self.window_overlap, keep_on_cpu=self.keep_on_cpu, stats_csv=self.stats_to_csv, ) @@ -629,6 +677,7 @@ def start(self, on_layer=False): use_window=self.use_window_inference, window_infer_size=self.window_inference_size, keep_on_cpu=self.keep_on_cpu, + window_overlap=self.window_overlap, stats_csv=self.stats_to_csv, layer=layer, ) @@ -724,9 +773,6 @@ def on_yield(data, widget): zoom = widget.zoom - # print(data["original"].shape) - # print(data["result"].shape) - viewer.dims.ndisplay = 3 viewer.scale_bar.visible = True diff --git a/napari_cellseg3d/plugin_model_training.py b/napari_cellseg3d/plugin_model_training.py index 517cf8fc..1e9deed6 100644 --- a/napari_cellseg3d/plugin_model_training.py +++ b/napari_cellseg3d/plugin_model_training.py @@ -31,7 +31,7 @@ from napari_cellseg3d.model_workers import TrainingWorker NUMBER_TABS = 3 -DEFAULT_PATCH_SIZE = 60 +DEFAULT_PATCH_SIZE = 64 class Trainer(ModelFramework): diff --git a/napari_cellseg3d/utils.py b/napari_cellseg3d/utils.py index 83b23e9e..498c4830 100644 --- a/napari_cellseg3d/utils.py +++ b/napari_cellseg3d/utils.py @@ -116,7 +116,6 @@ def resize(image, zoom_factors): def align_array_sizes(array_shape, target_shape): - index_differences = [] for i in range(len(target_shape)): if target_shape[i] != array_shape[i]: @@ -334,7 +333,6 @@ def fill_list_in_between(lst, n, elem): Returns : Filled list """ - new_list = [] for i in range(len(lst)): temp_list = [lst[i]] @@ -606,378 +604,3 @@ def format_Warning(message, category, filename, lineno, line=""): + str(message) + "\n" ) - - -# def dice_coeff(y_true, y_pred): -# smooth = 1. -# y_true_f = y_true.flatten() -# y_pred_f = K.flatten(y_pred) -# intersection = K.sum(y_true_f * y_pred_f) -# score = (2. * intersection + smooth) / (K.sum(y_true_f) + K.sum(y_pred_f) + smooth) -# return score - - -# def dice_loss(y_true, y_pred): -# loss = 1 - dice_coeff(y_true, y_pred) -# return loss - - -# def bce_dice_loss(y_true, y_pred): -# loss = binary_crossentropy(y_true, y_pred) + dice_loss(y_true, y_pred) -# return loss - - -def divide_imgs(images): - H = -(-images.shape[1] // 412) - W = -(-images.shape[2] // 412) - - diveded_imgs = np.zeros((images.shape[0] * H * W, 512, 512, 1), np.float32) - # print(H, W) - - for z in range(images.shape[0]): - image = images[z] - for h in range(H): - for w in range(W): - cropped_img = np.zeros((512, 512, 1), np.float32) - cropped_img -= 1 - - if images.shape[1] < 412: - h = -1 - if images.shape[2] < 412: - w = -1 - - if h == -1: - if w == -1: - cropped_img[ - 50 : images.shape[1] + 50, - 50 : images.shape[2] + 50, - 0, - ] = image[0 : images.shape[1], 0 : images.shape[2], 0] - elif w == 0: - cropped_img[ - 50 : images.shape[1] + 50, 50:512, 0 - ] = image[0 : images.shape[1], 0:462, 0] - elif w == W - 1: - cropped_img[ - 50 : images.shape[1] + 50, - 0 : images.shape[2] - 412 * W - 50, - 0, - ] = image[ - 0 : images.shape[1], - w * 412 - 50 : images.shape[2], - 0, - ] - else: - cropped_img[50 : images.shape[1] + 50, :, 0] = image[ - 0 : images.shape[1], - w * 412 - 50 : (w + 1) * 412 + 50, - 0, - ] - elif h == 0: - if w == -1: - cropped_img[ - 50:512, 50 : images.shape[2] + 50, 0 - ] = image[0:462, 0 : images.shape[2], 0] - elif w == 0: - cropped_img[50:512, 50:512, 0] = image[0:462, 0:462, 0] - elif w == W - 1: - cropped_img[ - 50:512, 0 : images.shape[2] - 412 * W - 50, 0 - ] = image[0:462, w * 412 - 50 : images.shape[2], 0] - else: - # cropped_img[50:512, :, 0] = image[0:462, w*412-50:(w+1)*412+50, 0] - try: - cropped_img[50:512, :, 0] = image[ - 0:462, w * 412 - 50 : (w + 1) * 412 + 50, 0 - ] - except: - cropped_img[ - 50:512, - 0 : images.shape[2] - 412 * (W - 1) - 50, - 0, - ] = image[ - 0:462, w * 412 - 50 : (w + 1) * 412 + 50, 0 - ] - elif h == H - 1: - if w == -1: - cropped_img[ - 0 : images.shape[1] - 412 * H - 50, - 50 : images.shape[2] + 50, - 0, - ] = image[ - h * 412 - 50 : images.shape[1], - 0 : images.shape[2], - 0, - ] - elif w == 0: - cropped_img[ - 0 : images.shape[1] - 412 * H - 50, 50:512, 0 - ] = image[h * 412 - 50 : images.shape[1], 0:462, 0] - elif w == W - 1: - cropped_img[ - 0 : images.shape[1] - 412 * H - 50, - 0 : images.shape[2] - 412 * W - 50, - 0, - ] = image[ - h * 412 - 50 : images.shape[1], - w * 412 - 50 : images.shape[2], - 0, - ] - else: - try: - cropped_img[ - 0 : images.shape[1] - 412 * H - 50, :, 0 - ] = image[ - h * 412 - 50 : images.shape[1], - w * 412 - 50 : (w + 1) * 412 + 50, - 0, - ] - except: - cropped_img[ - 0 : images.shape[1] - 412 * H - 50, - 0 : images.shape[2] - 412 * (W - 1) - 50, - 0, - ] = image[ - h * 412 - 50 : images.shape[1], - w * 412 - 50 : (w + 1) * 412 + 50, - 0, - ] - else: - if w == -1: - cropped_img[:, 50 : images.shape[2] + 50, 0] = image[ - h * 412 - 50 : (h + 1) * 412 + 50, - 0 : images.shape[2], - 0, - ] - elif w == 0: - # cropped_img[:, 50:512, 0] = image[h*412-50:(h+1)*412+50, 0:462, 0] - try: - cropped_img[:, 50:512, 0] = image[ - h * 412 - 50 : (h + 1) * 412 + 50, 0:462, 0 - ] - except: - cropped_img[ - 0 : images.shape[1] - 412 * H - 50 + 412, - 50:512, - 0, - ] = image[ - h * 412 - 50 : (h + 1) * 412 + 50, 0:462, 0 - ] - elif w == W - 1: - # cropped_img[:, 0:images.shape[2]-412*W-50, 0] = image[h*412-50:(h+1)*412+50, w*412-50:images.shape[2], 0] - try: - cropped_img[ - :, 0 : images.shape[2] - 412 * W - 50, 0 - ] = image[ - h * 412 - 50 : (h + 1) * 412 + 50, - w * 412 - 50 : images.shape[2], - 0, - ] - except: - cropped_img[ - 0 : images.shape[1] - 412 * H - 50 + 412, - 0 : images.shape[2] - 412 * W - 50, - 0, - ] = image[ - h * 412 - 50 : (h + 1) * 412 + 50, - w * 412 - 50 : images.shape[2], - 0, - ] - else: - # cropped_img[:, :, 0] = image[h*412-50:(h+1)*412+50, w*412-50:(w+1)*412+50, 0] - try: - cropped_img[:, :, 0] = image[ - h * 412 - 50 : (h + 1) * 412 + 50, - w * 412 - 50 : (w + 1) * 412 + 50, - 0, - ] - except: - try: - cropped_img[ - :, - 0 : images.shape[2] - 412 * (W - 1) - 50, - 0, - ] = image[ - h * 412 - 50 : (h + 1) * 412 + 50, - w * 412 - 50 : (w + 1) * 412 + 50, - 0, - ] - except: - cropped_img[ - 0 : images.shape[1] - 412 * (H - 1) - 50, - :, - 0, - ] = image[ - h * 412 - 50 : (h + 1) * 412 + 50, - w * 412 - 50 : (w + 1) * 412 + 50, - 0, - ] - h = max(0, h) - w = max(0, w) - diveded_imgs[z * H * W + w * H + h] = cropped_img - # print(z*H*W+ w*H+h) - - return diveded_imgs - - -def merge_imgs(imgs, original_image_shape): - merged_imgs = np.zeros( - ( - original_image_shape[0], - original_image_shape[1], - original_image_shape[2], - 1, - ), - np.float32, - ) - H = -(-original_image_shape[1] // 412) - W = -(-original_image_shape[2] // 412) - - for z in range(original_image_shape[0]): - for h in range(H): - for w in range(W): - - if original_image_shape[1] < 412: - h = -1 - if original_image_shape[2] < 412: - w = -1 - - # print(z*H*W+ max(w, 0)*H+max(h, 0)) - if h == -1: - if w == -1: - merged_imgs[ - z, - 0 : original_image_shape[1], - 0 : original_image_shape[2], - 0, - ] = imgs[z * H * W + 0 * H + 0][ - 50 : original_image_shape[1] + 50, - 50 : original_image_shape[2] + 50, - 0, - ] - elif w == 0: - merged_imgs[ - z, 0 : original_image_shape[1], 0:412, 0 - ] = imgs[z * H * W + w * H + 0][ - 50 : original_image_shape[1] + 50, 50:462, 0 - ] - elif w == W - 1: - merged_imgs[ - z, - 0 : original_image_shape[1], - w * 412 : original_image_shape[2], - 0, - ] = imgs[z * H * W + w * H + 0][ - 50 : original_image_shape[1] + 50, - 50 : original_image_shape[2] - 412 * W - 50, - 0, - ] - else: - merged_imgs[ - z, - 0 : original_image_shape[1], - w * 412 : (w + 1) * 412, - 0, - ] = imgs[z * H * W + w * H + 0][ - 50 : original_image_shape[1] + 50, 50:462, 0 - ] - elif h == 0: - if w == -1: - merged_imgs[ - z, 0:412, 0 : original_image_shape[2], 0 - ] = imgs[z * H * W + 0 * H + h][ - 50:462, 50 : original_image_shape[2] + 50, 0 - ] - elif w == 0: - merged_imgs[z, 0:412, 0:412, 0] = imgs[ - z * H * W + w * H + h - ][50:462, 50:462, 0] - elif w == W - 1: - merged_imgs[ - z, 0:412, w * 412 : original_image_shape[2], 0 - ] = imgs[z * H * W + w * H + h][ - 50:462, - 50 : original_image_shape[2] - 412 * W - 50, - 0, - ] - else: - merged_imgs[ - z, 0:412, w * 412 : (w + 1) * 412, 0 - ] = imgs[z * H * W + w * H + h][50:462, 50:462, 0] - elif h == H - 1: - if w == -1: - merged_imgs[ - z, - h * 412 : original_image_shape[1], - 0 : original_image_shape[2], - 0, - ] = imgs[z * H * W + 0 * H + h][ - 50 : original_image_shape[1] - 412 * H - 50, - 50 : original_image_shape[2] + 50, - 0, - ] - elif w == 0: - merged_imgs[ - z, h * 412 : original_image_shape[1], 0:412, 0 - ] = imgs[z * H * W + w * H + h][ - 50 : original_image_shape[1] - 412 * H - 50, - 50:462, - 0, - ] - elif w == W - 1: - merged_imgs[ - z, - h * 412 : original_image_shape[1], - w * 412 : original_image_shape[2], - 0, - ] = imgs[z * H * W + w * H + h][ - 50 : original_image_shape[1] - 412 * H - 50, - 50 : original_image_shape[2] - 412 * W - 50, - 0, - ] - else: - merged_imgs[ - z, - h * 412 : original_image_shape[1], - w * 412 : (w + 1) * 412, - 0, - ] = imgs[z * H * W + w * H + h][ - 50 : original_image_shape[1] - 412 * H - 50, - 50:462, - 0, - ] - else: - if w == -1: - merged_imgs[ - z, - h * 412 : (h + 1) * 412, - 0 : original_image_shape[2], - 0, - ] = imgs[z * H * W + 0 * H + h][ - 50:462, 50 : original_image_shape[2] + 50, 0 - ] - elif w == 0: - merged_imgs[ - z, h * 412 : (h + 1) * 412, 0:412, 0 - ] = imgs[z * H * W + w * H + h][50:462, 50:462, 0] - elif w == W - 1: - merged_imgs[ - z, - h * 412 : (h + 1) * 412, - w * 412 : original_image_shape[2], - 0, - ] = imgs[z * H * W + w * H + h][ - 50:462, - 50 : original_image_shape[2] - 412 * W - 50, - 0, - ] - else: - merged_imgs[ - z, - h * 412 : (h + 1) * 412, - w * 412 : (w + 1) * 412, - 0, - ] = imgs[z * H * W + w * H + h][50:462, 50:462, 0] - - print(merged_imgs.shape) - return merged_imgs diff --git a/requirements.txt b/requirements.txt index 9885d9f0..3ba73405 100644 --- a/requirements.txt +++ b/requirements.txt @@ -13,13 +13,10 @@ napari[all]>=0.4.14 QtPy opencv-python>=4.5.5 dask-image>=0.6.0 -scikit-image>=0.19.2 matplotlib>=3.4.1 tifffile>=2022.2.9 imageio-ffmpeg>=0.4.5 torch>=1.11 -monai>=0.9.0 -nibabel +monai[nibabel,scikit-image,itk,einops]>=0.9.0 pillow -itk>=5.2.0 vispy>=0.9.6 diff --git a/setup.cfg b/setup.cfg index b423c9ea..a0aa6eee 100644 --- a/setup.cfg +++ b/setup.cfg @@ -50,12 +50,13 @@ install_requires = tifffile>=2022.2.9 imageio-ffmpeg>=0.4.5 torch>=1.11 + monai[nibabel,scikit-image,itk,einops]>=0.9.0 tqdm monai>=0.9.0 nibabel scikit-image pillow - itk>=5.2.0 + tqdm matplotlib vispy>=0.9.6