feat: web数据集

This commit is contained in:
shaohuzhang1 2023-12-29 18:02:23 +08:00
parent 89a74dd862
commit 64c8cc6b39
13 changed files with 417 additions and 74 deletions

View File

@ -0,0 +1,20 @@
# Generated by Django 4.1.10 on 2023-12-28 15:16
from django.db import migrations, models
import django.db.models.deletion
class Migration(migrations.Migration):
dependencies = [
('dataset', '0002_dataset_meta_dataset_type_document_meta_and_more'),
('application', '0001_initial'),
]
operations = [
migrations.AlterField(
model_name='chatrecord',
name='dataset',
field=models.ForeignKey(blank=True, null=True, on_delete=django.db.models.deletion.SET_NULL, to='dataset.dataset', verbose_name='数据集'),
),
]

View File

@ -18,6 +18,8 @@ from common.config.embedding_config import VectorStore, EmbeddingModel
from common.db.search import native_search, get_dynamics_model
from common.event.common import poxy
from common.util.file_util import get_file_content
from common.util.fork import ForkManage
from common.util.lock import try_lock, un_lock
from dataset.models import Paragraph, Status, Document
from embedding.models import SourceType
from smartdoc.conf import PROJECT_DIR
@ -26,6 +28,14 @@ max_kb_error = logging.getLogger("max_kb_error")
max_kb = logging.getLogger("max_kb")
class SyncWebDatasetArgs:
def __init__(self, lock_key: str, url: str, selector: str, handler):
self.lock_key = lock_key
self.url = url
self.selector = selector
self.handler = handler
class ListenerManagement:
embedding_by_problem_signal = signal("embedding_by_problem")
embedding_by_paragraph_signal = signal("embedding_by_paragraph")
@ -38,6 +48,7 @@ class ListenerManagement:
enable_embedding_by_paragraph_signal = signal('enable_embedding_by_paragraph')
disable_embedding_by_paragraph_signal = signal('disable_embedding_by_paragraph')
init_embedding_model_signal = signal('init_embedding_model')
sync_web_dataset_signal = signal('sync_web_dataset')
@staticmethod
def embedding_by_problem(args):
@ -144,6 +155,18 @@ class ListenerManagement:
def enable_embedding_by_paragraph(paragraph_id):
VectorStore.get_embedding_vector().update_by_paragraph_id(paragraph_id, {'is_active': True})
@staticmethod
@poxy
def sync_web_dataset(args: SyncWebDatasetArgs):
if try_lock('sync_web_dataset' + args.lock_key):
try:
ForkManage(args.url, args.selector.split(" ")).fork(2, set(),
args.handler)
except Exception as e:
logging.getLogger("max_kb_error").error(f'{str(e)}:{traceback.format_exc()}')
finally:
un_lock('sync_web_dataset' + args.lock_key)
@staticmethod
@poxy
def init_embedding_model(ags):
@ -175,3 +198,5 @@ class ListenerManagement:
ListenerManagement.enable_embedding_by_paragraph_signal.connect(self.enable_embedding_by_paragraph)
# 初始化向量化模型
ListenerManagement.init_embedding_model_signal.connect(self.init_embedding_model)
# 同步web站点知识库
ListenerManagement.sync_web_dataset_signal.connect(self.sync_web_dataset)

View File

@ -1,10 +1,19 @@
import copy
import logging
import re
import traceback
from functools import reduce
from typing import List, Set
import requests
import html2text as ht
from bs4 import BeautifulSoup
from urllib.parse import urljoin
from urllib.parse import urljoin, urlparse, ParseResult
class ChildLink:
def __init__(self, url, tag):
self.url = url
self.tag = copy.deepcopy(tag)
class ForkManage:
@ -13,30 +22,34 @@ class ForkManage:
self.selector_list = selector_list
def fork(self, level: int, exclude_link_url: Set[str], fork_handler):
self.fork_child(self.base_url, self.selector_list, level, exclude_link_url, fork_handler)
self.fork_child(ChildLink(self.base_url, None), self.selector_list, level, exclude_link_url, fork_handler)
@staticmethod
def fork_child(base_url: str, selector_list: List[str], level: int, exclude_link_url: Set[str], fork_handler):
def fork_child(child_link: ChildLink, selector_list: List[str], level: int, exclude_link_url: Set[str],
fork_handler):
if level < 0:
return
response = Fork(base_url, selector_list).fork()
fork_handler(base_url, response)
else:
child_url = child_link.url[:-1] if child_link.url.endswith('/') else child_link.url
exclude_link_url.add(child_url)
response = Fork(child_link.url, selector_list).fork()
fork_handler(child_link, response)
for child_link in response.child_link_list:
if not exclude_link_url.__contains__(child_link):
exclude_link_url.add(child_link)
child_url = child_link.url[:-1] if child_link.url.endswith('/') else child_link.url
if not exclude_link_url.__contains__(child_url):
ForkManage.fork_child(child_link, selector_list, level - 1, exclude_link_url, fork_handler)
class Fork:
class Response:
def __init__(self, html_content: str, child_link_list: List[str], status, message: str):
self.html_content = html_content
def __init__(self, content: str, child_link_list: List[ChildLink], status, message: str):
self.content = content
self.child_link_list = child_link_list
self.status = status
self.message = message
@staticmethod
def success(html_content: str, child_link_list: List[str]):
def success(html_content: str, child_link_list: List[ChildLink]):
return Fork.Response(html_content, child_link_list, 200, '')
@staticmethod
@ -45,13 +58,17 @@ class Fork:
def __init__(self, base_fork_url: str, selector_list: List[str]):
self.base_fork_url = urljoin(base_fork_url if base_fork_url.endswith("/") else base_fork_url + '/', '.')
self.base_fork_url = base_fork_url
self.base_fork_url = self.base_fork_url[:-1]
self.selector_list = selector_list
self.urlparse = urlparse(self.base_fork_url)
self.base_url = ParseResult(scheme=self.urlparse.scheme, netloc=self.urlparse.netloc, path='', params='',
query='',
fragment='').geturl()
def get_child_link_list(self, bf: BeautifulSoup):
pattern = "^(?!(http:|https:|tel:/|#|mailto:|javascript:)).*|" + self.base_fork_url
pattern = "^((?!(http:|https:|tel:/|#|mailto:|javascript:))|" + self.base_fork_url + ").*"
link_list = bf.find_all(name='a', href=re.compile(pattern))
result = [self.parse_href(link.get('href')) for link in link_list]
result = [ChildLink(link.get('href'), link) for link in link_list]
return result
def get_content_html(self, bf: BeautifulSoup):
@ -65,23 +82,34 @@ class Fork:
f = bf.find_all(**params)
return "\n".join([str(row) for row in f])
def parse_href(self, href: str):
if href.startswith(self.base_fork_url[:-1] if self.base_fork_url.endswith('/') else self.base_fork_url):
return href
@staticmethod
def reset_url(tag, field, base_fork_url):
field_value: str = tag[field]
if field_value.startswith("/"):
result = urlparse(base_fork_url)
result_url = ParseResult(scheme=result.scheme, netloc=result.netloc, path=field_value, params='', query='',
fragment='').geturl()
else:
return urljoin(self.base_fork_url + '/' + (href if href.endswith('/') else href + '/'), ".")
result_url = urljoin(
base_fork_url + '/' + (field_value if field_value.endswith('/') else field_value + '/'),
".")
result_url = result_url[:-1] if result_url.endswith('/') else result_url
tag[field] = result_url
def reset_beautiful_soup(self, bf: BeautifulSoup):
href_list = bf.find_all(href=re.compile('^(?!(http:|https:|tel:/|#|mailto:|javascript:)).*'))
for h in href_list:
h['href'] = urljoin(
self.base_fork_url + '/' + (h['href'] if h['href'].endswith('/') else h['href'] + '/'),
".")[:-1]
src_list = bf.find_all(src=re.compile('^(?!(http:|https:|tel:/|#|mailto:|javascript:)).*'))
for s in src_list:
s['src'] = urljoin(
self.base_fork_url + '/' + (s['src'] if s['src'].endswith('/') else s['src'] + '/'),
".")[:-1]
reset_config_list = [
{
'field': 'href',
},
{
'field': 'src',
}
]
for reset_config in reset_config_list:
field = reset_config.get('field')
tag_list = bf.find_all(**{field: re.compile('^(?!(http:|https:|tel:/|#|mailto:|javascript:)).*')})
for tag in tag_list:
self.reset_url(tag, field, self.base_fork_url)
return bf
@staticmethod
@ -92,11 +120,14 @@ class Fork:
def fork(self):
try:
logging.getLogger("max_kb").info(f'fork:{self.base_fork_url}')
response = requests.get(self.base_fork_url)
if response.status_code != 200:
raise Exception(response.status_code)
logging.getLogger("max_kb").error(f"url: {self.base_fork_url} code:{response.status_code}")
return Fork.Response.error(f"url: {self.base_fork_url} code:{response.status_code}")
bf = self.get_beautiful_soup(response)
except Exception as e:
logging.getLogger("max_kb_error").error(f'{str(e)}:{traceback.format_exc()}')
return Fork.Response.error(str(e))
bf = self.reset_beautiful_soup(bf)
link_list = self.get_child_link_list(bf)
@ -106,7 +137,6 @@ class Fork:
def handler(base_url, response: Fork.Response):
print(base_url, response.status)
print(base_url.url, base_url.tag.text if base_url.tag else None, response.content)
ForkManage('https://dataease.io/docs/v2/', ['.md-content']).fork(3, set(), handler)
# ForkManage('https://bbs.fit2cloud.com/c/de/6', ['.md-content']).fork(3, set(), handler)

View File

@ -277,11 +277,11 @@ def filter_special_char(content: str):
class SplitModel:
def __init__(self, content_level_pattern, with_filter=True, limit=1024):
def __init__(self, content_level_pattern, with_filter=True, limit=4096):
self.content_level_pattern = content_level_pattern
self.with_filter = with_filter
if limit is None or limit > 1024:
limit = 1024
if limit is None or limit > 4096:
limit = 4096
if limit < 50:
limit = 50
self.limit = limit
@ -337,13 +337,12 @@ class SplitModel:
default_split_pattern = {
'md': [re.compile("^# .*"), re.compile('(?<!#)## (?!#).*'), re.compile("(?<!#)### (?!#).*"),
re.compile("(?<!#)####(?!#).*"), re.compile("(?<!#)#####(?!#).*"),
re.compile("(?<!#)######(?!#).*"),
re.compile("(?<! )- .*")],
re.compile("(?<!#)######(?!#).*")],
'default': [re.compile("(?<!\n)\n\n.+")]
}
def get_split_model(filename: str, with_filter: bool, limit: int):
def get_split_model(filename: str, with_filter: bool = False, limit: int = 4096):
"""
根据文件名称获取分段模型
:param limit: 每段大小

View File

@ -0,0 +1,38 @@
# Generated by Django 4.1.10 on 2023-12-28 15:16
from django.db import migrations, models
class Migration(migrations.Migration):
dependencies = [
('dataset', '0001_initial'),
]
operations = [
migrations.AddField(
model_name='dataset',
name='meta',
field=models.JSONField(default=dict, verbose_name='元数据'),
),
migrations.AddField(
model_name='dataset',
name='type',
field=models.CharField(choices=[('0', '通用类型'), ('1', 'web站点类型')], default='0', max_length=1, verbose_name='类型'),
),
migrations.AddField(
model_name='document',
name='meta',
field=models.JSONField(default=dict, verbose_name='元数据'),
),
migrations.AddField(
model_name='document',
name='type',
field=models.CharField(choices=[('0', '通用类型'), ('1', 'web站点类型')], default='0', max_length=1, verbose_name='类型'),
),
migrations.AlterField(
model_name='dataset',
name='name',
field=models.CharField(max_length=150, verbose_name='数据集名称'),
),
]

View File

@ -0,0 +1,18 @@
# Generated by Django 4.1.10 on 2023-12-29 17:49
from django.db import migrations, models
class Migration(migrations.Migration):
dependencies = [
('dataset', '0002_dataset_meta_dataset_type_document_meta_and_more'),
]
operations = [
migrations.AlterField(
model_name='paragraph',
name='content',
field=models.CharField(max_length=4096, verbose_name='段落内容'),
),
]

View File

@ -21,6 +21,12 @@ class Status(models.TextChoices):
error = 2, '导入失败'
class Type(models.TextChoices):
base = 0, '通用类型'
web = 1, 'web站点类型'
class DataSet(AppModelMixin):
"""
数据集表
@ -29,6 +35,10 @@ class DataSet(AppModelMixin):
name = models.CharField(max_length=150, verbose_name="数据集名称")
desc = models.CharField(max_length=256, verbose_name="数据库描述")
user = models.ForeignKey(User, on_delete=models.DO_NOTHING, verbose_name="所属用户")
type = models.CharField(verbose_name='类型', max_length=1, choices=Type.choices,
default=Type.base)
meta = models.JSONField(verbose_name="元数据", default=dict)
class Meta:
db_table = "dataset"
@ -46,6 +56,11 @@ class Document(AppModelMixin):
default=Status.embedding)
is_active = models.BooleanField(default=True)
type = models.CharField(verbose_name='类型', max_length=1, choices=Type.choices,
default=Type.base)
meta = models.JSONField(verbose_name="元数据", default=dict)
class Meta:
db_table = "document"
@ -57,7 +72,7 @@ class Paragraph(AppModelMixin):
id = models.UUIDField(primary_key=True, max_length=128, default=uuid.uuid1, editable=False, verbose_name="主键id")
document = models.ForeignKey(Document, on_delete=models.DO_NOTHING, db_constraint=False)
dataset = models.ForeignKey(DataSet, on_delete=models.DO_NOTHING)
content = models.CharField(max_length=1024, verbose_name="段落内容")
content = models.CharField(max_length=4096, verbose_name="段落内容")
title = models.CharField(max_length=256, verbose_name="标题", default="")
hit_num = models.IntegerField(verbose_name="命中数量", default=0)
star_num = models.IntegerField(verbose_name="点赞数", default=0)

View File

@ -6,11 +6,13 @@
@date2023/9/21 16:14
@desc:
"""
import logging
import os.path
import traceback
import uuid
from functools import reduce
from itertools import groupby
from typing import Dict
from urllib.parse import urlparse
from django.contrib.postgres.fields import ArrayField
from django.core import validators
@ -23,17 +25,18 @@ from application.models import ApplicationDatasetMapping
from common.config.embedding_config import VectorStore, EmbeddingModel
from common.db.search import get_dynamics_model, native_page_search, native_search
from common.db.sql_execute import select_list
from common.event.listener_manage import ListenerManagement
from common.event.listener_manage import ListenerManagement, SyncWebDatasetArgs
from common.exception.app_exception import AppApiException
from common.mixins.api_mixin import ApiMixin
from common.util.common import post
from common.util.file_util import get_file_content
from dataset.models.data_set import DataSet, Document, Paragraph, Problem
from common.util.fork import ChildLink, Fork, ForkManage
from common.util.split_model import get_split_model
from dataset.models.data_set import DataSet, Document, Paragraph, Problem, Type
from dataset.serializers.common_serializers import list_paragraph
from dataset.serializers.document_serializers import DocumentSerializers, DocumentInstanceSerializer
from setting.models import AuthOperate
from smartdoc.conf import PROJECT_DIR
from users.models import User
"""
# __exact 精确等于 like aaa
@ -187,30 +190,105 @@ class DataSetSerializers(serializers.ModelSerializer):
return DataSetSerializers.Operate.get_response_body_api()
class Create(ApiMixin, serializers.Serializer):
"""
创建序列化对象
"""
name = serializers.CharField(required=True,
validators=[
validators.MaxLengthValidator(limit_value=20,
message="知识库名称在1-20个字符之间"),
validators.MinLengthValidator(limit_value=1,
message="知识库名称在1-20个字符之间")
])
user_id = serializers.UUIDField(required=True)
desc = serializers.CharField(required=True,
validators=[
validators.MaxLengthValidator(limit_value=256,
message="知识库名称在1-256个字符之间"),
validators.MinLengthValidator(limit_value=1,
message="知识库名称在1-256个字符之间")
])
class CreateBaseSerializers(ApiMixin, serializers.Serializer):
"""
创建通用数据集序列化对象
"""
name = serializers.CharField(required=True,
validators=[
validators.MaxLengthValidator(limit_value=20,
message="知识库名称在1-20个字符之间"),
validators.MinLengthValidator(limit_value=1,
message="知识库名称在1-20个字符之间")
])
documents = DocumentInstanceSerializer(required=False, many=True)
desc = serializers.CharField(required=True,
validators=[
validators.MaxLengthValidator(limit_value=256,
message="知识库名称在1-256个字符之间"),
validators.MinLengthValidator(limit_value=1,
message="知识库名称在1-256个字符之间")
])
def is_valid(self, *, raise_exception=False):
super().is_valid(raise_exception=True)
return True
documents = DocumentInstanceSerializer(required=False, many=True)
def is_valid(self, *, raise_exception=False):
super().is_valid(raise_exception=True)
return True
class CreateWebSerializers(serializers.Serializer):
"""
创建web站点序列化对象
"""
name = serializers.CharField(required=True,
validators=[
validators.MaxLengthValidator(limit_value=20,
message="知识库名称在1-20个字符之间"),
validators.MinLengthValidator(limit_value=1,
message="知识库名称在1-20个字符之间")
])
desc = serializers.CharField(required=True,
validators=[
validators.MaxLengthValidator(limit_value=256,
message="知识库名称在1-256个字符之间"),
validators.MinLengthValidator(limit_value=1,
message="知识库名称在1-256个字符之间")
])
url = serializers.CharField(required=True)
selector = serializers.CharField(required=False, allow_null=True, allow_blank=True)
def is_valid(self, *, raise_exception=False):
super().is_valid(raise_exception=True)
return True
@staticmethod
def get_response_body_api():
return openapi.Schema(
type=openapi.TYPE_OBJECT,
required=['id', 'name', 'desc', 'user_id', 'char_length', 'document_count',
'update_time', 'create_time', 'document_list'],
properties={
'id': openapi.Schema(type=openapi.TYPE_STRING, title="id",
description="id", default="xx"),
'name': openapi.Schema(type=openapi.TYPE_STRING, title="名称",
description="名称", default="测试知识库"),
'desc': openapi.Schema(type=openapi.TYPE_STRING, title="描述",
description="描述", default="测试知识库描述"),
'user_id': openapi.Schema(type=openapi.TYPE_STRING, title="所属用户id",
description="所属用户id", default="user_xxxx"),
'char_length': openapi.Schema(type=openapi.TYPE_STRING, title="字符数",
description="字符数", default=10),
'document_count': openapi.Schema(type=openapi.TYPE_STRING, title="文档数量",
description="文档数量", default=1),
'update_time': openapi.Schema(type=openapi.TYPE_STRING, title="修改时间",
description="修改时间",
default="1970-01-01 00:00:00"),
'create_time': openapi.Schema(type=openapi.TYPE_STRING, title="创建时间",
description="创建时间",
default="1970-01-01 00:00:00"
),
'document_list': openapi.Schema(type=openapi.TYPE_ARRAY, title="文档列表",
description="文档列表",
items=DocumentSerializers.Operate.get_response_body_api())
}
)
@staticmethod
def get_request_body_api():
return openapi.Schema(
type=openapi.TYPE_OBJECT,
required=['name', 'desc', 'url'],
properties={
'name': openapi.Schema(type=openapi.TYPE_STRING, title="知识库名称", description="知识库名称"),
'desc': openapi.Schema(type=openapi.TYPE_STRING, title="知识库描述", description="知识库描述"),
'url': openapi.Schema(type=openapi.TYPE_STRING, title="web站点url", description="web站点url"),
'selector': openapi.Schema(type=openapi.TYPE_STRING, title="选择器", description="选择器")
}
)
@staticmethod
def post_embedding_dataset(document_list, dataset_id):
@ -220,16 +298,21 @@ class DataSetSerializers(serializers.ModelSerializer):
@post(post_function=post_embedding_dataset)
@transaction.atomic
def save(self, user: User):
def save(self, instance: Dict, with_valid=True):
if with_valid:
self.is_valid(raise_exception=True)
self.CreateBaseSerializers(data=instance).is_valid()
dataset_id = uuid.uuid1()
user_id = self.data.get('user_id')
dataset = DataSet(
**{'id': dataset_id, 'name': self.data.get("name"), 'desc': self.data.get('desc'), 'user': user})
**{'id': dataset_id, 'name': instance.get("name"), 'desc': instance.get('desc'), 'user_id': user_id})
document_model_list = []
paragraph_model_list = []
problem_model_list = []
# 插入文档
for document in self.data.get('documents') if 'documents' in self.data else []:
for document in instance.get('documents') if 'documents' in instance else []:
document_paragraph_dict_model = DocumentSerializers.Create.get_document_paragraph_model(dataset_id,
document)
document_model_list.append(document_paragraph_dict_model.get('document'))
@ -252,6 +335,47 @@ class DataSetSerializers(serializers.ModelSerializer):
'document_list': DocumentSerializers.Query(data={'dataset_id': dataset_id}).list(
with_valid=True)}, dataset_id
@staticmethod
def get_last_url_path(url):
parsed_url = urlparse(url)
if parsed_url.path is None or len(parsed_url.path) == 0:
return url
else:
return parsed_url.path.split("/")[-1]
@staticmethod
def get_save_handler(dataset_id, selector):
def handler(child_link: ChildLink, response: Fork.Response):
if response.status == 200:
try:
document_name = child_link.tag.text if child_link.tag is not None and len(
child_link.tag.text.strip()) > 0 else child_link.url
paragraphs = get_split_model('web.md').parse(response.content)
DocumentSerializers.Create(data={'dataset_id': dataset_id}).save(
{'name': document_name, 'paragraphs': paragraphs,
'meta': {'source_url': child_link.url, 'selector': selector},
'type': Type.web}, with_valid=True)
except Exception as e:
logging.getLogger("max_kb_error").error(f'{str(e)}:{traceback.format_exc()}')
return handler
def save_web(self, instance: Dict, with_valid=True):
if with_valid:
self.is_valid(raise_exception=True)
self.CreateWebSerializers(data=instance).is_valid(raise_exception=True)
user_id = self.data.get('user_id')
dataset_id = uuid.uuid1()
dataset = DataSet(
**{'id': dataset_id, 'name': instance.get("name"), 'desc': instance.get('desc'), 'user_id': user_id,
'type': Type.web, 'meta': {'source_url': instance.get('url'), 'selector': instance.get('selector')}})
dataset.save()
ListenerManagement.sync_web_dataset_signal.send(
SyncWebDatasetArgs(str(dataset_id), instance.get('url'), instance.get('selector'),
self.get_save_handler(dataset_id, instance.get('selector'))))
return {**DataSetSerializers(dataset).data,
'document_list': []}
@staticmethod
def get_response_body_api():
return openapi.Schema(
@ -298,12 +422,43 @@ class DataSetSerializers(serializers.ModelSerializer):
}
)
class Edit(serializers.Serializer):
class MetaSerializer(serializers.Serializer):
class WebMeta(serializers.Serializer):
source_url = serializers.CharField(required=True)
selector = serializers.CharField(required=False, allow_null=True, allow_blank=True)
def is_valid(self, *, raise_exception=False):
super().is_valid(raise_exception=True)
source_url = self.data.get('source_url')
response = Fork(source_url, []).fork()
if response.status == 500:
raise AppApiException(500, response.message)
class BaseMeta(serializers.Serializer):
def is_valid(self, *, raise_exception=False):
super().is_valid(raise_exception=True)
class Edit(serializers.Serializer):
name = serializers.CharField(required=False)
desc = serializers.CharField(required=False)
meta = serializers.DictField(required=False)
application_id_list = serializers.ListSerializer(required=False, child=serializers.UUIDField(required=True))
@staticmethod
def get_dataset_meta_valid_map():
dataset_meta_valid_map = {
Type.base: DataSetSerializers.MetaSerializer.BaseMeta,
Type.web: DataSetSerializers.MetaSerializer.WebMeta
}
return dataset_meta_valid_map
def is_valid(self, *, dataset: DataSet = None):
super().is_valid(raise_exception=True)
if 'meta' in self.data and self.data.get('meta') is not None:
dataset_meta_valid_map = self.get_dataset_meta_valid_map()
valid_class = dataset_meta_valid_map.get(dataset.type)
valid_class(data=self.data.get('meta')).is_valid(raise_exception=True)
class HitTest(ApiMixin, serializers.Serializer):
id = serializers.CharField(required=True)
user_id = serializers.UUIDField(required=False)
@ -392,12 +547,14 @@ class DataSetSerializers(serializers.ModelSerializer):
:return:
"""
self.is_valid()
DataSetSerializers.Edit(data=dataset).is_valid(raise_exception=True)
_dataset = QuerySet(DataSet).get(id=self.data.get("id"))
DataSetSerializers.Edit(data=dataset).is_valid(dataset=_dataset)
if "name" in dataset:
_dataset.name = dataset.get("name")
if 'desc' in dataset:
_dataset.desc = dataset.get("desc")
if 'meta' in dataset:
_dataset.meta = dataset.get('meta')
if 'application_id_list' in dataset and dataset.get('application_id_list') is not None:
application_id_list = dataset.get('application_id_list')
# 当前用户可修改关联的知识库列表
@ -429,6 +586,8 @@ class DataSetSerializers(serializers.ModelSerializer):
properties={
'name': openapi.Schema(type=openapi.TYPE_STRING, title="知识库名称", description="知识库名称"),
'desc': openapi.Schema(type=openapi.TYPE_STRING, title="知识库描述", description="知识库描述"),
'meta': openapi.Schema(type=openapi.TYPE_OBJECT, title="知识库元数据",
description="知识库元数据->web:{source_url:xxx,selector:'xxx'},base:{}"),
'application_id_list': openapi.Schema(type=openapi.TYPE_ARRAY, title="应用id列表",
description="应用id列表",
items=openapi.Schema(type=openapi.TYPE_STRING))

View File

@ -24,7 +24,7 @@ from common.mixins.api_mixin import ApiMixin
from common.util.common import post
from common.util.file_util import get_file_content
from common.util.split_model import SplitModel, get_split_model
from dataset.models.data_set import DataSet, Document, Paragraph, Problem
from dataset.models.data_set import DataSet, Document, Paragraph, Problem, Type
from dataset.serializers.paragraph_serializers import ParagraphSerializers, ParagraphInstanceSerializer
from smartdoc.conf import PROJECT_DIR
@ -243,7 +243,9 @@ class DocumentSerializers(ApiMixin, serializers.Serializer):
'name': instance.get('name'),
'char_length': reduce(lambda x, y: x + y,
[len(p.get('content')) for p in instance.get('paragraphs', [])],
0)})
0),
'meta': instance.get('meta') if instance.get('meta') is not None else {},
'type': instance.get('type') if instance.get('type') is not None else Type.base})
paragraph_model_dict_list = [ParagraphSerializers.Create(
data={'dataset_id': dataset_id, 'document_id': str(document_model.id)}).get_paragraph_problem_model(

View File

@ -37,7 +37,7 @@ class ParagraphInstanceSerializer(ApiMixin, serializers.Serializer):
段落实例对象
"""
content = serializers.CharField(required=True, validators=[
validators.MaxLengthValidator(limit_value=1024,
validators.MaxLengthValidator(limit_value=4096,
message="段落在1-1024个字符之间"),
validators.MinLengthValidator(limit_value=1,
message="段落在1-1024个字符之间"),

View File

@ -5,6 +5,7 @@ from . import views
app_name = "dataset"
urlpatterns = [
path('dataset', views.Dataset.as_view(), name="dataset"),
path('dataset/web', views.Dataset.CreateWebDataset.as_view()),
path('dataset/<str:dataset_id>', views.Dataset.Operate.as_view(), name="dataset_key"),
path('dataset/<str:dataset_id>/application', views.Dataset.Application.as_view()),
path('dataset/<int:current_page>/<int:page_size>', views.Dataset.Page.as_view(), name="dataset"),

View File

@ -23,6 +23,21 @@ from dataset.serializers.dataset_serializers import DataSetSerializers
class Dataset(APIView):
authentication_classes = [TokenAuth]
class CreateWebDataset(APIView):
authentication_classes = [TokenAuth]
@action(methods=['POST'], detail=False)
@swagger_auto_schema(operation_summary="创建web站点知识库",
operation_id="创建web站点知识库",
request_body=DataSetSerializers.Create.CreateWebSerializers.get_request_body_api(),
responses=get_api_response(
DataSetSerializers.Create.CreateWebSerializers.get_response_body_api()),
tags=["知识库"]
)
@has_permissions(PermissionConstants.DATASET_CREATE, compare=CompareConstants.AND)
def post(self, request: Request):
return result.success(DataSetSerializers.Create(data={'user_id': request.user.id}).save_web(request.data))
class Application(APIView):
authentication_classes = [TokenAuth]
@ -58,9 +73,7 @@ class Dataset(APIView):
)
@has_permissions(PermissionConstants.DATASET_CREATE, compare=CompareConstants.AND)
def post(self, request: Request):
s = DataSetSerializers.Create(data=request.data)
s.is_valid(raise_exception=True)
return result.success(s.save(request.user))
return result.success(DataSetSerializers.Create(data={'user_id': request.user.id}).save(request.data))
class HitTest(APIView):
authentication_classes = [TokenAuth]

View File

@ -0,0 +1,23 @@
# Generated by Django 4.1.10 on 2023-12-28 15:16
from django.db import migrations, models
class Migration(migrations.Migration):
dependencies = [
('setting', '0001_initial'),
]
operations = [
migrations.AlterField(
model_name='teammemberpermission',
name='auth_target_type',
field=models.CharField(choices=[('DATASET', '数据集'), ('APPLICATION', '应用')], default='DATASET', max_length=128, verbose_name='授权目标'),
),
migrations.AlterField(
model_name='teammemberpermission',
name='target',
field=models.UUIDField(verbose_name='数据集/应用id'),
),
]