fix: 修改Tokenizer加载顺序

This commit is contained in:
shaohuzhang1 2024-03-20 17:01:26 +08:00
parent 7933b15a38
commit fb7dfba567
2 changed files with 31 additions and 7 deletions

View File

@ -10,15 +10,27 @@ from typing import List
from langchain_community.chat_models import ChatOpenAI
from langchain_core.messages import BaseMessage, get_buffer_string
from transformers import GPT2TokenizerFast
tokenizer = GPT2TokenizerFast.from_pretrained('gpt2', cache_dir="/opt/maxkb/model/tokenizer", resume_download=False,
force_download=False)
class TokenizerManage:
tokenizer = None
@staticmethod
def get_tokenizer():
from transformers import GPT2TokenizerFast
if TokenizerManage.tokenizer is None:
TokenizerManage.tokenizer = GPT2TokenizerFast.from_pretrained('gpt2',
cache_dir="/opt/maxkb/model/tokenizer",
resume_download=False,
force_download=False)
return TokenizerManage.tokenizer
class OllamaChatModel(ChatOpenAI):
def get_num_tokens_from_messages(self, messages: List[BaseMessage]) -> int:
tokenizer = TokenizerManage.get_tokenizer()
return sum([len(tokenizer.encode(get_buffer_string([m]))) for m in messages])
def get_num_tokens(self, text: str) -> int:
tokenizer = TokenizerManage.get_tokenizer()
return len(tokenizer.encode(text))

View File

@ -9,7 +9,6 @@
from typing import Optional, List, Any, Iterator, cast
from langchain.callbacks.manager import CallbackManager
from langchain_community.chat_models import QianfanChatEndpoint
from langchain.chat_models.base import BaseChatModel
from langchain.load import dumpd
from langchain.schema import LLMResult
@ -17,18 +16,31 @@ from langchain.schema.language_model import LanguageModelInput
from langchain.schema.messages import BaseMessageChunk, BaseMessage, HumanMessage, AIMessage, get_buffer_string
from langchain.schema.output import ChatGenerationChunk
from langchain.schema.runnable import RunnableConfig
from transformers import GPT2TokenizerFast
from langchain_community.chat_models import QianfanChatEndpoint
tokenizer = GPT2TokenizerFast.from_pretrained('gpt2', cache_dir="/opt/maxkb/model/tokenizer", resume_download=False,
force_download=False)
class TokenizerManage:
tokenizer = None
@staticmethod
def get_tokenizer():
from transformers import GPT2TokenizerFast
if TokenizerManage.tokenizer is None:
TokenizerManage.tokenizer = GPT2TokenizerFast.from_pretrained('gpt2',
cache_dir="/opt/maxkb/model/tokenizer",
resume_download=False,
force_download=False)
return TokenizerManage.tokenizer
class QianfanChatModel(QianfanChatEndpoint):
def get_num_tokens_from_messages(self, messages: List[BaseMessage]) -> int:
tokenizer = TokenizerManage.get_tokenizer()
return sum([len(tokenizer.encode(get_buffer_string([m]))) for m in messages])
def get_num_tokens(self, text: str) -> int:
tokenizer = TokenizerManage.get_tokenizer()
return len(tokenizer.encode(text))
def stream(