From ea508c8076c0bb7f6fc14848eb1866aa96316266 Mon Sep 17 00:00:00 2001 From: Labbeti Date: Thu, 21 Dec 2023 18:20:04 +0100 Subject: [PATCH] Add: Check transformers version in FER metric and refactor internal code. --- src/aac_metrics/functional/fer.py | 109 ++++++++++++------------------ 1 file changed, 45 insertions(+), 64 deletions(-) diff --git a/src/aac_metrics/functional/fer.py b/src/aac_metrics/functional/fer.py index 7a3eabe..a15097d 100644 --- a/src/aac_metrics/functional/fer.py +++ b/src/aac_metrics/functional/fer.py @@ -14,6 +14,7 @@ import numpy as np import torch +import transformers from torch import nn, Tensor from tqdm import tqdm @@ -26,12 +27,11 @@ from aac_metrics.utils.globals import _get_device -# config according to the settings on your computer, this should be default setting of shadowsocks -DEFAULT_PROXIES = { +_DEFAULT_PROXIES = { "http": "socks5h://127.0.0.1:1080", "https": "socks5h://127.0.0.1:1080", } -PRETRAIN_ECHECKERS_DICT = { +_PRETRAIN_ECHECKERS_DICT = { "echecker_clotho_audiocaps_base": ( "https://github.com/blmoistawinde/fense/releases/download/V0.1/echecker_clotho_audiocaps_base.ckpt", "1a719f090af70614bbdb9f9437530b7e133c48cfa4a58d964de0d47fc974a2fa", @@ -41,13 +41,7 @@ "90ed0ac5033ec497ec66d4f68588053813e085671136dae312097c96c504f673", ), } - -RemoteFileMetadata = namedtuple("RemoteFileMetadata", ["filename", "url", "checksum"]) - -pylog = logging.getLogger(__name__) - - -ERROR_NAMES = ( +_ERROR_NAMES = ( "add_tail", "repeat_event", "repeat_adv", @@ -56,6 +50,10 @@ "error", ) +_RemoteFileMetadata = namedtuple("RemoteFileMetadata", ["filename", "url", "checksum"]) + +pylog = logging.getLogger(__name__) + class BERTFlatClassifier(nn.Module): def __init__(self, model_type: str, num_classes: int = 5) -> None: @@ -131,18 +129,29 @@ def fer( error_msg = f"Invalid candidates type. (expected list[str], found {candidates.__class__.__name__})" raise ValueError(error_msg) + version = transformers.__version__ + major, minor, _patch = map(int, version.split(".")) + if major > 4 or (major == 4 and minor > 30): + raise ValueError( + f"Invalid transformers version {version} for FER metric. Please use a version < 4.31.0." + ) + # Init models echecker, echecker_tokenizer = _load_echecker_and_tokenizer( - echecker, echecker_tokenizer, device, reset_state, verbose + echecker=echecker, + echecker_tokenizer=echecker_tokenizer, + device=device, + reset_state=reset_state, + verbose=verbose, ) # Compute and apply fluency error detection penalty probs_outs_sents = __detect_error_sents( - echecker, - echecker_tokenizer, # type: ignore - candidates, - batch_size, - device, + echecker=echecker, + echecker_tokenizer=echecker_tokenizer, + sents=candidates, + batch_size=batch_size, + device=device, ) fer_scores = (probs_outs_sents["error"] > error_threshold).astype(float) @@ -226,10 +235,10 @@ def __detect_error_sents( # batch_logits: (bsize, num_classes=6) # note: fix error in the original fense code: https://github.com/blmoistawinde/fense/blob/main/fense/evaluator.py#L69 probs = logits.sigmoid().transpose(0, 1).cpu().numpy() - probs_dic: dict[str, np.ndarray] = dict(zip(ERROR_NAMES, probs)) + probs_dic: dict[str, np.ndarray] = dict(zip(_ERROR_NAMES, probs)) else: - dic_lst_probs = {name: [] for name in ERROR_NAMES} + dic_lst_probs = {name: [] for name in _ERROR_NAMES} for i in range(0, len(sents), batch_size): batch = __infer_preprocess( @@ -257,11 +266,10 @@ def __detect_error_sents( def __check_download_resource( - remote: RemoteFileMetadata, + remote: _RemoteFileMetadata, use_proxy: bool = False, proxies: Optional[dict[str, str]] = None, ) -> str: - proxies = DEFAULT_PROXIES if use_proxy and proxies is None else proxies data_home = __get_data_home() file_path = os.path.join(data_home, remote.filename) if not os.path.exists(file_path): @@ -286,10 +294,10 @@ def __infer_preprocess( def __download( - remote: RemoteFileMetadata, + remote: _RemoteFileMetadata, file_path: Optional[str] = None, use_proxy: bool = False, - proxies: Optional[dict[str, str]] = DEFAULT_PROXIES, + proxies: Optional[dict[str, str]] = None, ) -> str: data_home = __get_data_home() file_path = __fetch_remote(remote, data_home, use_proxy, proxies) @@ -299,8 +307,12 @@ def __download( def __download_with_bar( url: str, file_path: str, - proxies: Optional[dict[str, str]] = DEFAULT_PROXIES, + use_proxy: bool = False, + proxies: Optional[dict[str, str]] = None, ) -> str: + if use_proxy and proxies is None: + proxies = _DEFAULT_PROXIES + # Streaming, so we can iterate over the response. response = requests.get(url, stream=True, proxies=proxies) total_size_in_bytes = int(response.headers.get("content-length", 0)) @@ -317,31 +329,13 @@ def __download_with_bar( def __fetch_remote( - remote: RemoteFileMetadata, + remote: _RemoteFileMetadata, dirname: Optional[str] = None, use_proxy: bool = False, - proxies: Optional[dict[str, str]] = DEFAULT_PROXIES, + proxies: Optional[dict[str, str]] = None, ) -> str: - """Helper function to download a remote dataset into path - Fetch a dataset pointed by remote's url, save into path using remote's - filename and ensure its integrity based on the SHA256 Checksum of the - downloaded file. - Parameters - ---------- - remote : RemoteFileMetadata - Named tuple containing remote dataset meta information: url, filename - and checksum - dirname : string - Directory to save the file to. - Returns - ------- - file_path: string - Full path of the created file. - """ - file_path = remote.filename if dirname is None else join(dirname, remote.filename) - proxies = None if not use_proxy else proxies - file_path = __download_with_bar(remote.url, file_path, proxies) + file_path = __download_with_bar(remote.url, file_path, use_proxy, proxies) checksum = __sha256(file_path) if remote.checksum != checksum: raise IOError( @@ -352,23 +346,10 @@ def __fetch_remote( return file_path -def __get_data_home(data_home: Optional[str] = None) -> str: # type: ignore - """Return the path of the scikit-learn data dir. - This folder is used by some large dataset loaders to avoid downloading the - data several times. - By default the data dir is set to a folder named 'fense_data' in the - user home folder. - Alternatively, it can be set by the 'FENSE_DATA' environment - variable or programmatically by giving an explicit folder path. The '~' - symbol is expanded to the user home folder. - If the folder does not already exist, it is automatically created. - Parameters - ---------- - data_home : str | None - The path to data dir. - """ +def __get_data_home(data_home: Optional[str] = None) -> str: if data_home is None: - data_home = environ.get("FENSE_DATA", join(torch.hub.get_dir(), "fense_data")) + DEFAULT_DATA_HOME = join(torch.hub.get_dir(), "fense_data") + data_home = environ.get("FENSE_DATA", DEFAULT_DATA_HOME) data_home: str data_home = expanduser(data_home) @@ -384,15 +365,15 @@ def __load_pretrain_echecker( proxies: Optional[dict[str, str]] = None, verbose: int = 0, ) -> BERTFlatClassifier: - if echecker_model not in PRETRAIN_ECHECKERS_DICT: + if echecker_model not in _PRETRAIN_ECHECKERS_DICT: raise ValueError( - f"Invalid argument {echecker_model=}. (expected one of {tuple(PRETRAIN_ECHECKERS_DICT.keys())})" + f"Invalid argument {echecker_model=}. (expected one of {tuple(_PRETRAIN_ECHECKERS_DICT.keys())})" ) device = _get_device(device) tfmers_logging.set_verbosity_error() # suppress loading warnings - url, checksum = PRETRAIN_ECHECKERS_DICT[echecker_model] - remote = RemoteFileMetadata( + url, checksum = _PRETRAIN_ECHECKERS_DICT[echecker_model] + remote = _RemoteFileMetadata( filename=f"{echecker_model}.ckpt", url=url, checksum=checksum ) file_path = __check_download_resource(remote, use_proxy, proxies)