@@ -32,13 +32,6 @@ def __init__(
3232 def forward (self , x ):
3333 return self .model (x )
3434
35- def _log_metrics (self ):
36- if self .trainer .is_global_zero :
37- str_metrics = ''
38- for key , val in self .trainer .logged_metrics .items ():
39- str_metrics += f'\n \t { key } : { val } '
40- logger .info (str_metrics )
41-
4235 def training_step (self , batch , batch_idx ):
4336 images , target = batch
4437 output = self (images )
@@ -48,7 +41,7 @@ def training_step(self, batch, batch_idx):
4841 self .log ('train_acc1' , acc1 , on_step = True , prog_bar = True , on_epoch = True )
4942 self .log ('train_acc5' , acc5 , on_step = True , on_epoch = True )
5043 return loss_train
51-
44+
5245 def validation_step (self , batch , batch_idx ):
5346 images , target = batch
5447 output = self (images )
@@ -59,9 +52,6 @@ def validation_step(self, batch, batch_idx):
5952 self .log ('val_acc5' , acc5 , on_epoch = True )
6053 return loss_val
6154
62- def on_validation_end (self ):
63- self ._log_metrics ()
64-
6555 @staticmethod
6656 def __accuracy (output , target , topk = (1 , )):
6757 """Computes the accuracy over the k top predictions for the specified values of k"""
@@ -89,6 +79,16 @@ def test_step(self, batch, batch_idx):
8979 self .log ('test_acc5' , acc5 , on_epoch = True )
9080 return loss_test
9181
82+ def _log_metrics (self ):
83+ if self .trainer .is_global_zero :
84+ str_metrics = ''
85+ for key , val in self .trainer .logged_metrics .items ():
86+ str_metrics += f'\n \t { key } : { val } '
87+ logger .info (str_metrics )
88+
89+ def on_validation_end (self ):
90+ self ._log_metrics ()
91+
9292 def on_test_end (self ):
9393 self ._log_metrics ()
9494
0 commit comments