-
Notifications
You must be signed in to change notification settings - Fork 0
/
model_manager.py
288 lines (253 loc) · 9.87 KB
/
model_manager.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
"""
A helper is a wrapper that joins a DataSet with a Trainer in a more compact way.
"""
import argparse
import json
import pathlib
from os import listdir
import pytorch_lightning as ptl
from pytorch_lightning.callbacks import ModelCheckpoint
from pytorch_lightning.loggers import TestTubeLogger
from config import makedirs, MODELS_DATA_PATH, RAW_DATA_PATH
from lightning_modules import L_ResNext50, L_WavenetTransformerClassifier, L_WavenetLSTMClassifier, L_GMMClassifier, \
L_WavenetClassifier, L_Conv1DClassifier, L_RNNClassifier
from loaders import CepstrumDataset, WaveformDataset, ExperimentDataset
class AbstractHelper:
model_name = 'UnnamedHelper'
dataset = ExperimentDataset
lightning_module = L_ResNext50
dataset_ratios = (0.5, 0.25, 0.25)
def __init__(self, experiment_name, parser, data_path, label_filename, models_path, dummy_mode=False):
"""
The helper manages model interaction with data.
:param experiment_name: Name of the experiment
:param parser: Argparse to add further paramters to the cli
:param data_path: Absolute path where the data is stored
:param label_filename: File name of the csv file that contains the filename and label data.
:param models_path: Path to where model data is stored.
"""
self.models_path = models_path
self.experiment_name = experiment_name
train_dataset, test_dataset, eval_dataset, number_of_classes = self.dataset.init_sets(
data_path,
label_filename,
ratio=self.dataset_ratios,
dummy_mode=dummy_mode
)
parser = self.lightning_module.add_model_specific_args(parser, None)
hyperparams = parser.parse_args()
self.module = self.lightning_module(
hyperparams,
number_of_classes,
train_dataset,
eval_dataset,
test_dataset,
)
gpus = json.loads(hyperparams.gpus)
self.save_dir = models_path / model_name / experiment_name
makedirs(self.save_dir)
if 'distributed_backend' not in hyperparams:
hyperparams.distributed_backend = 'dp'
# todo: connect ddp, fix gpu specification
# Logger Specification
# We use TubeLogger for nicer structure
version = 1
logger = TestTubeLogger(
save_dir=self.save_dir,
version=version # fixed to one to ensure checkpoint load
)
ckpt_folder = self.save_dir / 'default' / 'version_{}'.format(version) / 'checkpoints'
resume_from_checkpoint = self.find_best_epoch(ckpt_folder)
# trainer with some optimizations
self.trainer = ptl.Trainer(
gpus=gpus if len(gpus) else 0,
profiler=False, # for once is good
auto_scale_batch_size=False, # i prefer manually
auto_lr_find=False, # mostly diverges
# distributed_backend='ddp', # doesnt fill on ddp
precision=32, # throws error on 16
default_root_dir=self.save_dir,
logger=logger,
resume_from_checkpoint=resume_from_checkpoint,
callbacks=[
ModelCheckpoint(
save_last=True,
save_top_k=1,
verbose=True,
monitor='val_acc',
mode='max',
prefix=''
),
]
)
@staticmethod
def find_best_epoch(ckpt_folder):
"""
Find the highest epoch in the Test Tube file structure.
Assumes 'epoch={int}.ckpt' format on files.
:param ckpt_folder: dir where the checpoints are being saved.
:return: string of the path to the checkpoint. Or None if no checkpoints found.
"""
debug = True
try:
ckpt_files = listdir(ckpt_folder) # list of strings
epochs = [int(filename[6:-5]) for filename in ckpt_files if
'epoch=' in filename] # 'epoch={int}.ckpt' filename format
if len(epochs) == 0:
return None
resume_from_checkpoint = str(ckpt_folder / 'epoch={}.ckpt'.format(max(epochs)))
print('debug: loading from checkpoint epoch {}'.format(max(epochs))) if debug else None
except FileNotFoundError as e:
resume_from_checkpoint = None
print('debug: couldnt load from checkpoint (FileNotFoundError)') if debug else None
except ValueError as e:
resume_from_checkpoint = None
print('debug: couldnt load from checkpoint (ValueError)') if debug else None
print(e)
return resume_from_checkpoint
def train(self):
self.trainer.fit(self.module)
def test(self):
self.trainer.test(self.module)
def evaluate(self):
"""
Run a set of evaluations to obtain metrics of SID task.
Metrics:
Accuracy
Accuracy per class
Accuracy total
Recall
Per Class
Total
Error
:return:
"""
# self.trainer.run_evaluation(test=False)
return
# folder name of music files after some processing (raw, svs, svd, svs+svd, etc)
# raw_feature_name = 'svs'
class WavenetTransformerHelper(AbstractHelper):
model_name = 'wavenet_transformer'
dataset = WaveformDataset
# source_feature_name = SingingVoiceSeparationOpenUnmixFeatureExtractor.feature_name
lightning_module = L_WavenetTransformerClassifier
class WavenetLSTMHelper(AbstractHelper):
model_name = 'wavenet_lstm'
dataset = WaveformDataset
# source_feature_name = SingingVoiceSeparationOpenUnmixFeatureExtractor.feature_name
lightning_module = L_WavenetLSTMClassifier
class WavenetHelper(AbstractHelper):
model_name = 'wavenet'
dataset = WaveformDataset
# source_feature_name = SingingVoiceSeparationOpenUnmixFeatureExtractor.feature_name
lightning_module = L_WavenetClassifier
class Conv1dHelper(AbstractHelper):
model_name = 'conv1d'
dataset = WaveformDataset
# source_feature_name = SingingVoiceSeparationOpenUnmixFeatureExtractor.feature_name
lightning_module = L_Conv1DClassifier
class RNNHelper(AbstractHelper):
model_name = 'rnn'
dataset = WaveformDataset
# source_feature_name = SingingVoiceSeparationOpenUnmixFeatureExtractor.feature_name
lightning_module = L_RNNClassifier
class GMMClassifierHelper(AbstractHelper):
model_name = 'gmm'
dataset = CepstrumDataset
# source_feature_name = MelSpectralCoefficientsFeatureExtractor.feature_name
lightning_module = L_GMMClassifier
dataset_ratios = (.7, .29, .01)
def __init__(self, experiment_name, parser, data_path, label_filename, models_path, dummy_mode=False):
super().__init__(experiment_name, parser, data_path, label_filename, models_path, dummy_mode)
self.module.load_model(self.save_dir)
def train(self):
"""
The lighning_module of GMM has special behaviour. So it has trained boolean
attribute and save_model methods.
:return:
"""
self.module.start_training()
self.module.start_evaluation()
class ResNext50Helper(AbstractHelper):
model_name = 'resnext50'
dataset = CepstrumDataset
# source_feature_name = MelSpectralCoefficientsFeatureExtractor.feature_name
lightning_module = L_ResNext50
helpers = {
ResNext50Helper.model_name: ResNext50Helper,
WavenetTransformerHelper.model_name: WavenetTransformerHelper,
WavenetLSTMHelper.model_name: WavenetLSTMHelper,
GMMClassifierHelper.model_name: GMMClassifierHelper,
WavenetHelper.model_name: WavenetHelper,
Conv1dHelper.model_name: Conv1dHelper,
RNNHelper.model_name: RNNHelper,
}
def add_cli_args(parser):
parser.add_argument(
'--data_path',
help='Source path where input data files are stored',
default=RAW_DATA_PATH
)
parser.add_argument(
'--label_filename',
help='File name of label file (default is labels.csv)',
default='labels.csv'
)
parser.add_argument(
'--model_path',
help='Path where the model data is going to be stored.',
default=MODELS_DATA_PATH
)
parser.add_argument(
'--model',
help='Model name. (Ej. resnext50, gmm, transformer)',
default=GMMClassifierHelper.model_name,
# required=True
)
parser.add_argument(
'--experiment',
help='experiment identifier',
default='unnamed_experiment'
)
parser.add_argument(
'--mode',
help='can be train, test or dummy (to activate the dummy training).',
default='train',
required=False
)
parser.add_argument('--gpus', default='[]', type=str)
def parse_cli_args(args):
model_name = args.model
experiment_name = args.experiment
data_path = pathlib.Path(args.data_path)
models_path = pathlib.Path(args.model_path)
label_filename = args.label_filename
gpus = json.loads(args.gpus)
mode = args.mode
return model_name, experiment_name, data_path, models_path, label_filename, gpus, mode
if __name__ == '__main__':
parser = argparse.ArgumentParser(
description='Train a model from features from a data folder',
add_help=False
)
add_cli_args(parser)
args = parser.parse_args()
model_name, experiment_name, data_path, models_path, label_filename, _, mode = parse_cli_args(args)
helper_class = helpers[model_name]
helper = helper_class(
experiment_name,
parser,
data_path,
label_filename,
models_path,
dummy_mode=True if mode == 'dummy' else False
)
if mode == 'dummy' or mode == 'train':
helper.train()
elif mode == 'test':
helper.test()
elif mode == 'evaluation':
helper.evaluate()
else:
raise NotImplementedError('model mode not implemented')
print('helper ended')