mirror of
https://github.com/1Panel-dev/MaxKB.git
synced 2025-12-28 05:42:51 +00:00
feat: 支持向量模型
This commit is contained in:
parent
0e8d4eab12
commit
dcf5892b96
|
|
@ -47,22 +47,26 @@ class SyncWebDocumentArgs:
|
|||
|
||||
|
||||
class UpdateProblemArgs:
|
||||
def __init__(self, problem_id: str, problem_content: str):
|
||||
def __init__(self, problem_id: str, problem_content: str, embedding_model: Embeddings):
|
||||
self.problem_id = problem_id
|
||||
self.problem_content = problem_content
|
||||
self.embedding_model = embedding_model
|
||||
|
||||
|
||||
class UpdateEmbeddingDatasetIdArgs:
|
||||
def __init__(self, paragraph_id_list: List[str], target_dataset_id: str):
|
||||
def __init__(self, paragraph_id_list: List[str], target_dataset_id: str, target_embedding_model: Embeddings):
|
||||
self.paragraph_id_list = paragraph_id_list
|
||||
self.target_dataset_id = target_dataset_id
|
||||
self.target_embedding_model = target_embedding_model
|
||||
|
||||
|
||||
class UpdateEmbeddingDocumentIdArgs:
|
||||
def __init__(self, paragraph_id_list: List[str], target_document_id: str, target_dataset_id: str):
|
||||
def __init__(self, paragraph_id_list: List[str], target_document_id: str, target_dataset_id: str,
|
||||
target_embedding_model: Embeddings = None):
|
||||
self.paragraph_id_list = paragraph_id_list
|
||||
self.target_document_id = target_document_id
|
||||
self.target_dataset_id = target_dataset_id
|
||||
self.target_embedding_model = target_embedding_model
|
||||
|
||||
|
||||
class ListenerManagement:
|
||||
|
|
@ -88,6 +92,36 @@ class ListenerManagement:
|
|||
def embedding_by_problem(args, embedding_model: Embeddings):
|
||||
VectorStore.get_embedding_vector().save(**args, embedding=embedding_model)
|
||||
|
||||
@staticmethod
|
||||
def embedding_by_paragraph_list(paragraph_id_list, embedding_model: Embeddings):
|
||||
try:
|
||||
data_list = native_search(
|
||||
{'problem': QuerySet(get_dynamics_model({'paragraph.id': django.db.models.CharField()})).filter(
|
||||
**{'paragraph.id__in': paragraph_id_list}),
|
||||
'paragraph': QuerySet(Paragraph).filter(id__in=paragraph_id_list)},
|
||||
select_string=get_file_content(
|
||||
os.path.join(PROJECT_DIR, "apps", "common", 'sql', 'list_embedding_text.sql')))
|
||||
ListenerManagement.embedding_by_paragraph_data_list(data_list, paragraph_id_list=paragraph_id_list,
|
||||
embedding_model=embedding_model)
|
||||
except Exception as e:
|
||||
max_kb_error.error(f'查询向量数据:{paragraph_id_list}出现错误{str(e)}{traceback.format_exc()}')
|
||||
|
||||
@staticmethod
|
||||
@embedding_poxy
|
||||
def embedding_by_paragraph_data_list(data_list, paragraph_id_list, embedding_model: Embeddings):
|
||||
max_kb.info(f'开始--->向量化段落:{paragraph_id_list}')
|
||||
try:
|
||||
# 删除段落
|
||||
VectorStore.get_embedding_vector().delete_by_paragraph_ids(paragraph_id_list)
|
||||
# 批量向量化
|
||||
VectorStore.get_embedding_vector().batch_save(data_list, embedding_model)
|
||||
except Exception as e:
|
||||
max_kb_error.error(f'向量化段落:{paragraph_id_list}出现错误{str(e)}{traceback.format_exc()}')
|
||||
status = Status.error
|
||||
finally:
|
||||
QuerySet(Paragraph).filter(id__in=paragraph_id_list).update(**{'status': status})
|
||||
max_kb.info(f'结束--->向量化段落:{paragraph_id_list}')
|
||||
|
||||
@staticmethod
|
||||
@embedding_poxy
|
||||
def embedding_by_paragraph(paragraph_id, embedding_model: Embeddings):
|
||||
|
|
@ -227,14 +261,22 @@ class ListenerManagement:
|
|||
|
||||
@staticmethod
|
||||
def update_embedding_dataset_id(args: UpdateEmbeddingDatasetIdArgs):
|
||||
VectorStore.get_embedding_vector().update_by_paragraph_ids(args.paragraph_id_list,
|
||||
{'dataset_id': args.target_dataset_id})
|
||||
if args.target_embedding_model is None:
|
||||
VectorStore.get_embedding_vector().update_by_paragraph_ids(args.paragraph_id_list,
|
||||
{'dataset_id': args.target_dataset_id})
|
||||
else:
|
||||
ListenerManagement.embedding_by_paragraph_list(args.paragraph_id_list,
|
||||
embedding_model=args.target_embedding_model)
|
||||
|
||||
@staticmethod
|
||||
def update_embedding_document_id(args: UpdateEmbeddingDocumentIdArgs):
|
||||
VectorStore.get_embedding_vector().update_by_paragraph_ids(args.paragraph_id_list,
|
||||
{'document_id': args.target_document_id,
|
||||
'dataset_id': args.target_dataset_id})
|
||||
if args.target_embedding_model is None:
|
||||
VectorStore.get_embedding_vector().update_by_paragraph_ids(args.paragraph_id_list,
|
||||
{'document_id': args.target_document_id,
|
||||
'dataset_id': args.target_dataset_id})
|
||||
else:
|
||||
ListenerManagement.embedding_by_paragraph_list(args.paragraph_id_list,
|
||||
embedding_model=args.target_embedding_model)
|
||||
|
||||
@staticmethod
|
||||
def delete_embedding_by_source_ids(source_ids: List[str]):
|
||||
|
|
|
|||
|
|
@ -147,3 +147,7 @@ def get_embedding_model_by_dataset_id_list(dataset_id_list: List):
|
|||
def get_embedding_model_by_dataset_id(dataset_id: str):
|
||||
dataset = QuerySet(DataSet).select_related('embedding_mode').filter(id=dataset_id).first()
|
||||
return EmbeddingModelManage.get_model(dataset_id, lambda _id: get_model(dataset.embedding_mode))
|
||||
|
||||
|
||||
def get_embedding_model_by_dataset(dataset):
|
||||
return EmbeddingModelManage.get_model(str(dataset.id), lambda _id: get_model(dataset.embedding_mode))
|
||||
|
|
|
|||
|
|
@ -235,12 +235,17 @@ class DocumentSerializers(ApiMixin, serializers.Serializer):
|
|||
meta={})
|
||||
else:
|
||||
document_list.update(dataset_id=target_dataset_id)
|
||||
# 修改向量信息
|
||||
ListenerManagement.update_embedding_dataset_id(UpdateEmbeddingDatasetIdArgs(
|
||||
[paragraph.id for paragraph in paragraph_list],
|
||||
target_dataset_id))
|
||||
model = None
|
||||
if dataset.embedding_mode_id != target_dataset.embedding_mode_id:
|
||||
model = get_embedding_model_by_dataset_id(target_dataset_id)
|
||||
|
||||
pid_list = [paragraph.id for paragraph in paragraph_list]
|
||||
# 修改段落信息
|
||||
paragraph_list.update(dataset_id=target_dataset_id)
|
||||
# 修改向量信息
|
||||
ListenerManagement.update_embedding_dataset_id(UpdateEmbeddingDatasetIdArgs(
|
||||
pid_list,
|
||||
target_dataset_id, model))
|
||||
|
||||
@staticmethod
|
||||
def get_target_dataset_problem(target_dataset_id: str,
|
||||
|
|
|
|||
|
|
@ -20,9 +20,9 @@ from common.exception.app_exception import AppApiException
|
|||
from common.mixins.api_mixin import ApiMixin
|
||||
from common.util.common import post
|
||||
from common.util.field_message import ErrMessage
|
||||
from dataset.models import Paragraph, Problem, Document, ProblemParagraphMapping
|
||||
from dataset.models import Paragraph, Problem, Document, ProblemParagraphMapping, DataSet
|
||||
from dataset.serializers.common_serializers import update_document_char_length, BatchSerializer, ProblemParagraphObject, \
|
||||
ProblemParagraphManage, get_embedding_model_by_dataset_id
|
||||
ProblemParagraphManage, get_embedding_model_by_dataset_id, get_embedding_model_by_dataset
|
||||
from dataset.serializers.problem_serializers import ProblemInstanceSerializer, ProblemSerializer, ProblemSerializers
|
||||
from embedding.models import SourceType
|
||||
|
||||
|
|
@ -338,10 +338,11 @@ class ParagraphSerializers(ApiMixin, serializers.Serializer):
|
|||
# 修改mapping
|
||||
QuerySet(ProblemParagraphMapping).bulk_update(problem_paragraph_mapping_list,
|
||||
['document_id'])
|
||||
|
||||
# 修改向量段落信息
|
||||
ListenerManagement.update_embedding_document_id(UpdateEmbeddingDocumentIdArgs(
|
||||
[paragraph.id for paragraph in paragraph_list],
|
||||
target_document_id, target_dataset_id))
|
||||
target_document_id, target_dataset_id, target_embedding_model=None))
|
||||
# 修改段落信息
|
||||
paragraph_list.update(document_id=target_document_id)
|
||||
# 不同数据集迁移
|
||||
|
|
@ -368,12 +369,19 @@ class ParagraphSerializers(ApiMixin, serializers.Serializer):
|
|||
# 修改mapping
|
||||
QuerySet(ProblemParagraphMapping).bulk_update(problem_paragraph_mapping_list,
|
||||
['problem_id', 'dataset_id', 'document_id'])
|
||||
# 修改向量段落信息
|
||||
ListenerManagement.update_embedding_document_id(UpdateEmbeddingDocumentIdArgs(
|
||||
[paragraph.id for paragraph in paragraph_list],
|
||||
target_document_id, target_dataset_id))
|
||||
target_dataset = QuerySet(DataSet).filter(id=target_dataset_id).first()
|
||||
dataset = QuerySet(DataSet).filter(id=dataset_id).first()
|
||||
embedding_model = None
|
||||
if target_dataset.embedding_mode_id != dataset.embedding_mode_id:
|
||||
embedding_model = get_embedding_model_by_dataset(target_dataset)
|
||||
pid_list = [paragraph.id for paragraph in paragraph_list]
|
||||
# 修改段落信息
|
||||
paragraph_list.update(dataset_id=target_dataset_id, document_id=target_document_id)
|
||||
# 修改向量段落信息
|
||||
ListenerManagement.update_embedding_document_id(UpdateEmbeddingDocumentIdArgs(
|
||||
pid_list,
|
||||
target_document_id, target_dataset_id, target_embedding_model=embedding_model))
|
||||
|
||||
update_document_char_length(document_id)
|
||||
update_document_char_length(target_document_id)
|
||||
|
||||
|
|
|
|||
|
|
@ -20,7 +20,8 @@ from common.event import ListenerManagement, UpdateProblemArgs
|
|||
from common.mixins.api_mixin import ApiMixin
|
||||
from common.util.field_message import ErrMessage
|
||||
from common.util.file_util import get_file_content
|
||||
from dataset.models import Problem, Paragraph, ProblemParagraphMapping
|
||||
from dataset.models import Problem, Paragraph, ProblemParagraphMapping, DataSet
|
||||
from dataset.serializers.common_serializers import get_embedding_model_by_dataset_id
|
||||
from smartdoc.conf import PROJECT_DIR
|
||||
|
||||
|
||||
|
|
@ -157,6 +158,8 @@ class ProblemSerializers(ApiMixin, serializers.Serializer):
|
|||
content = instance.get('content')
|
||||
problem = QuerySet(Problem).filter(id=problem_id,
|
||||
dataset_id=dataset_id).first()
|
||||
QuerySet(DataSet).filter(id=dataset_id)
|
||||
problem.content = content
|
||||
problem.save()
|
||||
ListenerManagement.update_problem_signal.send(UpdateProblemArgs(problem_id, content))
|
||||
model = get_embedding_model_by_dataset_id(dataset_id)
|
||||
ListenerManagement.update_problem_signal.send(UpdateProblemArgs(problem_id, content, model))
|
||||
|
|
|
|||
Loading…
Reference in New Issue