feat: add model setting

This commit is contained in:
wxg0103 2025-04-17 18:01:33 +08:00
parent 934871f5c6
commit 946de675ff
277 changed files with 15472 additions and 3 deletions

86
apps/common/cache/file_cache.py vendored Normal file
View File

@ -0,0 +1,86 @@
# coding=utf-8
"""
@project: qabot
@Author
@file file_cache.py
@date2023/9/11 15:58
@desc: 文件缓存
"""
import datetime
import math
import os
import time
from diskcache import Cache
from django.core.cache.backends.base import BaseCache
class FileCache(BaseCache):
def __init__(self, dir, params):
super().__init__(params)
self._dir = os.path.abspath(dir)
self._createdir()
self.cache = Cache(self._dir)
def _createdir(self):
old_umask = os.umask(0o077)
try:
os.makedirs(self._dir, 0o700, exist_ok=True)
finally:
os.umask(old_umask)
def add(self, key, value, timeout=None, version=None):
expire = timeout if isinstance(timeout, int) or isinstance(timeout,
float) or timeout is None else timeout.total_seconds()
return self.cache.add(self.get_key(key, version), value=value, expire=expire)
def set(self, key, value, timeout=None, version=None):
expire = timeout if isinstance(timeout, int) or isinstance(timeout,
float) or timeout is None else timeout.total_seconds()
return self.cache.set(self.get_key(key, version), value=value, expire=expire)
def get(self, key, default=None, version=None):
return self.cache.get(self.get_key(key, version), default=default)
@staticmethod
def get_key(key, version):
if version is None:
return f"default:{key}"
return f"{version}:{key}"
def delete(self, key, version=None):
return self.cache.delete(self.get_key(key, version))
def touch(self, key, timeout=None, version=None):
expire = timeout if isinstance(timeout, int) or isinstance(timeout,
float) else timeout.total_seconds()
return self.cache.touch(self.get_key(key, version), expire=expire)
def ttl(self, key, version=None):
"""
获取key的剩余时间
:param key: key
:return: 剩余时间
@param version:
"""
value, expire_time = self.cache.get(self.get_key(key, version), expire_time=True)
if value is None:
return None
return datetime.timedelta(seconds=math.ceil(expire_time - time.time()))
def clear_by_application_id(self, application_id):
delete_keys = []
for key in self.cache.iterkeys():
value = self.cache.get(key)
if (hasattr(value,
'application') and value.application is not None and value.application.id is not None and
str(
value.application.id) == application_id):
delete_keys.append(key)
for key in delete_keys:
self.cache.delete(key)
def clear_timeout_data(self):
for key in self.cache.iterkeys():
self.get(key)

47
apps/common/cache/mem_cache.py vendored Normal file
View File

@ -0,0 +1,47 @@
# coding=utf-8
"""
@project: maxkb
@Author
@file mem_cache.py
@date2024/3/6 11:20
@desc:
"""
from django.core.cache.backends.base import DEFAULT_TIMEOUT
from django.core.cache.backends.locmem import LocMemCache
class MemCache(LocMemCache):
def __init__(self, name, params):
super().__init__(name, params)
def set(self, key, value, timeout=DEFAULT_TIMEOUT, version=None):
key = self.make_and_validate_key(key, version=version)
pickled = value
with self._lock:
self._set(key, pickled, timeout)
def get(self, key, default=None, version=None):
key = self.make_and_validate_key(key, version=version)
with self._lock:
if self._has_expired(key):
self._delete(key)
return default
pickled = self._cache[key]
self._cache.move_to_end(key, last=False)
return pickled
def clear_by_application_id(self, application_id):
delete_keys = []
for key in self._cache.keys():
value = self._cache.get(key)
if (hasattr(value,
'application') and value.application is not None and value.application.id is not None and
str(
value.application.id) == application_id):
delete_keys.append(key)
for key in delete_keys:
self._delete(key)
def clear_timeout_data(self):
for key in self._cache.keys():
self.get(key)

View File

@ -0,0 +1,66 @@
# coding=utf-8
"""
@project: maxkb
@Author
@file embedding_config.py
@date2023/10/23 16:03
@desc:
"""
import threading
import time
from common.cache.mem_cache import MemCache
lock = threading.Lock()
class ModelManage:
cache = MemCache('model', {})
up_clear_time = time.time()
@staticmethod
def get_model(_id, get_model):
# 获取锁
lock.acquire()
try:
model_instance = ModelManage.cache.get(_id)
if model_instance is None or not model_instance.is_cache_model():
model_instance = get_model(_id)
ModelManage.cache.set(_id, model_instance, timeout=60 * 30)
return model_instance
# 续期
ModelManage.cache.touch(_id, timeout=60 * 30)
ModelManage.clear_timeout_cache()
return model_instance
finally:
# 释放锁
lock.release()
@staticmethod
def clear_timeout_cache():
if time.time() - ModelManage.up_clear_time > 60:
ModelManage.cache.clear_timeout_data()
@staticmethod
def delete_key(_id):
if ModelManage.cache.has_key(_id):
ModelManage.cache.delete(_id)
class VectorStore:
from embedding.vector.pg_vector import PGVector
from embedding.vector.base_vector import BaseVectorStore
instance_map = {
'pg_vector': PGVector,
}
instance = None
@staticmethod
def get_embedding_vector() -> BaseVectorStore:
from embedding.vector.pg_vector import PGVector
if VectorStore.instance is None:
from maxkb.const import CONFIG
vector_store_class = VectorStore.instance_map.get(CONFIG.get("VECTOR_STORE_NAME"),
PGVector)
VectorStore.instance = vector_store_class()
return VectorStore.instance

View File

@ -0,0 +1,24 @@
# coding=utf-8
"""
@project: maxkb
@Author
@file tokenizer_manage_config.py
@date2024/4/28 10:17
@desc:
"""
class TokenizerManage:
tokenizer = None
@staticmethod
def get_tokenizer():
from transformers import BertTokenizer
if TokenizerManage.tokenizer is None:
TokenizerManage.tokenizer = BertTokenizer.from_pretrained(
'bert-base-cased',
cache_dir="/opt/maxkb/model/tokenizer",
local_files_only=True,
resume_download=False,
force_download=False)
return TokenizerManage.tokenizer

View File

@ -111,6 +111,13 @@ class PermissionConstants(Enum):
USER_DELETE = Permission(group=Group.USER, operate=Operate.DELETE, role_list=[RoleConstants.ADMIN])
TOOL_CREATE = Permission(group=Group.USER, operate=Operate.CREATE, role_list=[RoleConstants.ADMIN,
RoleConstants.USER])
MODEL_CREATE = Permission(group=Group.USER, operate=Operate.CREATE, role_list=[RoleConstants.ADMIN,
RoleConstants.USER])
MODEL_READ = Permission(group=Group.USER, operate=Operate.READ, role_list=[RoleConstants.ADMIN,
RoleConstants.USER])
MODEL_EDIT = Permission(group=Group.USER, operate=Operate.EDIT, role_list=[RoleConstants.ADMIN, RoleConstants.USER])
MODEL_DELETE = Permission(group=Group.USER, operate=Operate.DELETE,
role_list=[RoleConstants.ADMIN, RoleConstants.USER])
def get_workspace_application_permission(self):
return lambda r, kwargs: Permission(group=self.value.group, operate=self.value.operate,

View File

@ -0,0 +1,24 @@
# coding=utf-8
"""
@project: maxkb
@Author
@file __init__.py
@date2023/10/31 17:56
@desc:
"""
from .array_object_card import *
from .base_field import *
from .base_form import *
from .multi_select import *
from .object_card import *
from .password_input import *
from .radio_field import *
from .single_select_field import *
from .tab_card import *
from .table_radio import *
from .text_input_field import *
from .radio_button_field import *
from .table_checkbox import *
from .radio_card_field import *
from .label import *
from .slider_field import *

View File

@ -0,0 +1,33 @@
# coding=utf-8
"""
@project: maxkb
@Author
@file array_object_card.py
@date2023/10/31 18:03
@desc:
"""
from typing import Dict
from common.forms.base_field import BaseExecField, TriggerType
class ArrayCard(BaseExecField):
"""
收集List[Object]
"""
def __init__(self,
label: str,
text_field: str,
value_field: str,
provider: str,
method: str,
required: bool = False,
default_value: object = None,
relation_show_field_dict: Dict = None,
relation_trigger_field_dict: Dict = None,
trigger_type: TriggerType = TriggerType.OPTION_LIST,
attrs: Dict[str, object] = None,
props_info: Dict[str, object] = None):
super().__init__("ArrayObjectCard", label, text_field, value_field, provider, method, required, default_value,
relation_show_field_dict, relation_trigger_field_dict, trigger_type, attrs, props_info)

View File

@ -0,0 +1,157 @@
# coding=utf-8
"""
@project: maxkb
@Author
@file base_field.py
@date2023/10/31 18:07
@desc:
"""
from enum import Enum
from typing import List, Dict
from common.exception.app_exception import AppApiException
from common.forms.label.base_label import BaseLabel
from django.utils.translation import gettext_lazy as _
class TriggerType(Enum):
# 执行函数获取 OptionList数据
OPTION_LIST = 'OPTION_LIST'
# 执行函数获取子表单
CHILD_FORMS = 'CHILD_FORMS'
class BaseField:
def __init__(self,
input_type: str,
label: str or BaseLabel,
required: bool = False,
default_value: object = None,
relation_show_field_dict: Dict = None,
relation_trigger_field_dict: Dict = None,
trigger_type: TriggerType = TriggerType.OPTION_LIST,
attrs: Dict[str, object] = None,
props_info: Dict[str, object] = None):
"""
:param input_type: 字段
:param label: 提示
:param default_value: 默认值
:param relation_show_field_dict: {field:field_value_list} 表示在 field有值 ,并且值在field_value_list中才显示
:param relation_trigger_field_dict: {field:field_value_list} 表示在 field有值 ,并且值在field_value_list中才 执行函数获取 数据
:param trigger_type: 执行器类型 OPTION_LIST请求Option_list数据 CHILD_FORMS请求子表单
:param attrs: 前端attr数据
:param props_info: 其他额外信息
"""
if props_info is None:
props_info = {}
if attrs is None:
attrs = {}
self.label = label
self.attrs = attrs
self.props_info = props_info
self.default_value = default_value
self.input_type = input_type
self.relation_show_field_dict = {} if relation_show_field_dict is None else relation_show_field_dict
self.relation_trigger_field_dict = [] if relation_trigger_field_dict is None else relation_trigger_field_dict
self.required = required
self.trigger_type = trigger_type
def is_valid(self, value):
field_label = self.label.label if hasattr(self.label, 'to_dict') else self.label
if self.required and value is None:
raise AppApiException(500,
_('The field {field_label} is required').format(field_label=field_label))
def to_dict(self, **kwargs):
return {
'input_type': self.input_type,
'label': self.label.to_dict(**kwargs) if hasattr(self.label, 'to_dict') else self.label,
'required': self.required,
'default_value': self.default_value,
'relation_show_field_dict': self.relation_show_field_dict,
'relation_trigger_field_dict': self.relation_trigger_field_dict,
'trigger_type': self.trigger_type.value,
'attrs': self.attrs,
'props_info': self.props_info,
**kwargs
}
class BaseDefaultOptionField(BaseField):
def __init__(self, input_type: str,
label: str,
text_field: str,
value_field: str,
option_list: List[dict],
required: bool = False,
default_value: object = None,
relation_show_field_dict: Dict[str, object] = None,
attrs: Dict[str, object] = None,
props_info: Dict[str, object] = None):
"""
:param input_type: 字段
:param label: label
:param text_field: 文本字段
:param value_field: 值字段
:param option_list: 可选列表
:param required: 是否必填
:param default_value: 默认值
:param relation_show_field_dict: {field:field_value_list} 表示在 field有值 ,并且值在field_value_list中才显示
:param attrs: 前端attr数据
:param props_info: 其他额外信息
"""
super().__init__(input_type, label, required, default_value, relation_show_field_dict,
{}, TriggerType.OPTION_LIST, attrs, props_info)
self.text_field = text_field
self.value_field = value_field
self.option_list = option_list
def to_dict(self, **kwargs):
return {**super().to_dict(**kwargs), 'text_field': self.text_field, 'value_field': self.value_field,
'option_list': self.option_list}
class BaseExecField(BaseField):
def __init__(self,
input_type: str,
label: str,
text_field: str,
value_field: str,
provider: str,
method: str,
required: bool = False,
default_value: object = None,
relation_show_field_dict: Dict = None,
relation_trigger_field_dict: Dict = None,
trigger_type: TriggerType = TriggerType.OPTION_LIST,
attrs: Dict[str, object] = None,
props_info: Dict[str, object] = None):
"""
:param input_type: 字段
:param label: 提示
:param text_field: 文本字段
:param value_field: 值字段
:param provider: 指定供应商
:param method: 执行供应商函数 method
:param required: 是否必填
:param default_value: 默认值
:param relation_show_field_dict: {field:field_value_list} 表示在 field有值 ,并且值在field_value_list中才显示
:param relation_trigger_field_dict: {field:field_value_list} 表示在 field有值 ,并且值在field_value_list中才 执行函数获取 数据
:param trigger_type: 执行器类型 OPTION_LIST请求Option_list数据 CHILD_FORMS请求子表单
:param attrs: 前端attr数据
:param props_info: 其他额外信息
"""
super().__init__(input_type, label, required, default_value, relation_show_field_dict,
relation_trigger_field_dict,
trigger_type, attrs, props_info)
self.text_field = text_field
self.value_field = value_field
self.provider = provider
self.method = method
def to_dict(self, **kwargs):
return {**super().to_dict(**kwargs), 'text_field': self.text_field, 'value_field': self.value_field,
'provider': self.provider, 'method': self.method}

View File

@ -0,0 +1,30 @@
# coding=utf-8
"""
@project: maxkb
@Author
@file base_form.py
@date2023/11/1 16:04
@desc:
"""
from typing import Dict
from common.forms import BaseField
class BaseForm:
def to_form_list(self, **kwargs):
return [{**self.__getattribute__(key).to_dict(**kwargs), 'field': key} for key in
list(filter(lambda key: isinstance(self.__getattribute__(key), BaseField),
[attr for attr in vars(self.__class__) if not attr.startswith("__")]))]
def valid_form(self, form_data):
field_keys = list(filter(lambda key: isinstance(self.__getattribute__(key), BaseField),
[attr for attr in vars(self.__class__) if not attr.startswith("__")]))
for field_key in field_keys:
self.__getattribute__(field_key).is_valid(form_data.get(field_key))
def get_default_form_data(self):
return {key: self.__getattribute__(key).default_value for key in
[attr for attr in vars(self.__class__) if not attr.startswith("__")] if
isinstance(self.__getattribute__(key), BaseField) and self.__getattribute__(
key).default_value is not None}

View File

@ -0,0 +1,10 @@
# coding=utf-8
"""
@project: MaxKB
@Author
@file __init__.py.py
@date2024/8/22 17:19
@desc:
"""
from .base_label import *
from .tooltip_label import *

View File

@ -0,0 +1,28 @@
# coding=utf-8
"""
@project: MaxKB
@Author
@file base_label.py
@date2024/8/22 17:11
@desc:
"""
class BaseLabel:
def __init__(self,
input_type: str,
label: str,
attrs=None,
props_info=None):
self.input_type = input_type
self.label = label
self.attrs = attrs
self.props_info = props_info
def to_dict(self, **kwargs):
return {
'input_type': self.input_type,
'label': self.label,
'attrs': {} if self.attrs is None else self.attrs,
'props_info': {} if self.props_info is None else self.props_info,
}

View File

@ -0,0 +1,14 @@
# coding=utf-8
"""
@project: MaxKB
@Author
@file tooltip_label.py
@date2024/8/22 17:19
@desc:
"""
from common.forms.label.base_label import BaseLabel
class TooltipLabel(BaseLabel):
def __init__(self, label, tooltip):
super().__init__('TooltipLabel', label, attrs={'tooltip': tooltip}, props_info={})

View File

@ -0,0 +1,38 @@
# coding=utf-8
"""
@project: maxkb
@Author
@file multi_select.py
@date2023/10/31 18:00
@desc:
"""
from typing import List, Dict
from common.forms.base_field import BaseExecField, TriggerType
class MultiSelect(BaseExecField):
"""
下拉单选
"""
def __init__(self,
label: str,
text_field: str,
value_field: str,
option_list: List[str:object],
provider: str = None,
method: str = None,
required: bool = False,
default_value: object = None,
relation_show_field_dict: Dict = None,
relation_trigger_field_dict: Dict = None,
trigger_type: TriggerType = TriggerType.OPTION_LIST,
attrs: Dict[str, object] = None,
props_info: Dict[str, object] = None):
super().__init__("MultiSelect", label, text_field, value_field, provider, method, required, default_value,
relation_show_field_dict, relation_trigger_field_dict, trigger_type, attrs, props_info)
self.option_list = option_list
def to_dict(self):
return {**super().to_dict(), 'option_list': self.option_list}

View File

@ -0,0 +1,33 @@
# coding=utf-8
"""
@project: maxkb
@Author
@file object_card.py
@date2023/10/31 18:02
@desc:
"""
from typing import Dict
from common.forms.base_field import BaseExecField, TriggerType
class ObjectCard(BaseExecField):
"""
收集对象子表卡片
"""
def __init__(self,
label: str,
text_field: str,
value_field: str,
provider: str,
method: str,
required: bool = False,
default_value: object = None,
relation_show_field_dict: Dict = None,
relation_trigger_field_dict: Dict = None,
trigger_type: TriggerType = TriggerType.OPTION_LIST,
attrs: Dict[str, object] = None,
props_info: Dict[str, object] = None):
super().__init__("ObjectCard", label, text_field, value_field, provider, method, required, default_value,
relation_show_field_dict, relation_trigger_field_dict, trigger_type, attrs, props_info)

View File

@ -0,0 +1,26 @@
# coding=utf-8
"""
@project: maxkb
@Author
@file password_input.py
@date2023/11/1 14:48
@desc:
"""
from typing import Dict
from common.forms import BaseField, TriggerType
class PasswordInputField(BaseField):
"""
文本输入框
"""
def __init__(self, label: str,
required: bool = False,
default_value=None,
relation_show_field_dict: Dict = None,
attrs=None, props_info=None):
super().__init__('PasswordInput', label, required, default_value, relation_show_field_dict,
{},
TriggerType.OPTION_LIST, attrs, props_info)

View File

@ -0,0 +1,38 @@
# coding=utf-8
"""
@project: maxkb
@Author
@file radio_field.py
@date2023/10/31 17:59
@desc:
"""
from typing import List, Dict
from common.forms.base_field import BaseExecField, TriggerType
class Radio(BaseExecField):
"""
下拉单选
"""
def __init__(self,
label: str,
text_field: str,
value_field: str,
option_list: List[str:object],
provider: str,
method: str,
required: bool = False,
default_value: object = None,
relation_show_field_dict: Dict = None,
relation_trigger_field_dict: Dict = None,
trigger_type: TriggerType = TriggerType.OPTION_LIST,
attrs: Dict[str, object] = None,
props_info: Dict[str, object] = None):
super().__init__("RadioButton", label, text_field, value_field, provider, method, required, default_value,
relation_show_field_dict, relation_trigger_field_dict, trigger_type, attrs, props_info)
self.option_list = option_list
def to_dict(self):
return {**super().to_dict(), 'option_list': self.option_list}

View File

@ -0,0 +1,38 @@
# coding=utf-8
"""
@project: maxkb
@Author
@file radio_field.py
@date2023/10/31 17:59
@desc:
"""
from typing import List, Dict
from common.forms.base_field import BaseExecField, TriggerType
class Radio(BaseExecField):
"""
下拉单选
"""
def __init__(self,
label: str,
text_field: str,
value_field: str,
option_list: List[str:object],
provider: str,
method: str,
required: bool = False,
default_value: object = None,
relation_show_field_dict: Dict = None,
relation_trigger_field_dict: Dict = None,
trigger_type: TriggerType = TriggerType.OPTION_LIST,
attrs: Dict[str, object] = None,
props_info: Dict[str, object] = None):
super().__init__("RadioCard", label, text_field, value_field, provider, method, required, default_value,
relation_show_field_dict, relation_trigger_field_dict, trigger_type, attrs, props_info)
self.option_list = option_list
def to_dict(self):
return {**super().to_dict(), 'option_list': self.option_list}

View File

@ -0,0 +1,38 @@
# coding=utf-8
"""
@project: maxkb
@Author
@file radio_field.py
@date2023/10/31 17:59
@desc:
"""
from typing import List, Dict
from common.forms.base_field import BaseExecField, TriggerType
class Radio(BaseExecField):
"""
下拉单选
"""
def __init__(self,
label: str,
text_field: str,
value_field: str,
option_list: List[str:object],
provider: str,
method: str,
required: bool = False,
default_value: object = None,
relation_show_field_dict: Dict = None,
relation_trigger_field_dict: Dict = None,
trigger_type: TriggerType = TriggerType.OPTION_LIST,
attrs: Dict[str, object] = None,
props_info: Dict[str, object] = None):
super().__init__("Radio", label, text_field, value_field, provider, method, required, default_value,
relation_show_field_dict, relation_trigger_field_dict, trigger_type, attrs, props_info)
self.option_list = option_list
def to_dict(self):
return {**super().to_dict(), 'option_list': self.option_list}

View File

@ -0,0 +1,39 @@
# coding=utf-8
"""
@project: maxkb
@Author
@file single_select_field.py
@date2023/10/31 18:00
@desc:
"""
from typing import List, Dict
from common.forms import BaseLabel
from common.forms.base_field import TriggerType, BaseExecField
class SingleSelect(BaseExecField):
"""
下拉单选
"""
def __init__(self,
label: str or BaseLabel,
text_field: str,
value_field: str,
option_list: List[str:object],
provider: str = None,
method: str = None,
required: bool = False,
default_value: object = None,
relation_show_field_dict: Dict = None,
relation_trigger_field_dict: Dict = None,
trigger_type: TriggerType = TriggerType.OPTION_LIST,
attrs: Dict[str, object] = None,
props_info: Dict[str, object] = None):
super().__init__("SingleSelect", label, text_field, value_field, provider, method, required, default_value,
relation_show_field_dict, relation_trigger_field_dict, trigger_type, attrs, props_info)
self.option_list = option_list
def to_dict(self):
return {**super().to_dict(), 'option_list': self.option_list}

View File

@ -0,0 +1,65 @@
# coding=utf-8
"""
@project: MaxKB
@Author
@file slider_field.py
@date2024/8/22 17:06
@desc:
"""
from typing import Dict
from common.exception.app_exception import AppApiException
from common.forms import BaseField, TriggerType, BaseLabel
from django.utils.translation import gettext_lazy as _
class SliderField(BaseField):
"""
滑块输入框
"""
def __init__(self, label: str or BaseLabel,
_min,
_max,
_step,
precision,
required: bool = False,
default_value=None,
relation_show_field_dict: Dict = None,
attrs=None, props_info=None):
"""
@param label: 提示
@param _min: 最小值
@param _max: 最大值
@param _step: 步长
@param precision: 保留多少小数
@param required: 是否必填
@param default_value: 默认值
@param relation_show_field_dict:
@param attrs:
@param props_info:
"""
_attrs = {'min': _min, 'max': _max, 'step': _step,
'precision': precision, 'show-input-controls': False, 'show-input': True}
if attrs is not None:
_attrs.update(attrs)
super().__init__('Slider', label, required, default_value, relation_show_field_dict,
{},
TriggerType.OPTION_LIST, _attrs, props_info)
def is_valid(self, value):
super().is_valid(value)
field_label = self.label.label if hasattr(self.label, 'to_dict') else self.label
if value is not None:
if value < self.attrs.get('min'):
raise AppApiException(500,
_("The {field_label} cannot be less than {min}").format(field_label=field_label,
min=self.attrs.get(
'min')))
if value > self.attrs.get('max'):
raise AppApiException(500,
_("The {field_label} cannot be greater than {max}").format(
field_label=field_label,
max=self.attrs.get(
'max')))

View File

@ -0,0 +1,33 @@
"""
@project: MaxKB
@Author
@file switch_field.py
@date2024/10/13 19:43
@desc:
"""
from typing import Dict
from common.forms import BaseField, TriggerType, BaseLabel
class SwitchField(BaseField):
"""
滑块输入框
"""
def __init__(self, label: str or BaseLabel,
required: bool = False,
default_value=None,
relation_show_field_dict: Dict = None,
attrs=None, props_info=None):
"""
@param required: 是否必填
@param default_value: 默认值
@param relation_show_field_dict:
@param attrs:
@param props_info:
"""
super().__init__('Switch', label, required, default_value, relation_show_field_dict,
{},
TriggerType.OPTION_LIST, attrs, props_info)

View File

@ -0,0 +1,33 @@
# coding=utf-8
"""
@project: maxkb
@Author
@file tab_card.py
@date2023/10/31 18:03
@desc:
"""
from typing import Dict
from common.forms.base_field import BaseExecField, TriggerType
class TabCard(BaseExecField):
"""
收集 Tab类型数据 tab1:{},tab2:{}
"""
def __init__(self,
label: str,
text_field: str,
value_field: str,
provider: str,
method: str,
required: bool = False,
default_value: object = None,
relation_show_field_dict: Dict = None,
relation_trigger_field_dict: Dict = None,
trigger_type: TriggerType = TriggerType.OPTION_LIST,
attrs: Dict[str, object] = None,
props_info: Dict[str, object] = None):
super().__init__("TabCard", label, text_field, value_field, provider, method, required, default_value,
relation_show_field_dict, relation_trigger_field_dict, trigger_type, attrs, props_info)

View File

@ -0,0 +1,33 @@
# coding=utf-8
"""
@project: maxkb
@Author
@file table_radio.py
@date2023/10/31 18:01
@desc:
"""
from typing import Dict
from common.forms.base_field import TriggerType, BaseExecField
class TableRadio(BaseExecField):
"""
table 单选
"""
def __init__(self,
label: str,
text_field: str,
value_field: str,
provider: str,
method: str,
required: bool = False,
default_value: object = None,
relation_show_field_dict: Dict = None,
relation_trigger_field_dict: Dict = None,
trigger_type: TriggerType = TriggerType.OPTION_LIST,
attrs: Dict[str, object] = None,
props_info: Dict[str, object] = None):
super().__init__("TableCheckbox", label, text_field, value_field, provider, method, required, default_value,
relation_show_field_dict, relation_trigger_field_dict, trigger_type, attrs, props_info)

View File

@ -0,0 +1,33 @@
# coding=utf-8
"""
@project: maxkb
@Author
@file table_radio.py
@date2023/10/31 18:01
@desc:
"""
from typing import Dict
from common.forms.base_field import TriggerType, BaseExecField
class TableRadio(BaseExecField):
"""
table 单选
"""
def __init__(self,
label: str,
text_field: str,
value_field: str,
provider: str,
method: str,
required: bool = False,
default_value: object = None,
relation_show_field_dict: Dict = None,
relation_trigger_field_dict: Dict = None,
trigger_type: TriggerType = TriggerType.OPTION_LIST,
attrs: Dict[str, object] = None,
props_info: Dict[str, object] = None):
super().__init__("TableRadio", label, text_field, value_field, provider, method, required, default_value,
relation_show_field_dict, relation_trigger_field_dict, trigger_type, attrs, props_info)

View File

@ -0,0 +1,28 @@
# coding=utf-8
"""
@project: maxkb
@Author
@file text_input_field.py
@date2023/10/31 17:58
@desc:
"""
from typing import Dict
from common.forms import BaseLabel
from common.forms.base_field import BaseField, TriggerType
class TextInputField(BaseField):
"""
文本输入框
"""
def __init__(self, label: str or BaseLabel,
required: bool = False,
default_value=None,
relation_show_field_dict: Dict = None,
attrs=None, props_info=None):
super().__init__('TextInput', label, required, default_value, relation_show_field_dict,
{},
TriggerType.OPTION_LIST, attrs, props_info)

View File

@ -0,0 +1,18 @@
# coding=utf-8
"""
@project: maxkb
@Author
@file app_model_mixin.py
@date2023/9/21 9:41
@desc:
"""
from django.db import models
class AppModelMixin(models.Model):
create_time = models.DateTimeField(verbose_name="创建时间", auto_now_add=True)
update_time = models.DateTimeField(verbose_name="修改时间", auto_now=True)
class Meta:
abstract = True
ordering = ['create_time']

View File

@ -7,8 +7,18 @@
@desc:
"""
import hashlib
import io
import mimetypes
import re
import shutil
from typing import List
from django.core.files.uploadedfile import InMemoryUploadedFile
from django.utils.translation import gettext as _
from pydub import AudioSegment
from ..exception.app_exception import AppApiException
def password_encrypt(row_password):
"""
@ -36,3 +46,153 @@ def group_by(list_source: List, key):
array.append(e)
result[k] = array
return result
def encryption(message: str):
"""
加密敏感字段数据 加密方式是 如果密码是 1234567890 那么给前端则是 123******890
:param message:
:return:
"""
max_pre_len = 8
max_post_len = 4
message_len = len(message)
pre_len = int(message_len / 5 * 2)
post_len = int(message_len / 5 * 1)
pre_str = "".join([message[index] for index in
range(0, max_pre_len if pre_len > max_pre_len else 1 if pre_len <= 0 else int(pre_len))])
end_str = "".join(
[message[index] for index in
range(message_len - (int(post_len) if pre_len < max_post_len else max_post_len), message_len)])
content = "***************"
return pre_str + content + end_str
def _remove_empty_lines(text):
if not isinstance(text, str):
raise AppApiException(500, _('Text-to-speech node, the text content must be of string type'))
if not text:
raise AppApiException(500, _('Text-to-speech node, the text content cannot be empty'))
result = '\n'.join(line for line in text.split('\n') if line.strip())
return markdown_to_plain_text(result)
def markdown_to_plain_text(md: str) -> str:
# 移除图片 ![alt](url)
text = re.sub(r'!\[.*?\]\(.*?\)', '', md)
# 移除链接 [text](url)
text = re.sub(r'\[([^\]]+)\]\([^)]+\)', r'\1', text)
# 移除 Markdown 标题符号 (#, ##, ###)
text = re.sub(r'^#{1,6}\s+', '', text, flags=re.MULTILINE)
# 移除加粗 **text** 或 __text__
text = re.sub(r'\*\*(.*?)\*\*', r'\1', text)
text = re.sub(r'__(.*?)__', r'\1', text)
# 移除斜体 *text* 或 _text_
text = re.sub(r'\*(.*?)\*', r'\1', text)
text = re.sub(r'_(.*?)_', r'\1', text)
# 移除行内代码 `code`
text = re.sub(r'`(.*?)`', r'\1', text)
# 移除代码块 ```code```
text = re.sub(r'```[\s\S]*?```', '', text)
# 移除多余的换行符
text = re.sub(r'\n{2,}', '\n', text)
# 使用正则表达式去除所有 HTML 标签
text = re.sub(r'<[^>]+>', '', text)
# 去除多余的空白字符(包括换行符、制表符等)
text = re.sub(r'\s+', ' ', text)
# 去除表单渲染
re.sub(r'<form_rander>[\s\S]*?<\/form_rander>', '', text)
# 去除首尾空格
text = text.strip()
return text
def get_file_content(path):
with open(path, "r", encoding='utf-8') as file:
content = file.read()
return content
def bytes_to_uploaded_file(file_bytes, file_name="file.txt"):
content_type, _ = mimetypes.guess_type(file_name)
if content_type is None:
# 如果未能识别,设置为默认的二进制文件类型
content_type = "application/octet-stream"
# 创建一个内存中的字节流对象
file_stream = io.BytesIO(file_bytes)
# 获取文件大小
file_size = len(file_bytes)
# 创建 InMemoryUploadedFile 对象
uploaded_file = InMemoryUploadedFile(
file=file_stream,
field_name=None,
name=file_name,
content_type=content_type,
size=file_size,
charset=None,
)
return uploaded_file
def any_to_amr(any_path, amr_path):
"""
把任意格式转成amr文件
"""
if any_path.endswith(".amr"):
shutil.copy2(any_path, amr_path)
return
if any_path.endswith(".sil") or any_path.endswith(".silk") or any_path.endswith(".slk"):
raise NotImplementedError("Not support file type: {}".format(any_path))
audio = AudioSegment.from_file(any_path)
audio = audio.set_frame_rate(8000) # only support 8000
audio.export(amr_path, format="amr")
return audio.duration_seconds * 1000
def any_to_mp3(any_path, mp3_path):
"""
把任意格式转成mp3文件
"""
if any_path.endswith(".mp3"):
shutil.copy2(any_path, mp3_path)
return
if any_path.endswith(".sil") or any_path.endswith(".silk") or any_path.endswith(".slk"):
sil_to_wav(any_path, any_path)
any_path = mp3_path
audio = AudioSegment.from_file(any_path)
audio = audio.set_frame_rate(16000)
audio.export(mp3_path, format="mp3")
def sil_to_wav(silk_path, wav_path, rate: int = 24000):
"""
silk 文件转 wav
"""
try:
import pysilk
except ImportError:
raise AppApiException("import pysilk failed, wechaty voice message will not be supported.")
wav_data = pysilk.decode_file(silk_path, to_wav=True, sample_rate=rate)
with open(wav_path, "wb") as f:
f.write(wav_data)
def split_and_transcribe(file_path, model, max_segment_length_ms=59000, audio_format="mp3"):
audio_data = AudioSegment.from_file(file_path, format=audio_format)
audio_length_ms = len(audio_data)
if audio_length_ms <= max_segment_length_ms:
return model.speech_to_text(io.BytesIO(audio_data.export(format=audio_format).read()))
full_text = []
for start_ms in range(0, audio_length_ms, max_segment_length_ms):
end_ms = min(audio_length_ms, start_ms + max_segment_length_ms)
segment = audio_data[start_ms:end_ms]
text = model.speech_to_text(io.BytesIO(segment.export(format=audio_format).read()))
if isinstance(text, str):
full_text.append(text)
return ' '.join(full_text)

View File

@ -0,0 +1,140 @@
# coding=utf-8
"""
@project: maxkb
@Author
@file rsa_util.py
@date2023/11/3 11:13
@desc:
"""
import base64
import threading
from Crypto.Cipher import PKCS1_v1_5 as PKCS1_cipher
from Crypto.PublicKey import RSA
from django.core import cache
from django.db.models import QuerySet
from system_manage.models import SystemSetting, SettingType
lock = threading.Lock()
rsa_cache = cache.caches['default']
cache_key = "rsa_key"
# 对密钥加密的密码
secret_code = "mac_kb_password"
def generate():
"""
生成 私钥秘钥对
:return:{key:'公钥',value:'私钥'}
"""
# 生成一个 2048 位的密钥
key = RSA.generate(2048)
# 获取私钥
encrypted_key = key.export_key(passphrase=secret_code, pkcs=8,
protection="scryptAndAES128-CBC")
return {'key': key.publickey().export_key(), 'value': encrypted_key}
def get_key_pair():
rsa_value = rsa_cache.get(cache_key)
if rsa_value is None:
lock.acquire()
rsa_value = rsa_cache.get(cache_key)
if rsa_value is not None:
return rsa_value
try:
rsa_value = get_key_pair_by_sql()
rsa_cache.set(cache_key, rsa_value)
finally:
lock.release()
return rsa_value
def get_key_pair_by_sql():
system_setting = QuerySet(SystemSetting).filter(type=SettingType.RSA.value).first()
if system_setting is None:
kv = generate()
system_setting = SystemSetting(type=SettingType.RSA.value,
meta={'key': kv.get('key').decode(), 'value': kv.get('value').decode()})
system_setting.save()
return system_setting.meta
def encrypt(msg, public_key: str | None = None):
"""
加密
:param msg: 加密数据
:param public_key: 公钥
:return: 加密后的数据
"""
if public_key is None:
public_key = get_key_pair().get('key')
cipher = PKCS1_cipher.new(RSA.importKey(public_key))
encrypt_msg = cipher.encrypt(msg.encode("utf-8"))
return base64.b64encode(encrypt_msg).decode()
def decrypt(msg, pri_key: str | None = None):
"""
解密
:param msg: 需要解密的数据
:param pri_key: 私钥
:return: 解密后数据
"""
if pri_key is None:
pri_key = get_key_pair().get('value')
cipher = PKCS1_cipher.new(RSA.importKey(pri_key, passphrase=secret_code))
decrypt_data = cipher.decrypt(base64.b64decode(msg), 0)
return decrypt_data.decode("utf-8")
def rsa_long_encrypt(message, public_key: str | None = None, length=200):
"""
超长文本加密
:param message: 需要加密的字符串
:param public_key 公钥
:param length: 1024bit的证书用100 2048bit的证书用 200
:return: 加密后的数据
"""
# 读取公钥
if public_key is None:
public_key = get_key_pair().get('key')
cipher = PKCS1_cipher.new(RSA.importKey(extern_key=public_key,
passphrase=secret_code))
# 处理Plaintext is too long. 分段加密
if len(message) <= length:
# 对编码的数据进行加密并通过base64进行编码
result = base64.b64encode(cipher.encrypt(message.encode('utf-8')))
else:
rsa_text = []
# 对编码后的数据进行切片,原因:加密长度不能过长
for i in range(0, len(message), length):
cont = message[i:i + length]
# 对切片后的数据进行加密并新增到text后面
rsa_text.append(cipher.encrypt(cont.encode('utf-8')))
# 加密完进行拼接
cipher_text = b''.join(rsa_text)
# base64进行编码
result = base64.b64encode(cipher_text)
return result.decode()
def rsa_long_decrypt(message, pri_key: str | None = None, length=256):
"""
超长文本解密默认不加密
:param message: 需要解密的数据
:param pri_key: 秘钥
:param length : 1024bit的证书用1282048bit证书用256位
:return: 解密后的数据
"""
if pri_key is None:
pri_key = get_key_pair().get('value')
cipher = PKCS1_cipher.new(RSA.importKey(pri_key, passphrase=secret_code))
base64_de = base64.b64decode(message)
res = []
for i in range(0, len(base64_de), length):
res.append(cipher.decrypt(base64_de[i:i + length], 0))
return b"".join(res).decode()

View File

@ -102,3 +102,5 @@ msgstr "用户管理"
#: .\apps\users\views\user.py:24 .\apps\users\views\user.py:25
msgid "Get current user information"
msgstr "获取当前用户信息"

View File

@ -41,7 +41,8 @@ INSTALLED_APPS = [
'users.apps.UsersConfig',
'tools.apps.ToolConfig',
'common',
'system_manage'
'system_manage',
'models_provider',
]
MIDDLEWARE = [

View File

@ -22,7 +22,8 @@ from maxkb import settings
urlpatterns = [
path("api/", include("users.urls")),
path("api/", include("tools.urls"))
path("api/", include("tools.urls")),
path("api/", include("models_provider.urls")),
]
urlpatterns += [
path('schema/', SpectacularAPIView.as_view(), name='schema'), # schema的配置文件的路由下面两个ui也是根据这个配置文件来生成的

View File

View File

@ -0,0 +1,3 @@
from django.contrib import admin
# Register your models here.

View File

@ -0,0 +1 @@
# coding=utf-8

View File

@ -0,0 +1,20 @@
# coding=utf-8
from common.mixins.api_mixin import APIMixin
from common.result import ResultSerializer
from models_provider.serializers.model import ModelCreateRequest, ModelModelSerializer
class ModelCreateResponse(ResultSerializer):
def get_data(self):
return ModelModelSerializer()
class ModelCreateAPI(APIMixin):
@staticmethod
def get_request():
return ModelCreateRequest
@staticmethod
def get_response():
return ModelCreateResponse

View File

@ -0,0 +1,85 @@
# coding=utf-8
from drf_spectacular.types import OpenApiTypes
from drf_spectacular.utils import OpenApiParameter
from common.mixins.api_mixin import APIMixin
from common.result import ResultSerializer
from rest_framework import serializers
from django.utils.translation import gettext_lazy as _
class ProvideResponse(ResultSerializer):
def get_data(self):
return ProvideSerializer()
class ProvideSerializer(serializers.Serializer):
name = serializers.CharField(required=True, max_length=64, label=_("model name"))
provider = serializers.CharField(required=True, label=_("provider"))
icon = serializers.CharField(required=True, label=_("icon"))
class ProvideListSerializer(serializers.Serializer):
key = serializers.CharField(required=True, max_length=64, label=_("model name"))
value = serializers.CharField(required=True, label=_("value"))
class ModelListSerializer(serializers.Serializer):
name = serializers.CharField(required=True, label=_("model name"))
model_type = serializers.CharField(required=True, label=_("model type"))
desc = serializers.CharField(required=True, label=_("model name"))
class ProvideApi(APIMixin):
class ModelList(APIMixin):
@staticmethod
def get_query_params_api():
return [OpenApiParameter(
# 参数的名称是done
name="model_type",
# 对参数的备注
description="model_type",
# 指定参数的类型
type=OpenApiTypes.STR,
location=OpenApiParameter.QUERY,
# 指定必须给
required=False,
), OpenApiParameter(
# 参数的名称是done
name="provider",
# 对参数的备注
description="provider",
# 指定参数的类型
type=OpenApiTypes.STR,
location=OpenApiParameter.QUERY,
# 指定必须给
required=True,
)
]
@staticmethod
def get_response():
return serializers.ListSerializer(child=ModelListSerializer())
@staticmethod
def get_response():
return ProvideResponse
class ModelTypeList(APIMixin):
@staticmethod
def get_query_params_api():
return [OpenApiParameter(
# 参数的名称是done
name="provider",
# 对参数的备注
description="provider",
# 指定参数的类型
type=OpenApiTypes.STR,
location=OpenApiParameter.QUERY,
# 指定必须给
required=True,
)]
@staticmethod
def get_response():
return serializers.ListSerializer(child=ProvideListSerializer())

View File

@ -0,0 +1,6 @@
from django.apps import AppConfig
class ModelsProviderConfig(AppConfig):
default_auto_field = 'django.db.models.BigAutoField'
name = 'models_provider'

View File

@ -0,0 +1,255 @@
# coding=utf-8
from abc import ABC, abstractmethod
from enum import Enum
from functools import reduce
from typing import Dict, Iterator, Type, List
from pydantic import BaseModel
from common.exception.app_exception import AppApiException
from django.utils.translation import gettext_lazy as _
from common.utils.common import encryption
class DownModelChunkStatus(Enum):
success = "success"
error = "error"
pulling = "pulling"
unknown = 'unknown'
class ValidCode(Enum):
valid_error = 500
model_not_fount = 404
class DownModelChunk:
def __init__(self, status: DownModelChunkStatus, digest: str, progress: int, details: str, index: int):
self.details = details
self.status = status
self.digest = digest
self.progress = progress
self.index = index
def to_dict(self):
return {
"details": self.details,
"status": self.status.value,
"digest": self.digest,
"progress": self.progress,
"index": self.index
}
class IModelProvider(ABC):
@abstractmethod
def get_model_info_manage(self):
pass
@abstractmethod
def get_model_provide_info(self):
pass
def get_model_type_list(self):
return self.get_model_info_manage().get_model_type_list()
def get_model_list(self, model_type):
if model_type is None:
raise AppApiException(500, _('Model type cannot be empty'))
return self.get_model_info_manage().get_model_list_by_model_type(model_type)
def get_model_credential(self, model_type, model_name):
model_info = self.get_model_info_manage().get_model_info(model_type, model_name)
return model_info.model_credential
def get_model_params(self, model_type, model_name):
model_info = self.get_model_info_manage().get_model_info(model_type, model_name)
return model_info.model_credential
def is_valid_credential(self, model_type, model_name, model_credential: Dict[str, object],
model_params: Dict[str, object], raise_exception=False):
model_info = self.get_model_info_manage().get_model_info(model_type, model_name)
return model_info.model_credential.is_valid(model_type, model_name, model_credential, model_params, self,
raise_exception=raise_exception)
def get_model(self, model_type, model_name, model_credential: Dict[str, object], **model_kwargs) -> BaseModel:
model_info = self.get_model_info_manage().get_model_info(model_type, model_name)
return model_info.model_class.new_instance(model_type, model_name, model_credential, **model_kwargs)
def get_dialogue_number(self):
return 3
def down_model(self, model_type: str, model_name, model_credential: Dict[str, object]) -> Iterator[DownModelChunk]:
raise AppApiException(500, _('The current platform does not support downloading models'))
class MaxKBBaseModel(ABC):
@staticmethod
@abstractmethod
def new_instance(model_type, model_name, model_credential: Dict[str, object], **model_kwargs):
pass
@staticmethod
def is_cache_model():
return True
@staticmethod
def filter_optional_params(model_kwargs):
optional_params = {}
for key, value in model_kwargs.items():
if key not in ['model_id', 'use_local', 'streaming', 'show_ref_label']:
optional_params[key] = value
return optional_params
class BaseModelCredential(ABC):
@abstractmethod
def is_valid(self, model_type: str, model_name, model: Dict[str, object], model_params, provider,
raise_exception=True):
pass
@abstractmethod
def encryption_dict(self, model_info: Dict[str, object]):
"""
:param model_info: 模型数据
:return: 加密后数据
"""
pass
def get_model_params_setting_form(self, model_name):
"""
模型参数设置表单
:return:
"""
pass
@staticmethod
def encryption(message: str):
"""
加密敏感字段数据 加密方式是 如果密码是 1234567890 那么给前端则是 123******890
:param message:
:return:
"""
return encryption(message)
class ModelTypeConst(Enum):
LLM = {'code': 'LLM', 'message': _('LLM')}
EMBEDDING = {'code': 'EMBEDDING', 'message': _('Embedding Model')}
STT = {'code': 'STT', 'message': _('Speech2Text')}
TTS = {'code': 'TTS', 'message': _('TTS')}
IMAGE = {'code': 'IMAGE', 'message': _('Vision Model')}
TTI = {'code': 'TTI', 'message': _('Image Generation')}
RERANKER = {'code': 'RERANKER', 'message': _('Rerank')}
class ModelInfo:
def __init__(self, name: str, desc: str, model_type: ModelTypeConst, model_credential: BaseModelCredential,
model_class: Type[MaxKBBaseModel],
**keywords):
self.name = name
self.desc = desc
self.model_type = model_type.name
self.model_credential = model_credential
self.model_class = model_class
if keywords is not None:
for key in keywords.keys():
self.__setattr__(key, keywords.get(key))
def get_name(self):
"""
获取模型名称
:return: 模型名称
"""
return self.name
def get_desc(self):
"""
获取模型描述
:return: 模型描述
"""
return self.desc
def get_model_type(self):
return self.model_type
def get_model_class(self):
return self.model_class
def to_dict(self):
return reduce(lambda x, y: {**x, **y},
[{attr: self.__getattribute__(attr)} for attr in vars(self) if
not attr.startswith("__") and not attr == 'model_credential' and not attr == 'model_class'], {})
class ModelInfoManage:
def __init__(self):
self.model_dict = {}
self.model_list = []
self.default_model_list = []
self.default_model_dict = {}
def append_model_info(self, model_info: ModelInfo):
self.model_list.append(model_info)
model_type_dict = self.model_dict.get(model_info.model_type)
if model_type_dict is None:
self.model_dict[model_info.model_type] = {model_info.name: model_info}
else:
model_type_dict[model_info.name] = model_info
def append_default_model_info(self, model_info: ModelInfo):
self.default_model_list.append(model_info)
self.default_model_dict[model_info.model_type] = model_info
def get_model_list(self):
return [model.to_dict() for model in self.model_list]
def get_model_list_by_model_type(self, model_type):
return [model.to_dict() for model in self.model_list if model.model_type == model_type]
def get_model_type_list(self):
return [{'key': _type.value.get('message'), 'value': _type.value.get('code')} for _type in ModelTypeConst if
len([model for model in self.model_list if model.model_type == _type.name]) > 0]
def get_model_info(self, model_type, model_name) -> ModelInfo:
model_info = self.model_dict.get(model_type, {}).get(model_name, self.default_model_dict.get(model_type))
if model_info is None:
raise AppApiException(500, _('The model does not support'))
return model_info
class builder:
def __init__(self):
self.modelInfoManage = ModelInfoManage()
def append_model_info(self, model_info: ModelInfo):
self.modelInfoManage.append_model_info(model_info)
return self
def append_model_info_list(self, model_info_list: List[ModelInfo]):
for model_info in model_info_list:
self.modelInfoManage.append_model_info(model_info)
return self
def append_default_model_info(self, model_info: ModelInfo):
self.modelInfoManage.append_default_model_info(model_info)
return self
def build(self):
return self.modelInfoManage
class ModelProvideInfo:
def __init__(self, provider: str, name: str, icon: str):
self.provider = provider
self.name = name
self.icon = icon
def to_dict(self):
return reduce(lambda x, y: {**x, **y},
[{attr: self.__getattribute__(attr)} for attr in vars(self) if
not attr.startswith("__")], {})

View File

@ -0,0 +1,48 @@
# coding=utf-8
from enum import Enum
from models_provider.impl.aliyun_bai_lian_model_provider.aliyun_bai_lian_model_provider import \
AliyunBaiLianModelProvider
from models_provider.impl.anthropic_model_provider.anthropic_model_provider import AnthropicModelProvider
from models_provider.impl.aws_bedrock_model_provider.aws_bedrock_model_provider import BedrockModelProvider
from models_provider.impl.azure_model_provider.azure_model_provider import AzureModelProvider
from models_provider.impl.deepseek_model_provider.deepseek_model_provider import DeepSeekModelProvider
from models_provider.impl.gemini_model_provider.gemini_model_provider import GeminiModelProvider
from models_provider.impl.kimi_model_provider.kimi_model_provider import KimiModelProvider
from models_provider.impl.local_model_provider.local_model_provider import LocalModelProvider
from models_provider.impl.ollama_model_provider.ollama_model_provider import OllamaModelProvider
from models_provider.impl.openai_model_provider.openai_model_provider import OpenAIModelProvider
from models_provider.impl.qwen_model_provider.qwen_model_provider import QwenModelProvider
from models_provider.impl.siliconCloud_model_provider.siliconCloud_model_provider import SiliconCloudModelProvider
from models_provider.impl.tencent_cloud_model_provider.tencent_cloud_model_provider import TencentCloudModelProvider
from models_provider.impl.tencent_model_provider.tencent_model_provider import TencentModelProvider
from models_provider.impl.vllm_model_provider.vllm_model_provider import VllmModelProvider
from models_provider.impl.volcanic_engine_model_provider.volcanic_engine_model_provider import \
VolcanicEngineModelProvider
from models_provider.impl.wenxin_model_provider.wenxin_model_provider import WenxinModelProvider
from models_provider.impl.xf_model_provider.xf_model_provider import XunFeiModelProvider
from models_provider.impl.xinference_model_provider.xinference_model_provider import XinferenceModelProvider
from models_provider.impl.zhipu_model_provider.zhipu_model_provider import ZhiPuModelProvider
class ModelProvideConstants(Enum):
model_azure_provider = AzureModelProvider()
model_wenxin_provider = WenxinModelProvider()
model_ollama_provider = OllamaModelProvider()
model_openai_provider = OpenAIModelProvider()
model_kimi_provider = KimiModelProvider()
model_qwen_provider = QwenModelProvider()
model_zhipu_provider = ZhiPuModelProvider()
model_xf_provider = XunFeiModelProvider()
model_deepseek_provider = DeepSeekModelProvider()
model_gemini_provider = GeminiModelProvider()
model_volcanic_engine_provider = VolcanicEngineModelProvider()
model_tencent_provider = TencentModelProvider()
model_tencent_cloud_provider = TencentCloudModelProvider()
model_aws_bedrock_provider = BedrockModelProvider()
model_local_provider = LocalModelProvider()
model_xinference_provider = XinferenceModelProvider()
model_vllm_provider = VllmModelProvider()
aliyun_bai_lian_model_provider = AliyunBaiLianModelProvider()
model_anthropic_provider = AnthropicModelProvider()
model_siliconCloud_provider = SiliconCloudModelProvider()

View File

@ -0,0 +1,8 @@
# coding=utf-8
"""
@project: MaxKB
@Author
@file __init__.py
@date2024/9/9 17:42
@desc:
"""

View File

@ -0,0 +1,100 @@
# coding=utf-8
"""
@project: MaxKB
@Author
@file aliyun_bai_lian_model_provider.py
@date2024/9/9 17:43
@desc:
"""
import os
from common.utils.common import get_file_content
from models_provider.base_model_provider import ModelProvideInfo, ModelTypeConst, ModelInfo, IModelProvider, \
ModelInfoManage
from models_provider.impl.aliyun_bai_lian_model_provider.credential.embedding import \
AliyunBaiLianEmbeddingCredential
from models_provider.impl.aliyun_bai_lian_model_provider.credential.image import QwenVLModelCredential
from models_provider.impl.aliyun_bai_lian_model_provider.credential.llm import BaiLianLLMModelCredential
from models_provider.impl.aliyun_bai_lian_model_provider.credential.reranker import \
AliyunBaiLianRerankerCredential
from models_provider.impl.aliyun_bai_lian_model_provider.credential.stt import AliyunBaiLianSTTModelCredential
from models_provider.impl.aliyun_bai_lian_model_provider.credential.tti import QwenTextToImageModelCredential
from models_provider.impl.aliyun_bai_lian_model_provider.credential.tts import AliyunBaiLianTTSModelCredential
from models_provider.impl.aliyun_bai_lian_model_provider.model.embedding import AliyunBaiLianEmbedding
from models_provider.impl.aliyun_bai_lian_model_provider.model.image import QwenVLChatModel
from models_provider.impl.aliyun_bai_lian_model_provider.model.llm import BaiLianChatModel
from models_provider.impl.aliyun_bai_lian_model_provider.model.reranker import AliyunBaiLianReranker
from models_provider.impl.aliyun_bai_lian_model_provider.model.stt import AliyunBaiLianSpeechToText
from models_provider.impl.aliyun_bai_lian_model_provider.model.tti import QwenTextToImageModel
from models_provider.impl.aliyun_bai_lian_model_provider.model.tts import AliyunBaiLianTextToSpeech
from maxkb.conf import PROJECT_DIR
from django.utils.translation import gettext as _, gettext
aliyun_bai_lian_model_credential = AliyunBaiLianRerankerCredential()
aliyun_bai_lian_tts_model_credential = AliyunBaiLianTTSModelCredential()
aliyun_bai_lian_stt_model_credential = AliyunBaiLianSTTModelCredential()
aliyun_bai_lian_embedding_model_credential = AliyunBaiLianEmbeddingCredential()
aliyun_bai_lian_llm_model_credential = BaiLianLLMModelCredential()
qwenvl_model_credential = QwenVLModelCredential()
qwentti_model_credential = QwenTextToImageModelCredential()
model_info_list = [ModelInfo('gte-rerank',
_('With the GTE-Rerank text sorting series model developed by Alibaba Tongyi Lab, developers can integrate high-quality text retrieval and sorting through the LlamaIndex framework.'),
ModelTypeConst.RERANKER, aliyun_bai_lian_model_credential, AliyunBaiLianReranker),
ModelInfo('paraformer-realtime-v2',
_('Chinese (including various dialects such as Cantonese), English, Japanese, and Korean support free switching between multiple languages.'),
ModelTypeConst.STT, aliyun_bai_lian_stt_model_credential, AliyunBaiLianSpeechToText),
ModelInfo('cosyvoice-v1',
_('CosyVoice is based on a new generation of large generative speech models, which can predict emotions, intonation, rhythm, etc. based on context, and has better anthropomorphic effects.'),
ModelTypeConst.TTS, aliyun_bai_lian_tts_model_credential, AliyunBaiLianTextToSpeech),
ModelInfo('text-embedding-v1',
_("Universal text vector is Tongyi Lab's multi-language text unified vector model based on the LLM base. It provides high-level vector services for multiple mainstream languages around the world and helps developers quickly convert text data into high-quality vector data."),
ModelTypeConst.EMBEDDING, aliyun_bai_lian_embedding_model_credential,
AliyunBaiLianEmbedding),
ModelInfo('qwen-turbo', '', ModelTypeConst.LLM, aliyun_bai_lian_llm_model_credential,
BaiLianChatModel),
ModelInfo('qwen-plus', '', ModelTypeConst.LLM, aliyun_bai_lian_llm_model_credential,
BaiLianChatModel),
ModelInfo('qwen-max', '', ModelTypeConst.LLM, aliyun_bai_lian_llm_model_credential,
BaiLianChatModel)
]
module_info_vl_list = [
ModelInfo('qwen-vl-max', '', ModelTypeConst.IMAGE, qwenvl_model_credential, QwenVLChatModel),
ModelInfo('qwen-vl-max-0809', '', ModelTypeConst.IMAGE, qwenvl_model_credential, QwenVLChatModel),
ModelInfo('qwen-vl-plus-0809', '', ModelTypeConst.IMAGE, qwenvl_model_credential, QwenVLChatModel),
]
module_info_tti_list = [
ModelInfo('wanx-v1',
_('Tongyi Wanxiang - a large image model for text generation, supports bilingual input in Chinese and English, and supports the input of reference pictures for reference content or reference style migration. Key styles include but are not limited to watercolor, oil painting, Chinese painting, sketch, flat illustration, two-dimensional, and 3D. Cartoon.'),
ModelTypeConst.TTI, qwentti_model_credential, QwenTextToImageModel),
]
model_info_manage = (
ModelInfoManage.builder()
.append_model_info_list(model_info_list)
.append_model_info_list(module_info_vl_list)
.append_default_model_info(module_info_vl_list[0])
.append_model_info_list(module_info_tti_list)
.append_default_model_info(module_info_tti_list[0])
.append_default_model_info(model_info_list[1])
.append_default_model_info(model_info_list[2])
.append_default_model_info(model_info_list[3])
.append_default_model_info(model_info_list[4])
.append_default_model_info(model_info_list[0])
.build()
)
class AliyunBaiLianModelProvider(IModelProvider):
def get_model_info_manage(self):
return model_info_manage
def get_model_provide_info(self):
return ModelProvideInfo(provider='aliyun_bai_lian_model_provider', name=gettext('Alibaba Cloud Bailian'),
icon=get_file_content(
os.path.join(PROJECT_DIR, "apps", 'models_provider', 'impl',
'aliyun_bai_lian_model_provider',
'icon',
'aliyun_bai_lian_icon_svg')))

View File

@ -0,0 +1,74 @@
# coding=utf-8
"""
@project: MaxKB
@Author
@file embedding.py
@date2024/10/16 17:01
@desc:
"""
import traceback
from typing import Dict, Any
from django.utils.translation import gettext as _
from common import forms
from common.exception.app_exception import AppApiException
from common.forms import BaseForm
from models_provider.base_model_provider import BaseModelCredential, ValidCode
from models_provider.impl.aliyun_bai_lian_model_provider.model.embedding import AliyunBaiLianEmbedding
class AliyunBaiLianEmbeddingCredential(BaseForm, BaseModelCredential):
def is_valid(
self,
model_type: str,
model_name: str,
model_credential: Dict[str, Any],
model_params: Dict[str, Any],
provider: Any,
raise_exception: bool = False
) -> bool:
"""
验证模型凭据是否有效
"""
model_type_list = provider.get_model_type_list()
if not any(mt.get('value') == model_type for mt in model_type_list):
raise AppApiException(
ValidCode.valid_error.value,
f"{model_type} Model type is not supported"
)
required_keys = ['dashscope_api_key']
missing_keys = [key for key in required_keys if key not in model_credential]
if missing_keys:
if raise_exception:
raise AppApiException(
ValidCode.valid_error.value,
f"{', '.join(missing_keys)} is required"
)
return False
try:
model: AliyunBaiLianEmbedding = provider.get_model(model_type, model_name, model_credential)
model.embed_query(_("Hello"))
except Exception as e:
traceback.print_exc()
if isinstance(e, AppApiException):
raise e
if raise_exception:
raise AppApiException(
ValidCode.valid_error.value,
f"Verification failed, please check whether the parameters are correct: {e}"
)
return False
return True
def encryption_dict(self, model: Dict[str, Any]) -> Dict[str, Any]:
"""
加密敏感信息
"""
api_key = model.get('dashscope_api_key', '')
return {**model, 'dashscope_api_key': super().encryption(api_key)}
dashscope_api_key = forms.PasswordInputField('API Key', required=True)

View File

@ -0,0 +1,71 @@
# coding=utf-8
"""
@project: MaxKB
@Author
@file llm.py
@date2024/7/11 18:41
@desc:
"""
import traceback
from typing import Dict
from django.utils.translation import gettext_lazy as _, gettext
from langchain_core.messages import HumanMessage
from common import forms
from common.exception.app_exception import AppApiException
from common.forms import BaseForm
from models_provider.base_model_provider import BaseModelCredential, ValidCode
class QwenVLModelCredential(BaseForm, BaseModelCredential):
def is_valid(
self,
model_type: str,
model_name: str,
model_credential: Dict[str, object],
model_params: dict,
provider,
raise_exception: bool = False
) -> bool:
model_type_list = provider.get_model_type_list()
if not any(mt.get('value') == model_type for mt in model_type_list):
raise AppApiException(
ValidCode.valid_error.value,
gettext('{model_type} Model type is not supported').format(model_type=model_type)
)
required_keys = ['api_key']
for key in required_keys:
if key not in model_credential:
if raise_exception:
raise AppApiException(
ValidCode.valid_error.value,
gettext('{key} is required').format(key=key)
)
return False
try:
model = provider.get_model(model_type, model_name, model_credential, **model_params)
res = model.stream([HumanMessage(content=[{"type": "text", "text": gettext('Hello')}])])
for chunk in res:
print(chunk)
except Exception as e:
traceback.print_exc()
if isinstance(e, AppApiException):
raise e
if raise_exception:
raise AppApiException(
ValidCode.valid_error.value,
gettext('Verification failed, please check whether the parameters are correct: {error}').format(
error=str(e)
)
)
return False
return True
def encryption_dict(self, model: Dict[str, object]) -> Dict[str, object]:
return {**model, 'api_key': super().encryption(model.get('api_key', ''))}
api_key = forms.PasswordInputField('API Key', required=True)

View File

@ -0,0 +1,94 @@
# coding=utf-8
import traceback
from typing import Dict
from langchain_core.messages import HumanMessage
from django.utils.translation import gettext_lazy as _, gettext
from common import forms
from common.exception.app_exception import AppApiException
from common.forms import BaseForm, TooltipLabel
from models_provider.base_model_provider import BaseModelCredential, ValidCode
class BaiLianLLMModelParams(BaseForm):
temperature = forms.SliderField(
TooltipLabel(
_('Temperature'),
_('Higher values make the output more random, while lower values make it more focused and deterministic.')
),
required=True,
default_value=0.7,
_min=0.1,
_max=1.0,
_step=0.01,
precision=2
)
max_tokens = forms.SliderField(
TooltipLabel(
_('Output the maximum Tokens'),
_('Specify the maximum number of tokens that the model can generate.')
),
required=True,
default_value=800,
_min=1,
_max=100000,
_step=1,
precision=0
)
class BaiLianLLMModelCredential(BaseForm, BaseModelCredential):
api_base = forms.TextInputField(_('API URL'), required=True)
api_key = forms.PasswordInputField(_('API Key'), required=True)
def is_valid(
self,
model_type: str,
model_name: str,
model_credential: Dict[str, object],
model_params: dict,
provider,
raise_exception: bool = False
) -> bool:
model_type_list = provider.get_model_type_list()
if not any(mt.get('value') == model_type for mt in model_type_list):
raise AppApiException(
ValidCode.valid_error.value,
gettext('{model_type} Model type is not supported').format(model_type=model_type)
)
for key in ['api_base', 'api_key']:
if key not in model_credential:
if raise_exception:
raise AppApiException(
ValidCode.valid_error.value,
gettext('{key} is required').format(key=key)
)
return False
try:
model = provider.get_model(model_type, model_name, model_credential, **model_params)
model.invoke([HumanMessage(content=gettext('Hello'))])
except Exception as e:
traceback.print_exc()
if isinstance(e, AppApiException):
raise e
if raise_exception:
raise AppApiException(
ValidCode.valid_error.value,
gettext('Verification failed, please check whether the parameters are correct: {error}').format(
error=str(e)
)
)
return False
return True
def encryption_dict(self, model: Dict[str, object]) -> Dict[str, object]:
return {**model, 'api_key': super().encryption(model.get('api_key', ''))}
def get_model_params_setting_form(self, model_name: str) -> BaiLianLLMModelParams:
return BaiLianLLMModelParams()

View File

@ -0,0 +1,85 @@
# coding=utf-8
import traceback
from typing import Dict, Any
from django.utils.translation import gettext as _
from langchain_core.documents import Document
from common.exception.app_exception import AppApiException
from common.forms import BaseForm, PasswordInputField
from models_provider.base_model_provider import BaseModelCredential, ValidCode
from models_provider.impl.aliyun_bai_lian_model_provider.model.reranker import AliyunBaiLianReranker
class AliyunBaiLianRerankerCredential(BaseForm, BaseModelCredential):
"""
Credential class for the Aliyun BaiLian Reranker model.
Provides validation and encryption for the model credentials.
"""
dashscope_api_key = PasswordInputField('API Key', required=True)
def is_valid(
self,
model_type: str,
model_name: str,
model_credential: Dict[str, Any],
model_params: Dict[str, Any],
provider,
raise_exception: bool = False
) -> bool:
"""
Validate the model credentials.
:param model_type: Type of the model (e.g., 'RERANKER').
:param model_name: Name of the model.
:param model_credential: Dictionary containing the model credentials.
:param model_params: Parameters for the model.
:param provider: Model provider instance.
:param raise_exception: Whether to raise an exception on validation failure.
:return: Boolean indicating whether the credentials are valid.
"""
if model_type != 'RERANKER':
raise AppApiException(
ValidCode.valid_error.value,
_('{model_type} Model type is not supported').format(model_type=model_type)
)
required_keys = ['dashscope_api_key']
for key in required_keys:
if key not in model_credential:
if raise_exception:
raise AppApiException(
ValidCode.valid_error.value,
_('{key} is required').format(key=key)
)
return False
try:
model: AliyunBaiLianReranker = provider.get_model(model_type, model_name, model_credential)
model.compress_documents([Document(page_content=_('Hello'))], _('Hello'))
except Exception as e:
traceback.print_exc()
if isinstance(e, AppApiException):
raise e
if raise_exception:
raise AppApiException(
ValidCode.valid_error.value,
_('Verification failed, please check whether the parameters are correct: {error}').format(error=str(e))
)
return False
return True
def encryption_dict(self, model: Dict[str, object]) -> Dict[str, object]:
"""
Encrypt sensitive fields in the model dictionary.
:param model: Dictionary containing model details.
:return: Dictionary with encrypted sensitive fields.
"""
return {
**model,
'dashscope_api_key': super().encryption(model.get('dashscope_api_key', ''))
}

View File

@ -0,0 +1,92 @@
# coding=utf-8
import traceback
from typing import Dict, Any
from django.utils.translation import gettext as _
from common.exception.app_exception import AppApiException
from common.forms import BaseForm, PasswordInputField
from models_provider.base_model_provider import BaseModelCredential, ValidCode
class AliyunBaiLianSTTModelCredential(BaseForm, BaseModelCredential):
"""
Credential class for the Aliyun BaiLian STT (Speech-to-Text) model.
Provides validation and encryption for the model credentials.
"""
api_key = PasswordInputField("API Key", required=True)
def is_valid(
self,
model_type: str,
model_name: str,
model_credential: Dict[str, Any],
model_params: Dict[str, Any],
provider,
raise_exception: bool = False
) -> bool:
"""
Validate the model credentials.
:param model_type: Type of the model (e.g., 'STT').
:param model_name: Name of the model.
:param model_credential: Dictionary containing the model credentials.
:param model_params: Parameters for the model.
:param provider: Model provider instance.
:param raise_exception: Whether to raise an exception on validation failure.
:return: Boolean indicating whether the credentials are valid.
"""
model_type_list = provider.get_model_type_list()
if not any(mt.get('value') == model_type for mt in model_type_list):
raise AppApiException(
ValidCode.valid_error.value,
_('{model_type} Model type is not supported').format(model_type=model_type)
)
required_keys = ['api_key']
for key in required_keys:
if key not in model_credential:
if raise_exception:
raise AppApiException(
ValidCode.valid_error.value,
_('{key} is required').format(key=key)
)
return False
try:
model = provider.get_model(model_type, model_name, model_credential)
model.check_auth()
except Exception as e:
traceback.print_exc()
if isinstance(e, AppApiException):
raise e
if raise_exception:
raise AppApiException(
ValidCode.valid_error.value,
_('Verification failed, please check whether the parameters are correct: {error}').format(error=str(e))
)
return False
return True
def encryption_dict(self, model: Dict[str, object]) -> Dict[str, object]:
"""
Encrypt sensitive fields in the model dictionary.
:param model: Dictionary containing model details.
:return: Dictionary with encrypted sensitive fields.
"""
return {
**model,
'api_key': super().encryption(model.get('api_key', ''))
}
def get_model_params_setting_form(self, model_name: str):
"""
Get the parameter setting form for the specified model.
:param model_name: Name of the model.
:return: Parameter setting form (not implemented).
"""
pass

View File

@ -0,0 +1,147 @@
# coding=utf-8
import traceback
from typing import Dict, Any
from django.utils.translation import gettext_lazy as _, gettext
from common.exception.app_exception import AppApiException
from common.forms import BaseForm, PasswordInputField, SingleSelect, SliderField, TooltipLabel
from models_provider.base_model_provider import BaseModelCredential, ValidCode
class QwenModelParams(BaseForm):
"""
Parameters class for the Qwen Text-to-Image model.
Defines fields such as image size, number of images, and style.
"""
size = SingleSelect(
TooltipLabel(_('Image size'), _('Specify the size of the generated image, such as: 1024x1024')),
required=True,
default_value='1024*1024',
option_list=[
{'value': '1024*1024', 'label': '1024*1024'},
{'value': '720*1280', 'label': '720*1280'},
{'value': '768*1152', 'label': '768*1152'},
{'value': '1280*720', 'label': '1280*720'},
],
text_field='label',
value_field='value'
)
n = SliderField(
TooltipLabel(_('Number of pictures'), _('Specify the number of generated images')),
required=True,
default_value=1,
_min=1,
_max=4,
_step=1,
precision=0
)
style = SingleSelect(
TooltipLabel(_('Style'), _('Specify the style of generated images')),
required=True,
default_value='<auto>',
option_list=[
{'value': '<auto>', 'label': _('Default value, the image style is randomly output by the model')},
{'value': '<photography>', 'label': _('photography')},
{'value': '<portrait>', 'label': _('Portraits')},
{'value': '<3d cartoon>', 'label': _('3D cartoon')},
{'value': '<anime>', 'label': _('animation')},
{'value': '<oil painting>', 'label': _('painting')},
{'value': '<watercolor>', 'label': _('watercolor')},
{'value': '<sketch>', 'label': _('sketch')},
{'value': '<chinese painting>', 'label': _('Chinese painting')},
{'value': '<flat illustration>', 'label': _('flat illustration')},
],
text_field='label',
value_field='value'
)
class QwenTextToImageModelCredential(BaseForm, BaseModelCredential):
"""
Credential class for the Qwen Text-to-Image model.
Provides validation and encryption for the model credentials.
"""
api_key = PasswordInputField('API Key', required=True)
def is_valid(
self,
model_type: str,
model_name: str,
model_credential: Dict[str, Any],
model_params: Dict[str, Any],
provider,
raise_exception: bool = False
) -> bool:
"""
Validate the model credentials.
:param model_type: Type of the model (e.g., 'TEXT_TO_IMAGE').
:param model_name: Name of the model.
:param model_credential: Dictionary containing the model credentials.
:param model_params: Parameters for the model.
:param provider: Model provider instance.
:param raise_exception: Whether to raise an exception on validation failure.
:return: Boolean indicating whether the credentials are valid.
"""
model_type_list = provider.get_model_type_list()
if not any(mt.get('value') == model_type for mt in model_type_list):
raise AppApiException(
ValidCode.valid_error.value,
gettext('{model_type} Model type is not supported').format(model_type=model_type)
)
required_keys = ['api_key']
for key in required_keys:
if key not in model_credential:
if raise_exception:
raise AppApiException(
ValidCode.valid_error.value,
gettext('{key} is required').format(key=key)
)
return False
try:
model = provider.get_model(model_type, model_name, model_credential, **model_params)
res = model.check_auth()
print(res)
except Exception as e:
traceback.print_exc()
if isinstance(e, AppApiException):
raise e
if raise_exception:
raise AppApiException(
ValidCode.valid_error.value,
gettext(
'Verification failed, please check whether the parameters are correct: {error}'
).format(error=str(e))
)
return False
return True
def encryption_dict(self, model: Dict[str, object]) -> Dict[str, object]:
"""
Encrypt sensitive fields in the model dictionary.
:param model: Dictionary containing model details.
:return: Dictionary with encrypted sensitive fields.
"""
return {
**model,
'api_key': super().encryption(model.get('api_key', ''))
}
def get_model_params_setting_form(self, model_name: str):
"""
Get the parameter setting form for the specified model.
:param model_name: Name of the model.
:return: Parameter setting form.
"""
return QwenModelParams()

View File

@ -0,0 +1,139 @@
# coding=utf-8
import traceback
from typing import Dict
from django.utils.translation import gettext_lazy as _, gettext
from common.exception.app_exception import AppApiException
from common.forms import BaseForm, PasswordInputField, SingleSelect, SliderField, TooltipLabel
from models_provider.base_model_provider import BaseModelCredential, ValidCode
class AliyunBaiLianTTSModelGeneralParams(BaseForm):
"""
Parameters class for the Aliyun BaiLian TTS (Text-to-Speech) model.
Defines fields such as voice and speech rate.
"""
voice = SingleSelect(
TooltipLabel(_('Timbre'), _('Chinese sounds can support mixed scenes of Chinese and English')),
required=True,
default_value='longxiaochun',
text_field='value',
value_field='value',
option_list=[
{'text': _('Long Xiaochun'), 'value': 'longxiaochun'},
{'text': _('Long Xiaoxia'), 'value': 'longxiaoxia'},
{'text': _('Long Xiaochen'), 'value': 'longxiaocheng'},
{'text': _('Long Xiaobai'), 'value': 'longxiaobai'},
{'text': _('Long Laotie'), 'value': 'longlaotie'},
{'text': _('Long Shu'), 'value': 'longshu'},
{'text': _('Long Shuo'), 'value': 'longshuo'},
{'text': _('Long Jing'), 'value': 'longjing'},
{'text': _('Long Miao'), 'value': 'longmiao'},
{'text': _('Long Yue'), 'value': 'longyue'},
{'text': _('Long Yuan'), 'value': 'longyuan'},
{'text': _('Long Fei'), 'value': 'longfei'},
{'text': _('Long Jielidou'), 'value': 'longjielidou'},
{'text': _('Long Tong'), 'value': 'longtong'},
{'text': _('Long Xiang'), 'value': 'longxiang'},
{'text': 'Stella', 'value': 'loongstella'},
{'text': 'Bella', 'value': 'loongbella'},
]
)
speech_rate = SliderField(
TooltipLabel(_('Speaking speed'), _('[0.5, 2], the default is 1, usually one decimal place is enough')),
required=True,
default_value=1,
_min=0.5,
_max=2,
_step=0.1,
precision=1
)
class AliyunBaiLianTTSModelCredential(BaseForm, BaseModelCredential):
"""
Credential class for the Aliyun BaiLian TTS (Text-to-Speech) model.
Provides validation and encryption for the model credentials.
"""
api_key = PasswordInputField("API Key", required=True)
def is_valid(
self,
model_type: str,
model_name: str,
model_credential: Dict[str, object],
model_params,
provider,
raise_exception: bool = False
) -> bool:
"""
Validate the model credentials.
:param model_type: Type of the model (e.g., 'TTS').
:param model_name: Name of the model.
:param model_credential: Dictionary containing the model credentials.
:param model_params: Parameters for the model.
:param provider: Model provider instance.
:param raise_exception: Whether to raise an exception on validation failure.
:return: Boolean indicating whether the credentials are valid.
"""
model_type_list = provider.get_model_type_list()
if not any(mt.get('value') == model_type for mt in model_type_list):
raise AppApiException(
ValidCode.valid_error.value,
gettext('{model_type} Model type is not supported').format(model_type=model_type)
)
required_keys = ['api_key']
for key in required_keys:
if key not in model_credential:
if raise_exception:
raise AppApiException(
ValidCode.valid_error.value,
gettext('{key} is required').format(key=key)
)
return False
try:
model = provider.get_model(model_type, model_name, model_credential, **model_params)
model.check_auth()
except Exception as e:
traceback.print_exc()
if isinstance(e, AppApiException):
raise e
if raise_exception:
raise AppApiException(
ValidCode.valid_error.value,
gettext(
'Verification failed, please check whether the parameters are correct: {error}'
).format(error=str(e))
)
return False
return True
def encryption_dict(self, model: Dict[str, object]) -> Dict[str, object]:
"""
Encrypt sensitive fields in the model dictionary.
:param model: Dictionary containing model details.
:return: Dictionary with encrypted sensitive fields.
"""
return {
**model,
'api_key': super().encryption(model.get('api_key', ''))
}
def get_model_params_setting_form(self, model_name: str):
"""
Get the parameter setting form for the specified model.
:param model_name: Name of the model.
:return: Parameter setting form.
"""
return AliyunBaiLianTTSModelGeneralParams()

View File

@ -0,0 +1 @@
<svg id="图层_1" data-name="图层 1" xmlns="http://www.w3.org/2000/svg" viewBox="0 0 50 50"><defs><style>.cls-1{fill:#6d49f6;}.cls-2{fill:#6be8d2;}.cls-3{fill:#9b8ee8;}.cls-4{fill:#0d22d2;}.cls-5{fill:#2c53dc;}.cls-6{fill:#5eccc9;}.cls-7{fill:#fff;}</style></defs><title>【icon】阿里百炼大模型</title><path class="cls-1" d="M35.463,6.454,25,12.5,14.186,18.737,4.032,12.882a2.33,2.33,0,0,1,.851-.851L23.829,1.089a2.358,2.358,0,0,1,2.342,0Z"/><polygon class="cls-2" points="35.825 31.232 35.825 31.243 25 37.491 35.814 31.232 35.825 31.232"/><polygon class="cls-2" points="35.825 31.232 35.825 31.243 25 37.491 35.814 31.232 35.825 31.232"/><polygon class="cls-2" points="35.825 31.232 35.825 31.243 25 37.491 35.814 31.232 35.825 31.232"/><polygon class="cls-2" points="35.825 31.232 35.825 31.243 25 37.491 35.814 31.232 35.825 31.232"/><polygon class="cls-2" points="35.825 31.232 35.825 31.243 25 37.491 35.814 31.232 35.825 31.232"/><path class="cls-3" d="M45.117,12.031,35.463,6.454,25,12.5,14.186,18.737v.01L25,25h.011l10.814-6.248,10.143-5.865A2.33,2.33,0,0,0,45.117,12.031Z"/><path class="cls-4" d="M4.032,12.882a2.254,2.254,0,0,0-.32,1.171V35.926a2.319,2.319,0,0,0,.32,1.182l10.143-5.865v-12.5h.011v-.01Z"/><polygon class="cls-5" points="25 24.995 14.175 31.243 14.175 18.747 14.186 18.747 25 24.995"/><path class="cls-6" d="M45.968,37.1a2.278,2.278,0,0,1-.851.862L26.171,48.9a2.358,2.358,0,0,1-2.342,0l-9.3-5.375,10.462-6.035H25l10.825-6.248Z"/><path class="cls-2" d="M46.288,25.2V35.926a2.251,2.251,0,0,1-.32,1.171L35.825,31.243v-.011Z"/><path class="cls-7" d="M46.288,14.053V25.2L35.825,31.232V18.747l10.143-5.865A2.254,2.254,0,0,1,46.288,14.053Z"/><polygon class="cls-7" points="35.825 18.747 35.825 31.232 35.814 31.232 25.011 24.995 35.825 18.747"/><path class="cls-2" d="M35.814,31.232,25.011,25H25L14.175,31.243,4.032,37.108a2.33,2.33,0,0,0,.851.851l9.644,5.567,10.462-6.035H25l10.825-6.248v-.011Z"/></svg>

After

Width:  |  Height:  |  Size: 1.9 KiB

View File

@ -0,0 +1,66 @@
# coding=utf-8
"""
@project: MaxKB
@Author
@file embedding.py
@date2024/10/16 16:34
@desc:
"""
from functools import reduce
from typing import Dict, List
from langchain_community.embeddings import DashScopeEmbeddings
from langchain_community.embeddings.dashscope import embed_with_retry
from models_provider.base_model_provider import MaxKBBaseModel
def proxy_embed_documents(texts: List[str], step_size, embed_documents):
value = [embed_documents(texts[start_index:start_index + step_size]) for start_index in
range(0, len(texts), step_size)]
return reduce(lambda x, y: [*x, *y], value, [])
class AliyunBaiLianEmbedding(MaxKBBaseModel, DashScopeEmbeddings):
@staticmethod
def new_instance(model_type, model_name, model_credential: Dict[str, object], **model_kwargs):
return AliyunBaiLianEmbedding(
model=model_name,
dashscope_api_key=model_credential.get('dashscope_api_key')
)
def embed_documents(self, texts: List[str]) -> List[List[float]]:
if self.model == 'text-embedding-v3':
return proxy_embed_documents(texts, 6, self._embed_documents)
return self._embed_documents(texts)
def _embed_documents(self, texts: List[str]) -> List[List[float]]:
"""Call out to DashScope's embedding endpoint for embedding search docs.
Args:
texts: The list of texts to embed.
chunk_size: The chunk size of embeddings. If None, will use the chunk size
specified by the class.
Returns:
List of embeddings, one for each text.
"""
embeddings = embed_with_retry(
self, input=texts, text_type="document", model=self.model
)
embedding_list = [item["embedding"] for item in embeddings]
return embedding_list
def embed_query(self, text: str) -> List[float]:
"""Call out to DashScope's embedding endpoint for embedding query text.
Args:
text: The text to embed.
Returns:
Embedding for the text.
"""
embedding = embed_with_retry(
self, input=[text], text_type="document", model=self.model
)[0]["embedding"]
return embedding

View File

@ -0,0 +1,23 @@
# coding=utf-8
from typing import Dict
from models_provider.base_model_provider import MaxKBBaseModel
from models_provider.impl.base_chat_open_ai import BaseChatOpenAI
class QwenVLChatModel(MaxKBBaseModel, BaseChatOpenAI):
@staticmethod
def new_instance(model_type, model_name, model_credential: Dict[str, object], **model_kwargs):
optional_params = MaxKBBaseModel.filter_optional_params(model_kwargs)
chat_tong_yi = QwenVLChatModel(
model_name=model_name,
openai_api_key=model_credential.get('api_key'),
openai_api_base='https://dashscope.aliyuncs.com/compatible-mode/v1',
# stream_options={"include_usage": True},
streaming=True,
stream_usage=True,
**optional_params,
)
return chat_tong_yi

View File

@ -0,0 +1,24 @@
#!/usr/bin/env python
# -*- coding: UTF-8 -*-
from typing import Dict
from models_provider.base_model_provider import MaxKBBaseModel
from models_provider.impl.base_chat_open_ai import BaseChatOpenAI
class BaiLianChatModel(MaxKBBaseModel, BaseChatOpenAI):
@staticmethod
def is_cache_model():
return False
@staticmethod
def new_instance(model_type, model_name, model_credential: Dict[str, object], **model_kwargs):
optional_params = MaxKBBaseModel.filter_optional_params(model_kwargs)
if 'qwen-omni-turbo' in model_name or 'qwq' in model_name:
optional_params['streaming'] = True
return BaiLianChatModel(
model=model_name,
openai_api_base=model_credential.get('api_base'),
openai_api_key=model_credential.get('api_key'),
**optional_params
)

View File

@ -0,0 +1,20 @@
# coding=utf-8
"""
@project: MaxKB
@Author
@file reranker.py.py
@date2024/9/2 16:42
@desc:
"""
from typing import Dict
from langchain_community.document_compressors import DashScopeRerank
from models_provider.base_model_provider import MaxKBBaseModel
class AliyunBaiLianReranker(MaxKBBaseModel, DashScopeRerank):
@staticmethod
def new_instance(model_type, model_name, model_credential: Dict[str, object], **model_kwargs):
return AliyunBaiLianReranker(model=model_name, dashscope_api_key=model_credential.get('dashscope_api_key'),
top_n=model_kwargs.get('top_n', 3))

View File

@ -0,0 +1,75 @@
import os
import tempfile
from typing import Dict
import dashscope
from dashscope.audio.asr import (Recognition)
from pydub import AudioSegment
from models_provider.base_model_provider import MaxKBBaseModel
from models_provider.impl.base_stt import BaseSpeechToText
class AliyunBaiLianSpeechToText(MaxKBBaseModel, BaseSpeechToText):
api_key: str
model: str
def __init__(self, **kwargs):
super().__init__(**kwargs)
self.api_key = kwargs.get('api_key')
self.model = kwargs.get('model')
@staticmethod
def new_instance(model_type, model_name, model_credential: Dict[str, object], **model_kwargs):
optional_params = {}
if 'max_tokens' in model_kwargs and model_kwargs['max_tokens'] is not None:
optional_params['max_tokens'] = model_kwargs['max_tokens']
if 'temperature' in model_kwargs and model_kwargs['temperature'] is not None:
optional_params['temperature'] = model_kwargs['temperature']
if model_name == 'qwen-omni-turbo':
optional_params['streaming'] = True
return AliyunBaiLianSpeechToText(
model=model_name,
api_key=model_credential.get('api_key'),
**optional_params,
)
def check_auth(self):
cwd = os.path.dirname(os.path.abspath(__file__))
with open(f'{cwd}/iat_mp3_16k.mp3', 'rb') as f:
self.speech_to_text(f)
def speech_to_text(self, audio_file):
dashscope.api_key = self.api_key
recognition = Recognition(model=self.model,
format='mp3',
sample_rate=16000,
callback=None)
with tempfile.NamedTemporaryFile(delete=False) as temp_file:
# 将上传的文件保存到临时文件中
temp_file.write(audio_file.read())
# 获取临时文件的路径
temp_file_path = temp_file.name
try:
audio = AudioSegment.from_file(temp_file_path)
if audio.channels != 1:
audio = audio.set_channels(1)
audio = audio.set_frame_rate(16000)
# 将转换后的音频文件保存到临时文件中
audio.export(temp_file_path, format='mp3')
# 识别临时文件
result = recognition.call(temp_file_path)
text = ''
if result.status_code == 200:
result_sentence = result.get_sentence()
if result_sentence is not None:
for sentence in result_sentence:
text += sentence['text']
return text
else:
raise Exception('Error: ', result.message)
finally:
# 删除临时文件
os.remove(temp_file_path)

View File

@ -0,0 +1,59 @@
# coding=utf-8
from http import HTTPStatus
from typing import Dict
from dashscope import ImageSynthesis
from django.utils.translation import gettext
from langchain_community.chat_models import ChatTongyi
from langchain_core.messages import HumanMessage
from models_provider.base_model_provider import MaxKBBaseModel
from models_provider.impl.base_tti import BaseTextToImage
class QwenTextToImageModel(MaxKBBaseModel, BaseTextToImage):
api_key: str
model_name: str
params: dict
def __init__(self, **kwargs):
super().__init__(**kwargs)
self.api_key = kwargs.get('api_key')
self.model_name = kwargs.get('model_name')
self.params = kwargs.get('params')
@staticmethod
def new_instance(model_type, model_name, model_credential: Dict[str, object], **model_kwargs):
optional_params = {'params': {'size': '1024*1024', 'style': '<auto>', 'n': 1}}
for key, value in model_kwargs.items():
if key not in ['model_id', 'use_local', 'streaming']:
optional_params['params'][key] = value
chat_tong_yi = QwenTextToImageModel(
model_name=model_name,
api_key=model_credential.get('api_key'),
**optional_params,
)
return chat_tong_yi
def is_cache_model(self):
return False
def check_auth(self):
chat = ChatTongyi(api_key=self.api_key, model_name='qwen-max')
chat.invoke([HumanMessage([{"type": "text", "text": gettext('Hello')}])])
def generate_image(self, prompt: str, negative_prompt: str = None):
# api_base='https://dashscope.aliyuncs.com/compatible-mode/v1',
rsp = ImageSynthesis.call(api_key=self.api_key,
model=self.model_name,
prompt=prompt,
negative_prompt=negative_prompt,
**self.params)
file_urls = []
if rsp.status_code == HTTPStatus.OK:
for result in rsp.output.results:
file_urls.append(result.url)
else:
print('sync_call Failed, status_code: %s, code: %s, message: %s' %
(rsp.status_code, rsp.code, rsp.message))
return file_urls

View File

@ -0,0 +1,57 @@
from typing import Dict
import dashscope
from django.utils.translation import gettext as _
from common.utils.common import _remove_empty_lines
from models_provider.base_model_provider import MaxKBBaseModel
from models_provider.impl.base_tts import BaseTextToSpeech
class AliyunBaiLianTextToSpeech(MaxKBBaseModel, BaseTextToSpeech):
api_key: str
model: str
params: dict
def __init__(self, **kwargs):
super().__init__(**kwargs)
self.api_key = kwargs.get('api_key')
self.model = kwargs.get('model')
self.params = kwargs.get('params')
@staticmethod
def new_instance(model_type, model_name, model_credential: Dict[str, object], **model_kwargs):
optional_params = {'params': {'voice': 'longxiaochun', 'speech_rate': 1.0}}
for key, value in model_kwargs.items():
if key not in ['model_id', 'use_local', 'streaming']:
optional_params['params'][key] = value
return AliyunBaiLianTextToSpeech(
model=model_name,
api_key=model_credential.get('api_key'),
**optional_params,
)
def check_auth(self):
self.text_to_speech(_('Hello'))
def text_to_speech(self, text):
dashscope.api_key = self.api_key
text = _remove_empty_lines(text)
if 'sambert' in self.model:
from dashscope.audio.tts import SpeechSynthesizer
audio = SpeechSynthesizer.call(model=self.model, text=text, **self.params).get_audio_data()
else:
from dashscope.audio.tts_v2 import SpeechSynthesizer
synthesizer = SpeechSynthesizer(model=self.model, **self.params)
audio = synthesizer.call(text)
if audio is None:
raise Exception('Failed to generate audio')
if type(audio) == str:
print(audio)
raise Exception(audio)
return audio
def is_cache_model(self):
return False

View File

@ -0,0 +1,8 @@
# coding=utf-8
"""
@project: maxkb
@Author
@file __init__.py.py
@date2024/3/28 16:25
@desc:
"""

View File

@ -0,0 +1,62 @@
# coding=utf-8
"""
@project: maxkb
@Author
@file openai_model_provider.py
@date2024/3/28 16:26
@desc:
"""
import os
from common.utils.common import get_file_content
from models_provider.base_model_provider import IModelProvider, ModelProvideInfo, ModelInfo, \
ModelTypeConst, ModelInfoManage
from models_provider.impl.anthropic_model_provider.credential.image import AnthropicImageModelCredential
from models_provider.impl.anthropic_model_provider.credential.llm import AnthropicLLMModelCredential
from models_provider.impl.anthropic_model_provider.model.image import AnthropicImage
from models_provider.impl.anthropic_model_provider.model.llm import AnthropicChatModel
from maxkb.conf import PROJECT_DIR
openai_llm_model_credential = AnthropicLLMModelCredential()
openai_image_model_credential = AnthropicImageModelCredential()
model_info_list = [
ModelInfo('claude-3-opus-20240229', '', ModelTypeConst.LLM,
openai_llm_model_credential, AnthropicChatModel
),
ModelInfo('claude-3-sonnet-20240229', '', ModelTypeConst.LLM, openai_llm_model_credential,
AnthropicChatModel),
ModelInfo('claude-3-haiku-20240307', '', ModelTypeConst.LLM, openai_llm_model_credential,
AnthropicChatModel),
ModelInfo('claude-3-5-sonnet-20240620', '', ModelTypeConst.LLM, openai_llm_model_credential,
AnthropicChatModel),
ModelInfo('claude-3-5-haiku-20241022', '', ModelTypeConst.LLM, openai_llm_model_credential,
AnthropicChatModel),
ModelInfo('claude-3-5-sonnet-20241022', '', ModelTypeConst.LLM, openai_llm_model_credential,
AnthropicChatModel),
]
image_model_info = [
ModelInfo('claude-3-5-sonnet-20241022', '', ModelTypeConst.IMAGE, openai_image_model_credential,
AnthropicImage),
]
model_info_manage = (
ModelInfoManage.builder()
.append_model_info_list(model_info_list)
.append_default_model_info(model_info_list[0])
.append_model_info_list(image_model_info)
.append_default_model_info(image_model_info[0])
.build()
)
class AnthropicModelProvider(IModelProvider):
def get_model_info_manage(self):
return model_info_manage
def get_model_provide_info(self):
return ModelProvideInfo(provider='model_anthropic_provider', name='Anthropic', icon=get_file_content(
os.path.join(PROJECT_DIR, "apps", 'models_provider', 'impl', 'anthropic_model_provider', 'icon',
'anthropic_icon_svg')))

View File

@ -0,0 +1,53 @@
# coding=utf-8
import traceback
from typing import Dict
from django.utils.translation import gettext_lazy as _, gettext
from langchain_core.messages import HumanMessage
from common import forms
from common.exception.app_exception import AppApiException
from common.forms import BaseForm, TooltipLabel
from models_provider.base_model_provider import BaseModelCredential, ValidCode
class AnthropicImageModelCredential(BaseForm, BaseModelCredential):
api_base = forms.TextInputField(_('API URL'), required=True)
api_key = forms.PasswordInputField(_('API Key'), required=True)
def is_valid(self, model_type: str, model_name, model_credential: Dict[str, object], model_params, provider,
raise_exception=False):
model_type_list = provider.get_model_type_list()
if not any(list(filter(lambda mt: mt.get('value') == model_type, model_type_list))):
raise AppApiException(ValidCode.valid_error.value,
gettext('{model_type} Model type is not supported').format(model_type=model_type))
for key in ['api_base', 'api_key']:
if key not in model_credential:
if raise_exception:
raise AppApiException(ValidCode.valid_error.value, gettext('{key} is required').format(key=key))
else:
return False
try:
model = provider.get_model(model_type, model_name, model_credential)
res = model.stream([HumanMessage(content=[{"type": "text", "text": gettext("Hello")}])])
for chunk in res:
print(chunk)
except Exception as e:
traceback.print_exc()
if isinstance(e, AppApiException):
raise e
if raise_exception:
raise AppApiException(ValidCode.valid_error.value,
gettext(
'Verification failed, please check whether the parameters are correct: {error}').format(
error=str(e)))
else:
return False
return True
def encryption_dict(self, model: Dict[str, object]):
return {**model, 'api_key': super().encryption(model.get('api_key', ''))}
def get_model_params_setting_form(self, model_name):
pass

View File

@ -0,0 +1,78 @@
# coding=utf-8
"""
@project: MaxKB
@Author
@file llm.py
@date2024/7/11 18:32
@desc:
"""
import traceback
from typing import Dict
from langchain_core.messages import HumanMessage
from common import forms
from common.exception.app_exception import AppApiException
from common.forms import BaseForm, TooltipLabel
from models_provider.base_model_provider import BaseModelCredential, ValidCode
from django.utils.translation import gettext_lazy as _, gettext
class AnthropicLLMModelParams(BaseForm):
temperature = forms.SliderField(TooltipLabel(_('Temperature'),
_('Higher values make the output more random, while lower values make it more focused and deterministic')),
required=True, default_value=0.7,
_min=0.1,
_max=1.0,
_step=0.01,
precision=2)
max_tokens = forms.SliderField(
TooltipLabel(_('Output the maximum Tokens'),
_('Specify the maximum number of tokens that the model can generate')),
required=True, default_value=800,
_min=1,
_max=100000,
_step=1,
precision=0)
class AnthropicLLMModelCredential(BaseForm, BaseModelCredential):
def is_valid(self, model_type: str, model_name, model_credential: Dict[str, object], model_params, provider,
raise_exception=False):
model_type_list = provider.get_model_type_list()
if not any(list(filter(lambda mt: mt.get('value') == model_type, model_type_list))):
raise AppApiException(ValidCode.valid_error.value,
gettext('{model_type} Model type is not supported').format(model_type=model_type))
for key in ['api_base', 'api_key']:
if key not in model_credential:
if raise_exception:
raise AppApiException(ValidCode.valid_error.value, gettext('{key} is required').format(key=key))
else:
return False
try:
model = provider.get_model(model_type, model_name, model_credential)
model.invoke([HumanMessage(content=gettext('Hello'))])
except Exception as e:
traceback.print_exc()
if isinstance(e, AppApiException):
raise e
if raise_exception:
raise AppApiException(ValidCode.valid_error.value,
gettext(
'Verification failed, please check whether the parameters are correct: {error}').format(
error=str(e)))
else:
return False
return True
def encryption_dict(self, model: Dict[str, object]):
return {**model, 'api_key': super().encryption(model.get('api_key', ''))}
api_base = forms.TextInputField(_('API URL'), required=True)
api_key = forms.PasswordInputField(_('API Key'), required=True)
def get_model_params_setting_form(self, model_name):
return AnthropicLLMModelParams()

View File

@ -0,0 +1 @@
<svg xmlns="http://www.w3.org/2000/svg" shape-rendering="geometricPrecision" text-rendering="geometricPrecision" image-rendering="optimizeQuality" fill-rule="evenodd" clip-rule="evenodd" viewBox="0 0 512 512"><rect fill="#CC9B7A" width="512" height="512" rx="104.187" ry="105.042"/><path fill="#1F1F1E" fill-rule="nonzero" d="M318.663 149.787h-43.368l78.952 212.423 43.368.004-78.952-212.427zm-125.326 0l-78.952 212.427h44.255l15.932-44.608 82.846-.004 16.107 44.612h44.255l-79.126-212.427h-45.317zm-4.251 128.341l26.91-74.701 27.083 74.701h-53.993z"/></svg>

After

Width:  |  Height:  |  Size: 558 B

View File

@ -0,0 +1,26 @@
from typing import Dict
from langchain_anthropic import ChatAnthropic
from common.config.tokenizer_manage_config import TokenizerManage
from models_provider.base_model_provider import MaxKBBaseModel
def custom_get_token_ids(text: str):
tokenizer = TokenizerManage.get_tokenizer()
return tokenizer.encode(text)
class AnthropicImage(MaxKBBaseModel, ChatAnthropic):
@staticmethod
def new_instance(model_type, model_name, model_credential: Dict[str, object], **model_kwargs):
optional_params = MaxKBBaseModel.filter_optional_params(model_kwargs)
return AnthropicImage(
model=model_name,
anthropic_api_url=model_credential.get('api_base'),
anthropic_api_key=model_credential.get('api_key'),
# stream_options={"include_usage": True},
streaming=True,
**optional_params,
)

View File

@ -0,0 +1,53 @@
# coding=utf-8
"""
@project: maxkb
@Author
@file llm.py
@date2024/4/18 15:28
@desc:
"""
from typing import List, Dict
from langchain_anthropic import ChatAnthropic
from langchain_core.messages import BaseMessage, get_buffer_string
from common.config.tokenizer_manage_config import TokenizerManage
from models_provider.base_model_provider import MaxKBBaseModel
def custom_get_token_ids(text: str):
tokenizer = TokenizerManage.get_tokenizer()
return tokenizer.encode(text)
class AnthropicChatModel(MaxKBBaseModel, ChatAnthropic):
@staticmethod
def is_cache_model():
return False
@staticmethod
def new_instance(model_type, model_name, model_credential: Dict[str, object], **model_kwargs):
optional_params = MaxKBBaseModel.filter_optional_params(model_kwargs)
azure_chat_open_ai = AnthropicChatModel(
model=model_name,
anthropic_api_url=model_credential.get('api_base'),
anthropic_api_key=model_credential.get('api_key'),
**optional_params,
custom_get_token_ids=custom_get_token_ids
)
return azure_chat_open_ai
def get_num_tokens_from_messages(self, messages: List[BaseMessage]) -> int:
try:
return super().get_num_tokens_from_messages(messages)
except Exception as e:
tokenizer = TokenizerManage.get_tokenizer()
return sum([len(tokenizer.encode(get_buffer_string([m]))) for m in messages])
def get_num_tokens(self, text: str) -> int:
try:
return super().get_num_tokens(text)
except Exception as e:
tokenizer = TokenizerManage.get_tokenizer()
return len(tokenizer.encode(text))

View File

@ -0,0 +1,2 @@
#!/usr/bin/env python
# -*- coding: UTF-8 -*-

View File

@ -0,0 +1,154 @@
#!/usr/bin/env python
# -*- coding: UTF-8 -*-
import os
from common.utils.common import get_file_content
from models_provider.base_model_provider import (
IModelProvider, ModelProvideInfo, ModelInfo, ModelTypeConst, ModelInfoManage
)
from models_provider.impl.aws_bedrock_model_provider.credential.embedding import BedrockEmbeddingCredential
from models_provider.impl.aws_bedrock_model_provider.credential.llm import BedrockLLMModelCredential
from models_provider.impl.aws_bedrock_model_provider.model.embedding import BedrockEmbeddingModel
from models_provider.impl.aws_bedrock_model_provider.model.llm import BedrockModel
from maxkb.conf import PROJECT_DIR
from django.utils.translation import gettext as _
def _create_model_info(model_name, description, model_type, credential_class, model_class):
return ModelInfo(
name=model_name,
desc=description,
model_type=model_type,
model_credential=credential_class(),
model_class=model_class
)
def _get_aws_bedrock_icon_path():
return os.path.join(PROJECT_DIR, "apps", 'models_provider', 'impl', 'aws_bedrock_model_provider',
'icon', 'bedrock_icon_svg')
def _initialize_model_info():
model_info_list = [
_create_model_info(
'anthropic.claude-v2:1',
_('An update to Claude 2 that doubles the context window and improves reliability, hallucination rates, and evidence-based accuracy in long documents and RAG contexts.'),
ModelTypeConst.LLM,
BedrockLLMModelCredential,
BedrockModel
),
_create_model_info(
'anthropic.claude-v2',
_('Anthropic is a powerful model that can handle a variety of tasks, from complex dialogue and creative content generation to detailed command obedience.'),
ModelTypeConst.LLM,
BedrockLLMModelCredential,
BedrockModel
),
_create_model_info(
'anthropic.claude-3-haiku-20240307-v1:0',
_("The Claude 3 Haiku is Anthropic's fastest and most compact model, with near-instant responsiveness. The model can answer simple queries and requests quickly. Customers will be able to build seamless AI experiences that mimic human interactions. Claude 3 Haiku can process images and return text output, and provides 200K context windows."),
ModelTypeConst.LLM,
BedrockLLMModelCredential,
BedrockModel
),
_create_model_info(
'anthropic.claude-3-sonnet-20240229-v1:0',
_("The Claude 3 Sonnet model from Anthropic strikes the ideal balance between intelligence and speed, especially when it comes to handling enterprise workloads. This model offers maximum utility while being priced lower than competing products, and it's been engineered to be a solid choice for deploying AI at scale."),
ModelTypeConst.LLM,
BedrockLLMModelCredential,
BedrockModel
),
_create_model_info(
'anthropic.claude-3-5-sonnet-20240620-v1:0',
_('The Claude 3.5 Sonnet raises the industry standard for intelligence, outperforming competing models and the Claude 3 Opus in extensive evaluations, with the speed and cost-effectiveness of our mid-range models.'),
ModelTypeConst.LLM,
BedrockLLMModelCredential,
BedrockModel
),
_create_model_info(
'anthropic.claude-instant-v1',
_('A faster, more affordable but still very powerful model that can handle a range of tasks including casual conversation, text analysis, summarization and document question answering.'),
ModelTypeConst.LLM,
BedrockLLMModelCredential,
BedrockModel
),
_create_model_info(
'amazon.titan-text-premier-v1:0',
_("Titan Text Premier is the most powerful and advanced model in the Titan Text series, designed to deliver exceptional performance for a variety of enterprise applications. With its cutting-edge features, it delivers greater accuracy and outstanding results, making it an excellent choice for organizations looking for a top-notch text processing solution."),
ModelTypeConst.LLM,
BedrockLLMModelCredential,
BedrockModel
),
_create_model_info(
'amazon.titan-text-lite-v1',
_('Amazon Titan Text Lite is a lightweight, efficient model ideal for fine-tuning English-language tasks, including summarization and copywriting, where customers require smaller, more cost-effective, and highly customizable models.'),
ModelTypeConst.LLM,
BedrockLLMModelCredential,
BedrockModel),
_create_model_info(
'amazon.titan-text-express-v1',
_('Amazon Titan Text Express has context lengths of up to 8,000 tokens, making it ideal for a variety of high-level general language tasks, such as open-ended text generation and conversational chat, as well as support in retrieval-augmented generation (RAG). At launch, the model is optimized for English, but other languages are supported.'),
ModelTypeConst.LLM,
BedrockLLMModelCredential,
BedrockModel),
_create_model_info(
'mistral.mistral-7b-instruct-v0:2',
_('7B dense converter for rapid deployment and easy customization. Small in size yet powerful in a variety of use cases. Supports English and code, as well as 32k context windows.'),
ModelTypeConst.LLM,
BedrockLLMModelCredential,
BedrockModel),
_create_model_info(
'mistral.mistral-large-2402-v1:0',
_('Advanced Mistral AI large-scale language model capable of handling any language task, including complex multilingual reasoning, text understanding, transformation, and code generation.'),
ModelTypeConst.LLM,
BedrockLLMModelCredential,
BedrockModel),
_create_model_info(
'meta.llama3-70b-instruct-v1:0',
_('Ideal for content creation, conversational AI, language understanding, R&D, and enterprise applications'),
ModelTypeConst.LLM,
BedrockLLMModelCredential,
BedrockModel),
_create_model_info(
'meta.llama3-8b-instruct-v1:0',
_('Ideal for limited computing power and resources, edge devices, and faster training times.'),
ModelTypeConst.LLM,
BedrockLLMModelCredential,
BedrockModel),
]
embedded_model_info_list = [
_create_model_info(
'amazon.titan-embed-text-v1',
_('Titan Embed Text is the largest embedding model in the Amazon Titan Embed series and can handle various text embedding tasks, such as text classification, text similarity calculation, etc.'),
ModelTypeConst.EMBEDDING,
BedrockEmbeddingCredential,
BedrockEmbeddingModel
),
]
model_info_manage = ModelInfoManage.builder() \
.append_model_info_list(model_info_list) \
.append_default_model_info(model_info_list[0]) \
.append_model_info_list(embedded_model_info_list) \
.append_default_model_info(embedded_model_info_list[0]) \
.build()
return model_info_manage
class BedrockModelProvider(IModelProvider):
def __init__(self):
self._model_info_manage = _initialize_model_info()
def get_model_info_manage(self):
return self._model_info_manage
def get_model_provide_info(self):
icon_path = _get_aws_bedrock_icon_path()
icon_data = get_file_content(icon_path)
return ModelProvideInfo(
provider='model_aws_bedrock_provider',
name='Amazon Bedrock',
icon=icon_data
)

View File

@ -0,0 +1,53 @@
import traceback
from typing import Dict
from django.utils.translation import gettext as _
from common import forms
from common.exception.app_exception import AppApiException
from common.forms import BaseForm
from models_provider.base_model_provider import BaseModelCredential, ValidCode
from models_provider.impl.aws_bedrock_model_provider.model.embedding import BedrockEmbeddingModel
class BedrockEmbeddingCredential(BaseForm, BaseModelCredential):
def is_valid(self, model_type: str, model_name, model_credential: Dict[str, object], model_params, provider,
raise_exception=False):
model_type_list = provider.get_model_type_list()
if not any(mt.get('value') == model_type for mt in model_type_list):
if raise_exception:
raise AppApiException(ValidCode.valid_error.value,
_('{model_type} Model type is not supported').format(model_type=model_type))
return False
required_keys = ['region_name', 'access_key_id', 'secret_access_key']
if not all(key in model_credential for key in required_keys):
if raise_exception:
raise AppApiException(ValidCode.valid_error.value,
_('The following fields are required: {keys}').format(
keys=", ".join(required_keys)))
return False
try:
model: BedrockEmbeddingModel = provider.get_model(model_type, model_name, model_credential)
aa = model.embed_query(_('Hello'))
print(aa)
except AppApiException:
raise
except Exception as e:
traceback.print_exc()
if raise_exception:
raise AppApiException(ValidCode.valid_error.value,
_('Verification failed, please check whether the parameters are correct: {error}').format(
error=str(e)))
return False
return True
def encryption_dict(self, model: Dict[str, object]):
return {**model, 'secret_access_key': super().encryption(model.get('secret_access_key', ''))}
region_name = forms.TextInputField('Region Name', required=True)
access_key_id = forms.TextInputField('Access Key ID', required=True)
secret_access_key = forms.PasswordInputField('Secret Access Key', required=True)

View File

@ -0,0 +1,76 @@
import traceback
from typing import Dict
from django.utils.translation import gettext_lazy as _, gettext
from langchain_core.messages import HumanMessage
from common import forms
from common.exception.app_exception import AppApiException
from common.forms import BaseForm, TooltipLabel
from models_provider.base_model_provider import ValidCode, BaseModelCredential
class BedrockLLMModelParams(BaseForm):
temperature = forms.SliderField(TooltipLabel(_('Temperature'),
_('Higher values make the output more random, while lower values make it more focused and deterministic')),
required=True, default_value=0.7,
_min=0.1,
_max=1.0,
_step=0.01,
precision=2)
max_tokens = forms.SliderField(
TooltipLabel(_('Output the maximum Tokens'),
_('Specify the maximum number of tokens that the model can generate')),
required=True, default_value=1024,
_min=1,
_max=100000,
_step=1,
precision=0)
class BedrockLLMModelCredential(BaseForm, BaseModelCredential):
def is_valid(self, model_type: str, model_name, model_credential: Dict[str, object], model_params, provider,
raise_exception=False):
model_type_list = provider.get_model_type_list()
if not any(mt.get('value') == model_type for mt in model_type_list):
if raise_exception:
raise AppApiException(ValidCode.valid_error.value,
gettext('{model_type} Model type is not supported').format(model_type=model_type))
return False
required_keys = ['region_name', 'access_key_id', 'secret_access_key']
if not all(key in model_credential for key in required_keys):
if raise_exception:
raise AppApiException(ValidCode.valid_error.value,
gettext('The following fields are required: {keys}').format(
keys=", ".join(required_keys)))
return False
try:
model = provider.get_model(model_type, model_name, model_credential, **model_params)
model.invoke([HumanMessage(content=gettext('Hello'))])
except AppApiException:
raise
except Exception as e:
traceback.print_exc()
if raise_exception:
raise AppApiException(ValidCode.valid_error.value,
gettext(
'Verification failed, please check whether the parameters are correct: {error}').format(
error=str(e)))
return False
return True
def encryption_dict(self, model: Dict[str, object]):
return {**model, 'secret_access_key': super().encryption(model.get('secret_access_key', ''))}
region_name = forms.TextInputField('Region Name', required=True)
access_key_id = forms.TextInputField('Access Key ID', required=True)
secret_access_key = forms.PasswordInputField('Secret Access Key', required=True)
base_url = forms.TextInputField('Proxy URL', required=False)
def get_model_params_setting_form(self, model_name):
return BedrockLLMModelParams()

View File

@ -0,0 +1 @@
<svg xmlns="http://www.w3.org/2000/svg" version="2.0" focusable="false" aria-hidden="true" class="globalNav-1216 globalNav-1213" data-testid="awsc-logo" viewBox="0 0 29 17"><path class="globalNav-1214" d="M8.38 6.17a2.6 2.6 0 00.11.83c.08.232.18.456.3.67a.4.4 0 01.07.21.36.36 0 01-.18.28l-.59.39a.43.43 0 01-.24.08.38.38 0 01-.28-.13 2.38 2.38 0 01-.34-.43c-.09-.16-.18-.34-.28-.55a3.44 3.44 0 01-2.74 1.29 2.54 2.54 0 01-1.86-.67 2.36 2.36 0 01-.68-1.79 2.43 2.43 0 01.84-1.92 3.43 3.43 0 012.29-.72 6.75 6.75 0 011 .07c.35.05.7.12 1.07.2V3.3a2.06 2.06 0 00-.44-1.49 2.12 2.12 0 00-1.52-.43 4.4 4.4 0 00-1 .12 6.85 6.85 0 00-1 .32l-.33.12h-.14c-.14 0-.2-.1-.2-.29v-.46A.62.62 0 012.3.87a.78.78 0 01.27-.2A6 6 0 013.74.25 5.7 5.7 0 015.19.07a3.37 3.37 0 012.44.76 3 3 0 01.77 2.29l-.02 3.05zM4.6 7.59a3 3 0 001-.17 2 2 0 00.88-.6 1.36 1.36 0 00.32-.59 3.18 3.18 0 00.09-.81V5A7.52 7.52 0 006 4.87h-.88a2.13 2.13 0 00-1.38.37 1.3 1.3 0 00-.46 1.08 1.3 1.3 0 00.34 1c.278.216.63.313.98.27zm7.49 1a.56.56 0 01-.36-.09.73.73 0 01-.2-.37L9.35.93a1.39 1.39 0 01-.08-.38c0-.15.07-.23.22-.23h.92a.56.56 0 01.36.09.74.74 0 01.19.37L12.53 7 14 .79a.61.61 0 01.18-.37.59.59 0 01.37-.09h.75a.62.62 0 01.38.09.74.74 0 01.18.37L17.31 7 18.92.76a.74.74 0 01.19-.37.56.56 0 01.36-.09h.87a.21.21 0 01.23.23 1 1 0 010 .15s0 .13-.06.23l-2.26 7.2a.74.74 0 01-.19.37.6.6 0 01-.36.09h-.8a.53.53 0 01-.37-.1.64.64 0 01-.18-.37l-1.45-6-1.44 6a.64.64 0 01-.18.37.55.55 0 01-.37.1l-.82.02zm12 .24a6.29 6.29 0 01-1.44-.16 4.21 4.21 0 01-1.07-.37.69.69 0 01-.29-.26.66.66 0 01-.06-.27V7.3c0-.19.07-.29.21-.29a.57.57 0 01.18 0l.23.1c.32.143.656.25 1 .32.365.08.737.12 1.11.12a2.47 2.47 0 001.36-.31 1 1 0 00.48-.88.88.88 0 00-.25-.65 2.29 2.29 0 00-.94-.49l-1.35-.43a2.83 2.83 0 01-1.49-.94 2.24 2.24 0 01-.47-1.36 2 2 0 01.25-1c.167-.3.395-.563.67-.77a3 3 0 011-.48A4.1 4.1 0 0124.4.08a4.4 4.4 0 01.62 0l.61.1.53.15.39.16c.105.062.2.14.28.23a.57.57 0 01.08.31v.44c0 .2-.07.3-.21.3a.92.92 0 01-.36-.12 4.35 4.35 0 00-1.8-.36 2.51 2.51 0 00-1.24.26.92.92 0 00-.44.84c0 .249.1.488.28.66.295.236.635.41 1 .51l1.32.42a2.88 2.88 0 011.44.9 2.1 2.1 0 01.43 1.31 2.38 2.38 0 01-.24 1.08 2.34 2.34 0 01-.68.82 3 3 0 01-1 .53 4.59 4.59 0 01-1.35.22l.03-.01z"></path><path class="globalNav-1215" d="M25.82 13.43a20.07 20.07 0 01-11.35 3.47A20.54 20.54 0 01.61 11.62c-.29-.26 0-.62.32-.42a27.81 27.81 0 0013.86 3.68 27.54 27.54 0 0010.58-2.16c.52-.22.96.34.45.71z"></path><path class="globalNav-1215" d="M27.1 12c-.4-.51-2.6-.24-3.59-.12-.3 0-.34-.23-.07-.42 1.75-1.23 4.63-.88 5-.46.37.42-.09 3.3-1.74 4.68-.25.21-.49.09-.38-.18.34-.95 1.17-3.02.78-3.5z"></path></svg>

After

Width:  |  Height:  |  Size: 2.6 KiB

View File

@ -0,0 +1,60 @@
from langchain_community.embeddings import BedrockEmbeddings
from models_provider.base_model_provider import MaxKBBaseModel
from typing import Dict, List
from models_provider.impl.aws_bedrock_model_provider.model.llm import _update_aws_credentials
class BedrockEmbeddingModel(MaxKBBaseModel, BedrockEmbeddings):
def __init__(self, model_id: str, region_name: str, credentials_profile_name: str,
**kwargs):
super().__init__(model_id=model_id, region_name=region_name,
credentials_profile_name=credentials_profile_name, **kwargs)
@classmethod
def new_instance(cls, model_type: str, model_name: str, model_credential: Dict[str, str],
**model_kwargs) -> 'BedrockModel':
_update_aws_credentials(model_credential['access_key_id'], model_credential['access_key_id'],
model_credential['secret_access_key'])
return cls(
model_id=model_name,
region_name=model_credential['region_name'],
credentials_profile_name=model_credential['access_key_id'],
)
def embed_documents(self, texts: List[str]) -> List[List[float]]:
"""Compute doc embeddings using a Bedrock model.
Args:
texts: The list of texts to embed
Returns:
List of embeddings, one for each text.
"""
results = []
for text in texts:
response = self._embedding_func(text)
if self.normalize:
response = self._normalize_vector(response)
results.append(response)
return results
def embed_query(self, text: str) -> List[float]:
"""Compute query embeddings using a Bedrock model.
Args:
text: The text to embed.
Returns:
Embeddings for the text.
"""
embedding = self._embedding_func(text)
if self.normalize:
return self._normalize_vector(embedding)
return embedding

View File

@ -0,0 +1,88 @@
import os
import re
from typing import Dict
from botocore.config import Config
from langchain_community.chat_models import BedrockChat
from models_provider.base_model_provider import MaxKBBaseModel
def get_max_tokens_keyword(model_name):
"""
根据模型名称返回正确的 max_tokens 关键字
:param model_name: 模型名称字符串
:return: 对应的 max_tokens 关键字字符串
"""
maxTokens = ["ai21.j2-ultra-v1", "ai21.j2-mid-v1"]
# max_tokens_to_sample = ["anthropic.claude-v2:1", "anthropic.claude-v2", "anthropic.claude-instant-v1"]
maxTokenCount = ["amazon.titan-text-lite-v1", "amazon.titan-text-express-v1"]
max_new_tokens = [
"us.meta.llama3-2-1b-instruct-v1:0", "us.meta.llama3-2-3b-instruct-v1:0", "us.meta.llama3-2-11b-instruct-v1:0",
"us.meta.llama3-2-90b-instruct-v1:0"]
if model_name in maxTokens:
return 'maxTokens'
elif model_name in maxTokenCount:
return 'maxTokenCount'
elif model_name in max_new_tokens:
return 'max_new_tokens'
else:
return 'max_tokens'
class BedrockModel(MaxKBBaseModel, BedrockChat):
@staticmethod
def is_cache_model():
return False
def __init__(self, model_id: str, region_name: str, credentials_profile_name: str,
streaming: bool = False, config: Config = None, **kwargs):
super().__init__(model_id=model_id, region_name=region_name,
credentials_profile_name=credentials_profile_name, streaming=streaming, config=config,
**kwargs)
@classmethod
def new_instance(cls, model_type: str, model_name: str, model_credential: Dict[str, str],
**model_kwargs) -> 'BedrockModel':
optional_params = MaxKBBaseModel.filter_optional_params(model_kwargs)
config = {}
# 判断model_kwargs是否包含 base_url 且不为空
if 'base_url' in model_credential and model_credential['base_url']:
proxy_url = model_credential['base_url']
config = Config(
proxies={
'http': proxy_url,
'https': proxy_url
},
connect_timeout=60,
read_timeout=60
)
_update_aws_credentials(model_credential['access_key_id'], model_credential['access_key_id'],
model_credential['secret_access_key'])
return cls(
model_id=model_name,
region_name=model_credential['region_name'],
credentials_profile_name=model_credential['access_key_id'],
streaming=model_kwargs.pop('streaming', True),
model_kwargs=optional_params,
config=config
)
def _update_aws_credentials(profile_name, access_key_id, secret_access_key):
credentials_path = os.path.join(os.path.expanduser("~"), ".aws", "credentials")
os.makedirs(os.path.dirname(credentials_path), exist_ok=True)
content = open(credentials_path, 'r').read() if os.path.exists(credentials_path) else ''
pattern = rf'\n*\[{profile_name}\]\n*(aws_access_key_id = .*)\n*(aws_secret_access_key = .*)\n*'
content = re.sub(pattern, '', content, flags=re.DOTALL)
if not re.search(rf'\[{profile_name}\]', content):
content += f"\n[{profile_name}]\naws_access_key_id = {access_key_id}\naws_secret_access_key = {secret_access_key}\n"
with open(credentials_path, 'w') as file:
file.write(content)

View File

@ -0,0 +1,8 @@
# coding=utf-8
"""
@project: maxkb
@Author
@file __init__.py.py
@date2023/10/31 17:16
@desc:
"""

View File

@ -0,0 +1,116 @@
# coding=utf-8
"""
@project: maxkb
@Author
@file azure_model_provider.py
@date2023/10/31 16:19
@desc:
"""
import os
from common.utils.common import get_file_content
from models_provider.base_model_provider import IModelProvider, ModelProvideInfo, ModelInfo, \
ModelTypeConst, ModelInfoManage
from models_provider.impl.azure_model_provider.credential.embedding import AzureOpenAIEmbeddingCredential
from models_provider.impl.azure_model_provider.credential.image import AzureOpenAIImageModelCredential
from models_provider.impl.azure_model_provider.credential.llm import AzureLLMModelCredential
from models_provider.impl.azure_model_provider.credential.stt import AzureOpenAISTTModelCredential
from models_provider.impl.azure_model_provider.credential.tti import AzureOpenAITextToImageModelCredential
from models_provider.impl.azure_model_provider.credential.tts import AzureOpenAITTSModelCredential
from models_provider.impl.azure_model_provider.model.azure_chat_model import AzureChatModel
from models_provider.impl.azure_model_provider.model.embedding import AzureOpenAIEmbeddingModel
from models_provider.impl.azure_model_provider.model.image import AzureOpenAIImage
from models_provider.impl.azure_model_provider.model.stt import AzureOpenAISpeechToText
from models_provider.impl.azure_model_provider.model.tti import AzureOpenAITextToImage
from models_provider.impl.azure_model_provider.model.tts import AzureOpenAITextToSpeech
from maxkb.conf import PROJECT_DIR
from django.utils.translation import gettext_lazy as _
base_azure_llm_model_credential = AzureLLMModelCredential()
base_azure_embedding_model_credential = AzureOpenAIEmbeddingCredential()
base_azure_image_model_credential = AzureOpenAIImageModelCredential()
base_azure_tti_model_credential = AzureOpenAITextToImageModelCredential()
base_azure_tts_model_credential = AzureOpenAITTSModelCredential()
base_azure_stt_model_credential = AzureOpenAISTTModelCredential()
default_model_info = [
ModelInfo('Azure OpenAI', '', ModelTypeConst.LLM,
base_azure_llm_model_credential, AzureChatModel, api_version='2024-02-15-preview'
),
ModelInfo('gpt-4', '', ModelTypeConst.LLM,
base_azure_llm_model_credential, AzureChatModel, api_version='2024-02-15-preview'
),
ModelInfo('gpt-4o', '', ModelTypeConst.LLM,
base_azure_llm_model_credential, AzureChatModel, api_version='2024-02-15-preview'
),
ModelInfo('gpt-4o-mini', '', ModelTypeConst.LLM,
base_azure_llm_model_credential, AzureChatModel, api_version='2024-02-15-preview'
),
]
embedding_model_info = [
ModelInfo('text-embedding-3-large', '', ModelTypeConst.EMBEDDING,
base_azure_embedding_model_credential, AzureOpenAIEmbeddingModel, api_version='2023-05-15'
),
ModelInfo('text-embedding-3-small', '', ModelTypeConst.EMBEDDING,
base_azure_embedding_model_credential, AzureOpenAIEmbeddingModel, api_version='2023-05-15'
),
ModelInfo('text-embedding-ada-002', '', ModelTypeConst.EMBEDDING,
base_azure_embedding_model_credential, AzureOpenAIEmbeddingModel, api_version='2023-05-15'
),
]
image_model_info = [
ModelInfo('gpt-4o', '', ModelTypeConst.IMAGE,
base_azure_image_model_credential, AzureOpenAIImage, api_version='2023-05-15'
),
ModelInfo('gpt-4o-mini', '', ModelTypeConst.IMAGE,
base_azure_image_model_credential, AzureOpenAIImage, api_version='2023-05-15'
),
]
tti_model_info = [
ModelInfo('dall-e-3', '', ModelTypeConst.TTI,
base_azure_tti_model_credential, AzureOpenAITextToImage, api_version='2023-05-15'
),
]
tts_model_info = [
ModelInfo('tts', '', ModelTypeConst.TTS,
base_azure_tts_model_credential, AzureOpenAITextToSpeech, api_version='2023-05-15'
),
]
stt_model_info = [
ModelInfo('whisper', '', ModelTypeConst.STT,
base_azure_stt_model_credential, AzureOpenAISpeechToText, api_version='2023-05-15'
),
]
model_info_manage = (
ModelInfoManage.builder()
.append_default_model_info(default_model_info[0])
.append_model_info_list(default_model_info)
.append_model_info_list(embedding_model_info)
.append_default_model_info(embedding_model_info[0])
.append_model_info_list(image_model_info)
.append_default_model_info(image_model_info[0])
.append_model_info_list(stt_model_info)
.append_default_model_info(stt_model_info[0])
.append_model_info_list(tts_model_info)
.append_default_model_info(tts_model_info[0])
.append_model_info_list(tti_model_info)
.append_default_model_info(tti_model_info[0])
.build()
)
class AzureModelProvider(IModelProvider):
def get_model_info_manage(self):
return model_info_manage
def get_model_provide_info(self):
return ModelProvideInfo(provider='model_azure_provider', name='Azure OpenAI', icon=get_file_content(
os.path.join(PROJECT_DIR, "apps", 'models_provider', 'impl', 'azure_model_provider', 'icon',
'azure_icon_svg')))

View File

@ -0,0 +1,57 @@
# coding=utf-8
"""
@project: MaxKB
@Author
@file llm.py
@date2024/7/11 17:08
@desc:
"""
import traceback
from typing import Dict
from django.utils.translation import gettext as _
from common import forms
from common.exception.app_exception import AppApiException
from common.forms import BaseForm
from models_provider.base_model_provider import BaseModelCredential, ValidCode
class AzureOpenAIEmbeddingCredential(BaseForm, BaseModelCredential):
def is_valid(self, model_type: str, model_name, model_credential: Dict[str, object], model_params, provider,
raise_exception=False):
model_type_list = provider.get_model_type_list()
if not any(list(filter(lambda mt: mt.get('value') == model_type, model_type_list))):
raise AppApiException(ValidCode.valid_error.value,
_('{model_type} Model type is not supported').format(model_type=model_type))
for key in ['api_base', 'api_key', 'api_version']:
if key not in model_credential:
if raise_exception:
raise AppApiException(ValidCode.valid_error.value, _('{key} is required').format(key=key))
else:
return False
try:
model = provider.get_model(model_type, model_name, model_credential)
model.embed_query(_('Hello'))
except Exception as e:
traceback.print_exc()
if isinstance(e, AppApiException):
raise e
if raise_exception:
raise AppApiException(ValidCode.valid_error.value,
_('Verification failed, please check whether the parameters are correct'))
else:
return False
return True
def encryption_dict(self, model: Dict[str, object]):
return {**model, 'api_key': super().encryption(model.get('api_key', ''))}
api_version = forms.TextInputField("Api Version", required=True)
api_base = forms.TextInputField('Azure Endpoint', required=True)
api_key = forms.PasswordInputField("API Key", required=True)

View File

@ -0,0 +1,75 @@
# coding=utf-8
import base64
import os
import traceback
from typing import Dict
from langchain_core.messages import HumanMessage
from common import forms
from common.exception.app_exception import AppApiException
from common.forms import BaseForm, TooltipLabel
from models_provider.base_model_provider import BaseModelCredential, ValidCode
from django.utils.translation import gettext_lazy as _, gettext
class AzureOpenAIImageModelParams(BaseForm):
temperature = forms.SliderField(TooltipLabel(_('Temperature'),
_('Higher values make the output more random, while lower values make it more focused and deterministic')),
required=True, default_value=0.7,
_min=0.1,
_max=1.0,
_step=0.01,
precision=2)
max_tokens = forms.SliderField(
TooltipLabel(_('Output the maximum Tokens'),
_('Specify the maximum number of tokens that the model can generate')),
required=True, default_value=800,
_min=1,
_max=100000,
_step=1,
precision=0)
class AzureOpenAIImageModelCredential(BaseForm, BaseModelCredential):
api_version = forms.TextInputField("API Version", required=True)
api_base = forms.TextInputField('Azure Endpoint', required=True)
api_key = forms.PasswordInputField("API Key", required=True)
def is_valid(self, model_type: str, model_name, model_credential: Dict[str, object], model_params, provider,
raise_exception=False):
model_type_list = provider.get_model_type_list()
if not any(list(filter(lambda mt: mt.get('value') == model_type, model_type_list))):
raise AppApiException(ValidCode.valid_error.value,
gettext('{model_type} Model type is not supported').format(model_type=model_type))
for key in ['api_base', 'api_key', 'api_version']:
if key not in model_credential:
if raise_exception:
raise AppApiException(ValidCode.valid_error.value, gettext('{key} is required').format(key=key))
else:
return False
try:
model = provider.get_model(model_type, model_name, model_credential, **model_params)
res = model.stream([HumanMessage(content=[{"type": "text", "text": gettext('Hello')}])])
for chunk in res:
print(chunk)
except Exception as e:
traceback.print_exc()
if isinstance(e, AppApiException):
raise e
if raise_exception:
raise AppApiException(ValidCode.valid_error.value,
gettext(
'Verification failed, please check whether the parameters are correct: {error}').format(
error=str(e)))
else:
return False
return True
def encryption_dict(self, model: Dict[str, object]):
return {**model, 'api_key': super().encryption(model.get('api_key', ''))}
def get_model_params_setting_form(self, model_name):
return AzureOpenAIImageModelParams()

View File

@ -0,0 +1,96 @@
# coding=utf-8
"""
@project: MaxKB
@Author
@file llm.py
@date2024/7/11 17:08
@desc:
"""
import traceback
from typing import Dict
from langchain_core.messages import HumanMessage
from openai import BadRequestError
from common import forms
from common.exception.app_exception import AppApiException
from common.forms import BaseForm, TooltipLabel
from models_provider.base_model_provider import BaseModelCredential, ValidCode
from django.utils.translation import gettext_lazy as _, gettext
class AzureLLMModelParams(BaseForm):
temperature = forms.SliderField(TooltipLabel(_('Temperature'),
_('Higher values make the output more random, while lower values make it more focused and deterministic')),
required=True, default_value=0.7,
_min=0.1,
_max=1.0,
_step=0.01,
precision=2)
max_tokens = forms.SliderField(
TooltipLabel(_('Output the maximum Tokens'),
_('Specify the maximum number of tokens that the model can generate')),
required=True, default_value=800,
_min=1,
_max=100000,
_step=1,
precision=0)
class o3MiniLLMModelParams(BaseForm):
max_completion_tokens = forms.SliderField(
TooltipLabel(_('Output the maximum Tokens'),
_('Specify the maximum number of tokens that the model can generate')),
required=True, default_value=800,
_min=1,
_max=5000,
_step=1,
precision=0)
class AzureLLMModelCredential(BaseForm, BaseModelCredential):
def is_valid(self, model_type: str, model_name, model_credential: Dict[str, object], model_params, provider,
raise_exception=False):
model_type_list = provider.get_model_type_list()
if not any(list(filter(lambda mt: mt.get('value') == model_type, model_type_list))):
raise AppApiException(ValidCode.valid_error.value,
gettext('{model_type} Model type is not supported').format(model_type=model_type))
for key in ['api_base', 'api_key', 'deployment_name', 'api_version']:
if key not in model_credential:
if raise_exception:
raise AppApiException(ValidCode.valid_error.value, gettext('{key} is required').format(key=key))
else:
return False
try:
model = provider.get_model(model_type, model_name, model_credential, **model_params)
model.invoke([HumanMessage(content=gettext('Hello'))])
except Exception as e:
traceback.print_exc()
if isinstance(e, AppApiException) or isinstance(e, BadRequestError):
raise e
if raise_exception:
raise AppApiException(ValidCode.valid_error.value,
gettext('Verification failed, please check whether the parameters are correct'))
else:
return False
return True
def encryption_dict(self, model: Dict[str, object]):
return {**model, 'api_key': super().encryption(model.get('api_key', ''))}
api_version = forms.TextInputField("API Version", required=True)
api_base = forms.TextInputField('Azure Endpoint', required=True)
api_key = forms.PasswordInputField("API Key", required=True)
deployment_name = forms.TextInputField("Deployment name", required=True)
def get_model_params_setting_form(self, model_name):
if 'o3' in model_name or 'o1' in model_name:
return o3MiniLLMModelParams()
return AzureLLMModelParams()

View File

@ -0,0 +1,50 @@
# coding=utf-8
import traceback
from typing import Dict
from django.utils.translation import gettext as _
from common import forms
from common.exception.app_exception import AppApiException
from common.forms import BaseForm
from models_provider.base_model_provider import BaseModelCredential, ValidCode
class AzureOpenAISTTModelCredential(BaseForm, BaseModelCredential):
api_version = forms.TextInputField("API Version", required=True)
api_base = forms.TextInputField('Azure Endpoint', required=True)
api_key = forms.PasswordInputField("API Key", required=True)
def is_valid(self, model_type: str, model_name, model_credential: Dict[str, object], model_params, provider,
raise_exception=False):
model_type_list = provider.get_model_type_list()
if not any(list(filter(lambda mt: mt.get('value') == model_type, model_type_list))):
raise AppApiException(ValidCode.valid_error.value,
_('{model_type} Model type is not supported').format(model_type=model_type))
for key in ['api_base', 'api_key', 'api_version']:
if key not in model_credential:
if raise_exception:
raise AppApiException(ValidCode.valid_error.value, _('{key} is required').format(key=key))
else:
return False
try:
model = provider.get_model(model_type, model_name, model_credential)
model.check_auth()
except Exception as e:
traceback.print_exc()
if isinstance(e, AppApiException):
raise e
if raise_exception:
raise AppApiException(ValidCode.valid_error.value,
_('Verification failed, please check whether the parameters are correct: {error}').format(
error=str(e)))
else:
return False
return True
def encryption_dict(self, model: Dict[str, object]):
return {**model, 'api_key': super().encryption(model.get('api_key', ''))}
def get_model_params_setting_form(self, model_name):
pass

View File

@ -0,0 +1,87 @@
# coding=utf-8
import traceback
from typing import Dict
from django.utils.translation import gettext_lazy as _, gettext
from common import forms
from common.exception.app_exception import AppApiException
from common.forms import BaseForm, TooltipLabel
from models_provider.base_model_provider import BaseModelCredential, ValidCode
class AzureOpenAITTIModelParams(BaseForm):
size = forms.SingleSelect(
TooltipLabel(_('Image size'), _('Specify the size of the generated image, such as: 1024x1024')),
required=True,
default_value='1024x1024',
option_list=[
{'value': '1024x1024', 'label': '1024x1024'},
{'value': '1024x1792', 'label': '1024x1792'},
{'value': '1792x1024', 'label': '1792x1024'},
],
text_field='label',
value_field='value'
)
quality = forms.SingleSelect(
TooltipLabel(_('Picture quality'), ''),
required=True,
default_value='standard',
option_list=[
{'value': 'standard', 'label': 'standard'},
{'value': 'hd', 'label': 'hd'},
],
text_field='label',
value_field='value'
)
n = forms.SliderField(
TooltipLabel(_('Number of pictures'), _('Specify the number of generated images')),
required=True, default_value=1,
_min=1,
_max=10,
_step=1,
precision=0)
class AzureOpenAITextToImageModelCredential(BaseForm, BaseModelCredential):
api_version = forms.TextInputField("API Version", required=True)
api_base = forms.TextInputField('Azure Endpoint', required=True)
api_key = forms.PasswordInputField("API Key", required=True)
def is_valid(self, model_type: str, model_name, model_credential: Dict[str, object], model_params, provider,
raise_exception=False):
model_type_list = provider.get_model_type_list()
if not any(list(filter(lambda mt: mt.get('value') == model_type, model_type_list))):
raise AppApiException(ValidCode.valid_error.value,
gettext('{model_type} Model type is not supported').format(model_type=model_type))
for key in ['api_base', 'api_key', 'api_version']:
if key not in model_credential:
if raise_exception:
raise AppApiException(ValidCode.valid_error.value, gettext('{key} is required').format(key=key))
else:
return False
try:
model = provider.get_model(model_type, model_name, model_credential, **model_params)
res = model.check_auth()
print(res)
except Exception as e:
traceback.print_exc()
if isinstance(e, AppApiException):
raise e
if raise_exception:
raise AppApiException(ValidCode.valid_error.value,
gettext(
'Verification failed, please check whether the parameters are correct: {error}').format(
error=str(e)))
else:
return False
return True
def encryption_dict(self, model: Dict[str, object]):
return {**model, 'api_key': super().encryption(model.get('api_key', ''))}
def get_model_params_setting_form(self, model_name):
return AzureOpenAITTIModelParams()

View File

@ -0,0 +1,68 @@
# coding=utf-8
import traceback
from typing import Dict
from django.utils.translation import gettext_lazy as _, gettext
from common import forms
from common.exception.app_exception import AppApiException
from common.forms import BaseForm, TooltipLabel
from models_provider.base_model_provider import BaseModelCredential, ValidCode
class AzureOpenAITTSModelGeneralParams(BaseForm):
# alloy, echo, fable, onyx, nova, shimmer
voice = forms.SingleSelect(
TooltipLabel('Voice',
_('Try out the different sounds (Alloy, Echo, Fable, Onyx, Nova, and Sparkle) to find one that suits your desired tone and audience. The current voiceover is optimized for English.')),
required=True, default_value='alloy',
text_field='value',
value_field='value',
option_list=[
{'text': 'alloy', 'value': 'alloy'},
{'text': 'echo', 'value': 'echo'},
{'text': 'fable', 'value': 'fable'},
{'text': 'onyx', 'value': 'onyx'},
{'text': 'nova', 'value': 'nova'},
{'text': 'shimmer', 'value': 'shimmer'},
])
class AzureOpenAITTSModelCredential(BaseForm, BaseModelCredential):
api_version = forms.TextInputField("API Version", required=True)
api_base = forms.TextInputField('Azure Endpoint', required=True)
api_key = forms.PasswordInputField("API Key", required=True)
def is_valid(self, model_type: str, model_name, model_credential: Dict[str, object], model_params, provider,
raise_exception=False):
model_type_list = provider.get_model_type_list()
if not any(list(filter(lambda mt: mt.get('value') == model_type, model_type_list))):
raise AppApiException(ValidCode.valid_error.value,
gettext('{model_type} Model type is not supported').format(model_type=model_type))
for key in ['api_base', 'api_key', 'api_version']:
if key not in model_credential:
if raise_exception:
raise AppApiException(ValidCode.valid_error.value, gettext('{key} is required').format(key=key))
else:
return False
try:
model = provider.get_model(model_type, model_name, model_credential, **model_params)
model.check_auth()
except Exception as e:
traceback.print_exc()
if isinstance(e, AppApiException):
raise e
if raise_exception:
raise AppApiException(ValidCode.valid_error.value, gettext(
'Verification failed, please check whether the parameters are correct: {error}').format(
error=str(e)))
else:
return False
return True
def encryption_dict(self, model: Dict[str, object]):
return {**model, 'api_key': super().encryption(model.get('api_key', ''))}
def get_model_params_setting_form(self, model_name):
return AzureOpenAITTSModelGeneralParams()

View File

@ -0,0 +1 @@
<svg t="1724827784525" class="icon" viewBox="0 0 1083 1024" version="1.1" xmlns="http://www.w3.org/2000/svg" p-id="5331" width="100%" height="100%"><path d="M540.115607 419.535168L460.236208 3.644671 376.260809 1.620347c-31.825387-5.244301-54.020439 8.390289-66.588115 40.908208C297.107977 75.046474 194.101272 379.938035 0.654058 957.203237-1.245965 999.731792 12.341272 1020.99311 41.418728 1020.99311h275.015029c14.599399-4.010173 25.104277-17.02178 31.514636-39.031861 6.410358-22.010081 70.46659-209.486428 192.168694-562.427561z" fill="#0D559D" p-id="5332"></path><path d="M669.463676 658.929202l-351.23052 6.189873c-29.527306 6.969711-32.546035 22.615306-9.054705 46.941226 35.237734 36.486659 357.010497 317.661965 372.719722 311.580115 9.100578-3.52185 16.537896-4.005734 26.473064-9.14941 24.994775-12.939098 21.644578-70.568694-10.05059-172.887307l-28.855491-182.674497z" fill="#0078D4" p-id="5333"></path><path d="M371.417526 0.331468l350.216879-0.325549a47.405873 47.405873 0 0 1 31.042589 11.556994c7.219792 6.252023 13.292763 14.842081 18.220393 25.765734 9.603699 21.293873 112.287815 325.368601 308.052347 912.225665a54.729249 54.729249 0 0 1 1.284439 30.169526c-6.225387 25.780532-19.432324 39.509827-39.623768 41.181965-24.348116 2.018405-144.088046 2.962497-359.221272 2.833758 30.551306-9.66141 44.410821-28.275422 41.578543-55.843515C720.153156 940.511445 618.336185 634.30289 417.516763 49.264462h0.00296a84.346821 84.346821 0 0 0-13.869873-25.255213C392.112092 9.559306 382.194682 2.04652 373.893179 1.47237c-10.553711-0.728046-11.379422-1.109827-2.475653-1.140902z" fill="#2FA7E7" p-id="5334"></path></svg>

After

Width:  |  Height:  |  Size: 1.6 KiB

View File

@ -0,0 +1,51 @@
# coding=utf-8
"""
@project: maxkb
@Author
@file azure_chat_model.py
@date2024/4/28 11:45
@desc:
"""
from typing import List, Dict
from langchain_core.messages import BaseMessage, get_buffer_string
from langchain_openai import AzureChatOpenAI
from common.config.tokenizer_manage_config import TokenizerManage
from models_provider.base_model_provider import MaxKBBaseModel
class AzureChatModel(MaxKBBaseModel, AzureChatOpenAI):
@staticmethod
def is_cache_model():
return False
@staticmethod
def new_instance(model_type, model_name, model_credential: Dict[str, object], **model_kwargs):
optional_params = MaxKBBaseModel.filter_optional_params(model_kwargs)
return AzureChatModel(
azure_endpoint=model_credential.get('api_base'),
model_name=model_name,
openai_api_version=model_credential.get('api_version', '2024-02-15-preview'),
deployment_name=model_credential.get('deployment_name'),
openai_api_key=model_credential.get('api_key'),
openai_api_type="azure",
**optional_params,
streaming=True,
)
def get_num_tokens_from_messages(self, messages: List[BaseMessage]) -> int:
try:
return super().get_num_tokens_from_messages(messages)
except Exception as e:
tokenizer = TokenizerManage.get_tokenizer()
return sum([len(tokenizer.encode(get_buffer_string([m]))) for m in messages])
def get_num_tokens(self, text: str) -> int:
try:
return super().get_num_tokens(text)
except Exception as e:
tokenizer = TokenizerManage.get_tokenizer()
return len(tokenizer.encode(text))

View File

@ -0,0 +1,25 @@
# coding=utf-8
"""
@project: MaxKB
@Author
@file embedding.py
@date2024/7/12 17:44
@desc:
"""
from typing import Dict
from langchain_openai import AzureOpenAIEmbeddings
from models_provider.base_model_provider import MaxKBBaseModel
class AzureOpenAIEmbeddingModel(MaxKBBaseModel, AzureOpenAIEmbeddings):
@staticmethod
def new_instance(model_type, model_name, model_credential: Dict[str, object], **model_kwargs):
return AzureOpenAIEmbeddingModel(
model=model_name,
openai_api_key=model_credential.get('api_key'),
azure_endpoint=model_credential.get('api_base'),
openai_api_version=model_credential.get('api_version'),
openai_api_type="azure",
)

View File

@ -0,0 +1,42 @@
from typing import Dict, List
from langchain_core.messages import BaseMessage, get_buffer_string
from langchain_openai import AzureChatOpenAI
from common.config.tokenizer_manage_config import TokenizerManage
from models_provider.base_model_provider import MaxKBBaseModel
def custom_get_token_ids(text: str):
tokenizer = TokenizerManage.get_tokenizer()
return tokenizer.encode(text)
class AzureOpenAIImage(MaxKBBaseModel, AzureChatOpenAI):
@staticmethod
def new_instance(model_type, model_name, model_credential: Dict[str, object], **model_kwargs):
optional_params = MaxKBBaseModel.filter_optional_params(model_kwargs)
return AzureOpenAIImage(
model_name=model_name,
openai_api_key=model_credential.get('api_key'),
azure_endpoint=model_credential.get('api_base'),
openai_api_version=model_credential.get('api_version'),
openai_api_type="azure",
streaming=True,
**optional_params,
)
def get_num_tokens_from_messages(self, messages: List[BaseMessage]) -> int:
try:
return super().get_num_tokens_from_messages(messages)
except Exception as e:
tokenizer = TokenizerManage.get_tokenizer()
return sum([len(tokenizer.encode(get_buffer_string([m]))) for m in messages])
def get_num_tokens(self, text: str) -> int:
try:
return super().get_num_tokens(text)
except Exception as e:
tokenizer = TokenizerManage.get_tokenizer()
return len(tokenizer.encode(text))

View File

@ -0,0 +1,62 @@
import io
from typing import Dict
from openai import AzureOpenAI
from common.config.tokenizer_manage_config import TokenizerManage
from models_provider.base_model_provider import MaxKBBaseModel
from models_provider.impl.base_stt import BaseSpeechToText
def custom_get_token_ids(text: str):
tokenizer = TokenizerManage.get_tokenizer()
return tokenizer.encode(text)
class AzureOpenAISpeechToText(MaxKBBaseModel, BaseSpeechToText):
api_base: str
api_key: str
api_version: str
model: str
def __init__(self, **kwargs):
super().__init__(**kwargs)
self.api_key = kwargs.get('api_key')
self.api_base = kwargs.get('api_base')
self.api_version = kwargs.get('api_version')
@staticmethod
def new_instance(model_type, model_name, model_credential: Dict[str, object], **model_kwargs):
optional_params = {}
if 'max_tokens' in model_kwargs and model_kwargs['max_tokens'] is not None:
optional_params['max_tokens'] = model_kwargs['max_tokens']
if 'temperature' in model_kwargs and model_kwargs['temperature'] is not None:
optional_params['temperature'] = model_kwargs['temperature']
return AzureOpenAISpeechToText(
model=model_name,
api_base=model_credential.get('api_base'),
api_key=model_credential.get('api_key'),
api_version=model_credential.get('api_version'),
**optional_params,
)
def check_auth(self):
client = AzureOpenAI(
azure_endpoint=self.api_base,
api_key=self.api_key,
api_version=self.api_version
)
response_list = client.models.with_raw_response.list()
# print(response_list)
def speech_to_text(self, audio_file):
client = AzureOpenAI(
azure_endpoint=self.api_base,
api_key=self.api_key,
api_version=self.api_version
)
audio_data = audio_file.read()
buffer = io.BytesIO(audio_data)
buffer.name = "file.mp3" # this is the important line
res = client.audio.transcriptions.create(model=self.model, language="zh", file=buffer)
return res.text

View File

@ -0,0 +1,61 @@
from typing import Dict
from openai import AzureOpenAI
from common.config.tokenizer_manage_config import TokenizerManage
from models_provider.base_model_provider import MaxKBBaseModel
from models_provider.impl.base_tti import BaseTextToImage
def custom_get_token_ids(text: str):
tokenizer = TokenizerManage.get_tokenizer()
return tokenizer.encode(text)
class AzureOpenAITextToImage(MaxKBBaseModel, BaseTextToImage):
api_base: str
api_key: str
api_version: str
model: str
params: dict
def __init__(self, **kwargs):
super().__init__(**kwargs)
self.api_key = kwargs.get('api_key')
self.api_base = kwargs.get('api_base')
self.api_version = kwargs.get('api_version')
self.model = kwargs.get('model')
self.params = kwargs.get('params')
@staticmethod
def new_instance(model_type, model_name, model_credential: Dict[str, object], **model_kwargs):
optional_params = {'params': {'size': '1024x1024', 'quality': 'standard', 'n': 1}}
for key, value in model_kwargs.items():
if key not in ['model_id', 'use_local', 'streaming']:
optional_params['params'][key] = value
return AzureOpenAITextToImage(
model=model_name,
api_base=model_credential.get('api_base'),
api_key=model_credential.get('api_key'),
api_version=model_credential.get('api_version'),
**optional_params,
)
def is_cache_model(self):
return False
def check_auth(self):
chat = AzureOpenAI(api_key=self.api_key, azure_endpoint=self.api_base, api_version=self.api_version)
response_list = chat.models.with_raw_response.list()
# self.generate_image('生成一个小猫图片')
def generate_image(self, prompt: str, negative_prompt: str = None):
chat = AzureOpenAI(api_key=self.api_key, azure_endpoint=self.api_base, api_version=self.api_version)
res = chat.images.generate(model=self.model, prompt=prompt, **self.params)
file_urls = []
for content in res.data:
url = content.url
file_urls.append(url)
return file_urls

View File

@ -0,0 +1,69 @@
from typing import Dict
from openai import AzureOpenAI
from common.config.tokenizer_manage_config import TokenizerManage
from common.utils.common import _remove_empty_lines
from models_provider.base_model_provider import MaxKBBaseModel
from models_provider.impl.base_tts import BaseTextToSpeech
def custom_get_token_ids(text: str):
tokenizer = TokenizerManage.get_tokenizer()
return tokenizer.encode(text)
class AzureOpenAITextToSpeech(MaxKBBaseModel, BaseTextToSpeech):
api_base: str
api_key: str
api_version: str
model: str
params: dict
def __init__(self, **kwargs):
super().__init__(**kwargs)
self.api_key = kwargs.get('api_key')
self.api_base = kwargs.get('api_base')
self.api_version = kwargs.get('api_version')
self.model = kwargs.get('model')
self.params = kwargs.get('params')
@staticmethod
def new_instance(model_type, model_name, model_credential: Dict[str, object], **model_kwargs):
optional_params = {'params': {'voice': 'alloy'}}
for key, value in model_kwargs.items():
if key not in ['model_id', 'use_local', 'streaming']:
optional_params['params'][key] = value
return AzureOpenAITextToSpeech(
model=model_name,
api_base=model_credential.get('api_base'),
api_key=model_credential.get('api_key'),
api_version=model_credential.get('api_version'),
**optional_params,
)
def check_auth(self):
client = AzureOpenAI(
azure_endpoint=self.api_base,
api_key=self.api_key,
api_version=self.api_version
)
response_list = client.models.with_raw_response.list()
# print(response_list)
def text_to_speech(self, text):
client = AzureOpenAI(
azure_endpoint=self.api_base,
api_key=self.api_key,
api_version=self.api_version
)
text = _remove_empty_lines(text)
with client.audio.speech.with_streaming_response.create(
model=self.model,
input=text,
**self.params
) as response:
return response.read()
def is_cache_model(self):
return False

View File

@ -0,0 +1,168 @@
# coding=utf-8
import warnings
from typing import List, Dict, Optional, Any, Iterator, cast, Type, Union
import openai
from langchain_core.callbacks import CallbackManagerForLLMRun
from langchain_core.language_models import LanguageModelInput
from langchain_core.messages import BaseMessage, get_buffer_string, BaseMessageChunk, AIMessageChunk
from langchain_core.outputs import ChatGenerationChunk, ChatGeneration
from langchain_core.runnables import RunnableConfig, ensure_config
from langchain_core.utils.pydantic import is_basemodel_subclass
from langchain_openai import ChatOpenAI
from common.config.tokenizer_manage_config import TokenizerManage
def custom_get_token_ids(text: str):
tokenizer = TokenizerManage.get_tokenizer()
return tokenizer.encode(text)
class BaseChatOpenAI(ChatOpenAI):
usage_metadata: dict = {}
custom_get_token_ids = custom_get_token_ids
def get_last_generation_info(self) -> Optional[Dict[str, Any]]:
return self.usage_metadata
def get_num_tokens_from_messages(self, messages: List[BaseMessage]) -> int:
if self.usage_metadata is None or self.usage_metadata == {}:
try:
return super().get_num_tokens_from_messages(messages)
except Exception as e:
tokenizer = TokenizerManage.get_tokenizer()
return sum([len(tokenizer.encode(get_buffer_string([m]))) for m in messages])
return self.usage_metadata.get('input_tokens', 0)
def get_num_tokens(self, text: str) -> int:
if self.usage_metadata is None or self.usage_metadata == {}:
try:
return super().get_num_tokens(text)
except Exception as e:
tokenizer = TokenizerManage.get_tokenizer()
return len(tokenizer.encode(text))
return self.get_last_generation_info().get('output_tokens', 0)
def _stream(
self,
messages: List[BaseMessage],
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> Iterator[ChatGenerationChunk]:
kwargs["stream"] = True
kwargs["stream_options"] = {"include_usage": True}
"""Set default stream_options."""
stream_usage = self._should_stream_usage(kwargs.get('stream_usage'), **kwargs)
# Note: stream_options is not a valid parameter for Azure OpenAI.
# To support users proxying Azure through ChatOpenAI, here we only specify
# stream_options if include_usage is set to True.
# See https://learn.microsoft.com/en-us/azure/ai-services/openai/whats-new
# for release notes.
if stream_usage:
kwargs["stream_options"] = {"include_usage": stream_usage}
payload = self._get_request_payload(messages, stop=stop, **kwargs)
default_chunk_class: Type[BaseMessageChunk] = AIMessageChunk
base_generation_info = {}
if "response_format" in payload and is_basemodel_subclass(
payload["response_format"]
):
# TODO: Add support for streaming with Pydantic response_format.
warnings.warn("Streaming with Pydantic response_format not yet supported.")
chat_result = self._generate(
messages, stop, run_manager=run_manager, **kwargs
)
msg = chat_result.generations[0].message
yield ChatGenerationChunk(
message=AIMessageChunk(
**msg.dict(exclude={"type", "additional_kwargs"}),
# preserve the "parsed" Pydantic object without converting to dict
additional_kwargs=msg.additional_kwargs,
),
generation_info=chat_result.generations[0].generation_info,
)
return
if self.include_response_headers:
raw_response = self.client.with_raw_response.create(**payload)
response = raw_response.parse()
base_generation_info = {"headers": dict(raw_response.headers)}
else:
response = self.client.create(**payload)
with response:
is_first_chunk = True
for chunk in response:
if not isinstance(chunk, dict):
chunk = chunk.model_dump()
generation_chunk = super()._convert_chunk_to_generation_chunk(
chunk,
default_chunk_class,
base_generation_info if is_first_chunk else {},
)
if generation_chunk is None:
continue
# custom code
if len(chunk['choices']) > 0 and 'reasoning_content' in chunk['choices'][0]['delta']:
generation_chunk.message.additional_kwargs["reasoning_content"] = chunk['choices'][0]['delta'][
'reasoning_content']
default_chunk_class = generation_chunk.message.__class__
logprobs = (generation_chunk.generation_info or {}).get("logprobs")
if run_manager:
run_manager.on_llm_new_token(
generation_chunk.text, chunk=generation_chunk, logprobs=logprobs
)
is_first_chunk = False
# custom code
if generation_chunk.message.usage_metadata is not None:
self.usage_metadata = generation_chunk.message.usage_metadata
yield generation_chunk
def _create_chat_result(self,
response: Union[dict, openai.BaseModel],
generation_info: Optional[Dict] = None):
result = super()._create_chat_result(response, generation_info)
try:
reasoning_content = ''
reasoning_content_enable = False
for res in response.choices:
if 'reasoning_content' in res.message.model_extra:
reasoning_content_enable = True
_reasoning_content = res.message.model_extra.get('reasoning_content')
if _reasoning_content is not None:
reasoning_content += _reasoning_content
if reasoning_content_enable:
result.llm_output['reasoning_content'] = reasoning_content
except Exception as e:
pass
return result
def invoke(
self,
input: LanguageModelInput,
config: Optional[RunnableConfig] = None,
*,
stop: Optional[List[str]] = None,
**kwargs: Any,
) -> BaseMessage:
config = ensure_config(config)
chat_result = cast(
ChatGeneration,
self.generate_prompt(
[self._convert_input(input)],
stop=stop,
callbacks=config.get("callbacks"),
tags=config.get("tags"),
metadata=config.get("metadata"),
run_name=config.get("run_name"),
run_id=config.pop("run_id", None),
**kwargs,
).generations[0][0],
).message
self.usage_metadata = chat_result.response_metadata[
'token_usage'] if 'token_usage' in chat_result.response_metadata else chat_result.usage_metadata
return chat_result

View File

@ -0,0 +1,14 @@
# coding=utf-8
from abc import abstractmethod
from pydantic import BaseModel
class BaseSpeechToText(BaseModel):
@abstractmethod
def check_auth(self):
pass
@abstractmethod
def speech_to_text(self, audio_file):
pass

View File

@ -0,0 +1,14 @@
# coding=utf-8
from abc import abstractmethod
from pydantic import BaseModel
class BaseTextToImage(BaseModel):
@abstractmethod
def check_auth(self):
pass
@abstractmethod
def generate_image(self, prompt: str, negative_prompt: str = None):
pass

View File

@ -0,0 +1,14 @@
# coding=utf-8
from abc import abstractmethod
from pydantic import BaseModel
class BaseTextToSpeech(BaseModel):
@abstractmethod
def check_auth(self):
pass
@abstractmethod
def text_to_speech(self, text):
pass

View File

@ -0,0 +1,8 @@
#!/usr/bin/env python
# -*- coding: UTF-8 -*-
"""
@Project MaxKB
@File __init__.py.py
@Author Brian Yang
@Date 5/12/24 7:38 AM
"""

View File

@ -0,0 +1,77 @@
# coding=utf-8
"""
@project: MaxKB
@Author
@file llm.py
@date2024/7/11 17:51
@desc:
"""
import traceback
from typing import Dict
from django.utils.translation import gettext_lazy as _, gettext
from langchain_core.messages import HumanMessage
from common import forms
from common.exception.app_exception import AppApiException
from common.forms import BaseForm, TooltipLabel
from models_provider.base_model_provider import BaseModelCredential, ValidCode
class DeepSeekLLMModelParams(BaseForm):
temperature = forms.SliderField(TooltipLabel(_('Temperature'),
_('Higher values make the output more random, while lower values make it more focused and deterministic')),
required=True, default_value=0.7,
_min=0.1,
_max=1.0,
_step=0.01,
precision=2)
max_tokens = forms.SliderField(
TooltipLabel(_('Output the maximum Tokens'),
_('Specify the maximum number of tokens that the model can generate')),
required=True, default_value=800,
_min=1,
_max=100000,
_step=1,
precision=0)
class DeepSeekLLMModelCredential(BaseForm, BaseModelCredential):
def is_valid(self, model_type: str, model_name, model_credential: Dict[str, object], model_params, provider,
raise_exception=False):
model_type_list = provider.get_model_type_list()
if not any(list(filter(lambda mt: mt.get('value') == model_type, model_type_list))):
raise AppApiException(ValidCode.valid_error.value,
gettext('{model_type} Model type is not supported').format(model_type=model_type))
for key in ['api_key']:
if key not in model_credential:
if raise_exception:
raise AppApiException(ValidCode.valid_error.value, gettext('{key} is required').format(key=key))
else:
return False
try:
model = provider.get_model(model_type, model_name, model_credential, **model_params)
model.invoke([HumanMessage(content=gettext('Hello'))])
except Exception as e:
traceback.print_exc()
if isinstance(e, AppApiException):
raise e
if raise_exception:
raise AppApiException(ValidCode.valid_error.value,
gettext(
'Verification failed, please check whether the parameters are correct: {error}').format(
error=str(e)))
else:
return False
return True
def encryption_dict(self, model: Dict[str, object]):
return {**model, 'api_key': super().encryption(model.get('api_key', ''))}
api_key = forms.PasswordInputField('API Key', required=True)
def get_model_params_setting_form(self, model_name):
return DeepSeekLLMModelParams()

View File

@ -0,0 +1,47 @@
#!/usr/bin/env python
# -*- coding: UTF-8 -*-
"""
@Project MaxKB
@File deepseek_model_provider.py
@Author Brian Yang
@Date 5/12/24 7:40 AM
"""
import os
from common.utils.common import get_file_content
from models_provider.base_model_provider import IModelProvider, ModelProvideInfo, ModelInfo, ModelTypeConst, \
ModelInfoManage
from models_provider.impl.deepseek_model_provider.credential.llm import DeepSeekLLMModelCredential
from models_provider.impl.deepseek_model_provider.model.llm import DeepSeekChatModel
from maxkb.conf import PROJECT_DIR
from django.utils.translation import gettext as _
deepseek_llm_model_credential = DeepSeekLLMModelCredential()
deepseek_reasoner = ModelInfo('deepseek-reasoner', '', ModelTypeConst.LLM,
deepseek_llm_model_credential, DeepSeekChatModel
)
deepseek_chat = ModelInfo('deepseek-chat', _('Good at common conversational tasks, supports 32K contexts'),
ModelTypeConst.LLM,
deepseek_llm_model_credential, DeepSeekChatModel
)
deepseek_coder = ModelInfo('deepseek-coder', _('Good at handling programming tasks, supports 16K contexts'),
ModelTypeConst.LLM,
deepseek_llm_model_credential,
DeepSeekChatModel)
model_info_manage = ModelInfoManage.builder().append_model_info(deepseek_reasoner).append_model_info(deepseek_chat).append_model_info(
deepseek_coder).append_default_model_info(
deepseek_coder).build()
class DeepSeekModelProvider(IModelProvider):
def get_model_info_manage(self):
return model_info_manage
def get_model_provide_info(self):
return ModelProvideInfo(provider='model_deepseek_provider', name='DeepSeek', icon=get_file_content(
os.path.join(PROJECT_DIR, "apps", 'models_provider', 'impl', 'deepseek_model_provider', 'icon',
'deepseek_icon_svg')))

View File

@ -0,0 +1,6 @@
<svg width="100%" height="100%" viewBox="0 0 50 50" fill="none" xmlns="http://www.w3.org/2000/svg"
xmlns:xlink="http://www.w3.org/1999/xlink">
<path id="path"
d="M48.8354 10.0479C48.3232 9.79199 48.1025 10.2798 47.8032 10.5278C47.7007 10.6079 47.6143 10.7119 47.5273 10.8076C46.7793 11.624 45.9048 12.1597 44.7622 12.0957C43.0923 12 41.666 12.5356 40.4058 13.8398C40.1377 12.2319 39.2476 11.272 37.8926 10.6558C37.1836 10.3359 36.4668 10.0156 35.9702 9.31982C35.6235 8.82373 35.5293 8.27197 35.356 7.72754C35.2456 7.3999 35.1353 7.06396 34.7651 7.00781C34.3633 6.94385 34.2056 7.2876 34.0479 7.57568C33.418 8.75195 33.1733 10.0479 33.1973 11.3599C33.2524 14.312 34.4736 16.6641 36.8999 18.3359C37.1758 18.5278 37.2466 18.7197 37.1597 19C36.9946 19.5757 36.7974 20.1357 36.624 20.7119C36.5137 21.0801 36.3486 21.1597 35.9624 21C34.6309 20.4321 33.481 19.5918 32.4644 18.5757C30.7393 16.8721 29.1792 14.9917 27.2334 13.52C26.7764 13.1758 26.3193 12.856 25.8467 12.5518C23.8618 10.584 26.1069 8.96777 26.627 8.77588C27.1704 8.57568 26.8159 7.8877 25.0591 7.896C23.3022 7.90381 21.6953 8.50391 19.647 9.30371C19.3477 9.42383 19.0322 9.51172 18.7095 9.58398C16.8501 9.22363 14.9199 9.14355 12.9033 9.37598C9.10596 9.80762 6.07275 11.6396 3.84326 14.7681C1.16455 18.5278 0.53418 22.7998 1.30664 27.2559C2.11768 31.9521 4.46582 35.8398 8.07373 38.8799C11.8159 42.0322 16.1255 43.5762 21.041 43.2803C24.0269 43.104 27.3516 42.6963 31.1016 39.4561C32.0469 39.936 33.0396 40.1279 34.686 40.272C35.9546 40.3921 37.1758 40.208 38.1211 40.0078C39.6021 39.688 39.4995 38.2881 38.9639 38.0322C34.623 35.9678 35.5762 36.8081 34.71 36.1279C36.9155 33.4639 40.2402 30.6958 41.54 21.728C41.6426 21.0161 41.5557 20.5679 41.54 19.9917C41.5322 19.6396 41.6108 19.5039 42.0049 19.4639C43.0923 19.3359 44.1479 19.0317 45.1167 18.4878C47.9292 16.9199 49.064 14.3438 49.3315 11.2559C49.3711 10.7837 49.3237 10.2959 48.8354 10.0479ZM24.3262 37.8398C20.1196 34.4639 18.0791 33.3521 17.2358 33.3999C16.4482 33.4482 16.5898 34.3682 16.7632 34.9678C16.9443 35.5601 17.1812 35.9683 17.5117 36.4878C17.7402 36.832 17.8979 37.3442 17.2832 37.728C15.9282 38.584 13.5728 37.4399 13.4624 37.3838C10.7207 35.7358 8.42822 33.5601 6.81348 30.584C5.25342 27.7197 4.34766 24.6479 4.19775 21.3677C4.1582 20.5757 4.38672 20.2959 5.15869 20.1519C6.17529 19.96 7.22314 19.9199 8.23926 20.0718C12.5327 20.7119 16.1885 22.6719 19.2529 25.7759C21.002 27.5439 22.3252 29.6558 23.6885 31.7202C25.1377 33.9121 26.6978 36 28.6831 37.7119C29.3843 38.312 29.9434 38.7681 30.479 39.104C28.8643 39.2881 26.1699 39.3281 24.3262 37.8398ZM26.3433 24.6001C26.3433 24.248 26.6191 23.9678 26.9658 23.9678C27.0444 23.9678 27.1152 23.9839 27.1782 24.0078C27.2651 24.04 27.3438 24.0879 27.4067 24.1602C27.5171 24.272 27.5801 24.4321 27.5801 24.6001C27.5801 24.9521 27.3042 25.2319 26.9575 25.2319C26.6108 25.2319 26.3433 24.9521 26.3433 24.6001ZM32.6064 27.8799C32.2046 28.0479 31.8027 28.1919 31.4165 28.208C30.8179 28.2397 30.1641 27.9922 29.8096 27.688C29.2583 27.2158 28.8643 26.9521 28.6987 26.1279C28.6279 25.7759 28.6675 25.2319 28.7305 24.9199C28.8721 24.248 28.7144 23.8159 28.2495 23.4238C27.8716 23.104 27.3911 23.0161 26.8633 23.0161C26.666 23.0161 26.4849 22.9277 26.3511 22.856C26.1304 22.7441 25.9492 22.4639 26.1226 22.1201C26.1777 22.0078 26.4458 21.7358 26.5088 21.688C27.2256 21.272 28.0527 21.4077 28.8169 21.7197C29.5259 22.0161 30.0615 22.5601 30.834 23.3281C31.6216 24.2559 31.7632 24.5117 32.2124 25.208C32.5669 25.752 32.8901 26.312 33.1104 26.9521C33.2446 27.3521 33.0713 27.6802 32.6064 27.8799Z"
fill="#4D6BFE" fill-opacity="1.000000" fill-rule="nonzero" />
</svg>

After

Width:  |  Height:  |  Size: 3.5 KiB

View File

@ -0,0 +1,31 @@
#!/usr/bin/env python
# -*- coding: UTF-8 -*-
"""
@Project MaxKB
@File llm.py
@Author Brian Yang
@Date 5/12/24 7:44 AM
"""
from typing import Dict
from models_provider.base_model_provider import MaxKBBaseModel
from models_provider.impl.base_chat_open_ai import BaseChatOpenAI
class DeepSeekChatModel(MaxKBBaseModel, BaseChatOpenAI):
@staticmethod
def is_cache_model():
return False
@staticmethod
def new_instance(model_type, model_name, model_credential: Dict[str, object], **model_kwargs):
optional_params = MaxKBBaseModel.filter_optional_params(model_kwargs)
deepseek_chat_open_ai = DeepSeekChatModel(
model=model_name,
openai_api_base='https://api.deepseek.com',
openai_api_key=model_credential.get('api_key'),
**optional_params
)
return deepseek_chat_open_ai

View File

@ -0,0 +1,8 @@
#!/usr/bin/env python
# -*- coding: UTF-8 -*-
"""
@Project MaxKB
@File __init__.py.py
@Author Brian Yang
@Date 5/13/24 7:40 AM
"""

View File

@ -0,0 +1,52 @@
# coding=utf-8
"""
@project: MaxKB
@Author
@file embedding.py
@date2024/7/12 16:45
@desc:
"""
import traceback
from typing import Dict
from django.utils.translation import gettext as _
from common import forms
from common.exception.app_exception import AppApiException
from common.forms import BaseForm
from models_provider.base_model_provider import BaseModelCredential, ValidCode
class GeminiEmbeddingCredential(BaseForm, BaseModelCredential):
def is_valid(self, model_type: str, model_name, model_credential: Dict[str, object], model_params, provider,
raise_exception=True):
model_type_list = provider.get_model_type_list()
if not any(list(filter(lambda mt: mt.get('value') == model_type, model_type_list))):
raise AppApiException(ValidCode.valid_error.value,
_('{model_type} Model type is not supported').format(model_type=model_type))
for key in ['api_key']:
if key not in model_credential:
if raise_exception:
raise AppApiException(ValidCode.valid_error.value, _('{key} is required').format(key=key))
else:
return False
try:
model = provider.get_model(model_type, model_name, model_credential)
model.embed_query(_('Hello'))
except Exception as e:
traceback.print_exc()
if isinstance(e, AppApiException):
raise e
if raise_exception:
raise AppApiException(ValidCode.valid_error.value,
_('Verification failed, please check whether the parameters are correct: {error}').format(
error=str(e)))
else:
return False
return True
def encryption_dict(self, model: Dict[str, object]):
return {**model, 'api_key': super().encryption(model.get('api_key', ''))}
api_key = forms.PasswordInputField('API Key', required=True)

View File

@ -0,0 +1,71 @@
# coding=utf-8
import traceback
from typing import Dict
from django.utils.translation import gettext_lazy as _, gettext
from langchain_core.messages import HumanMessage
from common import forms
from common.exception.app_exception import AppApiException
from common.forms import BaseForm, TooltipLabel
from models_provider.base_model_provider import BaseModelCredential, ValidCode
class GeminiImageModelParams(BaseForm):
temperature = forms.SliderField(TooltipLabel(_('Temperature'),
_('Higher values make the output more random, while lower values make it more focused and deterministic')),
required=True, default_value=0.7,
_min=0.1,
_max=1.0,
_step=0.01,
precision=2)
max_tokens = forms.SliderField(
TooltipLabel(_('Output the maximum Tokens'),
_('Specify the maximum number of tokens that the model can generate')),
required=True, default_value=800,
_min=1,
_max=100000,
_step=1,
precision=0)
class GeminiImageModelCredential(BaseForm, BaseModelCredential):
api_key = forms.PasswordInputField('API Key', required=True)
def is_valid(self, model_type: str, model_name, model_credential: Dict[str, object], model_params, provider,
raise_exception=False):
model_type_list = provider.get_model_type_list()
if not any(list(filter(lambda mt: mt.get('value') == model_type, model_type_list))):
raise AppApiException(ValidCode.valid_error.value,
gettext('{model_type} Model type is not supported').format(model_type=model_type))
for key in ['api_key']:
if key not in model_credential:
if raise_exception:
raise AppApiException(ValidCode.valid_error.value, gettext('{key} is required').format(key=key))
else:
return False
try:
model = provider.get_model(model_type, model_name, model_credential, **model_params)
res = model.stream([HumanMessage(content=[{"type": "text", "text": gettext('Hello')}])])
for chunk in res:
print(chunk)
except Exception as e:
traceback.print_exc()
if isinstance(e, AppApiException):
raise e
if raise_exception:
raise AppApiException(ValidCode.valid_error.value,
gettext(
'Verification failed, please check whether the parameters are correct: {error}').format(
error=str(e)))
else:
return False
return True
def encryption_dict(self, model: Dict[str, object]):
return {**model, 'api_key': super().encryption(model.get('api_key', ''))}
def get_model_params_setting_form(self, model_name):
return GeminiImageModelParams()

View File

@ -0,0 +1,78 @@
# coding=utf-8
"""
@project: MaxKB
@Author
@file llm.py
@date2024/7/11 17:57
@desc:
"""
import traceback
from typing import Dict
from django.utils.translation import gettext_lazy as _, gettext
from langchain_core.messages import HumanMessage
from common import forms
from common.exception.app_exception import AppApiException
from common.forms import BaseForm, TooltipLabel
from models_provider.base_model_provider import BaseModelCredential, ValidCode
class GeminiLLMModelParams(BaseForm):
temperature = forms.SliderField(TooltipLabel(_('Temperature'),
_('Higher values make the output more random, while lower values make it more focused and deterministic')),
required=True, default_value=0.7,
_min=0.1,
_max=1.0,
_step=0.01,
precision=2)
max_tokens = forms.SliderField(
TooltipLabel(_('Output the maximum Tokens'),
_('Specify the maximum number of tokens that the model can generate')),
required=True, default_value=800,
_min=1,
_max=100000,
_step=1,
precision=0)
class GeminiLLMModelCredential(BaseForm, BaseModelCredential):
def is_valid(self, model_type: str, model_name, model_credential: Dict[str, object], model_params, provider,
raise_exception=False):
model_type_list = provider.get_model_type_list()
if not any(list(filter(lambda mt: mt.get('value') == model_type, model_type_list))):
raise AppApiException(ValidCode.valid_error.value,
gettext('{model_type} Model type is not supported').format(model_type=model_type))
for key in ['api_key']:
if key not in model_credential:
if raise_exception:
raise AppApiException(ValidCode.valid_error.value, gettext('{key} is required').format(key=key))
else:
return False
try:
model = provider.get_model(model_type, model_name, model_credential, **model_params)
res = model.invoke([HumanMessage(content=gettext('Hello'))])
print(res)
except Exception as e:
traceback.print_exc()
if isinstance(e, AppApiException):
raise e
if raise_exception:
raise AppApiException(ValidCode.valid_error.value,
gettext(
'Verification failed, please check whether the parameters are correct: {error}').format(
error=str(e)))
else:
return False
return True
def encryption_dict(self, model: Dict[str, object]):
return {**model, 'api_key': super().encryption(model.get('api_key', ''))}
api_key = forms.PasswordInputField('API Key', required=True)
def get_model_params_setting_form(self, model_name):
return GeminiLLMModelParams()

View File

@ -0,0 +1,48 @@
# coding=utf-8
import traceback
from typing import Dict
from django.utils.translation import gettext as _
from common import forms
from common.exception.app_exception import AppApiException
from common.forms import BaseForm
from models_provider.base_model_provider import BaseModelCredential, ValidCode
class GeminiSTTModelCredential(BaseForm, BaseModelCredential):
api_key = forms.PasswordInputField('API Key', required=True)
def is_valid(self, model_type: str, model_name, model_credential: Dict[str, object], model_params, provider,
raise_exception=False):
model_type_list = provider.get_model_type_list()
if not any(list(filter(lambda mt: mt.get('value') == model_type, model_type_list))):
raise AppApiException(ValidCode.valid_error.value,
_('{model_type} Model type is not supported').format(model_type=model_type))
for key in ['api_key']:
if key not in model_credential:
if raise_exception:
raise AppApiException(ValidCode.valid_error.value, _('{key} is required').format(key=key))
else:
return False
try:
model = provider.get_model(model_type, model_name, model_credential)
model.check_auth()
except Exception as e:
traceback.print_exc()
if isinstance(e, AppApiException):
raise e
if raise_exception:
raise AppApiException(ValidCode.valid_error.value,
_('Verification failed, please check whether the parameters are correct: {error}').format(
error=str(e)))
else:
return False
return True
def encryption_dict(self, model: Dict[str, object]):
return {**model, 'api_key': super().encryption(model.get('api_key', ''))}
def get_model_params_setting_form(self, model_name):
pass

Some files were not shown because too many files have changed in this diff Show More