refactor: update ZhipuChatModel to use BaseChatOpenAI and improve token counting
Some checks are pending
sync2gitee / repo-sync (push) Waiting to run
Typos Check / Spell Check with Typos (push) Waiting to run

--bug=1061305 --user=刘瑞斌 【应用】ai对话启用工具后部分模型(智谱)不统计tokens https://www.tapd.cn/62980211/s/1791683
This commit is contained in:
CaptainB 2025-10-30 10:51:52 +08:00 committed by 刘瑞斌
parent ef162bdc3e
commit 543f83a822

View File

@ -7,27 +7,21 @@
@desc:
"""
import json
from collections.abc import Iterator
from typing import Any, Dict, List, Optional
from typing import Dict, List
from langchain_community.chat_models import ChatZhipuAI
from langchain_community.chat_models.zhipuai import _truncate_params, _get_jwt_token, connect_sse, \
_convert_delta_to_message_chunk
from langchain_core.callbacks import (
CallbackManagerForLLMRun,
)
from langchain_core.messages import (
AIMessageChunk,
BaseMessage
)
from langchain_core.outputs import ChatGenerationChunk
from langchain_core.messages import BaseMessage, get_buffer_string
from common.config.tokenizer_manage_config import TokenizerManage
from models_provider.base_model_provider import MaxKBBaseModel
from models_provider.impl.base_chat_open_ai import BaseChatOpenAI
class ZhipuChatModel(MaxKBBaseModel, ChatZhipuAI):
optional_params: dict
def custom_get_token_ids(text: str):
tokenizer = TokenizerManage.get_tokenizer()
return tokenizer.encode(text)
class ZhipuChatModel(MaxKBBaseModel, BaseChatOpenAI):
@staticmethod
def is_cache_model():
@ -39,69 +33,23 @@ class ZhipuChatModel(MaxKBBaseModel, ChatZhipuAI):
zhipuai_chat = ZhipuChatModel(
api_key=model_credential.get('api_key'),
model=model_name,
base_url='https://open.bigmodel.cn/api/paas/v4',
extra_body=optional_params,
streaming=model_kwargs.get('streaming', False),
optional_params=optional_params,
**optional_params,
custom_get_token_ids=custom_get_token_ids
)
return zhipuai_chat
usage_metadata: dict = {}
def get_last_generation_info(self) -> Optional[Dict[str, Any]]:
return self.usage_metadata
def get_num_tokens_from_messages(self, messages: List[BaseMessage]) -> int:
return self.usage_metadata.get('prompt_tokens', 0)
try:
return super().get_num_tokens_from_messages(messages)
except Exception as e:
tokenizer = TokenizerManage.get_tokenizer()
return sum([len(tokenizer.encode(get_buffer_string([m]))) for m in messages])
def get_num_tokens(self, text: str) -> int:
return self.usage_metadata.get('completion_tokens', 0)
def _stream(
self,
messages: List[BaseMessage],
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> Iterator[ChatGenerationChunk]:
"""Stream the chat response in chunks."""
if self.zhipuai_api_key is None:
raise ValueError("Did not find zhipuai_api_key.")
if self.zhipuai_api_base is None:
raise ValueError("Did not find zhipu_api_base.")
message_dicts, params = self._create_message_dicts(messages, stop)
payload = {**params, **kwargs, **self.optional_params, "messages": message_dicts, "stream": True}
_truncate_params(payload)
headers = {
"Authorization": _get_jwt_token(self.zhipuai_api_key),
"Accept": "application/json",
}
default_chunk_class = AIMessageChunk
import httpx
with httpx.Client(headers=headers, timeout=60) as client:
with connect_sse(
client, "POST", self.zhipuai_api_base, json=payload
) as event_source:
for sse in event_source.iter_sse():
chunk = json.loads(sse.data)
if len(chunk["choices"]) == 0:
continue
choice = chunk["choices"][0]
generation_info = {}
if "usage" in chunk:
generation_info = chunk["usage"]
self.usage_metadata = generation_info
chunk = _convert_delta_to_message_chunk(
choice["delta"], default_chunk_class
)
finish_reason = choice.get("finish_reason", None)
chunk = ChatGenerationChunk(
message=chunk, generation_info=generation_info
)
yield chunk
if run_manager:
run_manager.on_llm_new_token(chunk.text, chunk=chunk)
if finish_reason is not None:
break
try:
return super().get_num_tokens(text)
except Exception as e:
tokenizer = TokenizerManage.get_tokenizer()
return len(tokenizer.encode(text))