feat: 数据集,文档,段落,问题,向量化接口

This commit is contained in:
shaohuzhang1 2023-10-24 20:24:32 +08:00
parent 64e93679f8
commit a2de9691fb
36 changed files with 1843 additions and 166 deletions

1
.gitignore vendored
View File

@ -162,6 +162,7 @@ cython_debug/
ui/node_modules
ui/dist
apps/static
models/
data
.idea
.dev

View File

@ -1,4 +1,4 @@
# Generated by Django 4.1.10 on 2023-10-09 06:33
# Generated by Django 4.1.10 on 2023-10-24 12:13
import django.contrib.postgres.fields
from django.db import migrations, models

View File

@ -0,0 +1,51 @@
# coding=utf-8
"""
@project: maxkb
@Author
@file embedding_config.py
@date2023/10/23 16:03
@desc:
"""
import types
from smartdoc.const import CONFIG
from langchain.embeddings import HuggingFaceEmbeddings
class EmbeddingModel:
instance = None
@staticmethod
def get_embedding_model():
"""
获取向量化模型
:return:
"""
if EmbeddingModel.instance is None:
model_name = CONFIG.get('EMBEDDING_MODEL_NAME')
cache_folder = CONFIG.get('EMBEDDING_MODEL_PATH')
device = CONFIG.get('EMBEDDING_DEVICE')
e = HuggingFaceEmbeddings(
model_name=model_name,
cache_folder=cache_folder,
model_kwargs={'device': device})
EmbeddingModel.instance = e
return EmbeddingModel.instance
class VectorStore:
from embedding.vector.pg_vector import PGVector
from embedding.vector.base_vector import BaseVectorStore
instance_map = {
'pg_vector': PGVector,
}
instance = None
@staticmethod
def get_embedding_vector() -> BaseVectorStore:
from embedding.vector.pg_vector import PGVector
if VectorStore.instance is None:
from smartdoc.const import CONFIG
vector_store_class = VectorStore.instance_map.get(CONFIG.get("VECTOR_STORE_NAME"),
PGVector)
VectorStore.instance = vector_store_class()
return VectorStore.instance

View File

@ -1,4 +1,3 @@
# coding=utf-8
"""
@project: qabot
@Author

View File

@ -0,0 +1,162 @@
# coding=utf-8
"""
@project: maxkb
@Author
@file listener_manage.py
@date2023/10/20 14:01
@desc:
"""
import os
from concurrent.futures import ThreadPoolExecutor
import django.db.models
from blinker import signal
from django.db.models import QuerySet
from common.config.embedding_config import VectorStore, EmbeddingModel
from common.db.search import native_search, get_dynamics_model
from common.util.file_util import get_file_content
from dataset.models import Paragraph, Status, Document
from embedding.models import SourceType
from smartdoc.conf import PROJECT_DIR
def poxy(poxy_function):
def inner(args):
ListenerManagement.work_thread_pool.submit(poxy_function, args)
return inner
class ListenerManagement:
work_thread_pool = ThreadPoolExecutor(5)
embedding_by_problem_signal = signal("embedding_by_problem")
embedding_by_paragraph_signal = signal("embedding_by_paragraph")
embedding_by_dataset_signal = signal("embedding_by_dataset")
embedding_by_document_signal = signal("embedding_by_document")
delete_embedding_by_document_signal = signal("delete_embedding_by_document")
delete_embedding_by_dataset_signal = signal("delete_embedding_by_dataset")
delete_embedding_by_paragraph_signal = signal("delete_embedding_by_paragraph")
delete_embedding_by_source_signal = signal("delete_embedding_by_source")
enable_embedding_by_paragraph_signal = signal('enable_embedding_by_paragraph')
disable_embedding_by_paragraph_signal = signal('disable_embedding_by_paragraph')
init_embedding_model_signal = signal('init_embedding_model')
@staticmethod
def embedding_by_problem(args):
VectorStore.get_embedding_vector().save(**args)
@staticmethod
@poxy
def embedding_by_paragraph(paragraph_id):
"""
向量化段落 根据段落id
:param paragraph_id: 段落id
:return: None
"""
data_list = native_search(
{'problem': QuerySet(get_dynamics_model({'problem.paragraph_id': django.db.models.CharField()})).filter(
**{'problem.paragraph_id': paragraph_id}),
'paragraph': QuerySet(Paragraph).filter(id=paragraph_id)},
select_string=get_file_content(
os.path.join(PROJECT_DIR, "apps", "common", 'sql', 'list_embedding_text.sql')))
# 批量向量化
VectorStore.get_embedding_vector().batch_save(data_list)
QuerySet(Paragraph).filter(id=paragraph_id).update(**{'status': Status.success.value})
@staticmethod
@poxy
def embedding_by_document(document_id):
"""
向量化文档
:param document_id: 文档id
:return: None
"""
data_list = native_search(
{'problem': QuerySet(get_dynamics_model({'problem.document_id': django.db.models.CharField()})).filter(
**{'problem.document_id': document_id}),
'paragraph': QuerySet(Paragraph).filter(document_id=document_id)},
select_string=get_file_content(
os.path.join(PROJECT_DIR, "apps", "common", 'sql', 'list_embedding_text.sql')))
# 批量向量化
VectorStore.get_embedding_vector().batch_save(data_list)
# 修改状态
QuerySet(Document).filter(id=document_id).update(**{'status': Status.success.value})
QuerySet(Paragraph).filter(document_id=document_id).update(**{'status': Status.success.value})
@staticmethod
@poxy
def embedding_by_dataset(dataset_id):
"""
向量化数据集
:param dataset_id: 数据集id
:return: None
"""
data_list = native_search(
{'problem': QuerySet(get_dynamics_model({'problem.dataset_id': django.db.models.CharField()})).filter(
**{'problem.dataset_id': dataset_id}),
'paragraph': QuerySet(Paragraph).filter(dataset_id=dataset_id)},
select_string=get_file_content(
os.path.join(PROJECT_DIR, "apps", "common", 'sql', 'list_embedding_text.sql')))
# 批量向量化
VectorStore.get_embedding_vector().batch_save(data_list)
# 修改文档 以及段落的状态
QuerySet(Document).filter(dataset_id=dataset_id).update(**{'status': Status.success.value})
QuerySet(Paragraph).filter(dataset_id=dataset_id).update(**{'status': Status.success.value})
@staticmethod
def delete_embedding_by_document(document_id):
VectorStore.get_embedding_vector().delete_by_document_id(document_id)
@staticmethod
def delete_embedding_by_dataset(dataset_id):
VectorStore.get_embedding_vector().delete_by_dataset_id(dataset_id)
@staticmethod
def delete_embedding_by_paragraph(paragraph_id):
VectorStore.get_embedding_vector().delete_by_paragraph_id(paragraph_id)
@staticmethod
def delete_embedding_by_source(source_id):
VectorStore.get_embedding_vector().delete_by_source_id(source_id, SourceType.PROBLEM)
@staticmethod
def disable_embedding_by_paragraph(paragraph_id):
VectorStore.get_embedding_vector().update_by_paragraph_id(paragraph_id, {'is_active': False})
@staticmethod
def enable_embedding_by_paragraph(paragraph_id):
VectorStore.get_embedding_vector().update_by_paragraph_id(paragraph_id, {'is_active': True})
@staticmethod
@poxy
def init_embedding_model(ags):
EmbeddingModel.get_embedding_model()
def run(self):
# 添加向量 根据问题id
ListenerManagement.embedding_by_problem_signal.connect(self.embedding_by_problem)
# 添加向量 根据段落id
ListenerManagement.embedding_by_paragraph_signal.connect(self.embedding_by_paragraph)
# 添加向量 根据数据集id
ListenerManagement.embedding_by_dataset_signal.connect(
self.embedding_by_dataset)
# 添加向量 根据文档id
ListenerManagement.embedding_by_document_signal.connect(
self.embedding_by_document)
# 删除 向量 根据文档
ListenerManagement.delete_embedding_by_document_signal.connect(self.delete_embedding_by_document)
# 删除 向量 根据数据集id
ListenerManagement.delete_embedding_by_dataset_signal.connect(self.delete_embedding_by_dataset)
# 删除向量 根据段落id
ListenerManagement.delete_embedding_by_paragraph_signal.connect(
self.delete_embedding_by_paragraph)
# 删除向量 根据资源id
ListenerManagement.delete_embedding_by_source_signal.connect(self.delete_embedding_by_source)
# 禁用段落
ListenerManagement.disable_embedding_by_paragraph_signal.connect(self.disable_embedding_by_paragraph)
# 启动段落向量
ListenerManagement.enable_embedding_by_paragraph_signal.connect(self.enable_embedding_by_paragraph)
# 初始化向量化模型
ListenerManagement.init_embedding_model_signal.connect(self.init_embedding_model)

View File

@ -14,7 +14,7 @@ from rest_framework.views import exception_handler
from common.exception.app_exception import AppApiException
from common.response import result
import traceback
def to_result(key, args, parent_key=None):
"""
将校验异常 args转换为统一数据
@ -59,6 +59,7 @@ def handle_exception(exc, context):
exception_class = exc.__class__
# 先调用REST framework默认的异常处理方法获得标准错误响应对象
response = exception_handler(exc, context)
traceback.print_exc()
# 在此处补充自定义的异常处理
if issubclass(exception_class, ValidationError):
return validation_error_to_result(exc)

View File

@ -70,7 +70,9 @@ def get_page_api_response(response_data_schema: openapi.Schema):
title="总条数",
default=1,
description="数据总条数"),
"records": response_data_schema,
"records": openapi.Schema(
type=openapi.TYPE_ARRAY,
items=response_data_schema),
"current": openapi.Schema(
type=openapi.TYPE_INTEGER,
title="当前页",
@ -115,6 +117,36 @@ def get_api_response(response_data_schema: openapi.Schema):
)})
def get_default_response():
return get_api_response(openapi.Schema(type=openapi.TYPE_BOOLEAN))
def get_api_array_response(response_data_schema: openapi.Schema):
"""
获取统一返回 响应Api
"""
return openapi.Responses(responses={200: openapi.Response(description="响应参数",
schema=openapi.Schema(
type=openapi.TYPE_OBJECT,
properties={
'code': openapi.Schema(
type=openapi.TYPE_INTEGER,
title="响应码",
default=200,
description="成功:200 失败:其他"),
"message": openapi.Schema(
type=openapi.TYPE_STRING,
title="提示",
default='成功',
description="错误提示"),
"data": openapi.Schema(type=openapi.TYPE_ARRAY,
items=response_data_schema)
}
),
)})
def success(data):
"""
获取一个成功的响应对象

View File

@ -0,0 +1,26 @@
SELECT
problem."id" AS "source_id",
problem.document_id AS document_id,
problem.paragraph_id AS paragraph_id,
problem.dataset_id AS dataset_id,
0 AS source_type,
problem."content" AS "text",
paragraph.is_active AS is_active
FROM
problem problem
LEFT JOIN paragraph paragraph ON paragraph."id" = problem.paragraph_id
${problem}
UNION
SELECT
paragraph."id" AS "source_id",
paragraph.document_id AS document_id,
paragraph."id" AS paragraph_id,
paragraph.dataset_id AS dataset_id,
1 AS source_type,
paragraph."content" AS "text",
paragraph.is_active AS is_active
FROM
paragraph paragraph
${paragraph}

View File

@ -0,0 +1,19 @@
# coding=utf-8
"""
@project: maxkb
@Author
@file common.py
@date2023/10/16 16:42
@desc:
"""
from functools import reduce
from typing import Dict
def query_params_to_single_dict(query_params: Dict):
return reduce(lambda x, y: {**x, y[0]: y[1]}, list(filter(lambda row: row[1] is not None,
list(map(lambda row: (
row[0], row[1][0] if isinstance(row[1][0],
list) and len(
row[1][0]) > 0 else row[1][0]),
query_params.items())))), {})

View File

@ -7,6 +7,7 @@
@desc:
"""
import re
from functools import reduce
from typing import List
import jieba
@ -25,7 +26,7 @@ def get_level_block(text, level_content_list, level_content_index):
level_content_list) else None
start_index = text.index(start_content)
end_index = text.index(next_content) if next_content is not None else len(text)
return text[start_index:end_index]
return text[start_index:end_index].replace(level_content_list[level_content_index]['content'], "")
def to_tree_obj(content, state='title'):
@ -88,7 +89,7 @@ def to_paragraph(obj: dict):
content = obj['content']
return {"keywords": get_keyword(content),
'parent_chain': list(map(lambda p: p['content'], obj['parent_chain'])),
'content': content}
'content': ",".join(list(map(lambda p: p['content'], obj['parent_chain']))) + content}
def get_keyword(content: str):
@ -109,13 +110,15 @@ def titles_to_paragraph(list_title: List[dict]):
:return: 块段落
"""
if len(list_title) > 0:
content = "\n".join(
content = "\n,".join(
list(map(lambda d: d['content'].strip("\r\n").strip("\n").strip("\\s"), list_title)))
return {'keywords': '',
'parent_chain': list(
map(lambda p: p['content'].strip("\r\n").strip("\n").strip("\\s"), list_title[0]['parent_chain'])),
'content': content}
'content': ",".join(list(
map(lambda p: p['content'].strip("\r\n").strip("\n").strip("\\s"),
list_title[0]['parent_chain']))) + content}
return None
@ -144,6 +147,15 @@ def to_block_paragraph(tree_data_list: List[dict]):
return list(map(lambda level: parse_group_key(level_group_dict[level]), level_group_dict))
def parse_title_level(text, content_level_pattern: List, index):
if len(content_level_pattern) == index:
return []
result = parse_level(text, content_level_pattern[index])
if len(result) == 0 and len(content_level_pattern) > index + 1:
return parse_title_level(text, content_level_pattern, index + 1)
return result
def parse_level(text, pattern: str):
"""
获取正则匹配到的文本
@ -151,10 +163,17 @@ def parse_level(text, pattern: str):
:param pattern: 正则
:return: 符合正则的文本
"""
level_content_list = list(map(to_tree_obj, re.findall(pattern, text, flags=0)))
level_content_list = list(map(to_tree_obj, re_findall(pattern, text)))
return list(map(filter_special_symbol, level_content_list))
def re_findall(pattern, text):
result = re.findall(pattern, text, flags=0)
return list(filter(lambda r: r is not None and len(r) > 0, reduce(lambda x, y: [*x, *y], list(
map(lambda row: [*(row if isinstance(row, tuple) else [row])], result)),
[])))
def to_flat_obj(parent_chain: List[dict], content: str, state: str):
"""
将树形属性转换为扁平对象
@ -194,10 +213,79 @@ def group_by(list_source: List, key):
return result
def result_tree_to_paragraph(result_tree: List[dict], result, parent_chain):
"""
转换为分段对象
:param result_tree: 解析文本的树
:param result: [] 用于递归
:param parent_chain: [] 用户递归存储数据
:return: List[{'problem':'xx','content':'xx'}]
"""
for item in result_tree:
print(item)
if item.get('state') == 'block':
result.append({'title': " ".join(parent_chain), 'content': item.get("content")})
children = item.get("children")
if children is not None and len(children) > 0:
result_tree_to_paragraph(children, result, [*parent_chain, item.get('content')])
return result
def post_handler_paragraph(content: str, limit: int, with_filter: bool):
"""
根据文本的最大字符分段
:param with_filter: 是否过滤特殊字符
:param content: 需要分段的文本字段
:param limit: 最大分段字符
:return: 分段后数据
"""
split_list = content.split('\n')
result = []
temp_char = ''
for split in split_list:
if len(temp_char + split) > limit:
result.append(temp_char)
temp_char = ''
temp_char = temp_char + split
if len(temp_char) > 0:
result.append(temp_char)
pattern = "[\\S\\s]{1," + str(limit) + '}'
# 如果\n 单段超过限制,则继续拆分
s = list(map(lambda row: filter_special_char(row) if with_filter else row, list(
reduce(lambda x, y: [*x, *y], list(map(lambda row: list(re.findall(pattern, row)), result)), []))))
return s
replace_map = {
re.compile('\n+'): '\n',
re.compile('\\s+'): ' ',
re.compile('#+'): "",
re.compile("\t+"): ''
}
def filter_special_char(content: str):
"""
过滤特殊字段
:param content: 文本
:return: 过滤后字段
"""
items = replace_map.items()
for key, value in items:
content = re.sub(key, value, content)
return content
class SplitModel:
def __init__(self, content_level_pattern):
def __init__(self, content_level_pattern, with_filter=True, limit=1024):
self.content_level_pattern = content_level_pattern
self.with_filter = with_filter
if limit is None or limit > 1024:
limit = 1024
if limit < 50:
limit = 50
self.limit = limit
def parse_to_tree(self, text: str, index=0):
"""
@ -208,23 +296,27 @@ class SplitModel:
"""
if len(self.content_level_pattern) == index:
return
level_content_list = parse_level(text, pattern=self.content_level_pattern[index])
level_content_list = parse_title_level(text, self.content_level_pattern, index)
for i in range(len(level_content_list)):
block = get_level_block(text, level_content_list, i)
children = self.parse_to_tree(text=block.replace(level_content_list[i]['content'][:-1], ""),
children = self.parse_to_tree(text=block,
index=index + 1)
if children is not None and len(children) > 0:
level_content_list[i]['children'] = children
else:
if len(block) > 0:
level_content_list[i]['children'] = [to_tree_obj(block, 'block')]
level_content_list[i]['children'] = list(
map(lambda row: to_tree_obj(row, 'block'),
post_handler_paragraph(block, with_filter=self.with_filter, limit=self.limit)))
if len(level_content_list) > 0:
end_index = text.index(level_content_list[0].get('content'))
if end_index == 0:
return level_content_list
other_content = text[0:end_index]
if len(other_content.strip()) > 0:
level_content_list.append(to_tree_obj(other_content, 'block'))
level_content_list = [*level_content_list, *list(
map(lambda row: to_tree_obj(row, 'block'),
post_handler_paragraph(other_content, with_filter=self.with_filter, limit=self.limit)))]
return level_content_list
def parse(self, text: str):
@ -234,4 +326,35 @@ class SplitModel:
:return: 解析后数据 {content:段落数据,keywords:[段落关键词],parent_chain:['段落父级链路']}
"""
result_tree = self.parse_to_tree(text, 0)
return flat_map(to_block_paragraph(result_tree))
return result_tree_to_paragraph(result_tree, [], [])
split_model_map = {
'md': SplitModel(
[re.compile("^# .*"), re.compile('(?<!#)## (?!#).*'), re.compile("(?<!#)### (?!#).*"),
re.compile("(?<!#)####(?!#).*"), re.compile("(?<!#)#####(?!#).*"),
re.compile("(?<!#)#####(?!#).*"),
re.compile("(?<! )- .*")]),
'default': SplitModel([re.compile("(?<!\n)\n\n.+")])
}
def get_split_model(filename: str):
"""
根据文件名称获取分段模型
:param filename: 文件名称
:return: 分段模型
"""
if filename.endswith(".md"):
return split_model_map.get('md')
return split_model_map.get("default")
def to_title_tree_string(result_tree: List):
f = flat(result_tree, [], [])
return "\n".join(list(map(lambda r: title_tostring(r), list(filter(lambda row: row.get('state') == 'title', f)))))
def title_tostring(title_obj):
f = "".join(list(map(lambda index: " ", range(0, len(title_obj.get("parent_chain"))))))
return f + "├───" + title_obj.get('content')

View File

@ -1,4 +1,4 @@
# Generated by Django 4.1.10 on 2023-10-09 06:33
# Generated by Django 4.1.10 on 2023-10-24 12:13
from django.db import migrations, models
import django.db.models.deletion
@ -36,6 +36,7 @@ class Migration(migrations.Migration):
('id', models.UUIDField(default=uuid.uuid1, editable=False, primary_key=True, serialize=False, verbose_name='主键id')),
('name', models.CharField(max_length=150, verbose_name='文档名称')),
('char_length', models.IntegerField(verbose_name='文档字符数 冗余字段')),
('status', models.CharField(choices=[('0', '导入中'), ('1', '已完成'), ('2', '导入失败')], default='0', max_length=1, verbose_name='状态')),
('is_active', models.BooleanField(default=True)),
('dataset', models.ForeignKey(on_delete=django.db.models.deletion.DO_NOTHING, to='dataset.dataset')),
],
@ -50,11 +51,14 @@ class Migration(migrations.Migration):
('update_time', models.DateTimeField(auto_now=True, verbose_name='修改时间')),
('id', models.UUIDField(default=uuid.uuid1, editable=False, primary_key=True, serialize=False, verbose_name='主键id')),
('content', models.CharField(max_length=1024, verbose_name='段落内容')),
('title', models.CharField(default='', max_length=256, verbose_name='标题')),
('hit_num', models.IntegerField(default=0, verbose_name='命中数量')),
('star_num', models.IntegerField(default=0, verbose_name='点赞数')),
('trample_num', models.IntegerField(default=0, verbose_name='点踩数')),
('status', models.CharField(choices=[('0', '导入中'), ('1', '已完成'), ('2', '导入失败')], default='0', max_length=1, verbose_name='状态')),
('is_active', models.BooleanField(default=True)),
('document', models.ForeignKey(on_delete=django.db.models.deletion.DO_NOTHING, to='dataset.document')),
('dataset', models.ForeignKey(on_delete=django.db.models.deletion.DO_NOTHING, to='dataset.dataset')),
('document', models.ForeignKey(db_constraint=False, on_delete=django.db.models.deletion.DO_NOTHING, to='dataset.document')),
],
options={
'db_table': 'paragraph',
@ -67,25 +71,15 @@ class Migration(migrations.Migration):
('update_time', models.DateTimeField(auto_now=True, verbose_name='修改时间')),
('id', models.UUIDField(default=uuid.uuid1, editable=False, primary_key=True, serialize=False, verbose_name='主键id')),
('content', models.CharField(max_length=256, verbose_name='问题内容')),
('hit_num', models.IntegerField(default=0, verbose_name='命中数量')),
('star_num', models.IntegerField(default=0, verbose_name='点赞数')),
('trample_num', models.IntegerField(default=0, verbose_name='点踩数')),
('dataset', models.ForeignKey(db_constraint=False, on_delete=django.db.models.deletion.DO_NOTHING, to='dataset.dataset')),
('document', models.ForeignKey(db_constraint=False, on_delete=django.db.models.deletion.DO_NOTHING, to='dataset.document')),
('paragraph', models.ForeignKey(db_constraint=False, on_delete=django.db.models.deletion.DO_NOTHING, to='dataset.paragraph')),
],
options={
'db_table': 'problem',
},
),
migrations.CreateModel(
name='ProblemAnswerMapping',
fields=[
('create_time', models.DateTimeField(auto_now_add=True, verbose_name='创建时间')),
('update_time', models.DateTimeField(auto_now=True, verbose_name='修改时间')),
('id', models.UUIDField(default=uuid.uuid1, editable=False, primary_key=True, serialize=False, verbose_name='主键id')),
('hit_num', models.IntegerField(default=0, verbose_name='命中数量')),
('star_num', models.IntegerField(default=0, verbose_name='点赞数')),
('trample_num', models.IntegerField(default=0, verbose_name='点踩数')),
('paragraph', models.ForeignKey(on_delete=django.db.models.deletion.DO_NOTHING, to='dataset.paragraph')),
('problem', models.ForeignKey(on_delete=django.db.models.deletion.DO_NOTHING, to='dataset.problem')),
],
options={
'db_table': 'problem_paragraph_mapping',
},
),
]

View File

@ -14,6 +14,13 @@ from common.mixins.app_model_mixin import AppModelMixin
from users.models import User
class Status(models.TextChoices):
"""订单类型"""
embedding = 0, '导入中'
success = 1, '已完成'
error = 2, '导入失败'
class DataSet(AppModelMixin):
"""
数据集表
@ -35,6 +42,8 @@ class Document(AppModelMixin):
dataset = models.ForeignKey(DataSet, on_delete=models.DO_NOTHING)
name = models.CharField(max_length=150, verbose_name="文档名称")
char_length = models.IntegerField(verbose_name="文档字符数 冗余字段")
status = models.CharField(verbose_name='状态', max_length=1, choices=Status.choices,
default=Status.embedding)
is_active = models.BooleanField(default=True)
class Meta:
@ -46,11 +55,15 @@ class Paragraph(AppModelMixin):
段落表
"""
id = models.UUIDField(primary_key=True, max_length=128, default=uuid.uuid1, editable=False, verbose_name="主键id")
document = models.ForeignKey(Document, on_delete=models.DO_NOTHING)
document = models.ForeignKey(Document, on_delete=models.DO_NOTHING, db_constraint=False)
dataset = models.ForeignKey(DataSet, on_delete=models.DO_NOTHING)
content = models.CharField(max_length=1024, verbose_name="段落内容")
title = models.CharField(max_length=256, verbose_name="标题", default="")
hit_num = models.IntegerField(verbose_name="命中数量", default=0)
star_num = models.IntegerField(verbose_name="点赞数", default=0)
trample_num = models.IntegerField(verbose_name="点踩数", default=0)
status = models.CharField(verbose_name='状态', max_length=1, choices=Status.choices,
default=Status.embedding)
is_active = models.BooleanField(default=True)
class Meta:
@ -62,23 +75,13 @@ class Problem(AppModelMixin):
问题表
"""
id = models.UUIDField(primary_key=True, max_length=128, default=uuid.uuid1, editable=False, verbose_name="主键id")
document = models.ForeignKey(Document, on_delete=models.DO_NOTHING, db_constraint=False)
dataset = models.ForeignKey(DataSet, on_delete=models.DO_NOTHING, db_constraint=False)
paragraph = models.ForeignKey(Paragraph, on_delete=models.DO_NOTHING, db_constraint=False)
content = models.CharField(max_length=256, verbose_name="问题内容")
class Meta:
db_table = "problem"
class ProblemAnswerMapping(AppModelMixin):
"""
问题 段落 映射表
"""
id = models.UUIDField(primary_key=True, max_length=128, default=uuid.uuid1, editable=False, verbose_name="主键id")
paragraph = models.ForeignKey(Paragraph, on_delete=models.DO_NOTHING)
problem = models.ForeignKey(Problem, on_delete=models.DO_NOTHING)
hit_num = models.IntegerField(verbose_name="命中数量", default=0)
star_num = models.IntegerField(verbose_name="点赞数", default=0)
trample_num = models.IntegerField(verbose_name="点踩数", default=0)
class Meta:
db_table = "problem_paragraph_mapping"
db_table = "problem"

View File

@ -8,7 +8,6 @@
"""
import os.path
import uuid
from functools import reduce
from typing import Dict
from django.contrib.postgres.fields import ArrayField
@ -19,11 +18,12 @@ from drf_yasg import openapi
from rest_framework import serializers
from common.db.search import get_dynamics_model, native_page_search, native_search
from common.event.listener_manage import ListenerManagement
from common.exception.app_exception import AppApiException
from common.mixins.api_mixin import ApiMixin
from common.util.file_util import get_file_content
from dataset.models.data_set import DataSet, Document, Paragraph
from dataset.serializers.document_serializers import CreateDocumentSerializers
from dataset.models.data_set import DataSet, Document, Paragraph, Problem
from dataset.serializers.document_serializers import DocumentSerializers, DocumentInstanceSerializer
from setting.models import AuthOperate
from smartdoc.conf import PROJECT_DIR
from users.models import User
@ -81,12 +81,12 @@ class DataSetSerializers(serializers.ModelSerializer):
user_id = self.data.get("user_id")
query_set_dict = {}
query_set = QuerySet(model=get_dynamics_model(
{'dataset.name': models.CharField(), 'dataset.desc': models.CharField(),
{'temp.name': models.CharField(), 'temp.desc': models.CharField(),
"document_temp.char_length": models.IntegerField()}))
if "desc" in self.data:
query_set = query_set.filter(**{'dataset.desc__contains': self.data.get("desc")})
if "name" in self.data:
query_set = query_set.filter(**{'dataset.name__contains': self.data.get("name")})
if "desc" in self.data and self.data.get('desc') is not None:
query_set = query_set.filter(**{'temp.desc__contains': self.data.get("desc")})
if "name" in self.data and self.data.get('name') is not None:
query_set = query_set.filter(**{'temp.name__contains': self.data.get("name")})
query_set_dict['default_sql'] = query_set
@ -133,9 +133,7 @@ class DataSetSerializers(serializers.ModelSerializer):
@staticmethod
def get_response_body_api():
return openapi.Schema(type=openapi.TYPE_ARRAY,
title="数据集列表", description="数据集列表",
items=DataSetSerializers.Operate.get_response_body_api())
return DataSetSerializers.Operate.get_response_body_api()
class Create(ApiMixin, serializers.Serializer):
"""
@ -157,7 +155,7 @@ class DataSetSerializers(serializers.ModelSerializer):
message="数据集名称在1-256个字符之间")
])
documents = CreateDocumentSerializers(required=False, many=True)
documents = DocumentInstanceSerializer(required=False, many=True)
def is_valid(self, *, raise_exception=False):
super().is_valid(raise_exception=True)
@ -168,28 +166,46 @@ class DataSetSerializers(serializers.ModelSerializer):
dataset_id = uuid.uuid1()
dataset = DataSet(
**{'id': dataset_id, 'name': self.data.get("name"), 'desc': self.data.get('desc'), 'user': user})
document_model_list = []
paragraph_model_list = []
if 'documents' in self.data:
documents = self.data.get('documents')
for document in documents:
document_model = Document(**{'dataset': dataset, 'id': uuid.uuid1(), 'name': document.get('name'),
'char_length': reduce(lambda x, y: x + y,
list(
map(lambda p: len(p),
document.get("paragraphs"))), 0)})
document_model_list.append(document_model)
if 'paragraphs' in document:
paragraph_model_list += list(map(lambda p: Paragraph(
**{'document': document_model, 'id': uuid.uuid1(), 'content': p}),
document.get('paragraphs')))
# 插入数据集
dataset.save()
# 插入文档
QuerySet(Document).bulk_create(document_model_list) if len(document_model_list) > 0 else None
# 插入段落
QuerySet(Paragraph).bulk_create(paragraph_model_list) if len(paragraph_model_list) > 0 else None
return True
for document in self.data.get('documents') if 'documents' in self.data else []:
DocumentSerializers.Create(data={'dataset_id': dataset_id}).save(document, with_valid=True,
with_embedding=False)
ListenerManagement.embedding_by_dataset_signal.send(str(dataset.id))
return {**DataSetSerializers(dataset).data,
'document_list': DocumentSerializers.Query(data={'dataset_id': dataset_id}).list(with_valid=True)}
@staticmethod
def get_response_body_api():
return openapi.Schema(
type=openapi.TYPE_OBJECT,
required=['id', 'name', 'desc', 'user_id', 'char_length', 'document_count',
'update_time', 'create_time', 'document_list'],
properties={
'id': openapi.Schema(type=openapi.TYPE_STRING, title="id",
description="id", default="xx"),
'name': openapi.Schema(type=openapi.TYPE_STRING, title="名称",
description="名称", default="测试数据集"),
'desc': openapi.Schema(type=openapi.TYPE_STRING, title="描述",
description="描述", default="测试数据集描述"),
'user_id': openapi.Schema(type=openapi.TYPE_STRING, title="所属用户id",
description="所属用户id", default="user_xxxx"),
'char_length': openapi.Schema(type=openapi.TYPE_STRING, title="字符数",
description="字符数", default=10),
'document_count': openapi.Schema(type=openapi.TYPE_STRING, title="文档数量",
description="文档数量", default=1),
'update_time': openapi.Schema(type=openapi.TYPE_STRING, title="修改时间",
description="修改时间",
default="1970-01-01 00:00:00"),
'create_time': openapi.Schema(type=openapi.TYPE_STRING, title="创建时间",
description="创建时间",
default="1970-01-01 00:00:00"
),
'document_list': openapi.Schema(type=openapi.TYPE_ARRAY, title="文档列表",
description="文档列表",
items=DocumentSerializers.Operate.get_response_body_api())
}
)
@staticmethod
def get_request_body_api():
@ -200,7 +216,7 @@ class DataSetSerializers(serializers.ModelSerializer):
'name': openapi.Schema(type=openapi.TYPE_STRING, title="数据集名称", description="数据集名称"),
'desc': openapi.Schema(type=openapi.TYPE_STRING, title="数据集描述", description="数据集描述"),
'documents': openapi.Schema(type=openapi.TYPE_ARRAY, title="文档数据", description="文档数据",
items=CreateDocumentSerializers().get_request_body_api()
items=DocumentSerializers().Create.get_request_body_api()
)
}
)
@ -217,10 +233,11 @@ class DataSetSerializers(serializers.ModelSerializer):
def delete(self):
self.is_valid()
dataset = QuerySet(DataSet).get(id=self.data.get("id"))
document_list = QuerySet(Document).filter(dataset=dataset)
QuerySet(Paragraph).filter(document__in=document_list).delete()
document_list.delete()
QuerySet(Document).filter(dataset=dataset).delete()
QuerySet(Paragraph).filter(dataset=dataset).delete()
QuerySet(Problem).filter(dataset=dataset).delete()
dataset.delete()
ListenerManagement.delete_embedding_by_dataset_signal.send(self.data.get('id'))
return True
def one(self, user_id, with_valid=True):
@ -303,9 +320,9 @@ class DataSetSerializers(serializers.ModelSerializer):
@staticmethod
def get_request_params_api():
return [openapi.Parameter(name='id',
return [openapi.Parameter(name='dataset_id',
in_=openapi.IN_PATH,
type=openapi.TYPE_STRING,
required=False,
required=True,
description='数据集id')
]

View File

@ -6,20 +6,29 @@
@date2023/9/22 13:43
@desc:
"""
import os
import uuid
from functools import reduce
from typing import List, Dict
from django.core import validators
from django.db import transaction
from django.db.models import QuerySet
from drf_yasg import openapi
from rest_framework import serializers
from common.db.search import native_search, native_page_search
from common.event.listener_manage import ListenerManagement
from common.exception.app_exception import AppApiException
from common.mixins.api_mixin import ApiMixin
from dataset.models.data_set import DataSet, Document, Paragraph
from common.util.file_util import get_file_content
from common.util.split_model import SplitModel, get_split_model
from dataset.models.data_set import DataSet, Document, Paragraph, Problem
from dataset.serializers.paragraph_serializers import ParagraphSerializers, ParagraphInstanceSerializer
from smartdoc.conf import PROJECT_DIR
class CreateDocumentSerializers(ApiMixin, serializers.Serializer):
class DocumentInstanceSerializer(ApiMixin, serializers.Serializer):
name = serializers.CharField(required=True,
validators=[
validators.MaxLengthValidator(limit_value=128,
@ -28,52 +37,265 @@ class CreateDocumentSerializers(ApiMixin, serializers.Serializer):
message="数据集名称在1-128个字符之间")
])
paragraphs = serializers.ListField(required=False,
child=serializers.CharField(required=True,
validators=[
validators.MaxLengthValidator(limit_value=256,
message="段落在1-256个字符之间"),
validators.MinLengthValidator(limit_value=1,
message="段落在1-256个字符之间")
]))
paragraphs = ParagraphInstanceSerializer(required=False, many=True)
def is_valid(self, *, dataset_id=None, raise_exception=False):
if not QuerySet(DataSet).filter(id=dataset_id).exists():
raise AppApiException(10000, "数据集id不存在")
return super().is_valid(raise_exception=True)
def save(self, dataset_id: str, **kwargs):
document_model = Document(
**{'dataset': DataSet(id=dataset_id),
'id': uuid.uuid1(),
'name': self.data.get('name'),
'char_length': reduce(lambda x, y: x + y, list(map(lambda p: len(p), self.data.get("paragraphs"))), 0)})
paragraph_model_list = list(map(lambda p: Paragraph(
**{'document': document_model, 'id': uuid.uuid1(), 'content': p}),
self.data.get('paragraphs')))
# 插入文档
document_model.save()
# 插入段落
QuerySet(Paragraph).bulk_create(paragraph_model_list) if len(paragraph_model_list) > 0 else None
return True
def get_request_body_api(self):
@staticmethod
def get_request_body_api():
return openapi.Schema(
type=openapi.TYPE_OBJECT,
required=['name', 'paragraph'],
required=['name', 'paragraphs'],
properties={
'name': openapi.Schema(type=openapi.TYPE_STRING, title="文档名称", description="文档名称"),
'paragraphs': openapi.Schema(type=openapi.TYPE_ARRAY, title="段落列表", description="段落列表",
items=openapi.Schema(type=openapi.TYPE_STRING, title="段落数据",
description="段落数据"))
items=ParagraphSerializers.Create.get_request_body_api())
}
)
def get_request_params_api(self):
return [openapi.Parameter(name='dataset_id',
in_=openapi.IN_PATH,
type=openapi.TYPE_STRING,
class DocumentSerializers(ApiMixin, serializers.Serializer):
class Query(ApiMixin, serializers.Serializer):
# 数据集id
dataset_id = serializers.UUIDField(required=True)
name = serializers.CharField(required=False,
validators=[
validators.MaxLengthValidator(limit_value=128,
message="文档名称在1-128个字符之间"),
validators.MinLengthValidator(limit_value=1,
message="数据集名称在1-128个字符之间")
])
def get_query_set(self):
query_set = QuerySet(model=Document)
query_set = query_set.filter(**{'dataset_id': self.data.get("dataset_id")})
if 'name' in self.data and self.data.get('name') is not None:
query_set = query_set.filter(**{'name__contains': self.data.get('name')})
return query_set
def list(self, with_valid=False):
if with_valid:
self.is_valid(raise_exception=True)
query_set = self.get_query_set()
return native_search(query_set, select_string=get_file_content(
os.path.join(PROJECT_DIR, "apps", "dataset", 'sql', 'list_document.sql')))
def page(self, current_page, page_size):
query_set = self.get_query_set()
return native_page_search(current_page, page_size, query_set, select_string=get_file_content(
os.path.join(PROJECT_DIR, "apps", "dataset", 'sql', 'list_document.sql')))
@staticmethod
def get_request_params_api():
return [openapi.Parameter(name='name',
in_=openapi.IN_QUERY,
type=openapi.TYPE_STRING,
required=False,
description='文档名称')]
@staticmethod
def get_response_body_api():
return openapi.Schema(type=openapi.TYPE_ARRAY,
title="文档列表", description="文档列表",
items=DocumentSerializers.Operate.get_response_body_api())
class Operate(ApiMixin, serializers.Serializer):
document_id = serializers.UUIDField(required=True)
@staticmethod
def get_request_params_api():
return [openapi.Parameter(name='dataset_id',
in_=openapi.IN_PATH,
type=openapi.TYPE_STRING,
required=True,
description='数据集id'),
openapi.Parameter(name='document_id',
in_=openapi.IN_PATH,
type=openapi.TYPE_STRING,
required=True,
description='文档id')
]
def is_valid(self, *, raise_exception=False):
super().is_valid(raise_exception=True)
document_id = self.data.get('document_id')
if not QuerySet(Document).filter(id=document_id).exists():
raise AppApiException(500, "文档id不存在")
def one(self, with_valid=False):
if with_valid:
self.is_valid(raise_exception=True)
query_set = QuerySet(model=Document)
query_set = query_set.filter(**{'id': self.data.get("document_id")})
return native_search(query_set, select_string=get_file_content(
os.path.join(PROJECT_DIR, "apps", "dataset", 'sql', 'list_document.sql')), with_search_one=True)
def edit(self, instance: Dict, with_valid=False):
if with_valid:
self.is_valid()
_document = QuerySet(Document).get(id=self.data.get("document_id"))
update_keys = ['name', 'is_active']
for update_key in update_keys:
if update_key in instance and instance.get(update_key) is not None:
_document.__setattr__(update_key, instance.get(update_key))
_document.save()
return self.one()
@transaction.atomic
def delete(self):
document_id = self.data.get("document_id")
QuerySet(model=Document).filter(id=document_id).delete()
# 删除段落
QuerySet(model=Paragraph).filter(document_id=document_id).delete()
# 删除问题
QuerySet(model=Problem).filter(document_id=document_id).delete()
# 删除向量库
ListenerManagement.delete_embedding_by_document_signal.send(document_id)
return True
@staticmethod
def get_response_body_api():
return openapi.Schema(
type=openapi.TYPE_OBJECT,
required=['id', 'name', 'char_length', 'user_id', 'paragraph_count', 'is_active'
'update_time', 'create_time'],
properties={
'id': openapi.Schema(type=openapi.TYPE_STRING, title="id",
description="id", default="xx"),
'name': openapi.Schema(type=openapi.TYPE_STRING, title="名称",
description="名称", default="测试数据集"),
'char_length': openapi.Schema(type=openapi.TYPE_INTEGER, title="字符数",
description="字符数", default=10),
'user_id': openapi.Schema(type=openapi.TYPE_STRING, title="用户id", description="用户id"),
'paragraph_count': openapi.Schema(type=openapi.TYPE_INTEGER, title="文档数量",
description="文档数量", default=1),
'is_active': openapi.Schema(type=openapi.TYPE_BOOLEAN, title="是否可用",
description="是否可用", default=True),
'update_time': openapi.Schema(type=openapi.TYPE_STRING, title="修改时间",
description="修改时间",
default="1970-01-01 00:00:00"),
'create_time': openapi.Schema(type=openapi.TYPE_STRING, title="创建时间",
description="创建时间",
default="1970-01-01 00:00:00"
)
}
)
@staticmethod
def get_request_body_api():
return openapi.Schema(
type=openapi.TYPE_OBJECT,
properties={
'name': openapi.Schema(type=openapi.TYPE_STRING, title="文档名称", description="文档名称"),
'is_active': openapi.Schema(type=openapi.TYPE_BOOLEAN, title="是否可用", description="是否可用"),
}
)
class Create(ApiMixin, serializers.Serializer):
dataset_id = serializers.UUIDField(required=True)
def is_valid(self, *, raise_exception=False):
super().is_valid(raise_exception=True)
if not QuerySet(DataSet).filter(id=self.data.get('dataset_id')).exists():
raise AppApiException(10000, "数据集id不存在")
return True
def save(self, instance: Dict, with_valid=False, with_embedding=True, **kwargs):
if with_valid:
DocumentInstanceSerializer(data=instance).is_valid()
self.is_valid(raise_exception=True)
dataset_id = self.data.get('dataset_id')
document_model = Document(
**{'dataset_id': dataset_id,
'id': uuid.uuid1(),
'name': instance.get('name'),
'char_length': reduce(lambda x, y: x + y,
[len(p.get('content')) for p in instance.get('paragraphs', [])],
0)})
for paragraph in instance.get('paragraphs') if 'paragraphs' in instance else []:
ParagraphSerializers.Create(
data={'dataset_id': dataset_id, 'document_id': str(document_model.id)}).save(paragraph,
with_valid=True,
with_embedding=False)
# 插入文档
document_model.save()
if with_embedding:
ListenerManagement.embedding_by_document_signal.send(str(document_model.id))
return DocumentSerializers.Operate(
data={'dataset_id': dataset_id, 'document_id': str(document_model.id)}).one(
with_valid=True)
@staticmethod
def get_request_body_api():
return DocumentInstanceSerializer.get_request_body_api()
@staticmethod
def get_request_params_api():
return [openapi.Parameter(name='dataset_id',
in_=openapi.IN_PATH,
type=openapi.TYPE_STRING,
required=True,
description='数据集id')
]
class Split(ApiMixin, serializers.Serializer):
file = serializers.ListField(required=True)
limit = serializers.IntegerField(required=False)
patterns = serializers.ListField(required=False,
child=serializers.CharField(required=True))
with_filter = serializers.BooleanField(required=False)
def is_valid(self, *, raise_exception=True):
super().is_valid()
files = self.data.get('file')
for f in files:
if f.size > 1024 * 1024 * 10:
raise AppApiException(500, "上传文件最大不能超过10m")
@staticmethod
def get_request_params_api():
return [
openapi.Parameter(name='file',
in_=openapi.IN_FORM,
type=openapi.TYPE_ARRAY,
items=openapi.Items(type=openapi.TYPE_FILE),
required=True,
description='数据集id')]
description='上传文件'),
openapi.Parameter(name='limit',
in_=openapi.IN_FORM,
required=False,
type=openapi.TYPE_INTEGER, title="分段长度", description="分段长度"),
openapi.Parameter(name='patterns',
in_=openapi.IN_FORM,
required=False,
type=openapi.TYPE_ARRAY, items=openapi.Items(type=openapi.TYPE_STRING),
title="分段正则列表", description="分段正则列表"),
openapi.Parameter(name='with_filter',
in_=openapi.IN_FORM,
required=False,
type=openapi.TYPE_BOOLEAN, title="是否清除特殊字符", description="是否清除特殊字符"),
]
def parse(self):
file_list = self.data.get("file")
return list(map(lambda f: file_to_paragraph(f, self.data.get("patterns"), self.data.get("with_filter"),
self.data.get("limit")), file_list))
def file_to_paragraph(file, pattern_list: List, with_filter, limit: int):
data = file.read()
if pattern_list is None or len(pattern_list) > 0:
split_model = SplitModel(pattern_list, with_filter, limit)
else:
split_model = get_split_model(file.name)
try:
content = data.decode('utf-8')
except BaseException as e:
return {'name': file.name,
'content': []}
return {'name': file.name,
'content': split_model.parse(content)
}

View File

@ -0,0 +1,278 @@
# coding=utf-8
"""
@project: maxkb
@Author
@file paragraph_serializers.py
@date2023/10/16 15:51
@desc:
"""
import uuid
from typing import Dict
from django.core import validators
from django.db import transaction
from django.db.models import QuerySet
from drf_yasg import openapi
from rest_framework import serializers
from common.db.search import page_search
from common.event.listener_manage import ListenerManagement
from common.exception.app_exception import AppApiException
from common.mixins.api_mixin import ApiMixin
from dataset.models import Paragraph, Problem
from dataset.serializers.problem_serializers import ProblemInstanceSerializer, ProblemSerializer
class ParagraphSerializer(serializers.ModelSerializer):
class Meta:
model = Paragraph
fields = ['id', 'content', 'hit_num', 'star_num', 'trample_num', 'is_active', 'document_id', 'title',
'create_time', 'update_time']
class ParagraphInstanceSerializer(ApiMixin, serializers.Serializer):
"""
段落实例对象
"""
content = serializers.CharField(required=True, validators=[
validators.MaxLengthValidator(limit_value=1024,
message="段落在1-1024个字符之间"),
validators.MinLengthValidator(limit_value=1,
message="段落在1-1024个字符之间")
])
title = serializers.CharField(required=False)
problem_list = ProblemInstanceSerializer(required=False, many=True)
is_active = serializers.BooleanField(required=False)
@staticmethod
def get_request_body_api():
return openapi.Schema(
type=openapi.TYPE_OBJECT,
required=['content'],
properties={
'content': openapi.Schema(type=openapi.TYPE_STRING, title="分段内容", description="分段内容"),
'title': openapi.Schema(type=openapi.TYPE_STRING, title="分段标题",
description="分段标题"),
'is_active': openapi.Schema(type=openapi.TYPE_BOOLEAN, title="是否可用", description="是否可用"),
'problem_list': openapi.Schema(type=openapi.TYPE_ARRAY, title='问题列表',
description="问题列表",
items=ProblemInstanceSerializer.get_request_body_api())
}
)
class ParagraphSerializers(ApiMixin, serializers.Serializer):
class Operate(ApiMixin, serializers.Serializer):
# 段落id
paragraph_id = serializers.UUIDField(required=True)
# 数据集id
dataset_id = serializers.UUIDField(required=True)
# 数据集id
document_id = serializers.UUIDField(required=True)
def is_valid(self, *, raise_exception=True):
super().is_valid(raise_exception=True)
if not QuerySet(Paragraph).filter(id=self.data.get('paragraph_id')).exists():
raise AppApiException(500, "段落id不存在")
@transaction.atomic
def edit(self, instance: Dict):
self.is_valid()
_paragraph = QuerySet(Paragraph).get(id=self.data.get("paragraph_id"))
update_keys = ['title', 'content', 'is_active']
for update_key in update_keys:
if update_key in instance and instance.get(update_key) is not None:
_paragraph.__setattr__(update_key, instance.get(update_key))
if 'problem_list' in instance:
update_problem_list = list(
filter(lambda row: 'id' in row and row.get('id') is not None, instance.get('problem_list')))
create_problem_list = list(filter(lambda row: row.get('id') is None, instance.get('problem_list')))
# 问题集合
problem_list = QuerySet(Problem).filter(paragraph_id=self.data.get("paragraph_id"))
# 校验前端 携带过来的id
for update_problem in update_problem_list:
if not set([str(row.id) for row in problem_list]).__contains__(update_problem.get('id')):
raise AppApiException(500, update_problem.get('id') + '问题id不存在')
# 对比需要删除的问题
delete_problem_list = list(filter(
lambda row: not [str(update_row.get('id')) for update_row in update_problem_list].__contains__(
str(row.id)), problem_list)) if len(update_problem_list) > 0 else []
# 删除问题
QuerySet(Problem).filter(id__in=[row.id for row in delete_problem_list]).delete() if len(
delete_problem_list) > 0 else None
# 插入新的问题
QuerySet(Problem).bulk_create(
[Problem(id=uuid.uuid1(), content=p.get('content'), paragraph_id=self.data.get('paragraph_id'),
dataset_id=self.data.get('dataset_id'), document_id=self.data.get('document_id')) for
p in create_problem_list]) if len(create_problem_list) else None
# 修改问题集合
QuerySet(Problem).bulk_update(
[Problem(id=row.get('id'), content=row.get('content')) for row in update_problem_list],
['content']) if len(
update_problem_list) > 0 else None
_paragraph.save()
if 'is_active' in instance and instance.get('is_active') is not None:
s = (ListenerManagement.enable_embedding_by_paragraph_signal if instance.get(
'is_active') else ListenerManagement.disable_embedding_by_paragraph_signal)
s.send(self.data.get('paragraph_id'))
return self.one()
def get_problem_list(self):
return [ProblemSerializer(problem).data for problem in
QuerySet(Problem).filter(paragraph_id=self.data.get("paragraph_id"))]
def one(self, with_valid=False):
if with_valid:
self.is_valid(raise_exception=True)
return {**ParagraphSerializer(QuerySet(model=Paragraph).get(id=self.data.get('paragraph_id'))).data,
'problem_list': self.get_problem_list()}
def delete(self, with_valid=False):
if with_valid:
self.is_valid(raise_exception=True)
paragraph_id = self.data.get('paragraph_id')
QuerySet(Paragraph).filter(id=paragraph_id).delete()
QuerySet(Problem).filter(paragraph_id=paragraph_id).delete()
ListenerManagement.delete_embedding_by_paragraph_signal.send(paragraph_id)
@staticmethod
def get_request_body_api():
return ParagraphInstanceSerializer.get_request_body_api()
@staticmethod
def get_response_body_api():
return ParagraphInstanceSerializer.get_request_body_api()
@staticmethod
def get_request_params_api():
return [openapi.Parameter(type=openapi.TYPE_STRING, in_=openapi.IN_PATH, name='paragraph_id',
description="段落id")]
class Create(ApiMixin, serializers.Serializer):
dataset_id = serializers.UUIDField(required=True)
document_id = serializers.UUIDField(required=True)
def save(self, instance: Dict, with_valid=True, with_embedding=True):
if with_valid:
ParagraphSerializers(data=instance).is_valid(raise_exception=True)
self.is_valid()
dataset_id = self.data.get("dataset_id")
document_id = self.data.get('document_id')
paragraph = Paragraph(id=uuid.uuid1(),
document_id=document_id,
content=instance.get("content"),
dataset_id=dataset_id,
title=instance.get("title") if 'title' in instance else '')
# 插入段落
paragraph.save()
problem_model_list = [Problem(id=uuid.uuid1(), content=problem.get('content'), paragraph_id=paragraph.id,
document_id=document_id, dataset_id=dataset_id) for problem in (
instance.get('problem_list') if 'problem_list' in instance else [])]
# 插入問題
QuerySet(Problem).bulk_create(problem_model_list) if len(problem_model_list) > 0 else None
if with_embedding:
ListenerManagement.embedding_by_paragraph_signal.send(str(paragraph.id))
return ParagraphSerializers.Operate(
data={'paragraph_id': str(paragraph.id), 'dataset_id': dataset_id, 'document_id': document_id}).one(
with_valid=True)
@staticmethod
def get_request_body_api():
return ParagraphInstanceSerializer.get_request_body_api()
@staticmethod
def get_request_params_api():
return [openapi.Parameter(name='dataset_id',
in_=openapi.IN_PATH,
type=openapi.TYPE_STRING,
required=True,
description='数据集id'),
openapi.Parameter(name='document_id', in_=openapi.IN_PATH,
type=openapi.TYPE_STRING,
required=True,
description="文档id")
]
class Query(ApiMixin, serializers.Serializer):
dataset_id = serializers.UUIDField(required=True)
document_id = serializers.UUIDField(required=True)
title = serializers.CharField(required=False)
def get_query_set(self):
query_set = QuerySet(model=Paragraph)
query_set = query_set.filter(
**{'dataset_id': self.data.get('dataset_id'), 'document_id': self.data.get("document_id")})
if 'title' in self.data:
query_set = query_set.filter(
**{'title__contains': self.data.get('title')})
return query_set
def list(self):
return list(map(lambda row: ParagraphSerializer(row).data, self.get_query_set()))
def page(self, current_page, page_size):
query_set = self.get_query_set()
return page_search(current_page, page_size, query_set, lambda row: ParagraphSerializer(row).data)
@staticmethod
def get_request_params_api():
return [openapi.Parameter(name='document_id',
in_=openapi.IN_PATH,
type=openapi.TYPE_STRING,
required=True,
description='文档id'),
openapi.Parameter(name='title',
in_=openapi.IN_QUERY,
type=openapi.TYPE_STRING,
required=False,
description='标题')
]
@staticmethod
def get_response_body_api():
return openapi.Schema(
type=openapi.TYPE_OBJECT,
required=['id', 'content', 'hit_num', 'star_num', 'trample_num', 'is_active', 'document_id', 'title',
'create_time', 'update_time'],
properties={
'id': openapi.Schema(type=openapi.TYPE_STRING, title="id",
description="id", default="xx"),
'content': openapi.Schema(type=openapi.TYPE_STRING, title="段落内容",
description="段落内容", default='段落内容'),
'title': openapi.Schema(type=openapi.TYPE_STRING, title="标题",
description="标题", default="xxx的描述"),
'hit_num': openapi.Schema(type=openapi.TYPE_INTEGER, title="命中数量", description="命中数量",
default=1),
'star_num': openapi.Schema(type=openapi.TYPE_INTEGER, title="点赞数量",
description="点赞数量", default=1),
'trample_num': openapi.Schema(type=openapi.TYPE_INTEGER, title="点踩数量",
description="点踩数", default=1),
'document_id': openapi.Schema(type=openapi.TYPE_STRING, title="文档id",
description="文档id", default='xxx'),
'is_active': openapi.Schema(type=openapi.TYPE_BOOLEAN, title="是否可用",
description="是否可用", default=True),
'update_time': openapi.Schema(type=openapi.TYPE_STRING, title="修改时间",
description="修改时间",
default="1970-01-01 00:00:00"),
'create_time': openapi.Schema(type=openapi.TYPE_STRING, title="创建时间",
description="创建时间",
default="1970-01-01 00:00:00"
)
}
)

View File

@ -0,0 +1,222 @@
# coding=utf-8
"""
@project: maxkb
@Author
@file problem_serializers.py
@date2023/10/23 13:55
@desc:
"""
import uuid
from typing import Dict
from django.db.models import QuerySet
from drf_yasg import openapi
from rest_framework import serializers
from common.event.listener_manage import ListenerManagement
from common.exception.app_exception import AppApiException
from common.mixins.api_mixin import ApiMixin
from dataset.models import Problem, Paragraph
from embedding.models import SourceType
from embedding.vector.pg_vector import PGVector
class ProblemSerializer(serializers.ModelSerializer):
class Meta:
model = Problem
fields = ['id', 'content', 'hit_num', 'star_num', 'trample_num', 'dataset_id', 'document_id',
'create_time', 'update_time']
class ProblemInstanceSerializer(ApiMixin, serializers.Serializer):
id = serializers.CharField(required=False)
content = serializers.CharField(required=True)
@staticmethod
def get_request_body_api():
return openapi.Schema(type=openapi.TYPE_OBJECT,
required=["content"],
properties={
'id': openapi.Schema(
type=openapi.TYPE_STRING,
title="问题id,修改的时候传递,创建的时候不传"),
'content': openapi.Schema(
type=openapi.TYPE_STRING, title="内容")
})
class ProblemSerializers(ApiMixin, serializers.Serializer):
class Create(ApiMixin, serializers.Serializer):
dataset_id = serializers.UUIDField(required=True)
document_id = serializers.UUIDField(required=True)
paragraph_id = serializers.UUIDField(required=True)
def save(self, instance: Dict, with_valid=True, with_embedding=True):
if with_valid:
self.is_valid()
ProblemInstanceSerializer(data=instance).is_valid(raise_exception=True)
problem = Problem(id=uuid.uuid1(), paragraph_id=self.data.get('paragraph_id'),
document_id=self.data.get('document_id'), dataset_id=self.data.get('dataset_id'),
content=instance.get('content'))
problem.save()
if with_embedding:
ListenerManagement.embedding_by_problem_signal.send({'text': problem.content,
'is_active': True,
'source_type': SourceType.PROBLEM,
'source_id': problem.id,
'document_id': self.data.get('document_id'),
'paragraph_id': self.data.get('paragraph_id'),
'dataset_id': self.data.get('dataset_id')})
return ProblemSerializers.Operate(
data={'dataset_id': self.data.get('dataset_id'), 'document_id': self.data.get('document_id'),
'paragraph_id': self.data.get('paragraph_id'), 'problem_id': problem.id}).one(with_valid=True)
@staticmethod
def get_request_body_api():
return ProblemInstanceSerializer.get_request_body_api()
@staticmethod
def get_request_params_api():
return [openapi.Parameter(name='dataset_id',
in_=openapi.IN_PATH,
type=openapi.TYPE_STRING,
required=True,
description='数据集id'),
openapi.Parameter(name='document_id',
in_=openapi.IN_PATH,
type=openapi.TYPE_STRING,
required=True,
description='文档id'),
openapi.Parameter(name='paragraph_id',
in_=openapi.IN_PATH,
type=openapi.TYPE_STRING,
required=True,
description='段落id')]
class Query(ApiMixin, serializers.Serializer):
dataset_id = serializers.UUIDField(required=True)
document_id = serializers.UUIDField(required=True)
paragraph_id = serializers.UUIDField(required=True)
def is_valid(self, *, raise_exception=True):
super().is_valid(raise_exception=True)
if not QuerySet(Paragraph).filter(id=self.data.get('paragraph_id')).exists():
raise AppApiException(500, "段落id不存在")
def get_query_set(self):
dataset_id = self.data.get('dataset_id')
document_id = self.data.get('document_id')
paragraph_id = self.data.get("paragraph_id")
return QuerySet(Problem).filter(
**{'paragraph_id': paragraph_id, 'dataset_id': dataset_id, 'document_id': document_id})
def list(self, with_valid=False):
"""
获取问题列表
:param with_valid: 是否校验
:return: 问题列表
"""
if with_valid:
self.is_valid(raise_exception=True)
query_set = self.get_query_set()
return [ProblemSerializer(p).data for p in query_set]
@staticmethod
def get_request_params_api():
return [openapi.Parameter(name='dataset_id',
in_=openapi.IN_PATH,
type=openapi.TYPE_STRING,
required=True,
description='数据集id'),
openapi.Parameter(name='document_id',
in_=openapi.IN_PATH,
type=openapi.TYPE_STRING,
required=True,
description='文档id')
, openapi.Parameter(name='paragraph_id',
in_=openapi.IN_PATH,
type=openapi.TYPE_STRING,
required=True,
description='段落id')]
class Operate(ApiMixin, serializers.Serializer):
dataset_id = serializers.UUIDField(required=True)
document_id = serializers.UUIDField(required=True)
paragraph_id = serializers.UUIDField(required=True)
problem_id = serializers.UUIDField(required=True)
def delete(self, with_valid=False):
if with_valid:
self.is_valid(raise_exception=True)
QuerySet(Problem).filter(**{'id': self.data.get('problem_id')}).delete()
PGVector().delete_by_source_id(self.data.get('problem_id'), SourceType.PROBLEM)
ListenerManagement.delete_embedding_by_source_signal.send(self.data.get('problem_id'))
return True
def one(self, with_valid=False):
if with_valid:
self.is_valid(raise_exception=True)
return ProblemInstanceSerializer(QuerySet(Problem).get(**{'id': self.data.get('problem_id')})).data
@staticmethod
def get_request_params_api():
return [openapi.Parameter(name='dataset_id',
in_=openapi.IN_PATH,
type=openapi.TYPE_STRING,
required=True,
description='数据集id'),
openapi.Parameter(name='document_id',
in_=openapi.IN_PATH,
type=openapi.TYPE_STRING,
required=True,
description='文档id')
, openapi.Parameter(name='paragraph_id',
in_=openapi.IN_PATH,
type=openapi.TYPE_STRING,
required=True,
description='段落id'),
openapi.Parameter(name='problem_id',
in_=openapi.IN_PATH,
type=openapi.TYPE_STRING,
required=True,
description='问题id')
]
@staticmethod
def get_response_body_api():
return openapi.Schema(
type=openapi.TYPE_OBJECT,
required=['id', 'content', 'hit_num', 'star_num', 'trample_num', 'dataset_id',
'document_id',
'create_time', 'update_time'],
properties={
'id': openapi.Schema(type=openapi.TYPE_STRING, title="id",
description="id", default="xx"),
'content': openapi.Schema(type=openapi.TYPE_STRING, title="问题内容",
description="问题内容", default='问题内容'),
'hit_num': openapi.Schema(type=openapi.TYPE_INTEGER, title="命中数量", description="命中数量",
default=1),
'star_num': openapi.Schema(type=openapi.TYPE_INTEGER, title="点赞数量",
description="点赞数量", default=1),
'trample_num': openapi.Schema(type=openapi.TYPE_INTEGER, title="点踩数量",
description="点踩数", default=1),
'document_id': openapi.Schema(type=openapi.TYPE_STRING, title="文档id",
description="文档id", default='xxx'),
'update_time': openapi.Schema(type=openapi.TYPE_STRING, title="修改时间",
description="修改时间",
default="1970-01-01 00:00:00"),
'create_time': openapi.Schema(type=openapi.TYPE_STRING, title="创建时间",
description="创建时间",
default="1970-01-01 00:00:00"
)
}
)

View File

@ -0,0 +1,5 @@
SELECT
"document".* ,
(SELECT "count"("id") FROM "paragraph" WHERE document_id="document"."id") as "paragraph_count"
FROM
"document" "document"

View File

@ -0,0 +1,10 @@
SELECT
problem."id",
problem."content",
problem_paragraph_mapping.hit_num,
problem_paragraph_mapping.star_num,
problem_paragraph_mapping.trample_num,
problem_paragraph_mapping.paragraph_id
FROM
problem problem
LEFT JOIN problem_paragraph_mapping problem_paragraph_mapping ON problem."id" = problem_paragraph_mapping.problem_id

View File

@ -7,5 +7,18 @@ urlpatterns = [
path('dataset', views.Dataset.as_view(), name="dataset"),
path('dataset/<str:dataset_id>', views.Dataset.Operate.as_view(), name="dataset_key"),
path('dataset/<int:current_page>/<int:page_size>', views.Dataset.Page.as_view(), name="dataset"),
path('dataset/<str:dataset_id>/document', views.Document.as_view(), name='document')
path('dataset/<str:dataset_id>/document', views.Document.as_view(), name='document'),
path('dataset/<str:dataset_id>/document/<str:document_id>', views.Document.Operate.as_view(),
name="document_operate"),
path('dataset/document/split', views.Document.Split.as_view(),
name="document_operate"),
path('dataset/<str:dataset_id>/document/<str:document_id>/paragraph', views.Paragraph.as_view()),
path('dataset/<str:dataset_id>/document/<str:document_id>/paragraph/<int:current_page>/<int:page_size>',
views.Paragraph.Page.as_view(), name='paragraph_page'),
path('dataset/<str:dataset_id>/document/<str:document_id>/paragraph/<paragraph_id>',
views.Paragraph.Operate.as_view()),
path('dataset/<str:dataset_id>/document/<str:document_id>/paragraph/<str:paragraph_id>/problem',
views.Problem.as_view()),
path('dataset/<str:dataset_id>/document/<str:document_id>/paragraph/<str:paragraph_id>/problem/<str:problem_id>',
views.Problem.Operate.as_view())
]

View File

@ -8,3 +8,5 @@
"""
from .dataset import *
from .document import *
from .paragraph import *
from .problem import *

View File

@ -26,7 +26,7 @@ class Dataset(APIView):
@swagger_auto_schema(operation_summary="获取数据集列表",
operation_id="获取数据集列表",
manual_parameters=DataSetSerializers.Query.get_request_params_api(),
responses=get_api_response(DataSetSerializers.Query.get_response_body_api()))
responses=result.get_api_array_response(DataSetSerializers.Query.get_response_body_api()))
@has_permissions(PermissionConstants.DATASET_READ, compare=CompareConstants.AND)
def get(self, request: Request):
d = DataSetSerializers.Query(data={**request.query_params, 'user_id': str(request.user.id)})
@ -36,19 +36,21 @@ class Dataset(APIView):
@action(methods=['POST'], detail=False)
@swagger_auto_schema(operation_summary="创建数据集",
operation_id="创建数据集",
request_body=DataSetSerializers.Create.get_request_body_api())
request_body=DataSetSerializers.Create.get_request_body_api(),
responses=get_api_response(DataSetSerializers.Create.get_response_body_api()))
@has_permissions(PermissionConstants.DATASET_CREATE, compare=CompareConstants.AND)
def post(self, request: Request):
s = DataSetSerializers.Create(data=request.data)
if s.is_valid():
s.save(request.user)
return result.success("ok")
s.is_valid(raise_exception=True)
return result.success(s.save(request.user))
class Operate(APIView):
authentication_classes = [TokenAuth]
@action(methods="DELETE", detail=False)
@swagger_auto_schema(operation_summary="删除数据集", operation_id="删除数据集")
@swagger_auto_schema(operation_summary="删除数据集", operation_id="删除数据集",
manual_parameters=DataSetSerializers.Operate.get_request_params_api(),
responses=result.get_default_response())
@has_permissions(lambda r, keywords: Permission(group=Group.DATASET, operate=Operate.MANAGE,
dynamic_tag=keywords.get('dataset_id')),
lambda r, k: Permission(group=Group.DATASET, operate=Operate.DELETE,
@ -59,6 +61,7 @@ class Dataset(APIView):
@action(methods="GET", detail=False)
@swagger_auto_schema(operation_summary="查询数据集详情根据数据集id", operation_id="查询数据集详情根据数据集id",
manual_parameters=DataSetSerializers.Operate.get_request_params_api(),
responses=get_api_response(DataSetSerializers.Operate.get_response_body_api()))
@has_permissions(lambda r, keywords: Permission(group=Group.DATASET, operate=Operate.USE,
dynamic_tag=keywords.get('dataset_id')))
@ -67,6 +70,7 @@ class Dataset(APIView):
@action(methods="PUT", detail=False)
@swagger_auto_schema(operation_summary="修改数据集信息", operation_id="修改数据集信息",
manual_parameters=DataSetSerializers.Operate.get_request_params_api(),
request_body=DataSetSerializers.Operate.get_request_body_api(),
responses=get_api_response(DataSetSerializers.Operate.get_response_body_api()))
@has_permissions(lambda r, keywords: Permission(group=Group.DATASET, operate=Operate.MANAGE,
@ -84,8 +88,10 @@ class Dataset(APIView):
manual_parameters=get_page_request_params(
DataSetSerializers.Query.get_request_params_api()),
responses=get_page_api_response(DataSetSerializers.Query.get_response_body_api()))
@has_permissions(PermissionConstants.USER_READ, compare=CompareConstants.AND)
@has_permissions(PermissionConstants.DATASET_READ, compare=CompareConstants.AND)
def get(self, request: Request, current_page, page_size):
d = DataSetSerializers.Query(data={**request.query_params, 'user_id': str(request.user.id)})
d = DataSetSerializers.Query(
data={'name': request.query_params.get('name', None), 'desc': request.query_params.get("desc", None),
'user_id': str(request.user.id)})
d.is_valid()
return result.success(d.page(current_page, page_size))

View File

@ -9,13 +9,15 @@
from drf_yasg.utils import swagger_auto_schema
from rest_framework.decorators import action
from rest_framework.parsers import MultiPartParser
from rest_framework.views import APIView
from rest_framework.views import Request
from common.auth import TokenAuth, has_permissions
from common.constants.permission_constants import Permission, Group, Operate, PermissionConstants
from common.response import result
from dataset.serializers.dataset_serializers import CreateDocumentSerializers
from common.util.common import query_params_to_single_dict
from dataset.serializers.document_serializers import DocumentSerializers
class Document(APIView):
@ -24,28 +26,102 @@ class Document(APIView):
@action(methods=['POST'], detail=False)
@swagger_auto_schema(operation_summary="创建文档",
operation_id="创建文档",
request_body=CreateDocumentSerializers().get_request_body_api(),
manual_parameters=CreateDocumentSerializers().get_request_params_api())
@has_permissions(PermissionConstants.DATASET_CREATE)
request_body=DocumentSerializers.Create.get_request_body_api(),
manual_parameters=DocumentSerializers.Create.get_request_params_api(),
responses=result.get_api_response(DocumentSerializers.Operate.get_response_body_api()))
@has_permissions(
lambda r, k: Permission(group=Group.DATASET, operate=Operate.MANAGE,
dynamic_tag=k.get('dataset_id')))
def post(self, request: Request, dataset_id: str):
d = CreateDocumentSerializers(data=request.data)
if d.is_valid(dataset_id=dataset_id):
d.save(dataset_id)
return result.success("ok")
class DocumentDetails(APIView):
authentication_classes = [TokenAuth]
return result.success(
DocumentSerializers.Create(data={'dataset_id': dataset_id}).save(request.data, with_valid=True))
@action(methods=['GET'], detail=False)
@swagger_auto_schema(operation_summary="获取文档详情",
operation_id="获取文档详情",
request_body=CreateDocumentSerializers().get_request_body_api(),
manual_parameters=CreateDocumentSerializers().get_request_params_api())
@swagger_auto_schema(operation_summary="文档列表",
operation_id="文档列表",
manual_parameters=DocumentSerializers.Query.get_request_params_api(),
responses=result.get_api_response(DocumentSerializers.Query.get_response_body_api()))
@has_permissions(
lambda r, k: Permission(group=Group.DATASET, operate=Operate.MANAGE, dynamic_tag=k.get('dataset_id')))
lambda r, k: Permission(group=Group.DATASET, operate=Operate.USE,
dynamic_tag=k.get('dataset_id')))
def get(self, request: Request, dataset_id: str):
d = CreateDocumentSerializers(data=request.data)
if d.is_valid(dataset_id=dataset_id):
d.save(dataset_id)
return result.success("ok")
d = DocumentSerializers.Query(
data={**query_params_to_single_dict(request.query_params), 'dataset_id': dataset_id})
d.is_valid(raise_exception=True)
return result.success(d.list())
class Operate(APIView):
authentication_classes = [TokenAuth]
@action(methods=['GET'], detail=False)
@swagger_auto_schema(operation_summary="获取文档详情",
operation_id="获取文档详情",
manual_parameters=DocumentSerializers.Operate.get_request_params_api(),
responses=result.get_api_response(DocumentSerializers.Operate.get_response_body_api()))
@has_permissions(
lambda r, k: Permission(group=Group.DATASET, operate=Operate.USE,
dynamic_tag=k.get('dataset_id')))
def get(self, request: Request, dataset_id: str, document_id: str):
operate = DocumentSerializers.Operate(data={'document_id': document_id, 'dataset_id': dataset_id})
operate.is_valid(raise_exception=True)
return result.success(operate.one())
@action(methods=['PUT'], detail=False)
@swagger_auto_schema(operation_summary="修改文档",
operation_id="修改文档",
manual_parameters=DocumentSerializers.Operate.get_request_params_api(),
request_body=DocumentSerializers.Operate.get_request_body_api(),
responses=result.get_api_response(DocumentSerializers.Operate.get_response_body_api())
)
@has_permissions(
lambda r, k: Permission(group=Group.DATASET, operate=Operate.MANAGE,
dynamic_tag=k.get('dataset_id')))
def put(self, request: Request, dataset_id: str, document_id: str):
return result.success(
DocumentSerializers.Operate(data={'document_id': document_id, 'dataset_id': dataset_id}).edit(
request.data,
with_valid=True))
@action(methods=['DELETE'], detail=False)
@swagger_auto_schema(operation_summary="删除文档",
operation_id="删除文档",
manual_parameters=DocumentSerializers.Operate.get_request_params_api(),
responses=result.get_default_response())
@has_permissions(
lambda r, k: Permission(group=Group.DATASET, operate=Operate.MANAGE,
dynamic_tag=k.get('dataset_id')))
def delete(self, request: Request, dataset_id: str, document_id: str):
operate = DocumentSerializers.Operate(data={'document_id': document_id, 'dataset_id': dataset_id})
operate.is_valid(raise_exception=True)
return result.success(operate.delete())
class Split(APIView):
parser_classes = [MultiPartParser]
@action(methods=['POST'], detail=False)
@swagger_auto_schema(operation_summary="分段文档",
operation_id="分段文档",
manual_parameters=DocumentSerializers.Split.get_request_params_api())
def post(self, request: Request):
ds = DocumentSerializers.Split(
data={'file': request.FILES.getlist('file'),
'patterns': request.data.getlist('patterns[]')})
ds.is_valid(raise_exception=True)
return result.success(ds.parse())
class Page(APIView):
authentication_classes = [TokenAuth]
@action(methods=['GET'], detail=False)
@swagger_auto_schema(operation_summary="获取数据集分页列表",
operation_id="获取数据集分页列表",
manual_parameters=DocumentSerializers.Query.get_request_params_api(),
responses=result.get_page_api_response(DocumentSerializers.Query.get_response_body_api()))
@has_permissions(
lambda r, k: Permission(group=Group.DATASET, operate=Operate.USE,
dynamic_tag=k.get('dataset_id')))
def get(self, request: Request, dataset_id: str, current_page, page_size):
d = DocumentSerializers.Query(
data={**query_params_to_single_dict(request.query_params), 'dataset_id': dataset_id})
d.is_valid(raise_exception=True)
return result.success(d.page(current_page, page_size))

View File

@ -0,0 +1,115 @@
# coding=utf-8
"""
@project: maxkb
@Author
@file paragraph_serializers.py
@date2023/10/16 15:51
@desc:
"""
from drf_yasg.utils import swagger_auto_schema
from rest_framework.decorators import action
from rest_framework.views import APIView
from rest_framework.views import Request
from common.auth import TokenAuth, has_permissions
from common.constants.permission_constants import Permission, Group, Operate
from common.response import result
from common.util.common import query_params_to_single_dict
from dataset.serializers.paragraph_serializers import ParagraphSerializers
class Paragraph(APIView):
authentication_classes = [TokenAuth]
@action(methods=['GET'], detail=False)
@swagger_auto_schema(operation_summary="段落列表",
operation_id="段落列表",
manual_parameters=ParagraphSerializers.Query.get_request_params_api(),
responses=result.get_api_array_response(ParagraphSerializers.Query.get_response_body_api()))
@has_permissions(
lambda r, k: Permission(group=Group.DATASET, operate=Operate.USE,
dynamic_tag=k.get('dataset_id')))
def get(self, request: Request, dataset_id: str, document_id: str):
q = ParagraphSerializers.Query(
data={**query_params_to_single_dict(request.query_params), 'dataset_id': dataset_id,
'document_id': document_id})
q.is_valid(raise_exception=True)
return result.success(q.list())
@action(methods=['POST'], detail=False)
@swagger_auto_schema(operation_summary="创建段落",
operation_id="创建段落",
manual_parameters=ParagraphSerializers.Create.get_request_params_api(),
request_body=ParagraphSerializers.Create.get_request_body_api(),
responses=result.get_api_response(ParagraphSerializers.Query.get_response_body_api()))
@has_permissions(
lambda r, k: Permission(group=Group.DATASET, operate=Operate.MANAGE,
dynamic_tag=k.get('dataset_id')))
def post(self, request: Request, dataset_id: str, document_id: str):
return result.success(
ParagraphSerializers.Create(data={'dataset_id': dataset_id, 'document_id': document_id}).save(request.data))
class Operate(APIView):
authentication_classes = [TokenAuth]
@action(methods=['UPDATE'], detail=False)
@swagger_auto_schema(operation_summary="修改段落数据",
operation_id="修改段落数据",
manual_parameters=ParagraphSerializers.Operate.get_request_params_api(),
request_body=ParagraphSerializers.Operate.get_request_body_api(),
responses=result.get_api_response(ParagraphSerializers.Operate.get_response_body_api()))
@has_permissions(
lambda r, k: Permission(group=Group.DATASET, operate=Operate.MANAGE,
dynamic_tag=k.get('dataset_id')))
def put(self, request: Request, dataset_id: str, document_id: str, paragraph_id: str):
o = ParagraphSerializers.Operate(
data={"paragraph_id": paragraph_id, 'dataset_id': dataset_id, 'document_id': document_id})
o.is_valid(raise_exception=True)
return result.success(o.edit(request.data))
@action(methods=['UPDATE'], detail=False)
@swagger_auto_schema(operation_summary="获取段落详情",
operation_id="获取段落详情",
manual_parameters=ParagraphSerializers.Operate.get_request_params_api(),
responses=result.get_api_response(ParagraphSerializers.Operate.get_response_body_api()))
@has_permissions(
lambda r, k: Permission(group=Group.DATASET, operate=Operate.USE,
dynamic_tag=k.get('dataset_id')))
def get(self, request: Request, dataset_id: str, document_id: str, paragraph_id: str):
o = ParagraphSerializers.Operate(
data={"dataset_id": dataset_id, 'document_id': document_id, "paragraph_id": paragraph_id})
o.is_valid(raise_exception=True)
return result.success(o.one())
@action(methods=['DELETE'], detail=False)
@swagger_auto_schema(operation_summary="删除段落",
operation_id="删除段落",
manual_parameters=ParagraphSerializers.Operate.get_request_params_api(),
responses=result.get_default_response())
@has_permissions(
lambda r, k: Permission(group=Group.DATASET, operate=Operate.MANAGE,
dynamic_tag=k.get('dataset_id')))
def delete(self, request: Request, dataset_id: str, document_id: str, paragraph_id: str):
o = ParagraphSerializers.Operate(
data={"dataset_id": dataset_id, 'document_id': document_id, "paragraph_id": paragraph_id})
o.is_valid(raise_exception=True)
return result.success(o.delete())
class Page(APIView):
authentication_classes = [TokenAuth]
@action(methods=['GET'], detail=False)
@swagger_auto_schema(operation_summary="分页获取段落列表",
operation_id="分页获取段落列表",
manual_parameters=result.get_page_request_params(
ParagraphSerializers.Query.get_request_params_api()),
responses=result.get_page_api_response(ParagraphSerializers.Query.get_response_body_api()))
@has_permissions(
lambda r, k: Permission(group=Group.DATASET, operate=Operate.USE,
dynamic_tag=k.get('dataset_id')))
def get(self, request: Request, dataset_id: str, document_id: str, current_page, page_size):
d = ParagraphSerializers.Query(
data={**query_params_to_single_dict(request.query_params), 'dataset_id': dataset_id,
'document_id': document_id})
d.is_valid(raise_exception=True)
return result.success(d.page(current_page, page_size))

View File

@ -0,0 +1,65 @@
# coding=utf-8
"""
@project: maxkb
@Author
@file problem.py
@date2023/10/23 13:54
@desc:
"""
from drf_yasg.utils import swagger_auto_schema
from rest_framework.decorators import action
from rest_framework.request import Request
from rest_framework.views import APIView
from common.auth import TokenAuth, has_permissions
from common.constants.permission_constants import Permission, Group, Operate
from common.response import result
from dataset.serializers.problem_serializers import ProblemSerializers
class Problem(APIView):
authentication_classes = [TokenAuth]
@action(methods=['POST'], detail=False)
@swagger_auto_schema(operation_summary="添加关联问题",
operation_id="添加段落关联问题",
manual_parameters=ProblemSerializers.Create.get_request_params_api(),
request_body=ProblemSerializers.Create.get_request_body_api(),
responses=result.get_api_response(ProblemSerializers.Operate.get_response_body_api()))
@has_permissions(
lambda r, k: Permission(group=Group.DATASET, operate=Operate.MANAGE,
dynamic_tag=k.get('dataset_id')))
def post(self, request: Request, dataset_id: str, document_id: str, paragraph_id: str):
return result.success(ProblemSerializers.Create(
data={"dataset_id": dataset_id, 'document_id': document_id, 'paragraph_id': paragraph_id}).save(
request.data, with_valid=True))
@action(methods=['GET'], detail=False)
@swagger_auto_schema(operation_summary="获取段落问题列表",
operation_id="获取段落问题列表",
manual_parameters=ProblemSerializers.Query.get_request_params_api(),
responses=result.get_api_array_response(ProblemSerializers.Operate.get_response_body_api()))
@has_permissions(
lambda r, k: Permission(group=Group.DATASET, operate=Operate.USE,
dynamic_tag=k.get('dataset_id')))
def get(self, request: Request, dataset_id: str, document_id: str, paragraph_id: str):
return result.success(ProblemSerializers.Query(
data={"dataset_id": dataset_id, 'document_id': document_id, 'paragraph_id': paragraph_id}).list(
with_valid=True))
class Operate(APIView):
authentication_classes = [TokenAuth]
@action(methods=['DELETE'], detail=False)
@swagger_auto_schema(operation_summary="删除段落问题",
operation_id="删除段落问题",
manual_parameters=ProblemSerializers.Query.get_request_params_api(),
responses=result.get_default_response())
@has_permissions(
lambda r, k: Permission(group=Group.DATASET, operate=Operate.MANAGE,
dynamic_tag=k.get('dataset_id')))
def delete(self, request: Request, dataset_id: str, document_id: str, paragraph_id: str, problem_id: str):
o = ProblemSerializers.Operate(
data={'dataset_id': dataset_id, 'document_id': document_id, 'paragraph_id': paragraph_id,
'problem_id': problem_id})
return result.success(o.delete(with_valid=True))

View File

@ -1,4 +1,4 @@
# Generated by Django 4.1.10 on 2023-10-09 06:33
# Generated by Django 4.1.10 on 2023-10-24 12:13
import common.field.vector_field
from django.db import migrations, models
@ -20,8 +20,11 @@ class Migration(migrations.Migration):
('id', models.CharField(max_length=128, primary_key=True, serialize=False, verbose_name='主键id')),
('source_id', models.CharField(max_length=128, verbose_name='资源id')),
('source_type', models.CharField(choices=[('0', '问题'), ('1', '段落')], default='0', max_length=1, verbose_name='资源类型')),
('is_active', models.BooleanField(default=True, max_length=1, verbose_name='是否可用')),
('embedding', common.field.vector_field.VectorField(verbose_name='向量')),
('dataset', models.ForeignKey(on_delete=django.db.models.deletion.DO_NOTHING, to='dataset.dataset', verbose_name='数据集关联')),
('dataset', models.ForeignKey(db_constraint=False, on_delete=django.db.models.deletion.DO_NOTHING, to='dataset.dataset', verbose_name='文档关联')),
('document', models.ForeignKey(db_constraint=False, on_delete=django.db.models.deletion.DO_NOTHING, to='dataset.document', verbose_name='文档关联')),
('paragraph', models.ForeignKey(db_constraint=False, on_delete=django.db.models.deletion.DO_NOTHING, to='dataset.paragraph', verbose_name='段落关联')),
],
options={
'db_table': 'embedding',

View File

@ -9,7 +9,7 @@
from django.db import models
from common.field.vector_field import VectorField
from dataset.models.data_set import DataSet
from dataset.models.data_set import Document, Paragraph, DataSet
class SourceType(models.TextChoices):
@ -26,7 +26,13 @@ class Embedding(models.Model):
source_type = models.CharField(verbose_name='资源类型', max_length=1, choices=SourceType.choices,
default=SourceType.PROBLEM)
dataset = models.ForeignKey(DataSet, on_delete=models.DO_NOTHING, verbose_name="数据集关联")
is_active = models.BooleanField(verbose_name="是否可用", max_length=1, default=True)
dataset = models.ForeignKey(DataSet, on_delete=models.DO_NOTHING, verbose_name="文档关联", db_constraint=False)
document = models.ForeignKey(Document, on_delete=models.DO_NOTHING, verbose_name="文档关联", db_constraint=False)
paragraph = models.ForeignKey(Paragraph, on_delete=models.DO_NOTHING, verbose_name="段落关联", db_constraint=False)
embedding = VectorField(verbose_name="向量")

View File

@ -0,0 +1,117 @@
# coding=utf-8
"""
@project: maxkb
@Author
@file base_vector.py
@date2023/10/18 19:16
@desc:
"""
from abc import ABC, abstractmethod
from typing import List, Dict
from langchain.embeddings import HuggingFaceEmbeddings
from common.config.embedding_config import EmbeddingModel
from embedding.models import SourceType
class BaseVectorStore(ABC):
vector_exists = False
@abstractmethod
def vector_is_create(self) -> bool:
"""
判断向量库是否创建
:return: 是否创建向量库
"""
pass
@abstractmethod
def vector_create(self):
"""
创建 向量库
:return:
"""
pass
def save_pre_handler(self):
"""
插入前置处理器 主要是判断向量库是否创建
:return: True
"""
if not BaseVectorStore.vector_exists:
if not self.vector_is_create():
self.vector_create()
BaseVectorStore.vector_exists = True
return True
def save(self, text, source_type: SourceType, dataset_id: str, document_id: str, paragraph_id: str, source_id: str,
is_active: bool,
embedding=None):
"""
插入向量数据
:param source_id: 资源id
:param dataset_id: 数据集id
:param text: 文本
:param source_type: 资源类型
:param document_id: 文档id
:param is_active: 是否禁用
:param embedding: 向量化处理器
:param paragraph_id 段落id
:return: bool
"""
if embedding is None:
embedding = EmbeddingModel.get_embedding_model()
self.save_pre_handler()
self._save(text, source_type, dataset_id, document_id, paragraph_id, source_id, is_active, embedding)
def batch_save(self, data_list: List[Dict], embedding=None):
"""
批量插入
:param data_list: 数据列表
:param embedding: 向量化处理器
:return: bool
"""
if embedding is None:
embedding = EmbeddingModel.get_embedding_model()
self.save_pre_handler()
self._batch_save(data_list, embedding)
return True
@abstractmethod
def _save(self, text, source_type: SourceType, dataset_id: str, document_id: str, paragraph_id: str, source_id: str,
is_active: bool,
embedding: HuggingFaceEmbeddings):
pass
@abstractmethod
def _batch_save(self, text_list: List[Dict], embedding: HuggingFaceEmbeddings):
pass
@abstractmethod
def search(self, query_text, dataset_id_list: list[str], is_active: bool, embedding: HuggingFaceEmbeddings):
pass
@abstractmethod
def update_by_paragraph_id(self, paragraph_id: str, instance: Dict):
pass
@abstractmethod
def update_by_source_id(self, source_id: str, instance: Dict):
pass
@abstractmethod
def delete_by_dataset_id(self, dataset_id: str):
pass
@abstractmethod
def delete_by_document_id(self, document_id: str):
pass
@abstractmethod
def delete_by_source_id(self, source_id: str, source_type: str):
pass
@abstractmethod
def delete_by_paragraph_id(self, paragraph_id: str):
pass

View File

@ -0,0 +1,79 @@
# coding=utf-8
"""
@project: maxkb
@Author
@file pg_vector.py
@date2023/10/19 15:28
@desc:
"""
import uuid
from typing import Dict, List
from django.db.models import QuerySet
from langchain.embeddings import HuggingFaceEmbeddings
from embedding.models import Embedding, SourceType
from embedding.vector.base_vector import BaseVectorStore
class PGVector(BaseVectorStore):
def vector_is_create(self) -> bool:
# 项目启动默认是创建好的 不需要再创建
return True
def vector_create(self):
return True
def _save(self, text, source_type: SourceType, dataset_id: str, document_id: str, paragraph_id: str, source_id: str,
is_active: bool,
embedding: HuggingFaceEmbeddings):
text_embedding = embedding.embed_query(text)
embedding = Embedding(id=uuid.uuid1(),
dataset_id=dataset_id,
document_id=document_id,
is_active=is_active,
paragraph_id=paragraph_id,
source_id=source_id,
embedding=text_embedding,
source_type=source_type,
)
embedding.save()
return True
def _batch_save(self, text_list: List[Dict], embedding: HuggingFaceEmbeddings):
texts = [row.get('text') for row in text_list]
embeddings = embedding.embed_documents(texts)
QuerySet(Embedding).bulk_create([Embedding(id=uuid.uuid1(),
document_id=text_list[index].get('document_id'),
paragraph_id=text_list[index].get('paragraph_id'),
dataset_id=text_list[index].get('dataset_id'),
is_active=text_list[index].get('is_active', True),
source_id=text_list[index].get('source_id'),
source_type=text_list[index].get('source_type'),
embedding=embeddings[index]) for index in
range(0, len(text_list))]) if len(text_list) > 0 else None
return True
def search(self, query_text, dataset_id_list: list[str], is_active: bool, embedding: HuggingFaceEmbeddings):
pass
def update_by_source_id(self, source_id: str, instance: Dict):
QuerySet(Embedding).filter(source_id=source_id).update(**instance)
def update_by_paragraph_id(self, paragraph_id: str, instance: Dict):
QuerySet(Embedding).filter(paragraph_id=paragraph_id).update(**instance)
def delete_by_dataset_id(self, dataset_id: str):
QuerySet(Embedding).filter(dataset_id=dataset_id).delete()
def delete_by_document_id(self, document_id: str):
QuerySet(Embedding).filter(document_id=document_id).delete()
return True
def delete_by_source_id(self, source_id: str, source_type: str):
QuerySet(Embedding).filter(source_id=source_id, source_type=source_type).delete()
return True
def delete_by_paragraph_id(self, paragraph_id: str):
QuerySet(Embedding).filter(paragraph_id=paragraph_id).delete()

View File

@ -1,4 +1,4 @@
# Generated by Django 4.1.10 on 2023-10-09 06:33
# Generated by Django 4.1.10 on 2023-10-24 12:13
import django.contrib.postgres.fields
from django.db import migrations, models

View File

@ -13,7 +13,7 @@ import os
import re
from importlib import import_module
from urllib.parse import urljoin, urlparse
import torch.backends
import yaml
BASE_DIR = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
@ -88,7 +88,14 @@ class Config(dict):
"EMAIL_HOST": "",
"EMAIL_PORT": 465,
"EMAIL_HOST_USER": "",
"EMAIL_HOST_PASSWORD": ""
"EMAIL_HOST_PASSWORD": "",
# 向量模型
"EMBEDDING_MODEL_NAME": "shibing624/text2vec-base-chinese",
"EMBEDDING_DEVICE": "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu",
"EMBEDDING_MODEL_PATH": os.path.join(PROJECT_DIR, 'models'),
# 向量库配置
"VECTOR_STORE_NAME": 'pg_vector'
}
def get_db_setting(self) -> dict:
@ -120,6 +127,8 @@ class ConfigManager:
def __init__(self, root_path=None):
self.root_path = root_path
self.config = self.config_class()
for key in self.config_class.defaults:
self.config[key] = self.config_class.defaults[key]
def from_mapping(self, *mapping, **kwargs):
"""Updates the config like :meth:`update` ignoring items with non-upper

View File

@ -100,6 +100,11 @@ LOGGING = {
'level': LOG_LEVEL,
'propagate': False,
},
'sqlalchemy': {
'handlers': ['console', 'file', 'syslog'],
'level': LOG_LEVEL,
'propagate': False,
},
'django.db.backends': {
'handlers': ['console', 'file', 'syslog'],
'propagate': False,

View File

@ -14,3 +14,12 @@ from django.core.wsgi import get_wsgi_application
os.environ.setdefault('DJANGO_SETTINGS_MODULE', 'smartdoc.settings')
application = get_wsgi_application()
def post_handler():
from common.event.listener_manage import ListenerManagement
ListenerManagement().run()
ListenerManagement.init_embedding_model_signal.send()
post_handler()

View File

@ -4,3 +4,4 @@ from django.apps import AppConfig
class UsersConfig(AppConfig):
default_auto_field = 'django.db.models.BigAutoField'
name = 'users'

View File

@ -1,4 +1,4 @@
# Generated by Django 4.1.10 on 2023-10-09 06:33
# Generated by Django 4.1.10 on 2023-10-24 12:13
from django.db import migrations, models
import uuid

View File

@ -22,7 +22,7 @@ from common.constants.permission_constants import PermissionConstants, CompareCo
from common.response import result
from smartdoc.settings import JWT_AUTH
from users.models.user import User as UserModel
from users.serializers.user_serializers import RegisterSerializer, LoginSerializer, UserSerializer, CheckCodeSerializer, \
from users.serializers.user_serializers import RegisterSerializer, LoginSerializer, CheckCodeSerializer, \
RePasswordSerializer, \
SendEmailSerializer, UserProfile

View File

@ -17,6 +17,12 @@ psycopg2-binary = "2.9.7"
jieba = "^0.42.1"
diskcache = "^5.6.3"
pillow = "9.5.0"
filetype = "^1.2.0"
chardet = "^5.2.0"
torch = "^2.1.0"
sentence-transformers = "^2.2.2"
blinker = "^1.6.3"
[build-system]
requires = ["poetry-core"]