@@ -142,28 +142,27 @@ def compute(self) -> float:
142
142
# for which there are no GT examples and renormalize probabilities
143
143
# This will give wrong results, but allows getting a value during training
144
144
# where it cannot be guaranteed that each batch contains all classes.
145
+ curr_unique_label = np .unique (labels ).tolist ()
145
146
if self .parent .unique_labels is None :
146
- unique_labels = np .unique (labels ).tolist ()
147
- curr_unique_label = unique_labels
147
+ unique_labels = curr_unique_label
148
148
else :
149
149
# If we are testing on a small subset of data and by chance it does not
150
150
# contain all classes, we need to provide the groundtruth labels
151
151
# separately.
152
152
unique_labels = self .parent .unique_labels
153
- curr_unique_label = np .unique (labels ).tolist ()
154
-
155
153
probs = out .probs [..., unique_labels ]
156
154
probs /= probs .sum (axis = - 1 , keepdims = True ) # renormalize
157
155
check_type (probs , Float ["b n" ])
158
156
if len (unique_labels ) == 2 :
159
- # Binary mode: make it binary, otherwise sklearn complains.
157
+ # Binary mode: make it binary and assume positive class is 1, otherwise
158
+ # sklearn complains.
160
159
assert (
161
160
probs .shape [- 1 ] == 2
162
161
), f"Unique labels are binary but probs.shape is { probs .shape } "
163
162
probs = probs [..., 1 ]
164
163
mask = out .mask [..., 0 ].astype (np .float32 )
165
164
check_type (mask , Float ["b" ])
166
- if len (curr_unique_label ) > 1 :
165
+ if len (curr_unique_label ) == len ( unique_labels ) :
167
166
# See comment above about small data subsets.
168
167
return sklearn_metrics .roc_auc_score (
169
168
y_true = labels ,
0 commit comments