From b7a406db56dc535706ccc81f9f1b12c2b9c9a339 Mon Sep 17 00:00:00 2001 From: shaohuzhang1 Date: Wed, 21 Feb 2024 18:10:18 +0800 Subject: [PATCH] =?UTF-8?q?feat:=20=E5=AF=B9=E8=AF=9D=E6=94=AF=E6=8C=81?= =?UTF-8?q?=E5=85=A8=E9=83=A8=E8=BF=94=E5=9B=9E?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../chat_pipeline/I_base_chat_pipeline.py | 2 +- .../step/chat_step/i_chat_step.py | 4 +- .../step/chat_step/impl/base_chat_step.py | 81 +++++++++++++++---- .../serializers/chat_message_serializers.py | 12 +-- apps/application/swagger_api/chat_api.py | 3 +- apps/application/views/chat_views.py | 3 +- 6 files changed, 80 insertions(+), 25 deletions(-) diff --git a/apps/application/chat_pipeline/I_base_chat_pipeline.py b/apps/application/chat_pipeline/I_base_chat_pipeline.py index 3a7f43060..1c0f8d998 100644 --- a/apps/application/chat_pipeline/I_base_chat_pipeline.py +++ b/apps/application/chat_pipeline/I_base_chat_pipeline.py @@ -116,10 +116,10 @@ class IBaseChatPipelineStep: :return: 执行结果 """ start_time = time.time() + self.context['start_time'] = start_time # 校验参数, self.valid_args(manage) self._run(manage) - self.context['start_time'] = start_time self.context['run_time'] = time.time() - start_time def _run(self, manage): diff --git a/apps/application/chat_pipeline/step/chat_step/i_chat_step.py b/apps/application/chat_pipeline/step/chat_step/i_chat_step.py index 1168fb9af..2c3691cd2 100644 --- a/apps/application/chat_pipeline/step/chat_step/i_chat_step.py +++ b/apps/application/chat_pipeline/step/chat_step/i_chat_step.py @@ -63,6 +63,8 @@ class IChatStep(IBaseChatPipelineStep): post_response_handler = InstanceField(model_type=PostResponseHandler) # 补全问题 padding_problem_text = serializers.CharField(required=False) + # 是否使用流的形式输出 + stream = serializers.BooleanField(required=False) def is_valid(self, *, raise_exception=False): super().is_valid(raise_exception=True) @@ -85,5 +87,5 @@ class IChatStep(IBaseChatPipelineStep): chat_model: BaseChatModel = None, paragraph_list=None, manage: PiplineManage = None, - padding_problem_text: str = None, **kwargs): + padding_problem_text: str = None, stream: bool = True, **kwargs): pass diff --git a/apps/application/chat_pipeline/step/chat_step/impl/base_chat_step.py b/apps/application/chat_pipeline/step/chat_step/impl/base_chat_step.py index 080142893..6d6b5c496 100644 --- a/apps/application/chat_pipeline/step/chat_step/impl/base_chat_step.py +++ b/apps/application/chat_pipeline/step/chat_step/impl/base_chat_step.py @@ -16,11 +16,12 @@ from typing import List from django.http import StreamingHttpResponse from langchain.chat_models.base import BaseChatModel from langchain.schema import BaseMessage -from langchain.schema.messages import BaseMessageChunk, HumanMessage +from langchain.schema.messages import BaseMessageChunk, HumanMessage, AIMessage from application.chat_pipeline.I_base_chat_pipeline import ParagraphPipelineModel from application.chat_pipeline.pipeline_manage import PiplineManage from application.chat_pipeline.step.chat_step.i_chat_step import IChatStep, PostResponseHandler +from common.response import result def event_content(response, @@ -71,23 +72,16 @@ class BaseChatStep(IChatStep): paragraph_list=None, manage: PiplineManage = None, padding_problem_text: str = None, + stream: bool = True, **kwargs): - # 调用模型 - if chat_model is None: - chat_result = iter( - [BaseMessageChunk(content=paragraph.title + "\n" + paragraph.content) for paragraph in paragraph_list]) + if stream: + return self.execute_stream(message_list, chat_id, problem_text, post_response_handler, chat_model, + paragraph_list, + manage, padding_problem_text) else: - chat_result = chat_model.stream(message_list) - - chat_record_id = uuid.uuid1() - r = StreamingHttpResponse( - streaming_content=event_content(chat_result, chat_id, chat_record_id, paragraph_list, - post_response_handler, manage, self, chat_model, message_list, problem_text, - padding_problem_text), - content_type='text/event-stream;charset=utf-8') - - r['Cache-Control'] = 'no-cache' - return r + return self.execute_block(message_list, chat_id, problem_text, post_response_handler, chat_model, + paragraph_list, + manage, padding_problem_text) def get_details(self, manage, **kwargs): return { @@ -109,3 +103,58 @@ class BaseChatStep(IChatStep): message_list] result.append({'role': 'ai', 'content': answer_text}) return result + + def execute_stream(self, message_list: List[BaseMessage], + chat_id, + problem_text, + post_response_handler: PostResponseHandler, + chat_model: BaseChatModel = None, + paragraph_list=None, + manage: PiplineManage = None, + padding_problem_text: str = None): + # 调用模型 + if chat_model is None: + chat_result = iter( + [BaseMessageChunk(content=paragraph.title + "\n" + paragraph.content) for paragraph in paragraph_list]) + else: + chat_result = chat_model.stream(message_list) + + chat_record_id = uuid.uuid1() + r = StreamingHttpResponse( + streaming_content=event_content(chat_result, chat_id, chat_record_id, paragraph_list, + post_response_handler, manage, self, chat_model, message_list, problem_text, + padding_problem_text), + content_type='text/event-stream;charset=utf-8') + + r['Cache-Control'] = 'no-cache' + return r + + def execute_block(self, message_list: List[BaseMessage], + chat_id, + problem_text, + post_response_handler: PostResponseHandler, + chat_model: BaseChatModel = None, + paragraph_list=None, + manage: PiplineManage = None, + padding_problem_text: str = None): + # 调用模型 + if chat_model is None: + chat_result = AIMessage( + content="\n\n".join([paragraph.title + "\n" + paragraph.content for paragraph in paragraph_list])) + else: + chat_result = chat_model(message_list) + chat_record_id = uuid.uuid1() + request_token = chat_model.get_num_tokens_from_messages(message_list) + response_token = chat_model.get_num_tokens(chat_result.content) + self.context['message_tokens'] = request_token + self.context['answer_tokens'] = response_token + current_time = time.time() + self.context['answer_text'] = chat_result.content + self.context['run_time'] = current_time - self.context['start_time'] + manage.context['run_time'] = current_time - manage.context['start_time'] + manage.context['message_tokens'] = manage.context['message_tokens'] + request_token + manage.context['answer_tokens'] = manage.context['answer_tokens'] + response_token + post_response_handler.handler(chat_id, chat_record_id, paragraph_list, problem_text, + chat_result.content, manage, self, padding_problem_text) + return result.success({'chat_id': str(chat_id), 'id': str(chat_record_id), 'operate': True, + 'content': chat_result.content, 'is_end': True}) diff --git a/apps/application/serializers/chat_message_serializers.py b/apps/application/serializers/chat_message_serializers.py index 3d03e57c3..9fe10065f 100644 --- a/apps/application/serializers/chat_message_serializers.py +++ b/apps/application/serializers/chat_message_serializers.py @@ -72,15 +72,16 @@ class ChatInfo: 'prompt') if 'prompt' in model_setting else Application.get_default_model_prompt(), 'chat_model': self.chat_model, 'model_id': self.application.model.id if self.application.model is not None else None, - 'problem_optimization': self.application.problem_optimization + 'problem_optimization': self.application.problem_optimization, + 'stream': True } def to_pipeline_manage_params(self, problem_text: str, post_response_handler: PostResponseHandler, - exclude_paragraph_id_list): + exclude_paragraph_id_list, stream=True): params = self.to_base_pipeline_manage_params() return {**params, 'problem_text': problem_text, 'post_response_handler': post_response_handler, - 'exclude_paragraph_id_list': exclude_paragraph_id_list} + 'exclude_paragraph_id_list': exclude_paragraph_id_list, 'stream': stream} def append_chat_record(self, chat_record: ChatRecord): # 存入缓存中 @@ -126,7 +127,7 @@ def get_post_handler(chat_info: ChatInfo): class ChatMessageSerializer(serializers.Serializer): chat_id = serializers.UUIDField(required=True) - def chat(self, message, re_chat: bool): + def chat(self, message, re_chat: bool, stream: bool): self.is_valid(raise_exception=True) chat_id = self.data.get('chat_id') chat_info: ChatInfo = chat_cache.get(chat_id) @@ -152,7 +153,8 @@ class ChatMessageSerializer(serializers.Serializer): chat_info.chat_record_list)]) exclude_paragraph_id_list = list(set(paragraph_id_list)) # 构建运行参数 - params = chat_info.to_pipeline_manage_params(message, get_post_handler(chat_info), exclude_paragraph_id_list) + params = chat_info.to_pipeline_manage_params(message, get_post_handler(chat_info), exclude_paragraph_id_list, + stream) # 运行流水线作业 pipline_message.run(params) return pipline_message.context['chat_result'] diff --git a/apps/application/swagger_api/chat_api.py b/apps/application/swagger_api/chat_api.py index d77f8de47..8ae2b64cd 100644 --- a/apps/application/swagger_api/chat_api.py +++ b/apps/application/swagger_api/chat_api.py @@ -20,7 +20,8 @@ class ChatApi(ApiMixin): required=['message'], properties={ 'message': openapi.Schema(type=openapi.TYPE_STRING, title="问题", description="问题"), - 're_chat': openapi.Schema(type=openapi.TYPE_BOOLEAN, title="重新生成", default="重新生成") + 're_chat': openapi.Schema(type=openapi.TYPE_BOOLEAN, title="重新生成", default=False), + 'stream': openapi.Schema(type=openapi.TYPE_BOOLEAN, title="重新生成", default=True) } ) diff --git a/apps/application/views/chat_views.py b/apps/application/views/chat_views.py index c39c149c9..3ce7f3e8c 100644 --- a/apps/application/views/chat_views.py +++ b/apps/application/views/chat_views.py @@ -72,7 +72,8 @@ class ChatView(APIView): ) def post(self, request: Request, chat_id: str): return ChatMessageSerializer(data={'chat_id': chat_id}).chat(request.data.get('message'), request.data.get( - 're_chat') if 're_chat' in request.data else False) + 're_chat') if 're_chat' in request.data else False, request.data.get( + 'stream') if 'stream' in request.data else True) @action(methods=['GET'], detail=False) @swagger_auto_schema(operation_summary="获取对话列表",