From 67f3ec34a120d8885d78a71b1f03cc52e0c8aec9 Mon Sep 17 00:00:00 2001 From: CaptainB Date: Tue, 15 Oct 2024 13:19:35 +0800 Subject: [PATCH] =?UTF-8?q?refactor:=20=E6=A8=A1=E5=9E=8B=E8=AE=BE?= =?UTF-8?q?=E7=BD=AE=E6=94=AF=E6=8C=81=E9=85=8D=E7=BD=AE=E5=8F=82=E6=95=B0?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- apps/common/forms/single_select_field.py | 3 +- ..._document_status_alter_paragraph_status.py | 23 +++++++ .../0007_model_model_params_form.py | 18 ++++++ apps/setting/models/model_management.py | 10 +++ .../aliyun_bai_lian_model_provider.py | 3 +- .../credential/tts.py | 39 ++++++++++- .../openai_model_provider/credential/tts.py | 58 +++++++++++++++++ .../openai_model_provider.py | 4 +- .../credential/tts.py | 32 +++++++++- .../impl/xf_model_provider/credential/tts.py | 27 +++++++- .../credential/tts.py | 60 +++++++++++++++++ .../xinference_model_provider.py | 4 +- .../serializers/provider_serializers.py | 33 +++++++++- apps/setting/views/model.py | 11 ++++ ui/src/api/model.ts | 17 ++++- ui/src/api/type/model.ts | 4 ++ .../dynamics-form/constructor/index.vue | 25 ++------ .../items/SingleSelectConstructor.vue | 17 ++--- .../template/component/AddParamDrawer.vue | 3 +- ui/src/views/template/component/ModelCard.vue | 8 ++- .../template/component/ParamSettingDialog.vue | 64 +++++++++++++++---- 21 files changed, 407 insertions(+), 56 deletions(-) create mode 100644 apps/dataset/migrations/0009_alter_document_status_alter_paragraph_status.py create mode 100644 apps/setting/migrations/0007_model_model_params_form.py create mode 100644 apps/setting/models_provider/impl/openai_model_provider/credential/tts.py create mode 100644 apps/setting/models_provider/impl/xinference_model_provider/credential/tts.py diff --git a/apps/common/forms/single_select_field.py b/apps/common/forms/single_select_field.py index cf3d50409..21bd5de57 100644 --- a/apps/common/forms/single_select_field.py +++ b/apps/common/forms/single_select_field.py @@ -8,6 +8,7 @@ """ from typing import List, Dict +from common.forms import BaseLabel from common.forms.base_field import TriggerType, BaseExecField @@ -17,7 +18,7 @@ class SingleSelect(BaseExecField): """ def __init__(self, - label: str, + label: str or BaseLabel, text_field: str, value_field: str, option_list: List[str:object], diff --git a/apps/dataset/migrations/0009_alter_document_status_alter_paragraph_status.py b/apps/dataset/migrations/0009_alter_document_status_alter_paragraph_status.py new file mode 100644 index 000000000..7c138a609 --- /dev/null +++ b/apps/dataset/migrations/0009_alter_document_status_alter_paragraph_status.py @@ -0,0 +1,23 @@ +# Generated by Django 4.2.15 on 2024-10-15 14:49 + +from django.db import migrations, models + + +class Migration(migrations.Migration): + + dependencies = [ + ('dataset', '0008_alter_document_status_alter_paragraph_status'), + ] + + operations = [ + migrations.AlterField( + model_name='document', + name='status', + field=models.CharField(choices=[('0', '导入中'), ('1', '已完成'), ('2', '导入失败'), ('3', '排队中'), ('4', '生成问题中')], default='3', max_length=1, verbose_name='状态'), + ), + migrations.AlterField( + model_name='paragraph', + name='status', + field=models.CharField(choices=[('0', '导入中'), ('1', '已完成'), ('2', '导入失败'), ('3', '排队中'), ('4', '生成问题中')], default='0', max_length=1, verbose_name='状态'), + ), + ] diff --git a/apps/setting/migrations/0007_model_model_params_form.py b/apps/setting/migrations/0007_model_model_params_form.py new file mode 100644 index 000000000..fa40b660d --- /dev/null +++ b/apps/setting/migrations/0007_model_model_params_form.py @@ -0,0 +1,18 @@ +# Generated by Django 4.2.15 on 2024-10-15 14:49 + +from django.db import migrations, models + + +class Migration(migrations.Migration): + + dependencies = [ + ('setting', '0006_alter_model_status'), + ] + + operations = [ + migrations.AddField( + model_name='model', + name='model_params_form', + field=models.JSONField(default=list, verbose_name='模型参数配置'), + ), + ] diff --git a/apps/setting/models/model_management.py b/apps/setting/models/model_management.py index c2f203fb0..638161e46 100644 --- a/apps/setting/models/model_management.py +++ b/apps/setting/models/model_management.py @@ -29,6 +29,13 @@ class PermissionType(models.TextChoices): PUBLIC = "PUBLIC", '公开' PRIVATE = "PRIVATE", "私有" +class ModelParam(models.Model): + label = models.CharField(max_length=128, verbose_name="参数") + field = models.CharField(max_length=256, verbose_name="显示名称") + default_value = models.CharField(max_length=1000, verbose_name="默认值") + input_type = models.CharField(max_length=32, verbose_name="组件类型") + attrs = models.JSONField(verbose_name="属性") + required = models.BooleanField(verbose_name="必填") class Model(AppModelMixin): """ @@ -56,6 +63,9 @@ class Model(AppModelMixin): permission_type = models.CharField(max_length=20, verbose_name='权限类型', choices=PermissionType.choices, default=PermissionType.PRIVATE) + model_params_form = models.JSONField(verbose_name="模型参数配置", default=list) + + def is_permission(self, user_id): if self.permission_type == PermissionType.PUBLIC or str(user_id) == str(self.user_id): return True diff --git a/apps/setting/models_provider/impl/aliyun_bai_lian_model_provider/aliyun_bai_lian_model_provider.py b/apps/setting/models_provider/impl/aliyun_bai_lian_model_provider/aliyun_bai_lian_model_provider.py index 83d42a57c..fffc7e68d 100644 --- a/apps/setting/models_provider/impl/aliyun_bai_lian_model_provider/aliyun_bai_lian_model_provider.py +++ b/apps/setting/models_provider/impl/aliyun_bai_lian_model_provider/aliyun_bai_lian_model_provider.py @@ -35,7 +35,8 @@ model_info_list = [ModelInfo('gte-rerank', ModelTypeConst.TTS, aliyun_bai_lian_tts_model_credential, AliyunBaiLianTextToSpeech), ] -model_info_manage = ModelInfoManage.builder().append_model_info_list(model_info_list).build() +model_info_manage = ModelInfoManage.builder().append_model_info_list(model_info_list).append_default_model_info( + model_info_list[1]).append_default_model_info(model_info_list[2]).build() class AliyunBaiLianModelProvider(IModelProvider): diff --git a/apps/setting/models_provider/impl/aliyun_bai_lian_model_provider/credential/tts.py b/apps/setting/models_provider/impl/aliyun_bai_lian_model_provider/credential/tts.py index fe54ddaba..c00ffddd0 100644 --- a/apps/setting/models_provider/impl/aliyun_bai_lian_model_provider/credential/tts.py +++ b/apps/setting/models_provider/impl/aliyun_bai_lian_model_provider/credential/tts.py @@ -4,10 +4,44 @@ from typing import Dict from common import forms from common.exception.app_exception import AppApiException -from common.forms import BaseForm +from common.forms import BaseForm, TooltipLabel from setting.models_provider.base_model_provider import BaseModelCredential, ValidCode +class AliyunBaiLianTTSModelGeneralParams(BaseForm): + voice = forms.SingleSelect( + TooltipLabel('音色', '中文音色可支持中英文混合场景'), + required=True, default_value='longxiaochun', + text_field='text', + value_field='value', + option_list=[ + {'text': '龙小淳', 'value': 'longxiaochun'}, + {'text': '龙小夏', 'value': 'longxiaoxia'}, + {'text': '龙小诚', 'value': 'longxiaocheng'}, + {'text': '龙小白', 'value': 'longxiaobai'}, + {'text': '龙老铁', 'value': 'longlaotie'}, + {'text': '龙书', 'value': 'longshu'}, + {'text': '龙硕', 'value': 'longshuo'}, + {'text': '龙婧', 'value': 'longjing'}, + {'text': '龙妙', 'value': 'longmiao'}, + {'text': '龙悦', 'value': 'longyue'}, + {'text': '龙媛', 'value': 'longyuan'}, + {'text': '龙飞', 'value': 'longfei'}, + {'text': '龙杰力豆', 'value': 'longjielidou'}, + {'text': '龙彤', 'value': 'longtong'}, + {'text': '龙祥', 'value': 'longxiang'}, + {'text': 'Stella', 'value': 'loongstella'}, + {'text': 'Bella', 'value': 'loongbella'}, + ]) + speech_rate = forms.SliderField( + TooltipLabel('语速', '[0.5,2],默认为1,通常保留一位小数即可'), + required=True, default_value=1, + _min=0.5, + _max=2, + _step=0.1, + precision=1) + + class AliyunBaiLianTTSModelCredential(BaseForm, BaseModelCredential): api_key = forms.PasswordInputField("API Key", required=True) @@ -38,6 +72,5 @@ class AliyunBaiLianTTSModelCredential(BaseForm, BaseModelCredential): def encryption_dict(self, model: Dict[str, object]): return {**model, 'api_key': super().encryption(model.get('api_key', ''))} - def get_model_params_setting_form(self, model_name): - pass + return AliyunBaiLianTTSModelGeneralParams() diff --git a/apps/setting/models_provider/impl/openai_model_provider/credential/tts.py b/apps/setting/models_provider/impl/openai_model_provider/credential/tts.py new file mode 100644 index 000000000..38d839ca0 --- /dev/null +++ b/apps/setting/models_provider/impl/openai_model_provider/credential/tts.py @@ -0,0 +1,58 @@ +# coding=utf-8 +from typing import Dict + +from common import forms +from common.exception.app_exception import AppApiException +from common.forms import BaseForm, TooltipLabel +from setting.models_provider.base_model_provider import BaseModelCredential, ValidCode + +class OpenAITTSModelGeneralParams(BaseForm): + # alloy, echo, fable, onyx, nova, shimmer + voice = forms.SingleSelect( + TooltipLabel('Voice', '尝试不同的声音(合金、回声、寓言、缟玛瑙、新星和闪光),找到一种适合您所需的音调和听众的声音。当前的语音针对英语进行了优化。'), + required=True, default_value='alloy', + text_field='text', + value_field='value', + option_list=[ + {'text': 'alloy', 'value': 'alloy'}, + {'text': 'echo', 'value': 'echo'}, + {'text': 'fable', 'value': 'fable'}, + {'text': 'onyx', 'value': 'onyx'}, + {'text': 'nova', 'value': 'nova'}, + {'text': 'shimmer', 'value': 'shimmer'}, + ]) + + +class OpenAITTSModelCredential(BaseForm, BaseModelCredential): + api_base = forms.TextInputField('API 域名', required=True) + api_key = forms.PasswordInputField('API Key', required=True) + + def is_valid(self, model_type: str, model_name, model_credential: Dict[str, object], provider, + raise_exception=False): + model_type_list = provider.get_model_type_list() + if not any(list(filter(lambda mt: mt.get('value') == model_type, model_type_list))): + raise AppApiException(ValidCode.valid_error.value, f'{model_type} 模型类型不支持') + + for key in ['api_base', 'api_key']: + if key not in model_credential: + if raise_exception: + raise AppApiException(ValidCode.valid_error.value, f'{key} 字段为必填字段') + else: + return False + try: + model = provider.get_model(model_type, model_name, model_credential) + model.check_auth() + except Exception as e: + if isinstance(e, AppApiException): + raise e + if raise_exception: + raise AppApiException(ValidCode.valid_error.value, f'校验失败,请检查参数是否正确: {str(e)}') + else: + return False + return True + + def encryption_dict(self, model: Dict[str, object]): + return {**model, 'api_key': super().encryption(model.get('api_key', ''))} + + def get_model_params_setting_form(self, model_name): + return OpenAITTSModelGeneralParams() diff --git a/apps/setting/models_provider/impl/openai_model_provider/openai_model_provider.py b/apps/setting/models_provider/impl/openai_model_provider/openai_model_provider.py index c2cc4a27e..f9221388f 100644 --- a/apps/setting/models_provider/impl/openai_model_provider/openai_model_provider.py +++ b/apps/setting/models_provider/impl/openai_model_provider/openai_model_provider.py @@ -14,6 +14,7 @@ from setting.models_provider.base_model_provider import IModelProvider, ModelPro from setting.models_provider.impl.openai_model_provider.credential.embedding import OpenAIEmbeddingCredential from setting.models_provider.impl.openai_model_provider.credential.llm import OpenAILLMModelCredential from setting.models_provider.impl.openai_model_provider.credential.stt import OpenAISTTModelCredential +from setting.models_provider.impl.openai_model_provider.credential.tts import OpenAITTSModelCredential from setting.models_provider.impl.openai_model_provider.model.embedding import OpenAIEmbeddingModel from setting.models_provider.impl.openai_model_provider.model.llm import OpenAIChatModel from setting.models_provider.impl.openai_model_provider.model.stt import OpenAISpeechToText @@ -22,6 +23,7 @@ from smartdoc.conf import PROJECT_DIR openai_llm_model_credential = OpenAILLMModelCredential() openai_stt_model_credential = OpenAISTTModelCredential() +openai_tts_model_credential = OpenAITTSModelCredential() model_info_list = [ ModelInfo('gpt-3.5-turbo', '最新的gpt-3.5-turbo,随OpenAI调整而更新', ModelTypeConst.LLM, openai_llm_model_credential, OpenAIChatModel @@ -70,7 +72,7 @@ model_info_list = [ ModelTypeConst.STT, openai_stt_model_credential, OpenAISpeechToText), ModelInfo('tts-1', '', - ModelTypeConst.TTS, openai_stt_model_credential, + ModelTypeConst.TTS, openai_tts_model_credential, OpenAITextToSpeech) ] open_ai_embedding_credential = OpenAIEmbeddingCredential() diff --git a/apps/setting/models_provider/impl/volcanic_engine_model_provider/credential/tts.py b/apps/setting/models_provider/impl/volcanic_engine_model_provider/credential/tts.py index 7a9866613..7565a2546 100644 --- a/apps/setting/models_provider/impl/volcanic_engine_model_provider/credential/tts.py +++ b/apps/setting/models_provider/impl/volcanic_engine_model_provider/credential/tts.py @@ -4,10 +4,38 @@ from typing import Dict from common import forms from common.exception.app_exception import AppApiException -from common.forms import BaseForm +from common.forms import BaseForm, TooltipLabel from setting.models_provider.base_model_provider import BaseModelCredential, ValidCode +class VolcanicEngineTTSModelGeneralParams(BaseForm): + voice_type = forms.SingleSelect( + TooltipLabel('音色', '中文音色可支持中英文混合场景'), + required=True, default_value='BV002_streaming', + text_field='text', + value_field='value', + option_list=[ + {'text': '灿灿 2.0', 'value': 'BV700_V2_streaming'}, + {'text': '炀炀', 'value': 'BV705_streaming'}, + {'text': '擎苍 2.0', 'value': 'BV701_V2_streaming'}, + {'text': '通用女声 2.0', 'value': 'BV001_V2_streaming'}, + {'text': '灿灿', 'value': 'BV700_streaming'}, + {'text': '超自然音色-梓梓2.0', 'value': 'BV406_V2_streaming'}, + {'text': '超自然音色-梓梓', 'value': 'BV406_streaming'}, + {'text': '超自然音色-燃燃2.0', 'value': 'BV407_V2_streaming'}, + {'text': '超自然音色-燃燃', 'value': 'BV407_streaming'}, + {'text': '通用女声', 'value': 'BV001_streaming'}, + {'text': '通用男声', 'value': 'BV002_streaming'}, + ]) + speed_ratio = forms.SliderField( + TooltipLabel('语速', '[0.2,3],默认为1,通常保留一位小数即可'), + required=True, default_value=1, + _min=0.2, + _max=3, + _step=0.1, + precision=1) + + class VolcanicEngineTTSModelCredential(BaseForm, BaseModelCredential): volcanic_api_url = forms.TextInputField('API 域名', required=True, default_value='wss://openspeech.bytedance.com/api/v1/tts/ws_binary') volcanic_app_id = forms.TextInputField('App ID', required=True) @@ -42,4 +70,4 @@ class VolcanicEngineTTSModelCredential(BaseForm, BaseModelCredential): return {**model, 'volcanic_token': super().encryption(model.get('volcanic_token', ''))} def get_model_params_setting_form(self, model_name): - pass + return VolcanicEngineTTSModelGeneralParams() diff --git a/apps/setting/models_provider/impl/xf_model_provider/credential/tts.py b/apps/setting/models_provider/impl/xf_model_provider/credential/tts.py index 309f882b9..f0e68b38b 100644 --- a/apps/setting/models_provider/impl/xf_model_provider/credential/tts.py +++ b/apps/setting/models_provider/impl/xf_model_provider/credential/tts.py @@ -4,10 +4,32 @@ from typing import Dict from common import forms from common.exception.app_exception import AppApiException -from common.forms import BaseForm +from common.forms import BaseForm, TooltipLabel from setting.models_provider.base_model_provider import BaseModelCredential, ValidCode +class XunFeiTTSModelGeneralParams(BaseForm): + vcn = forms.SingleSelect( + TooltipLabel('发音人', '发音人,可选值:请到控制台添加试用或购买发音人,添加后即显示发音人参数值'), + required=True, default_value='xiaoyan', + text_field='text', + value_field='value', + option_list=[ + {'text': '讯飞小燕', 'value': 'xiaoyan'}, + {'text': '讯飞许久', 'value': 'aisjiuxu'}, + {'text': '讯飞小萍', 'value': 'aisxping'}, + {'text': '讯飞小婧', 'value': 'aisjinger'}, + {'text': '讯飞许小宝', 'value': 'aisbabyxu'}, + ]) + speed = forms.SliderField( + TooltipLabel('语速', '语速,可选值:[0-100],默认为50'), + required=True, default_value=50, + _min=1, + _max=100, + _step=5, + precision=1) + + class XunFeiTTSModelCredential(BaseForm, BaseModelCredential): spark_api_url = forms.TextInputField('API 域名', required=True, default_value='wss://tts-api.xfyun.cn/v2/tts') spark_app_id = forms.TextInputField('APP ID', required=True) @@ -41,6 +63,5 @@ class XunFeiTTSModelCredential(BaseForm, BaseModelCredential): def encryption_dict(self, model: Dict[str, object]): return {**model, 'spark_api_secret': super().encryption(model.get('spark_api_secret', ''))} - def get_model_params_setting_form(self, model_name): - pass + return XunFeiTTSModelGeneralParams() diff --git a/apps/setting/models_provider/impl/xinference_model_provider/credential/tts.py b/apps/setting/models_provider/impl/xinference_model_provider/credential/tts.py new file mode 100644 index 000000000..d2844739f --- /dev/null +++ b/apps/setting/models_provider/impl/xinference_model_provider/credential/tts.py @@ -0,0 +1,60 @@ +# coding=utf-8 +from typing import Dict + +from common import forms +from common.exception.app_exception import AppApiException +from common.forms import BaseForm, TooltipLabel +from setting.models_provider.base_model_provider import BaseModelCredential, ValidCode + + +class XInferenceTTSModelGeneralParams(BaseForm): + # ['中文女', '中文男', '日语男', '粤语女', '英文女', '英文男', '韩语女'] + voice = forms.SingleSelect( + TooltipLabel('音色', ''), + required=True, default_value='中文女', + text_field='text', + value_field='value', + option_list=[ + {'text': '中文女', 'value': '中文女'}, + {'text': '中文男', 'value': '中文男'}, + {'text': '日语男', 'value': '日语男'}, + {'text': '粤语女', 'value': '粤语女'}, + {'text': '英文女', 'value': '英文女'}, + {'text': '英文男', 'value': '英文男'}, + {'text': '韩语女', 'value': '韩语女'}, + ]) + + +class XInferenceTTSModelCredential(BaseForm, BaseModelCredential): + api_base = forms.TextInputField('API 域名', required=True) + api_key = forms.PasswordInputField('API Key', required=True) + + def is_valid(self, model_type: str, model_name, model_credential: Dict[str, object], provider, + raise_exception=False): + model_type_list = provider.get_model_type_list() + if not any(list(filter(lambda mt: mt.get('value') == model_type, model_type_list))): + raise AppApiException(ValidCode.valid_error.value, f'{model_type} 模型类型不支持') + + for key in ['api_base', 'api_key']: + if key not in model_credential: + if raise_exception: + raise AppApiException(ValidCode.valid_error.value, f'{key} 字段为必填字段') + else: + return False + try: + model = provider.get_model(model_type, model_name, model_credential) + model.check_auth() + except Exception as e: + if isinstance(e, AppApiException): + raise e + if raise_exception: + raise AppApiException(ValidCode.valid_error.value, f'校验失败,请检查参数是否正确: {str(e)}') + else: + return False + return True + + def encryption_dict(self, model: Dict[str, object]): + return {**model, 'api_key': super().encryption(model.get('api_key', ''))} + + def get_model_params_setting_form(self, model_name): + return XInferenceTTSModelGeneralParams() diff --git a/apps/setting/models_provider/impl/xinference_model_provider/xinference_model_provider.py b/apps/setting/models_provider/impl/xinference_model_provider/xinference_model_provider.py index 06c199829..0e432e061 100644 --- a/apps/setting/models_provider/impl/xinference_model_provider/xinference_model_provider.py +++ b/apps/setting/models_provider/impl/xinference_model_provider/xinference_model_provider.py @@ -12,6 +12,7 @@ from setting.models_provider.impl.xinference_model_provider.credential.embedding from setting.models_provider.impl.xinference_model_provider.credential.llm import XinferenceLLMModelCredential from setting.models_provider.impl.xinference_model_provider.credential.reranker import XInferenceRerankerModelCredential from setting.models_provider.impl.xinference_model_provider.credential.stt import XInferenceSTTModelCredential +from setting.models_provider.impl.xinference_model_provider.credential.tts import XInferenceTTSModelCredential from setting.models_provider.impl.xinference_model_provider.model.embedding import XinferenceEmbedding from setting.models_provider.impl.xinference_model_provider.model.llm import XinferenceChatModel from setting.models_provider.impl.xinference_model_provider.model.reranker import XInferenceReranker @@ -21,6 +22,7 @@ from smartdoc.conf import PROJECT_DIR xinference_llm_model_credential = XinferenceLLMModelCredential() xinference_stt_model_credential = XInferenceSTTModelCredential() +xinference_tts_model_credential = XInferenceTTSModelCredential() model_info_list = [ ModelInfo( @@ -279,7 +281,7 @@ model_info_list = [ 'CosyVoice-300M-SFT', 'CosyVoice-300M-SFT是一个小型的语音合成模型。', ModelTypeConst.TTS, - xinference_stt_model_credential, + xinference_tts_model_credential, XInferenceTextToSpeech ), ModelInfo( diff --git a/apps/setting/serializers/provider_serializers.py b/apps/setting/serializers/provider_serializers.py index 5b926ec0b..e76e67d2e 100644 --- a/apps/setting/serializers/provider_serializers.py +++ b/apps/setting/serializers/provider_serializers.py @@ -244,7 +244,38 @@ class ModelSerializer(serializers.Serializer): model_id = self.data.get('id') model = QuerySet(Model).filter(id=model_id).first() credential = get_model_credential(model.provider, model.model_type, model.model_name) - return credential.get_model_params_setting_form(model.model_name).to_form_list() + # 已经保存过的模型参数表单 + if model.model_params_form is not None and len(model.model_params_form) > 0: + return model.model_params_form + # 没有保存过的LLM类型的 + if credential.get_model_params_setting_form(model.model_name) is not None: + return credential.get_model_params_setting_form(model.model_name).to_form_list() + # 其他的 + return model.model_params_form + + class ModelParamsForm(serializers.Serializer): + id = serializers.UUIDField(required=True, error_messages=ErrMessage.uuid("模型id")) + + user_id = serializers.UUIDField(required=True, error_messages=ErrMessage.uuid("用户id")) + + def is_valid(self, *, raise_exception=False): + super().is_valid(raise_exception=True) + model = QuerySet(Model).filter(id=self.data.get("id")).first() + if model is None: + raise AppApiException(500, '模型不存在') + if model.permission_type == PermissionType.PRIVATE and self.data.get('user_id') != str(model.user_id): + raise AppApiException(500, '没有权限访问到此模型') + + def save_model_params_form(self, model_params_form, with_valid=True): + if with_valid: + self.is_valid(raise_exception=True) + if model_params_form is None: + model_params_form = [] + model_id = self.data.get('id') + model = QuerySet(Model).filter(id=model_id).first() + model.model_params_form = model_params_form + model.save() + return True class Operate(serializers.Serializer): id = serializers.UUIDField(required=True, error_messages=ErrMessage.uuid("模型id")) diff --git a/apps/setting/views/model.py b/apps/setting/views/model.py index d57a19471..b4b7c4aaf 100644 --- a/apps/setting/views/model.py +++ b/apps/setting/views/model.py @@ -94,6 +94,17 @@ class Model(APIView): return result.success( ModelSerializer.ModelParams(data={'id': model_id, 'user_id': request.user.id}).get_model_params()) + @action(methods=['PUT'], detail=False) + @swagger_auto_schema(operation_summary="保存模型参数表单", + operation_id="保存模型参数表单", + manual_parameters=ProvideApi.ModelForm.get_request_params_api(), + tags=["模型"]) + @has_permissions(PermissionConstants.MODEL_READ) + def put(self, request: Request, model_id: str): + return result.success( + ModelSerializer.ModelParamsForm(data={'id': model_id, 'user_id': request.user.id}) + .save_model_params_form(request.data)) + class Operate(APIView): authentication_classes = [TokenAuth] diff --git a/ui/src/api/model.ts b/ui/src/api/model.ts index bcb31f9c4..59ca0f82d 100644 --- a/ui/src/api/model.ts +++ b/ui/src/api/model.ts @@ -118,6 +118,20 @@ const updateModel: ( return put(`${prefix}/${model_id}`, request, {}, loading) } +/** + * 修改模型参数配置 + * @param request 請求對象 + * @param loading 加載器 + * @returns + */ +const updateModelParamsForm: ( + model_id: string, + request: any[], + loading?: Ref +) => Promise> = (model_id, request, loading) => { + return put(`${prefix}/${model_id}/model_params_form`, request, {}, loading) +} + /** * 获取模型详情根据模型id 包括认证信息 * @param model_id 模型id @@ -172,5 +186,6 @@ export default { getModelById, getModelMetaById, pauseDownload, - getModelParamsForm + getModelParamsForm, + updateModelParamsForm } diff --git a/ui/src/api/type/model.ts b/ui/src/api/type/model.ts index 1660ffd71..667292055 100644 --- a/ui/src/api/type/model.ts +++ b/ui/src/api/type/model.ts @@ -76,6 +76,10 @@ interface Model { * 元数据 */ meta: Dict + /** + * 模型参数配置 + */ + model_params_form: Dict[] } interface CreateModelRequest { /** diff --git a/ui/src/components/dynamics-form/constructor/index.vue b/ui/src/components/dynamics-form/constructor/index.vue index 31c04d360..41c038850 100644 --- a/ui/src/components/dynamics-form/constructor/index.vue +++ b/ui/src/components/dynamics-form/constructor/index.vue @@ -7,11 +7,11 @@ :model="form_data" v-bind="$attrs" > - - + + - - + + @@ -19,8 +19,8 @@ - - + + { } } -const resetFields = () => { - console.log(123) - form_data.value = { - label: '', - field: '', - tooltip: '', - required: false, - input_type: '' - } -} - const validate = () => { if (ruleFormRef.value) { return ruleFormRef.value?.validate() @@ -120,6 +109,6 @@ onMounted(() => { } }) -defineExpose({ getData, resetFields, validate }) +defineExpose({ getData, validate }) diff --git a/ui/src/components/dynamics-form/constructor/items/SingleSelectConstructor.vue b/ui/src/components/dynamics-form/constructor/items/SingleSelectConstructor.vue index 615414795..1b96f879f 100644 --- a/ui/src/components/dynamics-form/constructor/items/SingleSelectConstructor.vue +++ b/ui/src/components/dynamics-form/constructor/items/SingleSelectConstructor.vue @@ -14,10 +14,10 @@
- + @@ -33,7 +33,7 @@ :rules="formValue.required ? [{ required: true, message: '默认值 为必填属性' }] : []" > - + @@ -54,11 +54,11 @@ const formValue = computed({ }) const addOption = () => { - formValue.value.optionList.push('') + formValue.value.option_list.push('') } const delOption = (index: number) => { - formValue.value.optionList.splice(index, 1) + formValue.value.option_list.splice(index, 1) } @@ -67,13 +67,14 @@ const getData = () => { input_type: 'SingleSelect', attrs: {}, default_value: formValue.value.default_value, - optionList: formValue.value.optionList + text_field: formValue.value.text_field, + value_field: formValue.value.value_field, + option_list: formValue.value.option_list } } defineExpose({ getData }) onMounted(() => { - console.log('props.modelValue', props.modelValue) - formValue.value.optionList = props.modelValue.optionList || [] + formValue.value.option_list = props.modelValue.option_list || [] }) diff --git a/ui/src/views/template/component/AddParamDrawer.vue b/ui/src/views/template/component/AddParamDrawer.vue index 3bda9f304..ee0c34d99 100644 --- a/ui/src/views/template/component/AddParamDrawer.vue +++ b/ui/src/views/template/component/AddParamDrawer.vue @@ -57,10 +57,11 @@ function confirmClick() { const formEl = DynamicsFormConstructorRef.value formEl?.validate().then((valid) => { if (valid) { + emit('refresh', formEl?.getData(), currentIndex.value) drawer.value = false isEdit.value = false currentItem.value = null - emit('refresh', formEl?.getData(), currentIndex.value) + currentIndex.value = null } }) } diff --git a/ui/src/views/template/component/ModelCard.vue b/ui/src/views/template/component/ModelCard.vue index bf72bbcc4..40ccb4684 100644 --- a/ui/src/views/template/component/ModelCard.vue +++ b/ui/src/views/template/component/ModelCard.vue @@ -93,7 +93,11 @@ - +