diff --git a/src/aac_metrics/__init__.py b/src/aac_metrics/__init__.py index ded7837..ddaa6fd 100644 --- a/src/aac_metrics/__init__.py +++ b/src/aac_metrics/__init__.py @@ -17,7 +17,7 @@ from .classes.base import AACMetric from .classes.bert_score_mrefs import BERTScoreMRefs -from .classes.bleu import BLEU +from .classes.bleu import BLEU, BLEU1, BLEU2, BLEU3, BLEU4 from .classes.cider_d import CIDErD from .classes.evaluate import DCASE2023Evaluate, Evaluate, _get_metric_factory_classes from .classes.fense import FENSE @@ -44,6 +44,10 @@ "AACMetric", "BERTScoreMRefs", "BLEU", + "BLEU1", + "BLEU2", + "BLEU3", + "BLEU4", "CIDErD", "Evaluate", "DCASE2023Evaluate", diff --git a/src/aac_metrics/classes/evaluate.py b/src/aac_metrics/classes/evaluate.py index 1a022e7..02f1b8b 100644 --- a/src/aac_metrics/classes/evaluate.py +++ b/src/aac_metrics/classes/evaluate.py @@ -45,7 +45,7 @@ class Evaluate(list[AACMetric], AACMetric[tuple[dict[str, Tensor], dict[str, Ten def __init__( self, - preprocess: bool = True, + preprocess: Union[bool, Callable[[list[str]], list[str]]] = True, metrics: Union[ str, Iterable[str], Iterable[AACMetric] ] = DEFAULT_METRICS_SET_NAME, @@ -165,7 +165,8 @@ def __init__( def _instantiate_metrics_classes( - metrics: Union[str, Iterable[str], Iterable[AACMetric]] = "aac", + metrics: Union[str, Iterable[str], Iterable[AACMetric]] = DEFAULT_METRICS_SET_NAME, + *, cache_path: Union[str, Path, None] = None, java_path: Union[str, Path, None] = None, tmp_path: Union[str, Path, None] = None, diff --git a/src/aac_metrics/functional/evaluate.py b/src/aac_metrics/functional/evaluate.py index 25f5aad..794c828 100644 --- a/src/aac_metrics/functional/evaluate.py +++ b/src/aac_metrics/functional/evaluate.py @@ -8,7 +8,7 @@ from typing import Any, Callable, Iterable, Optional, Union import torch -from torch import Tensor +from torch import Tensor, nn from aac_metrics.functional.bert_score_mrefs import bert_score_mrefs from aac_metrics.functional.bleu import bleu, bleu_1, bleu_2, bleu_3, bleu_4 @@ -24,8 +24,9 @@ from aac_metrics.functional.spider_max import spider_max from aac_metrics.functional.vocab import vocab from aac_metrics.utils.checks import check_metric_inputs +from aac_metrics.utils.collections import flat_list_of_list, unflat_list_of_list from aac_metrics.utils.log_utils import warn_once -from aac_metrics.utils.tokenization import preprocess_mono_sents, preprocess_mult_sents +from aac_metrics.utils.tokenization import preprocess_mono_sents pylog = logging.getLogger(__name__) @@ -83,7 +84,7 @@ def evaluate( candidates: list[str], mult_references: list[list[str]], - preprocess: bool = True, + preprocess: Union[bool, Callable[[list[str]], list[str]]] = True, metrics: Union[ str, Iterable[str], Iterable[Callable[[list, list], tuple]] ] = DEFAULT_METRICS_SET_NAME, @@ -97,7 +98,7 @@ def evaluate( :param candidates: The list of sentences to evaluate. :param mult_references: The list of list of sentences used as target. - :param preprocess: If True, the candidates and references will be passed as input to the PTB stanford tokenizer before computing metrics.defaults to True. + :param preprocess: If True, the candidates and references will be passed as input to the PTB stanford tokenizer before computing metrics. defaults to True. :param metrics: The name of the metric list or the explicit list of metrics to compute. defaults to "default". :param cache_path: The path to the external code directory. defaults to the value returned by :func:`~aac_metrics.utils.paths.get_default_cache_path`. :param java_path: The path to the java executable. defaults to the value returned by :func:`~aac_metrics.utils.paths.get_default_java_path`. @@ -110,24 +111,31 @@ def evaluate( check_metric_inputs(candidates, mult_references) metrics = _instantiate_metrics_functions( - metrics, cache_path, java_path, tmp_path, device, verbose + metrics, + cache_path=cache_path, + java_path=java_path, + tmp_path=tmp_path, + device=device, + verbose=verbose, ) - if preprocess: - common_kwds: dict[str, Any] = dict( + # Note: we use == here because preprocess is not necessary a boolean + if preprocess == False: # noqa: E712 + preprocess = nn.Identity() + + elif preprocess == True: # noqa: E712 + preprocess = partial( + preprocess_mono_sents, cache_path=cache_path, java_path=java_path, tmp_path=tmp_path, verbose=verbose, ) - candidates = preprocess_mono_sents( - candidates, - **common_kwds, - ) - mult_references = preprocess_mult_sents( - mult_references, - **common_kwds, - ) + + candidates = preprocess(candidates) + mult_references_flat, sizes = flat_list_of_list(mult_references) + mult_references_flat = preprocess(mult_references_flat) + mult_references = unflat_list_of_list(mult_references_flat, sizes) outs_corpus = {} outs_sents = {} @@ -174,7 +182,7 @@ def evaluate( def dcase2023_evaluate( candidates: list[str], mult_references: list[list[str]], - preprocess: bool = True, + preprocess: Union[bool, Callable[[list[str]], list[str]]] = True, cache_path: Union[str, Path, None] = None, java_path: Union[str, Path, None] = None, tmp_path: Union[str, Path, None] = None, @@ -211,7 +219,7 @@ def dcase2023_evaluate( def dcase2024_evaluate( candidates: list[str], mult_references: list[list[str]], - preprocess: bool = True, + preprocess: Union[bool, Callable[[list[str]], list[str]]] = True, cache_path: Union[str, Path, None] = None, java_path: Union[str, Path, None] = None, tmp_path: Union[str, Path, None] = None, @@ -247,6 +255,7 @@ def dcase2024_evaluate( def _instantiate_metrics_functions( metrics: Union[str, Iterable[str], Iterable[Callable[[list, list], tuple]]] = "all", + *, cache_path: Union[str, Path, None] = None, java_path: Union[str, Path, None] = None, tmp_path: Union[str, Path, None] = None, diff --git a/src/aac_metrics/utils/checks.py b/src/aac_metrics/utils/checks.py index 6e06554..fc57256 100644 --- a/src/aac_metrics/utils/checks.py +++ b/src/aac_metrics/utils/checks.py @@ -25,11 +25,21 @@ def check_metric_inputs( error_msgs = [] if not is_mono_sents(candidates): - error_msg = f"Invalid candidates type. (expected list[str], found {candidates.__class__.__name__})" + if isinstance(candidates, list) and len(candidates) > 0: + clsname = ( + f"{candidates.__class__.__name__}[{candidates[0].__class__.__name__}]" + ) + else: + clsname = candidates.__class__.__name__ + + error_msg = f"Invalid candidates type. (expected list[str], found {clsname})" error_msgs.append(error_msg) if not is_mult_sents(mult_references): - error_msg = f"Invalid mult_references type. (expected list[list[str]], found {mult_references.__class__.__name__})" + clsname = mult_references.__class__.__name__ + error_msg = ( + f"Invalid mult_references type. (expected list[list[str]], found {clsname})" + ) error_msgs.append(error_msg) if len(error_msgs) > 0: