From 72423a7c3e3b5a65ee66c5c5d4ce58b06d6fe55b Mon Sep 17 00:00:00 2001 From: wxg0103 <727495428@qq.com> Date: Wed, 24 Jul 2024 14:14:05 +0800 Subject: [PATCH] =?UTF-8?q?feat:=20=E5=A2=9E=E5=8A=A0=E5=AF=B9=E6=8E=A5?= =?UTF-8?q?=E6=A8=A1=E5=9E=8B?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../constants/model_provider_constants.py | 7 + .../aws_bedrock_model_provider/__init__.py | 2 + .../aws_bedrock_model_provider.py | 107 +++++++ .../credential/embedding.py | 64 ++++ .../credential/llm.py | 62 ++++ .../icon/bedrock_icon_svg | 1 + .../model/embedding.py | 25 ++ .../aws_bedrock_model_provider/model/llm.py | 35 +++ .../impl/tencent_model_provider/__init__.py | 2 + .../credential/embedding.py | 64 ++++ .../tencent_model_provider/credential/llm.py | 47 +++ .../icon/tencent_icon_svg | 5 + .../tencent_model_provider/model/embedding.py | 25 ++ .../tencent_model_provider/model/hunyuan.py | 273 ++++++++++++++++++ .../impl/tencent_model_provider/model/llm.py | 37 +++ .../tencent_model_provider.py | 103 +++++++ .../__init__.py | 2 + .../credential/embedding.py | 46 +++ .../credential/llm.py | 50 ++++ .../icon/volcanic_engine_icon_svg | 5 + .../model/embedding.py | 15 + .../model/llm.py | 27 ++ .../volcanic_engine_model_provider.py | 51 ++++ pyproject.toml | 2 + 24 files changed, 1057 insertions(+) create mode 100644 apps/setting/models_provider/impl/aws_bedrock_model_provider/__init__.py create mode 100644 apps/setting/models_provider/impl/aws_bedrock_model_provider/aws_bedrock_model_provider.py create mode 100644 apps/setting/models_provider/impl/aws_bedrock_model_provider/credential/embedding.py create mode 100644 apps/setting/models_provider/impl/aws_bedrock_model_provider/credential/llm.py create mode 100644 apps/setting/models_provider/impl/aws_bedrock_model_provider/icon/bedrock_icon_svg create mode 100644 apps/setting/models_provider/impl/aws_bedrock_model_provider/model/embedding.py create mode 100644 apps/setting/models_provider/impl/aws_bedrock_model_provider/model/llm.py create mode 100644 apps/setting/models_provider/impl/tencent_model_provider/__init__.py create mode 100644 apps/setting/models_provider/impl/tencent_model_provider/credential/embedding.py create mode 100644 apps/setting/models_provider/impl/tencent_model_provider/credential/llm.py create mode 100644 apps/setting/models_provider/impl/tencent_model_provider/icon/tencent_icon_svg create mode 100644 apps/setting/models_provider/impl/tencent_model_provider/model/embedding.py create mode 100644 apps/setting/models_provider/impl/tencent_model_provider/model/hunyuan.py create mode 100644 apps/setting/models_provider/impl/tencent_model_provider/model/llm.py create mode 100644 apps/setting/models_provider/impl/tencent_model_provider/tencent_model_provider.py create mode 100644 apps/setting/models_provider/impl/volcanic_engine_model_provider/__init__.py create mode 100644 apps/setting/models_provider/impl/volcanic_engine_model_provider/credential/embedding.py create mode 100644 apps/setting/models_provider/impl/volcanic_engine_model_provider/credential/llm.py create mode 100644 apps/setting/models_provider/impl/volcanic_engine_model_provider/icon/volcanic_engine_icon_svg create mode 100644 apps/setting/models_provider/impl/volcanic_engine_model_provider/model/embedding.py create mode 100644 apps/setting/models_provider/impl/volcanic_engine_model_provider/model/llm.py create mode 100644 apps/setting/models_provider/impl/volcanic_engine_model_provider/volcanic_engine_model_provider.py diff --git a/apps/setting/models_provider/constants/model_provider_constants.py b/apps/setting/models_provider/constants/model_provider_constants.py index c9e9659c3..6d4ff9cc0 100644 --- a/apps/setting/models_provider/constants/model_provider_constants.py +++ b/apps/setting/models_provider/constants/model_provider_constants.py @@ -8,6 +8,7 @@ """ from enum import Enum +from setting.models_provider.impl.aws_bedrock_model_provider.aws_bedrock_model_provider import BedrockModelProvider from setting.models_provider.impl.azure_model_provider.azure_model_provider import AzureModelProvider from setting.models_provider.impl.deepseek_model_provider.deepseek_model_provider import DeepSeekModelProvider from setting.models_provider.impl.gemini_model_provider.gemini_model_provider import GeminiModelProvider @@ -15,6 +16,9 @@ from setting.models_provider.impl.kimi_model_provider.kimi_model_provider import from setting.models_provider.impl.ollama_model_provider.ollama_model_provider import OllamaModelProvider from setting.models_provider.impl.openai_model_provider.openai_model_provider import OpenAIModelProvider from setting.models_provider.impl.qwen_model_provider.qwen_model_provider import QwenModelProvider +from setting.models_provider.impl.tencent_model_provider.tencent_model_provider import TencentModelProvider +from setting.models_provider.impl.volcanic_engine_model_provider.volcanic_engine_model_provider import \ + VolcanicEngineModelProvider from setting.models_provider.impl.wenxin_model_provider.wenxin_model_provider import WenxinModelProvider from setting.models_provider.impl.xf_model_provider.xf_model_provider import XunFeiModelProvider from setting.models_provider.impl.zhipu_model_provider.zhipu_model_provider import ZhiPuModelProvider @@ -32,4 +36,7 @@ class ModelProvideConstants(Enum): model_xf_provider = XunFeiModelProvider() model_deepseek_provider = DeepSeekModelProvider() model_gemini_provider = GeminiModelProvider() + model_volcanic_engine_provider = VolcanicEngineModelProvider() + model_tencent_provider = TencentModelProvider() + model_aws_bedrock_provider = BedrockModelProvider() model_local_provider = LocalModelProvider() diff --git a/apps/setting/models_provider/impl/aws_bedrock_model_provider/__init__.py b/apps/setting/models_provider/impl/aws_bedrock_model_provider/__init__.py new file mode 100644 index 000000000..8cb7f459e --- /dev/null +++ b/apps/setting/models_provider/impl/aws_bedrock_model_provider/__init__.py @@ -0,0 +1,2 @@ +#!/usr/bin/env python +# -*- coding: UTF-8 -*- diff --git a/apps/setting/models_provider/impl/aws_bedrock_model_provider/aws_bedrock_model_provider.py b/apps/setting/models_provider/impl/aws_bedrock_model_provider/aws_bedrock_model_provider.py new file mode 100644 index 000000000..a6187f995 --- /dev/null +++ b/apps/setting/models_provider/impl/aws_bedrock_model_provider/aws_bedrock_model_provider.py @@ -0,0 +1,107 @@ +#!/usr/bin/env python +# -*- coding: UTF-8 -*- + +import os +from common.util.file_util import get_file_content +from setting.models_provider.base_model_provider import ( + IModelProvider, ModelProvideInfo, ModelInfo, ModelTypeConst, ModelInfoManage +) +from setting.models_provider.impl.aws_bedrock_model_provider.credential.llm import BedrockLLMModelCredential +from setting.models_provider.impl.aws_bedrock_model_provider.model.llm import BedrockModel +from setting.models_provider.impl.tencent_model_provider.credential.embedding import TencentEmbeddingCredential +from setting.models_provider.impl.tencent_model_provider.credential.llm import TencentLLMModelCredential +from setting.models_provider.impl.tencent_model_provider.model.embedding import TencentEmbeddingModel +from setting.models_provider.impl.tencent_model_provider.model.llm import TencentModel +from smartdoc.conf import PROJECT_DIR + + +def _create_model_info(model_name, description, model_type, credential_class, model_class): + return ModelInfo( + name=model_name, + desc=description, + model_type=model_type, + model_credential=credential_class(), + model_class=model_class + ) + + +def _get_aws_bedrock_icon_path(): + return os.path.join(PROJECT_DIR, "apps", "setting", 'models_provider', 'impl', 'aws_bedrock_model_provider', + 'icon', 'bedrock_icon_svg') + + +def _initialize_model_info(): + model_info_list = [_create_model_info( + 'amazon.titan-text-premier-v1:0', + 'Titan Text Premier 是 Titan Text 系列中功能强大且先进的型号,旨在为各种企业应用程序提供卓越的性能。凭借其尖端功能,它提供了更高的准确性和出色的结果,使其成为寻求一流文本处理解决方案的组织的绝佳选择。', + ModelTypeConst.LLM, + BedrockLLMModelCredential, + BedrockModel + ), + _create_model_info( + 'amazon.titan-text-lite-v1', + 'Amazon Titan Text Lite 是一种轻量级的高效模型,非常适合英语任务的微调,包括摘要和文案写作等,在这种场景下,客户需要更小、更经济高效且高度可定制的模型', + ModelTypeConst.LLM, + BedrockLLMModelCredential, + BedrockModel), + _create_model_info( + 'amazon.titan-text-express-v1', + 'Amazon Titan Text Express 的上下文长度长达 8000 个令牌,因而非常适合各种高级常规语言任务,例如开放式文本生成和对话式聊天,以及检索增强生成(RAG)中的支持。在发布时,该模型针对英语进行了优化,但也支持其他语言。', + ModelTypeConst.LLM, + BedrockLLMModelCredential, + BedrockModel), + _create_model_info( + 'amazon.titan-embed-text-v2:0', + 'Amazon Titan Text Embeddings V2 是一种轻量级、高效的模型,非常适合在不同维度上执行高精度检索任务。该模型支持灵活的嵌入大小(1024、512 和 256),并优先考虑在较小维度上保持准确性,从而可以在不影响准确性的情况下降低存储成本。Titan Text Embeddings V2 适用于各种任务,包括文档检索、推荐系统、搜索引擎和对话式系统。', + ModelTypeConst.LLM, + BedrockLLMModelCredential, + BedrockModel), + _create_model_info( + 'mistral.mistral-7b-instruct-v0:2', + '7B 密集型转换器,可快速部署,易于定制。体积虽小,但功能强大,适用于各种用例。支持英语和代码,以及 32k 的上下文窗口。', + ModelTypeConst.EMBEDDING, + BedrockLLMModelCredential, + BedrockModel), + _create_model_info( + 'mistral.mistral-large-2402-v1:0', + '先进的 Mistral AI 大型语言模型,能够处理任何语言任务,包括复杂的多语言推理、文本理解、转换和代码生成。', + ModelTypeConst.LLM, + BedrockLLMModelCredential, + BedrockModel), + _create_model_info( + 'meta.llama3-70b-instruct-v1:0', + '非常适合内容创作、会话式人工智能、语言理解、研发和企业应用', + ModelTypeConst.LLM, + BedrockLLMModelCredential, + BedrockModel), + _create_model_info( + 'meta.llama3-8b-instruct-v1:0', + '非常适合有限的计算能力和资源、边缘设备和更快的训练时间。', + ModelTypeConst.LLM, + BedrockLLMModelCredential, + BedrockModel), + ] + + model_info_manage = ModelInfoManage.builder() \ + .append_model_info_list(model_info_list) \ + .append_default_model_info(model_info_list[0]) \ + .build() + + return model_info_manage + + +class BedrockModelProvider(IModelProvider): + def __init__(self): + self._model_info_manage = _initialize_model_info() + + def get_model_info_manage(self): + return self._model_info_manage + + def get_model_provide_info(self): + icon_path = _get_aws_bedrock_icon_path() + icon_data = get_file_content(icon_path) + return ModelProvideInfo( + provider='model_aws_bedrock_provider', + name='Amazon Bedrock', + icon=icon_data + ) diff --git a/apps/setting/models_provider/impl/aws_bedrock_model_provider/credential/embedding.py b/apps/setting/models_provider/impl/aws_bedrock_model_provider/credential/embedding.py new file mode 100644 index 000000000..40e36caca --- /dev/null +++ b/apps/setting/models_provider/impl/aws_bedrock_model_provider/credential/embedding.py @@ -0,0 +1,64 @@ +import json +from typing import Dict + +from tencentcloud.common import credential +from tencentcloud.common.profile.client_profile import ClientProfile +from tencentcloud.common.profile.http_profile import HttpProfile +from tencentcloud.hunyuan.v20230901 import hunyuan_client, models + +from common import forms +from common.exception.app_exception import AppApiException +from common.forms import BaseForm +from setting.models_provider.base_model_provider import BaseModelCredential, ValidCode + + +class TencentEmbeddingCredential(BaseForm, BaseModelCredential): + @classmethod + def _validate_model_type(cls, model_type: str, provider) -> bool: + model_type_list = provider.get_model_type_list() + if not any(mt.get('value') == model_type for mt in model_type_list): + raise AppApiException(ValidCode.valid_error.value, f'{model_type} 模型类型不支持') + return True + + @classmethod + def _validate_credential(cls, model_credential: Dict[str, object]) -> credential.Credential: + for key in ['SecretId', 'SecretKey']: + if key not in model_credential: + raise AppApiException(ValidCode.valid_error.value, f'{key} 字段为必填字段') + return credential.Credential(model_credential['SecretId'], model_credential['SecretKey']) + + @classmethod + def _test_credentials(cls, client, model_name: str): + req = models.GetEmbeddingRequest() + params = { + "Model": model_name, + "Input": "测试" + } + req.from_json_string(json.dumps(params)) + try: + res = client.GetEmbedding(req) + print(res.to_json_string()) + except Exception as e: + raise AppApiException(ValidCode.valid_error.value, f'校验失败,请检查参数是否正确: {str(e)}') + + def is_valid(self, model_type: str, model_name, model_credential: Dict[str, object], provider, + raise_exception=True) -> bool: + try: + self._validate_model_type(model_type, provider) + cred = self._validate_credential(model_credential) + httpProfile = HttpProfile(endpoint="hunyuan.tencentcloudapi.com") + clientProfile = ClientProfile(httpProfile=httpProfile) + client = hunyuan_client.HunyuanClient(cred, "", clientProfile) + self._test_credentials(client, model_name) + return True + except AppApiException as e: + if raise_exception: + raise e + return False + + def encryption_dict(self, model: Dict[str, object]) -> Dict[str, object]: + encrypted_secret_key = super().encryption(model.get('SecretKey', '')) + return {**model, 'SecretKey': encrypted_secret_key} + + SecretId = forms.PasswordInputField('SecretId', required=True) + SecretKey = forms.PasswordInputField('SecretKey', required=True) diff --git a/apps/setting/models_provider/impl/aws_bedrock_model_provider/credential/llm.py b/apps/setting/models_provider/impl/aws_bedrock_model_provider/credential/llm.py new file mode 100644 index 000000000..848250b1e --- /dev/null +++ b/apps/setting/models_provider/impl/aws_bedrock_model_provider/credential/llm.py @@ -0,0 +1,62 @@ +import os +import re +from typing import Dict +from common.exception.app_exception import AppApiException +from common.forms import BaseForm +from setting.models_provider.base_model_provider import ValidCode, BaseModelCredential +from langchain_core.messages import HumanMessage +from common import forms + + +class BedrockLLMModelCredential(BaseForm, BaseModelCredential): + + @staticmethod + def _update_aws_credentials(profile_name, access_key_id, secret_access_key): + credentials_path = os.path.join(os.path.expanduser("~"), ".aws", "credentials") + os.makedirs(os.path.dirname(credentials_path), exist_ok=True) + + content = open(credentials_path, 'r').read() if os.path.exists(credentials_path) else '' + pattern = rf'\n*\[{profile_name}\]\n*(aws_access_key_id = .*)\n*(aws_secret_access_key = .*)\n*' + content = re.sub(pattern, '', content, flags=re.DOTALL) + + if not re.search(rf'\[{profile_name}\]', content): + content += f"\n[{profile_name}]\naws_access_key_id = {access_key_id}\naws_secret_access_key = {secret_access_key}\n" + + with open(credentials_path, 'w') as file: + file.write(content) + + def is_valid(self, model_type: str, model_name, model_credential: Dict[str, object], provider, + raise_exception=False): + model_type_list = provider.get_model_type_list() + if not any(mt.get('value') == model_type for mt in model_type_list): + if raise_exception: + raise AppApiException(ValidCode.valid_error.value, f'{model_type} 模型类型不支持') + return False + + required_keys = ['region_name', 'access_key_id', 'secret_access_key'] + if not all(key in model_credential for key in required_keys): + if raise_exception: + raise AppApiException(ValidCode.valid_error.value, f'以下字段为必填字段: {", ".join(required_keys)}') + return False + + try: + self._update_aws_credentials('aws-profile', model_credential['access_key_id'], + model_credential['secret_access_key']) + model_credential['credentials_profile_name'] = 'aws-profile' + model = provider.get_model(model_type, model_name, model_credential) + model.invoke([HumanMessage(content='你好')]) + except AppApiException: + raise + except Exception as e: + if raise_exception: + raise AppApiException(ValidCode.valid_error.value, f'校验失败,请检查参数是否正确: {str(e)}') + return False + + return True + + def encryption_dict(self, model: Dict[str, object]): + return {**model, 'secret_access_key': super().encryption(model.get('secret_access_key', ''))} + + region_name = forms.TextInputField('Region Name', required=True) + access_key_id = forms.TextInputField('Access Key ID', required=True) + secret_access_key = forms.PasswordInputField('Secret Access Key', required=True) diff --git a/apps/setting/models_provider/impl/aws_bedrock_model_provider/icon/bedrock_icon_svg b/apps/setting/models_provider/impl/aws_bedrock_model_provider/icon/bedrock_icon_svg new file mode 100644 index 000000000..5f176a7d2 --- /dev/null +++ b/apps/setting/models_provider/impl/aws_bedrock_model_provider/icon/bedrock_icon_svg @@ -0,0 +1 @@ + \ No newline at end of file diff --git a/apps/setting/models_provider/impl/aws_bedrock_model_provider/model/embedding.py b/apps/setting/models_provider/impl/aws_bedrock_model_provider/model/embedding.py new file mode 100644 index 000000000..a5bd0336a --- /dev/null +++ b/apps/setting/models_provider/impl/aws_bedrock_model_provider/model/embedding.py @@ -0,0 +1,25 @@ +from setting.models_provider.base_model_provider import MaxKBBaseModel +from typing import Dict +import requests + + +class TencentEmbeddingModel(MaxKBBaseModel): + def __init__(self, secret_id: str, secret_key: str, api_base: str, model_name: str): + self.secret_id = secret_id + self.secret_key = secret_key + self.api_base = api_base + self.model_name = model_name + + @staticmethod + def new_instance(model_type: str, model_name: str, model_credential: Dict[str, str], **model_kwargs): + return TencentEmbeddingModel( + secret_id=model_credential.get('SecretId'), + secret_key=model_credential.get('SecretKey'), + api_base=model_credential.get('api_base'), + model_name=model_name, + ) + + + def _generate_auth_token(self): + # Example method to generate an authentication token for the model API + return f"{self.secret_id}:{self.secret_key}" diff --git a/apps/setting/models_provider/impl/aws_bedrock_model_provider/model/llm.py b/apps/setting/models_provider/impl/aws_bedrock_model_provider/model/llm.py new file mode 100644 index 000000000..43c3c00a0 --- /dev/null +++ b/apps/setting/models_provider/impl/aws_bedrock_model_provider/model/llm.py @@ -0,0 +1,35 @@ +from typing import List, Dict, Any +from langchain_community.chat_models import BedrockChat +from langchain_core.messages import BaseMessage, get_buffer_string +from common.config.tokenizer_manage_config import TokenizerManage +from setting.models_provider.base_model_provider import MaxKBBaseModel + + +class BedrockModel(MaxKBBaseModel, BedrockChat): + + def __init__(self, model_id: str, region_name: str, credentials_profile_name: str, + streaming: bool = False, **kwargs): + super().__init__(model_id=model_id, region_name=region_name, + credentials_profile_name=credentials_profile_name, streaming=streaming, **kwargs) + + @classmethod + def new_instance(cls, model_type: str, model_name: str, model_credential: Dict[str, str], + **model_kwargs) -> 'BedrockModel': + return cls( + model_id=model_name, + region_name=model_credential['region_name'], + credentials_profile_name=model_credential['credentials_profile_name'], + streaming=model_kwargs.pop('streaming', False), + **model_kwargs + ) + + def _get_num_tokens(self, content: str) -> int: + """Helper method to count tokens in a string.""" + tokenizer = TokenizerManage.get_tokenizer() + return len(tokenizer.encode(content)) + + def get_num_tokens_from_messages(self, messages: List[BaseMessage]) -> int: + return sum(self._get_num_tokens(get_buffer_string([message])) for message in messages) + + def get_num_tokens(self, text: str) -> int: + return self._get_num_tokens(text) diff --git a/apps/setting/models_provider/impl/tencent_model_provider/__init__.py b/apps/setting/models_provider/impl/tencent_model_provider/__init__.py new file mode 100644 index 000000000..8cb7f459e --- /dev/null +++ b/apps/setting/models_provider/impl/tencent_model_provider/__init__.py @@ -0,0 +1,2 @@ +#!/usr/bin/env python +# -*- coding: UTF-8 -*- diff --git a/apps/setting/models_provider/impl/tencent_model_provider/credential/embedding.py b/apps/setting/models_provider/impl/tencent_model_provider/credential/embedding.py new file mode 100644 index 000000000..40e36caca --- /dev/null +++ b/apps/setting/models_provider/impl/tencent_model_provider/credential/embedding.py @@ -0,0 +1,64 @@ +import json +from typing import Dict + +from tencentcloud.common import credential +from tencentcloud.common.profile.client_profile import ClientProfile +from tencentcloud.common.profile.http_profile import HttpProfile +from tencentcloud.hunyuan.v20230901 import hunyuan_client, models + +from common import forms +from common.exception.app_exception import AppApiException +from common.forms import BaseForm +from setting.models_provider.base_model_provider import BaseModelCredential, ValidCode + + +class TencentEmbeddingCredential(BaseForm, BaseModelCredential): + @classmethod + def _validate_model_type(cls, model_type: str, provider) -> bool: + model_type_list = provider.get_model_type_list() + if not any(mt.get('value') == model_type for mt in model_type_list): + raise AppApiException(ValidCode.valid_error.value, f'{model_type} 模型类型不支持') + return True + + @classmethod + def _validate_credential(cls, model_credential: Dict[str, object]) -> credential.Credential: + for key in ['SecretId', 'SecretKey']: + if key not in model_credential: + raise AppApiException(ValidCode.valid_error.value, f'{key} 字段为必填字段') + return credential.Credential(model_credential['SecretId'], model_credential['SecretKey']) + + @classmethod + def _test_credentials(cls, client, model_name: str): + req = models.GetEmbeddingRequest() + params = { + "Model": model_name, + "Input": "测试" + } + req.from_json_string(json.dumps(params)) + try: + res = client.GetEmbedding(req) + print(res.to_json_string()) + except Exception as e: + raise AppApiException(ValidCode.valid_error.value, f'校验失败,请检查参数是否正确: {str(e)}') + + def is_valid(self, model_type: str, model_name, model_credential: Dict[str, object], provider, + raise_exception=True) -> bool: + try: + self._validate_model_type(model_type, provider) + cred = self._validate_credential(model_credential) + httpProfile = HttpProfile(endpoint="hunyuan.tencentcloudapi.com") + clientProfile = ClientProfile(httpProfile=httpProfile) + client = hunyuan_client.HunyuanClient(cred, "", clientProfile) + self._test_credentials(client, model_name) + return True + except AppApiException as e: + if raise_exception: + raise e + return False + + def encryption_dict(self, model: Dict[str, object]) -> Dict[str, object]: + encrypted_secret_key = super().encryption(model.get('SecretKey', '')) + return {**model, 'SecretKey': encrypted_secret_key} + + SecretId = forms.PasswordInputField('SecretId', required=True) + SecretKey = forms.PasswordInputField('SecretKey', required=True) diff --git a/apps/setting/models_provider/impl/tencent_model_provider/credential/llm.py b/apps/setting/models_provider/impl/tencent_model_provider/credential/llm.py new file mode 100644 index 000000000..ad9ab3a82 --- /dev/null +++ b/apps/setting/models_provider/impl/tencent_model_provider/credential/llm.py @@ -0,0 +1,47 @@ +# coding=utf-8 +from langchain_core.messages import HumanMessage +from common import forms +from common.exception.app_exception import AppApiException +from common.forms import BaseForm +from setting.models_provider.base_model_provider import BaseModelCredential, ValidCode + + +class TencentLLMModelCredential(BaseForm, BaseModelCredential): + REQUIRED_FIELDS = ['hunyuan_app_id', 'hunyuan_secret_id', 'hunyuan_secret_key'] + + @classmethod + def _validate_model_type(cls, model_type, provider, raise_exception=False): + if not any(mt['value'] == model_type for mt in provider.get_model_type_list()): + if raise_exception: + raise AppApiException(ValidCode.valid_error.value, f'{model_type} 模型类型不支持') + return False + return True + + @classmethod + def _validate_credential_fields(cls, model_credential, raise_exception=False): + missing_keys = [key for key in cls.REQUIRED_FIELDS if key not in model_credential] + if missing_keys: + if raise_exception: + raise AppApiException(ValidCode.valid_error.value, f'{", ".join(missing_keys)} 字段为必填字段') + return False + return True + + def is_valid(self, model_type, model_name, model_credential, provider, raise_exception=False): + if not (self._validate_model_type(model_type, provider, raise_exception) and + self._validate_credential_fields(model_credential, raise_exception)): + return False + try: + model = provider.get_model(model_type, model_name, model_credential) + model.invoke([HumanMessage(content='你好')]) + except Exception as e: + if raise_exception: + raise AppApiException(ValidCode.valid_error.value, f'校验失败,请检查参数是否正确: {str(e)}') + return False + return True + + def encryption_dict(self, model): + return {**model, 'hunyuan_secret_key': super().encryption(model.get('hunyuan_secret_key', ''))} + + hunyuan_app_id = forms.TextInputField('APP ID', required=True) + hunyuan_secret_id = forms.PasswordInputField('SecretId', required=True) + hunyuan_secret_key = forms.PasswordInputField('SecretKey', required=True) \ No newline at end of file diff --git a/apps/setting/models_provider/impl/tencent_model_provider/icon/tencent_icon_svg b/apps/setting/models_provider/impl/tencent_model_provider/icon/tencent_icon_svg new file mode 100644 index 000000000..6cec08b74 --- /dev/null +++ b/apps/setting/models_provider/impl/tencent_model_provider/icon/tencent_icon_svg @@ -0,0 +1,5 @@ + + + + diff --git a/apps/setting/models_provider/impl/tencent_model_provider/model/embedding.py b/apps/setting/models_provider/impl/tencent_model_provider/model/embedding.py new file mode 100644 index 000000000..a5bd0336a --- /dev/null +++ b/apps/setting/models_provider/impl/tencent_model_provider/model/embedding.py @@ -0,0 +1,25 @@ +from setting.models_provider.base_model_provider import MaxKBBaseModel +from typing import Dict +import requests + + +class TencentEmbeddingModel(MaxKBBaseModel): + def __init__(self, secret_id: str, secret_key: str, api_base: str, model_name: str): + self.secret_id = secret_id + self.secret_key = secret_key + self.api_base = api_base + self.model_name = model_name + + @staticmethod + def new_instance(model_type: str, model_name: str, model_credential: Dict[str, str], **model_kwargs): + return TencentEmbeddingModel( + secret_id=model_credential.get('SecretId'), + secret_key=model_credential.get('SecretKey'), + api_base=model_credential.get('api_base'), + model_name=model_name, + ) + + + def _generate_auth_token(self): + # Example method to generate an authentication token for the model API + return f"{self.secret_id}:{self.secret_key}" diff --git a/apps/setting/models_provider/impl/tencent_model_provider/model/hunyuan.py b/apps/setting/models_provider/impl/tencent_model_provider/model/hunyuan.py new file mode 100644 index 000000000..9af6f983c --- /dev/null +++ b/apps/setting/models_provider/impl/tencent_model_provider/model/hunyuan.py @@ -0,0 +1,273 @@ +import json +import logging +from typing import Any, Dict, Iterator, List, Mapping, Optional, Type + +from langchain_core.callbacks import CallbackManagerForLLMRun +from langchain_core.language_models.chat_models import ( + BaseChatModel, + generate_from_stream, +) +from langchain_core.messages import ( + AIMessage, + AIMessageChunk, + BaseMessage, + BaseMessageChunk, + ChatMessage, + ChatMessageChunk, + HumanMessage, + HumanMessageChunk, +) +from langchain_core.outputs import ChatGeneration, ChatGenerationChunk, ChatResult +from langchain_core.pydantic_v1 import Field, SecretStr, root_validator +from langchain_core.utils import ( + convert_to_secret_str, + get_from_dict_or_env, + get_pydantic_field_names, + pre_init, +) + +logger = logging.getLogger(__name__) + + +def _convert_message_to_dict(message: BaseMessage) -> dict: + message_dict: Dict[str, Any] + if isinstance(message, ChatMessage): + message_dict = {"Role": message.role, "Content": message.content} + elif isinstance(message, HumanMessage): + message_dict = {"Role": "user", "Content": message.content} + elif isinstance(message, AIMessage): + message_dict = {"Role": "assistant", "Content": message.content} + else: + raise TypeError(f"Got unknown type {message}") + + return message_dict + + +def _convert_dict_to_message(_dict: Mapping[str, Any]) -> BaseMessage: + role = _dict["Role"] + if role == "user": + return HumanMessage(content=_dict["Content"]) + elif role == "assistant": + return AIMessage(content=_dict.get("Content", "") or "") + else: + return ChatMessage(content=_dict["Content"], role=role) + + +def _convert_delta_to_message_chunk( + _dict: Mapping[str, Any], default_class: Type[BaseMessageChunk] +) -> BaseMessageChunk: + role = _dict.get("Role") + content = _dict.get("Content") or "" + + if role == "user" or default_class == HumanMessageChunk: + return HumanMessageChunk(content=content) + elif role == "assistant" or default_class == AIMessageChunk: + return AIMessageChunk(content=content) + elif role or default_class == ChatMessageChunk: + return ChatMessageChunk(content=content, role=role) # type: ignore[arg-type] + else: + return default_class(content=content) # type: ignore[call-arg] + + +def _create_chat_result(response: Mapping[str, Any]) -> ChatResult: + generations = [] + for choice in response["Choices"]: + message = _convert_dict_to_message(choice["Message"]) + generations.append(ChatGeneration(message=message)) + + token_usage = response["Usage"] + llm_output = {"token_usage": token_usage} + return ChatResult(generations=generations, llm_output=llm_output) + + +class ChatHunyuan(BaseChatModel): + """Tencent Hunyuan chat models API by Tencent. + + For more information, see https://cloud.tencent.com/document/product/1729 + """ + + @property + def lc_secrets(self) -> Dict[str, str]: + return { + "hunyuan_app_id": "HUNYUAN_APP_ID", + "hunyuan_secret_id": "HUNYUAN_SECRET_ID", + "hunyuan_secret_key": "HUNYUAN_SECRET_KEY", + } + + @property + def lc_serializable(self) -> bool: + return True + + hunyuan_app_id: Optional[int] = None + """Hunyuan App ID""" + hunyuan_secret_id: Optional[str] = None + """Hunyuan Secret ID""" + hunyuan_secret_key: Optional[SecretStr] = None + """Hunyuan Secret Key""" + streaming: bool = False + """Whether to stream the results or not.""" + request_timeout: int = 60 + """Timeout for requests to Hunyuan API. Default is 60 seconds.""" + temperature: float = 1.0 + """What sampling temperature to use.""" + top_p: float = 1.0 + """What probability mass to use.""" + model: str = "hunyuan-lite" + """What Model to use. + Optional model: + - hunyuan-lite、 + - hunyuan-standard + - hunyuan-standard-256K + - hunyuan-pro + - hunyuan-code + - hunyuan-role + - hunyuan-functioncall + - hunyuan-vision + """ + stream_moderation: bool = False + """Whether to review the results or not when streaming is true.""" + enable_enhancement: bool = True + """Whether to enhancement the results or not.""" + + model_kwargs: Dict[str, Any] = Field(default_factory=dict) + """Holds any model parameters valid for API call not explicitly specified.""" + + class Config: + """Configuration for this pydantic object.""" + + allow_population_by_field_name = True + + @root_validator(pre=True) + def build_extra(cls, values: Dict[str, Any]) -> Dict[str, Any]: + """Build extra kwargs from additional params that were passed in.""" + all_required_field_names = get_pydantic_field_names(cls) + extra = values.get("model_kwargs", {}) + for field_name in list(values): + if field_name in extra: + raise ValueError(f"Found {field_name} supplied twice.") + if field_name not in all_required_field_names: + logger.warning( + f"""WARNING! {field_name} is not default parameter. + {field_name} was transferred to model_kwargs. + Please confirm that {field_name} is what you intended.""" + ) + extra[field_name] = values.pop(field_name) + + invalid_model_kwargs = all_required_field_names.intersection(extra.keys()) + if invalid_model_kwargs: + raise ValueError( + f"Parameters {invalid_model_kwargs} should be specified explicitly. " + f"Instead they were passed in as part of `model_kwargs` parameter." + ) + + values["model_kwargs"] = extra + return values + + @pre_init + def validate_environment(cls, values: Dict) -> Dict: + values["hunyuan_app_id"] = get_from_dict_or_env( + values, + "hunyuan_app_id", + "HUNYUAN_APP_ID", + ) + values["hunyuan_secret_id"] = get_from_dict_or_env( + values, + "hunyuan_secret_id", + "HUNYUAN_SECRET_ID", + ) + values["hunyuan_secret_key"] = convert_to_secret_str( + get_from_dict_or_env( + values, + "hunyuan_secret_key", + "HUNYUAN_SECRET_KEY", + ) + ) + return values + + @property + def _default_params(self) -> Dict[str, Any]: + """Get the default parameters for calling Hunyuan API.""" + normal_params = { + "Temperature": self.temperature, + "TopP": self.top_p, + "Model": self.model, + "Stream": self.streaming, + "StreamModeration": self.stream_moderation, + "EnableEnhancement": self.enable_enhancement, + } + return {**normal_params, **self.model_kwargs} + + def _generate( + self, + messages: List[BaseMessage], + stop: Optional[List[str]] = None, + run_manager: Optional[CallbackManagerForLLMRun] = None, + **kwargs: Any, + ) -> ChatResult: + if self.streaming: + stream_iter = self._stream( + messages=messages, stop=stop, run_manager=run_manager, **kwargs + ) + return generate_from_stream(stream_iter) + + res = self._chat(messages, **kwargs) + return _create_chat_result(json.loads(res.to_json_string())) + + def _stream( + self, + messages: List[BaseMessage], + stop: Optional[List[str]] = None, + run_manager: Optional[CallbackManagerForLLMRun] = None, + **kwargs: Any, + ) -> Iterator[ChatGenerationChunk]: + res = self._chat(messages, **kwargs) + + default_chunk_class = AIMessageChunk + for chunk in res: + chunk = chunk.get("data", "") + if len(chunk) == 0: + continue + response = json.loads(chunk) + if "error" in response: + raise ValueError(f"Error from Hunyuan api response: {response}") + + for choice in response["Choices"]: + chunk = _convert_delta_to_message_chunk( + choice["Delta"], default_chunk_class + ) + default_chunk_class = chunk.__class__ + cg_chunk = ChatGenerationChunk(message=chunk) + if run_manager: + run_manager.on_llm_new_token(chunk.content, chunk=cg_chunk) + yield cg_chunk + + def _chat(self, messages: List[BaseMessage], **kwargs: Any) -> Any: + if self.hunyuan_secret_key is None: + raise ValueError("Hunyuan secret key is not set.") + + try: + from tencentcloud.common import credential + from tencentcloud.hunyuan.v20230901 import hunyuan_client, models + except ImportError: + raise ImportError( + "Could not import tencentcloud python package. " + "Please install it with `pip install tencentcloud-sdk-python`." + ) + + parameters = {**self._default_params, **kwargs} + cred = credential.Credential( + self.hunyuan_secret_id, str(self.hunyuan_secret_key.get_secret_value()) + ) + client = hunyuan_client.HunyuanClient(cred, "") + req = models.ChatCompletionsRequest() + params = { + "Messages": [_convert_message_to_dict(m) for m in messages], + **parameters, + } + req.from_json_string(json.dumps(params)) + resp = client.ChatCompletions(req) + return resp + + @property + def _llm_type(self) -> str: + return "hunyuan-chat" \ No newline at end of file diff --git a/apps/setting/models_provider/impl/tencent_model_provider/model/llm.py b/apps/setting/models_provider/impl/tencent_model_provider/model/llm.py new file mode 100644 index 000000000..81116eb61 --- /dev/null +++ b/apps/setting/models_provider/impl/tencent_model_provider/model/llm.py @@ -0,0 +1,37 @@ +# coding=utf-8 + +from typing import List, Dict + +from langchain_core.messages import BaseMessage, get_buffer_string +from common.config.tokenizer_manage_config import TokenizerManage +from setting.models_provider.base_model_provider import MaxKBBaseModel +from setting.models_provider.impl.tencent_model_provider.model.hunyuan import ChatHunyuan + + +class TencentModel(MaxKBBaseModel, ChatHunyuan): + + def __init__(self, model_name: str, credentials: Dict[str, str], streaming: bool = False, **kwargs): + hunyuan_app_id = credentials.get('hunyuan_app_id') + hunyuan_secret_id = credentials.get('hunyuan_secret_id') + hunyuan_secret_key = credentials.get('hunyuan_secret_key') + + if not all([hunyuan_app_id, hunyuan_secret_id, hunyuan_secret_key]): + raise ValueError( + "All of 'hunyuan_app_id', 'hunyuan_secret_id', and 'hunyuan_secret_key' must be provided in credentials.") + + super().__init__(model=model_name, hunyuan_app_id=hunyuan_app_id, hunyuan_secret_id=hunyuan_secret_id, + hunyuan_secret_key=hunyuan_secret_key, streaming=streaming, **kwargs) + + @staticmethod + def new_instance(model_type: str, model_name: str, model_credential: Dict[str, object], + **model_kwargs) -> 'TencentModel': + streaming = model_kwargs.pop('streaming', False) + return TencentModel(model_name=model_name, credentials=model_credential, streaming=streaming, **model_kwargs) + + def get_num_tokens_from_messages(self, messages: List[BaseMessage]) -> int: + tokenizer = TokenizerManage.get_tokenizer() + return sum(len(tokenizer.encode(get_buffer_string([m]))) for m in messages) + + def get_num_tokens(self, text: str) -> int: + tokenizer = TokenizerManage.get_tokenizer() + return len(tokenizer.encode(text)) diff --git a/apps/setting/models_provider/impl/tencent_model_provider/tencent_model_provider.py b/apps/setting/models_provider/impl/tencent_model_provider/tencent_model_provider.py new file mode 100644 index 000000000..572ff940d --- /dev/null +++ b/apps/setting/models_provider/impl/tencent_model_provider/tencent_model_provider.py @@ -0,0 +1,103 @@ +#!/usr/bin/env python +# -*- coding: UTF-8 -*- + +import os +from common.util.file_util import get_file_content +from setting.models_provider.base_model_provider import ( + IModelProvider, ModelProvideInfo, ModelInfo, ModelTypeConst, ModelInfoManage +) +from setting.models_provider.impl.tencent_model_provider.credential.embedding import TencentEmbeddingCredential +from setting.models_provider.impl.tencent_model_provider.credential.llm import TencentLLMModelCredential +from setting.models_provider.impl.tencent_model_provider.model.embedding import TencentEmbeddingModel +from setting.models_provider.impl.tencent_model_provider.model.llm import TencentModel +from smartdoc.conf import PROJECT_DIR + + +def _create_model_info(model_name, description, model_type, credential_class, model_class): + return ModelInfo( + name=model_name, + desc=description, + model_type=model_type, + model_credential=credential_class(), + model_class=model_class + ) + + +def _get_tencent_icon_path(): + return os.path.join(PROJECT_DIR, "apps", "setting", 'models_provider', 'impl', 'tencent_model_provider', + 'icon', 'tencent_icon_svg') + + +def _initialize_model_info(): + model_info_list = [_create_model_info( + 'hunyuan-pro', + '当前混元模型中效果最优版本,万亿级参数规模 MOE-32K 长文模型。在各种 benchmark 上达到绝对领先的水平,复杂指令和推理,具备复杂数学能力,支持 functioncall,在多语言翻译、金融法律医疗等领域应用重点优化', + ModelTypeConst.LLM, + TencentLLMModelCredential, + TencentModel + ), + _create_model_info( + 'hunyuan-standard', + '采用更优的路由策略,同时缓解了负载均衡和专家趋同的问题。长文方面,大海捞针指标达到99.9%', + ModelTypeConst.LLM, + TencentLLMModelCredential, + TencentModel), + _create_model_info( + 'hunyuan-lite', + '升级为 MOE 结构,上下文窗口为 256k ,在 NLP,代码,数学,行业等多项评测集上领先众多开源模型', + ModelTypeConst.LLM, + TencentLLMModelCredential, + TencentModel), + _create_model_info( + 'hunyuan-role', + '混元最新版角色扮演模型,混元官方精调训练推出的角色扮演模型,基于混元模型结合角色扮演场景数据集进行增训,在角色扮演场景具有更好的基础效果', + ModelTypeConst.LLM, + TencentLLMModelCredential, + TencentModel), + _create_model_info( + 'hunyuan-functioncall ', + '混元最新 MOE 架构 FunctionCall 模型,经过高质量的 FunctionCall 数据训练,上下文窗口达 32K,在多个维度的评测指标上处于领先。', + ModelTypeConst.LLM, + TencentLLMModelCredential, + TencentModel), + _create_model_info( + 'hunyuan-code', + '混元最新代码生成模型,经过 200B 高质量代码数据增训基座模型,迭代半年高质量 SFT 数据训练,上下文长窗口长度增大到 8K,五大语言代码生成自动评测指标上位居前列;五大语言10项考量各方面综合代码任务人工高质量评测上,性能处于第一梯队', + ModelTypeConst.LLM, + TencentLLMModelCredential, + TencentModel), + ] + + tencent_embedding_model_info = _create_model_info( + 'hunyuan-embedding', + '', + ModelTypeConst.EMBEDDING, + TencentEmbeddingCredential, + TencentEmbeddingModel + ) + + model_info_embedding_list = [tencent_embedding_model_info] + + model_info_manage = ModelInfoManage.builder() \ + .append_model_info_list(model_info_list) \ + .append_default_model_info(model_info_list[0]) \ + .build() + + return model_info_manage + + +class TencentModelProvider(IModelProvider): + def __init__(self): + self._model_info_manage = _initialize_model_info() + + def get_model_info_manage(self): + return self._model_info_manage + + def get_model_provide_info(self): + icon_path = _get_tencent_icon_path() + icon_data = get_file_content(icon_path) + return ModelProvideInfo( + provider='model_tencent_provider', + name='腾讯混元', + icon=icon_data + ) diff --git a/apps/setting/models_provider/impl/volcanic_engine_model_provider/__init__.py b/apps/setting/models_provider/impl/volcanic_engine_model_provider/__init__.py new file mode 100644 index 000000000..8cb7f459e --- /dev/null +++ b/apps/setting/models_provider/impl/volcanic_engine_model_provider/__init__.py @@ -0,0 +1,2 @@ +#!/usr/bin/env python +# -*- coding: UTF-8 -*- diff --git a/apps/setting/models_provider/impl/volcanic_engine_model_provider/credential/embedding.py b/apps/setting/models_provider/impl/volcanic_engine_model_provider/credential/embedding.py new file mode 100644 index 000000000..d49d22e22 --- /dev/null +++ b/apps/setting/models_provider/impl/volcanic_engine_model_provider/credential/embedding.py @@ -0,0 +1,46 @@ +# coding=utf-8 +""" + @project: MaxKB + @Author:虎 + @file: embedding.py + @date:2024/7/12 16:45 + @desc: +""" +from typing import Dict + +from common import forms +from common.exception.app_exception import AppApiException +from common.forms import BaseForm +from setting.models_provider.base_model_provider import BaseModelCredential, ValidCode + + +class OpenAIEmbeddingCredential(BaseForm, BaseModelCredential): + def is_valid(self, model_type: str, model_name, model_credential: Dict[str, object], provider, + raise_exception=True): + model_type_list = provider.get_model_type_list() + if not any(list(filter(lambda mt: mt.get('value') == model_type, model_type_list))): + raise AppApiException(ValidCode.valid_error.value, f'{model_type} 模型类型不支持') + + for key in ['api_base', 'api_key']: + if key not in model_credential: + if raise_exception: + raise AppApiException(ValidCode.valid_error.value, f'{key} 字段为必填字段') + else: + return False + try: + model = provider.get_model(model_type, model_name, model_credential) + model.embed_query('你好') + except Exception as e: + if isinstance(e, AppApiException): + raise e + if raise_exception: + raise AppApiException(ValidCode.valid_error.value, f'校验失败,请检查参数是否正确: {str(e)}') + else: + return False + return True + + def encryption_dict(self, model: Dict[str, object]): + return {**model, 'api_key': super().encryption(model.get('api_key', ''))} + + api_base = forms.TextInputField('API 域名', required=True) + api_key = forms.PasswordInputField('API Key', required=True) diff --git a/apps/setting/models_provider/impl/volcanic_engine_model_provider/credential/llm.py b/apps/setting/models_provider/impl/volcanic_engine_model_provider/credential/llm.py new file mode 100644 index 000000000..5647e2a08 --- /dev/null +++ b/apps/setting/models_provider/impl/volcanic_engine_model_provider/credential/llm.py @@ -0,0 +1,50 @@ +# coding=utf-8 +""" + @project: MaxKB + @Author:虎 + @file: llm.py + @date:2024/7/11 17:57 + @desc: +""" +from typing import Dict + +from langchain_core.messages import HumanMessage + +from common import forms +from common.exception.app_exception import AppApiException +from common.forms import BaseForm +from setting.models_provider.base_model_provider import BaseModelCredential, ValidCode + + +class VolcanicEngineLLMModelCredential(BaseForm, BaseModelCredential): + + def is_valid(self, model_type: str, model_name, model_credential: Dict[str, object], provider, + raise_exception=False): + model_type_list = provider.get_model_type_list() + if not any(list(filter(lambda mt: mt.get('value') == model_type, model_type_list))): + raise AppApiException(ValidCode.valid_error.value, f'{model_type} 模型类型不支持') + + for key in ['access_key_id', 'secret_access_key']: + if key not in model_credential: + if raise_exception: + raise AppApiException(ValidCode.valid_error.value, f'{key} 字段为必填字段') + else: + return False + try: + model = provider.get_model(model_type, model_name, model_credential) + res = model.invoke([HumanMessage(content='你好')]) + print(res) + except Exception as e: + if isinstance(e, AppApiException): + raise e + if raise_exception: + raise AppApiException(ValidCode.valid_error.value, f'校验失败,请检查参数是否正确: {str(e)}') + else: + return False + return True + + def encryption_dict(self, model: Dict[str, object]): + return {**model, 'access_key_id': super().encryption(model.get('access_key_id', ''))} + + access_key_id = forms.PasswordInputField('Access Key ID', required=True) + secret_access_key = forms.PasswordInputField('Secret Access Key', required=True) diff --git a/apps/setting/models_provider/impl/volcanic_engine_model_provider/icon/volcanic_engine_icon_svg b/apps/setting/models_provider/impl/volcanic_engine_model_provider/icon/volcanic_engine_icon_svg new file mode 100644 index 000000000..05a1279ef --- /dev/null +++ b/apps/setting/models_provider/impl/volcanic_engine_model_provider/icon/volcanic_engine_icon_svg @@ -0,0 +1,5 @@ + + + + diff --git a/apps/setting/models_provider/impl/volcanic_engine_model_provider/model/embedding.py b/apps/setting/models_provider/impl/volcanic_engine_model_provider/model/embedding.py new file mode 100644 index 000000000..b7307a0e5 --- /dev/null +++ b/apps/setting/models_provider/impl/volcanic_engine_model_provider/model/embedding.py @@ -0,0 +1,15 @@ +from typing import Dict + +from langchain_community.embeddings import VolcanoEmbeddings + +from setting.models_provider.base_model_provider import MaxKBBaseModel + + +class VolcanicEngineEmbeddingModel(MaxKBBaseModel, VolcanoEmbeddings): + @staticmethod + def new_instance(model_type, model_name, model_credential: Dict[str, object], **model_kwargs): + return VolcanicEngineEmbeddingModel( + api_key=model_credential.get('api_key'), + model=model_name, + openai_api_base=model_credential.get('api_base'), + ) diff --git a/apps/setting/models_provider/impl/volcanic_engine_model_provider/model/llm.py b/apps/setting/models_provider/impl/volcanic_engine_model_provider/model/llm.py new file mode 100644 index 000000000..3ab309863 --- /dev/null +++ b/apps/setting/models_provider/impl/volcanic_engine_model_provider/model/llm.py @@ -0,0 +1,27 @@ +from typing import List, Dict + +from langchain_community.chat_models import VolcEngineMaasChat +from langchain_core.messages import BaseMessage, get_buffer_string + +from common.config.tokenizer_manage_config import TokenizerManage +from setting.models_provider.base_model_provider import MaxKBBaseModel +from langchain_openai import ChatOpenAI + + +class VolcanicEngineChatModel(MaxKBBaseModel, ChatOpenAI): + @staticmethod + def new_instance(model_type, model_name, model_credential: Dict[str, object], **model_kwargs): + volcanic_engine_chat = VolcanicEngineChatModel( + model=model_name, + volc_engine_maas_ak=model_credential.get("access_key_id"), + volc_engine_maas_sk=model_credential.get("secret_access_key"), + ) + return volcanic_engine_chat + + def get_num_tokens_from_messages(self, messages: List[BaseMessage]) -> int: + tokenizer = TokenizerManage.get_tokenizer() + return sum([len(tokenizer.encode(get_buffer_string([m]))) for m in messages]) + + def get_num_tokens(self, text: str) -> int: + tokenizer = TokenizerManage.get_tokenizer() + return len(tokenizer.encode(text)) diff --git a/apps/setting/models_provider/impl/volcanic_engine_model_provider/volcanic_engine_model_provider.py b/apps/setting/models_provider/impl/volcanic_engine_model_provider/volcanic_engine_model_provider.py new file mode 100644 index 000000000..af7bc76b3 --- /dev/null +++ b/apps/setting/models_provider/impl/volcanic_engine_model_provider/volcanic_engine_model_provider.py @@ -0,0 +1,51 @@ +#!/usr/bin/env python +# -*- coding: UTF-8 -*- +""" +@Project :MaxKB +@File :gemini_model_provider.py +@Author :Brian Yang +@Date :5/13/24 7:47 AM +""" +import os + +from common.util.file_util import get_file_content +from setting.models_provider.base_model_provider import IModelProvider, ModelProvideInfo, ModelInfo, ModelTypeConst, \ + ModelInfoManage +from setting.models_provider.impl.openai_model_provider.credential.embedding import OpenAIEmbeddingCredential +from setting.models_provider.impl.openai_model_provider.credential.llm import OpenAILLMModelCredential +from setting.models_provider.impl.openai_model_provider.model.embedding import OpenAIEmbeddingModel +from setting.models_provider.impl.openai_model_provider.model.llm import OpenAIChatModel + +from smartdoc.conf import PROJECT_DIR + +volcanic_engine_llm_model_credential = OpenAILLMModelCredential() + +model_info_list = [ + ModelInfo('ep-xxxxxxxxxx-yyyy', + '用户前往火山方舟的模型推理页面创建推理接入点,这里需要输入ep-xxxxxxxxxx-yyyy进行调用', + ModelTypeConst.LLM, + volcanic_engine_llm_model_credential, OpenAIChatModel + ) +] + +open_ai_embedding_credential = OpenAIEmbeddingCredential() +model_info_embedding_list = [ + ModelInfo('ep-xxxxxxxxxx-yyyy', + '用户前往火山方舟的模型推理页面创建推理接入点,这里需要输入ep-xxxxxxxxxx-yyyy进行调用', + ModelTypeConst.EMBEDDING, open_ai_embedding_credential, + OpenAIEmbeddingModel)] + +model_info_manage = ModelInfoManage.builder().append_model_info_list(model_info_list).append_default_model_info( + model_info_list[0]).build() + + +class VolcanicEngineModelProvider(IModelProvider): + + def get_model_info_manage(self): + return model_info_manage + + def get_model_provide_info(self): + return ModelProvideInfo(provider='model_volcanic_engine_provider', name='火山引擎', icon=get_file_content( + os.path.join(PROJECT_DIR, "apps", "setting", 'models_provider', 'impl', 'volcanic_engine_model_provider', + 'icon', + 'volcanic_engine_icon_svg'))) diff --git a/pyproject.toml b/pyproject.toml index 7566a7014..88fac01ee 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -46,6 +46,8 @@ gunicorn = "^22.0.0" python-daemon = "3.0.1" gevent = "^24.2.1" +boto3 = "^1.34.151" +langchain-aws = "^0.1.13" [build-system] requires = ["poetry-core"] build-backend = "poetry.core.masonry.api"