1
1
import os
2
2
import shutil
3
-
4
3
import torch
5
4
from matplotlib import pyplot as plt
6
5
from tqdm import tqdm
7
-
8
6
from EmbedSeg .criterions import get_loss
9
7
from EmbedSeg .datasets import get_dataset
10
8
from EmbedSeg .models import get_model
11
9
from EmbedSeg .utils .utils import AverageMeter , Cluster , Cluster_3d , Logger , Visualizer , prepare_embedding_for_train_image
12
-
13
10
torch .backends .cudnn .benchmark = True
14
11
from matplotlib .colors import ListedColormap
15
12
import numpy as np
@@ -137,7 +134,7 @@ def train_3d(virtual_batch_multiplier, one_hot, n_sigma, args):
137
134
138
135
139
136
def train_vanilla_3d (display , display_embedding , display_it , one_hot , grid_x , grid_y , grid_z , pixel_x , pixel_y , pixel_z , n_sigma ,
140
- args ): # this is without virtual batches!
137
+ zslice , args ): # this is without virtual batches!
141
138
142
139
# define meters
143
140
loss_meter = AverageMeter ()
@@ -163,50 +160,22 @@ def train_vanilla_3d(display, display_embedding, display_it, one_hot, grid_x, gr
163
160
loss_meter .update (loss .item ())
164
161
if display and i % display_it == 0 :
165
162
with torch .no_grad ():
166
- visualizer .display (im [0 ], key = 'image' , title = 'Image' )
163
+ visualizer .display (im [0 , 0 , zslice ], key = 'image' , title = 'Image' )
167
164
predictions = cluster .cluster_with_gt (output [0 ], instances [0 ], n_sigma = n_sigma )
168
165
if one_hot :
169
166
instance = invert_one_hot (instances [0 ].cpu ().detach ().numpy ())
170
167
visualizer .display (instance , key = 'groundtruth' , title = 'Ground Truth' ) # TODO
171
168
instance_ids = np .arange (instances .size (1 )) # instances[0] --> DYX
172
169
else :
173
- visualizer .display (instances [0 ].cpu (), key = 'groundtruth' , title = 'Ground Truth' ) # TODO
170
+ visualizer .display (instances [0 , zslice ].cpu (), key = 'groundtruth' , title = 'Ground Truth' ) # TODO
174
171
instance_ids = instances [0 ].unique ()
175
172
instance_ids = instance_ids [instance_ids != 0 ]
176
173
177
- if display_embedding :
178
- center_x , center_y , samples_x , samples_y , sample_spatial_embedding_x , \
179
- sample_spatial_embedding_y , sigma_x , sigma_y , color_sample_dic , color_embedding_dic = \
180
- prepare_embedding_for_train_image (one_hot = one_hot , grid_x = grid_x , grid_y = grid_y ,
181
- pixel_x = pixel_x , pixel_y = pixel_y ,
182
- predictions = predictions , instance_ids = instance_ids ,
183
- center_images = center_images ,
184
- output = output , instances = instances , n_sigma = n_sigma )
185
- if one_hot :
186
- visualizer .display (torch .max (instances [0 ], dim = 0 )[0 ], key = 'center' , title = 'Center' ,
187
- center_x = center_x ,
188
- center_y = center_y ,
189
- samples_x = samples_x , samples_y = samples_y ,
190
- sample_spatial_embedding_x = sample_spatial_embedding_x ,
191
- sample_spatial_embedding_y = sample_spatial_embedding_y ,
192
- sigma_x = sigma_x , sigma_y = sigma_y ,
193
- color_sample = color_sample_dic , color_embedding = color_embedding_dic )
194
- else :
195
- visualizer .display (instances [0 ] > 0 , key = 'center' , title = 'Center' , center_x = center_x ,
196
- center_y = center_y ,
197
- samples_x = samples_x , samples_y = samples_y ,
198
- sample_spatial_embedding_x = sample_spatial_embedding_x ,
199
- sample_spatial_embedding_y = sample_spatial_embedding_y ,
200
- sigma_x = sigma_x , sigma_y = sigma_y ,
201
- color_sample = color_sample_dic , color_embedding = color_embedding_dic )
202
- visualizer .display (predictions .cpu (), key = 'prediction' , title = 'Prediction' ) # TODO
174
+ visualizer .display (predictions .cpu ()[zslice , ...], key = 'prediction' , title = 'Prediction' ) # TODO
203
175
204
176
return loss_meter .avg
205
177
206
178
207
-
208
-
209
-
210
179
def val (virtual_batch_multiplier , one_hot , n_sigma , args ):
211
180
# define meters
212
181
loss_meter , iou_meter = AverageMeter (), AverageMeter ()
@@ -248,7 +217,7 @@ def val_vanilla(display, display_embedding, display_it, one_hot, grid_x, grid_y,
248
217
if one_hot :
249
218
instance = invert_one_hot (instances [0 ].cpu ().detach ().numpy ())
250
219
visualizer .display (instance , key = 'groundtruth' , title = 'Ground Truth' ) # TODO
251
- instance_ids = np .arange (instances [ 0 ] .size (1 ))
220
+ instance_ids = np .arange (instances .size (1 ))
252
221
else :
253
222
visualizer .display (instances [0 ].cpu (), key = 'groundtruth' , title = 'Ground Truth' ) # TODO
254
223
instance_ids = instances [0 ].unique ()
@@ -284,8 +253,6 @@ def val_vanilla(display, display_embedding, display_it, one_hot, grid_x, grid_y,
284
253
285
254
return loss_meter .avg , iou_meter .avg
286
255
287
-
288
-
289
256
def val_3d (virtual_batch_multiplier , one_hot , n_sigma , args ):
290
257
# define meters
291
258
loss_meter , iou_meter = AverageMeter (), AverageMeter ()
@@ -306,7 +273,7 @@ def val_3d(virtual_batch_multiplier, one_hot, n_sigma, args):
306
273
return loss_meter .avg * virtual_batch_multiplier , iou_meter .avg
307
274
308
275
309
- def val_vanilla_3d (display , display_embedding , display_it , one_hot , grid_x , grid_y , grid_z , pixel_x , pixel_y , pixel_z , n_sigma , args ):
276
+ def val_vanilla_3d (display , display_embedding , display_it , one_hot , grid_x , grid_y , grid_z , pixel_x , pixel_y , pixel_z , n_sigma , zslice , args ):
310
277
# define meters
311
278
loss_meter , iou_meter = AverageMeter (), AverageMeter ()
312
279
# put model into eval mode
@@ -322,44 +289,19 @@ def val_vanilla_3d(display, display_embedding, display_it, one_hot, grid_x, grid
322
289
loss = loss .mean ()
323
290
if display and i % display_it == 0 :
324
291
with torch .no_grad ():
325
- visualizer .display (im [0 ], key = 'image' , title = 'Image' )
292
+ visualizer .display (im [0 , 0 , zslice ], key = 'image' , title = 'Image' )
326
293
predictions = cluster .cluster_with_gt (output [0 ], instances [0 ], n_sigma = n_sigma )
327
294
if one_hot :
328
295
instance = invert_one_hot (instances [0 ].cpu ().detach ().numpy ())
329
296
visualizer .display (instance , key = 'groundtruth' , title = 'Ground Truth' ) # TODO
330
- instance_ids = np .arange (instances [ 0 ] .size (1 ))
297
+ instance_ids = np .arange (instances .size (1 ))
331
298
else :
332
- visualizer .display (instances [0 ].cpu (), key = 'groundtruth' , title = 'Ground Truth' ) # TODO
299
+ visualizer .display (instances [0 , zslice ].cpu (), key = 'groundtruth' , title = 'Ground Truth' ) # TODO
333
300
instance_ids = instances [0 ].unique ()
334
301
instance_ids = instance_ids [instance_ids != 0 ]
335
- if (display_embedding ):
336
- center_x , center_y , samples_x , samples_y , sample_spatial_embedding_x , \
337
- sample_spatial_embedding_y , sigma_x , sigma_y , color_sample_dic , color_embedding_dic = \
338
- prepare_embedding_for_train_image (one_hot = one_hot , grid_x = grid_x , grid_y = grid_y ,
339
- pixel_x = pixel_x , pixel_y = pixel_y ,
340
- predictions = predictions , instance_ids = instance_ids ,
341
- center_images = center_images ,
342
- output = output , instances = instances , n_sigma = n_sigma )
343
- if one_hot :
344
- visualizer .display (torch .max (instances [0 ], dim = 0 )[0 ].cpu (), key = 'center' , title = 'Center' ,
345
- # torch.max returns a tuple
346
- center_x = center_x ,
347
- center_y = center_y ,
348
- samples_x = samples_x , samples_y = samples_y ,
349
- sample_spatial_embedding_x = sample_spatial_embedding_x ,
350
- sample_spatial_embedding_y = sample_spatial_embedding_y ,
351
- sigma_x = sigma_x , sigma_y = sigma_y ,
352
- color_sample = color_sample_dic , color_embedding = color_embedding_dic )
353
- else :
354
- visualizer .display (instances [0 ] > 0 , key = 'center' , title = 'Center' , center_x = center_x ,
355
- center_y = center_y ,
356
- samples_x = samples_x , samples_y = samples_y ,
357
- sample_spatial_embedding_x = sample_spatial_embedding_x ,
358
- sample_spatial_embedding_y = sample_spatial_embedding_y ,
359
- sigma_x = sigma_x , sigma_y = sigma_y ,
360
- color_sample = color_sample_dic , color_embedding = color_embedding_dic )
361
302
362
- visualizer .display (predictions .cpu (), key = 'prediction' , title = 'Prediction' ) # TODO
303
+
304
+ visualizer .display (predictions .cpu ()[zslice , ...], key = 'prediction' , title = 'Prediction' ) # TODO
363
305
364
306
loss_meter .update (loss .item ())
365
307
@@ -375,13 +317,14 @@ def invert_one_hot(image):
375
317
return instance
376
318
377
319
378
- def save_checkpoint (state , is_best , epoch , save_dir , name = 'checkpoint.pth' ):
320
+ def save_checkpoint (state , is_best , epoch , save_dir , save_checkpoint_frequency , name = 'checkpoint.pth' ):
379
321
print ('=> saving checkpoint' )
380
322
file_name = os .path .join (save_dir , name )
381
323
torch .save (state , file_name )
382
- if (epoch % 10 == 0 ):
383
- file_name2 = os .path .join (save_dir , str (epoch ) + "_" + name )
384
- torch .save (state , file_name2 )
324
+ if (save_checkpoint_frequency is not None ):
325
+ if (epoch % int (save_checkpoint_frequency ) == 0 ):
326
+ file_name2 = os .path .join (save_dir , str (epoch ) + "_" + name )
327
+ torch .save (state , file_name2 )
385
328
if is_best :
386
329
shutil .copyfile (file_name , os .path .join (
387
330
save_dir , 'best_iou_model.pth' ))
@@ -409,7 +352,6 @@ def begin_training(train_dataset_dict, val_dataset_dict, model_dict, loss_dict,
409
352
410
353
# train dataloader
411
354
412
-
413
355
train_dataset = get_dataset (train_dataset_dict ['name' ], train_dataset_dict ['kwargs' ])
414
356
train_dataset_it = torch .utils .data .DataLoader (train_dataset , batch_size = train_dataset_dict ['batch_size' ],
415
357
shuffle = True , drop_last = True ,
@@ -459,8 +401,6 @@ def lambda_(epoch):
459
401
configs ['pixel_x' ], configs ['one_hot' ])
460
402
461
403
# Visualizer
462
-
463
-
464
404
visualizer = Visualizer (('image' , 'groundtruth' , 'prediction' , 'center' ), color_map ) # 5 keys
465
405
466
406
# Logger
@@ -519,10 +459,10 @@ def lambda_(epoch):
519
459
train_loss = train_vanilla_3d (display = configs ['display' ],
520
460
display_embedding = configs ['display_embedding' ],
521
461
display_it = configs ['display_it' ], one_hot = configs ['one_hot' ],
522
- n_sigma = loss_dict ['lossOpts' ]['n_sigma' ], grid_x = configs ['grid_x' ],
462
+ n_sigma = loss_dict ['lossOpts' ]['n_sigma' ], zslice = configs [ 'display_zslice' ], grid_x = configs ['grid_x' ],
523
463
grid_y = configs ['grid_y' ], grid_z = configs ['grid_z' ],
524
464
pixel_x = configs ['pixel_x' ], pixel_y = configs ['pixel_y' ],
525
- pixel_z = configs ['pixel_z' ], args = loss_dict ['lossW' ])
465
+ pixel_z = configs ['pixel_z' ], args = loss_dict ['lossW' ], )
526
466
527
467
if (val_dataset_dict ['virtual_batch_multiplier' ] > 1 ):
528
468
val_loss , val_iou = val_3d (virtual_batch_multiplier = val_dataset_dict ['virtual_batch_multiplier' ],
@@ -532,7 +472,7 @@ def lambda_(epoch):
532
472
val_loss , val_iou = val_vanilla_3d (display = configs ['display' ],
533
473
display_embedding = configs ['display_embedding' ],
534
474
display_it = configs ['display_it' ], one_hot = configs ['one_hot' ],
535
- n_sigma = loss_dict ['lossOpts' ]['n_sigma' ], grid_x = configs ['grid_x' ],
475
+ n_sigma = loss_dict ['lossOpts' ]['n_sigma' ], zslice = configs [ 'display_zslice' ], grid_x = configs ['grid_x' ],
536
476
grid_y = configs ['grid_y' ], grid_z = configs ['grid_z' ],
537
477
pixel_x = configs ['pixel_x' ], pixel_y = configs ['pixel_y' ], pixel_z = configs ['pixel_z' ],
538
478
args = loss_dict ['lossW' ])
@@ -558,6 +498,6 @@ def lambda_(epoch):
558
498
'optim_state_dict' : optimizer .state_dict (),
559
499
'logger_data' : logger .data ,
560
500
}
561
- save_checkpoint (state , is_best , epoch , save_dir = configs ['save_dir' ])
501
+ save_checkpoint (state , is_best , epoch , save_dir = configs ['save_dir' ], save_checkpoint_frequency = configs [ 'save_checkpoint_frequency' ] )
562
502
563
503
0 commit comments