refactor: Workflow execution logic (#1886)
Some checks failed
sync2gitee / repo-sync (push) Has been cancelled
Typos Check / Spell Check with Typos (push) Has been cancelled

This commit is contained in:
shaohuzhang1 2024-12-20 20:29:15 +08:00 committed by GitHub
parent abef79efa6
commit a00af1e288
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -29,7 +29,7 @@ from function_lib.models.function import FunctionLib
from setting.models import Model
from setting.models_provider import get_model_credential
executor = ThreadPoolExecutor(max_workers=50)
executor = ThreadPoolExecutor(max_workers=200)
class Edge:
@ -271,7 +271,7 @@ class WorkflowManage:
self.current_result = None
self.answer = ""
self.answer_list = ['']
self.status = 0
self.status = 200
self.base_to_response = base_to_response
self.chat_record = chat_record
self.await_future_map = {}
@ -384,8 +384,23 @@ class WorkflowManage:
'', True, message_tokens, answer_tokens, {})
def run_chain_async(self, current_node, node_result_future):
future = executor.submit(self.run_chain, current_node, node_result_future)
return future
return executor.submit(self.run_chain_manage, current_node, node_result_future)
def run_chain_manage(self, current_node, node_result_future):
if current_node is None:
start_node = self.get_start_node()
current_node = get_node(start_node.type)(start_node, self.params, self)
result = self.run_chain(current_node, node_result_future)
node_list = self.get_next_node_list(current_node, result)
if len(node_list) == 1:
self.run_chain_manage(node_list[0], None)
elif len(node_list) > 1:
# 获取到可执行的子节点
result_list = [{'node': node, 'future': executor.submit(self.run_chain_manage, node, None)} for node in
node_list]
self.set_await_map(result_list)
[r.get('future').result() for r in result_list]
def set_await_map(self, node_run_list):
sorted_node_run_list = sorted(node_run_list, key=lambda n: n.get('node').node.y)
@ -395,9 +410,6 @@ class WorkflowManage:
for i in range(index)]
def run_chain(self, current_node, node_result_future=None):
if current_node is None:
start_node = self.get_start_node()
current_node = get_node(start_node.type)(start_node, self.params, self)
if node_result_future is None:
node_result_future = self.run_node_future(current_node)
try:
@ -409,18 +421,10 @@ class WorkflowManage:
result = self.hand_event_node_result(current_node,
node_result_future) if is_stream else self.hand_node_result(
current_node, node_result_future)
with self.lock:
if current_node.status == 500:
return
node_list = self.get_next_node_list(current_node, result)
# 获取到可执行的子节点
result_list = [{'node': node, 'future': self.run_chain_async(node, None)} for node in node_list]
self.set_await_map(result_list)
[r.get('future').result() for r in result_list]
if self.status == 0:
self.status = 200
return result
except Exception as e:
traceback.print_exc()
return []
def hand_node_result(self, current_node, node_result_future):
try: