mirror of
https://github.com/1Panel-dev/MaxKB.git
synced 2025-12-26 01:33:05 +00:00
feat: web站点数据集文档同步
This commit is contained in:
parent
edbc8561c7
commit
82c8e322fb
|
|
@ -99,6 +99,9 @@ class PermissionConstants(Enum):
|
|||
DATASET_READ = Permission(group=Group.DATASET, operate=Operate.READ,
|
||||
roles=[RoleConstants.ADMIN, RoleConstants.USER])
|
||||
|
||||
DATASET_EDIT = Permission(group=Group.DATASET, operate=Operate.EDIT,
|
||||
roles=[RoleConstants.ADMIN, RoleConstants.USER])
|
||||
|
||||
APPLICATION_READ = Permission(group=Group.APPLICATION, operate=Operate.READ,
|
||||
roles=[RoleConstants.ADMIN, RoleConstants.USER])
|
||||
|
||||
|
|
|
|||
|
|
@ -8,6 +8,7 @@
|
|||
"""
|
||||
import logging
|
||||
import os.path
|
||||
import re
|
||||
import traceback
|
||||
import uuid
|
||||
from functools import reduce
|
||||
|
|
@ -483,6 +484,97 @@ class DataSetSerializers(serializers.ModelSerializer):
|
|||
return [{**p, 'similarity': hit_dict.get(p.get('id')).get('similarity'),
|
||||
'comprehensive_score': hit_dict.get(p.get('id')).get('comprehensive_score')} for p in p_list]
|
||||
|
||||
class SyncWeb(ApiMixin, serializers.Serializer):
|
||||
id = serializers.CharField(required=True)
|
||||
user_id = serializers.UUIDField(required=False)
|
||||
sync_type = serializers.CharField(required=True, validators=[
|
||||
validators.RegexValidator(regex=re.compile("^replace|complete$"),
|
||||
message="replace|complete", code=500)
|
||||
])
|
||||
|
||||
def is_valid(self, *, raise_exception=False):
|
||||
super().is_valid(raise_exception=True)
|
||||
first = QuerySet(DataSet).filter(id=self.data.get("id")).first()
|
||||
if first is None:
|
||||
raise AppApiException(300, "id不存在")
|
||||
if first.type != Type.web:
|
||||
raise AppApiException(500, "只有web站点类型才支持同步")
|
||||
|
||||
def sync(self, with_valid=True):
|
||||
if with_valid:
|
||||
self.is_valid(raise_exception=True)
|
||||
sync_type = self.data.get('sync_type')
|
||||
dataset_id = self.data.get('id')
|
||||
dataset = QuerySet(DataSet).get(id=dataset_id)
|
||||
self.__getattribute__(sync_type + '_sync')(dataset)
|
||||
return True
|
||||
|
||||
@staticmethod
|
||||
def get_sync_handler(dataset):
|
||||
def handler(child_link: ChildLink, response: Fork.Response):
|
||||
if response.status == 200:
|
||||
try:
|
||||
document_name = child_link.tag.text if child_link.tag is not None and len(
|
||||
child_link.tag.text.strip()) > 0 else child_link.url
|
||||
paragraphs = get_split_model('web.md').parse(response.content)
|
||||
first = QuerySet(Document).filter(meta__source_url=child_link.url).first()
|
||||
if first is not None:
|
||||
# 如果存在,使用文档同步
|
||||
DocumentSerializers.Sync(data={'document_id': first.id}).sync()
|
||||
# 发送向量化指令
|
||||
ListenerManagement.embedding_by_document_signal.send(first.id)
|
||||
else:
|
||||
# 插入
|
||||
DocumentSerializers.Create(data={'dataset_id': dataset.id}).save(
|
||||
{'name': document_name, 'paragraphs': paragraphs,
|
||||
'meta': {'source_url': child_link.url, 'selector': dataset.meta.get('selector')},
|
||||
'type': Type.web}, with_valid=True)
|
||||
except Exception as e:
|
||||
logging.getLogger("max_kb_error").error(f'{str(e)}:{traceback.format_exc()}')
|
||||
|
||||
return handler
|
||||
|
||||
def replace_sync(self, dataset):
|
||||
"""
|
||||
替换同步
|
||||
:return:
|
||||
"""
|
||||
url = dataset.meta.get('source_url')
|
||||
selector = dataset.meta.get('selector') if 'selector' in dataset.meta else None
|
||||
ListenerManagement.sync_web_dataset_signal.send(
|
||||
SyncWebDatasetArgs(str(dataset.id), url, selector,
|
||||
self.get_sync_handler(dataset)))
|
||||
|
||||
def complete_sync(self, dataset):
|
||||
"""
|
||||
完整同步 删掉当前数据集下所有的文档,再进行同步
|
||||
:return:
|
||||
"""
|
||||
# 删除文档
|
||||
QuerySet(Document).filter(dataset=dataset).delete()
|
||||
# 删除段落
|
||||
QuerySet(Paragraph).filter(dataset=dataset).delete()
|
||||
# 删除问题
|
||||
QuerySet(Problem).filter(dataset=dataset).delete()
|
||||
# 删除向量
|
||||
ListenerManagement.delete_embedding_by_dataset_signal.send(self.data.get('id'))
|
||||
# 同步
|
||||
self.replace_sync(dataset)
|
||||
|
||||
@staticmethod
|
||||
def get_request_params_api():
|
||||
return [openapi.Parameter(name='dataset_id',
|
||||
in_=openapi.IN_PATH,
|
||||
type=openapi.TYPE_STRING,
|
||||
required=True,
|
||||
description='知识库id'),
|
||||
openapi.Parameter(name='sync_type',
|
||||
in_=openapi.IN_QUERY,
|
||||
type=openapi.TYPE_STRING,
|
||||
required=True,
|
||||
description='同步类型->replace:替换同步,complete:完整同步')
|
||||
]
|
||||
|
||||
class Operate(ApiMixin, serializers.Serializer):
|
||||
id = serializers.CharField(required=True)
|
||||
user_id = serializers.UUIDField(required=False)
|
||||
|
|
|
|||
|
|
@ -6,7 +6,9 @@
|
|||
@date:2023/9/22 13:43
|
||||
@desc:
|
||||
"""
|
||||
import logging
|
||||
import os
|
||||
import traceback
|
||||
import uuid
|
||||
from functools import reduce
|
||||
from typing import List, Dict
|
||||
|
|
@ -23,8 +25,9 @@ from common.exception.app_exception import AppApiException
|
|||
from common.mixins.api_mixin import ApiMixin
|
||||
from common.util.common import post
|
||||
from common.util.file_util import get_file_content
|
||||
from common.util.fork import Fork
|
||||
from common.util.split_model import SplitModel, get_split_model
|
||||
from dataset.models.data_set import DataSet, Document, Paragraph, Problem, Type
|
||||
from dataset.models.data_set import DataSet, Document, Paragraph, Problem, Type, Status
|
||||
from dataset.serializers.paragraph_serializers import ParagraphSerializers, ParagraphInstanceSerializer
|
||||
from smartdoc.conf import PROJECT_DIR
|
||||
|
||||
|
|
@ -100,6 +103,58 @@ class DocumentSerializers(ApiMixin, serializers.Serializer):
|
|||
title="文档列表", description="文档列表",
|
||||
items=DocumentSerializers.Operate.get_response_body_api())
|
||||
|
||||
class Sync(ApiMixin, serializers.Serializer):
|
||||
document_id = serializers.UUIDField(required=True)
|
||||
|
||||
def is_valid(self, *, raise_exception=False):
|
||||
super().is_valid(raise_exception=True)
|
||||
document_id = self.data.get('document_id')
|
||||
first = QuerySet(Document).filter(id=document_id).first()
|
||||
if first is None:
|
||||
raise AppApiException(500, "文档id不存在")
|
||||
if first.type != Type.web:
|
||||
raise AppApiException(500, "只有web站点类型才支持同步")
|
||||
|
||||
def sync(self, with_valid=True):
|
||||
if with_valid:
|
||||
self.is_valid(raise_exception=True)
|
||||
document_id = self.data.get('document_id')
|
||||
document = QuerySet(Document).filter(id=document_id).first()
|
||||
try:
|
||||
document.status = Status.embedding
|
||||
document.save()
|
||||
source_url = document.meta.get('source_url')
|
||||
selector_list = document.meta.get('selector').split(" ") if 'selector' in document.meta else []
|
||||
result = Fork(source_url, selector_list).fork()
|
||||
if result.status == 200:
|
||||
# 删除段落
|
||||
QuerySet(model=Paragraph).filter(document_id=document_id).delete()
|
||||
# 删除问题
|
||||
QuerySet(model=Problem).filter(document_id=document_id).delete()
|
||||
# 删除向量库
|
||||
ListenerManagement.delete_embedding_by_document_signal.send(document_id)
|
||||
paragraphs = get_split_model('web.md').parse(result.content)
|
||||
document.char_length = reduce(lambda x, y: x + y,
|
||||
[len(p.get('content')) for p in paragraphs],
|
||||
0)
|
||||
document.save()
|
||||
document_paragraph_model = DocumentSerializers.Create.get_paragraph_model(document, paragraphs)
|
||||
|
||||
paragraph_model_list = document_paragraph_model.get('paragraph_model_list')
|
||||
problem_model_list = document_paragraph_model.get('problem_model_list')
|
||||
# 批量插入段落
|
||||
QuerySet(Paragraph).bulk_create(paragraph_model_list) if len(paragraph_model_list) > 0 else None
|
||||
# 批量插入问题
|
||||
QuerySet(Problem).bulk_create(problem_model_list) if len(problem_model_list) > 0 else None
|
||||
else:
|
||||
document.status = Status.error
|
||||
document.save()
|
||||
except Exception as e:
|
||||
logging.getLogger("max_kb_error").error(f'{str(e)}:{traceback.format_exc()}')
|
||||
document.status = Status.error
|
||||
document.save()
|
||||
return True
|
||||
|
||||
class Operate(ApiMixin, serializers.Serializer):
|
||||
document_id = serializers.UUIDField(required=True)
|
||||
|
||||
|
|
@ -146,6 +201,11 @@ class DocumentSerializers(ApiMixin, serializers.Serializer):
|
|||
if with_valid:
|
||||
self.is_valid(raise_exception=True)
|
||||
document_id = self.data.get("document_id")
|
||||
document = QuerySet(Document).filter(id=document_id).first()
|
||||
if document.type == Type.web:
|
||||
# 如果是web站点,就是先同步
|
||||
DocumentSerializers.Sync(data={'document_id': document_id}).sync()
|
||||
|
||||
ListenerManagement.embedding_by_document_signal.send(document_id)
|
||||
return True
|
||||
|
||||
|
|
@ -236,21 +296,11 @@ class DocumentSerializers(ApiMixin, serializers.Serializer):
|
|||
with_valid=True), document_id
|
||||
|
||||
@staticmethod
|
||||
def get_document_paragraph_model(dataset_id, instance: Dict):
|
||||
document_model = Document(
|
||||
**{'dataset_id': dataset_id,
|
||||
'id': uuid.uuid1(),
|
||||
'name': instance.get('name'),
|
||||
'char_length': reduce(lambda x, y: x + y,
|
||||
[len(p.get('content')) for p in instance.get('paragraphs', [])],
|
||||
0),
|
||||
'meta': instance.get('meta') if instance.get('meta') is not None else {},
|
||||
'type': instance.get('type') if instance.get('type') is not None else Type.base})
|
||||
|
||||
def get_paragraph_model(document_model, paragraph_list: List):
|
||||
dataset_id = document_model.dataset_id
|
||||
paragraph_model_dict_list = [ParagraphSerializers.Create(
|
||||
data={'dataset_id': dataset_id, 'document_id': str(document_model.id)}).get_paragraph_problem_model(
|
||||
dataset_id, document_model.id, paragraph) for paragraph in (instance.get('paragraphs') if
|
||||
'paragraphs' in instance else [])]
|
||||
dataset_id, document_model.id, paragraph) for paragraph in paragraph_list]
|
||||
|
||||
paragraph_model_list = []
|
||||
problem_model_list = []
|
||||
|
|
@ -263,6 +313,21 @@ class DocumentSerializers(ApiMixin, serializers.Serializer):
|
|||
return {'document': document_model, 'paragraph_model_list': paragraph_model_list,
|
||||
'problem_model_list': problem_model_list}
|
||||
|
||||
@staticmethod
|
||||
def get_document_paragraph_model(dataset_id, instance: Dict):
|
||||
document_model = Document(
|
||||
**{'dataset_id': dataset_id,
|
||||
'id': uuid.uuid1(),
|
||||
'name': instance.get('name'),
|
||||
'char_length': reduce(lambda x, y: x + y,
|
||||
[len(p.get('content')) for p in instance.get('paragraphs', [])],
|
||||
0),
|
||||
'meta': instance.get('meta') if instance.get('meta') is not None else {},
|
||||
'type': instance.get('type') if instance.get('type') is not None else Type.base})
|
||||
|
||||
return DocumentSerializers.Create.get_paragraph_model(document_model, instance.get('paragraphs') if
|
||||
'paragraphs' in instance else [])
|
||||
|
||||
@staticmethod
|
||||
def get_request_body_api():
|
||||
return DocumentInstanceSerializer.get_request_body_api()
|
||||
|
|
|
|||
|
|
@ -9,6 +9,7 @@ urlpatterns = [
|
|||
path('dataset/<str:dataset_id>', views.Dataset.Operate.as_view(), name="dataset_key"),
|
||||
path('dataset/<str:dataset_id>/application', views.Dataset.Application.as_view()),
|
||||
path('dataset/<int:current_page>/<int:page_size>', views.Dataset.Page.as_view(), name="dataset"),
|
||||
path('dataset/<str:dataset_id>/sync_web', views.Dataset.SyncWeb.as_view()),
|
||||
path('dataset/<str:dataset_id>/hit_test', views.Dataset.HitTest.as_view()),
|
||||
path('dataset/<str:dataset_id>/document', views.Document.as_view(), name='document'),
|
||||
path('dataset/<str:dataset_id>/document/_bach', views.Document.Batch.as_view()),
|
||||
|
|
|
|||
|
|
@ -23,6 +23,24 @@ from dataset.serializers.dataset_serializers import DataSetSerializers
|
|||
class Dataset(APIView):
|
||||
authentication_classes = [TokenAuth]
|
||||
|
||||
class SyncWeb(APIView):
|
||||
authentication_classes = [TokenAuth]
|
||||
|
||||
@action(methods=['PUT'], detail=False)
|
||||
@swagger_auto_schema(operation_summary="同步Web站点知识库",
|
||||
operation_id="同步Web站点知识库",
|
||||
manual_parameters=DataSetSerializers.SyncWeb.get_request_params_api(),
|
||||
responses=result.get_default_response(),
|
||||
tags=["知识库"])
|
||||
@has_permissions(lambda r, keywords: Permission(group=Group.DATASET, operate=Operate.MANAGE,
|
||||
dynamic_tag=keywords.get('dataset_id')),
|
||||
lambda r, k: Permission(group=Group.DATASET, operate=Operate.DELETE,
|
||||
dynamic_tag=k.get('dataset_id')), compare=CompareConstants.AND)
|
||||
def put(self, request: Request, dataset_id: str):
|
||||
return result.success(DataSetSerializers.SyncWeb(
|
||||
data={'sync_type': request.query_params.get('sync_type'), 'id': dataset_id,
|
||||
'user_id': str(request.user.id)}).sync())
|
||||
|
||||
class CreateWebDataset(APIView):
|
||||
authentication_classes = [TokenAuth]
|
||||
|
||||
|
|
|
|||
Loading…
Reference in New Issue