mirror of
https://github.com/1Panel-dev/MaxKB.git
synced 2025-12-26 01:33:05 +00:00
feat: 支持腾讯混元向量模型
This commit is contained in:
parent
6618c6baf3
commit
dead7e8da3
|
|
@ -1,11 +1,5 @@
|
|||
import json
|
||||
from typing import Dict
|
||||
|
||||
from tencentcloud.common import credential
|
||||
from tencentcloud.common.profile.client_profile import ClientProfile
|
||||
from tencentcloud.common.profile.http_profile import HttpProfile
|
||||
from tencentcloud.hunyuan.v20230901 import hunyuan_client, models
|
||||
|
||||
from common import forms
|
||||
from common.exception.app_exception import AppApiException
|
||||
from common.forms import BaseForm
|
||||
|
|
@ -13,48 +7,24 @@ from setting.models_provider.base_model_provider import BaseModelCredential, Val
|
|||
|
||||
|
||||
class TencentEmbeddingCredential(BaseForm, BaseModelCredential):
|
||||
@classmethod
|
||||
def _validate_model_type(cls, model_type: str, provider) -> bool:
|
||||
model_type_list = provider.get_model_type_list()
|
||||
if not any(mt.get('value') == model_type for mt in model_type_list):
|
||||
raise AppApiException(ValidCode.valid_error.value, f'{model_type} 模型类型不支持')
|
||||
return True
|
||||
|
||||
@classmethod
|
||||
def _validate_credential(cls, model_credential: Dict[str, object]) -> credential.Credential:
|
||||
for key in ['SecretId', 'SecretKey']:
|
||||
if key not in model_credential:
|
||||
raise AppApiException(ValidCode.valid_error.value, f'{key} 字段为必填字段')
|
||||
return credential.Credential(model_credential['SecretId'], model_credential['SecretKey'])
|
||||
|
||||
@classmethod
|
||||
def _test_credentials(cls, client, model_name: str):
|
||||
req = models.GetEmbeddingRequest()
|
||||
params = {
|
||||
"Model": model_name,
|
||||
"Input": "测试"
|
||||
}
|
||||
req.from_json_string(json.dumps(params))
|
||||
try:
|
||||
res = client.GetEmbedding(req)
|
||||
print(res.to_json_string())
|
||||
except Exception as e:
|
||||
raise AppApiException(ValidCode.valid_error.value, f'校验失败,请检查参数是否正确: {str(e)}')
|
||||
|
||||
def is_valid(self, model_type: str, model_name, model_credential: Dict[str, object], provider,
|
||||
raise_exception=True) -> bool:
|
||||
model_type_list = provider.get_model_type_list()
|
||||
if not any(list(filter(lambda mt: mt.get('value') == model_type, model_type_list))):
|
||||
raise AppApiException(ValidCode.valid_error.value, f'{model_type} 模型类型不支持')
|
||||
self.valid_form(model_credential)
|
||||
try:
|
||||
self._validate_model_type(model_type, provider)
|
||||
cred = self._validate_credential(model_credential)
|
||||
httpProfile = HttpProfile(endpoint="hunyuan.tencentcloudapi.com")
|
||||
clientProfile = ClientProfile(httpProfile=httpProfile)
|
||||
client = hunyuan_client.HunyuanClient(cred, "", clientProfile)
|
||||
self._test_credentials(client, model_name)
|
||||
return True
|
||||
except AppApiException as e:
|
||||
if raise_exception:
|
||||
model = provider.get_model(model_type, model_name, model_credential)
|
||||
model.embed_query('你好')
|
||||
except Exception as e:
|
||||
if isinstance(e, AppApiException):
|
||||
raise e
|
||||
return False
|
||||
if raise_exception:
|
||||
raise AppApiException(ValidCode.valid_error.value, f'校验失败,请检查参数是否正确: {str(e)}')
|
||||
else:
|
||||
return False
|
||||
return True
|
||||
|
||||
def encryption_dict(self, model: Dict[str, object]) -> Dict[str, object]:
|
||||
encrypted_secret_key = super().encryption(model.get('SecretKey', ''))
|
||||
|
|
|
|||
|
|
@ -1,25 +1,34 @@
|
|||
from setting.models_provider.base_model_provider import MaxKBBaseModel
|
||||
from typing import Dict
|
||||
import requests
|
||||
from typing import Dict, List
|
||||
|
||||
from langchain_core.embeddings import Embeddings
|
||||
from tencentcloud.common import credential
|
||||
from tencentcloud.hunyuan.v20230901.hunyuan_client import HunyuanClient
|
||||
from tencentcloud.hunyuan.v20230901.models import GetEmbeddingRequest
|
||||
|
||||
|
||||
class TencentEmbeddingModel(MaxKBBaseModel):
|
||||
def __init__(self, secret_id: str, secret_key: str, api_base: str, model_name: str):
|
||||
class TencentEmbeddingModel(Embeddings):
|
||||
def embed_documents(self, texts: List[str]) -> List[List[float]]:
|
||||
return [self.embed_query(text) for text in texts]
|
||||
|
||||
def embed_query(self, text: str) -> List[float]:
|
||||
request = GetEmbeddingRequest()
|
||||
request.Input = text
|
||||
res = self.client.GetEmbedding(request)
|
||||
return res.Data
|
||||
|
||||
def __init__(self, secret_id: str, secret_key: str, model_name: str):
|
||||
self.secret_id = secret_id
|
||||
self.secret_key = secret_key
|
||||
self.api_base = api_base
|
||||
self.model_name = model_name
|
||||
cred = credential.Credential(
|
||||
secret_id, secret_key
|
||||
)
|
||||
self.client = HunyuanClient(cred, "")
|
||||
|
||||
@staticmethod
|
||||
def new_instance(model_type: str, model_name: str, model_credential: Dict[str, str], **model_kwargs):
|
||||
return TencentEmbeddingModel(
|
||||
secret_id=model_credential.get('SecretId'),
|
||||
secret_key=model_credential.get('SecretKey'),
|
||||
api_base=model_credential.get('api_base'),
|
||||
model_name=model_name,
|
||||
)
|
||||
|
||||
|
||||
def _generate_auth_token(self):
|
||||
# Example method to generate an authentication token for the model API
|
||||
return f"{self.secret_id}:{self.secret_key}"
|
||||
|
|
|
|||
|
|
@ -70,7 +70,7 @@ def _initialize_model_info():
|
|||
|
||||
tencent_embedding_model_info = _create_model_info(
|
||||
'hunyuan-embedding',
|
||||
'',
|
||||
'腾讯混元 Embedding 接口,可以将文本转化为高质量的向量数据。向量维度为1024维。',
|
||||
ModelTypeConst.EMBEDDING,
|
||||
TencentEmbeddingCredential,
|
||||
TencentEmbeddingModel
|
||||
|
|
@ -80,6 +80,7 @@ def _initialize_model_info():
|
|||
|
||||
model_info_manage = ModelInfoManage.builder() \
|
||||
.append_model_info_list(model_info_list) \
|
||||
.append_model_info_list(model_info_embedding_list) \
|
||||
.append_default_model_info(model_info_list[0]) \
|
||||
.build()
|
||||
|
||||
|
|
|
|||
Loading…
Reference in New Issue