Skip to content

Commit

Permalink
Add: BERTScore model download.
Browse files Browse the repository at this point in the history
  • Loading branch information
Labbeti committed Nov 15, 2023
1 parent 5dd0572 commit b070a71
Showing 1 changed file with 15 additions and 0 deletions.
15 changes: 15 additions & 0 deletions src/aac_metrics/download.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@

import aac_metrics

from aac_metrics.classes.bert_score_mrefs import BERTScoreMRefs
from aac_metrics.classes.fense import FENSE
from aac_metrics.functional.meteor import DNAME_METEOR_CACHE
from aac_metrics.functional.spice import (
Expand Down Expand Up @@ -88,6 +89,7 @@ def download_metrics(
meteor: bool = True,
spice: bool = True,
fense: bool = True,
bert_score: bool = True,
verbose: int = 0,
) -> None:
"""Download the code needed for SPICE, METEOR, PTB Tokenizer and FENSE.
Expand All @@ -99,6 +101,7 @@ def download_metrics(
:param meteor: If True, downloads the METEOR code in cache directory. defaults to True.
:param spice: If True, downloads the SPICE code in cache directory. defaults to True.
:param fense: If True, downloads the FENSE models. defaults to True.
:param bert_score: If True, downloads the BERTScore model. defaults to True.
:param verbose: The verbose level. defaults to 0.
"""
if verbose >= 1:
Expand Down Expand Up @@ -127,6 +130,9 @@ def download_metrics(
if fense:
_download_fense(verbose)

if bert_score:
_download_bert_score(verbose)

if verbose >= 1:
pylog.info("aac-metrics download finished.")

Expand Down Expand Up @@ -317,6 +323,15 @@ def _download_fense(
_ = FENSE(device="cpu")


def _download_bert_score(
verbose: int = 0,
) -> None:
# Download models files for BERTScore metric
if verbose >= 1:
pylog.info("Downloading BERT model for BERTScore metric...")
_ = BERTScoreMRefs(device="cpu")


def _get_main_download_args() -> Namespace:
parser = ArgumentParser(
description="Download models and external code to evaluate captions."
Expand Down

0 comments on commit b070a71

Please sign in to comment.