-
Notifications
You must be signed in to change notification settings - Fork 1
/
run_exam.py
209 lines (178 loc) · 7.24 KB
/
run_exam.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
"""
Run a CPA exam session with prompt style 001.
"""
# imports
import datetime
import json
import time
from pathlib import Path
from typing import Iterator
# packages
import openai
import pandas
import tqdm
# set the key
openai.api_key = (Path(__file__).parent / ".openai_key").read_text()
# local imports
from question_data import parse_question_source
from prompts import *
MODEL_NAME = "text-davinci-003"
def get_parameter_sets() -> Iterator[dict]:
"""Generate a set of parameter sets."""
for temperature in [0.0, 0.5, 1.0]: # 0.0, 0.5, 1.0
for max_tokens in [
256,
]: # 16, 128
for top_p in [
1,
]: # 1, 0.75
for best_of in [1, 2]: # 1, 2, 4
for frequency_penalty in [
0,
]:
for presence_penalty in [
0,
]:
yield {
"temperature": temperature,
"max_tokens": max_tokens,
"top_p": top_p,
"best_of": best_of,
"frequency_penalty": frequency_penalty,
"presence_penalty": presence_penalty,
}
def get_next_session_path() -> Path:
"""Get the next session path."""
session_number = 1
while True:
session_id = f"cpa-exam-{session_number:03d}"
session_path = (
Path(__file__).parent.parent / "results" / "questions-02" / "sessions-001"
)
session_path.mkdir(exist_ok=True)
session_path = session_path / session_id
# skip if exists
if session_path.exists():
session_number += 1
continue
# otherwise continue
session_path.mkdir(exist_ok=True)
return session_path
def main():
# iterate through questions and generate prompt
question_file = Path(__file__).parent.parent / "data" / "questions_02.txt"
question_list = parse_question_source(question_file)
question_set_name = question_file.name
# set samples per value
num_samples_per_set = 1
"""
These prompts are only relevant for the test REG section:
generate_prompt_001,
generate_prompt_002,
generate_prompt_003,
generate_prompt_004,
generate_prompt_005,
generate_prompt_006,
generate_prompt_007,
generate_prompt_008,
generate_prompt_009,
generate_prompt_010,
"""
prompt_list = [
generate_prompt_011,
generate_prompt_012,
generate_prompt_013,
generate_prompt_014,
generate_prompt_015,
generate_prompt_016,
generate_prompt_017,
generate_prompt_018,
generate_prompt_019,
generate_prompt_020,
]
# iterate through parameter values
for parameter_kwargs in get_parameter_sets():
for sample_id in range(num_samples_per_set):
for prompt_method in prompt_list:
# set up the session path iteratively
session_path = get_next_session_path()
# status update
# print(f"Running with prompt method {str(prompt_method.__name__)}, parameters: {parameter_kwargs}")
# generate the prompts
exam_data = {
"model_name": MODEL_NAME,
"question_set": question_set_name,
"prompt_method": str(prompt_method.__name__),
"parameters": parameter_kwargs,
"start_time": datetime.datetime.now().isoformat(),
"end_time": None,
"questions": [],
}
# iterate through questions and generate prompt
question_prog_bar = tqdm.tqdm(question_list, desc="Questions")
for question in question_prog_bar:
# set description
question_prog_bar.set_description(
f"Q{question['question_number']}, prompt method {str(prompt_method.__name__)}, parameters: {parameter_kwargs}"
)
# generate the prompt
prompt = prompt_method(question)
# setup question data
question_data = {
"question_input": question,
"model_prompt": prompt,
"model_response": None,
}
# try to query the API, retry on failure, else log the failed response
try:
question_data["model_response"] = openai.Completion.create(
model=MODEL_NAME,
prompt=question_data["model_prompt"],
**parameter_kwargs,
)
except Exception as e:
question_prog_bar.set_description(
f"First error, retrying in 5: {e}"
)
question_data["model_response"] = None
# sleep and retry
time.sleep(5)
try:
question_data["model_response"] = openai.Completion.create(
model=MODEL_NAME,
prompt=question_data["model_prompt"],
**parameter_kwargs,
)
except Exception as f:
question_prog_bar.set_description(
f"Second error, retrying in 10: {f}"
)
question_data["model_response"] = None
# sleep and retry
time.sleep(10)
try:
question_data[
"model_response"
] = openai.Completion.create(
model=MODEL_NAME,
prompt=question_data["model_prompt"],
**parameter_kwargs,
)
except Exception as g:
print(f"Third error, skipping question {question}: {g}")
question_data["model_response"] = None
finally:
# log the current state of the exam
exam_data["questions"].append(question_data)
with open(
session_path / "exam_data.json", "wt", encoding="utf-8"
) as output_file:
json.dump(exam_data, output_file)
# save final state
exam_data["end_time"] = datetime.datetime.now().isoformat()
with open(
session_path / "exam_data.json", "wt", encoding="utf-8"
) as output_file:
json.dump(exam_data, output_file)
if __name__ == "__main__":
main()