Skip to content

Commit

Permalink
Add/Mod: Add use_shell option for METEOR and SPICE metrics. Shell is …
Browse files Browse the repository at this point in the history
…now used only on Windows devices.
  • Loading branch information
Labbeti committed Aug 31, 2023
1 parent a743367 commit 0f7ad13
Show file tree
Hide file tree
Showing 5 changed files with 35 additions and 6 deletions.
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,9 @@ All notable changes to this project will be documented in this file.
- Path management for Windows.

## [0.4.3] 2023-06-15
### Added
- Argument `use_shell` for `METEOR` and `SPICE` metrics.

### Changed
- `AACMetric` is no longer a subclass of `torchmetrics.Metric` even when it is installed. It avoid dependency to this package and remove potential errors due to Metric base class.
- Java 12 and 13 are now allowed in this package.
Expand Down
5 changes: 4 additions & 1 deletion src/aac_metrics/classes/meteor.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-

from typing import Union
from typing import Optional, Union

from torch import Tensor

Expand Down Expand Up @@ -32,6 +32,7 @@ def __init__(
java_path: str = ...,
java_max_memory: str = "2G",
language: str = "en",
use_shell: Optional[bool] = None,
verbose: int = 0,
) -> None:
super().__init__()
Expand All @@ -40,6 +41,7 @@ def __init__(
self._java_path = java_path
self._java_max_memory = java_max_memory
self._language = language
self._use_shell = use_shell
self._verbose = verbose

self._candidates = []
Expand All @@ -54,6 +56,7 @@ def compute(self) -> Union[tuple[dict[str, Tensor], dict[str, Tensor]], Tensor]:
java_path=self._java_path,
java_max_memory=self._java_max_memory,
language=self._language,
use_shell=self._use_shell,
verbose=self._verbose,
)

Expand Down
3 changes: 3 additions & 0 deletions src/aac_metrics/classes/spice.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ def __init__(
java_max_memory: str = "8G",
timeout: Union[None, int, Iterable[int]] = None,
separate_cache_dir: bool = True,
use_shell: Optional[bool] = None,
verbose: int = 0,
) -> None:
super().__init__()
Expand All @@ -50,6 +51,7 @@ def __init__(
self._java_max_memory = java_max_memory
self._timeout = timeout
self._separate_cache_dir = separate_cache_dir
self._use_shell = use_shell
self._verbose = verbose

self._candidates = []
Expand All @@ -67,6 +69,7 @@ def compute(self) -> Union[tuple[dict[str, Tensor], dict[str, Tensor]], Tensor]:
java_max_memory=self._java_max_memory,
timeout=self._timeout,
separate_cache_dir=self._separate_cache_dir,
use_shell=self._use_shell,
verbose=self._verbose,
)

Expand Down
16 changes: 13 additions & 3 deletions src/aac_metrics/functional/meteor.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,11 @@

import logging
import os.path as osp
import platform
import subprocess

from subprocess import Popen
from typing import Union
from typing import Optional, Union

import torch

Expand All @@ -34,6 +35,7 @@ def meteor(
java_path: str = ...,
java_max_memory: str = "2G",
language: str = "en",
use_shell: Optional[bool] = None,
verbose: int = 0,
) -> Union[tuple[dict[str, Tensor], dict[str, Tensor]], Tensor]:
"""Metric for Evaluation of Translation with Explicit ORdering function.
Expand All @@ -52,6 +54,9 @@ def meteor(
:param language: The language used for stem, synonym and paraphrase matching.
Can be one of ("en", "cz", "de", "es", "fr").
defaults to "en".
:param use_shell: Optional argument to force use os-specific shell for the java subprogram.
If None, it will use shell only on Windows OS.
defaults to None.
:param verbose: The verbose level. defaults to 0.
:returns: A tuple of globals and locals scores or a scalar tensor with the main global score.
"""
Expand All @@ -60,6 +65,9 @@ def meteor(

meteor_jar_fpath = osp.join(cache_path, FNAME_METEOR_JAR)

if use_shell is None:
use_shell = platform.system() == "Windows"

if __debug__:
if not osp.isfile(meteor_jar_fpath):
raise FileNotFoundError(
Expand Down Expand Up @@ -94,14 +102,16 @@ def meteor(
]

if verbose >= 2:
pylog.debug(f"Start METEOR process with command '{' '.join(meteor_cmd)}'...")
pylog.debug(
f"Run METEOR java code with: {' '.join(meteor_cmd)} and {use_shell=}"
)

meteor_process = Popen(
meteor_cmd,
stdin=subprocess.PIPE,
stdout=subprocess.PIPE,
stderr=subprocess.PIPE,
shell=True,
shell=use_shell,
)

n_candidates = len(candidates)
Expand Down
14 changes: 12 additions & 2 deletions src/aac_metrics/functional/spice.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import math
import os
import os.path as osp
import platform
import shutil
import subprocess
import tempfile
Expand Down Expand Up @@ -47,6 +48,7 @@ def spice(
java_max_memory: str = "8G",
timeout: Union[None, int, Iterable[int]] = None,
separate_cache_dir: bool = True,
use_shell: Optional[bool] = None,
verbose: int = 0,
) -> Union[tuple[dict[str, Tensor], dict[str, Tensor]], Tensor]:
"""Semantic Propositional Image Caption Evaluation function.
Expand All @@ -72,6 +74,9 @@ def spice(
:param separate_cache_dir: If True, the SPICE cache files will be stored into in a new temporary directory.
This removes potential freezes when multiple instances of SPICE are running in the same cache dir.
defaults to True.
:param use_shell: Optional argument to force use os-specific shell for the java subprogram.
If None, it will use shell only on Windows OS.
defaults to None.
:param verbose: The verbose level. defaults to 0.
:returns: A tuple of globals and locals scores or a scalar tensor with the main global score.
"""
Expand All @@ -82,6 +87,9 @@ def spice(

spice_fpath = osp.join(cache_path, FNAME_SPICE_JAR)

if use_shell is None:
use_shell = platform.system() == "Windows"

if __debug__:
if not osp.isfile(spice_fpath):
raise FileNotFoundError(
Expand Down Expand Up @@ -169,15 +177,17 @@ def spice(
spice_cmd += ["-threads", str(n_threads)]

if verbose >= 2:
pylog.debug(f"Run SPICE java code with: {' '.join(spice_cmd)}")
pylog.debug(
f"Run SPICE java code with: {' '.join(spice_cmd)} and {use_shell=}"
)

try:
subprocess.check_call(
spice_cmd,
stdout=stdout,
stderr=stderr,
timeout=timeout_i,
shell=True,
shell=use_shell,
)
if stdout is not None:
stdout.close()
Expand Down

0 comments on commit 0f7ad13

Please sign in to comment.