|
| 1 | +from typing import Any, Callable, TypedDict |
| 2 | + |
| 3 | +import tinker |
| 4 | +from gepa.core.adapter import EvaluationBatch, GEPAAdapter |
| 5 | + |
| 6 | +from tinker_cookbook import renderers |
| 7 | +from tinker_cookbook.tokenizer_utils import Tokenizer |
| 8 | + |
| 9 | +Scorer = Callable[[str, str, dict[str, Any] | None], float] |
| 10 | + |
| 11 | + |
| 12 | +class TinkerDataInst(TypedDict): |
| 13 | + input: str |
| 14 | + answer: str |
| 15 | + metadata: dict[str, Any] |
| 16 | + |
| 17 | + |
| 18 | +class TinkerTrajectory(TypedDict): |
| 19 | + data: TinkerDataInst |
| 20 | + response: str |
| 21 | + score: float |
| 22 | + error: str | None |
| 23 | + logprobs: list[float] | None |
| 24 | + tokens: list[int] | None |
| 25 | + |
| 26 | + |
| 27 | +class TinkerRolloutOutput(TypedDict): |
| 28 | + response: str |
| 29 | + |
| 30 | + |
| 31 | +TinkerReflectiveRecord = TypedDict( |
| 32 | + "TinkerReflectiveRecord", |
| 33 | + { |
| 34 | + "Inputs": str, |
| 35 | + "Generated Outputs": str, |
| 36 | + "Feedback": str, |
| 37 | + }, |
| 38 | +) |
| 39 | + |
| 40 | + |
| 41 | +def default_scorer(response: str, answer: str, metadata: dict[str, Any] | None = None) -> float: |
| 42 | + return 1.0 if answer.lower().strip() in response.lower().strip() else 0.0 |
| 43 | + |
| 44 | + |
| 45 | +class TinkerReflectionLM: |
| 46 | + def __init__( |
| 47 | + self, |
| 48 | + sampling_client: tinker.SamplingClient, |
| 49 | + renderer: renderers.Renderer, |
| 50 | + tokenizer: Tokenizer, |
| 51 | + max_tokens: int = 4096, |
| 52 | + temperature: float = 0.3, |
| 53 | + system_prompt: str | None = None, |
| 54 | + ): |
| 55 | + self.sampling_client = sampling_client |
| 56 | + self.renderer = renderer |
| 57 | + self.tokenizer = tokenizer |
| 58 | + self.system_prompt = system_prompt or ( |
| 59 | + "You are an expert prompt engineer. Analyze the execution traces and " |
| 60 | + "suggest improvements to the system prompt to improve task performance." |
| 61 | + ) |
| 62 | + self.sampling_params = tinker.SamplingParams( |
| 63 | + max_tokens=max_tokens, |
| 64 | + temperature=temperature, |
| 65 | + stop=self.renderer.get_stop_sequences(), |
| 66 | + ) |
| 67 | + |
| 68 | + def __call__(self, prompt: str) -> str: |
| 69 | + renderer_name = self.renderer.__class__.__name__ |
| 70 | + supports_system = "DeepSeek" not in renderer_name |
| 71 | + |
| 72 | + if supports_system: |
| 73 | + messages: list[renderers.Message] = [ |
| 74 | + {"role": "system", "content": self.system_prompt}, |
| 75 | + {"role": "user", "content": prompt}, |
| 76 | + ] |
| 77 | + else: |
| 78 | + combined_content = f"{self.system_prompt}\n\n{prompt}" |
| 79 | + messages: list[renderers.Message] = [ |
| 80 | + {"role": "user", "content": combined_content}, |
| 81 | + ] |
| 82 | + |
| 83 | + model_input = self.renderer.build_generation_prompt(messages) |
| 84 | + |
| 85 | + future = self.sampling_client.sample( |
| 86 | + prompt=model_input, |
| 87 | + num_samples=1, |
| 88 | + sampling_params=self.sampling_params, |
| 89 | + ) |
| 90 | + result = future.result() |
| 91 | + seq = result.sequences[0] |
| 92 | + parsed, _ = self.renderer.parse_response(seq.tokens) |
| 93 | + return parsed["content"] |
| 94 | + |
| 95 | + |
| 96 | +class TinkerGEPAAdapter(GEPAAdapter[TinkerDataInst, TinkerTrajectory, TinkerRolloutOutput]): |
| 97 | + def __init__( |
| 98 | + self, |
| 99 | + sampling_client: tinker.SamplingClient, |
| 100 | + renderer: renderers.Renderer, |
| 101 | + tokenizer: Tokenizer, |
| 102 | + scorer: Scorer | None = None, |
| 103 | + max_tokens: int = 2048, |
| 104 | + temperature: float = 0.7, |
| 105 | + failure_score: float = 0.0, |
| 106 | + component_name: str = "system_prompt", |
| 107 | + ): |
| 108 | + self.sampling_client = sampling_client |
| 109 | + self.renderer = renderer |
| 110 | + self.tokenizer = tokenizer |
| 111 | + self.scorer = scorer or default_scorer |
| 112 | + self.max_tokens = max_tokens |
| 113 | + self.temperature = temperature |
| 114 | + self.failure_score = failure_score |
| 115 | + self.component_name = component_name |
| 116 | + |
| 117 | + self.sampling_params = tinker.SamplingParams( |
| 118 | + max_tokens=self.max_tokens, |
| 119 | + temperature=self.temperature, |
| 120 | + stop=self.renderer.get_stop_sequences(), |
| 121 | + ) |
| 122 | + |
| 123 | + def _get_system_prompt(self, candidate: dict[str, str]) -> str: |
| 124 | + if self.component_name not in candidate: |
| 125 | + raise ValueError( |
| 126 | + f"Candidate missing '{self.component_name}'. Got: {list(candidate.keys())}" |
| 127 | + ) |
| 128 | + return candidate[self.component_name] |
| 129 | + |
| 130 | + def evaluate( |
| 131 | + self, |
| 132 | + batch: list[TinkerDataInst], |
| 133 | + candidate: dict[str, str], |
| 134 | + capture_traces: bool = False, |
| 135 | + ) -> EvaluationBatch[TinkerTrajectory, TinkerRolloutOutput]: |
| 136 | + system_prompt = self._get_system_prompt(candidate) |
| 137 | + |
| 138 | + futures = [] |
| 139 | + for data in batch: |
| 140 | + messages: list[renderers.Message] = [ |
| 141 | + {"role": "system", "content": system_prompt}, |
| 142 | + {"role": "user", "content": data["input"]}, |
| 143 | + ] |
| 144 | + model_input = self.renderer.build_generation_prompt(messages) |
| 145 | + futures.append( |
| 146 | + self.sampling_client.sample( |
| 147 | + prompt=model_input, |
| 148 | + num_samples=1, |
| 149 | + sampling_params=self.sampling_params, |
| 150 | + ) |
| 151 | + ) |
| 152 | + |
| 153 | + outputs: list[TinkerRolloutOutput] = [] |
| 154 | + scores: list[float] = [] |
| 155 | + trajectories: list[TinkerTrajectory] | None = [] if capture_traces else None |
| 156 | + |
| 157 | + for future, data in zip(futures, batch): |
| 158 | + result = future.result() |
| 159 | + seq = result.sequences[0] |
| 160 | + parsed, _ = self.renderer.parse_response(seq.tokens) |
| 161 | + response = parsed["content"] |
| 162 | + score = self.scorer(response, data["answer"], data.get("metadata")) |
| 163 | + |
| 164 | + outputs.append({"response": response}) |
| 165 | + scores.append(score) |
| 166 | + |
| 167 | + if trajectories is not None: |
| 168 | + trajectories.append( |
| 169 | + { |
| 170 | + "data": data, |
| 171 | + "response": response, |
| 172 | + "score": score, |
| 173 | + "error": None, |
| 174 | + "logprobs": seq.logprobs, |
| 175 | + "tokens": seq.tokens, |
| 176 | + } |
| 177 | + ) |
| 178 | + |
| 179 | + return EvaluationBatch(outputs=outputs, scores=scores, trajectories=trajectories) |
| 180 | + |
| 181 | + def make_reflective_dataset( |
| 182 | + self, |
| 183 | + candidate: dict[str, str], |
| 184 | + eval_batch: EvaluationBatch[TinkerTrajectory, TinkerRolloutOutput], |
| 185 | + components_to_update: list[str], |
| 186 | + ) -> dict[str, list[TinkerReflectiveRecord]]: |
| 187 | + trajectories = eval_batch.trajectories |
| 188 | + assert trajectories is not None |
| 189 | + |
| 190 | + result: dict[str, list[TinkerReflectiveRecord]] = {} |
| 191 | + |
| 192 | + for comp in components_to_update: |
| 193 | + items: list[TinkerReflectiveRecord] = [] |
| 194 | + |
| 195 | + for traj in trajectories: |
| 196 | + data = traj["data"] |
| 197 | + response = traj["response"] |
| 198 | + score = traj["score"] |
| 199 | + error = traj["error"] |
| 200 | + |
| 201 | + if error: |
| 202 | + feedback = f"Error: {error}" |
| 203 | + elif score >= 1.0: |
| 204 | + feedback = f"Correct. Expected: '{data['answer']}'" |
| 205 | + else: |
| 206 | + feedback = f"Incorrect. Expected: '{data['answer']}'" |
| 207 | + if data.get("metadata"): |
| 208 | + hints = ", ".join(f"{k}={v}" for k, v in data["metadata"].items()) |
| 209 | + feedback += f" (context: {hints})" |
| 210 | + |
| 211 | + items.append( |
| 212 | + { |
| 213 | + "Inputs": data["input"], |
| 214 | + "Generated Outputs": response[:1000] if response else "(empty)", |
| 215 | + "Feedback": feedback, |
| 216 | + } |
| 217 | + ) |
| 218 | + |
| 219 | + result[comp] = items |
| 220 | + |
| 221 | + return result |
0 commit comments