mirror of
https://github.com/1Panel-dev/MaxKB.git
synced 2025-12-26 18:22:46 +00:00
264 lines
9.7 KiB
Python
264 lines
9.7 KiB
Python
# coding=utf-8
|
||
"""
|
||
@project: MaxKB
|
||
@Author:虎
|
||
@file: common.py
|
||
@date:2024/12/11 17:57
|
||
@desc:
|
||
"""
|
||
|
||
from typing import List, Dict
|
||
|
||
from django.db.models import QuerySet
|
||
from django.utils.translation import gettext as _
|
||
from rest_framework.exceptions import ErrorDetail, ValidationError
|
||
|
||
from common.exception.app_exception import AppApiException
|
||
from common.utils.common import group_by
|
||
from models_provider.models import Model
|
||
from models_provider.tools import get_model_credential
|
||
from tools.models.tool import Tool
|
||
|
||
end_nodes = ['ai-chat-node', 'reply-node', 'function-node', 'function-lib-node', 'application-node',
|
||
'image-understand-node', 'speech-to-text-node', 'text-to-speech-node', 'image-generate-node']
|
||
|
||
|
||
class Answer:
|
||
def __init__(self, content, view_type, runtime_node_id, chat_record_id, child_node, real_node_id,
|
||
reasoning_content):
|
||
self.view_type = view_type
|
||
self.content = content
|
||
self.reasoning_content = reasoning_content
|
||
self.runtime_node_id = runtime_node_id
|
||
self.chat_record_id = chat_record_id
|
||
self.child_node = child_node
|
||
self.real_node_id = real_node_id
|
||
|
||
def to_dict(self):
|
||
return {'view_type': self.view_type, 'content': self.content, 'runtime_node_id': self.runtime_node_id,
|
||
'chat_record_id': self.chat_record_id,
|
||
'child_node': self.child_node,
|
||
'reasoning_content': self.reasoning_content,
|
||
'real_node_id': self.real_node_id}
|
||
|
||
|
||
class NodeChunk:
|
||
def __init__(self):
|
||
self.status = 0
|
||
self.chunk_list = []
|
||
|
||
def add_chunk(self, chunk):
|
||
self.chunk_list.append(chunk)
|
||
|
||
def end(self, chunk=None):
|
||
if chunk is not None:
|
||
self.add_chunk(chunk)
|
||
self.status = 200
|
||
|
||
def is_end(self):
|
||
return self.status == 200
|
||
|
||
|
||
class Edge:
|
||
def __init__(self, _id: str, _type: str, sourceNodeId: str, targetNodeId: str, **keywords):
|
||
self.id = _id
|
||
self.type = _type
|
||
self.sourceNodeId = sourceNodeId
|
||
self.targetNodeId = targetNodeId
|
||
for keyword in keywords:
|
||
self.__setattr__(keyword, keywords.get(keyword))
|
||
|
||
|
||
class Node:
|
||
def __init__(self, _id: str, _type: str, x: int, y: int, properties: dict, **kwargs):
|
||
self.id = _id
|
||
self.type = _type
|
||
self.x = x
|
||
self.y = y
|
||
self.properties = properties
|
||
for keyword in kwargs:
|
||
self.__setattr__(keyword, kwargs.get(keyword))
|
||
|
||
|
||
class EdgeNode:
|
||
edge: Edge
|
||
node: Node
|
||
|
||
def __init__(self, edge, node):
|
||
self.edge = edge
|
||
self.node = node
|
||
|
||
|
||
class Workflow:
|
||
"""
|
||
节点列表
|
||
"""
|
||
nodes: List[Node]
|
||
"""
|
||
线列表
|
||
"""
|
||
edges: List[Edge]
|
||
"""
|
||
节点id:node
|
||
"""
|
||
node_map: Dict[str, Node]
|
||
"""
|
||
节点id:当前节点id上面的所有节点
|
||
"""
|
||
up_node_map: Dict[str, List[EdgeNode]]
|
||
"""
|
||
节点id:当前节点id下面的所有节点
|
||
"""
|
||
next_node_map: Dict[str, List[EdgeNode]]
|
||
|
||
def __init__(self, nodes: List[Node], edges: List[Edge]):
|
||
self.nodes = nodes
|
||
self.edges = edges
|
||
self.node_map = {node.id: node for node in nodes}
|
||
|
||
self.up_node_map = {key: [EdgeNode(edge, self.node_map.get(edge.sourceNodeId)) for
|
||
edge in edges] for
|
||
key, edges in
|
||
group_by(edges, key=lambda edge: edge.targetNodeId).items()}
|
||
|
||
self.next_node_map = {key: [EdgeNode(edge, self.node_map.get(edge.targetNodeId)) for edge in edges] for
|
||
key, edges in
|
||
group_by(edges, key=lambda edge: edge.sourceNodeId).items()}
|
||
|
||
def get_node(self, node_id):
|
||
"""
|
||
根据node_id 获取节点信息
|
||
@param node_id: node_id
|
||
@return: 节点信息
|
||
"""
|
||
return self.node_map.get(node_id)
|
||
|
||
def get_up_edge_nodes(self, node_id) -> List[EdgeNode]:
|
||
"""
|
||
根据节点id 获取当前连接前置节点和连线
|
||
@param node_id: 节点id
|
||
@return: 节点连线列表
|
||
"""
|
||
return self.up_node_map.get(node_id)
|
||
|
||
def get_next_edge_nodes(self, node_id) -> List[EdgeNode]:
|
||
"""
|
||
根据节点id 获取当前连接目标节点和连线
|
||
@param node_id: 节点id
|
||
@return: 节点连线列表
|
||
"""
|
||
return self.next_node_map.get(node_id)
|
||
|
||
def get_up_nodes(self, node_id) -> List[Node]:
|
||
"""
|
||
根据节点id 获取当前连接前置节点
|
||
@param node_id: 节点id
|
||
@return: 节点列表
|
||
"""
|
||
return [en.node for en in self.up_node_map.get(node_id)]
|
||
|
||
def get_next_nodes(self, node_id) -> List[Node]:
|
||
"""
|
||
根据节点id 获取当前连接目标节点
|
||
@param node_id: 节点id
|
||
@return: 节点列表
|
||
"""
|
||
return [en.node for en in self.next_node_map.get(node_id, [])]
|
||
|
||
@staticmethod
|
||
def new_instance(flow_obj: Dict):
|
||
nodes = flow_obj.get('nodes')
|
||
edges = flow_obj.get('edges')
|
||
nodes = [Node(node.get('id'), node.get('type'), **node)
|
||
for node in nodes]
|
||
edges = [Edge(edge.get('id'), edge.get('type'), **edge) for edge in edges]
|
||
return Workflow(nodes, edges)
|
||
|
||
def get_start_node(self):
|
||
return self.get_node('start-node')
|
||
|
||
def get_search_node(self):
|
||
return [node for node in self.nodes if node.type == 'search-dataset-node']
|
||
|
||
def is_valid(self):
|
||
"""
|
||
校验工作流数据
|
||
"""
|
||
self.is_valid_model_params()
|
||
self.is_valid_start_node()
|
||
self.is_valid_base_node()
|
||
self.is_valid_work_flow()
|
||
|
||
@staticmethod
|
||
def is_valid_node_params(node: Node):
|
||
from application.flow.step_node import get_node
|
||
get_node(node.type)(node, None, None)
|
||
|
||
def is_valid_node(self, node: Node):
|
||
self.is_valid_node_params(node)
|
||
if node.type == 'condition-node':
|
||
branch_list = node.properties.get('node_data').get('branch')
|
||
for branch in branch_list:
|
||
source_anchor_id = f"{node.id}_{branch.get('id')}_right"
|
||
edge_list = [edge for edge in self.edges if edge.sourceAnchorId == source_anchor_id]
|
||
if len(edge_list) == 0:
|
||
raise AppApiException(500,
|
||
_('The branch {branch} of the {node} node needs to be connected').format(
|
||
node=node.properties.get("stepName"), branch=branch.get("type")))
|
||
|
||
else:
|
||
edge_list = [edge for edge in self.edges if edge.sourceNodeId == node.id]
|
||
if len(edge_list) == 0 and not end_nodes.__contains__(node.type):
|
||
raise AppApiException(500, _("{node} Nodes cannot be considered as end nodes").format(
|
||
node=node.properties.get("stepName")))
|
||
|
||
def is_valid_work_flow(self, up_node=None):
|
||
if up_node is None:
|
||
up_node = self.get_start_node()
|
||
self.is_valid_node(up_node)
|
||
next_nodes = self.get_next_nodes(up_node)
|
||
for next_node in next_nodes:
|
||
self.is_valid_work_flow(next_node)
|
||
|
||
def is_valid_start_node(self):
|
||
start_node_list = [node for node in self.nodes if node.id == 'start-node']
|
||
if len(start_node_list) == 0:
|
||
raise AppApiException(500, _('The starting node is required'))
|
||
if len(start_node_list) > 1:
|
||
raise AppApiException(500, _('There can only be one starting node'))
|
||
|
||
def is_valid_model_params(self):
|
||
node_list = [node for node in self.nodes if (node.type == 'ai-chat-node' or node.type == 'question-node')]
|
||
for node in node_list:
|
||
model = QuerySet(Model).filter(id=node.properties.get('node_data', {}).get('model_id')).first()
|
||
if model is None:
|
||
raise ValidationError(ErrorDetail(
|
||
_('The node {node} model does not exist').format(node=node.properties.get("stepName"))))
|
||
credential = get_model_credential(model.provider, model.model_type, model.model_name)
|
||
model_params_setting = node.properties.get('node_data', {}).get('model_params_setting')
|
||
model_params_setting_form = credential.get_model_params_setting_form(
|
||
model.model_name)
|
||
if model_params_setting is None:
|
||
model_params_setting = model_params_setting_form.get_default_form_data()
|
||
node.properties.get('node_data', {})['model_params_setting'] = model_params_setting
|
||
if node.properties.get('status', 200) != 200:
|
||
raise ValidationError(
|
||
ErrorDetail(_("Node {node} is unavailable").format(node.properties.get("stepName"))))
|
||
node_list = [node for node in self.nodes if (node.type == 'function-lib-node')]
|
||
for node in node_list:
|
||
function_lib_id = node.properties.get('node_data', {}).get('function_lib_id')
|
||
if function_lib_id is None:
|
||
raise ValidationError(ErrorDetail(
|
||
_('The library ID of node {node} cannot be empty').format(node=node.properties.get("stepName"))))
|
||
f_lib = QuerySet(Tool).filter(id=function_lib_id).first()
|
||
if f_lib is None:
|
||
raise ValidationError(ErrorDetail(_("The function library for node {node} is not available").format(
|
||
node=node.properties.get("stepName"))))
|
||
|
||
def is_valid_base_node(self):
|
||
base_node_list = [node for node in self.nodes if node.id == 'base-node']
|
||
if len(base_node_list) == 0:
|
||
raise AppApiException(500, _('Basic information node is required'))
|
||
if len(base_node_list) > 1:
|
||
raise AppApiException(500, _('There can only be one basic information node'))
|