From a5f056f9712a8accfedeff6609b556c4d45a897f Mon Sep 17 00:00:00 2001 From: Labbeti Date: Tue, 12 Sep 2023 11:30:40 +0200 Subject: [PATCH] Version 0.4.5 --- .github/workflows/python-package-pip.yaml | 4 +- CHANGELOG.md | 11 + CITATION.cff | 4 +- README.md | 26 +-- docs/aac_metrics.eval.rst | 7 + docs/usage.rst | 24 +-- pyproject.toml | 2 +- src/aac_metrics/__init__.py | 9 +- src/aac_metrics/classes/bleu.py | 13 +- src/aac_metrics/classes/cider_d.py | 26 +-- src/aac_metrics/classes/evaluate.py | 2 +- src/aac_metrics/classes/fense.py | 28 +-- src/aac_metrics/classes/fluerr.py | 22 +- src/aac_metrics/classes/meteor.py | 21 +- src/aac_metrics/classes/rouge_l.py | 14 +- src/aac_metrics/classes/sbert_sim.py | 16 +- src/aac_metrics/classes/spice.py | 25 ++- src/aac_metrics/classes/spider.py | 6 +- src/aac_metrics/classes/spider_fl.py | 6 +- src/aac_metrics/classes/spider_max.py | 8 +- src/aac_metrics/download.py | 28 ++- src/aac_metrics/eval.py | 237 ++++++++++++++++++++++ src/aac_metrics/functional/cider_d.py | 3 + src/aac_metrics/functional/evaluate.py | 6 +- src/aac_metrics/functional/fense.py | 4 +- src/aac_metrics/functional/fluerr.py | 4 +- src/aac_metrics/functional/meteor.py | 15 +- src/aac_metrics/functional/spice.py | 44 ++-- src/aac_metrics/functional/spider.py | 9 +- src/aac_metrics/functional/spider_fl.py | 9 +- src/aac_metrics/functional/spider_max.py | 9 +- tests/test_compare_cet.py | 21 +- tests/test_compare_fense.py | 2 +- tests/test_doc_examples.py | 130 ++++++++++++ 34 files changed, 629 insertions(+), 166 deletions(-) create mode 100644 docs/aac_metrics.eval.rst create mode 100644 src/aac_metrics/eval.py create mode 100644 tests/test_doc_examples.py diff --git a/.github/workflows/python-package-pip.yaml b/.github/workflows/python-package-pip.yaml index 489876f..1503230 100644 --- a/.github/workflows/python-package-pip.yaml +++ b/.github/workflows/python-package-pip.yaml @@ -10,6 +10,7 @@ on: env: CACHE_NUMBER: 0 # increase to reset cache manually + TMPDIR: '/tmp' # Cancel workflow if a new push occurs concurrency: @@ -49,8 +50,9 @@ jobs: - name: Install package shell: bash # note: ${GITHUB_REF##*/} gives the branch name + # note 2: dev is not the branch here, but the dev dependencies run: | - python -m pip install "aac-metrics[${GITHUB_REF_NAME}] @ git+https://github.com/Labbeti/aac-metrics@${GITHUB_REF##*/}" + python -m pip install "aac-metrics[dev] @ git+https://github.com/Labbeti/aac-metrics@${GITHUB_REF##*/}" - name: Load cache of external code and data uses: actions/cache@master diff --git a/CHANGELOG.md b/CHANGELOG.md index 170c2e1..58ca8a5 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -2,6 +2,17 @@ All notable changes to this project will be documented in this file. +## [0.4.5] 2023-09-12 +### Added +- Argument `use_shell` for `METEOR` and `SPICE` metrics and `download` function to fix Windows-OS specific error. + +### Changed +- Rename `evaluate.py` script to `eval.py`. + +### Fixed +- Workflow on main branch. +- Examples in README and doc with at least 2 sentences, and add a warning on all metrics that requires at least 2 candidates. + ## [0.4.4] 2023-08-14 ### Added - `Evaluate` class now implements a `__hash__` and `tolist()` methods. diff --git a/CITATION.cff b/CITATION.cff index 93cdbe0..7cf6afe 100644 --- a/CITATION.cff +++ b/CITATION.cff @@ -19,5 +19,5 @@ keywords: - captioning - audio-captioning license: MIT -version: 0.4.4 -date-released: '2023-08-14' +version: 0.4.5 +date-released: '2023-09-12' diff --git a/README.md b/README.md index 8132dbd..11094e1 100644 --- a/README.md +++ b/README.md @@ -49,8 +49,8 @@ aac-metrics-download ``` Notes: -- The external code for SPICE, METEOR and PTBTokenizer is stored in `$HOME/.cache/aac-metrics`. -- The weights of the FENSE fluency error detector and the the SBERT model are respectively stored by default in `$HOME/.cache/torch/hub/fense_data` and `$HOME/.cache/torch/sentence_transformers`. +- The external code for SPICE, METEOR and PTBTokenizer is stored in `~/.cache/aac-metrics`. +- The weights of the FENSE fluency error detector and the the SBERT model are respectively stored by default in `~/.cache/torch/hub/fense_data` and `~/.cache/torch/sentence_transformers`. ## Usage ### Evaluate default metrics @@ -59,13 +59,13 @@ The full evaluation pipeline to compute AAC metrics can be done with `aac_metric ```python from aac_metrics import evaluate -candidates: list[str] = ["a man is speaking"] -mult_references: list[list[str]] = [["a man speaks.", "someone speaks.", "a man is speaking while a bird is chirping in the background"]] +candidates: list[str] = ["a man is speaking", "rain falls"] +mult_references: list[list[str]] = [["a man speaks.", "someone speaks.", "a man is speaking while a bird is chirping in the background"], ["rain is falling hard on a surface"]] corpus_scores, _ = evaluate(candidates, mult_references) print(corpus_scores) # dict containing the score of each metric: "bleu_1", "bleu_2", "bleu_3", "bleu_4", "rouge_l", "meteor", "cider_d", "spice", "spider" -# {"bleu_1": tensor(0.7), "bleu_2": ..., ...} +# {"bleu_1": tensor(0.4278), "bleu_2": ..., ...} ``` ### Evaluate DCASE2023 metrics To compute metrics for the DCASE2023 challenge, just set the argument `metrics="dcase2023"` in `evaluate` function call. @@ -83,17 +83,17 @@ Evaluate a specific metric can be done using the `aac_metrics.functional.= 1.21.2 pyyaml >= 6.0 tqdm >= 4.64.0 sentence-transformers >= 2.2.2 +transformers < 4.31.0 ``` ### External requirements @@ -215,10 +217,10 @@ If you use this software, please consider cite it as below : Labbe_aac-metrics_2023, author = {Labbé, Etienne}, license = {MIT}, - month = {8}, + month = {9}, title = {{aac-metrics}}, url = {https://github.com/Labbeti/aac-metrics/}, - version = {0.4.4}, + version = {0.4.5}, year = {2023}, } ``` diff --git a/docs/aac_metrics.eval.rst b/docs/aac_metrics.eval.rst new file mode 100644 index 0000000..4da4361 --- /dev/null +++ b/docs/aac_metrics.eval.rst @@ -0,0 +1,7 @@ +aac\_metrics.eval module +======================== + +.. automodule:: aac_metrics.eval + :members: + :undoc-members: + :show-inheritance: diff --git a/docs/usage.rst b/docs/usage.rst index c655780..3223761 100644 --- a/docs/usage.rst +++ b/docs/usage.rst @@ -4,19 +4,19 @@ Usage Evaluate default AAC metrics ############################ -The full evaluation process to compute AAC metrics can be done with `aac_metrics.aac_evaluate` function. +The full evaluation process to compute AAC metrics can be done with `aac_metrics.dcase2023_evaluate` function. .. code-block:: python - from aac_metrics import aac_evaluate + from aac_metrics import evaluate - candidates: list[str] = ["a man is speaking", ...] - mult_references: list[list[str]] = [["a man speaks.", "someone speaks.", "a man is speaking while a bird is chirping in the background"], ...] + candidates: list[str] = ["a man is speaking", "rain falls"] + mult_references: list[list[str]] = [["a man speaks.", "someone speaks.", "a man is speaking while a bird is chirping in the background"], ["rain is falling hard on a surface"]] - corpus_scores, _ = aac_evaluate(candidates, mult_references) + corpus_scores, _ = evaluate(candidates, mult_references) print(corpus_scores) # dict containing the score of each aac metric: "bleu_1", "bleu_2", "bleu_3", "bleu_4", "rouge_l", "meteor", "cider_d", "spice", "spider" - # {"bleu_1": tensor(0.7), "bleu_2": ..., ...} + # {"bleu_1": tensor(0.4278), "bleu_2": ..., ...} Evaluate a specific metric @@ -25,24 +25,24 @@ Evaluate a specific metric Evaluate a specific metric can be done using the `aac_metrics.functional..` function or the `aac_metrics.classes..` class. .. warning:: - Unlike `aac_evaluate`, the tokenization with PTBTokenizer is not done with these functions, but you can do it manually with `preprocess_mono_sents` and `preprocess_mult_sents` functions. + Unlike `dcase2023_evaluate`, the tokenization with PTBTokenizer is not done with these functions, but you can do it manually with `preprocess_mono_sents` and `preprocess_mult_sents` functions. .. code-block:: python - + from aac_metrics.functional import cider_d from aac_metrics.utils.tokenization import preprocess_mono_sents, preprocess_mult_sents - candidates: list[str] = ["a man is speaking", ...] - mult_references: list[list[str]] = [["a man speaks.", "someone speaks.", "a man is speaking while a bird is chirping in the background"], ...] + candidates: list[str] = ["a man is speaking", "rain falls"] + mult_references: list[list[str]] = [["a man speaks.", "someone speaks.", "a man is speaking while a bird is chirping in the background"], ["rain is falling hard on a surface"]] candidates = preprocess_mono_sents(candidates) mult_references = preprocess_mult_sents(mult_references) corpus_scores, sents_scores = cider_d(candidates, mult_references) print(corpus_scores) - # {"cider_d": tensor(0.1)} + # {"cider_d": tensor(0.9614)} print(sents_scores) - # {"cider_d": tensor([0.9, ...])} + # {"cider_d": tensor([1.3641, 0.5587])} Each metrics also exists as a python class version, like `aac_metrics.classes.cider_d.CIDErD`. diff --git a/pyproject.toml b/pyproject.toml index 889d64f..9964a66 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -39,7 +39,7 @@ Changelog = "https://github.com/Labbeti/aac-metrics/blob/main/CHANGELOG.md" [project.scripts] aac-metrics = "aac_metrics.__main__:_print_usage" aac-metrics-download = "aac_metrics.download:_main_download" -aac-metrics-evaluate = "aac_metrics.evaluate:_main_evaluate" +aac-metrics-eval = "aac_metrics.eval:_main_eval" aac-metrics-info = "aac_metrics.info:print_install_info" [project.optional-dependencies] diff --git a/src/aac_metrics/__init__.py b/src/aac_metrics/__init__.py index f614bf3..99fccf0 100644 --- a/src/aac_metrics/__init__.py +++ b/src/aac_metrics/__init__.py @@ -10,7 +10,7 @@ __license__ = "MIT" __maintainer__ = "Etienne Labbé (Labbeti)" __status__ = "Development" -__version__ = "0.4.4" +__version__ = "0.4.5" from .classes.base import AACMetric @@ -65,9 +65,10 @@ def load_metric(name: str, **kwargs) -> AACMetric: name = name.lower().strip() factory = _get_metric_factory_classes(**kwargs) - if name in factory: - return factory[name]() - else: + if name not in factory: raise ValueError( f"Invalid argument {name=}. (expected one of {tuple(factory.keys())})" ) + + metric = factory[name]() + return metric diff --git a/src/aac_metrics/classes/bleu.py b/src/aac_metrics/classes/bleu.py index 9d6030a..926945a 100644 --- a/src/aac_metrics/classes/bleu.py +++ b/src/aac_metrics/classes/bleu.py @@ -53,12 +53,13 @@ def __init__( def compute(self) -> Union[tuple[dict[str, Tensor], dict[str, Tensor]], Tensor]: return _bleu_compute( - self._cooked_cands, - self._cooked_mrefs, - self._return_all_scores, - self._n, - self._option, - self._verbose, + cooked_cands=self._cooked_cands, + cooked_mrefs=self._cooked_mrefs, + return_all_scores=self._return_all_scores, + n=self._n, + option=self._option, + verbose=self._verbose, + return_1_to_n=False, ) def extra_repr(self) -> str: diff --git a/src/aac_metrics/classes/cider_d.py b/src/aac_metrics/classes/cider_d.py index c2bf68d..b22c7f5 100644 --- a/src/aac_metrics/classes/cider_d.py +++ b/src/aac_metrics/classes/cider_d.py @@ -49,13 +49,13 @@ def __init__( def compute(self) -> Union[tuple[dict[str, Tensor], dict[str, Tensor]], Tensor]: return _cider_d_compute( - self._cooked_cands, - self._cooked_mrefs, - self._return_all_scores, - self._n, - self._sigma, - self._return_tfidf, - self._scale, + cooked_cands=self._cooked_cands, + cooked_mrefs=self._cooked_mrefs, + return_all_scores=self._return_all_scores, + n=self._n, + sigma=self._sigma, + return_tfidf=self._return_tfidf, + scale=self._scale, ) def extra_repr(self) -> str: @@ -75,10 +75,10 @@ def update( mult_references: list[list[str]], ) -> None: self._cooked_cands, self._cooked_mrefs = _cider_d_update( - candidates, - mult_references, - self._n, - self._tokenizer, - self._cooked_cands, - self._cooked_mrefs, + candidates=candidates, + mult_references=mult_references, + n=self._n, + tokenizer=self._tokenizer, + prev_cooked_cands=self._cooked_cands, + prev_cooked_mrefs=self._cooked_mrefs, ) diff --git a/src/aac_metrics/classes/evaluate.py b/src/aac_metrics/classes/evaluate.py index 0e264a3..d633c57 100644 --- a/src/aac_metrics/classes/evaluate.py +++ b/src/aac_metrics/classes/evaluate.py @@ -108,7 +108,7 @@ def __hash__(self) -> int: class DCASE2023Evaluate(Evaluate): """Evaluate candidates with multiple references with DCASE2023 Audio Captioning metrics. - For more information, see :func:`~aac_metrics.functional.evaluate.aac_evaluate`. + For more information, see :func:`~aac_metrics.functional.evaluate.dcase2023_evaluate`. """ def __init__( diff --git a/src/aac_metrics/classes/fense.py b/src/aac_metrics/classes/fense.py index da0b318..e0fc73d 100644 --- a/src/aac_metrics/classes/fense.py +++ b/src/aac_metrics/classes/fense.py @@ -42,7 +42,7 @@ def __init__( device: Union[str, torch.device, None] = "auto", batch_size: int = 32, reset_state: bool = True, - return_probs: bool = True, + return_probs: bool = False, penalty: float = 0.9, verbose: int = 0, ) -> None: @@ -66,19 +66,19 @@ def __init__( def compute(self) -> Union[tuple[dict[str, Tensor], dict[str, Tensor]], Tensor]: return fense( - self._candidates, - self._mult_references, - self._return_all_scores, - self._sbert_model, - self._echecker, - self._echecker_tokenizer, - self._error_threshold, - self._device, - self._batch_size, - self._reset_state, - self._return_probs, - self._penalty, - self._verbose, + candidates=self._candidates, + mult_references=self._mult_references, + return_all_scores=self._return_all_scores, + sbert_model=self._sbert_model, + echecker=self._echecker, + echecker_tokenizer=self._echecker_tokenizer, + error_threshold=self._error_threshold, + device=self._device, + batch_size=self._batch_size, + reset_state=self._reset_state, + return_probs=self._return_probs, + penalty=self._penalty, + verbose=self._verbose, ) def extra_repr(self) -> str: diff --git a/src/aac_metrics/classes/fluerr.py b/src/aac_metrics/classes/fluerr.py index fdfb43d..0e84c5a 100644 --- a/src/aac_metrics/classes/fluerr.py +++ b/src/aac_metrics/classes/fluerr.py @@ -44,7 +44,7 @@ def __init__( device: Union[str, torch.device, None] = "auto", batch_size: int = 32, reset_state: bool = True, - return_probs: bool = True, + return_probs: bool = False, verbose: int = 0, ) -> None: echecker, echecker_tokenizer = _load_echecker_and_tokenizer(echecker, None, device, reset_state, verbose) # type: ignore @@ -64,16 +64,16 @@ def __init__( def compute(self) -> Union[tuple[dict[str, Tensor], dict[str, Tensor]], Tensor]: return fluerr( - self._candidates, - self._return_all_scores, - self._echecker, - self._echecker_tokenizer, - self._error_threshold, - self._device, - self._batch_size, - self._reset_state, - self._return_probs, - self._verbose, + candidates=self._candidates, + return_all_scores=self._return_all_scores, + echecker=self._echecker, + echecker_tokenizer=self._echecker_tokenizer, + error_threshold=self._error_threshold, + device=self._device, + batch_size=self._batch_size, + reset_state=self._reset_state, + return_probs=self._return_probs, + verbose=self._verbose, ) def extra_repr(self) -> str: diff --git a/src/aac_metrics/classes/meteor.py b/src/aac_metrics/classes/meteor.py index 2d8aff6..98e517e 100644 --- a/src/aac_metrics/classes/meteor.py +++ b/src/aac_metrics/classes/meteor.py @@ -1,7 +1,7 @@ #!/usr/bin/env python # -*- coding: utf-8 -*- -from typing import Union +from typing import Optional, Union from torch import Tensor @@ -32,6 +32,7 @@ def __init__( java_path: str = ..., java_max_memory: str = "2G", language: str = "en", + use_shell: Optional[bool] = None, verbose: int = 0, ) -> None: super().__init__() @@ -40,6 +41,7 @@ def __init__( self._java_path = java_path self._java_max_memory = java_max_memory self._language = language + self._use_shell = use_shell self._verbose = verbose self._candidates = [] @@ -47,14 +49,15 @@ def __init__( def compute(self) -> Union[tuple[dict[str, Tensor], dict[str, Tensor]], Tensor]: return meteor( - self._candidates, - self._mult_references, - self._return_all_scores, - self._cache_path, - self._java_path, - self._java_max_memory, - self._language, - self._verbose, + candidates=self._candidates, + mult_references=self._mult_references, + return_all_scores=self._return_all_scores, + cache_path=self._cache_path, + java_path=self._java_path, + java_max_memory=self._java_max_memory, + language=self._language, + use_shell=self._use_shell, + verbose=self._verbose, ) def extra_repr(self) -> str: diff --git a/src/aac_metrics/classes/rouge_l.py b/src/aac_metrics/classes/rouge_l.py index 7cc79c0..503c912 100644 --- a/src/aac_metrics/classes/rouge_l.py +++ b/src/aac_metrics/classes/rouge_l.py @@ -42,8 +42,8 @@ def __init__( def compute(self) -> Union[tuple[dict[str, Tensor], dict[str, Tensor]], Tensor]: return _rouge_l_compute( - self._rouge_l_scores, - self._return_all_scores, + rouge_l_scs=self._rouge_l_scores, + return_all_scores=self._return_all_scores, ) def extra_repr(self) -> str: @@ -62,9 +62,9 @@ def update( mult_references: list[list[str]], ) -> None: self._rouge_l_scores = _rouge_l_update( - candidates, - mult_references, - self._beta, - self._tokenizer, - self._rouge_l_scores, + candidates=candidates, + mult_references=mult_references, + beta=self._beta, + tokenizer=self._tokenizer, + prev_rouge_l_scores=self._rouge_l_scores, ) diff --git a/src/aac_metrics/classes/sbert_sim.py b/src/aac_metrics/classes/sbert_sim.py index 39704dc..99eaf52 100644 --- a/src/aac_metrics/classes/sbert_sim.py +++ b/src/aac_metrics/classes/sbert_sim.py @@ -57,14 +57,14 @@ def __init__( def compute(self) -> Union[tuple[dict[str, Tensor], dict[str, Tensor]], Tensor]: return sbert_sim( - self._candidates, - self._mult_references, - self._return_all_scores, - self._sbert_model, - self._device, - self._batch_size, - self._reset_state, - self._verbose, + candidates=self._candidates, + mult_references=self._mult_references, + return_all_scores=self._return_all_scores, + sbert_model=self._sbert_model, + device=self._device, + batch_size=self._batch_size, + reset_state=self._reset_state, + verbose=self._verbose, ) def extra_repr(self) -> str: diff --git a/src/aac_metrics/classes/spice.py b/src/aac_metrics/classes/spice.py index b9308be..17f1f9c 100644 --- a/src/aac_metrics/classes/spice.py +++ b/src/aac_metrics/classes/spice.py @@ -39,6 +39,7 @@ def __init__( java_max_memory: str = "8G", timeout: Union[None, int, Iterable[int]] = None, separate_cache_dir: bool = True, + use_shell: Optional[bool] = None, verbose: int = 0, ) -> None: super().__init__() @@ -50,6 +51,7 @@ def __init__( self._java_max_memory = java_max_memory self._timeout = timeout self._separate_cache_dir = separate_cache_dir + self._use_shell = use_shell self._verbose = verbose self._candidates = [] @@ -57,17 +59,18 @@ def __init__( def compute(self) -> Union[tuple[dict[str, Tensor], dict[str, Tensor]], Tensor]: return spice( - self._candidates, - self._mult_references, - self._return_all_scores, - self._cache_path, - self._java_path, - self._tmp_path, - self._n_threads, - self._java_max_memory, - self._timeout, - self._separate_cache_dir, - self._verbose, + candidates=self._candidates, + mult_references=self._mult_references, + return_all_scores=self._return_all_scores, + cache_path=self._cache_path, + java_path=self._java_path, + tmp_path=self._tmp_path, + n_threads=self._n_threads, + java_max_memory=self._java_max_memory, + timeout=self._timeout, + separate_cache_dir=self._separate_cache_dir, + use_shell=self._use_shell, + verbose=self._verbose, ) def extra_repr(self) -> str: diff --git a/src/aac_metrics/classes/spider.py b/src/aac_metrics/classes/spider.py index 3df89b8..0ecb30d 100644 --- a/src/aac_metrics/classes/spider.py +++ b/src/aac_metrics/classes/spider.py @@ -61,9 +61,9 @@ def __init__( def compute(self) -> Union[tuple[dict[str, Tensor], dict[str, Tensor]], Tensor]: return spider( - self._candidates, - self._mult_references, - self._return_all_scores, + candidates=self._candidates, + mult_references=self._mult_references, + return_all_scores=self._return_all_scores, # CIDEr args n=self._n, sigma=self._sigma, diff --git a/src/aac_metrics/classes/spider_fl.py b/src/aac_metrics/classes/spider_fl.py index 24edf96..f11f078 100644 --- a/src/aac_metrics/classes/spider_fl.py +++ b/src/aac_metrics/classes/spider_fl.py @@ -88,9 +88,9 @@ def __init__( def compute(self) -> Union[tuple[dict[str, Tensor], dict[str, Tensor]], Tensor]: return spider_fl( - self._candidates, - self._mult_references, - self._return_all_scores, + candidates=self._candidates, + mult_references=self._mult_references, + return_all_scores=self._return_all_scores, # CIDEr args n=self._n, sigma=self._sigma, diff --git a/src/aac_metrics/classes/spider_max.py b/src/aac_metrics/classes/spider_max.py index dd0bad4..a43c730 100644 --- a/src/aac_metrics/classes/spider_max.py +++ b/src/aac_metrics/classes/spider_max.py @@ -63,10 +63,10 @@ def __init__( def compute(self) -> Union[tuple[dict[str, Tensor], dict[str, Tensor]], Tensor]: return spider_max( - self._mult_candidates, - self._mult_references, - self._return_all_scores, - self._return_all_cands_scores, + mult_candidates=self._mult_candidates, + mult_references=self._mult_references, + return_all_scores=self._return_all_scores, + return_all_cands_scores=self._return_all_cands_scores, n=self._n, sigma=self._sigma, cache_path=self._cache_path, diff --git a/src/aac_metrics/download.py b/src/aac_metrics/download.py index 041f09b..bfdb103 100644 --- a/src/aac_metrics/download.py +++ b/src/aac_metrics/download.py @@ -4,11 +4,13 @@ import logging import os import os.path as osp +import platform import subprocess import sys from argparse import ArgumentParser, Namespace from subprocess import CalledProcessError +from typing import Optional from torch.hub import download_url_to_file @@ -76,6 +78,7 @@ def download( cache_path: str = ..., tmp_path: str = ..., + use_shell: Optional[bool] = None, ptb_tokenizer: bool = True, meteor: bool = True, spice: bool = True, @@ -86,12 +89,18 @@ def download( :param cache_path: The path to the external code directory. defaults to the value returned by :func:`~aac_metrics.utils.paths.get_default_cache_path`. :param tmp_path: The path to a temporary directory. defaults to the value returned by :func:`~aac_metrics.utils.paths.get_default_tmp_path`. + :param use_shell: Optional argument to force use os-specific shell for the bash installation script program. + If None, it will use shell only on Windows OS. + defaults to None. :param ptb_tokenizer: If True, downloads the PTBTokenizer code in cache directory. defaults to True. :param meteor: If True, downloads the METEOR code in cache directory. defaults to True. :param spice: If True, downloads the SPICE code in cache directory. defaults to True. :param fense: If True, downloads the FENSE models. defaults to True. :param verbose: The verbose level. defaults to 0. """ + if verbose >= 1: + pylog.info(f"aac-metrics download started.") + cache_path = _get_cache_path(cache_path) tmp_path = _get_tmp_path(tmp_path) @@ -110,14 +119,17 @@ def download( _download_meteor(cache_path, verbose) if spice: - _download_spice(cache_path, verbose) + _download_spice(cache_path, use_shell, verbose) if fense: _download_fense(verbose) + if verbose >= 1: + pylog.info(f"aac-metrics download finished.") + def _download_ptb_tokenizer( - cache_path: str = ..., + cache_path: str, verbose: int = 0, ) -> None: # Download JAR file for tokenization @@ -143,7 +155,7 @@ def _download_ptb_tokenizer( def _download_meteor( - cache_path: str = ..., + cache_path: str, verbose: int = 0, ) -> None: # Download JAR files for METEOR metric @@ -177,7 +189,8 @@ def _download_meteor( def _download_spice( - cache_path: str = ..., + cache_path: str, + use_shell: Optional[bool] = None, verbose: int = 0, ) -> None: # Download JAR files for SPICE metric @@ -208,15 +221,20 @@ def _download_spice( f"Downloading JAR sources for SPICE metric into '{spice_jar_dpath}'..." ) + if use_shell is None: + use_shell = platform.system() == "Windows" + command = ["bash", script_path, spice_jar_dpath] try: subprocess.check_call( command, stdout=None if verbose >= 2 else subprocess.DEVNULL, stderr=None if verbose >= 2 else subprocess.DEVNULL, + shell=use_shell, ) except (CalledProcessError, PermissionError) as err: - pylog.error(err) + pylog.error("Cannot install SPICE java source code.") + raise err def _download_fense( diff --git a/src/aac_metrics/eval.py b/src/aac_metrics/eval.py new file mode 100644 index 0000000..7c52883 --- /dev/null +++ b/src/aac_metrics/eval.py @@ -0,0 +1,237 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- + +import csv +import logging +import sys + +from argparse import ArgumentParser, Namespace +from pathlib import Path +from typing import Iterable, Union + +import yaml + +from aac_metrics.functional.evaluate import ( + evaluate, + METRICS_SETS, + DEFAULT_METRICS_SET_NAME, +) +from aac_metrics.utils.checks import check_metric_inputs, check_java_path +from aac_metrics.utils.paths import ( + get_default_cache_path, + get_default_java_path, + get_default_tmp_path, +) + + +pylog = logging.getLogger(__name__) + + +def load_csv_file( + fpath: Union[str, Path], + cands_columns: Union[str, Iterable[str]] = ("caption_predicted",), + mrefs_columns: Union[str, Iterable[str]] = ( + "caption_1", + "caption_2", + "caption_3", + "caption_4", + "caption_5", + ), + load_mult_cands: bool = False, + strict: bool = True, +) -> tuple[list, list[list[str]]]: + """Load candidates and mult_references from a CSV file. + + :param fpath: The filepath to the CSV file. + :param cands_columns: The columns of the candidates. defaults to ("captions_predicted",). + :param mrefs_columns: The columns of the multiple references. defaults to ("caption_1", "caption_2", "caption_3", "caption_4", "caption_5"). + :param load_mult_cands: If True, load multiple candidates from file. defaults to False. + :returns: A tuple of (candidates, mult_references) loaded from file. + """ + if isinstance(cands_columns, str): + cands_columns = [cands_columns] + else: + cands_columns = list(cands_columns) + + if isinstance(mrefs_columns, str): + mrefs_columns = [mrefs_columns] + else: + mrefs_columns = list(mrefs_columns) + + with open(fpath, "r") as file: + reader = csv.DictReader(file) + fieldnames = reader.fieldnames + data = list(reader) + + if fieldnames is None: + raise ValueError(f"Cannot read fieldnames in CSV file {fpath=}.") + + file_cands_columns = [column for column in cands_columns if column in fieldnames] + file_mrefs_columns = [column for column in mrefs_columns if column in fieldnames] + + if strict: + if len(file_cands_columns) != len(cands_columns): + raise ValueError( + f"Cannot find all candidates columns {cands_columns=} in file '{fpath}'." + ) + if len(file_mrefs_columns) != len(mrefs_columns): + raise ValueError( + f"Cannot find all references columns {mrefs_columns=} in file '{fpath}'." + ) + + if (load_mult_cands and len(file_cands_columns) <= 0) or ( + not load_mult_cands and len(file_cands_columns) != 1 + ): + raise ValueError( + f"Cannot find candidate column in file. ({cands_columns=} not found in {fieldnames=})" + ) + if len(file_mrefs_columns) <= 0: + raise ValueError( + f"Cannot find references columns in file. ({mrefs_columns=} not found in {fieldnames=})" + ) + + if load_mult_cands: + mult_candidates = _load_columns(data, file_cands_columns) + mult_references = _load_columns(data, file_mrefs_columns) + return mult_candidates, mult_references + else: + file_cand_column = file_cands_columns[0] + candidates = [data_i[file_cand_column] for data_i in data] + mult_references = _load_columns(data, file_mrefs_columns) + return candidates, mult_references + + +def _load_columns(data: list[dict[str, str]], columns: list[str]) -> list[list[str]]: + mult_sentences = [] + for data_i in data: + raw_sents = [data_i[column] for column in columns] + sents = [] + for raw_sent in raw_sents: + # Refs columns can be list[str] + if "[" in raw_sent and "]" in raw_sent: + try: + sent = eval(raw_sent) + assert isinstance(sent, list) and all( + isinstance(sent_i, str) for sent_i in sent + ) + sents += sent + except (SyntaxError, NameError): + sents.append(raw_sent) + else: + sents.append(raw_sent) + + mult_sentences.append(sents) + return mult_sentences + + +def _get_main_evaluate_args() -> Namespace: + parser = ArgumentParser(description="Evaluate an output file.") + + parser.add_argument( + "--input_file", + "-i", + type=str, + default="", + help="The input file path containing the candidates and references.", + required=True, + ) + parser.add_argument( + "--cand_columns", + "-cc", + type=str, + nargs="+", + default=("caption_predicted", "preds", "cands"), + help="The column names of the candidates in the CSV file. defaults to ('caption_predicted', 'preds', 'cands').", + ) + parser.add_argument( + "--mrefs_columns", + "-rc", + type=str, + nargs="+", + default=( + "caption_1", + "caption_2", + "caption_3", + "caption_4", + "caption_5", + "captions", + ), + help="The column names of the candidates in the CSV file. defaults to ('caption_1', 'caption_2', 'caption_3', 'caption_4', 'caption_5', 'captions').", + ) + parser.add_argument( + "--metrics_set_name", + type=str, + default=DEFAULT_METRICS_SET_NAME, + choices=tuple(METRICS_SETS.keys()), + help=f"The metrics set to compute. Can be one of {tuple(METRICS_SETS.keys())}. defaults to 'default'.", + ) + parser.add_argument( + "--cache_path", + type=str, + default=get_default_cache_path(), + help=f"Cache directory path. defaults to '{get_default_cache_path()}'.", + ) + parser.add_argument( + "--java_path", + type=str, + default=get_default_java_path(), + help=f"Java executable path. defaults to '{get_default_java_path()}'.", + ) + parser.add_argument( + "--tmp_path", + type=str, + default=get_default_tmp_path(), + help=f"Temporary directory path. defaults to '{get_default_tmp_path()}'.", + ) + parser.add_argument("--verbose", type=int, default=0, help="Verbose level.") + + args = parser.parse_args() + return args + + +def _main_eval() -> None: + format_ = "[%(asctime)s][%(name)s][%(levelname)s] - %(message)s" + handler = logging.StreamHandler(sys.stdout) + handler.setFormatter(logging.Formatter(format_)) + pkg_logger = logging.getLogger("aac_metrics") + pkg_logger.addHandler(handler) + + args = _get_main_evaluate_args() + + if not check_java_path(args.java_path): + raise RuntimeError(f"Invalid Java executable. ({args.java_path})") + + level = logging.INFO if args.verbose <= 1 else logging.DEBUG + pkg_logger.setLevel(level) + + if args.verbose >= 1: + pylog.info(f"Load file {args.input_file}...") + + candidates, mult_references = load_csv_file( + args.input_file, args.cand_columns, args.mrefs_columns + ) + check_metric_inputs(candidates, mult_references) + + refs_lens = list(map(len, mult_references)) + if args.verbose >= 1: + pylog.info( + f"Found {len(candidates)} candidates, {len(mult_references)} references and [{min(refs_lens)}, {max(refs_lens)}] references per candidate." + ) + + corpus_scores, _sents_scores = evaluate( + candidates=candidates, + mult_references=mult_references, + preprocess=True, + metrics=args.metrics_set_name, + cache_path=args.cache_path, + java_path=args.java_path, + tmp_path=args.tmp_path, + verbose=args.verbose, + ) + + corpus_scores = {k: v.item() for k, v in corpus_scores.items()} + pylog.info(f"Global scores:\n{yaml.dump(corpus_scores, sort_keys=False)}") + + +if __name__ == "__main__": + _main_eval() diff --git a/src/aac_metrics/functional/cider_d.py b/src/aac_metrics/functional/cider_d.py index 6577407..35df385 100644 --- a/src/aac_metrics/functional/cider_d.py +++ b/src/aac_metrics/functional/cider_d.py @@ -24,6 +24,9 @@ def cider_d( - Paper: https://arxiv.org/pdf/1411.5726.pdf + .. warning:: + This metric requires at least 2 candidates with 2 sets of references, otherwise it will raises a ValueError. + :param candidates: The list of sentences to evaluate. :param mult_references: The list of list of sentences used as target. :param return_all_scores: If True, returns a tuple containing the globals and locals scores. diff --git a/src/aac_metrics/functional/evaluate.py b/src/aac_metrics/functional/evaluate.py index dd0fba8..e5c1775 100644 --- a/src/aac_metrics/functional/evaluate.py +++ b/src/aac_metrics/functional/evaluate.py @@ -94,7 +94,7 @@ def evaluate( :param device: The PyTorch device used to run FENSE and SPIDErFL models. If None, it will try to detect use cuda if available. defaults to "auto". :param verbose: The verbose level. defaults to 0. - :returns: A tuple of globals and locals scores. + :returns: A tuple contains the corpus and sentences scores. """ check_metric_inputs(candidates, mult_references) @@ -103,7 +103,7 @@ def evaluate( ) if preprocess: - common_kwds = dict( + common_kwds: dict[str, Any] = dict( cache_path=cache_path, java_path=java_path, tmp_path=tmp_path, @@ -179,7 +179,7 @@ def dcase2023_evaluate( :param device: The PyTorch device used to run FENSE and SPIDErFL models. If None, it will try to detect use cuda if available. defaults to "auto". :param verbose: The verbose level. defaults to 0. - :returns: A tuple of globals and locals scores. + :returns: A tuple contains the corpus and sentences scores. """ return evaluate( candidates=candidates, diff --git a/src/aac_metrics/functional/fense.py b/src/aac_metrics/functional/fense.py index 68f2871..778b43b 100644 --- a/src/aac_metrics/functional/fense.py +++ b/src/aac_metrics/functional/fense.py @@ -40,7 +40,7 @@ def fense( device: Union[str, torch.device, None] = "auto", batch_size: int = 32, reset_state: bool = True, - return_probs: bool = True, + return_probs: bool = False, # Other args penalty: float = 0.9, verbose: int = 0, @@ -67,7 +67,7 @@ def fense( :param device: The PyTorch device used to run FENSE models. If "auto", it will use cuda if available. defaults to "auto". :param batch_size: The batch size of the sBERT and echecker models. defaults to 32. :param reset_state: If True, reset the state of the PyTorch global generator after the initialization of the pre-trained models. defaults to True. - :param return_probs: If True, return each individual error probability given by the fluency detector model. defaults to True. + :param return_probs: If True, return each individual error probability given by the fluency detector model. defaults to False. :param verbose: The verbose level. defaults to 0. :returns: A tuple of globals and locals scores or a scalar tensor with the main global score. """ diff --git a/src/aac_metrics/functional/fluerr.py b/src/aac_metrics/functional/fluerr.py index 65c66ec..07914c9 100644 --- a/src/aac_metrics/functional/fluerr.py +++ b/src/aac_metrics/functional/fluerr.py @@ -101,7 +101,7 @@ def fluerr( device: Union[str, torch.device, None] = "auto", batch_size: int = 32, reset_state: bool = True, - return_probs: bool = True, + return_probs: bool = False, verbose: int = 0, ) -> Union[Tensor, tuple[dict[str, Tensor], dict[str, Tensor]]]: """Return fluency error detected by a pre-trained BERT model. @@ -124,7 +124,7 @@ def fluerr( :param device: The PyTorch device used to run FENSE models. If "auto", it will use cuda if available. defaults to "auto". :param batch_size: The batch size of the echecker models. defaults to 32. :param reset_state: If True, reset the state of the PyTorch global generator after the initialization of the pre-trained models. defaults to True. - :param return_probs: If True, return each individual error probability given by the fluency detector model. defaults to True. + :param return_probs: If True, return each individual error probability given by the fluency detector model. defaults to False. :param verbose: The verbose level. defaults to 0. :returns: A tuple of globals and locals scores or a scalar tensor with the main global score. """ diff --git a/src/aac_metrics/functional/meteor.py b/src/aac_metrics/functional/meteor.py index 79df12a..8ea5dd4 100644 --- a/src/aac_metrics/functional/meteor.py +++ b/src/aac_metrics/functional/meteor.py @@ -5,10 +5,11 @@ import logging import os.path as osp +import platform import subprocess from subprocess import Popen -from typing import Union +from typing import Optional, Union import torch @@ -34,6 +35,7 @@ def meteor( java_path: str = ..., java_max_memory: str = "2G", language: str = "en", + use_shell: Optional[bool] = None, verbose: int = 0, ) -> Union[tuple[dict[str, Tensor], dict[str, Tensor]], Tensor]: """Metric for Evaluation of Translation with Explicit ORdering function. @@ -52,6 +54,9 @@ def meteor( :param language: The language used for stem, synonym and paraphrase matching. Can be one of ("en", "cz", "de", "es", "fr"). defaults to "en". + :param use_shell: Optional argument to force use os-specific shell for the java subprogram. + If None, it will use shell only on Windows OS. + defaults to None. :param verbose: The verbose level. defaults to 0. :returns: A tuple of globals and locals scores or a scalar tensor with the main global score. """ @@ -60,6 +65,9 @@ def meteor( meteor_jar_fpath = osp.join(cache_path, FNAME_METEOR_JAR) + if use_shell is None: + use_shell = platform.system() == "Windows" + if __debug__: if not osp.isfile(meteor_jar_fpath): raise FileNotFoundError( @@ -94,13 +102,16 @@ def meteor( ] if verbose >= 2: - pylog.debug(f"Start METEOR process with command '{' '.join(meteor_cmd)}'...") + pylog.debug( + f"Run METEOR java code with: {' '.join(meteor_cmd)} and {use_shell=}" + ) meteor_process = Popen( meteor_cmd, stdin=subprocess.PIPE, stdout=subprocess.PIPE, stderr=subprocess.PIPE, + shell=use_shell, ) n_candidates = len(candidates) diff --git a/src/aac_metrics/functional/spice.py b/src/aac_metrics/functional/spice.py index d6b6307..f673821 100644 --- a/src/aac_metrics/functional/spice.py +++ b/src/aac_metrics/functional/spice.py @@ -6,6 +6,7 @@ import math import os import os.path as osp +import platform import shutil import subprocess import tempfile @@ -47,6 +48,7 @@ def spice( java_max_memory: str = "8G", timeout: Union[None, int, Iterable[int]] = None, separate_cache_dir: bool = True, + use_shell: Optional[bool] = None, verbose: int = 0, ) -> Union[tuple[dict[str, Tensor], dict[str, Tensor]], Tensor]: """Semantic Propositional Image Caption Evaluation function. @@ -72,6 +74,9 @@ def spice( :param separate_cache_dir: If True, the SPICE cache files will be stored into in a new temporary directory. This removes potential freezes when multiple instances of SPICE are running in the same cache dir. defaults to True. + :param use_shell: Optional argument to force use os-specific shell for the java subprogram. + If None, it will use shell only on Windows OS. + defaults to None. :param verbose: The verbose level. defaults to 0. :returns: A tuple of globals and locals scores or a scalar tensor with the main global score. """ @@ -82,6 +87,9 @@ def spice( spice_fpath = osp.join(cache_path, FNAME_SPICE_JAR) + if use_shell is None: + use_shell = platform.system() == "Windows" + if __debug__: if not osp.isfile(spice_fpath): raise FileNotFoundError( @@ -138,19 +146,19 @@ def spice( stdout = None stderr = None else: - stdout = NamedTemporaryFile( + common_kwds: dict[str, Any] = dict( mode="w", delete=True, dir=tmp_path, - prefix="spice_stdout_", suffix=".txt", ) + stdout = NamedTemporaryFile( + prefix="spice_stdout_", + **common_kwds, + ) stderr = NamedTemporaryFile( - mode="w", - delete=True, - dir=tmp_path, prefix="spice_stderr_", - suffix=".txt", + **common_kwds, ) spice_cmd = [ @@ -169,7 +177,9 @@ def spice( spice_cmd += ["-threads", str(n_threads)] if verbose >= 2: - pylog.debug(f"Run SPICE java code with: {' '.join(spice_cmd)}") + pylog.debug( + f"Run SPICE java code with: {' '.join(spice_cmd)} and {use_shell=}" + ) try: subprocess.check_call( @@ -177,6 +187,7 @@ def spice( stdout=stdout, stderr=stderr, timeout=timeout_i, + shell=use_shell, ) if stdout is not None: stdout.close() @@ -209,17 +220,18 @@ def spice( and osp.isfile(stdout.name) and osp.isfile(stderr.name) ): - stdout_crashlog = stdout.name.replace( - "spice_stdout", "CRASH_spice_stdout" - ) - stderr_crashlog = stderr.name.replace( - "spice_stderr", "CRASH_spice_stderr" - ) - shutil.copy(stdout.name, stdout_crashlog) - shutil.copy(stderr.name, stderr_crashlog) pylog.error( - f"For more information, see temp files '{stdout_crashlog}' and '{stderr_crashlog}'." + f"For more information, see temp files '{stdout.name}' and '{stderr.name}'." ) + with open(stdout.name, "r") as file: + lines = file.readlines() + content = "\n".join(lines) + pylog.error(f"Content of '{stdout.name}':\n{content}") + + with open(stderr.name, "r") as file: + lines = file.readlines() + content = "\n".join(lines) + pylog.error(f"Content of '{stderr.name}':\n{content}") else: pylog.info( f"Note: No temp file recorded. (found {stdout=} and {stderr=})" diff --git a/src/aac_metrics/functional/spider.py b/src/aac_metrics/functional/spider.py index 01d2f5b..8bcb568 100644 --- a/src/aac_metrics/functional/spider.py +++ b/src/aac_metrics/functional/spider.py @@ -31,6 +31,9 @@ def spider( - Paper: https://arxiv.org/pdf/1612.00370.pdf + .. warning:: + This metric requires at least 2 candidates with 2 sets of references, otherwise it will raises a ValueError. + :param candidates: The list of sentences to evaluate. :param mult_references: The list of list of sentences used as target. :param return_all_scores: If True, returns a tuple containing the globals and locals scores. @@ -44,10 +47,14 @@ def spider( :param cache_path: The path to the external code directory. defaults to the value returned by :func:`~aac_metrics.utils.paths.get_default_cache_path`. :param java_path: The path to the java executable. defaults to the value returned by :func:`~aac_metrics.utils.paths.get_default_java_path`. :param tmp_path: Temporary directory path. defaults to the value returned by :func:`~aac_metrics.utils.paths.get_default_tmp_path`. - :param java_max_memory: The maximal java memory used. defaults to "8G". :param n_threads: Number of threads used to compute SPICE. None value will use the default value of the java program. defaults to None. + :param java_max_memory: The maximal java memory used. defaults to "8G". + :param timeout: The number of seconds before killing the java subprogram. + If a list is given, it will restart the program if the i-th timeout is reached. + If None, no timeout will be used. + defaults to None. :param verbose: The verbose level. defaults to 0. :returns: A tuple of globals and locals scores or a scalar tensor with the main global score. """ diff --git a/src/aac_metrics/functional/spider_fl.py b/src/aac_metrics/functional/spider_fl.py index ed9e007..441df68 100644 --- a/src/aac_metrics/functional/spider_fl.py +++ b/src/aac_metrics/functional/spider_fl.py @@ -57,6 +57,9 @@ def spider_fl( Based on https://github.com/felixgontier/dcase-2023-baseline/blob/main/metrics.py#L48. + .. warning:: + This metric requires at least 2 candidates with 2 sets of references, otherwise it will raises a ValueError. + :param candidates: The list of sentences to evaluate. :param mult_references: The list of list of sentences used as target. :param return_all_scores: If True, returns a tuple containing the globals and locals scores. @@ -70,10 +73,14 @@ def spider_fl( :param cache_path: The path to the external code directory. defaults to the value returned by :func:`~aac_metrics.utils.paths.get_default_cache_path`. :param java_path: The path to the java executable. defaults to the value returned by :func:`~aac_metrics.utils.paths.get_default_java_path`. :param tmp_path: Temporary directory path. defaults to the value returned by :func:`~aac_metrics.utils.paths.get_default_tmp_path`. - :param java_max_memory: The maximal java memory used. defaults to "8G". :param n_threads: Number of threads used to compute SPICE. None value will use the default value of the java program. defaults to None. + :param java_max_memory: The maximal java memory used. defaults to "8G". + :param timeout: The number of seconds before killing the java subprogram. + If a list is given, it will restart the program if the i-th timeout is reached. + If None, no timeout will be used. + defaults to None. :param echecker: The echecker model used to detect fluency errors. Can be "echecker_clotho_audiocaps_base", "echecker_clotho_audiocaps_tiny", "none" or None. defaults to "echecker_clotho_audiocaps_base". diff --git a/src/aac_metrics/functional/spider_max.py b/src/aac_metrics/functional/spider_max.py index 25b4d21..e1f79c7 100644 --- a/src/aac_metrics/functional/spider_max.py +++ b/src/aac_metrics/functional/spider_max.py @@ -32,9 +32,12 @@ def spider_max( ) -> Union[tuple[dict[str, Tensor], dict[str, Tensor]], Tensor]: """SPIDEr-max function. + Compute the maximal SPIDEr score accross multiple candidates. + - Paper: https://dcase.community/documents/workshop2022/proceedings/DCASE2022Workshop_Labbe_46.pdf - Compute the maximal SPIDEr score accross multiple candidates. + .. warning:: + This metric requires at least 2 candidates with 2 sets of references, otherwise it will raises a ValueError. :param mult_candidates: The list of list of sentences to evaluate. :param mult_references: The list of list of sentences used as target. @@ -55,6 +58,10 @@ def spider_max( :param n_threads: Number of threads used to compute SPICE. None value will use the default value of the java program. defaults to None. + :param timeout: The number of seconds before killing the java subprogram. + If a list is given, it will restart the program if the i-th timeout is reached. + If None, no timeout will be used. + defaults to None. :param verbose: The verbose level. defaults to 0. :returns: A tuple of globals and locals scores or a scalar tensor with the main global score. """ diff --git a/tests/test_compare_cet.py b/tests/test_compare_cet.py index 42eca79..ba66e51 100644 --- a/tests/test_compare_cet.py +++ b/tests/test_compare_cet.py @@ -4,9 +4,9 @@ import importlib import os import os.path as osp +import platform import subprocess import sys -import tempfile import unittest from pathlib import Path @@ -16,7 +16,8 @@ from torch import Tensor from aac_metrics.functional.evaluate import evaluate -from aac_metrics.evaluate import load_csv_file +from aac_metrics.eval import load_csv_file +from aac_metrics.utils.paths import get_default_tmp_path class TestCompareCaptionEvaluationTools(TestCase): @@ -25,8 +26,6 @@ class TestCompareCaptionEvaluationTools(TestCase): # Set Up methods @classmethod def setUpClass(cls) -> None: - if os.name == "nt": - return None cls.evaluate_metrics_from_lists = cls._import_cet_eval_func() @classmethod @@ -37,6 +36,7 @@ def _import_cet_eval_func( Tuple[Dict[str, float], Dict[int, Dict[str, float]]], ]: cet_path = osp.join(osp.dirname(__file__), "caption-evaluation-tools") + use_shell = platform.system() == "Windows" stanford_fpath = osp.join( cet_path, @@ -53,14 +53,15 @@ def _import_cet_eval_func( stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL, cwd=osp.join(cet_path, "coco_caption"), + shell=use_shell, ) # Append cet_path to allow imports of "caption" in eval_metrics.py. sys.path.append(cet_path) # Override cache and tmp dir to avoid outputs in source code. spice_module = importlib.import_module("coco_caption.pycocoevalcap.spice.spice") - spice_module.CACHE_DIR = tempfile.gettempdir() # type: ignore - spice_module.TEMP_DIR = tempfile.gettempdir() # type: ignore + spice_module.CACHE_DIR = get_default_tmp_path() # type: ignore + spice_module.TEMP_DIR = get_default_tmp_path() # type: ignore eval_metrics_module = importlib.import_module("eval_metrics") evaluate_metrics_from_lists = eval_metrics_module.evaluate_metrics_from_lists return evaluate_metrics_from_lists @@ -102,13 +103,13 @@ def _get_example_0(self) -> tuple[list[str], list[list[str]]]: def _test_with_example(self, cands: list[str], mrefs: list[list[str]]) -> None: if os.name == "nt": + # Skip this setup on windows return None + cet_outs = self.__class__.evaluate_metrics_from_lists(cands, mrefs) + cet_global_scores, _cet_sents_scores = cet_outs + corpus_scores, _ = evaluate(cands, mrefs, metrics="dcase2020") - ( - cet_global_scores, - _cet_sents_scores, - ) = self.__class__.evaluate_metrics_from_lists(cands, mrefs) cet_global_scores = {k.lower(): v for k, v in cet_global_scores.items()} cet_global_scores = { diff --git a/tests/test_compare_fense.py b/tests/test_compare_fense.py index 7c28467..749c7f1 100644 --- a/tests/test_compare_fense.py +++ b/tests/test_compare_fense.py @@ -12,7 +12,7 @@ from aac_metrics.classes.sbert_sim import SBERTSim from aac_metrics.classes.fense import FENSE -from aac_metrics.evaluate import load_csv_file +from aac_metrics.eval import load_csv_file class TestCompareFENSE(TestCase): diff --git a/tests/test_doc_examples.py b/tests/test_doc_examples.py new file mode 100644 index 0000000..2ead343 --- /dev/null +++ b/tests/test_doc_examples.py @@ -0,0 +1,130 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- + +import os +import unittest + +from unittest import TestCase + +import torch + +from aac_metrics import evaluate +from aac_metrics.functional import cider_d +from aac_metrics.utils.tokenization import ( + preprocess_mono_sents, + preprocess_mult_sents, +) + + +class TestReadmeExamples(TestCase): + def test_example_1(self) -> None: + if os.name == "nt": + return None + + candidates: list[str] = ["a man is speaking", "rain falls"] + mult_references: list[list[str]] = [ + [ + "a man speaks.", + "someone speaks.", + "a man is speaking while a bird is chirping in the background", + ], + ["rain is falling hard on a surface"], + ] + + corpus_scores, _ = evaluate(candidates, mult_references) + # print(corpus_scores) + # dict containing the score of each metric: "bleu_1", "bleu_2", "bleu_3", "bleu_4", "rouge_l", "meteor", "cider_d", "spice", "spider" + # {"bleu_1": tensor(0.4278), "bleu_2": ..., ...} + + expected_keys = [ + "bleu_1", + "bleu_2", + "bleu_3", + "bleu_4", + "rouge_l", + "meteor", + "cider_d", + "spice", + "spider", + ] + self.assertSetEqual(set(corpus_scores.keys()), set(expected_keys)) + + # print(corpus_scores["bleu_1"]) + # print(torch.as_tensor(0.4278, dtype=torch.float64)) + # print("END") + + self.assertTrue( + torch.allclose( + corpus_scores["bleu_1"], + torch.as_tensor(0.4278, dtype=torch.float64), + atol=0.0001, + ), + f"{corpus_scores['bleu_1']=}", + ) + + def test_example_2(self) -> None: + if os.name == "nt": + return None + + candidates: list[str] = ["a man is speaking", "rain falls"] + mult_references: list[list[str]] = [ + [ + "a man speaks.", + "someone speaks.", + "a man is speaking while a bird is chirping in the background", + ], + ["rain is falling hard on a surface"], + ] + + corpus_scores, _ = evaluate(candidates, mult_references, metrics="dcase2023") + # print(corpus_scores) + # dict containing the score of each metric: "meteor", "cider_d", "spice", "spider", "spider_fl", "fluerr" + + expected_keys = ["meteor", "cider_d", "spice", "spider", "spider_fl", "fluerr"] + self.assertTrue(set(corpus_scores.keys()).issuperset(expected_keys)) + + def test_example_3(self) -> None: + if os.name == "nt": + return None + + candidates: list[str] = ["a man is speaking", "rain falls"] + mult_references: list[list[str]] = [ + [ + "a man speaks.", + "someone speaks.", + "a man is speaking while a bird is chirping in the background", + ], + ["rain is falling hard on a surface"], + ] + + candidates = preprocess_mono_sents(candidates) + mult_references = preprocess_mult_sents(mult_references) + + outputs: tuple[dict, dict] = cider_d(candidates, mult_references) # type: ignore + corpus_scores, sents_scores = outputs + # print(corpus_scores) + # {"cider_d": tensor(0.9614)} + # print(sents_scores) + # {"cider_d": tensor([1.3641, 0.5587])} + + self.assertTrue(set(corpus_scores.keys()).issuperset({"cider_d"})) + self.assertTrue(set(sents_scores.keys()).issuperset({"cider_d"})) + + self.assertTrue( + torch.allclose( + corpus_scores["cider_d"], + torch.as_tensor(0.9614, dtype=torch.float64), + atol=0.0001, + ) + ) + self.assertTrue( + torch.allclose( + sents_scores["cider_d"], + torch.as_tensor([1.3641, 0.5587], dtype=torch.float64), + atol=0.0001, + ) + ) + + +if __name__ == "__main__": + unittest.main()