feat: 支持向量模型

This commit is contained in:
shaohuzhang1 2024-07-18 15:44:48 +08:00
parent 0e8d4eab12
commit dcf5892b96
5 changed files with 83 additions and 21 deletions

View File

@ -47,22 +47,26 @@ class SyncWebDocumentArgs:
class UpdateProblemArgs:
def __init__(self, problem_id: str, problem_content: str):
def __init__(self, problem_id: str, problem_content: str, embedding_model: Embeddings):
self.problem_id = problem_id
self.problem_content = problem_content
self.embedding_model = embedding_model
class UpdateEmbeddingDatasetIdArgs:
def __init__(self, paragraph_id_list: List[str], target_dataset_id: str):
def __init__(self, paragraph_id_list: List[str], target_dataset_id: str, target_embedding_model: Embeddings):
self.paragraph_id_list = paragraph_id_list
self.target_dataset_id = target_dataset_id
self.target_embedding_model = target_embedding_model
class UpdateEmbeddingDocumentIdArgs:
def __init__(self, paragraph_id_list: List[str], target_document_id: str, target_dataset_id: str):
def __init__(self, paragraph_id_list: List[str], target_document_id: str, target_dataset_id: str,
target_embedding_model: Embeddings = None):
self.paragraph_id_list = paragraph_id_list
self.target_document_id = target_document_id
self.target_dataset_id = target_dataset_id
self.target_embedding_model = target_embedding_model
class ListenerManagement:
@ -88,6 +92,36 @@ class ListenerManagement:
def embedding_by_problem(args, embedding_model: Embeddings):
VectorStore.get_embedding_vector().save(**args, embedding=embedding_model)
@staticmethod
def embedding_by_paragraph_list(paragraph_id_list, embedding_model: Embeddings):
try:
data_list = native_search(
{'problem': QuerySet(get_dynamics_model({'paragraph.id': django.db.models.CharField()})).filter(
**{'paragraph.id__in': paragraph_id_list}),
'paragraph': QuerySet(Paragraph).filter(id__in=paragraph_id_list)},
select_string=get_file_content(
os.path.join(PROJECT_DIR, "apps", "common", 'sql', 'list_embedding_text.sql')))
ListenerManagement.embedding_by_paragraph_data_list(data_list, paragraph_id_list=paragraph_id_list,
embedding_model=embedding_model)
except Exception as e:
max_kb_error.error(f'查询向量数据:{paragraph_id_list}出现错误{str(e)}{traceback.format_exc()}')
@staticmethod
@embedding_poxy
def embedding_by_paragraph_data_list(data_list, paragraph_id_list, embedding_model: Embeddings):
max_kb.info(f'开始--->向量化段落:{paragraph_id_list}')
try:
# 删除段落
VectorStore.get_embedding_vector().delete_by_paragraph_ids(paragraph_id_list)
# 批量向量化
VectorStore.get_embedding_vector().batch_save(data_list, embedding_model)
except Exception as e:
max_kb_error.error(f'向量化段落:{paragraph_id_list}出现错误{str(e)}{traceback.format_exc()}')
status = Status.error
finally:
QuerySet(Paragraph).filter(id__in=paragraph_id_list).update(**{'status': status})
max_kb.info(f'结束--->向量化段落:{paragraph_id_list}')
@staticmethod
@embedding_poxy
def embedding_by_paragraph(paragraph_id, embedding_model: Embeddings):
@ -227,14 +261,22 @@ class ListenerManagement:
@staticmethod
def update_embedding_dataset_id(args: UpdateEmbeddingDatasetIdArgs):
VectorStore.get_embedding_vector().update_by_paragraph_ids(args.paragraph_id_list,
{'dataset_id': args.target_dataset_id})
if args.target_embedding_model is None:
VectorStore.get_embedding_vector().update_by_paragraph_ids(args.paragraph_id_list,
{'dataset_id': args.target_dataset_id})
else:
ListenerManagement.embedding_by_paragraph_list(args.paragraph_id_list,
embedding_model=args.target_embedding_model)
@staticmethod
def update_embedding_document_id(args: UpdateEmbeddingDocumentIdArgs):
VectorStore.get_embedding_vector().update_by_paragraph_ids(args.paragraph_id_list,
{'document_id': args.target_document_id,
'dataset_id': args.target_dataset_id})
if args.target_embedding_model is None:
VectorStore.get_embedding_vector().update_by_paragraph_ids(args.paragraph_id_list,
{'document_id': args.target_document_id,
'dataset_id': args.target_dataset_id})
else:
ListenerManagement.embedding_by_paragraph_list(args.paragraph_id_list,
embedding_model=args.target_embedding_model)
@staticmethod
def delete_embedding_by_source_ids(source_ids: List[str]):

View File

@ -147,3 +147,7 @@ def get_embedding_model_by_dataset_id_list(dataset_id_list: List):
def get_embedding_model_by_dataset_id(dataset_id: str):
dataset = QuerySet(DataSet).select_related('embedding_mode').filter(id=dataset_id).first()
return EmbeddingModelManage.get_model(dataset_id, lambda _id: get_model(dataset.embedding_mode))
def get_embedding_model_by_dataset(dataset):
return EmbeddingModelManage.get_model(str(dataset.id), lambda _id: get_model(dataset.embedding_mode))

View File

@ -235,12 +235,17 @@ class DocumentSerializers(ApiMixin, serializers.Serializer):
meta={})
else:
document_list.update(dataset_id=target_dataset_id)
# 修改向量信息
ListenerManagement.update_embedding_dataset_id(UpdateEmbeddingDatasetIdArgs(
[paragraph.id for paragraph in paragraph_list],
target_dataset_id))
model = None
if dataset.embedding_mode_id != target_dataset.embedding_mode_id:
model = get_embedding_model_by_dataset_id(target_dataset_id)
pid_list = [paragraph.id for paragraph in paragraph_list]
# 修改段落信息
paragraph_list.update(dataset_id=target_dataset_id)
# 修改向量信息
ListenerManagement.update_embedding_dataset_id(UpdateEmbeddingDatasetIdArgs(
pid_list,
target_dataset_id, model))
@staticmethod
def get_target_dataset_problem(target_dataset_id: str,

View File

@ -20,9 +20,9 @@ from common.exception.app_exception import AppApiException
from common.mixins.api_mixin import ApiMixin
from common.util.common import post
from common.util.field_message import ErrMessage
from dataset.models import Paragraph, Problem, Document, ProblemParagraphMapping
from dataset.models import Paragraph, Problem, Document, ProblemParagraphMapping, DataSet
from dataset.serializers.common_serializers import update_document_char_length, BatchSerializer, ProblemParagraphObject, \
ProblemParagraphManage, get_embedding_model_by_dataset_id
ProblemParagraphManage, get_embedding_model_by_dataset_id, get_embedding_model_by_dataset
from dataset.serializers.problem_serializers import ProblemInstanceSerializer, ProblemSerializer, ProblemSerializers
from embedding.models import SourceType
@ -338,10 +338,11 @@ class ParagraphSerializers(ApiMixin, serializers.Serializer):
# 修改mapping
QuerySet(ProblemParagraphMapping).bulk_update(problem_paragraph_mapping_list,
['document_id'])
# 修改向量段落信息
ListenerManagement.update_embedding_document_id(UpdateEmbeddingDocumentIdArgs(
[paragraph.id for paragraph in paragraph_list],
target_document_id, target_dataset_id))
target_document_id, target_dataset_id, target_embedding_model=None))
# 修改段落信息
paragraph_list.update(document_id=target_document_id)
# 不同数据集迁移
@ -368,12 +369,19 @@ class ParagraphSerializers(ApiMixin, serializers.Serializer):
# 修改mapping
QuerySet(ProblemParagraphMapping).bulk_update(problem_paragraph_mapping_list,
['problem_id', 'dataset_id', 'document_id'])
# 修改向量段落信息
ListenerManagement.update_embedding_document_id(UpdateEmbeddingDocumentIdArgs(
[paragraph.id for paragraph in paragraph_list],
target_document_id, target_dataset_id))
target_dataset = QuerySet(DataSet).filter(id=target_dataset_id).first()
dataset = QuerySet(DataSet).filter(id=dataset_id).first()
embedding_model = None
if target_dataset.embedding_mode_id != dataset.embedding_mode_id:
embedding_model = get_embedding_model_by_dataset(target_dataset)
pid_list = [paragraph.id for paragraph in paragraph_list]
# 修改段落信息
paragraph_list.update(dataset_id=target_dataset_id, document_id=target_document_id)
# 修改向量段落信息
ListenerManagement.update_embedding_document_id(UpdateEmbeddingDocumentIdArgs(
pid_list,
target_document_id, target_dataset_id, target_embedding_model=embedding_model))
update_document_char_length(document_id)
update_document_char_length(target_document_id)

View File

@ -20,7 +20,8 @@ from common.event import ListenerManagement, UpdateProblemArgs
from common.mixins.api_mixin import ApiMixin
from common.util.field_message import ErrMessage
from common.util.file_util import get_file_content
from dataset.models import Problem, Paragraph, ProblemParagraphMapping
from dataset.models import Problem, Paragraph, ProblemParagraphMapping, DataSet
from dataset.serializers.common_serializers import get_embedding_model_by_dataset_id
from smartdoc.conf import PROJECT_DIR
@ -157,6 +158,8 @@ class ProblemSerializers(ApiMixin, serializers.Serializer):
content = instance.get('content')
problem = QuerySet(Problem).filter(id=problem_id,
dataset_id=dataset_id).first()
QuerySet(DataSet).filter(id=dataset_id)
problem.content = content
problem.save()
ListenerManagement.update_problem_signal.send(UpdateProblemArgs(problem_id, content))
model = get_embedding_model_by_dataset_id(dataset_id)
ListenerManagement.update_problem_signal.send(UpdateProblemArgs(problem_id, content, model))