From fb7dfba5674c5e670e877f8ccd45bdd5785bb153 Mon Sep 17 00:00:00 2001 From: shaohuzhang1 Date: Wed, 20 Mar 2024 17:01:26 +0800 Subject: [PATCH] =?UTF-8?q?fix:=20=E4=BF=AE=E6=94=B9Tokenizer=E5=8A=A0?= =?UTF-8?q?=E8=BD=BD=E9=A1=BA=E5=BA=8F?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../model/ollama_chat_model.py | 18 ++++++++++++++--- .../model/qian_fan_chat_model.py | 20 +++++++++++++++---- 2 files changed, 31 insertions(+), 7 deletions(-) diff --git a/apps/setting/models_provider/impl/ollama_model_provider/model/ollama_chat_model.py b/apps/setting/models_provider/impl/ollama_model_provider/model/ollama_chat_model.py index 7136a2769..876e1f28d 100644 --- a/apps/setting/models_provider/impl/ollama_model_provider/model/ollama_chat_model.py +++ b/apps/setting/models_provider/impl/ollama_model_provider/model/ollama_chat_model.py @@ -10,15 +10,27 @@ from typing import List from langchain_community.chat_models import ChatOpenAI from langchain_core.messages import BaseMessage, get_buffer_string -from transformers import GPT2TokenizerFast -tokenizer = GPT2TokenizerFast.from_pretrained('gpt2', cache_dir="/opt/maxkb/model/tokenizer", resume_download=False, - force_download=False) + +class TokenizerManage: + tokenizer = None + + @staticmethod + def get_tokenizer(): + from transformers import GPT2TokenizerFast + if TokenizerManage.tokenizer is None: + TokenizerManage.tokenizer = GPT2TokenizerFast.from_pretrained('gpt2', + cache_dir="/opt/maxkb/model/tokenizer", + resume_download=False, + force_download=False) + return TokenizerManage.tokenizer class OllamaChatModel(ChatOpenAI): def get_num_tokens_from_messages(self, messages: List[BaseMessage]) -> int: + tokenizer = TokenizerManage.get_tokenizer() return sum([len(tokenizer.encode(get_buffer_string([m]))) for m in messages]) def get_num_tokens(self, text: str) -> int: + tokenizer = TokenizerManage.get_tokenizer() return len(tokenizer.encode(text)) diff --git a/apps/setting/models_provider/impl/wenxin_model_provider/model/qian_fan_chat_model.py b/apps/setting/models_provider/impl/wenxin_model_provider/model/qian_fan_chat_model.py index 6fdd4ac50..a2065186e 100644 --- a/apps/setting/models_provider/impl/wenxin_model_provider/model/qian_fan_chat_model.py +++ b/apps/setting/models_provider/impl/wenxin_model_provider/model/qian_fan_chat_model.py @@ -9,7 +9,6 @@ from typing import Optional, List, Any, Iterator, cast from langchain.callbacks.manager import CallbackManager -from langchain_community.chat_models import QianfanChatEndpoint from langchain.chat_models.base import BaseChatModel from langchain.load import dumpd from langchain.schema import LLMResult @@ -17,18 +16,31 @@ from langchain.schema.language_model import LanguageModelInput from langchain.schema.messages import BaseMessageChunk, BaseMessage, HumanMessage, AIMessage, get_buffer_string from langchain.schema.output import ChatGenerationChunk from langchain.schema.runnable import RunnableConfig -from transformers import GPT2TokenizerFast +from langchain_community.chat_models import QianfanChatEndpoint -tokenizer = GPT2TokenizerFast.from_pretrained('gpt2', cache_dir="/opt/maxkb/model/tokenizer", resume_download=False, - force_download=False) + +class TokenizerManage: + tokenizer = None + + @staticmethod + def get_tokenizer(): + from transformers import GPT2TokenizerFast + if TokenizerManage.tokenizer is None: + TokenizerManage.tokenizer = GPT2TokenizerFast.from_pretrained('gpt2', + cache_dir="/opt/maxkb/model/tokenizer", + resume_download=False, + force_download=False) + return TokenizerManage.tokenizer class QianfanChatModel(QianfanChatEndpoint): def get_num_tokens_from_messages(self, messages: List[BaseMessage]) -> int: + tokenizer = TokenizerManage.get_tokenizer() return sum([len(tokenizer.encode(get_buffer_string([m]))) for m in messages]) def get_num_tokens(self, text: str) -> int: + tokenizer = TokenizerManage.get_tokenizer() return len(tokenizer.encode(text)) def stream(