Skip to content

Commit

Permalink
Merge pull request #424 from MouseLand/updates
Browse files Browse the repository at this point in the history
Updates
  • Loading branch information
carsen-stringer authored Jan 19, 2022
2 parents f90f046 + fb5432e commit e3f7f5f
Show file tree
Hide file tree
Showing 13 changed files with 573 additions and 211 deletions.
43 changes: 24 additions & 19 deletions cellpose/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,12 +40,11 @@ def parse_model_string(pretrained_model):
else:
model_str = os.path.split(pretrained_model)[-1]
if len(model_str)>3 and model_str[:4]=='unet':
core_logger.info(f'parsing model string {model_str} to get unet options')
nclasses = max(2, int(model_str[4]))
elif len(model_str)>7 and model_str[:8]=='cellpose':
core_logger.info(f'parsing model string {model_str} to get cellpose options')
nclasses = 3
else:
return None
return True, True, False
ostrs = model_str.split('_')[2::2]
residual_on = ostrs[0]=='on'
style_on = ostrs[1]=='on'
Expand Down Expand Up @@ -682,7 +681,7 @@ def loss_fn(self, lbl, y):

def train(self, train_data, train_labels, train_files=None,
test_data=None, test_labels=None, test_files=None,
channels=None, normalize=True, pretrained_model=None, save_path=None, save_every=50, save_each=False,
channels=None, normalize=True, save_path=None, save_every=50, save_each=False,
learning_rate=0.2, n_epochs=500, momentum=0.9, weight_decay=0.00001, batch_size=8, rescale=False):
""" train function uses 0-1 mask label and boundary pixels for training """

Expand Down Expand Up @@ -715,8 +714,7 @@ def train(self, train_data, train_labels, train_files=None,
del train_data[::8], train_classes[::8], train_labels[::8]

model_path = self._train_net(train_data, train_classes,
test_data, test_classes,
pretrained_model, save_path, save_every, save_each,
test_data, test_classes, save_path, save_every, save_each,
learning_rate, n_epochs, momentum, weight_decay,
batch_size, rescale)

Expand Down Expand Up @@ -842,9 +840,9 @@ def _set_criterion(self):
# Restored defaults. Need to make sure rescale is properly turned off and omni turned on when using CLI.
def _train_net(self, train_data, train_labels,
test_data=None, test_labels=None,
pretrained_model=None, save_path=None, save_every=100, save_each=False,
save_path=None, save_every=100, save_each=False,
learning_rate=0.2, n_epochs=500, momentum=0.9, weight_decay=0.00001,
SGD=True, batch_size=8, rescale=True, netstr='cellpose'):
SGD=True, batch_size=8, rescale=True, netstr=None):
""" train function uses loss function self.loss_fn in models.py"""

d = datetime.datetime.now()
Expand All @@ -870,13 +868,12 @@ def _train_net(self, train_data, train_labels,

nchan = train_data[0].shape[0]
core_logger.info('>>>> training network with %d channel input <<<<'%nchan)
core_logger.info('>>>> saving every %d epochs'%save_every)
core_logger.info('>>>> LR: %0.5f, batch_size: %d, weight_decay: %0.5f'%(self.learning_rate, self.batch_size, weight_decay))
core_logger.info('>>>> ntrain = %d'%nimg)
core_logger.info('>>>> rescale is %d'%rescale)

if test_data is not None:
core_logger.info('>>>> ntest = %d'%len(test_data))
core_logger.info(train_data[0].shape)
core_logger.info(f'>>>> ntrain = {nimg}, ntest = {len(test_data)}')
else:
core_logger.info(f'>>>> ntrain = {nimg}')

tic = time.time()

Expand All @@ -891,7 +888,6 @@ def _train_net(self, train_data, train_labels,
LR = np.append(LR, self.learning_rate*np.ones(max(0,self.n_epochs-10)))
else:
LR = self.learning_rate * np.ones(self.n_epochs)


lavg, nsum = 0, 0

Expand All @@ -913,7 +909,10 @@ def _train_net(self, train_data, train_labels,
for iepoch in range(self.n_epochs):

np.random.seed(iepoch)
rperm = np.random.permutation(nimg)
if nimg < batch_size:
rperm = np.random.choice(nimg, batch_size)
else:
rperm = np.random.permutation(nimg)
if SGD:
self._set_learning_rate(LR[iepoch])

Expand Down Expand Up @@ -961,11 +960,17 @@ def _train_net(self, train_data, train_labels,
if iepoch==self.n_epochs-1 or iepoch%save_every==1:
# save model at the end
if save_each: #separate files as model progresses
file_name = '{}_{}_{}_{}'.format(self.net_type, file_label,
d.strftime("%Y_%m_%d_%H_%M_%S.%f"),
'epoch_'+str(iepoch))
if netstr is None:
file_name = '{}_{}_{}_{}'.format(self.net_type, file_label,
d.strftime("%Y_%m_%d_%H_%M_%S.%f"),
'epoch_'+str(iepoch))
else:
file_name = '{}_{}'.format(netstr, 'epoch_'+str(iepoch))
else:
file_name = '{}_{}_{}'.format(self.net_type, file_label, d.strftime("%Y_%m_%d_%H_%M_%S.%f"))
if netstr is None:
file_name = '{}_{}_{}'.format(self.net_type, file_label, d.strftime("%Y_%m_%d_%H_%M_%S.%f"))
else:
file_name = netstr
file_name = os.path.join(file_path, file_name)
ksave += 1
core_logger.info(f'saving network parameters to {file_name}')
Expand Down
3 changes: 2 additions & 1 deletion cellpose/dynamics.py
Original file line number Diff line number Diff line change
Expand Up @@ -601,7 +601,7 @@ def labels_to_flows(labels, files=None, use_gpu=False, device=None, omni=False,r

if labels[0].shape[0] == 1 or labels[0].ndim < 3 or redo_flows: # flows need to be recomputed

dynamics_logger.info('NOTE: computing flows for labels (could be done before to save time)')
dynamics_logger.info('computing flows for labels')

# compute flows; labels are fixed in masks_to_flows, so they need to be passed back
labels, dist, heat, veci = map(list,zip(*[masks_to_flows(labels[n][0],use_gpu=use_gpu, device=device, omni=omni) for n in trange(nimg)]))
Expand Down Expand Up @@ -1064,6 +1064,7 @@ def compute_masks(dP, cellprob, bd=None, p=None, inds=None, niter=200, mask_thre
inds = np.stack(np.nonzero(cp_mask)).T
mask = omnipose.core.get_masks(p,bd,cellprob,cp_mask,inds,nclasses,cluster=cluster,
diam_threshold=diam_threshold,verbose=verbose)
mask = mask.astype(np.uint32)
else:
mask = get_masks(p, iscell=cp_mask, flows=dP, use_gpu=use_gpu)

Expand Down
Loading

0 comments on commit e3f7f5f

Please sign in to comment.