From e3c161d3a5b9f29afe8bdab0dd0bb1a3067f0ca3 Mon Sep 17 00:00:00 2001 From: Labbeti Date: Tue, 10 Oct 2023 11:35:28 +0200 Subject: [PATCH] Version 0.4.6 --- .github/workflows/python-package-pip.yaml | 2 +- CHANGELOG.md | 16 + CITATION.cff | 4 +- MANIFEST.in | 1 - README.md | 38 ++- docs/installation.rst | 1 + pyproject.toml | 2 +- src/aac_metrics/__init__.py | 28 +- src/aac_metrics/__main__.py | 2 +- src/aac_metrics/classes/evaluate.py | 47 +-- src/aac_metrics/download.py | 192 +++++++++--- src/aac_metrics/eval.py | 13 +- src/aac_metrics/evaluate.py | 237 --------------- src/aac_metrics/functional/evaluate.py | 39 ++- src/aac_metrics/functional/spice.py | 352 +++++++++++++++------- src/aac_metrics/functional/spider.py | 18 +- src/aac_metrics/info.py | 11 + src/aac_metrics/install_spice.sh | 55 ---- src/aac_metrics/utils/checks.py | 32 +- src/aac_metrics/utils/tokenization.py | 44 +-- tests/test_compare_cet.py | 25 +- tests/test_compare_fense.py | 49 +-- tests/test_doc_examples.py | 20 +- 23 files changed, 603 insertions(+), 625 deletions(-) delete mode 100644 src/aac_metrics/evaluate.py delete mode 100644 src/aac_metrics/install_spice.sh diff --git a/.github/workflows/python-package-pip.yaml b/.github/workflows/python-package-pip.yaml index 1503230..b8aaf46 100644 --- a/.github/workflows/python-package-pip.yaml +++ b/.github/workflows/python-package-pip.yaml @@ -50,7 +50,7 @@ jobs: - name: Install package shell: bash # note: ${GITHUB_REF##*/} gives the branch name - # note 2: dev is not the branch here, but the dev dependencies + # note 2: dev is NOT the branch here, but the dev dependencies run: | python -m pip install "aac-metrics[dev] @ git+https://github.com/Labbeti/aac-metrics@${GITHUB_REF##*/}" diff --git a/CHANGELOG.md b/CHANGELOG.md index 58ca8a5..ccdc785 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -2,6 +2,22 @@ All notable changes to this project will be documented in this file. +## [0.4.6] 2023-10-10 +### Added +- Argument `clean_archives` for `SPICE` download. + +### Changed +- Check if newline character is in the sentences before ptb tokenization. ([#6](https://github.com/Labbeti/aac-metrics/issues/6)) +- `SPICE` no longer requires bash script files for installation. + +### Fixed +- Maximal version of `transformers` dependancy set to 4.31.0 to avoid error with `FENSE` and `FluErr` metrics. +- `SPICE` crash message and error output files. +- Default value for `Evaluate` `metrics` argument. + +### Deleted +- Remove now useless `use_shell` option for download. + ## [0.4.5] 2023-09-12 ### Added - Argument `use_shell` for `METEOR` and `SPICE` metrics and `download` function to fix Windows-OS specific error. diff --git a/CITATION.cff b/CITATION.cff index 7cf6afe..872db92 100644 --- a/CITATION.cff +++ b/CITATION.cff @@ -19,5 +19,5 @@ keywords: - captioning - audio-captioning license: MIT -version: 0.4.5 -date-released: '2023-09-12' +version: 0.4.6 +date-released: '2023-10-10' diff --git a/MANIFEST.in b/MANIFEST.in index ccb7020..7542274 100644 --- a/MANIFEST.in +++ b/MANIFEST.in @@ -3,5 +3,4 @@ recursive-include src *.py global-exclude *.pyc global-exclude __pycache__ -include src/aac_metrics/install_spice.sh recursive-include data *.csv diff --git a/README.md b/README.md index 11094e1..1d32314 100644 --- a/README.md +++ b/README.md @@ -103,10 +103,10 @@ Each metrics also exists as a python class version, like `aac_metrics.classes.ci | Metric | Python Class | Origin | Range | Short description | |:---|:---|:---|:---|:---| | BLEU [[1]](#bleu) | `BLEU` | machine translation | [0, 1] | Precision of n-grams | -| ROUGE-L [[2]](#rouge-l) | `ROUGEL` | machine translation | [0, 1] | FScore of the longest common subsequence | +| ROUGE-L [[2]](#rouge-l) | `ROUGEL` | text summarization | [0, 1] | FScore of the longest common subsequence | | METEOR [[3]](#meteor) | `METEOR` | machine translation | [0, 1] | Cosine-similarity of frequencies with synonyms matching | | CIDEr-D [[4]](#cider) | `CIDErD` | image captioning | [0, 10] | Cosine-similarity of TF-IDF computed on n-grams | -| SPICE [[5]](#spice) | `SPICE` | image captioning | [0, 1] | FScore of semantic graph | +| SPICE [[5]](#spice) | `SPICE` | image captioning | [0, 1] | FScore of a semantic graph | | SPIDEr [[6]](#spider) | `SPIDEr` | image captioning | [0, 5.5] | Mean of CIDEr-D and SPICE | ### AAC-specific metrics @@ -114,16 +114,20 @@ Each metrics also exists as a python class version, like `aac_metrics.classes.ci |:---|:---|:---|:---|:---| | SPIDEr-max [[7]](#spider-max) | `SPIDErMax` | audio captioning | [0, 5.5] | Max of SPIDEr scores for multiples candidates | | SBERT-sim [[8]](#spider-max) | `SBERTSim` | audio captioning | [-1, 1] | Cosine-similarity of **Sentence-BERT embeddings** | -| Fluency Error [[8]](#spider-max) | `FluErr` | audio captioning | [0, 1] | Use a pretrained model to detect fluency errors in sentences | -| FENSE [[8]](#fense) | `FENSE` | audio captioning | [-1, 1] | Combines SBERT-sim and Fluency Error | -| SPIDEr-FL [[9]](#spider-fl) | `SPIDErFL` | audio captioning | [0, 5.5] | Combines SPIDEr and Fluency Error | +| Fluency error rate [[8]](#spider-max) | `FluErr` | audio captioning | [0, 1] | Detect fluency errors in sentences with a pretrained model | +| FENSE [[8]](#fense) | `FENSE` | audio captioning | [-1, 1] | Combines SBERT-sim and Fluency Error rate | +| SPIDEr-FL [[9]](#spider-fl) | `SPIDErFL` | audio captioning | [0, 5.5] | Combines SPIDEr and Fluency Error rate | + +### AAC metrics not implemented +- CB-Score [[10]](#cb-score) +- SPICE+ [[11]](#spice-plus) +- ACES [[12]](#aces) (can be found here: https://github.com/GlJS/ACES) ## Requirements This package has been developped for Ubuntu 20.04, and it is expected to work on most Linux distributions. Windows is not officially supported. ### Python packages - The pip requirements are automatically installed when using `pip install` on this repository. ``` torch >= 1.10.1 @@ -141,11 +145,14 @@ Most of these functions can specify a java executable path with `java_path` argu - `unzip` command to extract SPICE zipped files. ## Additional notes -### CIDEr or CIDEr-D ? +### CIDEr or CIDEr-D? The CIDEr metric differs from CIDEr-D because it applies a stemmer to each word before computing the n-grams of the sentences. In AAC, only the CIDEr-D is reported and used for SPIDEr in [caption-evaluation-tools](https://github.com/audio-captioning/caption-evaluation-tools), but some papers called it "CIDEr". -### Does metrics work on multi-GPU ? -No. Most of these metrics use numpy or external java programs to run, which prevents multi-GPU testing for now. +### Do metrics work on multi-GPU? +No. Most of these metrics use numpy or external java programs to run, which prevents multi-GPU testing in parallel. + +### Do metrics work on Windows/Mac OS? +Maybe. Most of the metrics only need python to run, which can be done on Windows. However, you might expect errors with METEOR metric, SPICE-based metrics and PTB tokenizer, since they requires an external java program to run. ## SPIDEr-max metric SPIDEr-max [[7]](#spider-max) is a metric based on SPIDEr that takes into account multiple candidates for the same audio. It computes the maximum of the SPIDEr scores for each candidate to balance the high sensitivity to the frequency of the words generated by the model. For more detail, please see the [documentation about SPIDEr-max](https://aac-metrics.readthedocs.io/en/stable/spider_max.html). @@ -197,6 +204,15 @@ arXiv: 1612.00370. [Online]. Available: http://arxiv.org/abs/1612.00370 #### SPIDEr-FL [9] DCASE website task6a description: https://dcase.community/challenge2023/task-automated-audio-captioning#evaluation +#### CB-score +[11] I. Martín-Morató, M. Harju, and A. Mesaros, “A Summarization Approach to Evaluating Audio Captioning,” Nov. 2022. [Online]. Available: https://dcase.community/documents/workshop2022/proceedings/DCASE2022Workshop_Martin-Morato_35.pdf + +#### SPICE-plus +[10] F. Gontier, R. Serizel, and C. Cerisara, “SPICE+: Evaluation of Automatic Audio Captioning Systems with Pre-Trained Language Models,” in ICASSP 2023 - 2023 IEEE International Conference on Acoustics, Speech and Signal Processing (ICASSP), 2023, pp. 1–5. doi: 10.1109/ICASSP49357.2023.10097021. + +#### ACES +[12] G. Wijngaard, E. Formisano, B. L. Giordano, M. Dumontier, “ACES: Evaluating Automated Audio Captioning Models on the Semantics of Sounds”, in EUSIPCO 2023, 2023. + ## Citation If you use **SPIDEr-max**, you can cite the following paper using BibTex : ``` @@ -217,10 +233,10 @@ If you use this software, please consider cite it as below : Labbe_aac-metrics_2023, author = {Labbé, Etienne}, license = {MIT}, - month = {9}, + month = {10}, title = {{aac-metrics}}, url = {https://github.com/Labbeti/aac-metrics/}, - version = {0.4.5}, + version = {0.4.6}, year = {2023}, } ``` diff --git a/docs/installation.rst b/docs/installation.rst index 0070cff..d618409 100644 --- a/docs/installation.rst +++ b/docs/installation.rst @@ -26,3 +26,4 @@ The python requirements are automatically installed when using pip on this repos pyyaml>=6.0 tqdm>=4.64.0 sentence-transformers>=2.2.2 + transformers<4.31.0 diff --git a/pyproject.toml b/pyproject.toml index 9964a66..5ce6ac1 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -27,6 +27,7 @@ dependencies = [ "pyyaml>=6.0", "tqdm>=4.64.0", "sentence-transformers>=2.2.2", + "transformers<4.31.0", ] dynamic = ["version"] @@ -50,7 +51,6 @@ dev = [ "scikit-image==0.19.2", "matplotlib==3.5.2", "torchmetrics>=0.10", - "transformers<4.31.0", ] [tool.setuptools.packages.find] diff --git a/src/aac_metrics/__init__.py b/src/aac_metrics/__init__.py index 99fccf0..a4aa875 100644 --- a/src/aac_metrics/__init__.py +++ b/src/aac_metrics/__init__.py @@ -1,28 +1,32 @@ #!/usr/bin/env python # -*- coding: utf-8 -*- -"""Audio Captioning metrics package. -""" +"""Metrics for evaluating Automated Audio Captioning systems, designed for PyTorch. """ + -__name__ = "aac-metrics" __author__ = "Etienne Labbé (Labbeti)" __author_email__ = "labbeti.pub@gmail.com" __license__ = "MIT" __maintainer__ = "Etienne Labbé (Labbeti)" +__name__ = "aac-metrics" __status__ = "Development" -__version__ = "0.4.5" +__version__ = "0.4.6" from .classes.base import AACMetric from .classes.bleu import BLEU from .classes.cider_d import CIDErD -from .classes.evaluate import DCASE2023Evaluate, _get_metric_factory_classes +from .classes.evaluate import Evaluate, DCASE2023Evaluate, _get_metric_factory_classes +from .classes.fluerr import FluErr from .classes.fense import FENSE from .classes.meteor import METEOR from .classes.rouge_l import ROUGEL +from .classes.sbert_sim import SBERTSim from .classes.spice import SPICE from .classes.spider import SPIDEr -from .functional.evaluate import dcase2023_evaluate, evaluate +from .classes.spider_fl import SPIDErFL +from .classes.spider_max import SPIDErMax +from .functional.evaluate import evaluate, dcase2023_evaluate from .utils.paths import ( get_default_cache_path, get_default_java_path, @@ -34,16 +38,22 @@ __all__ = [ + "AACMetric", "BLEU", "CIDErD", + "Evaluate", "DCASE2023Evaluate", "FENSE", + "FluErr", "METEOR", "ROUGEL", + "SBERTSim", "SPICE", "SPIDEr", - "dcase2023_evaluate", + "SPIDErFL", + "SPIDErMax", "evaluate", + "dcase2023_evaluate", "get_default_cache_path", "get_default_java_path", "get_default_tmp_path", @@ -58,8 +68,8 @@ def load_metric(name: str, **kwargs) -> AACMetric: """Load a metric class by name. :param name: The name of the metric. - Must be one of ("bleu_1", "bleu_2", "bleu_3", "bleu_4", "meteor", "rouge_l", "cider_d", "spice", "spider", "fense"). - :param **kwargs: The keyword optional arguments passed to the metric. + Can be one of ("bleu_1", "bleu_2", "bleu_3", "bleu_4", "meteor", "rouge_l", "cider_d", "spice", "spider", "fense"). + :param **kwargs: The keyword optional arguments passed to the metric factory. :returns: The Metric object built. """ name = name.lower().strip() diff --git a/src/aac_metrics/__main__.py b/src/aac_metrics/__main__.py index 9bc5e95..03968fd 100644 --- a/src/aac_metrics/__main__.py +++ b/src/aac_metrics/__main__.py @@ -6,7 +6,7 @@ def _print_usage() -> None: print( "Command line usage :\n" "- Download models and external code : aac-metrics-download ...\n" - "- Print scores from candidate and references file : aac-metrics-evaluate -i [FILEPATH]\n" + "- Print scores from candidate and references file : aac-metrics-eval -i [FILEPATH]\n" "- Print package version : aac-metrics-info\n" "- Show this usage page : aac-metrics\n" ) diff --git a/src/aac_metrics/classes/evaluate.py b/src/aac_metrics/classes/evaluate.py index d633c57..3df9836 100644 --- a/src/aac_metrics/classes/evaluate.py +++ b/src/aac_metrics/classes/evaluate.py @@ -5,7 +5,7 @@ import pickle import zlib -from typing import Callable, Iterable, Union +from typing import Any, Callable, Iterable, Union import torch @@ -22,7 +22,11 @@ from aac_metrics.classes.spice import SPICE from aac_metrics.classes.spider import SPIDEr from aac_metrics.classes.spider_fl import SPIDErFL -from aac_metrics.functional.evaluate import METRICS_SETS, evaluate +from aac_metrics.functional.evaluate import ( + DEFAULT_METRICS_SET_NAME, + METRICS_SETS, + evaluate, +) pylog = logging.getLogger(__name__) @@ -41,7 +45,9 @@ class Evaluate(list[AACMetric], AACMetric[tuple[dict[str, Tensor], dict[str, Ten def __init__( self, preprocess: bool = True, - metrics: Union[str, Iterable[str], Iterable[AACMetric]] = "aac", + metrics: Union[ + str, Iterable[str], Iterable[AACMetric] + ] = DEFAULT_METRICS_SET_NAME, cache_path: str = ..., java_path: str = ..., tmp_path: str = ..., @@ -171,74 +177,79 @@ def _get_metric_factory_classes( tmp_path: str = ..., device: Union[str, torch.device, None] = "auto", verbose: int = 0, + init_kwds: dict[str, Any] = ..., ) -> dict[str, Callable[[], AACMetric]]: - return { + if init_kwds is ...: + init_kwds = {} + + init_kwds = init_kwds | dict(return_all_scores=return_all_scores) + + factory = { "bleu": lambda: BLEU( - return_all_scores=return_all_scores, + **init_kwds, ), "bleu_1": lambda: BLEU( - return_all_scores=return_all_scores, n=1, + **init_kwds, ), "bleu_2": lambda: BLEU( - return_all_scores=return_all_scores, n=2, ), "bleu_3": lambda: BLEU( - return_all_scores=return_all_scores, n=3, + **init_kwds, ), "bleu_4": lambda: BLEU( - return_all_scores=return_all_scores, n=4, + **init_kwds, ), "meteor": lambda: METEOR( - return_all_scores=return_all_scores, cache_path=cache_path, java_path=java_path, verbose=verbose, + **init_kwds, ), "rouge_l": lambda: ROUGEL( - return_all_scores=return_all_scores, + **init_kwds, ), "cider_d": lambda: CIDErD( - return_all_scores=return_all_scores, + **init_kwds, ), "spice": lambda: SPICE( - return_all_scores=return_all_scores, cache_path=cache_path, java_path=java_path, tmp_path=tmp_path, verbose=verbose, + **init_kwds, ), "spider": lambda: SPIDEr( - return_all_scores=return_all_scores, cache_path=cache_path, java_path=java_path, tmp_path=tmp_path, verbose=verbose, + **init_kwds, ), "sbert_sim": lambda: SBERTSim( - return_all_scores=return_all_scores, device=device, verbose=verbose, + **init_kwds, ), "fluerr": lambda: FluErr( - return_all_scores=return_all_scores, device=device, verbose=verbose, ), "fense": lambda: FENSE( - return_all_scores=return_all_scores, device=device, verbose=verbose, + **init_kwds, ), "spider_fl": lambda: SPIDErFL( - return_all_scores=return_all_scores, cache_path=cache_path, java_path=java_path, tmp_path=tmp_path, device=device, verbose=verbose, + **init_kwds, ), } + return factory diff --git a/src/aac_metrics/download.py b/src/aac_metrics/download.py index bfdb103..9aea94d 100644 --- a/src/aac_metrics/download.py +++ b/src/aac_metrics/download.py @@ -4,13 +4,11 @@ import logging import os import os.path as osp -import platform -import subprocess +import shutil import sys from argparse import ArgumentParser, Namespace -from subprocess import CalledProcessError -from typing import Optional +from zipfile import ZipFile from torch.hub import download_url_to_file @@ -20,6 +18,7 @@ FNAME_SPICE_JAR, DNAME_SPICE_LOCAL_CACHE, DNAME_SPICE_CACHE, + check_spice_install, ) from aac_metrics.utils.paths import ( _get_cache_path, @@ -62,23 +61,27 @@ "url": "https://github.com/tylin/coco-caption/raw/master/pycocoevalcap/spice/spice-1.0.jar", "fname": "spice-1.0.jar", }, + "stanford_nlp": { + "url": "https://github.com/tylin/coco-caption/raw/master/pycocoevalcap/tokenizer/stanford-corenlp-3.4.1.jar", + "fname": "stanford-corenlp-3.4.1.jar", + }, "spice_zip": { "url": "https://panderson.me/images/SPICE-1.0.zip", "fname": "SPICE-1.0.zip", }, - "stanford_nlp": { - "url": "https://github.com/tylin/coco-caption/raw/master/pycocoevalcap/tokenizer/stanford-corenlp-3.4.1.jar", - "fname": "stanford-corenlp-3.4.1.jar", + "spice_corenlp_zip": { + "url": "http://nlp.stanford.edu/software/stanford-corenlp-full-2015-12-09.zip", + "fname": osp.join("SPICE-1.0", "stanford-corenlp-full-2015-12-09.zip"), }, } _TRUE_VALUES = ("true", "1", "t") _FALSE_VALUES = ("false", "0", "f") -def download( +def download_metrics( cache_path: str = ..., tmp_path: str = ..., - use_shell: Optional[bool] = None, + clean_archives: bool = True, ptb_tokenizer: bool = True, meteor: bool = True, spice: bool = True, @@ -89,9 +92,7 @@ def download( :param cache_path: The path to the external code directory. defaults to the value returned by :func:`~aac_metrics.utils.paths.get_default_cache_path`. :param tmp_path: The path to a temporary directory. defaults to the value returned by :func:`~aac_metrics.utils.paths.get_default_tmp_path`. - :param use_shell: Optional argument to force use os-specific shell for the bash installation script program. - If None, it will use shell only on Windows OS. - defaults to None. + :param clean_archives: If True, remove all archives files. defaults to True. :param ptb_tokenizer: If True, downloads the PTBTokenizer code in cache directory. defaults to True. :param meteor: If True, downloads the METEOR code in cache directory. defaults to True. :param spice: If True, downloads the SPICE code in cache directory. defaults to True. @@ -99,7 +100,7 @@ def download( :param verbose: The verbose level. defaults to 0. """ if verbose >= 1: - pylog.info(f"aac-metrics download started.") + pylog.info("aac-metrics download started.") cache_path = _get_cache_path(cache_path) tmp_path = _get_tmp_path(tmp_path) @@ -119,13 +120,13 @@ def download( _download_meteor(cache_path, verbose) if spice: - _download_spice(cache_path, use_shell, verbose) + _download_spice(cache_path, clean_archives, verbose) if fense: _download_fense(verbose) if verbose >= 1: - pylog.info(f"aac-metrics download finished.") + pylog.info("aac-metrics download finished.") def _download_ptb_tokenizer( @@ -143,6 +144,7 @@ def _download_ptb_tokenizer( url = info["url"] fname = info["fname"] fpath = osp.join(stanford_nlp_dpath, fname) + if not osp.isfile(fpath): if verbose >= 1: pylog.info( @@ -190,9 +192,43 @@ def _download_meteor( def _download_spice( cache_path: str, - use_shell: Optional[bool] = None, + clean_archives: bool = True, verbose: int = 0, ) -> None: + """Download SPICE java code. + + Target SPICE directory tree: + + spice + ├── cache + ├── lib + │ ├── ejml-0.23.jar + │ ├── fst-2.47.jar + │ ├── guava-19.0.jar + │ ├── hamcrest-core-1.3.jar + │ ├── jackson-core-2.5.3.jar + │ ├── javassist-3.19.0-GA.jar + │ ├── json-simple-1.1.1.jar + │ ├── junit-4.12.jar + │ ├── lmdbjni-0.4.6.jar + │ ├── lmdbjni-linux64-0.4.6.jar + │ ├── lmdbjni-osx64-0.4.6.jar + │ ├── lmdbjni-win64-0.4.6.jar + │ ├── Meteor-1.5.jar + │ ├── objenesis-2.4.jar + │ ├── SceneGraphParser-1.0.jar + │ ├── slf4j-api-1.7.12.jar + │ ├── slf4j-simple-1.7.21.jar + │ ├── stanford-corenlp-3.6.0.jar + │ └── stanford-corenlp-3.6.0-models.jar + └── spice-1.0.jar + """ + try: + check_spice_install(cache_path) + return None + except (FileNotFoundError, NotADirectoryError): + pass + # Download JAR files for SPICE metric spice_cache_dpath = osp.join(cache_path, DNAME_SPICE_CACHE) spice_jar_dpath = osp.join(cache_path, osp.dirname(FNAME_SPICE_JAR)) @@ -201,40 +237,73 @@ def _download_spice( os.makedirs(spice_jar_dpath, exist_ok=True) os.makedirs(spice_local_cache_path, exist_ok=True) - spice_zip_url = DATA_URLS["spice_zip"]["url"] - spice_zip_fpath = osp.join(spice_cache_dpath, DATA_URLS["spice_zip"]["fname"]) + for name in ("spice_zip", "spice_corenlp_zip"): + url = DATA_URLS[name]["url"] + fname = DATA_URLS[name]["fname"] + fpath = osp.join(spice_cache_dpath, fname) - if osp.isfile(spice_zip_fpath): - if verbose >= 1: - pylog.info(f"SPICE ZIP file '{spice_zip_fpath}' is already downloaded.") - else: - if verbose >= 1: - pylog.info(f"Downloading SPICE ZIP file '{spice_zip_fpath}'...") - download_url_to_file(spice_zip_url, spice_zip_fpath, progress=verbose > 0) + if osp.isfile(fpath): + if verbose >= 1: + pylog.info(f"File '{fpath}' is already downloaded for SPICE.") + else: + if verbose >= 1: + pylog.info(f"Downloading file '{fpath}' for SPICE...") - script_path = osp.join(osp.dirname(__file__), "install_spice.sh") - if not osp.isfile(script_path): - raise FileNotFoundError(f"Cannot find script '{osp.basename(script_path)}'.") + dpath = osp.dirname(fpath) + os.makedirs(dpath, exist_ok=True) + download_url_to_file(url, fpath, progress=verbose > 0) - if verbose >= 1: - pylog.info( - f"Downloading JAR sources for SPICE metric into '{spice_jar_dpath}'..." - ) + if fname.endswith(".zip"): + if verbose >= 1: + pylog.info(f"Extracting {fname} to {spice_cache_dpath}...") + + with ZipFile(fpath, "r") as file: + file.extractall(spice_cache_dpath) + + spice_lib_dpath = osp.join(spice_cache_dpath, "lib") + spice_unzip_dpath = osp.join(spice_cache_dpath, "SPICE-1.0") + corenlp_dpath = osp.join(spice_cache_dpath, "stanford-corenlp-full-2015-12-09") + + # Note: order matter here + to_move = [ + ("f", osp.join(spice_unzip_dpath, "spice-1.0.jar"), spice_cache_dpath), + ("f", osp.join(corenlp_dpath, "stanford-corenlp-3.6.0.jar"), spice_lib_dpath), + ( + "f", + osp.join(corenlp_dpath, "stanford-corenlp-3.6.0-models.jar"), + spice_lib_dpath, + ), + ] + for name in os.listdir(osp.join(spice_unzip_dpath, "lib")): + if not name.endswith(".jar"): + continue + fpath = osp.join(spice_unzip_dpath, "lib", name) + to_move.append(("f", fpath, spice_lib_dpath)) - if use_shell is None: - use_shell = platform.system() == "Windows" + os.makedirs(spice_lib_dpath, exist_ok=True) - command = ["bash", script_path, spice_jar_dpath] - try: - subprocess.check_call( - command, - stdout=None if verbose >= 2 else subprocess.DEVNULL, - stderr=None if verbose >= 2 else subprocess.DEVNULL, - shell=use_shell, - ) - except (CalledProcessError, PermissionError) as err: - pylog.error("Cannot install SPICE java source code.") - raise err + for i, (_src_type, src_path, parent_tgt_dpath) in enumerate(to_move): + tgt_path = osp.join(parent_tgt_dpath, osp.basename(src_path)) + + if osp.exists(tgt_path): + if verbose >= 1: + pylog.info( + f"Target '{tgt_path}' already exists. ({i+1}/{len(to_move)})" + ) + else: + if verbose >= 1: + pylog.info( + f"Moving '{src_path}' to '{parent_tgt_dpath}'... ({i+1}/{len(to_move)})" + ) + shutil.move(src_path, parent_tgt_dpath) + + shutil.rmtree(corenlp_dpath) + if clean_archives: + spice_zip_fname = DATA_URLS["spice_zip"]["fname"] + spice_zip_fpath = osp.join(spice_cache_dpath, spice_zip_fname) + + os.remove(spice_zip_fpath) + shutil.rmtree(spice_unzip_dpath) def _download_fense( @@ -263,6 +332,12 @@ def _get_main_download_args() -> Namespace: default=get_default_tmp_path(), help="Temporary directory path.", ) + parser.add_argument( + "--clean_archives", + type=_str_to_bool, + default=True, + help="If True, remove all archives files. defaults to True.", + ) parser.add_argument( "--ptb_tokenizer", type=_str_to_bool, @@ -293,21 +368,38 @@ def _get_main_download_args() -> Namespace: return args -def _main_download() -> None: +def _setup_logging(verbose: int = 1) -> None: format_ = "[%(asctime)s][%(name)s][%(levelname)s] - %(message)s" handler = logging.StreamHandler(sys.stdout) handler.setFormatter(logging.Formatter(format_)) pkg_logger = logging.getLogger("aac_metrics") - pkg_logger.addHandler(handler) + found = False + for handler in pkg_logger.handlers: + if isinstance(handler, logging.StreamHandler) and handler.stream is sys.stdout: + found = True + break + if not found: + pkg_logger.addHandler(handler) + + if verbose <= 0: + level = logging.WARNING + elif verbose == 1: + level = logging.INFO + else: + level = logging.DEBUG + pkg_logger.setLevel(level) + + +def _main_download() -> None: args = _get_main_download_args() - level = logging.INFO if args.verbose <= 1 else logging.DEBUG - pkg_logger.setLevel(level) + _setup_logging(args.verbose) - download( + download_metrics( cache_path=args.cache_path, tmp_path=args.tmp_path, + clean_archives=args.clean_archives, ptb_tokenizer=args.ptb_tokenizer, meteor=args.meteor, spice=args.spice, diff --git a/src/aac_metrics/eval.py b/src/aac_metrics/eval.py index 7c52883..cd19be9 100644 --- a/src/aac_metrics/eval.py +++ b/src/aac_metrics/eval.py @@ -3,7 +3,6 @@ import csv import logging -import sys from argparse import ArgumentParser, Namespace from pathlib import Path @@ -22,6 +21,7 @@ get_default_java_path, get_default_tmp_path, ) +from aac_metrics.download import _setup_logging pylog = logging.getLogger(__name__) @@ -190,20 +190,13 @@ def _get_main_evaluate_args() -> Namespace: def _main_eval() -> None: - format_ = "[%(asctime)s][%(name)s][%(levelname)s] - %(message)s" - handler = logging.StreamHandler(sys.stdout) - handler.setFormatter(logging.Formatter(format_)) - pkg_logger = logging.getLogger("aac_metrics") - pkg_logger.addHandler(handler) - args = _get_main_evaluate_args() + _setup_logging(args.verbose) + if not check_java_path(args.java_path): raise RuntimeError(f"Invalid Java executable. ({args.java_path})") - level = logging.INFO if args.verbose <= 1 else logging.DEBUG - pkg_logger.setLevel(level) - if args.verbose >= 1: pylog.info(f"Load file {args.input_file}...") diff --git a/src/aac_metrics/evaluate.py b/src/aac_metrics/evaluate.py deleted file mode 100644 index fcfba37..0000000 --- a/src/aac_metrics/evaluate.py +++ /dev/null @@ -1,237 +0,0 @@ -#!/usr/bin/env python -# -*- coding: utf-8 -*- - -import csv -import logging -import sys - -from argparse import ArgumentParser, Namespace -from pathlib import Path -from typing import Iterable, Union - -import yaml - -from aac_metrics.functional.evaluate import ( - evaluate, - METRICS_SETS, - DEFAULT_METRICS_SET_NAME, -) -from aac_metrics.utils.checks import check_metric_inputs, check_java_path -from aac_metrics.utils.paths import ( - get_default_cache_path, - get_default_java_path, - get_default_tmp_path, -) - - -pylog = logging.getLogger(__name__) - - -def load_csv_file( - fpath: Union[str, Path], - cands_columns: Union[str, Iterable[str]] = ("caption_predicted",), - mrefs_columns: Union[str, Iterable[str]] = ( - "caption_1", - "caption_2", - "caption_3", - "caption_4", - "caption_5", - ), - load_mult_cands: bool = False, - strict: bool = True, -) -> tuple[list, list[list[str]]]: - """Load candidates and mult_references from a CSV file. - - :param fpath: The filepath to the CSV file. - :param cands_columns: The columns of the candidates. defaults to ("captions_predicted",). - :param mrefs_columns: The columns of the multiple references. defaults to ("caption_1", "caption_2", "caption_3", "caption_4", "caption_5"). - :param load_mult_cands: If True, load multiple candidates from file. defaults to False. - :returns: A tuple of (candidates, mult_references) loaded from file. - """ - if isinstance(cands_columns, str): - cands_columns = [cands_columns] - else: - cands_columns = list(cands_columns) - - if isinstance(mrefs_columns, str): - mrefs_columns = [mrefs_columns] - else: - mrefs_columns = list(mrefs_columns) - - with open(fpath, "r") as file: - reader = csv.DictReader(file) - fieldnames = reader.fieldnames - data = list(reader) - - if fieldnames is None: - raise ValueError(f"Cannot read fieldnames in CSV file {fpath=}.") - - file_cands_columns = [column for column in cands_columns if column in fieldnames] - file_mrefs_columns = [column for column in mrefs_columns if column in fieldnames] - - if strict: - if len(file_cands_columns) != len(cands_columns): - raise ValueError( - f"Cannot find all candidates columns {cands_columns=} in file '{fpath}'." - ) - if len(file_mrefs_columns) != len(mrefs_columns): - raise ValueError( - f"Cannot find all references columns {mrefs_columns=} in file '{fpath}'." - ) - - if (load_mult_cands and len(file_cands_columns) <= 0) or ( - not load_mult_cands and len(file_cands_columns) != 1 - ): - raise ValueError( - f"Cannot find candidate column in file. ({cands_columns=} not found in {fieldnames=})" - ) - if len(file_mrefs_columns) <= 0: - raise ValueError( - f"Cannot find references columns in file. ({mrefs_columns=} not found in {fieldnames=})" - ) - - if load_mult_cands: - mult_candidates = _load_columns(data, file_cands_columns) - mult_references = _load_columns(data, file_mrefs_columns) - return mult_candidates, mult_references - else: - file_cand_column = file_cands_columns[0] - candidates = [data_i[file_cand_column] for data_i in data] - mult_references = _load_columns(data, file_mrefs_columns) - return candidates, mult_references - - -def _load_columns(data: list[dict[str, str]], columns: list[str]) -> list[list[str]]: - mult_sentences = [] - for data_i in data: - raw_sents = [data_i[column] for column in columns] - sents = [] - for raw_sent in raw_sents: - # Refs columns can be list[str] - if "[" in raw_sent and "]" in raw_sent: - try: - sent = eval(raw_sent) - assert isinstance(sent, list) and all( - isinstance(sent_i, str) for sent_i in sent - ) - sents += sent - except (SyntaxError, NameError): - sents.append(raw_sent) - else: - sents.append(raw_sent) - - mult_sentences.append(sents) - return mult_sentences - - -def _get_main_evaluate_args() -> Namespace: - parser = ArgumentParser(description="Evaluate an output file.") - - parser.add_argument( - "--input_file", - "-i", - type=str, - default="", - help="The input file path containing the candidates and references.", - required=True, - ) - parser.add_argument( - "--cand_columns", - "-cc", - type=str, - nargs="+", - default=("caption_predicted", "preds", "cands"), - help="The column names of the candidates in the CSV file. defaults to ('caption_predicted', 'preds', 'cands').", - ) - parser.add_argument( - "--mrefs_columns", - "-rc", - type=str, - nargs="+", - default=( - "caption_1", - "caption_2", - "caption_3", - "caption_4", - "caption_5", - "captions", - ), - help="The column names of the candidates in the CSV file. defaults to ('caption_1', 'caption_2', 'caption_3', 'caption_4', 'caption_5', 'captions').", - ) - parser.add_argument( - "--metrics_set_name", - type=str, - default=DEFAULT_METRICS_SET_NAME, - choices=tuple(METRICS_SETS.keys()), - help=f"The metrics set to compute. Can be one of {tuple(METRICS_SETS.keys())}. defaults to 'default'.", - ) - parser.add_argument( - "--cache_path", - type=str, - default=get_default_cache_path(), - help=f"Cache directory path. defaults to '{get_default_cache_path()}'.", - ) - parser.add_argument( - "--java_path", - type=str, - default=get_default_java_path(), - help=f"Java executable path. defaults to '{get_default_java_path()}'.", - ) - parser.add_argument( - "--tmp_path", - type=str, - default=get_default_tmp_path(), - help=f"Temporary directory path. defaults to '{get_default_tmp_path()}'.", - ) - parser.add_argument("--verbose", type=int, default=0, help="Verbose level.") - - args = parser.parse_args() - return args - - -def _main_evaluate() -> None: - format_ = "[%(asctime)s][%(name)s][%(levelname)s] - %(message)s" - handler = logging.StreamHandler(sys.stdout) - handler.setFormatter(logging.Formatter(format_)) - pkg_logger = logging.getLogger("aac_metrics") - pkg_logger.addHandler(handler) - - args = _get_main_evaluate_args() - - if not check_java_path(args.java_path): - raise RuntimeError(f"Invalid Java executable. ({args.java_path})") - - level = logging.INFO if args.verbose <= 1 else logging.DEBUG - pkg_logger.setLevel(level) - - if args.verbose >= 1: - pylog.info(f"Load file {args.input_file}...") - - candidates, mult_references = load_csv_file( - args.input_file, args.cand_columns, args.mrefs_columns - ) - check_metric_inputs(candidates, mult_references) - - refs_lens = list(map(len, mult_references)) - if args.verbose >= 1: - pylog.info( - f"Found {len(candidates)} candidates, {len(mult_references)} references and [{min(refs_lens)}, {max(refs_lens)}] references per candidate." - ) - - corpus_scores, _sents_scores = evaluate( - candidates=candidates, - mult_references=mult_references, - preprocess=True, - metrics=args.metrics_set_name, - cache_path=args.cache_path, - java_path=args.java_path, - tmp_path=args.tmp_path, - verbose=args.verbose, - ) - - corpus_scores = {k: v.item() for k, v in corpus_scores.items()} - pylog.info(f"Global scores:\n{yaml.dump(corpus_scores, sort_keys=False)}") - - -if __name__ == "__main__": - _main_evaluate() diff --git a/src/aac_metrics/functional/evaluate.py b/src/aac_metrics/functional/evaluate.py index e5c1775..4d0212f 100644 --- a/src/aac_metrics/functional/evaluate.py +++ b/src/aac_metrics/functional/evaluate.py @@ -239,88 +239,95 @@ def _get_metric_factory_functions( tmp_path: str = ..., device: Union[str, torch.device, None] = "auto", verbose: int = 0, + init_kwds: dict[str, Any] = ..., ) -> dict[str, Callable[[list[str], list[list[str]]], Any]]: - return { + if init_kwds is ...: + init_kwds = {} + + init_kwds = init_kwds | dict(return_all_scores=return_all_scores) + + factory = { "bleu": partial( bleu, - return_all_scores=return_all_scores, + **init_kwds, ), "bleu_1": partial( bleu, - return_all_scores=return_all_scores, n=1, + **init_kwds, ), "bleu_2": partial( bleu, - return_all_scores=return_all_scores, n=2, + **init_kwds, ), "bleu_3": partial( bleu, - return_all_scores=return_all_scores, n=3, + **init_kwds, ), "bleu_4": partial( bleu, - return_all_scores=return_all_scores, n=4, + **init_kwds, ), "meteor": partial( meteor, - return_all_scores=return_all_scores, cache_path=cache_path, java_path=java_path, verbose=verbose, + **init_kwds, ), "rouge_l": partial( rouge_l, - return_all_scores=return_all_scores, + **init_kwds, ), "cider_d": partial( cider_d, - return_all_scores=return_all_scores, + **init_kwds, ), "spice": partial( spice, - return_all_scores=return_all_scores, cache_path=cache_path, java_path=java_path, tmp_path=tmp_path, verbose=verbose, + **init_kwds, ), "spider": partial( spider, - return_all_scores=return_all_scores, cache_path=cache_path, java_path=java_path, tmp_path=tmp_path, verbose=verbose, + **init_kwds, ), "sbert_sim": partial( sbert_sim, - return_all_scores=return_all_scores, device=device, verbose=verbose, + **init_kwds, ), - "fluerr": partial( # type: ignore + "fluerr": partial( fluerr, - return_all_scores=return_all_scores, device=device, verbose=verbose, + **init_kwds, ), "fense": partial( fense, - return_all_scores=return_all_scores, device=device, verbose=verbose, + **init_kwds, ), "spider_fl": partial( spider_fl, - return_all_scores=return_all_scores, cache_path=cache_path, java_path=java_path, tmp_path=tmp_path, device=device, verbose=verbose, + **init_kwds, ), } + return factory diff --git a/src/aac_metrics/functional/spice.py b/src/aac_metrics/functional/spice.py index f673821..359f5e5 100644 --- a/src/aac_metrics/functional/spice.py +++ b/src/aac_metrics/functional/spice.py @@ -1,6 +1,7 @@ #!/usr/bin/env python # -*- coding: utf-8 -*- +import copy import json import logging import math @@ -85,16 +86,21 @@ def spice( java_path = _get_java_path(java_path) tmp_path = _get_tmp_path(tmp_path) + # Sometimes the java program can freeze, so timeout has been added to avoid using job time. + if timeout is None or isinstance(timeout, (int, float)): + timeout_lst = [timeout] + else: + timeout_lst = list(timeout) + timeout_lst: list[Optional[int]] + spice_fpath = osp.join(cache_path, FNAME_SPICE_JAR) if use_shell is None: use_shell = platform.system() == "Windows" if __debug__: - if not osp.isfile(spice_fpath): - raise FileNotFoundError( - f"Cannot find JAR file '{spice_fpath}' for SPICE metric. Maybe run 'aac-metrics-download' or specify another 'cache_path' directory." - ) + check_spice_install(cache_path) + if not check_java_path(java_path): raise RuntimeError( f"Invalid Java executable to compute SPICE score. ({java_path})" @@ -124,120 +130,57 @@ def spice( for i, (cand, refs) in enumerate(zip(candidates, mult_references)) ] - in_file = NamedTemporaryFile( - mode="w", delete=False, dir=tmp_path, prefix="spice_inputs_", suffix=".json" + json_kwds: dict[str, Any] = dict( + mode="w", + delete=False, + dir=tmp_path, + suffix=".json", ) + in_file = NamedTemporaryFile(prefix="spice_inputs_", **json_kwds) json.dump(input_data, in_file, indent=2) in_file.close() - # Sometimes the java program can freeze, so timeout has been added to avoid using job time. - if timeout is None or isinstance(timeout, (int, float)): - timeout_lst = [timeout] - else: - timeout_lst = list(timeout) - - out_file = NamedTemporaryFile( - mode="w", delete=False, dir=tmp_path, prefix="spice_outputs_", suffix=".json" - ) + out_file = NamedTemporaryFile(prefix="spice_outputs_", **json_kwds) out_file.close() - for i, timeout_i in enumerate(timeout_lst): - if verbose >= 3: - stdout = None - stderr = None - else: - common_kwds: dict[str, Any] = dict( - mode="w", - delete=True, - dir=tmp_path, - suffix=".txt", - ) - stdout = NamedTemporaryFile( - prefix="spice_stdout_", - **common_kwds, - ) - stderr = NamedTemporaryFile( - prefix="spice_stderr_", - **common_kwds, - ) + spice_cmd = [ + java_path, + "-jar", + f"-Xmx{java_max_memory}", + spice_fpath, + in_file.name, + "-cache", + spice_cache, + "-out", + out_file.name, + "-subset", + ] + if n_threads is not None: + spice_cmd += ["-threads", str(n_threads)] + + fpaths = [ + java_path, + spice_fpath, + in_file.name, + spice_cache, + out_file.name, + ] - spice_cmd = [ - java_path, - "-jar", - f"-Xmx{java_max_memory}", - spice_fpath, - in_file.name, - "-cache", - spice_cache, - "-out", + for i, timeout_i in enumerate(timeout_lst): + success = __run_spice( + i, + timeout_i, + timeout_lst, + spice_cmd, + tmp_path, out_file.name, - "-subset", - ] - if n_threads is not None: - spice_cmd += ["-threads", str(n_threads)] - - if verbose >= 2: - pylog.debug( - f"Run SPICE java code with: {' '.join(spice_cmd)} and {use_shell=}" - ) - - try: - subprocess.check_call( - spice_cmd, - stdout=stdout, - stderr=stderr, - timeout=timeout_i, - shell=use_shell, - ) - if stdout is not None: - stdout.close() - if stderr is not None: - stderr.close() + fpaths, + use_shell, + verbose, + ) + if success: break - except subprocess.TimeoutExpired as err: - pylog.warning( - f"Timeout SPICE java program with {timeout_i=}s (nb timeouts done={i+1}/{len(timeout_lst)})." - ) - - if i < len(timeout_lst) - 1: - # Clear out files - open(out_file.name, "w").close() - if stdout is not None: - stdout.close() - if stderr is not None: - stderr.close() - time.sleep(1.0) - else: - raise err - - except (CalledProcessError, PermissionError) as err: - pylog.error("Invalid SPICE call.") - pylog.error(f"Full command: '{' '.join(spice_cmd)}'") - if ( - stdout is not None - and stderr is not None - and osp.isfile(stdout.name) - and osp.isfile(stderr.name) - ): - pylog.error( - f"For more information, see temp files '{stdout.name}' and '{stderr.name}'." - ) - with open(stdout.name, "r") as file: - lines = file.readlines() - content = "\n".join(lines) - pylog.error(f"Content of '{stdout.name}':\n{content}") - - with open(stderr.name, "r") as file: - lines = file.readlines() - content = "\n".join(lines) - pylog.error(f"Content of '{stderr.name}':\n{content}") - else: - pylog.info( - f"Note: No temp file recorded. (found {stdout=} and {stderr=})" - ) - raise err - if verbose >= 2: pylog.debug("SPICE java code finished.") @@ -250,11 +193,11 @@ def spice( if separate_cache_dir: shutil.rmtree(spice_cache) - imgId_to_scores = {} spice_scores = [] for item in results: - imgId_to_scores[item["image_id"]] = item["scores"] - spice_scores.append(__float_convert(item["scores"]["All"]["f"])) + # item keys: "image_id", "scores" + spice_scores_i = __float_convert(item["scores"]["All"]["f"]) + spice_scores.append(spice_scores_i) spice_scores = np.array(spice_scores) # Note: use numpy to compute mean because np.mean and torch.mean can give very small differences @@ -276,6 +219,193 @@ def spice( return spice_score +def check_spice_install(cache_path: str) -> None: + """Check if SPICE is installed in cache directory. + + Raises FileNotFoundError or NotADirectoryError exception if something is missing. + """ + spice_fpath = osp.join(cache_path, FNAME_SPICE_JAR) + if not osp.isfile(spice_fpath): + raise FileNotFoundError( + f"Cannot find JAR file '{spice_fpath}' for SPICE metric. Maybe run 'aac-metrics-download' or specify another 'cache_path' directory." + ) + + local_cache_dpath = osp.join(cache_path, DNAME_SPICE_CACHE, "cache") + if not osp.isdir(local_cache_dpath): + raise NotADirectoryError( + f"Cannot find cache local directory '{local_cache_dpath}' for SPICE metric. Maybe run 'aac-metrics-download' or specify another 'cache_path' directory." + ) + + lib_dpath = osp.join(cache_path, DNAME_SPICE_CACHE, "lib") + if not osp.isdir(lib_dpath): + raise NotADirectoryError( + f"Cannot find lib directory '{lib_dpath}' for SPICE metric. Maybe run 'aac-metrics-download' or specify another 'cache_path' directory." + ) + + expected_jar_in_lib = [ + "ejml-0.23.jar", + "fst-2.47.jar", + "guava-19.0.jar", + "hamcrest-core-1.3.jar", + "jackson-core-2.5.3.jar", + "javassist-3.19.0-GA.jar", + "json-simple-1.1.1.jar", + "junit-4.12.jar", + "lmdbjni-0.4.6.jar", + "lmdbjni-linux64-0.4.6.jar", + "lmdbjni-osx64-0.4.6.jar", + "lmdbjni-win64-0.4.6.jar", + "Meteor-1.5.jar", + "objenesis-2.4.jar", + "SceneGraphParser-1.0.jar", + "slf4j-api-1.7.12.jar", + "slf4j-simple-1.7.21.jar", + "stanford-corenlp-3.6.0.jar", + "stanford-corenlp-3.6.0-models.jar", + ] + names = os.listdir(lib_dpath) + files_not_found = [] + for fname in expected_jar_in_lib: + if fname not in names: + files_not_found.append(fname) + if len(files_not_found) > 0: + raise FileNotFoundError( + f"Missing {len(files_not_found)} files in SPICE lib directory. (missing {', '.join(files_not_found)})" + ) + + +def __run_spice( + i: int, + timeout_i: Optional[int], + timeout_lst: list[Optional[int]], + spice_cmd: list[str], + tmp_path: str, + out_path: str, + paths: list[str], + use_shell: bool, + verbose: int, +) -> bool: + success = False + txt_kwds: dict[str, Any] = dict( + mode="w", + delete=False, + dir=tmp_path, + suffix=".txt", + ) + + if verbose >= 3: + stdout = None + stderr = None + else: + stdout = NamedTemporaryFile( + prefix="spice_stdout_", + **txt_kwds, + ) + stderr = NamedTemporaryFile( + prefix="spice_stderr_", + **txt_kwds, + ) + + if verbose >= 2: + pylog.debug(f"Run SPICE java code with: {' '.join(spice_cmd)} and {use_shell=}") + + try: + subprocess.check_call( + spice_cmd, + stdout=stdout, + stderr=stderr, + timeout=timeout_i, + shell=use_shell, + ) + if stdout is not None: + stdout.close() + os.remove(stdout.name) + if stderr is not None: + stderr.close() + os.remove(stderr.name) + + success = True + + except subprocess.TimeoutExpired as err: + pylog.warning( + f"Timeout SPICE java program with {timeout_i=}s (nb timeouts done={i+1}/{len(timeout_lst)})." + ) + + if i < len(timeout_lst) - 1: + # Clear out files + open(out_path, "w").close() + if stdout is not None: + stdout.close() + open(stdout.name, "w").close() + if stderr is not None: + stderr.close() + open(stderr.name, "w").close() + time.sleep(1.0) + else: + raise err + + except (CalledProcessError, PermissionError) as err: + pylog.error("Invalid SPICE call.") + pylog.error(f"Full command: '{' '.join(spice_cmd)}'") + pylog.error(f"Error: {err}") + + paths = copy.copy(paths) + if stdout is not None: + stdout.close() + paths.append(stdout.name) + if stderr is not None: + stderr.close() + paths.append(stderr.name) + + for path in paths: + rights = __get_access_rights(path) + pylog.error(f"{path} :\t {rights}") + + if ( + stdout is not None + and stderr is not None + and osp.isfile(stdout.name) + and osp.isfile(stderr.name) + ): + pylog.error( + f"For more information, see temp files '{stdout.name}' and '{stderr.name}'." + ) + + for path in (stdout.name, stderr.name): + try: + with open(path, "r") as file: + lines = file.readlines() + content = "\n".join(lines) + pylog.error(f"Content of '{path}':\n{content}") + except PermissionError as err2: + pylog.warning(f"Cannot open file '{path}'. ({err2})") + else: + pylog.info(f"Note: No temp file recorded. (found {stdout=} and {stderr=})") + raise err + + return success + + +def __get_access_rights(path: str) -> str: + info = {"t": "-", "r": "-", "w": "-", "x": "-"} + if osp.islink(path): + info["t"] = "l" + elif osp.isfile(path): + info["t"] = "f" + elif osp.isdir(path): + info["t"] = "d" + + if os.access(path, os.R_OK): + info["r"] = "r" + if os.access(path, os.W_OK): + info["w"] = "w" + if os.access(path, os.X_OK): + info["x"] = "x" + + rights = "".join(info.values()) + return rights + + def __float_convert(obj: Any) -> float: try: return float(obj) diff --git a/src/aac_metrics/functional/spider.py b/src/aac_metrics/functional/spider.py index 8bcb568..b1f65ea 100644 --- a/src/aac_metrics/functional/spider.py +++ b/src/aac_metrics/functional/spider.py @@ -64,19 +64,21 @@ def spider( f"Number of candidates and mult_references are different (found {len(candidates)} != {len(mult_references)})." ) - cider_d_outs: tuple = cider_d( # type: ignore - candidates, - mult_references, - True, + return_all_scores = True + + cider_d_outs: tuple[dict[str, Tensor], dict[str, Tensor]] = cider_d( # type: ignore + candidates=candidates, + mult_references=mult_references, + return_all_scores=return_all_scores, n=n, sigma=sigma, tokenizer=tokenizer, return_tfidf=return_tfidf, ) - spice_outs: tuple = spice( # type: ignore - candidates, - mult_references, - True, + spice_outs: tuple[dict[str, Tensor], dict[str, Tensor]] = spice( # type: ignore + candidates=candidates, + mult_references=mult_references, + return_all_scores=return_all_scores, cache_path=cache_path, java_path=java_path, tmp_path=tmp_path, diff --git a/src/aac_metrics/info.py b/src/aac_metrics/info.py index b852b2f..a0a6cf4 100644 --- a/src/aac_metrics/info.py +++ b/src/aac_metrics/info.py @@ -12,6 +12,7 @@ import aac_metrics +from aac_metrics.utils.checks import _get_java_version from aac_metrics.utils.paths import ( get_default_cache_path, get_default_java_path, @@ -24,8 +25,17 @@ def get_package_repository_path() -> str: return str(Path(__file__).parent.parent.parent) +def get_java_version() -> str: + try: + java_version = _get_java_version(get_default_java_path()) + return java_version + except ValueError: + return "UNKNOWN" + + def get_install_info() -> Dict[str, str]: """Return a dictionary containing the version python, the os name, the architecture name and the versions of the following packages: aac_datasets, torch, torchaudio.""" + return { "aac_metrics": aac_metrics.__version__, "python": f"{sys.version_info.major}.{sys.version_info.minor}.{sys.version_info.micro}", @@ -36,6 +46,7 @@ def get_install_info() -> Dict[str, str]: "cache_path": get_default_cache_path(), "java_path": get_default_java_path(), "tmp_path": get_default_tmp_path(), + "java_version": get_java_version(), } diff --git a/src/aac_metrics/install_spice.sh b/src/aac_metrics/install_spice.sh deleted file mode 100644 index 3107e8b..0000000 --- a/src/aac_metrics/install_spice.sh +++ /dev/null @@ -1,55 +0,0 @@ -#!/bin/bash - -DEFAULT_SPICE_ROOT="$HOME/.cache/aac-metrics/spice" - -if [ "$1" = "-h" ] || [ "$1" = "--help" ]; then - echo "Install all files for running the java SPICE program in the SPICE_ROOT directory." - echo "The default spice root path is \"${DEFAULT_SPICE_ROOT}\"." - echo "Usage: $0 [SPICE_ROOT]" - exit 0 -fi - -dpath_spice="$1" -if [ "$dpath_spice" = "" ]; then - dpath_spice="${DEFAULT_SPICE_ROOT}" -fi - -if [ ! -d "$dpath_spice" ]; then - echo "Error: SPICE_ROOT \"$dpath_spice\" is not a directory." - exit 1 -fi - -fname_zip="SPICE-1.0.zip" -fpath_zip="$dpath_spice/$fname_zip" -bn0=`basename $0` - -echo "[$bn0] Start installation of SPICE metric java code in directory \"$dpath_spice\"..." - -if [ ! -f "$fpath_zip" ]; then - echo "[$bn0] Zip file not found, downloading from https://panderson.me..." - wget https://panderson.me/images/SPICE-1.0.zip -P "$dpath_spice" -fi - -dpath_unzip="$dpath_spice/SPICE-1.0" -if [ ! -d "$dpath_unzip" ]; then - echo "[$bn0] Unzipping file $dpath_zip..." - unzip $fpath_zip -d "$dpath_spice" - - echo "[$bn0] Downloading Stanford models..." - bash $dpath_unzip/get_stanford_models.sh -fi - -dpath_lib="$dpath_spice/lib" -if [ ! -d "$dpath_lib" ]; then - echo "[$bn0] Moving lib directory to \"$dpath_spice\"..." - mv "$dpath_unzip/lib" "$dpath_spice" -fi - -fpath_jar="$dpath_spice/spice-1.0.jar" -if [ ! -f "$fpath_jar" ]; then - echo "[$bn0] Moving spice-1.0.jar file to \"$dpath_spice\"..." - mv "$dpath_unzip/spice-1.0.jar" "$dpath_spice" -fi - -echo "[$bn0] SPICE metric Java code is installed." -exit 0 diff --git a/src/aac_metrics/utils/checks.py b/src/aac_metrics/utils/checks.py index 2054fb1..6164266 100644 --- a/src/aac_metrics/utils/checks.py +++ b/src/aac_metrics/utils/checks.py @@ -5,10 +5,10 @@ import re import subprocess -from functools import cache from pathlib import Path from subprocess import CalledProcessError from typing import Any, Union +from typing_extensions import TypeGuard pylog = logging.getLogger(__name__) @@ -23,11 +23,18 @@ def check_metric_inputs( mult_references: Any, ) -> None: """Raises ValueError if candidates and mult_references does not have a valid type and size.""" + + error_msgs = [] if not is_mono_sents(candidates): - raise ValueError("Invalid candidates type. (expected list[str])") + error_msg = "Invalid candidates type. (expected list[str])" + error_msgs.append(error_msg) if not is_mult_sents(mult_references): - raise ValueError("Invalid mult_references type. (expected list[list[str]])") + error_msg = "Invalid mult_references type. (expected list[list[str]])" + error_msgs.append(error_msg) + + if len(error_msgs) > 0: + raise ValueError("\n".join(error_msgs)) same_len = len(candidates) == len(mult_references) if not same_len: @@ -53,18 +60,20 @@ def check_java_path(java_path: Union[str, Path]) -> bool: return valid -def is_mono_sents(sents: Any) -> bool: - """Returns True if input is list[str].""" - return isinstance(sents, list) and all(isinstance(sent, str) for sent in sents) +def is_mono_sents(sents: Any) -> TypeGuard[list[str]]: + """Returns True if input is list[str] containing sentences.""" + valid = isinstance(sents, list) and all(isinstance(sent, str) for sent in sents) + return valid -def is_mult_sents(mult_sents: Any) -> bool: - """Returns True if input is list[list[str]].""" - return ( +def is_mult_sents(mult_sents: Any) -> TypeGuard[list[list[str]]]: + """Returns True if input is list[list[str]] containing multiple sentences.""" + valid = ( isinstance(mult_sents, list) and all(isinstance(sents, list) for sents in mult_sents) and all(isinstance(sent, str) for sents in mult_sents for sent in sents) ) + return valid def _get_java_version(java_path: str) -> str: @@ -106,9 +115,8 @@ def _check_java_version(version: str, min_major: int, max_major: int) -> bool: major_version = int(result["major"]) minor_version = int(result["minor"]) - if ( - major_version == 1 and minor_version <= 8 - ): # java <= 8 use versioning "1.MAJOR.MINOR" and > 8 use "MAJOR.MINOR.PATCH" + if major_version == 1 and minor_version <= 8: + # java <= 8 use versioning "1.MAJOR.MINOR" and > 8 use "MAJOR.MINOR.PATCH" major_version = minor_version return min_major <= major_version <= max_major diff --git a/src/aac_metrics/utils/tokenization.py b/src/aac_metrics/utils/tokenization.py index 183145b..bc16d14 100644 --- a/src/aac_metrics/utils/tokenization.py +++ b/src/aac_metrics/utils/tokenization.py @@ -10,7 +10,7 @@ from typing import Any, Hashable, Iterable, Optional -from aac_metrics.utils.checks import check_java_path +from aac_metrics.utils.checks import check_java_path, is_mono_sents from aac_metrics.utils.collections import flat_list, unflat_list from aac_metrics.utils.paths import ( _get_cache_path, @@ -69,20 +69,31 @@ def ptb_tokenize_batch( :param verbose: The verbose level. defaults to 0. :returns: The sentences tokenized as list[list[str]]. """ + # Originally based on https://github.com/audio-captioning/caption-evaluation-tools/blob/c1798df4c91e29fe689b1ccd4ce45439ec966417/caption/pycocoevalcap/tokenizer/ptbtokenizer.py#L30 + sentences = list(sentences) + + if not is_mono_sents(sentences): + raise ValueError("Invalid argument sentences. (not a list[str] of sentences)") + if len(sentences) == 0: return [] cache_path = _get_cache_path(cache_path) java_path = _get_java_path(java_path) tmp_path = _get_tmp_path(tmp_path) - - # Based on https://github.com/audio-captioning/caption-evaluation-tools/blob/c1798df4c91e29fe689b1ccd4ce45439ec966417/caption/pycocoevalcap/tokenizer/ptbtokenizer.py#L30 + punctuations = list(punctuations) stanford_fpath = osp.join(cache_path, FNAME_STANFORD_CORENLP_3_4_1_JAR) # Sanity checks if __debug__: + newlines_count = sum(sent.count("\n") for sent in sentences) + if newlines_count > 0: + raise ValueError( + f"Invalid argument sentences for tokenization. (found {newlines_count} newlines character '\\n')" + ) + if not osp.isdir(cache_path): raise RuntimeError(f"Cannot find cache directory at {cache_path=}.") if not osp.isdir(tmp_path): @@ -111,9 +122,6 @@ def ptb_tokenize_batch( "-lowerCase", ] - # ====================================================== - # Prepare data for PTB AACTokenizer - # ====================================================== if audio_ids is None: audio_ids = list(range(len(sentences))) else: @@ -135,9 +143,6 @@ def ptb_tokenize_batch( for old, new in replaces.items(): sentences = sentences.replace(old, new) - # ====================================================== - # Save sentences to temporary file - # ====================================================== tmp_file = tempfile.NamedTemporaryFile( delete=False, dir=tmp_path, @@ -147,9 +152,6 @@ def ptb_tokenize_batch( tmp_file.write(sentences.encode()) tmp_file.close() - # ====================================================== - # Tokenize sentence - # ====================================================== cmd.append(osp.basename(tmp_file.name)) p_tokenizer = subprocess.Popen( cmd, @@ -157,28 +159,28 @@ def ptb_tokenize_batch( stdout=subprocess.PIPE, stderr=subprocess.DEVNULL if verbose <= 2 else None, ) - token_lines = p_tokenizer.communicate(input=sentences.rstrip().encode())[0] + encoded_sentences = sentences.rstrip().encode() + token_lines = p_tokenizer.communicate(input=encoded_sentences)[0] token_lines = token_lines.decode() lines = token_lines.split("\n") # remove temp file os.remove(tmp_file.name) - # ====================================================== - # Create dictionary for tokenized captions - # ====================================================== - outs: Any = [None for _ in range(len(lines))] if len(audio_ids) != len(lines): raise RuntimeError( f"PTB tokenize error: expected {len(audio_ids)} lines in output file but found {len(lines)}." + f"Maybe check if there is any newline character '\\n' in your sentences or disable preprocessing tokenization." ) - punctuations = list(punctuations) + outs: Any = [None for _ in range(len(lines))] for k, line in zip(audio_ids, lines): tokenized_caption = [ w for w in line.rstrip().split(" ") if w not in punctuations ] outs[k] = tokenized_caption - assert all(out is not None for out in outs) + assert all( + out is not None for out in outs + ), "INTERNAL ERROR: PTB tokenizer output is invalid." if verbose >= 2: duration = time.perf_counter() - start_time @@ -199,9 +201,9 @@ def preprocess_mono_sents( """Tokenize sentences using PTB Tokenizer then merge them by space. .. warning:: - PTB tokenizer is a java program that takes a list[str] as input, so calling several times `preprocess_mono_sents` is slow on list[list[str]]. + PTB tokenizer is a java program that takes a list[str] as input, so calling several times this function is slow on list[list[str]]. - If you want to process multiple sentences (list[list[str]]), use `preprocess_mult_sents` instead. + If you want to process multiple sentences (list[list[str]]), use :func:`~aac_metrics.utils.tokenization.preprocess_mult_sents` instead. :param sentences: The list of sentences to process. :param cache_path: The path to the external code directory. defaults to the value returned by :func:`~aac_metrics.utils.paths.get_default_cache_path`. diff --git a/tests/test_compare_cet.py b/tests/test_compare_cet.py index ba66e51..01e6794 100644 --- a/tests/test_compare_cet.py +++ b/tests/test_compare_cet.py @@ -2,7 +2,6 @@ # -*- coding: utf-8 -*- import importlib -import os import os.path as osp import platform import subprocess @@ -17,7 +16,9 @@ from aac_metrics.functional.evaluate import evaluate from aac_metrics.eval import load_csv_file -from aac_metrics.utils.paths import get_default_tmp_path +from aac_metrics.utils.paths import ( + get_default_tmp_path, +) class TestCompareCaptionEvaluationTools(TestCase): @@ -102,28 +103,28 @@ def _get_example_0(self) -> tuple[list[str], list[list[str]]]: return cands, mrefs def _test_with_example(self, cands: list[str], mrefs: list[list[str]]) -> None: - if os.name == "nt": - # Skip this setup on windows + if platform.system() == "Windows": return None + corpus_scores, _ = evaluate(cands, mrefs, metrics="dcase2020") + + self.assertIsInstance(corpus_scores, dict) + + for name, score in corpus_scores.items(): + self.assertIsInstance(score, Tensor, f"Invalid score type for {name=}") + self.assertEqual(score.ndim, 0, f"Invalid score ndim for {name=}") + cet_outs = self.__class__.evaluate_metrics_from_lists(cands, mrefs) cet_global_scores, _cet_sents_scores = cet_outs - corpus_scores, _ = evaluate(cands, mrefs, metrics="dcase2020") - cet_global_scores = {k.lower(): v for k, v in cet_global_scores.items()} cet_global_scores = { (k if k != "cider" else "cider_d"): v for k, v in cet_global_scores.items() } - self.assertIsInstance(corpus_scores, dict) self.assertIsInstance(cet_global_scores, dict) - - for name, score in corpus_scores.items(): - self.assertIsInstance(score, Tensor, f"Invalid score type for {name=}") - self.assertEqual(score.ndim, 0, f"Invalid score ndim for {name=}") - self.assertListEqual(list(corpus_scores.keys()), list(cet_global_scores.keys())) + for metric_name, v1 in corpus_scores.items(): v1 = v1.item() v2 = cet_global_scores[metric_name] diff --git a/tests/test_compare_fense.py b/tests/test_compare_fense.py index 749c7f1..6086d2d 100644 --- a/tests/test_compare_fense.py +++ b/tests/test_compare_fense.py @@ -10,7 +10,6 @@ from typing import Any from unittest import TestCase -from aac_metrics.classes.sbert_sim import SBERTSim from aac_metrics.classes.fense import FENSE from aac_metrics.eval import load_csv_file @@ -33,13 +32,8 @@ def setUpClass(cls) -> None: echecker_model="echecker_clotho_audiocaps_base", ) - cls.new_sbert_sim = SBERTSim( - return_all_scores=False, - device=device, - verbose=2, - ) cls.new_fense = FENSE( - return_all_scores=False, + return_all_scores=True, device=device, verbose=2, echecker="echecker_clotho_audiocaps_base", @@ -57,24 +51,17 @@ def test_example_1_fense(self) -> None: fpath = osp.join(osp.dirname(__file__), "..", "data", "example_1.csv") self._test_with_original_fense(fpath) - def test_example_1_sbert_sim(self) -> None: - fpath = osp.join(osp.dirname(__file__), "..", "data", "example_1.csv") - self._test_with_original_sbert_sim(fpath) - def test_example_2_fense(self) -> None: fpath = osp.join(osp.dirname(__file__), "..", "data", "example_2.csv") self._test_with_original_fense(fpath) - def test_example_2_sbert_sim(self) -> None: - fpath = osp.join(osp.dirname(__file__), "..", "data", "example_2.csv") - self._test_with_original_sbert_sim(fpath) - def test_output_size(self) -> None: fpath = osp.join(osp.dirname(__file__), "..", "data", "example_1.csv") cands, mrefs = load_csv_file(fpath) self.new_fense._return_all_scores = True - corpus_scores, sents_scores = self.new_fense(cands, mrefs) + outs: tuple = self.new_fense(cands, mrefs) # type: ignore + corpus_scores, sents_scores = outs self.new_fense._return_all_scores = False for name, score in corpus_scores.items(): @@ -90,34 +77,24 @@ def test_output_size(self) -> None: def _test_with_original_fense(self, fpath: str) -> None: cands, mrefs = load_csv_file(fpath) - src_fense_score = self.src_fense.corpus_score(cands, mrefs).item() - new_fense_score = self.new_fense(cands, mrefs).item() - - print(f"{fpath=}") - print(f"{src_fense_score=}") - print(f"{new_fense_score=}") - - self.assertEqual( - src_fense_score, - new_fense_score, - "Invalid FENSE score with original implementation.", - ) - - def _test_with_original_sbert_sim(self, fpath: str) -> None: - cands, mrefs = load_csv_file(fpath) - src_sbert_sim_score = self.src_sbert_sim.corpus_score(cands, mrefs).item() - new_sbert_sim_score = self.new_sbert_sim(cands, mrefs).item() + src_fense_score = self.src_fense.corpus_score(cands, mrefs).item() - print(f"{fpath=}") - print(f"{src_sbert_sim_score=}") - print(f"{new_sbert_sim_score=}") + outs: tuple = self.new_fense(cands, mrefs) # type: ignore + corpus_outs, _sents_outs = outs + new_sbert_sim_score = corpus_outs["sbert_sim"].item() + new_fense_score = corpus_outs["fense"].item() self.assertEqual( src_sbert_sim_score, new_sbert_sim_score, "Invalid SBERTSim score with original implementation.", ) + self.assertEqual( + src_fense_score, + new_fense_score, + "Invalid FENSE score with original implementation.", + ) if __name__ == "__main__": diff --git a/tests/test_doc_examples.py b/tests/test_doc_examples.py index 2ead343..a52e640 100644 --- a/tests/test_doc_examples.py +++ b/tests/test_doc_examples.py @@ -1,7 +1,7 @@ #!/usr/bin/env python # -*- coding: utf-8 -*- -import os +import platform import unittest from unittest import TestCase @@ -18,7 +18,7 @@ class TestReadmeExamples(TestCase): def test_example_1(self) -> None: - if os.name == "nt": + if platform.system() == "Windows": return None candidates: list[str] = ["a man is speaking", "rain falls"] @@ -48,11 +48,6 @@ def test_example_1(self) -> None: "spider", ] self.assertSetEqual(set(corpus_scores.keys()), set(expected_keys)) - - # print(corpus_scores["bleu_1"]) - # print(torch.as_tensor(0.4278, dtype=torch.float64)) - # print("END") - self.assertTrue( torch.allclose( corpus_scores["bleu_1"], @@ -63,7 +58,7 @@ def test_example_1(self) -> None: ) def test_example_2(self) -> None: - if os.name == "nt": + if platform.system() == "Windows": return None candidates: list[str] = ["a man is speaking", "rain falls"] @@ -84,9 +79,6 @@ def test_example_2(self) -> None: self.assertTrue(set(corpus_scores.keys()).issuperset(expected_keys)) def test_example_3(self) -> None: - if os.name == "nt": - return None - candidates: list[str] = ["a man is speaking", "rain falls"] mult_references: list[list[str]] = [ [ @@ -110,17 +102,19 @@ def test_example_3(self) -> None: self.assertTrue(set(corpus_scores.keys()).issuperset({"cider_d"})) self.assertTrue(set(sents_scores.keys()).issuperset({"cider_d"})) + dtype = torch.float64 + self.assertTrue( torch.allclose( corpus_scores["cider_d"], - torch.as_tensor(0.9614, dtype=torch.float64), + torch.as_tensor(0.9614, dtype=dtype), atol=0.0001, ) ) self.assertTrue( torch.allclose( sents_scores["cider_d"], - torch.as_tensor([1.3641, 0.5587], dtype=torch.float64), + torch.as_tensor([1.3641, 0.5587], dtype=dtype), atol=0.0001, ) )