Skip to content

Commit

Permalink
Add: BLEU-n to main init, preprocess custom callable argument.
Browse files Browse the repository at this point in the history
  • Loading branch information
Labbeti committed Jun 3, 2024
1 parent 62acc11 commit 7f0fb3a
Show file tree
Hide file tree
Showing 4 changed files with 46 additions and 22 deletions.
6 changes: 5 additions & 1 deletion src/aac_metrics/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@

from .classes.base import AACMetric
from .classes.bert_score_mrefs import BERTScoreMRefs
from .classes.bleu import BLEU
from .classes.bleu import BLEU, BLEU1, BLEU2, BLEU3, BLEU4
from .classes.cider_d import CIDErD
from .classes.evaluate import DCASE2023Evaluate, Evaluate, _get_metric_factory_classes
from .classes.fense import FENSE
Expand All @@ -44,6 +44,10 @@
"AACMetric",
"BERTScoreMRefs",
"BLEU",
"BLEU1",
"BLEU2",
"BLEU3",
"BLEU4",
"CIDErD",
"Evaluate",
"DCASE2023Evaluate",
Expand Down
5 changes: 3 additions & 2 deletions src/aac_metrics/classes/evaluate.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ class Evaluate(list[AACMetric], AACMetric[tuple[dict[str, Tensor], dict[str, Ten

def __init__(
self,
preprocess: bool = True,
preprocess: Union[bool, Callable[[list[str]], list[str]]] = True,
metrics: Union[
str, Iterable[str], Iterable[AACMetric]
] = DEFAULT_METRICS_SET_NAME,
Expand Down Expand Up @@ -165,7 +165,8 @@ def __init__(


def _instantiate_metrics_classes(
metrics: Union[str, Iterable[str], Iterable[AACMetric]] = "aac",
metrics: Union[str, Iterable[str], Iterable[AACMetric]] = DEFAULT_METRICS_SET_NAME,
*,
cache_path: Union[str, Path, None] = None,
java_path: Union[str, Path, None] = None,
tmp_path: Union[str, Path, None] = None,
Expand Down
43 changes: 26 additions & 17 deletions src/aac_metrics/functional/evaluate.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from typing import Any, Callable, Iterable, Optional, Union

import torch
from torch import Tensor
from torch import Tensor, nn

from aac_metrics.functional.bert_score_mrefs import bert_score_mrefs
from aac_metrics.functional.bleu import bleu, bleu_1, bleu_2, bleu_3, bleu_4
Expand All @@ -24,8 +24,9 @@
from aac_metrics.functional.spider_max import spider_max
from aac_metrics.functional.vocab import vocab
from aac_metrics.utils.checks import check_metric_inputs
from aac_metrics.utils.collections import flat_list_of_list, unflat_list_of_list
from aac_metrics.utils.log_utils import warn_once
from aac_metrics.utils.tokenization import preprocess_mono_sents, preprocess_mult_sents
from aac_metrics.utils.tokenization import preprocess_mono_sents

pylog = logging.getLogger(__name__)

Expand Down Expand Up @@ -83,7 +84,7 @@
def evaluate(
candidates: list[str],
mult_references: list[list[str]],
preprocess: bool = True,
preprocess: Union[bool, Callable[[list[str]], list[str]]] = True,
metrics: Union[
str, Iterable[str], Iterable[Callable[[list, list], tuple]]
] = DEFAULT_METRICS_SET_NAME,
Expand All @@ -97,7 +98,7 @@ def evaluate(
:param candidates: The list of sentences to evaluate.
:param mult_references: The list of list of sentences used as target.
:param preprocess: If True, the candidates and references will be passed as input to the PTB stanford tokenizer before computing metrics.defaults to True.
:param preprocess: If True, the candidates and references will be passed as input to the PTB stanford tokenizer before computing metrics. defaults to True.
:param metrics: The name of the metric list or the explicit list of metrics to compute. defaults to "default".
:param cache_path: The path to the external code directory. defaults to the value returned by :func:`~aac_metrics.utils.paths.get_default_cache_path`.
:param java_path: The path to the java executable. defaults to the value returned by :func:`~aac_metrics.utils.paths.get_default_java_path`.
Expand All @@ -110,24 +111,31 @@ def evaluate(
check_metric_inputs(candidates, mult_references)

metrics = _instantiate_metrics_functions(
metrics, cache_path, java_path, tmp_path, device, verbose
metrics,
cache_path=cache_path,
java_path=java_path,
tmp_path=tmp_path,
device=device,
verbose=verbose,
)

if preprocess:
common_kwds: dict[str, Any] = dict(
# Note: we use == here because preprocess is not necessary a boolean
if preprocess == False: # noqa: E712
preprocess = nn.Identity()

elif preprocess == True: # noqa: E712
preprocess = partial(
preprocess_mono_sents,
cache_path=cache_path,
java_path=java_path,
tmp_path=tmp_path,
verbose=verbose,
)
candidates = preprocess_mono_sents(
candidates,
**common_kwds,
)
mult_references = preprocess_mult_sents(
mult_references,
**common_kwds,
)

candidates = preprocess(candidates)
mult_references_flat, sizes = flat_list_of_list(mult_references)
mult_references_flat = preprocess(mult_references_flat)
mult_references = unflat_list_of_list(mult_references_flat, sizes)

outs_corpus = {}
outs_sents = {}
Expand Down Expand Up @@ -174,7 +182,7 @@ def evaluate(
def dcase2023_evaluate(
candidates: list[str],
mult_references: list[list[str]],
preprocess: bool = True,
preprocess: Union[bool, Callable[[list[str]], list[str]]] = True,
cache_path: Union[str, Path, None] = None,
java_path: Union[str, Path, None] = None,
tmp_path: Union[str, Path, None] = None,
Expand Down Expand Up @@ -211,7 +219,7 @@ def dcase2023_evaluate(
def dcase2024_evaluate(
candidates: list[str],
mult_references: list[list[str]],
preprocess: bool = True,
preprocess: Union[bool, Callable[[list[str]], list[str]]] = True,
cache_path: Union[str, Path, None] = None,
java_path: Union[str, Path, None] = None,
tmp_path: Union[str, Path, None] = None,
Expand Down Expand Up @@ -247,6 +255,7 @@ def dcase2024_evaluate(

def _instantiate_metrics_functions(
metrics: Union[str, Iterable[str], Iterable[Callable[[list, list], tuple]]] = "all",
*,
cache_path: Union[str, Path, None] = None,
java_path: Union[str, Path, None] = None,
tmp_path: Union[str, Path, None] = None,
Expand Down
14 changes: 12 additions & 2 deletions src/aac_metrics/utils/checks.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,11 +25,21 @@ def check_metric_inputs(

error_msgs = []
if not is_mono_sents(candidates):
error_msg = f"Invalid candidates type. (expected list[str], found {candidates.__class__.__name__})"
if isinstance(candidates, list) and len(candidates) > 0:
clsname = (
f"{candidates.__class__.__name__}[{candidates[0].__class__.__name__}]"
)
else:
clsname = candidates.__class__.__name__

error_msg = f"Invalid candidates type. (expected list[str], found {clsname})"
error_msgs.append(error_msg)

if not is_mult_sents(mult_references):
error_msg = f"Invalid mult_references type. (expected list[list[str]], found {mult_references.__class__.__name__})"
clsname = mult_references.__class__.__name__
error_msg = (
f"Invalid mult_references type. (expected list[list[str]], found {clsname})"
)
error_msgs.append(error_msg)

if len(error_msgs) > 0:
Expand Down

0 comments on commit 7f0fb3a

Please sign in to comment.