From b8056f992a794e15e1f87ff0629fd694015964e0 Mon Sep 17 00:00:00 2001 From: Labbeti Date: Wed, 26 Jun 2024 16:47:36 +0200 Subject: [PATCH] Add: Vocab precision, recall and f1 in vocab metric. --- src/aac_metrics/classes/vocab.py | 3 +++ src/aac_metrics/functional/vocab.py | 30 +++++++++++++++++++++-------- 2 files changed, 25 insertions(+), 8 deletions(-) diff --git a/src/aac_metrics/classes/vocab.py b/src/aac_metrics/classes/vocab.py index 9e9d845..6b1c474 100644 --- a/src/aac_metrics/classes/vocab.py +++ b/src/aac_metrics/classes/vocab.py @@ -68,6 +68,9 @@ def get_output_names(self) -> tuple[str, ...]: "vocab.mrefs_avg", "vocab.mrefs_std", "vocab.ratio_avg", + "vocab.precision", + "vocab.recall", + "vocab.f1", ) def reset(self) -> None: diff --git a/src/aac_metrics/functional/vocab.py b/src/aac_metrics/functional/vocab.py index fa41c37..a31cc08 100644 --- a/src/aac_metrics/functional/vocab.py +++ b/src/aac_metrics/functional/vocab.py @@ -55,8 +55,8 @@ def vocab( tok_cands = list(map(tokenizer, candidates)) del candidates - vocab_cands_len = _corpus_vocab(tok_cands, dtype) - _, vocab_per_cand = _sent_vocab(tok_cands, dtype) + vocab_cands_len = _corpus_vocab_size(tok_cands, dtype) + _, vocab_per_cand = _sent_vocab_sizes(tok_cands, dtype) if not return_all_scores: return vocab_cands_len @@ -75,7 +75,18 @@ def vocab( tok_mrefs = [list(map(tokenizer, refs)) for refs in mult_references] del mult_references - vocab_mrefs_len_full = _corpus_vocab( + corpus_vocab_cands = set(token for cand in tok_cands for token in cand) + corpus_vocab_mrefs = set( + token for refs in tok_mrefs for ref in refs for token in ref + ) + + tp = corpus_vocab_cands.intersection(corpus_vocab_mrefs) + fn = corpus_vocab_mrefs.difference(corpus_vocab_cands) + vocab_precision = len(tp) / len(corpus_vocab_cands) + vocab_recall = len(tp) / (len(tp) + len(fn)) + vocab_f1 = 2 * vocab_precision * vocab_recall / (vocab_precision + vocab_recall) + + vocab_mrefs_len_full = _corpus_vocab_size( [ref for refs in tok_mrefs for ref in refs], dtype ) vocab_ratio_len_full = vocab_cands_len / vocab_mrefs_len_full @@ -103,7 +114,7 @@ def vocab( for i in range(num_try): popped_refs, _ = _sample_sentences_split(tok_mrefs, generator=generator) - vocab_mrefs_len_i = _corpus_vocab(popped_refs, dtype) + vocab_mrefs_len_i = _corpus_vocab_size(popped_refs, dtype) vocab_mrefs_lens[i] = vocab_mrefs_len_i vocab_mrefs_avg = vocab_mrefs_lens.mean() @@ -114,19 +125,22 @@ def vocab( "vocab.ratio_full": vocab_ratio_len_full, "vocab.mrefs_avg": vocab_mrefs_avg, "vocab.ratio_avg": vocab_len_ratio_avg, + "vocab.precision": vocab_precision, + "vocab.recall": vocab_recall, + "vocab.f1": vocab_f1, } vocab_outs = corpus_scores, sents_scores return vocab_outs # type: ignore -def _corpus_vocab(tok_sents: list[list[str]], dtype: torch.dtype) -> Tensor: - corpus_cands_vocab = set(token for sent in tok_sents for token in sent) - vocab_len = torch.as_tensor(len(corpus_cands_vocab), dtype=dtype) +def _corpus_vocab_size(tok_sents: list[list[str]], dtype: torch.dtype) -> Tensor: + corpus_vocab = set(token for sent in tok_sents for token in sent) + vocab_len = torch.as_tensor(len(corpus_vocab), dtype=dtype) return vocab_len -def _sent_vocab( +def _sent_vocab_sizes( tok_sents: list[list[str]], dtype: torch.dtype, ) -> tuple[Tensor, Tensor]: