Skip to content

Commit

Permalink
Mod: Avoid test with original fense when transformers version is too …
Browse files Browse the repository at this point in the history
…recent.
  • Loading branch information
Labbeti committed Jan 3, 2024
1 parent e4a2ed7 commit 4737718
Show file tree
Hide file tree
Showing 2 changed files with 44 additions and 19 deletions.
18 changes: 14 additions & 4 deletions src/aac_metrics/functional/fer.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,13 @@ def from_pretrained(
proxies: Optional[dict[str, str]] = None,
verbose: int = 0,
) -> "BERTFlatClassifier":
return __load_pretrain_echecker(model_name, device, use_proxy, proxies, verbose)
return __load_pretrain_echecker(
echecker_model=model_name,
device=device,
use_proxy=use_proxy,
proxies=proxies,
verbose=verbose,
)

def forward(
self,
Expand Down Expand Up @@ -176,6 +182,12 @@ def fer(
return fer_score


def _use_new_echecker_loading() -> bool:
version = transformers.__version__
major, minor, _patch = map(int, version.split("."))
return major > 4 or (major == 4 and minor >= 31)


# - Private functions
def _load_echecker_and_tokenizer(
echecker: Union[str, BERTFlatClassifier] = "echecker_clotho_audiocaps_base",
Expand Down Expand Up @@ -390,9 +402,7 @@ def __load_pretrain_echecker(
)

# To support transformers > 4.31, because this lib changed BertEmbedding state_dict
version = transformers.__version__
major, minor, _patch = map(int, version.split("."))
if major > 4 or (major == 4 and minor >= 31):
if _use_new_echecker_loading():
state_dict.pop("encoder.embeddings.position_ids")

echecker.load_state_dict(state_dict)
Expand Down
45 changes: 30 additions & 15 deletions tests/test_compare_fense.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
# -*- coding: utf-8 -*-

import importlib
import logging
import os.path as osp
import sys
import torch
Expand All @@ -10,10 +11,16 @@
from typing import Any
from unittest import TestCase

import transformers

from aac_metrics.classes.fense import FENSE
from aac_metrics.functional.fer import _use_new_echecker_loading
from aac_metrics.eval import load_csv_file


pylog = logging.getLogger(__name__)


class TestCompareFENSE(TestCase):
# Set Up methods
@classmethod
Expand All @@ -23,21 +30,23 @@ def setUpClass(cls) -> None:
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using {device=}")

cls.src_sbert_sim = Evaluator(
device=device,
echecker_model="none",
)
cls.src_fense = Evaluator(
device=device,
echecker_model="echecker_clotho_audiocaps_base",
)

cls.new_fense = FENSE(
return_all_scores=True,
device=device,
verbose=2,
echecker="echecker_clotho_audiocaps_base",
)
cls.src_sbert_sim = Evaluator(
device=device,
echecker_model="none",
)
if _use_new_echecker_loading():
cls.src_fense = None
else:
cls.src_fense = Evaluator(
device=device,
echecker_model="echecker_clotho_audiocaps_base",
)

@classmethod
def _get_src_evaluator_class(cls) -> Any:
Expand Down Expand Up @@ -78,7 +87,6 @@ def _test_with_original_fense(self, fpath: str) -> None:
cands, mrefs = load_csv_file(fpath)

src_sbert_sim_score = self.src_sbert_sim.corpus_score(cands, mrefs).item()
src_fense_score = self.src_fense.corpus_score(cands, mrefs).item()

outs: tuple = self.new_fense(cands, mrefs) # type: ignore
corpus_outs, _sents_outs = outs
Expand All @@ -90,11 +98,18 @@ def _test_with_original_fense(self, fpath: str) -> None:
new_sbert_sim_score,
"Invalid SBERTSim score with original implementation.",
)
self.assertEqual(
src_fense_score,
new_fense_score,
"Invalid FENSE score with original implementation.",
)

if self.src_fense is None:
pylog.warning(
f"Skipping test with original FENSE for the transformers version {transformers.__version__}"
)
else:
src_fense_score = self.src_fense.corpus_score(cands, mrefs).item()
self.assertEqual(
src_fense_score,
new_fense_score,
"Invalid FENSE score with original implementation.",
)


if __name__ == "__main__":
Expand Down

0 comments on commit 4737718

Please sign in to comment.