mirror of
https://github.com/1Panel-dev/MaxKB.git
synced 2025-12-26 01:33:05 +00:00
fix: 修改Tokenizer加载顺序
This commit is contained in:
parent
7933b15a38
commit
fb7dfba567
|
|
@ -10,15 +10,27 @@ from typing import List
|
|||
|
||||
from langchain_community.chat_models import ChatOpenAI
|
||||
from langchain_core.messages import BaseMessage, get_buffer_string
|
||||
from transformers import GPT2TokenizerFast
|
||||
|
||||
tokenizer = GPT2TokenizerFast.from_pretrained('gpt2', cache_dir="/opt/maxkb/model/tokenizer", resume_download=False,
|
||||
force_download=False)
|
||||
|
||||
class TokenizerManage:
|
||||
tokenizer = None
|
||||
|
||||
@staticmethod
|
||||
def get_tokenizer():
|
||||
from transformers import GPT2TokenizerFast
|
||||
if TokenizerManage.tokenizer is None:
|
||||
TokenizerManage.tokenizer = GPT2TokenizerFast.from_pretrained('gpt2',
|
||||
cache_dir="/opt/maxkb/model/tokenizer",
|
||||
resume_download=False,
|
||||
force_download=False)
|
||||
return TokenizerManage.tokenizer
|
||||
|
||||
|
||||
class OllamaChatModel(ChatOpenAI):
|
||||
def get_num_tokens_from_messages(self, messages: List[BaseMessage]) -> int:
|
||||
tokenizer = TokenizerManage.get_tokenizer()
|
||||
return sum([len(tokenizer.encode(get_buffer_string([m]))) for m in messages])
|
||||
|
||||
def get_num_tokens(self, text: str) -> int:
|
||||
tokenizer = TokenizerManage.get_tokenizer()
|
||||
return len(tokenizer.encode(text))
|
||||
|
|
|
|||
|
|
@ -9,7 +9,6 @@
|
|||
from typing import Optional, List, Any, Iterator, cast
|
||||
|
||||
from langchain.callbacks.manager import CallbackManager
|
||||
from langchain_community.chat_models import QianfanChatEndpoint
|
||||
from langchain.chat_models.base import BaseChatModel
|
||||
from langchain.load import dumpd
|
||||
from langchain.schema import LLMResult
|
||||
|
|
@ -17,18 +16,31 @@ from langchain.schema.language_model import LanguageModelInput
|
|||
from langchain.schema.messages import BaseMessageChunk, BaseMessage, HumanMessage, AIMessage, get_buffer_string
|
||||
from langchain.schema.output import ChatGenerationChunk
|
||||
from langchain.schema.runnable import RunnableConfig
|
||||
from transformers import GPT2TokenizerFast
|
||||
from langchain_community.chat_models import QianfanChatEndpoint
|
||||
|
||||
tokenizer = GPT2TokenizerFast.from_pretrained('gpt2', cache_dir="/opt/maxkb/model/tokenizer", resume_download=False,
|
||||
force_download=False)
|
||||
|
||||
class TokenizerManage:
|
||||
tokenizer = None
|
||||
|
||||
@staticmethod
|
||||
def get_tokenizer():
|
||||
from transformers import GPT2TokenizerFast
|
||||
if TokenizerManage.tokenizer is None:
|
||||
TokenizerManage.tokenizer = GPT2TokenizerFast.from_pretrained('gpt2',
|
||||
cache_dir="/opt/maxkb/model/tokenizer",
|
||||
resume_download=False,
|
||||
force_download=False)
|
||||
return TokenizerManage.tokenizer
|
||||
|
||||
|
||||
class QianfanChatModel(QianfanChatEndpoint):
|
||||
|
||||
def get_num_tokens_from_messages(self, messages: List[BaseMessage]) -> int:
|
||||
tokenizer = TokenizerManage.get_tokenizer()
|
||||
return sum([len(tokenizer.encode(get_buffer_string([m]))) for m in messages])
|
||||
|
||||
def get_num_tokens(self, text: str) -> int:
|
||||
tokenizer = TokenizerManage.get_tokenizer()
|
||||
return len(tokenizer.encode(text))
|
||||
|
||||
def stream(
|
||||
|
|
|
|||
Loading…
Reference in New Issue