feat: 支持 文档刷新 -> 重写向量化

This commit is contained in:
shaohuzhang1 2023-12-13 18:03:57 +08:00
parent fd5fcb16e1
commit d5c3f04ce4
4 changed files with 73 additions and 29 deletions

View File

@ -46,15 +46,21 @@ class ListenerManagement:
:param paragraph_id: 段落id
:return: None
"""
data_list = native_search(
{'problem': QuerySet(get_dynamics_model({'problem.paragraph_id': django.db.models.CharField()})).filter(
**{'problem.paragraph_id': paragraph_id}),
'paragraph': QuerySet(Paragraph).filter(id=paragraph_id)},
select_string=get_file_content(
os.path.join(PROJECT_DIR, "apps", "common", 'sql', 'list_embedding_text.sql')))
# 批量向量化
VectorStore.get_embedding_vector().batch_save(data_list)
QuerySet(Paragraph).filter(id=paragraph_id).update(**{'status': Status.success.value})
status = Status.success
try:
data_list = native_search(
{'problem': QuerySet(get_dynamics_model({'problem.paragraph_id': django.db.models.CharField()})).filter(
**{'problem.paragraph_id': paragraph_id}),
'paragraph': QuerySet(Paragraph).filter(id=paragraph_id)},
select_string=get_file_content(
os.path.join(PROJECT_DIR, "apps", "common", 'sql', 'list_embedding_text.sql')))
# 删除段落
VectorStore.get_embedding_vector().delete_by_paragraph_id(paragraph_id)
# 批量向量化
VectorStore.get_embedding_vector().batch_save(data_list)
except Exception as e:
status = Status.error
QuerySet(Paragraph).filter(id=paragraph_id).update(**{'status': status})
@staticmethod
@poxy
@ -64,17 +70,23 @@ class ListenerManagement:
:param document_id: 文档id
:return: None
"""
data_list = native_search(
{'problem': QuerySet(get_dynamics_model({'problem.document_id': django.db.models.CharField()})).filter(
**{'problem.document_id': document_id}),
'paragraph': QuerySet(Paragraph).filter(document_id=document_id)},
select_string=get_file_content(
os.path.join(PROJECT_DIR, "apps", "common", 'sql', 'list_embedding_text.sql')))
# 批量向量化
VectorStore.get_embedding_vector().batch_save(data_list)
status = Status.success
try:
data_list = native_search(
{'problem': QuerySet(get_dynamics_model({'problem.document_id': django.db.models.CharField()})).filter(
**{'problem.document_id': document_id}),
'paragraph': QuerySet(Paragraph).filter(document_id=document_id)},
select_string=get_file_content(
os.path.join(PROJECT_DIR, "apps", "common", 'sql', 'list_embedding_text.sql')))
# 删除文档向量数据
VectorStore.get_embedding_vector().delete_by_document_id(document_id)
# 批量向量化
VectorStore.get_embedding_vector().batch_save(data_list)
except Exception as e:
status = Status.error
# 修改状态
QuerySet(Document).filter(id=document_id).update(**{'status': Status.success.value})
QuerySet(Paragraph).filter(document_id=document_id).update(**{'status': Status.success.value})
QuerySet(Document).filter(id=document_id).update(**{'status': status})
QuerySet(Paragraph).filter(document_id=document_id).update(**{'status': status})
@staticmethod
@poxy
@ -84,17 +96,23 @@ class ListenerManagement:
:param dataset_id: 数据集id
:return: None
"""
data_list = native_search(
{'problem': QuerySet(get_dynamics_model({'problem.dataset_id': django.db.models.CharField()})).filter(
**{'problem.dataset_id': dataset_id}),
'paragraph': QuerySet(Paragraph).filter(dataset_id=dataset_id)},
select_string=get_file_content(
os.path.join(PROJECT_DIR, "apps", "common", 'sql', 'list_embedding_text.sql')))
# 批量向量化
VectorStore.get_embedding_vector().batch_save(data_list)
status = Status.success
try:
data_list = native_search(
{'problem': QuerySet(get_dynamics_model({'problem.dataset_id': django.db.models.CharField()})).filter(
**{'problem.dataset_id': dataset_id}),
'paragraph': QuerySet(Paragraph).filter(dataset_id=dataset_id)},
select_string=get_file_content(
os.path.join(PROJECT_DIR, "apps", "common", 'sql', 'list_embedding_text.sql')))
# 删除数据集相关向量数据
VectorStore.get_embedding_vector().delete_by_dataset_id(dataset_id)
# 批量向量化
VectorStore.get_embedding_vector().batch_save(data_list)
except Exception as e:
status = Status.error
# 修改文档 以及段落的状态
QuerySet(Document).filter(dataset_id=dataset_id).update(**{'status': Status.success.value})
QuerySet(Paragraph).filter(dataset_id=dataset_id).update(**{'status': Status.success.value})
QuerySet(Document).filter(dataset_id=dataset_id).update(**{'status': status})
QuerySet(Paragraph).filter(dataset_id=dataset_id).update(**{'status': status})
@staticmethod
def delete_embedding_by_document(document_id):

View File

@ -140,6 +140,13 @@ class DocumentSerializers(ApiMixin, serializers.Serializer):
_document.save()
return self.one()
def refresh(self, with_valid=True):
if with_valid:
self.is_valid(raise_exception=True)
document_id = self.data.get("document_id")
ListenerManagement.embedding_by_document_signal.send(document_id)
return True
@transaction.atomic
def delete(self):
document_id = self.data.get("document_id")

View File

@ -17,6 +17,7 @@ urlpatterns = [
name="document_operate"),
path('dataset/document/split_pattern', views.Document.SplitPattern.as_view(),
name="document_operate"),
path('dataset/<str:dataset_id>/document/<str:document_id>/refresh', views.Document.Refresh.as_view()),
path('dataset/<str:dataset_id>/document/<str:document_id>/paragraph', views.Paragraph.as_view()),
path('dataset/<str:dataset_id>/document/<str:document_id>/paragraph/<int:current_page>/<int:page_size>',
views.Paragraph.Page.as_view(), name='paragraph_page'),

View File

@ -70,6 +70,24 @@ class Document(APIView):
def post(self, request: Request, dataset_id: str):
return result.success(DocumentSerializers.Batch(data={'dataset_id': dataset_id}).batch_save(request.data))
class Refresh(APIView):
authentication_classes = [TokenAuth]
@action(methods=['PUT'], detail=False)
@swagger_auto_schema(operation_summary="刷新文档向量库",
operation_id="刷新文档向量库",
manual_parameters=DocumentSerializers.Operate.get_request_params_api(),
responses=result.get_default_response(),
tags=["数据集/文档"]
)
@has_permissions(
lambda r, k: Permission(group=Group.DATASET, operate=Operate.MANAGE,
dynamic_tag=k.get('dataset_id')))
def put(self, request: Request, dataset_id: str, document_id: str):
return result.success(
DocumentSerializers.Operate(data={'document_id': document_id, 'dataset_id': dataset_id}).refresh(
))
class Operate(APIView):
authentication_classes = [TokenAuth]