-
Notifications
You must be signed in to change notification settings - Fork 2.3k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
5dced7c
commit 0fd73b9
Showing
5 changed files
with
80 additions
and
6 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,57 @@ | ||
from transformers import AutoModelForCausalLM, AutoTokenizer | ||
from transformers.generation import GenerationConfig | ||
import logging | ||
import colorama | ||
from .base_model import BaseLLMModel | ||
from ..presets import MODEL_METADATA | ||
|
||
|
||
class Qwen_Client(BaseLLMModel): | ||
def __init__(self, model_name, user_name="") -> None: | ||
super().__init__(model_name=model_name, user=user_name) | ||
self.tokenizer = AutoTokenizer.from_pretrained(MODEL_METADATA[model_name]["repo_id"], trust_remote_code=True, resume_download=True) | ||
self.model = AutoModelForCausalLM.from_pretrained(MODEL_METADATA[model_name]["repo_id"], device_map="auto", trust_remote_code=True, resume_download=True).eval() | ||
|
||
def generation_config(self): | ||
return GenerationConfig.from_dict({ | ||
"chat_format": "chatml", | ||
"do_sample": True, | ||
"eos_token_id": 151643, | ||
"max_length": self.token_upper_limit, | ||
"max_new_tokens": 512, | ||
"max_window_size": 6144, | ||
"pad_token_id": 151643, | ||
"top_k": 0, | ||
"top_p": self.top_p, | ||
"transformers_version": "4.33.2", | ||
"trust_remote_code": True, | ||
"temperature": self.temperature, | ||
}) | ||
|
||
def _get_glm_style_input(self): | ||
history = [x["content"] for x in self.history] | ||
query = history.pop() | ||
logging.debug(colorama.Fore.YELLOW + | ||
f"{history}" + colorama.Fore.RESET) | ||
assert ( | ||
len(history) % 2 == 0 | ||
), f"History should be even length. current history is: {history}" | ||
history = [[history[i], history[i + 1]] | ||
for i in range(0, len(history), 2)] | ||
return history, query | ||
|
||
def get_answer_at_once(self): | ||
history, query = self._get_glm_style_input() | ||
self.model.generation_config = self.generation_config() | ||
response, history = self.model.chat(self.tokenizer, query, history=history) | ||
return response, len(response) | ||
|
||
def get_answer_stream_iter(self): | ||
history, query = self._get_glm_style_input() | ||
self.model.generation_config = self.generation_config() | ||
for response in self.model.chat_stream( | ||
self.tokenizer, | ||
query, | ||
history, | ||
): | ||
yield response |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters