MaxKB/apps/models_provider/impl/openai_model_provider/credential/embedding.py
CaptainB 18543a0703 feat: add OpenAI embedding model parameters form
--bug=1062976 --user=刘瑞斌 【模型】openai的向量模型,设置dimensions参数后,向量化失败 https://www.tapd.cn/62980211/s/1790130
2025-10-27 12:46:49 +08:00

75 lines
2.7 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

# coding=utf-8
"""
@project: MaxKB
@Author
@file embedding.py
@date2024/7/12 16:45
@desc:
"""
import traceback
from typing import Dict
from django.utils.translation import gettext as _
from common import forms
from common.exception.app_exception import AppApiException
from common.forms import BaseForm, TooltipLabel
from models_provider.base_model_provider import BaseModelCredential, ValidCode
class OpenAIEmbeddingModelParams(BaseForm):
dimensions = forms.SingleSelect(
TooltipLabel(
_('Dimensions'),
_('')
),
required=True,
default_value=1024,
value_field='value',
text_field='label',
option_list=[
{'label': '1536', 'value': '1536'},
{'label': '1024', 'value': '1024'},
{'label': '768', 'value': '768'},
{'label': '512', 'value': '512'},
]
)
class OpenAIEmbeddingCredential(BaseForm, BaseModelCredential):
def is_valid(self, model_type: str, model_name, model_credential: Dict[str, object], model_params, provider,
raise_exception=True):
model_type_list = provider.get_model_type_list()
if not any(list(filter(lambda mt: mt.get('value') == model_type, model_type_list))):
raise AppApiException(ValidCode.valid_error.value,
_('{model_type} Model type is not supported').format(model_type=model_type))
for key in ['api_base', 'api_key']:
if key not in model_credential:
if raise_exception:
raise AppApiException(ValidCode.valid_error.value, _('{key} is required').format(key=key))
else:
return False
try:
model = provider.get_model(model_type, model_name, model_credential)
model.embed_query(_('Hello'))
except Exception as e:
traceback.print_exc()
if isinstance(e, AppApiException):
raise e
if raise_exception:
raise AppApiException(ValidCode.valid_error.value,
_('Verification failed, please check whether the parameters are correct: {error}').format(
error=str(e)))
else:
return False
return True
def encryption_dict(self, model: Dict[str, object]):
return {**model, 'api_key': super().encryption(model.get('api_key', ''))}
def get_model_params_setting_form(self, model_name):
return OpenAIEmbeddingModelParams()
api_base = forms.TextInputField('API URL', required=True)
api_key = forms.PasswordInputField('API Key', required=True)