6
6
7
7
from delphi .__main__ import run
8
8
from delphi .config import CacheConfig , ConstructorConfig , RunConfig , SamplerConfig
9
- from delphi .log .result_analysis import build_scores_df , latent_balanced_score_metrics
9
+ from delphi .log .result_analysis import get_metrics , load_data
10
10
11
11
12
12
async def test ():
@@ -58,23 +58,16 @@ async def test():
58
58
end_time = time .time ()
59
59
print (f"Time taken: { end_time - start_time } seconds" )
60
60
61
- # Performs better than random guessing
62
61
scores_path = Path .cwd () / "results" / run_cfg .name / "scores"
63
- hookpoint_firing_counts = torch .load (
64
- Path .cwd () / "results" / run_cfg .name / "log" / "hookpoint_firing_counts.pt" ,
65
- weights_only = True ,
66
- )
67
- df = build_scores_df (scores_path , run_cfg .hookpoints , hookpoint_firing_counts )
68
- for score_type in df ["score_type" ].unique ():
69
- score_df = df .query (f"score_type == '{ score_type } '" )
70
62
71
- weighted_mean_metrics = latent_balanced_score_metrics (
72
- score_df , score_type , verbose = False
73
- )
63
+ latent_df , _ = load_data (scores_path , run_cfg .hookpoints )
64
+ processed_df = get_metrics (latent_df )
74
65
75
- accuracy = weighted_mean_metrics ["accuracy" ]
66
+ # Performs better than random guessing
67
+ for score_type , df in processed_df .groupby ("score_type" ):
68
+ accuracy = df ["accuracy" ].mean ()
76
69
assert accuracy > 0.55 , f"Score type { score_type } has an accuracy of { accuracy } "
77
70
78
71
79
72
if __name__ == "__main__" :
80
- asyncio .run (test ())
73
+ asyncio .run (test ())
0 commit comments