diff --git a/src/aac_metrics/download.py b/src/aac_metrics/download.py index 829a155..713e187 100644 --- a/src/aac_metrics/download.py +++ b/src/aac_metrics/download.py @@ -79,6 +79,7 @@ def download_metrics( + *, cache_path: Union[str, Path, None] = None, tmp_path: Union[str, Path, None] = None, clean_archives: bool = True, @@ -87,6 +88,7 @@ def download_metrics( spice: bool = True, fense: bool = True, bert_score: bool = True, + force: bool = False, verbose: int = 0, ) -> None: """Download the code needed for SPICE, METEOR, PTB Tokenizer and FENSE. @@ -99,6 +101,9 @@ def download_metrics( :param spice: If True, downloads the SPICE code in cache directory. defaults to True. :param fense: If True, downloads the FENSE models. defaults to True. :param bert_score: If True, downloads the BERTScore model. defaults to True. + :param force: If True, force to download files and extract archives again even if they are already present on disk. + Works only for PTBTokenizer, METEOR and SPICE files. + defaults to False. :param verbose: The verbose level. defaults to 0. """ if verbose >= 1: @@ -116,19 +121,36 @@ def download_metrics( pylog.debug(f" Temp directory: {tmp_path}") if ptb_tokenizer: - _download_ptb_tokenizer(cache_path, verbose) + _download_ptb_tokenizer( + cache_path, + force=force, + verbose=verbose, + ) if meteor: - _download_meteor(cache_path, verbose) + _download_meteor( + cache_path, + force=force, + verbose=verbose, + ) if spice: - _download_spice(cache_path, clean_archives, verbose) + _download_spice( + cache_path, + clean_archives=clean_archives, + force=force, + verbose=verbose, + ) if fense: - _download_fense(verbose) + _download_fense( + verbose=verbose, + ) if bert_score: - _download_bert_score(verbose) + _download_bert_score( + verbose=verbose, + ) if verbose >= 1: pylog.info("aac-metrics download finished.") @@ -136,6 +158,8 @@ def download_metrics( def _download_ptb_tokenizer( cache_path: str, + *, + force: bool = False, verbose: int = 0, ) -> None: # Download JAR file for tokenization @@ -150,19 +174,22 @@ def _download_ptb_tokenizer( fname = info["fname"] fpath = osp.join(stanford_nlp_dpath, fname) - if not osp.isfile(fpath): - if verbose >= 1: - pylog.info( - f"Downloading JAR source for '{name}' in directory {stanford_nlp_dpath}." - ) - download_url_to_file(url, fpath, progress=verbose >= 1) - else: + if not force and osp.isfile(fpath): if verbose >= 1: pylog.info(f"Stanford model file '{name}' is already downloaded.") + return None + + if verbose >= 1: + pylog.info( + f"Downloading JAR source for '{name}' in directory {stanford_nlp_dpath}." + ) + download_url_to_file(url, fpath, progress=verbose >= 1) def _download_meteor( cache_path: str, + *, + force: bool = False, verbose: int = 0, ) -> None: # Download JAR files for METEOR metric @@ -178,7 +205,7 @@ def _download_meteor( subdir = osp.dirname(fname) fpath = osp.join(meteor_dpath, fname) - if osp.isfile(fpath): + if not force and osp.isfile(fpath): if verbose >= 1: pylog.info(f"Meteor file '{name}' is already downloaded.") continue @@ -197,7 +224,9 @@ def _download_meteor( def _download_spice( cache_path: str, + *, clean_archives: bool = True, + force: bool = False, verbose: int = 0, ) -> None: """Download SPICE java code. @@ -229,8 +258,9 @@ def _download_spice( └── spice-1.0.jar """ try: - check_spice_install(cache_path) - return None + if not force: + check_spice_install(cache_path) + return None except (FileNotFoundError, NotADirectoryError, PermissionError): pass @@ -247,7 +277,7 @@ def _download_spice( fname = DATA_URLS[name]["fname"] fpath = osp.join(spice_cache_dpath, fname) - if osp.isfile(fpath): + if not force and osp.isfile(fpath): if verbose >= 1: pylog.info(f"File '{fpath}' is already downloaded for SPICE.") else: @@ -312,6 +342,7 @@ def _download_spice( def _download_fense( + *, verbose: int = 0, ) -> None: # Download models files for FENSE metric @@ -321,6 +352,7 @@ def _download_fense( def _download_bert_score( + *, verbose: int = 0, ) -> None: # Download models files for BERTScore metric @@ -376,6 +408,12 @@ def _get_main_download_args() -> Namespace: default=True, help="Download FENSE models.", ) + parser.add_argument( + "--force", + type=_str_to_bool, + default=False, + help="Force to download files and extract archives again.", + ) parser.add_argument("--verbose", type=int, default=1, help="Verbose level.") args = parser.parse_args() @@ -394,6 +432,7 @@ def _main_download() -> None: meteor=args.meteor, spice=args.spice, fense=args.fense, + force=args.force, verbose=args.verbose, ) diff --git a/src/aac_metrics/functional/spice.py b/src/aac_metrics/functional/spice.py index 539d5d0..5e1c083 100644 --- a/src/aac_metrics/functional/spice.py +++ b/src/aac_metrics/functional/spice.py @@ -273,6 +273,7 @@ def check_spice_install(cache_path: str) -> None: for fname in expected_jar_in_lib: if fname not in names: files_not_found.append(fname) + if len(files_not_found) > 0: raise FileNotFoundError( f"Missing {len(files_not_found)} files in SPICE lib directory. (missing {', '.join(files_not_found)})"