diff --git a/apps/application/flow/i_step_node.py b/apps/application/flow/i_step_node.py index 9c2e1c544..251a8f437 100644 --- a/apps/application/flow/i_step_node.py +++ b/apps/application/flow/i_step_node.py @@ -10,6 +10,7 @@ import time from abc import abstractmethod from typing import Type, Dict, List +from django.core import cache from django.db.models import QuerySet from rest_framework import serializers @@ -18,7 +19,6 @@ from application.models.api_key_model import ApplicationPublicAccessClient from common.constants.authentication_type import AuthenticationType from common.field.common import InstanceField from common.util.field_message import ErrMessage -from django.core import cache chat_cache = cache.caches['chat_cache'] @@ -27,6 +27,9 @@ def write_context(step_variable: Dict, global_variable: Dict, node, workflow): if step_variable is not None: for key in step_variable: node.context[key] = step_variable[key] + if workflow.is_result() and 'answer' in step_variable: + yield step_variable['answer'] + workflow.answer += step_variable['answer'] if global_variable is not None: for key in global_variable: workflow.context[key] = global_variable[key] @@ -70,18 +73,14 @@ class WorkFlowPostHandler: class NodeResult: - def __init__(self, node_variable: Dict, workflow_variable: Dict, _to_response=None, _write_context=write_context): + def __init__(self, node_variable: Dict, workflow_variable: Dict, + _write_context=write_context): self._write_context = _write_context self.node_variable = node_variable self.workflow_variable = workflow_variable - self._to_response = _to_response def write_context(self, node, workflow): - self._write_context(self.node_variable, self.workflow_variable, node, workflow) - - def to_response(self, chat_id, chat_record_id, node, workflow, post_handler: WorkFlowPostHandler): - return self._to_response(chat_id, chat_record_id, self.node_variable, self.workflow_variable, node, workflow, - post_handler) + return self._write_context(self.node_variable, self.workflow_variable, node, workflow) def is_assertion_result(self): return 'branch_id' in self.node_variable diff --git a/apps/application/flow/step_node/ai_chat_step_node/i_chat_node.py b/apps/application/flow/step_node/ai_chat_step_node/i_chat_node.py index 3b941922f..78cbc462c 100644 --- a/apps/application/flow/step_node/ai_chat_step_node/i_chat_node.py +++ b/apps/application/flow/step_node/ai_chat_step_node/i_chat_node.py @@ -22,6 +22,8 @@ class ChatNodeSerializer(serializers.Serializer): # 多轮对话数量 dialogue_number = serializers.IntegerField(required=True, error_messages=ErrMessage.integer("多轮对话数量")) + is_result = serializers.BooleanField(required=False, error_messages=ErrMessage.boolean('是否返回内容')) + class IChatNode(INode): type = 'ai-chat-node' diff --git a/apps/application/flow/step_node/ai_chat_step_node/impl/base_chat_node.py b/apps/application/flow/step_node/ai_chat_step_node/impl/base_chat_node.py index d8d087b04..06dec696b 100644 --- a/apps/application/flow/step_node/ai_chat_step_node/impl/base_chat_node.py +++ b/apps/application/flow/step_node/ai_chat_step_node/impl/base_chat_node.py @@ -6,19 +6,30 @@ @date:2024/6/4 14:30 @desc: """ -import time from functools import reduce from typing import List, Dict from langchain.schema import HumanMessage, SystemMessage from langchain_core.messages import BaseMessage -from application.flow import tools from application.flow.i_step_node import NodeResult, INode from application.flow.step_node.ai_chat_step_node.i_chat_node import IChatNode from setting.models_provider.tools import get_model_instance_by_model_user_id +def _write_context(node_variable: Dict, workflow_variable: Dict, node: INode, workflow, answer: str): + chat_model = node_variable.get('chat_model') + message_tokens = chat_model.get_num_tokens_from_messages(node_variable.get('message_list')) + answer_tokens = chat_model.get_num_tokens(answer) + node.context['message_tokens'] = message_tokens + node.context['answer_tokens'] = answer_tokens + node.context['answer'] = answer + node.context['history_message'] = node_variable['history_message'] + node.context['question'] = node_variable['question'] + if workflow.is_result(): + workflow.answer += answer + + def write_context_stream(node_variable: Dict, workflow_variable: Dict, node: INode, workflow): """ 写入上下文数据 (流式) @@ -31,15 +42,8 @@ def write_context_stream(node_variable: Dict, workflow_variable: Dict, node: INo answer = '' for chunk in response: answer += chunk.content - chat_model = node_variable.get('chat_model') - message_tokens = chat_model.get_num_tokens_from_messages(node_variable.get('message_list')) - answer_tokens = chat_model.get_num_tokens(answer) - node.context['message_tokens'] = message_tokens - node.context['answer_tokens'] = answer_tokens - node.context['answer'] = answer - node.context['history_message'] = node_variable['history_message'] - node.context['question'] = node_variable['question'] - node.context['run_time'] = time.time() - node.context['start_time'] + yield answer + _write_context(node_variable, workflow_variable, node, workflow, answer) def write_context(node_variable: Dict, workflow_variable: Dict, node: INode, workflow): @@ -51,71 +55,8 @@ def write_context(node_variable: Dict, workflow_variable: Dict, node: INode, wor @param workflow: 工作流管理器 """ response = node_variable.get('result') - chat_model = node_variable.get('chat_model') answer = response.content - message_tokens = chat_model.get_num_tokens_from_messages(node_variable.get('message_list')) - answer_tokens = chat_model.get_num_tokens(answer) - node.context['message_tokens'] = message_tokens - node.context['answer_tokens'] = answer_tokens - node.context['answer'] = answer - node.context['history_message'] = node_variable['history_message'] - node.context['question'] = node_variable['question'] - - -def get_to_response_write_context(node_variable: Dict, node: INode): - def _write_context(answer, status=200): - chat_model = node_variable.get('chat_model') - - if status == 200: - answer_tokens = chat_model.get_num_tokens(answer) - message_tokens = chat_model.get_num_tokens_from_messages(node_variable.get('message_list')) - else: - answer_tokens = 0 - message_tokens = 0 - node.err_message = answer - node.status = status - node.context['message_tokens'] = message_tokens - node.context['answer_tokens'] = answer_tokens - node.context['answer'] = answer - node.context['run_time'] = time.time() - node.context['start_time'] - - return _write_context - - -def to_stream_response(chat_id, chat_record_id, node_variable: Dict, workflow_variable: Dict, node, workflow, - post_handler): - """ - 将流式数据 转换为 流式响应 - @param chat_id: 会话id - @param chat_record_id: 对话记录id - @param node_variable: 节点数据 - @param workflow_variable: 工作流数据 - @param node: 节点 - @param workflow: 工作流管理器 - @param post_handler: 后置处理器 输出结果后执行 - @return: 流式响应 - """ - response = node_variable.get('result') - _write_context = get_to_response_write_context(node_variable, node) - return tools.to_stream_response(chat_id, chat_record_id, response, workflow, _write_context, post_handler) - - -def to_response(chat_id, chat_record_id, node_variable: Dict, workflow_variable: Dict, node, workflow, - post_handler): - """ - 将结果转换 - @param chat_id: 会话id - @param chat_record_id: 对话记录id - @param node_variable: 节点数据 - @param workflow_variable: 工作流数据 - @param node: 节点 - @param workflow: 工作流管理器 - @param post_handler: 后置处理器 - @return: 响应 - """ - response = node_variable.get('result') - _write_context = get_to_response_write_context(node_variable, node) - return tools.to_response(chat_id, chat_record_id, response, workflow, _write_context, post_handler) + _write_context(node_variable, workflow_variable, node, workflow, answer) class BaseChatNode(IChatNode): @@ -132,13 +73,12 @@ class BaseChatNode(IChatNode): r = chat_model.stream(message_list) return NodeResult({'result': r, 'chat_model': chat_model, 'message_list': message_list, 'history_message': history_message, 'question': question.content}, {}, - _write_context=write_context_stream, - _to_response=to_stream_response) + _write_context=write_context_stream) else: r = chat_model.invoke(message_list) return NodeResult({'result': r, 'chat_model': chat_model, 'message_list': message_list, 'history_message': history_message, 'question': question.content}, {}, - _write_context=write_context, _to_response=to_response) + _write_context=write_context) @staticmethod def get_history_message(history_chat_record, dialogue_number): diff --git a/apps/application/flow/step_node/direct_reply_node/i_reply_node.py b/apps/application/flow/step_node/direct_reply_node/i_reply_node.py index 1d5256ac5..3c0f35875 100644 --- a/apps/application/flow/step_node/direct_reply_node/i_reply_node.py +++ b/apps/application/flow/step_node/direct_reply_node/i_reply_node.py @@ -20,6 +20,7 @@ class ReplyNodeParamsSerializer(serializers.Serializer): fields = serializers.ListField(required=False, error_messages=ErrMessage.list("引用字段")) content = serializers.CharField(required=False, allow_blank=True, allow_null=True, error_messages=ErrMessage.char("直接回答内容")) + is_result = serializers.BooleanField(required=False, error_messages=ErrMessage.boolean('是否返回内容')) def is_valid(self, *, raise_exception=False): super().is_valid(raise_exception=True) diff --git a/apps/application/flow/step_node/direct_reply_node/impl/base_reply_node.py b/apps/application/flow/step_node/direct_reply_node/impl/base_reply_node.py index 717dce161..de79279d9 100644 --- a/apps/application/flow/step_node/direct_reply_node/impl/base_reply_node.py +++ b/apps/application/flow/step_node/direct_reply_node/impl/base_reply_node.py @@ -6,69 +6,19 @@ @date:2024/6/11 17:25 @desc: """ -from typing import List, Dict +from typing import List -from langchain_core.messages import AIMessage, AIMessageChunk - -from application.flow import tools -from application.flow.i_step_node import NodeResult, INode +from application.flow.i_step_node import NodeResult from application.flow.step_node.direct_reply_node.i_reply_node import IReplyNode -def get_to_response_write_context(node_variable: Dict, node: INode): - def _write_context(answer, status=200): - node.context['answer'] = answer - - return _write_context - - -def to_stream_response(chat_id, chat_record_id, node_variable: Dict, workflow_variable: Dict, node, workflow, - post_handler): - """ - 将流式数据 转换为 流式响应 - @param chat_id: 会话id - @param chat_record_id: 对话记录id - @param node_variable: 节点数据 - @param workflow_variable: 工作流数据 - @param node: 节点 - @param workflow: 工作流管理器 - @param post_handler: 后置处理器 输出结果后执行 - @return: 流式响应 - """ - response = node_variable.get('result') - _write_context = get_to_response_write_context(node_variable, node) - return tools.to_stream_response(chat_id, chat_record_id, response, workflow, _write_context, post_handler) - - -def to_response(chat_id, chat_record_id, node_variable: Dict, workflow_variable: Dict, node, workflow, - post_handler): - """ - 将结果转换 - @param chat_id: 会话id - @param chat_record_id: 对话记录id - @param node_variable: 节点数据 - @param workflow_variable: 工作流数据 - @param node: 节点 - @param workflow: 工作流管理器 - @param post_handler: 后置处理器 - @return: 响应 - """ - response = node_variable.get('result') - _write_context = get_to_response_write_context(node_variable, node) - return tools.to_response(chat_id, chat_record_id, response, workflow, _write_context, post_handler) - - class BaseReplyNode(IReplyNode): def execute(self, reply_type, stream, fields=None, content=None, **kwargs) -> NodeResult: if reply_type == 'referencing': result = self.get_reference_content(fields) else: result = self.generate_reply_content(content) - if stream: - return NodeResult({'result': iter([AIMessageChunk(content=result)]), 'answer': result}, {}, - _to_response=to_stream_response) - else: - return NodeResult({'result': AIMessage(content=result), 'answer': result}, {}, _to_response=to_response) + return NodeResult({'answer': result}, {}) def generate_reply_content(self, prompt): return self.workflow_manage.generate_prompt(prompt) diff --git a/apps/application/flow/step_node/question_node/i_question_node.py b/apps/application/flow/step_node/question_node/i_question_node.py index ede120def..30790c7c6 100644 --- a/apps/application/flow/step_node/question_node/i_question_node.py +++ b/apps/application/flow/step_node/question_node/i_question_node.py @@ -22,6 +22,8 @@ class QuestionNodeSerializer(serializers.Serializer): # 多轮对话数量 dialogue_number = serializers.IntegerField(required=True, error_messages=ErrMessage.integer("多轮对话数量")) + is_result = serializers.BooleanField(required=False, error_messages=ErrMessage.boolean('是否返回内容')) + class IQuestionNode(INode): type = 'question-node' diff --git a/apps/application/flow/step_node/question_node/impl/base_question_node.py b/apps/application/flow/step_node/question_node/impl/base_question_node.py index f5855361e..ae0d404d2 100644 --- a/apps/application/flow/step_node/question_node/impl/base_question_node.py +++ b/apps/application/flow/step_node/question_node/impl/base_question_node.py @@ -6,19 +6,30 @@ @date:2024/6/4 14:30 @desc: """ -import time from functools import reduce from typing import List, Dict from langchain.schema import HumanMessage, SystemMessage from langchain_core.messages import BaseMessage -from application.flow import tools from application.flow.i_step_node import NodeResult, INode from application.flow.step_node.question_node.i_question_node import IQuestionNode from setting.models_provider.tools import get_model_instance_by_model_user_id +def _write_context(node_variable: Dict, workflow_variable: Dict, node: INode, workflow, answer: str): + chat_model = node_variable.get('chat_model') + message_tokens = chat_model.get_num_tokens_from_messages(node_variable.get('message_list')) + answer_tokens = chat_model.get_num_tokens(answer) + node.context['message_tokens'] = message_tokens + node.context['answer_tokens'] = answer_tokens + node.context['answer'] = answer + node.context['history_message'] = node_variable['history_message'] + node.context['question'] = node_variable['question'] + if workflow.is_result(): + workflow.answer += answer + + def write_context_stream(node_variable: Dict, workflow_variable: Dict, node: INode, workflow): """ 写入上下文数据 (流式) @@ -31,15 +42,8 @@ def write_context_stream(node_variable: Dict, workflow_variable: Dict, node: INo answer = '' for chunk in response: answer += chunk.content - chat_model = node_variable.get('chat_model') - message_tokens = chat_model.get_num_tokens_from_messages(node_variable.get('message_list')) - answer_tokens = chat_model.get_num_tokens(answer) - node.context['message_tokens'] = message_tokens - node.context['answer_tokens'] = answer_tokens - node.context['answer'] = answer - node.context['history_message'] = node_variable['history_message'] - node.context['question'] = node_variable['question'] - node.context['run_time'] = time.time() - node.context['start_time'] + yield answer + _write_context(node_variable, workflow_variable, node, workflow, answer) def write_context(node_variable: Dict, workflow_variable: Dict, node: INode, workflow): @@ -51,71 +55,8 @@ def write_context(node_variable: Dict, workflow_variable: Dict, node: INode, wor @param workflow: 工作流管理器 """ response = node_variable.get('result') - chat_model = node_variable.get('chat_model') answer = response.content - message_tokens = chat_model.get_num_tokens_from_messages(node_variable.get('message_list')) - answer_tokens = chat_model.get_num_tokens(answer) - node.context['message_tokens'] = message_tokens - node.context['answer_tokens'] = answer_tokens - node.context['answer'] = answer - node.context['history_message'] = node_variable['history_message'] - node.context['question'] = node_variable['question'] - - -def get_to_response_write_context(node_variable: Dict, node: INode): - def _write_context(answer, status=200): - chat_model = node_variable.get('chat_model') - - if status == 200: - answer_tokens = chat_model.get_num_tokens(answer) - message_tokens = chat_model.get_num_tokens_from_messages(node_variable.get('message_list')) - else: - answer_tokens = 0 - message_tokens = 0 - node.err_message = answer - node.status = status - node.context['message_tokens'] = message_tokens - node.context['answer_tokens'] = answer_tokens - node.context['answer'] = answer - node.context['run_time'] = time.time() - node.context['start_time'] - - return _write_context - - -def to_stream_response(chat_id, chat_record_id, node_variable: Dict, workflow_variable: Dict, node, workflow, - post_handler): - """ - 将流式数据 转换为 流式响应 - @param chat_id: 会话id - @param chat_record_id: 对话记录id - @param node_variable: 节点数据 - @param workflow_variable: 工作流数据 - @param node: 节点 - @param workflow: 工作流管理器 - @param post_handler: 后置处理器 输出结果后执行 - @return: 流式响应 - """ - response = node_variable.get('result') - _write_context = get_to_response_write_context(node_variable, node) - return tools.to_stream_response(chat_id, chat_record_id, response, workflow, _write_context, post_handler) - - -def to_response(chat_id, chat_record_id, node_variable: Dict, workflow_variable: Dict, node, workflow, - post_handler): - """ - 将结果转换 - @param chat_id: 会话id - @param chat_record_id: 对话记录id - @param node_variable: 节点数据 - @param workflow_variable: 工作流数据 - @param node: 节点 - @param workflow: 工作流管理器 - @param post_handler: 后置处理器 - @return: 响应 - """ - response = node_variable.get('result') - _write_context = get_to_response_write_context(node_variable, node) - return tools.to_response(chat_id, chat_record_id, response, workflow, _write_context, post_handler) + _write_context(node_variable, workflow_variable, node, workflow, answer) class BaseQuestionNode(IQuestionNode): @@ -131,15 +72,13 @@ class BaseQuestionNode(IQuestionNode): if stream: r = chat_model.stream(message_list) return NodeResult({'result': r, 'chat_model': chat_model, 'message_list': message_list, - 'get_to_response_write_context': get_to_response_write_context, 'history_message': history_message, 'question': question.content}, {}, - _write_context=write_context_stream, - _to_response=to_stream_response) + _write_context=write_context_stream) else: r = chat_model.invoke(message_list) return NodeResult({'result': r, 'chat_model': chat_model, 'message_list': message_list, 'history_message': history_message, 'question': question.content}, {}, - _write_context=write_context, _to_response=to_response) + _write_context=write_context) @staticmethod def get_history_message(history_chat_record, dialogue_number): diff --git a/apps/application/flow/tools.py b/apps/application/flow/tools.py index 839aae8da..b2bf6b1ad 100644 --- a/apps/application/flow/tools.py +++ b/apps/application/flow/tools.py @@ -85,3 +85,21 @@ def to_response(chat_id, chat_record_id, response: BaseMessage, workflow, write_ post_handler.handler(chat_id, chat_record_id, answer, workflow) return result.success({'chat_id': str(chat_id), 'id': str(chat_record_id), 'operate': True, 'content': answer, 'is_end': True}) + + +def to_response_simple(chat_id, chat_record_id, response: BaseMessage, workflow, + post_handler: WorkFlowPostHandler): + answer = response.content + post_handler.handler(chat_id, chat_record_id, answer, workflow) + return result.success({'chat_id': str(chat_id), 'id': str(chat_record_id), 'operate': True, + 'content': answer, 'is_end': True}) + + +def to_stream_response_simple(stream_event): + r = StreamingHttpResponse( + streaming_content=stream_event, + content_type='text/event-stream;charset=utf-8', + charset='utf-8') + + r['Cache-Control'] = 'no-cache' + return r diff --git a/apps/application/flow/workflow_manage.py b/apps/application/flow/workflow_manage.py index 68a2ba022..cfc0d6404 100644 --- a/apps/application/flow/workflow_manage.py +++ b/apps/application/flow/workflow_manage.py @@ -6,10 +6,11 @@ @date:2024/1/9 17:40 @desc: """ +import json from functools import reduce from typing import List, Dict -from langchain_core.messages import AIMessageChunk, AIMessage +from langchain_core.messages import AIMessage from langchain_core.prompts import PromptTemplate from application.flow import tools @@ -63,7 +64,6 @@ class Flow: def get_search_node(self): return [node for node in self.nodes if node.type == 'search-dataset-node'] - def is_valid(self): """ 校验工作流数据 @@ -140,33 +140,71 @@ class WorkflowManage: self.work_flow_post_handler = work_flow_post_handler self.current_node = None self.current_result = None + self.answer = "" def run(self): - """ - 运行工作流 - """ + if self.params.get('stream'): + return self.run_stream() + return self.run_block() + + def run_block(self): try: while self.has_next_node(self.current_result): self.current_node = self.get_next_node() self.node_context.append(self.current_node) self.current_result = self.current_node.run() - if self.has_next_node(self.current_result): - self.current_result.write_context(self.current_node, self) - else: - r = self.current_result.to_response(self.params['chat_id'], self.params['chat_record_id'], - self.current_node, self, - self.work_flow_post_handler) - return r + result = self.current_result.write_context(self.current_node, self) + if result is not None: + list(result) + if not self.has_next_node(self.current_result): + return tools.to_response_simple(self.params['chat_id'], self.params['chat_record_id'], + AIMessage(self.answer), self, + self.work_flow_post_handler) except Exception as e: - if self.params.get('stream'): - return tools.to_stream_response(self.params['chat_id'], self.params['chat_record_id'], - iter([AIMessageChunk(str(e))]), self, - self.current_node.get_write_error_context(e), - self.work_flow_post_handler) - else: - return tools.to_response(self.params['chat_id'], self.params['chat_record_id'], - AIMessage(str(e)), self, self.current_node.get_write_error_context(e), - self.work_flow_post_handler) + return tools.to_response(self.params['chat_id'], self.params['chat_record_id'], + AIMessage(str(e)), self, self.current_node.get_write_error_context(e), + self.work_flow_post_handler) + + def run_stream(self): + return tools.to_stream_response_simple(self.stream_event()) + + def stream_event(self): + try: + while self.has_next_node(self.current_result): + self.current_node = self.get_next_node() + self.node_context.append(self.current_node) + self.current_result = self.current_node.run() + result = self.current_result.write_context(self.current_node, self) + if result is not None: + for r in result: + if self.is_result(): + yield self.get_chunk_content(r) + if not self.has_next_node(self.current_result): + yield self.get_chunk_content('', True) + break + self.work_flow_post_handler.handler(self.params['chat_id'], self.params['chat_record_id'], + self.answer, + self) + except Exception as e: + self.current_node.get_write_error_context(e) + self.answer += str(e) + self.work_flow_post_handler.handler(self.params['chat_id'], self.params['chat_record_id'], + self.answer, + self) + yield self.get_chunk_content(str(e), True) + + def is_result(self): + """ + 判断是否是返回节点 + @return: + """ + return self.current_node.node_params.get('is_result', not self.has_next_node( + self.current_result)) if self.current_node.node_params is not None else False + + def get_chunk_content(self, chunk, is_end=False): + return 'data: ' + json.dumps( + {'chat_id': self.params['chat_id'], 'id': self.params['chat_record_id'], 'operate': True, + 'content': chunk, 'is_end': is_end}, ensure_ascii=False) + "\n\n" def has_next_node(self, node_result: NodeResult | None): """