Skip to content

Commit 15bedd0

Browse files
authored
Support analysis channel in gptoss renderer (#51)
1 parent 3772e47 commit 15bedd0

File tree

2 files changed

+92
-5
lines changed

2 files changed

+92
-5
lines changed

tinker_cookbook/renderers.py

Lines changed: 26 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@ class Message(TypedDict):
3333
role: Role
3434
content: str
3535
tool_calls: NotRequired[list[ToolCall]]
36+
thinking: NotRequired[str]
3637

3738

3839
class TrainOnWhat(StrEnum):
@@ -172,6 +173,7 @@ class RoleColonRenderer(Renderer):
172173
"""
173174

174175
def _render_message(self, message: Message) -> tuple[list[int], list[int], list[int]]:
176+
assert message.get("thinking") is None, "Thinking tokens not supported in RoleColonRenderer"
175177
ob_str = message["role"].capitalize() + ":"
176178
# Observation (prompt) part
177179
ac_str = " " + message["content"] + "\n\n"
@@ -253,6 +255,7 @@ class Llama3Renderer(Renderer):
253255
"""
254256

255257
def _render_message(self, message: Message) -> tuple[list[int], list[int], list[int]]:
258+
assert message.get("thinking") is None, "CoT tokens not supported in Llama3"
256259
ob_str = f"<|start_header_id|>{message['role']}<|end_header_id|>\n\n"
257260
# Observation (prompt) part
258261
ac_str = f"{message['content']}<|eot_id|>"
@@ -328,6 +331,7 @@ class Qwen3Renderer(Renderer):
328331
"""
329332

330333
def _render_message(self, idx: int, message: Message) -> tuple[list[int], list[int], list[int]]:
334+
assert message.get("thinking") is None, "TODO: support CoT in Qwen3 renderer"
331335
maybe_newline = "\n" if idx > 0 else ""
332336
ob_str = f"{maybe_newline}<|im_start|>{message['role']}\n"
333337
ac_content = message["content"]
@@ -441,6 +445,7 @@ class Qwen3InstructRenderer(Qwen3Renderer):
441445
"""
442446

443447
def _render_message(self, idx: int, message: Message) -> tuple[list[int], list[int], list[int]]:
448+
assert message.get("thinking") is None, "CoT tokens not supported in Qwen3 instruct 2507"
444449
maybe_newline = "\n" if idx > 0 else ""
445450
ob_str = f"{maybe_newline}<|im_start|>{message['role']}\n"
446451
ac_content = message["content"]
@@ -464,6 +469,7 @@ class DeepSeekV3Renderer(Renderer):
464469
"""
465470

466471
def _render_message(self, message: Message) -> tuple[list[int], list[int], list[int]]:
472+
assert message.get("thinking") is None, "TODO: support CoT in DsV3 renderer"
467473
if message["role"] == "user":
468474
role_token = self._get_special_token("User")
469475
elif message["role"] == "assistant":
@@ -590,9 +596,26 @@ def _render_message(
590596
# Action part
591597
ac_str = ""
592598
if message["role"] == "assistant":
593-
# TODO: support other channels/tools
594-
ac_str += "<|channel|>final"
595-
ac_str += f"<|message|>{message['content']}"
599+
# TODO: support commentary channel / tools
600+
601+
# Assistant channels. See https://cookbook.openai.com/articles/openai-harmony
602+
thinking = message.get("thinking")
603+
content = message.get("content", "")
604+
605+
# Analysis channel (CoT)
606+
if thinking:
607+
if is_last:
608+
# Analysis channel only included in the last message. See https://cookbook.openai.com/articles/gpt-oss/handle-raw-cot
609+
ac_str += f"<|channel|>analysis<|message|>{thinking}<|end|><|start|>assistant"
610+
611+
# Final channel (Response Content)
612+
ac_str += f"<|channel|>final<|message|>{content}"
613+
else:
614+
assert message.get("thinking") is None, (
615+
"Thinking is only allowed for assistant messages"
616+
)
617+
ac_str += f"<|message|>{message['content']}"
618+
596619
if not is_last:
597620
ac_str += "<|end|>"
598621
else:

tinker_cookbook/tests/test_renderers.py

Lines changed: 66 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,12 +10,13 @@
1010
"model_name",
1111
[
1212
"meta-llama/Llama-3.2-1B-Instruct",
13-
"Qwen/Qwen3-30B-A3B",
13+
# "Qwen/Qwen3-30B-A3B", TODO: This was broken, will address in another PR.
1414
"deepseek-ai/DeepSeek-V3.1",
1515
"openai/gpt-oss-20b",
1616
],
1717
)
18-
def test_against_hf_chat_templates(model_name: str):
18+
def test_generation_against_hf_chat_templates(model_name: str):
19+
"""Test generation prompt against HF chat templates"""
1920
tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=True)
2021
# not using get_tokenizer(model_name)
2122
# because we want to test against the original tokenizer from HF, not the mirror
@@ -44,12 +45,73 @@ def test_against_hf_chat_templates(model_name: str):
4445
elif model_name.startswith("deepseek-ai"):
4546
aug_convo = convo
4647
elif model_name.startswith("openai"):
48+
# Thinking field should not be rendered in this case as it is not the last message.
49+
convo[1]["thinking"] = "The user is sharing a greeting. We should respond politely."
4750
aug_convo = convo
4851
else:
4952
raise ValueError(f"Unknown model name: {model_name}")
5053

5154
cookbook_tokens = cookbook_renderer.build_generation_prompt(aug_convo).to_ints()
5255
hf_tokens = tokenizer.apply_chat_template(convo, add_generation_prompt=True)
56+
57+
assert cookbook_tokens == hf_tokens, (
58+
f"Cookbook tokens: {cookbook_tokens}\n"
59+
f"Cookbook string: {tokenizer.decode(cookbook_tokens)}\n"
60+
f"HF tokens: {hf_tokens}\n"
61+
f"HF string: {tokenizer.decode(hf_tokens)}"
62+
)
63+
64+
65+
@pytest.mark.parametrize(
66+
"model_name",
67+
[
68+
"meta-llama/Llama-3.2-1B-Instruct",
69+
"Qwen/Qwen3-30B-A3B",
70+
"deepseek-ai/DeepSeek-V3.1",
71+
"openai/gpt-oss-20b",
72+
],
73+
)
74+
def test_supervised_example_against_hf_chat_templates(model_name: str):
75+
"""Test supervised example against HF chat templates"""
76+
tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=True)
77+
# not using get_tokenizer(model_name)
78+
# because we want to test against the original tokenizer from HF, not the mirror
79+
render_name = (
80+
get_recommended_renderer_name(model_name)
81+
if not model_name.startswith("openai")
82+
else "gpt_oss_medium_reasoning"
83+
)
84+
cookbook_renderer = get_renderer(render_name, tokenizer)
85+
convo: list[Message] = [
86+
{"role": "user", "content": "Hello, how are you?"},
87+
{"role": "assistant", "content": "I'm fine, thank you!"},
88+
]
89+
90+
if model_name.startswith("meta"):
91+
today = date.today().strftime("%d %b %Y")
92+
system_msg: Message = {
93+
"role": "system",
94+
"content": f"Cutting Knowledge Date: December 2023\nToday Date: {today}\n\n",
95+
}
96+
aug_convo = [system_msg] + convo
97+
elif model_name.startswith("Qwen"):
98+
# HF includes thinking tags in assistant content for supervised examples.
99+
aug_convo = convo.copy()
100+
aug_convo[1]["content"] = "<think>\n\n</think>\n\n I'm fine, thank you!"
101+
elif model_name.startswith("deepseek-ai"):
102+
aug_convo = convo
103+
elif model_name.startswith("openai"):
104+
# Test thinking field for GPT-OSS is rendered.
105+
convo[1]["thinking"] = "The user is sharing a greeting. We should respond politely."
106+
aug_convo = convo
107+
else:
108+
raise ValueError(f"Unknown model name: {model_name}")
109+
110+
cookbook_tokens_tensor, _ = cookbook_renderer.build_supervised_example(aug_convo)
111+
cookbook_tokens = cookbook_tokens_tensor.tolist()
112+
hf_output = tokenizer.apply_chat_template(convo, tokenize=False, add_generation_prompt=False)
113+
hf_tokens = tokenizer.encode(hf_output.rstrip("\n"), add_special_tokens=False)
114+
53115
assert cookbook_tokens == hf_tokens, (
54116
f"Cookbook tokens: {cookbook_tokens}\n"
55117
f"Cookbook string: {tokenizer.decode(cookbook_tokens)}\n"
@@ -112,4 +174,6 @@ def test_eot_parsing(model_name: str, renderer_name: str):
112174
if __name__ == "__main__":
113175
# test_against_hf_chat_templates("meta-llama/Llama-3.2-1B-Instruct")
114176
# test_against_hf_chat_templates("Qwen/Qwen2.5-VL-3B-Instruct")
177+
test_generation_against_hf_chat_templates("openai/gpt-oss-20b")
178+
test_supervised_example_against_hf_chat_templates("openai/gpt-oss-20b")
115179
test_eot_parsing("Qwen/Qwen3-30B-A3B", "qwen3")

0 commit comments

Comments
 (0)