From ca355283609d0f9050785594393a7b0b09b04c64 Mon Sep 17 00:00:00 2001 From: Labbeti Date: Wed, 17 Apr 2024 12:37:41 +0200 Subject: [PATCH] Add: DCASE2024 metric class and function. --- CHANGELOG.md | 5 +++- README.md | 8 +++--- src/aac_metrics/classes/evaluate.py | 31 ++++++++++++++++++--- src/aac_metrics/functional/evaluate.py | 37 ++++++++++++++++++++++++++ 4 files changed, 72 insertions(+), 9 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 57ba1fc..b9b25a5 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -3,8 +3,11 @@ All notable changes to this project will be documented in this file. ## [0.5.5] UNRELEASED +### Added +- DCASE2024 challenge metric set, class and functions. + ### Changed -- Update metric typing for language servers. +- Improve metric output typing for language servers. ## [0.5.4] 2024-03-04 ### Fixed diff --git a/README.md b/README.md index 13b8a7a..4259e87 100644 --- a/README.md +++ b/README.md @@ -69,13 +69,13 @@ print(corpus_scores) # dict containing the score of each metric: "bleu_1", "bleu_2", "bleu_3", "bleu_4", "rouge_l", "meteor", "cider_d", "spice", "spider" # {"bleu_1": tensor(0.4278), "bleu_2": ..., ...} ``` -### Evaluate DCASE2023 metrics -To compute metrics for the DCASE2023 challenge, just set the argument `metrics="dcase2023"` in `evaluate` function call. +### Evaluate DCASE2024 metrics +To compute metrics for the DCASE2023 challenge, just set the argument `metrics="dcase2024"` in `evaluate` function call. ```python -corpus_scores, _ = evaluate(candidates, mult_references, metrics="dcase2023") +corpus_scores, _ = evaluate(candidates, mult_references, metrics="dcase2024") print(corpus_scores) -# dict containing the score of each metric: "meteor", "cider_d", "spice", "spider", "spider_fl", "fluerr" +# dict containing the score of each metric: "meteor", "cider_d", "spice", "spider", "spider_fl", "fer", "fense", "vocab" ``` ### Evaluate a specific metric diff --git a/src/aac_metrics/classes/evaluate.py b/src/aac_metrics/classes/evaluate.py index 9415e2f..1a022e7 100644 --- a/src/aac_metrics/classes/evaluate.py +++ b/src/aac_metrics/classes/evaluate.py @@ -4,12 +4,10 @@ import logging import pickle import zlib - from pathlib import Path from typing import Any, Callable, Iterable, Optional, Union import torch - from torch import Tensor from aac_metrics.classes.base import AACMetric @@ -23,8 +21,8 @@ from aac_metrics.classes.sbert_sim import SBERTSim from aac_metrics.classes.spice import SPICE from aac_metrics.classes.spider import SPIDEr -from aac_metrics.classes.spider_max import SPIDErMax from aac_metrics.classes.spider_fl import SPIDErFL +from aac_metrics.classes.spider_max import SPIDErMax from aac_metrics.classes.vocab import Vocab from aac_metrics.functional.evaluate import ( DEFAULT_METRICS_SET_NAME, @@ -32,7 +30,6 @@ evaluate, ) - pylog = logging.getLogger(__name__) @@ -141,6 +138,32 @@ def __init__( ) +class DCASE2024Evaluate(Evaluate): + """Evaluate candidates with multiple references with DCASE2024 Audio Captioning metrics. + + For more information, see :func:`~aac_metrics.functional.evaluate.dcase2024_evaluate`. + """ + + def __init__( + self, + preprocess: bool = True, + cache_path: Union[str, Path, None] = None, + java_path: Union[str, Path, None] = None, + tmp_path: Union[str, Path, None] = None, + device: Union[str, torch.device, None] = "cuda_if_available", + verbose: int = 0, + ) -> None: + super().__init__( + preprocess=preprocess, + metrics="dcase2024", + cache_path=cache_path, + java_path=java_path, + tmp_path=tmp_path, + device=device, + verbose=verbose, + ) + + def _instantiate_metrics_classes( metrics: Union[str, Iterable[str], Iterable[AACMetric]] = "aac", cache_path: Union[str, Path, None] = None, diff --git a/src/aac_metrics/functional/evaluate.py b/src/aac_metrics/functional/evaluate.py index 70b5bde..4e622ff 100644 --- a/src/aac_metrics/functional/evaluate.py +++ b/src/aac_metrics/functional/evaluate.py @@ -206,6 +206,43 @@ def dcase2023_evaluate( ) +def dcase2024_evaluate( + candidates: list[str], + mult_references: list[list[str]], + preprocess: bool = True, + cache_path: Union[str, Path, None] = None, + java_path: Union[str, Path, None] = None, + tmp_path: Union[str, Path, None] = None, + device: Union[str, torch.device, None] = "cuda_if_available", + verbose: int = 0, +) -> tuple[dict[str, Tensor], dict[str, Tensor]]: + """Evaluate candidates with multiple references with the DCASE2024 Audio Captioning metrics. + + :param candidates: The list of sentences to evaluate. + :param mult_references: The list of list of sentences used as target. + :param preprocess: If True, the candidates and references will be passed as input to the PTB stanford tokenizer before computing metrics. + defaults to True. + :param cache_path: The path to the external code directory. defaults to the value returned by :func:`~aac_metrics.utils.paths.get_default_cache_path`. + :param java_path: The path to the java executable. defaults to the value returned by :func:`~aac_metrics.utils.paths.get_default_java_path`. + :param tmp_path: Temporary directory path. defaults to the value returned by :func:`~aac_metrics.utils.paths.get_default_tmp_path`. + :param device: The PyTorch device used to run FENSE and SPIDErFL models. + If None, it will try to detect use cuda if available. defaults to "cuda_if_available". + :param verbose: The verbose level. defaults to 0. + :returns: A tuple contains the corpus and sentences scores. + """ + return evaluate( + candidates=candidates, + mult_references=mult_references, + preprocess=preprocess, + metrics="dcase2024", + cache_path=cache_path, + java_path=java_path, + tmp_path=tmp_path, + device=device, + verbose=verbose, + ) + + def _instantiate_metrics_functions( metrics: Union[str, Iterable[str], Iterable[Callable[[list, list], tuple]]] = "all", cache_path: Union[str, Path, None] = None,