Skip to content

Commit

Permalink
Version 0.4.3
Browse files Browse the repository at this point in the history
  • Loading branch information
Labbeti committed Jun 15, 2023
1 parent b74e2ad commit f087ae9
Show file tree
Hide file tree
Showing 19 changed files with 115 additions and 1,538 deletions.
9 changes: 7 additions & 2 deletions .github/workflows/python-package-pip.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ jobs:
submodules: recursive

- name: Set up Python ${{ matrix.python-version }}
uses: actions/setup-python@v2
uses: actions/setup-python@v4
with:
python-version: ${{ matrix.python-version }}
cache: 'pip'
Expand All @@ -39,8 +39,9 @@ jobs:
java-package: jre

- name: Install package
# note: ${GITHUB_REF##*/} gives the branch name
run: |
python -m pip install "aac-metrics[dev] @ git+https://github.com/Labbeti/aac-metrics@dev"
python -m pip install "aac-metrics[dev] @ git+https://github.com/Labbeti/aac-metrics@${GITHUB_REF##*/}"
- name: Load cache of external code and data
uses: actions/cache@master
Expand All @@ -52,6 +53,10 @@ jobs:
${{ runner.os }}-
# --- TESTS ---
- name: Compile python files
run: |
python -m compileall src
- name: Lint with flake8
run: |
python -m flake8 --config .flake8 --exit-zero --show-source --statistics src
Expand Down
14 changes: 14 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,20 @@

All notable changes to this project will be documented in this file.

## [0.4.4] UNRELEASED
### Changed
- TODO

## [0.4.3] 2023-06-15
### Changed
- `AACMetric` is no longer a subclass of `torchmetrics.Metric` even when it is installed. It avoid dependency to this package and remove potential errors due to Metric.
- Java 12 and 13 are now allowed.

### Fixed
- Output name `sbert_sim` in FENSE and SBERTSim classes.
- `Evaluate` class instantiation with `torchmetrics` >= 0.11.
- `evaluate.py` script when using a verbose mode != 0.

## [0.4.2] 2023-04-19
### Fixed
- File `install_spice.sh` is now in `src/aac_metrics` directory to fix download from a pip installation. ([#3](https://github.com/Labbeti/aac-metrics/issues/3))
Expand Down
4 changes: 2 additions & 2 deletions CITATION.cff
Original file line number Diff line number Diff line change
Expand Up @@ -19,5 +19,5 @@ keywords:
- captioning
- audio-captioning
license: MIT
version: 0.4.2
date-released: '2023-04-19'
version: 0.4.3
date-released: '2023-06-15'
8 changes: 2 additions & 6 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -138,10 +138,6 @@ The CIDEr metric differs from CIDEr-D because it applies a stemmer to each word
### Does metrics work on multi-GPU ?
No. Most of these metrics use numpy or external java programs to run, which prevents multi-GPU testing for now.

### Is torchmetrics needed for this package ?
No. But if torchmetrics is installed, all metrics classes will inherit from the base class `torchmetrics.Metric`.
It is because most of the metrics does not use PyTorch tensors to compute scores and numpy and strings cannot be added to states of `torchmetrics.Metric`.

## SPIDEr-max metric
SPIDEr-max [[7]](#spider-max) is a metric based on SPIDEr that takes into account multiple candidates for the same audio. It computes the maximum of the SPIDEr scores for each candidate to balance the high sensitivity to the frequency of the words generated by the model. For more detail, please see the [documentation about SPIDEr-max](https://aac-metrics.readthedocs.io/en/stable/spider_max.html).

Expand Down Expand Up @@ -216,10 +212,10 @@ If you use this software, please consider cite it as below :
Labbe_aac-metrics_2023,
author = {Labbé, Etienne},
license = {MIT},
month = {4},
month = {6},
title = {{aac-metrics}},
url = {https://github.com/Labbeti/aac-metrics/},
version = {0.4.2},
version = {0.4.3},
year = {2023},
}
```
Expand Down
465 changes: 0 additions & 465 deletions examples/example_1.csv

This file was deleted.

913 changes: 0 additions & 913 deletions examples/example_2.csv

This file was deleted.

55 changes: 0 additions & 55 deletions install_spice.sh

This file was deleted.

9 changes: 5 additions & 4 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -31,10 +31,10 @@ dependencies = [
dynamic = ["version"]

[project.urls]
homepage = "https://pypi.org/project/aac-metrics/"
documentation = "https://aac-metrics.readthedocs.io/"
repository = "https://github.com//Labbeti/aac-metrics.git"
changelog = "https://github.com/Labbeti/aac-metrics/blob/main/CHANGELOG.md"
Homepage = "https://pypi.org/project/aac-metrics/"
Documentation = "https://aac-metrics.readthedocs.io/"
Repository = "https://github.com//Labbeti/aac-metrics.git"
Changelog = "https://github.com/Labbeti/aac-metrics/blob/main/CHANGELOG.md"

[project.scripts]
aac-metrics = "aac_metrics.__main__:_print_usage"
Expand All @@ -49,6 +49,7 @@ dev = [
"black==22.8.0",
"scikit-image==0.19.2",
"matplotlib==3.5.2",
"torchmetrics>=0.10",
]

[tool.setuptools.packages.find]
Expand Down
2 changes: 1 addition & 1 deletion src/aac_metrics/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
__license__ = "MIT"
__maintainer__ = "Etienne Labbé (Labbeti)"
__status__ = "Development"
__version__ = "0.4.2"
__version__ = "0.4.3"


from .classes.base import AACMetric
Expand Down
59 changes: 25 additions & 34 deletions src/aac_metrics/classes/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,46 +3,37 @@

from typing import Any, Optional

from aac_metrics.utils.imports import _TORCHMETRICS_AVAILABLE
from torch import nn


if _TORCHMETRICS_AVAILABLE:
from torchmetrics import Metric as __BaseMetric # type: ignore
class AACMetric(nn.Module):
"""Base Metric module used when torchmetrics is not installed."""

class AACMetric(__BaseMetric): # type: ignore
# The theorical minimal value of the main global score of the metric.
min_value: Optional[float] = None
# The theorical maximal value of the main global score of the metric.
max_value: Optional[float] = None
# Global values
full_state_update: Optional[bool] = False
higher_is_better: Optional[bool] = None
is_differentiable: Optional[bool] = False

else:
from torch import nn
# The theorical minimal value of the main global score of the metric.
min_value: Optional[float] = None
# The theorical maximal value of the main global score of the metric.
max_value: Optional[float] = None

class AACMetric(nn.Module):
"""Base Metric module used when torchmetrics is not installed."""
def __init__(self, **kwargs: Any) -> None:
super().__init__(**kwargs)

# Global values
full_state_update: Optional[bool] = False
higher_is_better: Optional[bool] = None
is_differentiable: Optional[bool] = False
# Public methods
def compute(self) -> Any:
return None

# The theorical minimal value of the main global score of the metric.
min_value: Optional[float] = None
# The theorical maximal value of the main global score of the metric.
max_value: Optional[float] = None
def forward(self, *args: Any, **kwargs: Any) -> Any:
self.update(*args, **kwargs)
output = self.compute()
self.reset()
return output

# Public methods
def compute(self) -> Any:
return None
def reset(self) -> None:
pass

def forward(self, *args: Any, **kwargs: Any) -> Any:
self.update(*args, **kwargs)
output = self.compute()
self.reset()
return output

def reset(self) -> None:
pass

def update(self, *args, **kwargs) -> None:
pass
def update(self, *args, **kwargs) -> None:
pass
4 changes: 2 additions & 2 deletions src/aac_metrics/classes/evaluate.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
pylog = logging.getLogger(__name__)


class Evaluate(AACMetric, list[AACMetric]):
class Evaluate(list[AACMetric], AACMetric):
"""Evaluate candidates with multiple references with custom metrics.
For more information, see :func:`~aac_metrics.functional.evaluate.evaluate`.
Expand Down Expand Up @@ -55,8 +55,8 @@ def __init__(
verbose,
)

AACMetric.__init__(self)
list.__init__(self, metrics)
AACMetric.__init__(self)
self._preprocess = preprocess
self._cache_path = cache_path
self._java_path = java_path
Expand Down
2 changes: 1 addition & 1 deletion src/aac_metrics/classes/fense.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@ def extra_repr(self) -> str:
return f"error_threshold={self._error_threshold}, penalty={self._penalty}, device={self._device}, batch_size={self._batch_size}"

def get_output_names(self) -> tuple[str, ...]:
return ("sbert.sim", "fluerr", "fense") + tuple(
return ("sbert_sim", "fluerr", "fense") + tuple(
f"fluerr.{name}_prob" for name in ERROR_NAMES
)

Expand Down
2 changes: 1 addition & 1 deletion src/aac_metrics/classes/sbert_sim.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ def extra_repr(self) -> str:
return f"device={self._device}, batch_size={self._batch_size}"

def get_output_names(self) -> tuple[str, ...]:
return ("sbert.sim",)
return ("sbert_sim",)

def reset(self) -> None:
self._candidates = []
Expand Down
16 changes: 8 additions & 8 deletions src/aac_metrics/evaluate.py
Original file line number Diff line number Diff line change
Expand Up @@ -210,14 +210,14 @@ def _main_evaluate() -> None:
)

corpus_scores, _sents_scores = evaluate(
candidates,
mult_references,
True,
args.metrics_set_name,
args.cache_path,
args.java_path,
args.tmp_path,
args.verbose,
candidates=candidates,
mult_references=mult_references,
preprocess=True,
metrics=args.metrics_set_name,
cache_path=args.cache_path,
java_path=args.java_path,
tmp_path=args.tmp_path,
verbose=args.verbose,
)

corpus_scores = {k: v.item() for k, v in corpus_scores.items()}
Expand Down
24 changes: 15 additions & 9 deletions src/aac_metrics/functional/bleu.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-

import logging
import math

from collections import Counter
Expand All @@ -11,6 +12,8 @@
from torch import Tensor


pylog = logging.getLogger(__name__)

BLEU_OPTIONS = ("shortest", "average", "closest")


Expand Down Expand Up @@ -165,7 +168,7 @@ def __cook_references(
for ref in refs:
rl, counts = __cook_sentence(ref, n, tokenizer)
reflen.append(rl)
for (ngram, count) in counts.items():
for ngram, count in counts.items():
maxcounts[ngram] = max(maxcounts.get(ngram, 0), count)

# Calculate effective reference sentence length.
Expand All @@ -176,7 +179,7 @@ def __cook_references(

# lhuang: N.B.: leave reflen computaiton to the very end!!
# lhuang: N.B.: in case of "closest", keep a list of reflens!! (bad design)
return (reflen, maxcounts)
return reflen, maxcounts


def __cook_candidate(
Expand All @@ -203,7 +206,7 @@ def __cook_candidate(
result["guess"] = [max(0, testlen - k + 1) for k in range(1, n + 1)]
result["correct"] = [0] * n

for (ngram, count) in counts.items():
for ngram, count in counts.items():
result["correct"][len(ngram) - 1] += min(refmaxcounts.get(ngram, 0), count)

return result
Expand All @@ -221,7 +224,10 @@ def __compute_bleu_score(
bleu_list = [[] for _ in range(n)]

if option is None:
option = "average" if len(cooked_mrefs) == 1 else "closest"
if len(cooked_mrefs) == 1:
option = "average"
else:
option = "closest"

global_cands_len = 0
global_mrefs_len = 0
Expand Down Expand Up @@ -254,8 +260,8 @@ def __compute_bleu_score(
for k in range(n):
bleu_list[k][-1] *= math.exp(1 - 1 / ratio)

if verbose > 1:
print(comps, reflen)
if verbose > 2:
pylog.debug(comps, reflen)

totalcomps["reflen"] = global_mrefs_len
totalcomps["testlen"] = global_cands_len
Expand All @@ -274,9 +280,9 @@ def __compute_bleu_score(
for k in range(n):
bleus[k] *= math.exp(1 - 1 / ratio)

if verbose > 0:
print(totalcomps)
print("ratio:", ratio)
if verbose > 2:
pylog.debug(totalcomps)
pylog.debug("ratio:", ratio)

return bleus, bleu_list

Expand Down
Loading

0 comments on commit f087ae9

Please sign in to comment.