From f85ce4a74511a023b17cdd5e8f6db00fe0949f4d Mon Sep 17 00:00:00 2001 From: shaohuzhang1 Date: Thu, 17 Oct 2024 16:36:11 +0800 Subject: [PATCH] =?UTF-8?q?feat:=20=E6=94=AF=E6=8C=81=E8=AE=AF=E9=A3=9E?= =?UTF-8?q?=E5=90=91=E9=87=8F=E6=A8=A1=E5=9E=8B?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../xf_model_provider/credential/embedding.py | 43 ++++++++++++++++ .../impl/xf_model_provider/model/embedding.py | 49 +++++++++++++++++++ .../xf_model_provider/xf_model_provider.py | 4 ++ 3 files changed, 96 insertions(+) create mode 100644 apps/setting/models_provider/impl/xf_model_provider/credential/embedding.py create mode 100644 apps/setting/models_provider/impl/xf_model_provider/model/embedding.py diff --git a/apps/setting/models_provider/impl/xf_model_provider/credential/embedding.py b/apps/setting/models_provider/impl/xf_model_provider/credential/embedding.py new file mode 100644 index 000000000..b4f429f43 --- /dev/null +++ b/apps/setting/models_provider/impl/xf_model_provider/credential/embedding.py @@ -0,0 +1,43 @@ +# coding=utf-8 +""" + @project: MaxKB + @Author:虎 + @file: embedding.py + @date:2024/10/17 15:40 + @desc: +""" +from typing import Dict + +from common import forms +from common.exception.app_exception import AppApiException +from common.forms import BaseForm +from setting.models_provider.base_model_provider import BaseModelCredential, ValidCode + + +class XFEmbeddingCredential(BaseForm, BaseModelCredential): + + def is_valid(self, model_type: str, model_name, model_credential: Dict[str, object], provider, + raise_exception=False): + model_type_list = provider.get_model_type_list() + if not any(list(filter(lambda mt: mt.get('value') == model_type, model_type_list))): + raise AppApiException(ValidCode.valid_error.value, f'{model_type} 模型类型不支持') + self.valid_form(model_credential) + try: + model = provider.get_model(model_type, model_name, model_credential) + model.embed_query('你好') + except Exception as e: + if isinstance(e, AppApiException): + raise e + if raise_exception: + raise AppApiException(ValidCode.valid_error.value, f'校验失败,请检查参数是否正确: {str(e)}') + else: + return False + return True + + def encryption_dict(self, model: Dict[str, object]): + return model + + base_url = forms.TextInputField('API 域名', required=True, default_value="https://emb-cn-huabei-1.xf-yun.com/") + spark_app_id = forms.TextInputField('APP ID', required=True) + spark_api_key = forms.PasswordInputField("API Key", required=True) + spark_api_secret = forms.PasswordInputField('API Secret', required=True) diff --git a/apps/setting/models_provider/impl/xf_model_provider/model/embedding.py b/apps/setting/models_provider/impl/xf_model_provider/model/embedding.py new file mode 100644 index 000000000..78cc04ceb --- /dev/null +++ b/apps/setting/models_provider/impl/xf_model_provider/model/embedding.py @@ -0,0 +1,49 @@ +# coding=utf-8 +""" + @project: MaxKB + @Author:虎 + @file: embedding.py + @date:2024/10/17 15:29 + @desc: +""" + +import base64 +import json +from typing import Dict, Optional + +import numpy as np +from langchain_community.embeddings import SparkLLMTextEmbeddings +from numpy import ndarray + +from setting.models_provider.base_model_provider import MaxKBBaseModel + + +class XFEmbedding(MaxKBBaseModel, SparkLLMTextEmbeddings): + @staticmethod + def new_instance(model_type, model_name, model_credential: Dict[str, object], **model_kwargs): + return XFEmbedding( + spark_app_id=model_credential.get('spark_app_id'), + spark_api_key=model_credential.get('spark_api_key'), + spark_api_secret=model_credential.get('spark_api_secret') + ) + + @staticmethod + def _parser_message( + message: str, + ) -> Optional[ndarray]: + data = json.loads(message) + code = data["header"]["code"] + if code != 0: + # 这里是讯飞的QPS限制会报错,所以不建议用讯飞的向量模型 + raise Exception(f"Request error: {code}, {data}") + else: + text_base = data["payload"]["feature"]["text"] + text_data = base64.b64decode(text_base) + dt = np.dtype(np.float32) + dt = dt.newbyteorder("<") + text = np.frombuffer(text_data, dtype=dt) + if len(text) > 2560: + array = text[:2560] + else: + array = text + return array diff --git a/apps/setting/models_provider/impl/xf_model_provider/xf_model_provider.py b/apps/setting/models_provider/impl/xf_model_provider/xf_model_provider.py index 8c60083cf..04fd2d439 100644 --- a/apps/setting/models_provider/impl/xf_model_provider/xf_model_provider.py +++ b/apps/setting/models_provider/impl/xf_model_provider/xf_model_provider.py @@ -12,9 +12,11 @@ import ssl from common.util.file_util import get_file_content from setting.models_provider.base_model_provider import ModelProvideInfo, ModelTypeConst, ModelInfo, IModelProvider, \ ModelInfoManage +from setting.models_provider.impl.xf_model_provider.credential.embedding import XFEmbeddingCredential from setting.models_provider.impl.xf_model_provider.credential.llm import XunFeiLLMModelCredential from setting.models_provider.impl.xf_model_provider.credential.stt import XunFeiSTTModelCredential from setting.models_provider.impl.xf_model_provider.credential.tts import XunFeiTTSModelCredential +from setting.models_provider.impl.xf_model_provider.model.embedding import XFEmbedding from setting.models_provider.impl.xf_model_provider.model.llm import XFChatSparkLLM from setting.models_provider.impl.xf_model_provider.model.stt import XFSparkSpeechToText from setting.models_provider.impl.xf_model_provider.model.tts import XFSparkTextToSpeech @@ -25,12 +27,14 @@ ssl._create_default_https_context = ssl.create_default_context() qwen_model_credential = XunFeiLLMModelCredential() stt_model_credential = XunFeiSTTModelCredential() tts_model_credential = XunFeiTTSModelCredential() +embedding_model_credential = XFEmbeddingCredential() model_info_list = [ ModelInfo('generalv3.5', '', ModelTypeConst.LLM, qwen_model_credential, XFChatSparkLLM), ModelInfo('generalv3', '', ModelTypeConst.LLM, qwen_model_credential, XFChatSparkLLM), ModelInfo('generalv2', '', ModelTypeConst.LLM, qwen_model_credential, XFChatSparkLLM), ModelInfo('iat', '中英文识别', ModelTypeConst.STT, stt_model_credential, XFSparkSpeechToText), ModelInfo('tts', '', ModelTypeConst.TTS, tts_model_credential, XFSparkTextToSpeech), + ModelInfo('embedding', '', ModelTypeConst.EMBEDDING, embedding_model_credential, XFEmbedding) ] model_info_manage = ModelInfoManage.builder().append_model_info_list(model_info_list).append_default_model_info(