diff --git a/CHANGELOG.md b/CHANGELOG.md index f6a3f6e..b7edbe1 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -2,6 +2,11 @@ All notable changes to this project will be documented in this file. +## [0.5.2] 2024-01-05 +### Changed +- `aac-metrics` is now compatible with `transformers>=4.31`. +- Rename default device value "auto" to "cuda_if_available". + ## [0.5.1] 2023-12-20 ### Added - Check sentences inputs for all metrics. diff --git a/CITATION.cff b/CITATION.cff index 71458d5..625d4a5 100644 --- a/CITATION.cff +++ b/CITATION.cff @@ -19,5 +19,5 @@ keywords: - captioning - audio-captioning license: MIT -version: 0.5.1 -date-released: '2023-12-20' +version: 0.5.2 +date-released: '2024-01-05' diff --git a/README.md b/README.md index 8579949..4c0a0e3 100644 --- a/README.md +++ b/README.md @@ -124,7 +124,7 @@ Each metrics also exists as a python class version, like `aac_metrics.classes.ci ### Other metrics | Metric name | Python Class | Origin | Range | Short description | |:---|:---|:---|:---|:---| -| Vocabulary | `Vocab` | text generation | [0, +$\infty$[ | Number of unique words in candidates. | +| Vocabulary | `Vocab` | text generation | [0, +∞[ | Number of unique words in candidates. | ### Future directions This package currently does not include all metrics dedicated to audio captioning. Feel free to do a pull request / or ask to me by email if you want to include them. Those metrics not included are listed here: @@ -146,15 +146,13 @@ numpy >= 1.21.2 pyyaml >= 6.0 tqdm >= 4.64.0 sentence-transformers >= 2.2.2 -transformers < 4.31.0 +transformers torchmetrics >= 0.11.4 ``` ### External requirements - `java` **>= 1.8 and <= 1.13** is required to compute METEOR, SPICE and use the PTBTokenizer. -Most of these functions can specify a java executable path with `java_path` argument. - -- `unzip` command to extract SPICE zipped files. +Most of these functions can specify a java executable path with `java_path` argument or by overriding `AAC_METRICS_JAVA_PATH` environment variable. ## Additional notes ### CIDEr or CIDEr-D? @@ -233,14 +231,14 @@ If you use this software, please consider cite it as "Labbe, E. (2013). aac-metr ``` @software{ - Labbe_aac_metrics_2023, + Labbe_aac_metrics_2024, author = {LabbĂ©, Etienne}, license = {MIT}, - month = {12}, + month = {01}, title = {{aac-metrics}}, url = {https://github.com/Labbeti/aac-metrics/}, - version = {0.5.1}, - year = {2023}, + version = {0.5.2}, + year = {2024}, } ``` diff --git a/docs/aac_metrics.classes.rst b/docs/aac_metrics.classes.rst index d9e5a7d..d6fa491 100644 --- a/docs/aac_metrics.classes.rst +++ b/docs/aac_metrics.classes.rst @@ -13,11 +13,12 @@ Submodules :maxdepth: 4 aac_metrics.classes.base + aac_metrics.classes.bert_score_mrefs aac_metrics.classes.bleu aac_metrics.classes.cider_d aac_metrics.classes.evaluate aac_metrics.classes.fense - aac_metrics.classes.fluerr + aac_metrics.classes.fer aac_metrics.classes.meteor aac_metrics.classes.rouge_l aac_metrics.classes.sbert_sim @@ -25,3 +26,4 @@ Submodules aac_metrics.classes.spider aac_metrics.classes.spider_fl aac_metrics.classes.spider_max + aac_metrics.classes.vocab diff --git a/docs/aac_metrics.evaluate.rst b/docs/aac_metrics.evaluate.rst deleted file mode 100644 index 54c90ef..0000000 --- a/docs/aac_metrics.evaluate.rst +++ /dev/null @@ -1,7 +0,0 @@ -aac\_metrics.evaluate module -============================ - -.. automodule:: aac_metrics.evaluate - :members: - :undoc-members: - :show-inheritance: diff --git a/docs/aac_metrics.functional.rst b/docs/aac_metrics.functional.rst index fdfb751..2e73c90 100644 --- a/docs/aac_metrics.functional.rst +++ b/docs/aac_metrics.functional.rst @@ -12,11 +12,12 @@ Submodules .. toctree:: :maxdepth: 4 + aac_metrics.functional.bert_score_mrefs aac_metrics.functional.bleu aac_metrics.functional.cider_d aac_metrics.functional.evaluate aac_metrics.functional.fense - aac_metrics.functional.fluerr + aac_metrics.functional.fer aac_metrics.functional.meteor aac_metrics.functional.mult_cands aac_metrics.functional.rouge_l @@ -25,3 +26,4 @@ Submodules aac_metrics.functional.spider aac_metrics.functional.spider_fl aac_metrics.functional.spider_max + aac_metrics.functional.vocab diff --git a/docs/aac_metrics.utils.rst b/docs/aac_metrics.utils.rst index ef6aa3b..d75b736 100644 --- a/docs/aac_metrics.utils.rst +++ b/docs/aac_metrics.utils.rst @@ -13,6 +13,8 @@ Submodules :maxdepth: 4 aac_metrics.utils.checks + aac_metrics.utils.cmdline aac_metrics.utils.collections + aac_metrics.utils.globals aac_metrics.utils.imports aac_metrics.utils.tokenization diff --git a/requirements-dev.txt b/requirements-dev.txt index a782ffb..f6065ae 100644 --- a/requirements-dev.txt +++ b/requirements-dev.txt @@ -7,3 +7,5 @@ scikit-image==0.19.2 matplotlib==3.5.2 ipykernel==6.9.1 twine==4.0.1 +sphinx==7.2.6 +sphinx-press-theme==0.8.0 diff --git a/requirements.txt b/requirements.txt index a87060f..19e52ce 100644 --- a/requirements.txt +++ b/requirements.txt @@ -5,5 +5,5 @@ numpy>=1.21.2 pyyaml>=6.0 tqdm>=4.64.0 sentence-transformers>=2.2.2 -transformers<4.31.0 +transformers torchmetrics>=0.11.4 diff --git a/src/aac_metrics/__init__.py b/src/aac_metrics/__init__.py index deae747..fa9aeb2 100644 --- a/src/aac_metrics/__init__.py +++ b/src/aac_metrics/__init__.py @@ -10,7 +10,7 @@ __maintainer__ = "Etienne LabbĂ© (Labbeti)" __name__ = "aac-metrics" __status__ = "Development" -__version__ = "0.5.1" +__version__ = "0.5.2" from .classes.base import AACMetric diff --git a/src/aac_metrics/classes/base.py b/src/aac_metrics/classes/base.py index 3df7b33..2a90de1 100644 --- a/src/aac_metrics/classes/base.py +++ b/src/aac_metrics/classes/base.py @@ -1,6 +1,8 @@ #!/usr/bin/env python # -*- coding: utf-8 -*- +import math + from typing import Any, ClassVar, Generic, Optional, TypeVar, Union from torch import nn, Tensor @@ -19,9 +21,9 @@ class AACMetric(nn.Module, Generic[OutType]): is_differentiable: ClassVar[Optional[bool]] = False # The theorical minimal value of the main global score of the metric. - min_value: ClassVar[Optional[float]] = None + min_value: ClassVar[float] = -math.inf # The theorical maximal value of the main global score of the metric. - max_value: ClassVar[Optional[float]] = None + max_value: ClassVar[float] = math.inf def __init__(self, **kwargs: Any) -> None: super().__init__(**kwargs) diff --git a/src/aac_metrics/classes/bert_score_mrefs.py b/src/aac_metrics/classes/bert_score_mrefs.py index c31c1d9..a9451ba 100644 --- a/src/aac_metrics/classes/bert_score_mrefs.py +++ b/src/aac_metrics/classes/bert_score_mrefs.py @@ -6,13 +6,15 @@ import torch from torch import nn, Tensor -from torchmetrics.text.bert import _DEFAULT_MODEL from aac_metrics.classes.base import AACMetric from aac_metrics.functional.bert_score_mrefs import ( bert_score_mrefs, _load_model_and_tokenizer, + DEFAULT_BERT_SCORE_MODEL, + REDUCTIONS, ) +from aac_metrics.utils.globals import _get_device class BERTScoreMRefs(AACMetric): @@ -35,8 +37,8 @@ class BERTScoreMRefs(AACMetric): def __init__( self, return_all_scores: bool = True, - model: Union[str, nn.Module] = _DEFAULT_MODEL, - device: Union[str, torch.device, None] = "auto", + model: Union[str, nn.Module] = DEFAULT_BERT_SCORE_MODEL, + device: Union[str, torch.device, None] = "cuda_if_available", batch_size: int = 32, num_threads: int = 0, max_length: int = 64, @@ -46,6 +48,12 @@ def __init__( filter_nan: bool = True, verbose: int = 0, ) -> None: + if reduction not in REDUCTIONS: + raise ValueError( + f"Invalid argument {reduction=}. (expected one of {REDUCTIONS})" + ) + + device = _get_device(device) model, tokenizer = _load_model_and_tokenizer( model=model, tokenizer=None, diff --git a/src/aac_metrics/classes/bleu.py b/src/aac_metrics/classes/bleu.py index c6f1af2..07427cc 100644 --- a/src/aac_metrics/classes/bleu.py +++ b/src/aac_metrics/classes/bleu.py @@ -81,12 +81,12 @@ def update( mult_references: list[list[str]], ) -> None: self._cooked_cands, self._cooked_mrefs = _bleu_update( - candidates, - mult_references, - self._n, - self._tokenizer, - self._cooked_cands, - self._cooked_mrefs, + candidates=candidates, + mult_references=mult_references, + n=self._n, + tokenizer=self._tokenizer, + prev_cooked_cands=self._cooked_cands, + prev_cooked_mrefs=self._cooked_mrefs, ) diff --git a/src/aac_metrics/classes/evaluate.py b/src/aac_metrics/classes/evaluate.py index a4eee40..9415e2f 100644 --- a/src/aac_metrics/classes/evaluate.py +++ b/src/aac_metrics/classes/evaluate.py @@ -55,7 +55,7 @@ def __init__( cache_path: Union[str, Path, None] = None, java_path: Union[str, Path, None] = None, tmp_path: Union[str, Path, None] = None, - device: Union[str, torch.device, None] = "auto", + device: Union[str, torch.device, None] = "cuda_if_available", verbose: int = 0, ) -> None: metrics = _instantiate_metrics_classes( @@ -127,7 +127,7 @@ def __init__( cache_path: Union[str, Path, None] = None, java_path: Union[str, Path, None] = None, tmp_path: Union[str, Path, None] = None, - device: Union[str, torch.device, None] = "auto", + device: Union[str, torch.device, None] = "cuda_if_available", verbose: int = 0, ) -> None: super().__init__( @@ -146,7 +146,7 @@ def _instantiate_metrics_classes( cache_path: Union[str, Path, None] = None, java_path: Union[str, Path, None] = None, tmp_path: Union[str, Path, None] = None, - device: Union[str, torch.device, None] = "auto", + device: Union[str, torch.device, None] = "cuda_if_available", verbose: int = 0, ) -> list[AACMetric]: if isinstance(metrics, str) and metrics in METRICS_SETS: @@ -179,7 +179,7 @@ def _get_metric_factory_classes( cache_path: Union[str, Path, None] = None, java_path: Union[str, Path, None] = None, tmp_path: Union[str, Path, None] = None, - device: Union[str, torch.device, None] = "auto", + device: Union[str, torch.device, None] = "cuda_if_available", verbose: int = 0, init_kwds: Optional[dict[str, Any]] = None, ) -> dict[str, Callable[[], AACMetric]]: diff --git a/src/aac_metrics/classes/fense.py b/src/aac_metrics/classes/fense.py index f1bae17..ae5f859 100644 --- a/src/aac_metrics/classes/fense.py +++ b/src/aac_metrics/classes/fense.py @@ -7,11 +7,17 @@ import torch +from sentence_transformers import SentenceTransformer from torch import Tensor from aac_metrics.classes.base import AACMetric from aac_metrics.functional.fense import fense, _load_models_and_tokenizer -from aac_metrics.functional.fer import ERROR_NAMES +from aac_metrics.functional.fer import ( + BERTFlatClassifier, + _ERROR_NAMES, + DEFAULT_FER_MODEL, +) +from aac_metrics.functional.sbert_sim import DEFAULT_SBERT_SIM_MODEL pylog = logging.getLogger(__name__) @@ -36,10 +42,10 @@ class FENSE(AACMetric[Union[tuple[dict[str, Tensor], dict[str, Tensor]], Tensor] def __init__( self, return_all_scores: bool = True, - sbert_model: str = "paraphrase-TinyBERT-L6-v2", - echecker: str = "echecker_clotho_audiocaps_base", + sbert_model: Union[str, SentenceTransformer] = DEFAULT_SBERT_SIM_MODEL, + echecker: Union[str, BERTFlatClassifier] = DEFAULT_FER_MODEL, error_threshold: float = 0.9, - device: Union[str, torch.device, None] = "auto", + device: Union[str, torch.device, None] = "cuda_if_available", batch_size: int = 32, reset_state: bool = True, return_probs: bool = False, @@ -99,9 +105,10 @@ def extra_repr(self) -> str: return repr_ def get_output_names(self) -> tuple[str, ...]: - return ("sbert_sim", "fer", "fense") + tuple( - f"fer.{name}_prob" for name in ERROR_NAMES - ) + output_names = ["sbert_sim", "fer", "fense"] + if self._return_probs: + output_names += [f"fer.{name}_prob" for name in _ERROR_NAMES] + return tuple(output_names) def reset(self) -> None: self._candidates = [] diff --git a/src/aac_metrics/classes/fer.py b/src/aac_metrics/classes/fer.py index 4217a4f..1c06592 100644 --- a/src/aac_metrics/classes/fer.py +++ b/src/aac_metrics/classes/fer.py @@ -11,10 +11,13 @@ from aac_metrics.classes.base import AACMetric from aac_metrics.functional.fer import ( + BERTFlatClassifier, fer, _load_echecker_and_tokenizer, - ERROR_NAMES, + _ERROR_NAMES, + DEFAULT_FER_MODEL, ) +from aac_metrics.utils.globals import _get_device pylog = logging.getLogger(__name__) @@ -39,15 +42,22 @@ class FER(AACMetric[Union[tuple[dict[str, Tensor], dict[str, Tensor]], Tensor]]) def __init__( self, return_all_scores: bool = True, - echecker: str = "echecker_clotho_audiocaps_base", + echecker: Union[str, BERTFlatClassifier] = DEFAULT_FER_MODEL, error_threshold: float = 0.9, - device: Union[str, torch.device, None] = "auto", + device: Union[str, torch.device, None] = "cuda_if_available", batch_size: int = 32, reset_state: bool = True, return_probs: bool = False, verbose: int = 0, ) -> None: - echecker, echecker_tokenizer = _load_echecker_and_tokenizer(echecker, None, device, reset_state, verbose) # type: ignore + device = _get_device(device) + echecker, echecker_tokenizer = _load_echecker_and_tokenizer( + echecker=echecker, + echecker_tokenizer=None, + device=device, + reset_state=reset_state, + verbose=verbose, + ) super().__init__() self._return_all_scores = return_all_scores @@ -82,7 +92,10 @@ def extra_repr(self) -> str: return repr_ def get_output_names(self) -> tuple[str, ...]: - return ("fer",) + tuple(f"fer.{name}_prob" for name in ERROR_NAMES) + output_names = ["fer"] + if self._return_probs: + output_names += [f"fer.{name}_prob" for name in _ERROR_NAMES] + return tuple(output_names) def reset(self) -> None: self._candidates = [] diff --git a/src/aac_metrics/classes/fluerr.py b/src/aac_metrics/classes/fluerr.py deleted file mode 100644 index 0e84c5a..0000000 --- a/src/aac_metrics/classes/fluerr.py +++ /dev/null @@ -1,95 +0,0 @@ -#!/usr/bin/env python -# -*- coding: utf-8 -*- - -import logging - -from typing import Union - -import torch - -from torch import Tensor - -from aac_metrics.classes.base import AACMetric -from aac_metrics.functional.fluerr import ( - fluerr, - _load_echecker_and_tokenizer, - ERROR_NAMES, -) - - -pylog = logging.getLogger(__name__) - - -class FluErr(AACMetric[Union[tuple[dict[str, Tensor], dict[str, Tensor]], Tensor]]): - """Return fluency error rate detected by a pre-trained BERT model. - - - Paper: https://arxiv.org/abs/2110.04684 - - Original implementation: https://github.com/blmoistawinde/fense - - For more information, see :func:`~aac_metrics.functional.fluerr.fluerr`. - """ - - full_state_update = False - higher_is_better = False - is_differentiable = False - - min_value = -1.0 - max_value = 1.0 - - def __init__( - self, - return_all_scores: bool = True, - echecker: str = "echecker_clotho_audiocaps_base", - error_threshold: float = 0.9, - device: Union[str, torch.device, None] = "auto", - batch_size: int = 32, - reset_state: bool = True, - return_probs: bool = False, - verbose: int = 0, - ) -> None: - echecker, echecker_tokenizer = _load_echecker_and_tokenizer(echecker, None, device, reset_state, verbose) # type: ignore - - super().__init__() - self._return_all_scores = return_all_scores - self._echecker = echecker - self._echecker_tokenizer = echecker_tokenizer - self._error_threshold = error_threshold - self._device = device - self._batch_size = batch_size - self._reset_state = reset_state - self._return_probs = return_probs - self._verbose = verbose - - self._candidates = [] - - def compute(self) -> Union[tuple[dict[str, Tensor], dict[str, Tensor]], Tensor]: - return fluerr( - candidates=self._candidates, - return_all_scores=self._return_all_scores, - echecker=self._echecker, - echecker_tokenizer=self._echecker_tokenizer, - error_threshold=self._error_threshold, - device=self._device, - batch_size=self._batch_size, - reset_state=self._reset_state, - return_probs=self._return_probs, - verbose=self._verbose, - ) - - def extra_repr(self) -> str: - return f"device={self._device}, batch_size={self._batch_size}" - - def get_output_names(self) -> tuple[str, ...]: - return ("fluerr",) + tuple(f"fluerr.{name}_prob" for name in ERROR_NAMES) - - def reset(self) -> None: - self._candidates = [] - return super().reset() - - def update( - self, - candidates: list[str], - *args, - **kwargs, - ) -> None: - self._candidates += candidates diff --git a/src/aac_metrics/classes/sbert_sim.py b/src/aac_metrics/classes/sbert_sim.py index b32d0d4..76ba245 100644 --- a/src/aac_metrics/classes/sbert_sim.py +++ b/src/aac_metrics/classes/sbert_sim.py @@ -11,7 +11,12 @@ from torch import Tensor from aac_metrics.classes.base import AACMetric -from aac_metrics.functional.sbert_sim import sbert_sim, _load_sbert +from aac_metrics.functional.sbert_sim import ( + sbert_sim, + _load_sbert, + DEFAULT_SBERT_SIM_MODEL, +) +from aac_metrics.utils.globals import _get_device pylog = logging.getLogger(__name__) @@ -36,13 +41,18 @@ class SBERTSim(AACMetric[Union[tuple[dict[str, Tensor], dict[str, Tensor]], Tens def __init__( self, return_all_scores: bool = True, - sbert_model: Union[str, SentenceTransformer] = "paraphrase-TinyBERT-L6-v2", - device: Union[str, torch.device, None] = "auto", + sbert_model: Union[str, SentenceTransformer] = DEFAULT_SBERT_SIM_MODEL, + device: Union[str, torch.device, None] = "cuda_if_available", batch_size: int = 32, reset_state: bool = True, verbose: int = 0, ) -> None: - sbert_model = _load_sbert(sbert_model, device, reset_state) + device = _get_device(device) + sbert_model = _load_sbert( + sbert_model=sbert_model, + device=device, + reset_state=reset_state, + ) super().__init__() self._return_all_scores = return_all_scores diff --git a/src/aac_metrics/classes/spider_fl.py b/src/aac_metrics/classes/spider_fl.py index 070dc7a..d03d6ea 100644 --- a/src/aac_metrics/classes/spider_fl.py +++ b/src/aac_metrics/classes/spider_fl.py @@ -15,8 +15,11 @@ from aac_metrics.functional.fer import ( BERTFlatClassifier, _load_echecker_and_tokenizer, + _ERROR_NAMES, + DEFAULT_FER_MODEL, ) from aac_metrics.functional.spider_fl import spider_fl +from aac_metrics.utils.globals import _get_device pylog = logging.getLogger(__name__) @@ -49,10 +52,10 @@ def __init__( java_max_memory: str = "8G", timeout: Union[None, int, Iterable[int]] = None, # FluencyError args - echecker: Union[str, BERTFlatClassifier] = "echecker_clotho_audiocaps_base", + echecker: Union[str, BERTFlatClassifier] = DEFAULT_FER_MODEL, echecker_tokenizer: Optional[AutoTokenizer] = None, error_threshold: float = 0.9, - device: Union[str, torch.device, None] = "auto", + device: Union[str, torch.device, None] = "cuda_if_available", batch_size: int = 32, reset_state: bool = True, return_probs: bool = True, @@ -60,8 +63,13 @@ def __init__( penalty: float = 0.9, verbose: int = 0, ) -> None: + device = _get_device(device) echecker, echecker_tokenizer = _load_echecker_and_tokenizer( - echecker, echecker_tokenizer, device, reset_state, verbose + echecker=echecker, + echecker_tokenizer=echecker_tokenizer, + device=device, + reset_state=reset_state, + verbose=verbose, ) super().__init__() @@ -127,7 +135,10 @@ def extra_repr(self) -> str: return extra def get_output_names(self) -> tuple[str, ...]: - return ("cider_d", "spice", "spider", "spider_fl", "fer") + output_names = ["cider_d", "spice", "spider", "spider_fl", "fer"] + if self._return_probs: + output_names += [f"fer.{name}_prob" for name in _ERROR_NAMES] + return tuple(output_names) def reset(self) -> None: self._candidates = [] diff --git a/src/aac_metrics/classes/spider_max.py b/src/aac_metrics/classes/spider_max.py index 4e60640..8b1582c 100644 --- a/src/aac_metrics/classes/spider_max.py +++ b/src/aac_metrics/classes/spider_max.py @@ -89,7 +89,10 @@ def extra_repr(self) -> str: return repr_ def get_output_names(self) -> tuple[str, ...]: - return ("spider_max",) + output_names = ["spider_max"] + if self._return_all_cands_scores: + output_names += ["cider_d_all", "spice_all", "spider_all"] + return tuple(output_names) def reset(self) -> None: self._mult_candidates = [] diff --git a/src/aac_metrics/eval.py b/src/aac_metrics/eval.py index 94605cd..2302ab2 100644 --- a/src/aac_metrics/eval.py +++ b/src/aac_metrics/eval.py @@ -199,7 +199,7 @@ def _get_main_evaluate_args() -> Namespace: parser.add_argument( "--device", type=str, - default="auto", + default="cuda_if_available", help="Device used for model-based metrics. defaults to 'auto'.", ) parser.add_argument( diff --git a/src/aac_metrics/functional/bert_score_mrefs.py b/src/aac_metrics/functional/bert_score_mrefs.py index d6e932e..2e1bdbd 100644 --- a/src/aac_metrics/functional/bert_score_mrefs.py +++ b/src/aac_metrics/functional/bert_score_mrefs.py @@ -16,13 +16,17 @@ from aac_metrics.utils.globals import _get_device +DEFAULT_BERT_SCORE_MODEL = _DEFAULT_MODEL +REDUCTIONS = ("mean", "max", "min") + + def bert_score_mrefs( candidates: list[str], mult_references: list[list[str]], return_all_scores: bool = True, - model: Union[str, nn.Module] = _DEFAULT_MODEL, + model: Union[str, nn.Module] = DEFAULT_BERT_SCORE_MODEL, tokenizer: Optional[Callable] = None, - device: Union[str, torch.device, None] = "auto", + device: Union[str, torch.device, None] = "cuda_if_available", batch_size: int = 32, num_threads: int = 0, max_length: int = 64, @@ -48,12 +52,12 @@ def bert_score_mrefs( :param tokenizer: The fast tokenizer used to split sentences into words. If None, use the tokenizer corresponding to the model argument. defaults to None. - :param device: The PyTorch device used to run the BERT model. defaults to "auto". + :param device: The PyTorch device used to run the BERT model. defaults to "cuda_if_available". :param batch_size: The batch size used in the model forward. :param num_threads: A number of threads to use for a dataloader. defaults to 0. :param max_length: Max length when encoding sentences to tensor ids. defaults to 64. :param idf: Whether or not using Inverse document frequency to ponderate the BERTScores. defaults to False. - :param reduction: The reduction function to apply between multiple references for each audio. defaults to "mean". + :param reduction: The reduction function to apply between multiple references for each audio. defaults to "max". :param filter_nan: If True, replace NaN scores by 0.0. defaults to True. :param verbose: The verbose level. defaults to 0. :returns: A tuple of globals and locals scores or a scalar tensor with the main global score. @@ -132,7 +136,6 @@ def bert_score_mrefs( elif reduction == "min": reduction_fn = _min_reduce else: - REDUCTIONS = ("mean", "max", "min") raise ValueError( f"Invalid argument {reduction=}. (expected one of {REDUCTIONS})" ) @@ -165,11 +168,11 @@ def bert_score_mrefs( def _load_model_and_tokenizer( - model: Union[str, nn.Module], - tokenizer: Optional[Callable], - device: Union[str, torch.device, None], - reset_state: bool, - verbose: int, + model: Union[str, nn.Module] = DEFAULT_BERT_SCORE_MODEL, + tokenizer: Optional[Callable] = None, + device: Union[str, torch.device, None] = "cuda_if_available", + reset_state: bool = True, + verbose: int = 0, ) -> tuple[nn.Module, Optional[Callable]]: state = torch.random.get_rng_state() diff --git a/src/aac_metrics/functional/evaluate.py b/src/aac_metrics/functional/evaluate.py index 542ff80..772b1d5 100644 --- a/src/aac_metrics/functional/evaluate.py +++ b/src/aac_metrics/functional/evaluate.py @@ -85,7 +85,7 @@ def evaluate( cache_path: Union[str, Path, None] = None, java_path: Union[str, Path, None] = None, tmp_path: Union[str, Path, None] = None, - device: Union[str, torch.device, None] = "auto", + device: Union[str, torch.device, None] = "cuda_if_available", verbose: int = 0, ) -> tuple[dict[str, Tensor], dict[str, Tensor]]: """Evaluate candidates with multiple references with custom metrics. @@ -98,7 +98,7 @@ def evaluate( :param java_path: The path to the java executable. defaults to the value returned by :func:`~aac_metrics.utils.paths.get_default_java_path`. :param tmp_path: Temporary directory path. defaults to the value returned by :func:`~aac_metrics.utils.paths.get_default_tmp_path`. :param device: The PyTorch device used to run FENSE and SPIDErFL models. - If None, it will try to detect use cuda if available. defaults to "auto". + If None, it will try to detect use cuda if available. defaults to "cuda_if_available". :param verbose: The verbose level. defaults to 0. :returns: A tuple contains the corpus and sentences scores. """ @@ -172,7 +172,7 @@ def dcase2023_evaluate( cache_path: Union[str, Path, None] = None, java_path: Union[str, Path, None] = None, tmp_path: Union[str, Path, None] = None, - device: Union[str, torch.device, None] = "auto", + device: Union[str, torch.device, None] = "cuda_if_available", verbose: int = 0, ) -> tuple[dict[str, Tensor], dict[str, Tensor]]: """Evaluate candidates with multiple references with the DCASE2023 Audio Captioning metrics. @@ -185,7 +185,7 @@ def dcase2023_evaluate( :param java_path: The path to the java executable. defaults to the value returned by :func:`~aac_metrics.utils.paths.get_default_java_path`. :param tmp_path: Temporary directory path. defaults to the value returned by :func:`~aac_metrics.utils.paths.get_default_tmp_path`. :param device: The PyTorch device used to run FENSE and SPIDErFL models. - If None, it will try to detect use cuda if available. defaults to "auto". + If None, it will try to detect use cuda if available. defaults to "cuda_if_available". :param verbose: The verbose level. defaults to 0. :returns: A tuple contains the corpus and sentences scores. """ @@ -207,7 +207,7 @@ def _instantiate_metrics_functions( cache_path: Union[str, Path, None] = None, java_path: Union[str, Path, None] = None, tmp_path: Union[str, Path, None] = None, - device: Union[str, torch.device, None] = "auto", + device: Union[str, torch.device, None] = "cuda_if_available", verbose: int = 0, ) -> list[Callable]: if isinstance(metrics, str) and metrics in METRICS_SETS: @@ -245,7 +245,7 @@ def _get_metric_factory_functions( cache_path: Union[str, Path, None] = None, java_path: Union[str, Path, None] = None, tmp_path: Union[str, Path, None] = None, - device: Union[str, torch.device, None] = "auto", + device: Union[str, torch.device, None] = "cuda_if_available", verbose: int = 0, init_kwds: Optional[dict[str, Any]] = None, ) -> dict[str, Callable[[list[str], list[list[str]]], Any]]: diff --git a/src/aac_metrics/functional/fense.py b/src/aac_metrics/functional/fense.py index 643960b..74e462c 100644 --- a/src/aac_metrics/functional/fense.py +++ b/src/aac_metrics/functional/fense.py @@ -15,8 +15,13 @@ fer, _load_echecker_and_tokenizer, BERTFlatClassifier, + DEFAULT_FER_MODEL, +) +from aac_metrics.functional.sbert_sim import ( + sbert_sim, + _load_sbert, + DEFAULT_SBERT_SIM_MODEL, ) -from aac_metrics.functional.sbert_sim import sbert_sim, _load_sbert from aac_metrics.utils.checks import check_metric_inputs @@ -28,12 +33,12 @@ def fense( mult_references: list[list[str]], return_all_scores: bool = True, # SBERT args - sbert_model: Union[str, SentenceTransformer] = "paraphrase-TinyBERT-L6-v2", + sbert_model: Union[str, SentenceTransformer] = DEFAULT_SBERT_SIM_MODEL, # FluencyError args - echecker: Union[str, BERTFlatClassifier] = "echecker_clotho_audiocaps_base", + echecker: Union[str, BERTFlatClassifier] = DEFAULT_FER_MODEL, echecker_tokenizer: Optional[AutoTokenizer] = None, error_threshold: float = 0.9, - device: Union[str, torch.device, None] = "auto", + device: Union[str, torch.device, None] = "cuda_if_available", batch_size: int = 32, reset_state: bool = True, return_probs: bool = False, @@ -60,7 +65,7 @@ def fense( defaults to None. :param error_threshold: The threshold used to detect fluency errors for echecker model. defaults to 0.9. :param penalty: The penalty coefficient applied. Higher value means to lower the cos-sim scores when an error is detected. defaults to 0.9. - :param device: The PyTorch device used to run FENSE models. If "auto", it will use cuda if available. defaults to "auto". + :param device: The PyTorch device used to run FENSE models. If "cuda_if_available", it will use cuda if available. defaults to "cuda_if_available". :param batch_size: The batch size of the sBERT and echecker models. defaults to 32. :param reset_state: If True, reset the state of the PyTorch global generator after the initialization of the pre-trained models. defaults to True. :param return_probs: If True, return each individual error probability given by the fluency detector model. defaults to False. @@ -138,15 +143,17 @@ def _fense_from_outputs( def _load_models_and_tokenizer( - sbert_model: Union[str, SentenceTransformer] = "paraphrase-TinyBERT-L6-v2", - echecker: Union[str, BERTFlatClassifier] = "echecker_clotho_audiocaps_base", + sbert_model: Union[str, SentenceTransformer] = DEFAULT_SBERT_SIM_MODEL, + echecker: Union[str, BERTFlatClassifier] = DEFAULT_FER_MODEL, echecker_tokenizer: Optional[AutoTokenizer] = None, - device: Union[str, torch.device, None] = "auto", + device: Union[str, torch.device, None] = "cuda_if_available", reset_state: bool = True, verbose: int = 0, ) -> tuple[SentenceTransformer, BERTFlatClassifier, AutoTokenizer]: sbert_model = _load_sbert( - sbert_model=sbert_model, device=device, reset_state=reset_state + sbert_model=sbert_model, + device=device, + reset_state=reset_state, ) echecker, echecker_tokenizer = _load_echecker_and_tokenizer( echecker=echecker, diff --git a/src/aac_metrics/functional/fer.py b/src/aac_metrics/functional/fer.py index 7a3eabe..24f6095 100644 --- a/src/aac_metrics/functional/fer.py +++ b/src/aac_metrics/functional/fer.py @@ -14,6 +14,7 @@ import numpy as np import torch +import transformers from torch import nn, Tensor from tqdm import tqdm @@ -26,12 +27,14 @@ from aac_metrics.utils.globals import _get_device -# config according to the settings on your computer, this should be default setting of shadowsocks -DEFAULT_PROXIES = { +DEFAULT_FER_MODEL = "echecker_clotho_audiocaps_base" + + +_DEFAULT_PROXIES = { "http": "socks5h://127.0.0.1:1080", "https": "socks5h://127.0.0.1:1080", } -PRETRAIN_ECHECKERS_DICT = { +_PRETRAIN_ECHECKERS_DICT = { "echecker_clotho_audiocaps_base": ( "https://github.com/blmoistawinde/fense/releases/download/V0.1/echecker_clotho_audiocaps_base.ckpt", "1a719f090af70614bbdb9f9437530b7e133c48cfa4a58d964de0d47fc974a2fa", @@ -41,13 +44,7 @@ "90ed0ac5033ec497ec66d4f68588053813e085671136dae312097c96c504f673", ), } - -RemoteFileMetadata = namedtuple("RemoteFileMetadata", ["filename", "url", "checksum"]) - -pylog = logging.getLogger(__name__) - - -ERROR_NAMES = ( +_ERROR_NAMES = ( "add_tail", "repeat_event", "repeat_adv", @@ -56,6 +53,10 @@ "error", ) +_RemoteFileMetadata = namedtuple("RemoteFileMetadata", ["filename", "url", "checksum"]) + +pylog = logging.getLogger(__name__) + class BERTFlatClassifier(nn.Module): def __init__(self, model_type: str, num_classes: int = 5) -> None: @@ -69,13 +70,19 @@ def __init__(self, model_type: str, num_classes: int = 5) -> None: @classmethod def from_pretrained( cls, - model_name: str = "echecker_clotho_audiocaps_base", - device: Union[str, torch.device, None] = "auto", + model_name: str = DEFAULT_FER_MODEL, + device: Union[str, torch.device, None] = "cuda_if_available", use_proxy: bool = False, proxies: Optional[dict[str, str]] = None, verbose: int = 0, ) -> "BERTFlatClassifier": - return __load_pretrain_echecker(model_name, device, use_proxy, proxies, verbose) + return __load_pretrain_echecker( + echecker_model=model_name, + device=device, + use_proxy=use_proxy, + proxies=proxies, + verbose=verbose, + ) def forward( self, @@ -94,10 +101,10 @@ def forward( def fer( candidates: list[str], return_all_scores: bool = True, - echecker: Union[str, BERTFlatClassifier] = "echecker_clotho_audiocaps_base", + echecker: Union[str, BERTFlatClassifier] = DEFAULT_FER_MODEL, echecker_tokenizer: Optional[AutoTokenizer] = None, error_threshold: float = 0.9, - device: Union[str, torch.device, None] = "auto", + device: Union[str, torch.device, None] = "cuda_if_available", batch_size: int = 32, reset_state: bool = True, return_probs: bool = False, @@ -120,7 +127,7 @@ def fer( If None and echecker is not None, this value will be inferred with `echecker.model_type`. defaults to None. :param error_threshold: The threshold used to detect fluency errors for echecker model. defaults to 0.9. - :param device: The PyTorch device used to run FENSE models. If "auto", it will use cuda if available. defaults to "auto". + :param device: The PyTorch device used to run FENSE models. If "cuda_if_available", it will use cuda if available. defaults to "cuda_if_available". :param batch_size: The batch size of the echecker models. defaults to 32. :param reset_state: If True, reset the state of the PyTorch global generator after the initialization of the pre-trained models. defaults to True. :param return_probs: If True, return each individual error probability given by the fluency detector model. defaults to False. @@ -133,16 +140,20 @@ def fer( # Init models echecker, echecker_tokenizer = _load_echecker_and_tokenizer( - echecker, echecker_tokenizer, device, reset_state, verbose + echecker=echecker, + echecker_tokenizer=echecker_tokenizer, + device=device, + reset_state=reset_state, + verbose=verbose, ) # Compute and apply fluency error detection penalty probs_outs_sents = __detect_error_sents( - echecker, - echecker_tokenizer, # type: ignore - candidates, - batch_size, - device, + echecker=echecker, + echecker_tokenizer=echecker_tokenizer, + sents=candidates, + batch_size=batch_size, + device=device, ) fer_scores = (probs_outs_sents["error"] > error_threshold).astype(float) @@ -174,11 +185,17 @@ def fer( return fer_score +def _use_new_echecker_loading() -> bool: + version = transformers.__version__ + major, minor, _patch = map(int, version.split(".")) + return major > 4 or (major == 4 and minor >= 31) + + # - Private functions def _load_echecker_and_tokenizer( - echecker: Union[str, BERTFlatClassifier] = "echecker_clotho_audiocaps_base", + echecker: Union[str, BERTFlatClassifier] = DEFAULT_FER_MODEL, echecker_tokenizer: Optional[AutoTokenizer] = None, - device: Union[str, torch.device, None] = "auto", + device: Union[str, torch.device, None] = "cuda_if_available", reset_state: bool = True, verbose: int = 0, ) -> tuple[BERTFlatClassifier, AutoTokenizer]: @@ -226,10 +243,10 @@ def __detect_error_sents( # batch_logits: (bsize, num_classes=6) # note: fix error in the original fense code: https://github.com/blmoistawinde/fense/blob/main/fense/evaluator.py#L69 probs = logits.sigmoid().transpose(0, 1).cpu().numpy() - probs_dic: dict[str, np.ndarray] = dict(zip(ERROR_NAMES, probs)) + probs_dic: dict[str, np.ndarray] = dict(zip(_ERROR_NAMES, probs)) else: - dic_lst_probs = {name: [] for name in ERROR_NAMES} + dic_lst_probs = {name: [] for name in _ERROR_NAMES} for i in range(0, len(sents), batch_size): batch = __infer_preprocess( @@ -257,11 +274,10 @@ def __detect_error_sents( def __check_download_resource( - remote: RemoteFileMetadata, + remote: _RemoteFileMetadata, use_proxy: bool = False, proxies: Optional[dict[str, str]] = None, ) -> str: - proxies = DEFAULT_PROXIES if use_proxy and proxies is None else proxies data_home = __get_data_home() file_path = os.path.join(data_home, remote.filename) if not os.path.exists(file_path): @@ -286,10 +302,10 @@ def __infer_preprocess( def __download( - remote: RemoteFileMetadata, + remote: _RemoteFileMetadata, file_path: Optional[str] = None, use_proxy: bool = False, - proxies: Optional[dict[str, str]] = DEFAULT_PROXIES, + proxies: Optional[dict[str, str]] = None, ) -> str: data_home = __get_data_home() file_path = __fetch_remote(remote, data_home, use_proxy, proxies) @@ -299,8 +315,12 @@ def __download( def __download_with_bar( url: str, file_path: str, - proxies: Optional[dict[str, str]] = DEFAULT_PROXIES, + use_proxy: bool = False, + proxies: Optional[dict[str, str]] = None, ) -> str: + if use_proxy and proxies is None: + proxies = _DEFAULT_PROXIES + # Streaming, so we can iterate over the response. response = requests.get(url, stream=True, proxies=proxies) total_size_in_bytes = int(response.headers.get("content-length", 0)) @@ -317,31 +337,13 @@ def __download_with_bar( def __fetch_remote( - remote: RemoteFileMetadata, + remote: _RemoteFileMetadata, dirname: Optional[str] = None, use_proxy: bool = False, - proxies: Optional[dict[str, str]] = DEFAULT_PROXIES, + proxies: Optional[dict[str, str]] = None, ) -> str: - """Helper function to download a remote dataset into path - Fetch a dataset pointed by remote's url, save into path using remote's - filename and ensure its integrity based on the SHA256 Checksum of the - downloaded file. - Parameters - ---------- - remote : RemoteFileMetadata - Named tuple containing remote dataset meta information: url, filename - and checksum - dirname : string - Directory to save the file to. - Returns - ------- - file_path: string - Full path of the created file. - """ - file_path = remote.filename if dirname is None else join(dirname, remote.filename) - proxies = None if not use_proxy else proxies - file_path = __download_with_bar(remote.url, file_path, proxies) + file_path = __download_with_bar(remote.url, file_path, use_proxy, proxies) checksum = __sha256(file_path) if remote.checksum != checksum: raise IOError( @@ -352,23 +354,10 @@ def __fetch_remote( return file_path -def __get_data_home(data_home: Optional[str] = None) -> str: # type: ignore - """Return the path of the scikit-learn data dir. - This folder is used by some large dataset loaders to avoid downloading the - data several times. - By default the data dir is set to a folder named 'fense_data' in the - user home folder. - Alternatively, it can be set by the 'FENSE_DATA' environment - variable or programmatically by giving an explicit folder path. The '~' - symbol is expanded to the user home folder. - If the folder does not already exist, it is automatically created. - Parameters - ---------- - data_home : str | None - The path to data dir. - """ +def __get_data_home(data_home: Optional[str] = None) -> str: if data_home is None: - data_home = environ.get("FENSE_DATA", join(torch.hub.get_dir(), "fense_data")) + DEFAULT_DATA_HOME = join(torch.hub.get_dir(), "fense_data") + data_home = environ.get("FENSE_DATA", DEFAULT_DATA_HOME) data_home: str data_home = expanduser(data_home) @@ -379,20 +368,20 @@ def __get_data_home(data_home: Optional[str] = None) -> str: # type: ignore def __load_pretrain_echecker( echecker_model: str, - device: Union[str, torch.device, None] = "auto", + device: Union[str, torch.device, None] = "cuda_if_available", use_proxy: bool = False, proxies: Optional[dict[str, str]] = None, verbose: int = 0, ) -> BERTFlatClassifier: - if echecker_model not in PRETRAIN_ECHECKERS_DICT: + if echecker_model not in _PRETRAIN_ECHECKERS_DICT: raise ValueError( - f"Invalid argument {echecker_model=}. (expected one of {tuple(PRETRAIN_ECHECKERS_DICT.keys())})" + f"Invalid argument {echecker_model=}. (expected one of {tuple(_PRETRAIN_ECHECKERS_DICT.keys())})" ) device = _get_device(device) tfmers_logging.set_verbosity_error() # suppress loading warnings - url, checksum = PRETRAIN_ECHECKERS_DICT[echecker_model] - remote = RemoteFileMetadata( + url, checksum = _PRETRAIN_ECHECKERS_DICT[echecker_model] + remote = _RemoteFileMetadata( filename=f"{echecker_model}.ckpt", url=url, checksum=checksum ) file_path = __check_download_resource(remote, use_proxy, proxies) @@ -401,17 +390,25 @@ def __load_pretrain_echecker( pylog.debug(f"Loading echecker model from '{file_path}'.") model_states = torch.load(file_path) + model_type = model_states["model_type"] + num_classes = model_states["num_classes"] + state_dict = model_states["state_dict"] if verbose >= 2: pylog.debug( - f"Loading echecker model type '{model_states['model_type']}' with '{model_states['num_classes']}' classes." + f"Loading echecker model type '{model_type}' with '{num_classes}' classes." ) echecker = BERTFlatClassifier( - model_type=model_states["model_type"], - num_classes=model_states["num_classes"], + model_type=model_type, + num_classes=num_classes, ) - echecker.load_state_dict(model_states["state_dict"]) + + # To support transformers > 4.31, because this lib changed BertEmbedding state_dict + if _use_new_echecker_loading(): + state_dict.pop("encoder.embeddings.position_ids") + + echecker.load_state_dict(state_dict) echecker.eval() echecker.to(device=device) return echecker diff --git a/src/aac_metrics/functional/sbert_sim.py b/src/aac_metrics/functional/sbert_sim.py index 29d8c0b..b173d17 100644 --- a/src/aac_metrics/functional/sbert_sim.py +++ b/src/aac_metrics/functional/sbert_sim.py @@ -15,6 +15,8 @@ from aac_metrics.utils.globals import _get_device +DEFAULT_SBERT_SIM_MODEL = "paraphrase-TinyBERT-L6-v2" + pylog = logging.getLogger(__name__) @@ -22,8 +24,8 @@ def sbert_sim( candidates: list[str], mult_references: list[list[str]], return_all_scores: bool = True, - sbert_model: Union[str, SentenceTransformer] = "paraphrase-TinyBERT-L6-v2", - device: Union[str, torch.device, None] = "auto", + sbert_model: Union[str, SentenceTransformer] = DEFAULT_SBERT_SIM_MODEL, + device: Union[str, torch.device, None] = "cuda_if_available", batch_size: int = 32, reset_state: bool = True, verbose: int = 0, @@ -39,7 +41,7 @@ def sbert_sim( Otherwise returns a scalar tensor containing the main global score. defaults to True. :param sbert_model: The sentence BERT model used to extract sentence embeddings for cosine-similarity. defaults to "paraphrase-TinyBERT-L6-v2". - :param device: The PyTorch device used to run FENSE models. If "auto", it will use cuda if available. defaults to "auto". + :param device: The PyTorch device used to run FENSE models. If "cuda_if_available", it will use cuda if available. defaults to "cuda_if_available". :param batch_size: The batch size of the sBERT models. defaults to 32. :param reset_state: If True, reset the state of the PyTorch global generator after the initialization of the pre-trained models. defaults to True. :param verbose: The verbose level. defaults to 0. @@ -86,8 +88,8 @@ def sbert_sim( def _load_sbert( - sbert_model: Union[str, SentenceTransformer] = "paraphrase-TinyBERT-L6-v2", - device: Union[str, torch.device, None] = "auto", + sbert_model: Union[str, SentenceTransformer] = DEFAULT_SBERT_SIM_MODEL, + device: Union[str, torch.device, None] = "cuda_if_available", reset_state: bool = True, ) -> SentenceTransformer: state = torch.random.get_rng_state() diff --git a/src/aac_metrics/functional/spider_fl.py b/src/aac_metrics/functional/spider_fl.py index b3f4a97..65d092a 100644 --- a/src/aac_metrics/functional/spider_fl.py +++ b/src/aac_metrics/functional/spider_fl.py @@ -15,6 +15,7 @@ fer, _load_echecker_and_tokenizer, BERTFlatClassifier, + DEFAULT_FER_MODEL, ) from aac_metrics.functional.spider import spider from aac_metrics.utils.checks import check_metric_inputs @@ -40,10 +41,10 @@ def spider_fl( java_max_memory: str = "8G", timeout: Union[None, int, Iterable[int]] = None, # FluencyError args - echecker: Union[str, BERTFlatClassifier] = "echecker_clotho_audiocaps_base", + echecker: Union[str, BERTFlatClassifier] = DEFAULT_FER_MODEL, echecker_tokenizer: Optional[AutoTokenizer] = None, error_threshold: float = 0.9, - device: Union[str, torch.device, None] = "auto", + device: Union[str, torch.device, None] = "cuda_if_available", batch_size: int = 32, reset_state: bool = True, return_probs: bool = True, @@ -86,7 +87,7 @@ def spider_fl( If None and echecker is not None, this value will be inferred with `echecker.model_type`. defaults to None. :param error_threshold: The threshold used to detect fluency errors for echecker model. defaults to 0.9. - :param device: The PyTorch device used to run FENSE models. If "auto", it will use cuda if available. defaults to "auto". + :param device: The PyTorch device used to run FENSE models. If "cuda_if_available", it will use cuda if available. defaults to "cuda_if_available". :param batch_size: The batch size of the sBERT and echecker models. defaults to 32. :param reset_state: If True, reset the state of the PyTorch global generator after the initialization of the pre-trained models. defaults to True. :param return_probs: If True, return each individual error probability given by the fluency detector model. defaults to True. diff --git a/src/aac_metrics/utils/globals.py b/src/aac_metrics/utils/globals.py index c428ed7..58ceaf7 100644 --- a/src/aac_metrics/utils/globals.py +++ b/src/aac_metrics/utils/globals.py @@ -67,7 +67,7 @@ def _get_cache_path(cache_path: Union[str, Path, None] = None) -> str: def _get_device( - device: Union[str, torch.device, None] = None + device: Union[str, torch.device, None] = "cuda_if_available", ) -> Optional[torch.device]: value_name = "device" process_func = __DEFAULT_GLOBALS[value_name]["process"] @@ -86,14 +86,15 @@ def _get_tmp_path(tmp_path: Union[str, Path, None] = None) -> str: def __get_default_value(value_name: str) -> Any: values = __DEFAULT_GLOBALS[value_name]["values"] process_func = __DEFAULT_GLOBALS[value_name]["process"] + default_val = None for source, value_or_env_varname in values.items(): if source.startswith("env"): - value = os.getenv(value_or_env_varname, None) + value = os.getenv(value_or_env_varname, default_val) else: value = value_or_env_varname - if value is not None: + if value != default_val: value = process_func(value) return value @@ -111,7 +112,7 @@ def __set_default_value( def __get_value(value_name: str, value: Any = None) -> Any: - if value is ... or value is None: + if value is None or value is ...: return __get_default_value(value_name) else: process_func = __DEFAULT_GLOBALS[value_name]["process"] @@ -131,7 +132,7 @@ def __process_path(value: Union[str, Path, None]) -> Union[str, None]: def __process_device(value: Union[str, torch.device, None]) -> Optional[torch.device]: if value is None or value is ...: return None - if value == "auto": + if value == "cuda_if_available": value = "cuda" if torch.cuda.is_available() else "cpu" if isinstance(value, str): value = torch.device(value) @@ -150,7 +151,7 @@ def __process_device(value: Union[str, torch.device, None]) -> Optional[torch.de "device": { "values": { "env": "AAC_METRICS_DEVICE", - "package": "auto", + "package": "cuda_if_available", }, "process": __process_device, }, diff --git a/tests/test_compare_fense.py b/tests/test_compare_fense.py index 6086d2d..eccba83 100644 --- a/tests/test_compare_fense.py +++ b/tests/test_compare_fense.py @@ -2,6 +2,7 @@ # -*- coding: utf-8 -*- import importlib +import logging import os.path as osp import sys import torch @@ -10,10 +11,16 @@ from typing import Any from unittest import TestCase +import transformers + from aac_metrics.classes.fense import FENSE +from aac_metrics.functional.fer import _use_new_echecker_loading from aac_metrics.eval import load_csv_file +pylog = logging.getLogger(__name__) + + class TestCompareFENSE(TestCase): # Set Up methods @classmethod @@ -23,21 +30,24 @@ def setUpClass(cls) -> None: device = "cuda" if torch.cuda.is_available() else "cpu" print(f"Using {device=}") - cls.src_sbert_sim = Evaluator( - device=device, - echecker_model="none", - ) - cls.src_fense = Evaluator( - device=device, - echecker_model="echecker_clotho_audiocaps_base", - ) - + echecker = "echecker_clotho_audiocaps_base" cls.new_fense = FENSE( return_all_scores=True, device=device, verbose=2, - echecker="echecker_clotho_audiocaps_base", + echecker=echecker, + ) + cls.src_sbert_sim = Evaluator( + device=device, + echecker_model="none", ) + if _use_new_echecker_loading(): + cls.src_fense = None + else: + cls.src_fense = Evaluator( + device=device, + echecker_model=echecker, + ) @classmethod def _get_src_evaluator_class(cls) -> Any: @@ -78,7 +88,6 @@ def _test_with_original_fense(self, fpath: str) -> None: cands, mrefs = load_csv_file(fpath) src_sbert_sim_score = self.src_sbert_sim.corpus_score(cands, mrefs).item() - src_fense_score = self.src_fense.corpus_score(cands, mrefs).item() outs: tuple = self.new_fense(cands, mrefs) # type: ignore corpus_outs, _sents_outs = outs @@ -90,11 +99,18 @@ def _test_with_original_fense(self, fpath: str) -> None: new_sbert_sim_score, "Invalid SBERTSim score with original implementation.", ) - self.assertEqual( - src_fense_score, - new_fense_score, - "Invalid FENSE score with original implementation.", - ) + + if self.src_fense is None: + pylog.warning( + f"Skipping test with original FENSE for the transformers version {transformers.__version__}" + ) + else: + src_fense_score = self.src_fense.corpus_score(cands, mrefs).item() + self.assertEqual( + src_fense_score, + new_fense_score, + "Invalid FENSE score with original implementation.", + ) if __name__ == "__main__":