Skip to content

Commit

Permalink
Merge pull request #6 from AdaptiveMotorControlLab/mwm/download_models
Browse files Browse the repository at this point in the history
Mwm/download models
  • Loading branch information
MMathisLab authored Jun 28, 2022
2 parents 31b5fd7 + 9beb8a1 commit 9941222
Show file tree
Hide file tree
Showing 11 changed files with 62 additions and 100 deletions.
4 changes: 1 addition & 3 deletions docs/res/guides/custom_model_template.rst
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ To add a custom model, you will need a **.py** file with the following structure

def get_weights_file():
return "weights_file.pth" # name of the weights file for the model,
# which should be in *napari_cellseg3d/models/saved_weights*
# which should be in *napari_cellseg3d/models/pretrained*


def get_output(model, input):
Expand All @@ -35,5 +35,3 @@ To add a custom model, you will need a **.py** file with the following structure
def ModelClass(x1,x2...):
# your Pytorch model here...
return results # should return as [C, N, D,H,W]


7 changes: 5 additions & 2 deletions napari_cellseg3d/models/TRAILMAP_MS.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,13 @@
import torch
from torch import nn
from napari_cellseg3d import utils
import os


def get_weights_file():
# return "TMP_TEST_40e.pth"
return "TRAILMAP_DFl_best.pth"
# model additionally trained on Mathis/Wyss mesoSPIM data
target_dir = utils.download_model("TRAILMAP_MS")
return os.path.join(target_dir, "TRAILMAP_MS_best_metric_epoch_26.pth")


def get_net():
Expand Down
5 changes: 4 additions & 1 deletion napari_cellseg3d/models/model_SegResNet.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,15 @@
from monai.networks.nets import SegResNetVAE
from napari_cellseg3d import utils
import os


def get_net():
return SegResNetVAE


def get_weights_file():
return "SegResNet.pth"
target_dir = utils.download_model("SegResNet")
return os.path.join(target_dir, "SegResNet.pth")


def get_output(model, input):
Expand Down
7 changes: 5 additions & 2 deletions napari_cellseg3d/models/model_TRAILMAP.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,12 @@
from napari_cellseg3d.models.unet.model import UNet3D
from napari_cellseg3d import utils
import os


def get_weights_file():
# return "TMP_TEST_40e.pth"
return "trailmaptorchpretrained.pth"
# original model from Liqun Luo lab, transfered to pytorch
target_dir = utils.download_model("TRAILMAP")
return os.path.join(target_dir, "TRAILMAP_PyTorch.pth")


def get_net():
Expand Down
6 changes: 4 additions & 2 deletions napari_cellseg3d/models/model_VNet.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,16 @@
from monai.inferers import sliding_window_inference
from monai.networks.nets import VNet
from napari_cellseg3d import utils
import os


def get_net():
return VNet()


def get_weights_file():
# return "dice_VNet.pth"
return "VNet_40e.pth"
target_dir = utils.download_model("VNet")
return os.path.join(target_dir, "VNet_40e.pth")


def get_output(model, input):
Expand Down
6 changes: 6 additions & 0 deletions napari_cellseg3d/models/pretrained/pretrained_model_urls.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
{
"TRAILMAP_MS": "http://deeplabcut.rowland.harvard.edu/cellseg3dmodels/TRAILMAP_MS.tar.gz",
"TRAILMAP": "http://deeplabcut.rowland.harvard.edu/cellseg3dmodels/TRAILMAP.tar.gz",
"SegResNet": "http://deeplabcut.rowland.harvard.edu/cellseg3dmodels/SegResNet.tar.gz",
"VNet": "http://deeplabcut.rowland.harvard.edu/cellseg3dmodels/VNet.tar.gz"
}
6 changes: 0 additions & 6 deletions napari_cellseg3d/models/pretrained/pretrained_model_urls.yaml

This file was deleted.

14 changes: 1 addition & 13 deletions napari_cellseg3d/plugin_crop.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,19 +72,7 @@ def __init__(self, viewer: "napari.viewer.Viewer", parent):

self.build()

###########################################
if utils.ENABLE_TEST_MODE():
# TODO : remove/disable once done
if self.as_folder:
self.image_path = "C:/Users/Cyril/Desktop/Proj_bachelor/data/visual_png/sample"
if self.crop_label_choice.isChecked():
self.label_path = "C:/Users/Cyril/Desktop/Proj_bachelor/data/visual_png/sample_labels"
else:
self.image_path = "C:/Users/Cyril/Desktop/Proj_bachelor/data/visual_tif/volumes/images.tif"
if self.crop_label_choice.isChecked():
self.label_path = "C:/Users/Cyril/Desktop/Proj_bachelor/data/visual_tif/labels/testing_im.tif"

###########################################


def toggle_label_path(self):
if self.crop_label_choice.isChecked():
Expand Down
18 changes: 1 addition & 17 deletions napari_cellseg3d/plugin_metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,23 +67,7 @@ def __init__(self, viewer: "napari.viewer.Viewer", parent):

self.build()

######################################
# TODO test remove
import glob
import os

if utils.ENABLE_TEST_MODE():
ground_directory = "C:/Users/Cyril/Desktop/Proj_bachelor/data/cropped_visual/train/lab"
# ground_directory = "C:/Users/Cyril/Desktop/test/labels"
pred_directory = "C:/Users/Cyril/Desktop/test/pred"
# pred_directory = "C:/Users/Cyril/Desktop/test"
self.images_filepaths = sorted(
glob.glob(os.path.join(ground_directory, "*.tif"))
)
self.labels_filepaths = sorted(
glob.glob(os.path.join(pred_directory, "*.tif"))
)
###############################################################################


def build(self):
"""Builds the layout of the widget."""
Expand Down
39 changes: 1 addition & 38 deletions napari_cellseg3d/plugin_model_training.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,44 +132,7 @@ def __init__(

self.save_as_zip = False
"""Whether to zip results folder once done. Creates a zipped copy of the results folder."""
######################
######################
######################
# TEST TODO REMOVE
import glob

if utils.ENABLE_TEST_MODE():
directory = os.path.dirname(os.path.realpath(__file__)) + str(
Path("/models/dataset/volumes")
)
self.data_path = directory

lab_directory = os.path.dirname(os.path.realpath(__file__)) + str(
Path("/models/dataset/lab_sem")
)
self.label_path = lab_directory

self.images_filepaths = sorted(
glob.glob(os.path.join(directory, "*.tif"))
)

self.labels_filepaths = sorted(
glob.glob(os.path.join(lab_directory, "*.tif"))
)

if results_path == "":
self.results_path = "C:/Users/Cyril/Desktop/test/models"
else:
self.results_path = results_path

if data_path != "":
self.data_path = data_path

if label_path != "":
self.label_path = label_path
#######################
#######################
#######################


# recover default values
self.num_samples = samples
Expand Down
50 changes: 34 additions & 16 deletions napari_cellseg3d/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,29 +12,14 @@
from skimage.filters import gaussian
from tifffile import imread as tfl_imread
from tqdm import tqdm
import importlib.util

"""
utils.py
====================================
Definitions of utility functions and variables
"""

##################
##################
# dev util
def ENABLE_TEST_MODE():
path = Path(os.path.expanduser("~"))
# print(path)
print("TEST MODE ENABLED, DEV ONLY")
if path == Path("C:/Users/Cyril"):
return True
return False


##################
##################


def normalize_x(image):
"""Normalizes the values of an image array to be between [-1;1] rather than [0;255]
Expand Down Expand Up @@ -996,3 +981,36 @@ def merge_imgs(imgs, original_image_shape):

print(merged_imgs.shape)
return merged_imgs


def download_model(modelname):
"""
Downloads a specific pretained model.
This code is adapted from DeepLabCut with permission from MWMathis
"""
import json
import urllib.request
import tarfile

def show_progress(count, block_size, total_size):
pbar.update(block_size)

cellseg3d_path = os.path.split(importlib.util.find_spec("napari_cellseg3d").origin)[0]
pretrained_folder_path = os.path.join(cellseg3d_path, "models", "pretrained")
json_path = os.path.join(pretrained_folder_path, "pretrained_model_urls.json")
with open(json_path) as f:
neturls = json.load(f)
if modelname in neturls.keys():
url = neturls[modelname]
response = urllib.request.urlopen(url)
print(f"Downloading the model from the M.W. Mathis Lab server {url}....")
total_size = int(response.getheader("Content-Length"))
pbar = tqdm(unit="B", total=total_size, position=0)
filename, _ = urllib.request.urlretrieve(url, reporthook=show_progress)
with tarfile.open(filename, mode="r:gz") as tar:
tar.extractall(pretrained_folder_path)
return pretrained_folder_path
else:
raise ValueError(
f"Unknown model. `modelname` should be one of {', '.join(neturls)}"
)

0 comments on commit 9941222

Please sign in to comment.