-
Notifications
You must be signed in to change notification settings - Fork 13
/
early_stopping.py
94 lines (73 loc) · 2.84 KB
/
early_stopping.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
import torch
class EarlyStopping:
"""Early stop the training if current metric is worse than the best one for longer than
number of wait_epochs or if metric stops changing.
Parameters
----------
wait_epochs: int, optional (default=2)
Number of epochs to wait to improve the metric to stop the training.
"""
def __init__(self, wait_epochs=2):
self.wait_epochs = wait_epochs
self.num_bad_scores = 0
self.num_const_scores = 0
self.best_score = None
self.best_metric = 0
def stop(self, metric, model, metric_type='better_decrease', delta=0.03):
"""Stop the training if metric criteria aren't met.
Parameters
----------
metric: float
Metric used to evaluate the validation performance.
model: pytorch model
Pytorch model instance.
metric_type: str, optional (default='better_decrease')
Specify the metric type, available options: better_decrease, better_increase.
delta: float, optional (default=0.03)
The minimum change of a metric that is considered in stoping decision.
Fraction of the metric.
Returns
-------
Boolean
True if training should be stoped, otherwise False.
"""
self.delta = delta
delta = self.delta * metric
if self.best_score is None:
self.best_score = metric
self.save_model_state(metric, model)
return False
if abs(metric - self.best_score) < self.delta/3 * metric:
self.num_const_scores += 1
if self.num_const_scores >= self.wait_epochs + 1:
print('\nTraining stoped by EarlyStopping')
return True
else:
self.num_const_scores = 0
if metric_type == 'better_decrease':
if metric > self.best_score + delta:
self.num_bad_scores += 1
elif metric > self.best_score:
self.num_bad_scores = 0
else:
self.best_score = metric
self.save_model_state(metric, model)
self.num_bad_scores = 0
else:
if metric < self.best_score - delta:
self.num_bad_scores += 1
elif metric < self.best_score:
self.num_bad_scores = 0
else:
self.best_score = metric
self.save_model_state(metric, model)
self.num_bad_scores = 0
if self.num_bad_scores >= self.wait_epochs:
print('\nTraining stoped by EarlyStopping')
return True
return False
def save_model_state(self, metric, model):
"""Saves the best model state.
"""
torch.save(model.state_dict(), 'checkpoint.pt')
self.best_metric = metric