From 1f26744cef09aaebfc3c0e4e13aab9995b8b8351 Mon Sep 17 00:00:00 2001 From: shaohuzhang1 <80892890+shaohuzhang1@users.noreply.github.com> Date: Fri, 20 Sep 2024 17:07:52 +0800 Subject: [PATCH] =?UTF-8?q?feat:=20=E6=89=B9=E9=87=8F=E5=85=B3=E8=81=94?= =?UTF-8?q?=E9=97=AE=E9=A2=98=20(#1235)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- apps/common/event/listener_manage.py | 5 ++ .../serializers/problem_serializers.py | 74 ++++++++++++++- apps/dataset/swagger_api/problem_api.py | 30 +++++++ apps/dataset/views/problem.py | 14 +++ apps/embedding/task/embedding.py | 5 ++ apps/embedding/vector/base_vector.py | 36 ++++---- ui/src/api/problem.ts | 19 +++- .../problem/component/RelateProblemDialog.vue | 90 +++++++++++++------ ui/src/views/problem/index.vue | 25 +++++- 9 files changed, 245 insertions(+), 53 deletions(-) diff --git a/apps/common/event/listener_manage.py b/apps/common/event/listener_manage.py index 537677ba2..40ac4884d 100644 --- a/apps/common/event/listener_manage.py +++ b/apps/common/event/listener_manage.py @@ -137,6 +137,11 @@ class ListenerManagement: QuerySet(Paragraph).filter(id=paragraph_id).update(**{'status': status}) max_kb.info(f'结束--->向量化段落:{paragraph_id}') + @staticmethod + def embedding_by_data_list(data_list: List, embedding_model: Embeddings): + # 批量向量化 + VectorStore.get_embedding_vector().batch_save(data_list, embedding_model, lambda: True) + @staticmethod def embedding_by_document(document_id, embedding_model: Embeddings): """ diff --git a/apps/dataset/serializers/problem_serializers.py b/apps/dataset/serializers/problem_serializers.py index 14e3bced8..22ff6c1da 100644 --- a/apps/dataset/serializers/problem_serializers.py +++ b/apps/dataset/serializers/problem_serializers.py @@ -8,6 +8,7 @@ """ import os import uuid +from functools import reduce from typing import Dict, List from django.db import transaction @@ -21,7 +22,8 @@ from common.util.field_message import ErrMessage from common.util.file_util import get_file_content from dataset.models import Problem, Paragraph, ProblemParagraphMapping, DataSet from dataset.serializers.common_serializers import get_embedding_model_id_by_dataset_id -from embedding.task import delete_embedding_by_source_ids, update_problem_embedding +from embedding.models import SourceType +from embedding.task import delete_embedding_by_source_ids, update_problem_embedding, embedding_by_data_list from smartdoc.conf import PROJECT_DIR @@ -50,6 +52,35 @@ class ProblemInstanceSerializer(ApiMixin, serializers.Serializer): }) +class AssociationParagraph(serializers.Serializer): + paragraph_id = serializers.UUIDField(required=True, error_messages=ErrMessage.uuid("段落id")) + document_id = serializers.UUIDField(required=True, error_messages=ErrMessage.uuid("文档id")) + + +class BatchAssociation(serializers.Serializer): + problem_id_list = serializers.ListField(required=True, error_messages=ErrMessage.list("问题id列表"), + child=serializers.UUIDField(required=True, + error_messages=ErrMessage.uuid("问题id"))) + paragraph_list = AssociationParagraph(many=True) + + +def is_exits(exits_problem_paragraph_mapping_list, new_paragraph_mapping): + filter_list = [exits_problem_paragraph_mapping for exits_problem_paragraph_mapping in + exits_problem_paragraph_mapping_list if + str(exits_problem_paragraph_mapping.paragraph_id) == new_paragraph_mapping.paragraph_id + and str(exits_problem_paragraph_mapping.problem_id) == new_paragraph_mapping.problem_id + and str(exits_problem_paragraph_mapping.dataset_id) == new_paragraph_mapping.dataset_id] + return len(filter_list) > 0 + + +def to_problem_paragraph_mapping(problem, document_id: str, paragraph_id: str, dataset_id: str): + return ProblemParagraphMapping(id=uuid.uuid1(), + document_id=document_id, + paragraph_id=paragraph_id, + dataset_id=dataset_id, + problem_id=str(problem.id)), problem + + class ProblemSerializers(ApiMixin, serializers.Serializer): class Create(serializers.Serializer): dataset_id = serializers.UUIDField(required=True, error_messages=ErrMessage.uuid("知识库id")) @@ -115,6 +146,47 @@ class ProblemSerializers(ApiMixin, serializers.Serializer): delete_embedding_by_source_ids(source_ids) return True + def association(self, instance: Dict, with_valid=True): + if with_valid: + self.is_valid(raise_exception=True) + BatchAssociation(data=instance).is_valid(raise_exception=True) + dataset_id = self.data.get('dataset_id') + paragraph_list = instance.get('paragraph_list') + problem_id_list = instance.get('problem_id_list') + problem_list = QuerySet(Problem).filter(id__in=problem_id_list) + exits_problem_paragraph_mapping = QuerySet(ProblemParagraphMapping).filter(problem_id__in=problem_id_list, + paragraph_id__in=[ + p.get('paragraph_id') + for p in + paragraph_list]) + problem_paragraph_mapping_list = [(problem_paragraph_mapping, problem) for + problem_paragraph_mapping, problem in reduce(lambda x, y: [*x, *y], + [[ + to_problem_paragraph_mapping( + problem, + paragraph.get( + 'document_id'), + paragraph.get( + 'paragraph_id'), + dataset_id) for + paragraph in + paragraph_list] + for problem in + problem_list], []) if + not is_exits(exits_problem_paragraph_mapping, problem_paragraph_mapping)] + QuerySet(ProblemParagraphMapping).bulk_create( + [problem_paragraph_mapping for problem_paragraph_mapping, problem in problem_paragraph_mapping_list]) + data_list = [{'text': problem.content, + 'is_active': True, + 'source_type': SourceType.PROBLEM, + 'source_id': str(problem_paragraph_mapping.id), + 'document_id': str(problem_paragraph_mapping.document_id), + 'paragraph_id': str(problem_paragraph_mapping.paragraph_id), + 'dataset_id': dataset_id, + } for problem_paragraph_mapping, problem in problem_paragraph_mapping_list] + model_id = get_embedding_model_id_by_dataset_id(self.data.get('dataset_id')) + embedding_by_data_list(data_list, model_id=model_id) + class Operate(serializers.Serializer): dataset_id = serializers.UUIDField(required=True, error_messages=ErrMessage.uuid("知识库id")) diff --git a/apps/dataset/swagger_api/problem_api.py b/apps/dataset/swagger_api/problem_api.py index a7397aaaf..7932e0cea 100644 --- a/apps/dataset/swagger_api/problem_api.py +++ b/apps/dataset/swagger_api/problem_api.py @@ -36,6 +36,36 @@ class ProblemApi(ApiMixin): } ) + class BatchAssociation(ApiMixin): + @staticmethod + def get_request_params_api(): + return ProblemApi.BatchOperate.get_request_params_api() + + @staticmethod + def get_request_body_api(): + return openapi.Schema( + type=openapi.TYPE_OBJECT, + required=['problem_id_list'], + properties={ + 'problem_id_list': openapi.Schema(type=openapi.TYPE_ARRAY, title="问题id列表", + description="问题id列表", + items=openapi.Schema(type=openapi.TYPE_STRING)), + 'paragraph_list': openapi.Schema(type=openapi.TYPE_ARRAY, title="关联段落信息列表", + description="关联段落信息列表", + items=openapi.Schema(type=openapi.TYPE_OBJECT, + required=['paragraph_id', 'document_id'], + properties={ + 'paragraph_id': openapi.Schema( + type=openapi.TYPE_STRING, + title="段落id"), + 'document_id': openapi.Schema( + type=openapi.TYPE_STRING, + title="文档id") + })) + + } + ) + class BatchOperate(ApiMixin): @staticmethod def get_request_params_api(): diff --git a/apps/dataset/views/problem.py b/apps/dataset/views/problem.py index beebcc673..1d0ccb53b 100644 --- a/apps/dataset/views/problem.py +++ b/apps/dataset/views/problem.py @@ -88,6 +88,20 @@ class Problem(APIView): return result.success( ProblemSerializers.BatchOperate(data={'dataset_id': dataset_id}).delete(request.data)) + @action(methods=['POST'], detail=False) + @swagger_auto_schema(operation_summary="批量关联段落", + operation_id="批量关联段落", + request_body=ProblemApi.BatchAssociation.get_request_body_api(), + manual_parameters=ProblemApi.BatchOperate.get_request_params_api(), + responses=result.get_default_response(), + tags=["知识库/文档/段落/问题"]) + @has_permissions( + lambda r, k: Permission(group=Group.DATASET, operate=Operate.MANAGE, + dynamic_tag=k.get('dataset_id'))) + def post(self, request: Request, dataset_id: str): + return result.success( + ProblemSerializers.BatchOperate(data={'dataset_id': dataset_id}).association(request.data)) + class Operate(APIView): authentication_classes = [TokenAuth] diff --git a/apps/embedding/task/embedding.py b/apps/embedding/task/embedding.py index 779969c64..67590e2c7 100644 --- a/apps/embedding/task/embedding.py +++ b/apps/embedding/task/embedding.py @@ -111,6 +111,11 @@ def embedding_by_problem(args, model_id): ListenerManagement.embedding_by_problem(args, embedding_model) +def embedding_by_data_list(args: List, model_id): + embedding_model = get_embedding_model(model_id) + ListenerManagement.embedding_by_data_list(args, embedding_model) + + def delete_embedding_by_document(document_id): """ 删除指定文档id的向量 diff --git a/apps/embedding/vector/base_vector.py b/apps/embedding/vector/base_vector.py index f9ef23d72..ab5ab4103 100644 --- a/apps/embedding/vector/base_vector.py +++ b/apps/embedding/vector/base_vector.py @@ -87,27 +87,21 @@ class BaseVectorStore(ABC): self._batch_save(child_array, embedding, lambda: True) def batch_save(self, data_list: List[Dict], embedding: Embeddings, is_save_function): - # 获取锁 - lock.acquire() - try: - """ - 批量插入 - :param data_list: 数据列表 - :param embedding: 向量化处理器 - :return: bool - """ - self.save_pre_handler() - chunk_list = chunk_data_list(data_list) - result = sub_array(chunk_list) - for child_array in result: - if is_save_function(): - self._batch_save(child_array, embedding, is_save_function) - else: - break - finally: - # 释放锁 - lock.release() - return True + """ + 批量插入 + @param data_list: 数据列表 + @param embedding: 向量化处理器 + @param is_save_function: + :return: bool + """ + self.save_pre_handler() + chunk_list = chunk_data_list(data_list) + result = sub_array(chunk_list) + for child_array in result: + if is_save_function(): + self._batch_save(child_array, embedding, is_save_function) + else: + break @abstractmethod def _save(self, text, source_type: SourceType, dataset_id: str, document_id: str, paragraph_id: str, source_id: str, diff --git a/ui/src/api/problem.ts b/ui/src/api/problem.ts index 7d8d16226..4625d6de6 100644 --- a/ui/src/api/problem.ts +++ b/ui/src/api/problem.ts @@ -97,11 +97,28 @@ const getDetailProblems: ( return get(`${prefix}/${dataset_id}/problem/${problem_id}/paragraph`, undefined, loading) } +/** + * 批量关联段落 + * @param 参数 dataset_id, + * { + "problem_id_list": "Array", + "paragraph_list": "Array", + } + */ +const postMulAssociationProblem: ( + dataset_id: string, + data: any, + loading?: Ref +) => Promise> = (dataset_id, data, loading) => { + return post(`${prefix}/${dataset_id}/problem/_batch`, data, undefined, loading) +} + export default { getProblems, postProblems, delProblems, putProblems, getDetailProblems, - delMulProblem + delMulProblem, + postMulAssociationProblem } diff --git a/ui/src/views/problem/component/RelateProblemDialog.vue b/ui/src/views/problem/component/RelateProblemDialog.vue index 7306a827f..fed4fa073 100644 --- a/ui/src/views/problem/component/RelateProblemDialog.vue +++ b/ui/src/views/problem/component/RelateProblemDialog.vue @@ -91,6 +91,12 @@ +