mirror of
https://github.com/1Panel-dev/MaxKB.git
synced 2025-12-26 01:33:05 +00:00
fix: 私有部署计算tokens报错 (#284)
This commit is contained in:
parent
9d808b4ccd
commit
29427a0ad6
|
|
@ -0,0 +1,24 @@
|
|||
# coding=utf-8
|
||||
"""
|
||||
@project: maxkb
|
||||
@Author:虎
|
||||
@file: tokenizer_manage_config.py
|
||||
@date:2024/4/28 10:17
|
||||
@desc:
|
||||
"""
|
||||
|
||||
|
||||
class TokenizerManage:
|
||||
tokenizer = None
|
||||
|
||||
@staticmethod
|
||||
def get_tokenizer():
|
||||
from transformers import GPT2TokenizerFast
|
||||
if TokenizerManage.tokenizer is None:
|
||||
TokenizerManage.tokenizer = GPT2TokenizerFast.from_pretrained(
|
||||
'gpt2',
|
||||
cache_dir="/opt/maxkb/model/tokenizer",
|
||||
local_files_only=True,
|
||||
resume_download=False,
|
||||
force_download=False)
|
||||
return TokenizerManage.tokenizer
|
||||
|
|
@ -19,6 +19,7 @@ from common.util.file_util import get_file_content
|
|||
from setting.models_provider.base_model_provider import IModelProvider, ModelProvideInfo, BaseModelCredential, \
|
||||
ModelInfo, \
|
||||
ModelTypeConst, ValidCode
|
||||
from setting.models_provider.impl.azure_model_provider.model.azure_chat_model import AzureChatModel
|
||||
from smartdoc.conf import PROJECT_DIR
|
||||
|
||||
|
||||
|
|
@ -119,8 +120,8 @@ class AzureModelProvider(IModelProvider):
|
|||
|
||||
def get_model(self, model_type, model_name, model_credential: Dict[str, object], **model_kwargs) -> AzureChatOpenAI:
|
||||
model_info: ModelInfo = model_dict.get(model_name)
|
||||
azure_chat_open_ai = AzureChatOpenAI(
|
||||
openai_api_base=model_credential.get('api_base'),
|
||||
azure_chat_open_ai = AzureChatModel(
|
||||
azure_endpoint=model_credential.get('api_base'),
|
||||
openai_api_version=model_info.api_version if model_name in model_dict else model_credential.get(
|
||||
'api_version'),
|
||||
deployment_name=model_credential.get('deployment_name'),
|
||||
|
|
|
|||
|
|
@ -0,0 +1,24 @@
|
|||
# coding=utf-8
|
||||
"""
|
||||
@project: maxkb
|
||||
@Author:虎
|
||||
@file: azure_chat_model.py
|
||||
@date:2024/4/28 11:45
|
||||
@desc:
|
||||
"""
|
||||
from typing import List
|
||||
|
||||
from langchain_core.messages import BaseMessage, get_buffer_string
|
||||
from langchain_openai import AzureChatOpenAI
|
||||
|
||||
from common.config.tokenizer_manage_config import TokenizerManage
|
||||
|
||||
|
||||
class AzureChatModel(AzureChatOpenAI):
|
||||
def get_num_tokens_from_messages(self, messages: List[BaseMessage]) -> int:
|
||||
tokenizer = TokenizerManage.get_tokenizer()
|
||||
return sum([len(tokenizer.encode(get_buffer_string([m]))) for m in messages])
|
||||
|
||||
def get_num_tokens(self, text: str) -> int:
|
||||
tokenizer = TokenizerManage.get_tokenizer()
|
||||
return len(tokenizer.encode(text))
|
||||
|
|
@ -11,19 +11,7 @@ from typing import List
|
|||
from langchain_community.chat_models import ChatOpenAI
|
||||
from langchain_core.messages import BaseMessage, get_buffer_string
|
||||
|
||||
|
||||
class TokenizerManage:
|
||||
tokenizer = None
|
||||
|
||||
@staticmethod
|
||||
def get_tokenizer():
|
||||
from transformers import GPT2TokenizerFast
|
||||
if TokenizerManage.tokenizer is None:
|
||||
TokenizerManage.tokenizer = GPT2TokenizerFast.from_pretrained('gpt2',
|
||||
cache_dir="/opt/maxkb/model/tokenizer",
|
||||
resume_download=False,
|
||||
force_download=False)
|
||||
return TokenizerManage.tokenizer
|
||||
from common.config.tokenizer_manage_config import TokenizerManage
|
||||
|
||||
|
||||
class KimiChatModel(ChatOpenAI):
|
||||
|
|
|
|||
|
|
@ -11,19 +11,7 @@ from typing import List
|
|||
from langchain_community.chat_models import ChatOpenAI
|
||||
from langchain_core.messages import BaseMessage, get_buffer_string
|
||||
|
||||
|
||||
class TokenizerManage:
|
||||
tokenizer = None
|
||||
|
||||
@staticmethod
|
||||
def get_tokenizer():
|
||||
from transformers import GPT2TokenizerFast
|
||||
if TokenizerManage.tokenizer is None:
|
||||
TokenizerManage.tokenizer = GPT2TokenizerFast.from_pretrained('gpt2',
|
||||
cache_dir="/opt/maxkb/model/tokenizer",
|
||||
resume_download=False,
|
||||
force_download=False)
|
||||
return TokenizerManage.tokenizer
|
||||
from common.config.tokenizer_manage_config import TokenizerManage
|
||||
|
||||
|
||||
class OllamaChatModel(ChatOpenAI):
|
||||
|
|
|
|||
|
|
@ -11,19 +11,7 @@ from typing import List
|
|||
from langchain_core.messages import BaseMessage, get_buffer_string
|
||||
from langchain_openai import ChatOpenAI
|
||||
|
||||
|
||||
class TokenizerManage:
|
||||
tokenizer = None
|
||||
|
||||
@staticmethod
|
||||
def get_tokenizer():
|
||||
from transformers import GPT2TokenizerFast
|
||||
if TokenizerManage.tokenizer is None:
|
||||
TokenizerManage.tokenizer = GPT2TokenizerFast.from_pretrained('gpt2',
|
||||
cache_dir="/opt/maxkb/model/tokenizer",
|
||||
resume_download=False,
|
||||
force_download=False)
|
||||
return TokenizerManage.tokenizer
|
||||
from common.config.tokenizer_manage_config import TokenizerManage
|
||||
|
||||
|
||||
class OpenAIChatModel(ChatOpenAI):
|
||||
|
|
|
|||
|
|
@ -0,0 +1,24 @@
|
|||
# coding=utf-8
|
||||
"""
|
||||
@project: maxkb
|
||||
@Author:虎
|
||||
@file: qwen_chat_model.py
|
||||
@date:2024/4/28 11:44
|
||||
@desc:
|
||||
"""
|
||||
from typing import List
|
||||
|
||||
from langchain_community.chat_models import ChatTongyi
|
||||
from langchain_core.messages import BaseMessage, get_buffer_string
|
||||
|
||||
from common.config.tokenizer_manage_config import TokenizerManage
|
||||
|
||||
|
||||
class QwenChatModel(ChatTongyi):
|
||||
def get_num_tokens_from_messages(self, messages: List[BaseMessage]) -> int:
|
||||
tokenizer = TokenizerManage.get_tokenizer()
|
||||
return sum([len(tokenizer.encode(get_buffer_string([m]))) for m in messages])
|
||||
|
||||
def get_num_tokens(self, text: str) -> int:
|
||||
tokenizer = TokenizerManage.get_tokenizer()
|
||||
return len(tokenizer.encode(text))
|
||||
|
|
@ -18,6 +18,7 @@ from common.forms import BaseForm
|
|||
from common.util.file_util import get_file_content
|
||||
from setting.models_provider.base_model_provider import ModelProvideInfo, ModelTypeConst, BaseModelCredential, \
|
||||
ModelInfo, IModelProvider, ValidCode
|
||||
from setting.models_provider.impl.qwen_model_provider.model.qwen_chat_model import QwenChatModel
|
||||
from smartdoc.conf import PROJECT_DIR
|
||||
|
||||
|
||||
|
|
@ -66,7 +67,7 @@ class QwenModelProvider(IModelProvider):
|
|||
return 3
|
||||
|
||||
def get_model(self, model_type, model_name, model_credential: Dict[str, object], **model_kwargs) -> ChatTongyi:
|
||||
chat_tong_yi = ChatTongyi(
|
||||
chat_tong_yi = QwenChatModel(
|
||||
model_name=model_name,
|
||||
dashscope_api_key=model_credential.get('api_key')
|
||||
)
|
||||
|
|
|
|||
|
|
@ -18,19 +18,7 @@ from langchain.schema.output import ChatGenerationChunk
|
|||
from langchain.schema.runnable import RunnableConfig
|
||||
from langchain_community.chat_models import QianfanChatEndpoint
|
||||
|
||||
|
||||
class TokenizerManage:
|
||||
tokenizer = None
|
||||
|
||||
@staticmethod
|
||||
def get_tokenizer():
|
||||
from transformers import GPT2TokenizerFast
|
||||
if TokenizerManage.tokenizer is None:
|
||||
TokenizerManage.tokenizer = GPT2TokenizerFast.from_pretrained('gpt2',
|
||||
cache_dir="/opt/maxkb/model/tokenizer",
|
||||
resume_download=False,
|
||||
force_download=False)
|
||||
return TokenizerManage.tokenizer
|
||||
from common.config.tokenizer_manage_config import TokenizerManage
|
||||
|
||||
|
||||
class QianfanChatModel(QianfanChatEndpoint):
|
||||
|
|
|
|||
|
|
@ -12,11 +12,21 @@ from typing import List, Optional, Any, Iterator
|
|||
from langchain_community.chat_models import ChatSparkLLM
|
||||
from langchain_community.chat_models.sparkllm import _convert_message_to_dict, _convert_delta_to_message_chunk
|
||||
from langchain_core.callbacks import CallbackManagerForLLMRun
|
||||
from langchain_core.messages import BaseMessage, AIMessageChunk
|
||||
from langchain_core.messages import BaseMessage, AIMessageChunk, get_buffer_string
|
||||
from langchain_core.outputs import ChatGenerationChunk
|
||||
|
||||
from common.config.tokenizer_manage_config import TokenizerManage
|
||||
|
||||
|
||||
class XFChatSparkLLM(ChatSparkLLM):
|
||||
def get_num_tokens_from_messages(self, messages: List[BaseMessage]) -> int:
|
||||
tokenizer = TokenizerManage.get_tokenizer()
|
||||
return sum([len(tokenizer.encode(get_buffer_string([m]))) for m in messages])
|
||||
|
||||
def get_num_tokens(self, text: str) -> int:
|
||||
tokenizer = TokenizerManage.get_tokenizer()
|
||||
return len(tokenizer.encode(text))
|
||||
|
||||
def _stream(
|
||||
self,
|
||||
messages: List[BaseMessage],
|
||||
|
|
|
|||
|
|
@ -0,0 +1,24 @@
|
|||
# coding=utf-8
|
||||
"""
|
||||
@project: maxkb
|
||||
@Author:虎
|
||||
@file: zhipu_chat_model.py
|
||||
@date:2024/4/28 11:42
|
||||
@desc:
|
||||
"""
|
||||
from typing import List
|
||||
|
||||
from langchain_community.chat_models import ChatZhipuAI
|
||||
from langchain_core.messages import BaseMessage, get_buffer_string
|
||||
|
||||
from common.config.tokenizer_manage_config import TokenizerManage
|
||||
|
||||
|
||||
class ZhipuChatModel(ChatZhipuAI):
|
||||
def get_num_tokens_from_messages(self, messages: List[BaseMessage]) -> int:
|
||||
tokenizer = TokenizerManage.get_tokenizer()
|
||||
return sum([len(tokenizer.encode(get_buffer_string([m]))) for m in messages])
|
||||
|
||||
def get_num_tokens(self, text: str) -> int:
|
||||
tokenizer = TokenizerManage.get_tokenizer()
|
||||
return len(tokenizer.encode(text))
|
||||
|
|
@ -18,6 +18,7 @@ from common.forms import BaseForm
|
|||
from common.util.file_util import get_file_content
|
||||
from setting.models_provider.base_model_provider import ModelProvideInfo, ModelTypeConst, BaseModelCredential, \
|
||||
ModelInfo, IModelProvider, ValidCode
|
||||
from setting.models_provider.impl.zhipu_model_provider.model.zhipu_chat_model import ZhipuChatModel
|
||||
from smartdoc.conf import PROJECT_DIR
|
||||
|
||||
|
||||
|
|
@ -66,7 +67,7 @@ class ZhiPuModelProvider(IModelProvider):
|
|||
return 3
|
||||
|
||||
def get_model(self, model_type, model_name, model_credential: Dict[str, object], **model_kwargs) -> ChatZhipuAI:
|
||||
zhipuai_chat = ChatZhipuAI(
|
||||
zhipuai_chat = ZhipuChatModel(
|
||||
temperature=0.5,
|
||||
api_key=model_credential.get('api_key'),
|
||||
model=model_name
|
||||
|
|
|
|||
Loading…
Reference in New Issue