Skip to content

Commit

Permalink
Fix: CET unittest for windows and internal typing for fense unitest.
Browse files Browse the repository at this point in the history
  • Loading branch information
Labbeti committed Sep 18, 2023
1 parent 8b70bfe commit 9f8b245
Show file tree
Hide file tree
Showing 2 changed files with 14 additions and 3 deletions.
11 changes: 10 additions & 1 deletion tests/test_compare_cet.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,11 @@

from aac_metrics.functional.evaluate import evaluate
from aac_metrics.eval import load_csv_file
from aac_metrics.utils.paths import get_default_tmp_path, set_default_tmp_path
from aac_metrics.utils.paths import (
get_default_tmp_path,
set_default_cache_path,
set_default_tmp_path,
)


class TestCompareCaptionEvaluationTools(TestCase):
Expand All @@ -30,6 +34,11 @@ def setUpClass(cls) -> None:
tmp_path = osp.join(".", "tmp")
os.makedirs(tmp_path, exist_ok=True)
set_default_tmp_path(tmp_path)

cache_path = osp.join(".", "cache")
os.makedirs(cache_path, exist_ok=True)
set_default_cache_path(cache_path)

cls.evaluate_metrics_from_lists = cls._import_cet_eval_func()

@classmethod
Expand Down
6 changes: 4 additions & 2 deletions tests/test_compare_fense.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,8 @@ def test_output_size(self) -> None:
cands, mrefs = load_csv_file(fpath)

self.new_fense._return_all_scores = True
corpus_scores, sents_scores = self.new_fense(cands, mrefs)
outs: tuple = self.new_fense(cands, mrefs) # type: ignore
corpus_scores, sents_scores = outs
self.new_fense._return_all_scores = False

for name, score in corpus_scores.items():
Expand All @@ -79,7 +80,8 @@ def _test_with_original_fense(self, fpath: str) -> None:
src_sbert_sim_score = self.src_sbert_sim.corpus_score(cands, mrefs).item()
src_fense_score = self.src_fense.corpus_score(cands, mrefs).item()

corpus_outs, _sents_outs = self.new_fense(cands, mrefs)
outs: tuple = self.new_fense(cands, mrefs) # type: ignore
corpus_outs, _sents_outs = outs
new_sbert_sim_score = corpus_outs["sbert_sim"].item()
new_fense_score = corpus_outs["fense"].item()

Expand Down

0 comments on commit 9f8b245

Please sign in to comment.