fix: web站点 批量同步,批量删除,批量导入接口

This commit is contained in:
shaohuzhang1 2024-01-17 16:08:51 +08:00
parent 91b613c4da
commit 344c336143
5 changed files with 175 additions and 7 deletions

View File

@ -9,6 +9,7 @@
import logging
import os
import traceback
from typing import List
import django.db.models
from blinker import signal
@ -18,7 +19,7 @@ from common.config.embedding_config import VectorStore, EmbeddingModel
from common.db.search import native_search, get_dynamics_model
from common.event.common import poxy
from common.util.file_util import get_file_content
from common.util.fork import ForkManage
from common.util.fork import ForkManage, Fork
from common.util.lock import try_lock, un_lock
from dataset.models import Paragraph, Status, Document
from embedding.models import SourceType
@ -36,6 +37,13 @@ class SyncWebDatasetArgs:
self.handler = handler
class SyncWebDocumentArgs:
def __init__(self, source_url_list: List[str], selector: str, handler):
self.source_url_list = source_url_list
self.selector = selector
self.handler = handler
class ListenerManagement:
embedding_by_problem_signal = signal("embedding_by_problem")
embedding_by_paragraph_signal = signal("embedding_by_paragraph")
@ -49,6 +57,7 @@ class ListenerManagement:
disable_embedding_by_paragraph_signal = signal('disable_embedding_by_paragraph')
init_embedding_model_signal = signal('init_embedding_model')
sync_web_dataset_signal = signal('sync_web_dataset')
sync_web_document_signal = signal('sync_web_document')
@staticmethod
def embedding_by_problem(args):
@ -155,6 +164,13 @@ class ListenerManagement:
def enable_embedding_by_paragraph(paragraph_id):
VectorStore.get_embedding_vector().update_by_paragraph_id(paragraph_id, {'is_active': True})
@staticmethod
@poxy
def sync_web_document(args: SyncWebDocumentArgs):
for source_url in args.source_url_list:
result = Fork(base_fork_url=source_url, selector_list=args.selector.split(' ')).fork()
args.handler(source_url, args.selector, result)
@staticmethod
@poxy
def sync_web_dataset(args: SyncWebDatasetArgs):
@ -200,3 +216,5 @@ class ListenerManagement:
ListenerManagement.init_embedding_model_signal.connect(self.init_embedding_model)
# 同步web站点知识库
ListenerManagement.sync_web_dataset_signal.connect(self.sync_web_dataset)
# 同步web站点 文档
ListenerManagement.sync_web_document_signal.connect(self.sync_web_document)

View File

@ -10,9 +10,13 @@ import os
from typing import List
from django.db.models import QuerySet
from drf_yasg import openapi
from rest_framework import serializers
from common.db.search import native_search
from common.db.sql_execute import update_execute
from common.exception.app_exception import AppApiException
from common.mixins.api_mixin import ApiMixin
from common.util.file_util import get_file_content
from dataset.models import Paragraph
from smartdoc.conf import PROJECT_DIR
@ -29,3 +33,28 @@ def list_paragraph(paragraph_list: List[str]):
return []
return native_search(QuerySet(Paragraph).filter(id__in=paragraph_list), get_file_content(
os.path.join(PROJECT_DIR, "apps", "dataset", 'sql', 'list_paragraph.sql')))
class BatchSerializer(ApiMixin, serializers.Serializer):
id_list = serializers.ListField(required=True, child=serializers.UUIDField(required=True))
def is_valid(self, *, model=None, raise_exception=False):
super().is_valid(raise_exception=True)
if model is not None:
id_list = self.data.get('id_list')
model_list = QuerySet(model).filter(id__in=id_list)
if len(model_list) != len(id_list):
model_id_list = [str(m.id) for m in model_list]
error_id_list = list(filter(lambda row_id: not model_id_list.__contains__(row_id), id_list))
raise AppApiException(500, f"id不正确:{error_id_list}")
@staticmethod
def get_request_body_api():
return openapi.Schema(
type=openapi.TYPE_OBJECT,
properties={
'id_list': openapi.Schema(type=openapi.TYPE_ARRAY, items=openapi.Schema(type=openapi.TYPE_STRING),
title="主键id列表",
description="主键id列表")
}
)

View File

@ -21,7 +21,7 @@ from rest_framework import serializers
from common.db.search import native_search, native_page_search
from common.event.common import work_thread_pool
from common.event.listener_manage import ListenerManagement
from common.event.listener_manage import ListenerManagement, SyncWebDocumentArgs
from common.exception.app_exception import AppApiException
from common.mixins.api_mixin import ApiMixin
from common.util.common import post
@ -29,10 +29,28 @@ from common.util.file_util import get_file_content
from common.util.fork import Fork
from common.util.split_model import SplitModel, get_split_model
from dataset.models.data_set import DataSet, Document, Paragraph, Problem, Type, Status
from dataset.serializers.common_serializers import BatchSerializer
from dataset.serializers.paragraph_serializers import ParagraphSerializers, ParagraphInstanceSerializer
from smartdoc.conf import PROJECT_DIR
class DocumentWebInstanceSerializer(ApiMixin, serializers.Serializer):
source_url_list = serializers.ListField(required=True, child=serializers.CharField(required=True))
selector = serializers.CharField(required=False)
@staticmethod
def get_request_body_api():
return openapi.Schema(
type=openapi.TYPE_OBJECT,
required=['source_url_list'],
properties={
'source_url_list': openapi.Schema(type=openapi.TYPE_ARRAY, title="段落列表", description="段落列表",
items=openapi.Schema(type=openapi.TYPE_STRING)),
'selector': openapi.Schema(type=openapi.TYPE_STRING, title="文档名称", description="文档名称")
}
)
class DocumentInstanceSerializer(ApiMixin, serializers.Serializer):
name = serializers.CharField(required=True,
validators=[
@ -121,6 +139,8 @@ class DocumentSerializers(ApiMixin, serializers.Serializer):
self.is_valid(raise_exception=True)
document_id = self.data.get('document_id')
document = QuerySet(Document).filter(id=document_id).first()
if document.type != Type.web:
return True
try:
document.status = Status.embedding
document.save()
@ -301,6 +321,38 @@ class DocumentSerializers(ApiMixin, serializers.Serializer):
data={'dataset_id': dataset_id, 'document_id': document_id}).one(
with_valid=True), document_id
@staticmethod
def get_sync_handler(dataset_id):
def handler(source_url: str, selector, response: Fork.Response):
if response.status == 200:
try:
paragraphs = get_split_model('web.md').parse(response.content)
# 插入
DocumentSerializers.Create(data={'dataset_id': dataset_id}).save(
{'name': source_url, 'paragraphs': paragraphs,
'meta': {'source_url': source_url, 'selector': selector},
'type': Type.web}, with_valid=True)
except Exception as e:
logging.getLogger("max_kb_error").error(f'{str(e)}:{traceback.format_exc()}')
else:
Document(name=source_url,
meta={'source_url': source_url, 'selector': selector},
type=Type.web,
char_length=0,
status=Status.error).save()
return handler
def save_web(self, instance: Dict, with_valid=True):
if with_valid:
DocumentWebInstanceSerializer(data=instance).is_valid(raise_exception=True)
self.is_valid(raise_exception=True)
dataset_id = self.data.get('dataset_id')
source_url_list = instance.get('source_url_list')
selector = instance.get('selector')
args = SyncWebDocumentArgs(source_url_list, selector, self.get_sync_handler(dataset_id))
ListenerManagement.sync_web_document_signal.send(args)
@staticmethod
def get_paragraph_model(document_model, paragraph_list: List):
dataset_id = document_model.dataset_id
@ -331,8 +383,9 @@ class DocumentSerializers(ApiMixin, serializers.Serializer):
'meta': instance.get('meta') if instance.get('meta') is not None else {},
'type': instance.get('type') if instance.get('type') is not None else Type.base})
return DocumentSerializers.Create.get_paragraph_model(document_model, instance.get('paragraphs') if
'paragraphs' in instance else [])
return DocumentSerializers.Create.get_paragraph_model(document_model,
instance.get('paragraphs') if
'paragraphs' in instance else [])
@staticmethod
def get_request_body_api():
@ -451,6 +504,27 @@ class DocumentSerializers(ApiMixin, serializers.Serializer):
return native_search(query_set, select_string=get_file_content(
os.path.join(PROJECT_DIR, "apps", "dataset", 'sql', 'list_document.sql')), with_search_one=False),
@staticmethod
def _batch_sync(document_id_list: List[str]):
for document_id in document_id_list:
DocumentSerializers.Sync(data={'document_id': document_id}).sync()
def batch_sync(self, instance: Dict, with_valid=True):
if with_valid:
BatchSerializer(data=instance).is_valid(model=Document, raise_exception=True)
self.is_valid(raise_exception=True)
# 异步同步
work_thread_pool.submit(self._batch_sync,
instance.get('id_list'))
return True
def batch_delete(self, instance: Dict, with_valid=True):
if with_valid:
BatchSerializer(data=instance).is_valid(model=Document, raise_exception=True)
self.is_valid(raise_exception=True)
QuerySet(Document).filter(id__in=instance.get('id_list')).delete()
return True
def file_to_paragraph(file, pattern_list: List, with_filter: bool, limit: int):
data = file.read()

View File

@ -12,6 +12,7 @@ urlpatterns = [
path('dataset/<str:dataset_id>/sync_web', views.Dataset.SyncWeb.as_view()),
path('dataset/<str:dataset_id>/hit_test', views.Dataset.HitTest.as_view()),
path('dataset/<str:dataset_id>/document', views.Document.as_view(), name='document'),
path('dataset/<str:dataset_id>/document/web', views.WebDocument.as_view()),
path('dataset/<str:dataset_id>/document/_bach', views.Document.Batch.as_view()),
path('dataset/<str:dataset_id>/document/<int:current_page>/<int:page_size>', views.Document.Page.as_view()),
path('dataset/<str:dataset_id>/document/<str:document_id>', views.Document.Operate.as_view(),

View File

@ -14,11 +14,29 @@ from rest_framework.views import APIView
from rest_framework.views import Request
from common.auth import TokenAuth, has_permissions
from common.constants.permission_constants import Permission, Group, Operate, PermissionConstants
from common.event.common import work_thread_pool
from common.constants.permission_constants import Permission, Group, Operate
from common.response import result
from common.util.common import query_params_to_single_dict
from dataset.serializers.document_serializers import DocumentSerializers
from dataset.serializers.common_serializers import BatchSerializer
from dataset.serializers.document_serializers import DocumentSerializers, DocumentWebInstanceSerializer
class WebDocument(APIView):
authentication_classes = [TokenAuth]
@action(methods=['POST'], detail=False)
@swagger_auto_schema(operation_summary="创建Web站点文档",
operation_id="创建Web站点文档",
request_body=DocumentWebInstanceSerializer.get_request_body_api(),
manual_parameters=DocumentSerializers.Create.get_request_params_api(),
responses=result.get_api_response(DocumentSerializers.Operate.get_response_body_api()),
tags=["知识库/文档"])
@has_permissions(
lambda r, k: Permission(group=Group.DATASET, operate=Operate.MANAGE,
dynamic_tag=k.get('dataset_id')))
def post(self, request: Request, dataset_id: str):
return result.success(
DocumentSerializers.Create(data={'dataset_id': dataset_id}).save_web(request.data, with_valid=True))
class Document(APIView):
@ -71,6 +89,34 @@ class Document(APIView):
def post(self, request: Request, dataset_id: str):
return result.success(DocumentSerializers.Batch(data={'dataset_id': dataset_id}).batch_save(request.data))
@action(methods=['POST'], detail=False)
@swagger_auto_schema(operation_summary="批量同步文档",
operation_id="批量同步文档",
request_body=
BatchSerializer.get_request_body_api(),
manual_parameters=DocumentSerializers.Create.get_request_params_api(),
responses=result.get_default_response(),
tags=["知识库/文档"])
@has_permissions(
lambda r, k: Permission(group=Group.DATASET, operate=Operate.MANAGE,
dynamic_tag=k.get('dataset_id')))
def put(self, request: Request, dataset_id: str):
return result.success(DocumentSerializers.Batch(data={'dataset_id': dataset_id}).batch_sync(request.data))
@action(methods=['DELETE'], detail=False)
@swagger_auto_schema(operation_summary="批量删除文档",
operation_id="批量删除文档",
request_body=
BatchSerializer.get_request_body_api(),
manual_parameters=DocumentSerializers.Create.get_request_params_api(),
responses=result.get_default_response(),
tags=["知识库/文档"])
@has_permissions(
lambda r, k: Permission(group=Group.DATASET, operate=Operate.MANAGE,
dynamic_tag=k.get('dataset_id')))
def delete(self, request: Request, dataset_id: str):
return result.success(DocumentSerializers.Batch(data={'dataset_id': dataset_id}).batch_delete(request.data))
class Refresh(APIView):
authentication_classes = [TokenAuth]