From 29427a0ad695fdf7f7156809f5da112a87d980d4 Mon Sep 17 00:00:00 2001 From: shaohuzhang1 <80892890+shaohuzhang1@users.noreply.github.com> Date: Sun, 28 Apr 2024 11:59:19 +0800 Subject: [PATCH] =?UTF-8?q?fix:=20=E7=A7=81=E6=9C=89=E9=83=A8=E7=BD=B2?= =?UTF-8?q?=E8=AE=A1=E7=AE=97tokens=E6=8A=A5=E9=94=99=20(#284)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- apps/common/config/tokenizer_manage_config.py | 24 +++++++++++++++++++ .../azure_model_provider.py | 5 ++-- .../model/azure_chat_model.py | 24 +++++++++++++++++++ .../model/kimi_chat_model.py | 14 +---------- .../model/ollama_chat_model.py | 14 +---------- .../model/openai_chat_model.py | 14 +---------- .../model/qwen_chat_model.py | 24 +++++++++++++++++++ .../qwen_model_provider.py | 3 ++- .../model/qian_fan_chat_model.py | 14 +---------- .../xf_model_provider/model/xf_chat_model.py | 12 +++++++++- .../model/zhipu_chat_model.py | 24 +++++++++++++++++++ .../zhipu_model_provider.py | 3 ++- 12 files changed, 118 insertions(+), 57 deletions(-) create mode 100644 apps/common/config/tokenizer_manage_config.py create mode 100644 apps/setting/models_provider/impl/azure_model_provider/model/azure_chat_model.py create mode 100644 apps/setting/models_provider/impl/qwen_model_provider/model/qwen_chat_model.py create mode 100644 apps/setting/models_provider/impl/zhipu_model_provider/model/zhipu_chat_model.py diff --git a/apps/common/config/tokenizer_manage_config.py b/apps/common/config/tokenizer_manage_config.py new file mode 100644 index 000000000..1d3fa8dd9 --- /dev/null +++ b/apps/common/config/tokenizer_manage_config.py @@ -0,0 +1,24 @@ +# coding=utf-8 +""" + @project: maxkb + @Author:虎 + @file: tokenizer_manage_config.py + @date:2024/4/28 10:17 + @desc: +""" + + +class TokenizerManage: + tokenizer = None + + @staticmethod + def get_tokenizer(): + from transformers import GPT2TokenizerFast + if TokenizerManage.tokenizer is None: + TokenizerManage.tokenizer = GPT2TokenizerFast.from_pretrained( + 'gpt2', + cache_dir="/opt/maxkb/model/tokenizer", + local_files_only=True, + resume_download=False, + force_download=False) + return TokenizerManage.tokenizer diff --git a/apps/setting/models_provider/impl/azure_model_provider/azure_model_provider.py b/apps/setting/models_provider/impl/azure_model_provider/azure_model_provider.py index f58f8744c..1f04a268a 100644 --- a/apps/setting/models_provider/impl/azure_model_provider/azure_model_provider.py +++ b/apps/setting/models_provider/impl/azure_model_provider/azure_model_provider.py @@ -19,6 +19,7 @@ from common.util.file_util import get_file_content from setting.models_provider.base_model_provider import IModelProvider, ModelProvideInfo, BaseModelCredential, \ ModelInfo, \ ModelTypeConst, ValidCode +from setting.models_provider.impl.azure_model_provider.model.azure_chat_model import AzureChatModel from smartdoc.conf import PROJECT_DIR @@ -119,8 +120,8 @@ class AzureModelProvider(IModelProvider): def get_model(self, model_type, model_name, model_credential: Dict[str, object], **model_kwargs) -> AzureChatOpenAI: model_info: ModelInfo = model_dict.get(model_name) - azure_chat_open_ai = AzureChatOpenAI( - openai_api_base=model_credential.get('api_base'), + azure_chat_open_ai = AzureChatModel( + azure_endpoint=model_credential.get('api_base'), openai_api_version=model_info.api_version if model_name in model_dict else model_credential.get( 'api_version'), deployment_name=model_credential.get('deployment_name'), diff --git a/apps/setting/models_provider/impl/azure_model_provider/model/azure_chat_model.py b/apps/setting/models_provider/impl/azure_model_provider/model/azure_chat_model.py new file mode 100644 index 000000000..f11249de0 --- /dev/null +++ b/apps/setting/models_provider/impl/azure_model_provider/model/azure_chat_model.py @@ -0,0 +1,24 @@ +# coding=utf-8 +""" + @project: maxkb + @Author:虎 + @file: azure_chat_model.py + @date:2024/4/28 11:45 + @desc: +""" +from typing import List + +from langchain_core.messages import BaseMessage, get_buffer_string +from langchain_openai import AzureChatOpenAI + +from common.config.tokenizer_manage_config import TokenizerManage + + +class AzureChatModel(AzureChatOpenAI): + def get_num_tokens_from_messages(self, messages: List[BaseMessage]) -> int: + tokenizer = TokenizerManage.get_tokenizer() + return sum([len(tokenizer.encode(get_buffer_string([m]))) for m in messages]) + + def get_num_tokens(self, text: str) -> int: + tokenizer = TokenizerManage.get_tokenizer() + return len(tokenizer.encode(text)) diff --git a/apps/setting/models_provider/impl/kimi_model_provider/model/kimi_chat_model.py b/apps/setting/models_provider/impl/kimi_model_provider/model/kimi_chat_model.py index c69cae48d..deee11a02 100644 --- a/apps/setting/models_provider/impl/kimi_model_provider/model/kimi_chat_model.py +++ b/apps/setting/models_provider/impl/kimi_model_provider/model/kimi_chat_model.py @@ -11,19 +11,7 @@ from typing import List from langchain_community.chat_models import ChatOpenAI from langchain_core.messages import BaseMessage, get_buffer_string - -class TokenizerManage: - tokenizer = None - - @staticmethod - def get_tokenizer(): - from transformers import GPT2TokenizerFast - if TokenizerManage.tokenizer is None: - TokenizerManage.tokenizer = GPT2TokenizerFast.from_pretrained('gpt2', - cache_dir="/opt/maxkb/model/tokenizer", - resume_download=False, - force_download=False) - return TokenizerManage.tokenizer +from common.config.tokenizer_manage_config import TokenizerManage class KimiChatModel(ChatOpenAI): diff --git a/apps/setting/models_provider/impl/ollama_model_provider/model/ollama_chat_model.py b/apps/setting/models_provider/impl/ollama_model_provider/model/ollama_chat_model.py index 876e1f28d..86c5219d4 100644 --- a/apps/setting/models_provider/impl/ollama_model_provider/model/ollama_chat_model.py +++ b/apps/setting/models_provider/impl/ollama_model_provider/model/ollama_chat_model.py @@ -11,19 +11,7 @@ from typing import List from langchain_community.chat_models import ChatOpenAI from langchain_core.messages import BaseMessage, get_buffer_string - -class TokenizerManage: - tokenizer = None - - @staticmethod - def get_tokenizer(): - from transformers import GPT2TokenizerFast - if TokenizerManage.tokenizer is None: - TokenizerManage.tokenizer = GPT2TokenizerFast.from_pretrained('gpt2', - cache_dir="/opt/maxkb/model/tokenizer", - resume_download=False, - force_download=False) - return TokenizerManage.tokenizer +from common.config.tokenizer_manage_config import TokenizerManage class OllamaChatModel(ChatOpenAI): diff --git a/apps/setting/models_provider/impl/openai_model_provider/model/openai_chat_model.py b/apps/setting/models_provider/impl/openai_model_provider/model/openai_chat_model.py index 1cdfa2aff..7271fe8ad 100644 --- a/apps/setting/models_provider/impl/openai_model_provider/model/openai_chat_model.py +++ b/apps/setting/models_provider/impl/openai_model_provider/model/openai_chat_model.py @@ -11,19 +11,7 @@ from typing import List from langchain_core.messages import BaseMessage, get_buffer_string from langchain_openai import ChatOpenAI - -class TokenizerManage: - tokenizer = None - - @staticmethod - def get_tokenizer(): - from transformers import GPT2TokenizerFast - if TokenizerManage.tokenizer is None: - TokenizerManage.tokenizer = GPT2TokenizerFast.from_pretrained('gpt2', - cache_dir="/opt/maxkb/model/tokenizer", - resume_download=False, - force_download=False) - return TokenizerManage.tokenizer +from common.config.tokenizer_manage_config import TokenizerManage class OpenAIChatModel(ChatOpenAI): diff --git a/apps/setting/models_provider/impl/qwen_model_provider/model/qwen_chat_model.py b/apps/setting/models_provider/impl/qwen_model_provider/model/qwen_chat_model.py new file mode 100644 index 000000000..d3894d1d0 --- /dev/null +++ b/apps/setting/models_provider/impl/qwen_model_provider/model/qwen_chat_model.py @@ -0,0 +1,24 @@ +# coding=utf-8 +""" + @project: maxkb + @Author:虎 + @file: qwen_chat_model.py + @date:2024/4/28 11:44 + @desc: +""" +from typing import List + +from langchain_community.chat_models import ChatTongyi +from langchain_core.messages import BaseMessage, get_buffer_string + +from common.config.tokenizer_manage_config import TokenizerManage + + +class QwenChatModel(ChatTongyi): + def get_num_tokens_from_messages(self, messages: List[BaseMessage]) -> int: + tokenizer = TokenizerManage.get_tokenizer() + return sum([len(tokenizer.encode(get_buffer_string([m]))) for m in messages]) + + def get_num_tokens(self, text: str) -> int: + tokenizer = TokenizerManage.get_tokenizer() + return len(tokenizer.encode(text)) diff --git a/apps/setting/models_provider/impl/qwen_model_provider/qwen_model_provider.py b/apps/setting/models_provider/impl/qwen_model_provider/qwen_model_provider.py index 46ad1c6ec..179f90368 100644 --- a/apps/setting/models_provider/impl/qwen_model_provider/qwen_model_provider.py +++ b/apps/setting/models_provider/impl/qwen_model_provider/qwen_model_provider.py @@ -18,6 +18,7 @@ from common.forms import BaseForm from common.util.file_util import get_file_content from setting.models_provider.base_model_provider import ModelProvideInfo, ModelTypeConst, BaseModelCredential, \ ModelInfo, IModelProvider, ValidCode +from setting.models_provider.impl.qwen_model_provider.model.qwen_chat_model import QwenChatModel from smartdoc.conf import PROJECT_DIR @@ -66,7 +67,7 @@ class QwenModelProvider(IModelProvider): return 3 def get_model(self, model_type, model_name, model_credential: Dict[str, object], **model_kwargs) -> ChatTongyi: - chat_tong_yi = ChatTongyi( + chat_tong_yi = QwenChatModel( model_name=model_name, dashscope_api_key=model_credential.get('api_key') ) diff --git a/apps/setting/models_provider/impl/wenxin_model_provider/model/qian_fan_chat_model.py b/apps/setting/models_provider/impl/wenxin_model_provider/model/qian_fan_chat_model.py index a2065186e..b07e8a01b 100644 --- a/apps/setting/models_provider/impl/wenxin_model_provider/model/qian_fan_chat_model.py +++ b/apps/setting/models_provider/impl/wenxin_model_provider/model/qian_fan_chat_model.py @@ -18,19 +18,7 @@ from langchain.schema.output import ChatGenerationChunk from langchain.schema.runnable import RunnableConfig from langchain_community.chat_models import QianfanChatEndpoint - -class TokenizerManage: - tokenizer = None - - @staticmethod - def get_tokenizer(): - from transformers import GPT2TokenizerFast - if TokenizerManage.tokenizer is None: - TokenizerManage.tokenizer = GPT2TokenizerFast.from_pretrained('gpt2', - cache_dir="/opt/maxkb/model/tokenizer", - resume_download=False, - force_download=False) - return TokenizerManage.tokenizer +from common.config.tokenizer_manage_config import TokenizerManage class QianfanChatModel(QianfanChatEndpoint): diff --git a/apps/setting/models_provider/impl/xf_model_provider/model/xf_chat_model.py b/apps/setting/models_provider/impl/xf_model_provider/model/xf_chat_model.py index a09d48092..3b6a22c47 100644 --- a/apps/setting/models_provider/impl/xf_model_provider/model/xf_chat_model.py +++ b/apps/setting/models_provider/impl/xf_model_provider/model/xf_chat_model.py @@ -12,11 +12,21 @@ from typing import List, Optional, Any, Iterator from langchain_community.chat_models import ChatSparkLLM from langchain_community.chat_models.sparkllm import _convert_message_to_dict, _convert_delta_to_message_chunk from langchain_core.callbacks import CallbackManagerForLLMRun -from langchain_core.messages import BaseMessage, AIMessageChunk +from langchain_core.messages import BaseMessage, AIMessageChunk, get_buffer_string from langchain_core.outputs import ChatGenerationChunk +from common.config.tokenizer_manage_config import TokenizerManage + class XFChatSparkLLM(ChatSparkLLM): + def get_num_tokens_from_messages(self, messages: List[BaseMessage]) -> int: + tokenizer = TokenizerManage.get_tokenizer() + return sum([len(tokenizer.encode(get_buffer_string([m]))) for m in messages]) + + def get_num_tokens(self, text: str) -> int: + tokenizer = TokenizerManage.get_tokenizer() + return len(tokenizer.encode(text)) + def _stream( self, messages: List[BaseMessage], diff --git a/apps/setting/models_provider/impl/zhipu_model_provider/model/zhipu_chat_model.py b/apps/setting/models_provider/impl/zhipu_model_provider/model/zhipu_chat_model.py new file mode 100644 index 000000000..ceab8988d --- /dev/null +++ b/apps/setting/models_provider/impl/zhipu_model_provider/model/zhipu_chat_model.py @@ -0,0 +1,24 @@ +# coding=utf-8 +""" + @project: maxkb + @Author:虎 + @file: zhipu_chat_model.py + @date:2024/4/28 11:42 + @desc: +""" +from typing import List + +from langchain_community.chat_models import ChatZhipuAI +from langchain_core.messages import BaseMessage, get_buffer_string + +from common.config.tokenizer_manage_config import TokenizerManage + + +class ZhipuChatModel(ChatZhipuAI): + def get_num_tokens_from_messages(self, messages: List[BaseMessage]) -> int: + tokenizer = TokenizerManage.get_tokenizer() + return sum([len(tokenizer.encode(get_buffer_string([m]))) for m in messages]) + + def get_num_tokens(self, text: str) -> int: + tokenizer = TokenizerManage.get_tokenizer() + return len(tokenizer.encode(text)) diff --git a/apps/setting/models_provider/impl/zhipu_model_provider/zhipu_model_provider.py b/apps/setting/models_provider/impl/zhipu_model_provider/zhipu_model_provider.py index b84bb3d15..ebbb3b469 100644 --- a/apps/setting/models_provider/impl/zhipu_model_provider/zhipu_model_provider.py +++ b/apps/setting/models_provider/impl/zhipu_model_provider/zhipu_model_provider.py @@ -18,6 +18,7 @@ from common.forms import BaseForm from common.util.file_util import get_file_content from setting.models_provider.base_model_provider import ModelProvideInfo, ModelTypeConst, BaseModelCredential, \ ModelInfo, IModelProvider, ValidCode +from setting.models_provider.impl.zhipu_model_provider.model.zhipu_chat_model import ZhipuChatModel from smartdoc.conf import PROJECT_DIR @@ -66,7 +67,7 @@ class ZhiPuModelProvider(IModelProvider): return 3 def get_model(self, model_type, model_name, model_credential: Dict[str, object], **model_kwargs) -> ChatZhipuAI: - zhipuai_chat = ChatZhipuAI( + zhipuai_chat = ZhipuChatModel( temperature=0.5, api_key=model_credential.get('api_key'), model=model_name