mirror of
https://github.com/1Panel-dev/MaxKB.git
synced 2025-12-26 01:33:05 +00:00
feat: ai对话,问题优化,指定回复节点支持返回结果开关
This commit is contained in:
parent
7b48599e57
commit
209702cad2
|
|
@ -10,6 +10,7 @@ import time
|
|||
from abc import abstractmethod
|
||||
from typing import Type, Dict, List
|
||||
|
||||
from django.core import cache
|
||||
from django.db.models import QuerySet
|
||||
from rest_framework import serializers
|
||||
|
||||
|
|
@ -18,7 +19,6 @@ from application.models.api_key_model import ApplicationPublicAccessClient
|
|||
from common.constants.authentication_type import AuthenticationType
|
||||
from common.field.common import InstanceField
|
||||
from common.util.field_message import ErrMessage
|
||||
from django.core import cache
|
||||
|
||||
chat_cache = cache.caches['chat_cache']
|
||||
|
||||
|
|
@ -27,6 +27,9 @@ def write_context(step_variable: Dict, global_variable: Dict, node, workflow):
|
|||
if step_variable is not None:
|
||||
for key in step_variable:
|
||||
node.context[key] = step_variable[key]
|
||||
if workflow.is_result() and 'answer' in step_variable:
|
||||
yield step_variable['answer']
|
||||
workflow.answer += step_variable['answer']
|
||||
if global_variable is not None:
|
||||
for key in global_variable:
|
||||
workflow.context[key] = global_variable[key]
|
||||
|
|
@ -70,18 +73,14 @@ class WorkFlowPostHandler:
|
|||
|
||||
|
||||
class NodeResult:
|
||||
def __init__(self, node_variable: Dict, workflow_variable: Dict, _to_response=None, _write_context=write_context):
|
||||
def __init__(self, node_variable: Dict, workflow_variable: Dict,
|
||||
_write_context=write_context):
|
||||
self._write_context = _write_context
|
||||
self.node_variable = node_variable
|
||||
self.workflow_variable = workflow_variable
|
||||
self._to_response = _to_response
|
||||
|
||||
def write_context(self, node, workflow):
|
||||
self._write_context(self.node_variable, self.workflow_variable, node, workflow)
|
||||
|
||||
def to_response(self, chat_id, chat_record_id, node, workflow, post_handler: WorkFlowPostHandler):
|
||||
return self._to_response(chat_id, chat_record_id, self.node_variable, self.workflow_variable, node, workflow,
|
||||
post_handler)
|
||||
return self._write_context(self.node_variable, self.workflow_variable, node, workflow)
|
||||
|
||||
def is_assertion_result(self):
|
||||
return 'branch_id' in self.node_variable
|
||||
|
|
|
|||
|
|
@ -22,6 +22,8 @@ class ChatNodeSerializer(serializers.Serializer):
|
|||
# 多轮对话数量
|
||||
dialogue_number = serializers.IntegerField(required=True, error_messages=ErrMessage.integer("多轮对话数量"))
|
||||
|
||||
is_result = serializers.BooleanField(required=False, error_messages=ErrMessage.boolean('是否返回内容'))
|
||||
|
||||
|
||||
class IChatNode(INode):
|
||||
type = 'ai-chat-node'
|
||||
|
|
|
|||
|
|
@ -6,19 +6,30 @@
|
|||
@date:2024/6/4 14:30
|
||||
@desc:
|
||||
"""
|
||||
import time
|
||||
from functools import reduce
|
||||
from typing import List, Dict
|
||||
|
||||
from langchain.schema import HumanMessage, SystemMessage
|
||||
from langchain_core.messages import BaseMessage
|
||||
|
||||
from application.flow import tools
|
||||
from application.flow.i_step_node import NodeResult, INode
|
||||
from application.flow.step_node.ai_chat_step_node.i_chat_node import IChatNode
|
||||
from setting.models_provider.tools import get_model_instance_by_model_user_id
|
||||
|
||||
|
||||
def _write_context(node_variable: Dict, workflow_variable: Dict, node: INode, workflow, answer: str):
|
||||
chat_model = node_variable.get('chat_model')
|
||||
message_tokens = chat_model.get_num_tokens_from_messages(node_variable.get('message_list'))
|
||||
answer_tokens = chat_model.get_num_tokens(answer)
|
||||
node.context['message_tokens'] = message_tokens
|
||||
node.context['answer_tokens'] = answer_tokens
|
||||
node.context['answer'] = answer
|
||||
node.context['history_message'] = node_variable['history_message']
|
||||
node.context['question'] = node_variable['question']
|
||||
if workflow.is_result():
|
||||
workflow.answer += answer
|
||||
|
||||
|
||||
def write_context_stream(node_variable: Dict, workflow_variable: Dict, node: INode, workflow):
|
||||
"""
|
||||
写入上下文数据 (流式)
|
||||
|
|
@ -31,15 +42,8 @@ def write_context_stream(node_variable: Dict, workflow_variable: Dict, node: INo
|
|||
answer = ''
|
||||
for chunk in response:
|
||||
answer += chunk.content
|
||||
chat_model = node_variable.get('chat_model')
|
||||
message_tokens = chat_model.get_num_tokens_from_messages(node_variable.get('message_list'))
|
||||
answer_tokens = chat_model.get_num_tokens(answer)
|
||||
node.context['message_tokens'] = message_tokens
|
||||
node.context['answer_tokens'] = answer_tokens
|
||||
node.context['answer'] = answer
|
||||
node.context['history_message'] = node_variable['history_message']
|
||||
node.context['question'] = node_variable['question']
|
||||
node.context['run_time'] = time.time() - node.context['start_time']
|
||||
yield answer
|
||||
_write_context(node_variable, workflow_variable, node, workflow, answer)
|
||||
|
||||
|
||||
def write_context(node_variable: Dict, workflow_variable: Dict, node: INode, workflow):
|
||||
|
|
@ -51,71 +55,8 @@ def write_context(node_variable: Dict, workflow_variable: Dict, node: INode, wor
|
|||
@param workflow: 工作流管理器
|
||||
"""
|
||||
response = node_variable.get('result')
|
||||
chat_model = node_variable.get('chat_model')
|
||||
answer = response.content
|
||||
message_tokens = chat_model.get_num_tokens_from_messages(node_variable.get('message_list'))
|
||||
answer_tokens = chat_model.get_num_tokens(answer)
|
||||
node.context['message_tokens'] = message_tokens
|
||||
node.context['answer_tokens'] = answer_tokens
|
||||
node.context['answer'] = answer
|
||||
node.context['history_message'] = node_variable['history_message']
|
||||
node.context['question'] = node_variable['question']
|
||||
|
||||
|
||||
def get_to_response_write_context(node_variable: Dict, node: INode):
|
||||
def _write_context(answer, status=200):
|
||||
chat_model = node_variable.get('chat_model')
|
||||
|
||||
if status == 200:
|
||||
answer_tokens = chat_model.get_num_tokens(answer)
|
||||
message_tokens = chat_model.get_num_tokens_from_messages(node_variable.get('message_list'))
|
||||
else:
|
||||
answer_tokens = 0
|
||||
message_tokens = 0
|
||||
node.err_message = answer
|
||||
node.status = status
|
||||
node.context['message_tokens'] = message_tokens
|
||||
node.context['answer_tokens'] = answer_tokens
|
||||
node.context['answer'] = answer
|
||||
node.context['run_time'] = time.time() - node.context['start_time']
|
||||
|
||||
return _write_context
|
||||
|
||||
|
||||
def to_stream_response(chat_id, chat_record_id, node_variable: Dict, workflow_variable: Dict, node, workflow,
|
||||
post_handler):
|
||||
"""
|
||||
将流式数据 转换为 流式响应
|
||||
@param chat_id: 会话id
|
||||
@param chat_record_id: 对话记录id
|
||||
@param node_variable: 节点数据
|
||||
@param workflow_variable: 工作流数据
|
||||
@param node: 节点
|
||||
@param workflow: 工作流管理器
|
||||
@param post_handler: 后置处理器 输出结果后执行
|
||||
@return: 流式响应
|
||||
"""
|
||||
response = node_variable.get('result')
|
||||
_write_context = get_to_response_write_context(node_variable, node)
|
||||
return tools.to_stream_response(chat_id, chat_record_id, response, workflow, _write_context, post_handler)
|
||||
|
||||
|
||||
def to_response(chat_id, chat_record_id, node_variable: Dict, workflow_variable: Dict, node, workflow,
|
||||
post_handler):
|
||||
"""
|
||||
将结果转换
|
||||
@param chat_id: 会话id
|
||||
@param chat_record_id: 对话记录id
|
||||
@param node_variable: 节点数据
|
||||
@param workflow_variable: 工作流数据
|
||||
@param node: 节点
|
||||
@param workflow: 工作流管理器
|
||||
@param post_handler: 后置处理器
|
||||
@return: 响应
|
||||
"""
|
||||
response = node_variable.get('result')
|
||||
_write_context = get_to_response_write_context(node_variable, node)
|
||||
return tools.to_response(chat_id, chat_record_id, response, workflow, _write_context, post_handler)
|
||||
_write_context(node_variable, workflow_variable, node, workflow, answer)
|
||||
|
||||
|
||||
class BaseChatNode(IChatNode):
|
||||
|
|
@ -132,13 +73,12 @@ class BaseChatNode(IChatNode):
|
|||
r = chat_model.stream(message_list)
|
||||
return NodeResult({'result': r, 'chat_model': chat_model, 'message_list': message_list,
|
||||
'history_message': history_message, 'question': question.content}, {},
|
||||
_write_context=write_context_stream,
|
||||
_to_response=to_stream_response)
|
||||
_write_context=write_context_stream)
|
||||
else:
|
||||
r = chat_model.invoke(message_list)
|
||||
return NodeResult({'result': r, 'chat_model': chat_model, 'message_list': message_list,
|
||||
'history_message': history_message, 'question': question.content}, {},
|
||||
_write_context=write_context, _to_response=to_response)
|
||||
_write_context=write_context)
|
||||
|
||||
@staticmethod
|
||||
def get_history_message(history_chat_record, dialogue_number):
|
||||
|
|
|
|||
|
|
@ -20,6 +20,7 @@ class ReplyNodeParamsSerializer(serializers.Serializer):
|
|||
fields = serializers.ListField(required=False, error_messages=ErrMessage.list("引用字段"))
|
||||
content = serializers.CharField(required=False, allow_blank=True, allow_null=True,
|
||||
error_messages=ErrMessage.char("直接回答内容"))
|
||||
is_result = serializers.BooleanField(required=False, error_messages=ErrMessage.boolean('是否返回内容'))
|
||||
|
||||
def is_valid(self, *, raise_exception=False):
|
||||
super().is_valid(raise_exception=True)
|
||||
|
|
|
|||
|
|
@ -6,69 +6,19 @@
|
|||
@date:2024/6/11 17:25
|
||||
@desc:
|
||||
"""
|
||||
from typing import List, Dict
|
||||
from typing import List
|
||||
|
||||
from langchain_core.messages import AIMessage, AIMessageChunk
|
||||
|
||||
from application.flow import tools
|
||||
from application.flow.i_step_node import NodeResult, INode
|
||||
from application.flow.i_step_node import NodeResult
|
||||
from application.flow.step_node.direct_reply_node.i_reply_node import IReplyNode
|
||||
|
||||
|
||||
def get_to_response_write_context(node_variable: Dict, node: INode):
|
||||
def _write_context(answer, status=200):
|
||||
node.context['answer'] = answer
|
||||
|
||||
return _write_context
|
||||
|
||||
|
||||
def to_stream_response(chat_id, chat_record_id, node_variable: Dict, workflow_variable: Dict, node, workflow,
|
||||
post_handler):
|
||||
"""
|
||||
将流式数据 转换为 流式响应
|
||||
@param chat_id: 会话id
|
||||
@param chat_record_id: 对话记录id
|
||||
@param node_variable: 节点数据
|
||||
@param workflow_variable: 工作流数据
|
||||
@param node: 节点
|
||||
@param workflow: 工作流管理器
|
||||
@param post_handler: 后置处理器 输出结果后执行
|
||||
@return: 流式响应
|
||||
"""
|
||||
response = node_variable.get('result')
|
||||
_write_context = get_to_response_write_context(node_variable, node)
|
||||
return tools.to_stream_response(chat_id, chat_record_id, response, workflow, _write_context, post_handler)
|
||||
|
||||
|
||||
def to_response(chat_id, chat_record_id, node_variable: Dict, workflow_variable: Dict, node, workflow,
|
||||
post_handler):
|
||||
"""
|
||||
将结果转换
|
||||
@param chat_id: 会话id
|
||||
@param chat_record_id: 对话记录id
|
||||
@param node_variable: 节点数据
|
||||
@param workflow_variable: 工作流数据
|
||||
@param node: 节点
|
||||
@param workflow: 工作流管理器
|
||||
@param post_handler: 后置处理器
|
||||
@return: 响应
|
||||
"""
|
||||
response = node_variable.get('result')
|
||||
_write_context = get_to_response_write_context(node_variable, node)
|
||||
return tools.to_response(chat_id, chat_record_id, response, workflow, _write_context, post_handler)
|
||||
|
||||
|
||||
class BaseReplyNode(IReplyNode):
|
||||
def execute(self, reply_type, stream, fields=None, content=None, **kwargs) -> NodeResult:
|
||||
if reply_type == 'referencing':
|
||||
result = self.get_reference_content(fields)
|
||||
else:
|
||||
result = self.generate_reply_content(content)
|
||||
if stream:
|
||||
return NodeResult({'result': iter([AIMessageChunk(content=result)]), 'answer': result}, {},
|
||||
_to_response=to_stream_response)
|
||||
else:
|
||||
return NodeResult({'result': AIMessage(content=result), 'answer': result}, {}, _to_response=to_response)
|
||||
return NodeResult({'answer': result}, {})
|
||||
|
||||
def generate_reply_content(self, prompt):
|
||||
return self.workflow_manage.generate_prompt(prompt)
|
||||
|
|
|
|||
|
|
@ -22,6 +22,8 @@ class QuestionNodeSerializer(serializers.Serializer):
|
|||
# 多轮对话数量
|
||||
dialogue_number = serializers.IntegerField(required=True, error_messages=ErrMessage.integer("多轮对话数量"))
|
||||
|
||||
is_result = serializers.BooleanField(required=False, error_messages=ErrMessage.boolean('是否返回内容'))
|
||||
|
||||
|
||||
class IQuestionNode(INode):
|
||||
type = 'question-node'
|
||||
|
|
|
|||
|
|
@ -6,19 +6,30 @@
|
|||
@date:2024/6/4 14:30
|
||||
@desc:
|
||||
"""
|
||||
import time
|
||||
from functools import reduce
|
||||
from typing import List, Dict
|
||||
|
||||
from langchain.schema import HumanMessage, SystemMessage
|
||||
from langchain_core.messages import BaseMessage
|
||||
|
||||
from application.flow import tools
|
||||
from application.flow.i_step_node import NodeResult, INode
|
||||
from application.flow.step_node.question_node.i_question_node import IQuestionNode
|
||||
from setting.models_provider.tools import get_model_instance_by_model_user_id
|
||||
|
||||
|
||||
def _write_context(node_variable: Dict, workflow_variable: Dict, node: INode, workflow, answer: str):
|
||||
chat_model = node_variable.get('chat_model')
|
||||
message_tokens = chat_model.get_num_tokens_from_messages(node_variable.get('message_list'))
|
||||
answer_tokens = chat_model.get_num_tokens(answer)
|
||||
node.context['message_tokens'] = message_tokens
|
||||
node.context['answer_tokens'] = answer_tokens
|
||||
node.context['answer'] = answer
|
||||
node.context['history_message'] = node_variable['history_message']
|
||||
node.context['question'] = node_variable['question']
|
||||
if workflow.is_result():
|
||||
workflow.answer += answer
|
||||
|
||||
|
||||
def write_context_stream(node_variable: Dict, workflow_variable: Dict, node: INode, workflow):
|
||||
"""
|
||||
写入上下文数据 (流式)
|
||||
|
|
@ -31,15 +42,8 @@ def write_context_stream(node_variable: Dict, workflow_variable: Dict, node: INo
|
|||
answer = ''
|
||||
for chunk in response:
|
||||
answer += chunk.content
|
||||
chat_model = node_variable.get('chat_model')
|
||||
message_tokens = chat_model.get_num_tokens_from_messages(node_variable.get('message_list'))
|
||||
answer_tokens = chat_model.get_num_tokens(answer)
|
||||
node.context['message_tokens'] = message_tokens
|
||||
node.context['answer_tokens'] = answer_tokens
|
||||
node.context['answer'] = answer
|
||||
node.context['history_message'] = node_variable['history_message']
|
||||
node.context['question'] = node_variable['question']
|
||||
node.context['run_time'] = time.time() - node.context['start_time']
|
||||
yield answer
|
||||
_write_context(node_variable, workflow_variable, node, workflow, answer)
|
||||
|
||||
|
||||
def write_context(node_variable: Dict, workflow_variable: Dict, node: INode, workflow):
|
||||
|
|
@ -51,71 +55,8 @@ def write_context(node_variable: Dict, workflow_variable: Dict, node: INode, wor
|
|||
@param workflow: 工作流管理器
|
||||
"""
|
||||
response = node_variable.get('result')
|
||||
chat_model = node_variable.get('chat_model')
|
||||
answer = response.content
|
||||
message_tokens = chat_model.get_num_tokens_from_messages(node_variable.get('message_list'))
|
||||
answer_tokens = chat_model.get_num_tokens(answer)
|
||||
node.context['message_tokens'] = message_tokens
|
||||
node.context['answer_tokens'] = answer_tokens
|
||||
node.context['answer'] = answer
|
||||
node.context['history_message'] = node_variable['history_message']
|
||||
node.context['question'] = node_variable['question']
|
||||
|
||||
|
||||
def get_to_response_write_context(node_variable: Dict, node: INode):
|
||||
def _write_context(answer, status=200):
|
||||
chat_model = node_variable.get('chat_model')
|
||||
|
||||
if status == 200:
|
||||
answer_tokens = chat_model.get_num_tokens(answer)
|
||||
message_tokens = chat_model.get_num_tokens_from_messages(node_variable.get('message_list'))
|
||||
else:
|
||||
answer_tokens = 0
|
||||
message_tokens = 0
|
||||
node.err_message = answer
|
||||
node.status = status
|
||||
node.context['message_tokens'] = message_tokens
|
||||
node.context['answer_tokens'] = answer_tokens
|
||||
node.context['answer'] = answer
|
||||
node.context['run_time'] = time.time() - node.context['start_time']
|
||||
|
||||
return _write_context
|
||||
|
||||
|
||||
def to_stream_response(chat_id, chat_record_id, node_variable: Dict, workflow_variable: Dict, node, workflow,
|
||||
post_handler):
|
||||
"""
|
||||
将流式数据 转换为 流式响应
|
||||
@param chat_id: 会话id
|
||||
@param chat_record_id: 对话记录id
|
||||
@param node_variable: 节点数据
|
||||
@param workflow_variable: 工作流数据
|
||||
@param node: 节点
|
||||
@param workflow: 工作流管理器
|
||||
@param post_handler: 后置处理器 输出结果后执行
|
||||
@return: 流式响应
|
||||
"""
|
||||
response = node_variable.get('result')
|
||||
_write_context = get_to_response_write_context(node_variable, node)
|
||||
return tools.to_stream_response(chat_id, chat_record_id, response, workflow, _write_context, post_handler)
|
||||
|
||||
|
||||
def to_response(chat_id, chat_record_id, node_variable: Dict, workflow_variable: Dict, node, workflow,
|
||||
post_handler):
|
||||
"""
|
||||
将结果转换
|
||||
@param chat_id: 会话id
|
||||
@param chat_record_id: 对话记录id
|
||||
@param node_variable: 节点数据
|
||||
@param workflow_variable: 工作流数据
|
||||
@param node: 节点
|
||||
@param workflow: 工作流管理器
|
||||
@param post_handler: 后置处理器
|
||||
@return: 响应
|
||||
"""
|
||||
response = node_variable.get('result')
|
||||
_write_context = get_to_response_write_context(node_variable, node)
|
||||
return tools.to_response(chat_id, chat_record_id, response, workflow, _write_context, post_handler)
|
||||
_write_context(node_variable, workflow_variable, node, workflow, answer)
|
||||
|
||||
|
||||
class BaseQuestionNode(IQuestionNode):
|
||||
|
|
@ -131,15 +72,13 @@ class BaseQuestionNode(IQuestionNode):
|
|||
if stream:
|
||||
r = chat_model.stream(message_list)
|
||||
return NodeResult({'result': r, 'chat_model': chat_model, 'message_list': message_list,
|
||||
'get_to_response_write_context': get_to_response_write_context,
|
||||
'history_message': history_message, 'question': question.content}, {},
|
||||
_write_context=write_context_stream,
|
||||
_to_response=to_stream_response)
|
||||
_write_context=write_context_stream)
|
||||
else:
|
||||
r = chat_model.invoke(message_list)
|
||||
return NodeResult({'result': r, 'chat_model': chat_model, 'message_list': message_list,
|
||||
'history_message': history_message, 'question': question.content}, {},
|
||||
_write_context=write_context, _to_response=to_response)
|
||||
_write_context=write_context)
|
||||
|
||||
@staticmethod
|
||||
def get_history_message(history_chat_record, dialogue_number):
|
||||
|
|
|
|||
|
|
@ -85,3 +85,21 @@ def to_response(chat_id, chat_record_id, response: BaseMessage, workflow, write_
|
|||
post_handler.handler(chat_id, chat_record_id, answer, workflow)
|
||||
return result.success({'chat_id': str(chat_id), 'id': str(chat_record_id), 'operate': True,
|
||||
'content': answer, 'is_end': True})
|
||||
|
||||
|
||||
def to_response_simple(chat_id, chat_record_id, response: BaseMessage, workflow,
|
||||
post_handler: WorkFlowPostHandler):
|
||||
answer = response.content
|
||||
post_handler.handler(chat_id, chat_record_id, answer, workflow)
|
||||
return result.success({'chat_id': str(chat_id), 'id': str(chat_record_id), 'operate': True,
|
||||
'content': answer, 'is_end': True})
|
||||
|
||||
|
||||
def to_stream_response_simple(stream_event):
|
||||
r = StreamingHttpResponse(
|
||||
streaming_content=stream_event,
|
||||
content_type='text/event-stream;charset=utf-8',
|
||||
charset='utf-8')
|
||||
|
||||
r['Cache-Control'] = 'no-cache'
|
||||
return r
|
||||
|
|
|
|||
|
|
@ -6,10 +6,11 @@
|
|||
@date:2024/1/9 17:40
|
||||
@desc:
|
||||
"""
|
||||
import json
|
||||
from functools import reduce
|
||||
from typing import List, Dict
|
||||
|
||||
from langchain_core.messages import AIMessageChunk, AIMessage
|
||||
from langchain_core.messages import AIMessage
|
||||
from langchain_core.prompts import PromptTemplate
|
||||
|
||||
from application.flow import tools
|
||||
|
|
@ -63,7 +64,6 @@ class Flow:
|
|||
def get_search_node(self):
|
||||
return [node for node in self.nodes if node.type == 'search-dataset-node']
|
||||
|
||||
|
||||
def is_valid(self):
|
||||
"""
|
||||
校验工作流数据
|
||||
|
|
@ -140,33 +140,71 @@ class WorkflowManage:
|
|||
self.work_flow_post_handler = work_flow_post_handler
|
||||
self.current_node = None
|
||||
self.current_result = None
|
||||
self.answer = ""
|
||||
|
||||
def run(self):
|
||||
"""
|
||||
运行工作流
|
||||
"""
|
||||
if self.params.get('stream'):
|
||||
return self.run_stream()
|
||||
return self.run_block()
|
||||
|
||||
def run_block(self):
|
||||
try:
|
||||
while self.has_next_node(self.current_result):
|
||||
self.current_node = self.get_next_node()
|
||||
self.node_context.append(self.current_node)
|
||||
self.current_result = self.current_node.run()
|
||||
if self.has_next_node(self.current_result):
|
||||
self.current_result.write_context(self.current_node, self)
|
||||
else:
|
||||
r = self.current_result.to_response(self.params['chat_id'], self.params['chat_record_id'],
|
||||
self.current_node, self,
|
||||
self.work_flow_post_handler)
|
||||
return r
|
||||
result = self.current_result.write_context(self.current_node, self)
|
||||
if result is not None:
|
||||
list(result)
|
||||
if not self.has_next_node(self.current_result):
|
||||
return tools.to_response_simple(self.params['chat_id'], self.params['chat_record_id'],
|
||||
AIMessage(self.answer), self,
|
||||
self.work_flow_post_handler)
|
||||
except Exception as e:
|
||||
if self.params.get('stream'):
|
||||
return tools.to_stream_response(self.params['chat_id'], self.params['chat_record_id'],
|
||||
iter([AIMessageChunk(str(e))]), self,
|
||||
self.current_node.get_write_error_context(e),
|
||||
self.work_flow_post_handler)
|
||||
else:
|
||||
return tools.to_response(self.params['chat_id'], self.params['chat_record_id'],
|
||||
AIMessage(str(e)), self, self.current_node.get_write_error_context(e),
|
||||
self.work_flow_post_handler)
|
||||
return tools.to_response(self.params['chat_id'], self.params['chat_record_id'],
|
||||
AIMessage(str(e)), self, self.current_node.get_write_error_context(e),
|
||||
self.work_flow_post_handler)
|
||||
|
||||
def run_stream(self):
|
||||
return tools.to_stream_response_simple(self.stream_event())
|
||||
|
||||
def stream_event(self):
|
||||
try:
|
||||
while self.has_next_node(self.current_result):
|
||||
self.current_node = self.get_next_node()
|
||||
self.node_context.append(self.current_node)
|
||||
self.current_result = self.current_node.run()
|
||||
result = self.current_result.write_context(self.current_node, self)
|
||||
if result is not None:
|
||||
for r in result:
|
||||
if self.is_result():
|
||||
yield self.get_chunk_content(r)
|
||||
if not self.has_next_node(self.current_result):
|
||||
yield self.get_chunk_content('', True)
|
||||
break
|
||||
self.work_flow_post_handler.handler(self.params['chat_id'], self.params['chat_record_id'],
|
||||
self.answer,
|
||||
self)
|
||||
except Exception as e:
|
||||
self.current_node.get_write_error_context(e)
|
||||
self.answer += str(e)
|
||||
self.work_flow_post_handler.handler(self.params['chat_id'], self.params['chat_record_id'],
|
||||
self.answer,
|
||||
self)
|
||||
yield self.get_chunk_content(str(e), True)
|
||||
|
||||
def is_result(self):
|
||||
"""
|
||||
判断是否是返回节点
|
||||
@return:
|
||||
"""
|
||||
return self.current_node.node_params.get('is_result', not self.has_next_node(
|
||||
self.current_result)) if self.current_node.node_params is not None else False
|
||||
|
||||
def get_chunk_content(self, chunk, is_end=False):
|
||||
return 'data: ' + json.dumps(
|
||||
{'chat_id': self.params['chat_id'], 'id': self.params['chat_record_id'], 'operate': True,
|
||||
'content': chunk, 'is_end': is_end}, ensure_ascii=False) + "\n\n"
|
||||
|
||||
def has_next_node(self, node_result: NodeResult | None):
|
||||
"""
|
||||
|
|
|
|||
Loading…
Reference in New Issue