Skip to content

Commit

Permalink
Mod: Update typing, fix sbert name, change DCASE2023Evaluate.
Browse files Browse the repository at this point in the history
  • Loading branch information
Labbeti committed Jul 13, 2023
1 parent 015cc2a commit 407f1c2
Show file tree
Hide file tree
Showing 19 changed files with 176 additions and 33 deletions.
6 changes: 6 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,15 @@ All notable changes to this project will be documented in this file.
## [0.4.4] UNRELEASED
### Added
- `Evaluate` class now implements a `__hash__` and `tolist()` methods.
- BLEU 1 to n classes and functions.

### Changed
- Function `get_install_info` now returns `package_path`.
- AACMetric now indicate the output type when using `__call__` method.
- Rename `AACEvaluate` to `DCASE2023Evaluate` and use `dcase2023` metric set instead of `all` metric set.

### Fixed
- `sbert_sim` name in internal instantiation functions.

## [0.4.3] 2023-06-15
### Changed
Expand Down
4 changes: 2 additions & 2 deletions src/aac_metrics/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
from .classes.base import AACMetric
from .classes.bleu import BLEU
from .classes.cider_d import CIDErD
from .classes.evaluate import AACEvaluate, _get_metric_factory_classes
from .classes.evaluate import DCASE2023Evaluate, _get_metric_factory_classes
from .classes.fense import FENSE
from .classes.meteor import METEOR
from .classes.rouge_l import ROUGEL
Expand All @@ -28,7 +28,7 @@
__all__ = [
"BLEU",
"CIDErD",
"AACEvaluate",
"DCASE2023Evaluate",
"FENSE",
"METEOR",
"ROUGEL",
Expand Down
4 changes: 2 additions & 2 deletions src/aac_metrics/classes/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

from .bleu import BLEU
from .cider_d import CIDErD
from .evaluate import Evaluate, AACEvaluate
from .evaluate import DCASE2023Evaluate, Evaluate
from .fense import FENSE
from .fluerr import FluErr
from .meteor import METEOR
Expand All @@ -18,7 +18,7 @@
__all__ = [
"BLEU",
"CIDErD",
"AACEvaluate",
"DCASE2023Evaluate",
"Evaluate",
"FENSE",
"FluErr",
Expand Down
16 changes: 11 additions & 5 deletions src/aac_metrics/classes/base.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,14 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-

from typing import Any, Optional
from typing import Any, Generic, Optional, TypeVar

from torch import nn

OutType = TypeVar("OutType")

class AACMetric(nn.Module):

class AACMetric(nn.Module, Generic[OutType]):
"""Base Metric module for AAC metrics. Similar to torchmetrics.Metric."""

# Global values
Expand All @@ -23,10 +25,10 @@ def __init__(self, **kwargs: Any) -> None:
super().__init__(**kwargs)

# Public methods
def compute(self) -> Any:
return None
def compute(self) -> OutType:
return None # type: ignore

def forward(self, *args: Any, **kwargs: Any) -> Any:
def forward(self, *args: Any, **kwargs: Any) -> OutType:
self.update(*args, **kwargs)
output = self.compute()
self.reset()
Expand All @@ -37,3 +39,7 @@ def reset(self) -> None:

def update(self, *args, **kwargs) -> None:
pass

# Magic methods
def __call__(self, *args: Any, **kwds: Any) -> OutType:
return super().__call__(*args, **kwds)
46 changes: 45 additions & 1 deletion src/aac_metrics/classes/bleu.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
)


class BLEU(AACMetric):
class BLEU(AACMetric[Union[tuple[dict[str, Tensor], dict[str, Tensor]], Tensor]]):
"""BiLingual Evaluation Understudy metric class.
- Paper: https://www.aclweb.org/anthology/P02-1040.pdf
Expand Down Expand Up @@ -85,3 +85,47 @@ def update(
self._cooked_cands,
self._cooked_mrefs,
)


class BLEU1(BLEU):
def __init__(
self,
return_all_scores: bool = True,
option: str = "closest",
verbose: int = 0,
tokenizer: Callable[[str], list[str]] = str.split,
) -> None:
super().__init__(return_all_scores, 1, option, verbose, tokenizer)


class BLEU2(BLEU):
def __init__(
self,
return_all_scores: bool = True,
option: str = "closest",
verbose: int = 0,
tokenizer: Callable[[str], list[str]] = str.split,
) -> None:
super().__init__(return_all_scores, 2, option, verbose, tokenizer)


class BLEU3(BLEU):
def __init__(
self,
return_all_scores: bool = True,
option: str = "closest",
verbose: int = 0,
tokenizer: Callable[[str], list[str]] = str.split,
) -> None:
super().__init__(return_all_scores, 3, option, verbose, tokenizer)


class BLEU4(BLEU):
def __init__(
self,
return_all_scores: bool = True,
option: str = "closest",
verbose: int = 0,
tokenizer: Callable[[str], list[str]] = str.split,
) -> None:
super().__init__(return_all_scores, 4, option, verbose, tokenizer)
2 changes: 1 addition & 1 deletion src/aac_metrics/classes/cider_d.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
)


class CIDErD(AACMetric):
class CIDErD(AACMetric[Union[tuple[dict[str, Tensor], dict[str, Tensor]], Tensor]]):
"""Consensus-based Image Description Evaluation metric class.
- Paper: https://arxiv.org/pdf/1411.5726.pdf
Expand Down
13 changes: 7 additions & 6 deletions src/aac_metrics/classes/evaluate.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
pylog = logging.getLogger(__name__)


class Evaluate(list[AACMetric], AACMetric):
class Evaluate(list[AACMetric], AACMetric[tuple[dict[str, Tensor], dict[str, Tensor]]]):
"""Evaluate candidates with multiple references with custom metrics.
For more information, see :func:`~aac_metrics.functional.evaluate.evaluate`.
Expand Down Expand Up @@ -105,8 +105,8 @@ def __hash__(self) -> int:
return data


class AACEvaluate(Evaluate):
"""Evaluate candidates with multiple references with all Audio Captioning metrics.
class DCASE2023Evaluate(Evaluate):
"""Evaluate candidates with multiple references with DCASE2023 Audio Captioning metrics.
For more information, see :func:`~aac_metrics.functional.evaluate.aac_evaluate`.
"""
Expand All @@ -117,15 +117,16 @@ def __init__(
cache_path: str = "$HOME/.cache",
java_path: str = "java",
tmp_path: str = "/tmp",
device: Union[str, torch.device, None] = "auto",
verbose: int = 0,
) -> None:
super().__init__(
preprocess,
"aac",
"dcase2023",
cache_path,
java_path,
tmp_path,
"auto",
device,
verbose,
)

Expand Down Expand Up @@ -214,7 +215,7 @@ def _get_metric_factory_classes(
tmp_path=tmp_path,
verbose=verbose,
),
"sbert": lambda: SBERTSim(
"sbert_sim": lambda: SBERTSim(
return_all_scores=return_all_scores,
device=device,
verbose=verbose,
Expand Down
2 changes: 1 addition & 1 deletion src/aac_metrics/classes/fense.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
pylog = logging.getLogger(__name__)


class FENSE(AACMetric):
class FENSE(AACMetric[Union[tuple[dict[str, Tensor], dict[str, Tensor]], Tensor]]):
"""Fluency ENhanced Sentence-bert Evaluation (FENSE)
- Paper: https://arxiv.org/abs/2110.04684
Expand Down
2 changes: 1 addition & 1 deletion src/aac_metrics/classes/fluerr.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
pylog = logging.getLogger(__name__)


class FluErr(AACMetric):
class FluErr(AACMetric[Union[tuple[dict[str, Tensor], dict[str, Tensor]], Tensor]]):
"""Return fluency error rate detected by a pre-trained BERT model.
- Paper: https://arxiv.org/abs/2110.04684
Expand Down
2 changes: 1 addition & 1 deletion src/aac_metrics/classes/meteor.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from aac_metrics.functional.meteor import meteor


class METEOR(AACMetric):
class METEOR(AACMetric[Union[tuple[dict[str, Tensor], dict[str, Tensor]], Tensor]]):
"""Metric for Evaluation of Translation with Explicit ORdering metric class.
- Paper: https://dl.acm.org/doi/pdf/10.5555/1626355.1626389
Expand Down
2 changes: 1 addition & 1 deletion src/aac_metrics/classes/rouge_l.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
)


class ROUGEL(AACMetric):
class ROUGEL(AACMetric[Union[tuple[dict[str, Tensor], dict[str, Tensor]], Tensor]]):
"""Recall-Oriented Understudy for Gisting Evaluation class.
- Paper: https://aclanthology.org/W04-1013.pdf
Expand Down
2 changes: 1 addition & 1 deletion src/aac_metrics/classes/sbert_sim.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
pylog = logging.getLogger(__name__)


class SBERTSim(AACMetric):
class SBERTSim(AACMetric[Union[tuple[dict[str, Tensor], dict[str, Tensor]], Tensor]]):
"""Cosine-similarity of the Sentence-BERT embeddings.
- Paper: https://arxiv.org/abs/1908.10084
Expand Down
2 changes: 1 addition & 1 deletion src/aac_metrics/classes/spice.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
pylog = logging.getLogger(__name__)


class SPICE(AACMetric):
class SPICE(AACMetric[Union[tuple[dict[str, Tensor], dict[str, Tensor]], Tensor]]):
"""Semantic Propositional Image Caption Evaluation class.
- Paper: https://arxiv.org/pdf/1607.08822.pdf
Expand Down
2 changes: 1 addition & 1 deletion src/aac_metrics/classes/spider.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
pylog = logging.getLogger(__name__)


class SPIDEr(AACMetric):
class SPIDEr(AACMetric[Union[tuple[dict[str, Tensor], dict[str, Tensor]], Tensor]]):
"""SPIDEr class.
- Paper: https://arxiv.org/pdf/1612.00370.pdf
Expand Down
2 changes: 1 addition & 1 deletion src/aac_metrics/classes/spider_fl.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
pylog = logging.getLogger(__name__)


class SPIDErFL(AACMetric):
class SPIDErFL(AACMetric[Union[tuple[dict[str, Tensor], dict[str, Tensor]], Tensor]]):
"""SPIDErFL class.
For more information, see :func:`~aac_metrics.functional.spider_fl.spider_fl`.
Expand Down
2 changes: 1 addition & 1 deletion src/aac_metrics/classes/spider_max.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
pylog = logging.getLogger(__name__)


class SPIDErMax(AACMetric):
class SPIDErMax(AACMetric[Union[tuple[dict[str, Tensor], dict[str, Tensor]], Tensor]]):
"""SPIDEr-max class.
- Paper: https://hal.archives-ouvertes.fr/hal-03810396/file/Labbe_DCASE2022.pdf
Expand Down
84 changes: 84 additions & 0 deletions src/aac_metrics/functional/bleu.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,90 @@ def bleu(
)


def bleu_1(
candidates: list[str],
mult_references: list[list[str]],
return_all_scores: bool = True,
option: str = "closest",
verbose: int = 0,
tokenizer: Callable[[str], list[str]] = str.split,
return_1_to_n: bool = False,
) -> Union[tuple[dict[str, Tensor], dict[str, Tensor]], Tensor]:
return bleu(
candidates=candidates,
mult_references=mult_references,
return_all_scores=return_all_scores,
n=1,
option=option,
verbose=verbose,
tokenizer=tokenizer,
return_1_to_n=return_1_to_n,
)


def bleu_2(
candidates: list[str],
mult_references: list[list[str]],
return_all_scores: bool = True,
option: str = "closest",
verbose: int = 0,
tokenizer: Callable[[str], list[str]] = str.split,
return_1_to_n: bool = False,
) -> Union[tuple[dict[str, Tensor], dict[str, Tensor]], Tensor]:
return bleu(
candidates=candidates,
mult_references=mult_references,
return_all_scores=return_all_scores,
n=2,
option=option,
verbose=verbose,
tokenizer=tokenizer,
return_1_to_n=return_1_to_n,
)


def bleu_3(
candidates: list[str],
mult_references: list[list[str]],
return_all_scores: bool = True,
option: str = "closest",
verbose: int = 0,
tokenizer: Callable[[str], list[str]] = str.split,
return_1_to_n: bool = False,
) -> Union[tuple[dict[str, Tensor], dict[str, Tensor]], Tensor]:
return bleu(
candidates=candidates,
mult_references=mult_references,
return_all_scores=return_all_scores,
n=3,
option=option,
verbose=verbose,
tokenizer=tokenizer,
return_1_to_n=return_1_to_n,
)


def bleu_4(
candidates: list[str],
mult_references: list[list[str]],
return_all_scores: bool = True,
option: str = "closest",
verbose: int = 0,
tokenizer: Callable[[str], list[str]] = str.split,
return_1_to_n: bool = False,
) -> Union[tuple[dict[str, Tensor], dict[str, Tensor]], Tensor]:
return bleu(
candidates=candidates,
mult_references=mult_references,
return_all_scores=return_all_scores,
n=4,
option=option,
verbose=verbose,
tokenizer=tokenizer,
return_1_to_n=return_1_to_n,
)


def _bleu_update(
candidates: list[str],
mult_references: list[list[str]],
Expand Down
2 changes: 1 addition & 1 deletion src/aac_metrics/functional/evaluate.py
Original file line number Diff line number Diff line change
Expand Up @@ -288,7 +288,7 @@ def _get_metric_factory_functions(
tmp_path=tmp_path,
verbose=verbose,
),
"sbert": partial(
"sbert_sim": partial(
sbert_sim,
return_all_scores=return_all_scores,
device=device,
Expand Down
Loading

0 comments on commit 407f1c2

Please sign in to comment.