Skip to content

Commit

Permalink
Del: Sentence vocab code in Vocab metric.
Browse files Browse the repository at this point in the history
  • Loading branch information
Labbeti committed Oct 14, 2023
1 parent dacd89f commit ee2ed63
Showing 1 changed file with 1 addition and 17 deletions.
18 changes: 1 addition & 17 deletions src/aac_metrics/functional/vocab.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,14 +44,9 @@ def vocab(
if not return_all_scores:
return vocab_cands_len

sent_vocab_cands_len, sent_vocab_cands_lens = _sent_vocab(tok_cands, dtype)

sents_scores = {
"sent_vocab.cands": sent_vocab_cands_lens,
}
sents_scores = {}
corpus_scores = {
"vocab.cands": vocab_cands_len,
"sent_vocab.cands": sent_vocab_cands_len,
}

if mult_references is not None:
Expand Down Expand Up @@ -85,7 +80,6 @@ def vocab(
)

vocab_mrefs_lens = torch.empty((n_samples,), dtype=dtype)
sent_vocab_mrefs_lens = torch.empty((n_samples,), dtype=dtype)

for i in range(n_samples):
indexes = [
Expand All @@ -96,24 +90,14 @@ def vocab(
vocab_mrefs_len_i = _corpus_vocab(popped_refs)
vocab_mrefs_lens[i] = vocab_mrefs_len_i

_sent_vocab_mrefs_lens_i, sent_vocab_mrefs_len_i = _sent_vocab(
popped_refs, dtype
)
sent_vocab_mrefs_lens[i] = sent_vocab_mrefs_len_i

vocab_mrefs_avg = vocab_mrefs_lens.mean()
sent_vocab_mrefs_avg = sent_vocab_mrefs_lens.mean()

vocab_len_ratio_avg = vocab_cands_len / vocab_mrefs_avg
sent_vocab_len_ratio_avg = sent_vocab_cands_len / sent_vocab_mrefs_avg

corpus_scores |= {
"vocab.mrefs_full": vocab_mrefs_len_full,
"vocab.ratio_full": vocab_ratio_len_full,
"vocab.mrefs_avg": vocab_mrefs_avg,
"vocab.ratio_avg": vocab_len_ratio_avg,
"sent_vocab.mrefs_avg": sent_vocab_mrefs_avg,
"sent_vocab.ratio_avg": sent_vocab_len_ratio_avg,
}

return corpus_scores, sents_scores
Expand Down

0 comments on commit ee2ed63

Please sign in to comment.