From 15feca802a4315d6d6aeaa4a29c8dd4cc0240f95 Mon Sep 17 00:00:00 2001 From: shaohuzhang1 <80892890+shaohuzhang1@users.noreply.github.com> Date: Wed, 2 Apr 2025 17:45:18 +0800 Subject: [PATCH] fix: OpenAI Vector Model Using Openai Supplier (#2781) --- .../openai_model_provider/model/embedding.py | 26 +++++++++++++++---- 1 file changed, 21 insertions(+), 5 deletions(-) diff --git a/apps/setting/models_provider/impl/openai_model_provider/model/embedding.py b/apps/setting/models_provider/impl/openai_model_provider/model/embedding.py index 5ac1f8e6f..f95e78188 100644 --- a/apps/setting/models_provider/impl/openai_model_provider/model/embedding.py +++ b/apps/setting/models_provider/impl/openai_model_provider/model/embedding.py @@ -6,18 +6,34 @@ @date:2024/7/12 17:44 @desc: """ -from typing import Dict +from typing import Dict, List -from langchain_community.embeddings import OpenAIEmbeddings +import openai from setting.models_provider.base_model_provider import MaxKBBaseModel -class OpenAIEmbeddingModel(MaxKBBaseModel, OpenAIEmbeddings): +class OpenAIEmbeddingModel(MaxKBBaseModel): + model_name: str + + def __init__(self, api_key, base_url, model_name: str): + self.client = openai.OpenAI(api_key=api_key, base_url=base_url).embeddings + self.model_name = model_name + @staticmethod def new_instance(model_type, model_name, model_credential: Dict[str, object], **model_kwargs): return OpenAIEmbeddingModel( api_key=model_credential.get('api_key'), - model=model_name, - openai_api_base=model_credential.get('api_base'), + model_name=model_name, + base_url=model_credential.get('api_base'), ) + + def embed_query(self, text: str): + res = self.embed_documents([text]) + return res[0] + + def embed_documents( + self, texts: List[str], chunk_size: int | None = None + ) -> List[List[float]]: + res = self.client.create(input=texts, model=self.model_name, encoding_format="float") + return [e.embedding for e in res.data]