Skip to content

Commit

Permalink
feat: 加入通义千问支持
Browse files Browse the repository at this point in the history
  • Loading branch information
GaiZhenbiao committed Oct 19, 2023
1 parent 5dced7c commit 0fd73b9
Show file tree
Hide file tree
Showing 5 changed files with 80 additions and 6 deletions.
57 changes: 57 additions & 0 deletions modules/models/Qwen.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
from transformers import AutoModelForCausalLM, AutoTokenizer
from transformers.generation import GenerationConfig
import logging
import colorama
from .base_model import BaseLLMModel
from ..presets import MODEL_METADATA


class Qwen_Client(BaseLLMModel):
def __init__(self, model_name, user_name="") -> None:
super().__init__(model_name=model_name, user=user_name)
self.tokenizer = AutoTokenizer.from_pretrained(MODEL_METADATA[model_name]["repo_id"], trust_remote_code=True, resume_download=True)
self.model = AutoModelForCausalLM.from_pretrained(MODEL_METADATA[model_name]["repo_id"], device_map="auto", trust_remote_code=True, resume_download=True).eval()

def generation_config(self):
return GenerationConfig.from_dict({
"chat_format": "chatml",
"do_sample": True,
"eos_token_id": 151643,
"max_length": self.token_upper_limit,
"max_new_tokens": 512,
"max_window_size": 6144,
"pad_token_id": 151643,
"top_k": 0,
"top_p": self.top_p,
"transformers_version": "4.33.2",
"trust_remote_code": True,
"temperature": self.temperature,
})

def _get_glm_style_input(self):
history = [x["content"] for x in self.history]
query = history.pop()
logging.debug(colorama.Fore.YELLOW +
f"{history}" + colorama.Fore.RESET)
assert (
len(history) % 2 == 0
), f"History should be even length. current history is: {history}"
history = [[history[i], history[i + 1]]
for i in range(0, len(history), 2)]
return history, query

def get_answer_at_once(self):
history, query = self._get_glm_style_input()
self.model.generation_config = self.generation_config()
response, history = self.model.chat(self.tokenizer, query, history=history)
return response, len(response)

def get_answer_stream_iter(self):
history, query = self._get_glm_style_input()
self.model.generation_config = self.generation_config()
for response in self.model.chat_stream(
self.tokenizer,
query,
history,
):
yield response
12 changes: 7 additions & 5 deletions modules/models/base_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,6 +146,7 @@ class ModelType(Enum):
Spark = 12
OpenAIInstruct = 13
Claude = 14
Qwen = 15

@classmethod
def get_type(cls, model_name: str):
Expand Down Expand Up @@ -181,7 +182,9 @@ def get_type(cls, model_name: str):
elif "星火大模型" in model_name_lower:
model_type = ModelType.Spark
elif "claude" in model_name_lower:
model_type = ModelType.Claude
model_type = ModelType.Claude
elif "qwen" in model_name_lower:
model_type = ModelType.Qwen
else:
model_type = ModelType.LLaMA
return model_type
Expand Down Expand Up @@ -656,14 +659,13 @@ def delete_first_conversation(self):
def delete_last_conversation(self, chatbot):
if len(chatbot) > 0 and STANDARD_ERROR_MSG in chatbot[-1][1]:
msg = "由于包含报错信息,只删除chatbot记录"
chatbot.pop()
chatbot = chatbot[:-1]
return chatbot, self.history
if len(self.history) > 0:
self.history.pop()
self.history.pop()
self.history = self.history[:-2]
if len(chatbot) > 0:
msg = "删除了一组chatbot对话"
chatbot.pop()
chatbot = chatbot[:-1]
if len(self.all_token_counts) > 0:
msg = "删除了一组对话的token计数记录"
self.all_token_counts.pop()
Expand Down
5 changes: 4 additions & 1 deletion modules/models/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,9 +116,12 @@ def get_model(
from .spark import Spark_Client
model = Spark_Client(model_name, os.getenv("SPARK_APPID"), os.getenv(
"SPARK_API_KEY"), os.getenv("SPARK_API_SECRET"), user_name=user_name)
elif model_type == ModelType.Claude:
elif model_type == ModelType.Claude:
from .Claude import Claude_Client
model = Claude_Client(model_name="claude-2", api_secret=os.getenv("CLAUDE_API_SECRET"))
elif model_type == ModelType.Qwen:
from .Qwen import Qwen_Client
model = Qwen_Client(model_name, user_name=user_name)
elif model_type == ModelType.Unknown:
raise ValueError(f"未知模型: {model_name}")
logging.info(msg)
Expand Down
8 changes: 8 additions & 0 deletions modules/presets.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,8 @@
"StableLM",
"MOSS",
"Llama-2-7B-Chat",
"Qwen 7B",
"Qwen 14B"
]

# Additional metadate for local models
Expand All @@ -98,6 +100,12 @@
"Llama-2-7B-Chat":{
"repo_id": "TheBloke/Llama-2-7b-Chat-GGUF",
"filelist": ["llama-2-7b-chat.Q6_K.gguf"],
},
"Qwen 7B": {
"repo_id": "Qwen/Qwen-7B-Chat-Int4",
},
"Qwen 14B": {
"repo_id": "Qwen/Qwen-14B-Chat-Int4",
}
}

Expand Down
4 changes: 4 additions & 0 deletions requirements_advanced.txt
Original file line number Diff line number Diff line change
Expand Up @@ -6,3 +6,7 @@ sentence_transformers
accelerate
sentencepiece
llama-cpp-python
transformers_stream_generator
einops
optimum
auto-gptq

0 comments on commit 0fd73b9

Please sign in to comment.