mirror of
https://github.com/1Panel-dev/MaxKB.git
synced 2025-12-25 17:22:55 +00:00
feat: add optional parameters to OpenAIEmbeddingModel for enhanced embedding functionality
This commit is contained in:
parent
6b23469c29
commit
ae2e08d90d
|
|
@ -13,10 +13,27 @@ from django.utils.translation import gettext as _
|
|||
|
||||
from common import forms
|
||||
from common.exception.app_exception import AppApiException
|
||||
from common.forms import BaseForm
|
||||
from common.forms import BaseForm, TooltipLabel
|
||||
from models_provider.base_model_provider import BaseModelCredential, ValidCode
|
||||
from models_provider.impl.aliyun_bai_lian_model_provider.model.embedding import AliyunBaiLianEmbedding
|
||||
|
||||
class BaiLianEmbeddingModelParams(BaseForm):
|
||||
dimensions = forms.SingleSelect(
|
||||
TooltipLabel(
|
||||
_('Dimensions'),
|
||||
_('')
|
||||
),
|
||||
required=True,
|
||||
default_value=1024,
|
||||
value_field='value',
|
||||
text_field='label',
|
||||
option_list=[
|
||||
{'label': '1024', 'value': '1024'},
|
||||
{'label': '768', 'value': '768'},
|
||||
{'label': '512', 'value': '512'},
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
class AliyunBaiLianEmbeddingCredential(BaseForm, BaseModelCredential):
|
||||
|
||||
|
|
@ -71,4 +88,8 @@ class AliyunBaiLianEmbeddingCredential(BaseForm, BaseModelCredential):
|
|||
api_key = model.get('dashscope_api_key', '')
|
||||
return {**model, 'dashscope_api_key': super().encryption(api_key)}
|
||||
|
||||
|
||||
def get_model_params_setting_form(self, model_name):
|
||||
return BaiLianEmbeddingModelParams()
|
||||
|
||||
dashscope_api_key = forms.PasswordInputField('API Key', required=True)
|
||||
|
|
|
|||
|
|
@ -6,61 +6,43 @@
|
|||
@date:2024/10/16 16:34
|
||||
@desc:
|
||||
"""
|
||||
from functools import reduce
|
||||
from typing import Dict, List
|
||||
|
||||
from langchain_community.embeddings import DashScopeEmbeddings
|
||||
from langchain_community.embeddings.dashscope import embed_with_retry
|
||||
from openai import OpenAI
|
||||
|
||||
from models_provider.base_model_provider import MaxKBBaseModel
|
||||
|
||||
|
||||
def proxy_embed_documents(texts: List[str], step_size, embed_documents):
|
||||
value = [embed_documents(texts[start_index:start_index + step_size]) for start_index in
|
||||
range(0, len(texts), step_size)]
|
||||
return reduce(lambda x, y: [*x, *y], value, [])
|
||||
class AliyunBaiLianEmbedding(MaxKBBaseModel):
|
||||
model_name: str
|
||||
optional_params: dict
|
||||
|
||||
def __init__(self, api_key, model_name: str, optional_params: dict):
|
||||
self.client = OpenAI(api_key=api_key, base_url='https://dashscope.aliyuncs.com/compatible-mode/v1').embeddings
|
||||
self.model_name = model_name
|
||||
self.optional_params = optional_params
|
||||
|
||||
class AliyunBaiLianEmbedding(MaxKBBaseModel, DashScopeEmbeddings):
|
||||
@staticmethod
|
||||
def new_instance(model_type, model_name, model_credential: Dict[str, object], **model_kwargs):
|
||||
optional_params = MaxKBBaseModel.filter_optional_params(model_kwargs)
|
||||
return AliyunBaiLianEmbedding(
|
||||
model=model_name,
|
||||
dashscope_api_key=model_credential.get('dashscope_api_key')
|
||||
api_key=model_credential.get('dashscope_api_key'),
|
||||
model_name=model_name,
|
||||
optional_params=optional_params
|
||||
)
|
||||
|
||||
def embed_documents(self, texts: List[str]) -> List[List[float]]:
|
||||
if self.model == 'text-embedding-v3':
|
||||
return proxy_embed_documents(texts, 6, self._embed_documents)
|
||||
return self._embed_documents(texts)
|
||||
def embed_query(self, text: str):
|
||||
res = self.embed_documents([text])
|
||||
return res[0]
|
||||
|
||||
def _embed_documents(self, texts: List[str]) -> List[List[float]]:
|
||||
"""Call out to DashScope's embedding endpoint for embedding search docs.
|
||||
|
||||
Args:
|
||||
texts: The list of texts to embed.
|
||||
chunk_size: The chunk size of embeddings. If None, will use the chunk size
|
||||
specified by the class.
|
||||
|
||||
Returns:
|
||||
List of embeddings, one for each text.
|
||||
"""
|
||||
embeddings = embed_with_retry(
|
||||
self, input=texts, text_type="document", model=self.model
|
||||
)
|
||||
embedding_list = [item["embedding"] for item in embeddings]
|
||||
return embedding_list
|
||||
|
||||
def embed_query(self, text: str) -> List[float]:
|
||||
"""Call out to DashScope's embedding endpoint for embedding query text.
|
||||
|
||||
Args:
|
||||
text: The text to embed.
|
||||
|
||||
Returns:
|
||||
Embedding for the text.
|
||||
"""
|
||||
embedding = embed_with_retry(
|
||||
self, input=[text], text_type="document", model=self.model
|
||||
)[0]["embedding"]
|
||||
return embedding
|
||||
def embed_documents(
|
||||
self, texts: List[str], chunk_size: int | None = None
|
||||
) -> List[List[float]]:
|
||||
if len(self.optional_params) > 0:
|
||||
res = self.client.create(
|
||||
input=texts, model=self.model_name, encoding_format="float",
|
||||
**self.optional_params
|
||||
)
|
||||
else:
|
||||
res = self.client.create(input=texts, model=self.model_name, encoding_format="float")
|
||||
return [e.embedding for e in res.data]
|
||||
|
|
|
|||
|
|
@ -15,17 +15,21 @@ from models_provider.base_model_provider import MaxKBBaseModel
|
|||
|
||||
class OpenAIEmbeddingModel(MaxKBBaseModel):
|
||||
model_name: str
|
||||
optional_params: dict
|
||||
|
||||
def __init__(self, api_key, base_url, model_name: str):
|
||||
def __init__(self, api_key, base_url, model_name: str, optional_params: dict):
|
||||
self.client = openai.OpenAI(api_key=api_key, base_url=base_url).embeddings
|
||||
self.model_name = model_name
|
||||
self.optional_params = optional_params
|
||||
|
||||
@staticmethod
|
||||
def new_instance(model_type, model_name, model_credential: Dict[str, object], **model_kwargs):
|
||||
optional_params = MaxKBBaseModel.filter_optional_params(model_kwargs)
|
||||
return OpenAIEmbeddingModel(
|
||||
api_key=model_credential.get('api_key'),
|
||||
model_name=model_name,
|
||||
base_url=model_credential.get('api_base'),
|
||||
optional_params=optional_params
|
||||
)
|
||||
|
||||
def embed_query(self, text: str):
|
||||
|
|
@ -35,5 +39,11 @@ class OpenAIEmbeddingModel(MaxKBBaseModel):
|
|||
def embed_documents(
|
||||
self, texts: List[str], chunk_size: int | None = None
|
||||
) -> List[List[float]]:
|
||||
res = self.client.create(input=texts, model=self.model_name, encoding_format="float")
|
||||
if len(self.optional_params) > 0:
|
||||
res = self.client.create(
|
||||
input=texts, model=self.model_name, encoding_format="float",
|
||||
**self.optional_params
|
||||
)
|
||||
else:
|
||||
res = self.client.create(input=texts, model=self.model_name, encoding_format="float")
|
||||
return [e.embedding for e in res.data]
|
||||
|
|
|
|||
|
|
@ -140,8 +140,7 @@
|
|||
/>
|
||||
<el-empty
|
||||
v-else-if="
|
||||
base_form_data.model_type === 'RERANKER' ||
|
||||
base_form_data.model_type === 'EMBEDDING'
|
||||
base_form_data.model_type === 'RERANKER'
|
||||
"
|
||||
:description="$t('views.model.tip.emptyMessage2')"
|
||||
/>
|
||||
|
|
@ -150,7 +149,7 @@
|
|||
<el-button
|
||||
type="text"
|
||||
@click.stop="openAddDrawer()"
|
||||
:disabled="!['TTS', 'LLM', 'IMAGE', 'TTI', 'TTV', 'ITV','STT'].includes(base_form_data.model_type)"
|
||||
:disabled="!['TTS', 'LLM', 'IMAGE', 'TTI', 'TTV', 'ITV','STT', 'EMBEDDING'].includes(base_form_data.model_type)"
|
||||
>
|
||||
<AppIcon iconName="app-add-outlined" class="mr-4"/> {{ $t('common.add') }}
|
||||
</el-button>
|
||||
|
|
|
|||
|
|
@ -95,6 +95,7 @@
|
|||
currentModel.model_type === 'IMAGE' ||
|
||||
currentModel.model_type === 'TTI' ||
|
||||
currentModel.model_type === 'ITV' ||
|
||||
currentModel.model_type === 'EMBEDDING' ||
|
||||
currentModel.model_type === 'TTV') &&
|
||||
permissionPrecise.paramSetting(model.id)
|
||||
"
|
||||
|
|
|
|||
Loading…
Reference in New Issue