Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion tests/torchtune/models/qwen2/test_qwen2_tokenizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ def test_tokenize_messages_gt_max_seq_len(self, messages):
assert len(tokens) == 10
assert len(mask) == 10

def test_tokenize_message_drop_eos(self, messages):
def test_tokenize_message_drop_eot_and_eos(self, messages):
tokenizer = self.tokenizer()

# fmt: off
Expand Down
129 changes: 80 additions & 49 deletions torchtune/models/qwen2/_tokenizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -325,6 +325,69 @@ def decode(
text = "".join(sub_texts)
return text

def _tokenize_header(self, message: Message) -> List[int]:
"""
Tokenize header start, message role, and header end as list of ids
"""
return [self.im_start_id] + self.encode(
f"{message.role}\n", add_bos=False, add_eos=False
)

def _tokenize_body(self, message: Message) -> List[int]:
"""
Tokenize message content as list of ids
"""
tokenized_body = []
for item in message.content:
if item["type"] == "text":
tokenized_body += self.encode(
item["content"],
add_bos=False,
add_eos=False,
)
else:
raise RuntimeError(f"Unsupported message content type: {item['type']}")
return tokenized_body

def _tokenize_end(self, message: Message) -> List[int]:
return [self.im_end_id] + self.encode("\n", add_bos=False, add_eos=False)

def tokenize_message(
self,
message: Message,
index: int,
num_messages: int,
*,
add_start_tokens: bool = True,
) -> List[int]:
"""
Tokenize a message into a list of token ids.

Args:
message (Message): The message to tokenize.
index (int): The index of the current message.
num_messages (int): The total number of messages.
add_start_tokens (bool): Whether to prepend a tokenized header to the message. Default is True.

Returns:
List[int]: The list of token ids.
"""
tokenized_header = (
self._tokenize_header(message)
if message.role != "ipython" and add_start_tokens
else []
)
tokenized_body = self._tokenize_body(message)
tokenized_end = (
self._tokenize_end(message)
if message.role != "ipython"
and (message.role != "assistant" or index != num_messages - 1)
else []
)

tokenized_message = tokenized_header + tokenized_body + tokenized_end
return tokenized_message

def tokenize_messages(
self,
messages: list[Message],
Expand All @@ -337,14 +400,11 @@ def tokenize_messages(

Args:
messages (list[Message]): The message list to tokenize.
add_end_tokens (bool): Wether to add the tokenizer's end of message
tokens, such as eos_id. Default is True.
add_end_tokens (bool): Whether to append end tokens ids (end-of-seq, end-of-turn, end-of-message) at the end of the
last assistant message. This value should be set to False for generation. Default is True.

Returns:
tuple[list[int], list[bool]]: The list of token ids and the list of masks.

Raises:
RuntimeError: If a message contains non-text content
"""
assert not isinstance(self.prompt_template, ChatMLTemplate), (
"Using ChatMLTemplate with tokenize_messages will result in multiple <|im_*|> tokens wrapping each message."
Expand All @@ -356,56 +416,27 @@ def tokenize_messages(
else messages
)

tokenized_messages = []
tokens = []
mask = []
for index, message in enumerate(templated_messages):
tokens = []

# message header
if message.role != "ipython":
tokens.append(self.im_start_id)
tokens.extend(
self.encode(f"{message.role}\n", add_bos=False, add_eos=False)
)

# message content
for item in message.content:
if item["type"] == "text":
tokens.extend(
self.encode(
item["content"],
add_bos=False,
add_eos=False,
)
)
else:
raise RuntimeError(
f"Unsupported message content type: {item['type']}"
)

# message footer
if message.role != "ipython" and (
message.role != "assistant" or index != len(messages) - 1
):
tokens.append(self.im_end_id)
tokens.extend(self.encode("\n", add_bos=False, add_eos=False))

tokenized_messages.extend(tokens)
mask.extend([message.masked] * len(tokens))

# Break out early if we reach max_seq_len
if self.max_seq_len and len(tokenized_messages) >= self.max_seq_len:
num_messages = len(templated_messages)
for i, message in enumerate(templated_messages):
tokenized_message = self.tokenize_message(
message, index=i, num_messages=num_messages
)
tokens = tokens + tokenized_message
mask = mask + ([message.masked] * len(tokenized_message))

if self.max_seq_len and len(tokens) >= self.max_seq_len:
break

# Add the End-Of-Sequence token
if add_end_tokens:
tokenized_messages.append(self.eos_id)
mask.append(mask[-1])
tokens = tokens + [self.eos_id]
mask = mask + [mask[-1] if mask else True]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Noob q: why is this different than what we're doing for Llama3? ref

Copy link
Author

@ariG23498 ariG23498 May 13, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If we added True to mask if there was an end token this part of the test fails

expected_mask = [True] * 67 + [False] * 121

due to the fact that it searches for False.

Thanks for catching this, I wanted to ask about the test itself, but forgot.

Do you think I should change this test?


# Finally, truncate if necessary
if self.max_seq_len:
tokenized_messages = truncate(
tokens=tokenized_messages,
tokens = truncate(
tokens=tokens,
max_seq_len=self.max_seq_len,
eos_id=self.eos_id if add_end_tokens else None,
truncation_type=self.truncation_type,
Expand All @@ -417,7 +448,7 @@ def tokenize_messages(
truncation_type=self.truncation_type,
)

return tokenized_messages, mask
return tokens, mask

def __call__(
self, sample: Mapping[str, Any], inference: bool = False
Expand Down
Loading