From 4c164a49de740349519f2b27d0b711ba238149df Mon Sep 17 00:00:00 2001 From: shaohuzhang1 Date: Fri, 1 Dec 2023 17:30:06 +0800 Subject: [PATCH] =?UTF-8?q?feat:=20=E6=A8=A1=E5=9E=8B=E5=88=A0=E9=99=A4?= =?UTF-8?q?=E4=BF=AE=E6=94=B9?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- apps/application/models/application.py | 4 +- .../serializers/application_serializers.py | 4 +- .../serializers/provider_serializers.py | 83 ++++++- apps/setting/swagger_api/provide_api.py | 24 ++ apps/setting/urls.py | 1 + apps/setting/views/model.py | 36 ++- ui/src/api/model.ts | 35 ++- ui/src/api/type/model.ts | 31 ++- ui/src/components/dynamics-form/index.vue | 21 +- ui/src/components/index.ts | 4 +- .../views/template/component/CreateModel.vue | 13 +- ui/src/views/template/component/EditModel.vue | 211 ++++++++++++++++++ ui/src/views/template/component/Model.vue | 50 ++++- ui/src/views/template/index.vue | 11 +- 14 files changed, 494 insertions(+), 34 deletions(-) create mode 100644 ui/src/views/template/component/EditModel.vue diff --git a/apps/application/models/application.py b/apps/application/models/application.py index b7daeadf0..7dc728e3d 100644 --- a/apps/application/models/application.py +++ b/apps/application/models/application.py @@ -21,8 +21,8 @@ from users.models import User class Application(AppModelMixin): id = models.UUIDField(primary_key=True, max_length=128, default=uuid.uuid1, editable=False, verbose_name="主键id") name = models.CharField(max_length=128, verbose_name="应用名称") - desc = models.CharField(max_length=128, verbose_name="引用描述") - prologue = models.CharField(max_length=1024, verbose_name="开场白") + desc = models.CharField(max_length=128, verbose_name="引用描述", default="") + prologue = models.CharField(max_length=1024, verbose_name="开场白", default="") example = ArrayField(verbose_name="示例列表", base_field=models.CharField(max_length=256, blank=True)) dialogue_number = models.IntegerField(default=0, verbose_name="会话数量") status = models.BooleanField(default=True, verbose_name="是否发布") diff --git a/apps/application/serializers/application_serializers.py b/apps/application/serializers/application_serializers.py index 817808de7..8698f4627 100644 --- a/apps/application/serializers/application_serializers.py +++ b/apps/application/serializers/application_serializers.py @@ -62,10 +62,10 @@ class ApplicationSerializerModel(serializers.ModelSerializer): class ApplicationSerializer(serializers.Serializer): name = serializers.CharField(required=True) - desc = serializers.CharField(required=True) + desc = serializers.CharField(required=False) model_id = serializers.CharField(required=True) multiple_rounds_dialogue = serializers.BooleanField(required=True) - prologue = serializers.CharField(required=True) + prologue = serializers.CharField(required=False) example = serializers.ListSerializer(required=False, child=serializers.CharField(required=True)) dataset_id_list = serializers.ListSerializer(required=False, child=serializers.UUIDField(required=True)) diff --git a/apps/setting/serializers/provider_serializers.py b/apps/setting/serializers/provider_serializers.py index fba0b0a66..72271a9b5 100644 --- a/apps/setting/serializers/provider_serializers.py +++ b/apps/setting/serializers/provider_serializers.py @@ -49,6 +49,49 @@ class ModelSerializer(serializers.Serializer): return [ModelSerializer.model_to_dict(model) for model in model_query_set.filter(**query_params)] + class Edit(serializers.Serializer): + user_id = serializers.CharField(required=False) + + name = serializers.CharField(required=False) + + model_type = serializers.CharField(required=False) + + model_name = serializers.CharField(required=False) + + credential = serializers.DictField(required=False) + + def is_valid(self, model=None, raise_exception=False): + super().is_valid(raise_exception=True) + filter_params = {'user_id': self.data.get('user_id')} + if 'name' in self.data and self.data.get('name') is not None: + filter_params['name'] = self.data.get('name') + if QuerySet(Model).exclude(id=model.id).filter(**filter_params).exists(): + raise AppApiException(500, f'模型名称【{self.data.get("name")}】已存在') + + ModelSerializer.model_to_dict(model) + + provider = model.provider + model_type = self.data.get('model_type') + model_name = self.data.get( + 'model_name') + credential = self.data.get('credential') + + model_credential = ModelProvideConstants[provider].value.get_model_credential(model_type, + model_name) + source_model_credential = json.loads(decrypt(model.credential)) + source_encryption_model_credential = model_credential.encryption_dict(source_model_credential) + if credential is not None: + for k in source_encryption_model_credential.keys(): + if credential[k] == source_encryption_model_credential[k]: + credential[k] = source_model_credential[k] + # 校验模型认证数据 + model_credential.is_valid( + model_type, + model_name, + credential, + raise_exception=True) + return credential + class Create(serializers.Serializer): user_id = serializers.CharField(required=True) @@ -89,7 +132,7 @@ class ModelSerializer(serializers.Serializer): credential=encrypt(model_credential_str), provider=provider, model_type=model_type, model_name=model_name) model.save() - return ModelSerializer.Operate(data={'id': model.id}).one(user_id, with_valid=True) + return ModelSerializer.Operate(data={'id': model.id, 'user_id': user_id}).one(with_valid=True) @staticmethod def model_to_dict(model: Model): @@ -103,12 +146,46 @@ class ModelSerializer(serializers.Serializer): class Operate(serializers.Serializer): id = serializers.UUIDField(required=True) - def one(self, user_id, with_valid=False): + user_id = serializers.UUIDField(required=True) + + def is_valid(self, *, raise_exception=False): + super().is_valid(raise_exception=True) + model = QuerySet(Model).filter(id=self.data.get("id"), user_id=self.data.get("user_id")).first() + if model is None: + raise AppApiException(500, '模型不存在') + + def one(self, with_valid=False): if with_valid: self.is_valid(raise_exception=True) - model = QuerySet(Model).get(id=self.data.get('id'), user_id=user_id) + model = QuerySet(Model).get(id=self.data.get('id'), user_id=self.data.get('user_id')) return ModelSerializer.model_to_dict(model) + def delete(self, with_valid=True): + if with_valid: + self.is_valid(raise_exception=True) + QuerySet(Model).filter(id=self.data.get('id')).delete() + return True + + def edit(self, instance: Dict, user_id: str, with_valid=True): + if with_valid: + self.is_valid(raise_exception=True) + model = QuerySet(Model).filter(id=self.data.get('id')).first() + + if model is None: + raise AppApiException(500, '不存在的id') + else: + credential = ModelSerializer.Edit(data={**instance, 'user_id': user_id}).is_valid(model=model) + update_keys = ['credential', 'name', 'model_type', 'model_name'] + for update_key in update_keys: + if update_key in instance and instance.get(update_key) is not None: + if update_key == 'credential': + model_credential_str = json.dumps(credential) + model.__setattr__(update_key, encrypt(model_credential_str)) + else: + model.__setattr__(update_key, instance.get(update_key)) + model.save() + return self.one(with_valid=False) + class ProviderSerializer(serializers.Serializer): provider = serializers.CharField(required=True) diff --git a/apps/setting/swagger_api/provide_api.py b/apps/setting/swagger_api/provide_api.py index 083d37c8b..f68ac5be4 100644 --- a/apps/setting/swagger_api/provide_api.py +++ b/apps/setting/swagger_api/provide_api.py @@ -35,6 +35,30 @@ class ModelQueryApi(ApiMixin): ] +class ModelEditApi(ApiMixin): + @staticmethod + def get_request_body_api(): + return openapi.Schema(type=openapi.TYPE_OBJECT, + title="调用函数所需要的参数", + description="调用函数所需要的参数", + required=['provide', 'model_info'], + properties={ + 'name': openapi.Schema(type=openapi.TYPE_STRING, + title="模型名称", + description="模型名称"), + 'model_type': openapi.Schema(type=openapi.TYPE_STRING, + title="供应商", + description="供应商"), + 'model_name': openapi.Schema(type=openapi.TYPE_STRING, + title="供应商", + description="供应商"), + 'credential': openapi.Schema(type=openapi.TYPE_OBJECT, + title="模型证书信息", + description="模型证书信息") + } + ) + + class ModelCreateApi(ApiMixin): @staticmethod diff --git a/apps/setting/urls.py b/apps/setting/urls.py index 5501e9d2f..f58679046 100644 --- a/apps/setting/urls.py +++ b/apps/setting/urls.py @@ -15,5 +15,6 @@ urlpatterns = [ path('provider/model_form', views.Provide.ModelForm.as_view(), name="provider/model_form"), path('model', views.Model.as_view(), name='model'), + path('model/', views.Model.Operate.as_view(), name='model/operate') ] diff --git a/apps/setting/views/model.py b/apps/setting/views/model.py index 0efbc3aaa..330834efd 100644 --- a/apps/setting/views/model.py +++ b/apps/setting/views/model.py @@ -17,7 +17,7 @@ from common.response import result from common.util.common import query_params_to_single_dict from setting.models_provider.constants.model_provider_constants import ModelProvideConstants from setting.serializers.provider_serializers import ProviderSerializer, ModelSerializer -from setting.swagger_api.provide_api import ProvideApi, ModelCreateApi, ModelQueryApi +from setting.swagger_api.provide_api import ProvideApi, ModelCreateApi, ModelQueryApi, ModelEditApi class Model(APIView): @@ -46,9 +46,43 @@ class Model(APIView): data={**query_params_to_single_dict(request.query_params), 'user_id': request.user.id}).list( with_valid=True)) + class Operate(APIView): + authentication_classes = [TokenAuth] + + @action(methods=['PUT'], detail=False) + @swagger_auto_schema(operation_summary="修改模型", + operation_id="修改模型", + request_body=ModelEditApi.get_request_body_api() + , tags=["模型"]) + @has_permissions(PermissionConstants.MODEL_CREATE) + def put(self, request: Request, model_id: str): + return result.success( + ModelSerializer.Operate(data={'id': model_id, 'user_id': request.user.id}).edit(request.data, + str(request.user.id))) + + @action(methods=['DELETE'], detail=False) + @swagger_auto_schema(operation_summary="删除模型", + operation_id="删除模型", + responses=result.get_default_response() + , tags=["模型"]) + @has_permissions(PermissionConstants.MODEL_DELETE) + def delete(self, request: Request, model_id: str): + return result.success( + ModelSerializer.Operate(data={'id': model_id, 'user_id': request.user.id}).delete()) + + @action(methods=['GET'], detail=False) + @swagger_auto_schema(operation_summary="查询模型详细信息", + operation_id="查询模型详细信息", + tags=["模型"]) + @has_permissions(PermissionConstants.MODEL_READ) + def get(self, request: Request, model_id: str): + return result.success( + ModelSerializer.Operate(data={'id': model_id, 'user_id': request.user.id}).one(with_valid=True)) + class Provide(APIView): authentication_classes = [TokenAuth] + class Exec(APIView): authentication_classes = [TokenAuth] diff --git a/ui/src/api/model.ts b/ui/src/api/model.ts index e126c4077..cc52a072d 100644 --- a/ui/src/api/model.ts +++ b/ui/src/api/model.ts @@ -8,7 +8,8 @@ import type { ListModelRequest, Model, BaseModel, - CreateModelRequest + CreateModelRequest, + EditModelRequest } from '@/api/type/model' import type { FormField } from '@/components/dynamics-form/type' import type { KeyValue } from './type/common' @@ -84,11 +85,33 @@ const listBaseModel: ( * @param loading 加载器 * @returns */ -const createModel: (request: CreateModelRequest, loading?: Ref) => Promise = ( - request, +const createModel: ( + request: CreateModelRequest, + loading?: Ref +) => Promise> = (request, loading) => { + return post(`${prefix}`, request, {}, loading) +} + +/** + * 修改模型 + * @param request 請求對象 + * @param loading 加載器 + * @returns + */ +const updateModel: ( + model_id: string, + request: EditModelRequest, + loading?: Ref +) => Promise> = (model_id, request, loading) => { + console.log(request) + return put(`${prefix}/${model_id}`, request, {}, loading) +} + +const deleteModel: (model_id: string, loading?: Ref) => Promise> = ( + model_id, loading ) => { - return post(`${prefix}`, request, {}, loading) + return del(`${prefix}/${model_id}`, {}, loading) } export default { getModel, @@ -96,5 +119,7 @@ export default { getModelCreateForm, listModelType, listBaseModel, - createModel + createModel, + updateModel, + deleteModel } diff --git a/ui/src/api/type/model.ts b/ui/src/api/type/model.ts index 529db4a12..167154ada 100644 --- a/ui/src/api/type/model.ts +++ b/ui/src/api/type/model.ts @@ -43,7 +43,7 @@ interface Model { /** * 主键id */ - id: String + id: string /** * 模型名 */ @@ -88,6 +88,25 @@ interface CreateModelRequest { provider: string } +interface EditModelRequest { + /** + * 模型名 + */ + name: string + /** + * 模型类型 + */ + model_type: string + /** + * 基础模型 + */ + model_name: string + /** + * 认证信息 + */ + credential: any +} + interface BaseModel { /** * 基础模型名称 @@ -102,4 +121,12 @@ interface BaseModel { */ model_type: string } -export type { modelRequest, Provider, ListModelRequest, Model, BaseModel, CreateModelRequest } +export type { + modelRequest, + Provider, + ListModelRequest, + Model, + BaseModel, + CreateModelRequest, + EditModelRequest +} diff --git a/ui/src/components/dynamics-form/index.vue b/ui/src/components/dynamics-form/index.vue index 0abdd60c0..962e41294 100644 --- a/ui/src/components/dynamics-form/index.vue +++ b/ui/src/components/dynamics-form/index.vue @@ -104,9 +104,13 @@ const change = (field: FormField, value: any) => { formValue.value[field.field] = value } -watch(formValue.value, () => { - emit('update:modelValue', formValue.value) -}) +watch( + formValue, + () => { + emit('update:modelValue', formValue.value) + }, + { deep: true } +) /** * 触发器,用户获取子表单 或者 下拉选项 @@ -150,10 +154,13 @@ const initDefaultData = (formField: FormField) => { } onMounted(() => { - render(props.render_data) + render(props.render_data, {}) }) -const render = (render_data: string | Array | Promise>>) => { +const render = ( + render_data: string | Array | Promise>>, + data?: Dict +) => { if (typeof render_data == 'string') { triggerApi.get(render_data, {}, loading).then((ok) => { formFieldList.value = ok.data @@ -165,6 +172,10 @@ const render = (render_data: string | Array | Promise 取消 - 添加 + 添加 @@ -116,7 +115,12 @@ const credential_form_data = ref>({}) const form_data = computed({ get: () => { - return { ...base_form_data.value, ...credential_form_data.value } + return { + ...credential_form_data.value, + name: base_form_data.value.name, + model_type: base_form_data.value.model_type, + model_name: base_form_data.value.model_name + } }, set: (event: any) => { credential_form_data.value = event @@ -132,7 +136,7 @@ const getModelForm = (model_name: string) => { ).then((ok) => { model_form_field.value = ok.data // 渲染动态表单 - dynamicsFormRef.value?.render(model_form_field.value) + dynamicsFormRef.value?.render(model_form_field.value, undefined) }) } } @@ -171,6 +175,7 @@ const submit = () => { }, loading ).then((ok) => { + close() MsgSuccess('创建模型成功') emit('submit') }) diff --git a/ui/src/views/template/component/EditModel.vue b/ui/src/views/template/component/EditModel.vue new file mode 100644 index 000000000..7d84b173d --- /dev/null +++ b/ui/src/views/template/component/EditModel.vue @@ -0,0 +1,211 @@ + + + diff --git a/ui/src/views/template/component/Model.vue b/ui/src/views/template/component/Model.vue index 310f1f260..79de25108 100644 --- a/ui/src/views/template/component/Model.vue +++ b/ui/src/views/template/component/Model.vue @@ -28,16 +28,50 @@ > + +