Skip to content

Commit

Permalink
Mod: Update SPICE error files management.
Browse files Browse the repository at this point in the history
  • Loading branch information
Labbeti committed Sep 25, 2023
1 parent cd87fc8 commit 85d9e9b
Show file tree
Hide file tree
Showing 2 changed files with 28 additions and 22 deletions.
39 changes: 24 additions & 15 deletions src/aac_metrics/functional/spice.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,9 +124,13 @@ def spice(
for i, (cand, refs) in enumerate(zip(candidates, mult_references))
]

in_file = NamedTemporaryFile(
mode="w", delete=False, dir=tmp_path, prefix="spice_inputs_", suffix=".json"
json_kwds: dict[str, Any] = dict(
mode="w",
delete=False,
dir=tmp_path,
suffix=".json",
)
in_file = NamedTemporaryFile(prefix="spice_inputs_", **json_kwds)
json.dump(input_data, in_file, indent=2)
in_file.close()

Expand All @@ -136,29 +140,28 @@ def spice(
else:
timeout_lst = list(timeout)

out_file = NamedTemporaryFile(
mode="w", delete=False, dir=tmp_path, prefix="spice_outputs_", suffix=".json"
)
out_file = NamedTemporaryFile(prefix="spice_outputs_", **json_kwds)
out_file.close()

txt_kwds: dict[str, Any] = dict(
mode="w",
delete=False,
dir=tmp_path,
suffix=".txt",
)

for i, timeout_i in enumerate(timeout_lst):
if verbose >= 3:
stdout = None
stderr = None
else:
common_kwds: dict[str, Any] = dict(
mode="w",
delete=True,
dir=tmp_path,
suffix=".txt",
)
stdout = NamedTemporaryFile(
prefix="spice_stdout_",
**common_kwds,
**txt_kwds,
)
stderr = NamedTemporaryFile(
prefix="spice_stderr_",
**common_kwds,
**txt_kwds,
)

spice_cmd = [
Expand Down Expand Up @@ -191,8 +194,10 @@ def spice(
)
if stdout is not None:
stdout.close()
os.remove(stdout.name)
if stderr is not None:
stderr.close()
os.remove(stderr.name)
break

except subprocess.TimeoutExpired as err:
Expand All @@ -205,8 +210,10 @@ def spice(
open(out_file.name, "w").close()
if stdout is not None:
stdout.close()
open(stdout.name, "w").close()
if stderr is not None:
stderr.close()
open(stderr.name, "w").close()
time.sleep(1.0)
else:
raise err
Expand All @@ -224,8 +231,10 @@ def spice(
out_file.name,
]
if stdout is not None:
stdout.close()
fpaths.append(stdout.name)
if stderr is not None:
stderr.close()
fpaths.append(stderr.name)

for fpath in fpaths:
Expand Down Expand Up @@ -262,8 +271,8 @@ def spice(
lines = file.readlines()
content = "\n".join(lines)
pylog.error(f"Content of '{fpath}':\n{content}")
except PermissionError:
pylog.warning(f"Cannot open file '{fpath}'.")
except PermissionError as err2:
pylog.warning(f"Cannot open file '{fpath}'. ({err2})")
else:
pylog.info(
f"Note: No temp file recorded. (found {stdout=} and {stderr=})"
Expand Down
11 changes: 4 additions & 7 deletions tests/test_doc_examples.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,11 +56,6 @@ def test_example_1(self) -> None:
"spider",
]
self.assertSetEqual(set(corpus_scores.keys()), set(expected_keys))

# print(corpus_scores["bleu_1"])
# print(torch.as_tensor(0.4278, dtype=torch.float64))
# print("END")

self.assertTrue(
torch.allclose(
corpus_scores["bleu_1"],
Expand Down Expand Up @@ -112,17 +107,19 @@ def test_example_3(self) -> None:
self.assertTrue(set(corpus_scores.keys()).issuperset({"cider_d"}))
self.assertTrue(set(sents_scores.keys()).issuperset({"cider_d"}))

dtype = torch.float64

self.assertTrue(
torch.allclose(
corpus_scores["cider_d"],
torch.as_tensor(0.9614, dtype=torch.float64),
torch.as_tensor(0.9614, dtype=dtype),
atol=0.0001,
)
)
self.assertTrue(
torch.allclose(
sents_scores["cider_d"],
torch.as_tensor([1.3641, 0.5587], dtype=torch.float64),
torch.as_tensor([1.3641, 0.5587], dtype=dtype),
atol=0.0001,
)
)
Expand Down

0 comments on commit 85d9e9b

Please sign in to comment.