From 5c6e1ada42cd3bf8e7cf6ad6ec9c77452ff5c7a1 Mon Sep 17 00:00:00 2001 From: shaohuzhang1 <80892890+shaohuzhang1@users.noreply.github.com> Date: Wed, 11 Dec 2024 19:51:17 +0800 Subject: [PATCH] fix: In the dialogue, the form nodes of the sub-application are not displayed as separate cards. (#1821) (cherry picked from commit 37c963b4ad487652b8e68f03f87b5715064de92d) --- apps/application/flow/common.py | 21 ++++++ apps/application/flow/i_step_node.py | 7 +- .../impl/base_application_node.py | 65 +++++++++++++++++-- .../form_node/impl/base_form_node.py | 8 +-- apps/application/flow/workflow_manage.py | 50 +++++++------- ui/src/api/type/application.ts | 14 +++- 6 files changed, 124 insertions(+), 41 deletions(-) create mode 100644 apps/application/flow/common.py diff --git a/apps/application/flow/common.py b/apps/application/flow/common.py new file mode 100644 index 000000000..96db00119 --- /dev/null +++ b/apps/application/flow/common.py @@ -0,0 +1,21 @@ +# coding=utf-8 +""" + @project: MaxKB + @Author:虎 + @file: common.py + @date:2024/12/11 17:57 + @desc: +""" + + +class Answer: + def __init__(self, content, view_type, runtime_node_id, chat_record_id, child_node): + self.view_type = view_type + self.content = content + self.runtime_node_id = runtime_node_id + self.chat_record_id = chat_record_id + self.child_node = child_node + + def to_dict(self): + return {'view_type': self.view_type, 'content': self.content, 'runtime_node_id': self.runtime_node_id, + 'chat_record_id': self.chat_record_id, 'child_node': self.child_node} diff --git a/apps/application/flow/i_step_node.py b/apps/application/flow/i_step_node.py index caf25836f..a9316770b 100644 --- a/apps/application/flow/i_step_node.py +++ b/apps/application/flow/i_step_node.py @@ -17,6 +17,7 @@ from django.db.models import QuerySet from rest_framework import serializers from rest_framework.exceptions import ValidationError, ErrorDetail +from application.flow.common import Answer from application.models import ChatRecord from application.models.api_key_model import ApplicationPublicAccessClient from common.constants.authentication_type import AuthenticationType @@ -151,11 +152,11 @@ class INode: def save_context(self, details, workflow_manage): pass - def get_answer_text(self): + def get_answer_list(self) -> List[Answer] | None: if self.answer_text is None: return None - return {'content': self.answer_text, 'runtime_node_id': self.runtime_node_id, - 'chat_record_id': self.workflow_params['chat_record_id']} + return [ + Answer(self.answer_text, self.view_type, self.runtime_node_id, self.workflow_params['chat_record_id'], {})] def __init__(self, node, workflow_params, workflow_manage, up_node_id_list=None, get_node_params=lambda node: node.properties.get('node_data')): diff --git a/apps/application/flow/step_node/application_node/impl/base_application_node.py b/apps/application/flow/step_node/application_node/impl/base_application_node.py index 17b30d0eb..76f92f878 100644 --- a/apps/application/flow/step_node/application_node/impl/base_application_node.py +++ b/apps/application/flow/step_node/application_node/impl/base_application_node.py @@ -1,9 +1,11 @@ # coding=utf-8 import json +import re import time import uuid -from typing import Dict +from typing import Dict, List +from application.flow.common import Answer from application.flow.i_step_node import NodeResult, INode from application.flow.step_node.application_node.i_application_node import IApplicationNode from application.models import Chat @@ -19,7 +21,8 @@ def _is_interrupt_exec(node, node_variable: Dict, workflow_variable: Dict): def _write_context(node_variable: Dict, workflow_variable: Dict, node: INode, workflow, answer: str): result = node_variable.get('result') - node.context['child_node'] = node_variable.get('child_node') + node.context['application_node_dict'] = node_variable.get('application_node_dict') + node.context['node_dict'] = node_variable.get('node_dict', {}) node.context['is_interrupt_exec'] = node_variable.get('is_interrupt_exec') node.context['message_tokens'] = result.get('usage', {}).get('prompt_tokens', 0) node.context['answer_tokens'] = result.get('usage', {}).get('completion_tokens', 0) @@ -43,6 +46,7 @@ def write_context_stream(node_variable: Dict, workflow_variable: Dict, node: INo answer = '' usage = {} node_child_node = {} + application_node_dict = node.context.get('application_node_dict', {}) is_interrupt_exec = False for chunk in response: # 先把流转成字符串 @@ -61,6 +65,20 @@ def write_context_stream(node_variable: Dict, workflow_variable: Dict, node: INo answer += content node_child_node = {'runtime_node_id': runtime_node_id, 'chat_record_id': chat_record_id, 'child_node': child_node} + + if real_node_id is not None: + application_node = application_node_dict.get(real_node_id, None) + if application_node is None: + + application_node_dict[real_node_id] = {'content': content, + 'runtime_node_id': runtime_node_id, + 'chat_record_id': chat_record_id, + 'child_node': child_node, + 'index': len(application_node_dict), + 'view_type': view_type} + else: + application_node['content'] += content + yield {'content': content, 'node_type': node_type, 'runtime_node_id': runtime_node_id, 'chat_record_id': chat_record_id, @@ -72,6 +90,7 @@ def write_context_stream(node_variable: Dict, workflow_variable: Dict, node: INo node_variable['result'] = {'usage': usage} node_variable['is_interrupt_exec'] = is_interrupt_exec node_variable['child_node'] = node_child_node + node_variable['application_node_dict'] = application_node_dict _write_context(node_variable, workflow_variable, node, workflow, answer) @@ -90,12 +109,43 @@ def write_context(node_variable: Dict, workflow_variable: Dict, node: INode, wor _write_context(node_variable, workflow_variable, node, workflow, answer) +def reset_application_node_dict(application_node_dict, runtime_node_id, node_data): + try: + if application_node_dict is None: + return + for key in application_node_dict: + application_node = application_node_dict[key] + if application_node.get('runtime_node_id') == runtime_node_id: + content: str = application_node.get('content') + match = re.search('.*?', content) + if match: + form_setting_str = match.group().replace('', '').replace('', '') + form_setting = json.loads(form_setting_str) + form_setting['is_submit'] = True + form_setting['form_data'] = node_data + value = f'{json.dumps(form_setting)}' + res = re.sub('.*?', + '${value}', content) + application_node['content'] = res.replace('${value}', value) + except Exception as e: + pass + + class BaseApplicationNode(IApplicationNode): - def get_answer_text(self): + def get_answer_list(self) -> List[Answer] | None: if self.answer_text is None: return None - return {'content': self.answer_text, 'runtime_node_id': self.runtime_node_id, - 'chat_record_id': self.workflow_params['chat_record_id'], 'child_node': self.context.get('child_node')} + application_node_dict = self.context.get('application_node_dict') + if application_node_dict is None: + return [ + Answer(self.answer_text, self.view_type, self.runtime_node_id, self.workflow_params['chat_record_id'], + self.context.get('child_node'))] + else: + return [Answer(n.get('content'), n.get('view_type'), self.runtime_node_id, + self.workflow_params['chat_record_id'], {'runtime_node_id': n.get('runtime_node_id'), + 'chat_record_id': n.get('chat_record_id') + , 'child_node': n.get('child_node')}) for n in + sorted(application_node_dict.values(), key=lambda item: item.get('index'))] def save_context(self, details, workflow_manage): self.context['answer'] = details.get('answer') @@ -124,6 +174,8 @@ class BaseApplicationNode(IApplicationNode): runtime_node_id = child_node.get('runtime_node_id') record_id = child_node.get('chat_record_id') child_node_value = child_node.get('child_node') + application_node_dict = self.context.get('application_node_dict') + reset_application_node_dict(application_node_dict, runtime_node_id, node_data) response = ChatMessageSerializer( data={'chat_id': current_chat_id, 'message': message, @@ -181,5 +233,6 @@ class BaseApplicationNode(IApplicationNode): 'err_message': self.err_message, 'global_fields': global_fields, 'document_list': self.workflow_manage.document_list, - 'image_list': self.workflow_manage.image_list + 'image_list': self.workflow_manage.image_list, + 'application_node_dict': self.context.get('application_node_dict') } diff --git a/apps/application/flow/step_node/form_node/impl/base_form_node.py b/apps/application/flow/step_node/form_node/impl/base_form_node.py index 69bb69372..7a2b25c3c 100644 --- a/apps/application/flow/step_node/form_node/impl/base_form_node.py +++ b/apps/application/flow/step_node/form_node/impl/base_form_node.py @@ -8,10 +8,11 @@ """ import json import time -from typing import Dict +from typing import Dict, List from langchain_core.prompts import PromptTemplate +from application.flow.common import Answer from application.flow.i_step_node import NodeResult from application.flow.step_node.form_node.i_form_node import IFormNode @@ -60,7 +61,7 @@ class BaseFormNode(IFormNode): {'result': value, 'form_field_list': form_field_list, 'form_content_format': form_content_format}, {}, _write_context=write_context) - def get_answer_text(self): + def get_answer_list(self) -> List[Answer] | None: form_content_format = self.context.get('form_content_format') form_field_list = self.context.get('form_field_list') form_setting = {"form_field_list": form_field_list, "runtime_node_id": self.runtime_node_id, @@ -70,8 +71,7 @@ class BaseFormNode(IFormNode): form = f'{json.dumps(form_setting)}' prompt_template = PromptTemplate.from_template(form_content_format, template_format='jinja2') value = prompt_template.format(form=form) - return {'content': value, 'runtime_node_id': self.runtime_node_id, - 'chat_record_id': self.workflow_params['chat_record_id']} + return [Answer(value, self.view_type, self.runtime_node_id, self.workflow_params['chat_record_id'], None)] def get_details(self, index: int, **kwargs): form_content_format = self.context.get('form_content_format') diff --git a/apps/application/flow/workflow_manage.py b/apps/application/flow/workflow_manage.py index e580a5a6a..c62719b6e 100644 --- a/apps/application/flow/workflow_manage.py +++ b/apps/application/flow/workflow_manage.py @@ -19,6 +19,7 @@ from rest_framework import status from rest_framework.exceptions import ErrorDetail, ValidationError from application.flow import tools +from application.flow.common import Answer from application.flow.i_step_node import INode, WorkFlowPostHandler, NodeResult from application.flow.step_node import get_node from common.exception.app_exception import AppApiException @@ -302,6 +303,9 @@ class WorkflowManage: get_node_params=get_node_params) self.start_node.valid_args( {**self.start_node.node_params, 'form_data': start_node_data}, self.start_node.workflow_params) + if self.start_node.type == 'application-node': + application_node_dict = node_details.get('application_node_dict', {}) + self.start_node.context['application_node_dict'] = application_node_dict self.node_context.append(self.start_node) continue @@ -482,7 +486,7 @@ class WorkflowManage: '', False, 0, 0, {'node_is_end': True, 'runtime_node_id': current_node.runtime_node_id, 'node_type': current_node.type, - 'view_type': current_node.view_type, + 'view_type': view_type, 'child_node': child_node, 'real_node_id': real_node_id}) node_chunk.end(chunk) @@ -577,35 +581,29 @@ class WorkflowManage: def get_answer_text_list(self): result = [] - next_node_id_list = [] - if self.start_node is not None: - next_node_id_list = [edge.targetNodeId for edge in self.flow.edges if - edge.sourceNodeId == self.start_node.id] - for index in range(len(self.node_context)): - node = self.node_context[index] - up_node = None - if index > 0: - up_node = self.node_context[index - 1] - answer_text = node.get_answer_text() - if answer_text is not None: - if up_node is None or node.view_type == 'single_view' or ( - node.view_type == 'many_view' and up_node.view_type == 'single_view'): - result.append(node.get_answer_text()) - elif self.chat_record is not None and next_node_id_list.__contains__( - node.id) and up_node is not None and not next_node_id_list.__contains__( - up_node.id): - result.append(node.get_answer_text()) + answer_list = reduce(lambda x, y: [*x, *y], + [n.get_answer_list() for n in self.node_context if n.get_answer_list() is not None], + []) + up_node = None + for index in range(len(answer_list)): + current_answer = answer_list[index] + if len(current_answer.content) > 0: + if up_node is None or current_answer.view_type == 'single_view' or ( + current_answer.view_type == 'many_view' and up_node.view_type == 'single_view'): + result.append(current_answer) else: if len(result) > 0: exec_index = len(result) - 1 - content = result[exec_index]['content'] - result[exec_index]['content'] += answer_text['content'] if len( - content) == 0 else ('\n\n' + answer_text['content']) + content = result[exec_index].content + result[exec_index].content += current_answer.content if len( + content) == 0 else ('\n\n' + current_answer.content) else: - answer_text = node.get_answer_text() - result.insert(0, answer_text) - - return result + result.insert(0, current_answer) + up_node = current_answer + if len(result) == 0: + # 如果没有响应 就响应一个空数据 + return [Answer('', '', '', '', {}).to_dict()] + return [r.to_dict() for r in result] def get_next_node(self): """ diff --git a/ui/src/api/type/application.ts b/ui/src/api/type/application.ts index c1d07c968..868528023 100644 --- a/ui/src/api/type/application.ts +++ b/ui/src/api/type/application.ts @@ -120,7 +120,15 @@ export class ChatRecordManage { this.chat.answer_text = this.chat.answer_text + chunk_answer } - + get_current_up_node() { + for (let i = this.node_list.length - 2; i >= 0; i--) { + const n = this.node_list[i] + if (n.content.length > 0) { + return n + } + } + return undefined + } get_run_node() { if ( this.write_node_info && @@ -135,7 +143,7 @@ export class ChatRecordManage { const index = this.node_list.indexOf(run_node) let current_up_node = undefined if (index > 0) { - current_up_node = this.node_list[index - 1] + current_up_node = this.get_current_up_node() } let answer_text_list_index = 0 @@ -293,9 +301,11 @@ export class ChatRecordManage { let n = this.node_list.find((item) => item.real_node_id == chunk.real_node_id) if (n) { n.buffer.push(...chunk.content) + n.content += chunk.content } else { n = { buffer: [...chunk.content], + content: chunk.content, real_node_id: chunk.real_node_id, node_id: chunk.node_id, chat_record_id: chunk.chat_record_id,