feat: 工作流处理默认参数

This commit is contained in:
shaohuzhang1 2024-09-23 16:42:34 +08:00 committed by shaohuzhang1
parent 264a7309f7
commit f4a3883be2
3 changed files with 37 additions and 15 deletions

View File

@ -8,23 +8,41 @@
"""
import time
from datetime import datetime
from typing import List, Type
from rest_framework import serializers
from application.flow.i_step_node import NodeResult
from application.flow.step_node.start_node.i_start_node import IStarNode
def get_default_global_variable(input_field_list: List):
return {item.get('variable'): item.get('default_value') for item in input_field_list if
item.get('default_value', None) is not None}
def get_global_variable(node):
history_chat_record = node.flow_params_serializer.data.get('history_chat_record', [])
history_context = [{'question': chat_record.problem_text, 'answer': chat_record.answer_text} for chat_record in
history_chat_record]
chat_id = node.flow_params_serializer.data.get('chat_id')
return {'time': datetime.now().strftime('%Y-%m-%d %H:%M:%S'), 'start_time': time.time(),
'history_context': history_context, 'chat_id': str(chat_id), **node.workflow_manage.form_data}
class BaseStartStepNode(IStarNode):
def get_node_params_serializer_class(self) -> Type[serializers.Serializer]:
pass
def execute(self, question, **kwargs) -> NodeResult:
history_chat_record = self.flow_params_serializer.data.get('history_chat_record', [])
history_context = [{'question': chat_record.problem_text, 'answer': chat_record.answer_text} for chat_record in
history_chat_record]
chat_id = self.flow_params_serializer.data.get('chat_id')
base_node = self.workflow_manage.get_base_node()
default_global_variable = get_default_global_variable(base_node.properties.get('input_field_list', []))
workflow_variable = {**default_global_variable, **get_global_variable(self)}
"""
开始节点 初始化全局变量
"""
return NodeResult({'question': question},
{'time': datetime.now().strftime('%Y-%m-%d %H:%M:%S'), 'start_time': time.time(),
'history_context': history_context, 'chat_id': str(chat_id)})
workflow_variable)
def get_details(self, index: int, **kwargs):
global_fields = []
@ -33,7 +51,7 @@ class BaseStartStepNode(IStarNode):
global_fields.append({
'label': field['label'],
'key': key,
'value': self.workflow_manage[key] if key in self.workflow_manage else ''
'value': self.workflow_manage.context[key] if key in self.workflow_manage.context else ''
})
return {
'name': self.node.properties.get('stepName'),

View File

@ -169,9 +169,10 @@ class WorkflowManage:
base_to_response: BaseToResponse = SystemToResponse(), form_data=None):
if form_data is None:
form_data = {}
self.form_data = form_data
self.params = params
self.flow = flow
self.context = form_data
self.context = {}
self.node_context = []
self.work_flow_post_handler = work_flow_post_handler
self.current_node = None
@ -302,7 +303,7 @@ class WorkflowManage:
"""
if self.current_node is None:
node = self.get_start_node()
node_instance = get_node(node.type)(node, self.params, self.context)
node_instance = get_node(node.type)(node, self.params, self)
return node_instance
if self.current_result is not None and self.current_result.is_assertion_result():
for edge in self.flow.edges:
@ -368,6 +369,14 @@ class WorkflowManage:
start_node_list = [node for node in self.flow.nodes if node.type == 'start-node']
return start_node_list[0]
def get_base_node(self):
"""
获取基础节点
@return:
"""
base_node_list = [node for node in self.flow.nodes if node.type == 'base-node']
return base_node_list[0]
def get_node_cls_by_id(self, node_id):
for node in self.flow.nodes:
if node.id == node_id:

View File

@ -224,6 +224,7 @@ const {
const form = {
reranker_reference_list: [[]],
reranker_model_id: '',
question_reference_address: [],
reranker_setting: {
top_n: 3,
similarity: 0.6,
@ -306,12 +307,6 @@ const openCreateModel = (provider?: Provider) => {
onMounted(() => {
getProvider()
getModel()
if (typeof props.nodeModel.properties.node_data?.is_result === 'undefined') {
if (isLastNode(props.nodeModel)) {
set(props.nodeModel.properties.node_data, 'is_result', true)
}
}
set(props.nodeModel, 'validate', validate)
})
</script>