feat: 对话支持全部返回

This commit is contained in:
shaohuzhang1 2024-02-21 18:10:18 +08:00
parent 7eb18fbf30
commit b7a406db56
6 changed files with 80 additions and 25 deletions

View File

@ -116,10 +116,10 @@ class IBaseChatPipelineStep:
:return: 执行结果
"""
start_time = time.time()
self.context['start_time'] = start_time
# 校验参数,
self.valid_args(manage)
self._run(manage)
self.context['start_time'] = start_time
self.context['run_time'] = time.time() - start_time
def _run(self, manage):

View File

@ -63,6 +63,8 @@ class IChatStep(IBaseChatPipelineStep):
post_response_handler = InstanceField(model_type=PostResponseHandler)
# 补全问题
padding_problem_text = serializers.CharField(required=False)
# 是否使用流的形式输出
stream = serializers.BooleanField(required=False)
def is_valid(self, *, raise_exception=False):
super().is_valid(raise_exception=True)
@ -85,5 +87,5 @@ class IChatStep(IBaseChatPipelineStep):
chat_model: BaseChatModel = None,
paragraph_list=None,
manage: PiplineManage = None,
padding_problem_text: str = None, **kwargs):
padding_problem_text: str = None, stream: bool = True, **kwargs):
pass

View File

@ -16,11 +16,12 @@ from typing import List
from django.http import StreamingHttpResponse
from langchain.chat_models.base import BaseChatModel
from langchain.schema import BaseMessage
from langchain.schema.messages import BaseMessageChunk, HumanMessage
from langchain.schema.messages import BaseMessageChunk, HumanMessage, AIMessage
from application.chat_pipeline.I_base_chat_pipeline import ParagraphPipelineModel
from application.chat_pipeline.pipeline_manage import PiplineManage
from application.chat_pipeline.step.chat_step.i_chat_step import IChatStep, PostResponseHandler
from common.response import result
def event_content(response,
@ -71,23 +72,16 @@ class BaseChatStep(IChatStep):
paragraph_list=None,
manage: PiplineManage = None,
padding_problem_text: str = None,
stream: bool = True,
**kwargs):
# 调用模型
if chat_model is None:
chat_result = iter(
[BaseMessageChunk(content=paragraph.title + "\n" + paragraph.content) for paragraph in paragraph_list])
if stream:
return self.execute_stream(message_list, chat_id, problem_text, post_response_handler, chat_model,
paragraph_list,
manage, padding_problem_text)
else:
chat_result = chat_model.stream(message_list)
chat_record_id = uuid.uuid1()
r = StreamingHttpResponse(
streaming_content=event_content(chat_result, chat_id, chat_record_id, paragraph_list,
post_response_handler, manage, self, chat_model, message_list, problem_text,
padding_problem_text),
content_type='text/event-stream;charset=utf-8')
r['Cache-Control'] = 'no-cache'
return r
return self.execute_block(message_list, chat_id, problem_text, post_response_handler, chat_model,
paragraph_list,
manage, padding_problem_text)
def get_details(self, manage, **kwargs):
return {
@ -109,3 +103,58 @@ class BaseChatStep(IChatStep):
message_list]
result.append({'role': 'ai', 'content': answer_text})
return result
def execute_stream(self, message_list: List[BaseMessage],
chat_id,
problem_text,
post_response_handler: PostResponseHandler,
chat_model: BaseChatModel = None,
paragraph_list=None,
manage: PiplineManage = None,
padding_problem_text: str = None):
# 调用模型
if chat_model is None:
chat_result = iter(
[BaseMessageChunk(content=paragraph.title + "\n" + paragraph.content) for paragraph in paragraph_list])
else:
chat_result = chat_model.stream(message_list)
chat_record_id = uuid.uuid1()
r = StreamingHttpResponse(
streaming_content=event_content(chat_result, chat_id, chat_record_id, paragraph_list,
post_response_handler, manage, self, chat_model, message_list, problem_text,
padding_problem_text),
content_type='text/event-stream;charset=utf-8')
r['Cache-Control'] = 'no-cache'
return r
def execute_block(self, message_list: List[BaseMessage],
chat_id,
problem_text,
post_response_handler: PostResponseHandler,
chat_model: BaseChatModel = None,
paragraph_list=None,
manage: PiplineManage = None,
padding_problem_text: str = None):
# 调用模型
if chat_model is None:
chat_result = AIMessage(
content="\n\n".join([paragraph.title + "\n" + paragraph.content for paragraph in paragraph_list]))
else:
chat_result = chat_model(message_list)
chat_record_id = uuid.uuid1()
request_token = chat_model.get_num_tokens_from_messages(message_list)
response_token = chat_model.get_num_tokens(chat_result.content)
self.context['message_tokens'] = request_token
self.context['answer_tokens'] = response_token
current_time = time.time()
self.context['answer_text'] = chat_result.content
self.context['run_time'] = current_time - self.context['start_time']
manage.context['run_time'] = current_time - manage.context['start_time']
manage.context['message_tokens'] = manage.context['message_tokens'] + request_token
manage.context['answer_tokens'] = manage.context['answer_tokens'] + response_token
post_response_handler.handler(chat_id, chat_record_id, paragraph_list, problem_text,
chat_result.content, manage, self, padding_problem_text)
return result.success({'chat_id': str(chat_id), 'id': str(chat_record_id), 'operate': True,
'content': chat_result.content, 'is_end': True})

View File

@ -72,15 +72,16 @@ class ChatInfo:
'prompt') if 'prompt' in model_setting else Application.get_default_model_prompt(),
'chat_model': self.chat_model,
'model_id': self.application.model.id if self.application.model is not None else None,
'problem_optimization': self.application.problem_optimization
'problem_optimization': self.application.problem_optimization,
'stream': True
}
def to_pipeline_manage_params(self, problem_text: str, post_response_handler: PostResponseHandler,
exclude_paragraph_id_list):
exclude_paragraph_id_list, stream=True):
params = self.to_base_pipeline_manage_params()
return {**params, 'problem_text': problem_text, 'post_response_handler': post_response_handler,
'exclude_paragraph_id_list': exclude_paragraph_id_list}
'exclude_paragraph_id_list': exclude_paragraph_id_list, 'stream': stream}
def append_chat_record(self, chat_record: ChatRecord):
# 存入缓存中
@ -126,7 +127,7 @@ def get_post_handler(chat_info: ChatInfo):
class ChatMessageSerializer(serializers.Serializer):
chat_id = serializers.UUIDField(required=True)
def chat(self, message, re_chat: bool):
def chat(self, message, re_chat: bool, stream: bool):
self.is_valid(raise_exception=True)
chat_id = self.data.get('chat_id')
chat_info: ChatInfo = chat_cache.get(chat_id)
@ -152,7 +153,8 @@ class ChatMessageSerializer(serializers.Serializer):
chat_info.chat_record_list)])
exclude_paragraph_id_list = list(set(paragraph_id_list))
# 构建运行参数
params = chat_info.to_pipeline_manage_params(message, get_post_handler(chat_info), exclude_paragraph_id_list)
params = chat_info.to_pipeline_manage_params(message, get_post_handler(chat_info), exclude_paragraph_id_list,
stream)
# 运行流水线作业
pipline_message.run(params)
return pipline_message.context['chat_result']

View File

@ -20,7 +20,8 @@ class ChatApi(ApiMixin):
required=['message'],
properties={
'message': openapi.Schema(type=openapi.TYPE_STRING, title="问题", description="问题"),
're_chat': openapi.Schema(type=openapi.TYPE_BOOLEAN, title="重新生成", default="重新生成")
're_chat': openapi.Schema(type=openapi.TYPE_BOOLEAN, title="重新生成", default=False),
'stream': openapi.Schema(type=openapi.TYPE_BOOLEAN, title="重新生成", default=True)
}
)

View File

@ -72,7 +72,8 @@ class ChatView(APIView):
)
def post(self, request: Request, chat_id: str):
return ChatMessageSerializer(data={'chat_id': chat_id}).chat(request.data.get('message'), request.data.get(
're_chat') if 're_chat' in request.data else False)
're_chat') if 're_chat' in request.data else False, request.data.get(
'stream') if 'stream' in request.data else True)
@action(methods=['GET'], detail=False)
@swagger_auto_schema(operation_summary="获取对话列表",