diff --git a/CHANGELOG.md b/CHANGELOG.md index 52384c9..c0b1926 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -5,9 +5,15 @@ All notable changes to this project will be documented in this file. ## [0.4.4] UNRELEASED ### Added - `Evaluate` class now implements a `__hash__` and `tolist()` methods. +- BLEU 1 to n classes and functions. ### Changed - Function `get_install_info` now returns `package_path`. +- AACMetric now indicate the output type when using `__call__` method. +- Rename `AACEvaluate` to `DCASE2023Evaluate` and use `dcase2023` metric set instead of `all` metric set. + +### Fixed +- `sbert_sim` name in internal instantiation functions. ## [0.4.3] 2023-06-15 ### Changed diff --git a/src/aac_metrics/__init__.py b/src/aac_metrics/__init__.py index 7fc7d9e..0285f4d 100644 --- a/src/aac_metrics/__init__.py +++ b/src/aac_metrics/__init__.py @@ -16,7 +16,7 @@ from .classes.base import AACMetric from .classes.bleu import BLEU from .classes.cider_d import CIDErD -from .classes.evaluate import AACEvaluate, _get_metric_factory_classes +from .classes.evaluate import DCASE2023Evaluate, _get_metric_factory_classes from .classes.fense import FENSE from .classes.meteor import METEOR from .classes.rouge_l import ROUGEL @@ -28,7 +28,7 @@ __all__ = [ "BLEU", "CIDErD", - "AACEvaluate", + "DCASE2023Evaluate", "FENSE", "METEOR", "ROUGEL", diff --git a/src/aac_metrics/classes/__init__.py b/src/aac_metrics/classes/__init__.py index 23d3fa6..c2614ed 100644 --- a/src/aac_metrics/classes/__init__.py +++ b/src/aac_metrics/classes/__init__.py @@ -3,7 +3,7 @@ from .bleu import BLEU from .cider_d import CIDErD -from .evaluate import Evaluate, AACEvaluate +from .evaluate import DCASE2023Evaluate, Evaluate from .fense import FENSE from .fluerr import FluErr from .meteor import METEOR @@ -18,7 +18,7 @@ __all__ = [ "BLEU", "CIDErD", - "AACEvaluate", + "DCASE2023Evaluate", "Evaluate", "FENSE", "FluErr", diff --git a/src/aac_metrics/classes/base.py b/src/aac_metrics/classes/base.py index 9794623..0f0882b 100644 --- a/src/aac_metrics/classes/base.py +++ b/src/aac_metrics/classes/base.py @@ -1,12 +1,14 @@ #!/usr/bin/env python # -*- coding: utf-8 -*- -from typing import Any, Optional +from typing import Any, Generic, Optional, TypeVar from torch import nn +OutType = TypeVar("OutType") -class AACMetric(nn.Module): + +class AACMetric(nn.Module, Generic[OutType]): """Base Metric module for AAC metrics. Similar to torchmetrics.Metric.""" # Global values @@ -23,10 +25,10 @@ def __init__(self, **kwargs: Any) -> None: super().__init__(**kwargs) # Public methods - def compute(self) -> Any: - return None + def compute(self) -> OutType: + return None # type: ignore - def forward(self, *args: Any, **kwargs: Any) -> Any: + def forward(self, *args: Any, **kwargs: Any) -> OutType: self.update(*args, **kwargs) output = self.compute() self.reset() @@ -37,3 +39,7 @@ def reset(self) -> None: def update(self, *args, **kwargs) -> None: pass + + # Magic methods + def __call__(self, *args: Any, **kwds: Any) -> OutType: + return super().__call__(*args, **kwds) diff --git a/src/aac_metrics/classes/bleu.py b/src/aac_metrics/classes/bleu.py index d6b82c4..9d6030a 100644 --- a/src/aac_metrics/classes/bleu.py +++ b/src/aac_metrics/classes/bleu.py @@ -13,7 +13,7 @@ ) -class BLEU(AACMetric): +class BLEU(AACMetric[Union[tuple[dict[str, Tensor], dict[str, Tensor]], Tensor]]): """BiLingual Evaluation Understudy metric class. - Paper: https://www.aclweb.org/anthology/P02-1040.pdf @@ -85,3 +85,47 @@ def update( self._cooked_cands, self._cooked_mrefs, ) + + +class BLEU1(BLEU): + def __init__( + self, + return_all_scores: bool = True, + option: str = "closest", + verbose: int = 0, + tokenizer: Callable[[str], list[str]] = str.split, + ) -> None: + super().__init__(return_all_scores, 1, option, verbose, tokenizer) + + +class BLEU2(BLEU): + def __init__( + self, + return_all_scores: bool = True, + option: str = "closest", + verbose: int = 0, + tokenizer: Callable[[str], list[str]] = str.split, + ) -> None: + super().__init__(return_all_scores, 2, option, verbose, tokenizer) + + +class BLEU3(BLEU): + def __init__( + self, + return_all_scores: bool = True, + option: str = "closest", + verbose: int = 0, + tokenizer: Callable[[str], list[str]] = str.split, + ) -> None: + super().__init__(return_all_scores, 3, option, verbose, tokenizer) + + +class BLEU4(BLEU): + def __init__( + self, + return_all_scores: bool = True, + option: str = "closest", + verbose: int = 0, + tokenizer: Callable[[str], list[str]] = str.split, + ) -> None: + super().__init__(return_all_scores, 4, option, verbose, tokenizer) diff --git a/src/aac_metrics/classes/cider_d.py b/src/aac_metrics/classes/cider_d.py index 1372907..a9ad37d 100644 --- a/src/aac_metrics/classes/cider_d.py +++ b/src/aac_metrics/classes/cider_d.py @@ -12,7 +12,7 @@ ) -class CIDErD(AACMetric): +class CIDErD(AACMetric[Union[tuple[dict[str, Tensor], dict[str, Tensor]], Tensor]]): """Consensus-based Image Description Evaluation metric class. - Paper: https://arxiv.org/pdf/1411.5726.pdf diff --git a/src/aac_metrics/classes/evaluate.py b/src/aac_metrics/classes/evaluate.py index dbf2a22..0f67401 100644 --- a/src/aac_metrics/classes/evaluate.py +++ b/src/aac_metrics/classes/evaluate.py @@ -28,7 +28,7 @@ pylog = logging.getLogger(__name__) -class Evaluate(list[AACMetric], AACMetric): +class Evaluate(list[AACMetric], AACMetric[tuple[dict[str, Tensor], dict[str, Tensor]]]): """Evaluate candidates with multiple references with custom metrics. For more information, see :func:`~aac_metrics.functional.evaluate.evaluate`. @@ -105,8 +105,8 @@ def __hash__(self) -> int: return data -class AACEvaluate(Evaluate): - """Evaluate candidates with multiple references with all Audio Captioning metrics. +class DCASE2023Evaluate(Evaluate): + """Evaluate candidates with multiple references with DCASE2023 Audio Captioning metrics. For more information, see :func:`~aac_metrics.functional.evaluate.aac_evaluate`. """ @@ -117,15 +117,16 @@ def __init__( cache_path: str = "$HOME/.cache", java_path: str = "java", tmp_path: str = "/tmp", + device: Union[str, torch.device, None] = "auto", verbose: int = 0, ) -> None: super().__init__( preprocess, - "aac", + "dcase2023", cache_path, java_path, tmp_path, - "auto", + device, verbose, ) @@ -214,7 +215,7 @@ def _get_metric_factory_classes( tmp_path=tmp_path, verbose=verbose, ), - "sbert": lambda: SBERTSim( + "sbert_sim": lambda: SBERTSim( return_all_scores=return_all_scores, device=device, verbose=verbose, diff --git a/src/aac_metrics/classes/fense.py b/src/aac_metrics/classes/fense.py index 72876f8..da0b318 100644 --- a/src/aac_metrics/classes/fense.py +++ b/src/aac_metrics/classes/fense.py @@ -17,7 +17,7 @@ pylog = logging.getLogger(__name__) -class FENSE(AACMetric): +class FENSE(AACMetric[Union[tuple[dict[str, Tensor], dict[str, Tensor]], Tensor]]): """Fluency ENhanced Sentence-bert Evaluation (FENSE) - Paper: https://arxiv.org/abs/2110.04684 diff --git a/src/aac_metrics/classes/fluerr.py b/src/aac_metrics/classes/fluerr.py index aef3bdb..fdfb43d 100644 --- a/src/aac_metrics/classes/fluerr.py +++ b/src/aac_metrics/classes/fluerr.py @@ -20,7 +20,7 @@ pylog = logging.getLogger(__name__) -class FluErr(AACMetric): +class FluErr(AACMetric[Union[tuple[dict[str, Tensor], dict[str, Tensor]], Tensor]]): """Return fluency error rate detected by a pre-trained BERT model. - Paper: https://arxiv.org/abs/2110.04684 diff --git a/src/aac_metrics/classes/meteor.py b/src/aac_metrics/classes/meteor.py index 57aa8b0..969d663 100644 --- a/src/aac_metrics/classes/meteor.py +++ b/src/aac_metrics/classes/meteor.py @@ -9,7 +9,7 @@ from aac_metrics.functional.meteor import meteor -class METEOR(AACMetric): +class METEOR(AACMetric[Union[tuple[dict[str, Tensor], dict[str, Tensor]], Tensor]]): """Metric for Evaluation of Translation with Explicit ORdering metric class. - Paper: https://dl.acm.org/doi/pdf/10.5555/1626355.1626389 diff --git a/src/aac_metrics/classes/rouge_l.py b/src/aac_metrics/classes/rouge_l.py index a1c32dd..7cc79c0 100644 --- a/src/aac_metrics/classes/rouge_l.py +++ b/src/aac_metrics/classes/rouge_l.py @@ -12,7 +12,7 @@ ) -class ROUGEL(AACMetric): +class ROUGEL(AACMetric[Union[tuple[dict[str, Tensor], dict[str, Tensor]], Tensor]]): """Recall-Oriented Understudy for Gisting Evaluation class. - Paper: https://aclanthology.org/W04-1013.pdf diff --git a/src/aac_metrics/classes/sbert_sim.py b/src/aac_metrics/classes/sbert_sim.py index 56b4e2a..39704dc 100644 --- a/src/aac_metrics/classes/sbert_sim.py +++ b/src/aac_metrics/classes/sbert_sim.py @@ -17,7 +17,7 @@ pylog = logging.getLogger(__name__) -class SBERTSim(AACMetric): +class SBERTSim(AACMetric[Union[tuple[dict[str, Tensor], dict[str, Tensor]], Tensor]]): """Cosine-similarity of the Sentence-BERT embeddings. - Paper: https://arxiv.org/abs/1908.10084 diff --git a/src/aac_metrics/classes/spice.py b/src/aac_metrics/classes/spice.py index 7ff4a53..5143258 100644 --- a/src/aac_metrics/classes/spice.py +++ b/src/aac_metrics/classes/spice.py @@ -14,7 +14,7 @@ pylog = logging.getLogger(__name__) -class SPICE(AACMetric): +class SPICE(AACMetric[Union[tuple[dict[str, Tensor], dict[str, Tensor]], Tensor]]): """Semantic Propositional Image Caption Evaluation class. - Paper: https://arxiv.org/pdf/1607.08822.pdf diff --git a/src/aac_metrics/classes/spider.py b/src/aac_metrics/classes/spider.py index 2260c5a..7f7a5a7 100644 --- a/src/aac_metrics/classes/spider.py +++ b/src/aac_metrics/classes/spider.py @@ -14,7 +14,7 @@ pylog = logging.getLogger(__name__) -class SPIDEr(AACMetric): +class SPIDEr(AACMetric[Union[tuple[dict[str, Tensor], dict[str, Tensor]], Tensor]]): """SPIDEr class. - Paper: https://arxiv.org/pdf/1612.00370.pdf diff --git a/src/aac_metrics/classes/spider_fl.py b/src/aac_metrics/classes/spider_fl.py index 08d2a50..77252ed 100644 --- a/src/aac_metrics/classes/spider_fl.py +++ b/src/aac_metrics/classes/spider_fl.py @@ -21,7 +21,7 @@ pylog = logging.getLogger(__name__) -class SPIDErFL(AACMetric): +class SPIDErFL(AACMetric[Union[tuple[dict[str, Tensor], dict[str, Tensor]], Tensor]]): """SPIDErFL class. For more information, see :func:`~aac_metrics.functional.spider_fl.spider_fl`. diff --git a/src/aac_metrics/classes/spider_max.py b/src/aac_metrics/classes/spider_max.py index cbf43ae..fbf24d7 100644 --- a/src/aac_metrics/classes/spider_max.py +++ b/src/aac_metrics/classes/spider_max.py @@ -14,7 +14,7 @@ pylog = logging.getLogger(__name__) -class SPIDErMax(AACMetric): +class SPIDErMax(AACMetric[Union[tuple[dict[str, Tensor], dict[str, Tensor]], Tensor]]): """SPIDEr-max class. - Paper: https://hal.archives-ouvertes.fr/hal-03810396/file/Labbe_DCASE2022.pdf diff --git a/src/aac_metrics/functional/bleu.py b/src/aac_metrics/functional/bleu.py index 28ad8f0..5414c28 100644 --- a/src/aac_metrics/functional/bleu.py +++ b/src/aac_metrics/functional/bleu.py @@ -66,6 +66,90 @@ def bleu( ) +def bleu_1( + candidates: list[str], + mult_references: list[list[str]], + return_all_scores: bool = True, + option: str = "closest", + verbose: int = 0, + tokenizer: Callable[[str], list[str]] = str.split, + return_1_to_n: bool = False, +) -> Union[tuple[dict[str, Tensor], dict[str, Tensor]], Tensor]: + return bleu( + candidates=candidates, + mult_references=mult_references, + return_all_scores=return_all_scores, + n=1, + option=option, + verbose=verbose, + tokenizer=tokenizer, + return_1_to_n=return_1_to_n, + ) + + +def bleu_2( + candidates: list[str], + mult_references: list[list[str]], + return_all_scores: bool = True, + option: str = "closest", + verbose: int = 0, + tokenizer: Callable[[str], list[str]] = str.split, + return_1_to_n: bool = False, +) -> Union[tuple[dict[str, Tensor], dict[str, Tensor]], Tensor]: + return bleu( + candidates=candidates, + mult_references=mult_references, + return_all_scores=return_all_scores, + n=2, + option=option, + verbose=verbose, + tokenizer=tokenizer, + return_1_to_n=return_1_to_n, + ) + + +def bleu_3( + candidates: list[str], + mult_references: list[list[str]], + return_all_scores: bool = True, + option: str = "closest", + verbose: int = 0, + tokenizer: Callable[[str], list[str]] = str.split, + return_1_to_n: bool = False, +) -> Union[tuple[dict[str, Tensor], dict[str, Tensor]], Tensor]: + return bleu( + candidates=candidates, + mult_references=mult_references, + return_all_scores=return_all_scores, + n=3, + option=option, + verbose=verbose, + tokenizer=tokenizer, + return_1_to_n=return_1_to_n, + ) + + +def bleu_4( + candidates: list[str], + mult_references: list[list[str]], + return_all_scores: bool = True, + option: str = "closest", + verbose: int = 0, + tokenizer: Callable[[str], list[str]] = str.split, + return_1_to_n: bool = False, +) -> Union[tuple[dict[str, Tensor], dict[str, Tensor]], Tensor]: + return bleu( + candidates=candidates, + mult_references=mult_references, + return_all_scores=return_all_scores, + n=4, + option=option, + verbose=verbose, + tokenizer=tokenizer, + return_1_to_n=return_1_to_n, + ) + + def _bleu_update( candidates: list[str], mult_references: list[list[str]], diff --git a/src/aac_metrics/functional/evaluate.py b/src/aac_metrics/functional/evaluate.py index ad22af5..51c36ed 100644 --- a/src/aac_metrics/functional/evaluate.py +++ b/src/aac_metrics/functional/evaluate.py @@ -288,7 +288,7 @@ def _get_metric_factory_functions( tmp_path=tmp_path, verbose=verbose, ), - "sbert": partial( + "sbert_sim": partial( sbert_sim, return_all_scores=return_all_scores, device=device, diff --git a/src/aac_metrics/functional/fluerr.py b/src/aac_metrics/functional/fluerr.py index 4c1c629..b61b8a7 100644 --- a/src/aac_metrics/functional/fluerr.py +++ b/src/aac_metrics/functional/fluerr.py @@ -193,7 +193,7 @@ def _load_echecker_and_tokenizer( echecker = __load_pretrain_echecker(echecker, device, verbose=verbose) if echecker_tokenizer is None: - echecker_tokenizer = AutoTokenizer.from_pretrained(echecker.model_type) + echecker_tokenizer = AutoTokenizer.from_pretrained(echecker.model_type) # type: ignore echecker = echecker.eval() for p in echecker.parameters(): @@ -231,10 +231,10 @@ def __detect_error_sents( # batch_logits: (bsize, num_classes=6) # note: fix error in the original fense code: https://github.com/blmoistawinde/fense/blob/main/fense/evaluator.py#L69 probs = logits.sigmoid().transpose(0, 1).cpu().numpy() - probs_dic = dict(zip(ERROR_NAMES, probs)) + probs_dic: dict[str, np.ndarray] = dict(zip(ERROR_NAMES, probs)) else: - probs_dic = {name: [] for name in ERROR_NAMES} + dic_lst_probs = {name: [] for name in ERROR_NAMES} for i in range(0, len(sents), batch_size): batch = __infer_preprocess( @@ -251,10 +251,12 @@ def __detect_error_sents( # classes: add_tail, repeat_event, repeat_adv, remove_conj, remove_verb, error probs = batch_logits.sigmoid().cpu().numpy() - for j, name in enumerate(probs_dic.keys()): - probs_dic[name].append(probs[:, j]) + for j, name in enumerate(dic_lst_probs.keys()): + dic_lst_probs[name].append(probs[:, j]) - probs_dic = {name: np.concatenate(probs) for name, probs in probs_dic.items()} + probs_dic = { + name: np.concatenate(probs) for name, probs in dic_lst_probs.items() + } return probs_dic