Skip to content

Commit

Permalink
Mod: Rename paths.py to globals.py, update internal globals behaviour…
Browse files Browse the repository at this point in the history
… to handle non-path values and add temporary _get_device func.
  • Loading branch information
Labbeti committed Dec 8, 2023
1 parent 30bc672 commit 8914275
Show file tree
Hide file tree
Showing 13 changed files with 109 additions and 98 deletions.
2 changes: 1 addition & 1 deletion src/aac_metrics/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
from .classes.spider_max import SPIDErMax
from .classes.vocab import Vocab
from .functional.evaluate import evaluate, dcase2023_evaluate
from .utils.paths import (
from .utils.globals import (
get_default_cache_path,
get_default_java_path,
get_default_tmp_path,
Expand Down
6 changes: 5 additions & 1 deletion src/aac_metrics/classes/bert_score_mrefs.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,11 @@ def __init__(
verbose: int = 0,
) -> None:
model, tokenizer = _load_model_and_tokenizer(
model, None, device, reset_state, verbose
model=model,
tokenizer=None,
device=device,
reset_state=reset_state,
verbose=verbose,
)

super().__init__()
Expand Down
2 changes: 1 addition & 1 deletion src/aac_metrics/download.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
check_spice_install,
)
from aac_metrics.utils.cmdline import _str_to_bool, _setup_logging
from aac_metrics.utils.paths import (
from aac_metrics.utils.globals import (
_get_cache_path,
_get_tmp_path,
get_default_cache_path,
Expand Down
2 changes: 1 addition & 1 deletion src/aac_metrics/eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
)
from aac_metrics.utils.checks import check_metric_inputs, check_java_path
from aac_metrics.utils.cmdline import _str_to_bool, _str_to_opt_str, _setup_logging
from aac_metrics.utils.paths import (
from aac_metrics.utils.globals import (
get_default_cache_path,
get_default_java_path,
get_default_tmp_path,
Expand Down
28 changes: 13 additions & 15 deletions src/aac_metrics/functional/bert_score_mrefs.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from transformers import logging as tfmers_logging

from aac_metrics.utils.collections import flat_list, unflat_list, duplicate_list
from aac_metrics.utils.globals import _get_device


def bert_score_mrefs(
Expand Down Expand Up @@ -61,8 +62,13 @@ def bert_score_mrefs(
raise ValueError(
f"Invalid argument combinaison {model=} with {tokenizer=}."
)

model, tokenizer = _load_model_and_tokenizer(
model, tokenizer, device, reset_state, verbose
model=model,
tokenizer=tokenizer,
device=device,
reset_state=reset_state,
verbose=verbose,
)

elif isinstance(model, nn.Module):
Expand All @@ -76,11 +82,7 @@ def bert_score_mrefs(
f"Invalid argument type {type(model)=}. (expected str or nn.Module)"
)

if device == "auto":
device = "cuda" if torch.cuda.is_available() else "cpu"
if isinstance(device, str):
device = torch.device(device)

device = _get_device(device)
flat_mrefs, sizes = flat_list(mult_references)
duplicated_cands = duplicate_list(candidates, sizes)

Expand Down Expand Up @@ -116,9 +118,9 @@ def bert_score_mrefs(
if reduction == "mean":
reduction_fn = torch.mean
elif reduction == "max":
reduction_fn = max_reduce
reduction_fn = _max_reduce
elif reduction == "min":
reduction_fn = min_reduce
reduction_fn = _min_reduce
else:
REDUCTIONS = ("mean", "max", "min")
raise ValueError(
Expand Down Expand Up @@ -161,11 +163,7 @@ def _load_model_and_tokenizer(
) -> tuple[nn.Module, Optional[Callable]]:
state = torch.random.get_rng_state()

if device == "auto":
device = "cuda" if torch.cuda.is_available() else "cpu"
if isinstance(device, str):
device = torch.device(device)

device = _get_device(device)
if isinstance(model, str):
tfmers_verbosity = tfmers_logging.get_verbosity()
if verbose <= 1:
Expand All @@ -188,14 +186,14 @@ def _load_model_and_tokenizer(
return model, tokenizer # type: ignore


def max_reduce(x: Tensor, dim: Optional[int] = None) -> Tensor:
def _max_reduce(x: Tensor, dim: Optional[int] = None) -> Tensor:
if dim is None:
return x.max()
else:
return x.max(dim=dim).values


def min_reduce(x: Tensor, dim: Optional[int] = None) -> Tensor:
def _min_reduce(x: Tensor, dim: Optional[int] = None) -> Tensor:
if dim is None:
return x.min()
else:
Expand Down
29 changes: 9 additions & 20 deletions src/aac_metrics/functional/fer.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,8 @@
from transformers.models.auto.tokenization_auto import AutoTokenizer
from transformers.tokenization_utils_fast import PreTrainedTokenizerFast

from aac_metrics.utils.globals import _get_device


# config according to the settings on your computer, this should be default setting of shadowsocks
DEFAULT_PROXIES = {
Expand Down Expand Up @@ -182,13 +184,11 @@ def _load_echecker_and_tokenizer(
) -> tuple[BERTFlatClassifier, AutoTokenizer]:
state = torch.random.get_rng_state()

if device == "auto":
device = "cuda" if torch.cuda.is_available() else "cpu"
if isinstance(device, str):
device = torch.device(device)

device = _get_device(device)
if isinstance(echecker, str):
echecker = __load_pretrain_echecker(echecker, device, verbose=verbose)
echecker = __load_pretrain_echecker(
echecker_model=echecker, device=device, verbose=verbose
)

if echecker_tokenizer is None:
echecker_tokenizer = AutoTokenizer.from_pretrained(echecker.model_type) # type: ignore
Expand All @@ -211,10 +211,7 @@ def __detect_error_sents(
device: Union[str, torch.device, None],
max_len: int = 64,
) -> dict[str, np.ndarray]:
if device == "auto":
device = "cuda" if torch.cuda.is_available() else "cpu"
if isinstance(device, str):
device = torch.device(device)
device = _get_device(device)

if len(sents) <= batch_size:
batch = __infer_preprocess(
Expand Down Expand Up @@ -280,11 +277,7 @@ def __infer_preprocess(
device: Union[str, torch.device, None],
dtype: torch.dtype,
) -> Mapping[str, Tensor]:
if device == "auto":
device = "cuda" if torch.cuda.is_available() else "cpu"
if isinstance(device, str):
device = torch.device(device)

device = _get_device(device)
texts = __text_preprocess(texts) # type: ignore
batch = tokenizer(texts, truncation=True, padding="max_length", max_length=max_len)
for k in ("input_ids", "attention_mask", "token_type_ids"):
Expand Down Expand Up @@ -396,11 +389,7 @@ def __load_pretrain_echecker(
f"Invalid argument {echecker_model=}. (expected one of {tuple(PRETRAIN_ECHECKERS_DICT.keys())})"
)

if device == "auto":
device = "cuda" if torch.cuda.is_available() else "cpu"
if isinstance(device, str):
device = torch.device(device)

device = _get_device(device)
tfmers_logging.set_verbosity_error() # suppress loading warnings
url, checksum = PRETRAIN_ECHECKERS_DICT[echecker_model]
remote = RemoteFileMetadata(
Expand Down
2 changes: 1 addition & 1 deletion src/aac_metrics/functional/meteor.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
from torch import Tensor

from aac_metrics.utils.checks import check_java_path
from aac_metrics.utils.paths import _get_cache_path, _get_java_path
from aac_metrics.utils.globals import _get_cache_path, _get_java_path


pylog = logging.getLogger(__name__)
Expand Down
8 changes: 3 additions & 5 deletions src/aac_metrics/functional/sbert_sim.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@
from sentence_transformers import SentenceTransformer
from torch import Tensor

from aac_metrics.utils.globals import _get_device


pylog = logging.getLogger(__name__)

Expand Down Expand Up @@ -91,11 +93,7 @@ def _load_sbert(
) -> SentenceTransformer:
state = torch.random.get_rng_state()

if device == "auto":
device = "cuda" if torch.cuda.is_available() else "cpu"
if isinstance(device, str):
device = torch.device(device)

device = _get_device(device)
if isinstance(sbert_model, str):
sbert_model = SentenceTransformer(sbert_model, device=device) # type: ignore
sbert_model.to(device=device)
Expand Down
2 changes: 1 addition & 1 deletion src/aac_metrics/functional/spice.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
from torch import Tensor

from aac_metrics.utils.checks import check_java_path
from aac_metrics.utils.paths import (
from aac_metrics.utils.globals import (
_get_cache_path,
_get_java_path,
_get_tmp_path,
Expand Down
2 changes: 1 addition & 1 deletion src/aac_metrics/info.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
import aac_metrics

from aac_metrics.utils.checks import _get_java_version
from aac_metrics.utils.paths import (
from aac_metrics.utils.globals import (
get_default_cache_path,
get_default_java_path,
get_default_tmp_path,
Expand Down
Loading

0 comments on commit 8914275

Please sign in to comment.