From 5c6e1ada42cd3bf8e7cf6ad6ec9c77452ff5c7a1 Mon Sep 17 00:00:00 2001
From: shaohuzhang1 <80892890+shaohuzhang1@users.noreply.github.com>
Date: Wed, 11 Dec 2024 19:51:17 +0800
Subject: [PATCH] fix: In the dialogue, the form nodes of the sub-application
are not displayed as separate cards. (#1821)
(cherry picked from commit 37c963b4ad487652b8e68f03f87b5715064de92d)
---
apps/application/flow/common.py | 21 ++++++
apps/application/flow/i_step_node.py | 7 +-
.../impl/base_application_node.py | 65 +++++++++++++++++--
.../form_node/impl/base_form_node.py | 8 +--
apps/application/flow/workflow_manage.py | 50 +++++++-------
ui/src/api/type/application.ts | 14 +++-
6 files changed, 124 insertions(+), 41 deletions(-)
create mode 100644 apps/application/flow/common.py
diff --git a/apps/application/flow/common.py b/apps/application/flow/common.py
new file mode 100644
index 000000000..96db00119
--- /dev/null
+++ b/apps/application/flow/common.py
@@ -0,0 +1,21 @@
+# coding=utf-8
+"""
+ @project: MaxKB
+ @Author:虎
+ @file: common.py
+ @date:2024/12/11 17:57
+ @desc:
+"""
+
+
+class Answer:
+ def __init__(self, content, view_type, runtime_node_id, chat_record_id, child_node):
+ self.view_type = view_type
+ self.content = content
+ self.runtime_node_id = runtime_node_id
+ self.chat_record_id = chat_record_id
+ self.child_node = child_node
+
+ def to_dict(self):
+ return {'view_type': self.view_type, 'content': self.content, 'runtime_node_id': self.runtime_node_id,
+ 'chat_record_id': self.chat_record_id, 'child_node': self.child_node}
diff --git a/apps/application/flow/i_step_node.py b/apps/application/flow/i_step_node.py
index caf25836f..a9316770b 100644
--- a/apps/application/flow/i_step_node.py
+++ b/apps/application/flow/i_step_node.py
@@ -17,6 +17,7 @@ from django.db.models import QuerySet
from rest_framework import serializers
from rest_framework.exceptions import ValidationError, ErrorDetail
+from application.flow.common import Answer
from application.models import ChatRecord
from application.models.api_key_model import ApplicationPublicAccessClient
from common.constants.authentication_type import AuthenticationType
@@ -151,11 +152,11 @@ class INode:
def save_context(self, details, workflow_manage):
pass
- def get_answer_text(self):
+ def get_answer_list(self) -> List[Answer] | None:
if self.answer_text is None:
return None
- return {'content': self.answer_text, 'runtime_node_id': self.runtime_node_id,
- 'chat_record_id': self.workflow_params['chat_record_id']}
+ return [
+ Answer(self.answer_text, self.view_type, self.runtime_node_id, self.workflow_params['chat_record_id'], {})]
def __init__(self, node, workflow_params, workflow_manage, up_node_id_list=None,
get_node_params=lambda node: node.properties.get('node_data')):
diff --git a/apps/application/flow/step_node/application_node/impl/base_application_node.py b/apps/application/flow/step_node/application_node/impl/base_application_node.py
index 17b30d0eb..76f92f878 100644
--- a/apps/application/flow/step_node/application_node/impl/base_application_node.py
+++ b/apps/application/flow/step_node/application_node/impl/base_application_node.py
@@ -1,9 +1,11 @@
# coding=utf-8
import json
+import re
import time
import uuid
-from typing import Dict
+from typing import Dict, List
+from application.flow.common import Answer
from application.flow.i_step_node import NodeResult, INode
from application.flow.step_node.application_node.i_application_node import IApplicationNode
from application.models import Chat
@@ -19,7 +21,8 @@ def _is_interrupt_exec(node, node_variable: Dict, workflow_variable: Dict):
def _write_context(node_variable: Dict, workflow_variable: Dict, node: INode, workflow, answer: str):
result = node_variable.get('result')
- node.context['child_node'] = node_variable.get('child_node')
+ node.context['application_node_dict'] = node_variable.get('application_node_dict')
+ node.context['node_dict'] = node_variable.get('node_dict', {})
node.context['is_interrupt_exec'] = node_variable.get('is_interrupt_exec')
node.context['message_tokens'] = result.get('usage', {}).get('prompt_tokens', 0)
node.context['answer_tokens'] = result.get('usage', {}).get('completion_tokens', 0)
@@ -43,6 +46,7 @@ def write_context_stream(node_variable: Dict, workflow_variable: Dict, node: INo
answer = ''
usage = {}
node_child_node = {}
+ application_node_dict = node.context.get('application_node_dict', {})
is_interrupt_exec = False
for chunk in response:
# 先把流转成字符串
@@ -61,6 +65,20 @@ def write_context_stream(node_variable: Dict, workflow_variable: Dict, node: INo
answer += content
node_child_node = {'runtime_node_id': runtime_node_id, 'chat_record_id': chat_record_id,
'child_node': child_node}
+
+ if real_node_id is not None:
+ application_node = application_node_dict.get(real_node_id, None)
+ if application_node is None:
+
+ application_node_dict[real_node_id] = {'content': content,
+ 'runtime_node_id': runtime_node_id,
+ 'chat_record_id': chat_record_id,
+ 'child_node': child_node,
+ 'index': len(application_node_dict),
+ 'view_type': view_type}
+ else:
+ application_node['content'] += content
+
yield {'content': content,
'node_type': node_type,
'runtime_node_id': runtime_node_id, 'chat_record_id': chat_record_id,
@@ -72,6 +90,7 @@ def write_context_stream(node_variable: Dict, workflow_variable: Dict, node: INo
node_variable['result'] = {'usage': usage}
node_variable['is_interrupt_exec'] = is_interrupt_exec
node_variable['child_node'] = node_child_node
+ node_variable['application_node_dict'] = application_node_dict
_write_context(node_variable, workflow_variable, node, workflow, answer)
@@ -90,12 +109,43 @@ def write_context(node_variable: Dict, workflow_variable: Dict, node: INode, wor
_write_context(node_variable, workflow_variable, node, workflow, answer)
+def reset_application_node_dict(application_node_dict, runtime_node_id, node_data):
+ try:
+ if application_node_dict is None:
+ return
+ for key in application_node_dict:
+ application_node = application_node_dict[key]
+ if application_node.get('runtime_node_id') == runtime_node_id:
+ content: str = application_node.get('content')
+ match = re.search('.*?', content)
+ if match:
+ form_setting_str = match.group().replace('', '').replace('', '')
+ form_setting = json.loads(form_setting_str)
+ form_setting['is_submit'] = True
+ form_setting['form_data'] = node_data
+ value = f'{json.dumps(form_setting)}'
+ res = re.sub('.*?',
+ '${value}', content)
+ application_node['content'] = res.replace('${value}', value)
+ except Exception as e:
+ pass
+
+
class BaseApplicationNode(IApplicationNode):
- def get_answer_text(self):
+ def get_answer_list(self) -> List[Answer] | None:
if self.answer_text is None:
return None
- return {'content': self.answer_text, 'runtime_node_id': self.runtime_node_id,
- 'chat_record_id': self.workflow_params['chat_record_id'], 'child_node': self.context.get('child_node')}
+ application_node_dict = self.context.get('application_node_dict')
+ if application_node_dict is None:
+ return [
+ Answer(self.answer_text, self.view_type, self.runtime_node_id, self.workflow_params['chat_record_id'],
+ self.context.get('child_node'))]
+ else:
+ return [Answer(n.get('content'), n.get('view_type'), self.runtime_node_id,
+ self.workflow_params['chat_record_id'], {'runtime_node_id': n.get('runtime_node_id'),
+ 'chat_record_id': n.get('chat_record_id')
+ , 'child_node': n.get('child_node')}) for n in
+ sorted(application_node_dict.values(), key=lambda item: item.get('index'))]
def save_context(self, details, workflow_manage):
self.context['answer'] = details.get('answer')
@@ -124,6 +174,8 @@ class BaseApplicationNode(IApplicationNode):
runtime_node_id = child_node.get('runtime_node_id')
record_id = child_node.get('chat_record_id')
child_node_value = child_node.get('child_node')
+ application_node_dict = self.context.get('application_node_dict')
+ reset_application_node_dict(application_node_dict, runtime_node_id, node_data)
response = ChatMessageSerializer(
data={'chat_id': current_chat_id, 'message': message,
@@ -181,5 +233,6 @@ class BaseApplicationNode(IApplicationNode):
'err_message': self.err_message,
'global_fields': global_fields,
'document_list': self.workflow_manage.document_list,
- 'image_list': self.workflow_manage.image_list
+ 'image_list': self.workflow_manage.image_list,
+ 'application_node_dict': self.context.get('application_node_dict')
}
diff --git a/apps/application/flow/step_node/form_node/impl/base_form_node.py b/apps/application/flow/step_node/form_node/impl/base_form_node.py
index 69bb69372..7a2b25c3c 100644
--- a/apps/application/flow/step_node/form_node/impl/base_form_node.py
+++ b/apps/application/flow/step_node/form_node/impl/base_form_node.py
@@ -8,10 +8,11 @@
"""
import json
import time
-from typing import Dict
+from typing import Dict, List
from langchain_core.prompts import PromptTemplate
+from application.flow.common import Answer
from application.flow.i_step_node import NodeResult
from application.flow.step_node.form_node.i_form_node import IFormNode
@@ -60,7 +61,7 @@ class BaseFormNode(IFormNode):
{'result': value, 'form_field_list': form_field_list, 'form_content_format': form_content_format}, {},
_write_context=write_context)
- def get_answer_text(self):
+ def get_answer_list(self) -> List[Answer] | None:
form_content_format = self.context.get('form_content_format')
form_field_list = self.context.get('form_field_list')
form_setting = {"form_field_list": form_field_list, "runtime_node_id": self.runtime_node_id,
@@ -70,8 +71,7 @@ class BaseFormNode(IFormNode):
form = f'{json.dumps(form_setting)}'
prompt_template = PromptTemplate.from_template(form_content_format, template_format='jinja2')
value = prompt_template.format(form=form)
- return {'content': value, 'runtime_node_id': self.runtime_node_id,
- 'chat_record_id': self.workflow_params['chat_record_id']}
+ return [Answer(value, self.view_type, self.runtime_node_id, self.workflow_params['chat_record_id'], None)]
def get_details(self, index: int, **kwargs):
form_content_format = self.context.get('form_content_format')
diff --git a/apps/application/flow/workflow_manage.py b/apps/application/flow/workflow_manage.py
index e580a5a6a..c62719b6e 100644
--- a/apps/application/flow/workflow_manage.py
+++ b/apps/application/flow/workflow_manage.py
@@ -19,6 +19,7 @@ from rest_framework import status
from rest_framework.exceptions import ErrorDetail, ValidationError
from application.flow import tools
+from application.flow.common import Answer
from application.flow.i_step_node import INode, WorkFlowPostHandler, NodeResult
from application.flow.step_node import get_node
from common.exception.app_exception import AppApiException
@@ -302,6 +303,9 @@ class WorkflowManage:
get_node_params=get_node_params)
self.start_node.valid_args(
{**self.start_node.node_params, 'form_data': start_node_data}, self.start_node.workflow_params)
+ if self.start_node.type == 'application-node':
+ application_node_dict = node_details.get('application_node_dict', {})
+ self.start_node.context['application_node_dict'] = application_node_dict
self.node_context.append(self.start_node)
continue
@@ -482,7 +486,7 @@ class WorkflowManage:
'', False, 0, 0, {'node_is_end': True,
'runtime_node_id': current_node.runtime_node_id,
'node_type': current_node.type,
- 'view_type': current_node.view_type,
+ 'view_type': view_type,
'child_node': child_node,
'real_node_id': real_node_id})
node_chunk.end(chunk)
@@ -577,35 +581,29 @@ class WorkflowManage:
def get_answer_text_list(self):
result = []
- next_node_id_list = []
- if self.start_node is not None:
- next_node_id_list = [edge.targetNodeId for edge in self.flow.edges if
- edge.sourceNodeId == self.start_node.id]
- for index in range(len(self.node_context)):
- node = self.node_context[index]
- up_node = None
- if index > 0:
- up_node = self.node_context[index - 1]
- answer_text = node.get_answer_text()
- if answer_text is not None:
- if up_node is None or node.view_type == 'single_view' or (
- node.view_type == 'many_view' and up_node.view_type == 'single_view'):
- result.append(node.get_answer_text())
- elif self.chat_record is not None and next_node_id_list.__contains__(
- node.id) and up_node is not None and not next_node_id_list.__contains__(
- up_node.id):
- result.append(node.get_answer_text())
+ answer_list = reduce(lambda x, y: [*x, *y],
+ [n.get_answer_list() for n in self.node_context if n.get_answer_list() is not None],
+ [])
+ up_node = None
+ for index in range(len(answer_list)):
+ current_answer = answer_list[index]
+ if len(current_answer.content) > 0:
+ if up_node is None or current_answer.view_type == 'single_view' or (
+ current_answer.view_type == 'many_view' and up_node.view_type == 'single_view'):
+ result.append(current_answer)
else:
if len(result) > 0:
exec_index = len(result) - 1
- content = result[exec_index]['content']
- result[exec_index]['content'] += answer_text['content'] if len(
- content) == 0 else ('\n\n' + answer_text['content'])
+ content = result[exec_index].content
+ result[exec_index].content += current_answer.content if len(
+ content) == 0 else ('\n\n' + current_answer.content)
else:
- answer_text = node.get_answer_text()
- result.insert(0, answer_text)
-
- return result
+ result.insert(0, current_answer)
+ up_node = current_answer
+ if len(result) == 0:
+ # 如果没有响应 就响应一个空数据
+ return [Answer('', '', '', '', {}).to_dict()]
+ return [r.to_dict() for r in result]
def get_next_node(self):
"""
diff --git a/ui/src/api/type/application.ts b/ui/src/api/type/application.ts
index c1d07c968..868528023 100644
--- a/ui/src/api/type/application.ts
+++ b/ui/src/api/type/application.ts
@@ -120,7 +120,15 @@ export class ChatRecordManage {
this.chat.answer_text = this.chat.answer_text + chunk_answer
}
-
+ get_current_up_node() {
+ for (let i = this.node_list.length - 2; i >= 0; i--) {
+ const n = this.node_list[i]
+ if (n.content.length > 0) {
+ return n
+ }
+ }
+ return undefined
+ }
get_run_node() {
if (
this.write_node_info &&
@@ -135,7 +143,7 @@ export class ChatRecordManage {
const index = this.node_list.indexOf(run_node)
let current_up_node = undefined
if (index > 0) {
- current_up_node = this.node_list[index - 1]
+ current_up_node = this.get_current_up_node()
}
let answer_text_list_index = 0
@@ -293,9 +301,11 @@ export class ChatRecordManage {
let n = this.node_list.find((item) => item.real_node_id == chunk.real_node_id)
if (n) {
n.buffer.push(...chunk.content)
+ n.content += chunk.content
} else {
n = {
buffer: [...chunk.content],
+ content: chunk.content,
real_node_id: chunk.real_node_id,
node_id: chunk.node_id,
chat_record_id: chunk.chat_record_id,