From 344c336143bf77124affd2c45489fbe83f21a28b Mon Sep 17 00:00:00 2001 From: shaohuzhang1 Date: Wed, 17 Jan 2024 16:08:51 +0800 Subject: [PATCH] =?UTF-8?q?fix:=20web=E7=AB=99=E7=82=B9=20=E6=89=B9?= =?UTF-8?q?=E9=87=8F=E5=90=8C=E6=AD=A5,=E6=89=B9=E9=87=8F=E5=88=A0?= =?UTF-8?q?=E9=99=A4,=E6=89=B9=E9=87=8F=E5=AF=BC=E5=85=A5=E6=8E=A5?= =?UTF-8?q?=E5=8F=A3?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- apps/common/event/listener_manage.py | 20 ++++- .../dataset/serializers/common_serializers.py | 29 +++++++ .../serializers/document_serializers.py | 80 ++++++++++++++++++- apps/dataset/urls.py | 1 + apps/dataset/views/document.py | 52 +++++++++++- 5 files changed, 175 insertions(+), 7 deletions(-) diff --git a/apps/common/event/listener_manage.py b/apps/common/event/listener_manage.py index a053b2394..66fe14320 100644 --- a/apps/common/event/listener_manage.py +++ b/apps/common/event/listener_manage.py @@ -9,6 +9,7 @@ import logging import os import traceback +from typing import List import django.db.models from blinker import signal @@ -18,7 +19,7 @@ from common.config.embedding_config import VectorStore, EmbeddingModel from common.db.search import native_search, get_dynamics_model from common.event.common import poxy from common.util.file_util import get_file_content -from common.util.fork import ForkManage +from common.util.fork import ForkManage, Fork from common.util.lock import try_lock, un_lock from dataset.models import Paragraph, Status, Document from embedding.models import SourceType @@ -36,6 +37,13 @@ class SyncWebDatasetArgs: self.handler = handler +class SyncWebDocumentArgs: + def __init__(self, source_url_list: List[str], selector: str, handler): + self.source_url_list = source_url_list + self.selector = selector + self.handler = handler + + class ListenerManagement: embedding_by_problem_signal = signal("embedding_by_problem") embedding_by_paragraph_signal = signal("embedding_by_paragraph") @@ -49,6 +57,7 @@ class ListenerManagement: disable_embedding_by_paragraph_signal = signal('disable_embedding_by_paragraph') init_embedding_model_signal = signal('init_embedding_model') sync_web_dataset_signal = signal('sync_web_dataset') + sync_web_document_signal = signal('sync_web_document') @staticmethod def embedding_by_problem(args): @@ -155,6 +164,13 @@ class ListenerManagement: def enable_embedding_by_paragraph(paragraph_id): VectorStore.get_embedding_vector().update_by_paragraph_id(paragraph_id, {'is_active': True}) + @staticmethod + @poxy + def sync_web_document(args: SyncWebDocumentArgs): + for source_url in args.source_url_list: + result = Fork(base_fork_url=source_url, selector_list=args.selector.split(' ')).fork() + args.handler(source_url, args.selector, result) + @staticmethod @poxy def sync_web_dataset(args: SyncWebDatasetArgs): @@ -200,3 +216,5 @@ class ListenerManagement: ListenerManagement.init_embedding_model_signal.connect(self.init_embedding_model) # 同步web站点知识库 ListenerManagement.sync_web_dataset_signal.connect(self.sync_web_dataset) + # 同步web站点 文档 + ListenerManagement.sync_web_document_signal.connect(self.sync_web_document) diff --git a/apps/dataset/serializers/common_serializers.py b/apps/dataset/serializers/common_serializers.py index 77f3bca85..1d8010136 100644 --- a/apps/dataset/serializers/common_serializers.py +++ b/apps/dataset/serializers/common_serializers.py @@ -10,9 +10,13 @@ import os from typing import List from django.db.models import QuerySet +from drf_yasg import openapi +from rest_framework import serializers from common.db.search import native_search from common.db.sql_execute import update_execute +from common.exception.app_exception import AppApiException +from common.mixins.api_mixin import ApiMixin from common.util.file_util import get_file_content from dataset.models import Paragraph from smartdoc.conf import PROJECT_DIR @@ -29,3 +33,28 @@ def list_paragraph(paragraph_list: List[str]): return [] return native_search(QuerySet(Paragraph).filter(id__in=paragraph_list), get_file_content( os.path.join(PROJECT_DIR, "apps", "dataset", 'sql', 'list_paragraph.sql'))) + + +class BatchSerializer(ApiMixin, serializers.Serializer): + id_list = serializers.ListField(required=True, child=serializers.UUIDField(required=True)) + + def is_valid(self, *, model=None, raise_exception=False): + super().is_valid(raise_exception=True) + if model is not None: + id_list = self.data.get('id_list') + model_list = QuerySet(model).filter(id__in=id_list) + if len(model_list) != len(id_list): + model_id_list = [str(m.id) for m in model_list] + error_id_list = list(filter(lambda row_id: not model_id_list.__contains__(row_id), id_list)) + raise AppApiException(500, f"id不正确:{error_id_list}") + + @staticmethod + def get_request_body_api(): + return openapi.Schema( + type=openapi.TYPE_OBJECT, + properties={ + 'id_list': openapi.Schema(type=openapi.TYPE_ARRAY, items=openapi.Schema(type=openapi.TYPE_STRING), + title="主键id列表", + description="主键id列表") + } + ) diff --git a/apps/dataset/serializers/document_serializers.py b/apps/dataset/serializers/document_serializers.py index 7665b4c82..653c49166 100644 --- a/apps/dataset/serializers/document_serializers.py +++ b/apps/dataset/serializers/document_serializers.py @@ -21,7 +21,7 @@ from rest_framework import serializers from common.db.search import native_search, native_page_search from common.event.common import work_thread_pool -from common.event.listener_manage import ListenerManagement +from common.event.listener_manage import ListenerManagement, SyncWebDocumentArgs from common.exception.app_exception import AppApiException from common.mixins.api_mixin import ApiMixin from common.util.common import post @@ -29,10 +29,28 @@ from common.util.file_util import get_file_content from common.util.fork import Fork from common.util.split_model import SplitModel, get_split_model from dataset.models.data_set import DataSet, Document, Paragraph, Problem, Type, Status +from dataset.serializers.common_serializers import BatchSerializer from dataset.serializers.paragraph_serializers import ParagraphSerializers, ParagraphInstanceSerializer from smartdoc.conf import PROJECT_DIR +class DocumentWebInstanceSerializer(ApiMixin, serializers.Serializer): + source_url_list = serializers.ListField(required=True, child=serializers.CharField(required=True)) + selector = serializers.CharField(required=False) + + @staticmethod + def get_request_body_api(): + return openapi.Schema( + type=openapi.TYPE_OBJECT, + required=['source_url_list'], + properties={ + 'source_url_list': openapi.Schema(type=openapi.TYPE_ARRAY, title="段落列表", description="段落列表", + items=openapi.Schema(type=openapi.TYPE_STRING)), + 'selector': openapi.Schema(type=openapi.TYPE_STRING, title="文档名称", description="文档名称") + } + ) + + class DocumentInstanceSerializer(ApiMixin, serializers.Serializer): name = serializers.CharField(required=True, validators=[ @@ -121,6 +139,8 @@ class DocumentSerializers(ApiMixin, serializers.Serializer): self.is_valid(raise_exception=True) document_id = self.data.get('document_id') document = QuerySet(Document).filter(id=document_id).first() + if document.type != Type.web: + return True try: document.status = Status.embedding document.save() @@ -301,6 +321,38 @@ class DocumentSerializers(ApiMixin, serializers.Serializer): data={'dataset_id': dataset_id, 'document_id': document_id}).one( with_valid=True), document_id + @staticmethod + def get_sync_handler(dataset_id): + def handler(source_url: str, selector, response: Fork.Response): + if response.status == 200: + try: + paragraphs = get_split_model('web.md').parse(response.content) + # 插入 + DocumentSerializers.Create(data={'dataset_id': dataset_id}).save( + {'name': source_url, 'paragraphs': paragraphs, + 'meta': {'source_url': source_url, 'selector': selector}, + 'type': Type.web}, with_valid=True) + except Exception as e: + logging.getLogger("max_kb_error").error(f'{str(e)}:{traceback.format_exc()}') + else: + Document(name=source_url, + meta={'source_url': source_url, 'selector': selector}, + type=Type.web, + char_length=0, + status=Status.error).save() + + return handler + + def save_web(self, instance: Dict, with_valid=True): + if with_valid: + DocumentWebInstanceSerializer(data=instance).is_valid(raise_exception=True) + self.is_valid(raise_exception=True) + dataset_id = self.data.get('dataset_id') + source_url_list = instance.get('source_url_list') + selector = instance.get('selector') + args = SyncWebDocumentArgs(source_url_list, selector, self.get_sync_handler(dataset_id)) + ListenerManagement.sync_web_document_signal.send(args) + @staticmethod def get_paragraph_model(document_model, paragraph_list: List): dataset_id = document_model.dataset_id @@ -331,8 +383,9 @@ class DocumentSerializers(ApiMixin, serializers.Serializer): 'meta': instance.get('meta') if instance.get('meta') is not None else {}, 'type': instance.get('type') if instance.get('type') is not None else Type.base}) - return DocumentSerializers.Create.get_paragraph_model(document_model, instance.get('paragraphs') if - 'paragraphs' in instance else []) + return DocumentSerializers.Create.get_paragraph_model(document_model, + instance.get('paragraphs') if + 'paragraphs' in instance else []) @staticmethod def get_request_body_api(): @@ -451,6 +504,27 @@ class DocumentSerializers(ApiMixin, serializers.Serializer): return native_search(query_set, select_string=get_file_content( os.path.join(PROJECT_DIR, "apps", "dataset", 'sql', 'list_document.sql')), with_search_one=False), + @staticmethod + def _batch_sync(document_id_list: List[str]): + for document_id in document_id_list: + DocumentSerializers.Sync(data={'document_id': document_id}).sync() + + def batch_sync(self, instance: Dict, with_valid=True): + if with_valid: + BatchSerializer(data=instance).is_valid(model=Document, raise_exception=True) + self.is_valid(raise_exception=True) + # 异步同步 + work_thread_pool.submit(self._batch_sync, + instance.get('id_list')) + return True + + def batch_delete(self, instance: Dict, with_valid=True): + if with_valid: + BatchSerializer(data=instance).is_valid(model=Document, raise_exception=True) + self.is_valid(raise_exception=True) + QuerySet(Document).filter(id__in=instance.get('id_list')).delete() + return True + def file_to_paragraph(file, pattern_list: List, with_filter: bool, limit: int): data = file.read() diff --git a/apps/dataset/urls.py b/apps/dataset/urls.py index 33f0579f5..9504741f9 100644 --- a/apps/dataset/urls.py +++ b/apps/dataset/urls.py @@ -12,6 +12,7 @@ urlpatterns = [ path('dataset//sync_web', views.Dataset.SyncWeb.as_view()), path('dataset//hit_test', views.Dataset.HitTest.as_view()), path('dataset//document', views.Document.as_view(), name='document'), + path('dataset//document/web', views.WebDocument.as_view()), path('dataset//document/_bach', views.Document.Batch.as_view()), path('dataset//document//', views.Document.Page.as_view()), path('dataset//document/', views.Document.Operate.as_view(), diff --git a/apps/dataset/views/document.py b/apps/dataset/views/document.py index cf5ca397e..ad42e5554 100644 --- a/apps/dataset/views/document.py +++ b/apps/dataset/views/document.py @@ -14,11 +14,29 @@ from rest_framework.views import APIView from rest_framework.views import Request from common.auth import TokenAuth, has_permissions -from common.constants.permission_constants import Permission, Group, Operate, PermissionConstants -from common.event.common import work_thread_pool +from common.constants.permission_constants import Permission, Group, Operate from common.response import result from common.util.common import query_params_to_single_dict -from dataset.serializers.document_serializers import DocumentSerializers +from dataset.serializers.common_serializers import BatchSerializer +from dataset.serializers.document_serializers import DocumentSerializers, DocumentWebInstanceSerializer + + +class WebDocument(APIView): + authentication_classes = [TokenAuth] + + @action(methods=['POST'], detail=False) + @swagger_auto_schema(operation_summary="创建Web站点文档", + operation_id="创建Web站点文档", + request_body=DocumentWebInstanceSerializer.get_request_body_api(), + manual_parameters=DocumentSerializers.Create.get_request_params_api(), + responses=result.get_api_response(DocumentSerializers.Operate.get_response_body_api()), + tags=["知识库/文档"]) + @has_permissions( + lambda r, k: Permission(group=Group.DATASET, operate=Operate.MANAGE, + dynamic_tag=k.get('dataset_id'))) + def post(self, request: Request, dataset_id: str): + return result.success( + DocumentSerializers.Create(data={'dataset_id': dataset_id}).save_web(request.data, with_valid=True)) class Document(APIView): @@ -71,6 +89,34 @@ class Document(APIView): def post(self, request: Request, dataset_id: str): return result.success(DocumentSerializers.Batch(data={'dataset_id': dataset_id}).batch_save(request.data)) + @action(methods=['POST'], detail=False) + @swagger_auto_schema(operation_summary="批量同步文档", + operation_id="批量同步文档", + request_body= + BatchSerializer.get_request_body_api(), + manual_parameters=DocumentSerializers.Create.get_request_params_api(), + responses=result.get_default_response(), + tags=["知识库/文档"]) + @has_permissions( + lambda r, k: Permission(group=Group.DATASET, operate=Operate.MANAGE, + dynamic_tag=k.get('dataset_id'))) + def put(self, request: Request, dataset_id: str): + return result.success(DocumentSerializers.Batch(data={'dataset_id': dataset_id}).batch_sync(request.data)) + + @action(methods=['DELETE'], detail=False) + @swagger_auto_schema(operation_summary="批量删除文档", + operation_id="批量删除文档", + request_body= + BatchSerializer.get_request_body_api(), + manual_parameters=DocumentSerializers.Create.get_request_params_api(), + responses=result.get_default_response(), + tags=["知识库/文档"]) + @has_permissions( + lambda r, k: Permission(group=Group.DATASET, operate=Operate.MANAGE, + dynamic_tag=k.get('dataset_id'))) + def delete(self, request: Request, dataset_id: str): + return result.success(DocumentSerializers.Batch(data={'dataset_id': dataset_id}).batch_delete(request.data)) + class Refresh(APIView): authentication_classes = [TokenAuth]