From 27dd90c7df2c161cf7a131f4712ab0b613f2e241 Mon Sep 17 00:00:00 2001 From: wxg0103 <727495428@qq.com> Date: Wed, 19 Nov 2025 15:14:36 +0800 Subject: [PATCH] feat: enhance token management and normalize content handling in chat model --- .../step/chat_step/impl/base_chat_step.py | 18 ++++++---- apps/application/flow/tools.py | 17 ++++++++++ .../azure_model_provider/credential/llm.py | 3 +- .../model/azure_chat_model.py | 34 +++++++++++++++++-- 4 files changed, 62 insertions(+), 10 deletions(-) diff --git a/apps/application/chat_pipeline/step/chat_step/impl/base_chat_step.py b/apps/application/chat_pipeline/step/chat_step/impl/base_chat_step.py index 8eb18d5bc..ed953e13c 100644 --- a/apps/application/chat_pipeline/step/chat_step/impl/base_chat_step.py +++ b/apps/application/chat_pipeline/step/chat_step/impl/base_chat_step.py @@ -93,6 +93,7 @@ def event_content(response, reasoning_content_chunk = chunk.additional_kwargs.get('reasoning_content', '') else: reasoning_content_chunk = reasoning_chunk.get('reasoning_content') + content_chunk = reasoning._normalize_content(content_chunk) all_text += content_chunk if reasoning_content_chunk is None: reasoning_content_chunk = '' @@ -191,13 +192,15 @@ class BaseChatStep(IChatStep): manage, padding_problem_text, chat_user_id, chat_user_type, no_references_setting, model_setting, - mcp_enable, mcp_tool_ids, mcp_servers, mcp_source, tool_enable, tool_ids, mcp_output_enable) + mcp_enable, mcp_tool_ids, mcp_servers, mcp_source, tool_enable, tool_ids, + mcp_output_enable) else: return self.execute_block(message_list, chat_id, problem_text, post_response_handler, chat_model, paragraph_list, manage, padding_problem_text, chat_user_id, chat_user_type, no_references_setting, model_setting, - mcp_enable, mcp_tool_ids, mcp_servers, mcp_source, tool_enable, tool_ids, mcp_output_enable) + mcp_enable, mcp_tool_ids, mcp_servers, mcp_source, tool_enable, tool_ids, + mcp_output_enable) def get_details(self, manage, **kwargs): return { @@ -264,7 +267,6 @@ class BaseChatStep(IChatStep): return None - def get_stream_result(self, message_list: List[BaseMessage], chat_model: BaseChatModel = None, paragraph_list=None, @@ -294,7 +296,8 @@ class BaseChatStep(IChatStep): else: # 处理 MCP 请求 mcp_result = self._handle_mcp_request( - mcp_enable, tool_enable, mcp_source, mcp_servers, mcp_tool_ids, tool_ids, mcp_output_enable, chat_model, message_list, + mcp_enable, tool_enable, mcp_source, mcp_servers, mcp_tool_ids, tool_ids, mcp_output_enable, chat_model, + message_list, ) if mcp_result: return mcp_result, True @@ -319,7 +322,8 @@ class BaseChatStep(IChatStep): tool_ids=None, mcp_output_enable=True): chat_result, is_ai_chat = self.get_stream_result(message_list, chat_model, paragraph_list, - no_references_setting, problem_text, mcp_enable, mcp_tool_ids, mcp_servers, mcp_source, tool_enable, tool_ids, + no_references_setting, problem_text, mcp_enable, mcp_tool_ids, + mcp_servers, mcp_source, tool_enable, tool_ids, mcp_output_enable) chat_record_id = uuid.uuid7() r = StreamingHttpResponse( @@ -394,7 +398,9 @@ class BaseChatStep(IChatStep): # 调用模型 try: chat_result, is_ai_chat = self.get_block_result(message_list, chat_model, paragraph_list, - no_references_setting, problem_text, mcp_enable, mcp_tool_ids, mcp_servers, mcp_source, tool_enable, tool_ids, mcp_output_enable) + no_references_setting, problem_text, mcp_enable, + mcp_tool_ids, mcp_servers, mcp_source, tool_enable, + tool_ids, mcp_output_enable) if is_ai_chat: request_token = chat_model.get_num_tokens_from_messages(message_list) response_token = chat_model.get_num_tokens(chat_result.content) diff --git a/apps/application/flow/tools.py b/apps/application/flow/tools.py index e70e312a5..149eafaa0 100644 --- a/apps/application/flow/tools.py +++ b/apps/application/flow/tools.py @@ -48,6 +48,21 @@ class Reasoning: return r return {'content': '', 'reasoning_content': ''} + def _normalize_content(self, content): + """将不同类型的内容统一转换为字符串""" + if isinstance(content, str): + return content + elif isinstance(content, list): + # 处理包含多种内容类型的列表 + normalized_parts = [] + for item in content: + if isinstance(item, dict): + if item.get('type') == 'text': + normalized_parts.append(item.get('text', '')) + return ''.join(normalized_parts) + else: + return str(content) + def get_reasoning_content(self, chunk): # 如果没有开始思考过程标签那么就全是结果 if self.reasoning_content_start_tag is None or len(self.reasoning_content_start_tag) == 0: @@ -56,6 +71,7 @@ class Reasoning: # 如果没有结束思考过程标签那么就全部是思考过程 if self.reasoning_content_end_tag is None or len(self.reasoning_content_end_tag) == 0: return {'content': '', 'reasoning_content': chunk.content} + chunk.content = self._normalize_content(chunk.content) self.all_content += chunk.content if not self.reasoning_content_is_start and len(self.all_content) >= self.reasoning_content_start_tag_len: if self.all_content.startswith(self.reasoning_content_start_tag): @@ -202,6 +218,7 @@ def to_stream_response_simple(stream_event): r['Cache-Control'] = 'no-cache' return r + tool_message_json_template = """ ```json %s diff --git a/apps/models_provider/impl/azure_model_provider/credential/llm.py b/apps/models_provider/impl/azure_model_provider/credential/llm.py index 1e1967f6f..e8d8205c8 100644 --- a/apps/models_provider/impl/azure_model_provider/credential/llm.py +++ b/apps/models_provider/impl/azure_model_provider/credential/llm.py @@ -66,7 +66,8 @@ class AzureLLMModelCredential(BaseForm, BaseModelCredential): return False try: model = provider.get_model(model_type, model_name, model_credential, **model_params) - model.invoke([HumanMessage(content=gettext('Hello'))]) + res = model.invoke([HumanMessage(content=gettext('Hello'))]) + print( res) except Exception as e: traceback.print_exc() if isinstance(e, AppApiException) or isinstance(e, BadRequestError): diff --git a/apps/models_provider/impl/azure_model_provider/model/azure_chat_model.py b/apps/models_provider/impl/azure_model_provider/model/azure_chat_model.py index 4c0b546ff..36a12553a 100644 --- a/apps/models_provider/impl/azure_model_provider/model/azure_chat_model.py +++ b/apps/models_provider/impl/azure_model_provider/model/azure_chat_model.py @@ -7,9 +7,11 @@ @desc: """ -from typing import List, Dict +from typing import List, Dict, Optional, Any +from langchain_core.language_models import LanguageModelInput from langchain_core.messages import BaseMessage, get_buffer_string +from langchain_core.runnables import RunnableConfig from langchain_openai import AzureChatOpenAI from common.config.tokenizer_manage_config import TokenizerManage @@ -36,16 +38,42 @@ class AzureChatModel(MaxKBBaseModel, AzureChatOpenAI): streaming=True, ) + def get_last_generation_info(self) -> Optional[Dict[str, Any]]: + return self.__dict__.get('_last_generation_info') + def get_num_tokens_from_messages(self, messages: List[BaseMessage]) -> int: try: - return super().get_num_tokens_from_messages(messages) + return self.get_last_generation_info().get('input_tokens', 0) except Exception as e: tokenizer = TokenizerManage.get_tokenizer() return sum([len(tokenizer.encode(get_buffer_string([m]))) for m in messages]) def get_num_tokens(self, text: str) -> int: try: - return super().get_num_tokens(text) + return self.get_last_generation_info().get('output_tokens', 0) except Exception as e: tokenizer = TokenizerManage.get_tokenizer() return len(tokenizer.encode(text)) + + def invoke( + self, + input: LanguageModelInput, + config: Optional[RunnableConfig] = None, + *, + stop: Optional[list[str]] = None, + **kwargs: Any, + ) -> BaseMessage: + message = super().invoke(input, config, stop=stop, **kwargs) + if isinstance(message.content, str): + return message + elif isinstance(message.content, list): + # 构造新的响应消息返回 + content = message.content + normalized_parts = [] + for item in content: + if isinstance(item, dict): + if item.get('type') == 'text': + normalized_parts.append(item.get('text', '')) + message.content = ''.join(normalized_parts) + self.__dict__.setdefault('_last_generation_info', {}).update(message.usage_metadata) + return message