fix: 修复模型没权限使用时报错 (#825)

This commit is contained in:
shaohuzhang1 2024-07-22 11:44:38 +08:00 committed by GitHub
parent fb5ad9a06b
commit a404a5c6e9
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
12 changed files with 64 additions and 75 deletions

View File

@ -53,8 +53,7 @@ class IChatStep(IBaseChatPipelineStep):
# 对话列表
message_list = serializers.ListField(required=True, child=MessageField(required=True),
error_messages=ErrMessage.list("对话列表"))
# 大语言模型
chat_model = ModelField(required=False, allow_null=True, error_messages=ErrMessage.list("大语言模型"))
model_id = serializers.UUIDField(required=True, error_messages=ErrMessage.uuid("模型id"))
# 段落列表
paragraph_list = serializers.ListField(error_messages=ErrMessage.list("段落列表"))
# 对话id
@ -73,6 +72,8 @@ class IChatStep(IBaseChatPipelineStep):
# 未查询到引用分段
no_references_setting = NoReferencesSetting(required=True, error_messages=ErrMessage.base("无引用分段设置"))
user_id = serializers.UUIDField(required=True, error_messages=ErrMessage.uuid("用户id"))
def is_valid(self, *, raise_exception=False):
super().is_valid(raise_exception=True)
message_list: List = self.initial_data.get('message_list')
@ -91,7 +92,8 @@ class IChatStep(IBaseChatPipelineStep):
def execute(self, message_list: List[BaseMessage],
chat_id, problem_text,
post_response_handler: PostResponseHandler,
chat_model: BaseChatModel = None,
model_id: str = None,
user_id: str = None,
paragraph_list=None,
manage: PipelineManage = None,
padding_problem_text: str = None, stream: bool = True, client_id=None, client_type=None,

View File

@ -26,6 +26,7 @@ from application.chat_pipeline.step.chat_step.i_chat_step import IChatStep, Post
from application.models.api_key_model import ApplicationPublicAccessClient
from common.constants.authentication_type import AuthenticationType
from common.response import result
from setting.models_provider.tools import get_model_instance_by_model_user_id
def add_access_num(client_id=None, client_type=None):
@ -101,7 +102,8 @@ class BaseChatStep(IChatStep):
chat_id,
problem_text,
post_response_handler: PostResponseHandler,
chat_model: BaseChatModel = None,
model_id: str = None,
user_id: str = None,
paragraph_list=None,
manage: PipelineManage = None,
padding_problem_text: str = None,
@ -109,6 +111,7 @@ class BaseChatStep(IChatStep):
client_id=None, client_type=None,
no_references_setting=None,
**kwargs):
chat_model = get_model_instance_by_model_user_id(model_id, user_id)
if stream:
return self.execute_stream(message_list, chat_id, problem_text, post_response_handler, chat_model,
paragraph_list,

View File

@ -111,7 +111,7 @@ class FlowParamsSerializer(serializers.Serializer):
client_type = serializers.CharField(required=False, error_messages=ErrMessage.char("客户端类型"))
user_id = serializers.UUIDField(required=True, error_messages=ErrMessage.char("用户id"))
user_id = serializers.UUIDField(required=True, error_messages=ErrMessage.uuid("用户id"))
re_chat = serializers.BooleanField(required=True, error_messages=ErrMessage.boolean("换个答案"))

View File

@ -32,6 +32,7 @@ class IChatNode(INode):
def _run(self):
return self.execute(**self.node_params_serializer.data, **self.flow_params_serializer.data)
def execute(self, model_id, system, prompt, dialogue_number, history_chat_record, stream, chat_id, chat_record_id,
def execute(self, model_id, system, prompt, dialogue_number, history_chat_record, stream, chat_id,
chat_record_id,
**kwargs) -> NodeResult:
pass

View File

@ -6,21 +6,17 @@
@date2024/6/4 14:30
@desc:
"""
import json
import time
from functools import reduce
from typing import List, Dict
from django.db.models import QuerySet
from langchain.schema import HumanMessage, SystemMessage
from langchain_core.messages import BaseMessage
from application.flow import tools
from application.flow.i_step_node import NodeResult, INode
from application.flow.step_node.ai_chat_step_node.i_chat_node import IChatNode
from common.util.rsa_util import rsa_long_decrypt
from setting.models import Model
from setting.models_provider.constants.model_provider_constants import ModelProvideConstants
from setting.models_provider.tools import get_model_instance_by_model_user_id
def write_context_stream(node_variable: Dict, workflow_variable: Dict, node: INode, workflow):
@ -125,13 +121,7 @@ def to_response(chat_id, chat_record_id, node_variable: Dict, workflow_variable:
class BaseChatNode(IChatNode):
def execute(self, model_id, system, prompt, dialogue_number, history_chat_record, stream, chat_id, chat_record_id,
**kwargs) -> NodeResult:
model = QuerySet(Model).filter(id=model_id).first()
if model is None:
raise Exception("模型不存在")
chat_model = ModelProvideConstants[model.provider].value.get_model(model.model_type, model.model_name,
json.loads(
rsa_long_decrypt(model.credential)),
streaming=True)
chat_model = get_model_instance_by_model_user_id(model_id, self.flow_params_serializer.data.get('user_id'))
history_message = self.get_history_message(history_chat_record, dialogue_number)
self.context['history_message'] = history_message
question = self.generate_prompt_question(prompt)

View File

@ -6,21 +6,17 @@
@date2024/6/4 14:30
@desc:
"""
import json
import time
from functools import reduce
from typing import List, Dict
from django.db.models import QuerySet
from langchain.schema import HumanMessage, SystemMessage
from langchain_core.messages import BaseMessage
from application.flow import tools
from application.flow.i_step_node import NodeResult, INode
from application.flow.step_node.question_node.i_question_node import IQuestionNode
from common.util.rsa_util import rsa_long_decrypt
from setting.models import Model
from setting.models_provider.constants.model_provider_constants import ModelProvideConstants
from setting.models_provider.tools import get_model_instance_by_model_user_id
def write_context_stream(node_variable: Dict, workflow_variable: Dict, node: INode, workflow):
@ -125,13 +121,7 @@ def to_response(chat_id, chat_record_id, node_variable: Dict, workflow_variable:
class BaseQuestionNode(IQuestionNode):
def execute(self, model_id, system, prompt, dialogue_number, history_chat_record, stream, chat_id, chat_record_id,
**kwargs) -> NodeResult:
model = QuerySet(Model).filter(id=model_id).first()
if model is None:
raise Exception("模型不存在")
chat_model = ModelProvideConstants[model.provider].value.get_model(model.model_type, model.model_name,
json.loads(
rsa_long_decrypt(model.credential)),
streaming=True)
chat_model = get_model_instance_by_model_user_id(model_id, self.flow_params_serializer.data.get('user_id'))
history_message = self.get_history_message(history_chat_record, dialogue_number)
self.context['history_message'] = history_message
question = self.generate_prompt_question(prompt)

View File

@ -13,25 +13,15 @@ from django.db.models import QuerySet
from application.flow.i_step_node import NodeResult
from application.flow.step_node.search_dataset_node.i_search_dataset_node import ISearchDatasetStepNode
from common.config.embedding_config import VectorStore, ModelManage
from common.config.embedding_config import VectorStore
from common.db.search import native_search
from common.util.file_util import get_file_content
from dataset.models import Document, Paragraph, DataSet
from embedding.models import SearchMode
from setting.models import Model
from setting.models_provider import get_model
from setting.models_provider.tools import get_model_instance_by_model_user_id
from smartdoc.conf import PROJECT_DIR
def get_model_by_id(_id, user_id):
model = QuerySet(Model).filter(id=_id).first()
if model is None:
raise Exception("模型不存在")
if model.permission_type == 'PRIVATE' and str(model.user_id) != str(user_id):
raise Exception(f"无权限使用此模型:{model.name}")
return model
def get_embedding_id(dataset_id_list):
dataset_list = QuerySet(DataSet).filter(id__in=dataset_id_list)
if len(set([dataset.embedding_mode_id for dataset in dataset_list])) > 1:
@ -55,8 +45,7 @@ class BaseSearchDatasetNode(ISearchDatasetStepNode):
if len(dataset_id_list) == 0:
return get_none_result(question)
model_id = get_embedding_id(dataset_id_list)
model = get_model_by_id(model_id, self.flow_params_serializer.data.get('user_id'))
embedding_model = ModelManage.get_model(model_id, lambda _id: get_model(model))
embedding_model = get_model_instance_by_model_user_id(model_id, self.flow_params_serializer.data.get('user_id'))
embedding_value = embedding_model.embed_query(question)
vector = VectorStore.get_embedding_vector()
exclude_document_id_list = [str(document.id) for document in

View File

@ -639,10 +639,11 @@ class ApplicationSerializer(serializers.Serializer):
application.model_id = None
else:
model = QuerySet(Model).filter(
id=instance.get('model_id'),
user_id=application.user_id).first()
id=instance.get('model_id')).first()
if model is None:
raise AppApiException(500, "模型不存在")
if not model.is_permission(application.user_id):
raise AppApiException(500, f"沒有权限使用该模型:{model.name}")
if 'work_flow' in instance:
# 当前用户可修改关联的知识库列表
application_dataset_id_list = [str(dataset_dict.get('id')) for dataset_dict in

View File

@ -267,14 +267,6 @@ class ChatMessageSerializer(serializers.Serializer):
@staticmethod
def re_open_chat_simple(chat_id, application):
model = QuerySet(Model).filter(id=application.model_id).first()
chat_model = None
if model is not None:
# 对话模型
chat_model = ModelProvideConstants[model.provider].value.get_model(model.model_type, model.model_name,
json.loads(
rsa_long_decrypt(model.credential)),
streaming=True)
# 数据集id列表
dataset_id_list = [str(row.dataset_id) for row in
QuerySet(ApplicationDatasetMapping).filter(
@ -285,7 +277,7 @@ class ChatMessageSerializer(serializers.Serializer):
QuerySet(Document).filter(
dataset_id__in=dataset_id_list,
is_active=False)]
return ChatInfo(chat_id, chat_model, dataset_id_list, exclude_document_id_list, application)
return ChatInfo(chat_id, None, dataset_id_list, exclude_document_id_list, application)
@staticmethod
def re_open_chat_work_flow(chat_id, application):

View File

@ -7,7 +7,6 @@
@desc:
"""
import datetime
import json
import os
import re
import uuid
@ -38,13 +37,11 @@ from common.util.common import post
from common.util.field_message import ErrMessage
from common.util.file_util import get_file_content
from common.util.lock import try_lock, un_lock
from common.util.rsa_util import rsa_long_decrypt
from dataset.models import Document, Problem, Paragraph, ProblemParagraphMapping
from dataset.serializers.common_serializers import get_embedding_model_by_dataset_id
from dataset.serializers.paragraph_serializers import ParagraphSerializers
from setting.models import Model
from setting.models_provider import get_model
from setting.models_provider.constants.model_provider_constants import ModelProvideConstants
from smartdoc.conf import PROJECT_DIR
chat_cache = caches['model_cache']
@ -238,16 +235,12 @@ class ChatSerializers(serializers.Serializer):
def open_simple(self, application):
application_id = self.data.get('application_id')
model = QuerySet(Model).filter(id=application.model_id).first()
dataset_id_list = [str(row.dataset_id) for row in
QuerySet(ApplicationDatasetMapping).filter(
application_id=application_id)]
chat_model = None
if model is not None:
chat_model = ModelManage.get_model(str(model.id), lambda _id: get_model(model))
chat_id = str(uuid.uuid1())
chat_cache.set(chat_id,
ChatInfo(chat_id, chat_model, dataset_id_list,
ChatInfo(chat_id, None, dataset_id_list,
[str(document.id) for document in
QuerySet(Document).filter(
dataset_id__in=dataset_id_list,
@ -318,24 +311,14 @@ class ChatSerializers(serializers.Serializer):
user_id = self.is_valid(raise_exception=True)
chat_id = str(uuid.uuid1())
model_id = self.data.get('model_id')
if model_id is not None and len(model_id) > 0:
model = QuerySet(Model).filter(user_id=user_id, id=self.data.get('model_id')).first()
chat_model = ModelProvideConstants[model.provider].value.get_model(model.model_type, model.model_name,
json.loads(
rsa_long_decrypt(
model.credential)),
streaming=True)
else:
model = None
chat_model = None
dataset_id_list = self.data.get('dataset_id_list')
application = Application(id=None, dialogue_number=3, model=model,
application = Application(id=None, dialogue_number=3, model_id=model_id,
dataset_setting=self.data.get('dataset_setting'),
model_setting=self.data.get('model_setting'),
problem_optimization=self.data.get('problem_optimization'),
user_id=user_id)
chat_cache.set(chat_id,
ChatInfo(chat_id, chat_model, dataset_id_list,
ChatInfo(chat_id, None, dataset_id_list,
[str(document.id) for document in
QuerySet(Document).filter(
dataset_id__in=dataset_id_list,

View File

@ -56,6 +56,11 @@ class Model(AppModelMixin):
permission_type = models.CharField(max_length=20, verbose_name='权限类型', choices=PermissionType.choices,
default=PermissionType.PRIVATE)
def is_permission(self, user_id):
if self.permission_type == PermissionType.PRIVATE and str(user_id) == str(self.user_id):
return False
return True
class Meta:
db_table = "model"
unique_together = ['name', 'user_id']

View File

@ -0,0 +1,33 @@
# coding=utf-8
"""
@project: MaxKB
@Author
@file tools.py
@date2024/7/22 11:18
@desc:
"""
from django.db.models import QuerySet
from common.config.embedding_config import ModelManage
from setting.models import Model
from setting.models_provider import get_model
def get_model_by_id(_id, user_id):
model = QuerySet(Model).filter(id=_id).first()
if model is None:
raise Exception("模型不存在")
if model.permission_type == 'PRIVATE' and str(model.user_id) != str(user_id):
raise Exception(f"无权限使用此模型:{model.name}")
return model
def get_model_instance_by_model_user_id(model_id, user_id):
"""
获取模型实例,根据模型相关数据
@param model_id: 模型id
@param user_id: 用户id
@return: 模型实例
"""
model = get_model_by_id(model_id, user_id)
return ModelManage.get_model(model_id, lambda _id: get_model(model))