Skip to content

Commit

Permalink
Fix model download function
Browse files Browse the repository at this point in the history
  • Loading branch information
jeylau committed Jun 27, 2022
1 parent 04e0af2 commit 9beb8a1
Show file tree
Hide file tree
Showing 5 changed files with 19 additions and 41 deletions.
7 changes: 3 additions & 4 deletions napari_cellseg3d/models/TRAILMAP_MS.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,11 @@
from napari_cellseg3d import utils
import os

modelname = "TRAILMAP_MS"
target_dir = os.path.join("models","pretrained")

def get_weights_file():
utils.download_model(modelname, target_dir)
return "TRAILMAP_MS_best_metric_epoch_26.pth" #model additionally trained on Mathis/Wyss mesoSPIM data
# model additionally trained on Mathis/Wyss mesoSPIM data
target_dir = utils.download_model("TRAILMAP_MS")
return os.path.join(target_dir, "TRAILMAP_MS_best_metric_epoch_26.pth")


def get_net():
Expand Down
6 changes: 2 additions & 4 deletions napari_cellseg3d/models/model_SegResNet.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,16 +2,14 @@
from napari_cellseg3d import utils
import os

modelname = "SegResNet"
target_dir = os.path.join("models","pretrained")

def get_net():
return SegResNetVAE


def get_weights_file():
utils.download_model(modelname, target_dir)
return "SegResNet.pth"
target_dir = utils.download_model("SegResNet")
return os.path.join(target_dir, "SegResNet.pth")


def get_output(model, input):
Expand Down
7 changes: 3 additions & 4 deletions napari_cellseg3d/models/model_TRAILMAP.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,11 @@
from napari_cellseg3d import utils
import os

modelname = "TRAILMAP"
target_dir = os.path.join("models","pretrained")

def get_weights_file():
utils.download_model(modelname, target_dir)
return "TRAILMAP_PyTorch.pth" #original model from Liqun Luo lab, transfered to pytorch
# original model from Liqun Luo lab, transfered to pytorch
target_dir = utils.download_model("TRAILMAP")
return os.path.join(target_dir, "TRAILMAP_PyTorch.pth")


def get_net():
Expand Down
6 changes: 2 additions & 4 deletions napari_cellseg3d/models/model_VNet.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,16 +3,14 @@
from napari_cellseg3d import utils
import os

modelname = "VNet"
target_dir = os.path.join("models","pretrained")

def get_net():
return VNet()


def get_weights_file():
utils.download_model(modelname, target_dir)
return "VNet_40e.pth"
target_dir = utils.download_model("VNet")
return os.path.join(target_dir, "VNet_40e.pth")


def get_output(model, input):
Expand Down
34 changes: 9 additions & 25 deletions napari_cellseg3d/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -983,7 +983,7 @@ def merge_imgs(imgs, original_image_shape):
return merged_imgs


def download_model(modelname, target_dir):
def download_model(modelname):
"""
Downloads a specific pretained model.
This code is adapted from DeepLabCut with permission from MWMathis
Expand All @@ -995,24 +995,11 @@ def download_model(modelname, target_dir):
def show_progress(count, block_size, total_size):
pbar.update(block_size)

def tarfilenamecutting(tarf):
"""' auxfun to extract folder path
ie. /xyz-trainsetxyshufflez/
"""
for memberid, member in enumerate(tarf.getmembers()):
if memberid == 0:
parent = str(member.path)
l = len(parent) + 1
if member.path.startswith(parent):
member.path = member.path[l:]
yield member

#TODO: fix error in line 1021;
cellseg3d_path = os.path.split(importlib.util.find_spec("napari-cellseg3d").origin)[0]
json_path = os.path.join(cellseg3d_path, "models", "pretrained", "pretrained_model_urls.json")
cellseg3d_path = os.path.split(importlib.util.find_spec("napari_cellseg3d").origin)[0]
pretrained_folder_path = os.path.join(cellseg3d_path, "models", "pretrained")
json_path = os.path.join(pretrained_folder_path, "pretrained_model_urls.json")
with open(json_path) as f:
neturls = json.load(f)

if modelname in neturls.keys():
url = neturls[modelname]
response = urllib.request.urlopen(url)
Expand All @@ -1021,12 +1008,9 @@ def tarfilenamecutting(tarf):
pbar = tqdm(unit="B", total=total_size, position=0)
filename, _ = urllib.request.urlretrieve(url, reporthook=show_progress)
with tarfile.open(filename, mode="r:gz") as tar:
tar.extractall(target_dir, members=tarfilenamecutting(tar))
tar.extractall(pretrained_folder_path)
return pretrained_folder_path
else:
models = [
fn
for fn in neturls.keys()
if "VNet_" not in fn and "SegResNet" not in fn and "TRAILMAP_" not in fn
]
print("Model does not exist: ", modelname)
#print("Pick one of the following: ", models)
raise ValueError(
f"Unknown model. `modelname` should be one of {', '.join(neturls)}"
)

0 comments on commit 9beb8a1

Please sign in to comment.