mirror of
https://github.com/1Panel-dev/MaxKB.git
synced 2025-12-26 01:33:05 +00:00
feat: 工作流程
This commit is contained in:
parent
f87919642b
commit
fcb33db41d
|
|
@ -0,0 +1,168 @@
|
|||
# coding=utf-8
|
||||
"""
|
||||
@project: maxkb
|
||||
@Author:虎
|
||||
@file: i_step_node.py
|
||||
@date:2024/6/3 14:57
|
||||
@desc:
|
||||
"""
|
||||
import time
|
||||
from abc import abstractmethod
|
||||
from typing import Type, Dict, List
|
||||
|
||||
from django.db.models import QuerySet
|
||||
from rest_framework import serializers
|
||||
|
||||
from application.models import ChatRecord
|
||||
from application.models.api_key_model import ApplicationPublicAccessClient
|
||||
from application.serializers.application_serializers import chat_cache
|
||||
from common.constants.authentication_type import AuthenticationType
|
||||
from common.field.common import InstanceField
|
||||
from common.util.field_message import ErrMessage
|
||||
|
||||
|
||||
def write_context(step_variable: Dict, global_variable: Dict, node, workflow):
|
||||
if step_variable is not None:
|
||||
for key in step_variable:
|
||||
node.context[key] = step_variable[key]
|
||||
if global_variable is not None:
|
||||
for key in global_variable:
|
||||
workflow.context[key] = global_variable[key]
|
||||
|
||||
|
||||
class WorkFlowPostHandler:
|
||||
def __init__(self, chat_info, client_id, client_type):
|
||||
self.chat_info = chat_info
|
||||
self.client_id = client_id
|
||||
self.client_type = client_type
|
||||
|
||||
def handler(self, chat_id,
|
||||
chat_record_id,
|
||||
answer,
|
||||
workflow):
|
||||
question = workflow.params['question']
|
||||
chat_record = ChatRecord(id=chat_record_id,
|
||||
chat_id=chat_id,
|
||||
problem_text=question,
|
||||
answer_text=answer,
|
||||
details=workflow.get_details(),
|
||||
message_tokens=workflow.context['message_tokens'],
|
||||
answer_tokens=workflow.context['answer_tokens'],
|
||||
run_time=workflow.context['run_time'],
|
||||
index=len(self.chat_info.chat_record_list) + 1)
|
||||
self.chat_info.append_chat_record(chat_record, self.client_id)
|
||||
# 重新设置缓存
|
||||
chat_cache.set(chat_id,
|
||||
self.chat_info, timeout=60 * 30)
|
||||
if self.client_type == AuthenticationType.APPLICATION_ACCESS_TOKEN.value:
|
||||
application_public_access_client = QuerySet(ApplicationPublicAccessClient).filter(id=self.client_id).first()
|
||||
if application_public_access_client is not None:
|
||||
application_public_access_client.access_num = application_public_access_client.access_num + 1
|
||||
application_public_access_client.intraday_access_num = application_public_access_client.intraday_access_num + 1
|
||||
application_public_access_client.save()
|
||||
|
||||
|
||||
class NodeResult:
|
||||
def __init__(self, node_variable: Dict, workflow_variable: Dict, _to_response=None, _write_context=write_context):
|
||||
self._write_context = _write_context
|
||||
self.node_variable = node_variable
|
||||
self.workflow_variable = workflow_variable
|
||||
self._to_response = _to_response
|
||||
|
||||
def write_context(self, node, workflow):
|
||||
self._write_context(self.node_variable, self.workflow_variable, node, workflow)
|
||||
|
||||
def to_response(self, chat_id, chat_record_id, node, workflow, post_handler: WorkFlowPostHandler):
|
||||
return self._to_response(chat_id, chat_record_id, self.node_variable, self.workflow_variable, node, workflow,
|
||||
post_handler)
|
||||
|
||||
|
||||
class ReferenceAddressSerializer(serializers.Serializer):
|
||||
node_id = serializers.CharField(required=True, error_messages=ErrMessage.char("节点id"))
|
||||
fields = serializers.ListField(
|
||||
child=serializers.CharField(required=True, error_messages=ErrMessage.char("节点字段")), required=True,
|
||||
error_messages=ErrMessage.list("节点字段数组"))
|
||||
|
||||
|
||||
class FlowParamsSerializer(serializers.Serializer):
|
||||
# 历史对答
|
||||
history_chat_record = serializers.ListField(child=InstanceField(model_type=ChatRecord, required=True),
|
||||
error_messages=ErrMessage.list("历史对答"))
|
||||
|
||||
question = serializers.CharField(required=True, error_messages=ErrMessage.list("用户问题"))
|
||||
|
||||
chat_id = serializers.CharField(required=True, error_messages=ErrMessage.list("对话id"))
|
||||
|
||||
chat_record_id = serializers.CharField(required=True, error_messages=ErrMessage.char("对话记录id"))
|
||||
|
||||
stream = serializers.BooleanField(required=True, error_messages=ErrMessage.base("流式输出"))
|
||||
|
||||
client_id = serializers.CharField(required=True, error_messages=ErrMessage.char("客户端id"))
|
||||
|
||||
client_type = serializers.CharField(required=True, error_messages=ErrMessage.char("客户端类型"))
|
||||
|
||||
|
||||
class INode:
|
||||
def __init__(self, _id, node_params, workflow_params, workflow_manage):
|
||||
# 当前步骤上下文,用于存储当前步骤信息
|
||||
self.node_params = node_params
|
||||
self.workflow_manage = workflow_manage
|
||||
self.node_params_serializer = None
|
||||
self.flow_params_serializer = None
|
||||
self.context = {}
|
||||
self.id = _id
|
||||
self.valid_args(node_params, workflow_params)
|
||||
|
||||
def valid_args(self, node_params, flow_params):
|
||||
flow_params_serializer_class = self.get_flow_params_serializer_class()
|
||||
node_params_serializer_class = self.get_node_params_serializer_class()
|
||||
if flow_params_serializer_class is not None:
|
||||
self.flow_params_serializer = flow_params_serializer_class(data=flow_params)
|
||||
self.flow_params_serializer.is_valid(raise_exception=True)
|
||||
if node_params_serializer_class is not None:
|
||||
self.node_params_serializer = node_params_serializer_class(data=node_params)
|
||||
self.node_params_serializer.is_valid(raise_exception=True)
|
||||
|
||||
def get_reference_field(self, fields: List[str]):
|
||||
return self.get_field(self.context, fields)
|
||||
|
||||
@staticmethod
|
||||
def get_field(obj, fields: List[str]):
|
||||
for field in fields:
|
||||
value = obj.get(field)
|
||||
if value is None:
|
||||
return None
|
||||
else:
|
||||
obj = value
|
||||
return obj
|
||||
|
||||
@abstractmethod
|
||||
def get_node_params_serializer_class(self) -> Type[serializers.Serializer]:
|
||||
pass
|
||||
|
||||
def get_flow_params_serializer_class(self) -> Type[serializers.Serializer]:
|
||||
return FlowParamsSerializer
|
||||
|
||||
def run(self) -> NodeResult:
|
||||
"""
|
||||
:return: 执行结果
|
||||
"""
|
||||
start_time = time.time()
|
||||
self.context['start_time'] = start_time
|
||||
result = self._run()
|
||||
self.context['run_time'] = time.time() - start_time
|
||||
return result
|
||||
|
||||
def _run(self):
|
||||
result = self.execute()
|
||||
return result
|
||||
|
||||
def execute(self, **kwargs) -> NodeResult:
|
||||
pass
|
||||
|
||||
def get_details(self, **kwargs):
|
||||
"""
|
||||
运行详情
|
||||
:return: 步骤详情
|
||||
"""
|
||||
return None
|
||||
|
|
@ -0,0 +1,36 @@
|
|||
# coding=utf-8
|
||||
"""
|
||||
@project: maxkb
|
||||
@Author:虎
|
||||
@file: i_chat_node.py
|
||||
@date:2024/6/4 13:58
|
||||
@desc:
|
||||
"""
|
||||
from typing import Type
|
||||
|
||||
from rest_framework import serializers
|
||||
|
||||
from application.flow.i_step_node import INode, NodeResult
|
||||
from common.util.field_message import ErrMessage
|
||||
|
||||
|
||||
class ChatNodeSerializer(serializers.Serializer):
|
||||
model_id = serializers.CharField(required=True, error_messages=ErrMessage.char("模型id"))
|
||||
system = serializers.CharField(required=True, error_messages=ErrMessage.char("角色设定"))
|
||||
prompt = serializers.CharField(required=True, error_messages=ErrMessage.char("提示词"))
|
||||
# 多轮对话数量
|
||||
dialogue_number = serializers.IntegerField(required=True, error_messages=ErrMessage.integer("多轮对话数量"))
|
||||
|
||||
|
||||
class IChatNode(INode):
|
||||
type = 'ai-chat-node'
|
||||
|
||||
def get_node_params_serializer_class(self) -> Type[serializers.Serializer]:
|
||||
return ChatNodeSerializer
|
||||
|
||||
def _run(self):
|
||||
return self.execute(**self.node_params_serializer.data, **self.flow_params_serializer.data)
|
||||
|
||||
def execute(self, model_id, system, prompt, dialogue_number, history_chat_record, stream, chat_id, chat_record_id,
|
||||
**kwargs) -> NodeResult:
|
||||
pass
|
||||
|
|
@ -0,0 +1,145 @@
|
|||
# coding=utf-8
|
||||
"""
|
||||
@project: maxkb
|
||||
@Author:虎
|
||||
@file: base_chat_node.py
|
||||
@date:2024/6/4 14:30
|
||||
@desc:
|
||||
"""
|
||||
import json
|
||||
from typing import List, Dict
|
||||
|
||||
from django.db.models import QuerySet
|
||||
from langchain.schema import HumanMessage, SystemMessage
|
||||
from langchain_core.messages import BaseMessage
|
||||
|
||||
from application.flow import tools
|
||||
from application.flow.i_step_node import NodeResult, INode
|
||||
from application.flow.step_node.ai_chat_step_node.i_chat_node import IChatNode
|
||||
from common.util.rsa_util import rsa_long_decrypt
|
||||
from setting.models import Model
|
||||
from setting.models_provider.constants.model_provider_constants import ModelProvideConstants
|
||||
|
||||
|
||||
def write_context_stream(node_variable: Dict, workflow_variable: Dict, node: INode, workflow):
|
||||
"""
|
||||
写入上下文数据 (流式)
|
||||
@param node_variable: 节点数据
|
||||
@param workflow_variable: 全局数据
|
||||
@param node: 节点
|
||||
@param workflow: 工作流管理器
|
||||
"""
|
||||
response = node_variable.get('result')
|
||||
answer = ''
|
||||
for chunk in response:
|
||||
answer += chunk.content
|
||||
chat_model = node_variable.get('chat_model')
|
||||
message_tokens = chat_model.get_num_tokens_from_messages(node_variable.get('message_list'))
|
||||
answer_tokens = chat_model.get_num_tokens(answer)
|
||||
node.context['message_tokens'] = message_tokens
|
||||
node.context['answer_tokens'] = answer_tokens
|
||||
node.context['answer'] = answer
|
||||
|
||||
|
||||
def write_context(node_variable: Dict, workflow_variable: Dict, node: INode, workflow):
|
||||
"""
|
||||
写入上下文数据
|
||||
@param node_variable: 节点数据
|
||||
@param workflow_variable: 全局数据
|
||||
@param node: 节点实例对象
|
||||
@param workflow: 工作流管理器
|
||||
"""
|
||||
response = node_variable.get('result')
|
||||
chat_model = node_variable.get('chat_model')
|
||||
answer = response.content
|
||||
message_tokens = chat_model.get_num_tokens_from_messages(node_variable.get('message_list'))
|
||||
answer_tokens = chat_model.get_num_tokens(answer)
|
||||
node.context['message_tokens'] = message_tokens
|
||||
node.context['answer_tokens'] = answer_tokens
|
||||
node.context['answer'] = answer
|
||||
|
||||
|
||||
def get_to_response_write_context(node_variable: Dict, node: INode):
|
||||
def _write_context(answer):
|
||||
chat_model = node_variable.get('chat_model')
|
||||
message_tokens = chat_model.get_num_tokens_from_messages(node_variable.get('message_list'))
|
||||
answer_tokens = chat_model.get_num_tokens(answer)
|
||||
node.context['message_tokens'] = message_tokens
|
||||
node.context['answer_tokens'] = answer_tokens
|
||||
node.context['answer'] = answer
|
||||
|
||||
return _write_context
|
||||
|
||||
|
||||
def to_stream_response(chat_id, chat_record_id, node_variable: Dict, workflow_variable: Dict, node, workflow,
|
||||
post_handler):
|
||||
"""
|
||||
将流式数据 转换为 流式响应
|
||||
@param chat_id: 会话id
|
||||
@param chat_record_id: 对话记录id
|
||||
@param node_variable: 节点数据
|
||||
@param workflow_variable: 工作流数据
|
||||
@param node: 节点
|
||||
@param workflow: 工作流管理器
|
||||
@param post_handler: 后置处理器 输出结果后执行
|
||||
@return: 流式响应
|
||||
"""
|
||||
response = node_variable.get('result')
|
||||
_write_context = get_to_response_write_context(node_variable, node)
|
||||
return tools.to_stream_response(chat_id, chat_record_id, response, workflow, _write_context, post_handler)
|
||||
|
||||
|
||||
def to_response(chat_id, chat_record_id, node_variable: Dict, workflow_variable: Dict, node, workflow,
|
||||
post_handler):
|
||||
"""
|
||||
将结果转换
|
||||
@param chat_id: 会话id
|
||||
@param chat_record_id: 对话记录id
|
||||
@param node_variable: 节点数据
|
||||
@param workflow_variable: 工作流数据
|
||||
@param node: 节点
|
||||
@param workflow: 工作流管理器
|
||||
@param post_handler: 后置处理器
|
||||
@return: 响应
|
||||
"""
|
||||
response = node_variable.get('result')
|
||||
_write_context = get_to_response_write_context(node_variable, node)
|
||||
return tools.to_response(chat_id, chat_record_id, response, workflow, _write_context, post_handler)
|
||||
|
||||
|
||||
class BaseChatNode(IChatNode):
|
||||
def execute(self, model_id, system, prompt, dialogue_number, history_chat_record, stream, chat_id, chat_record_id,
|
||||
**kwargs) -> NodeResult:
|
||||
model = QuerySet(Model).filter(id=model_id).first()
|
||||
chat_model = ModelProvideConstants[model.provider].value.get_model(model.model_type, model.model_name,
|
||||
json.loads(
|
||||
rsa_long_decrypt(model.credential)),
|
||||
streaming=True)
|
||||
message_list = self.generate_message_list(system, prompt, history_chat_record, dialogue_number)
|
||||
if stream:
|
||||
r = chat_model.stream(message_list)
|
||||
return NodeResult({'result': r, 'chat_model': chat_model, 'message_list': message_list,
|
||||
'get_to_response_write_context': get_to_response_write_context}, {},
|
||||
_write_context=write_context_stream,
|
||||
_to_response=to_stream_response)
|
||||
else:
|
||||
r = chat_model.invoke(message_list)
|
||||
return NodeResult({'result': r, 'chat_model': chat_model, 'message_list': message_list}, {},
|
||||
_write_context=write_context, _to_response=to_response)
|
||||
|
||||
def generate_message_list(self, system: str, prompt: str, history_chat_record, dialogue_number):
|
||||
start_index = len(history_chat_record) - dialogue_number
|
||||
history_message = [[history_chat_record[index].get_human_message(), history_chat_record[index].get_ai_message()]
|
||||
for index in
|
||||
range(start_index if start_index > 0 else 0, len(history_chat_record))]
|
||||
return [SystemMessage(self.workflow_manage.generate_prompt(system)), *history_message,
|
||||
HumanMessage(self.workflow_manage.generate_prompt(prompt))]
|
||||
|
||||
@staticmethod
|
||||
def reset_message_list(message_list: List[BaseMessage], answer_text):
|
||||
result = [{'role': 'user' if isinstance(message, HumanMessage) else 'ai', 'content': message.content} for
|
||||
message
|
||||
in
|
||||
message_list]
|
||||
result.append({'role': 'ai', 'content': answer_text})
|
||||
return result
|
||||
|
|
@ -0,0 +1,56 @@
|
|||
# coding=utf-8
|
||||
"""
|
||||
@project: maxkb
|
||||
@Author:虎
|
||||
@file: i_search_dataset_node.py
|
||||
@date:2024/6/3 17:52
|
||||
@desc:
|
||||
"""
|
||||
import re
|
||||
from typing import Type
|
||||
|
||||
from django.core import validators
|
||||
from rest_framework import serializers
|
||||
|
||||
from application.flow.i_step_node import INode, ReferenceAddressSerializer, NodeResult
|
||||
from common.util.field_message import ErrMessage
|
||||
|
||||
|
||||
class SearchDatasetStepNodeSerializer(serializers.Serializer):
|
||||
# 需要查询的数据集id列表
|
||||
dataset_id_list = serializers.ListField(required=True, child=serializers.UUIDField(required=True),
|
||||
error_messages=ErrMessage.list("数据集id列表"))
|
||||
# 需要查询的条数
|
||||
top_n = serializers.IntegerField(required=True,
|
||||
error_messages=ErrMessage.integer("引用分段数"))
|
||||
# 相似度 0-1之间
|
||||
similarity = serializers.FloatField(required=True, max_value=2, min_value=0,
|
||||
error_messages=ErrMessage.float("引用分段数"))
|
||||
search_mode = serializers.CharField(required=True, validators=[
|
||||
validators.RegexValidator(regex=re.compile("^embedding|keywords|blend$"),
|
||||
message="类型只支持register|reset_password", code=500)
|
||||
], error_messages=ErrMessage.char("检索模式"))
|
||||
|
||||
question_reference_address = ReferenceAddressSerializer(required=False,
|
||||
error_messages=ErrMessage.char("问题应用地址"))
|
||||
|
||||
def is_valid(self, *, raise_exception=False):
|
||||
super().is_valid(raise_exception=True)
|
||||
|
||||
|
||||
class ISearchDatasetStepNode(INode):
|
||||
type = 'search-dataset-node'
|
||||
|
||||
def get_node_params_serializer_class(self) -> Type[serializers.Serializer]:
|
||||
return SearchDatasetStepNodeSerializer
|
||||
|
||||
def _run(self):
|
||||
question = self.workflow_manage.get_reference_field(
|
||||
self.node_params_serializer.data.get('question_reference_address').get('node_id'),
|
||||
self.node_params_serializer.data.get('question_reference_address').get('fields'))
|
||||
return self.execute(**self.node_params_serializer.data, question=question, exclude_paragraph_id_list=[])
|
||||
|
||||
def execute(self, dataset_id_list, top_n, similarity, search_mode, question_reference_address, question,
|
||||
exclude_paragraph_id_list=None,
|
||||
**kwargs) -> NodeResult:
|
||||
pass
|
||||
|
|
@ -0,0 +1,73 @@
|
|||
# coding=utf-8
|
||||
"""
|
||||
@project: maxkb
|
||||
@Author:虎
|
||||
@file: base_search_dataset_node.py
|
||||
@date:2024/6/4 11:56
|
||||
@desc:
|
||||
"""
|
||||
import os
|
||||
from typing import List, Dict
|
||||
|
||||
from django.db.models import QuerySet
|
||||
|
||||
from application.flow.i_step_node import NodeResult
|
||||
from application.flow.step_node.search_dataset_node.i_search_dataset_node import ISearchDatasetStepNode
|
||||
from common.config.embedding_config import EmbeddingModel, VectorStore
|
||||
from common.db.search import native_search
|
||||
from common.util.file_util import get_file_content
|
||||
from dataset.models import Document, Paragraph
|
||||
from embedding.models import SearchMode
|
||||
from smartdoc.conf import PROJECT_DIR
|
||||
|
||||
|
||||
class BaseSearchDatasetNode(ISearchDatasetStepNode):
|
||||
def execute(self, dataset_id_list, top_n, similarity, search_mode, question_reference_address, question,
|
||||
exclude_paragraph_id_list=None,
|
||||
**kwargs) -> NodeResult:
|
||||
embedding_model = EmbeddingModel.get_embedding_model()
|
||||
embedding_value = embedding_model.embed_query(question)
|
||||
vector = VectorStore.get_embedding_vector()
|
||||
exclude_document_id_list = [str(document.id) for document in
|
||||
QuerySet(Document).filter(
|
||||
dataset_id__in=dataset_id_list,
|
||||
is_active=False)]
|
||||
embedding_list = vector.query(question, embedding_value, dataset_id_list, exclude_document_id_list,
|
||||
exclude_paragraph_id_list, True, top_n, similarity, SearchMode(search_mode))
|
||||
if embedding_list is None:
|
||||
return NodeResult({'paragraph_list': [], 'is_hit_handling_method': []}, {})
|
||||
paragraph_list = self.list_paragraph(embedding_list, vector)
|
||||
result = [self.reset_paragraph(paragraph, embedding_list) for paragraph in paragraph_list]
|
||||
return NodeResult({'paragraph_list': result,
|
||||
'is_hit_handling_method_list': [row.get('is_hit_handling_method') for row in result]}, {})
|
||||
|
||||
@staticmethod
|
||||
def reset_paragraph(paragraph: Dict, embedding_list: List):
|
||||
filter_embedding_list = [embedding for embedding in embedding_list if
|
||||
str(embedding.get('paragraph_id')) == str(paragraph.get('id'))]
|
||||
if filter_embedding_list is not None and len(filter_embedding_list) > 0:
|
||||
find_embedding = filter_embedding_list[-1]
|
||||
return {
|
||||
**paragraph,
|
||||
'similarity': find_embedding.get('similarity'),
|
||||
'is_hit_handling_method': find_embedding.get('similarity') > paragraph.get(
|
||||
'directly_return_similarity') and paragraph.get('hit_handling_method') == 'directly_return'
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def list_paragraph(embedding_list: List, vector):
|
||||
paragraph_id_list = [row.get('paragraph_id') for row in embedding_list]
|
||||
if paragraph_id_list is None or len(paragraph_id_list) == 0:
|
||||
return []
|
||||
paragraph_list = native_search(QuerySet(Paragraph).filter(id__in=paragraph_id_list),
|
||||
get_file_content(
|
||||
os.path.join(PROJECT_DIR, "apps", "application", 'sql',
|
||||
'list_dataset_paragraph_by_paragraph_id.sql')),
|
||||
with_table_name=True)
|
||||
# 如果向量库中存在脏数据 直接删除
|
||||
if len(paragraph_list) != len(paragraph_id_list):
|
||||
exist_paragraph_list = [row.get('id') for row in paragraph_list]
|
||||
for paragraph_id in paragraph_id_list:
|
||||
if not exist_paragraph_list.__contains__(paragraph_id):
|
||||
vector.delete_by_paragraph_id(paragraph_id)
|
||||
return paragraph_list
|
||||
|
|
@ -0,0 +1,26 @@
|
|||
# coding=utf-8
|
||||
"""
|
||||
@project: maxkb
|
||||
@Author:虎
|
||||
@file: i_start_node.py
|
||||
@date:2024/6/3 16:54
|
||||
@desc:
|
||||
"""
|
||||
from typing import Type
|
||||
|
||||
from rest_framework import serializers
|
||||
|
||||
from application.flow.i_step_node import INode, NodeResult
|
||||
|
||||
|
||||
class IStarNode(INode):
|
||||
type = 'start-node'
|
||||
|
||||
def get_node_params_serializer_class(self) -> Type[serializers.Serializer] | None:
|
||||
return None
|
||||
|
||||
def _run(self):
|
||||
return self.execute(**self.flow_params_serializer.data)
|
||||
|
||||
def execute(self, question, **kwargs) -> NodeResult:
|
||||
pass
|
||||
|
|
@ -0,0 +1,20 @@
|
|||
# coding=utf-8
|
||||
"""
|
||||
@project: maxkb
|
||||
@Author:虎
|
||||
@file: base_start_node.py
|
||||
@date:2024/6/3 17:17
|
||||
@desc:
|
||||
"""
|
||||
import time
|
||||
|
||||
from application.flow.i_step_node import NodeResult
|
||||
from application.flow.step_node.start_node.i_start_node import IStarNode
|
||||
|
||||
|
||||
class BaseStartStepNode(IStarNode):
|
||||
def execute(self, question, **kwargs) -> NodeResult:
|
||||
"""
|
||||
开始节点 初始化全局变量
|
||||
"""
|
||||
return NodeResult({'question': question}, {'time': time.time()})
|
||||
|
|
@ -0,0 +1,79 @@
|
|||
# coding=utf-8
|
||||
"""
|
||||
@project: maxkb
|
||||
@Author:虎
|
||||
@file: utils.py
|
||||
@date:2024/6/6 15:15
|
||||
@desc:
|
||||
"""
|
||||
import json
|
||||
from typing import Iterator
|
||||
|
||||
from django.http import StreamingHttpResponse
|
||||
from langchain_core.messages import BaseMessageChunk, BaseMessage
|
||||
|
||||
from application.flow.i_step_node import WorkFlowPostHandler
|
||||
from common.response import result
|
||||
|
||||
|
||||
def event_content(chat_id, chat_record_id, response, workflow,
|
||||
write_context,
|
||||
post_handler: WorkFlowPostHandler):
|
||||
"""
|
||||
用于处理流式输出
|
||||
@param chat_id: 会话id
|
||||
@param chat_record_id: 对话记录id
|
||||
@param response: 响应数据
|
||||
@param workflow: 工作流管理器
|
||||
@param write_context 写入节点上下文
|
||||
@param post_handler: 后置处理器
|
||||
"""
|
||||
answer = ''
|
||||
for chunk in response:
|
||||
answer += chunk.content
|
||||
yield 'data: ' + json.dumps({'chat_id': str(chat_id), 'id': str(chat_record_id), 'operate': True,
|
||||
'content': chunk.content, 'is_end': False}) + "\n\n"
|
||||
write_context(answer)
|
||||
post_handler.handler(chat_id, chat_record_id, answer, workflow)
|
||||
yield 'data: ' + json.dumps({'chat_id': str(chat_id), 'id': str(chat_record_id), 'operate': True,
|
||||
'content': '', 'is_end': True}) + "\n\n"
|
||||
|
||||
|
||||
def to_stream_response(chat_id, chat_record_id, response: Iterator[BaseMessageChunk], workflow, write_context,
|
||||
post_handler):
|
||||
"""
|
||||
将结果转换为服务流输出
|
||||
@param chat_id: 会话id
|
||||
@param chat_record_id: 对话记录id
|
||||
@param response: 响应数据
|
||||
@param workflow: 工作流管理器
|
||||
@param write_context 写入节点上下文
|
||||
@param post_handler: 后置处理器
|
||||
@return: 响应
|
||||
"""
|
||||
r = StreamingHttpResponse(
|
||||
streaming_content=event_content(chat_id, chat_record_id, response, workflow, write_context, post_handler),
|
||||
content_type='text/event-stream;charset=utf-8')
|
||||
|
||||
r['Cache-Control'] = 'no-cache'
|
||||
return r
|
||||
|
||||
|
||||
def to_response(chat_id, chat_record_id, response: BaseMessage, workflow, write_context,
|
||||
post_handler: WorkFlowPostHandler):
|
||||
"""
|
||||
将结果转换为服务输出
|
||||
|
||||
@param chat_id: 会话id
|
||||
@param chat_record_id: 对话记录id
|
||||
@param response: 响应数据
|
||||
@param workflow: 工作流管理器
|
||||
@param write_context 写入节点上下文
|
||||
@param post_handler: 后置处理器
|
||||
@return: 响应
|
||||
"""
|
||||
answer = response.content
|
||||
write_context(answer)
|
||||
post_handler.handler(chat_id, chat_record_id, answer, workflow)
|
||||
return result.success({'chat_id': str(chat_id), 'id': str(chat_record_id), 'operate': True,
|
||||
'content': answer, 'is_end': True})
|
||||
|
|
@ -0,0 +1,166 @@
|
|||
# coding=utf-8
|
||||
"""
|
||||
@project: maxkb
|
||||
@Author:虎
|
||||
@file: workflow_manage.py
|
||||
@date:2024/1/9 17:40
|
||||
@desc:
|
||||
"""
|
||||
from typing import List, Dict
|
||||
|
||||
from langchain_core.prompts import PromptTemplate
|
||||
|
||||
from application.flow.i_step_node import INode, WorkFlowPostHandler
|
||||
from application.flow.step_node.ai_chat_step_node.impl.base_chat_node import BaseChatNode
|
||||
from application.flow.step_node.search_dataset_node.impl.base_search_dataset_node import BaseSearchDatasetNode
|
||||
from application.flow.step_node.start_node.impl.base_start_node import BaseStartStepNode
|
||||
|
||||
|
||||
class Edge:
|
||||
def __init__(self, _id: str, _type: str, sourceNodeId: str, targetNodeId: str, **keywords):
|
||||
self.id = _id
|
||||
self.type = _type
|
||||
self.sourceNodeId = sourceNodeId
|
||||
self.targetNodeId = targetNodeId
|
||||
for keyword in keywords:
|
||||
self.__setattr__(keyword, keywords.get(keyword))
|
||||
|
||||
|
||||
class Node:
|
||||
def __init__(self, _id: str, _type: str, x: int, y: int, properties: dict, **kwargs):
|
||||
self.id = _id
|
||||
self.type = _type
|
||||
self.x = x
|
||||
self.y = y
|
||||
self.properties = properties
|
||||
for keyword in kwargs:
|
||||
self.__setattr__(keyword, kwargs.get(keyword))
|
||||
|
||||
|
||||
class Flow:
|
||||
def __init__(self, nodes: List[Node], edges: List[Edge]):
|
||||
self.nodes = nodes
|
||||
self.edges = edges
|
||||
|
||||
@staticmethod
|
||||
def new_instance(flow_obj: Dict):
|
||||
nodes = flow_obj.get('nodes')
|
||||
edges = flow_obj.get('edges')
|
||||
nodes = [Node(node.get('id'), node.get('type'), **node)
|
||||
for node in nodes]
|
||||
edges = [Edge(edge.get('id'), edge.get('type'), **edge) for edge in edges]
|
||||
return Flow(nodes, edges)
|
||||
|
||||
|
||||
flow_node_dict = {
|
||||
'start-node': BaseStartStepNode,
|
||||
'search-dataset-node': BaseSearchDatasetNode,
|
||||
'chat-node': BaseChatNode
|
||||
}
|
||||
|
||||
|
||||
class WorkflowManage:
|
||||
def __init__(self, flow: Flow, params):
|
||||
self.params = params
|
||||
self.flow = flow
|
||||
self.context = {}
|
||||
self.node_dict = {}
|
||||
self.runtime_nodes = []
|
||||
self.current_node = None
|
||||
|
||||
def run(self):
|
||||
"""
|
||||
运行工作流
|
||||
"""
|
||||
while self.has_next_node():
|
||||
self.current_node = self.get_next_node()
|
||||
self.node_dict[self.current_node.id] = self.current_node
|
||||
result = self.current_node.run()
|
||||
if self.has_next_node():
|
||||
result.write_context(self.current_node, self)
|
||||
else:
|
||||
r = result.to_response(self.params['chat_id'], self.params['chat_record_id'], self.current_node, self,
|
||||
WorkFlowPostHandler(client_id=self.params['client_id'], chat_info=None,
|
||||
client_type='ss'))
|
||||
for row in r:
|
||||
print(row)
|
||||
print(self)
|
||||
|
||||
def has_next_node(self):
|
||||
"""
|
||||
是否有下一个可运行的节点
|
||||
"""
|
||||
if self.current_node is None:
|
||||
if self.get_start_node() is not None:
|
||||
return True
|
||||
else:
|
||||
for edge in self.flow.edges:
|
||||
if edge.sourceNodeId == self.current_node.id:
|
||||
return True
|
||||
return False
|
||||
|
||||
def get_runtime_details(self):
|
||||
return {}
|
||||
|
||||
def get_next_node(self):
|
||||
"""
|
||||
获取下一个可运行的所有节点
|
||||
"""
|
||||
if self.current_node is None:
|
||||
node = self.get_start_node()
|
||||
node_instance = flow_node_dict[node.type](node.id, node.properties.get('node_data'),
|
||||
self.params, self.context)
|
||||
return node_instance
|
||||
for edge in self.flow.edges:
|
||||
if edge.sourceNodeId == self.current_node.id:
|
||||
return self.get_node_cls_by_id(edge.targetNodeId)
|
||||
return None
|
||||
|
||||
def get_reference_field(self, node_id: str, fields: List[str]):
|
||||
"""
|
||||
|
||||
@param node_id: 节点id
|
||||
@param fields: 字段
|
||||
@return:
|
||||
"""
|
||||
if node_id == 'global':
|
||||
return INode.get_field(self.context, fields)
|
||||
else:
|
||||
return self.get_node_by_id(node_id).get_reference_field(fields)
|
||||
|
||||
def generate_prompt(self, prompt: str):
|
||||
"""
|
||||
格式化生成提示词
|
||||
@param prompt: 提示词信息
|
||||
@return: 格式化后的提示词
|
||||
"""
|
||||
prompt_template = PromptTemplate.from_template(prompt, template_format='jinja2')
|
||||
context = {
|
||||
'global': self.context,
|
||||
}
|
||||
for key in self.node_dict:
|
||||
context[key] = self.node_dict[key].context
|
||||
value = prompt_template.format(context=context)
|
||||
return value
|
||||
|
||||
def get_start_node(self):
|
||||
"""
|
||||
获取启动节点
|
||||
@return:
|
||||
"""
|
||||
return self.flow.nodes[0]
|
||||
|
||||
def get_node_cls_by_id(self, node_id):
|
||||
for node in self.flow.nodes:
|
||||
if node.id == node_id:
|
||||
node_instance = flow_node_dict[node.type](node.id, node.properties.get('node_data'),
|
||||
self.params, self)
|
||||
return node_instance
|
||||
return None
|
||||
|
||||
def get_node_by_id(self, node_id):
|
||||
return self.node_dict[node_id]
|
||||
|
||||
def get_node_reference(self, reference_address: Dict):
|
||||
node = self.get_node_by_id(reference_address.get('node_id'))
|
||||
return node.context[reference_address.get('node_field')]
|
||||
Loading…
Reference in New Issue