diff --git a/apps/application/migrations/0002_alter_chatrecord_dataset.py b/apps/application/migrations/0002_alter_chatrecord_dataset.py new file mode 100644 index 000000000..cd737e657 --- /dev/null +++ b/apps/application/migrations/0002_alter_chatrecord_dataset.py @@ -0,0 +1,20 @@ +# Generated by Django 4.1.10 on 2023-12-28 15:16 + +from django.db import migrations, models +import django.db.models.deletion + + +class Migration(migrations.Migration): + + dependencies = [ + ('dataset', '0002_dataset_meta_dataset_type_document_meta_and_more'), + ('application', '0001_initial'), + ] + + operations = [ + migrations.AlterField( + model_name='chatrecord', + name='dataset', + field=models.ForeignKey(blank=True, null=True, on_delete=django.db.models.deletion.SET_NULL, to='dataset.dataset', verbose_name='数据集'), + ), + ] diff --git a/apps/common/event/listener_manage.py b/apps/common/event/listener_manage.py index fb717b49c..a053b2394 100644 --- a/apps/common/event/listener_manage.py +++ b/apps/common/event/listener_manage.py @@ -18,6 +18,8 @@ from common.config.embedding_config import VectorStore, EmbeddingModel from common.db.search import native_search, get_dynamics_model from common.event.common import poxy from common.util.file_util import get_file_content +from common.util.fork import ForkManage +from common.util.lock import try_lock, un_lock from dataset.models import Paragraph, Status, Document from embedding.models import SourceType from smartdoc.conf import PROJECT_DIR @@ -26,6 +28,14 @@ max_kb_error = logging.getLogger("max_kb_error") max_kb = logging.getLogger("max_kb") +class SyncWebDatasetArgs: + def __init__(self, lock_key: str, url: str, selector: str, handler): + self.lock_key = lock_key + self.url = url + self.selector = selector + self.handler = handler + + class ListenerManagement: embedding_by_problem_signal = signal("embedding_by_problem") embedding_by_paragraph_signal = signal("embedding_by_paragraph") @@ -38,6 +48,7 @@ class ListenerManagement: enable_embedding_by_paragraph_signal = signal('enable_embedding_by_paragraph') disable_embedding_by_paragraph_signal = signal('disable_embedding_by_paragraph') init_embedding_model_signal = signal('init_embedding_model') + sync_web_dataset_signal = signal('sync_web_dataset') @staticmethod def embedding_by_problem(args): @@ -144,6 +155,18 @@ class ListenerManagement: def enable_embedding_by_paragraph(paragraph_id): VectorStore.get_embedding_vector().update_by_paragraph_id(paragraph_id, {'is_active': True}) + @staticmethod + @poxy + def sync_web_dataset(args: SyncWebDatasetArgs): + if try_lock('sync_web_dataset' + args.lock_key): + try: + ForkManage(args.url, args.selector.split(" ")).fork(2, set(), + args.handler) + except Exception as e: + logging.getLogger("max_kb_error").error(f'{str(e)}:{traceback.format_exc()}') + finally: + un_lock('sync_web_dataset' + args.lock_key) + @staticmethod @poxy def init_embedding_model(ags): @@ -175,3 +198,5 @@ class ListenerManagement: ListenerManagement.enable_embedding_by_paragraph_signal.connect(self.enable_embedding_by_paragraph) # 初始化向量化模型 ListenerManagement.init_embedding_model_signal.connect(self.init_embedding_model) + # 同步web站点知识库 + ListenerManagement.sync_web_dataset_signal.connect(self.sync_web_dataset) diff --git a/apps/common/util/fork.py b/apps/common/util/fork.py index 1413437bb..8897d6dee 100644 --- a/apps/common/util/fork.py +++ b/apps/common/util/fork.py @@ -1,10 +1,19 @@ +import copy +import logging import re +import traceback from functools import reduce from typing import List, Set import requests import html2text as ht from bs4 import BeautifulSoup -from urllib.parse import urljoin +from urllib.parse import urljoin, urlparse, ParseResult + + +class ChildLink: + def __init__(self, url, tag): + self.url = url + self.tag = copy.deepcopy(tag) class ForkManage: @@ -13,30 +22,34 @@ class ForkManage: self.selector_list = selector_list def fork(self, level: int, exclude_link_url: Set[str], fork_handler): - self.fork_child(self.base_url, self.selector_list, level, exclude_link_url, fork_handler) + self.fork_child(ChildLink(self.base_url, None), self.selector_list, level, exclude_link_url, fork_handler) @staticmethod - def fork_child(base_url: str, selector_list: List[str], level: int, exclude_link_url: Set[str], fork_handler): + def fork_child(child_link: ChildLink, selector_list: List[str], level: int, exclude_link_url: Set[str], + fork_handler): if level < 0: return - response = Fork(base_url, selector_list).fork() - fork_handler(base_url, response) + else: + child_url = child_link.url[:-1] if child_link.url.endswith('/') else child_link.url + exclude_link_url.add(child_url) + response = Fork(child_link.url, selector_list).fork() + fork_handler(child_link, response) for child_link in response.child_link_list: - if not exclude_link_url.__contains__(child_link): - exclude_link_url.add(child_link) + child_url = child_link.url[:-1] if child_link.url.endswith('/') else child_link.url + if not exclude_link_url.__contains__(child_url): ForkManage.fork_child(child_link, selector_list, level - 1, exclude_link_url, fork_handler) class Fork: class Response: - def __init__(self, html_content: str, child_link_list: List[str], status, message: str): - self.html_content = html_content + def __init__(self, content: str, child_link_list: List[ChildLink], status, message: str): + self.content = content self.child_link_list = child_link_list self.status = status self.message = message @staticmethod - def success(html_content: str, child_link_list: List[str]): + def success(html_content: str, child_link_list: List[ChildLink]): return Fork.Response(html_content, child_link_list, 200, '') @staticmethod @@ -45,13 +58,17 @@ class Fork: def __init__(self, base_fork_url: str, selector_list: List[str]): self.base_fork_url = urljoin(base_fork_url if base_fork_url.endswith("/") else base_fork_url + '/', '.') - self.base_fork_url = base_fork_url + self.base_fork_url = self.base_fork_url[:-1] self.selector_list = selector_list + self.urlparse = urlparse(self.base_fork_url) + self.base_url = ParseResult(scheme=self.urlparse.scheme, netloc=self.urlparse.netloc, path='', params='', + query='', + fragment='').geturl() def get_child_link_list(self, bf: BeautifulSoup): - pattern = "^(?!(http:|https:|tel:/|#|mailto:|javascript:)).*|" + self.base_fork_url + pattern = "^((?!(http:|https:|tel:/|#|mailto:|javascript:))|" + self.base_fork_url + ").*" link_list = bf.find_all(name='a', href=re.compile(pattern)) - result = [self.parse_href(link.get('href')) for link in link_list] + result = [ChildLink(link.get('href'), link) for link in link_list] return result def get_content_html(self, bf: BeautifulSoup): @@ -65,23 +82,34 @@ class Fork: f = bf.find_all(**params) return "\n".join([str(row) for row in f]) - def parse_href(self, href: str): - if href.startswith(self.base_fork_url[:-1] if self.base_fork_url.endswith('/') else self.base_fork_url): - return href + @staticmethod + def reset_url(tag, field, base_fork_url): + field_value: str = tag[field] + if field_value.startswith("/"): + result = urlparse(base_fork_url) + result_url = ParseResult(scheme=result.scheme, netloc=result.netloc, path=field_value, params='', query='', + fragment='').geturl() else: - return urljoin(self.base_fork_url + '/' + (href if href.endswith('/') else href + '/'), ".") + result_url = urljoin( + base_fork_url + '/' + (field_value if field_value.endswith('/') else field_value + '/'), + ".") + result_url = result_url[:-1] if result_url.endswith('/') else result_url + tag[field] = result_url def reset_beautiful_soup(self, bf: BeautifulSoup): - href_list = bf.find_all(href=re.compile('^(?!(http:|https:|tel:/|#|mailto:|javascript:)).*')) - for h in href_list: - h['href'] = urljoin( - self.base_fork_url + '/' + (h['href'] if h['href'].endswith('/') else h['href'] + '/'), - ".")[:-1] - src_list = bf.find_all(src=re.compile('^(?!(http:|https:|tel:/|#|mailto:|javascript:)).*')) - for s in src_list: - s['src'] = urljoin( - self.base_fork_url + '/' + (s['src'] if s['src'].endswith('/') else s['src'] + '/'), - ".")[:-1] + reset_config_list = [ + { + 'field': 'href', + }, + { + 'field': 'src', + } + ] + for reset_config in reset_config_list: + field = reset_config.get('field') + tag_list = bf.find_all(**{field: re.compile('^(?!(http:|https:|tel:/|#|mailto:|javascript:)).*')}) + for tag in tag_list: + self.reset_url(tag, field, self.base_fork_url) return bf @staticmethod @@ -92,11 +120,14 @@ class Fork: def fork(self): try: + logging.getLogger("max_kb").info(f'fork:{self.base_fork_url}') response = requests.get(self.base_fork_url) if response.status_code != 200: - raise Exception(response.status_code) + logging.getLogger("max_kb").error(f"url: {self.base_fork_url} code:{response.status_code}") + return Fork.Response.error(f"url: {self.base_fork_url} code:{response.status_code}") bf = self.get_beautiful_soup(response) except Exception as e: + logging.getLogger("max_kb_error").error(f'{str(e)}:{traceback.format_exc()}') return Fork.Response.error(str(e)) bf = self.reset_beautiful_soup(bf) link_list = self.get_child_link_list(bf) @@ -106,7 +137,6 @@ class Fork: def handler(base_url, response: Fork.Response): - print(base_url, response.status) + print(base_url.url, base_url.tag.text if base_url.tag else None, response.content) - -ForkManage('https://dataease.io/docs/v2/', ['.md-content']).fork(3, set(), handler) +# ForkManage('https://bbs.fit2cloud.com/c/de/6', ['.md-content']).fork(3, set(), handler) diff --git a/apps/common/util/split_model.py b/apps/common/util/split_model.py index 91007ed9c..bc5f4df12 100644 --- a/apps/common/util/split_model.py +++ b/apps/common/util/split_model.py @@ -277,11 +277,11 @@ def filter_special_char(content: str): class SplitModel: - def __init__(self, content_level_pattern, with_filter=True, limit=1024): + def __init__(self, content_level_pattern, with_filter=True, limit=4096): self.content_level_pattern = content_level_pattern self.with_filter = with_filter - if limit is None or limit > 1024: - limit = 1024 + if limit is None or limit > 4096: + limit = 4096 if limit < 50: limit = 50 self.limit = limit @@ -337,13 +337,12 @@ class SplitModel: default_split_pattern = { 'md': [re.compile("^# .*"), re.compile('(? 0 else child_link.url + paragraphs = get_split_model('web.md').parse(response.content) + DocumentSerializers.Create(data={'dataset_id': dataset_id}).save( + {'name': document_name, 'paragraphs': paragraphs, + 'meta': {'source_url': child_link.url, 'selector': selector}, + 'type': Type.web}, with_valid=True) + except Exception as e: + logging.getLogger("max_kb_error").error(f'{str(e)}:{traceback.format_exc()}') + + return handler + + def save_web(self, instance: Dict, with_valid=True): + if with_valid: + self.is_valid(raise_exception=True) + self.CreateWebSerializers(data=instance).is_valid(raise_exception=True) + user_id = self.data.get('user_id') + dataset_id = uuid.uuid1() + dataset = DataSet( + **{'id': dataset_id, 'name': instance.get("name"), 'desc': instance.get('desc'), 'user_id': user_id, + 'type': Type.web, 'meta': {'source_url': instance.get('url'), 'selector': instance.get('selector')}}) + dataset.save() + ListenerManagement.sync_web_dataset_signal.send( + SyncWebDatasetArgs(str(dataset_id), instance.get('url'), instance.get('selector'), + self.get_save_handler(dataset_id, instance.get('selector')))) + return {**DataSetSerializers(dataset).data, + 'document_list': []} + @staticmethod def get_response_body_api(): return openapi.Schema( @@ -298,12 +422,43 @@ class DataSetSerializers(serializers.ModelSerializer): } ) - class Edit(serializers.Serializer): + class MetaSerializer(serializers.Serializer): + class WebMeta(serializers.Serializer): + source_url = serializers.CharField(required=True) + selector = serializers.CharField(required=False, allow_null=True, allow_blank=True) + def is_valid(self, *, raise_exception=False): + super().is_valid(raise_exception=True) + source_url = self.data.get('source_url') + response = Fork(source_url, []).fork() + if response.status == 500: + raise AppApiException(500, response.message) + + class BaseMeta(serializers.Serializer): + def is_valid(self, *, raise_exception=False): + super().is_valid(raise_exception=True) + + class Edit(serializers.Serializer): name = serializers.CharField(required=False) desc = serializers.CharField(required=False) + meta = serializers.DictField(required=False) application_id_list = serializers.ListSerializer(required=False, child=serializers.UUIDField(required=True)) + @staticmethod + def get_dataset_meta_valid_map(): + dataset_meta_valid_map = { + Type.base: DataSetSerializers.MetaSerializer.BaseMeta, + Type.web: DataSetSerializers.MetaSerializer.WebMeta + } + return dataset_meta_valid_map + + def is_valid(self, *, dataset: DataSet = None): + super().is_valid(raise_exception=True) + if 'meta' in self.data and self.data.get('meta') is not None: + dataset_meta_valid_map = self.get_dataset_meta_valid_map() + valid_class = dataset_meta_valid_map.get(dataset.type) + valid_class(data=self.data.get('meta')).is_valid(raise_exception=True) + class HitTest(ApiMixin, serializers.Serializer): id = serializers.CharField(required=True) user_id = serializers.UUIDField(required=False) @@ -392,12 +547,14 @@ class DataSetSerializers(serializers.ModelSerializer): :return: """ self.is_valid() - DataSetSerializers.Edit(data=dataset).is_valid(raise_exception=True) _dataset = QuerySet(DataSet).get(id=self.data.get("id")) + DataSetSerializers.Edit(data=dataset).is_valid(dataset=_dataset) if "name" in dataset: _dataset.name = dataset.get("name") if 'desc' in dataset: _dataset.desc = dataset.get("desc") + if 'meta' in dataset: + _dataset.meta = dataset.get('meta') if 'application_id_list' in dataset and dataset.get('application_id_list') is not None: application_id_list = dataset.get('application_id_list') # 当前用户可修改关联的知识库列表 @@ -429,6 +586,8 @@ class DataSetSerializers(serializers.ModelSerializer): properties={ 'name': openapi.Schema(type=openapi.TYPE_STRING, title="知识库名称", description="知识库名称"), 'desc': openapi.Schema(type=openapi.TYPE_STRING, title="知识库描述", description="知识库描述"), + 'meta': openapi.Schema(type=openapi.TYPE_OBJECT, title="知识库元数据", + description="知识库元数据->web:{source_url:xxx,selector:'xxx'},base:{}"), 'application_id_list': openapi.Schema(type=openapi.TYPE_ARRAY, title="应用id列表", description="应用id列表", items=openapi.Schema(type=openapi.TYPE_STRING)) diff --git a/apps/dataset/serializers/document_serializers.py b/apps/dataset/serializers/document_serializers.py index 7e6f5dd04..519276cbf 100644 --- a/apps/dataset/serializers/document_serializers.py +++ b/apps/dataset/serializers/document_serializers.py @@ -24,7 +24,7 @@ from common.mixins.api_mixin import ApiMixin from common.util.common import post from common.util.file_util import get_file_content from common.util.split_model import SplitModel, get_split_model -from dataset.models.data_set import DataSet, Document, Paragraph, Problem +from dataset.models.data_set import DataSet, Document, Paragraph, Problem, Type from dataset.serializers.paragraph_serializers import ParagraphSerializers, ParagraphInstanceSerializer from smartdoc.conf import PROJECT_DIR @@ -243,7 +243,9 @@ class DocumentSerializers(ApiMixin, serializers.Serializer): 'name': instance.get('name'), 'char_length': reduce(lambda x, y: x + y, [len(p.get('content')) for p in instance.get('paragraphs', [])], - 0)}) + 0), + 'meta': instance.get('meta') if instance.get('meta') is not None else {}, + 'type': instance.get('type') if instance.get('type') is not None else Type.base}) paragraph_model_dict_list = [ParagraphSerializers.Create( data={'dataset_id': dataset_id, 'document_id': str(document_model.id)}).get_paragraph_problem_model( diff --git a/apps/dataset/serializers/paragraph_serializers.py b/apps/dataset/serializers/paragraph_serializers.py index 74491c7cb..48a1ef434 100644 --- a/apps/dataset/serializers/paragraph_serializers.py +++ b/apps/dataset/serializers/paragraph_serializers.py @@ -37,7 +37,7 @@ class ParagraphInstanceSerializer(ApiMixin, serializers.Serializer): 段落实例对象 """ content = serializers.CharField(required=True, validators=[ - validators.MaxLengthValidator(limit_value=1024, + validators.MaxLengthValidator(limit_value=4096, message="段落在1-1024个字符之间"), validators.MinLengthValidator(limit_value=1, message="段落在1-1024个字符之间"), diff --git a/apps/dataset/urls.py b/apps/dataset/urls.py index 16078842e..49d7903d6 100644 --- a/apps/dataset/urls.py +++ b/apps/dataset/urls.py @@ -5,6 +5,7 @@ from . import views app_name = "dataset" urlpatterns = [ path('dataset', views.Dataset.as_view(), name="dataset"), + path('dataset/web', views.Dataset.CreateWebDataset.as_view()), path('dataset/', views.Dataset.Operate.as_view(), name="dataset_key"), path('dataset//application', views.Dataset.Application.as_view()), path('dataset//', views.Dataset.Page.as_view(), name="dataset"), diff --git a/apps/dataset/views/dataset.py b/apps/dataset/views/dataset.py index 0721ee025..5f7fab5bc 100644 --- a/apps/dataset/views/dataset.py +++ b/apps/dataset/views/dataset.py @@ -23,6 +23,21 @@ from dataset.serializers.dataset_serializers import DataSetSerializers class Dataset(APIView): authentication_classes = [TokenAuth] + class CreateWebDataset(APIView): + authentication_classes = [TokenAuth] + + @action(methods=['POST'], detail=False) + @swagger_auto_schema(operation_summary="创建web站点知识库", + operation_id="创建web站点知识库", + request_body=DataSetSerializers.Create.CreateWebSerializers.get_request_body_api(), + responses=get_api_response( + DataSetSerializers.Create.CreateWebSerializers.get_response_body_api()), + tags=["知识库"] + ) + @has_permissions(PermissionConstants.DATASET_CREATE, compare=CompareConstants.AND) + def post(self, request: Request): + return result.success(DataSetSerializers.Create(data={'user_id': request.user.id}).save_web(request.data)) + class Application(APIView): authentication_classes = [TokenAuth] @@ -58,9 +73,7 @@ class Dataset(APIView): ) @has_permissions(PermissionConstants.DATASET_CREATE, compare=CompareConstants.AND) def post(self, request: Request): - s = DataSetSerializers.Create(data=request.data) - s.is_valid(raise_exception=True) - return result.success(s.save(request.user)) + return result.success(DataSetSerializers.Create(data={'user_id': request.user.id}).save(request.data)) class HitTest(APIView): authentication_classes = [TokenAuth] diff --git a/apps/setting/migrations/0002_alter_teammemberpermission_auth_target_type_and_more.py b/apps/setting/migrations/0002_alter_teammemberpermission_auth_target_type_and_more.py new file mode 100644 index 000000000..d11f5bdad --- /dev/null +++ b/apps/setting/migrations/0002_alter_teammemberpermission_auth_target_type_and_more.py @@ -0,0 +1,23 @@ +# Generated by Django 4.1.10 on 2023-12-28 15:16 + +from django.db import migrations, models + + +class Migration(migrations.Migration): + + dependencies = [ + ('setting', '0001_initial'), + ] + + operations = [ + migrations.AlterField( + model_name='teammemberpermission', + name='auth_target_type', + field=models.CharField(choices=[('DATASET', '数据集'), ('APPLICATION', '应用')], default='DATASET', max_length=128, verbose_name='授权目标'), + ), + migrations.AlterField( + model_name='teammemberpermission', + name='target', + field=models.UUIDField(verbose_name='数据集/应用id'), + ), + ]