mirror of
https://github.com/1Panel-dev/MaxKB.git
synced 2025-12-26 01:33:05 +00:00
fix: 修复openai模型在对接其他兼容openai接口平台时获取tokens错误 (#157)
This commit is contained in:
parent
d214e31f71
commit
0b083eecee
|
|
@ -0,0 +1,42 @@
|
|||
# coding=utf-8
|
||||
"""
|
||||
@project: maxkb
|
||||
@Author:虎
|
||||
@file: openai_chat_model.py
|
||||
@date:2024/4/18 15:28
|
||||
@desc:
|
||||
"""
|
||||
from typing import List
|
||||
|
||||
from langchain_core.messages import BaseMessage, get_buffer_string
|
||||
from langchain_openai import ChatOpenAI
|
||||
|
||||
|
||||
class TokenizerManage:
|
||||
tokenizer = None
|
||||
|
||||
@staticmethod
|
||||
def get_tokenizer():
|
||||
from transformers import GPT2TokenizerFast
|
||||
if TokenizerManage.tokenizer is None:
|
||||
TokenizerManage.tokenizer = GPT2TokenizerFast.from_pretrained('gpt2',
|
||||
cache_dir="/opt/maxkb/model/tokenizer",
|
||||
resume_download=False,
|
||||
force_download=False)
|
||||
return TokenizerManage.tokenizer
|
||||
|
||||
|
||||
class OpenAIChatModel(ChatOpenAI):
|
||||
def get_num_tokens_from_messages(self, messages: List[BaseMessage]) -> int:
|
||||
try:
|
||||
return super().get_num_tokens_from_messages(messages)
|
||||
except Exception as e:
|
||||
tokenizer = TokenizerManage.get_tokenizer()
|
||||
return sum([len(tokenizer.encode(get_buffer_string([m]))) for m in messages])
|
||||
|
||||
def get_num_tokens(self, text: str) -> int:
|
||||
try:
|
||||
return super().get_num_tokens(text)
|
||||
except Exception as e:
|
||||
tokenizer = TokenizerManage.get_tokenizer()
|
||||
return len(tokenizer.encode(text))
|
||||
|
|
@ -10,7 +10,6 @@ import os
|
|||
from typing import Dict
|
||||
|
||||
from langchain.schema import HumanMessage
|
||||
from langchain_openai import ChatOpenAI
|
||||
|
||||
from common import forms
|
||||
from common.exception.app_exception import AppApiException
|
||||
|
|
@ -19,6 +18,7 @@ from common.util.file_util import get_file_content
|
|||
from setting.models_provider.base_model_provider import IModelProvider, ModelProvideInfo, BaseModelCredential, \
|
||||
ModelInfo, \
|
||||
ModelTypeConst, ValidCode
|
||||
from setting.models_provider.impl.openai_model_provider.model.openai_chat_model import OpenAIChatModel
|
||||
from smartdoc.conf import PROJECT_DIR
|
||||
|
||||
|
||||
|
|
@ -71,8 +71,8 @@ class OpenAIModelProvider(IModelProvider):
|
|||
def get_dialogue_number(self):
|
||||
return 3
|
||||
|
||||
def get_model(self, model_type, model_name, model_credential: Dict[str, object], **model_kwargs) -> ChatOpenAI:
|
||||
azure_chat_open_ai = ChatOpenAI(
|
||||
def get_model(self, model_type, model_name, model_credential: Dict[str, object], **model_kwargs) -> OpenAIChatModel:
|
||||
azure_chat_open_ai = OpenAIChatModel(
|
||||
model=model_name,
|
||||
openai_api_base=model_credential.get('api_base'),
|
||||
openai_api_key=model_credential.get('api_key')
|
||||
|
|
|
|||
Loading…
Reference in New Issue