From 272449effa328bef116a22797e12330849fd3b9f Mon Sep 17 00:00:00 2001 From: C-Achard Date: Sat, 9 Jul 2022 10:23:44 +0200 Subject: [PATCH 1/7] Moved weights download + enhancements - Weight download now always happens in worker thread to avoid freezing UI - Download progress bar now shows on the log in the plugin - Models .py files changed to be as simple as possible (so users can more easily add custom models) --- .gitignore | 1 + napari_cellseg3d/interface.py | 1 + napari_cellseg3d/log_utility.py | 33 ++++++++ napari_cellseg3d/model_workers.py | 89 +++++++++++++++++++--- napari_cellseg3d/models/TRAILMAP_MS.py | 3 +- napari_cellseg3d/models/model_SegResNet.py | 3 +- napari_cellseg3d/models/model_TRAILMAP.py | 3 +- napari_cellseg3d/models/model_VNet.py | 3 +- napari_cellseg3d/plugin_model_inference.py | 2 +- napari_cellseg3d/plugin_model_training.py | 6 ++ napari_cellseg3d/utils.py | 41 +--------- 11 files changed, 125 insertions(+), 60 deletions(-) diff --git a/.gitignore b/.gitignore index d61b3f93..3c0e712a 100644 --- a/.gitignore +++ b/.gitignore @@ -10,6 +10,7 @@ __pycache__/ *.tif napari_cellseg3d/_tests/res/*.csv *.pth +*.db # Distribution / packaging .Python diff --git a/napari_cellseg3d/interface.py b/napari_cellseg3d/interface.py index 7f39e3bb..82445605 100644 --- a/napari_cellseg3d/interface.py +++ b/napari_cellseg3d/interface.py @@ -1,6 +1,7 @@ from typing import Optional from typing import Union + from qtpy.QtCore import Qt from qtpy.QtCore import QUrl from qtpy.QtGui import QDesktopServices diff --git a/napari_cellseg3d/log_utility.py b/napari_cellseg3d/log_utility.py index c4288f1f..c6aab5bb 100644 --- a/napari_cellseg3d/log_utility.py +++ b/napari_cellseg3d/log_utility.py @@ -1,5 +1,6 @@ import threading +from qtpy import QtCore from qtpy.QtGui import QTextCursor from qtpy.QtWidgets import QTextEdit @@ -22,6 +23,38 @@ def __init__(self, parent): # def receive_log(self, text): # self.print_and_log(text) + def write(self, message): + self.lock.acquire() + try: + if not hasattr(self, "flag"): + self.flag = False + message = message.replace('\r', '').rstrip() + if message: + method = "replace_last_line" if self.flag else "append" + QtCore.QMetaObject.invokeMethod(self, + method, + QtCore.Qt.QueuedConnection, + QtCore.Q_ARG(str, message)) + self.flag = True + else: + self.flag = False + + finally: + self.lock.release() + + @QtCore.Slot(str) + def replace_last_line(self, text): + self.lock.acquire() + try: + cursor = self.textCursor() + cursor.movePosition(QTextCursor.End) + cursor.select(QTextCursor.BlockUnderCursor) + cursor.removeSelectedText() + cursor.insertBlock() + self.setTextCursor(cursor) + self.insertPlainText(text) + finally: + self.lock.release() def print_and_log(self, text): """Utility used to both print to terminal and log text to a QTextEdit diff --git a/napari_cellseg3d/model_workers.py b/napari_cellseg3d/model_workers.py index 2cc6cb03..0853ba3d 100644 --- a/napari_cellseg3d/model_workers.py +++ b/napari_cellseg3d/model_workers.py @@ -1,9 +1,13 @@ import os import platform from pathlib import Path +import importlib.util +from typing import Optional import numpy as np +from tifffile import imwrite import torch +from tqdm import tqdm # MONAI from monai.data import CacheDataset @@ -37,9 +41,10 @@ # Qt from qtpy.QtCore import Signal -from tifffile import imwrite + from napari_cellseg3d import utils +from napari_cellseg3d import log_utility # local from napari_cellseg3d.model_instance_seg import binary_connected @@ -57,9 +62,61 @@ # https://napari-staging-site.github.io/guides/stable/threading.html WEIGHTS_DIR = os.path.dirname(os.path.realpath(__file__)) + str( - Path("/models/saved_weights") + Path("/models/pretrained") ) +class WeightsDownloader: + + def __init__(self, log_widget: Optional[log_utility.Log]= None): + self.log_widget = log_widget + + def download_weights(self,model_name: str): + """ + Downloads a specific pretrained model. + This code is adapted from DeepLabCut with permission from MWMathis. + + Args: + model_name (str): name of the model to download + """ + import json + import tarfile + import urllib.request + + def show_progress(count, block_size, total_size): + pbar.update(block_size) + + cellseg3d_path = os.path.split( + importlib.util.find_spec("napari_cellseg3d").origin + )[0] + pretrained_folder_path = os.path.join( + cellseg3d_path, "models", "pretrained" + ) + json_path = os.path.join( + pretrained_folder_path, "pretrained_model_urls.json" + ) + with open(json_path) as f: + neturls = json.load(f) + if model_name in neturls.keys(): + url = neturls[model_name] + response = urllib.request.urlopen(url) + + start_message = f"Downloading the model from the M.W. Mathis Lab server {url}...." + total_size = int(response.getheader("Content-Length")) + if self.log_widget is None: + print(start_message) + pbar = tqdm(unit="B", total=total_size, position=0) + else: + self.log_widget.print_and_log(start_message) + pbar = tqdm(unit="B", total=total_size, position=0, file=self.log_widget) + + filename, _ = urllib.request.urlretrieve(url, reporthook=show_progress) + with tarfile.open(filename, mode="r:gz") as tar: + tar.extractall(pretrained_folder_path) + else: + raise ValueError( + f"Unknown model. `modelname` should be one of {', '.join(neturls)}" + ) + class LogSignal(WorkerBaseSignals): """Signal to send messages to be logged from another thread. @@ -142,9 +199,12 @@ def __init__( self.window_infer_size = window_infer_size self.keep_on_cpu = keep_on_cpu self.stats_to_csv = stats_csv - """These attributes are all arguments of :py:func:~inference, please see that for reference""" + self.downloader = WeightsDownloader() + """Download utility""" + + @staticmethod def create_inference_dict(images_filepaths): """Create a dict for MONAI with "image" keys with all image paths in :py:attr:`~self.images_filepaths` @@ -154,6 +214,9 @@ def create_inference_dict(images_filepaths): data_dicts = [{"image": image_name} for image_name in images_filepaths] return data_dicts + def set_download_log(self, widget): + self.downloader.log_widget = widget + def log(self, text): """Sends a signal that ``text`` should be logged @@ -233,8 +296,8 @@ def inference(self): """ sys = platform.system() print(f"OS is {sys}") - if sys == "Darwin": # required for macOS ? - torch.set_num_threads(1) + if sys == "Darwin": + torch.set_num_threads(1) # required for threading on macOS ? self.log("Number of threads has been set to 1 for macOS") images_dict = self.create_inference_dict(self.images_filepaths) @@ -260,7 +323,7 @@ def inference(self): model = self.model_dict["class"].get_net() if self.model_dict["name"] == "SegResNet": model = self.model_dict["class"].get_net()( - input_image_size=[dims, dims, dims], # TODO FIX ! + input_image_size=[dims, dims, dims], # TODO FIX ! find a better way & remove model-specific code out_channels=1, # dropout_prob=0.3, ) @@ -304,12 +367,13 @@ def inference(self): # print(weights) self.log( "\nLoading weights..." - ) # TODO add try/except for invalid weights + ) # TODO add try/except for invalid weights for proper reset if self.weights_dict["custom"]: weights = self.weights_dict["path"] else: - weights = os.path.join(WEIGHTS_DIR, self.weights_dict["path"]) + self.downloader.download_weights(self.model_dict["name"]) + weights = os.path.join(WEIGHTS_DIR, self.model_dict["class"].get_weights_file()) model.load_state_dict( torch.load( @@ -544,8 +608,6 @@ def __init__( """ - - print("init") super().__init__(self.train) self._signals = LogSignal() self.log_signal = self._signals.log_signal @@ -571,7 +633,11 @@ def __init__( self.train_files = [] self.val_files = [] - print("end init") + + self.downloader = WeightsDownloader() + + def set_download_log(self, widget): + self.downloader.log_widget = widget def log(self, text): """Sends a signal that ``text`` should be logged @@ -838,6 +904,7 @@ def train(self): if self.weights_path is not None: if self.weights_path == "use_pretrained": weights_file = model_class.get_weights_file() + self.downloader.download_weights(model_name) weights = os.path.join(WEIGHTS_DIR, weights_file) self.weights_path = weights else: diff --git a/napari_cellseg3d/models/TRAILMAP_MS.py b/napari_cellseg3d/models/TRAILMAP_MS.py index 9905c71a..ff82b5f0 100644 --- a/napari_cellseg3d/models/TRAILMAP_MS.py +++ b/napari_cellseg3d/models/TRAILMAP_MS.py @@ -8,8 +8,7 @@ def get_weights_file(): # model additionally trained on Mathis/Wyss mesoSPIM data - target_dir = utils.download_model("TRAILMAP_MS") - return os.path.join(target_dir, "TRAILMAP_MS_best_metric_epoch_26.pth") + return "TRAILMAP_MS_best_metric_epoch_26.pth" def get_net(): diff --git a/napari_cellseg3d/models/model_SegResNet.py b/napari_cellseg3d/models/model_SegResNet.py index 98e07c49..14a9bd58 100644 --- a/napari_cellseg3d/models/model_SegResNet.py +++ b/napari_cellseg3d/models/model_SegResNet.py @@ -10,8 +10,7 @@ def get_net(): def get_weights_file(): - target_dir = utils.download_model("SegResNet") - return os.path.join(target_dir, "SegResNet.pth") + return "SegResNet.pth" def get_output(model, input): diff --git a/napari_cellseg3d/models/model_TRAILMAP.py b/napari_cellseg3d/models/model_TRAILMAP.py index 0c056032..14afc33d 100644 --- a/napari_cellseg3d/models/model_TRAILMAP.py +++ b/napari_cellseg3d/models/model_TRAILMAP.py @@ -6,8 +6,7 @@ def get_weights_file(): # original model from Liqun Luo lab, transfered to pytorch - target_dir = utils.download_model("TRAILMAP") - return os.path.join(target_dir, "TRAILMAP_PyTorch.pth") + return "TRAILMAP_PyTorch.pth" def get_net(): diff --git a/napari_cellseg3d/models/model_VNet.py b/napari_cellseg3d/models/model_VNet.py index 2b3d758b..cde7cf22 100644 --- a/napari_cellseg3d/models/model_VNet.py +++ b/napari_cellseg3d/models/model_VNet.py @@ -11,8 +11,7 @@ def get_net(): def get_weights_file(): - target_dir = utils.download_model("VNet") - return os.path.join(target_dir, "VNet_40e.pth") + return "VNet_40e.pth" def get_output(model, input): diff --git a/napari_cellseg3d/plugin_model_inference.py b/napari_cellseg3d/plugin_model_inference.py index 7b9842b6..22d94634 100644 --- a/napari_cellseg3d/plugin_model_inference.py +++ b/napari_cellseg3d/plugin_model_inference.py @@ -534,7 +534,6 @@ def start(self): else: weights_dict = { "custom": False, - "path": self.get_model(model_key).get_weights_file(), } if self.anisotropy_wdgt.is_enabled(): @@ -591,6 +590,7 @@ def start(self): keep_on_cpu=self.keep_on_cpu, stats_csv=self.stats_to_csv, ) + self.worker.set_download_log(self.log) yield_connect_show_res = lambda data: self.on_yield( data, diff --git a/napari_cellseg3d/plugin_model_training.py b/napari_cellseg3d/plugin_model_training.py index 99136641..4cf9bc31 100644 --- a/napari_cellseg3d/plugin_model_training.py +++ b/napari_cellseg3d/plugin_model_training.py @@ -851,6 +851,7 @@ def start(self): do_augmentation=self.augment_choice.isChecked(), deterministic=seed_dict, ) + self.worker.set_download_log(self.log) [btn.setVisible(False) for btn in self.close_buttons] @@ -988,6 +989,11 @@ def on_yield(data, widget): def make_csv(self): size_column = range(1, self.max_epochs + 1) + + if len(self.loss_values) == 0 or self.loss_values is None: + warnings.warn("No loss values to add to csv !") + return + self.df = pd.DataFrame( { "epoch": size_column, diff --git a/napari_cellseg3d/utils.py b/napari_cellseg3d/utils.py index 9fdf25c5..5c125283 100644 --- a/napari_cellseg3d/utils.py +++ b/napari_cellseg3d/utils.py @@ -1,9 +1,9 @@ -import importlib.util import os import warnings from datetime import datetime from pathlib import Path + import cv2 import numpy as np from dask_image.imread import imread as dask_imread @@ -969,42 +969,3 @@ def merge_imgs(imgs, original_image_shape): return merged_imgs -def download_model(modelname): - """ - Downloads a specific pretained model. - This code is adapted from DeepLabCut with permission from MWMathis - """ - import json - import tarfile - import urllib.request - - def show_progress(count, block_size, total_size): - pbar.update(block_size) - - cellseg3d_path = os.path.split( - importlib.util.find_spec("napari_cellseg3d").origin - )[0] - pretrained_folder_path = os.path.join( - cellseg3d_path, "models", "pretrained" - ) - json_path = os.path.join( - pretrained_folder_path, "pretrained_model_urls.json" - ) - with open(json_path) as f: - neturls = json.load(f) - if modelname in neturls.keys(): - url = neturls[modelname] - response = urllib.request.urlopen(url) - print( - f"Downloading the model from the M.W. Mathis Lab server {url}...." - ) - total_size = int(response.getheader("Content-Length")) - pbar = tqdm(unit="B", total=total_size, position=0) - filename, _ = urllib.request.urlretrieve(url, reporthook=show_progress) - with tarfile.open(filename, mode="r:gz") as tar: - tar.extractall(pretrained_folder_path) - return pretrained_folder_path - else: - raise ValueError( - f"Unknown model. `modelname` should be one of {', '.join(neturls)}" - ) From 34da28b2c2b1d34ce67ed0a1948d3e7f4a6c3564 Mon Sep 17 00:00:00 2001 From: C-Achard Date: Thu, 14 Jul 2022 18:11:46 +0200 Subject: [PATCH 2/7] Fixed model names mismatch - TRAILMAP is now TRAILMAP_MS, as intended initially - TRAILMAP_MS is now TRAILMAP_PyTorch.py, was incorrectly renamed - Fixed "Pytorch" typos - Improved logging function - Fixed redundant imports Co-Authored-By: Mackenzie Mathis <28102185+MMathisLab@users.noreply.github.com> --- napari_cellseg3d/dev_scripts/weight_conversion.py | 2 +- napari_cellseg3d/log_utility.py | 8 +++++--- napari_cellseg3d/model_framework.py | 8 ++++---- napari_cellseg3d/model_workers.py | 4 ++-- napari_cellseg3d/models/model_SegResNet.py | 3 --- .../{model_TRAILMAP.py => model_TRAILMAP_MS.py} | 7 ++----- .../{TRAILMAP_MS.py => model_TRAILMAP_PyTorch.py} | 12 +++++------- napari_cellseg3d/models/model_VNet.py | 3 --- .../models/pretrained/pretrained_model_urls.json | 2 +- napari_cellseg3d/plugin_model_training.py | 2 +- 10 files changed, 21 insertions(+), 30 deletions(-) rename napari_cellseg3d/models/{model_TRAILMAP.py => model_TRAILMAP_MS.py} (64%) rename napari_cellseg3d/models/{TRAILMAP_MS.py => model_TRAILMAP_PyTorch.py} (94%) diff --git a/napari_cellseg3d/dev_scripts/weight_conversion.py b/napari_cellseg3d/dev_scripts/weight_conversion.py index cb9c3e47..a752a847 100644 --- a/napari_cellseg3d/dev_scripts/weight_conversion.py +++ b/napari_cellseg3d/dev_scripts/weight_conversion.py @@ -3,7 +3,7 @@ import torch -from napari_cellseg3d.models.model_TRAILMAP import get_net +from napari_cellseg3d.models.model_TRAILMAP_PyTorch import get_net from napari_cellseg3d.models.unet.model import UNet3D # not sure this actually works when put here diff --git a/napari_cellseg3d/log_utility.py b/napari_cellseg3d/log_utility.py index c6aab5bb..d5ea2b24 100644 --- a/napari_cellseg3d/log_utility.py +++ b/napari_cellseg3d/log_utility.py @@ -56,18 +56,20 @@ def replace_last_line(self, text): finally: self.lock.release() - def print_and_log(self, text): + def print_and_log(self, text, printing=True): """Utility used to both print to terminal and log text to a QTextEdit item in a thread-safe manner. Use only for important user info. Args: text (str): Text to be printed and logged + printing (bool): Whether to print the message as well or not using print(). Defaults to True. """ self.lock.acquire() try: - print(text) - # causes issue if you clik on terminal (tied to CMD QuickEdit mode) + if printing: + print(text) + # causes issue if you clik on terminal (tied to CMD QuickEdit mode on Windows) self.moveCursor(QTextCursor.End) self.insertPlainText(f"\n{text}") self.verticalScrollBar().setValue( diff --git a/napari_cellseg3d/model_framework.py b/napari_cellseg3d/model_framework.py index 7cb0542f..851adb78 100644 --- a/napari_cellseg3d/model_framework.py +++ b/napari_cellseg3d/model_framework.py @@ -14,9 +14,9 @@ from napari_cellseg3d import utils from napari_cellseg3d.log_utility import Log from napari_cellseg3d.models import model_SegResNet as SegResNet -from napari_cellseg3d.models import model_TRAILMAP as TRAILMAP +from napari_cellseg3d.models import model_TRAILMAP_PyTorch as TRAILMAP from napari_cellseg3d.models import model_VNet as VNet -from napari_cellseg3d.models import TRAILMAP_MS as TMAP +from napari_cellseg3d.models import model_TRAILMAP_MS as TRAILMAP_MS from napari_cellseg3d.plugin_base import BasePluginFolder warnings.formatwarning = utils.format_Warning @@ -63,8 +63,8 @@ def __init__(self, viewer: "napari.viewer.Viewer"): self.models_dict = { "VNet": VNet, "SegResNet": SegResNet, - "TRAILMAP pre-trained": TRAILMAP, - "TRAILMAP_MS": TMAP, + "TRAILMAP_PyTorch": TRAILMAP, + "TRAILMAP_MS": TRAILMAP_MS, } """dict: dictionary of available models, with string for widget display as key diff --git a/napari_cellseg3d/model_workers.py b/napari_cellseg3d/model_workers.py index 0853ba3d..19de2e00 100644 --- a/napari_cellseg3d/model_workers.py +++ b/napari_cellseg3d/model_workers.py @@ -114,7 +114,7 @@ def show_progress(count, block_size, total_size): tar.extractall(pretrained_folder_path) else: raise ValueError( - f"Unknown model. `modelname` should be one of {', '.join(neturls)}" + f"Unknown model: {model_name}. Should be one of {', '.join(neturls)}" ) @@ -695,7 +695,7 @@ def log_parameters(self): self.log("-" * 20) def train(self): - """Trains the Pytorch model for the given number of epochs, with the selected model and data, + """Trains the PyTorch model for the given number of epochs, with the selected model and data, using the chosen batch size, validation interval, loss function, and number of samples. Will perform validation once every :py:obj:`val_interval` and save results if the mean dice is better diff --git a/napari_cellseg3d/models/model_SegResNet.py b/napari_cellseg3d/models/model_SegResNet.py index 14a9bd58..059b3ee5 100644 --- a/napari_cellseg3d/models/model_SegResNet.py +++ b/napari_cellseg3d/models/model_SegResNet.py @@ -1,8 +1,5 @@ -import os - from monai.networks.nets import SegResNetVAE -from napari_cellseg3d import utils def get_net(): diff --git a/napari_cellseg3d/models/model_TRAILMAP.py b/napari_cellseg3d/models/model_TRAILMAP_MS.py similarity index 64% rename from napari_cellseg3d/models/model_TRAILMAP.py rename to napari_cellseg3d/models/model_TRAILMAP_MS.py index 14afc33d..b6533da7 100644 --- a/napari_cellseg3d/models/model_TRAILMAP.py +++ b/napari_cellseg3d/models/model_TRAILMAP_MS.py @@ -1,12 +1,9 @@ -import os - -from napari_cellseg3d import utils from napari_cellseg3d.models.unet.model import UNet3D def get_weights_file(): - # original model from Liqun Luo lab, transfered to pytorch - return "TRAILMAP_PyTorch.pth" + # original model from Liqun Luo lab, transferred to pytorch + return "TRAILMAP_MS_best_metric_epoch_26.pth" def get_net(): diff --git a/napari_cellseg3d/models/TRAILMAP_MS.py b/napari_cellseg3d/models/model_TRAILMAP_PyTorch.py similarity index 94% rename from napari_cellseg3d/models/TRAILMAP_MS.py rename to napari_cellseg3d/models/model_TRAILMAP_PyTorch.py index ff82b5f0..4cd5c0eb 100644 --- a/napari_cellseg3d/models/TRAILMAP_MS.py +++ b/napari_cellseg3d/models/model_TRAILMAP_PyTorch.py @@ -1,18 +1,15 @@ -import os - import torch from torch import nn -from napari_cellseg3d import utils def get_weights_file(): # model additionally trained on Mathis/Wyss mesoSPIM data - return "TRAILMAP_MS_best_metric_epoch_26.pth" - + return "TRAILMAP_PyTorch.pth" + # FIXME currently incorrect, find good weights from TRAILMAP_test and upload them def get_net(): - return TRAILMAP_MS(1, 1) + return TRAILMAP_PyTorch(1, 1) def get_output(model, input): @@ -26,7 +23,7 @@ def get_validation(model, val_inputs): return model(val_inputs) -class TRAILMAP_MS(nn.Module): +class TRAILMAP_PyTorch(nn.Module): def __init__(self, in_ch, out_ch): super().__init__() self.conv0 = self.encoderBlock(in_ch, 32, 3) # input @@ -123,3 +120,4 @@ def outBlock(self, in_ch, out_ch, kernel_size, padding="same"): # nn.BatchNorm3d(out_ch), ) return out + diff --git a/napari_cellseg3d/models/model_VNet.py b/napari_cellseg3d/models/model_VNet.py index cde7cf22..63089b95 100644 --- a/napari_cellseg3d/models/model_VNet.py +++ b/napari_cellseg3d/models/model_VNet.py @@ -1,9 +1,6 @@ -import os - from monai.inferers import sliding_window_inference from monai.networks.nets import VNet -from napari_cellseg3d import utils def get_net(): diff --git a/napari_cellseg3d/models/pretrained/pretrained_model_urls.json b/napari_cellseg3d/models/pretrained/pretrained_model_urls.json index 86bc0f57..5d079a93 100644 --- a/napari_cellseg3d/models/pretrained/pretrained_model_urls.json +++ b/napari_cellseg3d/models/pretrained/pretrained_model_urls.json @@ -1,6 +1,6 @@ { "TRAILMAP_MS": "http://deeplabcut.rowland.harvard.edu/cellseg3dmodels/TRAILMAP_MS.tar.gz", - "TRAILMAP": "http://deeplabcut.rowland.harvard.edu/cellseg3dmodels/TRAILMAP.tar.gz", + "TRAILMAP_PyTorch": "http://deeplabcut.rowland.harvard.edu/cellseg3dmodels/TRAILMAP.tar.gz", "SegResNet": "http://deeplabcut.rowland.harvard.edu/cellseg3dmodels/SegResNet.tar.gz", "VNet": "http://deeplabcut.rowland.harvard.edu/cellseg3dmodels/VNet.tar.gz" } \ No newline at end of file diff --git a/napari_cellseg3d/plugin_model_training.py b/napari_cellseg3d/plugin_model_training.py index 4cf9bc31..226da4d8 100644 --- a/napari_cellseg3d/plugin_model_training.py +++ b/napari_cellseg3d/plugin_model_training.py @@ -35,7 +35,7 @@ class Trainer(ModelFramework): - """A plugin to train pre-defined Pytorch models for one-channel segmentation directly in napari. + """A plugin to train pre-defined PyTorch models for one-channel segmentation directly in napari. Features parameter selection for training, dynamic loss plotting and automatic saving of the best weights during training through validation.""" From e28ca386339db28100663ea1ea7154d5a301d615 Mon Sep 17 00:00:00 2001 From: C-Achard Date: Fri, 15 Jul 2022 14:14:16 +0200 Subject: [PATCH 3/7] Fixed model name error - Replaced all TRAILMAP_PyTorch with TRAILMAP --- napari_cellseg3d/dev_scripts/weight_conversion.py | 2 +- napari_cellseg3d/model_framework.py | 4 ++-- napari_cellseg3d/models/model_TRAILMAP_PyTorch.py | 6 +++--- .../models/pretrained/pretrained_model_urls.json | 2 +- 4 files changed, 7 insertions(+), 7 deletions(-) diff --git a/napari_cellseg3d/dev_scripts/weight_conversion.py b/napari_cellseg3d/dev_scripts/weight_conversion.py index a752a847..cb9c3e47 100644 --- a/napari_cellseg3d/dev_scripts/weight_conversion.py +++ b/napari_cellseg3d/dev_scripts/weight_conversion.py @@ -3,7 +3,7 @@ import torch -from napari_cellseg3d.models.model_TRAILMAP_PyTorch import get_net +from napari_cellseg3d.models.model_TRAILMAP import get_net from napari_cellseg3d.models.unet.model import UNet3D # not sure this actually works when put here diff --git a/napari_cellseg3d/model_framework.py b/napari_cellseg3d/model_framework.py index 851adb78..93107488 100644 --- a/napari_cellseg3d/model_framework.py +++ b/napari_cellseg3d/model_framework.py @@ -14,7 +14,7 @@ from napari_cellseg3d import utils from napari_cellseg3d.log_utility import Log from napari_cellseg3d.models import model_SegResNet as SegResNet -from napari_cellseg3d.models import model_TRAILMAP_PyTorch as TRAILMAP +from napari_cellseg3d.models import model_TRAILMAP as TRAILMAP from napari_cellseg3d.models import model_VNet as VNet from napari_cellseg3d.models import model_TRAILMAP_MS as TRAILMAP_MS from napari_cellseg3d.plugin_base import BasePluginFolder @@ -63,7 +63,7 @@ def __init__(self, viewer: "napari.viewer.Viewer"): self.models_dict = { "VNet": VNet, "SegResNet": SegResNet, - "TRAILMAP_PyTorch": TRAILMAP, + "TRAILMAP": TRAILMAP, "TRAILMAP_MS": TRAILMAP_MS, } """dict: dictionary of available models, with string for widget display as key diff --git a/napari_cellseg3d/models/model_TRAILMAP_PyTorch.py b/napari_cellseg3d/models/model_TRAILMAP_PyTorch.py index 4cd5c0eb..e6d2ed1a 100644 --- a/napari_cellseg3d/models/model_TRAILMAP_PyTorch.py +++ b/napari_cellseg3d/models/model_TRAILMAP_PyTorch.py @@ -5,11 +5,11 @@ def get_weights_file(): # model additionally trained on Mathis/Wyss mesoSPIM data - return "TRAILMAP_PyTorch.pth" + return "TRAILMAP.pth" # FIXME currently incorrect, find good weights from TRAILMAP_test and upload them def get_net(): - return TRAILMAP_PyTorch(1, 1) + return TRAILMAP(1, 1) def get_output(model, input): @@ -23,7 +23,7 @@ def get_validation(model, val_inputs): return model(val_inputs) -class TRAILMAP_PyTorch(nn.Module): +class TRAILMAP(nn.Module): def __init__(self, in_ch, out_ch): super().__init__() self.conv0 = self.encoderBlock(in_ch, 32, 3) # input diff --git a/napari_cellseg3d/models/pretrained/pretrained_model_urls.json b/napari_cellseg3d/models/pretrained/pretrained_model_urls.json index 5d079a93..86bc0f57 100644 --- a/napari_cellseg3d/models/pretrained/pretrained_model_urls.json +++ b/napari_cellseg3d/models/pretrained/pretrained_model_urls.json @@ -1,6 +1,6 @@ { "TRAILMAP_MS": "http://deeplabcut.rowland.harvard.edu/cellseg3dmodels/TRAILMAP_MS.tar.gz", - "TRAILMAP_PyTorch": "http://deeplabcut.rowland.harvard.edu/cellseg3dmodels/TRAILMAP.tar.gz", + "TRAILMAP": "http://deeplabcut.rowland.harvard.edu/cellseg3dmodels/TRAILMAP.tar.gz", "SegResNet": "http://deeplabcut.rowland.harvard.edu/cellseg3dmodels/SegResNet.tar.gz", "VNet": "http://deeplabcut.rowland.harvard.edu/cellseg3dmodels/VNet.tar.gz" } \ No newline at end of file From eb07b495863f7546489a2f0102233b87e9e84a13 Mon Sep 17 00:00:00 2001 From: C-Achard Date: Fri, 15 Jul 2022 14:25:29 +0200 Subject: [PATCH 4/7] Update model_TRAILMAP.py Renamed to correct name --- .../models/{model_TRAILMAP_PyTorch.py => model_TRAILMAP.py} | 0 1 file changed, 0 insertions(+), 0 deletions(-) rename napari_cellseg3d/models/{model_TRAILMAP_PyTorch.py => model_TRAILMAP.py} (100%) diff --git a/napari_cellseg3d/models/model_TRAILMAP_PyTorch.py b/napari_cellseg3d/models/model_TRAILMAP.py similarity index 100% rename from napari_cellseg3d/models/model_TRAILMAP_PyTorch.py rename to napari_cellseg3d/models/model_TRAILMAP.py From 8d143d281299102d06335ab471698dadd6d10b18 Mon Sep 17 00:00:00 2001 From: Mackenzie Mathis Date: Sun, 17 Jul 2022 02:33:51 +0200 Subject: [PATCH 5/7] Update model_TRAILMAP_MS.py --- napari_cellseg3d/models/model_TRAILMAP_MS.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/napari_cellseg3d/models/model_TRAILMAP_MS.py b/napari_cellseg3d/models/model_TRAILMAP_MS.py index b6533da7..1ee50158 100644 --- a/napari_cellseg3d/models/model_TRAILMAP_MS.py +++ b/napari_cellseg3d/models/model_TRAILMAP_MS.py @@ -2,7 +2,7 @@ def get_weights_file(): - # original model from Liqun Luo lab, transferred to pytorch + # original model from Liqun Luo lab, transferred to pytorch and trained on mesoSPIM-acquired data (mostly cFOS as of July 2022) return "TRAILMAP_MS_best_metric_epoch_26.pth" From e26285c050634ee3fdcd89bf6482ccad2a666311 Mon Sep 17 00:00:00 2001 From: C-Achard Date: Mon, 18 Jul 2022 16:12:59 +0200 Subject: [PATCH 6/7] Skip download + weights error handling - Skips download if the weights file is already present - Skips state_dict loading step in training if the weights are found to be incompatible --- napari_cellseg3d/log_utility.py | 12 ++-- napari_cellseg3d/model_workers.py | 84 ++++++++++++++++------ napari_cellseg3d/models/model_SegResNet.py | 1 - napari_cellseg3d/models/model_TRAILMAP.py | 3 +- napari_cellseg3d/models/model_VNet.py | 1 - napari_cellseg3d/plugin_model_training.py | 2 +- napari_cellseg3d/utils.py | 2 - 7 files changed, 73 insertions(+), 32 deletions(-) diff --git a/napari_cellseg3d/log_utility.py b/napari_cellseg3d/log_utility.py index d5ea2b24..516c1c47 100644 --- a/napari_cellseg3d/log_utility.py +++ b/napari_cellseg3d/log_utility.py @@ -28,13 +28,15 @@ def write(self, message): try: if not hasattr(self, "flag"): self.flag = False - message = message.replace('\r', '').rstrip() + message = message.replace("\r", "").rstrip() if message: method = "replace_last_line" if self.flag else "append" - QtCore.QMetaObject.invokeMethod(self, - method, - QtCore.Qt.QueuedConnection, - QtCore.Q_ARG(str, message)) + QtCore.QMetaObject.invokeMethod( + self, + method, + QtCore.Qt.QueuedConnection, + QtCore.Q_ARG(str, message), + ) self.flag = True else: self.flag = False diff --git a/napari_cellseg3d/model_workers.py b/napari_cellseg3d/model_workers.py index 19de2e00..a1d39c52 100644 --- a/napari_cellseg3d/model_workers.py +++ b/napari_cellseg3d/model_workers.py @@ -3,6 +3,7 @@ from pathlib import Path import importlib.util from typing import Optional +import warnings import numpy as np from tifffile import imwrite @@ -65,19 +66,28 @@ Path("/models/pretrained") ) + class WeightsDownloader: + """A utility class the downloads the weights of a model when needed.""" + + def __init__(self, log_widget: Optional[log_utility.Log] = None): + """ + Creates a WeightsDownloader, optionally with a log widget to display the progress. - def __init__(self, log_widget: Optional[log_utility.Log]= None): + Args: + log_widget (log_utility.Log): a Log to display the progress bar in. If None, uses print() + """ self.log_widget = log_widget - def download_weights(self,model_name: str): + def download_weights(self, model_name: str, model_weights_filename: str): """ - Downloads a specific pretrained model. - This code is adapted from DeepLabCut with permission from MWMathis. + Downloads a specific pretrained model. + This code is adapted from DeepLabCut with permission from MWMathis. - Args: - model_name (str): name of the model to download - """ + Args: + model_name (str): name of the model to download + model_weights_filename (str): name of the .pth file expected for the model + """ import json import tarfile import urllib.request @@ -94,6 +104,17 @@ def show_progress(count, block_size, total_size): json_path = os.path.join( pretrained_folder_path, "pretrained_model_urls.json" ) + + check_path = os.path.join( + pretrained_folder_path, model_weights_filename + ) + if os.path.exists(check_path): + message = f"Weight file {model_weights_filename} already exists, skipping download step" + if self.log_widget is not None: + self.log_widget.print_and_log(message, printing=False) + print(message) + return + with open(json_path) as f: neturls = json.load(f) if model_name in neturls.keys(): @@ -107,9 +128,16 @@ def show_progress(count, block_size, total_size): pbar = tqdm(unit="B", total=total_size, position=0) else: self.log_widget.print_and_log(start_message) - pbar = tqdm(unit="B", total=total_size, position=0, file=self.log_widget) + pbar = tqdm( + unit="B", + total=total_size, + position=0, + file=self.log_widget, + ) - filename, _ = urllib.request.urlretrieve(url, reporthook=show_progress) + filename, _ = urllib.request.urlretrieve( + url, reporthook=show_progress + ) with tarfile.open(filename, mode="r:gz") as tar: tar.extractall(pretrained_folder_path) else: @@ -204,7 +232,6 @@ def __init__( self.downloader = WeightsDownloader() """Download utility""" - @staticmethod def create_inference_dict(images_filepaths): """Create a dict for MONAI with "image" keys with all image paths in :py:attr:`~self.images_filepaths` @@ -297,7 +324,7 @@ def inference(self): sys = platform.system() print(f"OS is {sys}") if sys == "Darwin": - torch.set_num_threads(1) # required for threading on macOS ? + torch.set_num_threads(1) # required for threading on macOS ? self.log("Number of threads has been set to 1 for macOS") images_dict = self.create_inference_dict(self.images_filepaths) @@ -323,7 +350,11 @@ def inference(self): model = self.model_dict["class"].get_net() if self.model_dict["name"] == "SegResNet": model = self.model_dict["class"].get_net()( - input_image_size=[dims, dims, dims], # TODO FIX ! find a better way & remove model-specific code + input_image_size=[ + dims, + dims, + dims, + ], # TODO FIX ! find a better way & remove model-specific code out_channels=1, # dropout_prob=0.3, ) @@ -372,8 +403,13 @@ def inference(self): if self.weights_dict["custom"]: weights = self.weights_dict["path"] else: - self.downloader.download_weights(self.model_dict["name"]) - weights = os.path.join(WEIGHTS_DIR, self.model_dict["class"].get_weights_file()) + self.downloader.download_weights( + self.model_dict["name"], + self.model_dict["class"].get_weights_file(), + ) + weights = os.path.join( + WEIGHTS_DIR, self.model_dict["class"].get_weights_file() + ) model.load_state_dict( torch.load( @@ -904,18 +940,26 @@ def train(self): if self.weights_path is not None: if self.weights_path == "use_pretrained": weights_file = model_class.get_weights_file() - self.downloader.download_weights(model_name) + self.downloader.download_weights(model_name, weights_file) weights = os.path.join(WEIGHTS_DIR, weights_file) self.weights_path = weights else: weights = os.path.join(self.weights_path) - model.load_state_dict( - torch.load( - weights, - map_location=self.device, + try: + model.load_state_dict( + torch.load( + weights, + map_location=self.device, + ) ) - ) + except RuntimeError: + warn = ( + "WARNING:\nIt seems the weights were incompatible with the model,\n" + "the model will be trained from random weights" + ) + self.log(warn) + warnings.warn(warn) if self.device.type == "cuda": self.log("\nUsing GPU :") diff --git a/napari_cellseg3d/models/model_SegResNet.py b/napari_cellseg3d/models/model_SegResNet.py index 059b3ee5..41dc3bde 100644 --- a/napari_cellseg3d/models/model_SegResNet.py +++ b/napari_cellseg3d/models/model_SegResNet.py @@ -1,7 +1,6 @@ from monai.networks.nets import SegResNetVAE - def get_net(): return SegResNetVAE diff --git a/napari_cellseg3d/models/model_TRAILMAP.py b/napari_cellseg3d/models/model_TRAILMAP.py index e6d2ed1a..ec4cfdbb 100644 --- a/napari_cellseg3d/models/model_TRAILMAP.py +++ b/napari_cellseg3d/models/model_TRAILMAP.py @@ -2,12 +2,12 @@ from torch import nn - def get_weights_file(): # model additionally trained on Mathis/Wyss mesoSPIM data return "TRAILMAP.pth" # FIXME currently incorrect, find good weights from TRAILMAP_test and upload them + def get_net(): return TRAILMAP(1, 1) @@ -120,4 +120,3 @@ def outBlock(self, in_ch, out_ch, kernel_size, padding="same"): # nn.BatchNorm3d(out_ch), ) return out - diff --git a/napari_cellseg3d/models/model_VNet.py b/napari_cellseg3d/models/model_VNet.py index 63089b95..0c5f0b75 100644 --- a/napari_cellseg3d/models/model_VNet.py +++ b/napari_cellseg3d/models/model_VNet.py @@ -2,7 +2,6 @@ from monai.networks.nets import VNet - def get_net(): return VNet() diff --git a/napari_cellseg3d/plugin_model_training.py b/napari_cellseg3d/plugin_model_training.py index 569ed054..79e489d1 100644 --- a/napari_cellseg3d/plugin_model_training.py +++ b/napari_cellseg3d/plugin_model_training.py @@ -994,7 +994,7 @@ def make_csv(self): if len(self.loss_values) == 0 or self.loss_values is None: warnings.warn("No loss values to add to csv !") return - + self.df = pd.DataFrame( { "epoch": size_column, diff --git a/napari_cellseg3d/utils.py b/napari_cellseg3d/utils.py index d4d4e1f4..bc725fc1 100644 --- a/napari_cellseg3d/utils.py +++ b/napari_cellseg3d/utils.py @@ -978,5 +978,3 @@ def merge_imgs(imgs, original_image_shape): print(merged_imgs.shape) return merged_imgs - - From 487560d1cf644494cb59e5339a6c1d7214b8680b Mon Sep 17 00:00:00 2001 From: C-Achard Date: Mon, 18 Jul 2022 16:47:58 +0200 Subject: [PATCH 7/7] Worker warning + weights error handling - Added signal to emit warning from worker (not functional for now) - Added clearer warning when weights are not compatible in training --- napari_cellseg3d/log_utility.py | 8 +++++++ napari_cellseg3d/model_workers.py | 28 +++++++++++++++++++--- napari_cellseg3d/plugin_model_inference.py | 1 + napari_cellseg3d/plugin_model_training.py | 1 + 4 files changed, 35 insertions(+), 3 deletions(-) diff --git a/napari_cellseg3d/log_utility.py b/napari_cellseg3d/log_utility.py index 516c1c47..1ae9b2a0 100644 --- a/napari_cellseg3d/log_utility.py +++ b/napari_cellseg3d/log_utility.py @@ -1,4 +1,5 @@ import threading +import warnings from qtpy import QtCore from qtpy.QtGui import QTextCursor @@ -79,3 +80,10 @@ def print_and_log(self, text, printing=True): ) finally: self.lock.release() + + def warn(self, warning): + self.lock.acquire() + try: + warnings.warn(warning) + finally: + self.lock.release() diff --git a/napari_cellseg3d/model_workers.py b/napari_cellseg3d/model_workers.py index a1d39c52..eb2f3cb9 100644 --- a/napari_cellseg3d/model_workers.py +++ b/napari_cellseg3d/model_workers.py @@ -149,10 +149,12 @@ def show_progress(count, block_size, total_size): class LogSignal(WorkerBaseSignals): """Signal to send messages to be logged from another thread. - Separate from Worker instances as indicated `here`_""" + Separate from Worker instances as indicated `here`_""" # TODO link ? log_signal = Signal(str) """qtpy.QtCore.Signal: signal to be sent when some text should be logged""" + warn_signal = Signal(str) + """qtpy.QtCore.Signal: signal to be sent when some warning should be emitted in main thread""" # Should not be an instance variable but a class variable, not defined in __init__, see # https://stackoverflow.com/questions/2970312/pyqt4-qtcore-pyqtsignal-object-has-no-attribute-connect @@ -213,6 +215,7 @@ def __init__( super().__init__(self.inference) self._signals = LogSignal() # add custom signals self.log_signal = self._signals.log_signal + self.warn_signal = self._signals.warn_signal ########################################### ########################################### self.device = device @@ -252,6 +255,10 @@ def log(self, text): """ self.log_signal.emit(text) + def warn(self, warning): + """Sends a warning to main thread""" + self.warn_signal.emit(warning) + def log_parameters(self): self.log("-" * 20) @@ -647,7 +654,10 @@ def __init__( super().__init__(self.train) self._signals = LogSignal() self.log_signal = self._signals.log_signal + self.warn_signal = self._signals.warn_signal + self._weight_error = False + ############################################# self.device = device self.model_dict = model_dict self.weights_path = weights_path @@ -669,7 +679,7 @@ def __init__( self.train_files = [] self.val_files = [] - + ####################################### self.downloader = WeightsDownloader() def set_download_log(self, widget): @@ -683,6 +693,10 @@ def log(self, text): """ self.log_signal.emit(text) + def warn(self, warning): + """Sends a warning to main thread""" + self.warn_signal.emit(warning) + def log_parameters(self): self.log("-" * 20) @@ -726,6 +740,13 @@ def log_parameters(self): if self.weights_path is not None: self.log(f"Using weights from : {self.weights_path}") + if self._weight_error: + self.log( + ">>>>>>>>>>>>>>>>>\n" + "WARNING:\nChosen weights were incompatible with the model,\n" + "the model will be trained from random weights\n" + "<<<<<<<<<<<<<<<<<\n" + ) # self.log("\n") self.log("-" * 20) @@ -959,7 +980,8 @@ def train(self): "the model will be trained from random weights" ) self.log(warn) - warnings.warn(warn) + self.warn(warn) + self._weight_error = True if self.device.type == "cuda": self.log("\nUsing GPU :") diff --git a/napari_cellseg3d/plugin_model_inference.py b/napari_cellseg3d/plugin_model_inference.py index f04479db..711f4b49 100644 --- a/napari_cellseg3d/plugin_model_inference.py +++ b/napari_cellseg3d/plugin_model_inference.py @@ -599,6 +599,7 @@ def start(self): self.worker.started.connect(self.on_start) self.worker.log_signal.connect(self.log.print_and_log) + self.worker.warn_signal.connect(self.log.warn) self.worker.yielded.connect(yield_connect_show_res) self.worker.errored.connect( yield_connect_show_res diff --git a/napari_cellseg3d/plugin_model_training.py b/napari_cellseg3d/plugin_model_training.py index 79e489d1..517cf8fc 100644 --- a/napari_cellseg3d/plugin_model_training.py +++ b/napari_cellseg3d/plugin_model_training.py @@ -857,6 +857,7 @@ def start(self): [btn.setVisible(False) for btn in self.close_buttons] self.worker.log_signal.connect(self.log.print_and_log) + self.worker.warn_signal.connect(self.log.warn) self.worker.started.connect(self.on_start)