mirror of
https://github.com/1Panel-dev/MaxKB.git
synced 2025-12-26 01:33:05 +00:00
feat: 支持向量模型
This commit is contained in:
parent
9b81b89975
commit
bd4303aee7
|
|
@ -13,22 +13,45 @@ from django.db.models import QuerySet
|
|||
|
||||
from application.chat_pipeline.I_base_chat_pipeline import ParagraphPipelineModel
|
||||
from application.chat_pipeline.step.search_dataset_step.i_search_dataset_step import ISearchDatasetStep
|
||||
from common.config.embedding_config import VectorStore, EmbeddingModel
|
||||
from common.config.embedding_config import VectorStore, EmbeddingModelManage
|
||||
from common.db.search import native_search
|
||||
from common.util.file_util import get_file_content
|
||||
from dataset.models import Paragraph
|
||||
from dataset.models import Paragraph, DataSet
|
||||
from embedding.models import SearchMode
|
||||
from setting.models import Model
|
||||
from setting.models_provider import get_model
|
||||
from smartdoc.conf import PROJECT_DIR
|
||||
|
||||
|
||||
def get_model_by_id(_id):
|
||||
model = QuerySet(Model).filter(id=_id).first()
|
||||
if model is None:
|
||||
raise Exception("模型不存在")
|
||||
return model
|
||||
|
||||
|
||||
def get_embedding_id(dataset_id_list):
|
||||
dataset_list = QuerySet(DataSet).filter(id__in=dataset_id_list)
|
||||
if len(set([dataset.embedding_mode_id for dataset in dataset_list])) > 1:
|
||||
raise Exception("知识库未向量模型不一致")
|
||||
if len(dataset_list) == 0:
|
||||
raise Exception("知识库设置错误,请重新设置知识库")
|
||||
return dataset_list[0].embedding_mode_id
|
||||
|
||||
|
||||
class BaseSearchDatasetStep(ISearchDatasetStep):
|
||||
|
||||
def execute(self, problem_text: str, dataset_id_list: list[str], exclude_document_id_list: list[str],
|
||||
exclude_paragraph_id_list: list[str], top_n: int, similarity: float, padding_problem_text: str = None,
|
||||
search_mode: str = None,
|
||||
**kwargs) -> List[ParagraphPipelineModel]:
|
||||
if len(dataset_id_list) == 0:
|
||||
return []
|
||||
exec_problem_text = padding_problem_text if padding_problem_text is not None else problem_text
|
||||
embedding_model = EmbeddingModel.get_embedding_model()
|
||||
model_id = get_embedding_id(dataset_id_list)
|
||||
model = get_model_by_id(model_id)
|
||||
self.context['model_name'] = model.name
|
||||
embedding_model = EmbeddingModelManage.get_model(model_id, lambda _id: get_model(model))
|
||||
embedding_value = embedding_model.embed_query(exec_problem_text)
|
||||
vector = VectorStore.get_embedding_vector()
|
||||
embedding_list = vector.query(exec_problem_text, embedding_value, dataset_id_list, exclude_document_id_list,
|
||||
|
|
@ -101,7 +124,7 @@ class BaseSearchDatasetStep(ISearchDatasetStep):
|
|||
'run_time': self.context['run_time'],
|
||||
'problem_text': step_args.get(
|
||||
'padding_problem_text') if 'padding_problem_text' in step_args else step_args.get('problem_text'),
|
||||
'model_name': EmbeddingModel.get_embedding_model().model_name,
|
||||
'model_name': self.context.get('model_name'),
|
||||
'message_tokens': 0,
|
||||
'answer_tokens': 0,
|
||||
'cost': 0
|
||||
|
|
|
|||
|
|
@ -13,20 +13,47 @@ from django.db.models import QuerySet
|
|||
|
||||
from application.flow.i_step_node import NodeResult
|
||||
from application.flow.step_node.search_dataset_node.i_search_dataset_node import ISearchDatasetStepNode
|
||||
from common.config.embedding_config import EmbeddingModel, VectorStore
|
||||
from common.config.embedding_config import VectorStore, EmbeddingModelManage
|
||||
from common.db.search import native_search
|
||||
from common.util.file_util import get_file_content
|
||||
from dataset.models import Document, Paragraph
|
||||
from dataset.models import Document, Paragraph, DataSet
|
||||
from embedding.models import SearchMode
|
||||
from setting.models import Model
|
||||
from setting.models_provider import get_model
|
||||
from smartdoc.conf import PROJECT_DIR
|
||||
|
||||
|
||||
def get_model_by_id(_id):
|
||||
model = QuerySet(Model).filter(id=_id).first()
|
||||
if model is None:
|
||||
raise Exception("模型不存在")
|
||||
return get_model(model)
|
||||
|
||||
|
||||
def get_embedding_id(dataset_id_list):
|
||||
dataset_list = QuerySet(DataSet).filter(id__in=dataset_id_list)
|
||||
if len(set([dataset.embedding_mode_id for dataset in dataset_list])) > 1:
|
||||
raise Exception("知识库未向量模型不一致")
|
||||
if len(dataset_list) == 0:
|
||||
raise Exception("知识库设置错误,请重新设置知识库")
|
||||
return dataset_list[0].embedding_mode_id
|
||||
|
||||
|
||||
def get_none_result(question):
|
||||
return NodeResult(
|
||||
{'paragraph_list': [], 'is_hit_handling_method': [], 'question': question, 'data': '',
|
||||
'directly_return': ''}, {})
|
||||
|
||||
|
||||
class BaseSearchDatasetNode(ISearchDatasetStepNode):
|
||||
def execute(self, dataset_id_list, dataset_setting, question,
|
||||
exclude_paragraph_id_list=None,
|
||||
**kwargs) -> NodeResult:
|
||||
self.context['question'] = question
|
||||
embedding_model = EmbeddingModel.get_embedding_model()
|
||||
if len(dataset_id_list) == 0:
|
||||
return get_none_result(question)
|
||||
model_id = get_embedding_id(dataset_id_list)
|
||||
embedding_model = EmbeddingModelManage.get_model(model_id, get_model_by_id)
|
||||
embedding_value = embedding_model.embed_query(question)
|
||||
vector = VectorStore.get_embedding_vector()
|
||||
exclude_document_id_list = [str(document.id) for document in
|
||||
|
|
@ -37,7 +64,7 @@ class BaseSearchDatasetNode(ISearchDatasetStepNode):
|
|||
exclude_paragraph_id_list, True, dataset_setting.get('top_n'),
|
||||
dataset_setting.get('similarity'), SearchMode(dataset_setting.get('search_mode')))
|
||||
if embedding_list is None:
|
||||
return NodeResult({'paragraph_list': [], 'is_hit_handling_method': []}, {})
|
||||
return get_none_result(question)
|
||||
paragraph_list = self.list_paragraph(embedding_list, vector)
|
||||
result = [self.reset_paragraph(paragraph, embedding_list) for paragraph in paragraph_list]
|
||||
return NodeResult({'paragraph_list': result,
|
||||
|
|
|
|||
|
|
@ -26,7 +26,7 @@ from rest_framework import serializers
|
|||
from application.flow.workflow_manage import Flow
|
||||
from application.models import Application, ApplicationDatasetMapping, ApplicationTypeChoices, WorkFlowVersion
|
||||
from application.models.api_key_model import ApplicationAccessToken, ApplicationApiKey
|
||||
from common.config.embedding_config import VectorStore, EmbeddingModel
|
||||
from common.config.embedding_config import VectorStore
|
||||
from common.constants.authentication_type import AuthenticationType
|
||||
from common.db.search import get_dynamics_model, native_search, native_page_search
|
||||
from common.db.sql_execute import select_list
|
||||
|
|
@ -36,7 +36,7 @@ from common.util.common import valid_license
|
|||
from common.util.field_message import ErrMessage
|
||||
from common.util.file_util import get_file_content
|
||||
from dataset.models import DataSet, Document, Image
|
||||
from dataset.serializers.common_serializers import list_paragraph
|
||||
from dataset.serializers.common_serializers import list_paragraph, get_embedding_model_by_dataset_id_list
|
||||
from embedding.models import SearchMode
|
||||
from setting.models import AuthOperate
|
||||
from setting.models.model_management import Model
|
||||
|
|
@ -415,12 +415,13 @@ class ApplicationSerializer(serializers.Serializer):
|
|||
QuerySet(Document).filter(
|
||||
dataset_id__in=dataset_id_list,
|
||||
is_active=False)]
|
||||
model = get_embedding_model_by_dataset_id_list(dataset_id_list)
|
||||
# 向量库检索
|
||||
hit_list = vector.hit_test(self.data.get('query_text'), dataset_id_list, exclude_document_id_list,
|
||||
self.data.get('top_number'),
|
||||
self.data.get('similarity'),
|
||||
SearchMode(self.data.get('search_mode')),
|
||||
EmbeddingModel.get_embedding_model())
|
||||
model)
|
||||
hit_dict = reduce(lambda x, y: {**x, **y}, [{hit.get('paragraph_id'): hit} for hit in hit_list], {})
|
||||
p_list = list_paragraph([h.get('paragraph_id') for h in hit_list])
|
||||
return [{**p, 'similarity': hit_dict.get(p.get('id')).get('similarity'),
|
||||
|
|
|
|||
|
|
@ -41,3 +41,7 @@ class MemCache(LocMemCache):
|
|||
delete_keys.append(key)
|
||||
for key in delete_keys:
|
||||
self._delete(key)
|
||||
|
||||
def clear_timeout_data(self):
|
||||
for key in self._cache.keys():
|
||||
self.get(key)
|
||||
|
|
|
|||
|
|
@ -6,33 +6,31 @@
|
|||
@date:2023/10/23 16:03
|
||||
@desc:
|
||||
"""
|
||||
from langchain_huggingface.embeddings import HuggingFaceEmbeddings
|
||||
import time
|
||||
|
||||
from smartdoc.const import CONFIG
|
||||
from common.cache.mem_cache import MemCache
|
||||
|
||||
|
||||
class EmbeddingModel:
|
||||
instance = None
|
||||
class EmbeddingModelManage:
|
||||
cache = MemCache('model', {})
|
||||
up_clear_time = time.time()
|
||||
|
||||
@staticmethod
|
||||
def get_embedding_model():
|
||||
"""
|
||||
获取向量化模型
|
||||
:return:
|
||||
"""
|
||||
if EmbeddingModel.instance is None:
|
||||
model_name = CONFIG.get('EMBEDDING_MODEL_NAME')
|
||||
cache_folder = CONFIG.get('EMBEDDING_MODEL_PATH')
|
||||
device = CONFIG.get('EMBEDDING_DEVICE')
|
||||
encode_kwargs = {'normalize_embeddings': True}
|
||||
e = HuggingFaceEmbeddings(
|
||||
model_name=model_name,
|
||||
cache_folder=cache_folder,
|
||||
model_kwargs={'device': device},
|
||||
encode_kwargs=encode_kwargs,
|
||||
)
|
||||
EmbeddingModel.instance = e
|
||||
return EmbeddingModel.instance
|
||||
def get_model(_id, get_model):
|
||||
model_instance = EmbeddingModelManage.cache.get(_id)
|
||||
if model_instance is None:
|
||||
model_instance = get_model(_id)
|
||||
EmbeddingModelManage.cache.set(_id, model_instance, timeout=60 * 30)
|
||||
return model_instance
|
||||
# 续期
|
||||
EmbeddingModelManage.cache.touch(_id, timeout=60 * 30)
|
||||
EmbeddingModelManage.clear_timeout_cache()
|
||||
return model_instance
|
||||
|
||||
@staticmethod
|
||||
def clear_timeout_cache():
|
||||
if time.time() - EmbeddingModelManage.up_clear_time > 60:
|
||||
EmbeddingModelManage.cache.clear_timeout_data()
|
||||
|
||||
|
||||
class VectorStore:
|
||||
|
|
|
|||
|
|
@ -15,8 +15,9 @@ from typing import List
|
|||
import django.db.models
|
||||
from blinker import signal
|
||||
from django.db.models import QuerySet
|
||||
from langchain_core.embeddings import Embeddings
|
||||
|
||||
from common.config.embedding_config import VectorStore, EmbeddingModel
|
||||
from common.config.embedding_config import VectorStore
|
||||
from common.db.search import native_search, get_dynamics_model
|
||||
from common.event.common import poxy, embedding_poxy
|
||||
from common.util.file_util import get_file_content
|
||||
|
|
@ -89,11 +90,11 @@ class ListenerManagement:
|
|||
|
||||
@staticmethod
|
||||
@embedding_poxy
|
||||
def embedding_by_paragraph(paragraph_id):
|
||||
def embedding_by_paragraph(paragraph_id, embedding_model: Embeddings):
|
||||
"""
|
||||
向量化段落 根据段落id
|
||||
:param paragraph_id: 段落id
|
||||
:return: None
|
||||
@param paragraph_id: 段落id
|
||||
@param embedding_model: 向量模型
|
||||
"""
|
||||
max_kb.info(f"开始--->向量化段落:{paragraph_id}")
|
||||
status = Status.success
|
||||
|
|
@ -107,7 +108,7 @@ class ListenerManagement:
|
|||
# 删除段落
|
||||
VectorStore.get_embedding_vector().delete_by_paragraph_id(paragraph_id)
|
||||
# 批量向量化
|
||||
VectorStore.get_embedding_vector().batch_save(data_list)
|
||||
VectorStore.get_embedding_vector().batch_save(data_list, embedding_model)
|
||||
except Exception as e:
|
||||
max_kb_error.error(f'向量化段落:{paragraph_id}出现错误{str(e)}{traceback.format_exc()}')
|
||||
status = Status.error
|
||||
|
|
@ -117,10 +118,11 @@ class ListenerManagement:
|
|||
|
||||
@staticmethod
|
||||
@embedding_poxy
|
||||
def embedding_by_document(document_id):
|
||||
def embedding_by_document(document_id, embedding_model: Embeddings):
|
||||
"""
|
||||
向量化文档
|
||||
:param document_id: 文档id
|
||||
@param document_id: 文档id
|
||||
@param embedding_model 向量模型
|
||||
:return: None
|
||||
"""
|
||||
max_kb.info(f"开始--->向量化文档:{document_id}")
|
||||
|
|
@ -138,7 +140,7 @@ class ListenerManagement:
|
|||
# 删除文档向量数据
|
||||
VectorStore.get_embedding_vector().delete_by_document_id(document_id)
|
||||
# 批量向量化
|
||||
VectorStore.get_embedding_vector().batch_save(data_list)
|
||||
VectorStore.get_embedding_vector().batch_save(data_list, embedding_model)
|
||||
except Exception as e:
|
||||
max_kb_error.error(f'向量化文档:{document_id}出现错误{str(e)}{traceback.format_exc()}')
|
||||
status = Status.error
|
||||
|
|
@ -151,10 +153,11 @@ class ListenerManagement:
|
|||
|
||||
@staticmethod
|
||||
@embedding_poxy
|
||||
def embedding_by_dataset(dataset_id):
|
||||
def embedding_by_dataset(dataset_id, embedding_model: Embeddings):
|
||||
"""
|
||||
向量化知识库
|
||||
:param dataset_id: 知识库id
|
||||
@param dataset_id: 知识库id
|
||||
@param embedding_model 向量模型
|
||||
:return: None
|
||||
"""
|
||||
max_kb.info(f"开始--->向量化数据集:{dataset_id}")
|
||||
|
|
@ -162,7 +165,7 @@ class ListenerManagement:
|
|||
document_list = QuerySet(Document).filter(dataset_id=dataset_id)
|
||||
max_kb.info(f"数据集文档:{[d.name for d in document_list]}")
|
||||
for document in document_list:
|
||||
ListenerManagement.embedding_by_document(document.id)
|
||||
ListenerManagement.embedding_by_document(document.id, embedding_model)
|
||||
except Exception as e:
|
||||
max_kb_error.error(f'向量化数据集:{dataset_id}出现错误{str(e)}{traceback.format_exc()}')
|
||||
finally:
|
||||
|
|
@ -245,11 +248,6 @@ class ListenerManagement:
|
|||
def delete_embedding_by_dataset_id_list(source_ids: List[str]):
|
||||
VectorStore.get_embedding_vector().delete_by_dataset_id_list(source_ids)
|
||||
|
||||
@staticmethod
|
||||
@poxy
|
||||
def init_embedding_model(ages):
|
||||
EmbeddingModel.get_embedding_model()
|
||||
|
||||
def run(self):
|
||||
# 添加向量 根据问题id
|
||||
ListenerManagement.embedding_by_problem_signal.connect(self.embedding_by_problem)
|
||||
|
|
@ -276,8 +274,7 @@ class ListenerManagement:
|
|||
ListenerManagement.disable_embedding_by_paragraph_signal.connect(self.disable_embedding_by_paragraph)
|
||||
# 启动段落向量
|
||||
ListenerManagement.enable_embedding_by_paragraph_signal.connect(self.enable_embedding_by_paragraph)
|
||||
# 初始化向量化模型
|
||||
ListenerManagement.init_embedding_model_signal.connect(self.init_embedding_model)
|
||||
|
||||
# 同步web站点知识库
|
||||
ListenerManagement.sync_web_dataset_signal.connect(self.sync_web_dataset)
|
||||
# 同步web站点 文档
|
||||
|
|
|
|||
|
|
@ -1,4 +1,4 @@
|
|||
# Generated by Django 4.2.13 on 2024-07-15 15:56
|
||||
# Generated by Django 4.2.13 on 2024-07-17 13:56
|
||||
|
||||
import dataset.models.data_set
|
||||
from django.db import migrations, models
|
||||
|
|
@ -15,7 +15,7 @@ class Migration(migrations.Migration):
|
|||
operations = [
|
||||
migrations.AddField(
|
||||
model_name='dataset',
|
||||
name='embedding_mode_id',
|
||||
name='embedding_mode',
|
||||
field=models.ForeignKey(default=dataset.models.data_set.default_model, on_delete=django.db.models.deletion.DO_NOTHING, to='setting.model', verbose_name='向量模型'),
|
||||
),
|
||||
]
|
||||
|
|
@ -49,8 +49,8 @@ class DataSet(AppModelMixin):
|
|||
user = models.ForeignKey(User, on_delete=models.DO_NOTHING, verbose_name="所属用户")
|
||||
type = models.CharField(verbose_name='类型', max_length=1, choices=Type.choices,
|
||||
default=Type.base)
|
||||
embedding_mode_id = models.ForeignKey(Model, on_delete=models.DO_NOTHING, verbose_name="向量模型",
|
||||
default=default_model)
|
||||
embedding_mode = models.ForeignKey(Model, on_delete=models.DO_NOTHING, verbose_name="向量模型",
|
||||
default=default_model)
|
||||
meta = models.JSONField(verbose_name="元数据", default=dict)
|
||||
|
||||
class Meta:
|
||||
|
|
|
|||
|
|
@ -14,6 +14,7 @@ from django.db.models import QuerySet
|
|||
from drf_yasg import openapi
|
||||
from rest_framework import serializers
|
||||
|
||||
from common.config.embedding_config import EmbeddingModelManage
|
||||
from common.db.search import native_search
|
||||
from common.db.sql_execute import update_execute
|
||||
from common.exception.app_exception import AppApiException
|
||||
|
|
@ -21,7 +22,8 @@ from common.mixins.api_mixin import ApiMixin
|
|||
from common.util.field_message import ErrMessage
|
||||
from common.util.file_util import get_file_content
|
||||
from common.util.fork import Fork
|
||||
from dataset.models import Paragraph, Problem, ProblemParagraphMapping
|
||||
from dataset.models import Paragraph, Problem, ProblemParagraphMapping, DataSet
|
||||
from setting.models_provider import get_model
|
||||
from smartdoc.conf import PROJECT_DIR
|
||||
|
||||
|
||||
|
|
@ -130,3 +132,18 @@ class ProblemParagraphManage:
|
|||
result = [problem_model for problem_model, is_create in problem_content_dict.values() if
|
||||
is_create], problem_paragraph_mapping_list
|
||||
return result
|
||||
|
||||
|
||||
def get_embedding_model_by_dataset_id_list(dataset_id_list: List):
|
||||
dataset_list = QuerySet(DataSet).filter(id__in=dataset_id_list)
|
||||
if len(set([dataset.embedding_mode_id for dataset in dataset_list])) > 1:
|
||||
raise Exception("知识库未向量模型不一致")
|
||||
if len(dataset_list) == 0:
|
||||
raise Exception("知识库设置错误,请重新设置知识库")
|
||||
return EmbeddingModelManage.get_model(str(dataset_list[0].id),
|
||||
lambda _id: get_model(dataset_list[0].embedding_mode))
|
||||
|
||||
|
||||
def get_embedding_model_by_dataset_id(dataset_id: str):
|
||||
dataset = QuerySet(DataSet).select_related('embedding_mode').filter(id=dataset_id)
|
||||
return EmbeddingModelManage.get_model(str(dataset.id), lambda _id: get_model(dataset.embedding_mode))
|
||||
|
|
|
|||
|
|
@ -15,7 +15,6 @@ from functools import reduce
|
|||
from typing import Dict, List
|
||||
from urllib.parse import urlparse
|
||||
|
||||
from django.conf import settings
|
||||
from django.contrib.postgres.fields import ArrayField
|
||||
from django.core import validators
|
||||
from django.db import transaction, models
|
||||
|
|
@ -25,7 +24,7 @@ from drf_yasg import openapi
|
|||
from rest_framework import serializers
|
||||
|
||||
from application.models import ApplicationDatasetMapping
|
||||
from common.config.embedding_config import VectorStore, EmbeddingModel
|
||||
from common.config.embedding_config import VectorStore
|
||||
from common.db.search import get_dynamics_model, native_page_search, native_search
|
||||
from common.db.sql_execute import select_list
|
||||
from common.event import ListenerManagement, SyncWebDatasetArgs
|
||||
|
|
@ -37,7 +36,8 @@ from common.util.file_util import get_file_content
|
|||
from common.util.fork import ChildLink, Fork
|
||||
from common.util.split_model import get_split_model
|
||||
from dataset.models.data_set import DataSet, Document, Paragraph, Problem, Type, ProblemParagraphMapping
|
||||
from dataset.serializers.common_serializers import list_paragraph, MetaSerializer, ProblemParagraphManage
|
||||
from dataset.serializers.common_serializers import list_paragraph, MetaSerializer, ProblemParagraphManage, \
|
||||
get_embedding_model_by_dataset_id
|
||||
from dataset.serializers.document_serializers import DocumentSerializers, DocumentInstanceSerializer
|
||||
from embedding.models import SearchMode
|
||||
from setting.models import AuthOperate
|
||||
|
|
@ -359,8 +359,9 @@ class DataSetSerializers(serializers.ModelSerializer):
|
|||
|
||||
@staticmethod
|
||||
def post_embedding_dataset(document_list, dataset_id):
|
||||
model = get_embedding_model_by_dataset_id(dataset_id)
|
||||
# 发送向量化事件
|
||||
ListenerManagement.embedding_by_dataset_signal.send(dataset_id)
|
||||
ListenerManagement.embedding_by_dataset_signal.send(dataset_id, embedding_model=model)
|
||||
return document_list
|
||||
|
||||
def save_qa(self, instance: Dict, with_valid=True):
|
||||
|
|
@ -565,12 +566,13 @@ class DataSetSerializers(serializers.ModelSerializer):
|
|||
QuerySet(Document).filter(
|
||||
dataset_id=self.data.get('id'),
|
||||
is_active=False)]
|
||||
model = get_embedding_model_by_dataset_id(self.data.get('id'))
|
||||
# 向量库检索
|
||||
hit_list = vector.hit_test(self.data.get('query_text'), [self.data.get('id')], exclude_document_id_list,
|
||||
self.data.get('top_number'),
|
||||
self.data.get('similarity'),
|
||||
SearchMode(self.data.get('search_mode')),
|
||||
EmbeddingModel.get_embedding_model())
|
||||
model)
|
||||
hit_dict = reduce(lambda x, y: {**x, **y}, [{hit.get('paragraph_id'): hit} for hit in hit_list], {})
|
||||
p_list = list_paragraph([h.get('paragraph_id') for h in hit_list])
|
||||
return [{**p, 'similarity': hit_dict.get(p.get('id')).get('similarity'),
|
||||
|
|
|
|||
|
|
@ -10,9 +10,8 @@ import threading
|
|||
from abc import ABC, abstractmethod
|
||||
from typing import List, Dict
|
||||
|
||||
from langchain_community.embeddings import HuggingFaceEmbeddings
|
||||
from langchain_core.embeddings import Embeddings
|
||||
|
||||
from common.config.embedding_config import EmbeddingModel
|
||||
from common.util.common import sub_array
|
||||
from embedding.models import SourceType, SearchMode
|
||||
|
||||
|
|
@ -51,7 +50,7 @@ class BaseVectorStore(ABC):
|
|||
|
||||
def save(self, text, source_type: SourceType, dataset_id: str, document_id: str, paragraph_id: str, source_id: str,
|
||||
is_active: bool,
|
||||
embedding=None):
|
||||
embedding: Embeddings):
|
||||
"""
|
||||
插入向量数据
|
||||
:param source_id: 资源id
|
||||
|
|
@ -64,13 +63,10 @@ class BaseVectorStore(ABC):
|
|||
:param paragraph_id 段落id
|
||||
:return: bool
|
||||
"""
|
||||
|
||||
if embedding is None:
|
||||
embedding = EmbeddingModel.get_embedding_model()
|
||||
self.save_pre_handler()
|
||||
self._save(text, source_type, dataset_id, document_id, paragraph_id, source_id, is_active, embedding)
|
||||
|
||||
def batch_save(self, data_list: List[Dict], embedding=None):
|
||||
def batch_save(self, data_list: List[Dict], embedding: Embeddings):
|
||||
# 获取锁
|
||||
lock.acquire()
|
||||
try:
|
||||
|
|
@ -80,8 +76,6 @@ class BaseVectorStore(ABC):
|
|||
:param embedding: 向量化处理器
|
||||
:return: bool
|
||||
"""
|
||||
if embedding is None:
|
||||
embedding = EmbeddingModel.get_embedding_model()
|
||||
self.save_pre_handler()
|
||||
result = sub_array(data_list)
|
||||
for child_array in result:
|
||||
|
|
@ -94,17 +88,17 @@ class BaseVectorStore(ABC):
|
|||
@abstractmethod
|
||||
def _save(self, text, source_type: SourceType, dataset_id: str, document_id: str, paragraph_id: str, source_id: str,
|
||||
is_active: bool,
|
||||
embedding: HuggingFaceEmbeddings):
|
||||
embedding: Embeddings):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def _batch_save(self, text_list: List[Dict], embedding: HuggingFaceEmbeddings):
|
||||
def _batch_save(self, text_list: List[Dict], embedding: Embeddings):
|
||||
pass
|
||||
|
||||
def search(self, query_text, dataset_id_list: list[str], exclude_document_id_list: list[str],
|
||||
exclude_paragraph_list: list[str],
|
||||
is_active: bool,
|
||||
embedding: HuggingFaceEmbeddings):
|
||||
embedding: Embeddings):
|
||||
if dataset_id_list is None or len(dataset_id_list) == 0:
|
||||
return []
|
||||
embedding_query = embedding.embed_query(query_text)
|
||||
|
|
@ -123,7 +117,7 @@ class BaseVectorStore(ABC):
|
|||
def hit_test(self, query_text, dataset_id: list[str], exclude_document_id_list: list[str], top_number: int,
|
||||
similarity: float,
|
||||
search_mode: SearchMode,
|
||||
embedding: HuggingFaceEmbeddings):
|
||||
embedding: Embeddings):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
|
|
@ -142,14 +136,6 @@ class BaseVectorStore(ABC):
|
|||
def update_by_source_ids(self, source_ids: List[str], instance: Dict):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def embed_documents(self, text_list: List[str]):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def embed_query(self, text: str):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def delete_by_dataset_id(self, dataset_id: str):
|
||||
pass
|
||||
|
|
|
|||
|
|
@ -13,9 +13,8 @@ from abc import ABC, abstractmethod
|
|||
from typing import Dict, List
|
||||
|
||||
from django.db.models import QuerySet
|
||||
from langchain_community.embeddings import HuggingFaceEmbeddings
|
||||
from langchain_core.embeddings import Embeddings
|
||||
|
||||
from common.config.embedding_config import EmbeddingModel
|
||||
from common.db.search import generate_sql_by_query_dict
|
||||
from common.db.sql_execute import select_list
|
||||
from common.util.file_util import get_file_content
|
||||
|
|
@ -33,14 +32,6 @@ class PGVector(BaseVectorStore):
|
|||
def update_by_source_ids(self, source_ids: List[str], instance: Dict):
|
||||
QuerySet(Embedding).filter(source_id__in=source_ids).update(**instance)
|
||||
|
||||
def embed_documents(self, text_list: List[str]):
|
||||
embedding = EmbeddingModel.get_embedding_model()
|
||||
return embedding.embed_documents(text_list)
|
||||
|
||||
def embed_query(self, text: str):
|
||||
embedding = EmbeddingModel.get_embedding_model()
|
||||
return embedding.embed_query(text)
|
||||
|
||||
def vector_is_create(self) -> bool:
|
||||
# 项目启动默认是创建好的 不需要再创建
|
||||
return True
|
||||
|
|
@ -50,7 +41,7 @@ class PGVector(BaseVectorStore):
|
|||
|
||||
def _save(self, text, source_type: SourceType, dataset_id: str, document_id: str, paragraph_id: str, source_id: str,
|
||||
is_active: bool,
|
||||
embedding: HuggingFaceEmbeddings):
|
||||
embedding: Embeddings):
|
||||
text_embedding = embedding.embed_query(text)
|
||||
embedding = Embedding(id=uuid.uuid1(),
|
||||
dataset_id=dataset_id,
|
||||
|
|
@ -64,7 +55,7 @@ class PGVector(BaseVectorStore):
|
|||
embedding.save()
|
||||
return True
|
||||
|
||||
def _batch_save(self, text_list: List[Dict], embedding: HuggingFaceEmbeddings):
|
||||
def _batch_save(self, text_list: List[Dict], embedding: Embeddings):
|
||||
texts = [row.get('text') for row in text_list]
|
||||
embeddings = embedding.embed_documents(texts)
|
||||
embedding_list = [Embedding(id=uuid.uuid1(),
|
||||
|
|
@ -83,7 +74,7 @@ class PGVector(BaseVectorStore):
|
|||
def hit_test(self, query_text, dataset_id_list: list[str], exclude_document_id_list: list[str], top_number: int,
|
||||
similarity: float,
|
||||
search_mode: SearchMode,
|
||||
embedding: HuggingFaceEmbeddings):
|
||||
embedding: Embeddings):
|
||||
if dataset_id_list is None or len(dataset_id_list) == 0:
|
||||
return []
|
||||
exclude_dict = {}
|
||||
|
|
|
|||
|
|
@ -6,3 +6,85 @@
|
|||
@date:2023/10/31 17:16
|
||||
@desc:
|
||||
"""
|
||||
import json
|
||||
from typing import Dict
|
||||
|
||||
from common.util.rsa_util import rsa_long_decrypt
|
||||
from setting.models_provider.constants.model_provider_constants import ModelProvideConstants
|
||||
|
||||
|
||||
def get_model_(provider, model_type, model_name, credential):
|
||||
"""
|
||||
获取模型实例
|
||||
@param provider: 供应商
|
||||
@param model_type: 模型类型
|
||||
@param model_name: 模型名称
|
||||
@param credential: 认证信息
|
||||
@return: 模型实例
|
||||
"""
|
||||
model = get_provider(provider).get_model(model_type, model_name,
|
||||
json.loads(
|
||||
rsa_long_decrypt(credential)),
|
||||
streaming=True)
|
||||
return model
|
||||
|
||||
|
||||
def get_model(model):
|
||||
"""
|
||||
获取模型实例
|
||||
@param model: model 数据库Model实例对象
|
||||
@return: 模型实例
|
||||
"""
|
||||
return get_model_(model.provider, model.model_type, model.model_name, model.credential)
|
||||
|
||||
|
||||
def get_provider(provider):
|
||||
"""
|
||||
获取供应商实例
|
||||
@param provider: 供应商字符串
|
||||
@return: 供应商实例
|
||||
"""
|
||||
return ModelProvideConstants[provider].value
|
||||
|
||||
|
||||
def get_model_list(provider, model_type):
|
||||
"""
|
||||
获取模型列表
|
||||
@param provider: 供应商字符串
|
||||
@param model_type: 模型类型
|
||||
@return: 模型列表
|
||||
"""
|
||||
return get_provider(provider).get_model_list(model_type)
|
||||
|
||||
|
||||
def get_model_credential(provider, model_type, model_name):
|
||||
"""
|
||||
获取模型认证实例
|
||||
@param provider: 供应商字符串
|
||||
@param model_type: 模型类型
|
||||
@param model_name: 模型名称
|
||||
@return: 认证实例对象
|
||||
"""
|
||||
return get_provider(provider).get_model_credential(model_type, model_name)
|
||||
|
||||
|
||||
def get_model_type_list(provider):
|
||||
"""
|
||||
获取模型类型列表
|
||||
@param provider: 供应商字符串
|
||||
@return: 模型类型列表
|
||||
"""
|
||||
return get_provider(provider).get_model_type_list()
|
||||
|
||||
|
||||
def is_valid_credential(provider, model_type, model_name, model_credential: Dict[str, object], raise_exception=False):
|
||||
"""
|
||||
校验模型认证参数
|
||||
@param provider: 供应商字符串
|
||||
@param model_type: 模型类型
|
||||
@param model_name: 模型名称
|
||||
@param model_credential: 模型认证数据
|
||||
@param raise_exception: 是否抛出错误
|
||||
@return: True|False
|
||||
"""
|
||||
return get_provider(provider).is_valid_credential(model_type, model_name, model_credential, raise_exception)
|
||||
|
|
|
|||
|
|
@ -104,9 +104,6 @@ CACHES = {
|
|||
"token_cache": {
|
||||
'BACKEND': 'common.cache.file_cache.FileCache',
|
||||
'LOCATION': os.path.join(PROJECT_DIR, 'data', 'cache', "token_cache") # 文件夹路径
|
||||
},
|
||||
"chat_cache": {
|
||||
'BACKEND': 'django.core.cache.backends.locmem.LocMemCache',
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
|||
Loading…
Reference in New Issue