From 0ac44d93e688aacaea7bef51bdfacb87fad47ad8 Mon Sep 17 00:00:00 2001 From: Labbeti Date: Thu, 18 Apr 2024 17:16:34 +0200 Subject: [PATCH] Mod: Refactor interal vocab code. --- src/aac_metrics/functional/vocab.py | 40 +++++++++++++++++++---------- 1 file changed, 27 insertions(+), 13 deletions(-) diff --git a/src/aac_metrics/functional/vocab.py b/src/aac_metrics/functional/vocab.py index cfe2565..ff5b3cb 100644 --- a/src/aac_metrics/functional/vocab.py +++ b/src/aac_metrics/functional/vocab.py @@ -2,16 +2,17 @@ # -*- coding: utf-8 -*- import logging -from typing import Callable, Literal, TypedDict, Union +from typing import Callable, Literal, TypedDict, TypeVar, Union import torch -from torch import Tensor +from torch import Generator, Tensor from aac_metrics.utils.checks import check_metric_inputs, is_mono_sents pylog = logging.getLogger(__name__) +T = TypeVar("T") POP_STRATEGIES = ("max", "min") PopStrategy = Literal["max", "min"] VocabScores = TypedDict("VocabScores", {"vocab.cands": Tensor}) @@ -85,27 +86,23 @@ def vocab( generator = seed if pop_strategy == "max": - n_samples = max(len(refs) for refs in tok_mrefs) + num_try = max(len(refs) for refs in tok_mrefs) elif pop_strategy == "min": - n_samples = min(len(refs) for refs in tok_mrefs) + num_try = min(len(refs) for refs in tok_mrefs) elif isinstance(pop_strategy, int): - n_samples = pop_strategy + num_try = pop_strategy else: raise ValueError( f"Invalid argument {pop_strategy=}. (expected one of {POP_STRATEGIES} or an integer value)" ) if verbose >= 2: - pylog.debug(f"Found {n_samples=} with {pop_strategy=}.") + pylog.debug(f"Found {num_try=} with {pop_strategy=}.") - vocab_mrefs_lens = torch.empty((n_samples,), dtype=dtype) + vocab_mrefs_lens = torch.empty((num_try,), dtype=dtype) - for i in range(n_samples): - indexes = [ - int(torch.randint(0, len(refs), (), generator=generator).item()) - for refs in tok_mrefs - ] - popped_refs = [refs[idx] for idx, refs in zip(indexes, tok_mrefs)] + for i in range(num_try): + popped_refs, _ = _sample_sentences_split(tok_mrefs, generator=generator) vocab_mrefs_len_i = _corpus_vocab(popped_refs, dtype) vocab_mrefs_lens[i] = vocab_mrefs_len_i @@ -138,3 +135,20 @@ def _sent_vocab( ) sent_cands_vocab_len = sent_cands_vocabs_lens.mean() return sent_cands_vocab_len, sent_cands_vocabs_lens + + +def _sample_sentences_split( + mult_sentences: list[list[T]], + generator: Union[Generator, None] = None, +) -> tuple[list[T], list[list[T]]]: + candidates: list[T] = [] + mult_references: list[list[T]] = [] + + for sents in mult_sentences: + idx = int(torch.randint(0, len(sents), (), generator=generator).item()) + cand = sents[idx] + refs = [sent for i, sent in enumerate(sents) if i != idx] + candidates.append(cand) + mult_references.append(refs) + + return candidates, mult_references