fix: 私有部署计算tokens报错 (#284)

This commit is contained in:
shaohuzhang1 2024-04-28 11:59:19 +08:00 committed by GitHub
parent 9d808b4ccd
commit 29427a0ad6
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
12 changed files with 118 additions and 57 deletions

View File

@ -0,0 +1,24 @@
# coding=utf-8
"""
@project: maxkb
@Author
@file tokenizer_manage_config.py
@date2024/4/28 10:17
@desc:
"""
class TokenizerManage:
tokenizer = None
@staticmethod
def get_tokenizer():
from transformers import GPT2TokenizerFast
if TokenizerManage.tokenizer is None:
TokenizerManage.tokenizer = GPT2TokenizerFast.from_pretrained(
'gpt2',
cache_dir="/opt/maxkb/model/tokenizer",
local_files_only=True,
resume_download=False,
force_download=False)
return TokenizerManage.tokenizer

View File

@ -19,6 +19,7 @@ from common.util.file_util import get_file_content
from setting.models_provider.base_model_provider import IModelProvider, ModelProvideInfo, BaseModelCredential, \
ModelInfo, \
ModelTypeConst, ValidCode
from setting.models_provider.impl.azure_model_provider.model.azure_chat_model import AzureChatModel
from smartdoc.conf import PROJECT_DIR
@ -119,8 +120,8 @@ class AzureModelProvider(IModelProvider):
def get_model(self, model_type, model_name, model_credential: Dict[str, object], **model_kwargs) -> AzureChatOpenAI:
model_info: ModelInfo = model_dict.get(model_name)
azure_chat_open_ai = AzureChatOpenAI(
openai_api_base=model_credential.get('api_base'),
azure_chat_open_ai = AzureChatModel(
azure_endpoint=model_credential.get('api_base'),
openai_api_version=model_info.api_version if model_name in model_dict else model_credential.get(
'api_version'),
deployment_name=model_credential.get('deployment_name'),

View File

@ -0,0 +1,24 @@
# coding=utf-8
"""
@project: maxkb
@Author
@file azure_chat_model.py
@date2024/4/28 11:45
@desc:
"""
from typing import List
from langchain_core.messages import BaseMessage, get_buffer_string
from langchain_openai import AzureChatOpenAI
from common.config.tokenizer_manage_config import TokenizerManage
class AzureChatModel(AzureChatOpenAI):
def get_num_tokens_from_messages(self, messages: List[BaseMessage]) -> int:
tokenizer = TokenizerManage.get_tokenizer()
return sum([len(tokenizer.encode(get_buffer_string([m]))) for m in messages])
def get_num_tokens(self, text: str) -> int:
tokenizer = TokenizerManage.get_tokenizer()
return len(tokenizer.encode(text))

View File

@ -11,19 +11,7 @@ from typing import List
from langchain_community.chat_models import ChatOpenAI
from langchain_core.messages import BaseMessage, get_buffer_string
class TokenizerManage:
tokenizer = None
@staticmethod
def get_tokenizer():
from transformers import GPT2TokenizerFast
if TokenizerManage.tokenizer is None:
TokenizerManage.tokenizer = GPT2TokenizerFast.from_pretrained('gpt2',
cache_dir="/opt/maxkb/model/tokenizer",
resume_download=False,
force_download=False)
return TokenizerManage.tokenizer
from common.config.tokenizer_manage_config import TokenizerManage
class KimiChatModel(ChatOpenAI):

View File

@ -11,19 +11,7 @@ from typing import List
from langchain_community.chat_models import ChatOpenAI
from langchain_core.messages import BaseMessage, get_buffer_string
class TokenizerManage:
tokenizer = None
@staticmethod
def get_tokenizer():
from transformers import GPT2TokenizerFast
if TokenizerManage.tokenizer is None:
TokenizerManage.tokenizer = GPT2TokenizerFast.from_pretrained('gpt2',
cache_dir="/opt/maxkb/model/tokenizer",
resume_download=False,
force_download=False)
return TokenizerManage.tokenizer
from common.config.tokenizer_manage_config import TokenizerManage
class OllamaChatModel(ChatOpenAI):

View File

@ -11,19 +11,7 @@ from typing import List
from langchain_core.messages import BaseMessage, get_buffer_string
from langchain_openai import ChatOpenAI
class TokenizerManage:
tokenizer = None
@staticmethod
def get_tokenizer():
from transformers import GPT2TokenizerFast
if TokenizerManage.tokenizer is None:
TokenizerManage.tokenizer = GPT2TokenizerFast.from_pretrained('gpt2',
cache_dir="/opt/maxkb/model/tokenizer",
resume_download=False,
force_download=False)
return TokenizerManage.tokenizer
from common.config.tokenizer_manage_config import TokenizerManage
class OpenAIChatModel(ChatOpenAI):

View File

@ -0,0 +1,24 @@
# coding=utf-8
"""
@project: maxkb
@Author
@file qwen_chat_model.py
@date2024/4/28 11:44
@desc:
"""
from typing import List
from langchain_community.chat_models import ChatTongyi
from langchain_core.messages import BaseMessage, get_buffer_string
from common.config.tokenizer_manage_config import TokenizerManage
class QwenChatModel(ChatTongyi):
def get_num_tokens_from_messages(self, messages: List[BaseMessage]) -> int:
tokenizer = TokenizerManage.get_tokenizer()
return sum([len(tokenizer.encode(get_buffer_string([m]))) for m in messages])
def get_num_tokens(self, text: str) -> int:
tokenizer = TokenizerManage.get_tokenizer()
return len(tokenizer.encode(text))

View File

@ -18,6 +18,7 @@ from common.forms import BaseForm
from common.util.file_util import get_file_content
from setting.models_provider.base_model_provider import ModelProvideInfo, ModelTypeConst, BaseModelCredential, \
ModelInfo, IModelProvider, ValidCode
from setting.models_provider.impl.qwen_model_provider.model.qwen_chat_model import QwenChatModel
from smartdoc.conf import PROJECT_DIR
@ -66,7 +67,7 @@ class QwenModelProvider(IModelProvider):
return 3
def get_model(self, model_type, model_name, model_credential: Dict[str, object], **model_kwargs) -> ChatTongyi:
chat_tong_yi = ChatTongyi(
chat_tong_yi = QwenChatModel(
model_name=model_name,
dashscope_api_key=model_credential.get('api_key')
)

View File

@ -18,19 +18,7 @@ from langchain.schema.output import ChatGenerationChunk
from langchain.schema.runnable import RunnableConfig
from langchain_community.chat_models import QianfanChatEndpoint
class TokenizerManage:
tokenizer = None
@staticmethod
def get_tokenizer():
from transformers import GPT2TokenizerFast
if TokenizerManage.tokenizer is None:
TokenizerManage.tokenizer = GPT2TokenizerFast.from_pretrained('gpt2',
cache_dir="/opt/maxkb/model/tokenizer",
resume_download=False,
force_download=False)
return TokenizerManage.tokenizer
from common.config.tokenizer_manage_config import TokenizerManage
class QianfanChatModel(QianfanChatEndpoint):

View File

@ -12,11 +12,21 @@ from typing import List, Optional, Any, Iterator
from langchain_community.chat_models import ChatSparkLLM
from langchain_community.chat_models.sparkllm import _convert_message_to_dict, _convert_delta_to_message_chunk
from langchain_core.callbacks import CallbackManagerForLLMRun
from langchain_core.messages import BaseMessage, AIMessageChunk
from langchain_core.messages import BaseMessage, AIMessageChunk, get_buffer_string
from langchain_core.outputs import ChatGenerationChunk
from common.config.tokenizer_manage_config import TokenizerManage
class XFChatSparkLLM(ChatSparkLLM):
def get_num_tokens_from_messages(self, messages: List[BaseMessage]) -> int:
tokenizer = TokenizerManage.get_tokenizer()
return sum([len(tokenizer.encode(get_buffer_string([m]))) for m in messages])
def get_num_tokens(self, text: str) -> int:
tokenizer = TokenizerManage.get_tokenizer()
return len(tokenizer.encode(text))
def _stream(
self,
messages: List[BaseMessage],

View File

@ -0,0 +1,24 @@
# coding=utf-8
"""
@project: maxkb
@Author
@file zhipu_chat_model.py
@date2024/4/28 11:42
@desc:
"""
from typing import List
from langchain_community.chat_models import ChatZhipuAI
from langchain_core.messages import BaseMessage, get_buffer_string
from common.config.tokenizer_manage_config import TokenizerManage
class ZhipuChatModel(ChatZhipuAI):
def get_num_tokens_from_messages(self, messages: List[BaseMessage]) -> int:
tokenizer = TokenizerManage.get_tokenizer()
return sum([len(tokenizer.encode(get_buffer_string([m]))) for m in messages])
def get_num_tokens(self, text: str) -> int:
tokenizer = TokenizerManage.get_tokenizer()
return len(tokenizer.encode(text))

View File

@ -18,6 +18,7 @@ from common.forms import BaseForm
from common.util.file_util import get_file_content
from setting.models_provider.base_model_provider import ModelProvideInfo, ModelTypeConst, BaseModelCredential, \
ModelInfo, IModelProvider, ValidCode
from setting.models_provider.impl.zhipu_model_provider.model.zhipu_chat_model import ZhipuChatModel
from smartdoc.conf import PROJECT_DIR
@ -66,7 +67,7 @@ class ZhiPuModelProvider(IModelProvider):
return 3
def get_model(self, model_type, model_name, model_credential: Dict[str, object], **model_kwargs) -> ChatZhipuAI:
zhipuai_chat = ChatZhipuAI(
zhipuai_chat = ZhipuChatModel(
temperature=0.5,
api_key=model_credential.get('api_key'),
model=model_name