mirror of
https://github.com/1Panel-dev/MaxKB.git
synced 2025-12-26 01:33:05 +00:00
feat(provider): added regolo.ai (#3041)
This commit is contained in:
parent
357edbfbe0
commit
0cf05a76a0
|
|
@ -19,6 +19,8 @@ from setting.models_provider.impl.kimi_model_provider.kimi_model_provider import
|
|||
from setting.models_provider.impl.ollama_model_provider.ollama_model_provider import OllamaModelProvider
|
||||
from setting.models_provider.impl.openai_model_provider.openai_model_provider import OpenAIModelProvider
|
||||
from setting.models_provider.impl.qwen_model_provider.qwen_model_provider import QwenModelProvider
|
||||
from setting.models_provider.impl.regolo_model_provider.regolo_model_provider import \
|
||||
RegoloModelProvider
|
||||
from setting.models_provider.impl.siliconCloud_model_provider.siliconCloud_model_provider import \
|
||||
SiliconCloudModelProvider
|
||||
from setting.models_provider.impl.tencent_cloud_model_provider.tencent_cloud_model_provider import \
|
||||
|
|
@ -55,3 +57,4 @@ class ModelProvideConstants(Enum):
|
|||
aliyun_bai_lian_model_provider = AliyunBaiLianModelProvider()
|
||||
model_anthropic_provider = AnthropicModelProvider()
|
||||
model_siliconCloud_provider = SiliconCloudModelProvider()
|
||||
model_regolo_provider = RegoloModelProvider()
|
||||
|
|
|
|||
|
|
@ -0,0 +1,8 @@
|
|||
# coding=utf-8
|
||||
"""
|
||||
@project: maxkb
|
||||
@Author:虎
|
||||
@file: __init__.py.py
|
||||
@date:2024/3/28 16:25
|
||||
@desc:
|
||||
"""
|
||||
|
|
@ -0,0 +1,52 @@
|
|||
# coding=utf-8
|
||||
"""
|
||||
@project: MaxKB
|
||||
@Author:虎
|
||||
@file: embedding.py
|
||||
@date:2024/7/12 16:45
|
||||
@desc:
|
||||
"""
|
||||
import traceback
|
||||
from typing import Dict
|
||||
|
||||
from django.utils.translation import gettext as _
|
||||
|
||||
from common import forms
|
||||
from common.exception.app_exception import AppApiException
|
||||
from common.forms import BaseForm
|
||||
from setting.models_provider.base_model_provider import BaseModelCredential, ValidCode
|
||||
|
||||
|
||||
class RegoloEmbeddingCredential(BaseForm, BaseModelCredential):
|
||||
def is_valid(self, model_type: str, model_name, model_credential: Dict[str, object], model_params, provider,
|
||||
raise_exception=True):
|
||||
model_type_list = provider.get_model_type_list()
|
||||
if not any(list(filter(lambda mt: mt.get('value') == model_type, model_type_list))):
|
||||
raise AppApiException(ValidCode.valid_error.value,
|
||||
_('{model_type} Model type is not supported').format(model_type=model_type))
|
||||
|
||||
for key in ['api_key']:
|
||||
if key not in model_credential:
|
||||
if raise_exception:
|
||||
raise AppApiException(ValidCode.valid_error.value, _('{key} is required').format(key=key))
|
||||
else:
|
||||
return False
|
||||
try:
|
||||
model = provider.get_model(model_type, model_name, model_credential)
|
||||
model.embed_query(_('Hello'))
|
||||
except Exception as e:
|
||||
traceback.print_exc()
|
||||
if isinstance(e, AppApiException):
|
||||
raise e
|
||||
if raise_exception:
|
||||
raise AppApiException(ValidCode.valid_error.value,
|
||||
_('Verification failed, please check whether the parameters are correct: {error}').format(
|
||||
error=str(e)))
|
||||
else:
|
||||
return False
|
||||
return True
|
||||
|
||||
def encryption_dict(self, model: Dict[str, object]):
|
||||
return {**model, 'api_key': super().encryption(model.get('api_key', ''))}
|
||||
|
||||
api_key = forms.PasswordInputField('API Key', required=True)
|
||||
|
|
@ -0,0 +1,74 @@
|
|||
# coding=utf-8
|
||||
import base64
|
||||
import os
|
||||
import traceback
|
||||
from typing import Dict
|
||||
|
||||
from langchain_core.messages import HumanMessage
|
||||
|
||||
from common import forms
|
||||
from common.exception.app_exception import AppApiException
|
||||
from common.forms import BaseForm, TooltipLabel
|
||||
from setting.models_provider.base_model_provider import BaseModelCredential, ValidCode
|
||||
from django.utils.translation import gettext_lazy as _, gettext
|
||||
|
||||
|
||||
class RegoloImageModelParams(BaseForm):
|
||||
temperature = forms.SliderField(TooltipLabel(_('Temperature'),
|
||||
_('Higher values make the output more random, while lower values make it more focused and deterministic')),
|
||||
required=True, default_value=0.7,
|
||||
_min=0.1,
|
||||
_max=1.0,
|
||||
_step=0.01,
|
||||
precision=2)
|
||||
|
||||
max_tokens = forms.SliderField(
|
||||
TooltipLabel(_('Output the maximum Tokens'),
|
||||
_('Specify the maximum number of tokens that the model can generate')),
|
||||
required=True, default_value=800,
|
||||
_min=1,
|
||||
_max=100000,
|
||||
_step=1,
|
||||
precision=0)
|
||||
|
||||
|
||||
class RegoloImageModelCredential(BaseForm, BaseModelCredential):
|
||||
api_base = forms.TextInputField('API URL', required=True)
|
||||
api_key = forms.PasswordInputField('API Key', required=True)
|
||||
|
||||
def is_valid(self, model_type: str, model_name, model_credential: Dict[str, object], model_params, provider,
|
||||
raise_exception=False):
|
||||
model_type_list = provider.get_model_type_list()
|
||||
if not any(list(filter(lambda mt: mt.get('value') == model_type, model_type_list))):
|
||||
raise AppApiException(ValidCode.valid_error.value,
|
||||
gettext('{model_type} Model type is not supported').format(model_type=model_type))
|
||||
|
||||
for key in ['api_key']:
|
||||
if key not in model_credential:
|
||||
if raise_exception:
|
||||
raise AppApiException(ValidCode.valid_error.value, gettext('{key} is required').format(key=key))
|
||||
else:
|
||||
return False
|
||||
try:
|
||||
model = provider.get_model(model_type, model_name, model_credential, **model_params)
|
||||
res = model.stream([HumanMessage(content=[{"type": "text", "text": gettext('Hello')}])])
|
||||
for chunk in res:
|
||||
print(chunk)
|
||||
except Exception as e:
|
||||
traceback.print_exc()
|
||||
if isinstance(e, AppApiException):
|
||||
raise e
|
||||
if raise_exception:
|
||||
raise AppApiException(ValidCode.valid_error.value,
|
||||
gettext(
|
||||
'Verification failed, please check whether the parameters are correct: {error}').format(
|
||||
error=str(e)))
|
||||
else:
|
||||
return False
|
||||
return True
|
||||
|
||||
def encryption_dict(self, model: Dict[str, object]):
|
||||
return {**model, 'api_key': super().encryption(model.get('api_key', ''))}
|
||||
|
||||
def get_model_params_setting_form(self, model_name):
|
||||
return RegoloImageModelParams()
|
||||
|
|
@ -0,0 +1,78 @@
|
|||
# coding=utf-8
|
||||
"""
|
||||
@project: MaxKB
|
||||
@Author:虎
|
||||
@file: llm.py
|
||||
@date:2024/7/11 18:32
|
||||
@desc:
|
||||
"""
|
||||
import traceback
|
||||
from typing import Dict
|
||||
|
||||
from django.utils.translation import gettext_lazy as _, gettext
|
||||
from langchain_core.messages import HumanMessage
|
||||
|
||||
from common import forms
|
||||
from common.exception.app_exception import AppApiException
|
||||
from common.forms import BaseForm, TooltipLabel
|
||||
from setting.models_provider.base_model_provider import BaseModelCredential, ValidCode
|
||||
|
||||
|
||||
class RegoloLLMModelParams(BaseForm):
|
||||
temperature = forms.SliderField(TooltipLabel(_('Temperature'),
|
||||
_('Higher values make the output more random, while lower values make it more focused and deterministic')),
|
||||
required=True, default_value=0.7,
|
||||
_min=0.1,
|
||||
_max=1.0,
|
||||
_step=0.01,
|
||||
precision=2)
|
||||
|
||||
max_tokens = forms.SliderField(
|
||||
TooltipLabel(_('Output the maximum Tokens'),
|
||||
_('Specify the maximum number of tokens that the model can generate')),
|
||||
required=True, default_value=800,
|
||||
_min=1,
|
||||
_max=100000,
|
||||
_step=1,
|
||||
precision=0)
|
||||
|
||||
|
||||
class RegoloLLMModelCredential(BaseForm, BaseModelCredential):
|
||||
|
||||
def is_valid(self, model_type: str, model_name, model_credential: Dict[str, object], model_params, provider,
|
||||
raise_exception=False):
|
||||
model_type_list = provider.get_model_type_list()
|
||||
if not any(list(filter(lambda mt: mt.get('value') == model_type, model_type_list))):
|
||||
raise AppApiException(ValidCode.valid_error.value,
|
||||
gettext('{model_type} Model type is not supported').format(model_type=model_type))
|
||||
|
||||
for key in ['api_key']:
|
||||
if key not in model_credential:
|
||||
if raise_exception:
|
||||
raise AppApiException(ValidCode.valid_error.value, gettext('{key} is required').format(key=key))
|
||||
else:
|
||||
return False
|
||||
try:
|
||||
|
||||
model = provider.get_model(model_type, model_name, model_credential, **model_params)
|
||||
model.invoke([HumanMessage(content=gettext('Hello'))])
|
||||
except Exception as e:
|
||||
traceback.print_exc()
|
||||
if isinstance(e, AppApiException):
|
||||
raise e
|
||||
if raise_exception:
|
||||
raise AppApiException(ValidCode.valid_error.value,
|
||||
gettext(
|
||||
'Verification failed, please check whether the parameters are correct: {error}').format(
|
||||
error=str(e)))
|
||||
else:
|
||||
return False
|
||||
return True
|
||||
|
||||
def encryption_dict(self, model: Dict[str, object]):
|
||||
return {**model, 'api_key': super().encryption(model.get('api_key', ''))}
|
||||
|
||||
api_key = forms.PasswordInputField('API Key', required=True)
|
||||
|
||||
def get_model_params_setting_form(self, model_name):
|
||||
return RegoloLLMModelParams()
|
||||
|
|
@ -0,0 +1,89 @@
|
|||
# coding=utf-8
|
||||
import traceback
|
||||
from typing import Dict
|
||||
|
||||
from django.utils.translation import gettext_lazy as _, gettext
|
||||
|
||||
from common import forms
|
||||
from common.exception.app_exception import AppApiException
|
||||
from common.forms import BaseForm, TooltipLabel
|
||||
from setting.models_provider.base_model_provider import BaseModelCredential, ValidCode
|
||||
|
||||
|
||||
class RegoloTTIModelParams(BaseForm):
|
||||
size = forms.SingleSelect(
|
||||
TooltipLabel(_('Image size'),
|
||||
_('The image generation endpoint allows you to create raw images based on text prompts. ')),
|
||||
required=True,
|
||||
default_value='1024x1024',
|
||||
option_list=[
|
||||
{'value': '1024x1024', 'label': '1024x1024'},
|
||||
{'value': '1024x1792', 'label': '1024x1792'},
|
||||
{'value': '1792x1024', 'label': '1792x1024'},
|
||||
],
|
||||
text_field='label',
|
||||
value_field='value'
|
||||
)
|
||||
|
||||
quality = forms.SingleSelect(
|
||||
TooltipLabel(_('Picture quality'), _('''
|
||||
By default, images are produced in standard quality.
|
||||
''')),
|
||||
required=True,
|
||||
default_value='standard',
|
||||
option_list=[
|
||||
{'value': 'standard', 'label': 'standard'},
|
||||
{'value': 'hd', 'label': 'hd'},
|
||||
],
|
||||
text_field='label',
|
||||
value_field='value'
|
||||
)
|
||||
|
||||
n = forms.SliderField(
|
||||
TooltipLabel(_('Number of pictures'),
|
||||
_('1 as default')),
|
||||
required=True, default_value=1,
|
||||
_min=1,
|
||||
_max=10,
|
||||
_step=1,
|
||||
precision=0)
|
||||
|
||||
|
||||
class RegoloTextToImageModelCredential(BaseForm, BaseModelCredential):
|
||||
api_key = forms.PasswordInputField('API Key', required=True)
|
||||
|
||||
def is_valid(self, model_type: str, model_name, model_credential: Dict[str, object], model_params, provider,
|
||||
raise_exception=False):
|
||||
model_type_list = provider.get_model_type_list()
|
||||
if not any(list(filter(lambda mt: mt.get('value') == model_type, model_type_list))):
|
||||
raise AppApiException(ValidCode.valid_error.value,
|
||||
gettext('{model_type} Model type is not supported').format(model_type=model_type))
|
||||
|
||||
for key in ['api_key']:
|
||||
if key not in model_credential:
|
||||
if raise_exception:
|
||||
raise AppApiException(ValidCode.valid_error.value, gettext('{key} is required').format(key=key))
|
||||
else:
|
||||
return False
|
||||
try:
|
||||
model = provider.get_model(model_type, model_name, model_credential, **model_params)
|
||||
res = model.check_auth()
|
||||
print(res)
|
||||
except Exception as e:
|
||||
traceback.print_exc()
|
||||
if isinstance(e, AppApiException):
|
||||
raise e
|
||||
if raise_exception:
|
||||
raise AppApiException(ValidCode.valid_error.value,
|
||||
gettext(
|
||||
'Verification failed, please check whether the parameters are correct: {error}').format(
|
||||
error=str(e)))
|
||||
else:
|
||||
return False
|
||||
return True
|
||||
|
||||
def encryption_dict(self, model: Dict[str, object]):
|
||||
return {**model, 'api_key': super().encryption(model.get('api_key', ''))}
|
||||
|
||||
def get_model_params_setting_form(self, model_name):
|
||||
return RegoloTTIModelParams()
|
||||
|
|
@ -0,0 +1,65 @@
|
|||
<?xml version="1.0" encoding="UTF-8" standalone="no"?>
|
||||
<svg
|
||||
id="Livello_2"
|
||||
data-name="Livello 2"
|
||||
viewBox="0 0 104.4 104.38"
|
||||
version="1.1"
|
||||
sodipodi:docname="Regolo_logo_positive.svg"
|
||||
width="104.4"
|
||||
height="104.38"
|
||||
inkscape:version="1.4 (e7c3feb100, 2024-10-09)"
|
||||
xmlns:inkscape="http://www.inkscape.org/namespaces/inkscape"
|
||||
xmlns:sodipodi="http://sodipodi.sourceforge.net/DTD/sodipodi-0.dtd"
|
||||
xmlns="http://www.w3.org/2000/svg"
|
||||
xmlns:svg="http://www.w3.org/2000/svg">
|
||||
<sodipodi:namedview
|
||||
id="namedview13"
|
||||
pagecolor="#ffffff"
|
||||
bordercolor="#666666"
|
||||
borderopacity="1.0"
|
||||
inkscape:showpageshadow="2"
|
||||
inkscape:pageopacity="0.0"
|
||||
inkscape:pagecheckerboard="0"
|
||||
inkscape:deskcolor="#d1d1d1"
|
||||
inkscape:zoom="2.1335227"
|
||||
inkscape:cx="119.05193"
|
||||
inkscape:cy="48.511318"
|
||||
inkscape:window-width="1920"
|
||||
inkscape:window-height="1025"
|
||||
inkscape:window-x="0"
|
||||
inkscape:window-y="0"
|
||||
inkscape:window-maximized="1"
|
||||
inkscape:current-layer="g13" />
|
||||
<defs
|
||||
id="defs1">
|
||||
<style
|
||||
id="style1">
|
||||
.cls-1 {
|
||||
fill: #303030;
|
||||
}
|
||||
|
||||
.cls-2 {
|
||||
fill: #59e389;
|
||||
}
|
||||
</style>
|
||||
</defs>
|
||||
<g
|
||||
id="Grafica"
|
||||
transform="translate(0,-40.87)">
|
||||
<g
|
||||
id="g13">
|
||||
<path
|
||||
class="cls-1"
|
||||
d="m 104.39,105.96 v 36.18 c 0,0.32 -0.05,0.62 -0.14,0.91 -0.39,1.27 -1.58,2.2 -2.99,2.2 H 65.08 c -1.73,0 -3.13,-1.41 -3.13,-3.13 V 113.4 c 0,-0.15 0,-0.29 0,-0.44 v -7 c 0,-1.73 1.4,-3.13 3.13,-3.13 h 36.19 c 1.5,0 2.77,1.07 3.06,2.5 0.05,0.21 0.07,0.41 0.07,0.63 z"
|
||||
id="path1" />
|
||||
<path
|
||||
class="cls-1"
|
||||
d="m 104.39,105.96 v 36.18 c 0,0.32 -0.05,0.62 -0.14,0.91 -0.39,1.27 -1.58,2.2 -2.99,2.2 H 65.08 c -1.73,0 -3.13,-1.41 -3.13,-3.13 V 113.4 c 0,-0.15 0,-0.29 0,-0.44 v -7 c 0,-1.73 1.4,-3.13 3.13,-3.13 h 36.19 c 1.5,0 2.77,1.07 3.06,2.5 0.05,0.21 0.07,0.41 0.07,0.63 z"
|
||||
id="path2" />
|
||||
<path
|
||||
class="cls-2"
|
||||
d="M 101.27,40.88 H 65.09 c -1.73,0 -3.13,1.4 -3.13,3.13 v 28.71 c 0,4.71 -1.88,9.23 -5.2,12.56 L 44.42,97.61 c -3.32,3.33 -7.85,5.2 -12.55,5.2 H 18.98 c -2.21,0 -3.99,-1.79 -3.99,-3.99 V 87.29 c 0,-2.21 1.79,-3.99 3.99,-3.99 h 20.34 c 1.41,0 2.59,-0.93 2.99,-2.2 0.09,-0.29 0.14,-0.59 0.14,-0.91 V 44 c 0,-0.22 -0.02,-0.42 -0.07,-0.63 -0.29,-1.43 -1.56,-2.5 -3.06,-2.5 H 3.13 C 1.4,40.87 0,42.27 0,44 v 7 c 0,0.15 0,0.29 0,0.44 v 28.72 c 0,1.72 1.41,3.13 3.13,3.13 h 3.16 c 2.21,0 3.99,1.79 3.99,3.99 v 11.53 c 0,2.21 -1.79,3.99 -3.99,3.99 H 3.15 c -1.73,0 -3.13,1.4 -3.13,3.13 v 36.19 c 0,1.72 1.41,3.13 3.13,3.13 h 36.19 c 1.73,0 3.13,-1.41 3.13,-3.13 V 113.4 c 0,-4.7 1.87,-9.23 5.2,-12.55 L 60,88.51 c 3.33,-3.32 7.85,-5.2 12.56,-5.2 h 28.71 c 1.73,0 3.13,-1.4 3.13,-3.13 V 44 c 0,-1.73 -1.4,-3.13 -3.13,-3.13 z"
|
||||
id="path3" />
|
||||
</g>
|
||||
</g>
|
||||
</svg>
|
||||
|
After Width: | Height: | Size: 2.8 KiB |
|
|
@ -0,0 +1,23 @@
|
|||
# coding=utf-8
|
||||
"""
|
||||
@project: MaxKB
|
||||
@Author:虎
|
||||
@file: embedding.py
|
||||
@date:2024/7/12 17:44
|
||||
@desc:
|
||||
"""
|
||||
from typing import Dict
|
||||
|
||||
from langchain_community.embeddings import OpenAIEmbeddings
|
||||
|
||||
from setting.models_provider.base_model_provider import MaxKBBaseModel
|
||||
|
||||
|
||||
class RegoloEmbeddingModel(MaxKBBaseModel, OpenAIEmbeddings):
|
||||
@staticmethod
|
||||
def new_instance(model_type, model_name, model_credential: Dict[str, object], **model_kwargs):
|
||||
return RegoloEmbeddingModel(
|
||||
api_key=model_credential.get('api_key'),
|
||||
model=model_name,
|
||||
openai_api_base="https://api.regolo.ai/v1",
|
||||
)
|
||||
|
|
@ -0,0 +1,19 @@
|
|||
from typing import Dict
|
||||
|
||||
from setting.models_provider.base_model_provider import MaxKBBaseModel
|
||||
from setting.models_provider.impl.base_chat_open_ai import BaseChatOpenAI
|
||||
|
||||
|
||||
class RegoloImage(MaxKBBaseModel, BaseChatOpenAI):
|
||||
|
||||
@staticmethod
|
||||
def new_instance(model_type, model_name, model_credential: Dict[str, object], **model_kwargs):
|
||||
optional_params = MaxKBBaseModel.filter_optional_params(model_kwargs)
|
||||
return RegoloImage(
|
||||
model_name=model_name,
|
||||
openai_api_base="https://api.regolo.ai/v1",
|
||||
openai_api_key=model_credential.get('api_key'),
|
||||
streaming=True,
|
||||
stream_usage=True,
|
||||
extra_body=optional_params
|
||||
)
|
||||
|
|
@ -0,0 +1,38 @@
|
|||
# coding=utf-8
|
||||
"""
|
||||
@project: maxkb
|
||||
@Author:虎
|
||||
@file: llm.py
|
||||
@date:2024/4/18 15:28
|
||||
@desc:
|
||||
"""
|
||||
from typing import List, Dict
|
||||
|
||||
from langchain_core.messages import BaseMessage, get_buffer_string
|
||||
from langchain_openai.chat_models import ChatOpenAI
|
||||
|
||||
from common.config.tokenizer_manage_config import TokenizerManage
|
||||
from setting.models_provider.base_model_provider import MaxKBBaseModel
|
||||
from setting.models_provider.impl.base_chat_open_ai import BaseChatOpenAI
|
||||
|
||||
|
||||
def custom_get_token_ids(text: str):
|
||||
tokenizer = TokenizerManage.get_tokenizer()
|
||||
return tokenizer.encode(text)
|
||||
|
||||
|
||||
class RegoloChatModel(MaxKBBaseModel, BaseChatOpenAI):
|
||||
|
||||
@staticmethod
|
||||
def is_cache_model():
|
||||
return False
|
||||
|
||||
@staticmethod
|
||||
def new_instance(model_type, model_name, model_credential: Dict[str, object], **model_kwargs):
|
||||
optional_params = MaxKBBaseModel.filter_optional_params(model_kwargs)
|
||||
return RegoloChatModel(
|
||||
model=model_name,
|
||||
openai_api_base="https://api.regolo.ai/v1",
|
||||
openai_api_key=model_credential.get('api_key'),
|
||||
extra_body=optional_params
|
||||
)
|
||||
|
|
@ -0,0 +1,58 @@
|
|||
from typing import Dict
|
||||
|
||||
from openai import OpenAI
|
||||
|
||||
from common.config.tokenizer_manage_config import TokenizerManage
|
||||
from setting.models_provider.base_model_provider import MaxKBBaseModel
|
||||
from setting.models_provider.impl.base_tti import BaseTextToImage
|
||||
|
||||
|
||||
def custom_get_token_ids(text: str):
|
||||
tokenizer = TokenizerManage.get_tokenizer()
|
||||
return tokenizer.encode(text)
|
||||
|
||||
|
||||
class RegoloTextToImage(MaxKBBaseModel, BaseTextToImage):
|
||||
api_base: str
|
||||
api_key: str
|
||||
model: str
|
||||
params: dict
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
self.api_key = kwargs.get('api_key')
|
||||
self.api_base = "https://api.regolo.ai/v1"
|
||||
self.model = kwargs.get('model')
|
||||
self.params = kwargs.get('params')
|
||||
|
||||
@staticmethod
|
||||
def new_instance(model_type, model_name, model_credential: Dict[str, object], **model_kwargs):
|
||||
optional_params = {'params': {'size': '1024x1024', 'quality': 'standard', 'n': 1}}
|
||||
for key, value in model_kwargs.items():
|
||||
if key not in ['model_id', 'use_local', 'streaming']:
|
||||
optional_params['params'][key] = value
|
||||
return RegoloTextToImage(
|
||||
model=model_name,
|
||||
api_base="https://api.regolo.ai/v1",
|
||||
api_key=model_credential.get('api_key'),
|
||||
**optional_params,
|
||||
)
|
||||
|
||||
def is_cache_model(self):
|
||||
return False
|
||||
|
||||
def check_auth(self):
|
||||
chat = OpenAI(api_key=self.api_key, base_url=self.api_base)
|
||||
response_list = chat.models.with_raw_response.list()
|
||||
|
||||
# self.generate_image('生成一个小猫图片')
|
||||
|
||||
def generate_image(self, prompt: str, negative_prompt: str = None):
|
||||
chat = OpenAI(api_key=self.api_key, base_url=self.api_base)
|
||||
res = chat.images.generate(model=self.model, prompt=prompt, **self.params)
|
||||
file_urls = []
|
||||
for content in res.data:
|
||||
url = content.url
|
||||
file_urls.append(url)
|
||||
|
||||
return file_urls
|
||||
|
|
@ -0,0 +1,89 @@
|
|||
# coding=utf-8
|
||||
"""
|
||||
@project: maxkb
|
||||
@Author:虎
|
||||
@file: openai_model_provider.py
|
||||
@date:2024/3/28 16:26
|
||||
@desc:
|
||||
"""
|
||||
import os
|
||||
|
||||
from common.util.file_util import get_file_content
|
||||
from setting.models_provider.base_model_provider import IModelProvider, ModelProvideInfo, ModelInfo, \
|
||||
ModelTypeConst, ModelInfoManage
|
||||
from setting.models_provider.impl.regolo_model_provider.credential.embedding import \
|
||||
RegoloEmbeddingCredential
|
||||
from setting.models_provider.impl.regolo_model_provider.credential.llm import RegoloLLMModelCredential
|
||||
from setting.models_provider.impl.regolo_model_provider.credential.tti import \
|
||||
RegoloTextToImageModelCredential
|
||||
from setting.models_provider.impl.regolo_model_provider.model.embedding import RegoloEmbeddingModel
|
||||
from setting.models_provider.impl.regolo_model_provider.model.llm import RegoloChatModel
|
||||
from setting.models_provider.impl.regolo_model_provider.model.tti import RegoloTextToImage
|
||||
from smartdoc.conf import PROJECT_DIR
|
||||
from django.utils.translation import gettext as _
|
||||
|
||||
openai_llm_model_credential = RegoloLLMModelCredential()
|
||||
openai_tti_model_credential = RegoloTextToImageModelCredential()
|
||||
model_info_list = [
|
||||
ModelInfo('Phi-4', '', ModelTypeConst.LLM,
|
||||
openai_llm_model_credential, RegoloChatModel
|
||||
),
|
||||
ModelInfo('DeepSeek-R1-Distill-Qwen-32B', '', ModelTypeConst.LLM,
|
||||
openai_llm_model_credential,
|
||||
RegoloChatModel),
|
||||
ModelInfo('maestrale-chat-v0.4-beta', '',
|
||||
ModelTypeConst.LLM, openai_llm_model_credential,
|
||||
RegoloChatModel),
|
||||
ModelInfo('Llama-3.3-70B-Instruct',
|
||||
'',
|
||||
ModelTypeConst.LLM, openai_llm_model_credential,
|
||||
RegoloChatModel),
|
||||
ModelInfo('Llama-3.1-8B-Instruct',
|
||||
'',
|
||||
ModelTypeConst.LLM, openai_llm_model_credential,
|
||||
RegoloChatModel),
|
||||
ModelInfo('DeepSeek-Coder-6.7B-Instruct', '',
|
||||
ModelTypeConst.LLM, openai_llm_model_credential,
|
||||
RegoloChatModel)
|
||||
]
|
||||
open_ai_embedding_credential = RegoloEmbeddingCredential()
|
||||
model_info_embedding_list = [
|
||||
ModelInfo('gte-Qwen2', '',
|
||||
ModelTypeConst.EMBEDDING, open_ai_embedding_credential,
|
||||
RegoloEmbeddingModel),
|
||||
]
|
||||
|
||||
model_info_tti_list = [
|
||||
ModelInfo('FLUX.1-dev', '',
|
||||
ModelTypeConst.TTI, openai_tti_model_credential,
|
||||
RegoloTextToImage),
|
||||
ModelInfo('sdxl-turbo', '',
|
||||
ModelTypeConst.TTI, openai_tti_model_credential,
|
||||
RegoloTextToImage),
|
||||
]
|
||||
model_info_manage = (
|
||||
ModelInfoManage.builder()
|
||||
.append_model_info_list(model_info_list)
|
||||
.append_default_model_info(
|
||||
ModelInfo('gpt-3.5-turbo', _('The latest gpt-3.5-turbo, updated with OpenAI adjustments'), ModelTypeConst.LLM,
|
||||
openai_llm_model_credential, RegoloChatModel
|
||||
))
|
||||
.append_model_info_list(model_info_embedding_list)
|
||||
.append_default_model_info(model_info_embedding_list[0])
|
||||
.append_model_info_list(model_info_tti_list)
|
||||
.append_default_model_info(model_info_tti_list[0])
|
||||
|
||||
.build()
|
||||
)
|
||||
|
||||
|
||||
class RegoloModelProvider(IModelProvider):
|
||||
|
||||
def get_model_info_manage(self):
|
||||
return model_info_manage
|
||||
|
||||
def get_model_provide_info(self):
|
||||
return ModelProvideInfo(provider='model_regolo_provider', name='Regolo', icon=get_file_content(
|
||||
os.path.join(PROJECT_DIR, "apps", "setting", 'models_provider', 'impl', 'regolo_model_provider',
|
||||
'icon',
|
||||
'regolo_icon_svg')))
|
||||
Loading…
Reference in New Issue