mirror of
https://github.com/1Panel-dev/MaxKB.git
synced 2025-12-26 01:33:05 +00:00
refactor: 模型设置支持配置参数
This commit is contained in:
parent
3b995d34eb
commit
67f3ec34a1
|
|
@ -8,6 +8,7 @@
|
|||
"""
|
||||
from typing import List, Dict
|
||||
|
||||
from common.forms import BaseLabel
|
||||
from common.forms.base_field import TriggerType, BaseExecField
|
||||
|
||||
|
||||
|
|
@ -17,7 +18,7 @@ class SingleSelect(BaseExecField):
|
|||
"""
|
||||
|
||||
def __init__(self,
|
||||
label: str,
|
||||
label: str or BaseLabel,
|
||||
text_field: str,
|
||||
value_field: str,
|
||||
option_list: List[str:object],
|
||||
|
|
|
|||
|
|
@ -0,0 +1,23 @@
|
|||
# Generated by Django 4.2.15 on 2024-10-15 14:49
|
||||
|
||||
from django.db import migrations, models
|
||||
|
||||
|
||||
class Migration(migrations.Migration):
|
||||
|
||||
dependencies = [
|
||||
('dataset', '0008_alter_document_status_alter_paragraph_status'),
|
||||
]
|
||||
|
||||
operations = [
|
||||
migrations.AlterField(
|
||||
model_name='document',
|
||||
name='status',
|
||||
field=models.CharField(choices=[('0', '导入中'), ('1', '已完成'), ('2', '导入失败'), ('3', '排队中'), ('4', '生成问题中')], default='3', max_length=1, verbose_name='状态'),
|
||||
),
|
||||
migrations.AlterField(
|
||||
model_name='paragraph',
|
||||
name='status',
|
||||
field=models.CharField(choices=[('0', '导入中'), ('1', '已完成'), ('2', '导入失败'), ('3', '排队中'), ('4', '生成问题中')], default='0', max_length=1, verbose_name='状态'),
|
||||
),
|
||||
]
|
||||
|
|
@ -0,0 +1,18 @@
|
|||
# Generated by Django 4.2.15 on 2024-10-15 14:49
|
||||
|
||||
from django.db import migrations, models
|
||||
|
||||
|
||||
class Migration(migrations.Migration):
|
||||
|
||||
dependencies = [
|
||||
('setting', '0006_alter_model_status'),
|
||||
]
|
||||
|
||||
operations = [
|
||||
migrations.AddField(
|
||||
model_name='model',
|
||||
name='model_params_form',
|
||||
field=models.JSONField(default=list, verbose_name='模型参数配置'),
|
||||
),
|
||||
]
|
||||
|
|
@ -29,6 +29,13 @@ class PermissionType(models.TextChoices):
|
|||
PUBLIC = "PUBLIC", '公开'
|
||||
PRIVATE = "PRIVATE", "私有"
|
||||
|
||||
class ModelParam(models.Model):
|
||||
label = models.CharField(max_length=128, verbose_name="参数")
|
||||
field = models.CharField(max_length=256, verbose_name="显示名称")
|
||||
default_value = models.CharField(max_length=1000, verbose_name="默认值")
|
||||
input_type = models.CharField(max_length=32, verbose_name="组件类型")
|
||||
attrs = models.JSONField(verbose_name="属性")
|
||||
required = models.BooleanField(verbose_name="必填")
|
||||
|
||||
class Model(AppModelMixin):
|
||||
"""
|
||||
|
|
@ -56,6 +63,9 @@ class Model(AppModelMixin):
|
|||
permission_type = models.CharField(max_length=20, verbose_name='权限类型', choices=PermissionType.choices,
|
||||
default=PermissionType.PRIVATE)
|
||||
|
||||
model_params_form = models.JSONField(verbose_name="模型参数配置", default=list)
|
||||
|
||||
|
||||
def is_permission(self, user_id):
|
||||
if self.permission_type == PermissionType.PUBLIC or str(user_id) == str(self.user_id):
|
||||
return True
|
||||
|
|
|
|||
|
|
@ -35,7 +35,8 @@ model_info_list = [ModelInfo('gte-rerank',
|
|||
ModelTypeConst.TTS, aliyun_bai_lian_tts_model_credential, AliyunBaiLianTextToSpeech),
|
||||
]
|
||||
|
||||
model_info_manage = ModelInfoManage.builder().append_model_info_list(model_info_list).build()
|
||||
model_info_manage = ModelInfoManage.builder().append_model_info_list(model_info_list).append_default_model_info(
|
||||
model_info_list[1]).append_default_model_info(model_info_list[2]).build()
|
||||
|
||||
|
||||
class AliyunBaiLianModelProvider(IModelProvider):
|
||||
|
|
|
|||
|
|
@ -4,10 +4,44 @@ from typing import Dict
|
|||
|
||||
from common import forms
|
||||
from common.exception.app_exception import AppApiException
|
||||
from common.forms import BaseForm
|
||||
from common.forms import BaseForm, TooltipLabel
|
||||
from setting.models_provider.base_model_provider import BaseModelCredential, ValidCode
|
||||
|
||||
|
||||
class AliyunBaiLianTTSModelGeneralParams(BaseForm):
|
||||
voice = forms.SingleSelect(
|
||||
TooltipLabel('音色', '中文音色可支持中英文混合场景'),
|
||||
required=True, default_value='longxiaochun',
|
||||
text_field='text',
|
||||
value_field='value',
|
||||
option_list=[
|
||||
{'text': '龙小淳', 'value': 'longxiaochun'},
|
||||
{'text': '龙小夏', 'value': 'longxiaoxia'},
|
||||
{'text': '龙小诚', 'value': 'longxiaocheng'},
|
||||
{'text': '龙小白', 'value': 'longxiaobai'},
|
||||
{'text': '龙老铁', 'value': 'longlaotie'},
|
||||
{'text': '龙书', 'value': 'longshu'},
|
||||
{'text': '龙硕', 'value': 'longshuo'},
|
||||
{'text': '龙婧', 'value': 'longjing'},
|
||||
{'text': '龙妙', 'value': 'longmiao'},
|
||||
{'text': '龙悦', 'value': 'longyue'},
|
||||
{'text': '龙媛', 'value': 'longyuan'},
|
||||
{'text': '龙飞', 'value': 'longfei'},
|
||||
{'text': '龙杰力豆', 'value': 'longjielidou'},
|
||||
{'text': '龙彤', 'value': 'longtong'},
|
||||
{'text': '龙祥', 'value': 'longxiang'},
|
||||
{'text': 'Stella', 'value': 'loongstella'},
|
||||
{'text': 'Bella', 'value': 'loongbella'},
|
||||
])
|
||||
speech_rate = forms.SliderField(
|
||||
TooltipLabel('语速', '[0.5,2],默认为1,通常保留一位小数即可'),
|
||||
required=True, default_value=1,
|
||||
_min=0.5,
|
||||
_max=2,
|
||||
_step=0.1,
|
||||
precision=1)
|
||||
|
||||
|
||||
class AliyunBaiLianTTSModelCredential(BaseForm, BaseModelCredential):
|
||||
api_key = forms.PasswordInputField("API Key", required=True)
|
||||
|
||||
|
|
@ -38,6 +72,5 @@ class AliyunBaiLianTTSModelCredential(BaseForm, BaseModelCredential):
|
|||
def encryption_dict(self, model: Dict[str, object]):
|
||||
return {**model, 'api_key': super().encryption(model.get('api_key', ''))}
|
||||
|
||||
|
||||
def get_model_params_setting_form(self, model_name):
|
||||
pass
|
||||
return AliyunBaiLianTTSModelGeneralParams()
|
||||
|
|
|
|||
|
|
@ -0,0 +1,58 @@
|
|||
# coding=utf-8
|
||||
from typing import Dict
|
||||
|
||||
from common import forms
|
||||
from common.exception.app_exception import AppApiException
|
||||
from common.forms import BaseForm, TooltipLabel
|
||||
from setting.models_provider.base_model_provider import BaseModelCredential, ValidCode
|
||||
|
||||
class OpenAITTSModelGeneralParams(BaseForm):
|
||||
# alloy, echo, fable, onyx, nova, shimmer
|
||||
voice = forms.SingleSelect(
|
||||
TooltipLabel('Voice', '尝试不同的声音(合金、回声、寓言、缟玛瑙、新星和闪光),找到一种适合您所需的音调和听众的声音。当前的语音针对英语进行了优化。'),
|
||||
required=True, default_value='alloy',
|
||||
text_field='text',
|
||||
value_field='value',
|
||||
option_list=[
|
||||
{'text': 'alloy', 'value': 'alloy'},
|
||||
{'text': 'echo', 'value': 'echo'},
|
||||
{'text': 'fable', 'value': 'fable'},
|
||||
{'text': 'onyx', 'value': 'onyx'},
|
||||
{'text': 'nova', 'value': 'nova'},
|
||||
{'text': 'shimmer', 'value': 'shimmer'},
|
||||
])
|
||||
|
||||
|
||||
class OpenAITTSModelCredential(BaseForm, BaseModelCredential):
|
||||
api_base = forms.TextInputField('API 域名', required=True)
|
||||
api_key = forms.PasswordInputField('API Key', required=True)
|
||||
|
||||
def is_valid(self, model_type: str, model_name, model_credential: Dict[str, object], provider,
|
||||
raise_exception=False):
|
||||
model_type_list = provider.get_model_type_list()
|
||||
if not any(list(filter(lambda mt: mt.get('value') == model_type, model_type_list))):
|
||||
raise AppApiException(ValidCode.valid_error.value, f'{model_type} 模型类型不支持')
|
||||
|
||||
for key in ['api_base', 'api_key']:
|
||||
if key not in model_credential:
|
||||
if raise_exception:
|
||||
raise AppApiException(ValidCode.valid_error.value, f'{key} 字段为必填字段')
|
||||
else:
|
||||
return False
|
||||
try:
|
||||
model = provider.get_model(model_type, model_name, model_credential)
|
||||
model.check_auth()
|
||||
except Exception as e:
|
||||
if isinstance(e, AppApiException):
|
||||
raise e
|
||||
if raise_exception:
|
||||
raise AppApiException(ValidCode.valid_error.value, f'校验失败,请检查参数是否正确: {str(e)}')
|
||||
else:
|
||||
return False
|
||||
return True
|
||||
|
||||
def encryption_dict(self, model: Dict[str, object]):
|
||||
return {**model, 'api_key': super().encryption(model.get('api_key', ''))}
|
||||
|
||||
def get_model_params_setting_form(self, model_name):
|
||||
return OpenAITTSModelGeneralParams()
|
||||
|
|
@ -14,6 +14,7 @@ from setting.models_provider.base_model_provider import IModelProvider, ModelPro
|
|||
from setting.models_provider.impl.openai_model_provider.credential.embedding import OpenAIEmbeddingCredential
|
||||
from setting.models_provider.impl.openai_model_provider.credential.llm import OpenAILLMModelCredential
|
||||
from setting.models_provider.impl.openai_model_provider.credential.stt import OpenAISTTModelCredential
|
||||
from setting.models_provider.impl.openai_model_provider.credential.tts import OpenAITTSModelCredential
|
||||
from setting.models_provider.impl.openai_model_provider.model.embedding import OpenAIEmbeddingModel
|
||||
from setting.models_provider.impl.openai_model_provider.model.llm import OpenAIChatModel
|
||||
from setting.models_provider.impl.openai_model_provider.model.stt import OpenAISpeechToText
|
||||
|
|
@ -22,6 +23,7 @@ from smartdoc.conf import PROJECT_DIR
|
|||
|
||||
openai_llm_model_credential = OpenAILLMModelCredential()
|
||||
openai_stt_model_credential = OpenAISTTModelCredential()
|
||||
openai_tts_model_credential = OpenAITTSModelCredential()
|
||||
model_info_list = [
|
||||
ModelInfo('gpt-3.5-turbo', '最新的gpt-3.5-turbo,随OpenAI调整而更新', ModelTypeConst.LLM,
|
||||
openai_llm_model_credential, OpenAIChatModel
|
||||
|
|
@ -70,7 +72,7 @@ model_info_list = [
|
|||
ModelTypeConst.STT, openai_stt_model_credential,
|
||||
OpenAISpeechToText),
|
||||
ModelInfo('tts-1', '',
|
||||
ModelTypeConst.TTS, openai_stt_model_credential,
|
||||
ModelTypeConst.TTS, openai_tts_model_credential,
|
||||
OpenAITextToSpeech)
|
||||
]
|
||||
open_ai_embedding_credential = OpenAIEmbeddingCredential()
|
||||
|
|
|
|||
|
|
@ -4,10 +4,38 @@ from typing import Dict
|
|||
|
||||
from common import forms
|
||||
from common.exception.app_exception import AppApiException
|
||||
from common.forms import BaseForm
|
||||
from common.forms import BaseForm, TooltipLabel
|
||||
from setting.models_provider.base_model_provider import BaseModelCredential, ValidCode
|
||||
|
||||
|
||||
class VolcanicEngineTTSModelGeneralParams(BaseForm):
|
||||
voice_type = forms.SingleSelect(
|
||||
TooltipLabel('音色', '中文音色可支持中英文混合场景'),
|
||||
required=True, default_value='BV002_streaming',
|
||||
text_field='text',
|
||||
value_field='value',
|
||||
option_list=[
|
||||
{'text': '灿灿 2.0', 'value': 'BV700_V2_streaming'},
|
||||
{'text': '炀炀', 'value': 'BV705_streaming'},
|
||||
{'text': '擎苍 2.0', 'value': 'BV701_V2_streaming'},
|
||||
{'text': '通用女声 2.0', 'value': 'BV001_V2_streaming'},
|
||||
{'text': '灿灿', 'value': 'BV700_streaming'},
|
||||
{'text': '超自然音色-梓梓2.0', 'value': 'BV406_V2_streaming'},
|
||||
{'text': '超自然音色-梓梓', 'value': 'BV406_streaming'},
|
||||
{'text': '超自然音色-燃燃2.0', 'value': 'BV407_V2_streaming'},
|
||||
{'text': '超自然音色-燃燃', 'value': 'BV407_streaming'},
|
||||
{'text': '通用女声', 'value': 'BV001_streaming'},
|
||||
{'text': '通用男声', 'value': 'BV002_streaming'},
|
||||
])
|
||||
speed_ratio = forms.SliderField(
|
||||
TooltipLabel('语速', '[0.2,3],默认为1,通常保留一位小数即可'),
|
||||
required=True, default_value=1,
|
||||
_min=0.2,
|
||||
_max=3,
|
||||
_step=0.1,
|
||||
precision=1)
|
||||
|
||||
|
||||
class VolcanicEngineTTSModelCredential(BaseForm, BaseModelCredential):
|
||||
volcanic_api_url = forms.TextInputField('API 域名', required=True, default_value='wss://openspeech.bytedance.com/api/v1/tts/ws_binary')
|
||||
volcanic_app_id = forms.TextInputField('App ID', required=True)
|
||||
|
|
@ -42,4 +70,4 @@ class VolcanicEngineTTSModelCredential(BaseForm, BaseModelCredential):
|
|||
return {**model, 'volcanic_token': super().encryption(model.get('volcanic_token', ''))}
|
||||
|
||||
def get_model_params_setting_form(self, model_name):
|
||||
pass
|
||||
return VolcanicEngineTTSModelGeneralParams()
|
||||
|
|
|
|||
|
|
@ -4,10 +4,32 @@ from typing import Dict
|
|||
|
||||
from common import forms
|
||||
from common.exception.app_exception import AppApiException
|
||||
from common.forms import BaseForm
|
||||
from common.forms import BaseForm, TooltipLabel
|
||||
from setting.models_provider.base_model_provider import BaseModelCredential, ValidCode
|
||||
|
||||
|
||||
class XunFeiTTSModelGeneralParams(BaseForm):
|
||||
vcn = forms.SingleSelect(
|
||||
TooltipLabel('发音人', '发音人,可选值:请到控制台添加试用或购买发音人,添加后即显示发音人参数值'),
|
||||
required=True, default_value='xiaoyan',
|
||||
text_field='text',
|
||||
value_field='value',
|
||||
option_list=[
|
||||
{'text': '讯飞小燕', 'value': 'xiaoyan'},
|
||||
{'text': '讯飞许久', 'value': 'aisjiuxu'},
|
||||
{'text': '讯飞小萍', 'value': 'aisxping'},
|
||||
{'text': '讯飞小婧', 'value': 'aisjinger'},
|
||||
{'text': '讯飞许小宝', 'value': 'aisbabyxu'},
|
||||
])
|
||||
speed = forms.SliderField(
|
||||
TooltipLabel('语速', '语速,可选值:[0-100],默认为50'),
|
||||
required=True, default_value=50,
|
||||
_min=1,
|
||||
_max=100,
|
||||
_step=5,
|
||||
precision=1)
|
||||
|
||||
|
||||
class XunFeiTTSModelCredential(BaseForm, BaseModelCredential):
|
||||
spark_api_url = forms.TextInputField('API 域名', required=True, default_value='wss://tts-api.xfyun.cn/v2/tts')
|
||||
spark_app_id = forms.TextInputField('APP ID', required=True)
|
||||
|
|
@ -41,6 +63,5 @@ class XunFeiTTSModelCredential(BaseForm, BaseModelCredential):
|
|||
def encryption_dict(self, model: Dict[str, object]):
|
||||
return {**model, 'spark_api_secret': super().encryption(model.get('spark_api_secret', ''))}
|
||||
|
||||
|
||||
def get_model_params_setting_form(self, model_name):
|
||||
pass
|
||||
return XunFeiTTSModelGeneralParams()
|
||||
|
|
|
|||
|
|
@ -0,0 +1,60 @@
|
|||
# coding=utf-8
|
||||
from typing import Dict
|
||||
|
||||
from common import forms
|
||||
from common.exception.app_exception import AppApiException
|
||||
from common.forms import BaseForm, TooltipLabel
|
||||
from setting.models_provider.base_model_provider import BaseModelCredential, ValidCode
|
||||
|
||||
|
||||
class XInferenceTTSModelGeneralParams(BaseForm):
|
||||
# ['中文女', '中文男', '日语男', '粤语女', '英文女', '英文男', '韩语女']
|
||||
voice = forms.SingleSelect(
|
||||
TooltipLabel('音色', ''),
|
||||
required=True, default_value='中文女',
|
||||
text_field='text',
|
||||
value_field='value',
|
||||
option_list=[
|
||||
{'text': '中文女', 'value': '中文女'},
|
||||
{'text': '中文男', 'value': '中文男'},
|
||||
{'text': '日语男', 'value': '日语男'},
|
||||
{'text': '粤语女', 'value': '粤语女'},
|
||||
{'text': '英文女', 'value': '英文女'},
|
||||
{'text': '英文男', 'value': '英文男'},
|
||||
{'text': '韩语女', 'value': '韩语女'},
|
||||
])
|
||||
|
||||
|
||||
class XInferenceTTSModelCredential(BaseForm, BaseModelCredential):
|
||||
api_base = forms.TextInputField('API 域名', required=True)
|
||||
api_key = forms.PasswordInputField('API Key', required=True)
|
||||
|
||||
def is_valid(self, model_type: str, model_name, model_credential: Dict[str, object], provider,
|
||||
raise_exception=False):
|
||||
model_type_list = provider.get_model_type_list()
|
||||
if not any(list(filter(lambda mt: mt.get('value') == model_type, model_type_list))):
|
||||
raise AppApiException(ValidCode.valid_error.value, f'{model_type} 模型类型不支持')
|
||||
|
||||
for key in ['api_base', 'api_key']:
|
||||
if key not in model_credential:
|
||||
if raise_exception:
|
||||
raise AppApiException(ValidCode.valid_error.value, f'{key} 字段为必填字段')
|
||||
else:
|
||||
return False
|
||||
try:
|
||||
model = provider.get_model(model_type, model_name, model_credential)
|
||||
model.check_auth()
|
||||
except Exception as e:
|
||||
if isinstance(e, AppApiException):
|
||||
raise e
|
||||
if raise_exception:
|
||||
raise AppApiException(ValidCode.valid_error.value, f'校验失败,请检查参数是否正确: {str(e)}')
|
||||
else:
|
||||
return False
|
||||
return True
|
||||
|
||||
def encryption_dict(self, model: Dict[str, object]):
|
||||
return {**model, 'api_key': super().encryption(model.get('api_key', ''))}
|
||||
|
||||
def get_model_params_setting_form(self, model_name):
|
||||
return XInferenceTTSModelGeneralParams()
|
||||
|
|
@ -12,6 +12,7 @@ from setting.models_provider.impl.xinference_model_provider.credential.embedding
|
|||
from setting.models_provider.impl.xinference_model_provider.credential.llm import XinferenceLLMModelCredential
|
||||
from setting.models_provider.impl.xinference_model_provider.credential.reranker import XInferenceRerankerModelCredential
|
||||
from setting.models_provider.impl.xinference_model_provider.credential.stt import XInferenceSTTModelCredential
|
||||
from setting.models_provider.impl.xinference_model_provider.credential.tts import XInferenceTTSModelCredential
|
||||
from setting.models_provider.impl.xinference_model_provider.model.embedding import XinferenceEmbedding
|
||||
from setting.models_provider.impl.xinference_model_provider.model.llm import XinferenceChatModel
|
||||
from setting.models_provider.impl.xinference_model_provider.model.reranker import XInferenceReranker
|
||||
|
|
@ -21,6 +22,7 @@ from smartdoc.conf import PROJECT_DIR
|
|||
|
||||
xinference_llm_model_credential = XinferenceLLMModelCredential()
|
||||
xinference_stt_model_credential = XInferenceSTTModelCredential()
|
||||
xinference_tts_model_credential = XInferenceTTSModelCredential()
|
||||
|
||||
model_info_list = [
|
||||
ModelInfo(
|
||||
|
|
@ -279,7 +281,7 @@ model_info_list = [
|
|||
'CosyVoice-300M-SFT',
|
||||
'CosyVoice-300M-SFT是一个小型的语音合成模型。',
|
||||
ModelTypeConst.TTS,
|
||||
xinference_stt_model_credential,
|
||||
xinference_tts_model_credential,
|
||||
XInferenceTextToSpeech
|
||||
),
|
||||
ModelInfo(
|
||||
|
|
|
|||
|
|
@ -244,7 +244,38 @@ class ModelSerializer(serializers.Serializer):
|
|||
model_id = self.data.get('id')
|
||||
model = QuerySet(Model).filter(id=model_id).first()
|
||||
credential = get_model_credential(model.provider, model.model_type, model.model_name)
|
||||
return credential.get_model_params_setting_form(model.model_name).to_form_list()
|
||||
# 已经保存过的模型参数表单
|
||||
if model.model_params_form is not None and len(model.model_params_form) > 0:
|
||||
return model.model_params_form
|
||||
# 没有保存过的LLM类型的
|
||||
if credential.get_model_params_setting_form(model.model_name) is not None:
|
||||
return credential.get_model_params_setting_form(model.model_name).to_form_list()
|
||||
# 其他的
|
||||
return model.model_params_form
|
||||
|
||||
class ModelParamsForm(serializers.Serializer):
|
||||
id = serializers.UUIDField(required=True, error_messages=ErrMessage.uuid("模型id"))
|
||||
|
||||
user_id = serializers.UUIDField(required=True, error_messages=ErrMessage.uuid("用户id"))
|
||||
|
||||
def is_valid(self, *, raise_exception=False):
|
||||
super().is_valid(raise_exception=True)
|
||||
model = QuerySet(Model).filter(id=self.data.get("id")).first()
|
||||
if model is None:
|
||||
raise AppApiException(500, '模型不存在')
|
||||
if model.permission_type == PermissionType.PRIVATE and self.data.get('user_id') != str(model.user_id):
|
||||
raise AppApiException(500, '没有权限访问到此模型')
|
||||
|
||||
def save_model_params_form(self, model_params_form, with_valid=True):
|
||||
if with_valid:
|
||||
self.is_valid(raise_exception=True)
|
||||
if model_params_form is None:
|
||||
model_params_form = []
|
||||
model_id = self.data.get('id')
|
||||
model = QuerySet(Model).filter(id=model_id).first()
|
||||
model.model_params_form = model_params_form
|
||||
model.save()
|
||||
return True
|
||||
|
||||
class Operate(serializers.Serializer):
|
||||
id = serializers.UUIDField(required=True, error_messages=ErrMessage.uuid("模型id"))
|
||||
|
|
|
|||
|
|
@ -94,6 +94,17 @@ class Model(APIView):
|
|||
return result.success(
|
||||
ModelSerializer.ModelParams(data={'id': model_id, 'user_id': request.user.id}).get_model_params())
|
||||
|
||||
@action(methods=['PUT'], detail=False)
|
||||
@swagger_auto_schema(operation_summary="保存模型参数表单",
|
||||
operation_id="保存模型参数表单",
|
||||
manual_parameters=ProvideApi.ModelForm.get_request_params_api(),
|
||||
tags=["模型"])
|
||||
@has_permissions(PermissionConstants.MODEL_READ)
|
||||
def put(self, request: Request, model_id: str):
|
||||
return result.success(
|
||||
ModelSerializer.ModelParamsForm(data={'id': model_id, 'user_id': request.user.id})
|
||||
.save_model_params_form(request.data))
|
||||
|
||||
class Operate(APIView):
|
||||
authentication_classes = [TokenAuth]
|
||||
|
||||
|
|
|
|||
|
|
@ -118,6 +118,20 @@ const updateModel: (
|
|||
return put(`${prefix}/${model_id}`, request, {}, loading)
|
||||
}
|
||||
|
||||
/**
|
||||
* 修改模型参数配置
|
||||
* @param request 請求對象
|
||||
* @param loading 加載器
|
||||
* @returns
|
||||
*/
|
||||
const updateModelParamsForm: (
|
||||
model_id: string,
|
||||
request: any[],
|
||||
loading?: Ref<boolean>
|
||||
) => Promise<Result<Model>> = (model_id, request, loading) => {
|
||||
return put(`${prefix}/${model_id}/model_params_form`, request, {}, loading)
|
||||
}
|
||||
|
||||
/**
|
||||
* 获取模型详情根据模型id 包括认证信息
|
||||
* @param model_id 模型id
|
||||
|
|
@ -172,5 +186,6 @@ export default {
|
|||
getModelById,
|
||||
getModelMetaById,
|
||||
pauseDownload,
|
||||
getModelParamsForm
|
||||
getModelParamsForm,
|
||||
updateModelParamsForm
|
||||
}
|
||||
|
|
|
|||
|
|
@ -76,6 +76,10 @@ interface Model {
|
|||
* 元数据
|
||||
*/
|
||||
meta: Dict<any>
|
||||
/**
|
||||
* 模型参数配置
|
||||
*/
|
||||
model_params_form: Dict<any>[]
|
||||
}
|
||||
interface CreateModelRequest {
|
||||
/**
|
||||
|
|
|
|||
|
|
@ -7,11 +7,11 @@
|
|||
:model="form_data"
|
||||
v-bind="$attrs"
|
||||
>
|
||||
<el-form-item label="参数" :required="true" prop="label" :rules="rules.label">
|
||||
<el-input v-model="form_data.label" placeholder="请输入参数" />
|
||||
<el-form-item label="显示名称" :required="true" prop="label" :rules="rules.label">
|
||||
<el-input v-model="form_data.label" placeholder="请输入显示名称" />
|
||||
</el-form-item>
|
||||
<el-form-item label="显示名称" :required="true" prop="field" :rules="rules.field">
|
||||
<el-input v-model="form_data.field" placeholder="请输入显示名称" />
|
||||
<el-form-item label="参数" :required="true" prop="field" :rules="rules.field">
|
||||
<el-input v-model="form_data.field" placeholder="请输入参数" />
|
||||
</el-form-item>
|
||||
<el-form-item label="参数提示说明">
|
||||
<el-input v-model="form_data.tooltip" placeholder="请输入参数提示说明" />
|
||||
|
|
@ -19,8 +19,8 @@
|
|||
<el-form-item label="是否必填" :required="true" prop="required" :rules="rules.required">
|
||||
<el-switch v-model="form_data.required" />
|
||||
</el-form-item>
|
||||
<el-form-item label="组建类型" :required="true" prop="input_type" :rules="rules.input_type">
|
||||
<el-select v-model="form_data.input_type" placeholder="请选择组建类型">
|
||||
<el-form-item label="组件类型" :required="true" prop="input_type" :rules="rules.input_type">
|
||||
<el-select v-model="form_data.input_type" placeholder="请选择组件类型">
|
||||
<el-option
|
||||
v-for="input_type in input_type_list"
|
||||
:key="input_type.value"
|
||||
|
|
@ -86,17 +86,6 @@ const getData = () => {
|
|||
}
|
||||
}
|
||||
|
||||
const resetFields = () => {
|
||||
console.log(123)
|
||||
form_data.value = {
|
||||
label: '',
|
||||
field: '',
|
||||
tooltip: '',
|
||||
required: false,
|
||||
input_type: ''
|
||||
}
|
||||
}
|
||||
|
||||
const validate = () => {
|
||||
if (ruleFormRef.value) {
|
||||
return ruleFormRef.value?.validate()
|
||||
|
|
@ -120,6 +109,6 @@ onMounted(() => {
|
|||
}
|
||||
})
|
||||
|
||||
defineExpose({ getData, resetFields, validate })
|
||||
defineExpose({ getData, validate })
|
||||
</script>
|
||||
<style lang="scss"></style>
|
||||
|
|
|
|||
|
|
@ -14,10 +14,10 @@
|
|||
|
||||
<div
|
||||
class="w-full flex-between mb-8"
|
||||
v-for="(option, $index) in formValue.optionList"
|
||||
v-for="(option, $index) in formValue.option_list"
|
||||
:key="$index"
|
||||
>
|
||||
<el-input v-model="formValue.optionList[$index]" placeholder="请输入选项值" />
|
||||
<el-input v-model="formValue.option_list[$index].value" placeholder="请输入选项值" />
|
||||
<el-button link class="ml-8" @click.stop="delOption($index)">
|
||||
<el-icon>
|
||||
<Delete />
|
||||
|
|
@ -33,7 +33,7 @@
|
|||
:rules="formValue.required ? [{ required: true, message: '默认值 为必填属性' }] : []"
|
||||
>
|
||||
<el-select v-model="formValue.default_value">
|
||||
<el-option v-for="(option, index) in formValue.optionList" :key="index" :label="option" :value="option" />
|
||||
<el-option v-for="(option, index) in formValue.option_list" :key="index" :label="option.value" :value="option.value" />
|
||||
</el-select>
|
||||
</el-form-item>
|
||||
</template>
|
||||
|
|
@ -54,11 +54,11 @@ const formValue = computed({
|
|||
})
|
||||
|
||||
const addOption = () => {
|
||||
formValue.value.optionList.push('')
|
||||
formValue.value.option_list.push('')
|
||||
}
|
||||
|
||||
const delOption = (index: number) => {
|
||||
formValue.value.optionList.splice(index, 1)
|
||||
formValue.value.option_list.splice(index, 1)
|
||||
}
|
||||
|
||||
|
||||
|
|
@ -67,13 +67,14 @@ const getData = () => {
|
|||
input_type: 'SingleSelect',
|
||||
attrs: {},
|
||||
default_value: formValue.value.default_value,
|
||||
optionList: formValue.value.optionList
|
||||
text_field: formValue.value.text_field,
|
||||
value_field: formValue.value.value_field,
|
||||
option_list: formValue.value.option_list
|
||||
}
|
||||
}
|
||||
defineExpose({ getData })
|
||||
onMounted(() => {
|
||||
console.log('props.modelValue', props.modelValue)
|
||||
formValue.value.optionList = props.modelValue.optionList || []
|
||||
formValue.value.option_list = props.modelValue.option_list || []
|
||||
})
|
||||
</script>
|
||||
<style lang="scss"></style>
|
||||
|
|
|
|||
|
|
@ -57,10 +57,11 @@ function confirmClick() {
|
|||
const formEl = DynamicsFormConstructorRef.value
|
||||
formEl?.validate().then((valid) => {
|
||||
if (valid) {
|
||||
emit('refresh', formEl?.getData(), currentIndex.value)
|
||||
drawer.value = false
|
||||
isEdit.value = false
|
||||
currentItem.value = null
|
||||
emit('refresh', formEl?.getData(), currentIndex.value)
|
||||
currentIndex.value = null
|
||||
}
|
||||
})
|
||||
}
|
||||
|
|
|
|||
|
|
@ -93,7 +93,11 @@
|
|||
</el-button>
|
||||
<template #dropdown>
|
||||
<el-dropdown-menu>
|
||||
<el-dropdown-item icon="Setting" @click.stop="openParamSetting">
|
||||
<el-dropdown-item
|
||||
v-if="currentModel.model_type === 'TTS' || currentModel.model_type === 'LLM'"
|
||||
:disabled="!is_permisstion"
|
||||
icon="Setting" @click.stop="openParamSetting"
|
||||
>
|
||||
模型参数设置
|
||||
</el-dropdown-item>
|
||||
<el-dropdown-item icon="Delete" :disabled="!is_permisstion" text @click.stop="deleteModel">
|
||||
|
|
@ -105,7 +109,7 @@
|
|||
</div>
|
||||
</template>
|
||||
<EditModel ref="editModelRef" @submit="emit('change')"></EditModel>
|
||||
<ParamSettingDialog ref="paramSettingRef"/>
|
||||
<ParamSettingDialog ref="paramSettingRef" :model="model"/>
|
||||
</card-box>
|
||||
</template>
|
||||
<script setup lang="ts">
|
||||
|
|
|
|||
|
|
@ -12,16 +12,16 @@
|
|||
添加参数
|
||||
</el-button>
|
||||
<el-table
|
||||
:data="modelParams"
|
||||
:data="modelParamsForm"
|
||||
class="mb-16"
|
||||
>
|
||||
<el-table-column prop="label" label="参数">
|
||||
<el-table-column prop="label" label="显示名称">
|
||||
<template #default="{ row }">
|
||||
<span v-if="row.label && row.label.input_type === 'TooltipLabel'">{{ row.label.label }}</span>
|
||||
<span v-else>{{ row.label }}</span>
|
||||
</template>
|
||||
</el-table-column>
|
||||
<el-table-column prop="field" label="显示名称" />
|
||||
<el-table-column prop="field" label="参数" />
|
||||
<el-table-column label="组件类型">
|
||||
<template #default="{ row }">
|
||||
<el-tag type="info" class="info-tag" v-if="row.input_type === 'TextInput'">文本框</el-tag>
|
||||
|
|
@ -72,15 +72,31 @@
|
|||
import type { Model } from '@/api/type/model'
|
||||
import { ref } from 'vue'
|
||||
import AddParamDrawer from './AddParamDrawer.vue'
|
||||
import { MsgError } from '@/utils/message'
|
||||
import { MsgError, MsgSuccess } from '@/utils/message'
|
||||
import ModelApi from '@/api/model'
|
||||
|
||||
|
||||
const props = defineProps<{
|
||||
model: Model
|
||||
}>()
|
||||
|
||||
const loading = ref<boolean>(false)
|
||||
const dialogVisible = ref<boolean>(false)
|
||||
const modelParams = ref<any[]>([])
|
||||
const modelParamsForm = ref<any[]>([])
|
||||
const AddParamRef = ref()
|
||||
|
||||
const open = (model: Model) => {
|
||||
const open = () => {
|
||||
dialogVisible.value = true
|
||||
loading.value = true
|
||||
ModelApi.getModelParamsForm(
|
||||
props.model.id,
|
||||
loading
|
||||
).then((ok) => {
|
||||
loading.value = false
|
||||
modelParamsForm.value = ok.data
|
||||
}).catch(() => {
|
||||
loading.value = false
|
||||
})
|
||||
}
|
||||
|
||||
const close = () => {
|
||||
|
|
@ -93,27 +109,49 @@ function openAddDrawer(data?: any, index?: any) {
|
|||
}
|
||||
|
||||
function deleteParam(index: any) {
|
||||
modelParams.value.splice(index, 1)
|
||||
modelParamsForm.value.splice(index, 1)
|
||||
}
|
||||
|
||||
function refresh(data: any, index: any) {
|
||||
// console.log(data, index)
|
||||
for (let i = 0; i < modelParams.value.length; i++) {
|
||||
if (modelParams.value[i].field === data.field && index !== i) {
|
||||
for (let i = 0; i < modelParamsForm.value.length; i++) {
|
||||
let field = modelParamsForm.value[i].field
|
||||
let label = modelParamsForm.value[i].label
|
||||
if (label && label.input_type === 'TooltipLabel') {
|
||||
label = label.label
|
||||
}
|
||||
let label2 = data.label
|
||||
if (label2 && label2.input_type === 'TooltipLabel') {
|
||||
label2 = label2.label
|
||||
}
|
||||
|
||||
if (field === data.field && index !== i) {
|
||||
MsgError('变量已存在: ' + data.field)
|
||||
return
|
||||
}
|
||||
if (label === label2 && index !== i) {
|
||||
MsgError('变量已存在: ' + label)
|
||||
return
|
||||
}
|
||||
}
|
||||
if (index !== null) {
|
||||
modelParams.value.splice(index, 1, data)
|
||||
modelParamsForm.value.splice(index, 1, data)
|
||||
} else {
|
||||
modelParams.value.push(data)
|
||||
modelParamsForm.value.push(data)
|
||||
}
|
||||
console.log(modelParams.value)
|
||||
}
|
||||
|
||||
function submit() {
|
||||
console.log('submit')
|
||||
// console.log('submit: ', modelParamsForm.value)
|
||||
ModelApi.updateModelParamsForm(
|
||||
props.model.id,
|
||||
modelParamsForm.value,
|
||||
loading
|
||||
).then((ok) => {
|
||||
MsgSuccess('模型参数保存成功')
|
||||
close()
|
||||
// emit('submit')
|
||||
})
|
||||
}
|
||||
|
||||
defineExpose({ open, close })
|
||||
|
|
|
|||
Loading…
Reference in New Issue