mirror of
https://github.com/1Panel-dev/MaxKB.git
synced 2025-12-26 01:33:05 +00:00
Merge pull request #924 from 1Panel-dev/pr@main@feat_model
feat: 增加对接模型
This commit is contained in:
commit
6fd33cdeed
|
|
@ -8,6 +8,7 @@
|
|||
"""
|
||||
from enum import Enum
|
||||
|
||||
from setting.models_provider.impl.aws_bedrock_model_provider.aws_bedrock_model_provider import BedrockModelProvider
|
||||
from setting.models_provider.impl.azure_model_provider.azure_model_provider import AzureModelProvider
|
||||
from setting.models_provider.impl.deepseek_model_provider.deepseek_model_provider import DeepSeekModelProvider
|
||||
from setting.models_provider.impl.gemini_model_provider.gemini_model_provider import GeminiModelProvider
|
||||
|
|
@ -15,6 +16,9 @@ from setting.models_provider.impl.kimi_model_provider.kimi_model_provider import
|
|||
from setting.models_provider.impl.ollama_model_provider.ollama_model_provider import OllamaModelProvider
|
||||
from setting.models_provider.impl.openai_model_provider.openai_model_provider import OpenAIModelProvider
|
||||
from setting.models_provider.impl.qwen_model_provider.qwen_model_provider import QwenModelProvider
|
||||
from setting.models_provider.impl.tencent_model_provider.tencent_model_provider import TencentModelProvider
|
||||
from setting.models_provider.impl.volcanic_engine_model_provider.volcanic_engine_model_provider import \
|
||||
VolcanicEngineModelProvider
|
||||
from setting.models_provider.impl.wenxin_model_provider.wenxin_model_provider import WenxinModelProvider
|
||||
from setting.models_provider.impl.xf_model_provider.xf_model_provider import XunFeiModelProvider
|
||||
from setting.models_provider.impl.zhipu_model_provider.zhipu_model_provider import ZhiPuModelProvider
|
||||
|
|
@ -32,4 +36,7 @@ class ModelProvideConstants(Enum):
|
|||
model_xf_provider = XunFeiModelProvider()
|
||||
model_deepseek_provider = DeepSeekModelProvider()
|
||||
model_gemini_provider = GeminiModelProvider()
|
||||
model_volcanic_engine_provider = VolcanicEngineModelProvider()
|
||||
model_tencent_provider = TencentModelProvider()
|
||||
model_aws_bedrock_provider = BedrockModelProvider()
|
||||
model_local_provider = LocalModelProvider()
|
||||
|
|
|
|||
|
|
@ -0,0 +1,2 @@
|
|||
#!/usr/bin/env python
|
||||
# -*- coding: UTF-8 -*-
|
||||
|
|
@ -0,0 +1,107 @@
|
|||
#!/usr/bin/env python
|
||||
# -*- coding: UTF-8 -*-
|
||||
|
||||
import os
|
||||
from common.util.file_util import get_file_content
|
||||
from setting.models_provider.base_model_provider import (
|
||||
IModelProvider, ModelProvideInfo, ModelInfo, ModelTypeConst, ModelInfoManage
|
||||
)
|
||||
from setting.models_provider.impl.aws_bedrock_model_provider.credential.llm import BedrockLLMModelCredential
|
||||
from setting.models_provider.impl.aws_bedrock_model_provider.model.llm import BedrockModel
|
||||
from setting.models_provider.impl.tencent_model_provider.credential.embedding import TencentEmbeddingCredential
|
||||
from setting.models_provider.impl.tencent_model_provider.credential.llm import TencentLLMModelCredential
|
||||
from setting.models_provider.impl.tencent_model_provider.model.embedding import TencentEmbeddingModel
|
||||
from setting.models_provider.impl.tencent_model_provider.model.llm import TencentModel
|
||||
from smartdoc.conf import PROJECT_DIR
|
||||
|
||||
|
||||
def _create_model_info(model_name, description, model_type, credential_class, model_class):
|
||||
return ModelInfo(
|
||||
name=model_name,
|
||||
desc=description,
|
||||
model_type=model_type,
|
||||
model_credential=credential_class(),
|
||||
model_class=model_class
|
||||
)
|
||||
|
||||
|
||||
def _get_aws_bedrock_icon_path():
|
||||
return os.path.join(PROJECT_DIR, "apps", "setting", 'models_provider', 'impl', 'aws_bedrock_model_provider',
|
||||
'icon', 'bedrock_icon_svg')
|
||||
|
||||
|
||||
def _initialize_model_info():
|
||||
model_info_list = [_create_model_info(
|
||||
'amazon.titan-text-premier-v1:0',
|
||||
'Titan Text Premier 是 Titan Text 系列中功能强大且先进的型号,旨在为各种企业应用程序提供卓越的性能。凭借其尖端功能,它提供了更高的准确性和出色的结果,使其成为寻求一流文本处理解决方案的组织的绝佳选择。',
|
||||
ModelTypeConst.LLM,
|
||||
BedrockLLMModelCredential,
|
||||
BedrockModel
|
||||
),
|
||||
_create_model_info(
|
||||
'amazon.titan-text-lite-v1',
|
||||
'Amazon Titan Text Lite 是一种轻量级的高效模型,非常适合英语任务的微调,包括摘要和文案写作等,在这种场景下,客户需要更小、更经济高效且高度可定制的模型',
|
||||
ModelTypeConst.LLM,
|
||||
BedrockLLMModelCredential,
|
||||
BedrockModel),
|
||||
_create_model_info(
|
||||
'amazon.titan-text-express-v1',
|
||||
'Amazon Titan Text Express 的上下文长度长达 8000 个令牌,因而非常适合各种高级常规语言任务,例如开放式文本生成和对话式聊天,以及检索增强生成(RAG)中的支持。在发布时,该模型针对英语进行了优化,但也支持其他语言。',
|
||||
ModelTypeConst.LLM,
|
||||
BedrockLLMModelCredential,
|
||||
BedrockModel),
|
||||
_create_model_info(
|
||||
'amazon.titan-embed-text-v2:0',
|
||||
'Amazon Titan Text Embeddings V2 是一种轻量级、高效的模型,非常适合在不同维度上执行高精度检索任务。该模型支持灵活的嵌入大小(1024、512 和 256),并优先考虑在较小维度上保持准确性,从而可以在不影响准确性的情况下降低存储成本。Titan Text Embeddings V2 适用于各种任务,包括文档检索、推荐系统、搜索引擎和对话式系统。',
|
||||
ModelTypeConst.LLM,
|
||||
BedrockLLMModelCredential,
|
||||
BedrockModel),
|
||||
_create_model_info(
|
||||
'mistral.mistral-7b-instruct-v0:2',
|
||||
'7B 密集型转换器,可快速部署,易于定制。体积虽小,但功能强大,适用于各种用例。支持英语和代码,以及 32k 的上下文窗口。',
|
||||
ModelTypeConst.EMBEDDING,
|
||||
BedrockLLMModelCredential,
|
||||
BedrockModel),
|
||||
_create_model_info(
|
||||
'mistral.mistral-large-2402-v1:0',
|
||||
'先进的 Mistral AI 大型语言模型,能够处理任何语言任务,包括复杂的多语言推理、文本理解、转换和代码生成。',
|
||||
ModelTypeConst.LLM,
|
||||
BedrockLLMModelCredential,
|
||||
BedrockModel),
|
||||
_create_model_info(
|
||||
'meta.llama3-70b-instruct-v1:0',
|
||||
'非常适合内容创作、会话式人工智能、语言理解、研发和企业应用',
|
||||
ModelTypeConst.LLM,
|
||||
BedrockLLMModelCredential,
|
||||
BedrockModel),
|
||||
_create_model_info(
|
||||
'meta.llama3-8b-instruct-v1:0',
|
||||
'非常适合有限的计算能力和资源、边缘设备和更快的训练时间。',
|
||||
ModelTypeConst.LLM,
|
||||
BedrockLLMModelCredential,
|
||||
BedrockModel),
|
||||
]
|
||||
|
||||
model_info_manage = ModelInfoManage.builder() \
|
||||
.append_model_info_list(model_info_list) \
|
||||
.append_default_model_info(model_info_list[0]) \
|
||||
.build()
|
||||
|
||||
return model_info_manage
|
||||
|
||||
|
||||
class BedrockModelProvider(IModelProvider):
|
||||
def __init__(self):
|
||||
self._model_info_manage = _initialize_model_info()
|
||||
|
||||
def get_model_info_manage(self):
|
||||
return self._model_info_manage
|
||||
|
||||
def get_model_provide_info(self):
|
||||
icon_path = _get_aws_bedrock_icon_path()
|
||||
icon_data = get_file_content(icon_path)
|
||||
return ModelProvideInfo(
|
||||
provider='model_aws_bedrock_provider',
|
||||
name='Amazon Bedrock',
|
||||
icon=icon_data
|
||||
)
|
||||
|
|
@ -0,0 +1,64 @@
|
|||
import json
|
||||
from typing import Dict
|
||||
|
||||
from tencentcloud.common import credential
|
||||
from tencentcloud.common.profile.client_profile import ClientProfile
|
||||
from tencentcloud.common.profile.http_profile import HttpProfile
|
||||
from tencentcloud.hunyuan.v20230901 import hunyuan_client, models
|
||||
|
||||
from common import forms
|
||||
from common.exception.app_exception import AppApiException
|
||||
from common.forms import BaseForm
|
||||
from setting.models_provider.base_model_provider import BaseModelCredential, ValidCode
|
||||
|
||||
|
||||
class TencentEmbeddingCredential(BaseForm, BaseModelCredential):
|
||||
@classmethod
|
||||
def _validate_model_type(cls, model_type: str, provider) -> bool:
|
||||
model_type_list = provider.get_model_type_list()
|
||||
if not any(mt.get('value') == model_type for mt in model_type_list):
|
||||
raise AppApiException(ValidCode.valid_error.value, f'{model_type} 模型类型不支持')
|
||||
return True
|
||||
|
||||
@classmethod
|
||||
def _validate_credential(cls, model_credential: Dict[str, object]) -> credential.Credential:
|
||||
for key in ['SecretId', 'SecretKey']:
|
||||
if key not in model_credential:
|
||||
raise AppApiException(ValidCode.valid_error.value, f'{key} 字段为必填字段')
|
||||
return credential.Credential(model_credential['SecretId'], model_credential['SecretKey'])
|
||||
|
||||
@classmethod
|
||||
def _test_credentials(cls, client, model_name: str):
|
||||
req = models.GetEmbeddingRequest()
|
||||
params = {
|
||||
"Model": model_name,
|
||||
"Input": "测试"
|
||||
}
|
||||
req.from_json_string(json.dumps(params))
|
||||
try:
|
||||
res = client.GetEmbedding(req)
|
||||
print(res.to_json_string())
|
||||
except Exception as e:
|
||||
raise AppApiException(ValidCode.valid_error.value, f'校验失败,请检查参数是否正确: {str(e)}')
|
||||
|
||||
def is_valid(self, model_type: str, model_name, model_credential: Dict[str, object], provider,
|
||||
raise_exception=True) -> bool:
|
||||
try:
|
||||
self._validate_model_type(model_type, provider)
|
||||
cred = self._validate_credential(model_credential)
|
||||
httpProfile = HttpProfile(endpoint="hunyuan.tencentcloudapi.com")
|
||||
clientProfile = ClientProfile(httpProfile=httpProfile)
|
||||
client = hunyuan_client.HunyuanClient(cred, "", clientProfile)
|
||||
self._test_credentials(client, model_name)
|
||||
return True
|
||||
except AppApiException as e:
|
||||
if raise_exception:
|
||||
raise e
|
||||
return False
|
||||
|
||||
def encryption_dict(self, model: Dict[str, object]) -> Dict[str, object]:
|
||||
encrypted_secret_key = super().encryption(model.get('SecretKey', ''))
|
||||
return {**model, 'SecretKey': encrypted_secret_key}
|
||||
|
||||
SecretId = forms.PasswordInputField('SecretId', required=True)
|
||||
SecretKey = forms.PasswordInputField('SecretKey', required=True)
|
||||
|
|
@ -0,0 +1,62 @@
|
|||
import os
|
||||
import re
|
||||
from typing import Dict
|
||||
from common.exception.app_exception import AppApiException
|
||||
from common.forms import BaseForm
|
||||
from setting.models_provider.base_model_provider import ValidCode, BaseModelCredential
|
||||
from langchain_core.messages import HumanMessage
|
||||
from common import forms
|
||||
|
||||
|
||||
class BedrockLLMModelCredential(BaseForm, BaseModelCredential):
|
||||
|
||||
@staticmethod
|
||||
def _update_aws_credentials(profile_name, access_key_id, secret_access_key):
|
||||
credentials_path = os.path.join(os.path.expanduser("~"), ".aws", "credentials")
|
||||
os.makedirs(os.path.dirname(credentials_path), exist_ok=True)
|
||||
|
||||
content = open(credentials_path, 'r').read() if os.path.exists(credentials_path) else ''
|
||||
pattern = rf'\n*\[{profile_name}\]\n*(aws_access_key_id = .*)\n*(aws_secret_access_key = .*)\n*'
|
||||
content = re.sub(pattern, '', content, flags=re.DOTALL)
|
||||
|
||||
if not re.search(rf'\[{profile_name}\]', content):
|
||||
content += f"\n[{profile_name}]\naws_access_key_id = {access_key_id}\naws_secret_access_key = {secret_access_key}\n"
|
||||
|
||||
with open(credentials_path, 'w') as file:
|
||||
file.write(content)
|
||||
|
||||
def is_valid(self, model_type: str, model_name, model_credential: Dict[str, object], provider,
|
||||
raise_exception=False):
|
||||
model_type_list = provider.get_model_type_list()
|
||||
if not any(mt.get('value') == model_type for mt in model_type_list):
|
||||
if raise_exception:
|
||||
raise AppApiException(ValidCode.valid_error.value, f'{model_type} 模型类型不支持')
|
||||
return False
|
||||
|
||||
required_keys = ['region_name', 'access_key_id', 'secret_access_key']
|
||||
if not all(key in model_credential for key in required_keys):
|
||||
if raise_exception:
|
||||
raise AppApiException(ValidCode.valid_error.value, f'以下字段为必填字段: {", ".join(required_keys)}')
|
||||
return False
|
||||
|
||||
try:
|
||||
self._update_aws_credentials('aws-profile', model_credential['access_key_id'],
|
||||
model_credential['secret_access_key'])
|
||||
model_credential['credentials_profile_name'] = 'aws-profile'
|
||||
model = provider.get_model(model_type, model_name, model_credential)
|
||||
model.invoke([HumanMessage(content='你好')])
|
||||
except AppApiException:
|
||||
raise
|
||||
except Exception as e:
|
||||
if raise_exception:
|
||||
raise AppApiException(ValidCode.valid_error.value, f'校验失败,请检查参数是否正确: {str(e)}')
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
def encryption_dict(self, model: Dict[str, object]):
|
||||
return {**model, 'secret_access_key': super().encryption(model.get('secret_access_key', ''))}
|
||||
|
||||
region_name = forms.TextInputField('Region Name', required=True)
|
||||
access_key_id = forms.TextInputField('Access Key ID', required=True)
|
||||
secret_access_key = forms.PasswordInputField('Secret Access Key', required=True)
|
||||
|
|
@ -0,0 +1 @@
|
|||
<svg xmlns="http://www.w3.org/2000/svg" version="2.0" focusable="false" aria-hidden="true" class="globalNav-1216 globalNav-1213" data-testid="awsc-logo" viewBox="0 0 29 17"><path class="globalNav-1214" d="M8.38 6.17a2.6 2.6 0 00.11.83c.08.232.18.456.3.67a.4.4 0 01.07.21.36.36 0 01-.18.28l-.59.39a.43.43 0 01-.24.08.38.38 0 01-.28-.13 2.38 2.38 0 01-.34-.43c-.09-.16-.18-.34-.28-.55a3.44 3.44 0 01-2.74 1.29 2.54 2.54 0 01-1.86-.67 2.36 2.36 0 01-.68-1.79 2.43 2.43 0 01.84-1.92 3.43 3.43 0 012.29-.72 6.75 6.75 0 011 .07c.35.05.7.12 1.07.2V3.3a2.06 2.06 0 00-.44-1.49 2.12 2.12 0 00-1.52-.43 4.4 4.4 0 00-1 .12 6.85 6.85 0 00-1 .32l-.33.12h-.14c-.14 0-.2-.1-.2-.29v-.46A.62.62 0 012.3.87a.78.78 0 01.27-.2A6 6 0 013.74.25 5.7 5.7 0 015.19.07a3.37 3.37 0 012.44.76 3 3 0 01.77 2.29l-.02 3.05zM4.6 7.59a3 3 0 001-.17 2 2 0 00.88-.6 1.36 1.36 0 00.32-.59 3.18 3.18 0 00.09-.81V5A7.52 7.52 0 006 4.87h-.88a2.13 2.13 0 00-1.38.37 1.3 1.3 0 00-.46 1.08 1.3 1.3 0 00.34 1c.278.216.63.313.98.27zm7.49 1a.56.56 0 01-.36-.09.73.73 0 01-.2-.37L9.35.93a1.39 1.39 0 01-.08-.38c0-.15.07-.23.22-.23h.92a.56.56 0 01.36.09.74.74 0 01.19.37L12.53 7 14 .79a.61.61 0 01.18-.37.59.59 0 01.37-.09h.75a.62.62 0 01.38.09.74.74 0 01.18.37L17.31 7 18.92.76a.74.74 0 01.19-.37.56.56 0 01.36-.09h.87a.21.21 0 01.23.23 1 1 0 010 .15s0 .13-.06.23l-2.26 7.2a.74.74 0 01-.19.37.6.6 0 01-.36.09h-.8a.53.53 0 01-.37-.1.64.64 0 01-.18-.37l-1.45-6-1.44 6a.64.64 0 01-.18.37.55.55 0 01-.37.1l-.82.02zm12 .24a6.29 6.29 0 01-1.44-.16 4.21 4.21 0 01-1.07-.37.69.69 0 01-.29-.26.66.66 0 01-.06-.27V7.3c0-.19.07-.29.21-.29a.57.57 0 01.18 0l.23.1c.32.143.656.25 1 .32.365.08.737.12 1.11.12a2.47 2.47 0 001.36-.31 1 1 0 00.48-.88.88.88 0 00-.25-.65 2.29 2.29 0 00-.94-.49l-1.35-.43a2.83 2.83 0 01-1.49-.94 2.24 2.24 0 01-.47-1.36 2 2 0 01.25-1c.167-.3.395-.563.67-.77a3 3 0 011-.48A4.1 4.1 0 0124.4.08a4.4 4.4 0 01.62 0l.61.1.53.15.39.16c.105.062.2.14.28.23a.57.57 0 01.08.31v.44c0 .2-.07.3-.21.3a.92.92 0 01-.36-.12 4.35 4.35 0 00-1.8-.36 2.51 2.51 0 00-1.24.26.92.92 0 00-.44.84c0 .249.1.488.28.66.295.236.635.41 1 .51l1.32.42a2.88 2.88 0 011.44.9 2.1 2.1 0 01.43 1.31 2.38 2.38 0 01-.24 1.08 2.34 2.34 0 01-.68.82 3 3 0 01-1 .53 4.59 4.59 0 01-1.35.22l.03-.01z"></path><path class="globalNav-1215" d="M25.82 13.43a20.07 20.07 0 01-11.35 3.47A20.54 20.54 0 01.61 11.62c-.29-.26 0-.62.32-.42a27.81 27.81 0 0013.86 3.68 27.54 27.54 0 0010.58-2.16c.52-.22.96.34.45.71z"></path><path class="globalNav-1215" d="M27.1 12c-.4-.51-2.6-.24-3.59-.12-.3 0-.34-.23-.07-.42 1.75-1.23 4.63-.88 5-.46.37.42-.09 3.3-1.74 4.68-.25.21-.49.09-.38-.18.34-.95 1.17-3.02.78-3.5z"></path></svg>
|
||||
|
After Width: | Height: | Size: 2.6 KiB |
|
|
@ -0,0 +1,25 @@
|
|||
from setting.models_provider.base_model_provider import MaxKBBaseModel
|
||||
from typing import Dict
|
||||
import requests
|
||||
|
||||
|
||||
class TencentEmbeddingModel(MaxKBBaseModel):
|
||||
def __init__(self, secret_id: str, secret_key: str, api_base: str, model_name: str):
|
||||
self.secret_id = secret_id
|
||||
self.secret_key = secret_key
|
||||
self.api_base = api_base
|
||||
self.model_name = model_name
|
||||
|
||||
@staticmethod
|
||||
def new_instance(model_type: str, model_name: str, model_credential: Dict[str, str], **model_kwargs):
|
||||
return TencentEmbeddingModel(
|
||||
secret_id=model_credential.get('SecretId'),
|
||||
secret_key=model_credential.get('SecretKey'),
|
||||
api_base=model_credential.get('api_base'),
|
||||
model_name=model_name,
|
||||
)
|
||||
|
||||
|
||||
def _generate_auth_token(self):
|
||||
# Example method to generate an authentication token for the model API
|
||||
return f"{self.secret_id}:{self.secret_key}"
|
||||
|
|
@ -0,0 +1,35 @@
|
|||
from typing import List, Dict, Any
|
||||
from langchain_community.chat_models import BedrockChat
|
||||
from langchain_core.messages import BaseMessage, get_buffer_string
|
||||
from common.config.tokenizer_manage_config import TokenizerManage
|
||||
from setting.models_provider.base_model_provider import MaxKBBaseModel
|
||||
|
||||
|
||||
class BedrockModel(MaxKBBaseModel, BedrockChat):
|
||||
|
||||
def __init__(self, model_id: str, region_name: str, credentials_profile_name: str,
|
||||
streaming: bool = False, **kwargs):
|
||||
super().__init__(model_id=model_id, region_name=region_name,
|
||||
credentials_profile_name=credentials_profile_name, streaming=streaming, **kwargs)
|
||||
|
||||
@classmethod
|
||||
def new_instance(cls, model_type: str, model_name: str, model_credential: Dict[str, str],
|
||||
**model_kwargs) -> 'BedrockModel':
|
||||
return cls(
|
||||
model_id=model_name,
|
||||
region_name=model_credential['region_name'],
|
||||
credentials_profile_name=model_credential['credentials_profile_name'],
|
||||
streaming=model_kwargs.pop('streaming', False),
|
||||
**model_kwargs
|
||||
)
|
||||
|
||||
def _get_num_tokens(self, content: str) -> int:
|
||||
"""Helper method to count tokens in a string."""
|
||||
tokenizer = TokenizerManage.get_tokenizer()
|
||||
return len(tokenizer.encode(content))
|
||||
|
||||
def get_num_tokens_from_messages(self, messages: List[BaseMessage]) -> int:
|
||||
return sum(self._get_num_tokens(get_buffer_string([message])) for message in messages)
|
||||
|
||||
def get_num_tokens(self, text: str) -> int:
|
||||
return self._get_num_tokens(text)
|
||||
|
|
@ -0,0 +1,2 @@
|
|||
#!/usr/bin/env python
|
||||
# -*- coding: UTF-8 -*-
|
||||
|
|
@ -0,0 +1,64 @@
|
|||
import json
|
||||
from typing import Dict
|
||||
|
||||
from tencentcloud.common import credential
|
||||
from tencentcloud.common.profile.client_profile import ClientProfile
|
||||
from tencentcloud.common.profile.http_profile import HttpProfile
|
||||
from tencentcloud.hunyuan.v20230901 import hunyuan_client, models
|
||||
|
||||
from common import forms
|
||||
from common.exception.app_exception import AppApiException
|
||||
from common.forms import BaseForm
|
||||
from setting.models_provider.base_model_provider import BaseModelCredential, ValidCode
|
||||
|
||||
|
||||
class TencentEmbeddingCredential(BaseForm, BaseModelCredential):
|
||||
@classmethod
|
||||
def _validate_model_type(cls, model_type: str, provider) -> bool:
|
||||
model_type_list = provider.get_model_type_list()
|
||||
if not any(mt.get('value') == model_type for mt in model_type_list):
|
||||
raise AppApiException(ValidCode.valid_error.value, f'{model_type} 模型类型不支持')
|
||||
return True
|
||||
|
||||
@classmethod
|
||||
def _validate_credential(cls, model_credential: Dict[str, object]) -> credential.Credential:
|
||||
for key in ['SecretId', 'SecretKey']:
|
||||
if key not in model_credential:
|
||||
raise AppApiException(ValidCode.valid_error.value, f'{key} 字段为必填字段')
|
||||
return credential.Credential(model_credential['SecretId'], model_credential['SecretKey'])
|
||||
|
||||
@classmethod
|
||||
def _test_credentials(cls, client, model_name: str):
|
||||
req = models.GetEmbeddingRequest()
|
||||
params = {
|
||||
"Model": model_name,
|
||||
"Input": "测试"
|
||||
}
|
||||
req.from_json_string(json.dumps(params))
|
||||
try:
|
||||
res = client.GetEmbedding(req)
|
||||
print(res.to_json_string())
|
||||
except Exception as e:
|
||||
raise AppApiException(ValidCode.valid_error.value, f'校验失败,请检查参数是否正确: {str(e)}')
|
||||
|
||||
def is_valid(self, model_type: str, model_name, model_credential: Dict[str, object], provider,
|
||||
raise_exception=True) -> bool:
|
||||
try:
|
||||
self._validate_model_type(model_type, provider)
|
||||
cred = self._validate_credential(model_credential)
|
||||
httpProfile = HttpProfile(endpoint="hunyuan.tencentcloudapi.com")
|
||||
clientProfile = ClientProfile(httpProfile=httpProfile)
|
||||
client = hunyuan_client.HunyuanClient(cred, "", clientProfile)
|
||||
self._test_credentials(client, model_name)
|
||||
return True
|
||||
except AppApiException as e:
|
||||
if raise_exception:
|
||||
raise e
|
||||
return False
|
||||
|
||||
def encryption_dict(self, model: Dict[str, object]) -> Dict[str, object]:
|
||||
encrypted_secret_key = super().encryption(model.get('SecretKey', ''))
|
||||
return {**model, 'SecretKey': encrypted_secret_key}
|
||||
|
||||
SecretId = forms.PasswordInputField('SecretId', required=True)
|
||||
SecretKey = forms.PasswordInputField('SecretKey', required=True)
|
||||
|
|
@ -0,0 +1,47 @@
|
|||
# coding=utf-8
|
||||
from langchain_core.messages import HumanMessage
|
||||
from common import forms
|
||||
from common.exception.app_exception import AppApiException
|
||||
from common.forms import BaseForm
|
||||
from setting.models_provider.base_model_provider import BaseModelCredential, ValidCode
|
||||
|
||||
|
||||
class TencentLLMModelCredential(BaseForm, BaseModelCredential):
|
||||
REQUIRED_FIELDS = ['hunyuan_app_id', 'hunyuan_secret_id', 'hunyuan_secret_key']
|
||||
|
||||
@classmethod
|
||||
def _validate_model_type(cls, model_type, provider, raise_exception=False):
|
||||
if not any(mt['value'] == model_type for mt in provider.get_model_type_list()):
|
||||
if raise_exception:
|
||||
raise AppApiException(ValidCode.valid_error.value, f'{model_type} 模型类型不支持')
|
||||
return False
|
||||
return True
|
||||
|
||||
@classmethod
|
||||
def _validate_credential_fields(cls, model_credential, raise_exception=False):
|
||||
missing_keys = [key for key in cls.REQUIRED_FIELDS if key not in model_credential]
|
||||
if missing_keys:
|
||||
if raise_exception:
|
||||
raise AppApiException(ValidCode.valid_error.value, f'{", ".join(missing_keys)} 字段为必填字段')
|
||||
return False
|
||||
return True
|
||||
|
||||
def is_valid(self, model_type, model_name, model_credential, provider, raise_exception=False):
|
||||
if not (self._validate_model_type(model_type, provider, raise_exception) and
|
||||
self._validate_credential_fields(model_credential, raise_exception)):
|
||||
return False
|
||||
try:
|
||||
model = provider.get_model(model_type, model_name, model_credential)
|
||||
model.invoke([HumanMessage(content='你好')])
|
||||
except Exception as e:
|
||||
if raise_exception:
|
||||
raise AppApiException(ValidCode.valid_error.value, f'校验失败,请检查参数是否正确: {str(e)}')
|
||||
return False
|
||||
return True
|
||||
|
||||
def encryption_dict(self, model):
|
||||
return {**model, 'hunyuan_secret_key': super().encryption(model.get('hunyuan_secret_key', ''))}
|
||||
|
||||
hunyuan_app_id = forms.TextInputField('APP ID', required=True)
|
||||
hunyuan_secret_id = forms.PasswordInputField('SecretId', required=True)
|
||||
hunyuan_secret_key = forms.PasswordInputField('SecretKey', required=True)
|
||||
File diff suppressed because one or more lines are too long
|
After Width: | Height: | Size: 74 KiB |
|
|
@ -0,0 +1,25 @@
|
|||
from setting.models_provider.base_model_provider import MaxKBBaseModel
|
||||
from typing import Dict
|
||||
import requests
|
||||
|
||||
|
||||
class TencentEmbeddingModel(MaxKBBaseModel):
|
||||
def __init__(self, secret_id: str, secret_key: str, api_base: str, model_name: str):
|
||||
self.secret_id = secret_id
|
||||
self.secret_key = secret_key
|
||||
self.api_base = api_base
|
||||
self.model_name = model_name
|
||||
|
||||
@staticmethod
|
||||
def new_instance(model_type: str, model_name: str, model_credential: Dict[str, str], **model_kwargs):
|
||||
return TencentEmbeddingModel(
|
||||
secret_id=model_credential.get('SecretId'),
|
||||
secret_key=model_credential.get('SecretKey'),
|
||||
api_base=model_credential.get('api_base'),
|
||||
model_name=model_name,
|
||||
)
|
||||
|
||||
|
||||
def _generate_auth_token(self):
|
||||
# Example method to generate an authentication token for the model API
|
||||
return f"{self.secret_id}:{self.secret_key}"
|
||||
|
|
@ -0,0 +1,273 @@
|
|||
import json
|
||||
import logging
|
||||
from typing import Any, Dict, Iterator, List, Mapping, Optional, Type
|
||||
|
||||
from langchain_core.callbacks import CallbackManagerForLLMRun
|
||||
from langchain_core.language_models.chat_models import (
|
||||
BaseChatModel,
|
||||
generate_from_stream,
|
||||
)
|
||||
from langchain_core.messages import (
|
||||
AIMessage,
|
||||
AIMessageChunk,
|
||||
BaseMessage,
|
||||
BaseMessageChunk,
|
||||
ChatMessage,
|
||||
ChatMessageChunk,
|
||||
HumanMessage,
|
||||
HumanMessageChunk,
|
||||
)
|
||||
from langchain_core.outputs import ChatGeneration, ChatGenerationChunk, ChatResult
|
||||
from langchain_core.pydantic_v1 import Field, SecretStr, root_validator
|
||||
from langchain_core.utils import (
|
||||
convert_to_secret_str,
|
||||
get_from_dict_or_env,
|
||||
get_pydantic_field_names,
|
||||
pre_init,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _convert_message_to_dict(message: BaseMessage) -> dict:
|
||||
message_dict: Dict[str, Any]
|
||||
if isinstance(message, ChatMessage):
|
||||
message_dict = {"Role": message.role, "Content": message.content}
|
||||
elif isinstance(message, HumanMessage):
|
||||
message_dict = {"Role": "user", "Content": message.content}
|
||||
elif isinstance(message, AIMessage):
|
||||
message_dict = {"Role": "assistant", "Content": message.content}
|
||||
else:
|
||||
raise TypeError(f"Got unknown type {message}")
|
||||
|
||||
return message_dict
|
||||
|
||||
|
||||
def _convert_dict_to_message(_dict: Mapping[str, Any]) -> BaseMessage:
|
||||
role = _dict["Role"]
|
||||
if role == "user":
|
||||
return HumanMessage(content=_dict["Content"])
|
||||
elif role == "assistant":
|
||||
return AIMessage(content=_dict.get("Content", "") or "")
|
||||
else:
|
||||
return ChatMessage(content=_dict["Content"], role=role)
|
||||
|
||||
|
||||
def _convert_delta_to_message_chunk(
|
||||
_dict: Mapping[str, Any], default_class: Type[BaseMessageChunk]
|
||||
) -> BaseMessageChunk:
|
||||
role = _dict.get("Role")
|
||||
content = _dict.get("Content") or ""
|
||||
|
||||
if role == "user" or default_class == HumanMessageChunk:
|
||||
return HumanMessageChunk(content=content)
|
||||
elif role == "assistant" or default_class == AIMessageChunk:
|
||||
return AIMessageChunk(content=content)
|
||||
elif role or default_class == ChatMessageChunk:
|
||||
return ChatMessageChunk(content=content, role=role) # type: ignore[arg-type]
|
||||
else:
|
||||
return default_class(content=content) # type: ignore[call-arg]
|
||||
|
||||
|
||||
def _create_chat_result(response: Mapping[str, Any]) -> ChatResult:
|
||||
generations = []
|
||||
for choice in response["Choices"]:
|
||||
message = _convert_dict_to_message(choice["Message"])
|
||||
generations.append(ChatGeneration(message=message))
|
||||
|
||||
token_usage = response["Usage"]
|
||||
llm_output = {"token_usage": token_usage}
|
||||
return ChatResult(generations=generations, llm_output=llm_output)
|
||||
|
||||
|
||||
class ChatHunyuan(BaseChatModel):
|
||||
"""Tencent Hunyuan chat models API by Tencent.
|
||||
|
||||
For more information, see https://cloud.tencent.com/document/product/1729
|
||||
"""
|
||||
|
||||
@property
|
||||
def lc_secrets(self) -> Dict[str, str]:
|
||||
return {
|
||||
"hunyuan_app_id": "HUNYUAN_APP_ID",
|
||||
"hunyuan_secret_id": "HUNYUAN_SECRET_ID",
|
||||
"hunyuan_secret_key": "HUNYUAN_SECRET_KEY",
|
||||
}
|
||||
|
||||
@property
|
||||
def lc_serializable(self) -> bool:
|
||||
return True
|
||||
|
||||
hunyuan_app_id: Optional[int] = None
|
||||
"""Hunyuan App ID"""
|
||||
hunyuan_secret_id: Optional[str] = None
|
||||
"""Hunyuan Secret ID"""
|
||||
hunyuan_secret_key: Optional[SecretStr] = None
|
||||
"""Hunyuan Secret Key"""
|
||||
streaming: bool = False
|
||||
"""Whether to stream the results or not."""
|
||||
request_timeout: int = 60
|
||||
"""Timeout for requests to Hunyuan API. Default is 60 seconds."""
|
||||
temperature: float = 1.0
|
||||
"""What sampling temperature to use."""
|
||||
top_p: float = 1.0
|
||||
"""What probability mass to use."""
|
||||
model: str = "hunyuan-lite"
|
||||
"""What Model to use.
|
||||
Optional model:
|
||||
- hunyuan-lite、
|
||||
- hunyuan-standard
|
||||
- hunyuan-standard-256K
|
||||
- hunyuan-pro
|
||||
- hunyuan-code
|
||||
- hunyuan-role
|
||||
- hunyuan-functioncall
|
||||
- hunyuan-vision
|
||||
"""
|
||||
stream_moderation: bool = False
|
||||
"""Whether to review the results or not when streaming is true."""
|
||||
enable_enhancement: bool = True
|
||||
"""Whether to enhancement the results or not."""
|
||||
|
||||
model_kwargs: Dict[str, Any] = Field(default_factory=dict)
|
||||
"""Holds any model parameters valid for API call not explicitly specified."""
|
||||
|
||||
class Config:
|
||||
"""Configuration for this pydantic object."""
|
||||
|
||||
allow_population_by_field_name = True
|
||||
|
||||
@root_validator(pre=True)
|
||||
def build_extra(cls, values: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""Build extra kwargs from additional params that were passed in."""
|
||||
all_required_field_names = get_pydantic_field_names(cls)
|
||||
extra = values.get("model_kwargs", {})
|
||||
for field_name in list(values):
|
||||
if field_name in extra:
|
||||
raise ValueError(f"Found {field_name} supplied twice.")
|
||||
if field_name not in all_required_field_names:
|
||||
logger.warning(
|
||||
f"""WARNING! {field_name} is not default parameter.
|
||||
{field_name} was transferred to model_kwargs.
|
||||
Please confirm that {field_name} is what you intended."""
|
||||
)
|
||||
extra[field_name] = values.pop(field_name)
|
||||
|
||||
invalid_model_kwargs = all_required_field_names.intersection(extra.keys())
|
||||
if invalid_model_kwargs:
|
||||
raise ValueError(
|
||||
f"Parameters {invalid_model_kwargs} should be specified explicitly. "
|
||||
f"Instead they were passed in as part of `model_kwargs` parameter."
|
||||
)
|
||||
|
||||
values["model_kwargs"] = extra
|
||||
return values
|
||||
|
||||
@pre_init
|
||||
def validate_environment(cls, values: Dict) -> Dict:
|
||||
values["hunyuan_app_id"] = get_from_dict_or_env(
|
||||
values,
|
||||
"hunyuan_app_id",
|
||||
"HUNYUAN_APP_ID",
|
||||
)
|
||||
values["hunyuan_secret_id"] = get_from_dict_or_env(
|
||||
values,
|
||||
"hunyuan_secret_id",
|
||||
"HUNYUAN_SECRET_ID",
|
||||
)
|
||||
values["hunyuan_secret_key"] = convert_to_secret_str(
|
||||
get_from_dict_or_env(
|
||||
values,
|
||||
"hunyuan_secret_key",
|
||||
"HUNYUAN_SECRET_KEY",
|
||||
)
|
||||
)
|
||||
return values
|
||||
|
||||
@property
|
||||
def _default_params(self) -> Dict[str, Any]:
|
||||
"""Get the default parameters for calling Hunyuan API."""
|
||||
normal_params = {
|
||||
"Temperature": self.temperature,
|
||||
"TopP": self.top_p,
|
||||
"Model": self.model,
|
||||
"Stream": self.streaming,
|
||||
"StreamModeration": self.stream_moderation,
|
||||
"EnableEnhancement": self.enable_enhancement,
|
||||
}
|
||||
return {**normal_params, **self.model_kwargs}
|
||||
|
||||
def _generate(
|
||||
self,
|
||||
messages: List[BaseMessage],
|
||||
stop: Optional[List[str]] = None,
|
||||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> ChatResult:
|
||||
if self.streaming:
|
||||
stream_iter = self._stream(
|
||||
messages=messages, stop=stop, run_manager=run_manager, **kwargs
|
||||
)
|
||||
return generate_from_stream(stream_iter)
|
||||
|
||||
res = self._chat(messages, **kwargs)
|
||||
return _create_chat_result(json.loads(res.to_json_string()))
|
||||
|
||||
def _stream(
|
||||
self,
|
||||
messages: List[BaseMessage],
|
||||
stop: Optional[List[str]] = None,
|
||||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> Iterator[ChatGenerationChunk]:
|
||||
res = self._chat(messages, **kwargs)
|
||||
|
||||
default_chunk_class = AIMessageChunk
|
||||
for chunk in res:
|
||||
chunk = chunk.get("data", "")
|
||||
if len(chunk) == 0:
|
||||
continue
|
||||
response = json.loads(chunk)
|
||||
if "error" in response:
|
||||
raise ValueError(f"Error from Hunyuan api response: {response}")
|
||||
|
||||
for choice in response["Choices"]:
|
||||
chunk = _convert_delta_to_message_chunk(
|
||||
choice["Delta"], default_chunk_class
|
||||
)
|
||||
default_chunk_class = chunk.__class__
|
||||
cg_chunk = ChatGenerationChunk(message=chunk)
|
||||
if run_manager:
|
||||
run_manager.on_llm_new_token(chunk.content, chunk=cg_chunk)
|
||||
yield cg_chunk
|
||||
|
||||
def _chat(self, messages: List[BaseMessage], **kwargs: Any) -> Any:
|
||||
if self.hunyuan_secret_key is None:
|
||||
raise ValueError("Hunyuan secret key is not set.")
|
||||
|
||||
try:
|
||||
from tencentcloud.common import credential
|
||||
from tencentcloud.hunyuan.v20230901 import hunyuan_client, models
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
"Could not import tencentcloud python package. "
|
||||
"Please install it with `pip install tencentcloud-sdk-python`."
|
||||
)
|
||||
|
||||
parameters = {**self._default_params, **kwargs}
|
||||
cred = credential.Credential(
|
||||
self.hunyuan_secret_id, str(self.hunyuan_secret_key.get_secret_value())
|
||||
)
|
||||
client = hunyuan_client.HunyuanClient(cred, "")
|
||||
req = models.ChatCompletionsRequest()
|
||||
params = {
|
||||
"Messages": [_convert_message_to_dict(m) for m in messages],
|
||||
**parameters,
|
||||
}
|
||||
req.from_json_string(json.dumps(params))
|
||||
resp = client.ChatCompletions(req)
|
||||
return resp
|
||||
|
||||
@property
|
||||
def _llm_type(self) -> str:
|
||||
return "hunyuan-chat"
|
||||
|
|
@ -0,0 +1,37 @@
|
|||
# coding=utf-8
|
||||
|
||||
from typing import List, Dict
|
||||
|
||||
from langchain_core.messages import BaseMessage, get_buffer_string
|
||||
from common.config.tokenizer_manage_config import TokenizerManage
|
||||
from setting.models_provider.base_model_provider import MaxKBBaseModel
|
||||
from setting.models_provider.impl.tencent_model_provider.model.hunyuan import ChatHunyuan
|
||||
|
||||
|
||||
class TencentModel(MaxKBBaseModel, ChatHunyuan):
|
||||
|
||||
def __init__(self, model_name: str, credentials: Dict[str, str], streaming: bool = False, **kwargs):
|
||||
hunyuan_app_id = credentials.get('hunyuan_app_id')
|
||||
hunyuan_secret_id = credentials.get('hunyuan_secret_id')
|
||||
hunyuan_secret_key = credentials.get('hunyuan_secret_key')
|
||||
|
||||
if not all([hunyuan_app_id, hunyuan_secret_id, hunyuan_secret_key]):
|
||||
raise ValueError(
|
||||
"All of 'hunyuan_app_id', 'hunyuan_secret_id', and 'hunyuan_secret_key' must be provided in credentials.")
|
||||
|
||||
super().__init__(model=model_name, hunyuan_app_id=hunyuan_app_id, hunyuan_secret_id=hunyuan_secret_id,
|
||||
hunyuan_secret_key=hunyuan_secret_key, streaming=streaming, **kwargs)
|
||||
|
||||
@staticmethod
|
||||
def new_instance(model_type: str, model_name: str, model_credential: Dict[str, object],
|
||||
**model_kwargs) -> 'TencentModel':
|
||||
streaming = model_kwargs.pop('streaming', False)
|
||||
return TencentModel(model_name=model_name, credentials=model_credential, streaming=streaming, **model_kwargs)
|
||||
|
||||
def get_num_tokens_from_messages(self, messages: List[BaseMessage]) -> int:
|
||||
tokenizer = TokenizerManage.get_tokenizer()
|
||||
return sum(len(tokenizer.encode(get_buffer_string([m]))) for m in messages)
|
||||
|
||||
def get_num_tokens(self, text: str) -> int:
|
||||
tokenizer = TokenizerManage.get_tokenizer()
|
||||
return len(tokenizer.encode(text))
|
||||
|
|
@ -0,0 +1,103 @@
|
|||
#!/usr/bin/env python
|
||||
# -*- coding: UTF-8 -*-
|
||||
|
||||
import os
|
||||
from common.util.file_util import get_file_content
|
||||
from setting.models_provider.base_model_provider import (
|
||||
IModelProvider, ModelProvideInfo, ModelInfo, ModelTypeConst, ModelInfoManage
|
||||
)
|
||||
from setting.models_provider.impl.tencent_model_provider.credential.embedding import TencentEmbeddingCredential
|
||||
from setting.models_provider.impl.tencent_model_provider.credential.llm import TencentLLMModelCredential
|
||||
from setting.models_provider.impl.tencent_model_provider.model.embedding import TencentEmbeddingModel
|
||||
from setting.models_provider.impl.tencent_model_provider.model.llm import TencentModel
|
||||
from smartdoc.conf import PROJECT_DIR
|
||||
|
||||
|
||||
def _create_model_info(model_name, description, model_type, credential_class, model_class):
|
||||
return ModelInfo(
|
||||
name=model_name,
|
||||
desc=description,
|
||||
model_type=model_type,
|
||||
model_credential=credential_class(),
|
||||
model_class=model_class
|
||||
)
|
||||
|
||||
|
||||
def _get_tencent_icon_path():
|
||||
return os.path.join(PROJECT_DIR, "apps", "setting", 'models_provider', 'impl', 'tencent_model_provider',
|
||||
'icon', 'tencent_icon_svg')
|
||||
|
||||
|
||||
def _initialize_model_info():
|
||||
model_info_list = [_create_model_info(
|
||||
'hunyuan-pro',
|
||||
'当前混元模型中效果最优版本,万亿级参数规模 MOE-32K 长文模型。在各种 benchmark 上达到绝对领先的水平,复杂指令和推理,具备复杂数学能力,支持 functioncall,在多语言翻译、金融法律医疗等领域应用重点优化',
|
||||
ModelTypeConst.LLM,
|
||||
TencentLLMModelCredential,
|
||||
TencentModel
|
||||
),
|
||||
_create_model_info(
|
||||
'hunyuan-standard',
|
||||
'采用更优的路由策略,同时缓解了负载均衡和专家趋同的问题。长文方面,大海捞针指标达到99.9%',
|
||||
ModelTypeConst.LLM,
|
||||
TencentLLMModelCredential,
|
||||
TencentModel),
|
||||
_create_model_info(
|
||||
'hunyuan-lite',
|
||||
'升级为 MOE 结构,上下文窗口为 256k ,在 NLP,代码,数学,行业等多项评测集上领先众多开源模型',
|
||||
ModelTypeConst.LLM,
|
||||
TencentLLMModelCredential,
|
||||
TencentModel),
|
||||
_create_model_info(
|
||||
'hunyuan-role',
|
||||
'混元最新版角色扮演模型,混元官方精调训练推出的角色扮演模型,基于混元模型结合角色扮演场景数据集进行增训,在角色扮演场景具有更好的基础效果',
|
||||
ModelTypeConst.LLM,
|
||||
TencentLLMModelCredential,
|
||||
TencentModel),
|
||||
_create_model_info(
|
||||
'hunyuan-functioncall ',
|
||||
'混元最新 MOE 架构 FunctionCall 模型,经过高质量的 FunctionCall 数据训练,上下文窗口达 32K,在多个维度的评测指标上处于领先。',
|
||||
ModelTypeConst.LLM,
|
||||
TencentLLMModelCredential,
|
||||
TencentModel),
|
||||
_create_model_info(
|
||||
'hunyuan-code',
|
||||
'混元最新代码生成模型,经过 200B 高质量代码数据增训基座模型,迭代半年高质量 SFT 数据训练,上下文长窗口长度增大到 8K,五大语言代码生成自动评测指标上位居前列;五大语言10项考量各方面综合代码任务人工高质量评测上,性能处于第一梯队',
|
||||
ModelTypeConst.LLM,
|
||||
TencentLLMModelCredential,
|
||||
TencentModel),
|
||||
]
|
||||
|
||||
tencent_embedding_model_info = _create_model_info(
|
||||
'hunyuan-embedding',
|
||||
'',
|
||||
ModelTypeConst.EMBEDDING,
|
||||
TencentEmbeddingCredential,
|
||||
TencentEmbeddingModel
|
||||
)
|
||||
|
||||
model_info_embedding_list = [tencent_embedding_model_info]
|
||||
|
||||
model_info_manage = ModelInfoManage.builder() \
|
||||
.append_model_info_list(model_info_list) \
|
||||
.append_default_model_info(model_info_list[0]) \
|
||||
.build()
|
||||
|
||||
return model_info_manage
|
||||
|
||||
|
||||
class TencentModelProvider(IModelProvider):
|
||||
def __init__(self):
|
||||
self._model_info_manage = _initialize_model_info()
|
||||
|
||||
def get_model_info_manage(self):
|
||||
return self._model_info_manage
|
||||
|
||||
def get_model_provide_info(self):
|
||||
icon_path = _get_tencent_icon_path()
|
||||
icon_data = get_file_content(icon_path)
|
||||
return ModelProvideInfo(
|
||||
provider='model_tencent_provider',
|
||||
name='腾讯混元',
|
||||
icon=icon_data
|
||||
)
|
||||
|
|
@ -0,0 +1,2 @@
|
|||
#!/usr/bin/env python
|
||||
# -*- coding: UTF-8 -*-
|
||||
|
|
@ -0,0 +1,46 @@
|
|||
# coding=utf-8
|
||||
"""
|
||||
@project: MaxKB
|
||||
@Author:虎
|
||||
@file: embedding.py
|
||||
@date:2024/7/12 16:45
|
||||
@desc:
|
||||
"""
|
||||
from typing import Dict
|
||||
|
||||
from common import forms
|
||||
from common.exception.app_exception import AppApiException
|
||||
from common.forms import BaseForm
|
||||
from setting.models_provider.base_model_provider import BaseModelCredential, ValidCode
|
||||
|
||||
|
||||
class OpenAIEmbeddingCredential(BaseForm, BaseModelCredential):
|
||||
def is_valid(self, model_type: str, model_name, model_credential: Dict[str, object], provider,
|
||||
raise_exception=True):
|
||||
model_type_list = provider.get_model_type_list()
|
||||
if not any(list(filter(lambda mt: mt.get('value') == model_type, model_type_list))):
|
||||
raise AppApiException(ValidCode.valid_error.value, f'{model_type} 模型类型不支持')
|
||||
|
||||
for key in ['api_base', 'api_key']:
|
||||
if key not in model_credential:
|
||||
if raise_exception:
|
||||
raise AppApiException(ValidCode.valid_error.value, f'{key} 字段为必填字段')
|
||||
else:
|
||||
return False
|
||||
try:
|
||||
model = provider.get_model(model_type, model_name, model_credential)
|
||||
model.embed_query('你好')
|
||||
except Exception as e:
|
||||
if isinstance(e, AppApiException):
|
||||
raise e
|
||||
if raise_exception:
|
||||
raise AppApiException(ValidCode.valid_error.value, f'校验失败,请检查参数是否正确: {str(e)}')
|
||||
else:
|
||||
return False
|
||||
return True
|
||||
|
||||
def encryption_dict(self, model: Dict[str, object]):
|
||||
return {**model, 'api_key': super().encryption(model.get('api_key', ''))}
|
||||
|
||||
api_base = forms.TextInputField('API 域名', required=True)
|
||||
api_key = forms.PasswordInputField('API Key', required=True)
|
||||
|
|
@ -0,0 +1,50 @@
|
|||
# coding=utf-8
|
||||
"""
|
||||
@project: MaxKB
|
||||
@Author:虎
|
||||
@file: llm.py
|
||||
@date:2024/7/11 17:57
|
||||
@desc:
|
||||
"""
|
||||
from typing import Dict
|
||||
|
||||
from langchain_core.messages import HumanMessage
|
||||
|
||||
from common import forms
|
||||
from common.exception.app_exception import AppApiException
|
||||
from common.forms import BaseForm
|
||||
from setting.models_provider.base_model_provider import BaseModelCredential, ValidCode
|
||||
|
||||
|
||||
class VolcanicEngineLLMModelCredential(BaseForm, BaseModelCredential):
|
||||
|
||||
def is_valid(self, model_type: str, model_name, model_credential: Dict[str, object], provider,
|
||||
raise_exception=False):
|
||||
model_type_list = provider.get_model_type_list()
|
||||
if not any(list(filter(lambda mt: mt.get('value') == model_type, model_type_list))):
|
||||
raise AppApiException(ValidCode.valid_error.value, f'{model_type} 模型类型不支持')
|
||||
|
||||
for key in ['access_key_id', 'secret_access_key']:
|
||||
if key not in model_credential:
|
||||
if raise_exception:
|
||||
raise AppApiException(ValidCode.valid_error.value, f'{key} 字段为必填字段')
|
||||
else:
|
||||
return False
|
||||
try:
|
||||
model = provider.get_model(model_type, model_name, model_credential)
|
||||
res = model.invoke([HumanMessage(content='你好')])
|
||||
print(res)
|
||||
except Exception as e:
|
||||
if isinstance(e, AppApiException):
|
||||
raise e
|
||||
if raise_exception:
|
||||
raise AppApiException(ValidCode.valid_error.value, f'校验失败,请检查参数是否正确: {str(e)}')
|
||||
else:
|
||||
return False
|
||||
return True
|
||||
|
||||
def encryption_dict(self, model: Dict[str, object]):
|
||||
return {**model, 'access_key_id': super().encryption(model.get('access_key_id', ''))}
|
||||
|
||||
access_key_id = forms.PasswordInputField('Access Key ID', required=True)
|
||||
secret_access_key = forms.PasswordInputField('Secret Access Key', required=True)
|
||||
File diff suppressed because one or more lines are too long
|
After Width: | Height: | Size: 60 KiB |
|
|
@ -0,0 +1,15 @@
|
|||
from typing import Dict
|
||||
|
||||
from langchain_community.embeddings import VolcanoEmbeddings
|
||||
|
||||
from setting.models_provider.base_model_provider import MaxKBBaseModel
|
||||
|
||||
|
||||
class VolcanicEngineEmbeddingModel(MaxKBBaseModel, VolcanoEmbeddings):
|
||||
@staticmethod
|
||||
def new_instance(model_type, model_name, model_credential: Dict[str, object], **model_kwargs):
|
||||
return VolcanicEngineEmbeddingModel(
|
||||
api_key=model_credential.get('api_key'),
|
||||
model=model_name,
|
||||
openai_api_base=model_credential.get('api_base'),
|
||||
)
|
||||
|
|
@ -0,0 +1,27 @@
|
|||
from typing import List, Dict
|
||||
|
||||
from langchain_community.chat_models import VolcEngineMaasChat
|
||||
from langchain_core.messages import BaseMessage, get_buffer_string
|
||||
|
||||
from common.config.tokenizer_manage_config import TokenizerManage
|
||||
from setting.models_provider.base_model_provider import MaxKBBaseModel
|
||||
from langchain_openai import ChatOpenAI
|
||||
|
||||
|
||||
class VolcanicEngineChatModel(MaxKBBaseModel, ChatOpenAI):
|
||||
@staticmethod
|
||||
def new_instance(model_type, model_name, model_credential: Dict[str, object], **model_kwargs):
|
||||
volcanic_engine_chat = VolcanicEngineChatModel(
|
||||
model=model_name,
|
||||
volc_engine_maas_ak=model_credential.get("access_key_id"),
|
||||
volc_engine_maas_sk=model_credential.get("secret_access_key"),
|
||||
)
|
||||
return volcanic_engine_chat
|
||||
|
||||
def get_num_tokens_from_messages(self, messages: List[BaseMessage]) -> int:
|
||||
tokenizer = TokenizerManage.get_tokenizer()
|
||||
return sum([len(tokenizer.encode(get_buffer_string([m]))) for m in messages])
|
||||
|
||||
def get_num_tokens(self, text: str) -> int:
|
||||
tokenizer = TokenizerManage.get_tokenizer()
|
||||
return len(tokenizer.encode(text))
|
||||
|
|
@ -0,0 +1,51 @@
|
|||
#!/usr/bin/env python
|
||||
# -*- coding: UTF-8 -*-
|
||||
"""
|
||||
@Project :MaxKB
|
||||
@File :gemini_model_provider.py
|
||||
@Author :Brian Yang
|
||||
@Date :5/13/24 7:47 AM
|
||||
"""
|
||||
import os
|
||||
|
||||
from common.util.file_util import get_file_content
|
||||
from setting.models_provider.base_model_provider import IModelProvider, ModelProvideInfo, ModelInfo, ModelTypeConst, \
|
||||
ModelInfoManage
|
||||
from setting.models_provider.impl.openai_model_provider.credential.embedding import OpenAIEmbeddingCredential
|
||||
from setting.models_provider.impl.openai_model_provider.credential.llm import OpenAILLMModelCredential
|
||||
from setting.models_provider.impl.openai_model_provider.model.embedding import OpenAIEmbeddingModel
|
||||
from setting.models_provider.impl.openai_model_provider.model.llm import OpenAIChatModel
|
||||
|
||||
from smartdoc.conf import PROJECT_DIR
|
||||
|
||||
volcanic_engine_llm_model_credential = OpenAILLMModelCredential()
|
||||
|
||||
model_info_list = [
|
||||
ModelInfo('ep-xxxxxxxxxx-yyyy',
|
||||
'用户前往火山方舟的模型推理页面创建推理接入点,这里需要输入ep-xxxxxxxxxx-yyyy进行调用',
|
||||
ModelTypeConst.LLM,
|
||||
volcanic_engine_llm_model_credential, OpenAIChatModel
|
||||
)
|
||||
]
|
||||
|
||||
open_ai_embedding_credential = OpenAIEmbeddingCredential()
|
||||
model_info_embedding_list = [
|
||||
ModelInfo('ep-xxxxxxxxxx-yyyy',
|
||||
'用户前往火山方舟的模型推理页面创建推理接入点,这里需要输入ep-xxxxxxxxxx-yyyy进行调用',
|
||||
ModelTypeConst.EMBEDDING, open_ai_embedding_credential,
|
||||
OpenAIEmbeddingModel)]
|
||||
|
||||
model_info_manage = ModelInfoManage.builder().append_model_info_list(model_info_list).append_default_model_info(
|
||||
model_info_list[0]).build()
|
||||
|
||||
|
||||
class VolcanicEngineModelProvider(IModelProvider):
|
||||
|
||||
def get_model_info_manage(self):
|
||||
return model_info_manage
|
||||
|
||||
def get_model_provide_info(self):
|
||||
return ModelProvideInfo(provider='model_volcanic_engine_provider', name='火山引擎', icon=get_file_content(
|
||||
os.path.join(PROJECT_DIR, "apps", "setting", 'models_provider', 'impl', 'volcanic_engine_model_provider',
|
||||
'icon',
|
||||
'volcanic_engine_icon_svg')))
|
||||
|
|
@ -46,6 +46,8 @@ gunicorn = "^22.0.0"
|
|||
python-daemon = "3.0.1"
|
||||
gevent = "^24.2.1"
|
||||
|
||||
boto3 = "^1.34.151"
|
||||
langchain-aws = "^0.1.13"
|
||||
[build-system]
|
||||
requires = ["poetry-core"]
|
||||
build-backend = "poetry.core.masonry.api"
|
||||
|
|
|
|||
Loading…
Reference in New Issue