refactor: gemini
Some checks failed
sync2gitee / repo-sync (push) Has been cancelled
Typos Check / Spell Check with Typos (push) Has been cancelled

This commit is contained in:
wxg0103 2025-06-05 09:11:06 +08:00
parent 7ce66a7bf3
commit e12b1fe14e

View File

@ -13,7 +13,7 @@ from google.ai.generativelanguage_v1beta.types import (
Tool as GoogleTool,
)
from langchain_core.callbacks import CallbackManagerForLLMRun
from langchain_core.messages import BaseMessage
from langchain_core.messages import BaseMessage, get_buffer_string
from langchain_core.outputs import ChatGenerationChunk
from langchain_google_genai import ChatGoogleGenerativeAI
from langchain_google_genai._function_utils import _ToolConfigDict, _ToolDict
@ -22,6 +22,8 @@ from langchain_google_genai.chat_models import _chat_with_retry, _response_to_re
from langchain_google_genai._common import (
SafetySettingDict,
)
from common.config.tokenizer_manage_config import TokenizerManage
from setting.models_provider.base_model_provider import MaxKBBaseModel
@ -46,10 +48,18 @@ class GeminiChatModel(MaxKBBaseModel, ChatGoogleGenerativeAI):
return self.__dict__.get('_last_generation_info')
def get_num_tokens_from_messages(self, messages: List[BaseMessage]) -> int:
return self.get_last_generation_info().get('input_tokens', 0)
try:
return self.get_last_generation_info().get('input_tokens', 0)
except Exception as e:
tokenizer = TokenizerManage.get_tokenizer()
return sum([len(tokenizer.encode(get_buffer_string([m]))) for m in messages])
def get_num_tokens(self, text: str) -> int:
return self.get_last_generation_info().get('output_tokens', 0)
try:
return self.get_last_generation_info().get('output_tokens', 0)
except Exception as e:
tokenizer = TokenizerManage.get_tokenizer()
return len(tokenizer.encode(text))
def _stream(
self,