refactor: model

This commit is contained in:
wxg0103 2025-04-25 18:12:35 +08:00
parent 7f492b4d92
commit 83ace97ecc
16 changed files with 24 additions and 37 deletions

View File

@ -18,6 +18,6 @@ class QwenVLChatModel(MaxKBBaseModel, BaseChatOpenAI):
# stream_options={"include_usage": True},
streaming=True,
stream_usage=True,
**optional_params,
extra_body=optional_params
)
return chat_tong_yi

View File

@ -28,5 +28,5 @@ class OllamaImage(MaxKBBaseModel, BaseChatOpenAI):
# stream_options={"include_usage": True},
streaming=True,
stream_usage=True,
**optional_params,
extra_body=optional_params
)

View File

@ -16,5 +16,5 @@ class OpenAIImage(MaxKBBaseModel, BaseChatOpenAI):
# stream_options={"include_usage": True},
streaming=True,
stream_usage=True,
**optional_params,
extra_body=optional_params
)

View File

@ -35,8 +35,8 @@ class OpenAIChatModel(MaxKBBaseModel, BaseChatOpenAI):
streaming = False
chat_open_ai = OpenAIChatModel(
model=model_name,
openai_api_base=model_credential.get('api_base'),
openai_api_key=model_credential.get('api_key'),
api_base=model_credential.get('api_base'),
api_key=model_credential.get('api_key'),
extra_body=optional_params,
streaming=streaming,
custom_get_token_ids=custom_get_token_ids

View File

@ -16,5 +16,5 @@ class SiliconCloudImage(MaxKBBaseModel, BaseChatOpenAI):
# stream_options={"include_usage": True},
streaming=True,
stream_usage=True,
**optional_params,
extra_body=optional_params
)

View File

@ -33,21 +33,8 @@ class TencentCloudChatModel(MaxKBBaseModel, BaseChatOpenAI):
model=model_name,
openai_api_base=model_credential.get('api_base'),
openai_api_key=model_credential.get('api_key'),
**optional_params,
extra_body=optional_params,
custom_get_token_ids=custom_get_token_ids
)
return azure_chat_open_ai
def get_num_tokens_from_messages(self, messages: List[BaseMessage]) -> int:
try:
return super().get_num_tokens_from_messages(messages)
except Exception as e:
tokenizer = TokenizerManage.get_tokenizer()
return sum([len(tokenizer.encode(get_buffer_string([m]))) for m in messages])
def get_num_tokens(self, text: str) -> int:
try:
return super().get_num_tokens(text)
except Exception as e:
tokenizer = TokenizerManage.get_tokenizer()
return len(tokenizer.encode(text))

View File

@ -11,10 +11,10 @@ class TencentVision(MaxKBBaseModel, BaseChatOpenAI):
optional_params = MaxKBBaseModel.filter_optional_params(model_kwargs)
return TencentVision(
model_name=model_name,
openai_api_base='https://api.hunyuan.cloud.tencent.com/v1',
openai_api_key=model_credential.get('api_key'),
api_base='https://api.hunyuan.cloud.tencent.com/v1',
api_key=model_credential.get('api_key'),
# stream_options={"include_usage": True},
streaming=True,
stream_usage=True,
**optional_params,
extra_body=optional_params
)

View File

@ -19,7 +19,7 @@ class VllmImage(MaxKBBaseModel, BaseChatOpenAI):
# stream_options={"include_usage": True},
streaming=True,
stream_usage=True,
**optional_params,
extra_body=optional_params
)
def is_cache_model(self):

View File

@ -16,5 +16,5 @@ class VolcanicEngineImage(MaxKBBaseModel, BaseChatOpenAI):
# stream_options={"include_usage": True},
streaming=True,
stream_usage=True,
**optional_params,
extra_body=optional_params
)

View File

@ -27,15 +27,15 @@ from django.utils.translation import gettext as _
ssl._create_default_https_context = ssl.create_default_context()
qwen_model_credential = XunFeiLLMModelCredential()
xunfei_model_credential = XunFeiLLMModelCredential()
stt_model_credential = XunFeiSTTModelCredential()
image_model_credential = XunFeiImageModelCredential()
tts_model_credential = XunFeiTTSModelCredential()
embedding_model_credential = XFEmbeddingCredential()
model_info_list = [
ModelInfo('generalv3.5', '', ModelTypeConst.LLM, qwen_model_credential, XFChatSparkLLM),
ModelInfo('generalv3', '', ModelTypeConst.LLM, qwen_model_credential, XFChatSparkLLM),
ModelInfo('generalv2', '', ModelTypeConst.LLM, qwen_model_credential, XFChatSparkLLM),
ModelInfo('generalv3.5', '', ModelTypeConst.LLM, xunfei_model_credential, XFChatSparkLLM),
ModelInfo('generalv3', '', ModelTypeConst.LLM, xunfei_model_credential, XFChatSparkLLM),
ModelInfo('generalv2', '', ModelTypeConst.LLM, xunfei_model_credential, XFChatSparkLLM),
ModelInfo('iat', _('Chinese and English recognition'), ModelTypeConst.STT, stt_model_credential, XFSparkSpeechToText),
ModelInfo('tts', '', ModelTypeConst.TTS, tts_model_credential, XFSparkTextToSpeech),
ModelInfo('embedding', '', ModelTypeConst.EMBEDDING, embedding_model_credential, XFEmbedding)
@ -45,7 +45,7 @@ model_info_manage = (
ModelInfoManage.builder()
.append_model_info_list(model_info_list)
.append_default_model_info(
ModelInfo('generalv3.5', '', ModelTypeConst.LLM, qwen_model_credential, XFChatSparkLLM))
ModelInfo('generalv3.5', '', ModelTypeConst.LLM, xunfei_model_credential, XFChatSparkLLM))
.append_default_model_info(
ModelInfo('iat', _('Chinese and English recognition'), ModelTypeConst.STT, stt_model_credential, XFSparkSpeechToText),
)

View File

@ -19,7 +19,7 @@ class XinferenceImage(MaxKBBaseModel, BaseChatOpenAI):
# stream_options={"include_usage": True},
streaming=True,
stream_usage=True,
**optional_params,
extra_body=optional_params
)
def get_num_tokens_from_messages(self, messages: List[BaseMessage]) -> int:

View File

@ -16,5 +16,5 @@ class ZhiPuImage(MaxKBBaseModel, BaseChatOpenAI):
# stream_options={"include_usage": True},
streaming=True,
stream_usage=True,
**optional_params,
extra_body=optional_params
)

View File

@ -20,14 +20,14 @@ from models_provider.impl.zhipu_model_provider.model.tti import ZhiPuTextToImage
from maxkb.conf import PROJECT_DIR
from django.utils.translation import gettext as _
qwen_model_credential = ZhiPuLLMModelCredential()
zhipu_model_credential = ZhiPuLLMModelCredential()
zhipu_image_model_credential = ZhiPuImageModelCredential()
zhipu_tti_model_credential = ZhiPuTextToImageModelCredential()
model_info_list = [
ModelInfo('glm-4', '', ModelTypeConst.LLM, qwen_model_credential, ZhipuChatModel),
ModelInfo('glm-4v', '', ModelTypeConst.LLM, qwen_model_credential, ZhipuChatModel),
ModelInfo('glm-3-turbo', '', ModelTypeConst.LLM, qwen_model_credential, ZhipuChatModel)
ModelInfo('glm-4', '', ModelTypeConst.LLM, zhipu_model_credential, ZhipuChatModel),
ModelInfo('glm-4v', '', ModelTypeConst.LLM, zhipu_model_credential, ZhipuChatModel),
ModelInfo('glm-3-turbo', '', ModelTypeConst.LLM, zhipu_model_credential, ZhipuChatModel)
]
model_info_image_list = [
@ -57,7 +57,7 @@ model_info_tti_list = [
model_info_manage = (
ModelInfoManage.builder()
.append_model_info_list(model_info_list)
.append_default_model_info(ModelInfo('glm-4', '', ModelTypeConst.LLM, qwen_model_credential, ZhipuChatModel))
.append_default_model_info(ModelInfo('glm-4', '', ModelTypeConst.LLM, zhipu_model_credential, ZhipuChatModel))
.append_model_info_list(model_info_image_list)
.append_default_model_info(model_info_image_list[0])
.append_model_info_list(model_info_tti_list)