-
-
Notifications
You must be signed in to change notification settings - Fork 619
/
main.py
427 lines (329 loc) · 15.1 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
import os
from functools import partial
from pathlib import Path
import fire
import torch
try:
from torch.cuda.amp import autocast, GradScaler
except ImportError:
raise RuntimeError("Please, use recent PyTorch version, e.g. >=1.6.0")
import dataflow as data
import utils
import vis
from py_config_runner import ConfigObject, get_params, InferenceConfigSchema, TrainvalConfigSchema
import ignite.distributed as idist
from ignite.contrib.engines import common
from ignite.engine import Engine, Events
from ignite.handlers import Checkpoint, Timer
from ignite.metrics import Accuracy, Frequency, TopKCategoricalAccuracy
from ignite.utils import manual_seed, setup_logger
def training(local_rank, config, logger, with_clearml):
rank = idist.get_rank()
manual_seed(config.seed + local_rank)
train_loader = config.train_loader
val_loader = config.val_loader
train_eval_loader = config.train_eval_loader
model, optimizer, criterion = utils.initialize(config)
# Setup trainer for this specific task
trainer = create_trainer(model, optimizer, criterion, train_loader.sampler, config, logger, with_clearml)
# Setup evaluators
accuracy = Accuracy()
val_metrics = {
"Accuracy": accuracy,
"Top-5 Accuracy": TopKCategoricalAccuracy(k=5),
"Error": (1.0 - accuracy) * 100,
}
if ("val_metrics" in config) and isinstance(config.val_metrics, dict):
val_metrics.update(config.val_metrics)
evaluator = create_evaluator(model, val_metrics, config, with_clearml, tag="val")
train_evaluator = create_evaluator(model, val_metrics, config, with_clearml, tag="train")
val_interval = config.get("val_interval", 1)
# Run validation on every val_interval epoch, in the end of the training
# and in the begining if config.start_by_validation is True
event = Events.EPOCH_COMPLETED(every=val_interval)
if config.num_epochs % val_interval != 0:
event |= Events.COMPLETED
if config.get("start_by_validation", False):
event |= Events.STARTED
@trainer.on(event)
def run_validation():
epoch = trainer.state.epoch
state = train_evaluator.run(train_eval_loader)
utils.log_metrics(logger, epoch, state.times["COMPLETED"], "Train", state.metrics)
state = evaluator.run(val_loader)
utils.log_metrics(logger, epoch, state.times["COMPLETED"], "Test", state.metrics)
score_metric_name = "Accuracy"
if "es_patience" in config:
common.add_early_stopping_by_val_score(config.es_patience, evaluator, trainer, metric_name=score_metric_name)
# Store 2 best models by validation accuracy:
common.gen_save_best_models_by_val_score(
save_handler=utils.get_save_handler(config.output_path.as_posix(), with_clearml),
evaluator=evaluator,
models=model,
metric_name=score_metric_name,
n_saved=2,
trainer=trainer,
tag="val",
)
# Setup Tensorboard logger
if rank == 0:
tb_logger = common.setup_tb_logging(
config.output_path.as_posix(),
trainer,
optimizer,
evaluators={"training": train_evaluator, "validation": evaluator},
)
# Log validation predictions as images
# We define a custom event filter to log less frequently the images (to reduce storage size)
# - we plot images with masks of the middle validation batch
# - once every 3 validations and
# - at the end of the training
def custom_event_filter(_, val_iteration):
c1 = val_iteration == 1
c2 = trainer.state.epoch % (config.get("val_interval", 1) * 3) == 0
c2 |= trainer.state.epoch == config.num_epochs
return c1 and c2
# Image denormalization function to plot predictions with images
mean = config.get("mean", (0.485, 0.456, 0.406))
std = config.get("std", (0.229, 0.224, 0.225))
img_denormalize = partial(data.denormalize, mean=mean, std=std)
tb_logger.attach(
evaluator,
log_handler=vis.predictions_gt_images_handler(
img_denormalize_fn=img_denormalize, n_images=12, another_engine=trainer, prefix_tag="validation"
),
event_name=Events.ITERATION_COMPLETED(event_filter=custom_event_filter),
)
tb_logger.attach(
train_evaluator,
log_handler=vis.predictions_gt_images_handler(
img_denormalize_fn=img_denormalize, n_images=12, another_engine=trainer, prefix_tag="training"
),
event_name=Events.ITERATION_COMPLETED(event_filter=custom_event_filter),
)
trainer.run(train_loader, max_epochs=config.num_epochs)
if idist.get_rank() == 0:
tb_logger.close()
def create_trainer(model, optimizer, criterion, train_sampler, config, logger, with_clearml):
device = config.device
prepare_batch = data.prepare_batch
# Setup trainer
accumulation_steps = config.get("accumulation_steps", 1)
model_output_transform = config.get("model_output_transform", lambda x: x)
with_amp = config.get("with_amp", True)
scaler = GradScaler(enabled=with_amp)
def training_step(engine, batch):
model.train()
x, y = prepare_batch(batch, device=device, non_blocking=True)
with autocast(enabled=with_amp):
y_pred = model(x)
y_pred = model_output_transform(y_pred)
loss = criterion(y_pred, y) / accumulation_steps
output = {"supervised batch loss": loss.item(), "num_samples": len(x)}
scaler.scale(loss).backward()
if engine.state.iteration % accumulation_steps == 0:
scaler.step(optimizer)
scaler.update()
optimizer.zero_grad()
return output
trainer = Engine(training_step)
trainer.logger = logger
throughput_metric = Frequency(output_transform=lambda x: x["num_samples"])
throughput_metric.attach(trainer, name="Throughput")
timer = Timer(average=True)
timer.attach(
trainer,
resume=Events.ITERATION_STARTED,
pause=Events.ITERATION_COMPLETED,
step=Events.ITERATION_COMPLETED,
)
@trainer.on(Events.ITERATION_COMPLETED(every=20))
def log_progress():
metrics = dict(trainer.state.metrics)
epoch_length = trainer.state.epoch_length
metrics["ETA (seconds)"] = int((epoch_length - (trainer.state.iteration % epoch_length)) * timer.value())
metrics_str = ", ".join([f"{k}: {v}" for k, v in metrics.items()])
metrics_format = (
f"[{trainer.state.epoch}/{trainer.state.max_epochs}] "
+ f"Iter={trainer.state.iteration % epoch_length}/{epoch_length}: "
+ f"{metrics_str}"
)
trainer.logger.info(metrics_format)
output_names = [
"supervised batch loss",
]
lr_scheduler = config.lr_scheduler
to_save = {
"model": model,
"optimizer": optimizer,
"lr_scheduler": lr_scheduler,
"trainer": trainer,
"amp": scaler,
}
save_every_iters = config.get("save_every_iters", 1000)
common.setup_common_training_handlers(
trainer,
train_sampler,
to_save=to_save,
save_every_iters=save_every_iters,
save_handler=utils.get_save_handler(config.output_path.as_posix(), with_clearml),
lr_scheduler=lr_scheduler,
output_names=output_names,
# with_pbars=not with_clearml,
with_pbars=False,
log_every_iters=1,
)
resume_from = config.get("resume_from", None)
if resume_from is not None:
checkpoint_fp = Path(resume_from)
assert checkpoint_fp.exists(), f"Checkpoint '{checkpoint_fp.as_posix()}' is not found"
logger.info(f"Resume from a checkpoint: {checkpoint_fp.as_posix()}")
checkpoint = torch.load(checkpoint_fp.as_posix(), map_location="cpu")
Checkpoint.load_objects(to_load=to_save, checkpoint=checkpoint)
return trainer
def create_evaluator(model, metrics, config, with_clearml, tag="val"):
model_output_transform = config.get("model_output_transform", lambda x: x)
with_amp = config.get("with_amp", True)
prepare_batch = data.prepare_batch
@torch.no_grad()
def evaluate_step(engine, batch):
model.eval()
with autocast(enabled=with_amp):
x, y = prepare_batch(batch, device=config.device, non_blocking=True)
y_pred = model(x)
y_pred = model_output_transform(y_pred)
return y_pred, y
evaluator = Engine(evaluate_step)
for name, metric in metrics.items():
metric.attach(evaluator, name)
if idist.get_rank() == 0 and (not with_clearml):
common.ProgressBar(desc=f"Evaluation ({tag})", persist=False).attach(evaluator)
return evaluator
def setup_experiment_tracking(config, with_clearml, task_type="training"):
from datetime import datetime
assert task_type in ("training", "testing"), task_type
output_path = ""
if idist.get_rank() == 0:
if with_clearml:
from clearml import Task
schema = TrainvalConfigSchema if task_type == "training" else InferenceConfigSchema
task = Task.init("ImageNet Training", config.config_filepath.stem, task_type=task_type)
task.connect_configuration(config.config_filepath.as_posix())
task.upload_artifact(config.script_filepath.name, config.script_filepath.as_posix())
task.upload_artifact(config.config_filepath.name, config.config_filepath.as_posix())
task.connect(get_params(config, schema))
output_path = Path(os.environ.get("CLEARML_OUTPUT_PATH", "/tmp"))
output_path = output_path / "clearml" / datetime.now().strftime("%Y%m%d-%H%M%S")
else:
import shutil
output_path = Path(os.environ.get("OUTPUT_PATH", "/tmp/output-imagenet"))
output_path = output_path / task_type / config.config_filepath.stem
output_path = output_path / datetime.now().strftime("%Y%m%d-%H%M%S")
output_path.mkdir(parents=True, exist_ok=True)
shutil.copyfile(config.script_filepath.as_posix(), output_path / config.script_filepath.name)
shutil.copyfile(config.config_filepath.as_posix(), output_path / config.config_filepath.name)
output_path = output_path.as_posix()
return Path(idist.broadcast(output_path, src=0))
def run_training(config_filepath, backend="nccl", with_clearml=True):
"""Main entry to run training experiment
Args:
config_filepath (str): training configuration .py file
backend (str): distributed backend: nccl, gloo or None to run without distributed config
with_clearml (bool): if True, uses ClearML as experiment tracking system
"""
assert torch.cuda.is_available(), torch.cuda.is_available()
assert torch.backends.cudnn.enabled
torch.backends.cudnn.benchmark = True
config_filepath = Path(config_filepath)
assert config_filepath.exists(), f"File '{config_filepath.as_posix()}' is not found"
with idist.Parallel(backend=backend) as parallel:
logger = setup_logger(name="ImageNet Training", distributed_rank=idist.get_rank())
config = ConfigObject(config_filepath)
TrainvalConfigSchema.validate(config)
config.script_filepath = Path(__file__)
output_path = setup_experiment_tracking(config, with_clearml=with_clearml)
config.output_path = output_path
utils.log_basic_info(logger, get_params(config, TrainvalConfigSchema))
try:
parallel.run(training, config, logger=logger, with_clearml=with_clearml)
except KeyboardInterrupt:
logger.info("Catched KeyboardInterrupt -> exit")
except Exception as e: # noqa
logger.exception("")
raise e
def get_model_weights(config, logger, with_clearml):
path = ""
if with_clearml:
from clearml import Model
if idist.get_rank() > 0:
idist.barrier()
else:
model_id = config.weights_path
logger.info(f"Loading trained model: {model_id}")
model = Model(model_id)
assert model is not None, f"{model_id}"
path = model.get_local_copy()
idist.barrier()
path = idist.broadcast(path, src=0)
else:
path = config.weights_path
logger.info(f"Loading {path}")
assert Path(path).exists(), f"{path} is not found"
return torch.load(path)
def evaluation(local_rank, config, logger, with_clearml):
rank = idist.get_rank()
device = idist.device()
manual_seed(config.seed + local_rank)
data_loader = config.data_loader
model = config.model.to(device)
# Load weights:
state_dict = get_model_weights(config, logger, with_clearml)
model.load_state_dict(state_dict)
# Adapt model to dist config
model = idist.auto_model(model)
# Setup evaluators
val_metrics = {
"Accuracy": Accuracy(),
"Top-5 Accuracy": TopKCategoricalAccuracy(k=5),
}
if ("val_metrics" in config) and isinstance(config.val_metrics, dict):
val_metrics.update(config.val_metrics)
evaluator = create_evaluator(model, val_metrics, config, with_clearml, tag="val")
# Setup Tensorboard logger
if rank == 0:
tb_logger = common.TensorboardLogger(log_dir=config.output_path.as_posix())
tb_logger.attach_output_handler(evaluator, event_name=Events.COMPLETED, tag="validation", metric_names="all")
state = evaluator.run(data_loader)
utils.log_metrics(logger, 0, state.times["COMPLETED"], "Validation", state.metrics)
if idist.get_rank() == 0:
tb_logger.close()
def run_evaluation(config_filepath, backend="nccl", with_clearml=True):
"""Main entry to run model's evaluation:
- compute validation metrics
Args:
config_filepath (str): evaluation configuration .py file
backend (str): distributed backend: nccl, gloo, horovod or None to run without distributed config
with_clearml (bool): if True, uses ClearML as experiment tracking system
"""
assert torch.cuda.is_available(), torch.cuda.is_available()
assert torch.backends.cudnn.enabled
torch.backends.cudnn.benchmark = True
config_filepath = Path(config_filepath)
assert config_filepath.exists(), f"File '{config_filepath.as_posix()}' is not found"
with idist.Parallel(backend=backend) as parallel:
logger = setup_logger(name="ImageNet Evaluation", distributed_rank=idist.get_rank())
config = ConfigObject(config_filepath)
InferenceConfigSchema.validate(config)
config.script_filepath = Path(__file__)
output_path = setup_experiment_tracking(config, with_clearml=with_clearml, task_type="testing")
config.output_path = output_path
utils.log_basic_info(logger, get_params(config, InferenceConfigSchema))
try:
parallel.run(evaluation, config, logger=logger, with_clearml=with_clearml)
except KeyboardInterrupt:
logger.info("Catched KeyboardInterrupt -> exit")
except Exception as e: # noqa
logger.exception("")
raise e
if __name__ == "__main__":
fire.Fire({"training": run_training, "eval": run_evaluation})