refactor: Workflow execution logic (#1913)

This commit is contained in:
shaohuzhang1 2024-12-26 15:29:55 +08:00 committed by GitHub
parent bb58ac6f2c
commit efa73c827f
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
9 changed files with 94 additions and 70 deletions

View File

@ -19,3 +19,20 @@ class Answer:
def to_dict(self):
return {'view_type': self.view_type, 'content': self.content, 'runtime_node_id': self.runtime_node_id,
'chat_record_id': self.chat_record_id, 'child_node': self.child_node}
class NodeChunk:
def __init__(self):
self.status = 0
self.chunk_list = []
def add_chunk(self, chunk):
self.chunk_list.append(chunk)
def end(self, chunk=None):
if chunk is not None:
self.add_chunk(chunk)
self.status = 200
def is_end(self):
return self.status == 200

View File

@ -17,7 +17,7 @@ from django.db.models import QuerySet
from rest_framework import serializers
from rest_framework.exceptions import ValidationError, ErrorDetail
from application.flow.common import Answer
from application.flow.common import Answer, NodeChunk
from application.models import ChatRecord
from application.models.api_key_model import ApplicationPublicAccessClient
from common.constants.authentication_type import AuthenticationType
@ -175,6 +175,7 @@ class INode:
if up_node_id_list is None:
up_node_id_list = []
self.up_node_id_list = up_node_id_list
self.node_chunk = NodeChunk()
self.runtime_node_id = sha1(uuid.NAMESPACE_DNS.bytes + bytes(str(uuid.uuid5(uuid.NAMESPACE_DNS,
"".join([*sorted(up_node_id_list),
node.id]))),
@ -214,6 +215,7 @@ class INode:
def get_write_error_context(self, e):
self.status = 500
self.answer_text = str(e)
self.err_message = str(e)
self.context['run_time'] = time.time() - self.context['start_time']

View File

@ -10,7 +10,7 @@ import os
from typing import List, Dict
from django.db.models import QuerySet
from django.db import connection
from application.flow.i_step_node import NodeResult
from application.flow.step_node.search_dataset_node.i_search_dataset_node import ISearchDatasetStepNode
from common.config.embedding_config import VectorStore
@ -77,6 +77,8 @@ class BaseSearchDatasetNode(ISearchDatasetStepNode):
embedding_list = vector.query(question, embedding_value, dataset_id_list, exclude_document_id_list,
exclude_paragraph_id_list, True, dataset_setting.get('top_n'),
dataset_setting.get('similarity'), SearchMode(dataset_setting.get('search_mode')))
# 手动关闭数据库连接
connection.close()
if embedding_list is None:
return get_none_result(question)
paragraph_list = self.list_paragraph(embedding_list, vector)

View File

@ -6,6 +6,7 @@
@date2024/1/9 17:40
@desc:
"""
import concurrent
import json
import threading
import traceback
@ -13,6 +14,7 @@ from concurrent.futures import ThreadPoolExecutor
from functools import reduce
from typing import List, Dict
from django.db import close_old_connections
from django.db.models import QuerySet
from langchain_core.prompts import PromptTemplate
from rest_framework import status
@ -223,23 +225,6 @@ class NodeChunkManage:
return None
class NodeChunk:
def __init__(self):
self.status = 0
self.chunk_list = []
def add_chunk(self, chunk):
self.chunk_list.append(chunk)
def end(self, chunk=None):
if chunk is not None:
self.add_chunk(chunk)
self.status = 200
def is_end(self):
return self.status == 200
class WorkflowManage:
def __init__(self, flow: Flow, params, work_flow_post_handler: WorkFlowPostHandler,
base_to_response: BaseToResponse = SystemToResponse(), form_data=None, image_list=None,
@ -273,8 +258,9 @@ class WorkflowManage:
self.status = 200
self.base_to_response = base_to_response
self.chat_record = chat_record
self.await_future_map = {}
self.child_node = child_node
self.future_list = []
self.lock = threading.Lock()
if start_node_id is not None:
self.load_node(chat_record, start_node_id, start_node_data)
else:
@ -319,6 +305,7 @@ class WorkflowManage:
self.node_context.append(node)
def run(self):
close_old_connections()
if self.params.get('stream'):
return self.run_stream(self.start_node, None)
return self.run_block()
@ -328,8 +315,9 @@ class WorkflowManage:
非流式响应
@return: 结果
"""
result = self.run_chain_async(None, None)
result.result()
self.run_chain_async(None, None)
while self.is_run():
pass
details = self.get_runtime_details()
message_tokens = sum([row.get('message_tokens') for row in details.values() if
'message_tokens' in row and row.get('message_tokens') is not None])
@ -350,12 +338,22 @@ class WorkflowManage:
流式响应
@return:
"""
result = self.run_chain_async(current_node, node_result_future)
return tools.to_stream_response_simple(self.await_result(result))
self.run_chain_async(current_node, node_result_future)
return tools.to_stream_response_simple(self.await_result())
def await_result(self, result):
def is_run(self, timeout=0.1):
self.lock.acquire()
try:
while await_result(result):
r = concurrent.futures.wait(self.future_list, timeout)
return len(r.not_done) > 0
except Exception as e:
return True
finally:
self.lock.release()
def await_result(self):
try:
while self.is_run():
while True:
chunk = self.node_chunk_manage.pop()
if chunk is not None:
@ -383,12 +381,16 @@ class WorkflowManage:
'', True, message_tokens, answer_tokens, {})
def run_chain_async(self, current_node, node_result_future):
return executor.submit(self.run_chain_manage, current_node, node_result_future)
future = executor.submit(self.run_chain_manage, current_node, node_result_future)
self.future_list.append(future)
def run_chain_manage(self, current_node, node_result_future):
if current_node is None:
start_node = self.get_start_node()
current_node = get_node(start_node.type)(start_node, self.params, self)
self.node_chunk_manage.add_node_chunk(current_node.node_chunk)
# 添加节点
self.append_node(current_node)
result = self.run_chain(current_node, node_result_future)
if result is None:
return
@ -396,29 +398,22 @@ class WorkflowManage:
if len(node_list) == 1:
self.run_chain_manage(node_list[0], None)
elif len(node_list) > 1:
sorted_node_run_list = sorted(node_list, key=lambda n: n.node.y)
# 获取到可执行的子节点
result_list = [{'node': node, 'future': executor.submit(self.run_chain_manage, node, None)} for node in
node_list]
self.set_await_map(result_list)
[r.get('future').result() for r in result_list]
def set_await_map(self, node_run_list):
sorted_node_run_list = sorted(node_run_list, key=lambda n: n.get('node').node.y)
for index in range(len(sorted_node_run_list)):
self.await_future_map[sorted_node_run_list[index].get('node').runtime_node_id] = [
sorted_node_run_list[i].get('future')
for i in range(index)]
sorted_node_run_list]
try:
self.lock.acquire()
for r in result_list:
self.future_list.append(r.get('future'))
finally:
self.lock.release()
def run_chain(self, current_node, node_result_future=None):
if node_result_future is None:
node_result_future = self.run_node_future(current_node)
try:
is_stream = self.params.get('stream', True)
# 处理节点响应
await_future_list = self.await_future_map.get(current_node.runtime_node_id, None)
if await_future_list is not None:
[f.result() for f in await_future_list]
result = self.hand_event_node_result(current_node,
node_result_future) if is_stream else self.hand_node_result(
current_node, node_result_future)
@ -434,16 +429,14 @@ class WorkflowManage:
if result is not None:
# 阻塞获取结果
list(result)
# 添加节点
self.node_context.append(current_node)
return current_result
except Exception as e:
# 添加节点
self.node_context.append(current_node)
traceback.print_exc()
self.status = 500
current_node.get_write_error_context(e)
self.answer += str(e)
finally:
current_node.node_chunk.end()
def append_node(self, current_node):
for index in range(len(self.node_context)):
@ -454,15 +447,14 @@ class WorkflowManage:
self.node_context.append(current_node)
def hand_event_node_result(self, current_node, node_result_future):
node_chunk = NodeChunk()
real_node_id = current_node.runtime_node_id
child_node = {}
view_type = current_node.view_type
try:
current_result = node_result_future.result()
result = current_result.write_context(current_node, self)
if result is not None:
if self.is_result(current_node, current_result):
self.node_chunk_manage.add_node_chunk(node_chunk)
for r in result:
content = r
child_node = {}
@ -487,26 +479,24 @@ class WorkflowManage:
'child_node': child_node,
'node_is_end': node_is_end,
'real_node_id': real_node_id})
node_chunk.add_chunk(chunk)
chunk = self.base_to_response.to_stream_chunk_response(self.params['chat_id'],
self.params['chat_record_id'],
current_node.id,
current_node.up_node_id_list,
'', False, 0, 0, {'node_is_end': True,
'runtime_node_id': current_node.runtime_node_id,
'node_type': current_node.type,
'view_type': view_type,
'child_node': child_node,
'real_node_id': real_node_id})
node_chunk.end(chunk)
current_node.node_chunk.add_chunk(chunk)
chunk = (self.base_to_response
.to_stream_chunk_response(self.params['chat_id'],
self.params['chat_record_id'],
current_node.id,
current_node.up_node_id_list,
'', False, 0, 0, {'node_is_end': True,
'runtime_node_id': current_node.runtime_node_id,
'node_type': current_node.type,
'view_type': view_type,
'child_node': child_node,
'real_node_id': real_node_id}))
current_node.node_chunk.add_chunk(chunk)
else:
list(result)
# 添加节点
self.append_node(current_node)
return current_result
except Exception as e:
# 添加节点
self.append_node(current_node)
traceback.print_exc()
chunk = self.base_to_response.to_stream_chunk_response(self.params['chat_id'],
self.params['chat_record_id'],
@ -519,12 +509,12 @@ class WorkflowManage:
'view_type': current_node.view_type,
'child_node': {},
'real_node_id': real_node_id})
if not self.node_chunk_manage.contains(node_chunk):
self.node_chunk_manage.add_node_chunk(node_chunk)
node_chunk.end(chunk)
current_node.node_chunk.add_chunk(chunk)
current_node.get_write_error_context(e)
self.status = 500
return None
finally:
current_node.node_chunk.end()
def run_node_async(self, node):
future = executor.submit(self.run_node, node)
@ -636,6 +626,8 @@ class WorkflowManage:
@staticmethod
def dependent_node(up_node_id, node):
if not node.node_chunk.is_end():
return False
if node.id == up_node_id:
if node.type == 'form-node':
if node.context.get('form_data', None) is not None:

View File

@ -6,6 +6,7 @@
@date2024/7/22 11:18
@desc:
"""
from django.db import connection
from django.db.models import QuerySet
from common.config.embedding_config import ModelManage
@ -15,6 +16,8 @@ from setting.models_provider import get_model
def get_model_by_id(_id, user_id):
model = QuerySet(Model).filter(id=_id).first()
# 手动关闭数据库连接
connection.close()
if model is None:
raise Exception("模型不存在")
if model.permission_type == 'PRIVATE' and str(model.user_id) != str(user_id):

View File

@ -6,6 +6,7 @@
@date2024/8/20 20:39
@desc:
"""
from django.db import connection
from django.db.models import QuerySet
from langchain_core.documents import Document
from rest_framework import serializers
@ -18,6 +19,8 @@ from setting.models_provider import get_model
def get_embedding_model(model_id):
model = QuerySet(Model).filter(id=model_id).first()
# 手动关闭数据库连接
connection.close()
embedding_model = ModelManage.get_model(model_id,
lambda _id: get_model(model, use_local=True))
return embedding_model

View File

@ -80,7 +80,7 @@ class Config(dict):
"DB_PORT": 5432,
"DB_USER": "root",
"DB_PASSWORD": "Password123@postgres",
"DB_ENGINE": "django.db.backends.postgresql_psycopg2",
"DB_ENGINE": "dj_db_conn_pool.backends.postgresql",
# 向量模型
"EMBEDDING_MODEL_NAME": "shibing624/text2vec-base-chinese",
"EMBEDDING_DEVICE": "cpu",
@ -108,7 +108,11 @@ class Config(dict):
"PORT": self.get('DB_PORT'),
"USER": self.get('DB_USER'),
"PASSWORD": self.get('DB_PASSWORD'),
"ENGINE": self.get('DB_ENGINE')
"ENGINE": self.get('DB_ENGINE'),
"POOL_OPTIONS": {
"POOL_SIZE": 20,
"MAX_OVERFLOW": 5
}
}
def __init__(self, *args):

View File

@ -13,7 +13,7 @@ DB_HOST: 127.0.0.1
DB_PORT: 5432
DB_USER: root
DB_PASSWORD: Password123@postgres
DB_ENGINE: django.db.backends.postgresql_psycopg2
DB_ENGINE: dj_db_conn_pool.backends.postgresql
EMBEDDING_MODEL_PATH: /opt/maxkb/model/embedding
EMBEDDING_MODEL_NAME: /opt/maxkb/model/embedding/shibing624_text2vec-base-chinese

View File

@ -14,7 +14,7 @@ django-filter = "23.2"
langchain = "0.2.16"
langchain_community = "0.2.17"
langchain-huggingface = "^0.0.3"
psycopg2-binary = "2.9.7"
psycopg2-binary = "2.9.10"
jieba = "^0.42.1"
diskcache = "^5.6.3"
pillow = "^10.2.0"
@ -57,6 +57,7 @@ pylint = "3.1.0"
pydub = "^0.25.1"
cffi = "^1.17.1"
pysilk = "^0.0.1"
django-db-connection-pool = "^1.2.5"
[build-system]
requires = ["poetry-core"]
build-backend = "poetry.core.masonry.api"