diff --git a/CHANGELOG.md b/CHANGELOG.md index ae7fa7b..f6a3f6e 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -2,6 +2,13 @@ All notable changes to this project will be documented in this file. +## [0.5.1] 2023-12-20 +### Added +- Check sentences inputs for all metrics. + +### Fixed +- Fix `BERTScoreMRefs` metric with 1 candidate and 1 reference. + ## [0.5.0] 2023-12-08 ### Added - New `Vocab` metric to compute vocabulary size and vocabulary ratio. diff --git a/CITATION.cff b/CITATION.cff index e9ecb6d..71458d5 100644 --- a/CITATION.cff +++ b/CITATION.cff @@ -19,5 +19,5 @@ keywords: - captioning - audio-captioning license: MIT -version: 0.5.0 -date-released: '2023-12-08' +version: 0.5.1 +date-released: '2023-12-20' diff --git a/README.md b/README.md index dec1684..8579949 100644 --- a/README.md +++ b/README.md @@ -239,7 +239,7 @@ If you use this software, please consider cite it as "Labbe, E. (2013). aac-metr month = {12}, title = {{aac-metrics}}, url = {https://github.com/Labbeti/aac-metrics/}, - version = {0.5.0}, + version = {0.5.1}, year = {2023}, } ``` diff --git a/docs/aac_metrics.classes.fluerr.rst b/docs/aac_metrics.classes.fluerr.rst deleted file mode 100644 index 64fbf9d..0000000 --- a/docs/aac_metrics.classes.fluerr.rst +++ /dev/null @@ -1,7 +0,0 @@ -aac\_metrics.classes.fluerr module -================================== - -.. automodule:: aac_metrics.classes.fluerr - :members: - :undoc-members: - :show-inheritance: diff --git a/docs/aac_metrics.utils.globals.rst b/docs/aac_metrics.utils.globals.rst new file mode 100644 index 0000000..8a7e68b --- /dev/null +++ b/docs/aac_metrics.utils.globals.rst @@ -0,0 +1,7 @@ +aac\_metrics.utils.globals module +================================= + +.. automodule:: aac_metrics.utils.globals + :members: + :undoc-members: + :show-inheritance: diff --git a/docs/aac_metrics.utils.paths.rst b/docs/aac_metrics.utils.paths.rst deleted file mode 100644 index 0bfcc36..0000000 --- a/docs/aac_metrics.utils.paths.rst +++ /dev/null @@ -1,7 +0,0 @@ -aac\_metrics.utils.paths module -=============================== - -.. automodule:: aac_metrics.utils.paths - :members: - :undoc-members: - :show-inheritance: diff --git a/src/aac_metrics/__init__.py b/src/aac_metrics/__init__.py index 199bdbb..deae747 100644 --- a/src/aac_metrics/__init__.py +++ b/src/aac_metrics/__init__.py @@ -10,10 +10,11 @@ __maintainer__ = "Etienne Labbé (Labbeti)" __name__ = "aac-metrics" __status__ = "Development" -__version__ = "0.5.0" +__version__ = "0.5.1" from .classes.base import AACMetric +from .classes.bert_score_mrefs import BERTScoreMRefs from .classes.bleu import BLEU from .classes.cider_d import CIDErD from .classes.evaluate import Evaluate, DCASE2023Evaluate, _get_metric_factory_classes @@ -28,7 +29,7 @@ from .classes.spider_max import SPIDErMax from .classes.vocab import Vocab from .functional.evaluate import evaluate, dcase2023_evaluate -from .utils.paths import ( +from .utils.globals import ( get_default_cache_path, get_default_java_path, get_default_tmp_path, @@ -40,6 +41,7 @@ __all__ = [ "AACMetric", + "BERTScoreMRefs", "BLEU", "CIDErD", "Evaluate", diff --git a/src/aac_metrics/classes/__init__.py b/src/aac_metrics/classes/__init__.py index 1058fae..d3a17a5 100644 --- a/src/aac_metrics/classes/__init__.py +++ b/src/aac_metrics/classes/__init__.py @@ -1,6 +1,7 @@ #!/usr/bin/env python # -*- coding: utf-8 -*- +from .bert_score_mrefs import BERTScoreMRefs from .bleu import BLEU, BLEU1, BLEU2, BLEU3, BLEU4 from .cider_d import CIDErD from .evaluate import DCASE2023Evaluate, Evaluate @@ -17,6 +18,7 @@ __all__ = [ + "BERTScoreMRefs", "BLEU", "BLEU1", "BLEU2", diff --git a/src/aac_metrics/classes/bert_score_mrefs.py b/src/aac_metrics/classes/bert_score_mrefs.py index 4f07c06..c31c1d9 100644 --- a/src/aac_metrics/classes/bert_score_mrefs.py +++ b/src/aac_metrics/classes/bert_score_mrefs.py @@ -47,7 +47,11 @@ def __init__( verbose: int = 0, ) -> None: model, tokenizer = _load_model_and_tokenizer( - model, None, device, reset_state, verbose + model=model, + tokenizer=None, + device=device, + reset_state=reset_state, + verbose=verbose, ) super().__init__() diff --git a/src/aac_metrics/classes/fense.py b/src/aac_metrics/classes/fense.py index 1253b9b..f1bae17 100644 --- a/src/aac_metrics/classes/fense.py +++ b/src/aac_metrics/classes/fense.py @@ -46,7 +46,14 @@ def __init__( penalty: float = 0.9, verbose: int = 0, ) -> None: - sbert_model, echecker, echecker_tokenizer = _load_models_and_tokenizer(sbert_model, echecker, None, device, reset_state, verbose) # type: ignore + sbert_model, echecker, echecker_tokenizer = _load_models_and_tokenizer( + sbert_model=sbert_model, + echecker=echecker, + echecker_tokenizer=None, + device=device, + reset_state=reset_state, + verbose=verbose, + ) super().__init__() self._return_all_scores = return_all_scores diff --git a/src/aac_metrics/download.py b/src/aac_metrics/download.py index e1e9e38..36bace7 100644 --- a/src/aac_metrics/download.py +++ b/src/aac_metrics/download.py @@ -25,7 +25,7 @@ check_spice_install, ) from aac_metrics.utils.cmdline import _str_to_bool, _setup_logging -from aac_metrics.utils.paths import ( +from aac_metrics.utils.globals import ( _get_cache_path, _get_tmp_path, get_default_cache_path, diff --git a/src/aac_metrics/eval.py b/src/aac_metrics/eval.py index df733b8..94605cd 100644 --- a/src/aac_metrics/eval.py +++ b/src/aac_metrics/eval.py @@ -19,7 +19,7 @@ ) from aac_metrics.utils.checks import check_metric_inputs, check_java_path from aac_metrics.utils.cmdline import _str_to_bool, _str_to_opt_str, _setup_logging -from aac_metrics.utils.paths import ( +from aac_metrics.utils.globals import ( get_default_cache_path, get_default_java_path, get_default_tmp_path, diff --git a/src/aac_metrics/functional/__init__.py b/src/aac_metrics/functional/__init__.py index 04819ea..0bf5fc0 100644 --- a/src/aac_metrics/functional/__init__.py +++ b/src/aac_metrics/functional/__init__.py @@ -1,6 +1,7 @@ #!/usr/bin/env python # -*- coding: utf-8 -*- +from .bert_score_mrefs import bert_score_mrefs from .bleu import bleu, bleu_1, bleu_2, bleu_3, bleu_4 from .cider_d import cider_d from .evaluate import dcase2023_evaluate, evaluate @@ -17,6 +18,7 @@ __all__ = [ + "bert_score_mrefs", "bleu", "bleu_1", "bleu_2", diff --git a/src/aac_metrics/functional/bert_score_mrefs.py b/src/aac_metrics/functional/bert_score_mrefs.py index d239652..d6e932e 100644 --- a/src/aac_metrics/functional/bert_score_mrefs.py +++ b/src/aac_metrics/functional/bert_score_mrefs.py @@ -11,7 +11,9 @@ from transformers.models.auto.tokenization_auto import AutoTokenizer from transformers import logging as tfmers_logging +from aac_metrics.utils.checks import check_metric_inputs from aac_metrics.utils.collections import flat_list, unflat_list, duplicate_list +from aac_metrics.utils.globals import _get_device def bert_score_mrefs( @@ -56,13 +58,20 @@ def bert_score_mrefs( :param verbose: The verbose level. defaults to 0. :returns: A tuple of globals and locals scores or a scalar tensor with the main global score. """ + check_metric_inputs(candidates, mult_references) + if isinstance(model, str): if tokenizer is not None: raise ValueError( f"Invalid argument combinaison {model=} with {tokenizer=}." ) + model, tokenizer = _load_model_and_tokenizer( - model, tokenizer, device, reset_state, verbose + model=model, + tokenizer=tokenizer, + device=device, + reset_state=reset_state, + verbose=verbose, ) elif isinstance(model, nn.Module): @@ -76,21 +85,18 @@ def bert_score_mrefs( f"Invalid argument type {type(model)=}. (expected str or nn.Module)" ) - if device == "auto": - device = "cuda" if torch.cuda.is_available() else "cpu" - if isinstance(device, str): - device = torch.device(device) - + device = _get_device(device) flat_mrefs, sizes = flat_list(mult_references) duplicated_cands = duplicate_list(candidates, sizes) + assert len(duplicated_cands) == len(flat_mrefs) tfmers_verbosity = tfmers_logging.get_verbosity() if verbose <= 1: tfmers_logging.set_verbosity_error() sents_scores = bert_score( - duplicated_cands, - flat_mrefs, + preds=duplicated_cands, + target=flat_mrefs, model_name_or_path=None, model=model, # type: ignore user_tokenizer=tokenizer, @@ -105,6 +111,12 @@ def bert_score_mrefs( # Restore previous verbosity level tfmers_logging.set_verbosity(tfmers_verbosity) + # note: torchmetrics returns a float if input contains 1 cand and 1 ref, even in list + if len(duplicated_cands) == 1 and all( + isinstance(v, float) for v in sents_scores.values() + ): + sents_scores = {k: [v] for k, v in sents_scores.items()} + # sents_scores keys: "precision", "recall", "f1" sents_scores = {k: unflat_list(v, sizes) for k, v in sents_scores.items()} # type: ignore @@ -116,9 +128,9 @@ def bert_score_mrefs( if reduction == "mean": reduction_fn = torch.mean elif reduction == "max": - reduction_fn = max_reduce + reduction_fn = _max_reduce elif reduction == "min": - reduction_fn = min_reduce + reduction_fn = _min_reduce else: REDUCTIONS = ("mean", "max", "min") raise ValueError( @@ -161,11 +173,7 @@ def _load_model_and_tokenizer( ) -> tuple[nn.Module, Optional[Callable]]: state = torch.random.get_rng_state() - if device == "auto": - device = "cuda" if torch.cuda.is_available() else "cpu" - if isinstance(device, str): - device = torch.device(device) - + device = _get_device(device) if isinstance(model, str): tfmers_verbosity = tfmers_logging.get_verbosity() if verbose <= 1: @@ -188,14 +196,14 @@ def _load_model_and_tokenizer( return model, tokenizer # type: ignore -def max_reduce(x: Tensor, dim: Optional[int] = None) -> Tensor: +def _max_reduce(x: Tensor, dim: Optional[int] = None) -> Tensor: if dim is None: return x.max() else: return x.max(dim=dim).values -def min_reduce(x: Tensor, dim: Optional[int] = None) -> Tensor: +def _min_reduce(x: Tensor, dim: Optional[int] = None) -> Tensor: if dim is None: return x.min() else: diff --git a/src/aac_metrics/functional/bleu.py b/src/aac_metrics/functional/bleu.py index 5414c28..cc9c898 100644 --- a/src/aac_metrics/functional/bleu.py +++ b/src/aac_metrics/functional/bleu.py @@ -11,6 +11,8 @@ from torch import Tensor +from aac_metrics.utils.checks import check_metric_inputs + pylog = logging.getLogger(__name__) @@ -158,10 +160,8 @@ def _bleu_update( prev_cooked_cands: list, prev_cooked_mrefs: list, ) -> tuple[list, list[tuple]]: - if len(candidates) != len(mult_references): - raise ValueError( - f"Invalid number of candidates and references. (found {len(candidates)=} != {len(mult_references)=})" - ) + check_metric_inputs(candidates, mult_references) + new_cooked_mrefs = [ __cook_references(refs, None, n, tokenizer) for refs in mult_references ] diff --git a/src/aac_metrics/functional/cider_d.py b/src/aac_metrics/functional/cider_d.py index 35df385..136b24e 100644 --- a/src/aac_metrics/functional/cider_d.py +++ b/src/aac_metrics/functional/cider_d.py @@ -9,6 +9,8 @@ from torch import Tensor +from aac_metrics.utils.checks import check_metric_inputs + def cider_d( candidates: list[str], @@ -66,10 +68,7 @@ def _cider_d_update( prev_cooked_cands: list[Counter], prev_cooked_mrefs: list[list[Counter]], ) -> tuple[list, list]: - if len(candidates) != len(mult_references): - raise ValueError( - f"Invalid number of candidates and references. (found {len(candidates)=} != {len(mult_references)=})" - ) + check_metric_inputs(candidates, mult_references) new_cooked_mrefs = [ [__cook_sentence(ref, n, tokenizer) for ref in refs] for refs in mult_references ] diff --git a/src/aac_metrics/functional/fense.py b/src/aac_metrics/functional/fense.py index b5070c5..643960b 100644 --- a/src/aac_metrics/functional/fense.py +++ b/src/aac_metrics/functional/fense.py @@ -1,11 +1,6 @@ #!/usr/bin/env python # -*- coding: utf-8 -*- -"""FENSE metric functional API. - -Based on original implementation in https://github.com/blmoistawinde/fense/ -""" - import logging from typing import Optional, Union @@ -22,6 +17,7 @@ BERTFlatClassifier, ) from aac_metrics.functional.sbert_sim import sbert_sim, _load_sbert +from aac_metrics.utils.checks import check_metric_inputs pylog = logging.getLogger(__name__) @@ -71,6 +67,7 @@ def fense( :param verbose: The verbose level. defaults to 0. :returns: A tuple of globals and locals scores or a scalar tensor with the main global score. """ + check_metric_inputs(candidates, mult_references) # Init models sbert_model, echecker, echecker_tokenizer = _load_models_and_tokenizer( @@ -148,8 +145,14 @@ def _load_models_and_tokenizer( reset_state: bool = True, verbose: int = 0, ) -> tuple[SentenceTransformer, BERTFlatClassifier, AutoTokenizer]: - sbert_model = _load_sbert(sbert_model, device, reset_state) + sbert_model = _load_sbert( + sbert_model=sbert_model, device=device, reset_state=reset_state + ) echecker, echecker_tokenizer = _load_echecker_and_tokenizer( - echecker, echecker_tokenizer, device, reset_state, verbose + echecker=echecker, + echecker_tokenizer=echecker_tokenizer, + device=device, + reset_state=reset_state, + verbose=verbose, ) return sbert_model, echecker, echecker_tokenizer diff --git a/src/aac_metrics/functional/fer.py b/src/aac_metrics/functional/fer.py index b30fdaa..7a3eabe 100644 --- a/src/aac_metrics/functional/fer.py +++ b/src/aac_metrics/functional/fer.py @@ -1,10 +1,6 @@ #!/usr/bin/env python # -*- coding: utf-8 -*- -""" -BASED ON https://github.com/blmoistawinde/fense/ -""" - import hashlib import logging import os @@ -26,6 +22,9 @@ from transformers.models.auto.tokenization_auto import AutoTokenizer from transformers.tokenization_utils_fast import PreTrainedTokenizerFast +from aac_metrics.utils.checks import is_mono_sents +from aac_metrics.utils.globals import _get_device + # config according to the settings on your computer, this should be default setting of shadowsocks DEFAULT_PROXIES = { @@ -128,6 +127,9 @@ def fer( :param verbose: The verbose level. defaults to 0. :returns: A tuple of globals and locals scores or a scalar tensor with the main global score. """ + if not is_mono_sents(candidates): + error_msg = f"Invalid candidates type. (expected list[str], found {candidates.__class__.__name__})" + raise ValueError(error_msg) # Init models echecker, echecker_tokenizer = _load_echecker_and_tokenizer( @@ -182,13 +184,11 @@ def _load_echecker_and_tokenizer( ) -> tuple[BERTFlatClassifier, AutoTokenizer]: state = torch.random.get_rng_state() - if device == "auto": - device = "cuda" if torch.cuda.is_available() else "cpu" - if isinstance(device, str): - device = torch.device(device) - + device = _get_device(device) if isinstance(echecker, str): - echecker = __load_pretrain_echecker(echecker, device, verbose=verbose) + echecker = __load_pretrain_echecker( + echecker_model=echecker, device=device, verbose=verbose + ) if echecker_tokenizer is None: echecker_tokenizer = AutoTokenizer.from_pretrained(echecker.model_type) # type: ignore @@ -211,10 +211,7 @@ def __detect_error_sents( device: Union[str, torch.device, None], max_len: int = 64, ) -> dict[str, np.ndarray]: - if device == "auto": - device = "cuda" if torch.cuda.is_available() else "cpu" - if isinstance(device, str): - device = torch.device(device) + device = _get_device(device) if len(sents) <= batch_size: batch = __infer_preprocess( @@ -280,11 +277,7 @@ def __infer_preprocess( device: Union[str, torch.device, None], dtype: torch.dtype, ) -> Mapping[str, Tensor]: - if device == "auto": - device = "cuda" if torch.cuda.is_available() else "cpu" - if isinstance(device, str): - device = torch.device(device) - + device = _get_device(device) texts = __text_preprocess(texts) # type: ignore batch = tokenizer(texts, truncation=True, padding="max_length", max_length=max_len) for k in ("input_ids", "attention_mask", "token_type_ids"): @@ -396,11 +389,7 @@ def __load_pretrain_echecker( f"Invalid argument {echecker_model=}. (expected one of {tuple(PRETRAIN_ECHECKERS_DICT.keys())})" ) - if device == "auto": - device = "cuda" if torch.cuda.is_available() else "cpu" - if isinstance(device, str): - device = torch.device(device) - + device = _get_device(device) tfmers_logging.set_verbosity_error() # suppress loading warnings url, checksum = PRETRAIN_ECHECKERS_DICT[echecker_model] remote = RemoteFileMetadata( diff --git a/src/aac_metrics/functional/fluerr.py b/src/aac_metrics/functional/fluerr.py deleted file mode 100644 index 07914c9..0000000 --- a/src/aac_metrics/functional/fluerr.py +++ /dev/null @@ -1,450 +0,0 @@ -#!/usr/bin/env python -# -*- coding: utf-8 -*- - -""" -BASED ON https://github.com/blmoistawinde/fense/ -""" - -import hashlib -import logging -import os -import re -import requests - -from collections import namedtuple -from os import environ, makedirs -from os.path import exists, expanduser, join -from typing import Mapping, Optional, Union - -import numpy as np -import torch - -from torch import nn, Tensor -from tqdm import tqdm -from transformers import logging as tfmers_logging -from transformers.models.auto.modeling_auto import AutoModel -from transformers.models.auto.tokenization_auto import AutoTokenizer -from transformers.tokenization_utils_fast import PreTrainedTokenizerFast - - -# config according to the settings on your computer, this should be default setting of shadowsocks -DEFAULT_PROXIES = { - "http": "socks5h://127.0.0.1:1080", - "https": "socks5h://127.0.0.1:1080", -} -PRETRAIN_ECHECKERS_DICT = { - "echecker_clotho_audiocaps_base": ( - "https://github.com/blmoistawinde/fense/releases/download/V0.1/echecker_clotho_audiocaps_base.ckpt", - "1a719f090af70614bbdb9f9437530b7e133c48cfa4a58d964de0d47fc974a2fa", - ), - "echecker_clotho_audiocaps_tiny": ( - "https://github.com/blmoistawinde/fense/releases/download/V0.1/echecker_clotho_audiocaps_tiny.ckpt", - "90ed0ac5033ec497ec66d4f68588053813e085671136dae312097c96c504f673", - ), -} - -RemoteFileMetadata = namedtuple("RemoteFileMetadata", ["filename", "url", "checksum"]) - -pylog = logging.getLogger(__name__) - - -ERROR_NAMES = ( - "add_tail", - "repeat_event", - "repeat_adv", - "remove_conj", - "remove_verb", - "error", -) - - -class BERTFlatClassifier(nn.Module): - def __init__(self, model_type: str, num_classes: int = 5) -> None: - super().__init__() - self.model_type = model_type - self.num_classes = num_classes - self.encoder = AutoModel.from_pretrained(model_type) - self.dropout = nn.Dropout(self.encoder.config.hidden_dropout_prob) - self.clf = nn.Linear(self.encoder.config.hidden_size, num_classes) - - @classmethod - def from_pretrained( - cls, - model_name: str = "echecker_clotho_audiocaps_base", - device: Union[str, torch.device, None] = "auto", - use_proxy: bool = False, - proxies: Optional[dict[str, str]] = None, - verbose: int = 0, - ) -> "BERTFlatClassifier": - return __load_pretrain_echecker(model_name, device, use_proxy, proxies, verbose) - - def forward( - self, - input_ids: Optional[Tensor] = None, - attention_mask: Optional[Tensor] = None, - token_type_ids: Optional[Tensor] = None, - **kwargs, - ) -> Tensor: - outputs = self.encoder(input_ids, attention_mask, token_type_ids) - x = outputs.last_hidden_state[:, 0, :] - x = self.dropout(x) - logits = self.clf(x) - return logits - - -def fluerr( - candidates: list[str], - return_all_scores: bool = True, - echecker: Union[str, BERTFlatClassifier] = "echecker_clotho_audiocaps_base", - echecker_tokenizer: Optional[AutoTokenizer] = None, - error_threshold: float = 0.9, - device: Union[str, torch.device, None] = "auto", - batch_size: int = 32, - reset_state: bool = True, - return_probs: bool = False, - verbose: int = 0, -) -> Union[Tensor, tuple[dict[str, Tensor], dict[str, Tensor]]]: - """Return fluency error detected by a pre-trained BERT model. - - - Paper: https://arxiv.org/abs/2110.04684 - - Original implementation: https://github.com/blmoistawinde/fense - - :param candidates: The list of sentences to evaluate. - :param mult_references: The list of list of sentences used as target. - :param return_all_scores: If True, returns a tuple containing the globals and locals scores. - Otherwise returns a scalar tensor containing the main global score. - defaults to True. - :param echecker: The echecker model used to detect fluency errors. - Can be "echecker_clotho_audiocaps_base", "echecker_clotho_audiocaps_tiny", "none" or None. - defaults to "echecker_clotho_audiocaps_base". - :param echecker_tokenizer: The tokenizer of the echecker model. - If None and echecker is not None, this value will be inferred with `echecker.model_type`. - defaults to None. - :param error_threshold: The threshold used to detect fluency errors for echecker model. defaults to 0.9. - :param device: The PyTorch device used to run FENSE models. If "auto", it will use cuda if available. defaults to "auto". - :param batch_size: The batch size of the echecker models. defaults to 32. - :param reset_state: If True, reset the state of the PyTorch global generator after the initialization of the pre-trained models. defaults to True. - :param return_probs: If True, return each individual error probability given by the fluency detector model. defaults to False. - :param verbose: The verbose level. defaults to 0. - :returns: A tuple of globals and locals scores or a scalar tensor with the main global score. - """ - - # Init models - echecker, echecker_tokenizer = _load_echecker_and_tokenizer( - echecker, echecker_tokenizer, device, reset_state, verbose - ) - - # Compute and apply fluency error detection penalty - probs_outs_sents = __detect_error_sents( - echecker, - echecker_tokenizer, # type: ignore - candidates, - batch_size, - device, - ) - fluerr_scores = (probs_outs_sents["error"] > error_threshold).astype(float) - - fluerr_scores = torch.from_numpy(fluerr_scores) - fluerr_score = fluerr_scores.mean() - - if return_all_scores: - fluerr_outs_corpus = { - "fluerr": fluerr_score, - } - fluerr_outs_sents = { - "fluerr": fluerr_scores, - } - - if return_probs: - probs_outs_sents = { - f"fluerr.{k}_prob": v for k, v in probs_outs_sents.items() - } - probs_outs_sents = { - k: torch.from_numpy(v) for k, v in probs_outs_sents.items() - } - probs_outs_corpus = {k: v.mean() for k, v in probs_outs_sents.items()} - - fluerr_outs_corpus = probs_outs_corpus | fluerr_outs_corpus - fluerr_outs_sents = probs_outs_sents | fluerr_outs_sents - - fluerr_outs = fluerr_outs_corpus, fluerr_outs_sents - - return fluerr_outs - else: - return fluerr_score - - -# - Private functions -def _load_echecker_and_tokenizer( - echecker: Union[str, BERTFlatClassifier] = "echecker_clotho_audiocaps_base", - echecker_tokenizer: Optional[AutoTokenizer] = None, - device: Union[str, torch.device, None] = "auto", - reset_state: bool = True, - verbose: int = 0, -) -> tuple[BERTFlatClassifier, AutoTokenizer]: - state = torch.random.get_rng_state() - - if device == "auto": - device = "cuda" if torch.cuda.is_available() else "cpu" - if isinstance(device, str): - device = torch.device(device) - - if isinstance(echecker, str): - echecker = __load_pretrain_echecker(echecker, device, verbose=verbose) - - if echecker_tokenizer is None: - echecker_tokenizer = AutoTokenizer.from_pretrained(echecker.model_type) # type: ignore - - echecker = echecker.eval() - for p in echecker.parameters(): - p.requires_grad_(False) - - if reset_state: - torch.random.set_rng_state(state) - - return echecker, echecker_tokenizer # type: ignore - - -def __detect_error_sents( - echecker: BERTFlatClassifier, - echecker_tokenizer: PreTrainedTokenizerFast, - sents: list[str], - batch_size: int, - device: Union[str, torch.device, None], - max_len: int = 64, -) -> dict[str, np.ndarray]: - if device == "auto": - device = "cuda" if torch.cuda.is_available() else "cpu" - if isinstance(device, str): - device = torch.device(device) - - if len(sents) <= batch_size: - batch = __infer_preprocess( - echecker_tokenizer, - sents, - max_len=max_len, - device=device, - dtype=torch.long, - ) - logits: Tensor = echecker(**batch) - assert not logits.requires_grad - # batch_logits: (bsize, num_classes=6) - # note: fix error in the original fense code: https://github.com/blmoistawinde/fense/blob/main/fense/evaluator.py#L69 - probs = logits.sigmoid().transpose(0, 1).cpu().numpy() - probs_dic: dict[str, np.ndarray] = dict(zip(ERROR_NAMES, probs)) - - else: - dic_lst_probs = {name: [] for name in ERROR_NAMES} - - for i in range(0, len(sents), batch_size): - batch = __infer_preprocess( - echecker_tokenizer, - sents[i : i + batch_size], - max_len=max_len, - device=device, - dtype=torch.long, - ) - - batch_logits: Tensor = echecker(**batch) - assert not batch_logits.requires_grad - # batch_logits: (bsize, num_classes=6) - # classes: add_tail, repeat_event, repeat_adv, remove_conj, remove_verb, error - probs = batch_logits.sigmoid().cpu().numpy() - - for j, name in enumerate(dic_lst_probs.keys()): - dic_lst_probs[name].append(probs[:, j]) - - probs_dic = { - name: np.concatenate(probs) for name, probs in dic_lst_probs.items() - } - - return probs_dic - - -def __check_download_resource( - remote: RemoteFileMetadata, - use_proxy: bool = False, - proxies: Optional[dict[str, str]] = None, -) -> str: - proxies = DEFAULT_PROXIES if use_proxy and proxies is None else proxies - data_home = __get_data_home() - file_path = os.path.join(data_home, remote.filename) - if not os.path.exists(file_path): - # currently don't capture error at this level, assume download success - file_path = __download(remote, data_home, use_proxy, proxies) - return file_path - - -def __infer_preprocess( - tokenizer: PreTrainedTokenizerFast, - texts: list[str], - max_len: int, - device: Union[str, torch.device, None], - dtype: torch.dtype, -) -> Mapping[str, Tensor]: - if device == "auto": - device = "cuda" if torch.cuda.is_available() else "cpu" - if isinstance(device, str): - device = torch.device(device) - - texts = __text_preprocess(texts) # type: ignore - batch = tokenizer(texts, truncation=True, padding="max_length", max_length=max_len) - for k in ("input_ids", "attention_mask", "token_type_ids"): - batch[k] = torch.as_tensor(batch[k], device=device, dtype=dtype) # type: ignore - return batch - - -def __download( - remote: RemoteFileMetadata, - file_path: Optional[str] = None, - use_proxy: bool = False, - proxies: Optional[dict[str, str]] = DEFAULT_PROXIES, -) -> str: - data_home = __get_data_home() - file_path = __fetch_remote(remote, data_home, use_proxy, proxies) - return file_path - - -def __download_with_bar( - url: str, - file_path: str, - proxies: Optional[dict[str, str]] = DEFAULT_PROXIES, -) -> str: - # Streaming, so we can iterate over the response. - response = requests.get(url, stream=True, proxies=proxies) - total_size_in_bytes = int(response.headers.get("content-length", 0)) - block_size = 1024 # 1 KB - progress_bar = tqdm(total=total_size_in_bytes, unit="B", unit_scale=True) - with open(file_path, "wb") as file: - for data in response.iter_content(block_size): - progress_bar.update(len(data)) - file.write(data) - progress_bar.close() - if total_size_in_bytes != 0 and progress_bar.n != total_size_in_bytes: - raise Exception("ERROR, something went wrong with the downloading") - return file_path - - -def __fetch_remote( - remote: RemoteFileMetadata, - dirname: Optional[str] = None, - use_proxy: bool = False, - proxies: Optional[dict[str, str]] = DEFAULT_PROXIES, -) -> str: - """Helper function to download a remote dataset into path - Fetch a dataset pointed by remote's url, save into path using remote's - filename and ensure its integrity based on the SHA256 Checksum of the - downloaded file. - Parameters - ---------- - remote : RemoteFileMetadata - Named tuple containing remote dataset meta information: url, filename - and checksum - dirname : string - Directory to save the file to. - Returns - ------- - file_path: string - Full path of the created file. - """ - - file_path = remote.filename if dirname is None else join(dirname, remote.filename) - proxies = None if not use_proxy else proxies - file_path = __download_with_bar(remote.url, file_path, proxies) - checksum = __sha256(file_path) - if remote.checksum != checksum: - raise IOError( - "{} has an SHA256 checksum ({}) " - "differing from expected ({}), " - "file may be corrupted.".format(file_path, checksum, remote.checksum) - ) - return file_path - - -def __get_data_home(data_home: Optional[str] = None) -> str: # type: ignore - """Return the path of the scikit-learn data dir. - This folder is used by some large dataset loaders to avoid downloading the - data several times. - By default the data dir is set to a folder named 'fense_data' in the - user home folder. - Alternatively, it can be set by the 'FENSE_DATA' environment - variable or programmatically by giving an explicit folder path. The '~' - symbol is expanded to the user home folder. - If the folder does not already exist, it is automatically created. - Parameters - ---------- - data_home : str | None - The path to data dir. - """ - if data_home is None: - data_home = environ.get("FENSE_DATA", join(torch.hub.get_dir(), "fense_data")) - - data_home: str - data_home = expanduser(data_home) - if not exists(data_home): - makedirs(data_home) - return data_home - - -def __load_pretrain_echecker( - echecker_model: str, - device: Union[str, torch.device, None] = "auto", - use_proxy: bool = False, - proxies: Optional[dict[str, str]] = None, - verbose: int = 0, -) -> BERTFlatClassifier: - if echecker_model not in PRETRAIN_ECHECKERS_DICT: - raise ValueError( - f"Invalid argument {echecker_model=}. (expected one of {tuple(PRETRAIN_ECHECKERS_DICT.keys())})" - ) - - if device == "auto": - device = "cuda" if torch.cuda.is_available() else "cpu" - if isinstance(device, str): - device = torch.device(device) - - tfmers_logging.set_verbosity_error() # suppress loading warnings - url, checksum = PRETRAIN_ECHECKERS_DICT[echecker_model] - remote = RemoteFileMetadata( - filename=f"{echecker_model}.ckpt", url=url, checksum=checksum - ) - file_path = __check_download_resource(remote, use_proxy, proxies) - - if verbose >= 2: - pylog.debug(f"Loading echecker model from '{file_path}'.") - - model_states = torch.load(file_path) - - if verbose >= 2: - pylog.debug( - f"Loading echecker model type '{model_states['model_type']}' with '{model_states['num_classes']}' classes." - ) - - echecker = BERTFlatClassifier( - model_type=model_states["model_type"], - num_classes=model_states["num_classes"], - ) - echecker.load_state_dict(model_states["state_dict"]) - echecker.eval() - echecker.to(device=device) - return echecker - - -def __sha256(path: str) -> str: - """Calculate the sha256 hash of the file at path.""" - sha256hash = hashlib.sha256() - chunk_size = 8192 - with open(path, "rb") as f: - while True: - buffer = f.read(chunk_size) - if not buffer: - break - sha256hash.update(buffer) - return sha256hash.hexdigest() - - -def __text_preprocess(inp: Union[str, list[str]]) -> Union[str, list[str]]: - if isinstance(inp, str): - return re.sub(r"[^\w\s]", "", inp).lower() - else: - return [re.sub(r"[^\w\s]", "", x).lower() for x in inp] diff --git a/src/aac_metrics/functional/meteor.py b/src/aac_metrics/functional/meteor.py index 381c821..1abea6d 100644 --- a/src/aac_metrics/functional/meteor.py +++ b/src/aac_metrics/functional/meteor.py @@ -1,8 +1,6 @@ #!/usr/bin/env python # -*- coding: utf-8 -*- -# ORIGINAL CODE FROM https://github.com/tylin/coco-caption - import logging import os.path as osp import platform @@ -16,8 +14,8 @@ from torch import Tensor -from aac_metrics.utils.checks import check_java_path -from aac_metrics.utils.paths import _get_cache_path, _get_java_path +from aac_metrics.utils.checks import check_java_path, check_metric_inputs +from aac_metrics.utils.globals import _get_cache_path, _get_java_path pylog = logging.getLogger(__name__) @@ -45,6 +43,7 @@ def meteor( - Paper: https://dl.acm.org/doi/pdf/10.5555/1626355.1626389 - Documentation: https://www.cs.cmu.edu/~alavie/METEOR/README.html + - Original implementation: https://github.com/tylin/coco-caption :param candidates: The list of sentences to evaluate. :param mult_references: The list of list of sentences used as target. @@ -69,6 +68,8 @@ def meteor( :param verbose: The verbose level. defaults to 0. :returns: A tuple of globals and locals scores or a scalar tensor with the main global score. """ + check_metric_inputs(candidates, mult_references) + cache_path = _get_cache_path(cache_path) java_path = _get_java_path(java_path) @@ -87,11 +88,6 @@ def meteor( f"Invalid Java executable to compute METEOR score. ({java_path})" ) - if len(candidates) != len(mult_references): - raise ValueError( - f"Invalid number of candidates and references. (found {len(candidates)=} != {len(mult_references)=})" - ) - if language not in SUPPORTED_LANGUAGES: raise ValueError( f"Invalid argument {language=}. (expected one of {SUPPORTED_LANGUAGES})" diff --git a/src/aac_metrics/functional/mult_cands.py b/src/aac_metrics/functional/mult_cands.py index 851dba0..5145714 100644 --- a/src/aac_metrics/functional/mult_cands.py +++ b/src/aac_metrics/functional/mult_cands.py @@ -8,6 +8,8 @@ from torch import Tensor +from aac_metrics.utils.checks import is_mult_sents + SELECTIONS = ("max", "min", "mean") @@ -34,10 +36,13 @@ def mult_cands_metric( :param **kwargs: The keywords arguments given to the metric call. :returns: A tuple of globals and locals scores or a scalar tensor with the main global score. """ - if selection not in SELECTIONS: - raise ValueError( - f"Invalid argument {selection=}. (expected one of {SELECTIONS})" - ) + if not is_mult_sents(mult_candidates): + error_msg = f"Invalid mult_candidates type. (expected list[list[str]], found {mult_references.__class__.__name__})" + raise ValueError(error_msg) + + if not is_mult_sents(mult_references): + error_msg = f"Invalid mult_references type. (expected list[list[str]], found {mult_references.__class__.__name__})" + raise ValueError(error_msg) if len(mult_candidates) <= 0: raise ValueError( @@ -48,6 +53,11 @@ def mult_cands_metric( f"Number of candidate and mult_references are different ({len(mult_candidates)} != {len(mult_references)})." ) + if selection not in SELECTIONS: + raise ValueError( + f"Invalid argument {selection=}. (expected one of {SELECTIONS})" + ) + n_cands_per_audio = len(mult_candidates[0]) if not all(len(cands) == n_cands_per_audio for cands in mult_candidates): raise ValueError( diff --git a/src/aac_metrics/functional/rouge_l.py b/src/aac_metrics/functional/rouge_l.py index 66aa10e..1106dfa 100644 --- a/src/aac_metrics/functional/rouge_l.py +++ b/src/aac_metrics/functional/rouge_l.py @@ -1,16 +1,6 @@ #!/usr/bin/env python # -*- coding: utf-8 -*- -# Description : Computes ROUGE-L metric as described by Lin and Hovey (2004) -# -# Creation Date : 2015-01-07 06:03 -# Author : Ramakrishna Vedantam - -# ================================================================= -# This code was pulled from https://github.com/tylin/coco-caption -# Image-specific names and comments have been changed to be audio-specific -# ================================================================= - import logging from typing import Callable, Union @@ -20,6 +10,8 @@ from torch import Tensor +from aac_metrics.utils.checks import check_metric_inputs + pylog = logging.getLogger(__name__) @@ -34,6 +26,8 @@ def rouge_l( """Recall-Oriented Understudy for Gisting Evaluation function. - Paper: https://aclanthology.org/W04-1013.pdf + - Original Author: Ramakrishna Vedantam + - Original implementation: https://github.com/tylin/coco-caption :param candidates: The list of sentences to evaluate. :param mult_references: The list of list of sentences used as target. @@ -55,10 +49,8 @@ def _rouge_l_update( tokenizer: Callable[[str], list[str]], prev_rouge_l_scores: list[float], ) -> list[float]: - if len(candidates) != len(mult_references): - raise ValueError( - f"Invalid number of candidates and references. (found {len(candidates)=} != {len(mult_references)=})" - ) + check_metric_inputs(candidates, mult_references) + new_rouge_l_scores = [ __calc_score(cand, refs, beta, tokenizer) for cand, refs in zip(candidates, mult_references) diff --git a/src/aac_metrics/functional/sbert_sim.py b/src/aac_metrics/functional/sbert_sim.py index 82c8ccc..29d8c0b 100644 --- a/src/aac_metrics/functional/sbert_sim.py +++ b/src/aac_metrics/functional/sbert_sim.py @@ -1,10 +1,6 @@ #!/usr/bin/env python # -*- coding: utf-8 -*- -""" -BASED ON https://github.com/blmoistawinde/fense/ -""" - import logging from typing import Union @@ -15,6 +11,9 @@ from sentence_transformers import SentenceTransformer from torch import Tensor +from aac_metrics.utils.checks import check_metric_inputs +from aac_metrics.utils.globals import _get_device + pylog = logging.getLogger(__name__) @@ -46,6 +45,8 @@ def sbert_sim( :param verbose: The verbose level. defaults to 0. :returns: A tuple of globals and locals scores or a scalar tensor with the main global score. """ + check_metric_inputs(candidates, mult_references) + # Init models sbert_model = _load_sbert(sbert_model, device, reset_state) @@ -91,11 +92,7 @@ def _load_sbert( ) -> SentenceTransformer: state = torch.random.get_rng_state() - if device == "auto": - device = "cuda" if torch.cuda.is_available() else "cpu" - if isinstance(device, str): - device = torch.device(device) - + device = _get_device(device) if isinstance(sbert_model, str): sbert_model = SentenceTransformer(sbert_model, device=device) # type: ignore sbert_model.to(device=device) diff --git a/src/aac_metrics/functional/spice.py b/src/aac_metrics/functional/spice.py index c72c1c1..faddb64 100644 --- a/src/aac_metrics/functional/spice.py +++ b/src/aac_metrics/functional/spice.py @@ -23,8 +23,8 @@ from torch import Tensor -from aac_metrics.utils.checks import check_java_path -from aac_metrics.utils.paths import ( +from aac_metrics.utils.checks import check_java_path, check_metric_inputs +from aac_metrics.utils.globals import ( _get_cache_path, _get_java_path, _get_tmp_path, @@ -82,6 +82,7 @@ def spice( :param verbose: The verbose level. defaults to 0. :returns: A tuple of globals and locals scores or a scalar tensor with the main global score. """ + check_metric_inputs(candidates, mult_references) cache_path = _get_cache_path(cache_path) java_path = _get_java_path(java_path) diff --git a/src/aac_metrics/functional/spider.py b/src/aac_metrics/functional/spider.py index c2b0a0d..f297557 100644 --- a/src/aac_metrics/functional/spider.py +++ b/src/aac_metrics/functional/spider.py @@ -8,6 +8,7 @@ from aac_metrics.functional.cider_d import cider_d from aac_metrics.functional.spice import spice +from aac_metrics.utils.checks import check_metric_inputs def spider( @@ -59,14 +60,9 @@ def spider( :param verbose: The verbose level. defaults to 0. :returns: A tuple of globals and locals scores or a scalar tensor with the main global score. """ - - if len(candidates) != len(mult_references): - raise ValueError( - f"Number of candidates and mult_references are different (found {len(candidates)} != {len(mult_references)})." - ) + check_metric_inputs(candidates, mult_references) sub_return_all_scores = True - cider_d_outs: tuple[dict[str, Tensor], dict[str, Tensor]] = cider_d( # type: ignore candidates=candidates, mult_references=mult_references, diff --git a/src/aac_metrics/functional/spider_fl.py b/src/aac_metrics/functional/spider_fl.py index 9c71686..b3f4a97 100644 --- a/src/aac_metrics/functional/spider_fl.py +++ b/src/aac_metrics/functional/spider_fl.py @@ -1,10 +1,6 @@ #!/usr/bin/env python # -*- coding: utf-8 -*- -""" -Original based on https://github.com/blmoistawinde/fense/ -""" - import logging from pathlib import Path @@ -21,6 +17,7 @@ BERTFlatClassifier, ) from aac_metrics.functional.spider import spider +from aac_metrics.utils.checks import check_metric_inputs pylog = logging.getLogger(__name__) @@ -56,7 +53,7 @@ def spider_fl( ) -> Union[Tensor, tuple[dict[str, Tensor], dict[str, Tensor]]]: """Combinaison of SPIDEr with Fluency Error detector. - Based on https://github.com/felixgontier/dcase-2023-baseline/blob/main/metrics.py#L48. + - Original implementation: https://github.com/felixgontier/dcase-2023-baseline/blob/main/metrics.py#L48. .. warning:: This metric requires at least 2 candidates with 2 sets of references, otherwise it will raises a ValueError. @@ -97,6 +94,8 @@ def spider_fl( :param verbose: The verbose level. defaults to 0. :returns: A tuple of globals and locals scores or a scalar tensor with the main global score. """ + check_metric_inputs(candidates, mult_references) + # Init models echecker, echecker_tokenizer = _load_echecker_and_tokenizer( echecker=echecker, diff --git a/src/aac_metrics/functional/vocab.py b/src/aac_metrics/functional/vocab.py index 760e02a..f742617 100644 --- a/src/aac_metrics/functional/vocab.py +++ b/src/aac_metrics/functional/vocab.py @@ -9,6 +9,8 @@ from torch import Tensor +from aac_metrics.utils.checks import check_metric_inputs, is_mono_sents + pylog = logging.getLogger(__name__) @@ -39,6 +41,12 @@ def vocab( :param verbose: The verbose level. defaults to 0. :returns: A tuple of globals and locals scores or a scalar tensor with the main global score. """ + if mult_references is not None: + check_metric_inputs(candidates, mult_references) + elif not is_mono_sents(candidates): + error_msg = f"Invalid candidates type. (expected list[str], found {candidates.__class__.__name__})" + raise ValueError(error_msg) + tok_cands = list(map(tokenizer, candidates)) del candidates diff --git a/src/aac_metrics/info.py b/src/aac_metrics/info.py index a0a6cf4..2ba4058 100644 --- a/src/aac_metrics/info.py +++ b/src/aac_metrics/info.py @@ -13,7 +13,7 @@ import aac_metrics from aac_metrics.utils.checks import _get_java_version -from aac_metrics.utils.paths import ( +from aac_metrics.utils.globals import ( get_default_cache_path, get_default_java_path, get_default_tmp_path, diff --git a/src/aac_metrics/utils/checks.py b/src/aac_metrics/utils/checks.py index 061d36f..cf86afc 100644 --- a/src/aac_metrics/utils/checks.py +++ b/src/aac_metrics/utils/checks.py @@ -38,15 +38,13 @@ def check_metric_inputs( same_len = len(candidates) == len(mult_references) if not same_len: - raise ValueError( - f"Invalid number of candidates ({len(candidates)}) with the number of references ({len(mult_references)})." - ) + error_msg = f"Invalid number of candidates ({len(candidates)}) with the number of references ({len(mult_references)})." + raise ValueError(error_msg) at_least_1_ref_per_cand = all(len(refs) > 0 for refs in mult_references) if not at_least_1_ref_per_cand: - raise ValueError( - "Invalid number of references per candidate. (found at least 1 empty list of references)" - ) + error_msg = "Invalid number of references per candidate. (found at least 1 empty list of references)" + raise ValueError(error_msg) def check_java_path(java_path: Union[str, Path]) -> bool: diff --git a/src/aac_metrics/utils/globals.py b/src/aac_metrics/utils/globals.py new file mode 100644 index 0000000..c428ed7 --- /dev/null +++ b/src/aac_metrics/utils/globals.py @@ -0,0 +1,173 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- + +import logging +import os +import os.path as osp +import tempfile + +from pathlib import Path +from typing import Any, Optional, Union + +import torch + + +pylog = logging.getLogger(__name__) + + +# Public functions +def get_default_cache_path() -> str: + """Returns the default cache directory path. + + If :func:`~aac_metrics.utils.globals.set_default_cache_path` has been used before with a string argument, it will return the value given to this function. + Else if the environment variable AAC_METRICS_CACHE_PATH has been set to a string, it will return its value. + Else it will be equal to "~/.cache" by default. + """ + return __get_default_value("cache") + + +def get_default_java_path() -> str: + """Returns the default java executable path. + + If :func:`~aac_metrics.utils.globals.set_default_java_path` has been used before with a string argument, it will return the value given to this function. + Else if the environment variable AAC_METRICS_JAVA_PATH has been set to a string, it will return its value. + Else it will be equal to "java" by default. + """ + return __get_default_value("java") + + +def get_default_tmp_path() -> str: + """Returns the default temporary directory path. + + If :func:`~aac_metrics.utils.globals.set_default_tmp_path` has been used before with a string argument, it will return the value given to this function. + Else if the environment variable AAC_METRICS_TMP_PATH has been set to a string, it will return its value. + Else it will be equal to the value returned by :func:`~tempfile.gettempdir()` by default. + """ + return __get_default_value("tmp") + + +def set_default_cache_path(cache_path: Union[str, Path, None]) -> None: + """Override default cache directory path.""" + __set_default_value("cache", cache_path) + + +def set_default_java_path(java_path: Union[str, Path, None]) -> None: + """Override default java executable path.""" + __set_default_value("java", java_path) + + +def set_default_tmp_path(tmp_path: Union[str, Path, None]) -> None: + """Override default temporary directory path.""" + __set_default_value("tmp", tmp_path) + + +# Private functions +def _get_cache_path(cache_path: Union[str, Path, None] = None) -> str: + return __get_value("cache", cache_path) + + +def _get_device( + device: Union[str, torch.device, None] = None +) -> Optional[torch.device]: + value_name = "device" + process_func = __DEFAULT_GLOBALS[value_name]["process"] + device = process_func(device) + return device # type: ignore + + +def _get_java_path(java_path: Union[str, Path, None] = None) -> str: + return __get_value("java", java_path) + + +def _get_tmp_path(tmp_path: Union[str, Path, None] = None) -> str: + return __get_value("tmp", tmp_path) + + +def __get_default_value(value_name: str) -> Any: + values = __DEFAULT_GLOBALS[value_name]["values"] + process_func = __DEFAULT_GLOBALS[value_name]["process"] + + for source, value_or_env_varname in values.items(): + if source.startswith("env"): + value = os.getenv(value_or_env_varname, None) + else: + value = value_or_env_varname + + if value is not None: + value = process_func(value) + return value + + pylog.error(f"Values: {values}") + raise RuntimeError( + f"Invalid default value for value_name={value_name}. (all default values are None)" + ) + + +def __set_default_value( + value_name: str, + value: Any, +) -> None: + __DEFAULT_GLOBALS[value_name]["values"]["user"] = value + + +def __get_value(value_name: str, value: Any = None) -> Any: + if value is ... or value is None: + return __get_default_value(value_name) + else: + process_func = __DEFAULT_GLOBALS[value_name]["process"] + value = process_func(value) + return value + + +def __process_path(value: Union[str, Path, None]) -> Union[str, None]: + if value is None or value is ...: + return None + value = str(value) + value = osp.expanduser(value) + value = osp.expandvars(value) + return value + + +def __process_device(value: Union[str, torch.device, None]) -> Optional[torch.device]: + if value is None or value is ...: + return None + if value == "auto": + value = "cuda" if torch.cuda.is_available() else "cpu" + if isinstance(value, str): + value = torch.device(value) + return value + + +__DEFAULT_GLOBALS = { + "cache": { + "values": { + "user": None, + "env": "AAC_METRICS_CACHE_PATH", + "package": osp.join("~", ".cache"), + }, + "process": __process_path, + }, + "device": { + "values": { + "env": "AAC_METRICS_DEVICE", + "package": "auto", + }, + "process": __process_device, + }, + "java": { + "values": { + "user": None, + "env": "AAC_METRICS_JAVA_PATH", + "package": "java", + }, + "process": __process_path, + }, + "tmp": { + "values": { + "user": None, + "env": "AAC_METRICS_TMP_PATH", + "package": tempfile.gettempdir(), + }, + "process": __process_path, + }, +} diff --git a/src/aac_metrics/utils/paths.py b/src/aac_metrics/utils/paths.py deleted file mode 100644 index 43af618..0000000 --- a/src/aac_metrics/utils/paths.py +++ /dev/null @@ -1,146 +0,0 @@ -#!/usr/bin/env python -# -*- coding: utf-8 -*- - -import logging -import os -import os.path as osp -import tempfile - -from pathlib import Path -from typing import Union, overload - - -pylog = logging.getLogger(__name__) - - -__DEFAULT_GLOBALS: dict[str, dict[str, Union[str, None]]] = { - "cache": { - "user": None, - "env": "AAC_METRICS_CACHE_PATH", - "package": osp.expanduser(osp.join("~", ".cache")), - }, - "java": { - "user": None, - "env": "AAC_METRICS_JAVA_PATH", - "package": "java", - }, - "tmp": { - "user": None, - "env": "AAC_METRICS_TMP_PATH", - "package": tempfile.gettempdir(), - }, -} - - -# Public functions -def get_default_cache_path() -> str: - """Returns the default cache directory path. - - If :func:`~aac_metrics.utils.path.set_default_cache_path` has been used before with a string argument, it will return the value given to this function. - Else if the environment variable AAC_METRICS_CACHE_PATH has been set to a string, it will return its value. - Else it will be equal to "~/.cache" by default. - """ - return __get_default_value("cache") - - -def get_default_java_path() -> str: - """Returns the default java executable path. - - If :func:`~aac_metrics.utils.path.set_default_java_path` has been used before with a string argument, it will return the value given to this function. - Else if the environment variable AAC_METRICS_JAVA_PATH has been set to a string, it will return its value. - Else it will be equal to "java" by default. - """ - return __get_default_value("java") - - -def get_default_tmp_path() -> str: - """Returns the default temporary directory path. - - If :func:`~aac_metrics.utils.path.set_default_tmp_path` has been used before with a string argument, it will return the value given to this function. - Else if the environment variable AAC_METRICS_TMP_PATH has been set to a string, it will return its value. - Else it will be equal to the value returned by :func:`~tempfile.gettempdir()` by default. - """ - return __get_default_value("tmp") - - -def set_default_cache_path(cache_path: Union[str, Path, None]) -> None: - """Override default cache directory path.""" - __set_default_value("cache", cache_path) - - -def set_default_java_path(java_path: Union[str, Path, None]) -> None: - """Override default java executable path.""" - __set_default_value("java", java_path) - - -def set_default_tmp_path(tmp_path: Union[str, Path, None]) -> None: - """Override default temporary directory path.""" - __set_default_value("tmp", tmp_path) - - -# Private functions -def _get_cache_path(cache_path: Union[str, Path, None] = None) -> str: - return __get_value("cache", cache_path) - - -def _get_java_path(java_path: Union[str, Path, None] = None) -> str: - return __get_value("java", java_path) - - -def _get_tmp_path(tmp_path: Union[str, Path, None] = None) -> str: - return __get_value("tmp", tmp_path) - - -def __get_default_value(value_name: str) -> str: - values = __DEFAULT_GLOBALS[value_name] - - for source, value_or_env_varname in values.items(): - if value_or_env_varname is None: - continue - - if source.startswith("env"): - path = os.getenv(value_or_env_varname, None) - else: - path = value_or_env_varname - - if path is not None: - path = __process_value(path) - return path - - pylog.error(f"Paths values: {values}") - raise RuntimeError( - f"Invalid default path for {value_name=}. (all default paths are None)" - ) - - -def __set_default_value( - value_name: str, - value: Union[str, Path, None], -) -> None: - value = __process_value(value) - __DEFAULT_GLOBALS[value_name]["user"] = value - - -def __get_value(value_name: str, value: Union[str, Path, None] = None) -> str: - if value is ... or value is None: - return __get_default_value(value_name) - else: - value = __process_value(value) - return value - - -@overload -def __process_value(value: None) -> None: - ... - - -@overload -def __process_value(value: Union[str, Path]) -> str: - ... - - -def __process_value(value: Union[str, Path, None]) -> Union[str, None]: - value = str(value) - value = osp.expanduser(value) - value = osp.expandvars(value) - return value diff --git a/src/aac_metrics/utils/tokenization.py b/src/aac_metrics/utils/tokenization.py index 0236895..a84ce7a 100644 --- a/src/aac_metrics/utils/tokenization.py +++ b/src/aac_metrics/utils/tokenization.py @@ -13,7 +13,7 @@ from aac_metrics.utils.checks import check_java_path, is_mono_sents from aac_metrics.utils.collections import flat_list, unflat_list -from aac_metrics.utils.paths import ( +from aac_metrics.utils.globals import ( _get_cache_path, _get_java_path, _get_tmp_path, diff --git a/tests/test_compare_cet.py b/tests/test_compare_cet.py index 7c838e6..5058341 100644 --- a/tests/test_compare_cet.py +++ b/tests/test_compare_cet.py @@ -17,7 +17,7 @@ from aac_metrics.functional.evaluate import evaluate from aac_metrics.eval import load_csv_file -from aac_metrics.utils.paths import ( +from aac_metrics.utils.globals import ( get_default_tmp_path, ) from aac_metrics.download import _download_spice