feat: web站点数据集文档同步

This commit is contained in:
shaohuzhang1 2024-01-03 11:51:48 +08:00
parent edbc8561c7
commit 82c8e322fb
5 changed files with 193 additions and 14 deletions

View File

@ -99,6 +99,9 @@ class PermissionConstants(Enum):
DATASET_READ = Permission(group=Group.DATASET, operate=Operate.READ,
roles=[RoleConstants.ADMIN, RoleConstants.USER])
DATASET_EDIT = Permission(group=Group.DATASET, operate=Operate.EDIT,
roles=[RoleConstants.ADMIN, RoleConstants.USER])
APPLICATION_READ = Permission(group=Group.APPLICATION, operate=Operate.READ,
roles=[RoleConstants.ADMIN, RoleConstants.USER])

View File

@ -8,6 +8,7 @@
"""
import logging
import os.path
import re
import traceback
import uuid
from functools import reduce
@ -483,6 +484,97 @@ class DataSetSerializers(serializers.ModelSerializer):
return [{**p, 'similarity': hit_dict.get(p.get('id')).get('similarity'),
'comprehensive_score': hit_dict.get(p.get('id')).get('comprehensive_score')} for p in p_list]
class SyncWeb(ApiMixin, serializers.Serializer):
id = serializers.CharField(required=True)
user_id = serializers.UUIDField(required=False)
sync_type = serializers.CharField(required=True, validators=[
validators.RegexValidator(regex=re.compile("^replace|complete$"),
message="replace|complete", code=500)
])
def is_valid(self, *, raise_exception=False):
super().is_valid(raise_exception=True)
first = QuerySet(DataSet).filter(id=self.data.get("id")).first()
if first is None:
raise AppApiException(300, "id不存在")
if first.type != Type.web:
raise AppApiException(500, "只有web站点类型才支持同步")
def sync(self, with_valid=True):
if with_valid:
self.is_valid(raise_exception=True)
sync_type = self.data.get('sync_type')
dataset_id = self.data.get('id')
dataset = QuerySet(DataSet).get(id=dataset_id)
self.__getattribute__(sync_type + '_sync')(dataset)
return True
@staticmethod
def get_sync_handler(dataset):
def handler(child_link: ChildLink, response: Fork.Response):
if response.status == 200:
try:
document_name = child_link.tag.text if child_link.tag is not None and len(
child_link.tag.text.strip()) > 0 else child_link.url
paragraphs = get_split_model('web.md').parse(response.content)
first = QuerySet(Document).filter(meta__source_url=child_link.url).first()
if first is not None:
# 如果存在,使用文档同步
DocumentSerializers.Sync(data={'document_id': first.id}).sync()
# 发送向量化指令
ListenerManagement.embedding_by_document_signal.send(first.id)
else:
# 插入
DocumentSerializers.Create(data={'dataset_id': dataset.id}).save(
{'name': document_name, 'paragraphs': paragraphs,
'meta': {'source_url': child_link.url, 'selector': dataset.meta.get('selector')},
'type': Type.web}, with_valid=True)
except Exception as e:
logging.getLogger("max_kb_error").error(f'{str(e)}:{traceback.format_exc()}')
return handler
def replace_sync(self, dataset):
"""
替换同步
:return:
"""
url = dataset.meta.get('source_url')
selector = dataset.meta.get('selector') if 'selector' in dataset.meta else None
ListenerManagement.sync_web_dataset_signal.send(
SyncWebDatasetArgs(str(dataset.id), url, selector,
self.get_sync_handler(dataset)))
def complete_sync(self, dataset):
"""
完整同步 删掉当前数据集下所有的文档,再进行同步
:return:
"""
# 删除文档
QuerySet(Document).filter(dataset=dataset).delete()
# 删除段落
QuerySet(Paragraph).filter(dataset=dataset).delete()
# 删除问题
QuerySet(Problem).filter(dataset=dataset).delete()
# 删除向量
ListenerManagement.delete_embedding_by_dataset_signal.send(self.data.get('id'))
# 同步
self.replace_sync(dataset)
@staticmethod
def get_request_params_api():
return [openapi.Parameter(name='dataset_id',
in_=openapi.IN_PATH,
type=openapi.TYPE_STRING,
required=True,
description='知识库id'),
openapi.Parameter(name='sync_type',
in_=openapi.IN_QUERY,
type=openapi.TYPE_STRING,
required=True,
description='同步类型->replace:替换同步,complete:完整同步')
]
class Operate(ApiMixin, serializers.Serializer):
id = serializers.CharField(required=True)
user_id = serializers.UUIDField(required=False)

View File

@ -6,7 +6,9 @@
@date2023/9/22 13:43
@desc:
"""
import logging
import os
import traceback
import uuid
from functools import reduce
from typing import List, Dict
@ -23,8 +25,9 @@ from common.exception.app_exception import AppApiException
from common.mixins.api_mixin import ApiMixin
from common.util.common import post
from common.util.file_util import get_file_content
from common.util.fork import Fork
from common.util.split_model import SplitModel, get_split_model
from dataset.models.data_set import DataSet, Document, Paragraph, Problem, Type
from dataset.models.data_set import DataSet, Document, Paragraph, Problem, Type, Status
from dataset.serializers.paragraph_serializers import ParagraphSerializers, ParagraphInstanceSerializer
from smartdoc.conf import PROJECT_DIR
@ -100,6 +103,58 @@ class DocumentSerializers(ApiMixin, serializers.Serializer):
title="文档列表", description="文档列表",
items=DocumentSerializers.Operate.get_response_body_api())
class Sync(ApiMixin, serializers.Serializer):
document_id = serializers.UUIDField(required=True)
def is_valid(self, *, raise_exception=False):
super().is_valid(raise_exception=True)
document_id = self.data.get('document_id')
first = QuerySet(Document).filter(id=document_id).first()
if first is None:
raise AppApiException(500, "文档id不存在")
if first.type != Type.web:
raise AppApiException(500, "只有web站点类型才支持同步")
def sync(self, with_valid=True):
if with_valid:
self.is_valid(raise_exception=True)
document_id = self.data.get('document_id')
document = QuerySet(Document).filter(id=document_id).first()
try:
document.status = Status.embedding
document.save()
source_url = document.meta.get('source_url')
selector_list = document.meta.get('selector').split(" ") if 'selector' in document.meta else []
result = Fork(source_url, selector_list).fork()
if result.status == 200:
# 删除段落
QuerySet(model=Paragraph).filter(document_id=document_id).delete()
# 删除问题
QuerySet(model=Problem).filter(document_id=document_id).delete()
# 删除向量库
ListenerManagement.delete_embedding_by_document_signal.send(document_id)
paragraphs = get_split_model('web.md').parse(result.content)
document.char_length = reduce(lambda x, y: x + y,
[len(p.get('content')) for p in paragraphs],
0)
document.save()
document_paragraph_model = DocumentSerializers.Create.get_paragraph_model(document, paragraphs)
paragraph_model_list = document_paragraph_model.get('paragraph_model_list')
problem_model_list = document_paragraph_model.get('problem_model_list')
# 批量插入段落
QuerySet(Paragraph).bulk_create(paragraph_model_list) if len(paragraph_model_list) > 0 else None
# 批量插入问题
QuerySet(Problem).bulk_create(problem_model_list) if len(problem_model_list) > 0 else None
else:
document.status = Status.error
document.save()
except Exception as e:
logging.getLogger("max_kb_error").error(f'{str(e)}:{traceback.format_exc()}')
document.status = Status.error
document.save()
return True
class Operate(ApiMixin, serializers.Serializer):
document_id = serializers.UUIDField(required=True)
@ -146,6 +201,11 @@ class DocumentSerializers(ApiMixin, serializers.Serializer):
if with_valid:
self.is_valid(raise_exception=True)
document_id = self.data.get("document_id")
document = QuerySet(Document).filter(id=document_id).first()
if document.type == Type.web:
# 如果是web站点,就是先同步
DocumentSerializers.Sync(data={'document_id': document_id}).sync()
ListenerManagement.embedding_by_document_signal.send(document_id)
return True
@ -236,21 +296,11 @@ class DocumentSerializers(ApiMixin, serializers.Serializer):
with_valid=True), document_id
@staticmethod
def get_document_paragraph_model(dataset_id, instance: Dict):
document_model = Document(
**{'dataset_id': dataset_id,
'id': uuid.uuid1(),
'name': instance.get('name'),
'char_length': reduce(lambda x, y: x + y,
[len(p.get('content')) for p in instance.get('paragraphs', [])],
0),
'meta': instance.get('meta') if instance.get('meta') is not None else {},
'type': instance.get('type') if instance.get('type') is not None else Type.base})
def get_paragraph_model(document_model, paragraph_list: List):
dataset_id = document_model.dataset_id
paragraph_model_dict_list = [ParagraphSerializers.Create(
data={'dataset_id': dataset_id, 'document_id': str(document_model.id)}).get_paragraph_problem_model(
dataset_id, document_model.id, paragraph) for paragraph in (instance.get('paragraphs') if
'paragraphs' in instance else [])]
dataset_id, document_model.id, paragraph) for paragraph in paragraph_list]
paragraph_model_list = []
problem_model_list = []
@ -263,6 +313,21 @@ class DocumentSerializers(ApiMixin, serializers.Serializer):
return {'document': document_model, 'paragraph_model_list': paragraph_model_list,
'problem_model_list': problem_model_list}
@staticmethod
def get_document_paragraph_model(dataset_id, instance: Dict):
document_model = Document(
**{'dataset_id': dataset_id,
'id': uuid.uuid1(),
'name': instance.get('name'),
'char_length': reduce(lambda x, y: x + y,
[len(p.get('content')) for p in instance.get('paragraphs', [])],
0),
'meta': instance.get('meta') if instance.get('meta') is not None else {},
'type': instance.get('type') if instance.get('type') is not None else Type.base})
return DocumentSerializers.Create.get_paragraph_model(document_model, instance.get('paragraphs') if
'paragraphs' in instance else [])
@staticmethod
def get_request_body_api():
return DocumentInstanceSerializer.get_request_body_api()

View File

@ -9,6 +9,7 @@ urlpatterns = [
path('dataset/<str:dataset_id>', views.Dataset.Operate.as_view(), name="dataset_key"),
path('dataset/<str:dataset_id>/application', views.Dataset.Application.as_view()),
path('dataset/<int:current_page>/<int:page_size>', views.Dataset.Page.as_view(), name="dataset"),
path('dataset/<str:dataset_id>/sync_web', views.Dataset.SyncWeb.as_view()),
path('dataset/<str:dataset_id>/hit_test', views.Dataset.HitTest.as_view()),
path('dataset/<str:dataset_id>/document', views.Document.as_view(), name='document'),
path('dataset/<str:dataset_id>/document/_bach', views.Document.Batch.as_view()),

View File

@ -23,6 +23,24 @@ from dataset.serializers.dataset_serializers import DataSetSerializers
class Dataset(APIView):
authentication_classes = [TokenAuth]
class SyncWeb(APIView):
authentication_classes = [TokenAuth]
@action(methods=['PUT'], detail=False)
@swagger_auto_schema(operation_summary="同步Web站点知识库",
operation_id="同步Web站点知识库",
manual_parameters=DataSetSerializers.SyncWeb.get_request_params_api(),
responses=result.get_default_response(),
tags=["知识库"])
@has_permissions(lambda r, keywords: Permission(group=Group.DATASET, operate=Operate.MANAGE,
dynamic_tag=keywords.get('dataset_id')),
lambda r, k: Permission(group=Group.DATASET, operate=Operate.DELETE,
dynamic_tag=k.get('dataset_id')), compare=CompareConstants.AND)
def put(self, request: Request, dataset_id: str):
return result.success(DataSetSerializers.SyncWeb(
data={'sync_type': request.query_params.get('sync_type'), 'id': dataset_id,
'user_id': str(request.user.id)}).sync())
class CreateWebDataset(APIView):
authentication_classes = [TokenAuth]