From 171d1aaea0a3e35a0a0fb39afd8b92d6eb91b2df Mon Sep 17 00:00:00 2001 From: shaohuzhang1 <80892890+shaohuzhang1@users.noreply.github.com> Date: Wed, 3 Jul 2024 12:26:07 +0800 Subject: [PATCH] =?UTF-8?q?fix:=20=E4=BF=AE=E5=A4=8D=E3=80=90=E7=9F=A5?= =?UTF-8?q?=E8=AF=86=E5=BA=93=E3=80=91=E8=AE=BE=E7=BD=AE=E4=B8=AD=E5=85=B3?= =?UTF-8?q?=E8=81=94=E5=BA=94=E7=94=A8=E6=9C=AA=E6=98=BE=E7=A4=BA=E5=9C=A8?= =?UTF-8?q?=E9=AB=98=E7=BA=A7=E7=BC=96=E6=8E=92=E4=B8=AD=E4=BD=BF=E7=94=A8?= =?UTF-8?q?=E8=AF=A5=E7=9F=A5=E8=AF=86=E5=BA=93=E7=9A=84=E5=BA=94=E7=94=A8?= =?UTF-8?q?=20(#692)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../serializers/application_serializers.py | 28 +++++++++++++------ 1 file changed, 20 insertions(+), 8 deletions(-) diff --git a/apps/application/serializers/application_serializers.py b/apps/application/serializers/application_serializers.py index 64ff19580..189b70d48 100644 --- a/apps/application/serializers/application_serializers.py +++ b/apps/application/serializers/application_serializers.py @@ -533,6 +533,7 @@ class ApplicationSerializer(serializers.Serializer): QuerySet(Application).filter(id=self.data.get('application_id')).delete() return True + @transaction.atomic def publish(self, instance, with_valid=True): if with_valid: self.is_valid() @@ -549,9 +550,12 @@ class ApplicationSerializer(serializers.Serializer): application.desc = node_data.get('desc') application.prologue = node_data.get('prologue') dataset_list = self.list_dataset(with_valid=False) - self.update_reverse_search_node(work_flow, [str(dataset.get('id')) for dataset in dataset_list]) + dataset_id_list = self.update_reverse_search_node(work_flow, + [str(dataset.get('id')) for dataset in dataset_list]) application.work_flow = work_flow application.save() + # 插入知识库关联关系 + self.save_application_mapping(dataset_id_list, application.id) work_flow_version = WorkFlowVersion(work_flow=work_flow, application=application) work_flow_version.save() return True @@ -585,6 +589,7 @@ class ApplicationSerializer(serializers.Serializer): def update_reverse_search_node(self, work_flow, user_dataset_id_list: List): search_node_list = self.get_search_node(work_flow) + dataset_id_list = [] for search_node in search_node_list: node_data = search_node.get('properties', {}).get('node_data', {}) dataset_id_list = node_data.get('dataset_id_list', []) @@ -598,6 +603,8 @@ class ApplicationSerializer(serializers.Serializer): source_dataset_id_list = list({*source_dataset_id_list, *dataset_id_list}) node_data['source_dataset_id_list'] = [] node_data['dataset_id_list'] = source_dataset_id_list + dataset_id_list = [*source_dataset_id_list, *dataset_id_list] + return list(set(dataset_id_list)) def profile(self, with_valid=True): if with_valid: @@ -611,6 +618,7 @@ class ApplicationSerializer(serializers.Serializer): {**ApplicationSerializer.ApplicationModel(application).data, 'show_source': application_access_token.show_source}) + @transaction.atomic def edit(self, instance: Dict, with_valid=True): if with_valid: self.is_valid() @@ -653,16 +661,20 @@ class ApplicationSerializer(serializers.Serializer): if not application_dataset_id_list.__contains__(dataset_id): raise AppApiException(500, f"未知的知识库id${dataset_id},无法关联") - # 删除已经关联的id - QuerySet(ApplicationDatasetMapping).filter(dataset_id__in=application_dataset_id_list, - application_id=application_id).delete() - # 插入 - QuerySet(ApplicationDatasetMapping).bulk_create( - [ApplicationDatasetMapping(application_id=application_id, dataset_id=dataset_id) for dataset_id in - dataset_id_list]) if len(dataset_id_list) > 0 else None + self.save_application_mapping(application_dataset_id_list, application_id) chat_cache.clear_by_application_id(application_id) return self.one(with_valid=False) + @staticmethod + def save_application_mapping(dataset_id_list, application_id): + # 删除已经关联的id + QuerySet(ApplicationDatasetMapping).filter(dataset_id__in=dataset_id_list, + application_id=application_id).delete() + # 插入 + QuerySet(ApplicationDatasetMapping).bulk_create( + [ApplicationDatasetMapping(application_id=application_id, dataset_id=dataset_id) for dataset_id in + dataset_id_list]) if len(dataset_id_list) > 0 else None + def list_dataset(self, with_valid=True): if with_valid: self.is_valid(raise_exception=True)