Skip to content

Commit

Permalink
Merge pull request #24 from AdaptiveMotorControlLab/cy/download-impro…
Browse files Browse the repository at this point in the history
…vements

Improved pretrained weights usage
  • Loading branch information
MMathisLab authored Jul 21, 2022
2 parents 1709528 + 487560d commit ec8ab8b
Show file tree
Hide file tree
Showing 8 changed files with 107 additions and 34 deletions.
20 changes: 15 additions & 5 deletions napari_cellseg3d/log_utility.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import threading
import warnings

from qtpy import QtCore
from qtpy.QtGui import QTextCursor
Expand Down Expand Up @@ -28,13 +29,15 @@ def write(self, message):
try:
if not hasattr(self, "flag"):
self.flag = False
message = message.replace('\r', '').rstrip()
message = message.replace("\r", "").rstrip()
if message:
method = "replace_last_line" if self.flag else "append"
QtCore.QMetaObject.invokeMethod(self,
method,
QtCore.Qt.QueuedConnection,
QtCore.Q_ARG(str, message))
QtCore.QMetaObject.invokeMethod(
self,
method,
QtCore.Qt.QueuedConnection,
QtCore.Q_ARG(str, message),
)
self.flag = True
else:
self.flag = False
Expand Down Expand Up @@ -77,3 +80,10 @@ def print_and_log(self, text, printing=True):
)
finally:
self.lock.release()

def warn(self, warning):
self.lock.acquire()
try:
warnings.warn(warning)
finally:
self.lock.release()
110 changes: 88 additions & 22 deletions napari_cellseg3d/model_workers.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from pathlib import Path
import importlib.util
from typing import Optional
import warnings

import numpy as np
from tifffile import imwrite
Expand Down Expand Up @@ -65,19 +66,28 @@
Path("/models/pretrained")
)


class WeightsDownloader:
"""A utility class the downloads the weights of a model when needed."""

def __init__(self, log_widget: Optional[log_utility.Log] = None):
"""
Creates a WeightsDownloader, optionally with a log widget to display the progress.
def __init__(self, log_widget: Optional[log_utility.Log]= None):
Args:
log_widget (log_utility.Log): a Log to display the progress bar in. If None, uses print()
"""
self.log_widget = log_widget

def download_weights(self,model_name: str):
def download_weights(self, model_name: str, model_weights_filename: str):
"""
Downloads a specific pretrained model.
This code is adapted from DeepLabCut with permission from MWMathis.
Downloads a specific pretrained model.
This code is adapted from DeepLabCut with permission from MWMathis.
Args:
model_name (str): name of the model to download
"""
Args:
model_name (str): name of the model to download
model_weights_filename (str): name of the .pth file expected for the model
"""
import json
import tarfile
import urllib.request
Expand All @@ -94,6 +104,17 @@ def show_progress(count, block_size, total_size):
json_path = os.path.join(
pretrained_folder_path, "pretrained_model_urls.json"
)

check_path = os.path.join(
pretrained_folder_path, model_weights_filename
)
if os.path.exists(check_path):
message = f"Weight file {model_weights_filename} already exists, skipping download step"
if self.log_widget is not None:
self.log_widget.print_and_log(message, printing=False)
print(message)
return

with open(json_path) as f:
neturls = json.load(f)
if model_name in neturls.keys():
Expand All @@ -107,9 +128,16 @@ def show_progress(count, block_size, total_size):
pbar = tqdm(unit="B", total=total_size, position=0)
else:
self.log_widget.print_and_log(start_message)
pbar = tqdm(unit="B", total=total_size, position=0, file=self.log_widget)
pbar = tqdm(
unit="B",
total=total_size,
position=0,
file=self.log_widget,
)

filename, _ = urllib.request.urlretrieve(url, reporthook=show_progress)
filename, _ = urllib.request.urlretrieve(
url, reporthook=show_progress
)
with tarfile.open(filename, mode="r:gz") as tar:
tar.extractall(pretrained_folder_path)
else:
Expand All @@ -121,10 +149,12 @@ def show_progress(count, block_size, total_size):
class LogSignal(WorkerBaseSignals):
"""Signal to send messages to be logged from another thread.
Separate from Worker instances as indicated `here`_"""
Separate from Worker instances as indicated `here`_""" # TODO link ?

log_signal = Signal(str)
"""qtpy.QtCore.Signal: signal to be sent when some text should be logged"""
warn_signal = Signal(str)
"""qtpy.QtCore.Signal: signal to be sent when some warning should be emitted in main thread"""

# Should not be an instance variable but a class variable, not defined in __init__, see
# https://stackoverflow.com/questions/2970312/pyqt4-qtcore-pyqtsignal-object-has-no-attribute-connect
Expand Down Expand Up @@ -185,6 +215,7 @@ def __init__(
super().__init__(self.inference)
self._signals = LogSignal() # add custom signals
self.log_signal = self._signals.log_signal
self.warn_signal = self._signals.warn_signal
###########################################
###########################################
self.device = device
Expand All @@ -204,7 +235,6 @@ def __init__(
self.downloader = WeightsDownloader()
"""Download utility"""


@staticmethod
def create_inference_dict(images_filepaths):
"""Create a dict for MONAI with "image" keys with all image paths in :py:attr:`~self.images_filepaths`
Expand All @@ -225,6 +255,10 @@ def log(self, text):
"""
self.log_signal.emit(text)

def warn(self, warning):
"""Sends a warning to main thread"""
self.warn_signal.emit(warning)

def log_parameters(self):

self.log("-" * 20)
Expand Down Expand Up @@ -297,7 +331,7 @@ def inference(self):
sys = platform.system()
print(f"OS is {sys}")
if sys == "Darwin":
torch.set_num_threads(1) # required for threading on macOS ?
torch.set_num_threads(1) # required for threading on macOS ?
self.log("Number of threads has been set to 1 for macOS")

images_dict = self.create_inference_dict(self.images_filepaths)
Expand All @@ -323,7 +357,11 @@ def inference(self):
model = self.model_dict["class"].get_net()
if self.model_dict["name"] == "SegResNet":
model = self.model_dict["class"].get_net()(
input_image_size=[dims, dims, dims], # TODO FIX ! find a better way & remove model-specific code
input_image_size=[
dims,
dims,
dims,
], # TODO FIX ! find a better way & remove model-specific code
out_channels=1,
# dropout_prob=0.3,
)
Expand Down Expand Up @@ -372,8 +410,13 @@ def inference(self):
if self.weights_dict["custom"]:
weights = self.weights_dict["path"]
else:
self.downloader.download_weights(self.model_dict["name"])
weights = os.path.join(WEIGHTS_DIR, self.model_dict["class"].get_weights_file())
self.downloader.download_weights(
self.model_dict["name"],
self.model_dict["class"].get_weights_file(),
)
weights = os.path.join(
WEIGHTS_DIR, self.model_dict["class"].get_weights_file()
)

model.load_state_dict(
torch.load(
Expand Down Expand Up @@ -611,7 +654,10 @@ def __init__(
super().__init__(self.train)
self._signals = LogSignal()
self.log_signal = self._signals.log_signal
self.warn_signal = self._signals.warn_signal

self._weight_error = False
#############################################
self.device = device
self.model_dict = model_dict
self.weights_path = weights_path
Expand All @@ -633,7 +679,7 @@ def __init__(

self.train_files = []
self.val_files = []

#######################################
self.downloader = WeightsDownloader()

def set_download_log(self, widget):
Expand All @@ -647,6 +693,10 @@ def log(self, text):
"""
self.log_signal.emit(text)

def warn(self, warning):
"""Sends a warning to main thread"""
self.warn_signal.emit(warning)

def log_parameters(self):

self.log("-" * 20)
Expand Down Expand Up @@ -690,6 +740,13 @@ def log_parameters(self):

if self.weights_path is not None:
self.log(f"Using weights from : {self.weights_path}")
if self._weight_error:
self.log(
">>>>>>>>>>>>>>>>>\n"
"WARNING:\nChosen weights were incompatible with the model,\n"
"the model will be trained from random weights\n"
"<<<<<<<<<<<<<<<<<\n"
)

# self.log("\n")
self.log("-" * 20)
Expand Down Expand Up @@ -904,18 +961,27 @@ def train(self):
if self.weights_path is not None:
if self.weights_path == "use_pretrained":
weights_file = model_class.get_weights_file()
self.downloader.download_weights(model_name)
self.downloader.download_weights(model_name, weights_file)
weights = os.path.join(WEIGHTS_DIR, weights_file)
self.weights_path = weights
else:
weights = os.path.join(self.weights_path)

model.load_state_dict(
torch.load(
weights,
map_location=self.device,
try:
model.load_state_dict(
torch.load(
weights,
map_location=self.device,
)
)
)
except RuntimeError:
warn = (
"WARNING:\nIt seems the weights were incompatible with the model,\n"
"the model will be trained from random weights"
)
self.log(warn)
self.warn(warn)
self._weight_error = True

if self.device.type == "cuda":
self.log("\nUsing GPU :")
Expand Down
1 change: 0 additions & 1 deletion napari_cellseg3d/models/model_SegResNet.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
from monai.networks.nets import SegResNetVAE



def get_net():
return SegResNetVAE

Expand Down
3 changes: 1 addition & 2 deletions napari_cellseg3d/models/model_TRAILMAP.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,12 @@
from torch import nn



def get_weights_file():
# model additionally trained on Mathis/Wyss mesoSPIM data
return "TRAILMAP.pth"
# FIXME currently incorrect, find good weights from TRAILMAP_test and upload them


def get_net():
return TRAILMAP(1, 1)

Expand Down Expand Up @@ -120,4 +120,3 @@ def outBlock(self, in_ch, out_ch, kernel_size, padding="same"):
# nn.BatchNorm3d(out_ch),
)
return out

1 change: 0 additions & 1 deletion napari_cellseg3d/models/model_VNet.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
from monai.networks.nets import VNet



def get_net():
return VNet()

Expand Down
1 change: 1 addition & 0 deletions napari_cellseg3d/plugin_model_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -599,6 +599,7 @@ def start(self):

self.worker.started.connect(self.on_start)
self.worker.log_signal.connect(self.log.print_and_log)
self.worker.warn_signal.connect(self.log.warn)
self.worker.yielded.connect(yield_connect_show_res)
self.worker.errored.connect(
yield_connect_show_res
Expand Down
3 changes: 2 additions & 1 deletion napari_cellseg3d/plugin_model_training.py
Original file line number Diff line number Diff line change
Expand Up @@ -857,6 +857,7 @@ def start(self):
[btn.setVisible(False) for btn in self.close_buttons]

self.worker.log_signal.connect(self.log.print_and_log)
self.worker.warn_signal.connect(self.log.warn)

self.worker.started.connect(self.on_start)

Expand Down Expand Up @@ -994,7 +995,7 @@ def make_csv(self):
if len(self.loss_values) == 0 or self.loss_values is None:
warnings.warn("No loss values to add to csv !")
return

self.df = pd.DataFrame(
{
"epoch": size_column,
Expand Down
2 changes: 0 additions & 2 deletions napari_cellseg3d/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -978,5 +978,3 @@ def merge_imgs(imgs, original_image_shape):

print(merged_imgs.shape)
return merged_imgs


0 comments on commit ec8ab8b

Please sign in to comment.