Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 26 additions & 3 deletions tinker_cookbook/renderers.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ class Message(TypedDict):
role: Role
content: str
tool_calls: NotRequired[list[ToolCall]]
thinking: NotRequired[str]


class TrainOnWhat(StrEnum):
Expand Down Expand Up @@ -172,6 +173,7 @@ class RoleColonRenderer(Renderer):
"""

def _render_message(self, message: Message) -> tuple[list[int], list[int], list[int]]:
assert message.get("thinking") is None, "Thinking tokens not supported in RoleColonRenderer"
ob_str = message["role"].capitalize() + ":"
# Observation (prompt) part
ac_str = " " + message["content"] + "\n\n"
Expand Down Expand Up @@ -253,6 +255,7 @@ class Llama3Renderer(Renderer):
"""

def _render_message(self, message: Message) -> tuple[list[int], list[int], list[int]]:
assert message.get("thinking") is None, "CoT tokens not supported in Llama3"
ob_str = f"<|start_header_id|>{message['role']}<|end_header_id|>\n\n"
# Observation (prompt) part
ac_str = f"{message['content']}<|eot_id|>"
Expand Down Expand Up @@ -328,6 +331,7 @@ class Qwen3Renderer(Renderer):
"""

def _render_message(self, idx: int, message: Message) -> tuple[list[int], list[int], list[int]]:
assert message.get("thinking") is None, "TODO: support CoT in Qwen3 renderer"
maybe_newline = "\n" if idx > 0 else ""
ob_str = f"{maybe_newline}<|im_start|>{message['role']}\n"
ac_content = message["content"]
Expand Down Expand Up @@ -441,6 +445,7 @@ class Qwen3InstructRenderer(Qwen3Renderer):
"""

def _render_message(self, idx: int, message: Message) -> tuple[list[int], list[int], list[int]]:
assert message.get("thinking") is None, "CoT tokens not supported in Qwen3 instruct 2507"
maybe_newline = "\n" if idx > 0 else ""
ob_str = f"{maybe_newline}<|im_start|>{message['role']}\n"
ac_content = message["content"]
Expand All @@ -464,6 +469,7 @@ class DeepSeekV3Renderer(Renderer):
"""

def _render_message(self, message: Message) -> tuple[list[int], list[int], list[int]]:
assert message.get("thinking") is None, "TODO: support CoT in DsV3 renderer"
if message["role"] == "user":
role_token = self._get_special_token("User")
elif message["role"] == "assistant":
Expand Down Expand Up @@ -590,9 +596,26 @@ def _render_message(
# Action part
ac_str = ""
if message["role"] == "assistant":
# TODO: support other channels/tools
ac_str += "<|channel|>final"
ac_str += f"<|message|>{message['content']}"
# TODO: support commentary channel / tools

# Assistant channels. See https://cookbook.openai.com/articles/openai-harmony
thinking = message.get("thinking")
content = message.get("content", "")

# Analysis channel (CoT)
if thinking:
if is_last:
# Analysis channel only included in the last message. See https://cookbook.openai.com/articles/gpt-oss/handle-raw-cot
ac_str += f"<|channel|>analysis<|message|>{thinking}<|end|><|start|>assistant"

# Final channel (Response Content)
ac_str += f"<|channel|>final<|message|>{content}"
else:
assert message.get("thinking") is None, (
"Thinking is only allowed for assistant messages"
)
ac_str += f"<|message|>{message['content']}"

if not is_last:
ac_str += "<|end|>"
else:
Expand Down
68 changes: 66 additions & 2 deletions tinker_cookbook/tests/test_renderers.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,12 +10,13 @@
"model_name",
[
"meta-llama/Llama-3.2-1B-Instruct",
"Qwen/Qwen3-30B-A3B",
# "Qwen/Qwen3-30B-A3B", TODO: This was broken, will address in another PR.
"deepseek-ai/DeepSeek-V3.1",
"openai/gpt-oss-20b",
],
)
def test_against_hf_chat_templates(model_name: str):
def test_generation_against_hf_chat_templates(model_name: str):
"""Test generation prompt against HF chat templates"""
tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=True)
# not using get_tokenizer(model_name)
# because we want to test against the original tokenizer from HF, not the mirror
Expand Down Expand Up @@ -44,12 +45,73 @@ def test_against_hf_chat_templates(model_name: str):
elif model_name.startswith("deepseek-ai"):
aug_convo = convo
elif model_name.startswith("openai"):
# Thinking field should not be rendered in this case as it is not the last message.
convo[1]["thinking"] = "The user is sharing a greeting. We should respond politely."
aug_convo = convo
else:
raise ValueError(f"Unknown model name: {model_name}")

cookbook_tokens = cookbook_renderer.build_generation_prompt(aug_convo).to_ints()
hf_tokens = tokenizer.apply_chat_template(convo, add_generation_prompt=True)

assert cookbook_tokens == hf_tokens, (
f"Cookbook tokens: {cookbook_tokens}\n"
f"Cookbook string: {tokenizer.decode(cookbook_tokens)}\n"
f"HF tokens: {hf_tokens}\n"
f"HF string: {tokenizer.decode(hf_tokens)}"
)


@pytest.mark.parametrize(
"model_name",
[
"meta-llama/Llama-3.2-1B-Instruct",
"Qwen/Qwen3-30B-A3B",
"deepseek-ai/DeepSeek-V3.1",
"openai/gpt-oss-20b",
],
)
def test_supervised_example_against_hf_chat_templates(model_name: str):
"""Test supervised example against HF chat templates"""
tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=True)
# not using get_tokenizer(model_name)
# because we want to test against the original tokenizer from HF, not the mirror
render_name = (
get_recommended_renderer_name(model_name)
if not model_name.startswith("openai")
else "gpt_oss_medium_reasoning"
)
cookbook_renderer = get_renderer(render_name, tokenizer)
convo: list[Message] = [
{"role": "user", "content": "Hello, how are you?"},
{"role": "assistant", "content": "I'm fine, thank you!"},
]

if model_name.startswith("meta"):
today = date.today().strftime("%d %b %Y")
system_msg: Message = {
"role": "system",
"content": f"Cutting Knowledge Date: December 2023\nToday Date: {today}\n\n",
}
aug_convo = [system_msg] + convo
elif model_name.startswith("Qwen"):
# HF includes thinking tags in assistant content for supervised examples.
aug_convo = convo.copy()
aug_convo[1]["content"] = "<think>\n\n</think>\n\n I'm fine, thank you!"
elif model_name.startswith("deepseek-ai"):
aug_convo = convo
elif model_name.startswith("openai"):
# Test thinking field for GPT-OSS is rendered.
convo[1]["thinking"] = "The user is sharing a greeting. We should respond politely."
aug_convo = convo
else:
raise ValueError(f"Unknown model name: {model_name}")

cookbook_tokens_tensor, _ = cookbook_renderer.build_supervised_example(aug_convo)
cookbook_tokens = cookbook_tokens_tensor.tolist()
hf_output = tokenizer.apply_chat_template(convo, tokenize=False, add_generation_prompt=False)
hf_tokens = tokenizer.encode(hf_output.rstrip("\n"), add_special_tokens=False)

assert cookbook_tokens == hf_tokens, (
f"Cookbook tokens: {cookbook_tokens}\n"
f"Cookbook string: {tokenizer.decode(cookbook_tokens)}\n"
Expand Down Expand Up @@ -112,4 +174,6 @@ def test_eot_parsing(model_name: str, renderer_name: str):
if __name__ == "__main__":
# test_against_hf_chat_templates("meta-llama/Llama-3.2-1B-Instruct")
# test_against_hf_chat_templates("Qwen/Qwen2.5-VL-3B-Instruct")
test_generation_against_hf_chat_templates("openai/gpt-oss-20b")
test_supervised_example_against_hf_chat_templates("openai/gpt-oss-20b")
test_eot_parsing("Qwen/Qwen3-30B-A3B", "qwen3")