Skip to content

Commit

Permalink
Version 0.4.2
Browse files Browse the repository at this point in the history
  • Loading branch information
Labbeti committed Apr 19, 2023
1 parent 2f0514f commit b74e2ad
Show file tree
Hide file tree
Showing 14 changed files with 1,528 additions and 43 deletions.
7 changes: 3 additions & 4 deletions .github/workflows/python-package-pip.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@ jobs:

steps:
# --- INSTALLATIONS ---

- name: Checkout repository
uses: actions/checkout@v2
with:
Expand All @@ -41,14 +40,14 @@ jobs:

- name: Install package
run: |
python -m pip install -e .[dev]
python -m pip install "aac-metrics[dev] @ git+https://github.com/Labbeti/aac-metrics@dev"
- name: Load cache of external code and data
uses: actions/cache@master
id: cache_external
with:
path: /home/runner/.cache/aac-metrics/*
key: ${{ runner.os }}-${{ hashFiles('install_spice.sh') }}
key: ${{ runner.os }}-${{ hashFiles('src/aac_metrics/download.py') }}
restore-keys: |
${{ runner.os }}-
Expand All @@ -67,7 +66,7 @@ jobs:
- name: Print Java version
run: |
java --version
java -version
- name: Install external code if needed
if: steps.cache_external.outputs.cache-hit != 'true'
Expand Down
7 changes: 6 additions & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,11 @@

All notable changes to this project will be documented in this file.

## [0.4.2] 2023-04-19
### Fixed
- File `install_spice.sh` is now in `src/aac_metrics` directory to fix download from a pip installation. ([#3](https://github.com/Labbeti/aac-metrics/issues/3))
- Java version retriever to avoid exception when java version is correct. ([#2](https://github.com/Labbeti/aac-metrics/issues/2))

## [0.4.1] 2023-04-13
### Deleted
- Old unused files `package_tree.rst`, `fluency_error.py`, `sbert.py` and `spider_err.py`.
Expand All @@ -14,7 +19,7 @@ All notable changes to this project will be documented in this file.
- Rename `SPIDErErr` to `SPIDErFL` to match DCASE2023 metric name.
- Rename `SBERT` to `SBERTSim` to avoid confusion with SBERT model name.
- Rename `FluencyError` to `FluErr`.
- Check if Java executable version between 8 and 11.
- Check if Java executable version between 8 and 11. ([#1](https://github.com/Labbeti/aac-metrics/issues/1))

### Fixed
- `SPIDErFL` sentences scores outputs when using `return_all_scores=True`.
Expand Down
4 changes: 2 additions & 2 deletions CITATION.cff
Original file line number Diff line number Diff line change
Expand Up @@ -19,5 +19,5 @@ keywords:
- captioning
- audio-captioning
license: MIT
version: 0.4.1
date-released: '2023-04-13'
version: 0.4.2
date-released: '2023-04-19'
4 changes: 2 additions & 2 deletions MANIFEST.in
Original file line number Diff line number Diff line change
Expand Up @@ -3,5 +3,5 @@ recursive-include src *.py
global-exclude *.pyc
global-exclude __pycache__

include install_spice.sh
recursive-include examples *.csv
include src/aac_metrics/install_spice.sh
recursive-include data *.csv
6 changes: 3 additions & 3 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ Install the pip package:
pip install aac-metrics
```

Download the external code and models needed for METEOR, SPICE, SPIDEr, SPIDEr-max, PTBTokenizer, SBERT, FluencyError, FENSE and SPIDEr-FL:
Download the external code and models needed for METEOR, SPICE, SPIDEr, SPIDEr-max, PTBTokenizer, SBERTSim, FluencyError, FENSE and SPIDEr-FL:
```bash
aac-metrics-download
```
Expand Down Expand Up @@ -122,7 +122,7 @@ torch >= 1.10.1
numpy >= 1.21.2
pyyaml >= 6.0
tqdm >= 4.64.0
sentence-transformers>=2.2.2
sentence-transformers >= 2.2.2
```

### External requirements
Expand Down Expand Up @@ -219,7 +219,7 @@ If you use this software, please consider cite it as below :
month = {4},
title = {{aac-metrics}},
url = {https://github.com/Labbeti/aac-metrics/},
version = {0.4.1},
version = {0.4.2},
year = {2023},
}
```
Expand Down
465 changes: 465 additions & 0 deletions data/example_1.csv

Large diffs are not rendered by default.

913 changes: 913 additions & 0 deletions data/example_2.csv

Large diffs are not rendered by default.

2 changes: 1 addition & 1 deletion src/aac_metrics/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
__license__ = "MIT"
__maintainer__ = "Etienne Labbé (Labbeti)"
__status__ = "Development"
__version__ = "0.4.1"
__version__ = "0.4.2"


from .classes.base import AACMetric
Expand Down
4 changes: 2 additions & 2 deletions src/aac_metrics/download.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,7 +127,7 @@ def download(
os.makedirs(spice_jar_dpath, exist_ok=True)
os.makedirs(spice_cache_path, exist_ok=True)

script_path = osp.join(osp.dirname(__file__), "..", "..", "install_spice.sh")
script_path = osp.join(osp.dirname(__file__), "install_spice.sh")
if not osp.isfile(script_path):
raise FileNotFoundError(
f"Cannot find script '{osp.basename(script_path)}'."
Expand All @@ -151,7 +151,7 @@ def download(
if fense:
# Download models files for FENSE metric
if verbose >= 1:
pylog.info("Downloading sBert and Bert error detector for FENSE metric...")
pylog.info("Downloading SBERT and BERT error detector for FENSE metric...")
_ = FENSE(device="cpu")


Expand Down
55 changes: 55 additions & 0 deletions src/aac_metrics/install_spice.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
#!/bin/bash

DEFAULT_SPICE_ROOT="$HOME/.cache/aac-metrics/spice"

if [ "$1" = "-h" ] || [ "$1" = "--help" ]; then
echo "Install all files for running the java SPICE program in the SPICE_ROOT directory."
echo "The default spice root path is \"${DEFAULT_SPICE_ROOT}\"."
echo "Usage: $0 [SPICE_ROOT]"
exit 0
fi

dpath_spice="$1"
if [ "$dpath_spice" = "" ]; then
dpath_spice="${DEFAULT_SPICE_ROOT}"
fi

if [ ! -d "$dpath_spice" ]; then
echo "Error: SPICE_ROOT \"$dpath_spice\" is not a directory."
exit 1
fi

fname_zip="SPICE-1.0.zip"
fpath_zip="$dpath_spice/$fname_zip"
bn0=`basename $0`

echo "[$bn0] Start installation of SPICE metric java code in directory \"$dpath_spice\"..."

if [ ! -f "$fpath_zip" ]; then
echo "[$bn0] Zip file not found, downloading from https://panderson.me..."
wget https://panderson.me/images/SPICE-1.0.zip -P "$dpath_spice"
fi

dpath_unzip="$dpath_spice/SPICE-1.0"
if [ ! -d "$dpath_unzip" ]; then
echo "[$bn0] Unzipping file $dpath_zip..."
unzip $fpath_zip -d "$dpath_spice"

echo "[$bn0] Downloading Stanford models..."
bash $dpath_unzip/get_stanford_models.sh
fi

dpath_lib="$dpath_spice/lib"
if [ ! -d "$dpath_lib" ]; then
echo "[$bn0] Moving lib directory to \"$dpath_spice\"..."
mv "$dpath_unzip/lib" "$dpath_spice"
fi

fpath_jar="$dpath_spice/spice-1.0.jar"
if [ ! -f "$fpath_jar" ]; then
echo "[$bn0] Moving spice-1.0.jar file to \"$dpath_spice\"..."
mv "$dpath_unzip/spice-1.0.jar" "$dpath_spice"
fi

echo "[$bn0] SPICE metric Java code is installed."
exit 0
62 changes: 42 additions & 20 deletions src/aac_metrics/utils/checks.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,10 @@
# -*- coding: utf-8 -*-

import logging
import re
import subprocess

from dataclasses import dataclass
from functools import cache
from pathlib import Path
from subprocess import CalledProcessError
Expand All @@ -12,6 +14,7 @@

pylog = logging.getLogger(__name__)

VERSION_PATTERN = r"(?P<major>\d+)\.(?P<minor>\d+)\.(?P<patch>\d+).*"
MIN_JAVA_MAJOR_VERSION = 8
MAX_JAVA_MAJOR_VERSION = 11

Expand Down Expand Up @@ -41,42 +44,61 @@ def check_metric_inputs(


def check_java_path(java_path: Union[str, Path]) -> bool:
version = _get_java_version(str(java_path))
valid = _check_java_version(version, MIN_JAVA_MAJOR_VERSION, MAX_JAVA_MAJOR_VERSION)
if not valid:
pylog.error(
f"Using Java version {version} is not officially supported by aac-metrics package and will not work for METEOR and SPICE metrics."
f"(expected major version in range [{MIN_JAVA_MAJOR_VERSION}, {MAX_JAVA_MAJOR_VERSION}])"
)
return valid


def _get_java_version(java_path: str) -> str:
"""Returns True if the java path is valid."""
if not isinstance(java_path, (str, Path)):
return False
if not isinstance(java_path, str):
raise TypeError(f"Invalid argument type {type(java_path)=}. (expected str)")

output = ""
output = "INVALID"
try:
output = subprocess.check_output(
[str(java_path), "--version"],
[java_path, "-version"],
stderr=subprocess.STDOUT,
)
output = output.decode().strip()
version = output.split("\n")[0]
major_version = int(version.split(" ")[1].split(".")[0])
version = output.split(" ")[2][1:-1]

except (
CalledProcessError,
PermissionError,
FileNotFoundError,
) as err:
pylog.error(f"Invalid java path. (from {java_path=} and found error={err})")
return False
raise ValueError(f"Invalid java path. (from {java_path=} and found {err=})")

except (
IndexError,
ValueError,
) as err:
pylog.error(f"Invalid java version. (found {output=} and {err=})")
return False
except IndexError as err:
raise ValueError(
f"Invalid java version. (from {java_path=} and found {output=} and {err=})"
)

if not (MIN_JAVA_MAJOR_VERSION <= major_version <= MAX_JAVA_MAJOR_VERSION):
pylog.error(
f"Using Java version {version} is not officially supported by aac-metrics package and could not work for METEOR and SPICE metrics."
f"(found {major_version=} but expected in [{MIN_JAVA_MAJOR_VERSION}, {MAX_JAVA_MAJOR_VERSION}])"
return version


def _check_java_version(version: str, min_major: int, max_major: int) -> bool:
result = re.match(VERSION_PATTERN, version)
if result is None:
raise ValueError(
f"Invalid Java version {version=}. (expected version with pattern={VERSION_PATTERN})"
)
return False

return True
major_version = int(result["major"])
minor_version = int(result["minor"])

if (
major_version == 1 and minor_version <= 8
): # java <= 8 use versioning "1.MAJOR.MINOR" and > 8 use "MAJOR.MINOR.PATCH"
major_version = minor_version

return min_major <= major_version <= max_major


def is_mono_sents(sents: Any) -> bool:
Expand Down
4 changes: 2 additions & 2 deletions tests/test_compare_cet.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,12 +67,12 @@ def test_example_0(self) -> None:
self._test_with_example(cands, mrefs)

def test_example_1(self) -> None:
fpath = Path(__file__).parent.parent.joinpath("examples", "example_1.csv")
fpath = Path(__file__).parent.parent.joinpath("data", "example_1.csv")
cands, mrefs = load_csv_file(fpath)
self._test_with_example(cands, mrefs)

def test_example_2(self) -> None:
fpath = Path(__file__).parent.parent.joinpath("examples", "example_2.csv")
fpath = Path(__file__).parent.parent.joinpath("data", "example_2.csv")
cands, mrefs = load_csv_file(fpath)
self._test_with_example(cands, mrefs)

Expand Down
10 changes: 5 additions & 5 deletions tests/test_compare_fense.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,23 +54,23 @@ def _get_src_evaluator_class(cls) -> Any:

# Tests methods
def test_example_1_fense(self) -> None:
fpath = osp.join(osp.dirname(__file__), "..", "examples", "example_1.csv")
fpath = osp.join(osp.dirname(__file__), "..", "data", "example_1.csv")
self._test_with_original_fense(fpath)

def test_example_1_sbert_sim(self) -> None:
fpath = osp.join(osp.dirname(__file__), "..", "examples", "example_1.csv")
fpath = osp.join(osp.dirname(__file__), "..", "data", "example_1.csv")
self._test_with_original_sbert_sim(fpath)

def test_example_2_fense(self) -> None:
fpath = osp.join(osp.dirname(__file__), "..", "examples", "example_2.csv")
fpath = osp.join(osp.dirname(__file__), "..", "data", "example_2.csv")
self._test_with_original_fense(fpath)

def test_example_2_sbert_sim(self) -> None:
fpath = osp.join(osp.dirname(__file__), "..", "examples", "example_2.csv")
fpath = osp.join(osp.dirname(__file__), "..", "data", "example_2.csv")
self._test_with_original_sbert_sim(fpath)

def test_output_size(self) -> None:
fpath = osp.join(osp.dirname(__file__), "..", "examples", "example_1.csv")
fpath = osp.join(osp.dirname(__file__), "..", "data", "example_1.csv")
cands, mrefs = load_csv_file(fpath)

self.new_fense._return_all_scores = True
Expand Down
28 changes: 27 additions & 1 deletion tests/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,13 @@

from unittest import TestCase

from aac_metrics.utils.checks import is_mono_sents, is_mult_sents
from aac_metrics.utils.checks import (
is_mono_sents,
is_mult_sents,
_check_java_version,
MIN_JAVA_MAJOR_VERSION,
MAX_JAVA_MAJOR_VERSION,
)
from aac_metrics.utils.collections import flat_list, unflat_list


Expand Down Expand Up @@ -35,6 +41,26 @@ def test_misc_functions_1(self) -> None:
self.assertEqual(len(lst), len(unflat))
self.assertListEqual(lst, unflat)

def test_check_java_versions(self) -> None:
test_set = [
("1.0.0", False),
("1.7.0", False),
("1.8.0", True),
("1.9.0", False),
("1.10.0", False),
("9.0.0", True),
("10.0.0", True),
("11.0.0", True),
("12.0.0", False),
("17.0.0", False),
("20.0.0", False),
]
for version, expected in test_set:
output = _check_java_version(
version, MIN_JAVA_MAJOR_VERSION, MAX_JAVA_MAJOR_VERSION
)
self.assertEqual(output, expected, f"{version=}")


if __name__ == "__main__":
unittest.main()

0 comments on commit b74e2ad

Please sign in to comment.