diff --git a/apps/application/chat_pipeline/step/search_dataset_step/i_search_dataset_step.py b/apps/application/chat_pipeline/step/search_dataset_step/i_search_dataset_step.py index 549abfaf2..97da29643 100644 --- a/apps/application/chat_pipeline/step/search_dataset_step/i_search_dataset_step.py +++ b/apps/application/chat_pipeline/step/search_dataset_step/i_search_dataset_step.py @@ -43,6 +43,7 @@ class ISearchDatasetStep(IBaseChatPipelineStep): validators.RegexValidator(regex=re.compile("^embedding|keywords|blend$"), message="类型只支持register|reset_password", code=500) ], error_messages=ErrMessage.char("检索模式")) + user_id = serializers.UUIDField(required=True, error_messages=ErrMessage.uuid("用户id")) def get_step_serializer(self, manage: PipelineManage) -> Type[InstanceSerializer]: return self.InstanceSerializer @@ -56,6 +57,7 @@ class ISearchDatasetStep(IBaseChatPipelineStep): def execute(self, problem_text: str, dataset_id_list: list[str], exclude_document_id_list: list[str], exclude_paragraph_id_list: list[str], top_n: int, similarity: float, padding_problem_text: str = None, search_mode: str = None, + user_id=None, **kwargs) -> List[ParagraphPipelineModel]: """ 关于 用户和补全问题 说明: 补全问题如果有就使用补全问题去查询 反之就用用户原始问题查询 @@ -67,6 +69,7 @@ class ISearchDatasetStep(IBaseChatPipelineStep): :param exclude_paragraph_id_list: 需要排除段落id :param padding_problem_text 补全问题 :param search_mode 检索模式 + :param user_id 用户id :return: 段落列表 """ pass diff --git a/apps/application/chat_pipeline/step/search_dataset_step/impl/base_search_dataset_step.py b/apps/application/chat_pipeline/step/search_dataset_step/impl/base_search_dataset_step.py index 3dd9f8300..0d43c3b2a 100644 --- a/apps/application/chat_pipeline/step/search_dataset_step/impl/base_search_dataset_step.py +++ b/apps/application/chat_pipeline/step/search_dataset_step/impl/base_search_dataset_step.py @@ -13,22 +13,48 @@ from django.db.models import QuerySet from application.chat_pipeline.I_base_chat_pipeline import ParagraphPipelineModel from application.chat_pipeline.step.search_dataset_step.i_search_dataset_step import ISearchDatasetStep -from common.config.embedding_config import VectorStore, EmbeddingModel +from common.config.embedding_config import VectorStore, ModelManage from common.db.search import native_search from common.util.file_util import get_file_content -from dataset.models import Paragraph +from dataset.models import Paragraph, DataSet from embedding.models import SearchMode +from setting.models import Model +from setting.models_provider import get_model from smartdoc.conf import PROJECT_DIR +def get_model_by_id(_id, user_id): + model = QuerySet(Model).filter(id=_id).first() + if model is None: + raise Exception("模型不存在") + if model.permission_type == 'PRIVATE' and str(model.user_id) != str(user_id): + raise Exception(f"无权限使用此模型:{model.name}") + return model + + +def get_embedding_id(dataset_id_list): + dataset_list = QuerySet(DataSet).filter(id__in=dataset_id_list) + if len(set([dataset.embedding_mode_id for dataset in dataset_list])) > 1: + raise Exception("知识库未向量模型不一致") + if len(dataset_list) == 0: + raise Exception("知识库设置错误,请重新设置知识库") + return dataset_list[0].embedding_mode_id + + class BaseSearchDatasetStep(ISearchDatasetStep): def execute(self, problem_text: str, dataset_id_list: list[str], exclude_document_id_list: list[str], exclude_paragraph_id_list: list[str], top_n: int, similarity: float, padding_problem_text: str = None, search_mode: str = None, + user_id=None, **kwargs) -> List[ParagraphPipelineModel]: + if len(dataset_id_list) == 0: + return [] exec_problem_text = padding_problem_text if padding_problem_text is not None else problem_text - embedding_model = EmbeddingModel.get_embedding_model() + model_id = get_embedding_id(dataset_id_list) + model = get_model_by_id(model_id, user_id) + self.context['model_name'] = model.name + embedding_model = ModelManage.get_model(model_id, lambda _id: get_model(model)) embedding_value = embedding_model.embed_query(exec_problem_text) vector = VectorStore.get_embedding_vector() embedding_list = vector.query(exec_problem_text, embedding_value, dataset_id_list, exclude_document_id_list, @@ -101,7 +127,7 @@ class BaseSearchDatasetStep(ISearchDatasetStep): 'run_time': self.context['run_time'], 'problem_text': step_args.get( 'padding_problem_text') if 'padding_problem_text' in step_args else step_args.get('problem_text'), - 'model_name': EmbeddingModel.get_embedding_model().model_name, + 'model_name': self.context.get('model_name'), 'message_tokens': 0, 'answer_tokens': 0, 'cost': 0 diff --git a/apps/application/flow/i_step_node.py b/apps/application/flow/i_step_node.py index 98bc5dcd1..3a558ace0 100644 --- a/apps/application/flow/i_step_node.py +++ b/apps/application/flow/i_step_node.py @@ -111,6 +111,7 @@ class FlowParamsSerializer(serializers.Serializer): client_type = serializers.CharField(required=False, error_messages=ErrMessage.char("客户端类型")) + user_id = serializers.UUIDField(required=True, error_messages=ErrMessage.char("用户id")) re_chat = serializers.BooleanField(required=True, error_messages=ErrMessage.boolean("换个答案")) diff --git a/apps/application/flow/step_node/search_dataset_node/impl/base_search_dataset_node.py b/apps/application/flow/step_node/search_dataset_node/impl/base_search_dataset_node.py index 191b8c15b..61634a371 100644 --- a/apps/application/flow/step_node/search_dataset_node/impl/base_search_dataset_node.py +++ b/apps/application/flow/step_node/search_dataset_node/impl/base_search_dataset_node.py @@ -13,20 +13,50 @@ from django.db.models import QuerySet from application.flow.i_step_node import NodeResult from application.flow.step_node.search_dataset_node.i_search_dataset_node import ISearchDatasetStepNode -from common.config.embedding_config import EmbeddingModel, VectorStore +from common.config.embedding_config import VectorStore, ModelManage from common.db.search import native_search from common.util.file_util import get_file_content -from dataset.models import Document, Paragraph +from dataset.models import Document, Paragraph, DataSet from embedding.models import SearchMode +from setting.models import Model +from setting.models_provider import get_model from smartdoc.conf import PROJECT_DIR +def get_model_by_id(_id, user_id): + model = QuerySet(Model).filter(id=_id).first() + if model is None: + raise Exception("模型不存在") + if model.permission_type == 'PRIVATE' and str(model.user_id) != str(user_id): + raise Exception(f"无权限使用此模型:{model.name}") + return model + + +def get_embedding_id(dataset_id_list): + dataset_list = QuerySet(DataSet).filter(id__in=dataset_id_list) + if len(set([dataset.embedding_mode_id for dataset in dataset_list])) > 1: + raise Exception("知识库未向量模型不一致") + if len(dataset_list) == 0: + raise Exception("知识库设置错误,请重新设置知识库") + return dataset_list[0].embedding_mode_id + + +def get_none_result(question): + return NodeResult( + {'paragraph_list': [], 'is_hit_handling_method': [], 'question': question, 'data': '', + 'directly_return': ''}, {}) + + class BaseSearchDatasetNode(ISearchDatasetStepNode): def execute(self, dataset_id_list, dataset_setting, question, exclude_paragraph_id_list=None, **kwargs) -> NodeResult: self.context['question'] = question - embedding_model = EmbeddingModel.get_embedding_model() + if len(dataset_id_list) == 0: + return get_none_result(question) + model_id = get_embedding_id(dataset_id_list) + model = get_model_by_id(model_id, self.flow_params_serializer.data.get('user_id')) + embedding_model = ModelManage.get_model(model_id, lambda _id: get_model(model)) embedding_value = embedding_model.embed_query(question) vector = VectorStore.get_embedding_vector() exclude_document_id_list = [str(document.id) for document in @@ -37,7 +67,7 @@ class BaseSearchDatasetNode(ISearchDatasetStepNode): exclude_paragraph_id_list, True, dataset_setting.get('top_n'), dataset_setting.get('similarity'), SearchMode(dataset_setting.get('search_mode'))) if embedding_list is None: - return NodeResult({'paragraph_list': [], 'is_hit_handling_method': []}, {}) + return get_none_result(question) paragraph_list = self.list_paragraph(embedding_list, vector) result = [self.reset_paragraph(paragraph, embedding_list) for paragraph in paragraph_list] result = sorted(result, key=lambda p: p.get('similarity'), reverse=True) diff --git a/apps/application/migrations/0010_alter_chatrecord_details.py b/apps/application/migrations/0010_alter_chatrecord_details.py new file mode 100644 index 000000000..e46278009 --- /dev/null +++ b/apps/application/migrations/0010_alter_chatrecord_details.py @@ -0,0 +1,19 @@ +# Generated by Django 4.2.13 on 2024-07-15 15:52 + +import application.models.application +from django.db import migrations, models + + +class Migration(migrations.Migration): + + dependencies = [ + ('application', '0009_application_type_application_work_flow_and_more'), + ] + + operations = [ + migrations.AlterField( + model_name='chatrecord', + name='details', + field=models.JSONField(default=dict, encoder=application.models.application.DateEncoder, verbose_name='对话详情'), + ), + ] diff --git a/apps/application/serializers/application_serializers.py b/apps/application/serializers/application_serializers.py index d689448e1..e13c7219b 100644 --- a/apps/application/serializers/application_serializers.py +++ b/apps/application/serializers/application_serializers.py @@ -26,7 +26,7 @@ from rest_framework import serializers from application.flow.workflow_manage import Flow from application.models import Application, ApplicationDatasetMapping, ApplicationTypeChoices, WorkFlowVersion from application.models.api_key_model import ApplicationAccessToken, ApplicationApiKey -from common.config.embedding_config import VectorStore, EmbeddingModel +from common.config.embedding_config import VectorStore from common.constants.authentication_type import AuthenticationType from common.db.search import get_dynamics_model, native_search, native_page_search from common.db.sql_execute import select_list @@ -36,7 +36,7 @@ from common.util.common import valid_license from common.util.field_message import ErrMessage from common.util.file_util import get_file_content from dataset.models import DataSet, Document, Image -from dataset.serializers.common_serializers import list_paragraph +from dataset.serializers.common_serializers import list_paragraph, get_embedding_model_by_dataset_id_list from embedding.models import SearchMode from setting.models import AuthOperate from setting.models.model_management import Model @@ -415,12 +415,13 @@ class ApplicationSerializer(serializers.Serializer): QuerySet(Document).filter( dataset_id__in=dataset_id_list, is_active=False)] + model = get_embedding_model_by_dataset_id_list(dataset_id_list) # 向量库检索 hit_list = vector.hit_test(self.data.get('query_text'), dataset_id_list, exclude_document_id_list, self.data.get('top_number'), self.data.get('similarity'), SearchMode(self.data.get('search_mode')), - EmbeddingModel.get_embedding_model()) + model) hit_dict = reduce(lambda x, y: {**x, **y}, [{hit.get('paragraph_id'): hit} for hit in hit_list], {}) p_list = list_paragraph([h.get('paragraph_id') for h in hit_list]) return [{**p, 'similarity': hit_dict.get(p.get('id')).get('similarity'), @@ -522,12 +523,14 @@ class ApplicationSerializer(serializers.Serializer): if not QuerySet(Application).filter(id=self.data.get('application_id')).exists(): raise AppApiException(500, '不存在的应用id') - def list_model(self, with_valid=True): + def list_model(self, model_type=None, with_valid=True): if with_valid: self.is_valid() + if model_type is None: + model_type = "LLM" application = QuerySet(Application).filter(id=self.data.get("application_id")).first() return ModelSerializer.Query( - data={'user_id': application.user_id}).list( + data={'user_id': application.user_id, 'model_type': model_type}).list( with_valid=True) def delete(self, with_valid=True): diff --git a/apps/application/serializers/chat_message_serializers.py b/apps/application/serializers/chat_message_serializers.py index 41f19bc0d..5f2754918 100644 --- a/apps/application/serializers/chat_message_serializers.py +++ b/apps/application/serializers/chat_message_serializers.py @@ -88,7 +88,9 @@ class ChatInfo: 'no_references_setting': self.application.dataset_setting.get( 'no_references_setting') if 'no_references_setting' in self.application.dataset_setting else { 'status': 'ai_questioning', - 'value': '{question}'} + 'value': '{question}', + }, + 'user_id': self.application.user_id } @@ -221,11 +223,13 @@ class ChatMessageSerializer(serializers.Serializer): stream = self.data.get('stream') client_id = self.data.get('client_id') client_type = self.data.get('client_type') + user_id = chat_info.application.user_id work_flow_manage = WorkflowManage(Flow.new_instance(chat_info.work_flow_version.work_flow), {'history_chat_record': chat_info.chat_record_list, 'question': message, 'chat_id': chat_info.chat_id, 'chat_record_id': str(uuid.uuid1()), 'stream': stream, - 're_chat': re_chat}, WorkFlowPostHandler(chat_info, client_id, client_type)) + 're_chat': re_chat, + 'user_id': user_id}, WorkFlowPostHandler(chat_info, client_id, client_type)) r = work_flow_manage.run() return r diff --git a/apps/application/serializers/chat_serializers.py b/apps/application/serializers/chat_serializers.py index 402eff691..7e4018fb9 100644 --- a/apps/application/serializers/chat_serializers.py +++ b/apps/application/serializers/chat_serializers.py @@ -29,6 +29,7 @@ from application.models.api_key_model import ApplicationAccessToken from application.serializers.application_serializers import ModelDatasetAssociation, DatasetSettingSerializer, \ ModelSettingSerializer from application.serializers.chat_message_serializers import ChatInfo +from common.config.embedding_config import ModelManage from common.constants.permission_constants import RoleConstants from common.db.search import native_search, native_page_search, page_search, get_dynamics_model from common.event import ListenerManagement @@ -39,8 +40,10 @@ from common.util.file_util import get_file_content from common.util.lock import try_lock, un_lock from common.util.rsa_util import rsa_long_decrypt from dataset.models import Document, Problem, Paragraph, ProblemParagraphMapping +from dataset.serializers.common_serializers import get_embedding_model_by_dataset_id from dataset.serializers.paragraph_serializers import ParagraphSerializers from setting.models import Model +from setting.models_provider import get_model from setting.models_provider.constants.model_provider_constants import ModelProvideConstants from smartdoc.conf import PROJECT_DIR @@ -241,12 +244,7 @@ class ChatSerializers(serializers.Serializer): application_id=application_id)] chat_model = None if model is not None: - chat_model = ModelProvideConstants[model.provider].value.get_model(model.model_type, model.model_name, - json.loads( - rsa_long_decrypt( - model.credential)), - streaming=True) - + chat_model = ModelManage.get_model(str(model.id), lambda _id: get_model(model)) chat_id = str(uuid.uuid1()) chat_cache.set(chat_id, ChatInfo(chat_id, chat_model, dataset_id_list, @@ -259,6 +257,7 @@ class ChatSerializers(serializers.Serializer): class OpenWorkFlowChat(serializers.Serializer): work_flow = WorkFlowSerializers(error_messages=ErrMessage.uuid("工作流")) + user_id = serializers.UUIDField(required=True, error_messages=ErrMessage.uuid("用户id")) def open(self): self.is_valid(raise_exception=True) @@ -269,7 +268,8 @@ class ChatSerializers(serializers.Serializer): dataset_setting={}, model_setting={}, problem_optimization=None, - type=ApplicationTypeChoices.WORK_FLOW + type=ApplicationTypeChoices.WORK_FLOW, + user_id=self.data.get('user_id') ) work_flow_version = WorkFlowVersion(work_flow=work_flow) chat_cache.set(chat_id, @@ -332,7 +332,8 @@ class ChatSerializers(serializers.Serializer): application = Application(id=None, dialogue_number=3, model=model, dataset_setting=self.data.get('dataset_setting'), model_setting=self.data.get('model_setting'), - problem_optimization=self.data.get('problem_optimization')) + problem_optimization=self.data.get('problem_optimization'), + user_id=user_id) chat_cache.set(chat_id, ChatInfo(chat_id, chat_model, dataset_id_list, [str(document.id) for document in @@ -533,9 +534,10 @@ class ChatRecordSerializer(serializers.Serializer): raise AppApiException(500, "文档id不正确") @staticmethod - def post_embedding_paragraph(chat_record, paragraph_id): + def post_embedding_paragraph(chat_record, paragraph_id, dataset_id): + model = get_embedding_model_by_dataset_id(dataset_id) # 发送向量化事件 - ListenerManagement.embedding_by_paragraph_signal.send(paragraph_id) + ListenerManagement.embedding_by_paragraph_signal.send(paragraph_id, embedding_model=model) return chat_record @post(post_function=post_embedding_paragraph) @@ -573,7 +575,7 @@ class ChatRecordSerializer(serializers.Serializer): chat_record.improve_paragraph_id_list.append(paragraph.id) # 添加标注 chat_record.save() - return ChatRecordSerializerModel(chat_record).data, paragraph.id + return ChatRecordSerializerModel(chat_record).data, paragraph.id, dataset_id class Operate(serializers.Serializer): chat_id = serializers.UUIDField(required=True, error_messages=ErrMessage.uuid("对话id")) diff --git a/apps/application/swagger_api/application_api.py b/apps/application/swagger_api/application_api.py index 6e46931a6..e153f6279 100644 --- a/apps/application/swagger_api/application_api.py +++ b/apps/application/swagger_api/application_api.py @@ -67,6 +67,20 @@ class ApplicationApi(ApiMixin): } ) + class Model(ApiMixin): + @staticmethod + def get_request_params_api(): + return [openapi.Parameter(name='application_id', + in_=openapi.IN_PATH, + type=openapi.TYPE_STRING, + required=True, + description='应用id'), + openapi.Parameter(name='model_type', in_=openapi.IN_QUERY, + type=openapi.TYPE_STRING, + required=False, + description='模型类型'), + ] + class ApiKey(ApiMixin): @staticmethod def get_request_params_api(): diff --git a/apps/application/views/application_views.py b/apps/application/views/application_views.py index c5fc165d3..28fe412ad 100644 --- a/apps/application/views/application_views.py +++ b/apps/application/views/application_views.py @@ -175,7 +175,7 @@ class Application(APIView): @swagger_auto_schema(operation_summary="获取模型列表", operation_id="获取模型列表", tags=["应用"], - manual_parameters=ApplicationApi.ApiKey.get_request_params_api()) + manual_parameters=ApplicationApi.Model.get_request_params_api()) @has_permissions(ViewPermission( [RoleConstants.ADMIN, RoleConstants.USER], [lambda r, keywords: Permission(group=Group.APPLICATION, operate=Operate.USE, @@ -185,7 +185,7 @@ class Application(APIView): return result.success( ApplicationSerializer.Operate( data={'application_id': application_id, - 'user_id': request.user.id}).list_model()) + 'user_id': request.user.id}).list_model(request.query_params.get('model_type'))) class Profile(APIView): authentication_classes = [TokenAuth] diff --git a/apps/common/cache/mem_cache.py b/apps/common/cache/mem_cache.py index 86a3ce040..5afb1e562 100644 --- a/apps/common/cache/mem_cache.py +++ b/apps/common/cache/mem_cache.py @@ -41,3 +41,7 @@ class MemCache(LocMemCache): delete_keys.append(key) for key in delete_keys: self._delete(key) + + def clear_timeout_data(self): + for key in self._cache.keys(): + self.get(key) diff --git a/apps/common/config/embedding_config.py b/apps/common/config/embedding_config.py index 4d652bf95..b33dd83a1 100644 --- a/apps/common/config/embedding_config.py +++ b/apps/common/config/embedding_config.py @@ -6,33 +6,36 @@ @date:2023/10/23 16:03 @desc: """ -from langchain_huggingface.embeddings import HuggingFaceEmbeddings +import time -from smartdoc.const import CONFIG +from common.cache.mem_cache import MemCache -class EmbeddingModel: - instance = None +class ModelManage: + cache = MemCache('model', {}) + up_clear_time = time.time() @staticmethod - def get_embedding_model(): - """ - 获取向量化模型 - :return: - """ - if EmbeddingModel.instance is None: - model_name = CONFIG.get('EMBEDDING_MODEL_NAME') - cache_folder = CONFIG.get('EMBEDDING_MODEL_PATH') - device = CONFIG.get('EMBEDDING_DEVICE') - encode_kwargs = {'normalize_embeddings': True} - e = HuggingFaceEmbeddings( - model_name=model_name, - cache_folder=cache_folder, - model_kwargs={'device': device}, - encode_kwargs=encode_kwargs, - ) - EmbeddingModel.instance = e - return EmbeddingModel.instance + def get_model(_id, get_model): + model_instance = ModelManage.cache.get(_id) + if model_instance is None: + model_instance = get_model(_id) + ModelManage.cache.set(_id, model_instance, timeout=60 * 30) + return model_instance + # 续期 + ModelManage.cache.touch(_id, timeout=60 * 30) + ModelManage.clear_timeout_cache() + return model_instance + + @staticmethod + def clear_timeout_cache(): + if time.time() - ModelManage.up_clear_time > 60: + ModelManage.cache.clear_timeout_data() + + @staticmethod + def delete_key(_id): + if ModelManage.cache.has_key(_id): + ModelManage.cache.delete(_id) class VectorStore: diff --git a/apps/common/event/common.py b/apps/common/event/common.py index e35123758..9f4a945bf 100644 --- a/apps/common/event/common.py +++ b/apps/common/event/common.py @@ -14,14 +14,14 @@ embedding_thread_pool = ThreadPoolExecutor(3) def poxy(poxy_function): - def inner(args): - work_thread_pool.submit(poxy_function, args) + def inner(args, **keywords): + work_thread_pool.submit(poxy_function, args, **keywords) return inner def embedding_poxy(poxy_function): - def inner(args): - embedding_thread_pool.submit(poxy_function, args) + def inner(args, **keywords): + embedding_thread_pool.submit(poxy_function, args, **keywords) return inner diff --git a/apps/common/event/listener_manage.py b/apps/common/event/listener_manage.py index dea53cb13..0de80acee 100644 --- a/apps/common/event/listener_manage.py +++ b/apps/common/event/listener_manage.py @@ -15,8 +15,9 @@ from typing import List import django.db.models from blinker import signal from django.db.models import QuerySet +from langchain_core.embeddings import Embeddings -from common.config.embedding_config import VectorStore, EmbeddingModel +from common.config.embedding_config import VectorStore from common.db.search import native_search, get_dynamics_model from common.event.common import poxy, embedding_poxy from common.util.file_util import get_file_content @@ -46,22 +47,26 @@ class SyncWebDocumentArgs: class UpdateProblemArgs: - def __init__(self, problem_id: str, problem_content: str): + def __init__(self, problem_id: str, problem_content: str, embedding_model: Embeddings): self.problem_id = problem_id self.problem_content = problem_content + self.embedding_model = embedding_model class UpdateEmbeddingDatasetIdArgs: - def __init__(self, paragraph_id_list: List[str], target_dataset_id: str): + def __init__(self, paragraph_id_list: List[str], target_dataset_id: str, target_embedding_model: Embeddings): self.paragraph_id_list = paragraph_id_list self.target_dataset_id = target_dataset_id + self.target_embedding_model = target_embedding_model class UpdateEmbeddingDocumentIdArgs: - def __init__(self, paragraph_id_list: List[str], target_document_id: str, target_dataset_id: str): + def __init__(self, paragraph_id_list: List[str], target_document_id: str, target_dataset_id: str, + target_embedding_model: Embeddings = None): self.paragraph_id_list = paragraph_id_list self.target_document_id = target_document_id self.target_dataset_id = target_dataset_id + self.target_embedding_model = target_embedding_model class ListenerManagement: @@ -84,16 +89,46 @@ class ListenerManagement: delete_embedding_by_dataset_id_list_signal = signal("delete_embedding_by_dataset_id_list") @staticmethod - def embedding_by_problem(args): - VectorStore.get_embedding_vector().save(**args) + def embedding_by_problem(args, embedding_model: Embeddings): + VectorStore.get_embedding_vector().save(**args, embedding=embedding_model) + + @staticmethod + def embedding_by_paragraph_list(paragraph_id_list, embedding_model: Embeddings): + try: + data_list = native_search( + {'problem': QuerySet(get_dynamics_model({'paragraph.id': django.db.models.CharField()})).filter( + **{'paragraph.id__in': paragraph_id_list}), + 'paragraph': QuerySet(Paragraph).filter(id__in=paragraph_id_list)}, + select_string=get_file_content( + os.path.join(PROJECT_DIR, "apps", "common", 'sql', 'list_embedding_text.sql'))) + ListenerManagement.embedding_by_paragraph_data_list(data_list, paragraph_id_list=paragraph_id_list, + embedding_model=embedding_model) + except Exception as e: + max_kb_error.error(f'查询向量数据:{paragraph_id_list}出现错误{str(e)}{traceback.format_exc()}') @staticmethod @embedding_poxy - def embedding_by_paragraph(paragraph_id): + def embedding_by_paragraph_data_list(data_list, paragraph_id_list, embedding_model: Embeddings): + max_kb.info(f'开始--->向量化段落:{paragraph_id_list}') + try: + # 删除段落 + VectorStore.get_embedding_vector().delete_by_paragraph_ids(paragraph_id_list) + # 批量向量化 + VectorStore.get_embedding_vector().batch_save(data_list, embedding_model) + except Exception as e: + max_kb_error.error(f'向量化段落:{paragraph_id_list}出现错误{str(e)}{traceback.format_exc()}') + status = Status.error + finally: + QuerySet(Paragraph).filter(id__in=paragraph_id_list).update(**{'status': status}) + max_kb.info(f'结束--->向量化段落:{paragraph_id_list}') + + @staticmethod + @embedding_poxy + def embedding_by_paragraph(paragraph_id, embedding_model: Embeddings): """ 向量化段落 根据段落id - :param paragraph_id: 段落id - :return: None + @param paragraph_id: 段落id + @param embedding_model: 向量模型 """ max_kb.info(f"开始--->向量化段落:{paragraph_id}") status = Status.success @@ -107,7 +142,7 @@ class ListenerManagement: # 删除段落 VectorStore.get_embedding_vector().delete_by_paragraph_id(paragraph_id) # 批量向量化 - VectorStore.get_embedding_vector().batch_save(data_list) + VectorStore.get_embedding_vector().batch_save(data_list, embedding_model) except Exception as e: max_kb_error.error(f'向量化段落:{paragraph_id}出现错误{str(e)}{traceback.format_exc()}') status = Status.error @@ -117,12 +152,15 @@ class ListenerManagement: @staticmethod @embedding_poxy - def embedding_by_document(document_id): + def embedding_by_document(document_id, embedding_model: Embeddings): """ 向量化文档 - :param document_id: 文档id + @param document_id: 文档id + @param embedding_model 向量模型 :return: None """ + if not try_lock('embedding' + str(document_id)): + return max_kb.info(f"开始--->向量化文档:{document_id}") QuerySet(Document).filter(id=document_id).update(**{'status': Status.embedding}) QuerySet(Paragraph).filter(document_id=document_id).update(**{'status': Status.embedding}) @@ -138,7 +176,7 @@ class ListenerManagement: # 删除文档向量数据 VectorStore.get_embedding_vector().delete_by_document_id(document_id) # 批量向量化 - VectorStore.get_embedding_vector().batch_save(data_list) + VectorStore.get_embedding_vector().batch_save(data_list, embedding_model) except Exception as e: max_kb_error.error(f'向量化文档:{document_id}出现错误{str(e)}{traceback.format_exc()}') status = Status.error @@ -148,21 +186,24 @@ class ListenerManagement: **{'status': status, 'update_time': datetime.datetime.now()}) QuerySet(Paragraph).filter(document_id=document_id).update(**{'status': status}) max_kb.info(f"结束--->向量化文档:{document_id}") + un_lock('embedding' + str(document_id)) @staticmethod @embedding_poxy - def embedding_by_dataset(dataset_id): + def embedding_by_dataset(dataset_id, embedding_model: Embeddings): """ 向量化知识库 - :param dataset_id: 知识库id + @param dataset_id: 知识库id + @param embedding_model 向量模型 :return: None """ max_kb.info(f"开始--->向量化数据集:{dataset_id}") try: + ListenerManagement.delete_embedding_by_dataset(dataset_id) document_list = QuerySet(Document).filter(dataset_id=dataset_id) max_kb.info(f"数据集文档:{[d.name for d in document_list]}") for document in document_list: - ListenerManagement.embedding_by_document(document.id) + ListenerManagement.embedding_by_document(document.id, embedding_model=embedding_model) except Exception as e: max_kb_error.error(f'向量化数据集:{dataset_id}出现错误{str(e)}{traceback.format_exc()}') finally: @@ -224,14 +265,22 @@ class ListenerManagement: @staticmethod def update_embedding_dataset_id(args: UpdateEmbeddingDatasetIdArgs): - VectorStore.get_embedding_vector().update_by_paragraph_ids(args.paragraph_id_list, - {'dataset_id': args.target_dataset_id}) + if args.target_embedding_model is None: + VectorStore.get_embedding_vector().update_by_paragraph_ids(args.paragraph_id_list, + {'dataset_id': args.target_dataset_id}) + else: + ListenerManagement.embedding_by_paragraph_list(args.paragraph_id_list, + embedding_model=args.target_embedding_model) @staticmethod def update_embedding_document_id(args: UpdateEmbeddingDocumentIdArgs): - VectorStore.get_embedding_vector().update_by_paragraph_ids(args.paragraph_id_list, - {'document_id': args.target_document_id, - 'dataset_id': args.target_dataset_id}) + if args.target_embedding_model is None: + VectorStore.get_embedding_vector().update_by_paragraph_ids(args.paragraph_id_list, + {'document_id': args.target_document_id, + 'dataset_id': args.target_dataset_id}) + else: + ListenerManagement.embedding_by_paragraph_list(args.paragraph_id_list, + embedding_model=args.target_embedding_model) @staticmethod def delete_embedding_by_source_ids(source_ids: List[str]): @@ -245,11 +294,6 @@ class ListenerManagement: def delete_embedding_by_dataset_id_list(source_ids: List[str]): VectorStore.get_embedding_vector().delete_by_dataset_id_list(source_ids) - @staticmethod - @poxy - def init_embedding_model(ages): - EmbeddingModel.get_embedding_model() - def run(self): # 添加向量 根据问题id ListenerManagement.embedding_by_problem_signal.connect(self.embedding_by_problem) @@ -276,8 +320,7 @@ class ListenerManagement: ListenerManagement.disable_embedding_by_paragraph_signal.connect(self.disable_embedding_by_paragraph) # 启动段落向量 ListenerManagement.enable_embedding_by_paragraph_signal.connect(self.enable_embedding_by_paragraph) - # 初始化向量化模型 - ListenerManagement.init_embedding_model_signal.connect(self.init_embedding_model) + # 同步web站点知识库 ListenerManagement.sync_web_dataset_signal.connect(self.sync_web_dataset) # 同步web站点 文档 diff --git a/apps/dataset/migrations/0006_dataset_embedding_mode.py b/apps/dataset/migrations/0006_dataset_embedding_mode.py new file mode 100644 index 000000000..2248d8e36 --- /dev/null +++ b/apps/dataset/migrations/0006_dataset_embedding_mode.py @@ -0,0 +1,21 @@ +# Generated by Django 4.2.13 on 2024-07-17 13:56 + +import dataset.models.data_set +from django.db import migrations, models +import django.db.models.deletion + + +class Migration(migrations.Migration): + + dependencies = [ + ('setting', '0005_model_permission_type'), + ('dataset', '0005_file'), + ] + + operations = [ + migrations.AddField( + model_name='dataset', + name='embedding_mode', + field=models.ForeignKey(default=dataset.models.data_set.default_model, on_delete=django.db.models.deletion.DO_NOTHING, to='setting.model', verbose_name='向量模型'), + ), + ] diff --git a/apps/dataset/models/data_set.py b/apps/dataset/models/data_set.py index ca3b05e0c..a34854978 100644 --- a/apps/dataset/models/data_set.py +++ b/apps/dataset/models/data_set.py @@ -9,9 +9,11 @@ import uuid from django.db import models +from django.db.models import QuerySet from common.db.sql_execute import select_one from common.mixins.app_model_mixin import AppModelMixin +from setting.models import Model from users.models import User @@ -33,6 +35,10 @@ class HitHandlingMethod(models.TextChoices): directly_return = 'directly_return', '直接返回' +def default_model(): + return uuid.UUID('42f63a3d-427e-11ef-b3ec-a8a1595801ab') + + class DataSet(AppModelMixin): """ 数据集表 @@ -43,7 +49,8 @@ class DataSet(AppModelMixin): user = models.ForeignKey(User, on_delete=models.DO_NOTHING, verbose_name="所属用户") type = models.CharField(verbose_name='类型', max_length=1, choices=Type.choices, default=Type.base) - + embedding_mode = models.ForeignKey(Model, on_delete=models.DO_NOTHING, verbose_name="向量模型", + default=default_model) meta = models.JSONField(verbose_name="元数据", default=dict) class Meta: diff --git a/apps/dataset/serializers/common_serializers.py b/apps/dataset/serializers/common_serializers.py index 099d1f24e..649d04de4 100644 --- a/apps/dataset/serializers/common_serializers.py +++ b/apps/dataset/serializers/common_serializers.py @@ -14,6 +14,7 @@ from django.db.models import QuerySet from drf_yasg import openapi from rest_framework import serializers +from common.config.embedding_config import ModelManage from common.db.search import native_search from common.db.sql_execute import update_execute from common.exception.app_exception import AppApiException @@ -21,7 +22,8 @@ from common.mixins.api_mixin import ApiMixin from common.util.field_message import ErrMessage from common.util.file_util import get_file_content from common.util.fork import Fork -from dataset.models import Paragraph, Problem, ProblemParagraphMapping +from dataset.models import Paragraph, Problem, ProblemParagraphMapping, DataSet +from setting.models_provider import get_model from smartdoc.conf import PROJECT_DIR @@ -130,3 +132,22 @@ class ProblemParagraphManage: result = [problem_model for problem_model, is_create in problem_content_dict.values() if is_create], problem_paragraph_mapping_list return result + + +def get_embedding_model_by_dataset_id_list(dataset_id_list: List): + dataset_list = QuerySet(DataSet).filter(id__in=dataset_id_list) + if len(set([dataset.embedding_mode_id for dataset in dataset_list])) > 1: + raise Exception("知识库未向量模型不一致") + if len(dataset_list) == 0: + raise Exception("知识库设置错误,请重新设置知识库") + return ModelManage.get_model(str(dataset_list[0].id), + lambda _id: get_model(dataset_list[0].embedding_mode)) + + +def get_embedding_model_by_dataset_id(dataset_id: str): + dataset = QuerySet(DataSet).select_related('embedding_mode').filter(id=dataset_id).first() + return ModelManage.get_model(dataset_id, lambda _id: get_model(dataset.embedding_mode)) + + +def get_embedding_model_by_dataset(dataset): + return ModelManage.get_model(str(dataset.id), lambda _id: get_model(dataset.embedding_mode)) diff --git a/apps/dataset/serializers/dataset_serializers.py b/apps/dataset/serializers/dataset_serializers.py index 50a6a9b99..946398482 100644 --- a/apps/dataset/serializers/dataset_serializers.py +++ b/apps/dataset/serializers/dataset_serializers.py @@ -15,7 +15,6 @@ from functools import reduce from typing import Dict, List from urllib.parse import urlparse -from django.conf import settings from django.contrib.postgres.fields import ArrayField from django.core import validators from django.db import transaction, models @@ -25,7 +24,7 @@ from drf_yasg import openapi from rest_framework import serializers from application.models import ApplicationDatasetMapping -from common.config.embedding_config import VectorStore, EmbeddingModel +from common.config.embedding_config import VectorStore from common.db.search import get_dynamics_model, native_page_search, native_search from common.db.sql_execute import select_list from common.event import ListenerManagement, SyncWebDatasetArgs @@ -37,7 +36,8 @@ from common.util.file_util import get_file_content from common.util.fork import ChildLink, Fork from common.util.split_model import get_split_model from dataset.models.data_set import DataSet, Document, Paragraph, Problem, Type, ProblemParagraphMapping -from dataset.serializers.common_serializers import list_paragraph, MetaSerializer, ProblemParagraphManage +from dataset.serializers.common_serializers import list_paragraph, MetaSerializer, ProblemParagraphManage, \ + get_embedding_model_by_dataset_id from dataset.serializers.document_serializers import DocumentSerializers, DocumentInstanceSerializer from embedding.models import SearchMode from setting.models import AuthOperate @@ -206,6 +206,8 @@ class DataSetSerializers(serializers.ModelSerializer): max_length=256, min_length=1) + embedding_mode_id = serializers.UUIDField(required=True, error_messages=ErrMessage.uuid("向量模型")) + documents = DocumentInstanceSerializer(required=False, many=True) def is_valid(self, *, raise_exception=False): @@ -226,6 +228,8 @@ class DataSetSerializers(serializers.ModelSerializer): max_length=256, min_length=1) + embedding_mode_id = serializers.UUIDField(required=True, error_messages=ErrMessage.uuid("向量模型")) + file_list = serializers.ListSerializer(required=True, error_messages=ErrMessage.list("文件列表"), child=serializers.FileField(required=True, @@ -296,6 +300,8 @@ class DataSetSerializers(serializers.ModelSerializer): min_length=1) source_url = serializers.CharField(required=True, error_messages=ErrMessage.char("Web 根地址"), ) + embedding_mode_id = serializers.UUIDField(required=True, error_messages=ErrMessage.uuid("向量模型")) + selector = serializers.CharField(required=False, allow_null=True, allow_blank=True, error_messages=ErrMessage.char("选择器")) @@ -347,6 +353,8 @@ class DataSetSerializers(serializers.ModelSerializer): properties={ 'name': openapi.Schema(type=openapi.TYPE_STRING, title="知识库名称", description="知识库名称"), 'desc': openapi.Schema(type=openapi.TYPE_STRING, title="知识库描述", description="知识库描述"), + 'embedding_mode_id': openapi.Schema(type=openapi.TYPE_STRING, title="向量模型id", + description="向量模型id"), 'source_url': openapi.Schema(type=openapi.TYPE_STRING, title="web站点url", description="web站点url"), 'selector': openapi.Schema(type=openapi.TYPE_STRING, title="选择器", description="选择器") @@ -355,8 +363,9 @@ class DataSetSerializers(serializers.ModelSerializer): @staticmethod def post_embedding_dataset(document_list, dataset_id): + model = get_embedding_model_by_dataset_id(dataset_id) # 发送向量化事件 - ListenerManagement.embedding_by_dataset_signal.send(dataset_id) + ListenerManagement.embedding_by_dataset_signal.send(dataset_id, embedding_model=model) return document_list def save_qa(self, instance: Dict, with_valid=True): @@ -365,7 +374,8 @@ class DataSetSerializers(serializers.ModelSerializer): self.CreateQASerializers(data=instance).is_valid() file_list = instance.get('file_list') document_list = flat_map([DocumentSerializers.Create.parse_qa_file(file) for file in file_list]) - dataset_instance = {'name': instance.get('name'), 'desc': instance.get('desc'), 'documents': document_list} + dataset_instance = {'name': instance.get('name'), 'desc': instance.get('desc'), 'documents': document_list, + 'embedding_mode_id': instance.get('embedding_mode_id')} return self.save(dataset_instance, with_valid=True) @valid_license(model=DataSet, count=50, @@ -381,7 +391,8 @@ class DataSetSerializers(serializers.ModelSerializer): if QuerySet(DataSet).filter(user_id=user_id, name=instance.get('name')).exists(): raise AppApiException(500, "知识库名称重复!") dataset = DataSet( - **{'id': dataset_id, 'name': instance.get("name"), 'desc': instance.get('desc'), 'user_id': user_id}) + **{'id': dataset_id, 'name': instance.get("name"), 'desc': instance.get('desc'), 'user_id': user_id, + 'embedding_mode_id': instance.get('embedding_mode_id')}) document_model_list = [] paragraph_model_list = [] @@ -452,7 +463,8 @@ class DataSetSerializers(serializers.ModelSerializer): dataset = DataSet( **{'id': dataset_id, 'name': instance.get("name"), 'desc': instance.get('desc'), 'user_id': user_id, 'type': Type.web, - 'meta': {'source_url': instance.get('source_url'), 'selector': instance.get('selector')}}) + 'meta': {'source_url': instance.get('source_url'), 'selector': instance.get('selector'), + 'embedding_mode_id': instance.get('embedding_mode_id')}}) dataset.save() ListenerManagement.sync_web_dataset_signal.send( SyncWebDatasetArgs(str(dataset_id), instance.get('source_url'), instance.get('selector'), @@ -500,6 +512,8 @@ class DataSetSerializers(serializers.ModelSerializer): properties={ 'name': openapi.Schema(type=openapi.TYPE_STRING, title="知识库名称", description="知识库名称"), 'desc': openapi.Schema(type=openapi.TYPE_STRING, title="知识库描述", description="知识库描述"), + 'embedding_mode_id': openapi.Schema(type=openapi.TYPE_STRING, title='向量模型', + description='向量模型'), 'documents': openapi.Schema(type=openapi.TYPE_ARRAY, title="文档数据", description="文档数据", items=DocumentSerializers().Create.get_request_body_api() ) @@ -557,12 +571,13 @@ class DataSetSerializers(serializers.ModelSerializer): QuerySet(Document).filter( dataset_id=self.data.get('id'), is_active=False)] + model = get_embedding_model_by_dataset_id(self.data.get('id')) # 向量库检索 hit_list = vector.hit_test(self.data.get('query_text'), [self.data.get('id')], exclude_document_id_list, self.data.get('top_number'), self.data.get('similarity'), SearchMode(self.data.get('search_mode')), - EmbeddingModel.get_embedding_model()) + model) hit_dict = reduce(lambda x, y: {**x, **y}, [{hit.get('paragraph_id'): hit} for hit in hit_list], {}) p_list = list_paragraph([h.get('paragraph_id') for h in hit_list]) return [{**p, 'similarity': hit_dict.get(p.get('id')).get('similarity'), @@ -730,7 +745,8 @@ class DataSetSerializers(serializers.ModelSerializer): def re_embedding(self, with_valid=True): if with_valid: self.is_valid(raise_exception=True) - ListenerManagement.embedding_by_dataset_signal.send(self.data.get('id')) + model = get_embedding_model_by_dataset_id(self.data.get('id')) + ListenerManagement.embedding_by_dataset_signal.send(self.data.get('id'), embedding_model=model) def list_application(self, with_valid=True): if with_valid: @@ -769,6 +785,7 @@ class DataSetSerializers(serializers.ModelSerializer): QuerySet(ApplicationDatasetMapping).filter( dataset_id=self.data.get('id'))]))} + @transaction.atomic def edit(self, dataset: Dict, user_id: str): """ 修改知识库 @@ -782,6 +799,8 @@ class DataSetSerializers(serializers.ModelSerializer): raise AppApiException(500, "知识库名称重复!") _dataset = QuerySet(DataSet).get(id=self.data.get("id")) DataSetSerializers.Edit(data=dataset).is_valid(dataset=_dataset) + if 'embedding_mode_id' in dataset: + _dataset.embedding_mode_id = dataset.get('embedding_mode_id') if "name" in dataset: _dataset.name = dataset.get("name") if 'desc' in dataset: diff --git a/apps/dataset/serializers/document_serializers.py b/apps/dataset/serializers/document_serializers.py index 07a39b578..f89c12dc0 100644 --- a/apps/dataset/serializers/document_serializers.py +++ b/apps/dataset/serializers/document_serializers.py @@ -41,7 +41,8 @@ from common.util.file_util import get_file_content from common.util.fork import Fork from common.util.split_model import get_split_model from dataset.models.data_set import DataSet, Document, Paragraph, Problem, Type, Status, ProblemParagraphMapping, Image -from dataset.serializers.common_serializers import BatchSerializer, MetaSerializer, ProblemParagraphManage +from dataset.serializers.common_serializers import BatchSerializer, MetaSerializer, ProblemParagraphManage, \ + get_embedding_model_by_dataset_id from dataset.serializers.paragraph_serializers import ParagraphSerializers, ParagraphInstanceSerializer from smartdoc.conf import PROJECT_DIR @@ -234,12 +235,17 @@ class DocumentSerializers(ApiMixin, serializers.Serializer): meta={}) else: document_list.update(dataset_id=target_dataset_id) - # 修改向量信息 - ListenerManagement.update_embedding_dataset_id(UpdateEmbeddingDatasetIdArgs( - [paragraph.id for paragraph in paragraph_list], - target_dataset_id)) + model = None + if dataset.embedding_mode_id != target_dataset.embedding_mode_id: + model = get_embedding_model_by_dataset_id(target_dataset_id) + + pid_list = [paragraph.id for paragraph in paragraph_list] # 修改段落信息 paragraph_list.update(dataset_id=target_dataset_id) + # 修改向量信息 + ListenerManagement.update_embedding_dataset_id(UpdateEmbeddingDatasetIdArgs( + pid_list, + target_dataset_id, model)) @staticmethod def get_target_dataset_problem(target_dataset_id: str, @@ -392,7 +398,8 @@ class DocumentSerializers(ApiMixin, serializers.Serializer): problem_paragraph_mapping_list) > 0 else None # 向量化 if with_embedding: - ListenerManagement.embedding_by_document_signal.send(document_id) + model = get_embedding_model_by_dataset_id(dataset_id=document.dataset_id) + ListenerManagement.embedding_by_document_signal.send(document_id, embedding_model=model) else: document.status = Status.error document.save() @@ -405,6 +412,7 @@ class DocumentSerializers(ApiMixin, serializers.Serializer): class Operate(ApiMixin, serializers.Serializer): document_id = serializers.UUIDField(required=True, error_messages=ErrMessage.char( "文档id")) + dataset_id = serializers.UUIDField(required=True, error_messages=ErrMessage.char("数据集id")) @staticmethod def get_request_params_api(): @@ -530,7 +538,8 @@ class DocumentSerializers(ApiMixin, serializers.Serializer): if with_valid: self.is_valid(raise_exception=True) document_id = self.data.get("document_id") - ListenerManagement.embedding_by_document_signal.send(document_id) + model = get_embedding_model_by_dataset_id(dataset_id=self.data.get('dataset_id')) + ListenerManagement.embedding_by_document_signal.send(document_id, embedding_model=model) @transaction.atomic def delete(self): @@ -599,8 +608,9 @@ class DocumentSerializers(ApiMixin, serializers.Serializer): return True @staticmethod - def post_embedding(result, document_id): - ListenerManagement.embedding_by_document_signal.send(document_id) + def post_embedding(result, document_id, dataset_id): + model = get_embedding_model_by_dataset_id(dataset_id) + ListenerManagement.embedding_by_document_signal.send(document_id, embedding_model=model) return result @staticmethod @@ -646,7 +656,7 @@ class DocumentSerializers(ApiMixin, serializers.Serializer): document_id = str(document_model.id) return DocumentSerializers.Operate( data={'dataset_id': dataset_id, 'document_id': document_id}).one( - with_valid=True), document_id + with_valid=True), document_id, dataset_id @staticmethod def get_sync_handler(dataset_id): @@ -803,9 +813,10 @@ class DocumentSerializers(ApiMixin, serializers.Serializer): return openapi.Schema(type=openapi.TYPE_ARRAY, items=DocumentSerializers.Create.get_request_body_api()) @staticmethod - def post_embedding(document_list): + def post_embedding(document_list, dataset_id): for document_dict in document_list: - ListenerManagement.embedding_by_document_signal.send(document_dict.get('id')) + model = get_embedding_model_by_dataset_id(dataset_id) + ListenerManagement.embedding_by_document_signal.send(document_dict.get('id'), embedding_model=model) return document_list @post(post_function=post_embedding) @@ -846,7 +857,8 @@ class DocumentSerializers(ApiMixin, serializers.Serializer): return [], query_set = query_set.filter(**{'id__in': [d.id for d in document_model_list]}) return native_search(query_set, select_string=get_file_content( - os.path.join(PROJECT_DIR, "apps", "dataset", 'sql', 'list_document.sql')), with_search_one=False), + os.path.join(PROJECT_DIR, "apps", "dataset", 'sql', 'list_document.sql')), + with_search_one=False), dataset_id @staticmethod def _batch_sync(document_id_list: List[str]): diff --git a/apps/dataset/serializers/paragraph_serializers.py b/apps/dataset/serializers/paragraph_serializers.py index 61ae860b6..36f625ad7 100644 --- a/apps/dataset/serializers/paragraph_serializers.py +++ b/apps/dataset/serializers/paragraph_serializers.py @@ -20,9 +20,9 @@ from common.exception.app_exception import AppApiException from common.mixins.api_mixin import ApiMixin from common.util.common import post from common.util.field_message import ErrMessage -from dataset.models import Paragraph, Problem, Document, ProblemParagraphMapping +from dataset.models import Paragraph, Problem, Document, ProblemParagraphMapping, DataSet from dataset.serializers.common_serializers import update_document_char_length, BatchSerializer, ProblemParagraphObject, \ - ProblemParagraphManage + ProblemParagraphManage, get_embedding_model_by_dataset_id, get_embedding_model_by_dataset from dataset.serializers.problem_serializers import ProblemInstanceSerializer, ProblemSerializer, ProblemSerializers from embedding.models import SourceType @@ -132,6 +132,7 @@ class ParagraphSerializers(ApiMixin, serializers.Serializer): paragraph_id=self.data.get('paragraph_id'), dataset_id=self.data.get('dataset_id')) problem_paragraph_mapping.save() + model = get_embedding_model_by_dataset_id(self.data.get('dataset_id')) if with_embedding: ListenerManagement.embedding_by_problem_signal.send({'text': problem.content, 'is_active': True, @@ -140,7 +141,7 @@ class ParagraphSerializers(ApiMixin, serializers.Serializer): 'document_id': self.data.get('document_id'), 'paragraph_id': self.data.get('paragraph_id'), 'dataset_id': self.data.get('dataset_id'), - }) + }, embedding_model=model) return ProblemSerializers.Operate( data={'dataset_id': self.data.get('dataset_id'), @@ -227,6 +228,7 @@ class ParagraphSerializers(ApiMixin, serializers.Serializer): problem_id=problem.id) problem_paragraph_mapping.save() if with_embedding: + model = get_embedding_model_by_dataset_id(self.data.get('dataset_id')) ListenerManagement.embedding_by_problem_signal.send({'text': problem.content, 'is_active': True, 'source_type': SourceType.PROBLEM, @@ -234,7 +236,7 @@ class ParagraphSerializers(ApiMixin, serializers.Serializer): 'document_id': self.data.get('document_id'), 'paragraph_id': self.data.get('paragraph_id'), 'dataset_id': self.data.get('dataset_id'), - }) + }, embedding_model=model) def un_association(self, with_valid=True): if with_valid: @@ -336,10 +338,11 @@ class ParagraphSerializers(ApiMixin, serializers.Serializer): # 修改mapping QuerySet(ProblemParagraphMapping).bulk_update(problem_paragraph_mapping_list, ['document_id']) + # 修改向量段落信息 ListenerManagement.update_embedding_document_id(UpdateEmbeddingDocumentIdArgs( [paragraph.id for paragraph in paragraph_list], - target_document_id, target_dataset_id)) + target_document_id, target_dataset_id, target_embedding_model=None)) # 修改段落信息 paragraph_list.update(document_id=target_document_id) # 不同数据集迁移 @@ -366,12 +369,19 @@ class ParagraphSerializers(ApiMixin, serializers.Serializer): # 修改mapping QuerySet(ProblemParagraphMapping).bulk_update(problem_paragraph_mapping_list, ['problem_id', 'dataset_id', 'document_id']) - # 修改向量段落信息 - ListenerManagement.update_embedding_document_id(UpdateEmbeddingDocumentIdArgs( - [paragraph.id for paragraph in paragraph_list], - target_document_id, target_dataset_id)) + target_dataset = QuerySet(DataSet).filter(id=target_dataset_id).first() + dataset = QuerySet(DataSet).filter(id=dataset_id).first() + embedding_model = None + if target_dataset.embedding_mode_id != dataset.embedding_mode_id: + embedding_model = get_embedding_model_by_dataset(target_dataset) + pid_list = [paragraph.id for paragraph in paragraph_list] # 修改段落信息 paragraph_list.update(dataset_id=target_dataset_id, document_id=target_document_id) + # 修改向量段落信息 + ListenerManagement.update_embedding_document_id(UpdateEmbeddingDocumentIdArgs( + pid_list, + target_document_id, target_dataset_id, target_embedding_model=embedding_model)) + update_document_char_length(document_id) update_document_char_length(target_document_id) @@ -454,13 +464,14 @@ class ParagraphSerializers(ApiMixin, serializers.Serializer): raise AppApiException(500, "段落id不存在") @staticmethod - def post_embedding(paragraph, instance): + def post_embedding(paragraph, instance, dataset_id): if 'is_active' in instance and instance.get('is_active') is not None: s = (ListenerManagement.enable_embedding_by_paragraph_signal if instance.get( 'is_active') else ListenerManagement.disable_embedding_by_paragraph_signal) s.send(paragraph.get('id')) else: - ListenerManagement.embedding_by_paragraph_signal.send(paragraph.get('id')) + model = get_embedding_model_by_dataset_id(dataset_id) + ListenerManagement.embedding_by_paragraph_signal.send(paragraph.get('id'), embedding_model=model) return paragraph @post(post_embedding) @@ -508,7 +519,7 @@ class ParagraphSerializers(ApiMixin, serializers.Serializer): _paragraph.save() update_document_char_length(self.data.get('document_id')) - return self.one(), instance + return self.one(), instance, self.data.get('dataset_id') def get_problem_list(self): ProblemParagraphMapping(ProblemParagraphMapping) @@ -582,7 +593,8 @@ class ParagraphSerializers(ApiMixin, serializers.Serializer): # 修改长度 update_document_char_length(document_id) if with_embedding: - ListenerManagement.embedding_by_paragraph_signal.send(str(paragraph.id)) + model = get_embedding_model_by_dataset_id(dataset_id) + ListenerManagement.embedding_by_paragraph_signal.send(str(paragraph.id), embedding_model=model) return ParagraphSerializers.Operate( data={'paragraph_id': str(paragraph.id), 'dataset_id': dataset_id, 'document_id': document_id}).one( with_valid=True) diff --git a/apps/dataset/serializers/problem_serializers.py b/apps/dataset/serializers/problem_serializers.py index 5d00d5be4..34064f9a9 100644 --- a/apps/dataset/serializers/problem_serializers.py +++ b/apps/dataset/serializers/problem_serializers.py @@ -20,7 +20,8 @@ from common.event import ListenerManagement, UpdateProblemArgs from common.mixins.api_mixin import ApiMixin from common.util.field_message import ErrMessage from common.util.file_util import get_file_content -from dataset.models import Problem, Paragraph, ProblemParagraphMapping +from dataset.models import Problem, Paragraph, ProblemParagraphMapping, DataSet +from dataset.serializers.common_serializers import get_embedding_model_by_dataset_id from smartdoc.conf import PROJECT_DIR @@ -157,6 +158,8 @@ class ProblemSerializers(ApiMixin, serializers.Serializer): content = instance.get('content') problem = QuerySet(Problem).filter(id=problem_id, dataset_id=dataset_id).first() + QuerySet(DataSet).filter(id=dataset_id) problem.content = content problem.save() - ListenerManagement.update_problem_signal.send(UpdateProblemArgs(problem_id, content)) + model = get_embedding_model_by_dataset_id(dataset_id) + ListenerManagement.update_problem_signal.send(UpdateProblemArgs(problem_id, content, model)) diff --git a/apps/dataset/views/dataset.py b/apps/dataset/views/dataset.py index e2bd10e09..adf10e3ff 100644 --- a/apps/dataset/views/dataset.py +++ b/apps/dataset/views/dataset.py @@ -52,7 +52,6 @@ class Dataset(APIView): @action(methods=['POST'], detail=False) @swagger_auto_schema(operation_summary="创建QA知识库", operation_id="创建QA知识库", - manual_parameters=DataSetSerializers.Create.CreateQASerializers.get_request_params_api(), responses=get_api_response( DataSetSerializers.Create.CreateQASerializers.get_response_body_api()), diff --git a/apps/embedding/vector/base_vector.py b/apps/embedding/vector/base_vector.py index 2bfd0e977..281ca12f2 100644 --- a/apps/embedding/vector/base_vector.py +++ b/apps/embedding/vector/base_vector.py @@ -10,9 +10,8 @@ import threading from abc import ABC, abstractmethod from typing import List, Dict -from langchain_community.embeddings import HuggingFaceEmbeddings +from langchain_core.embeddings import Embeddings -from common.config.embedding_config import EmbeddingModel from common.util.common import sub_array from embedding.models import SourceType, SearchMode @@ -51,7 +50,7 @@ class BaseVectorStore(ABC): def save(self, text, source_type: SourceType, dataset_id: str, document_id: str, paragraph_id: str, source_id: str, is_active: bool, - embedding=None): + embedding: Embeddings): """ 插入向量数据 :param source_id: 资源id @@ -64,13 +63,10 @@ class BaseVectorStore(ABC): :param paragraph_id 段落id :return: bool """ - - if embedding is None: - embedding = EmbeddingModel.get_embedding_model() self.save_pre_handler() self._save(text, source_type, dataset_id, document_id, paragraph_id, source_id, is_active, embedding) - def batch_save(self, data_list: List[Dict], embedding=None): + def batch_save(self, data_list: List[Dict], embedding: Embeddings): # 获取锁 lock.acquire() try: @@ -80,8 +76,6 @@ class BaseVectorStore(ABC): :param embedding: 向量化处理器 :return: bool """ - if embedding is None: - embedding = EmbeddingModel.get_embedding_model() self.save_pre_handler() result = sub_array(data_list) for child_array in result: @@ -94,17 +88,17 @@ class BaseVectorStore(ABC): @abstractmethod def _save(self, text, source_type: SourceType, dataset_id: str, document_id: str, paragraph_id: str, source_id: str, is_active: bool, - embedding: HuggingFaceEmbeddings): + embedding: Embeddings): pass @abstractmethod - def _batch_save(self, text_list: List[Dict], embedding: HuggingFaceEmbeddings): + def _batch_save(self, text_list: List[Dict], embedding: Embeddings): pass def search(self, query_text, dataset_id_list: list[str], exclude_document_id_list: list[str], exclude_paragraph_list: list[str], is_active: bool, - embedding: HuggingFaceEmbeddings): + embedding: Embeddings): if dataset_id_list is None or len(dataset_id_list) == 0: return [] embedding_query = embedding.embed_query(query_text) @@ -123,7 +117,7 @@ class BaseVectorStore(ABC): def hit_test(self, query_text, dataset_id: list[str], exclude_document_id_list: list[str], top_number: int, similarity: float, search_mode: SearchMode, - embedding: HuggingFaceEmbeddings): + embedding: Embeddings): pass @abstractmethod @@ -142,14 +136,6 @@ class BaseVectorStore(ABC): def update_by_source_ids(self, source_ids: List[str], instance: Dict): pass - @abstractmethod - def embed_documents(self, text_list: List[str]): - pass - - @abstractmethod - def embed_query(self, text: str): - pass - @abstractmethod def delete_by_dataset_id(self, dataset_id: str): pass diff --git a/apps/embedding/vector/pg_vector.py b/apps/embedding/vector/pg_vector.py index e9a62ae57..935461b00 100644 --- a/apps/embedding/vector/pg_vector.py +++ b/apps/embedding/vector/pg_vector.py @@ -13,9 +13,8 @@ from abc import ABC, abstractmethod from typing import Dict, List from django.db.models import QuerySet -from langchain_community.embeddings import HuggingFaceEmbeddings +from langchain_core.embeddings import Embeddings -from common.config.embedding_config import EmbeddingModel from common.db.search import generate_sql_by_query_dict from common.db.sql_execute import select_list from common.util.file_util import get_file_content @@ -33,14 +32,6 @@ class PGVector(BaseVectorStore): def update_by_source_ids(self, source_ids: List[str], instance: Dict): QuerySet(Embedding).filter(source_id__in=source_ids).update(**instance) - def embed_documents(self, text_list: List[str]): - embedding = EmbeddingModel.get_embedding_model() - return embedding.embed_documents(text_list) - - def embed_query(self, text: str): - embedding = EmbeddingModel.get_embedding_model() - return embedding.embed_query(text) - def vector_is_create(self) -> bool: # 项目启动默认是创建好的 不需要再创建 return True @@ -50,7 +41,7 @@ class PGVector(BaseVectorStore): def _save(self, text, source_type: SourceType, dataset_id: str, document_id: str, paragraph_id: str, source_id: str, is_active: bool, - embedding: HuggingFaceEmbeddings): + embedding: Embeddings): text_embedding = embedding.embed_query(text) embedding = Embedding(id=uuid.uuid1(), dataset_id=dataset_id, @@ -64,7 +55,7 @@ class PGVector(BaseVectorStore): embedding.save() return True - def _batch_save(self, text_list: List[Dict], embedding: HuggingFaceEmbeddings): + def _batch_save(self, text_list: List[Dict], embedding: Embeddings): texts = [row.get('text') for row in text_list] embeddings = embedding.embed_documents(texts) embedding_list = [Embedding(id=uuid.uuid1(), @@ -83,7 +74,7 @@ class PGVector(BaseVectorStore): def hit_test(self, query_text, dataset_id_list: list[str], exclude_document_id_list: list[str], top_number: int, similarity: float, search_mode: SearchMode, - embedding: HuggingFaceEmbeddings): + embedding: Embeddings): if dataset_id_list is None or len(dataset_id_list) == 0: return [] exclude_dict = {} diff --git a/apps/setting/migrations/0005_model_permission_type.py b/apps/setting/migrations/0005_model_permission_type.py new file mode 100644 index 000000000..dba081a19 --- /dev/null +++ b/apps/setting/migrations/0005_model_permission_type.py @@ -0,0 +1,46 @@ +# Generated by Django 4.2.13 on 2024-07-15 15:23 +import json + +from django.db import migrations, models +from django.db.models import QuerySet + +from common.util.rsa_util import rsa_long_encrypt +from setting.models import Status, PermissionType +from smartdoc.const import CONFIG + +default_embedding_model_id = '42f63a3d-427e-11ef-b3ec-a8a1595801ab' + + +def save_default_embedding_model(apps, schema_editor): + ModelModel = apps.get_model('setting', 'Model') + cache_folder = CONFIG.get('EMBEDDING_MODEL_PATH') + model_name = CONFIG.get('EMBEDDING_MODEL_NAME') + credential = {'cache_folder': cache_folder} + model_credential_str = json.dumps(credential) + model = ModelModel(id=default_embedding_model_id, name='maxkb-embedding', status=Status.SUCCESS, + model_type="EMBEDDING", model_name=model_name, user_id='f0dd8f71-e4ee-11ee-8c84-a8a1595801ab', + provider='model_local_provider', + credential=rsa_long_encrypt(model_credential_str), meta={}, + permission_type=PermissionType.PUBLIC) + model.save() + + +def reverse_code_embedding_model(apps, schema_editor): + ModelModel = apps.get_model('setting', 'Model') + QuerySet(ModelModel).filter(id=default_embedding_model_id).delete() + + +class Migration(migrations.Migration): + dependencies = [ + ('setting', '0004_alter_model_credential'), + ] + + operations = [ + migrations.AddField( + model_name='model', + name='permission_type', + field=models.CharField(choices=[('PUBLIC', '公开'), ('PRIVATE', '私有')], default='PRIVATE', max_length=20, + verbose_name='权限类型'), + ), + migrations.RunPython(save_default_embedding_model, reverse_code_embedding_model) + ] diff --git a/apps/setting/models/model_management.py b/apps/setting/models/model_management.py index 5bdd1b296..4c47eadfc 100644 --- a/apps/setting/models/model_management.py +++ b/apps/setting/models/model_management.py @@ -22,6 +22,13 @@ class Status(models.TextChoices): DOWNLOAD = "DOWNLOAD", '下载中' + PAUSE_DOWNLOAD = "PAUSE_DOWNLOAD", '暂停下载' + + +class PermissionType(models.TextChoices): + PUBLIC = "PUBLIC", '公开' + PRIVATE = "PRIVATE", "私有" + class Model(AppModelMixin): """ @@ -46,6 +53,9 @@ class Model(AppModelMixin): meta = models.JSONField(verbose_name="模型元数据,用于存储下载,或者错误信息", default=dict) + permission_type = models.CharField(max_length=20, verbose_name='权限类型', choices=PermissionType.choices, + default=PermissionType.PRIVATE) + class Meta: db_table = "model" unique_together = ['name', 'user_id'] diff --git a/apps/setting/models_provider/__init__.py b/apps/setting/models_provider/__init__.py index 53b7001e5..285e90f32 100644 --- a/apps/setting/models_provider/__init__.py +++ b/apps/setting/models_provider/__init__.py @@ -6,3 +6,85 @@ @date:2023/10/31 17:16 @desc: """ +import json +from typing import Dict + +from common.util.rsa_util import rsa_long_decrypt +from setting.models_provider.constants.model_provider_constants import ModelProvideConstants + + +def get_model_(provider, model_type, model_name, credential): + """ + 获取模型实例 + @param provider: 供应商 + @param model_type: 模型类型 + @param model_name: 模型名称 + @param credential: 认证信息 + @return: 模型实例 + """ + model = get_provider(provider).get_model(model_type, model_name, + json.loads( + rsa_long_decrypt(credential)), + streaming=True) + return model + + +def get_model(model): + """ + 获取模型实例 + @param model: model 数据库Model实例对象 + @return: 模型实例 + """ + return get_model_(model.provider, model.model_type, model.model_name, model.credential) + + +def get_provider(provider): + """ + 获取供应商实例 + @param provider: 供应商字符串 + @return: 供应商实例 + """ + return ModelProvideConstants[provider].value + + +def get_model_list(provider, model_type): + """ + 获取模型列表 + @param provider: 供应商字符串 + @param model_type: 模型类型 + @return: 模型列表 + """ + return get_provider(provider).get_model_list(model_type) + + +def get_model_credential(provider, model_type, model_name): + """ + 获取模型认证实例 + @param provider: 供应商字符串 + @param model_type: 模型类型 + @param model_name: 模型名称 + @return: 认证实例对象 + """ + return get_provider(provider).get_model_credential(model_type, model_name) + + +def get_model_type_list(provider): + """ + 获取模型类型列表 + @param provider: 供应商字符串 + @return: 模型类型列表 + """ + return get_provider(provider).get_model_type_list() + + +def is_valid_credential(provider, model_type, model_name, model_credential: Dict[str, object], raise_exception=False): + """ + 校验模型认证参数 + @param provider: 供应商字符串 + @param model_type: 模型类型 + @param model_name: 模型名称 + @param model_credential: 模型认证数据 + @param raise_exception: 是否抛出错误 + @return: True|False + """ + return get_provider(provider).is_valid_credential(model_type, model_name, model_credential, raise_exception) diff --git a/apps/setting/models_provider/base_model_provider.py b/apps/setting/models_provider/base_model_provider.py index aa6c525e0..8a9ab5e2b 100644 --- a/apps/setting/models_provider/base_model_provider.py +++ b/apps/setting/models_provider/base_model_provider.py @@ -61,7 +61,7 @@ class IModelProvider(ABC): def get_model_list(self, model_type): if model_type is None: raise AppApiException(500, '模型类型不能为空') - return self.get_model_info_manage().get_model_list() + return self.get_model_info_manage().get_model_list_by_model_type(model_type) def get_model_credential(self, model_type, model_name): model_info = self.get_model_info_manage().get_model_info(model_type, model_name) @@ -191,6 +191,9 @@ class ModelInfoManage: def get_model_list(self): return [model.to_dict() for model in self.model_list] + def get_model_list_by_model_type(self, model_type): + return [model.to_dict() for model in self.model_list if model.model_type == model_type] + def get_model_type_list(self): return [{'key': _type.value.get('message'), 'value': _type.value.get('code')} for _type in ModelTypeConst if len([model for model in self.model_list if model.model_type == _type.name]) > 0] diff --git a/apps/setting/models_provider/constants/model_provider_constants.py b/apps/setting/models_provider/constants/model_provider_constants.py index 1db811f9e..c9e9659c3 100644 --- a/apps/setting/models_provider/constants/model_provider_constants.py +++ b/apps/setting/models_provider/constants/model_provider_constants.py @@ -18,6 +18,7 @@ from setting.models_provider.impl.qwen_model_provider.qwen_model_provider import from setting.models_provider.impl.wenxin_model_provider.wenxin_model_provider import WenxinModelProvider from setting.models_provider.impl.xf_model_provider.xf_model_provider import XunFeiModelProvider from setting.models_provider.impl.zhipu_model_provider.zhipu_model_provider import ZhiPuModelProvider +from setting.models_provider.impl.local_model_provider.local_model_provider import LocalModelProvider class ModelProvideConstants(Enum): @@ -31,3 +32,4 @@ class ModelProvideConstants(Enum): model_xf_provider = XunFeiModelProvider() model_deepseek_provider = DeepSeekModelProvider() model_gemini_provider = GeminiModelProvider() + model_local_provider = LocalModelProvider() diff --git a/apps/setting/models_provider/impl/local_model_provider/__init__.py b/apps/setting/models_provider/impl/local_model_provider/__init__.py new file mode 100644 index 000000000..90a8d72c3 --- /dev/null +++ b/apps/setting/models_provider/impl/local_model_provider/__init__.py @@ -0,0 +1,8 @@ +# coding=utf-8 +""" + @project: MaxKB + @Author:虎 + @file: __init__.py + @date:2024/7/10 17:48 + @desc: +""" diff --git a/apps/setting/models_provider/impl/local_model_provider/credential/embedding.py b/apps/setting/models_provider/impl/local_model_provider/credential/embedding.py new file mode 100644 index 000000000..a631196eb --- /dev/null +++ b/apps/setting/models_provider/impl/local_model_provider/credential/embedding.py @@ -0,0 +1,45 @@ +# coding=utf-8 +""" + @project: MaxKB + @Author:虎 + @file: embedding.py + @date:2024/7/11 11:06 + @desc: +""" +from typing import Dict + +from common import forms +from common.exception.app_exception import AppApiException +from common.forms import BaseForm +from setting.models_provider.base_model_provider import BaseModelCredential, ValidCode +from setting.models_provider.impl.local_model_provider.model.embedding import LocalEmbedding + + +class LocalEmbeddingCredential(BaseForm, BaseModelCredential): + + def is_valid(self, model_type: str, model_name, model_credential: Dict[str, object], provider, + raise_exception=False): + if not model_type == 'EMBEDDING': + raise AppApiException(ValidCode.valid_error.value, f'{model_type} 模型类型不支持') + for key in ['cache_folder']: + if key not in model_credential: + if raise_exception: + raise AppApiException(ValidCode.valid_error.value, f'{key} 字段为必填字段') + else: + return False + try: + model: LocalEmbedding = provider.get_model(model_type, model_name, model_credential) + model.embed_query('你好') + except Exception as e: + if isinstance(e, AppApiException): + raise e + if raise_exception: + raise AppApiException(ValidCode.valid_error.value, f'校验失败,请检查参数是否正确: {str(e)}') + else: + return False + return True + + def encryption_dict(self, model: Dict[str, object]): + return model + + cache_folder = forms.TextInputField('模型目录', required=True) diff --git a/apps/setting/models_provider/impl/local_model_provider/icon/local_icon_svg b/apps/setting/models_provider/impl/local_model_provider/icon/local_icon_svg new file mode 100644 index 000000000..62930faab --- /dev/null +++ b/apps/setting/models_provider/impl/local_model_provider/icon/local_icon_svg @@ -0,0 +1 @@ + \ No newline at end of file diff --git a/apps/setting/models_provider/impl/local_model_provider/local_model_provider.py b/apps/setting/models_provider/impl/local_model_provider/local_model_provider.py new file mode 100644 index 000000000..65cc57322 --- /dev/null +++ b/apps/setting/models_provider/impl/local_model_provider/local_model_provider.py @@ -0,0 +1,38 @@ +# coding=utf-8 +""" + @project: maxkb + @Author:虎 + @file: zhipu_model_provider.py + @date:2024/04/19 13:5 + @desc: +""" +import os +from typing import Dict + +from pydantic import BaseModel + +from common.exception.app_exception import AppApiException +from common.util.file_util import get_file_content +from setting.models_provider.base_model_provider import ModelProvideInfo, ModelTypeConst, ModelInfo, IModelProvider, \ + ModelInfoManage +from setting.models_provider.impl.local_model_provider.credential.embedding import LocalEmbeddingCredential +from setting.models_provider.impl.local_model_provider.model.embedding import LocalEmbedding +from smartdoc.conf import PROJECT_DIR + +embedding_text2vec_base_chinese = ModelInfo('shibing624/text2vec-base-chinese', '', ModelTypeConst.EMBEDDING, + LocalEmbeddingCredential(), LocalEmbedding) + +model_info_manage = (ModelInfoManage.builder().append_model_info(embedding_text2vec_base_chinese) + .append_default_model_info(embedding_text2vec_base_chinese) + .build()) + + +class LocalModelProvider(IModelProvider): + + def get_model_info_manage(self): + return model_info_manage + + def get_model_provide_info(self): + return ModelProvideInfo(provider='model_local_provider', name='本地模型', icon=get_file_content( + os.path.join(PROJECT_DIR, "apps", "setting", 'models_provider', 'impl', 'local_model_provider', 'icon', + 'local_icon_svg'))) diff --git a/apps/setting/models_provider/impl/local_model_provider/model/embedding.py b/apps/setting/models_provider/impl/local_model_provider/model/embedding.py new file mode 100644 index 000000000..92cdd7390 --- /dev/null +++ b/apps/setting/models_provider/impl/local_model_provider/model/embedding.py @@ -0,0 +1,22 @@ +# coding=utf-8 +""" + @project: MaxKB + @Author:虎 + @file: embedding.py + @date:2024/7/11 14:06 + @desc: +""" +from typing import Dict + +from langchain_huggingface import HuggingFaceEmbeddings + +from setting.models_provider.base_model_provider import MaxKBBaseModel + + +class LocalEmbedding(MaxKBBaseModel, HuggingFaceEmbeddings): + @staticmethod + def new_instance(model_type, model_name, model_credential: Dict[str, object], **model_kwargs): + return LocalEmbedding(model_name=model_name, cache_folder=model_credential.get('cache_folder'), + model_kwargs={'device': model_credential.get('device')}, + encode_kwargs={'normalize_embeddings': True}, + ) diff --git a/apps/setting/models_provider/impl/ollama_model_provider/credential/embedding.py b/apps/setting/models_provider/impl/ollama_model_provider/credential/embedding.py new file mode 100644 index 000000000..e0eeabe59 --- /dev/null +++ b/apps/setting/models_provider/impl/ollama_model_provider/credential/embedding.py @@ -0,0 +1,45 @@ +# coding=utf-8 +""" + @project: MaxKB + @Author:虎 + @file: embedding.py + @date:2024/7/12 15:10 + @desc: +""" +from typing import Dict + +from common import forms +from common.exception.app_exception import AppApiException +from common.forms import BaseForm +from setting.models_provider.base_model_provider import BaseModelCredential, ValidCode +from setting.models_provider.impl.local_model_provider.model.embedding import LocalEmbedding + + +class OllamaEmbeddingModelCredential(BaseForm, BaseModelCredential): + def is_valid(self, model_type: str, model_name, model_credential: Dict[str, object], provider, + raise_exception=False): + model_type_list = provider.get_model_type_list() + if not any(list(filter(lambda mt: mt.get('value') == model_type, model_type_list))): + raise AppApiException(ValidCode.valid_error.value, f'{model_type} 模型类型不支持') + try: + model_list = provider.get_base_model_list(model_credential.get('api_base')) + except Exception as e: + raise AppApiException(ValidCode.valid_error.value, "API 域名无效") + exist = [model for model in (model_list.get('models') if model_list.get('models') is not None else []) if + model.get('model') == model_name or model.get('model').replace(":latest", "") == model_name] + if len(exist) == 0: + raise AppApiException(ValidCode.model_not_fount, "模型不存在,请先下载模型") + model: LocalEmbedding = provider.get_model(model_type, model_name, model_credential) + model.embed_query('你好') + return True + + def encryption_dict(self, model_info: Dict[str, object]): + return model_info + + def build_model(self, model_info: Dict[str, object]): + for key in ['model']: + if key not in model_info: + raise AppApiException(500, f'{key} 字段为必填字段') + return self + + api_base = forms.TextInputField('API 域名', required=True) diff --git a/apps/setting/models_provider/impl/ollama_model_provider/model/embedding.py b/apps/setting/models_provider/impl/ollama_model_provider/model/embedding.py new file mode 100644 index 000000000..d1a68ebc7 --- /dev/null +++ b/apps/setting/models_provider/impl/ollama_model_provider/model/embedding.py @@ -0,0 +1,48 @@ +# coding=utf-8 +""" + @project: MaxKB + @Author:虎 + @file: embedding.py + @date:2024/7/12 15:02 + @desc: +""" +from typing import Dict, List + +from langchain_community.embeddings import OllamaEmbeddings + +from setting.models_provider.base_model_provider import MaxKBBaseModel + + +class OllamaEmbedding(MaxKBBaseModel, OllamaEmbeddings): + @staticmethod + def new_instance(model_type, model_name, model_credential: Dict[str, object], **model_kwargs): + return OllamaEmbedding( + model=model_name, + base_url=model_credential.get('api_base'), + ) + + def embed_documents(self, texts: List[str]) -> List[List[float]]: + """Embed documents using an Ollama deployed embedding model. + + Args: + texts: The list of texts to embed. + + Returns: + List of embeddings, one for each text. + """ + instruction_pairs = [f"{text}" for text in texts] + embeddings = self._embed(instruction_pairs) + return embeddings + + def embed_query(self, text: str) -> List[float]: + """Embed a query using a Ollama deployed embedding model. + + Args: + text: The text to embed. + + Returns: + Embeddings for the text. + """ + instruction_pair = f"{text}" + embedding = self._embed([instruction_pair])[0] + return embedding diff --git a/apps/setting/models_provider/impl/ollama_model_provider/ollama_model_provider.py b/apps/setting/models_provider/impl/ollama_model_provider/ollama_model_provider.py index 3839d5fcb..eb01d38d7 100644 --- a/apps/setting/models_provider/impl/ollama_model_provider/ollama_model_provider.py +++ b/apps/setting/models_provider/impl/ollama_model_provider/ollama_model_provider.py @@ -20,7 +20,9 @@ from common.forms import BaseForm from common.util.file_util import get_file_content from setting.models_provider.base_model_provider import IModelProvider, ModelProvideInfo, ModelInfo, ModelTypeConst, \ BaseModelCredential, DownModelChunk, DownModelChunkStatus, ValidCode, ModelInfoManage +from setting.models_provider.impl.ollama_model_provider.credential.embedding import OllamaEmbeddingModelCredential from setting.models_provider.impl.ollama_model_provider.credential.llm import OllamaLLMModelCredential +from setting.models_provider.impl.ollama_model_provider.model.embedding import OllamaEmbedding from setting.models_provider.impl.ollama_model_provider.model.llm import OllamaChatModel from smartdoc.conf import PROJECT_DIR @@ -88,14 +90,25 @@ model_info_list = [ ModelInfo( 'phi3', 'Phi-3 Mini是Microsoft的3.8B参数,轻量级,最先进的开放模型。', - ModelTypeConst.LLM, ollama_llm_model_credential, OllamaChatModel) + ModelTypeConst.LLM, ollama_llm_model_credential, OllamaChatModel), +] +ollama_embedding_model_credential = OllamaEmbeddingModelCredential() +embedding_model_info = [ + ModelInfo( + 'nomic-embed-text', + '一个具有大令牌上下文窗口的高性能开放嵌入模型。', + ModelTypeConst.EMBEDDING, ollama_embedding_model_credential, OllamaEmbedding), ] -model_info_manage = ModelInfoManage.builder().append_model_info_list(model_info_list).append_default_model_info( +model_info_manage = ModelInfoManage.builder().append_model_info_list(model_info_list).append_model_info_list( + embedding_model_info).append_default_model_info( ModelInfo( 'phi3', 'Phi-3 Mini是Microsoft的3.8B参数,轻量级,最先进的开放模型。', - ModelTypeConst.LLM, ollama_llm_model_credential, OllamaChatModel)).build() + ModelTypeConst.LLM, ollama_llm_model_credential, OllamaChatModel)).append_default_model_info(ModelInfo( + 'nomic-embed-text', + '一个具有大令牌上下文窗口的高性能开放嵌入模型。', + ModelTypeConst.EMBEDDING, ollama_embedding_model_credential, OllamaEmbedding), ).build() def get_base_url(url: str): @@ -139,7 +152,6 @@ def convert(response_stream) -> Iterator[DownModelChunk]: temp = "" if len(temp) > 0: - print(temp) rows = [t for t in temp.split("\n") if len(t) > 0] for row in rows: yield convert_to_down_model_chunk(row, index) @@ -154,9 +166,6 @@ class OllamaModelProvider(IModelProvider): os.path.join(PROJECT_DIR, "apps", "setting", 'models_provider', 'impl', 'ollama_model_provider', 'icon', 'ollama_icon_svg'))) - def get_dialogue_number(self): - return 2 - @staticmethod def get_base_model_list(api_base): base_url = get_base_url(api_base) @@ -165,7 +174,7 @@ class OllamaModelProvider(IModelProvider): return r.json() def down_model(self, model_type: str, model_name, model_credential: Dict[str, object]) -> Iterator[DownModelChunk]: - api_base = model_credential.get('api_base') + api_base = model_credential.get('api_base', '') base_url = get_base_url(api_base) r = requests.request( method="POST", diff --git a/apps/setting/models_provider/impl/openai_model_provider/credential/embedding.py b/apps/setting/models_provider/impl/openai_model_provider/credential/embedding.py new file mode 100644 index 000000000..d49d22e22 --- /dev/null +++ b/apps/setting/models_provider/impl/openai_model_provider/credential/embedding.py @@ -0,0 +1,46 @@ +# coding=utf-8 +""" + @project: MaxKB + @Author:虎 + @file: embedding.py + @date:2024/7/12 16:45 + @desc: +""" +from typing import Dict + +from common import forms +from common.exception.app_exception import AppApiException +from common.forms import BaseForm +from setting.models_provider.base_model_provider import BaseModelCredential, ValidCode + + +class OpenAIEmbeddingCredential(BaseForm, BaseModelCredential): + def is_valid(self, model_type: str, model_name, model_credential: Dict[str, object], provider, + raise_exception=True): + model_type_list = provider.get_model_type_list() + if not any(list(filter(lambda mt: mt.get('value') == model_type, model_type_list))): + raise AppApiException(ValidCode.valid_error.value, f'{model_type} 模型类型不支持') + + for key in ['api_base', 'api_key']: + if key not in model_credential: + if raise_exception: + raise AppApiException(ValidCode.valid_error.value, f'{key} 字段为必填字段') + else: + return False + try: + model = provider.get_model(model_type, model_name, model_credential) + model.embed_query('你好') + except Exception as e: + if isinstance(e, AppApiException): + raise e + if raise_exception: + raise AppApiException(ValidCode.valid_error.value, f'校验失败,请检查参数是否正确: {str(e)}') + else: + return False + return True + + def encryption_dict(self, model: Dict[str, object]): + return {**model, 'api_key': super().encryption(model.get('api_key', ''))} + + api_base = forms.TextInputField('API 域名', required=True) + api_key = forms.PasswordInputField('API Key', required=True) diff --git a/apps/setting/models_provider/impl/openai_model_provider/model/embedding.py b/apps/setting/models_provider/impl/openai_model_provider/model/embedding.py new file mode 100644 index 000000000..5ac1f8e6f --- /dev/null +++ b/apps/setting/models_provider/impl/openai_model_provider/model/embedding.py @@ -0,0 +1,23 @@ +# coding=utf-8 +""" + @project: MaxKB + @Author:虎 + @file: embedding.py + @date:2024/7/12 17:44 + @desc: +""" +from typing import Dict + +from langchain_community.embeddings import OpenAIEmbeddings + +from setting.models_provider.base_model_provider import MaxKBBaseModel + + +class OpenAIEmbeddingModel(MaxKBBaseModel, OpenAIEmbeddings): + @staticmethod + def new_instance(model_type, model_name, model_credential: Dict[str, object], **model_kwargs): + return OpenAIEmbeddingModel( + api_key=model_credential.get('api_key'), + model=model_name, + openai_api_base=model_credential.get('api_base'), + ) diff --git a/apps/setting/models_provider/impl/openai_model_provider/openai_model_provider.py b/apps/setting/models_provider/impl/openai_model_provider/openai_model_provider.py index 6d12869ce..fb4c89d7b 100644 --- a/apps/setting/models_provider/impl/openai_model_provider/openai_model_provider.py +++ b/apps/setting/models_provider/impl/openai_model_provider/openai_model_provider.py @@ -11,7 +11,9 @@ import os from common.util.file_util import get_file_content from setting.models_provider.base_model_provider import IModelProvider, ModelProvideInfo, ModelInfo, \ ModelTypeConst, ModelInfoManage +from setting.models_provider.impl.openai_model_provider.credential.embedding import OpenAIEmbeddingCredential from setting.models_provider.impl.openai_model_provider.credential.llm import OpenAILLMModelCredential +from setting.models_provider.impl.openai_model_provider.model.embedding import OpenAIEmbeddingModel from setting.models_provider.impl.openai_model_provider.model.llm import OpenAIChatModel from smartdoc.conf import PROJECT_DIR @@ -58,11 +60,17 @@ model_info_list = [ ModelTypeConst.LLM, openai_llm_model_credential, OpenAIChatModel) ] +open_ai_embedding_credential = OpenAIEmbeddingCredential() +model_info_embedding_list = [ + ModelInfo('text-embedding-ada-002', '', + ModelTypeConst.EMBEDDING, open_ai_embedding_credential, + OpenAIEmbeddingModel)] model_info_manage = ModelInfoManage.builder().append_model_info_list(model_info_list).append_default_model_info( ModelInfo('gpt-3.5-turbo', '最新的gpt-3.5-turbo,随OpenAI调整而更新', ModelTypeConst.LLM, openai_llm_model_credential, OpenAIChatModel - )).build() + )).append_model_info_list(model_info_embedding_list).append_default_model_info( + model_info_embedding_list[0]).build() class OpenAIModelProvider(IModelProvider): diff --git a/apps/setting/serializers/provider_serializers.py b/apps/setting/serializers/provider_serializers.py index b5d85e8a3..a0f28e326 100644 --- a/apps/setting/serializers/provider_serializers.py +++ b/apps/setting/serializers/provider_serializers.py @@ -7,12 +7,14 @@ @desc: """ import json +import re import threading import time import uuid from typing import Dict -from django.db.models import QuerySet +from django.core import validators +from django.db.models import QuerySet, Q from rest_framework import serializers from application.models import Application @@ -36,6 +38,9 @@ class ModelPullManage: for chunk in response: down_model_chunk[chunk.digest] = chunk.to_dict() if time.time() - timestamp > 5: + model_new = QuerySet(Model).filter(id=model.id).first() + if model_new.status == Status.PAUSE_DOWNLOAD: + return QuerySet(Model).filter(id=model.id).update( meta={"down_model_chunk": list(down_model_chunk.values())}) timestamp = time.time() @@ -72,7 +77,7 @@ class ModelSerializer(serializers.Serializer): self.is_valid(raise_exception=True) user_id = self.data.get('user_id') name = self.data.get('name') - model_query_set = QuerySet(Model).filter(user_id=user_id) + model_query_set = QuerySet(Model).filter((Q(user_id=user_id) | Q(permission_type='PUBLIC'))) query_params = {} if name is not None: query_params['name__contains'] = name @@ -85,7 +90,8 @@ class ModelSerializer(serializers.Serializer): return [ {'id': str(model.id), 'provider': model.provider, 'name': model.name, 'model_type': model.model_type, - 'model_name': model.model_name, 'status': model.status, 'meta': model.meta} for model in + 'model_name': model.model_name, 'status': model.status, 'meta': model.meta, + 'permission_type': model.permission_type} for model in model_query_set.filter(**query_params).order_by("-create_time")] class Edit(serializers.Serializer): @@ -96,6 +102,11 @@ class ModelSerializer(serializers.Serializer): model_type = serializers.CharField(required=False, error_messages=ErrMessage.char("模型类型")) + permission_type = serializers.CharField(required=False, error_messages=ErrMessage.char("权限"), validators=[ + validators.RegexValidator(regex=re.compile("^PUBLIC|PRIVATE$"), + message="权限只支持PUBLIC|PRIVATE", code=500) + ]) + model_name = serializers.CharField(required=False, error_messages=ErrMessage.char("模型类型")) credential = serializers.DictField(required=False, error_messages=ErrMessage.dict("认证信息")) @@ -135,6 +146,11 @@ class ModelSerializer(serializers.Serializer): model_type = serializers.CharField(required=True, error_messages=ErrMessage.char("模型类型")) + permission_type = serializers.CharField(required=True, error_messages=ErrMessage.char("权限"), validators=[ + validators.RegexValidator(regex=re.compile("^PUBLIC|PRIVATE$"), + message="权限只支持PUBLIC|PRIVATE", code=500) + ]) + model_name = serializers.CharField(required=True, error_messages=ErrMessage.char("基础模型")) credential = serializers.DictField(required=True, error_messages=ErrMessage.dict("认证信息")) @@ -165,10 +181,12 @@ class ModelSerializer(serializers.Serializer): provider = self.data.get('provider') model_type = self.data.get('model_type') model_name = self.data.get('model_name') + permission_type = self.data.get('permission_type') model_credential_str = json.dumps(credential) model = Model(id=uuid.uuid1(), status=status, user_id=user_id, name=name, credential=rsa_long_encrypt(model_credential_str), - provider=provider, model_type=model_type, model_name=model_name) + provider=provider, model_type=model_type, model_name=model_name, + permission_type=permission_type) model.save() if status == Status.DOWNLOAD: thread = threading.Thread(target=ModelPullManage.pull, args=(model, credential)) @@ -184,7 +202,8 @@ class ModelSerializer(serializers.Serializer): 'meta': model.meta, 'credential': ModelProvideConstants[model.provider].value.get_model_credential(model.model_type, model.model_name).encryption_dict( - credential)} + credential), + 'permission_type': model.permission_type} class Operate(serializers.Serializer): id = serializers.UUIDField(required=True, error_messages=ErrMessage.uuid("模型id")) @@ -210,7 +229,8 @@ class ModelSerializer(serializers.Serializer): return {'id': str(model.id), 'provider': model.provider, 'name': model.name, 'model_type': model.model_type, 'model_name': model.model_name, 'status': model.status, - 'meta': model.meta, } + 'meta': model.meta + } def delete(self, with_valid=True): if with_valid: @@ -221,6 +241,12 @@ class ModelSerializer(serializers.Serializer): QuerySet(Model).filter(id=self.data.get('id')).delete() return True + def pause_download(self, with_valid=True): + if with_valid: + self.is_valid(raise_exception=True) + QuerySet(Model).filter(id=self.data.get('id')).update(status=Status.PAUSE_DOWNLOAD) + return True + def edit(self, instance: Dict, user_id: str, with_valid=True): if with_valid: self.is_valid(raise_exception=True) @@ -245,7 +271,7 @@ class ModelSerializer(serializers.Serializer): model.status = Status.DOWNLOAD else: raise e - update_keys = ['credential', 'name', 'model_type', 'model_name'] + update_keys = ['credential', 'name', 'model_type', 'model_name', 'permission_type'] for update_key in update_keys: if update_key in instance and instance.get(update_key) is not None: if update_key == 'credential': diff --git a/apps/setting/swagger_api/provide_api.py b/apps/setting/swagger_api/provide_api.py index f68ac5be4..7544fdf25 100644 --- a/apps/setting/swagger_api/provide_api.py +++ b/apps/setting/swagger_api/provide_api.py @@ -74,6 +74,8 @@ class ModelCreateApi(ApiMixin): 'provider': openapi.Schema(type=openapi.TYPE_STRING, title="供应商", description="供应商"), + 'permission_type': openapi.Schema(type=openapi.TYPE_STRING, title="权限", + description="PUBLIC|PRIVATE"), 'model_type': openapi.Schema(type=openapi.TYPE_STRING, title="供应商", description="供应商"), @@ -82,7 +84,8 @@ class ModelCreateApi(ApiMixin): description="供应商"), 'credential': openapi.Schema(type=openapi.TYPE_OBJECT, title="模型证书信息", - description="模型证书信息") + description="模型证书信息"), + } ) diff --git a/apps/setting/urls.py b/apps/setting/urls.py index 2a7cdd68a..650e2a2b3 100644 --- a/apps/setting/urls.py +++ b/apps/setting/urls.py @@ -16,6 +16,7 @@ urlpatterns = [ name="provider/model_form"), path('model', views.Model.as_view(), name='model'), path('model/', views.Model.Operate.as_view(), name='model/operate'), + path('model//pause_download', views.Model.PauseDownload.as_view(), name='model/operate'), path('model//meta', views.Model.ModelMeta.as_view(), name='model/operate/meta'), path('email_setting', views.SystemSetting.Email.as_view(), name='email_setting'), path('valid//', views.Valid.as_view()) diff --git a/apps/setting/views/model.py b/apps/setting/views/model.py index 7ba0304fc..9108aa15a 100644 --- a/apps/setting/views/model.py +++ b/apps/setting/views/model.py @@ -69,6 +69,18 @@ class Model(APIView): return result.success( ModelSerializer.Operate(data={'id': model_id, 'user_id': request.user.id}).one_meta(with_valid=True)) + class PauseDownload(APIView): + authentication_classes = [TokenAuth] + + @action(methods=['PUT'], detail=False) + @swagger_auto_schema(operation_summary="暂停模型下载", + operation_id="暂停模型下载", + tags=["模型"]) + @has_permissions(PermissionConstants.MODEL_CREATE) + def put(self, request: Request, model_id: str): + return result.success( + ModelSerializer.Operate(data={'id': model_id, 'user_id': request.user.id}).pause_download()) + class Operate(APIView): authentication_classes = [TokenAuth] diff --git a/apps/smartdoc/settings/base.py b/apps/smartdoc/settings/base.py index 04e8810e5..88c9b5380 100644 --- a/apps/smartdoc/settings/base.py +++ b/apps/smartdoc/settings/base.py @@ -104,9 +104,6 @@ CACHES = { "token_cache": { 'BACKEND': 'common.cache.file_cache.FileCache', 'LOCATION': os.path.join(PROJECT_DIR, 'data', 'cache', "token_cache") # 文件夹路径 - }, - "chat_cache": { - 'BACKEND': 'django.core.cache.backends.locmem.LocMemCache', } } diff --git a/ui/src/api/dataset.ts b/ui/src/api/dataset.ts index df06d2768..6e61db1df 100644 --- a/ui/src/api/dataset.ts +++ b/ui/src/api/dataset.ts @@ -129,11 +129,12 @@ const getDatasetDetail: (dataset_id: string, loading?: Ref) => Promise< "desc": true } */ -const putDataset: (dataset_id: string, data: any) => Promise> = ( - dataset_id, - data: any -) => { - return put(`${prefix}/${dataset_id}`, data) +const putDataset: ( + dataset_id: string, + data: any, + loading?: Ref +) => Promise> = (dataset_id, data, loading) => { + return put(`${prefix}/${dataset_id}`, data, undefined, loading) } /** * 获取知识库 可关联的应用列表 diff --git a/ui/src/api/model.ts b/ui/src/api/model.ts index bb98984f8..15f6517ff 100644 --- a/ui/src/api/model.ts +++ b/ui/src/api/model.ts @@ -130,7 +130,18 @@ const getModelMetaById: (model_id: string, loading?: Ref) => Promise { return get(`${prefix}/${model_id}/meta`, {}, loading) } - +/** + * 暂停下载 + * @param model_id 模型id + * @param loading 加载器 + * @returns + */ +const pauseDownload: (model_id: string, loading?: Ref) => Promise> = ( + model_id, + loading +) => { + return put(`${prefix}/${model_id}/pause_download`, undefined, {}, loading) +} const deleteModel: (model_id: string, loading?: Ref) => Promise> = ( model_id, loading @@ -147,5 +158,6 @@ export default { updateModel, deleteModel, getModelById, - getModelMetaById + getModelMetaById, + pauseDownload } diff --git a/ui/src/api/type/dataset.ts b/ui/src/api/type/dataset.ts index 6ec73c323..a30c5c98e 100644 --- a/ui/src/api/type/dataset.ts +++ b/ui/src/api/type/dataset.ts @@ -3,6 +3,7 @@ interface datasetData { desc: String documents?: Array type?: String + embedding_mode_id?: String } export type { datasetData } diff --git a/ui/src/api/type/model.ts b/ui/src/api/type/model.ts index e07b36159..fcca438d5 100644 --- a/ui/src/api/type/model.ts +++ b/ui/src/api/type/model.ts @@ -53,6 +53,7 @@ interface Model { * 模型类型 */ model_type: string + permission_type: 'PUBLIC' | 'PRIVATE' /** * 基础模型 */ @@ -68,7 +69,7 @@ interface Model { /** * 状态 */ - status: 'SUCCESS' | 'DOWNLOAD' | 'ERROR' + status: 'SUCCESS' | 'DOWNLOAD' | 'ERROR' | 'PAUSE_DOWNLOAD' /** * 元数据 */ diff --git a/ui/src/components/card-checkbox/index.vue b/ui/src/components/card-checkbox/index.vue index b008aa9c8..1d4377326 100644 --- a/ui/src/components/card-checkbox/index.vue +++ b/ui/src/components/card-checkbox/index.vue @@ -18,7 +18,8 @@ - + + @@ -40,7 +41,7 @@ const toModelValue = computed(() => (props.valueField ? props.data[props.valueFi // set: (val) => val // }) -const emit = defineEmits(['update:modelValue']) +const emit = defineEmits(['update:modelValue', 'change']) const checked = () => { const value = props.modelValue ? props.modelValue : [] @@ -53,6 +54,10 @@ const checked = () => { emit('update:modelValue', [...value, toModelValue.value]) } } + +function checkboxChange() { + emit('change') +} diff --git a/ui/src/enums/model.ts b/ui/src/enums/model.ts new file mode 100644 index 000000000..342f0ccb9 --- /dev/null +++ b/ui/src/enums/model.ts @@ -0,0 +1,8 @@ +export enum PermissionType { + PRIVATE = '私有', + PUBLIC = '公用' +} +export enum PermissionDesc { + PRIVATE = '仅自己使用', + PUBLIC = '所有用户都可使用,不能编辑' +} diff --git a/ui/src/layout/components/breadcrumb/index.vue b/ui/src/layout/components/breadcrumb/index.vue index d8d62ccdf..68e73d937 100644 --- a/ui/src/layout/components/breadcrumb/index.vue +++ b/ui/src/layout/components/breadcrumb/index.vue @@ -99,7 +99,7 @@ - + diff --git a/ui/src/views/application/component/CreateApplicationDialog.vue b/ui/src/views/application/component/CreateApplicationDialog.vue index 17068b2f0..35964e1fd 100644 --- a/ui/src/views/application/component/CreateApplicationDialog.vue +++ b/ui/src/views/application/component/CreateApplicationDialog.vue @@ -63,10 +63,10 @@