diff --git a/api/apps/tenant_app.py b/api/apps/tenant_app.py
index 42b4d0a14e9..f08dd322764 100644
--- a/api/apps/tenant_app.py
+++ b/api/apps/tenant_app.py
@@ -119,4 +119,4 @@ def agree(tenant_id):
UserTenantService.filter_update([UserTenant.tenant_id == tenant_id, UserTenant.user_id == current_user.id], {"role": UserTenantRole.NORMAL})
return get_json_result(data=True)
except Exception as e:
- return server_error_response(e)
+ return server_error_response(e)
\ No newline at end of file
diff --git a/api/db/services/dialog_service.py b/api/db/services/dialog_service.py
index a86e4aad87c..a2cef9c7969 100644
--- a/api/db/services/dialog_service.py
+++ b/api/db/services/dialog_service.py
@@ -314,20 +314,72 @@ def decorate_answer(answer):
prompt = f"{prompt}\n\n - Total: {total_time_cost:.1f}ms\n - Check LLM: {check_llm_time_cost:.1f}ms\n - Create retriever: {create_retriever_time_cost:.1f}ms\n - Bind embedding: {bind_embedding_time_cost:.1f}ms\n - Bind LLM: {bind_llm_time_cost:.1f}ms\n - Tune question: {refine_question_time_cost:.1f}ms\n - Bind reranker: {bind_reranker_time_cost:.1f}ms\n - Generate keyword: {generate_keyword_time_cost:.1f}ms\n - Retrieval: {retrieval_time_cost:.1f}ms\n - Generate answer: {generate_result_time_cost:.1f}ms"
return {"answer": answer, "reference": refs, "prompt": prompt}
+
+ #注释原先流式代码
+ # if stream:
+ # last_ans = ""
+ # answer = ""
+ # for ans in chat_mdl.chat_streamly(prompt, msg[1:], gen_conf):
+ # answer = ans
+ # logging.info("answer_stream : {}".format(ans))
+ # delta_ans = ans[len(last_ans):]
+ # if num_tokens_from_string(delta_ans) < 16:
+ # continue
+ # last_ans = answer
+ # yield {"answer": answer, "reference": {}, "audio_binary": tts(tts_mdl, delta_ans)}
+ # delta_ans = answer[len(last_ans):]
+ # if delta_ans:
+ # yield {"answer": answer, "reference": {}, "audio_binary": tts(tts_mdl, delta_ans)}
+ # yield decorate_answer(answer)
+
if stream:
- last_ans = ""
+ # logging.info("stream_mode : {}".format(msg[1:]))
answer = ""
- for ans in chat_mdl.chat_streamly(prompt, msg[1:], gen_conf):
- answer = ans
- delta_ans = ans[len(last_ans):]
- if num_tokens_from_string(delta_ans) < 16:
- continue
- last_ans = answer
- yield {"answer": answer, "reference": {}, "audio_binary": tts(tts_mdl, delta_ans)}
- delta_ans = answer[len(last_ans):]
- if delta_ans:
- yield {"answer": answer, "reference": {}, "audio_binary": tts(tts_mdl, delta_ans)}
- yield decorate_answer(answer)
+ for delta in chat_mdl.chat_streamly(prompt, msg[1:], gen_conf):
+ # 检查是否为总令牌数或通知信息
+ if isinstance(delta, str):
+ if delta.isdigit():
+ # 理总令牌数(如果需要)
+ # total_tokens = int(delta)
+ continue
+ elif "\n**ERROR**:" in delta:
+ # 处理错误信息
+ answer += delta
+ yield {"answer": answer, "reference": {}, "audio_binary": b''} # 错误时不生成音频
+ continue
+
+ # 处理增量文本
+ delta_ans = delta
+ # if num_tokens_from_string(delta_ans) < 16:
+ # continue # 根据需求调整阈值
+
+ # 更新完整的答案
+ answer += delta_ans
+
+ # 生成音频
+ audio = tts(tts_mdl, delta_ans)
+ # logging.info(f"Generated audio for delta: {delta_ans}")
+ yield {"answer": delta_ans, "reference": {}, "audio_binary": audio}
+ elif isinstance(delta, dict):
+ # 如果 chat_streamly 仍返回字典(不推荐)
+ # 例如: {"new_text": "新增内容", "position": 10}
+ new_text = delta.get("new_text", "")
+ if not new_text:
+ continue
+ if num_tokens_from_string(new_text) < 16:
+ continue
+
+ # 更新完整的答案
+ answer += new_text
+
+ # 生成音频
+ audio = tts(tts_mdl, new_text)
+ yield {"answer": answer, "reference": {}, "audio_binary": audio}
+
+ # 最终装饰答案
+ decorated_answer = decorate_answer(answer)
+ # logging.info(f"Final decorated answer: {decorated_answer}")
+ yield decorated_answer
else:
answer = chat_mdl.chat(prompt, msg[1:], gen_conf)
logging.debug("User: {}|Assistant: {}".format(
diff --git a/poetry.toml b/poetry.toml
index 9a48dd825ab..9924e030b26 100644
--- a/poetry.toml
+++ b/poetry.toml
@@ -1,4 +1,6 @@
[virtualenvs]
in-project = true
create = true
-prefer-active-python = true
\ No newline at end of file
+prefer-active-python = true
+[repositories.tuna]
+url = "https://pypi.tuna.tsinghua.edu.cn/simple"
diff --git a/rag/llm/chat_model.py b/rag/llm/chat_model.py
index cf038cb433c..664a7f8968a 100644
--- a/rag/llm/chat_model.py
+++ b/rag/llm/chat_model.py
@@ -57,9 +57,50 @@ def chat(self, system, history, gen_conf):
except openai.APIError as e:
return "**ERROR**: " + str(e), 0
+ # def chat_streamly(self, system, history, gen_conf):
+ # if system:
+ # history.insert(0, {"role": "system", "content": system})
+ # ans = ""
+ # total_tokens = 0
+ # try:
+ # response = self.client.chat.completions.create(
+ # model=self.model_name,
+ # messages=history,
+ # stream=True,
+ # **gen_conf)
+ # for resp in response:
+ # if not resp.choices:
+ # continue
+ # if not resp.choices[0].delta.content:
+ # resp.choices[0].delta.content = ""
+ # ans += resp.choices[0].delta.content
+ #
+ # if not hasattr(resp, "usage") or not resp.usage:
+ # total_tokens = (
+ # total_tokens
+ # + num_tokens_from_string(resp.choices[0].delta.content)
+ # )
+ # elif isinstance(resp.usage, dict):
+ # total_tokens = resp.usage.get("total_tokens", total_tokens)
+ # else:
+ # total_tokens = resp.usage.total_tokens
+ #
+ # if resp.choices[0].finish_reason == "length":
+ # if is_chinese(ans):
+ # ans += LENGTH_NOTIFICATION_CN
+ # else:
+ # ans += LENGTH_NOTIFICATION_EN
+ # yield ans
+ #
+ # except openai.APIError as e:
+ # yield ans + "\n**ERROR**: " + str(e)
+ #
+ # yield total_tokens
+
def chat_streamly(self, system, history, gen_conf):
if system:
history.insert(0, {"role": "system", "content": system})
+
ans = ""
total_tokens = 0
try:
@@ -71,30 +112,44 @@ def chat_streamly(self, system, history, gen_conf):
for resp in response:
if not resp.choices:
continue
- if not resp.choices[0].delta.content:
- resp.choices[0].delta.content = ""
- ans += resp.choices[0].delta.content
- if not hasattr(resp, "usage") or not resp.usage:
- total_tokens = (
+ finish_reason = resp.choices[0].finish_reason
+ delta_content = resp.choices[0].delta.content if resp.choices[0].delta.content else ""
+
+ # 如果有新增文本,累积并输出增量
+ if delta_content:
+ ans += delta_content
+
+ # 更新令牌计数
+ if not hasattr(resp, "usage") or not resp.usage:
+ total_tokens = (
total_tokens
+ num_tokens_from_string(resp.choices[0].delta.content)
)
- elif isinstance(resp.usage, dict):
- total_tokens = resp.usage.get("total_tokens", total_tokens)
- else:
- total_tokens = resp.usage.total_tokens
+ elif isinstance(resp.usage, dict):
+ total_tokens = resp.usage.get("total_tokens", total_tokens)
+ else:
+ total_tokens = resp.usage.total_tokens
- if resp.choices[0].finish_reason == "length":
+ yield delta_content
+
+ # 即使delta_content为空,也要检查finish_reason
+ if finish_reason == "length":
+ # 长度受限时添加提示信息
if is_chinese(ans):
- ans += LENGTH_NOTIFICATION_CN
+ notification = LENGTH_NOTIFICATION_CN
else:
- ans += LENGTH_NOTIFICATION_EN
- yield ans
+ notification = LENGTH_NOTIFICATION_EN
+ yield notification
+
+ # 如果finish_reason为"stop"或其他值,可以在此添加相应逻辑
+ # (本示例中未对"stop"做额外处理,因为通常这意味着回答正常结束)
except openai.APIError as e:
+ # 返回错误信息
yield ans + "\n**ERROR**: " + str(e)
+ # 最终返回总令牌数
yield total_tokens
diff --git a/rag/llm/cv_model.py b/rag/llm/cv_model.py
index 74127b3cb5c..17963624fbf 100644
--- a/rag/llm/cv_model.py
+++ b/rag/llm/cv_model.py
@@ -37,7 +37,7 @@ def __init__(self, key, model_name):
def describe(self, image, max_tokens=300):
raise NotImplementedError("Please implement encode method!")
-
+
def chat(self, system, history, gen_conf, image=""):
if system:
history[-1]["content"] = system + history[-1]["content"] + "user query: " + history[-1]["content"]
@@ -92,7 +92,7 @@ def chat_streamly(self, system, history, gen_conf, image=""):
yield ans + "\n**ERROR**: " + str(e)
yield tk_count
-
+
def image2base64(self, image):
if isinstance(image, bytes):
return base64.b64encode(image).decode("utf-8")
@@ -142,7 +142,7 @@ def chat_prompt(self, text, b64):
class GptV4(Base):
def __init__(self, key, model_name="gpt-4-vision-preview", lang="Chinese", base_url="https://api.openai.com/v1"):
if not base_url:
- base_url="https://api.openai.com/v1"
+ base_url = "https://api.openai.com/v1"
self.client = OpenAI(api_key=key, base_url=base_url)
self.model_name = model_name
self.lang = lang
@@ -162,6 +162,7 @@ def describe(self, image, max_tokens=300):
)
return res.choices[0].message.content.strip(), res.usage.total_tokens
+
class AzureGptV4(Base):
def __init__(self, key, model_name, lang="Chinese", **kwargs):
api_key = json.loads(key).get('api_key', '')
@@ -220,7 +221,7 @@ def chat_prompt(self, text, b64):
{"image": f"{b64}"},
{"text": text},
]
-
+
def describe(self, image, max_tokens=300):
from http import HTTPStatus
from dashscope import MultiModalConversation
@@ -302,7 +303,7 @@ def describe(self, image, max_tokens=1024):
prompt = self.prompt(b64)
prompt[0]["content"][1]["type"] = "text"
-
+
res = self.client.chat.completions.create(
model=self.model_name,
messages=prompt,
@@ -341,7 +342,7 @@ def chat_streamly(self, system, history, gen_conf, image=""):
his["content"] = self.chat_prompt(his["content"], image)
response = self.client.chat.completions.create(
- model=self.model_name,
+ model=self.model_name,
messages=history,
max_tokens=gen_conf.get("max_tokens", 1000),
temperature=gen_conf.get("temperature", 0.3),
@@ -484,6 +485,7 @@ def describe(self, image, max_tokens=300):
)
return res.choices[0].message.content.strip(), res.usage.total_tokens
+
class GeminiCV(Base):
def __init__(self, key, model_name="gemini-1.0-pro-vision-latest", lang="Chinese", **kwargs):
from google.generativeai import client, GenerativeModel
@@ -492,21 +494,21 @@ def __init__(self, key, model_name="gemini-1.0-pro-vision-latest", lang="Chinese
self.model_name = model_name
self.model = GenerativeModel(model_name=self.model_name)
self.model._client = _client
- self.lang = lang
+ self.lang = lang
def describe(self, image, max_tokens=2048):
from PIL.Image import open
- gen_config = {'max_output_tokens':max_tokens}
+ gen_config = {'max_output_tokens': max_tokens}
prompt = "请用中文详细描述一下图中的内容,比如时间,地点,人物,事情,人物心情等,如果有数据请提取出数据。" if self.lang.lower() == "chinese" else \
"Please describe the content of this picture, like where, when, who, what happen. If it has number data, please extract them out."
- b64 = self.image2base64(image)
- img = open(BytesIO(base64.b64decode(b64)))
- input = [prompt,img]
+ b64 = self.image2base64(image)
+ img = open(BytesIO(base64.b64decode(b64)))
+ input = [prompt, img]
res = self.model.generate_content(
input,
generation_config=gen_config,
)
- return res.text,res.usage_metadata.total_token_count
+ return res.text, res.usage_metadata.total_token_count
def chat(self, system, history, gen_conf, image=""):
from transformers import GenerationConfig
@@ -566,11 +568,11 @@ def chat_streamly(self, system, history, gen_conf, image=""):
class OpenRouterCV(GptV4):
def __init__(
- self,
- key,
- model_name,
- lang="Chinese",
- base_url="https://openrouter.ai/api/v1",
+ self,
+ key,
+ model_name,
+ lang="Chinese",
+ base_url="https://openrouter.ai/api/v1",
):
if not base_url:
base_url = "https://openrouter.ai/api/v1"
@@ -589,11 +591,11 @@ def describe(self, image, max_tokens=1024):
class NvidiaCV(Base):
def __init__(
- self,
- key,
- model_name,
- lang="Chinese",
- base_url="https://ai.api.nvidia.com/v1/vlm",
+ self,
+ key,
+ model_name,
+ lang="Chinese",
+ base_url="https://ai.api.nvidia.com/v1/vlm",
):
if not base_url:
base_url = ("https://ai.api.nvidia.com/v1/vlm",)
@@ -632,11 +634,11 @@ def prompt(self, b64):
{
"role": "user",
"content": (
- "请用中文详细描述一下图中的内容,比如时间,地点,人物,事情,人物心情等,如果有数据请提取出数据。"
- if self.lang.lower() == "chinese"
- else "Please describe the content of this picture, like where, when, who, what happen. If it has number data, please extract them out."
- )
- + f' ',
+ "请用中文详细描述一下图中的内容,比如时间,地点,人物,事情,人物心情等,如果有数据请提取出数据。"
+ if self.lang.lower() == "chinese"
+ else "Please describe the content of this picture, like where, when, who, what happen. If it has number data, please extract them out."
+ )
+ + f' ',
}
]
@@ -652,7 +654,7 @@ def chat_prompt(self, text, b64):
class StepFunCV(GptV4):
def __init__(self, key, model_name="step-1v-8k", lang="Chinese", base_url="https://api.stepfun.com/v1"):
if not base_url:
- base_url="https://api.stepfun.com/v1"
+ base_url = "https://api.stepfun.com/v1"
self.client = OpenAI(api_key=key, base_url=base_url)
self.model_name = model_name
self.lang = lang
@@ -684,18 +686,18 @@ class TogetherAICV(GptV4):
def __init__(self, key, model_name, lang="Chinese", base_url="https://api.together.xyz/v1"):
if not base_url:
base_url = "https://api.together.xyz/v1"
- super().__init__(key, model_name,lang,base_url)
+ super().__init__(key, model_name, lang, base_url)
class YiCV(GptV4):
- def __init__(self, key, model_name, lang="Chinese",base_url="https://api.lingyiwanwu.com/v1",):
+ def __init__(self, key, model_name, lang="Chinese", base_url="https://api.lingyiwanwu.com/v1", ):
if not base_url:
base_url = "https://api.lingyiwanwu.com/v1"
- super().__init__(key, model_name,lang,base_url)
+ super().__init__(key, model_name, lang, base_url)
class HunyuanCV(Base):
- def __init__(self, key, model_name, lang="Chinese",base_url=None):
+ def __init__(self, key, model_name, lang="Chinese", base_url=None):
from tencentcloud.common import credential
from tencentcloud.hunyuan.v20230901 import hunyuan_client
@@ -712,7 +714,7 @@ def describe(self, image, max_tokens=4096):
from tencentcloud.common.exception.tencent_cloud_sdk_exception import (
TencentCloudSDKException,
)
-
+
b64 = self.image2base64(image)
req = models.ChatCompletionsRequest()
params = {"Model": self.model_name, "Messages": self.prompt(b64)}
@@ -724,7 +726,7 @@ def describe(self, image, max_tokens=4096):
return ans, response.Usage.TotalTokens
except TencentCloudSDKException as e:
return ans + "\n**ERROR**: " + str(e), 0
-
+
def prompt(self, b64):
return [
{
diff --git a/web/src/hooks/logic-hooks.ts b/web/src/hooks/logic-hooks.ts
index 5c8afebb1ed..ee4a92e3c50 100644
--- a/web/src/hooks/logic-hooks.ts
+++ b/web/src/hooks/logic-hooks.ts
@@ -195,6 +195,7 @@ export const useSendMessageWithSse = (
.pipeThrough(new EventSourceParserStream())
.getReader();
+ let accumulatedText = '';
while (true) {
const x = await reader?.read();
if (x) {
@@ -209,10 +210,13 @@ export const useSendMessageWithSse = (
const d = val?.data;
if (typeof d !== 'boolean') {
console.info('data:', d);
- setAnswer({
+ accumulatedText += d.answer || '';
+ const updatedMessage = {
...d,
+ answer: accumulatedText,
conversationId: body?.conversation_id,
- });
+ };
+ setAnswer(updatedMessage);
}
} catch (e) {
console.warn(e);