Skip to content

Commit

Permalink
Mod: Update internal typing and global paths.
Browse files Browse the repository at this point in the history
  • Loading branch information
Labbeti committed Nov 3, 2023
1 parent 7d3297d commit 8300f62
Show file tree
Hide file tree
Showing 3 changed files with 50 additions and 41 deletions.
6 changes: 3 additions & 3 deletions src/aac_metrics/classes/evaluate.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import pickle
import zlib

from typing import Any, Callable, Iterable, Union
from typing import Any, Callable, Iterable, Optional, Union

import torch

Expand Down Expand Up @@ -178,9 +178,9 @@ def _get_metric_factory_classes(
tmp_path: str = ...,
device: Union[str, torch.device, None] = "auto",
verbose: int = 0,
init_kwds: dict[str, Any] = ...,
init_kwds: Optional[dict[str, Any]] = None,
) -> dict[str, Callable[[], AACMetric]]:
if init_kwds is ... or init_kwds is None:
if init_kwds is None or init_kwds is ...:
init_kwds = {}

init_kwds = init_kwds | dict(return_all_scores=return_all_scores)
Expand Down
2 changes: 1 addition & 1 deletion src/aac_metrics/functional/evaluate.py
Original file line number Diff line number Diff line change
Expand Up @@ -243,7 +243,7 @@ def _get_metric_factory_functions(
tmp_path: str = ...,
device: Union[str, torch.device, None] = "auto",
verbose: int = 0,
init_kwds: Optional[dict[str, Any]] = ...,
init_kwds: Optional[dict[str, Any]] = None,
) -> dict[str, Callable[[list[str], list[list[str]]], Any]]:
if init_kwds is None or init_kwds is ...:
init_kwds = {}
Expand Down
83 changes: 46 additions & 37 deletions src/aac_metrics/utils/paths.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,13 +7,13 @@
import tempfile

from pathlib import Path
from typing import Optional, Union
from typing import Union, overload


pylog = logging.getLogger(__name__)


__DEFAULT_PATHS: dict[str, dict[str, Optional[str]]] = {
__DEFAULT_GLOBALS: dict[str, dict[str, Union[str, None]]] = {
"cache": {
"user": None,
"env": "AAC_METRICS_CACHE_PATH",
Expand All @@ -40,7 +40,7 @@ def get_default_cache_path() -> str:
Else if the environment variable AAC_METRICS_CACHE_PATH has been set to a string, it will return its value.
Else it will be equal to "~/.cache" by default.
"""
return __get_default_path("cache")
return __get_default_value("cache")


def get_default_java_path() -> str:
Expand All @@ -50,7 +50,7 @@ def get_default_java_path() -> str:
Else if the environment variable AAC_METRICS_JAVA_PATH has been set to a string, it will return its value.
Else it will be equal to "java" by default.
"""
return __get_default_path("java")
return __get_default_value("java")


def get_default_tmp_path() -> str:
Expand All @@ -60,78 +60,87 @@ def get_default_tmp_path() -> str:
Else if the environment variable AAC_METRICS_TMP_PATH has been set to a string, it will return its value.
Else it will be equal to the value returned by :func:`~tempfile.gettempdir()` by default.
"""
return __get_default_path("tmp")
return __get_default_value("tmp")


def set_default_cache_path(cache_path: Union[str, Path, None]) -> None:
"""Override default cache directory path."""
__set_default_path("cache", cache_path)
__set_default_value("cache", cache_path)


def set_default_java_path(java_path: Union[str, Path, None]) -> None:
"""Override default java executable path."""
__set_default_path("java", java_path)
__set_default_value("java", java_path)


def set_default_tmp_path(tmp_path: Union[str, Path, None]) -> None:
"""Override default temporary directory path."""
__set_default_path("tmp", tmp_path)
__set_default_value("tmp", tmp_path)


# Private functions
def _get_cache_path(cache_path: Union[str, Path, None] = None) -> str:
return __get_path("cache", cache_path)
return __get_value("cache", cache_path)


def _get_java_path(java_path: Union[str, Path, None] = None) -> str:
return __get_path("java", java_path)
return __get_value("java", java_path)


def _get_tmp_path(tmp_path: Union[str, Path, None] = None) -> str:
return __get_path("tmp", tmp_path)
return __get_value("tmp", tmp_path)


def __get_default_path(path_name: str) -> str:
paths = __DEFAULT_PATHS[path_name]
def __get_default_value(value_name: str) -> str:
values = __DEFAULT_GLOBALS[value_name]

for name, path_or_var in paths.items():
if path_or_var is None:
for source, value_or_env_varname in values.items():
if value_or_env_varname is None:
continue

if name.startswith("env"):
path = os.getenv(path_or_var, None)
if source.startswith("env"):
path = os.getenv(value_or_env_varname, None)
else:
path = path_or_var
path = value_or_env_varname

if path is not None:
path = __process_path(path)
path = __process_value(path)
return path

pylog.error(f"Paths values: {paths}")
pylog.error(f"Paths values: {values}")
raise RuntimeError(
f"Invalid default path for {path_name=}. (all default paths are None)"
f"Invalid default path for {value_name=}. (all default paths are None)"
)


def __set_default_path(
path_name: str,
path: Union[str, Path, None],
def __set_default_value(
value_name: str,
value: Union[str, Path, None],
) -> None:
if path is not ... and path is not None:
path = __process_path(path)
__DEFAULT_PATHS[path_name]["user"] = path
value = __process_value(value)
__DEFAULT_GLOBALS[value_name]["user"] = value


def __get_path(path_name: str, path: Union[str, Path, None] = None) -> str:
if path is ... or path is None:
return __get_default_path(path_name)
def __get_value(value_name: str, value: Union[str, Path, None] = None) -> str:
if value is ... or value is None:
return __get_default_value(value_name)
else:
path = __process_path(path)
return path
value = __process_value(value)
return value


def __process_path(path: Union[str, Path]) -> str:
path = str(path)
path = osp.expanduser(path)
path = osp.expandvars(path)
return path
@overload
def __process_value(value: None) -> None:
...


@overload
def __process_value(value: Union[str, Path]) -> str:
...


def __process_value(value: Union[str, Path, None]) -> Union[str, None]:
value = str(value)
value = osp.expanduser(value)
value = osp.expandvars(value)
return value

0 comments on commit 8300f62

Please sign in to comment.