Skip to content

Commit

Permalink
early stop in training
Browse files Browse the repository at this point in the history
  • Loading branch information
Natooz committed Dec 10, 2022
1 parent 2d114a5 commit 18f5fdf
Show file tree
Hide file tree
Showing 3 changed files with 62 additions and 36 deletions.
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
author='Nathan Fradet',
url='https://github.com/Natooz/TorchToolkit',
packages=find_packages(exclude=("test",)),
version='0.0.3',
version='0.0.4',
license='MIT',
description='Useful functions to use with PyTorch',
long_description=long_description,
Expand Down
6 changes: 5 additions & 1 deletion test/test_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ def test_train():
from torch.utils.tensorboard import SummaryWriter
from torch.utils.data import Dataset, DataLoader

from torchtoolkit.train import train, log_cuda_info, log_model_parameters
from torchtoolkit.train import train, log_cuda_info, log_model_parameters, ValidAccModeParameters
from torchtoolkit.data import create_subsets
from torchtoolkit.sampling import top_k

Expand Down Expand Up @@ -67,6 +67,10 @@ def forward_train(self, x: LongTensor, target: LongTensor, crit: Module, k: int
train(model, criterion, optimizer, dataloader_train, dataloader_valid, 100, 10, 10, tensorboard, logger=logger,
log_intvl=10, lr_scheduler=lr_scheduler, gradient_clip_norm=0.1, saving_dir=Path())

train(model, criterion, optimizer, dataloader_train, dataloader_valid, 50, 10, 10, tensorboard, logger=logger,
log_intvl=10, lr_scheduler=lr_scheduler, gradient_clip_norm=0.1, saving_dir=Path(),
iterator_kwargs={'early_stop_steps': 15, 'valid_acc_mode_parameters': ValidAccModeParameters(0.8, 5)})


if __name__ == '__main__':
test_train()
90 changes: 56 additions & 34 deletions torchtoolkit/train.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from logging import Logger
from pathlib import Path
from typing import List, Tuple, Callable
from typing import List, Callable
from contextlib import contextmanager
from functools import partial

Expand Down Expand Up @@ -70,48 +70,69 @@ def __null_context():
yield


class ValidAccModeParameters:
def __init__(self, valid_acc_target: float, nb_past_steps: int, min_nb_steps: int = 0,
max_nb_steps: int = float('inf')):
"""
:param valid_acc_target: validation accuracy target
:param nb_past_steps: the number of past valid accuracy values to use to compute the average validation
accuracy. If this average is > to the minimum valid acc to reach, the training can be stopped
if the number of steps is > to min_nb_steps given above.
:param min_nb_steps: min number of training steps when working with min_valid_acc (default: 0)
:param max_nb_steps: max number of training steps when working with min_valid_acc (default: +inf)
"""
self.valid_acc_target = valid_acc_target
self.nb_past_steps = nb_past_steps
self.min_nb_steps = min_nb_steps
self.max_nb_steps = max_nb_steps


class Iterator:
def __init__(self, nb_steps: int = None, min_nb_steps: int = 0, max_nb_steps: int = float('inf'),
min_valid_acc: Tuple[float, int] = (None, None), pbar_desc: str = 'TRAINING'):
def __init__(self, nb_steps: int = None, early_stop_steps: int = float('inf'),
valid_acc_mode_parameters: ValidAccModeParameters = None, pbar_desc: str = 'TRAINING'):
"""Training iterator class.
Can work in two modes:
1. Number of steps: will be iterated a fixed number of times
2. Min valid accuracy: will be iterated till the model reaches a target validation
2. Valid accuracy: will be iterated till the model reaches a target validation
accuracy value, or if the number of training steps exceeds max_nb_steps.
:param nb_steps: number of training steps. (default None)
:param min_nb_steps: min number of training steps when working with min_valid_acc (default: 0)
:param max_nb_steps: max number of training steps when working with min_valid_acc (default: +inf)
:param min_valid_acc: a set of parameters to use to train a model in "validation accuracy mode".
The first is the minimum valid accuracy value to reach,
the second is the number of past valid accuracy values to use to compute the average validation
accuracy. If this average is > to the minimum valid acc to reach, the training can be stopped
if the number of steps is > to min_nb_steps given above.
(default: (None, None))
minimal validation accuracy value to reach before stopping the iteration. (default None)
:param valid_acc_mode_parameters: parameters to use to train a model in "validation accuracy mode".
(default: None)
:param early_stop_steps: will stop training if the validation accuracy did not increase in this last
number of training steps (default: inf)
:param pbar_desc: progress bar description. (default: TRAINING)
"""
assert nb_steps is not None or all(i is not None for i in min_valid_acc), \
assert nb_steps is not None or valid_acc_mode_parameters is not None, \
'You must give at least nb_steps or min_valid_acc argument to construct the iterator'
self.nb_steps = nb_steps
self.min_nb_steps = min_nb_steps
self.max_nb_steps = max_nb_steps
self._min_valid_acc, self._valid_acc_nb_steps = min_valid_acc
self.valid_acc_params = valid_acc_mode_parameters
self._past_valid_acc = []
self.pbar = tqdm(total=max_nb_steps if self._min_valid_acc is not None else nb_steps, desc=pbar_desc)
self.best_valid_step = 0 # for early stop
self._early_stop_steps = early_stop_steps
self.pbar = tqdm(total=nb_steps if self.valid_acc_params is None else self.valid_acc_params.max_nb_steps,
desc=pbar_desc)

def __iter__(self):
self.step = 0
return self

def __next__(self):
if self._min_valid_acc is not None: # min valid acc mode
if self.__is_past_valid_acc_ok() or self.min_nb_steps < self.step < self.max_nb_steps:
# Early stop if valid acc did not increase
if self._early_stop_steps is not None and self.step - self.best_valid_step >= self._early_stop_steps:
raise StopIteration

# Validation target mode
elif self.valid_acc_params is not None:
if self.__is_past_valid_acc_ok() or \
self.valid_acc_params.min_nb_steps < self.step < self.valid_acc_params.max_nb_steps:
return self.__iter_update()
raise StopIteration

elif self.step < self.nb_steps: # nb_steps mode
elif self.step < self.nb_steps:
return self.__iter_update()

raise StopIteration

def __iter_update(self):
Expand All @@ -120,9 +141,11 @@ def __iter_update(self):
return self.step

def __is_past_valid_acc_ok(self) -> bool:
if len(self._past_valid_acc) < self._valid_acc_nb_steps:
if len(self._past_valid_acc) < self.valid_acc_params.nb_past_steps:
return False
return (sum(self._past_valid_acc[-self._valid_acc_nb_steps:]) / self._valid_acc_nb_steps) > self._min_valid_acc
return self.__mean_last_valid_acc(self.valid_acc_params.nb_past_steps) >= self.valid_acc_params.valid_acc_target

def __mean_last_valid_acc(self, nb_val: int) -> float: return sum(self._past_valid_acc[-nb_val:]) / nb_val

def update_valid_acc(self, valid_acc: float):
"""Stores the validation accuracy given in argument. Need to be called at each validation step in order
Expand Down Expand Up @@ -165,16 +188,11 @@ def train(model: Module, criterion: Module, optimizer: Optimizer, dataloader_tra
:param acc_func: accuracy function. (default: torchtoolkit.metrics.calculate_accuracy in greedy mode)
:param valid_metrics: custom metrics to run during validation phase, torchtoolkit.metrics.Metric. (default: None)
:param iterator_kwargs: parameters for the training iterator, to be given as a dictionary as:
- 'min_nb_steps': the minimum number of training steps to perform (default: 0)
- 'max_nb_steps': the maximum number of training step (default: +inf)
- 'min_valid_acc': Tuple[float, int] , a set of parameters to use
to training a model in "validation accuracy mode".
The first is the minimum valid accuracy value to reach,
the second is the number of past valid accuracy values to use to compute the average validation
accuracy. If this average is > to the minimum valid acc to reach, the training can be stopped
if the number of steps is > to min_nb_steps given above.
(default: (None, None))
default value of iterator_params is None, leading to the default value of the Iterator class.
- 'valid_acc_mode_parameters': parameters to train the model until it reaches a target
validation accuracy value, see ValidAccModeParameters (default: None)
- 'early_stop_steps': will stop training if the validation accuracy did not increase in
this last number of training steps (default: inf)
(default: None)
:param lr_scheduler: learning rate scheduler. (default: None)
:param device: device to run on (default: None --> select_device(use_cuda=True))
:param use_amp: to use Automatic Mixed Precision (AMP) during training. (default: True)
Expand All @@ -190,6 +208,7 @@ def train(model: Module, criterion: Module, optimizer: Optimizer, dataloader_tra
model = model.to(device)
model.train()
best_valid_loss = float('inf')
best_valid_loss_step = 0
last_loss_valid = last_acc_valid = 0 # use for pbar postfix
train_iter = iter(dataloader_train)
valid_iter = iter(dataloader_valid)
Expand Down Expand Up @@ -254,6 +273,8 @@ def train(model: Module, criterion: Module, optimizer: Optimizer, dataloader_tra
# Save model if loss as decreased
if saving_dir is not None and valid_loss < best_valid_loss:
best_valid_loss = valid_loss
best_valid_loss_step = training_step
iterator.best_valid_step = training_step
save({'training_step': training_step,
'model_state_dict': model.state_dict(),
'optimizer_state_dict': optimizer.state_dict(),
Expand All @@ -263,7 +284,8 @@ def train(model: Module, criterion: Module, optimizer: Optimizer, dataloader_tra
model.train()

iterator.pbar.set_postfix({'train_loss': f'{last_loss_train:.4f}', 'train_acc': f'{last_acc_train:.4f}',
'valid_loss': f'{last_loss_valid:.4f}', 'valid_acc': f'{last_acc_valid:.4f}'},
'valid_loss': f'{last_loss_valid:.4f}', 'valid_acc': f'{last_acc_valid:.4f}',
'best_valid_acc_step': best_valid_loss_step},
refresh=False)
if logger is not None and training_step % log_intvl == 0:
logger.debug(str(iterator.pbar).encode('utf-8'))
Expand Down

0 comments on commit 18f5fdf

Please sign in to comment.