From d2ec6d558b20dcad07fafae973855f670bcd0790 Mon Sep 17 00:00:00 2001 From: wxg0103 <727495428@qq.com> Date: Fri, 20 Jun 2025 10:40:51 +0800 Subject: [PATCH] fix: model add workspace_id --- .../serializers/model_serializer.py | 17 +++++++++++------ apps/models_provider/views/model.py | 16 ++++++++++------ 2 files changed, 21 insertions(+), 12 deletions(-) diff --git a/apps/models_provider/serializers/model_serializer.py b/apps/models_provider/serializers/model_serializer.py index f745805ff..9f8e0d087 100644 --- a/apps/models_provider/serializers/model_serializer.py +++ b/apps/models_provider/serializers/model_serializer.py @@ -107,12 +107,15 @@ class ModelSerializer(serializers.Serializer): class Operate(serializers.Serializer): id = serializers.UUIDField(required=True, label=_("model id")) user_id = serializers.UUIDField(required=False, label=_("user id")) + workspace_id = serializers.CharField(required=False, label=_("workspace id")) def is_valid(self, *, raise_exception=False): super().is_valid(raise_exception=True) - model = QuerySet(Model).filter( - id=self.data.get("id") - ).first() + workspace_id = self.data.get("workspace_id") + model_query = QuerySet(Model).filter(id=self.data.get("id")) + if workspace_id is not None: + model_query = model_query.filter(workspace_id=workspace_id) + model = model_query.first() if model is None: raise AppApiException(500, _('Model does not exist')) if model.workspace_id == 'None': @@ -122,7 +125,7 @@ class ModelSerializer(serializers.Serializer): if with_valid: super().is_valid(raise_exception=True) model = QuerySet(Model).get( - id=self.data.get('id') + id=self.data.get('id'), workspace_id=self.data.get('workspace_id') ) return ModelSerializer.model_to_dict(model) @@ -130,13 +133,15 @@ class ModelSerializer(serializers.Serializer): model = None if with_valid: super().is_valid(raise_exception=True) - model = QuerySet(Model).filter(id=self.data.get("id")).first() + model = QuerySet(Model).filter(id=self.data.get("id"), + workspace_id=self.data.get('workspace_id')).first() if model is None: raise AppApiException(500, _('Model does not exist')) return {'id': str(model.id), 'provider': model.provider, 'name': model.name, 'model_type': model.model_type, 'model_name': model.model_name, 'status': model.status, - 'meta': model.meta + 'meta': model.meta, + 'workspace_id': model.workspace_id, } def pause_download(self, with_valid=True): diff --git a/apps/models_provider/views/model.py b/apps/models_provider/views/model.py index cd9e02eac..a688c9db7 100644 --- a/apps/models_provider/views/model.py +++ b/apps/models_provider/views/model.py @@ -121,8 +121,9 @@ class ModelSetting(APIView): ) def put(self, request: Request, workspace_id, model_id: str): return result.success( - ModelSerializer.Operate(data={'id': model_id, 'user_id': request.user.id}).edit(request.data, - str(request.user.id))) + ModelSerializer.Operate( + data={'id': model_id, 'user_id': request.user.id, 'workspace_id': workspace_id}).edit(request.data, + str(request.user.id))) @extend_schema(methods=['DELETE'], summary=_('Delete model'), @@ -138,7 +139,8 @@ class ModelSetting(APIView): ) def delete(self, request: Request, workspace_id: str, model_id: str): return result.success( - ModelSerializer.Operate(data={'id': model_id, 'user_id': request.user.id}).delete()) + ModelSerializer.Operate( + data={'id': model_id, 'user_id': request.user.id, 'workspace_id': workspace_id}).delete()) @extend_schema(methods=['GET'], summary=_('Query model details'), @@ -151,7 +153,9 @@ class ModelSetting(APIView): RoleConstants.WORKSPACE_MANAGE.get_workspace_role()) def get(self, request: Request, workspace_id: str, model_id: str): return result.success( - ModelSerializer.Operate(data={'id': model_id, 'user_id': request.user.id}).one(with_valid=True)) + ModelSerializer.Operate( + data={'id': model_id, 'user_id': request.user.id, 'workspace_id': workspace_id}).one( + with_valid=True)) class ModelParamsForm(APIView): authentication_classes = [TokenAuth] @@ -203,7 +207,7 @@ class ModelSetting(APIView): RoleConstants.WORKSPACE_MANAGE.get_workspace_role()) def get(self, request: Request, workspace_id: str, model_id: str): return result.success( - ModelSerializer.Operate(data={'id': model_id}).one_meta(with_valid=True)) + ModelSerializer.Operate(data={'id': model_id, 'workspace_id': workspace_id}).one_meta(with_valid=True)) class PauseDownload(APIView): authentication_classes = [TokenAuth] @@ -220,7 +224,7 @@ class ModelSetting(APIView): RoleConstants.WORKSPACE_MANAGE.get_workspace_role()) def put(self, request: Request, workspace_id: str, model_id: str): return result.success( - ModelSerializer.Operate(data={'id': model_id}).pause_download()) + ModelSerializer.Operate(data={'id': model_id, 'workspace_id': workspace_id}).pause_download()) class SharedModel(APIView):