Skip to content

Commit 0341ada

Browse files
committed
Using common loop for train/eval/test + Other improvements.
1 parent e488d8d commit 0341ada

File tree

2 files changed

+134
-128
lines changed

2 files changed

+134
-128
lines changed

pytorch_common/train_utils.py

Lines changed: 133 additions & 127 deletions
Original file line numberDiff line numberDiff line change
@@ -26,14 +26,11 @@ def train_model(model, config, train_loader, val_loader, optimizer,
2626
so as to be able to change it on the fly without modifying config.
2727
- `start_epoch` may be provided if a trained checkpoint is loaded
2828
into the model and training is to be resumed from that point.
29-
- `decouple_fn_train` is a function which takes in the batch
30-
and returns the separated out inputs and targets.
31-
It may be specified if this process deviates from the
29+
- `decouple_fn_train` and `decouple_fn_eval` are functions which
30+
take in the batch and return the separated out inputs
31+
(and targets for training/evaluation).
32+
They may be specified if this process deviates from the
3233
default behavior (see `decouple_batch_train`).
33-
- `decouple_fn_eval` is a function which takes in the batch
34-
and returns just the inputs.
35-
It may be specified if this process deviates from the
36-
default behavior (see `decouple_batch_eval`).
3734
3835
NOTE: Training may be paused at any time with a keyboard interrupt.
3936
However, please avoid interrupting after an epoch is finished
@@ -49,32 +46,32 @@ def train_model(model, config, train_loader, val_loader, optimizer,
4946
train_losses = train_epoch(
5047
model=model,
5148
dataloader=train_loader,
52-
loss_criterion=loss_criterion_train,
5349
device=config.device,
50+
loss_criterion=loss_criterion_train,
5451
epoch=epoch,
5552
optimizer=optimizer,
5653
scheduler=scheduler if config.use_scheduler_after_step else None,
5754
decouple_fn=decouple_fn_train
5855
)
5956

6057
# Test on training set
61-
_, eval_metrics_train = test_epoch(
58+
_, eval_metrics_train, _, _ = evaluate_epoch(
6259
model=model,
6360
dataloader=train_loader,
64-
loss_criterion=loss_criterion_test,
6561
device=config.device,
62+
loss_criterion=loss_criterion_test,
6663
eval_criteria=eval_criteria,
6764
decouple_fn=decouple_fn_eval
6865
)
6966
# Add train losses+eval metrics, and log them
7067
train_logger.add_and_log_metrics(train_losses, eval_metrics_train)
7168

7269
# Test on val set
73-
val_losses, eval_metrics_val = test_epoch(
70+
val_losses, eval_metrics_val, _, _ = evaluate_epoch(
7471
model=model,
7572
dataloader=val_loader,
76-
loss_criterion=loss_criterion_test,
7773
device=config.device,
74+
loss_criterion=loss_criterion_test,
7875
eval_criteria=eval_criteria,
7976
decouple_fn=decouple_fn_eval
8077
)
@@ -151,175 +148,184 @@ def train_model(model, config, train_loader, val_loader, optimizer,
151148
}
152149
return return_dict
153150

154-
def train_epoch(model, dataloader, loss_criterion, device, epoch,
151+
def train_epoch(model, dataloader, device, loss_criterion, epoch,
155152
optimizer, scheduler=None, decouple_fn=None):
156153
'''
157154
Perform one training epoch.
158155
See `perform_one_epoch()` for more details.
159156
'''
160-
return perform_one_epoch(True, model, dataloader, loss_criterion,
161-
device, epoch=epoch, optimizer=optimizer,
157+
return perform_one_epoch('train', model, dataloader, device,
158+
loss_criterion=loss_criterion,
159+
epoch=epoch, optimizer=optimizer,
162160
scheduler=scheduler, decouple_fn=decouple_fn)
163161

164162
@torch.no_grad()
165-
def test_epoch(model, dataloader, loss_criterion, device,
166-
eval_criteria, return_outputs=False, decouple_fn=None):
163+
def evaluate_epoch(model, dataloader, device, loss_criterion,
164+
eval_criteria, decouple_fn=None):
167165
'''
168166
Perform one evaluation epoch.
169167
See `perform_one_epoch()` for more details.
170168
'''
171-
return perform_one_epoch(False, model, dataloader, loss_criterion,
172-
device, eval_criteria=eval_criteria,
173-
return_outputs=return_outputs,
169+
return perform_one_epoch('eval', model, dataloader, device,
170+
loss_criterion=loss_criterion,
171+
eval_criteria=eval_criteria,
172+
decouple_fn=decouple_fn)
173+
174+
@torch.no_grad()
175+
def get_all_predictions(model, dataloader, device, threshold_prob=None, decouple_fn=None):
176+
'''
177+
Make predictions on entire dataset and return raw outputs and optionally
178+
class predictions and probabilities if it's a classification model.
179+
See `perform_one_epoch()` for more details.
180+
'''
181+
return perform_one_epoch('test', model, dataloader, device,
182+
threshold_prob=threshold_prob,
174183
decouple_fn=decouple_fn)
175184

176185
@timing
177-
def perform_one_epoch(do_training, model, dataloader, loss_criterion,
178-
device, epoch=None, optimizer=None, scheduler=None,
179-
eval_criteria=None, return_outputs=False, decouple_fn=None):
180-
'''
181-
Common loop for one training or evaluation epoch on the entire dataset.
182-
Return the loss per example for each iteration, and all eval criteria
183-
if evaluation to be performed.
184-
:param do_training: If training is to be performed (otherwise evaluation)
186+
def perform_one_epoch(phase, model, dataloader, device, loss_criterion=None,
187+
epoch=None, optimizer=None, scheduler=None, eval_criteria=None,
188+
threshold_prob=None, decouple_fn=None):
189+
'''
190+
Common loop for one training / evaluation / testing epoch on the entire dataset.
191+
- For training, returns the loss per example for each iteration.
192+
- For evaluation, returns the loss per example for each iteration and all eval criteria.
193+
- For testing, returns raw model outputs, and optionally class predictions and
194+
probabilities if it's a classification model.
195+
196+
:param phase: Type of pass to perform over data
197+
Choices = 'train' | 'eval' | 'test'
185198
:param scheduler: Pass this only if it's a scheduler that requires taking a step
186199
after each batch iteration (e.g. CyclicLR), otherwise None
187-
:return_outputs: For evaluation (`do_training=False`), whether to return
188-
the targets and model outputs
189200
190-
If `do_training` is True, params `optimizer` and `epoch` must be provided.
191-
Otherwise for evaluation, param `eval_criteria` must be provided.
192-
At a time, either only training or only evaluation will be performed.
201+
If `mode=='train'`, params `optimizer` and `epoch` must be provided.
202+
If `mode=='eval'`, param `eval_criteria` must be provided.
203+
204+
At a time, only one of training / evaluation / testing will be performed.
205+
For a given mode, all arguments that pertain to other modes will be ignored.
193206
'''
194-
do_eval = not do_training # Whether to perform evaluation
195-
if do_training:
196-
for param_name, param in zip(['epoch', 'optimizer'], [epoch, optimizer]):
207+
ALLOWED_PHASES = ['train', 'eval', 'test']
208+
209+
# Check presence of required arguments
210+
if phase == 'train':
211+
for param_name, param in zip(['epoch', 'optimizer', 'loss_criterion'],
212+
[epoch, optimizer, loss_criterion]):
197213
assert param is not None, f'Param "{param_name}" must not be None for training.'
198-
else:
199-
assert eval_criteria is not None, 'Param "eval_criteria" must not be None for evaluation.'
214+
elif phase == 'eval':
215+
for param_name, param in zip(['eval_criteria', 'loss_criterion'],
216+
[eval_criteria, loss_criterion]):
217+
assert param is not None, f'Param "{param_name}" must not be None for evaluation.'
218+
elif phase != 'test':
219+
raise ValueError(f'Param "phase" ("{phase}") must be one of {ALLOWED_PHASES}.')
220+
221+
# Mode for retaining gradients / graph
222+
MODE = phase == 'train'
200223

201-
# Set training decoupling function to extract inputs and targets from batch
224+
# Set decoupling function to extract inputs (and optionally targets) from batch
202225
if decouple_fn is None:
203-
decouple_fn = decouple_batch_train
226+
decouple_fn = decouple_batch_test if phase == 'test' else decouple_batch_train
204227

205228
# Set model in training/eval mode as required
206-
model.train(mode=do_training)
229+
model.train(mode=MODE)
207230

231+
# Get required dataloader params
208232
num_batches, num_examples = len(dataloader), len(dataloader.dataset)
209233
batch_size = dataloader.batch_size
210234

211235
# Print 50 times in an epoch (or every time, if num_batches < 50)
212236
batches_to_print = np.unique(np.linspace(0, num_batches, num=50, endpoint=True, dtype=int))
213237

214-
# Store all losses, targets, and outputs
215-
loss_hist, targets_hist, outputs_hist = [], [], []
238+
# Store all required items to be returned
239+
loss_hist, targets_hist, outputs_hist, preds_hist, probs_hist = [], [], [], [], []
216240

217241
# Enable gradient computation if training to be performed else disable it.
218-
# Technically not required if this function is called from `test_epoch()`
219-
# because of decorator, but just being sure.
220-
with torch.set_grad_enabled(do_training):
242+
# Technically not required if this function is called from other supported
243+
# functions, e.g. `evaluate_epoch()` (because of decorator), but just being sure.
244+
with torch.set_grad_enabled(MODE):
221245
for batch_idx, batch in enumerate(islice(dataloader, num_batches)):
222-
# Get inputs and targets
223-
inputs, targets = send_batch_to_device(decouple_fn(batch), device)
246+
# Get inputs for testing
247+
if phase == 'test':
248+
inputs = send_batch_to_device(decouple_fn(batch), device)
249+
else: # Get inputs and targets for training/evaluation
250+
inputs, targets = send_batch_to_device(decouple_fn(batch), device)
224251

225252
# Reset gradients to zero
226-
if do_training:
253+
if phase == 'train':
227254
optimizer.zero_grad()
228255
model.zero_grad()
229256

230257
# Get model outputs
231258
outputs = get_model_outputs_only(model(inputs))
232259

233-
# Compute and store loss
234-
loss = loss_criterion(outputs, targets)
235-
loss_value = loss.item()
236-
loss_hist.append(loss_value)
260+
# Store items for testing + print progress
261+
if phase == 'test':
262+
outputs = send_batch_to_device(outputs, 'cpu')
263+
outputs_hist.extend(outputs)
237264

238-
# Perform training only steps
239-
if do_training:
240-
# Backprop + clip gradients + take scheduler step
241-
loss.backward()
242-
nn.utils.clip_grad_norm_(model.parameters(), 1.)
243-
optimizer.step()
244-
if scheduler is not None:
245-
take_scheduler_step(scheduler, loss_value)
265+
# Get class predictions and probabilities
266+
if model.model_type == 'classification':
267+
preds, probs = model.predict_proba(outputs, threshold_prob)
268+
preds_hist.extend(preds)
269+
probs_hist.extend(probs)
246270

247271
# Print progess
248272
if batch_idx in batches_to_print:
249-
logging.info('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
250-
epoch, (batch_idx+1) * batch_size, num_examples,
251-
100. * (batch_idx+1) / num_batches, loss_value))
252-
253-
else: # Store targets and model outputs on CPU for evaluation
254-
outputs, targets = send_batch_to_device((outputs, targets), 'cpu')
255-
outputs_hist.append(outputs)
256-
targets_hist.append(targets)
273+
logging.info('{}/{} ({:.0f}%) complete.'.format(
274+
(batch_idx+1) * batch_size, num_examples,
275+
100. * (batch_idx+1) / num_batches))
276+
277+
else: # Perform training / evaluation
278+
# Compute and store loss
279+
loss = loss_criterion(outputs, targets)
280+
loss_value = loss.item()
281+
loss_hist.append(loss_value)
282+
283+
# Perform training
284+
if phase == 'train':
285+
# Backprop + clip gradients + take scheduler step
286+
loss.backward()
287+
nn.utils.clip_grad_norm_(model.parameters(), 1.)
288+
optimizer.step()
289+
if scheduler is not None:
290+
take_scheduler_step(scheduler, loss_value)
291+
292+
# Print progess
293+
if batch_idx in batches_to_print:
294+
logging.info('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
295+
epoch, (batch_idx+1) * batch_size, num_examples,
296+
100. * (batch_idx+1) / num_batches, loss_value))
297+
298+
else: # Store items for evaluation
299+
outputs, targets = send_batch_to_device((outputs, targets), 'cpu')
300+
outputs_hist.append(outputs)
301+
targets_hist.append(targets)
257302

258303
# Reset gradients back to zero
259-
if do_training:
304+
if phase == 'train':
260305
optimizer.zero_grad()
261306
model.zero_grad()
262307

263-
else: # Perform evaluation on whole dataset
308+
elif phase == 'eval': # Perform evaluation on whole dataset
264309
outputs_hist = torch.cat(outputs_hist, dim=0)
265310
targets_hist = torch.cat(targets_hist, dim=0)
266311

267312
# Compute all evaluation criteria
268313
eval_metrics = {eval_criterion: eval_fn(outputs_hist, targets_hist) \
269314
for eval_criterion, eval_fn in eval_criteria.items()}
270315

271-
# Return necessary variables
272-
return_list = [loss_hist]
273-
if do_eval:
274-
return_list.append(eval_metrics)
275-
if return_outputs:
276-
return_list.extend([outputs_hist, targets_hist])
277-
return return_list
278-
279-
@timing
280-
@torch.no_grad()
281-
def get_all_predictions(model, dataloader, device, threshold_prob=None, decouple_fn=None):
282-
'''
283-
Make predictions on entire dataset and return raw outputs and optionally
284-
class predictions and probabilities if it's a classification model.
285-
'''
286-
# Set evaluation decoupling function to get inputs from the batch
287-
if decouple_fn is None:
288-
decouple_fn = decouple_batch_eval
289-
290-
model.eval()
291-
292-
num_batches, num_examples = len(dataloader), len(dataloader.dataset)
293-
batch_size = dataloader.batch_size
294-
295-
# Print 50 times in an epoch (or every time, if num_batches < 50)
296-
batches_to_print = np.unique(np.linspace(0, num_batches, num=50, endpoint=True, dtype=int))
297-
298-
outputs_hist, preds_hist, probs_hist = [], [], []
299-
for batch_idx, batch in enumerate(islice(dataloader, num_batches)):
300-
inputs = send_batch_to_device(decouple_fn(batch), device)
301-
302-
# Get model outputs
303-
outputs = get_model_outputs_only(model(inputs))
304-
outputs = send_batch_to_device(outputs, 'cpu')
305-
outputs_hist.extend(outputs)
306-
307-
if model.model_type == 'classification': # Get class predictions and probabilities
308-
preds, probs = model.predict_proba(outputs, threshold_prob)
309-
preds_hist.extend(preds)
310-
probs_hist.extend(probs)
311-
312-
# Print progess
313-
if batch_idx in batches_to_print:
314-
logging.info('{}/{} ({:.0f}%) complete.'.format(
315-
(batch_idx+1) * batch_size, num_examples,
316-
100. * (batch_idx+1) / num_batches))
317-
318-
outputs_hist = torch.stack(outputs_hist, dim=0)
319-
if model.model_type == 'classification':
320-
preds_hist = torch.stack(preds_hist, dim=0)
321-
probs_hist = torch.stack(probs_hist, dim=0)
322-
return outputs_hist, preds_hist, probs_hist
316+
else: # Get outputs, predictions, probabilities
317+
outputs_hist = torch.stack(outputs_hist, dim=0)
318+
if model.model_type == 'classification':
319+
preds_hist = torch.stack(preds_hist, dim=0)
320+
probs_hist = torch.stack(probs_hist, dim=0)
321+
322+
# Return necessary items
323+
if phase == 'train':
324+
return loss_hist
325+
elif phase == 'eval':
326+
return loss_hist, eval_metrics, outputs_hist, targets_hist
327+
else:
328+
return outputs_hist, preds_hist, probs_hist
323329

324330
def decouple_batch_train(batch):
325331
'''
@@ -336,7 +342,7 @@ def decouple_batch_train(batch):
336342
inputs, targets = batch[:2]
337343
return inputs, targets
338344

339-
def decouple_batch_eval(batch):
345+
def decouple_batch_test(batch):
340346
'''
341347
Extract and return just the inputs
342348
from a batch assuming it's the first
@@ -600,10 +606,10 @@ def validate_checkpoint_type(checkpoint_type, checkpoint_file=None):
600606
Check that the passed `checkpoint_type` is valid and matches that
601607
obtained from `checkpoint_file`, if provided.
602608
'''
603-
allowed_checkpoint_types = ['state', 'model']
604-
assert checkpoint_type in allowed_checkpoint_types, \
609+
ALLOWED_CHECKPOINT_TYPES = ['state', 'model']
610+
assert checkpoint_type in ALLOWED_CHECKPOINT_TYPES, \
605611
(f'Param "checkpoint_type" ("{checkpoint_type}") '
606-
f'must be one of {allowed_checkpoint_types}.')
612+
f'must be one of {ALLOWED_CHECKPOINT_TYPES}.')
607613

608614
# Check that provided checkpoint_type matches that of checkpoint_file
609615
if checkpoint_file is not None:

pytorch_common/utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -665,7 +665,7 @@ def _set_pooler(self):
665665
from transformers.configuration_auto import CONFIG_MAPPING
666666

667667
# Get a list of all supported model types ('bert', 'distilbert', etc.)
668-
self.supported_model_types = CONFIG_MAPPING.keys()
668+
self.supported_model_types = list(CONFIG_MAPPING.keys())
669669

670670
# Use default pooler if not supported
671671
if self.model_type not in self.supported_model_types:

0 commit comments

Comments
 (0)