diff --git a/CHANGELOG.md b/CHANGELOG.md index 1195c4f..57ba1fc 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -2,6 +2,10 @@ All notable changes to this project will be documented in this file. +## [0.5.5] UNRELEASED +### Changed +- Update metric typing for language servers. + ## [0.5.4] 2024-03-04 ### Fixed - Backward compatibility of `BERTScoreMrefs` with torchmetrics prior to 1.0.0. diff --git a/src/aac_metrics/classes/bert_score_mrefs.py b/src/aac_metrics/classes/bert_score_mrefs.py index 9807287..40481d0 100644 --- a/src/aac_metrics/classes/bert_score_mrefs.py +++ b/src/aac_metrics/classes/bert_score_mrefs.py @@ -10,6 +10,7 @@ from aac_metrics.functional.bert_score_mrefs import ( DEFAULT_BERT_SCORE_MODEL, REDUCTIONS, + BERTScoreMRefsOuts, Reduction, _load_model_and_tokenizer, bert_score_mrefs, @@ -17,7 +18,7 @@ from aac_metrics.utils.globals import _get_device -class BERTScoreMRefs(AACMetric): +class BERTScoreMRefs(AACMetric[Union[BERTScoreMRefsOuts, Tensor]]): """BERTScore metric which supports multiple references. The implementation is based on the bert_score implementation of torchmetrics. @@ -37,6 +38,7 @@ class BERTScoreMRefs(AACMetric): def __init__( self, return_all_scores: bool = True, + *, model: Union[str, nn.Module] = DEFAULT_BERT_SCORE_MODEL, device: Union[str, torch.device, None] = "cuda_if_available", batch_size: int = 32, @@ -79,7 +81,9 @@ def __init__( self._candidates = [] self._mult_references = [] - def compute(self) -> Union[tuple[dict[str, Tensor], dict[str, Tensor]], Tensor]: + def compute( + self, + ) -> Union[BERTScoreMRefsOuts, Tensor]: return bert_score_mrefs( candidates=self._candidates, mult_references=self._mult_references, diff --git a/src/aac_metrics/classes/bleu.py b/src/aac_metrics/classes/bleu.py index aa1c766..1d33c98 100644 --- a/src/aac_metrics/classes/bleu.py +++ b/src/aac_metrics/classes/bleu.py @@ -9,12 +9,13 @@ from aac_metrics.functional.bleu import ( BLEU_OPTIONS, BleuOption, + BLEUOuts, _bleu_compute, _bleu_update, ) -class BLEU(AACMetric[Union[tuple[dict[str, Tensor], dict[str, Tensor]], Tensor]]): +class BLEU(AACMetric[Union[BLEUOuts, Tensor]]): """BiLingual Evaluation Understudy metric class. - Paper: https://www.aclweb.org/anthology/P02-1040.pdf @@ -32,6 +33,7 @@ class BLEU(AACMetric[Union[tuple[dict[str, Tensor], dict[str, Tensor]], Tensor]] def __init__( self, return_all_scores: bool = True, + *, n: int = 4, option: BleuOption = "closest", verbose: int = 0, @@ -52,7 +54,7 @@ def __init__( self._cooked_cands = [] self._cooked_mrefs = [] - def compute(self) -> Union[tuple[dict[str, Tensor], dict[str, Tensor]], Tensor]: + def compute(self) -> Union[BLEUOuts, Tensor]: return _bleu_compute( cooked_cands=self._cooked_cands, cooked_mrefs=self._cooked_mrefs, diff --git a/src/aac_metrics/classes/cider_d.py b/src/aac_metrics/classes/cider_d.py index 77e5eaa..9b1c0b8 100644 --- a/src/aac_metrics/classes/cider_d.py +++ b/src/aac_metrics/classes/cider_d.py @@ -1,18 +1,15 @@ #!/usr/bin/env python # -*- coding: utf-8 -*- -from typing import Any, Callable, Union +from typing import Callable, Union from torch import Tensor from aac_metrics.classes.base import AACMetric -from aac_metrics.functional.cider_d import ( - _cider_d_compute, - _cider_d_update, -) +from aac_metrics.functional.cider_d import CIDErDOuts, _cider_d_compute, _cider_d_update -class CIDErD(AACMetric[Union[tuple[dict[str, Tensor], dict[str, Any]], Tensor]]): +class CIDErD(AACMetric[Union[CIDErDOuts, Tensor]]): """Consensus-based Image Description Evaluation metric class. - Paper: https://arxiv.org/pdf/1411.5726.pdf @@ -30,6 +27,7 @@ class CIDErD(AACMetric[Union[tuple[dict[str, Tensor], dict[str, Any]], Tensor]]) def __init__( self, return_all_scores: bool = True, + *, n: int = 4, sigma: float = 6.0, tokenizer: Callable[[str], list[str]] = str.split, @@ -47,7 +45,7 @@ def __init__( self._cooked_cands = [] self._cooked_mrefs = [] - def compute(self) -> Union[tuple[dict[str, Tensor], dict[str, Tensor]], Tensor]: + def compute(self) -> Union[CIDErDOuts, Tensor]: return _cider_d_compute( cooked_cands=self._cooked_cands, cooked_mrefs=self._cooked_mrefs, diff --git a/src/aac_metrics/classes/fense.py b/src/aac_metrics/classes/fense.py index ae5f859..684b543 100644 --- a/src/aac_metrics/classes/fense.py +++ b/src/aac_metrics/classes/fense.py @@ -2,28 +2,25 @@ # -*- coding: utf-8 -*- import logging - from typing import Union import torch - from sentence_transformers import SentenceTransformer from torch import Tensor from aac_metrics.classes.base import AACMetric -from aac_metrics.functional.fense import fense, _load_models_and_tokenizer +from aac_metrics.functional.fense import FENSEOuts, _load_models_and_tokenizer, fense from aac_metrics.functional.fer import ( - BERTFlatClassifier, _ERROR_NAMES, DEFAULT_FER_MODEL, + BERTFlatClassifier, ) from aac_metrics.functional.sbert_sim import DEFAULT_SBERT_SIM_MODEL - pylog = logging.getLogger(__name__) -class FENSE(AACMetric[Union[tuple[dict[str, Tensor], dict[str, Tensor]], Tensor]]): +class FENSE(AACMetric[Union[FENSEOuts, Tensor]]): """Fluency ENhanced Sentence-bert Evaluation (FENSE) - Paper: https://arxiv.org/abs/2110.04684 @@ -42,6 +39,7 @@ class FENSE(AACMetric[Union[tuple[dict[str, Tensor], dict[str, Tensor]], Tensor] def __init__( self, return_all_scores: bool = True, + *, sbert_model: Union[str, SentenceTransformer] = DEFAULT_SBERT_SIM_MODEL, echecker: Union[str, BERTFlatClassifier] = DEFAULT_FER_MODEL, error_threshold: float = 0.9, @@ -77,7 +75,7 @@ def __init__( self._candidates = [] self._mult_references = [] - def compute(self) -> Union[tuple[dict[str, Tensor], dict[str, Tensor]], Tensor]: + def compute(self) -> Union[FENSEOuts, Tensor]: return fense( candidates=self._candidates, mult_references=self._mult_references, diff --git a/src/aac_metrics/classes/fer.py b/src/aac_metrics/classes/fer.py index 1c06592..9b863ba 100644 --- a/src/aac_metrics/classes/fer.py +++ b/src/aac_metrics/classes/fer.py @@ -2,28 +2,26 @@ # -*- coding: utf-8 -*- import logging - from typing import Union import torch - from torch import Tensor from aac_metrics.classes.base import AACMetric from aac_metrics.functional.fer import ( - BERTFlatClassifier, - fer, - _load_echecker_and_tokenizer, _ERROR_NAMES, DEFAULT_FER_MODEL, + BERTFlatClassifier, + FEROuts, + _load_echecker_and_tokenizer, + fer, ) from aac_metrics.utils.globals import _get_device - pylog = logging.getLogger(__name__) -class FER(AACMetric[Union[tuple[dict[str, Tensor], dict[str, Tensor]], Tensor]]): +class FER(AACMetric[Union[FEROuts, Tensor]]): """Return Fluency Error Rate (FER) detected by a pre-trained BERT model. - Paper: https://arxiv.org/abs/2110.04684 @@ -42,6 +40,7 @@ class FER(AACMetric[Union[tuple[dict[str, Tensor], dict[str, Tensor]], Tensor]]) def __init__( self, return_all_scores: bool = True, + *, echecker: Union[str, BERTFlatClassifier] = DEFAULT_FER_MODEL, error_threshold: float = 0.9, device: Union[str, torch.device, None] = "cuda_if_available", @@ -72,7 +71,7 @@ def __init__( self._candidates = [] - def compute(self) -> Union[tuple[dict[str, Tensor], dict[str, Tensor]], Tensor]: + def compute(self) -> Union[FEROuts, Tensor]: return fer( candidates=self._candidates, return_all_scores=self._return_all_scores, diff --git a/src/aac_metrics/classes/meteor.py b/src/aac_metrics/classes/meteor.py index 05118be..5c36bd2 100644 --- a/src/aac_metrics/classes/meteor.py +++ b/src/aac_metrics/classes/meteor.py @@ -7,10 +7,10 @@ from torch import Tensor from aac_metrics.classes.base import AACMetric -from aac_metrics.functional.meteor import Language, meteor +from aac_metrics.functional.meteor import Language, METEOROuts, meteor -class METEOR(AACMetric[Union[tuple[dict[str, Tensor], dict[str, Tensor]], Tensor]]): +class METEOR(AACMetric[Union[METEOROuts, Tensor]]): """Metric for Evaluation of Translation with Explicit ORdering metric class. - Paper: https://dl.acm.org/doi/pdf/10.5555/1626355.1626389 @@ -29,6 +29,7 @@ class METEOR(AACMetric[Union[tuple[dict[str, Tensor], dict[str, Tensor]], Tensor def __init__( self, return_all_scores: bool = True, + *, cache_path: Union[str, Path, None] = None, java_path: Union[str, Path, None] = None, java_max_memory: str = "2G", @@ -52,7 +53,7 @@ def __init__( self._candidates = [] self._mult_references = [] - def compute(self) -> Union[tuple[dict[str, Tensor], dict[str, Tensor]], Tensor]: + def compute(self) -> Union[METEOROuts, Tensor]: return meteor( candidates=self._candidates, mult_references=self._mult_references, diff --git a/src/aac_metrics/classes/rouge_l.py b/src/aac_metrics/classes/rouge_l.py index 5002d89..915a24c 100644 --- a/src/aac_metrics/classes/rouge_l.py +++ b/src/aac_metrics/classes/rouge_l.py @@ -6,13 +6,10 @@ from torch import Tensor from aac_metrics.classes.base import AACMetric -from aac_metrics.functional.rouge_l import ( - _rouge_l_compute, - _rouge_l_update, -) +from aac_metrics.functional.rouge_l import ROUGELOuts, _rouge_l_compute, _rouge_l_update -class ROUGEL(AACMetric[Union[tuple[dict[str, Tensor], dict[str, Tensor]], Tensor]]): +class ROUGEL(AACMetric[Union[ROUGELOuts, Tensor]]): """Recall-Oriented Understudy for Gisting Evaluation class. - Paper: https://aclanthology.org/W04-1013.pdf @@ -30,6 +27,7 @@ class ROUGEL(AACMetric[Union[tuple[dict[str, Tensor], dict[str, Tensor]], Tensor def __init__( self, return_all_scores: bool = True, + *, beta: float = 1.2, tokenizer: Callable[[str], list[str]] = str.split, ) -> None: @@ -40,7 +38,7 @@ def __init__( self._rouge_l_scores = [] - def compute(self) -> Union[tuple[dict[str, Tensor], dict[str, Tensor]], Tensor]: + def compute(self) -> Union[ROUGELOuts, Tensor]: return _rouge_l_compute( rouge_l_scs=self._rouge_l_scores, return_all_scores=self._return_all_scores, diff --git a/src/aac_metrics/classes/sbert_sim.py b/src/aac_metrics/classes/sbert_sim.py index 76ba245..0ba90b2 100644 --- a/src/aac_metrics/classes/sbert_sim.py +++ b/src/aac_metrics/classes/sbert_sim.py @@ -2,27 +2,25 @@ # -*- coding: utf-8 -*- import logging - from typing import Union import torch - from sentence_transformers import SentenceTransformer from torch import Tensor from aac_metrics.classes.base import AACMetric from aac_metrics.functional.sbert_sim import ( - sbert_sim, - _load_sbert, DEFAULT_SBERT_SIM_MODEL, + SBERTSimOuts, + _load_sbert, + sbert_sim, ) from aac_metrics.utils.globals import _get_device - pylog = logging.getLogger(__name__) -class SBERTSim(AACMetric[Union[tuple[dict[str, Tensor], dict[str, Tensor]], Tensor]]): +class SBERTSim(AACMetric[Union[SBERTSimOuts, Tensor]]): """Cosine-similarity of the Sentence-BERT embeddings. - Paper: https://arxiv.org/abs/1908.10084 @@ -41,6 +39,7 @@ class SBERTSim(AACMetric[Union[tuple[dict[str, Tensor], dict[str, Tensor]], Tens def __init__( self, return_all_scores: bool = True, + *, sbert_model: Union[str, SentenceTransformer] = DEFAULT_SBERT_SIM_MODEL, device: Union[str, torch.device, None] = "cuda_if_available", batch_size: int = 32, @@ -65,7 +64,7 @@ def __init__( self._candidates = [] self._mult_references = [] - def compute(self) -> Union[tuple[dict[str, Tensor], dict[str, Tensor]], Tensor]: + def compute(self) -> Union[SBERTSimOuts, Tensor]: return sbert_sim( candidates=self._candidates, mult_references=self._mult_references, diff --git a/src/aac_metrics/classes/spice.py b/src/aac_metrics/classes/spice.py index 2cdff11..1d05591 100644 --- a/src/aac_metrics/classes/spice.py +++ b/src/aac_metrics/classes/spice.py @@ -2,20 +2,18 @@ # -*- coding: utf-8 -*- import logging - from pathlib import Path from typing import Iterable, Optional, Union from torch import Tensor from aac_metrics.classes.base import AACMetric -from aac_metrics.functional.spice import spice - +from aac_metrics.functional.spice import SPICEOuts, spice pylog = logging.getLogger(__name__) -class SPICE(AACMetric[Union[tuple[dict[str, Tensor], dict[str, Tensor]], Tensor]]): +class SPICE(AACMetric[Union[SPICEOuts, Tensor]]): """Semantic Propositional Image Caption Evaluation class. - Paper: https://arxiv.org/pdf/1607.08822.pdf @@ -33,6 +31,7 @@ class SPICE(AACMetric[Union[tuple[dict[str, Tensor], dict[str, Tensor]], Tensor] def __init__( self, return_all_scores: bool = True, + *, cache_path: Union[str, Path, None] = None, java_path: Union[str, Path, None] = None, tmp_path: Union[str, Path, None] = None, @@ -58,7 +57,7 @@ def __init__( self._candidates = [] self._mult_references = [] - def compute(self) -> Union[tuple[dict[str, Tensor], dict[str, Tensor]], Tensor]: + def compute(self) -> Union[SPICEOuts, Tensor]: return spice( candidates=self._candidates, mult_references=self._mult_references, diff --git a/src/aac_metrics/classes/spider.py b/src/aac_metrics/classes/spider.py index 6605420..b37adc7 100644 --- a/src/aac_metrics/classes/spider.py +++ b/src/aac_metrics/classes/spider.py @@ -2,20 +2,18 @@ # -*- coding: utf-8 -*- import logging - from pathlib import Path from typing import Iterable, Optional, Union from torch import Tensor from aac_metrics.classes.base import AACMetric -from aac_metrics.functional.spider import spider - +from aac_metrics.functional.spider import SPIDErOuts, spider pylog = logging.getLogger(__name__) -class SPIDEr(AACMetric[Union[tuple[dict[str, Tensor], dict[str, Tensor]], Tensor]]): +class SPIDEr(AACMetric[Union[SPIDErOuts, Tensor]]): """SPIDEr class. - Paper: https://arxiv.org/pdf/1612.00370.pdf @@ -33,6 +31,7 @@ class SPIDEr(AACMetric[Union[tuple[dict[str, Tensor], dict[str, Tensor]], Tensor def __init__( self, return_all_scores: bool = True, + *, # CIDErD args n: int = 4, sigma: float = 6.0, @@ -60,7 +59,7 @@ def __init__( self._candidates = [] self._mult_references = [] - def compute(self) -> Union[tuple[dict[str, Tensor], dict[str, Tensor]], Tensor]: + def compute(self) -> Union[SPIDErOuts, Tensor]: return spider( candidates=self._candidates, mult_references=self._mult_references, diff --git a/src/aac_metrics/classes/spider_fl.py b/src/aac_metrics/classes/spider_fl.py index d03d6ea..afc3407 100644 --- a/src/aac_metrics/classes/spider_fl.py +++ b/src/aac_metrics/classes/spider_fl.py @@ -2,30 +2,27 @@ # -*- coding: utf-8 -*- import logging - from pathlib import Path from typing import Iterable, Optional, Union import torch - from torch import Tensor from transformers.models.auto.tokenization_auto import AutoTokenizer from aac_metrics.classes.base import AACMetric from aac_metrics.functional.fer import ( - BERTFlatClassifier, - _load_echecker_and_tokenizer, _ERROR_NAMES, DEFAULT_FER_MODEL, + BERTFlatClassifier, + _load_echecker_and_tokenizer, ) -from aac_metrics.functional.spider_fl import spider_fl +from aac_metrics.functional.spider_fl import SPIDErFLOuts, spider_fl from aac_metrics.utils.globals import _get_device - pylog = logging.getLogger(__name__) -class SPIDErFL(AACMetric[Union[tuple[dict[str, Tensor], dict[str, Tensor]], Tensor]]): +class SPIDErFL(AACMetric[Union[SPIDErFLOuts, Tensor]]): """SPIDErFL class. For more information, see :func:`~aac_metrics.functional.spider_fl.spider_fl`. @@ -41,6 +38,7 @@ class SPIDErFL(AACMetric[Union[tuple[dict[str, Tensor], dict[str, Tensor]], Tens def __init__( self, return_all_scores: bool = True, + *, # CIDErD args n: int = 4, sigma: float = 6.0, @@ -95,7 +93,7 @@ def __init__( self._candidates = [] self._mult_references = [] - def compute(self) -> Union[tuple[dict[str, Tensor], dict[str, Tensor]], Tensor]: + def compute(self) -> Union[SPIDErFLOuts, Tensor]: return spider_fl( candidates=self._candidates, mult_references=self._mult_references, diff --git a/src/aac_metrics/classes/spider_max.py b/src/aac_metrics/classes/spider_max.py index 8b1582c..483450d 100644 --- a/src/aac_metrics/classes/spider_max.py +++ b/src/aac_metrics/classes/spider_max.py @@ -2,20 +2,18 @@ # -*- coding: utf-8 -*- import logging - from pathlib import Path from typing import Iterable, Optional, Union from torch import Tensor from aac_metrics.classes.base import AACMetric -from aac_metrics.functional.spider_max import spider_max - +from aac_metrics.functional.spider_max import SPIDErMaxOuts, spider_max pylog = logging.getLogger(__name__) -class SPIDErMax(AACMetric[Union[tuple[dict[str, Tensor], dict[str, Tensor]], Tensor]]): +class SPIDErMax(AACMetric[Union[SPIDErMaxOuts, Tensor]]): """SPIDEr-max class. - Paper: https://hal.archives-ouvertes.fr/hal-03810396/file/Labbe_DCASE2022.pdf @@ -33,6 +31,7 @@ class SPIDErMax(AACMetric[Union[tuple[dict[str, Tensor], dict[str, Tensor]], Ten def __init__( self, return_all_scores: bool = True, + *, return_all_cands_scores: bool = False, # CIDEr args n: int = 4, @@ -62,7 +61,7 @@ def __init__( self._mult_candidates = [] self._mult_references = [] - def compute(self) -> Union[tuple[dict[str, Tensor], dict[str, Tensor]], Tensor]: + def compute(self) -> Union[SPIDErMaxOuts, Tensor]: return spider_max( mult_candidates=self._mult_candidates, mult_references=self._mult_references, diff --git a/src/aac_metrics/classes/vocab.py b/src/aac_metrics/classes/vocab.py index 3a7e4c7..2fa7808 100644 --- a/src/aac_metrics/classes/vocab.py +++ b/src/aac_metrics/classes/vocab.py @@ -9,12 +9,12 @@ from torch import Tensor from aac_metrics.classes.base import AACMetric -from aac_metrics.functional.vocab import PopStrategy, vocab +from aac_metrics.functional.vocab import PopStrategy, VocabOuts, vocab pylog = logging.getLogger(__name__) -class Vocab(AACMetric[Union[tuple[dict[str, Tensor], dict[str, Tensor]], Tensor]]): +class Vocab(AACMetric[Union[VocabOuts, Tensor]]): """VocabStats class. For more information, see :func:`~aac_metrics.functional.vocab.vocab`. @@ -30,6 +30,7 @@ class Vocab(AACMetric[Union[tuple[dict[str, Tensor], dict[str, Tensor]], Tensor] def __init__( self, return_all_scores: bool = True, + *, seed: Union[None, int, torch.Generator] = 1234, tokenizer: Callable[[str], list[str]] = str.split, dtype: torch.dtype = torch.float64, @@ -47,7 +48,7 @@ def __init__( self._candidates = [] self._mult_references = [] - def compute(self) -> Union[tuple[dict[str, Tensor], dict[str, Tensor]], Tensor]: + def compute(self) -> Union[VocabOuts, Tensor]: return vocab( candidates=self._candidates, mult_references=self._mult_references, diff --git a/src/aac_metrics/functional/bert_score_mrefs.py b/src/aac_metrics/functional/bert_score_mrefs.py index a9a07a5..a8f0e62 100644 --- a/src/aac_metrics/functional/bert_score_mrefs.py +++ b/src/aac_metrics/functional/bert_score_mrefs.py @@ -1,7 +1,7 @@ #!/usr/bin/env python # -*- coding: utf-8 -*- -from typing import Callable, Literal, Optional, Union +from typing import Callable, Literal, Optional, TypedDict, Union import torch import torchmetrics @@ -19,12 +19,23 @@ DEFAULT_BERT_SCORE_MODEL = _DEFAULT_MODEL REDUCTIONS = ("mean", "max", "min") Reduction = Union[Literal["mean", "max", "min"], Callable[..., Tensor]] +_DEFAULT_SCORE_NAME = "bert_score.f1" +BERTScoreMRefsScores = TypedDict( + "BERTScoreMRefsScores", + { + "bert_score.f1": Tensor, + "bert_score.precision": Tensor, + "bert_score.recall": Tensor, + }, +) +BERTScoreMRefsOuts = tuple[BERTScoreMRefsScores, BERTScoreMRefsScores] def bert_score_mrefs( candidates: list[str], mult_references: list[list[str]], return_all_scores: bool = True, + *, model: Union[str, nn.Module] = DEFAULT_BERT_SCORE_MODEL, tokenizer: Optional[Callable] = None, device: Union[str, torch.device, None] = "cuda_if_available", @@ -36,7 +47,7 @@ def bert_score_mrefs( reduction: Reduction = "max", filter_nan: bool = True, verbose: int = 0, -) -> Union[tuple[dict[str, Tensor], dict[str, Tensor]], Tensor]: +) -> Union[BERTScoreMRefsOuts, Tensor]: """BERTScore metric which supports multiple references. The implementation is based on the bert_score implementation of torchmetrics. @@ -174,7 +185,7 @@ def bert_score_mrefs( if return_all_scores: return corpus_scores, sents_scores else: - return corpus_scores["bert_score.f1"] + return corpus_scores[_DEFAULT_SCORE_NAME] def _load_model_and_tokenizer( diff --git a/src/aac_metrics/functional/bleu.py b/src/aac_metrics/functional/bleu.py index 43cbb8d..82f9684 100644 --- a/src/aac_metrics/functional/bleu.py +++ b/src/aac_metrics/functional/bleu.py @@ -15,18 +15,21 @@ BLEU_OPTIONS = ("shortest", "average", "closest") BleuOption = Literal["shortest", "average", "closest"] +BLEUScores = dict[str, Tensor] +BLEUOuts = tuple[BLEUScores, BLEUScores] def bleu( candidates: list[str], mult_references: list[list[str]], return_all_scores: bool = True, + *, n: int = 4, option: BleuOption = "closest", verbose: int = 0, tokenizer: Callable[[str], list[str]] = str.split, return_1_to_n: bool = False, -) -> Union[tuple[dict[str, Tensor], dict[str, Tensor]], Tensor]: +) -> Union[BLEUOuts, Tensor]: """BiLingual Evaluation Understudy function. - Paper: https://www.aclweb.org/anthology/P02-1040.pdf @@ -70,11 +73,12 @@ def bleu_1( candidates: list[str], mult_references: list[list[str]], return_all_scores: bool = True, + *, option: BleuOption = "closest", verbose: int = 0, tokenizer: Callable[[str], list[str]] = str.split, return_1_to_n: bool = False, -) -> Union[tuple[dict[str, Tensor], dict[str, Tensor]], Tensor]: +) -> Union[BLEUOuts, Tensor]: return bleu( candidates=candidates, mult_references=mult_references, @@ -91,11 +95,12 @@ def bleu_2( candidates: list[str], mult_references: list[list[str]], return_all_scores: bool = True, + *, option: BleuOption = "closest", verbose: int = 0, tokenizer: Callable[[str], list[str]] = str.split, return_1_to_n: bool = False, -) -> Union[tuple[dict[str, Tensor], dict[str, Tensor]], Tensor]: +) -> Union[BLEUOuts, Tensor]: return bleu( candidates=candidates, mult_references=mult_references, @@ -112,11 +117,12 @@ def bleu_3( candidates: list[str], mult_references: list[list[str]], return_all_scores: bool = True, + *, option: BleuOption = "closest", verbose: int = 0, tokenizer: Callable[[str], list[str]] = str.split, return_1_to_n: bool = False, -) -> Union[tuple[dict[str, Tensor], dict[str, Tensor]], Tensor]: +) -> Union[BLEUOuts, Tensor]: return bleu( candidates=candidates, mult_references=mult_references, @@ -133,11 +139,12 @@ def bleu_4( candidates: list[str], mult_references: list[list[str]], return_all_scores: bool = True, + *, option: BleuOption = "closest", verbose: int = 0, tokenizer: Callable[[str], list[str]] = str.split, return_1_to_n: bool = False, -) -> Union[tuple[dict[str, Tensor], dict[str, Tensor]], Tensor]: +) -> Union[BLEUOuts, Tensor]: return bleu( candidates=candidates, mult_references=mult_references, @@ -180,7 +187,7 @@ def _bleu_compute( option: BleuOption = "closest", verbose: int = 0, return_1_to_n: bool = False, -) -> Union[Tensor, tuple[dict[str, Tensor], dict[str, Tensor]]]: +) -> Union[Tensor, BLEUOuts]: if option not in BLEU_OPTIONS: raise ValueError(f"Invalid option {option=}. (expected one of {BLEU_OPTIONS})") diff --git a/src/aac_metrics/functional/cider_d.py b/src/aac_metrics/functional/cider_d.py index 136b24e..35eb470 100644 --- a/src/aac_metrics/functional/cider_d.py +++ b/src/aac_metrics/functional/cider_d.py @@ -1,27 +1,30 @@ #!/usr/bin/env python # -*- coding: utf-8 -*- -from collections import defaultdict, Counter -from typing import Any, Callable, Mapping, Union +from collections import Counter, defaultdict +from typing import Callable, Mapping, TypedDict, Union import numpy as np import torch - from torch import Tensor from aac_metrics.utils.checks import check_metric_inputs +CIDErDScores = TypedDict("CIDErDScores", {"cider_d": Tensor}) +CIDErDOuts = tuple[CIDErDScores, CIDErDScores] + def cider_d( candidates: list[str], mult_references: list[list[str]], return_all_scores: bool = True, + *, n: int = 4, sigma: float = 6.0, tokenizer: Callable[[str], list[str]] = str.split, return_tfidf: bool = False, scale: float = 10.0, -) -> Union[Tensor, tuple[dict[str, Tensor], dict[str, Any]]]: +) -> Union[Tensor, CIDErDOuts]: """Consensus-based Image Description Evaluation function. - Paper: https://arxiv.org/pdf/1411.5726.pdf @@ -86,8 +89,8 @@ def _cider_d_compute( sigma: float, return_tfidf: bool, scale: float, -) -> Union[Tensor, tuple[dict[str, Tensor], dict[str, Any]]]: - if len(cooked_cands) <= 1: +) -> Union[Tensor, CIDErDOuts]: + if len(cooked_cands) < 2: raise ValueError( f"CIDEr-D metric does not support less than 2 candidates with 2 references. (found {len(cooked_cands)} candidates, but expected > 1)" ) diff --git a/src/aac_metrics/functional/fense.py b/src/aac_metrics/functional/fense.py index 74e462c..43cb66e 100644 --- a/src/aac_metrics/functional/fense.py +++ b/src/aac_metrics/functional/fense.py @@ -2,28 +2,33 @@ # -*- coding: utf-8 -*- import logging - -from typing import Optional, Union +from typing import Optional, TypedDict, Union import torch - from sentence_transformers import SentenceTransformer from torch import Tensor from transformers.models.auto.tokenization_auto import AutoTokenizer from aac_metrics.functional.fer import ( - fer, - _load_echecker_and_tokenizer, - BERTFlatClassifier, DEFAULT_FER_MODEL, + BERTFlatClassifier, + FEROuts, + _load_echecker_and_tokenizer, + fer, ) from aac_metrics.functional.sbert_sim import ( - sbert_sim, - _load_sbert, DEFAULT_SBERT_SIM_MODEL, + SBERTSimOuts, + _load_sbert, + sbert_sim, ) from aac_metrics.utils.checks import check_metric_inputs +FENSEScores = TypedDict( + "FENSEScores", {"sbert_sim": Tensor, "fer": Tensor, "fense": Tensor} +) +FENSEOuts = tuple[FENSEScores, FENSEScores] + pylog = logging.getLogger(__name__) @@ -32,6 +37,7 @@ def fense( candidates: list[str], mult_references: list[list[str]], return_all_scores: bool = True, + *, # SBERT args sbert_model: Union[str, SentenceTransformer] = DEFAULT_SBERT_SIM_MODEL, # FluencyError args @@ -45,7 +51,7 @@ def fense( # Other args penalty: float = 0.9, verbose: int = 0, -) -> Union[Tensor, tuple[dict[str, Tensor], dict[str, Tensor]]]: +) -> Union[Tensor, FENSEOuts]: """Fluency ENhanced Sentence-bert Evaluation (FENSE) - Paper: https://arxiv.org/abs/2110.04684 @@ -83,7 +89,7 @@ def fense( reset_state=reset_state, verbose=verbose, ) - sbert_sim_outs: tuple[dict[str, Tensor], dict[str, Tensor]] = sbert_sim( # type: ignore + sbert_sim_outs: SBERTSimOuts = sbert_sim( # type: ignore candidates=candidates, mult_references=mult_references, return_all_scores=True, @@ -93,7 +99,7 @@ def fense( reset_state=reset_state, verbose=verbose, ) - fer_outs: tuple[dict[str, Tensor], dict[str, Tensor]] = fer( # type: ignore + fer_outs: FEROuts = fer( # type: ignore candidates=candidates, return_all_scores=True, echecker=echecker, @@ -114,10 +120,10 @@ def fense( def _fense_from_outputs( - sbert_sim_outs: tuple[dict[str, Tensor], dict[str, Tensor]], - fer_outs: tuple[dict[str, Tensor], dict[str, Tensor]], + sbert_sim_outs: SBERTSimOuts, + fer_outs: FEROuts, penalty: float = 0.9, -) -> tuple[dict[str, Tensor], dict[str, Tensor]]: +) -> FENSEOuts: """Combines SBERT and FER outputs. Based on https://github.com/blmoistawinde/fense/blob/main/fense/evaluator.py#L121 diff --git a/src/aac_metrics/functional/fer.py b/src/aac_metrics/functional/fer.py index 24f6095..605d5e4 100644 --- a/src/aac_metrics/functional/fer.py +++ b/src/aac_metrics/functional/fer.py @@ -5,18 +5,16 @@ import logging import os import re -import requests - from collections import namedtuple from os import environ, makedirs from os.path import exists, expanduser, join -from typing import Mapping, Optional, Union +from typing import Mapping, Optional, TypedDict, Union import numpy as np +import requests import torch import transformers - -from torch import nn, Tensor +from torch import Tensor, nn from tqdm import tqdm from transformers import logging as tfmers_logging from transformers.models.auto.modeling_auto import AutoModel @@ -26,8 +24,9 @@ from aac_metrics.utils.checks import is_mono_sents from aac_metrics.utils.globals import _get_device - DEFAULT_FER_MODEL = "echecker_clotho_audiocaps_base" +FERScores = TypedDict("FERScores", {"fer": Tensor}) +FEROuts = tuple[FERScores, FERScores] _DEFAULT_PROXIES = { @@ -101,6 +100,7 @@ def forward( def fer( candidates: list[str], return_all_scores: bool = True, + *, echecker: Union[str, BERTFlatClassifier] = DEFAULT_FER_MODEL, echecker_tokenizer: Optional[AutoTokenizer] = None, error_threshold: float = 0.9, @@ -109,7 +109,7 @@ def fer( reset_state: bool = True, return_probs: bool = False, verbose: int = 0, -) -> Union[Tensor, tuple[dict[str, Tensor], dict[str, Tensor]]]: +) -> Union[Tensor, FEROuts]: """Return Fluency Error Rate (FER) detected by a pre-trained BERT model. - Paper: https://arxiv.org/abs/2110.04684 diff --git a/src/aac_metrics/functional/meteor.py b/src/aac_metrics/functional/meteor.py index db6a67c..e395a6d 100644 --- a/src/aac_metrics/functional/meteor.py +++ b/src/aac_metrics/functional/meteor.py @@ -7,7 +7,7 @@ import subprocess from pathlib import Path from subprocess import Popen -from typing import Iterable, Literal, Optional, Union +from typing import Iterable, Literal, Optional, TypedDict, Union import torch from torch import Tensor @@ -22,12 +22,15 @@ FNAME_METEOR_JAR = osp.join(DNAME_METEOR_CACHE, "meteor-1.5.jar") SUPPORTED_LANGUAGES = ("en", "cz", "de", "es", "fr") Language = Literal["en", "cz", "de", "es", "fr"] +METEORScores = TypedDict("METEORScores", {"meteor": Tensor}) +METEOROuts = tuple[METEORScores, METEORScores] def meteor( candidates: list[str], mult_references: list[list[str]], return_all_scores: bool = True, + *, cache_path: Union[str, Path, None] = None, java_path: Union[str, Path, None] = None, java_max_memory: str = "2G", @@ -36,7 +39,7 @@ def meteor( params: Optional[Iterable[float]] = None, weights: Optional[Iterable[float]] = None, verbose: int = 0, -) -> Union[tuple[dict[str, Tensor], dict[str, Tensor]], Tensor]: +) -> Union[METEOROuts, Tensor]: """Metric for Evaluation of Translation with Explicit ORdering function. - Paper: https://dl.acm.org/doi/pdf/10.5555/1626355.1626389 diff --git a/src/aac_metrics/functional/mult_cands.py b/src/aac_metrics/functional/mult_cands.py index cdf7c3d..1243337 100644 --- a/src/aac_metrics/functional/mult_cands.py +++ b/src/aac_metrics/functional/mult_cands.py @@ -19,6 +19,7 @@ def mult_cands_metric( mult_candidates: list[list[str]], mult_references: list[list[str]], return_all_scores: bool = True, + *, return_all_cands_scores: bool = False, selection: Selection = "max", reduction_fn: Callable[[Tensor], Tensor] = torch.mean, diff --git a/src/aac_metrics/functional/rouge_l.py b/src/aac_metrics/functional/rouge_l.py index 1106dfa..f23ae79 100644 --- a/src/aac_metrics/functional/rouge_l.py +++ b/src/aac_metrics/functional/rouge_l.py @@ -2,16 +2,17 @@ # -*- coding: utf-8 -*- import logging - -from typing import Callable, Union +from typing import Callable, TypedDict, Union import numpy as np import torch - from torch import Tensor from aac_metrics.utils.checks import check_metric_inputs +ROUGELScores = TypedDict("ROUGELScores", {"rouge_l": Tensor}) +ROUGELOuts = tuple[ROUGELScores, ROUGELScores] + pylog = logging.getLogger(__name__) @@ -20,9 +21,10 @@ def rouge_l( candidates: list[str], mult_references: list[list[str]], return_all_scores: bool = True, + *, beta: float = 1.2, tokenizer: Callable[[str], list[str]] = str.split, -) -> Union[tuple[dict[str, Tensor], dict[str, Tensor]], Tensor]: +) -> Union[ROUGELOuts, Tensor]: """Recall-Oriented Understudy for Gisting Evaluation function. - Paper: https://aclanthology.org/W04-1013.pdf @@ -62,7 +64,7 @@ def _rouge_l_update( def _rouge_l_compute( rouge_l_scs: list[float], return_all_scores: bool, -) -> Union[tuple[dict[str, Tensor], dict[str, Tensor]], Tensor]: +) -> Union[ROUGELOuts, Tensor]: # Note: use numpy to compute mean because np.mean and torch.mean can give very small differences rouge_l_scores_np = np.array(rouge_l_scs) rouge_l_score_np = rouge_l_scores_np.mean() diff --git a/src/aac_metrics/functional/sbert_sim.py b/src/aac_metrics/functional/sbert_sim.py index b173d17..e73cbce 100644 --- a/src/aac_metrics/functional/sbert_sim.py +++ b/src/aac_metrics/functional/sbert_sim.py @@ -2,20 +2,20 @@ # -*- coding: utf-8 -*- import logging - -from typing import Union +from typing import TypedDict, Union import numpy as np import torch - from sentence_transformers import SentenceTransformer from torch import Tensor from aac_metrics.utils.checks import check_metric_inputs from aac_metrics.utils.globals import _get_device - DEFAULT_SBERT_SIM_MODEL = "paraphrase-TinyBERT-L6-v2" +SBERTSimScores = TypedDict("SBERTSimScores", {"sbert_sim": Tensor}) +SBERTSimOuts = tuple[SBERTSimScores, SBERTSimScores] + pylog = logging.getLogger(__name__) @@ -24,12 +24,13 @@ def sbert_sim( candidates: list[str], mult_references: list[list[str]], return_all_scores: bool = True, + *, sbert_model: Union[str, SentenceTransformer] = DEFAULT_SBERT_SIM_MODEL, device: Union[str, torch.device, None] = "cuda_if_available", batch_size: int = 32, reset_state: bool = True, verbose: int = 0, -) -> Union[Tensor, tuple[dict[str, Tensor], dict[str, Tensor]]]: +) -> Union[Tensor, SBERTSimOuts]: """Cosine-similarity of the Sentence-BERT embeddings. - Paper: https://arxiv.org/abs/1908.10084 diff --git a/src/aac_metrics/functional/spice.py b/src/aac_metrics/functional/spice.py index d87e836..539d5d0 100644 --- a/src/aac_metrics/functional/spice.py +++ b/src/aac_metrics/functional/spice.py @@ -12,23 +12,20 @@ import subprocess import tempfile import time - from pathlib import Path from subprocess import CalledProcessError from tempfile import NamedTemporaryFile -from typing import Any, Iterable, Optional, Union +from typing import Any, Iterable, Optional, TypedDict, Union import numpy as np import torch - from torch import Tensor from aac_metrics.utils.checks import check_java_path, check_metric_inputs -from aac_metrics.utils.globals import ( - _get_cache_path, - _get_java_path, - _get_tmp_path, -) +from aac_metrics.utils.globals import _get_cache_path, _get_java_path, _get_tmp_path + +SPICEScores = TypedDict("SPICEScores", {"spice": Tensor}) +SPICEOuts = tuple[SPICEScores, SPICEScores] pylog = logging.getLogger(__name__) @@ -43,6 +40,7 @@ def spice( candidates: list[str], mult_references: list[list[str]], return_all_scores: bool = True, + *, cache_path: Union[str, Path, None] = None, java_path: Union[str, Path, None] = None, tmp_path: Union[str, Path, None] = None, @@ -52,7 +50,7 @@ def spice( separate_cache_dir: bool = True, use_shell: Optional[bool] = None, verbose: int = 0, -) -> Union[tuple[dict[str, Tensor], dict[str, Tensor]], Tensor]: +) -> Union[SPICEOuts, Tensor]: """Semantic Propositional Image Caption Evaluation function. - Paper: https://arxiv.org/pdf/1607.08822.pdf diff --git a/src/aac_metrics/functional/spider.py b/src/aac_metrics/functional/spider.py index f297557..fb0a537 100644 --- a/src/aac_metrics/functional/spider.py +++ b/src/aac_metrics/functional/spider.py @@ -2,19 +2,25 @@ # -*- coding: utf-8 -*- from pathlib import Path -from typing import Callable, Iterable, Optional, Union +from typing import Callable, Iterable, Optional, TypedDict, Union from torch import Tensor -from aac_metrics.functional.cider_d import cider_d -from aac_metrics.functional.spice import spice +from aac_metrics.functional.cider_d import CIDErDOuts, cider_d +from aac_metrics.functional.spice import SPICEOuts, spice from aac_metrics.utils.checks import check_metric_inputs +SPIDErScores = TypedDict( + "SPIDErScores", {"spider": Tensor, "cider_d": Tensor, "spice": Tensor} +) +SPIDErOuts = tuple[SPIDErScores, SPIDErScores] + def spider( candidates: list[str], mult_references: list[list[str]], return_all_scores: bool = True, + *, # CIDErD args n: int = 4, sigma: float = 6.0, @@ -28,7 +34,7 @@ def spider( java_max_memory: str = "8G", timeout: Union[None, int, Iterable[int]] = None, verbose: int = 0, -) -> Union[tuple[dict[str, Tensor], dict[str, Tensor]], Tensor]: +) -> Union[SPIDErOuts, Tensor]: """SPIDEr function. - Paper: https://arxiv.org/pdf/1612.00370.pdf @@ -63,7 +69,7 @@ def spider( check_metric_inputs(candidates, mult_references) sub_return_all_scores = True - cider_d_outs: tuple[dict[str, Tensor], dict[str, Tensor]] = cider_d( # type: ignore + cider_d_outs: CIDErDOuts = cider_d( # type: ignore candidates=candidates, mult_references=mult_references, return_all_scores=sub_return_all_scores, @@ -72,7 +78,7 @@ def spider( tokenizer=tokenizer, return_tfidf=return_tfidf, ) - spice_outs: tuple[dict[str, Tensor], dict[str, Tensor]] = spice( # type: ignore + spice_outs: SPICEOuts = spice( # type: ignore candidates=candidates, mult_references=mult_references, return_all_scores=sub_return_all_scores, @@ -93,9 +99,9 @@ def spider( def _spider_from_outputs( - cider_d_outs: tuple[dict[str, Tensor], dict[str, Tensor]], - spice_outs: tuple[dict[str, Tensor], dict[str, Tensor]], -) -> tuple[dict[str, Tensor], dict[str, Tensor]]: + cider_d_outs: CIDErDOuts, + spice_outs: SPICEOuts, +) -> SPIDErOuts: """Combines CIDErD and SPICE outputs.""" cider_d_outs_corpus, cider_d_outs_sents = cider_d_outs spice_outs_corpus, spice_outs_sents = spice_outs diff --git a/src/aac_metrics/functional/spider_fl.py b/src/aac_metrics/functional/spider_fl.py index 65d092a..ac1b6df 100644 --- a/src/aac_metrics/functional/spider_fl.py +++ b/src/aac_metrics/functional/spider_fl.py @@ -2,24 +2,34 @@ # -*- coding: utf-8 -*- import logging - from pathlib import Path -from typing import Callable, Iterable, Optional, Union +from typing import Callable, Iterable, Optional, TypedDict, Union import torch - from torch import Tensor from transformers.models.auto.tokenization_auto import AutoTokenizer from aac_metrics.functional.fer import ( - fer, - _load_echecker_and_tokenizer, - BERTFlatClassifier, DEFAULT_FER_MODEL, + BERTFlatClassifier, + FEROuts, + _load_echecker_and_tokenizer, + fer, ) -from aac_metrics.functional.spider import spider +from aac_metrics.functional.spider import SPIDErOuts, spider from aac_metrics.utils.checks import check_metric_inputs +SPIDErFLScores = TypedDict( + "SPIDErFLScores", + { + "spider_fl": Tensor, + "spider": Tensor, + "cider_d": Tensor, + "spice": Tensor, + "fer": Tensor, + }, +) +SPIDErFLOuts = tuple[SPIDErFLScores, SPIDErFLScores] pylog = logging.getLogger(__name__) @@ -28,6 +38,7 @@ def spider_fl( candidates: list[str], mult_references: list[list[str]], return_all_scores: bool = True, + *, # CIDErD args n: int = 4, sigma: float = 6.0, @@ -51,7 +62,7 @@ def spider_fl( # Other args penalty: float = 0.9, verbose: int = 0, -) -> Union[Tensor, tuple[dict[str, Tensor], dict[str, Tensor]]]: +) -> Union[Tensor, SPIDErFLOuts]: """Combinaison of SPIDEr with Fluency Error detector. - Original implementation: https://github.com/felixgontier/dcase-2023-baseline/blob/main/metrics.py#L48. @@ -105,7 +116,7 @@ def spider_fl( reset_state=reset_state, verbose=verbose, ) - spider_outs: tuple[dict[str, Tensor], dict[str, Tensor]] = spider( # type: ignore + spider_outs: SPIDErOuts = spider( # type: ignore candidates=candidates, mult_references=mult_references, return_all_scores=True, @@ -121,7 +132,7 @@ def spider_fl( timeout=timeout, verbose=verbose, ) - fer_outs: tuple[dict[str, Tensor], dict[str, Tensor]] = fer( # type: ignore + fer_outs: FEROuts = fer( # type: ignore candidates=candidates, return_all_scores=True, echecker=echecker, @@ -142,10 +153,10 @@ def spider_fl( def _spider_fl_from_outputs( - spider_outs: tuple[dict[str, Tensor], dict[str, Tensor]], - fer_outs: tuple[dict[str, Tensor], dict[str, Tensor]], + spider_outs: SPIDErOuts, + fer_outs: FEROuts, penalty: float = 0.9, -) -> tuple[dict[str, Tensor], dict[str, Tensor]]: +) -> SPIDErFLOuts: """Combines SPIDEr and FER outputs. Based on https://github.com/felixgontier/dcase-2023-baseline/blob/main/metrics.py#L48 diff --git a/src/aac_metrics/functional/spider_max.py b/src/aac_metrics/functional/spider_max.py index 7fe216e..4c0fee4 100644 --- a/src/aac_metrics/functional/spider_max.py +++ b/src/aac_metrics/functional/spider_max.py @@ -2,7 +2,7 @@ # -*- coding: utf-8 -*- from pathlib import Path -from typing import Callable, Iterable, Optional, Union +from typing import Callable, Iterable, Optional, TypedDict, Union import torch from torch import Tensor @@ -10,11 +10,18 @@ from aac_metrics.functional.mult_cands import mult_cands_metric from aac_metrics.functional.spider import spider +SPIDErMaxScores = TypedDict( + "SPIDErMaxScores", + {"spider_max": Tensor, "cider_d_max": Tensor, "spice_max": Tensor}, +) +SPIDErMaxOuts = tuple[SPIDErMaxScores, SPIDErMaxScores] + def spider_max( mult_candidates: list[list[str]], mult_references: list[list[str]], return_all_scores: bool = True, + *, return_all_cands_scores: bool = False, # CIDEr args n: int = 4, diff --git a/src/aac_metrics/functional/vocab.py b/src/aac_metrics/functional/vocab.py index 2e06740..cfe2565 100644 --- a/src/aac_metrics/functional/vocab.py +++ b/src/aac_metrics/functional/vocab.py @@ -2,7 +2,7 @@ # -*- coding: utf-8 -*- import logging -from typing import Callable, Literal, Union +from typing import Callable, Literal, TypedDict, Union import torch from torch import Tensor @@ -14,18 +14,21 @@ POP_STRATEGIES = ("max", "min") PopStrategy = Literal["max", "min"] +VocabScores = TypedDict("VocabScores", {"vocab.cands": Tensor}) +VocabOuts = tuple[VocabScores, VocabScores] def vocab( candidates: list[str], mult_references: Union[list[list[str]], None], return_all_scores: bool = True, + *, seed: Union[None, int, torch.Generator] = 1234, tokenizer: Callable[[str], list[str]] = str.split, dtype: torch.dtype = torch.float64, pop_strategy: PopStrategy = "max", verbose: int = 0, -) -> Union[tuple[dict[str, Tensor], dict[str, Tensor]], Tensor]: +) -> Union[VocabOuts, Tensor]: """Compute vocabulary statistics. Returns the candidate corpus vocabulary length, the references vocabulary length, the average vocabulary length for single references, and the vocabulary ratios between candidates and references. @@ -52,10 +55,13 @@ def vocab( del candidates vocab_cands_len = _corpus_vocab(tok_cands, dtype) + _, vocab_per_cand = _sent_vocab(tok_cands, dtype) if not return_all_scores: return vocab_cands_len - sents_scores = {} + sents_scores = { + "vocab.cands": vocab_per_cand, + } corpus_scores = { "vocab.cands": vocab_cands_len, }