mirror of
https://github.com/1Panel-dev/MaxKB.git
synced 2025-12-27 20:42:52 +00:00
feat: 对话纪要添加详情字段
This commit is contained in:
parent
c7883cd3fd
commit
c3e17229cd
|
|
@ -44,8 +44,10 @@ class WorkFlowPostHandler:
|
|||
workflow):
|
||||
question = workflow.params['question']
|
||||
details = workflow.get_runtime_details()
|
||||
message_tokens = sum([row.get('message_tokens') for row in details.values() if 'message_tokens' in row])
|
||||
answer_tokens = sum([row.get('answer_tokens') for row in details.values() if 'answer_tokens' in row])
|
||||
message_tokens = sum([row.get('message_tokens') for row in details.values() if
|
||||
'message_tokens' in row and row.get('message_tokens') is not None])
|
||||
answer_tokens = sum([row.get('answer_tokens') for row in details.values() if
|
||||
'answer_tokens' in row and row.get('answer_tokens') is not None])
|
||||
chat_record = ChatRecord(id=chat_record_id,
|
||||
chat_id=chat_id,
|
||||
problem_text=question,
|
||||
|
|
|
|||
|
|
@ -123,8 +123,11 @@ class BaseChatNode(IChatNode):
|
|||
rsa_long_decrypt(model.credential)),
|
||||
streaming=True)
|
||||
history_message = self.get_history_message(history_chat_record, dialogue_number)
|
||||
self.context['history_message'] = history_message
|
||||
question = self.generate_prompt_question(prompt)
|
||||
self.context['question'] = question.content
|
||||
message_list = self.generate_message_list(system, prompt, history_message)
|
||||
self.context['message_list'] = message_list
|
||||
if stream:
|
||||
r = chat_model.stream(message_list)
|
||||
return NodeResult({'result': r, 'chat_model': chat_model, 'message_list': message_list,
|
||||
|
|
@ -171,14 +174,13 @@ class BaseChatNode(IChatNode):
|
|||
'run_time': self.context.get('run_time'),
|
||||
'system': self.node_params.get('system'),
|
||||
'history_message': [{'content': message.content, 'role': message.type} for message in
|
||||
self.context.get('history_message')],
|
||||
(self.context.get('history_message') if self.context.get(
|
||||
'history_message') is not None else [])],
|
||||
'question': self.context.get('question'),
|
||||
'answer': self.context.get('answer'),
|
||||
'type': self.node.type,
|
||||
'message_tokens': self.context['message_tokens'],
|
||||
'answer_tokens': self.context['answer_tokens'],
|
||||
'message_tokens': self.context.get('message_tokens'),
|
||||
'answer_tokens': self.context.get('answer_tokens'),
|
||||
'status': self.status,
|
||||
'err_message': self.err_message
|
||||
} if self.status == 200 else {"index": index, 'type': self.node.type,
|
||||
'status': self.status,
|
||||
'err_message': self.err_message}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -23,6 +23,7 @@ class ConditionSerializer(serializers.Serializer):
|
|||
|
||||
class ConditionBranchSerializer(serializers.Serializer):
|
||||
id = serializers.CharField(required=True, error_messages=ErrMessage.char("分支id"))
|
||||
type = serializers.CharField(required=True, error_messages=ErrMessage.char("分支类型"))
|
||||
condition = serializers.CharField(required=True, error_messages=ErrMessage.char("条件or|and"))
|
||||
conditions = ConditionSerializer(many=True)
|
||||
|
||||
|
|
|
|||
|
|
@ -17,7 +17,7 @@ class BaseConditionNode(IConditionNode):
|
|||
def execute(self, **kwargs) -> NodeResult:
|
||||
branch_list = self.node_params_serializer.data['branch']
|
||||
branch = self._execute(branch_list)
|
||||
r = NodeResult({'branch_id': branch.get('id')}, {})
|
||||
r = NodeResult({'branch_id': branch.get('id'), 'branch_name': branch.get('type')}, {})
|
||||
return r
|
||||
|
||||
def _execute(self, branch_list: List):
|
||||
|
|
@ -44,8 +44,6 @@ class BaseConditionNode(IConditionNode):
|
|||
'branch_id': self.context.get('branch_id'),
|
||||
'branch_name': self.context.get('branch_name'),
|
||||
'type': self.node.type,
|
||||
'status': self.context.get('status'),
|
||||
'err_message': self.context.get('err_message')
|
||||
} if self.status == 200 else {"index": index, 'type': self.node.type,
|
||||
'status': self.status,
|
||||
'err_message': self.err_message}
|
||||
'status': self.status,
|
||||
'err_message': self.err_message
|
||||
}
|
||||
|
|
|
|||
|
|
@ -86,6 +86,4 @@ class BaseReplyNode(IReplyNode):
|
|||
'answer': self.context.get('answer'),
|
||||
'status': self.status,
|
||||
'err_message': self.err_message
|
||||
} if self.status == 200 else {"index": index, 'type': self.node.type,
|
||||
'status': self.status,
|
||||
'err_message': self.err_message}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -120,8 +120,11 @@ class BaseQuestionNode(IQuestionNode):
|
|||
rsa_long_decrypt(model.credential)),
|
||||
streaming=True)
|
||||
history_message = self.get_history_message(history_chat_record, dialogue_number)
|
||||
self.context['history_message'] = history_message
|
||||
question = self.generate_prompt_question(prompt)
|
||||
self.context['question'] = question.content
|
||||
message_list = self.generate_message_list(system, prompt, history_message)
|
||||
self.context['message_list'] = message_list
|
||||
if stream:
|
||||
r = chat_model.stream(message_list)
|
||||
return NodeResult({'result': r, 'chat_model': chat_model, 'message_list': message_list,
|
||||
|
|
@ -168,14 +171,13 @@ class BaseQuestionNode(IQuestionNode):
|
|||
'run_time': self.context.get('run_time'),
|
||||
'system': self.node_params.get('system'),
|
||||
'history_message': [{'content': message.content, 'role': message.type} for message in
|
||||
self.context.get('history_message')],
|
||||
(self.context.get('history_message') if self.context.get(
|
||||
'history_message') is not None else [])],
|
||||
'question': self.context.get('question'),
|
||||
'answer': self.context.get('answer'),
|
||||
'type': self.node.type,
|
||||
'message_tokens': self.context['message_tokens'],
|
||||
'answer_tokens': self.context['answer_tokens'],
|
||||
'message_tokens': self.context.get('message_tokens'),
|
||||
'answer_tokens': self.context.get('answer_tokens'),
|
||||
'status': self.status,
|
||||
'err_message': self.err_message
|
||||
} if self.status == 200 else {"index": index, 'type': self.node.type,
|
||||
'status': self.status,
|
||||
'err_message': self.err_message}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -12,7 +12,7 @@ from typing import Type
|
|||
from django.core import validators
|
||||
from rest_framework import serializers
|
||||
|
||||
from application.flow.i_step_node import INode, ReferenceAddressSerializer, NodeResult
|
||||
from application.flow.i_step_node import INode, NodeResult
|
||||
from common.util.field_message import ErrMessage
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -25,6 +25,7 @@ class BaseSearchDatasetNode(ISearchDatasetStepNode):
|
|||
def execute(self, dataset_id_list, dataset_setting, question,
|
||||
exclude_paragraph_id_list=None,
|
||||
**kwargs) -> NodeResult:
|
||||
self.context['question'] = question
|
||||
embedding_model = EmbeddingModel.get_embedding_model()
|
||||
embedding_value = embedding_model.embed_query(question)
|
||||
vector = VectorStore.get_embedding_vector()
|
||||
|
|
@ -85,6 +86,4 @@ class BaseSearchDatasetNode(ISearchDatasetStepNode):
|
|||
'type': self.node.type,
|
||||
'status': self.status,
|
||||
'err_message': self.err_message
|
||||
} if self.status == 200 else {"index": index, 'type': self.node.type,
|
||||
'status': self.status,
|
||||
'err_message': self.err_message}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -27,6 +27,4 @@ class BaseStartStepNode(IStarNode):
|
|||
'type': self.node.type,
|
||||
'status': self.status,
|
||||
'err_message': self.err_message
|
||||
} if self.status == 200 else {"index": index, 'type': self.node.type,
|
||||
'status': self.status,
|
||||
'err_message': self.err_message}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -32,11 +32,11 @@ def event_content(chat_id, chat_record_id, response, workflow,
|
|||
for chunk in response:
|
||||
answer += chunk.content
|
||||
yield 'data: ' + json.dumps({'chat_id': str(chat_id), 'id': str(chat_record_id), 'operate': True,
|
||||
'content': chunk.content, 'is_end': False}) + "\n\n"
|
||||
'content': chunk.content, 'is_end': False}, ensure_ascii=False) + "\n\n"
|
||||
write_context(answer)
|
||||
post_handler.handler(chat_id, chat_record_id, answer, workflow)
|
||||
yield 'data: ' + json.dumps({'chat_id': str(chat_id), 'id': str(chat_record_id), 'operate': True,
|
||||
'content': '', 'is_end': True}) + "\n\n"
|
||||
'content': '', 'is_end': True}, ensure_ascii=False) + "\n\n"
|
||||
|
||||
|
||||
def to_stream_response(chat_id, chat_record_id, response: Iterator[BaseMessageChunk], workflow, write_context,
|
||||
|
|
@ -53,7 +53,8 @@ def to_stream_response(chat_id, chat_record_id, response: Iterator[BaseMessageCh
|
|||
"""
|
||||
r = StreamingHttpResponse(
|
||||
streaming_content=event_content(chat_id, chat_record_id, response, workflow, write_context, post_handler),
|
||||
content_type='text/event-stream;charset=utf-8')
|
||||
content_type='text/event-stream;charset=utf-8',
|
||||
charset='utf-8')
|
||||
|
||||
r['Cache-Control'] = 'no-cache'
|
||||
return r
|
||||
|
|
|
|||
|
|
@ -15,6 +15,7 @@ class Page(dict):
|
|||
|
||||
|
||||
class Result(JsonResponse):
|
||||
charset = 'utf-8'
|
||||
"""
|
||||
接口统一返回对象
|
||||
"""
|
||||
|
|
|
|||
Loading…
Reference in New Issue