Skip to content

Commit ca68b85

Browse files
committed
add gepa readme + remove comments and fix aime task
1 parent 53c6c38 commit ca68b85

File tree

5 files changed

+763
-0
lines changed

5 files changed

+763
-0
lines changed

pyproject.toml

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -56,12 +56,17 @@ verifiers = [
5656
"verifiers",
5757
"openai",
5858
]
59+
gepa = [
60+
"gepa",
61+
"datasets",
62+
]
5963
all = [
6064
"tinker_cookbook[vector-search]",
6165
"tinker_cookbook[wandb]",
6266
"tinker_cookbook[neptune-scale]",
6367
"tinker_cookbook[trackio]",
6468
"tinker_cookbook[verifiers]",
69+
"tinker_cookbook[gepa]",
6570
]
6671

6772
[build-system]
Lines changed: 70 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,70 @@
1+
# GEPA: Prompt Optimization via LLM Reflection
2+
3+
Evolve system prompts through reflection - optimizer-free prompt optimization. GEPA evaluates prompts, has a teacher LLM reflect on failures, mutates the prompt, and keeps improvements.
4+
5+
**Paper**: [arxiv.org/abs/2507.19457](https://arxiv.org/abs/2507.19457) | **Library**: [github.com/gepa-ai/gepa](https://github.com/gepa-ai/gepa)
6+
7+
## Running This Recipe
8+
9+
```bash
10+
pip install tinker_cookbook[gepa]
11+
```
12+
13+
### GSM8K
14+
15+
```bash
16+
python -m tinker_cookbook.recipes.gepa.train \
17+
task_name=gsm8k \
18+
model_name="Qwen/Qwen3-4B-Instruct-2507" \
19+
reflection_model="deepseek-ai/DeepSeek-V3.1" \
20+
max_metric_calls=50
21+
```
22+
23+
After optimization, expect `final/best_score` around 0.91.
24+
25+
### HotpotQA
26+
27+
```bash
28+
python -m tinker_cookbook.recipes.gepa.train \
29+
task_name=hotpotqa \
30+
model_name="Qwen/Qwen3-4B-Instruct-2507" \
31+
reflection_model="deepseek-ai/DeepSeek-V3.1" \
32+
max_metric_calls=100
33+
```
34+
35+
### AIME
36+
37+
```bash
38+
python -m tinker_cookbook.recipes.gepa.train \
39+
task_name=aime \
40+
model_name="Qwen/Qwen3-4B-Instruct-2507" \
41+
reflection_model="deepseek-ai/DeepSeek-V3.1" \
42+
max_metric_calls=150 \
43+
eval_test=true
44+
```
45+
46+
## Custom Tasks
47+
48+
Register via `TASK_REGISTRY`:
49+
50+
```python
51+
from tinker_cookbook.recipes.gepa.tasks import GEPATask, register_task
52+
53+
@register_task("my_benchmark")
54+
class MyBenchmarkTask(GEPATask):
55+
@property
56+
def name(self) -> str:
57+
return "my_benchmark"
58+
59+
@property
60+
def seed_prompt(self) -> str:
61+
return "You are a helpful assistant..."
62+
63+
def load_data(self, seed: int = 0):
64+
return train, val, test # GEPADataInstance lists
65+
66+
def score(self, response: str, answer: str, metadata: dict | None = None) -> float:
67+
return 1.0 if answer.strip() in response else 0.0
68+
```
69+
70+
Then: `python -m tinker_cookbook.recipes.gepa.train task_name=my_benchmark`
Lines changed: 221 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,221 @@
1+
from typing import Any, Callable, TypedDict
2+
3+
import tinker
4+
from gepa.core.adapter import EvaluationBatch, GEPAAdapter
5+
6+
from tinker_cookbook import renderers
7+
from tinker_cookbook.tokenizer_utils import Tokenizer
8+
9+
Scorer = Callable[[str, str, dict[str, Any] | None], float]
10+
11+
12+
class TinkerDataInst(TypedDict):
13+
input: str
14+
answer: str
15+
metadata: dict[str, Any]
16+
17+
18+
class TinkerTrajectory(TypedDict):
19+
data: TinkerDataInst
20+
response: str
21+
score: float
22+
error: str | None
23+
logprobs: list[float] | None
24+
tokens: list[int] | None
25+
26+
27+
class TinkerRolloutOutput(TypedDict):
28+
response: str
29+
30+
31+
TinkerReflectiveRecord = TypedDict(
32+
"TinkerReflectiveRecord",
33+
{
34+
"Inputs": str,
35+
"Generated Outputs": str,
36+
"Feedback": str,
37+
},
38+
)
39+
40+
41+
def default_scorer(response: str, answer: str, metadata: dict[str, Any] | None = None) -> float:
42+
return 1.0 if answer.lower().strip() in response.lower().strip() else 0.0
43+
44+
45+
class TinkerReflectionLM:
46+
def __init__(
47+
self,
48+
sampling_client: tinker.SamplingClient,
49+
renderer: renderers.Renderer,
50+
tokenizer: Tokenizer,
51+
max_tokens: int = 4096,
52+
temperature: float = 0.3,
53+
system_prompt: str | None = None,
54+
):
55+
self.sampling_client = sampling_client
56+
self.renderer = renderer
57+
self.tokenizer = tokenizer
58+
self.system_prompt = system_prompt or (
59+
"You are an expert prompt engineer. Analyze the execution traces and "
60+
"suggest improvements to the system prompt to improve task performance."
61+
)
62+
self.sampling_params = tinker.SamplingParams(
63+
max_tokens=max_tokens,
64+
temperature=temperature,
65+
stop=self.renderer.get_stop_sequences(),
66+
)
67+
68+
def __call__(self, prompt: str) -> str:
69+
renderer_name = self.renderer.__class__.__name__
70+
supports_system = "DeepSeek" not in renderer_name
71+
72+
if supports_system:
73+
messages: list[renderers.Message] = [
74+
{"role": "system", "content": self.system_prompt},
75+
{"role": "user", "content": prompt},
76+
]
77+
else:
78+
combined_content = f"{self.system_prompt}\n\n{prompt}"
79+
messages: list[renderers.Message] = [
80+
{"role": "user", "content": combined_content},
81+
]
82+
83+
model_input = self.renderer.build_generation_prompt(messages)
84+
85+
future = self.sampling_client.sample(
86+
prompt=model_input,
87+
num_samples=1,
88+
sampling_params=self.sampling_params,
89+
)
90+
result = future.result()
91+
seq = result.sequences[0]
92+
parsed, _ = self.renderer.parse_response(seq.tokens)
93+
return parsed["content"]
94+
95+
96+
class TinkerGEPAAdapter(GEPAAdapter[TinkerDataInst, TinkerTrajectory, TinkerRolloutOutput]):
97+
def __init__(
98+
self,
99+
sampling_client: tinker.SamplingClient,
100+
renderer: renderers.Renderer,
101+
tokenizer: Tokenizer,
102+
scorer: Scorer | None = None,
103+
max_tokens: int = 2048,
104+
temperature: float = 0.7,
105+
failure_score: float = 0.0,
106+
component_name: str = "system_prompt",
107+
):
108+
self.sampling_client = sampling_client
109+
self.renderer = renderer
110+
self.tokenizer = tokenizer
111+
self.scorer = scorer or default_scorer
112+
self.max_tokens = max_tokens
113+
self.temperature = temperature
114+
self.failure_score = failure_score
115+
self.component_name = component_name
116+
117+
self.sampling_params = tinker.SamplingParams(
118+
max_tokens=self.max_tokens,
119+
temperature=self.temperature,
120+
stop=self.renderer.get_stop_sequences(),
121+
)
122+
123+
def _get_system_prompt(self, candidate: dict[str, str]) -> str:
124+
if self.component_name not in candidate:
125+
raise ValueError(
126+
f"Candidate missing '{self.component_name}'. Got: {list(candidate.keys())}"
127+
)
128+
return candidate[self.component_name]
129+
130+
def evaluate(
131+
self,
132+
batch: list[TinkerDataInst],
133+
candidate: dict[str, str],
134+
capture_traces: bool = False,
135+
) -> EvaluationBatch[TinkerTrajectory, TinkerRolloutOutput]:
136+
system_prompt = self._get_system_prompt(candidate)
137+
138+
futures = []
139+
for data in batch:
140+
messages: list[renderers.Message] = [
141+
{"role": "system", "content": system_prompt},
142+
{"role": "user", "content": data["input"]},
143+
]
144+
model_input = self.renderer.build_generation_prompt(messages)
145+
futures.append(
146+
self.sampling_client.sample(
147+
prompt=model_input,
148+
num_samples=1,
149+
sampling_params=self.sampling_params,
150+
)
151+
)
152+
153+
outputs: list[TinkerRolloutOutput] = []
154+
scores: list[float] = []
155+
trajectories: list[TinkerTrajectory] | None = [] if capture_traces else None
156+
157+
for future, data in zip(futures, batch):
158+
result = future.result()
159+
seq = result.sequences[0]
160+
parsed, _ = self.renderer.parse_response(seq.tokens)
161+
response = parsed["content"]
162+
score = self.scorer(response, data["answer"], data.get("metadata"))
163+
164+
outputs.append({"response": response})
165+
scores.append(score)
166+
167+
if trajectories is not None:
168+
trajectories.append(
169+
{
170+
"data": data,
171+
"response": response,
172+
"score": score,
173+
"error": None,
174+
"logprobs": seq.logprobs,
175+
"tokens": seq.tokens,
176+
}
177+
)
178+
179+
return EvaluationBatch(outputs=outputs, scores=scores, trajectories=trajectories)
180+
181+
def make_reflective_dataset(
182+
self,
183+
candidate: dict[str, str],
184+
eval_batch: EvaluationBatch[TinkerTrajectory, TinkerRolloutOutput],
185+
components_to_update: list[str],
186+
) -> dict[str, list[TinkerReflectiveRecord]]:
187+
trajectories = eval_batch.trajectories
188+
assert trajectories is not None
189+
190+
result: dict[str, list[TinkerReflectiveRecord]] = {}
191+
192+
for comp in components_to_update:
193+
items: list[TinkerReflectiveRecord] = []
194+
195+
for traj in trajectories:
196+
data = traj["data"]
197+
response = traj["response"]
198+
score = traj["score"]
199+
error = traj["error"]
200+
201+
if error:
202+
feedback = f"Error: {error}"
203+
elif score >= 1.0:
204+
feedback = f"Correct. Expected: '{data['answer']}'"
205+
else:
206+
feedback = f"Incorrect. Expected: '{data['answer']}'"
207+
if data.get("metadata"):
208+
hints = ", ".join(f"{k}={v}" for k, v in data["metadata"].items())
209+
feedback += f" (context: {hints})"
210+
211+
items.append(
212+
{
213+
"Inputs": data["input"],
214+
"Generated Outputs": response[:1000] if response else "(empty)",
215+
"Feedback": feedback,
216+
}
217+
)
218+
219+
result[comp] = items
220+
221+
return result

0 commit comments

Comments
 (0)