diff --git a/tinker_cookbook/renderers.py b/tinker_cookbook/renderers.py index 5e6c876..cc32e03 100644 --- a/tinker_cookbook/renderers.py +++ b/tinker_cookbook/renderers.py @@ -33,6 +33,7 @@ class Message(TypedDict): role: Role content: str tool_calls: NotRequired[list[ToolCall]] + thinking: NotRequired[str] class TrainOnWhat(StrEnum): @@ -172,6 +173,7 @@ class RoleColonRenderer(Renderer): """ def _render_message(self, message: Message) -> tuple[list[int], list[int], list[int]]: + assert message.get("thinking") is None, "Thinking tokens not supported in RoleColonRenderer" ob_str = message["role"].capitalize() + ":" # Observation (prompt) part ac_str = " " + message["content"] + "\n\n" @@ -253,6 +255,7 @@ class Llama3Renderer(Renderer): """ def _render_message(self, message: Message) -> tuple[list[int], list[int], list[int]]: + assert message.get("thinking") is None, "CoT tokens not supported in Llama3" ob_str = f"<|start_header_id|>{message['role']}<|end_header_id|>\n\n" # Observation (prompt) part ac_str = f"{message['content']}<|eot_id|>" @@ -328,6 +331,7 @@ class Qwen3Renderer(Renderer): """ def _render_message(self, idx: int, message: Message) -> tuple[list[int], list[int], list[int]]: + assert message.get("thinking") is None, "TODO: support CoT in Qwen3 renderer" maybe_newline = "\n" if idx > 0 else "" ob_str = f"{maybe_newline}<|im_start|>{message['role']}\n" ac_content = message["content"] @@ -441,6 +445,7 @@ class Qwen3InstructRenderer(Qwen3Renderer): """ def _render_message(self, idx: int, message: Message) -> tuple[list[int], list[int], list[int]]: + assert message.get("thinking") is None, "CoT tokens not supported in Qwen3 instruct 2507" maybe_newline = "\n" if idx > 0 else "" ob_str = f"{maybe_newline}<|im_start|>{message['role']}\n" ac_content = message["content"] @@ -464,6 +469,7 @@ class DeepSeekV3Renderer(Renderer): """ def _render_message(self, message: Message) -> tuple[list[int], list[int], list[int]]: + assert message.get("thinking") is None, "TODO: support CoT in DsV3 renderer" if message["role"] == "user": role_token = self._get_special_token("User") elif message["role"] == "assistant": @@ -590,9 +596,26 @@ def _render_message( # Action part ac_str = "" if message["role"] == "assistant": - # TODO: support other channels/tools - ac_str += "<|channel|>final" - ac_str += f"<|message|>{message['content']}" + # TODO: support commentary channel / tools + + # Assistant channels. See https://cookbook.openai.com/articles/openai-harmony + thinking = message.get("thinking") + content = message.get("content", "") + + # Analysis channel (CoT) + if thinking: + if is_last: + # Analysis channel only included in the last message. See https://cookbook.openai.com/articles/gpt-oss/handle-raw-cot + ac_str += f"<|channel|>analysis<|message|>{thinking}<|end|><|start|>assistant" + + # Final channel (Response Content) + ac_str += f"<|channel|>final<|message|>{content}" + else: + assert message.get("thinking") is None, ( + "Thinking is only allowed for assistant messages" + ) + ac_str += f"<|message|>{message['content']}" + if not is_last: ac_str += "<|end|>" else: diff --git a/tinker_cookbook/tests/test_renderers.py b/tinker_cookbook/tests/test_renderers.py index 14e25f6..f7d57e3 100644 --- a/tinker_cookbook/tests/test_renderers.py +++ b/tinker_cookbook/tests/test_renderers.py @@ -10,12 +10,13 @@ "model_name", [ "meta-llama/Llama-3.2-1B-Instruct", - "Qwen/Qwen3-30B-A3B", + # "Qwen/Qwen3-30B-A3B", TODO: This was broken, will address in another PR. "deepseek-ai/DeepSeek-V3.1", "openai/gpt-oss-20b", ], ) -def test_against_hf_chat_templates(model_name: str): +def test_generation_against_hf_chat_templates(model_name: str): + """Test generation prompt against HF chat templates""" tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=True) # not using get_tokenizer(model_name) # because we want to test against the original tokenizer from HF, not the mirror @@ -44,12 +45,73 @@ def test_against_hf_chat_templates(model_name: str): elif model_name.startswith("deepseek-ai"): aug_convo = convo elif model_name.startswith("openai"): + # Thinking field should not be rendered in this case as it is not the last message. + convo[1]["thinking"] = "The user is sharing a greeting. We should respond politely." aug_convo = convo else: raise ValueError(f"Unknown model name: {model_name}") cookbook_tokens = cookbook_renderer.build_generation_prompt(aug_convo).to_ints() hf_tokens = tokenizer.apply_chat_template(convo, add_generation_prompt=True) + + assert cookbook_tokens == hf_tokens, ( + f"Cookbook tokens: {cookbook_tokens}\n" + f"Cookbook string: {tokenizer.decode(cookbook_tokens)}\n" + f"HF tokens: {hf_tokens}\n" + f"HF string: {tokenizer.decode(hf_tokens)}" + ) + + +@pytest.mark.parametrize( + "model_name", + [ + "meta-llama/Llama-3.2-1B-Instruct", + "Qwen/Qwen3-30B-A3B", + "deepseek-ai/DeepSeek-V3.1", + "openai/gpt-oss-20b", + ], +) +def test_supervised_example_against_hf_chat_templates(model_name: str): + """Test supervised example against HF chat templates""" + tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=True) + # not using get_tokenizer(model_name) + # because we want to test against the original tokenizer from HF, not the mirror + render_name = ( + get_recommended_renderer_name(model_name) + if not model_name.startswith("openai") + else "gpt_oss_medium_reasoning" + ) + cookbook_renderer = get_renderer(render_name, tokenizer) + convo: list[Message] = [ + {"role": "user", "content": "Hello, how are you?"}, + {"role": "assistant", "content": "I'm fine, thank you!"}, + ] + + if model_name.startswith("meta"): + today = date.today().strftime("%d %b %Y") + system_msg: Message = { + "role": "system", + "content": f"Cutting Knowledge Date: December 2023\nToday Date: {today}\n\n", + } + aug_convo = [system_msg] + convo + elif model_name.startswith("Qwen"): + # HF includes thinking tags in assistant content for supervised examples. + aug_convo = convo.copy() + aug_convo[1]["content"] = "\n\n\n\n I'm fine, thank you!" + elif model_name.startswith("deepseek-ai"): + aug_convo = convo + elif model_name.startswith("openai"): + # Test thinking field for GPT-OSS is rendered. + convo[1]["thinking"] = "The user is sharing a greeting. We should respond politely." + aug_convo = convo + else: + raise ValueError(f"Unknown model name: {model_name}") + + cookbook_tokens_tensor, _ = cookbook_renderer.build_supervised_example(aug_convo) + cookbook_tokens = cookbook_tokens_tensor.tolist() + hf_output = tokenizer.apply_chat_template(convo, tokenize=False, add_generation_prompt=False) + hf_tokens = tokenizer.encode(hf_output.rstrip("\n"), add_special_tokens=False) + assert cookbook_tokens == hf_tokens, ( f"Cookbook tokens: {cookbook_tokens}\n" f"Cookbook string: {tokenizer.decode(cookbook_tokens)}\n" @@ -112,4 +174,6 @@ def test_eot_parsing(model_name: str, renderer_name: str): if __name__ == "__main__": # test_against_hf_chat_templates("meta-llama/Llama-3.2-1B-Instruct") # test_against_hf_chat_templates("Qwen/Qwen2.5-VL-3B-Instruct") + test_generation_against_hf_chat_templates("openai/gpt-oss-20b") + test_supervised_example_against_hf_chat_templates("openai/gpt-oss-20b") test_eot_parsing("Qwen/Qwen3-30B-A3B", "qwen3")