From b030c8eef36a61391507f50f34417efa86aef9af Mon Sep 17 00:00:00 2001 From: Labbeti Date: Tue, 5 Mar 2024 13:41:35 +0100 Subject: [PATCH] Mod: Update argument typing for bert score, bleu, meteor and vocab metrics. --- src/aac_metrics/classes/bert_score_mrefs.py | 14 +++++------ src/aac_metrics/classes/bleu.py | 3 ++- src/aac_metrics/classes/meteor.py | 4 ++-- src/aac_metrics/classes/vocab.py | 7 ++---- .../functional/bert_score_mrefs.py | 7 +++--- src/aac_metrics/functional/bleu.py | 24 +++++++++---------- src/aac_metrics/functional/meteor.py | 8 +++---- src/aac_metrics/functional/mult_cands.py | 13 +++++----- src/aac_metrics/functional/spider_max.py | 3 +-- src/aac_metrics/functional/vocab.py | 12 +++++----- src/aac_metrics/utils/checks.py | 6 +++++ src/aac_metrics/utils/globals.py | 8 +++---- 12 files changed, 54 insertions(+), 55 deletions(-) diff --git a/src/aac_metrics/classes/bert_score_mrefs.py b/src/aac_metrics/classes/bert_score_mrefs.py index 369103b..9807287 100644 --- a/src/aac_metrics/classes/bert_score_mrefs.py +++ b/src/aac_metrics/classes/bert_score_mrefs.py @@ -1,18 +1,18 @@ #!/usr/bin/env python # -*- coding: utf-8 -*- -from typing import Callable, Union +from typing import Union import torch - -from torch import nn, Tensor +from torch import Tensor, nn from aac_metrics.classes.base import AACMetric from aac_metrics.functional.bert_score_mrefs import ( - bert_score_mrefs, - _load_model_and_tokenizer, DEFAULT_BERT_SCORE_MODEL, REDUCTIONS, + Reduction, + _load_model_and_tokenizer, + bert_score_mrefs, ) from aac_metrics.utils.globals import _get_device @@ -44,7 +44,7 @@ def __init__( max_length: int = 64, reset_state: bool = True, idf: bool = False, - reduction: Union[str, Callable[[Tensor, ...], Tensor]] = "max", + reduction: Reduction = "max", filter_nan: bool = True, verbose: int = 0, ) -> None: @@ -110,7 +110,7 @@ def extra_repr(self) -> str: def get_output_names(self) -> tuple[str, ...]: return ( "bert_score.precision", - "bert_score.recalll", + "bert_score.recall", "bert_score.f1", ) diff --git a/src/aac_metrics/classes/bleu.py b/src/aac_metrics/classes/bleu.py index 07427cc..aa1c766 100644 --- a/src/aac_metrics/classes/bleu.py +++ b/src/aac_metrics/classes/bleu.py @@ -8,6 +8,7 @@ from aac_metrics.classes.base import AACMetric from aac_metrics.functional.bleu import ( BLEU_OPTIONS, + BleuOption, _bleu_compute, _bleu_update, ) @@ -32,7 +33,7 @@ def __init__( self, return_all_scores: bool = True, n: int = 4, - option: str = "closest", + option: BleuOption = "closest", verbose: int = 0, tokenizer: Callable[[str], list[str]] = str.split, ) -> None: diff --git a/src/aac_metrics/classes/meteor.py b/src/aac_metrics/classes/meteor.py index fe5e6cf..05118be 100644 --- a/src/aac_metrics/classes/meteor.py +++ b/src/aac_metrics/classes/meteor.py @@ -7,7 +7,7 @@ from torch import Tensor from aac_metrics.classes.base import AACMetric -from aac_metrics.functional.meteor import meteor +from aac_metrics.functional.meteor import Language, meteor class METEOR(AACMetric[Union[tuple[dict[str, Tensor], dict[str, Tensor]], Tensor]]): @@ -32,7 +32,7 @@ def __init__( cache_path: Union[str, Path, None] = None, java_path: Union[str, Path, None] = None, java_max_memory: str = "2G", - language: str = "en", + language: Language = "en", use_shell: Optional[bool] = None, params: Optional[Iterable[float]] = None, weights: Optional[Iterable[float]] = None, diff --git a/src/aac_metrics/classes/vocab.py b/src/aac_metrics/classes/vocab.py index 3f3f16c..3a7e4c7 100644 --- a/src/aac_metrics/classes/vocab.py +++ b/src/aac_metrics/classes/vocab.py @@ -3,16 +3,13 @@ import logging import math - from typing import Callable, Union import torch - from torch import Tensor from aac_metrics.classes.base import AACMetric -from aac_metrics.functional.vocab import vocab - +from aac_metrics.functional.vocab import PopStrategy, vocab pylog = logging.getLogger(__name__) @@ -36,7 +33,7 @@ def __init__( seed: Union[None, int, torch.Generator] = 1234, tokenizer: Callable[[str], list[str]] = str.split, dtype: torch.dtype = torch.float64, - pop_strategy: str = "max", + pop_strategy: PopStrategy = "max", verbose: int = 0, ) -> None: super().__init__() diff --git a/src/aac_metrics/functional/bert_score_mrefs.py b/src/aac_metrics/functional/bert_score_mrefs.py index 7752579..a9a07a5 100644 --- a/src/aac_metrics/functional/bert_score_mrefs.py +++ b/src/aac_metrics/functional/bert_score_mrefs.py @@ -1,7 +1,7 @@ #!/usr/bin/env python # -*- coding: utf-8 -*- -from typing import Callable, Optional, Union +from typing import Callable, Literal, Optional, Union import torch import torchmetrics @@ -18,6 +18,7 @@ DEFAULT_BERT_SCORE_MODEL = _DEFAULT_MODEL REDUCTIONS = ("mean", "max", "min") +Reduction = Union[Literal["mean", "max", "min"], Callable[..., Tensor]] def bert_score_mrefs( @@ -32,7 +33,7 @@ def bert_score_mrefs( max_length: int = 64, reset_state: bool = True, idf: bool = False, - reduction: Union[str, Callable[..., Tensor]] = "max", + reduction: Reduction = "max", filter_nan: bool = True, verbose: int = 0, ) -> Union[tuple[dict[str, Tensor], dict[str, Tensor]], Tensor]: @@ -62,7 +63,7 @@ def bert_score_mrefs( :param verbose: The verbose level. defaults to 0. :returns: A tuple of globals and locals scores or a scalar tensor with the main global score. """ - check_metric_inputs(candidates, mult_references) + check_metric_inputs(candidates, mult_references, min_length=1) if isinstance(model, str): if tokenizer is not None: diff --git a/src/aac_metrics/functional/bleu.py b/src/aac_metrics/functional/bleu.py index cc9c898..43cbb8d 100644 --- a/src/aac_metrics/functional/bleu.py +++ b/src/aac_metrics/functional/bleu.py @@ -3,20 +3,18 @@ import logging import math - from collections import Counter -from typing import Any, Callable, Optional, Union +from typing import Any, Callable, Literal, Optional, Union import torch - from torch import Tensor from aac_metrics.utils.checks import check_metric_inputs - pylog = logging.getLogger(__name__) BLEU_OPTIONS = ("shortest", "average", "closest") +BleuOption = Literal["shortest", "average", "closest"] def bleu( @@ -24,7 +22,7 @@ def bleu( mult_references: list[list[str]], return_all_scores: bool = True, n: int = 4, - option: str = "closest", + option: BleuOption = "closest", verbose: int = 0, tokenizer: Callable[[str], list[str]] = str.split, return_1_to_n: bool = False, @@ -72,7 +70,7 @@ def bleu_1( candidates: list[str], mult_references: list[list[str]], return_all_scores: bool = True, - option: str = "closest", + option: BleuOption = "closest", verbose: int = 0, tokenizer: Callable[[str], list[str]] = str.split, return_1_to_n: bool = False, @@ -93,7 +91,7 @@ def bleu_2( candidates: list[str], mult_references: list[list[str]], return_all_scores: bool = True, - option: str = "closest", + option: BleuOption = "closest", verbose: int = 0, tokenizer: Callable[[str], list[str]] = str.split, return_1_to_n: bool = False, @@ -114,7 +112,7 @@ def bleu_3( candidates: list[str], mult_references: list[list[str]], return_all_scores: bool = True, - option: str = "closest", + option: BleuOption = "closest", verbose: int = 0, tokenizer: Callable[[str], list[str]] = str.split, return_1_to_n: bool = False, @@ -135,7 +133,7 @@ def bleu_4( candidates: list[str], mult_references: list[list[str]], return_all_scores: bool = True, - option: str = "closest", + option: BleuOption = "closest", verbose: int = 0, tokenizer: Callable[[str], list[str]] = str.split, return_1_to_n: bool = False, @@ -179,7 +177,7 @@ def _bleu_compute( cooked_mrefs: list, return_all_scores: bool = True, n: int = 4, - option: str = "closest", + option: BleuOption = "closest", verbose: int = 0, return_1_to_n: bool = False, ) -> Union[Tensor, tuple[dict[str, Tensor], dict[str, Tensor]]]: @@ -189,7 +187,7 @@ def _bleu_compute( bleu_1_to_n_score, bleu_1_to_n_scores = __compute_bleu_score( cooked_cands, cooked_mrefs, - n, + n=n, option=option, verbose=verbose, ) @@ -300,7 +298,7 @@ def __compute_bleu_score( cooked_cands: list, cooked_mrefs: list, n: int, - option: Optional[str] = "closest", + option: BleuOption = "closest", verbose: int = 0, ) -> tuple[list[float], list[list[float]]]: SMALL = 1e-9 @@ -373,7 +371,7 @@ def __compute_bleu_score( def __single_reflen( reflens: list[int], - option: Optional[str] = None, + option: BleuOption, testlen: Optional[int] = None, ) -> float: if option == "shortest": diff --git a/src/aac_metrics/functional/meteor.py b/src/aac_metrics/functional/meteor.py index 1abea6d..db6a67c 100644 --- a/src/aac_metrics/functional/meteor.py +++ b/src/aac_metrics/functional/meteor.py @@ -5,25 +5,23 @@ import os.path as osp import platform import subprocess - from pathlib import Path from subprocess import Popen -from typing import Iterable, Optional, Union +from typing import Iterable, Literal, Optional, Union import torch - from torch import Tensor from aac_metrics.utils.checks import check_java_path, check_metric_inputs from aac_metrics.utils.globals import _get_cache_path, _get_java_path - pylog = logging.getLogger(__name__) DNAME_METEOR_CACHE = osp.join("aac-metrics", "meteor") FNAME_METEOR_JAR = osp.join(DNAME_METEOR_CACHE, "meteor-1.5.jar") SUPPORTED_LANGUAGES = ("en", "cz", "de", "es", "fr") +Language = Literal["en", "cz", "de", "es", "fr"] def meteor( @@ -33,7 +31,7 @@ def meteor( cache_path: Union[str, Path, None] = None, java_path: Union[str, Path, None] = None, java_max_memory: str = "2G", - language: str = "en", + language: Language = "en", use_shell: Optional[bool] = None, params: Optional[Iterable[float]] = None, weights: Optional[Iterable[float]] = None, diff --git a/src/aac_metrics/functional/mult_cands.py b/src/aac_metrics/functional/mult_cands.py index e5b7115..cdf7c3d 100644 --- a/src/aac_metrics/functional/mult_cands.py +++ b/src/aac_metrics/functional/mult_cands.py @@ -1,17 +1,16 @@ #!/usr/bin/env python # -*- coding: utf-8 -*- -from typing import Callable, Union +from typing import Callable, Literal, Union import torch import tqdm - from torch import Tensor from aac_metrics.utils.checks import is_mult_sents - SELECTIONS = ("max", "min", "mean") +Selection = Literal["max", "min", "mean"] def mult_cands_metric( @@ -21,8 +20,8 @@ def mult_cands_metric( mult_references: list[list[str]], return_all_scores: bool = True, return_all_cands_scores: bool = False, - selection: str = "max", - reduction: Callable[[Tensor], Tensor] = torch.mean, + selection: Selection = "max", + reduction_fn: Callable[[Tensor], Tensor] = torch.mean, **kwargs, ) -> Union[tuple[dict[str, Tensor], dict[str, Tensor]], Tensor]: """Multiple candidates metric wrapper. @@ -32,7 +31,7 @@ def mult_cands_metric( :param mult_candidates: The list of list of sentences to evaluate. :param mult_references: The references input. :param selection: The selection to apply. Can be "max", "min" or "mean". defaults to "max". - :param reduction: The reduction function to apply to local scores. defaults to torch.mean. + :param reduction_fn: The reduction function to apply to local scores. defaults to torch.mean. :param **kwargs: The keywords arguments given to the metric call. :returns: A tuple of globals and locals scores or a scalar tensor with the main global score. """ @@ -113,7 +112,7 @@ def mult_cands_metric( f"{k}_all": scores.transpose(0, 1) for k, scores in all_sents_scores.items() } - reduction_fn = reduction + reduction_fn = reduction_fn outs_corpus = {k: reduction_fn(scores) for k, scores in outs_sents.items()} if return_all_scores: diff --git a/src/aac_metrics/functional/spider_max.py b/src/aac_metrics/functional/spider_max.py index aefe6d3..7fe216e 100644 --- a/src/aac_metrics/functional/spider_max.py +++ b/src/aac_metrics/functional/spider_max.py @@ -5,7 +5,6 @@ from typing import Callable, Iterable, Optional, Union import torch - from torch import Tensor from aac_metrics.functional.mult_cands import mult_cands_metric @@ -74,7 +73,7 @@ def spider_max( return_all_scores=return_all_scores, return_all_cands_scores=return_all_cands_scores, selection="max", - reduction=torch.mean, + reduction_fn=torch.mean, # CIDEr args n=n, sigma=sigma, diff --git a/src/aac_metrics/functional/vocab.py b/src/aac_metrics/functional/vocab.py index f742617..2e06740 100644 --- a/src/aac_metrics/functional/vocab.py +++ b/src/aac_metrics/functional/vocab.py @@ -2,19 +2,20 @@ # -*- coding: utf-8 -*- import logging - -from typing import Callable, Union +from typing import Callable, Literal, Union import torch - from torch import Tensor from aac_metrics.utils.checks import check_metric_inputs, is_mono_sents - pylog = logging.getLogger(__name__) +POP_STRATEGIES = ("max", "min") +PopStrategy = Literal["max", "min"] + + def vocab( candidates: list[str], mult_references: Union[list[list[str]], None], @@ -22,7 +23,7 @@ def vocab( seed: Union[None, int, torch.Generator] = 1234, tokenizer: Callable[[str], list[str]] = str.split, dtype: torch.dtype = torch.float64, - pop_strategy: str = "max", + pop_strategy: PopStrategy = "max", verbose: int = 0, ) -> Union[tuple[dict[str, Tensor], dict[str, Tensor]], Tensor]: """Compute vocabulary statistics. @@ -84,7 +85,6 @@ def vocab( elif isinstance(pop_strategy, int): n_samples = pop_strategy else: - POP_STRATEGIES = ("max", "min") raise ValueError( f"Invalid argument {pop_strategy=}. (expected one of {POP_STRATEGIES} or an integer value)" ) diff --git a/src/aac_metrics/utils/checks.py b/src/aac_metrics/utils/checks.py index 6a5d27a..6e06554 100644 --- a/src/aac_metrics/utils/checks.py +++ b/src/aac_metrics/utils/checks.py @@ -19,6 +19,7 @@ def check_metric_inputs( candidates: Any, mult_references: Any, + min_length: int = 0, ) -> None: """Raises ValueError if candidates and mult_references does not have a valid type and size.""" @@ -44,6 +45,11 @@ def check_metric_inputs( error_msg = "Invalid number of references per candidate. (found at least 1 empty list of references)" raise ValueError(error_msg) + if len(candidates) < min_length: + raise ValueError( + f"Invalid number of sentences in candidates and references. (expected at least {min_length} sentences but found {len(candidates)=})" + ) + def check_java_path(java_path: Union[str, Path]) -> bool: version = _get_java_version(str(java_path)) diff --git a/src/aac_metrics/utils/globals.py b/src/aac_metrics/utils/globals.py index 58ceaf7..bd74cb5 100644 --- a/src/aac_metrics/utils/globals.py +++ b/src/aac_metrics/utils/globals.py @@ -5,12 +5,12 @@ import os import os.path as osp import tempfile - from pathlib import Path from typing import Any, Optional, Union import torch +_CUDA_IF_AVAILABLE: str = "cuda_if_available" pylog = logging.getLogger(__name__) @@ -67,7 +67,7 @@ def _get_cache_path(cache_path: Union[str, Path, None] = None) -> str: def _get_device( - device: Union[str, torch.device, None] = "cuda_if_available", + device: Union[str, torch.device, None] = _CUDA_IF_AVAILABLE, ) -> Optional[torch.device]: value_name = "device" process_func = __DEFAULT_GLOBALS[value_name]["process"] @@ -132,7 +132,7 @@ def __process_path(value: Union[str, Path, None]) -> Union[str, None]: def __process_device(value: Union[str, torch.device, None]) -> Optional[torch.device]: if value is None or value is ...: return None - if value == "cuda_if_available": + if value == _CUDA_IF_AVAILABLE: value = "cuda" if torch.cuda.is_available() else "cpu" if isinstance(value, str): value = torch.device(value) @@ -151,7 +151,7 @@ def __process_device(value: Union[str, torch.device, None]) -> Optional[torch.de "device": { "values": { "env": "AAC_METRICS_DEVICE", - "package": "cuda_if_available", + "package": _CUDA_IF_AVAILABLE, }, "process": __process_device, },