Skip to content

Commit 1f8fe3e

Browse files
authored
Merge pull request #7 from juglab/v0.2.3
v0.2.3
2 parents b7b965e + ff56e4f commit 1f8fe3e

File tree

9 files changed

+77
-1969
lines changed

9 files changed

+77
-1969
lines changed

EmbedSeg/test.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,7 @@ def begin_evaluating(test_configs, verbose=True, mask_region = None, mask_intens
6161
if(test_configs['name']=='2d'):
6262
test(verbose = verbose, grid_x = test_configs['grid_x'], grid_y = test_configs['grid_y'],
6363
pixel_x = test_configs['pixel_x'], pixel_y = test_configs['pixel_y'],
64-
one_hot = test_configs['dataset']['kwargs']['one_hot'], avg_bg = avg_bg)
64+
one_hot = test_configs['dataset']['kwargs']['one_hot'], avg_bg = avg_bg, n_sigma=n_sigma)
6565
elif(test_configs['name']=='3d'):
6666
test_3d(verbose=verbose,
6767
grid_x=test_configs['grid_x'], grid_y=test_configs['grid_y'], grid_z=test_configs['grid_z'],
@@ -70,7 +70,7 @@ def begin_evaluating(test_configs, verbose=True, mask_region = None, mask_intens
7070

7171

7272

73-
def test(verbose, grid_y=1024, grid_x=1024, pixel_y=1, pixel_x=1, one_hot = False, avg_bg = 0):
73+
def test(verbose, grid_y=1024, grid_x=1024, pixel_y=1, pixel_x=1, one_hot = False, avg_bg = 0, n_sigma = 2):
7474
"""
7575
:param verbose: if True, then average prevision is printed out for each image
7676
:param grid_y:
@@ -126,7 +126,7 @@ def test(verbose, grid_y=1024, grid_x=1024, pixel_y=1, pixel_x=1, one_hot = Fals
126126

127127
center_x, center_y, samples_x, samples_y, sample_spatial_embedding_x, sample_spatial_embedding_y, sigma_x, sigma_y, \
128128
color_sample_dic, color_embedding_dic = prepare_embedding_for_test_image(instance_map = instance_map, output = output, grid_x = grid_x, grid_y = grid_y,
129-
pixel_x = pixel_x, pixel_y =pixel_y, predictions =predictions)
129+
pixel_x = pixel_x, pixel_y =pixel_y, predictions =predictions, n_sigma = n_sigma)
130130

131131
base, _ = os.path.splitext(os.path.basename(sample['im_name'][0]))
132132
imageFileNames.append(base)

EmbedSeg/train.py

Lines changed: 20 additions & 80 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,12 @@
11
import os
22
import shutil
3-
43
import torch
54
from matplotlib import pyplot as plt
65
from tqdm import tqdm
7-
86
from EmbedSeg.criterions import get_loss
97
from EmbedSeg.datasets import get_dataset
108
from EmbedSeg.models import get_model
119
from EmbedSeg.utils.utils import AverageMeter, Cluster, Cluster_3d, Logger, Visualizer, prepare_embedding_for_train_image
12-
1310
torch.backends.cudnn.benchmark = True
1411
from matplotlib.colors import ListedColormap
1512
import numpy as np
@@ -137,7 +134,7 @@ def train_3d(virtual_batch_multiplier, one_hot, n_sigma, args):
137134

138135

139136
def train_vanilla_3d(display, display_embedding, display_it, one_hot, grid_x, grid_y, grid_z, pixel_x, pixel_y, pixel_z, n_sigma,
140-
args): # this is without virtual batches!
137+
zslice, args): # this is without virtual batches!
141138

142139
# define meters
143140
loss_meter = AverageMeter()
@@ -163,50 +160,22 @@ def train_vanilla_3d(display, display_embedding, display_it, one_hot, grid_x, gr
163160
loss_meter.update(loss.item())
164161
if display and i % display_it == 0:
165162
with torch.no_grad():
166-
visualizer.display(im[0], key='image', title='Image')
163+
visualizer.display(im[0, 0, zslice], key='image', title='Image')
167164
predictions = cluster.cluster_with_gt(output[0], instances[0], n_sigma=n_sigma)
168165
if one_hot:
169166
instance = invert_one_hot(instances[0].cpu().detach().numpy())
170167
visualizer.display(instance, key='groundtruth', title='Ground Truth') # TODO
171168
instance_ids = np.arange(instances.size(1)) # instances[0] --> DYX
172169
else:
173-
visualizer.display(instances[0].cpu(), key='groundtruth', title='Ground Truth') # TODO
170+
visualizer.display(instances[0, zslice].cpu(), key='groundtruth', title='Ground Truth') # TODO
174171
instance_ids = instances[0].unique()
175172
instance_ids = instance_ids[instance_ids != 0]
176173

177-
if display_embedding:
178-
center_x, center_y, samples_x, samples_y, sample_spatial_embedding_x, \
179-
sample_spatial_embedding_y, sigma_x, sigma_y, color_sample_dic, color_embedding_dic = \
180-
prepare_embedding_for_train_image(one_hot=one_hot, grid_x=grid_x, grid_y=grid_y,
181-
pixel_x=pixel_x, pixel_y=pixel_y,
182-
predictions=predictions, instance_ids=instance_ids,
183-
center_images=center_images,
184-
output=output, instances=instances, n_sigma=n_sigma)
185-
if one_hot:
186-
visualizer.display(torch.max(instances[0], dim=0)[0], key='center', title='Center',
187-
center_x=center_x,
188-
center_y=center_y,
189-
samples_x=samples_x, samples_y=samples_y,
190-
sample_spatial_embedding_x=sample_spatial_embedding_x,
191-
sample_spatial_embedding_y=sample_spatial_embedding_y,
192-
sigma_x=sigma_x, sigma_y=sigma_y,
193-
color_sample=color_sample_dic, color_embedding=color_embedding_dic)
194-
else:
195-
visualizer.display(instances[0] > 0, key='center', title='Center', center_x=center_x,
196-
center_y=center_y,
197-
samples_x=samples_x, samples_y=samples_y,
198-
sample_spatial_embedding_x=sample_spatial_embedding_x,
199-
sample_spatial_embedding_y=sample_spatial_embedding_y,
200-
sigma_x=sigma_x, sigma_y=sigma_y,
201-
color_sample=color_sample_dic, color_embedding=color_embedding_dic)
202-
visualizer.display(predictions.cpu(), key='prediction', title='Prediction') # TODO
174+
visualizer.display(predictions.cpu()[zslice, ...], key='prediction', title='Prediction') # TODO
203175

204176
return loss_meter.avg
205177

206178

207-
208-
209-
210179
def val(virtual_batch_multiplier, one_hot, n_sigma, args):
211180
# define meters
212181
loss_meter, iou_meter = AverageMeter(), AverageMeter()
@@ -248,7 +217,7 @@ def val_vanilla(display, display_embedding, display_it, one_hot, grid_x, grid_y,
248217
if one_hot:
249218
instance = invert_one_hot(instances[0].cpu().detach().numpy())
250219
visualizer.display(instance, key='groundtruth', title='Ground Truth') # TODO
251-
instance_ids = np.arange(instances[0].size(1))
220+
instance_ids = np.arange(instances.size(1))
252221
else:
253222
visualizer.display(instances[0].cpu(), key='groundtruth', title='Ground Truth') # TODO
254223
instance_ids = instances[0].unique()
@@ -284,8 +253,6 @@ def val_vanilla(display, display_embedding, display_it, one_hot, grid_x, grid_y,
284253

285254
return loss_meter.avg, iou_meter.avg
286255

287-
288-
289256
def val_3d(virtual_batch_multiplier, one_hot, n_sigma, args):
290257
# define meters
291258
loss_meter, iou_meter = AverageMeter(), AverageMeter()
@@ -306,7 +273,7 @@ def val_3d(virtual_batch_multiplier, one_hot, n_sigma, args):
306273
return loss_meter.avg * virtual_batch_multiplier, iou_meter.avg
307274

308275

309-
def val_vanilla_3d(display, display_embedding, display_it, one_hot, grid_x, grid_y, grid_z, pixel_x, pixel_y, pixel_z, n_sigma, args):
276+
def val_vanilla_3d(display, display_embedding, display_it, one_hot, grid_x, grid_y, grid_z, pixel_x, pixel_y, pixel_z, n_sigma, zslice, args):
310277
# define meters
311278
loss_meter, iou_meter = AverageMeter(), AverageMeter()
312279
# put model into eval mode
@@ -322,44 +289,19 @@ def val_vanilla_3d(display, display_embedding, display_it, one_hot, grid_x, grid
322289
loss = loss.mean()
323290
if display and i % display_it == 0:
324291
with torch.no_grad():
325-
visualizer.display(im[0], key='image', title='Image')
292+
visualizer.display(im[0, 0, zslice], key='image', title='Image')
326293
predictions = cluster.cluster_with_gt(output[0], instances[0], n_sigma=n_sigma)
327294
if one_hot:
328295
instance = invert_one_hot(instances[0].cpu().detach().numpy())
329296
visualizer.display(instance, key='groundtruth', title='Ground Truth') # TODO
330-
instance_ids = np.arange(instances[0].size(1))
297+
instance_ids = np.arange(instances.size(1))
331298
else:
332-
visualizer.display(instances[0].cpu(), key='groundtruth', title='Ground Truth') # TODO
299+
visualizer.display(instances[0, zslice].cpu(), key='groundtruth', title='Ground Truth') # TODO
333300
instance_ids = instances[0].unique()
334301
instance_ids = instance_ids[instance_ids != 0]
335-
if (display_embedding):
336-
center_x, center_y, samples_x, samples_y, sample_spatial_embedding_x, \
337-
sample_spatial_embedding_y, sigma_x, sigma_y, color_sample_dic, color_embedding_dic = \
338-
prepare_embedding_for_train_image(one_hot=one_hot, grid_x=grid_x, grid_y=grid_y,
339-
pixel_x=pixel_x, pixel_y=pixel_y,
340-
predictions=predictions, instance_ids=instance_ids,
341-
center_images=center_images,
342-
output=output, instances=instances, n_sigma=n_sigma)
343-
if one_hot:
344-
visualizer.display(torch.max(instances[0], dim=0)[0].cpu(), key='center', title='Center',
345-
# torch.max returns a tuple
346-
center_x=center_x,
347-
center_y=center_y,
348-
samples_x=samples_x, samples_y=samples_y,
349-
sample_spatial_embedding_x=sample_spatial_embedding_x,
350-
sample_spatial_embedding_y=sample_spatial_embedding_y,
351-
sigma_x=sigma_x, sigma_y=sigma_y,
352-
color_sample=color_sample_dic, color_embedding=color_embedding_dic)
353-
else:
354-
visualizer.display(instances[0] > 0, key='center', title='Center', center_x=center_x,
355-
center_y=center_y,
356-
samples_x=samples_x, samples_y=samples_y,
357-
sample_spatial_embedding_x=sample_spatial_embedding_x,
358-
sample_spatial_embedding_y=sample_spatial_embedding_y,
359-
sigma_x=sigma_x, sigma_y=sigma_y,
360-
color_sample=color_sample_dic, color_embedding=color_embedding_dic)
361302

362-
visualizer.display(predictions.cpu(), key='prediction', title='Prediction') # TODO
303+
304+
visualizer.display(predictions.cpu()[zslice, ...], key='prediction', title='Prediction') # TODO
363305

364306
loss_meter.update(loss.item())
365307

@@ -375,13 +317,14 @@ def invert_one_hot(image):
375317
return instance
376318

377319

378-
def save_checkpoint(state, is_best, epoch, save_dir, name='checkpoint.pth'):
320+
def save_checkpoint(state, is_best, epoch, save_dir, save_checkpoint_frequency, name='checkpoint.pth'):
379321
print('=> saving checkpoint')
380322
file_name = os.path.join(save_dir, name)
381323
torch.save(state, file_name)
382-
if (epoch % 10 == 0):
383-
file_name2 = os.path.join(save_dir, str(epoch) + "_" + name)
384-
torch.save(state, file_name2)
324+
if(save_checkpoint_frequency is not None):
325+
if (epoch % int(save_checkpoint_frequency) == 0):
326+
file_name2 = os.path.join(save_dir, str(epoch) + "_" + name)
327+
torch.save(state, file_name2)
385328
if is_best:
386329
shutil.copyfile(file_name, os.path.join(
387330
save_dir, 'best_iou_model.pth'))
@@ -409,7 +352,6 @@ def begin_training(train_dataset_dict, val_dataset_dict, model_dict, loss_dict,
409352

410353
# train dataloader
411354

412-
413355
train_dataset = get_dataset(train_dataset_dict['name'], train_dataset_dict['kwargs'])
414356
train_dataset_it = torch.utils.data.DataLoader(train_dataset, batch_size=train_dataset_dict['batch_size'],
415357
shuffle=True, drop_last=True,
@@ -459,8 +401,6 @@ def lambda_(epoch):
459401
configs['pixel_x'], configs['one_hot'])
460402

461403
# Visualizer
462-
463-
464404
visualizer = Visualizer(('image', 'groundtruth', 'prediction', 'center'), color_map) # 5 keys
465405

466406
# Logger
@@ -519,10 +459,10 @@ def lambda_(epoch):
519459
train_loss = train_vanilla_3d(display=configs['display'],
520460
display_embedding=configs['display_embedding'],
521461
display_it=configs['display_it'], one_hot=configs['one_hot'],
522-
n_sigma=loss_dict['lossOpts']['n_sigma'], grid_x=configs['grid_x'],
462+
n_sigma=loss_dict['lossOpts']['n_sigma'], zslice = configs['display_zslice'], grid_x=configs['grid_x'],
523463
grid_y=configs['grid_y'], grid_z=configs['grid_z'],
524464
pixel_x=configs['pixel_x'], pixel_y=configs['pixel_y'],
525-
pixel_z=configs['pixel_z'], args=loss_dict['lossW'])
465+
pixel_z=configs['pixel_z'], args=loss_dict['lossW'], )
526466

527467
if (val_dataset_dict['virtual_batch_multiplier'] > 1):
528468
val_loss, val_iou = val_3d(virtual_batch_multiplier=val_dataset_dict['virtual_batch_multiplier'],
@@ -532,7 +472,7 @@ def lambda_(epoch):
532472
val_loss, val_iou = val_vanilla_3d(display=configs['display'],
533473
display_embedding=configs['display_embedding'],
534474
display_it=configs['display_it'], one_hot=configs['one_hot'],
535-
n_sigma=loss_dict['lossOpts']['n_sigma'], grid_x=configs['grid_x'],
475+
n_sigma=loss_dict['lossOpts']['n_sigma'], zslice = configs['display_zslice'], grid_x=configs['grid_x'],
536476
grid_y=configs['grid_y'], grid_z=configs['grid_z'],
537477
pixel_x=configs['pixel_x'], pixel_y=configs['pixel_y'], pixel_z=configs['pixel_z'],
538478
args=loss_dict['lossW'])
@@ -558,6 +498,6 @@ def lambda_(epoch):
558498
'optim_state_dict': optimizer.state_dict(),
559499
'logger_data': logger.data,
560500
}
561-
save_checkpoint(state, is_best, epoch, save_dir=configs['save_dir'])
501+
save_checkpoint(state, is_best, epoch, save_dir=configs['save_dir'], save_checkpoint_frequency=configs['save_checkpoint_frequency'])
562502

563503

EmbedSeg/utils/create_dicts.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -302,7 +302,8 @@ def create_configs(save_dir,
302302
anisotropy_factor = None,
303303
l_y = 1,
304304
l_x = 1,
305-
305+
save_checkpoint_frequency = None,
306+
display_zslice = None
306307
):
307308
"""
308309
Creates `configs` dictionary from parameters.
@@ -337,6 +338,9 @@ def create_configs(save_dir,
337338
Pixel size in y
338339
pixel_x: int
339340
Pixel size in x
341+
save_checkpoint_frequency: int
342+
Save model weights after 'n' epochs (in addition to last and best model weights)
343+
Default is None
340344
"""
341345
if (n_z is None):
342346
l_z = None
@@ -358,7 +362,10 @@ def create_configs(save_dir,
358362
pixel_z = l_z,
359363
pixel_y = l_y,
360364
pixel_x = l_x,
361-
one_hot=one_hot)
365+
one_hot=one_hot,
366+
save_checkpoint_frequency=save_checkpoint_frequency,
367+
display_zslice = display_zslice
368+
)
362369
print(
363370
"`configs` dictionary successfully created with: "
364371
"\n -- n_epochs equal to {}, "

EmbedSeg/utils/utils.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -419,9 +419,9 @@ def prepare_embedding_for_train_image(one_hot, grid_x, grid_y, pixel_x, pixel_y,
419419
sample_spatial_embedding_y[id.item()] = add_samples(samples_spatial_embeddings, 1, grid_y - 1,
420420
pixel_y)
421421

422-
centre_mask = in_mask & center_images[0]
423-
if (centre_mask.sum().eq(1)):
424-
center = xym_s[centre_mask.expand_as(xym_s)].view(2, 1, 1)
422+
center_mask = in_mask & center_images[0].byte()
423+
if (center_mask.sum().eq(1)):
424+
center = xym_s[center_mask.expand_as(xym_s)].view(2, 1, 1)
425425
else:
426426
xy_in = xym_s[in_mask.expand_as(xym_s)].view(2, -1)
427427
center = xy_in.mean(1).view(2, 1, 1) # 2 x 1 x 1
@@ -444,7 +444,7 @@ def prepare_embedding_for_train_image(one_hot, grid_x, grid_y, pixel_x, pixel_y,
444444
sample_spatial_embedding_y, sigma_x, sigma_y, color_sample_dic, color_embedding_dic
445445

446446

447-
def prepare_embedding_for_test_image(instance_map, output, grid_x, grid_y, pixel_x, pixel_y, predictions):
447+
def prepare_embedding_for_test_image(instance_map, output, grid_x, grid_y, pixel_x, pixel_y, predictions, n_sigma):
448448
instance_ids = instance_map.unique()
449449
instance_ids = instance_ids[instance_ids != 0]
450450

@@ -454,7 +454,7 @@ def prepare_embedding_for_test_image(instance_map, output, grid_x, grid_y, pixel
454454
height, width = instance_map.size(0), instance_map.size(1)
455455
xym_s = xym[:, 0:height, 0:width].contiguous()
456456
spatial_emb = torch.tanh(output[0, 0:2]).cpu() + xym_s
457-
sigma = output[0, 2:2 + 2] # 2/3 Y X replace last + 2 with n_sigma parameter IMP TODO
457+
sigma = output[0, 2:2 + n_sigma]
458458
color_sample = sns.color_palette("dark")
459459
color_embedding = sns.color_palette("bright")
460460
color_sample_dic = {}
@@ -495,7 +495,7 @@ def prepare_embedding_for_test_image(instance_map, output, grid_x, grid_y, pixel
495495
center_y[id.item()] = degrid(center[1], grid_y - 1, pixel_y)
496496

497497
# sigma
498-
s = sigma[in_mask.expand_as(sigma)].view(2, -1).mean(1) # TODO view(2, -1) should become nsigma, -1
498+
s = sigma[in_mask.expand_as(sigma)].view(n_sigma, -1).mean(1)
499499
s = torch.exp(s * 10)
500500
sigma_x_tmp = 0.5 / s[0]
501501
sigma_y_tmp = 0.5 / s[1]

0 commit comments

Comments
 (0)