From fcb33db41d1b4a747ef68e4bed512a8a88a19685 Mon Sep 17 00:00:00 2001 From: shaohuzhang1 Date: Thu, 6 Jun 2024 19:27:02 +0800 Subject: [PATCH] =?UTF-8?q?=20feat:=20=E5=B7=A5=E4=BD=9C=E6=B5=81=E7=A8=8B?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- apps/application/flow/i_step_node.py | 168 ++++++++++++++++++ .../ai_chat_step_node/i_chat_node.py | 36 ++++ .../ai_chat_step_node/impl/base_chat_node.py | 145 +++++++++++++++ .../i_search_dataset_node.py | 56 ++++++ .../impl/base_search_dataset_node.py | 73 ++++++++ .../flow/step_node/start_node/i_start_node.py | 26 +++ .../start_node/impl/base_start_node.py | 20 +++ apps/application/flow/tools.py | 79 ++++++++ apps/application/flow/workflow_manage.py | 166 +++++++++++++++++ 9 files changed, 769 insertions(+) create mode 100644 apps/application/flow/i_step_node.py create mode 100644 apps/application/flow/step_node/ai_chat_step_node/i_chat_node.py create mode 100644 apps/application/flow/step_node/ai_chat_step_node/impl/base_chat_node.py create mode 100644 apps/application/flow/step_node/search_dataset_node/i_search_dataset_node.py create mode 100644 apps/application/flow/step_node/search_dataset_node/impl/base_search_dataset_node.py create mode 100644 apps/application/flow/step_node/start_node/i_start_node.py create mode 100644 apps/application/flow/step_node/start_node/impl/base_start_node.py create mode 100644 apps/application/flow/tools.py create mode 100644 apps/application/flow/workflow_manage.py diff --git a/apps/application/flow/i_step_node.py b/apps/application/flow/i_step_node.py new file mode 100644 index 000000000..96b1b7c24 --- /dev/null +++ b/apps/application/flow/i_step_node.py @@ -0,0 +1,168 @@ +# coding=utf-8 +""" + @project: maxkb + @Author:虎 + @file: i_step_node.py + @date:2024/6/3 14:57 + @desc: +""" +import time +from abc import abstractmethod +from typing import Type, Dict, List + +from django.db.models import QuerySet +from rest_framework import serializers + +from application.models import ChatRecord +from application.models.api_key_model import ApplicationPublicAccessClient +from application.serializers.application_serializers import chat_cache +from common.constants.authentication_type import AuthenticationType +from common.field.common import InstanceField +from common.util.field_message import ErrMessage + + +def write_context(step_variable: Dict, global_variable: Dict, node, workflow): + if step_variable is not None: + for key in step_variable: + node.context[key] = step_variable[key] + if global_variable is not None: + for key in global_variable: + workflow.context[key] = global_variable[key] + + +class WorkFlowPostHandler: + def __init__(self, chat_info, client_id, client_type): + self.chat_info = chat_info + self.client_id = client_id + self.client_type = client_type + + def handler(self, chat_id, + chat_record_id, + answer, + workflow): + question = workflow.params['question'] + chat_record = ChatRecord(id=chat_record_id, + chat_id=chat_id, + problem_text=question, + answer_text=answer, + details=workflow.get_details(), + message_tokens=workflow.context['message_tokens'], + answer_tokens=workflow.context['answer_tokens'], + run_time=workflow.context['run_time'], + index=len(self.chat_info.chat_record_list) + 1) + self.chat_info.append_chat_record(chat_record, self.client_id) + # 重新设置缓存 + chat_cache.set(chat_id, + self.chat_info, timeout=60 * 30) + if self.client_type == AuthenticationType.APPLICATION_ACCESS_TOKEN.value: + application_public_access_client = QuerySet(ApplicationPublicAccessClient).filter(id=self.client_id).first() + if application_public_access_client is not None: + application_public_access_client.access_num = application_public_access_client.access_num + 1 + application_public_access_client.intraday_access_num = application_public_access_client.intraday_access_num + 1 + application_public_access_client.save() + + +class NodeResult: + def __init__(self, node_variable: Dict, workflow_variable: Dict, _to_response=None, _write_context=write_context): + self._write_context = _write_context + self.node_variable = node_variable + self.workflow_variable = workflow_variable + self._to_response = _to_response + + def write_context(self, node, workflow): + self._write_context(self.node_variable, self.workflow_variable, node, workflow) + + def to_response(self, chat_id, chat_record_id, node, workflow, post_handler: WorkFlowPostHandler): + return self._to_response(chat_id, chat_record_id, self.node_variable, self.workflow_variable, node, workflow, + post_handler) + + +class ReferenceAddressSerializer(serializers.Serializer): + node_id = serializers.CharField(required=True, error_messages=ErrMessage.char("节点id")) + fields = serializers.ListField( + child=serializers.CharField(required=True, error_messages=ErrMessage.char("节点字段")), required=True, + error_messages=ErrMessage.list("节点字段数组")) + + +class FlowParamsSerializer(serializers.Serializer): + # 历史对答 + history_chat_record = serializers.ListField(child=InstanceField(model_type=ChatRecord, required=True), + error_messages=ErrMessage.list("历史对答")) + + question = serializers.CharField(required=True, error_messages=ErrMessage.list("用户问题")) + + chat_id = serializers.CharField(required=True, error_messages=ErrMessage.list("对话id")) + + chat_record_id = serializers.CharField(required=True, error_messages=ErrMessage.char("对话记录id")) + + stream = serializers.BooleanField(required=True, error_messages=ErrMessage.base("流式输出")) + + client_id = serializers.CharField(required=True, error_messages=ErrMessage.char("客户端id")) + + client_type = serializers.CharField(required=True, error_messages=ErrMessage.char("客户端类型")) + + +class INode: + def __init__(self, _id, node_params, workflow_params, workflow_manage): + # 当前步骤上下文,用于存储当前步骤信息 + self.node_params = node_params + self.workflow_manage = workflow_manage + self.node_params_serializer = None + self.flow_params_serializer = None + self.context = {} + self.id = _id + self.valid_args(node_params, workflow_params) + + def valid_args(self, node_params, flow_params): + flow_params_serializer_class = self.get_flow_params_serializer_class() + node_params_serializer_class = self.get_node_params_serializer_class() + if flow_params_serializer_class is not None: + self.flow_params_serializer = flow_params_serializer_class(data=flow_params) + self.flow_params_serializer.is_valid(raise_exception=True) + if node_params_serializer_class is not None: + self.node_params_serializer = node_params_serializer_class(data=node_params) + self.node_params_serializer.is_valid(raise_exception=True) + + def get_reference_field(self, fields: List[str]): + return self.get_field(self.context, fields) + + @staticmethod + def get_field(obj, fields: List[str]): + for field in fields: + value = obj.get(field) + if value is None: + return None + else: + obj = value + return obj + + @abstractmethod + def get_node_params_serializer_class(self) -> Type[serializers.Serializer]: + pass + + def get_flow_params_serializer_class(self) -> Type[serializers.Serializer]: + return FlowParamsSerializer + + def run(self) -> NodeResult: + """ + :return: 执行结果 + """ + start_time = time.time() + self.context['start_time'] = start_time + result = self._run() + self.context['run_time'] = time.time() - start_time + return result + + def _run(self): + result = self.execute() + return result + + def execute(self, **kwargs) -> NodeResult: + pass + + def get_details(self, **kwargs): + """ + 运行详情 + :return: 步骤详情 + """ + return None diff --git a/apps/application/flow/step_node/ai_chat_step_node/i_chat_node.py b/apps/application/flow/step_node/ai_chat_step_node/i_chat_node.py new file mode 100644 index 000000000..738072aa6 --- /dev/null +++ b/apps/application/flow/step_node/ai_chat_step_node/i_chat_node.py @@ -0,0 +1,36 @@ +# coding=utf-8 +""" + @project: maxkb + @Author:虎 + @file: i_chat_node.py + @date:2024/6/4 13:58 + @desc: +""" +from typing import Type + +from rest_framework import serializers + +from application.flow.i_step_node import INode, NodeResult +from common.util.field_message import ErrMessage + + +class ChatNodeSerializer(serializers.Serializer): + model_id = serializers.CharField(required=True, error_messages=ErrMessage.char("模型id")) + system = serializers.CharField(required=True, error_messages=ErrMessage.char("角色设定")) + prompt = serializers.CharField(required=True, error_messages=ErrMessage.char("提示词")) + # 多轮对话数量 + dialogue_number = serializers.IntegerField(required=True, error_messages=ErrMessage.integer("多轮对话数量")) + + +class IChatNode(INode): + type = 'ai-chat-node' + + def get_node_params_serializer_class(self) -> Type[serializers.Serializer]: + return ChatNodeSerializer + + def _run(self): + return self.execute(**self.node_params_serializer.data, **self.flow_params_serializer.data) + + def execute(self, model_id, system, prompt, dialogue_number, history_chat_record, stream, chat_id, chat_record_id, + **kwargs) -> NodeResult: + pass diff --git a/apps/application/flow/step_node/ai_chat_step_node/impl/base_chat_node.py b/apps/application/flow/step_node/ai_chat_step_node/impl/base_chat_node.py new file mode 100644 index 000000000..e688903e0 --- /dev/null +++ b/apps/application/flow/step_node/ai_chat_step_node/impl/base_chat_node.py @@ -0,0 +1,145 @@ +# coding=utf-8 +""" + @project: maxkb + @Author:虎 + @file: base_chat_node.py + @date:2024/6/4 14:30 + @desc: +""" +import json +from typing import List, Dict + +from django.db.models import QuerySet +from langchain.schema import HumanMessage, SystemMessage +from langchain_core.messages import BaseMessage + +from application.flow import tools +from application.flow.i_step_node import NodeResult, INode +from application.flow.step_node.ai_chat_step_node.i_chat_node import IChatNode +from common.util.rsa_util import rsa_long_decrypt +from setting.models import Model +from setting.models_provider.constants.model_provider_constants import ModelProvideConstants + + +def write_context_stream(node_variable: Dict, workflow_variable: Dict, node: INode, workflow): + """ + 写入上下文数据 (流式) + @param node_variable: 节点数据 + @param workflow_variable: 全局数据 + @param node: 节点 + @param workflow: 工作流管理器 + """ + response = node_variable.get('result') + answer = '' + for chunk in response: + answer += chunk.content + chat_model = node_variable.get('chat_model') + message_tokens = chat_model.get_num_tokens_from_messages(node_variable.get('message_list')) + answer_tokens = chat_model.get_num_tokens(answer) + node.context['message_tokens'] = message_tokens + node.context['answer_tokens'] = answer_tokens + node.context['answer'] = answer + + +def write_context(node_variable: Dict, workflow_variable: Dict, node: INode, workflow): + """ + 写入上下文数据 + @param node_variable: 节点数据 + @param workflow_variable: 全局数据 + @param node: 节点实例对象 + @param workflow: 工作流管理器 + """ + response = node_variable.get('result') + chat_model = node_variable.get('chat_model') + answer = response.content + message_tokens = chat_model.get_num_tokens_from_messages(node_variable.get('message_list')) + answer_tokens = chat_model.get_num_tokens(answer) + node.context['message_tokens'] = message_tokens + node.context['answer_tokens'] = answer_tokens + node.context['answer'] = answer + + +def get_to_response_write_context(node_variable: Dict, node: INode): + def _write_context(answer): + chat_model = node_variable.get('chat_model') + message_tokens = chat_model.get_num_tokens_from_messages(node_variable.get('message_list')) + answer_tokens = chat_model.get_num_tokens(answer) + node.context['message_tokens'] = message_tokens + node.context['answer_tokens'] = answer_tokens + node.context['answer'] = answer + + return _write_context + + +def to_stream_response(chat_id, chat_record_id, node_variable: Dict, workflow_variable: Dict, node, workflow, + post_handler): + """ + 将流式数据 转换为 流式响应 + @param chat_id: 会话id + @param chat_record_id: 对话记录id + @param node_variable: 节点数据 + @param workflow_variable: 工作流数据 + @param node: 节点 + @param workflow: 工作流管理器 + @param post_handler: 后置处理器 输出结果后执行 + @return: 流式响应 + """ + response = node_variable.get('result') + _write_context = get_to_response_write_context(node_variable, node) + return tools.to_stream_response(chat_id, chat_record_id, response, workflow, _write_context, post_handler) + + +def to_response(chat_id, chat_record_id, node_variable: Dict, workflow_variable: Dict, node, workflow, + post_handler): + """ + 将结果转换 + @param chat_id: 会话id + @param chat_record_id: 对话记录id + @param node_variable: 节点数据 + @param workflow_variable: 工作流数据 + @param node: 节点 + @param workflow: 工作流管理器 + @param post_handler: 后置处理器 + @return: 响应 + """ + response = node_variable.get('result') + _write_context = get_to_response_write_context(node_variable, node) + return tools.to_response(chat_id, chat_record_id, response, workflow, _write_context, post_handler) + + +class BaseChatNode(IChatNode): + def execute(self, model_id, system, prompt, dialogue_number, history_chat_record, stream, chat_id, chat_record_id, + **kwargs) -> NodeResult: + model = QuerySet(Model).filter(id=model_id).first() + chat_model = ModelProvideConstants[model.provider].value.get_model(model.model_type, model.model_name, + json.loads( + rsa_long_decrypt(model.credential)), + streaming=True) + message_list = self.generate_message_list(system, prompt, history_chat_record, dialogue_number) + if stream: + r = chat_model.stream(message_list) + return NodeResult({'result': r, 'chat_model': chat_model, 'message_list': message_list, + 'get_to_response_write_context': get_to_response_write_context}, {}, + _write_context=write_context_stream, + _to_response=to_stream_response) + else: + r = chat_model.invoke(message_list) + return NodeResult({'result': r, 'chat_model': chat_model, 'message_list': message_list}, {}, + _write_context=write_context, _to_response=to_response) + + def generate_message_list(self, system: str, prompt: str, history_chat_record, dialogue_number): + start_index = len(history_chat_record) - dialogue_number + history_message = [[history_chat_record[index].get_human_message(), history_chat_record[index].get_ai_message()] + for index in + range(start_index if start_index > 0 else 0, len(history_chat_record))] + return [SystemMessage(self.workflow_manage.generate_prompt(system)), *history_message, + HumanMessage(self.workflow_manage.generate_prompt(prompt))] + + @staticmethod + def reset_message_list(message_list: List[BaseMessage], answer_text): + result = [{'role': 'user' if isinstance(message, HumanMessage) else 'ai', 'content': message.content} for + message + in + message_list] + result.append({'role': 'ai', 'content': answer_text}) + return result diff --git a/apps/application/flow/step_node/search_dataset_node/i_search_dataset_node.py b/apps/application/flow/step_node/search_dataset_node/i_search_dataset_node.py new file mode 100644 index 000000000..b438e363d --- /dev/null +++ b/apps/application/flow/step_node/search_dataset_node/i_search_dataset_node.py @@ -0,0 +1,56 @@ +# coding=utf-8 +""" + @project: maxkb + @Author:虎 + @file: i_search_dataset_node.py + @date:2024/6/3 17:52 + @desc: +""" +import re +from typing import Type + +from django.core import validators +from rest_framework import serializers + +from application.flow.i_step_node import INode, ReferenceAddressSerializer, NodeResult +from common.util.field_message import ErrMessage + + +class SearchDatasetStepNodeSerializer(serializers.Serializer): + # 需要查询的数据集id列表 + dataset_id_list = serializers.ListField(required=True, child=serializers.UUIDField(required=True), + error_messages=ErrMessage.list("数据集id列表")) + # 需要查询的条数 + top_n = serializers.IntegerField(required=True, + error_messages=ErrMessage.integer("引用分段数")) + # 相似度 0-1之间 + similarity = serializers.FloatField(required=True, max_value=2, min_value=0, + error_messages=ErrMessage.float("引用分段数")) + search_mode = serializers.CharField(required=True, validators=[ + validators.RegexValidator(regex=re.compile("^embedding|keywords|blend$"), + message="类型只支持register|reset_password", code=500) + ], error_messages=ErrMessage.char("检索模式")) + + question_reference_address = ReferenceAddressSerializer(required=False, + error_messages=ErrMessage.char("问题应用地址")) + + def is_valid(self, *, raise_exception=False): + super().is_valid(raise_exception=True) + + +class ISearchDatasetStepNode(INode): + type = 'search-dataset-node' + + def get_node_params_serializer_class(self) -> Type[serializers.Serializer]: + return SearchDatasetStepNodeSerializer + + def _run(self): + question = self.workflow_manage.get_reference_field( + self.node_params_serializer.data.get('question_reference_address').get('node_id'), + self.node_params_serializer.data.get('question_reference_address').get('fields')) + return self.execute(**self.node_params_serializer.data, question=question, exclude_paragraph_id_list=[]) + + def execute(self, dataset_id_list, top_n, similarity, search_mode, question_reference_address, question, + exclude_paragraph_id_list=None, + **kwargs) -> NodeResult: + pass diff --git a/apps/application/flow/step_node/search_dataset_node/impl/base_search_dataset_node.py b/apps/application/flow/step_node/search_dataset_node/impl/base_search_dataset_node.py new file mode 100644 index 000000000..28b0eb3c2 --- /dev/null +++ b/apps/application/flow/step_node/search_dataset_node/impl/base_search_dataset_node.py @@ -0,0 +1,73 @@ +# coding=utf-8 +""" + @project: maxkb + @Author:虎 + @file: base_search_dataset_node.py + @date:2024/6/4 11:56 + @desc: +""" +import os +from typing import List, Dict + +from django.db.models import QuerySet + +from application.flow.i_step_node import NodeResult +from application.flow.step_node.search_dataset_node.i_search_dataset_node import ISearchDatasetStepNode +from common.config.embedding_config import EmbeddingModel, VectorStore +from common.db.search import native_search +from common.util.file_util import get_file_content +from dataset.models import Document, Paragraph +from embedding.models import SearchMode +from smartdoc.conf import PROJECT_DIR + + +class BaseSearchDatasetNode(ISearchDatasetStepNode): + def execute(self, dataset_id_list, top_n, similarity, search_mode, question_reference_address, question, + exclude_paragraph_id_list=None, + **kwargs) -> NodeResult: + embedding_model = EmbeddingModel.get_embedding_model() + embedding_value = embedding_model.embed_query(question) + vector = VectorStore.get_embedding_vector() + exclude_document_id_list = [str(document.id) for document in + QuerySet(Document).filter( + dataset_id__in=dataset_id_list, + is_active=False)] + embedding_list = vector.query(question, embedding_value, dataset_id_list, exclude_document_id_list, + exclude_paragraph_id_list, True, top_n, similarity, SearchMode(search_mode)) + if embedding_list is None: + return NodeResult({'paragraph_list': [], 'is_hit_handling_method': []}, {}) + paragraph_list = self.list_paragraph(embedding_list, vector) + result = [self.reset_paragraph(paragraph, embedding_list) for paragraph in paragraph_list] + return NodeResult({'paragraph_list': result, + 'is_hit_handling_method_list': [row.get('is_hit_handling_method') for row in result]}, {}) + + @staticmethod + def reset_paragraph(paragraph: Dict, embedding_list: List): + filter_embedding_list = [embedding for embedding in embedding_list if + str(embedding.get('paragraph_id')) == str(paragraph.get('id'))] + if filter_embedding_list is not None and len(filter_embedding_list) > 0: + find_embedding = filter_embedding_list[-1] + return { + **paragraph, + 'similarity': find_embedding.get('similarity'), + 'is_hit_handling_method': find_embedding.get('similarity') > paragraph.get( + 'directly_return_similarity') and paragraph.get('hit_handling_method') == 'directly_return' + } + + @staticmethod + def list_paragraph(embedding_list: List, vector): + paragraph_id_list = [row.get('paragraph_id') for row in embedding_list] + if paragraph_id_list is None or len(paragraph_id_list) == 0: + return [] + paragraph_list = native_search(QuerySet(Paragraph).filter(id__in=paragraph_id_list), + get_file_content( + os.path.join(PROJECT_DIR, "apps", "application", 'sql', + 'list_dataset_paragraph_by_paragraph_id.sql')), + with_table_name=True) + # 如果向量库中存在脏数据 直接删除 + if len(paragraph_list) != len(paragraph_id_list): + exist_paragraph_list = [row.get('id') for row in paragraph_list] + for paragraph_id in paragraph_id_list: + if not exist_paragraph_list.__contains__(paragraph_id): + vector.delete_by_paragraph_id(paragraph_id) + return paragraph_list diff --git a/apps/application/flow/step_node/start_node/i_start_node.py b/apps/application/flow/step_node/start_node/i_start_node.py new file mode 100644 index 000000000..4c1ecfd2a --- /dev/null +++ b/apps/application/flow/step_node/start_node/i_start_node.py @@ -0,0 +1,26 @@ +# coding=utf-8 +""" + @project: maxkb + @Author:虎 + @file: i_start_node.py + @date:2024/6/3 16:54 + @desc: +""" +from typing import Type + +from rest_framework import serializers + +from application.flow.i_step_node import INode, NodeResult + + +class IStarNode(INode): + type = 'start-node' + + def get_node_params_serializer_class(self) -> Type[serializers.Serializer] | None: + return None + + def _run(self): + return self.execute(**self.flow_params_serializer.data) + + def execute(self, question, **kwargs) -> NodeResult: + pass diff --git a/apps/application/flow/step_node/start_node/impl/base_start_node.py b/apps/application/flow/step_node/start_node/impl/base_start_node.py new file mode 100644 index 000000000..77a721ad8 --- /dev/null +++ b/apps/application/flow/step_node/start_node/impl/base_start_node.py @@ -0,0 +1,20 @@ +# coding=utf-8 +""" + @project: maxkb + @Author:虎 + @file: base_start_node.py + @date:2024/6/3 17:17 + @desc: +""" +import time + +from application.flow.i_step_node import NodeResult +from application.flow.step_node.start_node.i_start_node import IStarNode + + +class BaseStartStepNode(IStarNode): + def execute(self, question, **kwargs) -> NodeResult: + """ + 开始节点 初始化全局变量 + """ + return NodeResult({'question': question}, {'time': time.time()}) diff --git a/apps/application/flow/tools.py b/apps/application/flow/tools.py new file mode 100644 index 000000000..c255369c6 --- /dev/null +++ b/apps/application/flow/tools.py @@ -0,0 +1,79 @@ +# coding=utf-8 +""" + @project: maxkb + @Author:虎 + @file: utils.py + @date:2024/6/6 15:15 + @desc: +""" +import json +from typing import Iterator + +from django.http import StreamingHttpResponse +from langchain_core.messages import BaseMessageChunk, BaseMessage + +from application.flow.i_step_node import WorkFlowPostHandler +from common.response import result + + +def event_content(chat_id, chat_record_id, response, workflow, + write_context, + post_handler: WorkFlowPostHandler): + """ + 用于处理流式输出 + @param chat_id: 会话id + @param chat_record_id: 对话记录id + @param response: 响应数据 + @param workflow: 工作流管理器 + @param write_context 写入节点上下文 + @param post_handler: 后置处理器 + """ + answer = '' + for chunk in response: + answer += chunk.content + yield 'data: ' + json.dumps({'chat_id': str(chat_id), 'id': str(chat_record_id), 'operate': True, + 'content': chunk.content, 'is_end': False}) + "\n\n" + write_context(answer) + post_handler.handler(chat_id, chat_record_id, answer, workflow) + yield 'data: ' + json.dumps({'chat_id': str(chat_id), 'id': str(chat_record_id), 'operate': True, + 'content': '', 'is_end': True}) + "\n\n" + + +def to_stream_response(chat_id, chat_record_id, response: Iterator[BaseMessageChunk], workflow, write_context, + post_handler): + """ + 将结果转换为服务流输出 + @param chat_id: 会话id + @param chat_record_id: 对话记录id + @param response: 响应数据 + @param workflow: 工作流管理器 + @param write_context 写入节点上下文 + @param post_handler: 后置处理器 + @return: 响应 + """ + r = StreamingHttpResponse( + streaming_content=event_content(chat_id, chat_record_id, response, workflow, write_context, post_handler), + content_type='text/event-stream;charset=utf-8') + + r['Cache-Control'] = 'no-cache' + return r + + +def to_response(chat_id, chat_record_id, response: BaseMessage, workflow, write_context, + post_handler: WorkFlowPostHandler): + """ + 将结果转换为服务输出 + + @param chat_id: 会话id + @param chat_record_id: 对话记录id + @param response: 响应数据 + @param workflow: 工作流管理器 + @param write_context 写入节点上下文 + @param post_handler: 后置处理器 + @return: 响应 + """ + answer = response.content + write_context(answer) + post_handler.handler(chat_id, chat_record_id, answer, workflow) + return result.success({'chat_id': str(chat_id), 'id': str(chat_record_id), 'operate': True, + 'content': answer, 'is_end': True}) diff --git a/apps/application/flow/workflow_manage.py b/apps/application/flow/workflow_manage.py new file mode 100644 index 000000000..69e8d3176 --- /dev/null +++ b/apps/application/flow/workflow_manage.py @@ -0,0 +1,166 @@ +# coding=utf-8 +""" + @project: maxkb + @Author:虎 + @file: workflow_manage.py + @date:2024/1/9 17:40 + @desc: +""" +from typing import List, Dict + +from langchain_core.prompts import PromptTemplate + +from application.flow.i_step_node import INode, WorkFlowPostHandler +from application.flow.step_node.ai_chat_step_node.impl.base_chat_node import BaseChatNode +from application.flow.step_node.search_dataset_node.impl.base_search_dataset_node import BaseSearchDatasetNode +from application.flow.step_node.start_node.impl.base_start_node import BaseStartStepNode + + +class Edge: + def __init__(self, _id: str, _type: str, sourceNodeId: str, targetNodeId: str, **keywords): + self.id = _id + self.type = _type + self.sourceNodeId = sourceNodeId + self.targetNodeId = targetNodeId + for keyword in keywords: + self.__setattr__(keyword, keywords.get(keyword)) + + +class Node: + def __init__(self, _id: str, _type: str, x: int, y: int, properties: dict, **kwargs): + self.id = _id + self.type = _type + self.x = x + self.y = y + self.properties = properties + for keyword in kwargs: + self.__setattr__(keyword, kwargs.get(keyword)) + + +class Flow: + def __init__(self, nodes: List[Node], edges: List[Edge]): + self.nodes = nodes + self.edges = edges + + @staticmethod + def new_instance(flow_obj: Dict): + nodes = flow_obj.get('nodes') + edges = flow_obj.get('edges') + nodes = [Node(node.get('id'), node.get('type'), **node) + for node in nodes] + edges = [Edge(edge.get('id'), edge.get('type'), **edge) for edge in edges] + return Flow(nodes, edges) + + +flow_node_dict = { + 'start-node': BaseStartStepNode, + 'search-dataset-node': BaseSearchDatasetNode, + 'chat-node': BaseChatNode +} + + +class WorkflowManage: + def __init__(self, flow: Flow, params): + self.params = params + self.flow = flow + self.context = {} + self.node_dict = {} + self.runtime_nodes = [] + self.current_node = None + + def run(self): + """ + 运行工作流 + """ + while self.has_next_node(): + self.current_node = self.get_next_node() + self.node_dict[self.current_node.id] = self.current_node + result = self.current_node.run() + if self.has_next_node(): + result.write_context(self.current_node, self) + else: + r = result.to_response(self.params['chat_id'], self.params['chat_record_id'], self.current_node, self, + WorkFlowPostHandler(client_id=self.params['client_id'], chat_info=None, + client_type='ss')) + for row in r: + print(row) + print(self) + + def has_next_node(self): + """ + 是否有下一个可运行的节点 + """ + if self.current_node is None: + if self.get_start_node() is not None: + return True + else: + for edge in self.flow.edges: + if edge.sourceNodeId == self.current_node.id: + return True + return False + + def get_runtime_details(self): + return {} + + def get_next_node(self): + """ + 获取下一个可运行的所有节点 + """ + if self.current_node is None: + node = self.get_start_node() + node_instance = flow_node_dict[node.type](node.id, node.properties.get('node_data'), + self.params, self.context) + return node_instance + for edge in self.flow.edges: + if edge.sourceNodeId == self.current_node.id: + return self.get_node_cls_by_id(edge.targetNodeId) + return None + + def get_reference_field(self, node_id: str, fields: List[str]): + """ + + @param node_id: 节点id + @param fields: 字段 + @return: + """ + if node_id == 'global': + return INode.get_field(self.context, fields) + else: + return self.get_node_by_id(node_id).get_reference_field(fields) + + def generate_prompt(self, prompt: str): + """ + 格式化生成提示词 + @param prompt: 提示词信息 + @return: 格式化后的提示词 + """ + prompt_template = PromptTemplate.from_template(prompt, template_format='jinja2') + context = { + 'global': self.context, + } + for key in self.node_dict: + context[key] = self.node_dict[key].context + value = prompt_template.format(context=context) + return value + + def get_start_node(self): + """ + 获取启动节点 + @return: + """ + return self.flow.nodes[0] + + def get_node_cls_by_id(self, node_id): + for node in self.flow.nodes: + if node.id == node_id: + node_instance = flow_node_dict[node.type](node.id, node.properties.get('node_data'), + self.params, self) + return node_instance + return None + + def get_node_by_id(self, node_id): + return self.node_dict[node_id] + + def get_node_reference(self, reference_address: Dict): + node = self.get_node_by_id(reference_address.get('node_id')) + return node.context[reference_address.get('node_field')]