Skip to content

Commit

Permalink
Merge pull request #922 from OptimalScale/hymba-support-announcement
Browse files Browse the repository at this point in the history
Hymba support announcement
  • Loading branch information
research4pan authored Dec 9, 2024
2 parents 5723bcc + c31799f commit 3967232
Show file tree
Hide file tree
Showing 5 changed files with 212 additions and 49 deletions.
1 change: 1 addition & 0 deletions docs/source/examples/DATASETS.md
Original file line number Diff line number Diff line change
Expand Up @@ -153,6 +153,7 @@ Conversations should be formatted before feeding into the model. As of now, we'v
| `chatml` | `<\|im_start\|>system`<br>`You are a chatbot developed by LMFlow team.<\|im_end\|>`<br>`<\|im_start\|>user`<br>`Who are you?<\|im_end\|>`<br>`<\|im_start\|>assistant`<br>`I am a chatbot developed by LMFlow team.<\|im_end\|>`<br>`<\|im_start\|>user`<br>`How old are you?<\|im_end\|>`<br>`<\|im_start\|>assistant`<br>`I don't age like humans do. I exist as a piece of software, so I don't have a concept of age in the traditional sense.<\|im_end\|>`<br> | [Link](./supported_conversation_template.md#chatml) |
| `deepseek` | `<|begin▁of▁sentence|>You are a chatbot developed by LMFlow team.`<br><br>`User: Who are you?`<br><br>`Assistant: I am a chatbot developed by LMFlow team.<|end▁of▁sentence|>User: How old are you?`<br><br>`Assistant: I don't age like humans do. I exist as a piece of software, so I don't have a concept of age in the traditional sense.<|end▁of▁sentence|>` | [Link](./supported_conversation_template.md#deepseek) |
| `gemma` | `<bos>You are a chatbot developed by LMFlow team.<start_of_turn>user`<br>`Who are you?<end_of_turn>`<br>`<start_of_turn>model`<br>`I am a chatbot developed by LMFlow team.<end_of_turn>`<br>`<start_of_turn>user`<br>`How old are you?<end_of_turn>`<br>`<start_of_turn>model`<br>`I don't age like humans do. I exist as a piece of software, so I don't have a concept of age in the traditional sense.<end_of_turn>`<br> | [Link](./supported_conversation_template.md#gemma) |
| `hymba` | `<extra_id_0>System`<br>`You are a chatbot developed by LMFlow team.`<br>`<tool> {"name": "generate_qrcode", "description": "Generate a QR code for a given text", "parameters": {"type": "object", "properties": {"text": {"type": "string", "description": "The text to encode in the QR code"}}, "required": ["text"]}} </tool>`<br><br>`<extra_id_1>User`<br>`Who are you?`<br>`<extra_id_1>Assistant`<br>`I am a chatbot developed by LMFlow team.`<br>`<extra_id_1>User`<br>`How old are you?`<br>`<extra_id_1>Assistant`<br>`I don't age like humans do. I exist as a piece of software, so I don't have a concept of age in the traditional sense.</s>` | [Link](./supported_conversation_template.md#hymba) |
| `internlm2` | `<s><\|im_start\|>system`<br>`You are a chatbot developed by LMFlow team.<\|im_end\|>`<br>`<\|im_start\|>user`<br>`Who are you?<\|im_end\|>`<br>`<\|im_start\|>assistant`<br>`I am a chatbot developed by LMFlow team.<\|im_end\|>`<br>`<\|im_start\|>user`<br>`How old are you?<\|im_end\|>`<br>`<\|im_start\|>assistant`<br>`I don't age like humans do. I exist as a piece of software, so I don't have a concept of age in the traditional sense.<\|im_end\|>`<br> | [Link](./supported_conversation_template.md#internlm2) |
| `llama3` | `<\|begin_of_text\|><\|start_header_id\|>system<\|end_header_id\|>`<br><br>`You are a chatbot developed by LMFlow team.<\|eot_id\|><\|start_header_id\|>user<\|end_header_id\|>`<br><br>`Who are you?<\|eot_id\|><\|start_header_id\|>assistant<\|end_header_id\|>`<br><br>`I am a chatbot developed by LMFlow team.<\|eot_id\|><\|start_header_id\|>user<\|end_header_id\|>`<br><br>`How old are you?<\|eot_id\|><\|start_header_id\|>assistant<\|end_header_id\|>`<br><br>`I don't age like humans do. I exist as a piece of software, so I don't have a concept of age in the traditional sense.<\|eot_id\|>` | [Link](./supported_conversation_template.md#llama-3) |
| `llama2` | `<s>[INST] <<SYS>>`<br>`You are a chatbot developed by LMFlow team.`<br>`<</SYS>>`<br><br>`Who are you? [/INST] I am a chatbot developed by LMFlow team.</s><s>[INST] How old are you? [/INST] I don't age like humans do. I exist as a piece of software, so I don't have a concept of age in the traditional sense.</s>` | [Link](./supported_conversation_template.md#llama-2) |
Expand Down
45 changes: 45 additions & 0 deletions docs/source/examples/supported_conversation_template.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
- [ChatML](#chatml)
- [DeepSeek](#deepseek)
- [Gemma](#gemma)
- [Hymba](#hymba)
- [InternLM2](#internlm2)
- [Llama-2](#llama-2)
- [Llama-3](#llama-3)
Expand Down Expand Up @@ -154,6 +155,50 @@ As of now, Gemma does not support system messages officially. `ConversationTempl
```


## Hymba
**With a system message**
```
<extra_id_0>System\n{{system_message}}\n\n<extra_id_1>User\n{{user_message_0}}\n
```
```
<extra_id_0>System\n{{system_message}}\n<tool> {{tool_info}} </tool>\n\n<extra_id_1>User\n{{user_message_0}}\n
```

**Without a system message**
```{admonition} NOTICE
:class: warning
During the training, Hymba always uses special tokens for the system messages even if the system message is not provided.
```
```
<extra_id_0>System\n\n<extra_id_1>User\n{{user_message_0}}\n
```

**A complete conversation**
```
<extra_id_0>System\n{{system_message}}\n\n<extra_id_1>User\n{{user_message_0}}\n<extra_id_1>Assistant\n{{assistant_reply_0}}</s>
```

**Multiple rounds**
```
<extra_id_0>System\n{{system_message}}\n\n<extra_id_1>User\n{{user_message_0}}\n<extra_id_1>Assistant\n{{assistant_reply_0}}\n <extra_id_1>User\n{{user_message_1}}\n<extra_id_1>Assistant\n{{assistant_reply_1}}</s>
```

**jinja template**
[[Reference](https://huggingface.co/nvidia/Hymba-1.5B-Instruct/blob/c02a352b6f7c1138a197d7ae3fd72dcdff919eae/tokenizer_config.json#L40)]
```
{{'<extra_id_0>System'}}{% for message in messages %}{% if message['role'] == 'system' %}{{'\n' + message['content'].strip()}}{% if tools or contexts %}{{'\n'}}{% endif %}{% endif %}{% endfor %}{% if tools %}{% for tool in tools %}{{ '\n<tool> ' + tool|tojson + ' </tool>' }}{% endfor %}{% endif %}{% if contexts %}{% if tools %}{{'\n'}}{% endif %}{% for context in contexts %}{{ '\n<context> ' + context.strip() + ' </context>' }}{% endfor %}{% endif %}{{'\n\n'}}{% for message in messages %}{% if message['role'] == 'user' %}{{ '<extra_id_1>User\n' + message['content'].strip() + '\n' }}{% elif message['role'] == 'assistant' %}{{ '<extra_id_1>Assistant\n' + message['content'].strip() + '\n' }}{% elif message['role'] == 'tool' %}{{ '<extra_id_1>Tool\n' + message['content'].strip() + '\n' }}{% endif %}{% endfor %}{%- if add_generation_prompt %}{{'<extra_id_1>Assistant\n'}}{%- endif %}
```

**Filled Example**
```
<extra_id_0>System\nYou are a chatbot developed by LMFlow team.\n\n<extra_id_1>User\nWho are you?\n<extra_id_1>Assistant\nI am a chatbot developed by LMFlow team.\n<extra_id_1>User\nHow old are you?\n<extra_id_1>Assistant\nI don't age like humans do. I exist as a piece of software, so I don't have a concept of age in the traditional sense.</s>
```
```
<extra_id_0>System\nYou are a chatbot developed by LMFlow team.\n<tool> {"name": "generate_qrcode", "description": "Generate a QR code for a given text", "parameters": {"type": "object", "properties": {"text": {"type": "string", "description": "The text to encode in the QR code"}}, "required": ["text"]}} </tool>\n\n<extra_id_1>User\nWho are you?\n<extra_id_1>Assistant\nI am a chatbot developed by LMFlow team.\n<extra_id_1>User\nHow old are you?\n<extra_id_1>Assistant\nI don\'t age like humans do. I exist as a piece of software, so I don\'t have a concept of age in the traditional sense.</s>
```


## InternLM2
**With a system message**
```
Expand Down
2 changes: 2 additions & 0 deletions src/lmflow/utils/conversation_template/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from .chatml import CHATML_TEMPLATE
from .deepseek import DEEPSEEK_TEMPLATE
from .gemma import GEMMA_TEMPLATE
from .hymba import HYMBA_TEMPLATE
from .internlm import INTERNLM2_TEMPLATE
from .llama import LLAMA2_TEMPLATE, LLAMA3_TEMPLATE, LLAMA3_TEMPLATE_FOR_TOOL
from .phi import PHI3_TEMPLATE
Expand All @@ -22,6 +23,7 @@
'empty': EMPTY_TEMPLATE,
'empty_no_special_tokens': EMPTY_NO_SPECIAL_TOKENS_TEMPLATE,
'gemma': GEMMA_TEMPLATE,
'hymba': HYMBA_TEMPLATE,
'internlm2': INTERNLM2_TEMPLATE,
'llama2': LLAMA2_TEMPLATE,
'llama3': LLAMA3_TEMPLATE,
Expand Down
90 changes: 41 additions & 49 deletions src/lmflow/utils/conversation_template/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,18 +157,20 @@ def format(self, **kwargs) -> list:
class ConversationTemplate:
user_formatter: Formatter
assistant_formatter: Formatter
function_formatter: Optional[Formatter] = None,
observation_formatter: Optional[Formatter] = None,
function_formatter: Optional[Formatter] = None
observation_formatter: Optional[Formatter] = None
system_formatter: Optional[Formatter] = None
force_system: bool = False
tools_formatter: Optional[Formatter] = None
separator: Optional[TemplateComponent] = None
remove_last_sep: bool = False
special_starter: Optional[TemplateComponent] = None
special_stopper: Optional[TemplateComponent] = None
template_name: Optional[str] = None

def __post_init__(self):
if self.separator:
if self.separator.type not in ['string', 'token']:
if self.separator.type not in ['string', 'token', 'token_id']:
raise NotImplementedError(f"Component type {self.separator.type} cannot be used as a separator.")

if self.special_starter:
Expand All @@ -181,7 +183,6 @@ def encode_conversation(
messages: List[Dict[str, str]],
system: Optional[str] = None,
tools: Optional[List[str]] = None,
remove_last_sep: bool = False,
**kwargs
) -> Sequence[Tuple[List[int], List[int]]]:
r'''
Expand Down Expand Up @@ -219,27 +220,7 @@ def encode_conversation(
system = None

encoded_pairs = self._encode(tokenizer, messages, system, tools, **kwargs)

if self.separator and remove_last_sep:
# For models that require a separator between messages,
# user can include the seperator at the end of each template
# and specify the separator. Auto formatting will remove the
# last separator once user specifies this option.
encoded_pairs = self.remove_last_separator(encoded_pairs, tokenizer)

if self.special_starter:
# For models that has ONLY ONE bos token at the beginning of
# a conversation session (not a conversation pair), user can
# specify a special starter to add that starter to the very
# beginning of the conversation session.
# eg:
# llama-2: <s> and </s> at every pair of conversation
# v.s.
# llama-3: <|begin_of_text|> only at the beginning of a session
encoded_pairs = self.add_special_starter(encoded_pairs, tokenizer)

if self.special_stopper:
encoded_pairs = self.add_special_stopper(encoded_pairs, tokenizer)
encoded_pairs = self.post_process_pairs(encoded_pairs=encoded_pairs, tokenizer=tokenizer)

return encoded_pairs

Expand All @@ -256,7 +237,10 @@ def _encode(

res_all = []

system_formatted = self.system_formatter.format(content=system) if system else []
if system:
system_formatted = self.system_formatter.format(content=system)
else:
system_formatted = self.system_formatter.format(content='') if self.force_system else []
system_encoded = self._encode_template(system_formatted, tokenizer)

for i in range(0, len(messages), 2):
Expand Down Expand Up @@ -317,6 +301,30 @@ def _encode_template(
raise NotImplementedError(f"Component type {component.type} is not supported yet.")
return encoded_ids

def post_process_pairs(self, encoded_pairs, tokenizer):
if self.separator and self.remove_last_sep:
# For models that require a separator between messages,
# user can include the seperator at the end of each template
# and specify the separator. Auto formatting will remove the
# last separator once user specifies this option.
encoded_pairs = self.remove_last_separator(encoded_pairs, tokenizer)

if self.special_starter:
# For models that has ONLY ONE bos token at the beginning of
# a conversation session (not a conversation pair), user can
# specify a special starter to add that starter to the very
# beginning of the conversation session.
# eg:
# llama-2: <s> and </s> at every pair of conversation
# v.s.
# llama-3: <|begin_of_text|> only at the beginning of a session
encoded_pairs = self.add_special_starter(encoded_pairs, tokenizer)

if self.special_stopper:
encoded_pairs = self.add_special_stopper(encoded_pairs, tokenizer)

return encoded_pairs

def remove_last_separator(
self,
encoded_pairs: Sequence[Tuple[List[int], List[int]]],
Expand All @@ -327,6 +335,8 @@ def remove_last_separator(
separator_ids = tokenizer.encode(self.separator.content, add_special_tokens=False)
elif self.separator.type == 'token':
separator_ids = self._ensure_id_list(tokenizer.convert_tokens_to_ids(self.separator.content))
elif self.separator.type == 'token_id':
separator_ids = self._ensure_id_list(self.separator.content)
else:
raise ValueError(f"Component type {self.separator.type} cannot be used as a separator.")

Expand Down Expand Up @@ -404,7 +414,6 @@ def encode_conversation(
messages: List[Dict[str, str]],
system: Optional[str] = None,
tools: Optional[List[str]] = None,
remove_last_sep: bool = False,
**kwargs
) -> Sequence[Tuple[List[int], List[int]]]:
r'''
Expand Down Expand Up @@ -446,27 +455,7 @@ def encode_conversation(
else:
system = ""
encoded_pairs = self._encode(tokenizer, messages, system, tools, **kwargs)

if self.separator and remove_last_sep:
# For models that require a separator between messages,
# user can include the seperator at the end of each template
# and specify the separator. Auto formatting will remove the
# last separator once user specifies this option.
encoded_pairs = self.remove_last_separator(encoded_pairs, tokenizer)

if self.special_starter:
# For models that has ONLY ONE bos token at the beginning of
# a conversation session (not a conversation pair), user can
# specify a special starter to add that starter to the very
# beginning of the conversation session.
# eg:
# llama-2: <s> and </s> at every pair of conversation
# v.s.
# llama-3: <|begin_of_text|> only at the beginning of a session
encoded_pairs = self.add_special_starter(encoded_pairs, tokenizer)

if self.special_stopper:
encoded_pairs = self.add_special_stopper(encoded_pairs, tokenizer)
encoded_pairs = self.post_process_pairs(encoded_pairs=encoded_pairs, tokenizer=tokenizer)

return encoded_pairs

Expand All @@ -484,7 +473,10 @@ def _encode(
res_all = []
# Concatenate the system and tools strings
system = system + tools
system_formatted = self.system_formatter.format(content=system) if system else []
if system:
system_formatted = self.system_formatter.format(content=system)
else:
system_formatted = self.system_formatter.format(content='') if self.force_system else []
system_encoded = self._encode_template(system_formatted, tokenizer)
ls_for_save = []
for i in range(0, len(messages), 1):
Expand Down
123 changes: 123 additions & 0 deletions src/lmflow/utils/conversation_template/hymba.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,123 @@
#!/usr/bin/env python
# coding=utf-8
# Copyright 2024 Statistics and Machine Learning Research Group. All rights reserved.
from .base import StringFormatter, TemplateComponent, ConversationTemplateForTool
from typing import Dict, Set, Sequence, Literal, Union, List, Optional, Tuple

from transformers import PreTrainedTokenizer

# NOTE: 'contexts' are not used in sft
# {{'<extra_id_0>System'}}
# {% for message in messages %}
# {% if message['role'] == 'system' %}
# {{'\n' + message['content'].strip()}}
# {% if tools %}
# {{'\n'}}
# {% endif %}
# {% endif %}
# {% endfor %}
# {% if tools %}
# {% for tool in tools %}
# {{ '\n<tool> ' + tool|tojson + ' </tool>' }}
# {% endfor %}
# {% endif %}
# {{'\n\n'}}
# {% for message in messages %}
# {% if message['role'] == 'user' %}
# {{ '<extra_id_1>User\n' + message['content'].strip() + '\n' }}
# {% elif message['role'] == 'assistant' %}
# {{ '<extra_id_1>Assistant\n' + message['content'].strip() + '\n' }}
# {% elif message['role'] == 'tool' %}
# {{ '<extra_id_1>Tool\n' + message['content'].strip() + '\n' }}
# {% endif %}
# {% endfor %}
# {%- if add_generation_prompt %}
# {{'<extra_id_1>Assistant\n'}}
# {%- endif %}


class HymbaConversationTemplate(ConversationTemplateForTool):
def encode_conversation(
self,
tokenizer: PreTrainedTokenizer,
messages: List[Dict[str, str]],
system: Optional[str] = None,
tools: Optional[List[str]] = None,
**kwargs
) -> Sequence[Tuple[List[int], List[int]]]:
r'''
Messages here should be guaranteed to be in pairs, with the first message being the user message and the second message being the system message.
Data example:
```json
{
"conversation_id": 2,
"system": "sysinfo1",
"tools": ["tool_1_desc"],
"messages": [
{
"role": "user",
"content": "hi"
},
{
"role": "assistant",
"content": "Hello!"
}
]
}
```
'''
assert isinstance(messages, list), "Messages must be a list."

tools_out = ''
if tools is not None:
for tool in tools:
tools_out += "\n<tool> " + tool + " </tool>"

if system is None:
system = ""
else:
if system.replace(" ",""): # has actual content
if not self.system_formatter:
raise ValueError("Your dataset contains system message but no system formatter is provided. "
"Consider either providing a system formatter or removing system prompt from your dataset.")
system = '\n' + system
else:
system = ""
encoded_pairs = self._encode(tokenizer, messages, system, tools_out, **kwargs)
encoded_pairs = self.post_process_pairs(encoded_pairs=encoded_pairs, tokenizer=tokenizer)

return encoded_pairs


HYMBA_TEMPLATE = HymbaConversationTemplate(
template_name='hymba',
user_formatter=StringFormatter(
template=[
TemplateComponent(type='string', content='<extra_id_1>User\n{{content}}\n')
]
),
assistant_formatter=StringFormatter(
template=[
TemplateComponent(type='string', content='<extra_id_1>Assistant\n{{content}}\n')
]
),
function_formatter=StringFormatter(
template=[
TemplateComponent(type='string', content='<extra_id_1>Assistant\n{{content}}\n')
]
),
observation_formatter=StringFormatter(
template=[
TemplateComponent(type='string', content='<extra_id_1>Tool\n{{content}}\n')
]
),
system_formatter=StringFormatter(
template=[
TemplateComponent(type='string', content='<extra_id_0>System{{content}}\n\n')
]
),
separator=TemplateComponent(type='token_id', content=13),
remove_last_sep=True,
special_stopper=TemplateComponent(type='token', content='eos_token'),
force_system=True
)

0 comments on commit 3967232

Please sign in to comment.