diff --git a/tests/torchtune/models/qwen2/test_qwen2_tokenizer.py b/tests/torchtune/models/qwen2/test_qwen2_tokenizer.py index fd9ecec8b1..9272f780c7 100644 --- a/tests/torchtune/models/qwen2/test_qwen2_tokenizer.py +++ b/tests/torchtune/models/qwen2/test_qwen2_tokenizer.py @@ -16,6 +16,7 @@ class TestQwenTokenizer: def tokenizer(self, max_seq_len: Optional[int] = None): + print(ASSETS) return qwen2_tokenizer( path=str(ASSETS / "tiny_bpe_vocab.json"), merges_file=str(ASSETS / "tiny_bpe_merges.txt"), @@ -81,7 +82,7 @@ def test_tokenize_messages_gt_max_seq_len(self, messages): assert len(tokens) == 10 assert len(mask) == 10 - def test_tokenize_message_drop_eos(self, messages): + def test_tokenize_message_drop_eot_and_eos(self, messages): tokenizer = self.tokenizer() # fmt: off @@ -99,6 +100,6 @@ def test_tokenize_message_drop_eos(self, messages): # fmt: on expected_mask = [True] * 67 + [False] * 120 - tokens, mask = tokenizer.tokenize_messages(messages, add_eos=False) + tokens, mask = tokenizer.tokenize_messages(messages, add_end_tokens=False) assert tokens == expected_tokens assert mask == expected_mask diff --git a/torchtune/models/qwen2/_tokenizer.py b/torchtune/models/qwen2/_tokenizer.py index ef29b20d59..3ae03a49fb 100644 --- a/torchtune/models/qwen2/_tokenizer.py +++ b/torchtune/models/qwen2/_tokenizer.py @@ -325,11 +325,70 @@ def decode( text = "".join(sub_texts) return text + def _tokenize_header(self, message: Message) -> List[int]: + """ + Tokenize header start, message role, and header end as list of ids + """ + return ( + [self.im_start_id] + + self.encode(f"{message.role}\n", add_bos=False, add_eos=False) + ) + + def _tokenize_body(self, message: Message) -> List[int]: + """ + Tokenize message content as list of ids + """ + tokenized_body = [] + for item in message.content: + if item["type"] == "text": + tokenized_body += self.encode( + item["content"], + add_bos=False, + add_eos=False, + ) + else: + raise RuntimeError( + f"Unsupported message content type: {item['type']}" + ) + return tokenized_body + + def _tokenize_end(self, message: Message) -> List[int]: + return ( + [self.im_end_id] + + self.encode("\n", add_bos=False, add_eos=False) + ) + + def tokenize_message( + self, + message: Message, + index: int, + num_messages: int, + *, + add_start_tokens: bool = True, + ) -> List[int]: + """ + Tokenize a message into a list of token ids. + + Args: + message (Message): The message to tokenize. + add_start_tokens (bool): Whether to prepend a tokenized header to the message. Default is True. + add_end_tokens (bool): Whether to append eot or eom id at the end of the message. Default is True. + + Returns: + List[int]: The list of token ids. + """ + tokenized_header = self._tokenize_header(message) if message.role != "ipython" and add_start_tokens else [] + tokenized_body = self._tokenize_body(message) + tokenized_end = self._tokenize_end(message) if message.role != "ipython" and (message.role != "assistant" or index != num_messages - 1) else [] + + tokenized_message = tokenized_header + tokenized_body + tokenized_end + return tokenized_message + def tokenize_messages( self, messages: List[Message], *, - add_eos: bool = True, + add_end_tokens: bool = True, ) -> Tuple[List[int], List[bool]]: """ Given a list of messages, return a list of tokens for the concatenated @@ -356,68 +415,37 @@ def tokenize_messages( else messages ) - tokenized_messages = [] + tokens = [] mask = [] - for index, message in enumerate(templated_messages): - tokens = [] - - # message header - if message.role != "ipython": - tokens.append(self.im_start_id) - tokens.extend( - self.encode(f"{message.role}\n", add_bos=False, add_eos=False) - ) - - # message content - for item in message.content: - if item["type"] == "text": - tokens.extend( - self.encode( - item["content"], - add_bos=False, - add_eos=False, - ) - ) - else: - raise RuntimeError( - f"Unsupported message content type: {item['type']}" - ) - - # message footer - if message.role != "ipython" and ( - message.role != "assistant" or index != len(messages) - 1 - ): - tokens.append(self.im_end_id) - tokens.extend(self.encode("\n", add_bos=False, add_eos=False)) - tokenized_messages.extend(tokens) - mask.extend([message.masked] * len(tokens)) + num_messages = len(templated_messages) + for i, message in enumerate(templated_messages): + tokenized_message = self.tokenize_message(message, index=i, num_messages=num_messages) + tokens = tokens + tokenized_message + mask = mask + ([message.masked] * len(tokenized_message)) - # Break out early if we reach max_seq_len - if self.max_seq_len and len(tokenized_messages) >= self.max_seq_len: + if self.max_seq_len and len(tokens) >= self.max_seq_len: break - # Add the End-Of-Sequence token - if add_eos: - tokenized_messages.append(self.eos_id) - mask.append(mask[-1]) + if add_end_tokens: + tokens = tokens + [self.eos_id] + mask = mask + [mask[-1] if mask else True] - # Finally, truncate if necessary if self.max_seq_len: - tokenized_messages = truncate( - tokens=tokenized_messages, + tokens = truncate( + tokens=tokens, max_seq_len=self.max_seq_len, - eos_id=self.eos_id if add_eos else None, + eos_id=self.eos_id if add_end_tokens else None, truncation_type=self.truncation_type, ) mask = truncate( tokens=mask, max_seq_len=self.max_seq_len, - eos_id=True if add_eos else None, + eos_id=True if add_end_tokens else None, truncation_type=self.truncation_type, ) - return tokenized_messages, mask + return tokens, mask def __call__( self, sample: Mapping[str, Any], inference: bool = False @@ -436,7 +464,7 @@ def __call__( inference (bool): Whether the template is being used for inference or not. """ messages = sample.pop("messages") - tokens, mask = self.tokenize_messages(messages) + tokens, mask = self.tokenize_messages(messages, add_end_tokens=not inference) sample["tokens"] = tokens sample["mask"] = mask - return sample + return sample \ No newline at end of file