diff --git a/.gitignore b/.gitignore index b1b65726d..136ad6f4e 100644 --- a/.gitignore +++ b/.gitignore @@ -162,6 +162,7 @@ cython_debug/ ui/node_modules ui/dist apps/static +models/ data .idea .dev diff --git a/apps/application/migrations/0001_initial.py b/apps/application/migrations/0001_initial.py index 91af9d5ea..8837f441f 100644 --- a/apps/application/migrations/0001_initial.py +++ b/apps/application/migrations/0001_initial.py @@ -1,4 +1,4 @@ -# Generated by Django 4.1.10 on 2023-10-09 06:33 +# Generated by Django 4.1.10 on 2023-10-24 12:13 import django.contrib.postgres.fields from django.db import migrations, models diff --git a/apps/common/config/embedding_config.py b/apps/common/config/embedding_config.py new file mode 100644 index 000000000..37592b424 --- /dev/null +++ b/apps/common/config/embedding_config.py @@ -0,0 +1,51 @@ +# coding=utf-8 +""" + @project: maxkb + @Author:虎 + @file: embedding_config.py + @date:2023/10/23 16:03 + @desc: +""" +import types +from smartdoc.const import CONFIG +from langchain.embeddings import HuggingFaceEmbeddings + + +class EmbeddingModel: + instance = None + + @staticmethod + def get_embedding_model(): + """ + 获取向量化模型 + :return: + """ + if EmbeddingModel.instance is None: + model_name = CONFIG.get('EMBEDDING_MODEL_NAME') + cache_folder = CONFIG.get('EMBEDDING_MODEL_PATH') + device = CONFIG.get('EMBEDDING_DEVICE') + e = HuggingFaceEmbeddings( + model_name=model_name, + cache_folder=cache_folder, + model_kwargs={'device': device}) + EmbeddingModel.instance = e + return EmbeddingModel.instance + + +class VectorStore: + from embedding.vector.pg_vector import PGVector + from embedding.vector.base_vector import BaseVectorStore + instance_map = { + 'pg_vector': PGVector, + } + instance = None + + @staticmethod + def get_embedding_vector() -> BaseVectorStore: + from embedding.vector.pg_vector import PGVector + if VectorStore.instance is None: + from smartdoc.const import CONFIG + vector_store_class = VectorStore.instance_map.get(CONFIG.get("VECTOR_STORE_NAME"), + PGVector) + VectorStore.instance = vector_store_class() + return VectorStore.instance diff --git a/apps/common/constants/permission_constants.py b/apps/common/constants/permission_constants.py index 8ee8d9230..f948f9f42 100644 --- a/apps/common/constants/permission_constants.py +++ b/apps/common/constants/permission_constants.py @@ -1,4 +1,3 @@ -# coding=utf-8 """ @project: qabot @Author:虎 diff --git a/apps/common/event/listener_manage.py b/apps/common/event/listener_manage.py new file mode 100644 index 000000000..6b3c85af6 --- /dev/null +++ b/apps/common/event/listener_manage.py @@ -0,0 +1,162 @@ +# coding=utf-8 +""" + @project: maxkb + @Author:虎 + @file: listener_manage.py + @date:2023/10/20 14:01 + @desc: +""" +import os +from concurrent.futures import ThreadPoolExecutor + +import django.db.models +from blinker import signal +from django.db.models import QuerySet + +from common.config.embedding_config import VectorStore, EmbeddingModel +from common.db.search import native_search, get_dynamics_model +from common.util.file_util import get_file_content +from dataset.models import Paragraph, Status, Document +from embedding.models import SourceType +from smartdoc.conf import PROJECT_DIR + + +def poxy(poxy_function): + def inner(args): + ListenerManagement.work_thread_pool.submit(poxy_function, args) + + return inner + + +class ListenerManagement: + work_thread_pool = ThreadPoolExecutor(5) + + embedding_by_problem_signal = signal("embedding_by_problem") + embedding_by_paragraph_signal = signal("embedding_by_paragraph") + embedding_by_dataset_signal = signal("embedding_by_dataset") + embedding_by_document_signal = signal("embedding_by_document") + delete_embedding_by_document_signal = signal("delete_embedding_by_document") + delete_embedding_by_dataset_signal = signal("delete_embedding_by_dataset") + delete_embedding_by_paragraph_signal = signal("delete_embedding_by_paragraph") + delete_embedding_by_source_signal = signal("delete_embedding_by_source") + enable_embedding_by_paragraph_signal = signal('enable_embedding_by_paragraph') + disable_embedding_by_paragraph_signal = signal('disable_embedding_by_paragraph') + init_embedding_model_signal = signal('init_embedding_model') + + @staticmethod + def embedding_by_problem(args): + VectorStore.get_embedding_vector().save(**args) + + @staticmethod + @poxy + def embedding_by_paragraph(paragraph_id): + """ + 向量化段落 根据段落id + :param paragraph_id: 段落id + :return: None + """ + data_list = native_search( + {'problem': QuerySet(get_dynamics_model({'problem.paragraph_id': django.db.models.CharField()})).filter( + **{'problem.paragraph_id': paragraph_id}), + 'paragraph': QuerySet(Paragraph).filter(id=paragraph_id)}, + select_string=get_file_content( + os.path.join(PROJECT_DIR, "apps", "common", 'sql', 'list_embedding_text.sql'))) + # 批量向量化 + VectorStore.get_embedding_vector().batch_save(data_list) + QuerySet(Paragraph).filter(id=paragraph_id).update(**{'status': Status.success.value}) + + @staticmethod + @poxy + def embedding_by_document(document_id): + """ + 向量化文档 + :param document_id: 文档id + :return: None + """ + data_list = native_search( + {'problem': QuerySet(get_dynamics_model({'problem.document_id': django.db.models.CharField()})).filter( + **{'problem.document_id': document_id}), + 'paragraph': QuerySet(Paragraph).filter(document_id=document_id)}, + select_string=get_file_content( + os.path.join(PROJECT_DIR, "apps", "common", 'sql', 'list_embedding_text.sql'))) + # 批量向量化 + VectorStore.get_embedding_vector().batch_save(data_list) + # 修改状态 + QuerySet(Document).filter(id=document_id).update(**{'status': Status.success.value}) + QuerySet(Paragraph).filter(document_id=document_id).update(**{'status': Status.success.value}) + + @staticmethod + @poxy + def embedding_by_dataset(dataset_id): + """ + 向量化数据集 + :param dataset_id: 数据集id + :return: None + """ + data_list = native_search( + {'problem': QuerySet(get_dynamics_model({'problem.dataset_id': django.db.models.CharField()})).filter( + **{'problem.dataset_id': dataset_id}), + 'paragraph': QuerySet(Paragraph).filter(dataset_id=dataset_id)}, + select_string=get_file_content( + os.path.join(PROJECT_DIR, "apps", "common", 'sql', 'list_embedding_text.sql'))) + # 批量向量化 + VectorStore.get_embedding_vector().batch_save(data_list) + # 修改文档 以及段落的状态 + QuerySet(Document).filter(dataset_id=dataset_id).update(**{'status': Status.success.value}) + QuerySet(Paragraph).filter(dataset_id=dataset_id).update(**{'status': Status.success.value}) + + @staticmethod + def delete_embedding_by_document(document_id): + VectorStore.get_embedding_vector().delete_by_document_id(document_id) + + @staticmethod + def delete_embedding_by_dataset(dataset_id): + VectorStore.get_embedding_vector().delete_by_dataset_id(dataset_id) + + @staticmethod + def delete_embedding_by_paragraph(paragraph_id): + VectorStore.get_embedding_vector().delete_by_paragraph_id(paragraph_id) + + @staticmethod + def delete_embedding_by_source(source_id): + VectorStore.get_embedding_vector().delete_by_source_id(source_id, SourceType.PROBLEM) + + @staticmethod + def disable_embedding_by_paragraph(paragraph_id): + VectorStore.get_embedding_vector().update_by_paragraph_id(paragraph_id, {'is_active': False}) + + @staticmethod + def enable_embedding_by_paragraph(paragraph_id): + VectorStore.get_embedding_vector().update_by_paragraph_id(paragraph_id, {'is_active': True}) + + @staticmethod + @poxy + def init_embedding_model(ags): + EmbeddingModel.get_embedding_model() + + def run(self): + # 添加向量 根据问题id + ListenerManagement.embedding_by_problem_signal.connect(self.embedding_by_problem) + # 添加向量 根据段落id + ListenerManagement.embedding_by_paragraph_signal.connect(self.embedding_by_paragraph) + # 添加向量 根据数据集id + ListenerManagement.embedding_by_dataset_signal.connect( + self.embedding_by_dataset) + # 添加向量 根据文档id + ListenerManagement.embedding_by_document_signal.connect( + self.embedding_by_document) + # 删除 向量 根据文档 + ListenerManagement.delete_embedding_by_document_signal.connect(self.delete_embedding_by_document) + # 删除 向量 根据数据集id + ListenerManagement.delete_embedding_by_dataset_signal.connect(self.delete_embedding_by_dataset) + # 删除向量 根据段落id + ListenerManagement.delete_embedding_by_paragraph_signal.connect( + self.delete_embedding_by_paragraph) + # 删除向量 根据资源id + ListenerManagement.delete_embedding_by_source_signal.connect(self.delete_embedding_by_source) + # 禁用段落 + ListenerManagement.disable_embedding_by_paragraph_signal.connect(self.disable_embedding_by_paragraph) + # 启动段落向量 + ListenerManagement.enable_embedding_by_paragraph_signal.connect(self.enable_embedding_by_paragraph) + # 初始化向量化模型 + ListenerManagement.init_embedding_model_signal.connect(self.init_embedding_model) diff --git a/apps/common/handle/handle_exception.py b/apps/common/handle/handle_exception.py index 5bab9e3f8..e12befb6e 100644 --- a/apps/common/handle/handle_exception.py +++ b/apps/common/handle/handle_exception.py @@ -14,7 +14,7 @@ from rest_framework.views import exception_handler from common.exception.app_exception import AppApiException from common.response import result - +import traceback def to_result(key, args, parent_key=None): """ 将校验异常 args转换为统一数据 @@ -59,6 +59,7 @@ def handle_exception(exc, context): exception_class = exc.__class__ # 先调用REST framework默认的异常处理方法获得标准错误响应对象 response = exception_handler(exc, context) + traceback.print_exc() # 在此处补充自定义的异常处理 if issubclass(exception_class, ValidationError): return validation_error_to_result(exc) diff --git a/apps/common/response/result.py b/apps/common/response/result.py index 1b205069b..8e756aec7 100644 --- a/apps/common/response/result.py +++ b/apps/common/response/result.py @@ -70,7 +70,9 @@ def get_page_api_response(response_data_schema: openapi.Schema): title="总条数", default=1, description="数据总条数"), - "records": response_data_schema, + "records": openapi.Schema( + type=openapi.TYPE_ARRAY, + items=response_data_schema), "current": openapi.Schema( type=openapi.TYPE_INTEGER, title="当前页", @@ -115,6 +117,36 @@ def get_api_response(response_data_schema: openapi.Schema): )}) +def get_default_response(): + return get_api_response(openapi.Schema(type=openapi.TYPE_BOOLEAN)) + + +def get_api_array_response(response_data_schema: openapi.Schema): + """ + 获取统一返回 响应Api + """ + return openapi.Responses(responses={200: openapi.Response(description="响应参数", + schema=openapi.Schema( + type=openapi.TYPE_OBJECT, + properties={ + 'code': openapi.Schema( + type=openapi.TYPE_INTEGER, + title="响应码", + default=200, + description="成功:200 失败:其他"), + "message": openapi.Schema( + type=openapi.TYPE_STRING, + title="提示", + default='成功', + description="错误提示"), + "data": openapi.Schema(type=openapi.TYPE_ARRAY, + items=response_data_schema) + + } + ), + )}) + + def success(data): """ 获取一个成功的响应对象 diff --git a/apps/common/sql/list_embedding_text.sql b/apps/common/sql/list_embedding_text.sql new file mode 100644 index 000000000..2c55697ec --- /dev/null +++ b/apps/common/sql/list_embedding_text.sql @@ -0,0 +1,26 @@ +SELECT + problem."id" AS "source_id", + problem.document_id AS document_id, + problem.paragraph_id AS paragraph_id, + problem.dataset_id AS dataset_id, + 0 AS source_type, + problem."content" AS "text", + paragraph.is_active AS is_active +FROM + problem problem + LEFT JOIN paragraph paragraph ON paragraph."id" = problem.paragraph_id + ${problem} + +UNION +SELECT + paragraph."id" AS "source_id", + paragraph.document_id AS document_id, + paragraph."id" AS paragraph_id, + paragraph.dataset_id AS dataset_id, + 1 AS source_type, + paragraph."content" AS "text", + paragraph.is_active AS is_active +FROM + paragraph paragraph + + ${paragraph} \ No newline at end of file diff --git a/apps/common/util/common.py b/apps/common/util/common.py new file mode 100644 index 000000000..0fce2affd --- /dev/null +++ b/apps/common/util/common.py @@ -0,0 +1,19 @@ +# coding=utf-8 +""" + @project: maxkb + @Author:虎 + @file: common.py + @date:2023/10/16 16:42 + @desc: +""" +from functools import reduce +from typing import Dict + + +def query_params_to_single_dict(query_params: Dict): + return reduce(lambda x, y: {**x, y[0]: y[1]}, list(filter(lambda row: row[1] is not None, + list(map(lambda row: ( + row[0], row[1][0] if isinstance(row[1][0], + list) and len( + row[1][0]) > 0 else row[1][0]), + query_params.items())))), {}) diff --git a/apps/common/util/split_model.py b/apps/common/util/split_model.py index 11047be78..7f391468c 100644 --- a/apps/common/util/split_model.py +++ b/apps/common/util/split_model.py @@ -7,6 +7,7 @@ @desc: """ import re +from functools import reduce from typing import List import jieba @@ -25,7 +26,7 @@ def get_level_block(text, level_content_list, level_content_index): level_content_list) else None start_index = text.index(start_content) end_index = text.index(next_content) if next_content is not None else len(text) - return text[start_index:end_index] + return text[start_index:end_index].replace(level_content_list[level_content_index]['content'], "") def to_tree_obj(content, state='title'): @@ -88,7 +89,7 @@ def to_paragraph(obj: dict): content = obj['content'] return {"keywords": get_keyword(content), 'parent_chain': list(map(lambda p: p['content'], obj['parent_chain'])), - 'content': content} + 'content': ",".join(list(map(lambda p: p['content'], obj['parent_chain']))) + content} def get_keyword(content: str): @@ -109,13 +110,15 @@ def titles_to_paragraph(list_title: List[dict]): :return: 块段落 """ if len(list_title) > 0: - content = "\n".join( + content = "\n,".join( list(map(lambda d: d['content'].strip("\r\n").strip("\n").strip("\\s"), list_title))) return {'keywords': '', 'parent_chain': list( map(lambda p: p['content'].strip("\r\n").strip("\n").strip("\\s"), list_title[0]['parent_chain'])), - 'content': content} + 'content': ",".join(list( + map(lambda p: p['content'].strip("\r\n").strip("\n").strip("\\s"), + list_title[0]['parent_chain']))) + content} return None @@ -144,6 +147,15 @@ def to_block_paragraph(tree_data_list: List[dict]): return list(map(lambda level: parse_group_key(level_group_dict[level]), level_group_dict)) +def parse_title_level(text, content_level_pattern: List, index): + if len(content_level_pattern) == index: + return [] + result = parse_level(text, content_level_pattern[index]) + if len(result) == 0 and len(content_level_pattern) > index + 1: + return parse_title_level(text, content_level_pattern, index + 1) + return result + + def parse_level(text, pattern: str): """ 获取正则匹配到的文本 @@ -151,10 +163,17 @@ def parse_level(text, pattern: str): :param pattern: 正则 :return: 符合正则的文本 """ - level_content_list = list(map(to_tree_obj, re.findall(pattern, text, flags=0))) + level_content_list = list(map(to_tree_obj, re_findall(pattern, text))) return list(map(filter_special_symbol, level_content_list)) +def re_findall(pattern, text): + result = re.findall(pattern, text, flags=0) + return list(filter(lambda r: r is not None and len(r) > 0, reduce(lambda x, y: [*x, *y], list( + map(lambda row: [*(row if isinstance(row, tuple) else [row])], result)), + []))) + + def to_flat_obj(parent_chain: List[dict], content: str, state: str): """ 将树形属性转换为扁平对象 @@ -194,10 +213,79 @@ def group_by(list_source: List, key): return result +def result_tree_to_paragraph(result_tree: List[dict], result, parent_chain): + """ + 转换为分段对象 + :param result_tree: 解析文本的树 + :param result: 传[] 用于递归 + :param parent_chain: 传[] 用户递归存储数据 + :return: List[{'problem':'xx','content':'xx'}] + """ + for item in result_tree: + print(item) + if item.get('state') == 'block': + result.append({'title': " ".join(parent_chain), 'content': item.get("content")}) + children = item.get("children") + if children is not None and len(children) > 0: + result_tree_to_paragraph(children, result, [*parent_chain, item.get('content')]) + return result + + +def post_handler_paragraph(content: str, limit: int, with_filter: bool): + """ + 根据文本的最大字符分段 + :param with_filter: 是否过滤特殊字符 + :param content: 需要分段的文本字段 + :param limit: 最大分段字符 + :return: 分段后数据 + """ + split_list = content.split('\n') + result = [] + temp_char = '' + for split in split_list: + if len(temp_char + split) > limit: + result.append(temp_char) + temp_char = '' + temp_char = temp_char + split + if len(temp_char) > 0: + result.append(temp_char) + pattern = "[\\S\\s]{1," + str(limit) + '}' + # 如果\n 单段超过限制,则继续拆分 + s = list(map(lambda row: filter_special_char(row) if with_filter else row, list( + reduce(lambda x, y: [*x, *y], list(map(lambda row: list(re.findall(pattern, row)), result)), [])))) + return s + + +replace_map = { + re.compile('\n+'): '\n', + re.compile('\\s+'): ' ', + re.compile('#+'): "", + re.compile("\t+"): '' +} + + +def filter_special_char(content: str): + """ + 过滤特殊字段 + :param content: 文本 + :return: 过滤后字段 + """ + items = replace_map.items() + for key, value in items: + content = re.sub(key, value, content) + return content + + class SplitModel: - def __init__(self, content_level_pattern): + def __init__(self, content_level_pattern, with_filter=True, limit=1024): self.content_level_pattern = content_level_pattern + self.with_filter = with_filter + if limit is None or limit > 1024: + limit = 1024 + if limit < 50: + limit = 50 + self.limit = limit def parse_to_tree(self, text: str, index=0): """ @@ -208,23 +296,27 @@ class SplitModel: """ if len(self.content_level_pattern) == index: return - level_content_list = parse_level(text, pattern=self.content_level_pattern[index]) + level_content_list = parse_title_level(text, self.content_level_pattern, index) for i in range(len(level_content_list)): block = get_level_block(text, level_content_list, i) - children = self.parse_to_tree(text=block.replace(level_content_list[i]['content'][:-1], ""), + children = self.parse_to_tree(text=block, index=index + 1) if children is not None and len(children) > 0: level_content_list[i]['children'] = children else: if len(block) > 0: - level_content_list[i]['children'] = [to_tree_obj(block, 'block')] + level_content_list[i]['children'] = list( + map(lambda row: to_tree_obj(row, 'block'), + post_handler_paragraph(block, with_filter=self.with_filter, limit=self.limit))) if len(level_content_list) > 0: end_index = text.index(level_content_list[0].get('content')) if end_index == 0: return level_content_list other_content = text[0:end_index] if len(other_content.strip()) > 0: - level_content_list.append(to_tree_obj(other_content, 'block')) + level_content_list = [*level_content_list, *list( + map(lambda row: to_tree_obj(row, 'block'), + post_handler_paragraph(other_content, with_filter=self.with_filter, limit=self.limit)))] return level_content_list def parse(self, text: str): @@ -234,4 +326,35 @@ class SplitModel: :return: 解析后数据 {content:段落数据,keywords:[‘段落关键词’],parent_chain:['段落父级链路']} """ result_tree = self.parse_to_tree(text, 0) - return flat_map(to_block_paragraph(result_tree)) + return result_tree_to_paragraph(result_tree, [], []) + + +split_model_map = { + 'md': SplitModel( + [re.compile("^# .*"), re.compile('(? 0 else None - # 插入段落 - QuerySet(Paragraph).bulk_create(paragraph_model_list) if len(paragraph_model_list) > 0 else None - return True + for document in self.data.get('documents') if 'documents' in self.data else []: + DocumentSerializers.Create(data={'dataset_id': dataset_id}).save(document, with_valid=True, + with_embedding=False) + ListenerManagement.embedding_by_dataset_signal.send(str(dataset.id)) + return {**DataSetSerializers(dataset).data, + 'document_list': DocumentSerializers.Query(data={'dataset_id': dataset_id}).list(with_valid=True)} + + @staticmethod + def get_response_body_api(): + return openapi.Schema( + type=openapi.TYPE_OBJECT, + required=['id', 'name', 'desc', 'user_id', 'char_length', 'document_count', + 'update_time', 'create_time', 'document_list'], + properties={ + 'id': openapi.Schema(type=openapi.TYPE_STRING, title="id", + description="id", default="xx"), + 'name': openapi.Schema(type=openapi.TYPE_STRING, title="名称", + description="名称", default="测试数据集"), + 'desc': openapi.Schema(type=openapi.TYPE_STRING, title="描述", + description="描述", default="测试数据集描述"), + 'user_id': openapi.Schema(type=openapi.TYPE_STRING, title="所属用户id", + description="所属用户id", default="user_xxxx"), + 'char_length': openapi.Schema(type=openapi.TYPE_STRING, title="字符数", + description="字符数", default=10), + 'document_count': openapi.Schema(type=openapi.TYPE_STRING, title="文档数量", + description="文档数量", default=1), + 'update_time': openapi.Schema(type=openapi.TYPE_STRING, title="修改时间", + description="修改时间", + default="1970-01-01 00:00:00"), + 'create_time': openapi.Schema(type=openapi.TYPE_STRING, title="创建时间", + description="创建时间", + default="1970-01-01 00:00:00" + ), + 'document_list': openapi.Schema(type=openapi.TYPE_ARRAY, title="文档列表", + description="文档列表", + items=DocumentSerializers.Operate.get_response_body_api()) + } + ) @staticmethod def get_request_body_api(): @@ -200,7 +216,7 @@ class DataSetSerializers(serializers.ModelSerializer): 'name': openapi.Schema(type=openapi.TYPE_STRING, title="数据集名称", description="数据集名称"), 'desc': openapi.Schema(type=openapi.TYPE_STRING, title="数据集描述", description="数据集描述"), 'documents': openapi.Schema(type=openapi.TYPE_ARRAY, title="文档数据", description="文档数据", - items=CreateDocumentSerializers().get_request_body_api() + items=DocumentSerializers().Create.get_request_body_api() ) } ) @@ -217,10 +233,11 @@ class DataSetSerializers(serializers.ModelSerializer): def delete(self): self.is_valid() dataset = QuerySet(DataSet).get(id=self.data.get("id")) - document_list = QuerySet(Document).filter(dataset=dataset) - QuerySet(Paragraph).filter(document__in=document_list).delete() - document_list.delete() + QuerySet(Document).filter(dataset=dataset).delete() + QuerySet(Paragraph).filter(dataset=dataset).delete() + QuerySet(Problem).filter(dataset=dataset).delete() dataset.delete() + ListenerManagement.delete_embedding_by_dataset_signal.send(self.data.get('id')) return True def one(self, user_id, with_valid=True): @@ -303,9 +320,9 @@ class DataSetSerializers(serializers.ModelSerializer): @staticmethod def get_request_params_api(): - return [openapi.Parameter(name='id', + return [openapi.Parameter(name='dataset_id', in_=openapi.IN_PATH, type=openapi.TYPE_STRING, - required=False, + required=True, description='数据集id') ] diff --git a/apps/dataset/serializers/document_serializers.py b/apps/dataset/serializers/document_serializers.py index 1f1e92b04..7da8ac8ca 100644 --- a/apps/dataset/serializers/document_serializers.py +++ b/apps/dataset/serializers/document_serializers.py @@ -6,20 +6,29 @@ @date:2023/9/22 13:43 @desc: """ +import os import uuid from functools import reduce +from typing import List, Dict from django.core import validators +from django.db import transaction from django.db.models import QuerySet from drf_yasg import openapi from rest_framework import serializers +from common.db.search import native_search, native_page_search +from common.event.listener_manage import ListenerManagement from common.exception.app_exception import AppApiException from common.mixins.api_mixin import ApiMixin -from dataset.models.data_set import DataSet, Document, Paragraph +from common.util.file_util import get_file_content +from common.util.split_model import SplitModel, get_split_model +from dataset.models.data_set import DataSet, Document, Paragraph, Problem +from dataset.serializers.paragraph_serializers import ParagraphSerializers, ParagraphInstanceSerializer +from smartdoc.conf import PROJECT_DIR -class CreateDocumentSerializers(ApiMixin, serializers.Serializer): +class DocumentInstanceSerializer(ApiMixin, serializers.Serializer): name = serializers.CharField(required=True, validators=[ validators.MaxLengthValidator(limit_value=128, @@ -28,52 +37,265 @@ class CreateDocumentSerializers(ApiMixin, serializers.Serializer): message="数据集名称在1-128个字符之间") ]) - paragraphs = serializers.ListField(required=False, - child=serializers.CharField(required=True, - validators=[ - validators.MaxLengthValidator(limit_value=256, - message="段落在1-256个字符之间"), - validators.MinLengthValidator(limit_value=1, - message="段落在1-256个字符之间") - ])) + paragraphs = ParagraphInstanceSerializer(required=False, many=True) - def is_valid(self, *, dataset_id=None, raise_exception=False): - if not QuerySet(DataSet).filter(id=dataset_id).exists(): - raise AppApiException(10000, "数据集id不存在") - return super().is_valid(raise_exception=True) - - def save(self, dataset_id: str, **kwargs): - document_model = Document( - **{'dataset': DataSet(id=dataset_id), - 'id': uuid.uuid1(), - 'name': self.data.get('name'), - 'char_length': reduce(lambda x, y: x + y, list(map(lambda p: len(p), self.data.get("paragraphs"))), 0)}) - - paragraph_model_list = list(map(lambda p: Paragraph( - **{'document': document_model, 'id': uuid.uuid1(), 'content': p}), - self.data.get('paragraphs'))) - - # 插入文档 - document_model.save() - # 插入段落 - QuerySet(Paragraph).bulk_create(paragraph_model_list) if len(paragraph_model_list) > 0 else None - return True - - def get_request_body_api(self): + @staticmethod + def get_request_body_api(): return openapi.Schema( type=openapi.TYPE_OBJECT, - required=['name', 'paragraph'], + required=['name', 'paragraphs'], properties={ 'name': openapi.Schema(type=openapi.TYPE_STRING, title="文档名称", description="文档名称"), 'paragraphs': openapi.Schema(type=openapi.TYPE_ARRAY, title="段落列表", description="段落列表", - items=openapi.Schema(type=openapi.TYPE_STRING, title="段落数据", - description="段落数据")) + items=ParagraphSerializers.Create.get_request_body_api()) } ) - def get_request_params_api(self): - return [openapi.Parameter(name='dataset_id', - in_=openapi.IN_PATH, - type=openapi.TYPE_STRING, + +class DocumentSerializers(ApiMixin, serializers.Serializer): + class Query(ApiMixin, serializers.Serializer): + # 数据集id + dataset_id = serializers.UUIDField(required=True) + + name = serializers.CharField(required=False, + validators=[ + validators.MaxLengthValidator(limit_value=128, + message="文档名称在1-128个字符之间"), + validators.MinLengthValidator(limit_value=1, + message="数据集名称在1-128个字符之间") + ]) + + def get_query_set(self): + query_set = QuerySet(model=Document) + query_set = query_set.filter(**{'dataset_id': self.data.get("dataset_id")}) + if 'name' in self.data and self.data.get('name') is not None: + query_set = query_set.filter(**{'name__contains': self.data.get('name')}) + return query_set + + def list(self, with_valid=False): + if with_valid: + self.is_valid(raise_exception=True) + query_set = self.get_query_set() + return native_search(query_set, select_string=get_file_content( + os.path.join(PROJECT_DIR, "apps", "dataset", 'sql', 'list_document.sql'))) + + def page(self, current_page, page_size): + query_set = self.get_query_set() + return native_page_search(current_page, page_size, query_set, select_string=get_file_content( + os.path.join(PROJECT_DIR, "apps", "dataset", 'sql', 'list_document.sql'))) + + @staticmethod + def get_request_params_api(): + return [openapi.Parameter(name='name', + in_=openapi.IN_QUERY, + type=openapi.TYPE_STRING, + required=False, + description='文档名称')] + + @staticmethod + def get_response_body_api(): + return openapi.Schema(type=openapi.TYPE_ARRAY, + title="文档列表", description="文档列表", + items=DocumentSerializers.Operate.get_response_body_api()) + + class Operate(ApiMixin, serializers.Serializer): + document_id = serializers.UUIDField(required=True) + + @staticmethod + def get_request_params_api(): + return [openapi.Parameter(name='dataset_id', + in_=openapi.IN_PATH, + type=openapi.TYPE_STRING, + required=True, + description='数据集id'), + openapi.Parameter(name='document_id', + in_=openapi.IN_PATH, + type=openapi.TYPE_STRING, + required=True, + description='文档id') + ] + + def is_valid(self, *, raise_exception=False): + super().is_valid(raise_exception=True) + document_id = self.data.get('document_id') + if not QuerySet(Document).filter(id=document_id).exists(): + raise AppApiException(500, "文档id不存在") + + def one(self, with_valid=False): + if with_valid: + self.is_valid(raise_exception=True) + query_set = QuerySet(model=Document) + query_set = query_set.filter(**{'id': self.data.get("document_id")}) + return native_search(query_set, select_string=get_file_content( + os.path.join(PROJECT_DIR, "apps", "dataset", 'sql', 'list_document.sql')), with_search_one=True) + + def edit(self, instance: Dict, with_valid=False): + if with_valid: + self.is_valid() + _document = QuerySet(Document).get(id=self.data.get("document_id")) + update_keys = ['name', 'is_active'] + for update_key in update_keys: + if update_key in instance and instance.get(update_key) is not None: + _document.__setattr__(update_key, instance.get(update_key)) + _document.save() + return self.one() + + @transaction.atomic + def delete(self): + document_id = self.data.get("document_id") + QuerySet(model=Document).filter(id=document_id).delete() + # 删除段落 + QuerySet(model=Paragraph).filter(document_id=document_id).delete() + # 删除问题 + QuerySet(model=Problem).filter(document_id=document_id).delete() + # 删除向量库 + ListenerManagement.delete_embedding_by_document_signal.send(document_id) + return True + + @staticmethod + def get_response_body_api(): + return openapi.Schema( + type=openapi.TYPE_OBJECT, + required=['id', 'name', 'char_length', 'user_id', 'paragraph_count', 'is_active' + 'update_time', 'create_time'], + properties={ + 'id': openapi.Schema(type=openapi.TYPE_STRING, title="id", + description="id", default="xx"), + 'name': openapi.Schema(type=openapi.TYPE_STRING, title="名称", + description="名称", default="测试数据集"), + 'char_length': openapi.Schema(type=openapi.TYPE_INTEGER, title="字符数", + description="字符数", default=10), + 'user_id': openapi.Schema(type=openapi.TYPE_STRING, title="用户id", description="用户id"), + 'paragraph_count': openapi.Schema(type=openapi.TYPE_INTEGER, title="文档数量", + description="文档数量", default=1), + 'is_active': openapi.Schema(type=openapi.TYPE_BOOLEAN, title="是否可用", + description="是否可用", default=True), + 'update_time': openapi.Schema(type=openapi.TYPE_STRING, title="修改时间", + description="修改时间", + default="1970-01-01 00:00:00"), + 'create_time': openapi.Schema(type=openapi.TYPE_STRING, title="创建时间", + description="创建时间", + default="1970-01-01 00:00:00" + ) + } + ) + + @staticmethod + def get_request_body_api(): + return openapi.Schema( + type=openapi.TYPE_OBJECT, + properties={ + 'name': openapi.Schema(type=openapi.TYPE_STRING, title="文档名称", description="文档名称"), + 'is_active': openapi.Schema(type=openapi.TYPE_BOOLEAN, title="是否可用", description="是否可用"), + } + ) + + class Create(ApiMixin, serializers.Serializer): + dataset_id = serializers.UUIDField(required=True) + + def is_valid(self, *, raise_exception=False): + super().is_valid(raise_exception=True) + if not QuerySet(DataSet).filter(id=self.data.get('dataset_id')).exists(): + raise AppApiException(10000, "数据集id不存在") + return True + + def save(self, instance: Dict, with_valid=False, with_embedding=True, **kwargs): + if with_valid: + DocumentInstanceSerializer(data=instance).is_valid() + self.is_valid(raise_exception=True) + dataset_id = self.data.get('dataset_id') + + document_model = Document( + **{'dataset_id': dataset_id, + 'id': uuid.uuid1(), + 'name': instance.get('name'), + 'char_length': reduce(lambda x, y: x + y, + [len(p.get('content')) for p in instance.get('paragraphs', [])], + 0)}) + for paragraph in instance.get('paragraphs') if 'paragraphs' in instance else []: + ParagraphSerializers.Create( + data={'dataset_id': dataset_id, 'document_id': str(document_model.id)}).save(paragraph, + with_valid=True, + with_embedding=False) + # 插入文档 + document_model.save() + if with_embedding: + ListenerManagement.embedding_by_document_signal.send(str(document_model.id)) + return DocumentSerializers.Operate( + data={'dataset_id': dataset_id, 'document_id': str(document_model.id)}).one( + with_valid=True) + + @staticmethod + def get_request_body_api(): + return DocumentInstanceSerializer.get_request_body_api() + + @staticmethod + def get_request_params_api(): + return [openapi.Parameter(name='dataset_id', + in_=openapi.IN_PATH, + type=openapi.TYPE_STRING, + required=True, + description='数据集id') + ] + + class Split(ApiMixin, serializers.Serializer): + file = serializers.ListField(required=True) + + limit = serializers.IntegerField(required=False) + + patterns = serializers.ListField(required=False, + child=serializers.CharField(required=True)) + + with_filter = serializers.BooleanField(required=False) + + def is_valid(self, *, raise_exception=True): + super().is_valid() + files = self.data.get('file') + for f in files: + if f.size > 1024 * 1024 * 10: + raise AppApiException(500, "上传文件最大不能超过10m") + + @staticmethod + def get_request_params_api(): + return [ + openapi.Parameter(name='file', + in_=openapi.IN_FORM, + type=openapi.TYPE_ARRAY, + items=openapi.Items(type=openapi.TYPE_FILE), required=True, - description='数据集id')] + description='上传文件'), + openapi.Parameter(name='limit', + in_=openapi.IN_FORM, + required=False, + type=openapi.TYPE_INTEGER, title="分段长度", description="分段长度"), + openapi.Parameter(name='patterns', + in_=openapi.IN_FORM, + required=False, + type=openapi.TYPE_ARRAY, items=openapi.Items(type=openapi.TYPE_STRING), + title="分段正则列表", description="分段正则列表"), + openapi.Parameter(name='with_filter', + in_=openapi.IN_FORM, + required=False, + type=openapi.TYPE_BOOLEAN, title="是否清除特殊字符", description="是否清除特殊字符"), + ] + + def parse(self): + file_list = self.data.get("file") + return list(map(lambda f: file_to_paragraph(f, self.data.get("patterns"), self.data.get("with_filter"), + self.data.get("limit")), file_list)) + + +def file_to_paragraph(file, pattern_list: List, with_filter, limit: int): + data = file.read() + if pattern_list is None or len(pattern_list) > 0: + split_model = SplitModel(pattern_list, with_filter, limit) + else: + split_model = get_split_model(file.name) + try: + content = data.decode('utf-8') + except BaseException as e: + return {'name': file.name, + 'content': []} + return {'name': file.name, + 'content': split_model.parse(content) + } diff --git a/apps/dataset/serializers/paragraph_serializers.py b/apps/dataset/serializers/paragraph_serializers.py new file mode 100644 index 000000000..9e4e7867a --- /dev/null +++ b/apps/dataset/serializers/paragraph_serializers.py @@ -0,0 +1,278 @@ +# coding=utf-8 +""" + @project: maxkb + @Author:虎 + @file: paragraph_serializers.py + @date:2023/10/16 15:51 + @desc: +""" +import uuid +from typing import Dict + +from django.core import validators +from django.db import transaction +from django.db.models import QuerySet +from drf_yasg import openapi +from rest_framework import serializers + +from common.db.search import page_search +from common.event.listener_manage import ListenerManagement +from common.exception.app_exception import AppApiException +from common.mixins.api_mixin import ApiMixin +from dataset.models import Paragraph, Problem +from dataset.serializers.problem_serializers import ProblemInstanceSerializer, ProblemSerializer + + +class ParagraphSerializer(serializers.ModelSerializer): + class Meta: + model = Paragraph + fields = ['id', 'content', 'hit_num', 'star_num', 'trample_num', 'is_active', 'document_id', 'title', + 'create_time', 'update_time'] + + +class ParagraphInstanceSerializer(ApiMixin, serializers.Serializer): + """ + 段落实例对象 + """ + content = serializers.CharField(required=True, validators=[ + validators.MaxLengthValidator(limit_value=1024, + message="段落在1-1024个字符之间"), + validators.MinLengthValidator(limit_value=1, + message="段落在1-1024个字符之间") + ]) + + title = serializers.CharField(required=False) + + problem_list = ProblemInstanceSerializer(required=False, many=True) + + is_active = serializers.BooleanField(required=False) + + @staticmethod + def get_request_body_api(): + return openapi.Schema( + type=openapi.TYPE_OBJECT, + required=['content'], + properties={ + 'content': openapi.Schema(type=openapi.TYPE_STRING, title="分段内容", description="分段内容"), + + 'title': openapi.Schema(type=openapi.TYPE_STRING, title="分段标题", + description="分段标题"), + + 'is_active': openapi.Schema(type=openapi.TYPE_BOOLEAN, title="是否可用", description="是否可用"), + + 'problem_list': openapi.Schema(type=openapi.TYPE_ARRAY, title='问题列表', + description="问题列表", + items=ProblemInstanceSerializer.get_request_body_api()) + } + ) + + +class ParagraphSerializers(ApiMixin, serializers.Serializer): + class Operate(ApiMixin, serializers.Serializer): + # 段落id + paragraph_id = serializers.UUIDField(required=True) + # 数据集id + dataset_id = serializers.UUIDField(required=True) + # 数据集id + document_id = serializers.UUIDField(required=True) + + def is_valid(self, *, raise_exception=True): + super().is_valid(raise_exception=True) + if not QuerySet(Paragraph).filter(id=self.data.get('paragraph_id')).exists(): + raise AppApiException(500, "段落id不存在") + + @transaction.atomic + def edit(self, instance: Dict): + self.is_valid() + _paragraph = QuerySet(Paragraph).get(id=self.data.get("paragraph_id")) + update_keys = ['title', 'content', 'is_active'] + for update_key in update_keys: + if update_key in instance and instance.get(update_key) is not None: + _paragraph.__setattr__(update_key, instance.get(update_key)) + + if 'problem_list' in instance: + update_problem_list = list( + filter(lambda row: 'id' in row and row.get('id') is not None, instance.get('problem_list'))) + + create_problem_list = list(filter(lambda row: row.get('id') is None, instance.get('problem_list'))) + + # 问题集合 + problem_list = QuerySet(Problem).filter(paragraph_id=self.data.get("paragraph_id")) + + # 校验前端 携带过来的id + for update_problem in update_problem_list: + if not set([str(row.id) for row in problem_list]).__contains__(update_problem.get('id')): + raise AppApiException(500, update_problem.get('id') + '问题id不存在') + # 对比需要删除的问题 + delete_problem_list = list(filter( + lambda row: not [str(update_row.get('id')) for update_row in update_problem_list].__contains__( + str(row.id)), problem_list)) if len(update_problem_list) > 0 else [] + # 删除问题 + QuerySet(Problem).filter(id__in=[row.id for row in delete_problem_list]).delete() if len( + delete_problem_list) > 0 else None + # 插入新的问题 + QuerySet(Problem).bulk_create( + [Problem(id=uuid.uuid1(), content=p.get('content'), paragraph_id=self.data.get('paragraph_id'), + dataset_id=self.data.get('dataset_id'), document_id=self.data.get('document_id')) for + p in create_problem_list]) if len(create_problem_list) else None + + # 修改问题集合 + QuerySet(Problem).bulk_update( + [Problem(id=row.get('id'), content=row.get('content')) for row in update_problem_list], + ['content']) if len( + update_problem_list) > 0 else None + + _paragraph.save() + if 'is_active' in instance and instance.get('is_active') is not None: + s = (ListenerManagement.enable_embedding_by_paragraph_signal if instance.get( + 'is_active') else ListenerManagement.disable_embedding_by_paragraph_signal) + s.send(self.data.get('paragraph_id')) + return self.one() + + def get_problem_list(self): + return [ProblemSerializer(problem).data for problem in + QuerySet(Problem).filter(paragraph_id=self.data.get("paragraph_id"))] + + def one(self, with_valid=False): + if with_valid: + self.is_valid(raise_exception=True) + return {**ParagraphSerializer(QuerySet(model=Paragraph).get(id=self.data.get('paragraph_id'))).data, + 'problem_list': self.get_problem_list()} + + def delete(self, with_valid=False): + if with_valid: + self.is_valid(raise_exception=True) + paragraph_id = self.data.get('paragraph_id') + QuerySet(Paragraph).filter(id=paragraph_id).delete() + QuerySet(Problem).filter(paragraph_id=paragraph_id).delete() + ListenerManagement.delete_embedding_by_paragraph_signal.send(paragraph_id) + + @staticmethod + def get_request_body_api(): + return ParagraphInstanceSerializer.get_request_body_api() + + @staticmethod + def get_response_body_api(): + return ParagraphInstanceSerializer.get_request_body_api() + + @staticmethod + def get_request_params_api(): + return [openapi.Parameter(type=openapi.TYPE_STRING, in_=openapi.IN_PATH, name='paragraph_id', + description="段落id")] + + class Create(ApiMixin, serializers.Serializer): + dataset_id = serializers.UUIDField(required=True) + + document_id = serializers.UUIDField(required=True) + + def save(self, instance: Dict, with_valid=True, with_embedding=True): + if with_valid: + ParagraphSerializers(data=instance).is_valid(raise_exception=True) + self.is_valid() + dataset_id = self.data.get("dataset_id") + document_id = self.data.get('document_id') + + paragraph = Paragraph(id=uuid.uuid1(), + document_id=document_id, + content=instance.get("content"), + dataset_id=dataset_id, + title=instance.get("title") if 'title' in instance else '') + # 插入段落 + paragraph.save() + problem_model_list = [Problem(id=uuid.uuid1(), content=problem.get('content'), paragraph_id=paragraph.id, + document_id=document_id, dataset_id=dataset_id) for problem in ( + instance.get('problem_list') if 'problem_list' in instance else [])] + # 插入問題 + QuerySet(Problem).bulk_create(problem_model_list) if len(problem_model_list) > 0 else None + if with_embedding: + ListenerManagement.embedding_by_paragraph_signal.send(str(paragraph.id)) + return ParagraphSerializers.Operate( + data={'paragraph_id': str(paragraph.id), 'dataset_id': dataset_id, 'document_id': document_id}).one( + with_valid=True) + + @staticmethod + def get_request_body_api(): + return ParagraphInstanceSerializer.get_request_body_api() + + @staticmethod + def get_request_params_api(): + return [openapi.Parameter(name='dataset_id', + in_=openapi.IN_PATH, + type=openapi.TYPE_STRING, + required=True, + description='数据集id'), + openapi.Parameter(name='document_id', in_=openapi.IN_PATH, + type=openapi.TYPE_STRING, + required=True, + description="文档id") + ] + + class Query(ApiMixin, serializers.Serializer): + dataset_id = serializers.UUIDField(required=True) + + document_id = serializers.UUIDField(required=True) + + title = serializers.CharField(required=False) + + def get_query_set(self): + query_set = QuerySet(model=Paragraph) + query_set = query_set.filter( + **{'dataset_id': self.data.get('dataset_id'), 'document_id': self.data.get("document_id")}) + if 'title' in self.data: + query_set = query_set.filter( + **{'title__contains': self.data.get('title')}) + return query_set + + def list(self): + return list(map(lambda row: ParagraphSerializer(row).data, self.get_query_set())) + + def page(self, current_page, page_size): + query_set = self.get_query_set() + return page_search(current_page, page_size, query_set, lambda row: ParagraphSerializer(row).data) + + @staticmethod + def get_request_params_api(): + return [openapi.Parameter(name='document_id', + in_=openapi.IN_PATH, + type=openapi.TYPE_STRING, + required=True, + description='文档id'), + openapi.Parameter(name='title', + in_=openapi.IN_QUERY, + type=openapi.TYPE_STRING, + required=False, + description='标题') + ] + + @staticmethod + def get_response_body_api(): + return openapi.Schema( + type=openapi.TYPE_OBJECT, + required=['id', 'content', 'hit_num', 'star_num', 'trample_num', 'is_active', 'document_id', 'title', + 'create_time', 'update_time'], + properties={ + 'id': openapi.Schema(type=openapi.TYPE_STRING, title="id", + description="id", default="xx"), + 'content': openapi.Schema(type=openapi.TYPE_STRING, title="段落内容", + description="段落内容", default='段落内容'), + 'title': openapi.Schema(type=openapi.TYPE_STRING, title="标题", + description="标题", default="xxx的描述"), + 'hit_num': openapi.Schema(type=openapi.TYPE_INTEGER, title="命中数量", description="命中数量", + default=1), + 'star_num': openapi.Schema(type=openapi.TYPE_INTEGER, title="点赞数量", + description="点赞数量", default=1), + 'trample_num': openapi.Schema(type=openapi.TYPE_INTEGER, title="点踩数量", + description="点踩数", default=1), + 'document_id': openapi.Schema(type=openapi.TYPE_STRING, title="文档id", + description="文档id", default='xxx'), + 'is_active': openapi.Schema(type=openapi.TYPE_BOOLEAN, title="是否可用", + description="是否可用", default=True), + 'update_time': openapi.Schema(type=openapi.TYPE_STRING, title="修改时间", + description="修改时间", + default="1970-01-01 00:00:00"), + 'create_time': openapi.Schema(type=openapi.TYPE_STRING, title="创建时间", + description="创建时间", + default="1970-01-01 00:00:00" + ) + } + ) diff --git a/apps/dataset/serializers/problem_serializers.py b/apps/dataset/serializers/problem_serializers.py new file mode 100644 index 000000000..16c9ff1fa --- /dev/null +++ b/apps/dataset/serializers/problem_serializers.py @@ -0,0 +1,222 @@ +# coding=utf-8 +""" + @project: maxkb + @Author:虎 + @file: problem_serializers.py + @date:2023/10/23 13:55 + @desc: +""" +import uuid +from typing import Dict + +from django.db.models import QuerySet +from drf_yasg import openapi +from rest_framework import serializers + +from common.event.listener_manage import ListenerManagement +from common.exception.app_exception import AppApiException +from common.mixins.api_mixin import ApiMixin +from dataset.models import Problem, Paragraph +from embedding.models import SourceType +from embedding.vector.pg_vector import PGVector + + +class ProblemSerializer(serializers.ModelSerializer): + class Meta: + model = Problem + fields = ['id', 'content', 'hit_num', 'star_num', 'trample_num', 'dataset_id', 'document_id', + 'create_time', 'update_time'] + + +class ProblemInstanceSerializer(ApiMixin, serializers.Serializer): + id = serializers.CharField(required=False) + + content = serializers.CharField(required=True) + + @staticmethod + def get_request_body_api(): + return openapi.Schema(type=openapi.TYPE_OBJECT, + required=["content"], + properties={ + 'id': openapi.Schema( + type=openapi.TYPE_STRING, + title="问题id,修改的时候传递,创建的时候不传"), + 'content': openapi.Schema( + type=openapi.TYPE_STRING, title="内容") + }) + + +class ProblemSerializers(ApiMixin, serializers.Serializer): + class Create(ApiMixin, serializers.Serializer): + dataset_id = serializers.UUIDField(required=True) + + document_id = serializers.UUIDField(required=True) + + paragraph_id = serializers.UUIDField(required=True) + + def save(self, instance: Dict, with_valid=True, with_embedding=True): + if with_valid: + self.is_valid() + ProblemInstanceSerializer(data=instance).is_valid(raise_exception=True) + problem = Problem(id=uuid.uuid1(), paragraph_id=self.data.get('paragraph_id'), + document_id=self.data.get('document_id'), dataset_id=self.data.get('dataset_id'), + content=instance.get('content')) + problem.save() + if with_embedding: + ListenerManagement.embedding_by_problem_signal.send({'text': problem.content, + 'is_active': True, + 'source_type': SourceType.PROBLEM, + 'source_id': problem.id, + 'document_id': self.data.get('document_id'), + 'paragraph_id': self.data.get('paragraph_id'), + 'dataset_id': self.data.get('dataset_id')}) + + return ProblemSerializers.Operate( + data={'dataset_id': self.data.get('dataset_id'), 'document_id': self.data.get('document_id'), + 'paragraph_id': self.data.get('paragraph_id'), 'problem_id': problem.id}).one(with_valid=True) + + @staticmethod + def get_request_body_api(): + return ProblemInstanceSerializer.get_request_body_api() + + @staticmethod + def get_request_params_api(): + return [openapi.Parameter(name='dataset_id', + in_=openapi.IN_PATH, + type=openapi.TYPE_STRING, + required=True, + description='数据集id'), + openapi.Parameter(name='document_id', + in_=openapi.IN_PATH, + type=openapi.TYPE_STRING, + required=True, + description='文档id'), + openapi.Parameter(name='paragraph_id', + in_=openapi.IN_PATH, + type=openapi.TYPE_STRING, + required=True, + description='段落id')] + + class Query(ApiMixin, serializers.Serializer): + dataset_id = serializers.UUIDField(required=True) + + document_id = serializers.UUIDField(required=True) + + paragraph_id = serializers.UUIDField(required=True) + + def is_valid(self, *, raise_exception=True): + super().is_valid(raise_exception=True) + if not QuerySet(Paragraph).filter(id=self.data.get('paragraph_id')).exists(): + raise AppApiException(500, "段落id不存在") + + def get_query_set(self): + dataset_id = self.data.get('dataset_id') + document_id = self.data.get('document_id') + paragraph_id = self.data.get("paragraph_id") + return QuerySet(Problem).filter( + **{'paragraph_id': paragraph_id, 'dataset_id': dataset_id, 'document_id': document_id}) + + def list(self, with_valid=False): + """ + 获取问题列表 + :param with_valid: 是否校验 + :return: 问题列表 + """ + if with_valid: + self.is_valid(raise_exception=True) + query_set = self.get_query_set() + return [ProblemSerializer(p).data for p in query_set] + + @staticmethod + def get_request_params_api(): + return [openapi.Parameter(name='dataset_id', + in_=openapi.IN_PATH, + type=openapi.TYPE_STRING, + required=True, + description='数据集id'), + openapi.Parameter(name='document_id', + in_=openapi.IN_PATH, + type=openapi.TYPE_STRING, + required=True, + description='文档id') + , openapi.Parameter(name='paragraph_id', + in_=openapi.IN_PATH, + type=openapi.TYPE_STRING, + required=True, + description='段落id')] + + class Operate(ApiMixin, serializers.Serializer): + dataset_id = serializers.UUIDField(required=True) + + document_id = serializers.UUIDField(required=True) + + paragraph_id = serializers.UUIDField(required=True) + + problem_id = serializers.UUIDField(required=True) + + def delete(self, with_valid=False): + if with_valid: + self.is_valid(raise_exception=True) + QuerySet(Problem).filter(**{'id': self.data.get('problem_id')}).delete() + PGVector().delete_by_source_id(self.data.get('problem_id'), SourceType.PROBLEM) + ListenerManagement.delete_embedding_by_source_signal.send(self.data.get('problem_id')) + return True + + def one(self, with_valid=False): + if with_valid: + self.is_valid(raise_exception=True) + return ProblemInstanceSerializer(QuerySet(Problem).get(**{'id': self.data.get('problem_id')})).data + + @staticmethod + def get_request_params_api(): + return [openapi.Parameter(name='dataset_id', + in_=openapi.IN_PATH, + type=openapi.TYPE_STRING, + required=True, + description='数据集id'), + openapi.Parameter(name='document_id', + in_=openapi.IN_PATH, + type=openapi.TYPE_STRING, + required=True, + description='文档id') + , openapi.Parameter(name='paragraph_id', + in_=openapi.IN_PATH, + type=openapi.TYPE_STRING, + required=True, + description='段落id'), + openapi.Parameter(name='problem_id', + in_=openapi.IN_PATH, + type=openapi.TYPE_STRING, + required=True, + description='问题id') + ] + + @staticmethod + def get_response_body_api(): + return openapi.Schema( + type=openapi.TYPE_OBJECT, + required=['id', 'content', 'hit_num', 'star_num', 'trample_num', 'dataset_id', + 'document_id', + 'create_time', 'update_time'], + properties={ + 'id': openapi.Schema(type=openapi.TYPE_STRING, title="id", + description="id", default="xx"), + 'content': openapi.Schema(type=openapi.TYPE_STRING, title="问题内容", + description="问题内容", default='问题内容'), + 'hit_num': openapi.Schema(type=openapi.TYPE_INTEGER, title="命中数量", description="命中数量", + default=1), + 'star_num': openapi.Schema(type=openapi.TYPE_INTEGER, title="点赞数量", + description="点赞数量", default=1), + 'trample_num': openapi.Schema(type=openapi.TYPE_INTEGER, title="点踩数量", + description="点踩数", default=1), + 'document_id': openapi.Schema(type=openapi.TYPE_STRING, title="文档id", + description="文档id", default='xxx'), + 'update_time': openapi.Schema(type=openapi.TYPE_STRING, title="修改时间", + description="修改时间", + default="1970-01-01 00:00:00"), + 'create_time': openapi.Schema(type=openapi.TYPE_STRING, title="创建时间", + description="创建时间", + default="1970-01-01 00:00:00" + ) + } + ) diff --git a/apps/dataset/sql/list_document.sql b/apps/dataset/sql/list_document.sql new file mode 100644 index 000000000..7d8d6968f --- /dev/null +++ b/apps/dataset/sql/list_document.sql @@ -0,0 +1,5 @@ +SELECT + "document".* , + (SELECT "count"("id") FROM "paragraph" WHERE document_id="document"."id") as "paragraph_count" +FROM + "document" "document" diff --git a/apps/dataset/sql/list_problem.sql b/apps/dataset/sql/list_problem.sql new file mode 100644 index 000000000..3183f0da8 --- /dev/null +++ b/apps/dataset/sql/list_problem.sql @@ -0,0 +1,10 @@ +SELECT + problem."id", + problem."content", + problem_paragraph_mapping.hit_num, + problem_paragraph_mapping.star_num, + problem_paragraph_mapping.trample_num, + problem_paragraph_mapping.paragraph_id + FROM + problem problem + LEFT JOIN problem_paragraph_mapping problem_paragraph_mapping ON problem."id" = problem_paragraph_mapping.problem_id diff --git a/apps/dataset/urls.py b/apps/dataset/urls.py index 8e196cd26..a779642ff 100644 --- a/apps/dataset/urls.py +++ b/apps/dataset/urls.py @@ -7,5 +7,18 @@ urlpatterns = [ path('dataset', views.Dataset.as_view(), name="dataset"), path('dataset/', views.Dataset.Operate.as_view(), name="dataset_key"), path('dataset//', views.Dataset.Page.as_view(), name="dataset"), - path('dataset//document', views.Document.as_view(), name='document') + path('dataset//document', views.Document.as_view(), name='document'), + path('dataset//document/', views.Document.Operate.as_view(), + name="document_operate"), + path('dataset/document/split', views.Document.Split.as_view(), + name="document_operate"), + path('dataset//document//paragraph', views.Paragraph.as_view()), + path('dataset//document//paragraph//', + views.Paragraph.Page.as_view(), name='paragraph_page'), + path('dataset//document//paragraph/', + views.Paragraph.Operate.as_view()), + path('dataset//document//paragraph//problem', + views.Problem.as_view()), + path('dataset//document//paragraph//problem/', + views.Problem.Operate.as_view()) ] diff --git a/apps/dataset/views/__init__.py b/apps/dataset/views/__init__.py index 413cf8363..b82d9ef72 100644 --- a/apps/dataset/views/__init__.py +++ b/apps/dataset/views/__init__.py @@ -8,3 +8,5 @@ """ from .dataset import * from .document import * +from .paragraph import * +from .problem import * diff --git a/apps/dataset/views/dataset.py b/apps/dataset/views/dataset.py index 3079c7bc2..175bedc40 100644 --- a/apps/dataset/views/dataset.py +++ b/apps/dataset/views/dataset.py @@ -26,7 +26,7 @@ class Dataset(APIView): @swagger_auto_schema(operation_summary="获取数据集列表", operation_id="获取数据集列表", manual_parameters=DataSetSerializers.Query.get_request_params_api(), - responses=get_api_response(DataSetSerializers.Query.get_response_body_api())) + responses=result.get_api_array_response(DataSetSerializers.Query.get_response_body_api())) @has_permissions(PermissionConstants.DATASET_READ, compare=CompareConstants.AND) def get(self, request: Request): d = DataSetSerializers.Query(data={**request.query_params, 'user_id': str(request.user.id)}) @@ -36,19 +36,21 @@ class Dataset(APIView): @action(methods=['POST'], detail=False) @swagger_auto_schema(operation_summary="创建数据集", operation_id="创建数据集", - request_body=DataSetSerializers.Create.get_request_body_api()) + request_body=DataSetSerializers.Create.get_request_body_api(), + responses=get_api_response(DataSetSerializers.Create.get_response_body_api())) @has_permissions(PermissionConstants.DATASET_CREATE, compare=CompareConstants.AND) def post(self, request: Request): s = DataSetSerializers.Create(data=request.data) - if s.is_valid(): - s.save(request.user) - return result.success("ok") + s.is_valid(raise_exception=True) + return result.success(s.save(request.user)) class Operate(APIView): authentication_classes = [TokenAuth] @action(methods="DELETE", detail=False) - @swagger_auto_schema(operation_summary="删除数据集", operation_id="删除数据集") + @swagger_auto_schema(operation_summary="删除数据集", operation_id="删除数据集", + manual_parameters=DataSetSerializers.Operate.get_request_params_api(), + responses=result.get_default_response()) @has_permissions(lambda r, keywords: Permission(group=Group.DATASET, operate=Operate.MANAGE, dynamic_tag=keywords.get('dataset_id')), lambda r, k: Permission(group=Group.DATASET, operate=Operate.DELETE, @@ -59,6 +61,7 @@ class Dataset(APIView): @action(methods="GET", detail=False) @swagger_auto_schema(operation_summary="查询数据集详情根据数据集id", operation_id="查询数据集详情根据数据集id", + manual_parameters=DataSetSerializers.Operate.get_request_params_api(), responses=get_api_response(DataSetSerializers.Operate.get_response_body_api())) @has_permissions(lambda r, keywords: Permission(group=Group.DATASET, operate=Operate.USE, dynamic_tag=keywords.get('dataset_id'))) @@ -67,6 +70,7 @@ class Dataset(APIView): @action(methods="PUT", detail=False) @swagger_auto_schema(operation_summary="修改数据集信息", operation_id="修改数据集信息", + manual_parameters=DataSetSerializers.Operate.get_request_params_api(), request_body=DataSetSerializers.Operate.get_request_body_api(), responses=get_api_response(DataSetSerializers.Operate.get_response_body_api())) @has_permissions(lambda r, keywords: Permission(group=Group.DATASET, operate=Operate.MANAGE, @@ -84,8 +88,10 @@ class Dataset(APIView): manual_parameters=get_page_request_params( DataSetSerializers.Query.get_request_params_api()), responses=get_page_api_response(DataSetSerializers.Query.get_response_body_api())) - @has_permissions(PermissionConstants.USER_READ, compare=CompareConstants.AND) + @has_permissions(PermissionConstants.DATASET_READ, compare=CompareConstants.AND) def get(self, request: Request, current_page, page_size): - d = DataSetSerializers.Query(data={**request.query_params, 'user_id': str(request.user.id)}) + d = DataSetSerializers.Query( + data={'name': request.query_params.get('name', None), 'desc': request.query_params.get("desc", None), + 'user_id': str(request.user.id)}) d.is_valid() return result.success(d.page(current_page, page_size)) diff --git a/apps/dataset/views/document.py b/apps/dataset/views/document.py index 7f759e0ac..9354e7b25 100644 --- a/apps/dataset/views/document.py +++ b/apps/dataset/views/document.py @@ -9,13 +9,15 @@ from drf_yasg.utils import swagger_auto_schema from rest_framework.decorators import action +from rest_framework.parsers import MultiPartParser from rest_framework.views import APIView from rest_framework.views import Request from common.auth import TokenAuth, has_permissions from common.constants.permission_constants import Permission, Group, Operate, PermissionConstants from common.response import result -from dataset.serializers.dataset_serializers import CreateDocumentSerializers +from common.util.common import query_params_to_single_dict +from dataset.serializers.document_serializers import DocumentSerializers class Document(APIView): @@ -24,28 +26,102 @@ class Document(APIView): @action(methods=['POST'], detail=False) @swagger_auto_schema(operation_summary="创建文档", operation_id="创建文档", - request_body=CreateDocumentSerializers().get_request_body_api(), - manual_parameters=CreateDocumentSerializers().get_request_params_api()) - @has_permissions(PermissionConstants.DATASET_CREATE) + request_body=DocumentSerializers.Create.get_request_body_api(), + manual_parameters=DocumentSerializers.Create.get_request_params_api(), + responses=result.get_api_response(DocumentSerializers.Operate.get_response_body_api())) + @has_permissions( + lambda r, k: Permission(group=Group.DATASET, operate=Operate.MANAGE, + dynamic_tag=k.get('dataset_id'))) def post(self, request: Request, dataset_id: str): - d = CreateDocumentSerializers(data=request.data) - if d.is_valid(dataset_id=dataset_id): - d.save(dataset_id) - return result.success("ok") - - -class DocumentDetails(APIView): - authentication_classes = [TokenAuth] + return result.success( + DocumentSerializers.Create(data={'dataset_id': dataset_id}).save(request.data, with_valid=True)) @action(methods=['GET'], detail=False) - @swagger_auto_schema(operation_summary="获取文档详情", - operation_id="获取文档详情", - request_body=CreateDocumentSerializers().get_request_body_api(), - manual_parameters=CreateDocumentSerializers().get_request_params_api()) + @swagger_auto_schema(operation_summary="文档列表", + operation_id="文档列表", + manual_parameters=DocumentSerializers.Query.get_request_params_api(), + responses=result.get_api_response(DocumentSerializers.Query.get_response_body_api())) @has_permissions( - lambda r, k: Permission(group=Group.DATASET, operate=Operate.MANAGE, dynamic_tag=k.get('dataset_id'))) + lambda r, k: Permission(group=Group.DATASET, operate=Operate.USE, + dynamic_tag=k.get('dataset_id'))) def get(self, request: Request, dataset_id: str): - d = CreateDocumentSerializers(data=request.data) - if d.is_valid(dataset_id=dataset_id): - d.save(dataset_id) - return result.success("ok") + d = DocumentSerializers.Query( + data={**query_params_to_single_dict(request.query_params), 'dataset_id': dataset_id}) + d.is_valid(raise_exception=True) + return result.success(d.list()) + + class Operate(APIView): + authentication_classes = [TokenAuth] + + @action(methods=['GET'], detail=False) + @swagger_auto_schema(operation_summary="获取文档详情", + operation_id="获取文档详情", + manual_parameters=DocumentSerializers.Operate.get_request_params_api(), + responses=result.get_api_response(DocumentSerializers.Operate.get_response_body_api())) + @has_permissions( + lambda r, k: Permission(group=Group.DATASET, operate=Operate.USE, + dynamic_tag=k.get('dataset_id'))) + def get(self, request: Request, dataset_id: str, document_id: str): + operate = DocumentSerializers.Operate(data={'document_id': document_id, 'dataset_id': dataset_id}) + operate.is_valid(raise_exception=True) + return result.success(operate.one()) + + @action(methods=['PUT'], detail=False) + @swagger_auto_schema(operation_summary="修改文档", + operation_id="修改文档", + manual_parameters=DocumentSerializers.Operate.get_request_params_api(), + request_body=DocumentSerializers.Operate.get_request_body_api(), + responses=result.get_api_response(DocumentSerializers.Operate.get_response_body_api()) + ) + @has_permissions( + lambda r, k: Permission(group=Group.DATASET, operate=Operate.MANAGE, + dynamic_tag=k.get('dataset_id'))) + def put(self, request: Request, dataset_id: str, document_id: str): + return result.success( + DocumentSerializers.Operate(data={'document_id': document_id, 'dataset_id': dataset_id}).edit( + request.data, + with_valid=True)) + + @action(methods=['DELETE'], detail=False) + @swagger_auto_schema(operation_summary="删除文档", + operation_id="删除文档", + manual_parameters=DocumentSerializers.Operate.get_request_params_api(), + responses=result.get_default_response()) + @has_permissions( + lambda r, k: Permission(group=Group.DATASET, operate=Operate.MANAGE, + dynamic_tag=k.get('dataset_id'))) + def delete(self, request: Request, dataset_id: str, document_id: str): + operate = DocumentSerializers.Operate(data={'document_id': document_id, 'dataset_id': dataset_id}) + operate.is_valid(raise_exception=True) + return result.success(operate.delete()) + + class Split(APIView): + parser_classes = [MultiPartParser] + + @action(methods=['POST'], detail=False) + @swagger_auto_schema(operation_summary="分段文档", + operation_id="分段文档", + manual_parameters=DocumentSerializers.Split.get_request_params_api()) + def post(self, request: Request): + ds = DocumentSerializers.Split( + data={'file': request.FILES.getlist('file'), + 'patterns': request.data.getlist('patterns[]')}) + ds.is_valid(raise_exception=True) + return result.success(ds.parse()) + + class Page(APIView): + authentication_classes = [TokenAuth] + + @action(methods=['GET'], detail=False) + @swagger_auto_schema(operation_summary="获取数据集分页列表", + operation_id="获取数据集分页列表", + manual_parameters=DocumentSerializers.Query.get_request_params_api(), + responses=result.get_page_api_response(DocumentSerializers.Query.get_response_body_api())) + @has_permissions( + lambda r, k: Permission(group=Group.DATASET, operate=Operate.USE, + dynamic_tag=k.get('dataset_id'))) + def get(self, request: Request, dataset_id: str, current_page, page_size): + d = DocumentSerializers.Query( + data={**query_params_to_single_dict(request.query_params), 'dataset_id': dataset_id}) + d.is_valid(raise_exception=True) + return result.success(d.page(current_page, page_size)) diff --git a/apps/dataset/views/paragraph.py b/apps/dataset/views/paragraph.py new file mode 100644 index 000000000..37fe76fda --- /dev/null +++ b/apps/dataset/views/paragraph.py @@ -0,0 +1,115 @@ +# coding=utf-8 +""" + @project: maxkb + @Author:虎 + @file: paragraph_serializers.py + @date:2023/10/16 15:51 + @desc: +""" +from drf_yasg.utils import swagger_auto_schema +from rest_framework.decorators import action +from rest_framework.views import APIView +from rest_framework.views import Request + +from common.auth import TokenAuth, has_permissions +from common.constants.permission_constants import Permission, Group, Operate +from common.response import result +from common.util.common import query_params_to_single_dict +from dataset.serializers.paragraph_serializers import ParagraphSerializers + + +class Paragraph(APIView): + authentication_classes = [TokenAuth] + + @action(methods=['GET'], detail=False) + @swagger_auto_schema(operation_summary="段落列表", + operation_id="段落列表", + manual_parameters=ParagraphSerializers.Query.get_request_params_api(), + responses=result.get_api_array_response(ParagraphSerializers.Query.get_response_body_api())) + @has_permissions( + lambda r, k: Permission(group=Group.DATASET, operate=Operate.USE, + dynamic_tag=k.get('dataset_id'))) + def get(self, request: Request, dataset_id: str, document_id: str): + q = ParagraphSerializers.Query( + data={**query_params_to_single_dict(request.query_params), 'dataset_id': dataset_id, + 'document_id': document_id}) + q.is_valid(raise_exception=True) + return result.success(q.list()) + + @action(methods=['POST'], detail=False) + @swagger_auto_schema(operation_summary="创建段落", + operation_id="创建段落", + manual_parameters=ParagraphSerializers.Create.get_request_params_api(), + request_body=ParagraphSerializers.Create.get_request_body_api(), + responses=result.get_api_response(ParagraphSerializers.Query.get_response_body_api())) + @has_permissions( + lambda r, k: Permission(group=Group.DATASET, operate=Operate.MANAGE, + dynamic_tag=k.get('dataset_id'))) + def post(self, request: Request, dataset_id: str, document_id: str): + return result.success( + ParagraphSerializers.Create(data={'dataset_id': dataset_id, 'document_id': document_id}).save(request.data)) + + class Operate(APIView): + authentication_classes = [TokenAuth] + + @action(methods=['UPDATE'], detail=False) + @swagger_auto_schema(operation_summary="修改段落数据", + operation_id="修改段落数据", + manual_parameters=ParagraphSerializers.Operate.get_request_params_api(), + request_body=ParagraphSerializers.Operate.get_request_body_api(), + responses=result.get_api_response(ParagraphSerializers.Operate.get_response_body_api())) + @has_permissions( + lambda r, k: Permission(group=Group.DATASET, operate=Operate.MANAGE, + dynamic_tag=k.get('dataset_id'))) + def put(self, request: Request, dataset_id: str, document_id: str, paragraph_id: str): + o = ParagraphSerializers.Operate( + data={"paragraph_id": paragraph_id, 'dataset_id': dataset_id, 'document_id': document_id}) + o.is_valid(raise_exception=True) + return result.success(o.edit(request.data)) + + @action(methods=['UPDATE'], detail=False) + @swagger_auto_schema(operation_summary="获取段落详情", + operation_id="获取段落详情", + manual_parameters=ParagraphSerializers.Operate.get_request_params_api(), + responses=result.get_api_response(ParagraphSerializers.Operate.get_response_body_api())) + @has_permissions( + lambda r, k: Permission(group=Group.DATASET, operate=Operate.USE, + dynamic_tag=k.get('dataset_id'))) + def get(self, request: Request, dataset_id: str, document_id: str, paragraph_id: str): + o = ParagraphSerializers.Operate( + data={"dataset_id": dataset_id, 'document_id': document_id, "paragraph_id": paragraph_id}) + o.is_valid(raise_exception=True) + return result.success(o.one()) + + @action(methods=['DELETE'], detail=False) + @swagger_auto_schema(operation_summary="删除段落", + operation_id="删除段落", + manual_parameters=ParagraphSerializers.Operate.get_request_params_api(), + responses=result.get_default_response()) + @has_permissions( + lambda r, k: Permission(group=Group.DATASET, operate=Operate.MANAGE, + dynamic_tag=k.get('dataset_id'))) + def delete(self, request: Request, dataset_id: str, document_id: str, paragraph_id: str): + o = ParagraphSerializers.Operate( + data={"dataset_id": dataset_id, 'document_id': document_id, "paragraph_id": paragraph_id}) + o.is_valid(raise_exception=True) + return result.success(o.delete()) + + class Page(APIView): + authentication_classes = [TokenAuth] + + @action(methods=['GET'], detail=False) + @swagger_auto_schema(operation_summary="分页获取段落列表", + operation_id="分页获取段落列表", + manual_parameters=result.get_page_request_params( + ParagraphSerializers.Query.get_request_params_api()), + responses=result.get_page_api_response(ParagraphSerializers.Query.get_response_body_api())) + @has_permissions( + lambda r, k: Permission(group=Group.DATASET, operate=Operate.USE, + dynamic_tag=k.get('dataset_id'))) + def get(self, request: Request, dataset_id: str, document_id: str, current_page, page_size): + d = ParagraphSerializers.Query( + data={**query_params_to_single_dict(request.query_params), 'dataset_id': dataset_id, + 'document_id': document_id}) + d.is_valid(raise_exception=True) + return result.success(d.page(current_page, page_size)) diff --git a/apps/dataset/views/problem.py b/apps/dataset/views/problem.py new file mode 100644 index 000000000..d12064fbb --- /dev/null +++ b/apps/dataset/views/problem.py @@ -0,0 +1,65 @@ +# coding=utf-8 +""" + @project: maxkb + @Author:虎 + @file: problem.py + @date:2023/10/23 13:54 + @desc: +""" +from drf_yasg.utils import swagger_auto_schema +from rest_framework.decorators import action +from rest_framework.request import Request +from rest_framework.views import APIView + +from common.auth import TokenAuth, has_permissions +from common.constants.permission_constants import Permission, Group, Operate +from common.response import result +from dataset.serializers.problem_serializers import ProblemSerializers + + +class Problem(APIView): + authentication_classes = [TokenAuth] + + @action(methods=['POST'], detail=False) + @swagger_auto_schema(operation_summary="添加关联问题", + operation_id="添加段落关联问题", + manual_parameters=ProblemSerializers.Create.get_request_params_api(), + request_body=ProblemSerializers.Create.get_request_body_api(), + responses=result.get_api_response(ProblemSerializers.Operate.get_response_body_api())) + @has_permissions( + lambda r, k: Permission(group=Group.DATASET, operate=Operate.MANAGE, + dynamic_tag=k.get('dataset_id'))) + def post(self, request: Request, dataset_id: str, document_id: str, paragraph_id: str): + return result.success(ProblemSerializers.Create( + data={"dataset_id": dataset_id, 'document_id': document_id, 'paragraph_id': paragraph_id}).save( + request.data, with_valid=True)) + + @action(methods=['GET'], detail=False) + @swagger_auto_schema(operation_summary="获取段落问题列表", + operation_id="获取段落问题列表", + manual_parameters=ProblemSerializers.Query.get_request_params_api(), + responses=result.get_api_array_response(ProblemSerializers.Operate.get_response_body_api())) + @has_permissions( + lambda r, k: Permission(group=Group.DATASET, operate=Operate.USE, + dynamic_tag=k.get('dataset_id'))) + def get(self, request: Request, dataset_id: str, document_id: str, paragraph_id: str): + return result.success(ProblemSerializers.Query( + data={"dataset_id": dataset_id, 'document_id': document_id, 'paragraph_id': paragraph_id}).list( + with_valid=True)) + + class Operate(APIView): + authentication_classes = [TokenAuth] + + @action(methods=['DELETE'], detail=False) + @swagger_auto_schema(operation_summary="删除段落问题", + operation_id="删除段落问题", + manual_parameters=ProblemSerializers.Query.get_request_params_api(), + responses=result.get_default_response()) + @has_permissions( + lambda r, k: Permission(group=Group.DATASET, operate=Operate.MANAGE, + dynamic_tag=k.get('dataset_id'))) + def delete(self, request: Request, dataset_id: str, document_id: str, paragraph_id: str, problem_id: str): + o = ProblemSerializers.Operate( + data={'dataset_id': dataset_id, 'document_id': document_id, 'paragraph_id': paragraph_id, + 'problem_id': problem_id}) + return result.success(o.delete(with_valid=True)) diff --git a/apps/embedding/migrations/0001_initial.py b/apps/embedding/migrations/0001_initial.py index 968025f20..5f9bc91aa 100644 --- a/apps/embedding/migrations/0001_initial.py +++ b/apps/embedding/migrations/0001_initial.py @@ -1,4 +1,4 @@ -# Generated by Django 4.1.10 on 2023-10-09 06:33 +# Generated by Django 4.1.10 on 2023-10-24 12:13 import common.field.vector_field from django.db import migrations, models @@ -20,8 +20,11 @@ class Migration(migrations.Migration): ('id', models.CharField(max_length=128, primary_key=True, serialize=False, verbose_name='主键id')), ('source_id', models.CharField(max_length=128, verbose_name='资源id')), ('source_type', models.CharField(choices=[('0', '问题'), ('1', '段落')], default='0', max_length=1, verbose_name='资源类型')), + ('is_active', models.BooleanField(default=True, max_length=1, verbose_name='是否可用')), ('embedding', common.field.vector_field.VectorField(verbose_name='向量')), - ('dataset', models.ForeignKey(on_delete=django.db.models.deletion.DO_NOTHING, to='dataset.dataset', verbose_name='数据集关联')), + ('dataset', models.ForeignKey(db_constraint=False, on_delete=django.db.models.deletion.DO_NOTHING, to='dataset.dataset', verbose_name='文档关联')), + ('document', models.ForeignKey(db_constraint=False, on_delete=django.db.models.deletion.DO_NOTHING, to='dataset.document', verbose_name='文档关联')), + ('paragraph', models.ForeignKey(db_constraint=False, on_delete=django.db.models.deletion.DO_NOTHING, to='dataset.paragraph', verbose_name='段落关联')), ], options={ 'db_table': 'embedding', diff --git a/apps/embedding/models/embedding.py b/apps/embedding/models/embedding.py index 816dcd608..e3606f664 100644 --- a/apps/embedding/models/embedding.py +++ b/apps/embedding/models/embedding.py @@ -9,7 +9,7 @@ from django.db import models from common.field.vector_field import VectorField -from dataset.models.data_set import DataSet +from dataset.models.data_set import Document, Paragraph, DataSet class SourceType(models.TextChoices): @@ -26,7 +26,13 @@ class Embedding(models.Model): source_type = models.CharField(verbose_name='资源类型', max_length=1, choices=SourceType.choices, default=SourceType.PROBLEM) - dataset = models.ForeignKey(DataSet, on_delete=models.DO_NOTHING, verbose_name="数据集关联") + is_active = models.BooleanField(verbose_name="是否可用", max_length=1, default=True) + + dataset = models.ForeignKey(DataSet, on_delete=models.DO_NOTHING, verbose_name="文档关联", db_constraint=False) + + document = models.ForeignKey(Document, on_delete=models.DO_NOTHING, verbose_name="文档关联", db_constraint=False) + + paragraph = models.ForeignKey(Paragraph, on_delete=models.DO_NOTHING, verbose_name="段落关联", db_constraint=False) embedding = VectorField(verbose_name="向量") diff --git a/apps/embedding/vector/base_vector.py b/apps/embedding/vector/base_vector.py new file mode 100644 index 000000000..90d57abb9 --- /dev/null +++ b/apps/embedding/vector/base_vector.py @@ -0,0 +1,117 @@ +# coding=utf-8 +""" + @project: maxkb + @Author:虎 + @file: base_vector.py + @date:2023/10/18 19:16 + @desc: +""" +from abc import ABC, abstractmethod +from typing import List, Dict + +from langchain.embeddings import HuggingFaceEmbeddings + +from common.config.embedding_config import EmbeddingModel +from embedding.models import SourceType + + +class BaseVectorStore(ABC): + vector_exists = False + + @abstractmethod + def vector_is_create(self) -> bool: + """ + 判断向量库是否创建 + :return: 是否创建向量库 + """ + pass + + @abstractmethod + def vector_create(self): + """ + 创建 向量库 + :return: + """ + pass + + def save_pre_handler(self): + """ + 插入前置处理器 主要是判断向量库是否创建 + :return: True + """ + if not BaseVectorStore.vector_exists: + if not self.vector_is_create(): + self.vector_create() + BaseVectorStore.vector_exists = True + return True + + def save(self, text, source_type: SourceType, dataset_id: str, document_id: str, paragraph_id: str, source_id: str, + is_active: bool, + embedding=None): + """ + 插入向量数据 + :param source_id: 资源id + :param dataset_id: 数据集id + :param text: 文本 + :param source_type: 资源类型 + :param document_id: 文档id + :param is_active: 是否禁用 + :param embedding: 向量化处理器 + :param paragraph_id 段落id + :return: bool + """ + if embedding is None: + embedding = EmbeddingModel.get_embedding_model() + self.save_pre_handler() + self._save(text, source_type, dataset_id, document_id, paragraph_id, source_id, is_active, embedding) + + def batch_save(self, data_list: List[Dict], embedding=None): + """ + 批量插入 + :param data_list: 数据列表 + :param embedding: 向量化处理器 + :return: bool + """ + if embedding is None: + embedding = EmbeddingModel.get_embedding_model() + self.save_pre_handler() + self._batch_save(data_list, embedding) + return True + + @abstractmethod + def _save(self, text, source_type: SourceType, dataset_id: str, document_id: str, paragraph_id: str, source_id: str, + is_active: bool, + embedding: HuggingFaceEmbeddings): + pass + + @abstractmethod + def _batch_save(self, text_list: List[Dict], embedding: HuggingFaceEmbeddings): + pass + + @abstractmethod + def search(self, query_text, dataset_id_list: list[str], is_active: bool, embedding: HuggingFaceEmbeddings): + pass + + @abstractmethod + def update_by_paragraph_id(self, paragraph_id: str, instance: Dict): + pass + + @abstractmethod + def update_by_source_id(self, source_id: str, instance: Dict): + pass + + @abstractmethod + def delete_by_dataset_id(self, dataset_id: str): + pass + + @abstractmethod + def delete_by_document_id(self, document_id: str): + pass + + @abstractmethod + def delete_by_source_id(self, source_id: str, source_type: str): + pass + + @abstractmethod + def delete_by_paragraph_id(self, paragraph_id: str): + pass diff --git a/apps/embedding/vector/pg_vector.py b/apps/embedding/vector/pg_vector.py new file mode 100644 index 000000000..2daa8cbed --- /dev/null +++ b/apps/embedding/vector/pg_vector.py @@ -0,0 +1,79 @@ +# coding=utf-8 +""" + @project: maxkb + @Author:虎 + @file: pg_vector.py + @date:2023/10/19 15:28 + @desc: +""" +import uuid +from typing import Dict, List + +from django.db.models import QuerySet +from langchain.embeddings import HuggingFaceEmbeddings + +from embedding.models import Embedding, SourceType +from embedding.vector.base_vector import BaseVectorStore + + +class PGVector(BaseVectorStore): + + def vector_is_create(self) -> bool: + # 项目启动默认是创建好的 不需要再创建 + return True + + def vector_create(self): + return True + + def _save(self, text, source_type: SourceType, dataset_id: str, document_id: str, paragraph_id: str, source_id: str, + is_active: bool, + embedding: HuggingFaceEmbeddings): + text_embedding = embedding.embed_query(text) + embedding = Embedding(id=uuid.uuid1(), + dataset_id=dataset_id, + document_id=document_id, + is_active=is_active, + paragraph_id=paragraph_id, + source_id=source_id, + embedding=text_embedding, + source_type=source_type, + ) + embedding.save() + return True + + def _batch_save(self, text_list: List[Dict], embedding: HuggingFaceEmbeddings): + texts = [row.get('text') for row in text_list] + embeddings = embedding.embed_documents(texts) + QuerySet(Embedding).bulk_create([Embedding(id=uuid.uuid1(), + document_id=text_list[index].get('document_id'), + paragraph_id=text_list[index].get('paragraph_id'), + dataset_id=text_list[index].get('dataset_id'), + is_active=text_list[index].get('is_active', True), + source_id=text_list[index].get('source_id'), + source_type=text_list[index].get('source_type'), + embedding=embeddings[index]) for index in + range(0, len(text_list))]) if len(text_list) > 0 else None + return True + + def search(self, query_text, dataset_id_list: list[str], is_active: bool, embedding: HuggingFaceEmbeddings): + pass + + def update_by_source_id(self, source_id: str, instance: Dict): + QuerySet(Embedding).filter(source_id=source_id).update(**instance) + + def update_by_paragraph_id(self, paragraph_id: str, instance: Dict): + QuerySet(Embedding).filter(paragraph_id=paragraph_id).update(**instance) + + def delete_by_dataset_id(self, dataset_id: str): + QuerySet(Embedding).filter(dataset_id=dataset_id).delete() + + def delete_by_document_id(self, document_id: str): + QuerySet(Embedding).filter(document_id=document_id).delete() + return True + + def delete_by_source_id(self, source_id: str, source_type: str): + QuerySet(Embedding).filter(source_id=source_id, source_type=source_type).delete() + return True + + def delete_by_paragraph_id(self, paragraph_id: str): + QuerySet(Embedding).filter(paragraph_id=paragraph_id).delete() diff --git a/apps/setting/migrations/0001_initial.py b/apps/setting/migrations/0001_initial.py index 330047abb..e74b97f2b 100644 --- a/apps/setting/migrations/0001_initial.py +++ b/apps/setting/migrations/0001_initial.py @@ -1,4 +1,4 @@ -# Generated by Django 4.1.10 on 2023-10-09 06:33 +# Generated by Django 4.1.10 on 2023-10-24 12:13 import django.contrib.postgres.fields from django.db import migrations, models diff --git a/apps/smartdoc/conf.py b/apps/smartdoc/conf.py index 0f36419cc..386bfb915 100644 --- a/apps/smartdoc/conf.py +++ b/apps/smartdoc/conf.py @@ -13,7 +13,7 @@ import os import re from importlib import import_module from urllib.parse import urljoin, urlparse - +import torch.backends import yaml BASE_DIR = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) @@ -88,7 +88,14 @@ class Config(dict): "EMAIL_HOST": "", "EMAIL_PORT": 465, "EMAIL_HOST_USER": "", - "EMAIL_HOST_PASSWORD": "" + "EMAIL_HOST_PASSWORD": "", + # 向量模型 + "EMBEDDING_MODEL_NAME": "shibing624/text2vec-base-chinese", + "EMBEDDING_DEVICE": "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu", + "EMBEDDING_MODEL_PATH": os.path.join(PROJECT_DIR, 'models'), + # 向量库配置 + "VECTOR_STORE_NAME": 'pg_vector' + } def get_db_setting(self) -> dict: @@ -120,6 +127,8 @@ class ConfigManager: def __init__(self, root_path=None): self.root_path = root_path self.config = self.config_class() + for key in self.config_class.defaults: + self.config[key] = self.config_class.defaults[key] def from_mapping(self, *mapping, **kwargs): """Updates the config like :meth:`update` ignoring items with non-upper diff --git a/apps/smartdoc/settings/logging.py b/apps/smartdoc/settings/logging.py index 3944cf84e..76c7c85df 100644 --- a/apps/smartdoc/settings/logging.py +++ b/apps/smartdoc/settings/logging.py @@ -100,6 +100,11 @@ LOGGING = { 'level': LOG_LEVEL, 'propagate': False, }, + 'sqlalchemy': { + 'handlers': ['console', 'file', 'syslog'], + 'level': LOG_LEVEL, + 'propagate': False, + }, 'django.db.backends': { 'handlers': ['console', 'file', 'syslog'], 'propagate': False, diff --git a/apps/smartdoc/wsgi.py b/apps/smartdoc/wsgi.py index 287bec89a..ce5346b2b 100644 --- a/apps/smartdoc/wsgi.py +++ b/apps/smartdoc/wsgi.py @@ -14,3 +14,12 @@ from django.core.wsgi import get_wsgi_application os.environ.setdefault('DJANGO_SETTINGS_MODULE', 'smartdoc.settings') application = get_wsgi_application() + + +def post_handler(): + from common.event.listener_manage import ListenerManagement + ListenerManagement().run() + ListenerManagement.init_embedding_model_signal.send() + + +post_handler() diff --git a/apps/users/apps.py b/apps/users/apps.py index 72b140106..1ea7bf62f 100644 --- a/apps/users/apps.py +++ b/apps/users/apps.py @@ -4,3 +4,4 @@ from django.apps import AppConfig class UsersConfig(AppConfig): default_auto_field = 'django.db.models.BigAutoField' name = 'users' + diff --git a/apps/users/migrations/0001_initial.py b/apps/users/migrations/0001_initial.py index 7bd5bb2fa..4e354adb4 100644 --- a/apps/users/migrations/0001_initial.py +++ b/apps/users/migrations/0001_initial.py @@ -1,4 +1,4 @@ -# Generated by Django 4.1.10 on 2023-10-09 06:33 +# Generated by Django 4.1.10 on 2023-10-24 12:13 from django.db import migrations, models import uuid diff --git a/apps/users/views/user.py b/apps/users/views/user.py index 8c16d11b3..44cde7973 100644 --- a/apps/users/views/user.py +++ b/apps/users/views/user.py @@ -22,7 +22,7 @@ from common.constants.permission_constants import PermissionConstants, CompareCo from common.response import result from smartdoc.settings import JWT_AUTH from users.models.user import User as UserModel -from users.serializers.user_serializers import RegisterSerializer, LoginSerializer, UserSerializer, CheckCodeSerializer, \ +from users.serializers.user_serializers import RegisterSerializer, LoginSerializer, CheckCodeSerializer, \ RePasswordSerializer, \ SendEmailSerializer, UserProfile diff --git a/pyproject.toml b/pyproject.toml index 3f5816b6d..e844bfa77 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -17,6 +17,12 @@ psycopg2-binary = "2.9.7" jieba = "^0.42.1" diskcache = "^5.6.3" pillow = "9.5.0" +filetype = "^1.2.0" +chardet = "^5.2.0" +torch = "^2.1.0" +sentence-transformers = "^2.2.2" +blinker = "^1.6.3" + [build-system] requires = ["poetry-core"]