feat: ai对话,问题优化,指定回复节点支持返回结果开关

This commit is contained in:
shaohuzhang1 2024-08-01 18:47:02 +08:00
parent 7b48599e57
commit 209702cad2
9 changed files with 128 additions and 239 deletions

View File

@ -10,6 +10,7 @@ import time
from abc import abstractmethod
from typing import Type, Dict, List
from django.core import cache
from django.db.models import QuerySet
from rest_framework import serializers
@ -18,7 +19,6 @@ from application.models.api_key_model import ApplicationPublicAccessClient
from common.constants.authentication_type import AuthenticationType
from common.field.common import InstanceField
from common.util.field_message import ErrMessage
from django.core import cache
chat_cache = cache.caches['chat_cache']
@ -27,6 +27,9 @@ def write_context(step_variable: Dict, global_variable: Dict, node, workflow):
if step_variable is not None:
for key in step_variable:
node.context[key] = step_variable[key]
if workflow.is_result() and 'answer' in step_variable:
yield step_variable['answer']
workflow.answer += step_variable['answer']
if global_variable is not None:
for key in global_variable:
workflow.context[key] = global_variable[key]
@ -70,18 +73,14 @@ class WorkFlowPostHandler:
class NodeResult:
def __init__(self, node_variable: Dict, workflow_variable: Dict, _to_response=None, _write_context=write_context):
def __init__(self, node_variable: Dict, workflow_variable: Dict,
_write_context=write_context):
self._write_context = _write_context
self.node_variable = node_variable
self.workflow_variable = workflow_variable
self._to_response = _to_response
def write_context(self, node, workflow):
self._write_context(self.node_variable, self.workflow_variable, node, workflow)
def to_response(self, chat_id, chat_record_id, node, workflow, post_handler: WorkFlowPostHandler):
return self._to_response(chat_id, chat_record_id, self.node_variable, self.workflow_variable, node, workflow,
post_handler)
return self._write_context(self.node_variable, self.workflow_variable, node, workflow)
def is_assertion_result(self):
return 'branch_id' in self.node_variable

View File

@ -22,6 +22,8 @@ class ChatNodeSerializer(serializers.Serializer):
# 多轮对话数量
dialogue_number = serializers.IntegerField(required=True, error_messages=ErrMessage.integer("多轮对话数量"))
is_result = serializers.BooleanField(required=False, error_messages=ErrMessage.boolean('是否返回内容'))
class IChatNode(INode):
type = 'ai-chat-node'

View File

@ -6,19 +6,30 @@
@date2024/6/4 14:30
@desc:
"""
import time
from functools import reduce
from typing import List, Dict
from langchain.schema import HumanMessage, SystemMessage
from langchain_core.messages import BaseMessage
from application.flow import tools
from application.flow.i_step_node import NodeResult, INode
from application.flow.step_node.ai_chat_step_node.i_chat_node import IChatNode
from setting.models_provider.tools import get_model_instance_by_model_user_id
def _write_context(node_variable: Dict, workflow_variable: Dict, node: INode, workflow, answer: str):
chat_model = node_variable.get('chat_model')
message_tokens = chat_model.get_num_tokens_from_messages(node_variable.get('message_list'))
answer_tokens = chat_model.get_num_tokens(answer)
node.context['message_tokens'] = message_tokens
node.context['answer_tokens'] = answer_tokens
node.context['answer'] = answer
node.context['history_message'] = node_variable['history_message']
node.context['question'] = node_variable['question']
if workflow.is_result():
workflow.answer += answer
def write_context_stream(node_variable: Dict, workflow_variable: Dict, node: INode, workflow):
"""
写入上下文数据 (流式)
@ -31,15 +42,8 @@ def write_context_stream(node_variable: Dict, workflow_variable: Dict, node: INo
answer = ''
for chunk in response:
answer += chunk.content
chat_model = node_variable.get('chat_model')
message_tokens = chat_model.get_num_tokens_from_messages(node_variable.get('message_list'))
answer_tokens = chat_model.get_num_tokens(answer)
node.context['message_tokens'] = message_tokens
node.context['answer_tokens'] = answer_tokens
node.context['answer'] = answer
node.context['history_message'] = node_variable['history_message']
node.context['question'] = node_variable['question']
node.context['run_time'] = time.time() - node.context['start_time']
yield answer
_write_context(node_variable, workflow_variable, node, workflow, answer)
def write_context(node_variable: Dict, workflow_variable: Dict, node: INode, workflow):
@ -51,71 +55,8 @@ def write_context(node_variable: Dict, workflow_variable: Dict, node: INode, wor
@param workflow: 工作流管理器
"""
response = node_variable.get('result')
chat_model = node_variable.get('chat_model')
answer = response.content
message_tokens = chat_model.get_num_tokens_from_messages(node_variable.get('message_list'))
answer_tokens = chat_model.get_num_tokens(answer)
node.context['message_tokens'] = message_tokens
node.context['answer_tokens'] = answer_tokens
node.context['answer'] = answer
node.context['history_message'] = node_variable['history_message']
node.context['question'] = node_variable['question']
def get_to_response_write_context(node_variable: Dict, node: INode):
def _write_context(answer, status=200):
chat_model = node_variable.get('chat_model')
if status == 200:
answer_tokens = chat_model.get_num_tokens(answer)
message_tokens = chat_model.get_num_tokens_from_messages(node_variable.get('message_list'))
else:
answer_tokens = 0
message_tokens = 0
node.err_message = answer
node.status = status
node.context['message_tokens'] = message_tokens
node.context['answer_tokens'] = answer_tokens
node.context['answer'] = answer
node.context['run_time'] = time.time() - node.context['start_time']
return _write_context
def to_stream_response(chat_id, chat_record_id, node_variable: Dict, workflow_variable: Dict, node, workflow,
post_handler):
"""
将流式数据 转换为 流式响应
@param chat_id: 会话id
@param chat_record_id: 对话记录id
@param node_variable: 节点数据
@param workflow_variable: 工作流数据
@param node: 节点
@param workflow: 工作流管理器
@param post_handler: 后置处理器 输出结果后执行
@return: 流式响应
"""
response = node_variable.get('result')
_write_context = get_to_response_write_context(node_variable, node)
return tools.to_stream_response(chat_id, chat_record_id, response, workflow, _write_context, post_handler)
def to_response(chat_id, chat_record_id, node_variable: Dict, workflow_variable: Dict, node, workflow,
post_handler):
"""
将结果转换
@param chat_id: 会话id
@param chat_record_id: 对话记录id
@param node_variable: 节点数据
@param workflow_variable: 工作流数据
@param node: 节点
@param workflow: 工作流管理器
@param post_handler: 后置处理器
@return: 响应
"""
response = node_variable.get('result')
_write_context = get_to_response_write_context(node_variable, node)
return tools.to_response(chat_id, chat_record_id, response, workflow, _write_context, post_handler)
_write_context(node_variable, workflow_variable, node, workflow, answer)
class BaseChatNode(IChatNode):
@ -132,13 +73,12 @@ class BaseChatNode(IChatNode):
r = chat_model.stream(message_list)
return NodeResult({'result': r, 'chat_model': chat_model, 'message_list': message_list,
'history_message': history_message, 'question': question.content}, {},
_write_context=write_context_stream,
_to_response=to_stream_response)
_write_context=write_context_stream)
else:
r = chat_model.invoke(message_list)
return NodeResult({'result': r, 'chat_model': chat_model, 'message_list': message_list,
'history_message': history_message, 'question': question.content}, {},
_write_context=write_context, _to_response=to_response)
_write_context=write_context)
@staticmethod
def get_history_message(history_chat_record, dialogue_number):

View File

@ -20,6 +20,7 @@ class ReplyNodeParamsSerializer(serializers.Serializer):
fields = serializers.ListField(required=False, error_messages=ErrMessage.list("引用字段"))
content = serializers.CharField(required=False, allow_blank=True, allow_null=True,
error_messages=ErrMessage.char("直接回答内容"))
is_result = serializers.BooleanField(required=False, error_messages=ErrMessage.boolean('是否返回内容'))
def is_valid(self, *, raise_exception=False):
super().is_valid(raise_exception=True)

View File

@ -6,69 +6,19 @@
@date2024/6/11 17:25
@desc:
"""
from typing import List, Dict
from typing import List
from langchain_core.messages import AIMessage, AIMessageChunk
from application.flow import tools
from application.flow.i_step_node import NodeResult, INode
from application.flow.i_step_node import NodeResult
from application.flow.step_node.direct_reply_node.i_reply_node import IReplyNode
def get_to_response_write_context(node_variable: Dict, node: INode):
def _write_context(answer, status=200):
node.context['answer'] = answer
return _write_context
def to_stream_response(chat_id, chat_record_id, node_variable: Dict, workflow_variable: Dict, node, workflow,
post_handler):
"""
将流式数据 转换为 流式响应
@param chat_id: 会话id
@param chat_record_id: 对话记录id
@param node_variable: 节点数据
@param workflow_variable: 工作流数据
@param node: 节点
@param workflow: 工作流管理器
@param post_handler: 后置处理器 输出结果后执行
@return: 流式响应
"""
response = node_variable.get('result')
_write_context = get_to_response_write_context(node_variable, node)
return tools.to_stream_response(chat_id, chat_record_id, response, workflow, _write_context, post_handler)
def to_response(chat_id, chat_record_id, node_variable: Dict, workflow_variable: Dict, node, workflow,
post_handler):
"""
将结果转换
@param chat_id: 会话id
@param chat_record_id: 对话记录id
@param node_variable: 节点数据
@param workflow_variable: 工作流数据
@param node: 节点
@param workflow: 工作流管理器
@param post_handler: 后置处理器
@return: 响应
"""
response = node_variable.get('result')
_write_context = get_to_response_write_context(node_variable, node)
return tools.to_response(chat_id, chat_record_id, response, workflow, _write_context, post_handler)
class BaseReplyNode(IReplyNode):
def execute(self, reply_type, stream, fields=None, content=None, **kwargs) -> NodeResult:
if reply_type == 'referencing':
result = self.get_reference_content(fields)
else:
result = self.generate_reply_content(content)
if stream:
return NodeResult({'result': iter([AIMessageChunk(content=result)]), 'answer': result}, {},
_to_response=to_stream_response)
else:
return NodeResult({'result': AIMessage(content=result), 'answer': result}, {}, _to_response=to_response)
return NodeResult({'answer': result}, {})
def generate_reply_content(self, prompt):
return self.workflow_manage.generate_prompt(prompt)

View File

@ -22,6 +22,8 @@ class QuestionNodeSerializer(serializers.Serializer):
# 多轮对话数量
dialogue_number = serializers.IntegerField(required=True, error_messages=ErrMessage.integer("多轮对话数量"))
is_result = serializers.BooleanField(required=False, error_messages=ErrMessage.boolean('是否返回内容'))
class IQuestionNode(INode):
type = 'question-node'

View File

@ -6,19 +6,30 @@
@date2024/6/4 14:30
@desc:
"""
import time
from functools import reduce
from typing import List, Dict
from langchain.schema import HumanMessage, SystemMessage
from langchain_core.messages import BaseMessage
from application.flow import tools
from application.flow.i_step_node import NodeResult, INode
from application.flow.step_node.question_node.i_question_node import IQuestionNode
from setting.models_provider.tools import get_model_instance_by_model_user_id
def _write_context(node_variable: Dict, workflow_variable: Dict, node: INode, workflow, answer: str):
chat_model = node_variable.get('chat_model')
message_tokens = chat_model.get_num_tokens_from_messages(node_variable.get('message_list'))
answer_tokens = chat_model.get_num_tokens(answer)
node.context['message_tokens'] = message_tokens
node.context['answer_tokens'] = answer_tokens
node.context['answer'] = answer
node.context['history_message'] = node_variable['history_message']
node.context['question'] = node_variable['question']
if workflow.is_result():
workflow.answer += answer
def write_context_stream(node_variable: Dict, workflow_variable: Dict, node: INode, workflow):
"""
写入上下文数据 (流式)
@ -31,15 +42,8 @@ def write_context_stream(node_variable: Dict, workflow_variable: Dict, node: INo
answer = ''
for chunk in response:
answer += chunk.content
chat_model = node_variable.get('chat_model')
message_tokens = chat_model.get_num_tokens_from_messages(node_variable.get('message_list'))
answer_tokens = chat_model.get_num_tokens(answer)
node.context['message_tokens'] = message_tokens
node.context['answer_tokens'] = answer_tokens
node.context['answer'] = answer
node.context['history_message'] = node_variable['history_message']
node.context['question'] = node_variable['question']
node.context['run_time'] = time.time() - node.context['start_time']
yield answer
_write_context(node_variable, workflow_variable, node, workflow, answer)
def write_context(node_variable: Dict, workflow_variable: Dict, node: INode, workflow):
@ -51,71 +55,8 @@ def write_context(node_variable: Dict, workflow_variable: Dict, node: INode, wor
@param workflow: 工作流管理器
"""
response = node_variable.get('result')
chat_model = node_variable.get('chat_model')
answer = response.content
message_tokens = chat_model.get_num_tokens_from_messages(node_variable.get('message_list'))
answer_tokens = chat_model.get_num_tokens(answer)
node.context['message_tokens'] = message_tokens
node.context['answer_tokens'] = answer_tokens
node.context['answer'] = answer
node.context['history_message'] = node_variable['history_message']
node.context['question'] = node_variable['question']
def get_to_response_write_context(node_variable: Dict, node: INode):
def _write_context(answer, status=200):
chat_model = node_variable.get('chat_model')
if status == 200:
answer_tokens = chat_model.get_num_tokens(answer)
message_tokens = chat_model.get_num_tokens_from_messages(node_variable.get('message_list'))
else:
answer_tokens = 0
message_tokens = 0
node.err_message = answer
node.status = status
node.context['message_tokens'] = message_tokens
node.context['answer_tokens'] = answer_tokens
node.context['answer'] = answer
node.context['run_time'] = time.time() - node.context['start_time']
return _write_context
def to_stream_response(chat_id, chat_record_id, node_variable: Dict, workflow_variable: Dict, node, workflow,
post_handler):
"""
将流式数据 转换为 流式响应
@param chat_id: 会话id
@param chat_record_id: 对话记录id
@param node_variable: 节点数据
@param workflow_variable: 工作流数据
@param node: 节点
@param workflow: 工作流管理器
@param post_handler: 后置处理器 输出结果后执行
@return: 流式响应
"""
response = node_variable.get('result')
_write_context = get_to_response_write_context(node_variable, node)
return tools.to_stream_response(chat_id, chat_record_id, response, workflow, _write_context, post_handler)
def to_response(chat_id, chat_record_id, node_variable: Dict, workflow_variable: Dict, node, workflow,
post_handler):
"""
将结果转换
@param chat_id: 会话id
@param chat_record_id: 对话记录id
@param node_variable: 节点数据
@param workflow_variable: 工作流数据
@param node: 节点
@param workflow: 工作流管理器
@param post_handler: 后置处理器
@return: 响应
"""
response = node_variable.get('result')
_write_context = get_to_response_write_context(node_variable, node)
return tools.to_response(chat_id, chat_record_id, response, workflow, _write_context, post_handler)
_write_context(node_variable, workflow_variable, node, workflow, answer)
class BaseQuestionNode(IQuestionNode):
@ -131,15 +72,13 @@ class BaseQuestionNode(IQuestionNode):
if stream:
r = chat_model.stream(message_list)
return NodeResult({'result': r, 'chat_model': chat_model, 'message_list': message_list,
'get_to_response_write_context': get_to_response_write_context,
'history_message': history_message, 'question': question.content}, {},
_write_context=write_context_stream,
_to_response=to_stream_response)
_write_context=write_context_stream)
else:
r = chat_model.invoke(message_list)
return NodeResult({'result': r, 'chat_model': chat_model, 'message_list': message_list,
'history_message': history_message, 'question': question.content}, {},
_write_context=write_context, _to_response=to_response)
_write_context=write_context)
@staticmethod
def get_history_message(history_chat_record, dialogue_number):

View File

@ -85,3 +85,21 @@ def to_response(chat_id, chat_record_id, response: BaseMessage, workflow, write_
post_handler.handler(chat_id, chat_record_id, answer, workflow)
return result.success({'chat_id': str(chat_id), 'id': str(chat_record_id), 'operate': True,
'content': answer, 'is_end': True})
def to_response_simple(chat_id, chat_record_id, response: BaseMessage, workflow,
post_handler: WorkFlowPostHandler):
answer = response.content
post_handler.handler(chat_id, chat_record_id, answer, workflow)
return result.success({'chat_id': str(chat_id), 'id': str(chat_record_id), 'operate': True,
'content': answer, 'is_end': True})
def to_stream_response_simple(stream_event):
r = StreamingHttpResponse(
streaming_content=stream_event,
content_type='text/event-stream;charset=utf-8',
charset='utf-8')
r['Cache-Control'] = 'no-cache'
return r

View File

@ -6,10 +6,11 @@
@date2024/1/9 17:40
@desc:
"""
import json
from functools import reduce
from typing import List, Dict
from langchain_core.messages import AIMessageChunk, AIMessage
from langchain_core.messages import AIMessage
from langchain_core.prompts import PromptTemplate
from application.flow import tools
@ -63,7 +64,6 @@ class Flow:
def get_search_node(self):
return [node for node in self.nodes if node.type == 'search-dataset-node']
def is_valid(self):
"""
校验工作流数据
@ -140,33 +140,71 @@ class WorkflowManage:
self.work_flow_post_handler = work_flow_post_handler
self.current_node = None
self.current_result = None
self.answer = ""
def run(self):
"""
运行工作流
"""
if self.params.get('stream'):
return self.run_stream()
return self.run_block()
def run_block(self):
try:
while self.has_next_node(self.current_result):
self.current_node = self.get_next_node()
self.node_context.append(self.current_node)
self.current_result = self.current_node.run()
if self.has_next_node(self.current_result):
self.current_result.write_context(self.current_node, self)
else:
r = self.current_result.to_response(self.params['chat_id'], self.params['chat_record_id'],
self.current_node, self,
self.work_flow_post_handler)
return r
result = self.current_result.write_context(self.current_node, self)
if result is not None:
list(result)
if not self.has_next_node(self.current_result):
return tools.to_response_simple(self.params['chat_id'], self.params['chat_record_id'],
AIMessage(self.answer), self,
self.work_flow_post_handler)
except Exception as e:
if self.params.get('stream'):
return tools.to_stream_response(self.params['chat_id'], self.params['chat_record_id'],
iter([AIMessageChunk(str(e))]), self,
self.current_node.get_write_error_context(e),
self.work_flow_post_handler)
else:
return tools.to_response(self.params['chat_id'], self.params['chat_record_id'],
AIMessage(str(e)), self, self.current_node.get_write_error_context(e),
self.work_flow_post_handler)
return tools.to_response(self.params['chat_id'], self.params['chat_record_id'],
AIMessage(str(e)), self, self.current_node.get_write_error_context(e),
self.work_flow_post_handler)
def run_stream(self):
return tools.to_stream_response_simple(self.stream_event())
def stream_event(self):
try:
while self.has_next_node(self.current_result):
self.current_node = self.get_next_node()
self.node_context.append(self.current_node)
self.current_result = self.current_node.run()
result = self.current_result.write_context(self.current_node, self)
if result is not None:
for r in result:
if self.is_result():
yield self.get_chunk_content(r)
if not self.has_next_node(self.current_result):
yield self.get_chunk_content('', True)
break
self.work_flow_post_handler.handler(self.params['chat_id'], self.params['chat_record_id'],
self.answer,
self)
except Exception as e:
self.current_node.get_write_error_context(e)
self.answer += str(e)
self.work_flow_post_handler.handler(self.params['chat_id'], self.params['chat_record_id'],
self.answer,
self)
yield self.get_chunk_content(str(e), True)
def is_result(self):
"""
判断是否是返回节点
@return:
"""
return self.current_node.node_params.get('is_result', not self.has_next_node(
self.current_result)) if self.current_node.node_params is not None else False
def get_chunk_content(self, chunk, is_end=False):
return 'data: ' + json.dumps(
{'chat_id': self.params['chat_id'], 'id': self.params['chat_record_id'], 'operate': True,
'content': chunk, 'is_end': is_end}, ensure_ascii=False) + "\n\n"
def has_next_node(self, node_result: NodeResult | None):
"""