diff --git a/api/db/services/dialog_service.py b/api/db/services/dialog_service.py index e7a438031af..a092a50cad7 100644 --- a/api/db/services/dialog_service.py +++ b/api/db/services/dialog_service.py @@ -302,19 +302,16 @@ def decorate_answer(answer): if delta.isdigit(): # 处理总令牌数(如果需要) total_tokens = int(delta) - # logging.info(f"Total tokens used: {total_tokens}") continue # elif delta in [LENGTH_NOTIFICATION_CN, LENGTH_NOTIFICATION_EN]: # # 处理长度通知信息 # answer += delta - # # logging.info(f"Length notification: {delta}") # audio = tts(tts_mdl, delta) # yield {"answer": answer, "reference": {}, "audio_binary": audio} # continue elif "\n**ERROR**:" in delta: # 处理错误信息 answer += delta - # logging.error(f"Error in response: {delta}") yield {"answer": answer, "reference": {}, "audio_binary": b''} # 错误时不生成音频 continue @@ -344,7 +341,6 @@ def decorate_answer(answer): # 生成音频 audio = tts(tts_mdl, new_text) - # logging.info(f"Generated audio for new_text: {new_text}") yield {"answer": answer, "reference": {}, "audio_binary": audio} # 最终装饰答案 diff --git a/rag/llm/chat_model.py b/rag/llm/chat_model.py index cf038cb433c..664a7f8968a 100644 --- a/rag/llm/chat_model.py +++ b/rag/llm/chat_model.py @@ -57,9 +57,50 @@ def chat(self, system, history, gen_conf): except openai.APIError as e: return "**ERROR**: " + str(e), 0 + # def chat_streamly(self, system, history, gen_conf): + # if system: + # history.insert(0, {"role": "system", "content": system}) + # ans = "" + # total_tokens = 0 + # try: + # response = self.client.chat.completions.create( + # model=self.model_name, + # messages=history, + # stream=True, + # **gen_conf) + # for resp in response: + # if not resp.choices: + # continue + # if not resp.choices[0].delta.content: + # resp.choices[0].delta.content = "" + # ans += resp.choices[0].delta.content + # + # if not hasattr(resp, "usage") or not resp.usage: + # total_tokens = ( + # total_tokens + # + num_tokens_from_string(resp.choices[0].delta.content) + # ) + # elif isinstance(resp.usage, dict): + # total_tokens = resp.usage.get("total_tokens", total_tokens) + # else: + # total_tokens = resp.usage.total_tokens + # + # if resp.choices[0].finish_reason == "length": + # if is_chinese(ans): + # ans += LENGTH_NOTIFICATION_CN + # else: + # ans += LENGTH_NOTIFICATION_EN + # yield ans + # + # except openai.APIError as e: + # yield ans + "\n**ERROR**: " + str(e) + # + # yield total_tokens + def chat_streamly(self, system, history, gen_conf): if system: history.insert(0, {"role": "system", "content": system}) + ans = "" total_tokens = 0 try: @@ -71,30 +112,44 @@ def chat_streamly(self, system, history, gen_conf): for resp in response: if not resp.choices: continue - if not resp.choices[0].delta.content: - resp.choices[0].delta.content = "" - ans += resp.choices[0].delta.content - if not hasattr(resp, "usage") or not resp.usage: - total_tokens = ( + finish_reason = resp.choices[0].finish_reason + delta_content = resp.choices[0].delta.content if resp.choices[0].delta.content else "" + + # 如果有新增文本,累积并输出增量 + if delta_content: + ans += delta_content + + # 更新令牌计数 + if not hasattr(resp, "usage") or not resp.usage: + total_tokens = ( total_tokens + num_tokens_from_string(resp.choices[0].delta.content) ) - elif isinstance(resp.usage, dict): - total_tokens = resp.usage.get("total_tokens", total_tokens) - else: - total_tokens = resp.usage.total_tokens + elif isinstance(resp.usage, dict): + total_tokens = resp.usage.get("total_tokens", total_tokens) + else: + total_tokens = resp.usage.total_tokens - if resp.choices[0].finish_reason == "length": + yield delta_content + + # 即使delta_content为空,也要检查finish_reason + if finish_reason == "length": + # 长度受限时添加提示信息 if is_chinese(ans): - ans += LENGTH_NOTIFICATION_CN + notification = LENGTH_NOTIFICATION_CN else: - ans += LENGTH_NOTIFICATION_EN - yield ans + notification = LENGTH_NOTIFICATION_EN + yield notification + + # 如果finish_reason为"stop"或其他值,可以在此添加相应逻辑 + # (本示例中未对"stop"做额外处理,因为通常这意味着回答正常结束) except openai.APIError as e: + # 返回错误信息 yield ans + "\n**ERROR**: " + str(e) + # 最终返回总令牌数 yield total_tokens