From 72423a7c3e3b5a65ee66c5c5d4ce58b06d6fe55b Mon Sep 17 00:00:00 2001
From: wxg0103 <727495428@qq.com>
Date: Wed, 24 Jul 2024 14:14:05 +0800
Subject: [PATCH] =?UTF-8?q?feat:=20=E5=A2=9E=E5=8A=A0=E5=AF=B9=E6=8E=A5?=
=?UTF-8?q?=E6=A8=A1=E5=9E=8B?=
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit
---
.../constants/model_provider_constants.py | 7 +
.../aws_bedrock_model_provider/__init__.py | 2 +
.../aws_bedrock_model_provider.py | 107 +++++++
.../credential/embedding.py | 64 ++++
.../credential/llm.py | 62 ++++
.../icon/bedrock_icon_svg | 1 +
.../model/embedding.py | 25 ++
.../aws_bedrock_model_provider/model/llm.py | 35 +++
.../impl/tencent_model_provider/__init__.py | 2 +
.../credential/embedding.py | 64 ++++
.../tencent_model_provider/credential/llm.py | 47 +++
.../icon/tencent_icon_svg | 5 +
.../tencent_model_provider/model/embedding.py | 25 ++
.../tencent_model_provider/model/hunyuan.py | 273 ++++++++++++++++++
.../impl/tencent_model_provider/model/llm.py | 37 +++
.../tencent_model_provider.py | 103 +++++++
.../__init__.py | 2 +
.../credential/embedding.py | 46 +++
.../credential/llm.py | 50 ++++
.../icon/volcanic_engine_icon_svg | 5 +
.../model/embedding.py | 15 +
.../model/llm.py | 27 ++
.../volcanic_engine_model_provider.py | 51 ++++
pyproject.toml | 2 +
24 files changed, 1057 insertions(+)
create mode 100644 apps/setting/models_provider/impl/aws_bedrock_model_provider/__init__.py
create mode 100644 apps/setting/models_provider/impl/aws_bedrock_model_provider/aws_bedrock_model_provider.py
create mode 100644 apps/setting/models_provider/impl/aws_bedrock_model_provider/credential/embedding.py
create mode 100644 apps/setting/models_provider/impl/aws_bedrock_model_provider/credential/llm.py
create mode 100644 apps/setting/models_provider/impl/aws_bedrock_model_provider/icon/bedrock_icon_svg
create mode 100644 apps/setting/models_provider/impl/aws_bedrock_model_provider/model/embedding.py
create mode 100644 apps/setting/models_provider/impl/aws_bedrock_model_provider/model/llm.py
create mode 100644 apps/setting/models_provider/impl/tencent_model_provider/__init__.py
create mode 100644 apps/setting/models_provider/impl/tencent_model_provider/credential/embedding.py
create mode 100644 apps/setting/models_provider/impl/tencent_model_provider/credential/llm.py
create mode 100644 apps/setting/models_provider/impl/tencent_model_provider/icon/tencent_icon_svg
create mode 100644 apps/setting/models_provider/impl/tencent_model_provider/model/embedding.py
create mode 100644 apps/setting/models_provider/impl/tencent_model_provider/model/hunyuan.py
create mode 100644 apps/setting/models_provider/impl/tencent_model_provider/model/llm.py
create mode 100644 apps/setting/models_provider/impl/tencent_model_provider/tencent_model_provider.py
create mode 100644 apps/setting/models_provider/impl/volcanic_engine_model_provider/__init__.py
create mode 100644 apps/setting/models_provider/impl/volcanic_engine_model_provider/credential/embedding.py
create mode 100644 apps/setting/models_provider/impl/volcanic_engine_model_provider/credential/llm.py
create mode 100644 apps/setting/models_provider/impl/volcanic_engine_model_provider/icon/volcanic_engine_icon_svg
create mode 100644 apps/setting/models_provider/impl/volcanic_engine_model_provider/model/embedding.py
create mode 100644 apps/setting/models_provider/impl/volcanic_engine_model_provider/model/llm.py
create mode 100644 apps/setting/models_provider/impl/volcanic_engine_model_provider/volcanic_engine_model_provider.py
diff --git a/apps/setting/models_provider/constants/model_provider_constants.py b/apps/setting/models_provider/constants/model_provider_constants.py
index c9e9659c3..6d4ff9cc0 100644
--- a/apps/setting/models_provider/constants/model_provider_constants.py
+++ b/apps/setting/models_provider/constants/model_provider_constants.py
@@ -8,6 +8,7 @@
"""
from enum import Enum
+from setting.models_provider.impl.aws_bedrock_model_provider.aws_bedrock_model_provider import BedrockModelProvider
from setting.models_provider.impl.azure_model_provider.azure_model_provider import AzureModelProvider
from setting.models_provider.impl.deepseek_model_provider.deepseek_model_provider import DeepSeekModelProvider
from setting.models_provider.impl.gemini_model_provider.gemini_model_provider import GeminiModelProvider
@@ -15,6 +16,9 @@ from setting.models_provider.impl.kimi_model_provider.kimi_model_provider import
from setting.models_provider.impl.ollama_model_provider.ollama_model_provider import OllamaModelProvider
from setting.models_provider.impl.openai_model_provider.openai_model_provider import OpenAIModelProvider
from setting.models_provider.impl.qwen_model_provider.qwen_model_provider import QwenModelProvider
+from setting.models_provider.impl.tencent_model_provider.tencent_model_provider import TencentModelProvider
+from setting.models_provider.impl.volcanic_engine_model_provider.volcanic_engine_model_provider import \
+ VolcanicEngineModelProvider
from setting.models_provider.impl.wenxin_model_provider.wenxin_model_provider import WenxinModelProvider
from setting.models_provider.impl.xf_model_provider.xf_model_provider import XunFeiModelProvider
from setting.models_provider.impl.zhipu_model_provider.zhipu_model_provider import ZhiPuModelProvider
@@ -32,4 +36,7 @@ class ModelProvideConstants(Enum):
model_xf_provider = XunFeiModelProvider()
model_deepseek_provider = DeepSeekModelProvider()
model_gemini_provider = GeminiModelProvider()
+ model_volcanic_engine_provider = VolcanicEngineModelProvider()
+ model_tencent_provider = TencentModelProvider()
+ model_aws_bedrock_provider = BedrockModelProvider()
model_local_provider = LocalModelProvider()
diff --git a/apps/setting/models_provider/impl/aws_bedrock_model_provider/__init__.py b/apps/setting/models_provider/impl/aws_bedrock_model_provider/__init__.py
new file mode 100644
index 000000000..8cb7f459e
--- /dev/null
+++ b/apps/setting/models_provider/impl/aws_bedrock_model_provider/__init__.py
@@ -0,0 +1,2 @@
+#!/usr/bin/env python
+# -*- coding: UTF-8 -*-
diff --git a/apps/setting/models_provider/impl/aws_bedrock_model_provider/aws_bedrock_model_provider.py b/apps/setting/models_provider/impl/aws_bedrock_model_provider/aws_bedrock_model_provider.py
new file mode 100644
index 000000000..a6187f995
--- /dev/null
+++ b/apps/setting/models_provider/impl/aws_bedrock_model_provider/aws_bedrock_model_provider.py
@@ -0,0 +1,107 @@
+#!/usr/bin/env python
+# -*- coding: UTF-8 -*-
+
+import os
+from common.util.file_util import get_file_content
+from setting.models_provider.base_model_provider import (
+ IModelProvider, ModelProvideInfo, ModelInfo, ModelTypeConst, ModelInfoManage
+)
+from setting.models_provider.impl.aws_bedrock_model_provider.credential.llm import BedrockLLMModelCredential
+from setting.models_provider.impl.aws_bedrock_model_provider.model.llm import BedrockModel
+from setting.models_provider.impl.tencent_model_provider.credential.embedding import TencentEmbeddingCredential
+from setting.models_provider.impl.tencent_model_provider.credential.llm import TencentLLMModelCredential
+from setting.models_provider.impl.tencent_model_provider.model.embedding import TencentEmbeddingModel
+from setting.models_provider.impl.tencent_model_provider.model.llm import TencentModel
+from smartdoc.conf import PROJECT_DIR
+
+
+def _create_model_info(model_name, description, model_type, credential_class, model_class):
+ return ModelInfo(
+ name=model_name,
+ desc=description,
+ model_type=model_type,
+ model_credential=credential_class(),
+ model_class=model_class
+ )
+
+
+def _get_aws_bedrock_icon_path():
+ return os.path.join(PROJECT_DIR, "apps", "setting", 'models_provider', 'impl', 'aws_bedrock_model_provider',
+ 'icon', 'bedrock_icon_svg')
+
+
+def _initialize_model_info():
+ model_info_list = [_create_model_info(
+ 'amazon.titan-text-premier-v1:0',
+ 'Titan Text Premier 是 Titan Text 系列中功能强大且先进的型号,旨在为各种企业应用程序提供卓越的性能。凭借其尖端功能,它提供了更高的准确性和出色的结果,使其成为寻求一流文本处理解决方案的组织的绝佳选择。',
+ ModelTypeConst.LLM,
+ BedrockLLMModelCredential,
+ BedrockModel
+ ),
+ _create_model_info(
+ 'amazon.titan-text-lite-v1',
+ 'Amazon Titan Text Lite 是一种轻量级的高效模型,非常适合英语任务的微调,包括摘要和文案写作等,在这种场景下,客户需要更小、更经济高效且高度可定制的模型',
+ ModelTypeConst.LLM,
+ BedrockLLMModelCredential,
+ BedrockModel),
+ _create_model_info(
+ 'amazon.titan-text-express-v1',
+ 'Amazon Titan Text Express 的上下文长度长达 8000 个令牌,因而非常适合各种高级常规语言任务,例如开放式文本生成和对话式聊天,以及检索增强生成(RAG)中的支持。在发布时,该模型针对英语进行了优化,但也支持其他语言。',
+ ModelTypeConst.LLM,
+ BedrockLLMModelCredential,
+ BedrockModel),
+ _create_model_info(
+ 'amazon.titan-embed-text-v2:0',
+ 'Amazon Titan Text Embeddings V2 是一种轻量级、高效的模型,非常适合在不同维度上执行高精度检索任务。该模型支持灵活的嵌入大小(1024、512 和 256),并优先考虑在较小维度上保持准确性,从而可以在不影响准确性的情况下降低存储成本。Titan Text Embeddings V2 适用于各种任务,包括文档检索、推荐系统、搜索引擎和对话式系统。',
+ ModelTypeConst.LLM,
+ BedrockLLMModelCredential,
+ BedrockModel),
+ _create_model_info(
+ 'mistral.mistral-7b-instruct-v0:2',
+ '7B 密集型转换器,可快速部署,易于定制。体积虽小,但功能强大,适用于各种用例。支持英语和代码,以及 32k 的上下文窗口。',
+ ModelTypeConst.EMBEDDING,
+ BedrockLLMModelCredential,
+ BedrockModel),
+ _create_model_info(
+ 'mistral.mistral-large-2402-v1:0',
+ '先进的 Mistral AI 大型语言模型,能够处理任何语言任务,包括复杂的多语言推理、文本理解、转换和代码生成。',
+ ModelTypeConst.LLM,
+ BedrockLLMModelCredential,
+ BedrockModel),
+ _create_model_info(
+ 'meta.llama3-70b-instruct-v1:0',
+ '非常适合内容创作、会话式人工智能、语言理解、研发和企业应用',
+ ModelTypeConst.LLM,
+ BedrockLLMModelCredential,
+ BedrockModel),
+ _create_model_info(
+ 'meta.llama3-8b-instruct-v1:0',
+ '非常适合有限的计算能力和资源、边缘设备和更快的训练时间。',
+ ModelTypeConst.LLM,
+ BedrockLLMModelCredential,
+ BedrockModel),
+ ]
+
+ model_info_manage = ModelInfoManage.builder() \
+ .append_model_info_list(model_info_list) \
+ .append_default_model_info(model_info_list[0]) \
+ .build()
+
+ return model_info_manage
+
+
+class BedrockModelProvider(IModelProvider):
+ def __init__(self):
+ self._model_info_manage = _initialize_model_info()
+
+ def get_model_info_manage(self):
+ return self._model_info_manage
+
+ def get_model_provide_info(self):
+ icon_path = _get_aws_bedrock_icon_path()
+ icon_data = get_file_content(icon_path)
+ return ModelProvideInfo(
+ provider='model_aws_bedrock_provider',
+ name='Amazon Bedrock',
+ icon=icon_data
+ )
diff --git a/apps/setting/models_provider/impl/aws_bedrock_model_provider/credential/embedding.py b/apps/setting/models_provider/impl/aws_bedrock_model_provider/credential/embedding.py
new file mode 100644
index 000000000..40e36caca
--- /dev/null
+++ b/apps/setting/models_provider/impl/aws_bedrock_model_provider/credential/embedding.py
@@ -0,0 +1,64 @@
+import json
+from typing import Dict
+
+from tencentcloud.common import credential
+from tencentcloud.common.profile.client_profile import ClientProfile
+from tencentcloud.common.profile.http_profile import HttpProfile
+from tencentcloud.hunyuan.v20230901 import hunyuan_client, models
+
+from common import forms
+from common.exception.app_exception import AppApiException
+from common.forms import BaseForm
+from setting.models_provider.base_model_provider import BaseModelCredential, ValidCode
+
+
+class TencentEmbeddingCredential(BaseForm, BaseModelCredential):
+ @classmethod
+ def _validate_model_type(cls, model_type: str, provider) -> bool:
+ model_type_list = provider.get_model_type_list()
+ if not any(mt.get('value') == model_type for mt in model_type_list):
+ raise AppApiException(ValidCode.valid_error.value, f'{model_type} 模型类型不支持')
+ return True
+
+ @classmethod
+ def _validate_credential(cls, model_credential: Dict[str, object]) -> credential.Credential:
+ for key in ['SecretId', 'SecretKey']:
+ if key not in model_credential:
+ raise AppApiException(ValidCode.valid_error.value, f'{key} 字段为必填字段')
+ return credential.Credential(model_credential['SecretId'], model_credential['SecretKey'])
+
+ @classmethod
+ def _test_credentials(cls, client, model_name: str):
+ req = models.GetEmbeddingRequest()
+ params = {
+ "Model": model_name,
+ "Input": "测试"
+ }
+ req.from_json_string(json.dumps(params))
+ try:
+ res = client.GetEmbedding(req)
+ print(res.to_json_string())
+ except Exception as e:
+ raise AppApiException(ValidCode.valid_error.value, f'校验失败,请检查参数是否正确: {str(e)}')
+
+ def is_valid(self, model_type: str, model_name, model_credential: Dict[str, object], provider,
+ raise_exception=True) -> bool:
+ try:
+ self._validate_model_type(model_type, provider)
+ cred = self._validate_credential(model_credential)
+ httpProfile = HttpProfile(endpoint="hunyuan.tencentcloudapi.com")
+ clientProfile = ClientProfile(httpProfile=httpProfile)
+ client = hunyuan_client.HunyuanClient(cred, "", clientProfile)
+ self._test_credentials(client, model_name)
+ return True
+ except AppApiException as e:
+ if raise_exception:
+ raise e
+ return False
+
+ def encryption_dict(self, model: Dict[str, object]) -> Dict[str, object]:
+ encrypted_secret_key = super().encryption(model.get('SecretKey', ''))
+ return {**model, 'SecretKey': encrypted_secret_key}
+
+ SecretId = forms.PasswordInputField('SecretId', required=True)
+ SecretKey = forms.PasswordInputField('SecretKey', required=True)
diff --git a/apps/setting/models_provider/impl/aws_bedrock_model_provider/credential/llm.py b/apps/setting/models_provider/impl/aws_bedrock_model_provider/credential/llm.py
new file mode 100644
index 000000000..848250b1e
--- /dev/null
+++ b/apps/setting/models_provider/impl/aws_bedrock_model_provider/credential/llm.py
@@ -0,0 +1,62 @@
+import os
+import re
+from typing import Dict
+from common.exception.app_exception import AppApiException
+from common.forms import BaseForm
+from setting.models_provider.base_model_provider import ValidCode, BaseModelCredential
+from langchain_core.messages import HumanMessage
+from common import forms
+
+
+class BedrockLLMModelCredential(BaseForm, BaseModelCredential):
+
+ @staticmethod
+ def _update_aws_credentials(profile_name, access_key_id, secret_access_key):
+ credentials_path = os.path.join(os.path.expanduser("~"), ".aws", "credentials")
+ os.makedirs(os.path.dirname(credentials_path), exist_ok=True)
+
+ content = open(credentials_path, 'r').read() if os.path.exists(credentials_path) else ''
+ pattern = rf'\n*\[{profile_name}\]\n*(aws_access_key_id = .*)\n*(aws_secret_access_key = .*)\n*'
+ content = re.sub(pattern, '', content, flags=re.DOTALL)
+
+ if not re.search(rf'\[{profile_name}\]', content):
+ content += f"\n[{profile_name}]\naws_access_key_id = {access_key_id}\naws_secret_access_key = {secret_access_key}\n"
+
+ with open(credentials_path, 'w') as file:
+ file.write(content)
+
+ def is_valid(self, model_type: str, model_name, model_credential: Dict[str, object], provider,
+ raise_exception=False):
+ model_type_list = provider.get_model_type_list()
+ if not any(mt.get('value') == model_type for mt in model_type_list):
+ if raise_exception:
+ raise AppApiException(ValidCode.valid_error.value, f'{model_type} 模型类型不支持')
+ return False
+
+ required_keys = ['region_name', 'access_key_id', 'secret_access_key']
+ if not all(key in model_credential for key in required_keys):
+ if raise_exception:
+ raise AppApiException(ValidCode.valid_error.value, f'以下字段为必填字段: {", ".join(required_keys)}')
+ return False
+
+ try:
+ self._update_aws_credentials('aws-profile', model_credential['access_key_id'],
+ model_credential['secret_access_key'])
+ model_credential['credentials_profile_name'] = 'aws-profile'
+ model = provider.get_model(model_type, model_name, model_credential)
+ model.invoke([HumanMessage(content='你好')])
+ except AppApiException:
+ raise
+ except Exception as e:
+ if raise_exception:
+ raise AppApiException(ValidCode.valid_error.value, f'校验失败,请检查参数是否正确: {str(e)}')
+ return False
+
+ return True
+
+ def encryption_dict(self, model: Dict[str, object]):
+ return {**model, 'secret_access_key': super().encryption(model.get('secret_access_key', ''))}
+
+ region_name = forms.TextInputField('Region Name', required=True)
+ access_key_id = forms.TextInputField('Access Key ID', required=True)
+ secret_access_key = forms.PasswordInputField('Secret Access Key', required=True)
diff --git a/apps/setting/models_provider/impl/aws_bedrock_model_provider/icon/bedrock_icon_svg b/apps/setting/models_provider/impl/aws_bedrock_model_provider/icon/bedrock_icon_svg
new file mode 100644
index 000000000..5f176a7d2
--- /dev/null
+++ b/apps/setting/models_provider/impl/aws_bedrock_model_provider/icon/bedrock_icon_svg
@@ -0,0 +1 @@
+
\ No newline at end of file
diff --git a/apps/setting/models_provider/impl/aws_bedrock_model_provider/model/embedding.py b/apps/setting/models_provider/impl/aws_bedrock_model_provider/model/embedding.py
new file mode 100644
index 000000000..a5bd0336a
--- /dev/null
+++ b/apps/setting/models_provider/impl/aws_bedrock_model_provider/model/embedding.py
@@ -0,0 +1,25 @@
+from setting.models_provider.base_model_provider import MaxKBBaseModel
+from typing import Dict
+import requests
+
+
+class TencentEmbeddingModel(MaxKBBaseModel):
+ def __init__(self, secret_id: str, secret_key: str, api_base: str, model_name: str):
+ self.secret_id = secret_id
+ self.secret_key = secret_key
+ self.api_base = api_base
+ self.model_name = model_name
+
+ @staticmethod
+ def new_instance(model_type: str, model_name: str, model_credential: Dict[str, str], **model_kwargs):
+ return TencentEmbeddingModel(
+ secret_id=model_credential.get('SecretId'),
+ secret_key=model_credential.get('SecretKey'),
+ api_base=model_credential.get('api_base'),
+ model_name=model_name,
+ )
+
+
+ def _generate_auth_token(self):
+ # Example method to generate an authentication token for the model API
+ return f"{self.secret_id}:{self.secret_key}"
diff --git a/apps/setting/models_provider/impl/aws_bedrock_model_provider/model/llm.py b/apps/setting/models_provider/impl/aws_bedrock_model_provider/model/llm.py
new file mode 100644
index 000000000..43c3c00a0
--- /dev/null
+++ b/apps/setting/models_provider/impl/aws_bedrock_model_provider/model/llm.py
@@ -0,0 +1,35 @@
+from typing import List, Dict, Any
+from langchain_community.chat_models import BedrockChat
+from langchain_core.messages import BaseMessage, get_buffer_string
+from common.config.tokenizer_manage_config import TokenizerManage
+from setting.models_provider.base_model_provider import MaxKBBaseModel
+
+
+class BedrockModel(MaxKBBaseModel, BedrockChat):
+
+ def __init__(self, model_id: str, region_name: str, credentials_profile_name: str,
+ streaming: bool = False, **kwargs):
+ super().__init__(model_id=model_id, region_name=region_name,
+ credentials_profile_name=credentials_profile_name, streaming=streaming, **kwargs)
+
+ @classmethod
+ def new_instance(cls, model_type: str, model_name: str, model_credential: Dict[str, str],
+ **model_kwargs) -> 'BedrockModel':
+ return cls(
+ model_id=model_name,
+ region_name=model_credential['region_name'],
+ credentials_profile_name=model_credential['credentials_profile_name'],
+ streaming=model_kwargs.pop('streaming', False),
+ **model_kwargs
+ )
+
+ def _get_num_tokens(self, content: str) -> int:
+ """Helper method to count tokens in a string."""
+ tokenizer = TokenizerManage.get_tokenizer()
+ return len(tokenizer.encode(content))
+
+ def get_num_tokens_from_messages(self, messages: List[BaseMessage]) -> int:
+ return sum(self._get_num_tokens(get_buffer_string([message])) for message in messages)
+
+ def get_num_tokens(self, text: str) -> int:
+ return self._get_num_tokens(text)
diff --git a/apps/setting/models_provider/impl/tencent_model_provider/__init__.py b/apps/setting/models_provider/impl/tencent_model_provider/__init__.py
new file mode 100644
index 000000000..8cb7f459e
--- /dev/null
+++ b/apps/setting/models_provider/impl/tencent_model_provider/__init__.py
@@ -0,0 +1,2 @@
+#!/usr/bin/env python
+# -*- coding: UTF-8 -*-
diff --git a/apps/setting/models_provider/impl/tencent_model_provider/credential/embedding.py b/apps/setting/models_provider/impl/tencent_model_provider/credential/embedding.py
new file mode 100644
index 000000000..40e36caca
--- /dev/null
+++ b/apps/setting/models_provider/impl/tencent_model_provider/credential/embedding.py
@@ -0,0 +1,64 @@
+import json
+from typing import Dict
+
+from tencentcloud.common import credential
+from tencentcloud.common.profile.client_profile import ClientProfile
+from tencentcloud.common.profile.http_profile import HttpProfile
+from tencentcloud.hunyuan.v20230901 import hunyuan_client, models
+
+from common import forms
+from common.exception.app_exception import AppApiException
+from common.forms import BaseForm
+from setting.models_provider.base_model_provider import BaseModelCredential, ValidCode
+
+
+class TencentEmbeddingCredential(BaseForm, BaseModelCredential):
+ @classmethod
+ def _validate_model_type(cls, model_type: str, provider) -> bool:
+ model_type_list = provider.get_model_type_list()
+ if not any(mt.get('value') == model_type for mt in model_type_list):
+ raise AppApiException(ValidCode.valid_error.value, f'{model_type} 模型类型不支持')
+ return True
+
+ @classmethod
+ def _validate_credential(cls, model_credential: Dict[str, object]) -> credential.Credential:
+ for key in ['SecretId', 'SecretKey']:
+ if key not in model_credential:
+ raise AppApiException(ValidCode.valid_error.value, f'{key} 字段为必填字段')
+ return credential.Credential(model_credential['SecretId'], model_credential['SecretKey'])
+
+ @classmethod
+ def _test_credentials(cls, client, model_name: str):
+ req = models.GetEmbeddingRequest()
+ params = {
+ "Model": model_name,
+ "Input": "测试"
+ }
+ req.from_json_string(json.dumps(params))
+ try:
+ res = client.GetEmbedding(req)
+ print(res.to_json_string())
+ except Exception as e:
+ raise AppApiException(ValidCode.valid_error.value, f'校验失败,请检查参数是否正确: {str(e)}')
+
+ def is_valid(self, model_type: str, model_name, model_credential: Dict[str, object], provider,
+ raise_exception=True) -> bool:
+ try:
+ self._validate_model_type(model_type, provider)
+ cred = self._validate_credential(model_credential)
+ httpProfile = HttpProfile(endpoint="hunyuan.tencentcloudapi.com")
+ clientProfile = ClientProfile(httpProfile=httpProfile)
+ client = hunyuan_client.HunyuanClient(cred, "", clientProfile)
+ self._test_credentials(client, model_name)
+ return True
+ except AppApiException as e:
+ if raise_exception:
+ raise e
+ return False
+
+ def encryption_dict(self, model: Dict[str, object]) -> Dict[str, object]:
+ encrypted_secret_key = super().encryption(model.get('SecretKey', ''))
+ return {**model, 'SecretKey': encrypted_secret_key}
+
+ SecretId = forms.PasswordInputField('SecretId', required=True)
+ SecretKey = forms.PasswordInputField('SecretKey', required=True)
diff --git a/apps/setting/models_provider/impl/tencent_model_provider/credential/llm.py b/apps/setting/models_provider/impl/tencent_model_provider/credential/llm.py
new file mode 100644
index 000000000..ad9ab3a82
--- /dev/null
+++ b/apps/setting/models_provider/impl/tencent_model_provider/credential/llm.py
@@ -0,0 +1,47 @@
+# coding=utf-8
+from langchain_core.messages import HumanMessage
+from common import forms
+from common.exception.app_exception import AppApiException
+from common.forms import BaseForm
+from setting.models_provider.base_model_provider import BaseModelCredential, ValidCode
+
+
+class TencentLLMModelCredential(BaseForm, BaseModelCredential):
+ REQUIRED_FIELDS = ['hunyuan_app_id', 'hunyuan_secret_id', 'hunyuan_secret_key']
+
+ @classmethod
+ def _validate_model_type(cls, model_type, provider, raise_exception=False):
+ if not any(mt['value'] == model_type for mt in provider.get_model_type_list()):
+ if raise_exception:
+ raise AppApiException(ValidCode.valid_error.value, f'{model_type} 模型类型不支持')
+ return False
+ return True
+
+ @classmethod
+ def _validate_credential_fields(cls, model_credential, raise_exception=False):
+ missing_keys = [key for key in cls.REQUIRED_FIELDS if key not in model_credential]
+ if missing_keys:
+ if raise_exception:
+ raise AppApiException(ValidCode.valid_error.value, f'{", ".join(missing_keys)} 字段为必填字段')
+ return False
+ return True
+
+ def is_valid(self, model_type, model_name, model_credential, provider, raise_exception=False):
+ if not (self._validate_model_type(model_type, provider, raise_exception) and
+ self._validate_credential_fields(model_credential, raise_exception)):
+ return False
+ try:
+ model = provider.get_model(model_type, model_name, model_credential)
+ model.invoke([HumanMessage(content='你好')])
+ except Exception as e:
+ if raise_exception:
+ raise AppApiException(ValidCode.valid_error.value, f'校验失败,请检查参数是否正确: {str(e)}')
+ return False
+ return True
+
+ def encryption_dict(self, model):
+ return {**model, 'hunyuan_secret_key': super().encryption(model.get('hunyuan_secret_key', ''))}
+
+ hunyuan_app_id = forms.TextInputField('APP ID', required=True)
+ hunyuan_secret_id = forms.PasswordInputField('SecretId', required=True)
+ hunyuan_secret_key = forms.PasswordInputField('SecretKey', required=True)
\ No newline at end of file
diff --git a/apps/setting/models_provider/impl/tencent_model_provider/icon/tencent_icon_svg b/apps/setting/models_provider/impl/tencent_model_provider/icon/tencent_icon_svg
new file mode 100644
index 000000000..6cec08b74
--- /dev/null
+++ b/apps/setting/models_provider/impl/tencent_model_provider/icon/tencent_icon_svg
@@ -0,0 +1,5 @@
+
+
+
diff --git a/apps/setting/models_provider/impl/tencent_model_provider/model/embedding.py b/apps/setting/models_provider/impl/tencent_model_provider/model/embedding.py
new file mode 100644
index 000000000..a5bd0336a
--- /dev/null
+++ b/apps/setting/models_provider/impl/tencent_model_provider/model/embedding.py
@@ -0,0 +1,25 @@
+from setting.models_provider.base_model_provider import MaxKBBaseModel
+from typing import Dict
+import requests
+
+
+class TencentEmbeddingModel(MaxKBBaseModel):
+ def __init__(self, secret_id: str, secret_key: str, api_base: str, model_name: str):
+ self.secret_id = secret_id
+ self.secret_key = secret_key
+ self.api_base = api_base
+ self.model_name = model_name
+
+ @staticmethod
+ def new_instance(model_type: str, model_name: str, model_credential: Dict[str, str], **model_kwargs):
+ return TencentEmbeddingModel(
+ secret_id=model_credential.get('SecretId'),
+ secret_key=model_credential.get('SecretKey'),
+ api_base=model_credential.get('api_base'),
+ model_name=model_name,
+ )
+
+
+ def _generate_auth_token(self):
+ # Example method to generate an authentication token for the model API
+ return f"{self.secret_id}:{self.secret_key}"
diff --git a/apps/setting/models_provider/impl/tencent_model_provider/model/hunyuan.py b/apps/setting/models_provider/impl/tencent_model_provider/model/hunyuan.py
new file mode 100644
index 000000000..9af6f983c
--- /dev/null
+++ b/apps/setting/models_provider/impl/tencent_model_provider/model/hunyuan.py
@@ -0,0 +1,273 @@
+import json
+import logging
+from typing import Any, Dict, Iterator, List, Mapping, Optional, Type
+
+from langchain_core.callbacks import CallbackManagerForLLMRun
+from langchain_core.language_models.chat_models import (
+ BaseChatModel,
+ generate_from_stream,
+)
+from langchain_core.messages import (
+ AIMessage,
+ AIMessageChunk,
+ BaseMessage,
+ BaseMessageChunk,
+ ChatMessage,
+ ChatMessageChunk,
+ HumanMessage,
+ HumanMessageChunk,
+)
+from langchain_core.outputs import ChatGeneration, ChatGenerationChunk, ChatResult
+from langchain_core.pydantic_v1 import Field, SecretStr, root_validator
+from langchain_core.utils import (
+ convert_to_secret_str,
+ get_from_dict_or_env,
+ get_pydantic_field_names,
+ pre_init,
+)
+
+logger = logging.getLogger(__name__)
+
+
+def _convert_message_to_dict(message: BaseMessage) -> dict:
+ message_dict: Dict[str, Any]
+ if isinstance(message, ChatMessage):
+ message_dict = {"Role": message.role, "Content": message.content}
+ elif isinstance(message, HumanMessage):
+ message_dict = {"Role": "user", "Content": message.content}
+ elif isinstance(message, AIMessage):
+ message_dict = {"Role": "assistant", "Content": message.content}
+ else:
+ raise TypeError(f"Got unknown type {message}")
+
+ return message_dict
+
+
+def _convert_dict_to_message(_dict: Mapping[str, Any]) -> BaseMessage:
+ role = _dict["Role"]
+ if role == "user":
+ return HumanMessage(content=_dict["Content"])
+ elif role == "assistant":
+ return AIMessage(content=_dict.get("Content", "") or "")
+ else:
+ return ChatMessage(content=_dict["Content"], role=role)
+
+
+def _convert_delta_to_message_chunk(
+ _dict: Mapping[str, Any], default_class: Type[BaseMessageChunk]
+) -> BaseMessageChunk:
+ role = _dict.get("Role")
+ content = _dict.get("Content") or ""
+
+ if role == "user" or default_class == HumanMessageChunk:
+ return HumanMessageChunk(content=content)
+ elif role == "assistant" or default_class == AIMessageChunk:
+ return AIMessageChunk(content=content)
+ elif role or default_class == ChatMessageChunk:
+ return ChatMessageChunk(content=content, role=role) # type: ignore[arg-type]
+ else:
+ return default_class(content=content) # type: ignore[call-arg]
+
+
+def _create_chat_result(response: Mapping[str, Any]) -> ChatResult:
+ generations = []
+ for choice in response["Choices"]:
+ message = _convert_dict_to_message(choice["Message"])
+ generations.append(ChatGeneration(message=message))
+
+ token_usage = response["Usage"]
+ llm_output = {"token_usage": token_usage}
+ return ChatResult(generations=generations, llm_output=llm_output)
+
+
+class ChatHunyuan(BaseChatModel):
+ """Tencent Hunyuan chat models API by Tencent.
+
+ For more information, see https://cloud.tencent.com/document/product/1729
+ """
+
+ @property
+ def lc_secrets(self) -> Dict[str, str]:
+ return {
+ "hunyuan_app_id": "HUNYUAN_APP_ID",
+ "hunyuan_secret_id": "HUNYUAN_SECRET_ID",
+ "hunyuan_secret_key": "HUNYUAN_SECRET_KEY",
+ }
+
+ @property
+ def lc_serializable(self) -> bool:
+ return True
+
+ hunyuan_app_id: Optional[int] = None
+ """Hunyuan App ID"""
+ hunyuan_secret_id: Optional[str] = None
+ """Hunyuan Secret ID"""
+ hunyuan_secret_key: Optional[SecretStr] = None
+ """Hunyuan Secret Key"""
+ streaming: bool = False
+ """Whether to stream the results or not."""
+ request_timeout: int = 60
+ """Timeout for requests to Hunyuan API. Default is 60 seconds."""
+ temperature: float = 1.0
+ """What sampling temperature to use."""
+ top_p: float = 1.0
+ """What probability mass to use."""
+ model: str = "hunyuan-lite"
+ """What Model to use.
+ Optional model:
+ - hunyuan-lite、
+ - hunyuan-standard
+ - hunyuan-standard-256K
+ - hunyuan-pro
+ - hunyuan-code
+ - hunyuan-role
+ - hunyuan-functioncall
+ - hunyuan-vision
+ """
+ stream_moderation: bool = False
+ """Whether to review the results or not when streaming is true."""
+ enable_enhancement: bool = True
+ """Whether to enhancement the results or not."""
+
+ model_kwargs: Dict[str, Any] = Field(default_factory=dict)
+ """Holds any model parameters valid for API call not explicitly specified."""
+
+ class Config:
+ """Configuration for this pydantic object."""
+
+ allow_population_by_field_name = True
+
+ @root_validator(pre=True)
+ def build_extra(cls, values: Dict[str, Any]) -> Dict[str, Any]:
+ """Build extra kwargs from additional params that were passed in."""
+ all_required_field_names = get_pydantic_field_names(cls)
+ extra = values.get("model_kwargs", {})
+ for field_name in list(values):
+ if field_name in extra:
+ raise ValueError(f"Found {field_name} supplied twice.")
+ if field_name not in all_required_field_names:
+ logger.warning(
+ f"""WARNING! {field_name} is not default parameter.
+ {field_name} was transferred to model_kwargs.
+ Please confirm that {field_name} is what you intended."""
+ )
+ extra[field_name] = values.pop(field_name)
+
+ invalid_model_kwargs = all_required_field_names.intersection(extra.keys())
+ if invalid_model_kwargs:
+ raise ValueError(
+ f"Parameters {invalid_model_kwargs} should be specified explicitly. "
+ f"Instead they were passed in as part of `model_kwargs` parameter."
+ )
+
+ values["model_kwargs"] = extra
+ return values
+
+ @pre_init
+ def validate_environment(cls, values: Dict) -> Dict:
+ values["hunyuan_app_id"] = get_from_dict_or_env(
+ values,
+ "hunyuan_app_id",
+ "HUNYUAN_APP_ID",
+ )
+ values["hunyuan_secret_id"] = get_from_dict_or_env(
+ values,
+ "hunyuan_secret_id",
+ "HUNYUAN_SECRET_ID",
+ )
+ values["hunyuan_secret_key"] = convert_to_secret_str(
+ get_from_dict_or_env(
+ values,
+ "hunyuan_secret_key",
+ "HUNYUAN_SECRET_KEY",
+ )
+ )
+ return values
+
+ @property
+ def _default_params(self) -> Dict[str, Any]:
+ """Get the default parameters for calling Hunyuan API."""
+ normal_params = {
+ "Temperature": self.temperature,
+ "TopP": self.top_p,
+ "Model": self.model,
+ "Stream": self.streaming,
+ "StreamModeration": self.stream_moderation,
+ "EnableEnhancement": self.enable_enhancement,
+ }
+ return {**normal_params, **self.model_kwargs}
+
+ def _generate(
+ self,
+ messages: List[BaseMessage],
+ stop: Optional[List[str]] = None,
+ run_manager: Optional[CallbackManagerForLLMRun] = None,
+ **kwargs: Any,
+ ) -> ChatResult:
+ if self.streaming:
+ stream_iter = self._stream(
+ messages=messages, stop=stop, run_manager=run_manager, **kwargs
+ )
+ return generate_from_stream(stream_iter)
+
+ res = self._chat(messages, **kwargs)
+ return _create_chat_result(json.loads(res.to_json_string()))
+
+ def _stream(
+ self,
+ messages: List[BaseMessage],
+ stop: Optional[List[str]] = None,
+ run_manager: Optional[CallbackManagerForLLMRun] = None,
+ **kwargs: Any,
+ ) -> Iterator[ChatGenerationChunk]:
+ res = self._chat(messages, **kwargs)
+
+ default_chunk_class = AIMessageChunk
+ for chunk in res:
+ chunk = chunk.get("data", "")
+ if len(chunk) == 0:
+ continue
+ response = json.loads(chunk)
+ if "error" in response:
+ raise ValueError(f"Error from Hunyuan api response: {response}")
+
+ for choice in response["Choices"]:
+ chunk = _convert_delta_to_message_chunk(
+ choice["Delta"], default_chunk_class
+ )
+ default_chunk_class = chunk.__class__
+ cg_chunk = ChatGenerationChunk(message=chunk)
+ if run_manager:
+ run_manager.on_llm_new_token(chunk.content, chunk=cg_chunk)
+ yield cg_chunk
+
+ def _chat(self, messages: List[BaseMessage], **kwargs: Any) -> Any:
+ if self.hunyuan_secret_key is None:
+ raise ValueError("Hunyuan secret key is not set.")
+
+ try:
+ from tencentcloud.common import credential
+ from tencentcloud.hunyuan.v20230901 import hunyuan_client, models
+ except ImportError:
+ raise ImportError(
+ "Could not import tencentcloud python package. "
+ "Please install it with `pip install tencentcloud-sdk-python`."
+ )
+
+ parameters = {**self._default_params, **kwargs}
+ cred = credential.Credential(
+ self.hunyuan_secret_id, str(self.hunyuan_secret_key.get_secret_value())
+ )
+ client = hunyuan_client.HunyuanClient(cred, "")
+ req = models.ChatCompletionsRequest()
+ params = {
+ "Messages": [_convert_message_to_dict(m) for m in messages],
+ **parameters,
+ }
+ req.from_json_string(json.dumps(params))
+ resp = client.ChatCompletions(req)
+ return resp
+
+ @property
+ def _llm_type(self) -> str:
+ return "hunyuan-chat"
\ No newline at end of file
diff --git a/apps/setting/models_provider/impl/tencent_model_provider/model/llm.py b/apps/setting/models_provider/impl/tencent_model_provider/model/llm.py
new file mode 100644
index 000000000..81116eb61
--- /dev/null
+++ b/apps/setting/models_provider/impl/tencent_model_provider/model/llm.py
@@ -0,0 +1,37 @@
+# coding=utf-8
+
+from typing import List, Dict
+
+from langchain_core.messages import BaseMessage, get_buffer_string
+from common.config.tokenizer_manage_config import TokenizerManage
+from setting.models_provider.base_model_provider import MaxKBBaseModel
+from setting.models_provider.impl.tencent_model_provider.model.hunyuan import ChatHunyuan
+
+
+class TencentModel(MaxKBBaseModel, ChatHunyuan):
+
+ def __init__(self, model_name: str, credentials: Dict[str, str], streaming: bool = False, **kwargs):
+ hunyuan_app_id = credentials.get('hunyuan_app_id')
+ hunyuan_secret_id = credentials.get('hunyuan_secret_id')
+ hunyuan_secret_key = credentials.get('hunyuan_secret_key')
+
+ if not all([hunyuan_app_id, hunyuan_secret_id, hunyuan_secret_key]):
+ raise ValueError(
+ "All of 'hunyuan_app_id', 'hunyuan_secret_id', and 'hunyuan_secret_key' must be provided in credentials.")
+
+ super().__init__(model=model_name, hunyuan_app_id=hunyuan_app_id, hunyuan_secret_id=hunyuan_secret_id,
+ hunyuan_secret_key=hunyuan_secret_key, streaming=streaming, **kwargs)
+
+ @staticmethod
+ def new_instance(model_type: str, model_name: str, model_credential: Dict[str, object],
+ **model_kwargs) -> 'TencentModel':
+ streaming = model_kwargs.pop('streaming', False)
+ return TencentModel(model_name=model_name, credentials=model_credential, streaming=streaming, **model_kwargs)
+
+ def get_num_tokens_from_messages(self, messages: List[BaseMessage]) -> int:
+ tokenizer = TokenizerManage.get_tokenizer()
+ return sum(len(tokenizer.encode(get_buffer_string([m]))) for m in messages)
+
+ def get_num_tokens(self, text: str) -> int:
+ tokenizer = TokenizerManage.get_tokenizer()
+ return len(tokenizer.encode(text))
diff --git a/apps/setting/models_provider/impl/tencent_model_provider/tencent_model_provider.py b/apps/setting/models_provider/impl/tencent_model_provider/tencent_model_provider.py
new file mode 100644
index 000000000..572ff940d
--- /dev/null
+++ b/apps/setting/models_provider/impl/tencent_model_provider/tencent_model_provider.py
@@ -0,0 +1,103 @@
+#!/usr/bin/env python
+# -*- coding: UTF-8 -*-
+
+import os
+from common.util.file_util import get_file_content
+from setting.models_provider.base_model_provider import (
+ IModelProvider, ModelProvideInfo, ModelInfo, ModelTypeConst, ModelInfoManage
+)
+from setting.models_provider.impl.tencent_model_provider.credential.embedding import TencentEmbeddingCredential
+from setting.models_provider.impl.tencent_model_provider.credential.llm import TencentLLMModelCredential
+from setting.models_provider.impl.tencent_model_provider.model.embedding import TencentEmbeddingModel
+from setting.models_provider.impl.tencent_model_provider.model.llm import TencentModel
+from smartdoc.conf import PROJECT_DIR
+
+
+def _create_model_info(model_name, description, model_type, credential_class, model_class):
+ return ModelInfo(
+ name=model_name,
+ desc=description,
+ model_type=model_type,
+ model_credential=credential_class(),
+ model_class=model_class
+ )
+
+
+def _get_tencent_icon_path():
+ return os.path.join(PROJECT_DIR, "apps", "setting", 'models_provider', 'impl', 'tencent_model_provider',
+ 'icon', 'tencent_icon_svg')
+
+
+def _initialize_model_info():
+ model_info_list = [_create_model_info(
+ 'hunyuan-pro',
+ '当前混元模型中效果最优版本,万亿级参数规模 MOE-32K 长文模型。在各种 benchmark 上达到绝对领先的水平,复杂指令和推理,具备复杂数学能力,支持 functioncall,在多语言翻译、金融法律医疗等领域应用重点优化',
+ ModelTypeConst.LLM,
+ TencentLLMModelCredential,
+ TencentModel
+ ),
+ _create_model_info(
+ 'hunyuan-standard',
+ '采用更优的路由策略,同时缓解了负载均衡和专家趋同的问题。长文方面,大海捞针指标达到99.9%',
+ ModelTypeConst.LLM,
+ TencentLLMModelCredential,
+ TencentModel),
+ _create_model_info(
+ 'hunyuan-lite',
+ '升级为 MOE 结构,上下文窗口为 256k ,在 NLP,代码,数学,行业等多项评测集上领先众多开源模型',
+ ModelTypeConst.LLM,
+ TencentLLMModelCredential,
+ TencentModel),
+ _create_model_info(
+ 'hunyuan-role',
+ '混元最新版角色扮演模型,混元官方精调训练推出的角色扮演模型,基于混元模型结合角色扮演场景数据集进行增训,在角色扮演场景具有更好的基础效果',
+ ModelTypeConst.LLM,
+ TencentLLMModelCredential,
+ TencentModel),
+ _create_model_info(
+ 'hunyuan-functioncall ',
+ '混元最新 MOE 架构 FunctionCall 模型,经过高质量的 FunctionCall 数据训练,上下文窗口达 32K,在多个维度的评测指标上处于领先。',
+ ModelTypeConst.LLM,
+ TencentLLMModelCredential,
+ TencentModel),
+ _create_model_info(
+ 'hunyuan-code',
+ '混元最新代码生成模型,经过 200B 高质量代码数据增训基座模型,迭代半年高质量 SFT 数据训练,上下文长窗口长度增大到 8K,五大语言代码生成自动评测指标上位居前列;五大语言10项考量各方面综合代码任务人工高质量评测上,性能处于第一梯队',
+ ModelTypeConst.LLM,
+ TencentLLMModelCredential,
+ TencentModel),
+ ]
+
+ tencent_embedding_model_info = _create_model_info(
+ 'hunyuan-embedding',
+ '',
+ ModelTypeConst.EMBEDDING,
+ TencentEmbeddingCredential,
+ TencentEmbeddingModel
+ )
+
+ model_info_embedding_list = [tencent_embedding_model_info]
+
+ model_info_manage = ModelInfoManage.builder() \
+ .append_model_info_list(model_info_list) \
+ .append_default_model_info(model_info_list[0]) \
+ .build()
+
+ return model_info_manage
+
+
+class TencentModelProvider(IModelProvider):
+ def __init__(self):
+ self._model_info_manage = _initialize_model_info()
+
+ def get_model_info_manage(self):
+ return self._model_info_manage
+
+ def get_model_provide_info(self):
+ icon_path = _get_tencent_icon_path()
+ icon_data = get_file_content(icon_path)
+ return ModelProvideInfo(
+ provider='model_tencent_provider',
+ name='腾讯混元',
+ icon=icon_data
+ )
diff --git a/apps/setting/models_provider/impl/volcanic_engine_model_provider/__init__.py b/apps/setting/models_provider/impl/volcanic_engine_model_provider/__init__.py
new file mode 100644
index 000000000..8cb7f459e
--- /dev/null
+++ b/apps/setting/models_provider/impl/volcanic_engine_model_provider/__init__.py
@@ -0,0 +1,2 @@
+#!/usr/bin/env python
+# -*- coding: UTF-8 -*-
diff --git a/apps/setting/models_provider/impl/volcanic_engine_model_provider/credential/embedding.py b/apps/setting/models_provider/impl/volcanic_engine_model_provider/credential/embedding.py
new file mode 100644
index 000000000..d49d22e22
--- /dev/null
+++ b/apps/setting/models_provider/impl/volcanic_engine_model_provider/credential/embedding.py
@@ -0,0 +1,46 @@
+# coding=utf-8
+"""
+ @project: MaxKB
+ @Author:虎
+ @file: embedding.py
+ @date:2024/7/12 16:45
+ @desc:
+"""
+from typing import Dict
+
+from common import forms
+from common.exception.app_exception import AppApiException
+from common.forms import BaseForm
+from setting.models_provider.base_model_provider import BaseModelCredential, ValidCode
+
+
+class OpenAIEmbeddingCredential(BaseForm, BaseModelCredential):
+ def is_valid(self, model_type: str, model_name, model_credential: Dict[str, object], provider,
+ raise_exception=True):
+ model_type_list = provider.get_model_type_list()
+ if not any(list(filter(lambda mt: mt.get('value') == model_type, model_type_list))):
+ raise AppApiException(ValidCode.valid_error.value, f'{model_type} 模型类型不支持')
+
+ for key in ['api_base', 'api_key']:
+ if key not in model_credential:
+ if raise_exception:
+ raise AppApiException(ValidCode.valid_error.value, f'{key} 字段为必填字段')
+ else:
+ return False
+ try:
+ model = provider.get_model(model_type, model_name, model_credential)
+ model.embed_query('你好')
+ except Exception as e:
+ if isinstance(e, AppApiException):
+ raise e
+ if raise_exception:
+ raise AppApiException(ValidCode.valid_error.value, f'校验失败,请检查参数是否正确: {str(e)}')
+ else:
+ return False
+ return True
+
+ def encryption_dict(self, model: Dict[str, object]):
+ return {**model, 'api_key': super().encryption(model.get('api_key', ''))}
+
+ api_base = forms.TextInputField('API 域名', required=True)
+ api_key = forms.PasswordInputField('API Key', required=True)
diff --git a/apps/setting/models_provider/impl/volcanic_engine_model_provider/credential/llm.py b/apps/setting/models_provider/impl/volcanic_engine_model_provider/credential/llm.py
new file mode 100644
index 000000000..5647e2a08
--- /dev/null
+++ b/apps/setting/models_provider/impl/volcanic_engine_model_provider/credential/llm.py
@@ -0,0 +1,50 @@
+# coding=utf-8
+"""
+ @project: MaxKB
+ @Author:虎
+ @file: llm.py
+ @date:2024/7/11 17:57
+ @desc:
+"""
+from typing import Dict
+
+from langchain_core.messages import HumanMessage
+
+from common import forms
+from common.exception.app_exception import AppApiException
+from common.forms import BaseForm
+from setting.models_provider.base_model_provider import BaseModelCredential, ValidCode
+
+
+class VolcanicEngineLLMModelCredential(BaseForm, BaseModelCredential):
+
+ def is_valid(self, model_type: str, model_name, model_credential: Dict[str, object], provider,
+ raise_exception=False):
+ model_type_list = provider.get_model_type_list()
+ if not any(list(filter(lambda mt: mt.get('value') == model_type, model_type_list))):
+ raise AppApiException(ValidCode.valid_error.value, f'{model_type} 模型类型不支持')
+
+ for key in ['access_key_id', 'secret_access_key']:
+ if key not in model_credential:
+ if raise_exception:
+ raise AppApiException(ValidCode.valid_error.value, f'{key} 字段为必填字段')
+ else:
+ return False
+ try:
+ model = provider.get_model(model_type, model_name, model_credential)
+ res = model.invoke([HumanMessage(content='你好')])
+ print(res)
+ except Exception as e:
+ if isinstance(e, AppApiException):
+ raise e
+ if raise_exception:
+ raise AppApiException(ValidCode.valid_error.value, f'校验失败,请检查参数是否正确: {str(e)}')
+ else:
+ return False
+ return True
+
+ def encryption_dict(self, model: Dict[str, object]):
+ return {**model, 'access_key_id': super().encryption(model.get('access_key_id', ''))}
+
+ access_key_id = forms.PasswordInputField('Access Key ID', required=True)
+ secret_access_key = forms.PasswordInputField('Secret Access Key', required=True)
diff --git a/apps/setting/models_provider/impl/volcanic_engine_model_provider/icon/volcanic_engine_icon_svg b/apps/setting/models_provider/impl/volcanic_engine_model_provider/icon/volcanic_engine_icon_svg
new file mode 100644
index 000000000..05a1279ef
--- /dev/null
+++ b/apps/setting/models_provider/impl/volcanic_engine_model_provider/icon/volcanic_engine_icon_svg
@@ -0,0 +1,5 @@
+
+
+
diff --git a/apps/setting/models_provider/impl/volcanic_engine_model_provider/model/embedding.py b/apps/setting/models_provider/impl/volcanic_engine_model_provider/model/embedding.py
new file mode 100644
index 000000000..b7307a0e5
--- /dev/null
+++ b/apps/setting/models_provider/impl/volcanic_engine_model_provider/model/embedding.py
@@ -0,0 +1,15 @@
+from typing import Dict
+
+from langchain_community.embeddings import VolcanoEmbeddings
+
+from setting.models_provider.base_model_provider import MaxKBBaseModel
+
+
+class VolcanicEngineEmbeddingModel(MaxKBBaseModel, VolcanoEmbeddings):
+ @staticmethod
+ def new_instance(model_type, model_name, model_credential: Dict[str, object], **model_kwargs):
+ return VolcanicEngineEmbeddingModel(
+ api_key=model_credential.get('api_key'),
+ model=model_name,
+ openai_api_base=model_credential.get('api_base'),
+ )
diff --git a/apps/setting/models_provider/impl/volcanic_engine_model_provider/model/llm.py b/apps/setting/models_provider/impl/volcanic_engine_model_provider/model/llm.py
new file mode 100644
index 000000000..3ab309863
--- /dev/null
+++ b/apps/setting/models_provider/impl/volcanic_engine_model_provider/model/llm.py
@@ -0,0 +1,27 @@
+from typing import List, Dict
+
+from langchain_community.chat_models import VolcEngineMaasChat
+from langchain_core.messages import BaseMessage, get_buffer_string
+
+from common.config.tokenizer_manage_config import TokenizerManage
+from setting.models_provider.base_model_provider import MaxKBBaseModel
+from langchain_openai import ChatOpenAI
+
+
+class VolcanicEngineChatModel(MaxKBBaseModel, ChatOpenAI):
+ @staticmethod
+ def new_instance(model_type, model_name, model_credential: Dict[str, object], **model_kwargs):
+ volcanic_engine_chat = VolcanicEngineChatModel(
+ model=model_name,
+ volc_engine_maas_ak=model_credential.get("access_key_id"),
+ volc_engine_maas_sk=model_credential.get("secret_access_key"),
+ )
+ return volcanic_engine_chat
+
+ def get_num_tokens_from_messages(self, messages: List[BaseMessage]) -> int:
+ tokenizer = TokenizerManage.get_tokenizer()
+ return sum([len(tokenizer.encode(get_buffer_string([m]))) for m in messages])
+
+ def get_num_tokens(self, text: str) -> int:
+ tokenizer = TokenizerManage.get_tokenizer()
+ return len(tokenizer.encode(text))
diff --git a/apps/setting/models_provider/impl/volcanic_engine_model_provider/volcanic_engine_model_provider.py b/apps/setting/models_provider/impl/volcanic_engine_model_provider/volcanic_engine_model_provider.py
new file mode 100644
index 000000000..af7bc76b3
--- /dev/null
+++ b/apps/setting/models_provider/impl/volcanic_engine_model_provider/volcanic_engine_model_provider.py
@@ -0,0 +1,51 @@
+#!/usr/bin/env python
+# -*- coding: UTF-8 -*-
+"""
+@Project :MaxKB
+@File :gemini_model_provider.py
+@Author :Brian Yang
+@Date :5/13/24 7:47 AM
+"""
+import os
+
+from common.util.file_util import get_file_content
+from setting.models_provider.base_model_provider import IModelProvider, ModelProvideInfo, ModelInfo, ModelTypeConst, \
+ ModelInfoManage
+from setting.models_provider.impl.openai_model_provider.credential.embedding import OpenAIEmbeddingCredential
+from setting.models_provider.impl.openai_model_provider.credential.llm import OpenAILLMModelCredential
+from setting.models_provider.impl.openai_model_provider.model.embedding import OpenAIEmbeddingModel
+from setting.models_provider.impl.openai_model_provider.model.llm import OpenAIChatModel
+
+from smartdoc.conf import PROJECT_DIR
+
+volcanic_engine_llm_model_credential = OpenAILLMModelCredential()
+
+model_info_list = [
+ ModelInfo('ep-xxxxxxxxxx-yyyy',
+ '用户前往火山方舟的模型推理页面创建推理接入点,这里需要输入ep-xxxxxxxxxx-yyyy进行调用',
+ ModelTypeConst.LLM,
+ volcanic_engine_llm_model_credential, OpenAIChatModel
+ )
+]
+
+open_ai_embedding_credential = OpenAIEmbeddingCredential()
+model_info_embedding_list = [
+ ModelInfo('ep-xxxxxxxxxx-yyyy',
+ '用户前往火山方舟的模型推理页面创建推理接入点,这里需要输入ep-xxxxxxxxxx-yyyy进行调用',
+ ModelTypeConst.EMBEDDING, open_ai_embedding_credential,
+ OpenAIEmbeddingModel)]
+
+model_info_manage = ModelInfoManage.builder().append_model_info_list(model_info_list).append_default_model_info(
+ model_info_list[0]).build()
+
+
+class VolcanicEngineModelProvider(IModelProvider):
+
+ def get_model_info_manage(self):
+ return model_info_manage
+
+ def get_model_provide_info(self):
+ return ModelProvideInfo(provider='model_volcanic_engine_provider', name='火山引擎', icon=get_file_content(
+ os.path.join(PROJECT_DIR, "apps", "setting", 'models_provider', 'impl', 'volcanic_engine_model_provider',
+ 'icon',
+ 'volcanic_engine_icon_svg')))
diff --git a/pyproject.toml b/pyproject.toml
index 7566a7014..88fac01ee 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -46,6 +46,8 @@ gunicorn = "^22.0.0"
python-daemon = "3.0.1"
gevent = "^24.2.1"
+boto3 = "^1.34.151"
+langchain-aws = "^0.1.13"
[build-system]
requires = ["poetry-core"]
build-backend = "poetry.core.masonry.api"