mirror of
https://github.com/1Panel-dev/MaxKB.git
synced 2025-12-26 10:12:51 +00:00
194 lines
7.1 KiB
Python
194 lines
7.1 KiB
Python
# coding=utf-8
|
||
"""
|
||
@project: maxkb
|
||
@Author:虎
|
||
@file: workflow_manage.py
|
||
@date:2024/1/9 17:40
|
||
@desc:
|
||
"""
|
||
from concurrent.futures import ThreadPoolExecutor
|
||
from typing import List
|
||
|
||
from django.db import close_old_connections
|
||
from django.utils.translation import get_language
|
||
from langchain_core.prompts import PromptTemplate
|
||
|
||
from application.flow.common import Workflow
|
||
from application.flow.i_step_node import WorkFlowPostHandler, INode
|
||
from application.flow.step_node import get_node
|
||
from application.flow.workflow_manage import WorkflowManage
|
||
from common.handle.base_to_response import BaseToResponse
|
||
from common.handle.impl.response.system_to_response import SystemToResponse
|
||
|
||
executor = ThreadPoolExecutor(max_workers=200)
|
||
|
||
|
||
class NodeResultFuture:
|
||
def __init__(self, r, e, status=200):
|
||
self.r = r
|
||
self.e = e
|
||
self.status = status
|
||
|
||
def result(self):
|
||
if self.status == 200:
|
||
return self.r
|
||
else:
|
||
raise self.e
|
||
|
||
|
||
def await_result(result, timeout=1):
|
||
try:
|
||
result.result(timeout)
|
||
return False
|
||
except Exception as e:
|
||
return True
|
||
|
||
|
||
class NodeChunkManage:
|
||
|
||
def __init__(self, work_flow):
|
||
self.node_chunk_list = []
|
||
self.current_node_chunk = None
|
||
self.work_flow = work_flow
|
||
|
||
def add_node_chunk(self, node_chunk):
|
||
self.node_chunk_list.append(node_chunk)
|
||
|
||
def contains(self, node_chunk):
|
||
return self.node_chunk_list.__contains__(node_chunk)
|
||
|
||
def pop(self):
|
||
if self.current_node_chunk is None:
|
||
try:
|
||
current_node_chunk = self.node_chunk_list.pop(0)
|
||
self.current_node_chunk = current_node_chunk
|
||
except IndexError as e:
|
||
pass
|
||
if self.current_node_chunk is not None:
|
||
try:
|
||
chunk = self.current_node_chunk.chunk_list.pop(0)
|
||
return chunk
|
||
except IndexError as e:
|
||
if self.current_node_chunk.is_end():
|
||
self.current_node_chunk = None
|
||
if self.work_flow.answer_is_not_empty():
|
||
chunk = self.work_flow.base_to_response.to_stream_chunk_response(
|
||
self.work_flow.params['chat_id'],
|
||
self.work_flow.params['chat_record_id'],
|
||
'\n\n', False, 0, 0)
|
||
self.work_flow.append_answer('\n\n')
|
||
return chunk
|
||
return self.pop()
|
||
return None
|
||
|
||
|
||
class LoopWorkflowManage(WorkflowManage):
|
||
|
||
def __init__(self, flow: Workflow,
|
||
params,
|
||
work_flow_post_handler: WorkFlowPostHandler,
|
||
parentWorkflowManage,
|
||
loop_params,
|
||
get_loop_context,
|
||
base_to_response: BaseToResponse = SystemToResponse(),
|
||
start_node_id=None,
|
||
start_node_data=None, chat_record=None, child_node=None, is_the_task_interrupted=lambda: False):
|
||
self.parentWorkflowManage = parentWorkflowManage
|
||
self.loop_params = loop_params
|
||
self.get_loop_context = get_loop_context
|
||
self.loop_field_list = []
|
||
super().__init__(flow, params, work_flow_post_handler, base_to_response, None, None, None,
|
||
None,
|
||
None, None, start_node_id, start_node_data, chat_record, child_node, is_the_task_interrupted)
|
||
|
||
def get_node_cls_by_id(self, node_id, up_node_id_list=None,
|
||
get_node_params=lambda node: node.properties.get('node_data')):
|
||
for node in self.flow.nodes:
|
||
if node.id == node_id:
|
||
node_instance = get_node(node.type, self.flow.workflow_mode)(node,
|
||
self.params, self, up_node_id_list,
|
||
get_node_params,
|
||
salt=self.get_index())
|
||
return node_instance
|
||
return None
|
||
|
||
def stream(self):
|
||
close_old_connections()
|
||
language = get_language()
|
||
self.run_chain_async(self.start_node, None, language)
|
||
return self.await_result(is_cleanup=False)
|
||
|
||
def get_index(self):
|
||
return self.loop_params.get('index')
|
||
|
||
def get_start_node(self):
|
||
start_node_list = [node for node in self.flow.nodes if
|
||
['loop-start-node'].__contains__(node.type)]
|
||
return start_node_list[0]
|
||
|
||
def get_reference_field(self, node_id: str, fields: List[str]):
|
||
"""
|
||
@param node_id: 节点id
|
||
@param fields: 字段
|
||
@return:
|
||
"""
|
||
if node_id == 'global':
|
||
return self.parentWorkflowManage.get_reference_field(node_id, fields)
|
||
elif node_id == 'chat':
|
||
return self.parentWorkflowManage.get_reference_field(node_id, fields)
|
||
elif node_id == 'loop':
|
||
loop_context = self.get_loop_context()
|
||
return INode.get_field(loop_context, fields)
|
||
else:
|
||
node = self.get_node_by_id(node_id)
|
||
if node:
|
||
return node.get_reference_field(fields)
|
||
return self.parentWorkflowManage.get_reference_field(node_id, fields)
|
||
|
||
def get_workflow_content(self):
|
||
context = {
|
||
'global': self.context,
|
||
'chat': self.chat_context,
|
||
'loop': self.get_loop_context(),
|
||
}
|
||
|
||
for node in self.node_context:
|
||
context[node.id] = node.context
|
||
return context
|
||
|
||
def init_fields(self):
|
||
super().init_fields()
|
||
loop_field_list = []
|
||
loop_start_node = self.flow.get_node('loop-start-node')
|
||
loop_input_field_list = loop_start_node.properties.get('loop_input_field_list')
|
||
node_name = loop_start_node.properties.get('stepName')
|
||
node_id = loop_start_node.id
|
||
if loop_input_field_list is not None:
|
||
for f in loop_input_field_list:
|
||
loop_field_list.append(
|
||
{'label': f.get('label'), 'value': f.get('field'), 'node_id': node_id, 'node_name': node_name})
|
||
self.loop_field_list = loop_field_list
|
||
|
||
def reset_prompt(self, prompt: str):
|
||
prompt = super().reset_prompt(prompt)
|
||
for field in self.loop_field_list:
|
||
chatLabel = f"loop.{field.get('value')}"
|
||
chatValue = f"context.get('loop').get('{field.get('value', '')}','')"
|
||
prompt = prompt.replace(chatLabel, chatValue)
|
||
|
||
prompt = self.parentWorkflowManage.reset_prompt(prompt)
|
||
return prompt
|
||
|
||
def generate_prompt(self, prompt: str):
|
||
"""
|
||
格式化生成提示词
|
||
@param prompt: 提示词信息
|
||
@return: 格式化后的提示词
|
||
"""
|
||
|
||
context = {**self.get_workflow_content(), **self.parentWorkflowManage.get_workflow_content()}
|
||
prompt = self.reset_prompt(prompt)
|
||
prompt_template = PromptTemplate.from_template(prompt, template_format='jinja2')
|
||
value = prompt_template.format(context=context)
|
||
return value
|