Skip to content

Commit

Permalink
Add: Vocab jaccard output.
Browse files Browse the repository at this point in the history
  • Loading branch information
Labbeti committed Jun 26, 2024
1 parent b8056f9 commit 19702a4
Show file tree
Hide file tree
Showing 2 changed files with 9 additions and 4 deletions.
1 change: 1 addition & 0 deletions src/aac_metrics/classes/vocab.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,7 @@ def get_output_names(self) -> tuple[str, ...]:
"vocab.precision",
"vocab.recall",
"vocab.f1",
"vocab.jaccard",
)

def reset(self) -> None:
Expand Down
12 changes: 8 additions & 4 deletions src/aac_metrics/functional/vocab.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,11 +80,14 @@ def vocab(
token for refs in tok_mrefs for ref in refs for token in ref
)

tp = corpus_vocab_cands.intersection(corpus_vocab_mrefs)
fn = corpus_vocab_mrefs.difference(corpus_vocab_cands)
vocab_precision = len(tp) / len(corpus_vocab_cands)
vocab_recall = len(tp) / (len(tp) + len(fn))
inter = corpus_vocab_cands.intersection(corpus_vocab_mrefs) # True positives
diff = corpus_vocab_mrefs.difference(corpus_vocab_cands) # False negatives
union = corpus_vocab_cands.union(corpus_vocab_mrefs)

vocab_precision = len(inter) / len(corpus_vocab_cands)
vocab_recall = len(inter) / (len(inter) + len(diff))
vocab_f1 = 2 * vocab_precision * vocab_recall / (vocab_precision + vocab_recall)
vocab_jaccard = len(inter) + len(union)

vocab_mrefs_len_full = _corpus_vocab_size(
[ref for refs in tok_mrefs for ref in refs], dtype
Expand Down Expand Up @@ -128,6 +131,7 @@ def vocab(
"vocab.precision": vocab_precision,
"vocab.recall": vocab_recall,
"vocab.f1": vocab_f1,
"vocab.jaccard": vocab_jaccard,
}

vocab_outs = corpus_scores, sents_scores
Expand Down

0 comments on commit 19702a4

Please sign in to comment.