feat: enhance token management and normalize content handling in chat model

This commit is contained in:
wxg0103 2025-11-19 15:14:36 +08:00
parent 7d3f92bd51
commit 27dd90c7df
4 changed files with 62 additions and 10 deletions

View File

@ -93,6 +93,7 @@ def event_content(response,
reasoning_content_chunk = chunk.additional_kwargs.get('reasoning_content', '')
else:
reasoning_content_chunk = reasoning_chunk.get('reasoning_content')
content_chunk = reasoning._normalize_content(content_chunk)
all_text += content_chunk
if reasoning_content_chunk is None:
reasoning_content_chunk = ''
@ -191,13 +192,15 @@ class BaseChatStep(IChatStep):
manage, padding_problem_text, chat_user_id, chat_user_type,
no_references_setting,
model_setting,
mcp_enable, mcp_tool_ids, mcp_servers, mcp_source, tool_enable, tool_ids, mcp_output_enable)
mcp_enable, mcp_tool_ids, mcp_servers, mcp_source, tool_enable, tool_ids,
mcp_output_enable)
else:
return self.execute_block(message_list, chat_id, problem_text, post_response_handler, chat_model,
paragraph_list,
manage, padding_problem_text, chat_user_id, chat_user_type, no_references_setting,
model_setting,
mcp_enable, mcp_tool_ids, mcp_servers, mcp_source, tool_enable, tool_ids, mcp_output_enable)
mcp_enable, mcp_tool_ids, mcp_servers, mcp_source, tool_enable, tool_ids,
mcp_output_enable)
def get_details(self, manage, **kwargs):
return {
@ -264,7 +267,6 @@ class BaseChatStep(IChatStep):
return None
def get_stream_result(self, message_list: List[BaseMessage],
chat_model: BaseChatModel = None,
paragraph_list=None,
@ -294,7 +296,8 @@ class BaseChatStep(IChatStep):
else:
# 处理 MCP 请求
mcp_result = self._handle_mcp_request(
mcp_enable, tool_enable, mcp_source, mcp_servers, mcp_tool_ids, tool_ids, mcp_output_enable, chat_model, message_list,
mcp_enable, tool_enable, mcp_source, mcp_servers, mcp_tool_ids, tool_ids, mcp_output_enable, chat_model,
message_list,
)
if mcp_result:
return mcp_result, True
@ -319,7 +322,8 @@ class BaseChatStep(IChatStep):
tool_ids=None,
mcp_output_enable=True):
chat_result, is_ai_chat = self.get_stream_result(message_list, chat_model, paragraph_list,
no_references_setting, problem_text, mcp_enable, mcp_tool_ids, mcp_servers, mcp_source, tool_enable, tool_ids,
no_references_setting, problem_text, mcp_enable, mcp_tool_ids,
mcp_servers, mcp_source, tool_enable, tool_ids,
mcp_output_enable)
chat_record_id = uuid.uuid7()
r = StreamingHttpResponse(
@ -394,7 +398,9 @@ class BaseChatStep(IChatStep):
# 调用模型
try:
chat_result, is_ai_chat = self.get_block_result(message_list, chat_model, paragraph_list,
no_references_setting, problem_text, mcp_enable, mcp_tool_ids, mcp_servers, mcp_source, tool_enable, tool_ids, mcp_output_enable)
no_references_setting, problem_text, mcp_enable,
mcp_tool_ids, mcp_servers, mcp_source, tool_enable,
tool_ids, mcp_output_enable)
if is_ai_chat:
request_token = chat_model.get_num_tokens_from_messages(message_list)
response_token = chat_model.get_num_tokens(chat_result.content)

View File

@ -48,6 +48,21 @@ class Reasoning:
return r
return {'content': '', 'reasoning_content': ''}
def _normalize_content(self, content):
"""将不同类型的内容统一转换为字符串"""
if isinstance(content, str):
return content
elif isinstance(content, list):
# 处理包含多种内容类型的列表
normalized_parts = []
for item in content:
if isinstance(item, dict):
if item.get('type') == 'text':
normalized_parts.append(item.get('text', ''))
return ''.join(normalized_parts)
else:
return str(content)
def get_reasoning_content(self, chunk):
# 如果没有开始思考过程标签那么就全是结果
if self.reasoning_content_start_tag is None or len(self.reasoning_content_start_tag) == 0:
@ -56,6 +71,7 @@ class Reasoning:
# 如果没有结束思考过程标签那么就全部是思考过程
if self.reasoning_content_end_tag is None or len(self.reasoning_content_end_tag) == 0:
return {'content': '', 'reasoning_content': chunk.content}
chunk.content = self._normalize_content(chunk.content)
self.all_content += chunk.content
if not self.reasoning_content_is_start and len(self.all_content) >= self.reasoning_content_start_tag_len:
if self.all_content.startswith(self.reasoning_content_start_tag):
@ -202,6 +218,7 @@ def to_stream_response_simple(stream_event):
r['Cache-Control'] = 'no-cache'
return r
tool_message_json_template = """
```json
%s

View File

@ -66,7 +66,8 @@ class AzureLLMModelCredential(BaseForm, BaseModelCredential):
return False
try:
model = provider.get_model(model_type, model_name, model_credential, **model_params)
model.invoke([HumanMessage(content=gettext('Hello'))])
res = model.invoke([HumanMessage(content=gettext('Hello'))])
print( res)
except Exception as e:
traceback.print_exc()
if isinstance(e, AppApiException) or isinstance(e, BadRequestError):

View File

@ -7,9 +7,11 @@
@desc:
"""
from typing import List, Dict
from typing import List, Dict, Optional, Any
from langchain_core.language_models import LanguageModelInput
from langchain_core.messages import BaseMessage, get_buffer_string
from langchain_core.runnables import RunnableConfig
from langchain_openai import AzureChatOpenAI
from common.config.tokenizer_manage_config import TokenizerManage
@ -36,16 +38,42 @@ class AzureChatModel(MaxKBBaseModel, AzureChatOpenAI):
streaming=True,
)
def get_last_generation_info(self) -> Optional[Dict[str, Any]]:
return self.__dict__.get('_last_generation_info')
def get_num_tokens_from_messages(self, messages: List[BaseMessage]) -> int:
try:
return super().get_num_tokens_from_messages(messages)
return self.get_last_generation_info().get('input_tokens', 0)
except Exception as e:
tokenizer = TokenizerManage.get_tokenizer()
return sum([len(tokenizer.encode(get_buffer_string([m]))) for m in messages])
def get_num_tokens(self, text: str) -> int:
try:
return super().get_num_tokens(text)
return self.get_last_generation_info().get('output_tokens', 0)
except Exception as e:
tokenizer = TokenizerManage.get_tokenizer()
return len(tokenizer.encode(text))
def invoke(
self,
input: LanguageModelInput,
config: Optional[RunnableConfig] = None,
*,
stop: Optional[list[str]] = None,
**kwargs: Any,
) -> BaseMessage:
message = super().invoke(input, config, stop=stop, **kwargs)
if isinstance(message.content, str):
return message
elif isinstance(message.content, list):
# 构造新的响应消息返回
content = message.content
normalized_parts = []
for item in content:
if isinstance(item, dict):
if item.get('type') == 'text':
normalized_parts.append(item.get('text', ''))
message.content = ''.join(normalized_parts)
self.__dict__.setdefault('_last_generation_info', {}).update(message.usage_metadata)
return message