Skip to content

Commit aadd32e

Browse files
committed
Add Metric.from_mask helper method (facebookresearch#3411)
1 parent 94f1b9c commit aadd32e

File tree

3 files changed

+134
-15
lines changed

3 files changed

+134
-15
lines changed

parlai/core/metrics.py

Lines changed: 23 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
Optional,
2525
Set,
2626
Tuple,
27+
Type,
2728
Union,
2829
)
2930

@@ -272,7 +273,7 @@ def many(cls, *objs: List[TVector]) -> List[Metric]:
272273
"""
273274
Construct many of a Metric from the base parts.
274275
275-
Useful if you separately compute numerators and denomenators, etc.
276+
Useful if you separately compute numerators and denominators, etc.
276277
"""
277278
lengths = [len(o) for o in objs]
278279
objs = list(objs) # convert from tuple for inplace modification
@@ -286,6 +287,27 @@ def many(cls, *objs: List[TVector]) -> List[Metric]:
286287
raise IndexError(f'Uneven {cls.__name__} constructions: {lengths}')
287288
return [cls(*items) for items in zip(*objs)]
288289

290+
@classmethod
291+
def from_mask(
292+
cls, metric_per_token: torch.Tensor, mask: torch.Tensor, MyMetric: Type[Metric]
293+
) -> List[Metric]:
294+
"""
295+
From token-level metrics, returns an aggregate MyMetric per example in the batch.
296+
297+
:param metric_per_token:
298+
a (batchsize x num_tokens) Tensor
299+
:param mask:
300+
a (batchsize x num_tokens) Tensor to mask out tokens that should *not* be considered in the aggregate metric calculation.
301+
:param MyMetric:
302+
a subclass of Metric
303+
:return:
304+
a (batchsize) Tensor
305+
"""
306+
tokens_per_ex = mask.long().sum(dim=-1)
307+
metric_per_ex = (metric_per_token * mask).sum(dim=-1)
308+
metrics = MyMetric.many(metric_per_ex, tokens_per_ex)
309+
return metrics
310+
289311

290312
class FixedMetric(Metric):
291313
"""

parlai/core/torch_generator_agent.py

Lines changed: 19 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@
3434
from parlai.utils.misc import warn_once
3535
from parlai.utils.io import PathManager
3636
import parlai.utils.logging as logging
37-
from parlai.core.metrics import SumMetric, AverageMetric, FairseqBleuMetric
37+
from parlai.core.metrics import Metric, SumMetric, AverageMetric, FairseqBleuMetric
3838
from parlai.utils.fp16 import FP16SafeCrossEntropy
3939
import parlai.utils.fsdp as fsdp_utils
4040
from parlai.utils.torch import (
@@ -710,28 +710,35 @@ def compute_loss(self, batch, return_output=False):
710710
model_output = self.model(*self._model_input(batch), ys=batch.label_vec)
711711
scores, preds, *_ = model_output
712712
score_view = scores.reshape(-1, scores.size(-1))
713-
loss = self.criterion(score_view, batch.label_vec.view(-1))
714-
loss = loss.view(scores.shape[:-1]).sum(dim=1)
715-
# save loss to metrics
713+
loss_flattened = self.criterion(score_view, batch.label_vec.view(-1))
714+
loss_per_token = loss_flattened.view(scores.shape[:-1])
716715
notnull = batch.label_vec.ne(self.NULL_IDX)
717-
target_tokens = notnull.long().sum(dim=-1)
718-
correct = ((batch.label_vec == preds) * notnull).sum(dim=-1)
719716

717+
# save loss to metrics
720718
# cross entropy loss
721-
self.record_local_metric('loss', AverageMetric.many(loss, target_tokens))
719+
self.record_local_metric(
720+
'loss', Metric.from_mask(loss_per_token, notnull, AverageMetric)
721+
)
722722
# perplexity
723-
self.record_local_metric('ppl', PPLMetric.many(loss, target_tokens))
723+
self.record_local_metric(
724+
'ppl', Metric.from_mask(loss_per_token, notnull, PPLMetric)
725+
)
724726
# token-wise accuracy
725727
self.record_local_metric(
726-
'token_acc', AverageMetric.many(correct, target_tokens)
728+
'token_acc',
729+
Metric.from_mask(batch.label_vec == preds, notnull, AverageMetric),
727730
)
728731
# utterance-wise exact match
732+
num_target_tokens = notnull.long().sum(dim=-1)
733+
num_tokens_correct = ((batch.label_vec == preds) * notnull).sum(dim=-1)
729734
self.record_local_metric(
730-
'token_em', AverageMetric.many(correct == target_tokens)
735+
'token_em', AverageMetric.many(num_tokens_correct == num_target_tokens)
731736
)
737+
732738
# actually do backwards loss
739+
loss = loss_per_token.sum(dim=1)
733740
loss = loss.sum()
734-
loss /= target_tokens.sum() # average loss per token
741+
loss /= num_target_tokens.sum() # average loss per token
735742
if return_output:
736743
return (loss, model_output)
737744
else:
@@ -1440,7 +1447,7 @@ def set_block_list(self: TSType, block_list: Optional[SearchBlocklist]) -> TSTyp
14401447

14411448
def get_output_from_current_step(self):
14421449
"""
1443-
Get the outputput at the current step.
1450+
Get the output at the current step.
14441451
"""
14451452
return self.outputs[-1]
14461453

tests/test_metrics.py

Lines changed: 92 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
AverageMetric,
1414
SumMetric,
1515
FixedMetric,
16+
Metric,
1617
Metrics,
1718
GlobalAverageMetric,
1819
MacroAverageMetric,
@@ -28,6 +29,7 @@
2829
WeightedF1Metric,
2930
AUCMetrics,
3031
)
32+
from parlai.core.torch_generator_agent import PPLMetric
3133
import parlai.utils.testing as testing_utils
3234

3335

@@ -70,7 +72,6 @@ def test_sum_metric_additions(self):
7072
self.assertAlmostEqual(actual_output, output, places=6)
7173

7274
def test_average_metric_inputs(self):
73-
7475
passing_inputs_and_outputs = [
7576
((2, 4), 0.5),
7677
((17.0, 10.0), 1.7),
@@ -91,7 +92,6 @@ def test_average_metric_inputs(self):
9192
AverageMetric(input_[0], input_[1])
9293

9394
def test_average_metric_additions(self):
94-
9595
input_pairs_and_outputs = [
9696
((2, 4), (1.5, 1), 0.7),
9797
(
@@ -120,6 +120,96 @@ def test_macroaverage_additions(self):
120120
assert (m1 + m2) == AverageMetric(4, 7)
121121
assert MacroAverageMetric({'a': m1, 'b': m2}) == 0.5 * (1.0 / 3 + 3.0 / 4)
122122

123+
def test_average_metric_from_mask(self) -> None:
124+
# first test case. batchsize=3, num_tokens=10
125+
token_values_1 = torch.FloatTensor(
126+
[
127+
[1, 2, 3, 4, 5, 6, 7, 8, 9, 10],
128+
[-10, -8, -6, -4, -2, 0, 2, 4, 6, 8],
129+
[0.5, 1.0, 1.5, 2.0, 2.5, 3.0, 3.5, 4.0, 4.5, 5.0],
130+
]
131+
)
132+
token_mask_1 = torch.LongTensor(
133+
[
134+
[1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
135+
[1, 1, 1, 1, 1, 1, 0, 0, 0, 0],
136+
[1, 0, 1, 0, 1, 0, 1, 0, 1, 0],
137+
]
138+
)
139+
output_1 = [
140+
AverageMetric(55, 10),
141+
AverageMetric(-30, 6),
142+
AverageMetric(12.5, 5),
143+
]
144+
145+
# second test case. batchsize=4, num_tokens=5
146+
token_values_2 = torch.FloatTensor(
147+
[
148+
[1, 2, 3, 4, 5],
149+
[1.5, 0, -1, 3, -4],
150+
[-3, -2, -1, 0, 1],
151+
[4, 5, 6, 7, 8],
152+
]
153+
)
154+
token_mask_2 = torch.LongTensor(
155+
[
156+
[1, 1, 1, 1, 1],
157+
[1, 1, 1, 0, 0],
158+
[1, 0, 1, 0, 1],
159+
[0, 0, 0, 0, 0],
160+
]
161+
)
162+
output_2 = [
163+
AverageMetric(15, 5),
164+
AverageMetric(0.5, 3),
165+
AverageMetric(-3, 3),
166+
AverageMetric(0, 0),
167+
]
168+
169+
input_and_outputs = [
170+
(token_values_1, token_mask_1, output_1),
171+
(token_values_2, token_mask_2, output_2),
172+
]
173+
174+
for token_values, token_mask, output in input_and_outputs:
175+
actual_output = Metric.from_mask(token_values, token_mask, AverageMetric)
176+
self.assertEqual(len(actual_output), len(output))
177+
# Because Metric.from_mask() calls Metric.many(), which in turn converts tensors to lists,
178+
# it possible for the actual and expected outputs to be close to each other but not exactly equal.
179+
for a, o in zip(actual_output, output):
180+
self.assertIsInstance(a, type(o))
181+
self.assertAlmostEqual(a.value(), o.value(), places=6)
182+
183+
def test_ppl_metric_from_mask(self) -> None:
184+
# batchsize=3, num_tokens=10
185+
token_values = torch.FloatTensor(
186+
[
187+
[0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9],
188+
[0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1],
189+
[0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1],
190+
]
191+
)
192+
token_mask = torch.LongTensor(
193+
[
194+
[1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
195+
[1, 1, 1, 1, 1, 1, 0, 0, 0, 0],
196+
[0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
197+
]
198+
)
199+
output = [
200+
PPLMetric(4.5, 10),
201+
PPLMetric(0.6, 6),
202+
PPLMetric(0, 0),
203+
]
204+
actual_output = Metric.from_mask(token_values, token_mask, PPLMetric)
205+
206+
self.assertEqual(len(actual_output), len(output))
207+
# Because Metric.from_mask() calls Metric.many(), which in turn converts tensors to lists,
208+
# it possible for the actual and expected outputs to be close to each other but not exactly equal.
209+
for a, o in zip(actual_output, output):
210+
self.assertIsInstance(a, type(o))
211+
self.assertAlmostEqual(a.value(), o.value(), places=6)
212+
123213

124214
class TestMetrics(unittest.TestCase):
125215
"""

0 commit comments

Comments
 (0)