From 6e281f62426c42d6d5439aeacddac7c30614afa4 Mon Sep 17 00:00:00 2001 From: CaptainB Date: Fri, 10 Jan 2025 10:46:26 +0800 Subject: [PATCH] refactor: support volcanic engine embeddings --- .../credential/embedding.py | 2 +- .../volcanic_engine_model_provider/model/embedding.py | 7 ++++--- .../volcanic_engine_model_provider.py | 9 +++++++-- 3 files changed, 12 insertions(+), 6 deletions(-) diff --git a/apps/setting/models_provider/impl/volcanic_engine_model_provider/credential/embedding.py b/apps/setting/models_provider/impl/volcanic_engine_model_provider/credential/embedding.py index 15c1b2add..58f70c02e 100644 --- a/apps/setting/models_provider/impl/volcanic_engine_model_provider/credential/embedding.py +++ b/apps/setting/models_provider/impl/volcanic_engine_model_provider/credential/embedding.py @@ -14,7 +14,7 @@ from common.forms import BaseForm from setting.models_provider.base_model_provider import BaseModelCredential, ValidCode -class OpenAIEmbeddingCredential(BaseForm, BaseModelCredential): +class VolcanicEmbeddingCredential(BaseForm, BaseModelCredential): def is_valid(self, model_type: str, model_name, model_credential: Dict[str, object], model_params, provider, raise_exception=True): model_type_list = provider.get_model_type_list() diff --git a/apps/setting/models_provider/impl/volcanic_engine_model_provider/model/embedding.py b/apps/setting/models_provider/impl/volcanic_engine_model_provider/model/embedding.py index b7307a0e5..b950beacf 100644 --- a/apps/setting/models_provider/impl/volcanic_engine_model_provider/model/embedding.py +++ b/apps/setting/models_provider/impl/volcanic_engine_model_provider/model/embedding.py @@ -1,15 +1,16 @@ from typing import Dict -from langchain_community.embeddings import VolcanoEmbeddings +from langchain_openai import OpenAIEmbeddings from setting.models_provider.base_model_provider import MaxKBBaseModel -class VolcanicEngineEmbeddingModel(MaxKBBaseModel, VolcanoEmbeddings): +class VolcanicEngineEmbeddingModel(MaxKBBaseModel, OpenAIEmbeddings): @staticmethod def new_instance(model_type, model_name, model_credential: Dict[str, object], **model_kwargs): return VolcanicEngineEmbeddingModel( - api_key=model_credential.get('api_key'), + openai_api_key=model_credential.get('api_key'), model=model_name, openai_api_base=model_credential.get('api_base'), + check_embedding_ctx_length=False, ) diff --git a/apps/setting/models_provider/impl/volcanic_engine_model_provider/volcanic_engine_model_provider.py b/apps/setting/models_provider/impl/volcanic_engine_model_provider/volcanic_engine_model_provider.py index b0cb1343f..dc2c66094 100644 --- a/apps/setting/models_provider/impl/volcanic_engine_model_provider/volcanic_engine_model_provider.py +++ b/apps/setting/models_provider/impl/volcanic_engine_model_provider/volcanic_engine_model_provider.py @@ -14,10 +14,12 @@ from setting.models_provider.base_model_provider import IModelProvider, ModelPro from setting.models_provider.impl.openai_model_provider.credential.embedding import OpenAIEmbeddingCredential from setting.models_provider.impl.openai_model_provider.credential.llm import OpenAILLMModelCredential from setting.models_provider.impl.openai_model_provider.model.embedding import OpenAIEmbeddingModel +from setting.models_provider.impl.volcanic_engine_model_provider.credential.embedding import VolcanicEmbeddingCredential from setting.models_provider.impl.volcanic_engine_model_provider.credential.image import \ VolcanicEngineImageModelCredential from setting.models_provider.impl.volcanic_engine_model_provider.credential.tti import VolcanicEngineTTIModelCredential from setting.models_provider.impl.volcanic_engine_model_provider.credential.tts import VolcanicEngineTTSModelCredential +from setting.models_provider.impl.volcanic_engine_model_provider.model.embedding import VolcanicEngineEmbeddingModel from setting.models_provider.impl.volcanic_engine_model_provider.model.image import VolcanicEngineImage from setting.models_provider.impl.volcanic_engine_model_provider.model.llm import VolcanicEngineChatModel from setting.models_provider.impl.volcanic_engine_model_provider.credential.stt import VolcanicEngineSTTModelCredential @@ -81,12 +83,13 @@ model_info_list = [ ), ] -open_ai_embedding_credential = OpenAIEmbeddingCredential() +open_ai_embedding_credential = VolcanicEmbeddingCredential() model_info_embedding_list = [ ModelInfo('ep-xxxxxxxxxx-yyyy', '用户前往火山方舟的模型推理页面创建推理接入点,这里需要输入ep-xxxxxxxxxx-yyyy进行调用', ModelTypeConst.EMBEDDING, open_ai_embedding_credential, - OpenAIEmbeddingModel)] + VolcanicEngineEmbeddingModel) +] model_info_manage = ( ModelInfoManage.builder() @@ -96,6 +99,8 @@ model_info_manage = ( .append_default_model_info(model_info_list[2]) .append_default_model_info(model_info_list[3]) .append_default_model_info(model_info_list[4]) + .append_model_info_list(model_info_embedding_list) + .append_default_model_info(model_info_embedding_list[0]) .build() )