diff --git a/CHANGELOG.md b/CHANGELOG.md index b7edbe1..ac4eef3 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -2,10 +2,15 @@ All notable changes to this project will be documented in this file. +## [0.5.3] 2024-01-09 +### Fixed +- Fix `BERTScoreMrefs` computation when all multiple references sizes are equal. +- Check for empty timeout list in `SPICE` metric. + ## [0.5.2] 2024-01-05 ### Changed - `aac-metrics` is now compatible with `transformers>=4.31`. -- Rename default device value "auto" to "cuda_if_available". +- Rename default device value `"auto"` to `"cuda_if_available"`. ## [0.5.1] 2023-12-20 ### Added diff --git a/CITATION.cff b/CITATION.cff index 625d4a5..191aff9 100644 --- a/CITATION.cff +++ b/CITATION.cff @@ -19,5 +19,5 @@ keywords: - captioning - audio-captioning license: MIT -version: 0.5.2 -date-released: '2024-01-05' +version: 0.5.3 +date-released: '2024-01-09' diff --git a/README.md b/README.md index 4c0a0e3..1ad898e 100644 --- a/README.md +++ b/README.md @@ -237,7 +237,7 @@ If you use this software, please consider cite it as "Labbe, E. (2013). aac-metr month = {01}, title = {{aac-metrics}}, url = {https://github.com/Labbeti/aac-metrics/}, - version = {0.5.2}, + version = {0.5.3}, year = {2024}, } ``` diff --git a/src/aac_metrics/__init__.py b/src/aac_metrics/__init__.py index fa9aeb2..e38722c 100644 --- a/src/aac_metrics/__init__.py +++ b/src/aac_metrics/__init__.py @@ -10,7 +10,7 @@ __maintainer__ = "Etienne Labbé (Labbeti)" __name__ = "aac-metrics" __status__ = "Development" -__version__ = "0.5.2" +__version__ = "0.5.3" from .classes.base import AACMetric @@ -68,12 +68,17 @@ ] +def list_metrics_available() -> list[str]: + """Returns the list of metrics that can be loaded from its name.""" + factory = _get_metric_factory_classes() + return list(factory.keys()) + + def load_metric(name: str, **kwargs) -> AACMetric: """Load a metric class by name. :param name: The name of the metric. - Can be one of ("bleu_1", "bleu_2", "bleu_3", "bleu_4", "meteor", "rouge_l", "cider_d", "spice", "spider", "fense"). - :param **kwargs: The keyword optional arguments passed to the metric factory. + :param **kwargs: The optional keyword arguments passed to the metric factory. :returns: The Metric object built. """ name = name.lower().strip() diff --git a/src/aac_metrics/classes/bert_score_mrefs.py b/src/aac_metrics/classes/bert_score_mrefs.py index a9451ba..369103b 100644 --- a/src/aac_metrics/classes/bert_score_mrefs.py +++ b/src/aac_metrics/classes/bert_score_mrefs.py @@ -1,7 +1,7 @@ #!/usr/bin/env python # -*- coding: utf-8 -*- -from typing import Union +from typing import Callable, Union import torch @@ -44,7 +44,7 @@ def __init__( max_length: int = 64, reset_state: bool = True, idf: bool = False, - reduction: str = "max", + reduction: Union[str, Callable[[Tensor, ...], Tensor]] = "max", filter_nan: bool = True, verbose: int = 0, ) -> None: diff --git a/src/aac_metrics/functional/bert_score_mrefs.py b/src/aac_metrics/functional/bert_score_mrefs.py index 2e1bdbd..2ccdc69 100644 --- a/src/aac_metrics/functional/bert_score_mrefs.py +++ b/src/aac_metrics/functional/bert_score_mrefs.py @@ -32,7 +32,7 @@ def bert_score_mrefs( max_length: int = 64, reset_state: bool = True, idf: bool = False, - reduction: str = "max", + reduction: Union[str, Callable[[Tensor, ...], Tensor]] = "max", filter_nan: bool = True, verbose: int = 0, ) -> Union[tuple[dict[str, Tensor], dict[str, Tensor]], Tensor]: @@ -129,21 +129,23 @@ def bert_score_mrefs( dtype = torch.float32 - if reduction == "mean": - reduction_fn = torch.mean - elif reduction == "max": - reduction_fn = _max_reduce - elif reduction == "min": - reduction_fn = _min_reduce + if isinstance(reduction, str): + if reduction == "mean": + reduction_fn = torch.mean + elif reduction == "max": + reduction_fn = _max_reduce + elif reduction == "min": + reduction_fn = _min_reduce + else: + raise ValueError( + f"Invalid argument {reduction=}. (expected one of {REDUCTIONS})" + ) else: - raise ValueError( - f"Invalid argument {reduction=}. (expected one of {REDUCTIONS})" - ) + reduction_fn = reduction if len(sizes) > 0 and all(size == sizes[0] for size in sizes): sents_scores = { - k: reduction_fn(torch.as_tensor(v, dtype=dtype), dim=1) - for k, v in sents_scores.items() + k: reduction_fn(torch.stack(v), dim=1) for k, v in sents_scores.items() } else: sents_scores = { diff --git a/src/aac_metrics/functional/mult_cands.py b/src/aac_metrics/functional/mult_cands.py index 5145714..e5b7115 100644 --- a/src/aac_metrics/functional/mult_cands.py +++ b/src/aac_metrics/functional/mult_cands.py @@ -113,7 +113,8 @@ def mult_cands_metric( f"{k}_all": scores.transpose(0, 1) for k, scores in all_sents_scores.items() } - outs_corpus = {k: reduction(scores) for k, scores in outs_sents.items()} + reduction_fn = reduction + outs_corpus = {k: reduction_fn(scores) for k, scores in outs_sents.items()} if return_all_scores: return outs_corpus, outs_sents diff --git a/src/aac_metrics/functional/spice.py b/src/aac_metrics/functional/spice.py index faddb64..d87e836 100644 --- a/src/aac_metrics/functional/spice.py +++ b/src/aac_metrics/functional/spice.py @@ -93,7 +93,12 @@ def spice( timeout_lst = [timeout] else: timeout_lst = list(timeout) + timeout_lst: list[Optional[int]] + if len(timeout_lst) == 0: + raise ValueError( + f"Invalid argument {timeout_lst=}. (cannot call SPICE with empty number of timeouts)" + ) spice_fpath = osp.join(cache_path, FNAME_SPICE_JAR) @@ -170,15 +175,15 @@ def spice( for i, timeout_i in enumerate(timeout_lst): success = __run_spice( - i, - timeout_i, - timeout_lst, - spice_cmd, - tmp_path, - out_file.name, - fpaths, - use_shell, - verbose, + i=i, + timeout_i=timeout_i, + timeout_lst=timeout_lst, + spice_cmd=spice_cmd, + tmp_path=tmp_path, + out_path=out_file.name, + paths=fpaths, + use_shell=use_shell, + verbose=verbose, ) if success: break diff --git a/tests/test_all.py b/tests/test_all.py index 5323f4e..a071420 100644 --- a/tests/test_all.py +++ b/tests/test_all.py @@ -1,7 +1,6 @@ #!/usr/bin/env python # -*- coding: utf-8 -*- -import platform import unittest from unittest import TestCase