Skip to content

Commit

Permalink
Add: Options params and weights for METEOR metric.
Browse files Browse the repository at this point in the history
  • Loading branch information
Labbeti committed Nov 2, 2023
1 parent c5017ec commit a9cfaa8
Show file tree
Hide file tree
Showing 2 changed files with 34 additions and 2 deletions.
8 changes: 7 additions & 1 deletion src/aac_metrics/classes/meteor.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-

from typing import Optional, Union
from typing import Iterable, Optional, Union

from torch import Tensor

Expand Down Expand Up @@ -33,6 +33,8 @@ def __init__(
java_max_memory: str = "2G",
language: str = "en",
use_shell: Optional[bool] = None,
params: Optional[Iterable[float]] = None,
weights: Optional[Iterable[float]] = None,
verbose: int = 0,
) -> None:
super().__init__()
Expand All @@ -42,6 +44,8 @@ def __init__(
self._java_max_memory = java_max_memory
self._language = language
self._use_shell = use_shell
self._params = params
self._weights = weights
self._verbose = verbose

self._candidates = []
Expand All @@ -57,6 +61,8 @@ def compute(self) -> Union[tuple[dict[str, Tensor], dict[str, Tensor]], Tensor]:
java_max_memory=self._java_max_memory,
language=self._language,
use_shell=self._use_shell,
params=self._params,
weights=self._weights,
verbose=self._verbose,
)

Expand Down
28 changes: 27 additions & 1 deletion src/aac_metrics/functional/meteor.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
import subprocess

from subprocess import Popen
from typing import Optional, Union
from typing import Iterable, Optional, Union

import torch

Expand All @@ -36,6 +36,8 @@ def meteor(
java_max_memory: str = "2G",
language: str = "en",
use_shell: Optional[bool] = None,
params: Optional[Iterable[float]] = None,
weights: Optional[Iterable[float]] = None,
verbose: int = 0,
) -> Union[tuple[dict[str, Tensor], dict[str, Tensor]], Tensor]:
"""Metric for Evaluation of Translation with Explicit ORdering function.
Expand All @@ -57,6 +59,12 @@ def meteor(
:param use_shell: Optional argument to force use os-specific shell for the java subprogram.
If None, it will use shell only on Windows OS.
defaults to None.
:param params: List of 4 parameters (alpha, beta gamma delta) used in METEOR metric.
If None, it will use the default of the java program, which is (0.85, 0.2, 0.6, 0.75).
defaults to None.
:param weights: List of 4 parameters (w1, w2, w3, w4) used in METEOR metric.
If None, it will use the default of the java program, which is (1.0 1.0 0.6 0.8).
defaults to None.
:param verbose: The verbose level. defaults to 0.
:returns: A tuple of globals and locals scores or a scalar tensor with the main global score.
"""
Expand Down Expand Up @@ -101,6 +109,24 @@ def meteor(
"-norm",
]

if params is not None:
params = list(params)
if len(params) != 4:
raise ValueError(
f"Invalid argument {params=}. (expected 4 params but found {len(params)})"
)
params_arg = " ".join(map(str, params))
meteor_cmd += ["-p", f"{params_arg}"]

if weights is not None:
weights = list(weights)
if len(weights) != 4:
raise ValueError(
f"Invalid argument {weights=}. (expected 4 params but found {len(weights)})"
)
weights_arg = " ".join(map(str, weights))
meteor_cmd += ["-w", f"{weights_arg}"]

if verbose >= 2:
pylog.debug(
f"Run METEOR java code with: {' '.join(meteor_cmd)} and {use_shell=}"
Expand Down

0 comments on commit a9cfaa8

Please sign in to comment.