diff --git a/apps/setting/models_provider/impl/openai_model_provider/model/embedding.py b/apps/setting/models_provider/impl/openai_model_provider/model/embedding.py index 5ac1f8e6f..f95e78188 100644 --- a/apps/setting/models_provider/impl/openai_model_provider/model/embedding.py +++ b/apps/setting/models_provider/impl/openai_model_provider/model/embedding.py @@ -6,18 +6,34 @@ @date:2024/7/12 17:44 @desc: """ -from typing import Dict +from typing import Dict, List -from langchain_community.embeddings import OpenAIEmbeddings +import openai from setting.models_provider.base_model_provider import MaxKBBaseModel -class OpenAIEmbeddingModel(MaxKBBaseModel, OpenAIEmbeddings): +class OpenAIEmbeddingModel(MaxKBBaseModel): + model_name: str + + def __init__(self, api_key, base_url, model_name: str): + self.client = openai.OpenAI(api_key=api_key, base_url=base_url).embeddings + self.model_name = model_name + @staticmethod def new_instance(model_type, model_name, model_credential: Dict[str, object], **model_kwargs): return OpenAIEmbeddingModel( api_key=model_credential.get('api_key'), - model=model_name, - openai_api_base=model_credential.get('api_base'), + model_name=model_name, + base_url=model_credential.get('api_base'), ) + + def embed_query(self, text: str): + res = self.embed_documents([text]) + return res[0] + + def embed_documents( + self, texts: List[str], chunk_size: int | None = None + ) -> List[List[float]]: + res = self.client.create(input=texts, model=self.model_name, encoding_format="float") + return [e.embedding for e in res.data]