Skip to content

Commit

Permalink
Add: paths functions to select java, tmp or cache dirs.
Browse files Browse the repository at this point in the history
  • Loading branch information
Labbeti committed Aug 1, 2023
1 parent 8d425e9 commit b8a7808
Show file tree
Hide file tree
Showing 16 changed files with 220 additions and 75 deletions.
8 changes: 8 additions & 0 deletions .github/workflows/python-package-pip.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,14 @@ jobs:
java-version: ${{ matrix.java-version }}
java-package: jre

- name: Tests
run: |
echo $GITHUB_REF_NAME
echo ${GITHUB_REF_NAME}
echo %GITHUB_REF_NAME%
echo %{GITHUB_REF_NAME}%
echo %GITHUB_REF_NAME
- name: Install package
# note: ${GITHUB_REF##*/} gives the branch name
# python -m pip install "aac-metrics[dev] @ git+https://github.com/Labbeti/aac-metrics@${GITHUB_REF##*/}"
Expand Down
24 changes: 12 additions & 12 deletions src/aac_metrics/classes/evaluate.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,9 +42,9 @@ def __init__(
self,
preprocess: bool = True,
metrics: Union[str, Iterable[str], Iterable[AACMetric]] = "aac",
cache_path: str = "~/.cache",
java_path: str = "java",
tmp_path: str = "/tmp",
cache_path: str = ...,
java_path: str = ...,
tmp_path: str = ...,
device: Union[str, torch.device, None] = "auto",
verbose: int = 0,
) -> None:
Expand Down Expand Up @@ -114,9 +114,9 @@ class DCASE2023Evaluate(Evaluate):
def __init__(
self,
preprocess: bool = True,
cache_path: str = "~/.cache",
java_path: str = "java",
tmp_path: str = "/tmp",
cache_path: str = ...,
java_path: str = ...,
tmp_path: str = ...,
device: Union[str, torch.device, None] = "auto",
verbose: int = 0,
) -> None:
Expand All @@ -133,9 +133,9 @@ def __init__(

def _instantiate_metrics_classes(
metrics: Union[str, Iterable[str], Iterable[AACMetric]] = "aac",
cache_path: str = "~/.cache",
java_path: str = "java",
tmp_path: str = "/tmp",
cache_path: str = ...,
java_path: str = ...,
tmp_path: str = ...,
device: Union[str, torch.device, None] = "auto",
verbose: int = 0,
) -> list[AACMetric]:
Expand Down Expand Up @@ -166,9 +166,9 @@ def _instantiate_metrics_classes(

def _get_metric_factory_classes(
return_all_scores: bool = True,
cache_path: str = "~/.cache",
java_path: str = "java",
tmp_path: str = "/tmp",
cache_path: str = ...,
java_path: str = ...,
tmp_path: str = ...,
device: Union[str, torch.device, None] = "auto",
verbose: int = 0,
) -> dict[str, Callable[[], AACMetric]]:
Expand Down
4 changes: 2 additions & 2 deletions src/aac_metrics/classes/meteor.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,8 @@ class METEOR(AACMetric[Union[tuple[dict[str, Tensor], dict[str, Tensor]], Tensor
def __init__(
self,
return_all_scores: bool = True,
cache_path: str = "~/.cache",
java_path: str = "java",
cache_path: str = ...,
java_path: str = ...,
java_max_memory: str = "2G",
language: str = "en",
verbose: int = 0,
Expand Down
6 changes: 3 additions & 3 deletions src/aac_metrics/classes/spice.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,9 +32,9 @@ class SPICE(AACMetric[Union[tuple[dict[str, Tensor], dict[str, Tensor]], Tensor]
def __init__(
self,
return_all_scores: bool = True,
cache_path: str = "~/.cache",
java_path: str = "java",
tmp_path: str = "/tmp",
cache_path: str = ...,
java_path: str = ...,
tmp_path: str = ...,
n_threads: Optional[int] = None,
java_max_memory: str = "8G",
timeout: Union[None, int, Iterable[int]] = None,
Expand Down
6 changes: 3 additions & 3 deletions src/aac_metrics/classes/spider.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,9 +36,9 @@ def __init__(
n: int = 4,
sigma: float = 6.0,
# SPICE args
cache_path: str = "~/.cache",
java_path: str = "java",
tmp_path: str = "/tmp",
cache_path: str = ...,
java_path: str = ...,
tmp_path: str = ...,
n_threads: Optional[int] = None,
java_max_memory: str = "8G",
timeout: Union[None, int, Iterable[int]] = None,
Expand Down
6 changes: 3 additions & 3 deletions src/aac_metrics/classes/spider_fl.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,9 +41,9 @@ def __init__(
n: int = 4,
sigma: float = 6.0,
# SPICE args
cache_path: str = "~/.cache",
java_path: str = "java",
tmp_path: str = "/tmp",
cache_path: str = ...,
java_path: str = ...,
tmp_path: str = ...,
n_threads: Optional[int] = None,
java_max_memory: str = "8G",
timeout: Union[None, int, Iterable[int]] = None,
Expand Down
6 changes: 3 additions & 3 deletions src/aac_metrics/classes/spider_max.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,9 +37,9 @@ def __init__(
n: int = 4,
sigma: float = 6.0,
# SPICE args
cache_path: str = "~/.cache",
java_path: str = "java",
tmp_path: str = "/tmp",
cache_path: str = ...,
java_path: str = ...,
tmp_path: str = ...,
n_threads: Optional[int] = None,
java_max_memory: str = "8G",
timeout: Union[None, int, Iterable[int]] = None,
Expand Down
9 changes: 5 additions & 4 deletions src/aac_metrics/download.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from aac_metrics.classes.fense import FENSE
from aac_metrics.functional.meteor import FNAME_METEOR_JAR
from aac_metrics.functional.spice import FNAME_SPICE_JAR, DNAME_SPICE_CACHE
from aac_metrics.utils.path import _process_cache_path, _process_tmp_path
from aac_metrics.utils.tokenization import FNAME_STANFORD_CORENLP_3_4_1_JAR


Expand Down Expand Up @@ -44,8 +45,8 @@


def download(
cache_path: str = "~/.cache",
tmp_path: str = "/tmp",
cache_path: str = ...,
tmp_path: str = ...,
ptb_tokenizer: bool = True,
meteor: bool = True,
spice: bool = True,
Expand All @@ -62,8 +63,8 @@ def download(
:param fense: If True, downloads the FENSE models. defaults to True.
:param verbose: The verbose level. defaults to 0.
"""
cache_path = osp.expandvars(osp.expanduser(cache_path))
tmp_path = osp.expandvars(osp.expanduser(tmp_path))
cache_path = _process_cache_path(cache_path)
tmp_path = _process_tmp_path(tmp_path)

os.makedirs(cache_path, exist_ok=True)
os.makedirs(tmp_path, exist_ok=True)
Expand Down
24 changes: 12 additions & 12 deletions src/aac_metrics/functional/evaluate.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,9 +74,9 @@ def evaluate(
metrics: Union[
str, Iterable[str], Iterable[Callable[[list, list], tuple]]
] = "default",
cache_path: str = "~/.cache",
java_path: str = "java",
tmp_path: str = "/tmp",
cache_path: str = ...,
java_path: str = ...,
tmp_path: str = ...,
device: Union[str, torch.device, None] = "auto",
verbose: int = 0,
) -> tuple[dict[str, Tensor], dict[str, Tensor]]:
Expand Down Expand Up @@ -157,9 +157,9 @@ def dcase2023_evaluate(
candidates: list[str],
mult_references: list[list[str]],
preprocess: bool = True,
cache_path: str = "~/.cache",
java_path: str = "java",
tmp_path: str = "/tmp",
cache_path: str = ...,
java_path: str = ...,
tmp_path: str = ...,
device: Union[str, torch.device, None] = "auto",
verbose: int = 0,
) -> tuple[dict[str, Tensor], dict[str, Tensor]]:
Expand Down Expand Up @@ -192,9 +192,9 @@ def dcase2023_evaluate(

def _instantiate_metrics_functions(
metrics: Union[str, Iterable[str], Iterable[Callable[[list, list], tuple]]] = "all",
cache_path: str = "~/.cache",
java_path: str = "java",
tmp_path: str = "/tmp",
cache_path: str = ...,
java_path: str = ...,
tmp_path: str = ...,
device: Union[str, torch.device, None] = "auto",
verbose: int = 0,
) -> list[Callable]:
Expand Down Expand Up @@ -230,9 +230,9 @@ def _instantiate_metrics_functions(

def _get_metric_factory_functions(
return_all_scores: bool = True,
cache_path: str = "~/.cache",
java_path: str = "java",
tmp_path: str = "/tmp",
cache_path: str = ...,
java_path: str = ...,
tmp_path: str = ...,
device: Union[str, torch.device, None] = "auto",
verbose: int = 0,
) -> dict[str, Callable[[list[str], list[list[str]]], Any]]:
Expand Down
14 changes: 8 additions & 6 deletions src/aac_metrics/functional/meteor.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from torch import Tensor

from aac_metrics.utils.checks import check_java_path
from aac_metrics.utils.path import _process_cache_path, _process_java_path


pylog = logging.getLogger(__name__)
Expand All @@ -28,8 +29,8 @@ def meteor(
candidates: list[str],
mult_references: list[list[str]],
return_all_scores: bool = True,
cache_path: str = "~/.cache",
java_path: str = "java",
cache_path: str = ...,
java_path: str = ...,
java_max_memory: str = "2G",
language: str = "en",
verbose: int = 0,
Expand All @@ -47,15 +48,16 @@ def meteor(
:param cache_path: The path to the external code directory. defaults to "~/.cache".
:param java_path: The path to the java executable. defaults to "java".
:param java_max_memory: The maximal java memory used. defaults to "2G".
:param language: The language used for stem, synonym and paraphrase matching. defaults to "en".
:param language: The language used for stem, synonym and paraphrase matching.
Can be one of ("en", "cz", "de", "es", "fr").
defaults to "en".
:param verbose: The verbose level. defaults to 0.
:returns: A tuple of globals and locals scores or a scalar tensor with the main global score.
"""
cache_path = osp.expandvars(osp.expanduser(cache_path))
java_path = osp.expandvars(osp.expanduser(java_path))
cache_path = _process_cache_path(cache_path)
java_path = _process_java_path(java_path)

meteor_jar_fpath = osp.join(cache_path, FNAME_METEOR_JAR)
language = "en" # supported: en cz de es fr

if __debug__:
if not osp.isfile(meteor_jar_fpath):
Expand Down
17 changes: 11 additions & 6 deletions src/aac_metrics/functional/spice.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,11 @@
from torch import Tensor

from aac_metrics.utils.checks import check_java_path
from aac_metrics.utils.path import (
_process_cache_path,
_process_java_path,
_process_tmp_path,
)


pylog = logging.getLogger(__name__)
Expand All @@ -34,9 +39,9 @@ def spice(
candidates: list[str],
mult_references: list[list[str]],
return_all_scores: bool = True,
cache_path: str = "~/.cache",
java_path: str = "java",
tmp_path: str = "/tmp",
cache_path: str = ...,
java_path: str = ...,
tmp_path: str = ...,
n_threads: Optional[int] = None,
java_max_memory: str = "8G",
timeout: Union[None, int, Iterable[int]] = None,
Expand Down Expand Up @@ -70,9 +75,9 @@ def spice(
:returns: A tuple of globals and locals scores or a scalar tensor with the main global score.
"""

cache_path = osp.expandvars(osp.expanduser(cache_path))
java_path = osp.expandvars(osp.expanduser(java_path))
tmp_path = osp.expandvars(osp.expanduser(tmp_path))
cache_path = _process_cache_path(cache_path)
java_path = _process_java_path(java_path)
tmp_path = _process_tmp_path(tmp_path)

spice_fpath = osp.join(cache_path, FNAME_SPICE_JAR)

Expand Down
6 changes: 3 additions & 3 deletions src/aac_metrics/functional/spider.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,9 @@ def spider(
tokenizer: Callable[[str], list[str]] = str.split,
return_tfidf: bool = False,
# SPICE args
cache_path: str = "~/.cache",
java_path: str = "java",
tmp_path: str = "/tmp",
cache_path: str = ...,
java_path: str = ...,
tmp_path: str = ...,
n_threads: Optional[int] = None,
java_max_memory: str = "8G",
timeout: Union[None, int, Iterable[int]] = None,
Expand Down
6 changes: 3 additions & 3 deletions src/aac_metrics/functional/spider_fl.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,9 +35,9 @@ def spider_fl(
tokenizer: Callable[[str], list[str]] = str.split,
return_tfidf: bool = False,
# SPICE args
cache_path: str = "~/.cache",
java_path: str = "java",
tmp_path: str = "/tmp",
cache_path: str = ...,
java_path: str = ...,
tmp_path: str = ...,
n_threads: Optional[int] = None,
java_max_memory: str = "8G",
timeout: Union[None, int, Iterable[int]] = None,
Expand Down
6 changes: 3 additions & 3 deletions src/aac_metrics/functional/spider_max.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,9 +22,9 @@ def spider_max(
tokenizer: Callable[[str], list[str]] = str.split,
return_tfidf: bool = False,
# SPICE args
cache_path: str = "~/.cache",
java_path: str = "java",
tmp_path: str = "/tmp",
cache_path: str = ...,
java_path: str = ...,
tmp_path: str = ...,
n_threads: Optional[int] = None,
java_max_memory: str = "8G",
timeout: Union[None, int, Iterable[int]] = None,
Expand Down
Loading

0 comments on commit b8a7808

Please sign in to comment.