mirror of
https://github.com/1Panel-dev/MaxKB.git
synced 2025-12-26 01:33:05 +00:00
fix: 修复模型没权限使用时报错 (#825)
This commit is contained in:
parent
fb5ad9a06b
commit
a404a5c6e9
|
|
@ -53,8 +53,7 @@ class IChatStep(IBaseChatPipelineStep):
|
|||
# 对话列表
|
||||
message_list = serializers.ListField(required=True, child=MessageField(required=True),
|
||||
error_messages=ErrMessage.list("对话列表"))
|
||||
# 大语言模型
|
||||
chat_model = ModelField(required=False, allow_null=True, error_messages=ErrMessage.list("大语言模型"))
|
||||
model_id = serializers.UUIDField(required=True, error_messages=ErrMessage.uuid("模型id"))
|
||||
# 段落列表
|
||||
paragraph_list = serializers.ListField(error_messages=ErrMessage.list("段落列表"))
|
||||
# 对话id
|
||||
|
|
@ -73,6 +72,8 @@ class IChatStep(IBaseChatPipelineStep):
|
|||
# 未查询到引用分段
|
||||
no_references_setting = NoReferencesSetting(required=True, error_messages=ErrMessage.base("无引用分段设置"))
|
||||
|
||||
user_id = serializers.UUIDField(required=True, error_messages=ErrMessage.uuid("用户id"))
|
||||
|
||||
def is_valid(self, *, raise_exception=False):
|
||||
super().is_valid(raise_exception=True)
|
||||
message_list: List = self.initial_data.get('message_list')
|
||||
|
|
@ -91,7 +92,8 @@ class IChatStep(IBaseChatPipelineStep):
|
|||
def execute(self, message_list: List[BaseMessage],
|
||||
chat_id, problem_text,
|
||||
post_response_handler: PostResponseHandler,
|
||||
chat_model: BaseChatModel = None,
|
||||
model_id: str = None,
|
||||
user_id: str = None,
|
||||
paragraph_list=None,
|
||||
manage: PipelineManage = None,
|
||||
padding_problem_text: str = None, stream: bool = True, client_id=None, client_type=None,
|
||||
|
|
|
|||
|
|
@ -26,6 +26,7 @@ from application.chat_pipeline.step.chat_step.i_chat_step import IChatStep, Post
|
|||
from application.models.api_key_model import ApplicationPublicAccessClient
|
||||
from common.constants.authentication_type import AuthenticationType
|
||||
from common.response import result
|
||||
from setting.models_provider.tools import get_model_instance_by_model_user_id
|
||||
|
||||
|
||||
def add_access_num(client_id=None, client_type=None):
|
||||
|
|
@ -101,7 +102,8 @@ class BaseChatStep(IChatStep):
|
|||
chat_id,
|
||||
problem_text,
|
||||
post_response_handler: PostResponseHandler,
|
||||
chat_model: BaseChatModel = None,
|
||||
model_id: str = None,
|
||||
user_id: str = None,
|
||||
paragraph_list=None,
|
||||
manage: PipelineManage = None,
|
||||
padding_problem_text: str = None,
|
||||
|
|
@ -109,6 +111,7 @@ class BaseChatStep(IChatStep):
|
|||
client_id=None, client_type=None,
|
||||
no_references_setting=None,
|
||||
**kwargs):
|
||||
chat_model = get_model_instance_by_model_user_id(model_id, user_id)
|
||||
if stream:
|
||||
return self.execute_stream(message_list, chat_id, problem_text, post_response_handler, chat_model,
|
||||
paragraph_list,
|
||||
|
|
|
|||
|
|
@ -111,7 +111,7 @@ class FlowParamsSerializer(serializers.Serializer):
|
|||
|
||||
client_type = serializers.CharField(required=False, error_messages=ErrMessage.char("客户端类型"))
|
||||
|
||||
user_id = serializers.UUIDField(required=True, error_messages=ErrMessage.char("用户id"))
|
||||
user_id = serializers.UUIDField(required=True, error_messages=ErrMessage.uuid("用户id"))
|
||||
re_chat = serializers.BooleanField(required=True, error_messages=ErrMessage.boolean("换个答案"))
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -32,6 +32,7 @@ class IChatNode(INode):
|
|||
def _run(self):
|
||||
return self.execute(**self.node_params_serializer.data, **self.flow_params_serializer.data)
|
||||
|
||||
def execute(self, model_id, system, prompt, dialogue_number, history_chat_record, stream, chat_id, chat_record_id,
|
||||
def execute(self, model_id, system, prompt, dialogue_number, history_chat_record, stream, chat_id,
|
||||
chat_record_id,
|
||||
**kwargs) -> NodeResult:
|
||||
pass
|
||||
|
|
|
|||
|
|
@ -6,21 +6,17 @@
|
|||
@date:2024/6/4 14:30
|
||||
@desc:
|
||||
"""
|
||||
import json
|
||||
import time
|
||||
from functools import reduce
|
||||
from typing import List, Dict
|
||||
|
||||
from django.db.models import QuerySet
|
||||
from langchain.schema import HumanMessage, SystemMessage
|
||||
from langchain_core.messages import BaseMessage
|
||||
|
||||
from application.flow import tools
|
||||
from application.flow.i_step_node import NodeResult, INode
|
||||
from application.flow.step_node.ai_chat_step_node.i_chat_node import IChatNode
|
||||
from common.util.rsa_util import rsa_long_decrypt
|
||||
from setting.models import Model
|
||||
from setting.models_provider.constants.model_provider_constants import ModelProvideConstants
|
||||
from setting.models_provider.tools import get_model_instance_by_model_user_id
|
||||
|
||||
|
||||
def write_context_stream(node_variable: Dict, workflow_variable: Dict, node: INode, workflow):
|
||||
|
|
@ -125,13 +121,7 @@ def to_response(chat_id, chat_record_id, node_variable: Dict, workflow_variable:
|
|||
class BaseChatNode(IChatNode):
|
||||
def execute(self, model_id, system, prompt, dialogue_number, history_chat_record, stream, chat_id, chat_record_id,
|
||||
**kwargs) -> NodeResult:
|
||||
model = QuerySet(Model).filter(id=model_id).first()
|
||||
if model is None:
|
||||
raise Exception("模型不存在")
|
||||
chat_model = ModelProvideConstants[model.provider].value.get_model(model.model_type, model.model_name,
|
||||
json.loads(
|
||||
rsa_long_decrypt(model.credential)),
|
||||
streaming=True)
|
||||
chat_model = get_model_instance_by_model_user_id(model_id, self.flow_params_serializer.data.get('user_id'))
|
||||
history_message = self.get_history_message(history_chat_record, dialogue_number)
|
||||
self.context['history_message'] = history_message
|
||||
question = self.generate_prompt_question(prompt)
|
||||
|
|
|
|||
|
|
@ -6,21 +6,17 @@
|
|||
@date:2024/6/4 14:30
|
||||
@desc:
|
||||
"""
|
||||
import json
|
||||
import time
|
||||
from functools import reduce
|
||||
from typing import List, Dict
|
||||
|
||||
from django.db.models import QuerySet
|
||||
from langchain.schema import HumanMessage, SystemMessage
|
||||
from langchain_core.messages import BaseMessage
|
||||
|
||||
from application.flow import tools
|
||||
from application.flow.i_step_node import NodeResult, INode
|
||||
from application.flow.step_node.question_node.i_question_node import IQuestionNode
|
||||
from common.util.rsa_util import rsa_long_decrypt
|
||||
from setting.models import Model
|
||||
from setting.models_provider.constants.model_provider_constants import ModelProvideConstants
|
||||
from setting.models_provider.tools import get_model_instance_by_model_user_id
|
||||
|
||||
|
||||
def write_context_stream(node_variable: Dict, workflow_variable: Dict, node: INode, workflow):
|
||||
|
|
@ -125,13 +121,7 @@ def to_response(chat_id, chat_record_id, node_variable: Dict, workflow_variable:
|
|||
class BaseQuestionNode(IQuestionNode):
|
||||
def execute(self, model_id, system, prompt, dialogue_number, history_chat_record, stream, chat_id, chat_record_id,
|
||||
**kwargs) -> NodeResult:
|
||||
model = QuerySet(Model).filter(id=model_id).first()
|
||||
if model is None:
|
||||
raise Exception("模型不存在")
|
||||
chat_model = ModelProvideConstants[model.provider].value.get_model(model.model_type, model.model_name,
|
||||
json.loads(
|
||||
rsa_long_decrypt(model.credential)),
|
||||
streaming=True)
|
||||
chat_model = get_model_instance_by_model_user_id(model_id, self.flow_params_serializer.data.get('user_id'))
|
||||
history_message = self.get_history_message(history_chat_record, dialogue_number)
|
||||
self.context['history_message'] = history_message
|
||||
question = self.generate_prompt_question(prompt)
|
||||
|
|
|
|||
|
|
@ -13,25 +13,15 @@ from django.db.models import QuerySet
|
|||
|
||||
from application.flow.i_step_node import NodeResult
|
||||
from application.flow.step_node.search_dataset_node.i_search_dataset_node import ISearchDatasetStepNode
|
||||
from common.config.embedding_config import VectorStore, ModelManage
|
||||
from common.config.embedding_config import VectorStore
|
||||
from common.db.search import native_search
|
||||
from common.util.file_util import get_file_content
|
||||
from dataset.models import Document, Paragraph, DataSet
|
||||
from embedding.models import SearchMode
|
||||
from setting.models import Model
|
||||
from setting.models_provider import get_model
|
||||
from setting.models_provider.tools import get_model_instance_by_model_user_id
|
||||
from smartdoc.conf import PROJECT_DIR
|
||||
|
||||
|
||||
def get_model_by_id(_id, user_id):
|
||||
model = QuerySet(Model).filter(id=_id).first()
|
||||
if model is None:
|
||||
raise Exception("模型不存在")
|
||||
if model.permission_type == 'PRIVATE' and str(model.user_id) != str(user_id):
|
||||
raise Exception(f"无权限使用此模型:{model.name}")
|
||||
return model
|
||||
|
||||
|
||||
def get_embedding_id(dataset_id_list):
|
||||
dataset_list = QuerySet(DataSet).filter(id__in=dataset_id_list)
|
||||
if len(set([dataset.embedding_mode_id for dataset in dataset_list])) > 1:
|
||||
|
|
@ -55,8 +45,7 @@ class BaseSearchDatasetNode(ISearchDatasetStepNode):
|
|||
if len(dataset_id_list) == 0:
|
||||
return get_none_result(question)
|
||||
model_id = get_embedding_id(dataset_id_list)
|
||||
model = get_model_by_id(model_id, self.flow_params_serializer.data.get('user_id'))
|
||||
embedding_model = ModelManage.get_model(model_id, lambda _id: get_model(model))
|
||||
embedding_model = get_model_instance_by_model_user_id(model_id, self.flow_params_serializer.data.get('user_id'))
|
||||
embedding_value = embedding_model.embed_query(question)
|
||||
vector = VectorStore.get_embedding_vector()
|
||||
exclude_document_id_list = [str(document.id) for document in
|
||||
|
|
|
|||
|
|
@ -639,10 +639,11 @@ class ApplicationSerializer(serializers.Serializer):
|
|||
application.model_id = None
|
||||
else:
|
||||
model = QuerySet(Model).filter(
|
||||
id=instance.get('model_id'),
|
||||
user_id=application.user_id).first()
|
||||
id=instance.get('model_id')).first()
|
||||
if model is None:
|
||||
raise AppApiException(500, "模型不存在")
|
||||
if not model.is_permission(application.user_id):
|
||||
raise AppApiException(500, f"沒有权限使用该模型:{model.name}")
|
||||
if 'work_flow' in instance:
|
||||
# 当前用户可修改关联的知识库列表
|
||||
application_dataset_id_list = [str(dataset_dict.get('id')) for dataset_dict in
|
||||
|
|
|
|||
|
|
@ -267,14 +267,6 @@ class ChatMessageSerializer(serializers.Serializer):
|
|||
|
||||
@staticmethod
|
||||
def re_open_chat_simple(chat_id, application):
|
||||
model = QuerySet(Model).filter(id=application.model_id).first()
|
||||
chat_model = None
|
||||
if model is not None:
|
||||
# 对话模型
|
||||
chat_model = ModelProvideConstants[model.provider].value.get_model(model.model_type, model.model_name,
|
||||
json.loads(
|
||||
rsa_long_decrypt(model.credential)),
|
||||
streaming=True)
|
||||
# 数据集id列表
|
||||
dataset_id_list = [str(row.dataset_id) for row in
|
||||
QuerySet(ApplicationDatasetMapping).filter(
|
||||
|
|
@ -285,7 +277,7 @@ class ChatMessageSerializer(serializers.Serializer):
|
|||
QuerySet(Document).filter(
|
||||
dataset_id__in=dataset_id_list,
|
||||
is_active=False)]
|
||||
return ChatInfo(chat_id, chat_model, dataset_id_list, exclude_document_id_list, application)
|
||||
return ChatInfo(chat_id, None, dataset_id_list, exclude_document_id_list, application)
|
||||
|
||||
@staticmethod
|
||||
def re_open_chat_work_flow(chat_id, application):
|
||||
|
|
|
|||
|
|
@ -7,7 +7,6 @@
|
|||
@desc:
|
||||
"""
|
||||
import datetime
|
||||
import json
|
||||
import os
|
||||
import re
|
||||
import uuid
|
||||
|
|
@ -38,13 +37,11 @@ from common.util.common import post
|
|||
from common.util.field_message import ErrMessage
|
||||
from common.util.file_util import get_file_content
|
||||
from common.util.lock import try_lock, un_lock
|
||||
from common.util.rsa_util import rsa_long_decrypt
|
||||
from dataset.models import Document, Problem, Paragraph, ProblemParagraphMapping
|
||||
from dataset.serializers.common_serializers import get_embedding_model_by_dataset_id
|
||||
from dataset.serializers.paragraph_serializers import ParagraphSerializers
|
||||
from setting.models import Model
|
||||
from setting.models_provider import get_model
|
||||
from setting.models_provider.constants.model_provider_constants import ModelProvideConstants
|
||||
from smartdoc.conf import PROJECT_DIR
|
||||
|
||||
chat_cache = caches['model_cache']
|
||||
|
|
@ -238,16 +235,12 @@ class ChatSerializers(serializers.Serializer):
|
|||
|
||||
def open_simple(self, application):
|
||||
application_id = self.data.get('application_id')
|
||||
model = QuerySet(Model).filter(id=application.model_id).first()
|
||||
dataset_id_list = [str(row.dataset_id) for row in
|
||||
QuerySet(ApplicationDatasetMapping).filter(
|
||||
application_id=application_id)]
|
||||
chat_model = None
|
||||
if model is not None:
|
||||
chat_model = ModelManage.get_model(str(model.id), lambda _id: get_model(model))
|
||||
chat_id = str(uuid.uuid1())
|
||||
chat_cache.set(chat_id,
|
||||
ChatInfo(chat_id, chat_model, dataset_id_list,
|
||||
ChatInfo(chat_id, None, dataset_id_list,
|
||||
[str(document.id) for document in
|
||||
QuerySet(Document).filter(
|
||||
dataset_id__in=dataset_id_list,
|
||||
|
|
@ -318,24 +311,14 @@ class ChatSerializers(serializers.Serializer):
|
|||
user_id = self.is_valid(raise_exception=True)
|
||||
chat_id = str(uuid.uuid1())
|
||||
model_id = self.data.get('model_id')
|
||||
if model_id is not None and len(model_id) > 0:
|
||||
model = QuerySet(Model).filter(user_id=user_id, id=self.data.get('model_id')).first()
|
||||
chat_model = ModelProvideConstants[model.provider].value.get_model(model.model_type, model.model_name,
|
||||
json.loads(
|
||||
rsa_long_decrypt(
|
||||
model.credential)),
|
||||
streaming=True)
|
||||
else:
|
||||
model = None
|
||||
chat_model = None
|
||||
dataset_id_list = self.data.get('dataset_id_list')
|
||||
application = Application(id=None, dialogue_number=3, model=model,
|
||||
application = Application(id=None, dialogue_number=3, model_id=model_id,
|
||||
dataset_setting=self.data.get('dataset_setting'),
|
||||
model_setting=self.data.get('model_setting'),
|
||||
problem_optimization=self.data.get('problem_optimization'),
|
||||
user_id=user_id)
|
||||
chat_cache.set(chat_id,
|
||||
ChatInfo(chat_id, chat_model, dataset_id_list,
|
||||
ChatInfo(chat_id, None, dataset_id_list,
|
||||
[str(document.id) for document in
|
||||
QuerySet(Document).filter(
|
||||
dataset_id__in=dataset_id_list,
|
||||
|
|
|
|||
|
|
@ -56,6 +56,11 @@ class Model(AppModelMixin):
|
|||
permission_type = models.CharField(max_length=20, verbose_name='权限类型', choices=PermissionType.choices,
|
||||
default=PermissionType.PRIVATE)
|
||||
|
||||
def is_permission(self, user_id):
|
||||
if self.permission_type == PermissionType.PRIVATE and str(user_id) == str(self.user_id):
|
||||
return False
|
||||
return True
|
||||
|
||||
class Meta:
|
||||
db_table = "model"
|
||||
unique_together = ['name', 'user_id']
|
||||
|
|
|
|||
|
|
@ -0,0 +1,33 @@
|
|||
# coding=utf-8
|
||||
"""
|
||||
@project: MaxKB
|
||||
@Author:虎
|
||||
@file: tools.py
|
||||
@date:2024/7/22 11:18
|
||||
@desc:
|
||||
"""
|
||||
from django.db.models import QuerySet
|
||||
|
||||
from common.config.embedding_config import ModelManage
|
||||
from setting.models import Model
|
||||
from setting.models_provider import get_model
|
||||
|
||||
|
||||
def get_model_by_id(_id, user_id):
|
||||
model = QuerySet(Model).filter(id=_id).first()
|
||||
if model is None:
|
||||
raise Exception("模型不存在")
|
||||
if model.permission_type == 'PRIVATE' and str(model.user_id) != str(user_id):
|
||||
raise Exception(f"无权限使用此模型:{model.name}")
|
||||
return model
|
||||
|
||||
|
||||
def get_model_instance_by_model_user_id(model_id, user_id):
|
||||
"""
|
||||
获取模型实例,根据模型相关数据
|
||||
@param model_id: 模型id
|
||||
@param user_id: 用户id
|
||||
@return: 模型实例
|
||||
"""
|
||||
model = get_model_by_id(model_id, user_id)
|
||||
return ModelManage.get_model(model_id, lambda _id: get_model(model))
|
||||
Loading…
Reference in New Issue