diff --git a/apps/application/chat_pipeline/step/chat_step/i_chat_step.py b/apps/application/chat_pipeline/step/chat_step/i_chat_step.py index 1267f4336..2673c6b7b 100644 --- a/apps/application/chat_pipeline/step/chat_step/i_chat_step.py +++ b/apps/application/chat_pipeline/step/chat_step/i_chat_step.py @@ -65,16 +65,21 @@ class IChatStep(IBaseChatPipelineStep): post_response_handler = InstanceField(model_type=PostResponseHandler, error_messages=ErrMessage.base(_("Post-processor"))) # 补全问题 - padding_problem_text = serializers.CharField(required=False, error_messages=ErrMessage.base(_("Completion Question"))) + padding_problem_text = serializers.CharField(required=False, + error_messages=ErrMessage.base(_("Completion Question"))) # 是否使用流的形式输出 stream = serializers.BooleanField(required=False, error_messages=ErrMessage.base(_("Streaming Output"))) client_id = serializers.CharField(required=True, error_messages=ErrMessage.char(_("Client id"))) client_type = serializers.CharField(required=True, error_messages=ErrMessage.char(_("Client Type"))) # 未查询到引用分段 - no_references_setting = NoReferencesSetting(required=True, error_messages=ErrMessage.base(_("No reference segment settings"))) + no_references_setting = NoReferencesSetting(required=True, + error_messages=ErrMessage.base(_("No reference segment settings"))) user_id = serializers.UUIDField(required=True, error_messages=ErrMessage.uuid(_("User ID"))) + model_setting = serializers.DictField(required=True, allow_null=True, + error_messages=ErrMessage.dict(_("Model settings"))) + model_params_setting = serializers.DictField(required=False, allow_null=True, error_messages=ErrMessage.dict(_("Model parameter settings"))) @@ -101,5 +106,5 @@ class IChatStep(IBaseChatPipelineStep): paragraph_list=None, manage: PipelineManage = None, padding_problem_text: str = None, stream: bool = True, client_id=None, client_type=None, - no_references_setting=None, model_params_setting=None, **kwargs): + no_references_setting=None, model_params_setting=None, model_setting=None, **kwargs): pass diff --git a/apps/application/chat_pipeline/step/chat_step/impl/base_chat_step.py b/apps/application/chat_pipeline/step/chat_step/impl/base_chat_step.py index 8b8de260b..5bcdf1094 100644 --- a/apps/application/chat_pipeline/step/chat_step/impl/base_chat_step.py +++ b/apps/application/chat_pipeline/step/chat_step/impl/base_chat_step.py @@ -24,6 +24,7 @@ from rest_framework import status from application.chat_pipeline.I_base_chat_pipeline import ParagraphPipelineModel from application.chat_pipeline.pipeline_manage import PipelineManage from application.chat_pipeline.step.chat_step.i_chat_step import IChatStep, PostResponseHandler +from application.flow.tools import Reasoning from application.models.api_key_model import ApplicationPublicAccessClient from common.constants.authentication_type import AuthenticationType from setting.models_provider.tools import get_model_instance_by_model_user_id @@ -63,17 +64,37 @@ def event_content(response, problem_text: str, padding_problem_text: str = None, client_id=None, client_type=None, - is_ai_chat: bool = None): + is_ai_chat: bool = None, + model_setting=None): + if model_setting is None: + model_setting = {} + reasoning_content_enable = model_setting.get('reasoning_content_enable', False) + reasoning_content_start = model_setting.get('reasoning_content_start', '') + reasoning_content_end = model_setting.get('reasoning_content_end', '') + reasoning = Reasoning(reasoning_content_start, + reasoning_content_end) all_text = '' + reasoning_content = '' try: for chunk in response: - all_text += chunk.content + reasoning_chunk = reasoning.get_reasoning_content(chunk) + content_chunk = reasoning_chunk.get('content') + if 'reasoning_content' in chunk.additional_kwargs: + reasoning_content_chunk = chunk.additional_kwargs.get('reasoning_content', '') + else: + reasoning_content_chunk = reasoning_chunk.get('reasoning_content') + all_text += content_chunk + if reasoning_content_chunk is None: + reasoning_content_chunk = '' + reasoning_content += reasoning_content_chunk yield manage.get_base_to_response().to_stream_chunk_response(chat_id, str(chat_record_id), 'ai-chat-node', - [], chunk.content, + [], content_chunk, False, 0, 0, {'node_is_end': False, 'view_type': 'many_view', - 'node_type': 'ai-chat-node'}) + 'node_type': 'ai-chat-node', + 'real_node_id': 'ai-chat-node', + 'reasoning_content': reasoning_content_chunk if reasoning_content_enable else ''}) # 获取token if is_ai_chat: try: @@ -87,7 +108,8 @@ def event_content(response, response_token = 0 write_context(step, manage, request_token, response_token, all_text) post_response_handler.handler(chat_id, chat_record_id, paragraph_list, problem_text, - all_text, manage, step, padding_problem_text, client_id) + all_text, manage, step, padding_problem_text, client_id, + reasoning_content=reasoning_content if reasoning_content_enable else '') yield manage.get_base_to_response().to_stream_chunk_response(chat_id, str(chat_record_id), 'ai-chat-node', [], '', True, request_token, response_token, @@ -122,17 +144,20 @@ class BaseChatStep(IChatStep): client_id=None, client_type=None, no_references_setting=None, model_params_setting=None, + model_setting=None, **kwargs): chat_model = get_model_instance_by_model_user_id(model_id, user_id, **model_params_setting) if model_id is not None else None if stream: return self.execute_stream(message_list, chat_id, problem_text, post_response_handler, chat_model, paragraph_list, - manage, padding_problem_text, client_id, client_type, no_references_setting) + manage, padding_problem_text, client_id, client_type, no_references_setting, + model_setting) else: return self.execute_block(message_list, chat_id, problem_text, post_response_handler, chat_model, paragraph_list, - manage, padding_problem_text, client_id, client_type, no_references_setting) + manage, padding_problem_text, client_id, client_type, no_references_setting, + model_setting) def get_details(self, manage, **kwargs): return { @@ -187,14 +212,15 @@ class BaseChatStep(IChatStep): manage: PipelineManage = None, padding_problem_text: str = None, client_id=None, client_type=None, - no_references_setting=None): + no_references_setting=None, + model_setting=None): chat_result, is_ai_chat = self.get_stream_result(message_list, chat_model, paragraph_list, no_references_setting, problem_text) chat_record_id = uuid.uuid1() r = StreamingHttpResponse( streaming_content=event_content(chat_result, chat_id, chat_record_id, paragraph_list, post_response_handler, manage, self, chat_model, message_list, problem_text, - padding_problem_text, client_id, client_type, is_ai_chat), + padding_problem_text, client_id, client_type, is_ai_chat, model_setting), content_type='text/event-stream;charset=utf-8') r['Cache-Control'] = 'no-cache' @@ -230,7 +256,13 @@ class BaseChatStep(IChatStep): paragraph_list=None, manage: PipelineManage = None, padding_problem_text: str = None, - client_id=None, client_type=None, no_references_setting=None): + client_id=None, client_type=None, no_references_setting=None, + model_setting=None): + reasoning_content_enable = model_setting.get('reasoning_content_enable', False) + reasoning_content_start = model_setting.get('reasoning_content_start', '') + reasoning_content_end = model_setting.get('reasoning_content_end', '') + reasoning = Reasoning(reasoning_content_start, + reasoning_content_end) chat_record_id = uuid.uuid1() # 调用模型 try: @@ -243,14 +275,23 @@ class BaseChatStep(IChatStep): request_token = 0 response_token = 0 write_context(self, manage, request_token, response_token, chat_result.content) + reasoning.get_reasoning_content(chat_result) + reasoning_result = reasoning.get_reasoning_content(chat_result) + content = reasoning_result.get('content') + if 'reasoning_content' in chat_result.response_metadata: + reasoning_content = chat_result.response_metadata.get('reasoning_content', '') + else: + reasoning_content = reasoning_result.get('reasoning_content') post_response_handler.handler(chat_id, chat_record_id, paragraph_list, problem_text, - chat_result.content, manage, self, padding_problem_text, client_id) + chat_result.content, manage, self, padding_problem_text, client_id, + reasoning_content=reasoning_content if reasoning_content_enable else '') add_access_num(client_id, client_type, manage.context.get('application_id')) return manage.get_base_to_response().to_block_response(str(chat_id), str(chat_record_id), - chat_result.content, True, - request_token, response_token) + content, True, + request_token, response_token, + {'reasoning_content': reasoning_content}) except Exception as e: - all_text = '异常' + str(e) + all_text = 'Exception:' + str(e) write_context(self, manage, 0, 0, all_text) post_response_handler.handler(chat_id, chat_record_id, paragraph_list, problem_text, all_text, manage, self, padding_problem_text, client_id) diff --git a/apps/application/flow/common.py b/apps/application/flow/common.py index e0bfcdbb2..f5d4cb9b0 100644 --- a/apps/application/flow/common.py +++ b/apps/application/flow/common.py @@ -9,16 +9,22 @@ class Answer: - def __init__(self, content, view_type, runtime_node_id, chat_record_id, child_node): + def __init__(self, content, view_type, runtime_node_id, chat_record_id, child_node, real_node_id, + reasoning_content): self.view_type = view_type self.content = content + self.reasoning_content = reasoning_content self.runtime_node_id = runtime_node_id self.chat_record_id = chat_record_id self.child_node = child_node + self.real_node_id = real_node_id def to_dict(self): return {'view_type': self.view_type, 'content': self.content, 'runtime_node_id': self.runtime_node_id, - 'chat_record_id': self.chat_record_id, 'child_node': self.child_node} + 'chat_record_id': self.chat_record_id, + 'child_node': self.child_node, + 'reasoning_content': self.reasoning_content, + 'real_node_id': self.real_node_id} class NodeChunk: diff --git a/apps/application/flow/i_step_node.py b/apps/application/flow/i_step_node.py index 211fd4e2e..e1a567731 100644 --- a/apps/application/flow/i_step_node.py +++ b/apps/application/flow/i_step_node.py @@ -62,7 +62,9 @@ class WorkFlowPostHandler: answer_tokens = sum([row.get('answer_tokens') for row in details.values() if 'answer_tokens' in row and row.get('answer_tokens') is not None]) answer_text_list = workflow.get_answer_text_list() - answer_text = '\n\n'.join(answer['content'] for answer in answer_text_list) + answer_text = '\n\n'.join( + '\n\n'.join([a.get('content') for a in answer]) for answer in + answer_text_list) if workflow.chat_record is not None: chat_record = workflow.chat_record chat_record.answer_text = answer_text @@ -157,8 +159,10 @@ class INode: def get_answer_list(self) -> List[Answer] | None: if self.answer_text is None: return None + reasoning_content_enable = self.context.get('model_setting', {}).get('reasoning_content_enable', False) return [ - Answer(self.answer_text, self.view_type, self.runtime_node_id, self.workflow_params['chat_record_id'], {})] + Answer(self.answer_text, self.view_type, self.runtime_node_id, self.workflow_params['chat_record_id'], {}, + self.runtime_node_id, self.context.get('reasoning_content', '') if reasoning_content_enable else '')] def __init__(self, node, workflow_params, workflow_manage, up_node_id_list=None, get_node_params=lambda node: node.properties.get('node_data')): diff --git a/apps/application/flow/step_node/ai_chat_step_node/i_chat_node.py b/apps/application/flow/step_node/ai_chat_step_node/i_chat_node.py index 9475079af..336753450 100644 --- a/apps/application/flow/step_node/ai_chat_step_node/i_chat_node.py +++ b/apps/application/flow/step_node/ai_chat_step_node/i_chat_node.py @@ -28,8 +28,9 @@ class ChatNodeSerializer(serializers.Serializer): error_messages=ErrMessage.boolean(_('Whether to return content'))) model_params_setting = serializers.DictField(required=False, - error_messages=ErrMessage.integer(_("Model parameter settings"))) - + error_messages=ErrMessage.dict(_("Model parameter settings"))) + model_setting = serializers.DictField(required=False, + error_messages=ErrMessage.dict('Model settings')) dialogue_type = serializers.CharField(required=False, allow_blank=True, allow_null=True, error_messages=ErrMessage.char(_("Context Type"))) @@ -47,5 +48,6 @@ class IChatNode(INode): chat_record_id, model_params_setting=None, dialogue_type=None, + model_setting=None, **kwargs) -> NodeResult: pass diff --git a/apps/application/flow/step_node/ai_chat_step_node/impl/base_chat_node.py b/apps/application/flow/step_node/ai_chat_step_node/impl/base_chat_node.py index 86061dba0..dc73887eb 100644 --- a/apps/application/flow/step_node/ai_chat_step_node/impl/base_chat_node.py +++ b/apps/application/flow/step_node/ai_chat_step_node/impl/base_chat_node.py @@ -14,14 +14,17 @@ from django.db.models import QuerySet from langchain.schema import HumanMessage, SystemMessage from langchain_core.messages import BaseMessage, AIMessage +from application.flow.common import Answer from application.flow.i_step_node import NodeResult, INode from application.flow.step_node.ai_chat_step_node.i_chat_node import IChatNode +from application.flow.tools import Reasoning from setting.models import Model from setting.models_provider import get_model_credential from setting.models_provider.tools import get_model_instance_by_model_user_id -def _write_context(node_variable: Dict, workflow_variable: Dict, node: INode, workflow, answer: str): +def _write_context(node_variable: Dict, workflow_variable: Dict, node: INode, workflow, answer: str, + reasoning_content: str): chat_model = node_variable.get('chat_model') message_tokens = chat_model.get_num_tokens_from_messages(node_variable.get('message_list')) answer_tokens = chat_model.get_num_tokens(answer) @@ -31,6 +34,7 @@ def _write_context(node_variable: Dict, workflow_variable: Dict, node: INode, wo node.context['history_message'] = node_variable['history_message'] node.context['question'] = node_variable['question'] node.context['run_time'] = time.time() - node.context['start_time'] + node.context['reasoning_content'] = reasoning_content if workflow.is_result(node, NodeResult(node_variable, workflow_variable)): node.answer_text = answer @@ -45,10 +49,27 @@ def write_context_stream(node_variable: Dict, workflow_variable: Dict, node: INo """ response = node_variable.get('result') answer = '' + reasoning_content = '' + model_setting = node.context.get('model_setting', + {'reasoning_content_enable': False, 'reasoning_content_end': '', + 'reasoning_content_start': ''}) + reasoning = Reasoning(model_setting.get('reasoning_content_start', ''), + model_setting.get('reasoning_content_end', '')) for chunk in response: - answer += chunk.content - yield chunk.content - _write_context(node_variable, workflow_variable, node, workflow, answer) + reasoning_chunk = reasoning.get_reasoning_content(chunk) + content_chunk = reasoning_chunk.get('content') + if 'reasoning_content' in chunk.additional_kwargs: + reasoning_content_chunk = chunk.additional_kwargs.get('reasoning_content', '') + else: + reasoning_content_chunk = reasoning_chunk.get('reasoning_content') + answer += content_chunk + if reasoning_content_chunk is None: + reasoning_content_chunk = '' + reasoning_content += reasoning_content_chunk + yield {'content': content_chunk, + 'reasoning_content': reasoning_content_chunk if model_setting.get('reasoning_content_enable', + False) else ''} + _write_context(node_variable, workflow_variable, node, workflow, answer, reasoning_content) def write_context(node_variable: Dict, workflow_variable: Dict, node: INode, workflow): @@ -60,8 +81,17 @@ def write_context(node_variable: Dict, workflow_variable: Dict, node: INode, wor @param workflow: 工作流管理器 """ response = node_variable.get('result') - answer = response.content - _write_context(node_variable, workflow_variable, node, workflow, answer) + model_setting = node.context.get('model_setting', + {'reasoning_content_enable': False, 'reasoning_content_end': '', + 'reasoning_content_start': ''}) + reasoning = Reasoning(model_setting.get('reasoning_content_start'), model_setting.get('reasoning_content_end')) + reasoning_result = reasoning.get_reasoning_content(response) + content = reasoning_result.get('content') + if 'reasoning_content' in response.response_metadata: + reasoning_content = response.response_metadata.get('reasoning_content', '') + else: + reasoning_content = reasoning_result.get('reasoning_content') + _write_context(node_variable, workflow_variable, node, workflow, content, reasoning_content) def get_default_model_params_setting(model_id): @@ -92,17 +122,23 @@ class BaseChatNode(IChatNode): def save_context(self, details, workflow_manage): self.context['answer'] = details.get('answer') self.context['question'] = details.get('question') + self.context['reasoning_content'] = details.get('reasoning_content') self.answer_text = details.get('answer') def execute(self, model_id, system, prompt, dialogue_number, history_chat_record, stream, chat_id, chat_record_id, model_params_setting=None, dialogue_type=None, + model_setting=None, **kwargs) -> NodeResult: if dialogue_type is None: dialogue_type = 'WORKFLOW' if model_params_setting is None: model_params_setting = get_default_model_params_setting(model_id) + if model_setting is None: + model_setting = {'reasoning_content_enable': False, 'reasoning_content_end': '', + 'reasoning_content_start': ''} + self.context['model_setting'] = model_setting chat_model = get_model_instance_by_model_user_id(model_id, self.flow_params_serializer.data.get('user_id'), **model_params_setting) history_message = self.get_history_message(history_chat_record, dialogue_number, dialogue_type, @@ -164,6 +200,7 @@ class BaseChatNode(IChatNode): 'history_message') is not None else [])], 'question': self.context.get('question'), 'answer': self.context.get('answer'), + 'reasoning_content': self.context.get('reasoning_content'), 'type': self.node.type, 'message_tokens': self.context.get('message_tokens'), 'answer_tokens': self.context.get('answer_tokens'), diff --git a/apps/application/flow/step_node/application_node/impl/base_application_node.py b/apps/application/flow/step_node/application_node/impl/base_application_node.py index e6de4cc97..b1ef5f4d7 100644 --- a/apps/application/flow/step_node/application_node/impl/base_application_node.py +++ b/apps/application/flow/step_node/application_node/impl/base_application_node.py @@ -19,7 +19,8 @@ def _is_interrupt_exec(node, node_variable: Dict, workflow_variable: Dict): return node_variable.get('is_interrupt_exec', False) -def _write_context(node_variable: Dict, workflow_variable: Dict, node: INode, workflow, answer: str): +def _write_context(node_variable: Dict, workflow_variable: Dict, node: INode, workflow, answer: str, + reasoning_content: str): result = node_variable.get('result') node.context['application_node_dict'] = node_variable.get('application_node_dict') node.context['node_dict'] = node_variable.get('node_dict', {}) @@ -28,6 +29,7 @@ def _write_context(node_variable: Dict, workflow_variable: Dict, node: INode, wo node.context['answer_tokens'] = result.get('usage', {}).get('completion_tokens', 0) node.context['answer'] = answer node.context['result'] = answer + node.context['reasoning_content'] = reasoning_content node.context['question'] = node_variable['question'] node.context['run_time'] = time.time() - node.context['start_time'] if workflow.is_result(node, NodeResult(node_variable, workflow_variable)): @@ -44,6 +46,7 @@ def write_context_stream(node_variable: Dict, workflow_variable: Dict, node: INo """ response = node_variable.get('result') answer = '' + reasoning_content = '' usage = {} node_child_node = {} application_node_dict = node.context.get('application_node_dict', {}) @@ -60,9 +63,11 @@ def write_context_stream(node_variable: Dict, workflow_variable: Dict, node: INo node_type = response_content.get('node_type') real_node_id = response_content.get('real_node_id') node_is_end = response_content.get('node_is_end', False) + _reasoning_content = response_content.get('reasoning_content', '') if node_type == 'form-node': is_interrupt_exec = True answer += content + reasoning_content += _reasoning_content node_child_node = {'runtime_node_id': runtime_node_id, 'chat_record_id': chat_record_id, 'child_node': child_node} @@ -75,13 +80,16 @@ def write_context_stream(node_variable: Dict, workflow_variable: Dict, node: INo 'chat_record_id': chat_record_id, 'child_node': child_node, 'index': len(application_node_dict), - 'view_type': view_type} + 'view_type': view_type, + 'reasoning_content': _reasoning_content} else: application_node['content'] += content + application_node['reasoning_content'] += _reasoning_content yield {'content': content, 'node_type': node_type, 'runtime_node_id': runtime_node_id, 'chat_record_id': chat_record_id, + 'reasoning_content': _reasoning_content, 'child_node': child_node, 'real_node_id': real_node_id, 'node_is_end': node_is_end, @@ -91,7 +99,7 @@ def write_context_stream(node_variable: Dict, workflow_variable: Dict, node: INo node_variable['is_interrupt_exec'] = is_interrupt_exec node_variable['child_node'] = node_child_node node_variable['application_node_dict'] = application_node_dict - _write_context(node_variable, workflow_variable, node, workflow, answer) + _write_context(node_variable, workflow_variable, node, workflow, answer, reasoning_content) def write_context(node_variable: Dict, workflow_variable: Dict, node: INode, workflow): @@ -106,7 +114,8 @@ def write_context(node_variable: Dict, workflow_variable: Dict, node: INode, wor node_variable['result'] = {'usage': {'completion_tokens': response.get('completion_tokens'), 'prompt_tokens': response.get('prompt_tokens')}} answer = response.get('content', '') or "抱歉,没有查找到相关内容,请重新描述您的问题或提供更多信息。" - _write_context(node_variable, workflow_variable, node, workflow, answer) + reasoning_content = response.get('reasoning_content', '') + _write_context(node_variable, workflow_variable, node, workflow, answer, reasoning_content) def reset_application_node_dict(application_node_dict, runtime_node_id, node_data): @@ -139,18 +148,22 @@ class BaseApplicationNode(IApplicationNode): if application_node_dict is None or len(application_node_dict) == 0: return [ Answer(self.answer_text, self.view_type, self.runtime_node_id, self.workflow_params['chat_record_id'], - self.context.get('child_node'))] + self.context.get('child_node'), self.runtime_node_id, '')] else: return [Answer(n.get('content'), n.get('view_type'), self.runtime_node_id, self.workflow_params['chat_record_id'], {'runtime_node_id': n.get('runtime_node_id'), 'chat_record_id': n.get('chat_record_id') - , 'child_node': n.get('child_node')}) for n in + , 'child_node': n.get('child_node')}, n.get('real_node_id'), + n.get('reasoning_content', '')) + for n in sorted(application_node_dict.values(), key=lambda item: item.get('index'))] def save_context(self, details, workflow_manage): self.context['answer'] = details.get('answer') + self.context['result'] = details.get('answer') self.context['question'] = details.get('question') self.context['type'] = details.get('type') + self.context['reasoning_content'] = details.get('reasoning_content') self.answer_text = details.get('answer') def execute(self, application_id, message, chat_id, chat_record_id, stream, re_chat, client_id, client_type, @@ -229,6 +242,7 @@ class BaseApplicationNode(IApplicationNode): 'run_time': self.context.get('run_time'), 'question': self.context.get('question'), 'answer': self.context.get('answer'), + 'reasoning_content': self.context.get('reasoning_content'), 'type': self.node.type, 'message_tokens': self.context.get('message_tokens'), 'answer_tokens': self.context.get('answer_tokens'), diff --git a/apps/application/flow/step_node/form_node/impl/base_form_node.py b/apps/application/flow/step_node/form_node/impl/base_form_node.py index 8890538a1..3aeca8560 100644 --- a/apps/application/flow/step_node/form_node/impl/base_form_node.py +++ b/apps/application/flow/step_node/form_node/impl/base_form_node.py @@ -75,7 +75,8 @@ class BaseFormNode(IFormNode): form_content_format = self.workflow_manage.reset_prompt(form_content_format) prompt_template = PromptTemplate.from_template(form_content_format, template_format='jinja2') value = prompt_template.format(form=form, context=context) - return [Answer(value, self.view_type, self.runtime_node_id, self.workflow_params['chat_record_id'], None)] + return [Answer(value, self.view_type, self.runtime_node_id, self.workflow_params['chat_record_id'], None, + self.runtime_node_id, '')] def get_details(self, index: int, **kwargs): form_content_format = self.context.get('form_content_format') diff --git a/apps/application/flow/tools.py b/apps/application/flow/tools.py index b2bf6b1ad..2d4ebbcfd 100644 --- a/apps/application/flow/tools.py +++ b/apps/application/flow/tools.py @@ -16,6 +16,70 @@ from application.flow.i_step_node import WorkFlowPostHandler from common.response import result +class Reasoning: + def __init__(self, reasoning_content_start, reasoning_content_end): + self.content = "" + self.reasoning_content = "" + self.all_content = "" + self.reasoning_content_start_tag = reasoning_content_start + self.reasoning_content_end_tag = reasoning_content_end + self.reasoning_content_start_tag_len = len(reasoning_content_start) + self.reasoning_content_end_tag_len = len(reasoning_content_end) + self.reasoning_content_end_tag_prefix = reasoning_content_end[0] + self.reasoning_content_is_start = False + self.reasoning_content_is_end = False + self.reasoning_content_chunk = "" + + def get_reasoning_content(self, chunk): + self.all_content += chunk.content + if not self.reasoning_content_is_start and len(self.all_content) >= self.reasoning_content_start_tag_len: + if self.all_content.startswith(self.reasoning_content_start_tag): + self.reasoning_content_is_start = True + self.reasoning_content_chunk = self.all_content[self.reasoning_content_start_tag_len:] + else: + self.reasoning_content_is_end = True + else: + if self.reasoning_content_is_start: + self.reasoning_content_chunk += chunk.content + reasoning_content_end_tag_prefix_index = self.reasoning_content_chunk.find( + self.reasoning_content_end_tag_prefix) + if self.reasoning_content_is_end: + self.content += chunk.content + return {'content': chunk.content, 'reasoning_content': ''} + # 是否包含结束 + if reasoning_content_end_tag_prefix_index > -1: + if len( + self.reasoning_content_chunk) - reasoning_content_end_tag_prefix_index > self.reasoning_content_end_tag_len: + reasoning_content_end_tag_index = self.reasoning_content_chunk.find(self.reasoning_content_end_tag) + if reasoning_content_end_tag_index > -1: + reasoning_content_chunk = self.reasoning_content_chunk[0:reasoning_content_end_tag_index] + content_chunk = self.reasoning_content_chunk[ + reasoning_content_end_tag_index + self.reasoning_content_end_tag_len:] + self.reasoning_content += reasoning_content_chunk + self.content += content_chunk + self.reasoning_content_chunk = "" + self.reasoning_content_is_end = True + return {'content': content_chunk, 'reasoning_content': reasoning_content_chunk} + else: + reasoning_content_chunk = self.reasoning_content_chunk[0:reasoning_content_end_tag_prefix_index + 1] + self.reasoning_content_chunk = self.reasoning_content_chunk.replace(reasoning_content_chunk, '') + self.reasoning_content += reasoning_content_chunk + return {'content': '', 'reasoning_content': reasoning_content_chunk} + else: + return {'content': '', 'reasoning_content': ''} + + else: + if self.reasoning_content_is_end: + self.content += chunk.content + return {'content': chunk.content, 'reasoning_content': ''} + else: + # aaa + result = {'content': '', 'reasoning_content': self.reasoning_content_chunk} + self.reasoning_content += self.reasoning_content_chunk + self.reasoning_content_chunk = "" + return result + + def event_content(chat_id, chat_record_id, response, workflow, write_context, post_handler: WorkFlowPostHandler): diff --git a/apps/application/flow/workflow_manage.py b/apps/application/flow/workflow_manage.py index 9472ad376..9c27d6fc9 100644 --- a/apps/application/flow/workflow_manage.py +++ b/apps/application/flow/workflow_manage.py @@ -470,6 +470,7 @@ class WorkflowManage: if result is not None: if self.is_result(current_node, current_result): for r in result: + reasoning_content = '' content = r child_node = {} node_is_end = False @@ -479,9 +480,12 @@ class WorkflowManage: child_node = {'runtime_node_id': r.get('runtime_node_id'), 'chat_record_id': r.get('chat_record_id') , 'child_node': r.get('child_node')} - real_node_id = r.get('real_node_id') - node_is_end = r.get('node_is_end') + if r.__contains__('real_node_id'): + real_node_id = r.get('real_node_id') + if r.__contains__('node_is_end'): + node_is_end = r.get('node_is_end') view_type = r.get('view_type') + reasoning_content = r.get('reasoning_content') chunk = self.base_to_response.to_stream_chunk_response(self.params['chat_id'], self.params['chat_record_id'], current_node.id, @@ -492,7 +496,8 @@ class WorkflowManage: 'view_type': view_type, 'child_node': child_node, 'node_is_end': node_is_end, - 'real_node_id': real_node_id}) + 'real_node_id': real_node_id, + 'reasoning_content': reasoning_content}) current_node.node_chunk.add_chunk(chunk) chunk = (self.base_to_response .to_stream_chunk_response(self.params['chat_id'], @@ -504,7 +509,8 @@ class WorkflowManage: 'node_type': current_node.type, 'view_type': view_type, 'child_node': child_node, - 'real_node_id': real_node_id})) + 'real_node_id': real_node_id, + 'reasoning_content': ''})) current_node.node_chunk.add_chunk(chunk) else: list(result) @@ -516,7 +522,7 @@ class WorkflowManage: self.params['chat_record_id'], current_node.id, current_node.up_node_id_list, - str(e), False, 0, 0, + 'Exception:' + str(e), False, 0, 0, {'node_is_end': True, 'runtime_node_id': current_node.runtime_node_id, 'node_type': current_node.type, @@ -603,20 +609,19 @@ class WorkflowManage: if len(current_answer.content) > 0: if up_node is None or current_answer.view_type == 'single_view' or ( current_answer.view_type == 'many_view' and up_node.view_type == 'single_view'): - result.append(current_answer) + result.append([current_answer]) else: if len(result) > 0: exec_index = len(result) - 1 - content = result[exec_index].content - result[exec_index].content += current_answer.content if len( - content) == 0 else ('\n\n' + current_answer.content) + if isinstance(result[exec_index], list): + result[exec_index].append(current_answer) else: - result.insert(0, current_answer) + result.insert(0, [current_answer]) up_node = current_answer if len(result) == 0: # 如果没有响应 就响应一个空数据 - return [Answer('', '', '', '', {}).to_dict()] - return [r.to_dict() for r in result] + return [[]] + return [[item.to_dict() for item in r] for r in result] def get_next_node(self): """ diff --git a/apps/application/models/application.py b/apps/application/models/application.py index 8e6200d07..ace8c29d4 100644 --- a/apps/application/models/application.py +++ b/apps/application/models/application.py @@ -35,7 +35,13 @@ def get_dataset_setting_dict(): def get_model_setting_dict(): - return {'prompt': Application.get_default_model_prompt(), 'no_references_prompt': '{question}'} + return { + 'prompt': Application.get_default_model_prompt(), + 'no_references_prompt': '{question}', + 'reasoning_content_start': '', + 'reasoning_content_end': '', + 'reasoning_content_enable': False, + } class Application(AppModelMixin): diff --git a/apps/application/serializers/application_serializers.py b/apps/application/serializers/application_serializers.py index d199a4eab..d97f4782a 100644 --- a/apps/application/serializers/application_serializers.py +++ b/apps/application/serializers/application_serializers.py @@ -140,6 +140,13 @@ class ModelSettingSerializer(serializers.Serializer): error_messages=ErrMessage.char(_("Role prompts"))) no_references_prompt = serializers.CharField(required=True, max_length=102400, allow_null=True, allow_blank=True, error_messages=ErrMessage.char(_("No citation segmentation prompt"))) + reasoning_content_enable = serializers.BooleanField(required=False, + error_messages=ErrMessage.char(_("Thinking process switch"))) + reasoning_content_start = serializers.CharField(required=False, allow_null=True, allow_blank=True, max_length=256, + error_messages=ErrMessage.char( + _("The thinking process begins to mark"))) + reasoning_content_end = serializers.CharField(required=False, allow_null=True, allow_blank=True, max_length=256, + error_messages=ErrMessage.char(_("End of thinking process marker"))) class ApplicationWorkflowSerializer(serializers.Serializer): diff --git a/apps/application/serializers/chat_message_serializers.py b/apps/application/serializers/chat_message_serializers.py index ca593c61c..86394bb3e 100644 --- a/apps/application/serializers/chat_message_serializers.py +++ b/apps/application/serializers/chat_message_serializers.py @@ -22,6 +22,7 @@ from application.chat_pipeline.step.generate_human_message_step.impl.base_genera BaseGenerateHumanMessageStep from application.chat_pipeline.step.reset_problem_step.impl.base_reset_problem_step import BaseResetProblemStep from application.chat_pipeline.step.search_dataset_step.impl.base_search_dataset_step import BaseSearchDatasetStep +from application.flow.common import Answer from application.flow.i_step_node import WorkFlowPostHandler from application.flow.workflow_manage import WorkflowManage, Flow from application.models import ChatRecord, Chat, Application, ApplicationDatasetMapping, ApplicationTypeChoices, \ @@ -104,6 +105,7 @@ class ChatInfo: 'model_id': model_id, 'problem_optimization': self.application.problem_optimization, 'stream': True, + 'model_setting': model_setting, 'model_params_setting': model_params_setting if self.application.model_params_setting is None or len( self.application.model_params_setting.keys()) == 0 else self.application.model_params_setting, 'search_mode': self.application.dataset_setting.get( @@ -157,6 +159,8 @@ def get_post_handler(chat_info: ChatInfo): padding_problem_text: str = None, client_id=None, **kwargs): + answer_list = [[Answer(answer_text, 'ai-chat-node', 'ai-chat-node', 'ai-chat-node', {}, 'ai-chat-node', + kwargs.get('reasoning_content', '')).to_dict()]] chat_record = ChatRecord(id=chat_record_id, chat_id=chat_id, problem_text=problem_text, @@ -164,7 +168,7 @@ def get_post_handler(chat_info: ChatInfo): details=manage.get_details(), message_tokens=manage.context['message_tokens'], answer_tokens=manage.context['answer_tokens'], - answer_text_list=[answer_text], + answer_text_list=answer_list, run_time=manage.context['run_time'], index=len(chat_info.chat_record_list) + 1) chat_info.append_chat_record(chat_record, client_id) @@ -242,15 +246,18 @@ class ChatMessageSerializer(serializers.Serializer): runtime_node_id = serializers.CharField(required=False, allow_null=True, allow_blank=True, error_messages=ErrMessage.char(_("Runtime node id"))) - node_data = serializers.DictField(required=False, allow_null=True, error_messages=ErrMessage.char(_("Node parameters"))) - application_id = serializers.UUIDField(required=False, allow_null=True, error_messages=ErrMessage.uuid(_("Application ID"))) + node_data = serializers.DictField(required=False, allow_null=True, + error_messages=ErrMessage.char(_("Node parameters"))) + application_id = serializers.UUIDField(required=False, allow_null=True, + error_messages=ErrMessage.uuid(_("Application ID"))) client_id = serializers.CharField(required=True, error_messages=ErrMessage.char(_("Client id"))) client_type = serializers.CharField(required=True, error_messages=ErrMessage.char(_("Client Type"))) form_data = serializers.DictField(required=False, error_messages=ErrMessage.char(_("Global variables"))) image_list = serializers.ListField(required=False, error_messages=ErrMessage.list(_("picture"))) document_list = serializers.ListField(required=False, error_messages=ErrMessage.list(_("document"))) audio_list = serializers.ListField(required=False, error_messages=ErrMessage.list(_("Audio"))) - child_node = serializers.DictField(required=False, allow_null=True, error_messages=ErrMessage.dict(_("Child Nodes"))) + child_node = serializers.DictField(required=False, allow_null=True, + error_messages=ErrMessage.dict(_("Child Nodes"))) def is_valid_application_workflow(self, *, raise_exception=False): self.is_valid_intraday_access_num() diff --git a/apps/setting/models_provider/impl/base_chat_open_ai.py b/apps/setting/models_provider/impl/base_chat_open_ai.py index a932a0cb2..38b6b20fc 100644 --- a/apps/setting/models_provider/impl/base_chat_open_ai.py +++ b/apps/setting/models_provider/impl/base_chat_open_ai.py @@ -1,12 +1,16 @@ # coding=utf-8 +import warnings +from typing import List, Dict, Optional, Any, Iterator, cast, Type, Union -from typing import List, Dict, Optional, Any, Iterator, cast - +import openai +from langchain_core.callbacks import CallbackManagerForLLMRun from langchain_core.language_models import LanguageModelInput -from langchain_core.messages import BaseMessage, get_buffer_string +from langchain_core.messages import BaseMessage, get_buffer_string, BaseMessageChunk, AIMessageChunk from langchain_core.outputs import ChatGenerationChunk, ChatGeneration from langchain_core.runnables import RunnableConfig, ensure_config +from langchain_core.utils.pydantic import is_basemodel_subclass from langchain_openai import ChatOpenAI +from langchain_openai.chat_models.base import _convert_chunk_to_generation_chunk from common.config.tokenizer_manage_config import TokenizerManage @@ -36,14 +40,101 @@ class BaseChatOpenAI(ChatOpenAI): return self.get_last_generation_info().get('output_tokens', 0) def _stream( - self, *args: Any, stream_usage: Optional[bool] = None, **kwargs: Any + self, + messages: List[BaseMessage], + stop: Optional[List[str]] = None, + run_manager: Optional[CallbackManagerForLLMRun] = None, + **kwargs: Any, ) -> Iterator[ChatGenerationChunk]: + + """Set default stream_options.""" + stream_usage = self._should_stream_usage(kwargs.get('stream_usage'), **kwargs) + # Note: stream_options is not a valid parameter for Azure OpenAI. + # To support users proxying Azure through ChatOpenAI, here we only specify + # stream_options if include_usage is set to True. + # See https://learn.microsoft.com/en-us/azure/ai-services/openai/whats-new + # for release notes. + if stream_usage: + kwargs["stream_options"] = {"include_usage": stream_usage} + kwargs["stream"] = True - kwargs["stream_options"] = {"include_usage": True} - for chunk in super()._stream(*args, stream_usage=stream_usage, **kwargs): - if chunk.message.usage_metadata is not None: - self.usage_metadata = chunk.message.usage_metadata - yield chunk + payload = self._get_request_payload(messages, stop=stop, **kwargs) + default_chunk_class: Type[BaseMessageChunk] = AIMessageChunk + base_generation_info = {} + + if "response_format" in payload and is_basemodel_subclass( + payload["response_format"] + ): + # TODO: Add support for streaming with Pydantic response_format. + warnings.warn("Streaming with Pydantic response_format not yet supported.") + chat_result = self._generate( + messages, stop, run_manager=run_manager, **kwargs + ) + msg = chat_result.generations[0].message + yield ChatGenerationChunk( + message=AIMessageChunk( + **msg.dict(exclude={"type", "additional_kwargs"}), + # preserve the "parsed" Pydantic object without converting to dict + additional_kwargs=msg.additional_kwargs, + ), + generation_info=chat_result.generations[0].generation_info, + ) + return + if self.include_response_headers: + raw_response = self.client.with_raw_response.create(**payload) + response = raw_response.parse() + base_generation_info = {"headers": dict(raw_response.headers)} + else: + response = self.client.create(**payload) + with response: + is_first_chunk = True + for chunk in response: + if not isinstance(chunk, dict): + chunk = chunk.model_dump() + + generation_chunk = _convert_chunk_to_generation_chunk( + chunk, + default_chunk_class, + base_generation_info if is_first_chunk else {}, + ) + if generation_chunk is None: + continue + + # custom code + if generation_chunk.message.usage_metadata is not None: + self.usage_metadata = generation_chunk.message.usage_metadata + # custom code + if 'reasoning_content' in chunk['choices'][0]['delta']: + generation_chunk.message.additional_kwargs["reasoning_content"] = chunk['choices'][0]['delta'][ + 'reasoning_content'] + + default_chunk_class = generation_chunk.message.__class__ + logprobs = (generation_chunk.generation_info or {}).get("logprobs") + if run_manager: + run_manager.on_llm_new_token( + generation_chunk.text, chunk=generation_chunk, logprobs=logprobs + ) + is_first_chunk = False + yield generation_chunk + + def _create_chat_result(self, + response: Union[dict, openai.BaseModel], + generation_info: Optional[Dict] = None): + result = super()._create_chat_result(response, generation_info) + try: + reasoning_content = '' + reasoning_content_enable = False + for res in response.choices: + if 'reasoning_content' in res.message.model_extra: + reasoning_content_enable = True + _reasoning_content = res.message.model_extra.get('reasoning_content') + if _reasoning_content is not None: + reasoning_content += _reasoning_content + if reasoning_content_enable: + result.llm_output['reasoning_content'] = reasoning_content + except Exception as e: + pass + return result def invoke( self, diff --git a/apps/setting/models_provider/impl/openai_model_provider/model/llm.py b/apps/setting/models_provider/impl/openai_model_provider/model/llm.py index 5f0d1b295..15b94f6cd 100644 --- a/apps/setting/models_provider/impl/openai_model_provider/model/llm.py +++ b/apps/setting/models_provider/impl/openai_model_provider/model/llm.py @@ -13,6 +13,7 @@ from langchain_openai.chat_models import ChatOpenAI from common.config.tokenizer_manage_config import TokenizerManage from setting.models_provider.base_model_provider import MaxKBBaseModel +from setting.models_provider.impl.base_chat_open_ai import BaseChatOpenAI def custom_get_token_ids(text: str): @@ -20,7 +21,7 @@ def custom_get_token_ids(text: str): return tokenizer.encode(text) -class OpenAIChatModel(MaxKBBaseModel, ChatOpenAI): +class OpenAIChatModel(MaxKBBaseModel, BaseChatOpenAI): @staticmethod def is_cache_model(): diff --git a/ui/src/api/type/application.ts b/ui/src/api/type/application.ts index 404f5a55e..5e05bdd03 100644 --- a/ui/src/api/type/application.ts +++ b/ui/src/api/type/application.ts @@ -29,6 +29,7 @@ interface Chunk { chat_id: string chat_record_id: string content: string + reasoning_content: string node_id: string up_node_id: string is_end: boolean @@ -43,12 +44,16 @@ interface chatType { problem_text: string answer_text: string buffer: Array - answer_text_list: Array<{ - content: string - chat_record_id?: string - runtime_node_id?: string - child_node?: any - }> + answer_text_list: Array< + Array<{ + content: string + reasoning_content: string + chat_record_id?: string + runtime_node_id?: string + child_node?: any + real_node_id?: string + }> + > /** * 是否写入结束 */ @@ -83,6 +88,7 @@ interface WriteNodeInfo { answer_text_list_index: number current_up_node?: any divider_content?: Array + divider_reasoning_content?: Array } export class ChatRecordManage { id?: any @@ -105,20 +111,38 @@ export class ChatRecordManage { } append_answer( chunk_answer: string, + reasoning_content: string, index?: number, chat_record_id?: string, runtime_node_id?: string, - child_node?: any + child_node?: any, + real_node_id?: string ) { - const set_index = index != undefined ? index : this.chat.answer_text_list.length - 1 - const content = this.chat.answer_text_list[set_index] - ? this.chat.answer_text_list[set_index].content + chunk_answer - : chunk_answer - this.chat.answer_text_list[set_index] = { - content: content, - chat_record_id, - runtime_node_id, - child_node + if (chunk_answer || reasoning_content) { + const set_index = index != undefined ? index : this.chat.answer_text_list.length - 1 + let card_list = this.chat.answer_text_list[set_index] + if (!card_list) { + card_list = [] + this.chat.answer_text_list[set_index] = card_list + } + const answer_value = card_list.find((item) => item.real_node_id == real_node_id) + const content = answer_value ? answer_value.content + chunk_answer : chunk_answer + const _reasoning_content = answer_value + ? answer_value.reasoning_content + reasoning_content + : reasoning_content + if (answer_value) { + answer_value.content = content + answer_value.reasoning_content = _reasoning_content + } else { + card_list.push({ + content: content, + reasoning_content: _reasoning_content, + chat_record_id, + runtime_node_id, + child_node, + real_node_id + }) + } } this.chat.answer_text = this.chat.answer_text + chunk_answer @@ -155,7 +179,7 @@ export class ChatRecordManage { ) { const none_index = this.findIndex( this.chat.answer_text_list, - (item) => item.content == '', + (item) => (item.length == 1 && item[0].content == '') || item.length == 0, 'index' ) if (none_index > -1) { @@ -166,7 +190,7 @@ export class ChatRecordManage { } else { const none_index = this.findIndex( this.chat.answer_text_list, - (item) => item.content === '', + (item) => (item.length == 1 && item[0].content == '') || item.length == 0, 'index' ) if (none_index > -1) { @@ -178,10 +202,10 @@ export class ChatRecordManage { this.write_node_info = { current_node: run_node, - divider_content: ['\n\n'], current_up_node: current_up_node, answer_text_list_index: answer_text_list_index } + return this.write_node_info } return undefined @@ -210,7 +234,7 @@ export class ChatRecordManage { } const last_index = this.findIndex( this.chat.answer_text_list, - (item) => item.content == '', + (item) => (item.length == 1 && item[0].content == '') || item.length == 0, 'last' ) if (last_index > 0) { @@ -234,7 +258,8 @@ export class ChatRecordManage { } return } - const { current_node, answer_text_list_index, divider_content } = node_info + const { current_node, answer_text_list_index } = node_info + if (current_node.buffer.length > 20) { const context = current_node.is_end ? current_node.buffer.splice(0) @@ -242,12 +267,20 @@ export class ChatRecordManage { 0, current_node.is_end ? undefined : current_node.buffer.length - 20 ) + const reasoning_content = current_node.is_end + ? current_node.reasoning_content_buffer.splice(0) + : current_node.reasoning_content_buffer.splice( + 0, + current_node.is_end ? undefined : current_node.reasoning_content_buffer.length - 20 + ) this.append_answer( - (divider_content ? divider_content.splice(0).join('') : '') + context.join(''), + context.join(''), + reasoning_content.join(''), answer_text_list_index, current_node.chat_record_id, current_node.runtime_node_id, - current_node.child_node + current_node.child_node, + current_node.real_node_id ) } else if (this.is_close) { while (true) { @@ -257,27 +290,46 @@ export class ChatRecordManage { break } this.append_answer( - (node_info.divider_content ? node_info.divider_content.splice(0).join('') : '') + - node_info.current_node.buffer.splice(0).join(''), + node_info.current_node.buffer.splice(0).join(''), + node_info.current_node.reasoning_content_buffer.splice(0).join(''), node_info.answer_text_list_index, node_info.current_node.chat_record_id, node_info.current_node.runtime_node_id, - node_info.current_node.child_node + node_info.current_node.child_node, + node_info.current_node.real_node_id ) - if (node_info.current_node.buffer.length == 0) { + + if ( + node_info.current_node.buffer.length == 0 && + node_info.current_node.reasoning_content_buffer.length == 0 + ) { node_info.current_node.is_end = true } } this.closeInterval() } else { const s = current_node.buffer.shift() + const reasoning_content = current_node.reasoning_content_buffer.shift() if (s !== undefined) { this.append_answer( - (divider_content ? divider_content.splice(0).join('') : '') + s, + s, + '', answer_text_list_index, current_node.chat_record_id, current_node.runtime_node_id, - current_node.child_node + current_node.child_node, + current_node.real_node_id + ) + } + if (reasoning_content !== undefined) { + this.append_answer( + '', + reasoning_content, + answer_text_list_index, + current_node.chat_record_id, + current_node.runtime_node_id, + current_node.child_node, + current_node.real_node_id ) } } @@ -303,9 +355,15 @@ export class ChatRecordManage { if (n) { n.buffer.push(...chunk.content) n.content += chunk.content + if (chunk.reasoning_content) { + n.reasoning_content_buffer.push(...chunk.reasoning_content) + n.reasoning_content += chunk.reasoning_content + } } else { n = { buffer: [...chunk.content], + reasoning_content_buffer: chunk.reasoning_content ? [...chunk.reasoning_content] : [], + reasoning_content: chunk.reasoning_content ? chunk.reasoning_content : '', content: chunk.content, real_node_id: chunk.real_node_id, node_id: chunk.node_id, @@ -324,13 +382,18 @@ export class ChatRecordManage { n['is_end'] = true } } - append(answer_text_block: string) { + append(answer_text_block: string, reasoning_content?: string) { let set_index = this.findIndex( this.chat.answer_text_list, - (item) => item.content == '', + (item) => item.length == 1 && item[0].content == '', 'index' ) - this.chat.answer_text_list[set_index] = { content: answer_text_block } + this.chat.answer_text_list[set_index] = [ + { + content: answer_text_block, + reasoning_content: reasoning_content ? reasoning_content : '' + } + ] } } @@ -346,10 +409,10 @@ export class ChatManagement { chatRecord.appendChunk(chunk) } } - static append(chatRecordId: string, content: string) { + static append(chatRecordId: string, content: string, reasoning_content?: string) { const chatRecord = this.chatMessageContainer[chatRecordId] if (chatRecord) { - chatRecord.append(content) + chatRecord.append(content, reasoning_content) } } static updateStatus(chatRecordId: string, code: number) { diff --git a/ui/src/components/ai-chat/ExecutionDetailDialog.vue b/ui/src/components/ai-chat/ExecutionDetailDialog.vue index 263c4feb1..620e7ca5a 100644 --- a/ui/src/components/ai-chat/ExecutionDetailDialog.vue +++ b/ui/src/components/ai-chat/ExecutionDetailDialog.vue @@ -214,6 +214,17 @@ {{ item.question || '-' }} +
+
+ {{ $t('views.applicationWorkflow.nodes.aiChatNode.think')}} +
+
+ {{ item.reasoning_content || '-' }} +
+
{{ diff --git a/ui/src/components/ai-chat/component/answer-content/index.vue b/ui/src/components/ai-chat/component/answer-content/index.vue index f4f420daa..b90a7043d 100644 --- a/ui/src/components/ai-chat/component/answer-content/index.vue +++ b/ui/src/components/ai-chat/component/answer-content/index.vue @@ -10,19 +10,23 @@ - + {{ $t('chat.tip.stopAnswer') }} @@ -90,9 +94,20 @@ const openControl = (event: any) => { const answer_text_list = computed(() => { return props.chatRecord.answer_text_list.map((item) => { if (typeof item == 'string') { - return { content: item } + return [ + { + content: item, + chat_record_id: undefined, + child_node: undefined, + runtime_node_id: undefined, + reasoning_content: undefined + } + ] + } else if (item instanceof Array) { + return item + } else { + return [item] } - return item }) }) diff --git a/ui/src/components/ai-chat/component/prologue-content/index.vue b/ui/src/components/ai-chat/component/prologue-content/index.vue index 62e293c68..2fa247d8e 100644 --- a/ui/src/components/ai-chat/component/prologue-content/index.vue +++ b/ui/src/components/ai-chat/component/prologue-content/index.vue @@ -7,7 +7,11 @@
- +
@@ -27,9 +31,7 @@ const toQuickQuestion = (match: string, offset: number, input: string) => { return `${match.replace('- ', '')}` } const prologue = computed(() => { - const temp = props.available - ? props.application?.prologue - : t('chat.tip.prologueMessage') + const temp = props.available ? props.application?.prologue : t('chat.tip.prologueMessage') return temp?.replace(/-\s.+/g, toQuickQuestion) }) diff --git a/ui/src/components/ai-chat/index.vue b/ui/src/components/ai-chat/index.vue index 1260fb33e..e811cd06d 100644 --- a/ui/src/components/ai-chat/index.vue +++ b/ui/src/components/ai-chat/index.vue @@ -284,8 +284,10 @@ function chatMessage(chat?: any, problem?: string, re_chat?: boolean, other_para id: randomId(), problem_text: problem ? problem : inputValue.value.trim(), answer_text: '', - answer_text_list: [{ content: '' }], + answer_text_list: [[]], buffer: [], + reasoning_content: '', + reasoning_content_buffer: [], write_ed: false, is_stop: false, record_id: '', diff --git a/ui/src/components/markdown/MdRenderer.vue b/ui/src/components/markdown/MdRenderer.vue index bc2f90709..1cbdc7deb 100644 --- a/ui/src/components/markdown/MdRenderer.vue +++ b/ui/src/components/markdown/MdRenderer.vue @@ -1,40 +1,44 @@ + diff --git a/ui/src/locales/lang/en-US/views/application-workflow.ts b/ui/src/locales/lang/en-US/views/application-workflow.ts index 33913c887..c1275462a 100644 --- a/ui/src/locales/lang/en-US/views/application-workflow.ts +++ b/ui/src/locales/lang/en-US/views/application-workflow.ts @@ -113,7 +113,8 @@ export default { tooltip: `If turned off, the content of this node will not be output to the user. If you want the user to see the output of this node, please turn on the switch.` }, - defaultPrompt: 'Known Information' + defaultPrompt: 'Known Information', + think: 'Thinking Process', }, searchDatasetNode: { label: 'Knowledge Retrieval', diff --git a/ui/src/locales/lang/en-US/views/application.ts b/ui/src/locales/lang/en-US/views/application.ts index e05b5daa7..3ed968a4b 100644 --- a/ui/src/locales/lang/en-US/views/application.ts +++ b/ui/src/locales/lang/en-US/views/application.ts @@ -61,8 +61,9 @@ export default { references: ' (References Knowledge)', placeholder: 'Please enter prompt', requiredMessage: 'Please enter prompt', - tooltip:'By adjusting the content of the prompt, you can guide the direction of the large model chat.', - + tooltip: + 'By adjusting the content of the prompt, you can guide the direction of the large model chat.', + noReferencesTooltip: 'By adjusting the content of the prompt, you can guide the direction of the large model chat. This prompt will be fixed at the beginning of the context. Variables used: {question} is the question posed by the user.', referencesTooltip: @@ -105,6 +106,13 @@ export default { browser: 'Browser playback (free)', tts: 'TTS Model', listeningTest: 'Preview' + }, + reasoningContent: { + label: 'Output Thinking', + tooltip: + 'According to the thinking tags set by the model, the content between the tags will be considered as the thought process.', + start: 'Start', + end: 'End' } }, buttons: { diff --git a/ui/src/locales/lang/zh-CN/views/application-workflow.ts b/ui/src/locales/lang/zh-CN/views/application-workflow.ts index e134f79e8..2ccaf16a1 100644 --- a/ui/src/locales/lang/zh-CN/views/application-workflow.ts +++ b/ui/src/locales/lang/zh-CN/views/application-workflow.ts @@ -114,7 +114,8 @@ export default { tooltip: `关闭后该节点的内容则不输出给用户。 如果你想让用户看到该节点的输出内容,请打开开关。` }, - defaultPrompt: '已知信息' + defaultPrompt: '已知信息', + think: '思考过程', }, searchDatasetNode: { label: '知识库检索', diff --git a/ui/src/locales/lang/zh-CN/views/application.ts b/ui/src/locales/lang/zh-CN/views/application.ts index 209544d3d..6f52b82a0 100644 --- a/ui/src/locales/lang/zh-CN/views/application.ts +++ b/ui/src/locales/lang/zh-CN/views/application.ts @@ -55,7 +55,8 @@ export default { references: ' (引用知识库)', placeholder: '请输入提示词', requiredMessage: '请输入提示词', - tooltip:'通过调整提示词内容,可以引导大模型聊天方向,该提示词会被固定在上下文的开头,可以使用变量。', + tooltip: + '通过调整提示词内容,可以引导大模型聊天方向,该提示词会被固定在上下文的开头,可以使用变量。', noReferencesTooltip: '通过调整提示词内容,可以引导大模型聊天方向,该提示词会被固定在上下文的开头。可以使用变量:{question} 是用户提出问题的占位符。', referencesTooltip: @@ -96,6 +97,12 @@ export default { browser: '浏览器播放(免费)', tts: 'TTS模型', listeningTest: '试听' + }, + reasoningContent: { + label: '输出思考', + tooltip: '请根据模型返回的思考标签设置,标签中间的内容将为认定为思考过程', + start: '开始', + end: '结束' } }, buttons: { @@ -194,6 +201,5 @@ export default { text: '针对用户提问调试段落匹配情况,保障回答效果。', emptyMessage1: '命中段落显示在这里', emptyMessage2: '没有命中的分段' - }, - + } } diff --git a/ui/src/locales/lang/zh-Hant/views/application-workflow.ts b/ui/src/locales/lang/zh-Hant/views/application-workflow.ts index 5e308fa8f..c7fa83e5a 100644 --- a/ui/src/locales/lang/zh-Hant/views/application-workflow.ts +++ b/ui/src/locales/lang/zh-Hant/views/application-workflow.ts @@ -113,7 +113,8 @@ export default { tooltip: `關閉後該節點的內容則不輸出給用戶。 如果你想讓用戶看到該節點的輸出內容,請打開開關。` }, - defaultPrompt: '已知信息' + defaultPrompt: '已知信息', + think: '思考過程', }, searchDatasetNode: { label: '知識庫檢索', diff --git a/ui/src/locales/lang/zh-Hant/views/application.ts b/ui/src/locales/lang/zh-Hant/views/application.ts index dcfb7bf04..69062f4ec 100644 --- a/ui/src/locales/lang/zh-Hant/views/application.ts +++ b/ui/src/locales/lang/zh-Hant/views/application.ts @@ -97,6 +97,12 @@ export default { browser: '瀏覽器播放(免費)', tts: 'TTS模型', listeningTest: '試聽' + }, + reasoningContent: { + label: '輸出思考', + tooltip:'請根據模型返回的思考標簽設置,標簽中間的內容將爲認定爲思考過程', + start: '開始', + end: '結束', } }, buttons: { diff --git a/ui/src/views/application/ApplicationSetting.vue b/ui/src/views/application/ApplicationSetting.vue index 6d1b4a938..e4ddcc788 100644 --- a/ui/src/views/application/ApplicationSetting.vue +++ b/ui/src/views/application/ApplicationSetting.vue @@ -65,6 +65,7 @@ + diff --git a/ui/src/workflow/common/data.ts b/ui/src/workflow/common/data.ts index 1661f214e..43a632adf 100644 --- a/ui/src/workflow/common/data.ts +++ b/ui/src/workflow/common/data.ts @@ -65,6 +65,10 @@ export const aiChatNode = { { label: t('views.applicationWorkflow.nodes.aiChatNode.answer'), value: 'answer' + }, + { + label: t('views.applicationWorkflow.nodes.aiChatNode.think'), + value: 'reasoning_content' } ] } diff --git a/ui/src/workflow/nodes/ai-chat-node/index.vue b/ui/src/workflow/nodes/ai-chat-node/index.vue index 4e301398b..2502d4a67 100644 --- a/ui/src/workflow/nodes/ai-chat-node/index.vue +++ b/ui/src/workflow/nodes/ai-chat-node/index.vue @@ -7,7 +7,6 @@ :model="chat_data" label-position="top" require-asterisk-position="right" - class="mb-24" label-width="auto" ref="aiChatNodeFormRef" hide-required-asterisk @@ -29,6 +28,7 @@ }}* + - + + + + + + +