mirror of
https://github.com/1Panel-dev/MaxKB.git
synced 2025-12-26 01:33:05 +00:00
feat: 对话支持全部返回
This commit is contained in:
parent
7eb18fbf30
commit
b7a406db56
|
|
@ -116,10 +116,10 @@ class IBaseChatPipelineStep:
|
|||
:return: 执行结果
|
||||
"""
|
||||
start_time = time.time()
|
||||
self.context['start_time'] = start_time
|
||||
# 校验参数,
|
||||
self.valid_args(manage)
|
||||
self._run(manage)
|
||||
self.context['start_time'] = start_time
|
||||
self.context['run_time'] = time.time() - start_time
|
||||
|
||||
def _run(self, manage):
|
||||
|
|
|
|||
|
|
@ -63,6 +63,8 @@ class IChatStep(IBaseChatPipelineStep):
|
|||
post_response_handler = InstanceField(model_type=PostResponseHandler)
|
||||
# 补全问题
|
||||
padding_problem_text = serializers.CharField(required=False)
|
||||
# 是否使用流的形式输出
|
||||
stream = serializers.BooleanField(required=False)
|
||||
|
||||
def is_valid(self, *, raise_exception=False):
|
||||
super().is_valid(raise_exception=True)
|
||||
|
|
@ -85,5 +87,5 @@ class IChatStep(IBaseChatPipelineStep):
|
|||
chat_model: BaseChatModel = None,
|
||||
paragraph_list=None,
|
||||
manage: PiplineManage = None,
|
||||
padding_problem_text: str = None, **kwargs):
|
||||
padding_problem_text: str = None, stream: bool = True, **kwargs):
|
||||
pass
|
||||
|
|
|
|||
|
|
@ -16,11 +16,12 @@ from typing import List
|
|||
from django.http import StreamingHttpResponse
|
||||
from langchain.chat_models.base import BaseChatModel
|
||||
from langchain.schema import BaseMessage
|
||||
from langchain.schema.messages import BaseMessageChunk, HumanMessage
|
||||
from langchain.schema.messages import BaseMessageChunk, HumanMessage, AIMessage
|
||||
|
||||
from application.chat_pipeline.I_base_chat_pipeline import ParagraphPipelineModel
|
||||
from application.chat_pipeline.pipeline_manage import PiplineManage
|
||||
from application.chat_pipeline.step.chat_step.i_chat_step import IChatStep, PostResponseHandler
|
||||
from common.response import result
|
||||
|
||||
|
||||
def event_content(response,
|
||||
|
|
@ -71,23 +72,16 @@ class BaseChatStep(IChatStep):
|
|||
paragraph_list=None,
|
||||
manage: PiplineManage = None,
|
||||
padding_problem_text: str = None,
|
||||
stream: bool = True,
|
||||
**kwargs):
|
||||
# 调用模型
|
||||
if chat_model is None:
|
||||
chat_result = iter(
|
||||
[BaseMessageChunk(content=paragraph.title + "\n" + paragraph.content) for paragraph in paragraph_list])
|
||||
if stream:
|
||||
return self.execute_stream(message_list, chat_id, problem_text, post_response_handler, chat_model,
|
||||
paragraph_list,
|
||||
manage, padding_problem_text)
|
||||
else:
|
||||
chat_result = chat_model.stream(message_list)
|
||||
|
||||
chat_record_id = uuid.uuid1()
|
||||
r = StreamingHttpResponse(
|
||||
streaming_content=event_content(chat_result, chat_id, chat_record_id, paragraph_list,
|
||||
post_response_handler, manage, self, chat_model, message_list, problem_text,
|
||||
padding_problem_text),
|
||||
content_type='text/event-stream;charset=utf-8')
|
||||
|
||||
r['Cache-Control'] = 'no-cache'
|
||||
return r
|
||||
return self.execute_block(message_list, chat_id, problem_text, post_response_handler, chat_model,
|
||||
paragraph_list,
|
||||
manage, padding_problem_text)
|
||||
|
||||
def get_details(self, manage, **kwargs):
|
||||
return {
|
||||
|
|
@ -109,3 +103,58 @@ class BaseChatStep(IChatStep):
|
|||
message_list]
|
||||
result.append({'role': 'ai', 'content': answer_text})
|
||||
return result
|
||||
|
||||
def execute_stream(self, message_list: List[BaseMessage],
|
||||
chat_id,
|
||||
problem_text,
|
||||
post_response_handler: PostResponseHandler,
|
||||
chat_model: BaseChatModel = None,
|
||||
paragraph_list=None,
|
||||
manage: PiplineManage = None,
|
||||
padding_problem_text: str = None):
|
||||
# 调用模型
|
||||
if chat_model is None:
|
||||
chat_result = iter(
|
||||
[BaseMessageChunk(content=paragraph.title + "\n" + paragraph.content) for paragraph in paragraph_list])
|
||||
else:
|
||||
chat_result = chat_model.stream(message_list)
|
||||
|
||||
chat_record_id = uuid.uuid1()
|
||||
r = StreamingHttpResponse(
|
||||
streaming_content=event_content(chat_result, chat_id, chat_record_id, paragraph_list,
|
||||
post_response_handler, manage, self, chat_model, message_list, problem_text,
|
||||
padding_problem_text),
|
||||
content_type='text/event-stream;charset=utf-8')
|
||||
|
||||
r['Cache-Control'] = 'no-cache'
|
||||
return r
|
||||
|
||||
def execute_block(self, message_list: List[BaseMessage],
|
||||
chat_id,
|
||||
problem_text,
|
||||
post_response_handler: PostResponseHandler,
|
||||
chat_model: BaseChatModel = None,
|
||||
paragraph_list=None,
|
||||
manage: PiplineManage = None,
|
||||
padding_problem_text: str = None):
|
||||
# 调用模型
|
||||
if chat_model is None:
|
||||
chat_result = AIMessage(
|
||||
content="\n\n".join([paragraph.title + "\n" + paragraph.content for paragraph in paragraph_list]))
|
||||
else:
|
||||
chat_result = chat_model(message_list)
|
||||
chat_record_id = uuid.uuid1()
|
||||
request_token = chat_model.get_num_tokens_from_messages(message_list)
|
||||
response_token = chat_model.get_num_tokens(chat_result.content)
|
||||
self.context['message_tokens'] = request_token
|
||||
self.context['answer_tokens'] = response_token
|
||||
current_time = time.time()
|
||||
self.context['answer_text'] = chat_result.content
|
||||
self.context['run_time'] = current_time - self.context['start_time']
|
||||
manage.context['run_time'] = current_time - manage.context['start_time']
|
||||
manage.context['message_tokens'] = manage.context['message_tokens'] + request_token
|
||||
manage.context['answer_tokens'] = manage.context['answer_tokens'] + response_token
|
||||
post_response_handler.handler(chat_id, chat_record_id, paragraph_list, problem_text,
|
||||
chat_result.content, manage, self, padding_problem_text)
|
||||
return result.success({'chat_id': str(chat_id), 'id': str(chat_record_id), 'operate': True,
|
||||
'content': chat_result.content, 'is_end': True})
|
||||
|
|
|
|||
|
|
@ -72,15 +72,16 @@ class ChatInfo:
|
|||
'prompt') if 'prompt' in model_setting else Application.get_default_model_prompt(),
|
||||
'chat_model': self.chat_model,
|
||||
'model_id': self.application.model.id if self.application.model is not None else None,
|
||||
'problem_optimization': self.application.problem_optimization
|
||||
'problem_optimization': self.application.problem_optimization,
|
||||
'stream': True
|
||||
|
||||
}
|
||||
|
||||
def to_pipeline_manage_params(self, problem_text: str, post_response_handler: PostResponseHandler,
|
||||
exclude_paragraph_id_list):
|
||||
exclude_paragraph_id_list, stream=True):
|
||||
params = self.to_base_pipeline_manage_params()
|
||||
return {**params, 'problem_text': problem_text, 'post_response_handler': post_response_handler,
|
||||
'exclude_paragraph_id_list': exclude_paragraph_id_list}
|
||||
'exclude_paragraph_id_list': exclude_paragraph_id_list, 'stream': stream}
|
||||
|
||||
def append_chat_record(self, chat_record: ChatRecord):
|
||||
# 存入缓存中
|
||||
|
|
@ -126,7 +127,7 @@ def get_post_handler(chat_info: ChatInfo):
|
|||
class ChatMessageSerializer(serializers.Serializer):
|
||||
chat_id = serializers.UUIDField(required=True)
|
||||
|
||||
def chat(self, message, re_chat: bool):
|
||||
def chat(self, message, re_chat: bool, stream: bool):
|
||||
self.is_valid(raise_exception=True)
|
||||
chat_id = self.data.get('chat_id')
|
||||
chat_info: ChatInfo = chat_cache.get(chat_id)
|
||||
|
|
@ -152,7 +153,8 @@ class ChatMessageSerializer(serializers.Serializer):
|
|||
chat_info.chat_record_list)])
|
||||
exclude_paragraph_id_list = list(set(paragraph_id_list))
|
||||
# 构建运行参数
|
||||
params = chat_info.to_pipeline_manage_params(message, get_post_handler(chat_info), exclude_paragraph_id_list)
|
||||
params = chat_info.to_pipeline_manage_params(message, get_post_handler(chat_info), exclude_paragraph_id_list,
|
||||
stream)
|
||||
# 运行流水线作业
|
||||
pipline_message.run(params)
|
||||
return pipline_message.context['chat_result']
|
||||
|
|
|
|||
|
|
@ -20,7 +20,8 @@ class ChatApi(ApiMixin):
|
|||
required=['message'],
|
||||
properties={
|
||||
'message': openapi.Schema(type=openapi.TYPE_STRING, title="问题", description="问题"),
|
||||
're_chat': openapi.Schema(type=openapi.TYPE_BOOLEAN, title="重新生成", default="重新生成")
|
||||
're_chat': openapi.Schema(type=openapi.TYPE_BOOLEAN, title="重新生成", default=False),
|
||||
'stream': openapi.Schema(type=openapi.TYPE_BOOLEAN, title="重新生成", default=True)
|
||||
|
||||
}
|
||||
)
|
||||
|
|
|
|||
|
|
@ -72,7 +72,8 @@ class ChatView(APIView):
|
|||
)
|
||||
def post(self, request: Request, chat_id: str):
|
||||
return ChatMessageSerializer(data={'chat_id': chat_id}).chat(request.data.get('message'), request.data.get(
|
||||
're_chat') if 're_chat' in request.data else False)
|
||||
're_chat') if 're_chat' in request.data else False, request.data.get(
|
||||
'stream') if 'stream' in request.data else True)
|
||||
|
||||
@action(methods=['GET'], detail=False)
|
||||
@swagger_auto_schema(operation_summary="获取对话列表",
|
||||
|
|
|
|||
Loading…
Reference in New Issue