Skip to content

Commit

Permalink
Fix: SPICE download and logging setup.
Browse files Browse the repository at this point in the history
  • Loading branch information
Labbeti committed Sep 26, 2023
1 parent df12588 commit 5f2f37f
Show file tree
Hide file tree
Showing 5 changed files with 98 additions and 60 deletions.
39 changes: 10 additions & 29 deletions src/aac_metrics/__init__.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,9 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-

"""Audio Captioning metrics package.
"""Metrics for evaluating Automated Audio Captioning systems, designed for PyTorch.
"""

__name__ = "aac-metrics"
__author__ = "Etienne Labbé (Labbeti)"
__author_email__ = "[email protected]"
__license__ = "MIT"
__maintainer__ = "Etienne Labbé (Labbeti)"
__status__ = "Development"
__version__ = "0.4.5"


from .classes.base import AACMetric
from .classes.bleu import BLEU
Expand All @@ -31,35 +23,24 @@
set_default_java_path,
set_default_tmp_path,
)
from . import download, eval, info


__all__ = [
"BLEU",
"CIDErD",
"DCASE2023Evaluate",
"FENSE",
"METEOR",
"ROUGEL",
"SPICE",
"SPIDEr",
"dcase2023_evaluate",
"evaluate",
"get_default_cache_path",
"get_default_java_path",
"get_default_tmp_path",
"set_default_cache_path",
"set_default_java_path",
"set_default_tmp_path",
"load_metric",
]
__name__ = "aac-metrics"
__author__ = "Etienne Labbé (Labbeti)"
__author_email__ = "[email protected]"
__license__ = "MIT"
__maintainer__ = "Etienne Labbé (Labbeti)"
__status__ = "Development"
__version__ = "0.4.5"


def load_metric(name: str, **kwargs) -> AACMetric:
"""Load a metric class by name.
:param name: The name of the metric.
Must be one of ("bleu_1", "bleu_2", "bleu_3", "bleu_4", "meteor", "rouge_l", "cider_d", "spice", "spider", "fense").
:param **kwargs: The keyword optional arguments passed to the metric.
:param **kwargs: The keyword optional arguments passed to the metric factory.
:returns: The Metric object built.
"""
name = name.lower().strip()
Expand Down
21 changes: 19 additions & 2 deletions src/aac_metrics/classes/evaluate.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import pickle
import zlib

from typing import Callable, Iterable, Union
from typing import Any, Callable, Iterable, Union

import torch

Expand Down Expand Up @@ -171,14 +171,20 @@ def _get_metric_factory_classes(
tmp_path: str = ...,
device: Union[str, torch.device, None] = "auto",
verbose: int = 0,
init_kwds: dict[str, Any] = ...,
) -> dict[str, Callable[[], AACMetric]]:
return {
if init_kwds is ...:
init_kwds = {}

factory = {
"bleu": lambda: BLEU(
return_all_scores=return_all_scores,
**init_kwds,
),
"bleu_1": lambda: BLEU(
return_all_scores=return_all_scores,
n=1,
**init_kwds,
),
"bleu_2": lambda: BLEU(
return_all_scores=return_all_scores,
Expand All @@ -187,41 +193,49 @@ def _get_metric_factory_classes(
"bleu_3": lambda: BLEU(
return_all_scores=return_all_scores,
n=3,
**init_kwds,
),
"bleu_4": lambda: BLEU(
return_all_scores=return_all_scores,
n=4,
**init_kwds,
),
"meteor": lambda: METEOR(
return_all_scores=return_all_scores,
cache_path=cache_path,
java_path=java_path,
verbose=verbose,
**init_kwds,
),
"rouge_l": lambda: ROUGEL(
return_all_scores=return_all_scores,
**init_kwds,
),
"cider_d": lambda: CIDErD(
return_all_scores=return_all_scores,
**init_kwds,
),
"spice": lambda: SPICE(
return_all_scores=return_all_scores,
cache_path=cache_path,
java_path=java_path,
tmp_path=tmp_path,
verbose=verbose,
**init_kwds,
),
"spider": lambda: SPIDEr(
return_all_scores=return_all_scores,
cache_path=cache_path,
java_path=java_path,
tmp_path=tmp_path,
verbose=verbose,
**init_kwds,
),
"sbert_sim": lambda: SBERTSim(
return_all_scores=return_all_scores,
device=device,
verbose=verbose,
**init_kwds,
),
"fluerr": lambda: FluErr(
return_all_scores=return_all_scores,
Expand All @@ -232,6 +246,7 @@ def _get_metric_factory_classes(
return_all_scores=return_all_scores,
device=device,
verbose=verbose,
**init_kwds,
),
"spider_fl": lambda: SPIDErFL(
return_all_scores=return_all_scores,
Expand All @@ -240,5 +255,7 @@ def _get_metric_factory_classes(
tmp_path=tmp_path,
device=device,
verbose=verbose,
**init_kwds,
),
}
return factory
67 changes: 39 additions & 28 deletions src/aac_metrics/download.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
FNAME_SPICE_JAR,
DNAME_SPICE_LOCAL_CACHE,
DNAME_SPICE_CACHE,
check_spice_install,
)
from aac_metrics.utils.paths import (
_get_cache_path,
Expand Down Expand Up @@ -67,19 +68,17 @@
"spice_zip": {
"url": "https://panderson.me/images/SPICE-1.0.zip",
"fname": "SPICE-1.0.zip",
"extract_to": ".",
},
"spice_corenlp_zip": {
"url": "http://nlp.stanford.edu/software/stanford-corenlp-full-2015-12-09.zip",
"fname": osp.join("SPICE-1.0", "stanford-corenlp-full-2015-12-09.zip"),
"extract_to": "lib",
},
}
_TRUE_VALUES = ("true", "1", "t")
_FALSE_VALUES = ("false", "0", "f")


def download(
def download_metrics(
cache_path: str = ...,
tmp_path: str = ...,
clean_archives: bool = True,
Expand Down Expand Up @@ -224,6 +223,12 @@ def _download_spice(
│ └── stanford-corenlp-3.6.0-models.jar
└── spice-1.0.jar
"""
try:
check_spice_install(cache_path)
return None
except (FileNotFoundError, NotADirectoryError):
pass

# Download JAR files for SPICE metric
spice_cache_dpath = osp.join(cache_path, DNAME_SPICE_CACHE)
spice_jar_dpath = osp.join(cache_path, osp.dirname(FNAME_SPICE_JAR))
Expand All @@ -249,47 +254,42 @@ def _download_spice(
download_url_to_file(url, fpath, progress=verbose > 0)

if fname.endswith(".zip"):
parent_tgt_dpath = osp.join(
spice_cache_dpath, DATA_URLS[name]["extract_to"]
)
os.makedirs(parent_tgt_dpath, exist_ok=True)

if verbose >= 1:
pylog.info(f"Extracting {fname} to {parent_tgt_dpath}...")
pylog.info(f"Extracting {fname} to {spice_cache_dpath}...")

with ZipFile(fpath, "r") as file:
file.extractall(parent_tgt_dpath)
file.extractall(spice_cache_dpath)

spice_lib_dpath = osp.join(spice_cache_dpath, "lib")
spice_unzip_dpath = osp.join(spice_cache_dpath, "SPICE-1.0")
corenlp_dname = "stanford-corenlp-full-2015-12-09"
corenlp_dpath = osp.join(spice_lib_dpath, corenlp_dname)
corenlp_dpath = osp.join(spice_cache_dpath, "stanford-corenlp-full-2015-12-09")

# Note: order matter here
to_move = [
("f", osp.join(spice_unzip_dpath, "spice-1.0.jar"), spice_cache_dpath),
("d", osp.join(spice_unzip_dpath, "lib"), spice_cache_dpath),
("f", osp.join(corenlp_dpath, "stanford-corenlp-3.6.0.jar"), spice_lib_dpath),
(
"f",
osp.join(corenlp_dpath, "stanford-corenlp-3.6.0-models.jar"),
spice_lib_dpath,
),
("d", osp.join(spice_unzip_dpath, "lib"), spice_cache_dpath),
("f", osp.join(spice_unzip_dpath, "spice-1.0.jar"), spice_cache_dpath),
]
for src_type, src_path, parent_tgt_dpath in to_move:
tgt_path = osp.join(parent_tgt_dpath, osp.basename(src_path))

if osp.exists(tgt_path):
if src_type == "f":
if verbose >= 1:
pylog.info(f"Target file '{tgt_path}' already exists.")
elif src_type == "d":
if verbose >= 1:
pylog.info(f"Moving all objects in '{src_path}' to '{tgt_path}'...")
for name in os.listdir(src_path):
shutil.move(osp.join(src_path, name), tgt_path)
os.rmdir(src_path)
else:
raise ValueError(f"Invalid type value {src_type}.")
if verbose >= 1:
pylog.info(f"Target '{tgt_path}' already exists.")
# if src_type == "f":
# elif src_type == "d":
# if verbose >= 1:
# pylog.info(f"Moving all objects in '{src_path}' to '{tgt_path}'...")
# for name in os.listdir(src_path):
# shutil.move(osp.join(src_path, name), tgt_path)
# os.rmdir(src_path)
# else:
# raise ValueError(f"Invalid type value {src_type}.")
else:
pylog.info(f"Moving '{src_path}' to '{parent_tgt_dpath}'...")
shutil.move(src_path, parent_tgt_dpath)
Expand Down Expand Up @@ -370,10 +370,21 @@ def _setup_logging(verbose: int = 1) -> None:
handler = logging.StreamHandler(sys.stdout)
handler.setFormatter(logging.Formatter(format_))
pkg_logger = logging.getLogger("aac_metrics")
if handler not in pkg_logger.handlers:

found = False
for handler in pkg_logger.handlers:
if isinstance(handler, logging.StreamHandler) and handler.stream is sys.stdout:
found = True
break
if not found:
pkg_logger.addHandler(handler)

level = logging.INFO if verbose <= 1 else logging.DEBUG
if verbose <= 0:
level = logging.WARNING
elif verbose == 1:
level = logging.INFO
else:
level = logging.DEBUG
pkg_logger.setLevel(level)


Expand All @@ -382,7 +393,7 @@ def _main_download() -> None:

_setup_logging(args.verbose)

download(
download_metrics(
cache_path=args.cache_path,
tmp_path=args.tmp_path,
clean_archives=args.clean_archives,
Expand Down
Loading

0 comments on commit 5f2f37f

Please sign in to comment.