From b14a799350f284e200c7febed78caba9725a044b Mon Sep 17 00:00:00 2001 From: shaohuzhang1 Date: Fri, 19 Jul 2024 10:34:47 +0800 Subject: [PATCH] =?UTF-8?q?feat:=20=E6=94=AF=E6=8C=81=E5=90=91=E9=87=8F?= =?UTF-8?q?=E6=A8=A1=E5=9E=8B?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../i_search_dataset_step.py | 3 +++ .../impl/base_search_dataset_step.py | 11 +++++++---- .../impl/base_search_dataset_node.py | 4 ++-- .../serializers/chat_message_serializers.py | 4 +++- .../serializers/chat_serializers.py | 9 +++------ apps/common/config/embedding_config.py | 19 ++++++++++++------- .../dataset/serializers/common_serializers.py | 8 ++++---- 7 files changed, 34 insertions(+), 24 deletions(-) diff --git a/apps/application/chat_pipeline/step/search_dataset_step/i_search_dataset_step.py b/apps/application/chat_pipeline/step/search_dataset_step/i_search_dataset_step.py index 549abfaf2..97da29643 100644 --- a/apps/application/chat_pipeline/step/search_dataset_step/i_search_dataset_step.py +++ b/apps/application/chat_pipeline/step/search_dataset_step/i_search_dataset_step.py @@ -43,6 +43,7 @@ class ISearchDatasetStep(IBaseChatPipelineStep): validators.RegexValidator(regex=re.compile("^embedding|keywords|blend$"), message="类型只支持register|reset_password", code=500) ], error_messages=ErrMessage.char("检索模式")) + user_id = serializers.UUIDField(required=True, error_messages=ErrMessage.uuid("用户id")) def get_step_serializer(self, manage: PipelineManage) -> Type[InstanceSerializer]: return self.InstanceSerializer @@ -56,6 +57,7 @@ class ISearchDatasetStep(IBaseChatPipelineStep): def execute(self, problem_text: str, dataset_id_list: list[str], exclude_document_id_list: list[str], exclude_paragraph_id_list: list[str], top_n: int, similarity: float, padding_problem_text: str = None, search_mode: str = None, + user_id=None, **kwargs) -> List[ParagraphPipelineModel]: """ 关于 用户和补全问题 说明: 补全问题如果有就使用补全问题去查询 反之就用用户原始问题查询 @@ -67,6 +69,7 @@ class ISearchDatasetStep(IBaseChatPipelineStep): :param exclude_paragraph_id_list: 需要排除段落id :param padding_problem_text 补全问题 :param search_mode 检索模式 + :param user_id 用户id :return: 段落列表 """ pass diff --git a/apps/application/chat_pipeline/step/search_dataset_step/impl/base_search_dataset_step.py b/apps/application/chat_pipeline/step/search_dataset_step/impl/base_search_dataset_step.py index d893237da..0d43c3b2a 100644 --- a/apps/application/chat_pipeline/step/search_dataset_step/impl/base_search_dataset_step.py +++ b/apps/application/chat_pipeline/step/search_dataset_step/impl/base_search_dataset_step.py @@ -13,7 +13,7 @@ from django.db.models import QuerySet from application.chat_pipeline.I_base_chat_pipeline import ParagraphPipelineModel from application.chat_pipeline.step.search_dataset_step.i_search_dataset_step import ISearchDatasetStep -from common.config.embedding_config import VectorStore, EmbeddingModelManage +from common.config.embedding_config import VectorStore, ModelManage from common.db.search import native_search from common.util.file_util import get_file_content from dataset.models import Paragraph, DataSet @@ -23,10 +23,12 @@ from setting.models_provider import get_model from smartdoc.conf import PROJECT_DIR -def get_model_by_id(_id): +def get_model_by_id(_id, user_id): model = QuerySet(Model).filter(id=_id).first() if model is None: raise Exception("模型不存在") + if model.permission_type == 'PRIVATE' and str(model.user_id) != str(user_id): + raise Exception(f"无权限使用此模型:{model.name}") return model @@ -44,14 +46,15 @@ class BaseSearchDatasetStep(ISearchDatasetStep): def execute(self, problem_text: str, dataset_id_list: list[str], exclude_document_id_list: list[str], exclude_paragraph_id_list: list[str], top_n: int, similarity: float, padding_problem_text: str = None, search_mode: str = None, + user_id=None, **kwargs) -> List[ParagraphPipelineModel]: if len(dataset_id_list) == 0: return [] exec_problem_text = padding_problem_text if padding_problem_text is not None else problem_text model_id = get_embedding_id(dataset_id_list) - model = get_model_by_id(model_id) + model = get_model_by_id(model_id, user_id) self.context['model_name'] = model.name - embedding_model = EmbeddingModelManage.get_model(model_id, lambda _id: get_model(model)) + embedding_model = ModelManage.get_model(model_id, lambda _id: get_model(model)) embedding_value = embedding_model.embed_query(exec_problem_text) vector = VectorStore.get_embedding_vector() embedding_list = vector.query(exec_problem_text, embedding_value, dataset_id_list, exclude_document_id_list, diff --git a/apps/application/flow/step_node/search_dataset_node/impl/base_search_dataset_node.py b/apps/application/flow/step_node/search_dataset_node/impl/base_search_dataset_node.py index e6b70749a..61634a371 100644 --- a/apps/application/flow/step_node/search_dataset_node/impl/base_search_dataset_node.py +++ b/apps/application/flow/step_node/search_dataset_node/impl/base_search_dataset_node.py @@ -13,7 +13,7 @@ from django.db.models import QuerySet from application.flow.i_step_node import NodeResult from application.flow.step_node.search_dataset_node.i_search_dataset_node import ISearchDatasetStepNode -from common.config.embedding_config import VectorStore, EmbeddingModelManage +from common.config.embedding_config import VectorStore, ModelManage from common.db.search import native_search from common.util.file_util import get_file_content from dataset.models import Document, Paragraph, DataSet @@ -56,7 +56,7 @@ class BaseSearchDatasetNode(ISearchDatasetStepNode): return get_none_result(question) model_id = get_embedding_id(dataset_id_list) model = get_model_by_id(model_id, self.flow_params_serializer.data.get('user_id')) - embedding_model = EmbeddingModelManage.get_model(model_id, lambda _id: get_model(model)) + embedding_model = ModelManage.get_model(model_id, lambda _id: get_model(model)) embedding_value = embedding_model.embed_query(question) vector = VectorStore.get_embedding_vector() exclude_document_id_list = [str(document.id) for document in diff --git a/apps/application/serializers/chat_message_serializers.py b/apps/application/serializers/chat_message_serializers.py index eb5b2fe76..5f2754918 100644 --- a/apps/application/serializers/chat_message_serializers.py +++ b/apps/application/serializers/chat_message_serializers.py @@ -88,7 +88,9 @@ class ChatInfo: 'no_references_setting': self.application.dataset_setting.get( 'no_references_setting') if 'no_references_setting' in self.application.dataset_setting else { 'status': 'ai_questioning', - 'value': '{question}'} + 'value': '{question}', + }, + 'user_id': self.application.user_id } diff --git a/apps/application/serializers/chat_serializers.py b/apps/application/serializers/chat_serializers.py index 6c8d090d6..7e4018fb9 100644 --- a/apps/application/serializers/chat_serializers.py +++ b/apps/application/serializers/chat_serializers.py @@ -29,6 +29,7 @@ from application.models.api_key_model import ApplicationAccessToken from application.serializers.application_serializers import ModelDatasetAssociation, DatasetSettingSerializer, \ ModelSettingSerializer from application.serializers.chat_message_serializers import ChatInfo +from common.config.embedding_config import ModelManage from common.constants.permission_constants import RoleConstants from common.db.search import native_search, native_page_search, page_search, get_dynamics_model from common.event import ListenerManagement @@ -42,6 +43,7 @@ from dataset.models import Document, Problem, Paragraph, ProblemParagraphMapping from dataset.serializers.common_serializers import get_embedding_model_by_dataset_id from dataset.serializers.paragraph_serializers import ParagraphSerializers from setting.models import Model +from setting.models_provider import get_model from setting.models_provider.constants.model_provider_constants import ModelProvideConstants from smartdoc.conf import PROJECT_DIR @@ -242,12 +244,7 @@ class ChatSerializers(serializers.Serializer): application_id=application_id)] chat_model = None if model is not None: - chat_model = ModelProvideConstants[model.provider].value.get_model(model.model_type, model.model_name, - json.loads( - rsa_long_decrypt( - model.credential)), - streaming=True) - + chat_model = ModelManage.get_model(str(model.id), lambda _id: get_model(model)) chat_id = str(uuid.uuid1()) chat_cache.set(chat_id, ChatInfo(chat_id, chat_model, dataset_id_list, diff --git a/apps/common/config/embedding_config.py b/apps/common/config/embedding_config.py index e784c16a1..b33dd83a1 100644 --- a/apps/common/config/embedding_config.py +++ b/apps/common/config/embedding_config.py @@ -11,26 +11,31 @@ import time from common.cache.mem_cache import MemCache -class EmbeddingModelManage: +class ModelManage: cache = MemCache('model', {}) up_clear_time = time.time() @staticmethod def get_model(_id, get_model): - model_instance = EmbeddingModelManage.cache.get(_id) + model_instance = ModelManage.cache.get(_id) if model_instance is None: model_instance = get_model(_id) - EmbeddingModelManage.cache.set(_id, model_instance, timeout=60 * 30) + ModelManage.cache.set(_id, model_instance, timeout=60 * 30) return model_instance # 续期 - EmbeddingModelManage.cache.touch(_id, timeout=60 * 30) - EmbeddingModelManage.clear_timeout_cache() + ModelManage.cache.touch(_id, timeout=60 * 30) + ModelManage.clear_timeout_cache() return model_instance @staticmethod def clear_timeout_cache(): - if time.time() - EmbeddingModelManage.up_clear_time > 60: - EmbeddingModelManage.cache.clear_timeout_data() + if time.time() - ModelManage.up_clear_time > 60: + ModelManage.cache.clear_timeout_data() + + @staticmethod + def delete_key(_id): + if ModelManage.cache.has_key(_id): + ModelManage.cache.delete(_id) class VectorStore: diff --git a/apps/dataset/serializers/common_serializers.py b/apps/dataset/serializers/common_serializers.py index 5c2eb853b..649d04de4 100644 --- a/apps/dataset/serializers/common_serializers.py +++ b/apps/dataset/serializers/common_serializers.py @@ -14,7 +14,7 @@ from django.db.models import QuerySet from drf_yasg import openapi from rest_framework import serializers -from common.config.embedding_config import EmbeddingModelManage +from common.config.embedding_config import ModelManage from common.db.search import native_search from common.db.sql_execute import update_execute from common.exception.app_exception import AppApiException @@ -140,14 +140,14 @@ def get_embedding_model_by_dataset_id_list(dataset_id_list: List): raise Exception("知识库未向量模型不一致") if len(dataset_list) == 0: raise Exception("知识库设置错误,请重新设置知识库") - return EmbeddingModelManage.get_model(str(dataset_list[0].id), + return ModelManage.get_model(str(dataset_list[0].id), lambda _id: get_model(dataset_list[0].embedding_mode)) def get_embedding_model_by_dataset_id(dataset_id: str): dataset = QuerySet(DataSet).select_related('embedding_mode').filter(id=dataset_id).first() - return EmbeddingModelManage.get_model(dataset_id, lambda _id: get_model(dataset.embedding_mode)) + return ModelManage.get_model(dataset_id, lambda _id: get_model(dataset.embedding_mode)) def get_embedding_model_by_dataset(dataset): - return EmbeddingModelManage.get_model(str(dataset.id), lambda _id: get_model(dataset.embedding_mode)) + return ModelManage.get_model(str(dataset.id), lambda _id: get_model(dataset.embedding_mode))