Skip to content

Commit

Permalink
Mod/Fix: Update kwds args pass to internal private funcs and fix _get…
Browse files Browse the repository at this point in the history
…_device behaviour.
  • Loading branch information
Labbeti committed Dec 11, 2023
1 parent 8914275 commit 10c6045
Show file tree
Hide file tree
Showing 3 changed files with 23 additions and 6 deletions.
9 changes: 8 additions & 1 deletion src/aac_metrics/classes/fense.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,14 @@ def __init__(
penalty: float = 0.9,
verbose: int = 0,
) -> None:
sbert_model, echecker, echecker_tokenizer = _load_models_and_tokenizer(sbert_model, echecker, None, device, reset_state, verbose) # type: ignore
sbert_model, echecker, echecker_tokenizer = _load_models_and_tokenizer(
sbert_model=sbert_model,
echecker=echecker,
echecker_tokenizer=None,
device=device,
reset_state=reset_state,
verbose=verbose,
)

super().__init__()
self._return_all_scores = return_all_scores
Expand Down
10 changes: 8 additions & 2 deletions src/aac_metrics/functional/fense.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,8 +148,14 @@ def _load_models_and_tokenizer(
reset_state: bool = True,
verbose: int = 0,
) -> tuple[SentenceTransformer, BERTFlatClassifier, AutoTokenizer]:
sbert_model = _load_sbert(sbert_model, device, reset_state)
sbert_model = _load_sbert(
sbert_model=sbert_model, device=device, reset_state=reset_state
)
echecker, echecker_tokenizer = _load_echecker_and_tokenizer(
echecker, echecker_tokenizer, device, reset_state, verbose
echecker=echecker,
echecker_tokenizer=echecker_tokenizer,
device=device,
reset_state=reset_state,
verbose=verbose,
)
return sbert_model, echecker, echecker_tokenizer
10 changes: 7 additions & 3 deletions src/aac_metrics/utils/globals.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,9 +87,6 @@ def __get_default_value(value_name: str) -> Any:
values = __DEFAULT_GLOBALS[value_name]["values"]
process_func = __DEFAULT_GLOBALS[value_name]["process"]
for source, value_or_env_varname in values.items():
if value_or_env_varname is None:
continue

if source.startswith("env"):
value = os.getenv(value_or_env_varname, None)
else:
Expand Down Expand Up @@ -149,6 +146,13 @@ def __process_device(value: Union[str, torch.device, None]) -> Optional[torch.de
},
"process": __process_path,
},
"device": {
"values": {
"env": "AAC_METRICS_DEVICE",
"package": "auto",
},
"process": __process_device,
},
"java": {
"values": {
"user": None,
Expand Down

0 comments on commit 10c6045

Please sign in to comment.