feat: 对话纪要添加详情字段

This commit is contained in:
shaohuzhang1 2024-06-24 14:45:45 +08:00
parent c7883cd3fd
commit c3e17229cd
11 changed files with 35 additions and 33 deletions

View File

@ -44,8 +44,10 @@ class WorkFlowPostHandler:
workflow):
question = workflow.params['question']
details = workflow.get_runtime_details()
message_tokens = sum([row.get('message_tokens') for row in details.values() if 'message_tokens' in row])
answer_tokens = sum([row.get('answer_tokens') for row in details.values() if 'answer_tokens' in row])
message_tokens = sum([row.get('message_tokens') for row in details.values() if
'message_tokens' in row and row.get('message_tokens') is not None])
answer_tokens = sum([row.get('answer_tokens') for row in details.values() if
'answer_tokens' in row and row.get('answer_tokens') is not None])
chat_record = ChatRecord(id=chat_record_id,
chat_id=chat_id,
problem_text=question,

View File

@ -123,8 +123,11 @@ class BaseChatNode(IChatNode):
rsa_long_decrypt(model.credential)),
streaming=True)
history_message = self.get_history_message(history_chat_record, dialogue_number)
self.context['history_message'] = history_message
question = self.generate_prompt_question(prompt)
self.context['question'] = question.content
message_list = self.generate_message_list(system, prompt, history_message)
self.context['message_list'] = message_list
if stream:
r = chat_model.stream(message_list)
return NodeResult({'result': r, 'chat_model': chat_model, 'message_list': message_list,
@ -171,14 +174,13 @@ class BaseChatNode(IChatNode):
'run_time': self.context.get('run_time'),
'system': self.node_params.get('system'),
'history_message': [{'content': message.content, 'role': message.type} for message in
self.context.get('history_message')],
(self.context.get('history_message') if self.context.get(
'history_message') is not None else [])],
'question': self.context.get('question'),
'answer': self.context.get('answer'),
'type': self.node.type,
'message_tokens': self.context['message_tokens'],
'answer_tokens': self.context['answer_tokens'],
'message_tokens': self.context.get('message_tokens'),
'answer_tokens': self.context.get('answer_tokens'),
'status': self.status,
'err_message': self.err_message
} if self.status == 200 else {"index": index, 'type': self.node.type,
'status': self.status,
'err_message': self.err_message}
}

View File

@ -23,6 +23,7 @@ class ConditionSerializer(serializers.Serializer):
class ConditionBranchSerializer(serializers.Serializer):
id = serializers.CharField(required=True, error_messages=ErrMessage.char("分支id"))
type = serializers.CharField(required=True, error_messages=ErrMessage.char("分支类型"))
condition = serializers.CharField(required=True, error_messages=ErrMessage.char("条件or|and"))
conditions = ConditionSerializer(many=True)

View File

@ -17,7 +17,7 @@ class BaseConditionNode(IConditionNode):
def execute(self, **kwargs) -> NodeResult:
branch_list = self.node_params_serializer.data['branch']
branch = self._execute(branch_list)
r = NodeResult({'branch_id': branch.get('id')}, {})
r = NodeResult({'branch_id': branch.get('id'), 'branch_name': branch.get('type')}, {})
return r
def _execute(self, branch_list: List):
@ -44,8 +44,6 @@ class BaseConditionNode(IConditionNode):
'branch_id': self.context.get('branch_id'),
'branch_name': self.context.get('branch_name'),
'type': self.node.type,
'status': self.context.get('status'),
'err_message': self.context.get('err_message')
} if self.status == 200 else {"index": index, 'type': self.node.type,
'status': self.status,
'err_message': self.err_message}
'status': self.status,
'err_message': self.err_message
}

View File

@ -86,6 +86,4 @@ class BaseReplyNode(IReplyNode):
'answer': self.context.get('answer'),
'status': self.status,
'err_message': self.err_message
} if self.status == 200 else {"index": index, 'type': self.node.type,
'status': self.status,
'err_message': self.err_message}
}

View File

@ -120,8 +120,11 @@ class BaseQuestionNode(IQuestionNode):
rsa_long_decrypt(model.credential)),
streaming=True)
history_message = self.get_history_message(history_chat_record, dialogue_number)
self.context['history_message'] = history_message
question = self.generate_prompt_question(prompt)
self.context['question'] = question.content
message_list = self.generate_message_list(system, prompt, history_message)
self.context['message_list'] = message_list
if stream:
r = chat_model.stream(message_list)
return NodeResult({'result': r, 'chat_model': chat_model, 'message_list': message_list,
@ -168,14 +171,13 @@ class BaseQuestionNode(IQuestionNode):
'run_time': self.context.get('run_time'),
'system': self.node_params.get('system'),
'history_message': [{'content': message.content, 'role': message.type} for message in
self.context.get('history_message')],
(self.context.get('history_message') if self.context.get(
'history_message') is not None else [])],
'question': self.context.get('question'),
'answer': self.context.get('answer'),
'type': self.node.type,
'message_tokens': self.context['message_tokens'],
'answer_tokens': self.context['answer_tokens'],
'message_tokens': self.context.get('message_tokens'),
'answer_tokens': self.context.get('answer_tokens'),
'status': self.status,
'err_message': self.err_message
} if self.status == 200 else {"index": index, 'type': self.node.type,
'status': self.status,
'err_message': self.err_message}
}

View File

@ -12,7 +12,7 @@ from typing import Type
from django.core import validators
from rest_framework import serializers
from application.flow.i_step_node import INode, ReferenceAddressSerializer, NodeResult
from application.flow.i_step_node import INode, NodeResult
from common.util.field_message import ErrMessage

View File

@ -25,6 +25,7 @@ class BaseSearchDatasetNode(ISearchDatasetStepNode):
def execute(self, dataset_id_list, dataset_setting, question,
exclude_paragraph_id_list=None,
**kwargs) -> NodeResult:
self.context['question'] = question
embedding_model = EmbeddingModel.get_embedding_model()
embedding_value = embedding_model.embed_query(question)
vector = VectorStore.get_embedding_vector()
@ -85,6 +86,4 @@ class BaseSearchDatasetNode(ISearchDatasetStepNode):
'type': self.node.type,
'status': self.status,
'err_message': self.err_message
} if self.status == 200 else {"index": index, 'type': self.node.type,
'status': self.status,
'err_message': self.err_message}
}

View File

@ -27,6 +27,4 @@ class BaseStartStepNode(IStarNode):
'type': self.node.type,
'status': self.status,
'err_message': self.err_message
} if self.status == 200 else {"index": index, 'type': self.node.type,
'status': self.status,
'err_message': self.err_message}
}

View File

@ -32,11 +32,11 @@ def event_content(chat_id, chat_record_id, response, workflow,
for chunk in response:
answer += chunk.content
yield 'data: ' + json.dumps({'chat_id': str(chat_id), 'id': str(chat_record_id), 'operate': True,
'content': chunk.content, 'is_end': False}) + "\n\n"
'content': chunk.content, 'is_end': False}, ensure_ascii=False) + "\n\n"
write_context(answer)
post_handler.handler(chat_id, chat_record_id, answer, workflow)
yield 'data: ' + json.dumps({'chat_id': str(chat_id), 'id': str(chat_record_id), 'operate': True,
'content': '', 'is_end': True}) + "\n\n"
'content': '', 'is_end': True}, ensure_ascii=False) + "\n\n"
def to_stream_response(chat_id, chat_record_id, response: Iterator[BaseMessageChunk], workflow, write_context,
@ -53,7 +53,8 @@ def to_stream_response(chat_id, chat_record_id, response: Iterator[BaseMessageCh
"""
r = StreamingHttpResponse(
streaming_content=event_content(chat_id, chat_record_id, response, workflow, write_context, post_handler),
content_type='text/event-stream;charset=utf-8')
content_type='text/event-stream;charset=utf-8',
charset='utf-8')
r['Cache-Control'] = 'no-cache'
return r

View File

@ -15,6 +15,7 @@ class Page(dict):
class Result(JsonResponse):
charset = 'utf-8'
"""
接口统一返回对象
"""