diff --git a/apps/models_provider/serializers/model_serializer.py b/apps/models_provider/serializers/model_serializer.py index f44ffac2c..07df04953 100644 --- a/apps/models_provider/serializers/model_serializer.py +++ b/apps/models_provider/serializers/model_serializer.py @@ -336,20 +336,53 @@ class ModelSerializer(serializers.Serializer): query_params = self._build_query_params(workspace_id) return self._fetch_models(query_params) + def model_list(self, workspace_id, with_valid=True): + if with_valid: + self.is_valid(raise_exception=True) + queryset = self._build_query_params(workspace_id) + model_workspace_authorization = DatabaseModelManage.get_model("model_workspace_authorization") + if model_workspace_authorization is not None: + queryset = get_authorized_tool(queryset, workspace_id, + model_workspace_authorization=model_workspace_authorization) + shared_model = [] + normal_model = [] + + for model in queryset.order_by("-create_time"): + data = { + 'id': str(model.id), + 'provider': model.provider, + 'name': model.name, + 'model_type': model.model_type, + 'model_name': model.model_name, + 'status': model.status, + 'meta': model.meta, + 'user_id': model.user_id, + 'username': model.user.nick_name + } + if model.workspace_id == 'None': + shared_model.append(data) + else: + normal_model.append(data) + + return { + "shared_model": shared_model, + "model": normal_model + } + def _build_query_params(self, workspace_id): - query_params = {} + queryset = QuerySet(Model) if workspace_id: - query_params['workspace_id'] = workspace_id + queryset = queryset.filter(workspace_id=workspace_id) for field in ['name', 'model_type', 'model_name', 'provider', 'create_user']: value = self.data.get(field) if value is not None: if field == 'name': - query_params[f'{field}__icontains'] = value + queryset = queryset.filter(**{f'{field}__icontains': value}) elif field == 'create_user': - query_params['user_id'] = value + queryset = queryset.filter(user_id=value) else: - query_params[field] = value - return query_params + queryset = queryset.filter(**{field: value}) + return queryset def _fetch_models(self, query_params): return [ diff --git a/apps/models_provider/urls.py b/apps/models_provider/urls.py index a5b25016b..1860ac48b 100644 --- a/apps/models_provider/urls.py +++ b/apps/models_provider/urls.py @@ -12,6 +12,7 @@ urlpatterns = [ path('provider/model_params_form', views.Provide.ModelParamsForm.as_view()), path('provider/model_form', views.Provide.ModelForm.as_view()), path('workspace//model', views.ModelSetting.as_view()), + path('workspace//model_list', views.ModelList.as_view()), path('workspace//model//model_params_form', views.ModelSetting.ModelParamsForm.as_view()), path('workspace//model/', views.ModelSetting.Operate.as_view()), diff --git a/apps/models_provider/views/model.py b/apps/models_provider/views/model.py index 02a77d4b9..3586b03ad 100644 --- a/apps/models_provider/views/model.py +++ b/apps/models_provider/views/model.py @@ -248,3 +248,22 @@ class WorkspaceSharedModelSetting(APIView): def get(self, request: Request, workspace_id: str): return result.success( WorkspaceSharedModelSerializer(data={'workspace_id': workspace_id}).get_share_model_list()) + + +class ModelList(APIView): + authentication_classes = [TokenAuth] + + @extend_schema(methods=['GET'], + summary=_('Query all model list'), + description=_('Query all model list'), + operation_id=_('Query all model list'), # type: ignore + parameters=ModelListResponse.get_parameters(), + responses=ModelListResponse.get_response(), + tags=[_('Model')]) # type: ignore + @has_permissions(PermissionConstants.MODEL_READ.get_workspace_permission(), + RoleConstants.WORKSPACE_MANAGE.get_workspace_role(), RoleConstants.USER.get_workspace_role()) + def get(self, request: Request, workspace_id: str): + return result.success( + ModelSerializer.Query( + data={**query_params_to_single_dict(request.query_params)}).model_list(workspace_id=workspace_id, + with_valid=True))