diff --git a/napari_cellseg3d/log_utility.py b/napari_cellseg3d/log_utility.py index d5ea2b24..1ae9b2a0 100644 --- a/napari_cellseg3d/log_utility.py +++ b/napari_cellseg3d/log_utility.py @@ -1,4 +1,5 @@ import threading +import warnings from qtpy import QtCore from qtpy.QtGui import QTextCursor @@ -28,13 +29,15 @@ def write(self, message): try: if not hasattr(self, "flag"): self.flag = False - message = message.replace('\r', '').rstrip() + message = message.replace("\r", "").rstrip() if message: method = "replace_last_line" if self.flag else "append" - QtCore.QMetaObject.invokeMethod(self, - method, - QtCore.Qt.QueuedConnection, - QtCore.Q_ARG(str, message)) + QtCore.QMetaObject.invokeMethod( + self, + method, + QtCore.Qt.QueuedConnection, + QtCore.Q_ARG(str, message), + ) self.flag = True else: self.flag = False @@ -77,3 +80,10 @@ def print_and_log(self, text, printing=True): ) finally: self.lock.release() + + def warn(self, warning): + self.lock.acquire() + try: + warnings.warn(warning) + finally: + self.lock.release() diff --git a/napari_cellseg3d/model_workers.py b/napari_cellseg3d/model_workers.py index 19de2e00..eb2f3cb9 100644 --- a/napari_cellseg3d/model_workers.py +++ b/napari_cellseg3d/model_workers.py @@ -3,6 +3,7 @@ from pathlib import Path import importlib.util from typing import Optional +import warnings import numpy as np from tifffile import imwrite @@ -65,19 +66,28 @@ Path("/models/pretrained") ) + class WeightsDownloader: + """A utility class the downloads the weights of a model when needed.""" + + def __init__(self, log_widget: Optional[log_utility.Log] = None): + """ + Creates a WeightsDownloader, optionally with a log widget to display the progress. - def __init__(self, log_widget: Optional[log_utility.Log]= None): + Args: + log_widget (log_utility.Log): a Log to display the progress bar in. If None, uses print() + """ self.log_widget = log_widget - def download_weights(self,model_name: str): + def download_weights(self, model_name: str, model_weights_filename: str): """ - Downloads a specific pretrained model. - This code is adapted from DeepLabCut with permission from MWMathis. + Downloads a specific pretrained model. + This code is adapted from DeepLabCut with permission from MWMathis. - Args: - model_name (str): name of the model to download - """ + Args: + model_name (str): name of the model to download + model_weights_filename (str): name of the .pth file expected for the model + """ import json import tarfile import urllib.request @@ -94,6 +104,17 @@ def show_progress(count, block_size, total_size): json_path = os.path.join( pretrained_folder_path, "pretrained_model_urls.json" ) + + check_path = os.path.join( + pretrained_folder_path, model_weights_filename + ) + if os.path.exists(check_path): + message = f"Weight file {model_weights_filename} already exists, skipping download step" + if self.log_widget is not None: + self.log_widget.print_and_log(message, printing=False) + print(message) + return + with open(json_path) as f: neturls = json.load(f) if model_name in neturls.keys(): @@ -107,9 +128,16 @@ def show_progress(count, block_size, total_size): pbar = tqdm(unit="B", total=total_size, position=0) else: self.log_widget.print_and_log(start_message) - pbar = tqdm(unit="B", total=total_size, position=0, file=self.log_widget) + pbar = tqdm( + unit="B", + total=total_size, + position=0, + file=self.log_widget, + ) - filename, _ = urllib.request.urlretrieve(url, reporthook=show_progress) + filename, _ = urllib.request.urlretrieve( + url, reporthook=show_progress + ) with tarfile.open(filename, mode="r:gz") as tar: tar.extractall(pretrained_folder_path) else: @@ -121,10 +149,12 @@ def show_progress(count, block_size, total_size): class LogSignal(WorkerBaseSignals): """Signal to send messages to be logged from another thread. - Separate from Worker instances as indicated `here`_""" + Separate from Worker instances as indicated `here`_""" # TODO link ? log_signal = Signal(str) """qtpy.QtCore.Signal: signal to be sent when some text should be logged""" + warn_signal = Signal(str) + """qtpy.QtCore.Signal: signal to be sent when some warning should be emitted in main thread""" # Should not be an instance variable but a class variable, not defined in __init__, see # https://stackoverflow.com/questions/2970312/pyqt4-qtcore-pyqtsignal-object-has-no-attribute-connect @@ -185,6 +215,7 @@ def __init__( super().__init__(self.inference) self._signals = LogSignal() # add custom signals self.log_signal = self._signals.log_signal + self.warn_signal = self._signals.warn_signal ########################################### ########################################### self.device = device @@ -204,7 +235,6 @@ def __init__( self.downloader = WeightsDownloader() """Download utility""" - @staticmethod def create_inference_dict(images_filepaths): """Create a dict for MONAI with "image" keys with all image paths in :py:attr:`~self.images_filepaths` @@ -225,6 +255,10 @@ def log(self, text): """ self.log_signal.emit(text) + def warn(self, warning): + """Sends a warning to main thread""" + self.warn_signal.emit(warning) + def log_parameters(self): self.log("-" * 20) @@ -297,7 +331,7 @@ def inference(self): sys = platform.system() print(f"OS is {sys}") if sys == "Darwin": - torch.set_num_threads(1) # required for threading on macOS ? + torch.set_num_threads(1) # required for threading on macOS ? self.log("Number of threads has been set to 1 for macOS") images_dict = self.create_inference_dict(self.images_filepaths) @@ -323,7 +357,11 @@ def inference(self): model = self.model_dict["class"].get_net() if self.model_dict["name"] == "SegResNet": model = self.model_dict["class"].get_net()( - input_image_size=[dims, dims, dims], # TODO FIX ! find a better way & remove model-specific code + input_image_size=[ + dims, + dims, + dims, + ], # TODO FIX ! find a better way & remove model-specific code out_channels=1, # dropout_prob=0.3, ) @@ -372,8 +410,13 @@ def inference(self): if self.weights_dict["custom"]: weights = self.weights_dict["path"] else: - self.downloader.download_weights(self.model_dict["name"]) - weights = os.path.join(WEIGHTS_DIR, self.model_dict["class"].get_weights_file()) + self.downloader.download_weights( + self.model_dict["name"], + self.model_dict["class"].get_weights_file(), + ) + weights = os.path.join( + WEIGHTS_DIR, self.model_dict["class"].get_weights_file() + ) model.load_state_dict( torch.load( @@ -611,7 +654,10 @@ def __init__( super().__init__(self.train) self._signals = LogSignal() self.log_signal = self._signals.log_signal + self.warn_signal = self._signals.warn_signal + self._weight_error = False + ############################################# self.device = device self.model_dict = model_dict self.weights_path = weights_path @@ -633,7 +679,7 @@ def __init__( self.train_files = [] self.val_files = [] - + ####################################### self.downloader = WeightsDownloader() def set_download_log(self, widget): @@ -647,6 +693,10 @@ def log(self, text): """ self.log_signal.emit(text) + def warn(self, warning): + """Sends a warning to main thread""" + self.warn_signal.emit(warning) + def log_parameters(self): self.log("-" * 20) @@ -690,6 +740,13 @@ def log_parameters(self): if self.weights_path is not None: self.log(f"Using weights from : {self.weights_path}") + if self._weight_error: + self.log( + ">>>>>>>>>>>>>>>>>\n" + "WARNING:\nChosen weights were incompatible with the model,\n" + "the model will be trained from random weights\n" + "<<<<<<<<<<<<<<<<<\n" + ) # self.log("\n") self.log("-" * 20) @@ -904,18 +961,27 @@ def train(self): if self.weights_path is not None: if self.weights_path == "use_pretrained": weights_file = model_class.get_weights_file() - self.downloader.download_weights(model_name) + self.downloader.download_weights(model_name, weights_file) weights = os.path.join(WEIGHTS_DIR, weights_file) self.weights_path = weights else: weights = os.path.join(self.weights_path) - model.load_state_dict( - torch.load( - weights, - map_location=self.device, + try: + model.load_state_dict( + torch.load( + weights, + map_location=self.device, + ) ) - ) + except RuntimeError: + warn = ( + "WARNING:\nIt seems the weights were incompatible with the model,\n" + "the model will be trained from random weights" + ) + self.log(warn) + self.warn(warn) + self._weight_error = True if self.device.type == "cuda": self.log("\nUsing GPU :") diff --git a/napari_cellseg3d/models/model_SegResNet.py b/napari_cellseg3d/models/model_SegResNet.py index 059b3ee5..41dc3bde 100644 --- a/napari_cellseg3d/models/model_SegResNet.py +++ b/napari_cellseg3d/models/model_SegResNet.py @@ -1,7 +1,6 @@ from monai.networks.nets import SegResNetVAE - def get_net(): return SegResNetVAE diff --git a/napari_cellseg3d/models/model_TRAILMAP.py b/napari_cellseg3d/models/model_TRAILMAP.py index e6d2ed1a..ec4cfdbb 100644 --- a/napari_cellseg3d/models/model_TRAILMAP.py +++ b/napari_cellseg3d/models/model_TRAILMAP.py @@ -2,12 +2,12 @@ from torch import nn - def get_weights_file(): # model additionally trained on Mathis/Wyss mesoSPIM data return "TRAILMAP.pth" # FIXME currently incorrect, find good weights from TRAILMAP_test and upload them + def get_net(): return TRAILMAP(1, 1) @@ -120,4 +120,3 @@ def outBlock(self, in_ch, out_ch, kernel_size, padding="same"): # nn.BatchNorm3d(out_ch), ) return out - diff --git a/napari_cellseg3d/models/model_VNet.py b/napari_cellseg3d/models/model_VNet.py index 63089b95..0c5f0b75 100644 --- a/napari_cellseg3d/models/model_VNet.py +++ b/napari_cellseg3d/models/model_VNet.py @@ -2,7 +2,6 @@ from monai.networks.nets import VNet - def get_net(): return VNet() diff --git a/napari_cellseg3d/plugin_model_inference.py b/napari_cellseg3d/plugin_model_inference.py index f04479db..711f4b49 100644 --- a/napari_cellseg3d/plugin_model_inference.py +++ b/napari_cellseg3d/plugin_model_inference.py @@ -599,6 +599,7 @@ def start(self): self.worker.started.connect(self.on_start) self.worker.log_signal.connect(self.log.print_and_log) + self.worker.warn_signal.connect(self.log.warn) self.worker.yielded.connect(yield_connect_show_res) self.worker.errored.connect( yield_connect_show_res diff --git a/napari_cellseg3d/plugin_model_training.py b/napari_cellseg3d/plugin_model_training.py index 569ed054..517cf8fc 100644 --- a/napari_cellseg3d/plugin_model_training.py +++ b/napari_cellseg3d/plugin_model_training.py @@ -857,6 +857,7 @@ def start(self): [btn.setVisible(False) for btn in self.close_buttons] self.worker.log_signal.connect(self.log.print_and_log) + self.worker.warn_signal.connect(self.log.warn) self.worker.started.connect(self.on_start) @@ -994,7 +995,7 @@ def make_csv(self): if len(self.loss_values) == 0 or self.loss_values is None: warnings.warn("No loss values to add to csv !") return - + self.df = pd.DataFrame( { "epoch": size_column, diff --git a/napari_cellseg3d/utils.py b/napari_cellseg3d/utils.py index d4d4e1f4..bc725fc1 100644 --- a/napari_cellseg3d/utils.py +++ b/napari_cellseg3d/utils.py @@ -978,5 +978,3 @@ def merge_imgs(imgs, original_image_shape): print(merged_imgs.shape) return merged_imgs - -