Skip to content

Commit

Permalink
Mod: Remove transformers minimal version by making FER metric compati…
Browse files Browse the repository at this point in the history
…ble with version >= 4.31.
  • Loading branch information
Labbeti committed Jan 3, 2024
1 parent 6fcff0d commit 5d710f9
Show file tree
Hide file tree
Showing 6 changed files with 35 additions and 18 deletions.
4 changes: 4 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,10 @@

All notable changes to this project will be documented in this file.

## [0.5.2] UNRELEASED
### Changed
- `aac-metrics` is now compatible with `transformers>=4.31`

## [0.5.1] 2023-12-20
### Added
- Check sentences inputs for all metrics.
Expand Down
2 changes: 1 addition & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -5,5 +5,5 @@ numpy>=1.21.2
pyyaml>=6.0
tqdm>=4.64.0
sentence-transformers>=2.2.2
transformers<4.31.0
transformers
torchmetrics>=0.11.4
7 changes: 4 additions & 3 deletions src/aac_metrics/classes/fense.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,12 @@

import torch

from sentence_transformers import SentenceTransformer
from torch import Tensor

from aac_metrics.classes.base import AACMetric
from aac_metrics.functional.fense import fense, _load_models_and_tokenizer
from aac_metrics.functional.fer import _ERROR_NAMES
from aac_metrics.functional.fer import BERTFlatClassifier, _ERROR_NAMES


pylog = logging.getLogger(__name__)
Expand All @@ -36,8 +37,8 @@ class FENSE(AACMetric[Union[tuple[dict[str, Tensor], dict[str, Tensor]], Tensor]
def __init__(
self,
return_all_scores: bool = True,
sbert_model: str = "paraphrase-TinyBERT-L6-v2",
echecker: str = "echecker_clotho_audiocaps_base",
sbert_model: Union[str, SentenceTransformer] = "paraphrase-TinyBERT-L6-v2",
echecker: Union[str, BERTFlatClassifier] = "echecker_clotho_audiocaps_base",
error_threshold: float = 0.9,
device: Union[str, torch.device, None] = "auto",
batch_size: int = 32,
Expand Down
11 changes: 9 additions & 2 deletions src/aac_metrics/classes/fer.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@

from aac_metrics.classes.base import AACMetric
from aac_metrics.functional.fer import (
BERTFlatClassifier,
fer,
_load_echecker_and_tokenizer,
_ERROR_NAMES,
Expand Down Expand Up @@ -39,15 +40,21 @@ class FER(AACMetric[Union[tuple[dict[str, Tensor], dict[str, Tensor]], Tensor]])
def __init__(
self,
return_all_scores: bool = True,
echecker: str = "echecker_clotho_audiocaps_base",
echecker: Union[str, BERTFlatClassifier] = "echecker_clotho_audiocaps_base",
error_threshold: float = 0.9,
device: Union[str, torch.device, None] = "auto",
batch_size: int = 32,
reset_state: bool = True,
return_probs: bool = False,
verbose: int = 0,
) -> None:
echecker, echecker_tokenizer = _load_echecker_and_tokenizer(echecker, None, device, reset_state, verbose) # type: ignore
echecker, echecker_tokenizer = _load_echecker_and_tokenizer(
echecker=echecker,
echecker_tokenizer=None,
device=device,
reset_state=reset_state,
verbose=verbose,
)

super().__init__()
self._return_all_scores = return_all_scores
Expand Down
4 changes: 3 additions & 1 deletion src/aac_metrics/functional/fense.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,7 +146,9 @@ def _load_models_and_tokenizer(
verbose: int = 0,
) -> tuple[SentenceTransformer, BERTFlatClassifier, AutoTokenizer]:
sbert_model = _load_sbert(
sbert_model=sbert_model, device=device, reset_state=reset_state
sbert_model=sbert_model,
device=device,
reset_state=reset_state,
)
echecker, echecker_tokenizer = _load_echecker_and_tokenizer(
echecker=echecker,
Expand Down
25 changes: 14 additions & 11 deletions src/aac_metrics/functional/fer.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,13 +129,6 @@ def fer(
error_msg = f"Invalid candidates type. (expected list[str], found {candidates.__class__.__name__})"
raise ValueError(error_msg)

version = transformers.__version__
major, minor, _patch = map(int, version.split("."))
if major > 4 or (major == 4 and minor > 30):
raise ValueError(
f"Invalid transformers version {version} for FER metric. Please use a version < 4.31.0."
)

# Init models
echecker, echecker_tokenizer = _load_echecker_and_tokenizer(
echecker=echecker,
Expand Down Expand Up @@ -382,17 +375,27 @@ def __load_pretrain_echecker(
pylog.debug(f"Loading echecker model from '{file_path}'.")

model_states = torch.load(file_path)
model_type = model_states["model_type"]
num_classes = model_states["num_classes"]
state_dict = model_states["state_dict"]

if verbose >= 2:
pylog.debug(
f"Loading echecker model type '{model_states['model_type']}' with '{model_states['num_classes']}' classes."
f"Loading echecker model type '{model_type}' with '{num_classes}' classes."
)

echecker = BERTFlatClassifier(
model_type=model_states["model_type"],
num_classes=model_states["num_classes"],
model_type=model_type,
num_classes=num_classes,
)
echecker.load_state_dict(model_states["state_dict"])

# To support transformers > 4.31, because this lib changed BertEmbedding state_dict
version = transformers.__version__
major, minor, _patch = map(int, version.split("."))
if major > 4 or (major == 4 and minor >= 31):
state_dict.pop("encoder.embeddings.position_ids")

echecker.load_state_dict(state_dict)
echecker.eval()
echecker.to(device=device)
return echecker
Expand Down

0 comments on commit 5d710f9

Please sign in to comment.