mirror of
https://github.com/1Panel-dev/MaxKB.git
synced 2025-12-26 01:33:05 +00:00
feat: enhance token management and normalize content handling in chat model
This commit is contained in:
parent
7d3f92bd51
commit
27dd90c7df
|
|
@ -93,6 +93,7 @@ def event_content(response,
|
|||
reasoning_content_chunk = chunk.additional_kwargs.get('reasoning_content', '')
|
||||
else:
|
||||
reasoning_content_chunk = reasoning_chunk.get('reasoning_content')
|
||||
content_chunk = reasoning._normalize_content(content_chunk)
|
||||
all_text += content_chunk
|
||||
if reasoning_content_chunk is None:
|
||||
reasoning_content_chunk = ''
|
||||
|
|
@ -191,13 +192,15 @@ class BaseChatStep(IChatStep):
|
|||
manage, padding_problem_text, chat_user_id, chat_user_type,
|
||||
no_references_setting,
|
||||
model_setting,
|
||||
mcp_enable, mcp_tool_ids, mcp_servers, mcp_source, tool_enable, tool_ids, mcp_output_enable)
|
||||
mcp_enable, mcp_tool_ids, mcp_servers, mcp_source, tool_enable, tool_ids,
|
||||
mcp_output_enable)
|
||||
else:
|
||||
return self.execute_block(message_list, chat_id, problem_text, post_response_handler, chat_model,
|
||||
paragraph_list,
|
||||
manage, padding_problem_text, chat_user_id, chat_user_type, no_references_setting,
|
||||
model_setting,
|
||||
mcp_enable, mcp_tool_ids, mcp_servers, mcp_source, tool_enable, tool_ids, mcp_output_enable)
|
||||
mcp_enable, mcp_tool_ids, mcp_servers, mcp_source, tool_enable, tool_ids,
|
||||
mcp_output_enable)
|
||||
|
||||
def get_details(self, manage, **kwargs):
|
||||
return {
|
||||
|
|
@ -264,7 +267,6 @@ class BaseChatStep(IChatStep):
|
|||
|
||||
return None
|
||||
|
||||
|
||||
def get_stream_result(self, message_list: List[BaseMessage],
|
||||
chat_model: BaseChatModel = None,
|
||||
paragraph_list=None,
|
||||
|
|
@ -294,7 +296,8 @@ class BaseChatStep(IChatStep):
|
|||
else:
|
||||
# 处理 MCP 请求
|
||||
mcp_result = self._handle_mcp_request(
|
||||
mcp_enable, tool_enable, mcp_source, mcp_servers, mcp_tool_ids, tool_ids, mcp_output_enable, chat_model, message_list,
|
||||
mcp_enable, tool_enable, mcp_source, mcp_servers, mcp_tool_ids, tool_ids, mcp_output_enable, chat_model,
|
||||
message_list,
|
||||
)
|
||||
if mcp_result:
|
||||
return mcp_result, True
|
||||
|
|
@ -319,7 +322,8 @@ class BaseChatStep(IChatStep):
|
|||
tool_ids=None,
|
||||
mcp_output_enable=True):
|
||||
chat_result, is_ai_chat = self.get_stream_result(message_list, chat_model, paragraph_list,
|
||||
no_references_setting, problem_text, mcp_enable, mcp_tool_ids, mcp_servers, mcp_source, tool_enable, tool_ids,
|
||||
no_references_setting, problem_text, mcp_enable, mcp_tool_ids,
|
||||
mcp_servers, mcp_source, tool_enable, tool_ids,
|
||||
mcp_output_enable)
|
||||
chat_record_id = uuid.uuid7()
|
||||
r = StreamingHttpResponse(
|
||||
|
|
@ -394,7 +398,9 @@ class BaseChatStep(IChatStep):
|
|||
# 调用模型
|
||||
try:
|
||||
chat_result, is_ai_chat = self.get_block_result(message_list, chat_model, paragraph_list,
|
||||
no_references_setting, problem_text, mcp_enable, mcp_tool_ids, mcp_servers, mcp_source, tool_enable, tool_ids, mcp_output_enable)
|
||||
no_references_setting, problem_text, mcp_enable,
|
||||
mcp_tool_ids, mcp_servers, mcp_source, tool_enable,
|
||||
tool_ids, mcp_output_enable)
|
||||
if is_ai_chat:
|
||||
request_token = chat_model.get_num_tokens_from_messages(message_list)
|
||||
response_token = chat_model.get_num_tokens(chat_result.content)
|
||||
|
|
|
|||
|
|
@ -48,6 +48,21 @@ class Reasoning:
|
|||
return r
|
||||
return {'content': '', 'reasoning_content': ''}
|
||||
|
||||
def _normalize_content(self, content):
|
||||
"""将不同类型的内容统一转换为字符串"""
|
||||
if isinstance(content, str):
|
||||
return content
|
||||
elif isinstance(content, list):
|
||||
# 处理包含多种内容类型的列表
|
||||
normalized_parts = []
|
||||
for item in content:
|
||||
if isinstance(item, dict):
|
||||
if item.get('type') == 'text':
|
||||
normalized_parts.append(item.get('text', ''))
|
||||
return ''.join(normalized_parts)
|
||||
else:
|
||||
return str(content)
|
||||
|
||||
def get_reasoning_content(self, chunk):
|
||||
# 如果没有开始思考过程标签那么就全是结果
|
||||
if self.reasoning_content_start_tag is None or len(self.reasoning_content_start_tag) == 0:
|
||||
|
|
@ -56,6 +71,7 @@ class Reasoning:
|
|||
# 如果没有结束思考过程标签那么就全部是思考过程
|
||||
if self.reasoning_content_end_tag is None or len(self.reasoning_content_end_tag) == 0:
|
||||
return {'content': '', 'reasoning_content': chunk.content}
|
||||
chunk.content = self._normalize_content(chunk.content)
|
||||
self.all_content += chunk.content
|
||||
if not self.reasoning_content_is_start and len(self.all_content) >= self.reasoning_content_start_tag_len:
|
||||
if self.all_content.startswith(self.reasoning_content_start_tag):
|
||||
|
|
@ -202,6 +218,7 @@ def to_stream_response_simple(stream_event):
|
|||
r['Cache-Control'] = 'no-cache'
|
||||
return r
|
||||
|
||||
|
||||
tool_message_json_template = """
|
||||
```json
|
||||
%s
|
||||
|
|
|
|||
|
|
@ -66,7 +66,8 @@ class AzureLLMModelCredential(BaseForm, BaseModelCredential):
|
|||
return False
|
||||
try:
|
||||
model = provider.get_model(model_type, model_name, model_credential, **model_params)
|
||||
model.invoke([HumanMessage(content=gettext('Hello'))])
|
||||
res = model.invoke([HumanMessage(content=gettext('Hello'))])
|
||||
print( res)
|
||||
except Exception as e:
|
||||
traceback.print_exc()
|
||||
if isinstance(e, AppApiException) or isinstance(e, BadRequestError):
|
||||
|
|
|
|||
|
|
@ -7,9 +7,11 @@
|
|||
@desc:
|
||||
"""
|
||||
|
||||
from typing import List, Dict
|
||||
from typing import List, Dict, Optional, Any
|
||||
|
||||
from langchain_core.language_models import LanguageModelInput
|
||||
from langchain_core.messages import BaseMessage, get_buffer_string
|
||||
from langchain_core.runnables import RunnableConfig
|
||||
from langchain_openai import AzureChatOpenAI
|
||||
|
||||
from common.config.tokenizer_manage_config import TokenizerManage
|
||||
|
|
@ -36,16 +38,42 @@ class AzureChatModel(MaxKBBaseModel, AzureChatOpenAI):
|
|||
streaming=True,
|
||||
)
|
||||
|
||||
def get_last_generation_info(self) -> Optional[Dict[str, Any]]:
|
||||
return self.__dict__.get('_last_generation_info')
|
||||
|
||||
def get_num_tokens_from_messages(self, messages: List[BaseMessage]) -> int:
|
||||
try:
|
||||
return super().get_num_tokens_from_messages(messages)
|
||||
return self.get_last_generation_info().get('input_tokens', 0)
|
||||
except Exception as e:
|
||||
tokenizer = TokenizerManage.get_tokenizer()
|
||||
return sum([len(tokenizer.encode(get_buffer_string([m]))) for m in messages])
|
||||
|
||||
def get_num_tokens(self, text: str) -> int:
|
||||
try:
|
||||
return super().get_num_tokens(text)
|
||||
return self.get_last_generation_info().get('output_tokens', 0)
|
||||
except Exception as e:
|
||||
tokenizer = TokenizerManage.get_tokenizer()
|
||||
return len(tokenizer.encode(text))
|
||||
|
||||
def invoke(
|
||||
self,
|
||||
input: LanguageModelInput,
|
||||
config: Optional[RunnableConfig] = None,
|
||||
*,
|
||||
stop: Optional[list[str]] = None,
|
||||
**kwargs: Any,
|
||||
) -> BaseMessage:
|
||||
message = super().invoke(input, config, stop=stop, **kwargs)
|
||||
if isinstance(message.content, str):
|
||||
return message
|
||||
elif isinstance(message.content, list):
|
||||
# 构造新的响应消息返回
|
||||
content = message.content
|
||||
normalized_parts = []
|
||||
for item in content:
|
||||
if isinstance(item, dict):
|
||||
if item.get('type') == 'text':
|
||||
normalized_parts.append(item.get('text', ''))
|
||||
message.content = ''.join(normalized_parts)
|
||||
self.__dict__.setdefault('_last_generation_info', {}).update(message.usage_metadata)
|
||||
return message
|
||||
|
|
|
|||
Loading…
Reference in New Issue