From b6c65154c54bfdaf75dbb58b0435629d9c3edcd8 Mon Sep 17 00:00:00 2001
From: shaohuzhang1 <80892890+shaohuzhang1@users.noreply.github.com>
Date: Tue, 3 Dec 2024 15:23:53 +0800
Subject: [PATCH] =?UTF-8?q?fix:=20=E4=BF=AE=E5=A4=8D=E5=AD=90=E5=BA=94?=
=?UTF-8?q?=E7=94=A8=E8=A1=A8=E5=8D=95=E8=B0=83=E7=94=A8=E6=97=A0=E6=B3=95?=
=?UTF-8?q?=E8=B0=83=E7=94=A8=E9=97=AE=E9=A2=98=20(#1741)?=
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit
---
apps/application/flow/i_step_node.py | 27 ++++-
.../application_node/i_application_node.py | 4 +-
.../impl/base_application_node.py | 62 ++++++++--
.../flow/step_node/form_node/i_form_node.py | 3 +-
.../form_node/impl/base_form_node.py | 10 +-
apps/application/flow/workflow_manage.py | 63 ++++++++---
...application_file_upload_enable_and_more.py | 7 +-
apps/application/models/application.py | 3 +-
.../serializers/chat_message_serializers.py | 5 +-
apps/application/views/chat_views.py | 3 +-
.../impl/response/openai_to_response.py | 2 +-
.../impl/response/system_to_response.py | 2 +-
apps/embedding/task/embedding.py | 3 +-
ui/src/api/type/application.ts | 106 ++++++++++++++----
.../component/answer-content/index.vue | 25 ++++-
ui/src/components/ai-chat/index.vue | 4 +-
ui/src/components/markdown/FormRander.vue | 31 +++--
ui/src/components/markdown/MdRenderer.vue | 6 +
18 files changed, 273 insertions(+), 93 deletions(-)
diff --git a/apps/application/flow/i_step_node.py b/apps/application/flow/i_step_node.py
index d3c6207aa..caf25836f 100644
--- a/apps/application/flow/i_step_node.py
+++ b/apps/application/flow/i_step_node.py
@@ -40,6 +40,10 @@ def write_context(step_variable: Dict, global_variable: Dict, node, workflow):
node.context['run_time'] = time.time() - node.context['start_time']
+def is_interrupt(node, step_variable: Dict, global_variable: Dict):
+ return node.type == 'form-node' and not node.context.get('is_submit', False)
+
+
class WorkFlowPostHandler:
def __init__(self, chat_info, client_id, client_type):
self.chat_info = chat_info
@@ -57,7 +61,7 @@ class WorkFlowPostHandler:
answer_tokens = sum([row.get('answer_tokens') for row in details.values() if
'answer_tokens' in row and row.get('answer_tokens') is not None])
answer_text_list = workflow.get_answer_text_list()
- answer_text = '\n\n'.join(answer_text_list)
+ answer_text = '\n\n'.join(answer['content'] for answer in answer_text_list)
if workflow.chat_record is not None:
chat_record = workflow.chat_record
chat_record.answer_text = answer_text
@@ -91,10 +95,11 @@ class WorkFlowPostHandler:
class NodeResult:
def __init__(self, node_variable: Dict, workflow_variable: Dict,
- _write_context=write_context):
+ _write_context=write_context, _is_interrupt=is_interrupt):
self._write_context = _write_context
self.node_variable = node_variable
self.workflow_variable = workflow_variable
+ self._is_interrupt = _is_interrupt
def write_context(self, node, workflow):
return self._write_context(self.node_variable, self.workflow_variable, node, workflow)
@@ -102,6 +107,14 @@ class NodeResult:
def is_assertion_result(self):
return 'branch_id' in self.node_variable
+ def is_interrupt_exec(self, current_node):
+ """
+ 是否中断执行
+ @param current_node:
+ @return:
+ """
+ return self._is_interrupt(current_node, self.node_variable, self.workflow_variable)
+
class ReferenceAddressSerializer(serializers.Serializer):
node_id = serializers.CharField(required=True, error_messages=ErrMessage.char("节点id"))
@@ -139,14 +152,18 @@ class INode:
pass
def get_answer_text(self):
- return self.answer_text
+ if self.answer_text is None:
+ return None
+ return {'content': self.answer_text, 'runtime_node_id': self.runtime_node_id,
+ 'chat_record_id': self.workflow_params['chat_record_id']}
- def __init__(self, node, workflow_params, workflow_manage, up_node_id_list=None):
+ def __init__(self, node, workflow_params, workflow_manage, up_node_id_list=None,
+ get_node_params=lambda node: node.properties.get('node_data')):
# 当前步骤上下文,用于存储当前步骤信息
self.status = 200
self.err_message = ''
self.node = node
- self.node_params = node.properties.get('node_data')
+ self.node_params = get_node_params(node)
self.workflow_params = workflow_params
self.workflow_manage = workflow_manage
self.node_params_serializer = None
diff --git a/apps/application/flow/step_node/application_node/i_application_node.py b/apps/application/flow/step_node/application_node/i_application_node.py
index d0184d4a3..8c4675ea7 100644
--- a/apps/application/flow/step_node/application_node/i_application_node.py
+++ b/apps/application/flow/step_node/application_node/i_application_node.py
@@ -14,6 +14,8 @@ class ApplicationNodeSerializer(serializers.Serializer):
user_input_field_list = serializers.ListField(required=False, error_messages=ErrMessage.uuid("用户输入字段"))
image_list = serializers.ListField(required=False, error_messages=ErrMessage.list("图片"))
document_list = serializers.ListField(required=False, error_messages=ErrMessage.list("文档"))
+ child_node = serializers.DictField(required=False, allow_null=True, error_messages=ErrMessage.dict("子节点"))
+ node_data = serializers.DictField(required=False, allow_null=True, error_messages=ErrMessage.dict("表单数据"))
class IApplicationNode(INode):
@@ -55,5 +57,5 @@ class IApplicationNode(INode):
message=str(question), **kwargs)
def execute(self, application_id, message, chat_id, chat_record_id, stream, re_chat, client_id, client_type,
- app_document_list=None, app_image_list=None, **kwargs) -> NodeResult:
+ app_document_list=None, app_image_list=None, child_node=None, node_data=None, **kwargs) -> NodeResult:
pass
diff --git a/apps/application/flow/step_node/application_node/impl/base_application_node.py b/apps/application/flow/step_node/application_node/impl/base_application_node.py
index 30954a2b6..532a81e52 100644
--- a/apps/application/flow/step_node/application_node/impl/base_application_node.py
+++ b/apps/application/flow/step_node/application_node/impl/base_application_node.py
@@ -2,19 +2,25 @@
import json
import time
import uuid
-from typing import List, Dict
+from typing import Dict
+
from application.flow.i_step_node import NodeResult, INode
from application.flow.step_node.application_node.i_application_node import IApplicationNode
from application.models import Chat
-from common.handle.impl.response.openai_to_response import OpenaiToResponse
def string_to_uuid(input_str):
return str(uuid.uuid5(uuid.NAMESPACE_DNS, input_str))
+def _is_interrupt_exec(node, node_variable: Dict, workflow_variable: Dict):
+ return node_variable.get('is_interrupt_exec', False)
+
+
def _write_context(node_variable: Dict, workflow_variable: Dict, node: INode, workflow, answer: str):
result = node_variable.get('result')
+ node.context['child_node'] = node_variable['child_node']
+ node.context['is_interrupt_exec'] = node_variable['is_interrupt_exec']
node.context['message_tokens'] = result.get('usage', {}).get('prompt_tokens', 0)
node.context['answer_tokens'] = result.get('usage', {}).get('completion_tokens', 0)
node.context['answer'] = answer
@@ -36,17 +42,34 @@ def write_context_stream(node_variable: Dict, workflow_variable: Dict, node: INo
response = node_variable.get('result')
answer = ''
usage = {}
+ node_child_node = {}
+ is_interrupt_exec = False
for chunk in response:
# 先把流转成字符串
response_content = chunk.decode('utf-8')[6:]
response_content = json.loads(response_content)
- choices = response_content.get('choices')
- if choices and isinstance(choices, list) and len(choices) > 0:
- content = choices[0].get('delta', {}).get('content', '')
- answer += content
- yield content
+ content = response_content.get('content', '')
+ runtime_node_id = response_content.get('runtime_node_id', '')
+ chat_record_id = response_content.get('chat_record_id', '')
+ child_node = response_content.get('child_node')
+ node_type = response_content.get('node_type')
+ real_node_id = response_content.get('real_node_id')
+ node_is_end = response_content.get('node_is_end', False)
+ if node_type == 'form-node':
+ is_interrupt_exec = True
+ answer += content
+ node_child_node = {'runtime_node_id': runtime_node_id, 'chat_record_id': chat_record_id,
+ 'child_node': child_node}
+ yield {'content': content,
+ 'node_type': node_type,
+ 'runtime_node_id': runtime_node_id, 'chat_record_id': chat_record_id,
+ 'child_node': child_node,
+ 'real_node_id': real_node_id,
+ 'node_is_end': node_is_end}
usage = response_content.get('usage', {})
node_variable['result'] = {'usage': usage}
+ node_variable['is_interrupt_exec'] = is_interrupt_exec
+ node_variable['child_node'] = node_child_node
_write_context(node_variable, workflow_variable, node, workflow, answer)
@@ -64,6 +87,11 @@ def write_context(node_variable: Dict, workflow_variable: Dict, node: INode, wor
class BaseApplicationNode(IApplicationNode):
+ def get_answer_text(self):
+ if self.answer_text is None:
+ return None
+ return {'content': self.answer_text, 'runtime_node_id': self.runtime_node_id,
+ 'chat_record_id': self.workflow_params['chat_record_id'], 'child_node': self.context.get('child_node')}
def save_context(self, details, workflow_manage):
self.context['answer'] = details.get('answer')
@@ -72,7 +100,7 @@ class BaseApplicationNode(IApplicationNode):
self.answer_text = details.get('answer')
def execute(self, application_id, message, chat_id, chat_record_id, stream, re_chat, client_id, client_type,
- app_document_list=None, app_image_list=None,
+ app_document_list=None, app_image_list=None, child_node=None, node_data=None,
**kwargs) -> NodeResult:
from application.serializers.chat_message_serializers import ChatMessageSerializer
# 生成嵌入应用的chat_id
@@ -85,6 +113,14 @@ class BaseApplicationNode(IApplicationNode):
app_document_list = []
if app_image_list is None:
app_image_list = []
+ runtime_node_id = None
+ record_id = None
+ child_node_value = None
+ if child_node is not None:
+ runtime_node_id = child_node.get('runtime_node_id')
+ record_id = child_node.get('chat_record_id')
+ child_node_value = child_node.get('child_node')
+
response = ChatMessageSerializer(
data={'chat_id': current_chat_id, 'message': message,
're_chat': re_chat,
@@ -94,16 +130,20 @@ class BaseApplicationNode(IApplicationNode):
'client_type': client_type,
'document_list': app_document_list,
'image_list': app_image_list,
- 'form_data': kwargs}).chat(base_to_response=OpenaiToResponse())
+ 'runtime_node_id': runtime_node_id,
+ 'chat_record_id': record_id,
+ 'child_node': child_node_value,
+ 'node_data': node_data,
+ 'form_data': kwargs}).chat()
if response.status_code == 200:
if stream:
content_generator = response.streaming_content
return NodeResult({'result': content_generator, 'question': message}, {},
- _write_context=write_context_stream)
+ _write_context=write_context_stream, _is_interrupt=_is_interrupt_exec)
else:
data = json.loads(response.content)
return NodeResult({'result': data, 'question': message}, {},
- _write_context=write_context)
+ _write_context=write_context, _is_interrupt=_is_interrupt_exec)
def get_details(self, index: int, **kwargs):
global_fields = []
diff --git a/apps/application/flow/step_node/form_node/i_form_node.py b/apps/application/flow/step_node/form_node/i_form_node.py
index b793a5b78..cfd178ded 100644
--- a/apps/application/flow/step_node/form_node/i_form_node.py
+++ b/apps/application/flow/step_node/form_node/i_form_node.py
@@ -17,6 +17,7 @@ from common.util.field_message import ErrMessage
class FormNodeParamsSerializer(serializers.Serializer):
form_field_list = serializers.ListField(required=True, error_messages=ErrMessage.list("表单配置"))
form_content_format = serializers.CharField(required=True, error_messages=ErrMessage.char('表单输出内容'))
+ form_data = serializers.DictField(required=False, allow_null=True, error_messages=ErrMessage.dict("表单数据"))
class IFormNode(INode):
@@ -29,5 +30,5 @@ class IFormNode(INode):
def _run(self):
return self.execute(**self.node_params_serializer.data, **self.flow_params_serializer.data)
- def execute(self, form_field_list, form_content_format, **kwargs) -> NodeResult:
+ def execute(self, form_field_list, form_content_format, form_data, **kwargs) -> NodeResult:
pass
diff --git a/apps/application/flow/step_node/form_node/impl/base_form_node.py b/apps/application/flow/step_node/form_node/impl/base_form_node.py
index ae86b11d6..db179f588 100644
--- a/apps/application/flow/step_node/form_node/impl/base_form_node.py
+++ b/apps/application/flow/step_node/form_node/impl/base_form_node.py
@@ -42,7 +42,12 @@ class BaseFormNode(IFormNode):
for key in form_data:
self.context[key] = form_data[key]
- def execute(self, form_field_list, form_content_format, **kwargs) -> NodeResult:
+ def execute(self, form_field_list, form_content_format, form_data, **kwargs) -> NodeResult:
+ if form_data is not None:
+ self.context['is_submit'] = True
+ self.context['form_data'] = form_data
+ else:
+ self.context['is_submit'] = False
form_setting = {"form_field_list": form_field_list, "runtime_node_id": self.runtime_node_id,
"chat_record_id": self.flow_params_serializer.data.get("chat_record_id"),
"is_submit": self.context.get("is_submit", False)}
@@ -63,7 +68,8 @@ class BaseFormNode(IFormNode):
form = f'{json.dumps(form_setting)}'
prompt_template = PromptTemplate.from_template(form_content_format, template_format='jinja2')
value = prompt_template.format(form=form)
- return value
+ return {'content': value, 'runtime_node_id': self.runtime_node_id,
+ 'chat_record_id': self.workflow_params['chat_record_id']}
def get_details(self, index: int, **kwargs):
form_content_format = self.context.get('form_content_format')
diff --git a/apps/application/flow/workflow_manage.py b/apps/application/flow/workflow_manage.py
index c62612c0a..60fad57c9 100644
--- a/apps/application/flow/workflow_manage.py
+++ b/apps/application/flow/workflow_manage.py
@@ -244,15 +244,15 @@ class WorkflowManage:
base_to_response: BaseToResponse = SystemToResponse(), form_data=None, image_list=None,
document_list=None,
start_node_id=None,
- start_node_data=None, chat_record=None):
+ start_node_data=None, chat_record=None, child_node=None):
if form_data is None:
form_data = {}
if image_list is None:
image_list = []
if document_list is None:
document_list = []
+ self.start_node_id = start_node_id
self.start_node = None
- self.start_node_result_future = None
self.form_data = form_data
self.image_list = image_list
self.document_list = document_list
@@ -270,6 +270,7 @@ class WorkflowManage:
self.base_to_response = base_to_response
self.chat_record = chat_record
self.await_future_map = {}
+ self.child_node = child_node
if start_node_id is not None:
self.load_node(chat_record, start_node_id, start_node_data)
else:
@@ -290,11 +291,17 @@ class WorkflowManage:
for node_details in sorted(chat_record.details.values(), key=lambda d: d.get('index')):
node_id = node_details.get('node_id')
if node_details.get('runtime_node_id') == start_node_id:
- self.start_node = self.get_node_cls_by_id(node_id, node_details.get('up_node_id_list'))
- self.start_node.valid_args(self.start_node.node_params, self.start_node.workflow_params)
- self.start_node.save_context(node_details, self)
- node_result = NodeResult({**start_node_data, 'form_data': start_node_data, 'is_submit': True}, {})
- self.start_node_result_future = NodeResultFuture(node_result, None)
+ def get_node_params(n):
+ is_result = False
+ if n.type == 'application-node':
+ is_result = True
+ return {**n.properties.get('node_data'), 'form_data': start_node_data, 'node_data': start_node_data,
+ 'child_node': self.child_node, 'is_result': is_result}
+
+ self.start_node = self.get_node_cls_by_id(node_id, node_details.get('up_node_id_list'),
+ get_node_params=get_node_params)
+ self.start_node.valid_args(
+ {**self.start_node.node_params, 'form_data': start_node_data}, self.start_node.workflow_params)
self.node_context.append(self.start_node)
continue
@@ -306,7 +313,7 @@ class WorkflowManage:
def run(self):
if self.params.get('stream'):
- return self.run_stream(self.start_node, self.start_node_result_future)
+ return self.run_stream(self.start_node, None)
return self.run_block()
def run_block(self):
@@ -429,22 +436,41 @@ class WorkflowManage:
if result is not None:
if self.is_result(current_node, current_result):
self.node_chunk_manage.add_node_chunk(node_chunk)
+ child_node = {}
+ real_node_id = current_node.runtime_node_id
for r in result:
+ content = r
+ child_node = {}
+ node_is_end = False
+ if isinstance(r, dict):
+ content = r.get('content')
+ child_node = {'runtime_node_id': r.get('runtime_node_id'),
+ 'chat_record_id': r.get('chat_record_id')
+ , 'child_node': r.get('child_node')}
+ real_node_id = r.get('real_node_id')
+ node_is_end = r.get('node_is_end')
chunk = self.base_to_response.to_stream_chunk_response(self.params['chat_id'],
self.params['chat_record_id'],
current_node.id,
current_node.up_node_id_list,
- r, False, 0, 0,
+ content, False, 0, 0,
{'node_type': current_node.type,
- 'view_type': current_node.view_type})
+ 'runtime_node_id': current_node.runtime_node_id,
+ 'view_type': current_node.view_type,
+ 'child_node': child_node,
+ 'node_is_end': node_is_end,
+ 'real_node_id': real_node_id})
node_chunk.add_chunk(chunk)
chunk = self.base_to_response.to_stream_chunk_response(self.params['chat_id'],
self.params['chat_record_id'],
current_node.id,
current_node.up_node_id_list,
'', False, 0, 0, {'node_is_end': True,
+ 'runtime_node_id': current_node.runtime_node_id,
'node_type': current_node.type,
- 'view_type': current_node.view_type})
+ 'view_type': current_node.view_type,
+ 'child_node': child_node,
+ 'real_node_id': real_node_id})
node_chunk.end(chunk)
else:
list(result)
@@ -554,9 +580,9 @@ class WorkflowManage:
else:
if len(result) > 0:
exec_index = len(result) - 1
- content = result[exec_index]
- result[exec_index] += answer_text if len(
- content) == 0 else ('\n\n' + answer_text)
+ content = result[exec_index]['content']
+ result[exec_index]['content'] += answer_text['content'] if len(
+ content) == 0 else ('\n\n' + answer_text['content'])
else:
answer_text = node.get_answer_text()
result.insert(0, answer_text)
@@ -613,8 +639,8 @@ class WorkflowManage:
@param current_node_result: 当前可执行节点结果
@return: 可执行节点列表
"""
-
- if current_node.type == 'form-node' and 'form_data' not in current_node_result.node_variable:
+ # 判断是否中断执行
+ if current_node_result.is_interrupt_exec(current_node):
return []
node_list = []
if current_node_result is not None and current_node_result.is_assertion_result():
@@ -689,11 +715,12 @@ class WorkflowManage:
base_node_list = [node for node in self.flow.nodes if node.type == 'base-node']
return base_node_list[0]
- def get_node_cls_by_id(self, node_id, up_node_id_list=None):
+ def get_node_cls_by_id(self, node_id, up_node_id_list=None,
+ get_node_params=lambda node: node.properties.get('node_data')):
for node in self.flow.nodes:
if node.id == node_id:
node_instance = get_node(node.type)(node,
- self.params, self, up_node_id_list)
+ self.params, self, up_node_id_list, get_node_params)
return node_instance
return None
diff --git a/apps/application/migrations/0019_application_file_upload_enable_and_more.py b/apps/application/migrations/0019_application_file_upload_enable_and_more.py
index f8c33c2c1..7b12321f1 100644
--- a/apps/application/migrations/0019_application_file_upload_enable_and_more.py
+++ b/apps/application/migrations/0019_application_file_upload_enable_and_more.py
@@ -4,8 +4,8 @@ import django.contrib.postgres.fields
from django.db import migrations, models
sql = """
-UPDATE "public".application_chat_record
-SET "answer_text_list" = ARRAY[answer_text];
+UPDATE application_chat_record
+SET answer_text_list=ARRAY[jsonb_build_object('content',answer_text)]
"""
@@ -28,8 +28,7 @@ class Migration(migrations.Migration):
migrations.AddField(
model_name='chatrecord',
name='answer_text_list',
- field=django.contrib.postgres.fields.ArrayField(base_field=models.CharField(max_length=40960), default=list,
- size=None, verbose_name='改进标注列表'),
+ field=django.contrib.postgres.fields.ArrayField(base_field=models.JSONField(), default=list, size=None, verbose_name='改进标注列表')
),
migrations.RunSQL(sql)
]
diff --git a/apps/application/models/application.py b/apps/application/models/application.py
index c90d1325c..6ed33c48c 100644
--- a/apps/application/models/application.py
+++ b/apps/application/models/application.py
@@ -69,7 +69,6 @@ class Application(AppModelMixin):
file_upload_enable = models.BooleanField(verbose_name="文件上传是否启用", default=False)
file_upload_setting = models.JSONField(verbose_name="文件上传相关设置", default=dict)
-
@staticmethod
def get_default_model_prompt():
return ('已知信息:'
@@ -148,7 +147,7 @@ class ChatRecord(AppModelMixin):
problem_text = models.CharField(max_length=10240, verbose_name="问题")
answer_text = models.CharField(max_length=40960, verbose_name="答案")
answer_text_list = ArrayField(verbose_name="改进标注列表",
- base_field=models.CharField(max_length=40960)
+ base_field=models.JSONField()
, default=list)
message_tokens = models.IntegerField(verbose_name="请求token数量", default=0)
answer_tokens = models.IntegerField(verbose_name="响应token数量", default=0)
diff --git a/apps/application/serializers/chat_message_serializers.py b/apps/application/serializers/chat_message_serializers.py
index 5f38cfc7d..c6374c914 100644
--- a/apps/application/serializers/chat_message_serializers.py
+++ b/apps/application/serializers/chat_message_serializers.py
@@ -238,13 +238,14 @@ class ChatMessageSerializer(serializers.Serializer):
runtime_node_id = serializers.CharField(required=False, allow_null=True, allow_blank=True,
error_messages=ErrMessage.char("运行时节点id"))
- node_data = serializers.DictField(required=False, error_messages=ErrMessage.char("节点参数"))
+ node_data = serializers.DictField(required=False, allow_null=True, error_messages=ErrMessage.char("节点参数"))
application_id = serializers.UUIDField(required=False, allow_null=True, error_messages=ErrMessage.uuid("应用id"))
client_id = serializers.CharField(required=True, error_messages=ErrMessage.char("客户端id"))
client_type = serializers.CharField(required=True, error_messages=ErrMessage.char("客户端类型"))
form_data = serializers.DictField(required=False, error_messages=ErrMessage.char("全局变量"))
image_list = serializers.ListField(required=False, error_messages=ErrMessage.list("图片"))
document_list = serializers.ListField(required=False, error_messages=ErrMessage.list("文档"))
+ child_node = serializers.DictField(required=False, allow_null=True, error_messages=ErrMessage.dict("子节点"))
def is_valid_application_workflow(self, *, raise_exception=False):
self.is_valid_intraday_access_num()
@@ -353,7 +354,7 @@ class ChatMessageSerializer(serializers.Serializer):
'user_id': user_id}, WorkFlowPostHandler(chat_info, client_id, client_type),
base_to_response, form_data, image_list, document_list,
self.data.get('runtime_node_id'),
- self.data.get('node_data'), chat_record)
+ self.data.get('node_data'), chat_record, self.data.get('child_node'))
r = work_flow_manage.run()
return r
diff --git a/apps/application/views/chat_views.py b/apps/application/views/chat_views.py
index a3ef50766..ccd727c8b 100644
--- a/apps/application/views/chat_views.py
+++ b/apps/application/views/chat_views.py
@@ -138,7 +138,8 @@ class ChatView(APIView):
'node_id': request.data.get('node_id', None),
'runtime_node_id': request.data.get('runtime_node_id', None),
'node_data': request.data.get('node_data', {}),
- 'chat_record_id': request.data.get('chat_record_id')}
+ 'chat_record_id': request.data.get('chat_record_id'),
+ 'child_node': request.data.get('child_node')}
).chat()
@action(methods=['GET'], detail=False)
diff --git a/apps/common/handle/impl/response/openai_to_response.py b/apps/common/handle/impl/response/openai_to_response.py
index 0ce605514..7897a8c60 100644
--- a/apps/common/handle/impl/response/openai_to_response.py
+++ b/apps/common/handle/impl/response/openai_to_response.py
@@ -35,7 +35,7 @@ class OpenaiToResponse(BaseToResponse):
def to_stream_chunk_response(self, chat_id, chat_record_id, node_id, up_node_id_list, content, is_end, completion_tokens,
prompt_tokens, other_params: dict = None):
chunk = ChatCompletionChunk(id=chat_record_id, model='', object='chat.completion.chunk',
- created=datetime.datetime.now().second, choices=[
+ created=datetime.datetime.now().second,choices=[
Choice(delta=ChoiceDelta(content=content, chat_id=chat_id), finish_reason='stop' if is_end else None,
index=0)],
usage=CompletionUsage(completion_tokens=completion_tokens,
diff --git a/apps/common/handle/impl/response/system_to_response.py b/apps/common/handle/impl/response/system_to_response.py
index 7c374ef01..5999f1297 100644
--- a/apps/common/handle/impl/response/system_to_response.py
+++ b/apps/common/handle/impl/response/system_to_response.py
@@ -28,7 +28,7 @@ class SystemToResponse(BaseToResponse):
prompt_tokens, other_params: dict = None):
if other_params is None:
other_params = {}
- chunk = json.dumps({'chat_id': str(chat_id), 'id': str(chat_record_id), 'operate': True,
+ chunk = json.dumps({'chat_id': str(chat_id), 'chat_record_id': str(chat_record_id), 'operate': True,
'content': content, 'node_id': node_id, 'up_node_id_list': up_node_id_list, 'is_end': is_end,
'usage': {'completion_tokens': completion_tokens,
'prompt_tokens': prompt_tokens,
diff --git a/apps/embedding/task/embedding.py b/apps/embedding/task/embedding.py
index 7087636c2..7c39ebcdd 100644
--- a/apps/embedding/task/embedding.py
+++ b/apps/embedding/task/embedding.py
@@ -6,7 +6,6 @@
@date:2024/8/19 14:13
@desc:
"""
-import datetime
import logging
import traceback
from typing import List
@@ -17,7 +16,7 @@ from django.db.models import QuerySet
from common.config.embedding_config import ModelManage
from common.event import ListenerManagement, UpdateProblemArgs, UpdateEmbeddingDatasetIdArgs, \
UpdateEmbeddingDocumentIdArgs
-from dataset.models import Document, Status, TaskType, State
+from dataset.models import Document, TaskType, State
from ops import celery_app
from setting.models import Model
from setting.models_provider import get_model
diff --git a/ui/src/api/type/application.ts b/ui/src/api/type/application.ts
index b45a6783f..95812bbab 100644
--- a/ui/src/api/type/application.ts
+++ b/ui/src/api/type/application.ts
@@ -23,8 +23,9 @@ interface ApplicationFormType {
tts_type?: string
}
interface Chunk {
+ real_node_id: string
chat_id: string
- id: string
+ chat_record_id: string
content: string
node_id: string
up_node_id: string
@@ -32,13 +33,20 @@ interface Chunk {
node_is_end: boolean
node_type: string
view_type: string
+ runtime_node_id: string
+ child_node: any
}
interface chatType {
id: string
problem_text: string
answer_text: string
buffer: Array
- answer_text_list: Array
+ answer_text_list: Array<{
+ content: string
+ chat_record_id?: string
+ runtime_node_id?: string
+ child_node?: any
+ }>
/**
* 是否写入结束
*/
@@ -92,15 +100,24 @@ export class ChatRecordManage {
this.write_ed = false
this.node_list = []
}
- append_answer(chunk_answer: string, index?: number) {
- this.chat.answer_text_list[index != undefined ? index : this.chat.answer_text_list.length - 1] =
- this.chat.answer_text_list[
- index !== undefined ? index : this.chat.answer_text_list.length - 1
- ]
- ? this.chat.answer_text_list[
- index !== undefined ? index : this.chat.answer_text_list.length - 1
- ] + chunk_answer
- : chunk_answer
+ append_answer(
+ chunk_answer: string,
+ index?: number,
+ chat_record_id?: string,
+ runtime_node_id?: string,
+ child_node?: any
+ ) {
+ const set_index = index != undefined ? index : this.chat.answer_text_list.length - 1
+ const content = this.chat.answer_text_list[set_index]
+ ? this.chat.answer_text_list[set_index].content + chunk_answer
+ : chunk_answer
+ this.chat.answer_text_list[set_index] = {
+ content: content,
+ chat_record_id,
+ runtime_node_id,
+ child_node
+ }
+
this.chat.answer_text = this.chat.answer_text + chunk_answer
}
@@ -127,14 +144,22 @@ export class ChatRecordManage {
run_node.view_type == 'single_view' ||
(run_node.view_type == 'many_view' && current_up_node.view_type == 'single_view')
) {
- const none_index = this.chat.answer_text_list.indexOf('')
+ const none_index = this.findIndex(
+ this.chat.answer_text_list,
+ (item) => item.content == '',
+ 'index'
+ )
if (none_index > -1) {
answer_text_list_index = none_index
} else {
answer_text_list_index = this.chat.answer_text_list.length
}
} else {
- const none_index = this.chat.answer_text_list.indexOf('')
+ const none_index = this.findIndex(
+ this.chat.answer_text_list,
+ (item) => item.content === '',
+ 'index'
+ )
if (none_index > -1) {
answer_text_list_index = none_index
} else {
@@ -152,6 +177,19 @@ export class ChatRecordManage {
}
return undefined
}
+ findIndex(array: Array, find: (item: T) => boolean, type: 'last' | 'index') {
+ let set_index = -1
+ for (let index = 0; index < array.length; index++) {
+ const element = array[index]
+ if (find(element)) {
+ set_index = index
+ if (type == 'index') {
+ break
+ }
+ }
+ }
+ return set_index
+ }
closeInterval() {
this.chat.write_ed = true
this.write_ed = true
@@ -161,7 +199,11 @@ export class ChatRecordManage {
if (this.id) {
clearInterval(this.id)
}
- const last_index = this.chat.answer_text_list.lastIndexOf('')
+ const last_index = this.findIndex(
+ this.chat.answer_text_list,
+ (item) => item.content == '',
+ 'last'
+ )
if (last_index > 0) {
this.chat.answer_text_list.splice(last_index, 1)
}
@@ -193,19 +235,29 @@ export class ChatRecordManage {
)
this.append_answer(
(divider_content ? divider_content.splice(0).join('') : '') + context.join(''),
- answer_text_list_index
+ answer_text_list_index,
+ current_node.chat_record_id,
+ current_node.runtime_node_id,
+ current_node.child_node
)
} else if (this.is_close) {
while (true) {
const node_info = this.get_run_node()
+
if (node_info == undefined) {
break
}
this.append_answer(
(node_info.divider_content ? node_info.divider_content.splice(0).join('') : '') +
node_info.current_node.buffer.splice(0).join(''),
- node_info.answer_text_list_index
+ node_info.answer_text_list_index,
+ current_node.chat_record_id,
+ current_node.runtime_node_id,
+ current_node.child_node
)
+ if (node_info.current_node.buffer.length == 0) {
+ node_info.current_node.is_end = true
+ }
}
this.closeInterval()
} else {
@@ -213,7 +265,10 @@ export class ChatRecordManage {
if (s !== undefined) {
this.append_answer(
(divider_content ? divider_content.splice(0).join('') : '') + s,
- answer_text_list_index
+ answer_text_list_index,
+ current_node.chat_record_id,
+ current_node.runtime_node_id,
+ current_node.child_node
)
}
}
@@ -235,16 +290,18 @@ export class ChatRecordManage {
this.is_stop = false
}
appendChunk(chunk: Chunk) {
- let n = this.node_list.find(
- (item) => item.node_id == chunk.node_id && item.up_node_id === chunk.up_node_id
- )
+ let n = this.node_list.find((item) => item.real_node_id == chunk.real_node_id)
if (n) {
n.buffer.push(...chunk.content)
} else {
n = {
buffer: [...chunk.content],
+ real_node_id: chunk.real_node_id,
node_id: chunk.node_id,
+ chat_record_id: chunk.chat_record_id,
up_node_id: chunk.up_node_id,
+ runtime_node_id: chunk.runtime_node_id,
+ child_node: chunk.child_node,
node_type: chunk.node_type,
index: this.node_list.length,
view_type: chunk.view_type,
@@ -257,9 +314,12 @@ export class ChatRecordManage {
}
}
append(answer_text_block: string) {
- const index =this.chat.answer_text_list.indexOf("")
- this.chat.answer_text_list[index]=answer_text_block
-
+ let set_index = this.findIndex(
+ this.chat.answer_text_list,
+ (item) => item.content == '',
+ 'index'
+ )
+ this.chat.answer_text_list[set_index] = { content: answer_text_block }
}
}
diff --git a/ui/src/components/ai-chat/component/answer-content/index.vue b/ui/src/components/ai-chat/component/answer-content/index.vue
index ec4f13f2b..4d1c7a684 100644
--- a/ui/src/components/ai-chat/component/answer-content/index.vue
+++ b/ui/src/components/ai-chat/component/answer-content/index.vue
@@ -1,6 +1,6 @@
-
+
@@ -9,14 +9,18 @@
@@ -51,6 +55,7 @@ import KnowledgeSource from '@/components/ai-chat/KnowledgeSource.vue'
import MdRenderer from '@/components/markdown/MdRenderer.vue'
import OperationButton from '@/components/ai-chat/component/operation-button/index.vue'
import { type chatType } from '@/api/type/application'
+import { computed } from 'vue'
const props = defineProps<{
chatRecord: chatType
application: any
@@ -71,9 +76,17 @@ const chatMessage = (question: string, type: 'old' | 'new', other_params_data?:
props.sendMessage(question, other_params_data)
}
}
-const add_answer_text_list = (answer_text_list: Array) => {
- answer_text_list.push('')
+const add_answer_text_list = (answer_text_list: Array) => {
+ answer_text_list.push({ content: '' })
}
+const answer_text_list = computed(() => {
+ return props.chatRecord.answer_text_list.map((item) => {
+ if (typeof item == 'string') {
+ return { content: item }
+ }
+ return item
+ })
+})
function showSource(row: any) {
if (props.type === 'log') {
diff --git a/ui/src/components/ai-chat/index.vue b/ui/src/components/ai-chat/index.vue
index 72a138753..6e4993ef6 100644
--- a/ui/src/components/ai-chat/index.vue
+++ b/ui/src/components/ai-chat/index.vue
@@ -222,7 +222,7 @@ const getWrite = (chat: any, reader: any, stream: boolean) => {
for (const index in split) {
const chunk = JSON?.parse(split[index].replace('data:', ''))
chat.chat_id = chunk.chat_id
- chat.record_id = chunk.id
+ chat.record_id = chunk.chat_record_id
ChatManagement.appendChunk(chat.id, chunk)
if (chunk.is_end) {
@@ -278,7 +278,7 @@ function chatMessage(chat?: any, problem?: string, re_chat?: boolean, other_para
id: randomId(),
problem_text: problem ? problem : inputValue.value.trim(),
answer_text: '',
- answer_text_list: [''],
+ answer_text_list: [{ content: '' }],
buffer: [],
write_ed: false,
is_stop: false,
diff --git a/ui/src/components/markdown/FormRander.vue b/ui/src/components/markdown/FormRander.vue
index 1d675e8b4..aa4437614 100644
--- a/ui/src/components/markdown/FormRander.vue
+++ b/ui/src/components/markdown/FormRander.vue
@@ -10,7 +10,10 @@
v-model="form_data"
:model="form_data"
>
- 提交
@@ -18,13 +21,19 @@