Skip to content

Commit 99e565e

Browse files
committed
add gepa readme + remove comments and fix aime task
1 parent 53c6c38 commit 99e565e

File tree

5 files changed

+748
-0
lines changed

5 files changed

+748
-0
lines changed

pyproject.toml

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -56,12 +56,17 @@ verifiers = [
5656
"verifiers",
5757
"openai",
5858
]
59+
gepa = [
60+
"gepa",
61+
"datasets",
62+
]
5963
all = [
6064
"tinker_cookbook[vector-search]",
6165
"tinker_cookbook[wandb]",
6266
"tinker_cookbook[neptune-scale]",
6367
"tinker_cookbook[trackio]",
6468
"tinker_cookbook[verifiers]",
69+
"tinker_cookbook[gepa]",
6570
]
6671

6772
[build-system]
Lines changed: 70 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,70 @@
1+
# GEPA: Prompt Optimization via LLM Reflection
2+
3+
Evolve system prompts through reflection - optimizer-free prompt optimization. GEPA evaluates prompts, has a teacher LLM reflect on failures, mutates the prompt, and keeps improvements.
4+
5+
**Paper**: [arxiv.org/abs/2507.19457](https://arxiv.org/abs/2507.19457) | **Library**: [github.com/gepa-ai/gepa](https://github.com/gepa-ai/gepa)
6+
7+
## Running This Recipe
8+
9+
```bash
10+
pip install tinker_cookbook[gepa]
11+
```
12+
13+
### GSM8K
14+
15+
```bash
16+
python -m tinker_cookbook.recipes.gepa.train \
17+
task_name=gsm8k \
18+
model_name="Qwen/Qwen3-4B-Instruct-2507" \
19+
reflection_model="deepseek-ai/DeepSeek-V3.1" \
20+
max_metric_calls=50
21+
```
22+
23+
After optimization, expect `final/best_score` around 0.91.
24+
25+
### HotpotQA
26+
27+
```bash
28+
python -m tinker_cookbook.recipes.gepa.train \
29+
task_name=hotpotqa \
30+
model_name="Qwen/Qwen3-4B-Instruct-2507" \
31+
reflection_model="deepseek-ai/DeepSeek-V3.1" \
32+
max_metric_calls=100
33+
```
34+
35+
### AIME
36+
37+
```bash
38+
python -m tinker_cookbook.recipes.gepa.train \
39+
task_name=aime \
40+
model_name="Qwen/Qwen3-4B-Instruct-2507" \
41+
reflection_model="deepseek-ai/DeepSeek-V3.1" \
42+
max_metric_calls=150 \
43+
eval_test=true
44+
```
45+
46+
## Custom Tasks
47+
48+
Register via `TASK_REGISTRY`:
49+
50+
```python
51+
from tinker_cookbook.recipes.gepa.tasks import GEPATask, register_task
52+
53+
@register_task("my_benchmark")
54+
class MyBenchmarkTask(GEPATask):
55+
@property
56+
def name(self) -> str:
57+
return "my_benchmark"
58+
59+
@property
60+
def seed_prompt(self) -> str:
61+
return "You are a helpful assistant..."
62+
63+
def load_data(self, seed: int = 0):
64+
return train, val, test # GEPADataInstance lists
65+
66+
def score(self, response: str, answer: str, metadata: dict | None = None) -> float:
67+
return 1.0 if answer.strip() in response else 0.0
68+
```
69+
70+
Then: `python -m tinker_cookbook.recipes.gepa.train task_name=my_benchmark`
Lines changed: 215 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,215 @@
1+
from typing import Any, Callable, TypedDict
2+
3+
import tinker
4+
from gepa.core.adapter import EvaluationBatch, GEPAAdapter
5+
6+
from tinker_cookbook import renderers
7+
from tinker_cookbook.tokenizer_utils import Tokenizer
8+
9+
Scorer = Callable[[str, str, dict[str, Any] | None], float]
10+
11+
12+
class TinkerDataInst(TypedDict):
13+
input: str
14+
answer: str
15+
metadata: dict[str, Any]
16+
17+
18+
class TinkerTrajectory(TypedDict):
19+
data: TinkerDataInst
20+
response: str
21+
score: float
22+
error: str | None
23+
logprobs: list[float] | None
24+
tokens: list[int] | None
25+
26+
27+
class TinkerRolloutOutput(TypedDict):
28+
response: str
29+
30+
31+
TinkerReflectiveRecord = TypedDict(
32+
"TinkerReflectiveRecord",
33+
{
34+
"Inputs": str,
35+
"Generated Outputs": str,
36+
"Feedback": str,
37+
},
38+
)
39+
40+
41+
def default_scorer(response: str, answer: str, metadata: dict[str, Any] | None = None) -> float:
42+
return 1.0 if answer.lower().strip() in response.lower().strip() else 0.0
43+
44+
45+
class TinkerReflectionLM:
46+
def __init__(
47+
self,
48+
sampling_client: tinker.SamplingClient,
49+
renderer: renderers.Renderer,
50+
tokenizer: Tokenizer,
51+
max_tokens: int = 4096,
52+
temperature: float = 0.3,
53+
system_prompt: str | None = None,
54+
):
55+
self.sampling_client = sampling_client
56+
self.renderer = renderer
57+
self.tokenizer = tokenizer
58+
self.system_prompt = system_prompt or (
59+
"You are an expert prompt engineer. Analyze the execution traces and "
60+
"suggest improvements to the system prompt to improve task performance."
61+
)
62+
self.sampling_params = tinker.SamplingParams(
63+
max_tokens=max_tokens,
64+
temperature=temperature,
65+
stop=self.renderer.get_stop_sequences(),
66+
)
67+
68+
def __call__(self, prompt: str) -> str:
69+
renderer_name = self.renderer.__class__.__name__
70+
supports_system = "DeepSeek" not in renderer_name
71+
72+
if supports_system:
73+
messages: list[renderers.Message] = [
74+
{"role": "system", "content": self.system_prompt},
75+
{"role": "user", "content": prompt},
76+
]
77+
else:
78+
combined_content = f"{self.system_prompt}\n\n{prompt}"
79+
messages: list[renderers.Message] = [
80+
{"role": "user", "content": combined_content},
81+
]
82+
83+
model_input = self.renderer.build_generation_prompt(messages)
84+
85+
future = self.sampling_client.sample(
86+
prompt=model_input,
87+
num_samples=1,
88+
sampling_params=self.sampling_params,
89+
)
90+
result = future.result()
91+
seq = result.sequences[0]
92+
parsed, _ = self.renderer.parse_response(seq.tokens)
93+
return parsed["content"]
94+
95+
96+
class TinkerGEPAAdapter(GEPAAdapter[TinkerDataInst, TinkerTrajectory, TinkerRolloutOutput]):
97+
def __init__(
98+
self,
99+
sampling_client: tinker.SamplingClient,
100+
renderer: renderers.Renderer,
101+
tokenizer: Tokenizer,
102+
scorer: Scorer | None = None,
103+
max_tokens: int = 2048,
104+
temperature: float = 0.7,
105+
failure_score: float = 0.0,
106+
component_name: str = "system_prompt",
107+
):
108+
self.sampling_client = sampling_client
109+
self.renderer = renderer
110+
self.tokenizer = tokenizer
111+
self.scorer = scorer or default_scorer
112+
self.max_tokens = max_tokens
113+
self.temperature = temperature
114+
self.failure_score = failure_score
115+
self.component_name = component_name
116+
117+
self.sampling_params = tinker.SamplingParams(
118+
max_tokens=self.max_tokens,
119+
temperature=self.temperature,
120+
stop=self.renderer.get_stop_sequences(),
121+
)
122+
123+
def _get_system_prompt(self, candidate: dict[str, str]) -> str:
124+
if self.component_name not in candidate:
125+
raise ValueError(f"Candidate missing '{self.component_name}'. Got: {list(candidate.keys())}")
126+
return candidate[self.component_name]
127+
128+
def evaluate(
129+
self,
130+
batch: list[TinkerDataInst],
131+
candidate: dict[str, str],
132+
capture_traces: bool = False,
133+
) -> EvaluationBatch[TinkerTrajectory, TinkerRolloutOutput]:
134+
system_prompt = self._get_system_prompt(candidate)
135+
136+
futures = []
137+
for data in batch:
138+
messages: list[renderers.Message] = [
139+
{"role": "system", "content": system_prompt},
140+
{"role": "user", "content": data["input"]},
141+
]
142+
model_input = self.renderer.build_generation_prompt(messages)
143+
futures.append(
144+
self.sampling_client.sample(
145+
prompt=model_input,
146+
num_samples=1,
147+
sampling_params=self.sampling_params,
148+
)
149+
)
150+
151+
outputs: list[TinkerRolloutOutput] = []
152+
scores: list[float] = []
153+
trajectories: list[TinkerTrajectory] | None = [] if capture_traces else None
154+
155+
for future, data in zip(futures, batch):
156+
result = future.result()
157+
seq = result.sequences[0]
158+
parsed, _ = self.renderer.parse_response(seq.tokens)
159+
response = parsed["content"]
160+
score = self.scorer(response, data["answer"], data.get("metadata"))
161+
162+
outputs.append({"response": response})
163+
scores.append(score)
164+
165+
if trajectories is not None:
166+
trajectories.append({
167+
"data": data,
168+
"response": response,
169+
"score": score,
170+
"error": None,
171+
"logprobs": seq.logprobs,
172+
"tokens": seq.tokens,
173+
})
174+
175+
return EvaluationBatch(outputs=outputs, scores=scores, trajectories=trajectories)
176+
177+
def make_reflective_dataset(
178+
self,
179+
candidate: dict[str, str],
180+
eval_batch: EvaluationBatch[TinkerTrajectory, TinkerRolloutOutput],
181+
components_to_update: list[str],
182+
) -> dict[str, list[TinkerReflectiveRecord]]:
183+
trajectories = eval_batch.trajectories
184+
assert trajectories is not None
185+
186+
result: dict[str, list[TinkerReflectiveRecord]] = {}
187+
188+
for comp in components_to_update:
189+
items: list[TinkerReflectiveRecord] = []
190+
191+
for traj in trajectories:
192+
data = traj["data"]
193+
response = traj["response"]
194+
score = traj["score"]
195+
error = traj["error"]
196+
197+
if error:
198+
feedback = f"Error: {error}"
199+
elif score >= 1.0:
200+
feedback = f"Correct. Expected: '{data['answer']}'"
201+
else:
202+
feedback = f"Incorrect. Expected: '{data['answer']}'"
203+
if data.get("metadata"):
204+
hints = ", ".join(f"{k}={v}" for k, v in data["metadata"].items())
205+
feedback += f" (context: {hints})"
206+
207+
items.append({
208+
"Inputs": data["input"],
209+
"Generated Outputs": response[:1000] if response else "(empty)",
210+
"Feedback": feedback,
211+
})
212+
213+
result[comp] = items
214+
215+
return result

0 commit comments

Comments
 (0)