mirror of
https://github.com/1Panel-dev/MaxKB.git
synced 2025-12-26 01:33:05 +00:00
feat: web数据集
This commit is contained in:
parent
89a74dd862
commit
64c8cc6b39
|
|
@ -0,0 +1,20 @@
|
|||
# Generated by Django 4.1.10 on 2023-12-28 15:16
|
||||
|
||||
from django.db import migrations, models
|
||||
import django.db.models.deletion
|
||||
|
||||
|
||||
class Migration(migrations.Migration):
|
||||
|
||||
dependencies = [
|
||||
('dataset', '0002_dataset_meta_dataset_type_document_meta_and_more'),
|
||||
('application', '0001_initial'),
|
||||
]
|
||||
|
||||
operations = [
|
||||
migrations.AlterField(
|
||||
model_name='chatrecord',
|
||||
name='dataset',
|
||||
field=models.ForeignKey(blank=True, null=True, on_delete=django.db.models.deletion.SET_NULL, to='dataset.dataset', verbose_name='数据集'),
|
||||
),
|
||||
]
|
||||
|
|
@ -18,6 +18,8 @@ from common.config.embedding_config import VectorStore, EmbeddingModel
|
|||
from common.db.search import native_search, get_dynamics_model
|
||||
from common.event.common import poxy
|
||||
from common.util.file_util import get_file_content
|
||||
from common.util.fork import ForkManage
|
||||
from common.util.lock import try_lock, un_lock
|
||||
from dataset.models import Paragraph, Status, Document
|
||||
from embedding.models import SourceType
|
||||
from smartdoc.conf import PROJECT_DIR
|
||||
|
|
@ -26,6 +28,14 @@ max_kb_error = logging.getLogger("max_kb_error")
|
|||
max_kb = logging.getLogger("max_kb")
|
||||
|
||||
|
||||
class SyncWebDatasetArgs:
|
||||
def __init__(self, lock_key: str, url: str, selector: str, handler):
|
||||
self.lock_key = lock_key
|
||||
self.url = url
|
||||
self.selector = selector
|
||||
self.handler = handler
|
||||
|
||||
|
||||
class ListenerManagement:
|
||||
embedding_by_problem_signal = signal("embedding_by_problem")
|
||||
embedding_by_paragraph_signal = signal("embedding_by_paragraph")
|
||||
|
|
@ -38,6 +48,7 @@ class ListenerManagement:
|
|||
enable_embedding_by_paragraph_signal = signal('enable_embedding_by_paragraph')
|
||||
disable_embedding_by_paragraph_signal = signal('disable_embedding_by_paragraph')
|
||||
init_embedding_model_signal = signal('init_embedding_model')
|
||||
sync_web_dataset_signal = signal('sync_web_dataset')
|
||||
|
||||
@staticmethod
|
||||
def embedding_by_problem(args):
|
||||
|
|
@ -144,6 +155,18 @@ class ListenerManagement:
|
|||
def enable_embedding_by_paragraph(paragraph_id):
|
||||
VectorStore.get_embedding_vector().update_by_paragraph_id(paragraph_id, {'is_active': True})
|
||||
|
||||
@staticmethod
|
||||
@poxy
|
||||
def sync_web_dataset(args: SyncWebDatasetArgs):
|
||||
if try_lock('sync_web_dataset' + args.lock_key):
|
||||
try:
|
||||
ForkManage(args.url, args.selector.split(" ")).fork(2, set(),
|
||||
args.handler)
|
||||
except Exception as e:
|
||||
logging.getLogger("max_kb_error").error(f'{str(e)}:{traceback.format_exc()}')
|
||||
finally:
|
||||
un_lock('sync_web_dataset' + args.lock_key)
|
||||
|
||||
@staticmethod
|
||||
@poxy
|
||||
def init_embedding_model(ags):
|
||||
|
|
@ -175,3 +198,5 @@ class ListenerManagement:
|
|||
ListenerManagement.enable_embedding_by_paragraph_signal.connect(self.enable_embedding_by_paragraph)
|
||||
# 初始化向量化模型
|
||||
ListenerManagement.init_embedding_model_signal.connect(self.init_embedding_model)
|
||||
# 同步web站点知识库
|
||||
ListenerManagement.sync_web_dataset_signal.connect(self.sync_web_dataset)
|
||||
|
|
|
|||
|
|
@ -1,10 +1,19 @@
|
|||
import copy
|
||||
import logging
|
||||
import re
|
||||
import traceback
|
||||
from functools import reduce
|
||||
from typing import List, Set
|
||||
import requests
|
||||
import html2text as ht
|
||||
from bs4 import BeautifulSoup
|
||||
from urllib.parse import urljoin
|
||||
from urllib.parse import urljoin, urlparse, ParseResult
|
||||
|
||||
|
||||
class ChildLink:
|
||||
def __init__(self, url, tag):
|
||||
self.url = url
|
||||
self.tag = copy.deepcopy(tag)
|
||||
|
||||
|
||||
class ForkManage:
|
||||
|
|
@ -13,30 +22,34 @@ class ForkManage:
|
|||
self.selector_list = selector_list
|
||||
|
||||
def fork(self, level: int, exclude_link_url: Set[str], fork_handler):
|
||||
self.fork_child(self.base_url, self.selector_list, level, exclude_link_url, fork_handler)
|
||||
self.fork_child(ChildLink(self.base_url, None), self.selector_list, level, exclude_link_url, fork_handler)
|
||||
|
||||
@staticmethod
|
||||
def fork_child(base_url: str, selector_list: List[str], level: int, exclude_link_url: Set[str], fork_handler):
|
||||
def fork_child(child_link: ChildLink, selector_list: List[str], level: int, exclude_link_url: Set[str],
|
||||
fork_handler):
|
||||
if level < 0:
|
||||
return
|
||||
response = Fork(base_url, selector_list).fork()
|
||||
fork_handler(base_url, response)
|
||||
else:
|
||||
child_url = child_link.url[:-1] if child_link.url.endswith('/') else child_link.url
|
||||
exclude_link_url.add(child_url)
|
||||
response = Fork(child_link.url, selector_list).fork()
|
||||
fork_handler(child_link, response)
|
||||
for child_link in response.child_link_list:
|
||||
if not exclude_link_url.__contains__(child_link):
|
||||
exclude_link_url.add(child_link)
|
||||
child_url = child_link.url[:-1] if child_link.url.endswith('/') else child_link.url
|
||||
if not exclude_link_url.__contains__(child_url):
|
||||
ForkManage.fork_child(child_link, selector_list, level - 1, exclude_link_url, fork_handler)
|
||||
|
||||
|
||||
class Fork:
|
||||
class Response:
|
||||
def __init__(self, html_content: str, child_link_list: List[str], status, message: str):
|
||||
self.html_content = html_content
|
||||
def __init__(self, content: str, child_link_list: List[ChildLink], status, message: str):
|
||||
self.content = content
|
||||
self.child_link_list = child_link_list
|
||||
self.status = status
|
||||
self.message = message
|
||||
|
||||
@staticmethod
|
||||
def success(html_content: str, child_link_list: List[str]):
|
||||
def success(html_content: str, child_link_list: List[ChildLink]):
|
||||
return Fork.Response(html_content, child_link_list, 200, '')
|
||||
|
||||
@staticmethod
|
||||
|
|
@ -45,13 +58,17 @@ class Fork:
|
|||
|
||||
def __init__(self, base_fork_url: str, selector_list: List[str]):
|
||||
self.base_fork_url = urljoin(base_fork_url if base_fork_url.endswith("/") else base_fork_url + '/', '.')
|
||||
self.base_fork_url = base_fork_url
|
||||
self.base_fork_url = self.base_fork_url[:-1]
|
||||
self.selector_list = selector_list
|
||||
self.urlparse = urlparse(self.base_fork_url)
|
||||
self.base_url = ParseResult(scheme=self.urlparse.scheme, netloc=self.urlparse.netloc, path='', params='',
|
||||
query='',
|
||||
fragment='').geturl()
|
||||
|
||||
def get_child_link_list(self, bf: BeautifulSoup):
|
||||
pattern = "^(?!(http:|https:|tel:/|#|mailto:|javascript:)).*|" + self.base_fork_url
|
||||
pattern = "^((?!(http:|https:|tel:/|#|mailto:|javascript:))|" + self.base_fork_url + ").*"
|
||||
link_list = bf.find_all(name='a', href=re.compile(pattern))
|
||||
result = [self.parse_href(link.get('href')) for link in link_list]
|
||||
result = [ChildLink(link.get('href'), link) for link in link_list]
|
||||
return result
|
||||
|
||||
def get_content_html(self, bf: BeautifulSoup):
|
||||
|
|
@ -65,23 +82,34 @@ class Fork:
|
|||
f = bf.find_all(**params)
|
||||
return "\n".join([str(row) for row in f])
|
||||
|
||||
def parse_href(self, href: str):
|
||||
if href.startswith(self.base_fork_url[:-1] if self.base_fork_url.endswith('/') else self.base_fork_url):
|
||||
return href
|
||||
@staticmethod
|
||||
def reset_url(tag, field, base_fork_url):
|
||||
field_value: str = tag[field]
|
||||
if field_value.startswith("/"):
|
||||
result = urlparse(base_fork_url)
|
||||
result_url = ParseResult(scheme=result.scheme, netloc=result.netloc, path=field_value, params='', query='',
|
||||
fragment='').geturl()
|
||||
else:
|
||||
return urljoin(self.base_fork_url + '/' + (href if href.endswith('/') else href + '/'), ".")
|
||||
result_url = urljoin(
|
||||
base_fork_url + '/' + (field_value if field_value.endswith('/') else field_value + '/'),
|
||||
".")
|
||||
result_url = result_url[:-1] if result_url.endswith('/') else result_url
|
||||
tag[field] = result_url
|
||||
|
||||
def reset_beautiful_soup(self, bf: BeautifulSoup):
|
||||
href_list = bf.find_all(href=re.compile('^(?!(http:|https:|tel:/|#|mailto:|javascript:)).*'))
|
||||
for h in href_list:
|
||||
h['href'] = urljoin(
|
||||
self.base_fork_url + '/' + (h['href'] if h['href'].endswith('/') else h['href'] + '/'),
|
||||
".")[:-1]
|
||||
src_list = bf.find_all(src=re.compile('^(?!(http:|https:|tel:/|#|mailto:|javascript:)).*'))
|
||||
for s in src_list:
|
||||
s['src'] = urljoin(
|
||||
self.base_fork_url + '/' + (s['src'] if s['src'].endswith('/') else s['src'] + '/'),
|
||||
".")[:-1]
|
||||
reset_config_list = [
|
||||
{
|
||||
'field': 'href',
|
||||
},
|
||||
{
|
||||
'field': 'src',
|
||||
}
|
||||
]
|
||||
for reset_config in reset_config_list:
|
||||
field = reset_config.get('field')
|
||||
tag_list = bf.find_all(**{field: re.compile('^(?!(http:|https:|tel:/|#|mailto:|javascript:)).*')})
|
||||
for tag in tag_list:
|
||||
self.reset_url(tag, field, self.base_fork_url)
|
||||
return bf
|
||||
|
||||
@staticmethod
|
||||
|
|
@ -92,11 +120,14 @@ class Fork:
|
|||
|
||||
def fork(self):
|
||||
try:
|
||||
logging.getLogger("max_kb").info(f'fork:{self.base_fork_url}')
|
||||
response = requests.get(self.base_fork_url)
|
||||
if response.status_code != 200:
|
||||
raise Exception(response.status_code)
|
||||
logging.getLogger("max_kb").error(f"url: {self.base_fork_url} code:{response.status_code}")
|
||||
return Fork.Response.error(f"url: {self.base_fork_url} code:{response.status_code}")
|
||||
bf = self.get_beautiful_soup(response)
|
||||
except Exception as e:
|
||||
logging.getLogger("max_kb_error").error(f'{str(e)}:{traceback.format_exc()}')
|
||||
return Fork.Response.error(str(e))
|
||||
bf = self.reset_beautiful_soup(bf)
|
||||
link_list = self.get_child_link_list(bf)
|
||||
|
|
@ -106,7 +137,6 @@ class Fork:
|
|||
|
||||
|
||||
def handler(base_url, response: Fork.Response):
|
||||
print(base_url, response.status)
|
||||
print(base_url.url, base_url.tag.text if base_url.tag else None, response.content)
|
||||
|
||||
|
||||
ForkManage('https://dataease.io/docs/v2/', ['.md-content']).fork(3, set(), handler)
|
||||
# ForkManage('https://bbs.fit2cloud.com/c/de/6', ['.md-content']).fork(3, set(), handler)
|
||||
|
|
|
|||
|
|
@ -277,11 +277,11 @@ def filter_special_char(content: str):
|
|||
|
||||
class SplitModel:
|
||||
|
||||
def __init__(self, content_level_pattern, with_filter=True, limit=1024):
|
||||
def __init__(self, content_level_pattern, with_filter=True, limit=4096):
|
||||
self.content_level_pattern = content_level_pattern
|
||||
self.with_filter = with_filter
|
||||
if limit is None or limit > 1024:
|
||||
limit = 1024
|
||||
if limit is None or limit > 4096:
|
||||
limit = 4096
|
||||
if limit < 50:
|
||||
limit = 50
|
||||
self.limit = limit
|
||||
|
|
@ -337,13 +337,12 @@ class SplitModel:
|
|||
default_split_pattern = {
|
||||
'md': [re.compile("^# .*"), re.compile('(?<!#)## (?!#).*'), re.compile("(?<!#)### (?!#).*"),
|
||||
re.compile("(?<!#)####(?!#).*"), re.compile("(?<!#)#####(?!#).*"),
|
||||
re.compile("(?<!#)######(?!#).*"),
|
||||
re.compile("(?<! )- .*")],
|
||||
re.compile("(?<!#)######(?!#).*")],
|
||||
'default': [re.compile("(?<!\n)\n\n.+")]
|
||||
}
|
||||
|
||||
|
||||
def get_split_model(filename: str, with_filter: bool, limit: int):
|
||||
def get_split_model(filename: str, with_filter: bool = False, limit: int = 4096):
|
||||
"""
|
||||
根据文件名称获取分段模型
|
||||
:param limit: 每段大小
|
||||
|
|
|
|||
|
|
@ -0,0 +1,38 @@
|
|||
# Generated by Django 4.1.10 on 2023-12-28 15:16
|
||||
|
||||
from django.db import migrations, models
|
||||
|
||||
|
||||
class Migration(migrations.Migration):
|
||||
|
||||
dependencies = [
|
||||
('dataset', '0001_initial'),
|
||||
]
|
||||
|
||||
operations = [
|
||||
migrations.AddField(
|
||||
model_name='dataset',
|
||||
name='meta',
|
||||
field=models.JSONField(default=dict, verbose_name='元数据'),
|
||||
),
|
||||
migrations.AddField(
|
||||
model_name='dataset',
|
||||
name='type',
|
||||
field=models.CharField(choices=[('0', '通用类型'), ('1', 'web站点类型')], default='0', max_length=1, verbose_name='类型'),
|
||||
),
|
||||
migrations.AddField(
|
||||
model_name='document',
|
||||
name='meta',
|
||||
field=models.JSONField(default=dict, verbose_name='元数据'),
|
||||
),
|
||||
migrations.AddField(
|
||||
model_name='document',
|
||||
name='type',
|
||||
field=models.CharField(choices=[('0', '通用类型'), ('1', 'web站点类型')], default='0', max_length=1, verbose_name='类型'),
|
||||
),
|
||||
migrations.AlterField(
|
||||
model_name='dataset',
|
||||
name='name',
|
||||
field=models.CharField(max_length=150, verbose_name='数据集名称'),
|
||||
),
|
||||
]
|
||||
|
|
@ -0,0 +1,18 @@
|
|||
# Generated by Django 4.1.10 on 2023-12-29 17:49
|
||||
|
||||
from django.db import migrations, models
|
||||
|
||||
|
||||
class Migration(migrations.Migration):
|
||||
|
||||
dependencies = [
|
||||
('dataset', '0002_dataset_meta_dataset_type_document_meta_and_more'),
|
||||
]
|
||||
|
||||
operations = [
|
||||
migrations.AlterField(
|
||||
model_name='paragraph',
|
||||
name='content',
|
||||
field=models.CharField(max_length=4096, verbose_name='段落内容'),
|
||||
),
|
||||
]
|
||||
|
|
@ -21,6 +21,12 @@ class Status(models.TextChoices):
|
|||
error = 2, '导入失败'
|
||||
|
||||
|
||||
class Type(models.TextChoices):
|
||||
base = 0, '通用类型'
|
||||
|
||||
web = 1, 'web站点类型'
|
||||
|
||||
|
||||
class DataSet(AppModelMixin):
|
||||
"""
|
||||
数据集表
|
||||
|
|
@ -29,6 +35,10 @@ class DataSet(AppModelMixin):
|
|||
name = models.CharField(max_length=150, verbose_name="数据集名称")
|
||||
desc = models.CharField(max_length=256, verbose_name="数据库描述")
|
||||
user = models.ForeignKey(User, on_delete=models.DO_NOTHING, verbose_name="所属用户")
|
||||
type = models.CharField(verbose_name='类型', max_length=1, choices=Type.choices,
|
||||
default=Type.base)
|
||||
|
||||
meta = models.JSONField(verbose_name="元数据", default=dict)
|
||||
|
||||
class Meta:
|
||||
db_table = "dataset"
|
||||
|
|
@ -46,6 +56,11 @@ class Document(AppModelMixin):
|
|||
default=Status.embedding)
|
||||
is_active = models.BooleanField(default=True)
|
||||
|
||||
type = models.CharField(verbose_name='类型', max_length=1, choices=Type.choices,
|
||||
default=Type.base)
|
||||
|
||||
meta = models.JSONField(verbose_name="元数据", default=dict)
|
||||
|
||||
class Meta:
|
||||
db_table = "document"
|
||||
|
||||
|
|
@ -57,7 +72,7 @@ class Paragraph(AppModelMixin):
|
|||
id = models.UUIDField(primary_key=True, max_length=128, default=uuid.uuid1, editable=False, verbose_name="主键id")
|
||||
document = models.ForeignKey(Document, on_delete=models.DO_NOTHING, db_constraint=False)
|
||||
dataset = models.ForeignKey(DataSet, on_delete=models.DO_NOTHING)
|
||||
content = models.CharField(max_length=1024, verbose_name="段落内容")
|
||||
content = models.CharField(max_length=4096, verbose_name="段落内容")
|
||||
title = models.CharField(max_length=256, verbose_name="标题", default="")
|
||||
hit_num = models.IntegerField(verbose_name="命中数量", default=0)
|
||||
star_num = models.IntegerField(verbose_name="点赞数", default=0)
|
||||
|
|
|
|||
|
|
@ -6,11 +6,13 @@
|
|||
@date:2023/9/21 16:14
|
||||
@desc:
|
||||
"""
|
||||
import logging
|
||||
import os.path
|
||||
import traceback
|
||||
import uuid
|
||||
from functools import reduce
|
||||
from itertools import groupby
|
||||
from typing import Dict
|
||||
from urllib.parse import urlparse
|
||||
|
||||
from django.contrib.postgres.fields import ArrayField
|
||||
from django.core import validators
|
||||
|
|
@ -23,17 +25,18 @@ from application.models import ApplicationDatasetMapping
|
|||
from common.config.embedding_config import VectorStore, EmbeddingModel
|
||||
from common.db.search import get_dynamics_model, native_page_search, native_search
|
||||
from common.db.sql_execute import select_list
|
||||
from common.event.listener_manage import ListenerManagement
|
||||
from common.event.listener_manage import ListenerManagement, SyncWebDatasetArgs
|
||||
from common.exception.app_exception import AppApiException
|
||||
from common.mixins.api_mixin import ApiMixin
|
||||
from common.util.common import post
|
||||
from common.util.file_util import get_file_content
|
||||
from dataset.models.data_set import DataSet, Document, Paragraph, Problem
|
||||
from common.util.fork import ChildLink, Fork, ForkManage
|
||||
from common.util.split_model import get_split_model
|
||||
from dataset.models.data_set import DataSet, Document, Paragraph, Problem, Type
|
||||
from dataset.serializers.common_serializers import list_paragraph
|
||||
from dataset.serializers.document_serializers import DocumentSerializers, DocumentInstanceSerializer
|
||||
from setting.models import AuthOperate
|
||||
from smartdoc.conf import PROJECT_DIR
|
||||
from users.models import User
|
||||
|
||||
"""
|
||||
# __exact 精确等于 like ‘aaa’
|
||||
|
|
@ -187,30 +190,105 @@ class DataSetSerializers(serializers.ModelSerializer):
|
|||
return DataSetSerializers.Operate.get_response_body_api()
|
||||
|
||||
class Create(ApiMixin, serializers.Serializer):
|
||||
"""
|
||||
创建序列化对象
|
||||
"""
|
||||
name = serializers.CharField(required=True,
|
||||
validators=[
|
||||
validators.MaxLengthValidator(limit_value=20,
|
||||
message="知识库名称在1-20个字符之间"),
|
||||
validators.MinLengthValidator(limit_value=1,
|
||||
message="知识库名称在1-20个字符之间")
|
||||
])
|
||||
user_id = serializers.UUIDField(required=True)
|
||||
|
||||
desc = serializers.CharField(required=True,
|
||||
validators=[
|
||||
validators.MaxLengthValidator(limit_value=256,
|
||||
message="知识库名称在1-256个字符之间"),
|
||||
validators.MinLengthValidator(limit_value=1,
|
||||
message="知识库名称在1-256个字符之间")
|
||||
])
|
||||
class CreateBaseSerializers(ApiMixin, serializers.Serializer):
|
||||
"""
|
||||
创建通用数据集序列化对象
|
||||
"""
|
||||
name = serializers.CharField(required=True,
|
||||
validators=[
|
||||
validators.MaxLengthValidator(limit_value=20,
|
||||
message="知识库名称在1-20个字符之间"),
|
||||
validators.MinLengthValidator(limit_value=1,
|
||||
message="知识库名称在1-20个字符之间")
|
||||
])
|
||||
|
||||
documents = DocumentInstanceSerializer(required=False, many=True)
|
||||
desc = serializers.CharField(required=True,
|
||||
validators=[
|
||||
validators.MaxLengthValidator(limit_value=256,
|
||||
message="知识库名称在1-256个字符之间"),
|
||||
validators.MinLengthValidator(limit_value=1,
|
||||
message="知识库名称在1-256个字符之间")
|
||||
])
|
||||
|
||||
def is_valid(self, *, raise_exception=False):
|
||||
super().is_valid(raise_exception=True)
|
||||
return True
|
||||
documents = DocumentInstanceSerializer(required=False, many=True)
|
||||
|
||||
def is_valid(self, *, raise_exception=False):
|
||||
super().is_valid(raise_exception=True)
|
||||
return True
|
||||
|
||||
class CreateWebSerializers(serializers.Serializer):
|
||||
"""
|
||||
创建web站点序列化对象
|
||||
"""
|
||||
name = serializers.CharField(required=True,
|
||||
validators=[
|
||||
validators.MaxLengthValidator(limit_value=20,
|
||||
message="知识库名称在1-20个字符之间"),
|
||||
validators.MinLengthValidator(limit_value=1,
|
||||
message="知识库名称在1-20个字符之间")
|
||||
])
|
||||
|
||||
desc = serializers.CharField(required=True,
|
||||
validators=[
|
||||
validators.MaxLengthValidator(limit_value=256,
|
||||
message="知识库名称在1-256个字符之间"),
|
||||
validators.MinLengthValidator(limit_value=1,
|
||||
message="知识库名称在1-256个字符之间")
|
||||
])
|
||||
url = serializers.CharField(required=True)
|
||||
|
||||
selector = serializers.CharField(required=False, allow_null=True, allow_blank=True)
|
||||
|
||||
def is_valid(self, *, raise_exception=False):
|
||||
super().is_valid(raise_exception=True)
|
||||
return True
|
||||
|
||||
@staticmethod
|
||||
def get_response_body_api():
|
||||
return openapi.Schema(
|
||||
type=openapi.TYPE_OBJECT,
|
||||
required=['id', 'name', 'desc', 'user_id', 'char_length', 'document_count',
|
||||
'update_time', 'create_time', 'document_list'],
|
||||
properties={
|
||||
'id': openapi.Schema(type=openapi.TYPE_STRING, title="id",
|
||||
description="id", default="xx"),
|
||||
'name': openapi.Schema(type=openapi.TYPE_STRING, title="名称",
|
||||
description="名称", default="测试知识库"),
|
||||
'desc': openapi.Schema(type=openapi.TYPE_STRING, title="描述",
|
||||
description="描述", default="测试知识库描述"),
|
||||
'user_id': openapi.Schema(type=openapi.TYPE_STRING, title="所属用户id",
|
||||
description="所属用户id", default="user_xxxx"),
|
||||
'char_length': openapi.Schema(type=openapi.TYPE_STRING, title="字符数",
|
||||
description="字符数", default=10),
|
||||
'document_count': openapi.Schema(type=openapi.TYPE_STRING, title="文档数量",
|
||||
description="文档数量", default=1),
|
||||
'update_time': openapi.Schema(type=openapi.TYPE_STRING, title="修改时间",
|
||||
description="修改时间",
|
||||
default="1970-01-01 00:00:00"),
|
||||
'create_time': openapi.Schema(type=openapi.TYPE_STRING, title="创建时间",
|
||||
description="创建时间",
|
||||
default="1970-01-01 00:00:00"
|
||||
),
|
||||
'document_list': openapi.Schema(type=openapi.TYPE_ARRAY, title="文档列表",
|
||||
description="文档列表",
|
||||
items=DocumentSerializers.Operate.get_response_body_api())
|
||||
}
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def get_request_body_api():
|
||||
return openapi.Schema(
|
||||
type=openapi.TYPE_OBJECT,
|
||||
required=['name', 'desc', 'url'],
|
||||
properties={
|
||||
'name': openapi.Schema(type=openapi.TYPE_STRING, title="知识库名称", description="知识库名称"),
|
||||
'desc': openapi.Schema(type=openapi.TYPE_STRING, title="知识库描述", description="知识库描述"),
|
||||
'url': openapi.Schema(type=openapi.TYPE_STRING, title="web站点url", description="web站点url"),
|
||||
'selector': openapi.Schema(type=openapi.TYPE_STRING, title="选择器", description="选择器")
|
||||
}
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def post_embedding_dataset(document_list, dataset_id):
|
||||
|
|
@ -220,16 +298,21 @@ class DataSetSerializers(serializers.ModelSerializer):
|
|||
|
||||
@post(post_function=post_embedding_dataset)
|
||||
@transaction.atomic
|
||||
def save(self, user: User):
|
||||
def save(self, instance: Dict, with_valid=True):
|
||||
if with_valid:
|
||||
self.is_valid(raise_exception=True)
|
||||
self.CreateBaseSerializers(data=instance).is_valid()
|
||||
dataset_id = uuid.uuid1()
|
||||
user_id = self.data.get('user_id')
|
||||
|
||||
dataset = DataSet(
|
||||
**{'id': dataset_id, 'name': self.data.get("name"), 'desc': self.data.get('desc'), 'user': user})
|
||||
**{'id': dataset_id, 'name': instance.get("name"), 'desc': instance.get('desc'), 'user_id': user_id})
|
||||
|
||||
document_model_list = []
|
||||
paragraph_model_list = []
|
||||
problem_model_list = []
|
||||
# 插入文档
|
||||
for document in self.data.get('documents') if 'documents' in self.data else []:
|
||||
for document in instance.get('documents') if 'documents' in instance else []:
|
||||
document_paragraph_dict_model = DocumentSerializers.Create.get_document_paragraph_model(dataset_id,
|
||||
document)
|
||||
document_model_list.append(document_paragraph_dict_model.get('document'))
|
||||
|
|
@ -252,6 +335,47 @@ class DataSetSerializers(serializers.ModelSerializer):
|
|||
'document_list': DocumentSerializers.Query(data={'dataset_id': dataset_id}).list(
|
||||
with_valid=True)}, dataset_id
|
||||
|
||||
@staticmethod
|
||||
def get_last_url_path(url):
|
||||
parsed_url = urlparse(url)
|
||||
if parsed_url.path is None or len(parsed_url.path) == 0:
|
||||
return url
|
||||
else:
|
||||
return parsed_url.path.split("/")[-1]
|
||||
|
||||
@staticmethod
|
||||
def get_save_handler(dataset_id, selector):
|
||||
def handler(child_link: ChildLink, response: Fork.Response):
|
||||
if response.status == 200:
|
||||
try:
|
||||
document_name = child_link.tag.text if child_link.tag is not None and len(
|
||||
child_link.tag.text.strip()) > 0 else child_link.url
|
||||
paragraphs = get_split_model('web.md').parse(response.content)
|
||||
DocumentSerializers.Create(data={'dataset_id': dataset_id}).save(
|
||||
{'name': document_name, 'paragraphs': paragraphs,
|
||||
'meta': {'source_url': child_link.url, 'selector': selector},
|
||||
'type': Type.web}, with_valid=True)
|
||||
except Exception as e:
|
||||
logging.getLogger("max_kb_error").error(f'{str(e)}:{traceback.format_exc()}')
|
||||
|
||||
return handler
|
||||
|
||||
def save_web(self, instance: Dict, with_valid=True):
|
||||
if with_valid:
|
||||
self.is_valid(raise_exception=True)
|
||||
self.CreateWebSerializers(data=instance).is_valid(raise_exception=True)
|
||||
user_id = self.data.get('user_id')
|
||||
dataset_id = uuid.uuid1()
|
||||
dataset = DataSet(
|
||||
**{'id': dataset_id, 'name': instance.get("name"), 'desc': instance.get('desc'), 'user_id': user_id,
|
||||
'type': Type.web, 'meta': {'source_url': instance.get('url'), 'selector': instance.get('selector')}})
|
||||
dataset.save()
|
||||
ListenerManagement.sync_web_dataset_signal.send(
|
||||
SyncWebDatasetArgs(str(dataset_id), instance.get('url'), instance.get('selector'),
|
||||
self.get_save_handler(dataset_id, instance.get('selector'))))
|
||||
return {**DataSetSerializers(dataset).data,
|
||||
'document_list': []}
|
||||
|
||||
@staticmethod
|
||||
def get_response_body_api():
|
||||
return openapi.Schema(
|
||||
|
|
@ -298,12 +422,43 @@ class DataSetSerializers(serializers.ModelSerializer):
|
|||
}
|
||||
)
|
||||
|
||||
class Edit(serializers.Serializer):
|
||||
class MetaSerializer(serializers.Serializer):
|
||||
class WebMeta(serializers.Serializer):
|
||||
source_url = serializers.CharField(required=True)
|
||||
selector = serializers.CharField(required=False, allow_null=True, allow_blank=True)
|
||||
|
||||
def is_valid(self, *, raise_exception=False):
|
||||
super().is_valid(raise_exception=True)
|
||||
source_url = self.data.get('source_url')
|
||||
response = Fork(source_url, []).fork()
|
||||
if response.status == 500:
|
||||
raise AppApiException(500, response.message)
|
||||
|
||||
class BaseMeta(serializers.Serializer):
|
||||
def is_valid(self, *, raise_exception=False):
|
||||
super().is_valid(raise_exception=True)
|
||||
|
||||
class Edit(serializers.Serializer):
|
||||
name = serializers.CharField(required=False)
|
||||
desc = serializers.CharField(required=False)
|
||||
meta = serializers.DictField(required=False)
|
||||
application_id_list = serializers.ListSerializer(required=False, child=serializers.UUIDField(required=True))
|
||||
|
||||
@staticmethod
|
||||
def get_dataset_meta_valid_map():
|
||||
dataset_meta_valid_map = {
|
||||
Type.base: DataSetSerializers.MetaSerializer.BaseMeta,
|
||||
Type.web: DataSetSerializers.MetaSerializer.WebMeta
|
||||
}
|
||||
return dataset_meta_valid_map
|
||||
|
||||
def is_valid(self, *, dataset: DataSet = None):
|
||||
super().is_valid(raise_exception=True)
|
||||
if 'meta' in self.data and self.data.get('meta') is not None:
|
||||
dataset_meta_valid_map = self.get_dataset_meta_valid_map()
|
||||
valid_class = dataset_meta_valid_map.get(dataset.type)
|
||||
valid_class(data=self.data.get('meta')).is_valid(raise_exception=True)
|
||||
|
||||
class HitTest(ApiMixin, serializers.Serializer):
|
||||
id = serializers.CharField(required=True)
|
||||
user_id = serializers.UUIDField(required=False)
|
||||
|
|
@ -392,12 +547,14 @@ class DataSetSerializers(serializers.ModelSerializer):
|
|||
:return:
|
||||
"""
|
||||
self.is_valid()
|
||||
DataSetSerializers.Edit(data=dataset).is_valid(raise_exception=True)
|
||||
_dataset = QuerySet(DataSet).get(id=self.data.get("id"))
|
||||
DataSetSerializers.Edit(data=dataset).is_valid(dataset=_dataset)
|
||||
if "name" in dataset:
|
||||
_dataset.name = dataset.get("name")
|
||||
if 'desc' in dataset:
|
||||
_dataset.desc = dataset.get("desc")
|
||||
if 'meta' in dataset:
|
||||
_dataset.meta = dataset.get('meta')
|
||||
if 'application_id_list' in dataset and dataset.get('application_id_list') is not None:
|
||||
application_id_list = dataset.get('application_id_list')
|
||||
# 当前用户可修改关联的知识库列表
|
||||
|
|
@ -429,6 +586,8 @@ class DataSetSerializers(serializers.ModelSerializer):
|
|||
properties={
|
||||
'name': openapi.Schema(type=openapi.TYPE_STRING, title="知识库名称", description="知识库名称"),
|
||||
'desc': openapi.Schema(type=openapi.TYPE_STRING, title="知识库描述", description="知识库描述"),
|
||||
'meta': openapi.Schema(type=openapi.TYPE_OBJECT, title="知识库元数据",
|
||||
description="知识库元数据->web:{source_url:xxx,selector:'xxx'},base:{}"),
|
||||
'application_id_list': openapi.Schema(type=openapi.TYPE_ARRAY, title="应用id列表",
|
||||
description="应用id列表",
|
||||
items=openapi.Schema(type=openapi.TYPE_STRING))
|
||||
|
|
|
|||
|
|
@ -24,7 +24,7 @@ from common.mixins.api_mixin import ApiMixin
|
|||
from common.util.common import post
|
||||
from common.util.file_util import get_file_content
|
||||
from common.util.split_model import SplitModel, get_split_model
|
||||
from dataset.models.data_set import DataSet, Document, Paragraph, Problem
|
||||
from dataset.models.data_set import DataSet, Document, Paragraph, Problem, Type
|
||||
from dataset.serializers.paragraph_serializers import ParagraphSerializers, ParagraphInstanceSerializer
|
||||
from smartdoc.conf import PROJECT_DIR
|
||||
|
||||
|
|
@ -243,7 +243,9 @@ class DocumentSerializers(ApiMixin, serializers.Serializer):
|
|||
'name': instance.get('name'),
|
||||
'char_length': reduce(lambda x, y: x + y,
|
||||
[len(p.get('content')) for p in instance.get('paragraphs', [])],
|
||||
0)})
|
||||
0),
|
||||
'meta': instance.get('meta') if instance.get('meta') is not None else {},
|
||||
'type': instance.get('type') if instance.get('type') is not None else Type.base})
|
||||
|
||||
paragraph_model_dict_list = [ParagraphSerializers.Create(
|
||||
data={'dataset_id': dataset_id, 'document_id': str(document_model.id)}).get_paragraph_problem_model(
|
||||
|
|
|
|||
|
|
@ -37,7 +37,7 @@ class ParagraphInstanceSerializer(ApiMixin, serializers.Serializer):
|
|||
段落实例对象
|
||||
"""
|
||||
content = serializers.CharField(required=True, validators=[
|
||||
validators.MaxLengthValidator(limit_value=1024,
|
||||
validators.MaxLengthValidator(limit_value=4096,
|
||||
message="段落在1-1024个字符之间"),
|
||||
validators.MinLengthValidator(limit_value=1,
|
||||
message="段落在1-1024个字符之间"),
|
||||
|
|
|
|||
|
|
@ -5,6 +5,7 @@ from . import views
|
|||
app_name = "dataset"
|
||||
urlpatterns = [
|
||||
path('dataset', views.Dataset.as_view(), name="dataset"),
|
||||
path('dataset/web', views.Dataset.CreateWebDataset.as_view()),
|
||||
path('dataset/<str:dataset_id>', views.Dataset.Operate.as_view(), name="dataset_key"),
|
||||
path('dataset/<str:dataset_id>/application', views.Dataset.Application.as_view()),
|
||||
path('dataset/<int:current_page>/<int:page_size>', views.Dataset.Page.as_view(), name="dataset"),
|
||||
|
|
|
|||
|
|
@ -23,6 +23,21 @@ from dataset.serializers.dataset_serializers import DataSetSerializers
|
|||
class Dataset(APIView):
|
||||
authentication_classes = [TokenAuth]
|
||||
|
||||
class CreateWebDataset(APIView):
|
||||
authentication_classes = [TokenAuth]
|
||||
|
||||
@action(methods=['POST'], detail=False)
|
||||
@swagger_auto_schema(operation_summary="创建web站点知识库",
|
||||
operation_id="创建web站点知识库",
|
||||
request_body=DataSetSerializers.Create.CreateWebSerializers.get_request_body_api(),
|
||||
responses=get_api_response(
|
||||
DataSetSerializers.Create.CreateWebSerializers.get_response_body_api()),
|
||||
tags=["知识库"]
|
||||
)
|
||||
@has_permissions(PermissionConstants.DATASET_CREATE, compare=CompareConstants.AND)
|
||||
def post(self, request: Request):
|
||||
return result.success(DataSetSerializers.Create(data={'user_id': request.user.id}).save_web(request.data))
|
||||
|
||||
class Application(APIView):
|
||||
authentication_classes = [TokenAuth]
|
||||
|
||||
|
|
@ -58,9 +73,7 @@ class Dataset(APIView):
|
|||
)
|
||||
@has_permissions(PermissionConstants.DATASET_CREATE, compare=CompareConstants.AND)
|
||||
def post(self, request: Request):
|
||||
s = DataSetSerializers.Create(data=request.data)
|
||||
s.is_valid(raise_exception=True)
|
||||
return result.success(s.save(request.user))
|
||||
return result.success(DataSetSerializers.Create(data={'user_id': request.user.id}).save(request.data))
|
||||
|
||||
class HitTest(APIView):
|
||||
authentication_classes = [TokenAuth]
|
||||
|
|
|
|||
|
|
@ -0,0 +1,23 @@
|
|||
# Generated by Django 4.1.10 on 2023-12-28 15:16
|
||||
|
||||
from django.db import migrations, models
|
||||
|
||||
|
||||
class Migration(migrations.Migration):
|
||||
|
||||
dependencies = [
|
||||
('setting', '0001_initial'),
|
||||
]
|
||||
|
||||
operations = [
|
||||
migrations.AlterField(
|
||||
model_name='teammemberpermission',
|
||||
name='auth_target_type',
|
||||
field=models.CharField(choices=[('DATASET', '数据集'), ('APPLICATION', '应用')], default='DATASET', max_length=128, verbose_name='授权目标'),
|
||||
),
|
||||
migrations.AlterField(
|
||||
model_name='teammemberpermission',
|
||||
name='target',
|
||||
field=models.UUIDField(verbose_name='数据集/应用id'),
|
||||
),
|
||||
]
|
||||
Loading…
Reference in New Issue