feat: 增加对接模型

This commit is contained in:
wxg0103 2024-07-24 14:14:05 +08:00
parent 35f0c18dd3
commit 72423a7c3e
24 changed files with 1057 additions and 0 deletions

View File

@ -8,6 +8,7 @@
"""
from enum import Enum
from setting.models_provider.impl.aws_bedrock_model_provider.aws_bedrock_model_provider import BedrockModelProvider
from setting.models_provider.impl.azure_model_provider.azure_model_provider import AzureModelProvider
from setting.models_provider.impl.deepseek_model_provider.deepseek_model_provider import DeepSeekModelProvider
from setting.models_provider.impl.gemini_model_provider.gemini_model_provider import GeminiModelProvider
@ -15,6 +16,9 @@ from setting.models_provider.impl.kimi_model_provider.kimi_model_provider import
from setting.models_provider.impl.ollama_model_provider.ollama_model_provider import OllamaModelProvider
from setting.models_provider.impl.openai_model_provider.openai_model_provider import OpenAIModelProvider
from setting.models_provider.impl.qwen_model_provider.qwen_model_provider import QwenModelProvider
from setting.models_provider.impl.tencent_model_provider.tencent_model_provider import TencentModelProvider
from setting.models_provider.impl.volcanic_engine_model_provider.volcanic_engine_model_provider import \
VolcanicEngineModelProvider
from setting.models_provider.impl.wenxin_model_provider.wenxin_model_provider import WenxinModelProvider
from setting.models_provider.impl.xf_model_provider.xf_model_provider import XunFeiModelProvider
from setting.models_provider.impl.zhipu_model_provider.zhipu_model_provider import ZhiPuModelProvider
@ -32,4 +36,7 @@ class ModelProvideConstants(Enum):
model_xf_provider = XunFeiModelProvider()
model_deepseek_provider = DeepSeekModelProvider()
model_gemini_provider = GeminiModelProvider()
model_volcanic_engine_provider = VolcanicEngineModelProvider()
model_tencent_provider = TencentModelProvider()
model_aws_bedrock_provider = BedrockModelProvider()
model_local_provider = LocalModelProvider()

View File

@ -0,0 +1,2 @@
#!/usr/bin/env python
# -*- coding: UTF-8 -*-

View File

@ -0,0 +1,107 @@
#!/usr/bin/env python
# -*- coding: UTF-8 -*-
import os
from common.util.file_util import get_file_content
from setting.models_provider.base_model_provider import (
IModelProvider, ModelProvideInfo, ModelInfo, ModelTypeConst, ModelInfoManage
)
from setting.models_provider.impl.aws_bedrock_model_provider.credential.llm import BedrockLLMModelCredential
from setting.models_provider.impl.aws_bedrock_model_provider.model.llm import BedrockModel
from setting.models_provider.impl.tencent_model_provider.credential.embedding import TencentEmbeddingCredential
from setting.models_provider.impl.tencent_model_provider.credential.llm import TencentLLMModelCredential
from setting.models_provider.impl.tencent_model_provider.model.embedding import TencentEmbeddingModel
from setting.models_provider.impl.tencent_model_provider.model.llm import TencentModel
from smartdoc.conf import PROJECT_DIR
def _create_model_info(model_name, description, model_type, credential_class, model_class):
return ModelInfo(
name=model_name,
desc=description,
model_type=model_type,
model_credential=credential_class(),
model_class=model_class
)
def _get_aws_bedrock_icon_path():
return os.path.join(PROJECT_DIR, "apps", "setting", 'models_provider', 'impl', 'aws_bedrock_model_provider',
'icon', 'bedrock_icon_svg')
def _initialize_model_info():
model_info_list = [_create_model_info(
'amazon.titan-text-premier-v1:0',
'Titan Text Premier 是 Titan Text 系列中功能强大且先进的型号,旨在为各种企业应用程序提供卓越的性能。凭借其尖端功能,它提供了更高的准确性和出色的结果,使其成为寻求一流文本处理解决方案的组织的绝佳选择。',
ModelTypeConst.LLM,
BedrockLLMModelCredential,
BedrockModel
),
_create_model_info(
'amazon.titan-text-lite-v1',
'Amazon Titan Text Lite 是一种轻量级的高效模型,非常适合英语任务的微调,包括摘要和文案写作等,在这种场景下,客户需要更小、更经济高效且高度可定制的模型',
ModelTypeConst.LLM,
BedrockLLMModelCredential,
BedrockModel),
_create_model_info(
'amazon.titan-text-express-v1',
'Amazon Titan Text Express 的上下文长度长达 8000 个令牌因而非常适合各种高级常规语言任务例如开放式文本生成和对话式聊天以及检索增强生成RAG中的支持。在发布时该模型针对英语进行了优化但也支持其他语言。',
ModelTypeConst.LLM,
BedrockLLMModelCredential,
BedrockModel),
_create_model_info(
'amazon.titan-embed-text-v2:0',
'Amazon Titan Text Embeddings V2 是一种轻量级、高效的模型非常适合在不同维度上执行高精度检索任务。该模型支持灵活的嵌入大小1024、512 和 256并优先考虑在较小维度上保持准确性从而可以在不影响准确性的情况下降低存储成本。Titan Text Embeddings V2 适用于各种任务,包括文档检索、推荐系统、搜索引擎和对话式系统。',
ModelTypeConst.LLM,
BedrockLLMModelCredential,
BedrockModel),
_create_model_info(
'mistral.mistral-7b-instruct-v0:2',
'7B 密集型转换器,可快速部署,易于定制。体积虽小,但功能强大,适用于各种用例。支持英语和代码,以及 32k 的上下文窗口。',
ModelTypeConst.EMBEDDING,
BedrockLLMModelCredential,
BedrockModel),
_create_model_info(
'mistral.mistral-large-2402-v1:0',
'先进的 Mistral AI 大型语言模型,能够处理任何语言任务,包括复杂的多语言推理、文本理解、转换和代码生成。',
ModelTypeConst.LLM,
BedrockLLMModelCredential,
BedrockModel),
_create_model_info(
'meta.llama3-70b-instruct-v1:0',
'非常适合内容创作、会话式人工智能、语言理解、研发和企业应用',
ModelTypeConst.LLM,
BedrockLLMModelCredential,
BedrockModel),
_create_model_info(
'meta.llama3-8b-instruct-v1:0',
'非常适合有限的计算能力和资源、边缘设备和更快的训练时间。',
ModelTypeConst.LLM,
BedrockLLMModelCredential,
BedrockModel),
]
model_info_manage = ModelInfoManage.builder() \
.append_model_info_list(model_info_list) \
.append_default_model_info(model_info_list[0]) \
.build()
return model_info_manage
class BedrockModelProvider(IModelProvider):
def __init__(self):
self._model_info_manage = _initialize_model_info()
def get_model_info_manage(self):
return self._model_info_manage
def get_model_provide_info(self):
icon_path = _get_aws_bedrock_icon_path()
icon_data = get_file_content(icon_path)
return ModelProvideInfo(
provider='model_aws_bedrock_provider',
name='Amazon Bedrock',
icon=icon_data
)

View File

@ -0,0 +1,64 @@
import json
from typing import Dict
from tencentcloud.common import credential
from tencentcloud.common.profile.client_profile import ClientProfile
from tencentcloud.common.profile.http_profile import HttpProfile
from tencentcloud.hunyuan.v20230901 import hunyuan_client, models
from common import forms
from common.exception.app_exception import AppApiException
from common.forms import BaseForm
from setting.models_provider.base_model_provider import BaseModelCredential, ValidCode
class TencentEmbeddingCredential(BaseForm, BaseModelCredential):
@classmethod
def _validate_model_type(cls, model_type: str, provider) -> bool:
model_type_list = provider.get_model_type_list()
if not any(mt.get('value') == model_type for mt in model_type_list):
raise AppApiException(ValidCode.valid_error.value, f'{model_type} 模型类型不支持')
return True
@classmethod
def _validate_credential(cls, model_credential: Dict[str, object]) -> credential.Credential:
for key in ['SecretId', 'SecretKey']:
if key not in model_credential:
raise AppApiException(ValidCode.valid_error.value, f'{key} 字段为必填字段')
return credential.Credential(model_credential['SecretId'], model_credential['SecretKey'])
@classmethod
def _test_credentials(cls, client, model_name: str):
req = models.GetEmbeddingRequest()
params = {
"Model": model_name,
"Input": "测试"
}
req.from_json_string(json.dumps(params))
try:
res = client.GetEmbedding(req)
print(res.to_json_string())
except Exception as e:
raise AppApiException(ValidCode.valid_error.value, f'校验失败,请检查参数是否正确: {str(e)}')
def is_valid(self, model_type: str, model_name, model_credential: Dict[str, object], provider,
raise_exception=True) -> bool:
try:
self._validate_model_type(model_type, provider)
cred = self._validate_credential(model_credential)
httpProfile = HttpProfile(endpoint="hunyuan.tencentcloudapi.com")
clientProfile = ClientProfile(httpProfile=httpProfile)
client = hunyuan_client.HunyuanClient(cred, "", clientProfile)
self._test_credentials(client, model_name)
return True
except AppApiException as e:
if raise_exception:
raise e
return False
def encryption_dict(self, model: Dict[str, object]) -> Dict[str, object]:
encrypted_secret_key = super().encryption(model.get('SecretKey', ''))
return {**model, 'SecretKey': encrypted_secret_key}
SecretId = forms.PasswordInputField('SecretId', required=True)
SecretKey = forms.PasswordInputField('SecretKey', required=True)

View File

@ -0,0 +1,62 @@
import os
import re
from typing import Dict
from common.exception.app_exception import AppApiException
from common.forms import BaseForm
from setting.models_provider.base_model_provider import ValidCode, BaseModelCredential
from langchain_core.messages import HumanMessage
from common import forms
class BedrockLLMModelCredential(BaseForm, BaseModelCredential):
@staticmethod
def _update_aws_credentials(profile_name, access_key_id, secret_access_key):
credentials_path = os.path.join(os.path.expanduser("~"), ".aws", "credentials")
os.makedirs(os.path.dirname(credentials_path), exist_ok=True)
content = open(credentials_path, 'r').read() if os.path.exists(credentials_path) else ''
pattern = rf'\n*\[{profile_name}\]\n*(aws_access_key_id = .*)\n*(aws_secret_access_key = .*)\n*'
content = re.sub(pattern, '', content, flags=re.DOTALL)
if not re.search(rf'\[{profile_name}\]', content):
content += f"\n[{profile_name}]\naws_access_key_id = {access_key_id}\naws_secret_access_key = {secret_access_key}\n"
with open(credentials_path, 'w') as file:
file.write(content)
def is_valid(self, model_type: str, model_name, model_credential: Dict[str, object], provider,
raise_exception=False):
model_type_list = provider.get_model_type_list()
if not any(mt.get('value') == model_type for mt in model_type_list):
if raise_exception:
raise AppApiException(ValidCode.valid_error.value, f'{model_type} 模型类型不支持')
return False
required_keys = ['region_name', 'access_key_id', 'secret_access_key']
if not all(key in model_credential for key in required_keys):
if raise_exception:
raise AppApiException(ValidCode.valid_error.value, f'以下字段为必填字段: {", ".join(required_keys)}')
return False
try:
self._update_aws_credentials('aws-profile', model_credential['access_key_id'],
model_credential['secret_access_key'])
model_credential['credentials_profile_name'] = 'aws-profile'
model = provider.get_model(model_type, model_name, model_credential)
model.invoke([HumanMessage(content='你好')])
except AppApiException:
raise
except Exception as e:
if raise_exception:
raise AppApiException(ValidCode.valid_error.value, f'校验失败,请检查参数是否正确: {str(e)}')
return False
return True
def encryption_dict(self, model: Dict[str, object]):
return {**model, 'secret_access_key': super().encryption(model.get('secret_access_key', ''))}
region_name = forms.TextInputField('Region Name', required=True)
access_key_id = forms.TextInputField('Access Key ID', required=True)
secret_access_key = forms.PasswordInputField('Secret Access Key', required=True)

View File

@ -0,0 +1 @@
<svg xmlns="http://www.w3.org/2000/svg" version="2.0" focusable="false" aria-hidden="true" class="globalNav-1216 globalNav-1213" data-testid="awsc-logo" viewBox="0 0 29 17"><path class="globalNav-1214" d="M8.38 6.17a2.6 2.6 0 00.11.83c.08.232.18.456.3.67a.4.4 0 01.07.21.36.36 0 01-.18.28l-.59.39a.43.43 0 01-.24.08.38.38 0 01-.28-.13 2.38 2.38 0 01-.34-.43c-.09-.16-.18-.34-.28-.55a3.44 3.44 0 01-2.74 1.29 2.54 2.54 0 01-1.86-.67 2.36 2.36 0 01-.68-1.79 2.43 2.43 0 01.84-1.92 3.43 3.43 0 012.29-.72 6.75 6.75 0 011 .07c.35.05.7.12 1.07.2V3.3a2.06 2.06 0 00-.44-1.49 2.12 2.12 0 00-1.52-.43 4.4 4.4 0 00-1 .12 6.85 6.85 0 00-1 .32l-.33.12h-.14c-.14 0-.2-.1-.2-.29v-.46A.62.62 0 012.3.87a.78.78 0 01.27-.2A6 6 0 013.74.25 5.7 5.7 0 015.19.07a3.37 3.37 0 012.44.76 3 3 0 01.77 2.29l-.02 3.05zM4.6 7.59a3 3 0 001-.17 2 2 0 00.88-.6 1.36 1.36 0 00.32-.59 3.18 3.18 0 00.09-.81V5A7.52 7.52 0 006 4.87h-.88a2.13 2.13 0 00-1.38.37 1.3 1.3 0 00-.46 1.08 1.3 1.3 0 00.34 1c.278.216.63.313.98.27zm7.49 1a.56.56 0 01-.36-.09.73.73 0 01-.2-.37L9.35.93a1.39 1.39 0 01-.08-.38c0-.15.07-.23.22-.23h.92a.56.56 0 01.36.09.74.74 0 01.19.37L12.53 7 14 .79a.61.61 0 01.18-.37.59.59 0 01.37-.09h.75a.62.62 0 01.38.09.74.74 0 01.18.37L17.31 7 18.92.76a.74.74 0 01.19-.37.56.56 0 01.36-.09h.87a.21.21 0 01.23.23 1 1 0 010 .15s0 .13-.06.23l-2.26 7.2a.74.74 0 01-.19.37.6.6 0 01-.36.09h-.8a.53.53 0 01-.37-.1.64.64 0 01-.18-.37l-1.45-6-1.44 6a.64.64 0 01-.18.37.55.55 0 01-.37.1l-.82.02zm12 .24a6.29 6.29 0 01-1.44-.16 4.21 4.21 0 01-1.07-.37.69.69 0 01-.29-.26.66.66 0 01-.06-.27V7.3c0-.19.07-.29.21-.29a.57.57 0 01.18 0l.23.1c.32.143.656.25 1 .32.365.08.737.12 1.11.12a2.47 2.47 0 001.36-.31 1 1 0 00.48-.88.88.88 0 00-.25-.65 2.29 2.29 0 00-.94-.49l-1.35-.43a2.83 2.83 0 01-1.49-.94 2.24 2.24 0 01-.47-1.36 2 2 0 01.25-1c.167-.3.395-.563.67-.77a3 3 0 011-.48A4.1 4.1 0 0124.4.08a4.4 4.4 0 01.62 0l.61.1.53.15.39.16c.105.062.2.14.28.23a.57.57 0 01.08.31v.44c0 .2-.07.3-.21.3a.92.92 0 01-.36-.12 4.35 4.35 0 00-1.8-.36 2.51 2.51 0 00-1.24.26.92.92 0 00-.44.84c0 .249.1.488.28.66.295.236.635.41 1 .51l1.32.42a2.88 2.88 0 011.44.9 2.1 2.1 0 01.43 1.31 2.38 2.38 0 01-.24 1.08 2.34 2.34 0 01-.68.82 3 3 0 01-1 .53 4.59 4.59 0 01-1.35.22l.03-.01z"></path><path class="globalNav-1215" d="M25.82 13.43a20.07 20.07 0 01-11.35 3.47A20.54 20.54 0 01.61 11.62c-.29-.26 0-.62.32-.42a27.81 27.81 0 0013.86 3.68 27.54 27.54 0 0010.58-2.16c.52-.22.96.34.45.71z"></path><path class="globalNav-1215" d="M27.1 12c-.4-.51-2.6-.24-3.59-.12-.3 0-.34-.23-.07-.42 1.75-1.23 4.63-.88 5-.46.37.42-.09 3.3-1.74 4.68-.25.21-.49.09-.38-.18.34-.95 1.17-3.02.78-3.5z"></path></svg>

After

Width:  |  Height:  |  Size: 2.6 KiB

View File

@ -0,0 +1,25 @@
from setting.models_provider.base_model_provider import MaxKBBaseModel
from typing import Dict
import requests
class TencentEmbeddingModel(MaxKBBaseModel):
def __init__(self, secret_id: str, secret_key: str, api_base: str, model_name: str):
self.secret_id = secret_id
self.secret_key = secret_key
self.api_base = api_base
self.model_name = model_name
@staticmethod
def new_instance(model_type: str, model_name: str, model_credential: Dict[str, str], **model_kwargs):
return TencentEmbeddingModel(
secret_id=model_credential.get('SecretId'),
secret_key=model_credential.get('SecretKey'),
api_base=model_credential.get('api_base'),
model_name=model_name,
)
def _generate_auth_token(self):
# Example method to generate an authentication token for the model API
return f"{self.secret_id}:{self.secret_key}"

View File

@ -0,0 +1,35 @@
from typing import List, Dict, Any
from langchain_community.chat_models import BedrockChat
from langchain_core.messages import BaseMessage, get_buffer_string
from common.config.tokenizer_manage_config import TokenizerManage
from setting.models_provider.base_model_provider import MaxKBBaseModel
class BedrockModel(MaxKBBaseModel, BedrockChat):
def __init__(self, model_id: str, region_name: str, credentials_profile_name: str,
streaming: bool = False, **kwargs):
super().__init__(model_id=model_id, region_name=region_name,
credentials_profile_name=credentials_profile_name, streaming=streaming, **kwargs)
@classmethod
def new_instance(cls, model_type: str, model_name: str, model_credential: Dict[str, str],
**model_kwargs) -> 'BedrockModel':
return cls(
model_id=model_name,
region_name=model_credential['region_name'],
credentials_profile_name=model_credential['credentials_profile_name'],
streaming=model_kwargs.pop('streaming', False),
**model_kwargs
)
def _get_num_tokens(self, content: str) -> int:
"""Helper method to count tokens in a string."""
tokenizer = TokenizerManage.get_tokenizer()
return len(tokenizer.encode(content))
def get_num_tokens_from_messages(self, messages: List[BaseMessage]) -> int:
return sum(self._get_num_tokens(get_buffer_string([message])) for message in messages)
def get_num_tokens(self, text: str) -> int:
return self._get_num_tokens(text)

View File

@ -0,0 +1,2 @@
#!/usr/bin/env python
# -*- coding: UTF-8 -*-

View File

@ -0,0 +1,64 @@
import json
from typing import Dict
from tencentcloud.common import credential
from tencentcloud.common.profile.client_profile import ClientProfile
from tencentcloud.common.profile.http_profile import HttpProfile
from tencentcloud.hunyuan.v20230901 import hunyuan_client, models
from common import forms
from common.exception.app_exception import AppApiException
from common.forms import BaseForm
from setting.models_provider.base_model_provider import BaseModelCredential, ValidCode
class TencentEmbeddingCredential(BaseForm, BaseModelCredential):
@classmethod
def _validate_model_type(cls, model_type: str, provider) -> bool:
model_type_list = provider.get_model_type_list()
if not any(mt.get('value') == model_type for mt in model_type_list):
raise AppApiException(ValidCode.valid_error.value, f'{model_type} 模型类型不支持')
return True
@classmethod
def _validate_credential(cls, model_credential: Dict[str, object]) -> credential.Credential:
for key in ['SecretId', 'SecretKey']:
if key not in model_credential:
raise AppApiException(ValidCode.valid_error.value, f'{key} 字段为必填字段')
return credential.Credential(model_credential['SecretId'], model_credential['SecretKey'])
@classmethod
def _test_credentials(cls, client, model_name: str):
req = models.GetEmbeddingRequest()
params = {
"Model": model_name,
"Input": "测试"
}
req.from_json_string(json.dumps(params))
try:
res = client.GetEmbedding(req)
print(res.to_json_string())
except Exception as e:
raise AppApiException(ValidCode.valid_error.value, f'校验失败,请检查参数是否正确: {str(e)}')
def is_valid(self, model_type: str, model_name, model_credential: Dict[str, object], provider,
raise_exception=True) -> bool:
try:
self._validate_model_type(model_type, provider)
cred = self._validate_credential(model_credential)
httpProfile = HttpProfile(endpoint="hunyuan.tencentcloudapi.com")
clientProfile = ClientProfile(httpProfile=httpProfile)
client = hunyuan_client.HunyuanClient(cred, "", clientProfile)
self._test_credentials(client, model_name)
return True
except AppApiException as e:
if raise_exception:
raise e
return False
def encryption_dict(self, model: Dict[str, object]) -> Dict[str, object]:
encrypted_secret_key = super().encryption(model.get('SecretKey', ''))
return {**model, 'SecretKey': encrypted_secret_key}
SecretId = forms.PasswordInputField('SecretId', required=True)
SecretKey = forms.PasswordInputField('SecretKey', required=True)

View File

@ -0,0 +1,47 @@
# coding=utf-8
from langchain_core.messages import HumanMessage
from common import forms
from common.exception.app_exception import AppApiException
from common.forms import BaseForm
from setting.models_provider.base_model_provider import BaseModelCredential, ValidCode
class TencentLLMModelCredential(BaseForm, BaseModelCredential):
REQUIRED_FIELDS = ['hunyuan_app_id', 'hunyuan_secret_id', 'hunyuan_secret_key']
@classmethod
def _validate_model_type(cls, model_type, provider, raise_exception=False):
if not any(mt['value'] == model_type for mt in provider.get_model_type_list()):
if raise_exception:
raise AppApiException(ValidCode.valid_error.value, f'{model_type} 模型类型不支持')
return False
return True
@classmethod
def _validate_credential_fields(cls, model_credential, raise_exception=False):
missing_keys = [key for key in cls.REQUIRED_FIELDS if key not in model_credential]
if missing_keys:
if raise_exception:
raise AppApiException(ValidCode.valid_error.value, f'{", ".join(missing_keys)} 字段为必填字段')
return False
return True
def is_valid(self, model_type, model_name, model_credential, provider, raise_exception=False):
if not (self._validate_model_type(model_type, provider, raise_exception) and
self._validate_credential_fields(model_credential, raise_exception)):
return False
try:
model = provider.get_model(model_type, model_name, model_credential)
model.invoke([HumanMessage(content='你好')])
except Exception as e:
if raise_exception:
raise AppApiException(ValidCode.valid_error.value, f'校验失败,请检查参数是否正确: {str(e)}')
return False
return True
def encryption_dict(self, model):
return {**model, 'hunyuan_secret_key': super().encryption(model.get('hunyuan_secret_key', ''))}
hunyuan_app_id = forms.TextInputField('APP ID', required=True)
hunyuan_secret_id = forms.PasswordInputField('SecretId', required=True)
hunyuan_secret_key = forms.PasswordInputField('SecretKey', required=True)

File diff suppressed because one or more lines are too long

After

Width:  |  Height:  |  Size: 74 KiB

View File

@ -0,0 +1,25 @@
from setting.models_provider.base_model_provider import MaxKBBaseModel
from typing import Dict
import requests
class TencentEmbeddingModel(MaxKBBaseModel):
def __init__(self, secret_id: str, secret_key: str, api_base: str, model_name: str):
self.secret_id = secret_id
self.secret_key = secret_key
self.api_base = api_base
self.model_name = model_name
@staticmethod
def new_instance(model_type: str, model_name: str, model_credential: Dict[str, str], **model_kwargs):
return TencentEmbeddingModel(
secret_id=model_credential.get('SecretId'),
secret_key=model_credential.get('SecretKey'),
api_base=model_credential.get('api_base'),
model_name=model_name,
)
def _generate_auth_token(self):
# Example method to generate an authentication token for the model API
return f"{self.secret_id}:{self.secret_key}"

View File

@ -0,0 +1,273 @@
import json
import logging
from typing import Any, Dict, Iterator, List, Mapping, Optional, Type
from langchain_core.callbacks import CallbackManagerForLLMRun
from langchain_core.language_models.chat_models import (
BaseChatModel,
generate_from_stream,
)
from langchain_core.messages import (
AIMessage,
AIMessageChunk,
BaseMessage,
BaseMessageChunk,
ChatMessage,
ChatMessageChunk,
HumanMessage,
HumanMessageChunk,
)
from langchain_core.outputs import ChatGeneration, ChatGenerationChunk, ChatResult
from langchain_core.pydantic_v1 import Field, SecretStr, root_validator
from langchain_core.utils import (
convert_to_secret_str,
get_from_dict_or_env,
get_pydantic_field_names,
pre_init,
)
logger = logging.getLogger(__name__)
def _convert_message_to_dict(message: BaseMessage) -> dict:
message_dict: Dict[str, Any]
if isinstance(message, ChatMessage):
message_dict = {"Role": message.role, "Content": message.content}
elif isinstance(message, HumanMessage):
message_dict = {"Role": "user", "Content": message.content}
elif isinstance(message, AIMessage):
message_dict = {"Role": "assistant", "Content": message.content}
else:
raise TypeError(f"Got unknown type {message}")
return message_dict
def _convert_dict_to_message(_dict: Mapping[str, Any]) -> BaseMessage:
role = _dict["Role"]
if role == "user":
return HumanMessage(content=_dict["Content"])
elif role == "assistant":
return AIMessage(content=_dict.get("Content", "") or "")
else:
return ChatMessage(content=_dict["Content"], role=role)
def _convert_delta_to_message_chunk(
_dict: Mapping[str, Any], default_class: Type[BaseMessageChunk]
) -> BaseMessageChunk:
role = _dict.get("Role")
content = _dict.get("Content") or ""
if role == "user" or default_class == HumanMessageChunk:
return HumanMessageChunk(content=content)
elif role == "assistant" or default_class == AIMessageChunk:
return AIMessageChunk(content=content)
elif role or default_class == ChatMessageChunk:
return ChatMessageChunk(content=content, role=role) # type: ignore[arg-type]
else:
return default_class(content=content) # type: ignore[call-arg]
def _create_chat_result(response: Mapping[str, Any]) -> ChatResult:
generations = []
for choice in response["Choices"]:
message = _convert_dict_to_message(choice["Message"])
generations.append(ChatGeneration(message=message))
token_usage = response["Usage"]
llm_output = {"token_usage": token_usage}
return ChatResult(generations=generations, llm_output=llm_output)
class ChatHunyuan(BaseChatModel):
"""Tencent Hunyuan chat models API by Tencent.
For more information, see https://cloud.tencent.com/document/product/1729
"""
@property
def lc_secrets(self) -> Dict[str, str]:
return {
"hunyuan_app_id": "HUNYUAN_APP_ID",
"hunyuan_secret_id": "HUNYUAN_SECRET_ID",
"hunyuan_secret_key": "HUNYUAN_SECRET_KEY",
}
@property
def lc_serializable(self) -> bool:
return True
hunyuan_app_id: Optional[int] = None
"""Hunyuan App ID"""
hunyuan_secret_id: Optional[str] = None
"""Hunyuan Secret ID"""
hunyuan_secret_key: Optional[SecretStr] = None
"""Hunyuan Secret Key"""
streaming: bool = False
"""Whether to stream the results or not."""
request_timeout: int = 60
"""Timeout for requests to Hunyuan API. Default is 60 seconds."""
temperature: float = 1.0
"""What sampling temperature to use."""
top_p: float = 1.0
"""What probability mass to use."""
model: str = "hunyuan-lite"
"""What Model to use.
Optional model:
- hunyuan-lite
- hunyuan-standard
- hunyuan-standard-256K
- hunyuan-pro
- hunyuan-code
- hunyuan-role
- hunyuan-functioncall
- hunyuan-vision
"""
stream_moderation: bool = False
"""Whether to review the results or not when streaming is true."""
enable_enhancement: bool = True
"""Whether to enhancement the results or not."""
model_kwargs: Dict[str, Any] = Field(default_factory=dict)
"""Holds any model parameters valid for API call not explicitly specified."""
class Config:
"""Configuration for this pydantic object."""
allow_population_by_field_name = True
@root_validator(pre=True)
def build_extra(cls, values: Dict[str, Any]) -> Dict[str, Any]:
"""Build extra kwargs from additional params that were passed in."""
all_required_field_names = get_pydantic_field_names(cls)
extra = values.get("model_kwargs", {})
for field_name in list(values):
if field_name in extra:
raise ValueError(f"Found {field_name} supplied twice.")
if field_name not in all_required_field_names:
logger.warning(
f"""WARNING! {field_name} is not default parameter.
{field_name} was transferred to model_kwargs.
Please confirm that {field_name} is what you intended."""
)
extra[field_name] = values.pop(field_name)
invalid_model_kwargs = all_required_field_names.intersection(extra.keys())
if invalid_model_kwargs:
raise ValueError(
f"Parameters {invalid_model_kwargs} should be specified explicitly. "
f"Instead they were passed in as part of `model_kwargs` parameter."
)
values["model_kwargs"] = extra
return values
@pre_init
def validate_environment(cls, values: Dict) -> Dict:
values["hunyuan_app_id"] = get_from_dict_or_env(
values,
"hunyuan_app_id",
"HUNYUAN_APP_ID",
)
values["hunyuan_secret_id"] = get_from_dict_or_env(
values,
"hunyuan_secret_id",
"HUNYUAN_SECRET_ID",
)
values["hunyuan_secret_key"] = convert_to_secret_str(
get_from_dict_or_env(
values,
"hunyuan_secret_key",
"HUNYUAN_SECRET_KEY",
)
)
return values
@property
def _default_params(self) -> Dict[str, Any]:
"""Get the default parameters for calling Hunyuan API."""
normal_params = {
"Temperature": self.temperature,
"TopP": self.top_p,
"Model": self.model,
"Stream": self.streaming,
"StreamModeration": self.stream_moderation,
"EnableEnhancement": self.enable_enhancement,
}
return {**normal_params, **self.model_kwargs}
def _generate(
self,
messages: List[BaseMessage],
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> ChatResult:
if self.streaming:
stream_iter = self._stream(
messages=messages, stop=stop, run_manager=run_manager, **kwargs
)
return generate_from_stream(stream_iter)
res = self._chat(messages, **kwargs)
return _create_chat_result(json.loads(res.to_json_string()))
def _stream(
self,
messages: List[BaseMessage],
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> Iterator[ChatGenerationChunk]:
res = self._chat(messages, **kwargs)
default_chunk_class = AIMessageChunk
for chunk in res:
chunk = chunk.get("data", "")
if len(chunk) == 0:
continue
response = json.loads(chunk)
if "error" in response:
raise ValueError(f"Error from Hunyuan api response: {response}")
for choice in response["Choices"]:
chunk = _convert_delta_to_message_chunk(
choice["Delta"], default_chunk_class
)
default_chunk_class = chunk.__class__
cg_chunk = ChatGenerationChunk(message=chunk)
if run_manager:
run_manager.on_llm_new_token(chunk.content, chunk=cg_chunk)
yield cg_chunk
def _chat(self, messages: List[BaseMessage], **kwargs: Any) -> Any:
if self.hunyuan_secret_key is None:
raise ValueError("Hunyuan secret key is not set.")
try:
from tencentcloud.common import credential
from tencentcloud.hunyuan.v20230901 import hunyuan_client, models
except ImportError:
raise ImportError(
"Could not import tencentcloud python package. "
"Please install it with `pip install tencentcloud-sdk-python`."
)
parameters = {**self._default_params, **kwargs}
cred = credential.Credential(
self.hunyuan_secret_id, str(self.hunyuan_secret_key.get_secret_value())
)
client = hunyuan_client.HunyuanClient(cred, "")
req = models.ChatCompletionsRequest()
params = {
"Messages": [_convert_message_to_dict(m) for m in messages],
**parameters,
}
req.from_json_string(json.dumps(params))
resp = client.ChatCompletions(req)
return resp
@property
def _llm_type(self) -> str:
return "hunyuan-chat"

View File

@ -0,0 +1,37 @@
# coding=utf-8
from typing import List, Dict
from langchain_core.messages import BaseMessage, get_buffer_string
from common.config.tokenizer_manage_config import TokenizerManage
from setting.models_provider.base_model_provider import MaxKBBaseModel
from setting.models_provider.impl.tencent_model_provider.model.hunyuan import ChatHunyuan
class TencentModel(MaxKBBaseModel, ChatHunyuan):
def __init__(self, model_name: str, credentials: Dict[str, str], streaming: bool = False, **kwargs):
hunyuan_app_id = credentials.get('hunyuan_app_id')
hunyuan_secret_id = credentials.get('hunyuan_secret_id')
hunyuan_secret_key = credentials.get('hunyuan_secret_key')
if not all([hunyuan_app_id, hunyuan_secret_id, hunyuan_secret_key]):
raise ValueError(
"All of 'hunyuan_app_id', 'hunyuan_secret_id', and 'hunyuan_secret_key' must be provided in credentials.")
super().__init__(model=model_name, hunyuan_app_id=hunyuan_app_id, hunyuan_secret_id=hunyuan_secret_id,
hunyuan_secret_key=hunyuan_secret_key, streaming=streaming, **kwargs)
@staticmethod
def new_instance(model_type: str, model_name: str, model_credential: Dict[str, object],
**model_kwargs) -> 'TencentModel':
streaming = model_kwargs.pop('streaming', False)
return TencentModel(model_name=model_name, credentials=model_credential, streaming=streaming, **model_kwargs)
def get_num_tokens_from_messages(self, messages: List[BaseMessage]) -> int:
tokenizer = TokenizerManage.get_tokenizer()
return sum(len(tokenizer.encode(get_buffer_string([m]))) for m in messages)
def get_num_tokens(self, text: str) -> int:
tokenizer = TokenizerManage.get_tokenizer()
return len(tokenizer.encode(text))

View File

@ -0,0 +1,103 @@
#!/usr/bin/env python
# -*- coding: UTF-8 -*-
import os
from common.util.file_util import get_file_content
from setting.models_provider.base_model_provider import (
IModelProvider, ModelProvideInfo, ModelInfo, ModelTypeConst, ModelInfoManage
)
from setting.models_provider.impl.tencent_model_provider.credential.embedding import TencentEmbeddingCredential
from setting.models_provider.impl.tencent_model_provider.credential.llm import TencentLLMModelCredential
from setting.models_provider.impl.tencent_model_provider.model.embedding import TencentEmbeddingModel
from setting.models_provider.impl.tencent_model_provider.model.llm import TencentModel
from smartdoc.conf import PROJECT_DIR
def _create_model_info(model_name, description, model_type, credential_class, model_class):
return ModelInfo(
name=model_name,
desc=description,
model_type=model_type,
model_credential=credential_class(),
model_class=model_class
)
def _get_tencent_icon_path():
return os.path.join(PROJECT_DIR, "apps", "setting", 'models_provider', 'impl', 'tencent_model_provider',
'icon', 'tencent_icon_svg')
def _initialize_model_info():
model_info_list = [_create_model_info(
'hunyuan-pro',
'当前混元模型中效果最优版本,万亿级参数规模 MOE-32K 长文模型。在各种 benchmark 上达到绝对领先的水平,复杂指令和推理,具备复杂数学能力,支持 functioncall在多语言翻译、金融法律医疗等领域应用重点优化',
ModelTypeConst.LLM,
TencentLLMModelCredential,
TencentModel
),
_create_model_info(
'hunyuan-standard',
'采用更优的路由策略同时缓解了负载均衡和专家趋同的问题。长文方面大海捞针指标达到99.9%',
ModelTypeConst.LLM,
TencentLLMModelCredential,
TencentModel),
_create_model_info(
'hunyuan-lite',
'升级为 MOE 结构,上下文窗口为 256k ,在 NLP代码数学行业等多项评测集上领先众多开源模型',
ModelTypeConst.LLM,
TencentLLMModelCredential,
TencentModel),
_create_model_info(
'hunyuan-role',
'混元最新版角色扮演模型,混元官方精调训练推出的角色扮演模型,基于混元模型结合角色扮演场景数据集进行增训,在角色扮演场景具有更好的基础效果',
ModelTypeConst.LLM,
TencentLLMModelCredential,
TencentModel),
_create_model_info(
'hunyuan-functioncall ',
'混元最新 MOE 架构 FunctionCall 模型,经过高质量的 FunctionCall 数据训练,上下文窗口达 32K在多个维度的评测指标上处于领先。',
ModelTypeConst.LLM,
TencentLLMModelCredential,
TencentModel),
_create_model_info(
'hunyuan-code',
'混元最新代码生成模型,经过 200B 高质量代码数据增训基座模型,迭代半年高质量 SFT 数据训练,上下文长窗口长度增大到 8K五大语言代码生成自动评测指标上位居前列五大语言10项考量各方面综合代码任务人工高质量评测上性能处于第一梯队',
ModelTypeConst.LLM,
TencentLLMModelCredential,
TencentModel),
]
tencent_embedding_model_info = _create_model_info(
'hunyuan-embedding',
'',
ModelTypeConst.EMBEDDING,
TencentEmbeddingCredential,
TencentEmbeddingModel
)
model_info_embedding_list = [tencent_embedding_model_info]
model_info_manage = ModelInfoManage.builder() \
.append_model_info_list(model_info_list) \
.append_default_model_info(model_info_list[0]) \
.build()
return model_info_manage
class TencentModelProvider(IModelProvider):
def __init__(self):
self._model_info_manage = _initialize_model_info()
def get_model_info_manage(self):
return self._model_info_manage
def get_model_provide_info(self):
icon_path = _get_tencent_icon_path()
icon_data = get_file_content(icon_path)
return ModelProvideInfo(
provider='model_tencent_provider',
name='腾讯混元',
icon=icon_data
)

View File

@ -0,0 +1,2 @@
#!/usr/bin/env python
# -*- coding: UTF-8 -*-

View File

@ -0,0 +1,46 @@
# coding=utf-8
"""
@project: MaxKB
@Author
@file embedding.py
@date2024/7/12 16:45
@desc:
"""
from typing import Dict
from common import forms
from common.exception.app_exception import AppApiException
from common.forms import BaseForm
from setting.models_provider.base_model_provider import BaseModelCredential, ValidCode
class OpenAIEmbeddingCredential(BaseForm, BaseModelCredential):
def is_valid(self, model_type: str, model_name, model_credential: Dict[str, object], provider,
raise_exception=True):
model_type_list = provider.get_model_type_list()
if not any(list(filter(lambda mt: mt.get('value') == model_type, model_type_list))):
raise AppApiException(ValidCode.valid_error.value, f'{model_type} 模型类型不支持')
for key in ['api_base', 'api_key']:
if key not in model_credential:
if raise_exception:
raise AppApiException(ValidCode.valid_error.value, f'{key} 字段为必填字段')
else:
return False
try:
model = provider.get_model(model_type, model_name, model_credential)
model.embed_query('你好')
except Exception as e:
if isinstance(e, AppApiException):
raise e
if raise_exception:
raise AppApiException(ValidCode.valid_error.value, f'校验失败,请检查参数是否正确: {str(e)}')
else:
return False
return True
def encryption_dict(self, model: Dict[str, object]):
return {**model, 'api_key': super().encryption(model.get('api_key', ''))}
api_base = forms.TextInputField('API 域名', required=True)
api_key = forms.PasswordInputField('API Key', required=True)

View File

@ -0,0 +1,50 @@
# coding=utf-8
"""
@project: MaxKB
@Author
@file llm.py
@date2024/7/11 17:57
@desc:
"""
from typing import Dict
from langchain_core.messages import HumanMessage
from common import forms
from common.exception.app_exception import AppApiException
from common.forms import BaseForm
from setting.models_provider.base_model_provider import BaseModelCredential, ValidCode
class VolcanicEngineLLMModelCredential(BaseForm, BaseModelCredential):
def is_valid(self, model_type: str, model_name, model_credential: Dict[str, object], provider,
raise_exception=False):
model_type_list = provider.get_model_type_list()
if not any(list(filter(lambda mt: mt.get('value') == model_type, model_type_list))):
raise AppApiException(ValidCode.valid_error.value, f'{model_type} 模型类型不支持')
for key in ['access_key_id', 'secret_access_key']:
if key not in model_credential:
if raise_exception:
raise AppApiException(ValidCode.valid_error.value, f'{key} 字段为必填字段')
else:
return False
try:
model = provider.get_model(model_type, model_name, model_credential)
res = model.invoke([HumanMessage(content='你好')])
print(res)
except Exception as e:
if isinstance(e, AppApiException):
raise e
if raise_exception:
raise AppApiException(ValidCode.valid_error.value, f'校验失败,请检查参数是否正确: {str(e)}')
else:
return False
return True
def encryption_dict(self, model: Dict[str, object]):
return {**model, 'access_key_id': super().encryption(model.get('access_key_id', ''))}
access_key_id = forms.PasswordInputField('Access Key ID', required=True)
secret_access_key = forms.PasswordInputField('Secret Access Key', required=True)

File diff suppressed because one or more lines are too long

After

Width:  |  Height:  |  Size: 60 KiB

View File

@ -0,0 +1,15 @@
from typing import Dict
from langchain_community.embeddings import VolcanoEmbeddings
from setting.models_provider.base_model_provider import MaxKBBaseModel
class VolcanicEngineEmbeddingModel(MaxKBBaseModel, VolcanoEmbeddings):
@staticmethod
def new_instance(model_type, model_name, model_credential: Dict[str, object], **model_kwargs):
return VolcanicEngineEmbeddingModel(
api_key=model_credential.get('api_key'),
model=model_name,
openai_api_base=model_credential.get('api_base'),
)

View File

@ -0,0 +1,27 @@
from typing import List, Dict
from langchain_community.chat_models import VolcEngineMaasChat
from langchain_core.messages import BaseMessage, get_buffer_string
from common.config.tokenizer_manage_config import TokenizerManage
from setting.models_provider.base_model_provider import MaxKBBaseModel
from langchain_openai import ChatOpenAI
class VolcanicEngineChatModel(MaxKBBaseModel, ChatOpenAI):
@staticmethod
def new_instance(model_type, model_name, model_credential: Dict[str, object], **model_kwargs):
volcanic_engine_chat = VolcanicEngineChatModel(
model=model_name,
volc_engine_maas_ak=model_credential.get("access_key_id"),
volc_engine_maas_sk=model_credential.get("secret_access_key"),
)
return volcanic_engine_chat
def get_num_tokens_from_messages(self, messages: List[BaseMessage]) -> int:
tokenizer = TokenizerManage.get_tokenizer()
return sum([len(tokenizer.encode(get_buffer_string([m]))) for m in messages])
def get_num_tokens(self, text: str) -> int:
tokenizer = TokenizerManage.get_tokenizer()
return len(tokenizer.encode(text))

View File

@ -0,0 +1,51 @@
#!/usr/bin/env python
# -*- coding: UTF-8 -*-
"""
@Project MaxKB
@File gemini_model_provider.py
@Author Brian Yang
@Date 5/13/24 7:47 AM
"""
import os
from common.util.file_util import get_file_content
from setting.models_provider.base_model_provider import IModelProvider, ModelProvideInfo, ModelInfo, ModelTypeConst, \
ModelInfoManage
from setting.models_provider.impl.openai_model_provider.credential.embedding import OpenAIEmbeddingCredential
from setting.models_provider.impl.openai_model_provider.credential.llm import OpenAILLMModelCredential
from setting.models_provider.impl.openai_model_provider.model.embedding import OpenAIEmbeddingModel
from setting.models_provider.impl.openai_model_provider.model.llm import OpenAIChatModel
from smartdoc.conf import PROJECT_DIR
volcanic_engine_llm_model_credential = OpenAILLMModelCredential()
model_info_list = [
ModelInfo('ep-xxxxxxxxxx-yyyy',
'用户前往火山方舟的模型推理页面创建推理接入点这里需要输入ep-xxxxxxxxxx-yyyy进行调用',
ModelTypeConst.LLM,
volcanic_engine_llm_model_credential, OpenAIChatModel
)
]
open_ai_embedding_credential = OpenAIEmbeddingCredential()
model_info_embedding_list = [
ModelInfo('ep-xxxxxxxxxx-yyyy',
'用户前往火山方舟的模型推理页面创建推理接入点这里需要输入ep-xxxxxxxxxx-yyyy进行调用',
ModelTypeConst.EMBEDDING, open_ai_embedding_credential,
OpenAIEmbeddingModel)]
model_info_manage = ModelInfoManage.builder().append_model_info_list(model_info_list).append_default_model_info(
model_info_list[0]).build()
class VolcanicEngineModelProvider(IModelProvider):
def get_model_info_manage(self):
return model_info_manage
def get_model_provide_info(self):
return ModelProvideInfo(provider='model_volcanic_engine_provider', name='火山引擎', icon=get_file_content(
os.path.join(PROJECT_DIR, "apps", "setting", 'models_provider', 'impl', 'volcanic_engine_model_provider',
'icon',
'volcanic_engine_icon_svg')))

View File

@ -46,6 +46,8 @@ gunicorn = "^22.0.0"
python-daemon = "3.0.1"
gevent = "^24.2.1"
boto3 = "^1.34.151"
langchain-aws = "^0.1.13"
[build-system]
requires = ["poetry-core"]
build-backend = "poetry.core.masonry.api"