From bd4303aee7fb9a45ed875c06553f614e9e508a0f Mon Sep 17 00:00:00 2001 From: shaohuzhang1 Date: Wed, 17 Jul 2024 17:01:57 +0800 Subject: [PATCH] =?UTF-8?q?feat:=20=E6=94=AF=E6=8C=81=E5=90=91=E9=87=8F?= =?UTF-8?q?=E6=A8=A1=E5=9E=8B?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../impl/base_search_dataset_step.py | 31 ++++++- .../impl/base_search_dataset_node.py | 35 +++++++- .../serializers/application_serializers.py | 7 +- apps/common/cache/mem_cache.py | 4 + apps/common/config/embedding_config.py | 42 +++++----- apps/common/event/listener_manage.py | 33 ++++---- ...e_id.py => 0006_dataset_embedding_mode.py} | 4 +- apps/dataset/models/data_set.py | 4 +- .../dataset/serializers/common_serializers.py | 19 ++++- .../serializers/dataset_serializers.py | 12 +-- apps/embedding/vector/base_vector.py | 28 ++----- apps/embedding/vector/pg_vector.py | 17 +--- apps/setting/models_provider/__init__.py | 82 +++++++++++++++++++ apps/smartdoc/settings/base.py | 3 - 14 files changed, 223 insertions(+), 98 deletions(-) rename apps/dataset/migrations/{0006_dataset_embedding_mode_id.py => 0006_dataset_embedding_mode.py} (86%) diff --git a/apps/application/chat_pipeline/step/search_dataset_step/impl/base_search_dataset_step.py b/apps/application/chat_pipeline/step/search_dataset_step/impl/base_search_dataset_step.py index 3dd9f8300..d893237da 100644 --- a/apps/application/chat_pipeline/step/search_dataset_step/impl/base_search_dataset_step.py +++ b/apps/application/chat_pipeline/step/search_dataset_step/impl/base_search_dataset_step.py @@ -13,22 +13,45 @@ from django.db.models import QuerySet from application.chat_pipeline.I_base_chat_pipeline import ParagraphPipelineModel from application.chat_pipeline.step.search_dataset_step.i_search_dataset_step import ISearchDatasetStep -from common.config.embedding_config import VectorStore, EmbeddingModel +from common.config.embedding_config import VectorStore, EmbeddingModelManage from common.db.search import native_search from common.util.file_util import get_file_content -from dataset.models import Paragraph +from dataset.models import Paragraph, DataSet from embedding.models import SearchMode +from setting.models import Model +from setting.models_provider import get_model from smartdoc.conf import PROJECT_DIR +def get_model_by_id(_id): + model = QuerySet(Model).filter(id=_id).first() + if model is None: + raise Exception("模型不存在") + return model + + +def get_embedding_id(dataset_id_list): + dataset_list = QuerySet(DataSet).filter(id__in=dataset_id_list) + if len(set([dataset.embedding_mode_id for dataset in dataset_list])) > 1: + raise Exception("知识库未向量模型不一致") + if len(dataset_list) == 0: + raise Exception("知识库设置错误,请重新设置知识库") + return dataset_list[0].embedding_mode_id + + class BaseSearchDatasetStep(ISearchDatasetStep): def execute(self, problem_text: str, dataset_id_list: list[str], exclude_document_id_list: list[str], exclude_paragraph_id_list: list[str], top_n: int, similarity: float, padding_problem_text: str = None, search_mode: str = None, **kwargs) -> List[ParagraphPipelineModel]: + if len(dataset_id_list) == 0: + return [] exec_problem_text = padding_problem_text if padding_problem_text is not None else problem_text - embedding_model = EmbeddingModel.get_embedding_model() + model_id = get_embedding_id(dataset_id_list) + model = get_model_by_id(model_id) + self.context['model_name'] = model.name + embedding_model = EmbeddingModelManage.get_model(model_id, lambda _id: get_model(model)) embedding_value = embedding_model.embed_query(exec_problem_text) vector = VectorStore.get_embedding_vector() embedding_list = vector.query(exec_problem_text, embedding_value, dataset_id_list, exclude_document_id_list, @@ -101,7 +124,7 @@ class BaseSearchDatasetStep(ISearchDatasetStep): 'run_time': self.context['run_time'], 'problem_text': step_args.get( 'padding_problem_text') if 'padding_problem_text' in step_args else step_args.get('problem_text'), - 'model_name': EmbeddingModel.get_embedding_model().model_name, + 'model_name': self.context.get('model_name'), 'message_tokens': 0, 'answer_tokens': 0, 'cost': 0 diff --git a/apps/application/flow/step_node/search_dataset_node/impl/base_search_dataset_node.py b/apps/application/flow/step_node/search_dataset_node/impl/base_search_dataset_node.py index 20e0af9fc..1a30a29ea 100644 --- a/apps/application/flow/step_node/search_dataset_node/impl/base_search_dataset_node.py +++ b/apps/application/flow/step_node/search_dataset_node/impl/base_search_dataset_node.py @@ -13,20 +13,47 @@ from django.db.models import QuerySet from application.flow.i_step_node import NodeResult from application.flow.step_node.search_dataset_node.i_search_dataset_node import ISearchDatasetStepNode -from common.config.embedding_config import EmbeddingModel, VectorStore +from common.config.embedding_config import VectorStore, EmbeddingModelManage from common.db.search import native_search from common.util.file_util import get_file_content -from dataset.models import Document, Paragraph +from dataset.models import Document, Paragraph, DataSet from embedding.models import SearchMode +from setting.models import Model +from setting.models_provider import get_model from smartdoc.conf import PROJECT_DIR +def get_model_by_id(_id): + model = QuerySet(Model).filter(id=_id).first() + if model is None: + raise Exception("模型不存在") + return get_model(model) + + +def get_embedding_id(dataset_id_list): + dataset_list = QuerySet(DataSet).filter(id__in=dataset_id_list) + if len(set([dataset.embedding_mode_id for dataset in dataset_list])) > 1: + raise Exception("知识库未向量模型不一致") + if len(dataset_list) == 0: + raise Exception("知识库设置错误,请重新设置知识库") + return dataset_list[0].embedding_mode_id + + +def get_none_result(question): + return NodeResult( + {'paragraph_list': [], 'is_hit_handling_method': [], 'question': question, 'data': '', + 'directly_return': ''}, {}) + + class BaseSearchDatasetNode(ISearchDatasetStepNode): def execute(self, dataset_id_list, dataset_setting, question, exclude_paragraph_id_list=None, **kwargs) -> NodeResult: self.context['question'] = question - embedding_model = EmbeddingModel.get_embedding_model() + if len(dataset_id_list) == 0: + return get_none_result(question) + model_id = get_embedding_id(dataset_id_list) + embedding_model = EmbeddingModelManage.get_model(model_id, get_model_by_id) embedding_value = embedding_model.embed_query(question) vector = VectorStore.get_embedding_vector() exclude_document_id_list = [str(document.id) for document in @@ -37,7 +64,7 @@ class BaseSearchDatasetNode(ISearchDatasetStepNode): exclude_paragraph_id_list, True, dataset_setting.get('top_n'), dataset_setting.get('similarity'), SearchMode(dataset_setting.get('search_mode'))) if embedding_list is None: - return NodeResult({'paragraph_list': [], 'is_hit_handling_method': []}, {}) + return get_none_result(question) paragraph_list = self.list_paragraph(embedding_list, vector) result = [self.reset_paragraph(paragraph, embedding_list) for paragraph in paragraph_list] return NodeResult({'paragraph_list': result, diff --git a/apps/application/serializers/application_serializers.py b/apps/application/serializers/application_serializers.py index 390456305..ed77f7ff7 100644 --- a/apps/application/serializers/application_serializers.py +++ b/apps/application/serializers/application_serializers.py @@ -26,7 +26,7 @@ from rest_framework import serializers from application.flow.workflow_manage import Flow from application.models import Application, ApplicationDatasetMapping, ApplicationTypeChoices, WorkFlowVersion from application.models.api_key_model import ApplicationAccessToken, ApplicationApiKey -from common.config.embedding_config import VectorStore, EmbeddingModel +from common.config.embedding_config import VectorStore from common.constants.authentication_type import AuthenticationType from common.db.search import get_dynamics_model, native_search, native_page_search from common.db.sql_execute import select_list @@ -36,7 +36,7 @@ from common.util.common import valid_license from common.util.field_message import ErrMessage from common.util.file_util import get_file_content from dataset.models import DataSet, Document, Image -from dataset.serializers.common_serializers import list_paragraph +from dataset.serializers.common_serializers import list_paragraph, get_embedding_model_by_dataset_id_list from embedding.models import SearchMode from setting.models import AuthOperate from setting.models.model_management import Model @@ -415,12 +415,13 @@ class ApplicationSerializer(serializers.Serializer): QuerySet(Document).filter( dataset_id__in=dataset_id_list, is_active=False)] + model = get_embedding_model_by_dataset_id_list(dataset_id_list) # 向量库检索 hit_list = vector.hit_test(self.data.get('query_text'), dataset_id_list, exclude_document_id_list, self.data.get('top_number'), self.data.get('similarity'), SearchMode(self.data.get('search_mode')), - EmbeddingModel.get_embedding_model()) + model) hit_dict = reduce(lambda x, y: {**x, **y}, [{hit.get('paragraph_id'): hit} for hit in hit_list], {}) p_list = list_paragraph([h.get('paragraph_id') for h in hit_list]) return [{**p, 'similarity': hit_dict.get(p.get('id')).get('similarity'), diff --git a/apps/common/cache/mem_cache.py b/apps/common/cache/mem_cache.py index 86a3ce040..5afb1e562 100644 --- a/apps/common/cache/mem_cache.py +++ b/apps/common/cache/mem_cache.py @@ -41,3 +41,7 @@ class MemCache(LocMemCache): delete_keys.append(key) for key in delete_keys: self._delete(key) + + def clear_timeout_data(self): + for key in self._cache.keys(): + self.get(key) diff --git a/apps/common/config/embedding_config.py b/apps/common/config/embedding_config.py index 4d652bf95..e784c16a1 100644 --- a/apps/common/config/embedding_config.py +++ b/apps/common/config/embedding_config.py @@ -6,33 +6,31 @@ @date:2023/10/23 16:03 @desc: """ -from langchain_huggingface.embeddings import HuggingFaceEmbeddings +import time -from smartdoc.const import CONFIG +from common.cache.mem_cache import MemCache -class EmbeddingModel: - instance = None +class EmbeddingModelManage: + cache = MemCache('model', {}) + up_clear_time = time.time() @staticmethod - def get_embedding_model(): - """ - 获取向量化模型 - :return: - """ - if EmbeddingModel.instance is None: - model_name = CONFIG.get('EMBEDDING_MODEL_NAME') - cache_folder = CONFIG.get('EMBEDDING_MODEL_PATH') - device = CONFIG.get('EMBEDDING_DEVICE') - encode_kwargs = {'normalize_embeddings': True} - e = HuggingFaceEmbeddings( - model_name=model_name, - cache_folder=cache_folder, - model_kwargs={'device': device}, - encode_kwargs=encode_kwargs, - ) - EmbeddingModel.instance = e - return EmbeddingModel.instance + def get_model(_id, get_model): + model_instance = EmbeddingModelManage.cache.get(_id) + if model_instance is None: + model_instance = get_model(_id) + EmbeddingModelManage.cache.set(_id, model_instance, timeout=60 * 30) + return model_instance + # 续期 + EmbeddingModelManage.cache.touch(_id, timeout=60 * 30) + EmbeddingModelManage.clear_timeout_cache() + return model_instance + + @staticmethod + def clear_timeout_cache(): + if time.time() - EmbeddingModelManage.up_clear_time > 60: + EmbeddingModelManage.cache.clear_timeout_data() class VectorStore: diff --git a/apps/common/event/listener_manage.py b/apps/common/event/listener_manage.py index dea53cb13..2df7b3415 100644 --- a/apps/common/event/listener_manage.py +++ b/apps/common/event/listener_manage.py @@ -15,8 +15,9 @@ from typing import List import django.db.models from blinker import signal from django.db.models import QuerySet +from langchain_core.embeddings import Embeddings -from common.config.embedding_config import VectorStore, EmbeddingModel +from common.config.embedding_config import VectorStore from common.db.search import native_search, get_dynamics_model from common.event.common import poxy, embedding_poxy from common.util.file_util import get_file_content @@ -89,11 +90,11 @@ class ListenerManagement: @staticmethod @embedding_poxy - def embedding_by_paragraph(paragraph_id): + def embedding_by_paragraph(paragraph_id, embedding_model: Embeddings): """ 向量化段落 根据段落id - :param paragraph_id: 段落id - :return: None + @param paragraph_id: 段落id + @param embedding_model: 向量模型 """ max_kb.info(f"开始--->向量化段落:{paragraph_id}") status = Status.success @@ -107,7 +108,7 @@ class ListenerManagement: # 删除段落 VectorStore.get_embedding_vector().delete_by_paragraph_id(paragraph_id) # 批量向量化 - VectorStore.get_embedding_vector().batch_save(data_list) + VectorStore.get_embedding_vector().batch_save(data_list, embedding_model) except Exception as e: max_kb_error.error(f'向量化段落:{paragraph_id}出现错误{str(e)}{traceback.format_exc()}') status = Status.error @@ -117,10 +118,11 @@ class ListenerManagement: @staticmethod @embedding_poxy - def embedding_by_document(document_id): + def embedding_by_document(document_id, embedding_model: Embeddings): """ 向量化文档 - :param document_id: 文档id + @param document_id: 文档id + @param embedding_model 向量模型 :return: None """ max_kb.info(f"开始--->向量化文档:{document_id}") @@ -138,7 +140,7 @@ class ListenerManagement: # 删除文档向量数据 VectorStore.get_embedding_vector().delete_by_document_id(document_id) # 批量向量化 - VectorStore.get_embedding_vector().batch_save(data_list) + VectorStore.get_embedding_vector().batch_save(data_list, embedding_model) except Exception as e: max_kb_error.error(f'向量化文档:{document_id}出现错误{str(e)}{traceback.format_exc()}') status = Status.error @@ -151,10 +153,11 @@ class ListenerManagement: @staticmethod @embedding_poxy - def embedding_by_dataset(dataset_id): + def embedding_by_dataset(dataset_id, embedding_model: Embeddings): """ 向量化知识库 - :param dataset_id: 知识库id + @param dataset_id: 知识库id + @param embedding_model 向量模型 :return: None """ max_kb.info(f"开始--->向量化数据集:{dataset_id}") @@ -162,7 +165,7 @@ class ListenerManagement: document_list = QuerySet(Document).filter(dataset_id=dataset_id) max_kb.info(f"数据集文档:{[d.name for d in document_list]}") for document in document_list: - ListenerManagement.embedding_by_document(document.id) + ListenerManagement.embedding_by_document(document.id, embedding_model) except Exception as e: max_kb_error.error(f'向量化数据集:{dataset_id}出现错误{str(e)}{traceback.format_exc()}') finally: @@ -245,11 +248,6 @@ class ListenerManagement: def delete_embedding_by_dataset_id_list(source_ids: List[str]): VectorStore.get_embedding_vector().delete_by_dataset_id_list(source_ids) - @staticmethod - @poxy - def init_embedding_model(ages): - EmbeddingModel.get_embedding_model() - def run(self): # 添加向量 根据问题id ListenerManagement.embedding_by_problem_signal.connect(self.embedding_by_problem) @@ -276,8 +274,7 @@ class ListenerManagement: ListenerManagement.disable_embedding_by_paragraph_signal.connect(self.disable_embedding_by_paragraph) # 启动段落向量 ListenerManagement.enable_embedding_by_paragraph_signal.connect(self.enable_embedding_by_paragraph) - # 初始化向量化模型 - ListenerManagement.init_embedding_model_signal.connect(self.init_embedding_model) + # 同步web站点知识库 ListenerManagement.sync_web_dataset_signal.connect(self.sync_web_dataset) # 同步web站点 文档 diff --git a/apps/dataset/migrations/0006_dataset_embedding_mode_id.py b/apps/dataset/migrations/0006_dataset_embedding_mode.py similarity index 86% rename from apps/dataset/migrations/0006_dataset_embedding_mode_id.py rename to apps/dataset/migrations/0006_dataset_embedding_mode.py index 77f546c60..2248d8e36 100644 --- a/apps/dataset/migrations/0006_dataset_embedding_mode_id.py +++ b/apps/dataset/migrations/0006_dataset_embedding_mode.py @@ -1,4 +1,4 @@ -# Generated by Django 4.2.13 on 2024-07-15 15:56 +# Generated by Django 4.2.13 on 2024-07-17 13:56 import dataset.models.data_set from django.db import migrations, models @@ -15,7 +15,7 @@ class Migration(migrations.Migration): operations = [ migrations.AddField( model_name='dataset', - name='embedding_mode_id', + name='embedding_mode', field=models.ForeignKey(default=dataset.models.data_set.default_model, on_delete=django.db.models.deletion.DO_NOTHING, to='setting.model', verbose_name='向量模型'), ), ] diff --git a/apps/dataset/models/data_set.py b/apps/dataset/models/data_set.py index 69fed2a09..a34854978 100644 --- a/apps/dataset/models/data_set.py +++ b/apps/dataset/models/data_set.py @@ -49,8 +49,8 @@ class DataSet(AppModelMixin): user = models.ForeignKey(User, on_delete=models.DO_NOTHING, verbose_name="所属用户") type = models.CharField(verbose_name='类型', max_length=1, choices=Type.choices, default=Type.base) - embedding_mode_id = models.ForeignKey(Model, on_delete=models.DO_NOTHING, verbose_name="向量模型", - default=default_model) + embedding_mode = models.ForeignKey(Model, on_delete=models.DO_NOTHING, verbose_name="向量模型", + default=default_model) meta = models.JSONField(verbose_name="元数据", default=dict) class Meta: diff --git a/apps/dataset/serializers/common_serializers.py b/apps/dataset/serializers/common_serializers.py index 099d1f24e..045588f6d 100644 --- a/apps/dataset/serializers/common_serializers.py +++ b/apps/dataset/serializers/common_serializers.py @@ -14,6 +14,7 @@ from django.db.models import QuerySet from drf_yasg import openapi from rest_framework import serializers +from common.config.embedding_config import EmbeddingModelManage from common.db.search import native_search from common.db.sql_execute import update_execute from common.exception.app_exception import AppApiException @@ -21,7 +22,8 @@ from common.mixins.api_mixin import ApiMixin from common.util.field_message import ErrMessage from common.util.file_util import get_file_content from common.util.fork import Fork -from dataset.models import Paragraph, Problem, ProblemParagraphMapping +from dataset.models import Paragraph, Problem, ProblemParagraphMapping, DataSet +from setting.models_provider import get_model from smartdoc.conf import PROJECT_DIR @@ -130,3 +132,18 @@ class ProblemParagraphManage: result = [problem_model for problem_model, is_create in problem_content_dict.values() if is_create], problem_paragraph_mapping_list return result + + +def get_embedding_model_by_dataset_id_list(dataset_id_list: List): + dataset_list = QuerySet(DataSet).filter(id__in=dataset_id_list) + if len(set([dataset.embedding_mode_id for dataset in dataset_list])) > 1: + raise Exception("知识库未向量模型不一致") + if len(dataset_list) == 0: + raise Exception("知识库设置错误,请重新设置知识库") + return EmbeddingModelManage.get_model(str(dataset_list[0].id), + lambda _id: get_model(dataset_list[0].embedding_mode)) + + +def get_embedding_model_by_dataset_id(dataset_id: str): + dataset = QuerySet(DataSet).select_related('embedding_mode').filter(id=dataset_id) + return EmbeddingModelManage.get_model(str(dataset.id), lambda _id: get_model(dataset.embedding_mode)) diff --git a/apps/dataset/serializers/dataset_serializers.py b/apps/dataset/serializers/dataset_serializers.py index 0d87744c1..9f551c75e 100644 --- a/apps/dataset/serializers/dataset_serializers.py +++ b/apps/dataset/serializers/dataset_serializers.py @@ -15,7 +15,6 @@ from functools import reduce from typing import Dict, List from urllib.parse import urlparse -from django.conf import settings from django.contrib.postgres.fields import ArrayField from django.core import validators from django.db import transaction, models @@ -25,7 +24,7 @@ from drf_yasg import openapi from rest_framework import serializers from application.models import ApplicationDatasetMapping -from common.config.embedding_config import VectorStore, EmbeddingModel +from common.config.embedding_config import VectorStore from common.db.search import get_dynamics_model, native_page_search, native_search from common.db.sql_execute import select_list from common.event import ListenerManagement, SyncWebDatasetArgs @@ -37,7 +36,8 @@ from common.util.file_util import get_file_content from common.util.fork import ChildLink, Fork from common.util.split_model import get_split_model from dataset.models.data_set import DataSet, Document, Paragraph, Problem, Type, ProblemParagraphMapping -from dataset.serializers.common_serializers import list_paragraph, MetaSerializer, ProblemParagraphManage +from dataset.serializers.common_serializers import list_paragraph, MetaSerializer, ProblemParagraphManage, \ + get_embedding_model_by_dataset_id from dataset.serializers.document_serializers import DocumentSerializers, DocumentInstanceSerializer from embedding.models import SearchMode from setting.models import AuthOperate @@ -359,8 +359,9 @@ class DataSetSerializers(serializers.ModelSerializer): @staticmethod def post_embedding_dataset(document_list, dataset_id): + model = get_embedding_model_by_dataset_id(dataset_id) # 发送向量化事件 - ListenerManagement.embedding_by_dataset_signal.send(dataset_id) + ListenerManagement.embedding_by_dataset_signal.send(dataset_id, embedding_model=model) return document_list def save_qa(self, instance: Dict, with_valid=True): @@ -565,12 +566,13 @@ class DataSetSerializers(serializers.ModelSerializer): QuerySet(Document).filter( dataset_id=self.data.get('id'), is_active=False)] + model = get_embedding_model_by_dataset_id(self.data.get('id')) # 向量库检索 hit_list = vector.hit_test(self.data.get('query_text'), [self.data.get('id')], exclude_document_id_list, self.data.get('top_number'), self.data.get('similarity'), SearchMode(self.data.get('search_mode')), - EmbeddingModel.get_embedding_model()) + model) hit_dict = reduce(lambda x, y: {**x, **y}, [{hit.get('paragraph_id'): hit} for hit in hit_list], {}) p_list = list_paragraph([h.get('paragraph_id') for h in hit_list]) return [{**p, 'similarity': hit_dict.get(p.get('id')).get('similarity'), diff --git a/apps/embedding/vector/base_vector.py b/apps/embedding/vector/base_vector.py index 2bfd0e977..281ca12f2 100644 --- a/apps/embedding/vector/base_vector.py +++ b/apps/embedding/vector/base_vector.py @@ -10,9 +10,8 @@ import threading from abc import ABC, abstractmethod from typing import List, Dict -from langchain_community.embeddings import HuggingFaceEmbeddings +from langchain_core.embeddings import Embeddings -from common.config.embedding_config import EmbeddingModel from common.util.common import sub_array from embedding.models import SourceType, SearchMode @@ -51,7 +50,7 @@ class BaseVectorStore(ABC): def save(self, text, source_type: SourceType, dataset_id: str, document_id: str, paragraph_id: str, source_id: str, is_active: bool, - embedding=None): + embedding: Embeddings): """ 插入向量数据 :param source_id: 资源id @@ -64,13 +63,10 @@ class BaseVectorStore(ABC): :param paragraph_id 段落id :return: bool """ - - if embedding is None: - embedding = EmbeddingModel.get_embedding_model() self.save_pre_handler() self._save(text, source_type, dataset_id, document_id, paragraph_id, source_id, is_active, embedding) - def batch_save(self, data_list: List[Dict], embedding=None): + def batch_save(self, data_list: List[Dict], embedding: Embeddings): # 获取锁 lock.acquire() try: @@ -80,8 +76,6 @@ class BaseVectorStore(ABC): :param embedding: 向量化处理器 :return: bool """ - if embedding is None: - embedding = EmbeddingModel.get_embedding_model() self.save_pre_handler() result = sub_array(data_list) for child_array in result: @@ -94,17 +88,17 @@ class BaseVectorStore(ABC): @abstractmethod def _save(self, text, source_type: SourceType, dataset_id: str, document_id: str, paragraph_id: str, source_id: str, is_active: bool, - embedding: HuggingFaceEmbeddings): + embedding: Embeddings): pass @abstractmethod - def _batch_save(self, text_list: List[Dict], embedding: HuggingFaceEmbeddings): + def _batch_save(self, text_list: List[Dict], embedding: Embeddings): pass def search(self, query_text, dataset_id_list: list[str], exclude_document_id_list: list[str], exclude_paragraph_list: list[str], is_active: bool, - embedding: HuggingFaceEmbeddings): + embedding: Embeddings): if dataset_id_list is None or len(dataset_id_list) == 0: return [] embedding_query = embedding.embed_query(query_text) @@ -123,7 +117,7 @@ class BaseVectorStore(ABC): def hit_test(self, query_text, dataset_id: list[str], exclude_document_id_list: list[str], top_number: int, similarity: float, search_mode: SearchMode, - embedding: HuggingFaceEmbeddings): + embedding: Embeddings): pass @abstractmethod @@ -142,14 +136,6 @@ class BaseVectorStore(ABC): def update_by_source_ids(self, source_ids: List[str], instance: Dict): pass - @abstractmethod - def embed_documents(self, text_list: List[str]): - pass - - @abstractmethod - def embed_query(self, text: str): - pass - @abstractmethod def delete_by_dataset_id(self, dataset_id: str): pass diff --git a/apps/embedding/vector/pg_vector.py b/apps/embedding/vector/pg_vector.py index 5c0d04536..0e718bd6e 100644 --- a/apps/embedding/vector/pg_vector.py +++ b/apps/embedding/vector/pg_vector.py @@ -13,9 +13,8 @@ from abc import ABC, abstractmethod from typing import Dict, List from django.db.models import QuerySet -from langchain_community.embeddings import HuggingFaceEmbeddings +from langchain_core.embeddings import Embeddings -from common.config.embedding_config import EmbeddingModel from common.db.search import generate_sql_by_query_dict from common.db.sql_execute import select_list from common.util.file_util import get_file_content @@ -33,14 +32,6 @@ class PGVector(BaseVectorStore): def update_by_source_ids(self, source_ids: List[str], instance: Dict): QuerySet(Embedding).filter(source_id__in=source_ids).update(**instance) - def embed_documents(self, text_list: List[str]): - embedding = EmbeddingModel.get_embedding_model() - return embedding.embed_documents(text_list) - - def embed_query(self, text: str): - embedding = EmbeddingModel.get_embedding_model() - return embedding.embed_query(text) - def vector_is_create(self) -> bool: # 项目启动默认是创建好的 不需要再创建 return True @@ -50,7 +41,7 @@ class PGVector(BaseVectorStore): def _save(self, text, source_type: SourceType, dataset_id: str, document_id: str, paragraph_id: str, source_id: str, is_active: bool, - embedding: HuggingFaceEmbeddings): + embedding: Embeddings): text_embedding = embedding.embed_query(text) embedding = Embedding(id=uuid.uuid1(), dataset_id=dataset_id, @@ -64,7 +55,7 @@ class PGVector(BaseVectorStore): embedding.save() return True - def _batch_save(self, text_list: List[Dict], embedding: HuggingFaceEmbeddings): + def _batch_save(self, text_list: List[Dict], embedding: Embeddings): texts = [row.get('text') for row in text_list] embeddings = embedding.embed_documents(texts) embedding_list = [Embedding(id=uuid.uuid1(), @@ -83,7 +74,7 @@ class PGVector(BaseVectorStore): def hit_test(self, query_text, dataset_id_list: list[str], exclude_document_id_list: list[str], top_number: int, similarity: float, search_mode: SearchMode, - embedding: HuggingFaceEmbeddings): + embedding: Embeddings): if dataset_id_list is None or len(dataset_id_list) == 0: return [] exclude_dict = {} diff --git a/apps/setting/models_provider/__init__.py b/apps/setting/models_provider/__init__.py index 53b7001e5..285e90f32 100644 --- a/apps/setting/models_provider/__init__.py +++ b/apps/setting/models_provider/__init__.py @@ -6,3 +6,85 @@ @date:2023/10/31 17:16 @desc: """ +import json +from typing import Dict + +from common.util.rsa_util import rsa_long_decrypt +from setting.models_provider.constants.model_provider_constants import ModelProvideConstants + + +def get_model_(provider, model_type, model_name, credential): + """ + 获取模型实例 + @param provider: 供应商 + @param model_type: 模型类型 + @param model_name: 模型名称 + @param credential: 认证信息 + @return: 模型实例 + """ + model = get_provider(provider).get_model(model_type, model_name, + json.loads( + rsa_long_decrypt(credential)), + streaming=True) + return model + + +def get_model(model): + """ + 获取模型实例 + @param model: model 数据库Model实例对象 + @return: 模型实例 + """ + return get_model_(model.provider, model.model_type, model.model_name, model.credential) + + +def get_provider(provider): + """ + 获取供应商实例 + @param provider: 供应商字符串 + @return: 供应商实例 + """ + return ModelProvideConstants[provider].value + + +def get_model_list(provider, model_type): + """ + 获取模型列表 + @param provider: 供应商字符串 + @param model_type: 模型类型 + @return: 模型列表 + """ + return get_provider(provider).get_model_list(model_type) + + +def get_model_credential(provider, model_type, model_name): + """ + 获取模型认证实例 + @param provider: 供应商字符串 + @param model_type: 模型类型 + @param model_name: 模型名称 + @return: 认证实例对象 + """ + return get_provider(provider).get_model_credential(model_type, model_name) + + +def get_model_type_list(provider): + """ + 获取模型类型列表 + @param provider: 供应商字符串 + @return: 模型类型列表 + """ + return get_provider(provider).get_model_type_list() + + +def is_valid_credential(provider, model_type, model_name, model_credential: Dict[str, object], raise_exception=False): + """ + 校验模型认证参数 + @param provider: 供应商字符串 + @param model_type: 模型类型 + @param model_name: 模型名称 + @param model_credential: 模型认证数据 + @param raise_exception: 是否抛出错误 + @return: True|False + """ + return get_provider(provider).is_valid_credential(model_type, model_name, model_credential, raise_exception) diff --git a/apps/smartdoc/settings/base.py b/apps/smartdoc/settings/base.py index 04e8810e5..88c9b5380 100644 --- a/apps/smartdoc/settings/base.py +++ b/apps/smartdoc/settings/base.py @@ -104,9 +104,6 @@ CACHES = { "token_cache": { 'BACKEND': 'common.cache.file_cache.FileCache', 'LOCATION': os.path.join(PROJECT_DIR, 'data', 'cache', "token_cache") # 文件夹路径 - }, - "chat_cache": { - 'BACKEND': 'django.core.cache.backends.locmem.LocMemCache', } }