diff --git a/apps/setting/models_provider/impl/gemini_model_provider/model/llm.py b/apps/setting/models_provider/impl/gemini_model_provider/model/llm.py index 4106cc1d6..af23d0341 100644 --- a/apps/setting/models_provider/impl/gemini_model_provider/model/llm.py +++ b/apps/setting/models_provider/impl/gemini_model_provider/model/llm.py @@ -13,7 +13,7 @@ from google.ai.generativelanguage_v1beta.types import ( Tool as GoogleTool, ) from langchain_core.callbacks import CallbackManagerForLLMRun -from langchain_core.messages import BaseMessage +from langchain_core.messages import BaseMessage, get_buffer_string from langchain_core.outputs import ChatGenerationChunk from langchain_google_genai import ChatGoogleGenerativeAI from langchain_google_genai._function_utils import _ToolConfigDict, _ToolDict @@ -22,6 +22,8 @@ from langchain_google_genai.chat_models import _chat_with_retry, _response_to_re from langchain_google_genai._common import ( SafetySettingDict, ) + +from common.config.tokenizer_manage_config import TokenizerManage from setting.models_provider.base_model_provider import MaxKBBaseModel @@ -46,10 +48,18 @@ class GeminiChatModel(MaxKBBaseModel, ChatGoogleGenerativeAI): return self.__dict__.get('_last_generation_info') def get_num_tokens_from_messages(self, messages: List[BaseMessage]) -> int: - return self.get_last_generation_info().get('input_tokens', 0) + try: + return self.get_last_generation_info().get('input_tokens', 0) + except Exception as e: + tokenizer = TokenizerManage.get_tokenizer() + return sum([len(tokenizer.encode(get_buffer_string([m]))) for m in messages]) def get_num_tokens(self, text: str) -> int: - return self.get_last_generation_info().get('output_tokens', 0) + try: + return self.get_last_generation_info().get('output_tokens', 0) + except Exception as e: + tokenizer = TokenizerManage.get_tokenizer() + return len(tokenizer.encode(text)) def _stream( self,