mirror of
https://github.com/1Panel-dev/MaxKB.git
synced 2025-12-29 07:52:50 +00:00
feat: 支持节点参数设置直接输出 #846
This commit is contained in:
parent
76c1acbabb
commit
35f0c18dd3
|
|
@ -10,6 +10,7 @@ import time
|
|||
from abc import abstractmethod
|
||||
from typing import Type, Dict, List
|
||||
|
||||
from django.core import cache
|
||||
from django.db.models import QuerySet
|
||||
from rest_framework import serializers
|
||||
|
||||
|
|
@ -18,7 +19,6 @@ from application.models.api_key_model import ApplicationPublicAccessClient
|
|||
from common.constants.authentication_type import AuthenticationType
|
||||
from common.field.common import InstanceField
|
||||
from common.util.field_message import ErrMessage
|
||||
from django.core import cache
|
||||
|
||||
chat_cache = cache.caches['chat_cache']
|
||||
|
||||
|
|
@ -27,9 +27,13 @@ def write_context(step_variable: Dict, global_variable: Dict, node, workflow):
|
|||
if step_variable is not None:
|
||||
for key in step_variable:
|
||||
node.context[key] = step_variable[key]
|
||||
if workflow.is_result() and 'answer' in step_variable:
|
||||
yield step_variable['answer']
|
||||
workflow.answer += step_variable['answer']
|
||||
if global_variable is not None:
|
||||
for key in global_variable:
|
||||
workflow.context[key] = global_variable[key]
|
||||
node.context['run_time'] = time.time() - node.context['start_time']
|
||||
|
||||
|
||||
class WorkFlowPostHandler:
|
||||
|
|
@ -70,18 +74,14 @@ class WorkFlowPostHandler:
|
|||
|
||||
|
||||
class NodeResult:
|
||||
def __init__(self, node_variable: Dict, workflow_variable: Dict, _to_response=None, _write_context=write_context):
|
||||
def __init__(self, node_variable: Dict, workflow_variable: Dict,
|
||||
_write_context=write_context):
|
||||
self._write_context = _write_context
|
||||
self.node_variable = node_variable
|
||||
self.workflow_variable = workflow_variable
|
||||
self._to_response = _to_response
|
||||
|
||||
def write_context(self, node, workflow):
|
||||
self._write_context(self.node_variable, self.workflow_variable, node, workflow)
|
||||
|
||||
def to_response(self, chat_id, chat_record_id, node, workflow, post_handler: WorkFlowPostHandler):
|
||||
return self._to_response(chat_id, chat_record_id, self.node_variable, self.workflow_variable, node, workflow,
|
||||
post_handler)
|
||||
return self._write_context(self.node_variable, self.workflow_variable, node, workflow)
|
||||
|
||||
def is_assertion_result(self):
|
||||
return 'branch_id' in self.node_variable
|
||||
|
|
|
|||
|
|
@ -22,6 +22,8 @@ class ChatNodeSerializer(serializers.Serializer):
|
|||
# 多轮对话数量
|
||||
dialogue_number = serializers.IntegerField(required=True, error_messages=ErrMessage.integer("多轮对话数量"))
|
||||
|
||||
is_result = serializers.BooleanField(required=False, error_messages=ErrMessage.boolean('是否返回内容'))
|
||||
|
||||
|
||||
class IChatNode(INode):
|
||||
type = 'ai-chat-node'
|
||||
|
|
|
|||
|
|
@ -13,12 +13,25 @@ from typing import List, Dict
|
|||
from langchain.schema import HumanMessage, SystemMessage
|
||||
from langchain_core.messages import BaseMessage
|
||||
|
||||
from application.flow import tools
|
||||
from application.flow.i_step_node import NodeResult, INode
|
||||
from application.flow.step_node.ai_chat_step_node.i_chat_node import IChatNode
|
||||
from setting.models_provider.tools import get_model_instance_by_model_user_id
|
||||
|
||||
|
||||
def _write_context(node_variable: Dict, workflow_variable: Dict, node: INode, workflow, answer: str):
|
||||
chat_model = node_variable.get('chat_model')
|
||||
message_tokens = chat_model.get_num_tokens_from_messages(node_variable.get('message_list'))
|
||||
answer_tokens = chat_model.get_num_tokens(answer)
|
||||
node.context['message_tokens'] = message_tokens
|
||||
node.context['answer_tokens'] = answer_tokens
|
||||
node.context['answer'] = answer
|
||||
node.context['history_message'] = node_variable['history_message']
|
||||
node.context['question'] = node_variable['question']
|
||||
node.context['run_time'] = time.time() - node.context['start_time']
|
||||
if workflow.is_result():
|
||||
workflow.answer += answer
|
||||
|
||||
|
||||
def write_context_stream(node_variable: Dict, workflow_variable: Dict, node: INode, workflow):
|
||||
"""
|
||||
写入上下文数据 (流式)
|
||||
|
|
@ -31,15 +44,8 @@ def write_context_stream(node_variable: Dict, workflow_variable: Dict, node: INo
|
|||
answer = ''
|
||||
for chunk in response:
|
||||
answer += chunk.content
|
||||
chat_model = node_variable.get('chat_model')
|
||||
message_tokens = chat_model.get_num_tokens_from_messages(node_variable.get('message_list'))
|
||||
answer_tokens = chat_model.get_num_tokens(answer)
|
||||
node.context['message_tokens'] = message_tokens
|
||||
node.context['answer_tokens'] = answer_tokens
|
||||
node.context['answer'] = answer
|
||||
node.context['history_message'] = node_variable['history_message']
|
||||
node.context['question'] = node_variable['question']
|
||||
node.context['run_time'] = time.time() - node.context['start_time']
|
||||
yield answer
|
||||
_write_context(node_variable, workflow_variable, node, workflow, answer)
|
||||
|
||||
|
||||
def write_context(node_variable: Dict, workflow_variable: Dict, node: INode, workflow):
|
||||
|
|
@ -51,71 +57,8 @@ def write_context(node_variable: Dict, workflow_variable: Dict, node: INode, wor
|
|||
@param workflow: 工作流管理器
|
||||
"""
|
||||
response = node_variable.get('result')
|
||||
chat_model = node_variable.get('chat_model')
|
||||
answer = response.content
|
||||
message_tokens = chat_model.get_num_tokens_from_messages(node_variable.get('message_list'))
|
||||
answer_tokens = chat_model.get_num_tokens(answer)
|
||||
node.context['message_tokens'] = message_tokens
|
||||
node.context['answer_tokens'] = answer_tokens
|
||||
node.context['answer'] = answer
|
||||
node.context['history_message'] = node_variable['history_message']
|
||||
node.context['question'] = node_variable['question']
|
||||
|
||||
|
||||
def get_to_response_write_context(node_variable: Dict, node: INode):
|
||||
def _write_context(answer, status=200):
|
||||
chat_model = node_variable.get('chat_model')
|
||||
|
||||
if status == 200:
|
||||
answer_tokens = chat_model.get_num_tokens(answer)
|
||||
message_tokens = chat_model.get_num_tokens_from_messages(node_variable.get('message_list'))
|
||||
else:
|
||||
answer_tokens = 0
|
||||
message_tokens = 0
|
||||
node.err_message = answer
|
||||
node.status = status
|
||||
node.context['message_tokens'] = message_tokens
|
||||
node.context['answer_tokens'] = answer_tokens
|
||||
node.context['answer'] = answer
|
||||
node.context['run_time'] = time.time() - node.context['start_time']
|
||||
|
||||
return _write_context
|
||||
|
||||
|
||||
def to_stream_response(chat_id, chat_record_id, node_variable: Dict, workflow_variable: Dict, node, workflow,
|
||||
post_handler):
|
||||
"""
|
||||
将流式数据 转换为 流式响应
|
||||
@param chat_id: 会话id
|
||||
@param chat_record_id: 对话记录id
|
||||
@param node_variable: 节点数据
|
||||
@param workflow_variable: 工作流数据
|
||||
@param node: 节点
|
||||
@param workflow: 工作流管理器
|
||||
@param post_handler: 后置处理器 输出结果后执行
|
||||
@return: 流式响应
|
||||
"""
|
||||
response = node_variable.get('result')
|
||||
_write_context = get_to_response_write_context(node_variable, node)
|
||||
return tools.to_stream_response(chat_id, chat_record_id, response, workflow, _write_context, post_handler)
|
||||
|
||||
|
||||
def to_response(chat_id, chat_record_id, node_variable: Dict, workflow_variable: Dict, node, workflow,
|
||||
post_handler):
|
||||
"""
|
||||
将结果转换
|
||||
@param chat_id: 会话id
|
||||
@param chat_record_id: 对话记录id
|
||||
@param node_variable: 节点数据
|
||||
@param workflow_variable: 工作流数据
|
||||
@param node: 节点
|
||||
@param workflow: 工作流管理器
|
||||
@param post_handler: 后置处理器
|
||||
@return: 响应
|
||||
"""
|
||||
response = node_variable.get('result')
|
||||
_write_context = get_to_response_write_context(node_variable, node)
|
||||
return tools.to_response(chat_id, chat_record_id, response, workflow, _write_context, post_handler)
|
||||
_write_context(node_variable, workflow_variable, node, workflow, answer)
|
||||
|
||||
|
||||
class BaseChatNode(IChatNode):
|
||||
|
|
@ -132,13 +75,12 @@ class BaseChatNode(IChatNode):
|
|||
r = chat_model.stream(message_list)
|
||||
return NodeResult({'result': r, 'chat_model': chat_model, 'message_list': message_list,
|
||||
'history_message': history_message, 'question': question.content}, {},
|
||||
_write_context=write_context_stream,
|
||||
_to_response=to_stream_response)
|
||||
_write_context=write_context_stream)
|
||||
else:
|
||||
r = chat_model.invoke(message_list)
|
||||
return NodeResult({'result': r, 'chat_model': chat_model, 'message_list': message_list,
|
||||
'history_message': history_message, 'question': question.content}, {},
|
||||
_write_context=write_context, _to_response=to_response)
|
||||
_write_context=write_context)
|
||||
|
||||
@staticmethod
|
||||
def get_history_message(history_chat_record, dialogue_number):
|
||||
|
|
|
|||
|
|
@ -20,6 +20,7 @@ class ReplyNodeParamsSerializer(serializers.Serializer):
|
|||
fields = serializers.ListField(required=False, error_messages=ErrMessage.list("引用字段"))
|
||||
content = serializers.CharField(required=False, allow_blank=True, allow_null=True,
|
||||
error_messages=ErrMessage.char("直接回答内容"))
|
||||
is_result = serializers.BooleanField(required=False, error_messages=ErrMessage.boolean('是否返回内容'))
|
||||
|
||||
def is_valid(self, *, raise_exception=False):
|
||||
super().is_valid(raise_exception=True)
|
||||
|
|
|
|||
|
|
@ -6,69 +6,19 @@
|
|||
@date:2024/6/11 17:25
|
||||
@desc:
|
||||
"""
|
||||
from typing import List, Dict
|
||||
from typing import List
|
||||
|
||||
from langchain_core.messages import AIMessage, AIMessageChunk
|
||||
|
||||
from application.flow import tools
|
||||
from application.flow.i_step_node import NodeResult, INode
|
||||
from application.flow.i_step_node import NodeResult
|
||||
from application.flow.step_node.direct_reply_node.i_reply_node import IReplyNode
|
||||
|
||||
|
||||
def get_to_response_write_context(node_variable: Dict, node: INode):
|
||||
def _write_context(answer, status=200):
|
||||
node.context['answer'] = answer
|
||||
|
||||
return _write_context
|
||||
|
||||
|
||||
def to_stream_response(chat_id, chat_record_id, node_variable: Dict, workflow_variable: Dict, node, workflow,
|
||||
post_handler):
|
||||
"""
|
||||
将流式数据 转换为 流式响应
|
||||
@param chat_id: 会话id
|
||||
@param chat_record_id: 对话记录id
|
||||
@param node_variable: 节点数据
|
||||
@param workflow_variable: 工作流数据
|
||||
@param node: 节点
|
||||
@param workflow: 工作流管理器
|
||||
@param post_handler: 后置处理器 输出结果后执行
|
||||
@return: 流式响应
|
||||
"""
|
||||
response = node_variable.get('result')
|
||||
_write_context = get_to_response_write_context(node_variable, node)
|
||||
return tools.to_stream_response(chat_id, chat_record_id, response, workflow, _write_context, post_handler)
|
||||
|
||||
|
||||
def to_response(chat_id, chat_record_id, node_variable: Dict, workflow_variable: Dict, node, workflow,
|
||||
post_handler):
|
||||
"""
|
||||
将结果转换
|
||||
@param chat_id: 会话id
|
||||
@param chat_record_id: 对话记录id
|
||||
@param node_variable: 节点数据
|
||||
@param workflow_variable: 工作流数据
|
||||
@param node: 节点
|
||||
@param workflow: 工作流管理器
|
||||
@param post_handler: 后置处理器
|
||||
@return: 响应
|
||||
"""
|
||||
response = node_variable.get('result')
|
||||
_write_context = get_to_response_write_context(node_variable, node)
|
||||
return tools.to_response(chat_id, chat_record_id, response, workflow, _write_context, post_handler)
|
||||
|
||||
|
||||
class BaseReplyNode(IReplyNode):
|
||||
def execute(self, reply_type, stream, fields=None, content=None, **kwargs) -> NodeResult:
|
||||
if reply_type == 'referencing':
|
||||
result = self.get_reference_content(fields)
|
||||
else:
|
||||
result = self.generate_reply_content(content)
|
||||
if stream:
|
||||
return NodeResult({'result': iter([AIMessageChunk(content=result)]), 'answer': result}, {},
|
||||
_to_response=to_stream_response)
|
||||
else:
|
||||
return NodeResult({'result': AIMessage(content=result), 'answer': result}, {}, _to_response=to_response)
|
||||
return NodeResult({'answer': result}, {})
|
||||
|
||||
def generate_reply_content(self, prompt):
|
||||
return self.workflow_manage.generate_prompt(prompt)
|
||||
|
|
|
|||
|
|
@ -22,6 +22,8 @@ class QuestionNodeSerializer(serializers.Serializer):
|
|||
# 多轮对话数量
|
||||
dialogue_number = serializers.IntegerField(required=True, error_messages=ErrMessage.integer("多轮对话数量"))
|
||||
|
||||
is_result = serializers.BooleanField(required=False, error_messages=ErrMessage.boolean('是否返回内容'))
|
||||
|
||||
|
||||
class IQuestionNode(INode):
|
||||
type = 'question-node'
|
||||
|
|
|
|||
|
|
@ -13,12 +13,25 @@ from typing import List, Dict
|
|||
from langchain.schema import HumanMessage, SystemMessage
|
||||
from langchain_core.messages import BaseMessage
|
||||
|
||||
from application.flow import tools
|
||||
from application.flow.i_step_node import NodeResult, INode
|
||||
from application.flow.step_node.question_node.i_question_node import IQuestionNode
|
||||
from setting.models_provider.tools import get_model_instance_by_model_user_id
|
||||
|
||||
|
||||
def _write_context(node_variable: Dict, workflow_variable: Dict, node: INode, workflow, answer: str):
|
||||
chat_model = node_variable.get('chat_model')
|
||||
message_tokens = chat_model.get_num_tokens_from_messages(node_variable.get('message_list'))
|
||||
answer_tokens = chat_model.get_num_tokens(answer)
|
||||
node.context['message_tokens'] = message_tokens
|
||||
node.context['answer_tokens'] = answer_tokens
|
||||
node.context['answer'] = answer
|
||||
node.context['history_message'] = node_variable['history_message']
|
||||
node.context['question'] = node_variable['question']
|
||||
node.context['run_time'] = time.time() - node.context['start_time']
|
||||
if workflow.is_result():
|
||||
workflow.answer += answer
|
||||
|
||||
|
||||
def write_context_stream(node_variable: Dict, workflow_variable: Dict, node: INode, workflow):
|
||||
"""
|
||||
写入上下文数据 (流式)
|
||||
|
|
@ -31,15 +44,8 @@ def write_context_stream(node_variable: Dict, workflow_variable: Dict, node: INo
|
|||
answer = ''
|
||||
for chunk in response:
|
||||
answer += chunk.content
|
||||
chat_model = node_variable.get('chat_model')
|
||||
message_tokens = chat_model.get_num_tokens_from_messages(node_variable.get('message_list'))
|
||||
answer_tokens = chat_model.get_num_tokens(answer)
|
||||
node.context['message_tokens'] = message_tokens
|
||||
node.context['answer_tokens'] = answer_tokens
|
||||
node.context['answer'] = answer
|
||||
node.context['history_message'] = node_variable['history_message']
|
||||
node.context['question'] = node_variable['question']
|
||||
node.context['run_time'] = time.time() - node.context['start_time']
|
||||
yield answer
|
||||
_write_context(node_variable, workflow_variable, node, workflow, answer)
|
||||
|
||||
|
||||
def write_context(node_variable: Dict, workflow_variable: Dict, node: INode, workflow):
|
||||
|
|
@ -51,71 +57,8 @@ def write_context(node_variable: Dict, workflow_variable: Dict, node: INode, wor
|
|||
@param workflow: 工作流管理器
|
||||
"""
|
||||
response = node_variable.get('result')
|
||||
chat_model = node_variable.get('chat_model')
|
||||
answer = response.content
|
||||
message_tokens = chat_model.get_num_tokens_from_messages(node_variable.get('message_list'))
|
||||
answer_tokens = chat_model.get_num_tokens(answer)
|
||||
node.context['message_tokens'] = message_tokens
|
||||
node.context['answer_tokens'] = answer_tokens
|
||||
node.context['answer'] = answer
|
||||
node.context['history_message'] = node_variable['history_message']
|
||||
node.context['question'] = node_variable['question']
|
||||
|
||||
|
||||
def get_to_response_write_context(node_variable: Dict, node: INode):
|
||||
def _write_context(answer, status=200):
|
||||
chat_model = node_variable.get('chat_model')
|
||||
|
||||
if status == 200:
|
||||
answer_tokens = chat_model.get_num_tokens(answer)
|
||||
message_tokens = chat_model.get_num_tokens_from_messages(node_variable.get('message_list'))
|
||||
else:
|
||||
answer_tokens = 0
|
||||
message_tokens = 0
|
||||
node.err_message = answer
|
||||
node.status = status
|
||||
node.context['message_tokens'] = message_tokens
|
||||
node.context['answer_tokens'] = answer_tokens
|
||||
node.context['answer'] = answer
|
||||
node.context['run_time'] = time.time() - node.context['start_time']
|
||||
|
||||
return _write_context
|
||||
|
||||
|
||||
def to_stream_response(chat_id, chat_record_id, node_variable: Dict, workflow_variable: Dict, node, workflow,
|
||||
post_handler):
|
||||
"""
|
||||
将流式数据 转换为 流式响应
|
||||
@param chat_id: 会话id
|
||||
@param chat_record_id: 对话记录id
|
||||
@param node_variable: 节点数据
|
||||
@param workflow_variable: 工作流数据
|
||||
@param node: 节点
|
||||
@param workflow: 工作流管理器
|
||||
@param post_handler: 后置处理器 输出结果后执行
|
||||
@return: 流式响应
|
||||
"""
|
||||
response = node_variable.get('result')
|
||||
_write_context = get_to_response_write_context(node_variable, node)
|
||||
return tools.to_stream_response(chat_id, chat_record_id, response, workflow, _write_context, post_handler)
|
||||
|
||||
|
||||
def to_response(chat_id, chat_record_id, node_variable: Dict, workflow_variable: Dict, node, workflow,
|
||||
post_handler):
|
||||
"""
|
||||
将结果转换
|
||||
@param chat_id: 会话id
|
||||
@param chat_record_id: 对话记录id
|
||||
@param node_variable: 节点数据
|
||||
@param workflow_variable: 工作流数据
|
||||
@param node: 节点
|
||||
@param workflow: 工作流管理器
|
||||
@param post_handler: 后置处理器
|
||||
@return: 响应
|
||||
"""
|
||||
response = node_variable.get('result')
|
||||
_write_context = get_to_response_write_context(node_variable, node)
|
||||
return tools.to_response(chat_id, chat_record_id, response, workflow, _write_context, post_handler)
|
||||
_write_context(node_variable, workflow_variable, node, workflow, answer)
|
||||
|
||||
|
||||
class BaseQuestionNode(IQuestionNode):
|
||||
|
|
@ -131,15 +74,13 @@ class BaseQuestionNode(IQuestionNode):
|
|||
if stream:
|
||||
r = chat_model.stream(message_list)
|
||||
return NodeResult({'result': r, 'chat_model': chat_model, 'message_list': message_list,
|
||||
'get_to_response_write_context': get_to_response_write_context,
|
||||
'history_message': history_message, 'question': question.content}, {},
|
||||
_write_context=write_context_stream,
|
||||
_to_response=to_stream_response)
|
||||
_write_context=write_context_stream)
|
||||
else:
|
||||
r = chat_model.invoke(message_list)
|
||||
return NodeResult({'result': r, 'chat_model': chat_model, 'message_list': message_list,
|
||||
'history_message': history_message, 'question': question.content}, {},
|
||||
_write_context=write_context, _to_response=to_response)
|
||||
_write_context=write_context)
|
||||
|
||||
@staticmethod
|
||||
def get_history_message(history_chat_record, dialogue_number):
|
||||
|
|
|
|||
|
|
@ -85,3 +85,21 @@ def to_response(chat_id, chat_record_id, response: BaseMessage, workflow, write_
|
|||
post_handler.handler(chat_id, chat_record_id, answer, workflow)
|
||||
return result.success({'chat_id': str(chat_id), 'id': str(chat_record_id), 'operate': True,
|
||||
'content': answer, 'is_end': True})
|
||||
|
||||
|
||||
def to_response_simple(chat_id, chat_record_id, response: BaseMessage, workflow,
|
||||
post_handler: WorkFlowPostHandler):
|
||||
answer = response.content
|
||||
post_handler.handler(chat_id, chat_record_id, answer, workflow)
|
||||
return result.success({'chat_id': str(chat_id), 'id': str(chat_record_id), 'operate': True,
|
||||
'content': answer, 'is_end': True})
|
||||
|
||||
|
||||
def to_stream_response_simple(stream_event):
|
||||
r = StreamingHttpResponse(
|
||||
streaming_content=stream_event,
|
||||
content_type='text/event-stream;charset=utf-8',
|
||||
charset='utf-8')
|
||||
|
||||
r['Cache-Control'] = 'no-cache'
|
||||
return r
|
||||
|
|
|
|||
|
|
@ -6,10 +6,11 @@
|
|||
@date:2024/1/9 17:40
|
||||
@desc:
|
||||
"""
|
||||
import json
|
||||
from functools import reduce
|
||||
from typing import List, Dict
|
||||
|
||||
from langchain_core.messages import AIMessageChunk, AIMessage
|
||||
from langchain_core.messages import AIMessage
|
||||
from langchain_core.prompts import PromptTemplate
|
||||
|
||||
from application.flow import tools
|
||||
|
|
@ -63,7 +64,6 @@ class Flow:
|
|||
def get_search_node(self):
|
||||
return [node for node in self.nodes if node.type == 'search-dataset-node']
|
||||
|
||||
|
||||
def is_valid(self):
|
||||
"""
|
||||
校验工作流数据
|
||||
|
|
@ -140,33 +140,71 @@ class WorkflowManage:
|
|||
self.work_flow_post_handler = work_flow_post_handler
|
||||
self.current_node = None
|
||||
self.current_result = None
|
||||
self.answer = ""
|
||||
|
||||
def run(self):
|
||||
"""
|
||||
运行工作流
|
||||
"""
|
||||
if self.params.get('stream'):
|
||||
return self.run_stream()
|
||||
return self.run_block()
|
||||
|
||||
def run_block(self):
|
||||
try:
|
||||
while self.has_next_node(self.current_result):
|
||||
self.current_node = self.get_next_node()
|
||||
self.node_context.append(self.current_node)
|
||||
self.current_result = self.current_node.run()
|
||||
if self.has_next_node(self.current_result):
|
||||
self.current_result.write_context(self.current_node, self)
|
||||
else:
|
||||
r = self.current_result.to_response(self.params['chat_id'], self.params['chat_record_id'],
|
||||
self.current_node, self,
|
||||
self.work_flow_post_handler)
|
||||
return r
|
||||
result = self.current_result.write_context(self.current_node, self)
|
||||
if result is not None:
|
||||
list(result)
|
||||
if not self.has_next_node(self.current_result):
|
||||
return tools.to_response_simple(self.params['chat_id'], self.params['chat_record_id'],
|
||||
AIMessage(self.answer), self,
|
||||
self.work_flow_post_handler)
|
||||
except Exception as e:
|
||||
if self.params.get('stream'):
|
||||
return tools.to_stream_response(self.params['chat_id'], self.params['chat_record_id'],
|
||||
iter([AIMessageChunk(str(e))]), self,
|
||||
self.current_node.get_write_error_context(e),
|
||||
self.work_flow_post_handler)
|
||||
else:
|
||||
return tools.to_response(self.params['chat_id'], self.params['chat_record_id'],
|
||||
AIMessage(str(e)), self, self.current_node.get_write_error_context(e),
|
||||
self.work_flow_post_handler)
|
||||
return tools.to_response(self.params['chat_id'], self.params['chat_record_id'],
|
||||
AIMessage(str(e)), self, self.current_node.get_write_error_context(e),
|
||||
self.work_flow_post_handler)
|
||||
|
||||
def run_stream(self):
|
||||
return tools.to_stream_response_simple(self.stream_event())
|
||||
|
||||
def stream_event(self):
|
||||
try:
|
||||
while self.has_next_node(self.current_result):
|
||||
self.current_node = self.get_next_node()
|
||||
self.node_context.append(self.current_node)
|
||||
self.current_result = self.current_node.run()
|
||||
result = self.current_result.write_context(self.current_node, self)
|
||||
if result is not None:
|
||||
for r in result:
|
||||
if self.is_result():
|
||||
yield self.get_chunk_content(r)
|
||||
if not self.has_next_node(self.current_result):
|
||||
yield self.get_chunk_content('', True)
|
||||
break
|
||||
self.work_flow_post_handler.handler(self.params['chat_id'], self.params['chat_record_id'],
|
||||
self.answer,
|
||||
self)
|
||||
except Exception as e:
|
||||
self.current_node.get_write_error_context(e)
|
||||
self.answer += str(e)
|
||||
self.work_flow_post_handler.handler(self.params['chat_id'], self.params['chat_record_id'],
|
||||
self.answer,
|
||||
self)
|
||||
yield self.get_chunk_content(str(e), True)
|
||||
|
||||
def is_result(self):
|
||||
"""
|
||||
判断是否是返回节点
|
||||
@return:
|
||||
"""
|
||||
return self.current_node.node_params.get('is_result', not self.has_next_node(
|
||||
self.current_result)) if self.current_node.node_params is not None else False
|
||||
|
||||
def get_chunk_content(self, chunk, is_end=False):
|
||||
return 'data: ' + json.dumps(
|
||||
{'chat_id': self.params['chat_id'], 'id': self.params['chat_record_id'], 'operate': True,
|
||||
'content': chunk, 'is_end': is_end}, ensure_ascii=False) + "\n\n"
|
||||
|
||||
def has_next_node(self, node_result: NodeResult | None):
|
||||
"""
|
||||
|
|
|
|||
|
|
@ -170,3 +170,13 @@ export const nodeDict: any = {
|
|||
export function isWorkFlow(type: string | undefined) {
|
||||
return type === 'WORK_FLOW'
|
||||
}
|
||||
|
||||
export function isLastNode(nodeModel: any) {
|
||||
const incoming = nodeModel.graphModel.getNodeIncomingNode(nodeModel.id)
|
||||
const outcomming = nodeModel.graphModel.getNodeOutgoingNode(nodeModel.id)
|
||||
if (incoming.length > 0 && outcomming.length === 0) {
|
||||
return true
|
||||
} else {
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -132,6 +132,23 @@
|
|||
class="w-full"
|
||||
/>
|
||||
</el-form-item>
|
||||
<el-form-item label="返回内容" @click.prevent>
|
||||
<template #label>
|
||||
<div class="flex align-center">
|
||||
<div class="mr-4">
|
||||
<span>返回内容<span class="danger">*</span></span>
|
||||
</div>
|
||||
<el-tooltip effect="dark" placement="right" popper-class="max-w-200">
|
||||
<template #content>
|
||||
关闭后该节点的内容则不输出给用户。
|
||||
如果你想让用户看到该节点的输出内容,请打开开关。
|
||||
</template>
|
||||
<AppIcon iconName="app-warning" class="app-warning-icon"></AppIcon>
|
||||
</el-tooltip>
|
||||
</div>
|
||||
</template>
|
||||
<el-switch size="small" v-model="chat_data.is_result" />
|
||||
</el-form-item>
|
||||
</el-form>
|
||||
</el-card>
|
||||
|
||||
|
|
@ -156,6 +173,7 @@ import applicationApi from '@/api/application'
|
|||
import useStore from '@/stores'
|
||||
import { relatedObject } from '@/utils/utils'
|
||||
import type { Provider } from '@/api/type/model'
|
||||
import { isLastNode } from '@/workflow/common/data'
|
||||
|
||||
const { model } = useStore()
|
||||
const isKeyDown = ref(false)
|
||||
|
|
@ -180,7 +198,8 @@ const form = {
|
|||
model_id: '',
|
||||
system: '',
|
||||
prompt: defaultPrompt,
|
||||
dialogue_number: 1
|
||||
dialogue_number: 1,
|
||||
is_result: false
|
||||
}
|
||||
|
||||
const chat_data = computed({
|
||||
|
|
@ -240,6 +259,12 @@ const openCreateModel = (provider?: Provider) => {
|
|||
onMounted(() => {
|
||||
getProvider()
|
||||
getModel()
|
||||
if (typeof props.nodeModel.properties.node_data?.is_result === 'undefined') {
|
||||
if (isLastNode(props.nodeModel)) {
|
||||
set(props.nodeModel.properties.node_data, 'is_result', true)
|
||||
}
|
||||
}
|
||||
|
||||
set(props.nodeModel, 'validate', validate)
|
||||
})
|
||||
</script>
|
||||
|
|
|
|||
|
|
@ -133,6 +133,23 @@
|
|||
class="w-full"
|
||||
/>
|
||||
</el-form-item>
|
||||
<el-form-item label="返回内容" @click.prevent>
|
||||
<template #label>
|
||||
<div class="flex align-center">
|
||||
<div class="mr-4">
|
||||
<span>返回内容<span class="danger">*</span></span>
|
||||
</div>
|
||||
<el-tooltip effect="dark" placement="right" popper-class="max-w-200">
|
||||
<template #content>
|
||||
关闭后该节点的内容则不输出给用户。
|
||||
如果你想让用户看到该节点的输出内容,请打开开关。
|
||||
</template>
|
||||
<AppIcon iconName="app-warning" class="app-warning-icon"></AppIcon>
|
||||
</el-tooltip>
|
||||
</div>
|
||||
</template>
|
||||
<el-switch size="small" v-model="form_data.is_result" />
|
||||
</el-form-item>
|
||||
</el-form>
|
||||
</el-card>
|
||||
<!-- 添加模版 -->
|
||||
|
|
@ -156,6 +173,8 @@ import applicationApi from '@/api/application'
|
|||
import useStore from '@/stores'
|
||||
import { relatedObject } from '@/utils/utils'
|
||||
import type { Provider } from '@/api/type/model'
|
||||
import { isLastNode } from '@/workflow/common/data'
|
||||
|
||||
const { model } = useStore()
|
||||
const isKeyDown = ref(false)
|
||||
const wheel = (e: any) => {
|
||||
|
|
@ -177,7 +196,8 @@ const form = {
|
|||
model_id: '',
|
||||
system: '你是一个问题优化大师',
|
||||
prompt: defaultPrompt,
|
||||
dialogue_number: 1
|
||||
dialogue_number: 1,
|
||||
is_result: false
|
||||
}
|
||||
|
||||
const form_data = computed({
|
||||
|
|
@ -237,6 +257,11 @@ const openCreateModel = (provider?: Provider) => {
|
|||
onMounted(() => {
|
||||
getProvider()
|
||||
getModel()
|
||||
if (typeof props.nodeModel.properties.node_data?.is_result === 'undefined') {
|
||||
if (isLastNode(props.nodeModel)) {
|
||||
set(props.nodeModel.properties.node_data, 'is_result', true)
|
||||
}
|
||||
}
|
||||
set(props.nodeModel, 'validate', validate)
|
||||
})
|
||||
</script>
|
||||
|
|
|
|||
|
|
@ -46,6 +46,23 @@
|
|||
v-model="form_data.fields"
|
||||
/>
|
||||
</el-form-item>
|
||||
<el-form-item label="返回内容" @click.prevent>
|
||||
<template #label>
|
||||
<div class="flex align-center">
|
||||
<div class="mr-4">
|
||||
<span>返回内容<span class="danger">*</span></span>
|
||||
</div>
|
||||
<el-tooltip effect="dark" placement="right" popper-class="max-w-200">
|
||||
<template #content>
|
||||
关闭后该节点的内容则不输出给用户。
|
||||
如果你想让用户看到该节点的输出内容,请打开开关。
|
||||
</template>
|
||||
<AppIcon iconName="app-warning" class="app-warning-icon"></AppIcon>
|
||||
</el-tooltip>
|
||||
</div>
|
||||
</template>
|
||||
<el-switch size="small" v-model="form_data.is_result" />
|
||||
</el-form-item>
|
||||
</el-form>
|
||||
</el-card>
|
||||
<!-- 回复内容弹出层 -->
|
||||
|
|
@ -64,12 +81,14 @@ import { set } from 'lodash'
|
|||
import NodeContainer from '@/workflow/common/NodeContainer.vue'
|
||||
import NodeCascader from '@/workflow/common/NodeCascader.vue'
|
||||
import { ref, computed, onMounted } from 'vue'
|
||||
import { isLastNode } from '@/workflow/common/data'
|
||||
|
||||
const props = defineProps<{ nodeModel: any }>()
|
||||
const form = {
|
||||
reply_type: 'content',
|
||||
content: '',
|
||||
fields: []
|
||||
fields: [],
|
||||
is_result: false
|
||||
}
|
||||
const footers: any = [null, '=', 0]
|
||||
|
||||
|
|
@ -111,6 +130,12 @@ const validate = () => {
|
|||
}
|
||||
|
||||
onMounted(() => {
|
||||
if (typeof props.nodeModel.properties.node_data?.is_result === 'undefined') {
|
||||
if (isLastNode(props.nodeModel)) {
|
||||
set(props.nodeModel.properties.node_data, 'is_result', true)
|
||||
}
|
||||
}
|
||||
|
||||
set(props.nodeModel, 'validate', validate)
|
||||
})
|
||||
</script>
|
||||
|
|
|
|||
Loading…
Reference in New Issue