mirror of
https://github.com/1Panel-dev/MaxKB.git
synced 2025-12-26 01:33:05 +00:00
feat: 支持向量模型
This commit is contained in:
parent
ead263da22
commit
b14a799350
|
|
@ -43,6 +43,7 @@ class ISearchDatasetStep(IBaseChatPipelineStep):
|
|||
validators.RegexValidator(regex=re.compile("^embedding|keywords|blend$"),
|
||||
message="类型只支持register|reset_password", code=500)
|
||||
], error_messages=ErrMessage.char("检索模式"))
|
||||
user_id = serializers.UUIDField(required=True, error_messages=ErrMessage.uuid("用户id"))
|
||||
|
||||
def get_step_serializer(self, manage: PipelineManage) -> Type[InstanceSerializer]:
|
||||
return self.InstanceSerializer
|
||||
|
|
@ -56,6 +57,7 @@ class ISearchDatasetStep(IBaseChatPipelineStep):
|
|||
def execute(self, problem_text: str, dataset_id_list: list[str], exclude_document_id_list: list[str],
|
||||
exclude_paragraph_id_list: list[str], top_n: int, similarity: float, padding_problem_text: str = None,
|
||||
search_mode: str = None,
|
||||
user_id=None,
|
||||
**kwargs) -> List[ParagraphPipelineModel]:
|
||||
"""
|
||||
关于 用户和补全问题 说明: 补全问题如果有就使用补全问题去查询 反之就用用户原始问题查询
|
||||
|
|
@ -67,6 +69,7 @@ class ISearchDatasetStep(IBaseChatPipelineStep):
|
|||
:param exclude_paragraph_id_list: 需要排除段落id
|
||||
:param padding_problem_text 补全问题
|
||||
:param search_mode 检索模式
|
||||
:param user_id 用户id
|
||||
:return: 段落列表
|
||||
"""
|
||||
pass
|
||||
|
|
|
|||
|
|
@ -13,7 +13,7 @@ from django.db.models import QuerySet
|
|||
|
||||
from application.chat_pipeline.I_base_chat_pipeline import ParagraphPipelineModel
|
||||
from application.chat_pipeline.step.search_dataset_step.i_search_dataset_step import ISearchDatasetStep
|
||||
from common.config.embedding_config import VectorStore, EmbeddingModelManage
|
||||
from common.config.embedding_config import VectorStore, ModelManage
|
||||
from common.db.search import native_search
|
||||
from common.util.file_util import get_file_content
|
||||
from dataset.models import Paragraph, DataSet
|
||||
|
|
@ -23,10 +23,12 @@ from setting.models_provider import get_model
|
|||
from smartdoc.conf import PROJECT_DIR
|
||||
|
||||
|
||||
def get_model_by_id(_id):
|
||||
def get_model_by_id(_id, user_id):
|
||||
model = QuerySet(Model).filter(id=_id).first()
|
||||
if model is None:
|
||||
raise Exception("模型不存在")
|
||||
if model.permission_type == 'PRIVATE' and str(model.user_id) != str(user_id):
|
||||
raise Exception(f"无权限使用此模型:{model.name}")
|
||||
return model
|
||||
|
||||
|
||||
|
|
@ -44,14 +46,15 @@ class BaseSearchDatasetStep(ISearchDatasetStep):
|
|||
def execute(self, problem_text: str, dataset_id_list: list[str], exclude_document_id_list: list[str],
|
||||
exclude_paragraph_id_list: list[str], top_n: int, similarity: float, padding_problem_text: str = None,
|
||||
search_mode: str = None,
|
||||
user_id=None,
|
||||
**kwargs) -> List[ParagraphPipelineModel]:
|
||||
if len(dataset_id_list) == 0:
|
||||
return []
|
||||
exec_problem_text = padding_problem_text if padding_problem_text is not None else problem_text
|
||||
model_id = get_embedding_id(dataset_id_list)
|
||||
model = get_model_by_id(model_id)
|
||||
model = get_model_by_id(model_id, user_id)
|
||||
self.context['model_name'] = model.name
|
||||
embedding_model = EmbeddingModelManage.get_model(model_id, lambda _id: get_model(model))
|
||||
embedding_model = ModelManage.get_model(model_id, lambda _id: get_model(model))
|
||||
embedding_value = embedding_model.embed_query(exec_problem_text)
|
||||
vector = VectorStore.get_embedding_vector()
|
||||
embedding_list = vector.query(exec_problem_text, embedding_value, dataset_id_list, exclude_document_id_list,
|
||||
|
|
|
|||
|
|
@ -13,7 +13,7 @@ from django.db.models import QuerySet
|
|||
|
||||
from application.flow.i_step_node import NodeResult
|
||||
from application.flow.step_node.search_dataset_node.i_search_dataset_node import ISearchDatasetStepNode
|
||||
from common.config.embedding_config import VectorStore, EmbeddingModelManage
|
||||
from common.config.embedding_config import VectorStore, ModelManage
|
||||
from common.db.search import native_search
|
||||
from common.util.file_util import get_file_content
|
||||
from dataset.models import Document, Paragraph, DataSet
|
||||
|
|
@ -56,7 +56,7 @@ class BaseSearchDatasetNode(ISearchDatasetStepNode):
|
|||
return get_none_result(question)
|
||||
model_id = get_embedding_id(dataset_id_list)
|
||||
model = get_model_by_id(model_id, self.flow_params_serializer.data.get('user_id'))
|
||||
embedding_model = EmbeddingModelManage.get_model(model_id, lambda _id: get_model(model))
|
||||
embedding_model = ModelManage.get_model(model_id, lambda _id: get_model(model))
|
||||
embedding_value = embedding_model.embed_query(question)
|
||||
vector = VectorStore.get_embedding_vector()
|
||||
exclude_document_id_list = [str(document.id) for document in
|
||||
|
|
|
|||
|
|
@ -88,7 +88,9 @@ class ChatInfo:
|
|||
'no_references_setting': self.application.dataset_setting.get(
|
||||
'no_references_setting') if 'no_references_setting' in self.application.dataset_setting else {
|
||||
'status': 'ai_questioning',
|
||||
'value': '{question}'}
|
||||
'value': '{question}',
|
||||
},
|
||||
'user_id': self.application.user_id
|
||||
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -29,6 +29,7 @@ from application.models.api_key_model import ApplicationAccessToken
|
|||
from application.serializers.application_serializers import ModelDatasetAssociation, DatasetSettingSerializer, \
|
||||
ModelSettingSerializer
|
||||
from application.serializers.chat_message_serializers import ChatInfo
|
||||
from common.config.embedding_config import ModelManage
|
||||
from common.constants.permission_constants import RoleConstants
|
||||
from common.db.search import native_search, native_page_search, page_search, get_dynamics_model
|
||||
from common.event import ListenerManagement
|
||||
|
|
@ -42,6 +43,7 @@ from dataset.models import Document, Problem, Paragraph, ProblemParagraphMapping
|
|||
from dataset.serializers.common_serializers import get_embedding_model_by_dataset_id
|
||||
from dataset.serializers.paragraph_serializers import ParagraphSerializers
|
||||
from setting.models import Model
|
||||
from setting.models_provider import get_model
|
||||
from setting.models_provider.constants.model_provider_constants import ModelProvideConstants
|
||||
from smartdoc.conf import PROJECT_DIR
|
||||
|
||||
|
|
@ -242,12 +244,7 @@ class ChatSerializers(serializers.Serializer):
|
|||
application_id=application_id)]
|
||||
chat_model = None
|
||||
if model is not None:
|
||||
chat_model = ModelProvideConstants[model.provider].value.get_model(model.model_type, model.model_name,
|
||||
json.loads(
|
||||
rsa_long_decrypt(
|
||||
model.credential)),
|
||||
streaming=True)
|
||||
|
||||
chat_model = ModelManage.get_model(str(model.id), lambda _id: get_model(model))
|
||||
chat_id = str(uuid.uuid1())
|
||||
chat_cache.set(chat_id,
|
||||
ChatInfo(chat_id, chat_model, dataset_id_list,
|
||||
|
|
|
|||
|
|
@ -11,26 +11,31 @@ import time
|
|||
from common.cache.mem_cache import MemCache
|
||||
|
||||
|
||||
class EmbeddingModelManage:
|
||||
class ModelManage:
|
||||
cache = MemCache('model', {})
|
||||
up_clear_time = time.time()
|
||||
|
||||
@staticmethod
|
||||
def get_model(_id, get_model):
|
||||
model_instance = EmbeddingModelManage.cache.get(_id)
|
||||
model_instance = ModelManage.cache.get(_id)
|
||||
if model_instance is None:
|
||||
model_instance = get_model(_id)
|
||||
EmbeddingModelManage.cache.set(_id, model_instance, timeout=60 * 30)
|
||||
ModelManage.cache.set(_id, model_instance, timeout=60 * 30)
|
||||
return model_instance
|
||||
# 续期
|
||||
EmbeddingModelManage.cache.touch(_id, timeout=60 * 30)
|
||||
EmbeddingModelManage.clear_timeout_cache()
|
||||
ModelManage.cache.touch(_id, timeout=60 * 30)
|
||||
ModelManage.clear_timeout_cache()
|
||||
return model_instance
|
||||
|
||||
@staticmethod
|
||||
def clear_timeout_cache():
|
||||
if time.time() - EmbeddingModelManage.up_clear_time > 60:
|
||||
EmbeddingModelManage.cache.clear_timeout_data()
|
||||
if time.time() - ModelManage.up_clear_time > 60:
|
||||
ModelManage.cache.clear_timeout_data()
|
||||
|
||||
@staticmethod
|
||||
def delete_key(_id):
|
||||
if ModelManage.cache.has_key(_id):
|
||||
ModelManage.cache.delete(_id)
|
||||
|
||||
|
||||
class VectorStore:
|
||||
|
|
|
|||
|
|
@ -14,7 +14,7 @@ from django.db.models import QuerySet
|
|||
from drf_yasg import openapi
|
||||
from rest_framework import serializers
|
||||
|
||||
from common.config.embedding_config import EmbeddingModelManage
|
||||
from common.config.embedding_config import ModelManage
|
||||
from common.db.search import native_search
|
||||
from common.db.sql_execute import update_execute
|
||||
from common.exception.app_exception import AppApiException
|
||||
|
|
@ -140,14 +140,14 @@ def get_embedding_model_by_dataset_id_list(dataset_id_list: List):
|
|||
raise Exception("知识库未向量模型不一致")
|
||||
if len(dataset_list) == 0:
|
||||
raise Exception("知识库设置错误,请重新设置知识库")
|
||||
return EmbeddingModelManage.get_model(str(dataset_list[0].id),
|
||||
return ModelManage.get_model(str(dataset_list[0].id),
|
||||
lambda _id: get_model(dataset_list[0].embedding_mode))
|
||||
|
||||
|
||||
def get_embedding_model_by_dataset_id(dataset_id: str):
|
||||
dataset = QuerySet(DataSet).select_related('embedding_mode').filter(id=dataset_id).first()
|
||||
return EmbeddingModelManage.get_model(dataset_id, lambda _id: get_model(dataset.embedding_mode))
|
||||
return ModelManage.get_model(dataset_id, lambda _id: get_model(dataset.embedding_mode))
|
||||
|
||||
|
||||
def get_embedding_model_by_dataset(dataset):
|
||||
return EmbeddingModelManage.get_model(str(dataset.id), lambda _id: get_model(dataset.embedding_mode))
|
||||
return ModelManage.get_model(str(dataset.id), lambda _id: get_model(dataset.embedding_mode))
|
||||
|
|
|
|||
Loading…
Reference in New Issue