Skip to content

Commit

Permalink
Mod: Update default model names and methods order.
Browse files Browse the repository at this point in the history
  • Loading branch information
Labbeti committed Jan 4, 2024
1 parent f35b044 commit 414880b
Show file tree
Hide file tree
Showing 14 changed files with 102 additions and 59 deletions.
25 changes: 9 additions & 16 deletions src/aac_metrics/classes/bert_score_mrefs.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,12 @@
import torch

from torch import nn, Tensor
from torchmetrics.text.bert import _DEFAULT_MODEL

from aac_metrics.classes.base import AACMetric
from aac_metrics.functional.bert_score_mrefs import (
bert_score_mrefs,
_load_model_and_tokenizer,
DEFAULT_BERT_SCORE_MODEL,
REDUCTIONS,
)
from aac_metrics.utils.globals import _get_device
Expand All @@ -37,7 +37,7 @@ class BERTScoreMRefs(AACMetric):
def __init__(
self,
return_all_scores: bool = True,
model: Union[str, nn.Module] = _DEFAULT_MODEL,
model: Union[str, nn.Module] = DEFAULT_BERT_SCORE_MODEL,
device: Union[str, torch.device, None] = "cuda_if_available",
batch_size: int = 32,
num_threads: int = 0,
Expand Down Expand Up @@ -79,7 +79,6 @@ def __init__(
self._candidates = []
self._mult_references = []

# AACMetric methods
def compute(self) -> Union[tuple[dict[str, Tensor], dict[str, Tensor]], Tensor]:
return bert_score_mrefs(
candidates=self._candidates,
Expand Down Expand Up @@ -108,6 +107,13 @@ def extra_repr(self) -> str:
repr_ = ", ".join(f"{k}={v}" for k, v in hparams.items())
return repr_

def get_output_names(self) -> tuple[str, ...]:
return (
"bert_score.precision",
"bert_score.recalll",
"bert_score.f1",
)

def reset(self) -> None:
self._candidates = []
self._mult_references = []
Expand All @@ -120,16 +126,3 @@ def update(
) -> None:
self._candidates += candidates
self._mult_references += mult_references

# Other methods
@property
def device(self) -> torch.device:
try:
param = next(iter(self.parameters()))

def get_output_names(self) -> tuple[str, ...]:
return (
"bert_score.precision",
"bert_score.recalll",
"bert_score.f1",
)
12 changes: 6 additions & 6 deletions src/aac_metrics/classes/bleu.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,12 +81,12 @@ def update(
mult_references: list[list[str]],
) -> None:
self._cooked_cands, self._cooked_mrefs = _bleu_update(
candidates,
mult_references,
self._n,
self._tokenizer,
self._cooked_cands,
self._cooked_mrefs,
candidates=candidates,
mult_references=mult_references,
n=self._n,
tokenizer=self._tokenizer,
cooked_cands=self._cooked_cands,
cooked_mrefs=self._cooked_mrefs,
)


Expand Down
18 changes: 12 additions & 6 deletions src/aac_metrics/classes/fense.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,12 @@

from aac_metrics.classes.base import AACMetric
from aac_metrics.functional.fense import fense, _load_models_and_tokenizer
from aac_metrics.functional.fer import BERTFlatClassifier, _ERROR_NAMES
from aac_metrics.functional.fer import (
BERTFlatClassifier,
_ERROR_NAMES,
DEFAULT_FER_MODEL,
)
from aac_metrics.functional.sbert_sim import DEFAULT_SBERT_SIM_MODEL


pylog = logging.getLogger(__name__)
Expand All @@ -37,8 +42,8 @@ class FENSE(AACMetric[Union[tuple[dict[str, Tensor], dict[str, Tensor]], Tensor]
def __init__(
self,
return_all_scores: bool = True,
sbert_model: Union[str, SentenceTransformer] = "paraphrase-TinyBERT-L6-v2",
echecker: Union[str, BERTFlatClassifier] = "echecker_clotho_audiocaps_base",
sbert_model: Union[str, SentenceTransformer] = DEFAULT_SBERT_SIM_MODEL,
echecker: Union[str, BERTFlatClassifier] = DEFAULT_FER_MODEL,
error_threshold: float = 0.9,
device: Union[str, torch.device, None] = "cuda_if_available",
batch_size: int = 32,
Expand Down Expand Up @@ -100,9 +105,10 @@ def extra_repr(self) -> str:
return repr_

def get_output_names(self) -> tuple[str, ...]:
return ("sbert_sim", "fer", "fense") + tuple(
f"fer.{name}_prob" for name in _ERROR_NAMES
)
output_names = ["sbert_sim", "fer", "fense"]
if self._return_probs:
output_names += [f"fer.{name}_prob" for name in _ERROR_NAMES]
return tuple(output_names)

def reset(self) -> None:
self._candidates = []
Expand Down
10 changes: 8 additions & 2 deletions src/aac_metrics/classes/fer.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,9 @@
fer,
_load_echecker_and_tokenizer,
_ERROR_NAMES,
DEFAULT_FER_MODEL,
)
from aac_metrics.utils.globals import _get_device


pylog = logging.getLogger(__name__)
Expand All @@ -40,14 +42,15 @@ class FER(AACMetric[Union[tuple[dict[str, Tensor], dict[str, Tensor]], Tensor]])
def __init__(
self,
return_all_scores: bool = True,
echecker: Union[str, BERTFlatClassifier] = "echecker_clotho_audiocaps_base",
echecker: Union[str, BERTFlatClassifier] = DEFAULT_FER_MODEL,
error_threshold: float = 0.9,
device: Union[str, torch.device, None] = "cuda_if_available",
batch_size: int = 32,
reset_state: bool = True,
return_probs: bool = False,
verbose: int = 0,
) -> None:
device = _get_device(device)
echecker, echecker_tokenizer = _load_echecker_and_tokenizer(
echecker=echecker,
echecker_tokenizer=None,
Expand Down Expand Up @@ -89,7 +92,10 @@ def extra_repr(self) -> str:
return repr_

def get_output_names(self) -> tuple[str, ...]:
return ("fer",) + tuple(f"fer.{name}_prob" for name in _ERROR_NAMES)
output_names = ["fer"]
if self._return_probs:
output_names += [f"fer.{name}_prob" for name in _ERROR_NAMES]
return tuple(output_names)

def reset(self) -> None:
self._candidates = []
Expand Down
16 changes: 13 additions & 3 deletions src/aac_metrics/classes/sbert_sim.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,12 @@
from torch import Tensor

from aac_metrics.classes.base import AACMetric
from aac_metrics.functional.sbert_sim import sbert_sim, _load_sbert
from aac_metrics.functional.sbert_sim import (
sbert_sim,
_load_sbert,
DEFAULT_SBERT_SIM_MODEL,
)
from aac_metrics.utils.globals import _get_device


pylog = logging.getLogger(__name__)
Expand All @@ -36,13 +41,18 @@ class SBERTSim(AACMetric[Union[tuple[dict[str, Tensor], dict[str, Tensor]], Tens
def __init__(
self,
return_all_scores: bool = True,
sbert_model: Union[str, SentenceTransformer] = "paraphrase-TinyBERT-L6-v2",
sbert_model: Union[str, SentenceTransformer] = DEFAULT_SBERT_SIM_MODEL,
device: Union[str, torch.device, None] = "cuda_if_available",
batch_size: int = 32,
reset_state: bool = True,
verbose: int = 0,
) -> None:
sbert_model = _load_sbert(sbert_model, device, reset_state)
device = _get_device(device)
sbert_model = _load_sbert(
sbert_model=sbert_model,
device=device,
reset_state=reset_state,
)

super().__init__()
self._return_all_scores = return_all_scores
Expand Down
17 changes: 14 additions & 3 deletions src/aac_metrics/classes/spider_fl.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,11 @@
from aac_metrics.functional.fer import (
BERTFlatClassifier,
_load_echecker_and_tokenizer,
_ERROR_NAMES,
DEFAULT_FER_MODEL,
)
from aac_metrics.functional.spider_fl import spider_fl
from aac_metrics.utils.globals import _get_device


pylog = logging.getLogger(__name__)
Expand Down Expand Up @@ -49,7 +52,7 @@ def __init__(
java_max_memory: str = "8G",
timeout: Union[None, int, Iterable[int]] = None,
# FluencyError args
echecker: Union[str, BERTFlatClassifier] = "echecker_clotho_audiocaps_base",
echecker: Union[str, BERTFlatClassifier] = DEFAULT_FER_MODEL,
echecker_tokenizer: Optional[AutoTokenizer] = None,
error_threshold: float = 0.9,
device: Union[str, torch.device, None] = "cuda_if_available",
Expand All @@ -60,8 +63,13 @@ def __init__(
penalty: float = 0.9,
verbose: int = 0,
) -> None:
device = _get_device(device)
echecker, echecker_tokenizer = _load_echecker_and_tokenizer(
echecker, echecker_tokenizer, device, reset_state, verbose
echecker=echecker,
echecker_tokenizer=echecker_tokenizer,
device=device,
reset_state=reset_state,
verbose=verbose,
)

super().__init__()
Expand Down Expand Up @@ -127,7 +135,10 @@ def extra_repr(self) -> str:
return extra

def get_output_names(self) -> tuple[str, ...]:
return ("cider_d", "spice", "spider", "spider_fl", "fer")
output_names = ["cider_d", "spice", "spider", "spider_fl", "fer"]
if self._return_probs:
output_names += [f"fer.{name}_prob" for name in _ERROR_NAMES]
return tuple(output_names)

def reset(self) -> None:
self._candidates = []
Expand Down
5 changes: 4 additions & 1 deletion src/aac_metrics/classes/spider_max.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,10 @@ def extra_repr(self) -> str:
return repr_

def get_output_names(self) -> tuple[str, ...]:
return ("spider_max",)
output_names = ["spider_max"]
if self._return_all_cands_scores:
output_names += ["cider_d_all", "spice_all", "spider_all"]
return tuple(output_names)

def reset(self) -> None:
self._mult_candidates = []
Expand Down
13 changes: 7 additions & 6 deletions src/aac_metrics/functional/bert_score_mrefs.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,14 +16,15 @@
from aac_metrics.utils.globals import _get_device


DEFAULT_BERT_SCORE_MODEL = _DEFAULT_MODEL
REDUCTIONS = ("mean", "max", "min")


def bert_score_mrefs(
candidates: list[str],
mult_references: list[list[str]],
return_all_scores: bool = True,
model: Union[str, nn.Module] = _DEFAULT_MODEL,
model: Union[str, nn.Module] = DEFAULT_BERT_SCORE_MODEL,
tokenizer: Optional[Callable] = None,
device: Union[str, torch.device, None] = "cuda_if_available",
batch_size: int = 32,
Expand Down Expand Up @@ -167,11 +168,11 @@ def bert_score_mrefs(


def _load_model_and_tokenizer(
model: Union[str, nn.Module],
tokenizer: Optional[Callable],
device: Union[str, torch.device, None],
reset_state: bool,
verbose: int,
model: Union[str, nn.Module] = DEFAULT_BERT_SCORE_MODEL,
tokenizer: Optional[Callable] = None,
device: Union[str, torch.device, None] = "cuda_if_available",
reset_state: bool = True,
verbose: int = 0,
) -> tuple[nn.Module, Optional[Callable]]:
state = torch.random.get_rng_state()

Expand Down
15 changes: 10 additions & 5 deletions src/aac_metrics/functional/fense.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,13 @@
fer,
_load_echecker_and_tokenizer,
BERTFlatClassifier,
DEFAULT_FER_MODEL,
)
from aac_metrics.functional.sbert_sim import (
sbert_sim,
_load_sbert,
DEFAULT_SBERT_SIM_MODEL,
)
from aac_metrics.functional.sbert_sim import sbert_sim, _load_sbert
from aac_metrics.utils.checks import check_metric_inputs


Expand All @@ -28,9 +33,9 @@ def fense(
mult_references: list[list[str]],
return_all_scores: bool = True,
# SBERT args
sbert_model: Union[str, SentenceTransformer] = "paraphrase-TinyBERT-L6-v2",
sbert_model: Union[str, SentenceTransformer] = DEFAULT_SBERT_SIM_MODEL,
# FluencyError args
echecker: Union[str, BERTFlatClassifier] = "echecker_clotho_audiocaps_base",
echecker: Union[str, BERTFlatClassifier] = DEFAULT_FER_MODEL,
echecker_tokenizer: Optional[AutoTokenizer] = None,
error_threshold: float = 0.9,
device: Union[str, torch.device, None] = "cuda_if_available",
Expand Down Expand Up @@ -138,8 +143,8 @@ def _fense_from_outputs(


def _load_models_and_tokenizer(
sbert_model: Union[str, SentenceTransformer] = "paraphrase-TinyBERT-L6-v2",
echecker: Union[str, BERTFlatClassifier] = "echecker_clotho_audiocaps_base",
sbert_model: Union[str, SentenceTransformer] = DEFAULT_SBERT_SIM_MODEL,
echecker: Union[str, BERTFlatClassifier] = DEFAULT_FER_MODEL,
echecker_tokenizer: Optional[AutoTokenizer] = None,
device: Union[str, torch.device, None] = "cuda_if_available",
reset_state: bool = True,
Expand Down
9 changes: 6 additions & 3 deletions src/aac_metrics/functional/fer.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,9 @@
from aac_metrics.utils.globals import _get_device


DEFAULT_FER_MODEL = "echecker_clotho_audiocaps_base"


_DEFAULT_PROXIES = {
"http": "socks5h://127.0.0.1:1080",
"https": "socks5h://127.0.0.1:1080",
Expand Down Expand Up @@ -67,7 +70,7 @@ def __init__(self, model_type: str, num_classes: int = 5) -> None:
@classmethod
def from_pretrained(
cls,
model_name: str = "echecker_clotho_audiocaps_base",
model_name: str = DEFAULT_FER_MODEL,
device: Union[str, torch.device, None] = "cuda_if_available",
use_proxy: bool = False,
proxies: Optional[dict[str, str]] = None,
Expand Down Expand Up @@ -98,7 +101,7 @@ def forward(
def fer(
candidates: list[str],
return_all_scores: bool = True,
echecker: Union[str, BERTFlatClassifier] = "echecker_clotho_audiocaps_base",
echecker: Union[str, BERTFlatClassifier] = DEFAULT_FER_MODEL,
echecker_tokenizer: Optional[AutoTokenizer] = None,
error_threshold: float = 0.9,
device: Union[str, torch.device, None] = "cuda_if_available",
Expand Down Expand Up @@ -190,7 +193,7 @@ def _use_new_echecker_loading() -> bool:

# - Private functions
def _load_echecker_and_tokenizer(
echecker: Union[str, BERTFlatClassifier] = "echecker_clotho_audiocaps_base",
echecker: Union[str, BERTFlatClassifier] = DEFAULT_FER_MODEL,
echecker_tokenizer: Optional[AutoTokenizer] = None,
device: Union[str, torch.device, None] = "cuda_if_available",
reset_state: bool = True,
Expand Down
6 changes: 4 additions & 2 deletions src/aac_metrics/functional/sbert_sim.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,14 +15,16 @@
from aac_metrics.utils.globals import _get_device


DEFAULT_SBERT_SIM_MODEL = "paraphrase-TinyBERT-L6-v2"

pylog = logging.getLogger(__name__)


def sbert_sim(
candidates: list[str],
mult_references: list[list[str]],
return_all_scores: bool = True,
sbert_model: Union[str, SentenceTransformer] = "paraphrase-TinyBERT-L6-v2",
sbert_model: Union[str, SentenceTransformer] = DEFAULT_SBERT_SIM_MODEL,
device: Union[str, torch.device, None] = "cuda_if_available",
batch_size: int = 32,
reset_state: bool = True,
Expand Down Expand Up @@ -86,7 +88,7 @@ def sbert_sim(


def _load_sbert(
sbert_model: Union[str, SentenceTransformer] = "paraphrase-TinyBERT-L6-v2",
sbert_model: Union[str, SentenceTransformer] = DEFAULT_SBERT_SIM_MODEL,
device: Union[str, torch.device, None] = "cuda_if_available",
reset_state: bool = True,
) -> SentenceTransformer:
Expand Down
Loading

0 comments on commit 414880b

Please sign in to comment.