From 6b4cee14127006cc148f74ec0d31029873b2e05c Mon Sep 17 00:00:00 2001 From: shaohuzhang1 <80892890+shaohuzhang1@users.noreply.github.com> Date: Wed, 4 Dec 2024 14:19:37 +0800 Subject: [PATCH] =?UTF-8?q?fix:=20=E4=BF=AE=E5=A4=8D=E5=AF=B9=E8=AF=9D?= =?UTF-8?q?=E4=BD=BF=E7=94=A8api=E8=B0=83=E7=94=A8=E6=97=A0=E6=B3=95?= =?UTF-8?q?=E5=93=8D=E5=BA=94=E6=95=B0=E6=8D=AE=20(#1755)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../application_node/impl/base_application_node.py | 8 +++++--- apps/application/flow/workflow_manage.py | 6 ++++-- apps/common/handle/impl/response/system_to_response.py | 10 +++++++--- 3 files changed, 16 insertions(+), 8 deletions(-) diff --git a/apps/application/flow/step_node/application_node/impl/base_application_node.py b/apps/application/flow/step_node/application_node/impl/base_application_node.py index 532a81e52..5527cf489 100644 --- a/apps/application/flow/step_node/application_node/impl/base_application_node.py +++ b/apps/application/flow/step_node/application_node/impl/base_application_node.py @@ -19,8 +19,8 @@ def _is_interrupt_exec(node, node_variable: Dict, workflow_variable: Dict): def _write_context(node_variable: Dict, workflow_variable: Dict, node: INode, workflow, answer: str): result = node_variable.get('result') - node.context['child_node'] = node_variable['child_node'] - node.context['is_interrupt_exec'] = node_variable['is_interrupt_exec'] + node.context['child_node'] = node_variable.get('child_node') + node.context['is_interrupt_exec'] = node_variable.get('is_interrupt_exec') node.context['message_tokens'] = result.get('usage', {}).get('prompt_tokens', 0) node.context['answer_tokens'] = result.get('usage', {}).get('completion_tokens', 0) node.context['answer'] = answer @@ -81,7 +81,9 @@ def write_context(node_variable: Dict, workflow_variable: Dict, node: INode, wor @param node: 节点实例对象 @param workflow: 工作流管理器 """ - response = node_variable.get('result')['choices'][0]['message'] + response = node_variable.get('result', {}).get('data', {}) + node_variable['result'] = {'usage': {'completion_tokens': response.get('completion_tokens'), + 'prompt_tokens': response.get('prompt_tokens')}} answer = response.get('content', '') or "抱歉,没有查找到相关内容,请重新描述您的问题或提供更多信息。" _write_context(node_variable, workflow_variable, node, workflow, answer) diff --git a/apps/application/flow/workflow_manage.py b/apps/application/flow/workflow_manage.py index 37efc5467..b126387d5 100644 --- a/apps/application/flow/workflow_manage.py +++ b/apps/application/flow/workflow_manage.py @@ -328,11 +328,13 @@ class WorkflowManage: 'message_tokens' in row and row.get('message_tokens') is not None]) answer_tokens = sum([row.get('answer_tokens') for row in details.values() if 'answer_tokens' in row and row.get('answer_tokens') is not None]) + answer_text_list = self.get_answer_text_list() + answer_text = '\n\n'.join(answer['content'] for answer in answer_text_list) self.work_flow_post_handler.handler(self.params['chat_id'], self.params['chat_record_id'], - self.answer, + answer_text, self) return self.base_to_response.to_block_response(self.params['chat_id'], - self.params['chat_record_id'], self.answer, True + self.params['chat_record_id'], answer_text, True , message_tokens, answer_tokens, _status=status.HTTP_200_OK if self.status == 200 else status.HTTP_500_INTERNAL_SERVER_ERROR) diff --git a/apps/common/handle/impl/response/system_to_response.py b/apps/common/handle/impl/response/system_to_response.py index 5999f1297..8df5ce139 100644 --- a/apps/common/handle/impl/response/system_to_response.py +++ b/apps/common/handle/impl/response/system_to_response.py @@ -21,15 +21,19 @@ class SystemToResponse(BaseToResponse): if other_params is None: other_params = {} return result.success({'chat_id': str(chat_id), 'id': str(chat_record_id), 'operate': True, - 'content': content, 'is_end': is_end, **other_params}, response_status=_status, + 'content': content, 'is_end': is_end, **other_params, + 'completion_tokens': completion_tokens, 'prompt_tokens': prompt_tokens}, + response_status=_status, code=_status) - def to_stream_chunk_response(self, chat_id, chat_record_id, node_id, up_node_id_list, content, is_end, completion_tokens, + def to_stream_chunk_response(self, chat_id, chat_record_id, node_id, up_node_id_list, content, is_end, + completion_tokens, prompt_tokens, other_params: dict = None): if other_params is None: other_params = {} chunk = json.dumps({'chat_id': str(chat_id), 'chat_record_id': str(chat_record_id), 'operate': True, - 'content': content, 'node_id': node_id, 'up_node_id_list': up_node_id_list, 'is_end': is_end, + 'content': content, 'node_id': node_id, 'up_node_id_list': up_node_id_list, + 'is_end': is_end, 'usage': {'completion_tokens': completion_tokens, 'prompt_tokens': prompt_tokens, 'total_tokens': completion_tokens + prompt_tokens},