diff --git a/src/aac_metrics/classes/vocab.py b/src/aac_metrics/classes/vocab.py index 6b1c474..f6c70b5 100644 --- a/src/aac_metrics/classes/vocab.py +++ b/src/aac_metrics/classes/vocab.py @@ -71,6 +71,7 @@ def get_output_names(self) -> tuple[str, ...]: "vocab.precision", "vocab.recall", "vocab.f1", + "vocab.jaccard", ) def reset(self) -> None: diff --git a/src/aac_metrics/functional/vocab.py b/src/aac_metrics/functional/vocab.py index a31cc08..f207d56 100644 --- a/src/aac_metrics/functional/vocab.py +++ b/src/aac_metrics/functional/vocab.py @@ -80,11 +80,14 @@ def vocab( token for refs in tok_mrefs for ref in refs for token in ref ) - tp = corpus_vocab_cands.intersection(corpus_vocab_mrefs) - fn = corpus_vocab_mrefs.difference(corpus_vocab_cands) - vocab_precision = len(tp) / len(corpus_vocab_cands) - vocab_recall = len(tp) / (len(tp) + len(fn)) + inter = corpus_vocab_cands.intersection(corpus_vocab_mrefs) # True positives + diff = corpus_vocab_mrefs.difference(corpus_vocab_cands) # False negatives + union = corpus_vocab_cands.union(corpus_vocab_mrefs) + + vocab_precision = len(inter) / len(corpus_vocab_cands) + vocab_recall = len(inter) / (len(inter) + len(diff)) vocab_f1 = 2 * vocab_precision * vocab_recall / (vocab_precision + vocab_recall) + vocab_jaccard = len(inter) + len(union) vocab_mrefs_len_full = _corpus_vocab_size( [ref for refs in tok_mrefs for ref in refs], dtype @@ -128,6 +131,7 @@ def vocab( "vocab.precision": vocab_precision, "vocab.recall": vocab_recall, "vocab.f1": vocab_f1, + "vocab.jaccard": vocab_jaccard, } vocab_outs = corpus_scores, sents_scores