mirror of
https://github.com/1Panel-dev/MaxKB.git
synced 2025-12-25 17:22:55 +00:00
refactor: gemini
This commit is contained in:
parent
7ce66a7bf3
commit
e12b1fe14e
|
|
@ -13,7 +13,7 @@ from google.ai.generativelanguage_v1beta.types import (
|
|||
Tool as GoogleTool,
|
||||
)
|
||||
from langchain_core.callbacks import CallbackManagerForLLMRun
|
||||
from langchain_core.messages import BaseMessage
|
||||
from langchain_core.messages import BaseMessage, get_buffer_string
|
||||
from langchain_core.outputs import ChatGenerationChunk
|
||||
from langchain_google_genai import ChatGoogleGenerativeAI
|
||||
from langchain_google_genai._function_utils import _ToolConfigDict, _ToolDict
|
||||
|
|
@ -22,6 +22,8 @@ from langchain_google_genai.chat_models import _chat_with_retry, _response_to_re
|
|||
from langchain_google_genai._common import (
|
||||
SafetySettingDict,
|
||||
)
|
||||
|
||||
from common.config.tokenizer_manage_config import TokenizerManage
|
||||
from setting.models_provider.base_model_provider import MaxKBBaseModel
|
||||
|
||||
|
||||
|
|
@ -46,10 +48,18 @@ class GeminiChatModel(MaxKBBaseModel, ChatGoogleGenerativeAI):
|
|||
return self.__dict__.get('_last_generation_info')
|
||||
|
||||
def get_num_tokens_from_messages(self, messages: List[BaseMessage]) -> int:
|
||||
return self.get_last_generation_info().get('input_tokens', 0)
|
||||
try:
|
||||
return self.get_last_generation_info().get('input_tokens', 0)
|
||||
except Exception as e:
|
||||
tokenizer = TokenizerManage.get_tokenizer()
|
||||
return sum([len(tokenizer.encode(get_buffer_string([m]))) for m in messages])
|
||||
|
||||
def get_num_tokens(self, text: str) -> int:
|
||||
return self.get_last_generation_info().get('output_tokens', 0)
|
||||
try:
|
||||
return self.get_last_generation_info().get('output_tokens', 0)
|
||||
except Exception as e:
|
||||
tokenizer = TokenizerManage.get_tokenizer()
|
||||
return len(tokenizer.encode(text))
|
||||
|
||||
def _stream(
|
||||
self,
|
||||
|
|
|
|||
Loading…
Reference in New Issue