feat: 工作流程

This commit is contained in:
shaohuzhang1 2024-06-06 19:27:02 +08:00
parent f87919642b
commit fcb33db41d
9 changed files with 769 additions and 0 deletions

View File

@ -0,0 +1,168 @@
# coding=utf-8
"""
@project: maxkb
@Author
@file i_step_node.py
@date2024/6/3 14:57
@desc:
"""
import time
from abc import abstractmethod
from typing import Type, Dict, List
from django.db.models import QuerySet
from rest_framework import serializers
from application.models import ChatRecord
from application.models.api_key_model import ApplicationPublicAccessClient
from application.serializers.application_serializers import chat_cache
from common.constants.authentication_type import AuthenticationType
from common.field.common import InstanceField
from common.util.field_message import ErrMessage
def write_context(step_variable: Dict, global_variable: Dict, node, workflow):
if step_variable is not None:
for key in step_variable:
node.context[key] = step_variable[key]
if global_variable is not None:
for key in global_variable:
workflow.context[key] = global_variable[key]
class WorkFlowPostHandler:
def __init__(self, chat_info, client_id, client_type):
self.chat_info = chat_info
self.client_id = client_id
self.client_type = client_type
def handler(self, chat_id,
chat_record_id,
answer,
workflow):
question = workflow.params['question']
chat_record = ChatRecord(id=chat_record_id,
chat_id=chat_id,
problem_text=question,
answer_text=answer,
details=workflow.get_details(),
message_tokens=workflow.context['message_tokens'],
answer_tokens=workflow.context['answer_tokens'],
run_time=workflow.context['run_time'],
index=len(self.chat_info.chat_record_list) + 1)
self.chat_info.append_chat_record(chat_record, self.client_id)
# 重新设置缓存
chat_cache.set(chat_id,
self.chat_info, timeout=60 * 30)
if self.client_type == AuthenticationType.APPLICATION_ACCESS_TOKEN.value:
application_public_access_client = QuerySet(ApplicationPublicAccessClient).filter(id=self.client_id).first()
if application_public_access_client is not None:
application_public_access_client.access_num = application_public_access_client.access_num + 1
application_public_access_client.intraday_access_num = application_public_access_client.intraday_access_num + 1
application_public_access_client.save()
class NodeResult:
def __init__(self, node_variable: Dict, workflow_variable: Dict, _to_response=None, _write_context=write_context):
self._write_context = _write_context
self.node_variable = node_variable
self.workflow_variable = workflow_variable
self._to_response = _to_response
def write_context(self, node, workflow):
self._write_context(self.node_variable, self.workflow_variable, node, workflow)
def to_response(self, chat_id, chat_record_id, node, workflow, post_handler: WorkFlowPostHandler):
return self._to_response(chat_id, chat_record_id, self.node_variable, self.workflow_variable, node, workflow,
post_handler)
class ReferenceAddressSerializer(serializers.Serializer):
node_id = serializers.CharField(required=True, error_messages=ErrMessage.char("节点id"))
fields = serializers.ListField(
child=serializers.CharField(required=True, error_messages=ErrMessage.char("节点字段")), required=True,
error_messages=ErrMessage.list("节点字段数组"))
class FlowParamsSerializer(serializers.Serializer):
# 历史对答
history_chat_record = serializers.ListField(child=InstanceField(model_type=ChatRecord, required=True),
error_messages=ErrMessage.list("历史对答"))
question = serializers.CharField(required=True, error_messages=ErrMessage.list("用户问题"))
chat_id = serializers.CharField(required=True, error_messages=ErrMessage.list("对话id"))
chat_record_id = serializers.CharField(required=True, error_messages=ErrMessage.char("对话记录id"))
stream = serializers.BooleanField(required=True, error_messages=ErrMessage.base("流式输出"))
client_id = serializers.CharField(required=True, error_messages=ErrMessage.char("客户端id"))
client_type = serializers.CharField(required=True, error_messages=ErrMessage.char("客户端类型"))
class INode:
def __init__(self, _id, node_params, workflow_params, workflow_manage):
# 当前步骤上下文,用于存储当前步骤信息
self.node_params = node_params
self.workflow_manage = workflow_manage
self.node_params_serializer = None
self.flow_params_serializer = None
self.context = {}
self.id = _id
self.valid_args(node_params, workflow_params)
def valid_args(self, node_params, flow_params):
flow_params_serializer_class = self.get_flow_params_serializer_class()
node_params_serializer_class = self.get_node_params_serializer_class()
if flow_params_serializer_class is not None:
self.flow_params_serializer = flow_params_serializer_class(data=flow_params)
self.flow_params_serializer.is_valid(raise_exception=True)
if node_params_serializer_class is not None:
self.node_params_serializer = node_params_serializer_class(data=node_params)
self.node_params_serializer.is_valid(raise_exception=True)
def get_reference_field(self, fields: List[str]):
return self.get_field(self.context, fields)
@staticmethod
def get_field(obj, fields: List[str]):
for field in fields:
value = obj.get(field)
if value is None:
return None
else:
obj = value
return obj
@abstractmethod
def get_node_params_serializer_class(self) -> Type[serializers.Serializer]:
pass
def get_flow_params_serializer_class(self) -> Type[serializers.Serializer]:
return FlowParamsSerializer
def run(self) -> NodeResult:
"""
:return: 执行结果
"""
start_time = time.time()
self.context['start_time'] = start_time
result = self._run()
self.context['run_time'] = time.time() - start_time
return result
def _run(self):
result = self.execute()
return result
def execute(self, **kwargs) -> NodeResult:
pass
def get_details(self, **kwargs):
"""
运行详情
:return: 步骤详情
"""
return None

View File

@ -0,0 +1,36 @@
# coding=utf-8
"""
@project: maxkb
@Author
@file i_chat_node.py
@date2024/6/4 13:58
@desc:
"""
from typing import Type
from rest_framework import serializers
from application.flow.i_step_node import INode, NodeResult
from common.util.field_message import ErrMessage
class ChatNodeSerializer(serializers.Serializer):
model_id = serializers.CharField(required=True, error_messages=ErrMessage.char("模型id"))
system = serializers.CharField(required=True, error_messages=ErrMessage.char("角色设定"))
prompt = serializers.CharField(required=True, error_messages=ErrMessage.char("提示词"))
# 多轮对话数量
dialogue_number = serializers.IntegerField(required=True, error_messages=ErrMessage.integer("多轮对话数量"))
class IChatNode(INode):
type = 'ai-chat-node'
def get_node_params_serializer_class(self) -> Type[serializers.Serializer]:
return ChatNodeSerializer
def _run(self):
return self.execute(**self.node_params_serializer.data, **self.flow_params_serializer.data)
def execute(self, model_id, system, prompt, dialogue_number, history_chat_record, stream, chat_id, chat_record_id,
**kwargs) -> NodeResult:
pass

View File

@ -0,0 +1,145 @@
# coding=utf-8
"""
@project: maxkb
@Author
@file base_chat_node.py
@date2024/6/4 14:30
@desc:
"""
import json
from typing import List, Dict
from django.db.models import QuerySet
from langchain.schema import HumanMessage, SystemMessage
from langchain_core.messages import BaseMessage
from application.flow import tools
from application.flow.i_step_node import NodeResult, INode
from application.flow.step_node.ai_chat_step_node.i_chat_node import IChatNode
from common.util.rsa_util import rsa_long_decrypt
from setting.models import Model
from setting.models_provider.constants.model_provider_constants import ModelProvideConstants
def write_context_stream(node_variable: Dict, workflow_variable: Dict, node: INode, workflow):
"""
写入上下文数据 (流式)
@param node_variable: 节点数据
@param workflow_variable: 全局数据
@param node: 节点
@param workflow: 工作流管理器
"""
response = node_variable.get('result')
answer = ''
for chunk in response:
answer += chunk.content
chat_model = node_variable.get('chat_model')
message_tokens = chat_model.get_num_tokens_from_messages(node_variable.get('message_list'))
answer_tokens = chat_model.get_num_tokens(answer)
node.context['message_tokens'] = message_tokens
node.context['answer_tokens'] = answer_tokens
node.context['answer'] = answer
def write_context(node_variable: Dict, workflow_variable: Dict, node: INode, workflow):
"""
写入上下文数据
@param node_variable: 节点数据
@param workflow_variable: 全局数据
@param node: 节点实例对象
@param workflow: 工作流管理器
"""
response = node_variable.get('result')
chat_model = node_variable.get('chat_model')
answer = response.content
message_tokens = chat_model.get_num_tokens_from_messages(node_variable.get('message_list'))
answer_tokens = chat_model.get_num_tokens(answer)
node.context['message_tokens'] = message_tokens
node.context['answer_tokens'] = answer_tokens
node.context['answer'] = answer
def get_to_response_write_context(node_variable: Dict, node: INode):
def _write_context(answer):
chat_model = node_variable.get('chat_model')
message_tokens = chat_model.get_num_tokens_from_messages(node_variable.get('message_list'))
answer_tokens = chat_model.get_num_tokens(answer)
node.context['message_tokens'] = message_tokens
node.context['answer_tokens'] = answer_tokens
node.context['answer'] = answer
return _write_context
def to_stream_response(chat_id, chat_record_id, node_variable: Dict, workflow_variable: Dict, node, workflow,
post_handler):
"""
将流式数据 转换为 流式响应
@param chat_id: 会话id
@param chat_record_id: 对话记录id
@param node_variable: 节点数据
@param workflow_variable: 工作流数据
@param node: 节点
@param workflow: 工作流管理器
@param post_handler: 后置处理器 输出结果后执行
@return: 流式响应
"""
response = node_variable.get('result')
_write_context = get_to_response_write_context(node_variable, node)
return tools.to_stream_response(chat_id, chat_record_id, response, workflow, _write_context, post_handler)
def to_response(chat_id, chat_record_id, node_variable: Dict, workflow_variable: Dict, node, workflow,
post_handler):
"""
将结果转换
@param chat_id: 会话id
@param chat_record_id: 对话记录id
@param node_variable: 节点数据
@param workflow_variable: 工作流数据
@param node: 节点
@param workflow: 工作流管理器
@param post_handler: 后置处理器
@return: 响应
"""
response = node_variable.get('result')
_write_context = get_to_response_write_context(node_variable, node)
return tools.to_response(chat_id, chat_record_id, response, workflow, _write_context, post_handler)
class BaseChatNode(IChatNode):
def execute(self, model_id, system, prompt, dialogue_number, history_chat_record, stream, chat_id, chat_record_id,
**kwargs) -> NodeResult:
model = QuerySet(Model).filter(id=model_id).first()
chat_model = ModelProvideConstants[model.provider].value.get_model(model.model_type, model.model_name,
json.loads(
rsa_long_decrypt(model.credential)),
streaming=True)
message_list = self.generate_message_list(system, prompt, history_chat_record, dialogue_number)
if stream:
r = chat_model.stream(message_list)
return NodeResult({'result': r, 'chat_model': chat_model, 'message_list': message_list,
'get_to_response_write_context': get_to_response_write_context}, {},
_write_context=write_context_stream,
_to_response=to_stream_response)
else:
r = chat_model.invoke(message_list)
return NodeResult({'result': r, 'chat_model': chat_model, 'message_list': message_list}, {},
_write_context=write_context, _to_response=to_response)
def generate_message_list(self, system: str, prompt: str, history_chat_record, dialogue_number):
start_index = len(history_chat_record) - dialogue_number
history_message = [[history_chat_record[index].get_human_message(), history_chat_record[index].get_ai_message()]
for index in
range(start_index if start_index > 0 else 0, len(history_chat_record))]
return [SystemMessage(self.workflow_manage.generate_prompt(system)), *history_message,
HumanMessage(self.workflow_manage.generate_prompt(prompt))]
@staticmethod
def reset_message_list(message_list: List[BaseMessage], answer_text):
result = [{'role': 'user' if isinstance(message, HumanMessage) else 'ai', 'content': message.content} for
message
in
message_list]
result.append({'role': 'ai', 'content': answer_text})
return result

View File

@ -0,0 +1,56 @@
# coding=utf-8
"""
@project: maxkb
@Author
@file i_search_dataset_node.py
@date2024/6/3 17:52
@desc:
"""
import re
from typing import Type
from django.core import validators
from rest_framework import serializers
from application.flow.i_step_node import INode, ReferenceAddressSerializer, NodeResult
from common.util.field_message import ErrMessage
class SearchDatasetStepNodeSerializer(serializers.Serializer):
# 需要查询的数据集id列表
dataset_id_list = serializers.ListField(required=True, child=serializers.UUIDField(required=True),
error_messages=ErrMessage.list("数据集id列表"))
# 需要查询的条数
top_n = serializers.IntegerField(required=True,
error_messages=ErrMessage.integer("引用分段数"))
# 相似度 0-1之间
similarity = serializers.FloatField(required=True, max_value=2, min_value=0,
error_messages=ErrMessage.float("引用分段数"))
search_mode = serializers.CharField(required=True, validators=[
validators.RegexValidator(regex=re.compile("^embedding|keywords|blend$"),
message="类型只支持register|reset_password", code=500)
], error_messages=ErrMessage.char("检索模式"))
question_reference_address = ReferenceAddressSerializer(required=False,
error_messages=ErrMessage.char("问题应用地址"))
def is_valid(self, *, raise_exception=False):
super().is_valid(raise_exception=True)
class ISearchDatasetStepNode(INode):
type = 'search-dataset-node'
def get_node_params_serializer_class(self) -> Type[serializers.Serializer]:
return SearchDatasetStepNodeSerializer
def _run(self):
question = self.workflow_manage.get_reference_field(
self.node_params_serializer.data.get('question_reference_address').get('node_id'),
self.node_params_serializer.data.get('question_reference_address').get('fields'))
return self.execute(**self.node_params_serializer.data, question=question, exclude_paragraph_id_list=[])
def execute(self, dataset_id_list, top_n, similarity, search_mode, question_reference_address, question,
exclude_paragraph_id_list=None,
**kwargs) -> NodeResult:
pass

View File

@ -0,0 +1,73 @@
# coding=utf-8
"""
@project: maxkb
@Author
@file base_search_dataset_node.py
@date2024/6/4 11:56
@desc:
"""
import os
from typing import List, Dict
from django.db.models import QuerySet
from application.flow.i_step_node import NodeResult
from application.flow.step_node.search_dataset_node.i_search_dataset_node import ISearchDatasetStepNode
from common.config.embedding_config import EmbeddingModel, VectorStore
from common.db.search import native_search
from common.util.file_util import get_file_content
from dataset.models import Document, Paragraph
from embedding.models import SearchMode
from smartdoc.conf import PROJECT_DIR
class BaseSearchDatasetNode(ISearchDatasetStepNode):
def execute(self, dataset_id_list, top_n, similarity, search_mode, question_reference_address, question,
exclude_paragraph_id_list=None,
**kwargs) -> NodeResult:
embedding_model = EmbeddingModel.get_embedding_model()
embedding_value = embedding_model.embed_query(question)
vector = VectorStore.get_embedding_vector()
exclude_document_id_list = [str(document.id) for document in
QuerySet(Document).filter(
dataset_id__in=dataset_id_list,
is_active=False)]
embedding_list = vector.query(question, embedding_value, dataset_id_list, exclude_document_id_list,
exclude_paragraph_id_list, True, top_n, similarity, SearchMode(search_mode))
if embedding_list is None:
return NodeResult({'paragraph_list': [], 'is_hit_handling_method': []}, {})
paragraph_list = self.list_paragraph(embedding_list, vector)
result = [self.reset_paragraph(paragraph, embedding_list) for paragraph in paragraph_list]
return NodeResult({'paragraph_list': result,
'is_hit_handling_method_list': [row.get('is_hit_handling_method') for row in result]}, {})
@staticmethod
def reset_paragraph(paragraph: Dict, embedding_list: List):
filter_embedding_list = [embedding for embedding in embedding_list if
str(embedding.get('paragraph_id')) == str(paragraph.get('id'))]
if filter_embedding_list is not None and len(filter_embedding_list) > 0:
find_embedding = filter_embedding_list[-1]
return {
**paragraph,
'similarity': find_embedding.get('similarity'),
'is_hit_handling_method': find_embedding.get('similarity') > paragraph.get(
'directly_return_similarity') and paragraph.get('hit_handling_method') == 'directly_return'
}
@staticmethod
def list_paragraph(embedding_list: List, vector):
paragraph_id_list = [row.get('paragraph_id') for row in embedding_list]
if paragraph_id_list is None or len(paragraph_id_list) == 0:
return []
paragraph_list = native_search(QuerySet(Paragraph).filter(id__in=paragraph_id_list),
get_file_content(
os.path.join(PROJECT_DIR, "apps", "application", 'sql',
'list_dataset_paragraph_by_paragraph_id.sql')),
with_table_name=True)
# 如果向量库中存在脏数据 直接删除
if len(paragraph_list) != len(paragraph_id_list):
exist_paragraph_list = [row.get('id') for row in paragraph_list]
for paragraph_id in paragraph_id_list:
if not exist_paragraph_list.__contains__(paragraph_id):
vector.delete_by_paragraph_id(paragraph_id)
return paragraph_list

View File

@ -0,0 +1,26 @@
# coding=utf-8
"""
@project: maxkb
@Author
@file i_start_node.py
@date2024/6/3 16:54
@desc:
"""
from typing import Type
from rest_framework import serializers
from application.flow.i_step_node import INode, NodeResult
class IStarNode(INode):
type = 'start-node'
def get_node_params_serializer_class(self) -> Type[serializers.Serializer] | None:
return None
def _run(self):
return self.execute(**self.flow_params_serializer.data)
def execute(self, question, **kwargs) -> NodeResult:
pass

View File

@ -0,0 +1,20 @@
# coding=utf-8
"""
@project: maxkb
@Author
@file base_start_node.py
@date2024/6/3 17:17
@desc:
"""
import time
from application.flow.i_step_node import NodeResult
from application.flow.step_node.start_node.i_start_node import IStarNode
class BaseStartStepNode(IStarNode):
def execute(self, question, **kwargs) -> NodeResult:
"""
开始节点 初始化全局变量
"""
return NodeResult({'question': question}, {'time': time.time()})

View File

@ -0,0 +1,79 @@
# coding=utf-8
"""
@project: maxkb
@Author
@file utils.py
@date2024/6/6 15:15
@desc:
"""
import json
from typing import Iterator
from django.http import StreamingHttpResponse
from langchain_core.messages import BaseMessageChunk, BaseMessage
from application.flow.i_step_node import WorkFlowPostHandler
from common.response import result
def event_content(chat_id, chat_record_id, response, workflow,
write_context,
post_handler: WorkFlowPostHandler):
"""
用于处理流式输出
@param chat_id: 会话id
@param chat_record_id: 对话记录id
@param response: 响应数据
@param workflow: 工作流管理器
@param write_context 写入节点上下文
@param post_handler: 后置处理器
"""
answer = ''
for chunk in response:
answer += chunk.content
yield 'data: ' + json.dumps({'chat_id': str(chat_id), 'id': str(chat_record_id), 'operate': True,
'content': chunk.content, 'is_end': False}) + "\n\n"
write_context(answer)
post_handler.handler(chat_id, chat_record_id, answer, workflow)
yield 'data: ' + json.dumps({'chat_id': str(chat_id), 'id': str(chat_record_id), 'operate': True,
'content': '', 'is_end': True}) + "\n\n"
def to_stream_response(chat_id, chat_record_id, response: Iterator[BaseMessageChunk], workflow, write_context,
post_handler):
"""
将结果转换为服务流输出
@param chat_id: 会话id
@param chat_record_id: 对话记录id
@param response: 响应数据
@param workflow: 工作流管理器
@param write_context 写入节点上下文
@param post_handler: 后置处理器
@return: 响应
"""
r = StreamingHttpResponse(
streaming_content=event_content(chat_id, chat_record_id, response, workflow, write_context, post_handler),
content_type='text/event-stream;charset=utf-8')
r['Cache-Control'] = 'no-cache'
return r
def to_response(chat_id, chat_record_id, response: BaseMessage, workflow, write_context,
post_handler: WorkFlowPostHandler):
"""
将结果转换为服务输出
@param chat_id: 会话id
@param chat_record_id: 对话记录id
@param response: 响应数据
@param workflow: 工作流管理器
@param write_context 写入节点上下文
@param post_handler: 后置处理器
@return: 响应
"""
answer = response.content
write_context(answer)
post_handler.handler(chat_id, chat_record_id, answer, workflow)
return result.success({'chat_id': str(chat_id), 'id': str(chat_record_id), 'operate': True,
'content': answer, 'is_end': True})

View File

@ -0,0 +1,166 @@
# coding=utf-8
"""
@project: maxkb
@Author
@file workflow_manage.py
@date2024/1/9 17:40
@desc:
"""
from typing import List, Dict
from langchain_core.prompts import PromptTemplate
from application.flow.i_step_node import INode, WorkFlowPostHandler
from application.flow.step_node.ai_chat_step_node.impl.base_chat_node import BaseChatNode
from application.flow.step_node.search_dataset_node.impl.base_search_dataset_node import BaseSearchDatasetNode
from application.flow.step_node.start_node.impl.base_start_node import BaseStartStepNode
class Edge:
def __init__(self, _id: str, _type: str, sourceNodeId: str, targetNodeId: str, **keywords):
self.id = _id
self.type = _type
self.sourceNodeId = sourceNodeId
self.targetNodeId = targetNodeId
for keyword in keywords:
self.__setattr__(keyword, keywords.get(keyword))
class Node:
def __init__(self, _id: str, _type: str, x: int, y: int, properties: dict, **kwargs):
self.id = _id
self.type = _type
self.x = x
self.y = y
self.properties = properties
for keyword in kwargs:
self.__setattr__(keyword, kwargs.get(keyword))
class Flow:
def __init__(self, nodes: List[Node], edges: List[Edge]):
self.nodes = nodes
self.edges = edges
@staticmethod
def new_instance(flow_obj: Dict):
nodes = flow_obj.get('nodes')
edges = flow_obj.get('edges')
nodes = [Node(node.get('id'), node.get('type'), **node)
for node in nodes]
edges = [Edge(edge.get('id'), edge.get('type'), **edge) for edge in edges]
return Flow(nodes, edges)
flow_node_dict = {
'start-node': BaseStartStepNode,
'search-dataset-node': BaseSearchDatasetNode,
'chat-node': BaseChatNode
}
class WorkflowManage:
def __init__(self, flow: Flow, params):
self.params = params
self.flow = flow
self.context = {}
self.node_dict = {}
self.runtime_nodes = []
self.current_node = None
def run(self):
"""
运行工作流
"""
while self.has_next_node():
self.current_node = self.get_next_node()
self.node_dict[self.current_node.id] = self.current_node
result = self.current_node.run()
if self.has_next_node():
result.write_context(self.current_node, self)
else:
r = result.to_response(self.params['chat_id'], self.params['chat_record_id'], self.current_node, self,
WorkFlowPostHandler(client_id=self.params['client_id'], chat_info=None,
client_type='ss'))
for row in r:
print(row)
print(self)
def has_next_node(self):
"""
是否有下一个可运行的节点
"""
if self.current_node is None:
if self.get_start_node() is not None:
return True
else:
for edge in self.flow.edges:
if edge.sourceNodeId == self.current_node.id:
return True
return False
def get_runtime_details(self):
return {}
def get_next_node(self):
"""
获取下一个可运行的所有节点
"""
if self.current_node is None:
node = self.get_start_node()
node_instance = flow_node_dict[node.type](node.id, node.properties.get('node_data'),
self.params, self.context)
return node_instance
for edge in self.flow.edges:
if edge.sourceNodeId == self.current_node.id:
return self.get_node_cls_by_id(edge.targetNodeId)
return None
def get_reference_field(self, node_id: str, fields: List[str]):
"""
@param node_id: 节点id
@param fields: 字段
@return:
"""
if node_id == 'global':
return INode.get_field(self.context, fields)
else:
return self.get_node_by_id(node_id).get_reference_field(fields)
def generate_prompt(self, prompt: str):
"""
格式化生成提示词
@param prompt: 提示词信息
@return: 格式化后的提示词
"""
prompt_template = PromptTemplate.from_template(prompt, template_format='jinja2')
context = {
'global': self.context,
}
for key in self.node_dict:
context[key] = self.node_dict[key].context
value = prompt_template.format(context=context)
return value
def get_start_node(self):
"""
获取启动节点
@return:
"""
return self.flow.nodes[0]
def get_node_cls_by_id(self, node_id):
for node in self.flow.nodes:
if node.id == node_id:
node_instance = flow_node_dict[node.type](node.id, node.properties.get('node_data'),
self.params, self)
return node_instance
return None
def get_node_by_id(self, node_id):
return self.node_dict[node_id]
def get_node_reference(self, reference_address: Dict):
node = self.get_node_by_id(reference_address.get('node_id'))
return node.context[reference_address.get('node_field')]