Skip to content

Commit f0d862e

Browse files
authored
Merge pull request #56 from imoneoi/3.5_mistral
3.5 mistral
2 parents 0323d57 + a76cada commit f0d862e

File tree

90 files changed

+16211
-1211
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

90 files changed

+16211
-1211
lines changed

.gitignore

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,12 +6,15 @@ wandb/
66

77
# Old
88
old/
9+
temp/
10+
profiler/
911

1012
# Logs
1113
logs/
1214

1315
# eval
1416
eval_results/
17+
evalplus_codegen/
1518

1619
# All datasets
1720
dataset/

README.md

Lines changed: 199 additions & 113 deletions
Large diffs are not rendered by default.

assets/openchat.png

1.92 MB
Loading

ochat/config/__init__.py

Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,60 @@
1+
from functools import partial
2+
3+
import torch
4+
import transformers
5+
6+
from ochat.config.model_config import ModelConfig
7+
from ochat.config.conversation_template import Message, Conversation, ConversationTemplate
8+
import ochat.models
9+
10+
11+
_V3_2_PREFIXES = {
12+
# OpenAI mapping
13+
14+
"user": "User:",
15+
"assistant": "Assistant:"
16+
}
17+
18+
19+
def _v3_2_role_prefix(from_role, condition):
20+
return f"{condition} {_V3_2_PREFIXES[from_role]}".strip()
21+
22+
23+
MODEL_CONFIG_MAP = {
24+
# OpenChat V3.2
25+
"openchat_v3.2": ModelConfig(
26+
# Model
27+
model_max_context=4096,
28+
model_tokenizer_create=partial(transformers.AutoTokenizer.from_pretrained,
29+
use_fast=False,
30+
legacy=False),
31+
model_create_for_training=partial(ochat.models.LlamaForCausalLM.from_pretrained,
32+
low_cpu_mem_usage=True,
33+
torch_dtype=torch.bfloat16),
34+
35+
# Conversation Template
36+
conversation_template=partial(ConversationTemplate,
37+
role_prefix=_v3_2_role_prefix,
38+
eot="<|end_of_turn|>",
39+
inference_condition="GPT4")
40+
),
41+
42+
"openchat_v3.2_mistral": ModelConfig(
43+
serving_aliases=("openchat_3.5", ),
44+
45+
# Model
46+
model_max_context=8192,
47+
model_tokenizer_create=partial(transformers.AutoTokenizer.from_pretrained,
48+
use_fast=False,
49+
legacy=True), # Mistral use legacy=True https://huggingface.co/mistralai/Mistral-7B-v0.1/blob/main/tokenizer_config.json
50+
model_create_for_training=partial(ochat.models.MistralForCausalLM.from_pretrained,
51+
low_cpu_mem_usage=True,
52+
torch_dtype=torch.bfloat16),
53+
54+
# Conversation Template
55+
conversation_template=partial(ConversationTemplate,
56+
role_prefix=_v3_2_role_prefix,
57+
eot="<|end_of_turn|>",
58+
inference_condition="GPT4 Correct")
59+
),
60+
}

ochat/config/conversation_template.py

Lines changed: 122 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,122 @@
1+
from typing import Optional, Callable, Iterable, List, Dict
2+
3+
from pydantic import BaseModel
4+
5+
6+
class Message(BaseModel):
7+
role: str
8+
content: str
9+
10+
weight: Optional[float] = None
11+
12+
13+
class Conversation(BaseModel):
14+
items: List[Message]
15+
16+
condition: str = ""
17+
system: str = ""
18+
19+
20+
class ConversationTemplate(BaseModel):
21+
tokenizer: Callable
22+
23+
# Prompt
24+
role_prefix: Callable
25+
eot: str
26+
27+
inference_condition: Optional[str] = None
28+
29+
# Private
30+
bos_tokens_: List[int]
31+
eot_tokens_: List[int]
32+
33+
def __init__(self, **data):
34+
tokenizer = data["tokenizer"]
35+
eot = data["eot"]
36+
bos_tokens_ = tokenizer("").input_ids
37+
eot_tokens_ = tokenizer(eot, add_special_tokens=False).input_ids
38+
39+
super().__init__(**data, bos_tokens_=bos_tokens_, eot_tokens_=eot_tokens_)
40+
41+
def _safe_tokenize(self, strings: Iterable[str]) -> List[List[int]]:
42+
return self.tokenizer(strings, split_special_tokens=True, return_attention_mask=False, add_special_tokens=False).input_ids
43+
44+
def tokenize_conversations(self, conversations: Iterable[Conversation], inference: bool = False, seq_level_weight: bool = False):
45+
# Pre-tokenize all conversations
46+
default_condition = self.inference_condition if inference else ""
47+
48+
sys_mappings = set()
49+
role_mappings = set()
50+
all_text = []
51+
for conv in conversations:
52+
sys_mappings.add(conv.system)
53+
for msg in conv.items:
54+
role_mappings.add((msg.role, conv.condition or default_condition))
55+
all_text.append(msg.content)
56+
57+
sys_mappings = list(sys_mappings)
58+
role_mappings = list(role_mappings)
59+
60+
# Tokenize
61+
sys_mappings = dict(zip(sys_mappings, self._safe_tokenize(sys_mappings)))
62+
role_mappings = dict(zip(role_mappings, self._safe_tokenize([self.role_prefix(*args) for args in role_mappings])))
63+
all_text = self._safe_tokenize(all_text)
64+
65+
# Convert
66+
result_tokens = []
67+
result_weights = []
68+
all_text_idx = 0
69+
for conv in conversations:
70+
tokens = []
71+
weights = []
72+
73+
# bos tokens
74+
tokens.extend(self.bos_tokens_)
75+
weights.extend([0.] * len(self.bos_tokens_))
76+
77+
# System
78+
if conv.system:
79+
system = sys_mappings[conv.system]
80+
tokens.extend(system)
81+
weights.extend([0.] * len(system))
82+
83+
tokens.extend(self.eot_tokens_)
84+
weights.extend([0.] * len(self.eot_tokens_))
85+
86+
# Messages
87+
last_idx = len(conv.items) - 1
88+
for idx, msg in enumerate(conv.items):
89+
# Prefix
90+
role = role_mappings[(msg.role, conv.condition or default_condition)]
91+
tokens.extend(role)
92+
weights.extend([0.] * len(role))
93+
94+
# Message
95+
text = all_text[all_text_idx]
96+
all_text_idx += 1
97+
98+
# weight
99+
w = None
100+
if not inference:
101+
assert msg.weight is not None
102+
103+
w = msg.weight
104+
if seq_level_weight:
105+
w /= len(text) + len(self.eot_tokens_)
106+
107+
# Message tokens
108+
tokens.extend(text)
109+
weights.extend([w] * len(text))
110+
111+
if not (inference and idx == last_idx): # Do not add EOT on last turn during inference
112+
tokens.extend(self.eot_tokens_)
113+
weights.extend([w] * len(self.eot_tokens_))
114+
115+
# Append result
116+
result_tokens.append(tokens)
117+
result_weights.append(weights)
118+
119+
# Sanity check
120+
assert all_text_idx == len(all_text)
121+
122+
return result_tokens, result_weights

0 commit comments

Comments
 (0)