From 8cdb0857348e778fc64867d9e93f7d5277a6fa40 Mon Sep 17 00:00:00 2001 From: wxg0103 <727495428@qq.com> Date: Tue, 26 Aug 2025 13:20:08 +0800 Subject: [PATCH] feat: add support for v2 API version in embedding models and update validation logic --- .../impl/openai_model_provider/model/llm.py | 1 - .../credential/embedding.py | 49 ++++++++++++++--- .../wenxin_model_provider/model/embedding.py | 52 +++++++++++++++++-- 3 files changed, 88 insertions(+), 14 deletions(-) diff --git a/apps/models_provider/impl/openai_model_provider/model/llm.py b/apps/models_provider/impl/openai_model_provider/model/llm.py index 2d5b76af7..798338037 100644 --- a/apps/models_provider/impl/openai_model_provider/model/llm.py +++ b/apps/models_provider/impl/openai_model_provider/model/llm.py @@ -9,7 +9,6 @@ from typing import List, Dict from langchain_core.messages import BaseMessage, get_buffer_string -from langchain_openai.chat_models import ChatOpenAI from common.config.tokenizer_manage_config import TokenizerManage from models_provider.base_model_provider import MaxKBBaseModel diff --git a/apps/models_provider/impl/wenxin_model_provider/credential/embedding.py b/apps/models_provider/impl/wenxin_model_provider/credential/embedding.py index 5d72b773a..a0fc15d13 100644 --- a/apps/models_provider/impl/wenxin_model_provider/credential/embedding.py +++ b/apps/models_provider/impl/wenxin_model_provider/credential/embedding.py @@ -21,11 +21,27 @@ class QianfanEmbeddingCredential(BaseForm, BaseModelCredential): def is_valid(self, model_type: str, model_name, model_credential: Dict[str, object], model_params, provider, raise_exception=False): - model_type_list = provider.get_model_type_list() - if not any(list(filter(lambda mt: mt.get('value') == model_type, model_type_list))): - raise AppApiException(ValidCode.valid_error.value, - _('{model_type} Model type is not supported').format(model_type=model_type)) - self.valid_form(model_credential) + api_version = model_credential.get('api_version', 'v1') + model = provider.get_model(model_type, model_name, model_credential, **model_params) + if api_version == 'v1': + model_type_list = provider.get_model_type_list() + if not any(list(filter(lambda mt: mt.get('value') == model_type, model_type_list))): + raise AppApiException(ValidCode.valid_error.value, + _('{model_type} Model type is not supported').format(model_type=model_type)) + model_info = [model.lower() for model in model.client.models()] + if not model_info.__contains__(model_name.lower()): + raise AppApiException(ValidCode.valid_error.value, + _('{model_name} The model does not support').format(model_name=model_name)) + required_keys = ['qianfan_ak', 'qianfan_sk'] + if api_version == 'v2': + required_keys = ['api_base', 'qianfan_ak'] + + for key in required_keys: + if key not in model_credential: + if raise_exception: + raise AppApiException(ValidCode.valid_error.value, _('{key} is required').format(key=key)) + else: + return False try: model = provider.get_model(model_type, model_name, model_credential) model.embed_query(_('Hello')) @@ -42,8 +58,25 @@ class QianfanEmbeddingCredential(BaseForm, BaseModelCredential): return True def encryption_dict(self, model: Dict[str, object]): - return {**model, 'qianfan_sk': super().encryption(model.get('qianfan_sk', ''))} + api_version = model.get('api_version', 'v1') + if api_version == 'v1': + return {**model, 'qianfan_sk': super().encryption(model.get('qianfan_sk', ''))} + else: # v2 + return {**model, 'qianfan_ak': super().encryption(model.get('qianfan_ak', ''))} + api_version = forms.Radio('API Version', required=True, text_field='label', value_field='value', + option_list=[ + {'label': 'v1', 'value': 'v1'}, + {'label': 'v2', 'value': 'v2'} + ], + default_value='v1', + provider='', + method='', ) + + # v2版本字段 + api_base = forms.TextInputField("API URL", required=True, relation_show_field_dict={"api_version": ["v2"]}) + + # v1版本字段 qianfan_ak = forms.PasswordInputField('API Key', required=True) - - qianfan_sk = forms.PasswordInputField("Secret Key", required=True) + qianfan_sk = forms.PasswordInputField("Secret Key", required=True, + relation_show_field_dict={"api_version": ["v1"]}) diff --git a/apps/models_provider/impl/wenxin_model_provider/model/embedding.py b/apps/models_provider/impl/wenxin_model_provider/model/embedding.py index 418aabde5..cfa555d92 100644 --- a/apps/models_provider/impl/wenxin_model_provider/model/embedding.py +++ b/apps/models_provider/impl/wenxin_model_provider/model/embedding.py @@ -6,18 +6,60 @@ @date:2024/10/17 16:48 @desc: """ -from typing import Dict - +from typing import Dict, List from langchain_community.embeddings import QianfanEmbeddingsEndpoint - +import openai from models_provider.base_model_provider import MaxKBBaseModel -class QianfanEmbeddings(MaxKBBaseModel, QianfanEmbeddingsEndpoint): +class QianfanV1Embeddings(MaxKBBaseModel, QianfanEmbeddingsEndpoint): @staticmethod def new_instance(model_type, model_name, model_credential: Dict[str, object], **model_kwargs): - return QianfanEmbeddings( + return QianfanV1Embeddings( model=model_name, qianfan_ak=model_credential.get('qianfan_ak'), qianfan_sk=model_credential.get('qianfan_sk'), ) + + +class QianfanV2EmbeddingModel(MaxKBBaseModel): + model_name: str + + @staticmethod + def is_cache_model(): + return False + + def __init__(self, api_key, base_url, model_name: str): + self.client = openai.OpenAI(api_key=api_key, base_url=base_url).embeddings + self.model_name = model_name + + @staticmethod + def new_instance(model_type, model_name, model_credential: Dict[str, object], **model_kwargs): + return QianfanV2EmbeddingModel( + api_key=model_credential.get('qianfan_ak'), + model_name=model_name, + base_url=model_credential.get('api_base'), + ) + + def embed_query(self, text: str): + res = self.embed_documents([text]) + return res[0] + + def embed_documents( + self, texts: List[ str], + ) -> List[List[float]]: + res = self.client.create(input=texts, model=self.model_name, encoding_format="float") + return [e.embedding for e in res.data] + + +class QianfanEmbeddings(MaxKBBaseModel): + + + @staticmethod + def new_instance(model_type, model_name, model_credential: Dict[str, object], **model_kwargs): + api_version = model_credential.get('api_version', 'v1') + + if api_version == "v1": + return QianfanV1Embeddings.new_instance(model_type, model_name, model_credential, **model_kwargs) + elif api_version == "v2": + return QianfanV2EmbeddingModel.new_instance(model_type, model_name, model_credential, **model_kwargs)