MaxKB/apps/application/flow/workflow_manage.py
2024-06-14 19:18:07 +08:00

181 lines
6.2 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

# coding=utf-8
"""
@project: maxkb
@Author
@file workflow_manage.py
@date2024/1/9 17:40
@desc:
"""
import json
import uuid
from typing import List, Dict
from langchain_core.prompts import PromptTemplate
from application.flow.i_step_node import INode, WorkFlowPostHandler, NodeResult
from application.flow.step_node import get_node
class Edge:
def __init__(self, _id: str, _type: str, sourceNodeId: str, targetNodeId: str, **keywords):
self.id = _id
self.type = _type
self.sourceNodeId = sourceNodeId
self.targetNodeId = targetNodeId
for keyword in keywords:
self.__setattr__(keyword, keywords.get(keyword))
class Node:
def __init__(self, _id: str, _type: str, x: int, y: int, properties: dict, **kwargs):
self.id = _id
self.type = _type
self.x = x
self.y = y
self.properties = properties
for keyword in kwargs:
self.__setattr__(keyword, kwargs.get(keyword))
class Flow:
def __init__(self, nodes: List[Node], edges: List[Edge]):
self.nodes = nodes
self.edges = edges
@staticmethod
def new_instance(flow_obj: Dict):
nodes = flow_obj.get('nodes')
edges = flow_obj.get('edges')
nodes = [Node(node.get('id'), node.get('type'), **node)
for node in nodes]
edges = [Edge(edge.get('id'), edge.get('type'), **edge) for edge in edges]
return Flow(nodes, edges)
class WorkflowManage:
def __init__(self, flow: Flow, params, work_flow_post_handler: WorkFlowPostHandler):
self.params = params
self.flow = flow
self.context = {}
self.node_context = []
self.work_flow_post_handler = work_flow_post_handler
self.current_node = None
self.current_result = None
def run(self):
"""
运行工作流
"""
while self.has_next_node(self.current_result):
self.current_node = self.get_next_node()
self.node_context.append(self.current_node)
self.current_result = self.current_node.run()
if self.has_next_node(self.current_result):
self.current_result.write_context(self.current_node, self)
else:
r = self.current_result.to_response(self.params['chat_id'], self.params['chat_record_id'],
self.current_node, self,
self.work_flow_post_handler)
return r
def has_next_node(self, node_result: NodeResult | None):
"""
是否有下一个可运行的节点
"""
if self.current_node is None:
if self.get_start_node() is not None:
return True
else:
if node_result is not None and node_result.is_assertion_result():
for edge in self.flow.edges:
if (edge.sourceNodeId == self.current_node.id and
f"{edge.sourceNodeId}_{node_result.node_variable.get('branch_id')}_right" == edge.sourceAnchorId):
return True
else:
for edge in self.flow.edges:
if edge.sourceNodeId == self.current_node.id:
return True
return False
def get_runtime_details(self):
details_result = {}
for index in range(len(self.node_context)):
node = self.node_context[index]
details = node.get_details({'index': index})
details_result[node.id] = details
return details_result
def get_next_node(self):
"""
获取下一个可运行的所有节点
"""
if self.current_node is None:
node = self.get_start_node()
node_instance = get_node(node.type)(node, self.params, self.context)
return node_instance
if self.current_result is not None and self.current_result.is_assertion_result():
for edge in self.flow.edges:
if (edge.sourceNodeId == self.current_node.id and
f"{edge.sourceNodeId}_{self.current_result.node_variable.get('branch_id')}_right" == edge.sourceAnchorId):
return self.get_node_cls_by_id(edge.targetNodeId)
else:
for edge in self.flow.edges:
if edge.sourceNodeId == self.current_node.id:
return self.get_node_cls_by_id(edge.targetNodeId)
return None
def get_reference_field(self, node_id: str, fields: List[str]):
"""
@param node_id: 节点id
@param fields: 字段
@return:
"""
if node_id == 'global':
return INode.get_field(self.context, fields)
else:
return self.get_node_by_id(node_id).get_reference_field(fields)
def generate_prompt(self, prompt: str):
"""
格式化生成提示词
@param prompt: 提示词信息
@return: 格式化后的提示词
"""
prompt_template = PromptTemplate.from_template(prompt, template_format='jinja2')
context = {
'global': self.context,
}
for node in self.node_context:
context[node.id] = node.context
value = prompt_template.format(context=context)
return value
def get_start_node(self):
"""
获取启动节点
@return:
"""
start_node_list = [node for node in self.flow.nodes if node.type == 'start-node']
return start_node_list[0]
def get_node_cls_by_id(self, node_id):
for node in self.flow.nodes:
if node.id == node_id:
node_instance = get_node(node.type)(node,
self.params, self)
return node_instance
return None
def get_node_by_id(self, node_id):
for node in self.node_context:
if node.id == node_id:
return node
return None
def get_node_reference(self, reference_address: Dict):
node = self.get_node_by_id(reference_address.get('node_id'))
return node.context[reference_address.get('node_field')]