Skip to content

Commit

Permalink
Mod: Refactor interal vocab code.
Browse files Browse the repository at this point in the history
  • Loading branch information
Labbeti committed Apr 18, 2024
1 parent 4e7fa82 commit 0ac44d9
Showing 1 changed file with 27 additions and 13 deletions.
40 changes: 27 additions & 13 deletions src/aac_metrics/functional/vocab.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,16 +2,17 @@
# -*- coding: utf-8 -*-

import logging
from typing import Callable, Literal, TypedDict, Union
from typing import Callable, Literal, TypedDict, TypeVar, Union

import torch
from torch import Tensor
from torch import Generator, Tensor

from aac_metrics.utils.checks import check_metric_inputs, is_mono_sents

pylog = logging.getLogger(__name__)


T = TypeVar("T")
POP_STRATEGIES = ("max", "min")
PopStrategy = Literal["max", "min"]
VocabScores = TypedDict("VocabScores", {"vocab.cands": Tensor})
Expand Down Expand Up @@ -85,27 +86,23 @@ def vocab(
generator = seed

if pop_strategy == "max":
n_samples = max(len(refs) for refs in tok_mrefs)
num_try = max(len(refs) for refs in tok_mrefs)
elif pop_strategy == "min":
n_samples = min(len(refs) for refs in tok_mrefs)
num_try = min(len(refs) for refs in tok_mrefs)
elif isinstance(pop_strategy, int):
n_samples = pop_strategy
num_try = pop_strategy
else:
raise ValueError(
f"Invalid argument {pop_strategy=}. (expected one of {POP_STRATEGIES} or an integer value)"
)

if verbose >= 2:
pylog.debug(f"Found {n_samples=} with {pop_strategy=}.")
pylog.debug(f"Found {num_try=} with {pop_strategy=}.")

vocab_mrefs_lens = torch.empty((n_samples,), dtype=dtype)
vocab_mrefs_lens = torch.empty((num_try,), dtype=dtype)

for i in range(n_samples):
indexes = [
int(torch.randint(0, len(refs), (), generator=generator).item())
for refs in tok_mrefs
]
popped_refs = [refs[idx] for idx, refs in zip(indexes, tok_mrefs)]
for i in range(num_try):
popped_refs, _ = _sample_sentences_split(tok_mrefs, generator=generator)
vocab_mrefs_len_i = _corpus_vocab(popped_refs, dtype)
vocab_mrefs_lens[i] = vocab_mrefs_len_i

Expand Down Expand Up @@ -138,3 +135,20 @@ def _sent_vocab(
)
sent_cands_vocab_len = sent_cands_vocabs_lens.mean()
return sent_cands_vocab_len, sent_cands_vocabs_lens


def _sample_sentences_split(
mult_sentences: list[list[T]],
generator: Union[Generator, None] = None,
) -> tuple[list[T], list[list[T]]]:
candidates: list[T] = []
mult_references: list[list[T]] = []

for sents in mult_sentences:
idx = int(torch.randint(0, len(sents), (), generator=generator).item())
cand = sents[idx]
refs = [sent for i, sent in enumerate(sents) if i != idx]
candidates.append(cand)
mult_references.append(refs)

return candidates, mult_references

0 comments on commit 0ac44d9

Please sign in to comment.