mirror of
https://github.com/1Panel-dev/MaxKB.git
synced 2025-12-26 01:33:05 +00:00
fix: web站点 批量同步,批量删除,批量导入接口
This commit is contained in:
parent
91b613c4da
commit
344c336143
|
|
@ -9,6 +9,7 @@
|
|||
import logging
|
||||
import os
|
||||
import traceback
|
||||
from typing import List
|
||||
|
||||
import django.db.models
|
||||
from blinker import signal
|
||||
|
|
@ -18,7 +19,7 @@ from common.config.embedding_config import VectorStore, EmbeddingModel
|
|||
from common.db.search import native_search, get_dynamics_model
|
||||
from common.event.common import poxy
|
||||
from common.util.file_util import get_file_content
|
||||
from common.util.fork import ForkManage
|
||||
from common.util.fork import ForkManage, Fork
|
||||
from common.util.lock import try_lock, un_lock
|
||||
from dataset.models import Paragraph, Status, Document
|
||||
from embedding.models import SourceType
|
||||
|
|
@ -36,6 +37,13 @@ class SyncWebDatasetArgs:
|
|||
self.handler = handler
|
||||
|
||||
|
||||
class SyncWebDocumentArgs:
|
||||
def __init__(self, source_url_list: List[str], selector: str, handler):
|
||||
self.source_url_list = source_url_list
|
||||
self.selector = selector
|
||||
self.handler = handler
|
||||
|
||||
|
||||
class ListenerManagement:
|
||||
embedding_by_problem_signal = signal("embedding_by_problem")
|
||||
embedding_by_paragraph_signal = signal("embedding_by_paragraph")
|
||||
|
|
@ -49,6 +57,7 @@ class ListenerManagement:
|
|||
disable_embedding_by_paragraph_signal = signal('disable_embedding_by_paragraph')
|
||||
init_embedding_model_signal = signal('init_embedding_model')
|
||||
sync_web_dataset_signal = signal('sync_web_dataset')
|
||||
sync_web_document_signal = signal('sync_web_document')
|
||||
|
||||
@staticmethod
|
||||
def embedding_by_problem(args):
|
||||
|
|
@ -155,6 +164,13 @@ class ListenerManagement:
|
|||
def enable_embedding_by_paragraph(paragraph_id):
|
||||
VectorStore.get_embedding_vector().update_by_paragraph_id(paragraph_id, {'is_active': True})
|
||||
|
||||
@staticmethod
|
||||
@poxy
|
||||
def sync_web_document(args: SyncWebDocumentArgs):
|
||||
for source_url in args.source_url_list:
|
||||
result = Fork(base_fork_url=source_url, selector_list=args.selector.split(' ')).fork()
|
||||
args.handler(source_url, args.selector, result)
|
||||
|
||||
@staticmethod
|
||||
@poxy
|
||||
def sync_web_dataset(args: SyncWebDatasetArgs):
|
||||
|
|
@ -200,3 +216,5 @@ class ListenerManagement:
|
|||
ListenerManagement.init_embedding_model_signal.connect(self.init_embedding_model)
|
||||
# 同步web站点知识库
|
||||
ListenerManagement.sync_web_dataset_signal.connect(self.sync_web_dataset)
|
||||
# 同步web站点 文档
|
||||
ListenerManagement.sync_web_document_signal.connect(self.sync_web_document)
|
||||
|
|
|
|||
|
|
@ -10,9 +10,13 @@ import os
|
|||
from typing import List
|
||||
|
||||
from django.db.models import QuerySet
|
||||
from drf_yasg import openapi
|
||||
from rest_framework import serializers
|
||||
|
||||
from common.db.search import native_search
|
||||
from common.db.sql_execute import update_execute
|
||||
from common.exception.app_exception import AppApiException
|
||||
from common.mixins.api_mixin import ApiMixin
|
||||
from common.util.file_util import get_file_content
|
||||
from dataset.models import Paragraph
|
||||
from smartdoc.conf import PROJECT_DIR
|
||||
|
|
@ -29,3 +33,28 @@ def list_paragraph(paragraph_list: List[str]):
|
|||
return []
|
||||
return native_search(QuerySet(Paragraph).filter(id__in=paragraph_list), get_file_content(
|
||||
os.path.join(PROJECT_DIR, "apps", "dataset", 'sql', 'list_paragraph.sql')))
|
||||
|
||||
|
||||
class BatchSerializer(ApiMixin, serializers.Serializer):
|
||||
id_list = serializers.ListField(required=True, child=serializers.UUIDField(required=True))
|
||||
|
||||
def is_valid(self, *, model=None, raise_exception=False):
|
||||
super().is_valid(raise_exception=True)
|
||||
if model is not None:
|
||||
id_list = self.data.get('id_list')
|
||||
model_list = QuerySet(model).filter(id__in=id_list)
|
||||
if len(model_list) != len(id_list):
|
||||
model_id_list = [str(m.id) for m in model_list]
|
||||
error_id_list = list(filter(lambda row_id: not model_id_list.__contains__(row_id), id_list))
|
||||
raise AppApiException(500, f"id不正确:{error_id_list}")
|
||||
|
||||
@staticmethod
|
||||
def get_request_body_api():
|
||||
return openapi.Schema(
|
||||
type=openapi.TYPE_OBJECT,
|
||||
properties={
|
||||
'id_list': openapi.Schema(type=openapi.TYPE_ARRAY, items=openapi.Schema(type=openapi.TYPE_STRING),
|
||||
title="主键id列表",
|
||||
description="主键id列表")
|
||||
}
|
||||
)
|
||||
|
|
|
|||
|
|
@ -21,7 +21,7 @@ from rest_framework import serializers
|
|||
|
||||
from common.db.search import native_search, native_page_search
|
||||
from common.event.common import work_thread_pool
|
||||
from common.event.listener_manage import ListenerManagement
|
||||
from common.event.listener_manage import ListenerManagement, SyncWebDocumentArgs
|
||||
from common.exception.app_exception import AppApiException
|
||||
from common.mixins.api_mixin import ApiMixin
|
||||
from common.util.common import post
|
||||
|
|
@ -29,10 +29,28 @@ from common.util.file_util import get_file_content
|
|||
from common.util.fork import Fork
|
||||
from common.util.split_model import SplitModel, get_split_model
|
||||
from dataset.models.data_set import DataSet, Document, Paragraph, Problem, Type, Status
|
||||
from dataset.serializers.common_serializers import BatchSerializer
|
||||
from dataset.serializers.paragraph_serializers import ParagraphSerializers, ParagraphInstanceSerializer
|
||||
from smartdoc.conf import PROJECT_DIR
|
||||
|
||||
|
||||
class DocumentWebInstanceSerializer(ApiMixin, serializers.Serializer):
|
||||
source_url_list = serializers.ListField(required=True, child=serializers.CharField(required=True))
|
||||
selector = serializers.CharField(required=False)
|
||||
|
||||
@staticmethod
|
||||
def get_request_body_api():
|
||||
return openapi.Schema(
|
||||
type=openapi.TYPE_OBJECT,
|
||||
required=['source_url_list'],
|
||||
properties={
|
||||
'source_url_list': openapi.Schema(type=openapi.TYPE_ARRAY, title="段落列表", description="段落列表",
|
||||
items=openapi.Schema(type=openapi.TYPE_STRING)),
|
||||
'selector': openapi.Schema(type=openapi.TYPE_STRING, title="文档名称", description="文档名称")
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
class DocumentInstanceSerializer(ApiMixin, serializers.Serializer):
|
||||
name = serializers.CharField(required=True,
|
||||
validators=[
|
||||
|
|
@ -121,6 +139,8 @@ class DocumentSerializers(ApiMixin, serializers.Serializer):
|
|||
self.is_valid(raise_exception=True)
|
||||
document_id = self.data.get('document_id')
|
||||
document = QuerySet(Document).filter(id=document_id).first()
|
||||
if document.type != Type.web:
|
||||
return True
|
||||
try:
|
||||
document.status = Status.embedding
|
||||
document.save()
|
||||
|
|
@ -301,6 +321,38 @@ class DocumentSerializers(ApiMixin, serializers.Serializer):
|
|||
data={'dataset_id': dataset_id, 'document_id': document_id}).one(
|
||||
with_valid=True), document_id
|
||||
|
||||
@staticmethod
|
||||
def get_sync_handler(dataset_id):
|
||||
def handler(source_url: str, selector, response: Fork.Response):
|
||||
if response.status == 200:
|
||||
try:
|
||||
paragraphs = get_split_model('web.md').parse(response.content)
|
||||
# 插入
|
||||
DocumentSerializers.Create(data={'dataset_id': dataset_id}).save(
|
||||
{'name': source_url, 'paragraphs': paragraphs,
|
||||
'meta': {'source_url': source_url, 'selector': selector},
|
||||
'type': Type.web}, with_valid=True)
|
||||
except Exception as e:
|
||||
logging.getLogger("max_kb_error").error(f'{str(e)}:{traceback.format_exc()}')
|
||||
else:
|
||||
Document(name=source_url,
|
||||
meta={'source_url': source_url, 'selector': selector},
|
||||
type=Type.web,
|
||||
char_length=0,
|
||||
status=Status.error).save()
|
||||
|
||||
return handler
|
||||
|
||||
def save_web(self, instance: Dict, with_valid=True):
|
||||
if with_valid:
|
||||
DocumentWebInstanceSerializer(data=instance).is_valid(raise_exception=True)
|
||||
self.is_valid(raise_exception=True)
|
||||
dataset_id = self.data.get('dataset_id')
|
||||
source_url_list = instance.get('source_url_list')
|
||||
selector = instance.get('selector')
|
||||
args = SyncWebDocumentArgs(source_url_list, selector, self.get_sync_handler(dataset_id))
|
||||
ListenerManagement.sync_web_document_signal.send(args)
|
||||
|
||||
@staticmethod
|
||||
def get_paragraph_model(document_model, paragraph_list: List):
|
||||
dataset_id = document_model.dataset_id
|
||||
|
|
@ -331,8 +383,9 @@ class DocumentSerializers(ApiMixin, serializers.Serializer):
|
|||
'meta': instance.get('meta') if instance.get('meta') is not None else {},
|
||||
'type': instance.get('type') if instance.get('type') is not None else Type.base})
|
||||
|
||||
return DocumentSerializers.Create.get_paragraph_model(document_model, instance.get('paragraphs') if
|
||||
'paragraphs' in instance else [])
|
||||
return DocumentSerializers.Create.get_paragraph_model(document_model,
|
||||
instance.get('paragraphs') if
|
||||
'paragraphs' in instance else [])
|
||||
|
||||
@staticmethod
|
||||
def get_request_body_api():
|
||||
|
|
@ -451,6 +504,27 @@ class DocumentSerializers(ApiMixin, serializers.Serializer):
|
|||
return native_search(query_set, select_string=get_file_content(
|
||||
os.path.join(PROJECT_DIR, "apps", "dataset", 'sql', 'list_document.sql')), with_search_one=False),
|
||||
|
||||
@staticmethod
|
||||
def _batch_sync(document_id_list: List[str]):
|
||||
for document_id in document_id_list:
|
||||
DocumentSerializers.Sync(data={'document_id': document_id}).sync()
|
||||
|
||||
def batch_sync(self, instance: Dict, with_valid=True):
|
||||
if with_valid:
|
||||
BatchSerializer(data=instance).is_valid(model=Document, raise_exception=True)
|
||||
self.is_valid(raise_exception=True)
|
||||
# 异步同步
|
||||
work_thread_pool.submit(self._batch_sync,
|
||||
instance.get('id_list'))
|
||||
return True
|
||||
|
||||
def batch_delete(self, instance: Dict, with_valid=True):
|
||||
if with_valid:
|
||||
BatchSerializer(data=instance).is_valid(model=Document, raise_exception=True)
|
||||
self.is_valid(raise_exception=True)
|
||||
QuerySet(Document).filter(id__in=instance.get('id_list')).delete()
|
||||
return True
|
||||
|
||||
|
||||
def file_to_paragraph(file, pattern_list: List, with_filter: bool, limit: int):
|
||||
data = file.read()
|
||||
|
|
|
|||
|
|
@ -12,6 +12,7 @@ urlpatterns = [
|
|||
path('dataset/<str:dataset_id>/sync_web', views.Dataset.SyncWeb.as_view()),
|
||||
path('dataset/<str:dataset_id>/hit_test', views.Dataset.HitTest.as_view()),
|
||||
path('dataset/<str:dataset_id>/document', views.Document.as_view(), name='document'),
|
||||
path('dataset/<str:dataset_id>/document/web', views.WebDocument.as_view()),
|
||||
path('dataset/<str:dataset_id>/document/_bach', views.Document.Batch.as_view()),
|
||||
path('dataset/<str:dataset_id>/document/<int:current_page>/<int:page_size>', views.Document.Page.as_view()),
|
||||
path('dataset/<str:dataset_id>/document/<str:document_id>', views.Document.Operate.as_view(),
|
||||
|
|
|
|||
|
|
@ -14,11 +14,29 @@ from rest_framework.views import APIView
|
|||
from rest_framework.views import Request
|
||||
|
||||
from common.auth import TokenAuth, has_permissions
|
||||
from common.constants.permission_constants import Permission, Group, Operate, PermissionConstants
|
||||
from common.event.common import work_thread_pool
|
||||
from common.constants.permission_constants import Permission, Group, Operate
|
||||
from common.response import result
|
||||
from common.util.common import query_params_to_single_dict
|
||||
from dataset.serializers.document_serializers import DocumentSerializers
|
||||
from dataset.serializers.common_serializers import BatchSerializer
|
||||
from dataset.serializers.document_serializers import DocumentSerializers, DocumentWebInstanceSerializer
|
||||
|
||||
|
||||
class WebDocument(APIView):
|
||||
authentication_classes = [TokenAuth]
|
||||
|
||||
@action(methods=['POST'], detail=False)
|
||||
@swagger_auto_schema(operation_summary="创建Web站点文档",
|
||||
operation_id="创建Web站点文档",
|
||||
request_body=DocumentWebInstanceSerializer.get_request_body_api(),
|
||||
manual_parameters=DocumentSerializers.Create.get_request_params_api(),
|
||||
responses=result.get_api_response(DocumentSerializers.Operate.get_response_body_api()),
|
||||
tags=["知识库/文档"])
|
||||
@has_permissions(
|
||||
lambda r, k: Permission(group=Group.DATASET, operate=Operate.MANAGE,
|
||||
dynamic_tag=k.get('dataset_id')))
|
||||
def post(self, request: Request, dataset_id: str):
|
||||
return result.success(
|
||||
DocumentSerializers.Create(data={'dataset_id': dataset_id}).save_web(request.data, with_valid=True))
|
||||
|
||||
|
||||
class Document(APIView):
|
||||
|
|
@ -71,6 +89,34 @@ class Document(APIView):
|
|||
def post(self, request: Request, dataset_id: str):
|
||||
return result.success(DocumentSerializers.Batch(data={'dataset_id': dataset_id}).batch_save(request.data))
|
||||
|
||||
@action(methods=['POST'], detail=False)
|
||||
@swagger_auto_schema(operation_summary="批量同步文档",
|
||||
operation_id="批量同步文档",
|
||||
request_body=
|
||||
BatchSerializer.get_request_body_api(),
|
||||
manual_parameters=DocumentSerializers.Create.get_request_params_api(),
|
||||
responses=result.get_default_response(),
|
||||
tags=["知识库/文档"])
|
||||
@has_permissions(
|
||||
lambda r, k: Permission(group=Group.DATASET, operate=Operate.MANAGE,
|
||||
dynamic_tag=k.get('dataset_id')))
|
||||
def put(self, request: Request, dataset_id: str):
|
||||
return result.success(DocumentSerializers.Batch(data={'dataset_id': dataset_id}).batch_sync(request.data))
|
||||
|
||||
@action(methods=['DELETE'], detail=False)
|
||||
@swagger_auto_schema(operation_summary="批量删除文档",
|
||||
operation_id="批量删除文档",
|
||||
request_body=
|
||||
BatchSerializer.get_request_body_api(),
|
||||
manual_parameters=DocumentSerializers.Create.get_request_params_api(),
|
||||
responses=result.get_default_response(),
|
||||
tags=["知识库/文档"])
|
||||
@has_permissions(
|
||||
lambda r, k: Permission(group=Group.DATASET, operate=Operate.MANAGE,
|
||||
dynamic_tag=k.get('dataset_id')))
|
||||
def delete(self, request: Request, dataset_id: str):
|
||||
return result.success(DocumentSerializers.Batch(data={'dataset_id': dataset_id}).batch_delete(request.data))
|
||||
|
||||
class Refresh(APIView):
|
||||
authentication_classes = [TokenAuth]
|
||||
|
||||
|
|
|
|||
Loading…
Reference in New Issue