fix: 修复xinference向量模型添加失败的缺陷

(cherry picked from commit fc015d17bc)
This commit is contained in:
wxg0103 2024-11-06 18:59:47 +08:00 committed by shaohuzhang1
parent cb83ed8688
commit cf66e33c58
2 changed files with 73 additions and 3 deletions

View File

@ -15,7 +15,8 @@ class XinferenceEmbeddingModelCredential(BaseForm, BaseModelCredential):
if not any(list(filter(lambda mt: mt.get('value') == model_type, model_type_list))):
raise AppApiException(ValidCode.valid_error.value, f'{model_type} 模型类型不支持')
try:
model_list = provider.get_base_model_list(model_credential.get('api_base'), 'embedding')
model_list = provider.get_base_model_list(model_credential.get('api_base'), model_credential.get('api_key'),
'embedding')
except Exception as e:
raise AppApiException(ValidCode.valid_error.value, "API 域名无效")
exist = provider.get_model_info_by_name(model_list, model_name)
@ -36,3 +37,4 @@ class XinferenceEmbeddingModelCredential(BaseForm, BaseModelCredential):
return self
api_base = forms.TextInputField('API 域名', required=True)
api_key = forms.PasswordInputField('API Key', required=True)

View File

@ -1,18 +1,26 @@
# coding=utf-8
import threading
from typing import Dict
from typing import Dict, Optional, List, Any
from langchain_community.embeddings import XinferenceEmbeddings
from langchain_core.embeddings import Embeddings
from setting.models_provider.base_model_provider import MaxKBBaseModel
class XinferenceEmbedding(MaxKBBaseModel, XinferenceEmbeddings):
class XinferenceEmbedding(MaxKBBaseModel, Embeddings):
client: Any
server_url: Optional[str]
"""URL of the xinference server"""
model_uid: Optional[str]
"""UID of the launched model"""
@staticmethod
def new_instance(model_type, model_name, model_credential: Dict[str, object], **model_kwargs):
return XinferenceEmbedding(
model_uid=model_name,
server_url=model_credential.get('api_base'),
api_key=model_credential.get('api_key'),
)
def down_model(self):
@ -22,3 +30,63 @@ class XinferenceEmbedding(MaxKBBaseModel, XinferenceEmbeddings):
thread = threading.Thread(target=self.down_model)
thread.daemon = True
thread.start()
def __init__(
self, server_url: Optional[str] = None, model_uid: Optional[str] = None,
api_key: Optional[str] = None
):
try:
from xinference.client import RESTfulClient
except ImportError:
try:
from xinference_client import RESTfulClient
except ImportError as e:
raise ImportError(
"Could not import RESTfulClient from xinference. Please install it"
" with `pip install xinference` or `pip install xinference_client`."
) from e
if server_url is None:
raise ValueError("Please provide server URL")
if model_uid is None:
raise ValueError("Please provide the model UID")
self.server_url = server_url
self.model_uid = model_uid
self.api_key = api_key
self.client = RESTfulClient(server_url, api_key)
def embed_documents(self, texts: List[str]) -> List[List[float]]:
"""Embed a list of documents using Xinference.
Args:
texts: The list of texts to embed.
Returns:
List of embeddings, one for each text.
"""
model = self.client.get_model(self.model_uid)
embeddings = [
model.create_embedding(text)["data"][0]["embedding"] for text in texts
]
return [list(map(float, e)) for e in embeddings]
def embed_query(self, text: str) -> List[float]:
"""Embed a query of documents using Xinference.
Args:
text: The text to embed.
Returns:
Embeddings for the text.
"""
model = self.client.get_model(self.model_uid)
embedding_res = model.create_embedding(text)
embedding = embedding_res["data"][0]["embedding"]
return list(map(float, embedding))