|
10 | 10 | "model_name", |
11 | 11 | [ |
12 | 12 | "meta-llama/Llama-3.2-1B-Instruct", |
13 | | - "Qwen/Qwen3-30B-A3B", |
| 13 | + # "Qwen/Qwen3-30B-A3B", TODO: This was broken, will address in another PR. |
14 | 14 | "deepseek-ai/DeepSeek-V3.1", |
15 | 15 | "openai/gpt-oss-20b", |
16 | 16 | ], |
17 | 17 | ) |
18 | | -def test_against_hf_chat_templates(model_name: str): |
| 18 | +def test_generation_against_hf_chat_templates(model_name: str): |
| 19 | + """Test generation prompt against HF chat templates""" |
19 | 20 | tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=True) |
20 | 21 | # not using get_tokenizer(model_name) |
21 | 22 | # because we want to test against the original tokenizer from HF, not the mirror |
@@ -44,12 +45,73 @@ def test_against_hf_chat_templates(model_name: str): |
44 | 45 | elif model_name.startswith("deepseek-ai"): |
45 | 46 | aug_convo = convo |
46 | 47 | elif model_name.startswith("openai"): |
| 48 | + # Thinking field should not be rendered in this case as it is not the last message. |
| 49 | + convo[1]["thinking"] = "The user is sharing a greeting. We should respond politely." |
47 | 50 | aug_convo = convo |
48 | 51 | else: |
49 | 52 | raise ValueError(f"Unknown model name: {model_name}") |
50 | 53 |
|
51 | 54 | cookbook_tokens = cookbook_renderer.build_generation_prompt(aug_convo).to_ints() |
52 | 55 | hf_tokens = tokenizer.apply_chat_template(convo, add_generation_prompt=True) |
| 56 | + |
| 57 | + assert cookbook_tokens == hf_tokens, ( |
| 58 | + f"Cookbook tokens: {cookbook_tokens}\n" |
| 59 | + f"Cookbook string: {tokenizer.decode(cookbook_tokens)}\n" |
| 60 | + f"HF tokens: {hf_tokens}\n" |
| 61 | + f"HF string: {tokenizer.decode(hf_tokens)}" |
| 62 | + ) |
| 63 | + |
| 64 | + |
| 65 | +@pytest.mark.parametrize( |
| 66 | + "model_name", |
| 67 | + [ |
| 68 | + "meta-llama/Llama-3.2-1B-Instruct", |
| 69 | + "Qwen/Qwen3-30B-A3B", |
| 70 | + "deepseek-ai/DeepSeek-V3.1", |
| 71 | + "openai/gpt-oss-20b", |
| 72 | + ], |
| 73 | +) |
| 74 | +def test_supervised_example_against_hf_chat_templates(model_name: str): |
| 75 | + """Test supervised example against HF chat templates""" |
| 76 | + tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=True) |
| 77 | + # not using get_tokenizer(model_name) |
| 78 | + # because we want to test against the original tokenizer from HF, not the mirror |
| 79 | + render_name = ( |
| 80 | + get_recommended_renderer_name(model_name) |
| 81 | + if not model_name.startswith("openai") |
| 82 | + else "gpt_oss_medium_reasoning" |
| 83 | + ) |
| 84 | + cookbook_renderer = get_renderer(render_name, tokenizer) |
| 85 | + convo: list[Message] = [ |
| 86 | + {"role": "user", "content": "Hello, how are you?"}, |
| 87 | + {"role": "assistant", "content": "I'm fine, thank you!"}, |
| 88 | + ] |
| 89 | + |
| 90 | + if model_name.startswith("meta"): |
| 91 | + today = date.today().strftime("%d %b %Y") |
| 92 | + system_msg: Message = { |
| 93 | + "role": "system", |
| 94 | + "content": f"Cutting Knowledge Date: December 2023\nToday Date: {today}\n\n", |
| 95 | + } |
| 96 | + aug_convo = [system_msg] + convo |
| 97 | + elif model_name.startswith("Qwen"): |
| 98 | + # HF includes thinking tags in assistant content for supervised examples. |
| 99 | + aug_convo = convo.copy() |
| 100 | + aug_convo[1]["content"] = "<think>\n\n</think>\n\n I'm fine, thank you!" |
| 101 | + elif model_name.startswith("deepseek-ai"): |
| 102 | + aug_convo = convo |
| 103 | + elif model_name.startswith("openai"): |
| 104 | + # Test thinking field for GPT-OSS is rendered. |
| 105 | + convo[1]["thinking"] = "The user is sharing a greeting. We should respond politely." |
| 106 | + aug_convo = convo |
| 107 | + else: |
| 108 | + raise ValueError(f"Unknown model name: {model_name}") |
| 109 | + |
| 110 | + cookbook_tokens_tensor, _ = cookbook_renderer.build_supervised_example(aug_convo) |
| 111 | + cookbook_tokens = cookbook_tokens_tensor.tolist() |
| 112 | + hf_output = tokenizer.apply_chat_template(convo, tokenize=False, add_generation_prompt=False) |
| 113 | + hf_tokens = tokenizer.encode(hf_output.rstrip("\n"), add_special_tokens=False) |
| 114 | + |
53 | 115 | assert cookbook_tokens == hf_tokens, ( |
54 | 116 | f"Cookbook tokens: {cookbook_tokens}\n" |
55 | 117 | f"Cookbook string: {tokenizer.decode(cookbook_tokens)}\n" |
@@ -112,4 +174,6 @@ def test_eot_parsing(model_name: str, renderer_name: str): |
112 | 174 | if __name__ == "__main__": |
113 | 175 | # test_against_hf_chat_templates("meta-llama/Llama-3.2-1B-Instruct") |
114 | 176 | # test_against_hf_chat_templates("Qwen/Qwen2.5-VL-3B-Instruct") |
| 177 | + test_generation_against_hf_chat_templates("openai/gpt-oss-20b") |
| 178 | + test_supervised_example_against_hf_chat_templates("openai/gpt-oss-20b") |
115 | 179 | test_eot_parsing("Qwen/Qwen3-30B-A3B", "qwen3") |
0 commit comments