diff --git a/apps/application/chat_pipeline/step/chat_step/i_chat_step.py b/apps/application/chat_pipeline/step/chat_step/i_chat_step.py index 1451cd88e60..6c3012f8e3d 100644 --- a/apps/application/chat_pipeline/step/chat_step/i_chat_step.py +++ b/apps/application/chat_pipeline/step/chat_step/i_chat_step.py @@ -75,7 +75,7 @@ class InstanceSerializer(serializers.Serializer): no_references_setting = NoReferencesSetting(required=True, label=_("No reference segment settings")) - user_id = serializers.UUIDField(required=True, label=_("User ID")) + workspace_id = serializers.CharField(required=True, label=_("Workspace ID")) model_setting = serializers.DictField(required=True, allow_null=True, label=_("Model settings")) @@ -102,7 +102,7 @@ def execute(self, message_list: List[BaseMessage], chat_id, problem_text, post_response_handler: PostResponseHandler, model_id: str = None, - user_id: str = None, + workspace_id: str = None, paragraph_list=None, manage: PipelineManage = None, padding_problem_text: str = None, stream: bool = True, chat_user_id=None, chat_user_type=None, diff --git a/apps/application/chat_pipeline/step/chat_step/impl/base_chat_step.py b/apps/application/chat_pipeline/step/chat_step/impl/base_chat_step.py index 13cab54e277..9ce08d546d8 100644 --- a/apps/application/chat_pipeline/step/chat_step/impl/base_chat_step.py +++ b/apps/application/chat_pipeline/step/chat_step/impl/base_chat_step.py @@ -26,7 +26,7 @@ from application.chat_pipeline.step.chat_step.i_chat_step import IChatStep, PostResponseHandler from application.flow.tools import Reasoning from application.models import ApplicationChatUserStats, ChatUserType -from models_provider.tools import get_model_instance_by_model_user_id +from models_provider.tools import get_model_instance_by_model_workspace_id def add_access_num(chat_user_id=None, chat_user_type=None, application_id=None): @@ -157,7 +157,7 @@ def execute(self, message_list: List[BaseMessage], problem_text, post_response_handler: PostResponseHandler, model_id: str = None, - user_id: str = None, + workspace_id: str = None, paragraph_list=None, manage: PipelineManage = None, padding_problem_text: str = None, @@ -167,8 +167,8 @@ def execute(self, message_list: List[BaseMessage], model_params_setting=None, model_setting=None, **kwargs): - chat_model = get_model_instance_by_model_user_id(model_id, user_id, - **model_params_setting) if model_id is not None else None + chat_model = get_model_instance_by_model_workspace_id(model_id, workspace_id, + **model_params_setting) if model_id is not None else None if stream: return self.execute_stream(message_list, chat_id, problem_text, post_response_handler, chat_model, paragraph_list, diff --git a/apps/application/chat_pipeline/step/reset_problem_step/i_reset_problem_step.py b/apps/application/chat_pipeline/step/reset_problem_step/i_reset_problem_step.py index 5c740eda4dc..a0e06204364 100644 --- a/apps/application/chat_pipeline/step/reset_problem_step/i_reset_problem_step.py +++ b/apps/application/chat_pipeline/step/reset_problem_step/i_reset_problem_step.py @@ -27,7 +27,7 @@ class InstanceSerializer(serializers.Serializer): label=_("History Questions")) # 大语言模型 model_id = serializers.UUIDField(required=False, allow_null=True, label=_("Model id")) - user_id = serializers.UUIDField(required=True, label=_("User ID")) + workspace_id = serializers.CharField(required=True, label=_("User ID")) problem_optimization_prompt = serializers.CharField(required=False, max_length=102400, label=_("Question completion prompt")) @@ -50,6 +50,6 @@ def _run(self, manage: PipelineManage): @abstractmethod def execute(self, problem_text: str, history_chat_record: List[ChatRecord] = None, model_id: str = None, problem_optimization_prompt=None, - user_id=None, + workspace_id=None, **kwargs): pass diff --git a/apps/application/chat_pipeline/step/reset_problem_step/impl/base_reset_problem_step.py b/apps/application/chat_pipeline/step/reset_problem_step/impl/base_reset_problem_step.py index 8e4a0cdbf71..391e35196fa 100644 --- a/apps/application/chat_pipeline/step/reset_problem_step/impl/base_reset_problem_step.py +++ b/apps/application/chat_pipeline/step/reset_problem_step/impl/base_reset_problem_step.py @@ -14,7 +14,7 @@ from application.chat_pipeline.step.reset_problem_step.i_reset_problem_step import IResetProblemStep from application.models import ChatRecord from common.utils.split_model import flat_map -from models_provider.tools import get_model_instance_by_model_user_id +from models_provider.tools import get_model_instance_by_model_workspace_id prompt = _( "() contains the user's question. Answer the guessed user's question based on the context ({question}) Requirement: Output a complete question and put it in the <data></data> tag") @@ -23,9 +23,9 @@ class BaseResetProblemStep(IResetProblemStep): def execute(self, problem_text: str, history_chat_record: List[ChatRecord] = None, model_id: str = None, problem_optimization_prompt=None, - user_id=None, + workspace_id=None, **kwargs) -> str: - chat_model = get_model_instance_by_model_user_id(model_id, user_id) if model_id is not None else None + chat_model = get_model_instance_by_model_workspace_id(model_id, workspace_id) if model_id is not None else None if chat_model is None: return problem_text start_index = len(history_chat_record) - 3 diff --git a/apps/application/chat_pipeline/step/search_dataset_step/i_search_dataset_step.py b/apps/application/chat_pipeline/step/search_dataset_step/i_search_dataset_step.py index bb08ca9e229..2e96b1aabc6 100644 --- a/apps/application/chat_pipeline/step/search_dataset_step/i_search_dataset_step.py +++ b/apps/application/chat_pipeline/step/search_dataset_step/i_search_dataset_step.py @@ -44,7 +44,7 @@ class InstanceSerializer(serializers.Serializer): validators.RegexValidator(regex=re.compile("^embedding|keywords|blend$"), message=_("The type only supports embedding|keywords|blend"), code=500) ], label=_("Retrieval Mode")) - user_id = serializers.UUIDField(required=True, label=_("User ID")) + workspace_id = serializers.CharField(required=True, label=_("Workspace ID")) def get_step_serializer(self, manage: PipelineManage) -> Type[InstanceSerializer]: return self.InstanceSerializer @@ -58,19 +58,19 @@ def _run(self, manage: PipelineManage): def execute(self, problem_text: str, knowledge_id_list: list[str], exclude_document_id_list: list[str], exclude_paragraph_id_list: list[str], top_n: int, similarity: float, padding_problem_text: str = None, search_mode: str = None, - user_id=None, + workspace_id=None, **kwargs) -> List[ParagraphPipelineModel]: """ 关于 用户和补全问题 说明: 补全问题如果有就使用补全问题去查询 反之就用用户原始问题查询 :param similarity: 相关性 :param top_n: 查询多少条 :param problem_text: 用户问题 - :param knowledge_id_list: 需要查询的数据集id列表 + :param knowledge_id_list: 需要查询的数据集id列表 :param exclude_document_id_list: 需要排除的文档id :param exclude_paragraph_id_list: 需要排除段落id :param padding_problem_text 补全问题 :param search_mode 检索模式 - :param user_id 用户id + :param workspace_id 工作空间id :return: 段落列表 """ pass diff --git a/apps/application/chat_pipeline/step/search_dataset_step/impl/base_search_dataset_step.py b/apps/application/chat_pipeline/step/search_dataset_step/impl/base_search_dataset_step.py index 88c58cd048f..9c494912386 100644 --- a/apps/application/chat_pipeline/step/search_dataset_step/impl/base_search_dataset_step.py +++ b/apps/application/chat_pipeline/step/search_dataset_step/impl/base_search_dataset_step.py @@ -25,13 +25,13 @@ from models_provider.tools import get_model -def get_model_by_id(_id, user_id): - model = QuerySet(Model).filter(id=_id).first() +def get_model_by_id(_id, workspace_id): + model = QuerySet(Model).filter(id=_id, model_type="EMBEDDING").first() if model is None: raise Exception(_("Model does not exist")) - if model.permission_type == 'PRIVATE' and str(model.user_id) != str(user_id): - message = lazy_format(_('No permission to use this model {model_name}'), model_name=model.name) - raise Exception(message) + if model.workspace_id is not None: + if model.workspace_id != workspace_id: + raise Exception(_("Model does not exist")) return model @@ -50,13 +50,13 @@ class BaseSearchDatasetStep(ISearchDatasetStep): def execute(self, problem_text: str, knowledge_id_list: list[str], exclude_document_id_list: list[str], exclude_paragraph_id_list: list[str], top_n: int, similarity: float, padding_problem_text: str = None, search_mode: str = None, - user_id=None, + workspace_id=None, **kwargs) -> List[ParagraphPipelineModel]: if len(knowledge_id_list) == 0: return [] exec_problem_text = padding_problem_text if padding_problem_text is not None else problem_text model_id = get_embedding_id(knowledge_id_list) - model = get_model_by_id(model_id, user_id) + model = get_model_by_id(model_id, workspace_id) self.context['model_name'] = model.name embedding_model = ModelManage.get_model(model_id, lambda _id: get_model(model)) embedding_value = embedding_model.embed_query(exec_problem_text) diff --git a/apps/application/flow/step_node/ai_chat_step_node/impl/base_chat_node.py b/apps/application/flow/step_node/ai_chat_step_node/impl/base_chat_node.py index a9780f0a741..47abc1ebc08 100644 --- a/apps/application/flow/step_node/ai_chat_step_node/impl/base_chat_node.py +++ b/apps/application/flow/step_node/ai_chat_step_node/impl/base_chat_node.py @@ -11,7 +11,6 @@ import re import time from functools import reduce -from types import AsyncGeneratorType from typing import List, Dict from django.db.models import QuerySet @@ -24,7 +23,7 @@ from application.flow.step_node.ai_chat_step_node.i_chat_node import IChatNode from application.flow.tools import Reasoning from models_provider.models import Model -from models_provider.tools import get_model_credential, get_model_instance_by_model_user_id +from models_provider.tools import get_model_credential, get_model_instance_by_model_workspace_id tool_message_template = """ <details> @@ -206,8 +205,9 @@ def execute(self, model_id, system, prompt, dialogue_number, history_chat_record model_setting = {'reasoning_content_enable': False, 'reasoning_content_end': '</think>', 'reasoning_content_start': '<think>'} self.context['model_setting'] = model_setting - chat_model = get_model_instance_by_model_user_id(model_id, self.flow_params_serializer.data.get('user_id'), - **model_params_setting) + workspace_id = self.workflow_manage.get_body().get('workspace_id') + chat_model = get_model_instance_by_model_workspace_id(model_id, workspace_id, + **model_params_setting) history_message = self.get_history_message(history_chat_record, dialogue_number, dialogue_type, self.runtime_node_id) self.context['history_message'] = history_message diff --git a/apps/application/flow/step_node/image_generate_step_node/impl/base_image_generate_node.py b/apps/application/flow/step_node/image_generate_step_node/impl/base_image_generate_node.py index 2feba6fd12b..3c5f0a851e0 100644 --- a/apps/application/flow/step_node/image_generate_step_node/impl/base_image_generate_node.py +++ b/apps/application/flow/step_node/image_generate_step_node/impl/base_image_generate_node.py @@ -9,7 +9,7 @@ from application.flow.step_node.image_generate_step_node.i_image_generate_node import IImageGenerateNode from common.utils.common import bytes_to_uploaded_file from oss.serializers.file import FileSerializer -from models_provider.tools import get_model_instance_by_model_user_id +from models_provider.tools import get_model_instance_by_model_workspace_id class BaseImageGenerateNode(IImageGenerateNode): @@ -25,8 +25,9 @@ def execute(self, model_id, prompt, negative_prompt, dialogue_number, dialogue_t **kwargs) -> NodeResult: print(model_params_setting) application = self.workflow_manage.work_flow_post_handler.chat_info.application - tti_model = get_model_instance_by_model_user_id(model_id, self.flow_params_serializer.data.get('user_id'), - **model_params_setting) + workspace_id = self.workflow_manage.get_body().get('workspace_id') + tti_model = get_model_instance_by_model_workspace_id(model_id, workspace_id, + **model_params_setting) history_message = self.get_history_message(history_chat_record, dialogue_number) self.context['history_message'] = history_message question = self.generate_prompt_question(prompt) diff --git a/apps/application/flow/step_node/image_understand_step_node/impl/base_image_understand_node.py b/apps/application/flow/step_node/image_understand_step_node/impl/base_image_understand_node.py index 9acfaced222..b4813cc12d1 100644 --- a/apps/application/flow/step_node/image_understand_step_node/impl/base_image_understand_node.py +++ b/apps/application/flow/step_node/image_understand_step_node/impl/base_image_understand_node.py @@ -11,7 +11,7 @@ from application.flow.i_step_node import NodeResult, INode from application.flow.step_node.image_understand_step_node.i_image_understand_node import IImageUnderstandNode from knowledge.models import File -from models_provider.tools import get_model_instance_by_model_user_id +from models_provider.tools import get_model_instance_by_model_workspace_id def _write_context(node_variable: Dict, workflow_variable: Dict, node: INode, workflow, answer: str): @@ -79,9 +79,9 @@ def execute(self, model_id, system, prompt, dialogue_number, dialogue_type, hist # 处理不正确的参数 if image is None or not isinstance(image, list): image = [] - print(model_params_setting) - image_model = get_model_instance_by_model_user_id(model_id, self.flow_params_serializer.data.get('user_id'), - **model_params_setting) + workspace_id = self.workflow_manage.get_body().get('workspace_id') + image_model = get_model_instance_by_model_workspace_id(model_id, workspace_id, + **model_params_setting) # 执行详情中的历史消息不需要图片内容 history_message = self.get_history_message_for_details(history_chat_record, dialogue_number) self.context['history_message'] = history_message diff --git a/apps/application/flow/step_node/question_node/impl/base_question_node.py b/apps/application/flow/step_node/question_node/impl/base_question_node.py index c1463223084..8c848e4acd0 100644 --- a/apps/application/flow/step_node/question_node/impl/base_question_node.py +++ b/apps/application/flow/step_node/question_node/impl/base_question_node.py @@ -18,7 +18,7 @@ from application.flow.i_step_node import NodeResult, INode from application.flow.step_node.question_node.i_question_node import IQuestionNode from models_provider.models import Model -from models_provider.tools import get_model_instance_by_model_user_id, get_model_credential +from models_provider.tools import get_model_instance_by_model_workspace_id, get_model_credential def _write_context(node_variable: Dict, workflow_variable: Dict, node: INode, workflow, answer: str): @@ -87,8 +87,9 @@ def execute(self, model_id, system, prompt, dialogue_number, history_chat_record **kwargs) -> NodeResult: if model_params_setting is None: model_params_setting = get_default_model_params_setting(model_id) - chat_model = get_model_instance_by_model_user_id(model_id, self.flow_params_serializer.data.get('user_id'), - **model_params_setting) + workspace_id = self.workflow_manage.get_body().get('workspace_id') + chat_model = get_model_instance_by_model_workspace_id(model_id, workspace_id, + **model_params_setting) history_message = self.get_history_message(history_chat_record, dialogue_number) self.context['history_message'] = history_message question = self.generate_prompt_question(prompt) diff --git a/apps/application/flow/step_node/reranker_node/impl/base_reranker_node.py b/apps/application/flow/step_node/reranker_node/impl/base_reranker_node.py index 88c5b326ae7..0639f21cf3b 100644 --- a/apps/application/flow/step_node/reranker_node/impl/base_reranker_node.py +++ b/apps/application/flow/step_node/reranker_node/impl/base_reranker_node.py @@ -12,7 +12,7 @@ from application.flow.i_step_node import NodeResult from application.flow.step_node.reranker_node.i_reranker_node import IRerankerNode -from models_provider.tools import get_model_instance_by_model_user_id +from models_provider.tools import get_model_instance_by_model_workspace_id def merge_reranker_list(reranker_list, result=None): @@ -78,8 +78,9 @@ def execute(self, question, reranker_setting, reranker_list, reranker_model_id, self.context['document_list'] = [{'page_content': document.page_content, 'metadata': document.metadata} for document in documents] self.context['question'] = question - reranker_model = get_model_instance_by_model_user_id(reranker_model_id, - self.flow_params_serializer.data.get('user_id'), + workspace_id = self.workflow_manage.get_body().get('workspace_id') + reranker_model = get_model_instance_by_model_workspace_id(reranker_model_id, + workspace_id, top_n=top_n) result = reranker_model.compress_documents( documents, diff --git a/apps/application/flow/step_node/search_dataset_node/impl/base_search_dataset_node.py b/apps/application/flow/step_node/search_dataset_node/impl/base_search_dataset_node.py index 3edbd77fe51..c4214efc1c1 100644 --- a/apps/application/flow/step_node/search_dataset_node/impl/base_search_dataset_node.py +++ b/apps/application/flow/step_node/search_dataset_node/impl/base_search_dataset_node.py @@ -19,7 +19,7 @@ from common.utils.common import get_file_content from knowledge.models import Document, Paragraph, Knowledge, SearchMode from maxkb.conf import PROJECT_DIR -from models_provider.tools import get_model_instance_by_model_user_id +from models_provider.tools import get_model_instance_by_model_workspace_id def get_embedding_id(dataset_id_list): @@ -67,7 +67,8 @@ def execute(self, dataset_id_list, dataset_setting, question, if len(dataset_id_list) == 0: return get_none_result(question) model_id = get_embedding_id(dataset_id_list) - embedding_model = get_model_instance_by_model_user_id(model_id, self.flow_params_serializer.data.get('user_id')) + workspace_id = self.workflow_manage.get_body().get('workspace_id') + embedding_model = get_model_instance_by_model_workspace_id(model_id, workspace_id) embedding_value = embedding_model.embed_query(question) vector = VectorStore.get_embedding_vector() exclude_document_id_list = [str(document.id) for document in diff --git a/apps/application/flow/step_node/speech_to_text_step_node/impl/base_speech_to_text_node.py b/apps/application/flow/step_node/speech_to_text_step_node/impl/base_speech_to_text_node.py index 403457b87b7..1ddbc9fd54f 100644 --- a/apps/application/flow/step_node/speech_to_text_step_node/impl/base_speech_to_text_node.py +++ b/apps/application/flow/step_node/speech_to_text_step_node/impl/base_speech_to_text_node.py @@ -9,7 +9,7 @@ from application.flow.step_node.speech_to_text_step_node.i_speech_to_text_node import ISpeechToTextNode from common.utils.common import split_and_transcribe, any_to_mp3 from knowledge.models import File -from models_provider.tools import get_model_instance_by_model_user_id +from models_provider.tools import get_model_instance_by_model_workspace_id class BaseSpeechToTextNode(ISpeechToTextNode): @@ -20,7 +20,8 @@ def save_context(self, details, workflow_manage): self.answer_text = details.get('answer') def execute(self, stt_model_id, chat_id, audio, **kwargs) -> NodeResult: - stt_model = get_model_instance_by_model_user_id(stt_model_id, self.flow_params_serializer.data.get('user_id')) + workspace_id = self.workflow_manage.get_body().get('workspace_id') + stt_model = get_model_instance_by_model_workspace_id(stt_model_id, workspace_id) audio_list = audio self.context['audio_list'] = audio diff --git a/apps/application/flow/step_node/text_to_speech_step_node/impl/base_text_to_speech_node.py b/apps/application/flow/step_node/text_to_speech_step_node/impl/base_text_to_speech_node.py index 473143bd1de..6c10a0d360b 100644 --- a/apps/application/flow/step_node/text_to_speech_step_node/impl/base_text_to_speech_node.py +++ b/apps/application/flow/step_node/text_to_speech_step_node/impl/base_text_to_speech_node.py @@ -6,8 +6,8 @@ from application.flow.i_step_node import NodeResult from application.flow.step_node.text_to_speech_step_node.i_text_to_speech_node import ITextToSpeechNode +from models_provider.tools import get_model_instance_by_model_workspace_id from oss.serializers.file import FileSerializer -from models_provider.tools import get_model_instance_by_model_user_id def bytes_to_uploaded_file(file_bytes, file_name="generated_audio.mp3"): @@ -42,8 +42,9 @@ def execute(self, tts_model_id, chat_id, content, model_params_setting=None, **kwargs) -> NodeResult: self.context['content'] = content - model = get_model_instance_by_model_user_id(tts_model_id, self.flow_params_serializer.data.get('user_id'), - **model_params_setting) + workspace_id = self.workflow_manage.get_body().get('workspace_id') + model = get_model_instance_by_model_workspace_id(tts_model_id, workspace_id, + **model_params_setting) audio_byte = model.text_to_speech(content) # 需要把这个音频文件存储到数据库中 file_name = 'generated_audio.mp3' diff --git a/apps/application/serializers/application.py b/apps/application/serializers/application.py index aafaf382308..43acf5d4766 100644 --- a/apps/application/serializers/application.py +++ b/apps/application/serializers/application.py @@ -538,6 +538,7 @@ def to_application(application, workspace_id, user_id): class ApplicationOperateSerializer(serializers.Serializer): application_id = serializers.UUIDField(required=True, label=_("Application ID")) user_id = serializers.UUIDField(required=True, label=_("User ID")) + workspace_id = serializers.CharField(required=False, allow_null=True, allow_blank=True, label=_("Workspace ID")) def is_valid(self, *, raise_exception=False): super().is_valid(raise_exception=True) @@ -682,7 +683,6 @@ def edit(self, instance: Dict, with_valid=True): for update_key in update_keys: if update_key in instance and instance.get(update_key) is not None: application.__setattr__(update_key, instance.get(update_key)) - print(application.name) application.save() if 'knowledge_id_list' in instance: @@ -690,11 +690,11 @@ def edit(self, instance: Dict, with_valid=True): # 当前用户可修改关联的知识库列表 application_knowledge_id_list = [str(knowledge.id) for knowledge in self.list_knowledge(with_valid=False)] - for dataset_id in knowledge_id_list: - if not application_knowledge_id_list.__contains__(dataset_id): + for knowledge_id in knowledge_id_list: + if not application_knowledge_id_list.__contains__(knowledge_id): message = lazy_format(_('Unknown knowledge base id {dataset_id}, unable to associate'), - dataset_id=dataset_id) - raise AppApiException(500, message) + dataset_id=knowledge_id) + raise AppApiException(500, str(message)) self.save_application_knowledge_mapping(application_knowledge_id_list, knowledge_id_list, application_id) return self.one(with_valid=False) @@ -707,8 +707,8 @@ def one(self, with_valid=True): knowledge_list = self.list_knowledge(with_valid=False) mapping_knowledge_id_list = [akm.knowledge_id for akm in QuerySet(ApplicationKnowledgeMapping).filter(application_id=application_id)] - knowledge_id_list = [d.get('id') for d in - list(filter(lambda row: mapping_knowledge_id_list.__contains__(row.get('id')), + knowledge_id_list = [d.id for d in + list(filter(lambda row: mapping_knowledge_id_list.__contains__(row.id), knowledge_list))] return {**ApplicationSerializerModel(application).data, 'knowledge_id_list': knowledge_id_list} @@ -729,5 +729,5 @@ def save_application_knowledge_mapping(application_knowledge_id_list, knowledge_ application_id=application_id).delete() # 插入 QuerySet(ApplicationKnowledgeMapping).bulk_create( - [ApplicationKnowledgeMapping(application_id=application_id, dataset_id=dataset_id) for dataset_id in + [ApplicationKnowledgeMapping(application_id=application_id, knowledge_id=knowledge_id) for knowledge_id in knowledge_id_list]) if len(knowledge_id_list) > 0 else None diff --git a/apps/application/serializers/common.py b/apps/application/serializers/common.py index 29d720d09de..acb20b0a49c 100644 --- a/apps/application/serializers/common.py +++ b/apps/application/serializers/common.py @@ -98,7 +98,7 @@ def to_base_pipeline_manage_params(self): self.application.model_params_setting.keys()) == 0 else self.application.model_params_setting, 'search_mode': self.application.knowledge_setting.get('search_mode') or 'embedding', 'no_references_setting': self.get_no_references_setting(self.application.knowledge_setting, model_setting), - 'user_id': self.application.user_id, + 'workspace_id': self.application.workspace_id, 'application_id': self.application.id } diff --git a/apps/application/views/application.py b/apps/application/views/application.py index ee8a8cd52a9..2bb72d7ab94 100644 --- a/apps/application/views/application.py +++ b/apps/application/views/application.py @@ -130,6 +130,7 @@ class Export(APIView): def post(self, request: Request, workspace_id: str, application_id: str): return ApplicationOperateSerializer( data={'application_id': application_id, + 'workspace_id': workspace_id, 'user_id': request.user.id}).export(request.data) class Operate(APIView): @@ -148,11 +149,12 @@ class Operate(APIView): RoleConstants.WORKSPACE_MANAGE.get_workspace_role()) @log(menu='Application', operate='Deleting application', get_operation_object=lambda r, k: get_application_operation_object(k.get('application_id')), - + ) def delete(self, request: Request, workspace_id: str, application_id: str): return result.success(ApplicationOperateSerializer( - data={'application_id': application_id, 'user_id': request.user.id}).delete( + data={'application_id': application_id, 'user_id': request.user.id, + 'workspace_id': workspace_id, }).delete( with_valid=True)) @extend_schema( @@ -173,7 +175,8 @@ def delete(self, request: Request, workspace_id: str, application_id: str): def put(self, request: Request, workspace_id: str, application_id: str): return result.success( ApplicationOperateSerializer( - data={'application_id': application_id, 'user_id': request.user.id}).edit( + data={'application_id': application_id, 'user_id': request.user.id, + 'workspace_id': workspace_id, }).edit( request.data)) @extend_schema( @@ -190,7 +193,8 @@ def put(self, request: Request, workspace_id: str, application_id: str): RoleConstants.WORKSPACE_MANAGE.get_workspace_role(), RoleConstants.ADMIN) def get(self, request: Request, workspace_id: str, application_id: str): return result.success(ApplicationOperateSerializer( - data={'application_id': application_id, 'user_id': request.user.id}).one()) + data={'application_id': application_id, 'user_id': request.user.id, + 'workspace_id': workspace_id, }).one()) class Publish(APIView): authentication_classes = [TokenAuth] @@ -207,9 +211,10 @@ class Publish(APIView): ) @log(menu='Application', operate='Publishing an application', get_operation_object=lambda r, k: get_application_operation_object(k.get('application_id')), - + ) def put(self, request: Request, workspace_id: str, application_id: str): return result.success( ApplicationOperateSerializer( - data={'application_id': application_id, 'user_id': request.user.id}).publish(request.data)) + data={'application_id': application_id, 'user_id': request.user.id, + 'workspace_id': workspace_id, }).publish(request.data)) diff --git a/apps/chat/serializers/chat.py b/apps/chat/serializers/chat.py index c7eacabc4f1..33ad6c5aaca 100644 --- a/apps/chat/serializers/chat.py +++ b/apps/chat/serializers/chat.py @@ -366,7 +366,7 @@ def open_simple(self, application): chat_user_id = self.data.get("chat_user_id") chat_user_type = self.data.get("chat_user_type") debug = self.data.get("debug") - knowledge_id_list = [str(row.dataset_id) for row in + knowledge_id_list = [str(row.knowledge_id) for row in QuerySet(ApplicationKnowledgeMapping).filter( application_id=application_id)] chat_id = str(uuid.uuid7()) diff --git a/apps/models_provider/tools.py b/apps/models_provider/tools.py index 7d58aac760d..3c28a7a66ce 100644 --- a/apps/models_provider/tools.py +++ b/apps/models_provider/tools.py @@ -103,21 +103,25 @@ def is_valid_credential(provider, model_type, model_name, model_credential: Dict raise_exception) -def get_model_by_id(_id, user_id): +def get_model_by_id(_id, workspace_id): model = QuerySet(Model).filter(id=_id).first() # 手动关闭数据库连接 connection.close() if model is None: raise Exception(_('Model does not exist')) + if model.workspace_id: + if model.workspace_id != workspace_id: + raise Exception(_('Model does not exist')) + return model -def get_model_instance_by_model_user_id(model_id, user_id, **kwargs): +def get_model_instance_by_model_workspace_id(model_id, workspace_id, **kwargs): """ 获取模型实例,根据模型相关数据 - @param model_id: 模型id - @param user_id: 用户id - @return: 模型实例 + @param model_id: 模型id + @param workspace_id: 工作空间id + @return: 模型实例 """ - model = get_model_by_id(model_id, user_id) + model = get_model_by_id(model_id, workspace_id) return ModelManage.get_model(model_id, lambda _id: get_model(model, **kwargs))