diff --git a/api/apps/tenant_app.py b/api/apps/tenant_app.py index 42b4d0a14e9..f08dd322764 100644 --- a/api/apps/tenant_app.py +++ b/api/apps/tenant_app.py @@ -119,4 +119,4 @@ def agree(tenant_id): UserTenantService.filter_update([UserTenant.tenant_id == tenant_id, UserTenant.user_id == current_user.id], {"role": UserTenantRole.NORMAL}) return get_json_result(data=True) except Exception as e: - return server_error_response(e) + return server_error_response(e) \ No newline at end of file diff --git a/api/db/services/dialog_service.py b/api/db/services/dialog_service.py index a86e4aad87c..a2cef9c7969 100644 --- a/api/db/services/dialog_service.py +++ b/api/db/services/dialog_service.py @@ -314,20 +314,72 @@ def decorate_answer(answer): prompt = f"{prompt}\n\n - Total: {total_time_cost:.1f}ms\n - Check LLM: {check_llm_time_cost:.1f}ms\n - Create retriever: {create_retriever_time_cost:.1f}ms\n - Bind embedding: {bind_embedding_time_cost:.1f}ms\n - Bind LLM: {bind_llm_time_cost:.1f}ms\n - Tune question: {refine_question_time_cost:.1f}ms\n - Bind reranker: {bind_reranker_time_cost:.1f}ms\n - Generate keyword: {generate_keyword_time_cost:.1f}ms\n - Retrieval: {retrieval_time_cost:.1f}ms\n - Generate answer: {generate_result_time_cost:.1f}ms" return {"answer": answer, "reference": refs, "prompt": prompt} + + #注释原先流式代码 + # if stream: + # last_ans = "" + # answer = "" + # for ans in chat_mdl.chat_streamly(prompt, msg[1:], gen_conf): + # answer = ans + # logging.info("answer_stream : {}".format(ans)) + # delta_ans = ans[len(last_ans):] + # if num_tokens_from_string(delta_ans) < 16: + # continue + # last_ans = answer + # yield {"answer": answer, "reference": {}, "audio_binary": tts(tts_mdl, delta_ans)} + # delta_ans = answer[len(last_ans):] + # if delta_ans: + # yield {"answer": answer, "reference": {}, "audio_binary": tts(tts_mdl, delta_ans)} + # yield decorate_answer(answer) + if stream: - last_ans = "" + # logging.info("stream_mode : {}".format(msg[1:])) answer = "" - for ans in chat_mdl.chat_streamly(prompt, msg[1:], gen_conf): - answer = ans - delta_ans = ans[len(last_ans):] - if num_tokens_from_string(delta_ans) < 16: - continue - last_ans = answer - yield {"answer": answer, "reference": {}, "audio_binary": tts(tts_mdl, delta_ans)} - delta_ans = answer[len(last_ans):] - if delta_ans: - yield {"answer": answer, "reference": {}, "audio_binary": tts(tts_mdl, delta_ans)} - yield decorate_answer(answer) + for delta in chat_mdl.chat_streamly(prompt, msg[1:], gen_conf): + # 检查是否为总令牌数或通知信息 + if isinstance(delta, str): + if delta.isdigit(): + # 理总令牌数(如果需要) + # total_tokens = int(delta) + continue + elif "\n**ERROR**:" in delta: + # 处理错误信息 + answer += delta + yield {"answer": answer, "reference": {}, "audio_binary": b''} # 错误时不生成音频 + continue + + # 处理增量文本 + delta_ans = delta + # if num_tokens_from_string(delta_ans) < 16: + # continue # 根据需求调整阈值 + + # 更新完整的答案 + answer += delta_ans + + # 生成音频 + audio = tts(tts_mdl, delta_ans) + # logging.info(f"Generated audio for delta: {delta_ans}") + yield {"answer": delta_ans, "reference": {}, "audio_binary": audio} + elif isinstance(delta, dict): + # 如果 chat_streamly 仍返回字典(不推荐) + # 例如: {"new_text": "新增内容", "position": 10} + new_text = delta.get("new_text", "") + if not new_text: + continue + if num_tokens_from_string(new_text) < 16: + continue + + # 更新完整的答案 + answer += new_text + + # 生成音频 + audio = tts(tts_mdl, new_text) + yield {"answer": answer, "reference": {}, "audio_binary": audio} + + # 最终装饰答案 + decorated_answer = decorate_answer(answer) + # logging.info(f"Final decorated answer: {decorated_answer}") + yield decorated_answer else: answer = chat_mdl.chat(prompt, msg[1:], gen_conf) logging.debug("User: {}|Assistant: {}".format( diff --git a/poetry.toml b/poetry.toml index 9a48dd825ab..9924e030b26 100644 --- a/poetry.toml +++ b/poetry.toml @@ -1,4 +1,6 @@ [virtualenvs] in-project = true create = true -prefer-active-python = true \ No newline at end of file +prefer-active-python = true +[repositories.tuna] +url = "https://pypi.tuna.tsinghua.edu.cn/simple" diff --git a/rag/llm/chat_model.py b/rag/llm/chat_model.py index cf038cb433c..664a7f8968a 100644 --- a/rag/llm/chat_model.py +++ b/rag/llm/chat_model.py @@ -57,9 +57,50 @@ def chat(self, system, history, gen_conf): except openai.APIError as e: return "**ERROR**: " + str(e), 0 + # def chat_streamly(self, system, history, gen_conf): + # if system: + # history.insert(0, {"role": "system", "content": system}) + # ans = "" + # total_tokens = 0 + # try: + # response = self.client.chat.completions.create( + # model=self.model_name, + # messages=history, + # stream=True, + # **gen_conf) + # for resp in response: + # if not resp.choices: + # continue + # if not resp.choices[0].delta.content: + # resp.choices[0].delta.content = "" + # ans += resp.choices[0].delta.content + # + # if not hasattr(resp, "usage") or not resp.usage: + # total_tokens = ( + # total_tokens + # + num_tokens_from_string(resp.choices[0].delta.content) + # ) + # elif isinstance(resp.usage, dict): + # total_tokens = resp.usage.get("total_tokens", total_tokens) + # else: + # total_tokens = resp.usage.total_tokens + # + # if resp.choices[0].finish_reason == "length": + # if is_chinese(ans): + # ans += LENGTH_NOTIFICATION_CN + # else: + # ans += LENGTH_NOTIFICATION_EN + # yield ans + # + # except openai.APIError as e: + # yield ans + "\n**ERROR**: " + str(e) + # + # yield total_tokens + def chat_streamly(self, system, history, gen_conf): if system: history.insert(0, {"role": "system", "content": system}) + ans = "" total_tokens = 0 try: @@ -71,30 +112,44 @@ def chat_streamly(self, system, history, gen_conf): for resp in response: if not resp.choices: continue - if not resp.choices[0].delta.content: - resp.choices[0].delta.content = "" - ans += resp.choices[0].delta.content - if not hasattr(resp, "usage") or not resp.usage: - total_tokens = ( + finish_reason = resp.choices[0].finish_reason + delta_content = resp.choices[0].delta.content if resp.choices[0].delta.content else "" + + # 如果有新增文本,累积并输出增量 + if delta_content: + ans += delta_content + + # 更新令牌计数 + if not hasattr(resp, "usage") or not resp.usage: + total_tokens = ( total_tokens + num_tokens_from_string(resp.choices[0].delta.content) ) - elif isinstance(resp.usage, dict): - total_tokens = resp.usage.get("total_tokens", total_tokens) - else: - total_tokens = resp.usage.total_tokens + elif isinstance(resp.usage, dict): + total_tokens = resp.usage.get("total_tokens", total_tokens) + else: + total_tokens = resp.usage.total_tokens - if resp.choices[0].finish_reason == "length": + yield delta_content + + # 即使delta_content为空,也要检查finish_reason + if finish_reason == "length": + # 长度受限时添加提示信息 if is_chinese(ans): - ans += LENGTH_NOTIFICATION_CN + notification = LENGTH_NOTIFICATION_CN else: - ans += LENGTH_NOTIFICATION_EN - yield ans + notification = LENGTH_NOTIFICATION_EN + yield notification + + # 如果finish_reason为"stop"或其他值,可以在此添加相应逻辑 + # (本示例中未对"stop"做额外处理,因为通常这意味着回答正常结束) except openai.APIError as e: + # 返回错误信息 yield ans + "\n**ERROR**: " + str(e) + # 最终返回总令牌数 yield total_tokens diff --git a/rag/llm/cv_model.py b/rag/llm/cv_model.py index 74127b3cb5c..17963624fbf 100644 --- a/rag/llm/cv_model.py +++ b/rag/llm/cv_model.py @@ -37,7 +37,7 @@ def __init__(self, key, model_name): def describe(self, image, max_tokens=300): raise NotImplementedError("Please implement encode method!") - + def chat(self, system, history, gen_conf, image=""): if system: history[-1]["content"] = system + history[-1]["content"] + "user query: " + history[-1]["content"] @@ -92,7 +92,7 @@ def chat_streamly(self, system, history, gen_conf, image=""): yield ans + "\n**ERROR**: " + str(e) yield tk_count - + def image2base64(self, image): if isinstance(image, bytes): return base64.b64encode(image).decode("utf-8") @@ -142,7 +142,7 @@ def chat_prompt(self, text, b64): class GptV4(Base): def __init__(self, key, model_name="gpt-4-vision-preview", lang="Chinese", base_url="https://api.openai.com/v1"): if not base_url: - base_url="https://api.openai.com/v1" + base_url = "https://api.openai.com/v1" self.client = OpenAI(api_key=key, base_url=base_url) self.model_name = model_name self.lang = lang @@ -162,6 +162,7 @@ def describe(self, image, max_tokens=300): ) return res.choices[0].message.content.strip(), res.usage.total_tokens + class AzureGptV4(Base): def __init__(self, key, model_name, lang="Chinese", **kwargs): api_key = json.loads(key).get('api_key', '') @@ -220,7 +221,7 @@ def chat_prompt(self, text, b64): {"image": f"{b64}"}, {"text": text}, ] - + def describe(self, image, max_tokens=300): from http import HTTPStatus from dashscope import MultiModalConversation @@ -302,7 +303,7 @@ def describe(self, image, max_tokens=1024): prompt = self.prompt(b64) prompt[0]["content"][1]["type"] = "text" - + res = self.client.chat.completions.create( model=self.model_name, messages=prompt, @@ -341,7 +342,7 @@ def chat_streamly(self, system, history, gen_conf, image=""): his["content"] = self.chat_prompt(his["content"], image) response = self.client.chat.completions.create( - model=self.model_name, + model=self.model_name, messages=history, max_tokens=gen_conf.get("max_tokens", 1000), temperature=gen_conf.get("temperature", 0.3), @@ -484,6 +485,7 @@ def describe(self, image, max_tokens=300): ) return res.choices[0].message.content.strip(), res.usage.total_tokens + class GeminiCV(Base): def __init__(self, key, model_name="gemini-1.0-pro-vision-latest", lang="Chinese", **kwargs): from google.generativeai import client, GenerativeModel @@ -492,21 +494,21 @@ def __init__(self, key, model_name="gemini-1.0-pro-vision-latest", lang="Chinese self.model_name = model_name self.model = GenerativeModel(model_name=self.model_name) self.model._client = _client - self.lang = lang + self.lang = lang def describe(self, image, max_tokens=2048): from PIL.Image import open - gen_config = {'max_output_tokens':max_tokens} + gen_config = {'max_output_tokens': max_tokens} prompt = "请用中文详细描述一下图中的内容,比如时间,地点,人物,事情,人物心情等,如果有数据请提取出数据。" if self.lang.lower() == "chinese" else \ "Please describe the content of this picture, like where, when, who, what happen. If it has number data, please extract them out." - b64 = self.image2base64(image) - img = open(BytesIO(base64.b64decode(b64))) - input = [prompt,img] + b64 = self.image2base64(image) + img = open(BytesIO(base64.b64decode(b64))) + input = [prompt, img] res = self.model.generate_content( input, generation_config=gen_config, ) - return res.text,res.usage_metadata.total_token_count + return res.text, res.usage_metadata.total_token_count def chat(self, system, history, gen_conf, image=""): from transformers import GenerationConfig @@ -566,11 +568,11 @@ def chat_streamly(self, system, history, gen_conf, image=""): class OpenRouterCV(GptV4): def __init__( - self, - key, - model_name, - lang="Chinese", - base_url="https://openrouter.ai/api/v1", + self, + key, + model_name, + lang="Chinese", + base_url="https://openrouter.ai/api/v1", ): if not base_url: base_url = "https://openrouter.ai/api/v1" @@ -589,11 +591,11 @@ def describe(self, image, max_tokens=1024): class NvidiaCV(Base): def __init__( - self, - key, - model_name, - lang="Chinese", - base_url="https://ai.api.nvidia.com/v1/vlm", + self, + key, + model_name, + lang="Chinese", + base_url="https://ai.api.nvidia.com/v1/vlm", ): if not base_url: base_url = ("https://ai.api.nvidia.com/v1/vlm",) @@ -632,11 +634,11 @@ def prompt(self, b64): { "role": "user", "content": ( - "请用中文详细描述一下图中的内容,比如时间,地点,人物,事情,人物心情等,如果有数据请提取出数据。" - if self.lang.lower() == "chinese" - else "Please describe the content of this picture, like where, when, who, what happen. If it has number data, please extract them out." - ) - + f' ', + "请用中文详细描述一下图中的内容,比如时间,地点,人物,事情,人物心情等,如果有数据请提取出数据。" + if self.lang.lower() == "chinese" + else "Please describe the content of this picture, like where, when, who, what happen. If it has number data, please extract them out." + ) + + f' ', } ] @@ -652,7 +654,7 @@ def chat_prompt(self, text, b64): class StepFunCV(GptV4): def __init__(self, key, model_name="step-1v-8k", lang="Chinese", base_url="https://api.stepfun.com/v1"): if not base_url: - base_url="https://api.stepfun.com/v1" + base_url = "https://api.stepfun.com/v1" self.client = OpenAI(api_key=key, base_url=base_url) self.model_name = model_name self.lang = lang @@ -684,18 +686,18 @@ class TogetherAICV(GptV4): def __init__(self, key, model_name, lang="Chinese", base_url="https://api.together.xyz/v1"): if not base_url: base_url = "https://api.together.xyz/v1" - super().__init__(key, model_name,lang,base_url) + super().__init__(key, model_name, lang, base_url) class YiCV(GptV4): - def __init__(self, key, model_name, lang="Chinese",base_url="https://api.lingyiwanwu.com/v1",): + def __init__(self, key, model_name, lang="Chinese", base_url="https://api.lingyiwanwu.com/v1", ): if not base_url: base_url = "https://api.lingyiwanwu.com/v1" - super().__init__(key, model_name,lang,base_url) + super().__init__(key, model_name, lang, base_url) class HunyuanCV(Base): - def __init__(self, key, model_name, lang="Chinese",base_url=None): + def __init__(self, key, model_name, lang="Chinese", base_url=None): from tencentcloud.common import credential from tencentcloud.hunyuan.v20230901 import hunyuan_client @@ -712,7 +714,7 @@ def describe(self, image, max_tokens=4096): from tencentcloud.common.exception.tencent_cloud_sdk_exception import ( TencentCloudSDKException, ) - + b64 = self.image2base64(image) req = models.ChatCompletionsRequest() params = {"Model": self.model_name, "Messages": self.prompt(b64)} @@ -724,7 +726,7 @@ def describe(self, image, max_tokens=4096): return ans, response.Usage.TotalTokens except TencentCloudSDKException as e: return ans + "\n**ERROR**: " + str(e), 0 - + def prompt(self, b64): return [ { diff --git a/web/src/hooks/logic-hooks.ts b/web/src/hooks/logic-hooks.ts index 5c8afebb1ed..ee4a92e3c50 100644 --- a/web/src/hooks/logic-hooks.ts +++ b/web/src/hooks/logic-hooks.ts @@ -195,6 +195,7 @@ export const useSendMessageWithSse = ( .pipeThrough(new EventSourceParserStream()) .getReader(); + let accumulatedText = ''; while (true) { const x = await reader?.read(); if (x) { @@ -209,10 +210,13 @@ export const useSendMessageWithSse = ( const d = val?.data; if (typeof d !== 'boolean') { console.info('data:', d); - setAnswer({ + accumulatedText += d.answer || ''; + const updatedMessage = { ...d, + answer: accumulatedText, conversationId: body?.conversation_id, - }); + }; + setAnswer(updatedMessage); } } catch (e) { console.warn(e);