mirror of
https://github.com/1Panel-dev/MaxKB.git
synced 2025-12-25 17:22:55 +00:00
refactor: Workflow execution logic (#1913)
This commit is contained in:
parent
bb58ac6f2c
commit
efa73c827f
|
|
@ -19,3 +19,20 @@ class Answer:
|
|||
def to_dict(self):
|
||||
return {'view_type': self.view_type, 'content': self.content, 'runtime_node_id': self.runtime_node_id,
|
||||
'chat_record_id': self.chat_record_id, 'child_node': self.child_node}
|
||||
|
||||
|
||||
class NodeChunk:
|
||||
def __init__(self):
|
||||
self.status = 0
|
||||
self.chunk_list = []
|
||||
|
||||
def add_chunk(self, chunk):
|
||||
self.chunk_list.append(chunk)
|
||||
|
||||
def end(self, chunk=None):
|
||||
if chunk is not None:
|
||||
self.add_chunk(chunk)
|
||||
self.status = 200
|
||||
|
||||
def is_end(self):
|
||||
return self.status == 200
|
||||
|
|
|
|||
|
|
@ -17,7 +17,7 @@ from django.db.models import QuerySet
|
|||
from rest_framework import serializers
|
||||
from rest_framework.exceptions import ValidationError, ErrorDetail
|
||||
|
||||
from application.flow.common import Answer
|
||||
from application.flow.common import Answer, NodeChunk
|
||||
from application.models import ChatRecord
|
||||
from application.models.api_key_model import ApplicationPublicAccessClient
|
||||
from common.constants.authentication_type import AuthenticationType
|
||||
|
|
@ -175,6 +175,7 @@ class INode:
|
|||
if up_node_id_list is None:
|
||||
up_node_id_list = []
|
||||
self.up_node_id_list = up_node_id_list
|
||||
self.node_chunk = NodeChunk()
|
||||
self.runtime_node_id = sha1(uuid.NAMESPACE_DNS.bytes + bytes(str(uuid.uuid5(uuid.NAMESPACE_DNS,
|
||||
"".join([*sorted(up_node_id_list),
|
||||
node.id]))),
|
||||
|
|
@ -214,6 +215,7 @@ class INode:
|
|||
|
||||
def get_write_error_context(self, e):
|
||||
self.status = 500
|
||||
self.answer_text = str(e)
|
||||
self.err_message = str(e)
|
||||
self.context['run_time'] = time.time() - self.context['start_time']
|
||||
|
||||
|
|
|
|||
|
|
@ -10,7 +10,7 @@ import os
|
|||
from typing import List, Dict
|
||||
|
||||
from django.db.models import QuerySet
|
||||
|
||||
from django.db import connection
|
||||
from application.flow.i_step_node import NodeResult
|
||||
from application.flow.step_node.search_dataset_node.i_search_dataset_node import ISearchDatasetStepNode
|
||||
from common.config.embedding_config import VectorStore
|
||||
|
|
@ -77,6 +77,8 @@ class BaseSearchDatasetNode(ISearchDatasetStepNode):
|
|||
embedding_list = vector.query(question, embedding_value, dataset_id_list, exclude_document_id_list,
|
||||
exclude_paragraph_id_list, True, dataset_setting.get('top_n'),
|
||||
dataset_setting.get('similarity'), SearchMode(dataset_setting.get('search_mode')))
|
||||
# 手动关闭数据库连接
|
||||
connection.close()
|
||||
if embedding_list is None:
|
||||
return get_none_result(question)
|
||||
paragraph_list = self.list_paragraph(embedding_list, vector)
|
||||
|
|
|
|||
|
|
@ -6,6 +6,7 @@
|
|||
@date:2024/1/9 17:40
|
||||
@desc:
|
||||
"""
|
||||
import concurrent
|
||||
import json
|
||||
import threading
|
||||
import traceback
|
||||
|
|
@ -13,6 +14,7 @@ from concurrent.futures import ThreadPoolExecutor
|
|||
from functools import reduce
|
||||
from typing import List, Dict
|
||||
|
||||
from django.db import close_old_connections
|
||||
from django.db.models import QuerySet
|
||||
from langchain_core.prompts import PromptTemplate
|
||||
from rest_framework import status
|
||||
|
|
@ -223,23 +225,6 @@ class NodeChunkManage:
|
|||
return None
|
||||
|
||||
|
||||
class NodeChunk:
|
||||
def __init__(self):
|
||||
self.status = 0
|
||||
self.chunk_list = []
|
||||
|
||||
def add_chunk(self, chunk):
|
||||
self.chunk_list.append(chunk)
|
||||
|
||||
def end(self, chunk=None):
|
||||
if chunk is not None:
|
||||
self.add_chunk(chunk)
|
||||
self.status = 200
|
||||
|
||||
def is_end(self):
|
||||
return self.status == 200
|
||||
|
||||
|
||||
class WorkflowManage:
|
||||
def __init__(self, flow: Flow, params, work_flow_post_handler: WorkFlowPostHandler,
|
||||
base_to_response: BaseToResponse = SystemToResponse(), form_data=None, image_list=None,
|
||||
|
|
@ -273,8 +258,9 @@ class WorkflowManage:
|
|||
self.status = 200
|
||||
self.base_to_response = base_to_response
|
||||
self.chat_record = chat_record
|
||||
self.await_future_map = {}
|
||||
self.child_node = child_node
|
||||
self.future_list = []
|
||||
self.lock = threading.Lock()
|
||||
if start_node_id is not None:
|
||||
self.load_node(chat_record, start_node_id, start_node_data)
|
||||
else:
|
||||
|
|
@ -319,6 +305,7 @@ class WorkflowManage:
|
|||
self.node_context.append(node)
|
||||
|
||||
def run(self):
|
||||
close_old_connections()
|
||||
if self.params.get('stream'):
|
||||
return self.run_stream(self.start_node, None)
|
||||
return self.run_block()
|
||||
|
|
@ -328,8 +315,9 @@ class WorkflowManage:
|
|||
非流式响应
|
||||
@return: 结果
|
||||
"""
|
||||
result = self.run_chain_async(None, None)
|
||||
result.result()
|
||||
self.run_chain_async(None, None)
|
||||
while self.is_run():
|
||||
pass
|
||||
details = self.get_runtime_details()
|
||||
message_tokens = sum([row.get('message_tokens') for row in details.values() if
|
||||
'message_tokens' in row and row.get('message_tokens') is not None])
|
||||
|
|
@ -350,12 +338,22 @@ class WorkflowManage:
|
|||
流式响应
|
||||
@return:
|
||||
"""
|
||||
result = self.run_chain_async(current_node, node_result_future)
|
||||
return tools.to_stream_response_simple(self.await_result(result))
|
||||
self.run_chain_async(current_node, node_result_future)
|
||||
return tools.to_stream_response_simple(self.await_result())
|
||||
|
||||
def await_result(self, result):
|
||||
def is_run(self, timeout=0.1):
|
||||
self.lock.acquire()
|
||||
try:
|
||||
while await_result(result):
|
||||
r = concurrent.futures.wait(self.future_list, timeout)
|
||||
return len(r.not_done) > 0
|
||||
except Exception as e:
|
||||
return True
|
||||
finally:
|
||||
self.lock.release()
|
||||
|
||||
def await_result(self):
|
||||
try:
|
||||
while self.is_run():
|
||||
while True:
|
||||
chunk = self.node_chunk_manage.pop()
|
||||
if chunk is not None:
|
||||
|
|
@ -383,12 +381,16 @@ class WorkflowManage:
|
|||
'', True, message_tokens, answer_tokens, {})
|
||||
|
||||
def run_chain_async(self, current_node, node_result_future):
|
||||
return executor.submit(self.run_chain_manage, current_node, node_result_future)
|
||||
future = executor.submit(self.run_chain_manage, current_node, node_result_future)
|
||||
self.future_list.append(future)
|
||||
|
||||
def run_chain_manage(self, current_node, node_result_future):
|
||||
if current_node is None:
|
||||
start_node = self.get_start_node()
|
||||
current_node = get_node(start_node.type)(start_node, self.params, self)
|
||||
self.node_chunk_manage.add_node_chunk(current_node.node_chunk)
|
||||
# 添加节点
|
||||
self.append_node(current_node)
|
||||
result = self.run_chain(current_node, node_result_future)
|
||||
if result is None:
|
||||
return
|
||||
|
|
@ -396,29 +398,22 @@ class WorkflowManage:
|
|||
if len(node_list) == 1:
|
||||
self.run_chain_manage(node_list[0], None)
|
||||
elif len(node_list) > 1:
|
||||
|
||||
sorted_node_run_list = sorted(node_list, key=lambda n: n.node.y)
|
||||
# 获取到可执行的子节点
|
||||
result_list = [{'node': node, 'future': executor.submit(self.run_chain_manage, node, None)} for node in
|
||||
node_list]
|
||||
self.set_await_map(result_list)
|
||||
[r.get('future').result() for r in result_list]
|
||||
|
||||
def set_await_map(self, node_run_list):
|
||||
sorted_node_run_list = sorted(node_run_list, key=lambda n: n.get('node').node.y)
|
||||
for index in range(len(sorted_node_run_list)):
|
||||
self.await_future_map[sorted_node_run_list[index].get('node').runtime_node_id] = [
|
||||
sorted_node_run_list[i].get('future')
|
||||
for i in range(index)]
|
||||
sorted_node_run_list]
|
||||
try:
|
||||
self.lock.acquire()
|
||||
for r in result_list:
|
||||
self.future_list.append(r.get('future'))
|
||||
finally:
|
||||
self.lock.release()
|
||||
|
||||
def run_chain(self, current_node, node_result_future=None):
|
||||
if node_result_future is None:
|
||||
node_result_future = self.run_node_future(current_node)
|
||||
try:
|
||||
is_stream = self.params.get('stream', True)
|
||||
# 处理节点响应
|
||||
await_future_list = self.await_future_map.get(current_node.runtime_node_id, None)
|
||||
if await_future_list is not None:
|
||||
[f.result() for f in await_future_list]
|
||||
result = self.hand_event_node_result(current_node,
|
||||
node_result_future) if is_stream else self.hand_node_result(
|
||||
current_node, node_result_future)
|
||||
|
|
@ -434,16 +429,14 @@ class WorkflowManage:
|
|||
if result is not None:
|
||||
# 阻塞获取结果
|
||||
list(result)
|
||||
# 添加节点
|
||||
self.node_context.append(current_node)
|
||||
return current_result
|
||||
except Exception as e:
|
||||
# 添加节点
|
||||
self.node_context.append(current_node)
|
||||
traceback.print_exc()
|
||||
self.status = 500
|
||||
current_node.get_write_error_context(e)
|
||||
self.answer += str(e)
|
||||
finally:
|
||||
current_node.node_chunk.end()
|
||||
|
||||
def append_node(self, current_node):
|
||||
for index in range(len(self.node_context)):
|
||||
|
|
@ -454,15 +447,14 @@ class WorkflowManage:
|
|||
self.node_context.append(current_node)
|
||||
|
||||
def hand_event_node_result(self, current_node, node_result_future):
|
||||
node_chunk = NodeChunk()
|
||||
real_node_id = current_node.runtime_node_id
|
||||
child_node = {}
|
||||
view_type = current_node.view_type
|
||||
try:
|
||||
current_result = node_result_future.result()
|
||||
result = current_result.write_context(current_node, self)
|
||||
if result is not None:
|
||||
if self.is_result(current_node, current_result):
|
||||
self.node_chunk_manage.add_node_chunk(node_chunk)
|
||||
for r in result:
|
||||
content = r
|
||||
child_node = {}
|
||||
|
|
@ -487,26 +479,24 @@ class WorkflowManage:
|
|||
'child_node': child_node,
|
||||
'node_is_end': node_is_end,
|
||||
'real_node_id': real_node_id})
|
||||
node_chunk.add_chunk(chunk)
|
||||
chunk = self.base_to_response.to_stream_chunk_response(self.params['chat_id'],
|
||||
self.params['chat_record_id'],
|
||||
current_node.id,
|
||||
current_node.up_node_id_list,
|
||||
'', False, 0, 0, {'node_is_end': True,
|
||||
'runtime_node_id': current_node.runtime_node_id,
|
||||
'node_type': current_node.type,
|
||||
'view_type': view_type,
|
||||
'child_node': child_node,
|
||||
'real_node_id': real_node_id})
|
||||
node_chunk.end(chunk)
|
||||
current_node.node_chunk.add_chunk(chunk)
|
||||
chunk = (self.base_to_response
|
||||
.to_stream_chunk_response(self.params['chat_id'],
|
||||
self.params['chat_record_id'],
|
||||
current_node.id,
|
||||
current_node.up_node_id_list,
|
||||
'', False, 0, 0, {'node_is_end': True,
|
||||
'runtime_node_id': current_node.runtime_node_id,
|
||||
'node_type': current_node.type,
|
||||
'view_type': view_type,
|
||||
'child_node': child_node,
|
||||
'real_node_id': real_node_id}))
|
||||
current_node.node_chunk.add_chunk(chunk)
|
||||
else:
|
||||
list(result)
|
||||
# 添加节点
|
||||
self.append_node(current_node)
|
||||
return current_result
|
||||
except Exception as e:
|
||||
# 添加节点
|
||||
self.append_node(current_node)
|
||||
traceback.print_exc()
|
||||
chunk = self.base_to_response.to_stream_chunk_response(self.params['chat_id'],
|
||||
self.params['chat_record_id'],
|
||||
|
|
@ -519,12 +509,12 @@ class WorkflowManage:
|
|||
'view_type': current_node.view_type,
|
||||
'child_node': {},
|
||||
'real_node_id': real_node_id})
|
||||
if not self.node_chunk_manage.contains(node_chunk):
|
||||
self.node_chunk_manage.add_node_chunk(node_chunk)
|
||||
node_chunk.end(chunk)
|
||||
current_node.node_chunk.add_chunk(chunk)
|
||||
current_node.get_write_error_context(e)
|
||||
self.status = 500
|
||||
return None
|
||||
finally:
|
||||
current_node.node_chunk.end()
|
||||
|
||||
def run_node_async(self, node):
|
||||
future = executor.submit(self.run_node, node)
|
||||
|
|
@ -636,6 +626,8 @@ class WorkflowManage:
|
|||
|
||||
@staticmethod
|
||||
def dependent_node(up_node_id, node):
|
||||
if not node.node_chunk.is_end():
|
||||
return False
|
||||
if node.id == up_node_id:
|
||||
if node.type == 'form-node':
|
||||
if node.context.get('form_data', None) is not None:
|
||||
|
|
|
|||
|
|
@ -6,6 +6,7 @@
|
|||
@date:2024/7/22 11:18
|
||||
@desc:
|
||||
"""
|
||||
from django.db import connection
|
||||
from django.db.models import QuerySet
|
||||
|
||||
from common.config.embedding_config import ModelManage
|
||||
|
|
@ -15,6 +16,8 @@ from setting.models_provider import get_model
|
|||
|
||||
def get_model_by_id(_id, user_id):
|
||||
model = QuerySet(Model).filter(id=_id).first()
|
||||
# 手动关闭数据库连接
|
||||
connection.close()
|
||||
if model is None:
|
||||
raise Exception("模型不存在")
|
||||
if model.permission_type == 'PRIVATE' and str(model.user_id) != str(user_id):
|
||||
|
|
|
|||
|
|
@ -6,6 +6,7 @@
|
|||
@date:2024/8/20 20:39
|
||||
@desc:
|
||||
"""
|
||||
from django.db import connection
|
||||
from django.db.models import QuerySet
|
||||
from langchain_core.documents import Document
|
||||
from rest_framework import serializers
|
||||
|
|
@ -18,6 +19,8 @@ from setting.models_provider import get_model
|
|||
|
||||
def get_embedding_model(model_id):
|
||||
model = QuerySet(Model).filter(id=model_id).first()
|
||||
# 手动关闭数据库连接
|
||||
connection.close()
|
||||
embedding_model = ModelManage.get_model(model_id,
|
||||
lambda _id: get_model(model, use_local=True))
|
||||
return embedding_model
|
||||
|
|
|
|||
|
|
@ -80,7 +80,7 @@ class Config(dict):
|
|||
"DB_PORT": 5432,
|
||||
"DB_USER": "root",
|
||||
"DB_PASSWORD": "Password123@postgres",
|
||||
"DB_ENGINE": "django.db.backends.postgresql_psycopg2",
|
||||
"DB_ENGINE": "dj_db_conn_pool.backends.postgresql",
|
||||
# 向量模型
|
||||
"EMBEDDING_MODEL_NAME": "shibing624/text2vec-base-chinese",
|
||||
"EMBEDDING_DEVICE": "cpu",
|
||||
|
|
@ -108,7 +108,11 @@ class Config(dict):
|
|||
"PORT": self.get('DB_PORT'),
|
||||
"USER": self.get('DB_USER'),
|
||||
"PASSWORD": self.get('DB_PASSWORD'),
|
||||
"ENGINE": self.get('DB_ENGINE')
|
||||
"ENGINE": self.get('DB_ENGINE'),
|
||||
"POOL_OPTIONS": {
|
||||
"POOL_SIZE": 20,
|
||||
"MAX_OVERFLOW": 5
|
||||
}
|
||||
}
|
||||
|
||||
def __init__(self, *args):
|
||||
|
|
|
|||
|
|
@ -13,7 +13,7 @@ DB_HOST: 127.0.0.1
|
|||
DB_PORT: 5432
|
||||
DB_USER: root
|
||||
DB_PASSWORD: Password123@postgres
|
||||
DB_ENGINE: django.db.backends.postgresql_psycopg2
|
||||
DB_ENGINE: dj_db_conn_pool.backends.postgresql
|
||||
EMBEDDING_MODEL_PATH: /opt/maxkb/model/embedding
|
||||
EMBEDDING_MODEL_NAME: /opt/maxkb/model/embedding/shibing624_text2vec-base-chinese
|
||||
|
||||
|
|
|
|||
|
|
@ -14,7 +14,7 @@ django-filter = "23.2"
|
|||
langchain = "0.2.16"
|
||||
langchain_community = "0.2.17"
|
||||
langchain-huggingface = "^0.0.3"
|
||||
psycopg2-binary = "2.9.7"
|
||||
psycopg2-binary = "2.9.10"
|
||||
jieba = "^0.42.1"
|
||||
diskcache = "^5.6.3"
|
||||
pillow = "^10.2.0"
|
||||
|
|
@ -57,6 +57,7 @@ pylint = "3.1.0"
|
|||
pydub = "^0.25.1"
|
||||
cffi = "^1.17.1"
|
||||
pysilk = "^0.0.1"
|
||||
django-db-connection-pool = "^1.2.5"
|
||||
[build-system]
|
||||
requires = ["poetry-core"]
|
||||
build-backend = "poetry.core.masonry.api"
|
||||
|
|
|
|||
Loading…
Reference in New Issue