Skip to content

Commit d7a0b84

Browse files
authored
Add export_air_bench_csv.py script for AIRBench (stanford-crfm#2699)
1 parent 6a2ff93 commit d7a0b84

File tree

1 file changed

+109
-0
lines changed

1 file changed

+109
-0
lines changed
Lines changed: 109 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,109 @@
1+
"""Reads all runs from the suite and writes them to the CSV folder in CSV format.
2+
3+
EXPERIMENTAL: Not for public use.
4+
TEMPORARY: Delete after 2024-09-30"""
5+
6+
import argparse
7+
import csv
8+
import os
9+
import re
10+
11+
from tqdm import tqdm
12+
13+
from helm.benchmark.adaptation.scenario_state import ScenarioState
14+
from helm.common.codec import from_json
15+
from helm.common.general import ensure_directory_exists
16+
17+
18+
class FieldNames:
19+
CATEGORY_ID = "cate-idx"
20+
L2_NAME = "l2-name"
21+
L3_NAME = "l3-name"
22+
L4_NAME = "l4-name"
23+
PROMPT = "prompt"
24+
RESPONSE = "response"
25+
JUDGE_PROMPT = "judge_prompt"
26+
SCORE_REASON = "score_reason"
27+
SCORE = "score"
28+
29+
30+
def process_one(scenario_state_path: str, csv_file_path: str):
31+
with open(scenario_state_path) as f:
32+
scenario_state = from_json(f.read(), ScenarioState)
33+
34+
fieldnames = [
35+
FieldNames.CATEGORY_ID,
36+
FieldNames.L2_NAME,
37+
FieldNames.L3_NAME,
38+
FieldNames.L4_NAME,
39+
FieldNames.PROMPT,
40+
FieldNames.RESPONSE,
41+
FieldNames.JUDGE_PROMPT,
42+
FieldNames.SCORE_REASON,
43+
FieldNames.SCORE,
44+
]
45+
with open(csv_file_path, "w", newline="") as output_file:
46+
writer = csv.DictWriter(output_file, fieldnames=fieldnames)
47+
writer.writeheader()
48+
for request_state in scenario_state.request_states:
49+
row = {}
50+
references = request_state.instance.references
51+
assert len(references) == 4
52+
row[FieldNames.CATEGORY_ID] = references[0].output.text
53+
row[FieldNames.L2_NAME] = references[1].output.text
54+
row[FieldNames.L3_NAME] = references[2].output.text
55+
row[FieldNames.L4_NAME] = references[3].output.text
56+
row[FieldNames.PROMPT] = request_state.request.prompt
57+
assert request_state.result
58+
assert len(request_state.result.completions) == 1
59+
row[FieldNames.RESPONSE] = request_state.result.completions[0].text
60+
assert request_state.annotations
61+
row[FieldNames.JUDGE_PROMPT] = request_state.annotations["air_bench_2024"]["prompt_text"]
62+
row[FieldNames.SCORE_REASON] = request_state.annotations["air_bench_2024"]["reasoning"]
63+
row[FieldNames.SCORE] = request_state.annotations["air_bench_2024"]["score"]
64+
writer.writerow(row)
65+
print(f"Wrote {csv_file_path}")
66+
67+
68+
def process_all(suite_path: str, csv_path: str):
69+
ensure_directory_exists(csv_path)
70+
run_dir_names = sorted([p for p in os.listdir(suite_path) if p.startswith("air_bench_2024:")])
71+
for run_dir_name in tqdm(run_dir_names, disable=None):
72+
scenario_state_path = os.path.join(suite_path, run_dir_name, "scenario_state.json")
73+
if not os.path.isfile(scenario_state_path):
74+
continue
75+
model_name_match = re.search("model=([A-Za-z0-9_-]+)", run_dir_name)
76+
assert model_name_match
77+
model_name = model_name_match[1]
78+
csv_file_path = os.path.join(csv_path, f"{model_name}_result.csv")
79+
process_one(scenario_state_path, csv_file_path)
80+
81+
82+
def main():
83+
parser = argparse.ArgumentParser()
84+
parser.add_argument(
85+
"-o",
86+
"--output-path",
87+
type=str,
88+
help="Where the benchmarking output lives",
89+
default="benchmark_output",
90+
)
91+
parser.add_argument(
92+
"--csv-path",
93+
type=str,
94+
help="Name of the CSV folder.",
95+
default="csv_output",
96+
)
97+
parser.add_argument(
98+
"--suite",
99+
type=str,
100+
help="Name of the suite.",
101+
required=True,
102+
)
103+
args = parser.parse_args()
104+
suite_path = os.path.join(args.output_path, "runs", args.suite)
105+
process_all(suite_path, args.csv_path)
106+
107+
108+
if __name__ == "__main__":
109+
main()

0 commit comments

Comments
 (0)