From f3389e5b91278da291461260dac5fe67ab9652b1 Mon Sep 17 00:00:00 2001 From: Labbeti Date: Fri, 7 Jun 2024 10:35:55 +0200 Subject: [PATCH] Mod/Fix: Update init fns with list_metrics_available and fix PopStrategy typing in vocab. --- src/aac_metrics/__init__.py | 19 ++++++++++--------- src/aac_metrics/functional/vocab.py | 2 +- 2 files changed, 11 insertions(+), 10 deletions(-) diff --git a/src/aac_metrics/__init__.py b/src/aac_metrics/__init__.py index ddaa6fd..be2ce33 100644 --- a/src/aac_metrics/__init__.py +++ b/src/aac_metrics/__init__.py @@ -19,7 +19,12 @@ from .classes.bert_score_mrefs import BERTScoreMRefs from .classes.bleu import BLEU, BLEU1, BLEU2, BLEU3, BLEU4 from .classes.cider_d import CIDErD -from .classes.evaluate import DCASE2023Evaluate, Evaluate, _get_metric_factory_classes +from .classes.evaluate import ( + DCASE2023Evaluate, + Evaluate, + _get_metric_factory_classes, + _instantiate_metrics_classes, +) from .classes.fense import FENSE from .classes.fer import FER from .classes.meteor import METEOR @@ -69,6 +74,7 @@ "set_default_cache_path", "set_default_java_path", "set_default_tmp_path", + "list_metrics_available", "load_metric", ] @@ -86,13 +92,8 @@ def load_metric(name: str, **kwargs) -> AACMetric: :param **kwargs: The optional keyword arguments passed to the metric factory. :returns: The Metric object built. """ - name = name.lower().strip() - - factory = _get_metric_factory_classes(**kwargs) - if name not in factory: - raise ValueError( - f"Invalid argument {name=}. (expected one of {tuple(factory.keys())})" - ) + if not isinstance(name, str): + raise TypeError(f"Invalid argument type {type(name)}. (expected str)") - metric = factory[name]() + metric = _instantiate_metrics_classes(name, **kwargs) return metric diff --git a/src/aac_metrics/functional/vocab.py b/src/aac_metrics/functional/vocab.py index 2782704..fa41c37 100644 --- a/src/aac_metrics/functional/vocab.py +++ b/src/aac_metrics/functional/vocab.py @@ -14,7 +14,7 @@ T = TypeVar("T") POP_STRATEGIES = ("max", "min") -PopStrategy = Literal["max", "min"] +PopStrategy = Union[Literal["max", "min"], int] VocabScores = TypedDict("VocabScores", {"vocab.cands": Tensor}) VocabOuts = tuple[VocabScores, VocabScores]