Skip to content

Commit

Permalink
Version 0.5.4
Browse files Browse the repository at this point in the history
  • Loading branch information
Labbeti committed Mar 4, 2024
1 parent e513a6f commit bc47a84
Show file tree
Hide file tree
Showing 16 changed files with 128 additions and 96 deletions.
36 changes: 36 additions & 0 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
# exclude: ""

repos:
# Format Code
- repo: https://github.com/ambv/black
rev: 22.8.0
hooks:
- id: black

# Sort imports
- repo: https://github.com/PyCQA/isort
rev: 5.12.0
hooks:
- id: isort
args: ["--profile", "black"]

# Formatting, Whitespace, etc
- repo: https://github.com/pre-commit/pre-commit-hooks
rev: v2.2.3
hooks:
- id: trailing-whitespace
- id: check-added-large-files
args: ['--maxkb=1000']
- id: check-ast
- id: check-json
- id: check-merge-conflict
- id: check-xml
- id: check-yaml
- id: debug-statements
- id: end-of-file-fixer
- id: requirements-txt-fixer
- id: mixed-line-ending
args: ['--fix=no']
- id: flake8
# args: ['--ignore=E203,E501,F811,E712,W503']
args: ['--config=.flake8']
7 changes: 7 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,13 @@

All notable changes to this project will be documented in this file.

## [0.5.4] 2024-03-04
### Fixed
- Backward compatibility of `BERTScoreMrefs` with torchmetrics prior to 1.0.0.

### Deleted
- `Version` class to use `packaging.version.Version` instead.

## [0.5.3] 2024-01-09
### Fixed
- Fix `BERTScoreMrefs` computation when all multiple references sizes are equal.
Expand Down
4 changes: 2 additions & 2 deletions CITATION.cff
Original file line number Diff line number Diff line change
Expand Up @@ -19,5 +19,5 @@ keywords:
- captioning
- audio-captioning
license: MIT
version: 0.5.3
date-released: '2024-01-09'
version: 0.5.4
date-released: '2024-03-04'
32 changes: 25 additions & 7 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,24 @@ print(sents_scores)

Each metrics also exists as a python class version, like `aac_metrics.classes.cider_d.CIDErD`.

## Which metric(s) should I choose for Automated Audio Captioning?
To evaluate audio captioning systems, I would recommand to compute `SPIDEr`, `FENSE` and `Vocab` metrics. `SPIDEr` is useful to compare with the rest of the litterature, but it is highly sensitive to n-gram matching and can overestimate model trained with reinforcement learning. `FENSE` is more consistent and variable than `SPIDEr`, but uses a model not trained on audio captions. `Vocab` can give you an insight about the model diversity. To compute all of these metrics at once, you can use for example the `Evaluate` class:

```python
from aac_metrics import Evaluate

evaluate = Evaluate(metrics=["spider", "fense", "vocab"])

candidates: list[str] = ...
mult_references: list[list[str]] = ...

corpus_scores, _ = evaluate(candidates, mult_references)

vocab_size = corpus_scores["vocab"]
spider_score = corpus_scores["spider"]
fense_score = corpus_scores["fense"]
```

## Metrics
### Legacy metrics
| Metric name | Python Class | Origin | Range | Short description |
Expand Down Expand Up @@ -187,7 +205,7 @@ SPIDEr-max [[7]](#spider-max) is a metric based on SPIDEr that takes into accoun
[6] S. Liu, Z. Zhu, N. Ye, S. Guadarrama, and K. Murphy, “Improved Image Captioning via Policy Gradient optimization of SPIDEr,” 2017 IEEE International Conference on Computer Vision (ICCV), pp. 873–881, Oct. 2017, arXiv: 1612.00370. [Online]. Available: http://arxiv.org/abs/1612.00370

#### BERTScore
[7] T. Zhang*, V. Kishore*, F. Wu*, K. Q. Weinberger, and Y. Artzi, “BERTScore: Evaluating Text Generation with BERT,” 2020. [Online]. Available: https://openreview.net/forum?id=SkeHuCVFDr
[7] T. Zhang*, V. Kishore*, F. Wu*, K. Q. Weinberger, and Y. Artzi, “BERTScore: Evaluating Text Generation with BERT,” 2020. [Online]. Available: https://openreview.net/forum?id=SkeHuCVFDr

#### SPIDEr-max
[8] E. Labbé, T. Pellegrini, and J. Pinquier, “Is my automatic audio captioning system so bad? spider-max: a metric to consider several caption candidates,” Nov. 2022. [Online]. Available: https://hal.archives-ouvertes.fr/hal-03810396
Expand All @@ -199,19 +217,19 @@ SPIDEr-max [[7]](#spider-max) is a metric based on SPIDEr that takes into accoun
[10] DCASE website task6a description: https://dcase.community/challenge2023/task-automated-audio-captioning#evaluation

#### CB-score
[11] I. Martín-Morató, M. Harju, and A. Mesaros, “A Summarization Approach to Evaluating Audio Captioning,” Nov. 2022. [Online]. Available: https://dcase.community/documents/workshop2022/proceedings/DCASE2022Workshop_Martin-Morato_35.pdf
[11] I. Martín-Morató, M. Harju, and A. Mesaros, “A Summarization Approach to Evaluating Audio Captioning,” Nov. 2022. [Online]. Available: https://dcase.community/documents/workshop2022/proceedings/DCASE2022Workshop_Martin-Morato_35.pdf

#### SPICE-plus
[12] F. Gontier, R. Serizel, and C. Cerisara, “SPICE+: Evaluation of Automatic Audio Captioning Systems with Pre-Trained Language Models,” in ICASSP 2023 - 2023 IEEE International Conference on Acoustics, Speech and Signal Processing (ICASSP), 2023, pp. 1–5. doi: 10.1109/ICASSP49357.2023.10097021.
[12] F. Gontier, R. Serizel, and C. Cerisara, “SPICE+: Evaluation of Automatic Audio Captioning Systems with Pre-Trained Language Models,” in ICASSP 2023 - 2023 IEEE International Conference on Acoustics, Speech and Signal Processing (ICASSP), 2023, pp. 1–5. doi: 10.1109/ICASSP49357.2023.10097021.

#### ACES
[13] G. Wijngaard, E. Formisano, B. L. Giordano, M. Dumontier, “ACES: Evaluating Automated Audio Captioning Models on the Semantics of Sounds”, in EUSIPCO 2023, 2023.

#### SBF
[14] R. Mahfuz, Y. Guo, A. K. Sridhar, and E. Visser, Detecting False Alarms and Misses in Audio Captions. 2023. [Online]. Available: https://arxiv.org/pdf/2309.03326.pdf
[14] R. Mahfuz, Y. Guo, A. K. Sridhar, and E. Visser, Detecting False Alarms and Misses in Audio Captions. 2023. [Online]. Available: https://arxiv.org/pdf/2309.03326.pdf

#### s2v
[15] S. Bhosale, R. Chakraborty, and S. K. Kopparapu, “A Novel Metric For Evaluating Audio Caption Similarity,” in ICASSP 2023 - 2023 IEEE International Conference on Acoustics, Speech and Signal Processing (ICASSP), 2023, pp. 1–5. doi: 10.1109/ICASSP49357.2023.10096526.
[15] S. Bhosale, R. Chakraborty, and S. K. Kopparapu, “A Novel Metric For Evaluating Audio Caption Similarity,” in ICASSP 2023 - 2023 IEEE International Conference on Acoustics, Speech and Signal Processing (ICASSP), 2023, pp. 1–5. doi: 10.1109/ICASSP49357.2023.10096526.

## Citation
If you use **SPIDEr-max**, you can cite the following paper using BibTex :
Expand All @@ -234,10 +252,10 @@ If you use this software, please consider cite it as "Labbe, E. (2013). aac-metr
Labbe_aac_metrics_2024,
author = {Labbé, Etienne},
license = {MIT},
month = {01},
month = {03},
title = {{aac-metrics}},
url = {https://github.com/Labbeti/aac-metrics/},
version = {0.5.3},
version = {0.5.4},
year = {2024},
}
```
Expand Down
7 changes: 0 additions & 7 deletions docs/aac_metrics.utils.imports.rst

This file was deleted.

1 change: 0 additions & 1 deletion docs/aac_metrics.utils.rst
Original file line number Diff line number Diff line change
Expand Up @@ -16,5 +16,4 @@ Submodules
aac_metrics.utils.cmdline
aac_metrics.utils.collections
aac_metrics.utils.globals
aac_metrics.utils.imports
aac_metrics.utils.tokenization
2 changes: 1 addition & 1 deletion docs/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,5 +103,5 @@ def setup(app) -> None:
app.add_css_file("my_theme.css")


# TODO: to be used with sphinx>=7.1
# Only works with sphinx>=7.1
maximum_signature_line_length = 10
7 changes: 4 additions & 3 deletions docs/installation.rst
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,13 @@ Installation
Simply run to install the package:

.. code-block:: bash
pip install aac-metrics
Then download the external tools needed for SPICE, PTBTokenizer, METEOR and FENSE:

.. code-block:: bash
aac-metrics-download
Expand All @@ -26,4 +26,5 @@ The python requirements are automatically installed when using pip on this repos
pyyaml>=6.0
tqdm>=4.64.0
sentence-transformers>=2.2.2
transformers<4.31.0
transformers
packaging
11 changes: 6 additions & 5 deletions requirements-dev.txt
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
# -*- coding: utf-8 -*-

pytest==7.1.2
flake8==4.0.1
black==22.8.0
scikit-image==0.19.2
matplotlib==3.5.2
flake8==4.0.1
ipykernel==6.9.1
twine==4.0.1
matplotlib==3.5.2
pre-commit
pytest==7.1.2
scikit-image==0.19.2
sphinx==7.2.6
sphinx-press-theme==0.8.0
twine==4.0.1
7 changes: 4 additions & 3 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
# -*- coding: utf-8 -*-

torch>=1.10.1
numpy>=1.21.2
packaging>=23
pyyaml>=6.0
tqdm>=4.64.0
sentence-transformers>=2.2.2
transformers
torch>=1.10.1
torchmetrics>=0.11.4
tqdm>=4.64.0
transformers
11 changes: 6 additions & 5 deletions src/aac_metrics/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,20 +6,22 @@

__author__ = "Etienne Labbé (Labbeti)"
__author_email__ = "[email protected]"
__docs__ = "Audio Captioning Metrics"
__docs_url__ = "https://aac-metrics.readthedocs.io/en/stable/"
__license__ = "MIT"
__maintainer__ = "Etienne Labbé (Labbeti)"
__name__ = "aac-metrics"
__status__ = "Development"
__version__ = "0.5.3"
__version__ = "0.5.4"


from .classes.base import AACMetric
from .classes.bert_score_mrefs import BERTScoreMRefs
from .classes.bleu import BLEU
from .classes.cider_d import CIDErD
from .classes.evaluate import Evaluate, DCASE2023Evaluate, _get_metric_factory_classes
from .classes.fer import FER
from .classes.evaluate import DCASE2023Evaluate, Evaluate, _get_metric_factory_classes
from .classes.fense import FENSE
from .classes.fer import FER
from .classes.meteor import METEOR
from .classes.rouge_l import ROUGEL
from .classes.sbert_sim import SBERTSim
Expand All @@ -28,7 +30,7 @@
from .classes.spider_fl import SPIDErFL
from .classes.spider_max import SPIDErMax
from .classes.vocab import Vocab
from .functional.evaluate import evaluate, dcase2023_evaluate
from .functional.evaluate import dcase2023_evaluate, evaluate
from .utils.globals import (
get_default_cache_path,
get_default_java_path,
Expand All @@ -38,7 +40,6 @@
set_default_tmp_path,
)


__all__ = [
"AACMetric",
"BERTScoreMRefs",
Expand Down
27 changes: 17 additions & 10 deletions src/aac_metrics/functional/bert_score_mrefs.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,18 +4,18 @@
from typing import Callable, Optional, Union

import torch

from torch import nn, Tensor
from torchmetrics.functional.text.bert import bert_score, _DEFAULT_MODEL
import torchmetrics
from packaging.version import Version
from torch import Tensor, nn
from torchmetrics.functional.text.bert import _DEFAULT_MODEL, bert_score
from transformers import logging as tfmers_logging
from transformers.models.auto.modeling_auto import AutoModel
from transformers.models.auto.tokenization_auto import AutoTokenizer
from transformers import logging as tfmers_logging

from aac_metrics.utils.checks import check_metric_inputs
from aac_metrics.utils.collections import flat_list, unflat_list, duplicate_list
from aac_metrics.utils.collections import duplicate_list, flat_list, unflat_list
from aac_metrics.utils.globals import _get_device


DEFAULT_BERT_SCORE_MODEL = _DEFAULT_MODEL
REDUCTIONS = ("mean", "max", "min")

Expand All @@ -32,7 +32,7 @@ def bert_score_mrefs(
max_length: int = 64,
reset_state: bool = True,
idf: bool = False,
reduction: Union[str, Callable[[Tensor, ...], Tensor]] = "max",
reduction: Union[str, Callable[..., Tensor]] = "max",
filter_nan: bool = True,
verbose: int = 0,
) -> Union[tuple[dict[str, Tensor], dict[str, Tensor]], Tensor]:
Expand Down Expand Up @@ -144,9 +144,16 @@ def bert_score_mrefs(
reduction_fn = reduction

if len(sizes) > 0 and all(size == sizes[0] for size in sizes):
sents_scores = {
k: reduction_fn(torch.stack(v), dim=1) for k, v in sents_scores.items()
}
torchmetrics_version = Version(torchmetrics.__version__)
if torchmetrics_version < Version("1.0.0"):
# backward compatibility
sents_scores = {
k: reduction_fn(torch.as_tensor(v, dtype=dtype), dim=1) for k, v in sents_scores.items() # type: ignore
}
else:
sents_scores = {
k: reduction_fn(torch.stack(v), dim=1) for k, v in sents_scores.items() # type: ignore
}
else:
sents_scores = {
k: torch.stack([reduction_fn(torch.as_tensor(vi, dtype=dtype)) for vi in v])
Expand Down
14 changes: 8 additions & 6 deletions src/aac_metrics/functional/evaluate.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,11 @@

import logging
import time

from functools import partial
from functools import cache, partial
from pathlib import Path
from typing import Any, Callable, Iterable, Optional, Union

import torch

from torch import Tensor

from aac_metrics.functional.bert_score_mrefs import bert_score_mrefs
Expand All @@ -28,7 +26,6 @@
from aac_metrics.utils.checks import check_metric_inputs
from aac_metrics.utils.tokenization import preprocess_mono_sents, preprocess_mult_sents


pylog = logging.getLogger(__name__)


Expand Down Expand Up @@ -155,8 +152,8 @@ def evaluate(
set(outs_sents_i.keys()).intersection(outs_sents.keys())
)
if len(corpus_overlap) > 0 or len(sents_overlap) > 0:
pylog.warning(
f"Found overlapping metric outputs names. (found {corpus_overlap=} and {sents_overlap=})"
warn_once(
f"Found overlapping metric outputs names. (found {corpus_overlap=} and {sents_overlap=} at least twice)"
)

outs_corpus |= outs_corpus_i
Expand Down Expand Up @@ -352,3 +349,8 @@ def _get_metric_factory_functions(
),
}
return factory


@cache
def warn_once(msg: str) -> None:
pylog.warning(msg)
26 changes: 9 additions & 17 deletions src/aac_metrics/utils/checks.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,18 +2,16 @@
# -*- coding: utf-8 -*-

import logging
import re
import subprocess

from pathlib import Path
from subprocess import CalledProcessError
from typing import Any, Union
from typing_extensions import TypeGuard

from packaging.version import Version
from typing_extensions import TypeGuard

pylog = logging.getLogger(__name__)

VERSION_PATTERN = r"(?P<major>\d+)\.(?P<minor>\d+)\.(?P<patch>\d+).*"
MIN_JAVA_MAJOR_VERSION = 8
MAX_JAVA_MAJOR_VERSION = 13

Expand Down Expand Up @@ -103,18 +101,12 @@ def _get_java_version(java_path: str) -> str:
return version


def _check_java_version(version: str, min_major: int, max_major: int) -> bool:
result = re.match(VERSION_PATTERN, version)
if result is None:
raise ValueError(
f"Invalid Java version {version=}. (expected version with pattern={VERSION_PATTERN})"
)

major_version = int(result["major"])
minor_version = int(result["minor"])
def _check_java_version(version_str: str, min_major: int, max_major: int) -> bool:
version = Version(version_str)

if major_version == 1 and minor_version <= 8:
# java <= 8 use versioning "1.MAJOR.MINOR" and > 8 use "MAJOR.MINOR.PATCH"
major_version = minor_version
if version.major == 1 and version.minor <= 8:
# java <= 8 use versioning "1.MAJOR.MINOR" and > 8 use "MAJOR.MINOR.MICRO"
version_str = ".".join(map(str, (version.minor, version.micro)))
version = Version(version_str)

return min_major <= major_version <= max_major
return Version(f"{min_major}") <= version < Version(f"{max_major + 1}")
Loading

0 comments on commit bc47a84

Please sign in to comment.