diff --git a/apps/setting/models_provider/impl/openai_model_provider/model/openai_chat_model.py b/apps/setting/models_provider/impl/openai_model_provider/model/openai_chat_model.py new file mode 100644 index 000000000..1cdfa2aff --- /dev/null +++ b/apps/setting/models_provider/impl/openai_model_provider/model/openai_chat_model.py @@ -0,0 +1,42 @@ +# coding=utf-8 +""" + @project: maxkb + @Author:虎 + @file: openai_chat_model.py + @date:2024/4/18 15:28 + @desc: +""" +from typing import List + +from langchain_core.messages import BaseMessage, get_buffer_string +from langchain_openai import ChatOpenAI + + +class TokenizerManage: + tokenizer = None + + @staticmethod + def get_tokenizer(): + from transformers import GPT2TokenizerFast + if TokenizerManage.tokenizer is None: + TokenizerManage.tokenizer = GPT2TokenizerFast.from_pretrained('gpt2', + cache_dir="/opt/maxkb/model/tokenizer", + resume_download=False, + force_download=False) + return TokenizerManage.tokenizer + + +class OpenAIChatModel(ChatOpenAI): + def get_num_tokens_from_messages(self, messages: List[BaseMessage]) -> int: + try: + return super().get_num_tokens_from_messages(messages) + except Exception as e: + tokenizer = TokenizerManage.get_tokenizer() + return sum([len(tokenizer.encode(get_buffer_string([m]))) for m in messages]) + + def get_num_tokens(self, text: str) -> int: + try: + return super().get_num_tokens(text) + except Exception as e: + tokenizer = TokenizerManage.get_tokenizer() + return len(tokenizer.encode(text)) diff --git a/apps/setting/models_provider/impl/openai_model_provider/openai_model_provider.py b/apps/setting/models_provider/impl/openai_model_provider/openai_model_provider.py index 9c773c479..aab6ac08c 100644 --- a/apps/setting/models_provider/impl/openai_model_provider/openai_model_provider.py +++ b/apps/setting/models_provider/impl/openai_model_provider/openai_model_provider.py @@ -10,7 +10,6 @@ import os from typing import Dict from langchain.schema import HumanMessage -from langchain_openai import ChatOpenAI from common import forms from common.exception.app_exception import AppApiException @@ -19,6 +18,7 @@ from common.util.file_util import get_file_content from setting.models_provider.base_model_provider import IModelProvider, ModelProvideInfo, BaseModelCredential, \ ModelInfo, \ ModelTypeConst, ValidCode +from setting.models_provider.impl.openai_model_provider.model.openai_chat_model import OpenAIChatModel from smartdoc.conf import PROJECT_DIR @@ -71,8 +71,8 @@ class OpenAIModelProvider(IModelProvider): def get_dialogue_number(self): return 3 - def get_model(self, model_type, model_name, model_credential: Dict[str, object], **model_kwargs) -> ChatOpenAI: - azure_chat_open_ai = ChatOpenAI( + def get_model(self, model_type, model_name, model_credential: Dict[str, object], **model_kwargs) -> OpenAIChatModel: + azure_chat_open_ai = OpenAIChatModel( model=model_name, openai_api_base=model_credential.get('api_base'), openai_api_key=model_credential.get('api_key')