diff --git a/apps/common/constants/permission_constants.py b/apps/common/constants/permission_constants.py index e4106ed7e..6e1df31d2 100644 --- a/apps/common/constants/permission_constants.py +++ b/apps/common/constants/permission_constants.py @@ -99,6 +99,9 @@ class PermissionConstants(Enum): DATASET_READ = Permission(group=Group.DATASET, operate=Operate.READ, roles=[RoleConstants.ADMIN, RoleConstants.USER]) + DATASET_EDIT = Permission(group=Group.DATASET, operate=Operate.EDIT, + roles=[RoleConstants.ADMIN, RoleConstants.USER]) + APPLICATION_READ = Permission(group=Group.APPLICATION, operate=Operate.READ, roles=[RoleConstants.ADMIN, RoleConstants.USER]) diff --git a/apps/dataset/serializers/dataset_serializers.py b/apps/dataset/serializers/dataset_serializers.py index 2f9737b35..3baa7d5c7 100644 --- a/apps/dataset/serializers/dataset_serializers.py +++ b/apps/dataset/serializers/dataset_serializers.py @@ -8,6 +8,7 @@ """ import logging import os.path +import re import traceback import uuid from functools import reduce @@ -483,6 +484,97 @@ class DataSetSerializers(serializers.ModelSerializer): return [{**p, 'similarity': hit_dict.get(p.get('id')).get('similarity'), 'comprehensive_score': hit_dict.get(p.get('id')).get('comprehensive_score')} for p in p_list] + class SyncWeb(ApiMixin, serializers.Serializer): + id = serializers.CharField(required=True) + user_id = serializers.UUIDField(required=False) + sync_type = serializers.CharField(required=True, validators=[ + validators.RegexValidator(regex=re.compile("^replace|complete$"), + message="replace|complete", code=500) + ]) + + def is_valid(self, *, raise_exception=False): + super().is_valid(raise_exception=True) + first = QuerySet(DataSet).filter(id=self.data.get("id")).first() + if first is None: + raise AppApiException(300, "id不存在") + if first.type != Type.web: + raise AppApiException(500, "只有web站点类型才支持同步") + + def sync(self, with_valid=True): + if with_valid: + self.is_valid(raise_exception=True) + sync_type = self.data.get('sync_type') + dataset_id = self.data.get('id') + dataset = QuerySet(DataSet).get(id=dataset_id) + self.__getattribute__(sync_type + '_sync')(dataset) + return True + + @staticmethod + def get_sync_handler(dataset): + def handler(child_link: ChildLink, response: Fork.Response): + if response.status == 200: + try: + document_name = child_link.tag.text if child_link.tag is not None and len( + child_link.tag.text.strip()) > 0 else child_link.url + paragraphs = get_split_model('web.md').parse(response.content) + first = QuerySet(Document).filter(meta__source_url=child_link.url).first() + if first is not None: + # 如果存在,使用文档同步 + DocumentSerializers.Sync(data={'document_id': first.id}).sync() + # 发送向量化指令 + ListenerManagement.embedding_by_document_signal.send(first.id) + else: + # 插入 + DocumentSerializers.Create(data={'dataset_id': dataset.id}).save( + {'name': document_name, 'paragraphs': paragraphs, + 'meta': {'source_url': child_link.url, 'selector': dataset.meta.get('selector')}, + 'type': Type.web}, with_valid=True) + except Exception as e: + logging.getLogger("max_kb_error").error(f'{str(e)}:{traceback.format_exc()}') + + return handler + + def replace_sync(self, dataset): + """ + 替换同步 + :return: + """ + url = dataset.meta.get('source_url') + selector = dataset.meta.get('selector') if 'selector' in dataset.meta else None + ListenerManagement.sync_web_dataset_signal.send( + SyncWebDatasetArgs(str(dataset.id), url, selector, + self.get_sync_handler(dataset))) + + def complete_sync(self, dataset): + """ + 完整同步 删掉当前数据集下所有的文档,再进行同步 + :return: + """ + # 删除文档 + QuerySet(Document).filter(dataset=dataset).delete() + # 删除段落 + QuerySet(Paragraph).filter(dataset=dataset).delete() + # 删除问题 + QuerySet(Problem).filter(dataset=dataset).delete() + # 删除向量 + ListenerManagement.delete_embedding_by_dataset_signal.send(self.data.get('id')) + # 同步 + self.replace_sync(dataset) + + @staticmethod + def get_request_params_api(): + return [openapi.Parameter(name='dataset_id', + in_=openapi.IN_PATH, + type=openapi.TYPE_STRING, + required=True, + description='知识库id'), + openapi.Parameter(name='sync_type', + in_=openapi.IN_QUERY, + type=openapi.TYPE_STRING, + required=True, + description='同步类型->replace:替换同步,complete:完整同步') + ] + class Operate(ApiMixin, serializers.Serializer): id = serializers.CharField(required=True) user_id = serializers.UUIDField(required=False) diff --git a/apps/dataset/serializers/document_serializers.py b/apps/dataset/serializers/document_serializers.py index 519276cbf..bc0a4363e 100644 --- a/apps/dataset/serializers/document_serializers.py +++ b/apps/dataset/serializers/document_serializers.py @@ -6,7 +6,9 @@ @date:2023/9/22 13:43 @desc: """ +import logging import os +import traceback import uuid from functools import reduce from typing import List, Dict @@ -23,8 +25,9 @@ from common.exception.app_exception import AppApiException from common.mixins.api_mixin import ApiMixin from common.util.common import post from common.util.file_util import get_file_content +from common.util.fork import Fork from common.util.split_model import SplitModel, get_split_model -from dataset.models.data_set import DataSet, Document, Paragraph, Problem, Type +from dataset.models.data_set import DataSet, Document, Paragraph, Problem, Type, Status from dataset.serializers.paragraph_serializers import ParagraphSerializers, ParagraphInstanceSerializer from smartdoc.conf import PROJECT_DIR @@ -100,6 +103,58 @@ class DocumentSerializers(ApiMixin, serializers.Serializer): title="文档列表", description="文档列表", items=DocumentSerializers.Operate.get_response_body_api()) + class Sync(ApiMixin, serializers.Serializer): + document_id = serializers.UUIDField(required=True) + + def is_valid(self, *, raise_exception=False): + super().is_valid(raise_exception=True) + document_id = self.data.get('document_id') + first = QuerySet(Document).filter(id=document_id).first() + if first is None: + raise AppApiException(500, "文档id不存在") + if first.type != Type.web: + raise AppApiException(500, "只有web站点类型才支持同步") + + def sync(self, with_valid=True): + if with_valid: + self.is_valid(raise_exception=True) + document_id = self.data.get('document_id') + document = QuerySet(Document).filter(id=document_id).first() + try: + document.status = Status.embedding + document.save() + source_url = document.meta.get('source_url') + selector_list = document.meta.get('selector').split(" ") if 'selector' in document.meta else [] + result = Fork(source_url, selector_list).fork() + if result.status == 200: + # 删除段落 + QuerySet(model=Paragraph).filter(document_id=document_id).delete() + # 删除问题 + QuerySet(model=Problem).filter(document_id=document_id).delete() + # 删除向量库 + ListenerManagement.delete_embedding_by_document_signal.send(document_id) + paragraphs = get_split_model('web.md').parse(result.content) + document.char_length = reduce(lambda x, y: x + y, + [len(p.get('content')) for p in paragraphs], + 0) + document.save() + document_paragraph_model = DocumentSerializers.Create.get_paragraph_model(document, paragraphs) + + paragraph_model_list = document_paragraph_model.get('paragraph_model_list') + problem_model_list = document_paragraph_model.get('problem_model_list') + # 批量插入段落 + QuerySet(Paragraph).bulk_create(paragraph_model_list) if len(paragraph_model_list) > 0 else None + # 批量插入问题 + QuerySet(Problem).bulk_create(problem_model_list) if len(problem_model_list) > 0 else None + else: + document.status = Status.error + document.save() + except Exception as e: + logging.getLogger("max_kb_error").error(f'{str(e)}:{traceback.format_exc()}') + document.status = Status.error + document.save() + return True + class Operate(ApiMixin, serializers.Serializer): document_id = serializers.UUIDField(required=True) @@ -146,6 +201,11 @@ class DocumentSerializers(ApiMixin, serializers.Serializer): if with_valid: self.is_valid(raise_exception=True) document_id = self.data.get("document_id") + document = QuerySet(Document).filter(id=document_id).first() + if document.type == Type.web: + # 如果是web站点,就是先同步 + DocumentSerializers.Sync(data={'document_id': document_id}).sync() + ListenerManagement.embedding_by_document_signal.send(document_id) return True @@ -236,21 +296,11 @@ class DocumentSerializers(ApiMixin, serializers.Serializer): with_valid=True), document_id @staticmethod - def get_document_paragraph_model(dataset_id, instance: Dict): - document_model = Document( - **{'dataset_id': dataset_id, - 'id': uuid.uuid1(), - 'name': instance.get('name'), - 'char_length': reduce(lambda x, y: x + y, - [len(p.get('content')) for p in instance.get('paragraphs', [])], - 0), - 'meta': instance.get('meta') if instance.get('meta') is not None else {}, - 'type': instance.get('type') if instance.get('type') is not None else Type.base}) - + def get_paragraph_model(document_model, paragraph_list: List): + dataset_id = document_model.dataset_id paragraph_model_dict_list = [ParagraphSerializers.Create( data={'dataset_id': dataset_id, 'document_id': str(document_model.id)}).get_paragraph_problem_model( - dataset_id, document_model.id, paragraph) for paragraph in (instance.get('paragraphs') if - 'paragraphs' in instance else [])] + dataset_id, document_model.id, paragraph) for paragraph in paragraph_list] paragraph_model_list = [] problem_model_list = [] @@ -263,6 +313,21 @@ class DocumentSerializers(ApiMixin, serializers.Serializer): return {'document': document_model, 'paragraph_model_list': paragraph_model_list, 'problem_model_list': problem_model_list} + @staticmethod + def get_document_paragraph_model(dataset_id, instance: Dict): + document_model = Document( + **{'dataset_id': dataset_id, + 'id': uuid.uuid1(), + 'name': instance.get('name'), + 'char_length': reduce(lambda x, y: x + y, + [len(p.get('content')) for p in instance.get('paragraphs', [])], + 0), + 'meta': instance.get('meta') if instance.get('meta') is not None else {}, + 'type': instance.get('type') if instance.get('type') is not None else Type.base}) + + return DocumentSerializers.Create.get_paragraph_model(document_model, instance.get('paragraphs') if + 'paragraphs' in instance else []) + @staticmethod def get_request_body_api(): return DocumentInstanceSerializer.get_request_body_api() diff --git a/apps/dataset/urls.py b/apps/dataset/urls.py index 49d7903d6..33f0579f5 100644 --- a/apps/dataset/urls.py +++ b/apps/dataset/urls.py @@ -9,6 +9,7 @@ urlpatterns = [ path('dataset/', views.Dataset.Operate.as_view(), name="dataset_key"), path('dataset//application', views.Dataset.Application.as_view()), path('dataset//', views.Dataset.Page.as_view(), name="dataset"), + path('dataset//sync_web', views.Dataset.SyncWeb.as_view()), path('dataset//hit_test', views.Dataset.HitTest.as_view()), path('dataset//document', views.Document.as_view(), name='document'), path('dataset//document/_bach', views.Document.Batch.as_view()), diff --git a/apps/dataset/views/dataset.py b/apps/dataset/views/dataset.py index 5f7fab5bc..b6e63948b 100644 --- a/apps/dataset/views/dataset.py +++ b/apps/dataset/views/dataset.py @@ -23,6 +23,24 @@ from dataset.serializers.dataset_serializers import DataSetSerializers class Dataset(APIView): authentication_classes = [TokenAuth] + class SyncWeb(APIView): + authentication_classes = [TokenAuth] + + @action(methods=['PUT'], detail=False) + @swagger_auto_schema(operation_summary="同步Web站点知识库", + operation_id="同步Web站点知识库", + manual_parameters=DataSetSerializers.SyncWeb.get_request_params_api(), + responses=result.get_default_response(), + tags=["知识库"]) + @has_permissions(lambda r, keywords: Permission(group=Group.DATASET, operate=Operate.MANAGE, + dynamic_tag=keywords.get('dataset_id')), + lambda r, k: Permission(group=Group.DATASET, operate=Operate.DELETE, + dynamic_tag=k.get('dataset_id')), compare=CompareConstants.AND) + def put(self, request: Request, dataset_id: str): + return result.success(DataSetSerializers.SyncWeb( + data={'sync_type': request.query_params.get('sync_type'), 'id': dataset_id, + 'user_id': str(request.user.id)}).sync()) + class CreateWebDataset(APIView): authentication_classes = [TokenAuth]