From a404a5c6e97b70702b22aaa9c3ae2a9e33fe3906 Mon Sep 17 00:00:00 2001 From: shaohuzhang1 <80892890+shaohuzhang1@users.noreply.github.com> Date: Mon, 22 Jul 2024 11:44:38 +0800 Subject: [PATCH] =?UTF-8?q?fix:=20=E4=BF=AE=E5=A4=8D=E6=A8=A1=E5=9E=8B?= =?UTF-8?q?=E6=B2=A1=E6=9D=83=E9=99=90=E4=BD=BF=E7=94=A8=E6=97=B6=E6=8A=A5?= =?UTF-8?q?=E9=94=99=20(#825)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../step/chat_step/i_chat_step.py | 8 +++-- .../step/chat_step/impl/base_chat_step.py | 5 ++- apps/application/flow/i_step_node.py | 2 +- .../ai_chat_step_node/i_chat_node.py | 3 +- .../ai_chat_step_node/impl/base_chat_node.py | 14 ++------ .../question_node/impl/base_question_node.py | 14 ++------ .../impl/base_search_dataset_node.py | 17 ++-------- .../serializers/application_serializers.py | 5 +-- .../serializers/chat_message_serializers.py | 10 +----- .../serializers/chat_serializers.py | 23 ++----------- apps/setting/models/model_management.py | 5 +++ apps/setting/models_provider/tools.py | 33 +++++++++++++++++++ 12 files changed, 64 insertions(+), 75 deletions(-) create mode 100644 apps/setting/models_provider/tools.py diff --git a/apps/application/chat_pipeline/step/chat_step/i_chat_step.py b/apps/application/chat_pipeline/step/chat_step/i_chat_step.py index 8fbac34c9..18b08ea7a 100644 --- a/apps/application/chat_pipeline/step/chat_step/i_chat_step.py +++ b/apps/application/chat_pipeline/step/chat_step/i_chat_step.py @@ -53,8 +53,7 @@ class IChatStep(IBaseChatPipelineStep): # 对话列表 message_list = serializers.ListField(required=True, child=MessageField(required=True), error_messages=ErrMessage.list("对话列表")) - # 大语言模型 - chat_model = ModelField(required=False, allow_null=True, error_messages=ErrMessage.list("大语言模型")) + model_id = serializers.UUIDField(required=True, error_messages=ErrMessage.uuid("模型id")) # 段落列表 paragraph_list = serializers.ListField(error_messages=ErrMessage.list("段落列表")) # 对话id @@ -73,6 +72,8 @@ class IChatStep(IBaseChatPipelineStep): # 未查询到引用分段 no_references_setting = NoReferencesSetting(required=True, error_messages=ErrMessage.base("无引用分段设置")) + user_id = serializers.UUIDField(required=True, error_messages=ErrMessage.uuid("用户id")) + def is_valid(self, *, raise_exception=False): super().is_valid(raise_exception=True) message_list: List = self.initial_data.get('message_list') @@ -91,7 +92,8 @@ class IChatStep(IBaseChatPipelineStep): def execute(self, message_list: List[BaseMessage], chat_id, problem_text, post_response_handler: PostResponseHandler, - chat_model: BaseChatModel = None, + model_id: str = None, + user_id: str = None, paragraph_list=None, manage: PipelineManage = None, padding_problem_text: str = None, stream: bool = True, client_id=None, client_type=None, diff --git a/apps/application/chat_pipeline/step/chat_step/impl/base_chat_step.py b/apps/application/chat_pipeline/step/chat_step/impl/base_chat_step.py index f7dbe5835..750335e2d 100644 --- a/apps/application/chat_pipeline/step/chat_step/impl/base_chat_step.py +++ b/apps/application/chat_pipeline/step/chat_step/impl/base_chat_step.py @@ -26,6 +26,7 @@ from application.chat_pipeline.step.chat_step.i_chat_step import IChatStep, Post from application.models.api_key_model import ApplicationPublicAccessClient from common.constants.authentication_type import AuthenticationType from common.response import result +from setting.models_provider.tools import get_model_instance_by_model_user_id def add_access_num(client_id=None, client_type=None): @@ -101,7 +102,8 @@ class BaseChatStep(IChatStep): chat_id, problem_text, post_response_handler: PostResponseHandler, - chat_model: BaseChatModel = None, + model_id: str = None, + user_id: str = None, paragraph_list=None, manage: PipelineManage = None, padding_problem_text: str = None, @@ -109,6 +111,7 @@ class BaseChatStep(IChatStep): client_id=None, client_type=None, no_references_setting=None, **kwargs): + chat_model = get_model_instance_by_model_user_id(model_id, user_id) if stream: return self.execute_stream(message_list, chat_id, problem_text, post_response_handler, chat_model, paragraph_list, diff --git a/apps/application/flow/i_step_node.py b/apps/application/flow/i_step_node.py index 3a558ace0..0d741bd39 100644 --- a/apps/application/flow/i_step_node.py +++ b/apps/application/flow/i_step_node.py @@ -111,7 +111,7 @@ class FlowParamsSerializer(serializers.Serializer): client_type = serializers.CharField(required=False, error_messages=ErrMessage.char("客户端类型")) - user_id = serializers.UUIDField(required=True, error_messages=ErrMessage.char("用户id")) + user_id = serializers.UUIDField(required=True, error_messages=ErrMessage.uuid("用户id")) re_chat = serializers.BooleanField(required=True, error_messages=ErrMessage.boolean("换个答案")) diff --git a/apps/application/flow/step_node/ai_chat_step_node/i_chat_node.py b/apps/application/flow/step_node/ai_chat_step_node/i_chat_node.py index d0dfbaef9..3b941922f 100644 --- a/apps/application/flow/step_node/ai_chat_step_node/i_chat_node.py +++ b/apps/application/flow/step_node/ai_chat_step_node/i_chat_node.py @@ -32,6 +32,7 @@ class IChatNode(INode): def _run(self): return self.execute(**self.node_params_serializer.data, **self.flow_params_serializer.data) - def execute(self, model_id, system, prompt, dialogue_number, history_chat_record, stream, chat_id, chat_record_id, + def execute(self, model_id, system, prompt, dialogue_number, history_chat_record, stream, chat_id, + chat_record_id, **kwargs) -> NodeResult: pass diff --git a/apps/application/flow/step_node/ai_chat_step_node/impl/base_chat_node.py b/apps/application/flow/step_node/ai_chat_step_node/impl/base_chat_node.py index 4443d851b..d8d087b04 100644 --- a/apps/application/flow/step_node/ai_chat_step_node/impl/base_chat_node.py +++ b/apps/application/flow/step_node/ai_chat_step_node/impl/base_chat_node.py @@ -6,21 +6,17 @@ @date:2024/6/4 14:30 @desc: """ -import json import time from functools import reduce from typing import List, Dict -from django.db.models import QuerySet from langchain.schema import HumanMessage, SystemMessage from langchain_core.messages import BaseMessage from application.flow import tools from application.flow.i_step_node import NodeResult, INode from application.flow.step_node.ai_chat_step_node.i_chat_node import IChatNode -from common.util.rsa_util import rsa_long_decrypt -from setting.models import Model -from setting.models_provider.constants.model_provider_constants import ModelProvideConstants +from setting.models_provider.tools import get_model_instance_by_model_user_id def write_context_stream(node_variable: Dict, workflow_variable: Dict, node: INode, workflow): @@ -125,13 +121,7 @@ def to_response(chat_id, chat_record_id, node_variable: Dict, workflow_variable: class BaseChatNode(IChatNode): def execute(self, model_id, system, prompt, dialogue_number, history_chat_record, stream, chat_id, chat_record_id, **kwargs) -> NodeResult: - model = QuerySet(Model).filter(id=model_id).first() - if model is None: - raise Exception("模型不存在") - chat_model = ModelProvideConstants[model.provider].value.get_model(model.model_type, model.model_name, - json.loads( - rsa_long_decrypt(model.credential)), - streaming=True) + chat_model = get_model_instance_by_model_user_id(model_id, self.flow_params_serializer.data.get('user_id')) history_message = self.get_history_message(history_chat_record, dialogue_number) self.context['history_message'] = history_message question = self.generate_prompt_question(prompt) diff --git a/apps/application/flow/step_node/question_node/impl/base_question_node.py b/apps/application/flow/step_node/question_node/impl/base_question_node.py index d8257b138..f5855361e 100644 --- a/apps/application/flow/step_node/question_node/impl/base_question_node.py +++ b/apps/application/flow/step_node/question_node/impl/base_question_node.py @@ -6,21 +6,17 @@ @date:2024/6/4 14:30 @desc: """ -import json import time from functools import reduce from typing import List, Dict -from django.db.models import QuerySet from langchain.schema import HumanMessage, SystemMessage from langchain_core.messages import BaseMessage from application.flow import tools from application.flow.i_step_node import NodeResult, INode from application.flow.step_node.question_node.i_question_node import IQuestionNode -from common.util.rsa_util import rsa_long_decrypt -from setting.models import Model -from setting.models_provider.constants.model_provider_constants import ModelProvideConstants +from setting.models_provider.tools import get_model_instance_by_model_user_id def write_context_stream(node_variable: Dict, workflow_variable: Dict, node: INode, workflow): @@ -125,13 +121,7 @@ def to_response(chat_id, chat_record_id, node_variable: Dict, workflow_variable: class BaseQuestionNode(IQuestionNode): def execute(self, model_id, system, prompt, dialogue_number, history_chat_record, stream, chat_id, chat_record_id, **kwargs) -> NodeResult: - model = QuerySet(Model).filter(id=model_id).first() - if model is None: - raise Exception("模型不存在") - chat_model = ModelProvideConstants[model.provider].value.get_model(model.model_type, model.model_name, - json.loads( - rsa_long_decrypt(model.credential)), - streaming=True) + chat_model = get_model_instance_by_model_user_id(model_id, self.flow_params_serializer.data.get('user_id')) history_message = self.get_history_message(history_chat_record, dialogue_number) self.context['history_message'] = history_message question = self.generate_prompt_question(prompt) diff --git a/apps/application/flow/step_node/search_dataset_node/impl/base_search_dataset_node.py b/apps/application/flow/step_node/search_dataset_node/impl/base_search_dataset_node.py index 61634a371..60bec0a4a 100644 --- a/apps/application/flow/step_node/search_dataset_node/impl/base_search_dataset_node.py +++ b/apps/application/flow/step_node/search_dataset_node/impl/base_search_dataset_node.py @@ -13,25 +13,15 @@ from django.db.models import QuerySet from application.flow.i_step_node import NodeResult from application.flow.step_node.search_dataset_node.i_search_dataset_node import ISearchDatasetStepNode -from common.config.embedding_config import VectorStore, ModelManage +from common.config.embedding_config import VectorStore from common.db.search import native_search from common.util.file_util import get_file_content from dataset.models import Document, Paragraph, DataSet from embedding.models import SearchMode -from setting.models import Model -from setting.models_provider import get_model +from setting.models_provider.tools import get_model_instance_by_model_user_id from smartdoc.conf import PROJECT_DIR -def get_model_by_id(_id, user_id): - model = QuerySet(Model).filter(id=_id).first() - if model is None: - raise Exception("模型不存在") - if model.permission_type == 'PRIVATE' and str(model.user_id) != str(user_id): - raise Exception(f"无权限使用此模型:{model.name}") - return model - - def get_embedding_id(dataset_id_list): dataset_list = QuerySet(DataSet).filter(id__in=dataset_id_list) if len(set([dataset.embedding_mode_id for dataset in dataset_list])) > 1: @@ -55,8 +45,7 @@ class BaseSearchDatasetNode(ISearchDatasetStepNode): if len(dataset_id_list) == 0: return get_none_result(question) model_id = get_embedding_id(dataset_id_list) - model = get_model_by_id(model_id, self.flow_params_serializer.data.get('user_id')) - embedding_model = ModelManage.get_model(model_id, lambda _id: get_model(model)) + embedding_model = get_model_instance_by_model_user_id(model_id, self.flow_params_serializer.data.get('user_id')) embedding_value = embedding_model.embed_query(question) vector = VectorStore.get_embedding_vector() exclude_document_id_list = [str(document.id) for document in diff --git a/apps/application/serializers/application_serializers.py b/apps/application/serializers/application_serializers.py index e13c7219b..7e85ffc4a 100644 --- a/apps/application/serializers/application_serializers.py +++ b/apps/application/serializers/application_serializers.py @@ -639,10 +639,11 @@ class ApplicationSerializer(serializers.Serializer): application.model_id = None else: model = QuerySet(Model).filter( - id=instance.get('model_id'), - user_id=application.user_id).first() + id=instance.get('model_id')).first() if model is None: raise AppApiException(500, "模型不存在") + if not model.is_permission(application.user_id): + raise AppApiException(500, f"沒有权限使用该模型:{model.name}") if 'work_flow' in instance: # 当前用户可修改关联的知识库列表 application_dataset_id_list = [str(dataset_dict.get('id')) for dataset_dict in diff --git a/apps/application/serializers/chat_message_serializers.py b/apps/application/serializers/chat_message_serializers.py index 5f2754918..ecfd30070 100644 --- a/apps/application/serializers/chat_message_serializers.py +++ b/apps/application/serializers/chat_message_serializers.py @@ -267,14 +267,6 @@ class ChatMessageSerializer(serializers.Serializer): @staticmethod def re_open_chat_simple(chat_id, application): - model = QuerySet(Model).filter(id=application.model_id).first() - chat_model = None - if model is not None: - # 对话模型 - chat_model = ModelProvideConstants[model.provider].value.get_model(model.model_type, model.model_name, - json.loads( - rsa_long_decrypt(model.credential)), - streaming=True) # 数据集id列表 dataset_id_list = [str(row.dataset_id) for row in QuerySet(ApplicationDatasetMapping).filter( @@ -285,7 +277,7 @@ class ChatMessageSerializer(serializers.Serializer): QuerySet(Document).filter( dataset_id__in=dataset_id_list, is_active=False)] - return ChatInfo(chat_id, chat_model, dataset_id_list, exclude_document_id_list, application) + return ChatInfo(chat_id, None, dataset_id_list, exclude_document_id_list, application) @staticmethod def re_open_chat_work_flow(chat_id, application): diff --git a/apps/application/serializers/chat_serializers.py b/apps/application/serializers/chat_serializers.py index 7e4018fb9..0398fac13 100644 --- a/apps/application/serializers/chat_serializers.py +++ b/apps/application/serializers/chat_serializers.py @@ -7,7 +7,6 @@ @desc: """ import datetime -import json import os import re import uuid @@ -38,13 +37,11 @@ from common.util.common import post from common.util.field_message import ErrMessage from common.util.file_util import get_file_content from common.util.lock import try_lock, un_lock -from common.util.rsa_util import rsa_long_decrypt from dataset.models import Document, Problem, Paragraph, ProblemParagraphMapping from dataset.serializers.common_serializers import get_embedding_model_by_dataset_id from dataset.serializers.paragraph_serializers import ParagraphSerializers from setting.models import Model from setting.models_provider import get_model -from setting.models_provider.constants.model_provider_constants import ModelProvideConstants from smartdoc.conf import PROJECT_DIR chat_cache = caches['model_cache'] @@ -238,16 +235,12 @@ class ChatSerializers(serializers.Serializer): def open_simple(self, application): application_id = self.data.get('application_id') - model = QuerySet(Model).filter(id=application.model_id).first() dataset_id_list = [str(row.dataset_id) for row in QuerySet(ApplicationDatasetMapping).filter( application_id=application_id)] - chat_model = None - if model is not None: - chat_model = ModelManage.get_model(str(model.id), lambda _id: get_model(model)) chat_id = str(uuid.uuid1()) chat_cache.set(chat_id, - ChatInfo(chat_id, chat_model, dataset_id_list, + ChatInfo(chat_id, None, dataset_id_list, [str(document.id) for document in QuerySet(Document).filter( dataset_id__in=dataset_id_list, @@ -318,24 +311,14 @@ class ChatSerializers(serializers.Serializer): user_id = self.is_valid(raise_exception=True) chat_id = str(uuid.uuid1()) model_id = self.data.get('model_id') - if model_id is not None and len(model_id) > 0: - model = QuerySet(Model).filter(user_id=user_id, id=self.data.get('model_id')).first() - chat_model = ModelProvideConstants[model.provider].value.get_model(model.model_type, model.model_name, - json.loads( - rsa_long_decrypt( - model.credential)), - streaming=True) - else: - model = None - chat_model = None dataset_id_list = self.data.get('dataset_id_list') - application = Application(id=None, dialogue_number=3, model=model, + application = Application(id=None, dialogue_number=3, model_id=model_id, dataset_setting=self.data.get('dataset_setting'), model_setting=self.data.get('model_setting'), problem_optimization=self.data.get('problem_optimization'), user_id=user_id) chat_cache.set(chat_id, - ChatInfo(chat_id, chat_model, dataset_id_list, + ChatInfo(chat_id, None, dataset_id_list, [str(document.id) for document in QuerySet(Document).filter( dataset_id__in=dataset_id_list, diff --git a/apps/setting/models/model_management.py b/apps/setting/models/model_management.py index 4c47eadfc..e696e0630 100644 --- a/apps/setting/models/model_management.py +++ b/apps/setting/models/model_management.py @@ -56,6 +56,11 @@ class Model(AppModelMixin): permission_type = models.CharField(max_length=20, verbose_name='权限类型', choices=PermissionType.choices, default=PermissionType.PRIVATE) + def is_permission(self, user_id): + if self.permission_type == PermissionType.PRIVATE and str(user_id) == str(self.user_id): + return False + return True + class Meta: db_table = "model" unique_together = ['name', 'user_id'] diff --git a/apps/setting/models_provider/tools.py b/apps/setting/models_provider/tools.py new file mode 100644 index 000000000..293b9f7ef --- /dev/null +++ b/apps/setting/models_provider/tools.py @@ -0,0 +1,33 @@ +# coding=utf-8 +""" + @project: MaxKB + @Author:虎 + @file: tools.py + @date:2024/7/22 11:18 + @desc: +""" +from django.db.models import QuerySet + +from common.config.embedding_config import ModelManage +from setting.models import Model +from setting.models_provider import get_model + + +def get_model_by_id(_id, user_id): + model = QuerySet(Model).filter(id=_id).first() + if model is None: + raise Exception("模型不存在") + if model.permission_type == 'PRIVATE' and str(model.user_id) != str(user_id): + raise Exception(f"无权限使用此模型:{model.name}") + return model + + +def get_model_instance_by_model_user_id(model_id, user_id): + """ + 获取模型实例,根据模型相关数据 + @param model_id: 模型id + @param user_id: 用户id + @return: 模型实例 + """ + model = get_model_by_id(model_id, user_id) + return ModelManage.get_model(model_id, lambda _id: get_model(model))