mirror of
https://github.com/1Panel-dev/MaxKB.git
synced 2025-12-26 01:33:05 +00:00
feat: add model setting
This commit is contained in:
parent
934871f5c6
commit
946de675ff
|
|
@ -0,0 +1,86 @@
|
|||
# coding=utf-8
|
||||
"""
|
||||
@project: qabot
|
||||
@Author:虎
|
||||
@file: file_cache.py
|
||||
@date:2023/9/11 15:58
|
||||
@desc: 文件缓存
|
||||
"""
|
||||
import datetime
|
||||
import math
|
||||
import os
|
||||
import time
|
||||
|
||||
from diskcache import Cache
|
||||
from django.core.cache.backends.base import BaseCache
|
||||
|
||||
|
||||
class FileCache(BaseCache):
|
||||
def __init__(self, dir, params):
|
||||
super().__init__(params)
|
||||
self._dir = os.path.abspath(dir)
|
||||
self._createdir()
|
||||
self.cache = Cache(self._dir)
|
||||
|
||||
def _createdir(self):
|
||||
old_umask = os.umask(0o077)
|
||||
try:
|
||||
os.makedirs(self._dir, 0o700, exist_ok=True)
|
||||
finally:
|
||||
os.umask(old_umask)
|
||||
|
||||
def add(self, key, value, timeout=None, version=None):
|
||||
expire = timeout if isinstance(timeout, int) or isinstance(timeout,
|
||||
float) or timeout is None else timeout.total_seconds()
|
||||
return self.cache.add(self.get_key(key, version), value=value, expire=expire)
|
||||
|
||||
def set(self, key, value, timeout=None, version=None):
|
||||
expire = timeout if isinstance(timeout, int) or isinstance(timeout,
|
||||
float) or timeout is None else timeout.total_seconds()
|
||||
return self.cache.set(self.get_key(key, version), value=value, expire=expire)
|
||||
|
||||
def get(self, key, default=None, version=None):
|
||||
return self.cache.get(self.get_key(key, version), default=default)
|
||||
|
||||
@staticmethod
|
||||
def get_key(key, version):
|
||||
if version is None:
|
||||
return f"default:{key}"
|
||||
return f"{version}:{key}"
|
||||
|
||||
def delete(self, key, version=None):
|
||||
return self.cache.delete(self.get_key(key, version))
|
||||
|
||||
def touch(self, key, timeout=None, version=None):
|
||||
expire = timeout if isinstance(timeout, int) or isinstance(timeout,
|
||||
float) else timeout.total_seconds()
|
||||
|
||||
return self.cache.touch(self.get_key(key, version), expire=expire)
|
||||
|
||||
def ttl(self, key, version=None):
|
||||
"""
|
||||
获取key的剩余时间
|
||||
:param key: key
|
||||
:return: 剩余时间
|
||||
@param version:
|
||||
"""
|
||||
value, expire_time = self.cache.get(self.get_key(key, version), expire_time=True)
|
||||
if value is None:
|
||||
return None
|
||||
return datetime.timedelta(seconds=math.ceil(expire_time - time.time()))
|
||||
|
||||
def clear_by_application_id(self, application_id):
|
||||
delete_keys = []
|
||||
for key in self.cache.iterkeys():
|
||||
value = self.cache.get(key)
|
||||
if (hasattr(value,
|
||||
'application') and value.application is not None and value.application.id is not None and
|
||||
str(
|
||||
value.application.id) == application_id):
|
||||
delete_keys.append(key)
|
||||
for key in delete_keys:
|
||||
self.cache.delete(key)
|
||||
|
||||
def clear_timeout_data(self):
|
||||
for key in self.cache.iterkeys():
|
||||
self.get(key)
|
||||
|
|
@ -0,0 +1,47 @@
|
|||
# coding=utf-8
|
||||
"""
|
||||
@project: maxkb
|
||||
@Author:虎
|
||||
@file: mem_cache.py
|
||||
@date:2024/3/6 11:20
|
||||
@desc:
|
||||
"""
|
||||
from django.core.cache.backends.base import DEFAULT_TIMEOUT
|
||||
from django.core.cache.backends.locmem import LocMemCache
|
||||
|
||||
|
||||
class MemCache(LocMemCache):
|
||||
def __init__(self, name, params):
|
||||
super().__init__(name, params)
|
||||
|
||||
def set(self, key, value, timeout=DEFAULT_TIMEOUT, version=None):
|
||||
key = self.make_and_validate_key(key, version=version)
|
||||
pickled = value
|
||||
with self._lock:
|
||||
self._set(key, pickled, timeout)
|
||||
|
||||
def get(self, key, default=None, version=None):
|
||||
key = self.make_and_validate_key(key, version=version)
|
||||
with self._lock:
|
||||
if self._has_expired(key):
|
||||
self._delete(key)
|
||||
return default
|
||||
pickled = self._cache[key]
|
||||
self._cache.move_to_end(key, last=False)
|
||||
return pickled
|
||||
|
||||
def clear_by_application_id(self, application_id):
|
||||
delete_keys = []
|
||||
for key in self._cache.keys():
|
||||
value = self._cache.get(key)
|
||||
if (hasattr(value,
|
||||
'application') and value.application is not None and value.application.id is not None and
|
||||
str(
|
||||
value.application.id) == application_id):
|
||||
delete_keys.append(key)
|
||||
for key in delete_keys:
|
||||
self._delete(key)
|
||||
|
||||
def clear_timeout_data(self):
|
||||
for key in self._cache.keys():
|
||||
self.get(key)
|
||||
|
|
@ -0,0 +1,66 @@
|
|||
# coding=utf-8
|
||||
"""
|
||||
@project: maxkb
|
||||
@Author:虎
|
||||
@file: embedding_config.py
|
||||
@date:2023/10/23 16:03
|
||||
@desc:
|
||||
"""
|
||||
import threading
|
||||
import time
|
||||
|
||||
from common.cache.mem_cache import MemCache
|
||||
|
||||
lock = threading.Lock()
|
||||
|
||||
|
||||
class ModelManage:
|
||||
cache = MemCache('model', {})
|
||||
up_clear_time = time.time()
|
||||
|
||||
@staticmethod
|
||||
def get_model(_id, get_model):
|
||||
# 获取锁
|
||||
lock.acquire()
|
||||
try:
|
||||
model_instance = ModelManage.cache.get(_id)
|
||||
if model_instance is None or not model_instance.is_cache_model():
|
||||
model_instance = get_model(_id)
|
||||
ModelManage.cache.set(_id, model_instance, timeout=60 * 30)
|
||||
return model_instance
|
||||
# 续期
|
||||
ModelManage.cache.touch(_id, timeout=60 * 30)
|
||||
ModelManage.clear_timeout_cache()
|
||||
return model_instance
|
||||
finally:
|
||||
# 释放锁
|
||||
lock.release()
|
||||
|
||||
@staticmethod
|
||||
def clear_timeout_cache():
|
||||
if time.time() - ModelManage.up_clear_time > 60:
|
||||
ModelManage.cache.clear_timeout_data()
|
||||
|
||||
@staticmethod
|
||||
def delete_key(_id):
|
||||
if ModelManage.cache.has_key(_id):
|
||||
ModelManage.cache.delete(_id)
|
||||
|
||||
|
||||
class VectorStore:
|
||||
from embedding.vector.pg_vector import PGVector
|
||||
from embedding.vector.base_vector import BaseVectorStore
|
||||
instance_map = {
|
||||
'pg_vector': PGVector,
|
||||
}
|
||||
instance = None
|
||||
|
||||
@staticmethod
|
||||
def get_embedding_vector() -> BaseVectorStore:
|
||||
from embedding.vector.pg_vector import PGVector
|
||||
if VectorStore.instance is None:
|
||||
from maxkb.const import CONFIG
|
||||
vector_store_class = VectorStore.instance_map.get(CONFIG.get("VECTOR_STORE_NAME"),
|
||||
PGVector)
|
||||
VectorStore.instance = vector_store_class()
|
||||
return VectorStore.instance
|
||||
|
|
@ -0,0 +1,24 @@
|
|||
# coding=utf-8
|
||||
"""
|
||||
@project: maxkb
|
||||
@Author:虎
|
||||
@file: tokenizer_manage_config.py
|
||||
@date:2024/4/28 10:17
|
||||
@desc:
|
||||
"""
|
||||
|
||||
|
||||
class TokenizerManage:
|
||||
tokenizer = None
|
||||
|
||||
@staticmethod
|
||||
def get_tokenizer():
|
||||
from transformers import BertTokenizer
|
||||
if TokenizerManage.tokenizer is None:
|
||||
TokenizerManage.tokenizer = BertTokenizer.from_pretrained(
|
||||
'bert-base-cased',
|
||||
cache_dir="/opt/maxkb/model/tokenizer",
|
||||
local_files_only=True,
|
||||
resume_download=False,
|
||||
force_download=False)
|
||||
return TokenizerManage.tokenizer
|
||||
|
|
@ -111,6 +111,13 @@ class PermissionConstants(Enum):
|
|||
USER_DELETE = Permission(group=Group.USER, operate=Operate.DELETE, role_list=[RoleConstants.ADMIN])
|
||||
TOOL_CREATE = Permission(group=Group.USER, operate=Operate.CREATE, role_list=[RoleConstants.ADMIN,
|
||||
RoleConstants.USER])
|
||||
MODEL_CREATE = Permission(group=Group.USER, operate=Operate.CREATE, role_list=[RoleConstants.ADMIN,
|
||||
RoleConstants.USER])
|
||||
MODEL_READ = Permission(group=Group.USER, operate=Operate.READ, role_list=[RoleConstants.ADMIN,
|
||||
RoleConstants.USER])
|
||||
MODEL_EDIT = Permission(group=Group.USER, operate=Operate.EDIT, role_list=[RoleConstants.ADMIN, RoleConstants.USER])
|
||||
MODEL_DELETE = Permission(group=Group.USER, operate=Operate.DELETE,
|
||||
role_list=[RoleConstants.ADMIN, RoleConstants.USER])
|
||||
|
||||
def get_workspace_application_permission(self):
|
||||
return lambda r, kwargs: Permission(group=self.value.group, operate=self.value.operate,
|
||||
|
|
|
|||
|
|
@ -0,0 +1,24 @@
|
|||
# coding=utf-8
|
||||
"""
|
||||
@project: maxkb
|
||||
@Author:虎
|
||||
@file: __init__.py
|
||||
@date:2023/10/31 17:56
|
||||
@desc:
|
||||
"""
|
||||
from .array_object_card import *
|
||||
from .base_field import *
|
||||
from .base_form import *
|
||||
from .multi_select import *
|
||||
from .object_card import *
|
||||
from .password_input import *
|
||||
from .radio_field import *
|
||||
from .single_select_field import *
|
||||
from .tab_card import *
|
||||
from .table_radio import *
|
||||
from .text_input_field import *
|
||||
from .radio_button_field import *
|
||||
from .table_checkbox import *
|
||||
from .radio_card_field import *
|
||||
from .label import *
|
||||
from .slider_field import *
|
||||
|
|
@ -0,0 +1,33 @@
|
|||
# coding=utf-8
|
||||
"""
|
||||
@project: maxkb
|
||||
@Author:虎
|
||||
@file: array_object_card.py
|
||||
@date:2023/10/31 18:03
|
||||
@desc:
|
||||
"""
|
||||
from typing import Dict
|
||||
|
||||
from common.forms.base_field import BaseExecField, TriggerType
|
||||
|
||||
|
||||
class ArrayCard(BaseExecField):
|
||||
"""
|
||||
收集List[Object]
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
label: str,
|
||||
text_field: str,
|
||||
value_field: str,
|
||||
provider: str,
|
||||
method: str,
|
||||
required: bool = False,
|
||||
default_value: object = None,
|
||||
relation_show_field_dict: Dict = None,
|
||||
relation_trigger_field_dict: Dict = None,
|
||||
trigger_type: TriggerType = TriggerType.OPTION_LIST,
|
||||
attrs: Dict[str, object] = None,
|
||||
props_info: Dict[str, object] = None):
|
||||
super().__init__("ArrayObjectCard", label, text_field, value_field, provider, method, required, default_value,
|
||||
relation_show_field_dict, relation_trigger_field_dict, trigger_type, attrs, props_info)
|
||||
|
|
@ -0,0 +1,157 @@
|
|||
# coding=utf-8
|
||||
"""
|
||||
@project: maxkb
|
||||
@Author:虎
|
||||
@file: base_field.py
|
||||
@date:2023/10/31 18:07
|
||||
@desc:
|
||||
"""
|
||||
from enum import Enum
|
||||
from typing import List, Dict
|
||||
|
||||
from common.exception.app_exception import AppApiException
|
||||
from common.forms.label.base_label import BaseLabel
|
||||
from django.utils.translation import gettext_lazy as _
|
||||
|
||||
|
||||
class TriggerType(Enum):
|
||||
# 执行函数获取 OptionList数据
|
||||
OPTION_LIST = 'OPTION_LIST'
|
||||
# 执行函数获取子表单
|
||||
CHILD_FORMS = 'CHILD_FORMS'
|
||||
|
||||
|
||||
class BaseField:
|
||||
def __init__(self,
|
||||
input_type: str,
|
||||
label: str or BaseLabel,
|
||||
required: bool = False,
|
||||
default_value: object = None,
|
||||
relation_show_field_dict: Dict = None,
|
||||
relation_trigger_field_dict: Dict = None,
|
||||
trigger_type: TriggerType = TriggerType.OPTION_LIST,
|
||||
attrs: Dict[str, object] = None,
|
||||
props_info: Dict[str, object] = None):
|
||||
"""
|
||||
|
||||
:param input_type: 字段
|
||||
:param label: 提示
|
||||
:param default_value: 默认值
|
||||
:param relation_show_field_dict: {field:field_value_list} 表示在 field有值 ,并且值在field_value_list中才显示
|
||||
:param relation_trigger_field_dict: {field:field_value_list} 表示在 field有值 ,并且值在field_value_list中才 执行函数获取 数据
|
||||
:param trigger_type: 执行器类型 OPTION_LIST请求Option_list数据 CHILD_FORMS请求子表单
|
||||
:param attrs: 前端attr数据
|
||||
:param props_info: 其他额外信息
|
||||
"""
|
||||
if props_info is None:
|
||||
props_info = {}
|
||||
if attrs is None:
|
||||
attrs = {}
|
||||
self.label = label
|
||||
self.attrs = attrs
|
||||
self.props_info = props_info
|
||||
self.default_value = default_value
|
||||
self.input_type = input_type
|
||||
self.relation_show_field_dict = {} if relation_show_field_dict is None else relation_show_field_dict
|
||||
self.relation_trigger_field_dict = [] if relation_trigger_field_dict is None else relation_trigger_field_dict
|
||||
self.required = required
|
||||
self.trigger_type = trigger_type
|
||||
|
||||
def is_valid(self, value):
|
||||
field_label = self.label.label if hasattr(self.label, 'to_dict') else self.label
|
||||
if self.required and value is None:
|
||||
raise AppApiException(500,
|
||||
_('The field {field_label} is required').format(field_label=field_label))
|
||||
|
||||
def to_dict(self, **kwargs):
|
||||
return {
|
||||
'input_type': self.input_type,
|
||||
'label': self.label.to_dict(**kwargs) if hasattr(self.label, 'to_dict') else self.label,
|
||||
'required': self.required,
|
||||
'default_value': self.default_value,
|
||||
'relation_show_field_dict': self.relation_show_field_dict,
|
||||
'relation_trigger_field_dict': self.relation_trigger_field_dict,
|
||||
'trigger_type': self.trigger_type.value,
|
||||
'attrs': self.attrs,
|
||||
'props_info': self.props_info,
|
||||
**kwargs
|
||||
}
|
||||
|
||||
|
||||
class BaseDefaultOptionField(BaseField):
|
||||
def __init__(self, input_type: str,
|
||||
label: str,
|
||||
text_field: str,
|
||||
value_field: str,
|
||||
option_list: List[dict],
|
||||
required: bool = False,
|
||||
default_value: object = None,
|
||||
relation_show_field_dict: Dict[str, object] = None,
|
||||
attrs: Dict[str, object] = None,
|
||||
props_info: Dict[str, object] = None):
|
||||
"""
|
||||
|
||||
:param input_type: 字段
|
||||
:param label: label
|
||||
:param text_field: 文本字段
|
||||
:param value_field: 值字段
|
||||
:param option_list: 可选列表
|
||||
:param required: 是否必填
|
||||
:param default_value: 默认值
|
||||
:param relation_show_field_dict: {field:field_value_list} 表示在 field有值 ,并且值在field_value_list中才显示
|
||||
:param attrs: 前端attr数据
|
||||
:param props_info: 其他额外信息
|
||||
"""
|
||||
super().__init__(input_type, label, required, default_value, relation_show_field_dict,
|
||||
{}, TriggerType.OPTION_LIST, attrs, props_info)
|
||||
self.text_field = text_field
|
||||
self.value_field = value_field
|
||||
self.option_list = option_list
|
||||
|
||||
def to_dict(self, **kwargs):
|
||||
return {**super().to_dict(**kwargs), 'text_field': self.text_field, 'value_field': self.value_field,
|
||||
'option_list': self.option_list}
|
||||
|
||||
|
||||
class BaseExecField(BaseField):
|
||||
def __init__(self,
|
||||
input_type: str,
|
||||
label: str,
|
||||
text_field: str,
|
||||
value_field: str,
|
||||
provider: str,
|
||||
method: str,
|
||||
required: bool = False,
|
||||
default_value: object = None,
|
||||
relation_show_field_dict: Dict = None,
|
||||
relation_trigger_field_dict: Dict = None,
|
||||
trigger_type: TriggerType = TriggerType.OPTION_LIST,
|
||||
attrs: Dict[str, object] = None,
|
||||
props_info: Dict[str, object] = None):
|
||||
"""
|
||||
|
||||
:param input_type: 字段
|
||||
:param label: 提示
|
||||
:param text_field: 文本字段
|
||||
:param value_field: 值字段
|
||||
:param provider: 指定供应商
|
||||
:param method: 执行供应商函数 method
|
||||
:param required: 是否必填
|
||||
:param default_value: 默认值
|
||||
:param relation_show_field_dict: {field:field_value_list} 表示在 field有值 ,并且值在field_value_list中才显示
|
||||
:param relation_trigger_field_dict: {field:field_value_list} 表示在 field有值 ,并且值在field_value_list中才 执行函数获取 数据
|
||||
:param trigger_type: 执行器类型 OPTION_LIST请求Option_list数据 CHILD_FORMS请求子表单
|
||||
:param attrs: 前端attr数据
|
||||
:param props_info: 其他额外信息
|
||||
"""
|
||||
super().__init__(input_type, label, required, default_value, relation_show_field_dict,
|
||||
relation_trigger_field_dict,
|
||||
trigger_type, attrs, props_info)
|
||||
self.text_field = text_field
|
||||
self.value_field = value_field
|
||||
self.provider = provider
|
||||
self.method = method
|
||||
|
||||
def to_dict(self, **kwargs):
|
||||
return {**super().to_dict(**kwargs), 'text_field': self.text_field, 'value_field': self.value_field,
|
||||
'provider': self.provider, 'method': self.method}
|
||||
|
|
@ -0,0 +1,30 @@
|
|||
# coding=utf-8
|
||||
"""
|
||||
@project: maxkb
|
||||
@Author:虎
|
||||
@file: base_form.py
|
||||
@date:2023/11/1 16:04
|
||||
@desc:
|
||||
"""
|
||||
from typing import Dict
|
||||
|
||||
from common.forms import BaseField
|
||||
|
||||
|
||||
class BaseForm:
|
||||
def to_form_list(self, **kwargs):
|
||||
return [{**self.__getattribute__(key).to_dict(**kwargs), 'field': key} for key in
|
||||
list(filter(lambda key: isinstance(self.__getattribute__(key), BaseField),
|
||||
[attr for attr in vars(self.__class__) if not attr.startswith("__")]))]
|
||||
|
||||
def valid_form(self, form_data):
|
||||
field_keys = list(filter(lambda key: isinstance(self.__getattribute__(key), BaseField),
|
||||
[attr for attr in vars(self.__class__) if not attr.startswith("__")]))
|
||||
for field_key in field_keys:
|
||||
self.__getattribute__(field_key).is_valid(form_data.get(field_key))
|
||||
|
||||
def get_default_form_data(self):
|
||||
return {key: self.__getattribute__(key).default_value for key in
|
||||
[attr for attr in vars(self.__class__) if not attr.startswith("__")] if
|
||||
isinstance(self.__getattribute__(key), BaseField) and self.__getattribute__(
|
||||
key).default_value is not None}
|
||||
|
|
@ -0,0 +1,10 @@
|
|||
# coding=utf-8
|
||||
"""
|
||||
@project: MaxKB
|
||||
@Author:虎
|
||||
@file: __init__.py.py
|
||||
@date:2024/8/22 17:19
|
||||
@desc:
|
||||
"""
|
||||
from .base_label import *
|
||||
from .tooltip_label import *
|
||||
|
|
@ -0,0 +1,28 @@
|
|||
# coding=utf-8
|
||||
"""
|
||||
@project: MaxKB
|
||||
@Author:虎
|
||||
@file: base_label.py
|
||||
@date:2024/8/22 17:11
|
||||
@desc:
|
||||
"""
|
||||
|
||||
|
||||
class BaseLabel:
|
||||
def __init__(self,
|
||||
input_type: str,
|
||||
label: str,
|
||||
attrs=None,
|
||||
props_info=None):
|
||||
self.input_type = input_type
|
||||
self.label = label
|
||||
self.attrs = attrs
|
||||
self.props_info = props_info
|
||||
|
||||
def to_dict(self, **kwargs):
|
||||
return {
|
||||
'input_type': self.input_type,
|
||||
'label': self.label,
|
||||
'attrs': {} if self.attrs is None else self.attrs,
|
||||
'props_info': {} if self.props_info is None else self.props_info,
|
||||
}
|
||||
|
|
@ -0,0 +1,14 @@
|
|||
# coding=utf-8
|
||||
"""
|
||||
@project: MaxKB
|
||||
@Author:虎
|
||||
@file: tooltip_label.py
|
||||
@date:2024/8/22 17:19
|
||||
@desc:
|
||||
"""
|
||||
from common.forms.label.base_label import BaseLabel
|
||||
|
||||
|
||||
class TooltipLabel(BaseLabel):
|
||||
def __init__(self, label, tooltip):
|
||||
super().__init__('TooltipLabel', label, attrs={'tooltip': tooltip}, props_info={})
|
||||
|
|
@ -0,0 +1,38 @@
|
|||
# coding=utf-8
|
||||
"""
|
||||
@project: maxkb
|
||||
@Author:虎
|
||||
@file: multi_select.py
|
||||
@date:2023/10/31 18:00
|
||||
@desc:
|
||||
"""
|
||||
from typing import List, Dict
|
||||
|
||||
from common.forms.base_field import BaseExecField, TriggerType
|
||||
|
||||
|
||||
class MultiSelect(BaseExecField):
|
||||
"""
|
||||
下拉单选
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
label: str,
|
||||
text_field: str,
|
||||
value_field: str,
|
||||
option_list: List[str:object],
|
||||
provider: str = None,
|
||||
method: str = None,
|
||||
required: bool = False,
|
||||
default_value: object = None,
|
||||
relation_show_field_dict: Dict = None,
|
||||
relation_trigger_field_dict: Dict = None,
|
||||
trigger_type: TriggerType = TriggerType.OPTION_LIST,
|
||||
attrs: Dict[str, object] = None,
|
||||
props_info: Dict[str, object] = None):
|
||||
super().__init__("MultiSelect", label, text_field, value_field, provider, method, required, default_value,
|
||||
relation_show_field_dict, relation_trigger_field_dict, trigger_type, attrs, props_info)
|
||||
self.option_list = option_list
|
||||
|
||||
def to_dict(self):
|
||||
return {**super().to_dict(), 'option_list': self.option_list}
|
||||
|
|
@ -0,0 +1,33 @@
|
|||
# coding=utf-8
|
||||
"""
|
||||
@project: maxkb
|
||||
@Author:虎
|
||||
@file: object_card.py
|
||||
@date:2023/10/31 18:02
|
||||
@desc:
|
||||
"""
|
||||
from typing import Dict
|
||||
|
||||
from common.forms.base_field import BaseExecField, TriggerType
|
||||
|
||||
|
||||
class ObjectCard(BaseExecField):
|
||||
"""
|
||||
收集对象子表卡片
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
label: str,
|
||||
text_field: str,
|
||||
value_field: str,
|
||||
provider: str,
|
||||
method: str,
|
||||
required: bool = False,
|
||||
default_value: object = None,
|
||||
relation_show_field_dict: Dict = None,
|
||||
relation_trigger_field_dict: Dict = None,
|
||||
trigger_type: TriggerType = TriggerType.OPTION_LIST,
|
||||
attrs: Dict[str, object] = None,
|
||||
props_info: Dict[str, object] = None):
|
||||
super().__init__("ObjectCard", label, text_field, value_field, provider, method, required, default_value,
|
||||
relation_show_field_dict, relation_trigger_field_dict, trigger_type, attrs, props_info)
|
||||
|
|
@ -0,0 +1,26 @@
|
|||
# coding=utf-8
|
||||
"""
|
||||
@project: maxkb
|
||||
@Author:虎
|
||||
@file: password_input.py
|
||||
@date:2023/11/1 14:48
|
||||
@desc:
|
||||
"""
|
||||
from typing import Dict
|
||||
|
||||
from common.forms import BaseField, TriggerType
|
||||
|
||||
|
||||
class PasswordInputField(BaseField):
|
||||
"""
|
||||
文本输入框
|
||||
"""
|
||||
|
||||
def __init__(self, label: str,
|
||||
required: bool = False,
|
||||
default_value=None,
|
||||
relation_show_field_dict: Dict = None,
|
||||
attrs=None, props_info=None):
|
||||
super().__init__('PasswordInput', label, required, default_value, relation_show_field_dict,
|
||||
{},
|
||||
TriggerType.OPTION_LIST, attrs, props_info)
|
||||
|
|
@ -0,0 +1,38 @@
|
|||
# coding=utf-8
|
||||
"""
|
||||
@project: maxkb
|
||||
@Author:虎
|
||||
@file: radio_field.py
|
||||
@date:2023/10/31 17:59
|
||||
@desc:
|
||||
"""
|
||||
from typing import List, Dict
|
||||
|
||||
from common.forms.base_field import BaseExecField, TriggerType
|
||||
|
||||
|
||||
class Radio(BaseExecField):
|
||||
"""
|
||||
下拉单选
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
label: str,
|
||||
text_field: str,
|
||||
value_field: str,
|
||||
option_list: List[str:object],
|
||||
provider: str,
|
||||
method: str,
|
||||
required: bool = False,
|
||||
default_value: object = None,
|
||||
relation_show_field_dict: Dict = None,
|
||||
relation_trigger_field_dict: Dict = None,
|
||||
trigger_type: TriggerType = TriggerType.OPTION_LIST,
|
||||
attrs: Dict[str, object] = None,
|
||||
props_info: Dict[str, object] = None):
|
||||
super().__init__("RadioButton", label, text_field, value_field, provider, method, required, default_value,
|
||||
relation_show_field_dict, relation_trigger_field_dict, trigger_type, attrs, props_info)
|
||||
self.option_list = option_list
|
||||
|
||||
def to_dict(self):
|
||||
return {**super().to_dict(), 'option_list': self.option_list}
|
||||
|
|
@ -0,0 +1,38 @@
|
|||
# coding=utf-8
|
||||
"""
|
||||
@project: maxkb
|
||||
@Author:虎
|
||||
@file: radio_field.py
|
||||
@date:2023/10/31 17:59
|
||||
@desc:
|
||||
"""
|
||||
from typing import List, Dict
|
||||
|
||||
from common.forms.base_field import BaseExecField, TriggerType
|
||||
|
||||
|
||||
class Radio(BaseExecField):
|
||||
"""
|
||||
下拉单选
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
label: str,
|
||||
text_field: str,
|
||||
value_field: str,
|
||||
option_list: List[str:object],
|
||||
provider: str,
|
||||
method: str,
|
||||
required: bool = False,
|
||||
default_value: object = None,
|
||||
relation_show_field_dict: Dict = None,
|
||||
relation_trigger_field_dict: Dict = None,
|
||||
trigger_type: TriggerType = TriggerType.OPTION_LIST,
|
||||
attrs: Dict[str, object] = None,
|
||||
props_info: Dict[str, object] = None):
|
||||
super().__init__("RadioCard", label, text_field, value_field, provider, method, required, default_value,
|
||||
relation_show_field_dict, relation_trigger_field_dict, trigger_type, attrs, props_info)
|
||||
self.option_list = option_list
|
||||
|
||||
def to_dict(self):
|
||||
return {**super().to_dict(), 'option_list': self.option_list}
|
||||
|
|
@ -0,0 +1,38 @@
|
|||
# coding=utf-8
|
||||
"""
|
||||
@project: maxkb
|
||||
@Author:虎
|
||||
@file: radio_field.py
|
||||
@date:2023/10/31 17:59
|
||||
@desc:
|
||||
"""
|
||||
from typing import List, Dict
|
||||
|
||||
from common.forms.base_field import BaseExecField, TriggerType
|
||||
|
||||
|
||||
class Radio(BaseExecField):
|
||||
"""
|
||||
下拉单选
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
label: str,
|
||||
text_field: str,
|
||||
value_field: str,
|
||||
option_list: List[str:object],
|
||||
provider: str,
|
||||
method: str,
|
||||
required: bool = False,
|
||||
default_value: object = None,
|
||||
relation_show_field_dict: Dict = None,
|
||||
relation_trigger_field_dict: Dict = None,
|
||||
trigger_type: TriggerType = TriggerType.OPTION_LIST,
|
||||
attrs: Dict[str, object] = None,
|
||||
props_info: Dict[str, object] = None):
|
||||
super().__init__("Radio", label, text_field, value_field, provider, method, required, default_value,
|
||||
relation_show_field_dict, relation_trigger_field_dict, trigger_type, attrs, props_info)
|
||||
self.option_list = option_list
|
||||
|
||||
def to_dict(self):
|
||||
return {**super().to_dict(), 'option_list': self.option_list}
|
||||
|
|
@ -0,0 +1,39 @@
|
|||
# coding=utf-8
|
||||
"""
|
||||
@project: maxkb
|
||||
@Author:虎
|
||||
@file: single_select_field.py
|
||||
@date:2023/10/31 18:00
|
||||
@desc:
|
||||
"""
|
||||
from typing import List, Dict
|
||||
|
||||
from common.forms import BaseLabel
|
||||
from common.forms.base_field import TriggerType, BaseExecField
|
||||
|
||||
|
||||
class SingleSelect(BaseExecField):
|
||||
"""
|
||||
下拉单选
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
label: str or BaseLabel,
|
||||
text_field: str,
|
||||
value_field: str,
|
||||
option_list: List[str:object],
|
||||
provider: str = None,
|
||||
method: str = None,
|
||||
required: bool = False,
|
||||
default_value: object = None,
|
||||
relation_show_field_dict: Dict = None,
|
||||
relation_trigger_field_dict: Dict = None,
|
||||
trigger_type: TriggerType = TriggerType.OPTION_LIST,
|
||||
attrs: Dict[str, object] = None,
|
||||
props_info: Dict[str, object] = None):
|
||||
super().__init__("SingleSelect", label, text_field, value_field, provider, method, required, default_value,
|
||||
relation_show_field_dict, relation_trigger_field_dict, trigger_type, attrs, props_info)
|
||||
self.option_list = option_list
|
||||
|
||||
def to_dict(self):
|
||||
return {**super().to_dict(), 'option_list': self.option_list}
|
||||
|
|
@ -0,0 +1,65 @@
|
|||
# coding=utf-8
|
||||
"""
|
||||
@project: MaxKB
|
||||
@Author:虎
|
||||
@file: slider_field.py
|
||||
@date:2024/8/22 17:06
|
||||
@desc:
|
||||
"""
|
||||
from typing import Dict
|
||||
|
||||
from common.exception.app_exception import AppApiException
|
||||
from common.forms import BaseField, TriggerType, BaseLabel
|
||||
from django.utils.translation import gettext_lazy as _
|
||||
|
||||
|
||||
class SliderField(BaseField):
|
||||
"""
|
||||
滑块输入框
|
||||
"""
|
||||
|
||||
def __init__(self, label: str or BaseLabel,
|
||||
_min,
|
||||
_max,
|
||||
_step,
|
||||
precision,
|
||||
required: bool = False,
|
||||
default_value=None,
|
||||
relation_show_field_dict: Dict = None,
|
||||
attrs=None, props_info=None):
|
||||
"""
|
||||
@param label: 提示
|
||||
@param _min: 最小值
|
||||
@param _max: 最大值
|
||||
@param _step: 步长
|
||||
@param precision: 保留多少小数
|
||||
@param required: 是否必填
|
||||
@param default_value: 默认值
|
||||
@param relation_show_field_dict:
|
||||
@param attrs:
|
||||
@param props_info:
|
||||
"""
|
||||
_attrs = {'min': _min, 'max': _max, 'step': _step,
|
||||
'precision': precision, 'show-input-controls': False, 'show-input': True}
|
||||
if attrs is not None:
|
||||
_attrs.update(attrs)
|
||||
super().__init__('Slider', label, required, default_value, relation_show_field_dict,
|
||||
{},
|
||||
TriggerType.OPTION_LIST, _attrs, props_info)
|
||||
|
||||
def is_valid(self, value):
|
||||
super().is_valid(value)
|
||||
field_label = self.label.label if hasattr(self.label, 'to_dict') else self.label
|
||||
if value is not None:
|
||||
if value < self.attrs.get('min'):
|
||||
raise AppApiException(500,
|
||||
_("The {field_label} cannot be less than {min}").format(field_label=field_label,
|
||||
min=self.attrs.get(
|
||||
'min')))
|
||||
|
||||
if value > self.attrs.get('max'):
|
||||
raise AppApiException(500,
|
||||
_("The {field_label} cannot be greater than {max}").format(
|
||||
field_label=field_label,
|
||||
max=self.attrs.get(
|
||||
'max')))
|
||||
|
|
@ -0,0 +1,33 @@
|
|||
"""
|
||||
@project: MaxKB
|
||||
@Author:虎
|
||||
@file: switch_field.py
|
||||
@date:2024/10/13 19:43
|
||||
@desc:
|
||||
"""
|
||||
from typing import Dict
|
||||
from common.forms import BaseField, TriggerType, BaseLabel
|
||||
|
||||
|
||||
class SwitchField(BaseField):
|
||||
"""
|
||||
滑块输入框
|
||||
"""
|
||||
|
||||
def __init__(self, label: str or BaseLabel,
|
||||
required: bool = False,
|
||||
default_value=None,
|
||||
relation_show_field_dict: Dict = None,
|
||||
|
||||
attrs=None, props_info=None):
|
||||
"""
|
||||
@param required: 是否必填
|
||||
@param default_value: 默认值
|
||||
@param relation_show_field_dict:
|
||||
@param attrs:
|
||||
@param props_info:
|
||||
"""
|
||||
|
||||
super().__init__('Switch', label, required, default_value, relation_show_field_dict,
|
||||
{},
|
||||
TriggerType.OPTION_LIST, attrs, props_info)
|
||||
|
|
@ -0,0 +1,33 @@
|
|||
# coding=utf-8
|
||||
"""
|
||||
@project: maxkb
|
||||
@Author:虎
|
||||
@file: tab_card.py
|
||||
@date:2023/10/31 18:03
|
||||
@desc:
|
||||
"""
|
||||
from typing import Dict
|
||||
|
||||
from common.forms.base_field import BaseExecField, TriggerType
|
||||
|
||||
|
||||
class TabCard(BaseExecField):
|
||||
"""
|
||||
收集 Tab类型数据 tab1:{},tab2:{}
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
label: str,
|
||||
text_field: str,
|
||||
value_field: str,
|
||||
provider: str,
|
||||
method: str,
|
||||
required: bool = False,
|
||||
default_value: object = None,
|
||||
relation_show_field_dict: Dict = None,
|
||||
relation_trigger_field_dict: Dict = None,
|
||||
trigger_type: TriggerType = TriggerType.OPTION_LIST,
|
||||
attrs: Dict[str, object] = None,
|
||||
props_info: Dict[str, object] = None):
|
||||
super().__init__("TabCard", label, text_field, value_field, provider, method, required, default_value,
|
||||
relation_show_field_dict, relation_trigger_field_dict, trigger_type, attrs, props_info)
|
||||
|
|
@ -0,0 +1,33 @@
|
|||
# coding=utf-8
|
||||
"""
|
||||
@project: maxkb
|
||||
@Author:虎
|
||||
@file: table_radio.py
|
||||
@date:2023/10/31 18:01
|
||||
@desc:
|
||||
"""
|
||||
from typing import Dict
|
||||
|
||||
from common.forms.base_field import TriggerType, BaseExecField
|
||||
|
||||
|
||||
class TableRadio(BaseExecField):
|
||||
"""
|
||||
table 单选
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
label: str,
|
||||
text_field: str,
|
||||
value_field: str,
|
||||
provider: str,
|
||||
method: str,
|
||||
required: bool = False,
|
||||
default_value: object = None,
|
||||
relation_show_field_dict: Dict = None,
|
||||
relation_trigger_field_dict: Dict = None,
|
||||
trigger_type: TriggerType = TriggerType.OPTION_LIST,
|
||||
attrs: Dict[str, object] = None,
|
||||
props_info: Dict[str, object] = None):
|
||||
super().__init__("TableCheckbox", label, text_field, value_field, provider, method, required, default_value,
|
||||
relation_show_field_dict, relation_trigger_field_dict, trigger_type, attrs, props_info)
|
||||
|
|
@ -0,0 +1,33 @@
|
|||
# coding=utf-8
|
||||
"""
|
||||
@project: maxkb
|
||||
@Author:虎
|
||||
@file: table_radio.py
|
||||
@date:2023/10/31 18:01
|
||||
@desc:
|
||||
"""
|
||||
from typing import Dict
|
||||
|
||||
from common.forms.base_field import TriggerType, BaseExecField
|
||||
|
||||
|
||||
class TableRadio(BaseExecField):
|
||||
"""
|
||||
table 单选
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
label: str,
|
||||
text_field: str,
|
||||
value_field: str,
|
||||
provider: str,
|
||||
method: str,
|
||||
required: bool = False,
|
||||
default_value: object = None,
|
||||
relation_show_field_dict: Dict = None,
|
||||
relation_trigger_field_dict: Dict = None,
|
||||
trigger_type: TriggerType = TriggerType.OPTION_LIST,
|
||||
attrs: Dict[str, object] = None,
|
||||
props_info: Dict[str, object] = None):
|
||||
super().__init__("TableRadio", label, text_field, value_field, provider, method, required, default_value,
|
||||
relation_show_field_dict, relation_trigger_field_dict, trigger_type, attrs, props_info)
|
||||
|
|
@ -0,0 +1,28 @@
|
|||
# coding=utf-8
|
||||
"""
|
||||
@project: maxkb
|
||||
@Author:虎
|
||||
@file: text_input_field.py
|
||||
@date:2023/10/31 17:58
|
||||
@desc:
|
||||
"""
|
||||
from typing import Dict
|
||||
|
||||
from common.forms import BaseLabel
|
||||
from common.forms.base_field import BaseField, TriggerType
|
||||
|
||||
|
||||
class TextInputField(BaseField):
|
||||
"""
|
||||
文本输入框
|
||||
"""
|
||||
|
||||
def __init__(self, label: str or BaseLabel,
|
||||
required: bool = False,
|
||||
default_value=None,
|
||||
relation_show_field_dict: Dict = None,
|
||||
|
||||
attrs=None, props_info=None):
|
||||
super().__init__('TextInput', label, required, default_value, relation_show_field_dict,
|
||||
{},
|
||||
TriggerType.OPTION_LIST, attrs, props_info)
|
||||
|
|
@ -0,0 +1,18 @@
|
|||
# coding=utf-8
|
||||
"""
|
||||
@project: maxkb
|
||||
@Author:虎
|
||||
@file: app_model_mixin.py
|
||||
@date:2023/9/21 9:41
|
||||
@desc:
|
||||
"""
|
||||
from django.db import models
|
||||
|
||||
|
||||
class AppModelMixin(models.Model):
|
||||
create_time = models.DateTimeField(verbose_name="创建时间", auto_now_add=True)
|
||||
update_time = models.DateTimeField(verbose_name="修改时间", auto_now=True)
|
||||
|
||||
class Meta:
|
||||
abstract = True
|
||||
ordering = ['create_time']
|
||||
|
|
@ -7,8 +7,18 @@
|
|||
@desc:
|
||||
"""
|
||||
import hashlib
|
||||
import io
|
||||
import mimetypes
|
||||
import re
|
||||
import shutil
|
||||
from typing import List
|
||||
|
||||
from django.core.files.uploadedfile import InMemoryUploadedFile
|
||||
from django.utils.translation import gettext as _
|
||||
from pydub import AudioSegment
|
||||
|
||||
from ..exception.app_exception import AppApiException
|
||||
|
||||
|
||||
def password_encrypt(row_password):
|
||||
"""
|
||||
|
|
@ -36,3 +46,153 @@ def group_by(list_source: List, key):
|
|||
array.append(e)
|
||||
result[k] = array
|
||||
return result
|
||||
|
||||
|
||||
def encryption(message: str):
|
||||
"""
|
||||
加密敏感字段数据 加密方式是 如果密码是 1234567890 那么给前端则是 123******890
|
||||
:param message:
|
||||
:return:
|
||||
"""
|
||||
max_pre_len = 8
|
||||
max_post_len = 4
|
||||
message_len = len(message)
|
||||
pre_len = int(message_len / 5 * 2)
|
||||
post_len = int(message_len / 5 * 1)
|
||||
pre_str = "".join([message[index] for index in
|
||||
range(0, max_pre_len if pre_len > max_pre_len else 1 if pre_len <= 0 else int(pre_len))])
|
||||
end_str = "".join(
|
||||
[message[index] for index in
|
||||
range(message_len - (int(post_len) if pre_len < max_post_len else max_post_len), message_len)])
|
||||
content = "***************"
|
||||
return pre_str + content + end_str
|
||||
|
||||
|
||||
def _remove_empty_lines(text):
|
||||
if not isinstance(text, str):
|
||||
raise AppApiException(500, _('Text-to-speech node, the text content must be of string type'))
|
||||
if not text:
|
||||
raise AppApiException(500, _('Text-to-speech node, the text content cannot be empty'))
|
||||
result = '\n'.join(line for line in text.split('\n') if line.strip())
|
||||
return markdown_to_plain_text(result)
|
||||
|
||||
|
||||
def markdown_to_plain_text(md: str) -> str:
|
||||
# 移除图片 
|
||||
text = re.sub(r'!\[.*?\]\(.*?\)', '', md)
|
||||
# 移除链接 [text](url)
|
||||
text = re.sub(r'\[([^\]]+)\]\([^)]+\)', r'\1', text)
|
||||
# 移除 Markdown 标题符号 (#, ##, ###)
|
||||
text = re.sub(r'^#{1,6}\s+', '', text, flags=re.MULTILINE)
|
||||
# 移除加粗 **text** 或 __text__
|
||||
text = re.sub(r'\*\*(.*?)\*\*', r'\1', text)
|
||||
text = re.sub(r'__(.*?)__', r'\1', text)
|
||||
# 移除斜体 *text* 或 _text_
|
||||
text = re.sub(r'\*(.*?)\*', r'\1', text)
|
||||
text = re.sub(r'_(.*?)_', r'\1', text)
|
||||
# 移除行内代码 `code`
|
||||
text = re.sub(r'`(.*?)`', r'\1', text)
|
||||
# 移除代码块 ```code```
|
||||
text = re.sub(r'```[\s\S]*?```', '', text)
|
||||
# 移除多余的换行符
|
||||
text = re.sub(r'\n{2,}', '\n', text)
|
||||
# 使用正则表达式去除所有 HTML 标签
|
||||
text = re.sub(r'<[^>]+>', '', text)
|
||||
# 去除多余的空白字符(包括换行符、制表符等)
|
||||
text = re.sub(r'\s+', ' ', text)
|
||||
# 去除表单渲染
|
||||
re.sub(r'<form_rander>[\s\S]*?<\/form_rander>', '', text)
|
||||
# 去除首尾空格
|
||||
text = text.strip()
|
||||
return text
|
||||
|
||||
|
||||
def get_file_content(path):
|
||||
with open(path, "r", encoding='utf-8') as file:
|
||||
content = file.read()
|
||||
return content
|
||||
|
||||
|
||||
|
||||
def bytes_to_uploaded_file(file_bytes, file_name="file.txt"):
|
||||
content_type, _ = mimetypes.guess_type(file_name)
|
||||
if content_type is None:
|
||||
# 如果未能识别,设置为默认的二进制文件类型
|
||||
content_type = "application/octet-stream"
|
||||
# 创建一个内存中的字节流对象
|
||||
file_stream = io.BytesIO(file_bytes)
|
||||
|
||||
# 获取文件大小
|
||||
file_size = len(file_bytes)
|
||||
|
||||
# 创建 InMemoryUploadedFile 对象
|
||||
uploaded_file = InMemoryUploadedFile(
|
||||
file=file_stream,
|
||||
field_name=None,
|
||||
name=file_name,
|
||||
content_type=content_type,
|
||||
size=file_size,
|
||||
charset=None,
|
||||
)
|
||||
return uploaded_file
|
||||
|
||||
|
||||
def any_to_amr(any_path, amr_path):
|
||||
"""
|
||||
把任意格式转成amr文件
|
||||
"""
|
||||
if any_path.endswith(".amr"):
|
||||
shutil.copy2(any_path, amr_path)
|
||||
return
|
||||
if any_path.endswith(".sil") or any_path.endswith(".silk") or any_path.endswith(".slk"):
|
||||
raise NotImplementedError("Not support file type: {}".format(any_path))
|
||||
audio = AudioSegment.from_file(any_path)
|
||||
audio = audio.set_frame_rate(8000) # only support 8000
|
||||
audio.export(amr_path, format="amr")
|
||||
return audio.duration_seconds * 1000
|
||||
|
||||
|
||||
def any_to_mp3(any_path, mp3_path):
|
||||
"""
|
||||
把任意格式转成mp3文件
|
||||
"""
|
||||
if any_path.endswith(".mp3"):
|
||||
shutil.copy2(any_path, mp3_path)
|
||||
return
|
||||
if any_path.endswith(".sil") or any_path.endswith(".silk") or any_path.endswith(".slk"):
|
||||
sil_to_wav(any_path, any_path)
|
||||
any_path = mp3_path
|
||||
audio = AudioSegment.from_file(any_path)
|
||||
audio = audio.set_frame_rate(16000)
|
||||
audio.export(mp3_path, format="mp3")
|
||||
|
||||
|
||||
def sil_to_wav(silk_path, wav_path, rate: int = 24000):
|
||||
"""
|
||||
silk 文件转 wav
|
||||
"""
|
||||
try:
|
||||
import pysilk
|
||||
except ImportError:
|
||||
raise AppApiException("import pysilk failed, wechaty voice message will not be supported.")
|
||||
wav_data = pysilk.decode_file(silk_path, to_wav=True, sample_rate=rate)
|
||||
with open(wav_path, "wb") as f:
|
||||
f.write(wav_data)
|
||||
|
||||
|
||||
def split_and_transcribe(file_path, model, max_segment_length_ms=59000, audio_format="mp3"):
|
||||
audio_data = AudioSegment.from_file(file_path, format=audio_format)
|
||||
audio_length_ms = len(audio_data)
|
||||
|
||||
if audio_length_ms <= max_segment_length_ms:
|
||||
return model.speech_to_text(io.BytesIO(audio_data.export(format=audio_format).read()))
|
||||
|
||||
full_text = []
|
||||
for start_ms in range(0, audio_length_ms, max_segment_length_ms):
|
||||
end_ms = min(audio_length_ms, start_ms + max_segment_length_ms)
|
||||
segment = audio_data[start_ms:end_ms]
|
||||
text = model.speech_to_text(io.BytesIO(segment.export(format=audio_format).read()))
|
||||
if isinstance(text, str):
|
||||
full_text.append(text)
|
||||
return ' '.join(full_text)
|
||||
|
||||
|
|
|
|||
|
|
@ -0,0 +1,140 @@
|
|||
# coding=utf-8
|
||||
"""
|
||||
@project: maxkb
|
||||
@Author:虎
|
||||
@file: rsa_util.py
|
||||
@date:2023/11/3 11:13
|
||||
@desc:
|
||||
"""
|
||||
import base64
|
||||
import threading
|
||||
|
||||
from Crypto.Cipher import PKCS1_v1_5 as PKCS1_cipher
|
||||
from Crypto.PublicKey import RSA
|
||||
from django.core import cache
|
||||
from django.db.models import QuerySet
|
||||
|
||||
from system_manage.models import SystemSetting, SettingType
|
||||
|
||||
lock = threading.Lock()
|
||||
rsa_cache = cache.caches['default']
|
||||
cache_key = "rsa_key"
|
||||
# 对密钥加密的密码
|
||||
secret_code = "mac_kb_password"
|
||||
|
||||
|
||||
def generate():
|
||||
"""
|
||||
生成 私钥秘钥对
|
||||
:return:{key:'公钥',value:'私钥'}
|
||||
"""
|
||||
# 生成一个 2048 位的密钥
|
||||
key = RSA.generate(2048)
|
||||
|
||||
# 获取私钥
|
||||
encrypted_key = key.export_key(passphrase=secret_code, pkcs=8,
|
||||
protection="scryptAndAES128-CBC")
|
||||
return {'key': key.publickey().export_key(), 'value': encrypted_key}
|
||||
|
||||
|
||||
def get_key_pair():
|
||||
rsa_value = rsa_cache.get(cache_key)
|
||||
if rsa_value is None:
|
||||
lock.acquire()
|
||||
rsa_value = rsa_cache.get(cache_key)
|
||||
if rsa_value is not None:
|
||||
return rsa_value
|
||||
try:
|
||||
rsa_value = get_key_pair_by_sql()
|
||||
rsa_cache.set(cache_key, rsa_value)
|
||||
finally:
|
||||
lock.release()
|
||||
return rsa_value
|
||||
|
||||
|
||||
def get_key_pair_by_sql():
|
||||
system_setting = QuerySet(SystemSetting).filter(type=SettingType.RSA.value).first()
|
||||
if system_setting is None:
|
||||
kv = generate()
|
||||
system_setting = SystemSetting(type=SettingType.RSA.value,
|
||||
meta={'key': kv.get('key').decode(), 'value': kv.get('value').decode()})
|
||||
system_setting.save()
|
||||
return system_setting.meta
|
||||
|
||||
|
||||
def encrypt(msg, public_key: str | None = None):
|
||||
"""
|
||||
加密
|
||||
:param msg: 加密数据
|
||||
:param public_key: 公钥
|
||||
:return: 加密后的数据
|
||||
"""
|
||||
if public_key is None:
|
||||
public_key = get_key_pair().get('key')
|
||||
cipher = PKCS1_cipher.new(RSA.importKey(public_key))
|
||||
encrypt_msg = cipher.encrypt(msg.encode("utf-8"))
|
||||
return base64.b64encode(encrypt_msg).decode()
|
||||
|
||||
|
||||
def decrypt(msg, pri_key: str | None = None):
|
||||
"""
|
||||
解密
|
||||
:param msg: 需要解密的数据
|
||||
:param pri_key: 私钥
|
||||
:return: 解密后数据
|
||||
"""
|
||||
if pri_key is None:
|
||||
pri_key = get_key_pair().get('value')
|
||||
cipher = PKCS1_cipher.new(RSA.importKey(pri_key, passphrase=secret_code))
|
||||
decrypt_data = cipher.decrypt(base64.b64decode(msg), 0)
|
||||
return decrypt_data.decode("utf-8")
|
||||
|
||||
|
||||
def rsa_long_encrypt(message, public_key: str | None = None, length=200):
|
||||
"""
|
||||
超长文本加密
|
||||
|
||||
:param message: 需要加密的字符串
|
||||
:param public_key 公钥
|
||||
:param length: 1024bit的证书用100, 2048bit的证书用 200
|
||||
:return: 加密后的数据
|
||||
"""
|
||||
# 读取公钥
|
||||
if public_key is None:
|
||||
public_key = get_key_pair().get('key')
|
||||
cipher = PKCS1_cipher.new(RSA.importKey(extern_key=public_key,
|
||||
passphrase=secret_code))
|
||||
# 处理:Plaintext is too long. 分段加密
|
||||
if len(message) <= length:
|
||||
# 对编码的数据进行加密,并通过base64进行编码
|
||||
result = base64.b64encode(cipher.encrypt(message.encode('utf-8')))
|
||||
else:
|
||||
rsa_text = []
|
||||
# 对编码后的数据进行切片,原因:加密长度不能过长
|
||||
for i in range(0, len(message), length):
|
||||
cont = message[i:i + length]
|
||||
# 对切片后的数据进行加密,并新增到text后面
|
||||
rsa_text.append(cipher.encrypt(cont.encode('utf-8')))
|
||||
# 加密完进行拼接
|
||||
cipher_text = b''.join(rsa_text)
|
||||
# base64进行编码
|
||||
result = base64.b64encode(cipher_text)
|
||||
return result.decode()
|
||||
|
||||
|
||||
def rsa_long_decrypt(message, pri_key: str | None = None, length=256):
|
||||
"""
|
||||
超长文本解密,默认不加密
|
||||
:param message: 需要解密的数据
|
||||
:param pri_key: 秘钥
|
||||
:param length : 1024bit的证书用128,2048bit证书用256位
|
||||
:return: 解密后的数据
|
||||
"""
|
||||
if pri_key is None:
|
||||
pri_key = get_key_pair().get('value')
|
||||
cipher = PKCS1_cipher.new(RSA.importKey(pri_key, passphrase=secret_code))
|
||||
base64_de = base64.b64decode(message)
|
||||
res = []
|
||||
for i in range(0, len(base64_de), length):
|
||||
res.append(cipher.decrypt(base64_de[i:i + length], 0))
|
||||
return b"".join(res).decode()
|
||||
|
|
@ -102,3 +102,5 @@ msgstr "用户管理"
|
|||
#: .\apps\users\views\user.py:24 .\apps\users\views\user.py:25
|
||||
msgid "Get current user information"
|
||||
msgstr "获取当前用户信息"
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -41,7 +41,8 @@ INSTALLED_APPS = [
|
|||
'users.apps.UsersConfig',
|
||||
'tools.apps.ToolConfig',
|
||||
'common',
|
||||
'system_manage'
|
||||
'system_manage',
|
||||
'models_provider',
|
||||
]
|
||||
|
||||
MIDDLEWARE = [
|
||||
|
|
|
|||
|
|
@ -22,7 +22,8 @@ from maxkb import settings
|
|||
|
||||
urlpatterns = [
|
||||
path("api/", include("users.urls")),
|
||||
path("api/", include("tools.urls"))
|
||||
path("api/", include("tools.urls")),
|
||||
path("api/", include("models_provider.urls")),
|
||||
]
|
||||
urlpatterns += [
|
||||
path('schema/', SpectacularAPIView.as_view(), name='schema'), # schema的配置文件的路由,下面两个ui也是根据这个配置文件来生成的
|
||||
|
|
|
|||
|
|
@ -0,0 +1,3 @@
|
|||
from django.contrib import admin
|
||||
|
||||
# Register your models here.
|
||||
|
|
@ -0,0 +1 @@
|
|||
# coding=utf-8
|
||||
|
|
@ -0,0 +1,20 @@
|
|||
# coding=utf-8
|
||||
|
||||
from common.mixins.api_mixin import APIMixin
|
||||
from common.result import ResultSerializer
|
||||
from models_provider.serializers.model import ModelCreateRequest, ModelModelSerializer
|
||||
|
||||
|
||||
class ModelCreateResponse(ResultSerializer):
|
||||
def get_data(self):
|
||||
return ModelModelSerializer()
|
||||
|
||||
|
||||
class ModelCreateAPI(APIMixin):
|
||||
@staticmethod
|
||||
def get_request():
|
||||
return ModelCreateRequest
|
||||
|
||||
@staticmethod
|
||||
def get_response():
|
||||
return ModelCreateResponse
|
||||
|
|
@ -0,0 +1,85 @@
|
|||
# coding=utf-8
|
||||
from drf_spectacular.types import OpenApiTypes
|
||||
from drf_spectacular.utils import OpenApiParameter
|
||||
|
||||
from common.mixins.api_mixin import APIMixin
|
||||
from common.result import ResultSerializer
|
||||
from rest_framework import serializers
|
||||
from django.utils.translation import gettext_lazy as _
|
||||
|
||||
|
||||
class ProvideResponse(ResultSerializer):
|
||||
def get_data(self):
|
||||
return ProvideSerializer()
|
||||
|
||||
|
||||
class ProvideSerializer(serializers.Serializer):
|
||||
name = serializers.CharField(required=True, max_length=64, label=_("model name"))
|
||||
provider = serializers.CharField(required=True, label=_("provider"))
|
||||
icon = serializers.CharField(required=True, label=_("icon"))
|
||||
|
||||
|
||||
class ProvideListSerializer(serializers.Serializer):
|
||||
key = serializers.CharField(required=True, max_length=64, label=_("model name"))
|
||||
value = serializers.CharField(required=True, label=_("value"))
|
||||
|
||||
|
||||
class ModelListSerializer(serializers.Serializer):
|
||||
name = serializers.CharField(required=True, label=_("model name"))
|
||||
model_type = serializers.CharField(required=True, label=_("model type"))
|
||||
desc = serializers.CharField(required=True, label=_("model name"))
|
||||
|
||||
|
||||
class ProvideApi(APIMixin):
|
||||
class ModelList(APIMixin):
|
||||
@staticmethod
|
||||
def get_query_params_api():
|
||||
return [OpenApiParameter(
|
||||
# 参数的名称是done
|
||||
name="model_type",
|
||||
# 对参数的备注
|
||||
description="model_type",
|
||||
# 指定参数的类型
|
||||
type=OpenApiTypes.STR,
|
||||
location=OpenApiParameter.QUERY,
|
||||
# 指定必须给
|
||||
required=False,
|
||||
), OpenApiParameter(
|
||||
# 参数的名称是done
|
||||
name="provider",
|
||||
# 对参数的备注
|
||||
description="provider",
|
||||
# 指定参数的类型
|
||||
type=OpenApiTypes.STR,
|
||||
location=OpenApiParameter.QUERY,
|
||||
# 指定必须给
|
||||
required=True,
|
||||
)
|
||||
]
|
||||
|
||||
@staticmethod
|
||||
def get_response():
|
||||
return serializers.ListSerializer(child=ModelListSerializer())
|
||||
|
||||
@staticmethod
|
||||
def get_response():
|
||||
return ProvideResponse
|
||||
|
||||
class ModelTypeList(APIMixin):
|
||||
@staticmethod
|
||||
def get_query_params_api():
|
||||
return [OpenApiParameter(
|
||||
# 参数的名称是done
|
||||
name="provider",
|
||||
# 对参数的备注
|
||||
description="provider",
|
||||
# 指定参数的类型
|
||||
type=OpenApiTypes.STR,
|
||||
location=OpenApiParameter.QUERY,
|
||||
# 指定必须给
|
||||
required=True,
|
||||
)]
|
||||
|
||||
@staticmethod
|
||||
def get_response():
|
||||
return serializers.ListSerializer(child=ProvideListSerializer())
|
||||
|
|
@ -0,0 +1,6 @@
|
|||
from django.apps import AppConfig
|
||||
|
||||
|
||||
class ModelsProviderConfig(AppConfig):
|
||||
default_auto_field = 'django.db.models.BigAutoField'
|
||||
name = 'models_provider'
|
||||
|
|
@ -0,0 +1,255 @@
|
|||
# coding=utf-8
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from enum import Enum
|
||||
from functools import reduce
|
||||
from typing import Dict, Iterator, Type, List
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from common.exception.app_exception import AppApiException
|
||||
from django.utils.translation import gettext_lazy as _
|
||||
|
||||
from common.utils.common import encryption
|
||||
|
||||
|
||||
class DownModelChunkStatus(Enum):
|
||||
success = "success"
|
||||
error = "error"
|
||||
pulling = "pulling"
|
||||
unknown = 'unknown'
|
||||
|
||||
|
||||
class ValidCode(Enum):
|
||||
valid_error = 500
|
||||
model_not_fount = 404
|
||||
|
||||
|
||||
class DownModelChunk:
|
||||
def __init__(self, status: DownModelChunkStatus, digest: str, progress: int, details: str, index: int):
|
||||
self.details = details
|
||||
self.status = status
|
||||
self.digest = digest
|
||||
self.progress = progress
|
||||
self.index = index
|
||||
|
||||
def to_dict(self):
|
||||
return {
|
||||
"details": self.details,
|
||||
"status": self.status.value,
|
||||
"digest": self.digest,
|
||||
"progress": self.progress,
|
||||
"index": self.index
|
||||
}
|
||||
|
||||
|
||||
class IModelProvider(ABC):
|
||||
@abstractmethod
|
||||
def get_model_info_manage(self):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_model_provide_info(self):
|
||||
pass
|
||||
|
||||
def get_model_type_list(self):
|
||||
return self.get_model_info_manage().get_model_type_list()
|
||||
|
||||
def get_model_list(self, model_type):
|
||||
if model_type is None:
|
||||
raise AppApiException(500, _('Model type cannot be empty'))
|
||||
return self.get_model_info_manage().get_model_list_by_model_type(model_type)
|
||||
|
||||
def get_model_credential(self, model_type, model_name):
|
||||
model_info = self.get_model_info_manage().get_model_info(model_type, model_name)
|
||||
return model_info.model_credential
|
||||
|
||||
def get_model_params(self, model_type, model_name):
|
||||
model_info = self.get_model_info_manage().get_model_info(model_type, model_name)
|
||||
return model_info.model_credential
|
||||
|
||||
def is_valid_credential(self, model_type, model_name, model_credential: Dict[str, object],
|
||||
model_params: Dict[str, object], raise_exception=False):
|
||||
model_info = self.get_model_info_manage().get_model_info(model_type, model_name)
|
||||
return model_info.model_credential.is_valid(model_type, model_name, model_credential, model_params, self,
|
||||
raise_exception=raise_exception)
|
||||
|
||||
def get_model(self, model_type, model_name, model_credential: Dict[str, object], **model_kwargs) -> BaseModel:
|
||||
model_info = self.get_model_info_manage().get_model_info(model_type, model_name)
|
||||
return model_info.model_class.new_instance(model_type, model_name, model_credential, **model_kwargs)
|
||||
|
||||
def get_dialogue_number(self):
|
||||
return 3
|
||||
|
||||
def down_model(self, model_type: str, model_name, model_credential: Dict[str, object]) -> Iterator[DownModelChunk]:
|
||||
raise AppApiException(500, _('The current platform does not support downloading models'))
|
||||
|
||||
|
||||
class MaxKBBaseModel(ABC):
|
||||
@staticmethod
|
||||
@abstractmethod
|
||||
def new_instance(model_type, model_name, model_credential: Dict[str, object], **model_kwargs):
|
||||
pass
|
||||
|
||||
@staticmethod
|
||||
def is_cache_model():
|
||||
return True
|
||||
|
||||
@staticmethod
|
||||
def filter_optional_params(model_kwargs):
|
||||
optional_params = {}
|
||||
for key, value in model_kwargs.items():
|
||||
if key not in ['model_id', 'use_local', 'streaming', 'show_ref_label']:
|
||||
optional_params[key] = value
|
||||
return optional_params
|
||||
|
||||
|
||||
class BaseModelCredential(ABC):
|
||||
|
||||
@abstractmethod
|
||||
def is_valid(self, model_type: str, model_name, model: Dict[str, object], model_params, provider,
|
||||
raise_exception=True):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def encryption_dict(self, model_info: Dict[str, object]):
|
||||
"""
|
||||
:param model_info: 模型数据
|
||||
:return: 加密后数据
|
||||
"""
|
||||
pass
|
||||
|
||||
def get_model_params_setting_form(self, model_name):
|
||||
"""
|
||||
模型参数设置表单
|
||||
:return:
|
||||
"""
|
||||
pass
|
||||
|
||||
@staticmethod
|
||||
def encryption(message: str):
|
||||
"""
|
||||
加密敏感字段数据 加密方式是 如果密码是 1234567890 那么给前端则是 123******890
|
||||
:param message:
|
||||
:return:
|
||||
"""
|
||||
return encryption(message)
|
||||
|
||||
|
||||
class ModelTypeConst(Enum):
|
||||
LLM = {'code': 'LLM', 'message': _('LLM')}
|
||||
EMBEDDING = {'code': 'EMBEDDING', 'message': _('Embedding Model')}
|
||||
STT = {'code': 'STT', 'message': _('Speech2Text')}
|
||||
TTS = {'code': 'TTS', 'message': _('TTS')}
|
||||
IMAGE = {'code': 'IMAGE', 'message': _('Vision Model')}
|
||||
TTI = {'code': 'TTI', 'message': _('Image Generation')}
|
||||
RERANKER = {'code': 'RERANKER', 'message': _('Rerank')}
|
||||
|
||||
|
||||
class ModelInfo:
|
||||
def __init__(self, name: str, desc: str, model_type: ModelTypeConst, model_credential: BaseModelCredential,
|
||||
model_class: Type[MaxKBBaseModel],
|
||||
**keywords):
|
||||
self.name = name
|
||||
self.desc = desc
|
||||
self.model_type = model_type.name
|
||||
self.model_credential = model_credential
|
||||
self.model_class = model_class
|
||||
if keywords is not None:
|
||||
for key in keywords.keys():
|
||||
self.__setattr__(key, keywords.get(key))
|
||||
|
||||
def get_name(self):
|
||||
"""
|
||||
获取模型名称
|
||||
:return: 模型名称
|
||||
"""
|
||||
return self.name
|
||||
|
||||
def get_desc(self):
|
||||
"""
|
||||
获取模型描述
|
||||
:return: 模型描述
|
||||
"""
|
||||
return self.desc
|
||||
|
||||
def get_model_type(self):
|
||||
return self.model_type
|
||||
|
||||
def get_model_class(self):
|
||||
return self.model_class
|
||||
|
||||
def to_dict(self):
|
||||
return reduce(lambda x, y: {**x, **y},
|
||||
[{attr: self.__getattribute__(attr)} for attr in vars(self) if
|
||||
not attr.startswith("__") and not attr == 'model_credential' and not attr == 'model_class'], {})
|
||||
|
||||
|
||||
class ModelInfoManage:
|
||||
def __init__(self):
|
||||
self.model_dict = {}
|
||||
self.model_list = []
|
||||
self.default_model_list = []
|
||||
self.default_model_dict = {}
|
||||
|
||||
def append_model_info(self, model_info: ModelInfo):
|
||||
self.model_list.append(model_info)
|
||||
model_type_dict = self.model_dict.get(model_info.model_type)
|
||||
if model_type_dict is None:
|
||||
self.model_dict[model_info.model_type] = {model_info.name: model_info}
|
||||
else:
|
||||
model_type_dict[model_info.name] = model_info
|
||||
|
||||
def append_default_model_info(self, model_info: ModelInfo):
|
||||
self.default_model_list.append(model_info)
|
||||
self.default_model_dict[model_info.model_type] = model_info
|
||||
|
||||
def get_model_list(self):
|
||||
return [model.to_dict() for model in self.model_list]
|
||||
|
||||
def get_model_list_by_model_type(self, model_type):
|
||||
return [model.to_dict() for model in self.model_list if model.model_type == model_type]
|
||||
|
||||
def get_model_type_list(self):
|
||||
return [{'key': _type.value.get('message'), 'value': _type.value.get('code')} for _type in ModelTypeConst if
|
||||
len([model for model in self.model_list if model.model_type == _type.name]) > 0]
|
||||
|
||||
def get_model_info(self, model_type, model_name) -> ModelInfo:
|
||||
model_info = self.model_dict.get(model_type, {}).get(model_name, self.default_model_dict.get(model_type))
|
||||
if model_info is None:
|
||||
raise AppApiException(500, _('The model does not support'))
|
||||
return model_info
|
||||
|
||||
class builder:
|
||||
def __init__(self):
|
||||
self.modelInfoManage = ModelInfoManage()
|
||||
|
||||
def append_model_info(self, model_info: ModelInfo):
|
||||
self.modelInfoManage.append_model_info(model_info)
|
||||
return self
|
||||
|
||||
def append_model_info_list(self, model_info_list: List[ModelInfo]):
|
||||
for model_info in model_info_list:
|
||||
self.modelInfoManage.append_model_info(model_info)
|
||||
return self
|
||||
|
||||
def append_default_model_info(self, model_info: ModelInfo):
|
||||
self.modelInfoManage.append_default_model_info(model_info)
|
||||
return self
|
||||
|
||||
def build(self):
|
||||
return self.modelInfoManage
|
||||
|
||||
|
||||
class ModelProvideInfo:
|
||||
def __init__(self, provider: str, name: str, icon: str):
|
||||
self.provider = provider
|
||||
|
||||
self.name = name
|
||||
|
||||
self.icon = icon
|
||||
|
||||
def to_dict(self):
|
||||
return reduce(lambda x, y: {**x, **y},
|
||||
[{attr: self.__getattribute__(attr)} for attr in vars(self) if
|
||||
not attr.startswith("__")], {})
|
||||
|
|
@ -0,0 +1,48 @@
|
|||
# coding=utf-8
|
||||
from enum import Enum
|
||||
|
||||
from models_provider.impl.aliyun_bai_lian_model_provider.aliyun_bai_lian_model_provider import \
|
||||
AliyunBaiLianModelProvider
|
||||
from models_provider.impl.anthropic_model_provider.anthropic_model_provider import AnthropicModelProvider
|
||||
from models_provider.impl.aws_bedrock_model_provider.aws_bedrock_model_provider import BedrockModelProvider
|
||||
from models_provider.impl.azure_model_provider.azure_model_provider import AzureModelProvider
|
||||
from models_provider.impl.deepseek_model_provider.deepseek_model_provider import DeepSeekModelProvider
|
||||
from models_provider.impl.gemini_model_provider.gemini_model_provider import GeminiModelProvider
|
||||
from models_provider.impl.kimi_model_provider.kimi_model_provider import KimiModelProvider
|
||||
from models_provider.impl.local_model_provider.local_model_provider import LocalModelProvider
|
||||
from models_provider.impl.ollama_model_provider.ollama_model_provider import OllamaModelProvider
|
||||
from models_provider.impl.openai_model_provider.openai_model_provider import OpenAIModelProvider
|
||||
from models_provider.impl.qwen_model_provider.qwen_model_provider import QwenModelProvider
|
||||
from models_provider.impl.siliconCloud_model_provider.siliconCloud_model_provider import SiliconCloudModelProvider
|
||||
from models_provider.impl.tencent_cloud_model_provider.tencent_cloud_model_provider import TencentCloudModelProvider
|
||||
from models_provider.impl.tencent_model_provider.tencent_model_provider import TencentModelProvider
|
||||
from models_provider.impl.vllm_model_provider.vllm_model_provider import VllmModelProvider
|
||||
from models_provider.impl.volcanic_engine_model_provider.volcanic_engine_model_provider import \
|
||||
VolcanicEngineModelProvider
|
||||
from models_provider.impl.wenxin_model_provider.wenxin_model_provider import WenxinModelProvider
|
||||
from models_provider.impl.xf_model_provider.xf_model_provider import XunFeiModelProvider
|
||||
from models_provider.impl.xinference_model_provider.xinference_model_provider import XinferenceModelProvider
|
||||
from models_provider.impl.zhipu_model_provider.zhipu_model_provider import ZhiPuModelProvider
|
||||
|
||||
|
||||
class ModelProvideConstants(Enum):
|
||||
model_azure_provider = AzureModelProvider()
|
||||
model_wenxin_provider = WenxinModelProvider()
|
||||
model_ollama_provider = OllamaModelProvider()
|
||||
model_openai_provider = OpenAIModelProvider()
|
||||
model_kimi_provider = KimiModelProvider()
|
||||
model_qwen_provider = QwenModelProvider()
|
||||
model_zhipu_provider = ZhiPuModelProvider()
|
||||
model_xf_provider = XunFeiModelProvider()
|
||||
model_deepseek_provider = DeepSeekModelProvider()
|
||||
model_gemini_provider = GeminiModelProvider()
|
||||
model_volcanic_engine_provider = VolcanicEngineModelProvider()
|
||||
model_tencent_provider = TencentModelProvider()
|
||||
model_tencent_cloud_provider = TencentCloudModelProvider()
|
||||
model_aws_bedrock_provider = BedrockModelProvider()
|
||||
model_local_provider = LocalModelProvider()
|
||||
model_xinference_provider = XinferenceModelProvider()
|
||||
model_vllm_provider = VllmModelProvider()
|
||||
aliyun_bai_lian_model_provider = AliyunBaiLianModelProvider()
|
||||
model_anthropic_provider = AnthropicModelProvider()
|
||||
model_siliconCloud_provider = SiliconCloudModelProvider()
|
||||
|
|
@ -0,0 +1,8 @@
|
|||
# coding=utf-8
|
||||
"""
|
||||
@project: MaxKB
|
||||
@Author:虎
|
||||
@file: __init__.py
|
||||
@date:2024/9/9 17:42
|
||||
@desc:
|
||||
"""
|
||||
|
|
@ -0,0 +1,100 @@
|
|||
# coding=utf-8
|
||||
"""
|
||||
@project: MaxKB
|
||||
@Author:虎
|
||||
@file: aliyun_bai_lian_model_provider.py
|
||||
@date:2024/9/9 17:43
|
||||
@desc:
|
||||
"""
|
||||
import os
|
||||
|
||||
from common.utils.common import get_file_content
|
||||
from models_provider.base_model_provider import ModelProvideInfo, ModelTypeConst, ModelInfo, IModelProvider, \
|
||||
ModelInfoManage
|
||||
from models_provider.impl.aliyun_bai_lian_model_provider.credential.embedding import \
|
||||
AliyunBaiLianEmbeddingCredential
|
||||
from models_provider.impl.aliyun_bai_lian_model_provider.credential.image import QwenVLModelCredential
|
||||
from models_provider.impl.aliyun_bai_lian_model_provider.credential.llm import BaiLianLLMModelCredential
|
||||
from models_provider.impl.aliyun_bai_lian_model_provider.credential.reranker import \
|
||||
AliyunBaiLianRerankerCredential
|
||||
from models_provider.impl.aliyun_bai_lian_model_provider.credential.stt import AliyunBaiLianSTTModelCredential
|
||||
from models_provider.impl.aliyun_bai_lian_model_provider.credential.tti import QwenTextToImageModelCredential
|
||||
from models_provider.impl.aliyun_bai_lian_model_provider.credential.tts import AliyunBaiLianTTSModelCredential
|
||||
from models_provider.impl.aliyun_bai_lian_model_provider.model.embedding import AliyunBaiLianEmbedding
|
||||
from models_provider.impl.aliyun_bai_lian_model_provider.model.image import QwenVLChatModel
|
||||
from models_provider.impl.aliyun_bai_lian_model_provider.model.llm import BaiLianChatModel
|
||||
from models_provider.impl.aliyun_bai_lian_model_provider.model.reranker import AliyunBaiLianReranker
|
||||
from models_provider.impl.aliyun_bai_lian_model_provider.model.stt import AliyunBaiLianSpeechToText
|
||||
from models_provider.impl.aliyun_bai_lian_model_provider.model.tti import QwenTextToImageModel
|
||||
from models_provider.impl.aliyun_bai_lian_model_provider.model.tts import AliyunBaiLianTextToSpeech
|
||||
from maxkb.conf import PROJECT_DIR
|
||||
from django.utils.translation import gettext as _, gettext
|
||||
|
||||
aliyun_bai_lian_model_credential = AliyunBaiLianRerankerCredential()
|
||||
aliyun_bai_lian_tts_model_credential = AliyunBaiLianTTSModelCredential()
|
||||
aliyun_bai_lian_stt_model_credential = AliyunBaiLianSTTModelCredential()
|
||||
aliyun_bai_lian_embedding_model_credential = AliyunBaiLianEmbeddingCredential()
|
||||
aliyun_bai_lian_llm_model_credential = BaiLianLLMModelCredential()
|
||||
qwenvl_model_credential = QwenVLModelCredential()
|
||||
qwentti_model_credential = QwenTextToImageModelCredential()
|
||||
|
||||
model_info_list = [ModelInfo('gte-rerank',
|
||||
_('With the GTE-Rerank text sorting series model developed by Alibaba Tongyi Lab, developers can integrate high-quality text retrieval and sorting through the LlamaIndex framework.'),
|
||||
ModelTypeConst.RERANKER, aliyun_bai_lian_model_credential, AliyunBaiLianReranker),
|
||||
ModelInfo('paraformer-realtime-v2',
|
||||
_('Chinese (including various dialects such as Cantonese), English, Japanese, and Korean support free switching between multiple languages.'),
|
||||
ModelTypeConst.STT, aliyun_bai_lian_stt_model_credential, AliyunBaiLianSpeechToText),
|
||||
ModelInfo('cosyvoice-v1',
|
||||
_('CosyVoice is based on a new generation of large generative speech models, which can predict emotions, intonation, rhythm, etc. based on context, and has better anthropomorphic effects.'),
|
||||
ModelTypeConst.TTS, aliyun_bai_lian_tts_model_credential, AliyunBaiLianTextToSpeech),
|
||||
ModelInfo('text-embedding-v1',
|
||||
_("Universal text vector is Tongyi Lab's multi-language text unified vector model based on the LLM base. It provides high-level vector services for multiple mainstream languages around the world and helps developers quickly convert text data into high-quality vector data."),
|
||||
ModelTypeConst.EMBEDDING, aliyun_bai_lian_embedding_model_credential,
|
||||
AliyunBaiLianEmbedding),
|
||||
ModelInfo('qwen-turbo', '', ModelTypeConst.LLM, aliyun_bai_lian_llm_model_credential,
|
||||
BaiLianChatModel),
|
||||
ModelInfo('qwen-plus', '', ModelTypeConst.LLM, aliyun_bai_lian_llm_model_credential,
|
||||
BaiLianChatModel),
|
||||
ModelInfo('qwen-max', '', ModelTypeConst.LLM, aliyun_bai_lian_llm_model_credential,
|
||||
BaiLianChatModel)
|
||||
]
|
||||
|
||||
module_info_vl_list = [
|
||||
ModelInfo('qwen-vl-max', '', ModelTypeConst.IMAGE, qwenvl_model_credential, QwenVLChatModel),
|
||||
ModelInfo('qwen-vl-max-0809', '', ModelTypeConst.IMAGE, qwenvl_model_credential, QwenVLChatModel),
|
||||
ModelInfo('qwen-vl-plus-0809', '', ModelTypeConst.IMAGE, qwenvl_model_credential, QwenVLChatModel),
|
||||
]
|
||||
module_info_tti_list = [
|
||||
ModelInfo('wanx-v1',
|
||||
_('Tongyi Wanxiang - a large image model for text generation, supports bilingual input in Chinese and English, and supports the input of reference pictures for reference content or reference style migration. Key styles include but are not limited to watercolor, oil painting, Chinese painting, sketch, flat illustration, two-dimensional, and 3D. Cartoon.'),
|
||||
ModelTypeConst.TTI, qwentti_model_credential, QwenTextToImageModel),
|
||||
]
|
||||
|
||||
model_info_manage = (
|
||||
ModelInfoManage.builder()
|
||||
.append_model_info_list(model_info_list)
|
||||
.append_model_info_list(module_info_vl_list)
|
||||
.append_default_model_info(module_info_vl_list[0])
|
||||
.append_model_info_list(module_info_tti_list)
|
||||
.append_default_model_info(module_info_tti_list[0])
|
||||
.append_default_model_info(model_info_list[1])
|
||||
.append_default_model_info(model_info_list[2])
|
||||
.append_default_model_info(model_info_list[3])
|
||||
.append_default_model_info(model_info_list[4])
|
||||
.append_default_model_info(model_info_list[0])
|
||||
.build()
|
||||
)
|
||||
|
||||
|
||||
class AliyunBaiLianModelProvider(IModelProvider):
|
||||
|
||||
def get_model_info_manage(self):
|
||||
return model_info_manage
|
||||
|
||||
def get_model_provide_info(self):
|
||||
return ModelProvideInfo(provider='aliyun_bai_lian_model_provider', name=gettext('Alibaba Cloud Bailian'),
|
||||
icon=get_file_content(
|
||||
os.path.join(PROJECT_DIR, "apps", 'models_provider', 'impl',
|
||||
'aliyun_bai_lian_model_provider',
|
||||
'icon',
|
||||
'aliyun_bai_lian_icon_svg')))
|
||||
|
|
@ -0,0 +1,74 @@
|
|||
# coding=utf-8
|
||||
"""
|
||||
@project: MaxKB
|
||||
@Author:虎
|
||||
@file: embedding.py
|
||||
@date:2024/10/16 17:01
|
||||
@desc:
|
||||
"""
|
||||
import traceback
|
||||
from typing import Dict, Any
|
||||
|
||||
from django.utils.translation import gettext as _
|
||||
|
||||
from common import forms
|
||||
from common.exception.app_exception import AppApiException
|
||||
from common.forms import BaseForm
|
||||
from models_provider.base_model_provider import BaseModelCredential, ValidCode
|
||||
from models_provider.impl.aliyun_bai_lian_model_provider.model.embedding import AliyunBaiLianEmbedding
|
||||
|
||||
|
||||
class AliyunBaiLianEmbeddingCredential(BaseForm, BaseModelCredential):
|
||||
|
||||
def is_valid(
|
||||
self,
|
||||
model_type: str,
|
||||
model_name: str,
|
||||
model_credential: Dict[str, Any],
|
||||
model_params: Dict[str, Any],
|
||||
provider: Any,
|
||||
raise_exception: bool = False
|
||||
) -> bool:
|
||||
"""
|
||||
验证模型凭据是否有效
|
||||
"""
|
||||
model_type_list = provider.get_model_type_list()
|
||||
if not any(mt.get('value') == model_type for mt in model_type_list):
|
||||
raise AppApiException(
|
||||
ValidCode.valid_error.value,
|
||||
f"{model_type} Model type is not supported"
|
||||
)
|
||||
required_keys = ['dashscope_api_key']
|
||||
missing_keys = [key for key in required_keys if key not in model_credential]
|
||||
if missing_keys:
|
||||
if raise_exception:
|
||||
raise AppApiException(
|
||||
ValidCode.valid_error.value,
|
||||
f"{', '.join(missing_keys)} is required"
|
||||
)
|
||||
return False
|
||||
|
||||
try:
|
||||
model: AliyunBaiLianEmbedding = provider.get_model(model_type, model_name, model_credential)
|
||||
model.embed_query(_("Hello"))
|
||||
except Exception as e:
|
||||
traceback.print_exc()
|
||||
if isinstance(e, AppApiException):
|
||||
raise e
|
||||
if raise_exception:
|
||||
raise AppApiException(
|
||||
ValidCode.valid_error.value,
|
||||
f"Verification failed, please check whether the parameters are correct: {e}"
|
||||
)
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
def encryption_dict(self, model: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""
|
||||
加密敏感信息
|
||||
"""
|
||||
api_key = model.get('dashscope_api_key', '')
|
||||
return {**model, 'dashscope_api_key': super().encryption(api_key)}
|
||||
|
||||
dashscope_api_key = forms.PasswordInputField('API Key', required=True)
|
||||
|
|
@ -0,0 +1,71 @@
|
|||
# coding=utf-8
|
||||
"""
|
||||
@project: MaxKB
|
||||
@Author:虎
|
||||
@file: llm.py
|
||||
@date:2024/7/11 18:41
|
||||
@desc:
|
||||
"""
|
||||
import traceback
|
||||
from typing import Dict
|
||||
|
||||
from django.utils.translation import gettext_lazy as _, gettext
|
||||
from langchain_core.messages import HumanMessage
|
||||
|
||||
from common import forms
|
||||
from common.exception.app_exception import AppApiException
|
||||
from common.forms import BaseForm
|
||||
from models_provider.base_model_provider import BaseModelCredential, ValidCode
|
||||
|
||||
|
||||
class QwenVLModelCredential(BaseForm, BaseModelCredential):
|
||||
|
||||
def is_valid(
|
||||
self,
|
||||
model_type: str,
|
||||
model_name: str,
|
||||
model_credential: Dict[str, object],
|
||||
model_params: dict,
|
||||
provider,
|
||||
raise_exception: bool = False
|
||||
) -> bool:
|
||||
model_type_list = provider.get_model_type_list()
|
||||
if not any(mt.get('value') == model_type for mt in model_type_list):
|
||||
raise AppApiException(
|
||||
ValidCode.valid_error.value,
|
||||
gettext('{model_type} Model type is not supported').format(model_type=model_type)
|
||||
)
|
||||
required_keys = ['api_key']
|
||||
for key in required_keys:
|
||||
if key not in model_credential:
|
||||
if raise_exception:
|
||||
raise AppApiException(
|
||||
ValidCode.valid_error.value,
|
||||
gettext('{key} is required').format(key=key)
|
||||
)
|
||||
return False
|
||||
|
||||
try:
|
||||
model = provider.get_model(model_type, model_name, model_credential, **model_params)
|
||||
res = model.stream([HumanMessage(content=[{"type": "text", "text": gettext('Hello')}])])
|
||||
for chunk in res:
|
||||
print(chunk)
|
||||
except Exception as e:
|
||||
traceback.print_exc()
|
||||
if isinstance(e, AppApiException):
|
||||
raise e
|
||||
if raise_exception:
|
||||
raise AppApiException(
|
||||
ValidCode.valid_error.value,
|
||||
gettext('Verification failed, please check whether the parameters are correct: {error}').format(
|
||||
error=str(e)
|
||||
)
|
||||
)
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
def encryption_dict(self, model: Dict[str, object]) -> Dict[str, object]:
|
||||
return {**model, 'api_key': super().encryption(model.get('api_key', ''))}
|
||||
|
||||
api_key = forms.PasswordInputField('API Key', required=True)
|
||||
|
|
@ -0,0 +1,94 @@
|
|||
# coding=utf-8
|
||||
import traceback
|
||||
from typing import Dict
|
||||
|
||||
from langchain_core.messages import HumanMessage
|
||||
from django.utils.translation import gettext_lazy as _, gettext
|
||||
|
||||
from common import forms
|
||||
from common.exception.app_exception import AppApiException
|
||||
from common.forms import BaseForm, TooltipLabel
|
||||
from models_provider.base_model_provider import BaseModelCredential, ValidCode
|
||||
|
||||
|
||||
class BaiLianLLMModelParams(BaseForm):
|
||||
temperature = forms.SliderField(
|
||||
TooltipLabel(
|
||||
_('Temperature'),
|
||||
_('Higher values make the output more random, while lower values make it more focused and deterministic.')
|
||||
),
|
||||
required=True,
|
||||
default_value=0.7,
|
||||
_min=0.1,
|
||||
_max=1.0,
|
||||
_step=0.01,
|
||||
precision=2
|
||||
)
|
||||
|
||||
max_tokens = forms.SliderField(
|
||||
TooltipLabel(
|
||||
_('Output the maximum Tokens'),
|
||||
_('Specify the maximum number of tokens that the model can generate.')
|
||||
),
|
||||
required=True,
|
||||
default_value=800,
|
||||
_min=1,
|
||||
_max=100000,
|
||||
_step=1,
|
||||
precision=0
|
||||
)
|
||||
|
||||
|
||||
class BaiLianLLMModelCredential(BaseForm, BaseModelCredential):
|
||||
|
||||
api_base = forms.TextInputField(_('API URL'), required=True)
|
||||
api_key = forms.PasswordInputField(_('API Key'), required=True)
|
||||
|
||||
def is_valid(
|
||||
self,
|
||||
model_type: str,
|
||||
model_name: str,
|
||||
model_credential: Dict[str, object],
|
||||
model_params: dict,
|
||||
provider,
|
||||
raise_exception: bool = False
|
||||
) -> bool:
|
||||
model_type_list = provider.get_model_type_list()
|
||||
if not any(mt.get('value') == model_type for mt in model_type_list):
|
||||
raise AppApiException(
|
||||
ValidCode.valid_error.value,
|
||||
gettext('{model_type} Model type is not supported').format(model_type=model_type)
|
||||
)
|
||||
|
||||
for key in ['api_base', 'api_key']:
|
||||
if key not in model_credential:
|
||||
if raise_exception:
|
||||
raise AppApiException(
|
||||
ValidCode.valid_error.value,
|
||||
gettext('{key} is required').format(key=key)
|
||||
)
|
||||
return False
|
||||
|
||||
try:
|
||||
model = provider.get_model(model_type, model_name, model_credential, **model_params)
|
||||
model.invoke([HumanMessage(content=gettext('Hello'))])
|
||||
except Exception as e:
|
||||
traceback.print_exc()
|
||||
if isinstance(e, AppApiException):
|
||||
raise e
|
||||
if raise_exception:
|
||||
raise AppApiException(
|
||||
ValidCode.valid_error.value,
|
||||
gettext('Verification failed, please check whether the parameters are correct: {error}').format(
|
||||
error=str(e)
|
||||
)
|
||||
)
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
def encryption_dict(self, model: Dict[str, object]) -> Dict[str, object]:
|
||||
return {**model, 'api_key': super().encryption(model.get('api_key', ''))}
|
||||
|
||||
def get_model_params_setting_form(self, model_name: str) -> BaiLianLLMModelParams:
|
||||
return BaiLianLLMModelParams()
|
||||
|
|
@ -0,0 +1,85 @@
|
|||
# coding=utf-8
|
||||
|
||||
import traceback
|
||||
from typing import Dict, Any
|
||||
|
||||
from django.utils.translation import gettext as _
|
||||
from langchain_core.documents import Document
|
||||
|
||||
from common.exception.app_exception import AppApiException
|
||||
from common.forms import BaseForm, PasswordInputField
|
||||
from models_provider.base_model_provider import BaseModelCredential, ValidCode
|
||||
from models_provider.impl.aliyun_bai_lian_model_provider.model.reranker import AliyunBaiLianReranker
|
||||
|
||||
|
||||
class AliyunBaiLianRerankerCredential(BaseForm, BaseModelCredential):
|
||||
"""
|
||||
Credential class for the Aliyun BaiLian Reranker model.
|
||||
Provides validation and encryption for the model credentials.
|
||||
"""
|
||||
|
||||
dashscope_api_key = PasswordInputField('API Key', required=True)
|
||||
|
||||
def is_valid(
|
||||
self,
|
||||
model_type: str,
|
||||
model_name: str,
|
||||
model_credential: Dict[str, Any],
|
||||
model_params: Dict[str, Any],
|
||||
provider,
|
||||
raise_exception: bool = False
|
||||
) -> bool:
|
||||
"""
|
||||
Validate the model credentials.
|
||||
|
||||
:param model_type: Type of the model (e.g., 'RERANKER').
|
||||
:param model_name: Name of the model.
|
||||
:param model_credential: Dictionary containing the model credentials.
|
||||
:param model_params: Parameters for the model.
|
||||
:param provider: Model provider instance.
|
||||
:param raise_exception: Whether to raise an exception on validation failure.
|
||||
:return: Boolean indicating whether the credentials are valid.
|
||||
"""
|
||||
if model_type != 'RERANKER':
|
||||
raise AppApiException(
|
||||
ValidCode.valid_error.value,
|
||||
_('{model_type} Model type is not supported').format(model_type=model_type)
|
||||
)
|
||||
|
||||
required_keys = ['dashscope_api_key']
|
||||
for key in required_keys:
|
||||
if key not in model_credential:
|
||||
if raise_exception:
|
||||
raise AppApiException(
|
||||
ValidCode.valid_error.value,
|
||||
_('{key} is required').format(key=key)
|
||||
)
|
||||
return False
|
||||
|
||||
try:
|
||||
model: AliyunBaiLianReranker = provider.get_model(model_type, model_name, model_credential)
|
||||
model.compress_documents([Document(page_content=_('Hello'))], _('Hello'))
|
||||
except Exception as e:
|
||||
traceback.print_exc()
|
||||
if isinstance(e, AppApiException):
|
||||
raise e
|
||||
if raise_exception:
|
||||
raise AppApiException(
|
||||
ValidCode.valid_error.value,
|
||||
_('Verification failed, please check whether the parameters are correct: {error}').format(error=str(e))
|
||||
)
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
def encryption_dict(self, model: Dict[str, object]) -> Dict[str, object]:
|
||||
"""
|
||||
Encrypt sensitive fields in the model dictionary.
|
||||
|
||||
:param model: Dictionary containing model details.
|
||||
:return: Dictionary with encrypted sensitive fields.
|
||||
"""
|
||||
return {
|
||||
**model,
|
||||
'dashscope_api_key': super().encryption(model.get('dashscope_api_key', ''))
|
||||
}
|
||||
|
|
@ -0,0 +1,92 @@
|
|||
# coding=utf-8
|
||||
|
||||
import traceback
|
||||
from typing import Dict, Any
|
||||
|
||||
from django.utils.translation import gettext as _
|
||||
from common.exception.app_exception import AppApiException
|
||||
from common.forms import BaseForm, PasswordInputField
|
||||
from models_provider.base_model_provider import BaseModelCredential, ValidCode
|
||||
|
||||
|
||||
class AliyunBaiLianSTTModelCredential(BaseForm, BaseModelCredential):
|
||||
"""
|
||||
Credential class for the Aliyun BaiLian STT (Speech-to-Text) model.
|
||||
Provides validation and encryption for the model credentials.
|
||||
"""
|
||||
|
||||
api_key = PasswordInputField("API Key", required=True)
|
||||
|
||||
def is_valid(
|
||||
self,
|
||||
model_type: str,
|
||||
model_name: str,
|
||||
model_credential: Dict[str, Any],
|
||||
model_params: Dict[str, Any],
|
||||
provider,
|
||||
raise_exception: bool = False
|
||||
) -> bool:
|
||||
"""
|
||||
Validate the model credentials.
|
||||
|
||||
:param model_type: Type of the model (e.g., 'STT').
|
||||
:param model_name: Name of the model.
|
||||
:param model_credential: Dictionary containing the model credentials.
|
||||
:param model_params: Parameters for the model.
|
||||
:param provider: Model provider instance.
|
||||
:param raise_exception: Whether to raise an exception on validation failure.
|
||||
:return: Boolean indicating whether the credentials are valid.
|
||||
"""
|
||||
model_type_list = provider.get_model_type_list()
|
||||
if not any(mt.get('value') == model_type for mt in model_type_list):
|
||||
raise AppApiException(
|
||||
ValidCode.valid_error.value,
|
||||
_('{model_type} Model type is not supported').format(model_type=model_type)
|
||||
)
|
||||
|
||||
required_keys = ['api_key']
|
||||
for key in required_keys:
|
||||
if key not in model_credential:
|
||||
if raise_exception:
|
||||
raise AppApiException(
|
||||
ValidCode.valid_error.value,
|
||||
_('{key} is required').format(key=key)
|
||||
)
|
||||
return False
|
||||
|
||||
try:
|
||||
model = provider.get_model(model_type, model_name, model_credential)
|
||||
model.check_auth()
|
||||
except Exception as e:
|
||||
traceback.print_exc()
|
||||
if isinstance(e, AppApiException):
|
||||
raise e
|
||||
if raise_exception:
|
||||
raise AppApiException(
|
||||
ValidCode.valid_error.value,
|
||||
_('Verification failed, please check whether the parameters are correct: {error}').format(error=str(e))
|
||||
)
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
def encryption_dict(self, model: Dict[str, object]) -> Dict[str, object]:
|
||||
"""
|
||||
Encrypt sensitive fields in the model dictionary.
|
||||
|
||||
:param model: Dictionary containing model details.
|
||||
:return: Dictionary with encrypted sensitive fields.
|
||||
"""
|
||||
return {
|
||||
**model,
|
||||
'api_key': super().encryption(model.get('api_key', ''))
|
||||
}
|
||||
|
||||
def get_model_params_setting_form(self, model_name: str):
|
||||
"""
|
||||
Get the parameter setting form for the specified model.
|
||||
|
||||
:param model_name: Name of the model.
|
||||
:return: Parameter setting form (not implemented).
|
||||
"""
|
||||
pass
|
||||
|
|
@ -0,0 +1,147 @@
|
|||
# coding=utf-8
|
||||
|
||||
import traceback
|
||||
from typing import Dict, Any
|
||||
|
||||
from django.utils.translation import gettext_lazy as _, gettext
|
||||
|
||||
from common.exception.app_exception import AppApiException
|
||||
from common.forms import BaseForm, PasswordInputField, SingleSelect, SliderField, TooltipLabel
|
||||
from models_provider.base_model_provider import BaseModelCredential, ValidCode
|
||||
|
||||
|
||||
class QwenModelParams(BaseForm):
|
||||
"""
|
||||
Parameters class for the Qwen Text-to-Image model.
|
||||
Defines fields such as image size, number of images, and style.
|
||||
"""
|
||||
|
||||
size = SingleSelect(
|
||||
TooltipLabel(_('Image size'), _('Specify the size of the generated image, such as: 1024x1024')),
|
||||
required=True,
|
||||
default_value='1024*1024',
|
||||
option_list=[
|
||||
{'value': '1024*1024', 'label': '1024*1024'},
|
||||
{'value': '720*1280', 'label': '720*1280'},
|
||||
{'value': '768*1152', 'label': '768*1152'},
|
||||
{'value': '1280*720', 'label': '1280*720'},
|
||||
],
|
||||
text_field='label',
|
||||
value_field='value'
|
||||
)
|
||||
|
||||
n = SliderField(
|
||||
TooltipLabel(_('Number of pictures'), _('Specify the number of generated images')),
|
||||
required=True,
|
||||
default_value=1,
|
||||
_min=1,
|
||||
_max=4,
|
||||
_step=1,
|
||||
precision=0
|
||||
)
|
||||
|
||||
style = SingleSelect(
|
||||
TooltipLabel(_('Style'), _('Specify the style of generated images')),
|
||||
required=True,
|
||||
default_value='<auto>',
|
||||
option_list=[
|
||||
{'value': '<auto>', 'label': _('Default value, the image style is randomly output by the model')},
|
||||
{'value': '<photography>', 'label': _('photography')},
|
||||
{'value': '<portrait>', 'label': _('Portraits')},
|
||||
{'value': '<3d cartoon>', 'label': _('3D cartoon')},
|
||||
{'value': '<anime>', 'label': _('animation')},
|
||||
{'value': '<oil painting>', 'label': _('painting')},
|
||||
{'value': '<watercolor>', 'label': _('watercolor')},
|
||||
{'value': '<sketch>', 'label': _('sketch')},
|
||||
{'value': '<chinese painting>', 'label': _('Chinese painting')},
|
||||
{'value': '<flat illustration>', 'label': _('flat illustration')},
|
||||
],
|
||||
text_field='label',
|
||||
value_field='value'
|
||||
)
|
||||
|
||||
|
||||
class QwenTextToImageModelCredential(BaseForm, BaseModelCredential):
|
||||
"""
|
||||
Credential class for the Qwen Text-to-Image model.
|
||||
Provides validation and encryption for the model credentials.
|
||||
"""
|
||||
|
||||
api_key = PasswordInputField('API Key', required=True)
|
||||
|
||||
def is_valid(
|
||||
self,
|
||||
model_type: str,
|
||||
model_name: str,
|
||||
model_credential: Dict[str, Any],
|
||||
model_params: Dict[str, Any],
|
||||
provider,
|
||||
raise_exception: bool = False
|
||||
) -> bool:
|
||||
"""
|
||||
Validate the model credentials.
|
||||
|
||||
:param model_type: Type of the model (e.g., 'TEXT_TO_IMAGE').
|
||||
:param model_name: Name of the model.
|
||||
:param model_credential: Dictionary containing the model credentials.
|
||||
:param model_params: Parameters for the model.
|
||||
:param provider: Model provider instance.
|
||||
:param raise_exception: Whether to raise an exception on validation failure.
|
||||
:return: Boolean indicating whether the credentials are valid.
|
||||
"""
|
||||
model_type_list = provider.get_model_type_list()
|
||||
if not any(mt.get('value') == model_type for mt in model_type_list):
|
||||
raise AppApiException(
|
||||
ValidCode.valid_error.value,
|
||||
gettext('{model_type} Model type is not supported').format(model_type=model_type)
|
||||
)
|
||||
|
||||
required_keys = ['api_key']
|
||||
for key in required_keys:
|
||||
if key not in model_credential:
|
||||
if raise_exception:
|
||||
raise AppApiException(
|
||||
ValidCode.valid_error.value,
|
||||
gettext('{key} is required').format(key=key)
|
||||
)
|
||||
return False
|
||||
|
||||
try:
|
||||
model = provider.get_model(model_type, model_name, model_credential, **model_params)
|
||||
res = model.check_auth()
|
||||
print(res)
|
||||
except Exception as e:
|
||||
traceback.print_exc()
|
||||
if isinstance(e, AppApiException):
|
||||
raise e
|
||||
if raise_exception:
|
||||
raise AppApiException(
|
||||
ValidCode.valid_error.value,
|
||||
gettext(
|
||||
'Verification failed, please check whether the parameters are correct: {error}'
|
||||
).format(error=str(e))
|
||||
)
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
def encryption_dict(self, model: Dict[str, object]) -> Dict[str, object]:
|
||||
"""
|
||||
Encrypt sensitive fields in the model dictionary.
|
||||
|
||||
:param model: Dictionary containing model details.
|
||||
:return: Dictionary with encrypted sensitive fields.
|
||||
"""
|
||||
return {
|
||||
**model,
|
||||
'api_key': super().encryption(model.get('api_key', ''))
|
||||
}
|
||||
|
||||
def get_model_params_setting_form(self, model_name: str):
|
||||
"""
|
||||
Get the parameter setting form for the specified model.
|
||||
|
||||
:param model_name: Name of the model.
|
||||
:return: Parameter setting form.
|
||||
"""
|
||||
return QwenModelParams()
|
||||
|
|
@ -0,0 +1,139 @@
|
|||
# coding=utf-8
|
||||
|
||||
import traceback
|
||||
from typing import Dict
|
||||
|
||||
from django.utils.translation import gettext_lazy as _, gettext
|
||||
|
||||
from common.exception.app_exception import AppApiException
|
||||
from common.forms import BaseForm, PasswordInputField, SingleSelect, SliderField, TooltipLabel
|
||||
from models_provider.base_model_provider import BaseModelCredential, ValidCode
|
||||
|
||||
|
||||
class AliyunBaiLianTTSModelGeneralParams(BaseForm):
|
||||
"""
|
||||
Parameters class for the Aliyun BaiLian TTS (Text-to-Speech) model.
|
||||
Defines fields such as voice and speech rate.
|
||||
"""
|
||||
|
||||
voice = SingleSelect(
|
||||
TooltipLabel(_('Timbre'), _('Chinese sounds can support mixed scenes of Chinese and English')),
|
||||
required=True,
|
||||
default_value='longxiaochun',
|
||||
text_field='value',
|
||||
value_field='value',
|
||||
option_list=[
|
||||
{'text': _('Long Xiaochun'), 'value': 'longxiaochun'},
|
||||
{'text': _('Long Xiaoxia'), 'value': 'longxiaoxia'},
|
||||
{'text': _('Long Xiaochen'), 'value': 'longxiaocheng'},
|
||||
{'text': _('Long Xiaobai'), 'value': 'longxiaobai'},
|
||||
{'text': _('Long Laotie'), 'value': 'longlaotie'},
|
||||
{'text': _('Long Shu'), 'value': 'longshu'},
|
||||
{'text': _('Long Shuo'), 'value': 'longshuo'},
|
||||
{'text': _('Long Jing'), 'value': 'longjing'},
|
||||
{'text': _('Long Miao'), 'value': 'longmiao'},
|
||||
{'text': _('Long Yue'), 'value': 'longyue'},
|
||||
{'text': _('Long Yuan'), 'value': 'longyuan'},
|
||||
{'text': _('Long Fei'), 'value': 'longfei'},
|
||||
{'text': _('Long Jielidou'), 'value': 'longjielidou'},
|
||||
{'text': _('Long Tong'), 'value': 'longtong'},
|
||||
{'text': _('Long Xiang'), 'value': 'longxiang'},
|
||||
{'text': 'Stella', 'value': 'loongstella'},
|
||||
{'text': 'Bella', 'value': 'loongbella'},
|
||||
]
|
||||
)
|
||||
|
||||
speech_rate = SliderField(
|
||||
TooltipLabel(_('Speaking speed'), _('[0.5, 2], the default is 1, usually one decimal place is enough')),
|
||||
required=True,
|
||||
default_value=1,
|
||||
_min=0.5,
|
||||
_max=2,
|
||||
_step=0.1,
|
||||
precision=1
|
||||
)
|
||||
|
||||
|
||||
class AliyunBaiLianTTSModelCredential(BaseForm, BaseModelCredential):
|
||||
"""
|
||||
Credential class for the Aliyun BaiLian TTS (Text-to-Speech) model.
|
||||
Provides validation and encryption for the model credentials.
|
||||
"""
|
||||
|
||||
api_key = PasswordInputField("API Key", required=True)
|
||||
|
||||
def is_valid(
|
||||
self,
|
||||
model_type: str,
|
||||
model_name: str,
|
||||
model_credential: Dict[str, object],
|
||||
model_params,
|
||||
provider,
|
||||
raise_exception: bool = False
|
||||
) -> bool:
|
||||
"""
|
||||
Validate the model credentials.
|
||||
|
||||
:param model_type: Type of the model (e.g., 'TTS').
|
||||
:param model_name: Name of the model.
|
||||
:param model_credential: Dictionary containing the model credentials.
|
||||
:param model_params: Parameters for the model.
|
||||
:param provider: Model provider instance.
|
||||
:param raise_exception: Whether to raise an exception on validation failure.
|
||||
:return: Boolean indicating whether the credentials are valid.
|
||||
"""
|
||||
model_type_list = provider.get_model_type_list()
|
||||
if not any(mt.get('value') == model_type for mt in model_type_list):
|
||||
raise AppApiException(
|
||||
ValidCode.valid_error.value,
|
||||
gettext('{model_type} Model type is not supported').format(model_type=model_type)
|
||||
)
|
||||
|
||||
required_keys = ['api_key']
|
||||
for key in required_keys:
|
||||
if key not in model_credential:
|
||||
if raise_exception:
|
||||
raise AppApiException(
|
||||
ValidCode.valid_error.value,
|
||||
gettext('{key} is required').format(key=key)
|
||||
)
|
||||
return False
|
||||
|
||||
try:
|
||||
model = provider.get_model(model_type, model_name, model_credential, **model_params)
|
||||
model.check_auth()
|
||||
except Exception as e:
|
||||
traceback.print_exc()
|
||||
if isinstance(e, AppApiException):
|
||||
raise e
|
||||
if raise_exception:
|
||||
raise AppApiException(
|
||||
ValidCode.valid_error.value,
|
||||
gettext(
|
||||
'Verification failed, please check whether the parameters are correct: {error}'
|
||||
).format(error=str(e))
|
||||
)
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
def encryption_dict(self, model: Dict[str, object]) -> Dict[str, object]:
|
||||
"""
|
||||
Encrypt sensitive fields in the model dictionary.
|
||||
|
||||
:param model: Dictionary containing model details.
|
||||
:return: Dictionary with encrypted sensitive fields.
|
||||
"""
|
||||
return {
|
||||
**model,
|
||||
'api_key': super().encryption(model.get('api_key', ''))
|
||||
}
|
||||
|
||||
def get_model_params_setting_form(self, model_name: str):
|
||||
"""
|
||||
Get the parameter setting form for the specified model.
|
||||
|
||||
:param model_name: Name of the model.
|
||||
:return: Parameter setting form.
|
||||
"""
|
||||
return AliyunBaiLianTTSModelGeneralParams()
|
||||
|
|
@ -0,0 +1 @@
|
|||
<svg id="图层_1" data-name="图层 1" xmlns="http://www.w3.org/2000/svg" viewBox="0 0 50 50"><defs><style>.cls-1{fill:#6d49f6;}.cls-2{fill:#6be8d2;}.cls-3{fill:#9b8ee8;}.cls-4{fill:#0d22d2;}.cls-5{fill:#2c53dc;}.cls-6{fill:#5eccc9;}.cls-7{fill:#fff;}</style></defs><title>【icon】阿里百炼大模型</title><path class="cls-1" d="M35.463,6.454,25,12.5,14.186,18.737,4.032,12.882a2.33,2.33,0,0,1,.851-.851L23.829,1.089a2.358,2.358,0,0,1,2.342,0Z"/><polygon class="cls-2" points="35.825 31.232 35.825 31.243 25 37.491 35.814 31.232 35.825 31.232"/><polygon class="cls-2" points="35.825 31.232 35.825 31.243 25 37.491 35.814 31.232 35.825 31.232"/><polygon class="cls-2" points="35.825 31.232 35.825 31.243 25 37.491 35.814 31.232 35.825 31.232"/><polygon class="cls-2" points="35.825 31.232 35.825 31.243 25 37.491 35.814 31.232 35.825 31.232"/><polygon class="cls-2" points="35.825 31.232 35.825 31.243 25 37.491 35.814 31.232 35.825 31.232"/><path class="cls-3" d="M45.117,12.031,35.463,6.454,25,12.5,14.186,18.737v.01L25,25h.011l10.814-6.248,10.143-5.865A2.33,2.33,0,0,0,45.117,12.031Z"/><path class="cls-4" d="M4.032,12.882a2.254,2.254,0,0,0-.32,1.171V35.926a2.319,2.319,0,0,0,.32,1.182l10.143-5.865v-12.5h.011v-.01Z"/><polygon class="cls-5" points="25 24.995 14.175 31.243 14.175 18.747 14.186 18.747 25 24.995"/><path class="cls-6" d="M45.968,37.1a2.278,2.278,0,0,1-.851.862L26.171,48.9a2.358,2.358,0,0,1-2.342,0l-9.3-5.375,10.462-6.035H25l10.825-6.248Z"/><path class="cls-2" d="M46.288,25.2V35.926a2.251,2.251,0,0,1-.32,1.171L35.825,31.243v-.011Z"/><path class="cls-7" d="M46.288,14.053V25.2L35.825,31.232V18.747l10.143-5.865A2.254,2.254,0,0,1,46.288,14.053Z"/><polygon class="cls-7" points="35.825 18.747 35.825 31.232 35.814 31.232 25.011 24.995 35.825 18.747"/><path class="cls-2" d="M35.814,31.232,25.011,25H25L14.175,31.243,4.032,37.108a2.33,2.33,0,0,0,.851.851l9.644,5.567,10.462-6.035H25l10.825-6.248v-.011Z"/></svg>
|
||||
|
After Width: | Height: | Size: 1.9 KiB |
|
|
@ -0,0 +1,66 @@
|
|||
# coding=utf-8
|
||||
"""
|
||||
@project: MaxKB
|
||||
@Author:虎
|
||||
@file: embedding.py
|
||||
@date:2024/10/16 16:34
|
||||
@desc:
|
||||
"""
|
||||
from functools import reduce
|
||||
from typing import Dict, List
|
||||
|
||||
from langchain_community.embeddings import DashScopeEmbeddings
|
||||
from langchain_community.embeddings.dashscope import embed_with_retry
|
||||
|
||||
from models_provider.base_model_provider import MaxKBBaseModel
|
||||
|
||||
|
||||
def proxy_embed_documents(texts: List[str], step_size, embed_documents):
|
||||
value = [embed_documents(texts[start_index:start_index + step_size]) for start_index in
|
||||
range(0, len(texts), step_size)]
|
||||
return reduce(lambda x, y: [*x, *y], value, [])
|
||||
|
||||
|
||||
class AliyunBaiLianEmbedding(MaxKBBaseModel, DashScopeEmbeddings):
|
||||
@staticmethod
|
||||
def new_instance(model_type, model_name, model_credential: Dict[str, object], **model_kwargs):
|
||||
return AliyunBaiLianEmbedding(
|
||||
model=model_name,
|
||||
dashscope_api_key=model_credential.get('dashscope_api_key')
|
||||
)
|
||||
|
||||
def embed_documents(self, texts: List[str]) -> List[List[float]]:
|
||||
if self.model == 'text-embedding-v3':
|
||||
return proxy_embed_documents(texts, 6, self._embed_documents)
|
||||
return self._embed_documents(texts)
|
||||
|
||||
def _embed_documents(self, texts: List[str]) -> List[List[float]]:
|
||||
"""Call out to DashScope's embedding endpoint for embedding search docs.
|
||||
|
||||
Args:
|
||||
texts: The list of texts to embed.
|
||||
chunk_size: The chunk size of embeddings. If None, will use the chunk size
|
||||
specified by the class.
|
||||
|
||||
Returns:
|
||||
List of embeddings, one for each text.
|
||||
"""
|
||||
embeddings = embed_with_retry(
|
||||
self, input=texts, text_type="document", model=self.model
|
||||
)
|
||||
embedding_list = [item["embedding"] for item in embeddings]
|
||||
return embedding_list
|
||||
|
||||
def embed_query(self, text: str) -> List[float]:
|
||||
"""Call out to DashScope's embedding endpoint for embedding query text.
|
||||
|
||||
Args:
|
||||
text: The text to embed.
|
||||
|
||||
Returns:
|
||||
Embedding for the text.
|
||||
"""
|
||||
embedding = embed_with_retry(
|
||||
self, input=[text], text_type="document", model=self.model
|
||||
)[0]["embedding"]
|
||||
return embedding
|
||||
Binary file not shown.
|
|
@ -0,0 +1,23 @@
|
|||
# coding=utf-8
|
||||
|
||||
from typing import Dict
|
||||
|
||||
from models_provider.base_model_provider import MaxKBBaseModel
|
||||
from models_provider.impl.base_chat_open_ai import BaseChatOpenAI
|
||||
|
||||
|
||||
class QwenVLChatModel(MaxKBBaseModel, BaseChatOpenAI):
|
||||
|
||||
@staticmethod
|
||||
def new_instance(model_type, model_name, model_credential: Dict[str, object], **model_kwargs):
|
||||
optional_params = MaxKBBaseModel.filter_optional_params(model_kwargs)
|
||||
chat_tong_yi = QwenVLChatModel(
|
||||
model_name=model_name,
|
||||
openai_api_key=model_credential.get('api_key'),
|
||||
openai_api_base='https://dashscope.aliyuncs.com/compatible-mode/v1',
|
||||
# stream_options={"include_usage": True},
|
||||
streaming=True,
|
||||
stream_usage=True,
|
||||
**optional_params,
|
||||
)
|
||||
return chat_tong_yi
|
||||
|
|
@ -0,0 +1,24 @@
|
|||
#!/usr/bin/env python
|
||||
# -*- coding: UTF-8 -*-
|
||||
from typing import Dict
|
||||
|
||||
from models_provider.base_model_provider import MaxKBBaseModel
|
||||
from models_provider.impl.base_chat_open_ai import BaseChatOpenAI
|
||||
|
||||
|
||||
class BaiLianChatModel(MaxKBBaseModel, BaseChatOpenAI):
|
||||
@staticmethod
|
||||
def is_cache_model():
|
||||
return False
|
||||
|
||||
@staticmethod
|
||||
def new_instance(model_type, model_name, model_credential: Dict[str, object], **model_kwargs):
|
||||
optional_params = MaxKBBaseModel.filter_optional_params(model_kwargs)
|
||||
if 'qwen-omni-turbo' in model_name or 'qwq' in model_name:
|
||||
optional_params['streaming'] = True
|
||||
return BaiLianChatModel(
|
||||
model=model_name,
|
||||
openai_api_base=model_credential.get('api_base'),
|
||||
openai_api_key=model_credential.get('api_key'),
|
||||
**optional_params
|
||||
)
|
||||
|
|
@ -0,0 +1,20 @@
|
|||
# coding=utf-8
|
||||
"""
|
||||
@project: MaxKB
|
||||
@Author:虎
|
||||
@file: reranker.py.py
|
||||
@date:2024/9/2 16:42
|
||||
@desc:
|
||||
"""
|
||||
from typing import Dict
|
||||
|
||||
from langchain_community.document_compressors import DashScopeRerank
|
||||
|
||||
from models_provider.base_model_provider import MaxKBBaseModel
|
||||
|
||||
|
||||
class AliyunBaiLianReranker(MaxKBBaseModel, DashScopeRerank):
|
||||
@staticmethod
|
||||
def new_instance(model_type, model_name, model_credential: Dict[str, object], **model_kwargs):
|
||||
return AliyunBaiLianReranker(model=model_name, dashscope_api_key=model_credential.get('dashscope_api_key'),
|
||||
top_n=model_kwargs.get('top_n', 3))
|
||||
|
|
@ -0,0 +1,75 @@
|
|||
import os
|
||||
import tempfile
|
||||
from typing import Dict
|
||||
|
||||
import dashscope
|
||||
from dashscope.audio.asr import (Recognition)
|
||||
from pydub import AudioSegment
|
||||
|
||||
from models_provider.base_model_provider import MaxKBBaseModel
|
||||
from models_provider.impl.base_stt import BaseSpeechToText
|
||||
|
||||
|
||||
class AliyunBaiLianSpeechToText(MaxKBBaseModel, BaseSpeechToText):
|
||||
api_key: str
|
||||
model: str
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
self.api_key = kwargs.get('api_key')
|
||||
self.model = kwargs.get('model')
|
||||
|
||||
@staticmethod
|
||||
def new_instance(model_type, model_name, model_credential: Dict[str, object], **model_kwargs):
|
||||
optional_params = {}
|
||||
if 'max_tokens' in model_kwargs and model_kwargs['max_tokens'] is not None:
|
||||
optional_params['max_tokens'] = model_kwargs['max_tokens']
|
||||
if 'temperature' in model_kwargs and model_kwargs['temperature'] is not None:
|
||||
optional_params['temperature'] = model_kwargs['temperature']
|
||||
if model_name == 'qwen-omni-turbo':
|
||||
optional_params['streaming'] = True
|
||||
return AliyunBaiLianSpeechToText(
|
||||
model=model_name,
|
||||
api_key=model_credential.get('api_key'),
|
||||
**optional_params,
|
||||
)
|
||||
|
||||
def check_auth(self):
|
||||
cwd = os.path.dirname(os.path.abspath(__file__))
|
||||
with open(f'{cwd}/iat_mp3_16k.mp3', 'rb') as f:
|
||||
self.speech_to_text(f)
|
||||
|
||||
def speech_to_text(self, audio_file):
|
||||
dashscope.api_key = self.api_key
|
||||
recognition = Recognition(model=self.model,
|
||||
format='mp3',
|
||||
sample_rate=16000,
|
||||
callback=None)
|
||||
with tempfile.NamedTemporaryFile(delete=False) as temp_file:
|
||||
# 将上传的文件保存到临时文件中
|
||||
temp_file.write(audio_file.read())
|
||||
# 获取临时文件的路径
|
||||
temp_file_path = temp_file.name
|
||||
|
||||
try:
|
||||
audio = AudioSegment.from_file(temp_file_path)
|
||||
if audio.channels != 1:
|
||||
audio = audio.set_channels(1)
|
||||
audio = audio.set_frame_rate(16000)
|
||||
|
||||
# 将转换后的音频文件保存到临时文件中
|
||||
audio.export(temp_file_path, format='mp3')
|
||||
# 识别临时文件
|
||||
result = recognition.call(temp_file_path)
|
||||
text = ''
|
||||
if result.status_code == 200:
|
||||
result_sentence = result.get_sentence()
|
||||
if result_sentence is not None:
|
||||
for sentence in result_sentence:
|
||||
text += sentence['text']
|
||||
return text
|
||||
else:
|
||||
raise Exception('Error: ', result.message)
|
||||
finally:
|
||||
# 删除临时文件
|
||||
os.remove(temp_file_path)
|
||||
|
|
@ -0,0 +1,59 @@
|
|||
# coding=utf-8
|
||||
from http import HTTPStatus
|
||||
from typing import Dict
|
||||
|
||||
from dashscope import ImageSynthesis
|
||||
from django.utils.translation import gettext
|
||||
from langchain_community.chat_models import ChatTongyi
|
||||
from langchain_core.messages import HumanMessage
|
||||
|
||||
from models_provider.base_model_provider import MaxKBBaseModel
|
||||
from models_provider.impl.base_tti import BaseTextToImage
|
||||
|
||||
|
||||
class QwenTextToImageModel(MaxKBBaseModel, BaseTextToImage):
|
||||
api_key: str
|
||||
model_name: str
|
||||
params: dict
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
self.api_key = kwargs.get('api_key')
|
||||
self.model_name = kwargs.get('model_name')
|
||||
self.params = kwargs.get('params')
|
||||
|
||||
@staticmethod
|
||||
def new_instance(model_type, model_name, model_credential: Dict[str, object], **model_kwargs):
|
||||
optional_params = {'params': {'size': '1024*1024', 'style': '<auto>', 'n': 1}}
|
||||
for key, value in model_kwargs.items():
|
||||
if key not in ['model_id', 'use_local', 'streaming']:
|
||||
optional_params['params'][key] = value
|
||||
chat_tong_yi = QwenTextToImageModel(
|
||||
model_name=model_name,
|
||||
api_key=model_credential.get('api_key'),
|
||||
**optional_params,
|
||||
)
|
||||
return chat_tong_yi
|
||||
|
||||
def is_cache_model(self):
|
||||
return False
|
||||
|
||||
def check_auth(self):
|
||||
chat = ChatTongyi(api_key=self.api_key, model_name='qwen-max')
|
||||
chat.invoke([HumanMessage([{"type": "text", "text": gettext('Hello')}])])
|
||||
|
||||
def generate_image(self, prompt: str, negative_prompt: str = None):
|
||||
# api_base='https://dashscope.aliyuncs.com/compatible-mode/v1',
|
||||
rsp = ImageSynthesis.call(api_key=self.api_key,
|
||||
model=self.model_name,
|
||||
prompt=prompt,
|
||||
negative_prompt=negative_prompt,
|
||||
**self.params)
|
||||
file_urls = []
|
||||
if rsp.status_code == HTTPStatus.OK:
|
||||
for result in rsp.output.results:
|
||||
file_urls.append(result.url)
|
||||
else:
|
||||
print('sync_call Failed, status_code: %s, code: %s, message: %s' %
|
||||
(rsp.status_code, rsp.code, rsp.message))
|
||||
return file_urls
|
||||
|
|
@ -0,0 +1,57 @@
|
|||
from typing import Dict
|
||||
|
||||
import dashscope
|
||||
|
||||
from django.utils.translation import gettext as _
|
||||
|
||||
from common.utils.common import _remove_empty_lines
|
||||
from models_provider.base_model_provider import MaxKBBaseModel
|
||||
from models_provider.impl.base_tts import BaseTextToSpeech
|
||||
|
||||
|
||||
class AliyunBaiLianTextToSpeech(MaxKBBaseModel, BaseTextToSpeech):
|
||||
api_key: str
|
||||
model: str
|
||||
params: dict
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
self.api_key = kwargs.get('api_key')
|
||||
self.model = kwargs.get('model')
|
||||
self.params = kwargs.get('params')
|
||||
|
||||
@staticmethod
|
||||
def new_instance(model_type, model_name, model_credential: Dict[str, object], **model_kwargs):
|
||||
optional_params = {'params': {'voice': 'longxiaochun', 'speech_rate': 1.0}}
|
||||
for key, value in model_kwargs.items():
|
||||
if key not in ['model_id', 'use_local', 'streaming']:
|
||||
optional_params['params'][key] = value
|
||||
|
||||
return AliyunBaiLianTextToSpeech(
|
||||
model=model_name,
|
||||
api_key=model_credential.get('api_key'),
|
||||
**optional_params,
|
||||
)
|
||||
|
||||
def check_auth(self):
|
||||
self.text_to_speech(_('Hello'))
|
||||
|
||||
def text_to_speech(self, text):
|
||||
dashscope.api_key = self.api_key
|
||||
text = _remove_empty_lines(text)
|
||||
if 'sambert' in self.model:
|
||||
from dashscope.audio.tts import SpeechSynthesizer
|
||||
audio = SpeechSynthesizer.call(model=self.model, text=text, **self.params).get_audio_data()
|
||||
else:
|
||||
from dashscope.audio.tts_v2 import SpeechSynthesizer
|
||||
synthesizer = SpeechSynthesizer(model=self.model, **self.params)
|
||||
audio = synthesizer.call(text)
|
||||
if audio is None:
|
||||
raise Exception('Failed to generate audio')
|
||||
if type(audio) == str:
|
||||
print(audio)
|
||||
raise Exception(audio)
|
||||
return audio
|
||||
|
||||
def is_cache_model(self):
|
||||
return False
|
||||
|
|
@ -0,0 +1,8 @@
|
|||
# coding=utf-8
|
||||
"""
|
||||
@project: maxkb
|
||||
@Author:虎
|
||||
@file: __init__.py.py
|
||||
@date:2024/3/28 16:25
|
||||
@desc:
|
||||
"""
|
||||
|
|
@ -0,0 +1,62 @@
|
|||
# coding=utf-8
|
||||
"""
|
||||
@project: maxkb
|
||||
@Author:虎
|
||||
@file: openai_model_provider.py
|
||||
@date:2024/3/28 16:26
|
||||
@desc:
|
||||
"""
|
||||
import os
|
||||
|
||||
from common.utils.common import get_file_content
|
||||
from models_provider.base_model_provider import IModelProvider, ModelProvideInfo, ModelInfo, \
|
||||
ModelTypeConst, ModelInfoManage
|
||||
from models_provider.impl.anthropic_model_provider.credential.image import AnthropicImageModelCredential
|
||||
from models_provider.impl.anthropic_model_provider.credential.llm import AnthropicLLMModelCredential
|
||||
from models_provider.impl.anthropic_model_provider.model.image import AnthropicImage
|
||||
from models_provider.impl.anthropic_model_provider.model.llm import AnthropicChatModel
|
||||
from maxkb.conf import PROJECT_DIR
|
||||
|
||||
openai_llm_model_credential = AnthropicLLMModelCredential()
|
||||
openai_image_model_credential = AnthropicImageModelCredential()
|
||||
|
||||
model_info_list = [
|
||||
ModelInfo('claude-3-opus-20240229', '', ModelTypeConst.LLM,
|
||||
openai_llm_model_credential, AnthropicChatModel
|
||||
),
|
||||
ModelInfo('claude-3-sonnet-20240229', '', ModelTypeConst.LLM, openai_llm_model_credential,
|
||||
AnthropicChatModel),
|
||||
ModelInfo('claude-3-haiku-20240307', '', ModelTypeConst.LLM, openai_llm_model_credential,
|
||||
AnthropicChatModel),
|
||||
ModelInfo('claude-3-5-sonnet-20240620', '', ModelTypeConst.LLM, openai_llm_model_credential,
|
||||
AnthropicChatModel),
|
||||
ModelInfo('claude-3-5-haiku-20241022', '', ModelTypeConst.LLM, openai_llm_model_credential,
|
||||
AnthropicChatModel),
|
||||
ModelInfo('claude-3-5-sonnet-20241022', '', ModelTypeConst.LLM, openai_llm_model_credential,
|
||||
AnthropicChatModel),
|
||||
]
|
||||
|
||||
image_model_info = [
|
||||
ModelInfo('claude-3-5-sonnet-20241022', '', ModelTypeConst.IMAGE, openai_image_model_credential,
|
||||
AnthropicImage),
|
||||
]
|
||||
|
||||
model_info_manage = (
|
||||
ModelInfoManage.builder()
|
||||
.append_model_info_list(model_info_list)
|
||||
.append_default_model_info(model_info_list[0])
|
||||
.append_model_info_list(image_model_info)
|
||||
.append_default_model_info(image_model_info[0])
|
||||
.build()
|
||||
)
|
||||
|
||||
|
||||
class AnthropicModelProvider(IModelProvider):
|
||||
|
||||
def get_model_info_manage(self):
|
||||
return model_info_manage
|
||||
|
||||
def get_model_provide_info(self):
|
||||
return ModelProvideInfo(provider='model_anthropic_provider', name='Anthropic', icon=get_file_content(
|
||||
os.path.join(PROJECT_DIR, "apps", 'models_provider', 'impl', 'anthropic_model_provider', 'icon',
|
||||
'anthropic_icon_svg')))
|
||||
|
|
@ -0,0 +1,53 @@
|
|||
# coding=utf-8
|
||||
import traceback
|
||||
from typing import Dict
|
||||
|
||||
from django.utils.translation import gettext_lazy as _, gettext
|
||||
from langchain_core.messages import HumanMessage
|
||||
|
||||
from common import forms
|
||||
from common.exception.app_exception import AppApiException
|
||||
from common.forms import BaseForm, TooltipLabel
|
||||
from models_provider.base_model_provider import BaseModelCredential, ValidCode
|
||||
|
||||
|
||||
class AnthropicImageModelCredential(BaseForm, BaseModelCredential):
|
||||
api_base = forms.TextInputField(_('API URL'), required=True)
|
||||
api_key = forms.PasswordInputField(_('API Key'), required=True)
|
||||
|
||||
def is_valid(self, model_type: str, model_name, model_credential: Dict[str, object], model_params, provider,
|
||||
raise_exception=False):
|
||||
model_type_list = provider.get_model_type_list()
|
||||
if not any(list(filter(lambda mt: mt.get('value') == model_type, model_type_list))):
|
||||
raise AppApiException(ValidCode.valid_error.value,
|
||||
gettext('{model_type} Model type is not supported').format(model_type=model_type))
|
||||
|
||||
for key in ['api_base', 'api_key']:
|
||||
if key not in model_credential:
|
||||
if raise_exception:
|
||||
raise AppApiException(ValidCode.valid_error.value, gettext('{key} is required').format(key=key))
|
||||
else:
|
||||
return False
|
||||
try:
|
||||
model = provider.get_model(model_type, model_name, model_credential)
|
||||
res = model.stream([HumanMessage(content=[{"type": "text", "text": gettext("Hello")}])])
|
||||
for chunk in res:
|
||||
print(chunk)
|
||||
except Exception as e:
|
||||
traceback.print_exc()
|
||||
if isinstance(e, AppApiException):
|
||||
raise e
|
||||
if raise_exception:
|
||||
raise AppApiException(ValidCode.valid_error.value,
|
||||
gettext(
|
||||
'Verification failed, please check whether the parameters are correct: {error}').format(
|
||||
error=str(e)))
|
||||
else:
|
||||
return False
|
||||
return True
|
||||
|
||||
def encryption_dict(self, model: Dict[str, object]):
|
||||
return {**model, 'api_key': super().encryption(model.get('api_key', ''))}
|
||||
|
||||
def get_model_params_setting_form(self, model_name):
|
||||
pass
|
||||
|
|
@ -0,0 +1,78 @@
|
|||
# coding=utf-8
|
||||
"""
|
||||
@project: MaxKB
|
||||
@Author:虎
|
||||
@file: llm.py
|
||||
@date:2024/7/11 18:32
|
||||
@desc:
|
||||
"""
|
||||
import traceback
|
||||
from typing import Dict
|
||||
|
||||
from langchain_core.messages import HumanMessage
|
||||
|
||||
from common import forms
|
||||
from common.exception.app_exception import AppApiException
|
||||
from common.forms import BaseForm, TooltipLabel
|
||||
from models_provider.base_model_provider import BaseModelCredential, ValidCode
|
||||
from django.utils.translation import gettext_lazy as _, gettext
|
||||
|
||||
|
||||
class AnthropicLLMModelParams(BaseForm):
|
||||
temperature = forms.SliderField(TooltipLabel(_('Temperature'),
|
||||
_('Higher values make the output more random, while lower values make it more focused and deterministic')),
|
||||
required=True, default_value=0.7,
|
||||
_min=0.1,
|
||||
_max=1.0,
|
||||
_step=0.01,
|
||||
precision=2)
|
||||
|
||||
max_tokens = forms.SliderField(
|
||||
TooltipLabel(_('Output the maximum Tokens'),
|
||||
_('Specify the maximum number of tokens that the model can generate')),
|
||||
required=True, default_value=800,
|
||||
_min=1,
|
||||
_max=100000,
|
||||
_step=1,
|
||||
precision=0)
|
||||
|
||||
|
||||
class AnthropicLLMModelCredential(BaseForm, BaseModelCredential):
|
||||
|
||||
def is_valid(self, model_type: str, model_name, model_credential: Dict[str, object], model_params, provider,
|
||||
raise_exception=False):
|
||||
model_type_list = provider.get_model_type_list()
|
||||
if not any(list(filter(lambda mt: mt.get('value') == model_type, model_type_list))):
|
||||
raise AppApiException(ValidCode.valid_error.value,
|
||||
gettext('{model_type} Model type is not supported').format(model_type=model_type))
|
||||
|
||||
for key in ['api_base', 'api_key']:
|
||||
if key not in model_credential:
|
||||
if raise_exception:
|
||||
raise AppApiException(ValidCode.valid_error.value, gettext('{key} is required').format(key=key))
|
||||
else:
|
||||
return False
|
||||
try:
|
||||
model = provider.get_model(model_type, model_name, model_credential)
|
||||
model.invoke([HumanMessage(content=gettext('Hello'))])
|
||||
except Exception as e:
|
||||
traceback.print_exc()
|
||||
if isinstance(e, AppApiException):
|
||||
raise e
|
||||
if raise_exception:
|
||||
raise AppApiException(ValidCode.valid_error.value,
|
||||
gettext(
|
||||
'Verification failed, please check whether the parameters are correct: {error}').format(
|
||||
error=str(e)))
|
||||
else:
|
||||
return False
|
||||
return True
|
||||
|
||||
def encryption_dict(self, model: Dict[str, object]):
|
||||
return {**model, 'api_key': super().encryption(model.get('api_key', ''))}
|
||||
|
||||
api_base = forms.TextInputField(_('API URL'), required=True)
|
||||
api_key = forms.PasswordInputField(_('API Key'), required=True)
|
||||
|
||||
def get_model_params_setting_form(self, model_name):
|
||||
return AnthropicLLMModelParams()
|
||||
|
|
@ -0,0 +1 @@
|
|||
<svg xmlns="http://www.w3.org/2000/svg" shape-rendering="geometricPrecision" text-rendering="geometricPrecision" image-rendering="optimizeQuality" fill-rule="evenodd" clip-rule="evenodd" viewBox="0 0 512 512"><rect fill="#CC9B7A" width="512" height="512" rx="104.187" ry="105.042"/><path fill="#1F1F1E" fill-rule="nonzero" d="M318.663 149.787h-43.368l78.952 212.423 43.368.004-78.952-212.427zm-125.326 0l-78.952 212.427h44.255l15.932-44.608 82.846-.004 16.107 44.612h44.255l-79.126-212.427h-45.317zm-4.251 128.341l26.91-74.701 27.083 74.701h-53.993z"/></svg>
|
||||
|
After Width: | Height: | Size: 558 B |
|
|
@ -0,0 +1,26 @@
|
|||
from typing import Dict
|
||||
|
||||
from langchain_anthropic import ChatAnthropic
|
||||
|
||||
from common.config.tokenizer_manage_config import TokenizerManage
|
||||
from models_provider.base_model_provider import MaxKBBaseModel
|
||||
|
||||
|
||||
def custom_get_token_ids(text: str):
|
||||
tokenizer = TokenizerManage.get_tokenizer()
|
||||
return tokenizer.encode(text)
|
||||
|
||||
|
||||
class AnthropicImage(MaxKBBaseModel, ChatAnthropic):
|
||||
|
||||
@staticmethod
|
||||
def new_instance(model_type, model_name, model_credential: Dict[str, object], **model_kwargs):
|
||||
optional_params = MaxKBBaseModel.filter_optional_params(model_kwargs)
|
||||
return AnthropicImage(
|
||||
model=model_name,
|
||||
anthropic_api_url=model_credential.get('api_base'),
|
||||
anthropic_api_key=model_credential.get('api_key'),
|
||||
# stream_options={"include_usage": True},
|
||||
streaming=True,
|
||||
**optional_params,
|
||||
)
|
||||
|
|
@ -0,0 +1,53 @@
|
|||
# coding=utf-8
|
||||
"""
|
||||
@project: maxkb
|
||||
@Author:虎
|
||||
@file: llm.py
|
||||
@date:2024/4/18 15:28
|
||||
@desc:
|
||||
"""
|
||||
from typing import List, Dict
|
||||
|
||||
from langchain_anthropic import ChatAnthropic
|
||||
from langchain_core.messages import BaseMessage, get_buffer_string
|
||||
|
||||
from common.config.tokenizer_manage_config import TokenizerManage
|
||||
from models_provider.base_model_provider import MaxKBBaseModel
|
||||
|
||||
|
||||
def custom_get_token_ids(text: str):
|
||||
tokenizer = TokenizerManage.get_tokenizer()
|
||||
return tokenizer.encode(text)
|
||||
|
||||
|
||||
class AnthropicChatModel(MaxKBBaseModel, ChatAnthropic):
|
||||
|
||||
@staticmethod
|
||||
def is_cache_model():
|
||||
return False
|
||||
|
||||
@staticmethod
|
||||
def new_instance(model_type, model_name, model_credential: Dict[str, object], **model_kwargs):
|
||||
optional_params = MaxKBBaseModel.filter_optional_params(model_kwargs)
|
||||
azure_chat_open_ai = AnthropicChatModel(
|
||||
model=model_name,
|
||||
anthropic_api_url=model_credential.get('api_base'),
|
||||
anthropic_api_key=model_credential.get('api_key'),
|
||||
**optional_params,
|
||||
custom_get_token_ids=custom_get_token_ids
|
||||
)
|
||||
return azure_chat_open_ai
|
||||
|
||||
def get_num_tokens_from_messages(self, messages: List[BaseMessage]) -> int:
|
||||
try:
|
||||
return super().get_num_tokens_from_messages(messages)
|
||||
except Exception as e:
|
||||
tokenizer = TokenizerManage.get_tokenizer()
|
||||
return sum([len(tokenizer.encode(get_buffer_string([m]))) for m in messages])
|
||||
|
||||
def get_num_tokens(self, text: str) -> int:
|
||||
try:
|
||||
return super().get_num_tokens(text)
|
||||
except Exception as e:
|
||||
tokenizer = TokenizerManage.get_tokenizer()
|
||||
return len(tokenizer.encode(text))
|
||||
|
|
@ -0,0 +1,2 @@
|
|||
#!/usr/bin/env python
|
||||
# -*- coding: UTF-8 -*-
|
||||
|
|
@ -0,0 +1,154 @@
|
|||
#!/usr/bin/env python
|
||||
# -*- coding: UTF-8 -*-
|
||||
|
||||
import os
|
||||
from common.utils.common import get_file_content
|
||||
from models_provider.base_model_provider import (
|
||||
IModelProvider, ModelProvideInfo, ModelInfo, ModelTypeConst, ModelInfoManage
|
||||
)
|
||||
from models_provider.impl.aws_bedrock_model_provider.credential.embedding import BedrockEmbeddingCredential
|
||||
from models_provider.impl.aws_bedrock_model_provider.credential.llm import BedrockLLMModelCredential
|
||||
from models_provider.impl.aws_bedrock_model_provider.model.embedding import BedrockEmbeddingModel
|
||||
from models_provider.impl.aws_bedrock_model_provider.model.llm import BedrockModel
|
||||
from maxkb.conf import PROJECT_DIR
|
||||
from django.utils.translation import gettext as _
|
||||
|
||||
|
||||
def _create_model_info(model_name, description, model_type, credential_class, model_class):
|
||||
return ModelInfo(
|
||||
name=model_name,
|
||||
desc=description,
|
||||
model_type=model_type,
|
||||
model_credential=credential_class(),
|
||||
model_class=model_class
|
||||
)
|
||||
|
||||
|
||||
def _get_aws_bedrock_icon_path():
|
||||
return os.path.join(PROJECT_DIR, "apps", 'models_provider', 'impl', 'aws_bedrock_model_provider',
|
||||
'icon', 'bedrock_icon_svg')
|
||||
|
||||
|
||||
def _initialize_model_info():
|
||||
model_info_list = [
|
||||
_create_model_info(
|
||||
'anthropic.claude-v2:1',
|
||||
_('An update to Claude 2 that doubles the context window and improves reliability, hallucination rates, and evidence-based accuracy in long documents and RAG contexts.'),
|
||||
ModelTypeConst.LLM,
|
||||
BedrockLLMModelCredential,
|
||||
BedrockModel
|
||||
),
|
||||
_create_model_info(
|
||||
'anthropic.claude-v2',
|
||||
_('Anthropic is a powerful model that can handle a variety of tasks, from complex dialogue and creative content generation to detailed command obedience.'),
|
||||
ModelTypeConst.LLM,
|
||||
BedrockLLMModelCredential,
|
||||
BedrockModel
|
||||
),
|
||||
_create_model_info(
|
||||
'anthropic.claude-3-haiku-20240307-v1:0',
|
||||
_("The Claude 3 Haiku is Anthropic's fastest and most compact model, with near-instant responsiveness. The model can answer simple queries and requests quickly. Customers will be able to build seamless AI experiences that mimic human interactions. Claude 3 Haiku can process images and return text output, and provides 200K context windows."),
|
||||
ModelTypeConst.LLM,
|
||||
BedrockLLMModelCredential,
|
||||
BedrockModel
|
||||
),
|
||||
_create_model_info(
|
||||
'anthropic.claude-3-sonnet-20240229-v1:0',
|
||||
_("The Claude 3 Sonnet model from Anthropic strikes the ideal balance between intelligence and speed, especially when it comes to handling enterprise workloads. This model offers maximum utility while being priced lower than competing products, and it's been engineered to be a solid choice for deploying AI at scale."),
|
||||
ModelTypeConst.LLM,
|
||||
BedrockLLMModelCredential,
|
||||
BedrockModel
|
||||
),
|
||||
_create_model_info(
|
||||
'anthropic.claude-3-5-sonnet-20240620-v1:0',
|
||||
_('The Claude 3.5 Sonnet raises the industry standard for intelligence, outperforming competing models and the Claude 3 Opus in extensive evaluations, with the speed and cost-effectiveness of our mid-range models.'),
|
||||
ModelTypeConst.LLM,
|
||||
BedrockLLMModelCredential,
|
||||
BedrockModel
|
||||
),
|
||||
_create_model_info(
|
||||
'anthropic.claude-instant-v1',
|
||||
_('A faster, more affordable but still very powerful model that can handle a range of tasks including casual conversation, text analysis, summarization and document question answering.'),
|
||||
ModelTypeConst.LLM,
|
||||
BedrockLLMModelCredential,
|
||||
BedrockModel
|
||||
),
|
||||
_create_model_info(
|
||||
'amazon.titan-text-premier-v1:0',
|
||||
_("Titan Text Premier is the most powerful and advanced model in the Titan Text series, designed to deliver exceptional performance for a variety of enterprise applications. With its cutting-edge features, it delivers greater accuracy and outstanding results, making it an excellent choice for organizations looking for a top-notch text processing solution."),
|
||||
ModelTypeConst.LLM,
|
||||
BedrockLLMModelCredential,
|
||||
BedrockModel
|
||||
),
|
||||
_create_model_info(
|
||||
'amazon.titan-text-lite-v1',
|
||||
_('Amazon Titan Text Lite is a lightweight, efficient model ideal for fine-tuning English-language tasks, including summarization and copywriting, where customers require smaller, more cost-effective, and highly customizable models.'),
|
||||
ModelTypeConst.LLM,
|
||||
BedrockLLMModelCredential,
|
||||
BedrockModel),
|
||||
_create_model_info(
|
||||
'amazon.titan-text-express-v1',
|
||||
_('Amazon Titan Text Express has context lengths of up to 8,000 tokens, making it ideal for a variety of high-level general language tasks, such as open-ended text generation and conversational chat, as well as support in retrieval-augmented generation (RAG). At launch, the model is optimized for English, but other languages are supported.'),
|
||||
ModelTypeConst.LLM,
|
||||
BedrockLLMModelCredential,
|
||||
BedrockModel),
|
||||
_create_model_info(
|
||||
'mistral.mistral-7b-instruct-v0:2',
|
||||
_('7B dense converter for rapid deployment and easy customization. Small in size yet powerful in a variety of use cases. Supports English and code, as well as 32k context windows.'),
|
||||
ModelTypeConst.LLM,
|
||||
BedrockLLMModelCredential,
|
||||
BedrockModel),
|
||||
_create_model_info(
|
||||
'mistral.mistral-large-2402-v1:0',
|
||||
_('Advanced Mistral AI large-scale language model capable of handling any language task, including complex multilingual reasoning, text understanding, transformation, and code generation.'),
|
||||
ModelTypeConst.LLM,
|
||||
BedrockLLMModelCredential,
|
||||
BedrockModel),
|
||||
_create_model_info(
|
||||
'meta.llama3-70b-instruct-v1:0',
|
||||
_('Ideal for content creation, conversational AI, language understanding, R&D, and enterprise applications'),
|
||||
ModelTypeConst.LLM,
|
||||
BedrockLLMModelCredential,
|
||||
BedrockModel),
|
||||
_create_model_info(
|
||||
'meta.llama3-8b-instruct-v1:0',
|
||||
_('Ideal for limited computing power and resources, edge devices, and faster training times.'),
|
||||
ModelTypeConst.LLM,
|
||||
BedrockLLMModelCredential,
|
||||
BedrockModel),
|
||||
]
|
||||
embedded_model_info_list = [
|
||||
_create_model_info(
|
||||
'amazon.titan-embed-text-v1',
|
||||
_('Titan Embed Text is the largest embedding model in the Amazon Titan Embed series and can handle various text embedding tasks, such as text classification, text similarity calculation, etc.'),
|
||||
ModelTypeConst.EMBEDDING,
|
||||
BedrockEmbeddingCredential,
|
||||
BedrockEmbeddingModel
|
||||
),
|
||||
]
|
||||
|
||||
model_info_manage = ModelInfoManage.builder() \
|
||||
.append_model_info_list(model_info_list) \
|
||||
.append_default_model_info(model_info_list[0]) \
|
||||
.append_model_info_list(embedded_model_info_list) \
|
||||
.append_default_model_info(embedded_model_info_list[0]) \
|
||||
.build()
|
||||
|
||||
return model_info_manage
|
||||
|
||||
|
||||
class BedrockModelProvider(IModelProvider):
|
||||
def __init__(self):
|
||||
self._model_info_manage = _initialize_model_info()
|
||||
|
||||
def get_model_info_manage(self):
|
||||
return self._model_info_manage
|
||||
|
||||
def get_model_provide_info(self):
|
||||
icon_path = _get_aws_bedrock_icon_path()
|
||||
icon_data = get_file_content(icon_path)
|
||||
return ModelProvideInfo(
|
||||
provider='model_aws_bedrock_provider',
|
||||
name='Amazon Bedrock',
|
||||
icon=icon_data
|
||||
)
|
||||
|
|
@ -0,0 +1,53 @@
|
|||
import traceback
|
||||
from typing import Dict
|
||||
|
||||
from django.utils.translation import gettext as _
|
||||
|
||||
from common import forms
|
||||
from common.exception.app_exception import AppApiException
|
||||
from common.forms import BaseForm
|
||||
from models_provider.base_model_provider import BaseModelCredential, ValidCode
|
||||
from models_provider.impl.aws_bedrock_model_provider.model.embedding import BedrockEmbeddingModel
|
||||
|
||||
|
||||
class BedrockEmbeddingCredential(BaseForm, BaseModelCredential):
|
||||
|
||||
def is_valid(self, model_type: str, model_name, model_credential: Dict[str, object], model_params, provider,
|
||||
raise_exception=False):
|
||||
model_type_list = provider.get_model_type_list()
|
||||
if not any(mt.get('value') == model_type for mt in model_type_list):
|
||||
if raise_exception:
|
||||
raise AppApiException(ValidCode.valid_error.value,
|
||||
_('{model_type} Model type is not supported').format(model_type=model_type))
|
||||
return False
|
||||
|
||||
required_keys = ['region_name', 'access_key_id', 'secret_access_key']
|
||||
if not all(key in model_credential for key in required_keys):
|
||||
if raise_exception:
|
||||
raise AppApiException(ValidCode.valid_error.value,
|
||||
_('The following fields are required: {keys}').format(
|
||||
keys=", ".join(required_keys)))
|
||||
return False
|
||||
|
||||
try:
|
||||
model: BedrockEmbeddingModel = provider.get_model(model_type, model_name, model_credential)
|
||||
aa = model.embed_query(_('Hello'))
|
||||
print(aa)
|
||||
except AppApiException:
|
||||
raise
|
||||
except Exception as e:
|
||||
traceback.print_exc()
|
||||
if raise_exception:
|
||||
raise AppApiException(ValidCode.valid_error.value,
|
||||
_('Verification failed, please check whether the parameters are correct: {error}').format(
|
||||
error=str(e)))
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
def encryption_dict(self, model: Dict[str, object]):
|
||||
return {**model, 'secret_access_key': super().encryption(model.get('secret_access_key', ''))}
|
||||
|
||||
region_name = forms.TextInputField('Region Name', required=True)
|
||||
access_key_id = forms.TextInputField('Access Key ID', required=True)
|
||||
secret_access_key = forms.PasswordInputField('Secret Access Key', required=True)
|
||||
|
|
@ -0,0 +1,76 @@
|
|||
import traceback
|
||||
from typing import Dict
|
||||
|
||||
from django.utils.translation import gettext_lazy as _, gettext
|
||||
from langchain_core.messages import HumanMessage
|
||||
|
||||
from common import forms
|
||||
from common.exception.app_exception import AppApiException
|
||||
from common.forms import BaseForm, TooltipLabel
|
||||
from models_provider.base_model_provider import ValidCode, BaseModelCredential
|
||||
|
||||
|
||||
class BedrockLLMModelParams(BaseForm):
|
||||
temperature = forms.SliderField(TooltipLabel(_('Temperature'),
|
||||
_('Higher values make the output more random, while lower values make it more focused and deterministic')),
|
||||
required=True, default_value=0.7,
|
||||
_min=0.1,
|
||||
_max=1.0,
|
||||
_step=0.01,
|
||||
precision=2)
|
||||
|
||||
max_tokens = forms.SliderField(
|
||||
TooltipLabel(_('Output the maximum Tokens'),
|
||||
_('Specify the maximum number of tokens that the model can generate')),
|
||||
required=True, default_value=1024,
|
||||
_min=1,
|
||||
_max=100000,
|
||||
_step=1,
|
||||
precision=0)
|
||||
|
||||
|
||||
class BedrockLLMModelCredential(BaseForm, BaseModelCredential):
|
||||
|
||||
def is_valid(self, model_type: str, model_name, model_credential: Dict[str, object], model_params, provider,
|
||||
raise_exception=False):
|
||||
model_type_list = provider.get_model_type_list()
|
||||
if not any(mt.get('value') == model_type for mt in model_type_list):
|
||||
if raise_exception:
|
||||
raise AppApiException(ValidCode.valid_error.value,
|
||||
gettext('{model_type} Model type is not supported').format(model_type=model_type))
|
||||
return False
|
||||
|
||||
required_keys = ['region_name', 'access_key_id', 'secret_access_key']
|
||||
if not all(key in model_credential for key in required_keys):
|
||||
if raise_exception:
|
||||
raise AppApiException(ValidCode.valid_error.value,
|
||||
gettext('The following fields are required: {keys}').format(
|
||||
keys=", ".join(required_keys)))
|
||||
return False
|
||||
|
||||
try:
|
||||
model = provider.get_model(model_type, model_name, model_credential, **model_params)
|
||||
model.invoke([HumanMessage(content=gettext('Hello'))])
|
||||
except AppApiException:
|
||||
raise
|
||||
except Exception as e:
|
||||
traceback.print_exc()
|
||||
if raise_exception:
|
||||
raise AppApiException(ValidCode.valid_error.value,
|
||||
gettext(
|
||||
'Verification failed, please check whether the parameters are correct: {error}').format(
|
||||
error=str(e)))
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
def encryption_dict(self, model: Dict[str, object]):
|
||||
return {**model, 'secret_access_key': super().encryption(model.get('secret_access_key', ''))}
|
||||
|
||||
region_name = forms.TextInputField('Region Name', required=True)
|
||||
access_key_id = forms.TextInputField('Access Key ID', required=True)
|
||||
secret_access_key = forms.PasswordInputField('Secret Access Key', required=True)
|
||||
base_url = forms.TextInputField('Proxy URL', required=False)
|
||||
|
||||
def get_model_params_setting_form(self, model_name):
|
||||
return BedrockLLMModelParams()
|
||||
|
|
@ -0,0 +1 @@
|
|||
<svg xmlns="http://www.w3.org/2000/svg" version="2.0" focusable="false" aria-hidden="true" class="globalNav-1216 globalNav-1213" data-testid="awsc-logo" viewBox="0 0 29 17"><path class="globalNav-1214" d="M8.38 6.17a2.6 2.6 0 00.11.83c.08.232.18.456.3.67a.4.4 0 01.07.21.36.36 0 01-.18.28l-.59.39a.43.43 0 01-.24.08.38.38 0 01-.28-.13 2.38 2.38 0 01-.34-.43c-.09-.16-.18-.34-.28-.55a3.44 3.44 0 01-2.74 1.29 2.54 2.54 0 01-1.86-.67 2.36 2.36 0 01-.68-1.79 2.43 2.43 0 01.84-1.92 3.43 3.43 0 012.29-.72 6.75 6.75 0 011 .07c.35.05.7.12 1.07.2V3.3a2.06 2.06 0 00-.44-1.49 2.12 2.12 0 00-1.52-.43 4.4 4.4 0 00-1 .12 6.85 6.85 0 00-1 .32l-.33.12h-.14c-.14 0-.2-.1-.2-.29v-.46A.62.62 0 012.3.87a.78.78 0 01.27-.2A6 6 0 013.74.25 5.7 5.7 0 015.19.07a3.37 3.37 0 012.44.76 3 3 0 01.77 2.29l-.02 3.05zM4.6 7.59a3 3 0 001-.17 2 2 0 00.88-.6 1.36 1.36 0 00.32-.59 3.18 3.18 0 00.09-.81V5A7.52 7.52 0 006 4.87h-.88a2.13 2.13 0 00-1.38.37 1.3 1.3 0 00-.46 1.08 1.3 1.3 0 00.34 1c.278.216.63.313.98.27zm7.49 1a.56.56 0 01-.36-.09.73.73 0 01-.2-.37L9.35.93a1.39 1.39 0 01-.08-.38c0-.15.07-.23.22-.23h.92a.56.56 0 01.36.09.74.74 0 01.19.37L12.53 7 14 .79a.61.61 0 01.18-.37.59.59 0 01.37-.09h.75a.62.62 0 01.38.09.74.74 0 01.18.37L17.31 7 18.92.76a.74.74 0 01.19-.37.56.56 0 01.36-.09h.87a.21.21 0 01.23.23 1 1 0 010 .15s0 .13-.06.23l-2.26 7.2a.74.74 0 01-.19.37.6.6 0 01-.36.09h-.8a.53.53 0 01-.37-.1.64.64 0 01-.18-.37l-1.45-6-1.44 6a.64.64 0 01-.18.37.55.55 0 01-.37.1l-.82.02zm12 .24a6.29 6.29 0 01-1.44-.16 4.21 4.21 0 01-1.07-.37.69.69 0 01-.29-.26.66.66 0 01-.06-.27V7.3c0-.19.07-.29.21-.29a.57.57 0 01.18 0l.23.1c.32.143.656.25 1 .32.365.08.737.12 1.11.12a2.47 2.47 0 001.36-.31 1 1 0 00.48-.88.88.88 0 00-.25-.65 2.29 2.29 0 00-.94-.49l-1.35-.43a2.83 2.83 0 01-1.49-.94 2.24 2.24 0 01-.47-1.36 2 2 0 01.25-1c.167-.3.395-.563.67-.77a3 3 0 011-.48A4.1 4.1 0 0124.4.08a4.4 4.4 0 01.62 0l.61.1.53.15.39.16c.105.062.2.14.28.23a.57.57 0 01.08.31v.44c0 .2-.07.3-.21.3a.92.92 0 01-.36-.12 4.35 4.35 0 00-1.8-.36 2.51 2.51 0 00-1.24.26.92.92 0 00-.44.84c0 .249.1.488.28.66.295.236.635.41 1 .51l1.32.42a2.88 2.88 0 011.44.9 2.1 2.1 0 01.43 1.31 2.38 2.38 0 01-.24 1.08 2.34 2.34 0 01-.68.82 3 3 0 01-1 .53 4.59 4.59 0 01-1.35.22l.03-.01z"></path><path class="globalNav-1215" d="M25.82 13.43a20.07 20.07 0 01-11.35 3.47A20.54 20.54 0 01.61 11.62c-.29-.26 0-.62.32-.42a27.81 27.81 0 0013.86 3.68 27.54 27.54 0 0010.58-2.16c.52-.22.96.34.45.71z"></path><path class="globalNav-1215" d="M27.1 12c-.4-.51-2.6-.24-3.59-.12-.3 0-.34-.23-.07-.42 1.75-1.23 4.63-.88 5-.46.37.42-.09 3.3-1.74 4.68-.25.21-.49.09-.38-.18.34-.95 1.17-3.02.78-3.5z"></path></svg>
|
||||
|
After Width: | Height: | Size: 2.6 KiB |
|
|
@ -0,0 +1,60 @@
|
|||
from langchain_community.embeddings import BedrockEmbeddings
|
||||
|
||||
from models_provider.base_model_provider import MaxKBBaseModel
|
||||
from typing import Dict, List
|
||||
|
||||
from models_provider.impl.aws_bedrock_model_provider.model.llm import _update_aws_credentials
|
||||
|
||||
|
||||
class BedrockEmbeddingModel(MaxKBBaseModel, BedrockEmbeddings):
|
||||
def __init__(self, model_id: str, region_name: str, credentials_profile_name: str,
|
||||
**kwargs):
|
||||
super().__init__(model_id=model_id, region_name=region_name,
|
||||
credentials_profile_name=credentials_profile_name, **kwargs)
|
||||
|
||||
@classmethod
|
||||
def new_instance(cls, model_type: str, model_name: str, model_credential: Dict[str, str],
|
||||
**model_kwargs) -> 'BedrockModel':
|
||||
_update_aws_credentials(model_credential['access_key_id'], model_credential['access_key_id'],
|
||||
model_credential['secret_access_key'])
|
||||
return cls(
|
||||
model_id=model_name,
|
||||
region_name=model_credential['region_name'],
|
||||
credentials_profile_name=model_credential['access_key_id'],
|
||||
)
|
||||
|
||||
def embed_documents(self, texts: List[str]) -> List[List[float]]:
|
||||
"""Compute doc embeddings using a Bedrock model.
|
||||
|
||||
Args:
|
||||
texts: The list of texts to embed
|
||||
|
||||
Returns:
|
||||
List of embeddings, one for each text.
|
||||
"""
|
||||
results = []
|
||||
for text in texts:
|
||||
response = self._embedding_func(text)
|
||||
|
||||
if self.normalize:
|
||||
response = self._normalize_vector(response)
|
||||
|
||||
results.append(response)
|
||||
|
||||
return results
|
||||
|
||||
def embed_query(self, text: str) -> List[float]:
|
||||
"""Compute query embeddings using a Bedrock model.
|
||||
|
||||
Args:
|
||||
text: The text to embed.
|
||||
|
||||
Returns:
|
||||
Embeddings for the text.
|
||||
"""
|
||||
embedding = self._embedding_func(text)
|
||||
|
||||
if self.normalize:
|
||||
return self._normalize_vector(embedding)
|
||||
|
||||
return embedding
|
||||
|
|
@ -0,0 +1,88 @@
|
|||
import os
|
||||
import re
|
||||
from typing import Dict
|
||||
|
||||
from botocore.config import Config
|
||||
from langchain_community.chat_models import BedrockChat
|
||||
|
||||
from models_provider.base_model_provider import MaxKBBaseModel
|
||||
|
||||
|
||||
def get_max_tokens_keyword(model_name):
|
||||
"""
|
||||
根据模型名称返回正确的 max_tokens 关键字。
|
||||
|
||||
:param model_name: 模型名称字符串
|
||||
:return: 对应的 max_tokens 关键字字符串
|
||||
"""
|
||||
maxTokens = ["ai21.j2-ultra-v1", "ai21.j2-mid-v1"]
|
||||
# max_tokens_to_sample = ["anthropic.claude-v2:1", "anthropic.claude-v2", "anthropic.claude-instant-v1"]
|
||||
maxTokenCount = ["amazon.titan-text-lite-v1", "amazon.titan-text-express-v1"]
|
||||
max_new_tokens = [
|
||||
"us.meta.llama3-2-1b-instruct-v1:0", "us.meta.llama3-2-3b-instruct-v1:0", "us.meta.llama3-2-11b-instruct-v1:0",
|
||||
"us.meta.llama3-2-90b-instruct-v1:0"]
|
||||
if model_name in maxTokens:
|
||||
return 'maxTokens'
|
||||
elif model_name in maxTokenCount:
|
||||
return 'maxTokenCount'
|
||||
elif model_name in max_new_tokens:
|
||||
return 'max_new_tokens'
|
||||
else:
|
||||
return 'max_tokens'
|
||||
|
||||
|
||||
class BedrockModel(MaxKBBaseModel, BedrockChat):
|
||||
|
||||
@staticmethod
|
||||
def is_cache_model():
|
||||
return False
|
||||
|
||||
def __init__(self, model_id: str, region_name: str, credentials_profile_name: str,
|
||||
streaming: bool = False, config: Config = None, **kwargs):
|
||||
super().__init__(model_id=model_id, region_name=region_name,
|
||||
credentials_profile_name=credentials_profile_name, streaming=streaming, config=config,
|
||||
**kwargs)
|
||||
|
||||
@classmethod
|
||||
def new_instance(cls, model_type: str, model_name: str, model_credential: Dict[str, str],
|
||||
**model_kwargs) -> 'BedrockModel':
|
||||
optional_params = MaxKBBaseModel.filter_optional_params(model_kwargs)
|
||||
|
||||
config = {}
|
||||
# 判断model_kwargs是否包含 base_url 且不为空
|
||||
if 'base_url' in model_credential and model_credential['base_url']:
|
||||
proxy_url = model_credential['base_url']
|
||||
config = Config(
|
||||
proxies={
|
||||
'http': proxy_url,
|
||||
'https': proxy_url
|
||||
},
|
||||
connect_timeout=60,
|
||||
read_timeout=60
|
||||
)
|
||||
_update_aws_credentials(model_credential['access_key_id'], model_credential['access_key_id'],
|
||||
model_credential['secret_access_key'])
|
||||
|
||||
return cls(
|
||||
model_id=model_name,
|
||||
region_name=model_credential['region_name'],
|
||||
credentials_profile_name=model_credential['access_key_id'],
|
||||
streaming=model_kwargs.pop('streaming', True),
|
||||
model_kwargs=optional_params,
|
||||
config=config
|
||||
)
|
||||
|
||||
|
||||
def _update_aws_credentials(profile_name, access_key_id, secret_access_key):
|
||||
credentials_path = os.path.join(os.path.expanduser("~"), ".aws", "credentials")
|
||||
os.makedirs(os.path.dirname(credentials_path), exist_ok=True)
|
||||
|
||||
content = open(credentials_path, 'r').read() if os.path.exists(credentials_path) else ''
|
||||
pattern = rf'\n*\[{profile_name}\]\n*(aws_access_key_id = .*)\n*(aws_secret_access_key = .*)\n*'
|
||||
content = re.sub(pattern, '', content, flags=re.DOTALL)
|
||||
|
||||
if not re.search(rf'\[{profile_name}\]', content):
|
||||
content += f"\n[{profile_name}]\naws_access_key_id = {access_key_id}\naws_secret_access_key = {secret_access_key}\n"
|
||||
|
||||
with open(credentials_path, 'w') as file:
|
||||
file.write(content)
|
||||
|
|
@ -0,0 +1,8 @@
|
|||
# coding=utf-8
|
||||
"""
|
||||
@project: maxkb
|
||||
@Author:虎
|
||||
@file: __init__.py.py
|
||||
@date:2023/10/31 17:16
|
||||
@desc:
|
||||
"""
|
||||
|
|
@ -0,0 +1,116 @@
|
|||
# coding=utf-8
|
||||
"""
|
||||
@project: maxkb
|
||||
@Author:虎
|
||||
@file: azure_model_provider.py
|
||||
@date:2023/10/31 16:19
|
||||
@desc:
|
||||
"""
|
||||
import os
|
||||
|
||||
from common.utils.common import get_file_content
|
||||
from models_provider.base_model_provider import IModelProvider, ModelProvideInfo, ModelInfo, \
|
||||
ModelTypeConst, ModelInfoManage
|
||||
from models_provider.impl.azure_model_provider.credential.embedding import AzureOpenAIEmbeddingCredential
|
||||
from models_provider.impl.azure_model_provider.credential.image import AzureOpenAIImageModelCredential
|
||||
from models_provider.impl.azure_model_provider.credential.llm import AzureLLMModelCredential
|
||||
from models_provider.impl.azure_model_provider.credential.stt import AzureOpenAISTTModelCredential
|
||||
from models_provider.impl.azure_model_provider.credential.tti import AzureOpenAITextToImageModelCredential
|
||||
from models_provider.impl.azure_model_provider.credential.tts import AzureOpenAITTSModelCredential
|
||||
from models_provider.impl.azure_model_provider.model.azure_chat_model import AzureChatModel
|
||||
from models_provider.impl.azure_model_provider.model.embedding import AzureOpenAIEmbeddingModel
|
||||
from models_provider.impl.azure_model_provider.model.image import AzureOpenAIImage
|
||||
from models_provider.impl.azure_model_provider.model.stt import AzureOpenAISpeechToText
|
||||
from models_provider.impl.azure_model_provider.model.tti import AzureOpenAITextToImage
|
||||
from models_provider.impl.azure_model_provider.model.tts import AzureOpenAITextToSpeech
|
||||
from maxkb.conf import PROJECT_DIR
|
||||
from django.utils.translation import gettext_lazy as _
|
||||
|
||||
base_azure_llm_model_credential = AzureLLMModelCredential()
|
||||
base_azure_embedding_model_credential = AzureOpenAIEmbeddingCredential()
|
||||
base_azure_image_model_credential = AzureOpenAIImageModelCredential()
|
||||
base_azure_tti_model_credential = AzureOpenAITextToImageModelCredential()
|
||||
base_azure_tts_model_credential = AzureOpenAITTSModelCredential()
|
||||
base_azure_stt_model_credential = AzureOpenAISTTModelCredential()
|
||||
|
||||
default_model_info = [
|
||||
ModelInfo('Azure OpenAI', '', ModelTypeConst.LLM,
|
||||
base_azure_llm_model_credential, AzureChatModel, api_version='2024-02-15-preview'
|
||||
),
|
||||
ModelInfo('gpt-4', '', ModelTypeConst.LLM,
|
||||
base_azure_llm_model_credential, AzureChatModel, api_version='2024-02-15-preview'
|
||||
),
|
||||
ModelInfo('gpt-4o', '', ModelTypeConst.LLM,
|
||||
base_azure_llm_model_credential, AzureChatModel, api_version='2024-02-15-preview'
|
||||
),
|
||||
ModelInfo('gpt-4o-mini', '', ModelTypeConst.LLM,
|
||||
base_azure_llm_model_credential, AzureChatModel, api_version='2024-02-15-preview'
|
||||
),
|
||||
]
|
||||
|
||||
embedding_model_info = [
|
||||
ModelInfo('text-embedding-3-large', '', ModelTypeConst.EMBEDDING,
|
||||
base_azure_embedding_model_credential, AzureOpenAIEmbeddingModel, api_version='2023-05-15'
|
||||
),
|
||||
ModelInfo('text-embedding-3-small', '', ModelTypeConst.EMBEDDING,
|
||||
base_azure_embedding_model_credential, AzureOpenAIEmbeddingModel, api_version='2023-05-15'
|
||||
),
|
||||
ModelInfo('text-embedding-ada-002', '', ModelTypeConst.EMBEDDING,
|
||||
base_azure_embedding_model_credential, AzureOpenAIEmbeddingModel, api_version='2023-05-15'
|
||||
),
|
||||
]
|
||||
|
||||
image_model_info = [
|
||||
ModelInfo('gpt-4o', '', ModelTypeConst.IMAGE,
|
||||
base_azure_image_model_credential, AzureOpenAIImage, api_version='2023-05-15'
|
||||
),
|
||||
ModelInfo('gpt-4o-mini', '', ModelTypeConst.IMAGE,
|
||||
base_azure_image_model_credential, AzureOpenAIImage, api_version='2023-05-15'
|
||||
),
|
||||
]
|
||||
|
||||
tti_model_info = [
|
||||
ModelInfo('dall-e-3', '', ModelTypeConst.TTI,
|
||||
base_azure_tti_model_credential, AzureOpenAITextToImage, api_version='2023-05-15'
|
||||
),
|
||||
]
|
||||
|
||||
tts_model_info = [
|
||||
ModelInfo('tts', '', ModelTypeConst.TTS,
|
||||
base_azure_tts_model_credential, AzureOpenAITextToSpeech, api_version='2023-05-15'
|
||||
),
|
||||
]
|
||||
|
||||
stt_model_info = [
|
||||
ModelInfo('whisper', '', ModelTypeConst.STT,
|
||||
base_azure_stt_model_credential, AzureOpenAISpeechToText, api_version='2023-05-15'
|
||||
),
|
||||
]
|
||||
|
||||
model_info_manage = (
|
||||
ModelInfoManage.builder()
|
||||
.append_default_model_info(default_model_info[0])
|
||||
.append_model_info_list(default_model_info)
|
||||
.append_model_info_list(embedding_model_info)
|
||||
.append_default_model_info(embedding_model_info[0])
|
||||
.append_model_info_list(image_model_info)
|
||||
.append_default_model_info(image_model_info[0])
|
||||
.append_model_info_list(stt_model_info)
|
||||
.append_default_model_info(stt_model_info[0])
|
||||
.append_model_info_list(tts_model_info)
|
||||
.append_default_model_info(tts_model_info[0])
|
||||
.append_model_info_list(tti_model_info)
|
||||
.append_default_model_info(tti_model_info[0])
|
||||
.build()
|
||||
)
|
||||
|
||||
|
||||
class AzureModelProvider(IModelProvider):
|
||||
|
||||
def get_model_info_manage(self):
|
||||
return model_info_manage
|
||||
|
||||
def get_model_provide_info(self):
|
||||
return ModelProvideInfo(provider='model_azure_provider', name='Azure OpenAI', icon=get_file_content(
|
||||
os.path.join(PROJECT_DIR, "apps", 'models_provider', 'impl', 'azure_model_provider', 'icon',
|
||||
'azure_icon_svg')))
|
||||
|
|
@ -0,0 +1,57 @@
|
|||
# coding=utf-8
|
||||
"""
|
||||
@project: MaxKB
|
||||
@Author:虎
|
||||
@file: llm.py
|
||||
@date:2024/7/11 17:08
|
||||
@desc:
|
||||
"""
|
||||
import traceback
|
||||
from typing import Dict
|
||||
|
||||
from django.utils.translation import gettext as _
|
||||
|
||||
from common import forms
|
||||
from common.exception.app_exception import AppApiException
|
||||
from common.forms import BaseForm
|
||||
from models_provider.base_model_provider import BaseModelCredential, ValidCode
|
||||
|
||||
|
||||
class AzureOpenAIEmbeddingCredential(BaseForm, BaseModelCredential):
|
||||
|
||||
def is_valid(self, model_type: str, model_name, model_credential: Dict[str, object], model_params, provider,
|
||||
raise_exception=False):
|
||||
model_type_list = provider.get_model_type_list()
|
||||
if not any(list(filter(lambda mt: mt.get('value') == model_type, model_type_list))):
|
||||
raise AppApiException(ValidCode.valid_error.value,
|
||||
_('{model_type} Model type is not supported').format(model_type=model_type))
|
||||
|
||||
for key in ['api_base', 'api_key', 'api_version']:
|
||||
if key not in model_credential:
|
||||
if raise_exception:
|
||||
raise AppApiException(ValidCode.valid_error.value, _('{key} is required').format(key=key))
|
||||
else:
|
||||
return False
|
||||
try:
|
||||
model = provider.get_model(model_type, model_name, model_credential)
|
||||
model.embed_query(_('Hello'))
|
||||
except Exception as e:
|
||||
traceback.print_exc()
|
||||
if isinstance(e, AppApiException):
|
||||
raise e
|
||||
if raise_exception:
|
||||
raise AppApiException(ValidCode.valid_error.value,
|
||||
_('Verification failed, please check whether the parameters are correct'))
|
||||
else:
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
def encryption_dict(self, model: Dict[str, object]):
|
||||
return {**model, 'api_key': super().encryption(model.get('api_key', ''))}
|
||||
|
||||
api_version = forms.TextInputField("Api Version", required=True)
|
||||
|
||||
api_base = forms.TextInputField('Azure Endpoint', required=True)
|
||||
|
||||
api_key = forms.PasswordInputField("API Key", required=True)
|
||||
|
|
@ -0,0 +1,75 @@
|
|||
# coding=utf-8
|
||||
import base64
|
||||
import os
|
||||
import traceback
|
||||
from typing import Dict
|
||||
|
||||
from langchain_core.messages import HumanMessage
|
||||
|
||||
from common import forms
|
||||
from common.exception.app_exception import AppApiException
|
||||
from common.forms import BaseForm, TooltipLabel
|
||||
from models_provider.base_model_provider import BaseModelCredential, ValidCode
|
||||
from django.utils.translation import gettext_lazy as _, gettext
|
||||
|
||||
|
||||
class AzureOpenAIImageModelParams(BaseForm):
|
||||
temperature = forms.SliderField(TooltipLabel(_('Temperature'),
|
||||
_('Higher values make the output more random, while lower values make it more focused and deterministic')),
|
||||
required=True, default_value=0.7,
|
||||
_min=0.1,
|
||||
_max=1.0,
|
||||
_step=0.01,
|
||||
precision=2)
|
||||
|
||||
max_tokens = forms.SliderField(
|
||||
TooltipLabel(_('Output the maximum Tokens'),
|
||||
_('Specify the maximum number of tokens that the model can generate')),
|
||||
required=True, default_value=800,
|
||||
_min=1,
|
||||
_max=100000,
|
||||
_step=1,
|
||||
precision=0)
|
||||
|
||||
|
||||
class AzureOpenAIImageModelCredential(BaseForm, BaseModelCredential):
|
||||
api_version = forms.TextInputField("API Version", required=True)
|
||||
api_base = forms.TextInputField('Azure Endpoint', required=True)
|
||||
api_key = forms.PasswordInputField("API Key", required=True)
|
||||
|
||||
def is_valid(self, model_type: str, model_name, model_credential: Dict[str, object], model_params, provider,
|
||||
raise_exception=False):
|
||||
model_type_list = provider.get_model_type_list()
|
||||
if not any(list(filter(lambda mt: mt.get('value') == model_type, model_type_list))):
|
||||
raise AppApiException(ValidCode.valid_error.value,
|
||||
gettext('{model_type} Model type is not supported').format(model_type=model_type))
|
||||
|
||||
for key in ['api_base', 'api_key', 'api_version']:
|
||||
if key not in model_credential:
|
||||
if raise_exception:
|
||||
raise AppApiException(ValidCode.valid_error.value, gettext('{key} is required').format(key=key))
|
||||
else:
|
||||
return False
|
||||
try:
|
||||
model = provider.get_model(model_type, model_name, model_credential, **model_params)
|
||||
res = model.stream([HumanMessage(content=[{"type": "text", "text": gettext('Hello')}])])
|
||||
for chunk in res:
|
||||
print(chunk)
|
||||
except Exception as e:
|
||||
traceback.print_exc()
|
||||
if isinstance(e, AppApiException):
|
||||
raise e
|
||||
if raise_exception:
|
||||
raise AppApiException(ValidCode.valid_error.value,
|
||||
gettext(
|
||||
'Verification failed, please check whether the parameters are correct: {error}').format(
|
||||
error=str(e)))
|
||||
else:
|
||||
return False
|
||||
return True
|
||||
|
||||
def encryption_dict(self, model: Dict[str, object]):
|
||||
return {**model, 'api_key': super().encryption(model.get('api_key', ''))}
|
||||
|
||||
def get_model_params_setting_form(self, model_name):
|
||||
return AzureOpenAIImageModelParams()
|
||||
|
|
@ -0,0 +1,96 @@
|
|||
# coding=utf-8
|
||||
"""
|
||||
@project: MaxKB
|
||||
@Author:虎
|
||||
@file: llm.py
|
||||
@date:2024/7/11 17:08
|
||||
@desc:
|
||||
"""
|
||||
import traceback
|
||||
from typing import Dict
|
||||
|
||||
from langchain_core.messages import HumanMessage
|
||||
from openai import BadRequestError
|
||||
|
||||
from common import forms
|
||||
from common.exception.app_exception import AppApiException
|
||||
from common.forms import BaseForm, TooltipLabel
|
||||
from models_provider.base_model_provider import BaseModelCredential, ValidCode
|
||||
from django.utils.translation import gettext_lazy as _, gettext
|
||||
|
||||
|
||||
class AzureLLMModelParams(BaseForm):
|
||||
temperature = forms.SliderField(TooltipLabel(_('Temperature'),
|
||||
_('Higher values make the output more random, while lower values make it more focused and deterministic')),
|
||||
required=True, default_value=0.7,
|
||||
_min=0.1,
|
||||
_max=1.0,
|
||||
_step=0.01,
|
||||
precision=2)
|
||||
|
||||
max_tokens = forms.SliderField(
|
||||
TooltipLabel(_('Output the maximum Tokens'),
|
||||
_('Specify the maximum number of tokens that the model can generate')),
|
||||
required=True, default_value=800,
|
||||
_min=1,
|
||||
_max=100000,
|
||||
_step=1,
|
||||
precision=0)
|
||||
|
||||
|
||||
class o3MiniLLMModelParams(BaseForm):
|
||||
max_completion_tokens = forms.SliderField(
|
||||
TooltipLabel(_('Output the maximum Tokens'),
|
||||
_('Specify the maximum number of tokens that the model can generate')),
|
||||
required=True, default_value=800,
|
||||
_min=1,
|
||||
_max=5000,
|
||||
_step=1,
|
||||
precision=0)
|
||||
|
||||
|
||||
class AzureLLMModelCredential(BaseForm, BaseModelCredential):
|
||||
|
||||
def is_valid(self, model_type: str, model_name, model_credential: Dict[str, object], model_params, provider,
|
||||
raise_exception=False):
|
||||
model_type_list = provider.get_model_type_list()
|
||||
if not any(list(filter(lambda mt: mt.get('value') == model_type, model_type_list))):
|
||||
raise AppApiException(ValidCode.valid_error.value,
|
||||
gettext('{model_type} Model type is not supported').format(model_type=model_type))
|
||||
|
||||
for key in ['api_base', 'api_key', 'deployment_name', 'api_version']:
|
||||
if key not in model_credential:
|
||||
if raise_exception:
|
||||
raise AppApiException(ValidCode.valid_error.value, gettext('{key} is required').format(key=key))
|
||||
else:
|
||||
return False
|
||||
try:
|
||||
model = provider.get_model(model_type, model_name, model_credential, **model_params)
|
||||
model.invoke([HumanMessage(content=gettext('Hello'))])
|
||||
except Exception as e:
|
||||
traceback.print_exc()
|
||||
if isinstance(e, AppApiException) or isinstance(e, BadRequestError):
|
||||
raise e
|
||||
if raise_exception:
|
||||
raise AppApiException(ValidCode.valid_error.value,
|
||||
gettext('Verification failed, please check whether the parameters are correct'))
|
||||
else:
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
def encryption_dict(self, model: Dict[str, object]):
|
||||
return {**model, 'api_key': super().encryption(model.get('api_key', ''))}
|
||||
|
||||
api_version = forms.TextInputField("API Version", required=True)
|
||||
|
||||
api_base = forms.TextInputField('Azure Endpoint', required=True)
|
||||
|
||||
api_key = forms.PasswordInputField("API Key", required=True)
|
||||
|
||||
deployment_name = forms.TextInputField("Deployment name", required=True)
|
||||
|
||||
def get_model_params_setting_form(self, model_name):
|
||||
if 'o3' in model_name or 'o1' in model_name:
|
||||
return o3MiniLLMModelParams()
|
||||
return AzureLLMModelParams()
|
||||
|
|
@ -0,0 +1,50 @@
|
|||
# coding=utf-8
|
||||
import traceback
|
||||
from typing import Dict
|
||||
|
||||
from django.utils.translation import gettext as _
|
||||
|
||||
from common import forms
|
||||
from common.exception.app_exception import AppApiException
|
||||
from common.forms import BaseForm
|
||||
from models_provider.base_model_provider import BaseModelCredential, ValidCode
|
||||
|
||||
|
||||
class AzureOpenAISTTModelCredential(BaseForm, BaseModelCredential):
|
||||
api_version = forms.TextInputField("API Version", required=True)
|
||||
api_base = forms.TextInputField('Azure Endpoint', required=True)
|
||||
api_key = forms.PasswordInputField("API Key", required=True)
|
||||
|
||||
def is_valid(self, model_type: str, model_name, model_credential: Dict[str, object], model_params, provider,
|
||||
raise_exception=False):
|
||||
model_type_list = provider.get_model_type_list()
|
||||
if not any(list(filter(lambda mt: mt.get('value') == model_type, model_type_list))):
|
||||
raise AppApiException(ValidCode.valid_error.value,
|
||||
_('{model_type} Model type is not supported').format(model_type=model_type))
|
||||
|
||||
for key in ['api_base', 'api_key', 'api_version']:
|
||||
if key not in model_credential:
|
||||
if raise_exception:
|
||||
raise AppApiException(ValidCode.valid_error.value, _('{key} is required').format(key=key))
|
||||
else:
|
||||
return False
|
||||
try:
|
||||
model = provider.get_model(model_type, model_name, model_credential)
|
||||
model.check_auth()
|
||||
except Exception as e:
|
||||
traceback.print_exc()
|
||||
if isinstance(e, AppApiException):
|
||||
raise e
|
||||
if raise_exception:
|
||||
raise AppApiException(ValidCode.valid_error.value,
|
||||
_('Verification failed, please check whether the parameters are correct: {error}').format(
|
||||
error=str(e)))
|
||||
else:
|
||||
return False
|
||||
return True
|
||||
|
||||
def encryption_dict(self, model: Dict[str, object]):
|
||||
return {**model, 'api_key': super().encryption(model.get('api_key', ''))}
|
||||
|
||||
def get_model_params_setting_form(self, model_name):
|
||||
pass
|
||||
|
|
@ -0,0 +1,87 @@
|
|||
# coding=utf-8
|
||||
import traceback
|
||||
from typing import Dict
|
||||
|
||||
from django.utils.translation import gettext_lazy as _, gettext
|
||||
|
||||
from common import forms
|
||||
from common.exception.app_exception import AppApiException
|
||||
from common.forms import BaseForm, TooltipLabel
|
||||
from models_provider.base_model_provider import BaseModelCredential, ValidCode
|
||||
|
||||
|
||||
class AzureOpenAITTIModelParams(BaseForm):
|
||||
size = forms.SingleSelect(
|
||||
TooltipLabel(_('Image size'), _('Specify the size of the generated image, such as: 1024x1024')),
|
||||
required=True,
|
||||
default_value='1024x1024',
|
||||
option_list=[
|
||||
{'value': '1024x1024', 'label': '1024x1024'},
|
||||
{'value': '1024x1792', 'label': '1024x1792'},
|
||||
{'value': '1792x1024', 'label': '1792x1024'},
|
||||
],
|
||||
text_field='label',
|
||||
value_field='value'
|
||||
)
|
||||
|
||||
quality = forms.SingleSelect(
|
||||
TooltipLabel(_('Picture quality'), ''),
|
||||
required=True,
|
||||
default_value='standard',
|
||||
option_list=[
|
||||
{'value': 'standard', 'label': 'standard'},
|
||||
{'value': 'hd', 'label': 'hd'},
|
||||
],
|
||||
text_field='label',
|
||||
value_field='value'
|
||||
)
|
||||
|
||||
n = forms.SliderField(
|
||||
TooltipLabel(_('Number of pictures'), _('Specify the number of generated images')),
|
||||
required=True, default_value=1,
|
||||
_min=1,
|
||||
_max=10,
|
||||
_step=1,
|
||||
precision=0)
|
||||
|
||||
|
||||
class AzureOpenAITextToImageModelCredential(BaseForm, BaseModelCredential):
|
||||
api_version = forms.TextInputField("API Version", required=True)
|
||||
api_base = forms.TextInputField('Azure Endpoint', required=True)
|
||||
api_key = forms.PasswordInputField("API Key", required=True)
|
||||
|
||||
def is_valid(self, model_type: str, model_name, model_credential: Dict[str, object], model_params, provider,
|
||||
raise_exception=False):
|
||||
model_type_list = provider.get_model_type_list()
|
||||
if not any(list(filter(lambda mt: mt.get('value') == model_type, model_type_list))):
|
||||
raise AppApiException(ValidCode.valid_error.value,
|
||||
gettext('{model_type} Model type is not supported').format(model_type=model_type))
|
||||
|
||||
for key in ['api_base', 'api_key', 'api_version']:
|
||||
if key not in model_credential:
|
||||
if raise_exception:
|
||||
raise AppApiException(ValidCode.valid_error.value, gettext('{key} is required').format(key=key))
|
||||
else:
|
||||
return False
|
||||
try:
|
||||
model = provider.get_model(model_type, model_name, model_credential, **model_params)
|
||||
res = model.check_auth()
|
||||
print(res)
|
||||
except Exception as e:
|
||||
traceback.print_exc()
|
||||
if isinstance(e, AppApiException):
|
||||
raise e
|
||||
if raise_exception:
|
||||
raise AppApiException(ValidCode.valid_error.value,
|
||||
gettext(
|
||||
'Verification failed, please check whether the parameters are correct: {error}').format(
|
||||
error=str(e)))
|
||||
else:
|
||||
return False
|
||||
return True
|
||||
|
||||
def encryption_dict(self, model: Dict[str, object]):
|
||||
return {**model, 'api_key': super().encryption(model.get('api_key', ''))}
|
||||
|
||||
def get_model_params_setting_form(self, model_name):
|
||||
return AzureOpenAITTIModelParams()
|
||||
|
|
@ -0,0 +1,68 @@
|
|||
# coding=utf-8
|
||||
import traceback
|
||||
from typing import Dict
|
||||
|
||||
from django.utils.translation import gettext_lazy as _, gettext
|
||||
|
||||
from common import forms
|
||||
from common.exception.app_exception import AppApiException
|
||||
from common.forms import BaseForm, TooltipLabel
|
||||
from models_provider.base_model_provider import BaseModelCredential, ValidCode
|
||||
|
||||
|
||||
class AzureOpenAITTSModelGeneralParams(BaseForm):
|
||||
# alloy, echo, fable, onyx, nova, shimmer
|
||||
voice = forms.SingleSelect(
|
||||
TooltipLabel('Voice',
|
||||
_('Try out the different sounds (Alloy, Echo, Fable, Onyx, Nova, and Sparkle) to find one that suits your desired tone and audience. The current voiceover is optimized for English.')),
|
||||
required=True, default_value='alloy',
|
||||
text_field='value',
|
||||
value_field='value',
|
||||
option_list=[
|
||||
{'text': 'alloy', 'value': 'alloy'},
|
||||
{'text': 'echo', 'value': 'echo'},
|
||||
{'text': 'fable', 'value': 'fable'},
|
||||
{'text': 'onyx', 'value': 'onyx'},
|
||||
{'text': 'nova', 'value': 'nova'},
|
||||
{'text': 'shimmer', 'value': 'shimmer'},
|
||||
])
|
||||
|
||||
|
||||
class AzureOpenAITTSModelCredential(BaseForm, BaseModelCredential):
|
||||
api_version = forms.TextInputField("API Version", required=True)
|
||||
api_base = forms.TextInputField('Azure Endpoint', required=True)
|
||||
api_key = forms.PasswordInputField("API Key", required=True)
|
||||
|
||||
def is_valid(self, model_type: str, model_name, model_credential: Dict[str, object], model_params, provider,
|
||||
raise_exception=False):
|
||||
model_type_list = provider.get_model_type_list()
|
||||
if not any(list(filter(lambda mt: mt.get('value') == model_type, model_type_list))):
|
||||
raise AppApiException(ValidCode.valid_error.value,
|
||||
gettext('{model_type} Model type is not supported').format(model_type=model_type))
|
||||
|
||||
for key in ['api_base', 'api_key', 'api_version']:
|
||||
if key not in model_credential:
|
||||
if raise_exception:
|
||||
raise AppApiException(ValidCode.valid_error.value, gettext('{key} is required').format(key=key))
|
||||
else:
|
||||
return False
|
||||
try:
|
||||
model = provider.get_model(model_type, model_name, model_credential, **model_params)
|
||||
model.check_auth()
|
||||
except Exception as e:
|
||||
traceback.print_exc()
|
||||
if isinstance(e, AppApiException):
|
||||
raise e
|
||||
if raise_exception:
|
||||
raise AppApiException(ValidCode.valid_error.value, gettext(
|
||||
'Verification failed, please check whether the parameters are correct: {error}').format(
|
||||
error=str(e)))
|
||||
else:
|
||||
return False
|
||||
return True
|
||||
|
||||
def encryption_dict(self, model: Dict[str, object]):
|
||||
return {**model, 'api_key': super().encryption(model.get('api_key', ''))}
|
||||
|
||||
def get_model_params_setting_form(self, model_name):
|
||||
return AzureOpenAITTSModelGeneralParams()
|
||||
|
|
@ -0,0 +1 @@
|
|||
<svg t="1724827784525" class="icon" viewBox="0 0 1083 1024" version="1.1" xmlns="http://www.w3.org/2000/svg" p-id="5331" width="100%" height="100%"><path d="M540.115607 419.535168L460.236208 3.644671 376.260809 1.620347c-31.825387-5.244301-54.020439 8.390289-66.588115 40.908208C297.107977 75.046474 194.101272 379.938035 0.654058 957.203237-1.245965 999.731792 12.341272 1020.99311 41.418728 1020.99311h275.015029c14.599399-4.010173 25.104277-17.02178 31.514636-39.031861 6.410358-22.010081 70.46659-209.486428 192.168694-562.427561z" fill="#0D559D" p-id="5332"></path><path d="M669.463676 658.929202l-351.23052 6.189873c-29.527306 6.969711-32.546035 22.615306-9.054705 46.941226 35.237734 36.486659 357.010497 317.661965 372.719722 311.580115 9.100578-3.52185 16.537896-4.005734 26.473064-9.14941 24.994775-12.939098 21.644578-70.568694-10.05059-172.887307l-28.855491-182.674497z" fill="#0078D4" p-id="5333"></path><path d="M371.417526 0.331468l350.216879-0.325549a47.405873 47.405873 0 0 1 31.042589 11.556994c7.219792 6.252023 13.292763 14.842081 18.220393 25.765734 9.603699 21.293873 112.287815 325.368601 308.052347 912.225665a54.729249 54.729249 0 0 1 1.284439 30.169526c-6.225387 25.780532-19.432324 39.509827-39.623768 41.181965-24.348116 2.018405-144.088046 2.962497-359.221272 2.833758 30.551306-9.66141 44.410821-28.275422 41.578543-55.843515C720.153156 940.511445 618.336185 634.30289 417.516763 49.264462h0.00296a84.346821 84.346821 0 0 0-13.869873-25.255213C392.112092 9.559306 382.194682 2.04652 373.893179 1.47237c-10.553711-0.728046-11.379422-1.109827-2.475653-1.140902z" fill="#2FA7E7" p-id="5334"></path></svg>
|
||||
|
After Width: | Height: | Size: 1.6 KiB |
|
|
@ -0,0 +1,51 @@
|
|||
# coding=utf-8
|
||||
"""
|
||||
@project: maxkb
|
||||
@Author:虎
|
||||
@file: azure_chat_model.py
|
||||
@date:2024/4/28 11:45
|
||||
@desc:
|
||||
"""
|
||||
|
||||
from typing import List, Dict
|
||||
|
||||
from langchain_core.messages import BaseMessage, get_buffer_string
|
||||
from langchain_openai import AzureChatOpenAI
|
||||
|
||||
from common.config.tokenizer_manage_config import TokenizerManage
|
||||
from models_provider.base_model_provider import MaxKBBaseModel
|
||||
|
||||
|
||||
class AzureChatModel(MaxKBBaseModel, AzureChatOpenAI):
|
||||
@staticmethod
|
||||
def is_cache_model():
|
||||
return False
|
||||
|
||||
@staticmethod
|
||||
def new_instance(model_type, model_name, model_credential: Dict[str, object], **model_kwargs):
|
||||
optional_params = MaxKBBaseModel.filter_optional_params(model_kwargs)
|
||||
|
||||
return AzureChatModel(
|
||||
azure_endpoint=model_credential.get('api_base'),
|
||||
model_name=model_name,
|
||||
openai_api_version=model_credential.get('api_version', '2024-02-15-preview'),
|
||||
deployment_name=model_credential.get('deployment_name'),
|
||||
openai_api_key=model_credential.get('api_key'),
|
||||
openai_api_type="azure",
|
||||
**optional_params,
|
||||
streaming=True,
|
||||
)
|
||||
|
||||
def get_num_tokens_from_messages(self, messages: List[BaseMessage]) -> int:
|
||||
try:
|
||||
return super().get_num_tokens_from_messages(messages)
|
||||
except Exception as e:
|
||||
tokenizer = TokenizerManage.get_tokenizer()
|
||||
return sum([len(tokenizer.encode(get_buffer_string([m]))) for m in messages])
|
||||
|
||||
def get_num_tokens(self, text: str) -> int:
|
||||
try:
|
||||
return super().get_num_tokens(text)
|
||||
except Exception as e:
|
||||
tokenizer = TokenizerManage.get_tokenizer()
|
||||
return len(tokenizer.encode(text))
|
||||
|
|
@ -0,0 +1,25 @@
|
|||
# coding=utf-8
|
||||
"""
|
||||
@project: MaxKB
|
||||
@Author:虎
|
||||
@file: embedding.py
|
||||
@date:2024/7/12 17:44
|
||||
@desc:
|
||||
"""
|
||||
from typing import Dict
|
||||
|
||||
from langchain_openai import AzureOpenAIEmbeddings
|
||||
|
||||
from models_provider.base_model_provider import MaxKBBaseModel
|
||||
|
||||
|
||||
class AzureOpenAIEmbeddingModel(MaxKBBaseModel, AzureOpenAIEmbeddings):
|
||||
@staticmethod
|
||||
def new_instance(model_type, model_name, model_credential: Dict[str, object], **model_kwargs):
|
||||
return AzureOpenAIEmbeddingModel(
|
||||
model=model_name,
|
||||
openai_api_key=model_credential.get('api_key'),
|
||||
azure_endpoint=model_credential.get('api_base'),
|
||||
openai_api_version=model_credential.get('api_version'),
|
||||
openai_api_type="azure",
|
||||
)
|
||||
|
|
@ -0,0 +1,42 @@
|
|||
from typing import Dict, List
|
||||
|
||||
from langchain_core.messages import BaseMessage, get_buffer_string
|
||||
from langchain_openai import AzureChatOpenAI
|
||||
|
||||
from common.config.tokenizer_manage_config import TokenizerManage
|
||||
from models_provider.base_model_provider import MaxKBBaseModel
|
||||
|
||||
|
||||
def custom_get_token_ids(text: str):
|
||||
tokenizer = TokenizerManage.get_tokenizer()
|
||||
return tokenizer.encode(text)
|
||||
|
||||
|
||||
class AzureOpenAIImage(MaxKBBaseModel, AzureChatOpenAI):
|
||||
|
||||
@staticmethod
|
||||
def new_instance(model_type, model_name, model_credential: Dict[str, object], **model_kwargs):
|
||||
optional_params = MaxKBBaseModel.filter_optional_params(model_kwargs)
|
||||
return AzureOpenAIImage(
|
||||
model_name=model_name,
|
||||
openai_api_key=model_credential.get('api_key'),
|
||||
azure_endpoint=model_credential.get('api_base'),
|
||||
openai_api_version=model_credential.get('api_version'),
|
||||
openai_api_type="azure",
|
||||
streaming=True,
|
||||
**optional_params,
|
||||
)
|
||||
|
||||
def get_num_tokens_from_messages(self, messages: List[BaseMessage]) -> int:
|
||||
try:
|
||||
return super().get_num_tokens_from_messages(messages)
|
||||
except Exception as e:
|
||||
tokenizer = TokenizerManage.get_tokenizer()
|
||||
return sum([len(tokenizer.encode(get_buffer_string([m]))) for m in messages])
|
||||
|
||||
def get_num_tokens(self, text: str) -> int:
|
||||
try:
|
||||
return super().get_num_tokens(text)
|
||||
except Exception as e:
|
||||
tokenizer = TokenizerManage.get_tokenizer()
|
||||
return len(tokenizer.encode(text))
|
||||
|
|
@ -0,0 +1,62 @@
|
|||
import io
|
||||
from typing import Dict
|
||||
|
||||
from openai import AzureOpenAI
|
||||
|
||||
from common.config.tokenizer_manage_config import TokenizerManage
|
||||
from models_provider.base_model_provider import MaxKBBaseModel
|
||||
from models_provider.impl.base_stt import BaseSpeechToText
|
||||
|
||||
|
||||
def custom_get_token_ids(text: str):
|
||||
tokenizer = TokenizerManage.get_tokenizer()
|
||||
return tokenizer.encode(text)
|
||||
|
||||
|
||||
class AzureOpenAISpeechToText(MaxKBBaseModel, BaseSpeechToText):
|
||||
api_base: str
|
||||
api_key: str
|
||||
api_version: str
|
||||
model: str
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
self.api_key = kwargs.get('api_key')
|
||||
self.api_base = kwargs.get('api_base')
|
||||
self.api_version = kwargs.get('api_version')
|
||||
|
||||
@staticmethod
|
||||
def new_instance(model_type, model_name, model_credential: Dict[str, object], **model_kwargs):
|
||||
optional_params = {}
|
||||
if 'max_tokens' in model_kwargs and model_kwargs['max_tokens'] is not None:
|
||||
optional_params['max_tokens'] = model_kwargs['max_tokens']
|
||||
if 'temperature' in model_kwargs and model_kwargs['temperature'] is not None:
|
||||
optional_params['temperature'] = model_kwargs['temperature']
|
||||
return AzureOpenAISpeechToText(
|
||||
model=model_name,
|
||||
api_base=model_credential.get('api_base'),
|
||||
api_key=model_credential.get('api_key'),
|
||||
api_version=model_credential.get('api_version'),
|
||||
**optional_params,
|
||||
)
|
||||
|
||||
def check_auth(self):
|
||||
client = AzureOpenAI(
|
||||
azure_endpoint=self.api_base,
|
||||
api_key=self.api_key,
|
||||
api_version=self.api_version
|
||||
)
|
||||
response_list = client.models.with_raw_response.list()
|
||||
# print(response_list)
|
||||
|
||||
def speech_to_text(self, audio_file):
|
||||
client = AzureOpenAI(
|
||||
azure_endpoint=self.api_base,
|
||||
api_key=self.api_key,
|
||||
api_version=self.api_version
|
||||
)
|
||||
audio_data = audio_file.read()
|
||||
buffer = io.BytesIO(audio_data)
|
||||
buffer.name = "file.mp3" # this is the important line
|
||||
res = client.audio.transcriptions.create(model=self.model, language="zh", file=buffer)
|
||||
return res.text
|
||||
|
|
@ -0,0 +1,61 @@
|
|||
from typing import Dict
|
||||
|
||||
from openai import AzureOpenAI
|
||||
|
||||
from common.config.tokenizer_manage_config import TokenizerManage
|
||||
from models_provider.base_model_provider import MaxKBBaseModel
|
||||
from models_provider.impl.base_tti import BaseTextToImage
|
||||
|
||||
|
||||
def custom_get_token_ids(text: str):
|
||||
tokenizer = TokenizerManage.get_tokenizer()
|
||||
return tokenizer.encode(text)
|
||||
|
||||
|
||||
class AzureOpenAITextToImage(MaxKBBaseModel, BaseTextToImage):
|
||||
api_base: str
|
||||
api_key: str
|
||||
api_version: str
|
||||
model: str
|
||||
params: dict
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
self.api_key = kwargs.get('api_key')
|
||||
self.api_base = kwargs.get('api_base')
|
||||
self.api_version = kwargs.get('api_version')
|
||||
self.model = kwargs.get('model')
|
||||
self.params = kwargs.get('params')
|
||||
|
||||
@staticmethod
|
||||
def new_instance(model_type, model_name, model_credential: Dict[str, object], **model_kwargs):
|
||||
optional_params = {'params': {'size': '1024x1024', 'quality': 'standard', 'n': 1}}
|
||||
for key, value in model_kwargs.items():
|
||||
if key not in ['model_id', 'use_local', 'streaming']:
|
||||
optional_params['params'][key] = value
|
||||
return AzureOpenAITextToImage(
|
||||
model=model_name,
|
||||
api_base=model_credential.get('api_base'),
|
||||
api_key=model_credential.get('api_key'),
|
||||
api_version=model_credential.get('api_version'),
|
||||
**optional_params,
|
||||
)
|
||||
|
||||
def is_cache_model(self):
|
||||
return False
|
||||
|
||||
def check_auth(self):
|
||||
chat = AzureOpenAI(api_key=self.api_key, azure_endpoint=self.api_base, api_version=self.api_version)
|
||||
response_list = chat.models.with_raw_response.list()
|
||||
|
||||
# self.generate_image('生成一个小猫图片')
|
||||
|
||||
def generate_image(self, prompt: str, negative_prompt: str = None):
|
||||
chat = AzureOpenAI(api_key=self.api_key, azure_endpoint=self.api_base, api_version=self.api_version)
|
||||
res = chat.images.generate(model=self.model, prompt=prompt, **self.params)
|
||||
file_urls = []
|
||||
for content in res.data:
|
||||
url = content.url
|
||||
file_urls.append(url)
|
||||
|
||||
return file_urls
|
||||
|
|
@ -0,0 +1,69 @@
|
|||
from typing import Dict
|
||||
|
||||
from openai import AzureOpenAI
|
||||
|
||||
from common.config.tokenizer_manage_config import TokenizerManage
|
||||
from common.utils.common import _remove_empty_lines
|
||||
from models_provider.base_model_provider import MaxKBBaseModel
|
||||
from models_provider.impl.base_tts import BaseTextToSpeech
|
||||
|
||||
|
||||
def custom_get_token_ids(text: str):
|
||||
tokenizer = TokenizerManage.get_tokenizer()
|
||||
return tokenizer.encode(text)
|
||||
|
||||
|
||||
class AzureOpenAITextToSpeech(MaxKBBaseModel, BaseTextToSpeech):
|
||||
api_base: str
|
||||
api_key: str
|
||||
api_version: str
|
||||
model: str
|
||||
params: dict
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
self.api_key = kwargs.get('api_key')
|
||||
self.api_base = kwargs.get('api_base')
|
||||
self.api_version = kwargs.get('api_version')
|
||||
self.model = kwargs.get('model')
|
||||
self.params = kwargs.get('params')
|
||||
|
||||
@staticmethod
|
||||
def new_instance(model_type, model_name, model_credential: Dict[str, object], **model_kwargs):
|
||||
optional_params = {'params': {'voice': 'alloy'}}
|
||||
for key, value in model_kwargs.items():
|
||||
if key not in ['model_id', 'use_local', 'streaming']:
|
||||
optional_params['params'][key] = value
|
||||
return AzureOpenAITextToSpeech(
|
||||
model=model_name,
|
||||
api_base=model_credential.get('api_base'),
|
||||
api_key=model_credential.get('api_key'),
|
||||
api_version=model_credential.get('api_version'),
|
||||
**optional_params,
|
||||
)
|
||||
|
||||
def check_auth(self):
|
||||
client = AzureOpenAI(
|
||||
azure_endpoint=self.api_base,
|
||||
api_key=self.api_key,
|
||||
api_version=self.api_version
|
||||
)
|
||||
response_list = client.models.with_raw_response.list()
|
||||
# print(response_list)
|
||||
|
||||
def text_to_speech(self, text):
|
||||
client = AzureOpenAI(
|
||||
azure_endpoint=self.api_base,
|
||||
api_key=self.api_key,
|
||||
api_version=self.api_version
|
||||
)
|
||||
text = _remove_empty_lines(text)
|
||||
with client.audio.speech.with_streaming_response.create(
|
||||
model=self.model,
|
||||
input=text,
|
||||
**self.params
|
||||
) as response:
|
||||
return response.read()
|
||||
|
||||
def is_cache_model(self):
|
||||
return False
|
||||
|
|
@ -0,0 +1,168 @@
|
|||
# coding=utf-8
|
||||
import warnings
|
||||
from typing import List, Dict, Optional, Any, Iterator, cast, Type, Union
|
||||
|
||||
import openai
|
||||
from langchain_core.callbacks import CallbackManagerForLLMRun
|
||||
from langchain_core.language_models import LanguageModelInput
|
||||
from langchain_core.messages import BaseMessage, get_buffer_string, BaseMessageChunk, AIMessageChunk
|
||||
from langchain_core.outputs import ChatGenerationChunk, ChatGeneration
|
||||
from langchain_core.runnables import RunnableConfig, ensure_config
|
||||
from langchain_core.utils.pydantic import is_basemodel_subclass
|
||||
from langchain_openai import ChatOpenAI
|
||||
|
||||
from common.config.tokenizer_manage_config import TokenizerManage
|
||||
|
||||
|
||||
def custom_get_token_ids(text: str):
|
||||
tokenizer = TokenizerManage.get_tokenizer()
|
||||
return tokenizer.encode(text)
|
||||
|
||||
|
||||
class BaseChatOpenAI(ChatOpenAI):
|
||||
usage_metadata: dict = {}
|
||||
custom_get_token_ids = custom_get_token_ids
|
||||
|
||||
def get_last_generation_info(self) -> Optional[Dict[str, Any]]:
|
||||
return self.usage_metadata
|
||||
|
||||
def get_num_tokens_from_messages(self, messages: List[BaseMessage]) -> int:
|
||||
if self.usage_metadata is None or self.usage_metadata == {}:
|
||||
try:
|
||||
return super().get_num_tokens_from_messages(messages)
|
||||
except Exception as e:
|
||||
tokenizer = TokenizerManage.get_tokenizer()
|
||||
return sum([len(tokenizer.encode(get_buffer_string([m]))) for m in messages])
|
||||
return self.usage_metadata.get('input_tokens', 0)
|
||||
|
||||
def get_num_tokens(self, text: str) -> int:
|
||||
if self.usage_metadata is None or self.usage_metadata == {}:
|
||||
try:
|
||||
return super().get_num_tokens(text)
|
||||
except Exception as e:
|
||||
tokenizer = TokenizerManage.get_tokenizer()
|
||||
return len(tokenizer.encode(text))
|
||||
return self.get_last_generation_info().get('output_tokens', 0)
|
||||
|
||||
def _stream(
|
||||
self,
|
||||
messages: List[BaseMessage],
|
||||
stop: Optional[List[str]] = None,
|
||||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> Iterator[ChatGenerationChunk]:
|
||||
kwargs["stream"] = True
|
||||
kwargs["stream_options"] = {"include_usage": True}
|
||||
"""Set default stream_options."""
|
||||
stream_usage = self._should_stream_usage(kwargs.get('stream_usage'), **kwargs)
|
||||
# Note: stream_options is not a valid parameter for Azure OpenAI.
|
||||
# To support users proxying Azure through ChatOpenAI, here we only specify
|
||||
# stream_options if include_usage is set to True.
|
||||
# See https://learn.microsoft.com/en-us/azure/ai-services/openai/whats-new
|
||||
# for release notes.
|
||||
if stream_usage:
|
||||
kwargs["stream_options"] = {"include_usage": stream_usage}
|
||||
|
||||
payload = self._get_request_payload(messages, stop=stop, **kwargs)
|
||||
default_chunk_class: Type[BaseMessageChunk] = AIMessageChunk
|
||||
base_generation_info = {}
|
||||
|
||||
if "response_format" in payload and is_basemodel_subclass(
|
||||
payload["response_format"]
|
||||
):
|
||||
# TODO: Add support for streaming with Pydantic response_format.
|
||||
warnings.warn("Streaming with Pydantic response_format not yet supported.")
|
||||
chat_result = self._generate(
|
||||
messages, stop, run_manager=run_manager, **kwargs
|
||||
)
|
||||
msg = chat_result.generations[0].message
|
||||
yield ChatGenerationChunk(
|
||||
message=AIMessageChunk(
|
||||
**msg.dict(exclude={"type", "additional_kwargs"}),
|
||||
# preserve the "parsed" Pydantic object without converting to dict
|
||||
additional_kwargs=msg.additional_kwargs,
|
||||
),
|
||||
generation_info=chat_result.generations[0].generation_info,
|
||||
)
|
||||
return
|
||||
if self.include_response_headers:
|
||||
raw_response = self.client.with_raw_response.create(**payload)
|
||||
response = raw_response.parse()
|
||||
base_generation_info = {"headers": dict(raw_response.headers)}
|
||||
else:
|
||||
response = self.client.create(**payload)
|
||||
with response:
|
||||
is_first_chunk = True
|
||||
for chunk in response:
|
||||
if not isinstance(chunk, dict):
|
||||
chunk = chunk.model_dump()
|
||||
|
||||
generation_chunk = super()._convert_chunk_to_generation_chunk(
|
||||
chunk,
|
||||
default_chunk_class,
|
||||
base_generation_info if is_first_chunk else {},
|
||||
)
|
||||
if generation_chunk is None:
|
||||
continue
|
||||
|
||||
# custom code
|
||||
if len(chunk['choices']) > 0 and 'reasoning_content' in chunk['choices'][0]['delta']:
|
||||
generation_chunk.message.additional_kwargs["reasoning_content"] = chunk['choices'][0]['delta'][
|
||||
'reasoning_content']
|
||||
|
||||
default_chunk_class = generation_chunk.message.__class__
|
||||
logprobs = (generation_chunk.generation_info or {}).get("logprobs")
|
||||
if run_manager:
|
||||
run_manager.on_llm_new_token(
|
||||
generation_chunk.text, chunk=generation_chunk, logprobs=logprobs
|
||||
)
|
||||
is_first_chunk = False
|
||||
# custom code
|
||||
if generation_chunk.message.usage_metadata is not None:
|
||||
self.usage_metadata = generation_chunk.message.usage_metadata
|
||||
yield generation_chunk
|
||||
|
||||
def _create_chat_result(self,
|
||||
response: Union[dict, openai.BaseModel],
|
||||
generation_info: Optional[Dict] = None):
|
||||
result = super()._create_chat_result(response, generation_info)
|
||||
try:
|
||||
reasoning_content = ''
|
||||
reasoning_content_enable = False
|
||||
for res in response.choices:
|
||||
if 'reasoning_content' in res.message.model_extra:
|
||||
reasoning_content_enable = True
|
||||
_reasoning_content = res.message.model_extra.get('reasoning_content')
|
||||
if _reasoning_content is not None:
|
||||
reasoning_content += _reasoning_content
|
||||
if reasoning_content_enable:
|
||||
result.llm_output['reasoning_content'] = reasoning_content
|
||||
except Exception as e:
|
||||
pass
|
||||
return result
|
||||
|
||||
def invoke(
|
||||
self,
|
||||
input: LanguageModelInput,
|
||||
config: Optional[RunnableConfig] = None,
|
||||
*,
|
||||
stop: Optional[List[str]] = None,
|
||||
**kwargs: Any,
|
||||
) -> BaseMessage:
|
||||
config = ensure_config(config)
|
||||
chat_result = cast(
|
||||
ChatGeneration,
|
||||
self.generate_prompt(
|
||||
[self._convert_input(input)],
|
||||
stop=stop,
|
||||
callbacks=config.get("callbacks"),
|
||||
tags=config.get("tags"),
|
||||
metadata=config.get("metadata"),
|
||||
run_name=config.get("run_name"),
|
||||
run_id=config.pop("run_id", None),
|
||||
**kwargs,
|
||||
).generations[0][0],
|
||||
).message
|
||||
self.usage_metadata = chat_result.response_metadata[
|
||||
'token_usage'] if 'token_usage' in chat_result.response_metadata else chat_result.usage_metadata
|
||||
return chat_result
|
||||
|
|
@ -0,0 +1,14 @@
|
|||
# coding=utf-8
|
||||
from abc import abstractmethod
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
class BaseSpeechToText(BaseModel):
|
||||
@abstractmethod
|
||||
def check_auth(self):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def speech_to_text(self, audio_file):
|
||||
pass
|
||||
|
|
@ -0,0 +1,14 @@
|
|||
# coding=utf-8
|
||||
from abc import abstractmethod
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
class BaseTextToImage(BaseModel):
|
||||
@abstractmethod
|
||||
def check_auth(self):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def generate_image(self, prompt: str, negative_prompt: str = None):
|
||||
pass
|
||||
|
|
@ -0,0 +1,14 @@
|
|||
# coding=utf-8
|
||||
from abc import abstractmethod
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
class BaseTextToSpeech(BaseModel):
|
||||
@abstractmethod
|
||||
def check_auth(self):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def text_to_speech(self, text):
|
||||
pass
|
||||
|
|
@ -0,0 +1,8 @@
|
|||
#!/usr/bin/env python
|
||||
# -*- coding: UTF-8 -*-
|
||||
"""
|
||||
@Project :MaxKB
|
||||
@File :__init__.py.py
|
||||
@Author :Brian Yang
|
||||
@Date :5/12/24 7:38 AM
|
||||
"""
|
||||
|
|
@ -0,0 +1,77 @@
|
|||
# coding=utf-8
|
||||
"""
|
||||
@project: MaxKB
|
||||
@Author:虎
|
||||
@file: llm.py
|
||||
@date:2024/7/11 17:51
|
||||
@desc:
|
||||
"""
|
||||
import traceback
|
||||
from typing import Dict
|
||||
|
||||
from django.utils.translation import gettext_lazy as _, gettext
|
||||
from langchain_core.messages import HumanMessage
|
||||
|
||||
from common import forms
|
||||
from common.exception.app_exception import AppApiException
|
||||
from common.forms import BaseForm, TooltipLabel
|
||||
from models_provider.base_model_provider import BaseModelCredential, ValidCode
|
||||
|
||||
|
||||
class DeepSeekLLMModelParams(BaseForm):
|
||||
temperature = forms.SliderField(TooltipLabel(_('Temperature'),
|
||||
_('Higher values make the output more random, while lower values make it more focused and deterministic')),
|
||||
required=True, default_value=0.7,
|
||||
_min=0.1,
|
||||
_max=1.0,
|
||||
_step=0.01,
|
||||
precision=2)
|
||||
|
||||
max_tokens = forms.SliderField(
|
||||
TooltipLabel(_('Output the maximum Tokens'),
|
||||
_('Specify the maximum number of tokens that the model can generate')),
|
||||
required=True, default_value=800,
|
||||
_min=1,
|
||||
_max=100000,
|
||||
_step=1,
|
||||
precision=0)
|
||||
|
||||
|
||||
class DeepSeekLLMModelCredential(BaseForm, BaseModelCredential):
|
||||
|
||||
def is_valid(self, model_type: str, model_name, model_credential: Dict[str, object], model_params, provider,
|
||||
raise_exception=False):
|
||||
model_type_list = provider.get_model_type_list()
|
||||
if not any(list(filter(lambda mt: mt.get('value') == model_type, model_type_list))):
|
||||
raise AppApiException(ValidCode.valid_error.value,
|
||||
gettext('{model_type} Model type is not supported').format(model_type=model_type))
|
||||
|
||||
for key in ['api_key']:
|
||||
if key not in model_credential:
|
||||
if raise_exception:
|
||||
raise AppApiException(ValidCode.valid_error.value, gettext('{key} is required').format(key=key))
|
||||
else:
|
||||
return False
|
||||
try:
|
||||
model = provider.get_model(model_type, model_name, model_credential, **model_params)
|
||||
model.invoke([HumanMessage(content=gettext('Hello'))])
|
||||
except Exception as e:
|
||||
traceback.print_exc()
|
||||
if isinstance(e, AppApiException):
|
||||
raise e
|
||||
if raise_exception:
|
||||
raise AppApiException(ValidCode.valid_error.value,
|
||||
gettext(
|
||||
'Verification failed, please check whether the parameters are correct: {error}').format(
|
||||
error=str(e)))
|
||||
else:
|
||||
return False
|
||||
return True
|
||||
|
||||
def encryption_dict(self, model: Dict[str, object]):
|
||||
return {**model, 'api_key': super().encryption(model.get('api_key', ''))}
|
||||
|
||||
api_key = forms.PasswordInputField('API Key', required=True)
|
||||
|
||||
def get_model_params_setting_form(self, model_name):
|
||||
return DeepSeekLLMModelParams()
|
||||
|
|
@ -0,0 +1,47 @@
|
|||
#!/usr/bin/env python
|
||||
# -*- coding: UTF-8 -*-
|
||||
"""
|
||||
@Project :MaxKB
|
||||
@File :deepseek_model_provider.py
|
||||
@Author :Brian Yang
|
||||
@Date :5/12/24 7:40 AM
|
||||
"""
|
||||
import os
|
||||
|
||||
from common.utils.common import get_file_content
|
||||
from models_provider.base_model_provider import IModelProvider, ModelProvideInfo, ModelInfo, ModelTypeConst, \
|
||||
ModelInfoManage
|
||||
from models_provider.impl.deepseek_model_provider.credential.llm import DeepSeekLLMModelCredential
|
||||
from models_provider.impl.deepseek_model_provider.model.llm import DeepSeekChatModel
|
||||
from maxkb.conf import PROJECT_DIR
|
||||
from django.utils.translation import gettext as _
|
||||
|
||||
deepseek_llm_model_credential = DeepSeekLLMModelCredential()
|
||||
deepseek_reasoner = ModelInfo('deepseek-reasoner', '', ModelTypeConst.LLM,
|
||||
deepseek_llm_model_credential, DeepSeekChatModel
|
||||
)
|
||||
|
||||
deepseek_chat = ModelInfo('deepseek-chat', _('Good at common conversational tasks, supports 32K contexts'),
|
||||
ModelTypeConst.LLM,
|
||||
deepseek_llm_model_credential, DeepSeekChatModel
|
||||
)
|
||||
|
||||
deepseek_coder = ModelInfo('deepseek-coder', _('Good at handling programming tasks, supports 16K contexts'),
|
||||
ModelTypeConst.LLM,
|
||||
deepseek_llm_model_credential,
|
||||
DeepSeekChatModel)
|
||||
|
||||
model_info_manage = ModelInfoManage.builder().append_model_info(deepseek_reasoner).append_model_info(deepseek_chat).append_model_info(
|
||||
deepseek_coder).append_default_model_info(
|
||||
deepseek_coder).build()
|
||||
|
||||
|
||||
class DeepSeekModelProvider(IModelProvider):
|
||||
|
||||
def get_model_info_manage(self):
|
||||
return model_info_manage
|
||||
|
||||
def get_model_provide_info(self):
|
||||
return ModelProvideInfo(provider='model_deepseek_provider', name='DeepSeek', icon=get_file_content(
|
||||
os.path.join(PROJECT_DIR, "apps", 'models_provider', 'impl', 'deepseek_model_provider', 'icon',
|
||||
'deepseek_icon_svg')))
|
||||
|
|
@ -0,0 +1,6 @@
|
|||
<svg width="100%" height="100%" viewBox="0 0 50 50" fill="none" xmlns="http://www.w3.org/2000/svg"
|
||||
xmlns:xlink="http://www.w3.org/1999/xlink">
|
||||
<path id="path"
|
||||
d="M48.8354 10.0479C48.3232 9.79199 48.1025 10.2798 47.8032 10.5278C47.7007 10.6079 47.6143 10.7119 47.5273 10.8076C46.7793 11.624 45.9048 12.1597 44.7622 12.0957C43.0923 12 41.666 12.5356 40.4058 13.8398C40.1377 12.2319 39.2476 11.272 37.8926 10.6558C37.1836 10.3359 36.4668 10.0156 35.9702 9.31982C35.6235 8.82373 35.5293 8.27197 35.356 7.72754C35.2456 7.3999 35.1353 7.06396 34.7651 7.00781C34.3633 6.94385 34.2056 7.2876 34.0479 7.57568C33.418 8.75195 33.1733 10.0479 33.1973 11.3599C33.2524 14.312 34.4736 16.6641 36.8999 18.3359C37.1758 18.5278 37.2466 18.7197 37.1597 19C36.9946 19.5757 36.7974 20.1357 36.624 20.7119C36.5137 21.0801 36.3486 21.1597 35.9624 21C34.6309 20.4321 33.481 19.5918 32.4644 18.5757C30.7393 16.8721 29.1792 14.9917 27.2334 13.52C26.7764 13.1758 26.3193 12.856 25.8467 12.5518C23.8618 10.584 26.1069 8.96777 26.627 8.77588C27.1704 8.57568 26.8159 7.8877 25.0591 7.896C23.3022 7.90381 21.6953 8.50391 19.647 9.30371C19.3477 9.42383 19.0322 9.51172 18.7095 9.58398C16.8501 9.22363 14.9199 9.14355 12.9033 9.37598C9.10596 9.80762 6.07275 11.6396 3.84326 14.7681C1.16455 18.5278 0.53418 22.7998 1.30664 27.2559C2.11768 31.9521 4.46582 35.8398 8.07373 38.8799C11.8159 42.0322 16.1255 43.5762 21.041 43.2803C24.0269 43.104 27.3516 42.6963 31.1016 39.4561C32.0469 39.936 33.0396 40.1279 34.686 40.272C35.9546 40.3921 37.1758 40.208 38.1211 40.0078C39.6021 39.688 39.4995 38.2881 38.9639 38.0322C34.623 35.9678 35.5762 36.8081 34.71 36.1279C36.9155 33.4639 40.2402 30.6958 41.54 21.728C41.6426 21.0161 41.5557 20.5679 41.54 19.9917C41.5322 19.6396 41.6108 19.5039 42.0049 19.4639C43.0923 19.3359 44.1479 19.0317 45.1167 18.4878C47.9292 16.9199 49.064 14.3438 49.3315 11.2559C49.3711 10.7837 49.3237 10.2959 48.8354 10.0479ZM24.3262 37.8398C20.1196 34.4639 18.0791 33.3521 17.2358 33.3999C16.4482 33.4482 16.5898 34.3682 16.7632 34.9678C16.9443 35.5601 17.1812 35.9683 17.5117 36.4878C17.7402 36.832 17.8979 37.3442 17.2832 37.728C15.9282 38.584 13.5728 37.4399 13.4624 37.3838C10.7207 35.7358 8.42822 33.5601 6.81348 30.584C5.25342 27.7197 4.34766 24.6479 4.19775 21.3677C4.1582 20.5757 4.38672 20.2959 5.15869 20.1519C6.17529 19.96 7.22314 19.9199 8.23926 20.0718C12.5327 20.7119 16.1885 22.6719 19.2529 25.7759C21.002 27.5439 22.3252 29.6558 23.6885 31.7202C25.1377 33.9121 26.6978 36 28.6831 37.7119C29.3843 38.312 29.9434 38.7681 30.479 39.104C28.8643 39.2881 26.1699 39.3281 24.3262 37.8398ZM26.3433 24.6001C26.3433 24.248 26.6191 23.9678 26.9658 23.9678C27.0444 23.9678 27.1152 23.9839 27.1782 24.0078C27.2651 24.04 27.3438 24.0879 27.4067 24.1602C27.5171 24.272 27.5801 24.4321 27.5801 24.6001C27.5801 24.9521 27.3042 25.2319 26.9575 25.2319C26.6108 25.2319 26.3433 24.9521 26.3433 24.6001ZM32.6064 27.8799C32.2046 28.0479 31.8027 28.1919 31.4165 28.208C30.8179 28.2397 30.1641 27.9922 29.8096 27.688C29.2583 27.2158 28.8643 26.9521 28.6987 26.1279C28.6279 25.7759 28.6675 25.2319 28.7305 24.9199C28.8721 24.248 28.7144 23.8159 28.2495 23.4238C27.8716 23.104 27.3911 23.0161 26.8633 23.0161C26.666 23.0161 26.4849 22.9277 26.3511 22.856C26.1304 22.7441 25.9492 22.4639 26.1226 22.1201C26.1777 22.0078 26.4458 21.7358 26.5088 21.688C27.2256 21.272 28.0527 21.4077 28.8169 21.7197C29.5259 22.0161 30.0615 22.5601 30.834 23.3281C31.6216 24.2559 31.7632 24.5117 32.2124 25.208C32.5669 25.752 32.8901 26.312 33.1104 26.9521C33.2446 27.3521 33.0713 27.6802 32.6064 27.8799Z"
|
||||
fill="#4D6BFE" fill-opacity="1.000000" fill-rule="nonzero" />
|
||||
</svg>
|
||||
|
After Width: | Height: | Size: 3.5 KiB |
|
|
@ -0,0 +1,31 @@
|
|||
#!/usr/bin/env python
|
||||
# -*- coding: UTF-8 -*-
|
||||
"""
|
||||
@Project :MaxKB
|
||||
@File :llm.py
|
||||
@Author :Brian Yang
|
||||
@Date :5/12/24 7:44 AM
|
||||
"""
|
||||
from typing import Dict
|
||||
|
||||
from models_provider.base_model_provider import MaxKBBaseModel
|
||||
from models_provider.impl.base_chat_open_ai import BaseChatOpenAI
|
||||
|
||||
|
||||
class DeepSeekChatModel(MaxKBBaseModel, BaseChatOpenAI):
|
||||
|
||||
@staticmethod
|
||||
def is_cache_model():
|
||||
return False
|
||||
|
||||
@staticmethod
|
||||
def new_instance(model_type, model_name, model_credential: Dict[str, object], **model_kwargs):
|
||||
optional_params = MaxKBBaseModel.filter_optional_params(model_kwargs)
|
||||
|
||||
deepseek_chat_open_ai = DeepSeekChatModel(
|
||||
model=model_name,
|
||||
openai_api_base='https://api.deepseek.com',
|
||||
openai_api_key=model_credential.get('api_key'),
|
||||
**optional_params
|
||||
)
|
||||
return deepseek_chat_open_ai
|
||||
|
|
@ -0,0 +1,8 @@
|
|||
#!/usr/bin/env python
|
||||
# -*- coding: UTF-8 -*-
|
||||
"""
|
||||
@Project :MaxKB
|
||||
@File :__init__.py.py
|
||||
@Author :Brian Yang
|
||||
@Date :5/13/24 7:40 AM
|
||||
"""
|
||||
|
|
@ -0,0 +1,52 @@
|
|||
# coding=utf-8
|
||||
"""
|
||||
@project: MaxKB
|
||||
@Author:虎
|
||||
@file: embedding.py
|
||||
@date:2024/7/12 16:45
|
||||
@desc:
|
||||
"""
|
||||
import traceback
|
||||
from typing import Dict
|
||||
|
||||
from django.utils.translation import gettext as _
|
||||
|
||||
from common import forms
|
||||
from common.exception.app_exception import AppApiException
|
||||
from common.forms import BaseForm
|
||||
from models_provider.base_model_provider import BaseModelCredential, ValidCode
|
||||
|
||||
|
||||
class GeminiEmbeddingCredential(BaseForm, BaseModelCredential):
|
||||
def is_valid(self, model_type: str, model_name, model_credential: Dict[str, object], model_params, provider,
|
||||
raise_exception=True):
|
||||
model_type_list = provider.get_model_type_list()
|
||||
if not any(list(filter(lambda mt: mt.get('value') == model_type, model_type_list))):
|
||||
raise AppApiException(ValidCode.valid_error.value,
|
||||
_('{model_type} Model type is not supported').format(model_type=model_type))
|
||||
|
||||
for key in ['api_key']:
|
||||
if key not in model_credential:
|
||||
if raise_exception:
|
||||
raise AppApiException(ValidCode.valid_error.value, _('{key} is required').format(key=key))
|
||||
else:
|
||||
return False
|
||||
try:
|
||||
model = provider.get_model(model_type, model_name, model_credential)
|
||||
model.embed_query(_('Hello'))
|
||||
except Exception as e:
|
||||
traceback.print_exc()
|
||||
if isinstance(e, AppApiException):
|
||||
raise e
|
||||
if raise_exception:
|
||||
raise AppApiException(ValidCode.valid_error.value,
|
||||
_('Verification failed, please check whether the parameters are correct: {error}').format(
|
||||
error=str(e)))
|
||||
else:
|
||||
return False
|
||||
return True
|
||||
|
||||
def encryption_dict(self, model: Dict[str, object]):
|
||||
return {**model, 'api_key': super().encryption(model.get('api_key', ''))}
|
||||
|
||||
api_key = forms.PasswordInputField('API Key', required=True)
|
||||
|
|
@ -0,0 +1,71 @@
|
|||
# coding=utf-8
|
||||
import traceback
|
||||
from typing import Dict
|
||||
|
||||
from django.utils.translation import gettext_lazy as _, gettext
|
||||
from langchain_core.messages import HumanMessage
|
||||
|
||||
from common import forms
|
||||
from common.exception.app_exception import AppApiException
|
||||
from common.forms import BaseForm, TooltipLabel
|
||||
from models_provider.base_model_provider import BaseModelCredential, ValidCode
|
||||
|
||||
|
||||
class GeminiImageModelParams(BaseForm):
|
||||
temperature = forms.SliderField(TooltipLabel(_('Temperature'),
|
||||
_('Higher values make the output more random, while lower values make it more focused and deterministic')),
|
||||
required=True, default_value=0.7,
|
||||
_min=0.1,
|
||||
_max=1.0,
|
||||
_step=0.01,
|
||||
precision=2)
|
||||
|
||||
max_tokens = forms.SliderField(
|
||||
TooltipLabel(_('Output the maximum Tokens'),
|
||||
_('Specify the maximum number of tokens that the model can generate')),
|
||||
required=True, default_value=800,
|
||||
_min=1,
|
||||
_max=100000,
|
||||
_step=1,
|
||||
precision=0)
|
||||
|
||||
|
||||
class GeminiImageModelCredential(BaseForm, BaseModelCredential):
|
||||
api_key = forms.PasswordInputField('API Key', required=True)
|
||||
|
||||
def is_valid(self, model_type: str, model_name, model_credential: Dict[str, object], model_params, provider,
|
||||
raise_exception=False):
|
||||
model_type_list = provider.get_model_type_list()
|
||||
if not any(list(filter(lambda mt: mt.get('value') == model_type, model_type_list))):
|
||||
raise AppApiException(ValidCode.valid_error.value,
|
||||
gettext('{model_type} Model type is not supported').format(model_type=model_type))
|
||||
|
||||
for key in ['api_key']:
|
||||
if key not in model_credential:
|
||||
if raise_exception:
|
||||
raise AppApiException(ValidCode.valid_error.value, gettext('{key} is required').format(key=key))
|
||||
else:
|
||||
return False
|
||||
try:
|
||||
model = provider.get_model(model_type, model_name, model_credential, **model_params)
|
||||
res = model.stream([HumanMessage(content=[{"type": "text", "text": gettext('Hello')}])])
|
||||
for chunk in res:
|
||||
print(chunk)
|
||||
except Exception as e:
|
||||
traceback.print_exc()
|
||||
if isinstance(e, AppApiException):
|
||||
raise e
|
||||
if raise_exception:
|
||||
raise AppApiException(ValidCode.valid_error.value,
|
||||
gettext(
|
||||
'Verification failed, please check whether the parameters are correct: {error}').format(
|
||||
error=str(e)))
|
||||
else:
|
||||
return False
|
||||
return True
|
||||
|
||||
def encryption_dict(self, model: Dict[str, object]):
|
||||
return {**model, 'api_key': super().encryption(model.get('api_key', ''))}
|
||||
|
||||
def get_model_params_setting_form(self, model_name):
|
||||
return GeminiImageModelParams()
|
||||
|
|
@ -0,0 +1,78 @@
|
|||
# coding=utf-8
|
||||
"""
|
||||
@project: MaxKB
|
||||
@Author:虎
|
||||
@file: llm.py
|
||||
@date:2024/7/11 17:57
|
||||
@desc:
|
||||
"""
|
||||
import traceback
|
||||
from typing import Dict
|
||||
|
||||
from django.utils.translation import gettext_lazy as _, gettext
|
||||
from langchain_core.messages import HumanMessage
|
||||
|
||||
from common import forms
|
||||
from common.exception.app_exception import AppApiException
|
||||
from common.forms import BaseForm, TooltipLabel
|
||||
from models_provider.base_model_provider import BaseModelCredential, ValidCode
|
||||
|
||||
|
||||
class GeminiLLMModelParams(BaseForm):
|
||||
temperature = forms.SliderField(TooltipLabel(_('Temperature'),
|
||||
_('Higher values make the output more random, while lower values make it more focused and deterministic')),
|
||||
required=True, default_value=0.7,
|
||||
_min=0.1,
|
||||
_max=1.0,
|
||||
_step=0.01,
|
||||
precision=2)
|
||||
|
||||
max_tokens = forms.SliderField(
|
||||
TooltipLabel(_('Output the maximum Tokens'),
|
||||
_('Specify the maximum number of tokens that the model can generate')),
|
||||
required=True, default_value=800,
|
||||
_min=1,
|
||||
_max=100000,
|
||||
_step=1,
|
||||
precision=0)
|
||||
|
||||
|
||||
class GeminiLLMModelCredential(BaseForm, BaseModelCredential):
|
||||
|
||||
def is_valid(self, model_type: str, model_name, model_credential: Dict[str, object], model_params, provider,
|
||||
raise_exception=False):
|
||||
model_type_list = provider.get_model_type_list()
|
||||
if not any(list(filter(lambda mt: mt.get('value') == model_type, model_type_list))):
|
||||
raise AppApiException(ValidCode.valid_error.value,
|
||||
gettext('{model_type} Model type is not supported').format(model_type=model_type))
|
||||
|
||||
for key in ['api_key']:
|
||||
if key not in model_credential:
|
||||
if raise_exception:
|
||||
raise AppApiException(ValidCode.valid_error.value, gettext('{key} is required').format(key=key))
|
||||
else:
|
||||
return False
|
||||
try:
|
||||
model = provider.get_model(model_type, model_name, model_credential, **model_params)
|
||||
res = model.invoke([HumanMessage(content=gettext('Hello'))])
|
||||
print(res)
|
||||
except Exception as e:
|
||||
traceback.print_exc()
|
||||
if isinstance(e, AppApiException):
|
||||
raise e
|
||||
if raise_exception:
|
||||
raise AppApiException(ValidCode.valid_error.value,
|
||||
gettext(
|
||||
'Verification failed, please check whether the parameters are correct: {error}').format(
|
||||
error=str(e)))
|
||||
else:
|
||||
return False
|
||||
return True
|
||||
|
||||
def encryption_dict(self, model: Dict[str, object]):
|
||||
return {**model, 'api_key': super().encryption(model.get('api_key', ''))}
|
||||
|
||||
api_key = forms.PasswordInputField('API Key', required=True)
|
||||
|
||||
def get_model_params_setting_form(self, model_name):
|
||||
return GeminiLLMModelParams()
|
||||
|
|
@ -0,0 +1,48 @@
|
|||
# coding=utf-8
|
||||
import traceback
|
||||
from typing import Dict
|
||||
|
||||
from django.utils.translation import gettext as _
|
||||
|
||||
from common import forms
|
||||
from common.exception.app_exception import AppApiException
|
||||
from common.forms import BaseForm
|
||||
from models_provider.base_model_provider import BaseModelCredential, ValidCode
|
||||
|
||||
|
||||
class GeminiSTTModelCredential(BaseForm, BaseModelCredential):
|
||||
api_key = forms.PasswordInputField('API Key', required=True)
|
||||
|
||||
def is_valid(self, model_type: str, model_name, model_credential: Dict[str, object], model_params, provider,
|
||||
raise_exception=False):
|
||||
model_type_list = provider.get_model_type_list()
|
||||
if not any(list(filter(lambda mt: mt.get('value') == model_type, model_type_list))):
|
||||
raise AppApiException(ValidCode.valid_error.value,
|
||||
_('{model_type} Model type is not supported').format(model_type=model_type))
|
||||
|
||||
for key in ['api_key']:
|
||||
if key not in model_credential:
|
||||
if raise_exception:
|
||||
raise AppApiException(ValidCode.valid_error.value, _('{key} is required').format(key=key))
|
||||
else:
|
||||
return False
|
||||
try:
|
||||
model = provider.get_model(model_type, model_name, model_credential)
|
||||
model.check_auth()
|
||||
except Exception as e:
|
||||
traceback.print_exc()
|
||||
if isinstance(e, AppApiException):
|
||||
raise e
|
||||
if raise_exception:
|
||||
raise AppApiException(ValidCode.valid_error.value,
|
||||
_('Verification failed, please check whether the parameters are correct: {error}').format(
|
||||
error=str(e)))
|
||||
else:
|
||||
return False
|
||||
return True
|
||||
|
||||
def encryption_dict(self, model: Dict[str, object]):
|
||||
return {**model, 'api_key': super().encryption(model.get('api_key', ''))}
|
||||
|
||||
def get_model_params_setting_form(self, model_name):
|
||||
pass
|
||||
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue