Skip to content

Commit

Permalink
Fix: Resolve conflicts
Browse files Browse the repository at this point in the history
  • Loading branch information
[email protected] committed Dec 9, 2024
1 parent 6903acc commit 3a64029
Show file tree
Hide file tree
Showing 2 changed files with 68 additions and 17 deletions.
4 changes: 0 additions & 4 deletions api/db/services/dialog_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -302,19 +302,16 @@ def decorate_answer(answer):
if delta.isdigit():
# 处理总令牌数(如果需要)
total_tokens = int(delta)
# logging.info(f"Total tokens used: {total_tokens}")
continue
# elif delta in [LENGTH_NOTIFICATION_CN, LENGTH_NOTIFICATION_EN]:
# # 处理长度通知信息
# answer += delta
# # logging.info(f"Length notification: {delta}")
# audio = tts(tts_mdl, delta)
# yield {"answer": answer, "reference": {}, "audio_binary": audio}
# continue
elif "\n**ERROR**:" in delta:
# 处理错误信息
answer += delta
# logging.error(f"Error in response: {delta}")
yield {"answer": answer, "reference": {}, "audio_binary": b''} # 错误时不生成音频
continue

Expand Down Expand Up @@ -344,7 +341,6 @@ def decorate_answer(answer):

# 生成音频
audio = tts(tts_mdl, new_text)
# logging.info(f"Generated audio for new_text: {new_text}")
yield {"answer": answer, "reference": {}, "audio_binary": audio}

# 最终装饰答案
Expand Down
81 changes: 68 additions & 13 deletions rag/llm/chat_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,9 +57,50 @@ def chat(self, system, history, gen_conf):
except openai.APIError as e:
return "**ERROR**: " + str(e), 0

# def chat_streamly(self, system, history, gen_conf):
# if system:
# history.insert(0, {"role": "system", "content": system})
# ans = ""
# total_tokens = 0
# try:
# response = self.client.chat.completions.create(
# model=self.model_name,
# messages=history,
# stream=True,
# **gen_conf)
# for resp in response:
# if not resp.choices:
# continue
# if not resp.choices[0].delta.content:
# resp.choices[0].delta.content = ""
# ans += resp.choices[0].delta.content
#
# if not hasattr(resp, "usage") or not resp.usage:
# total_tokens = (
# total_tokens
# + num_tokens_from_string(resp.choices[0].delta.content)
# )
# elif isinstance(resp.usage, dict):
# total_tokens = resp.usage.get("total_tokens", total_tokens)
# else:
# total_tokens = resp.usage.total_tokens
#
# if resp.choices[0].finish_reason == "length":
# if is_chinese(ans):
# ans += LENGTH_NOTIFICATION_CN
# else:
# ans += LENGTH_NOTIFICATION_EN
# yield ans
#
# except openai.APIError as e:
# yield ans + "\n**ERROR**: " + str(e)
#
# yield total_tokens

def chat_streamly(self, system, history, gen_conf):
if system:
history.insert(0, {"role": "system", "content": system})

ans = ""
total_tokens = 0
try:
Expand All @@ -71,30 +112,44 @@ def chat_streamly(self, system, history, gen_conf):
for resp in response:
if not resp.choices:
continue
if not resp.choices[0].delta.content:
resp.choices[0].delta.content = ""
ans += resp.choices[0].delta.content

if not hasattr(resp, "usage") or not resp.usage:
total_tokens = (
finish_reason = resp.choices[0].finish_reason
delta_content = resp.choices[0].delta.content if resp.choices[0].delta.content else ""

# 如果有新增文本,累积并输出增量
if delta_content:
ans += delta_content

# 更新令牌计数
if not hasattr(resp, "usage") or not resp.usage:
total_tokens = (
total_tokens
+ num_tokens_from_string(resp.choices[0].delta.content)
)
elif isinstance(resp.usage, dict):
total_tokens = resp.usage.get("total_tokens", total_tokens)
else:
total_tokens = resp.usage.total_tokens
elif isinstance(resp.usage, dict):
total_tokens = resp.usage.get("total_tokens", total_tokens)
else:
total_tokens = resp.usage.total_tokens

if resp.choices[0].finish_reason == "length":
yield delta_content

# 即使delta_content为空,也要检查finish_reason
if finish_reason == "length":
# 长度受限时添加提示信息
if is_chinese(ans):
ans += LENGTH_NOTIFICATION_CN
notification = LENGTH_NOTIFICATION_CN
else:
ans += LENGTH_NOTIFICATION_EN
yield ans
notification = LENGTH_NOTIFICATION_EN
yield notification

# 如果finish_reason为"stop"或其他值,可以在此添加相应逻辑
# (本示例中未对"stop"做额外处理,因为通常这意味着回答正常结束)

except openai.APIError as e:
# 返回错误信息
yield ans + "\n**ERROR**: " + str(e)

# 最终返回总令牌数
yield total_tokens


Expand Down

0 comments on commit 3a64029

Please sign in to comment.