@@ -37,20 +37,22 @@ def before_init(self):
37
37
38
38
def before_epoch_start (self , epoch ):
39
39
self .raw_output = []
40
- self .predicts = np . array ([]). astype ( np . float )
41
- self .targets = np . array ([]). astype ( np . int )
40
+ self .predicts = []
41
+ self .targets = []
42
42
43
43
def after_val_iteration_ended (self , predicts , data , ** ignore ):
44
44
raw_output = predicts .detach ().cpu ().numpy ()
45
45
predicts = np .argmax (raw_output , axis = 1 )
46
- predicts = predicts .reshape ([- 1 ])
47
- targets = data [1 ].cpu ().numpy ().reshape ([- 1 ])
46
+ targets = data [1 ].cpu ().numpy ()
48
47
49
48
self .raw_output .append (raw_output )
50
- self .predicts = np . concatenate (( self . predicts , predicts ) )
51
- self .targets = np . concatenate (( self . targets , targets ) )
49
+ self .predicts . append ( predicts )
50
+ self .targets . append ( targets )
52
51
53
52
def after_epoch_end (self , val_loss , ** ignore ):
53
+ self .predicts = np .concatenate (self .predicts )
54
+ self .targets = np .concatenate (self .targets )
55
+
54
56
self ._save_results ()
55
57
self .accuracy and self ._accuracy ()
56
58
self .confusion_matrix and self ._confusion_matrix ()
@@ -119,5 +121,5 @@ def _kappa_score(self):
119
121
120
122
def _save_results (self ):
121
123
file_name = self .plugin_file (f"result.{ self .current_epoch } .npz" )
122
- raw_output = np .stack (self .raw_output )
124
+ raw_output = np .concatenate (self .raw_output )
123
125
np .savez_compressed (file_name , predicts = self .predicts , targets = self .targets , raw_output = raw_output )
0 commit comments