feat: 支持向量模型

This commit is contained in:
shaohuzhang1 2024-07-19 10:34:47 +08:00
parent ead263da22
commit b14a799350
7 changed files with 34 additions and 24 deletions

View File

@ -43,6 +43,7 @@ class ISearchDatasetStep(IBaseChatPipelineStep):
validators.RegexValidator(regex=re.compile("^embedding|keywords|blend$"),
message="类型只支持register|reset_password", code=500)
], error_messages=ErrMessage.char("检索模式"))
user_id = serializers.UUIDField(required=True, error_messages=ErrMessage.uuid("用户id"))
def get_step_serializer(self, manage: PipelineManage) -> Type[InstanceSerializer]:
return self.InstanceSerializer
@ -56,6 +57,7 @@ class ISearchDatasetStep(IBaseChatPipelineStep):
def execute(self, problem_text: str, dataset_id_list: list[str], exclude_document_id_list: list[str],
exclude_paragraph_id_list: list[str], top_n: int, similarity: float, padding_problem_text: str = None,
search_mode: str = None,
user_id=None,
**kwargs) -> List[ParagraphPipelineModel]:
"""
关于 用户和补全问题 说明: 补全问题如果有就使用补全问题去查询 反之就用用户原始问题查询
@ -67,6 +69,7 @@ class ISearchDatasetStep(IBaseChatPipelineStep):
:param exclude_paragraph_id_list: 需要排除段落id
:param padding_problem_text 补全问题
:param search_mode 检索模式
:param user_id 用户id
:return: 段落列表
"""
pass

View File

@ -13,7 +13,7 @@ from django.db.models import QuerySet
from application.chat_pipeline.I_base_chat_pipeline import ParagraphPipelineModel
from application.chat_pipeline.step.search_dataset_step.i_search_dataset_step import ISearchDatasetStep
from common.config.embedding_config import VectorStore, EmbeddingModelManage
from common.config.embedding_config import VectorStore, ModelManage
from common.db.search import native_search
from common.util.file_util import get_file_content
from dataset.models import Paragraph, DataSet
@ -23,10 +23,12 @@ from setting.models_provider import get_model
from smartdoc.conf import PROJECT_DIR
def get_model_by_id(_id):
def get_model_by_id(_id, user_id):
model = QuerySet(Model).filter(id=_id).first()
if model is None:
raise Exception("模型不存在")
if model.permission_type == 'PRIVATE' and str(model.user_id) != str(user_id):
raise Exception(f"无权限使用此模型:{model.name}")
return model
@ -44,14 +46,15 @@ class BaseSearchDatasetStep(ISearchDatasetStep):
def execute(self, problem_text: str, dataset_id_list: list[str], exclude_document_id_list: list[str],
exclude_paragraph_id_list: list[str], top_n: int, similarity: float, padding_problem_text: str = None,
search_mode: str = None,
user_id=None,
**kwargs) -> List[ParagraphPipelineModel]:
if len(dataset_id_list) == 0:
return []
exec_problem_text = padding_problem_text if padding_problem_text is not None else problem_text
model_id = get_embedding_id(dataset_id_list)
model = get_model_by_id(model_id)
model = get_model_by_id(model_id, user_id)
self.context['model_name'] = model.name
embedding_model = EmbeddingModelManage.get_model(model_id, lambda _id: get_model(model))
embedding_model = ModelManage.get_model(model_id, lambda _id: get_model(model))
embedding_value = embedding_model.embed_query(exec_problem_text)
vector = VectorStore.get_embedding_vector()
embedding_list = vector.query(exec_problem_text, embedding_value, dataset_id_list, exclude_document_id_list,

View File

@ -13,7 +13,7 @@ from django.db.models import QuerySet
from application.flow.i_step_node import NodeResult
from application.flow.step_node.search_dataset_node.i_search_dataset_node import ISearchDatasetStepNode
from common.config.embedding_config import VectorStore, EmbeddingModelManage
from common.config.embedding_config import VectorStore, ModelManage
from common.db.search import native_search
from common.util.file_util import get_file_content
from dataset.models import Document, Paragraph, DataSet
@ -56,7 +56,7 @@ class BaseSearchDatasetNode(ISearchDatasetStepNode):
return get_none_result(question)
model_id = get_embedding_id(dataset_id_list)
model = get_model_by_id(model_id, self.flow_params_serializer.data.get('user_id'))
embedding_model = EmbeddingModelManage.get_model(model_id, lambda _id: get_model(model))
embedding_model = ModelManage.get_model(model_id, lambda _id: get_model(model))
embedding_value = embedding_model.embed_query(question)
vector = VectorStore.get_embedding_vector()
exclude_document_id_list = [str(document.id) for document in

View File

@ -88,7 +88,9 @@ class ChatInfo:
'no_references_setting': self.application.dataset_setting.get(
'no_references_setting') if 'no_references_setting' in self.application.dataset_setting else {
'status': 'ai_questioning',
'value': '{question}'}
'value': '{question}',
},
'user_id': self.application.user_id
}

View File

@ -29,6 +29,7 @@ from application.models.api_key_model import ApplicationAccessToken
from application.serializers.application_serializers import ModelDatasetAssociation, DatasetSettingSerializer, \
ModelSettingSerializer
from application.serializers.chat_message_serializers import ChatInfo
from common.config.embedding_config import ModelManage
from common.constants.permission_constants import RoleConstants
from common.db.search import native_search, native_page_search, page_search, get_dynamics_model
from common.event import ListenerManagement
@ -42,6 +43,7 @@ from dataset.models import Document, Problem, Paragraph, ProblemParagraphMapping
from dataset.serializers.common_serializers import get_embedding_model_by_dataset_id
from dataset.serializers.paragraph_serializers import ParagraphSerializers
from setting.models import Model
from setting.models_provider import get_model
from setting.models_provider.constants.model_provider_constants import ModelProvideConstants
from smartdoc.conf import PROJECT_DIR
@ -242,12 +244,7 @@ class ChatSerializers(serializers.Serializer):
application_id=application_id)]
chat_model = None
if model is not None:
chat_model = ModelProvideConstants[model.provider].value.get_model(model.model_type, model.model_name,
json.loads(
rsa_long_decrypt(
model.credential)),
streaming=True)
chat_model = ModelManage.get_model(str(model.id), lambda _id: get_model(model))
chat_id = str(uuid.uuid1())
chat_cache.set(chat_id,
ChatInfo(chat_id, chat_model, dataset_id_list,

View File

@ -11,26 +11,31 @@ import time
from common.cache.mem_cache import MemCache
class EmbeddingModelManage:
class ModelManage:
cache = MemCache('model', {})
up_clear_time = time.time()
@staticmethod
def get_model(_id, get_model):
model_instance = EmbeddingModelManage.cache.get(_id)
model_instance = ModelManage.cache.get(_id)
if model_instance is None:
model_instance = get_model(_id)
EmbeddingModelManage.cache.set(_id, model_instance, timeout=60 * 30)
ModelManage.cache.set(_id, model_instance, timeout=60 * 30)
return model_instance
# 续期
EmbeddingModelManage.cache.touch(_id, timeout=60 * 30)
EmbeddingModelManage.clear_timeout_cache()
ModelManage.cache.touch(_id, timeout=60 * 30)
ModelManage.clear_timeout_cache()
return model_instance
@staticmethod
def clear_timeout_cache():
if time.time() - EmbeddingModelManage.up_clear_time > 60:
EmbeddingModelManage.cache.clear_timeout_data()
if time.time() - ModelManage.up_clear_time > 60:
ModelManage.cache.clear_timeout_data()
@staticmethod
def delete_key(_id):
if ModelManage.cache.has_key(_id):
ModelManage.cache.delete(_id)
class VectorStore:

View File

@ -14,7 +14,7 @@ from django.db.models import QuerySet
from drf_yasg import openapi
from rest_framework import serializers
from common.config.embedding_config import EmbeddingModelManage
from common.config.embedding_config import ModelManage
from common.db.search import native_search
from common.db.sql_execute import update_execute
from common.exception.app_exception import AppApiException
@ -140,14 +140,14 @@ def get_embedding_model_by_dataset_id_list(dataset_id_list: List):
raise Exception("知识库未向量模型不一致")
if len(dataset_list) == 0:
raise Exception("知识库设置错误,请重新设置知识库")
return EmbeddingModelManage.get_model(str(dataset_list[0].id),
return ModelManage.get_model(str(dataset_list[0].id),
lambda _id: get_model(dataset_list[0].embedding_mode))
def get_embedding_model_by_dataset_id(dataset_id: str):
dataset = QuerySet(DataSet).select_related('embedding_mode').filter(id=dataset_id).first()
return EmbeddingModelManage.get_model(dataset_id, lambda _id: get_model(dataset.embedding_mode))
return ModelManage.get_model(dataset_id, lambda _id: get_model(dataset.embedding_mode))
def get_embedding_model_by_dataset(dataset):
return EmbeddingModelManage.get_model(str(dataset.id), lambda _id: get_model(dataset.embedding_mode))
return ModelManage.get_model(str(dataset.id), lambda _id: get_model(dataset.embedding_mode))