Skip to content

Commit

Permalink
Add: DCASE2024 metric class and function.
Browse files Browse the repository at this point in the history
  • Loading branch information
Labbeti committed Apr 17, 2024
1 parent 51c9203 commit ca35528
Show file tree
Hide file tree
Showing 4 changed files with 72 additions and 9 deletions.
5 changes: 4 additions & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,11 @@
All notable changes to this project will be documented in this file.

## [0.5.5] UNRELEASED
### Added
- DCASE2024 challenge metric set, class and functions.

### Changed
- Update metric typing for language servers.
- Improve metric output typing for language servers.

## [0.5.4] 2024-03-04
### Fixed
Expand Down
8 changes: 4 additions & 4 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -69,13 +69,13 @@ print(corpus_scores)
# dict containing the score of each metric: "bleu_1", "bleu_2", "bleu_3", "bleu_4", "rouge_l", "meteor", "cider_d", "spice", "spider"
# {"bleu_1": tensor(0.4278), "bleu_2": ..., ...}
```
### Evaluate DCASE2023 metrics
To compute metrics for the DCASE2023 challenge, just set the argument `metrics="dcase2023"` in `evaluate` function call.
### Evaluate DCASE2024 metrics
To compute metrics for the DCASE2023 challenge, just set the argument `metrics="dcase2024"` in `evaluate` function call.

```python
corpus_scores, _ = evaluate(candidates, mult_references, metrics="dcase2023")
corpus_scores, _ = evaluate(candidates, mult_references, metrics="dcase2024")
print(corpus_scores)
# dict containing the score of each metric: "meteor", "cider_d", "spice", "spider", "spider_fl", "fluerr"
# dict containing the score of each metric: "meteor", "cider_d", "spice", "spider", "spider_fl", "fer", "fense", "vocab"
```

### Evaluate a specific metric
Expand Down
31 changes: 27 additions & 4 deletions src/aac_metrics/classes/evaluate.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,10 @@
import logging
import pickle
import zlib

from pathlib import Path
from typing import Any, Callable, Iterable, Optional, Union

import torch

from torch import Tensor

from aac_metrics.classes.base import AACMetric
Expand All @@ -23,16 +21,15 @@
from aac_metrics.classes.sbert_sim import SBERTSim
from aac_metrics.classes.spice import SPICE
from aac_metrics.classes.spider import SPIDEr
from aac_metrics.classes.spider_max import SPIDErMax
from aac_metrics.classes.spider_fl import SPIDErFL
from aac_metrics.classes.spider_max import SPIDErMax
from aac_metrics.classes.vocab import Vocab
from aac_metrics.functional.evaluate import (
DEFAULT_METRICS_SET_NAME,
METRICS_SETS,
evaluate,
)


pylog = logging.getLogger(__name__)


Expand Down Expand Up @@ -141,6 +138,32 @@ def __init__(
)


class DCASE2024Evaluate(Evaluate):
"""Evaluate candidates with multiple references with DCASE2024 Audio Captioning metrics.
For more information, see :func:`~aac_metrics.functional.evaluate.dcase2024_evaluate`.
"""

def __init__(
self,
preprocess: bool = True,
cache_path: Union[str, Path, None] = None,
java_path: Union[str, Path, None] = None,
tmp_path: Union[str, Path, None] = None,
device: Union[str, torch.device, None] = "cuda_if_available",
verbose: int = 0,
) -> None:
super().__init__(
preprocess=preprocess,
metrics="dcase2024",
cache_path=cache_path,
java_path=java_path,
tmp_path=tmp_path,
device=device,
verbose=verbose,
)


def _instantiate_metrics_classes(
metrics: Union[str, Iterable[str], Iterable[AACMetric]] = "aac",
cache_path: Union[str, Path, None] = None,
Expand Down
37 changes: 37 additions & 0 deletions src/aac_metrics/functional/evaluate.py
Original file line number Diff line number Diff line change
Expand Up @@ -206,6 +206,43 @@ def dcase2023_evaluate(
)


def dcase2024_evaluate(
candidates: list[str],
mult_references: list[list[str]],
preprocess: bool = True,
cache_path: Union[str, Path, None] = None,
java_path: Union[str, Path, None] = None,
tmp_path: Union[str, Path, None] = None,
device: Union[str, torch.device, None] = "cuda_if_available",
verbose: int = 0,
) -> tuple[dict[str, Tensor], dict[str, Tensor]]:
"""Evaluate candidates with multiple references with the DCASE2024 Audio Captioning metrics.
:param candidates: The list of sentences to evaluate.
:param mult_references: The list of list of sentences used as target.
:param preprocess: If True, the candidates and references will be passed as input to the PTB stanford tokenizer before computing metrics.
defaults to True.
:param cache_path: The path to the external code directory. defaults to the value returned by :func:`~aac_metrics.utils.paths.get_default_cache_path`.
:param java_path: The path to the java executable. defaults to the value returned by :func:`~aac_metrics.utils.paths.get_default_java_path`.
:param tmp_path: Temporary directory path. defaults to the value returned by :func:`~aac_metrics.utils.paths.get_default_tmp_path`.
:param device: The PyTorch device used to run FENSE and SPIDErFL models.
If None, it will try to detect use cuda if available. defaults to "cuda_if_available".
:param verbose: The verbose level. defaults to 0.
:returns: A tuple contains the corpus and sentences scores.
"""
return evaluate(
candidates=candidates,
mult_references=mult_references,
preprocess=preprocess,
metrics="dcase2024",
cache_path=cache_path,
java_path=java_path,
tmp_path=tmp_path,
device=device,
verbose=verbose,
)


def _instantiate_metrics_functions(
metrics: Union[str, Iterable[str], Iterable[Callable[[list, list], tuple]]] = "all",
cache_path: Union[str, Path, None] = None,
Expand Down

0 comments on commit ca35528

Please sign in to comment.