Skip to content

Commit

Permalink
Version 0.5.1
Browse files Browse the repository at this point in the history
  • Loading branch information
Labbeti committed Dec 20, 2023
1 parent 45139d6 commit 1353169
Show file tree
Hide file tree
Showing 33 changed files with 323 additions and 733 deletions.
7 changes: 7 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,13 @@

All notable changes to this project will be documented in this file.

## [0.5.1] 2023-12-20
### Added
- Check sentences inputs for all metrics.

### Fixed
- Fix `BERTScoreMRefs` metric with 1 candidate and 1 reference.

## [0.5.0] 2023-12-08
### Added
- New `Vocab` metric to compute vocabulary size and vocabulary ratio.
Expand Down
4 changes: 2 additions & 2 deletions CITATION.cff
Original file line number Diff line number Diff line change
Expand Up @@ -19,5 +19,5 @@ keywords:
- captioning
- audio-captioning
license: MIT
version: 0.5.0
date-released: '2023-12-08'
version: 0.5.1
date-released: '2023-12-20'
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -239,7 +239,7 @@ If you use this software, please consider cite it as "Labbe, E. (2013). aac-metr
month = {12},
title = {{aac-metrics}},
url = {https://github.com/Labbeti/aac-metrics/},
version = {0.5.0},
version = {0.5.1},
year = {2023},
}
```
Expand Down
7 changes: 0 additions & 7 deletions docs/aac_metrics.classes.fluerr.rst

This file was deleted.

7 changes: 7 additions & 0 deletions docs/aac_metrics.utils.globals.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
aac\_metrics.utils.globals module
=================================

.. automodule:: aac_metrics.utils.globals
:members:
:undoc-members:
:show-inheritance:
7 changes: 0 additions & 7 deletions docs/aac_metrics.utils.paths.rst

This file was deleted.

6 changes: 4 additions & 2 deletions src/aac_metrics/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,10 +10,11 @@
__maintainer__ = "Etienne Labbé (Labbeti)"
__name__ = "aac-metrics"
__status__ = "Development"
__version__ = "0.5.0"
__version__ = "0.5.1"


from .classes.base import AACMetric
from .classes.bert_score_mrefs import BERTScoreMRefs
from .classes.bleu import BLEU
from .classes.cider_d import CIDErD
from .classes.evaluate import Evaluate, DCASE2023Evaluate, _get_metric_factory_classes
Expand All @@ -28,7 +29,7 @@
from .classes.spider_max import SPIDErMax
from .classes.vocab import Vocab
from .functional.evaluate import evaluate, dcase2023_evaluate
from .utils.paths import (
from .utils.globals import (
get_default_cache_path,
get_default_java_path,
get_default_tmp_path,
Expand All @@ -40,6 +41,7 @@

__all__ = [
"AACMetric",
"BERTScoreMRefs",
"BLEU",
"CIDErD",
"Evaluate",
Expand Down
2 changes: 2 additions & 0 deletions src/aac_metrics/classes/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-

from .bert_score_mrefs import BERTScoreMRefs
from .bleu import BLEU, BLEU1, BLEU2, BLEU3, BLEU4
from .cider_d import CIDErD
from .evaluate import DCASE2023Evaluate, Evaluate
Expand All @@ -17,6 +18,7 @@


__all__ = [
"BERTScoreMRefs",
"BLEU",
"BLEU1",
"BLEU2",
Expand Down
6 changes: 5 additions & 1 deletion src/aac_metrics/classes/bert_score_mrefs.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,11 @@ def __init__(
verbose: int = 0,
) -> None:
model, tokenizer = _load_model_and_tokenizer(
model, None, device, reset_state, verbose
model=model,
tokenizer=None,
device=device,
reset_state=reset_state,
verbose=verbose,
)

super().__init__()
Expand Down
9 changes: 8 additions & 1 deletion src/aac_metrics/classes/fense.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,14 @@ def __init__(
penalty: float = 0.9,
verbose: int = 0,
) -> None:
sbert_model, echecker, echecker_tokenizer = _load_models_and_tokenizer(sbert_model, echecker, None, device, reset_state, verbose) # type: ignore
sbert_model, echecker, echecker_tokenizer = _load_models_and_tokenizer(
sbert_model=sbert_model,
echecker=echecker,
echecker_tokenizer=None,
device=device,
reset_state=reset_state,
verbose=verbose,
)

super().__init__()
self._return_all_scores = return_all_scores
Expand Down
2 changes: 1 addition & 1 deletion src/aac_metrics/download.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
check_spice_install,
)
from aac_metrics.utils.cmdline import _str_to_bool, _setup_logging
from aac_metrics.utils.paths import (
from aac_metrics.utils.globals import (
_get_cache_path,
_get_tmp_path,
get_default_cache_path,
Expand Down
2 changes: 1 addition & 1 deletion src/aac_metrics/eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
)
from aac_metrics.utils.checks import check_metric_inputs, check_java_path
from aac_metrics.utils.cmdline import _str_to_bool, _str_to_opt_str, _setup_logging
from aac_metrics.utils.paths import (
from aac_metrics.utils.globals import (
get_default_cache_path,
get_default_java_path,
get_default_tmp_path,
Expand Down
2 changes: 2 additions & 0 deletions src/aac_metrics/functional/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-

from .bert_score_mrefs import bert_score_mrefs
from .bleu import bleu, bleu_1, bleu_2, bleu_3, bleu_4
from .cider_d import cider_d
from .evaluate import dcase2023_evaluate, evaluate
Expand All @@ -17,6 +18,7 @@


__all__ = [
"bert_score_mrefs",
"bleu",
"bleu_1",
"bleu_2",
Expand Down
42 changes: 25 additions & 17 deletions src/aac_metrics/functional/bert_score_mrefs.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,9 @@
from transformers.models.auto.tokenization_auto import AutoTokenizer
from transformers import logging as tfmers_logging

from aac_metrics.utils.checks import check_metric_inputs
from aac_metrics.utils.collections import flat_list, unflat_list, duplicate_list
from aac_metrics.utils.globals import _get_device


def bert_score_mrefs(
Expand Down Expand Up @@ -56,13 +58,20 @@ def bert_score_mrefs(
:param verbose: The verbose level. defaults to 0.
:returns: A tuple of globals and locals scores or a scalar tensor with the main global score.
"""
check_metric_inputs(candidates, mult_references)

if isinstance(model, str):
if tokenizer is not None:
raise ValueError(
f"Invalid argument combinaison {model=} with {tokenizer=}."
)

model, tokenizer = _load_model_and_tokenizer(
model, tokenizer, device, reset_state, verbose
model=model,
tokenizer=tokenizer,
device=device,
reset_state=reset_state,
verbose=verbose,
)

elif isinstance(model, nn.Module):
Expand All @@ -76,21 +85,18 @@ def bert_score_mrefs(
f"Invalid argument type {type(model)=}. (expected str or nn.Module)"
)

if device == "auto":
device = "cuda" if torch.cuda.is_available() else "cpu"
if isinstance(device, str):
device = torch.device(device)

device = _get_device(device)
flat_mrefs, sizes = flat_list(mult_references)
duplicated_cands = duplicate_list(candidates, sizes)
assert len(duplicated_cands) == len(flat_mrefs)

tfmers_verbosity = tfmers_logging.get_verbosity()
if verbose <= 1:
tfmers_logging.set_verbosity_error()

sents_scores = bert_score(
duplicated_cands,
flat_mrefs,
preds=duplicated_cands,
target=flat_mrefs,
model_name_or_path=None,
model=model, # type: ignore
user_tokenizer=tokenizer,
Expand All @@ -105,6 +111,12 @@ def bert_score_mrefs(
# Restore previous verbosity level
tfmers_logging.set_verbosity(tfmers_verbosity)

# note: torchmetrics returns a float if input contains 1 cand and 1 ref, even in list
if len(duplicated_cands) == 1 and all(
isinstance(v, float) for v in sents_scores.values()
):
sents_scores = {k: [v] for k, v in sents_scores.items()}

# sents_scores keys: "precision", "recall", "f1"
sents_scores = {k: unflat_list(v, sizes) for k, v in sents_scores.items()} # type: ignore

Expand All @@ -116,9 +128,9 @@ def bert_score_mrefs(
if reduction == "mean":
reduction_fn = torch.mean
elif reduction == "max":
reduction_fn = max_reduce
reduction_fn = _max_reduce
elif reduction == "min":
reduction_fn = min_reduce
reduction_fn = _min_reduce
else:
REDUCTIONS = ("mean", "max", "min")
raise ValueError(
Expand Down Expand Up @@ -161,11 +173,7 @@ def _load_model_and_tokenizer(
) -> tuple[nn.Module, Optional[Callable]]:
state = torch.random.get_rng_state()

if device == "auto":
device = "cuda" if torch.cuda.is_available() else "cpu"
if isinstance(device, str):
device = torch.device(device)

device = _get_device(device)
if isinstance(model, str):
tfmers_verbosity = tfmers_logging.get_verbosity()
if verbose <= 1:
Expand All @@ -188,14 +196,14 @@ def _load_model_and_tokenizer(
return model, tokenizer # type: ignore


def max_reduce(x: Tensor, dim: Optional[int] = None) -> Tensor:
def _max_reduce(x: Tensor, dim: Optional[int] = None) -> Tensor:
if dim is None:
return x.max()
else:
return x.max(dim=dim).values


def min_reduce(x: Tensor, dim: Optional[int] = None) -> Tensor:
def _min_reduce(x: Tensor, dim: Optional[int] = None) -> Tensor:
if dim is None:
return x.min()
else:
Expand Down
8 changes: 4 additions & 4 deletions src/aac_metrics/functional/bleu.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@

from torch import Tensor

from aac_metrics.utils.checks import check_metric_inputs


pylog = logging.getLogger(__name__)

Expand Down Expand Up @@ -158,10 +160,8 @@ def _bleu_update(
prev_cooked_cands: list,
prev_cooked_mrefs: list,
) -> tuple[list, list[tuple]]:
if len(candidates) != len(mult_references):
raise ValueError(
f"Invalid number of candidates and references. (found {len(candidates)=} != {len(mult_references)=})"
)
check_metric_inputs(candidates, mult_references)

new_cooked_mrefs = [
__cook_references(refs, None, n, tokenizer) for refs in mult_references
]
Expand Down
7 changes: 3 additions & 4 deletions src/aac_metrics/functional/cider_d.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@

from torch import Tensor

from aac_metrics.utils.checks import check_metric_inputs


def cider_d(
candidates: list[str],
Expand Down Expand Up @@ -66,10 +68,7 @@ def _cider_d_update(
prev_cooked_cands: list[Counter],
prev_cooked_mrefs: list[list[Counter]],
) -> tuple[list, list]:
if len(candidates) != len(mult_references):
raise ValueError(
f"Invalid number of candidates and references. (found {len(candidates)=} != {len(mult_references)=})"
)
check_metric_inputs(candidates, mult_references)
new_cooked_mrefs = [
[__cook_sentence(ref, n, tokenizer) for ref in refs] for refs in mult_references
]
Expand Down
17 changes: 10 additions & 7 deletions src/aac_metrics/functional/fense.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,6 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-

"""FENSE metric functional API.
Based on original implementation in https://github.com/blmoistawinde/fense/
"""

import logging

from typing import Optional, Union
Expand All @@ -22,6 +17,7 @@
BERTFlatClassifier,
)
from aac_metrics.functional.sbert_sim import sbert_sim, _load_sbert
from aac_metrics.utils.checks import check_metric_inputs


pylog = logging.getLogger(__name__)
Expand Down Expand Up @@ -71,6 +67,7 @@ def fense(
:param verbose: The verbose level. defaults to 0.
:returns: A tuple of globals and locals scores or a scalar tensor with the main global score.
"""
check_metric_inputs(candidates, mult_references)

# Init models
sbert_model, echecker, echecker_tokenizer = _load_models_and_tokenizer(
Expand Down Expand Up @@ -148,8 +145,14 @@ def _load_models_and_tokenizer(
reset_state: bool = True,
verbose: int = 0,
) -> tuple[SentenceTransformer, BERTFlatClassifier, AutoTokenizer]:
sbert_model = _load_sbert(sbert_model, device, reset_state)
sbert_model = _load_sbert(
sbert_model=sbert_model, device=device, reset_state=reset_state
)
echecker, echecker_tokenizer = _load_echecker_and_tokenizer(
echecker, echecker_tokenizer, device, reset_state, verbose
echecker=echecker,
echecker_tokenizer=echecker_tokenizer,
device=device,
reset_state=reset_state,
verbose=verbose,
)
return sbert_model, echecker, echecker_tokenizer
Loading

0 comments on commit 1353169

Please sign in to comment.