From 7d938f9e602b39301b842b1c8384c9038d35c46b Mon Sep 17 00:00:00 2001 From: shaohuzhang1 <80892890+shaohuzhang1@users.noreply.github.com> Date: Wed, 4 Dec 2024 10:30:45 +0800 Subject: [PATCH] =?UTF-8?q?fix:=20=E8=AE=BE=E7=BD=AE=E6=A8=A1=E5=9E=8B?= =?UTF-8?q?=E5=8F=82=E6=95=B0=E9=BB=98=E8=AE=A4=E5=80=BC,=E5=85=81?= =?UTF-8?q?=E8=AE=B8=E7=94=A8=E6=88=B7=E4=B8=8D=E8=AE=BE=E7=BD=AE=E6=A8=A1?= =?UTF-8?q?=E5=9E=8B=E5=8F=82=E6=95=B0=20(#1750)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../0009_set_default_model_params_form.py | 19 +++++++++++++++++++ .../serializers/provider_serializers.py | 15 ++++----------- 2 files changed, 23 insertions(+), 11 deletions(-) create mode 100644 apps/setting/migrations/0009_set_default_model_params_form.py diff --git a/apps/setting/migrations/0009_set_default_model_params_form.py b/apps/setting/migrations/0009_set_default_model_params_form.py new file mode 100644 index 000000000..6b4d4b453 --- /dev/null +++ b/apps/setting/migrations/0009_set_default_model_params_form.py @@ -0,0 +1,19 @@ +# Generated by Django 4.2.15 on 2024-10-15 14:49 + +from django.db import migrations, models + +sql = """ +UPDATE "public"."model" +SET "model_params_form" = '[{"attrs": {"max": 1, "min": 0.1, "step": 0.01, "precision": 2, "show-input": true, "show-input-controls": false}, "field": "temperature", "label": {"attrs": {"tooltip": "较高的数值会使输出更加随机,而较低的数值会使其更加集中和确定"}, "label": "温度", "input_type": "TooltipLabel", "props_info": {}}, "required": true, "input_type": "Slider", "props_info": {}, "trigger_type": "OPTION_LIST", "default_value": 0.5, "relation_show_field_dict": {}, "relation_trigger_field_dict": {}}, {"attrs": {"max": 100000, "min": 1, "step": 1, "precision": 0, "show-input": true, "show-input-controls": false}, "field": "max_tokens", "label": {"attrs": {"tooltip": "指定模型可生成的最大token个数"}, "label": "输出最大Tokens", "input_type": "TooltipLabel", "props_info": {}}, "required": true, "input_type": "Slider", "props_info": {}, "trigger_type": "OPTION_LIST", "default_value": 4096, "relation_show_field_dict": {}, "relation_trigger_field_dict": {}}]' +WHERE jsonb_array_length(model_params_form)=0 +""" + + +class Migration(migrations.Migration): + dependencies = [ + ('setting', '0008_modelparam'), + ] + + operations = [ + migrations.RunSQL(sql) + ] diff --git a/apps/setting/serializers/provider_serializers.py b/apps/setting/serializers/provider_serializers.py index e76e67d2e..e732087b0 100644 --- a/apps/setting/serializers/provider_serializers.py +++ b/apps/setting/serializers/provider_serializers.py @@ -79,7 +79,6 @@ class ModelSerializer(serializers.Serializer): create_user = serializers.CharField(required=False, error_messages=ErrMessage.char("创建者")) - def list(self, with_valid): if with_valid: self.is_valid(raise_exception=True) @@ -92,7 +91,8 @@ class ModelSerializer(serializers.Serializer): model_query_set = QuerySet(Model).filter(Q(user_id=create_user)) # 当前用户能查看其他人的模型,只能查看公开的 else: - model_query_set = QuerySet(Model).filter((Q(user_id=self.data.get('create_user')) & Q(permission_type='PUBLIC'))) + model_query_set = QuerySet(Model).filter( + (Q(user_id=self.data.get('create_user')) & Q(permission_type='PUBLIC'))) else: model_query_set = QuerySet(Model).filter((Q(user_id=user_id) | Q(permission_type='PUBLIC'))) query_params = {} @@ -107,11 +107,11 @@ class ModelSerializer(serializers.Serializer): if self.data.get('permission_type') is not None: query_params['permission_type'] = self.data.get('permission_type') - return [ {'id': str(model.id), 'provider': model.provider, 'name': model.name, 'model_type': model.model_type, 'model_name': model.model_name, 'status': model.status, 'meta': model.meta, - 'permission_type': model.permission_type, 'user_id': model.user_id, 'username': model.user.username} for model in + 'permission_type': model.permission_type, 'user_id': model.user_id, 'username': model.user.username} + for model in model_query_set.filter(**query_params).order_by("-create_time")] class Edit(serializers.Serializer): @@ -243,14 +243,7 @@ class ModelSerializer(serializers.Serializer): self.is_valid(raise_exception=True) model_id = self.data.get('id') model = QuerySet(Model).filter(id=model_id).first() - credential = get_model_credential(model.provider, model.model_type, model.model_name) # 已经保存过的模型参数表单 - if model.model_params_form is not None and len(model.model_params_form) > 0: - return model.model_params_form - # 没有保存过的LLM类型的 - if credential.get_model_params_setting_form(model.model_name) is not None: - return credential.get_model_params_setting_form(model.model_name).to_form_list() - # 其他的 return model.model_params_form class ModelParamsForm(serializers.Serializer):