Skip to content

Commit e787dec

Browse files
akuznetsaThe kauldron Authors
authored andcommitted
Adding NormalizeToRange op and support for unqiue_labels in the RocAuc computation.
PiperOrigin-RevId: 691053530
1 parent 53d078f commit e787dec

File tree

2 files changed

+39
-6
lines changed

2 files changed

+39
-6
lines changed

kauldron/metrics/classification.py

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -142,28 +142,27 @@ def compute(self) -> float:
142142
# for which there are no GT examples and renormalize probabilities
143143
# This will give wrong results, but allows getting a value during training
144144
# where it cannot be guaranteed that each batch contains all classes.
145+
curr_unique_label = np.unique(labels).tolist()
145146
if self.parent.unique_labels is None:
146-
unique_labels = np.unique(labels).tolist()
147-
curr_unique_label = unique_labels
147+
unique_labels = curr_unique_label
148148
else:
149149
# If we are testing on a small subset of data and by chance it does not
150150
# contain all classes, we need to provide the groundtruth labels
151151
# separately.
152152
unique_labels = self.parent.unique_labels
153-
curr_unique_label = np.unique(labels).tolist()
154-
155153
probs = out.probs[..., unique_labels]
156154
probs /= probs.sum(axis=-1, keepdims=True) # renormalize
157155
check_type(probs, Float["b n"])
158156
if len(unique_labels) == 2:
159-
# Binary mode: make it binary, otherwise sklearn complains.
157+
# Binary mode: make it binary and assume positive class is 1, otherwise
158+
# sklearn complains.
160159
assert (
161160
probs.shape[-1] == 2
162161
), f"Unique labels are binary but probs.shape is {probs.shape}"
163162
probs = probs[..., 1]
164163
mask = out.mask[..., 0].astype(np.float32)
165164
check_type(mask, Float["b"])
166-
if len(curr_unique_label) > 1:
165+
if len(curr_unique_label) == len(unique_labels):
167166
# See comment above about small data subsets.
168167
return sklearn_metrics.roc_auc_score(
169168
y_true=labels,

kauldron/metrics/classification_test.py

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,3 +47,37 @@ def test_roc():
4747
s3.merge(s0)
4848
with pytest.raises(ValueError, match='from different metrics'):
4949
s0.merge(s3)
50+
51+
52+
def test_roc_with_binary_labels():
53+
metric = metrics.RocAuc()
54+
55+
logits = jnp.asarray([
56+
[1.0, 0.0],
57+
[0.3, 0.4],
58+
[0.0, 2.0],
59+
])
60+
labels = jnp.asarray([[1], [0], [1]], dtype=jnp.int32)
61+
62+
s0 = metric.get_state(logits=logits, labels=labels)
63+
64+
x = s0.compute()
65+
np.testing.assert_allclose(x, 0.5)
66+
67+
68+
69+
def test_roc_with_unique_labels():
70+
metric = metrics.RocAuc(unique_labels=[0, 1, 2])
71+
72+
logits = jnp.asarray([
73+
[1.0, 0.0, 0.0],
74+
[0.3, 0.4, 0.3],
75+
[0.0, 2.0, 8.0],
76+
])
77+
labels = jnp.asarray([[2], [0], [2]], dtype=jnp.int32)
78+
79+
s0 = metric.get_state(logits=logits, labels=labels)
80+
x = s0.compute()
81+
np.testing.assert_allclose(x, 0.0)
82+
83+

0 commit comments

Comments
 (0)