Skip to content

Commit

Permalink
Version 0.5.2
Browse files Browse the repository at this point in the history
  • Loading branch information
Labbeti committed Jan 5, 2024
1 parent 1353169 commit bc1a25e
Show file tree
Hide file tree
Showing 29 changed files with 276 additions and 286 deletions.
5 changes: 5 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,11 @@

All notable changes to this project will be documented in this file.

## [0.5.2] 2024-01-05
### Changed
- `aac-metrics` is now compatible with `transformers>=4.31`.
- Rename default device value "auto" to "cuda_if_available".

## [0.5.1] 2023-12-20
### Added
- Check sentences inputs for all metrics.
Expand Down
4 changes: 2 additions & 2 deletions CITATION.cff
Original file line number Diff line number Diff line change
Expand Up @@ -19,5 +19,5 @@ keywords:
- captioning
- audio-captioning
license: MIT
version: 0.5.1
date-released: '2023-12-20'
version: 0.5.2
date-released: '2024-01-05'
16 changes: 7 additions & 9 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,7 @@ Each metrics also exists as a python class version, like `aac_metrics.classes.ci
### Other metrics
| Metric name | Python Class | Origin | Range | Short description |
|:---|:---|:---|:---|:---|
| Vocabulary | `Vocab` | text generation | [0, +$\infty$[ | Number of unique words in candidates. |
| Vocabulary | `Vocab` | text generation | [0, +∞[ | Number of unique words in candidates. |

### Future directions
This package currently does not include all metrics dedicated to audio captioning. Feel free to do a pull request / or ask to me by email if you want to include them. Those metrics not included are listed here:
Expand All @@ -146,15 +146,13 @@ numpy >= 1.21.2
pyyaml >= 6.0
tqdm >= 4.64.0
sentence-transformers >= 2.2.2
transformers < 4.31.0
transformers
torchmetrics >= 0.11.4
```

### External requirements
- `java` **>= 1.8 and <= 1.13** is required to compute METEOR, SPICE and use the PTBTokenizer.
Most of these functions can specify a java executable path with `java_path` argument.

- `unzip` command to extract SPICE zipped files.
Most of these functions can specify a java executable path with `java_path` argument or by overriding `AAC_METRICS_JAVA_PATH` environment variable.

## Additional notes
### CIDEr or CIDEr-D?
Expand Down Expand Up @@ -233,14 +231,14 @@ If you use this software, please consider cite it as "Labbe, E. (2013). aac-metr

```
@software{
Labbe_aac_metrics_2023,
Labbe_aac_metrics_2024,
author = {Labbé, Etienne},
license = {MIT},
month = {12},
month = {01},
title = {{aac-metrics}},
url = {https://github.com/Labbeti/aac-metrics/},
version = {0.5.1},
year = {2023},
version = {0.5.2},
year = {2024},
}
```

Expand Down
4 changes: 3 additions & 1 deletion docs/aac_metrics.classes.rst
Original file line number Diff line number Diff line change
Expand Up @@ -13,15 +13,17 @@ Submodules
:maxdepth: 4

aac_metrics.classes.base
aac_metrics.classes.bert_score_mrefs
aac_metrics.classes.bleu
aac_metrics.classes.cider_d
aac_metrics.classes.evaluate
aac_metrics.classes.fense
aac_metrics.classes.fluerr
aac_metrics.classes.fer
aac_metrics.classes.meteor
aac_metrics.classes.rouge_l
aac_metrics.classes.sbert_sim
aac_metrics.classes.spice
aac_metrics.classes.spider
aac_metrics.classes.spider_fl
aac_metrics.classes.spider_max
aac_metrics.classes.vocab
7 changes: 0 additions & 7 deletions docs/aac_metrics.evaluate.rst

This file was deleted.

4 changes: 3 additions & 1 deletion docs/aac_metrics.functional.rst
Original file line number Diff line number Diff line change
Expand Up @@ -12,11 +12,12 @@ Submodules
.. toctree::
:maxdepth: 4

aac_metrics.functional.bert_score_mrefs
aac_metrics.functional.bleu
aac_metrics.functional.cider_d
aac_metrics.functional.evaluate
aac_metrics.functional.fense
aac_metrics.functional.fluerr
aac_metrics.functional.fer
aac_metrics.functional.meteor
aac_metrics.functional.mult_cands
aac_metrics.functional.rouge_l
Expand All @@ -25,3 +26,4 @@ Submodules
aac_metrics.functional.spider
aac_metrics.functional.spider_fl
aac_metrics.functional.spider_max
aac_metrics.functional.vocab
2 changes: 2 additions & 0 deletions docs/aac_metrics.utils.rst
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@ Submodules
:maxdepth: 4

aac_metrics.utils.checks
aac_metrics.utils.cmdline
aac_metrics.utils.collections
aac_metrics.utils.globals
aac_metrics.utils.imports
aac_metrics.utils.tokenization
2 changes: 2 additions & 0 deletions requirements-dev.txt
Original file line number Diff line number Diff line change
Expand Up @@ -7,3 +7,5 @@ scikit-image==0.19.2
matplotlib==3.5.2
ipykernel==6.9.1
twine==4.0.1
sphinx==7.2.6
sphinx-press-theme==0.8.0
2 changes: 1 addition & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -5,5 +5,5 @@ numpy>=1.21.2
pyyaml>=6.0
tqdm>=4.64.0
sentence-transformers>=2.2.2
transformers<4.31.0
transformers
torchmetrics>=0.11.4
2 changes: 1 addition & 1 deletion src/aac_metrics/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
__maintainer__ = "Etienne Labbé (Labbeti)"
__name__ = "aac-metrics"
__status__ = "Development"
__version__ = "0.5.1"
__version__ = "0.5.2"


from .classes.base import AACMetric
Expand Down
6 changes: 4 additions & 2 deletions src/aac_metrics/classes/base.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-

import math

from typing import Any, ClassVar, Generic, Optional, TypeVar, Union

from torch import nn, Tensor
Expand All @@ -19,9 +21,9 @@ class AACMetric(nn.Module, Generic[OutType]):
is_differentiable: ClassVar[Optional[bool]] = False

# The theorical minimal value of the main global score of the metric.
min_value: ClassVar[Optional[float]] = None
min_value: ClassVar[float] = -math.inf
# The theorical maximal value of the main global score of the metric.
max_value: ClassVar[Optional[float]] = None
max_value: ClassVar[float] = math.inf

def __init__(self, **kwargs: Any) -> None:
super().__init__(**kwargs)
Expand Down
14 changes: 11 additions & 3 deletions src/aac_metrics/classes/bert_score_mrefs.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,13 +6,15 @@
import torch

from torch import nn, Tensor
from torchmetrics.text.bert import _DEFAULT_MODEL

from aac_metrics.classes.base import AACMetric
from aac_metrics.functional.bert_score_mrefs import (
bert_score_mrefs,
_load_model_and_tokenizer,
DEFAULT_BERT_SCORE_MODEL,
REDUCTIONS,
)
from aac_metrics.utils.globals import _get_device


class BERTScoreMRefs(AACMetric):
Expand All @@ -35,8 +37,8 @@ class BERTScoreMRefs(AACMetric):
def __init__(
self,
return_all_scores: bool = True,
model: Union[str, nn.Module] = _DEFAULT_MODEL,
device: Union[str, torch.device, None] = "auto",
model: Union[str, nn.Module] = DEFAULT_BERT_SCORE_MODEL,
device: Union[str, torch.device, None] = "cuda_if_available",
batch_size: int = 32,
num_threads: int = 0,
max_length: int = 64,
Expand All @@ -46,6 +48,12 @@ def __init__(
filter_nan: bool = True,
verbose: int = 0,
) -> None:
if reduction not in REDUCTIONS:
raise ValueError(
f"Invalid argument {reduction=}. (expected one of {REDUCTIONS})"
)

device = _get_device(device)
model, tokenizer = _load_model_and_tokenizer(
model=model,
tokenizer=None,
Expand Down
12 changes: 6 additions & 6 deletions src/aac_metrics/classes/bleu.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,12 +81,12 @@ def update(
mult_references: list[list[str]],
) -> None:
self._cooked_cands, self._cooked_mrefs = _bleu_update(
candidates,
mult_references,
self._n,
self._tokenizer,
self._cooked_cands,
self._cooked_mrefs,
candidates=candidates,
mult_references=mult_references,
n=self._n,
tokenizer=self._tokenizer,
prev_cooked_cands=self._cooked_cands,
prev_cooked_mrefs=self._cooked_mrefs,
)


Expand Down
8 changes: 4 additions & 4 deletions src/aac_metrics/classes/evaluate.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ def __init__(
cache_path: Union[str, Path, None] = None,
java_path: Union[str, Path, None] = None,
tmp_path: Union[str, Path, None] = None,
device: Union[str, torch.device, None] = "auto",
device: Union[str, torch.device, None] = "cuda_if_available",
verbose: int = 0,
) -> None:
metrics = _instantiate_metrics_classes(
Expand Down Expand Up @@ -127,7 +127,7 @@ def __init__(
cache_path: Union[str, Path, None] = None,
java_path: Union[str, Path, None] = None,
tmp_path: Union[str, Path, None] = None,
device: Union[str, torch.device, None] = "auto",
device: Union[str, torch.device, None] = "cuda_if_available",
verbose: int = 0,
) -> None:
super().__init__(
Expand All @@ -146,7 +146,7 @@ def _instantiate_metrics_classes(
cache_path: Union[str, Path, None] = None,
java_path: Union[str, Path, None] = None,
tmp_path: Union[str, Path, None] = None,
device: Union[str, torch.device, None] = "auto",
device: Union[str, torch.device, None] = "cuda_if_available",
verbose: int = 0,
) -> list[AACMetric]:
if isinstance(metrics, str) and metrics in METRICS_SETS:
Expand Down Expand Up @@ -179,7 +179,7 @@ def _get_metric_factory_classes(
cache_path: Union[str, Path, None] = None,
java_path: Union[str, Path, None] = None,
tmp_path: Union[str, Path, None] = None,
device: Union[str, torch.device, None] = "auto",
device: Union[str, torch.device, None] = "cuda_if_available",
verbose: int = 0,
init_kwds: Optional[dict[str, Any]] = None,
) -> dict[str, Callable[[], AACMetric]]:
Expand Down
21 changes: 14 additions & 7 deletions src/aac_metrics/classes/fense.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,17 @@

import torch

from sentence_transformers import SentenceTransformer
from torch import Tensor

from aac_metrics.classes.base import AACMetric
from aac_metrics.functional.fense import fense, _load_models_and_tokenizer
from aac_metrics.functional.fer import ERROR_NAMES
from aac_metrics.functional.fer import (
BERTFlatClassifier,
_ERROR_NAMES,
DEFAULT_FER_MODEL,
)
from aac_metrics.functional.sbert_sim import DEFAULT_SBERT_SIM_MODEL


pylog = logging.getLogger(__name__)
Expand All @@ -36,10 +42,10 @@ class FENSE(AACMetric[Union[tuple[dict[str, Tensor], dict[str, Tensor]], Tensor]
def __init__(
self,
return_all_scores: bool = True,
sbert_model: str = "paraphrase-TinyBERT-L6-v2",
echecker: str = "echecker_clotho_audiocaps_base",
sbert_model: Union[str, SentenceTransformer] = DEFAULT_SBERT_SIM_MODEL,
echecker: Union[str, BERTFlatClassifier] = DEFAULT_FER_MODEL,
error_threshold: float = 0.9,
device: Union[str, torch.device, None] = "auto",
device: Union[str, torch.device, None] = "cuda_if_available",
batch_size: int = 32,
reset_state: bool = True,
return_probs: bool = False,
Expand Down Expand Up @@ -99,9 +105,10 @@ def extra_repr(self) -> str:
return repr_

def get_output_names(self) -> tuple[str, ...]:
return ("sbert_sim", "fer", "fense") + tuple(
f"fer.{name}_prob" for name in ERROR_NAMES
)
output_names = ["sbert_sim", "fer", "fense"]
if self._return_probs:
output_names += [f"fer.{name}_prob" for name in _ERROR_NAMES]
return tuple(output_names)

def reset(self) -> None:
self._candidates = []
Expand Down
23 changes: 18 additions & 5 deletions src/aac_metrics/classes/fer.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,13 @@

from aac_metrics.classes.base import AACMetric
from aac_metrics.functional.fer import (
BERTFlatClassifier,
fer,
_load_echecker_and_tokenizer,
ERROR_NAMES,
_ERROR_NAMES,
DEFAULT_FER_MODEL,
)
from aac_metrics.utils.globals import _get_device


pylog = logging.getLogger(__name__)
Expand All @@ -39,15 +42,22 @@ class FER(AACMetric[Union[tuple[dict[str, Tensor], dict[str, Tensor]], Tensor]])
def __init__(
self,
return_all_scores: bool = True,
echecker: str = "echecker_clotho_audiocaps_base",
echecker: Union[str, BERTFlatClassifier] = DEFAULT_FER_MODEL,
error_threshold: float = 0.9,
device: Union[str, torch.device, None] = "auto",
device: Union[str, torch.device, None] = "cuda_if_available",
batch_size: int = 32,
reset_state: bool = True,
return_probs: bool = False,
verbose: int = 0,
) -> None:
echecker, echecker_tokenizer = _load_echecker_and_tokenizer(echecker, None, device, reset_state, verbose) # type: ignore
device = _get_device(device)
echecker, echecker_tokenizer = _load_echecker_and_tokenizer(
echecker=echecker,
echecker_tokenizer=None,
device=device,
reset_state=reset_state,
verbose=verbose,
)

super().__init__()
self._return_all_scores = return_all_scores
Expand Down Expand Up @@ -82,7 +92,10 @@ def extra_repr(self) -> str:
return repr_

def get_output_names(self) -> tuple[str, ...]:
return ("fer",) + tuple(f"fer.{name}_prob" for name in ERROR_NAMES)
output_names = ["fer"]
if self._return_probs:
output_names += [f"fer.{name}_prob" for name in _ERROR_NAMES]
return tuple(output_names)

def reset(self) -> None:
self._candidates = []
Expand Down
Loading

0 comments on commit bc1a25e

Please sign in to comment.