diff --git a/apps/common/event/listener_manage.py b/apps/common/event/listener_manage.py index 9fc25e3b8..eebc3d406 100644 --- a/apps/common/event/listener_manage.py +++ b/apps/common/event/listener_manage.py @@ -21,7 +21,7 @@ from common.event.common import poxy, embedding_poxy from common.util.file_util import get_file_content from common.util.fork import ForkManage, Fork from common.util.lock import try_lock, un_lock -from dataset.models import Paragraph, Status, Document +from dataset.models import Paragraph, Status, Document, ProblemParagraphMapping from embedding.models import SourceType from smartdoc.conf import PROJECT_DIR @@ -44,6 +44,12 @@ class SyncWebDocumentArgs: self.handler = handler +class UpdateProblemArgs: + def __init__(self, problem_id: str, problem_content: str): + self.problem_id = problem_id + self.problem_content = problem_content + + class ListenerManagement: embedding_by_problem_signal = signal("embedding_by_problem") embedding_by_paragraph_signal = signal("embedding_by_paragraph") @@ -59,6 +65,8 @@ class ListenerManagement: init_embedding_model_signal = signal('init_embedding_model') sync_web_dataset_signal = signal('sync_web_dataset') sync_web_document_signal = signal('sync_web_document') + update_problem_signal = signal('update_problem') + delete_embedding_by_source_ids_signal = signal('delete_embedding_by_source_ids') @staticmethod def embedding_by_problem(args): @@ -76,8 +84,8 @@ class ListenerManagement: status = Status.success try: data_list = native_search( - {'problem': QuerySet(get_dynamics_model({'problem.paragraph_id': django.db.models.CharField()})).filter( - **{'problem.paragraph_id': paragraph_id}), + {'problem': QuerySet(get_dynamics_model({'paragraph.id': django.db.models.CharField()})).filter( + **{'paragraph.id': paragraph_id}), 'paragraph': QuerySet(Paragraph).filter(id=paragraph_id)}, select_string=get_file_content( os.path.join(PROJECT_DIR, "apps", "common", 'sql', 'list_embedding_text.sql'))) @@ -104,8 +112,9 @@ class ListenerManagement: status = Status.success try: data_list = native_search( - {'problem': QuerySet(get_dynamics_model({'problem.document_id': django.db.models.CharField()})).filter( - **{'problem.document_id': document_id}), + {'problem': QuerySet( + get_dynamics_model({'paragraph.document_id': django.db.models.CharField()})).filter( + **{'paragraph.document_id': document_id}), 'paragraph': QuerySet(Paragraph).filter(document_id=document_id)}, select_string=get_file_content( os.path.join(PROJECT_DIR, "apps", "common", 'sql', 'list_embedding_text.sql'))) @@ -188,6 +197,17 @@ class ListenerManagement: finally: un_lock('sync_web_dataset' + args.lock_key) + @staticmethod + def update_problem(args: UpdateProblemArgs): + problem_paragraph_mapping_list = QuerySet(ProblemParagraphMapping).filter(problem_id=args.problem_id) + embed_value = VectorStore.get_embedding_vector().embed_query(args.problem_content) + VectorStore.get_embedding_vector().update_by_source_ids([v.id for v in problem_paragraph_mapping_list], + {'embedding': embed_value}) + + @staticmethod + def delete_embedding_by_source_ids(source_ids: List[str]): + VectorStore.get_embedding_vector().delete_by_source_ids(source_ids, SourceType.PROBLEM) + @staticmethod @poxy def init_embedding_model(ags): @@ -225,3 +245,6 @@ class ListenerManagement: ListenerManagement.sync_web_dataset_signal.connect(self.sync_web_dataset) # 同步web站点 文档 ListenerManagement.sync_web_document_signal.connect(self.sync_web_document) + # 更新问题向量 + ListenerManagement.update_problem_signal.connect(self.update_problem) + ListenerManagement.delete_embedding_by_source_ids_signal.connect(self.delete_embedding_by_source_ids) diff --git a/apps/common/sql/list_embedding_text.sql b/apps/common/sql/list_embedding_text.sql index b051da1c6..74f3b224b 100644 --- a/apps/common/sql/list_embedding_text.sql +++ b/apps/common/sql/list_embedding_text.sql @@ -1,14 +1,15 @@ SELECT - problem."id" AS "source_id", - problem.document_id AS document_id, - problem.paragraph_id AS paragraph_id, + problem_paragraph_mapping."id" AS "source_id", + paragraph.document_id AS document_id, + paragraph."id" AS paragraph_id, problem.dataset_id AS dataset_id, 0 AS source_type, problem."content" AS "text", paragraph.is_active AS is_active FROM problem problem - LEFT JOIN paragraph paragraph ON paragraph."id" = problem.paragraph_id + LEFT JOIN problem_paragraph_mapping problem_paragraph_mapping ON problem_paragraph_mapping.problem_id=problem."id" + LEFT JOIN paragraph paragraph ON paragraph."id" = problem_paragraph_mapping.paragraph_id ${problem} UNION diff --git a/apps/dataset/migrations/0005_remove_problem_document_remove_problem_paragraph_and_more.py b/apps/dataset/migrations/0005_remove_problem_document_remove_problem_paragraph_and_more.py new file mode 100644 index 000000000..bd8491966 --- /dev/null +++ b/apps/dataset/migrations/0005_remove_problem_document_remove_problem_paragraph_and_more.py @@ -0,0 +1,59 @@ +# Generated by Django 4.1.10 on 2024-03-08 18:29 + +from django.db import migrations, models +import django.db.models.deletion +import uuid + +from embedding.models import SourceType + + +def delete_problem_embedding(apps, schema_editor): + Embedding = apps.get_model('embedding', 'Embedding') + Embedding.objects.filter(source_type=SourceType.PROBLEM).delete() + + +class Migration(migrations.Migration): + dependencies = [ + ('dataset', '0004_remove_paragraph_hit_num_remove_paragraph_star_num_and_more'), + ] + + operations = [ + migrations.RemoveField( + model_name='problem', + name='document', + ), + migrations.RemoveField( + model_name='problem', + name='paragraph', + ), + migrations.AddField( + model_name='paragraph', + name='hit_num', + field=models.IntegerField(default=0, verbose_name='命中次数'), + ), + migrations.AddField( + model_name='problem', + name='hit_num', + field=models.IntegerField(default=0, verbose_name='命中次数'), + ), + migrations.CreateModel( + name='ProblemParagraphMapping', + fields=[ + ('create_time', models.DateTimeField(auto_now_add=True, verbose_name='创建时间')), + ('update_time', models.DateTimeField(auto_now=True, verbose_name='修改时间')), + ('id', models.UUIDField(default=uuid.uuid1, editable=False, primary_key=True, serialize=False, + verbose_name='主键id')), + ('dataset', models.ForeignKey(db_constraint=False, on_delete=django.db.models.deletion.DO_NOTHING, + to='dataset.dataset')), + ('document', models.ForeignKey(on_delete=django.db.models.deletion.DO_NOTHING, to='dataset.document')), + ('paragraph', models.ForeignKey(db_constraint=False, on_delete=django.db.models.deletion.DO_NOTHING, + to='dataset.paragraph')), + ('problem', models.ForeignKey(db_constraint=False, on_delete=django.db.models.deletion.DO_NOTHING, + to='dataset.problem')), + ], + options={ + 'db_table': 'problem_paragraph_mapping', + }, + ), + migrations.RunPython(delete_problem_embedding) + ] diff --git a/apps/dataset/models/data_set.py b/apps/dataset/models/data_set.py index 7a69e2753..4cea2ef90 100644 --- a/apps/dataset/models/data_set.py +++ b/apps/dataset/models/data_set.py @@ -76,6 +76,7 @@ class Paragraph(AppModelMixin): title = models.CharField(max_length=256, verbose_name="标题", default="") status = models.CharField(verbose_name='状态', max_length=1, choices=Status.choices, default=Status.embedding) + hit_num = models.IntegerField(verbose_name="命中次数", default=0) is_active = models.BooleanField(default=True) class Meta: @@ -87,10 +88,20 @@ class Problem(AppModelMixin): 问题表 """ id = models.UUIDField(primary_key=True, max_length=128, default=uuid.uuid1, editable=False, verbose_name="主键id") - document = models.ForeignKey(Document, on_delete=models.DO_NOTHING, db_constraint=False) dataset = models.ForeignKey(DataSet, on_delete=models.DO_NOTHING, db_constraint=False) - paragraph = models.ForeignKey(Paragraph, on_delete=models.DO_NOTHING, db_constraint=False) content = models.CharField(max_length=256, verbose_name="问题内容") + hit_num = models.IntegerField(verbose_name="命中次数", default=0) class Meta: db_table = "problem" + + +class ProblemParagraphMapping(AppModelMixin): + id = models.UUIDField(primary_key=True, max_length=128, default=uuid.uuid1, editable=False, verbose_name="主键id") + dataset = models.ForeignKey(DataSet, on_delete=models.DO_NOTHING, db_constraint=False) + document = models.ForeignKey(Document, on_delete=models.DO_NOTHING) + problem = models.ForeignKey(Problem, on_delete=models.DO_NOTHING, db_constraint=False) + paragraph = models.ForeignKey(Paragraph, on_delete=models.DO_NOTHING, db_constraint=False) + + class Meta: + db_table = "problem_paragraph_mapping" diff --git a/apps/dataset/serializers/dataset_serializers.py b/apps/dataset/serializers/dataset_serializers.py index 6ab39a9dd..226502786 100644 --- a/apps/dataset/serializers/dataset_serializers.py +++ b/apps/dataset/serializers/dataset_serializers.py @@ -34,7 +34,7 @@ from common.util.field_message import ErrMessage from common.util.file_util import get_file_content from common.util.fork import ChildLink, Fork from common.util.split_model import get_split_model -from dataset.models.data_set import DataSet, Document, Paragraph, Problem, Type +from dataset.models.data_set import DataSet, Document, Paragraph, Problem, Type, ProblemParagraphMapping from dataset.serializers.common_serializers import list_paragraph, MetaSerializer from dataset.serializers.document_serializers import DocumentSerializers, DocumentInstanceSerializer from setting.models import AuthOperate @@ -303,6 +303,7 @@ class DataSetSerializers(serializers.ModelSerializer): document_model_list = [] paragraph_model_list = [] problem_model_list = [] + problem_paragraph_mapping_list = [] # 插入文档 for document in instance.get('documents') if 'documents' in instance else []: document_paragraph_dict_model = DocumentSerializers.Create.get_document_paragraph_model(dataset_id, @@ -312,6 +313,8 @@ class DataSetSerializers(serializers.ModelSerializer): paragraph_model_list.append(paragraph) for problem in document_paragraph_dict_model.get('problem_model_list'): problem_model_list.append(problem) + for problem_paragraph_mapping in document_paragraph_dict_model.get('problem_paragraph_mapping_list'): + problem_paragraph_mapping_list.append(problem_paragraph_mapping) # 插入知识库 dataset.save() @@ -321,6 +324,9 @@ class DataSetSerializers(serializers.ModelSerializer): QuerySet(Paragraph).bulk_create(paragraph_model_list) if len(paragraph_model_list) > 0 else None # 批量插入问题 QuerySet(Problem).bulk_create(problem_model_list) if len(problem_model_list) > 0 else None + # 批量插入关联问题 + QuerySet(ProblemParagraphMapping).bulk_create(problem_paragraph_mapping_list) if len( + problem_paragraph_mapping_list) > 0 else None # 响应数据 return {**DataSetSerializers(dataset).data, diff --git a/apps/dataset/serializers/document_serializers.py b/apps/dataset/serializers/document_serializers.py index 4e79cf062..8a0979c8e 100644 --- a/apps/dataset/serializers/document_serializers.py +++ b/apps/dataset/serializers/document_serializers.py @@ -28,7 +28,7 @@ from common.util.field_message import ErrMessage from common.util.file_util import get_file_content from common.util.fork import Fork from common.util.split_model import SplitModel, get_split_model -from dataset.models.data_set import DataSet, Document, Paragraph, Problem, Type, Status +from dataset.models.data_set import DataSet, Document, Paragraph, Problem, Type, Status, ProblemParagraphMapping from dataset.serializers.common_serializers import BatchSerializer, MetaSerializer from dataset.serializers.paragraph_serializers import ParagraphSerializers, ParagraphInstanceSerializer from smartdoc.conf import PROJECT_DIR @@ -179,7 +179,7 @@ class DocumentSerializers(ApiMixin, serializers.Serializer): # 删除段落 QuerySet(model=Paragraph).filter(document_id=document_id).delete() # 删除问题 - QuerySet(model=Problem).filter(document_id=document_id).delete() + QuerySet(model=ProblemParagraphMapping).filter(document_id=document_id).delete() # 删除向量库 ListenerManagement.delete_embedding_by_document_signal.send(document_id) paragraphs = get_split_model('web.md').parse(result.content) @@ -191,10 +191,14 @@ class DocumentSerializers(ApiMixin, serializers.Serializer): paragraph_model_list = document_paragraph_model.get('paragraph_model_list') problem_model_list = document_paragraph_model.get('problem_model_list') + problem_paragraph_mapping_list = document_paragraph_model.get('problem_paragraph_mapping_list') # 批量插入段落 QuerySet(Paragraph).bulk_create(paragraph_model_list) if len(paragraph_model_list) > 0 else None # 批量插入问题 QuerySet(Problem).bulk_create(problem_model_list) if len(problem_model_list) > 0 else None + # 插入关联问题 + QuerySet(ProblemParagraphMapping).bulk_create(problem_paragraph_mapping_list) if len( + problem_paragraph_mapping_list) > 0 else None # 向量化 if with_embedding: ListenerManagement.embedding_by_document_signal.send(document_id) @@ -273,7 +277,7 @@ class DocumentSerializers(ApiMixin, serializers.Serializer): # 删除段落 QuerySet(model=Paragraph).filter(document_id=document_id).delete() # 删除问题 - QuerySet(model=Problem).filter(document_id=document_id).delete() + QuerySet(model=ProblemParagraphMapping).filter(document_id=document_id).delete() # 删除向量库 ListenerManagement.delete_embedding_by_document_signal.send(document_id) return True @@ -344,12 +348,17 @@ class DocumentSerializers(ApiMixin, serializers.Serializer): document_model = document_paragraph_model.get('document') paragraph_model_list = document_paragraph_model.get('paragraph_model_list') problem_model_list = document_paragraph_model.get('problem_model_list') + problem_paragraph_mapping_list = document_paragraph_model.get('problem_paragraph_mapping_list') + # 插入文档 document_model.save() # 批量插入段落 QuerySet(Paragraph).bulk_create(paragraph_model_list) if len(paragraph_model_list) > 0 else None # 批量插入问题 QuerySet(Problem).bulk_create(problem_model_list) if len(problem_model_list) > 0 else None + # 批量插入关联问题 + QuerySet(ProblemParagraphMapping).bulk_create(problem_paragraph_mapping_list) if len( + problem_paragraph_mapping_list) > 0 else None document_id = str(document_model.id) return DocumentSerializers.Operate( data={'dataset_id': dataset_id, 'document_id': document_id}).one( @@ -396,14 +405,18 @@ class DocumentSerializers(ApiMixin, serializers.Serializer): paragraph_model_list = [] problem_model_list = [] + problem_paragraph_mapping_list = [] for paragraphs in paragraph_model_dict_list: paragraph = paragraphs.get('paragraph') for problem_model in paragraphs.get('problem_model_list'): problem_model_list.append(problem_model) + for problem_paragraph_mapping in paragraphs.get('problem_paragraph_mapping_list'): + problem_paragraph_mapping_list.append(problem_paragraph_mapping) paragraph_model_list.append(paragraph) return {'document': document_model, 'paragraph_model_list': paragraph_model_list, - 'problem_model_list': problem_model_list} + 'problem_model_list': problem_model_list, + 'problem_paragraph_mapping_list': problem_paragraph_mapping_list} @staticmethod def get_document_paragraph_model(dataset_id, instance: Dict): @@ -523,6 +536,7 @@ class DocumentSerializers(ApiMixin, serializers.Serializer): document_model_list = [] paragraph_model_list = [] problem_model_list = [] + problem_paragraph_mapping_list = [] # 插入文档 for document in instance_list: document_paragraph_dict_model = DocumentSerializers.Create.get_document_paragraph_model(dataset_id, @@ -532,6 +546,8 @@ class DocumentSerializers(ApiMixin, serializers.Serializer): paragraph_model_list.append(paragraph) for problem in document_paragraph_dict_model.get('problem_model_list'): problem_model_list.append(problem) + for problem_paragraph_mapping in document_paragraph_dict_model.get('problem_paragraph_mapping_list'): + problem_paragraph_mapping_list.append(problem_paragraph_mapping) # 插入文档 QuerySet(Document).bulk_create(document_model_list) if len(document_model_list) > 0 else None @@ -539,6 +555,9 @@ class DocumentSerializers(ApiMixin, serializers.Serializer): QuerySet(Paragraph).bulk_create(paragraph_model_list) if len(paragraph_model_list) > 0 else None # 批量插入问题 QuerySet(Problem).bulk_create(problem_model_list) if len(problem_model_list) > 0 else None + # 批量插入关联问题 + QuerySet(ProblemParagraphMapping).bulk_create(problem_paragraph_mapping_list) if len( + problem_paragraph_mapping_list) > 0 else None # 查询文档 query_set = QuerySet(model=Document) query_set = query_set.filter(**{'id__in': [d.id for d in document_model_list]}) diff --git a/apps/dataset/serializers/paragraph_serializers.py b/apps/dataset/serializers/paragraph_serializers.py index 4bafcabd1..023f4cecd 100644 --- a/apps/dataset/serializers/paragraph_serializers.py +++ b/apps/dataset/serializers/paragraph_serializers.py @@ -20,9 +20,10 @@ from common.exception.app_exception import AppApiException from common.mixins.api_mixin import ApiMixin from common.util.common import post from common.util.field_message import ErrMessage -from dataset.models import Paragraph, Problem, Document +from dataset.models import Paragraph, Problem, Document, ProblemParagraphMapping from dataset.serializers.common_serializers import update_document_char_length -from dataset.serializers.problem_serializers import ProblemInstanceSerializer, ProblemSerializer +from dataset.serializers.problem_serializers import ProblemInstanceSerializer, ProblemSerializer, ProblemSerializers +from embedding.models import SourceType class ParagraphSerializer(serializers.ModelSerializer): @@ -84,6 +85,193 @@ class ParagraphSerializers(ApiMixin, serializers.Serializer): content = serializers.CharField(required=True, max_length=4096, error_messages=ErrMessage.char( "分段内容")) + class Problem(ApiMixin, serializers.Serializer): + dataset_id = serializers.UUIDField(required=True, error_messages=ErrMessage.uuid("知识库id")) + + document_id = serializers.UUIDField(required=True, error_messages=ErrMessage.uuid("文档id")) + + paragraph_id = serializers.UUIDField(required=True, error_messages=ErrMessage.uuid("段落id")) + + def is_valid(self, *, raise_exception=False): + super().is_valid(raise_exception=True) + if not QuerySet(Paragraph).filter(id=self.data.get('paragraph_id')).exists(): + raise AppApiException(500, "段落id不存在") + + def list(self, with_valid=False): + """ + 获取问题列表 + :param with_valid: 是否校验 + :return: 问题列表 + """ + if with_valid: + self.is_valid(raise_exception=True) + problem_paragraph_mapping = QuerySet(ProblemParagraphMapping).filter(dataset_id=self.data.get("dataset_id"), + paragraph_id=self.data.get( + 'paragraph_id')) + return [ProblemSerializer(row).data for row in + QuerySet(Problem).filter(id__in=[row.problem_id for row in problem_paragraph_mapping])] + + @transaction.atomic + def save(self, instance: Dict, with_valid=True, with_embedding=True): + if with_valid: + self.is_valid() + ProblemInstanceSerializer(data=instance).is_valid(raise_exception=True) + problem = QuerySet(Problem).filter(dataset_id=self.data.get('dataset_id'), + content=instance.get('content')).first() + if problem is None: + problem = Problem(id=uuid.uuid1(), dataset_id=self.data.get('dataset_id'), + content=instance.get('content')) + problem.save() + if QuerySet(ProblemParagraphMapping).filter(dataset_id=self.data.get('dataset_id'), problem_id=problem.id, + paragraph_id=self.data.get('paragraph_id')).exists(): + raise AppApiException(500, "已经关联,请勿重复关联") + problem_paragraph_mapping = ProblemParagraphMapping(id=uuid.uuid1(), + problem_id=problem.id, + document_id=self.data.get('document_id'), + paragraph_id=self.data.get('paragraph_id'), + dataset_id=self.data.get('dataset_id')) + problem_paragraph_mapping.save() + if with_embedding: + ListenerManagement.embedding_by_problem_signal.send({'text': problem.content, + 'is_active': True, + 'source_type': SourceType.PROBLEM, + 'source_id': problem_paragraph_mapping.id, + 'document_id': self.data.get('document_id'), + 'paragraph_id': self.data.get('paragraph_id'), + 'dataset_id': self.data.get('dataset_id'), + }) + + return ProblemSerializers.Operate( + data={'dataset_id': self.data.get('dataset_id'), + 'problem_id': problem.id}).one(with_valid=True) + + @staticmethod + def get_request_params_api(): + return [openapi.Parameter(name='dataset_id', + in_=openapi.IN_PATH, + type=openapi.TYPE_STRING, + required=True, + description='知识库id'), + openapi.Parameter(name='document_id', + in_=openapi.IN_PATH, + type=openapi.TYPE_STRING, + required=True, + description='文档id'), + openapi.Parameter(name='paragraph_id', + in_=openapi.IN_PATH, + type=openapi.TYPE_STRING, + required=True, + description='段落id')] + + @staticmethod + def get_request_body_api(): + return openapi.Schema(type=openapi.TYPE_OBJECT, + required=["content"], + properties={ + 'content': openapi.Schema( + type=openapi.TYPE_STRING, title="内容") + }) + + @staticmethod + def get_response_body_api(): + return openapi.Schema( + type=openapi.TYPE_OBJECT, + required=['id', 'content', 'hit_num', 'dataset_id', 'create_time', 'update_time'], + properties={ + 'id': openapi.Schema(type=openapi.TYPE_STRING, title="id", + description="id", default="xx"), + 'content': openapi.Schema(type=openapi.TYPE_STRING, title="问题内容", + description="问题内容", default='问题内容'), + 'hit_num': openapi.Schema(type=openapi.TYPE_INTEGER, title="命中数量", description="命中数量", + default=1), + 'dataset_id': openapi.Schema(type=openapi.TYPE_STRING, title="知识库id", + description="知识库id", default='xxx'), + 'update_time': openapi.Schema(type=openapi.TYPE_STRING, title="修改时间", + description="修改时间", + default="1970-01-01 00:00:00"), + 'create_time': openapi.Schema(type=openapi.TYPE_STRING, title="创建时间", + description="创建时间", + default="1970-01-01 00:00:00" + ) + } + ) + + class Association(ApiMixin, serializers.Serializer): + dataset_id = serializers.UUIDField(required=True, error_messages=ErrMessage.uuid("知识库id")) + + problem_id = serializers.UUIDField(required=True, error_messages=ErrMessage.uuid("问题id")) + + document_id = serializers.UUIDField(required=True, error_messages=ErrMessage.uuid("文档id")) + + paragraph_id = serializers.UUIDField(required=True, error_messages=ErrMessage.uuid("段落id")) + + def is_valid(self, *, raise_exception=True): + super().is_valid(raise_exception=True) + dataset_id = self.data.get('dataset_id') + paragraph_id = self.data.get('paragraph_id') + problem_id = self.data.get("problem_id") + if not QuerySet(Paragraph).filter(dataset_id=dataset_id, id=paragraph_id).exists(): + raise AppApiException(500, "段落不存在") + if not QuerySet(Problem).filter(dataset_id=dataset_id, id=problem_id).exists(): + raise AppApiException(500, "问题不存在") + + def association(self, with_valid=True, with_embedding=True): + if with_valid: + self.is_valid(raise_exception=True) + problem = QuerySet(Problem).filter(id=self.data.get("problem_id")) + problem_paragraph_mapping = ProblemParagraphMapping(id=uuid.uuid1(), + document_id=self.data.get('document_id'), + paragraph_id=self.data.get('paragraph_id'), + dataset_id=self.data.get('dataset_id'), + problem_id=problem.id) + problem_paragraph_mapping.save() + if with_embedding: + ListenerManagement.embedding_by_problem_signal.send({'text': problem.content, + 'is_active': True, + 'source_type': SourceType.PROBLEM, + 'source_id': problem_paragraph_mapping.id, + 'document_id': self.data.get('document_id'), + 'paragraph_id': self.data.get('paragraph_id'), + 'dataset_id': self.data.get('dataset_id'), + }) + + def un_association(self, with_valid=True): + if with_valid: + self.is_valid(raise_exception=True) + problem_paragraph_mapping = QuerySet(ProblemParagraphMapping).filter( + paragraph_id=self.data.get('paragraph_id'), + dataset_id=self.data.get('dataset_id'), + problem_id=self.data.get( + 'problem_id')).first() + problem_paragraph_mapping_id = problem_paragraph_mapping.id + problem_paragraph_mapping.delete() + ListenerManagement.delete_embedding_by_source_signal.send(problem_paragraph_mapping_id) + return True + + @staticmethod + def get_request_params_api(): + return [openapi.Parameter(name='dataset_id', + in_=openapi.IN_PATH, + type=openapi.TYPE_STRING, + required=True, + description='知识库id'), + openapi.Parameter(name='document_id', + in_=openapi.IN_PATH, + type=openapi.TYPE_STRING, + required=True, + description='文档id') + , openapi.Parameter(name='paragraph_id', + in_=openapi.IN_PATH, + type=openapi.TYPE_STRING, + required=True, + description='段落id'), + openapi.Parameter(name='problem_id', + in_=openapi.IN_PATH, + type=openapi.TYPE_STRING, + required=True, + description='问题id') + ] + class Operate(ApiMixin, serializers.Serializer): # 段落id paragraph_id = serializers.UUIDField(required=True, error_messages=ErrMessage.char( @@ -158,8 +346,13 @@ class ParagraphSerializers(ApiMixin, serializers.Serializer): return self.one(), instance def get_problem_list(self): - return [ProblemSerializer(problem).data for problem in - QuerySet(Problem).filter(paragraph_id=self.data.get("paragraph_id"))] + ProblemParagraphMapping(ProblemParagraphMapping) + problem_paragraph_mapping = QuerySet(ProblemParagraphMapping).filter( + paragraph_id=self.data.get("paragraph_id")) + if len(problem_paragraph_mapping) > 0: + return [ProblemSerializer(problem).data for problem in + QuerySet(Problem).filter(id__in=[ppm.problem_id for ppm in problem_paragraph_mapping])] + return [] def one(self, with_valid=False): if with_valid: @@ -172,7 +365,7 @@ class ParagraphSerializers(ApiMixin, serializers.Serializer): self.is_valid(raise_exception=True) paragraph_id = self.data.get('paragraph_id') QuerySet(Paragraph).filter(id=paragraph_id).delete() - QuerySet(Problem).filter(paragraph_id=paragraph_id).delete() + QuerySet(ProblemParagraphMapping).filter(paragraph_id=paragraph_id).delete() ListenerManagement.delete_embedding_by_paragraph_signal.send(paragraph_id) @staticmethod @@ -210,10 +403,14 @@ class ParagraphSerializers(ApiMixin, serializers.Serializer): paragraph_problem_model = self.get_paragraph_problem_model(dataset_id, document_id, instance) paragraph = paragraph_problem_model.get('paragraph') problem_model_list = paragraph_problem_model.get('problem_model_list') + problem_paragraph_mapping_list = paragraph_problem_model.get('problem_paragraph_mapping_list') # 插入段落 paragraph_problem_model.get('paragraph').save() # 插入問題 QuerySet(Problem).bulk_create(problem_model_list) if len(problem_model_list) > 0 else None + # 插入问题关联关系 + QuerySet(ProblemParagraphMapping).bulk_create(problem_paragraph_mapping_list) if len( + problem_paragraph_mapping_list) > 0 else None # 修改长度 update_document_char_length(document_id) if with_embedding: @@ -229,12 +426,35 @@ class ParagraphSerializers(ApiMixin, serializers.Serializer): content=instance.get("content"), dataset_id=dataset_id, title=instance.get("title") if 'title' in instance else '') + problem_list = instance.get('problem_list') + exists_problem_list = [] + if 'problem_list' in instance and len(problem_list) > 0: + exists_problem_list = QuerySet(Problem).filter(dataset_id=dataset_id, + content__in=[p.get('content') for p in + problem_list]).all() - problem_model_list = [Problem(id=uuid.uuid1(), content=problem.get('content'), paragraph_id=paragraph.id, - document_id=document_id, dataset_id=dataset_id) for problem in ( - instance.get('problem_list') if 'problem_list' in instance else [])] + problem_model_list = [ + ParagraphSerializers.Create.or_get(exists_problem_list, problem.get('content'), dataset_id) for + problem in ( + instance.get('problem_list') if 'problem_list' in instance else [])] - return {'paragraph': paragraph, 'problem_model_list': problem_model_list} + problem_paragraph_mapping_list = [ + ProblemParagraphMapping(id=uuid.uuid1(), document_id=document_id, problem_id=problem_model.id, + paragraph_id=paragraph.id, + dataset_id=dataset_id) for + problem_model in problem_model_list] + return {'paragraph': paragraph, + 'problem_model_list': [problem_model for problem_model in problem_model_list if + not list(exists_problem_list).__contains__(problem_model)], + 'problem_paragraph_mapping_list': problem_paragraph_mapping_list} + + @staticmethod + def or_get(exists_problem_list, content, dataset_id): + exists = [row for row in exists_problem_list if row.content == content] + if len(exists) > 0: + return exists[0] + else: + return Problem(id=uuid.uuid1(), content=content, dataset_id=dataset_id) @staticmethod def get_request_body_api(): diff --git a/apps/dataset/serializers/problem_serializers.py b/apps/dataset/serializers/problem_serializers.py index 449d6e814..a1bf6df91 100644 --- a/apps/dataset/serializers/problem_serializers.py +++ b/apps/dataset/serializers/problem_serializers.py @@ -6,6 +6,7 @@ @date:2023/10/23 13:55 @desc: """ +import os import uuid from typing import Dict @@ -14,19 +15,19 @@ from django.db.models import QuerySet from drf_yasg import openapi from rest_framework import serializers -from common.event.listener_manage import ListenerManagement -from common.exception.app_exception import AppApiException +from common.db.search import native_search, native_page_search +from common.event import ListenerManagement, UpdateProblemArgs from common.mixins.api_mixin import ApiMixin from common.util.field_message import ErrMessage -from dataset.models import Problem, Paragraph -from embedding.models import SourceType -from embedding.vector.pg_vector import PGVector +from common.util.file_util import get_file_content +from dataset.models import Problem, Paragraph, ProblemParagraphMapping +from smartdoc.conf import PROJECT_DIR class ProblemSerializer(serializers.ModelSerializer): class Meta: model = Problem - fields = ['id', 'content', 'dataset_id', 'document_id', + fields = ['id', 'content', 'dataset_id', 'create_time', 'update_time'] @@ -49,186 +50,92 @@ class ProblemInstanceSerializer(ApiMixin, serializers.Serializer): class ProblemSerializers(ApiMixin, serializers.Serializer): - class Create(ApiMixin, serializers.Serializer): + class Create(serializers.Serializer): dataset_id = serializers.UUIDField(required=True, error_messages=ErrMessage.uuid("知识库id")) + problem_list = serializers.ListField(required=True, error_messages=ErrMessage.list("问题列表"), + child=serializers.CharField(required=True, + error_messages=ErrMessage.char("问题"))) - document_id = serializers.UUIDField(required=True, error_messages=ErrMessage.uuid("文档id")) - - paragraph_id = serializers.UUIDField(required=True, error_messages=ErrMessage.uuid("段落id")) - - def is_valid(self, *, raise_exception=False): - super().is_valid(raise_exception=True) - if not QuerySet(Paragraph).filter(id=self.data.get('paragraph_id'), - document_id=self.data.get('document_id'), - dataset_id=self.data.get('dataset_id')).exists(): - raise AppApiException(500, "段落id不正确") - - @transaction.atomic - def save(self, instance: Dict, with_valid=True, with_embedding=True): - if with_valid: - self.is_valid() - ProblemInstanceSerializer(data=instance).is_valid(raise_exception=True) - problem = Problem(id=uuid.uuid1(), paragraph_id=self.data.get('paragraph_id'), - document_id=self.data.get('document_id'), dataset_id=self.data.get('dataset_id'), - content=instance.get('content')) - problem.save() - if with_embedding: - ListenerManagement.embedding_by_problem_signal.send({'text': problem.content, - 'is_active': True, - 'source_type': SourceType.PROBLEM, - 'source_id': problem.id, - 'document_id': self.data.get('document_id'), - 'paragraph_id': self.data.get('paragraph_id'), - 'dataset_id': self.data.get('dataset_id'), - - }) - - return ProblemSerializers.Operate( - data={'dataset_id': self.data.get('dataset_id'), 'document_id': self.data.get('document_id'), - 'paragraph_id': self.data.get('paragraph_id'), 'problem_id': problem.id}).one(with_valid=True) - - @staticmethod - def get_request_body_api(): - return ProblemInstanceSerializer.get_request_body_api() - - @staticmethod - def get_request_params_api(): - return [openapi.Parameter(name='dataset_id', - in_=openapi.IN_PATH, - type=openapi.TYPE_STRING, - required=True, - description='知识库id'), - openapi.Parameter(name='document_id', - in_=openapi.IN_PATH, - type=openapi.TYPE_STRING, - required=True, - description='文档id'), - openapi.Parameter(name='paragraph_id', - in_=openapi.IN_PATH, - type=openapi.TYPE_STRING, - required=True, - description='段落id')] - - class Query(ApiMixin, serializers.Serializer): - dataset_id = serializers.UUIDField(required=True, error_messages=ErrMessage.uuid("知识库id")) - - document_id = serializers.UUIDField(required=True, error_messages=ErrMessage.uuid("文档id")) - - paragraph_id = serializers.UUIDField(required=True, error_messages=ErrMessage.uuid("段落id")) - - def is_valid(self, *, raise_exception=True): - super().is_valid(raise_exception=True) - if not QuerySet(Paragraph).filter(id=self.data.get('paragraph_id')).exists(): - raise AppApiException(500, "段落id不存在") - - def get_query_set(self): - dataset_id = self.data.get('dataset_id') - document_id = self.data.get('document_id') - paragraph_id = self.data.get("paragraph_id") - return QuerySet(Problem).filter( - **{'paragraph_id': paragraph_id, 'dataset_id': dataset_id, 'document_id': document_id}) - - def list(self, with_valid=False): - """ - 获取问题列表 - :param with_valid: 是否校验 - :return: 问题列表 - """ + def batch(self, with_valid=True): if with_valid: self.is_valid(raise_exception=True) - query_set = self.get_query_set() - return [ProblemSerializer(p).data for p in query_set] + problem_list = self.data.get('problem_list') + dataset_id = self.data.get('dataset_id') + exists_problem_content_list = [problem.content for problem in + QuerySet(Problem).filter(dataset_id=dataset_id, + content__in=problem_list)] + problem_instance_list = [Problem(id=uuid.uuid1(), dataset_id=dataset_id, content=problem_content) for + problem_content in + self.data.get('problem_list') if + (not exists_problem_content_list.__contains__(problem_content) if + len(exists_problem_content_list) > 0 else True)] - @staticmethod - def get_request_params_api(): - return [openapi.Parameter(name='dataset_id', - in_=openapi.IN_PATH, - type=openapi.TYPE_STRING, - required=True, - description='知识库id'), - openapi.Parameter(name='document_id', - in_=openapi.IN_PATH, - type=openapi.TYPE_STRING, - required=True, - description='文档id') - , openapi.Parameter(name='paragraph_id', - in_=openapi.IN_PATH, - type=openapi.TYPE_STRING, - required=True, - description='段落id')] + QuerySet(Problem).bulk_create(problem_instance_list) if len(problem_instance_list) > 0 else None + return [ProblemSerializer(problem_instance).data for problem_instance in problem_instance_list] - class Operate(ApiMixin, serializers.Serializer): + class Query(serializers.Serializer): dataset_id = serializers.UUIDField(required=True, error_messages=ErrMessage.uuid("知识库id")) + content = serializers.CharField(required=False, error_messages=ErrMessage.char("问题")) - document_id = serializers.UUIDField(required=True, error_messages=ErrMessage.uuid("文档id")) + def get_query_set(self): + query_set = QuerySet(model=Problem) + query_set = query_set.filter( + **{'dataset_id': self.data.get('dataset_id')}) + if 'content' in self.data: + query_set = query_set.filter(**{'content__contains': self.data.get('content')}) + return query_set - paragraph_id = serializers.UUIDField(required=True, error_messages=ErrMessage.uuid("段落id")) + def list(self): + query_set = self.get_query_set() + return native_search(query_set, select_string=get_file_content( + os.path.join(PROJECT_DIR, "apps", "dataset", 'sql', 'list_problem.sql'))) + + def page(self, current_page, page_size): + query_set = self.get_query_set() + return native_page_search(current_page, page_size, query_set, select_string=get_file_content( + os.path.join(PROJECT_DIR, "apps", "dataset", 'sql', 'list_problem.sql'))) + + class Operate(serializers.Serializer): + dataset_id = serializers.UUIDField(required=True, error_messages=ErrMessage.uuid("知识库id")) problem_id = serializers.UUIDField(required=True, error_messages=ErrMessage.uuid("问题id")) - def delete(self, with_valid=False): + def list_paragraph(self, with_valid=True): if with_valid: self.is_valid(raise_exception=True) - QuerySet(Problem).filter(**{'id': self.data.get('problem_id')}).delete() - PGVector().delete_by_source_id(self.data.get('problem_id'), SourceType.PROBLEM) - ListenerManagement.delete_embedding_by_source_signal.send(self.data.get('problem_id')) - return True + problem_paragraph_mapping = QuerySet(ProblemParagraphMapping).filter(dataset_id=self.data.get("dataset_id"), + problem_id=self.data.get("problem_id")) + return native_search( + QuerySet(Paragraph).filter(id__in=[row.paragraph_id for row in problem_paragraph_mapping]), + select_string=get_file_content( + os.path.join(PROJECT_DIR, "apps", "dataset", 'sql', 'list_paragraph.sql'))) - def one(self, with_valid=False): + def one(self, with_valid=True): if with_valid: self.is_valid(raise_exception=True) return ProblemInstanceSerializer(QuerySet(Problem).get(**{'id': self.data.get('problem_id')})).data - @staticmethod - def get_request_params_api(): - return [openapi.Parameter(name='dataset_id', - in_=openapi.IN_PATH, - type=openapi.TYPE_STRING, - required=True, - description='知识库id'), - openapi.Parameter(name='document_id', - in_=openapi.IN_PATH, - type=openapi.TYPE_STRING, - required=True, - description='文档id') - , openapi.Parameter(name='paragraph_id', - in_=openapi.IN_PATH, - type=openapi.TYPE_STRING, - required=True, - description='段落id'), - openapi.Parameter(name='problem_id', - in_=openapi.IN_PATH, - type=openapi.TYPE_STRING, - required=True, - description='问题id') - ] + @transaction.atomic + def delete(self, with_valid=True): + if with_valid: + self.is_valid(raise_exception=True) + problem_paragraph_mapping_list = QuerySet(ProblemParagraphMapping).filter( + dataset_id=self.data.get('dataset_id'), + problem_id=self.data.get('problem_id')) + source_ids = [row.id for row in problem_paragraph_mapping_list] + QuerySet(Problem).filter(id=self.data.get('problem_id')).delete() + ListenerManagement.delete_embedding_by_source_ids_signal.send(source_ids) + return True - @staticmethod - def get_response_body_api(): - return openapi.Schema( - type=openapi.TYPE_OBJECT, - required=['id', 'content', 'hit_num', 'star_num', 'trample_num', 'dataset_id', - 'document_id', - 'create_time', 'update_time'], - properties={ - 'id': openapi.Schema(type=openapi.TYPE_STRING, title="id", - description="id", default="xx"), - 'content': openapi.Schema(type=openapi.TYPE_STRING, title="问题内容", - description="问题内容", default='问题内容'), - 'hit_num': openapi.Schema(type=openapi.TYPE_INTEGER, title="命中数量", description="命中数量", - default=1), - 'star_num': openapi.Schema(type=openapi.TYPE_INTEGER, title="点赞数量", - description="点赞数量", default=1), - 'trample_num': openapi.Schema(type=openapi.TYPE_INTEGER, title="点踩数量", - description="点踩数", default=1), - 'document_id': openapi.Schema(type=openapi.TYPE_STRING, title="文档id", - description="文档id", default='xxx'), - 'update_time': openapi.Schema(type=openapi.TYPE_STRING, title="修改时间", - description="修改时间", - default="1970-01-01 00:00:00"), - 'create_time': openapi.Schema(type=openapi.TYPE_STRING, title="创建时间", - description="创建时间", - default="1970-01-01 00:00:00" - ) - } - ) + @transaction.atomic + def edit(self, instance: Dict, with_valid=True): + if with_valid: + self.is_valid(raise_exception=True) + problem_id = self.data.get('problem_id') + dataset_id = self.data.get('dataset_id') + content = instance.get('content') + problem = QuerySet(Problem).filter(id=problem_id, + dataset_id=dataset_id).first() + problem.content = content + problem.save() + ListenerManagement.update_problem_signal.send(UpdateProblemArgs(problem_id, content)) diff --git a/apps/dataset/sql/list_problem.sql b/apps/dataset/sql/list_problem.sql index 3183f0da8..affb51334 100644 --- a/apps/dataset/sql/list_problem.sql +++ b/apps/dataset/sql/list_problem.sql @@ -1,10 +1,5 @@ SELECT - problem."id", - problem."content", - problem_paragraph_mapping.hit_num, - problem_paragraph_mapping.star_num, - problem_paragraph_mapping.trample_num, - problem_paragraph_mapping.paragraph_id + problem.*, + (SELECT "count"("id") FROM "problem_paragraph_mapping" WHERE problem_id="problem"."id") as "paragraph_count" FROM problem problem - LEFT JOIN problem_paragraph_mapping problem_paragraph_mapping ON problem."id" = problem_paragraph_mapping.problem_id diff --git a/apps/dataset/swagger_api/problem_api.py b/apps/dataset/swagger_api/problem_api.py new file mode 100644 index 000000000..6e57ad4b6 --- /dev/null +++ b/apps/dataset/swagger_api/problem_api.py @@ -0,0 +1,127 @@ +# coding=utf-8 +""" + @project: maxkb + @Author:虎 + @file: problem_api.py + @date:2024/3/11 10:49 + @desc: +""" +from drf_yasg import openapi + +from common.mixins.api_mixin import ApiMixin + + +class ProblemApi(ApiMixin): + @staticmethod + def get_response_body_api(): + return openapi.Schema( + type=openapi.TYPE_OBJECT, + required=['id', 'content', 'hit_num', 'dataset_id', 'create_time', 'update_time'], + properties={ + 'id': openapi.Schema(type=openapi.TYPE_STRING, title="id", + description="id", default="xx"), + 'content': openapi.Schema(type=openapi.TYPE_STRING, title="问题内容", + description="问题内容", default='问题内容'), + 'hit_num': openapi.Schema(type=openapi.TYPE_INTEGER, title="命中数量", description="命中数量", + default=1), + 'dataset_id': openapi.Schema(type=openapi.TYPE_STRING, title="知识库id", + description="知识库id", default='xxx'), + 'update_time': openapi.Schema(type=openapi.TYPE_STRING, title="修改时间", + description="修改时间", + default="1970-01-01 00:00:00"), + 'create_time': openapi.Schema(type=openapi.TYPE_STRING, title="创建时间", + description="创建时间", + default="1970-01-01 00:00:00" + ) + } + ) + + class Operate(ApiMixin): + @staticmethod + def get_request_params_api(): + return [openapi.Parameter(name='dataset_id', + in_=openapi.IN_PATH, + type=openapi.TYPE_STRING, + required=True, + description='知识库id'), + openapi.Parameter(name='problem_id', + in_=openapi.IN_PATH, + type=openapi.TYPE_STRING, + required=True, + description='问题id')] + + @staticmethod + def get_request_body_api(): + return openapi.Schema( + type=openapi.TYPE_OBJECT, + required=['content'], + properties={ + 'content': openapi.Schema(type=openapi.TYPE_STRING, title="问题内容", + description="问题内容"), + + } + ) + + class Paragraph(ApiMixin): + @staticmethod + def get_request_params_api(): + return ProblemApi.Operate.get_request_params_api() + + @staticmethod + def get_response_body_api(): + return openapi.Schema( + type=openapi.TYPE_OBJECT, + required=['content'], + properties={ + 'content': openapi.Schema(type=openapi.TYPE_STRING, max_length=4096, title="分段内容", + description="分段内容"), + 'title': openapi.Schema(type=openapi.TYPE_STRING, max_length=256, title="分段标题", + description="分段标题"), + 'is_active': openapi.Schema(type=openapi.TYPE_BOOLEAN, title="是否可用", description="是否可用"), + 'hit_num': openapi.Schema(type=openapi.TYPE_NUMBER, title="命中次数", description="命中次数"), + 'update_time': openapi.Schema(type=openapi.TYPE_STRING, title="修改时间", + description="修改时间", + default="1970-01-01 00:00:00"), + 'create_time': openapi.Schema(type=openapi.TYPE_STRING, title="创建时间", + description="创建时间", + default="1970-01-01 00:00:00" + ), + } + ) + + class Query(ApiMixin): + @staticmethod + def get_request_params_api(): + return [openapi.Parameter(name='dataset_id', + in_=openapi.IN_PATH, + type=openapi.TYPE_STRING, + required=True, + description='知识库id'), + openapi.Parameter(name='content', + in_=openapi.IN_QUERY, + type=openapi.TYPE_STRING, + required=False, + description='问题')] + + class BatchCreate(ApiMixin): + @staticmethod + def get_request_body_api(): + return openapi.Schema(type=openapi.TYPE_ARRAY, + items=ProblemApi.Create.get_request_body_api()) + + @staticmethod + def get_request_params_api(): + return ProblemApi.Create.get_request_params_api() + + class Create(ApiMixin): + @staticmethod + def get_request_body_api(): + return openapi.Schema(type=openapi.TYPE_STRING, description="问题文本") + + @staticmethod + def get_request_params_api(): + return [openapi.Parameter(name='dataset_id', + in_=openapi.IN_PATH, + type=openapi.TYPE_STRING, + required=True, + description='知识库id')] diff --git a/apps/dataset/urls.py b/apps/dataset/urls.py index 9504741f9..d153d7134 100644 --- a/apps/dataset/urls.py +++ b/apps/dataset/urls.py @@ -28,7 +28,16 @@ urlpatterns = [ path('dataset//document//paragraph/', views.Paragraph.Operate.as_view()), path('dataset//document//paragraph//problem', - views.Problem.as_view()), - path('dataset//document//paragraph//problem/', - views.Problem.Operate.as_view()) + views.Paragraph.Problem.as_view()), + path( + 'dataset//document//paragraph//problem//un_association', + views.Paragraph.Problem.UnAssociation.as_view()), + path( + 'dataset//document//paragraph//problem//association', + views.Paragraph.Problem.Association.as_view()), + path('dataset//problem', views.Problem.as_view()), + path('dataset//problem//', views.Problem.Page.as_view()), + path('dataset//problem/', views.Problem.Operate.as_view()), + path('dataset//problem//paragraph', views.Problem.Paragraph.as_view()), + ] diff --git a/apps/dataset/views/paragraph.py b/apps/dataset/views/paragraph.py index e530769c9..3ce6f1114 100644 --- a/apps/dataset/views/paragraph.py +++ b/apps/dataset/views/paragraph.py @@ -52,6 +52,73 @@ class Paragraph(APIView): return result.success( ParagraphSerializers.Create(data={'dataset_id': dataset_id, 'document_id': document_id}).save(request.data)) + class Problem(APIView): + authentication_classes = [TokenAuth] + + @action(methods=['POST'], detail=False) + @swagger_auto_schema(operation_summary="添加关联问题", + operation_id="添加段落关联问题", + manual_parameters=ParagraphSerializers.Problem.get_request_params_api(), + request_body=ParagraphSerializers.Problem.get_request_body_api(), + responses=result.get_api_response(ParagraphSerializers.Problem.get_response_body_api()), + tags=["知识库/文档/段落"]) + @has_permissions( + lambda r, k: Permission(group=Group.DATASET, operate=Operate.MANAGE, + dynamic_tag=k.get('dataset_id'))) + def post(self, request: Request, dataset_id: str, document_id: str, paragraph_id: str): + return result.success(ParagraphSerializers.Problem( + data={"dataset_id": dataset_id, 'document_id': document_id, 'paragraph_id': paragraph_id}).save( + request.data, with_valid=True)) + + @action(methods=['GET'], detail=False) + @swagger_auto_schema(operation_summary="获取段落问题列表", + operation_id="获取段落问题列表", + manual_parameters=ParagraphSerializers.Problem.get_request_params_api(), + responses=result.get_api_array_response( + ParagraphSerializers.Problem.get_response_body_api()), + tags=["知识库/文档/段落"]) + @has_permissions( + lambda r, k: Permission(group=Group.DATASET, operate=Operate.USE, + dynamic_tag=k.get('dataset_id'))) + def get(self, request: Request, dataset_id: str, document_id: str, paragraph_id: str): + return result.success(ParagraphSerializers.Problem( + data={"dataset_id": dataset_id, 'document_id': document_id, 'paragraph_id': paragraph_id}).list( + with_valid=True)) + + class UnAssociation(APIView): + authentication_classes = [TokenAuth] + + @action(methods=['PUT'], detail=False) + @swagger_auto_schema(operation_summary="解除关联问题", + operation_id="解除关联问题", + manual_parameters=ParagraphSerializers.Association.get_request_params_api(), + responses=result.get_default_response(), + tags=["知识库/文档/段落"]) + @has_permissions( + lambda r, k: Permission(group=Group.DATASET, operate=Operate.MANAGE, + dynamic_tag=k.get('dataset_id'))) + def put(self, request: Request, dataset_id: str, document_id: str, paragraph_id: str, problem_id: str): + return result.success(ParagraphSerializers.Association( + data={'dataset_id': dataset_id, 'document_id': document_id, 'paragraph_id': paragraph_id, + 'problem_id': problem_id}).un_association()) + + class Association(APIView): + authentication_classes = [TokenAuth] + + @action(methods=['PUT'], detail=False) + @swagger_auto_schema(operation_summary="关联问题", + operation_id="关联问题", + manual_parameters=ParagraphSerializers.Association.get_request_params_api(), + responses=result.get_default_response(), + tags=["知识库/文档/段落"]) + @has_permissions( + lambda r, k: Permission(group=Group.DATASET, operate=Operate.MANAGE, + dynamic_tag=k.get('dataset_id'))) + def put(self, request: Request, dataset_id: str, document_id: str, paragraph_id: str, problem_id: str): + return result.success(ParagraphSerializers.Association( + data={'dataset_id': dataset_id, 'document_id': document_id, 'paragraph_id': paragraph_id, + 'problem_id': problem_id}).association()) + class Operate(APIView): authentication_classes = [TokenAuth] @@ -61,7 +128,7 @@ class Paragraph(APIView): manual_parameters=ParagraphSerializers.Operate.get_request_params_api(), request_body=ParagraphSerializers.Operate.get_request_body_api(), responses=result.get_api_response(ParagraphSerializers.Operate.get_response_body_api()) - ,tags=["知识库/文档/段落"]) + , tags=["知识库/文档/段落"]) @has_permissions( lambda r, k: Permission(group=Group.DATASET, operate=Operate.MANAGE, dynamic_tag=k.get('dataset_id'))) diff --git a/apps/dataset/views/problem.py b/apps/dataset/views/problem.py index 6c9408f5d..36731425f 100644 --- a/apps/dataset/views/problem.py +++ b/apps/dataset/views/problem.py @@ -8,61 +8,115 @@ """ from drf_yasg.utils import swagger_auto_schema from rest_framework.decorators import action -from rest_framework.request import Request from rest_framework.views import APIView +from rest_framework.views import Request from common.auth import TokenAuth, has_permissions from common.constants.permission_constants import Permission, Group, Operate from common.response import result +from common.util.common import query_params_to_single_dict from dataset.serializers.problem_serializers import ProblemSerializers +from dataset.swagger_api.problem_api import ProblemApi class Problem(APIView): authentication_classes = [TokenAuth] + @action(methods=['GET'], detail=False) + @swagger_auto_schema(operation_summary="问题列表", + operation_id="问题列表", + manual_parameters=ProblemApi.Query.get_request_params_api(), + responses=result.get_api_array_response(ProblemApi.get_response_body_api()), + tags=["知识库/文档/段落/问题"] + ) + @has_permissions( + lambda r, k: Permission(group=Group.DATASET, operate=Operate.USE, + dynamic_tag=k.get('dataset_id'))) + def get(self, request: Request, dataset_id: str): + q = ProblemSerializers.Query( + data={**query_params_to_single_dict(request.query_params), 'dataset_id': dataset_id}) + q.is_valid(raise_exception=True) + return result.success(q.list()) + @action(methods=['POST'], detail=False) - @swagger_auto_schema(operation_summary="添加关联问题", - operation_id="添加段落关联问题", - manual_parameters=ProblemSerializers.Create.get_request_params_api(), - request_body=ProblemSerializers.Create.get_request_body_api(), - responses=result.get_api_response(ProblemSerializers.Operate.get_response_body_api()), + @swagger_auto_schema(operation_summary="创建问题", + operation_id="创建问题", + manual_parameters=ProblemApi.BatchCreate.get_request_params_api(), + request_body=ProblemApi.BatchCreate.get_request_body_api(), + responses=result.get_api_response(ProblemApi.Query.get_response_body_api()), tags=["知识库/文档/段落/问题"]) @has_permissions( lambda r, k: Permission(group=Group.DATASET, operate=Operate.MANAGE, dynamic_tag=k.get('dataset_id'))) - def post(self, request: Request, dataset_id: str, document_id: str, paragraph_id: str): - return result.success(ProblemSerializers.Create( - data={"dataset_id": dataset_id, 'document_id': document_id, 'paragraph_id': paragraph_id}).save( - request.data, with_valid=True)) + def post(self, request: Request, dataset_id: str): + return result.success( + ProblemSerializers.Create( + data={'dataset_id': dataset_id, 'problem_list': request.query_params.get('problem_list')}).save()) - @action(methods=['GET'], detail=False) - @swagger_auto_schema(operation_summary="获取段落问题列表", - operation_id="获取段落问题列表", - manual_parameters=ProblemSerializers.Query.get_request_params_api(), - responses=result.get_api_array_response(ProblemSerializers.Operate.get_response_body_api()), - tags=["知识库/文档/段落/问题"]) - @has_permissions( - lambda r, k: Permission(group=Group.DATASET, operate=Operate.USE, - dynamic_tag=k.get('dataset_id'))) - def get(self, request: Request, dataset_id: str, document_id: str, paragraph_id: str): - return result.success(ProblemSerializers.Query( - data={"dataset_id": dataset_id, 'document_id': document_id, 'paragraph_id': paragraph_id}).list( - with_valid=True)) + class Paragraph(APIView): + authentication_classes = [TokenAuth] + + @action(methods=['GET'], detail=False) + @swagger_auto_schema(operation_summary="获取关联段落列表", + operation_id="获取关联段落列表", + manual_parameters=ProblemApi.Paragraph.get_request_params_api(), + responses=result.get_api_array_response(ProblemApi.Paragraph.get_response_body_api()), + tags=["知识库/文档/段落/问题"]) + @has_permissions( + lambda r, k: Permission(group=Group.DATASET, operate=Operate.USE, + dynamic_tag=k.get('dataset_id'))) + def get(self, request: Request, dataset_id: str, problem_id: str): + return result.success(ProblemSerializers.Operate( + data={**query_params_to_single_dict(request.query_params), 'dataset_id': dataset_id, + 'problem_id': problem_id}).list_paragraph()) class Operate(APIView): authentication_classes = [TokenAuth] @action(methods=['DELETE'], detail=False) - @swagger_auto_schema(operation_summary="删除段落问题", - operation_id="删除段落问题", - manual_parameters=ProblemSerializers.Query.get_request_params_api(), + @swagger_auto_schema(operation_summary="删除问题", + operation_id="删除问题", + manual_parameters=ProblemApi.Operate.get_request_params_api(), responses=result.get_default_response(), tags=["知识库/文档/段落/问题"]) @has_permissions( lambda r, k: Permission(group=Group.DATASET, operate=Operate.MANAGE, dynamic_tag=k.get('dataset_id'))) - def delete(self, request: Request, dataset_id: str, document_id: str, paragraph_id: str, problem_id: str): - o = ProblemSerializers.Operate( - data={'dataset_id': dataset_id, 'document_id': document_id, 'paragraph_id': paragraph_id, - 'problem_id': problem_id}) - return result.success(o.delete(with_valid=True)) + def delete(self, request: Request, dataset_id: str, problem_id: str): + return result.success(ProblemSerializers.Operate( + data={**query_params_to_single_dict(request.query_params), 'dataset_id': dataset_id, + 'problem_id': problem_id}).delete()) + + @action(methods=['PUT'], detail=False) + @swagger_auto_schema(operation_summary="修改问题", + operation_id="修改问题", + manual_parameters=ProblemApi.Operate.get_request_params_api(), + request_body=ProblemApi.Operate.get_request_body_api(), + responses=result.get_api_response(ProblemApi.get_response_body_api()), + tags=["知识库/文档/段落/问题"]) + @has_permissions( + lambda r, k: Permission(group=Group.DATASET, operate=Operate.MANAGE, + dynamic_tag=k.get('dataset_id'))) + def put(self, request: Request, dataset_id: str, problem_id: str): + return result.success(ProblemSerializers.Operate( + data={**query_params_to_single_dict(request.query_params), 'dataset_id': dataset_id, + 'problem_id': problem_id}).edit(request.data)) + + class Page(APIView): + authentication_classes = [TokenAuth] + + @action(methods=['GET'], detail=False) + @swagger_auto_schema(operation_summary="分页获取问题列表", + operation_id="分页获取问题列表", + manual_parameters=result.get_page_request_params( + ProblemApi.Query.get_request_params_api()), + responses=result.get_page_api_response(ProblemApi.get_response_body_api()), + tags=["知识库/文档/段落/问题"]) + @has_permissions( + lambda r, k: Permission(group=Group.DATASET, operate=Operate.USE, + dynamic_tag=k.get('dataset_id'))) + def get(self, request: Request, dataset_id: str, current_page, page_size): + d = ProblemSerializers.Query( + data={**query_params_to_single_dict(request.query_params), 'dataset_id': dataset_id}) + d.is_valid(raise_exception=True) + return result.success(d.page(current_page, page_size)) diff --git a/apps/embedding/vector/base_vector.py b/apps/embedding/vector/base_vector.py index 665896970..f12c49895 100644 --- a/apps/embedding/vector/base_vector.py +++ b/apps/embedding/vector/base_vector.py @@ -131,6 +131,18 @@ class BaseVectorStore(ABC): def update_by_source_id(self, source_id: str, instance: Dict): pass + @abstractmethod + def update_by_source_ids(self, source_ids: List[str], instance: Dict): + pass + + @abstractmethod + def embed_documents(self, text_list: List[str]): + pass + + @abstractmethod + def embed_query(self, text: str): + pass + @abstractmethod def delete_by_dataset_id(self, dataset_id: str): pass @@ -147,6 +159,10 @@ class BaseVectorStore(ABC): def delete_by_source_id(self, source_id: str, source_type: str): pass + @abstractmethod + def delete_by_source_ids(self, source_ids: List[str], source_type: str): + pass + @abstractmethod def delete_by_paragraph_id(self, paragraph_id: str): pass diff --git a/apps/embedding/vector/pg_vector.py b/apps/embedding/vector/pg_vector.py index e8edb5d30..e2abe9509 100644 --- a/apps/embedding/vector/pg_vector.py +++ b/apps/embedding/vector/pg_vector.py @@ -14,6 +14,7 @@ from typing import Dict, List from django.db.models import QuerySet from langchain_community.embeddings import HuggingFaceEmbeddings +from common.config.embedding_config import EmbeddingModel from common.db.search import native_search, generate_sql_by_query_dict from common.db.sql_execute import select_one, select_list from common.util.file_util import get_file_content @@ -24,6 +25,20 @@ from smartdoc.conf import PROJECT_DIR class PGVector(BaseVectorStore): + def delete_by_source_ids(self, source_ids: List[str], source_type: str): + QuerySet(Embedding).filter(source_id__in=source_ids, source_type=source_type).delete() + + def update_by_source_ids(self, source_ids: List[str], instance: Dict): + QuerySet(Embedding).filter(source_id__in=source_ids).update(**instance) + + def embed_documents(self, text_list: List[str]): + embedding = EmbeddingModel.get_embedding_model() + return embedding.embed_documents(text_list) + + def embed_query(self, text: str): + embedding = EmbeddingModel.get_embedding_model() + return embedding.embed_query(text) + def vector_is_create(self) -> bool: # 项目启动默认是创建好的 不需要再创建 return True diff --git a/ui/src/api/paragraph.ts b/ui/src/api/paragraph.ts index f29376160..3df63b770 100644 --- a/ui/src/api/paragraph.ts +++ b/ui/src/api/paragraph.ts @@ -137,7 +137,7 @@ const postProblem: ( ) } /** - * 删除问题 + * 解除关联问题 * @param 参数 dataset_id, document_id, paragraph_id,problem_id */ const delProblem: ( @@ -146,8 +146,8 @@ const delProblem: ( paragraph_id: string, problem_id: string ) => Promise> = (dataset_id, document_id, paragraph_id, problem_id) => { - return del( - `${prefix}/${dataset_id}/document/${document_id}/paragraph/${paragraph_id}/problem/${problem_id}` + return put( + `${prefix}/${dataset_id}/document/${document_id}/paragraph/${paragraph_id}/problem/${problem_id}/un_association` ) }