Skip to content

Commit

Permalink
Merge pull request #10 from AdaptiveMotorControlLab/bugfix/local-inst…
Browse files Browse the repository at this point in the history
…allation

Bug fix in local installation & minor fixes in launch_review
  • Loading branch information
MMathisLab authored Jun 29, 2022
2 parents 9941222 + d4080ee commit 9006a19
Show file tree
Hide file tree
Showing 20 changed files with 110 additions and 92 deletions.
2 changes: 1 addition & 1 deletion napari_cellseg3d/interface.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from typing import Union
from typing import Optional
from typing import Union

from qtpy.QtCore import Qt
from qtpy.QtCore import QUrl
Expand Down
111 changes: 68 additions & 43 deletions napari_cellseg3d/launch_review.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,15 @@
import os
from pathlib import Path

import matplotlib.pyplot as plt
import numpy as np
from magicgui import magicgui
from matplotlib.backends.backend_qt5agg import (
FigureCanvasQTAgg as FigureCanvas,
)
from matplotlib.backends.backend_qt5agg import \
FigureCanvasQTAgg as FigureCanvas
from matplotlib.figure import Figure
from monai.transforms import Zoom
from qtpy.QtWidgets import QSizePolicy
from scipy import ndimage
from monai.transforms import Zoom
from tifffile import imwrite

from napari_cellseg3d import utils
Expand Down Expand Up @@ -147,6 +147,7 @@ def launch_review(

layer = view1.layers[0]
layer1 = view1.layers[1]

# if not as_folder:
# r_path = os.path.dirname(r_path)

Expand All @@ -164,20 +165,19 @@ def file_widget(
dirname = Path(r_path)
# def saver():
out_dir = file_widget.dirname.value

# print("The directory is:", out_dir)

def quicksave():
if not as_folder:
if viewer.layers["labels"] is not None:
time = utils.get_date_time()
name = str(out_dir) + "/labels_reviewed_" + time + ".tif"
name = os.path.join(str(out_dir), "labels_reviewed.tif")
dat = viewer.layers["labels"].data
imwrite(name, data=dat)

else:
if viewer.layers["labels"] is not None:
time = utils.get_date_time()
dir_name = str(out_dir) + "/labels_reviewed_" + time
dir_name = os.path.join(str(out_dir), "labels_reviewed")
dat = viewer.layers["labels"].data
utils.save_stack(dat, dir_name, filetype=filetype)

Expand Down Expand Up @@ -206,17 +206,17 @@ def quicksave():
xy_axes = canvas.figure.add_subplot(3, 1, 1)
canvas.figure.suptitle("Shift-click on image for plot \n", fontsize=8)
xy_axes.imshow(np.zeros((100, 100), np.int16))
xy_axes.scatter(50, 50, s=10, c="red", alpha=0.25)
xy_axes.scatter(50, 50, s=10, c="green", alpha=0.25)
xy_axes.set_xlabel("x axis")
xy_axes.set_ylabel("y axis")
yz_axes = canvas.figure.add_subplot(3, 1, 2)
yz_axes.imshow(np.zeros((100, 100), np.int16))
yz_axes.scatter(50, 50, s=10, c="red", alpha=0.25)
yz_axes.scatter(50, 50, s=10, c="green", alpha=0.25)
yz_axes.set_xlabel("y axis")
yz_axes.set_ylabel("z axis")
zx_axes = canvas.figure.add_subplot(3, 1, 3)
zx_axes.imshow(np.zeros((100, 100), np.int16))
zx_axes.scatter(50, 50, s=10, c="red", alpha=0.25)
zx_axes.scatter(50, 50, s=10, c="green", alpha=0.25)
zx_axes.set_xlabel("x axis")
zx_axes.set_ylabel("z axis")

Expand All @@ -234,17 +234,33 @@ def update_canvas_canvas(viewer, event):

if "shift" in event.modifiers:
try:
m_point = np.round(viewer.cursor.position).astype(int)
print(m_point)

crop_big = crop_img(
[m_point[0], m_point[1], m_point[2]],
cursor_position = np.round(viewer.cursor.position).astype(int)
print(cursor_position)

cropped_volume = crop_volume_around_point(
[
cursor_position[0],
cursor_position[1],
cursor_position[2],
],
viewer.layers["volume"],
)

xy_axes.imshow(crop_big[50], cmap="inferno", vmin=200, vmax=2000)
yz_axes.imshow(crop_big.transpose(1, 0, 2)[50], cmap="inferno", vmin=200, vmax=2000)
zx_axes.imshow(crop_big.transpose(2, 0, 1)[50], cmap="inferno", vmin=200, vmax=2000)
xy_axes.imshow(
cropped_volume[50], cmap="inferno", vmin=200, vmax=2000
)
yz_axes.imshow(
cropped_volume.transpose(1, 0, 2)[50],
cmap="inferno",
vmin=200,
vmax=2000,
)
zx_axes.imshow(
cropped_volume.transpose(2, 0, 1)[50],
cmap="inferno",
vmin=200,
vmax=2000,
)
canvas.draw_idle()
except Exception as e:
print(e)
Expand All @@ -262,40 +278,49 @@ def update_button(axis_event):

view1.dims.events.current_step.connect(update_button)

def crop_img(points, layer):
def crop_volume_around_point(points, layer):

if zoom_factor != [1,1,1]:
im = np.array(layer.data, dtype=np.int16)
image = Zoom(
if zoom_factor != [1, 1, 1]:
vol = np.array(layer.data, dtype=np.int16)
volume = Zoom(
zoom_factor,
keep_size=False,
padding_mode="empty",
)(np.expand_dims(im, axis=0))
image = image[0]
)(np.expand_dims(vol, axis=0))
volume = volume[0]
# image = ndimage.zoom(layer.data, zoom_factor, mode="nearest") # cleaner but much slower...
else :
image = layer.data

min_vals = [x - 50 for x in points]
max_vals = [x + 50 for x in points]
yohaku_minus = [n if n < 0 else 0 for n in min_vals]
yohaku_plus = [
x - image.shape[i] if image.shape[i] < x else 0
for i, x in enumerate(max_vals)
else:
volume = layer.data

min_coordinates = [point - 50 for point in points]
max_coordinates = [point + 50 for point in points]
inferior_bound = [
min_coordinate if min_coordinate < 0 else 0
for min_coordinate in min_coordinates
]
superior_bound = [
max_coordinate - volume.shape[i]
if volume.shape[i] < max_coordinate
else 0
for i, max_coordinate in enumerate(max_coordinates)
]

crop_slice = tuple(
slice(np.maximum(0, n), x) for n, x in zip(min_vals, max_vals)
slice(np.maximum(0, min_coordinate), max_coordinate)
for min_coordinate, max_coordinate in zip(
min_coordinates, max_coordinates
)
)

if as_folder:
crop_temp = image[crop_slice].persist().compute()
crop_temp = volume[crop_slice].persist().compute()
else:

crop_temp = layer.data[crop_slice]
cropped_img = np.zeros((100, 100, 100), np.int16)
cropped_img[
-yohaku_minus[0] : 100 - yohaku_plus[0],
-yohaku_minus[1] : 100 - yohaku_plus[1],
-yohaku_minus[2] : 100 - yohaku_plus[2],

cropped_volume = np.zeros((100, 100, 100), np.int16)
cropped_volume[
-inferior_bound[0] : 100 - superior_bound[0],
-inferior_bound[1] : 100 - superior_bound[1],
-inferior_bound[2] : 100 - superior_bound[2],
] = crop_temp
return cropped_img
return cropped_volume
1 change: 0 additions & 1 deletion napari_cellseg3d/model_framework.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@

import napari
import torch

# Qt
from qtpy.QtWidgets import QLineEdit
from qtpy.QtWidgets import QProgressBar
Expand Down
3 changes: 1 addition & 2 deletions napari_cellseg3d/model_instance_seg.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,9 @@
from __future__ import print_function

import numpy as np
from skimage.measure import label

# from skimage.measure import marching_cubes
# from skimage.measure import mesh_surface_area
from skimage.measure import label
from skimage.measure import regionprops
from skimage.morphology import remove_small_objects
from skimage.segmentation import watershed
Expand Down
4 changes: 0 additions & 4 deletions napari_cellseg3d/model_workers.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@

import numpy as np
import torch

# MONAI
from monai.data import CacheDataset
from monai.data import DataLoader
Expand All @@ -30,17 +29,14 @@
from monai.transforms import SpatialPadd
from monai.transforms import Zoom
from monai.utils import set_determinism

# threads
from napari.qt.threading import GeneratorWorker
from napari.qt.threading import WorkerBaseSignals

# Qt
from qtpy.QtCore import Signal
from tifffile import imwrite

from napari_cellseg3d import utils

# local
from napari_cellseg3d.model_instance_seg import binary_connected
from napari_cellseg3d.model_instance_seg import binary_watershed
Expand Down
4 changes: 3 additions & 1 deletion napari_cellseg3d/models/TRAILMAP_MS.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
import os

import torch
from torch import nn

from napari_cellseg3d import utils
import os


def get_weights_file():
Expand Down
4 changes: 3 additions & 1 deletion napari_cellseg3d/models/model_SegResNet.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
import os

from monai.networks.nets import SegResNetVAE

from napari_cellseg3d import utils
import os


def get_net():
Expand Down
5 changes: 3 additions & 2 deletions napari_cellseg3d/models/model_TRAILMAP.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
from napari_cellseg3d.models.unet.model import UNet3D
from napari_cellseg3d import utils
import os

from napari_cellseg3d import utils
from napari_cellseg3d.models.unet.model import UNet3D


def get_weights_file():
# original model from Liqun Luo lab, transfered to pytorch
Expand Down
4 changes: 3 additions & 1 deletion napari_cellseg3d/models/model_VNet.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
import os

from monai.inferers import sliding_window_inference
from monai.networks.nets import VNet

from napari_cellseg3d import utils
import os


def get_net():
Expand Down
3 changes: 0 additions & 3 deletions napari_cellseg3d/plugin_crop.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
from magicgui import magicgui
from magicgui.widgets import Container
from magicgui.widgets import Slider

# Qt
from qtpy.QtWidgets import QSizePolicy
from tifffile import imwrite
Expand Down Expand Up @@ -72,8 +71,6 @@ def __init__(self, viewer: "napari.viewer.Viewer", parent):

self.build()



def toggle_label_path(self):
if self.crop_label_choice.isChecked():
self.lbl_label.setVisible(True)
Expand Down
1 change: 0 additions & 1 deletion napari_cellseg3d/plugin_dock.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@

import napari
import pandas as pd

# Qt
from qtpy.QtWidgets import QVBoxLayout
from qtpy.QtWidgets import QWidget
Expand Down
8 changes: 4 additions & 4 deletions napari_cellseg3d/plugin_helper.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
import napari
import pathlib

import napari
from qtpy.QtCore import QSize
from qtpy.QtGui import QIcon
from qtpy.QtGui import QPixmap
# Qt
from qtpy.QtWidgets import QVBoxLayout
from qtpy.QtWidgets import QWidget
from qtpy.QtGui import QPixmap
from qtpy.QtGui import QIcon
from qtpy.QtCore import QSize

# local
from napari_cellseg3d import interface as ui
Expand Down
7 changes: 2 additions & 5 deletions napari_cellseg3d/plugin_metrics.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,8 @@
import matplotlib.pyplot as plt
import napari
import numpy as np
from matplotlib.backends.backend_qt5agg import (
FigureCanvasQTAgg as FigureCanvas,
)
from matplotlib.backends.backend_qt5agg import \
FigureCanvasQTAgg as FigureCanvas
from matplotlib.figure import Figure
from monai.transforms import SpatialPad
from monai.transforms import ToTensor
Expand Down Expand Up @@ -67,8 +66,6 @@ def __init__(self, viewer: "napari.viewer.Viewer", parent):

self.build()



def build(self):
"""Builds the layout of the widget."""

Expand Down
2 changes: 0 additions & 2 deletions napari_cellseg3d/plugin_model_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
import napari
import numpy as np
import pandas as pd

# Qt
from qtpy.QtWidgets import QSizePolicy

Expand Down Expand Up @@ -77,7 +76,6 @@ def __init__(self, viewer: "napari.viewer.Viewer"):
self.use_window_inference = False
self.window_inference_size = None


###########################
# interface

Expand Down
8 changes: 2 additions & 6 deletions napari_cellseg3d/plugin_model_training.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,19 +8,16 @@
import numpy as np
import pandas as pd
import torch
from matplotlib.backends.backend_qt5agg import (
FigureCanvasQTAgg as FigureCanvas,
)
from matplotlib.backends.backend_qt5agg import \
FigureCanvasQTAgg as FigureCanvas
from matplotlib.figure import Figure

# MONAI
from monai.losses import DiceCELoss
from monai.losses import DiceFocalLoss
from monai.losses import DiceLoss
from monai.losses import FocalLoss
from monai.losses import GeneralizedDiceLoss
from monai.losses import TverskyLoss

# Qt
from qtpy.QtWidgets import QSizePolicy

Expand Down Expand Up @@ -132,7 +129,6 @@ def __init__(

self.save_as_zip = False
"""Whether to zip results folder once done. Creates a zipped copy of the results folder."""


# recover default values
self.num_samples = samples
Expand Down
Loading

0 comments on commit 9006a19

Please sign in to comment.