Skip to content

Commit

Permalink
Mod/Fix: Update init fns with list_metrics_available and fix PopStrat…
Browse files Browse the repository at this point in the history
…egy typing in vocab.
  • Loading branch information
Labbeti committed Jun 7, 2024
1 parent 864c5ef commit f3389e5
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 10 deletions.
19 changes: 10 additions & 9 deletions src/aac_metrics/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,12 @@
from .classes.bert_score_mrefs import BERTScoreMRefs
from .classes.bleu import BLEU, BLEU1, BLEU2, BLEU3, BLEU4
from .classes.cider_d import CIDErD
from .classes.evaluate import DCASE2023Evaluate, Evaluate, _get_metric_factory_classes
from .classes.evaluate import (
DCASE2023Evaluate,
Evaluate,
_get_metric_factory_classes,
_instantiate_metrics_classes,
)
from .classes.fense import FENSE
from .classes.fer import FER
from .classes.meteor import METEOR
Expand Down Expand Up @@ -69,6 +74,7 @@
"set_default_cache_path",
"set_default_java_path",
"set_default_tmp_path",
"list_metrics_available",
"load_metric",
]

Expand All @@ -86,13 +92,8 @@ def load_metric(name: str, **kwargs) -> AACMetric:
:param **kwargs: The optional keyword arguments passed to the metric factory.
:returns: The Metric object built.
"""
name = name.lower().strip()

factory = _get_metric_factory_classes(**kwargs)
if name not in factory:
raise ValueError(
f"Invalid argument {name=}. (expected one of {tuple(factory.keys())})"
)
if not isinstance(name, str):
raise TypeError(f"Invalid argument type {type(name)}. (expected str)")

metric = factory[name]()
metric = _instantiate_metrics_classes(name, **kwargs)
return metric
2 changes: 1 addition & 1 deletion src/aac_metrics/functional/vocab.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@

T = TypeVar("T")
POP_STRATEGIES = ("max", "min")
PopStrategy = Literal["max", "min"]
PopStrategy = Union[Literal["max", "min"], int]
VocabScores = TypedDict("VocabScores", {"vocab.cands": Tensor})
VocabOuts = tuple[VocabScores, VocabScores]

Expand Down

0 comments on commit f3389e5

Please sign in to comment.