mirror of
https://github.com/1Panel-dev/MaxKB.git
synced 2025-12-26 01:33:05 +00:00
refactor: update ZhipuChatModel to use BaseChatOpenAI and improve token counting
--bug=1061305 --user=刘瑞斌 【应用】ai对话启用工具后部分模型(智谱)不统计tokens https://www.tapd.cn/62980211/s/1791683
This commit is contained in:
parent
ef162bdc3e
commit
543f83a822
|
|
@ -7,27 +7,21 @@
|
|||
@desc:
|
||||
"""
|
||||
|
||||
import json
|
||||
from collections.abc import Iterator
|
||||
from typing import Any, Dict, List, Optional
|
||||
from typing import Dict, List
|
||||
|
||||
from langchain_community.chat_models import ChatZhipuAI
|
||||
from langchain_community.chat_models.zhipuai import _truncate_params, _get_jwt_token, connect_sse, \
|
||||
_convert_delta_to_message_chunk
|
||||
from langchain_core.callbacks import (
|
||||
CallbackManagerForLLMRun,
|
||||
)
|
||||
from langchain_core.messages import (
|
||||
AIMessageChunk,
|
||||
BaseMessage
|
||||
)
|
||||
from langchain_core.outputs import ChatGenerationChunk
|
||||
from langchain_core.messages import BaseMessage, get_buffer_string
|
||||
|
||||
from common.config.tokenizer_manage_config import TokenizerManage
|
||||
from models_provider.base_model_provider import MaxKBBaseModel
|
||||
from models_provider.impl.base_chat_open_ai import BaseChatOpenAI
|
||||
|
||||
|
||||
class ZhipuChatModel(MaxKBBaseModel, ChatZhipuAI):
|
||||
optional_params: dict
|
||||
def custom_get_token_ids(text: str):
|
||||
tokenizer = TokenizerManage.get_tokenizer()
|
||||
return tokenizer.encode(text)
|
||||
|
||||
|
||||
class ZhipuChatModel(MaxKBBaseModel, BaseChatOpenAI):
|
||||
|
||||
@staticmethod
|
||||
def is_cache_model():
|
||||
|
|
@ -39,69 +33,23 @@ class ZhipuChatModel(MaxKBBaseModel, ChatZhipuAI):
|
|||
zhipuai_chat = ZhipuChatModel(
|
||||
api_key=model_credential.get('api_key'),
|
||||
model=model_name,
|
||||
base_url='https://open.bigmodel.cn/api/paas/v4',
|
||||
extra_body=optional_params,
|
||||
streaming=model_kwargs.get('streaming', False),
|
||||
optional_params=optional_params,
|
||||
**optional_params,
|
||||
custom_get_token_ids=custom_get_token_ids
|
||||
)
|
||||
return zhipuai_chat
|
||||
|
||||
usage_metadata: dict = {}
|
||||
|
||||
def get_last_generation_info(self) -> Optional[Dict[str, Any]]:
|
||||
return self.usage_metadata
|
||||
|
||||
def get_num_tokens_from_messages(self, messages: List[BaseMessage]) -> int:
|
||||
return self.usage_metadata.get('prompt_tokens', 0)
|
||||
try:
|
||||
return super().get_num_tokens_from_messages(messages)
|
||||
except Exception as e:
|
||||
tokenizer = TokenizerManage.get_tokenizer()
|
||||
return sum([len(tokenizer.encode(get_buffer_string([m]))) for m in messages])
|
||||
|
||||
def get_num_tokens(self, text: str) -> int:
|
||||
return self.usage_metadata.get('completion_tokens', 0)
|
||||
|
||||
def _stream(
|
||||
self,
|
||||
messages: List[BaseMessage],
|
||||
stop: Optional[List[str]] = None,
|
||||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> Iterator[ChatGenerationChunk]:
|
||||
"""Stream the chat response in chunks."""
|
||||
if self.zhipuai_api_key is None:
|
||||
raise ValueError("Did not find zhipuai_api_key.")
|
||||
if self.zhipuai_api_base is None:
|
||||
raise ValueError("Did not find zhipu_api_base.")
|
||||
message_dicts, params = self._create_message_dicts(messages, stop)
|
||||
payload = {**params, **kwargs, **self.optional_params, "messages": message_dicts, "stream": True}
|
||||
_truncate_params(payload)
|
||||
headers = {
|
||||
"Authorization": _get_jwt_token(self.zhipuai_api_key),
|
||||
"Accept": "application/json",
|
||||
}
|
||||
|
||||
default_chunk_class = AIMessageChunk
|
||||
import httpx
|
||||
|
||||
with httpx.Client(headers=headers, timeout=60) as client:
|
||||
with connect_sse(
|
||||
client, "POST", self.zhipuai_api_base, json=payload
|
||||
) as event_source:
|
||||
for sse in event_source.iter_sse():
|
||||
chunk = json.loads(sse.data)
|
||||
if len(chunk["choices"]) == 0:
|
||||
continue
|
||||
choice = chunk["choices"][0]
|
||||
generation_info = {}
|
||||
if "usage" in chunk:
|
||||
generation_info = chunk["usage"]
|
||||
self.usage_metadata = generation_info
|
||||
chunk = _convert_delta_to_message_chunk(
|
||||
choice["delta"], default_chunk_class
|
||||
)
|
||||
finish_reason = choice.get("finish_reason", None)
|
||||
|
||||
chunk = ChatGenerationChunk(
|
||||
message=chunk, generation_info=generation_info
|
||||
)
|
||||
yield chunk
|
||||
if run_manager:
|
||||
run_manager.on_llm_new_token(chunk.text, chunk=chunk)
|
||||
if finish_reason is not None:
|
||||
break
|
||||
try:
|
||||
return super().get_num_tokens(text)
|
||||
except Exception as e:
|
||||
tokenizer = TokenizerManage.get_tokenizer()
|
||||
return len(tokenizer.encode(text))
|
||||
|
|
|
|||
Loading…
Reference in New Issue