diff --git a/apps/setting/models_provider/impl/ollama_model_provider/model/ollama_chat_model.py b/apps/setting/models_provider/impl/ollama_model_provider/model/ollama_chat_model.py index 7136a2769..876e1f28d 100644 --- a/apps/setting/models_provider/impl/ollama_model_provider/model/ollama_chat_model.py +++ b/apps/setting/models_provider/impl/ollama_model_provider/model/ollama_chat_model.py @@ -10,15 +10,27 @@ from typing import List from langchain_community.chat_models import ChatOpenAI from langchain_core.messages import BaseMessage, get_buffer_string -from transformers import GPT2TokenizerFast -tokenizer = GPT2TokenizerFast.from_pretrained('gpt2', cache_dir="/opt/maxkb/model/tokenizer", resume_download=False, - force_download=False) + +class TokenizerManage: + tokenizer = None + + @staticmethod + def get_tokenizer(): + from transformers import GPT2TokenizerFast + if TokenizerManage.tokenizer is None: + TokenizerManage.tokenizer = GPT2TokenizerFast.from_pretrained('gpt2', + cache_dir="/opt/maxkb/model/tokenizer", + resume_download=False, + force_download=False) + return TokenizerManage.tokenizer class OllamaChatModel(ChatOpenAI): def get_num_tokens_from_messages(self, messages: List[BaseMessage]) -> int: + tokenizer = TokenizerManage.get_tokenizer() return sum([len(tokenizer.encode(get_buffer_string([m]))) for m in messages]) def get_num_tokens(self, text: str) -> int: + tokenizer = TokenizerManage.get_tokenizer() return len(tokenizer.encode(text)) diff --git a/apps/setting/models_provider/impl/wenxin_model_provider/model/qian_fan_chat_model.py b/apps/setting/models_provider/impl/wenxin_model_provider/model/qian_fan_chat_model.py index 6fdd4ac50..a2065186e 100644 --- a/apps/setting/models_provider/impl/wenxin_model_provider/model/qian_fan_chat_model.py +++ b/apps/setting/models_provider/impl/wenxin_model_provider/model/qian_fan_chat_model.py @@ -9,7 +9,6 @@ from typing import Optional, List, Any, Iterator, cast from langchain.callbacks.manager import CallbackManager -from langchain_community.chat_models import QianfanChatEndpoint from langchain.chat_models.base import BaseChatModel from langchain.load import dumpd from langchain.schema import LLMResult @@ -17,18 +16,31 @@ from langchain.schema.language_model import LanguageModelInput from langchain.schema.messages import BaseMessageChunk, BaseMessage, HumanMessage, AIMessage, get_buffer_string from langchain.schema.output import ChatGenerationChunk from langchain.schema.runnable import RunnableConfig -from transformers import GPT2TokenizerFast +from langchain_community.chat_models import QianfanChatEndpoint -tokenizer = GPT2TokenizerFast.from_pretrained('gpt2', cache_dir="/opt/maxkb/model/tokenizer", resume_download=False, - force_download=False) + +class TokenizerManage: + tokenizer = None + + @staticmethod + def get_tokenizer(): + from transformers import GPT2TokenizerFast + if TokenizerManage.tokenizer is None: + TokenizerManage.tokenizer = GPT2TokenizerFast.from_pretrained('gpt2', + cache_dir="/opt/maxkb/model/tokenizer", + resume_download=False, + force_download=False) + return TokenizerManage.tokenizer class QianfanChatModel(QianfanChatEndpoint): def get_num_tokens_from_messages(self, messages: List[BaseMessage]) -> int: + tokenizer = TokenizerManage.get_tokenizer() return sum([len(tokenizer.encode(get_buffer_string([m]))) for m in messages]) def get_num_tokens(self, text: str) -> int: + tokenizer = TokenizerManage.get_tokenizer() return len(tokenizer.encode(text)) def stream(