diff --git a/apps/application/flow/i_step_node.py b/apps/application/flow/i_step_node.py index 72fb178e7..440ff0703 100644 --- a/apps/application/flow/i_step_node.py +++ b/apps/application/flow/i_step_node.py @@ -44,8 +44,10 @@ class WorkFlowPostHandler: workflow): question = workflow.params['question'] details = workflow.get_runtime_details() - message_tokens = sum([row.get('message_tokens') for row in details.values() if 'message_tokens' in row]) - answer_tokens = sum([row.get('answer_tokens') for row in details.values() if 'answer_tokens' in row]) + message_tokens = sum([row.get('message_tokens') for row in details.values() if + 'message_tokens' in row and row.get('message_tokens') is not None]) + answer_tokens = sum([row.get('answer_tokens') for row in details.values() if + 'answer_tokens' in row and row.get('answer_tokens') is not None]) chat_record = ChatRecord(id=chat_record_id, chat_id=chat_id, problem_text=question, diff --git a/apps/application/flow/step_node/ai_chat_step_node/impl/base_chat_node.py b/apps/application/flow/step_node/ai_chat_step_node/impl/base_chat_node.py index deef0a40e..766c31c80 100644 --- a/apps/application/flow/step_node/ai_chat_step_node/impl/base_chat_node.py +++ b/apps/application/flow/step_node/ai_chat_step_node/impl/base_chat_node.py @@ -123,8 +123,11 @@ class BaseChatNode(IChatNode): rsa_long_decrypt(model.credential)), streaming=True) history_message = self.get_history_message(history_chat_record, dialogue_number) + self.context['history_message'] = history_message question = self.generate_prompt_question(prompt) + self.context['question'] = question.content message_list = self.generate_message_list(system, prompt, history_message) + self.context['message_list'] = message_list if stream: r = chat_model.stream(message_list) return NodeResult({'result': r, 'chat_model': chat_model, 'message_list': message_list, @@ -171,14 +174,13 @@ class BaseChatNode(IChatNode): 'run_time': self.context.get('run_time'), 'system': self.node_params.get('system'), 'history_message': [{'content': message.content, 'role': message.type} for message in - self.context.get('history_message')], + (self.context.get('history_message') if self.context.get( + 'history_message') is not None else [])], 'question': self.context.get('question'), 'answer': self.context.get('answer'), 'type': self.node.type, - 'message_tokens': self.context['message_tokens'], - 'answer_tokens': self.context['answer_tokens'], + 'message_tokens': self.context.get('message_tokens'), + 'answer_tokens': self.context.get('answer_tokens'), 'status': self.status, 'err_message': self.err_message - } if self.status == 200 else {"index": index, 'type': self.node.type, - 'status': self.status, - 'err_message': self.err_message} + } diff --git a/apps/application/flow/step_node/condition_node/i_condition_node.py b/apps/application/flow/step_node/condition_node/i_condition_node.py index f85bd082d..ffb975a98 100644 --- a/apps/application/flow/step_node/condition_node/i_condition_node.py +++ b/apps/application/flow/step_node/condition_node/i_condition_node.py @@ -23,6 +23,7 @@ class ConditionSerializer(serializers.Serializer): class ConditionBranchSerializer(serializers.Serializer): id = serializers.CharField(required=True, error_messages=ErrMessage.char("分支id")) + type = serializers.CharField(required=True, error_messages=ErrMessage.char("分支类型")) condition = serializers.CharField(required=True, error_messages=ErrMessage.char("条件or|and")) conditions = ConditionSerializer(many=True) diff --git a/apps/application/flow/step_node/condition_node/impl/base_condition_node.py b/apps/application/flow/step_node/condition_node/impl/base_condition_node.py index 62dfa905e..94f004b19 100644 --- a/apps/application/flow/step_node/condition_node/impl/base_condition_node.py +++ b/apps/application/flow/step_node/condition_node/impl/base_condition_node.py @@ -17,7 +17,7 @@ class BaseConditionNode(IConditionNode): def execute(self, **kwargs) -> NodeResult: branch_list = self.node_params_serializer.data['branch'] branch = self._execute(branch_list) - r = NodeResult({'branch_id': branch.get('id')}, {}) + r = NodeResult({'branch_id': branch.get('id'), 'branch_name': branch.get('type')}, {}) return r def _execute(self, branch_list: List): @@ -44,8 +44,6 @@ class BaseConditionNode(IConditionNode): 'branch_id': self.context.get('branch_id'), 'branch_name': self.context.get('branch_name'), 'type': self.node.type, - 'status': self.context.get('status'), - 'err_message': self.context.get('err_message') - } if self.status == 200 else {"index": index, 'type': self.node.type, - 'status': self.status, - 'err_message': self.err_message} + 'status': self.status, + 'err_message': self.err_message + } diff --git a/apps/application/flow/step_node/direct_reply_node/impl/base_reply_node.py b/apps/application/flow/step_node/direct_reply_node/impl/base_reply_node.py index 5675fa832..b7e77dde0 100644 --- a/apps/application/flow/step_node/direct_reply_node/impl/base_reply_node.py +++ b/apps/application/flow/step_node/direct_reply_node/impl/base_reply_node.py @@ -86,6 +86,4 @@ class BaseReplyNode(IReplyNode): 'answer': self.context.get('answer'), 'status': self.status, 'err_message': self.err_message - } if self.status == 200 else {"index": index, 'type': self.node.type, - 'status': self.status, - 'err_message': self.err_message} + } diff --git a/apps/application/flow/step_node/question_node/impl/base_question_node.py b/apps/application/flow/step_node/question_node/impl/base_question_node.py index 6192d13e9..4163e79b4 100644 --- a/apps/application/flow/step_node/question_node/impl/base_question_node.py +++ b/apps/application/flow/step_node/question_node/impl/base_question_node.py @@ -120,8 +120,11 @@ class BaseQuestionNode(IQuestionNode): rsa_long_decrypt(model.credential)), streaming=True) history_message = self.get_history_message(history_chat_record, dialogue_number) + self.context['history_message'] = history_message question = self.generate_prompt_question(prompt) + self.context['question'] = question.content message_list = self.generate_message_list(system, prompt, history_message) + self.context['message_list'] = message_list if stream: r = chat_model.stream(message_list) return NodeResult({'result': r, 'chat_model': chat_model, 'message_list': message_list, @@ -168,14 +171,13 @@ class BaseQuestionNode(IQuestionNode): 'run_time': self.context.get('run_time'), 'system': self.node_params.get('system'), 'history_message': [{'content': message.content, 'role': message.type} for message in - self.context.get('history_message')], + (self.context.get('history_message') if self.context.get( + 'history_message') is not None else [])], 'question': self.context.get('question'), 'answer': self.context.get('answer'), 'type': self.node.type, - 'message_tokens': self.context['message_tokens'], - 'answer_tokens': self.context['answer_tokens'], + 'message_tokens': self.context.get('message_tokens'), + 'answer_tokens': self.context.get('answer_tokens'), 'status': self.status, 'err_message': self.err_message - } if self.status == 200 else {"index": index, 'type': self.node.type, - 'status': self.status, - 'err_message': self.err_message} + } diff --git a/apps/application/flow/step_node/search_dataset_node/i_search_dataset_node.py b/apps/application/flow/step_node/search_dataset_node/i_search_dataset_node.py index 7476ecd04..020273b71 100644 --- a/apps/application/flow/step_node/search_dataset_node/i_search_dataset_node.py +++ b/apps/application/flow/step_node/search_dataset_node/i_search_dataset_node.py @@ -12,7 +12,7 @@ from typing import Type from django.core import validators from rest_framework import serializers -from application.flow.i_step_node import INode, ReferenceAddressSerializer, NodeResult +from application.flow.i_step_node import INode, NodeResult from common.util.field_message import ErrMessage diff --git a/apps/application/flow/step_node/search_dataset_node/impl/base_search_dataset_node.py b/apps/application/flow/step_node/search_dataset_node/impl/base_search_dataset_node.py index 1f22e8603..2b233ecb2 100644 --- a/apps/application/flow/step_node/search_dataset_node/impl/base_search_dataset_node.py +++ b/apps/application/flow/step_node/search_dataset_node/impl/base_search_dataset_node.py @@ -25,6 +25,7 @@ class BaseSearchDatasetNode(ISearchDatasetStepNode): def execute(self, dataset_id_list, dataset_setting, question, exclude_paragraph_id_list=None, **kwargs) -> NodeResult: + self.context['question'] = question embedding_model = EmbeddingModel.get_embedding_model() embedding_value = embedding_model.embed_query(question) vector = VectorStore.get_embedding_vector() @@ -85,6 +86,4 @@ class BaseSearchDatasetNode(ISearchDatasetStepNode): 'type': self.node.type, 'status': self.status, 'err_message': self.err_message - } if self.status == 200 else {"index": index, 'type': self.node.type, - 'status': self.status, - 'err_message': self.err_message} + } diff --git a/apps/application/flow/step_node/start_node/impl/base_start_node.py b/apps/application/flow/step_node/start_node/impl/base_start_node.py index 67cd2a48e..8ed8a67fd 100644 --- a/apps/application/flow/step_node/start_node/impl/base_start_node.py +++ b/apps/application/flow/step_node/start_node/impl/base_start_node.py @@ -27,6 +27,4 @@ class BaseStartStepNode(IStarNode): 'type': self.node.type, 'status': self.status, 'err_message': self.err_message - } if self.status == 200 else {"index": index, 'type': self.node.type, - 'status': self.status, - 'err_message': self.err_message} + } diff --git a/apps/application/flow/tools.py b/apps/application/flow/tools.py index c255369c6..e791f6c11 100644 --- a/apps/application/flow/tools.py +++ b/apps/application/flow/tools.py @@ -32,11 +32,11 @@ def event_content(chat_id, chat_record_id, response, workflow, for chunk in response: answer += chunk.content yield 'data: ' + json.dumps({'chat_id': str(chat_id), 'id': str(chat_record_id), 'operate': True, - 'content': chunk.content, 'is_end': False}) + "\n\n" + 'content': chunk.content, 'is_end': False}, ensure_ascii=False) + "\n\n" write_context(answer) post_handler.handler(chat_id, chat_record_id, answer, workflow) yield 'data: ' + json.dumps({'chat_id': str(chat_id), 'id': str(chat_record_id), 'operate': True, - 'content': '', 'is_end': True}) + "\n\n" + 'content': '', 'is_end': True}, ensure_ascii=False) + "\n\n" def to_stream_response(chat_id, chat_record_id, response: Iterator[BaseMessageChunk], workflow, write_context, @@ -53,7 +53,8 @@ def to_stream_response(chat_id, chat_record_id, response: Iterator[BaseMessageCh """ r = StreamingHttpResponse( streaming_content=event_content(chat_id, chat_record_id, response, workflow, write_context, post_handler), - content_type='text/event-stream;charset=utf-8') + content_type='text/event-stream;charset=utf-8', + charset='utf-8') r['Cache-Control'] = 'no-cache' return r diff --git a/apps/common/response/result.py b/apps/common/response/result.py index d1cf6a3ad..bb2ba0faf 100644 --- a/apps/common/response/result.py +++ b/apps/common/response/result.py @@ -15,6 +15,7 @@ class Page(dict): class Result(JsonResponse): + charset = 'utf-8' """ 接口统一返回对象 """