-
Notifications
You must be signed in to change notification settings - Fork 1.7k
/
Copy pathtrain.py
495 lines (444 loc) · 22.4 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import gc
import time
import yaml
import json
from collections import deque
import shutil
from copy import deepcopy
import paddle
import paddle.nn.functional as F
from paddleseg.utils import (TimeAverager, calculate_eta, resume,
worker_init_fn, train_profiler, op_flops_funs,
init_ema_params, update_ema_model, logger)
from paddleseg.core.val import evaluate
from paddleseg.core.export import export, save_model_info, update_train_results
from paddleseg.utils.logger import setup_logger
def check_logits_losses(logits_list, losses):
len_logits = len(logits_list)
len_losses = len(losses['types'])
if len_logits != len_losses:
raise RuntimeError(
'The length of logits_list should equal to the types of loss config: {} != {}.'
.format(len_logits, len_losses))
def loss_computation(logits_list, labels, edges, losses):
check_logits_losses(logits_list, losses)
loss_list = []
for i in range(len(logits_list)):
logits = logits_list[i]
loss_i = losses['types'][i]
coef_i = losses['coef'][i]
if loss_i.__class__.__name__ in ('BCELoss', ) and loss_i.edge_label:
# Use edges as labels According to loss type.
loss_list.append(coef_i * loss_i(logits, edges))
elif loss_i.__class__.__name__ == 'MixedLoss':
mixed_loss_list = loss_i(logits, labels)
for mixed_loss in mixed_loss_list:
loss_list.append(coef_i * mixed_loss)
elif loss_i.__class__.__name__ in ("KLLoss", ):
loss_list.append(coef_i *
loss_i(logits_list[0], logits_list[1].detach()))
else:
loss_list.append(coef_i * loss_i(logits, labels))
return loss_list
def train(model,
train_dataset,
val_dataset=None,
optimizer=None,
save_dir='output',
iters=10000,
batch_size=2,
early_stop_intervals=None,
resume_model=None,
save_interval=1000,
log_iters=10,
num_workers=0,
use_vdl=False,
use_ema=False,
losses=None,
keep_checkpoint_max=5,
test_config=None,
precision='fp32',
amp_level='O1',
profiler_options=None,
to_static_training=False,
logger=setup_logger(__file__),
print_mem_info=False,
shuffle=True,
**kwargs):
"""
Launch training.
Args:
model(nn.Layer): A semantic segmentation model.
train_dataset (paddle.io.Dataset): Used to read and process training datasets.
val_dataset (paddle.io.Dataset, optional): Used to read and process validation datasets.
optimizer (paddle.optimizer.Optimizer): The optimizer.
save_dir (str, optional): The directory for saving the model snapshot. Default: 'output'.
iters (int, optional): How may iters to train the model. Defualt: 10000.
batch_size (int, optional): Mini batch size of one gpu or cpu. Default: 2.
resume_model (str, optional): The path of resume model.
save_interval (int, optional): How many iters to save a model snapshot once during training. Default: 1000.
log_iters (int, optional): Display logging information at every log_iters. Default: 10.
num_workers (int, optional): Num workers for data loader. Default: 0.
use_vdl (bool, optional): Whether to record the data to VisualDL during training. Default: False.
losses (dict, optional): A dict including 'types' and 'coef'. The length of coef should equal to 1 or len(losses['types']).
The 'types' item is a list of object of paddleseg.models.losses while the 'coef' item is a list of the relevant coefficient.
keep_checkpoint_max (int, optional): Maximum number of checkpoints to save. Default: 5.
test_config(dict, optional): Evaluation config.
precision (str, optional): Use AMP if precision='fp16'. If precision='fp32', the training is normal.
amp_level (str, optional): Auto mixed precision level. Accepted values are “O1” and “O2”: O1 represent mixed precision,
the input data type of each operator will be casted by white_list and black_list; O2 represent Pure fp16, all operators
parameters and input data will be casted to fp16, except operators in black_list, don’t support fp16 kernel and batchnorm. Default is O1(amp)
profiler_options (str, optional): The option of train profiler.
to_static_training (bool, optional): Whether to use @to_static for training.
logger (Logger, optional): Logger for logging. Default: setup_logger(__file__).
print_mem_info (bool, optional): Whether to print memory info. Default: False.
"""
if use_ema:
ema_model = deepcopy(model)
ema_model.eval()
for param in ema_model.parameters():
param.stop_gradient = True
uniform_output_enabled = kwargs.pop("uniform_output_enabled", False)
cli_args = kwargs.pop("cli_args", None)
model.train()
nranks = paddle.distributed.ParallelEnv().nranks
local_rank = paddle.distributed.ParallelEnv().local_rank
start_iter = 0
stop_count = 0
stop_status = False
if resume_model is not None:
start_iter = resume(model, optimizer, resume_model)
if not os.path.isdir(save_dir):
if os.path.exists(save_dir):
os.remove(save_dir)
os.makedirs(save_dir, exist_ok=True)
# use amp
if precision == 'fp16':
logger.info('use AMP to train. AMP level = {}'.format(amp_level))
scaler = paddle.amp.GradScaler(init_loss_scaling=1024)
if amp_level == 'O2':
model, optimizer = paddle.amp.decorate(models=model,
optimizers=optimizer,
level='O2',
save_dtype='float32')
if nranks > 1:
paddle.distributed.fleet.init(is_collective=True)
optimizer = paddle.distributed.fleet.distributed_optimizer(
optimizer) # The return is Fleet object
ddp_model = paddle.distributed.fleet.distributed_model(model)
batch_sampler = paddle.io.DistributedBatchSampler(train_dataset,
batch_size=batch_size,
shuffle=shuffle,
drop_last=True)
loader = paddle.io.DataLoader(
train_dataset,
batch_sampler=batch_sampler,
num_workers=num_workers,
return_list=True,
worker_init_fn=worker_init_fn,
)
if use_vdl:
from visualdl import LogWriter
log_writer = LogWriter(save_dir)
if to_static_training:
model = paddle.jit.to_static(model)
logger.info("Successfully applied @to_static")
avg_loss = 0.0
avg_loss_list = []
iters_per_epoch = len(batch_sampler)
best_mean_iou = -1.0
best_ema_mean_iou = -1.0
best_model_iter = -1
reader_cost_averager = TimeAverager()
batch_cost_averager = TimeAverager()
save_models = deque()
batch_start = time.time()
iter = start_iter
while iter < iters and not stop_status:
if iter == start_iter and use_ema:
init_ema_params(ema_model, model)
for data in loader:
if iter == start_iter and "npu" in paddle.get_device():
logger.info(
"sleep 1 second at the first iter for npu to fix nan loss")
time.sleep(1)
iter += 1
if iter > iters:
version = paddle.__version__
if version == '2.1.2':
continue
else:
break
reader_cost_averager.record(time.time() - batch_start)
images = data['img']
labels = data['label'].astype('int64')
edges = None
if 'edge' in data.keys():
edges = data['edge'].astype('int64')
if hasattr(model, 'data_format') and model.data_format == 'NHWC':
images = images.transpose((0, 2, 3, 1))
if precision == 'fp16':
with paddle.amp.auto_cast(
level=amp_level,
enable=True,
custom_white_list={
"elementwise_add", "batch_norm", "sync_batch_norm"
},
custom_black_list={'bilinear_interp_v2'}):
logits_list = ddp_model(images) if nranks > 1 else model(
images)
if nranks > 1 and hasattr(ddp_model._layers,
'loss_computation'):
loss_list = ddp_model._layers.loss_computation(
logits_list, losses, data)
elif nranks == 1 and hasattr(model, 'loss_computation'):
loss_list = model.loss_computation(
logits_list, losses, data)
else:
loss_list = loss_computation(logits_list=logits_list,
labels=labels,
edges=edges,
losses=losses)
loss = sum(loss_list)
scaled = scaler.scale(loss) # scale the loss
scaled.backward() # do backward
if isinstance(optimizer, paddle.distributed.fleet.Fleet):
scaler.step(optimizer.user_defined_optimizer)
else:
scaler.step(optimizer) # update parameters
scaler.update() # update parameters
else:
logits_list = ddp_model(images) if nranks > 1 else model(images)
if nranks > 1 and hasattr(ddp_model._layers,
'loss_computation'):
loss_list = ddp_model._layers.loss_computation(
logits_list, losses, data)
elif nranks == 1 and hasattr(model, 'loss_computation'):
loss_list = model.loss_computation(logits_list, losses,
data)
else:
loss_list = loss_computation(logits_list=logits_list,
labels=labels,
edges=edges,
losses=losses)
loss = sum(loss_list)
loss.backward()
optimizer.step()
lr = optimizer.get_lr()
# update lr
if isinstance(optimizer, paddle.distributed.fleet.Fleet):
lr_sche = optimizer.user_defined_optimizer._learning_rate
else:
lr_sche = optimizer._learning_rate
if isinstance(lr_sche, paddle.optimizer.lr.LRScheduler):
if isinstance(lr_sche, paddle.optimizer.lr.ReduceOnPlateau):
lr_sche.step(loss)
else:
lr_sche.step()
train_profiler.add_profiler_step(profiler_options)
model.clear_gradients()
avg_loss += float(loss)
if not avg_loss_list:
avg_loss_list = [float(l) for l in loss_list]
else:
for i in range(len(loss_list)):
avg_loss_list[i] += float(loss_list[i])
batch_cost_averager.record(time.time() - batch_start,
num_samples=batch_size)
if (iter) % log_iters == 0:
avg_loss /= log_iters
avg_loss_list = [l / log_iters for l in avg_loss_list]
remain_iters = iters - iter
avg_train_batch_cost = batch_cost_averager.get_average()
avg_train_reader_cost = reader_cost_averager.get_average()
eta = calculate_eta(remain_iters, avg_train_batch_cost)
max_mem_reserved_str = ""
max_mem_allocated_str = ""
if paddle.device.is_compiled_with_cuda() and print_mem_info:
max_mem_reserved_str = f", max_mem_reserved: {paddle.device.cuda.max_memory_reserved() // (1024 ** 2)} MB"
max_mem_allocated_str = f", max_mem_allocated: {paddle.device.cuda.max_memory_allocated() // (1024 ** 2)} MB"
logger.info(
"[TRAIN] epoch: {}, iter: {}/{}, loss: {:.4f}, lr: {:.6f}, batch_cost: {:.4f}, reader_cost: {:.5f}, ips: {:.4f} samples/sec{}{} | ETA {}"
.format((iter - 1) // iters_per_epoch + 1, iter, iters,
avg_loss, lr, avg_train_batch_cost,
avg_train_reader_cost,
batch_cost_averager.get_ips_average(),
max_mem_reserved_str, max_mem_allocated_str, eta))
if use_vdl:
log_writer.add_scalar('Train/loss', avg_loss, iter)
# Record all losses if there are more than 2 losses.
if len(avg_loss_list) > 1:
avg_loss_dict = {}
for i, value in enumerate(avg_loss_list):
avg_loss_dict['loss_' + str(i)] = value
for key, value in avg_loss_dict.items():
log_tag = 'Train/' + key
log_writer.add_scalar(log_tag, value, iter)
log_writer.add_scalar('Train/lr', lr, iter)
log_writer.add_scalar('Train/batch_cost',
avg_train_batch_cost, iter)
log_writer.add_scalar('Train/reader_cost',
avg_train_reader_cost, iter)
avg_loss = 0.0
avg_loss_list = []
reader_cost_averager.reset()
batch_cost_averager.reset()
if use_ema:
update_ema_model(ema_model, model, step=iter)
if (iter % save_interval == 0 or iter == iters) and (val_dataset
is not None):
num_workers = 1 if num_workers > 0 else 0
if test_config is None:
test_config = {}
mean_iou, acc, _, _, _ = evaluate(model,
val_dataset,
num_workers=num_workers,
precision=precision,
amp_level=amp_level,
**test_config)
if use_ema:
ema_mean_iou, ema_acc, _, _, _ = evaluate(
ema_model,
val_dataset,
num_workers=num_workers,
precision=precision,
amp_level=amp_level,
**test_config)
model.train()
if (iter % save_interval == 0 or iter == iters) and local_rank == 0:
current_save_dir = os.path.join(save_dir,
"iter_{}".format(iter))
if not os.path.isdir(current_save_dir):
os.makedirs(current_save_dir)
paddle.save(model.state_dict(),
os.path.join(current_save_dir, 'model.pdparams'))
paddle.save(optimizer.state_dict(),
os.path.join(current_save_dir, 'model.pdopt'))
if uniform_output_enabled:
export(cli_args, model, current_save_dir)
gc.collect()
if use_ema:
paddle.save(
ema_model.state_dict(),
os.path.join(current_save_dir, 'ema_model.pdparams'))
if uniform_output_enabled:
export(cli_args, ema_model, current_save_dir, use_ema)
gc.collect()
save_models.append(current_save_dir)
if len(save_models) > keep_checkpoint_max > 0:
model_to_remove = save_models.popleft()
shutil.rmtree(model_to_remove)
if val_dataset is not None:
states_dict = {'mIoU': mean_iou, 'Acc': acc, 'iter': iter}
paddle.save(
states_dict,
os.path.join(current_save_dir, 'model.pdstates'))
if uniform_output_enabled:
save_model_info(states_dict, current_save_dir)
update_train_results(cli_args,
"iter_{}".format(iter),
states_dict,
done_flag=iter == iters)
if mean_iou > best_mean_iou:
stop_count = 0
best_mean_iou = mean_iou
best_model_iter = iter
best_model_dir = os.path.join(save_dir, "best_model")
paddle.save(
model.state_dict(),
os.path.join(best_model_dir, 'model.pdparams'))
paddle.save(
states_dict,
os.path.join(best_model_dir, 'model.pdstates'))
if uniform_output_enabled:
export(cli_args, model, best_model_dir)
gc.collect()
save_model_info(states_dict, best_model_dir)
update_train_results(cli_args,
"best_model",
states_dict,
done_flag=iter == iters)
elif mean_iou < best_mean_iou:
stop_count += 1
if early_stop_intervals is not None and stop_count >= early_stop_intervals:
stop_status = True
logger.info(
'Early stopping at iter {}. The best mean IoU is {:.4f}.'
.format(iter, best_mean_iou))
else:
logger.info(
'[EVAL] The model with the best validation mIoU ({:.4f}) was saved at iter {}.'
.format(best_mean_iou, best_model_iter))
if use_ema:
ema_states_dict = {
'mIoU': ema_mean_iou,
'Acc': ema_acc,
'iter': iter
}
paddle.save(
ema_states_dict,
os.path.join(current_save_dir,
'ema_model.pdstates'))
if ema_mean_iou > best_ema_mean_iou:
best_ema_mean_iou = ema_mean_iou
best_ema_model_iter = iter
best_ema_model_dir = os.path.join(
save_dir, "ema_best_model")
paddle.save(
ema_model.state_dict(),
os.path.join(best_ema_model_dir,
'ema_model.pdparams'))
paddle.save(
ema_states_dict,
os.path.join(best_ema_model_dir,
'ema_model.pdstates'))
if uniform_output_enabled:
export(cli_args, ema_model, best_ema_model_dir,
use_ema)
gc.collect()
save_model_info(ema_states_dict,
best_ema_model_dir)
update_train_results(cli_args,
"ema_best_model",
ema_states_dict,
done_flag=iter == iters,
ema=use_ema)
logger.info(
'[EVAL] The EMA model with the best validation mIoU ({:.4f}) was saved at iter {}.'
.format(best_ema_mean_iou, best_ema_model_iter))
if use_vdl:
log_writer.add_scalar('Evaluate/mIoU', mean_iou, iter)
log_writer.add_scalar('Evaluate/Acc', acc, iter)
if use_ema:
log_writer.add_scalar('Evaluate/Ema_mIoU',
ema_mean_iou, iter)
log_writer.add_scalar('Evaluate/Ema_Acc', ema_acc,
iter)
if stop_status:
break
model.train()
batch_start = time.time()
# Calculate flops.
if local_rank == 0 and not (precision == 'fp16' and amp_level == 'O2'):
_, c, h, w = images.shape
_ = paddle.flops(
model, [1, c, h, w],
custom_ops={paddle.nn.SyncBatchNorm: op_flops_funs.count_syncbn})
# Sleep for a second to let dataloader release resources.
time.sleep(1)
if use_vdl:
log_writer.close()