From dead7e8da3bfd0403cb24b9eea9bcf85667d8c02 Mon Sep 17 00:00:00 2001 From: shaohuzhang1 Date: Fri, 18 Oct 2024 13:57:38 +0800 Subject: [PATCH] =?UTF-8?q?feat:=20=E6=94=AF=E6=8C=81=E8=85=BE=E8=AE=AF?= =?UTF-8?q?=E6=B7=B7=E5=85=83=E5=90=91=E9=87=8F=E6=A8=A1=E5=9E=8B?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../credential/embedding.py | 56 +++++-------------- .../tencent_model_provider/model/embedding.py | 33 +++++++---- .../tencent_model_provider.py | 3 +- 3 files changed, 36 insertions(+), 56 deletions(-) diff --git a/apps/setting/models_provider/impl/tencent_model_provider/credential/embedding.py b/apps/setting/models_provider/impl/tencent_model_provider/credential/embedding.py index 40e36caca..a0b00649d 100644 --- a/apps/setting/models_provider/impl/tencent_model_provider/credential/embedding.py +++ b/apps/setting/models_provider/impl/tencent_model_provider/credential/embedding.py @@ -1,11 +1,5 @@ -import json from typing import Dict -from tencentcloud.common import credential -from tencentcloud.common.profile.client_profile import ClientProfile -from tencentcloud.common.profile.http_profile import HttpProfile -from tencentcloud.hunyuan.v20230901 import hunyuan_client, models - from common import forms from common.exception.app_exception import AppApiException from common.forms import BaseForm @@ -13,48 +7,24 @@ from setting.models_provider.base_model_provider import BaseModelCredential, Val class TencentEmbeddingCredential(BaseForm, BaseModelCredential): - @classmethod - def _validate_model_type(cls, model_type: str, provider) -> bool: - model_type_list = provider.get_model_type_list() - if not any(mt.get('value') == model_type for mt in model_type_list): - raise AppApiException(ValidCode.valid_error.value, f'{model_type} 模型类型不支持') - return True - - @classmethod - def _validate_credential(cls, model_credential: Dict[str, object]) -> credential.Credential: - for key in ['SecretId', 'SecretKey']: - if key not in model_credential: - raise AppApiException(ValidCode.valid_error.value, f'{key} 字段为必填字段') - return credential.Credential(model_credential['SecretId'], model_credential['SecretKey']) - - @classmethod - def _test_credentials(cls, client, model_name: str): - req = models.GetEmbeddingRequest() - params = { - "Model": model_name, - "Input": "测试" - } - req.from_json_string(json.dumps(params)) - try: - res = client.GetEmbedding(req) - print(res.to_json_string()) - except Exception as e: - raise AppApiException(ValidCode.valid_error.value, f'校验失败,请检查参数是否正确: {str(e)}') def is_valid(self, model_type: str, model_name, model_credential: Dict[str, object], provider, raise_exception=True) -> bool: + model_type_list = provider.get_model_type_list() + if not any(list(filter(lambda mt: mt.get('value') == model_type, model_type_list))): + raise AppApiException(ValidCode.valid_error.value, f'{model_type} 模型类型不支持') + self.valid_form(model_credential) try: - self._validate_model_type(model_type, provider) - cred = self._validate_credential(model_credential) - httpProfile = HttpProfile(endpoint="hunyuan.tencentcloudapi.com") - clientProfile = ClientProfile(httpProfile=httpProfile) - client = hunyuan_client.HunyuanClient(cred, "", clientProfile) - self._test_credentials(client, model_name) - return True - except AppApiException as e: - if raise_exception: + model = provider.get_model(model_type, model_name, model_credential) + model.embed_query('你好') + except Exception as e: + if isinstance(e, AppApiException): raise e - return False + if raise_exception: + raise AppApiException(ValidCode.valid_error.value, f'校验失败,请检查参数是否正确: {str(e)}') + else: + return False + return True def encryption_dict(self, model: Dict[str, object]) -> Dict[str, object]: encrypted_secret_key = super().encryption(model.get('SecretKey', '')) diff --git a/apps/setting/models_provider/impl/tencent_model_provider/model/embedding.py b/apps/setting/models_provider/impl/tencent_model_provider/model/embedding.py index a5bd0336a..126e51b9d 100644 --- a/apps/setting/models_provider/impl/tencent_model_provider/model/embedding.py +++ b/apps/setting/models_provider/impl/tencent_model_provider/model/embedding.py @@ -1,25 +1,34 @@ -from setting.models_provider.base_model_provider import MaxKBBaseModel -from typing import Dict -import requests +from typing import Dict, List + +from langchain_core.embeddings import Embeddings +from tencentcloud.common import credential +from tencentcloud.hunyuan.v20230901.hunyuan_client import HunyuanClient +from tencentcloud.hunyuan.v20230901.models import GetEmbeddingRequest -class TencentEmbeddingModel(MaxKBBaseModel): - def __init__(self, secret_id: str, secret_key: str, api_base: str, model_name: str): +class TencentEmbeddingModel(Embeddings): + def embed_documents(self, texts: List[str]) -> List[List[float]]: + return [self.embed_query(text) for text in texts] + + def embed_query(self, text: str) -> List[float]: + request = GetEmbeddingRequest() + request.Input = text + res = self.client.GetEmbedding(request) + return res.Data + + def __init__(self, secret_id: str, secret_key: str, model_name: str): self.secret_id = secret_id self.secret_key = secret_key - self.api_base = api_base self.model_name = model_name + cred = credential.Credential( + secret_id, secret_key + ) + self.client = HunyuanClient(cred, "") @staticmethod def new_instance(model_type: str, model_name: str, model_credential: Dict[str, str], **model_kwargs): return TencentEmbeddingModel( secret_id=model_credential.get('SecretId'), secret_key=model_credential.get('SecretKey'), - api_base=model_credential.get('api_base'), model_name=model_name, ) - - - def _generate_auth_token(self): - # Example method to generate an authentication token for the model API - return f"{self.secret_id}:{self.secret_key}" diff --git a/apps/setting/models_provider/impl/tencent_model_provider/tencent_model_provider.py b/apps/setting/models_provider/impl/tencent_model_provider/tencent_model_provider.py index 572ff940d..47841a032 100644 --- a/apps/setting/models_provider/impl/tencent_model_provider/tencent_model_provider.py +++ b/apps/setting/models_provider/impl/tencent_model_provider/tencent_model_provider.py @@ -70,7 +70,7 @@ def _initialize_model_info(): tencent_embedding_model_info = _create_model_info( 'hunyuan-embedding', - '', + '腾讯混元 Embedding 接口,可以将文本转化为高质量的向量数据。向量维度为1024维。', ModelTypeConst.EMBEDDING, TencentEmbeddingCredential, TencentEmbeddingModel @@ -80,6 +80,7 @@ def _initialize_model_info(): model_info_manage = ModelInfoManage.builder() \ .append_model_info_list(model_info_list) \ + .append_model_info_list(model_info_embedding_list) \ .append_default_model_info(model_info_list[0]) \ .build()