From aeff8ab0b49163201a526aeba78d2ba3de07dcf3 Mon Sep 17 00:00:00 2001 From: shaohuzhang1 Date: Thu, 18 Jul 2024 19:11:29 +0800 Subject: [PATCH] =?UTF-8?q?feat:=20=E6=94=AF=E6=8C=81=E5=90=91=E9=87=8F?= =?UTF-8?q?=E6=A8=A1=E5=9E=8B?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- apps/application/flow/i_step_node.py | 2 ++ .../search_dataset_node/impl/base_search_dataset_node.py | 9 ++++++--- apps/application/serializers/chat_message_serializers.py | 4 +++- apps/application/serializers/chat_serializers.py | 7 +++++-- 4 files changed, 16 insertions(+), 6 deletions(-) diff --git a/apps/application/flow/i_step_node.py b/apps/application/flow/i_step_node.py index 0aa620e26..2107f6f8f 100644 --- a/apps/application/flow/i_step_node.py +++ b/apps/application/flow/i_step_node.py @@ -111,6 +111,8 @@ class FlowParamsSerializer(serializers.Serializer): client_type = serializers.CharField(required=False, error_messages=ErrMessage.char("客户端类型")) + user_id = serializers.CharField(required=True, error_messages=ErrMessage.char("用户id")) + class INode: def __init__(self, node, workflow_params, workflow_manage): diff --git a/apps/application/flow/step_node/search_dataset_node/impl/base_search_dataset_node.py b/apps/application/flow/step_node/search_dataset_node/impl/base_search_dataset_node.py index 1a30a29ea..1d315291a 100644 --- a/apps/application/flow/step_node/search_dataset_node/impl/base_search_dataset_node.py +++ b/apps/application/flow/step_node/search_dataset_node/impl/base_search_dataset_node.py @@ -23,11 +23,13 @@ from setting.models_provider import get_model from smartdoc.conf import PROJECT_DIR -def get_model_by_id(_id): +def get_model_by_id(_id, user_id): model = QuerySet(Model).filter(id=_id).first() if model is None: raise Exception("模型不存在") - return get_model(model) + if model.permission_type == 'PRIVATE' and str(model.user_id) != str(user_id): + raise Exception(f"无权限使用此模型:{model.name}") + return model def get_embedding_id(dataset_id_list): @@ -53,7 +55,8 @@ class BaseSearchDatasetNode(ISearchDatasetStepNode): if len(dataset_id_list) == 0: return get_none_result(question) model_id = get_embedding_id(dataset_id_list) - embedding_model = EmbeddingModelManage.get_model(model_id, get_model_by_id) + model = get_model_by_id(model_id, self.flow_params_serializer.data.get('user_id')) + embedding_model = EmbeddingModelManage.get_model(model_id, lambda _id: get_model(model)) embedding_value = embedding_model.embed_query(question) vector = VectorStore.get_embedding_vector() exclude_document_id_list = [str(document.id) for document in diff --git a/apps/application/serializers/chat_message_serializers.py b/apps/application/serializers/chat_message_serializers.py index 41f19bc0d..eb5b2fe76 100644 --- a/apps/application/serializers/chat_message_serializers.py +++ b/apps/application/serializers/chat_message_serializers.py @@ -221,11 +221,13 @@ class ChatMessageSerializer(serializers.Serializer): stream = self.data.get('stream') client_id = self.data.get('client_id') client_type = self.data.get('client_type') + user_id = chat_info.application.user_id work_flow_manage = WorkflowManage(Flow.new_instance(chat_info.work_flow_version.work_flow), {'history_chat_record': chat_info.chat_record_list, 'question': message, 'chat_id': chat_info.chat_id, 'chat_record_id': str(uuid.uuid1()), 'stream': stream, - 're_chat': re_chat}, WorkFlowPostHandler(chat_info, client_id, client_type)) + 're_chat': re_chat, + 'user_id': user_id}, WorkFlowPostHandler(chat_info, client_id, client_type)) r = work_flow_manage.run() return r diff --git a/apps/application/serializers/chat_serializers.py b/apps/application/serializers/chat_serializers.py index 48b0ca9ff..6c8d090d6 100644 --- a/apps/application/serializers/chat_serializers.py +++ b/apps/application/serializers/chat_serializers.py @@ -260,6 +260,7 @@ class ChatSerializers(serializers.Serializer): class OpenWorkFlowChat(serializers.Serializer): work_flow = WorkFlowSerializers(error_messages=ErrMessage.uuid("工作流")) + user_id = serializers.UUIDField(required=True, error_messages=ErrMessage.uuid("用户id")) def open(self): self.is_valid(raise_exception=True) @@ -270,7 +271,8 @@ class ChatSerializers(serializers.Serializer): dataset_setting={}, model_setting={}, problem_optimization=None, - type=ApplicationTypeChoices.WORK_FLOW + type=ApplicationTypeChoices.WORK_FLOW, + user_id=self.data.get('user_id') ) work_flow_version = WorkFlowVersion(work_flow=work_flow) chat_cache.set(chat_id, @@ -333,7 +335,8 @@ class ChatSerializers(serializers.Serializer): application = Application(id=None, dialogue_number=3, model=model, dataset_setting=self.data.get('dataset_setting'), model_setting=self.data.get('model_setting'), - problem_optimization=self.data.get('problem_optimization')) + problem_optimization=self.data.get('problem_optimization'), + user_id=user_id) chat_cache.set(chat_id, ChatInfo(chat_id, chat_model, dataset_id_list, [str(document.id) for document in