Skip to content

Commit

Permalink
Add: BERTScore metric to func and class factories.
Browse files Browse the repository at this point in the history
  • Loading branch information
Labbeti committed Nov 15, 2023
1 parent b070a71 commit 89d09cf
Show file tree
Hide file tree
Showing 5 changed files with 22 additions and 4 deletions.
6 changes: 6 additions & 0 deletions src/aac_metrics/classes/evaluate.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from torch import Tensor

from aac_metrics.classes.base import AACMetric
from aac_metrics.classes.bert_score_mrefs import BERTScoreMRefs
from aac_metrics.classes.bleu import BLEU, BLEU1, BLEU2, BLEU3, BLEU4
from aac_metrics.classes.cider_d import CIDErD
from aac_metrics.classes.fense import FENSE
Expand Down Expand Up @@ -251,6 +252,11 @@ def _get_metric_factory_classes(
**init_kwds,
),
"vocab": lambda: Vocab(
verbose=verbose,
**init_kwds,
),
"bert_score": lambda: BERTScoreMRefs(
verbose=verbose,
**init_kwds,
),
}
Expand Down
3 changes: 3 additions & 0 deletions src/aac_metrics/classes/vocab.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,13 +37,15 @@ def __init__(
tokenizer: Callable[[str], list[str]] = str.split,
dtype: torch.dtype = torch.float64,
pop_strategy: str = "max",
verbose: int = 0,
) -> None:
super().__init__()
self._return_all_scores = return_all_scores
self._seed = seed
self._tokenizer = tokenizer
self._dtype = dtype
self._pop_strategy = pop_strategy
self._verbose = verbose

self._candidates = []
self._mult_references = []
Expand All @@ -57,6 +59,7 @@ def compute(self) -> Union[tuple[dict[str, Tensor], dict[str, Tensor]], Tensor]:
tokenizer=self._tokenizer,
dtype=self._dtype,
pop_strategy=self._pop_strategy,
verbose=self._verbose,
)

def get_output_names(self) -> tuple[str, ...]:
Expand Down
5 changes: 1 addition & 4 deletions src/aac_metrics/functional/bert_score_mrefs.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-

from typing import Callable, Optional, TypeVar, Union
from typing import Callable, Optional, Union

import torch

Expand All @@ -14,9 +14,6 @@
from aac_metrics.utils.collections import flat_list, unflat_list, duplicate_list


T = TypeVar("T")


def bert_score_mrefs(
candidates: list[str],
mult_references: list[list[str]],
Expand Down
7 changes: 7 additions & 0 deletions src/aac_metrics/functional/evaluate.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@

from torch import Tensor

from aac_metrics.functional.bert_score_mrefs import bert_score_mrefs
from aac_metrics.functional.bleu import bleu, bleu_1, bleu_2, bleu_3, bleu_4
from aac_metrics.functional.cider_d import cider_d
from aac_metrics.functional.fense import fense
Expand Down Expand Up @@ -67,6 +68,7 @@
"fense", # includes sbert, fer
"spider_fl", # includes cider_d, spice, spider, fer
"vocab",
"bert_score",
),
}
DEFAULT_METRICS_SET_NAME = "default"
Expand Down Expand Up @@ -332,6 +334,11 @@ def _get_metric_factory_functions(
),
"vocab": partial(
vocab,
verbose=verbose,
**init_kwds,
),
"bert_score": partial(
bert_score_mrefs,
**init_kwds,
),
}
Expand Down
5 changes: 5 additions & 0 deletions src/aac_metrics/functional/vocab.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ def vocab(
tokenizer: Callable[[str], list[str]] = str.split,
dtype: torch.dtype = torch.float64,
pop_strategy: str = "max",
verbose: int = 0,
) -> Union[tuple[dict[str, Tensor], dict[str, Tensor]], Tensor]:
"""Compute vocabulary statistics.
Expand All @@ -35,6 +36,7 @@ def vocab(
:param tokenizer: The function used to split a sentence into tokens. defaults to str.split.
:param dtype: Torch floating point dtype for numerical precision. defaults to torch.float64.
:param pop_strategy: Strategy to compute average reference vocab. defaults to "max".
:param verbose: The verbose level. defaults to 0.
:returns: A tuple of globals and locals scores or a scalar tensor with the main global score.
"""
tok_cands = list(map(tokenizer, candidates))
Expand Down Expand Up @@ -79,6 +81,9 @@ def vocab(
f"Invalid argument {pop_strategy=}. (expected one of {POP_STRATEGIES} or an integer value)"
)

if verbose >= 2:
pylog.debug(f"Found {n_samples=} with {pop_strategy=}.")

vocab_mrefs_lens = torch.empty((n_samples,), dtype=dtype)

for i in range(n_samples):
Expand Down

0 comments on commit 89d09cf

Please sign in to comment.