feat: add optional parameters to OpenAIEmbeddingModel for enhanced embedding functionality

This commit is contained in:
CaptainB 2025-09-28 12:03:03 +08:00 committed by 刘瑞斌
parent 01ba883946
commit e4ac7783e3
5 changed files with 63 additions and 50 deletions

View File

@ -13,10 +13,27 @@ from django.utils.translation import gettext as _
from common import forms
from common.exception.app_exception import AppApiException
from common.forms import BaseForm
from common.forms import BaseForm, TooltipLabel
from models_provider.base_model_provider import BaseModelCredential, ValidCode
from models_provider.impl.aliyun_bai_lian_model_provider.model.embedding import AliyunBaiLianEmbedding
class BaiLianEmbeddingModelParams(BaseForm):
dimensions = forms.SingleSelect(
TooltipLabel(
_('Dimensions'),
_('')
),
required=True,
default_value=1024,
value_field='value',
text_field='label',
option_list=[
{'label': '1024', 'value': '1024'},
{'label': '768', 'value': '768'},
{'label': '512', 'value': '512'},
]
)
class AliyunBaiLianEmbeddingCredential(BaseForm, BaseModelCredential):
@ -71,4 +88,8 @@ class AliyunBaiLianEmbeddingCredential(BaseForm, BaseModelCredential):
api_key = model.get('dashscope_api_key', '')
return {**model, 'dashscope_api_key': super().encryption(api_key)}
def get_model_params_setting_form(self, model_name):
return BaiLianEmbeddingModelParams()
dashscope_api_key = forms.PasswordInputField('API Key', required=True)

View File

@ -6,61 +6,43 @@
@date2024/10/16 16:34
@desc:
"""
from functools import reduce
from typing import Dict, List
from langchain_community.embeddings import DashScopeEmbeddings
from langchain_community.embeddings.dashscope import embed_with_retry
from openai import OpenAI
from models_provider.base_model_provider import MaxKBBaseModel
def proxy_embed_documents(texts: List[str], step_size, embed_documents):
value = [embed_documents(texts[start_index:start_index + step_size]) for start_index in
range(0, len(texts), step_size)]
return reduce(lambda x, y: [*x, *y], value, [])
class AliyunBaiLianEmbedding(MaxKBBaseModel):
model_name: str
optional_params: dict
def __init__(self, api_key, model_name: str, optional_params: dict):
self.client = OpenAI(api_key=api_key, base_url='https://dashscope.aliyuncs.com/compatible-mode/v1').embeddings
self.model_name = model_name
self.optional_params = optional_params
class AliyunBaiLianEmbedding(MaxKBBaseModel, DashScopeEmbeddings):
@staticmethod
def new_instance(model_type, model_name, model_credential: Dict[str, object], **model_kwargs):
optional_params = MaxKBBaseModel.filter_optional_params(model_kwargs)
return AliyunBaiLianEmbedding(
model=model_name,
dashscope_api_key=model_credential.get('dashscope_api_key')
api_key=model_credential.get('dashscope_api_key'),
model_name=model_name,
optional_params=optional_params
)
def embed_documents(self, texts: List[str]) -> List[List[float]]:
if self.model == 'text-embedding-v3':
return proxy_embed_documents(texts, 6, self._embed_documents)
return self._embed_documents(texts)
def embed_query(self, text: str):
res = self.embed_documents([text])
return res[0]
def _embed_documents(self, texts: List[str]) -> List[List[float]]:
"""Call out to DashScope's embedding endpoint for embedding search docs.
Args:
texts: The list of texts to embed.
chunk_size: The chunk size of embeddings. If None, will use the chunk size
specified by the class.
Returns:
List of embeddings, one for each text.
"""
embeddings = embed_with_retry(
self, input=texts, text_type="document", model=self.model
)
embedding_list = [item["embedding"] for item in embeddings]
return embedding_list
def embed_query(self, text: str) -> List[float]:
"""Call out to DashScope's embedding endpoint for embedding query text.
Args:
text: The text to embed.
Returns:
Embedding for the text.
"""
embedding = embed_with_retry(
self, input=[text], text_type="document", model=self.model
)[0]["embedding"]
return embedding
def embed_documents(
self, texts: List[str], chunk_size: int | None = None
) -> List[List[float]]:
if len(self.optional_params) > 0:
res = self.client.create(
input=texts, model=self.model_name, encoding_format="float",
**self.optional_params
)
else:
res = self.client.create(input=texts, model=self.model_name, encoding_format="float")
return [e.embedding for e in res.data]

View File

@ -15,17 +15,21 @@ from models_provider.base_model_provider import MaxKBBaseModel
class OpenAIEmbeddingModel(MaxKBBaseModel):
model_name: str
optional_params: dict
def __init__(self, api_key, base_url, model_name: str):
def __init__(self, api_key, base_url, model_name: str, optional_params: dict):
self.client = openai.OpenAI(api_key=api_key, base_url=base_url).embeddings
self.model_name = model_name
self.optional_params = optional_params
@staticmethod
def new_instance(model_type, model_name, model_credential: Dict[str, object], **model_kwargs):
optional_params = MaxKBBaseModel.filter_optional_params(model_kwargs)
return OpenAIEmbeddingModel(
api_key=model_credential.get('api_key'),
model_name=model_name,
base_url=model_credential.get('api_base'),
optional_params=optional_params
)
def embed_query(self, text: str):
@ -35,5 +39,11 @@ class OpenAIEmbeddingModel(MaxKBBaseModel):
def embed_documents(
self, texts: List[str], chunk_size: int | None = None
) -> List[List[float]]:
res = self.client.create(input=texts, model=self.model_name, encoding_format="float")
if len(self.optional_params) > 0:
res = self.client.create(
input=texts, model=self.model_name, encoding_format="float",
**self.optional_params
)
else:
res = self.client.create(input=texts, model=self.model_name, encoding_format="float")
return [e.embedding for e in res.data]

View File

@ -140,8 +140,7 @@
/>
<el-empty
v-else-if="
base_form_data.model_type === 'RERANKER' ||
base_form_data.model_type === 'EMBEDDING'
base_form_data.model_type === 'RERANKER'
"
:description="$t('views.model.tip.emptyMessage2')"
/>
@ -150,7 +149,7 @@
<el-button
type="text"
@click.stop="openAddDrawer()"
:disabled="!['TTS', 'LLM', 'IMAGE', 'TTI', 'TTV', 'ITV','STT'].includes(base_form_data.model_type)"
:disabled="!['TTS', 'LLM', 'IMAGE', 'TTI', 'TTV', 'ITV','STT', 'EMBEDDING'].includes(base_form_data.model_type)"
>
<AppIcon iconName="app-add-outlined" class="mr-4"/> {{ $t('common.add') }}
</el-button>

View File

@ -95,6 +95,7 @@
currentModel.model_type === 'IMAGE' ||
currentModel.model_type === 'TTI' ||
currentModel.model_type === 'ITV' ||
currentModel.model_type === 'EMBEDDING' ||
currentModel.model_type === 'TTV') &&
permissionPrecise.paramSetting(model.id)
"