From 0b083eecee45654bef291ba97a8dc56d8a068199 Mon Sep 17 00:00:00 2001 From: shaohuzhang1 <80892890+shaohuzhang1@users.noreply.github.com> Date: Thu, 18 Apr 2024 15:41:48 +0800 Subject: [PATCH] =?UTF-8?q?fix:=20=E4=BF=AE=E5=A4=8Dopenai=E6=A8=A1?= =?UTF-8?q?=E5=9E=8B=E5=9C=A8=E5=AF=B9=E6=8E=A5=E5=85=B6=E4=BB=96=E5=85=BC?= =?UTF-8?q?=E5=AE=B9openai=E6=8E=A5=E5=8F=A3=E5=B9=B3=E5=8F=B0=E6=97=B6?= =?UTF-8?q?=E8=8E=B7=E5=8F=96tokens=E9=94=99=E8=AF=AF=20(#157)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../model/openai_chat_model.py | 42 +++++++++++++++++++ .../openai_model_provider.py | 6 +-- 2 files changed, 45 insertions(+), 3 deletions(-) create mode 100644 apps/setting/models_provider/impl/openai_model_provider/model/openai_chat_model.py diff --git a/apps/setting/models_provider/impl/openai_model_provider/model/openai_chat_model.py b/apps/setting/models_provider/impl/openai_model_provider/model/openai_chat_model.py new file mode 100644 index 000000000..1cdfa2aff --- /dev/null +++ b/apps/setting/models_provider/impl/openai_model_provider/model/openai_chat_model.py @@ -0,0 +1,42 @@ +# coding=utf-8 +""" + @project: maxkb + @Author:虎 + @file: openai_chat_model.py + @date:2024/4/18 15:28 + @desc: +""" +from typing import List + +from langchain_core.messages import BaseMessage, get_buffer_string +from langchain_openai import ChatOpenAI + + +class TokenizerManage: + tokenizer = None + + @staticmethod + def get_tokenizer(): + from transformers import GPT2TokenizerFast + if TokenizerManage.tokenizer is None: + TokenizerManage.tokenizer = GPT2TokenizerFast.from_pretrained('gpt2', + cache_dir="/opt/maxkb/model/tokenizer", + resume_download=False, + force_download=False) + return TokenizerManage.tokenizer + + +class OpenAIChatModel(ChatOpenAI): + def get_num_tokens_from_messages(self, messages: List[BaseMessage]) -> int: + try: + return super().get_num_tokens_from_messages(messages) + except Exception as e: + tokenizer = TokenizerManage.get_tokenizer() + return sum([len(tokenizer.encode(get_buffer_string([m]))) for m in messages]) + + def get_num_tokens(self, text: str) -> int: + try: + return super().get_num_tokens(text) + except Exception as e: + tokenizer = TokenizerManage.get_tokenizer() + return len(tokenizer.encode(text)) diff --git a/apps/setting/models_provider/impl/openai_model_provider/openai_model_provider.py b/apps/setting/models_provider/impl/openai_model_provider/openai_model_provider.py index 9c773c479..aab6ac08c 100644 --- a/apps/setting/models_provider/impl/openai_model_provider/openai_model_provider.py +++ b/apps/setting/models_provider/impl/openai_model_provider/openai_model_provider.py @@ -10,7 +10,6 @@ import os from typing import Dict from langchain.schema import HumanMessage -from langchain_openai import ChatOpenAI from common import forms from common.exception.app_exception import AppApiException @@ -19,6 +18,7 @@ from common.util.file_util import get_file_content from setting.models_provider.base_model_provider import IModelProvider, ModelProvideInfo, BaseModelCredential, \ ModelInfo, \ ModelTypeConst, ValidCode +from setting.models_provider.impl.openai_model_provider.model.openai_chat_model import OpenAIChatModel from smartdoc.conf import PROJECT_DIR @@ -71,8 +71,8 @@ class OpenAIModelProvider(IModelProvider): def get_dialogue_number(self): return 3 - def get_model(self, model_type, model_name, model_credential: Dict[str, object], **model_kwargs) -> ChatOpenAI: - azure_chat_open_ai = ChatOpenAI( + def get_model(self, model_type, model_name, model_credential: Dict[str, object], **model_kwargs) -> OpenAIChatModel: + azure_chat_open_ai = OpenAIChatModel( model=model_name, openai_api_base=model_credential.get('api_base'), openai_api_key=model_credential.get('api_key')