refactor: support volcanic engine embeddings
Some checks failed
sync2gitee / repo-sync (push) Has been cancelled
Typos Check / Spell Check with Typos (push) Has been cancelled

This commit is contained in:
CaptainB 2025-01-10 10:46:26 +08:00 committed by 刘瑞斌
parent de85895ad6
commit 6e281f6242
3 changed files with 12 additions and 6 deletions

View File

@ -14,7 +14,7 @@ from common.forms import BaseForm
from setting.models_provider.base_model_provider import BaseModelCredential, ValidCode
class OpenAIEmbeddingCredential(BaseForm, BaseModelCredential):
class VolcanicEmbeddingCredential(BaseForm, BaseModelCredential):
def is_valid(self, model_type: str, model_name, model_credential: Dict[str, object], model_params, provider,
raise_exception=True):
model_type_list = provider.get_model_type_list()

View File

@ -1,15 +1,16 @@
from typing import Dict
from langchain_community.embeddings import VolcanoEmbeddings
from langchain_openai import OpenAIEmbeddings
from setting.models_provider.base_model_provider import MaxKBBaseModel
class VolcanicEngineEmbeddingModel(MaxKBBaseModel, VolcanoEmbeddings):
class VolcanicEngineEmbeddingModel(MaxKBBaseModel, OpenAIEmbeddings):
@staticmethod
def new_instance(model_type, model_name, model_credential: Dict[str, object], **model_kwargs):
return VolcanicEngineEmbeddingModel(
api_key=model_credential.get('api_key'),
openai_api_key=model_credential.get('api_key'),
model=model_name,
openai_api_base=model_credential.get('api_base'),
check_embedding_ctx_length=False,
)

View File

@ -14,10 +14,12 @@ from setting.models_provider.base_model_provider import IModelProvider, ModelPro
from setting.models_provider.impl.openai_model_provider.credential.embedding import OpenAIEmbeddingCredential
from setting.models_provider.impl.openai_model_provider.credential.llm import OpenAILLMModelCredential
from setting.models_provider.impl.openai_model_provider.model.embedding import OpenAIEmbeddingModel
from setting.models_provider.impl.volcanic_engine_model_provider.credential.embedding import VolcanicEmbeddingCredential
from setting.models_provider.impl.volcanic_engine_model_provider.credential.image import \
VolcanicEngineImageModelCredential
from setting.models_provider.impl.volcanic_engine_model_provider.credential.tti import VolcanicEngineTTIModelCredential
from setting.models_provider.impl.volcanic_engine_model_provider.credential.tts import VolcanicEngineTTSModelCredential
from setting.models_provider.impl.volcanic_engine_model_provider.model.embedding import VolcanicEngineEmbeddingModel
from setting.models_provider.impl.volcanic_engine_model_provider.model.image import VolcanicEngineImage
from setting.models_provider.impl.volcanic_engine_model_provider.model.llm import VolcanicEngineChatModel
from setting.models_provider.impl.volcanic_engine_model_provider.credential.stt import VolcanicEngineSTTModelCredential
@ -81,12 +83,13 @@ model_info_list = [
),
]
open_ai_embedding_credential = OpenAIEmbeddingCredential()
open_ai_embedding_credential = VolcanicEmbeddingCredential()
model_info_embedding_list = [
ModelInfo('ep-xxxxxxxxxx-yyyy',
'用户前往火山方舟的模型推理页面创建推理接入点这里需要输入ep-xxxxxxxxxx-yyyy进行调用',
ModelTypeConst.EMBEDDING, open_ai_embedding_credential,
OpenAIEmbeddingModel)]
VolcanicEngineEmbeddingModel)
]
model_info_manage = (
ModelInfoManage.builder()
@ -96,6 +99,8 @@ model_info_manage = (
.append_default_model_info(model_info_list[2])
.append_default_model_info(model_info_list[3])
.append_default_model_info(model_info_list[4])
.append_model_info_list(model_info_embedding_list)
.append_default_model_info(model_info_embedding_list[0])
.build()
)