Skip to content

Commit

Permalink
Add: Vocab precision, recall and f1 in vocab metric.
Browse files Browse the repository at this point in the history
  • Loading branch information
Labbeti committed Jun 26, 2024
1 parent 0aa67ab commit b8056f9
Show file tree
Hide file tree
Showing 2 changed files with 25 additions and 8 deletions.
3 changes: 3 additions & 0 deletions src/aac_metrics/classes/vocab.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,9 @@ def get_output_names(self) -> tuple[str, ...]:
"vocab.mrefs_avg",
"vocab.mrefs_std",
"vocab.ratio_avg",
"vocab.precision",
"vocab.recall",
"vocab.f1",
)

def reset(self) -> None:
Expand Down
30 changes: 22 additions & 8 deletions src/aac_metrics/functional/vocab.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,8 +55,8 @@ def vocab(
tok_cands = list(map(tokenizer, candidates))
del candidates

vocab_cands_len = _corpus_vocab(tok_cands, dtype)
_, vocab_per_cand = _sent_vocab(tok_cands, dtype)
vocab_cands_len = _corpus_vocab_size(tok_cands, dtype)
_, vocab_per_cand = _sent_vocab_sizes(tok_cands, dtype)
if not return_all_scores:
return vocab_cands_len

Expand All @@ -75,7 +75,18 @@ def vocab(
tok_mrefs = [list(map(tokenizer, refs)) for refs in mult_references]
del mult_references

vocab_mrefs_len_full = _corpus_vocab(
corpus_vocab_cands = set(token for cand in tok_cands for token in cand)
corpus_vocab_mrefs = set(
token for refs in tok_mrefs for ref in refs for token in ref
)

tp = corpus_vocab_cands.intersection(corpus_vocab_mrefs)
fn = corpus_vocab_mrefs.difference(corpus_vocab_cands)
vocab_precision = len(tp) / len(corpus_vocab_cands)
vocab_recall = len(tp) / (len(tp) + len(fn))
vocab_f1 = 2 * vocab_precision * vocab_recall / (vocab_precision + vocab_recall)

vocab_mrefs_len_full = _corpus_vocab_size(
[ref for refs in tok_mrefs for ref in refs], dtype
)
vocab_ratio_len_full = vocab_cands_len / vocab_mrefs_len_full
Expand Down Expand Up @@ -103,7 +114,7 @@ def vocab(

for i in range(num_try):
popped_refs, _ = _sample_sentences_split(tok_mrefs, generator=generator)
vocab_mrefs_len_i = _corpus_vocab(popped_refs, dtype)
vocab_mrefs_len_i = _corpus_vocab_size(popped_refs, dtype)
vocab_mrefs_lens[i] = vocab_mrefs_len_i

vocab_mrefs_avg = vocab_mrefs_lens.mean()
Expand All @@ -114,19 +125,22 @@ def vocab(
"vocab.ratio_full": vocab_ratio_len_full,
"vocab.mrefs_avg": vocab_mrefs_avg,
"vocab.ratio_avg": vocab_len_ratio_avg,
"vocab.precision": vocab_precision,
"vocab.recall": vocab_recall,
"vocab.f1": vocab_f1,
}

vocab_outs = corpus_scores, sents_scores
return vocab_outs # type: ignore


def _corpus_vocab(tok_sents: list[list[str]], dtype: torch.dtype) -> Tensor:
corpus_cands_vocab = set(token for sent in tok_sents for token in sent)
vocab_len = torch.as_tensor(len(corpus_cands_vocab), dtype=dtype)
def _corpus_vocab_size(tok_sents: list[list[str]], dtype: torch.dtype) -> Tensor:
corpus_vocab = set(token for sent in tok_sents for token in sent)
vocab_len = torch.as_tensor(len(corpus_vocab), dtype=dtype)
return vocab_len


def _sent_vocab(
def _sent_vocab_sizes(
tok_sents: list[list[str]],
dtype: torch.dtype,
) -> tuple[Tensor, Tensor]:
Expand Down

0 comments on commit b8056f9

Please sign in to comment.