MaxKB/apps/models_provider/impl/gemini_model_provider/model/llm.py
2025-06-05 09:12:38 +08:00

101 lines
3.8 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

#!/usr/bin/env python
# -*- coding: UTF-8 -*-
"""
@Project MaxKB
@File llm.py
@Author Brian Yang
@Date 5/13/24 7:40 AM
"""
from typing import List, Dict, Optional, Sequence, Union, Any, Iterator, cast
from google.ai.generativelanguage_v1 import GenerateContentResponse
from google.ai.generativelanguage_v1beta.types import (
Tool as GoogleTool,
)
from langchain_core.callbacks import CallbackManagerForLLMRun
from langchain_core.messages import BaseMessage, get_buffer_string
from langchain_core.outputs import ChatGenerationChunk
from langchain_google_genai import ChatGoogleGenerativeAI
from langchain_google_genai._function_utils import _ToolConfigDict, _ToolDict
from langchain_google_genai.chat_models import _chat_with_retry, _response_to_result, \
_FunctionDeclarationType
from langchain_google_genai._common import (
SafetySettingDict,
)
from common.config.tokenizer_manage_config import TokenizerManage
from models_provider.base_model_provider import MaxKBBaseModel
class GeminiChatModel(MaxKBBaseModel, ChatGoogleGenerativeAI):
@staticmethod
def is_cache_model():
return False
@staticmethod
def new_instance(model_type, model_name, model_credential: Dict[str, object], **model_kwargs):
optional_params = MaxKBBaseModel.filter_optional_params(model_kwargs)
gemini_chat = GeminiChatModel(
model=model_name,
google_api_key=model_credential.get('api_key'),
**optional_params
)
return gemini_chat
def get_last_generation_info(self) -> Optional[Dict[str, Any]]:
return self.__dict__.get('_last_generation_info')
def get_num_tokens_from_messages(self, messages: List[BaseMessage]) -> int:
try:
return self.get_last_generation_info().get('input_tokens', 0)
except Exception as e:
tokenizer = TokenizerManage.get_tokenizer()
return sum([len(tokenizer.encode(get_buffer_string([m]))) for m in messages])
def get_num_tokens(self, text: str) -> int:
try:
return self.get_last_generation_info().get('output_tokens', 0)
except Exception as e:
tokenizer = TokenizerManage.get_tokenizer()
return len(tokenizer.encode(text))
def _stream(
self,
messages: List[BaseMessage],
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
*,
tools: Optional[Sequence[Union[_ToolDict, GoogleTool]]] = None,
functions: Optional[Sequence[_FunctionDeclarationType]] = None,
safety_settings: Optional[SafetySettingDict] = None,
tool_config: Optional[Union[Dict, _ToolConfigDict]] = None,
generation_config: Optional[Dict[str, Any]] = None,
**kwargs: Any,
) -> Iterator[ChatGenerationChunk]:
request = self._prepare_request(
messages,
stop=stop,
tools=tools,
functions=functions,
safety_settings=safety_settings,
tool_config=tool_config,
generation_config=generation_config,
)
response: GenerateContentResponse = _chat_with_retry(
request=request,
generation_method=self.client.stream_generate_content,
**kwargs,
metadata=self.default_metadata,
)
for chunk in response:
_chat_result = _response_to_result(chunk, stream=True)
gen = cast(ChatGenerationChunk, _chat_result.generations[0])
if gen.message:
token_usage = gen.message.usage_metadata
self.__dict__.setdefault('_last_generation_info', {}).update(token_usage)
if run_manager:
run_manager.on_llm_new_token(gen.text)
yield gen