Skip to content

Commit

Permalink
Version 0.5.3
Browse files Browse the repository at this point in the history
  • Loading branch information
Labbeti committed Jan 9, 2024
1 parent bc1a25e commit e513a6f
Show file tree
Hide file tree
Showing 9 changed files with 49 additions and 32 deletions.
7 changes: 6 additions & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,15 @@

All notable changes to this project will be documented in this file.

## [0.5.3] 2024-01-09
### Fixed
- Fix `BERTScoreMrefs` computation when all multiple references sizes are equal.
- Check for empty timeout list in `SPICE` metric.

## [0.5.2] 2024-01-05
### Changed
- `aac-metrics` is now compatible with `transformers>=4.31`.
- Rename default device value "auto" to "cuda_if_available".
- Rename default device value `"auto"` to `"cuda_if_available"`.

## [0.5.1] 2023-12-20
### Added
Expand Down
4 changes: 2 additions & 2 deletions CITATION.cff
Original file line number Diff line number Diff line change
Expand Up @@ -19,5 +19,5 @@ keywords:
- captioning
- audio-captioning
license: MIT
version: 0.5.2
date-released: '2024-01-05'
version: 0.5.3
date-released: '2024-01-09'
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -237,7 +237,7 @@ If you use this software, please consider cite it as "Labbe, E. (2013). aac-metr
month = {01},
title = {{aac-metrics}},
url = {https://github.com/Labbeti/aac-metrics/},
version = {0.5.2},
version = {0.5.3},
year = {2024},
}
```
Expand Down
11 changes: 8 additions & 3 deletions src/aac_metrics/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
__maintainer__ = "Etienne Labbé (Labbeti)"
__name__ = "aac-metrics"
__status__ = "Development"
__version__ = "0.5.2"
__version__ = "0.5.3"


from .classes.base import AACMetric
Expand Down Expand Up @@ -68,12 +68,17 @@
]


def list_metrics_available() -> list[str]:
"""Returns the list of metrics that can be loaded from its name."""
factory = _get_metric_factory_classes()
return list(factory.keys())


def load_metric(name: str, **kwargs) -> AACMetric:
"""Load a metric class by name.
:param name: The name of the metric.
Can be one of ("bleu_1", "bleu_2", "bleu_3", "bleu_4", "meteor", "rouge_l", "cider_d", "spice", "spider", "fense").
:param **kwargs: The keyword optional arguments passed to the metric factory.
:param **kwargs: The optional keyword arguments passed to the metric factory.
:returns: The Metric object built.
"""
name = name.lower().strip()
Expand Down
4 changes: 2 additions & 2 deletions src/aac_metrics/classes/bert_score_mrefs.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-

from typing import Union
from typing import Callable, Union

import torch

Expand Down Expand Up @@ -44,7 +44,7 @@ def __init__(
max_length: int = 64,
reset_state: bool = True,
idf: bool = False,
reduction: str = "max",
reduction: Union[str, Callable[[Tensor, ...], Tensor]] = "max",
filter_nan: bool = True,
verbose: int = 0,
) -> None:
Expand Down
26 changes: 14 additions & 12 deletions src/aac_metrics/functional/bert_score_mrefs.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ def bert_score_mrefs(
max_length: int = 64,
reset_state: bool = True,
idf: bool = False,
reduction: str = "max",
reduction: Union[str, Callable[[Tensor, ...], Tensor]] = "max",
filter_nan: bool = True,
verbose: int = 0,
) -> Union[tuple[dict[str, Tensor], dict[str, Tensor]], Tensor]:
Expand Down Expand Up @@ -129,21 +129,23 @@ def bert_score_mrefs(

dtype = torch.float32

if reduction == "mean":
reduction_fn = torch.mean
elif reduction == "max":
reduction_fn = _max_reduce
elif reduction == "min":
reduction_fn = _min_reduce
if isinstance(reduction, str):
if reduction == "mean":
reduction_fn = torch.mean
elif reduction == "max":
reduction_fn = _max_reduce
elif reduction == "min":
reduction_fn = _min_reduce
else:
raise ValueError(
f"Invalid argument {reduction=}. (expected one of {REDUCTIONS})"
)
else:
raise ValueError(
f"Invalid argument {reduction=}. (expected one of {REDUCTIONS})"
)
reduction_fn = reduction

if len(sizes) > 0 and all(size == sizes[0] for size in sizes):
sents_scores = {
k: reduction_fn(torch.as_tensor(v, dtype=dtype), dim=1)
for k, v in sents_scores.items()
k: reduction_fn(torch.stack(v), dim=1) for k, v in sents_scores.items()
}
else:
sents_scores = {
Expand Down
3 changes: 2 additions & 1 deletion src/aac_metrics/functional/mult_cands.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,8 @@ def mult_cands_metric(
f"{k}_all": scores.transpose(0, 1) for k, scores in all_sents_scores.items()
}

outs_corpus = {k: reduction(scores) for k, scores in outs_sents.items()}
reduction_fn = reduction
outs_corpus = {k: reduction_fn(scores) for k, scores in outs_sents.items()}

if return_all_scores:
return outs_corpus, outs_sents
Expand Down
23 changes: 14 additions & 9 deletions src/aac_metrics/functional/spice.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,12 @@ def spice(
timeout_lst = [timeout]
else:
timeout_lst = list(timeout)

timeout_lst: list[Optional[int]]
if len(timeout_lst) == 0:
raise ValueError(
f"Invalid argument {timeout_lst=}. (cannot call SPICE with empty number of timeouts)"
)

spice_fpath = osp.join(cache_path, FNAME_SPICE_JAR)

Expand Down Expand Up @@ -170,15 +175,15 @@ def spice(

for i, timeout_i in enumerate(timeout_lst):
success = __run_spice(
i,
timeout_i,
timeout_lst,
spice_cmd,
tmp_path,
out_file.name,
fpaths,
use_shell,
verbose,
i=i,
timeout_i=timeout_i,
timeout_lst=timeout_lst,
spice_cmd=spice_cmd,
tmp_path=tmp_path,
out_path=out_file.name,
paths=fpaths,
use_shell=use_shell,
verbose=verbose,
)
if success:
break
Expand Down
1 change: 0 additions & 1 deletion tests/test_all.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-

import platform
import unittest

from unittest import TestCase
Expand Down

0 comments on commit e513a6f

Please sign in to comment.