Skip to content

Commit 4b45f8a

Browse files
author
SrGonao
committed
Merge branch 'more-tests' of https://github.com/EleutherAI/delphi
2 parents 5f9ecc9 + 05bcf03 commit 4b45f8a

File tree

1 file changed

+7
-14
lines changed

1 file changed

+7
-14
lines changed

tests/e2e.py

+7-14
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66

77
from delphi.__main__ import run
88
from delphi.config import CacheConfig, ConstructorConfig, RunConfig, SamplerConfig
9-
from delphi.log.result_analysis import build_scores_df, latent_balanced_score_metrics
9+
from delphi.log.result_analysis import get_metrics, load_data
1010

1111

1212
async def test():
@@ -58,23 +58,16 @@ async def test():
5858
end_time = time.time()
5959
print(f"Time taken: {end_time - start_time} seconds")
6060

61-
# Performs better than random guessing
6261
scores_path = Path.cwd() / "results" / run_cfg.name / "scores"
63-
hookpoint_firing_counts = torch.load(
64-
Path.cwd() / "results" / run_cfg.name / "log" / "hookpoint_firing_counts.pt",
65-
weights_only=True,
66-
)
67-
df = build_scores_df(scores_path, run_cfg.hookpoints, hookpoint_firing_counts)
68-
for score_type in df["score_type"].unique():
69-
score_df = df.query(f"score_type == '{score_type}'")
7062

71-
weighted_mean_metrics = latent_balanced_score_metrics(
72-
score_df, score_type, verbose=False
73-
)
63+
latent_df, _ = load_data(scores_path, run_cfg.hookpoints)
64+
processed_df = get_metrics(latent_df)
7465

75-
accuracy = weighted_mean_metrics["accuracy"]
66+
# Performs better than random guessing
67+
for score_type, df in processed_df.groupby("score_type"):
68+
accuracy = df["accuracy"].mean()
7669
assert accuracy > 0.55, f"Score type {score_type} has an accuracy of {accuracy}"
7770

7871

7972
if __name__ == "__main__":
80-
asyncio.run(test())
73+
asyncio.run(test())

0 commit comments

Comments
 (0)