@@ -26,14 +26,11 @@ def train_model(model, config, train_loader, val_loader, optimizer,
26
26
so as to be able to change it on the fly without modifying config.
27
27
- `start_epoch` may be provided if a trained checkpoint is loaded
28
28
into the model and training is to be resumed from that point.
29
- - `decouple_fn_train` is a function which takes in the batch
30
- and returns the separated out inputs and targets.
31
- It may be specified if this process deviates from the
29
+ - `decouple_fn_train` and `decouple_fn_eval` are functions which
30
+ take in the batch and return the separated out inputs
31
+ (and targets for training/evaluation).
32
+ They may be specified if this process deviates from the
32
33
default behavior (see `decouple_batch_train`).
33
- - `decouple_fn_eval` is a function which takes in the batch
34
- and returns just the inputs.
35
- It may be specified if this process deviates from the
36
- default behavior (see `decouple_batch_eval`).
37
34
38
35
NOTE: Training may be paused at any time with a keyboard interrupt.
39
36
However, please avoid interrupting after an epoch is finished
@@ -49,32 +46,32 @@ def train_model(model, config, train_loader, val_loader, optimizer,
49
46
train_losses = train_epoch (
50
47
model = model ,
51
48
dataloader = train_loader ,
52
- loss_criterion = loss_criterion_train ,
53
49
device = config .device ,
50
+ loss_criterion = loss_criterion_train ,
54
51
epoch = epoch ,
55
52
optimizer = optimizer ,
56
53
scheduler = scheduler if config .use_scheduler_after_step else None ,
57
54
decouple_fn = decouple_fn_train
58
55
)
59
56
60
57
# Test on training set
61
- _ , eval_metrics_train = test_epoch (
58
+ _ , eval_metrics_train , _ , _ = evaluate_epoch (
62
59
model = model ,
63
60
dataloader = train_loader ,
64
- loss_criterion = loss_criterion_test ,
65
61
device = config .device ,
62
+ loss_criterion = loss_criterion_test ,
66
63
eval_criteria = eval_criteria ,
67
64
decouple_fn = decouple_fn_eval
68
65
)
69
66
# Add train losses+eval metrics, and log them
70
67
train_logger .add_and_log_metrics (train_losses , eval_metrics_train )
71
68
72
69
# Test on val set
73
- val_losses , eval_metrics_val = test_epoch (
70
+ val_losses , eval_metrics_val , _ , _ = evaluate_epoch (
74
71
model = model ,
75
72
dataloader = val_loader ,
76
- loss_criterion = loss_criterion_test ,
77
73
device = config .device ,
74
+ loss_criterion = loss_criterion_test ,
78
75
eval_criteria = eval_criteria ,
79
76
decouple_fn = decouple_fn_eval
80
77
)
@@ -151,175 +148,184 @@ def train_model(model, config, train_loader, val_loader, optimizer,
151
148
}
152
149
return return_dict
153
150
154
- def train_epoch (model , dataloader , loss_criterion , device , epoch ,
151
+ def train_epoch (model , dataloader , device , loss_criterion , epoch ,
155
152
optimizer , scheduler = None , decouple_fn = None ):
156
153
'''
157
154
Perform one training epoch.
158
155
See `perform_one_epoch()` for more details.
159
156
'''
160
- return perform_one_epoch (True , model , dataloader , loss_criterion ,
161
- device , epoch = epoch , optimizer = optimizer ,
157
+ return perform_one_epoch ('train' , model , dataloader , device ,
158
+ loss_criterion = loss_criterion ,
159
+ epoch = epoch , optimizer = optimizer ,
162
160
scheduler = scheduler , decouple_fn = decouple_fn )
163
161
164
162
@torch .no_grad ()
165
- def test_epoch (model , dataloader , loss_criterion , device ,
166
- eval_criteria , return_outputs = False , decouple_fn = None ):
163
+ def evaluate_epoch (model , dataloader , device , loss_criterion ,
164
+ eval_criteria , decouple_fn = None ):
167
165
'''
168
166
Perform one evaluation epoch.
169
167
See `perform_one_epoch()` for more details.
170
168
'''
171
- return perform_one_epoch (False , model , dataloader , loss_criterion ,
172
- device , eval_criteria = eval_criteria ,
173
- return_outputs = return_outputs ,
169
+ return perform_one_epoch ('eval' , model , dataloader , device ,
170
+ loss_criterion = loss_criterion ,
171
+ eval_criteria = eval_criteria ,
172
+ decouple_fn = decouple_fn )
173
+
174
+ @torch .no_grad ()
175
+ def get_all_predictions (model , dataloader , device , threshold_prob = None , decouple_fn = None ):
176
+ '''
177
+ Make predictions on entire dataset and return raw outputs and optionally
178
+ class predictions and probabilities if it's a classification model.
179
+ See `perform_one_epoch()` for more details.
180
+ '''
181
+ return perform_one_epoch ('test' , model , dataloader , device ,
182
+ threshold_prob = threshold_prob ,
174
183
decouple_fn = decouple_fn )
175
184
176
185
@timing
177
- def perform_one_epoch (do_training , model , dataloader , loss_criterion ,
178
- device , epoch = None , optimizer = None , scheduler = None ,
179
- eval_criteria = None , return_outputs = False , decouple_fn = None ):
180
- '''
181
- Common loop for one training or evaluation epoch on the entire dataset.
182
- Return the loss per example for each iteration, and all eval criteria
183
- if evaluation to be performed.
184
- :param do_training: If training is to be performed (otherwise evaluation)
186
+ def perform_one_epoch (phase , model , dataloader , device , loss_criterion = None ,
187
+ epoch = None , optimizer = None , scheduler = None , eval_criteria = None ,
188
+ threshold_prob = None , decouple_fn = None ):
189
+ '''
190
+ Common loop for one training / evaluation / testing epoch on the entire dataset.
191
+ - For training, returns the loss per example for each iteration.
192
+ - For evaluation, returns the loss per example for each iteration and all eval criteria.
193
+ - For testing, returns raw model outputs, and optionally class predictions and
194
+ probabilities if it's a classification model.
195
+
196
+ :param phase: Type of pass to perform over data
197
+ Choices = 'train' | 'eval' | 'test'
185
198
:param scheduler: Pass this only if it's a scheduler that requires taking a step
186
199
after each batch iteration (e.g. CyclicLR), otherwise None
187
- :return_outputs: For evaluation (`do_training=False`), whether to return
188
- the targets and model outputs
189
200
190
- If `do_training` is True, params `optimizer` and `epoch` must be provided.
191
- Otherwise for evaluation, param `eval_criteria` must be provided.
192
- At a time, either only training or only evaluation will be performed.
201
+ If `mode=='train'`, params `optimizer` and `epoch` must be provided.
202
+ If `mode=='eval'`, param `eval_criteria` must be provided.
203
+
204
+ At a time, only one of training / evaluation / testing will be performed.
205
+ For a given mode, all arguments that pertain to other modes will be ignored.
193
206
'''
194
- do_eval = not do_training # Whether to perform evaluation
195
- if do_training :
196
- for param_name , param in zip (['epoch' , 'optimizer' ], [epoch , optimizer ]):
207
+ ALLOWED_PHASES = ['train' , 'eval' , 'test' ]
208
+
209
+ # Check presence of required arguments
210
+ if phase == 'train' :
211
+ for param_name , param in zip (['epoch' , 'optimizer' , 'loss_criterion' ],
212
+ [epoch , optimizer , loss_criterion ]):
197
213
assert param is not None , f'Param "{ param_name } " must not be None for training.'
198
- else :
199
- assert eval_criteria is not None , 'Param "eval_criteria" must not be None for evaluation.'
214
+ elif phase == 'eval' :
215
+ for param_name , param in zip (['eval_criteria' , 'loss_criterion' ],
216
+ [eval_criteria , loss_criterion ]):
217
+ assert param is not None , f'Param "{ param_name } " must not be None for evaluation.'
218
+ elif phase != 'test' :
219
+ raise ValueError (f'Param "phase" ("{ phase } ") must be one of { ALLOWED_PHASES } .' )
220
+
221
+ # Mode for retaining gradients / graph
222
+ MODE = phase == 'train'
200
223
201
- # Set training decoupling function to extract inputs and targets from batch
224
+ # Set decoupling function to extract inputs ( and optionally targets) from batch
202
225
if decouple_fn is None :
203
- decouple_fn = decouple_batch_train
226
+ decouple_fn = decouple_batch_test if phase == 'test' else decouple_batch_train
204
227
205
228
# Set model in training/eval mode as required
206
- model .train (mode = do_training )
229
+ model .train (mode = MODE )
207
230
231
+ # Get required dataloader params
208
232
num_batches , num_examples = len (dataloader ), len (dataloader .dataset )
209
233
batch_size = dataloader .batch_size
210
234
211
235
# Print 50 times in an epoch (or every time, if num_batches < 50)
212
236
batches_to_print = np .unique (np .linspace (0 , num_batches , num = 50 , endpoint = True , dtype = int ))
213
237
214
- # Store all losses, targets, and outputs
215
- loss_hist , targets_hist , outputs_hist = [], [], []
238
+ # Store all required items to be returned
239
+ loss_hist , targets_hist , outputs_hist , preds_hist , probs_hist = [], [], [], [], []
216
240
217
241
# Enable gradient computation if training to be performed else disable it.
218
- # Technically not required if this function is called from `test_epoch()`
219
- # because of decorator, but just being sure.
220
- with torch .set_grad_enabled (do_training ):
242
+ # Technically not required if this function is called from other supported
243
+ # functions, e.g. `evaluate_epoch()` ( because of decorator) , but just being sure.
244
+ with torch .set_grad_enabled (MODE ):
221
245
for batch_idx , batch in enumerate (islice (dataloader , num_batches )):
222
- # Get inputs and targets
223
- inputs , targets = send_batch_to_device (decouple_fn (batch ), device )
246
+ # Get inputs for testing
247
+ if phase == 'test' :
248
+ inputs = send_batch_to_device (decouple_fn (batch ), device )
249
+ else : # Get inputs and targets for training/evaluation
250
+ inputs , targets = send_batch_to_device (decouple_fn (batch ), device )
224
251
225
252
# Reset gradients to zero
226
- if do_training :
253
+ if phase == 'train' :
227
254
optimizer .zero_grad ()
228
255
model .zero_grad ()
229
256
230
257
# Get model outputs
231
258
outputs = get_model_outputs_only (model (inputs ))
232
259
233
- # Compute and store loss
234
- loss = loss_criterion ( outputs , targets )
235
- loss_value = loss . item ( )
236
- loss_hist . append ( loss_value )
260
+ # Store items for testing + print progress
261
+ if phase == 'test' :
262
+ outputs = send_batch_to_device ( outputs , 'cpu' )
263
+ outputs_hist . extend ( outputs )
237
264
238
- # Perform training only steps
239
- if do_training :
240
- # Backprop + clip gradients + take scheduler step
241
- loss .backward ()
242
- nn .utils .clip_grad_norm_ (model .parameters (), 1. )
243
- optimizer .step ()
244
- if scheduler is not None :
245
- take_scheduler_step (scheduler , loss_value )
265
+ # Get class predictions and probabilities
266
+ if model .model_type == 'classification' :
267
+ preds , probs = model .predict_proba (outputs , threshold_prob )
268
+ preds_hist .extend (preds )
269
+ probs_hist .extend (probs )
246
270
247
271
# Print progess
248
272
if batch_idx in batches_to_print :
249
- logging .info ('Train Epoch: {} [{}/{} ({:.0f}%)]\t Loss: {:.6f}' .format (
250
- epoch , (batch_idx + 1 ) * batch_size , num_examples ,
251
- 100. * (batch_idx + 1 ) / num_batches , loss_value ))
252
-
253
- else : # Store targets and model outputs on CPU for evaluation
254
- outputs , targets = send_batch_to_device ((outputs , targets ), 'cpu' )
255
- outputs_hist .append (outputs )
256
- targets_hist .append (targets )
273
+ logging .info ('{}/{} ({:.0f}%) complete.' .format (
274
+ (batch_idx + 1 ) * batch_size , num_examples ,
275
+ 100. * (batch_idx + 1 ) / num_batches ))
276
+
277
+ else : # Perform training / evaluation
278
+ # Compute and store loss
279
+ loss = loss_criterion (outputs , targets )
280
+ loss_value = loss .item ()
281
+ loss_hist .append (loss_value )
282
+
283
+ # Perform training
284
+ if phase == 'train' :
285
+ # Backprop + clip gradients + take scheduler step
286
+ loss .backward ()
287
+ nn .utils .clip_grad_norm_ (model .parameters (), 1. )
288
+ optimizer .step ()
289
+ if scheduler is not None :
290
+ take_scheduler_step (scheduler , loss_value )
291
+
292
+ # Print progess
293
+ if batch_idx in batches_to_print :
294
+ logging .info ('Train Epoch: {} [{}/{} ({:.0f}%)]\t Loss: {:.6f}' .format (
295
+ epoch , (batch_idx + 1 ) * batch_size , num_examples ,
296
+ 100. * (batch_idx + 1 ) / num_batches , loss_value ))
297
+
298
+ else : # Store items for evaluation
299
+ outputs , targets = send_batch_to_device ((outputs , targets ), 'cpu' )
300
+ outputs_hist .append (outputs )
301
+ targets_hist .append (targets )
257
302
258
303
# Reset gradients back to zero
259
- if do_training :
304
+ if phase == 'train' :
260
305
optimizer .zero_grad ()
261
306
model .zero_grad ()
262
307
263
- else : # Perform evaluation on whole dataset
308
+ elif phase == 'eval' : # Perform evaluation on whole dataset
264
309
outputs_hist = torch .cat (outputs_hist , dim = 0 )
265
310
targets_hist = torch .cat (targets_hist , dim = 0 )
266
311
267
312
# Compute all evaluation criteria
268
313
eval_metrics = {eval_criterion : eval_fn (outputs_hist , targets_hist ) \
269
314
for eval_criterion , eval_fn in eval_criteria .items ()}
270
315
271
- # Return necessary variables
272
- return_list = [loss_hist ]
273
- if do_eval :
274
- return_list .append (eval_metrics )
275
- if return_outputs :
276
- return_list .extend ([outputs_hist , targets_hist ])
277
- return return_list
278
-
279
- @timing
280
- @torch .no_grad ()
281
- def get_all_predictions (model , dataloader , device , threshold_prob = None , decouple_fn = None ):
282
- '''
283
- Make predictions on entire dataset and return raw outputs and optionally
284
- class predictions and probabilities if it's a classification model.
285
- '''
286
- # Set evaluation decoupling function to get inputs from the batch
287
- if decouple_fn is None :
288
- decouple_fn = decouple_batch_eval
289
-
290
- model .eval ()
291
-
292
- num_batches , num_examples = len (dataloader ), len (dataloader .dataset )
293
- batch_size = dataloader .batch_size
294
-
295
- # Print 50 times in an epoch (or every time, if num_batches < 50)
296
- batches_to_print = np .unique (np .linspace (0 , num_batches , num = 50 , endpoint = True , dtype = int ))
297
-
298
- outputs_hist , preds_hist , probs_hist = [], [], []
299
- for batch_idx , batch in enumerate (islice (dataloader , num_batches )):
300
- inputs = send_batch_to_device (decouple_fn (batch ), device )
301
-
302
- # Get model outputs
303
- outputs = get_model_outputs_only (model (inputs ))
304
- outputs = send_batch_to_device (outputs , 'cpu' )
305
- outputs_hist .extend (outputs )
306
-
307
- if model .model_type == 'classification' : # Get class predictions and probabilities
308
- preds , probs = model .predict_proba (outputs , threshold_prob )
309
- preds_hist .extend (preds )
310
- probs_hist .extend (probs )
311
-
312
- # Print progess
313
- if batch_idx in batches_to_print :
314
- logging .info ('{}/{} ({:.0f}%) complete.' .format (
315
- (batch_idx + 1 ) * batch_size , num_examples ,
316
- 100. * (batch_idx + 1 ) / num_batches ))
317
-
318
- outputs_hist = torch .stack (outputs_hist , dim = 0 )
319
- if model .model_type == 'classification' :
320
- preds_hist = torch .stack (preds_hist , dim = 0 )
321
- probs_hist = torch .stack (probs_hist , dim = 0 )
322
- return outputs_hist , preds_hist , probs_hist
316
+ else : # Get outputs, predictions, probabilities
317
+ outputs_hist = torch .stack (outputs_hist , dim = 0 )
318
+ if model .model_type == 'classification' :
319
+ preds_hist = torch .stack (preds_hist , dim = 0 )
320
+ probs_hist = torch .stack (probs_hist , dim = 0 )
321
+
322
+ # Return necessary items
323
+ if phase == 'train' :
324
+ return loss_hist
325
+ elif phase == 'eval' :
326
+ return loss_hist , eval_metrics , outputs_hist , targets_hist
327
+ else :
328
+ return outputs_hist , preds_hist , probs_hist
323
329
324
330
def decouple_batch_train (batch ):
325
331
'''
@@ -336,7 +342,7 @@ def decouple_batch_train(batch):
336
342
inputs , targets = batch [:2 ]
337
343
return inputs , targets
338
344
339
- def decouple_batch_eval (batch ):
345
+ def decouple_batch_test (batch ):
340
346
'''
341
347
Extract and return just the inputs
342
348
from a batch assuming it's the first
@@ -600,10 +606,10 @@ def validate_checkpoint_type(checkpoint_type, checkpoint_file=None):
600
606
Check that the passed `checkpoint_type` is valid and matches that
601
607
obtained from `checkpoint_file`, if provided.
602
608
'''
603
- allowed_checkpoint_types = ['state' , 'model' ]
604
- assert checkpoint_type in allowed_checkpoint_types , \
609
+ ALLOWED_CHECKPOINT_TYPES = ['state' , 'model' ]
610
+ assert checkpoint_type in ALLOWED_CHECKPOINT_TYPES , \
605
611
(f'Param "checkpoint_type" ("{ checkpoint_type } ") '
606
- f'must be one of { allowed_checkpoint_types } .' )
612
+ f'must be one of { ALLOWED_CHECKPOINT_TYPES } .' )
607
613
608
614
# Check that provided checkpoint_type matches that of checkpoint_file
609
615
if checkpoint_file is not None :
0 commit comments