From 4ca4d2fa84a20c2c7ec28d237e787bf58e0d0dbd Mon Sep 17 00:00:00 2001 From: shaohuzhang1 <80892890+shaohuzhang1@users.noreply.github.com> Date: Wed, 3 Jul 2024 11:46:37 +0800 Subject: [PATCH] =?UTF-8?q?fix:=20=E4=BF=AE=E5=A4=8D=E3=80=90=E5=BA=94?= =?UTF-8?q?=E7=94=A8=E7=BC=96=E6=8E=92=E3=80=91=E6=B2=A1=E6=9C=89=E6=9D=83?= =?UTF-8?q?=E9=99=90=E7=9A=84=E7=9F=A5=E8=AF=86=E5=BA=93=EF=BC=8C=E5=90=8D?= =?UTF-8?q?=E7=A7=B0=E6=98=AF=E7=A9=BA=E7=99=BD=E7=9A=84=20(#691)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- apps/application/flow/workflow_manage.py | 4 ++ .../serializers/application_serializers.py | 39 ++++++++++++++++++- 2 files changed, 42 insertions(+), 1 deletion(-) diff --git a/apps/application/flow/workflow_manage.py b/apps/application/flow/workflow_manage.py index f99d9351b..68a2ba022 100644 --- a/apps/application/flow/workflow_manage.py +++ b/apps/application/flow/workflow_manage.py @@ -60,6 +60,10 @@ class Flow: start_node_list = [node for node in self.nodes if node.id == 'start-node'] return start_node_list[0] + def get_search_node(self): + return [node for node in self.nodes if node.type == 'search-dataset-node'] + + def is_valid(self): """ 校验工作流数据 diff --git a/apps/application/serializers/application_serializers.py b/apps/application/serializers/application_serializers.py index 40e447d48..64ff19580 100644 --- a/apps/application/serializers/application_serializers.py +++ b/apps/application/serializers/application_serializers.py @@ -12,7 +12,7 @@ import os import re import uuid from functools import reduce -from typing import Dict +from typing import Dict, List from django.contrib.postgres.fields import ArrayField from django.core import cache, validators @@ -548,6 +548,8 @@ class ApplicationSerializer(serializers.Serializer): application.name = node_data.get('name') application.desc = node_data.get('desc') application.prologue = node_data.get('prologue') + dataset_list = self.list_dataset(with_valid=False) + self.update_reverse_search_node(work_flow, [str(dataset.get('id')) for dataset in dataset_list]) application.work_flow = work_flow application.save() work_flow_version = WorkFlowVersion(work_flow=work_flow, application=application) @@ -565,9 +567,38 @@ class ApplicationSerializer(serializers.Serializer): dataset_id_list = [d.get('id') for d in list(filter(lambda row: mapping_dataset_id_list.__contains__(row.get('id')), dataset_list))] + self.update_search_node(application.work_flow, [str(dataset.get('id')) for dataset in dataset_list]) return {**ApplicationSerializer.Query.reset_application(ApplicationSerializerModel(application).data), 'dataset_id_list': dataset_id_list} + def get_search_node(self, work_flow): + return [node for node in work_flow.get('nodes', []) if node.get('type', '') == 'search-dataset-node'] + + def update_search_node(self, work_flow, user_dataset_id_list: List): + search_node_list = self.get_search_node(work_flow) + for search_node in search_node_list: + node_data = search_node.get('properties', {}).get('node_data', {}) + dataset_id_list = node_data.get('dataset_id_list', []) + node_data['source_dataset_id_list'] = dataset_id_list + node_data['dataset_id_list'] = [dataset_id for dataset_id in dataset_id_list if + user_dataset_id_list.__contains__(dataset_id)] + + def update_reverse_search_node(self, work_flow, user_dataset_id_list: List): + search_node_list = self.get_search_node(work_flow) + for search_node in search_node_list: + node_data = search_node.get('properties', {}).get('node_data', {}) + dataset_id_list = node_data.get('dataset_id_list', []) + for dataset_id in dataset_id_list: + if not user_dataset_id_list.__contains__(dataset_id): + raise AppApiException(500, f"未知的知识库id${dataset_id},无法关联") + + source_dataset_id_list = node_data.get('source_dataset_id_list', []) + source_dataset_id_list = [source_dataset_id for source_dataset_id in source_dataset_id_list if + not user_dataset_id_list.__contains__(source_dataset_id)] + source_dataset_id_list = list({*source_dataset_id_list, *dataset_id_list}) + node_data['source_dataset_id_list'] = [] + node_data['dataset_id_list'] = source_dataset_id_list + def profile(self, with_valid=True): if with_valid: self.is_valid() @@ -596,6 +627,12 @@ class ApplicationSerializer(serializers.Serializer): user_id=application.user_id).first() if model is None: raise AppApiException(500, "模型不存在") + if 'work_flow' in instance: + # 当前用户可修改关联的知识库列表 + application_dataset_id_list = [str(dataset_dict.get('id')) for dataset_dict in + self.list_dataset(with_valid=False)] + self.update_reverse_search_node(instance.get('work_flow'), application_dataset_id_list) + update_keys = ['name', 'desc', 'model_id', 'multiple_rounds_dialogue', 'prologue', 'status', 'dataset_setting', 'model_setting', 'problem_optimization', 'api_key_is_active', 'icon', 'work_flow']