Skip to content

Commit

Permalink
Add: Force option for download script.
Browse files Browse the repository at this point in the history
  • Loading branch information
LABBE Etienne committed May 22, 2024
1 parent 1e8bcec commit 152d977
Show file tree
Hide file tree
Showing 2 changed files with 56 additions and 16 deletions.
71 changes: 55 additions & 16 deletions src/aac_metrics/download.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,7 @@


def download_metrics(
*,
cache_path: Union[str, Path, None] = None,
tmp_path: Union[str, Path, None] = None,
clean_archives: bool = True,
Expand All @@ -87,6 +88,7 @@ def download_metrics(
spice: bool = True,
fense: bool = True,
bert_score: bool = True,
force: bool = False,
verbose: int = 0,
) -> None:
"""Download the code needed for SPICE, METEOR, PTB Tokenizer and FENSE.
Expand All @@ -99,6 +101,9 @@ def download_metrics(
:param spice: If True, downloads the SPICE code in cache directory. defaults to True.
:param fense: If True, downloads the FENSE models. defaults to True.
:param bert_score: If True, downloads the BERTScore model. defaults to True.
:param force: If True, force to download files and extract archives again even if they are already present on disk.
Works only for PTBTokenizer, METEOR and SPICE files.
defaults to False.
:param verbose: The verbose level. defaults to 0.
"""
if verbose >= 1:
Expand All @@ -116,26 +121,45 @@ def download_metrics(
pylog.debug(f" Temp directory: {tmp_path}")

if ptb_tokenizer:
_download_ptb_tokenizer(cache_path, verbose)
_download_ptb_tokenizer(
cache_path,
force=force,
verbose=verbose,
)

if meteor:
_download_meteor(cache_path, verbose)
_download_meteor(
cache_path,
force=force,
verbose=verbose,
)

if spice:
_download_spice(cache_path, clean_archives, verbose)
_download_spice(
cache_path,
clean_archives=clean_archives,
force=force,
verbose=verbose,
)

if fense:
_download_fense(verbose)
_download_fense(
verbose=verbose,
)

if bert_score:
_download_bert_score(verbose)
_download_bert_score(
verbose=verbose,
)

if verbose >= 1:
pylog.info("aac-metrics download finished.")


def _download_ptb_tokenizer(
cache_path: str,
*,
force: bool = False,
verbose: int = 0,
) -> None:
# Download JAR file for tokenization
Expand All @@ -150,19 +174,22 @@ def _download_ptb_tokenizer(
fname = info["fname"]
fpath = osp.join(stanford_nlp_dpath, fname)

if not osp.isfile(fpath):
if verbose >= 1:
pylog.info(
f"Downloading JAR source for '{name}' in directory {stanford_nlp_dpath}."
)
download_url_to_file(url, fpath, progress=verbose >= 1)
else:
if not force and osp.isfile(fpath):
if verbose >= 1:
pylog.info(f"Stanford model file '{name}' is already downloaded.")
return None

if verbose >= 1:
pylog.info(
f"Downloading JAR source for '{name}' in directory {stanford_nlp_dpath}."
)
download_url_to_file(url, fpath, progress=verbose >= 1)


def _download_meteor(
cache_path: str,
*,
force: bool = False,
verbose: int = 0,
) -> None:
# Download JAR files for METEOR metric
Expand All @@ -178,7 +205,7 @@ def _download_meteor(
subdir = osp.dirname(fname)
fpath = osp.join(meteor_dpath, fname)

if osp.isfile(fpath):
if not force and osp.isfile(fpath):
if verbose >= 1:
pylog.info(f"Meteor file '{name}' is already downloaded.")
continue
Expand All @@ -197,7 +224,9 @@ def _download_meteor(

def _download_spice(
cache_path: str,
*,
clean_archives: bool = True,
force: bool = False,
verbose: int = 0,
) -> None:
"""Download SPICE java code.
Expand Down Expand Up @@ -229,8 +258,9 @@ def _download_spice(
└── spice-1.0.jar
"""
try:
check_spice_install(cache_path)
return None
if not force:
check_spice_install(cache_path)
return None
except (FileNotFoundError, NotADirectoryError, PermissionError):
pass

Expand All @@ -247,7 +277,7 @@ def _download_spice(
fname = DATA_URLS[name]["fname"]
fpath = osp.join(spice_cache_dpath, fname)

if osp.isfile(fpath):
if not force and osp.isfile(fpath):
if verbose >= 1:
pylog.info(f"File '{fpath}' is already downloaded for SPICE.")
else:
Expand Down Expand Up @@ -312,6 +342,7 @@ def _download_spice(


def _download_fense(
*,
verbose: int = 0,
) -> None:
# Download models files for FENSE metric
Expand All @@ -321,6 +352,7 @@ def _download_fense(


def _download_bert_score(
*,
verbose: int = 0,
) -> None:
# Download models files for BERTScore metric
Expand Down Expand Up @@ -376,6 +408,12 @@ def _get_main_download_args() -> Namespace:
default=True,
help="Download FENSE models.",
)
parser.add_argument(
"--force",
type=_str_to_bool,
default=False,
help="Force to download files and extract archives again.",
)
parser.add_argument("--verbose", type=int, default=1, help="Verbose level.")

args = parser.parse_args()
Expand All @@ -394,6 +432,7 @@ def _main_download() -> None:
meteor=args.meteor,
spice=args.spice,
fense=args.fense,
force=args.force,
verbose=args.verbose,
)

Expand Down
1 change: 1 addition & 0 deletions src/aac_metrics/functional/spice.py
Original file line number Diff line number Diff line change
Expand Up @@ -273,6 +273,7 @@ def check_spice_install(cache_path: str) -> None:
for fname in expected_jar_in_lib:
if fname not in names:
files_not_found.append(fname)

if len(files_not_found) > 0:
raise FileNotFoundError(
f"Missing {len(files_not_found)} files in SPICE lib directory. (missing {', '.join(files_not_found)})"
Expand Down

0 comments on commit 152d977

Please sign in to comment.