diff --git a/cellpose/core.py b/cellpose/core.py
index d27d0a0b..012a9509 100644
--- a/cellpose/core.py
+++ b/cellpose/core.py
@@ -40,12 +40,11 @@ def parse_model_string(pretrained_model):
else:
model_str = os.path.split(pretrained_model)[-1]
if len(model_str)>3 and model_str[:4]=='unet':
- core_logger.info(f'parsing model string {model_str} to get unet options')
nclasses = max(2, int(model_str[4]))
elif len(model_str)>7 and model_str[:8]=='cellpose':
- core_logger.info(f'parsing model string {model_str} to get cellpose options')
+ nclasses = 3
else:
- return None
+ return True, True, False
ostrs = model_str.split('_')[2::2]
residual_on = ostrs[0]=='on'
style_on = ostrs[1]=='on'
@@ -682,7 +681,7 @@ def loss_fn(self, lbl, y):
def train(self, train_data, train_labels, train_files=None,
test_data=None, test_labels=None, test_files=None,
- channels=None, normalize=True, pretrained_model=None, save_path=None, save_every=50, save_each=False,
+ channels=None, normalize=True, save_path=None, save_every=50, save_each=False,
learning_rate=0.2, n_epochs=500, momentum=0.9, weight_decay=0.00001, batch_size=8, rescale=False):
""" train function uses 0-1 mask label and boundary pixels for training """
@@ -715,8 +714,7 @@ def train(self, train_data, train_labels, train_files=None,
del train_data[::8], train_classes[::8], train_labels[::8]
model_path = self._train_net(train_data, train_classes,
- test_data, test_classes,
- pretrained_model, save_path, save_every, save_each,
+ test_data, test_classes, save_path, save_every, save_each,
learning_rate, n_epochs, momentum, weight_decay,
batch_size, rescale)
@@ -842,9 +840,9 @@ def _set_criterion(self):
# Restored defaults. Need to make sure rescale is properly turned off and omni turned on when using CLI.
def _train_net(self, train_data, train_labels,
test_data=None, test_labels=None,
- pretrained_model=None, save_path=None, save_every=100, save_each=False,
+ save_path=None, save_every=100, save_each=False,
learning_rate=0.2, n_epochs=500, momentum=0.9, weight_decay=0.00001,
- SGD=True, batch_size=8, rescale=True, netstr='cellpose'):
+ SGD=True, batch_size=8, rescale=True, netstr=None):
""" train function uses loss function self.loss_fn in models.py"""
d = datetime.datetime.now()
@@ -870,13 +868,12 @@ def _train_net(self, train_data, train_labels,
nchan = train_data[0].shape[0]
core_logger.info('>>>> training network with %d channel input <<<<'%nchan)
- core_logger.info('>>>> saving every %d epochs'%save_every)
core_logger.info('>>>> LR: %0.5f, batch_size: %d, weight_decay: %0.5f'%(self.learning_rate, self.batch_size, weight_decay))
- core_logger.info('>>>> ntrain = %d'%nimg)
- core_logger.info('>>>> rescale is %d'%rescale)
+
if test_data is not None:
- core_logger.info('>>>> ntest = %d'%len(test_data))
- core_logger.info(train_data[0].shape)
+ core_logger.info(f'>>>> ntrain = {nimg}, ntest = {len(test_data)}')
+ else:
+ core_logger.info(f'>>>> ntrain = {nimg}')
tic = time.time()
@@ -891,7 +888,6 @@ def _train_net(self, train_data, train_labels,
LR = np.append(LR, self.learning_rate*np.ones(max(0,self.n_epochs-10)))
else:
LR = self.learning_rate * np.ones(self.n_epochs)
-
lavg, nsum = 0, 0
@@ -913,7 +909,10 @@ def _train_net(self, train_data, train_labels,
for iepoch in range(self.n_epochs):
np.random.seed(iepoch)
- rperm = np.random.permutation(nimg)
+ if nimg < batch_size:
+ rperm = np.random.choice(nimg, batch_size)
+ else:
+ rperm = np.random.permutation(nimg)
if SGD:
self._set_learning_rate(LR[iepoch])
@@ -961,11 +960,17 @@ def _train_net(self, train_data, train_labels,
if iepoch==self.n_epochs-1 or iepoch%save_every==1:
# save model at the end
if save_each: #separate files as model progresses
- file_name = '{}_{}_{}_{}'.format(self.net_type, file_label,
- d.strftime("%Y_%m_%d_%H_%M_%S.%f"),
- 'epoch_'+str(iepoch))
+ if netstr is None:
+ file_name = '{}_{}_{}_{}'.format(self.net_type, file_label,
+ d.strftime("%Y_%m_%d_%H_%M_%S.%f"),
+ 'epoch_'+str(iepoch))
+ else:
+ file_name = '{}_{}'.format(netstr, 'epoch_'+str(iepoch))
else:
- file_name = '{}_{}_{}'.format(self.net_type, file_label, d.strftime("%Y_%m_%d_%H_%M_%S.%f"))
+ if netstr is None:
+ file_name = '{}_{}_{}'.format(self.net_type, file_label, d.strftime("%Y_%m_%d_%H_%M_%S.%f"))
+ else:
+ file_name = netstr
file_name = os.path.join(file_path, file_name)
ksave += 1
core_logger.info(f'saving network parameters to {file_name}')
diff --git a/cellpose/dynamics.py b/cellpose/dynamics.py
index 5fc9f708..8b7c8c88 100644
--- a/cellpose/dynamics.py
+++ b/cellpose/dynamics.py
@@ -601,7 +601,7 @@ def labels_to_flows(labels, files=None, use_gpu=False, device=None, omni=False,r
if labels[0].shape[0] == 1 or labels[0].ndim < 3 or redo_flows: # flows need to be recomputed
- dynamics_logger.info('NOTE: computing flows for labels (could be done before to save time)')
+ dynamics_logger.info('computing flows for labels')
# compute flows; labels are fixed in masks_to_flows, so they need to be passed back
labels, dist, heat, veci = map(list,zip(*[masks_to_flows(labels[n][0],use_gpu=use_gpu, device=device, omni=omni) for n in trange(nimg)]))
@@ -1064,6 +1064,7 @@ def compute_masks(dP, cellprob, bd=None, p=None, inds=None, niter=200, mask_thre
inds = np.stack(np.nonzero(cp_mask)).T
mask = omnipose.core.get_masks(p,bd,cellprob,cp_mask,inds,nclasses,cluster=cluster,
diam_threshold=diam_threshold,verbose=verbose)
+ mask = mask.astype(np.uint32)
else:
mask = get_masks(p, iscell=cp_mask, flows=dP, use_gpu=use_gpu)
diff --git a/cellpose/gui/gui.py b/cellpose/gui/gui.py
index 20bd252e..231d94c9 100644
--- a/cellpose/gui/gui.py
+++ b/cellpose/gui/gui.py
@@ -4,7 +4,7 @@
from tqdm import tqdm
from PyQt5 import QtGui, QtCore, Qt, QtWidgets
-from PyQt5.QtWidgets import QMainWindow, QApplication, QWidget, QScrollBar, QSlider, QComboBox, QGridLayout, QPushButton, QFrame, QCheckBox, QLabel, QProgressBar, QLineEdit
+from PyQt5.QtWidgets import QMainWindow, QApplication, QWidget, QScrollBar, QSlider, QComboBox, QGridLayout, QPushButton, QFrame, QCheckBox, QLabel, QProgressBar, QLineEdit, QMessageBox
import pyqtgraph as pg
from pyqtgraph import GraphicsScene
@@ -15,7 +15,7 @@
from . import guiparts, menus, io
from .. import models, core, dynamics
from ..utils import download_url_to_file, masks_to_outlines
-from ..io import save_server
+from ..io import OMNI_INSTALLED, save_server, get_image_files, imsave, imread
from ..transforms import resize_image, normalize99 #fixed import
from ..plot import disk
@@ -34,7 +34,6 @@
SERVER_UPLOAD = False
#Define possible models; can we make a master list in another file to use in models and main?
-MODEL_NAMES = ['cyto', 'nuclei', 'cyto2', 'cyto2_omni', 'bact_omni']
class QHLine(QFrame):
def __init__(self):
@@ -116,8 +115,10 @@ def make_cmap(cm=0):
cmap = pg.ColorMap(pos=np.linspace(0.0,255,256), color=color)
return cmap
+global logger
def run(image=None):
from ..io import logger_setup
+ global logger
logger, log_file = logger_setup()
# Always start by initializing Qt (only once per application)
warnings.filterwarnings("ignore")
@@ -175,7 +176,11 @@ def __init__(self, image=None):
menus.mainmenu(self)
menus.editmenu(self)
+ #menus.modelmenu(self)
+ self.model_strings = models.MODEL_NAMES
menus.helpmenu(self)
+ if OMNI_INSTALLED:
+ menus.omnimenu(self)
self.setStyleSheet("QMainWindow {background: 'black';}")
self.stylePressed = ("QPushButton {Text-align: left; "
@@ -205,10 +210,11 @@ def __init__(self, image=None):
# ---- drawing area ---- #
self.win = pg.GraphicsLayoutWidget()
- self.l0.addWidget(self.win, 0,3, b, 20)
+ self.l0.addWidget(self.win, 0, 8, b, 30)
self.win.scene().sigMouseClicked.connect(self.plot_clicked)
self.win.scene().sigMouseMoved.connect(self.mouse_moved)
self.make_viewbox()
+ self.l0.setColumnStretch(8, 1)
bwrmap = make_bwr()
self.bwr = bwrmap.getLookupTable(start=0.0, stop=255.0, alpha=False)
self.cmap = []
@@ -220,6 +226,8 @@ def __init__(self, image=None):
if MATPLOTLIB:
self.colormap = (plt.get_cmap('gist_ncar')(np.linspace(0.0,.9,1000)) * 255).astype(np.uint8)
+ np.random.seed(42) # make colors stable
+ self.colormap = self.colormap[np.random.permutation(1000)]
else:
np.random.seed(42) # make colors stable
self.colormap = ((np.random.rand(1000,3)*0.8+0.1)*255).astype(np.uint8)
@@ -231,6 +239,8 @@ def __init__(self, image=None):
self.filename = image
io._load_image(self, self.filename)
+ # training from segmentation
+ self.training = False
self.setAcceptDrops(True)
self.win.show()
self.show()
@@ -265,22 +275,22 @@ def make_buttons(self):
label = QLabel('Views:')#[\u2191 \u2193]')
label.setStyleSheet(self.headings)
label.setFont(self.boldfont)
- self.l0.addWidget(label, 0,0,1,1)
+ self.l0.addWidget(label, 0,0,1,4)
label = QLabel('[up/down or W/S]')
label.setStyleSheet(label_style)
label.setFont(self.smallfont)
- self.l0.addWidget(label, 1,0,1,1)
+ self.l0.addWidget(label, 1,0,1,4)
label = QLabel('[pageup/down]')
label.setStyleSheet(label_style)
label.setFont(self.smallfont)
- self.l0.addWidget(label, 1,1,1,1)
+ self.l0.addWidget(label, 1,4,1,4)
b=2
self.view = 0 # 0=image, 1=flowsXY, 2=flowsZ, 3=cellprob
self.color = 0 # 0=RGB, 1=gray, 2=R, 3=G, 4=B
- self.RGBChoose = guiparts.RGBRadioButtons(self, b,1)
+ self.RGBChoose = guiparts.RGBRadioButtons(self, b,4)
self.RGBDropDown = QComboBox()
self.RGBDropDown.addItems(["RGB","gray","spectral","red","green","blue"])
self.RGBDropDown.setFont(self.medfont)
@@ -288,7 +298,7 @@ def make_buttons(self):
self.RGBDropDown.setFixedWidth(60)
self.RGBDropDown.setStyleSheet(self.dropdowns)
- self.l0.addWidget(self.RGBDropDown, b,0,1,1)
+ self.l0.addWidget(self.RGBDropDown, b,0,1,4)
b+=3
self.resize = -1
@@ -297,12 +307,12 @@ def make_buttons(self):
b+=1
line = QHLine()
line.setStyleSheet('color: white;')
- self.l0.addWidget(line, b,0,1,2)
+ self.l0.addWidget(line, b,0,1,8)
b+=1
label = QLabel('Drawing:')
label.setStyleSheet(self.headings)
label.setFont(self.boldfont)
- self.l0.addWidget(label, b,0,1,2)
+ self.l0.addWidget(label, b,0,1,8)
b+=1
self.brush_size = 3
@@ -312,11 +322,11 @@ def make_buttons(self):
self.BrushChoose.setFixedWidth(60)
self.BrushChoose.setStyleSheet(self.dropdowns)
self.BrushChoose.setFont(self.medfont)
- self.l0.addWidget(self.BrushChoose, b, 1,1,1)
+ self.l0.addWidget(self.BrushChoose, b, 4,1,4)
label = QLabel('brush size: [, .]')
label.setStyleSheet(label_style)
label.setFont(self.medfont)
- self.l0.addWidget(label, b,0,1,1)
+ self.l0.addWidget(label, b,0,1,4)
# cross-hair
self.vLine = pg.InfiniteLine(angle=90, movable=False)
@@ -328,7 +338,7 @@ def make_buttons(self):
self.SCheckBox.setStyleSheet(self.checkstyle)
self.SCheckBox.setFont(self.medfont)
self.SCheckBox.toggled.connect(self.autosave_on)
- self.l0.addWidget(self.SCheckBox, b,0,1,2)
+ self.l0.addWidget(self.SCheckBox, b,0,1,4)
b+=1
# turn on crosshairs
@@ -336,9 +346,9 @@ def make_buttons(self):
self.CHCheckBox.setStyleSheet(self.checkstyle)
self.CHCheckBox.setFont(self.medfont)
self.CHCheckBox.toggled.connect(self.cross_hairs)
- self.l0.addWidget(self.CHCheckBox, b,0,1,1)
+ self.l0.addWidget(self.CHCheckBox, b,0,1,4)
- b+=1
+ b-=1
# turn off masks
self.layer_off = False
self.masksOn = True
@@ -347,7 +357,7 @@ def make_buttons(self):
self.MCheckBox.setFont(self.medfont)
self.MCheckBox.setChecked(True)
self.MCheckBox.toggled.connect(self.toggle_masks)
- self.l0.addWidget(self.MCheckBox, b,0,1,2)
+ self.l0.addWidget(self.MCheckBox, b,4,1,4)
b+=1
# turn off outlines
@@ -355,7 +365,7 @@ def make_buttons(self):
self.OCheckBox = QCheckBox('outlines on [Z]')
self.OCheckBox.setStyleSheet(self.checkstyle)
self.OCheckBox.setFont(self.medfont)
- self.l0.addWidget(self.OCheckBox, b,0,1,2)
+ self.l0.addWidget(self.OCheckBox, b,4,1,4)
self.OCheckBox.setChecked(False)
self.OCheckBox.toggled.connect(self.toggle_masks)
@@ -364,7 +374,7 @@ def make_buttons(self):
# send to server
self.ServerButton = QPushButton(' send manual seg. to server')
self.ServerButton.clicked.connect(lambda: save_server(self))
- self.l0.addWidget(self.ServerButton, b,0,1,2)
+ self.l0.addWidget(self.ServerButton, b,0,1,8)
self.ServerButton.setEnabled(False)
self.ServerButton.setStyleSheet(self.styleInactive)
self.ServerButton.setFont(self.boldfont)
@@ -372,12 +382,12 @@ def make_buttons(self):
b+=1
line = QHLine()
line.setStyleSheet('color: white;')
- self.l0.addWidget(line, b,0,1,2)
+ self.l0.addWidget(line, b,0,1,8)
b+=1
label = QLabel('Segmentation:')
label.setStyleSheet(self.headings)
label.setFont(self.boldfont)
- self.l0.addWidget(label, b,0,1,2)
+ self.l0.addWidget(label, b,0,1,8)
b+=1
self.diameter = 30
@@ -385,7 +395,7 @@ def make_buttons(self):
label.setStyleSheet(label_style)
label.setFont(self.medfont)
label.setToolTip('you can manually enter the approximate diameter for your cells, \nor press “calibrate” to let the model estimate it. \nThe size is represented by a disk at the bottom of the view window \n(can turn this disk off by unchecking “scale disk on”)')
- self.l0.addWidget(label, b, 0,1,2)
+ self.l0.addWidget(label, b, 0,1,8)
self.Diameter = QLineEdit()
self.Diameter.setToolTip('you can manually enter the approximate diameter for your cells, \nor press “calibrate” to let the model estimate it. \nThe size is represented by a disk at the bottom of the view window \n(can turn this disk off by unchecking “scale disk on”)')
self.Diameter.setText(str(self.diameter))
@@ -393,12 +403,12 @@ def make_buttons(self):
self.Diameter.returnPressed.connect(self.compute_scale)
self.Diameter.setFixedWidth(50)
b+=1
- self.l0.addWidget(self.Diameter, b, 0,1,2)
+ self.l0.addWidget(self.Diameter, b, 0,1,4)
# recompute model
self.SizeButton = QPushButton(' calibrate')
self.SizeButton.clicked.connect(self.calibrate_size)
- self.l0.addWidget(self.SizeButton, b,1,1,1)
+ self.l0.addWidget(self.SizeButton, b,4,1,4)
self.SizeButton.setEnabled(False)
self.SizeButton.setStyleSheet(self.styleInactive)
self.SizeButton.setFont(self.boldfont)
@@ -412,7 +422,7 @@ def make_buttons(self):
self.ScaleOn.setChecked(True)
self.ScaleOn.setToolTip('see current diameter as red disk at bottom')
self.ScaleOn.toggled.connect(self.toggle_scale)
- self.l0.addWidget(self.ScaleOn, b,0,1,2)
+ self.l0.addWidget(self.ScaleOn, b,0,1,4)
# use GPU
b+=1
@@ -421,23 +431,28 @@ def make_buttons(self):
self.useGPU.setFont(self.medfont)
self.useGPU.setToolTip('if you have specially installed the cuda version of mxnet, then you can activate this, but it won’t give huge speedups when running single 2D images in the GUI.')
self.check_gpu()
- self.l0.addWidget(self.useGPU, b,0,1,1)
+ self.l0.addWidget(self.useGPU, b,0,1,4)
# fast mode
self.NetAvg = QComboBox()
self.NetAvg.addItems(['average 4 nets', '+ resample (slow)', 'run 1 net (fast)', ])
self.NetAvg.setFont(self.medfont)
self.NetAvg.setToolTip('average 4 different fit networks (default) + resample for smooth masks (slow) or run 1 network (fast)')
- self.l0.addWidget(self.NetAvg, b,1,1,1)
+ self.l0.addWidget(self.NetAvg, b,4,1,4)
b+=1
# choose models
self.ModelChoose = QComboBox()
- self.ModelChoose.addItems(MODEL_NAMES) #added omnipose model names
- self.ModelChoose.setFixedWidth(70)
+ if len(self.model_strings) > len(models.MODEL_NAMES):
+ current_index = len(models.MODEL_NAMES)
+ else:
+ current_index = 0
+ self.ModelChoose.addItems(self.model_strings) #added omnipose model names
+ self.ModelChoose.setFixedWidth(150)
self.ModelChoose.setStyleSheet(self.dropdowns)
self.ModelChoose.setFont(self.medfont)
- self.l0.addWidget(self.ModelChoose, b, 1,1,1)
+ self.ModelChoose.setCurrentIndex(current_index)
+ self.l0.addWidget(self.ModelChoose, b, 4,1,4)
label = QLabel('model: ')
label.setStyleSheet(label_style)
label.setFont(self.medfont)
@@ -446,7 +461,7 @@ def make_buttons(self):
and two omnipose models: bact_omni and cyto2_omni'
label.setToolTip(tipstr)
self.ModelChoose.setToolTip(tipstr)
- self.l0.addWidget(label, b, 0,1,1)
+ self.l0.addWidget(label, b, 0,1,4)
b+=1
# choose channel
@@ -455,7 +470,7 @@ def make_buttons(self):
self.ChannelChoose[1].addItems(['none','red','green','blue'])
cstr = ['chan to segment:', 'chan2 (optional): ']
for i in range(2):
- self.ChannelChoose[i].setFixedWidth(70)
+ #self.ChannelChoose[i].setFixedWidth(70)
self.ChannelChoose[i].setStyleSheet(self.dropdowns)
self.ChannelChoose[i].setFont(self.medfont)
label = QLabel(cstr[i])
@@ -467,8 +482,8 @@ def make_buttons(self):
else:
label.setToolTip('if cytoplasm model is chosen, and you also have a nuclear channel, then choose the nuclear channel for this option')
self.ChannelChoose[i].setToolTip('if cytoplasm model is chosen, and you also have a nuclear channel, then choose the nuclear channel for this option')
- self.l0.addWidget(label, b, 0,1,1)
- self.l0.addWidget(self.ChannelChoose[i], b, 1,1,1)
+ self.l0.addWidget(label, b, 0,1,4)
+ self.l0.addWidget(self.ChannelChoose[i], b, 4,1,4)
b+=1
# use inverted image for running cellpose
@@ -476,40 +491,22 @@ def make_buttons(self):
self.invert = QCheckBox('invert grayscale')
self.invert.setStyleSheet(self.checkstyle)
self.invert.setFont(self.medfont)
- self.l0.addWidget(self.invert, b,0,1,2)
+ self.l0.addWidget(self.invert, b,0,1,4)
- # use omnipose mask recontruction
- b+=1
- self.omni = QCheckBox('omni mask alg')
- self.omni.setStyleSheet(self.checkstyle)
- self.omni.setFont(self.medfont)
- self.omni.setChecked(False)
- self.omni.setToolTip('use Omnipose mask recontruction algorithm (fix over-segmentation)')
- # self.omni.toggled.connect(self.compute_model)
- self.l0.addWidget(self.omni, b,0,1,2)
- # use DBSCAN clustering
- b+=1
- self.cluster = QCheckBox('cluster masks')
- self.cluster.setStyleSheet(self.checkstyle)
- self.cluster.setFont(self.medfont)
- self.cluster.setChecked(False)
- self.cluster.setToolTip('force DBSCAN clustering when omni is enabled')
- # self.cluster.toggled.connect(self.compute_model)
- self.l0.addWidget(self.cluster, b,0,1,2)
b+=1
- # recompute model
+ # recompute segmentation
self.ModelButton = QPushButton(' run segmentation')
self.ModelButton.clicked.connect(self.compute_model)
- self.l0.addWidget(self.ModelButton, b,0,1,2)
+ self.l0.addWidget(self.ModelButton, b,0,1,8)
self.ModelButton.setEnabled(False)
self.ModelButton.setStyleSheet(self.styleInactive)
self.ModelButton.setFont(self.boldfont)
b+=1
self.progress = QProgressBar(self)
self.progress.setStyleSheet('color: gray;')
- self.l0.addWidget(self.progress, b,0,1,2)
+ self.l0.addWidget(self.progress, b,0,1,8)
# post-hoc paramater tuning
@@ -518,7 +515,7 @@ def make_buttons(self):
label.setToolTip('threshold on flow match to accept a mask (set lower to get more cells)')
label.setStyleSheet(label_style)
label.setFont(self.medfont)
- self.l0.addWidget(label, b, 0,1,2)
+ self.l0.addWidget(label, b, 0,1,8)
b+=1
self.threshold = 0.4
@@ -527,7 +524,7 @@ def make_buttons(self):
self.threshslider.setMinimum(1.0)
self.threshslider.setMaximum(30.0)
self.threshslider.setValue(31 - 4)
- self.l0.addWidget(self.threshslider, b, 0,1,2)
+ self.l0.addWidget(self.threshslider, b, 0,1,8)
self.threshslider.valueChanged.connect(self.compute_cprob)
self.threshslider.setStyleSheet(guiparts.horizontal_slider_style())
self.threshslider.setEnabled(False)
@@ -538,7 +535,7 @@ def make_buttons(self):
(set lower to include more pixels)')
label.setStyleSheet(label_style)
label.setFont(self.medfont)
- self.l0.addWidget(label, b, 0,1,2)
+ self.l0.addWidget(label, b, 0,1,8)
b+=1
self.probslider = QSlider()
@@ -547,7 +544,7 @@ def make_buttons(self):
self.probslider.setMaximum(6.0)
self.probslider.setValue(0.0)
self.cellprob = 0.0
- self.l0.addWidget(self.probslider, b, 0,1,2)
+ self.l0.addWidget(self.probslider, b, 0,1,8)
self.probslider.valueChanged.connect(self.compute_cprob)
self.probslider.setStyleSheet(guiparts.horizontal_slider_style())
self.probslider.setEnabled(False)
@@ -555,19 +552,21 @@ def make_buttons(self):
b+=1
line = QHLine()
line.setStyleSheet('color: white;')
- self.l0.addWidget(line, b,0,1,2)
-
- self.autobtn = QCheckBox('auto-adjust')
- self.autobtn.setStyleSheet(self.checkstyle)
- self.autobtn.setFont(self.medfont)
- self.autobtn.setChecked(True)
- self.l0.addWidget(self.autobtn, b+2,0,1,1)
+ self.l0.addWidget(line, b,0,1,8)
+
b+=1
label = QLabel('Image saturation:')
label.setStyleSheet(self.headings)
label.setFont(self.boldfont)
- self.l0.addWidget(label, b,0,1,2)
+ self.l0.addWidget(label, b,0,1,8)
+
+ b+=1
+ self.autobtn = QCheckBox('auto-adjust')
+ self.autobtn.setStyleSheet(self.checkstyle)
+ self.autobtn.setFont(self.medfont)
+ self.autobtn.setChecked(True)
+ self.l0.addWidget(self.autobtn, b,0,1,4)
b+=1
self.slider = guiparts.RangeSlider(self)
@@ -576,28 +575,39 @@ def make_buttons(self):
self.slider.setLow(0)
self.slider.setHigh(255)
self.slider.setTickPosition(QSlider.TicksRight)
- self.l0.addWidget(self.slider, b,1,1,1)
+ self.l0.addWidget(self.slider, b,0,1,8)
+
+ b+=1
+ self.l0.addWidget(QLabel(''),b,0,1,4)
self.l0.setRowStretch(b, 1)
-
- b+=2
+
+ b+=1
+ self.quadrant_label = QLabel('image quadrants:')
+ self.quadrant_label.setAlignment(QtCore.Qt.AlignRight | QtCore.Qt.AlignVCenter)
+ self.quadrant_label.setStyleSheet(label_style)
+ self.quadrant_label.setFont(self.medfont)
+ self.l0.addWidget(self.quadrant_label, b, 1,1,4)
+ guiparts.make_quadrants(self, b)
+
+ b+=3
# add z position underneath
self.currentZ = 0
label = QLabel('Z:')
label.setAlignment(QtCore.Qt.AlignRight | QtCore.Qt.AlignVCenter)
label.setStyleSheet(label_style)
- self.l0.addWidget(label, b, 0,1,1)
+ self.l0.addWidget(label, b, 4,1,1)
self.zpos = QLineEdit()
self.zpos.setAlignment(QtCore.Qt.AlignRight | QtCore.Qt.AlignVCenter)
self.zpos.setText(str(self.currentZ))
self.zpos.returnPressed.connect(self.compute_scale)
self.zpos.setFixedWidth(60)
- self.l0.addWidget(self.zpos, b, 1,1,1)
+ self.l0.addWidget(self.zpos, b, 5,1,3)
# add scrollbar underneath
self.scroll = QScrollBar(QtCore.Qt.Horizontal)
self.scroll.setMaximum(10)
self.scroll.valueChanged.connect(self.move_in_Z)
- self.l0.addWidget(self.scroll, b,3,1,20)
+ self.l0.addWidget(self.scroll, b,8,1,30)
return b
def keyPressEvent(self, event):
@@ -722,7 +732,7 @@ def calibrate_size(self):
diams, _ = self.model.sz.eval(self.stack[self.currentZ].copy(), invert=self.invert.isChecked(),
channels=self.get_channels(), progress=self.progress)
diams = np.maximum(5.0, diams)
- print('estimated diameter of cells using %s model = %0.1f pixels'%
+ logger.info('estimated diameter of cells using %s model = %0.1f pixels'%
(self.current_model, diams))
self.Diameter.setText('%0.1f'%diams)
self.diameter = diams
@@ -764,13 +774,9 @@ def undo_remove_action(self):
self.undo_remove_cell()
def get_files(self):
- images = []
- images.extend(glob.glob(os.path.dirname(self.filename) + '/*.png'))
- images.extend(glob.glob(os.path.dirname(self.filename) + '/*.jpg'))
- images.extend(glob.glob(os.path.dirname(self.filename) + '/*.jpeg'))
- images.extend(glob.glob(os.path.dirname(self.filename) + '/*.tif'))
- images.extend(glob.glob(os.path.dirname(self.filename) + '/*.tiff'))
- images = natsorted(images)
+ folder = os.path.dirname(self.filename)
+ mask_filter = '_masks'
+ images = get_image_files(folder, mask_filter)
fnames = [os.path.split(images[k])[-1] for k in range(len(images))]
f0 = os.path.split(self.filename)[-1]
idx = np.nonzero(np.array(fnames)==f0)[0][0]
@@ -781,10 +787,10 @@ def get_prev_image(self):
idx = (idx-1)%len(images)
io._load_image(self, filename=images[idx])
- def get_next_image(self):
+ def get_next_image(self, load_seg=True):
images, idx = self.get_files()
idx = (idx+1)%len(images)
- io._load_image(self, filename=images[idx])
+ io._load_image(self, filename=images[idx], load_seg=load_seg)
def dragEnterEvent(self, event):
if event.mimeData().hasUrls():
@@ -849,7 +855,6 @@ def make_viewbox(self):
self.p0.addItem(self.layer)
self.p0.addItem(self.scale)
- guiparts.make_quadrants(self)
def reset(self):
# ---- start sets of points ---- #
@@ -925,7 +930,6 @@ def clear_all(self):
self.outpix = np.zeros((self.NZ,self.Ly,self.Lx), np.uint16)
self.cellcolors = [np.array([255,255,255])]
self.ncells = 0
- print('removed all cells')
self.toggle_removals()
self.update_plot()
@@ -972,7 +976,7 @@ def remove_cell(self, idx):
del self.cellcolors[idx]
del self.zdraw[idx-1]
self.ncells -= 1
- print('removed cell %d'%(idx-1))
+ print('>>> removed cell %d'%(idx-1))
if self.ncells==0:
self.ClearButton.setEnabled(False)
if self.NZ==1:
@@ -987,7 +991,6 @@ def merge_cells(self, idx):
ar1, ac1 = np.nonzero(self.cellpix[z]==self.selected)
touching = np.logical_and((ar0[:,np.newaxis] - ar1)==1,
(ac0[:,np.newaxis] - ac1)==1).sum()
- print(touching)
ar = np.hstack((ar0, ar1))
ac = np.hstack((ac0, ac1))
if touching:
@@ -1004,7 +1007,7 @@ def merge_cells(self, idx):
color = self.cellcolors[self.prev_selected]
self.draw_mask(z, ar, ac, vr, vc, color, idx=self.prev_selected)
self.remove_cell(self.selected)
- print('merged two cells')
+ print('>>> merged two cells')
self.update_plot()
io._save_sets(self)
self.undo.setEnabled(False)
@@ -1022,7 +1025,7 @@ def undo_remove_cell(self):
self.ncells+=1
self.ismanual = np.append(self.ismanual, self.removed_cell[0])
self.zdraw.append([])
- print('added back removed cell')
+ print('>>> added back removed cell')
self.update_plot()
io._save_sets(self)
self.removed_cell = []
@@ -1035,14 +1038,17 @@ def remove_stroke(self, delete_points=True):
cZ = stroke[0,0]
outpix = self.outpix[cZ][stroke[:,1],stroke[:,2]]>0
self.layers[cZ][stroke[~outpix,1],stroke[~outpix,2]] = np.array([0,0,0,0])
+ #if self.masksOn:
+ cellpix = self.cellpix[cZ][stroke[:,1], stroke[:,2]]
+ ccol = np.array(self.cellcolors.copy())
+ if self.selected > 0:
+ ccol[self.selected] = np.array([255,255,255])
+ col2mask = ccol[cellpix]
if self.masksOn:
- cellpix = self.cellpix[cZ][stroke[:,1], stroke[:,2]]
- ccol = np.array(self.cellcolors.copy())
- if self.selected > 0:
- ccol[self.selected] = np.array([255,255,255])
- col2mask = ccol[cellpix]
col2mask = np.concatenate((col2mask, self.opacity*(cellpix[:,np.newaxis]>0)), axis=-1)
- self.layers[cZ][stroke[:,1], stroke[:,2], :] = col2mask
+ else:
+ col2mask = np.concatenate((col2mask, 0*(cellpix[:,np.newaxis]>0)), axis=-1)
+ self.layers[cZ][stroke[:,1], stroke[:,2], :] = col2mask
if self.outlinesOn:
self.layers[cZ][stroke[outpix,1],stroke[outpix,2]] = np.array(self.outcolor)
if delete_points:
@@ -1134,9 +1140,7 @@ def add_set(self):
while len(self.strokes) > 0:
self.remove_stroke(delete_points=False)
if len(self.current_point_set) > 8:
- np.random.seed(42) # make colors stable
- col_rand = np.random.randint(1000)
- color = self.colormap[col_rand,:3]
+ color = self.colormap[self.ncells,:3]
median = self.add_mask(points=self.current_point_set, color=color)
if median is not None:
self.removed_cell = []
@@ -1229,8 +1233,8 @@ def draw_mask(self, z, ar, ac, vr, vc, color, idx=None):
self.cellpix[z][vr, vc] = idx
self.cellpix[z][ar, ac] = idx
self.outpix[z][vr, vc] = idx
+ self.layers[z][ar, ac, :3] = color
if self.masksOn:
- self.layers[z][ar, ac, :3] = color
self.layers[z][ar, ac, -1] = self.opacity
if self.outlinesOn:
self.layers[z][vr, vc] = np.array(self.outcolor)
@@ -1278,9 +1282,9 @@ def compute_saturation(self):
# compute percentiles from stack
self.saturation = []
for n in range(len(self.stack)):
- # changed to use omnipose convention
- self.saturation.append([np.percentile(self.stack[n].astype(np.float32),0.01),
- np.percentile(self.stack[n].astype(np.float32),99.99)])
+ # reverted for cellular images, maybe there can be an option?
+ self.saturation.append([np.percentile(self.stack[n].astype(np.float32),1),
+ np.percentile(self.stack[n].astype(np.float32),99)])
def chanchoose(self, image):
if image.ndim > 2:
@@ -1293,12 +1297,107 @@ def chanchoose(self, image):
image = image[:,:,chanid].astype(np.float32)
return image
- def initialize_model(self):
+ def get_model_path(self):
self.current_model = self.ModelChoose.currentText()
- print(self.current_model)
- self.model = models.Cellpose(gpu=self.useGPU.isChecked(),
- torch=self.torch,
- model_type=self.current_model)
+ if self.current_model in models.MODEL_NAMES:
+ self.current_model_path = models.model_path(self.current_model, 0, self.torch)
+ else:
+ self.current_model_path = os.fspath(models.MODEL_DIR.joinpath(self.current_model))
+
+ def initialize_model(self):
+ self.get_model_path()
+ if self.current_model in models.MODEL_NAMES:
+ self.model = models.Cellpose(gpu=self.useGPU.isChecked(),
+ torch=self.torch,
+ model_type=self.current_model)
+ self.SizeButton.setEnabled(True)
+ else:
+ self.model = models.CellposeModel(gpu=self.useGPU.isChecked(),
+ torch=True,
+ pretrained_model=self.current_model_path)
+ self.SizeButton.setEnabled(False)
+
+ def add_model(self):
+ io._add_model(self)
+ #a_list = ["abc", "def", "ghi"]
+ #textfile = open("a_file.txt", "w")
+ #for element in a_list:
+ # textfile.write(element + "\n")
+ #textfile.close()
+ return
+
+ def remove_model(self):
+ io._remove_model(self)
+ return
+
+ def new_model(self):
+ if self.NZ!=1:
+ print('ERROR: cannot train model on 3D data')
+ return
+
+ # do not save current masks, could be from bad model
+ #print('GUI_INFO: saving current masks to add to training')
+ #io._save_sets(self)
+
+ # train model
+ image_names = self.get_files()[0]
+ self.train_data, self.train_labels, self.train_files = io._get_train_set(image_names)
+ TW = guiparts.TrainWindow(self, models.MODEL_NAMES)
+ train = TW.exec_()
+ if train:
+ logger.info(f'training with {[os.path.split(f)[1] for f in self.train_files]}')
+ self.get_model_path()
+ self.channels = self.get_channels()
+ logger.info(f'training with chan (cyto) = {self.ChannelChoose[0].currentText()}, chan2 (nuclei)={self.ChannelChoose[1].currentText()}')
+
+ if self.training:
+ # currently in training mode, need to remove new model path
+ print(f'GUI_INFO: removing previous model ({os.path.split(self.new_model_path)[-1]}) from gui')
+ io._remove_model(self, self.new_model_ind)
+ else:
+ self.training = True
+ self.endtrain.setEnabled(True)
+ self.SizeButton.setEnabled(False)
+ self.train_model()
+
+ else:
+ print('GUI_INFO: training cancelled')
+
+
+ def train_model(self):
+ logger.info(f'training new model starting at model {self.current_model_path}')
+ self.model = models.CellposeModel(gpu=self.useGPU.isChecked(),
+ torch=True,
+ pretrained_model=self.current_model_path)
+ self.SizeButton.setEnabled(False)
+ save_path = os.path.dirname(self.filename)
+ d = datetime.datetime.now()
+ netstr = self.current_model + d.strftime("_%Y%m%d_%H%M%S")
+ self.new_model_path = self.model.retrain(self.train_data, self.train_labels, self.train_files,
+ channels=self.channels, save_path=save_path,
+ learning_rate=self.learning_rate, n_epochs=self.n_epochs,
+ weight_decay=self.weight_decay,
+ netstr=netstr)
+
+ # run model on next image
+ io._add_model(self, self.new_model_path)
+ self.new_model_ind = len(self.model_strings)-1
+ print(f'GUI_INFO: model saved to {self.new_model_path} and loaded in gui')
+ self.autorun = True
+ if self.autorun:
+ self.get_next_image(load_seg=True)
+ if self.train_files[0] == self.filename:
+ print(f'GUI_INFO: trained on all images + masks in folder --> auto-end training')
+ self.end_train()
+ #self.get_next_image(load_seg=True)
+ return
+ self.compute_model()
+ logger.info(f'!!! computed masks for {os.path.split(self.filename)[1]} from new model !!!')
+
+ def end_train(self):
+ self.endtrain.setEnabled(False)
+ self.training = False
+ print('done')
def compute_cprob(self):
rerun = False
@@ -1313,11 +1412,11 @@ def compute_cprob(self):
if self.threshold==3.0 or self.NZ>1:
thresh = None
- print('computing masks with cell prob=%0.3f, no flow error threshold'%
+ logger.info('computing masks with cell prob=%0.3f, no flow error threshold'%
(self.cellprob))
else:
thresh = self.threshold
- print('computing masks with cell prob=%0.3f, flow error threshold=%0.3f'%
+ logger.info('computing masks with cell prob=%0.3f, flow error threshold=%0.3f'%
(self.cellprob, thresh))
maski = dynamics.compute_masks(self.flows[4][:-1],
self.flows[4][-1],
@@ -1325,8 +1424,8 @@ def compute_cprob(self):
mask_threshold=self.cellprob,
flow_threshold=thresh,
resize=self.cellpix.shape[-2:],
- omni=self.omni.isChecked(),
- cluster=self.cluster.isChecked())[0]
+ omni=OMNI_INSTALLED and self.omni.isChecked(),
+ cluster=OMNI_INSTALLED and self.cluster.isChecked())[0]
self.masksOn = True
self.MCheckBox.setChecked(True)
@@ -1334,19 +1433,18 @@ def compute_cprob(self):
# self.OCheckBox.setChecked(True)
if maski.ndim<3:
maski = maski[np.newaxis,...]
- print('%d cells found'%(len(np.unique(maski)[1:])))
+ logger.info('%d cells found'%(len(np.unique(maski)[1:])))
io._masks_to_gui(self, maski, outlines=None)
self.show()
def compute_model(self):
self.progress.setValue(0)
- if 1:
+ try:
tic=time.time()
self.clear_all()
self.flows = [[],[],[]]
self.initialize_model()
-
- print('using model %s'%self.current_model)
+ logger.info('using model %s'%self.current_model)
self.progress.setValue(10)
do_3D = False
if self.NZ > 1:
@@ -1366,16 +1464,16 @@ def compute_model(self):
self.Diameter.setText('%0.1f'%self.diameter)
# allow omni to be togged manually or forced by model
- self.omni.setChecked(self.omni.isChecked() or omni_model)
+ if OMNI_INSTALLED:
+ self.omni.setChecked(self.omni.isChecked() or omni_model)
+ self.cluster.setChecked(self.cluster.isChecked() or omni_model)
- net_avg = self.NetAvg.currentIndex()<2
+ net_avg = self.NetAvg.currentIndex()<2 and self.current_model in models.MODEL_NAMES
resample = self.NetAvg.currentIndex()==1
- # print(data.shape,channels,self.diameter,resample,do_3D,net_avg)
- print('net_avg',net_avg)
- masks, flows, _, _ = self.model.eval(data, channels=channels,
- diameter=self.diameter, invert=self.invert.isChecked(),
- net_avg=net_avg, augment=False, resample=resample,
- do_3D=do_3D, progress=self.progress, omni=self.omni.isChecked())
+ masks, flows = self.model.eval(data, channels=channels,
+ diameter=self.diameter, invert=self.invert.isChecked(),
+ net_avg=net_avg, augment=False, resample=resample,
+ do_3D=do_3D, progress=self.progress, omni=OMNI_INSTALLED and self.omni.isChecked())[:2]
except Exception as e:
print('NET ERROR: %s'%e)
self.progress.setValue(0)
@@ -1401,7 +1499,7 @@ def compute_model(self):
self.flows.append(flows[3].squeeze()) #p
self.flows.append(np.concatenate((flows[1], flows[2][np.newaxis,...]), axis=0)) #dP, dist/prob
- print('%d cells found with cellpose net in %0.3f sec'%(len(np.unique(masks)[1:]), time.time()-tic))
+ logger.info('%d cells found with model in %0.3f sec'%(len(np.unique(masks)[1:]), time.time()-tic))
self.progress.setValue(80)
z=0
self.masksOn = True
@@ -1416,17 +1514,16 @@ def compute_model(self):
if not do_3D:
self.threshslider.setEnabled(True)
self.probslider.setEnabled(True)
- else: #except Exception as e:
+ except Exception as e:
print('ERROR: %s'%e)
def enable_buttons(self):
- #self.X2Up.setEnabled(True)
- #self.X2Down.setEnabled(True)
self.ModelButton.setEnabled(True)
self.SizeButton.setEnabled(True)
self.ModelButton.setStyleSheet(self.styleUnpressed)
self.SizeButton.setStyleSheet(self.styleUnpressed)
+ #self.newmodel.setEnabled(True)
self.loadMasks.setEnabled(True)
self.saveSet.setEnabled(True)
self.savePNG.setEnabled(True)
diff --git a/cellpose/gui/guiparts.py b/cellpose/gui/guiparts.py
index 1dd0a600..b47367c7 100644
--- a/cellpose/gui/guiparts.py
+++ b/cellpose/gui/guiparts.py
@@ -1,19 +1,79 @@
from PyQt5 import QtGui, QtCore, QtWidgets
from PyQt5.QtGui import QPainter, QPixmap
-from PyQt5.QtWidgets import QApplication, QRadioButton, QWidget, QDialog, QButtonGroup, QSlider, QStyle, QStyleOptionSlider, QGridLayout, QPushButton, QLabel
+from PyQt5.QtWidgets import QApplication, QRadioButton, QWidget, QDialog, QButtonGroup, QSlider, QStyle, QStyleOptionSlider, QGridLayout, QPushButton, QLabel, QLineEdit, QDialogButtonBox, QComboBox, QCheckBox
import pyqtgraph as pg
from pyqtgraph import functions as fn
from pyqtgraph import Point
import numpy as np
import pathlib
-def make_quadrants(parent):
+class TrainWindow(QDialog):
+ def __init__(self, parent, model_strings):
+ super().__init__(parent)
+ self.setGeometry(100,100,300,300)
+ self.setWindowTitle('train settings')
+ self.win = QWidget(self)
+ self.l0 = QGridLayout()
+ self.win.setLayout(self.l0)
+
+ yoff = 0
+ qlabel = QLabel('train model using images + masks available in current folder')
+ qlabel.setAlignment(QtCore.Qt.AlignRight | QtCore.Qt.AlignVCenter)
+ self.l0.addWidget(qlabel, yoff,0,1,2)
+
+ # choose initial model
+ yoff+=1
+ self.ModelChoose = QComboBox()
+ self.ModelChoose.addItems(model_strings)
+ self.ModelChoose.setFixedWidth(150)
+ self.ModelChoose.setCurrentIndex(0)
+ self.l0.addWidget(self.ModelChoose, yoff, 1,1,1)
+ qlabel = QLabel('initial model: ')
+ qlabel.setAlignment(QtCore.Qt.AlignRight | QtCore.Qt.AlignVCenter)
+ self.l0.addWidget(qlabel, yoff,0,1,1)
+
+ # choose parameters
+ labels = ['learning_rate', 'weight_decay', 'n_epochs']
+ values = [0.025, 0.0001, 100]
+ self.edits = []
+ yoff += 1
+ for i, (label, value) in enumerate(zip(labels, values)):
+ qlabel = QLabel(label)
+ qlabel.setAlignment(QtCore.Qt.AlignRight | QtCore.Qt.AlignVCenter)
+ self.l0.addWidget(qlabel, i+yoff,0,1,1)
+ self.edits.append(QLineEdit())
+ self.edits[-1].setText(str(value))
+ self.l0.addWidget(self.edits[-1], i+yoff, 1,1,1)
+
+ yoff+=len(labels)
+ self.autorun = QCheckBox('auto-run trained model on next image in folder')
+ self.autorun.setChecked(True)
+ self.l0.addWidget(self.autorun, yoff, 0, 1, 2)
+
+ # click button
+ yoff+=1
+ QBtn = QDialogButtonBox.Ok | QDialogButtonBox.Cancel
+ self.buttonBox = QDialogButtonBox(QBtn)
+ self.buttonBox.accepted.connect(lambda: self.accept(parent))
+ self.buttonBox.rejected.connect(self.reject)
+ self.l0.addWidget(self.buttonBox, yoff, 0, 1,3)
+
+ def accept(self, parent):
+ parent.autorun = self.autorun.isChecked()
+ parent.learning_rate = float(self.edits[0].text())
+ parent.weight_decay = float(self.edits[1].text())
+ parent.n_epochs = int(self.edits[2].text())
+ parent.ModelChoose.setCurrentIndex(self.ModelChoose.currentIndex())
+ self.done(1)
+ # return
+
+def make_quadrants(parent, yp):
""" make quadrant buttons """
parent.quadbtns = QButtonGroup(parent)
for b in range(9):
btn = QuadButton(b, ' '+str(b+1), parent)
parent.quadbtns.addButton(btn, b)
- parent.l0.addWidget(btn, 0 + parent.quadbtns.button(b).ypos, 29 + parent.quadbtns.button(b).xpos, 1, 1)
+ parent.l0.addWidget(btn, yp + parent.quadbtns.button(b).ypos, 5+parent.quadbtns.button(b).xpos, 1, 1)
btn.setEnabled(True)
b += 1
parent.quadbtns.setExclusive(True)
@@ -271,7 +331,7 @@ def __init__(self, parent=None, row=0, col=0):
button.setChecked(True)
self.addButton(button, b)
button.toggled.connect(lambda: self.btnpress(parent))
- self.parent.l0.addWidget(button, row+b,col,1,1)
+ self.parent.l0.addWidget(button, row+b,col,1,3)
self.setExclusive(True)
#self.buttons.
@@ -559,7 +619,7 @@ def __init__(self, parent=None, *args):
self.hover_control = QStyle.SC_None
self.click_offset = 0
- self.setOrientation(QtCore.Qt.Vertical)
+ self.setOrientation(QtCore.Qt.Horizontal)
self.setTickPosition(QSlider.TicksRight)
self.setStyleSheet(\
"QSlider::handle:horizontal {\
diff --git a/cellpose/gui/io.py b/cellpose/gui/io.py
index 910b9131..1f4709bb 100644
--- a/cellpose/gui/io.py
+++ b/cellpose/gui/io.py
@@ -1,11 +1,12 @@
-import os, datetime, gc, warnings, glob
+import os, datetime, gc, warnings, glob, shutil
from natsort import natsorted
import numpy as np
import cv2
import tifffile
import logging
+import fastremap
-from .. import utils, plot, transforms
+from .. import utils, plot, transforms, models
from ..io import imread, imsave, outlines_to_text
try:
@@ -24,7 +25,82 @@
# WIP to make GUI use N-color masks. Tricky thing is that only the display should be
# reduced to N colors; selection and editing should act on unique labels.
-def _load_image(parent, filename=None):
+def _init_model_list(parent):
+ models.MODEL_DIR.mkdir(parents=True, exist_ok=True)
+ parent.model_list_path = os.fspath(models.MODEL_DIR.joinpath('gui_models.txt'))
+ parent.model_strings = models.MODEL_NAMES.copy()
+ if not os.path.exists(parent.model_list_path):
+ textfile = open(parent.model_list_path, 'w')
+ textfile.close()
+ else:
+ with open(parent.model_list_path, 'r') as textfile:
+ lines = [line.rstrip() for line in textfile]
+ if len(lines) > 0:
+ parent.model_strings.extend(lines)
+
+def _add_model(parent, filename=None):
+ if filename is None:
+ name = QFileDialog.getOpenFileName(
+ parent, "Add model to GUI"
+ )
+ filename = name[0]
+ fname = os.path.split(filename)[-1]
+ shutil.copyfile(filename, os.fspath(models.MODEL_DIR.joinpath(fname)))
+ print(f'GUI_INFO: {filename} copied to models folder {os.fspath(models.MODEL_DIR)}')
+ with open(parent.model_list_path, 'a') as textfile:
+ textfile.write(fname + '\n')
+ parent.ModelChoose.addItems([fname])
+ parent.model_strings.append(fname)
+ parent.ModelChoose.setCurrentIndex(len(parent.model_strings) - 1)
+
+def _remove_model(parent, ind=None):
+ if ind is None:
+ ind = parent.ModelChoose.currentIndex()
+ if ind > len(models.MODEL_NAMES)-1:
+ print(f'GUI_INFO: deleting {parent.model_strings[ind]} from GUI')
+ parent.ModelChoose.removeItem(ind)
+ del parent.model_strings[ind]
+ custom_strings = parent.model_strings[len(models.MODEL_NAMES):]
+ if len(custom_strings) > 0:
+ with open(parent.model_list_path, 'w') as textfile:
+ for fname in custom_strings:
+ textfile.write(fname + '\n')
+ parent.ModelChoose.setCurrentIndex(len(parent.model_strings) - 1)
+ else:
+ # write empty file
+ textfile = open(parent.model_list_path, 'w')
+ textfile.close()
+ parent.ModelChoose.setCurrentIndex(0)
+ else:
+ print('ERROR: cannot remove built-in model, select custom model to delete')
+
+
+def _get_train_set(image_names):
+ """ get training data and labels for images in current folder image_names"""
+ train_data, train_labels, train_files = [], [], []
+ for image_name_full in image_names:
+ image_name = os.path.splitext(image_name_full)[0]
+ label_name = None
+ if os.path.exists(image_name + '_seg.npy'):
+ dat = np.load(image_name + '_seg.npy', allow_pickle=True).item()
+ masks = dat['masks']
+ imsave(image_name + '_masks.tif', masks)
+ label_name = image_name + '_masks.tif'
+ else:
+ mask_filter = '_masks'
+ if os.path.exists(image_name + mask_filter + '.tif'):
+ label_name = image_name + mask_filter + '.tif'
+ elif os.path.exists(image_name + mask_filter + '.tiff'):
+ label_name = image_name + mask_filter + '.tiff'
+ elif os.path.exists(image_name + mask_filter + '.png'):
+ label_name = image_name + mask_filter + '.png'
+ if label_name is not None:
+ train_files.append(image_name_full)
+ train_data.append(imread(image_name_full))
+ train_labels.append(imread(label_name))
+ return train_data, train_labels, train_files
+
+def _load_image(parent, filename=None, load_seg=True):
""" load image with filename; if None, open QFileDialog """
if filename is None:
name = QFileDialog.getOpenFileName(
@@ -32,24 +108,23 @@ def _load_image(parent, filename=None):
)
filename = name[0]
manual_file = os.path.splitext(filename)[0]+'_seg.npy'
- if os.path.isfile(manual_file):
- print(manual_file)
- _load_seg(parent, manual_file, image=imread(filename), image_file=filename)
- return
- elif os.path.isfile(os.path.splitext(filename)[0]+'_manual.npy'):
- manual_file = os.path.splitext(filename)[0]+'_manual.npy'
- _load_seg(parent, manual_file, image=imread(filename), image_file=filename)
- return
+ if load_seg:
+ if os.path.isfile(manual_file):
+ _load_seg(parent, manual_file, image=imread(filename), image_file=filename)
+ return
+ elif os.path.isfile(os.path.splitext(filename)[0]+'_manual.npy'):
+ manual_file = os.path.splitext(filename)[0]+'_manual.npy'
+ _load_seg(parent, manual_file, image=imread(filename), image_file=filename)
+ return
try:
image = imread(filename)
parent.loaded = True
except:
- print('images not compatible')
+ print('ERROR: images not compatible')
if parent.loaded:
parent.reset()
parent.filename = filename
- print(filename)
filename = os.path.split(parent.filename)[-1]
_initialize_images(parent, image, resize=parent.resize, X2=0)
parent.clear_all()
@@ -118,7 +193,6 @@ def _initialize_images(parent, image, resize, X2):
parent.stack[k] = img
parent.imask=0
- print(parent.NZ, parent.stack[0].shape)
parent.Ly, parent.Lx = img.shape[0], img.shape[1]
parent.stack = np.array(parent.stack)
parent.layers = 0*np.ones((parent.NZ,parent.Ly,parent.Lx,4), np.uint8)
@@ -142,7 +216,7 @@ def _load_seg(parent, filename=None, image=None, image_file=None):
parent.loaded = True
except:
parent.loaded = False
- print('not NPY')
+ print('ERROR: not NPY')
return
parent.reset()
@@ -175,8 +249,7 @@ def _load_seg(parent, filename=None, image=None, image_file=None):
return
else:
parent.filename = image_file
- print(parent.filename)
-
+
if 'X2' in dat:
parent.X2 = dat['X2']
else:
@@ -213,15 +286,14 @@ def _load_seg(parent, filename=None, image=None, image_file=None):
if dat['masks'].min()==-1:
dat['masks'] += 1
dat['outlines'] += 1
+ parent.ncells = dat['masks'].max()
if 'colors' in dat:
colors = dat['colors']
else:
- col_rand = np.random.randint(0, 1000, (dat['masks'].max(),))
- colors = parent.colormap[col_rand,:3]
+ colors = parent.colormap[:parent.ncells,:3]
parent.cellpix = dat['masks']
parent.outpix = dat['outlines']
parent.cellcolors.extend(colors)
- parent.ncells = parent.cellpix.max()
parent.draw_masks()
if 'est_diam' in dat:
parent.Diameter.setText('%0.1f'%dat['est_diam'])
@@ -235,7 +307,7 @@ def _load_seg(parent, filename=None, image=None, image_file=None):
else:
parent.zdraw = [None for n in range(parent.ncells)]
parent.loaded = True
- print('%d masks found'%(parent.ncells))
+ print(f'GUI_INFO: {parent.ncells} masks found in {filename}')
else:
parent.clear_all()
@@ -301,7 +373,7 @@ def _load_masks(parent, filename=None):
if masks.shape[0]!=parent.NZ:
print('ERROR: masks are not same depth (number of planes) as image stack')
return
- print('%d masks found'%(len(np.unique(masks))-1))
+ print(f'GUI_INFO: {len(np.unique(masks))-1} masks found in {filename}')
_masks_to_gui(parent, masks, outlines)
@@ -312,12 +384,9 @@ def _masks_to_gui(parent, masks, outlines=None):
# get unique values
shape = masks.shape
- if NCOLOR:
- masks = ncolor.label(masks)
- else:
- _, masks = np.unique(masks, return_inverse=True)
- masks = np.reshape(masks, shape)
- masks = masks.astype(np.uint16) if masks.max()<2**16-1 else masks.astype(np.uint32)
+ fastremap.renumber(masks, in_place=True)
+ masks = np.reshape(masks, shape)
+ masks = masks.astype(np.uint16) if masks.max()<2**16-1 else masks.astype(np.uint32)
parent.cellpix = masks
# get outlines
@@ -326,8 +395,8 @@ def _masks_to_gui(parent, masks, outlines=None):
for z in range(parent.NZ):
outlines = utils.masks_to_outlines(masks[z])
parent.outpix[z] = outlines * masks[z]
- if z%50==0:
- print('plane %d outlines processed'%z)
+ if z%50==0 and parent.NZ > 1:
+ print('GUI_INFO: plane %d outlines processed'%z)
else:
parent.outpix = outlines
shape = parent.outpix.shape
@@ -335,12 +404,7 @@ def _masks_to_gui(parent, masks, outlines=None):
parent.outpix = np.reshape(parent.outpix, shape)
parent.ncells = parent.cellpix.max()
- np.random.seed(42) #try to make a bit more stable
-
- if NCOLOR:
- colors = parent.colormap[np.linspace(0,255,parent.ncells).astype(int), :3]
- else:
- colors = parent.colormap[np.random.randint(0,1000,size=parent.ncells), :3]
+ colors = parent.colormap[:parent.ncells, :3]
parent.cellcolors = list(np.concatenate((np.array([[255,255,255]]), colors), axis=0).astype(np.uint8))
parent.draw_masks()
@@ -356,17 +420,17 @@ def _save_png(parent):
filename = parent.filename
base = os.path.splitext(filename)[0]
if parent.NZ==1:
- print('saving 2D masks to png')
+ print('GUI_INFO: saving 2D masks to png')
imsave(base + '_cp_masks.png', parent.cellpix[0])
else:
- print('saving 3D masks to tiff')
+ print('GUI_INFO: saving 3D masks to tiff')
imsave(base + '_cp_masks.tif', parent.cellpix)
def _save_outlines(parent):
filename = parent.filename
base = os.path.splitext(filename)[0]
if parent.NZ==1:
- print('saving 2D outlines to text file, see docs for info to load into ImageJ')
+ print('GUI_INFO: saving 2D outlines to text file, see docs for info to load into ImageJ')
outlines = utils.outlines_list(parent.cellpix[0])
outlines_to_text(base, outlines)
else:
@@ -400,8 +464,7 @@ def _save_sets(parent):
'ismanual': parent.ismanual,
'X2': parent.X2,
'filename': parent.filename,
- 'flows': parent.flows})
+ 'flows': parent.flows,
+ 'model_path': parent.current_model_path if hasattr(parent, 'current_model_path') else 0})
#print(parent.point_sets)
- print('--- %d ROIs saved chan1 %s, chan2 %s'%(parent.ncells,
- parent.ChannelChoose[0].currentText(),
- parent.ChannelChoose[1].currentText()))
+ print('GUI_INFO: %d ROIs saved to %s'%(parent.ncells, base + '_seg.npy'))
diff --git a/cellpose/gui/menus.py b/cellpose/gui/menus.py
index 2efbace5..5fc358a8 100644
--- a/cellpose/gui/menus.py
+++ b/cellpose/gui/menus.py
@@ -86,6 +86,37 @@ def editmenu(parent):
parent.remcell.setEnabled(False)
edit_menu.addAction(parent.remcell)
+ parent.mergecell = QAction('FYI: Merge cells by Alt+Click', parent)
+ parent.mergecell.setEnabled(False)
+ edit_menu.addAction(parent.mergecell)
+
+def modelmenu(parent):
+ main_menu = parent.menuBar()
+ io._init_model_list(parent)
+ model_menu = main_menu.addMenu("&Models")
+ parent.addmodel = QAction('Add custom torch model to GUI', parent)
+ #parent.addmodel.setShortcut("Ctrl+A")
+ parent.addmodel.triggered.connect(parent.add_model)
+ parent.addmodel.setEnabled(True)
+ model_menu.addAction(parent.addmodel)
+
+ parent.removemodel = QAction('Remove selected custom model from GUI', parent)
+ #parent.removemodel.setShortcut("Ctrl+R")
+ parent.removemodel.triggered.connect(parent.remove_model)
+ parent.removemodel.setEnabled(True)
+ model_menu.addAction(parent.removemodel)
+
+ parent.newmodel = QAction('&Train new model with image+masks in folder', parent)
+ parent.newmodel.setShortcut("Ctrl+T")
+ parent.newmodel.triggered.connect(parent.new_model)
+ parent.newmodel.setEnabled(False)
+ model_menu.addAction(parent.newmodel)
+
+ parent.endtrain = QAction('End training', parent)
+ parent.endtrain.triggered.connect(parent.end_train)
+ parent.endtrain.setEnabled(False)
+ model_menu.addAction(parent.endtrain)
+
def helpmenu(parent):
main_menu = parent.menuBar()
help_menu = main_menu.addMenu("&Help")
@@ -103,3 +134,16 @@ def helpmenu(parent):
openGUI.setShortcut("Ctrl+G")
openGUI.triggered.connect(parent.gui_window)
help_menu.addAction(openGUI)
+
+def omnimenu(parent):
+ main_menu = parent.menuBar()
+ omni_menu = main_menu.addMenu("&Omnipose")
+ # use omnipose mask recontruction
+ parent.omni = QAction('use Omnipose mask recontruction algorithm (fix over-segmentation)', parent, checkable=True)
+ parent.omni.setChecked(False)
+ omni_menu.addAction(parent.omni)
+
+ # use DBSCAN clustering
+ parent.cluster = QAction('force DBSCAN clustering when omni is enabled', parent, checkable=True)
+ parent.cluster.setChecked(False)
+ omni_menu.addAction(parent.cluster)
\ No newline at end of file
diff --git a/cellpose/io.py b/cellpose/io.py
index 52d0b318..aa9f4563 100644
--- a/cellpose/io.py
+++ b/cellpose/io.py
@@ -9,7 +9,7 @@
from . import utils, plot, transforms
try:
- import omnipose.utils.format_labels as format_labels
+ from omnipose.utils import format_labels
import ncolor
OMNI_INSTALLED = True
except:
@@ -155,6 +155,9 @@ def get_label_files(image_names, mask_filter, imf=None):
label_names = [label_names[n] + mask_filter + '.tiff' for n in range(nimg)]
elif os.path.exists(label_names[0] + mask_filter + '.png'):
label_names = [label_names[n] + mask_filter + '.png' for n in range(nimg)]
+ # todo, allow _seg.npy
+ #elif os.path.exists(label_names[0] + '_seg.npy'):
+ # io_logger.info('labels found as _seg.npy files, converting to tif')
else:
raise ValueError('labels not provided with correct --mask_filter')
if not all([os.path.exists(label) for label in label_names]):
diff --git a/cellpose/models.py b/cellpose/models.py
index cc57d0a7..ab8f4602 100644
--- a/cellpose/models.py
+++ b/cellpose/models.py
@@ -21,7 +21,10 @@
_MODEL_DIR_ENV = os.environ.get("CELLPOSE_LOCAL_MODELS_PATH")
_MODEL_DIR_DEFAULT = pathlib.Path.home().joinpath('.cellpose', 'models')
MODEL_DIR = pathlib.Path(_MODEL_DIR_ENV) if _MODEL_DIR_ENV else _MODEL_DIR_DEFAULT
-MODEL_NAMES = ['cyto','nuclei','bact','cyto2','bact_omni','cyto2_omni']
+if OMNI_INSTALLED:
+ MODEL_NAMES = ['cyto','nuclei','cyto2','bact','bact_omni','cyto2_omni']
+else:
+ MODEL_NAMES = ['cyto','nuclei','cyto2']
def model_path(model_type, model_index, use_torch):
torch_str = 'torch' if use_torch else ''
@@ -771,7 +774,8 @@ def train(self, train_data, train_labels, train_files=None,
channels=None, normalize=True, pretrained_model=None,
save_path=None, save_every=100, save_each=False,
learning_rate=0.2, n_epochs=500, momentum=0.9, SGD=True,
- weight_decay=0.00001, batch_size=8, rescale=False, omni=False):
+ weight_decay=0.00001, batch_size=8, rescale=True, omni=False,
+ netstr=None):
""" train network with images train_data
@@ -832,6 +836,9 @@ def train(self, train_data, train_labels, train_files=None,
if True it assumes you will fit a size model after training or resize your images accordingly,
if False it will try to train the model to be scale-invariant (works worse)
+ netstr: str (default, None)
+ name of network, otherwise saved with name as params + training start time
+
"""
if rescale:
models_logger.info(f'Training with rescale = {rescale:.2f}')
@@ -847,9 +854,94 @@ def train(self, train_data, train_labels, train_files=None,
model_path = self._train_net(train_data, train_flows,
test_data, test_flows,
- pretrained_model, save_path, save_every, save_each,
+ save_path, save_every, save_each,
+ learning_rate, n_epochs, momentum, weight_decay, SGD,
+ batch_size, rescale, netstr)
+ self.pretrained_model = model_path
+ return model_path
+
+ def retrain(self, train_data, train_labels, train_files=None,
+ test_data=None, test_labels=None, test_files=None,
+ channels=None, normalize=True,
+ save_path=None, save_every=100, save_each=False,
+ learning_rate=0.025, n_epochs=100, momentum=0.9, SGD=True,
+ weight_decay=0.0001, batch_size=8, rescale=True, omni=False, netstr=None):
+
+ """ retrain network with images train_data
+
+ Parameters
+ ------------------
+
+ train_data: list of arrays (2D or 3D)
+ images for training
+
+ train_labels: list of arrays (2D or 3D)
+ labels for train_data, where 0=no masks; 1,2,...=mask labels
+ can include flows as additional images
+
+ train_files: list of strings
+ file names for images in train_data (to save flows for future runs)
+
+ test_data: list of arrays (2D or 3D)
+ images for testing
+
+ test_labels: list of arrays (2D or 3D)
+ labels for test_data, where 0=no masks; 1,2,...=mask labels;
+ can include flows as additional images
+
+ test_files: list of strings
+ file names for images in test_data (to save flows for future runs)
+
+ channels: list of ints (default, None)
+ channels to use for training
+
+ normalize: bool (default, True)
+ normalize data so 0.0=1st percentile and 1.0=99th percentile of image intensities in each channel
+
+ save_path: string (default, None)
+ where to save trained model, if None it is not saved
+
+ save_every: int (default, 100)
+ save network every [save_every] epochs
+
+ learning_rate: float (default, 0.2)
+ learning rate for training
+
+ n_epochs: int (default, 500)
+ how many times to go through whole training set during training
+
+ weight_decay: float (default, 0.00001)
+
+ SGD: bool (default, True) use SGD as optimization instead of RAdam
+
+ batch_size: int (optional, default 8)
+ number of 224x224 patches to run simultaneously on the GPU
+ (can make smaller or bigger depending on GPU memory usage)
+
+ rescale: bool (default, True)
+ whether or not to rescale images to diam_mean during training,
+ if True it assumes you will fit a size model after training or resize your images accordingly,
+ if False it will try to train the model to be scale-invariant (works worse)
+
+ netstr: str (default, None)
+ name of network, otherwise saved with name as params + training start time
+
+ """
+ train_data, train_labels, test_data, test_labels, run_test = transforms.reshape_train_test(train_data, train_labels,
+ test_data, test_labels,
+ channels, normalize, omni)
+ # check if train_labels have flows
+ train_flows = dynamics.labels_to_flows(train_labels, files=train_files, use_gpu=self.gpu, device=self.device, omni=omni)
+ if run_test:
+ test_flows = dynamics.labels_to_flows(test_labels, files=test_files, use_gpu=self.gpu, device=self.device)
+ else:
+ test_flows = None
+
+ model_path = self._train_net(train_data, train_flows,
+ test_data, test_flows,
+ save_path, save_every, save_each,
learning_rate, n_epochs, momentum, weight_decay, SGD,
- batch_size, rescale)
+ batch_size, rescale, netstr)
self.pretrained_model = model_path
return model_path
diff --git a/cellpose/plot.py b/cellpose/plot.py
index 2ff9ba87..83353a94 100644
--- a/cellpose/plot.py
+++ b/cellpose/plot.py
@@ -43,7 +43,7 @@ def dx_to_circ(dP,transparency=False,mask=None):
"""
dP = np.array(dP)
- mag = np.clip(transforms.normalize99(np.sqrt(np.sum(dP**2,axis=0)),omni=1), 0, 1.)
+ mag = np.clip(transforms.normalize99(np.sqrt(np.sum(dP**2,axis=0))), 0, 1.)
angles = np.arctan2(dP[1], dP[0])+np.pi
a = 2
r = ((np.cos(angles)+1)/a)
diff --git a/cellpose/transforms.py b/cellpose/transforms.py
index 5f9c436e..20aef407 100644
--- a/cellpose/transforms.py
+++ b/cellpose/transforms.py
@@ -444,7 +444,6 @@ def reshape_train_test(train_data, train_labels, test_data, test_labels, channel
return
if not run_test:
- transforms_logger.info('NOTE: test data not provided OR labels incorrect OR not same number of channels as train data')
test_data, test_labels = None, None
return train_data, train_labels, test_data, test_labels, run_test
@@ -500,7 +499,6 @@ def reshape_and_normalize_data(train_data, test_data=None, channels=None, normal
if normalize:
data[i] = normalize_img(data[i], axis=0, omni=omni)
nchan = [data[i].shape[0] for i in range(nimg)]
- transforms_logger.info('%s channels = %d'%(['train', 'test'][test], nchan[0]))
run_test = True
return train_data, test_data, run_test
diff --git a/environment.yml b/environment.yml
index df410bbd..3f229753 100644
--- a/environment.yml
+++ b/environment.yml
@@ -15,8 +15,7 @@ dependencies:
- google-cloud-storage
- tqdm
- tifffile
- - scikit-image
- - scikit-learn
+ - fastremap
- cellpose
diff --git a/setup.py b/setup.py
index c7459484..1faba3de 100644
--- a/setup.py
+++ b/setup.py
@@ -4,7 +4,6 @@
install_deps = ['numpy>=1.20.0', 'scipy', 'natsort',
'tifffile', 'tqdm', 'numba',
'torch>=1.6',
- 'torch_optimizer',
'opencv-python-headless',
'fastremap'
]
@@ -25,7 +24,8 @@
omni_deps = [
'scikit-image',
'scikit-learn',
- 'edt','fastremap','torch_optimizer',
+ 'edt',
+ 'torch_optimizer',
'ncolor'
]
diff --git a/tests/test_train.py b/tests/test_train.py
index 8a88eb30..d96788b8 100644
--- a/tests/test_train.py
+++ b/tests/test_train.py
@@ -17,7 +17,7 @@ def test_class_train(data_dir):
model = models.CellposeModel(pretrained_model=None, diam_mean=30)
cpmodel_path = model.train(images, labels, train_files=image_names,
test_data=test_images, test_labels=test_labels, test_files=image_names_test,
- channels=[2,1], save_path=train_dir, n_epochs=10)
+ channels=[2,1], save_path=train_dir, n_epochs=3)
print('>>>> model trained and saved to %s'%cpmodel_path)
def test_cli_train(data_dir):
@@ -28,7 +28,7 @@ def test_cli_train(data_dir):
train_dir = str(data_dir.joinpath('2D').joinpath('train'))
model_dir = str(data_dir.joinpath('2D').joinpath('train').joinpath('models'))
shutil.rmtree(model_dir, ignore_errors=True)
- cmd = 'python -m cellpose --train --train_size --n_epochs 10 --dir %s --mask_filter _cyto_masks --pretrained_model None --chan 2 --chan2 1 --diameter 40'%train_dir
+ cmd = 'python -m cellpose --train --train_size --n_epochs 3 --dir %s --mask_filter _cyto_masks --pretrained_model None --chan 2 --chan2 1 --diameter 40'%train_dir
try:
cmd_stdout = check_output(cmd, stderr=STDOUT, shell=True).decode()
except Exception as e:
@@ -53,7 +53,7 @@ def test_cli_train_pretrained(data_dir):
train_dir = str(data_dir.joinpath('2D').joinpath('train'))
model_dir = str(data_dir.joinpath('2D').joinpath('train').joinpath('models'))
shutil.rmtree(model_dir, ignore_errors=True)
- cmd = 'python -m cellpose --train --train_size --n_epochs 10 --dir %s --mask_filter _cyto_masks --pretrained_model cyto --chan 2 --chan2 1 --diameter 30'%train_dir
+ cmd = 'python -m cellpose --train --train_size --n_epochs 3 --dir %s --mask_filter _cyto_masks --pretrained_model cyto --chan 2 --chan2 1 --diameter 30'%train_dir
try:
cmd_stdout = check_output(cmd, stderr=STDOUT, shell=True).decode()
except Exception as e: