diff --git a/apps/application/flow/workflow_manage.py b/apps/application/flow/workflow_manage.py index f99d9351b..68a2ba022 100644 --- a/apps/application/flow/workflow_manage.py +++ b/apps/application/flow/workflow_manage.py @@ -60,6 +60,10 @@ class Flow: start_node_list = [node for node in self.nodes if node.id == 'start-node'] return start_node_list[0] + def get_search_node(self): + return [node for node in self.nodes if node.type == 'search-dataset-node'] + + def is_valid(self): """ 校验工作流数据 diff --git a/apps/application/serializers/application_serializers.py b/apps/application/serializers/application_serializers.py index 40e447d48..64ff19580 100644 --- a/apps/application/serializers/application_serializers.py +++ b/apps/application/serializers/application_serializers.py @@ -12,7 +12,7 @@ import os import re import uuid from functools import reduce -from typing import Dict +from typing import Dict, List from django.contrib.postgres.fields import ArrayField from django.core import cache, validators @@ -548,6 +548,8 @@ class ApplicationSerializer(serializers.Serializer): application.name = node_data.get('name') application.desc = node_data.get('desc') application.prologue = node_data.get('prologue') + dataset_list = self.list_dataset(with_valid=False) + self.update_reverse_search_node(work_flow, [str(dataset.get('id')) for dataset in dataset_list]) application.work_flow = work_flow application.save() work_flow_version = WorkFlowVersion(work_flow=work_flow, application=application) @@ -565,9 +567,38 @@ class ApplicationSerializer(serializers.Serializer): dataset_id_list = [d.get('id') for d in list(filter(lambda row: mapping_dataset_id_list.__contains__(row.get('id')), dataset_list))] + self.update_search_node(application.work_flow, [str(dataset.get('id')) for dataset in dataset_list]) return {**ApplicationSerializer.Query.reset_application(ApplicationSerializerModel(application).data), 'dataset_id_list': dataset_id_list} + def get_search_node(self, work_flow): + return [node for node in work_flow.get('nodes', []) if node.get('type', '') == 'search-dataset-node'] + + def update_search_node(self, work_flow, user_dataset_id_list: List): + search_node_list = self.get_search_node(work_flow) + for search_node in search_node_list: + node_data = search_node.get('properties', {}).get('node_data', {}) + dataset_id_list = node_data.get('dataset_id_list', []) + node_data['source_dataset_id_list'] = dataset_id_list + node_data['dataset_id_list'] = [dataset_id for dataset_id in dataset_id_list if + user_dataset_id_list.__contains__(dataset_id)] + + def update_reverse_search_node(self, work_flow, user_dataset_id_list: List): + search_node_list = self.get_search_node(work_flow) + for search_node in search_node_list: + node_data = search_node.get('properties', {}).get('node_data', {}) + dataset_id_list = node_data.get('dataset_id_list', []) + for dataset_id in dataset_id_list: + if not user_dataset_id_list.__contains__(dataset_id): + raise AppApiException(500, f"未知的知识库id${dataset_id},无法关联") + + source_dataset_id_list = node_data.get('source_dataset_id_list', []) + source_dataset_id_list = [source_dataset_id for source_dataset_id in source_dataset_id_list if + not user_dataset_id_list.__contains__(source_dataset_id)] + source_dataset_id_list = list({*source_dataset_id_list, *dataset_id_list}) + node_data['source_dataset_id_list'] = [] + node_data['dataset_id_list'] = source_dataset_id_list + def profile(self, with_valid=True): if with_valid: self.is_valid() @@ -596,6 +627,12 @@ class ApplicationSerializer(serializers.Serializer): user_id=application.user_id).first() if model is None: raise AppApiException(500, "模型不存在") + if 'work_flow' in instance: + # 当前用户可修改关联的知识库列表 + application_dataset_id_list = [str(dataset_dict.get('id')) for dataset_dict in + self.list_dataset(with_valid=False)] + self.update_reverse_search_node(instance.get('work_flow'), application_dataset_id_list) + update_keys = ['name', 'desc', 'model_id', 'multiple_rounds_dialogue', 'prologue', 'status', 'dataset_setting', 'model_setting', 'problem_optimization', 'api_key_is_active', 'icon', 'work_flow']