Skip to content

Commit 598b72f

Browse files
authored
fix: chat bugs (#3308)
1 parent 03ec0f3 commit 598b72f

File tree

19 files changed

+83
-68
lines changed

19 files changed

+83
-68
lines changed

apps/application/chat_pipeline/step/chat_step/i_chat_step.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -75,7 +75,7 @@ class InstanceSerializer(serializers.Serializer):
7575
no_references_setting = NoReferencesSetting(required=True,
7676
label=_("No reference segment settings"))
7777

78-
user_id = serializers.UUIDField(required=True, label=_("User ID"))
78+
workspace_id = serializers.CharField(required=True, label=_("Workspace ID"))
7979

8080
model_setting = serializers.DictField(required=True, allow_null=True,
8181
label=_("Model settings"))
@@ -102,7 +102,7 @@ def execute(self, message_list: List[BaseMessage],
102102
chat_id, problem_text,
103103
post_response_handler: PostResponseHandler,
104104
model_id: str = None,
105-
user_id: str = None,
105+
workspace_id: str = None,
106106
paragraph_list=None,
107107
manage: PipelineManage = None,
108108
padding_problem_text: str = None, stream: bool = True, chat_user_id=None, chat_user_type=None,

apps/application/chat_pipeline/step/chat_step/impl/base_chat_step.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@
2626
from application.chat_pipeline.step.chat_step.i_chat_step import IChatStep, PostResponseHandler
2727
from application.flow.tools import Reasoning
2828
from application.models import ApplicationChatUserStats, ChatUserType
29-
from models_provider.tools import get_model_instance_by_model_user_id
29+
from models_provider.tools import get_model_instance_by_model_workspace_id
3030

3131

3232
def add_access_num(chat_user_id=None, chat_user_type=None, application_id=None):
@@ -157,7 +157,7 @@ def execute(self, message_list: List[BaseMessage],
157157
problem_text,
158158
post_response_handler: PostResponseHandler,
159159
model_id: str = None,
160-
user_id: str = None,
160+
workspace_id: str = None,
161161
paragraph_list=None,
162162
manage: PipelineManage = None,
163163
padding_problem_text: str = None,
@@ -167,8 +167,8 @@ def execute(self, message_list: List[BaseMessage],
167167
model_params_setting=None,
168168
model_setting=None,
169169
**kwargs):
170-
chat_model = get_model_instance_by_model_user_id(model_id, user_id,
171-
**model_params_setting) if model_id is not None else None
170+
chat_model = get_model_instance_by_model_workspace_id(model_id, workspace_id,
171+
**model_params_setting) if model_id is not None else None
172172
if stream:
173173
return self.execute_stream(message_list, chat_id, problem_text, post_response_handler, chat_model,
174174
paragraph_list,

apps/application/chat_pipeline/step/reset_problem_step/i_reset_problem_step.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ class InstanceSerializer(serializers.Serializer):
2727
label=_("History Questions"))
2828
# 大语言模型
2929
model_id = serializers.UUIDField(required=False, allow_null=True, label=_("Model id"))
30-
user_id = serializers.UUIDField(required=True, label=_("User ID"))
30+
workspace_id = serializers.CharField(required=True, label=_("User ID"))
3131
problem_optimization_prompt = serializers.CharField(required=False, max_length=102400,
3232
label=_("Question completion prompt"))
3333

@@ -50,6 +50,6 @@ def _run(self, manage: PipelineManage):
5050
@abstractmethod
5151
def execute(self, problem_text: str, history_chat_record: List[ChatRecord] = None, model_id: str = None,
5252
problem_optimization_prompt=None,
53-
user_id=None,
53+
workspace_id=None,
5454
**kwargs):
5555
pass

apps/application/chat_pipeline/step/reset_problem_step/impl/base_reset_problem_step.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
from application.chat_pipeline.step.reset_problem_step.i_reset_problem_step import IResetProblemStep
1515
from application.models import ChatRecord
1616
from common.utils.split_model import flat_map
17-
from models_provider.tools import get_model_instance_by_model_user_id
17+
from models_provider.tools import get_model_instance_by_model_workspace_id
1818

1919
prompt = _(
2020
"() contains the user's question. Answer the guessed user's question based on the context ({question}) Requirement: Output a complete question and put it in the <data></data> tag")
@@ -23,9 +23,9 @@
2323
class BaseResetProblemStep(IResetProblemStep):
2424
def execute(self, problem_text: str, history_chat_record: List[ChatRecord] = None, model_id: str = None,
2525
problem_optimization_prompt=None,
26-
user_id=None,
26+
workspace_id=None,
2727
**kwargs) -> str:
28-
chat_model = get_model_instance_by_model_user_id(model_id, user_id) if model_id is not None else None
28+
chat_model = get_model_instance_by_model_workspace_id(model_id, workspace_id) if model_id is not None else None
2929
if chat_model is None:
3030
return problem_text
3131
start_index = len(history_chat_record) - 3

apps/application/chat_pipeline/step/search_dataset_step/i_search_dataset_step.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,7 @@ class InstanceSerializer(serializers.Serializer):
4444
validators.RegexValidator(regex=re.compile("^embedding|keywords|blend$"),
4545
message=_("The type only supports embedding|keywords|blend"), code=500)
4646
], label=_("Retrieval Mode"))
47-
user_id = serializers.UUIDField(required=True, label=_("User ID"))
47+
workspace_id = serializers.CharField(required=True, label=_("Workspace ID"))
4848

4949
def get_step_serializer(self, manage: PipelineManage) -> Type[InstanceSerializer]:
5050
return self.InstanceSerializer
@@ -58,19 +58,19 @@ def _run(self, manage: PipelineManage):
5858
def execute(self, problem_text: str, knowledge_id_list: list[str], exclude_document_id_list: list[str],
5959
exclude_paragraph_id_list: list[str], top_n: int, similarity: float, padding_problem_text: str = None,
6060
search_mode: str = None,
61-
user_id=None,
61+
workspace_id=None,
6262
**kwargs) -> List[ParagraphPipelineModel]:
6363
"""
6464
关于 用户和补全问题 说明: 补全问题如果有就使用补全问题去查询 反之就用用户原始问题查询
6565
:param similarity: 相关性
6666
:param top_n: 查询多少条
6767
:param problem_text: 用户问题
68-
:param knowledge_id_list: 需要查询的数据集id列表
68+
:param knowledge_id_list: 需要查询的数据集id列表
6969
:param exclude_document_id_list: 需要排除的文档id
7070
:param exclude_paragraph_id_list: 需要排除段落id
7171
:param padding_problem_text 补全问题
7272
:param search_mode 检索模式
73-
:param user_id 用户id
73+
:param workspace_id 工作空间id
7474
:return: 段落列表
7575
"""
7676
pass

apps/application/chat_pipeline/step/search_dataset_step/impl/base_search_dataset_step.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -25,13 +25,13 @@
2525
from models_provider.tools import get_model
2626

2727

28-
def get_model_by_id(_id, user_id):
29-
model = QuerySet(Model).filter(id=_id).first()
28+
def get_model_by_id(_id, workspace_id):
29+
model = QuerySet(Model).filter(id=_id, model_type="EMBEDDING").first()
3030
if model is None:
3131
raise Exception(_("Model does not exist"))
32-
if model.permission_type == 'PRIVATE' and str(model.user_id) != str(user_id):
33-
message = lazy_format(_('No permission to use this model {model_name}'), model_name=model.name)
34-
raise Exception(message)
32+
if model.workspace_id is not None:
33+
if model.workspace_id != workspace_id:
34+
raise Exception(_("Model does not exist"))
3535
return model
3636

3737

@@ -50,13 +50,13 @@ class BaseSearchDatasetStep(ISearchDatasetStep):
5050
def execute(self, problem_text: str, knowledge_id_list: list[str], exclude_document_id_list: list[str],
5151
exclude_paragraph_id_list: list[str], top_n: int, similarity: float, padding_problem_text: str = None,
5252
search_mode: str = None,
53-
user_id=None,
53+
workspace_id=None,
5454
**kwargs) -> List[ParagraphPipelineModel]:
5555
if len(knowledge_id_list) == 0:
5656
return []
5757
exec_problem_text = padding_problem_text if padding_problem_text is not None else problem_text
5858
model_id = get_embedding_id(knowledge_id_list)
59-
model = get_model_by_id(model_id, user_id)
59+
model = get_model_by_id(model_id, workspace_id)
6060
self.context['model_name'] = model.name
6161
embedding_model = ModelManage.get_model(model_id, lambda _id: get_model(model))
6262
embedding_value = embedding_model.embed_query(exec_problem_text)

apps/application/flow/step_node/ai_chat_step_node/impl/base_chat_node.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,6 @@
1111
import re
1212
import time
1313
from functools import reduce
14-
from types import AsyncGeneratorType
1514
from typing import List, Dict
1615

1716
from django.db.models import QuerySet
@@ -24,7 +23,7 @@
2423
from application.flow.step_node.ai_chat_step_node.i_chat_node import IChatNode
2524
from application.flow.tools import Reasoning
2625
from models_provider.models import Model
27-
from models_provider.tools import get_model_credential, get_model_instance_by_model_user_id
26+
from models_provider.tools import get_model_credential, get_model_instance_by_model_workspace_id
2827

2928
tool_message_template = """
3029
<details>
@@ -206,8 +205,9 @@ def execute(self, model_id, system, prompt, dialogue_number, history_chat_record
206205
model_setting = {'reasoning_content_enable': False, 'reasoning_content_end': '</think>',
207206
'reasoning_content_start': '<think>'}
208207
self.context['model_setting'] = model_setting
209-
chat_model = get_model_instance_by_model_user_id(model_id, self.flow_params_serializer.data.get('user_id'),
210-
**model_params_setting)
208+
workspace_id = self.workflow_manage.get_body().get('workspace_id')
209+
chat_model = get_model_instance_by_model_workspace_id(model_id, workspace_id,
210+
**model_params_setting)
211211
history_message = self.get_history_message(history_chat_record, dialogue_number, dialogue_type,
212212
self.runtime_node_id)
213213
self.context['history_message'] = history_message

apps/application/flow/step_node/image_generate_step_node/impl/base_image_generate_node.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
from application.flow.step_node.image_generate_step_node.i_image_generate_node import IImageGenerateNode
1010
from common.utils.common import bytes_to_uploaded_file
1111
from oss.serializers.file import FileSerializer
12-
from models_provider.tools import get_model_instance_by_model_user_id
12+
from models_provider.tools import get_model_instance_by_model_workspace_id
1313

1414

1515
class BaseImageGenerateNode(IImageGenerateNode):
@@ -25,8 +25,9 @@ def execute(self, model_id, prompt, negative_prompt, dialogue_number, dialogue_t
2525
**kwargs) -> NodeResult:
2626
print(model_params_setting)
2727
application = self.workflow_manage.work_flow_post_handler.chat_info.application
28-
tti_model = get_model_instance_by_model_user_id(model_id, self.flow_params_serializer.data.get('user_id'),
29-
**model_params_setting)
28+
workspace_id = self.workflow_manage.get_body().get('workspace_id')
29+
tti_model = get_model_instance_by_model_workspace_id(model_id, workspace_id,
30+
**model_params_setting)
3031
history_message = self.get_history_message(history_chat_record, dialogue_number)
3132
self.context['history_message'] = history_message
3233
question = self.generate_prompt_question(prompt)

apps/application/flow/step_node/image_understand_step_node/impl/base_image_understand_node.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
from application.flow.i_step_node import NodeResult, INode
1212
from application.flow.step_node.image_understand_step_node.i_image_understand_node import IImageUnderstandNode
1313
from knowledge.models import File
14-
from models_provider.tools import get_model_instance_by_model_user_id
14+
from models_provider.tools import get_model_instance_by_model_workspace_id
1515

1616

1717
def _write_context(node_variable: Dict, workflow_variable: Dict, node: INode, workflow, answer: str):
@@ -79,9 +79,9 @@ def execute(self, model_id, system, prompt, dialogue_number, dialogue_type, hist
7979
# 处理不正确的参数
8080
if image is None or not isinstance(image, list):
8181
image = []
82-
print(model_params_setting)
83-
image_model = get_model_instance_by_model_user_id(model_id, self.flow_params_serializer.data.get('user_id'),
84-
**model_params_setting)
82+
workspace_id = self.workflow_manage.get_body().get('workspace_id')
83+
image_model = get_model_instance_by_model_workspace_id(model_id, workspace_id,
84+
**model_params_setting)
8585
# 执行详情中的历史消息不需要图片内容
8686
history_message = self.get_history_message_for_details(history_chat_record, dialogue_number)
8787
self.context['history_message'] = history_message

apps/application/flow/step_node/question_node/impl/base_question_node.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
from application.flow.i_step_node import NodeResult, INode
1919
from application.flow.step_node.question_node.i_question_node import IQuestionNode
2020
from models_provider.models import Model
21-
from models_provider.tools import get_model_instance_by_model_user_id, get_model_credential
21+
from models_provider.tools import get_model_instance_by_model_workspace_id, get_model_credential
2222

2323

2424
def _write_context(node_variable: Dict, workflow_variable: Dict, node: INode, workflow, answer: str):
@@ -87,8 +87,9 @@ def execute(self, model_id, system, prompt, dialogue_number, history_chat_record
8787
**kwargs) -> NodeResult:
8888
if model_params_setting is None:
8989
model_params_setting = get_default_model_params_setting(model_id)
90-
chat_model = get_model_instance_by_model_user_id(model_id, self.flow_params_serializer.data.get('user_id'),
91-
**model_params_setting)
90+
workspace_id = self.workflow_manage.get_body().get('workspace_id')
91+
chat_model = get_model_instance_by_model_workspace_id(model_id, workspace_id,
92+
**model_params_setting)
9293
history_message = self.get_history_message(history_chat_record, dialogue_number)
9394
self.context['history_message'] = history_message
9495
question = self.generate_prompt_question(prompt)

apps/application/flow/step_node/reranker_node/impl/base_reranker_node.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212

1313
from application.flow.i_step_node import NodeResult
1414
from application.flow.step_node.reranker_node.i_reranker_node import IRerankerNode
15-
from models_provider.tools import get_model_instance_by_model_user_id
15+
from models_provider.tools import get_model_instance_by_model_workspace_id
1616

1717

1818
def merge_reranker_list(reranker_list, result=None):
@@ -78,8 +78,9 @@ def execute(self, question, reranker_setting, reranker_list, reranker_model_id,
7878
self.context['document_list'] = [{'page_content': document.page_content, 'metadata': document.metadata} for
7979
document in documents]
8080
self.context['question'] = question
81-
reranker_model = get_model_instance_by_model_user_id(reranker_model_id,
82-
self.flow_params_serializer.data.get('user_id'),
81+
workspace_id = self.workflow_manage.get_body().get('workspace_id')
82+
reranker_model = get_model_instance_by_model_workspace_id(reranker_model_id,
83+
workspace_id,
8384
top_n=top_n)
8485
result = reranker_model.compress_documents(
8586
documents,

apps/application/flow/step_node/search_dataset_node/impl/base_search_dataset_node.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
from common.utils.common import get_file_content
2020
from knowledge.models import Document, Paragraph, Knowledge, SearchMode
2121
from maxkb.conf import PROJECT_DIR
22-
from models_provider.tools import get_model_instance_by_model_user_id
22+
from models_provider.tools import get_model_instance_by_model_workspace_id
2323

2424

2525
def get_embedding_id(dataset_id_list):
@@ -67,7 +67,8 @@ def execute(self, dataset_id_list, dataset_setting, question,
6767
if len(dataset_id_list) == 0:
6868
return get_none_result(question)
6969
model_id = get_embedding_id(dataset_id_list)
70-
embedding_model = get_model_instance_by_model_user_id(model_id, self.flow_params_serializer.data.get('user_id'))
70+
workspace_id = self.workflow_manage.get_body().get('workspace_id')
71+
embedding_model = get_model_instance_by_model_workspace_id(model_id, workspace_id)
7172
embedding_value = embedding_model.embed_query(question)
7273
vector = VectorStore.get_embedding_vector()
7374
exclude_document_id_list = [str(document.id) for document in

apps/application/flow/step_node/speech_to_text_step_node/impl/base_speech_to_text_node.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
from application.flow.step_node.speech_to_text_step_node.i_speech_to_text_node import ISpeechToTextNode
1010
from common.utils.common import split_and_transcribe, any_to_mp3
1111
from knowledge.models import File
12-
from models_provider.tools import get_model_instance_by_model_user_id
12+
from models_provider.tools import get_model_instance_by_model_workspace_id
1313

1414

1515
class BaseSpeechToTextNode(ISpeechToTextNode):
@@ -20,7 +20,8 @@ def save_context(self, details, workflow_manage):
2020
self.answer_text = details.get('answer')
2121

2222
def execute(self, stt_model_id, chat_id, audio, **kwargs) -> NodeResult:
23-
stt_model = get_model_instance_by_model_user_id(stt_model_id, self.flow_params_serializer.data.get('user_id'))
23+
workspace_id = self.workflow_manage.get_body().get('workspace_id')
24+
stt_model = get_model_instance_by_model_workspace_id(stt_model_id, workspace_id)
2425
audio_list = audio
2526
self.context['audio_list'] = audio
2627

apps/application/flow/step_node/text_to_speech_step_node/impl/base_text_to_speech_node.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6,8 +6,8 @@
66

77
from application.flow.i_step_node import NodeResult
88
from application.flow.step_node.text_to_speech_step_node.i_text_to_speech_node import ITextToSpeechNode
9+
from models_provider.tools import get_model_instance_by_model_workspace_id
910
from oss.serializers.file import FileSerializer
10-
from models_provider.tools import get_model_instance_by_model_user_id
1111

1212

1313
def bytes_to_uploaded_file(file_bytes, file_name="generated_audio.mp3"):
@@ -42,8 +42,9 @@ def execute(self, tts_model_id, chat_id,
4242
content, model_params_setting=None,
4343
**kwargs) -> NodeResult:
4444
self.context['content'] = content
45-
model = get_model_instance_by_model_user_id(tts_model_id, self.flow_params_serializer.data.get('user_id'),
46-
**model_params_setting)
45+
workspace_id = self.workflow_manage.get_body().get('workspace_id')
46+
model = get_model_instance_by_model_workspace_id(tts_model_id, workspace_id,
47+
**model_params_setting)
4748
audio_byte = model.text_to_speech(content)
4849
# 需要把这个音频文件存储到数据库中
4950
file_name = 'generated_audio.mp3'

0 commit comments

Comments
 (0)