diff --git a/apps/application/flow/__init__.py b/apps/application/flow/__init__.py new file mode 100644 index 000000000..328e8f8ec --- /dev/null +++ b/apps/application/flow/__init__.py @@ -0,0 +1,8 @@ +# coding=utf-8 +""" + @project: maxkb + @Author:虎 + @file: __init__.py.py + @date:2024/6/7 14:43 + @desc: +""" diff --git a/apps/application/flow/default_workflow.json b/apps/application/flow/default_workflow.json new file mode 100644 index 000000000..9c460f549 --- /dev/null +++ b/apps/application/flow/default_workflow.json @@ -0,0 +1,426 @@ +{ + "nodes": [ + { + "id": "base-node", + "type": "base-node", + "x": 440, + "y": 3350, + "properties": { + "config": {}, + "height": 517, + "stepName": "基本信息", + "node_data": { + "desc": "", + "name": "", + "prologue": "您好,我是 MaxKB 小助手,您可以向我提出 MaxKB 使用问题。\n- MaxKB 主要功能有什么?\n- MaxKB 支持哪些大语言模型?\n- MaxKB 支持哪些文档类型?" + } + } + }, + { + "id": "start-node", + "type": "start-node", + "x": 440, + "y": 3710, + "properties": { + "config": { + "fields": [ + { + "label": "用户问题", + "value": "question" + } + ], + "globalFields": [ + { + "value": "time", + "label": "当前时间" + } + ] + }, + "fields": [ + { + "label": "用户问题", + "value": "question" + } + ], + "height": 268.533, + "stepName": "开始", + "globalFields": [ + { + "label": "当前时间", + "value": "time" + } + ] + } + }, + { + "id": "b931efe5-5b66-46e0-ae3b-0160cb18eeb5", + "type": "search-dataset-node", + "x": 830, + "y": 3470, + "properties": { + "config": { + "fields": [ + { + "label": "检索结果的分段列表", + "value": "paragraph_list" + }, + { + "label": "满足直接回答的分段列表", + "value": "is_hit_handling_method_list" + }, + { + "label": "检索结果", + "value": "data" + }, + { + "label": "满足直接回答的分段内容", + "value": "directly_return" + } + ] + }, + "height": 754.8, + "stepName": "知识库检索", + "node_data": { + "dataset_id_list": [], + "dataset_setting": { + "top_n": 3, + "similarity": 0.6, + "search_mode": "embedding", + "max_paragraph_char_number": 5000 + }, + "question_reference_address": [ + "start-node", + "question" + ] + } + } + }, + { + "id": "fc60863a-dec2-4854-9e5a-7a44b7187a2b", + "type": "condition-node", + "x": 1380, + "y": 3470, + "properties": { + "width": 600, + "config": { + "fields": [ + { + "label": "分支名称", + "value": "branch_name" + } + ] + }, + "height": 524.6669999999999, + "stepName": "判断器", + "node_data": { + "branch": [ + { + "id": "1009", + "type": "IF", + "condition": "and", + "conditions": [ + { + "field": [ + "b931efe5-5b66-46e0-ae3b-0160cb18eeb5", + "is_hit_handling_method_list" + ], + "value": "1", + "compare": "len_ge" + } + ] + }, + { + "id": "4908", + "type": "ELSE IF 1", + "condition": "and", + "conditions": [ + { + "field": [ + "b931efe5-5b66-46e0-ae3b-0160cb18eeb5", + "paragraph_list" + ], + "value": "1", + "compare": "len_ge" + } + ] + }, + { + "id": "161", + "type": "ELSE", + "condition": "and", + "conditions": [] + } + ] + }, + "branch_condition_list": [ + { + "index": 0, + "height": 116.133, + "id": "1009" + }, + { + "index": 1, + "height": 116.133, + "id": "4908" + }, + { + "index": 2, + "height": 40, + "id": "161" + } + ] + } + }, + { + "id": "4ffe1086-25df-4c85-b168-979b5bbf0a26", + "type": "reply-node", + "x": 2090, + "y": 2820, + "properties": { + "config": { + "fields": [ + { + "label": "内容", + "value": "answer" + } + ] + }, + "height": 312.267, + "stepName": "指定回复", + "node_data": { + "fields": [ + "b931efe5-5b66-46e0-ae3b-0160cb18eeb5", + "directly_return" + ], + "content": "", + "reply_type": "referencing" + } + } + }, + { + "id": "f1f1ee18-5a02-46f6-b4e6-226253cdffbb", + "type": "ai-chat-node", + "x": 2090, + "y": 3460, + "properties": { + "config": { + "fields": [ + { + "label": "AI 回答内容", + "value": "answer" + } + ] + }, + "height": 681.4, + "stepName": "AI 对话", + "node_data": { + "prompt": "已知信息:\n{{知识库检索.data}}\n问题:\n{{开始.question}}", + "system": "", + "model_id": "", + "dialogue_number": 0 + } + } + }, + { + "id": "309d0eef-c597-46b5-8d51-b9a28aaef4c7", + "type": "ai-chat-node", + "x": 2090, + "y": 4180, + "properties": { + "config": { + "fields": [ + { + "label": "AI 回答内容", + "value": "answer" + } + ] + }, + "height": 681.4, + "stepName": "AI 对话1", + "node_data": { + "prompt": "{{开始.question}}", + "system": "", + "model_id": "", + "dialogue_number": 0 + } + } + } + ], + "edges": [ + { + "id": "7d0f166f-c472-41b2-b9a2-c294f4c83d73", + "type": "app-edge", + "sourceNodeId": "start-node", + "targetNodeId": "b931efe5-5b66-46e0-ae3b-0160cb18eeb5", + "startPoint": { + "x": 600, + "y": 3710 + }, + "endPoint": { + "x": 670, + "y": 3470 + }, + "properties": {}, + "pointsList": [ + { + "x": 600, + "y": 3710 + }, + { + "x": 710, + "y": 3710 + }, + { + "x": 560, + "y": 3470 + }, + { + "x": 670, + "y": 3470 + } + ], + "sourceAnchorId": "start-node_right", + "targetAnchorId": "b931efe5-5b66-46e0-ae3b-0160cb18eeb5_left" + }, + { + "id": "35cb86dd-f328-429e-a973-12fd7218b696", + "type": "app-edge", + "sourceNodeId": "b931efe5-5b66-46e0-ae3b-0160cb18eeb5", + "targetNodeId": "fc60863a-dec2-4854-9e5a-7a44b7187a2b", + "startPoint": { + "x": 990, + "y": 3470 + }, + "endPoint": { + "x": 1090, + "y": 3470 + }, + "properties": {}, + "pointsList": [ + { + "x": 990, + "y": 3470 + }, + { + "x": 1100, + "y": 3470 + }, + { + "x": 980, + "y": 3470 + }, + { + "x": 1090, + "y": 3470 + } + ], + "sourceAnchorId": "b931efe5-5b66-46e0-ae3b-0160cb18eeb5_right", + "targetAnchorId": "fc60863a-dec2-4854-9e5a-7a44b7187a2b_left" + }, + { + "id": "e8f6cfe6-7e48-41cd-abd3-abfb5304d0d8", + "type": "app-edge", + "sourceNodeId": "fc60863a-dec2-4854-9e5a-7a44b7187a2b", + "targetNodeId": "4ffe1086-25df-4c85-b168-979b5bbf0a26", + "startPoint": { + "x": 1670, + "y": 3340.733 + }, + "endPoint": { + "x": 1930, + "y": 2820 + }, + "properties": {}, + "pointsList": [ + { + "x": 1670, + "y": 3340.733 + }, + { + "x": 1780, + "y": 3340.733 + }, + { + "x": 1820, + "y": 2820 + }, + { + "x": 1930, + "y": 2820 + } + ], + "sourceAnchorId": "fc60863a-dec2-4854-9e5a-7a44b7187a2b_1009_right", + "targetAnchorId": "4ffe1086-25df-4c85-b168-979b5bbf0a26_left" + }, + { + "id": "994ff325-6f7a-4ebc-b61b-10e15519d6d2", + "type": "app-edge", + "sourceNodeId": "fc60863a-dec2-4854-9e5a-7a44b7187a2b", + "targetNodeId": "f1f1ee18-5a02-46f6-b4e6-226253cdffbb", + "startPoint": { + "x": 1670, + "y": 3464.866 + }, + "endPoint": { + "x": 1930, + "y": 3460 + }, + "properties": {}, + "pointsList": [ + { + "x": 1670, + "y": 3464.866 + }, + { + "x": 1780, + "y": 3464.866 + }, + { + "x": 1820, + "y": 3460 + }, + { + "x": 1930, + "y": 3460 + } + ], + "sourceAnchorId": "fc60863a-dec2-4854-9e5a-7a44b7187a2b_4908_right", + "targetAnchorId": "f1f1ee18-5a02-46f6-b4e6-226253cdffbb_left" + }, + { + "id": "19270caf-bb9f-4ba7-9bf8-200aa70fecd5", + "type": "app-edge", + "sourceNodeId": "fc60863a-dec2-4854-9e5a-7a44b7187a2b", + "targetNodeId": "309d0eef-c597-46b5-8d51-b9a28aaef4c7", + "startPoint": { + "x": 1670, + "y": 3550.9325000000003 + }, + "endPoint": { + "x": 1930, + "y": 4180 + }, + "properties": {}, + "pointsList": [ + { + "x": 1670, + "y": 3550.9325000000003 + }, + { + "x": 1780, + "y": 3550.9325000000003 + }, + { + "x": 1820, + "y": 4180 + }, + { + "x": 1930, + "y": 4180 + } + ], + "sourceAnchorId": "fc60863a-dec2-4854-9e5a-7a44b7187a2b_161_right", + "targetAnchorId": "309d0eef-c597-46b5-8d51-b9a28aaef4c7_left" + } + ] +} \ No newline at end of file diff --git a/apps/application/flow/i_step_node.py b/apps/application/flow/i_step_node.py new file mode 100644 index 000000000..0aa620e26 --- /dev/null +++ b/apps/application/flow/i_step_node.py @@ -0,0 +1,190 @@ +# coding=utf-8 +""" + @project: maxkb + @Author:虎 + @file: i_step_node.py + @date:2024/6/3 14:57 + @desc: +""" +import time +from abc import abstractmethod +from typing import Type, Dict, List + +from django.db.models import QuerySet +from rest_framework import serializers + +from application.models import ChatRecord +from application.models.api_key_model import ApplicationPublicAccessClient +from common.constants.authentication_type import AuthenticationType +from common.field.common import InstanceField +from common.util.field_message import ErrMessage +from django.core import cache + +chat_cache = cache.caches['model_cache'] + + +def write_context(step_variable: Dict, global_variable: Dict, node, workflow): + if step_variable is not None: + for key in step_variable: + node.context[key] = step_variable[key] + if global_variable is not None: + for key in global_variable: + workflow.context[key] = global_variable[key] + + +class WorkFlowPostHandler: + def __init__(self, chat_info, client_id, client_type): + self.chat_info = chat_info + self.client_id = client_id + self.client_type = client_type + + def handler(self, chat_id, + chat_record_id, + answer, + workflow): + question = workflow.params['question'] + details = workflow.get_runtime_details() + message_tokens = sum([row.get('message_tokens') for row in details.values() if + 'message_tokens' in row and row.get('message_tokens') is not None]) + answer_tokens = sum([row.get('answer_tokens') for row in details.values() if + 'answer_tokens' in row and row.get('answer_tokens') is not None]) + chat_record = ChatRecord(id=chat_record_id, + chat_id=chat_id, + problem_text=question, + answer_text=answer, + details=details, + message_tokens=message_tokens, + answer_tokens=answer_tokens, + run_time=time.time() - workflow.context['start_time'], + index=0) + self.chat_info.append_chat_record(chat_record, self.client_id) + # 重新设置缓存 + chat_cache.set(chat_id, + self.chat_info, timeout=60 * 30) + if self.client_type == AuthenticationType.APPLICATION_ACCESS_TOKEN.value: + application_public_access_client = QuerySet(ApplicationPublicAccessClient).filter(id=self.client_id).first() + if application_public_access_client is not None: + application_public_access_client.access_num = application_public_access_client.access_num + 1 + application_public_access_client.intraday_access_num = application_public_access_client.intraday_access_num + 1 + application_public_access_client.save() + + +class NodeResult: + def __init__(self, node_variable: Dict, workflow_variable: Dict, _to_response=None, _write_context=write_context): + self._write_context = _write_context + self.node_variable = node_variable + self.workflow_variable = workflow_variable + self._to_response = _to_response + + def write_context(self, node, workflow): + self._write_context(self.node_variable, self.workflow_variable, node, workflow) + + def to_response(self, chat_id, chat_record_id, node, workflow, post_handler: WorkFlowPostHandler): + return self._to_response(chat_id, chat_record_id, self.node_variable, self.workflow_variable, node, workflow, + post_handler) + + def is_assertion_result(self): + return 'branch_id' in self.node_variable + + +class ReferenceAddressSerializer(serializers.Serializer): + node_id = serializers.CharField(required=True, error_messages=ErrMessage.char("节点id")) + fields = serializers.ListField( + child=serializers.CharField(required=True, error_messages=ErrMessage.char("节点字段")), required=True, + error_messages=ErrMessage.list("节点字段数组")) + + +class FlowParamsSerializer(serializers.Serializer): + # 历史对答 + history_chat_record = serializers.ListField(child=InstanceField(model_type=ChatRecord, required=True), + error_messages=ErrMessage.list("历史对答")) + + question = serializers.CharField(required=True, error_messages=ErrMessage.list("用户问题")) + + chat_id = serializers.CharField(required=True, error_messages=ErrMessage.list("对话id")) + + chat_record_id = serializers.CharField(required=True, error_messages=ErrMessage.char("对话记录id")) + + stream = serializers.BooleanField(required=True, error_messages=ErrMessage.base("流式输出")) + + client_id = serializers.CharField(required=False, error_messages=ErrMessage.char("客户端id")) + + client_type = serializers.CharField(required=False, error_messages=ErrMessage.char("客户端类型")) + + +class INode: + def __init__(self, node, workflow_params, workflow_manage): + # 当前步骤上下文,用于存储当前步骤信息 + self.status = 200 + self.err_message = '' + self.node = node + self.node_params = node.properties.get('node_data') + self.workflow_manage = workflow_manage + self.node_params_serializer = None + self.flow_params_serializer = None + self.context = {} + self.id = node.id + self.valid_args(self.node_params, workflow_params) + + def valid_args(self, node_params, flow_params): + flow_params_serializer_class = self.get_flow_params_serializer_class() + node_params_serializer_class = self.get_node_params_serializer_class() + if flow_params_serializer_class is not None and flow_params is not None: + self.flow_params_serializer = flow_params_serializer_class(data=flow_params) + self.flow_params_serializer.is_valid(raise_exception=True) + if node_params_serializer_class is not None: + self.node_params_serializer = node_params_serializer_class(data=node_params) + self.node_params_serializer.is_valid(raise_exception=True) + + def get_reference_field(self, fields: List[str]): + return self.get_field(self.context, fields) + + @staticmethod + def get_field(obj, fields: List[str]): + for field in fields: + value = obj.get(field) + if value is None: + return None + else: + obj = value + return obj + + @abstractmethod + def get_node_params_serializer_class(self) -> Type[serializers.Serializer]: + pass + + def get_flow_params_serializer_class(self) -> Type[serializers.Serializer]: + return FlowParamsSerializer + + def get_write_error_context(self, e): + self.status = 500 + self.err_message = str(e) + + def write_error_context(answer, status=200): + pass + + return write_error_context + + def run(self) -> NodeResult: + """ + :return: 执行结果 + """ + start_time = time.time() + self.context['start_time'] = start_time + result = self._run() + self.context['run_time'] = time.time() - start_time + return result + + def _run(self): + result = self.execute() + return result + + def execute(self, **kwargs) -> NodeResult: + pass + + def get_details(self, index: int, **kwargs): + """ + 运行详情 + :return: 步骤详情 + """ + return {} diff --git a/apps/application/flow/step_node/__init__.py b/apps/application/flow/step_node/__init__.py new file mode 100644 index 000000000..b3692fa8b --- /dev/null +++ b/apps/application/flow/step_node/__init__.py @@ -0,0 +1,23 @@ +# coding=utf-8 +""" + @project: maxkb + @Author:虎 + @file: __init__.py.py + @date:2024/6/7 14:43 + @desc: +""" +from .ai_chat_step_node import * +from .condition_node import * +from .question_node import * +from .search_dataset_node import * +from .start_node import * +from .direct_reply_node import * + +node_list = [BaseStartStepNode, BaseChatNode, BaseSearchDatasetNode, BaseQuestionNode, BaseConditionNode, BaseReplyNode] + + +def get_node(node_type): + find_list = [node for node in node_list if node.type == node_type] + if len(find_list) > 0: + return find_list[0] + return None diff --git a/apps/application/flow/step_node/ai_chat_step_node/__init__.py b/apps/application/flow/step_node/ai_chat_step_node/__init__.py new file mode 100644 index 000000000..1929ae2af --- /dev/null +++ b/apps/application/flow/step_node/ai_chat_step_node/__init__.py @@ -0,0 +1,9 @@ +# coding=utf-8 +""" + @project: maxkb + @Author:虎 + @file: __init__.py + @date:2024/6/11 15:29 + @desc: +""" +from .impl import * diff --git a/apps/application/flow/step_node/ai_chat_step_node/i_chat_node.py b/apps/application/flow/step_node/ai_chat_step_node/i_chat_node.py new file mode 100644 index 000000000..d0dfbaef9 --- /dev/null +++ b/apps/application/flow/step_node/ai_chat_step_node/i_chat_node.py @@ -0,0 +1,37 @@ +# coding=utf-8 +""" + @project: maxkb + @Author:虎 + @file: i_chat_node.py + @date:2024/6/4 13:58 + @desc: +""" +from typing import Type + +from rest_framework import serializers + +from application.flow.i_step_node import INode, NodeResult +from common.util.field_message import ErrMessage + + +class ChatNodeSerializer(serializers.Serializer): + model_id = serializers.CharField(required=True, error_messages=ErrMessage.char("模型id")) + system = serializers.CharField(required=False, allow_blank=True, allow_null=True, + error_messages=ErrMessage.char("角色设定")) + prompt = serializers.CharField(required=True, error_messages=ErrMessage.char("提示词")) + # 多轮对话数量 + dialogue_number = serializers.IntegerField(required=True, error_messages=ErrMessage.integer("多轮对话数量")) + + +class IChatNode(INode): + type = 'ai-chat-node' + + def get_node_params_serializer_class(self) -> Type[serializers.Serializer]: + return ChatNodeSerializer + + def _run(self): + return self.execute(**self.node_params_serializer.data, **self.flow_params_serializer.data) + + def execute(self, model_id, system, prompt, dialogue_number, history_chat_record, stream, chat_id, chat_record_id, + **kwargs) -> NodeResult: + pass diff --git a/apps/application/flow/step_node/ai_chat_step_node/impl/__init__.py b/apps/application/flow/step_node/ai_chat_step_node/impl/__init__.py new file mode 100644 index 000000000..79051a999 --- /dev/null +++ b/apps/application/flow/step_node/ai_chat_step_node/impl/__init__.py @@ -0,0 +1,9 @@ +# coding=utf-8 +""" + @project: maxkb + @Author:虎 + @file: __init__.py + @date:2024/6/11 15:34 + @desc: +""" +from .base_chat_node import BaseChatNode diff --git a/apps/application/flow/step_node/ai_chat_step_node/impl/base_chat_node.py b/apps/application/flow/step_node/ai_chat_step_node/impl/base_chat_node.py new file mode 100644 index 000000000..48df8a463 --- /dev/null +++ b/apps/application/flow/step_node/ai_chat_step_node/impl/base_chat_node.py @@ -0,0 +1,195 @@ +# coding=utf-8 +""" + @project: maxkb + @Author:虎 + @file: base_question_node.py + @date:2024/6/4 14:30 + @desc: +""" +import json +import time +from functools import reduce +from typing import List, Dict + +from django.db.models import QuerySet +from langchain.schema import HumanMessage, SystemMessage +from langchain_core.messages import BaseMessage + +from application.flow import tools +from application.flow.i_step_node import NodeResult, INode +from application.flow.step_node.ai_chat_step_node.i_chat_node import IChatNode +from common.util.rsa_util import rsa_long_decrypt +from setting.models import Model +from setting.models_provider.constants.model_provider_constants import ModelProvideConstants + + +def write_context_stream(node_variable: Dict, workflow_variable: Dict, node: INode, workflow): + """ + 写入上下文数据 (流式) + @param node_variable: 节点数据 + @param workflow_variable: 全局数据 + @param node: 节点 + @param workflow: 工作流管理器 + """ + response = node_variable.get('result') + answer = '' + for chunk in response: + answer += chunk.content + chat_model = node_variable.get('chat_model') + message_tokens = chat_model.get_num_tokens_from_messages(node_variable.get('message_list')) + answer_tokens = chat_model.get_num_tokens(answer) + node.context['message_tokens'] = message_tokens + node.context['answer_tokens'] = answer_tokens + node.context['answer'] = answer + node.context['history_message'] = node_variable['history_message'] + node.context['question'] = node_variable['question'] + node.context['run_time'] = time.time() - node.context['start_time'] + + +def write_context(node_variable: Dict, workflow_variable: Dict, node: INode, workflow): + """ + 写入上下文数据 + @param node_variable: 节点数据 + @param workflow_variable: 全局数据 + @param node: 节点实例对象 + @param workflow: 工作流管理器 + """ + response = node_variable.get('result') + chat_model = node_variable.get('chat_model') + answer = response.content + message_tokens = chat_model.get_num_tokens_from_messages(node_variable.get('message_list')) + answer_tokens = chat_model.get_num_tokens(answer) + node.context['message_tokens'] = message_tokens + node.context['answer_tokens'] = answer_tokens + node.context['answer'] = answer + node.context['history_message'] = node_variable['history_message'] + node.context['question'] = node_variable['question'] + + +def get_to_response_write_context(node_variable: Dict, node: INode): + def _write_context(answer, status=200): + chat_model = node_variable.get('chat_model') + + if status == 200: + answer_tokens = chat_model.get_num_tokens(answer) + message_tokens = chat_model.get_num_tokens_from_messages(node_variable.get('message_list')) + else: + answer_tokens = 0 + message_tokens = 0 + node.err_message = answer + node.status = status + node.context['message_tokens'] = message_tokens + node.context['answer_tokens'] = answer_tokens + node.context['answer'] = answer + node.context['run_time'] = time.time() - node.context['start_time'] + + return _write_context + + +def to_stream_response(chat_id, chat_record_id, node_variable: Dict, workflow_variable: Dict, node, workflow, + post_handler): + """ + 将流式数据 转换为 流式响应 + @param chat_id: 会话id + @param chat_record_id: 对话记录id + @param node_variable: 节点数据 + @param workflow_variable: 工作流数据 + @param node: 节点 + @param workflow: 工作流管理器 + @param post_handler: 后置处理器 输出结果后执行 + @return: 流式响应 + """ + response = node_variable.get('result') + _write_context = get_to_response_write_context(node_variable, node) + return tools.to_stream_response(chat_id, chat_record_id, response, workflow, _write_context, post_handler) + + +def to_response(chat_id, chat_record_id, node_variable: Dict, workflow_variable: Dict, node, workflow, + post_handler): + """ + 将结果转换 + @param chat_id: 会话id + @param chat_record_id: 对话记录id + @param node_variable: 节点数据 + @param workflow_variable: 工作流数据 + @param node: 节点 + @param workflow: 工作流管理器 + @param post_handler: 后置处理器 + @return: 响应 + """ + response = node_variable.get('result') + _write_context = get_to_response_write_context(node_variable, node) + return tools.to_response(chat_id, chat_record_id, response, workflow, _write_context, post_handler) + + +class BaseChatNode(IChatNode): + def execute(self, model_id, system, prompt, dialogue_number, history_chat_record, stream, chat_id, chat_record_id, + **kwargs) -> NodeResult: + model = QuerySet(Model).filter(id=model_id).first() + chat_model = ModelProvideConstants[model.provider].value.get_model(model.model_type, model.model_name, + json.loads( + rsa_long_decrypt(model.credential)), + streaming=True) + history_message = self.get_history_message(history_chat_record, dialogue_number) + self.context['history_message'] = history_message + question = self.generate_prompt_question(prompt) + self.context['question'] = question.content + message_list = self.generate_message_list(system, prompt, history_message) + self.context['message_list'] = message_list + if stream: + r = chat_model.stream(message_list) + return NodeResult({'result': r, 'chat_model': chat_model, 'message_list': message_list, + 'history_message': history_message, 'question': question.content}, {}, + _write_context=write_context_stream, + _to_response=to_stream_response) + else: + r = chat_model.invoke(message_list) + return NodeResult({'result': r, 'chat_model': chat_model, 'message_list': message_list, + 'history_message': history_message, 'question': question.content}, {}, + _write_context=write_context, _to_response=to_response) + + @staticmethod + def get_history_message(history_chat_record, dialogue_number): + start_index = len(history_chat_record) - dialogue_number + history_message = reduce(lambda x, y: [*x, *y], [ + [history_chat_record[index].get_human_message(), history_chat_record[index].get_ai_message()] + for index in + range(start_index if start_index > 0 else 0, len(history_chat_record))], []) + return history_message + + def generate_prompt_question(self, prompt): + return HumanMessage(self.workflow_manage.generate_prompt(prompt)) + + def generate_message_list(self, system: str, prompt: str, history_message): + if system is not None and len(system) > 0: + return [SystemMessage(self.workflow_manage.generate_prompt(system)), *history_message, + HumanMessage(self.workflow_manage.generate_prompt(prompt))] + else: + return [*history_message, HumanMessage(self.workflow_manage.generate_prompt(prompt))] + + @staticmethod + def reset_message_list(message_list: List[BaseMessage], answer_text): + result = [{'role': 'user' if isinstance(message, HumanMessage) else 'ai', 'content': message.content} for + message + in + message_list] + result.append({'role': 'ai', 'content': answer_text}) + return result + + def get_details(self, index: int, **kwargs): + return { + 'name': self.node.properties.get('stepName'), + "index": index, + 'run_time': self.context.get('run_time'), + 'system': self.node_params.get('system'), + 'history_message': [{'content': message.content, 'role': message.type} for message in + (self.context.get('history_message') if self.context.get( + 'history_message') is not None else [])], + 'question': self.context.get('question'), + 'answer': self.context.get('answer'), + 'type': self.node.type, + 'message_tokens': self.context.get('message_tokens'), + 'answer_tokens': self.context.get('answer_tokens'), + 'status': self.status, + 'err_message': self.err_message + } diff --git a/apps/application/flow/step_node/condition_node/__init__.py b/apps/application/flow/step_node/condition_node/__init__.py new file mode 100644 index 000000000..57638504c --- /dev/null +++ b/apps/application/flow/step_node/condition_node/__init__.py @@ -0,0 +1,9 @@ +# coding=utf-8 +""" + @project: maxkb + @Author:虎 + @file: __init__.py.py + @date:2024/6/7 14:43 + @desc: +""" +from .impl import * diff --git a/apps/application/flow/step_node/condition_node/compare/__init__.py b/apps/application/flow/step_node/condition_node/compare/__init__.py new file mode 100644 index 000000000..b2c464b41 --- /dev/null +++ b/apps/application/flow/step_node/condition_node/compare/__init__.py @@ -0,0 +1,23 @@ +# coding=utf-8 +""" + @project: maxkb + @Author:虎 + @file: __init__.py.py + @date:2024/6/7 14:43 + @desc: +""" + +from .contain_compare import * +from .equal_compare import * +from .gt_compare import * +from .ge_compare import * +from .le_compare import * +from .lt_compare import * +from .len_ge_compare import * +from .len_gt_compare import * +from .len_le_compare import * +from .len_lt_compare import * +from .len_equal_compare import * + +compare_handle_list = [GECompare(), GTCompare(), ContainCompare(), EqualCompare(), LTCompare(), LECompare(), + LenLECompare(), LenGECompare(), LenEqualCompare(), LenGTCompare(), LenLTCompare()] diff --git a/apps/application/flow/step_node/condition_node/compare/compare.py b/apps/application/flow/step_node/condition_node/compare/compare.py new file mode 100644 index 000000000..6cbb4af07 --- /dev/null +++ b/apps/application/flow/step_node/condition_node/compare/compare.py @@ -0,0 +1,20 @@ +# coding=utf-8 +""" + @project: maxkb + @Author:虎 + @file: compare.py + @date:2024/6/7 14:37 + @desc: +""" +from abc import abstractmethod +from typing import List + + +class Compare: + @abstractmethod + def support(self, node_id, fields: List[str], source_value, compare, target_value): + pass + + @abstractmethod + def compare(self, source_value, compare, target_value): + pass diff --git a/apps/application/flow/step_node/condition_node/compare/contain_compare.py b/apps/application/flow/step_node/condition_node/compare/contain_compare.py new file mode 100644 index 000000000..6073131a5 --- /dev/null +++ b/apps/application/flow/step_node/condition_node/compare/contain_compare.py @@ -0,0 +1,23 @@ +# coding=utf-8 +""" + @project: maxkb + @Author:虎 + @file: contain_compare.py + @date:2024/6/11 10:02 + @desc: +""" +from typing import List + +from application.flow.step_node.condition_node.compare.compare import Compare + + +class ContainCompare(Compare): + + def support(self, node_id, fields: List[str], source_value, compare, target_value): + if compare == 'contain': + return True + + def compare(self, source_value, compare, target_value): + if isinstance(source_value, str): + return str(target_value) in source_value + return any([str(item) == str(target_value) for item in source_value]) diff --git a/apps/application/flow/step_node/condition_node/compare/equal_compare.py b/apps/application/flow/step_node/condition_node/compare/equal_compare.py new file mode 100644 index 000000000..0061a82f6 --- /dev/null +++ b/apps/application/flow/step_node/condition_node/compare/equal_compare.py @@ -0,0 +1,21 @@ +# coding=utf-8 +""" + @project: maxkb + @Author:虎 + @file: equal_compare.py + @date:2024/6/7 14:44 + @desc: +""" +from typing import List + +from application.flow.step_node.condition_node.compare.compare import Compare + + +class EqualCompare(Compare): + + def support(self, node_id, fields: List[str], source_value, compare, target_value): + if compare == 'eq': + return True + + def compare(self, source_value, compare, target_value): + return str(source_value) == str(target_value) diff --git a/apps/application/flow/step_node/condition_node/compare/ge_compare.py b/apps/application/flow/step_node/condition_node/compare/ge_compare.py new file mode 100644 index 000000000..d4e22cbd6 --- /dev/null +++ b/apps/application/flow/step_node/condition_node/compare/ge_compare.py @@ -0,0 +1,24 @@ +# coding=utf-8 +""" + @project: maxkb + @Author:虎 + @file: lt_compare.py + @date:2024/6/11 9:52 + @desc: 大于比较器 +""" +from typing import List + +from application.flow.step_node.condition_node.compare.compare import Compare + + +class GECompare(Compare): + + def support(self, node_id, fields: List[str], source_value, compare, target_value): + if compare == 'ge': + return True + + def compare(self, source_value, compare, target_value): + try: + return float(source_value) >= float(target_value) + except Exception as e: + return False diff --git a/apps/application/flow/step_node/condition_node/compare/gt_compare.py b/apps/application/flow/step_node/condition_node/compare/gt_compare.py new file mode 100644 index 000000000..80942abb2 --- /dev/null +++ b/apps/application/flow/step_node/condition_node/compare/gt_compare.py @@ -0,0 +1,24 @@ +# coding=utf-8 +""" + @project: maxkb + @Author:虎 + @file: lt_compare.py + @date:2024/6/11 9:52 + @desc: 大于比较器 +""" +from typing import List + +from application.flow.step_node.condition_node.compare.compare import Compare + + +class GTCompare(Compare): + + def support(self, node_id, fields: List[str], source_value, compare, target_value): + if compare == 'gt': + return True + + def compare(self, source_value, compare, target_value): + try: + return float(source_value) > float(target_value) + except Exception as e: + return False diff --git a/apps/application/flow/step_node/condition_node/compare/is_not_null_compare.py b/apps/application/flow/step_node/condition_node/compare/is_not_null_compare.py new file mode 100644 index 000000000..9c281e381 --- /dev/null +++ b/apps/application/flow/step_node/condition_node/compare/is_not_null_compare.py @@ -0,0 +1,21 @@ +# coding=utf-8 +""" + @project: maxkb + @Author:虎 + @file: is_not_null_compare.py + @date:2024/6/28 10:45 + @desc: +""" +from typing import List + +from application.flow.step_node.condition_node.compare import Compare + + +class IsNotNullCompare(Compare): + + def support(self, node_id, fields: List[str], source_value, compare, target_value): + if compare == 'is_not_null': + return True + + def compare(self, source_value, compare, target_value=None): + return source_value is not None diff --git a/apps/application/flow/step_node/condition_node/compare/is_null_compare.py b/apps/application/flow/step_node/condition_node/compare/is_null_compare.py new file mode 100644 index 000000000..6d49de605 --- /dev/null +++ b/apps/application/flow/step_node/condition_node/compare/is_null_compare.py @@ -0,0 +1,21 @@ +# coding=utf-8 +""" + @project: maxkb + @Author:虎 + @file: is_null_compare.py + @date:2024/6/28 10:45 + @desc: +""" +from typing import List + +from application.flow.step_node.condition_node.compare import Compare + + +class IsNullCompare(Compare): + + def support(self, node_id, fields: List[str], source_value, compare, target_value): + if compare == 'is_null': + return True + + def compare(self, source_value, compare, target_value=None): + return source_value is None diff --git a/apps/application/flow/step_node/condition_node/compare/le_compare.py b/apps/application/flow/step_node/condition_node/compare/le_compare.py new file mode 100644 index 000000000..77a0bca0f --- /dev/null +++ b/apps/application/flow/step_node/condition_node/compare/le_compare.py @@ -0,0 +1,24 @@ +# coding=utf-8 +""" + @project: maxkb + @Author:虎 + @file: lt_compare.py + @date:2024/6/11 9:52 + @desc: 小于比较器 +""" +from typing import List + +from application.flow.step_node.condition_node.compare.compare import Compare + + +class LECompare(Compare): + + def support(self, node_id, fields: List[str], source_value, compare, target_value): + if compare == 'le': + return True + + def compare(self, source_value, compare, target_value): + try: + return float(source_value) <= float(target_value) + except Exception as e: + return False diff --git a/apps/application/flow/step_node/condition_node/compare/len_equal_compare.py b/apps/application/flow/step_node/condition_node/compare/len_equal_compare.py new file mode 100644 index 000000000..f2b0764c5 --- /dev/null +++ b/apps/application/flow/step_node/condition_node/compare/len_equal_compare.py @@ -0,0 +1,24 @@ +# coding=utf-8 +""" + @project: maxkb + @Author:虎 + @file: equal_compare.py + @date:2024/6/7 14:44 + @desc: +""" +from typing import List + +from application.flow.step_node.condition_node.compare.compare import Compare + + +class LenEqualCompare(Compare): + + def support(self, node_id, fields: List[str], source_value, compare, target_value): + if compare == 'len_eq': + return True + + def compare(self, source_value, compare, target_value): + try: + return len(source_value) == int(target_value) + except Exception as e: + return False diff --git a/apps/application/flow/step_node/condition_node/compare/len_ge_compare.py b/apps/application/flow/step_node/condition_node/compare/len_ge_compare.py new file mode 100644 index 000000000..87f11eb2c --- /dev/null +++ b/apps/application/flow/step_node/condition_node/compare/len_ge_compare.py @@ -0,0 +1,24 @@ +# coding=utf-8 +""" + @project: maxkb + @Author:虎 + @file: lt_compare.py + @date:2024/6/11 9:52 + @desc: 大于比较器 +""" +from typing import List + +from application.flow.step_node.condition_node.compare.compare import Compare + + +class LenGECompare(Compare): + + def support(self, node_id, fields: List[str], source_value, compare, target_value): + if compare == 'len_ge': + return True + + def compare(self, source_value, compare, target_value): + try: + return len(source_value) >= int(target_value) + except Exception as e: + return False diff --git a/apps/application/flow/step_node/condition_node/compare/len_gt_compare.py b/apps/application/flow/step_node/condition_node/compare/len_gt_compare.py new file mode 100644 index 000000000..0532d353d --- /dev/null +++ b/apps/application/flow/step_node/condition_node/compare/len_gt_compare.py @@ -0,0 +1,24 @@ +# coding=utf-8 +""" + @project: maxkb + @Author:虎 + @file: lt_compare.py + @date:2024/6/11 9:52 + @desc: 大于比较器 +""" +from typing import List + +from application.flow.step_node.condition_node.compare.compare import Compare + + +class LenGTCompare(Compare): + + def support(self, node_id, fields: List[str], source_value, compare, target_value): + if compare == 'len_gt': + return True + + def compare(self, source_value, compare, target_value): + try: + return len(source_value) > int(target_value) + except Exception as e: + return False diff --git a/apps/application/flow/step_node/condition_node/compare/len_le_compare.py b/apps/application/flow/step_node/condition_node/compare/len_le_compare.py new file mode 100644 index 000000000..d315a754a --- /dev/null +++ b/apps/application/flow/step_node/condition_node/compare/len_le_compare.py @@ -0,0 +1,24 @@ +# coding=utf-8 +""" + @project: maxkb + @Author:虎 + @file: lt_compare.py + @date:2024/6/11 9:52 + @desc: 小于比较器 +""" +from typing import List + +from application.flow.step_node.condition_node.compare.compare import Compare + + +class LenLECompare(Compare): + + def support(self, node_id, fields: List[str], source_value, compare, target_value): + if compare == 'len_le': + return True + + def compare(self, source_value, compare, target_value): + try: + return len(source_value) <= int(target_value) + except Exception as e: + return False diff --git a/apps/application/flow/step_node/condition_node/compare/len_lt_compare.py b/apps/application/flow/step_node/condition_node/compare/len_lt_compare.py new file mode 100644 index 000000000..c89638cd7 --- /dev/null +++ b/apps/application/flow/step_node/condition_node/compare/len_lt_compare.py @@ -0,0 +1,24 @@ +# coding=utf-8 +""" + @project: maxkb + @Author:虎 + @file: lt_compare.py + @date:2024/6/11 9:52 + @desc: 小于比较器 +""" +from typing import List + +from application.flow.step_node.condition_node.compare.compare import Compare + + +class LenLTCompare(Compare): + + def support(self, node_id, fields: List[str], source_value, compare, target_value): + if compare == 'len_lt': + return True + + def compare(self, source_value, compare, target_value): + try: + return len(source_value) < int(target_value) + except Exception as e: + return False diff --git a/apps/application/flow/step_node/condition_node/compare/lt_compare.py b/apps/application/flow/step_node/condition_node/compare/lt_compare.py new file mode 100644 index 000000000..d2d5be748 --- /dev/null +++ b/apps/application/flow/step_node/condition_node/compare/lt_compare.py @@ -0,0 +1,24 @@ +# coding=utf-8 +""" + @project: maxkb + @Author:虎 + @file: lt_compare.py + @date:2024/6/11 9:52 + @desc: 小于比较器 +""" +from typing import List + +from application.flow.step_node.condition_node.compare.compare import Compare + + +class LTCompare(Compare): + + def support(self, node_id, fields: List[str], source_value, compare, target_value): + if compare == 'lt': + return True + + def compare(self, source_value, compare, target_value): + try: + return float(source_value) < float(target_value) + except Exception as e: + return False diff --git a/apps/application/flow/step_node/condition_node/compare/not_contain_compare.py b/apps/application/flow/step_node/condition_node/compare/not_contain_compare.py new file mode 100644 index 000000000..cfa0063a5 --- /dev/null +++ b/apps/application/flow/step_node/condition_node/compare/not_contain_compare.py @@ -0,0 +1,23 @@ +# coding=utf-8 +""" + @project: maxkb + @Author:虎 + @file: contain_compare.py + @date:2024/6/11 10:02 + @desc: +""" +from typing import List + +from application.flow.step_node.condition_node.compare.compare import Compare + + +class ContainCompare(Compare): + + def support(self, node_id, fields: List[str], source_value, compare, target_value): + if compare == 'not_contain': + return True + + def compare(self, source_value, compare, target_value): + if isinstance(source_value, str): + return str(target_value) not in source_value + return not any([str(item) == str(target_value) for item in source_value]) diff --git a/apps/application/flow/step_node/condition_node/i_condition_node.py b/apps/application/flow/step_node/condition_node/i_condition_node.py new file mode 100644 index 000000000..ffb975a98 --- /dev/null +++ b/apps/application/flow/step_node/condition_node/i_condition_node.py @@ -0,0 +1,39 @@ +# coding=utf-8 +""" + @project: maxkb + @Author:虎 + @file: i_condition_node.py + @date:2024/6/7 9:54 + @desc: +""" +import json +from typing import Type + +from rest_framework import serializers + +from application.flow.i_step_node import INode +from common.util.field_message import ErrMessage + + +class ConditionSerializer(serializers.Serializer): + compare = serializers.CharField(required=True, error_messages=ErrMessage.char("比较器")) + value = serializers.CharField(required=True, error_messages=ErrMessage.char("")) + field = serializers.ListField(required=True, error_messages=ErrMessage.char("字段")) + + +class ConditionBranchSerializer(serializers.Serializer): + id = serializers.CharField(required=True, error_messages=ErrMessage.char("分支id")) + type = serializers.CharField(required=True, error_messages=ErrMessage.char("分支类型")) + condition = serializers.CharField(required=True, error_messages=ErrMessage.char("条件or|and")) + conditions = ConditionSerializer(many=True) + + +class ConditionNodeParamsSerializer(serializers.Serializer): + branch = ConditionBranchSerializer(many=True) + + +class IConditionNode(INode): + def get_node_params_serializer_class(self) -> Type[serializers.Serializer]: + return ConditionNodeParamsSerializer + + type = 'condition-node' diff --git a/apps/application/flow/step_node/condition_node/impl/__init__.py b/apps/application/flow/step_node/condition_node/impl/__init__.py new file mode 100644 index 000000000..c21cd3ebb --- /dev/null +++ b/apps/application/flow/step_node/condition_node/impl/__init__.py @@ -0,0 +1,9 @@ +# coding=utf-8 +""" + @project: maxkb + @Author:虎 + @file: __init__.py + @date:2024/6/11 15:35 + @desc: +""" +from .base_condition_node import BaseConditionNode diff --git a/apps/application/flow/step_node/condition_node/impl/base_condition_node.py b/apps/application/flow/step_node/condition_node/impl/base_condition_node.py new file mode 100644 index 000000000..3164bb9fe --- /dev/null +++ b/apps/application/flow/step_node/condition_node/impl/base_condition_node.py @@ -0,0 +1,50 @@ +# coding=utf-8 +""" + @project: maxkb + @Author:虎 + @file: base_condition_node.py + @date:2024/6/7 11:29 + @desc: +""" +from typing import List + +from application.flow.i_step_node import NodeResult +from application.flow.step_node.condition_node.compare import compare_handle_list +from application.flow.step_node.condition_node.i_condition_node import IConditionNode + + +class BaseConditionNode(IConditionNode): + def execute(self, **kwargs) -> NodeResult: + branch_list = self.node_params_serializer.data['branch'] + branch = self._execute(branch_list) + r = NodeResult({'branch_id': branch.get('id'), 'branch_name': branch.get('type')}, {}) + return r + + def _execute(self, branch_list: List): + for branch in branch_list: + if self.branch_assertion(branch): + return branch + + def branch_assertion(self, branch): + condition_list = [self.assertion(row.get('field'), row.get('compare'), row.get('value')) for row in + branch.get('conditions')] + condition = branch.get('condition') + return all(condition_list) if condition == 'and' else any(condition_list) + + def assertion(self, field_list: List[str], compare: str, value): + field_value = self.workflow_manage.get_reference_field(field_list[0], field_list[1:]) + for compare_handler in compare_handle_list: + if compare_handler.support(field_list[0], field_list[1:], field_value, compare, value): + return compare_handler.compare(field_value, compare, value) + + def get_details(self, index: int, **kwargs): + return { + 'name': self.node.properties.get('stepName'), + "index": index, + 'run_time': self.context.get('run_time'), + 'branch_id': self.context.get('branch_id'), + 'branch_name': self.context.get('branch_name'), + 'type': self.node.type, + 'status': self.status, + 'err_message': self.err_message + } diff --git a/apps/application/flow/step_node/direct_reply_node/__init__.py b/apps/application/flow/step_node/direct_reply_node/__init__.py new file mode 100644 index 000000000..cf360f956 --- /dev/null +++ b/apps/application/flow/step_node/direct_reply_node/__init__.py @@ -0,0 +1,9 @@ +# coding=utf-8 +""" + @project: maxkb + @Author:虎 + @file: __init__.py + @date:2024/6/11 17:50 + @desc: +""" +from .impl import * \ No newline at end of file diff --git a/apps/application/flow/step_node/direct_reply_node/i_reply_node.py b/apps/application/flow/step_node/direct_reply_node/i_reply_node.py new file mode 100644 index 000000000..1d5256ac5 --- /dev/null +++ b/apps/application/flow/step_node/direct_reply_node/i_reply_node.py @@ -0,0 +1,46 @@ +# coding=utf-8 +""" + @project: maxkb + @Author:虎 + @file: i_reply_node.py + @date:2024/6/11 16:25 + @desc: +""" +from typing import Type + +from rest_framework import serializers + +from application.flow.i_step_node import INode, NodeResult +from common.exception.app_exception import AppApiException +from common.util.field_message import ErrMessage + + +class ReplyNodeParamsSerializer(serializers.Serializer): + reply_type = serializers.CharField(required=True, error_messages=ErrMessage.char("回复类型")) + fields = serializers.ListField(required=False, error_messages=ErrMessage.list("引用字段")) + content = serializers.CharField(required=False, allow_blank=True, allow_null=True, + error_messages=ErrMessage.char("直接回答内容")) + + def is_valid(self, *, raise_exception=False): + super().is_valid(raise_exception=True) + if self.data.get('reply_type') == 'referencing': + if 'fields' not in self.data: + raise AppApiException(500, "引用字段不能为空") + if len(self.data.get('fields')) < 2: + raise AppApiException(500, "引用字段错误") + else: + if 'content' not in self.data or self.data.get('content') is None: + raise AppApiException(500, "内容不能为空") + + +class IReplyNode(INode): + type = 'reply-node' + + def get_node_params_serializer_class(self) -> Type[serializers.Serializer]: + return ReplyNodeParamsSerializer + + def _run(self): + return self.execute(**self.node_params_serializer.data, **self.flow_params_serializer.data) + + def execute(self, reply_type, stream, fields=None, content=None, **kwargs) -> NodeResult: + pass diff --git a/apps/application/flow/step_node/direct_reply_node/impl/__init__.py b/apps/application/flow/step_node/direct_reply_node/impl/__init__.py new file mode 100644 index 000000000..3307e9089 --- /dev/null +++ b/apps/application/flow/step_node/direct_reply_node/impl/__init__.py @@ -0,0 +1,9 @@ +# coding=utf-8 +""" + @project: maxkb + @Author:虎 + @file: __init__.py + @date:2024/6/11 17:49 + @desc: +""" +from .base_reply_node import * \ No newline at end of file diff --git a/apps/application/flow/step_node/direct_reply_node/impl/base_reply_node.py b/apps/application/flow/step_node/direct_reply_node/impl/base_reply_node.py new file mode 100644 index 000000000..d266265be --- /dev/null +++ b/apps/application/flow/step_node/direct_reply_node/impl/base_reply_node.py @@ -0,0 +1,90 @@ +# coding=utf-8 +""" + @project: maxkb + @Author:虎 + @file: base_reply_node.py + @date:2024/6/11 17:25 + @desc: +""" +from typing import List, Dict + +from langchain_core.messages import AIMessage, AIMessageChunk + +from application.flow import tools +from application.flow.i_step_node import NodeResult, INode +from application.flow.step_node.direct_reply_node.i_reply_node import IReplyNode + + +def get_to_response_write_context(node_variable: Dict, node: INode): + def _write_context(answer, status=200): + node.context['answer'] = answer + + return _write_context + + +def to_stream_response(chat_id, chat_record_id, node_variable: Dict, workflow_variable: Dict, node, workflow, + post_handler): + """ + 将流式数据 转换为 流式响应 + @param chat_id: 会话id + @param chat_record_id: 对话记录id + @param node_variable: 节点数据 + @param workflow_variable: 工作流数据 + @param node: 节点 + @param workflow: 工作流管理器 + @param post_handler: 后置处理器 输出结果后执行 + @return: 流式响应 + """ + response = node_variable.get('result') + _write_context = get_to_response_write_context(node_variable, node) + return tools.to_stream_response(chat_id, chat_record_id, response, workflow, _write_context, post_handler) + + +def to_response(chat_id, chat_record_id, node_variable: Dict, workflow_variable: Dict, node, workflow, + post_handler): + """ + 将结果转换 + @param chat_id: 会话id + @param chat_record_id: 对话记录id + @param node_variable: 节点数据 + @param workflow_variable: 工作流数据 + @param node: 节点 + @param workflow: 工作流管理器 + @param post_handler: 后置处理器 + @return: 响应 + """ + response = node_variable.get('result') + _write_context = get_to_response_write_context(node_variable, node) + return tools.to_response(chat_id, chat_record_id, response, workflow, _write_context, post_handler) + + +class BaseReplyNode(IReplyNode): + def execute(self, reply_type, stream, fields=None, content=None, **kwargs) -> NodeResult: + if reply_type == 'referencing': + result = self.get_reference_content(fields) + else: + result = self.generate_reply_content(content) + if stream: + return NodeResult({'result': iter([AIMessageChunk(content=result)])}, {}, + _to_response=to_stream_response) + else: + return NodeResult({'result': AIMessage(content=result)}, {}, _to_response=to_response) + + def generate_reply_content(self, prompt): + return self.workflow_manage.generate_prompt(prompt) + + def get_reference_content(self, fields: List[str]): + return str(self.workflow_manage.get_reference_field( + fields[0], + fields[1:])) + + def get_details(self, index: int, **kwargs): + return { + 'name': self.node.properties.get('stepName'), + "index": index, + 'run_time': self.context.get('run_time'), + 'type': self.node.type, + 'answer': self.context.get('answer'), + 'status': self.status, + 'err_message': self.err_message + } diff --git a/apps/application/flow/step_node/question_node/__init__.py b/apps/application/flow/step_node/question_node/__init__.py new file mode 100644 index 000000000..98a1afcd9 --- /dev/null +++ b/apps/application/flow/step_node/question_node/__init__.py @@ -0,0 +1,9 @@ +# coding=utf-8 +""" + @project: maxkb + @Author:虎 + @file: __init__.py + @date:2024/6/11 15:30 + @desc: +""" +from .impl import * diff --git a/apps/application/flow/step_node/question_node/i_question_node.py b/apps/application/flow/step_node/question_node/i_question_node.py new file mode 100644 index 000000000..ede120def --- /dev/null +++ b/apps/application/flow/step_node/question_node/i_question_node.py @@ -0,0 +1,37 @@ +# coding=utf-8 +""" + @project: maxkb + @Author:虎 + @file: i_chat_node.py + @date:2024/6/4 13:58 + @desc: +""" +from typing import Type + +from rest_framework import serializers + +from application.flow.i_step_node import INode, NodeResult +from common.util.field_message import ErrMessage + + +class QuestionNodeSerializer(serializers.Serializer): + model_id = serializers.CharField(required=True, error_messages=ErrMessage.char("模型id")) + system = serializers.CharField(required=False, allow_blank=True, allow_null=True, + error_messages=ErrMessage.char("角色设定")) + prompt = serializers.CharField(required=True, error_messages=ErrMessage.char("提示词")) + # 多轮对话数量 + dialogue_number = serializers.IntegerField(required=True, error_messages=ErrMessage.integer("多轮对话数量")) + + +class IQuestionNode(INode): + type = 'question-node' + + def get_node_params_serializer_class(self) -> Type[serializers.Serializer]: + return QuestionNodeSerializer + + def _run(self): + return self.execute(**self.node_params_serializer.data, **self.flow_params_serializer.data) + + def execute(self, model_id, system, prompt, dialogue_number, history_chat_record, stream, chat_id, chat_record_id, + **kwargs) -> NodeResult: + pass diff --git a/apps/application/flow/step_node/question_node/impl/__init__.py b/apps/application/flow/step_node/question_node/impl/__init__.py new file mode 100644 index 000000000..d85aa8724 --- /dev/null +++ b/apps/application/flow/step_node/question_node/impl/__init__.py @@ -0,0 +1,9 @@ +# coding=utf-8 +""" + @project: maxkb + @Author:虎 + @file: __init__.py + @date:2024/6/11 15:35 + @desc: +""" +from .base_question_node import BaseQuestionNode diff --git a/apps/application/flow/step_node/question_node/impl/base_question_node.py b/apps/application/flow/step_node/question_node/impl/base_question_node.py new file mode 100644 index 000000000..65fc52c32 --- /dev/null +++ b/apps/application/flow/step_node/question_node/impl/base_question_node.py @@ -0,0 +1,196 @@ +# coding=utf-8 +""" + @project: maxkb + @Author:虎 + @file: base_question_node.py + @date:2024/6/4 14:30 + @desc: +""" +import json +import time +from functools import reduce +from typing import List, Dict + +from django.db.models import QuerySet +from langchain.schema import HumanMessage, SystemMessage +from langchain_core.messages import BaseMessage + +from application.flow import tools +from application.flow.i_step_node import NodeResult, INode +from application.flow.step_node.question_node.i_question_node import IQuestionNode +from common.util.rsa_util import rsa_long_decrypt +from setting.models import Model +from setting.models_provider.constants.model_provider_constants import ModelProvideConstants + + +def write_context_stream(node_variable: Dict, workflow_variable: Dict, node: INode, workflow): + """ + 写入上下文数据 (流式) + @param node_variable: 节点数据 + @param workflow_variable: 全局数据 + @param node: 节点 + @param workflow: 工作流管理器 + """ + response = node_variable.get('result') + answer = '' + for chunk in response: + answer += chunk.content + chat_model = node_variable.get('chat_model') + message_tokens = chat_model.get_num_tokens_from_messages(node_variable.get('message_list')) + answer_tokens = chat_model.get_num_tokens(answer) + node.context['message_tokens'] = message_tokens + node.context['answer_tokens'] = answer_tokens + node.context['answer'] = answer + node.context['history_message'] = node_variable['history_message'] + node.context['question'] = node_variable['question'] + node.context['run_time'] = time.time() - node.context['start_time'] + + +def write_context(node_variable: Dict, workflow_variable: Dict, node: INode, workflow): + """ + 写入上下文数据 + @param node_variable: 节点数据 + @param workflow_variable: 全局数据 + @param node: 节点实例对象 + @param workflow: 工作流管理器 + """ + response = node_variable.get('result') + chat_model = node_variable.get('chat_model') + answer = response.content + message_tokens = chat_model.get_num_tokens_from_messages(node_variable.get('message_list')) + answer_tokens = chat_model.get_num_tokens(answer) + node.context['message_tokens'] = message_tokens + node.context['answer_tokens'] = answer_tokens + node.context['answer'] = answer + node.context['history_message'] = node_variable['history_message'] + node.context['question'] = node_variable['question'] + + +def get_to_response_write_context(node_variable: Dict, node: INode): + def _write_context(answer, status=200): + chat_model = node_variable.get('chat_model') + + if status == 200: + answer_tokens = chat_model.get_num_tokens(answer) + message_tokens = chat_model.get_num_tokens_from_messages(node_variable.get('message_list')) + else: + answer_tokens = 0 + message_tokens = 0 + node.err_message = answer + node.status = status + node.context['message_tokens'] = message_tokens + node.context['answer_tokens'] = answer_tokens + node.context['answer'] = answer + node.context['run_time'] = time.time() - node.context['start_time'] + + return _write_context + + +def to_stream_response(chat_id, chat_record_id, node_variable: Dict, workflow_variable: Dict, node, workflow, + post_handler): + """ + 将流式数据 转换为 流式响应 + @param chat_id: 会话id + @param chat_record_id: 对话记录id + @param node_variable: 节点数据 + @param workflow_variable: 工作流数据 + @param node: 节点 + @param workflow: 工作流管理器 + @param post_handler: 后置处理器 输出结果后执行 + @return: 流式响应 + """ + response = node_variable.get('result') + _write_context = get_to_response_write_context(node_variable, node) + return tools.to_stream_response(chat_id, chat_record_id, response, workflow, _write_context, post_handler) + + +def to_response(chat_id, chat_record_id, node_variable: Dict, workflow_variable: Dict, node, workflow, + post_handler): + """ + 将结果转换 + @param chat_id: 会话id + @param chat_record_id: 对话记录id + @param node_variable: 节点数据 + @param workflow_variable: 工作流数据 + @param node: 节点 + @param workflow: 工作流管理器 + @param post_handler: 后置处理器 + @return: 响应 + """ + response = node_variable.get('result') + _write_context = get_to_response_write_context(node_variable, node) + return tools.to_response(chat_id, chat_record_id, response, workflow, _write_context, post_handler) + + +class BaseQuestionNode(IQuestionNode): + def execute(self, model_id, system, prompt, dialogue_number, history_chat_record, stream, chat_id, chat_record_id, + **kwargs) -> NodeResult: + model = QuerySet(Model).filter(id=model_id).first() + chat_model = ModelProvideConstants[model.provider].value.get_model(model.model_type, model.model_name, + json.loads( + rsa_long_decrypt(model.credential)), + streaming=True) + history_message = self.get_history_message(history_chat_record, dialogue_number) + self.context['history_message'] = history_message + question = self.generate_prompt_question(prompt) + self.context['question'] = question.content + message_list = self.generate_message_list(system, prompt, history_message) + self.context['message_list'] = message_list + if stream: + r = chat_model.stream(message_list) + return NodeResult({'result': r, 'chat_model': chat_model, 'message_list': message_list, + 'get_to_response_write_context': get_to_response_write_context, + 'history_message': history_message, 'question': question.content}, {}, + _write_context=write_context_stream, + _to_response=to_stream_response) + else: + r = chat_model.invoke(message_list) + return NodeResult({'result': r, 'chat_model': chat_model, 'message_list': message_list, + 'history_message': history_message, 'question': question.content}, {}, + _write_context=write_context, _to_response=to_response) + + @staticmethod + def get_history_message(history_chat_record, dialogue_number): + start_index = len(history_chat_record) - dialogue_number + history_message = reduce(lambda x, y: [*x, *y], [ + [history_chat_record[index].get_human_message(), history_chat_record[index].get_ai_message()] + for index in + range(start_index if start_index > 0 else 0, len(history_chat_record))], []) + return history_message + + def generate_prompt_question(self, prompt): + return HumanMessage(self.workflow_manage.generate_prompt(prompt)) + + def generate_message_list(self, system: str, prompt: str, history_message): + if system is None or len(system) == 0: + return [SystemMessage(self.workflow_manage.generate_prompt(system)), *history_message, + HumanMessage(self.workflow_manage.generate_prompt(prompt))] + else: + return [*history_message, HumanMessage(self.workflow_manage.generate_prompt(prompt))] + + @staticmethod + def reset_message_list(message_list: List[BaseMessage], answer_text): + result = [{'role': 'user' if isinstance(message, HumanMessage) else 'ai', 'content': message.content} for + message + in + message_list] + result.append({'role': 'ai', 'content': answer_text}) + return result + + def get_details(self, index: int, **kwargs): + return { + 'name': self.node.properties.get('stepName'), + "index": index, + 'run_time': self.context.get('run_time'), + 'system': self.node_params.get('system'), + 'history_message': [{'content': message.content, 'role': message.type} for message in + (self.context.get('history_message') if self.context.get( + 'history_message') is not None else [])], + 'question': self.context.get('question'), + 'answer': self.context.get('answer'), + 'type': self.node.type, + 'message_tokens': self.context.get('message_tokens'), + 'answer_tokens': self.context.get('answer_tokens'), + 'status': self.status, + 'err_message': self.err_message + } diff --git a/apps/application/flow/step_node/search_dataset_node/__init__.py b/apps/application/flow/step_node/search_dataset_node/__init__.py new file mode 100644 index 000000000..98a1afcd9 --- /dev/null +++ b/apps/application/flow/step_node/search_dataset_node/__init__.py @@ -0,0 +1,9 @@ +# coding=utf-8 +""" + @project: maxkb + @Author:虎 + @file: __init__.py + @date:2024/6/11 15:30 + @desc: +""" +from .impl import * diff --git a/apps/application/flow/step_node/search_dataset_node/i_search_dataset_node.py b/apps/application/flow/step_node/search_dataset_node/i_search_dataset_node.py new file mode 100644 index 000000000..0a134527c --- /dev/null +++ b/apps/application/flow/step_node/search_dataset_node/i_search_dataset_node.py @@ -0,0 +1,61 @@ +# coding=utf-8 +""" + @project: maxkb + @Author:虎 + @file: i_search_dataset_node.py + @date:2024/6/3 17:52 + @desc: +""" +import re +from typing import Type + +from django.core import validators +from rest_framework import serializers + +from application.flow.i_step_node import INode, NodeResult +from common.util.field_message import ErrMessage + + +class DatasetSettingSerializer(serializers.Serializer): + # 需要查询的条数 + top_n = serializers.IntegerField(required=True, + error_messages=ErrMessage.integer("引用分段数")) + # 相似度 0-1之间 + similarity = serializers.FloatField(required=True, max_value=2, min_value=0, + error_messages=ErrMessage.float("引用分段数")) + search_mode = serializers.CharField(required=True, validators=[ + validators.RegexValidator(regex=re.compile("^embedding|keywords|blend$"), + message="类型只支持register|reset_password", code=500) + ], error_messages=ErrMessage.char("检索模式")) + max_paragraph_char_number = serializers.IntegerField(required=True, + error_messages=ErrMessage.float("最大引用分段字数")) + + +class SearchDatasetStepNodeSerializer(serializers.Serializer): + # 需要查询的数据集id列表 + dataset_id_list = serializers.ListField(required=True, child=serializers.UUIDField(required=True), + error_messages=ErrMessage.list("数据集id列表")) + dataset_setting = DatasetSettingSerializer(required=True) + + question_reference_address = serializers.ListField(required=True, ) + + def is_valid(self, *, raise_exception=False): + super().is_valid(raise_exception=True) + + +class ISearchDatasetStepNode(INode): + type = 'search-dataset-node' + + def get_node_params_serializer_class(self) -> Type[serializers.Serializer]: + return SearchDatasetStepNodeSerializer + + def _run(self): + question = self.workflow_manage.get_reference_field( + self.node_params_serializer.data.get('question_reference_address')[0], + self.node_params_serializer.data.get('question_reference_address')[1:]) + return self.execute(**self.node_params_serializer.data, question=str(question), exclude_paragraph_id_list=[]) + + def execute(self, dataset_id_list, dataset_setting, question, + exclude_paragraph_id_list=None, + **kwargs) -> NodeResult: + pass diff --git a/apps/application/flow/step_node/search_dataset_node/impl/__init__.py b/apps/application/flow/step_node/search_dataset_node/impl/__init__.py new file mode 100644 index 000000000..a9cff0d09 --- /dev/null +++ b/apps/application/flow/step_node/search_dataset_node/impl/__init__.py @@ -0,0 +1,9 @@ +# coding=utf-8 +""" + @project: maxkb + @Author:虎 + @file: __init__.py + @date:2024/6/11 15:35 + @desc: +""" +from .base_search_dataset_node import BaseSearchDatasetNode diff --git a/apps/application/flow/step_node/search_dataset_node/impl/base_search_dataset_node.py b/apps/application/flow/step_node/search_dataset_node/impl/base_search_dataset_node.py new file mode 100644 index 000000000..20e0af9fc --- /dev/null +++ b/apps/application/flow/step_node/search_dataset_node/impl/base_search_dataset_node.py @@ -0,0 +1,93 @@ +# coding=utf-8 +""" + @project: maxkb + @Author:虎 + @file: base_search_dataset_node.py + @date:2024/6/4 11:56 + @desc: +""" +import os +from typing import List, Dict + +from django.db.models import QuerySet + +from application.flow.i_step_node import NodeResult +from application.flow.step_node.search_dataset_node.i_search_dataset_node import ISearchDatasetStepNode +from common.config.embedding_config import EmbeddingModel, VectorStore +from common.db.search import native_search +from common.util.file_util import get_file_content +from dataset.models import Document, Paragraph +from embedding.models import SearchMode +from smartdoc.conf import PROJECT_DIR + + +class BaseSearchDatasetNode(ISearchDatasetStepNode): + def execute(self, dataset_id_list, dataset_setting, question, + exclude_paragraph_id_list=None, + **kwargs) -> NodeResult: + self.context['question'] = question + embedding_model = EmbeddingModel.get_embedding_model() + embedding_value = embedding_model.embed_query(question) + vector = VectorStore.get_embedding_vector() + exclude_document_id_list = [str(document.id) for document in + QuerySet(Document).filter( + dataset_id__in=dataset_id_list, + is_active=False)] + embedding_list = vector.query(question, embedding_value, dataset_id_list, exclude_document_id_list, + exclude_paragraph_id_list, True, dataset_setting.get('top_n'), + dataset_setting.get('similarity'), SearchMode(dataset_setting.get('search_mode'))) + if embedding_list is None: + return NodeResult({'paragraph_list': [], 'is_hit_handling_method': []}, {}) + paragraph_list = self.list_paragraph(embedding_list, vector) + result = [self.reset_paragraph(paragraph, embedding_list) for paragraph in paragraph_list] + return NodeResult({'paragraph_list': result, + 'is_hit_handling_method_list': [row for row in result if row.get('is_hit_handling_method')], + 'data': '\n'.join([paragraph.get('content') for paragraph in paragraph_list]), + 'directly_return': '\n'.join([paragraph.get('content') for paragraph in result if + paragraph.get('is_hit_handling_method')]), + 'question': question}, + + {}) + + @staticmethod + def reset_paragraph(paragraph: Dict, embedding_list: List): + filter_embedding_list = [embedding for embedding in embedding_list if + str(embedding.get('paragraph_id')) == str(paragraph.get('id'))] + if filter_embedding_list is not None and len(filter_embedding_list) > 0: + find_embedding = filter_embedding_list[-1] + return { + **paragraph, + 'similarity': find_embedding.get('similarity'), + 'is_hit_handling_method': find_embedding.get('similarity') > paragraph.get( + 'directly_return_similarity') and paragraph.get('hit_handling_method') == 'directly_return' + } + + @staticmethod + def list_paragraph(embedding_list: List, vector): + paragraph_id_list = [row.get('paragraph_id') for row in embedding_list] + if paragraph_id_list is None or len(paragraph_id_list) == 0: + return [] + paragraph_list = native_search(QuerySet(Paragraph).filter(id__in=paragraph_id_list), + get_file_content( + os.path.join(PROJECT_DIR, "apps", "application", 'sql', + 'list_dataset_paragraph_by_paragraph_id.sql')), + with_table_name=True) + # 如果向量库中存在脏数据 直接删除 + if len(paragraph_list) != len(paragraph_id_list): + exist_paragraph_list = [row.get('id') for row in paragraph_list] + for paragraph_id in paragraph_id_list: + if not exist_paragraph_list.__contains__(paragraph_id): + vector.delete_by_paragraph_id(paragraph_id) + return paragraph_list + + def get_details(self, index: int, **kwargs): + return { + 'name': self.node.properties.get('stepName'), + 'question': self.context.get('question'), + "index": index, + 'run_time': self.context.get('run_time'), + 'paragraph_list': self.context.get('paragraph_list'), + 'type': self.node.type, + 'status': self.status, + 'err_message': self.err_message + } diff --git a/apps/application/flow/step_node/start_node/__init__.py b/apps/application/flow/step_node/start_node/__init__.py new file mode 100644 index 000000000..98a1afcd9 --- /dev/null +++ b/apps/application/flow/step_node/start_node/__init__.py @@ -0,0 +1,9 @@ +# coding=utf-8 +""" + @project: maxkb + @Author:虎 + @file: __init__.py + @date:2024/6/11 15:30 + @desc: +""" +from .impl import * diff --git a/apps/application/flow/step_node/start_node/i_start_node.py b/apps/application/flow/step_node/start_node/i_start_node.py new file mode 100644 index 000000000..4c1ecfd2a --- /dev/null +++ b/apps/application/flow/step_node/start_node/i_start_node.py @@ -0,0 +1,26 @@ +# coding=utf-8 +""" + @project: maxkb + @Author:虎 + @file: i_start_node.py + @date:2024/6/3 16:54 + @desc: +""" +from typing import Type + +from rest_framework import serializers + +from application.flow.i_step_node import INode, NodeResult + + +class IStarNode(INode): + type = 'start-node' + + def get_node_params_serializer_class(self) -> Type[serializers.Serializer] | None: + return None + + def _run(self): + return self.execute(**self.flow_params_serializer.data) + + def execute(self, question, **kwargs) -> NodeResult: + pass diff --git a/apps/application/flow/step_node/start_node/impl/__init__.py b/apps/application/flow/step_node/start_node/impl/__init__.py new file mode 100644 index 000000000..b68a92d02 --- /dev/null +++ b/apps/application/flow/step_node/start_node/impl/__init__.py @@ -0,0 +1,9 @@ +# coding=utf-8 +""" + @project: maxkb + @Author:虎 + @file: __init__.py + @date:2024/6/11 15:36 + @desc: +""" +from .base_start_node import BaseStartStepNode diff --git a/apps/application/flow/step_node/start_node/impl/base_start_node.py b/apps/application/flow/step_node/start_node/impl/base_start_node.py new file mode 100644 index 000000000..7043e42eb --- /dev/null +++ b/apps/application/flow/step_node/start_node/impl/base_start_node.py @@ -0,0 +1,33 @@ +# coding=utf-8 +""" + @project: maxkb + @Author:虎 + @file: base_start_node.py + @date:2024/6/3 17:17 + @desc: +""" +import time +from datetime import datetime + +from application.flow.i_step_node import NodeResult +from application.flow.step_node.start_node.i_start_node import IStarNode + + +class BaseStartStepNode(IStarNode): + def execute(self, question, **kwargs) -> NodeResult: + """ + 开始节点 初始化全局变量 + """ + return NodeResult({'question': question}, + {'time': datetime.now().strftime('%Y-%m-%d %H:%M:%S'), 'start_time': time.time()}) + + def get_details(self, index: int, **kwargs): + return { + 'name': self.node.properties.get('stepName'), + "index": index, + "question": self.context.get('question'), + 'run_time': self.context.get('run_time'), + 'type': self.node.type, + 'status': self.status, + 'err_message': self.err_message + } diff --git a/apps/application/flow/tools.py b/apps/application/flow/tools.py new file mode 100644 index 000000000..839aae8da --- /dev/null +++ b/apps/application/flow/tools.py @@ -0,0 +1,87 @@ +# coding=utf-8 +""" + @project: maxkb + @Author:虎 + @file: utils.py + @date:2024/6/6 15:15 + @desc: +""" +import json +from typing import Iterator + +from django.http import StreamingHttpResponse +from langchain_core.messages import BaseMessageChunk, BaseMessage + +from application.flow.i_step_node import WorkFlowPostHandler +from common.response import result + + +def event_content(chat_id, chat_record_id, response, workflow, + write_context, + post_handler: WorkFlowPostHandler): + """ + 用于处理流式输出 + @param chat_id: 会话id + @param chat_record_id: 对话记录id + @param response: 响应数据 + @param workflow: 工作流管理器 + @param write_context 写入节点上下文 + @param post_handler: 后置处理器 + """ + answer = '' + try: + for chunk in response: + answer += chunk.content + yield 'data: ' + json.dumps({'chat_id': str(chat_id), 'id': str(chat_record_id), 'operate': True, + 'content': chunk.content, 'is_end': False}, ensure_ascii=False) + "\n\n" + write_context(answer, 200) + post_handler.handler(chat_id, chat_record_id, answer, workflow) + yield 'data: ' + json.dumps({'chat_id': str(chat_id), 'id': str(chat_record_id), 'operate': True, + 'content': '', 'is_end': True}, ensure_ascii=False) + "\n\n" + except Exception as e: + answer = str(e) + write_context(answer, 500) + post_handler.handler(chat_id, chat_record_id, answer, workflow) + yield 'data: ' + json.dumps({'chat_id': str(chat_id), 'id': str(chat_record_id), 'operate': True, + 'content': answer, 'is_end': True}, ensure_ascii=False) + "\n\n" + + +def to_stream_response(chat_id, chat_record_id, response: Iterator[BaseMessageChunk], workflow, write_context, + post_handler): + """ + 将结果转换为服务流输出 + @param chat_id: 会话id + @param chat_record_id: 对话记录id + @param response: 响应数据 + @param workflow: 工作流管理器 + @param write_context 写入节点上下文 + @param post_handler: 后置处理器 + @return: 响应 + """ + r = StreamingHttpResponse( + streaming_content=event_content(chat_id, chat_record_id, response, workflow, write_context, post_handler), + content_type='text/event-stream;charset=utf-8', + charset='utf-8') + + r['Cache-Control'] = 'no-cache' + return r + + +def to_response(chat_id, chat_record_id, response: BaseMessage, workflow, write_context, + post_handler: WorkFlowPostHandler): + """ + 将结果转换为服务输出 + + @param chat_id: 会话id + @param chat_record_id: 对话记录id + @param response: 响应数据 + @param workflow: 工作流管理器 + @param write_context 写入节点上下文 + @param post_handler: 后置处理器 + @return: 响应 + """ + answer = response.content + write_context(answer) + post_handler.handler(chat_id, chat_record_id, answer, workflow) + return result.success({'chat_id': str(chat_id), 'id': str(chat_record_id), 'operate': True, + 'content': answer, 'is_end': True}) diff --git a/apps/application/flow/workflow_manage.py b/apps/application/flow/workflow_manage.py new file mode 100644 index 000000000..f99d9351b --- /dev/null +++ b/apps/application/flow/workflow_manage.py @@ -0,0 +1,282 @@ +# coding=utf-8 +""" + @project: maxkb + @Author:虎 + @file: workflow_manage.py + @date:2024/1/9 17:40 + @desc: +""" +from functools import reduce +from typing import List, Dict + +from langchain_core.messages import AIMessageChunk, AIMessage +from langchain_core.prompts import PromptTemplate + +from application.flow import tools +from application.flow.i_step_node import INode, WorkFlowPostHandler, NodeResult +from application.flow.step_node import get_node +from common.exception.app_exception import AppApiException + + +class Edge: + def __init__(self, _id: str, _type: str, sourceNodeId: str, targetNodeId: str, **keywords): + self.id = _id + self.type = _type + self.sourceNodeId = sourceNodeId + self.targetNodeId = targetNodeId + for keyword in keywords: + self.__setattr__(keyword, keywords.get(keyword)) + + +class Node: + def __init__(self, _id: str, _type: str, x: int, y: int, properties: dict, **kwargs): + self.id = _id + self.type = _type + self.x = x + self.y = y + self.properties = properties + for keyword in kwargs: + self.__setattr__(keyword, kwargs.get(keyword)) + + +end_nodes = ['ai-chat-node', 'reply-node'] + + +class Flow: + def __init__(self, nodes: List[Node], edges: List[Edge]): + self.nodes = nodes + self.edges = edges + + @staticmethod + def new_instance(flow_obj: Dict): + nodes = flow_obj.get('nodes') + edges = flow_obj.get('edges') + nodes = [Node(node.get('id'), node.get('type'), **node) + for node in nodes] + edges = [Edge(edge.get('id'), edge.get('type'), **edge) for edge in edges] + return Flow(nodes, edges) + + def get_start_node(self): + start_node_list = [node for node in self.nodes if node.id == 'start-node'] + return start_node_list[0] + + def is_valid(self): + """ + 校验工作流数据 + """ + self.is_valid_start_node() + self.is_valid_base_node() + self.is_valid_work_flow() + + @staticmethod + def is_valid_node_params(node: Node): + get_node(node.type)(node, None, None) + + def is_valid_node(self, node: Node): + self.is_valid_node_params(node) + if node.type == 'condition-node': + branch_list = node.properties.get('node_data').get('branch') + for branch in branch_list: + source_anchor_id = f"{node.id}_{branch.get('id')}_right" + edge_list = [edge for edge in self.edges if edge.sourceAnchorId == source_anchor_id] + if len(edge_list) == 0: + raise AppApiException(500, + f'{node.properties.get("stepName")} 节点的{branch.get("type")}分支需要连接') + elif len(edge_list) > 1: + raise AppApiException(500, + f'{node.properties.get("stepName")} 节点的{branch.get("type")}分支不能连接俩个节点') + + else: + edge_list = [edge for edge in self.edges if edge.sourceNodeId == node.id] + if len(edge_list) == 0 and not end_nodes.__contains__(node.type): + raise AppApiException(500, f'{node.properties.get("stepName")} 节点不能当做结束节点') + elif len(edge_list) > 1: + raise AppApiException(500, + f'{node.properties.get("stepName")} 节点不能连接俩个节点') + + def get_next_nodes(self, node: Node): + edge_list = [edge for edge in self.edges if edge.sourceNodeId == node.id] + node_list = reduce(lambda x, y: [*x, *y], + [[node for node in self.nodes if node.id == edge.targetNodeId] for edge in edge_list], + []) + if len(node_list) == 0 and not end_nodes.__contains__(node.type): + raise AppApiException(500, + f'不存在的下一个节点') + return node_list + + def is_valid_work_flow(self, up_node=None): + if up_node is None: + up_node = self.get_start_node() + self.is_valid_node(up_node) + next_nodes = self.get_next_nodes(up_node) + for next_node in next_nodes: + self.is_valid_work_flow(next_node) + + def is_valid_start_node(self): + start_node_list = [node for node in self.nodes if node.id == 'start-node'] + if len(start_node_list) == 0: + raise AppApiException(500, '开始节点必填') + if len(start_node_list) > 1: + raise AppApiException(500, '开始节点只能有一个') + + def is_valid_base_node(self): + base_node_list = [node for node in self.nodes if node.id == 'base-node'] + if len(base_node_list) == 0: + raise AppApiException(500, '基本信息节点必填') + if len(base_node_list) > 1: + raise AppApiException(500, '基本信息节点只能有一个') + + +class WorkflowManage: + def __init__(self, flow: Flow, params, work_flow_post_handler: WorkFlowPostHandler): + self.params = params + self.flow = flow + self.context = {} + self.node_context = [] + self.work_flow_post_handler = work_flow_post_handler + self.current_node = None + self.current_result = None + + def run(self): + """ + 运行工作流 + """ + try: + while self.has_next_node(self.current_result): + self.current_node = self.get_next_node() + self.node_context.append(self.current_node) + self.current_result = self.current_node.run() + if self.has_next_node(self.current_result): + self.current_result.write_context(self.current_node, self) + else: + r = self.current_result.to_response(self.params['chat_id'], self.params['chat_record_id'], + self.current_node, self, + self.work_flow_post_handler) + return r + except Exception as e: + if self.params.get('stream'): + return tools.to_stream_response(self.params['chat_id'], self.params['chat_record_id'], + iter([AIMessageChunk(str(e))]), self, + self.current_node.get_write_error_context(e), + self.work_flow_post_handler) + else: + return tools.to_response(self.params['chat_id'], self.params['chat_record_id'], + AIMessage(str(e)), self, self.current_node.get_write_error_context(e), + self.work_flow_post_handler) + + def has_next_node(self, node_result: NodeResult | None): + """ + 是否有下一个可运行的节点 + """ + if self.current_node is None: + if self.get_start_node() is not None: + return True + else: + if node_result is not None and node_result.is_assertion_result(): + for edge in self.flow.edges: + if (edge.sourceNodeId == self.current_node.id and + f"{edge.sourceNodeId}_{node_result.node_variable.get('branch_id')}_right" == edge.sourceAnchorId): + return True + else: + for edge in self.flow.edges: + if edge.sourceNodeId == self.current_node.id: + return True + return False + + def get_runtime_details(self): + details_result = {} + for index in range(len(self.node_context)): + node = self.node_context[index] + details = node.get_details(index) + details_result[node.id] = details + return details_result + + def get_next_node(self): + """ + 获取下一个可运行的所有节点 + """ + if self.current_node is None: + node = self.get_start_node() + node_instance = get_node(node.type)(node, self.params, self.context) + return node_instance + if self.current_result is not None and self.current_result.is_assertion_result(): + for edge in self.flow.edges: + if (edge.sourceNodeId == self.current_node.id and + f"{edge.sourceNodeId}_{self.current_result.node_variable.get('branch_id')}_right" == edge.sourceAnchorId): + return self.get_node_cls_by_id(edge.targetNodeId) + else: + for edge in self.flow.edges: + if edge.sourceNodeId == self.current_node.id: + return self.get_node_cls_by_id(edge.targetNodeId) + + return None + + def get_reference_field(self, node_id: str, fields: List[str]): + """ + + @param node_id: 节点id + @param fields: 字段 + @return: + """ + if node_id == 'global': + return INode.get_field(self.context, fields) + else: + return self.get_node_by_id(node_id).get_reference_field(fields) + + def generate_prompt(self, prompt: str): + """ + 格式化生成提示词 + @param prompt: 提示词信息 + @return: 格式化后的提示词 + """ + context = { + 'global': self.context, + } + + for node in self.node_context: + properties = node.node.properties + node_config = properties.get('config') + if node_config is not None: + fields = node_config.get('fields') + if fields is not None: + for field in fields: + globeLabel = f"{properties.get('stepName')}.{field.get('value')}" + globeValue = f"context['{node.id}'].{field.get('value')}" + prompt = prompt.replace(globeLabel, globeValue) + global_fields = node_config.get('globalFields') + if global_fields is not None: + for field in global_fields: + globeLabel = f"全局变量.{field.get('value')}" + globeValue = f"context['global'].{field.get('value')}" + prompt = prompt.replace(globeLabel, globeValue) + context[node.id] = node.context + prompt_template = PromptTemplate.from_template(prompt, template_format='jinja2') + + value = prompt_template.format(context=context) + return value + + def get_start_node(self): + """ + 获取启动节点 + @return: + """ + start_node_list = [node for node in self.flow.nodes if node.type == 'start-node'] + return start_node_list[0] + + def get_node_cls_by_id(self, node_id): + for node in self.flow.nodes: + if node.id == node_id: + node_instance = get_node(node.type)(node, + self.params, self) + return node_instance + return None + + def get_node_by_id(self, node_id): + for node in self.node_context: + if node.id == node_id: + return node + return None + + def get_node_reference(self, reference_address: Dict): + node = self.get_node_by_id(reference_address.get('node_id')) + return node.context[reference_address.get('node_field')] diff --git a/apps/application/migrations/0009_application_type_application_work_flow_and_more.py b/apps/application/migrations/0009_application_type_application_work_flow_and_more.py new file mode 100644 index 000000000..5d0bf0c9f --- /dev/null +++ b/apps/application/migrations/0009_application_type_application_work_flow_and_more.py @@ -0,0 +1,38 @@ +# Generated by Django 4.1.13 on 2024-06-25 16:30 + +from django.db import migrations, models +import django.db.models.deletion +import uuid + + +class Migration(migrations.Migration): + + dependencies = [ + ('application', '0008_chat_is_deleted'), + ] + + operations = [ + migrations.AddField( + model_name='application', + name='type', + field=models.CharField(choices=[('SIMPLE', '简易'), ('WORK_FLOW', '工作流')], default='SIMPLE', max_length=256, verbose_name='应用类型'), + ), + migrations.AddField( + model_name='application', + name='work_flow', + field=models.JSONField(default=dict, verbose_name='工作流数据'), + ), + migrations.CreateModel( + name='WorkFlowVersion', + fields=[ + ('create_time', models.DateTimeField(auto_now_add=True, verbose_name='创建时间')), + ('update_time', models.DateTimeField(auto_now=True, verbose_name='修改时间')), + ('id', models.UUIDField(default=uuid.uuid1, editable=False, primary_key=True, serialize=False, verbose_name='主键id')), + ('work_flow', models.JSONField(default=dict, verbose_name='工作流数据')), + ('application', models.ForeignKey(on_delete=django.db.models.deletion.CASCADE, to='application.application')), + ], + options={ + 'db_table': 'application_work_flow_version', + }, + ), + ] diff --git a/apps/application/models/application.py b/apps/application/models/application.py index 073e980c7..bdd0a672e 100644 --- a/apps/application/models/application.py +++ b/apps/application/models/application.py @@ -6,6 +6,8 @@ @date:2023/9/25 14:24 @desc: """ +import datetime +import json import uuid from django.contrib.postgres.fields import ArrayField @@ -18,6 +20,12 @@ from setting.models.model_management import Model from users.models import User +class ApplicationTypeChoices(models.TextChoices): + """订单类型""" + SIMPLE = 'SIMPLE', '简易' + WORK_FLOW = 'WORK_FLOW', '工作流' + + def get_dataset_setting_dict(): return {'top_n': 3, 'similarity': 0.6, 'max_paragraph_char_number': 5000, 'search_mode': 'embedding', 'no_references_setting': { @@ -42,6 +50,9 @@ class Application(AppModelMixin): model_setting = models.JSONField(verbose_name="模型参数相关设置", default=get_model_setting_dict) problem_optimization = models.BooleanField(verbose_name="问题优化", default=False) icon = models.CharField(max_length=256, verbose_name="应用icon", default="/ui/favicon.ico") + work_flow = models.JSONField(verbose_name="工作流数据", default=dict) + type = models.CharField(verbose_name="应用类型", choices=ApplicationTypeChoices.choices, + default=ApplicationTypeChoices.SIMPLE, max_length=256) @staticmethod def get_default_model_prompt(): @@ -61,6 +72,15 @@ class Application(AppModelMixin): db_table = "application" +class WorkFlowVersion(AppModelMixin): + id = models.UUIDField(primary_key=True, max_length=128, default=uuid.uuid1, editable=False, verbose_name="主键id") + application = models.ForeignKey(Application, on_delete=models.CASCADE) + work_flow = models.JSONField(verbose_name="工作流数据", default=dict) + + class Meta: + db_table = "application_work_flow_version" + + class ApplicationDatasetMapping(AppModelMixin): id = models.UUIDField(primary_key=True, max_length=128, default=uuid.uuid1, editable=False, verbose_name="主键id") application = models.ForeignKey(Application, on_delete=models.CASCADE) @@ -88,6 +108,16 @@ class VoteChoices(models.TextChoices): TRAMPLE = 1, '反对' +class DateEncoder(json.JSONEncoder): + def default(self, obj): + if isinstance(obj, uuid.UUID): + return str(obj) + if isinstance(obj, datetime.datetime): + return obj.strftime("%Y-%m-%d %H:%M:%S") + else: + return json.JSONEncoder.default(self, obj) + + class ChatRecord(AppModelMixin): """ 对话日志 详情 @@ -101,7 +131,7 @@ class ChatRecord(AppModelMixin): message_tokens = models.IntegerField(verbose_name="请求token数量", default=0) answer_tokens = models.IntegerField(verbose_name="响应token数量", default=0) const = models.IntegerField(verbose_name="总费用", default=0) - details = models.JSONField(verbose_name="对话详情", default=dict) + details = models.JSONField(verbose_name="对话详情", default=dict, encoder=DateEncoder) improve_paragraph_id_list = ArrayField(verbose_name="改进标注列表", base_field=models.UUIDField(max_length=128, blank=True) , default=list) diff --git a/apps/application/serializers/application_serializers.py b/apps/application/serializers/application_serializers.py index 47ecb4446..40e447d48 100644 --- a/apps/application/serializers/application_serializers.py +++ b/apps/application/serializers/application_serializers.py @@ -7,6 +7,7 @@ @desc: """ import hashlib +import json import os import re import uuid @@ -22,7 +23,8 @@ from django.http import HttpResponse from django.template import Template, Context from rest_framework import serializers -from application.models import Application, ApplicationDatasetMapping +from application.flow.workflow_manage import Flow +from application.models import Application, ApplicationDatasetMapping, ApplicationTypeChoices, WorkFlowVersion from application.models.api_key_model import ApplicationAccessToken, ApplicationApiKey from common.config.embedding_config import VectorStore, EmbeddingModel from common.constants.authentication_type import AuthenticationType @@ -105,6 +107,47 @@ class ModelSettingSerializer(serializers.Serializer): prompt = serializers.CharField(required=True, max_length=2048, error_messages=ErrMessage.char("提示词")) +class ApplicationWorkflowSerializer(serializers.Serializer): + name = serializers.CharField(required=True, max_length=64, min_length=1, error_messages=ErrMessage.char("应用名称")) + desc = serializers.CharField(required=False, allow_null=True, allow_blank=True, + max_length=256, min_length=1, + error_messages=ErrMessage.char("应用描述")) + prologue = serializers.CharField(required=False, allow_null=True, allow_blank=True, max_length=4096, + error_messages=ErrMessage.char("开场白")) + + @staticmethod + def to_application_model(user_id: str, application: Dict): + + default_workflow_json = get_file_content( + os.path.join(PROJECT_DIR, "apps", "application", 'flow', 'default_workflow.json')) + default_workflow = json.loads(default_workflow_json) + for node in default_workflow.get('nodes'): + if node.get('id') == 'base-node': + node.get('properties')['node_data'] = {"desc": application.get('desc'), + "name": application.get('name'), + "prologue": application.get('prologue')} + return Application(id=uuid.uuid1(), + name=application.get('name'), + desc=application.get('desc'), + prologue="", + dialogue_number=0, + user_id=user_id, model_id=None, + dataset_setting={}, + model_setting={}, + problem_optimization=False, + type=ApplicationTypeChoices.WORK_FLOW, + work_flow=default_workflow + ) + + +def get_base_node_work_flow(work_flow): + node_list = work_flow.get('nodes') + base_node_list = [node for node in node_list if node.get('id') == 'base-node'] + if len(base_node_list) > 0: + return base_node_list[-1] + return None + + class ApplicationSerializer(serializers.Serializer): name = serializers.CharField(required=True, max_length=64, min_length=1, error_messages=ErrMessage.char("应用名称")) desc = serializers.CharField(required=False, allow_null=True, allow_blank=True, @@ -123,6 +166,13 @@ class ApplicationSerializer(serializers.Serializer): model_setting = ModelSettingSerializer(required=True) # 问题补全 problem_optimization = serializers.BooleanField(required=True, error_messages=ErrMessage.boolean("问题补全")) + # 应用类型 + type = serializers.CharField(required=True, error_messages=ErrMessage.char("应用类型"), + validators=[ + validators.RegexValidator(regex=re.compile("^SIMPLE|WORK_FLOW$"), + message="应用类型只支持SIMPLE|WORK_FLOW", code=500) + ] + ) def is_valid(self, *, user_id=None, raise_exception=False): super().is_valid(raise_exception=True) @@ -281,6 +331,24 @@ class ApplicationSerializer(serializers.Serializer): @transaction.atomic def insert(self, application: Dict): + application_type = application.get('type') + if 'WORK_FLOW' == application_type: + return self.insert_workflow(application) + else: + return self.insert_simple(application) + + def insert_workflow(self, application: Dict): + self.is_valid(raise_exception=True) + user_id = self.data.get('user_id') + ApplicationWorkflowSerializer(data=application).is_valid(raise_exception=True) + application_model = ApplicationWorkflowSerializer.to_application_model(user_id, application) + application_model.save() + # 插入认证信息 + ApplicationAccessToken(application_id=application_model.id, + access_token=hashlib.md5(str(uuid.uuid1()).encode()).hexdigest()[8:24]).save() + return ApplicationSerializerModel(application_model).data + + def insert_simple(self, application: Dict): self.is_valid(raise_exception=True) user_id = self.data.get('user_id') ApplicationSerializer(data=application).is_valid(user_id=user_id, raise_exception=True) @@ -296,7 +364,7 @@ class ApplicationSerializer(serializers.Serializer): access_token=hashlib.md5(str(uuid.uuid1()).encode()).hexdigest()[8:24]).save() # 插入关联数据 QuerySet(ApplicationDatasetMapping).bulk_create(application_dataset_mapping_model_list) - return True + return ApplicationSerializerModel(application_model).data @staticmethod def to_application_model(user_id: str, application: Dict): @@ -306,7 +374,9 @@ class ApplicationSerializer(serializers.Serializer): user_id=user_id, model_id=application.get('model_id'), dataset_setting=application.get('dataset_setting'), model_setting=application.get('model_setting'), - problem_optimization=application.get('problem_optimization') + problem_optimization=application.get('problem_optimization'), + type=ApplicationTypeChoices.SIMPLE, + work_flow={} ) @staticmethod @@ -420,7 +490,7 @@ class ApplicationSerializer(serializers.Serializer): class ApplicationModel(serializers.ModelSerializer): class Meta: model = Application - fields = ['id', 'name', 'desc', 'prologue', 'dialogue_number', 'icon'] + fields = ['id', 'name', 'desc', 'prologue', 'dialogue_number', 'icon', 'type'] class IconOperate(serializers.Serializer): application_id = serializers.UUIDField(required=True, error_messages=ErrMessage.uuid("应用id")) @@ -463,6 +533,27 @@ class ApplicationSerializer(serializers.Serializer): QuerySet(Application).filter(id=self.data.get('application_id')).delete() return True + def publish(self, instance, with_valid=True): + if with_valid: + self.is_valid() + application = QuerySet(Application).filter(id=self.data.get("application_id")).first() + work_flow = instance.get('work_flow') + if work_flow is None: + raise AppApiException(500, "work_flow是必填字段") + Flow.new_instance(work_flow).is_valid() + base_node = get_base_node_work_flow(work_flow) + if base_node is not None: + node_data = base_node.get('properties').get('node_data') + if node_data is not None: + application.name = node_data.get('name') + application.desc = node_data.get('desc') + application.prologue = node_data.get('prologue') + application.work_flow = work_flow + application.save() + work_flow_version = WorkFlowVersion(work_flow=work_flow, application=application) + work_flow_version.save() + return True + def one(self, with_valid=True): if with_valid: self.is_valid() @@ -507,7 +598,7 @@ class ApplicationSerializer(serializers.Serializer): raise AppApiException(500, "模型不存在") update_keys = ['name', 'desc', 'model_id', 'multiple_rounds_dialogue', 'prologue', 'status', 'dataset_setting', 'model_setting', 'problem_optimization', - 'api_key_is_active', 'icon'] + 'api_key_is_active', 'icon', 'work_flow'] for update_key in update_keys: if update_key in instance and instance.get(update_key) is not None: if update_key == 'multiple_rounds_dialogue': diff --git a/apps/application/serializers/chat_message_serializers.py b/apps/application/serializers/chat_message_serializers.py index cc658f71d..41f19bc0d 100644 --- a/apps/application/serializers/chat_message_serializers.py +++ b/apps/application/serializers/chat_message_serializers.py @@ -7,6 +7,7 @@ @desc: """ import json +import uuid from typing import List from uuid import UUID @@ -22,7 +23,10 @@ from application.chat_pipeline.step.generate_human_message_step.impl.base_genera BaseGenerateHumanMessageStep from application.chat_pipeline.step.reset_problem_step.impl.base_reset_problem_step import BaseResetProblemStep from application.chat_pipeline.step.search_dataset_step.impl.base_search_dataset_step import BaseSearchDatasetStep -from application.models import ChatRecord, Chat, Application, ApplicationDatasetMapping +from application.flow.i_step_node import WorkFlowPostHandler +from application.flow.workflow_manage import WorkflowManage, Flow +from application.models import ChatRecord, Chat, Application, ApplicationDatasetMapping, ApplicationTypeChoices, \ + WorkFlowVersion from application.models.api_key_model import ApplicationPublicAccessClient, ApplicationAccessToken from common.constants.authentication_type import AuthenticationType from common.exception.app_exception import AppApiException, AppChatNumOutOfBoundsFailed @@ -39,10 +43,11 @@ chat_cache = caches['model_cache'] class ChatInfo: def __init__(self, chat_id: str, - chat_model: BaseChatModel, + chat_model: BaseChatModel | None, dataset_id_list: List[str], exclude_document_id_list: list[str], - application: Application): + application: Application, + work_flow_version: WorkFlowVersion = None): """ :param chat_id: 对话id :param chat_model: 对话模型 @@ -56,6 +61,7 @@ class ChatInfo: self.dataset_id_list = dataset_id_list self.exclude_document_id_list = exclude_document_id_list self.chat_record_list: List[ChatRecord] = [] + self.work_flow_version = work_flow_version def to_base_pipeline_manage_params(self): dataset_setting = self.application.dataset_setting @@ -146,8 +152,10 @@ class ChatMessageSerializer(serializers.Serializer): client_id = serializers.CharField(required=True, error_messages=ErrMessage.char("客户端id")) client_type = serializers.CharField(required=True, error_messages=ErrMessage.char("客户端类型")) - def is_valid(self, *, raise_exception=False): - super().is_valid(raise_exception=True) + def is_valid_application_workflow(self, *, raise_exception=False): + self.is_valid_intraday_access_num() + + def is_valid_intraday_access_num(self): if self.data.get('client_type') == AuthenticationType.APPLICATION_ACCESS_TOKEN.value: access_client = QuerySet(ApplicationPublicAccessClient).filter(id=self.data.get('client_id')).first() if access_client is None: @@ -161,12 +169,9 @@ class ChatMessageSerializer(serializers.Serializer): application_id=self.data.get('application_id')).first() if application_access_token.access_num <= access_client.intraday_access_num: raise AppChatNumOutOfBoundsFailed(1002, "访问次数超过今日访问量") - chat_id = self.data.get('chat_id') - chat_info: ChatInfo = chat_cache.get(chat_id) - if chat_info is None: - chat_info = self.re_open_chat(chat_id) - chat_cache.set(chat_id, - chat_info, timeout=60 * 30) + + def is_valid_application_simple(self, *, chat_info: ChatInfo, raise_exception=False): + self.is_valid_intraday_access_num() model = chat_info.application.model if model is None: return chat_info @@ -179,8 +184,7 @@ class ChatMessageSerializer(serializers.Serializer): raise AppApiException(500, "模型正在下载中,请稍后再发起对话") return chat_info - def chat(self): - chat_info = self.is_valid(raise_exception=True) + def chat_simple(self, chat_info: ChatInfo): message = self.data.get('message') re_chat = self.data.get('re_chat') stream = self.data.get('stream') @@ -211,14 +215,54 @@ class ChatMessageSerializer(serializers.Serializer): pipeline_message.run(params) return pipeline_message.context['chat_result'] - @staticmethod - def re_open_chat(chat_id: str): + def chat_work_flow(self, chat_info: ChatInfo): + message = self.data.get('message') + re_chat = self.data.get('re_chat') + stream = self.data.get('stream') + client_id = self.data.get('client_id') + client_type = self.data.get('client_type') + work_flow_manage = WorkflowManage(Flow.new_instance(chat_info.work_flow_version.work_flow), + {'history_chat_record': chat_info.chat_record_list, 'question': message, + 'chat_id': chat_info.chat_id, 'chat_record_id': str(uuid.uuid1()), + 'stream': stream, + 're_chat': re_chat}, WorkFlowPostHandler(chat_info, client_id, client_type)) + r = work_flow_manage.run() + return r + + def chat(self): + super().is_valid(raise_exception=True) + chat_info = self.get_chat_info() + if chat_info.application.type == ApplicationTypeChoices.SIMPLE: + self.is_valid_application_simple(raise_exception=True, chat_info=chat_info), + return self.chat_simple(chat_info) + else: + self.is_valid_application_workflow(raise_exception=True) + return self.chat_work_flow(chat_info) + + def get_chat_info(self): + self.is_valid(raise_exception=True) + chat_id = self.data.get('chat_id') + chat_info: ChatInfo = chat_cache.get(chat_id) + if chat_info is None: + chat_info: ChatInfo = self.re_open_chat(chat_id) + chat_cache.set(chat_id, + chat_info, timeout=60 * 30) + return chat_info + + def re_open_chat(self, chat_id: str): chat = QuerySet(Chat).filter(id=chat_id).first() if chat is None: raise AppApiException(500, "会话不存在") application = QuerySet(Application).filter(id=chat.application_id).first() if application is None: raise AppApiException(500, "应用不存在") + if application.type == ApplicationTypeChoices.SIMPLE: + return self.re_open_chat_simple(chat_id, application) + else: + return self.re_open_chat_work_flow(chat_id, application) + + @staticmethod + def re_open_chat_simple(chat_id, application): model = QuerySet(Model).filter(id=application.model_id).first() chat_model = None if model is not None: @@ -238,3 +282,11 @@ class ChatMessageSerializer(serializers.Serializer): dataset_id__in=dataset_id_list, is_active=False)] return ChatInfo(chat_id, chat_model, dataset_id_list, exclude_document_id_list, application) + + @staticmethod + def re_open_chat_work_flow(chat_id, application): + work_flow_version = QuerySet(WorkFlowVersion).filter(application_id=application.id).order_by( + '-create_time')[0:1].first() + if work_flow_version is None: + raise AppApiException(500, "应用未发布,请发布后再使用") + return ChatInfo(chat_id, None, [], [], application, work_flow_version) diff --git a/apps/application/serializers/chat_serializers.py b/apps/application/serializers/chat_serializers.py index a0d0b7690..402eff691 100644 --- a/apps/application/serializers/chat_serializers.py +++ b/apps/application/serializers/chat_serializers.py @@ -22,7 +22,9 @@ from django.db.models import QuerySet, Q from django.http import HttpResponse from rest_framework import serializers -from application.models import Chat, Application, ApplicationDatasetMapping, VoteChoices, ChatRecord +from application.flow.workflow_manage import Flow +from application.models import Chat, Application, ApplicationDatasetMapping, VoteChoices, ChatRecord, WorkFlowVersion, \ + ApplicationTypeChoices from application.models.api_key_model import ApplicationAccessToken from application.serializers.application_serializers import ModelDatasetAssociation, DatasetSettingSerializer, \ ModelSettingSerializer @@ -45,6 +47,11 @@ from smartdoc.conf import PROJECT_DIR chat_cache = caches['model_cache'] +class WorkFlowSerializers(serializers.Serializer): + nodes = serializers.ListSerializer(child=serializers.DictField(), error_messages=ErrMessage.uuid("节点")) + edges = serializers.ListSerializer(child=serializers.DictField(), error_messages=ErrMessage.uuid("连线")) + + class ChatSerializers(serializers.Serializer): class Operate(serializers.Serializer): chat_id = serializers.UUIDField(required=True, error_messages=ErrMessage.uuid("对话id")) @@ -207,6 +214,27 @@ class ChatSerializers(serializers.Serializer): self.is_valid(raise_exception=True) application_id = self.data.get('application_id') application = QuerySet(Application).get(id=application_id) + if application.type == ApplicationTypeChoices.SIMPLE: + return self.open_simple(application) + else: + return self.open_work_flow(application) + + def open_work_flow(self, application): + self.is_valid(raise_exception=True) + application_id = self.data.get('application_id') + chat_id = str(uuid.uuid1()) + work_flow_version = QuerySet(WorkFlowVersion).filter(application_id=application_id).order_by( + '-create_time')[0:1].first() + if work_flow_version is None: + raise AppApiException(500, "应用未发布,请发布后再使用") + chat_cache.set(chat_id, + ChatInfo(chat_id, None, [], + [], + application, work_flow_version), timeout=60 * 30) + return chat_id + + def open_simple(self, application): + application_id = self.data.get('application_id') model = QuerySet(Model).filter(id=application.model_id).first() dataset_id_list = [str(row.dataset_id) for row in QuerySet(ApplicationDatasetMapping).filter( @@ -229,6 +257,27 @@ class ChatSerializers(serializers.Serializer): application), timeout=60 * 30) return chat_id + class OpenWorkFlowChat(serializers.Serializer): + work_flow = WorkFlowSerializers(error_messages=ErrMessage.uuid("工作流")) + + def open(self): + self.is_valid(raise_exception=True) + work_flow = self.data.get('work_flow') + Flow.new_instance(work_flow).is_valid() + chat_id = str(uuid.uuid1()) + application = Application(id=None, dialogue_number=3, model=None, + dataset_setting={}, + model_setting={}, + problem_optimization=None, + type=ApplicationTypeChoices.WORK_FLOW + ) + work_flow_version = WorkFlowVersion(work_flow=work_flow) + chat_cache.set(chat_id, + ChatInfo(chat_id, None, [], + [], + application, work_flow_version), timeout=60 * 30) + return chat_id + class OpenTempChat(serializers.Serializer): user_id = serializers.UUIDField(required=True, error_messages=ErrMessage.uuid("用户id")) @@ -329,7 +378,7 @@ class ChatRecordSerializer(serializers.Serializer): chat_info: ChatInfo = chat_cache.get(chat_id) if chat_info is not None: chat_record_list = [chat_record for chat_record in chat_info.chat_record_list if - chat_record.id == uuid.UUID(chat_record_id)] + str(chat_record.id) == str(chat_record_id)] if chat_record_list is not None and len(chat_record_list): return chat_record_list[-1] return QuerySet(ChatRecord).filter(id=chat_record_id, chat_id=chat_id).first() @@ -377,7 +426,8 @@ class ChatRecordSerializer(serializers.Serializer): 'padding_problem_text': chat_record.details.get('problem_padding').get( 'padding_problem_text') if 'problem_padding' in chat_record.details else None, 'dataset_list': dataset_list, - 'paragraph_list': paragraph_list + 'paragraph_list': paragraph_list, + 'execution_details': [chat_record.details[key] for key in chat_record.details] } def page(self, current_page: int, page_size: int, with_valid=True): diff --git a/apps/application/swagger_api/application_api.py b/apps/application/swagger_api/application_api.py index 4bacc5831..6e46931a6 100644 --- a/apps/application/swagger_api/application_api.py +++ b/apps/application/swagger_api/application_api.py @@ -161,7 +161,25 @@ class ApplicationApi(ApiMixin): 'problem_optimization': openapi.Schema(type=openapi.TYPE_BOOLEAN, title="问题优化", description="是否开启问题优化", default=True), 'icon': openapi.Schema(type=openapi.TYPE_STRING, title="icon", - description="icon", default="/ui/favicon.ico") + description="icon", default="/ui/favicon.ico"), + 'work_flow': ApplicationApi.WorkFlow.get_request_body_api() + + } + ) + + class WorkFlow(ApiMixin): + @staticmethod + def get_request_body_api(): + return openapi.Schema( + type=openapi.TYPE_OBJECT, + required=[''], + properties={ + 'nodes': openapi.Schema(type=openapi.TYPE_ARRAY, items=openapi.Schema(type=openapi.TYPE_OBJECT), + title="节点列表", description="节点列表", + default=[]), + 'edges': openapi.Schema(type=openapi.TYPE_ARRAY, items=openapi.Schema(type=openapi.TYPE_OBJECT), + title='连线列表', description="连线列表", + default={}), } ) @@ -219,6 +237,17 @@ class ApplicationApi(ApiMixin): } ) + class Publish(ApiMixin): + @staticmethod + def get_request_body_api(): + return openapi.Schema( + type=openapi.TYPE_OBJECT, + required=[], + properties={ + 'work_flow': ApplicationApi.WorkFlow.get_request_body_api() + } + ) + class Create(ApiMixin): @staticmethod def get_request_body_api(): @@ -239,7 +268,9 @@ class ApplicationApi(ApiMixin): 'dataset_setting': ApplicationApi.DatasetSetting.get_request_body_api(), 'model_setting': ApplicationApi.ModelSetting.get_request_body_api(), 'problem_optimization': openapi.Schema(type=openapi.TYPE_BOOLEAN, title="问题优化", - description="是否开启问题优化", default=True) + description="是否开启问题优化", default=True), + 'type': openapi.Schema(type=openapi.TYPE_STRING, title="应用类型", + description="应用类型 简易:SIMPLE|工作流:WORK_FLOW") } ) diff --git a/apps/application/swagger_api/chat_api.py b/apps/application/swagger_api/chat_api.py index 9c56cd21e..2ff8f8ac5 100644 --- a/apps/application/swagger_api/chat_api.py +++ b/apps/application/swagger_api/chat_api.py @@ -82,6 +82,17 @@ class ChatApi(ApiMixin): ] + class OpenWorkFlowTemp(ApiMixin): + @staticmethod + def get_request_body_api(): + return openapi.Schema( + type=openapi.TYPE_OBJECT, + required=[], + properties={ + 'work_flow': ApplicationApi.WorkFlow.get_request_body_api() + } + ) + class OpenTempChat(ApiMixin): @staticmethod def get_request_body_api(): diff --git a/apps/application/urls.py b/apps/application/urls.py index 4fcbbbf0c..335205d37 100644 --- a/apps/application/urls.py +++ b/apps/application/urls.py @@ -8,6 +8,7 @@ urlpatterns = [ path('application/profile', views.Application.Profile.as_view(), name='application/profile'), path('application/embed', views.Application.Embed.as_view()), path('application/authentication', views.Application.Authentication.as_view()), + path('application//publish', views.Application.Publish.as_view()), path('application//edit_icon', views.Application.EditIcon.as_view()), path('application//statistics/customer_count', views.ApplicationStatistics.CustomerCount.as_view()), @@ -30,6 +31,7 @@ urlpatterns = [ path('application//', views.Application.Page.as_view(), name='application_page'), path('application//chat/open', views.ChatView.Open.as_view(), name='application/open'), path("application/chat/open", views.ChatView.OpenTemp.as_view()), + path("application/chat_workflow/open", views.ChatView.OpenWorkFlowTemp.as_view()), path("application//chat/client//", views.ChatView.ClientChatHistoryPage.as_view()), path("application//chat/client/", diff --git a/apps/application/views/application_views.py b/apps/application/views/application_views.py index 3ebed0899..16663d32d 100644 --- a/apps/application/views/application_views.py +++ b/apps/application/views/application_views.py @@ -7,6 +7,7 @@ @desc: """ +from django.core import cache from django.http import HttpResponse from drf_yasg.utils import swagger_auto_schema from rest_framework.decorators import action @@ -27,6 +28,8 @@ from common.swagger_api.common_api import CommonApi from common.util.common import query_params_to_single_dict from dataset.serializers.dataset_serializers import DataSetSerializers +chat_cache = cache.caches['model_cache'] + class ApplicationStatistics(APIView): class CustomerCount(APIView): @@ -332,8 +335,7 @@ class Application(APIView): tags=['应用']) @has_permissions(PermissionConstants.APPLICATION_CREATE, compare=CompareConstants.AND) def post(self, request: Request): - ApplicationSerializer.Create(data={'user_id': request.user.id}).insert(request.data) - return result.success(True) + return result.success(ApplicationSerializer.Create(data={'user_id': request.user.id}).insert(request.data)) @action(methods=['GET'], detail=False) @swagger_auto_schema(operation_summary="获取应用列表", @@ -370,6 +372,26 @@ class Application(APIView): 'search_mode': request.query_params.get('search_mode')}).hit_test( )) + class Publish(APIView): + authentication_classes = [TokenAuth] + + @action(methods=['PUT'], detail=False) + @swagger_auto_schema(operation_summary="发布应用", + operation_id="发布应用", + manual_parameters=ApplicationApi.Operate.get_request_params_api(), + request_body=ApplicationApi.Publish.get_request_body_api(), + responses=result.get_default_response(), + tags=['应用']) + @has_permissions(ViewPermission( + [RoleConstants.ADMIN, RoleConstants.USER], + [lambda r, keywords: Permission(group=Group.APPLICATION, operate=Operate.MANAGE, + dynamic_tag=keywords.get('application_id'))], + compare=CompareConstants.AND)) + def put(self, request: Request, application_id: str): + return result.success( + ApplicationSerializer.Operate( + data={'application_id': application_id, 'user_id': request.user.id}).publish(request.data)) + class Operate(APIView): authentication_classes = [TokenAuth] diff --git a/apps/application/views/chat_views.py b/apps/application/views/chat_views.py index b7f5968f2..577c08838 100644 --- a/apps/application/views/chat_views.py +++ b/apps/application/views/chat_views.py @@ -64,6 +64,18 @@ class ChatView(APIView): return result.success(ChatSerializers.OpenChat( data={'user_id': request.user.id, 'application_id': application_id}).open()) + class OpenWorkFlowTemp(APIView): + authentication_classes = [TokenAuth] + + @action(methods=['POST'], detail=False) + @swagger_auto_schema(operation_summary="获取工作流临时会话id", + operation_id="获取工作流临时会话id", + request_body=ChatApi.OpenWorkFlowTemp.get_request_body_api(), + tags=["应用/会话"]) + def post(self, request: Request): + return result.success(ChatSerializers.OpenWorkFlowChat( + data={**request.data, 'user_id': request.user.id}).open()) + class OpenTemp(APIView): authentication_classes = [TokenAuth] diff --git a/apps/common/db/compiler.py b/apps/common/db/compiler.py index 9a65f93e1..69640c8a0 100644 --- a/apps/common/db/compiler.py +++ b/apps/common/db/compiler.py @@ -7,9 +7,10 @@ @desc: """ -from django.core.exceptions import EmptyResultSet +from django.core.exceptions import EmptyResultSet, FullResultSet from django.db import NotSupportedError from django.db.models.sql.compiler import SQLCompiler +from django.db.transaction import TransactionManagementError class AppSQLCompiler(SQLCompiler): @@ -19,15 +20,16 @@ class AppSQLCompiler(SQLCompiler): field_replace_dict = {} self.field_replace_dict = field_replace_dict - def get_query_str(self, with_limits=True, with_table_name=False): + def get_query_str(self, with_limits=True, with_table_name=False, with_col_aliases=False): refcounts_before = self.query.alias_refcount.copy() try: - extra_select, order_by, group_by = self.pre_sql_setup() + combinator = self.query.combinator + extra_select, order_by, group_by = self.pre_sql_setup( + with_col_aliases=with_col_aliases or bool(combinator), + ) for_update_part = None # Is a LIMIT/OFFSET clause needed? - with_limit_offset = with_limits and ( - self.query.high_mark is not None or self.query.low_mark - ) + with_limit_offset = with_limits and self.query.is_sliced combinator = self.query.combinator features = self.connection.features if combinator: @@ -40,8 +42,14 @@ class AppSQLCompiler(SQLCompiler): result, params = self.get_combinator_sql( combinator, self.query.combinator_all ) + elif self.qualify: + result, params = self.get_qualify_sql() + order_by = None else: distinct_fields, distinct_params = self.get_distinct() + # This must come after 'select', 'ordering', and 'distinct' + # (see docstring of get_from_clause() for details). + from_, f_params = self.get_from_clause() try: where, w_params = ( self.compile(self.where) if self.where is not None else ("", []) @@ -51,11 +59,92 @@ class AppSQLCompiler(SQLCompiler): raise # Use a predicate that's always False. where, w_params = "0 = 1", [] - having, h_params = ( - self.compile(self.having) if self.having is not None else ("", []) - ) + except FullResultSet: + where, w_params = "", [] + try: + having, h_params = ( + self.compile(self.having) + if self.having is not None + else ("", []) + ) + except FullResultSet: + having, h_params = "", [] result = [] params = [] + + if self.query.distinct: + distinct_result, distinct_params = self.connection.ops.distinct_sql( + distinct_fields, + distinct_params, + ) + result += distinct_result + params += distinct_params + + out_cols = [] + for _, (s_sql, s_params), alias in self.select + extra_select: + if alias: + s_sql = "%s AS %s" % ( + s_sql, + self.connection.ops.quote_name(alias), + ) + params.extend(s_params) + out_cols.append(s_sql) + + params.extend(f_params) + + if self.query.select_for_update and features.has_select_for_update: + if ( + self.connection.get_autocommit() + # Don't raise an exception when database doesn't + # support transactions, as it's a noop. + and features.supports_transactions + ): + raise TransactionManagementError( + "select_for_update cannot be used outside of a transaction." + ) + + if ( + with_limit_offset + and not features.supports_select_for_update_with_limit + ): + raise NotSupportedError( + "LIMIT/OFFSET is not supported with " + "select_for_update on this database backend." + ) + nowait = self.query.select_for_update_nowait + skip_locked = self.query.select_for_update_skip_locked + of = self.query.select_for_update_of + no_key = self.query.select_for_no_key_update + # If it's a NOWAIT/SKIP LOCKED/OF/NO KEY query but the + # backend doesn't support it, raise NotSupportedError to + # prevent a possible deadlock. + if nowait and not features.has_select_for_update_nowait: + raise NotSupportedError( + "NOWAIT is not supported on this database backend." + ) + elif skip_locked and not features.has_select_for_update_skip_locked: + raise NotSupportedError( + "SKIP LOCKED is not supported on this database backend." + ) + elif of and not features.has_select_for_update_of: + raise NotSupportedError( + "FOR UPDATE OF is not supported on this database backend." + ) + elif no_key and not features.has_select_for_no_key_update: + raise NotSupportedError( + "FOR NO KEY UPDATE is not supported on this " + "database backend." + ) + for_update_part = self.connection.ops.for_update_sql( + nowait=nowait, + skip_locked=skip_locked, + of=self.get_select_for_update_of_arguments(), + no_key=no_key, + ) + + if for_update_part and features.for_update_after_from: + result.append(for_update_part) + if where: result.append("WHERE %s" % where) params.extend(w_params) @@ -91,7 +180,11 @@ class AppSQLCompiler(SQLCompiler): for _, (o_sql, o_params, _) in order_by: ordering.append(o_sql) params.extend(o_params) - result.append("ORDER BY %s" % ", ".join(ordering)) + order_by_sql = "ORDER BY %s" % ", ".join(ordering) + if combinator and features.requires_compound_order_by_subquery: + result = ["SELECT * FROM (", *result, ")", order_by_sql] + else: + result.append(order_by_sql) if with_limit_offset: result.append( @@ -102,6 +195,7 @@ class AppSQLCompiler(SQLCompiler): if for_update_part and not features.for_update_after_from: result.append(for_update_part) + from_, f_params = self.get_from_clause() sql = " ".join(result) if not with_table_name: diff --git a/apps/common/response/result.py b/apps/common/response/result.py index d1cf6a3ad..bb2ba0faf 100644 --- a/apps/common/response/result.py +++ b/apps/common/response/result.py @@ -15,6 +15,7 @@ class Page(dict): class Result(JsonResponse): + charset = 'utf-8' """ 接口统一返回对象 """ diff --git a/apps/setting/models_provider/impl/wenxin_model_provider/model/qian_fan_chat_model.py b/apps/setting/models_provider/impl/wenxin_model_provider/model/qian_fan_chat_model.py index b07e8a01b..c8892e055 100644 --- a/apps/setting/models_provider/impl/wenxin_model_provider/model/qian_fan_chat_model.py +++ b/apps/setting/models_provider/impl/wenxin_model_provider/model/qian_fan_chat_model.py @@ -30,59 +30,3 @@ class QianfanChatModel(QianfanChatEndpoint): def get_num_tokens(self, text: str) -> int: tokenizer = TokenizerManage.get_tokenizer() return len(tokenizer.encode(text)) - - def stream( - self, - input: LanguageModelInput, - config: Optional[RunnableConfig] = None, - *, - stop: Optional[List[str]] = None, - **kwargs: Any, - ) -> Iterator[BaseMessageChunk]: - if len(input) % 2 == 0: - input = [HumanMessage(content='padding'), *input] - input = [ - HumanMessage(content=input[index].content) if index % 2 == 0 else AIMessage(content=input[index].content) - for index in range(0, len(input))] - if type(self)._stream == BaseChatModel._stream: - # model doesn't implement streaming, so use default implementation - yield cast( - BaseMessageChunk, self.invoke(input, config=config, stop=stop, **kwargs) - ) - else: - config = config or {} - messages = self._convert_input(input).to_messages() - params = self._get_invocation_params(stop=stop, **kwargs) - options = {"stop": stop, **kwargs} - callback_manager = CallbackManager.configure( - config.get("callbacks"), - self.callbacks, - self.verbose, - config.get("tags"), - self.tags, - config.get("metadata"), - self.metadata, - ) - (run_manager,) = callback_manager.on_chat_model_start( - dumpd(self), - [messages], - invocation_params=params, - options=options, - name=config.get("run_name"), - ) - try: - generation: Optional[ChatGenerationChunk] = None - for chunk in self._stream( - messages, stop=stop, run_manager=run_manager, **kwargs - ): - yield chunk.message - if generation is None: - generation = chunk - assert generation is not None - except BaseException as e: - run_manager.on_llm_error(e) - raise e - else: - run_manager.on_llm_end( - LLMResult(generations=[[generation]]), - ) diff --git a/package-lock.json b/package-lock.json new file mode 100644 index 000000000..d70a5c3b4 --- /dev/null +++ b/package-lock.json @@ -0,0 +1,6 @@ +{ + "name": "MaxKB", + "lockfileVersion": 3, + "requires": true, + "packages": {} +} diff --git a/pyproject.toml b/pyproject.toml index fffd51cff..826ec71ca 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -7,8 +7,8 @@ readme = "README.md" [tool.poetry.dependencies] python = "^3.11" -django = "4.1.13" -djangorestframework = "3.14.0" +django = "4.2.13" +djangorestframework = "^3.15.2" drf-yasg = "1.21.7" django-filter = "23.2" langchain = "^0.2.3" diff --git a/ui/index.html b/ui/index.html index 4101498d3..09bec9ae4 100644 --- a/ui/index.html +++ b/ui/index.html @@ -5,7 +5,7 @@ diff --git a/ui/package.json b/ui/package.json index 4bc39ccc8..fb9d937d0 100644 --- a/ui/package.json +++ b/ui/package.json @@ -14,26 +14,20 @@ }, "dependencies": { "@ctrl/tinycolor": "^4.1.0", + "@logicflow/core": "^1.2.27", + "@logicflow/extension": "^1.2.27", "@vueuse/core": "^10.9.0", "axios": "^0.28.0", "cropperjs": "^1.6.2", "echarts": "^5.5.0", "element-plus": "^2.5.6", "file-saver": "^2.0.5", + "highlight.js": "^11.9.0", "install": "^0.13.0", "katex": "^0.16.10", "lodash": "^4.17.21", - "markdown-it": "^13.0.2", - "markdown-it-abbr": "^1.0.4", - "markdown-it-anchor": "^8.6.7", - "markdown-it-footnote": "^3.0.3", - "markdown-it-highlightjs": "^4.0.1", - "markdown-it-sub": "^1.0.0", - "markdown-it-sup": "^1.0.0", - "markdown-it-task-lists": "^2.1.1", - "markdown-it-toc-done-right": "^4.2.0", "marked": "^12.0.2", - "md-editor-v3": "4.12.1", + "md-editor-v3": "^4.16.7", "medium-zoom": "^1.1.0", "mermaid": "^10.9.0", "mitt": "^3.0.0", @@ -53,8 +47,6 @@ "@tsconfig/node18": "^18.2.0", "@types/file-saver": "^2.0.7", "@types/jsdom": "^21.1.1", - "@types/markdown-it": "^13.0.7", - "@types/markdown-it-highlightjs": "^3.3.4", "@types/node": "^18.17.5", "@types/nprogress": "^0.2.0", "@vitejs/plugin-vue": "^4.3.1", diff --git a/ui/src/App.vue b/ui/src/App.vue index d59d59725..3ed4b37bf 100644 --- a/ui/src/App.vue +++ b/ui/src/App.vue @@ -4,6 +4,4 @@ - + diff --git a/ui/src/api/application.ts b/ui/src/api/application.ts index c83ff3c8e..f37f67937 100644 --- a/ui/src/api/application.ts +++ b/ui/src/api/application.ts @@ -45,8 +45,7 @@ const postApplication: ( /** * 修改应用 - * @param 参数 - + * @param 参数 */ const putApplication: ( application_id: String, @@ -150,6 +149,16 @@ const postChatOpen: (data: ApplicationFormType) => Promise> = (data) return post(`${prefix}/chat/open`, data) } +/** + * 获得工作流临时回话Id + * @param 参数 + +} + */ +const postWorkflowChatOpen: (data: ApplicationFormType) => Promise> = (data) => { + return post(`${prefix}/chat_workflow/open`, data) +} + /** * 正式回话Id * @param 参数 @@ -228,6 +237,18 @@ const getApplicationModel: ( return get(`${prefix}/${application_id}/model`, loading) } +/** + * 发布应用 + * @param 参数 + */ +const putPublishApplication: ( + application_id: String, + data: ApplicationFormType, + loading?: Ref +) => Promise> = (application_id, data, loading) => { + return put(`${prefix}/${application_id}/publish`, data, undefined, loading) +} + export default { getAllAppilcation, getApplication, @@ -245,5 +266,7 @@ export default { getProfile, putChatVote, getApplicationHitTest, - getApplicationModel + getApplicationModel, + putPublishApplication, + postWorkflowChatOpen } diff --git a/ui/src/api/type/application.ts b/ui/src/api/type/application.ts index 6da9dd84a..4dc9047f8 100644 --- a/ui/src/api/type/application.ts +++ b/ui/src/api/type/application.ts @@ -11,6 +11,8 @@ interface ApplicationFormType { model_setting?: any problem_optimization?: boolean icon?: string | undefined + type?: string + work_flow?: any } interface chatType { id: string diff --git a/ui/src/assets/MaxKB-logo.svg b/ui/src/assets/MaxKB-logo.svg new file mode 100644 index 000000000..326c901d0 --- /dev/null +++ b/ui/src/assets/MaxKB-logo.svg @@ -0,0 +1 @@ +MaxKB \ No newline at end of file diff --git a/ui/src/assets/icon_condition.svg b/ui/src/assets/icon_condition.svg new file mode 100644 index 000000000..2bc80a212 --- /dev/null +++ b/ui/src/assets/icon_condition.svg @@ -0,0 +1,3 @@ + + + diff --git a/ui/src/assets/icon_globe_color.svg b/ui/src/assets/icon_globe_color.svg new file mode 100644 index 000000000..7ede591d5 --- /dev/null +++ b/ui/src/assets/icon_globe_color.svg @@ -0,0 +1 @@ + \ No newline at end of file diff --git a/ui/src/assets/icon_hi.svg b/ui/src/assets/icon_hi.svg new file mode 100644 index 000000000..84bb36ac2 --- /dev/null +++ b/ui/src/assets/icon_hi.svg @@ -0,0 +1,5 @@ + + + + + diff --git a/ui/src/assets/icon_reply.svg b/ui/src/assets/icon_reply.svg new file mode 100644 index 000000000..430fc7fc1 --- /dev/null +++ b/ui/src/assets/icon_reply.svg @@ -0,0 +1,3 @@ + + + diff --git a/ui/src/assets/icon_setting.svg b/ui/src/assets/icon_setting.svg new file mode 100644 index 000000000..afa97360f --- /dev/null +++ b/ui/src/assets/icon_setting.svg @@ -0,0 +1,3 @@ + + + diff --git a/ui/src/assets/icon_start.svg b/ui/src/assets/icon_start.svg new file mode 100644 index 000000000..0b8d73064 --- /dev/null +++ b/ui/src/assets/icon_start.svg @@ -0,0 +1,4 @@ + + + + diff --git a/ui/src/components/ai-chat/ExecutionDetailDialog.vue b/ui/src/components/ai-chat/ExecutionDetailDialog.vue new file mode 100644 index 000000000..848691c05 --- /dev/null +++ b/ui/src/components/ai-chat/ExecutionDetailDialog.vue @@ -0,0 +1,209 @@ + + + diff --git a/ui/src/components/ai-chat/KnowledgeSource.vue b/ui/src/components/ai-chat/KnowledgeSource.vue new file mode 100644 index 000000000..8b8db4a78 --- /dev/null +++ b/ui/src/components/ai-chat/KnowledgeSource.vue @@ -0,0 +1,79 @@ + + + diff --git a/ui/src/components/ai-chat/ParagraphSourceDialog.vue b/ui/src/components/ai-chat/ParagraphSourceDialog.vue index 5a5b9e10b..4034c7cac 100644 --- a/ui/src/components/ai-chat/ParagraphSourceDialog.vue +++ b/ui/src/components/ai-chat/ParagraphSourceDialog.vue @@ -19,50 +19,7 @@ @@ -75,10 +32,9 @@ import { ref, watch, onBeforeUnmount } from 'vue' import { cloneDeep } from 'lodash' import { arraySort } from '@/utils/utils' -import { MdPreview } from 'md-editor-v3' +import ParagraphCard from './component/ParagraphCard.vue' const emit = defineEmits(['refresh']) -const ParagraphDialogRef = ref() const dialogVisible = ref(false) const detail = ref({}) @@ -114,9 +70,6 @@ defineExpose({ open }) .paragraph-source-height { max-height: calc(100vh - 260px); } - .paragraph-source-card { - height: 260px; - } } @media only screen and (max-width: 768px) { .paragraph-source { @@ -124,9 +77,6 @@ defineExpose({ open }) .footer-content { display: block; } - .paragraph-source-card { - height: 285px; - } } } diff --git a/ui/src/components/ai-chat/component/ParagraphCard.vue b/ui/src/components/ai-chat/component/ParagraphCard.vue new file mode 100644 index 000000000..104e66242 --- /dev/null +++ b/ui/src/components/ai-chat/component/ParagraphCard.vue @@ -0,0 +1,58 @@ + + + diff --git a/ui/src/components/ai-chat/index.vue b/ui/src/components/ai-chat/index.vue index 56be13c5d..5ad33cfb1 100644 --- a/ui/src/components/ai-chat/index.vue +++ b/ui/src/components/ai-chat/index.vue @@ -27,6 +27,8 @@ ref="editorRef" editorId="preview-only" :modelValue="item.str" + noIconfont + no-mermaid /> @@ -73,38 +75,9 @@ +
- 知识来源 -
- - {{ dataset.name }} - -
- -
- 引用分段:{{ item.paragraph_list?.length || 0 }} - - 消耗 tokens: {{ item?.message_tokens + item?.answer_tokens }} - - - 耗时: {{ item.run_time?.toFixed(2) }} s - -
+
@@ -168,8 +141,6 @@
- - diff --git a/ui/src/components/markdown-renderer/index.vue b/ui/src/components/markdown-renderer/index.vue deleted file mode 100644 index 586722d14..000000000 --- a/ui/src/components/markdown-renderer/index.vue +++ /dev/null @@ -1,66 +0,0 @@ - - - - diff --git a/ui/src/components/markdown/MdEditor.vue b/ui/src/components/markdown/MdEditor.vue new file mode 100644 index 000000000..ec1c3ae67 --- /dev/null +++ b/ui/src/components/markdown/MdEditor.vue @@ -0,0 +1,14 @@ + + + diff --git a/ui/src/components/markdown/MdPreview.vue b/ui/src/components/markdown/MdPreview.vue new file mode 100644 index 000000000..7efcf95cf --- /dev/null +++ b/ui/src/components/markdown/MdPreview.vue @@ -0,0 +1,8 @@ + + + diff --git a/ui/src/components/markdown-renderer/MdRenderer.vue b/ui/src/components/markdown/MdRenderer.vue similarity index 96% rename from ui/src/components/markdown-renderer/MdRenderer.vue rename to ui/src/components/markdown/MdRenderer.vue index 416cf7f2b..0053e3f87 100644 --- a/ui/src/components/markdown-renderer/MdRenderer.vue +++ b/ui/src/components/markdown/MdRenderer.vue @@ -1,5 +1,6 @@