Skip to content

Commit dcb1bf9

Browse files
committed
add correctness tests for temperature-scaled logprobs returned by the sampling engine
1 parent 5f8b59c commit dcb1bf9

File tree

1 file changed

+368
-0
lines changed

1 file changed

+368
-0
lines changed
Lines changed: 368 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,368 @@
1+
"""
2+
Validate temperature scaling in sampling by comparing pairwise logprob differences.
3+
4+
Two complementary checks ensure correctness across temperatures and sequence positions:
5+
1. Temperature scaling: Verifies (log p_τ(i) - log p_τ(j)) ≈ (1/τ) * (log p_1(i) - log p_1(j))
6+
2. Sequence-level consistency: Validates multi-token sampling returns accurate logprobs at each step.
7+
"""
8+
9+
from __future__ import annotations
10+
11+
import asyncio
12+
from typing import Sequence
13+
14+
import chz
15+
import numpy as np
16+
import tinker
17+
18+
from tinker_cookbook.tokenizer_utils import get_tokenizer
19+
20+
21+
def _default_temperatures() -> list[float]:
22+
return [0.5, 0.7, 1.0, 1.2, 1.5, 1.8]
23+
24+
25+
@chz.chz
26+
class Config:
27+
base_model: str
28+
prompt: str = (
29+
"Explain temperature scaling in language model sampling, include a brief "
30+
"example, and discuss calibration vs diversity trade-offs."
31+
)
32+
temperatures: list[float] = chz.field(default_factory=_default_temperatures)
33+
baseline_temperature: float = 1.0
34+
num_trials: int = 20
35+
check_sequence_consistency: bool = True
36+
consistency_check_length: int = 20
37+
consistency_check_temp: float = 0.5
38+
seed: int | None = 42
39+
base_url: str | None = None
40+
41+
42+
async def _sample_next_token(
43+
sampling_client: tinker.SamplingClient,
44+
model_input: tinker.ModelInput,
45+
*,
46+
temperature: float,
47+
max_tokens: int,
48+
seed: int | None,
49+
) -> tuple[list[int], list[float]]:
50+
resp = await sampling_client.sample_async(
51+
prompt=model_input,
52+
num_samples=1,
53+
sampling_params=tinker.SamplingParams(
54+
max_tokens=max_tokens,
55+
temperature=temperature,
56+
seed=seed,
57+
),
58+
)
59+
seq = resp.sequences[0]
60+
if seq.logprobs is None:
61+
raise RuntimeError("Sampling response did not include logprobs")
62+
return seq.tokens, seq.logprobs
63+
64+
65+
async def _collect_sampled_token_logprobs(
66+
sampling_client: tinker.SamplingClient,
67+
model_input: tinker.ModelInput,
68+
*,
69+
temperature: float,
70+
num_trials: int,
71+
max_tokens: int,
72+
seed: int | None,
73+
) -> dict[int, float]:
74+
"""Collect token_id -> logprob at a given temperature over several trials."""
75+
out: dict[int, float] = {}
76+
base = 0 if seed is None else seed
77+
for i in range(num_trials):
78+
s = base + i if seed is not None else None
79+
tokens, lps = await _sample_next_token(
80+
sampling_client,
81+
model_input,
82+
temperature=temperature,
83+
max_tokens=max_tokens,
84+
seed=s,
85+
)
86+
if not tokens:
87+
continue
88+
t = tokens[0]
89+
out.setdefault(t, lps[0])
90+
return out
91+
92+
93+
async def _compute_logp1_for_tokens(
94+
sampling_client: tinker.SamplingClient,
95+
prompt_tokens: list[int],
96+
tokens: Sequence[int],
97+
) -> dict[int, float]:
98+
"""Compute baseline log p_1(token|prompt) for each token via compute_logprobs_async."""
99+
res: dict[int, float] = {}
100+
for tok in tokens:
101+
seq = tinker.ModelInput.from_ints(prompt_tokens + [tok])
102+
lps = await sampling_client.compute_logprobs_async(seq)
103+
lp = lps[len(prompt_tokens)]
104+
if lp is None:
105+
raise RuntimeError(
106+
"compute_logprobs_async did not return a logprob for the sampled token"
107+
)
108+
res[tok] = lp
109+
return res
110+
111+
112+
def _pairwise_ratio_metrics(
113+
base_logp: dict[int, float],
114+
temp_logp: dict[int, float],
115+
temperature: float,
116+
) -> dict[str, float]:
117+
"""Compare pairwise logprob differences: (log p_τ(i) - log p_τ(j)) vs (1/τ) * (log p_1(i) - log p_1(j))."""
118+
common = sorted(set(base_logp) & set(temp_logp))
119+
if len(common) < 2:
120+
return {
121+
"tokens": float(len(common)),
122+
"pairs": 0.0,
123+
"mean_abs_err": float("nan"),
124+
"max_abs_err": float("nan"),
125+
}
126+
base_diffs: list[float] = []
127+
temp_diffs: list[float] = []
128+
inv_tau = 1.0 / max(temperature, 1e-9)
129+
for a in range(len(common)):
130+
for b in range(a + 1, len(common)):
131+
i, j = common[a], common[b]
132+
base_diffs.append(inv_tau * (base_logp[i] - base_logp[j]))
133+
temp_diffs.append(temp_logp[i] - temp_logp[j])
134+
x = np.array(base_diffs, dtype=float)
135+
y = np.array(temp_diffs, dtype=float)
136+
abs_err = np.abs(y - x)
137+
mean_abs_err = float(np.mean(abs_err))
138+
max_abs_err = float(np.max(abs_err))
139+
return {
140+
"tokens": float(len(common)),
141+
"pairs": float(len(base_diffs)),
142+
"mean_abs_err": mean_abs_err,
143+
"max_abs_err": max_abs_err,
144+
}
145+
146+
147+
# ============================================================================
148+
# Sequence-level consistency validation
149+
# ============================================================================
150+
151+
152+
async def _sample_sequence_oneshot(
153+
sampling_client: tinker.SamplingClient,
154+
prompt_tokens: list[int],
155+
*,
156+
temperature: float,
157+
max_tokens: int,
158+
seed: int | None,
159+
) -> tuple[list[int], list[float]]:
160+
"""Sample a sequence in one call with max_tokens > 1."""
161+
model_input = tinker.ModelInput.from_ints(prompt_tokens)
162+
resp = await sampling_client.sample_async(
163+
prompt=model_input,
164+
num_samples=1,
165+
sampling_params=tinker.SamplingParams(
166+
max_tokens=max_tokens,
167+
temperature=temperature,
168+
seed=seed,
169+
),
170+
)
171+
seq = resp.sequences[0]
172+
if seq.logprobs is None:
173+
raise RuntimeError("Sampling response did not include logprobs")
174+
return seq.tokens, seq.logprobs
175+
176+
177+
async def _resample_tokens_individually(
178+
sampling_client: tinker.SamplingClient,
179+
prompt_tokens: list[int],
180+
*,
181+
temperature: float,
182+
length: int,
183+
seed: int | None,
184+
) -> tuple[list[int], list[float]]:
185+
"""Sample tokens one at a time, feeding each back into the prefix.
186+
187+
This mimics what max_tokens > 1 should do internally: sample token i,
188+
append to context, then sample token i+1.
189+
"""
190+
tokens: list[int] = []
191+
logprobs: list[float] = []
192+
current_prefix = prompt_tokens.copy()
193+
194+
for i in range(length):
195+
model_input = tinker.ModelInput.from_ints(current_prefix)
196+
# Increment seed for each position to get different random states
197+
pos_seed = (seed + i) if seed is not None else None
198+
199+
resp = await sampling_client.sample_async(
200+
prompt=model_input,
201+
num_samples=1,
202+
sampling_params=tinker.SamplingParams(
203+
max_tokens=1,
204+
temperature=temperature,
205+
seed=pos_seed,
206+
),
207+
)
208+
seq = resp.sequences[0]
209+
if not seq.tokens or seq.logprobs is None:
210+
break
211+
212+
tok = seq.tokens[0]
213+
logprob = seq.logprobs[0]
214+
tokens.append(tok)
215+
logprobs.append(logprob)
216+
current_prefix.append(tok)
217+
218+
return tokens, logprobs
219+
220+
221+
def _compare_logprobs(
222+
sampled_logprobs: list[float],
223+
computed_logprobs: list[float],
224+
) -> dict[str, float]:
225+
"""Compare sampled vs recomputed logprobs."""
226+
min_len = min(len(sampled_logprobs), len(computed_logprobs))
227+
if min_len == 0:
228+
return {
229+
"length": 0.0,
230+
"mean_diff": float("nan"),
231+
"max_diff": float("nan"),
232+
}
233+
234+
diffs = [abs(sampled_logprobs[i] - computed_logprobs[i]) for i in range(min_len)]
235+
236+
return {
237+
"length": float(min_len),
238+
"mean_diff": float(np.mean(diffs)),
239+
"max_diff": float(np.max(diffs)),
240+
}
241+
242+
243+
async def validate_sequence_consistency(
244+
sampling_client: tinker.SamplingClient,
245+
prompt_tokens: list[int],
246+
*,
247+
temperature: float,
248+
length: int,
249+
seed: int | None,
250+
tokenizer,
251+
) -> None:
252+
"""Validate that sample_async(max_tokens > 1) returns accurate per-step logprobs.
253+
254+
Generates a sequence then resamples each position individually to find matching tokens
255+
and compare their logprobs, validating correctness at each step.
256+
"""
257+
print("\n" + "=" * 75)
258+
print("SEQUENCE-LEVEL CONSISTENCY CHECK (multi-token logprob validation)")
259+
print("=" * 75)
260+
print(
261+
f"Generate with max_tokens={length} at temp={temperature}, then resample each position individually to verify logprob consistency."
262+
)
263+
print(f"{'Temp':>8} {'Length':>8} {'Matches':>8} {'Mean Diff':>12} {'Max Diff':>12}")
264+
print("-" * 75)
265+
266+
tau = temperature
267+
gen_tokens, gen_logprobs = await _sample_sequence_oneshot(
268+
sampling_client, prompt_tokens, temperature=tau, max_tokens=length, seed=seed
269+
)
270+
271+
matching_diffs: list[float] = []
272+
num_attempts_per_position = 5
273+
274+
for i in range(len(gen_tokens)):
275+
prefix = prompt_tokens + gen_tokens[:i]
276+
model_input = tinker.ModelInput.from_ints(prefix)
277+
278+
for attempt in range(num_attempts_per_position):
279+
resp = await sampling_client.sample_async(
280+
prompt=model_input,
281+
num_samples=1,
282+
sampling_params=tinker.SamplingParams(
283+
max_tokens=1,
284+
temperature=tau,
285+
seed=(seed + 1000 * (i + 1) + attempt) if seed is not None else None,
286+
),
287+
)
288+
seq = resp.sequences[0]
289+
if not seq.tokens or seq.logprobs is None:
290+
continue
291+
292+
if seq.tokens[0] == gen_tokens[i]:
293+
matching_diffs.append(abs(gen_logprobs[i] - seq.logprobs[0]))
294+
break
295+
296+
if len(matching_diffs) == 0:
297+
print(f"{tau:>8.3f} {len(gen_tokens):>8} {0:>8} {'N/A':>12} {'N/A':>12} {'N/A':>8}")
298+
return
299+
300+
mean_diff = float(np.mean(matching_diffs))
301+
max_diff = float(np.max(matching_diffs))
302+
print(
303+
f"{tau:>8.3f} {len(gen_tokens):>8} {len(matching_diffs):>8} {mean_diff:>12.6f} {max_diff:>12.6f}"
304+
)
305+
print()
306+
307+
308+
async def main(cfg: Config) -> None:
309+
tokenizer = get_tokenizer(cfg.base_model)
310+
prompt_tokens = tokenizer.encode(cfg.prompt)
311+
model_input = tinker.ModelInput.from_ints(prompt_tokens)
312+
313+
service = tinker.ServiceClient(base_url=cfg.base_url)
314+
sampler = service.create_sampling_client(base_model=cfg.base_model)
315+
316+
print("\n" + "=" * 75)
317+
print("TEMPERATURE SCALING VALIDATION")
318+
print("=" * 75)
319+
320+
base_seen = await _collect_sampled_token_logprobs(
321+
sampler,
322+
model_input,
323+
temperature=cfg.baseline_temperature,
324+
num_trials=cfg.num_trials,
325+
max_tokens=1,
326+
seed=cfg.seed,
327+
)
328+
base_logp = await _compute_logp1_for_tokens(sampler, prompt_tokens, list(base_seen))
329+
330+
print(f"Model: {cfg.base_model}, {cfg.num_trials} trials per temperature")
331+
print(f"{'Temp':>8} {'Unique Tokens':>15} {'Pairs':>8} {'Mean Diff':>12} {'Max Diff':>12}")
332+
print("-" * 75)
333+
334+
for tau in cfg.temperatures:
335+
temp_seen = await _collect_sampled_token_logprobs(
336+
sampler,
337+
model_input,
338+
temperature=tau,
339+
num_trials=cfg.num_trials,
340+
max_tokens=1,
341+
seed=cfg.seed,
342+
)
343+
missing = [t for t in temp_seen if t not in base_logp]
344+
if missing:
345+
base_logp.update(await _compute_logp1_for_tokens(sampler, prompt_tokens, missing))
346+
metrics = _pairwise_ratio_metrics(base_logp, temp_seen, tau)
347+
348+
mean_diff = metrics["mean_abs_err"]
349+
max_diff = metrics["max_abs_err"]
350+
print(
351+
f"{tau:>8.3f} {int(metrics['tokens']):>15} {int(metrics['pairs']):>8} {mean_diff:>12.6f} {max_diff:>12.6f}"
352+
)
353+
354+
if cfg.check_sequence_consistency:
355+
await validate_sequence_consistency(
356+
sampler,
357+
prompt_tokens,
358+
temperature=cfg.consistency_check_temp,
359+
length=cfg.consistency_check_length,
360+
seed=cfg.seed,
361+
tokenizer=tokenizer,
362+
)
363+
364+
print()
365+
366+
367+
if __name__ == "__main__":
368+
asyncio.run(chz.nested_entrypoint(main))

0 commit comments

Comments
 (0)