@@ -32,13 +32,6 @@ def __init__(
32
32
def forward (self , x ):
33
33
return self .model (x )
34
34
35
- def _log_metrics (self ):
36
- if self .trainer .is_global_zero :
37
- str_metrics = ''
38
- for key , val in self .trainer .logged_metrics .items ():
39
- str_metrics += f'\n \t { key } : { val } '
40
- logger .info (str_metrics )
41
-
42
35
def training_step (self , batch , batch_idx ):
43
36
images , target = batch
44
37
output = self (images )
@@ -48,7 +41,7 @@ def training_step(self, batch, batch_idx):
48
41
self .log ('train_acc1' , acc1 , on_step = True , prog_bar = True , on_epoch = True )
49
42
self .log ('train_acc5' , acc5 , on_step = True , on_epoch = True )
50
43
return loss_train
51
-
44
+
52
45
def validation_step (self , batch , batch_idx ):
53
46
images , target = batch
54
47
output = self (images )
@@ -59,9 +52,6 @@ def validation_step(self, batch, batch_idx):
59
52
self .log ('val_acc5' , acc5 , on_epoch = True )
60
53
return loss_val
61
54
62
- def on_validation_end (self ):
63
- self ._log_metrics ()
64
-
65
55
@staticmethod
66
56
def __accuracy (output , target , topk = (1 , )):
67
57
"""Computes the accuracy over the k top predictions for the specified values of k"""
@@ -89,6 +79,16 @@ def test_step(self, batch, batch_idx):
89
79
self .log ('test_acc5' , acc5 , on_epoch = True )
90
80
return loss_test
91
81
82
+ def _log_metrics (self ):
83
+ if self .trainer .is_global_zero :
84
+ str_metrics = ''
85
+ for key , val in self .trainer .logged_metrics .items ():
86
+ str_metrics += f'\n \t { key } : { val } '
87
+ logger .info (str_metrics )
88
+
89
+ def on_validation_end (self ):
90
+ self ._log_metrics ()
91
+
92
92
def on_test_end (self ):
93
93
self ._log_metrics ()
94
94
0 commit comments