diff --git a/apps/setting/models_provider/impl/vllm_model_provider/model/image.py b/apps/setting/models_provider/impl/vllm_model_provider/model/image.py index 75eafd80b..4d5dda29d 100644 --- a/apps/setting/models_provider/impl/vllm_model_provider/model/image.py +++ b/apps/setting/models_provider/impl/vllm_model_provider/model/image.py @@ -1,5 +1,8 @@ -from typing import Dict +from typing import Dict, List +from langchain_core.messages import get_buffer_string, BaseMessage + +from common.config.tokenizer_manage_config import TokenizerManage from setting.models_provider.base_model_provider import MaxKBBaseModel from setting.models_provider.impl.base_chat_open_ai import BaseChatOpenAI @@ -21,3 +24,15 @@ class VllmImage(MaxKBBaseModel, BaseChatOpenAI): def is_cache_model(self): return False + + def get_num_tokens_from_messages(self, messages: List[BaseMessage]) -> int: + if self.usage_metadata is None or self.usage_metadata == {}: + tokenizer = TokenizerManage.get_tokenizer() + return sum([len(tokenizer.encode(get_buffer_string([m]))) for m in messages]) + return self.usage_metadata.get('input_tokens', 0) + + def get_num_tokens(self, text: str) -> int: + if self.usage_metadata is None or self.usage_metadata == {}: + tokenizer = TokenizerManage.get_tokenizer() + return len(tokenizer.encode(text)) + return self.get_last_generation_info().get('output_tokens', 0) diff --git a/apps/setting/models_provider/impl/vllm_model_provider/model/llm.py b/apps/setting/models_provider/impl/vllm_model_provider/model/llm.py index 6f3ed0620..7d2a63acd 100644 --- a/apps/setting/models_provider/impl/vllm_model_provider/model/llm.py +++ b/apps/setting/models_provider/impl/vllm_model_provider/model/llm.py @@ -1,8 +1,11 @@ # coding=utf-8 -from typing import Dict +from typing import Dict, List from urllib.parse import urlparse, ParseResult +from langchain_core.messages import BaseMessage, get_buffer_string + +from common.config.tokenizer_manage_config import TokenizerManage from setting.models_provider.base_model_provider import MaxKBBaseModel from setting.models_provider.impl.base_chat_open_ai import BaseChatOpenAI @@ -33,3 +36,15 @@ class VllmChatModel(MaxKBBaseModel, BaseChatOpenAI): stream_usage=True, ) return vllm_chat_open_ai + + def get_num_tokens_from_messages(self, messages: List[BaseMessage]) -> int: + if self.usage_metadata is None or self.usage_metadata == {}: + tokenizer = TokenizerManage.get_tokenizer() + return sum([len(tokenizer.encode(get_buffer_string([m]))) for m in messages]) + return self.usage_metadata.get('input_tokens', 0) + + def get_num_tokens(self, text: str) -> int: + if self.usage_metadata is None or self.usage_metadata == {}: + tokenizer = TokenizerManage.get_tokenizer() + return len(tokenizer.encode(text)) + return self.get_last_generation_info().get('output_tokens', 0)