mirror of
https://github.com/1Panel-dev/MaxKB.git
synced 2025-12-26 01:33:05 +00:00
feat: 数据集,文档,段落,问题,向量化接口
This commit is contained in:
parent
64e93679f8
commit
a2de9691fb
|
|
@ -162,6 +162,7 @@ cython_debug/
|
|||
ui/node_modules
|
||||
ui/dist
|
||||
apps/static
|
||||
models/
|
||||
data
|
||||
.idea
|
||||
.dev
|
||||
|
|
|
|||
|
|
@ -1,4 +1,4 @@
|
|||
# Generated by Django 4.1.10 on 2023-10-09 06:33
|
||||
# Generated by Django 4.1.10 on 2023-10-24 12:13
|
||||
|
||||
import django.contrib.postgres.fields
|
||||
from django.db import migrations, models
|
||||
|
|
|
|||
|
|
@ -0,0 +1,51 @@
|
|||
# coding=utf-8
|
||||
"""
|
||||
@project: maxkb
|
||||
@Author:虎
|
||||
@file: embedding_config.py
|
||||
@date:2023/10/23 16:03
|
||||
@desc:
|
||||
"""
|
||||
import types
|
||||
from smartdoc.const import CONFIG
|
||||
from langchain.embeddings import HuggingFaceEmbeddings
|
||||
|
||||
|
||||
class EmbeddingModel:
|
||||
instance = None
|
||||
|
||||
@staticmethod
|
||||
def get_embedding_model():
|
||||
"""
|
||||
获取向量化模型
|
||||
:return:
|
||||
"""
|
||||
if EmbeddingModel.instance is None:
|
||||
model_name = CONFIG.get('EMBEDDING_MODEL_NAME')
|
||||
cache_folder = CONFIG.get('EMBEDDING_MODEL_PATH')
|
||||
device = CONFIG.get('EMBEDDING_DEVICE')
|
||||
e = HuggingFaceEmbeddings(
|
||||
model_name=model_name,
|
||||
cache_folder=cache_folder,
|
||||
model_kwargs={'device': device})
|
||||
EmbeddingModel.instance = e
|
||||
return EmbeddingModel.instance
|
||||
|
||||
|
||||
class VectorStore:
|
||||
from embedding.vector.pg_vector import PGVector
|
||||
from embedding.vector.base_vector import BaseVectorStore
|
||||
instance_map = {
|
||||
'pg_vector': PGVector,
|
||||
}
|
||||
instance = None
|
||||
|
||||
@staticmethod
|
||||
def get_embedding_vector() -> BaseVectorStore:
|
||||
from embedding.vector.pg_vector import PGVector
|
||||
if VectorStore.instance is None:
|
||||
from smartdoc.const import CONFIG
|
||||
vector_store_class = VectorStore.instance_map.get(CONFIG.get("VECTOR_STORE_NAME"),
|
||||
PGVector)
|
||||
VectorStore.instance = vector_store_class()
|
||||
return VectorStore.instance
|
||||
|
|
@ -1,4 +1,3 @@
|
|||
# coding=utf-8
|
||||
"""
|
||||
@project: qabot
|
||||
@Author:虎
|
||||
|
|
|
|||
|
|
@ -0,0 +1,162 @@
|
|||
# coding=utf-8
|
||||
"""
|
||||
@project: maxkb
|
||||
@Author:虎
|
||||
@file: listener_manage.py
|
||||
@date:2023/10/20 14:01
|
||||
@desc:
|
||||
"""
|
||||
import os
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
|
||||
import django.db.models
|
||||
from blinker import signal
|
||||
from django.db.models import QuerySet
|
||||
|
||||
from common.config.embedding_config import VectorStore, EmbeddingModel
|
||||
from common.db.search import native_search, get_dynamics_model
|
||||
from common.util.file_util import get_file_content
|
||||
from dataset.models import Paragraph, Status, Document
|
||||
from embedding.models import SourceType
|
||||
from smartdoc.conf import PROJECT_DIR
|
||||
|
||||
|
||||
def poxy(poxy_function):
|
||||
def inner(args):
|
||||
ListenerManagement.work_thread_pool.submit(poxy_function, args)
|
||||
|
||||
return inner
|
||||
|
||||
|
||||
class ListenerManagement:
|
||||
work_thread_pool = ThreadPoolExecutor(5)
|
||||
|
||||
embedding_by_problem_signal = signal("embedding_by_problem")
|
||||
embedding_by_paragraph_signal = signal("embedding_by_paragraph")
|
||||
embedding_by_dataset_signal = signal("embedding_by_dataset")
|
||||
embedding_by_document_signal = signal("embedding_by_document")
|
||||
delete_embedding_by_document_signal = signal("delete_embedding_by_document")
|
||||
delete_embedding_by_dataset_signal = signal("delete_embedding_by_dataset")
|
||||
delete_embedding_by_paragraph_signal = signal("delete_embedding_by_paragraph")
|
||||
delete_embedding_by_source_signal = signal("delete_embedding_by_source")
|
||||
enable_embedding_by_paragraph_signal = signal('enable_embedding_by_paragraph')
|
||||
disable_embedding_by_paragraph_signal = signal('disable_embedding_by_paragraph')
|
||||
init_embedding_model_signal = signal('init_embedding_model')
|
||||
|
||||
@staticmethod
|
||||
def embedding_by_problem(args):
|
||||
VectorStore.get_embedding_vector().save(**args)
|
||||
|
||||
@staticmethod
|
||||
@poxy
|
||||
def embedding_by_paragraph(paragraph_id):
|
||||
"""
|
||||
向量化段落 根据段落id
|
||||
:param paragraph_id: 段落id
|
||||
:return: None
|
||||
"""
|
||||
data_list = native_search(
|
||||
{'problem': QuerySet(get_dynamics_model({'problem.paragraph_id': django.db.models.CharField()})).filter(
|
||||
**{'problem.paragraph_id': paragraph_id}),
|
||||
'paragraph': QuerySet(Paragraph).filter(id=paragraph_id)},
|
||||
select_string=get_file_content(
|
||||
os.path.join(PROJECT_DIR, "apps", "common", 'sql', 'list_embedding_text.sql')))
|
||||
# 批量向量化
|
||||
VectorStore.get_embedding_vector().batch_save(data_list)
|
||||
QuerySet(Paragraph).filter(id=paragraph_id).update(**{'status': Status.success.value})
|
||||
|
||||
@staticmethod
|
||||
@poxy
|
||||
def embedding_by_document(document_id):
|
||||
"""
|
||||
向量化文档
|
||||
:param document_id: 文档id
|
||||
:return: None
|
||||
"""
|
||||
data_list = native_search(
|
||||
{'problem': QuerySet(get_dynamics_model({'problem.document_id': django.db.models.CharField()})).filter(
|
||||
**{'problem.document_id': document_id}),
|
||||
'paragraph': QuerySet(Paragraph).filter(document_id=document_id)},
|
||||
select_string=get_file_content(
|
||||
os.path.join(PROJECT_DIR, "apps", "common", 'sql', 'list_embedding_text.sql')))
|
||||
# 批量向量化
|
||||
VectorStore.get_embedding_vector().batch_save(data_list)
|
||||
# 修改状态
|
||||
QuerySet(Document).filter(id=document_id).update(**{'status': Status.success.value})
|
||||
QuerySet(Paragraph).filter(document_id=document_id).update(**{'status': Status.success.value})
|
||||
|
||||
@staticmethod
|
||||
@poxy
|
||||
def embedding_by_dataset(dataset_id):
|
||||
"""
|
||||
向量化数据集
|
||||
:param dataset_id: 数据集id
|
||||
:return: None
|
||||
"""
|
||||
data_list = native_search(
|
||||
{'problem': QuerySet(get_dynamics_model({'problem.dataset_id': django.db.models.CharField()})).filter(
|
||||
**{'problem.dataset_id': dataset_id}),
|
||||
'paragraph': QuerySet(Paragraph).filter(dataset_id=dataset_id)},
|
||||
select_string=get_file_content(
|
||||
os.path.join(PROJECT_DIR, "apps", "common", 'sql', 'list_embedding_text.sql')))
|
||||
# 批量向量化
|
||||
VectorStore.get_embedding_vector().batch_save(data_list)
|
||||
# 修改文档 以及段落的状态
|
||||
QuerySet(Document).filter(dataset_id=dataset_id).update(**{'status': Status.success.value})
|
||||
QuerySet(Paragraph).filter(dataset_id=dataset_id).update(**{'status': Status.success.value})
|
||||
|
||||
@staticmethod
|
||||
def delete_embedding_by_document(document_id):
|
||||
VectorStore.get_embedding_vector().delete_by_document_id(document_id)
|
||||
|
||||
@staticmethod
|
||||
def delete_embedding_by_dataset(dataset_id):
|
||||
VectorStore.get_embedding_vector().delete_by_dataset_id(dataset_id)
|
||||
|
||||
@staticmethod
|
||||
def delete_embedding_by_paragraph(paragraph_id):
|
||||
VectorStore.get_embedding_vector().delete_by_paragraph_id(paragraph_id)
|
||||
|
||||
@staticmethod
|
||||
def delete_embedding_by_source(source_id):
|
||||
VectorStore.get_embedding_vector().delete_by_source_id(source_id, SourceType.PROBLEM)
|
||||
|
||||
@staticmethod
|
||||
def disable_embedding_by_paragraph(paragraph_id):
|
||||
VectorStore.get_embedding_vector().update_by_paragraph_id(paragraph_id, {'is_active': False})
|
||||
|
||||
@staticmethod
|
||||
def enable_embedding_by_paragraph(paragraph_id):
|
||||
VectorStore.get_embedding_vector().update_by_paragraph_id(paragraph_id, {'is_active': True})
|
||||
|
||||
@staticmethod
|
||||
@poxy
|
||||
def init_embedding_model(ags):
|
||||
EmbeddingModel.get_embedding_model()
|
||||
|
||||
def run(self):
|
||||
# 添加向量 根据问题id
|
||||
ListenerManagement.embedding_by_problem_signal.connect(self.embedding_by_problem)
|
||||
# 添加向量 根据段落id
|
||||
ListenerManagement.embedding_by_paragraph_signal.connect(self.embedding_by_paragraph)
|
||||
# 添加向量 根据数据集id
|
||||
ListenerManagement.embedding_by_dataset_signal.connect(
|
||||
self.embedding_by_dataset)
|
||||
# 添加向量 根据文档id
|
||||
ListenerManagement.embedding_by_document_signal.connect(
|
||||
self.embedding_by_document)
|
||||
# 删除 向量 根据文档
|
||||
ListenerManagement.delete_embedding_by_document_signal.connect(self.delete_embedding_by_document)
|
||||
# 删除 向量 根据数据集id
|
||||
ListenerManagement.delete_embedding_by_dataset_signal.connect(self.delete_embedding_by_dataset)
|
||||
# 删除向量 根据段落id
|
||||
ListenerManagement.delete_embedding_by_paragraph_signal.connect(
|
||||
self.delete_embedding_by_paragraph)
|
||||
# 删除向量 根据资源id
|
||||
ListenerManagement.delete_embedding_by_source_signal.connect(self.delete_embedding_by_source)
|
||||
# 禁用段落
|
||||
ListenerManagement.disable_embedding_by_paragraph_signal.connect(self.disable_embedding_by_paragraph)
|
||||
# 启动段落向量
|
||||
ListenerManagement.enable_embedding_by_paragraph_signal.connect(self.enable_embedding_by_paragraph)
|
||||
# 初始化向量化模型
|
||||
ListenerManagement.init_embedding_model_signal.connect(self.init_embedding_model)
|
||||
|
|
@ -14,7 +14,7 @@ from rest_framework.views import exception_handler
|
|||
from common.exception.app_exception import AppApiException
|
||||
from common.response import result
|
||||
|
||||
|
||||
import traceback
|
||||
def to_result(key, args, parent_key=None):
|
||||
"""
|
||||
将校验异常 args转换为统一数据
|
||||
|
|
@ -59,6 +59,7 @@ def handle_exception(exc, context):
|
|||
exception_class = exc.__class__
|
||||
# 先调用REST framework默认的异常处理方法获得标准错误响应对象
|
||||
response = exception_handler(exc, context)
|
||||
traceback.print_exc()
|
||||
# 在此处补充自定义的异常处理
|
||||
if issubclass(exception_class, ValidationError):
|
||||
return validation_error_to_result(exc)
|
||||
|
|
|
|||
|
|
@ -70,7 +70,9 @@ def get_page_api_response(response_data_schema: openapi.Schema):
|
|||
title="总条数",
|
||||
default=1,
|
||||
description="数据总条数"),
|
||||
"records": response_data_schema,
|
||||
"records": openapi.Schema(
|
||||
type=openapi.TYPE_ARRAY,
|
||||
items=response_data_schema),
|
||||
"current": openapi.Schema(
|
||||
type=openapi.TYPE_INTEGER,
|
||||
title="当前页",
|
||||
|
|
@ -115,6 +117,36 @@ def get_api_response(response_data_schema: openapi.Schema):
|
|||
)})
|
||||
|
||||
|
||||
def get_default_response():
|
||||
return get_api_response(openapi.Schema(type=openapi.TYPE_BOOLEAN))
|
||||
|
||||
|
||||
def get_api_array_response(response_data_schema: openapi.Schema):
|
||||
"""
|
||||
获取统一返回 响应Api
|
||||
"""
|
||||
return openapi.Responses(responses={200: openapi.Response(description="响应参数",
|
||||
schema=openapi.Schema(
|
||||
type=openapi.TYPE_OBJECT,
|
||||
properties={
|
||||
'code': openapi.Schema(
|
||||
type=openapi.TYPE_INTEGER,
|
||||
title="响应码",
|
||||
default=200,
|
||||
description="成功:200 失败:其他"),
|
||||
"message": openapi.Schema(
|
||||
type=openapi.TYPE_STRING,
|
||||
title="提示",
|
||||
default='成功',
|
||||
description="错误提示"),
|
||||
"data": openapi.Schema(type=openapi.TYPE_ARRAY,
|
||||
items=response_data_schema)
|
||||
|
||||
}
|
||||
),
|
||||
)})
|
||||
|
||||
|
||||
def success(data):
|
||||
"""
|
||||
获取一个成功的响应对象
|
||||
|
|
|
|||
|
|
@ -0,0 +1,26 @@
|
|||
SELECT
|
||||
problem."id" AS "source_id",
|
||||
problem.document_id AS document_id,
|
||||
problem.paragraph_id AS paragraph_id,
|
||||
problem.dataset_id AS dataset_id,
|
||||
0 AS source_type,
|
||||
problem."content" AS "text",
|
||||
paragraph.is_active AS is_active
|
||||
FROM
|
||||
problem problem
|
||||
LEFT JOIN paragraph paragraph ON paragraph."id" = problem.paragraph_id
|
||||
${problem}
|
||||
|
||||
UNION
|
||||
SELECT
|
||||
paragraph."id" AS "source_id",
|
||||
paragraph.document_id AS document_id,
|
||||
paragraph."id" AS paragraph_id,
|
||||
paragraph.dataset_id AS dataset_id,
|
||||
1 AS source_type,
|
||||
paragraph."content" AS "text",
|
||||
paragraph.is_active AS is_active
|
||||
FROM
|
||||
paragraph paragraph
|
||||
|
||||
${paragraph}
|
||||
|
|
@ -0,0 +1,19 @@
|
|||
# coding=utf-8
|
||||
"""
|
||||
@project: maxkb
|
||||
@Author:虎
|
||||
@file: common.py
|
||||
@date:2023/10/16 16:42
|
||||
@desc:
|
||||
"""
|
||||
from functools import reduce
|
||||
from typing import Dict
|
||||
|
||||
|
||||
def query_params_to_single_dict(query_params: Dict):
|
||||
return reduce(lambda x, y: {**x, y[0]: y[1]}, list(filter(lambda row: row[1] is not None,
|
||||
list(map(lambda row: (
|
||||
row[0], row[1][0] if isinstance(row[1][0],
|
||||
list) and len(
|
||||
row[1][0]) > 0 else row[1][0]),
|
||||
query_params.items())))), {})
|
||||
|
|
@ -7,6 +7,7 @@
|
|||
@desc:
|
||||
"""
|
||||
import re
|
||||
from functools import reduce
|
||||
from typing import List
|
||||
|
||||
import jieba
|
||||
|
|
@ -25,7 +26,7 @@ def get_level_block(text, level_content_list, level_content_index):
|
|||
level_content_list) else None
|
||||
start_index = text.index(start_content)
|
||||
end_index = text.index(next_content) if next_content is not None else len(text)
|
||||
return text[start_index:end_index]
|
||||
return text[start_index:end_index].replace(level_content_list[level_content_index]['content'], "")
|
||||
|
||||
|
||||
def to_tree_obj(content, state='title'):
|
||||
|
|
@ -88,7 +89,7 @@ def to_paragraph(obj: dict):
|
|||
content = obj['content']
|
||||
return {"keywords": get_keyword(content),
|
||||
'parent_chain': list(map(lambda p: p['content'], obj['parent_chain'])),
|
||||
'content': content}
|
||||
'content': ",".join(list(map(lambda p: p['content'], obj['parent_chain']))) + content}
|
||||
|
||||
|
||||
def get_keyword(content: str):
|
||||
|
|
@ -109,13 +110,15 @@ def titles_to_paragraph(list_title: List[dict]):
|
|||
:return: 块段落
|
||||
"""
|
||||
if len(list_title) > 0:
|
||||
content = "\n".join(
|
||||
content = "\n,".join(
|
||||
list(map(lambda d: d['content'].strip("\r\n").strip("\n").strip("\\s"), list_title)))
|
||||
|
||||
return {'keywords': '',
|
||||
'parent_chain': list(
|
||||
map(lambda p: p['content'].strip("\r\n").strip("\n").strip("\\s"), list_title[0]['parent_chain'])),
|
||||
'content': content}
|
||||
'content': ",".join(list(
|
||||
map(lambda p: p['content'].strip("\r\n").strip("\n").strip("\\s"),
|
||||
list_title[0]['parent_chain']))) + content}
|
||||
return None
|
||||
|
||||
|
||||
|
|
@ -144,6 +147,15 @@ def to_block_paragraph(tree_data_list: List[dict]):
|
|||
return list(map(lambda level: parse_group_key(level_group_dict[level]), level_group_dict))
|
||||
|
||||
|
||||
def parse_title_level(text, content_level_pattern: List, index):
|
||||
if len(content_level_pattern) == index:
|
||||
return []
|
||||
result = parse_level(text, content_level_pattern[index])
|
||||
if len(result) == 0 and len(content_level_pattern) > index + 1:
|
||||
return parse_title_level(text, content_level_pattern, index + 1)
|
||||
return result
|
||||
|
||||
|
||||
def parse_level(text, pattern: str):
|
||||
"""
|
||||
获取正则匹配到的文本
|
||||
|
|
@ -151,10 +163,17 @@ def parse_level(text, pattern: str):
|
|||
:param pattern: 正则
|
||||
:return: 符合正则的文本
|
||||
"""
|
||||
level_content_list = list(map(to_tree_obj, re.findall(pattern, text, flags=0)))
|
||||
level_content_list = list(map(to_tree_obj, re_findall(pattern, text)))
|
||||
return list(map(filter_special_symbol, level_content_list))
|
||||
|
||||
|
||||
def re_findall(pattern, text):
|
||||
result = re.findall(pattern, text, flags=0)
|
||||
return list(filter(lambda r: r is not None and len(r) > 0, reduce(lambda x, y: [*x, *y], list(
|
||||
map(lambda row: [*(row if isinstance(row, tuple) else [row])], result)),
|
||||
[])))
|
||||
|
||||
|
||||
def to_flat_obj(parent_chain: List[dict], content: str, state: str):
|
||||
"""
|
||||
将树形属性转换为扁平对象
|
||||
|
|
@ -194,10 +213,79 @@ def group_by(list_source: List, key):
|
|||
return result
|
||||
|
||||
|
||||
def result_tree_to_paragraph(result_tree: List[dict], result, parent_chain):
|
||||
"""
|
||||
转换为分段对象
|
||||
:param result_tree: 解析文本的树
|
||||
:param result: 传[] 用于递归
|
||||
:param parent_chain: 传[] 用户递归存储数据
|
||||
:return: List[{'problem':'xx','content':'xx'}]
|
||||
"""
|
||||
for item in result_tree:
|
||||
print(item)
|
||||
if item.get('state') == 'block':
|
||||
result.append({'title': " ".join(parent_chain), 'content': item.get("content")})
|
||||
children = item.get("children")
|
||||
if children is not None and len(children) > 0:
|
||||
result_tree_to_paragraph(children, result, [*parent_chain, item.get('content')])
|
||||
return result
|
||||
|
||||
|
||||
def post_handler_paragraph(content: str, limit: int, with_filter: bool):
|
||||
"""
|
||||
根据文本的最大字符分段
|
||||
:param with_filter: 是否过滤特殊字符
|
||||
:param content: 需要分段的文本字段
|
||||
:param limit: 最大分段字符
|
||||
:return: 分段后数据
|
||||
"""
|
||||
split_list = content.split('\n')
|
||||
result = []
|
||||
temp_char = ''
|
||||
for split in split_list:
|
||||
if len(temp_char + split) > limit:
|
||||
result.append(temp_char)
|
||||
temp_char = ''
|
||||
temp_char = temp_char + split
|
||||
if len(temp_char) > 0:
|
||||
result.append(temp_char)
|
||||
pattern = "[\\S\\s]{1," + str(limit) + '}'
|
||||
# 如果\n 单段超过限制,则继续拆分
|
||||
s = list(map(lambda row: filter_special_char(row) if with_filter else row, list(
|
||||
reduce(lambda x, y: [*x, *y], list(map(lambda row: list(re.findall(pattern, row)), result)), []))))
|
||||
return s
|
||||
|
||||
|
||||
replace_map = {
|
||||
re.compile('\n+'): '\n',
|
||||
re.compile('\\s+'): ' ',
|
||||
re.compile('#+'): "",
|
||||
re.compile("\t+"): ''
|
||||
}
|
||||
|
||||
|
||||
def filter_special_char(content: str):
|
||||
"""
|
||||
过滤特殊字段
|
||||
:param content: 文本
|
||||
:return: 过滤后字段
|
||||
"""
|
||||
items = replace_map.items()
|
||||
for key, value in items:
|
||||
content = re.sub(key, value, content)
|
||||
return content
|
||||
|
||||
|
||||
class SplitModel:
|
||||
|
||||
def __init__(self, content_level_pattern):
|
||||
def __init__(self, content_level_pattern, with_filter=True, limit=1024):
|
||||
self.content_level_pattern = content_level_pattern
|
||||
self.with_filter = with_filter
|
||||
if limit is None or limit > 1024:
|
||||
limit = 1024
|
||||
if limit < 50:
|
||||
limit = 50
|
||||
self.limit = limit
|
||||
|
||||
def parse_to_tree(self, text: str, index=0):
|
||||
"""
|
||||
|
|
@ -208,23 +296,27 @@ class SplitModel:
|
|||
"""
|
||||
if len(self.content_level_pattern) == index:
|
||||
return
|
||||
level_content_list = parse_level(text, pattern=self.content_level_pattern[index])
|
||||
level_content_list = parse_title_level(text, self.content_level_pattern, index)
|
||||
for i in range(len(level_content_list)):
|
||||
block = get_level_block(text, level_content_list, i)
|
||||
children = self.parse_to_tree(text=block.replace(level_content_list[i]['content'][:-1], ""),
|
||||
children = self.parse_to_tree(text=block,
|
||||
index=index + 1)
|
||||
if children is not None and len(children) > 0:
|
||||
level_content_list[i]['children'] = children
|
||||
else:
|
||||
if len(block) > 0:
|
||||
level_content_list[i]['children'] = [to_tree_obj(block, 'block')]
|
||||
level_content_list[i]['children'] = list(
|
||||
map(lambda row: to_tree_obj(row, 'block'),
|
||||
post_handler_paragraph(block, with_filter=self.with_filter, limit=self.limit)))
|
||||
if len(level_content_list) > 0:
|
||||
end_index = text.index(level_content_list[0].get('content'))
|
||||
if end_index == 0:
|
||||
return level_content_list
|
||||
other_content = text[0:end_index]
|
||||
if len(other_content.strip()) > 0:
|
||||
level_content_list.append(to_tree_obj(other_content, 'block'))
|
||||
level_content_list = [*level_content_list, *list(
|
||||
map(lambda row: to_tree_obj(row, 'block'),
|
||||
post_handler_paragraph(other_content, with_filter=self.with_filter, limit=self.limit)))]
|
||||
return level_content_list
|
||||
|
||||
def parse(self, text: str):
|
||||
|
|
@ -234,4 +326,35 @@ class SplitModel:
|
|||
:return: 解析后数据 {content:段落数据,keywords:[‘段落关键词’],parent_chain:['段落父级链路']}
|
||||
"""
|
||||
result_tree = self.parse_to_tree(text, 0)
|
||||
return flat_map(to_block_paragraph(result_tree))
|
||||
return result_tree_to_paragraph(result_tree, [], [])
|
||||
|
||||
|
||||
split_model_map = {
|
||||
'md': SplitModel(
|
||||
[re.compile("^# .*"), re.compile('(?<!#)## (?!#).*'), re.compile("(?<!#)### (?!#).*"),
|
||||
re.compile("(?<!#)####(?!#).*"), re.compile("(?<!#)#####(?!#).*"),
|
||||
re.compile("(?<!#)#####(?!#).*"),
|
||||
re.compile("(?<! )- .*")]),
|
||||
'default': SplitModel([re.compile("(?<!\n)\n\n.+")])
|
||||
}
|
||||
|
||||
|
||||
def get_split_model(filename: str):
|
||||
"""
|
||||
根据文件名称获取分段模型
|
||||
:param filename: 文件名称
|
||||
:return: 分段模型
|
||||
"""
|
||||
if filename.endswith(".md"):
|
||||
return split_model_map.get('md')
|
||||
return split_model_map.get("default")
|
||||
|
||||
|
||||
def to_title_tree_string(result_tree: List):
|
||||
f = flat(result_tree, [], [])
|
||||
return "\n│".join(list(map(lambda r: title_tostring(r), list(filter(lambda row: row.get('state') == 'title', f)))))
|
||||
|
||||
|
||||
def title_tostring(title_obj):
|
||||
f = "│ ".join(list(map(lambda index: " ", range(0, len(title_obj.get("parent_chain"))))))
|
||||
return f + "├───" + title_obj.get('content')
|
||||
|
|
|
|||
|
|
@ -1,4 +1,4 @@
|
|||
# Generated by Django 4.1.10 on 2023-10-09 06:33
|
||||
# Generated by Django 4.1.10 on 2023-10-24 12:13
|
||||
|
||||
from django.db import migrations, models
|
||||
import django.db.models.deletion
|
||||
|
|
@ -36,6 +36,7 @@ class Migration(migrations.Migration):
|
|||
('id', models.UUIDField(default=uuid.uuid1, editable=False, primary_key=True, serialize=False, verbose_name='主键id')),
|
||||
('name', models.CharField(max_length=150, verbose_name='文档名称')),
|
||||
('char_length', models.IntegerField(verbose_name='文档字符数 冗余字段')),
|
||||
('status', models.CharField(choices=[('0', '导入中'), ('1', '已完成'), ('2', '导入失败')], default='0', max_length=1, verbose_name='状态')),
|
||||
('is_active', models.BooleanField(default=True)),
|
||||
('dataset', models.ForeignKey(on_delete=django.db.models.deletion.DO_NOTHING, to='dataset.dataset')),
|
||||
],
|
||||
|
|
@ -50,11 +51,14 @@ class Migration(migrations.Migration):
|
|||
('update_time', models.DateTimeField(auto_now=True, verbose_name='修改时间')),
|
||||
('id', models.UUIDField(default=uuid.uuid1, editable=False, primary_key=True, serialize=False, verbose_name='主键id')),
|
||||
('content', models.CharField(max_length=1024, verbose_name='段落内容')),
|
||||
('title', models.CharField(default='', max_length=256, verbose_name='标题')),
|
||||
('hit_num', models.IntegerField(default=0, verbose_name='命中数量')),
|
||||
('star_num', models.IntegerField(default=0, verbose_name='点赞数')),
|
||||
('trample_num', models.IntegerField(default=0, verbose_name='点踩数')),
|
||||
('status', models.CharField(choices=[('0', '导入中'), ('1', '已完成'), ('2', '导入失败')], default='0', max_length=1, verbose_name='状态')),
|
||||
('is_active', models.BooleanField(default=True)),
|
||||
('document', models.ForeignKey(on_delete=django.db.models.deletion.DO_NOTHING, to='dataset.document')),
|
||||
('dataset', models.ForeignKey(on_delete=django.db.models.deletion.DO_NOTHING, to='dataset.dataset')),
|
||||
('document', models.ForeignKey(db_constraint=False, on_delete=django.db.models.deletion.DO_NOTHING, to='dataset.document')),
|
||||
],
|
||||
options={
|
||||
'db_table': 'paragraph',
|
||||
|
|
@ -67,25 +71,15 @@ class Migration(migrations.Migration):
|
|||
('update_time', models.DateTimeField(auto_now=True, verbose_name='修改时间')),
|
||||
('id', models.UUIDField(default=uuid.uuid1, editable=False, primary_key=True, serialize=False, verbose_name='主键id')),
|
||||
('content', models.CharField(max_length=256, verbose_name='问题内容')),
|
||||
('hit_num', models.IntegerField(default=0, verbose_name='命中数量')),
|
||||
('star_num', models.IntegerField(default=0, verbose_name='点赞数')),
|
||||
('trample_num', models.IntegerField(default=0, verbose_name='点踩数')),
|
||||
('dataset', models.ForeignKey(db_constraint=False, on_delete=django.db.models.deletion.DO_NOTHING, to='dataset.dataset')),
|
||||
('document', models.ForeignKey(db_constraint=False, on_delete=django.db.models.deletion.DO_NOTHING, to='dataset.document')),
|
||||
('paragraph', models.ForeignKey(db_constraint=False, on_delete=django.db.models.deletion.DO_NOTHING, to='dataset.paragraph')),
|
||||
],
|
||||
options={
|
||||
'db_table': 'problem',
|
||||
},
|
||||
),
|
||||
migrations.CreateModel(
|
||||
name='ProblemAnswerMapping',
|
||||
fields=[
|
||||
('create_time', models.DateTimeField(auto_now_add=True, verbose_name='创建时间')),
|
||||
('update_time', models.DateTimeField(auto_now=True, verbose_name='修改时间')),
|
||||
('id', models.UUIDField(default=uuid.uuid1, editable=False, primary_key=True, serialize=False, verbose_name='主键id')),
|
||||
('hit_num', models.IntegerField(default=0, verbose_name='命中数量')),
|
||||
('star_num', models.IntegerField(default=0, verbose_name='点赞数')),
|
||||
('trample_num', models.IntegerField(default=0, verbose_name='点踩数')),
|
||||
('paragraph', models.ForeignKey(on_delete=django.db.models.deletion.DO_NOTHING, to='dataset.paragraph')),
|
||||
('problem', models.ForeignKey(on_delete=django.db.models.deletion.DO_NOTHING, to='dataset.problem')),
|
||||
],
|
||||
options={
|
||||
'db_table': 'problem_paragraph_mapping',
|
||||
},
|
||||
),
|
||||
]
|
||||
|
|
|
|||
|
|
@ -14,6 +14,13 @@ from common.mixins.app_model_mixin import AppModelMixin
|
|||
from users.models import User
|
||||
|
||||
|
||||
class Status(models.TextChoices):
|
||||
"""订单类型"""
|
||||
embedding = 0, '导入中'
|
||||
success = 1, '已完成'
|
||||
error = 2, '导入失败'
|
||||
|
||||
|
||||
class DataSet(AppModelMixin):
|
||||
"""
|
||||
数据集表
|
||||
|
|
@ -35,6 +42,8 @@ class Document(AppModelMixin):
|
|||
dataset = models.ForeignKey(DataSet, on_delete=models.DO_NOTHING)
|
||||
name = models.CharField(max_length=150, verbose_name="文档名称")
|
||||
char_length = models.IntegerField(verbose_name="文档字符数 冗余字段")
|
||||
status = models.CharField(verbose_name='状态', max_length=1, choices=Status.choices,
|
||||
default=Status.embedding)
|
||||
is_active = models.BooleanField(default=True)
|
||||
|
||||
class Meta:
|
||||
|
|
@ -46,11 +55,15 @@ class Paragraph(AppModelMixin):
|
|||
段落表
|
||||
"""
|
||||
id = models.UUIDField(primary_key=True, max_length=128, default=uuid.uuid1, editable=False, verbose_name="主键id")
|
||||
document = models.ForeignKey(Document, on_delete=models.DO_NOTHING)
|
||||
document = models.ForeignKey(Document, on_delete=models.DO_NOTHING, db_constraint=False)
|
||||
dataset = models.ForeignKey(DataSet, on_delete=models.DO_NOTHING)
|
||||
content = models.CharField(max_length=1024, verbose_name="段落内容")
|
||||
title = models.CharField(max_length=256, verbose_name="标题", default="")
|
||||
hit_num = models.IntegerField(verbose_name="命中数量", default=0)
|
||||
star_num = models.IntegerField(verbose_name="点赞数", default=0)
|
||||
trample_num = models.IntegerField(verbose_name="点踩数", default=0)
|
||||
status = models.CharField(verbose_name='状态', max_length=1, choices=Status.choices,
|
||||
default=Status.embedding)
|
||||
is_active = models.BooleanField(default=True)
|
||||
|
||||
class Meta:
|
||||
|
|
@ -62,23 +75,13 @@ class Problem(AppModelMixin):
|
|||
问题表
|
||||
"""
|
||||
id = models.UUIDField(primary_key=True, max_length=128, default=uuid.uuid1, editable=False, verbose_name="主键id")
|
||||
document = models.ForeignKey(Document, on_delete=models.DO_NOTHING, db_constraint=False)
|
||||
dataset = models.ForeignKey(DataSet, on_delete=models.DO_NOTHING, db_constraint=False)
|
||||
paragraph = models.ForeignKey(Paragraph, on_delete=models.DO_NOTHING, db_constraint=False)
|
||||
content = models.CharField(max_length=256, verbose_name="问题内容")
|
||||
|
||||
class Meta:
|
||||
db_table = "problem"
|
||||
|
||||
|
||||
class ProblemAnswerMapping(AppModelMixin):
|
||||
"""
|
||||
问题 段落 映射表
|
||||
"""
|
||||
id = models.UUIDField(primary_key=True, max_length=128, default=uuid.uuid1, editable=False, verbose_name="主键id")
|
||||
paragraph = models.ForeignKey(Paragraph, on_delete=models.DO_NOTHING)
|
||||
problem = models.ForeignKey(Problem, on_delete=models.DO_NOTHING)
|
||||
hit_num = models.IntegerField(verbose_name="命中数量", default=0)
|
||||
star_num = models.IntegerField(verbose_name="点赞数", default=0)
|
||||
trample_num = models.IntegerField(verbose_name="点踩数", default=0)
|
||||
|
||||
class Meta:
|
||||
db_table = "problem_paragraph_mapping"
|
||||
|
||||
db_table = "problem"
|
||||
|
|
|
|||
|
|
@ -8,7 +8,6 @@
|
|||
"""
|
||||
import os.path
|
||||
import uuid
|
||||
from functools import reduce
|
||||
from typing import Dict
|
||||
|
||||
from django.contrib.postgres.fields import ArrayField
|
||||
|
|
@ -19,11 +18,12 @@ from drf_yasg import openapi
|
|||
from rest_framework import serializers
|
||||
|
||||
from common.db.search import get_dynamics_model, native_page_search, native_search
|
||||
from common.event.listener_manage import ListenerManagement
|
||||
from common.exception.app_exception import AppApiException
|
||||
from common.mixins.api_mixin import ApiMixin
|
||||
from common.util.file_util import get_file_content
|
||||
from dataset.models.data_set import DataSet, Document, Paragraph
|
||||
from dataset.serializers.document_serializers import CreateDocumentSerializers
|
||||
from dataset.models.data_set import DataSet, Document, Paragraph, Problem
|
||||
from dataset.serializers.document_serializers import DocumentSerializers, DocumentInstanceSerializer
|
||||
from setting.models import AuthOperate
|
||||
from smartdoc.conf import PROJECT_DIR
|
||||
from users.models import User
|
||||
|
|
@ -81,12 +81,12 @@ class DataSetSerializers(serializers.ModelSerializer):
|
|||
user_id = self.data.get("user_id")
|
||||
query_set_dict = {}
|
||||
query_set = QuerySet(model=get_dynamics_model(
|
||||
{'dataset.name': models.CharField(), 'dataset.desc': models.CharField(),
|
||||
{'temp.name': models.CharField(), 'temp.desc': models.CharField(),
|
||||
"document_temp.char_length": models.IntegerField()}))
|
||||
if "desc" in self.data:
|
||||
query_set = query_set.filter(**{'dataset.desc__contains': self.data.get("desc")})
|
||||
if "name" in self.data:
|
||||
query_set = query_set.filter(**{'dataset.name__contains': self.data.get("name")})
|
||||
if "desc" in self.data and self.data.get('desc') is not None:
|
||||
query_set = query_set.filter(**{'temp.desc__contains': self.data.get("desc")})
|
||||
if "name" in self.data and self.data.get('name') is not None:
|
||||
query_set = query_set.filter(**{'temp.name__contains': self.data.get("name")})
|
||||
|
||||
query_set_dict['default_sql'] = query_set
|
||||
|
||||
|
|
@ -133,9 +133,7 @@ class DataSetSerializers(serializers.ModelSerializer):
|
|||
|
||||
@staticmethod
|
||||
def get_response_body_api():
|
||||
return openapi.Schema(type=openapi.TYPE_ARRAY,
|
||||
title="数据集列表", description="数据集列表",
|
||||
items=DataSetSerializers.Operate.get_response_body_api())
|
||||
return DataSetSerializers.Operate.get_response_body_api()
|
||||
|
||||
class Create(ApiMixin, serializers.Serializer):
|
||||
"""
|
||||
|
|
@ -157,7 +155,7 @@ class DataSetSerializers(serializers.ModelSerializer):
|
|||
message="数据集名称在1-256个字符之间")
|
||||
])
|
||||
|
||||
documents = CreateDocumentSerializers(required=False, many=True)
|
||||
documents = DocumentInstanceSerializer(required=False, many=True)
|
||||
|
||||
def is_valid(self, *, raise_exception=False):
|
||||
super().is_valid(raise_exception=True)
|
||||
|
|
@ -168,28 +166,46 @@ class DataSetSerializers(serializers.ModelSerializer):
|
|||
dataset_id = uuid.uuid1()
|
||||
dataset = DataSet(
|
||||
**{'id': dataset_id, 'name': self.data.get("name"), 'desc': self.data.get('desc'), 'user': user})
|
||||
document_model_list = []
|
||||
paragraph_model_list = []
|
||||
if 'documents' in self.data:
|
||||
documents = self.data.get('documents')
|
||||
for document in documents:
|
||||
document_model = Document(**{'dataset': dataset, 'id': uuid.uuid1(), 'name': document.get('name'),
|
||||
'char_length': reduce(lambda x, y: x + y,
|
||||
list(
|
||||
map(lambda p: len(p),
|
||||
document.get("paragraphs"))), 0)})
|
||||
document_model_list.append(document_model)
|
||||
if 'paragraphs' in document:
|
||||
paragraph_model_list += list(map(lambda p: Paragraph(
|
||||
**{'document': document_model, 'id': uuid.uuid1(), 'content': p}),
|
||||
document.get('paragraphs')))
|
||||
# 插入数据集
|
||||
dataset.save()
|
||||
# 插入文档
|
||||
QuerySet(Document).bulk_create(document_model_list) if len(document_model_list) > 0 else None
|
||||
# 插入段落
|
||||
QuerySet(Paragraph).bulk_create(paragraph_model_list) if len(paragraph_model_list) > 0 else None
|
||||
return True
|
||||
for document in self.data.get('documents') if 'documents' in self.data else []:
|
||||
DocumentSerializers.Create(data={'dataset_id': dataset_id}).save(document, with_valid=True,
|
||||
with_embedding=False)
|
||||
ListenerManagement.embedding_by_dataset_signal.send(str(dataset.id))
|
||||
return {**DataSetSerializers(dataset).data,
|
||||
'document_list': DocumentSerializers.Query(data={'dataset_id': dataset_id}).list(with_valid=True)}
|
||||
|
||||
@staticmethod
|
||||
def get_response_body_api():
|
||||
return openapi.Schema(
|
||||
type=openapi.TYPE_OBJECT,
|
||||
required=['id', 'name', 'desc', 'user_id', 'char_length', 'document_count',
|
||||
'update_time', 'create_time', 'document_list'],
|
||||
properties={
|
||||
'id': openapi.Schema(type=openapi.TYPE_STRING, title="id",
|
||||
description="id", default="xx"),
|
||||
'name': openapi.Schema(type=openapi.TYPE_STRING, title="名称",
|
||||
description="名称", default="测试数据集"),
|
||||
'desc': openapi.Schema(type=openapi.TYPE_STRING, title="描述",
|
||||
description="描述", default="测试数据集描述"),
|
||||
'user_id': openapi.Schema(type=openapi.TYPE_STRING, title="所属用户id",
|
||||
description="所属用户id", default="user_xxxx"),
|
||||
'char_length': openapi.Schema(type=openapi.TYPE_STRING, title="字符数",
|
||||
description="字符数", default=10),
|
||||
'document_count': openapi.Schema(type=openapi.TYPE_STRING, title="文档数量",
|
||||
description="文档数量", default=1),
|
||||
'update_time': openapi.Schema(type=openapi.TYPE_STRING, title="修改时间",
|
||||
description="修改时间",
|
||||
default="1970-01-01 00:00:00"),
|
||||
'create_time': openapi.Schema(type=openapi.TYPE_STRING, title="创建时间",
|
||||
description="创建时间",
|
||||
default="1970-01-01 00:00:00"
|
||||
),
|
||||
'document_list': openapi.Schema(type=openapi.TYPE_ARRAY, title="文档列表",
|
||||
description="文档列表",
|
||||
items=DocumentSerializers.Operate.get_response_body_api())
|
||||
}
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def get_request_body_api():
|
||||
|
|
@ -200,7 +216,7 @@ class DataSetSerializers(serializers.ModelSerializer):
|
|||
'name': openapi.Schema(type=openapi.TYPE_STRING, title="数据集名称", description="数据集名称"),
|
||||
'desc': openapi.Schema(type=openapi.TYPE_STRING, title="数据集描述", description="数据集描述"),
|
||||
'documents': openapi.Schema(type=openapi.TYPE_ARRAY, title="文档数据", description="文档数据",
|
||||
items=CreateDocumentSerializers().get_request_body_api()
|
||||
items=DocumentSerializers().Create.get_request_body_api()
|
||||
)
|
||||
}
|
||||
)
|
||||
|
|
@ -217,10 +233,11 @@ class DataSetSerializers(serializers.ModelSerializer):
|
|||
def delete(self):
|
||||
self.is_valid()
|
||||
dataset = QuerySet(DataSet).get(id=self.data.get("id"))
|
||||
document_list = QuerySet(Document).filter(dataset=dataset)
|
||||
QuerySet(Paragraph).filter(document__in=document_list).delete()
|
||||
document_list.delete()
|
||||
QuerySet(Document).filter(dataset=dataset).delete()
|
||||
QuerySet(Paragraph).filter(dataset=dataset).delete()
|
||||
QuerySet(Problem).filter(dataset=dataset).delete()
|
||||
dataset.delete()
|
||||
ListenerManagement.delete_embedding_by_dataset_signal.send(self.data.get('id'))
|
||||
return True
|
||||
|
||||
def one(self, user_id, with_valid=True):
|
||||
|
|
@ -303,9 +320,9 @@ class DataSetSerializers(serializers.ModelSerializer):
|
|||
|
||||
@staticmethod
|
||||
def get_request_params_api():
|
||||
return [openapi.Parameter(name='id',
|
||||
return [openapi.Parameter(name='dataset_id',
|
||||
in_=openapi.IN_PATH,
|
||||
type=openapi.TYPE_STRING,
|
||||
required=False,
|
||||
required=True,
|
||||
description='数据集id')
|
||||
]
|
||||
|
|
|
|||
|
|
@ -6,20 +6,29 @@
|
|||
@date:2023/9/22 13:43
|
||||
@desc:
|
||||
"""
|
||||
import os
|
||||
import uuid
|
||||
from functools import reduce
|
||||
from typing import List, Dict
|
||||
|
||||
from django.core import validators
|
||||
from django.db import transaction
|
||||
from django.db.models import QuerySet
|
||||
from drf_yasg import openapi
|
||||
from rest_framework import serializers
|
||||
|
||||
from common.db.search import native_search, native_page_search
|
||||
from common.event.listener_manage import ListenerManagement
|
||||
from common.exception.app_exception import AppApiException
|
||||
from common.mixins.api_mixin import ApiMixin
|
||||
from dataset.models.data_set import DataSet, Document, Paragraph
|
||||
from common.util.file_util import get_file_content
|
||||
from common.util.split_model import SplitModel, get_split_model
|
||||
from dataset.models.data_set import DataSet, Document, Paragraph, Problem
|
||||
from dataset.serializers.paragraph_serializers import ParagraphSerializers, ParagraphInstanceSerializer
|
||||
from smartdoc.conf import PROJECT_DIR
|
||||
|
||||
|
||||
class CreateDocumentSerializers(ApiMixin, serializers.Serializer):
|
||||
class DocumentInstanceSerializer(ApiMixin, serializers.Serializer):
|
||||
name = serializers.CharField(required=True,
|
||||
validators=[
|
||||
validators.MaxLengthValidator(limit_value=128,
|
||||
|
|
@ -28,52 +37,265 @@ class CreateDocumentSerializers(ApiMixin, serializers.Serializer):
|
|||
message="数据集名称在1-128个字符之间")
|
||||
])
|
||||
|
||||
paragraphs = serializers.ListField(required=False,
|
||||
child=serializers.CharField(required=True,
|
||||
validators=[
|
||||
validators.MaxLengthValidator(limit_value=256,
|
||||
message="段落在1-256个字符之间"),
|
||||
validators.MinLengthValidator(limit_value=1,
|
||||
message="段落在1-256个字符之间")
|
||||
]))
|
||||
paragraphs = ParagraphInstanceSerializer(required=False, many=True)
|
||||
|
||||
def is_valid(self, *, dataset_id=None, raise_exception=False):
|
||||
if not QuerySet(DataSet).filter(id=dataset_id).exists():
|
||||
raise AppApiException(10000, "数据集id不存在")
|
||||
return super().is_valid(raise_exception=True)
|
||||
|
||||
def save(self, dataset_id: str, **kwargs):
|
||||
document_model = Document(
|
||||
**{'dataset': DataSet(id=dataset_id),
|
||||
'id': uuid.uuid1(),
|
||||
'name': self.data.get('name'),
|
||||
'char_length': reduce(lambda x, y: x + y, list(map(lambda p: len(p), self.data.get("paragraphs"))), 0)})
|
||||
|
||||
paragraph_model_list = list(map(lambda p: Paragraph(
|
||||
**{'document': document_model, 'id': uuid.uuid1(), 'content': p}),
|
||||
self.data.get('paragraphs')))
|
||||
|
||||
# 插入文档
|
||||
document_model.save()
|
||||
# 插入段落
|
||||
QuerySet(Paragraph).bulk_create(paragraph_model_list) if len(paragraph_model_list) > 0 else None
|
||||
return True
|
||||
|
||||
def get_request_body_api(self):
|
||||
@staticmethod
|
||||
def get_request_body_api():
|
||||
return openapi.Schema(
|
||||
type=openapi.TYPE_OBJECT,
|
||||
required=['name', 'paragraph'],
|
||||
required=['name', 'paragraphs'],
|
||||
properties={
|
||||
'name': openapi.Schema(type=openapi.TYPE_STRING, title="文档名称", description="文档名称"),
|
||||
'paragraphs': openapi.Schema(type=openapi.TYPE_ARRAY, title="段落列表", description="段落列表",
|
||||
items=openapi.Schema(type=openapi.TYPE_STRING, title="段落数据",
|
||||
description="段落数据"))
|
||||
items=ParagraphSerializers.Create.get_request_body_api())
|
||||
}
|
||||
)
|
||||
|
||||
def get_request_params_api(self):
|
||||
return [openapi.Parameter(name='dataset_id',
|
||||
in_=openapi.IN_PATH,
|
||||
type=openapi.TYPE_STRING,
|
||||
|
||||
class DocumentSerializers(ApiMixin, serializers.Serializer):
|
||||
class Query(ApiMixin, serializers.Serializer):
|
||||
# 数据集id
|
||||
dataset_id = serializers.UUIDField(required=True)
|
||||
|
||||
name = serializers.CharField(required=False,
|
||||
validators=[
|
||||
validators.MaxLengthValidator(limit_value=128,
|
||||
message="文档名称在1-128个字符之间"),
|
||||
validators.MinLengthValidator(limit_value=1,
|
||||
message="数据集名称在1-128个字符之间")
|
||||
])
|
||||
|
||||
def get_query_set(self):
|
||||
query_set = QuerySet(model=Document)
|
||||
query_set = query_set.filter(**{'dataset_id': self.data.get("dataset_id")})
|
||||
if 'name' in self.data and self.data.get('name') is not None:
|
||||
query_set = query_set.filter(**{'name__contains': self.data.get('name')})
|
||||
return query_set
|
||||
|
||||
def list(self, with_valid=False):
|
||||
if with_valid:
|
||||
self.is_valid(raise_exception=True)
|
||||
query_set = self.get_query_set()
|
||||
return native_search(query_set, select_string=get_file_content(
|
||||
os.path.join(PROJECT_DIR, "apps", "dataset", 'sql', 'list_document.sql')))
|
||||
|
||||
def page(self, current_page, page_size):
|
||||
query_set = self.get_query_set()
|
||||
return native_page_search(current_page, page_size, query_set, select_string=get_file_content(
|
||||
os.path.join(PROJECT_DIR, "apps", "dataset", 'sql', 'list_document.sql')))
|
||||
|
||||
@staticmethod
|
||||
def get_request_params_api():
|
||||
return [openapi.Parameter(name='name',
|
||||
in_=openapi.IN_QUERY,
|
||||
type=openapi.TYPE_STRING,
|
||||
required=False,
|
||||
description='文档名称')]
|
||||
|
||||
@staticmethod
|
||||
def get_response_body_api():
|
||||
return openapi.Schema(type=openapi.TYPE_ARRAY,
|
||||
title="文档列表", description="文档列表",
|
||||
items=DocumentSerializers.Operate.get_response_body_api())
|
||||
|
||||
class Operate(ApiMixin, serializers.Serializer):
|
||||
document_id = serializers.UUIDField(required=True)
|
||||
|
||||
@staticmethod
|
||||
def get_request_params_api():
|
||||
return [openapi.Parameter(name='dataset_id',
|
||||
in_=openapi.IN_PATH,
|
||||
type=openapi.TYPE_STRING,
|
||||
required=True,
|
||||
description='数据集id'),
|
||||
openapi.Parameter(name='document_id',
|
||||
in_=openapi.IN_PATH,
|
||||
type=openapi.TYPE_STRING,
|
||||
required=True,
|
||||
description='文档id')
|
||||
]
|
||||
|
||||
def is_valid(self, *, raise_exception=False):
|
||||
super().is_valid(raise_exception=True)
|
||||
document_id = self.data.get('document_id')
|
||||
if not QuerySet(Document).filter(id=document_id).exists():
|
||||
raise AppApiException(500, "文档id不存在")
|
||||
|
||||
def one(self, with_valid=False):
|
||||
if with_valid:
|
||||
self.is_valid(raise_exception=True)
|
||||
query_set = QuerySet(model=Document)
|
||||
query_set = query_set.filter(**{'id': self.data.get("document_id")})
|
||||
return native_search(query_set, select_string=get_file_content(
|
||||
os.path.join(PROJECT_DIR, "apps", "dataset", 'sql', 'list_document.sql')), with_search_one=True)
|
||||
|
||||
def edit(self, instance: Dict, with_valid=False):
|
||||
if with_valid:
|
||||
self.is_valid()
|
||||
_document = QuerySet(Document).get(id=self.data.get("document_id"))
|
||||
update_keys = ['name', 'is_active']
|
||||
for update_key in update_keys:
|
||||
if update_key in instance and instance.get(update_key) is not None:
|
||||
_document.__setattr__(update_key, instance.get(update_key))
|
||||
_document.save()
|
||||
return self.one()
|
||||
|
||||
@transaction.atomic
|
||||
def delete(self):
|
||||
document_id = self.data.get("document_id")
|
||||
QuerySet(model=Document).filter(id=document_id).delete()
|
||||
# 删除段落
|
||||
QuerySet(model=Paragraph).filter(document_id=document_id).delete()
|
||||
# 删除问题
|
||||
QuerySet(model=Problem).filter(document_id=document_id).delete()
|
||||
# 删除向量库
|
||||
ListenerManagement.delete_embedding_by_document_signal.send(document_id)
|
||||
return True
|
||||
|
||||
@staticmethod
|
||||
def get_response_body_api():
|
||||
return openapi.Schema(
|
||||
type=openapi.TYPE_OBJECT,
|
||||
required=['id', 'name', 'char_length', 'user_id', 'paragraph_count', 'is_active'
|
||||
'update_time', 'create_time'],
|
||||
properties={
|
||||
'id': openapi.Schema(type=openapi.TYPE_STRING, title="id",
|
||||
description="id", default="xx"),
|
||||
'name': openapi.Schema(type=openapi.TYPE_STRING, title="名称",
|
||||
description="名称", default="测试数据集"),
|
||||
'char_length': openapi.Schema(type=openapi.TYPE_INTEGER, title="字符数",
|
||||
description="字符数", default=10),
|
||||
'user_id': openapi.Schema(type=openapi.TYPE_STRING, title="用户id", description="用户id"),
|
||||
'paragraph_count': openapi.Schema(type=openapi.TYPE_INTEGER, title="文档数量",
|
||||
description="文档数量", default=1),
|
||||
'is_active': openapi.Schema(type=openapi.TYPE_BOOLEAN, title="是否可用",
|
||||
description="是否可用", default=True),
|
||||
'update_time': openapi.Schema(type=openapi.TYPE_STRING, title="修改时间",
|
||||
description="修改时间",
|
||||
default="1970-01-01 00:00:00"),
|
||||
'create_time': openapi.Schema(type=openapi.TYPE_STRING, title="创建时间",
|
||||
description="创建时间",
|
||||
default="1970-01-01 00:00:00"
|
||||
)
|
||||
}
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def get_request_body_api():
|
||||
return openapi.Schema(
|
||||
type=openapi.TYPE_OBJECT,
|
||||
properties={
|
||||
'name': openapi.Schema(type=openapi.TYPE_STRING, title="文档名称", description="文档名称"),
|
||||
'is_active': openapi.Schema(type=openapi.TYPE_BOOLEAN, title="是否可用", description="是否可用"),
|
||||
}
|
||||
)
|
||||
|
||||
class Create(ApiMixin, serializers.Serializer):
|
||||
dataset_id = serializers.UUIDField(required=True)
|
||||
|
||||
def is_valid(self, *, raise_exception=False):
|
||||
super().is_valid(raise_exception=True)
|
||||
if not QuerySet(DataSet).filter(id=self.data.get('dataset_id')).exists():
|
||||
raise AppApiException(10000, "数据集id不存在")
|
||||
return True
|
||||
|
||||
def save(self, instance: Dict, with_valid=False, with_embedding=True, **kwargs):
|
||||
if with_valid:
|
||||
DocumentInstanceSerializer(data=instance).is_valid()
|
||||
self.is_valid(raise_exception=True)
|
||||
dataset_id = self.data.get('dataset_id')
|
||||
|
||||
document_model = Document(
|
||||
**{'dataset_id': dataset_id,
|
||||
'id': uuid.uuid1(),
|
||||
'name': instance.get('name'),
|
||||
'char_length': reduce(lambda x, y: x + y,
|
||||
[len(p.get('content')) for p in instance.get('paragraphs', [])],
|
||||
0)})
|
||||
for paragraph in instance.get('paragraphs') if 'paragraphs' in instance else []:
|
||||
ParagraphSerializers.Create(
|
||||
data={'dataset_id': dataset_id, 'document_id': str(document_model.id)}).save(paragraph,
|
||||
with_valid=True,
|
||||
with_embedding=False)
|
||||
# 插入文档
|
||||
document_model.save()
|
||||
if with_embedding:
|
||||
ListenerManagement.embedding_by_document_signal.send(str(document_model.id))
|
||||
return DocumentSerializers.Operate(
|
||||
data={'dataset_id': dataset_id, 'document_id': str(document_model.id)}).one(
|
||||
with_valid=True)
|
||||
|
||||
@staticmethod
|
||||
def get_request_body_api():
|
||||
return DocumentInstanceSerializer.get_request_body_api()
|
||||
|
||||
@staticmethod
|
||||
def get_request_params_api():
|
||||
return [openapi.Parameter(name='dataset_id',
|
||||
in_=openapi.IN_PATH,
|
||||
type=openapi.TYPE_STRING,
|
||||
required=True,
|
||||
description='数据集id')
|
||||
]
|
||||
|
||||
class Split(ApiMixin, serializers.Serializer):
|
||||
file = serializers.ListField(required=True)
|
||||
|
||||
limit = serializers.IntegerField(required=False)
|
||||
|
||||
patterns = serializers.ListField(required=False,
|
||||
child=serializers.CharField(required=True))
|
||||
|
||||
with_filter = serializers.BooleanField(required=False)
|
||||
|
||||
def is_valid(self, *, raise_exception=True):
|
||||
super().is_valid()
|
||||
files = self.data.get('file')
|
||||
for f in files:
|
||||
if f.size > 1024 * 1024 * 10:
|
||||
raise AppApiException(500, "上传文件最大不能超过10m")
|
||||
|
||||
@staticmethod
|
||||
def get_request_params_api():
|
||||
return [
|
||||
openapi.Parameter(name='file',
|
||||
in_=openapi.IN_FORM,
|
||||
type=openapi.TYPE_ARRAY,
|
||||
items=openapi.Items(type=openapi.TYPE_FILE),
|
||||
required=True,
|
||||
description='数据集id')]
|
||||
description='上传文件'),
|
||||
openapi.Parameter(name='limit',
|
||||
in_=openapi.IN_FORM,
|
||||
required=False,
|
||||
type=openapi.TYPE_INTEGER, title="分段长度", description="分段长度"),
|
||||
openapi.Parameter(name='patterns',
|
||||
in_=openapi.IN_FORM,
|
||||
required=False,
|
||||
type=openapi.TYPE_ARRAY, items=openapi.Items(type=openapi.TYPE_STRING),
|
||||
title="分段正则列表", description="分段正则列表"),
|
||||
openapi.Parameter(name='with_filter',
|
||||
in_=openapi.IN_FORM,
|
||||
required=False,
|
||||
type=openapi.TYPE_BOOLEAN, title="是否清除特殊字符", description="是否清除特殊字符"),
|
||||
]
|
||||
|
||||
def parse(self):
|
||||
file_list = self.data.get("file")
|
||||
return list(map(lambda f: file_to_paragraph(f, self.data.get("patterns"), self.data.get("with_filter"),
|
||||
self.data.get("limit")), file_list))
|
||||
|
||||
|
||||
def file_to_paragraph(file, pattern_list: List, with_filter, limit: int):
|
||||
data = file.read()
|
||||
if pattern_list is None or len(pattern_list) > 0:
|
||||
split_model = SplitModel(pattern_list, with_filter, limit)
|
||||
else:
|
||||
split_model = get_split_model(file.name)
|
||||
try:
|
||||
content = data.decode('utf-8')
|
||||
except BaseException as e:
|
||||
return {'name': file.name,
|
||||
'content': []}
|
||||
return {'name': file.name,
|
||||
'content': split_model.parse(content)
|
||||
}
|
||||
|
|
|
|||
|
|
@ -0,0 +1,278 @@
|
|||
# coding=utf-8
|
||||
"""
|
||||
@project: maxkb
|
||||
@Author:虎
|
||||
@file: paragraph_serializers.py
|
||||
@date:2023/10/16 15:51
|
||||
@desc:
|
||||
"""
|
||||
import uuid
|
||||
from typing import Dict
|
||||
|
||||
from django.core import validators
|
||||
from django.db import transaction
|
||||
from django.db.models import QuerySet
|
||||
from drf_yasg import openapi
|
||||
from rest_framework import serializers
|
||||
|
||||
from common.db.search import page_search
|
||||
from common.event.listener_manage import ListenerManagement
|
||||
from common.exception.app_exception import AppApiException
|
||||
from common.mixins.api_mixin import ApiMixin
|
||||
from dataset.models import Paragraph, Problem
|
||||
from dataset.serializers.problem_serializers import ProblemInstanceSerializer, ProblemSerializer
|
||||
|
||||
|
||||
class ParagraphSerializer(serializers.ModelSerializer):
|
||||
class Meta:
|
||||
model = Paragraph
|
||||
fields = ['id', 'content', 'hit_num', 'star_num', 'trample_num', 'is_active', 'document_id', 'title',
|
||||
'create_time', 'update_time']
|
||||
|
||||
|
||||
class ParagraphInstanceSerializer(ApiMixin, serializers.Serializer):
|
||||
"""
|
||||
段落实例对象
|
||||
"""
|
||||
content = serializers.CharField(required=True, validators=[
|
||||
validators.MaxLengthValidator(limit_value=1024,
|
||||
message="段落在1-1024个字符之间"),
|
||||
validators.MinLengthValidator(limit_value=1,
|
||||
message="段落在1-1024个字符之间")
|
||||
])
|
||||
|
||||
title = serializers.CharField(required=False)
|
||||
|
||||
problem_list = ProblemInstanceSerializer(required=False, many=True)
|
||||
|
||||
is_active = serializers.BooleanField(required=False)
|
||||
|
||||
@staticmethod
|
||||
def get_request_body_api():
|
||||
return openapi.Schema(
|
||||
type=openapi.TYPE_OBJECT,
|
||||
required=['content'],
|
||||
properties={
|
||||
'content': openapi.Schema(type=openapi.TYPE_STRING, title="分段内容", description="分段内容"),
|
||||
|
||||
'title': openapi.Schema(type=openapi.TYPE_STRING, title="分段标题",
|
||||
description="分段标题"),
|
||||
|
||||
'is_active': openapi.Schema(type=openapi.TYPE_BOOLEAN, title="是否可用", description="是否可用"),
|
||||
|
||||
'problem_list': openapi.Schema(type=openapi.TYPE_ARRAY, title='问题列表',
|
||||
description="问题列表",
|
||||
items=ProblemInstanceSerializer.get_request_body_api())
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
class ParagraphSerializers(ApiMixin, serializers.Serializer):
|
||||
class Operate(ApiMixin, serializers.Serializer):
|
||||
# 段落id
|
||||
paragraph_id = serializers.UUIDField(required=True)
|
||||
# 数据集id
|
||||
dataset_id = serializers.UUIDField(required=True)
|
||||
# 数据集id
|
||||
document_id = serializers.UUIDField(required=True)
|
||||
|
||||
def is_valid(self, *, raise_exception=True):
|
||||
super().is_valid(raise_exception=True)
|
||||
if not QuerySet(Paragraph).filter(id=self.data.get('paragraph_id')).exists():
|
||||
raise AppApiException(500, "段落id不存在")
|
||||
|
||||
@transaction.atomic
|
||||
def edit(self, instance: Dict):
|
||||
self.is_valid()
|
||||
_paragraph = QuerySet(Paragraph).get(id=self.data.get("paragraph_id"))
|
||||
update_keys = ['title', 'content', 'is_active']
|
||||
for update_key in update_keys:
|
||||
if update_key in instance and instance.get(update_key) is not None:
|
||||
_paragraph.__setattr__(update_key, instance.get(update_key))
|
||||
|
||||
if 'problem_list' in instance:
|
||||
update_problem_list = list(
|
||||
filter(lambda row: 'id' in row and row.get('id') is not None, instance.get('problem_list')))
|
||||
|
||||
create_problem_list = list(filter(lambda row: row.get('id') is None, instance.get('problem_list')))
|
||||
|
||||
# 问题集合
|
||||
problem_list = QuerySet(Problem).filter(paragraph_id=self.data.get("paragraph_id"))
|
||||
|
||||
# 校验前端 携带过来的id
|
||||
for update_problem in update_problem_list:
|
||||
if not set([str(row.id) for row in problem_list]).__contains__(update_problem.get('id')):
|
||||
raise AppApiException(500, update_problem.get('id') + '问题id不存在')
|
||||
# 对比需要删除的问题
|
||||
delete_problem_list = list(filter(
|
||||
lambda row: not [str(update_row.get('id')) for update_row in update_problem_list].__contains__(
|
||||
str(row.id)), problem_list)) if len(update_problem_list) > 0 else []
|
||||
# 删除问题
|
||||
QuerySet(Problem).filter(id__in=[row.id for row in delete_problem_list]).delete() if len(
|
||||
delete_problem_list) > 0 else None
|
||||
# 插入新的问题
|
||||
QuerySet(Problem).bulk_create(
|
||||
[Problem(id=uuid.uuid1(), content=p.get('content'), paragraph_id=self.data.get('paragraph_id'),
|
||||
dataset_id=self.data.get('dataset_id'), document_id=self.data.get('document_id')) for
|
||||
p in create_problem_list]) if len(create_problem_list) else None
|
||||
|
||||
# 修改问题集合
|
||||
QuerySet(Problem).bulk_update(
|
||||
[Problem(id=row.get('id'), content=row.get('content')) for row in update_problem_list],
|
||||
['content']) if len(
|
||||
update_problem_list) > 0 else None
|
||||
|
||||
_paragraph.save()
|
||||
if 'is_active' in instance and instance.get('is_active') is not None:
|
||||
s = (ListenerManagement.enable_embedding_by_paragraph_signal if instance.get(
|
||||
'is_active') else ListenerManagement.disable_embedding_by_paragraph_signal)
|
||||
s.send(self.data.get('paragraph_id'))
|
||||
return self.one()
|
||||
|
||||
def get_problem_list(self):
|
||||
return [ProblemSerializer(problem).data for problem in
|
||||
QuerySet(Problem).filter(paragraph_id=self.data.get("paragraph_id"))]
|
||||
|
||||
def one(self, with_valid=False):
|
||||
if with_valid:
|
||||
self.is_valid(raise_exception=True)
|
||||
return {**ParagraphSerializer(QuerySet(model=Paragraph).get(id=self.data.get('paragraph_id'))).data,
|
||||
'problem_list': self.get_problem_list()}
|
||||
|
||||
def delete(self, with_valid=False):
|
||||
if with_valid:
|
||||
self.is_valid(raise_exception=True)
|
||||
paragraph_id = self.data.get('paragraph_id')
|
||||
QuerySet(Paragraph).filter(id=paragraph_id).delete()
|
||||
QuerySet(Problem).filter(paragraph_id=paragraph_id).delete()
|
||||
ListenerManagement.delete_embedding_by_paragraph_signal.send(paragraph_id)
|
||||
|
||||
@staticmethod
|
||||
def get_request_body_api():
|
||||
return ParagraphInstanceSerializer.get_request_body_api()
|
||||
|
||||
@staticmethod
|
||||
def get_response_body_api():
|
||||
return ParagraphInstanceSerializer.get_request_body_api()
|
||||
|
||||
@staticmethod
|
||||
def get_request_params_api():
|
||||
return [openapi.Parameter(type=openapi.TYPE_STRING, in_=openapi.IN_PATH, name='paragraph_id',
|
||||
description="段落id")]
|
||||
|
||||
class Create(ApiMixin, serializers.Serializer):
|
||||
dataset_id = serializers.UUIDField(required=True)
|
||||
|
||||
document_id = serializers.UUIDField(required=True)
|
||||
|
||||
def save(self, instance: Dict, with_valid=True, with_embedding=True):
|
||||
if with_valid:
|
||||
ParagraphSerializers(data=instance).is_valid(raise_exception=True)
|
||||
self.is_valid()
|
||||
dataset_id = self.data.get("dataset_id")
|
||||
document_id = self.data.get('document_id')
|
||||
|
||||
paragraph = Paragraph(id=uuid.uuid1(),
|
||||
document_id=document_id,
|
||||
content=instance.get("content"),
|
||||
dataset_id=dataset_id,
|
||||
title=instance.get("title") if 'title' in instance else '')
|
||||
# 插入段落
|
||||
paragraph.save()
|
||||
problem_model_list = [Problem(id=uuid.uuid1(), content=problem.get('content'), paragraph_id=paragraph.id,
|
||||
document_id=document_id, dataset_id=dataset_id) for problem in (
|
||||
instance.get('problem_list') if 'problem_list' in instance else [])]
|
||||
# 插入問題
|
||||
QuerySet(Problem).bulk_create(problem_model_list) if len(problem_model_list) > 0 else None
|
||||
if with_embedding:
|
||||
ListenerManagement.embedding_by_paragraph_signal.send(str(paragraph.id))
|
||||
return ParagraphSerializers.Operate(
|
||||
data={'paragraph_id': str(paragraph.id), 'dataset_id': dataset_id, 'document_id': document_id}).one(
|
||||
with_valid=True)
|
||||
|
||||
@staticmethod
|
||||
def get_request_body_api():
|
||||
return ParagraphInstanceSerializer.get_request_body_api()
|
||||
|
||||
@staticmethod
|
||||
def get_request_params_api():
|
||||
return [openapi.Parameter(name='dataset_id',
|
||||
in_=openapi.IN_PATH,
|
||||
type=openapi.TYPE_STRING,
|
||||
required=True,
|
||||
description='数据集id'),
|
||||
openapi.Parameter(name='document_id', in_=openapi.IN_PATH,
|
||||
type=openapi.TYPE_STRING,
|
||||
required=True,
|
||||
description="文档id")
|
||||
]
|
||||
|
||||
class Query(ApiMixin, serializers.Serializer):
|
||||
dataset_id = serializers.UUIDField(required=True)
|
||||
|
||||
document_id = serializers.UUIDField(required=True)
|
||||
|
||||
title = serializers.CharField(required=False)
|
||||
|
||||
def get_query_set(self):
|
||||
query_set = QuerySet(model=Paragraph)
|
||||
query_set = query_set.filter(
|
||||
**{'dataset_id': self.data.get('dataset_id'), 'document_id': self.data.get("document_id")})
|
||||
if 'title' in self.data:
|
||||
query_set = query_set.filter(
|
||||
**{'title__contains': self.data.get('title')})
|
||||
return query_set
|
||||
|
||||
def list(self):
|
||||
return list(map(lambda row: ParagraphSerializer(row).data, self.get_query_set()))
|
||||
|
||||
def page(self, current_page, page_size):
|
||||
query_set = self.get_query_set()
|
||||
return page_search(current_page, page_size, query_set, lambda row: ParagraphSerializer(row).data)
|
||||
|
||||
@staticmethod
|
||||
def get_request_params_api():
|
||||
return [openapi.Parameter(name='document_id',
|
||||
in_=openapi.IN_PATH,
|
||||
type=openapi.TYPE_STRING,
|
||||
required=True,
|
||||
description='文档id'),
|
||||
openapi.Parameter(name='title',
|
||||
in_=openapi.IN_QUERY,
|
||||
type=openapi.TYPE_STRING,
|
||||
required=False,
|
||||
description='标题')
|
||||
]
|
||||
|
||||
@staticmethod
|
||||
def get_response_body_api():
|
||||
return openapi.Schema(
|
||||
type=openapi.TYPE_OBJECT,
|
||||
required=['id', 'content', 'hit_num', 'star_num', 'trample_num', 'is_active', 'document_id', 'title',
|
||||
'create_time', 'update_time'],
|
||||
properties={
|
||||
'id': openapi.Schema(type=openapi.TYPE_STRING, title="id",
|
||||
description="id", default="xx"),
|
||||
'content': openapi.Schema(type=openapi.TYPE_STRING, title="段落内容",
|
||||
description="段落内容", default='段落内容'),
|
||||
'title': openapi.Schema(type=openapi.TYPE_STRING, title="标题",
|
||||
description="标题", default="xxx的描述"),
|
||||
'hit_num': openapi.Schema(type=openapi.TYPE_INTEGER, title="命中数量", description="命中数量",
|
||||
default=1),
|
||||
'star_num': openapi.Schema(type=openapi.TYPE_INTEGER, title="点赞数量",
|
||||
description="点赞数量", default=1),
|
||||
'trample_num': openapi.Schema(type=openapi.TYPE_INTEGER, title="点踩数量",
|
||||
description="点踩数", default=1),
|
||||
'document_id': openapi.Schema(type=openapi.TYPE_STRING, title="文档id",
|
||||
description="文档id", default='xxx'),
|
||||
'is_active': openapi.Schema(type=openapi.TYPE_BOOLEAN, title="是否可用",
|
||||
description="是否可用", default=True),
|
||||
'update_time': openapi.Schema(type=openapi.TYPE_STRING, title="修改时间",
|
||||
description="修改时间",
|
||||
default="1970-01-01 00:00:00"),
|
||||
'create_time': openapi.Schema(type=openapi.TYPE_STRING, title="创建时间",
|
||||
description="创建时间",
|
||||
default="1970-01-01 00:00:00"
|
||||
)
|
||||
}
|
||||
)
|
||||
|
|
@ -0,0 +1,222 @@
|
|||
# coding=utf-8
|
||||
"""
|
||||
@project: maxkb
|
||||
@Author:虎
|
||||
@file: problem_serializers.py
|
||||
@date:2023/10/23 13:55
|
||||
@desc:
|
||||
"""
|
||||
import uuid
|
||||
from typing import Dict
|
||||
|
||||
from django.db.models import QuerySet
|
||||
from drf_yasg import openapi
|
||||
from rest_framework import serializers
|
||||
|
||||
from common.event.listener_manage import ListenerManagement
|
||||
from common.exception.app_exception import AppApiException
|
||||
from common.mixins.api_mixin import ApiMixin
|
||||
from dataset.models import Problem, Paragraph
|
||||
from embedding.models import SourceType
|
||||
from embedding.vector.pg_vector import PGVector
|
||||
|
||||
|
||||
class ProblemSerializer(serializers.ModelSerializer):
|
||||
class Meta:
|
||||
model = Problem
|
||||
fields = ['id', 'content', 'hit_num', 'star_num', 'trample_num', 'dataset_id', 'document_id',
|
||||
'create_time', 'update_time']
|
||||
|
||||
|
||||
class ProblemInstanceSerializer(ApiMixin, serializers.Serializer):
|
||||
id = serializers.CharField(required=False)
|
||||
|
||||
content = serializers.CharField(required=True)
|
||||
|
||||
@staticmethod
|
||||
def get_request_body_api():
|
||||
return openapi.Schema(type=openapi.TYPE_OBJECT,
|
||||
required=["content"],
|
||||
properties={
|
||||
'id': openapi.Schema(
|
||||
type=openapi.TYPE_STRING,
|
||||
title="问题id,修改的时候传递,创建的时候不传"),
|
||||
'content': openapi.Schema(
|
||||
type=openapi.TYPE_STRING, title="内容")
|
||||
})
|
||||
|
||||
|
||||
class ProblemSerializers(ApiMixin, serializers.Serializer):
|
||||
class Create(ApiMixin, serializers.Serializer):
|
||||
dataset_id = serializers.UUIDField(required=True)
|
||||
|
||||
document_id = serializers.UUIDField(required=True)
|
||||
|
||||
paragraph_id = serializers.UUIDField(required=True)
|
||||
|
||||
def save(self, instance: Dict, with_valid=True, with_embedding=True):
|
||||
if with_valid:
|
||||
self.is_valid()
|
||||
ProblemInstanceSerializer(data=instance).is_valid(raise_exception=True)
|
||||
problem = Problem(id=uuid.uuid1(), paragraph_id=self.data.get('paragraph_id'),
|
||||
document_id=self.data.get('document_id'), dataset_id=self.data.get('dataset_id'),
|
||||
content=instance.get('content'))
|
||||
problem.save()
|
||||
if with_embedding:
|
||||
ListenerManagement.embedding_by_problem_signal.send({'text': problem.content,
|
||||
'is_active': True,
|
||||
'source_type': SourceType.PROBLEM,
|
||||
'source_id': problem.id,
|
||||
'document_id': self.data.get('document_id'),
|
||||
'paragraph_id': self.data.get('paragraph_id'),
|
||||
'dataset_id': self.data.get('dataset_id')})
|
||||
|
||||
return ProblemSerializers.Operate(
|
||||
data={'dataset_id': self.data.get('dataset_id'), 'document_id': self.data.get('document_id'),
|
||||
'paragraph_id': self.data.get('paragraph_id'), 'problem_id': problem.id}).one(with_valid=True)
|
||||
|
||||
@staticmethod
|
||||
def get_request_body_api():
|
||||
return ProblemInstanceSerializer.get_request_body_api()
|
||||
|
||||
@staticmethod
|
||||
def get_request_params_api():
|
||||
return [openapi.Parameter(name='dataset_id',
|
||||
in_=openapi.IN_PATH,
|
||||
type=openapi.TYPE_STRING,
|
||||
required=True,
|
||||
description='数据集id'),
|
||||
openapi.Parameter(name='document_id',
|
||||
in_=openapi.IN_PATH,
|
||||
type=openapi.TYPE_STRING,
|
||||
required=True,
|
||||
description='文档id'),
|
||||
openapi.Parameter(name='paragraph_id',
|
||||
in_=openapi.IN_PATH,
|
||||
type=openapi.TYPE_STRING,
|
||||
required=True,
|
||||
description='段落id')]
|
||||
|
||||
class Query(ApiMixin, serializers.Serializer):
|
||||
dataset_id = serializers.UUIDField(required=True)
|
||||
|
||||
document_id = serializers.UUIDField(required=True)
|
||||
|
||||
paragraph_id = serializers.UUIDField(required=True)
|
||||
|
||||
def is_valid(self, *, raise_exception=True):
|
||||
super().is_valid(raise_exception=True)
|
||||
if not QuerySet(Paragraph).filter(id=self.data.get('paragraph_id')).exists():
|
||||
raise AppApiException(500, "段落id不存在")
|
||||
|
||||
def get_query_set(self):
|
||||
dataset_id = self.data.get('dataset_id')
|
||||
document_id = self.data.get('document_id')
|
||||
paragraph_id = self.data.get("paragraph_id")
|
||||
return QuerySet(Problem).filter(
|
||||
**{'paragraph_id': paragraph_id, 'dataset_id': dataset_id, 'document_id': document_id})
|
||||
|
||||
def list(self, with_valid=False):
|
||||
"""
|
||||
获取问题列表
|
||||
:param with_valid: 是否校验
|
||||
:return: 问题列表
|
||||
"""
|
||||
if with_valid:
|
||||
self.is_valid(raise_exception=True)
|
||||
query_set = self.get_query_set()
|
||||
return [ProblemSerializer(p).data for p in query_set]
|
||||
|
||||
@staticmethod
|
||||
def get_request_params_api():
|
||||
return [openapi.Parameter(name='dataset_id',
|
||||
in_=openapi.IN_PATH,
|
||||
type=openapi.TYPE_STRING,
|
||||
required=True,
|
||||
description='数据集id'),
|
||||
openapi.Parameter(name='document_id',
|
||||
in_=openapi.IN_PATH,
|
||||
type=openapi.TYPE_STRING,
|
||||
required=True,
|
||||
description='文档id')
|
||||
, openapi.Parameter(name='paragraph_id',
|
||||
in_=openapi.IN_PATH,
|
||||
type=openapi.TYPE_STRING,
|
||||
required=True,
|
||||
description='段落id')]
|
||||
|
||||
class Operate(ApiMixin, serializers.Serializer):
|
||||
dataset_id = serializers.UUIDField(required=True)
|
||||
|
||||
document_id = serializers.UUIDField(required=True)
|
||||
|
||||
paragraph_id = serializers.UUIDField(required=True)
|
||||
|
||||
problem_id = serializers.UUIDField(required=True)
|
||||
|
||||
def delete(self, with_valid=False):
|
||||
if with_valid:
|
||||
self.is_valid(raise_exception=True)
|
||||
QuerySet(Problem).filter(**{'id': self.data.get('problem_id')}).delete()
|
||||
PGVector().delete_by_source_id(self.data.get('problem_id'), SourceType.PROBLEM)
|
||||
ListenerManagement.delete_embedding_by_source_signal.send(self.data.get('problem_id'))
|
||||
return True
|
||||
|
||||
def one(self, with_valid=False):
|
||||
if with_valid:
|
||||
self.is_valid(raise_exception=True)
|
||||
return ProblemInstanceSerializer(QuerySet(Problem).get(**{'id': self.data.get('problem_id')})).data
|
||||
|
||||
@staticmethod
|
||||
def get_request_params_api():
|
||||
return [openapi.Parameter(name='dataset_id',
|
||||
in_=openapi.IN_PATH,
|
||||
type=openapi.TYPE_STRING,
|
||||
required=True,
|
||||
description='数据集id'),
|
||||
openapi.Parameter(name='document_id',
|
||||
in_=openapi.IN_PATH,
|
||||
type=openapi.TYPE_STRING,
|
||||
required=True,
|
||||
description='文档id')
|
||||
, openapi.Parameter(name='paragraph_id',
|
||||
in_=openapi.IN_PATH,
|
||||
type=openapi.TYPE_STRING,
|
||||
required=True,
|
||||
description='段落id'),
|
||||
openapi.Parameter(name='problem_id',
|
||||
in_=openapi.IN_PATH,
|
||||
type=openapi.TYPE_STRING,
|
||||
required=True,
|
||||
description='问题id')
|
||||
]
|
||||
|
||||
@staticmethod
|
||||
def get_response_body_api():
|
||||
return openapi.Schema(
|
||||
type=openapi.TYPE_OBJECT,
|
||||
required=['id', 'content', 'hit_num', 'star_num', 'trample_num', 'dataset_id',
|
||||
'document_id',
|
||||
'create_time', 'update_time'],
|
||||
properties={
|
||||
'id': openapi.Schema(type=openapi.TYPE_STRING, title="id",
|
||||
description="id", default="xx"),
|
||||
'content': openapi.Schema(type=openapi.TYPE_STRING, title="问题内容",
|
||||
description="问题内容", default='问题内容'),
|
||||
'hit_num': openapi.Schema(type=openapi.TYPE_INTEGER, title="命中数量", description="命中数量",
|
||||
default=1),
|
||||
'star_num': openapi.Schema(type=openapi.TYPE_INTEGER, title="点赞数量",
|
||||
description="点赞数量", default=1),
|
||||
'trample_num': openapi.Schema(type=openapi.TYPE_INTEGER, title="点踩数量",
|
||||
description="点踩数", default=1),
|
||||
'document_id': openapi.Schema(type=openapi.TYPE_STRING, title="文档id",
|
||||
description="文档id", default='xxx'),
|
||||
'update_time': openapi.Schema(type=openapi.TYPE_STRING, title="修改时间",
|
||||
description="修改时间",
|
||||
default="1970-01-01 00:00:00"),
|
||||
'create_time': openapi.Schema(type=openapi.TYPE_STRING, title="创建时间",
|
||||
description="创建时间",
|
||||
default="1970-01-01 00:00:00"
|
||||
)
|
||||
}
|
||||
)
|
||||
|
|
@ -0,0 +1,5 @@
|
|||
SELECT
|
||||
"document".* ,
|
||||
(SELECT "count"("id") FROM "paragraph" WHERE document_id="document"."id") as "paragraph_count"
|
||||
FROM
|
||||
"document" "document"
|
||||
|
|
@ -0,0 +1,10 @@
|
|||
SELECT
|
||||
problem."id",
|
||||
problem."content",
|
||||
problem_paragraph_mapping.hit_num,
|
||||
problem_paragraph_mapping.star_num,
|
||||
problem_paragraph_mapping.trample_num,
|
||||
problem_paragraph_mapping.paragraph_id
|
||||
FROM
|
||||
problem problem
|
||||
LEFT JOIN problem_paragraph_mapping problem_paragraph_mapping ON problem."id" = problem_paragraph_mapping.problem_id
|
||||
|
|
@ -7,5 +7,18 @@ urlpatterns = [
|
|||
path('dataset', views.Dataset.as_view(), name="dataset"),
|
||||
path('dataset/<str:dataset_id>', views.Dataset.Operate.as_view(), name="dataset_key"),
|
||||
path('dataset/<int:current_page>/<int:page_size>', views.Dataset.Page.as_view(), name="dataset"),
|
||||
path('dataset/<str:dataset_id>/document', views.Document.as_view(), name='document')
|
||||
path('dataset/<str:dataset_id>/document', views.Document.as_view(), name='document'),
|
||||
path('dataset/<str:dataset_id>/document/<str:document_id>', views.Document.Operate.as_view(),
|
||||
name="document_operate"),
|
||||
path('dataset/document/split', views.Document.Split.as_view(),
|
||||
name="document_operate"),
|
||||
path('dataset/<str:dataset_id>/document/<str:document_id>/paragraph', views.Paragraph.as_view()),
|
||||
path('dataset/<str:dataset_id>/document/<str:document_id>/paragraph/<int:current_page>/<int:page_size>',
|
||||
views.Paragraph.Page.as_view(), name='paragraph_page'),
|
||||
path('dataset/<str:dataset_id>/document/<str:document_id>/paragraph/<paragraph_id>',
|
||||
views.Paragraph.Operate.as_view()),
|
||||
path('dataset/<str:dataset_id>/document/<str:document_id>/paragraph/<str:paragraph_id>/problem',
|
||||
views.Problem.as_view()),
|
||||
path('dataset/<str:dataset_id>/document/<str:document_id>/paragraph/<str:paragraph_id>/problem/<str:problem_id>',
|
||||
views.Problem.Operate.as_view())
|
||||
]
|
||||
|
|
|
|||
|
|
@ -8,3 +8,5 @@
|
|||
"""
|
||||
from .dataset import *
|
||||
from .document import *
|
||||
from .paragraph import *
|
||||
from .problem import *
|
||||
|
|
|
|||
|
|
@ -26,7 +26,7 @@ class Dataset(APIView):
|
|||
@swagger_auto_schema(operation_summary="获取数据集列表",
|
||||
operation_id="获取数据集列表",
|
||||
manual_parameters=DataSetSerializers.Query.get_request_params_api(),
|
||||
responses=get_api_response(DataSetSerializers.Query.get_response_body_api()))
|
||||
responses=result.get_api_array_response(DataSetSerializers.Query.get_response_body_api()))
|
||||
@has_permissions(PermissionConstants.DATASET_READ, compare=CompareConstants.AND)
|
||||
def get(self, request: Request):
|
||||
d = DataSetSerializers.Query(data={**request.query_params, 'user_id': str(request.user.id)})
|
||||
|
|
@ -36,19 +36,21 @@ class Dataset(APIView):
|
|||
@action(methods=['POST'], detail=False)
|
||||
@swagger_auto_schema(operation_summary="创建数据集",
|
||||
operation_id="创建数据集",
|
||||
request_body=DataSetSerializers.Create.get_request_body_api())
|
||||
request_body=DataSetSerializers.Create.get_request_body_api(),
|
||||
responses=get_api_response(DataSetSerializers.Create.get_response_body_api()))
|
||||
@has_permissions(PermissionConstants.DATASET_CREATE, compare=CompareConstants.AND)
|
||||
def post(self, request: Request):
|
||||
s = DataSetSerializers.Create(data=request.data)
|
||||
if s.is_valid():
|
||||
s.save(request.user)
|
||||
return result.success("ok")
|
||||
s.is_valid(raise_exception=True)
|
||||
return result.success(s.save(request.user))
|
||||
|
||||
class Operate(APIView):
|
||||
authentication_classes = [TokenAuth]
|
||||
|
||||
@action(methods="DELETE", detail=False)
|
||||
@swagger_auto_schema(operation_summary="删除数据集", operation_id="删除数据集")
|
||||
@swagger_auto_schema(operation_summary="删除数据集", operation_id="删除数据集",
|
||||
manual_parameters=DataSetSerializers.Operate.get_request_params_api(),
|
||||
responses=result.get_default_response())
|
||||
@has_permissions(lambda r, keywords: Permission(group=Group.DATASET, operate=Operate.MANAGE,
|
||||
dynamic_tag=keywords.get('dataset_id')),
|
||||
lambda r, k: Permission(group=Group.DATASET, operate=Operate.DELETE,
|
||||
|
|
@ -59,6 +61,7 @@ class Dataset(APIView):
|
|||
|
||||
@action(methods="GET", detail=False)
|
||||
@swagger_auto_schema(operation_summary="查询数据集详情根据数据集id", operation_id="查询数据集详情根据数据集id",
|
||||
manual_parameters=DataSetSerializers.Operate.get_request_params_api(),
|
||||
responses=get_api_response(DataSetSerializers.Operate.get_response_body_api()))
|
||||
@has_permissions(lambda r, keywords: Permission(group=Group.DATASET, operate=Operate.USE,
|
||||
dynamic_tag=keywords.get('dataset_id')))
|
||||
|
|
@ -67,6 +70,7 @@ class Dataset(APIView):
|
|||
|
||||
@action(methods="PUT", detail=False)
|
||||
@swagger_auto_schema(operation_summary="修改数据集信息", operation_id="修改数据集信息",
|
||||
manual_parameters=DataSetSerializers.Operate.get_request_params_api(),
|
||||
request_body=DataSetSerializers.Operate.get_request_body_api(),
|
||||
responses=get_api_response(DataSetSerializers.Operate.get_response_body_api()))
|
||||
@has_permissions(lambda r, keywords: Permission(group=Group.DATASET, operate=Operate.MANAGE,
|
||||
|
|
@ -84,8 +88,10 @@ class Dataset(APIView):
|
|||
manual_parameters=get_page_request_params(
|
||||
DataSetSerializers.Query.get_request_params_api()),
|
||||
responses=get_page_api_response(DataSetSerializers.Query.get_response_body_api()))
|
||||
@has_permissions(PermissionConstants.USER_READ, compare=CompareConstants.AND)
|
||||
@has_permissions(PermissionConstants.DATASET_READ, compare=CompareConstants.AND)
|
||||
def get(self, request: Request, current_page, page_size):
|
||||
d = DataSetSerializers.Query(data={**request.query_params, 'user_id': str(request.user.id)})
|
||||
d = DataSetSerializers.Query(
|
||||
data={'name': request.query_params.get('name', None), 'desc': request.query_params.get("desc", None),
|
||||
'user_id': str(request.user.id)})
|
||||
d.is_valid()
|
||||
return result.success(d.page(current_page, page_size))
|
||||
|
|
|
|||
|
|
@ -9,13 +9,15 @@
|
|||
|
||||
from drf_yasg.utils import swagger_auto_schema
|
||||
from rest_framework.decorators import action
|
||||
from rest_framework.parsers import MultiPartParser
|
||||
from rest_framework.views import APIView
|
||||
from rest_framework.views import Request
|
||||
|
||||
from common.auth import TokenAuth, has_permissions
|
||||
from common.constants.permission_constants import Permission, Group, Operate, PermissionConstants
|
||||
from common.response import result
|
||||
from dataset.serializers.dataset_serializers import CreateDocumentSerializers
|
||||
from common.util.common import query_params_to_single_dict
|
||||
from dataset.serializers.document_serializers import DocumentSerializers
|
||||
|
||||
|
||||
class Document(APIView):
|
||||
|
|
@ -24,28 +26,102 @@ class Document(APIView):
|
|||
@action(methods=['POST'], detail=False)
|
||||
@swagger_auto_schema(operation_summary="创建文档",
|
||||
operation_id="创建文档",
|
||||
request_body=CreateDocumentSerializers().get_request_body_api(),
|
||||
manual_parameters=CreateDocumentSerializers().get_request_params_api())
|
||||
@has_permissions(PermissionConstants.DATASET_CREATE)
|
||||
request_body=DocumentSerializers.Create.get_request_body_api(),
|
||||
manual_parameters=DocumentSerializers.Create.get_request_params_api(),
|
||||
responses=result.get_api_response(DocumentSerializers.Operate.get_response_body_api()))
|
||||
@has_permissions(
|
||||
lambda r, k: Permission(group=Group.DATASET, operate=Operate.MANAGE,
|
||||
dynamic_tag=k.get('dataset_id')))
|
||||
def post(self, request: Request, dataset_id: str):
|
||||
d = CreateDocumentSerializers(data=request.data)
|
||||
if d.is_valid(dataset_id=dataset_id):
|
||||
d.save(dataset_id)
|
||||
return result.success("ok")
|
||||
|
||||
|
||||
class DocumentDetails(APIView):
|
||||
authentication_classes = [TokenAuth]
|
||||
return result.success(
|
||||
DocumentSerializers.Create(data={'dataset_id': dataset_id}).save(request.data, with_valid=True))
|
||||
|
||||
@action(methods=['GET'], detail=False)
|
||||
@swagger_auto_schema(operation_summary="获取文档详情",
|
||||
operation_id="获取文档详情",
|
||||
request_body=CreateDocumentSerializers().get_request_body_api(),
|
||||
manual_parameters=CreateDocumentSerializers().get_request_params_api())
|
||||
@swagger_auto_schema(operation_summary="文档列表",
|
||||
operation_id="文档列表",
|
||||
manual_parameters=DocumentSerializers.Query.get_request_params_api(),
|
||||
responses=result.get_api_response(DocumentSerializers.Query.get_response_body_api()))
|
||||
@has_permissions(
|
||||
lambda r, k: Permission(group=Group.DATASET, operate=Operate.MANAGE, dynamic_tag=k.get('dataset_id')))
|
||||
lambda r, k: Permission(group=Group.DATASET, operate=Operate.USE,
|
||||
dynamic_tag=k.get('dataset_id')))
|
||||
def get(self, request: Request, dataset_id: str):
|
||||
d = CreateDocumentSerializers(data=request.data)
|
||||
if d.is_valid(dataset_id=dataset_id):
|
||||
d.save(dataset_id)
|
||||
return result.success("ok")
|
||||
d = DocumentSerializers.Query(
|
||||
data={**query_params_to_single_dict(request.query_params), 'dataset_id': dataset_id})
|
||||
d.is_valid(raise_exception=True)
|
||||
return result.success(d.list())
|
||||
|
||||
class Operate(APIView):
|
||||
authentication_classes = [TokenAuth]
|
||||
|
||||
@action(methods=['GET'], detail=False)
|
||||
@swagger_auto_schema(operation_summary="获取文档详情",
|
||||
operation_id="获取文档详情",
|
||||
manual_parameters=DocumentSerializers.Operate.get_request_params_api(),
|
||||
responses=result.get_api_response(DocumentSerializers.Operate.get_response_body_api()))
|
||||
@has_permissions(
|
||||
lambda r, k: Permission(group=Group.DATASET, operate=Operate.USE,
|
||||
dynamic_tag=k.get('dataset_id')))
|
||||
def get(self, request: Request, dataset_id: str, document_id: str):
|
||||
operate = DocumentSerializers.Operate(data={'document_id': document_id, 'dataset_id': dataset_id})
|
||||
operate.is_valid(raise_exception=True)
|
||||
return result.success(operate.one())
|
||||
|
||||
@action(methods=['PUT'], detail=False)
|
||||
@swagger_auto_schema(operation_summary="修改文档",
|
||||
operation_id="修改文档",
|
||||
manual_parameters=DocumentSerializers.Operate.get_request_params_api(),
|
||||
request_body=DocumentSerializers.Operate.get_request_body_api(),
|
||||
responses=result.get_api_response(DocumentSerializers.Operate.get_response_body_api())
|
||||
)
|
||||
@has_permissions(
|
||||
lambda r, k: Permission(group=Group.DATASET, operate=Operate.MANAGE,
|
||||
dynamic_tag=k.get('dataset_id')))
|
||||
def put(self, request: Request, dataset_id: str, document_id: str):
|
||||
return result.success(
|
||||
DocumentSerializers.Operate(data={'document_id': document_id, 'dataset_id': dataset_id}).edit(
|
||||
request.data,
|
||||
with_valid=True))
|
||||
|
||||
@action(methods=['DELETE'], detail=False)
|
||||
@swagger_auto_schema(operation_summary="删除文档",
|
||||
operation_id="删除文档",
|
||||
manual_parameters=DocumentSerializers.Operate.get_request_params_api(),
|
||||
responses=result.get_default_response())
|
||||
@has_permissions(
|
||||
lambda r, k: Permission(group=Group.DATASET, operate=Operate.MANAGE,
|
||||
dynamic_tag=k.get('dataset_id')))
|
||||
def delete(self, request: Request, dataset_id: str, document_id: str):
|
||||
operate = DocumentSerializers.Operate(data={'document_id': document_id, 'dataset_id': dataset_id})
|
||||
operate.is_valid(raise_exception=True)
|
||||
return result.success(operate.delete())
|
||||
|
||||
class Split(APIView):
|
||||
parser_classes = [MultiPartParser]
|
||||
|
||||
@action(methods=['POST'], detail=False)
|
||||
@swagger_auto_schema(operation_summary="分段文档",
|
||||
operation_id="分段文档",
|
||||
manual_parameters=DocumentSerializers.Split.get_request_params_api())
|
||||
def post(self, request: Request):
|
||||
ds = DocumentSerializers.Split(
|
||||
data={'file': request.FILES.getlist('file'),
|
||||
'patterns': request.data.getlist('patterns[]')})
|
||||
ds.is_valid(raise_exception=True)
|
||||
return result.success(ds.parse())
|
||||
|
||||
class Page(APIView):
|
||||
authentication_classes = [TokenAuth]
|
||||
|
||||
@action(methods=['GET'], detail=False)
|
||||
@swagger_auto_schema(operation_summary="获取数据集分页列表",
|
||||
operation_id="获取数据集分页列表",
|
||||
manual_parameters=DocumentSerializers.Query.get_request_params_api(),
|
||||
responses=result.get_page_api_response(DocumentSerializers.Query.get_response_body_api()))
|
||||
@has_permissions(
|
||||
lambda r, k: Permission(group=Group.DATASET, operate=Operate.USE,
|
||||
dynamic_tag=k.get('dataset_id')))
|
||||
def get(self, request: Request, dataset_id: str, current_page, page_size):
|
||||
d = DocumentSerializers.Query(
|
||||
data={**query_params_to_single_dict(request.query_params), 'dataset_id': dataset_id})
|
||||
d.is_valid(raise_exception=True)
|
||||
return result.success(d.page(current_page, page_size))
|
||||
|
|
|
|||
|
|
@ -0,0 +1,115 @@
|
|||
# coding=utf-8
|
||||
"""
|
||||
@project: maxkb
|
||||
@Author:虎
|
||||
@file: paragraph_serializers.py
|
||||
@date:2023/10/16 15:51
|
||||
@desc:
|
||||
"""
|
||||
from drf_yasg.utils import swagger_auto_schema
|
||||
from rest_framework.decorators import action
|
||||
from rest_framework.views import APIView
|
||||
from rest_framework.views import Request
|
||||
|
||||
from common.auth import TokenAuth, has_permissions
|
||||
from common.constants.permission_constants import Permission, Group, Operate
|
||||
from common.response import result
|
||||
from common.util.common import query_params_to_single_dict
|
||||
from dataset.serializers.paragraph_serializers import ParagraphSerializers
|
||||
|
||||
|
||||
class Paragraph(APIView):
|
||||
authentication_classes = [TokenAuth]
|
||||
|
||||
@action(methods=['GET'], detail=False)
|
||||
@swagger_auto_schema(operation_summary="段落列表",
|
||||
operation_id="段落列表",
|
||||
manual_parameters=ParagraphSerializers.Query.get_request_params_api(),
|
||||
responses=result.get_api_array_response(ParagraphSerializers.Query.get_response_body_api()))
|
||||
@has_permissions(
|
||||
lambda r, k: Permission(group=Group.DATASET, operate=Operate.USE,
|
||||
dynamic_tag=k.get('dataset_id')))
|
||||
def get(self, request: Request, dataset_id: str, document_id: str):
|
||||
q = ParagraphSerializers.Query(
|
||||
data={**query_params_to_single_dict(request.query_params), 'dataset_id': dataset_id,
|
||||
'document_id': document_id})
|
||||
q.is_valid(raise_exception=True)
|
||||
return result.success(q.list())
|
||||
|
||||
@action(methods=['POST'], detail=False)
|
||||
@swagger_auto_schema(operation_summary="创建段落",
|
||||
operation_id="创建段落",
|
||||
manual_parameters=ParagraphSerializers.Create.get_request_params_api(),
|
||||
request_body=ParagraphSerializers.Create.get_request_body_api(),
|
||||
responses=result.get_api_response(ParagraphSerializers.Query.get_response_body_api()))
|
||||
@has_permissions(
|
||||
lambda r, k: Permission(group=Group.DATASET, operate=Operate.MANAGE,
|
||||
dynamic_tag=k.get('dataset_id')))
|
||||
def post(self, request: Request, dataset_id: str, document_id: str):
|
||||
return result.success(
|
||||
ParagraphSerializers.Create(data={'dataset_id': dataset_id, 'document_id': document_id}).save(request.data))
|
||||
|
||||
class Operate(APIView):
|
||||
authentication_classes = [TokenAuth]
|
||||
|
||||
@action(methods=['UPDATE'], detail=False)
|
||||
@swagger_auto_schema(operation_summary="修改段落数据",
|
||||
operation_id="修改段落数据",
|
||||
manual_parameters=ParagraphSerializers.Operate.get_request_params_api(),
|
||||
request_body=ParagraphSerializers.Operate.get_request_body_api(),
|
||||
responses=result.get_api_response(ParagraphSerializers.Operate.get_response_body_api()))
|
||||
@has_permissions(
|
||||
lambda r, k: Permission(group=Group.DATASET, operate=Operate.MANAGE,
|
||||
dynamic_tag=k.get('dataset_id')))
|
||||
def put(self, request: Request, dataset_id: str, document_id: str, paragraph_id: str):
|
||||
o = ParagraphSerializers.Operate(
|
||||
data={"paragraph_id": paragraph_id, 'dataset_id': dataset_id, 'document_id': document_id})
|
||||
o.is_valid(raise_exception=True)
|
||||
return result.success(o.edit(request.data))
|
||||
|
||||
@action(methods=['UPDATE'], detail=False)
|
||||
@swagger_auto_schema(operation_summary="获取段落详情",
|
||||
operation_id="获取段落详情",
|
||||
manual_parameters=ParagraphSerializers.Operate.get_request_params_api(),
|
||||
responses=result.get_api_response(ParagraphSerializers.Operate.get_response_body_api()))
|
||||
@has_permissions(
|
||||
lambda r, k: Permission(group=Group.DATASET, operate=Operate.USE,
|
||||
dynamic_tag=k.get('dataset_id')))
|
||||
def get(self, request: Request, dataset_id: str, document_id: str, paragraph_id: str):
|
||||
o = ParagraphSerializers.Operate(
|
||||
data={"dataset_id": dataset_id, 'document_id': document_id, "paragraph_id": paragraph_id})
|
||||
o.is_valid(raise_exception=True)
|
||||
return result.success(o.one())
|
||||
|
||||
@action(methods=['DELETE'], detail=False)
|
||||
@swagger_auto_schema(operation_summary="删除段落",
|
||||
operation_id="删除段落",
|
||||
manual_parameters=ParagraphSerializers.Operate.get_request_params_api(),
|
||||
responses=result.get_default_response())
|
||||
@has_permissions(
|
||||
lambda r, k: Permission(group=Group.DATASET, operate=Operate.MANAGE,
|
||||
dynamic_tag=k.get('dataset_id')))
|
||||
def delete(self, request: Request, dataset_id: str, document_id: str, paragraph_id: str):
|
||||
o = ParagraphSerializers.Operate(
|
||||
data={"dataset_id": dataset_id, 'document_id': document_id, "paragraph_id": paragraph_id})
|
||||
o.is_valid(raise_exception=True)
|
||||
return result.success(o.delete())
|
||||
|
||||
class Page(APIView):
|
||||
authentication_classes = [TokenAuth]
|
||||
|
||||
@action(methods=['GET'], detail=False)
|
||||
@swagger_auto_schema(operation_summary="分页获取段落列表",
|
||||
operation_id="分页获取段落列表",
|
||||
manual_parameters=result.get_page_request_params(
|
||||
ParagraphSerializers.Query.get_request_params_api()),
|
||||
responses=result.get_page_api_response(ParagraphSerializers.Query.get_response_body_api()))
|
||||
@has_permissions(
|
||||
lambda r, k: Permission(group=Group.DATASET, operate=Operate.USE,
|
||||
dynamic_tag=k.get('dataset_id')))
|
||||
def get(self, request: Request, dataset_id: str, document_id: str, current_page, page_size):
|
||||
d = ParagraphSerializers.Query(
|
||||
data={**query_params_to_single_dict(request.query_params), 'dataset_id': dataset_id,
|
||||
'document_id': document_id})
|
||||
d.is_valid(raise_exception=True)
|
||||
return result.success(d.page(current_page, page_size))
|
||||
|
|
@ -0,0 +1,65 @@
|
|||
# coding=utf-8
|
||||
"""
|
||||
@project: maxkb
|
||||
@Author:虎
|
||||
@file: problem.py
|
||||
@date:2023/10/23 13:54
|
||||
@desc:
|
||||
"""
|
||||
from drf_yasg.utils import swagger_auto_schema
|
||||
from rest_framework.decorators import action
|
||||
from rest_framework.request import Request
|
||||
from rest_framework.views import APIView
|
||||
|
||||
from common.auth import TokenAuth, has_permissions
|
||||
from common.constants.permission_constants import Permission, Group, Operate
|
||||
from common.response import result
|
||||
from dataset.serializers.problem_serializers import ProblemSerializers
|
||||
|
||||
|
||||
class Problem(APIView):
|
||||
authentication_classes = [TokenAuth]
|
||||
|
||||
@action(methods=['POST'], detail=False)
|
||||
@swagger_auto_schema(operation_summary="添加关联问题",
|
||||
operation_id="添加段落关联问题",
|
||||
manual_parameters=ProblemSerializers.Create.get_request_params_api(),
|
||||
request_body=ProblemSerializers.Create.get_request_body_api(),
|
||||
responses=result.get_api_response(ProblemSerializers.Operate.get_response_body_api()))
|
||||
@has_permissions(
|
||||
lambda r, k: Permission(group=Group.DATASET, operate=Operate.MANAGE,
|
||||
dynamic_tag=k.get('dataset_id')))
|
||||
def post(self, request: Request, dataset_id: str, document_id: str, paragraph_id: str):
|
||||
return result.success(ProblemSerializers.Create(
|
||||
data={"dataset_id": dataset_id, 'document_id': document_id, 'paragraph_id': paragraph_id}).save(
|
||||
request.data, with_valid=True))
|
||||
|
||||
@action(methods=['GET'], detail=False)
|
||||
@swagger_auto_schema(operation_summary="获取段落问题列表",
|
||||
operation_id="获取段落问题列表",
|
||||
manual_parameters=ProblemSerializers.Query.get_request_params_api(),
|
||||
responses=result.get_api_array_response(ProblemSerializers.Operate.get_response_body_api()))
|
||||
@has_permissions(
|
||||
lambda r, k: Permission(group=Group.DATASET, operate=Operate.USE,
|
||||
dynamic_tag=k.get('dataset_id')))
|
||||
def get(self, request: Request, dataset_id: str, document_id: str, paragraph_id: str):
|
||||
return result.success(ProblemSerializers.Query(
|
||||
data={"dataset_id": dataset_id, 'document_id': document_id, 'paragraph_id': paragraph_id}).list(
|
||||
with_valid=True))
|
||||
|
||||
class Operate(APIView):
|
||||
authentication_classes = [TokenAuth]
|
||||
|
||||
@action(methods=['DELETE'], detail=False)
|
||||
@swagger_auto_schema(operation_summary="删除段落问题",
|
||||
operation_id="删除段落问题",
|
||||
manual_parameters=ProblemSerializers.Query.get_request_params_api(),
|
||||
responses=result.get_default_response())
|
||||
@has_permissions(
|
||||
lambda r, k: Permission(group=Group.DATASET, operate=Operate.MANAGE,
|
||||
dynamic_tag=k.get('dataset_id')))
|
||||
def delete(self, request: Request, dataset_id: str, document_id: str, paragraph_id: str, problem_id: str):
|
||||
o = ProblemSerializers.Operate(
|
||||
data={'dataset_id': dataset_id, 'document_id': document_id, 'paragraph_id': paragraph_id,
|
||||
'problem_id': problem_id})
|
||||
return result.success(o.delete(with_valid=True))
|
||||
|
|
@ -1,4 +1,4 @@
|
|||
# Generated by Django 4.1.10 on 2023-10-09 06:33
|
||||
# Generated by Django 4.1.10 on 2023-10-24 12:13
|
||||
|
||||
import common.field.vector_field
|
||||
from django.db import migrations, models
|
||||
|
|
@ -20,8 +20,11 @@ class Migration(migrations.Migration):
|
|||
('id', models.CharField(max_length=128, primary_key=True, serialize=False, verbose_name='主键id')),
|
||||
('source_id', models.CharField(max_length=128, verbose_name='资源id')),
|
||||
('source_type', models.CharField(choices=[('0', '问题'), ('1', '段落')], default='0', max_length=1, verbose_name='资源类型')),
|
||||
('is_active', models.BooleanField(default=True, max_length=1, verbose_name='是否可用')),
|
||||
('embedding', common.field.vector_field.VectorField(verbose_name='向量')),
|
||||
('dataset', models.ForeignKey(on_delete=django.db.models.deletion.DO_NOTHING, to='dataset.dataset', verbose_name='数据集关联')),
|
||||
('dataset', models.ForeignKey(db_constraint=False, on_delete=django.db.models.deletion.DO_NOTHING, to='dataset.dataset', verbose_name='文档关联')),
|
||||
('document', models.ForeignKey(db_constraint=False, on_delete=django.db.models.deletion.DO_NOTHING, to='dataset.document', verbose_name='文档关联')),
|
||||
('paragraph', models.ForeignKey(db_constraint=False, on_delete=django.db.models.deletion.DO_NOTHING, to='dataset.paragraph', verbose_name='段落关联')),
|
||||
],
|
||||
options={
|
||||
'db_table': 'embedding',
|
||||
|
|
|
|||
|
|
@ -9,7 +9,7 @@
|
|||
from django.db import models
|
||||
|
||||
from common.field.vector_field import VectorField
|
||||
from dataset.models.data_set import DataSet
|
||||
from dataset.models.data_set import Document, Paragraph, DataSet
|
||||
|
||||
|
||||
class SourceType(models.TextChoices):
|
||||
|
|
@ -26,7 +26,13 @@ class Embedding(models.Model):
|
|||
source_type = models.CharField(verbose_name='资源类型', max_length=1, choices=SourceType.choices,
|
||||
default=SourceType.PROBLEM)
|
||||
|
||||
dataset = models.ForeignKey(DataSet, on_delete=models.DO_NOTHING, verbose_name="数据集关联")
|
||||
is_active = models.BooleanField(verbose_name="是否可用", max_length=1, default=True)
|
||||
|
||||
dataset = models.ForeignKey(DataSet, on_delete=models.DO_NOTHING, verbose_name="文档关联", db_constraint=False)
|
||||
|
||||
document = models.ForeignKey(Document, on_delete=models.DO_NOTHING, verbose_name="文档关联", db_constraint=False)
|
||||
|
||||
paragraph = models.ForeignKey(Paragraph, on_delete=models.DO_NOTHING, verbose_name="段落关联", db_constraint=False)
|
||||
|
||||
embedding = VectorField(verbose_name="向量")
|
||||
|
||||
|
|
|
|||
|
|
@ -0,0 +1,117 @@
|
|||
# coding=utf-8
|
||||
"""
|
||||
@project: maxkb
|
||||
@Author:虎
|
||||
@file: base_vector.py
|
||||
@date:2023/10/18 19:16
|
||||
@desc:
|
||||
"""
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import List, Dict
|
||||
|
||||
from langchain.embeddings import HuggingFaceEmbeddings
|
||||
|
||||
from common.config.embedding_config import EmbeddingModel
|
||||
from embedding.models import SourceType
|
||||
|
||||
|
||||
class BaseVectorStore(ABC):
|
||||
vector_exists = False
|
||||
|
||||
@abstractmethod
|
||||
def vector_is_create(self) -> bool:
|
||||
"""
|
||||
判断向量库是否创建
|
||||
:return: 是否创建向量库
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def vector_create(self):
|
||||
"""
|
||||
创建 向量库
|
||||
:return:
|
||||
"""
|
||||
pass
|
||||
|
||||
def save_pre_handler(self):
|
||||
"""
|
||||
插入前置处理器 主要是判断向量库是否创建
|
||||
:return: True
|
||||
"""
|
||||
if not BaseVectorStore.vector_exists:
|
||||
if not self.vector_is_create():
|
||||
self.vector_create()
|
||||
BaseVectorStore.vector_exists = True
|
||||
return True
|
||||
|
||||
def save(self, text, source_type: SourceType, dataset_id: str, document_id: str, paragraph_id: str, source_id: str,
|
||||
is_active: bool,
|
||||
embedding=None):
|
||||
"""
|
||||
插入向量数据
|
||||
:param source_id: 资源id
|
||||
:param dataset_id: 数据集id
|
||||
:param text: 文本
|
||||
:param source_type: 资源类型
|
||||
:param document_id: 文档id
|
||||
:param is_active: 是否禁用
|
||||
:param embedding: 向量化处理器
|
||||
:param paragraph_id 段落id
|
||||
:return: bool
|
||||
"""
|
||||
if embedding is None:
|
||||
embedding = EmbeddingModel.get_embedding_model()
|
||||
self.save_pre_handler()
|
||||
self._save(text, source_type, dataset_id, document_id, paragraph_id, source_id, is_active, embedding)
|
||||
|
||||
def batch_save(self, data_list: List[Dict], embedding=None):
|
||||
"""
|
||||
批量插入
|
||||
:param data_list: 数据列表
|
||||
:param embedding: 向量化处理器
|
||||
:return: bool
|
||||
"""
|
||||
if embedding is None:
|
||||
embedding = EmbeddingModel.get_embedding_model()
|
||||
self.save_pre_handler()
|
||||
self._batch_save(data_list, embedding)
|
||||
return True
|
||||
|
||||
@abstractmethod
|
||||
def _save(self, text, source_type: SourceType, dataset_id: str, document_id: str, paragraph_id: str, source_id: str,
|
||||
is_active: bool,
|
||||
embedding: HuggingFaceEmbeddings):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def _batch_save(self, text_list: List[Dict], embedding: HuggingFaceEmbeddings):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def search(self, query_text, dataset_id_list: list[str], is_active: bool, embedding: HuggingFaceEmbeddings):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def update_by_paragraph_id(self, paragraph_id: str, instance: Dict):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def update_by_source_id(self, source_id: str, instance: Dict):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def delete_by_dataset_id(self, dataset_id: str):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def delete_by_document_id(self, document_id: str):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def delete_by_source_id(self, source_id: str, source_type: str):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def delete_by_paragraph_id(self, paragraph_id: str):
|
||||
pass
|
||||
|
|
@ -0,0 +1,79 @@
|
|||
# coding=utf-8
|
||||
"""
|
||||
@project: maxkb
|
||||
@Author:虎
|
||||
@file: pg_vector.py
|
||||
@date:2023/10/19 15:28
|
||||
@desc:
|
||||
"""
|
||||
import uuid
|
||||
from typing import Dict, List
|
||||
|
||||
from django.db.models import QuerySet
|
||||
from langchain.embeddings import HuggingFaceEmbeddings
|
||||
|
||||
from embedding.models import Embedding, SourceType
|
||||
from embedding.vector.base_vector import BaseVectorStore
|
||||
|
||||
|
||||
class PGVector(BaseVectorStore):
|
||||
|
||||
def vector_is_create(self) -> bool:
|
||||
# 项目启动默认是创建好的 不需要再创建
|
||||
return True
|
||||
|
||||
def vector_create(self):
|
||||
return True
|
||||
|
||||
def _save(self, text, source_type: SourceType, dataset_id: str, document_id: str, paragraph_id: str, source_id: str,
|
||||
is_active: bool,
|
||||
embedding: HuggingFaceEmbeddings):
|
||||
text_embedding = embedding.embed_query(text)
|
||||
embedding = Embedding(id=uuid.uuid1(),
|
||||
dataset_id=dataset_id,
|
||||
document_id=document_id,
|
||||
is_active=is_active,
|
||||
paragraph_id=paragraph_id,
|
||||
source_id=source_id,
|
||||
embedding=text_embedding,
|
||||
source_type=source_type,
|
||||
)
|
||||
embedding.save()
|
||||
return True
|
||||
|
||||
def _batch_save(self, text_list: List[Dict], embedding: HuggingFaceEmbeddings):
|
||||
texts = [row.get('text') for row in text_list]
|
||||
embeddings = embedding.embed_documents(texts)
|
||||
QuerySet(Embedding).bulk_create([Embedding(id=uuid.uuid1(),
|
||||
document_id=text_list[index].get('document_id'),
|
||||
paragraph_id=text_list[index].get('paragraph_id'),
|
||||
dataset_id=text_list[index].get('dataset_id'),
|
||||
is_active=text_list[index].get('is_active', True),
|
||||
source_id=text_list[index].get('source_id'),
|
||||
source_type=text_list[index].get('source_type'),
|
||||
embedding=embeddings[index]) for index in
|
||||
range(0, len(text_list))]) if len(text_list) > 0 else None
|
||||
return True
|
||||
|
||||
def search(self, query_text, dataset_id_list: list[str], is_active: bool, embedding: HuggingFaceEmbeddings):
|
||||
pass
|
||||
|
||||
def update_by_source_id(self, source_id: str, instance: Dict):
|
||||
QuerySet(Embedding).filter(source_id=source_id).update(**instance)
|
||||
|
||||
def update_by_paragraph_id(self, paragraph_id: str, instance: Dict):
|
||||
QuerySet(Embedding).filter(paragraph_id=paragraph_id).update(**instance)
|
||||
|
||||
def delete_by_dataset_id(self, dataset_id: str):
|
||||
QuerySet(Embedding).filter(dataset_id=dataset_id).delete()
|
||||
|
||||
def delete_by_document_id(self, document_id: str):
|
||||
QuerySet(Embedding).filter(document_id=document_id).delete()
|
||||
return True
|
||||
|
||||
def delete_by_source_id(self, source_id: str, source_type: str):
|
||||
QuerySet(Embedding).filter(source_id=source_id, source_type=source_type).delete()
|
||||
return True
|
||||
|
||||
def delete_by_paragraph_id(self, paragraph_id: str):
|
||||
QuerySet(Embedding).filter(paragraph_id=paragraph_id).delete()
|
||||
|
|
@ -1,4 +1,4 @@
|
|||
# Generated by Django 4.1.10 on 2023-10-09 06:33
|
||||
# Generated by Django 4.1.10 on 2023-10-24 12:13
|
||||
|
||||
import django.contrib.postgres.fields
|
||||
from django.db import migrations, models
|
||||
|
|
|
|||
|
|
@ -13,7 +13,7 @@ import os
|
|||
import re
|
||||
from importlib import import_module
|
||||
from urllib.parse import urljoin, urlparse
|
||||
|
||||
import torch.backends
|
||||
import yaml
|
||||
|
||||
BASE_DIR = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
|
||||
|
|
@ -88,7 +88,14 @@ class Config(dict):
|
|||
"EMAIL_HOST": "",
|
||||
"EMAIL_PORT": 465,
|
||||
"EMAIL_HOST_USER": "",
|
||||
"EMAIL_HOST_PASSWORD": ""
|
||||
"EMAIL_HOST_PASSWORD": "",
|
||||
# 向量模型
|
||||
"EMBEDDING_MODEL_NAME": "shibing624/text2vec-base-chinese",
|
||||
"EMBEDDING_DEVICE": "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu",
|
||||
"EMBEDDING_MODEL_PATH": os.path.join(PROJECT_DIR, 'models'),
|
||||
# 向量库配置
|
||||
"VECTOR_STORE_NAME": 'pg_vector'
|
||||
|
||||
}
|
||||
|
||||
def get_db_setting(self) -> dict:
|
||||
|
|
@ -120,6 +127,8 @@ class ConfigManager:
|
|||
def __init__(self, root_path=None):
|
||||
self.root_path = root_path
|
||||
self.config = self.config_class()
|
||||
for key in self.config_class.defaults:
|
||||
self.config[key] = self.config_class.defaults[key]
|
||||
|
||||
def from_mapping(self, *mapping, **kwargs):
|
||||
"""Updates the config like :meth:`update` ignoring items with non-upper
|
||||
|
|
|
|||
|
|
@ -100,6 +100,11 @@ LOGGING = {
|
|||
'level': LOG_LEVEL,
|
||||
'propagate': False,
|
||||
},
|
||||
'sqlalchemy': {
|
||||
'handlers': ['console', 'file', 'syslog'],
|
||||
'level': LOG_LEVEL,
|
||||
'propagate': False,
|
||||
},
|
||||
'django.db.backends': {
|
||||
'handlers': ['console', 'file', 'syslog'],
|
||||
'propagate': False,
|
||||
|
|
|
|||
|
|
@ -14,3 +14,12 @@ from django.core.wsgi import get_wsgi_application
|
|||
os.environ.setdefault('DJANGO_SETTINGS_MODULE', 'smartdoc.settings')
|
||||
|
||||
application = get_wsgi_application()
|
||||
|
||||
|
||||
def post_handler():
|
||||
from common.event.listener_manage import ListenerManagement
|
||||
ListenerManagement().run()
|
||||
ListenerManagement.init_embedding_model_signal.send()
|
||||
|
||||
|
||||
post_handler()
|
||||
|
|
|
|||
|
|
@ -4,3 +4,4 @@ from django.apps import AppConfig
|
|||
class UsersConfig(AppConfig):
|
||||
default_auto_field = 'django.db.models.BigAutoField'
|
||||
name = 'users'
|
||||
|
||||
|
|
|
|||
|
|
@ -1,4 +1,4 @@
|
|||
# Generated by Django 4.1.10 on 2023-10-09 06:33
|
||||
# Generated by Django 4.1.10 on 2023-10-24 12:13
|
||||
|
||||
from django.db import migrations, models
|
||||
import uuid
|
||||
|
|
|
|||
|
|
@ -22,7 +22,7 @@ from common.constants.permission_constants import PermissionConstants, CompareCo
|
|||
from common.response import result
|
||||
from smartdoc.settings import JWT_AUTH
|
||||
from users.models.user import User as UserModel
|
||||
from users.serializers.user_serializers import RegisterSerializer, LoginSerializer, UserSerializer, CheckCodeSerializer, \
|
||||
from users.serializers.user_serializers import RegisterSerializer, LoginSerializer, CheckCodeSerializer, \
|
||||
RePasswordSerializer, \
|
||||
SendEmailSerializer, UserProfile
|
||||
|
||||
|
|
|
|||
|
|
@ -17,6 +17,12 @@ psycopg2-binary = "2.9.7"
|
|||
jieba = "^0.42.1"
|
||||
diskcache = "^5.6.3"
|
||||
pillow = "9.5.0"
|
||||
filetype = "^1.2.0"
|
||||
chardet = "^5.2.0"
|
||||
torch = "^2.1.0"
|
||||
sentence-transformers = "^2.2.2"
|
||||
blinker = "^1.6.3"
|
||||
|
||||
|
||||
[build-system]
|
||||
requires = ["poetry-core"]
|
||||
|
|
|
|||
Loading…
Reference in New Issue