@@ -58,6 +58,7 @@ def correct_argument_order(func: Callable) -> Callable:
58
58
return corrected_func
59
59
60
60
61
+ @torch .no_grad ()
61
62
def get_mse_loss (y_predicted : torch .Tensor , y_true : torch .Tensor , ** kwargs ) -> float :
62
63
"""
63
64
Compute MSE loss.
@@ -75,14 +76,56 @@ def auc_score(y_predicted: torch.Tensor, y_true: torch.Tensor, **kwargs) -> floa
75
76
return auc (fpr , tpr )
76
77
77
78
79
+ @torch .no_grad ()
80
+ def top_k_accuracy_scores (
81
+ eval_metrics : List [object ], results : _StringDict , y_predicted : torch .Tensor , y_true : torch .Tensor , ** kwargs
82
+ ) -> None :
83
+ """
84
+ A batch implementation of `top_k_accuracy_score()`.
85
+
86
+ It takes all top_k_accuracy-based metrics as input,
87
+ and computes the respective metrics efficiently.
88
+ If done separately, the top-k indices would need
89
+ to be computed separately for each k, while here
90
+ it happens only once.
91
+ """
92
+ assert len (y_predicted ) == len (y_true )
93
+
94
+ ks = []
95
+ for eval_metric in eval_metrics :
96
+ k = int (match (eval_metric .criterion , TOP_K_ACCURACY_REGEX )) # Extract k
97
+ ks .append (k )
98
+
99
+ max_k = max (ks )
100
+ _ , top_indices = torch .topk (y_predicted , max_k , dim = 1 ) # Compute the top `max_k` predicted classes
101
+ top_indices = top_indices .t () # Transpose for mathematical convenience
102
+ correct_max_k = top_indices .eq (
103
+ y_true .long ().view (1 , - 1 ).expand_as (top_indices )
104
+ ) # Get correct predictions in top max_k
105
+
106
+ # Compute top-k accuracy for all k's
107
+ for i , k in enumerate (ks ):
108
+ correct_k = correct_max_k [:k ].reshape (- 1 ).float ().sum (dim = 0 , keepdim = True ) # Get correct predictions in top k
109
+ top_k_accuracy = correct_k / len (y_true ) # Divide by batch size (because of transpose earlier)
110
+ results [eval_metrics [i ].criterion ] = top_k_accuracy .item ()
111
+
112
+
113
+ @torch .no_grad ()
78
114
def top_k_accuracy_score (y_predicted : torch .Tensor , y_true : torch .Tensor , ** kwargs ) -> float :
79
115
"""
80
116
Compute the top-k accuracy score
81
117
in a multi-class setting.
82
118
83
119
Conversion to numpy is expensive in this
84
120
case. Stick to using PyTorch tensors.
121
+
122
+ Note: This function is not recommended if you have
123
+ more than one k that this is to be computed
124
+ for. Please use the much more efficient
125
+ `top_k_accuracy_scores()` in that case.
85
126
"""
127
+ assert len (y_predicted ) == len (y_true )
128
+
86
129
k = int (match (kwargs ["criterion" ], TOP_K_ACCURACY_REGEX )) # Extract k
87
130
_ , topk_indices = torch .topk (y_predicted , k , dim = 1 ) # Compute the top-k predicted classes
88
131
correct_examples = torch .eq (y_true [..., None , ...].long (), topk_indices ).any (dim = 1 )
@@ -191,7 +234,7 @@ def __repr__(self):
191
234
},
192
235
"auc" : {"preprocess_fn" : prob_class1 , "eval_fn" : auc_score , "model_type" : "classification" },
193
236
"top_k_accuracy" : {
194
- "eval_fn" : top_k_accuracy_score ,
237
+ "eval_fn" : top_k_accuracy_score , # Not actually used (in favor of `top_k_accuracy_scores()` for efficiecy)
195
238
"regex" : TOP_K_ACCURACY_REGEX ,
196
239
"model_type" : "classification" ,
197
240
},
0 commit comments