diff --git a/.github/workflows/python-package-pip.yaml b/.github/workflows/python-package-pip.yaml index 44d6383..489876f 100644 --- a/.github/workflows/python-package-pip.yaml +++ b/.github/workflows/python-package-pip.yaml @@ -8,13 +8,21 @@ on: pull_request: branches: [ main, dev ] +env: + CACHE_NUMBER: 0 # increase to reset cache manually + +# Cancel workflow if a new push occurs +concurrency: + group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.ref }} + cancel-in-progress: true + jobs: build: runs-on: ${{ matrix.os }} strategy: matrix: - os: [ubuntu-latest] + os: [ubuntu-latest,windows-latest] python-version: ["3.9"] java-version: ["11"] @@ -39,9 +47,10 @@ jobs: java-package: jre - name: Install package + shell: bash # note: ${GITHUB_REF##*/} gives the branch name run: | - python -m pip install "aac-metrics[dev] @ git+https://github.com/Labbeti/aac-metrics@${GITHUB_REF##*/}" + python -m pip install "aac-metrics[${GITHUB_REF_NAME}] @ git+https://github.com/Labbeti/aac-metrics@${GITHUB_REF##*/}" - name: Load cache of external code and data uses: actions/cache@master diff --git a/.gitignore b/.gitignore index 5d9d203..8e8a81c 100644 --- a/.gitignore +++ b/.gitignore @@ -136,3 +136,5 @@ tests/fense tmp/ tmp*/ *.mdb +core-python* +core-srun* diff --git a/CHANGELOG.md b/CHANGELOG.md index ba46ba8..170c2e1 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -2,14 +2,25 @@ All notable changes to this project will be documented in this file. -## [0.4.4] UNRELEASED +## [0.4.4] 2023-08-14 +### Added +- `Evaluate` class now implements a `__hash__` and `tolist()` methods. +- `BLEU` 1 to n classes and functions. +- Get and set global user paths for cache, java and tmp. + ### Changed -- TODO +- Function `get_install_info` now returns `package_path`. +- `AACMetric` now indicate the output type when using `__call__` method. +- Rename `AACEvaluate` to `DCASE2023Evaluate` and use `dcase2023` metric set instead of `all` metric set. + +### Fixed +- `sbert_sim` name in internal instantiation functions. +- Path management for Windows. ## [0.4.3] 2023-06-15 ### Changed -- `AACMetric` is no longer a subclass of `torchmetrics.Metric` even when it is installed. It avoid dependency to this package and remove potential errors due to Metric. -- Java 12 and 13 are now allowed. +- `AACMetric` is no longer a subclass of `torchmetrics.Metric` even when it is installed. It avoid dependency to this package and remove potential errors due to Metric base class. +- Java 12 and 13 are now allowed in this package. ### Fixed - Output name `sbert_sim` in FENSE and SBERTSim classes. diff --git a/CITATION.cff b/CITATION.cff index 82dad3a..93cdbe0 100644 --- a/CITATION.cff +++ b/CITATION.cff @@ -19,5 +19,5 @@ keywords: - captioning - audio-captioning license: MIT -version: 0.4.3 -date-released: '2023-06-15' +version: 0.4.4 +date-released: '2023-08-14' diff --git a/README.md b/README.md index b0a12dc..8132dbd 100644 --- a/README.md +++ b/README.md @@ -38,6 +38,11 @@ Install the pip package: pip install aac-metrics ``` +If you want to check if the package has been installed and the version, you can use this command: +```bash +aac-metrics-info +``` + Download the external code and models needed for METEOR, SPICE, SPIDEr, SPIDEr-max, PTBTokenizer, SBERTSim, FluencyError, FENSE and SPIDEr-FL: ```bash aac-metrics-download @@ -114,8 +119,10 @@ Each metrics also exists as a python class version, like `aac_metrics.classes.ci | SPIDEr-FL [[9]](#spider-fl) | `SPIDErFL` | audio captioning | [0, 5.5] | Combines SPIDEr and Fluency Error | ## Requirements +This package has been developped for Ubuntu 20.04, and it is expected to work on most Linux distributions. ### Python packages + The pip requirements are automatically installed when using `pip install` on this repository. ``` torch >= 1.10.1 @@ -126,7 +133,7 @@ sentence-transformers >= 2.2.2 ``` ### External requirements -- `java` **>= 1.8 and <= 1.11** is required to compute METEOR, SPICE and use the PTBTokenizer. +- `java` **>= 1.8 and <= 1.13** is required to compute METEOR, SPICE and use the PTBTokenizer. Most of these functions can specify a java executable path with `java_path` argument. - `unzip` command to extract SPICE zipped files. @@ -191,18 +198,14 @@ arXiv: 1612.00370. [Online]. Available: http://arxiv.org/abs/1612.00370 ## Citation If you use **SPIDEr-max**, you can cite the following paper using BibTex : ``` -@inproceedings{labbe:hal-03810396, - TITLE = {{Is my automatic audio captioning system so bad? spider-max: a metric to consider several caption candidates}}, - AUTHOR = {Labb{\'e}, Etienne and Pellegrini, Thomas and Pinquier, Julien}, - URL = {https://hal.archives-ouvertes.fr/hal-03810396}, - BOOKTITLE = {{Workshop DCASE}}, - ADDRESS = {Nancy, France}, - YEAR = {2022}, - MONTH = Nov, - KEYWORDS = {audio captioning ; evaluation metric ; beam search ; multiple candidates}, - PDF = {https://hal.archives-ouvertes.fr/hal-03810396/file/Labbe_DCASE2022.pdf}, - HAL_ID = {hal-03810396}, - HAL_VERSION = {v1}, +@inproceedings{Labbe2022, + title = {Is my Automatic Audio Captioning System so Bad? SPIDEr-max: A Metric to Consider Several Caption Candidates}, + author = {Labb\'{e}, Etienne and Pellegrini, Thomas and Pinquier, Julien}, + year = 2022, + month = {November}, + booktitle = {Proceedings of the 7th Detection and Classification of Acoustic Scenes and Events 2022 Workshop (DCASE2022)}, + address = {Nancy, France}, + url = {https://dcase.community/documents/workshop2022/proceedings/DCASE2022Workshop_Labbe_46.pdf} } ``` @@ -212,10 +215,10 @@ If you use this software, please consider cite it as below : Labbe_aac-metrics_2023, author = {Labbé, Etienne}, license = {MIT}, - month = {6}, + month = {8}, title = {{aac-metrics}}, url = {https://github.com/Labbeti/aac-metrics/}, - version = {0.4.3}, + version = {0.4.4}, year = {2023}, } ``` diff --git a/docs/aac_metrics.utils.paths.rst b/docs/aac_metrics.utils.paths.rst new file mode 100644 index 0000000..0bfcc36 --- /dev/null +++ b/docs/aac_metrics.utils.paths.rst @@ -0,0 +1,7 @@ +aac\_metrics.utils.paths module +=============================== + +.. automodule:: aac_metrics.utils.paths + :members: + :undoc-members: + :show-inheritance: diff --git a/docs/conf.py b/docs/conf.py index ba0b825..be57d0e 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -101,3 +101,7 @@ def setup(app) -> None: app.add_css_file("my_theme.css") + + +# TODO: to be used with sphinx>=7.1 +maximum_signature_line_length = 10 diff --git a/docs/requirements.txt b/docs/requirements.txt index 26625d5..d25ec2e 100644 --- a/docs/requirements.txt +++ b/docs/requirements.txt @@ -1 +1 @@ -sphinx-press-theme==0.8.0 +sphinx-press-theme>=0.8.0 diff --git a/docs/spider_max.rst b/docs/spider_max.rst index 87e136d..edd2e89 100644 --- a/docs/spider_max.rst +++ b/docs/spider_max.rst @@ -75,7 +75,7 @@ Here is 2 examples with the 5 candidates generated by the beam search algorithm, (Audio file id "jid4t-FzUn0" from AudioCaps testing subset) -Even with very similar candidates, the SPIDEr scores varies drastically. To adress this issue, we proposed a SPIDEr-max metric which take the maximum value of several candidates for the same audio. SPIDEr-max demonstrate that SPIDEr can exceed state-of-the-art scores on AudioCaps and Clotho and even [human scores on AudioCaps](https://hal.archives-ouvertes.fr/hal-03810396). +Even with very similar candidates, the SPIDEr scores varies drastically. To adress this issue, we proposed a SPIDEr-max metric which take the maximum value of several candidates for the same audio. SPIDEr-max demonstrate that SPIDEr can exceed state-of-the-art scores on AudioCaps and Clotho and even `human scores on AudioCaps `_. How ? ##### @@ -95,6 +95,6 @@ This usage is very similar to other captioning metrics, with the main difference corpus_scores, sents_scores = spider_max(mult_candidates, mult_references) print(corpus_scores) - # {"spider": tensor(0.1), ...} + # {"spider_max": tensor(0.1), ...} print(sents_scores) - # {"spider": tensor([0.9, ...]), ...} + # {"spider_max": tensor([0.9, ...]), ...} diff --git a/pyproject.toml b/pyproject.toml index 18a97cd..889d64f 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -33,7 +33,7 @@ dynamic = ["version"] [project.urls] Homepage = "https://pypi.org/project/aac-metrics/" Documentation = "https://aac-metrics.readthedocs.io/" -Repository = "https://github.com//Labbeti/aac-metrics.git" +Repository = "https://github.com/Labbeti/aac-metrics.git" Changelog = "https://github.com/Labbeti/aac-metrics/blob/main/CHANGELOG.md" [project.scripts] @@ -50,6 +50,7 @@ dev = [ "scikit-image==0.19.2", "matplotlib==3.5.2", "torchmetrics>=0.10", + "transformers<4.31.0", ] [tool.setuptools.packages.find] diff --git a/requirements.txt b/requirements.txt index ec3e4e0..cff75af 100644 --- a/requirements.txt +++ b/requirements.txt @@ -5,3 +5,4 @@ numpy>=1.21.2 pyyaml>=6.0 tqdm>=4.64.0 sentence-transformers>=2.2.2 +transformers<4.31.0 diff --git a/src/aac_metrics/__init__.py b/src/aac_metrics/__init__.py index 7fc7d9e..f614bf3 100644 --- a/src/aac_metrics/__init__.py +++ b/src/aac_metrics/__init__.py @@ -10,25 +10,33 @@ __license__ = "MIT" __maintainer__ = "Etienne Labbé (Labbeti)" __status__ = "Development" -__version__ = "0.4.3" +__version__ = "0.4.4" from .classes.base import AACMetric from .classes.bleu import BLEU from .classes.cider_d import CIDErD -from .classes.evaluate import AACEvaluate, _get_metric_factory_classes +from .classes.evaluate import DCASE2023Evaluate, _get_metric_factory_classes from .classes.fense import FENSE from .classes.meteor import METEOR from .classes.rouge_l import ROUGEL from .classes.spice import SPICE from .classes.spider import SPIDEr from .functional.evaluate import dcase2023_evaluate, evaluate +from .utils.paths import ( + get_default_cache_path, + get_default_java_path, + get_default_tmp_path, + set_default_cache_path, + set_default_java_path, + set_default_tmp_path, +) __all__ = [ "BLEU", "CIDErD", - "AACEvaluate", + "DCASE2023Evaluate", "FENSE", "METEOR", "ROUGEL", @@ -36,6 +44,13 @@ "SPIDEr", "dcase2023_evaluate", "evaluate", + "get_default_cache_path", + "get_default_java_path", + "get_default_tmp_path", + "set_default_cache_path", + "set_default_java_path", + "set_default_tmp_path", + "load_metric", ] diff --git a/src/aac_metrics/classes/__init__.py b/src/aac_metrics/classes/__init__.py index 23d3fa6..c2614ed 100644 --- a/src/aac_metrics/classes/__init__.py +++ b/src/aac_metrics/classes/__init__.py @@ -3,7 +3,7 @@ from .bleu import BLEU from .cider_d import CIDErD -from .evaluate import Evaluate, AACEvaluate +from .evaluate import DCASE2023Evaluate, Evaluate from .fense import FENSE from .fluerr import FluErr from .meteor import METEOR @@ -18,7 +18,7 @@ __all__ = [ "BLEU", "CIDErD", - "AACEvaluate", + "DCASE2023Evaluate", "Evaluate", "FENSE", "FluErr", diff --git a/src/aac_metrics/classes/base.py b/src/aac_metrics/classes/base.py index aa9253c..0f0882b 100644 --- a/src/aac_metrics/classes/base.py +++ b/src/aac_metrics/classes/base.py @@ -1,13 +1,15 @@ #!/usr/bin/env python # -*- coding: utf-8 -*- -from typing import Any, Optional +from typing import Any, Generic, Optional, TypeVar from torch import nn +OutType = TypeVar("OutType") -class AACMetric(nn.Module): - """Base Metric module used when torchmetrics is not installed.""" + +class AACMetric(nn.Module, Generic[OutType]): + """Base Metric module for AAC metrics. Similar to torchmetrics.Metric.""" # Global values full_state_update: Optional[bool] = False @@ -23,10 +25,10 @@ def __init__(self, **kwargs: Any) -> None: super().__init__(**kwargs) # Public methods - def compute(self) -> Any: - return None + def compute(self) -> OutType: + return None # type: ignore - def forward(self, *args: Any, **kwargs: Any) -> Any: + def forward(self, *args: Any, **kwargs: Any) -> OutType: self.update(*args, **kwargs) output = self.compute() self.reset() @@ -37,3 +39,7 @@ def reset(self) -> None: def update(self, *args, **kwargs) -> None: pass + + # Magic methods + def __call__(self, *args: Any, **kwds: Any) -> OutType: + return super().__call__(*args, **kwds) diff --git a/src/aac_metrics/classes/bleu.py b/src/aac_metrics/classes/bleu.py index d6b82c4..9d6030a 100644 --- a/src/aac_metrics/classes/bleu.py +++ b/src/aac_metrics/classes/bleu.py @@ -13,7 +13,7 @@ ) -class BLEU(AACMetric): +class BLEU(AACMetric[Union[tuple[dict[str, Tensor], dict[str, Tensor]], Tensor]]): """BiLingual Evaluation Understudy metric class. - Paper: https://www.aclweb.org/anthology/P02-1040.pdf @@ -85,3 +85,47 @@ def update( self._cooked_cands, self._cooked_mrefs, ) + + +class BLEU1(BLEU): + def __init__( + self, + return_all_scores: bool = True, + option: str = "closest", + verbose: int = 0, + tokenizer: Callable[[str], list[str]] = str.split, + ) -> None: + super().__init__(return_all_scores, 1, option, verbose, tokenizer) + + +class BLEU2(BLEU): + def __init__( + self, + return_all_scores: bool = True, + option: str = "closest", + verbose: int = 0, + tokenizer: Callable[[str], list[str]] = str.split, + ) -> None: + super().__init__(return_all_scores, 2, option, verbose, tokenizer) + + +class BLEU3(BLEU): + def __init__( + self, + return_all_scores: bool = True, + option: str = "closest", + verbose: int = 0, + tokenizer: Callable[[str], list[str]] = str.split, + ) -> None: + super().__init__(return_all_scores, 3, option, verbose, tokenizer) + + +class BLEU4(BLEU): + def __init__( + self, + return_all_scores: bool = True, + option: str = "closest", + verbose: int = 0, + tokenizer: Callable[[str], list[str]] = str.split, + ) -> None: + super().__init__(return_all_scores, 4, option, verbose, tokenizer) diff --git a/src/aac_metrics/classes/cider_d.py b/src/aac_metrics/classes/cider_d.py index 1372907..c2bf68d 100644 --- a/src/aac_metrics/classes/cider_d.py +++ b/src/aac_metrics/classes/cider_d.py @@ -1,7 +1,7 @@ #!/usr/bin/env python # -*- coding: utf-8 -*- -from typing import Callable, Union +from typing import Any, Callable, Union from torch import Tensor @@ -12,7 +12,7 @@ ) -class CIDErD(AACMetric): +class CIDErD(AACMetric[Union[tuple[dict[str, Tensor], dict[str, Any]], Tensor]]): """Consensus-based Image Description Evaluation metric class. - Paper: https://arxiv.org/pdf/1411.5726.pdf diff --git a/src/aac_metrics/classes/evaluate.py b/src/aac_metrics/classes/evaluate.py index 7a7dea8..0e264a3 100644 --- a/src/aac_metrics/classes/evaluate.py +++ b/src/aac_metrics/classes/evaluate.py @@ -2,6 +2,8 @@ # -*- coding: utf-8 -*- import logging +import pickle +import zlib from typing import Callable, Iterable, Union @@ -26,7 +28,7 @@ pylog = logging.getLogger(__name__) -class Evaluate(list[AACMetric], AACMetric): +class Evaluate(list[AACMetric], AACMetric[tuple[dict[str, Tensor], dict[str, Tensor]]]): """Evaluate candidates with multiple references with custom metrics. For more information, see :func:`~aac_metrics.functional.evaluate.evaluate`. @@ -40,19 +42,19 @@ def __init__( self, preprocess: bool = True, metrics: Union[str, Iterable[str], Iterable[AACMetric]] = "aac", - cache_path: str = "$HOME/.cache", - java_path: str = "java", - tmp_path: str = "/tmp", + cache_path: str = ..., + java_path: str = ..., + tmp_path: str = ..., device: Union[str, torch.device, None] = "auto", verbose: int = 0, ) -> None: metrics = _instantiate_metrics_classes( - metrics, - cache_path, - java_path, - tmp_path, - device, - verbose, + metrics=metrics, + cache_path=cache_path, + java_path=java_path, + tmp_path=tmp_path, + device=device, + verbose=verbose, ) list.__init__(self, metrics) @@ -69,15 +71,15 @@ def __init__( def compute(self) -> tuple[dict[str, Tensor], dict[str, Tensor]]: return evaluate( - self._candidates, - self._mult_references, - self._preprocess, - self, - self._cache_path, - self._java_path, - self._tmp_path, - self._device, - self._verbose, + candidates=self._candidates, + mult_references=self._mult_references, + preprocess=self._preprocess, + metrics=self, + cache_path=self._cache_path, + java_path=self._java_path, + tmp_path=self._tmp_path, + device=self._device, + verbose=self._verbose, ) def reset(self) -> None: @@ -93,9 +95,18 @@ def update( self._candidates += candidates self._mult_references += mult_references + def tolist(self) -> list[AACMetric]: + return list(self) + + def __hash__(self) -> int: + # note: assume that all metrics can be pickled + data = pickle.dumps(self) + data = zlib.adler32(data) + return data -class AACEvaluate(Evaluate): - """Evaluate candidates with multiple references with all Audio Captioning metrics. + +class DCASE2023Evaluate(Evaluate): + """Evaluate candidates with multiple references with DCASE2023 Audio Captioning metrics. For more information, see :func:`~aac_metrics.functional.evaluate.aac_evaluate`. """ @@ -103,27 +114,28 @@ class AACEvaluate(Evaluate): def __init__( self, preprocess: bool = True, - cache_path: str = "$HOME/.cache", - java_path: str = "java", - tmp_path: str = "/tmp", + cache_path: str = ..., + java_path: str = ..., + tmp_path: str = ..., + device: Union[str, torch.device, None] = "auto", verbose: int = 0, ) -> None: super().__init__( - preprocess, - "aac", - cache_path, - java_path, - tmp_path, - "auto", - verbose, + preprocess=preprocess, + metrics="dcase2023", + cache_path=cache_path, + java_path=java_path, + tmp_path=tmp_path, + device=device, + verbose=verbose, ) def _instantiate_metrics_classes( metrics: Union[str, Iterable[str], Iterable[AACMetric]] = "aac", - cache_path: str = "$HOME/.cache", - java_path: str = "java", - tmp_path: str = "/tmp", + cache_path: str = ..., + java_path: str = ..., + tmp_path: str = ..., device: Union[str, torch.device, None] = "auto", verbose: int = 0, ) -> list[AACMetric]: @@ -136,12 +148,12 @@ def _instantiate_metrics_classes( metrics = list(metrics) # type: ignore metric_factory = _get_metric_factory_classes( - True, - cache_path, - java_path, - tmp_path, - device, - verbose, + return_all_scores=True, + cache_path=cache_path, + java_path=java_path, + tmp_path=tmp_path, + device=device, + verbose=verbose, ) metrics_inst: list[AACMetric] = [] @@ -154,13 +166,16 @@ def _instantiate_metrics_classes( def _get_metric_factory_classes( return_all_scores: bool = True, - cache_path: str = "$HOME/.cache", - java_path: str = "java", - tmp_path: str = "/tmp", + cache_path: str = ..., + java_path: str = ..., + tmp_path: str = ..., device: Union[str, torch.device, None] = "auto", verbose: int = 0, ) -> dict[str, Callable[[], AACMetric]]: return { + "bleu": lambda: BLEU( + return_all_scores=return_all_scores, + ), "bleu_1": lambda: BLEU( return_all_scores=return_all_scores, n=1, @@ -203,7 +218,7 @@ def _get_metric_factory_classes( tmp_path=tmp_path, verbose=verbose, ), - "sbert": lambda: SBERTSim( + "sbert_sim": lambda: SBERTSim( return_all_scores=return_all_scores, device=device, verbose=verbose, diff --git a/src/aac_metrics/classes/fense.py b/src/aac_metrics/classes/fense.py index 72876f8..da0b318 100644 --- a/src/aac_metrics/classes/fense.py +++ b/src/aac_metrics/classes/fense.py @@ -17,7 +17,7 @@ pylog = logging.getLogger(__name__) -class FENSE(AACMetric): +class FENSE(AACMetric[Union[tuple[dict[str, Tensor], dict[str, Tensor]], Tensor]]): """Fluency ENhanced Sentence-bert Evaluation (FENSE) - Paper: https://arxiv.org/abs/2110.04684 diff --git a/src/aac_metrics/classes/fluerr.py b/src/aac_metrics/classes/fluerr.py index aef3bdb..fdfb43d 100644 --- a/src/aac_metrics/classes/fluerr.py +++ b/src/aac_metrics/classes/fluerr.py @@ -20,7 +20,7 @@ pylog = logging.getLogger(__name__) -class FluErr(AACMetric): +class FluErr(AACMetric[Union[tuple[dict[str, Tensor], dict[str, Tensor]], Tensor]]): """Return fluency error rate detected by a pre-trained BERT model. - Paper: https://arxiv.org/abs/2110.04684 diff --git a/src/aac_metrics/classes/meteor.py b/src/aac_metrics/classes/meteor.py index 57aa8b0..2d8aff6 100644 --- a/src/aac_metrics/classes/meteor.py +++ b/src/aac_metrics/classes/meteor.py @@ -9,7 +9,7 @@ from aac_metrics.functional.meteor import meteor -class METEOR(AACMetric): +class METEOR(AACMetric[Union[tuple[dict[str, Tensor], dict[str, Tensor]], Tensor]]): """Metric for Evaluation of Translation with Explicit ORdering metric class. - Paper: https://dl.acm.org/doi/pdf/10.5555/1626355.1626389 @@ -28,8 +28,8 @@ class METEOR(AACMetric): def __init__( self, return_all_scores: bool = True, - cache_path: str = "$HOME/.cache", - java_path: str = "java", + cache_path: str = ..., + java_path: str = ..., java_max_memory: str = "2G", language: str = "en", verbose: int = 0, diff --git a/src/aac_metrics/classes/rouge_l.py b/src/aac_metrics/classes/rouge_l.py index a1c32dd..7cc79c0 100644 --- a/src/aac_metrics/classes/rouge_l.py +++ b/src/aac_metrics/classes/rouge_l.py @@ -12,7 +12,7 @@ ) -class ROUGEL(AACMetric): +class ROUGEL(AACMetric[Union[tuple[dict[str, Tensor], dict[str, Tensor]], Tensor]]): """Recall-Oriented Understudy for Gisting Evaluation class. - Paper: https://aclanthology.org/W04-1013.pdf diff --git a/src/aac_metrics/classes/sbert_sim.py b/src/aac_metrics/classes/sbert_sim.py index 56b4e2a..39704dc 100644 --- a/src/aac_metrics/classes/sbert_sim.py +++ b/src/aac_metrics/classes/sbert_sim.py @@ -17,7 +17,7 @@ pylog = logging.getLogger(__name__) -class SBERTSim(AACMetric): +class SBERTSim(AACMetric[Union[tuple[dict[str, Tensor], dict[str, Tensor]], Tensor]]): """Cosine-similarity of the Sentence-BERT embeddings. - Paper: https://arxiv.org/abs/1908.10084 diff --git a/src/aac_metrics/classes/spice.py b/src/aac_metrics/classes/spice.py index 7ff4a53..b9308be 100644 --- a/src/aac_metrics/classes/spice.py +++ b/src/aac_metrics/classes/spice.py @@ -14,7 +14,7 @@ pylog = logging.getLogger(__name__) -class SPICE(AACMetric): +class SPICE(AACMetric[Union[tuple[dict[str, Tensor], dict[str, Tensor]], Tensor]]): """Semantic Propositional Image Caption Evaluation class. - Paper: https://arxiv.org/pdf/1607.08822.pdf @@ -32,9 +32,9 @@ class SPICE(AACMetric): def __init__( self, return_all_scores: bool = True, - cache_path: str = "$HOME/.cache", - java_path: str = "java", - tmp_path: str = "/tmp", + cache_path: str = ..., + java_path: str = ..., + tmp_path: str = ..., n_threads: Optional[int] = None, java_max_memory: str = "8G", timeout: Union[None, int, Iterable[int]] = None, diff --git a/src/aac_metrics/classes/spider.py b/src/aac_metrics/classes/spider.py index 2260c5a..3df89b8 100644 --- a/src/aac_metrics/classes/spider.py +++ b/src/aac_metrics/classes/spider.py @@ -14,7 +14,7 @@ pylog = logging.getLogger(__name__) -class SPIDEr(AACMetric): +class SPIDEr(AACMetric[Union[tuple[dict[str, Tensor], dict[str, Tensor]], Tensor]]): """SPIDEr class. - Paper: https://arxiv.org/pdf/1612.00370.pdf @@ -36,9 +36,9 @@ def __init__( n: int = 4, sigma: float = 6.0, # SPICE args - cache_path: str = "$HOME/.cache", - java_path: str = "java", - tmp_path: str = "/tmp", + cache_path: str = ..., + java_path: str = ..., + tmp_path: str = ..., n_threads: Optional[int] = None, java_max_memory: str = "8G", timeout: Union[None, int, Iterable[int]] = None, diff --git a/src/aac_metrics/classes/spider_fl.py b/src/aac_metrics/classes/spider_fl.py index 08d2a50..24edf96 100644 --- a/src/aac_metrics/classes/spider_fl.py +++ b/src/aac_metrics/classes/spider_fl.py @@ -21,7 +21,7 @@ pylog = logging.getLogger(__name__) -class SPIDErFL(AACMetric): +class SPIDErFL(AACMetric[Union[tuple[dict[str, Tensor], dict[str, Tensor]], Tensor]]): """SPIDErFL class. For more information, see :func:`~aac_metrics.functional.spider_fl.spider_fl`. @@ -41,9 +41,9 @@ def __init__( n: int = 4, sigma: float = 6.0, # SPICE args - cache_path: str = "$HOME/.cache", - java_path: str = "java", - tmp_path: str = "/tmp", + cache_path: str = ..., + java_path: str = ..., + tmp_path: str = ..., n_threads: Optional[int] = None, java_max_memory: str = "8G", timeout: Union[None, int, Iterable[int]] = None, diff --git a/src/aac_metrics/classes/spider_max.py b/src/aac_metrics/classes/spider_max.py index cbf43ae..dd0bad4 100644 --- a/src/aac_metrics/classes/spider_max.py +++ b/src/aac_metrics/classes/spider_max.py @@ -14,7 +14,7 @@ pylog = logging.getLogger(__name__) -class SPIDErMax(AACMetric): +class SPIDErMax(AACMetric[Union[tuple[dict[str, Tensor], dict[str, Tensor]], Tensor]]): """SPIDEr-max class. - Paper: https://hal.archives-ouvertes.fr/hal-03810396/file/Labbe_DCASE2022.pdf @@ -37,9 +37,9 @@ def __init__( n: int = 4, sigma: float = 6.0, # SPICE args - cache_path: str = "$HOME/.cache", - java_path: str = "java", - tmp_path: str = "/tmp", + cache_path: str = ..., + java_path: str = ..., + tmp_path: str = ..., n_threads: Optional[int] = None, java_max_memory: str = "8G", timeout: Union[None, int, Iterable[int]] = None, diff --git a/src/aac_metrics/download.py b/src/aac_metrics/download.py index 9957d07..041f09b 100644 --- a/src/aac_metrics/download.py +++ b/src/aac_metrics/download.py @@ -13,15 +13,25 @@ from torch.hub import download_url_to_file from aac_metrics.classes.fense import FENSE -from aac_metrics.functional.meteor import FNAME_METEOR_JAR -from aac_metrics.functional.spice import FNAME_SPICE_JAR, DNAME_SPICE_CACHE +from aac_metrics.functional.meteor import DNAME_METEOR_CACHE +from aac_metrics.functional.spice import ( + FNAME_SPICE_JAR, + DNAME_SPICE_LOCAL_CACHE, + DNAME_SPICE_CACHE, +) +from aac_metrics.utils.paths import ( + _get_cache_path, + _get_tmp_path, + get_default_cache_path, + get_default_tmp_path, +) from aac_metrics.utils.tokenization import FNAME_STANFORD_CORENLP_3_4_1_JAR pylog = logging.getLogger(__name__) -JAR_URLS = { +DATA_URLS = { "meteor": { "url": "https://github.com/tylin/coco-caption/raw/master/pycocoevalcap/meteor/meteor-1.5.jar", "fname": "meteor-1.5.jar", @@ -30,10 +40,30 @@ "url": "https://github.com/tylin/coco-caption/raw/master/pycocoevalcap/meteor/data/paraphrase-en.gz", "fname": osp.join("data", "paraphrase-en.gz"), }, + "meteor_data_fr": { + "url": "https://github.com/cmu-mtlab/meteor/raw/master/data/paraphrase-fr.gz", + "fname": osp.join("data", "paraphrase-fr.gz"), + }, + "meteor_data_de": { + "url": "https://github.com/cmu-mtlab/meteor/raw/master/data/paraphrase-de.gz", + "fname": osp.join("data", "paraphrase-de.gz"), + }, + "meteor_data_es": { + "url": "https://github.com/cmu-mtlab/meteor/raw/master/data/paraphrase-es.gz", + "fname": osp.join("data", "paraphrase-es.gz"), + }, + "meteor_data_cz": { + "url": "https://github.com/cmu-mtlab/meteor/raw/master/data/paraphrase-cz.gz", + "fname": osp.join("data", "paraphrase-cz.gz"), + }, "spice": { "url": "https://github.com/tylin/coco-caption/raw/master/pycocoevalcap/spice/spice-1.0.jar", "fname": "spice-1.0.jar", }, + "spice_zip": { + "url": "https://panderson.me/images/SPICE-1.0.zip", + "fname": "SPICE-1.0.zip", + }, "stanford_nlp": { "url": "https://github.com/tylin/coco-caption/raw/master/pycocoevalcap/tokenizer/stanford-corenlp-3.4.1.jar", "fname": "stanford-corenlp-3.4.1.jar", @@ -44,115 +74,158 @@ def download( - cache_path: str = "$HOME/.cache", - tmp_path: str = "/tmp", + cache_path: str = ..., + tmp_path: str = ..., ptb_tokenizer: bool = True, meteor: bool = True, spice: bool = True, fense: bool = True, verbose: int = 0, ) -> None: - """Download the code needed for SPICE, METEOR and PTB Tokenizer. + """Download the code needed for SPICE, METEOR, PTB Tokenizer and FENSE. - :param cache_path: The path to the external code directory. defaults to "$HOME/.cache". - :param tmp_path: The path to a temporary directory. defaults to "/tmp". + :param cache_path: The path to the external code directory. defaults to the value returned by :func:`~aac_metrics.utils.paths.get_default_cache_path`. + :param tmp_path: The path to a temporary directory. defaults to the value returned by :func:`~aac_metrics.utils.paths.get_default_tmp_path`. :param ptb_tokenizer: If True, downloads the PTBTokenizer code in cache directory. defaults to True. :param meteor: If True, downloads the METEOR code in cache directory. defaults to True. :param spice: If True, downloads the SPICE code in cache directory. defaults to True. :param fense: If True, downloads the FENSE models. defaults to True. :param verbose: The verbose level. defaults to 0. """ - cache_path = osp.expandvars(cache_path) - tmp_path = osp.expandvars(tmp_path) + cache_path = _get_cache_path(cache_path) + tmp_path = _get_tmp_path(tmp_path) os.makedirs(cache_path, exist_ok=True) os.makedirs(tmp_path, exist_ok=True) - if ptb_tokenizer: - # Download JAR file for tokenization - stanford_nlp_dpath = osp.join( - cache_path, osp.dirname(FNAME_STANFORD_CORENLP_3_4_1_JAR) - ) - os.makedirs(stanford_nlp_dpath, exist_ok=True) + if verbose >= 2: + pylog.debug("AAC setup:") + pylog.debug(f" Cache directory: {cache_path}") + pylog.debug(f" Temp directory: {tmp_path}") - name = "stanford_nlp" - info = JAR_URLS[name] - url = info["url"] - fname = info["fname"] - fpath = osp.join(stanford_nlp_dpath, fname) - if not osp.isfile(fpath): - if verbose >= 1: - pylog.info( - f"Downloading jar source for '{name}' in directory {stanford_nlp_dpath}." - ) - download_url_to_file(url, fpath, progress=verbose >= 1) - else: - if verbose >= 1: - pylog.info(f"Stanford model file '{name}' is already downloaded.") + if ptb_tokenizer: + _download_ptb_tokenizer(cache_path, verbose) if meteor: - # Download JAR files for METEOR metric - meteor_dpath = osp.join(cache_path, osp.dirname(FNAME_METEOR_JAR)) - os.makedirs(meteor_dpath, exist_ok=True) - - for name in ("meteor", "meteor_data"): - info = JAR_URLS[name] - url = info["url"] - fname = info["fname"] - subdir = osp.dirname(fname) - fpath = osp.join(meteor_dpath, fname) - - if not osp.isfile(fpath): - if verbose >= 1: - pylog.info( - f"Downloading jar source for '{name}' in directory {meteor_dpath}." - ) - if subdir not in ("", "."): - os.makedirs(osp.join(meteor_dpath, subdir), exist_ok=True) - - download_url_to_file( - url, - fpath, - progress=verbose >= 1, - ) - else: - if verbose >= 1: - pylog.info(f"Meteor file '{name}' is already downloaded.") + _download_meteor(cache_path, verbose) if spice: - # Download JAR files for SPICE metric - spice_jar_dpath = osp.join(cache_path, osp.dirname(FNAME_SPICE_JAR)) - spice_cache_path = osp.join(cache_path, DNAME_SPICE_CACHE) + _download_spice(cache_path, verbose) - os.makedirs(spice_jar_dpath, exist_ok=True) - os.makedirs(spice_cache_path, exist_ok=True) + if fense: + _download_fense(verbose) - script_path = osp.join(osp.dirname(__file__), "install_spice.sh") - if not osp.isfile(script_path): - raise FileNotFoundError( - f"Cannot find script '{osp.basename(script_path)}'." - ) +def _download_ptb_tokenizer( + cache_path: str = ..., + verbose: int = 0, +) -> None: + # Download JAR file for tokenization + stanford_nlp_dpath = osp.join( + cache_path, osp.dirname(FNAME_STANFORD_CORENLP_3_4_1_JAR) + ) + os.makedirs(stanford_nlp_dpath, exist_ok=True) + + name = "stanford_nlp" + info = DATA_URLS[name] + url = info["url"] + fname = info["fname"] + fpath = osp.join(stanford_nlp_dpath, fname) + if not osp.isfile(fpath): if verbose >= 1: pylog.info( - f"Downloading JAR sources for SPICE metric into '{spice_jar_dpath}'..." + f"Downloading JAR source for '{name}' in directory {stanford_nlp_dpath}." ) + download_url_to_file(url, fpath, progress=verbose >= 1) + else: + if verbose >= 1: + pylog.info(f"Stanford model file '{name}' is already downloaded.") - command = ["bash", script_path, spice_jar_dpath] - try: - subprocess.check_call( - command, - stdout=None if verbose >= 2 else subprocess.DEVNULL, - stderr=None if verbose >= 2 else subprocess.DEVNULL, - ) - except (CalledProcessError, PermissionError) as err: - pylog.error(err) - if fense: - # Download models files for FENSE metric +def _download_meteor( + cache_path: str = ..., + verbose: int = 0, +) -> None: + # Download JAR files for METEOR metric + meteor_dpath = osp.join(cache_path, DNAME_METEOR_CACHE) + os.makedirs(meteor_dpath, exist_ok=True) + + meteors_names = [name for name in DATA_URLS.keys() if name.startswith("meteor")] + + for name in meteors_names: + info = DATA_URLS[name] + url = info["url"] + fname = info["fname"] + subdir = osp.dirname(fname) + fpath = osp.join(meteor_dpath, fname) + + if osp.isfile(fpath): + if verbose >= 1: + pylog.info(f"Meteor file '{name}' is already downloaded.") + continue + if verbose >= 1: - pylog.info("Downloading SBERT and BERT error detector for FENSE metric...") - _ = FENSE(device="cpu") + pylog.info(f"Downloading source for '{fname}' in directory {meteor_dpath}.") + if subdir not in ("", "."): + os.makedirs(osp.dirname(fpath), exist_ok=True) + + download_url_to_file( + url, + fpath, + progress=verbose >= 1, + ) + + +def _download_spice( + cache_path: str = ..., + verbose: int = 0, +) -> None: + # Download JAR files for SPICE metric + spice_cache_dpath = osp.join(cache_path, DNAME_SPICE_CACHE) + spice_jar_dpath = osp.join(cache_path, osp.dirname(FNAME_SPICE_JAR)) + spice_local_cache_path = osp.join(cache_path, DNAME_SPICE_LOCAL_CACHE) + + os.makedirs(spice_jar_dpath, exist_ok=True) + os.makedirs(spice_local_cache_path, exist_ok=True) + + spice_zip_url = DATA_URLS["spice_zip"]["url"] + spice_zip_fpath = osp.join(spice_cache_dpath, DATA_URLS["spice_zip"]["fname"]) + + if osp.isfile(spice_zip_fpath): + if verbose >= 1: + pylog.info(f"SPICE ZIP file '{spice_zip_fpath}' is already downloaded.") + else: + if verbose >= 1: + pylog.info(f"Downloading SPICE ZIP file '{spice_zip_fpath}'...") + download_url_to_file(spice_zip_url, spice_zip_fpath, progress=verbose > 0) + + script_path = osp.join(osp.dirname(__file__), "install_spice.sh") + if not osp.isfile(script_path): + raise FileNotFoundError(f"Cannot find script '{osp.basename(script_path)}'.") + + if verbose >= 1: + pylog.info( + f"Downloading JAR sources for SPICE metric into '{spice_jar_dpath}'..." + ) + + command = ["bash", script_path, spice_jar_dpath] + try: + subprocess.check_call( + command, + stdout=None if verbose >= 2 else subprocess.DEVNULL, + stderr=None if verbose >= 2 else subprocess.DEVNULL, + ) + except (CalledProcessError, PermissionError) as err: + pylog.error(err) + + +def _download_fense( + verbose: int = 0, +) -> None: + # Download models files for FENSE metric + if verbose >= 1: + pylog.info("Downloading SBERT and BERT error detector for FENSE metric...") + _ = FENSE(device="cpu") def _get_main_download_args() -> Namespace: @@ -163,13 +236,13 @@ def _get_main_download_args() -> Namespace: parser.add_argument( "--cache_path", type=str, - default="$HOME/.cache", + default=get_default_cache_path(), help="Cache directory path.", ) parser.add_argument( "--tmp_path", type=str, - default="/tmp", + default=get_default_tmp_path(), help="Temporary directory path.", ) parser.add_argument( @@ -215,13 +288,13 @@ def _main_download() -> None: pkg_logger.setLevel(level) download( - args.cache_path, - args.tmp_path, - args.ptb_tokenizer, - args.meteor, - args.spice, - args.fense, - args.verbose, + cache_path=args.cache_path, + tmp_path=args.tmp_path, + ptb_tokenizer=args.ptb_tokenizer, + meteor=args.meteor, + spice=args.spice, + fense=args.fense, + verbose=args.verbose, ) diff --git a/src/aac_metrics/evaluate.py b/src/aac_metrics/evaluate.py index 815ec08..fcfba37 100644 --- a/src/aac_metrics/evaluate.py +++ b/src/aac_metrics/evaluate.py @@ -11,8 +11,17 @@ import yaml -from aac_metrics.functional.evaluate import evaluate, METRICS_SETS +from aac_metrics.functional.evaluate import ( + evaluate, + METRICS_SETS, + DEFAULT_METRICS_SET_NAME, +) from aac_metrics.utils.checks import check_metric_inputs, check_java_path +from aac_metrics.utils.paths import ( + get_default_cache_path, + get_default_java_path, + get_default_tmp_path, +) pylog = logging.getLogger(__name__) @@ -152,27 +161,27 @@ def _get_main_evaluate_args() -> Namespace: parser.add_argument( "--metrics_set_name", type=str, - default="default", + default=DEFAULT_METRICS_SET_NAME, choices=tuple(METRICS_SETS.keys()), help=f"The metrics set to compute. Can be one of {tuple(METRICS_SETS.keys())}. defaults to 'default'.", ) parser.add_argument( "--cache_path", type=str, - default="$HOME/.cache", - help="Cache directory path. defaults to '$HOME/.cache'.", + default=get_default_cache_path(), + help=f"Cache directory path. defaults to '{get_default_cache_path()}'.", ) parser.add_argument( "--java_path", type=str, - default="java", - help="Java executable path. defaults to 'java'.", + default=get_default_java_path(), + help=f"Java executable path. defaults to '{get_default_java_path()}'.", ) parser.add_argument( "--tmp_path", type=str, - default="/tmp", - help="Temporary directory path. defaults to '/tmp'.", + default=get_default_tmp_path(), + help=f"Temporary directory path. defaults to '{get_default_tmp_path()}'.", ) parser.add_argument("--verbose", type=int, default=0, help="Verbose level.") diff --git a/src/aac_metrics/functional/bleu.py b/src/aac_metrics/functional/bleu.py index 28ad8f0..5414c28 100644 --- a/src/aac_metrics/functional/bleu.py +++ b/src/aac_metrics/functional/bleu.py @@ -66,6 +66,90 @@ def bleu( ) +def bleu_1( + candidates: list[str], + mult_references: list[list[str]], + return_all_scores: bool = True, + option: str = "closest", + verbose: int = 0, + tokenizer: Callable[[str], list[str]] = str.split, + return_1_to_n: bool = False, +) -> Union[tuple[dict[str, Tensor], dict[str, Tensor]], Tensor]: + return bleu( + candidates=candidates, + mult_references=mult_references, + return_all_scores=return_all_scores, + n=1, + option=option, + verbose=verbose, + tokenizer=tokenizer, + return_1_to_n=return_1_to_n, + ) + + +def bleu_2( + candidates: list[str], + mult_references: list[list[str]], + return_all_scores: bool = True, + option: str = "closest", + verbose: int = 0, + tokenizer: Callable[[str], list[str]] = str.split, + return_1_to_n: bool = False, +) -> Union[tuple[dict[str, Tensor], dict[str, Tensor]], Tensor]: + return bleu( + candidates=candidates, + mult_references=mult_references, + return_all_scores=return_all_scores, + n=2, + option=option, + verbose=verbose, + tokenizer=tokenizer, + return_1_to_n=return_1_to_n, + ) + + +def bleu_3( + candidates: list[str], + mult_references: list[list[str]], + return_all_scores: bool = True, + option: str = "closest", + verbose: int = 0, + tokenizer: Callable[[str], list[str]] = str.split, + return_1_to_n: bool = False, +) -> Union[tuple[dict[str, Tensor], dict[str, Tensor]], Tensor]: + return bleu( + candidates=candidates, + mult_references=mult_references, + return_all_scores=return_all_scores, + n=3, + option=option, + verbose=verbose, + tokenizer=tokenizer, + return_1_to_n=return_1_to_n, + ) + + +def bleu_4( + candidates: list[str], + mult_references: list[list[str]], + return_all_scores: bool = True, + option: str = "closest", + verbose: int = 0, + tokenizer: Callable[[str], list[str]] = str.split, + return_1_to_n: bool = False, +) -> Union[tuple[dict[str, Tensor], dict[str, Tensor]], Tensor]: + return bleu( + candidates=candidates, + mult_references=mult_references, + return_all_scores=return_all_scores, + n=4, + option=option, + verbose=verbose, + tokenizer=tokenizer, + return_1_to_n=return_1_to_n, + ) + + def _bleu_update( candidates: list[str], mult_references: list[list[str]], diff --git a/src/aac_metrics/functional/evaluate.py b/src/aac_metrics/functional/evaluate.py index bdb8789..dd0fba8 100644 --- a/src/aac_metrics/functional/evaluate.py +++ b/src/aac_metrics/functional/evaluate.py @@ -21,6 +21,7 @@ from aac_metrics.functional.spice import spice from aac_metrics.functional.spider import spider from aac_metrics.functional.spider_fl import spider_fl +from aac_metrics.utils.checks import check_metric_inputs from aac_metrics.utils.tokenization import preprocess_mono_sents, preprocess_mult_sents @@ -65,6 +66,7 @@ "spider_fl", # includes cider_d, spice, spider, fluerr ), } +DEFAULT_METRICS_SET_NAME = "default" def evaluate( @@ -73,10 +75,10 @@ def evaluate( preprocess: bool = True, metrics: Union[ str, Iterable[str], Iterable[Callable[[list, list], tuple]] - ] = "default", - cache_path: str = "$HOME/.cache", - java_path: str = "java", - tmp_path: str = "/tmp", + ] = DEFAULT_METRICS_SET_NAME, + cache_path: str = ..., + java_path: str = ..., + tmp_path: str = ..., device: Union[str, torch.device, None] = "auto", verbose: int = 0, ) -> tuple[dict[str, Tensor], dict[str, Tensor]]: @@ -86,32 +88,34 @@ def evaluate( :param mult_references: The list of list of sentences used as target. :param preprocess: If True, the candidates and references will be passed as input to the PTB stanford tokenizer before computing metrics.defaults to True. :param metrics: The name of the metric list or the explicit list of metrics to compute. defaults to "default". - :param cache_path: The path to the external code directory. defaults to "$HOME/.cache". - :param java_path: The path to the java executable. defaults to "java". - :param tmp_path: Temporary directory path. defaults to "/tmp". + :param cache_path: The path to the external code directory. defaults to the value returned by :func:`~aac_metrics.utils.paths.get_default_cache_path`. + :param java_path: The path to the java executable. defaults to the value returned by :func:`~aac_metrics.utils.paths.get_default_java_path`. + :param tmp_path: Temporary directory path. defaults to the value returned by :func:`~aac_metrics.utils.paths.get_default_tmp_path`. :param device: The PyTorch device used to run FENSE and SPIDErFL models. If None, it will try to detect use cuda if available. defaults to "auto". :param verbose: The verbose level. defaults to 0. :returns: A tuple of globals and locals scores. """ + check_metric_inputs(candidates, mult_references) + metrics = _instantiate_metrics_functions( metrics, cache_path, java_path, tmp_path, device, verbose ) if preprocess: + common_kwds = dict( + cache_path=cache_path, + java_path=java_path, + tmp_path=tmp_path, + verbose=verbose, + ) candidates = preprocess_mono_sents( candidates, - cache_path, - java_path, - tmp_path, - verbose=verbose, + **common_kwds, ) mult_references = preprocess_mult_sents( mult_references, - cache_path, - java_path, - tmp_path, - verbose=verbose, + **common_kwds, ) outs_corpus = {} @@ -157,44 +161,44 @@ def dcase2023_evaluate( candidates: list[str], mult_references: list[list[str]], preprocess: bool = True, - cache_path: str = "$HOME/.cache", - java_path: str = "java", - tmp_path: str = "/tmp", + cache_path: str = ..., + java_path: str = ..., + tmp_path: str = ..., device: Union[str, torch.device, None] = "auto", verbose: int = 0, ) -> tuple[dict[str, Tensor], dict[str, Tensor]]: - """Evaluate candidates with multiple references with all Audio Captioning metrics. + """Evaluate candidates with multiple references with the DCASE2023 Audio Captioning metrics. :param candidates: The list of sentences to evaluate. :param mult_references: The list of list of sentences used as target. :param preprocess: If True, the candidates and references will be passed as input to the PTB stanford tokenizer before computing metrics. defaults to True. - :param cache_path: The path to the external code directory. defaults to "$HOME/.cache". - :param java_path: The path to the java executable. defaults to "java". - :param tmp_path: Temporary directory path. defaults to "/tmp". + :param cache_path: The path to the external code directory. defaults to the value returned by :func:`~aac_metrics.utils.paths.get_default_cache_path`. + :param java_path: The path to the java executable. defaults to the value returned by :func:`~aac_metrics.utils.paths.get_default_java_path`. + :param tmp_path: Temporary directory path. defaults to the value returned by :func:`~aac_metrics.utils.paths.get_default_tmp_path`. :param device: The PyTorch device used to run FENSE and SPIDErFL models. If None, it will try to detect use cuda if available. defaults to "auto". :param verbose: The verbose level. defaults to 0. :returns: A tuple of globals and locals scores. """ return evaluate( - candidates, - mult_references, - preprocess, - "dcase2023", - cache_path, - java_path, - tmp_path, - device, - verbose, + candidates=candidates, + mult_references=mult_references, + preprocess=preprocess, + metrics="dcase2023", + cache_path=cache_path, + java_path=java_path, + tmp_path=tmp_path, + device=device, + verbose=verbose, ) def _instantiate_metrics_functions( metrics: Union[str, Iterable[str], Iterable[Callable[[list, list], tuple]]] = "all", - cache_path: str = "$HOME/.cache", - java_path: str = "java", - tmp_path: str = "/tmp", + cache_path: str = ..., + java_path: str = ..., + tmp_path: str = ..., device: Union[str, torch.device, None] = "auto", verbose: int = 0, ) -> list[Callable]: @@ -212,12 +216,12 @@ def _instantiate_metrics_functions( ) metric_factory = _get_metric_factory_functions( - True, - cache_path, - java_path, - tmp_path, - device, - verbose, + return_all_scores=True, + cache_path=cache_path, + java_path=java_path, + tmp_path=tmp_path, + device=device, + verbose=verbose, ) metrics_inst: list[Callable] = [] @@ -230,13 +234,17 @@ def _instantiate_metrics_functions( def _get_metric_factory_functions( return_all_scores: bool = True, - cache_path: str = "$HOME/.cache", - java_path: str = "java", - tmp_path: str = "/tmp", + cache_path: str = ..., + java_path: str = ..., + tmp_path: str = ..., device: Union[str, torch.device, None] = "auto", verbose: int = 0, ) -> dict[str, Callable[[list[str], list[list[str]]], Any]]: return { + "bleu": partial( + bleu, + return_all_scores=return_all_scores, + ), "bleu_1": partial( bleu, return_all_scores=return_all_scores, @@ -288,13 +296,13 @@ def _get_metric_factory_functions( tmp_path=tmp_path, verbose=verbose, ), - "sbert": partial( + "sbert_sim": partial( sbert_sim, return_all_scores=return_all_scores, device=device, verbose=verbose, ), - "fluerr": partial( + "fluerr": partial( # type: ignore fluerr, return_all_scores=return_all_scores, device=device, diff --git a/src/aac_metrics/functional/fense.py b/src/aac_metrics/functional/fense.py index 7c5a362..68f2871 100644 --- a/src/aac_metrics/functional/fense.py +++ b/src/aac_metrics/functional/fense.py @@ -66,7 +66,7 @@ def fense( :param penalty: The penalty coefficient applied. Higher value means to lower the cos-sim scores when an error is detected. defaults to 0.9. :param device: The PyTorch device used to run FENSE models. If "auto", it will use cuda if available. defaults to "auto". :param batch_size: The batch size of the sBERT and echecker models. defaults to 32. - :param reset_state: If True, reset the state of the PyTorch global generator after the pre-trained model are built. defaults to True. + :param reset_state: If True, reset the state of the PyTorch global generator after the initialization of the pre-trained models. defaults to True. :param return_probs: If True, return each individual error probability given by the fluency detector model. defaults to True. :param verbose: The verbose level. defaults to 0. :returns: A tuple of globals and locals scores or a scalar tensor with the main global score. diff --git a/src/aac_metrics/functional/fluerr.py b/src/aac_metrics/functional/fluerr.py index 4c1c629..65c66ec 100644 --- a/src/aac_metrics/functional/fluerr.py +++ b/src/aac_metrics/functional/fluerr.py @@ -123,7 +123,7 @@ def fluerr( :param error_threshold: The threshold used to detect fluency errors for echecker model. defaults to 0.9. :param device: The PyTorch device used to run FENSE models. If "auto", it will use cuda if available. defaults to "auto". :param batch_size: The batch size of the echecker models. defaults to 32. - :param reset_state: If True, reset the state of the PyTorch global generator after the pre-trained model are built. defaults to True. + :param reset_state: If True, reset the state of the PyTorch global generator after the initialization of the pre-trained models. defaults to True. :param return_probs: If True, return each individual error probability given by the fluency detector model. defaults to True. :param verbose: The verbose level. defaults to 0. :returns: A tuple of globals and locals scores or a scalar tensor with the main global score. @@ -193,7 +193,7 @@ def _load_echecker_and_tokenizer( echecker = __load_pretrain_echecker(echecker, device, verbose=verbose) if echecker_tokenizer is None: - echecker_tokenizer = AutoTokenizer.from_pretrained(echecker.model_type) + echecker_tokenizer = AutoTokenizer.from_pretrained(echecker.model_type) # type: ignore echecker = echecker.eval() for p in echecker.parameters(): @@ -231,10 +231,10 @@ def __detect_error_sents( # batch_logits: (bsize, num_classes=6) # note: fix error in the original fense code: https://github.com/blmoistawinde/fense/blob/main/fense/evaluator.py#L69 probs = logits.sigmoid().transpose(0, 1).cpu().numpy() - probs_dic = dict(zip(ERROR_NAMES, probs)) + probs_dic: dict[str, np.ndarray] = dict(zip(ERROR_NAMES, probs)) else: - probs_dic = {name: [] for name in ERROR_NAMES} + dic_lst_probs = {name: [] for name in ERROR_NAMES} for i in range(0, len(sents), batch_size): batch = __infer_preprocess( @@ -251,10 +251,12 @@ def __detect_error_sents( # classes: add_tail, repeat_event, repeat_adv, remove_conj, remove_verb, error probs = batch_logits.sigmoid().cpu().numpy() - for j, name in enumerate(probs_dic.keys()): - probs_dic[name].append(probs[:, j]) + for j, name in enumerate(dic_lst_probs.keys()): + dic_lst_probs[name].append(probs[:, j]) - probs_dic = {name: np.concatenate(probs) for name, probs in probs_dic.items()} + probs_dic = { + name: np.concatenate(probs) for name, probs in dic_lst_probs.items() + } return probs_dic @@ -412,6 +414,12 @@ def __load_pretrain_echecker( pylog.debug(f"Loading echecker model from '{file_path}'.") model_states = torch.load(file_path) + + if verbose >= 2: + pylog.debug( + f"Loading echecker model type '{model_states['model_type']}' with '{model_states['num_classes']}' classes." + ) + echecker = BERTFlatClassifier( model_type=model_states["model_type"], num_classes=model_states["num_classes"], diff --git a/src/aac_metrics/functional/meteor.py b/src/aac_metrics/functional/meteor.py index da0dd3d..79df12a 100644 --- a/src/aac_metrics/functional/meteor.py +++ b/src/aac_metrics/functional/meteor.py @@ -15,12 +15,14 @@ from torch import Tensor from aac_metrics.utils.checks import check_java_path +from aac_metrics.utils.paths import _get_cache_path, _get_java_path pylog = logging.getLogger(__name__) -FNAME_METEOR_JAR = osp.join("aac-metrics", "meteor", "meteor-1.5.jar") +DNAME_METEOR_CACHE = osp.join("aac-metrics", "meteor") +FNAME_METEOR_JAR = osp.join(DNAME_METEOR_CACHE, "meteor-1.5.jar") SUPPORTED_LANGUAGES = ("en", "cz", "de", "es", "fr") @@ -28,8 +30,8 @@ def meteor( candidates: list[str], mult_references: list[list[str]], return_all_scores: bool = True, - cache_path: str = "$HOME/.cache", - java_path: str = "java", + cache_path: str = ..., + java_path: str = ..., java_max_memory: str = "2G", language: str = "en", verbose: int = 0, @@ -44,18 +46,19 @@ def meteor( :param return_all_scores: If True, returns a tuple containing the globals and locals scores. Otherwise returns a scalar tensor containing the main global score. defaults to True. - :param cache_path: The path to the external code directory. defaults to "$HOME/.cache". - :param java_path: The path to the java executable. defaults to "java". + :param cache_path: The path to the external code directory. defaults to the value returned by :func:`~aac_metrics.utils.paths.get_default_cache_path`. + :param java_path: The path to the java executable. defaults to the value returned by :func:`~aac_metrics.utils.paths.get_default_java_path`. :param java_max_memory: The maximal java memory used. defaults to "2G". - :param language: The language used for stem, synonym and paraphrase matching. defaults to "en". + :param language: The language used for stem, synonym and paraphrase matching. + Can be one of ("en", "cz", "de", "es", "fr"). + defaults to "en". :param verbose: The verbose level. defaults to 0. :returns: A tuple of globals and locals scores or a scalar tensor with the main global score. """ - cache_path = osp.expandvars(cache_path) - java_path = osp.expandvars(java_path) + cache_path = _get_cache_path(cache_path) + java_path = _get_java_path(java_path) meteor_jar_fpath = osp.join(cache_path, FNAME_METEOR_JAR) - language = "en" # supported: en cz de es fr if __debug__: if not osp.isfile(meteor_jar_fpath): diff --git a/src/aac_metrics/functional/mult_cands.py b/src/aac_metrics/functional/mult_cands.py index a86d15e..98ae01c 100644 --- a/src/aac_metrics/functional/mult_cands.py +++ b/src/aac_metrics/functional/mult_cands.py @@ -76,20 +76,20 @@ def mult_cands_metric( if selection == "max": indexes = all_sents_scores[metric_out_name].argmax(dim=0).unsqueeze(dim=0) outs_sents = { - k: scores.gather(0, indexes).squeeze(dim=0) + f"{k}_{selection}": scores.gather(0, indexes).squeeze(dim=0) for k, scores in all_sents_scores.items() } elif selection == "min": indexes = all_sents_scores[metric_out_name].argmin(dim=0).unsqueeze(dim=0) outs_sents = { - k: scores.gather(0, indexes).squeeze(dim=0) + f"{k}_{selection}": scores.gather(0, indexes).squeeze(dim=0) for k, scores in all_sents_scores.items() } elif selection == "mean": selected_scores = all_sents_scores[metric_out_name].mean(dim=0) - outs_sents = {metric_out_name: selected_scores} + outs_sents = {f"{metric_out_name}_{selection}": selected_scores} else: raise ValueError( diff --git a/src/aac_metrics/functional/rouge_l.py b/src/aac_metrics/functional/rouge_l.py index 0fad55e..66aa10e 100644 --- a/src/aac_metrics/functional/rouge_l.py +++ b/src/aac_metrics/functional/rouge_l.py @@ -144,6 +144,7 @@ def __my_lcs(string: list[str], sub: list[str]) -> int: sub, string = string, sub lengths = [[0 for i in range(0, len(sub) + 1)] for j in range(0, len(string) + 1)] + # lengths shape: (len(string)+1, len(sub)+1) for j in range(1, len(sub) + 1): for i in range(1, len(string) + 1): diff --git a/src/aac_metrics/functional/sbert_sim.py b/src/aac_metrics/functional/sbert_sim.py index fc7e5df..82c8ccc 100644 --- a/src/aac_metrics/functional/sbert_sim.py +++ b/src/aac_metrics/functional/sbert_sim.py @@ -42,7 +42,7 @@ def sbert_sim( :param sbert_model: The sentence BERT model used to extract sentence embeddings for cosine-similarity. defaults to "paraphrase-TinyBERT-L6-v2". :param device: The PyTorch device used to run FENSE models. If "auto", it will use cuda if available. defaults to "auto". :param batch_size: The batch size of the sBERT models. defaults to 32. - :param reset_state: If True, reset the state of the PyTorch global generator after the pre-trained model are built. defaults to True. + :param reset_state: If True, reset the state of the PyTorch global generator after the initialization of the pre-trained models. defaults to True. :param verbose: The verbose level. defaults to 0. :returns: A tuple of globals and locals scores or a scalar tensor with the main global score. """ diff --git a/src/aac_metrics/functional/spice.py b/src/aac_metrics/functional/spice.py index a16aefd..d6b6307 100644 --- a/src/aac_metrics/functional/spice.py +++ b/src/aac_metrics/functional/spice.py @@ -21,22 +21,28 @@ from torch import Tensor from aac_metrics.utils.checks import check_java_path +from aac_metrics.utils.paths import ( + _get_cache_path, + _get_java_path, + _get_tmp_path, +) pylog = logging.getLogger(__name__) -DNAME_SPICE_CACHE = osp.join("aac-metrics", "spice", "cache") -FNAME_SPICE_JAR = osp.join("aac-metrics", "spice", "spice-1.0.jar") +DNAME_SPICE_CACHE = osp.join("aac-metrics", "spice") +DNAME_SPICE_LOCAL_CACHE = osp.join(DNAME_SPICE_CACHE, "cache") +FNAME_SPICE_JAR = osp.join(DNAME_SPICE_CACHE, "spice-1.0.jar") def spice( candidates: list[str], mult_references: list[list[str]], return_all_scores: bool = True, - cache_path: str = "$HOME/.cache", - java_path: str = "java", - tmp_path: str = "/tmp", + cache_path: str = ..., + java_path: str = ..., + tmp_path: str = ..., n_threads: Optional[int] = None, java_max_memory: str = "8G", timeout: Union[None, int, Iterable[int]] = None, @@ -52,9 +58,9 @@ def spice( :param return_all_scores: If True, returns a tuple containing the globals and locals scores. Otherwise returns a scalar tensor containing the main global score. defaults to True. - :param cache_path: The path to the external code directory. defaults to "$HOME/.cache". - :param java_path: The path to the java executable. defaults to "java". - :param tmp_path: Temporary directory path. defaults to "/tmp". + :param cache_path: The path to the external code directory. defaults to the value returned by :func:`~aac_metrics.utils.paths.get_default_cache_path`. + :param java_path: The path to the java executable. defaults to the value returned by :func:`~aac_metrics.utils.paths.get_default_java_path`. + :param tmp_path: Temporary directory path. defaults to the value returned by :func:`~aac_metrics.utils.paths.get_default_tmp_path`. :param n_threads: Number of threads used to compute SPICE. None value will use the default value of the java program. defaults to None. @@ -70,9 +76,9 @@ def spice( :returns: A tuple of globals and locals scores or a scalar tensor with the main global score. """ - cache_path = osp.expandvars(cache_path) - java_path = osp.expandvars(java_path) - tmp_path = osp.expandvars(tmp_path) + cache_path = _get_cache_path(cache_path) + java_path = _get_java_path(java_path) + tmp_path = _get_tmp_path(tmp_path) spice_fpath = osp.join(cache_path, FNAME_SPICE_JAR) @@ -94,7 +100,7 @@ def spice( if separate_cache_dir: spice_cache = tempfile.mkdtemp(dir=tmp_path) else: - spice_cache = osp.join(cache_path, DNAME_SPICE_CACHE) + spice_cache = osp.join(cache_path, DNAME_SPICE_LOCAL_CACHE) del cache_path if verbose >= 2: diff --git a/src/aac_metrics/functional/spider.py b/src/aac_metrics/functional/spider.py index 085cf0b..01d2f5b 100644 --- a/src/aac_metrics/functional/spider.py +++ b/src/aac_metrics/functional/spider.py @@ -19,9 +19,9 @@ def spider( tokenizer: Callable[[str], list[str]] = str.split, return_tfidf: bool = False, # SPICE args - cache_path: str = "$HOME/.cache", - java_path: str = "java", - tmp_path: str = "/tmp", + cache_path: str = ..., + java_path: str = ..., + tmp_path: str = ..., n_threads: Optional[int] = None, java_max_memory: str = "8G", timeout: Union[None, int, Iterable[int]] = None, @@ -41,9 +41,9 @@ def spider( :param tokenizer: The fast tokenizer used to split sentences into words. defaults to str.split. :param return_tfidf: If True, returns the list of dictionaries containing the tf-idf scores of n-grams in the sents_score output. defaults to False. - :param cache_path: The path to the external code directory. defaults to "$HOME/.cache". - :param java_path: The path to the java executable. defaults to "java". - :param tmp_path: Temporary directory path. defaults to "/tmp". + :param cache_path: The path to the external code directory. defaults to the value returned by :func:`~aac_metrics.utils.paths.get_default_cache_path`. + :param java_path: The path to the java executable. defaults to the value returned by :func:`~aac_metrics.utils.paths.get_default_java_path`. + :param tmp_path: Temporary directory path. defaults to the value returned by :func:`~aac_metrics.utils.paths.get_default_tmp_path`. :param java_max_memory: The maximal java memory used. defaults to "8G". :param n_threads: Number of threads used to compute SPICE. None value will use the default value of the java program. diff --git a/src/aac_metrics/functional/spider_fl.py b/src/aac_metrics/functional/spider_fl.py index 77c385f..ed9e007 100644 --- a/src/aac_metrics/functional/spider_fl.py +++ b/src/aac_metrics/functional/spider_fl.py @@ -35,9 +35,9 @@ def spider_fl( tokenizer: Callable[[str], list[str]] = str.split, return_tfidf: bool = False, # SPICE args - cache_path: str = "$HOME/.cache", - java_path: str = "java", - tmp_path: str = "/tmp", + cache_path: str = ..., + java_path: str = ..., + tmp_path: str = ..., n_threads: Optional[int] = None, java_max_memory: str = "8G", timeout: Union[None, int, Iterable[int]] = None, @@ -67,9 +67,9 @@ def spider_fl( :param tokenizer: The fast tokenizer used to split sentences into words. defaults to str.split. :param return_tfidf: If True, returns the list of dictionaries containing the tf-idf scores of n-grams in the sents_score output. defaults to False. - :param cache_path: The path to the external code directory. defaults to "$HOME/.cache". - :param java_path: The path to the java executable. defaults to "java". - :param tmp_path: Temporary directory path. defaults to "/tmp". + :param cache_path: The path to the external code directory. defaults to the value returned by :func:`~aac_metrics.utils.paths.get_default_cache_path`. + :param java_path: The path to the java executable. defaults to the value returned by :func:`~aac_metrics.utils.paths.get_default_java_path`. + :param tmp_path: Temporary directory path. defaults to the value returned by :func:`~aac_metrics.utils.paths.get_default_tmp_path`. :param java_max_memory: The maximal java memory used. defaults to "8G". :param n_threads: Number of threads used to compute SPICE. None value will use the default value of the java program. @@ -83,7 +83,7 @@ def spider_fl( :param error_threshold: The threshold used to detect fluency errors for echecker model. defaults to 0.9. :param device: The PyTorch device used to run FENSE models. If "auto", it will use cuda if available. defaults to "auto". :param batch_size: The batch size of the sBERT and echecker models. defaults to 32. - :param reset_state: If True, reset the state of the PyTorch global generator after the pre-trained model are built. defaults to True. + :param reset_state: If True, reset the state of the PyTorch global generator after the initialization of the pre-trained models. defaults to True. :param return_probs: If True, return each individual error probability given by the fluency detector model. defaults to True. :param penalty: The penalty coefficient applied. Higher value means to lower the cos-sim scores when an error is detected. defaults to 0.9. :param verbose: The verbose level. defaults to 0. diff --git a/src/aac_metrics/functional/spider_max.py b/src/aac_metrics/functional/spider_max.py index cfa5b29..25b4d21 100644 --- a/src/aac_metrics/functional/spider_max.py +++ b/src/aac_metrics/functional/spider_max.py @@ -22,9 +22,9 @@ def spider_max( tokenizer: Callable[[str], list[str]] = str.split, return_tfidf: bool = False, # SPICE args - cache_path: str = "$HOME/.cache", - java_path: str = "java", - tmp_path: str = "/tmp", + cache_path: str = ..., + java_path: str = ..., + tmp_path: str = ..., n_threads: Optional[int] = None, java_max_memory: str = "8G", timeout: Union[None, int, Iterable[int]] = None, @@ -32,7 +32,7 @@ def spider_max( ) -> Union[tuple[dict[str, Tensor], dict[str, Tensor]], Tensor]: """SPIDEr-max function. - - Paper: https://hal.archives-ouvertes.fr/hal-03810396/file/Labbe_DCASE2022.pdf + - Paper: https://dcase.community/documents/workshop2022/proceedings/DCASE2022Workshop_Labbe_46.pdf Compute the maximal SPIDEr score accross multiple candidates. @@ -41,16 +41,16 @@ def spider_max( :param return_all_scores: If True, returns a tuple containing the globals and locals scores. Otherwise returns a scalar tensor containing the main global score. defaults to True. - :param return_all_cands_scores: If True, returns all multiple candidates scores in sents_scores outputs as tensor of shape (n_audoi, n_cands_per_audio). + :param return_all_cands_scores: If True, returns all multiple candidates scores in sents_scores outputs as tensor of shape (n_audio, n_cands_per_audio). defaults to False. :param n: Maximal number of n-grams taken into account. defaults to 4. :param sigma: Standard deviation parameter used for gaussian penalty. defaults to 6.0. :param tokenizer: The fast tokenizer used to split sentences into words. defaults to str.split. :param return_tfidf: If True, returns the list of dictionaries containing the tf-idf scores of n-grams in the sents_score output. defaults to False. - :param cache_path: The path to the external code directory. defaults to "$HOME/.cache". - :param java_path: The path to the java executable. defaults to "java". - :param tmp_path: Temporary directory path. defaults to "/tmp". + :param cache_path: The path to the external code directory. defaults to the value returned by :func:`~aac_metrics.utils.paths.get_default_cache_path`. + :param java_path: The path to the java executable. defaults to the value returned by :func:`~aac_metrics.utils.paths.get_default_java_path`. + :param tmp_path: Temporary directory path. defaults to the value returned by :func:`~aac_metrics.utils.paths.get_default_tmp_path`. :param java_max_memory: The maximal java memory used. defaults to "8G". :param n_threads: Number of threads used to compute SPICE. None value will use the default value of the java program. diff --git a/src/aac_metrics/info.py b/src/aac_metrics/info.py index 15eb915..b852b2f 100644 --- a/src/aac_metrics/info.py +++ b/src/aac_metrics/info.py @@ -4,6 +4,7 @@ import platform import sys +from pathlib import Path from typing import Dict import torch @@ -11,6 +12,17 @@ import aac_metrics +from aac_metrics.utils.paths import ( + get_default_cache_path, + get_default_java_path, + get_default_tmp_path, +) + + +def get_package_repository_path() -> str: + """Return the absolute path where the source code of this package is installed.""" + return str(Path(__file__).parent.parent.parent) + def get_install_info() -> Dict[str, str]: """Return a dictionary containing the version python, the os name, the architecture name and the versions of the following packages: aac_datasets, torch, torchaudio.""" @@ -20,11 +32,15 @@ def get_install_info() -> Dict[str, str]: "os": platform.system(), "architecture": platform.architecture()[0], "torch": str(torch.__version__), + "package_path": get_package_repository_path(), + "cache_path": get_default_cache_path(), + "java_path": get_default_java_path(), + "tmp_path": get_default_tmp_path(), } def print_install_info() -> None: - """Print packages versions and architecture info.""" + """Show main packages versions.""" install_info = get_install_info() print(yaml.dump(install_info, sort_keys=False)) diff --git a/src/aac_metrics/utils/checks.py b/src/aac_metrics/utils/checks.py index da71f2a..2054fb1 100644 --- a/src/aac_metrics/utils/checks.py +++ b/src/aac_metrics/utils/checks.py @@ -112,8 +112,3 @@ def _check_java_version(version: str, min_major: int, max_major: int) -> bool: major_version = minor_version return min_major <= major_version <= max_major - - -@cache -def _warn_once(msg: str) -> None: - pylog.warning(msg) diff --git a/src/aac_metrics/utils/imports.py b/src/aac_metrics/utils/imports.py index b756eb3..cd8bc28 100644 --- a/src/aac_metrics/utils/imports.py +++ b/src/aac_metrics/utils/imports.py @@ -7,6 +7,7 @@ @cache def _package_is_available(package_name: str) -> bool: + """Returns True if package is installed.""" try: return find_spec(package_name) is not None except AttributeError: diff --git a/src/aac_metrics/utils/paths.py b/src/aac_metrics/utils/paths.py new file mode 100644 index 0000000..04bf7c4 --- /dev/null +++ b/src/aac_metrics/utils/paths.py @@ -0,0 +1,135 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- + +import logging +import os +import os.path as osp +import tempfile + +from typing import Optional, Union + + +pylog = logging.getLogger(__name__) + + +__DEFAULT_PATHS: dict[str, dict[str, Optional[str]]] = { + "cache": { + "user": None, + "env": "AAC_METRICS_CACHE_PATH", + "package": osp.expanduser(osp.join("~", ".cache")), + }, + "java": { + "user": None, + "env": "AAC_METRICS_JAVA_PATH", + "package": "java", + }, + "tmp": { + "user": None, + "env": "AAC_METRICS_TMP_PATH", + "package": tempfile.gettempdir(), + }, +} + + +# Public functions +def get_default_cache_path() -> str: + """Returns the default cache directory path. + + If :func:`~aac_metrics.utils.path.set_default_cache_path` has been used before with a string argument, it will return the value given to this function. + Else if the environment variable AAC_METRICS_CACHE_PATH has been set to a string, it will return its value. + Else it will be equal to "~/.cache" by default. + """ + return __get_default_path("cache") + + +def get_default_java_path() -> str: + """Returns the default java executable path. + + If :func:`~aac_metrics.utils.path.set_default_java_path` has been used before with a string argument, it will return the value given to this function. + Else if the environment variable AAC_METRICS_JAVA_PATH has been set to a string, it will return its value. + Else it will be equal to "java" by default. + """ + return __get_default_path("java") + + +def get_default_tmp_path() -> str: + """Returns the default temporary directory path. + + If :func:`~aac_metrics.utils.path.set_default_tmp_path` has been used before with a string argument, it will return the value given to this function. + Else if the environment variable AAC_METRICS_TMP_PATH has been set to a string, it will return its value. + Else it will be equal to the value returned by :func:`~tempfile.gettempdir()` by default. + """ + return __get_default_path("tmp") + + +def set_default_cache_path(cache_path: Optional[str]) -> None: + """Override default cache directory path.""" + __set_default_path("cache", cache_path) + + +def set_default_java_path(java_path: Optional[str]) -> None: + """Override default java executable path.""" + __set_default_path("java", java_path) + + +def set_default_tmp_path(tmp_path: Optional[str]) -> None: + """Override default temporary directory path.""" + __set_default_path("tmp", tmp_path) + + +# Private functions +def _get_cache_path(cache_path: Union[str, None] = ...) -> str: + return __get_path("cache", cache_path) + + +def _get_java_path(java_path: Union[str, None] = ...) -> str: + return __get_path("java", java_path) + + +def _get_tmp_path(tmp_path: Union[str, None] = ...) -> str: + return __get_path("tmp", tmp_path) + + +def __get_default_path(path_name: str) -> str: + paths = __DEFAULT_PATHS[path_name] + + for name, path_or_var in paths.items(): + if path_or_var is None: + continue + + if name.startswith("env"): + path = os.getenv(path_or_var, None) + else: + path = path_or_var + + if path is not None: + path = __process_path(path) + return path + + pylog.error(f"Paths values: {paths}") + raise RuntimeError( + f"Invalid default path for {path_name=}. (all default paths are None)" + ) + + +def __set_default_path( + path_name: str, + path: Optional[str], +) -> None: + if path is not ... and path is not None: + path = __process_path(path) + __DEFAULT_PATHS[path_name]["user"] = path + + +def __get_path(path_name: str, path: Union[str, None] = ...) -> str: + if path is ... or path is None: + return __get_default_path(path_name) + else: + path = __process_path(path) + return path + + +def __process_path(path: str) -> str: + path = osp.expanduser(path) + path = osp.expandvars(path) + return path diff --git a/src/aac_metrics/utils/tokenization.py b/src/aac_metrics/utils/tokenization.py index cfb4116..183145b 100644 --- a/src/aac_metrics/utils/tokenization.py +++ b/src/aac_metrics/utils/tokenization.py @@ -12,6 +12,11 @@ from aac_metrics.utils.checks import check_java_path from aac_metrics.utils.collections import flat_list, unflat_list +from aac_metrics.utils.paths import ( + _get_cache_path, + _get_java_path, + _get_tmp_path, +) pylog = logging.getLogger(__name__) @@ -46,19 +51,21 @@ def ptb_tokenize_batch( sentences: Iterable[str], audio_ids: Optional[Iterable[Hashable]] = None, - cache_path: str = "$HOME/.cache", - java_path: str = "java", - tmp_path: str = "/tmp", + cache_path: str = ..., + java_path: str = ..., + tmp_path: str = ..., punctuations: Iterable[str] = PTB_PUNCTUATIONS, + normalize_apostrophe: bool = False, verbose: int = 0, ) -> list[list[str]]: """Use PTB Tokenizer to process sentences. Should be used only with all the sentences of a subset due to slow computation. :param sentences: The sentences to tokenize. :param audio_ids: The optional audio names. None will use the audio index as name. defaults to None. - :param cache_path: The path to the external directory containing the JAR program. defaults to "$HOME/.cache". - :param java_path: The path to the java executable. defaults to "java". - :param tmp_path: The path to a temporary directory. defaults to "/tmp". + :param cache_path: The path to the external directory containing the JAR program. defaults to the value returned by :func:`~aac_metrics.utils.paths.get_default_cache_path`. + :param java_path: The path to the java executable. defaults to the value returned by :func:`~aac_metrics.utils.paths.get_default_java_path`. + :param tmp_path: The path to a temporary directory. defaults to the value returned by :func:`~aac_metrics.utils.paths.get_default_tmp_path`. + :param normalize_apostrophe: If True, add apostrophes for French language. defaults to False. :param verbose: The verbose level. defaults to 0. :returns: The sentences tokenized as list[list[str]]. """ @@ -66,9 +73,9 @@ def ptb_tokenize_batch( if len(sentences) == 0: return [] - cache_path = osp.expandvars(cache_path) - java_path = osp.expandvars(java_path) - tmp_path = osp.expandvars(tmp_path) + cache_path = _get_cache_path(cache_path) + java_path = _get_java_path(java_path) + tmp_path = _get_tmp_path(tmp_path) # Based on https://github.com/audio-captioning/caption-evaluation-tools/blob/c1798df4c91e29fe689b1ccd4ce45439ec966417/caption/pycocoevalcap/tokenizer/ptbtokenizer.py#L30 @@ -105,7 +112,7 @@ def ptb_tokenize_batch( ] # ====================================================== - # prepare data for PTB AACTokenizer + # Prepare data for PTB AACTokenizer # ====================================================== if audio_ids is None: audio_ids = list(range(len(sentences))) @@ -118,9 +125,18 @@ def ptb_tokenize_batch( ) sentences = "\n".join(sentences) + if normalize_apostrophe: + replaces = { + " s ": " s'", + "'": "' ", + "' ": "' ", + " '": "'", + } + for old, new in replaces.items(): + sentences = sentences.replace(old, new) # ====================================================== - # save sentences to temporary file + # Save sentences to temporary file # ====================================================== tmp_file = tempfile.NamedTemporaryFile( delete=False, @@ -132,7 +148,7 @@ def ptb_tokenize_batch( tmp_file.close() # ====================================================== - # tokenize sentence + # Tokenize sentence # ====================================================== cmd.append(osp.basename(tmp_file.name)) p_tokenizer = subprocess.Popen( @@ -148,7 +164,7 @@ def ptb_tokenize_batch( os.remove(tmp_file.name) # ====================================================== - # create dictionary for tokenized captions + # Create dictionary for tokenized captions # ====================================================== outs: Any = [None for _ in range(len(lines))] if len(audio_ids) != len(lines): @@ -173,10 +189,11 @@ def ptb_tokenize_batch( def preprocess_mono_sents( sentences: list[str], - cache_path: str = "$HOME/.cache", - java_path: str = "java", - tmp_path: str = "/tmp", + cache_path: str = ..., + java_path: str = ..., + tmp_path: str = ..., punctuations: Iterable[str] = PTB_PUNCTUATIONS, + normalize_apostrophe: bool = False, verbose: int = 0, ) -> list[str]: """Tokenize sentences using PTB Tokenizer then merge them by space. @@ -187,13 +204,22 @@ def preprocess_mono_sents( If you want to process multiple sentences (list[list[str]]), use `preprocess_mult_sents` instead. :param sentences: The list of sentences to process. - :param cache_path: The path to the external code directory. defaults to "$HOME/.cache". - :param java_path: The path to the java executable. defaults to "java". - :param tmp_path: Temporary directory path. defaults to "/tmp". + :param cache_path: The path to the external code directory. defaults to the value returned by :func:`~aac_metrics.utils.paths.get_default_cache_path`. + :param java_path: The path to the java executable. defaults to the value returned by :func:`~aac_metrics.utils.paths.get_default_java_path`. + :param tmp_path: Temporary directory path. defaults to the value returned by :func:`~aac_metrics.utils.paths.get_default_tmp_path`. + :param normalize_apostrophe: If True, add apostrophes for French language. defaults to False. + :param verbose: The verbose level. defaults to 0. :returns: The sentences processed by the tokenizer. """ tok_sents = ptb_tokenize_batch( - sentences, None, cache_path, java_path, tmp_path, punctuations, verbose + sentences=sentences, + audio_ids=None, + cache_path=cache_path, + java_path=java_path, + tmp_path=tmp_path, + punctuations=punctuations, + normalize_apostrophe=normalize_apostrophe, + verbose=verbose, ) sentences = [" ".join(sent) for sent in tok_sents] return sentences @@ -201,30 +227,34 @@ def preprocess_mono_sents( def preprocess_mult_sents( mult_sentences: list[list[str]], - cache_path: str = "$HOME/.cache", - java_path: str = "java", - tmp_path: str = "/tmp", + cache_path: str = ..., + java_path: str = ..., + tmp_path: str = ..., punctuations: Iterable[str] = PTB_PUNCTUATIONS, + normalize_apostrophe: bool = False, verbose: int = 0, ) -> list[list[str]]: """Tokenize multiple sentences using PTB Tokenizer with only one call then merge them by space. :param mult_sentences: The list of list of sentences to process. - :param cache_path: The path to the external code directory. defaults to "$HOME/.cache". - :param java_path: The path to the java executable. defaults to "java". - :param tmp_path: Temporary directory path. defaults to "/tmp". + :param cache_path: The path to the external code directory. defaults to the value returned by :func:`~aac_metrics.utils.paths.get_default_cache_path`. + :param java_path: The path to the java executable. defaults to the value returned by :func:`~aac_metrics.utils.paths.get_default_java_path`. + :param tmp_path: Temporary directory path. defaults to the value returned by :func:`~aac_metrics.utils.paths.get_default_tmp_path`. + :param normalize_apostrophe: If True, add apostrophes for French language. defaults to False. + :param verbose: The verbose level. defaults to 0. :returns: The multiple sentences processed by the tokenizer. """ # Flat list flatten_sents, sizes = flat_list(mult_sentences) flatten_sents = preprocess_mono_sents( - flatten_sents, - cache_path, - java_path, - tmp_path, - punctuations, - verbose, + sentences=flatten_sents, + cache_path=cache_path, + java_path=java_path, + tmp_path=tmp_path, + punctuations=punctuations, + normalize_apostrophe=normalize_apostrophe, + verbose=verbose, ) mult_sentences = unflat_list(flatten_sents, sizes) return mult_sentences diff --git a/tests/test_compare_cet.py b/tests/test_compare_cet.py index 8ce62b5..42eca79 100644 --- a/tests/test_compare_cet.py +++ b/tests/test_compare_cet.py @@ -2,9 +2,11 @@ # -*- coding: utf-8 -*- import importlib +import os import os.path as osp import subprocess import sys +import tempfile import unittest from pathlib import Path @@ -23,6 +25,8 @@ class TestCompareCaptionEvaluationTools(TestCase): # Set Up methods @classmethod def setUpClass(cls) -> None: + if os.name == "nt": + return None cls.evaluate_metrics_from_lists = cls._import_cet_eval_func() @classmethod @@ -55,8 +59,8 @@ def _import_cet_eval_func( sys.path.append(cet_path) # Override cache and tmp dir to avoid outputs in source code. spice_module = importlib.import_module("coco_caption.pycocoevalcap.spice.spice") - spice_module.CACHE_DIR = "/tmp" # type: ignore - spice_module.TEMP_DIR = "/tmp" # type: ignore + spice_module.CACHE_DIR = tempfile.gettempdir() # type: ignore + spice_module.TEMP_DIR = tempfile.gettempdir() # type: ignore eval_metrics_module = importlib.import_module("eval_metrics") evaluate_metrics_from_lists = eval_metrics_module.evaluate_metrics_from_lists return evaluate_metrics_from_lists @@ -97,6 +101,9 @@ def _get_example_0(self) -> tuple[list[str], list[list[str]]]: return cands, mrefs def _test_with_example(self, cands: list[str], mrefs: list[list[str]]) -> None: + if os.name == "nt": + return None + corpus_scores, _ = evaluate(cands, mrefs, metrics="dcase2020") ( cet_global_scores,