-
Notifications
You must be signed in to change notification settings - Fork 43
/
train.py
127 lines (113 loc) · 5.32 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
import argparse
import collections
import os
import transformers
from sacred import Experiment
import data_loader.data_loader as module_data
import model.loss as module_loss
import model.metric as module_metric
import model.model as module_arch
import utils.visualizer as module_vis
from parse_config import ConfigParser
from trainer import Trainer
from utils.util import replace_nested_dict_item
ex = Experiment('train')
@ex.main
def run():
logger = config.get_logger('train')
os.environ['TOKENIZERS_PARALLELISM'] = "false"
# TODO: improve Create identity (do nothing) visualiser?
if config['visualizer']['type'] != "":
visualizer = config.initialize(
name='visualizer',
module=module_vis,
exp_name=config['name'],
web_dir=config._web_log_dir
)
else:
visualizer = None
# build tokenizer
tokenizer = transformers.AutoTokenizer.from_pretrained(config['arch']['args']['text_params']['model'],
TOKENIZERS_PARALLELISM=False)
# setup data_loader instances
data_loader, valid_data_loader = init_dataloaders(config, module_data)
print('Train dataset: ', [x.n_samples for x in data_loader], ' samples')
print('Val dataset: ', [x.n_samples for x in valid_data_loader], ' samples')
# build model architecture, then print to console
model = config.initialize('arch', module_arch)
logger.info(model)
# get function handles of loss and metrics
loss = config.initialize(name="loss", module=module_loss)
metrics = [getattr(module_metric, met) for met in config['metrics']]
# build optimizer, learning rate scheduler. delete every lines containing lr_scheduler for disabling scheduler
trainable_params = filter(lambda p: p.requires_grad, model.parameters())
optimizer = config.initialize('optimizer', transformers, trainable_params)
lr_scheduler = None
if 'lr_scheduler' in config._config:
if hasattr(transformers, config._config['lr_scheduler']['type']):
lr_scheduler = config.initialize('lr_scheduler', transformers, optimizer)
else:
print('lr scheduler not found')
if config['trainer']['neptune']:
writer = ex
else:
writer = None
trainer = Trainer(model, loss, metrics, optimizer,
config=config,
data_loader=data_loader,
valid_data_loader=valid_data_loader,
lr_scheduler=lr_scheduler,
visualizer=visualizer,
writer=writer,
tokenizer=tokenizer,
max_samples_per_epoch=config['trainer']['max_samples_per_epoch'])
trainer.train()
def init_dataloaders(config, module_data):
"""
We need a way to change split from 'train' to 'val'.
"""
if "type" in config["data_loader"] and "args" in config["data_loader"]:
# then its a single dataloader
data_loader = [config.initialize("data_loader", module_data)]
config['data_loader']['args'] = replace_nested_dict_item(config['data_loader']['args'], 'split', 'val')
valid_data_loader = [config.initialize("data_loader", module_data)]
elif isinstance(config["data_loader"], list):
data_loader = [config.initialize('data_loader', module_data, index=idx) for idx in
range(len(config['data_loader']))]
new_cfg_li = []
for dl_cfg in config['data_loader']:
dl_cfg['args'] = replace_nested_dict_item(dl_cfg['args'], 'split', 'val')
new_cfg_li.append(dl_cfg)
config._config['data_loader'] = new_cfg_li
valid_data_loader = [config.initialize('data_loader', module_data, index=idx) for idx in
range(len(config['data_loader']))]
else:
raise ValueError("Check data_loader config, not correct format.")
return data_loader, valid_data_loader
if __name__ == '__main__':
args = argparse.ArgumentParser(description='PyTorch Template')
args.add_argument('-c', '--config', default=None, type=str,
help='config file path (default: None)')
args.add_argument('-r', '--resume', default=None, type=str,
help='path to latest checkpoint (default: None)')
args.add_argument('-d', '--device', default=None, type=str,
help='indices of GPUs to enable (default: all)')
args.add_argument('-o', '--observe', action='store_true',
help='Whether to observe (neptune)')
# custom cli options to modify configuration from default values given in json file.
CustomArgs = collections.namedtuple('CustomArgs', 'flags type target')
options = [
CustomArgs(['--lr', '--learning_rate'], type=float, target=('optimizer', 'args', 'lr')),
CustomArgs(['--bs', '--batch_size'], type=int, target=('data_loader', 'args', 'batch_size')),
]
config = ConfigParser(args, options)
ex.add_config(config._config)
if config['trainer']['neptune']:
# delete this error if you have added your own neptune credentials neptune.ai
raise ValueError('Neptune credentials not set up yet.')
ex.observers.append(NeptuneObserver(
api_token='INSERT TOKEN',
project_name='INSERT PROJECT NAME'))
ex.run()
else:
run()