From 01c0f62b4a0276c530a4d8af790e12afd33dfcec Mon Sep 17 00:00:00 2001 From: shaohuzhang1 Date: Tue, 5 Mar 2024 17:03:45 +0800 Subject: [PATCH] =?UTF-8?q?feat:=20=E6=8F=90=E4=BE=9B=E6=A0=B9=E6=8D=AE?= =?UTF-8?q?=E5=BA=94=E7=94=A8=E6=9F=A5=E8=AF=A2=E6=A8=A1=E5=9E=8B=E5=88=97?= =?UTF-8?q?=E8=A1=A8=E6=8E=A5=E5=8F=A3?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../serializers/application_serializers.py | 15 ++++++++++++++- apps/application/urls.py | 1 + apps/application/views/application_views.py | 19 +++++++++++++++++++ 3 files changed, 34 insertions(+), 1 deletion(-) diff --git a/apps/application/serializers/application_serializers.py b/apps/application/serializers/application_serializers.py index 2d13d0ad1..d421f4616 100644 --- a/apps/application/serializers/application_serializers.py +++ b/apps/application/serializers/application_serializers.py @@ -33,6 +33,7 @@ from dataset.serializers.common_serializers import list_paragraph from setting.models import AuthOperate from setting.models.model_management import Model from setting.models_provider.constants.model_provider_constants import ModelProvideConstants +from setting.serializers.provider_serializers import ModelSerializer from smartdoc.conf import PROJECT_DIR from smartdoc.settings import JWT_AUTH @@ -329,6 +330,14 @@ class ApplicationSerializer(serializers.Serializer): if not QuerySet(Application).filter(id=self.data.get('application_id')).exists(): raise AppApiException(500, '不存在的应用id') + def list_model(self, with_valid=True): + if with_valid: + self.is_valid() + application = QuerySet(Application).filter(id=self.data.get("application_id")).first() + return ModelSerializer.Query( + data={'user_id': application.user_id}).list( + with_valid=True) + def delete(self, with_valid=True): if with_valid: self.is_valid() @@ -366,7 +375,11 @@ class ApplicationSerializer(serializers.Serializer): application = QuerySet(Application).get(id=application_id) - model = QuerySet(Model).get(id=instance.get('model_id') if 'model_id' in instance else application.model_id) + model = QuerySet(Model).filter( + id=instance.get('model_id') if 'model_id' in instance else application.model_id, + user_id=application.user_id).first() + if model is None: + raise AppApiException(500, "模型不存在") update_keys = ['name', 'desc', 'model_id', 'multiple_rounds_dialogue', 'prologue', 'status', 'dataset_setting', 'model_setting', 'problem_optimization', diff --git a/apps/application/urls.py b/apps/application/urls.py index 6842737b8..7b530eefc 100644 --- a/apps/application/urls.py +++ b/apps/application/urls.py @@ -7,6 +7,7 @@ urlpatterns = [ path('application', views.Application.as_view(), name="application"), path('application/profile', views.Application.Profile.as_view()), path('application/authentication', views.Application.Authentication.as_view()), + path('application//model', views.Application.Model.as_view()), path('application//hit_test', views.Application.HitTest.as_view()), path('application//api_key', views.Application.ApplicationKey.as_view()), path("application//api_key/", diff --git a/apps/application/views/application_views.py b/apps/application/views/application_views.py index 8fde8fad4..f1d7595d6 100644 --- a/apps/application/views/application_views.py +++ b/apps/application/views/application_views.py @@ -27,6 +27,25 @@ from dataset.serializers.dataset_serializers import DataSetSerializers class Application(APIView): authentication_classes = [TokenAuth] + class Model(APIView): + authentication_classes = [TokenAuth] + + @action(methods=["GET"], detail=False) + @swagger_auto_schema(operation_summary="获取模型列表", + operation_id="获取模型列表", + tags=["应用"], + manual_parameters=ApplicationApi.ApiKey.get_request_params_api()) + @has_permissions(ViewPermission( + [RoleConstants.ADMIN, RoleConstants.USER], + [lambda r, keywords: Permission(group=Group.APPLICATION, operate=Operate.USE, + dynamic_tag=keywords.get('application_id'))], + compare=CompareConstants.AND)) + def get(self, request: Request, application_id: str): + return result.success( + ApplicationSerializer.Operate( + data={'application_id': application_id, + 'user_id': request.user.id}).list_model()) + class Profile(APIView): authentication_classes = [TokenAuth]