From 92fd481affc5b998ddf8a5d4ba986f87b22195f9 Mon Sep 17 00:00:00 2001 From: wxg0103 <727495428@qq.com> Date: Mon, 23 Jun 2025 17:21:55 +0800 Subject: [PATCH] refactor: shared model --- .../serializers/model_serializer.py | 58 +++++++------------ apps/models_provider/views/model.py | 2 +- 2 files changed, 22 insertions(+), 38 deletions(-) diff --git a/apps/models_provider/serializers/model_serializer.py b/apps/models_provider/serializers/model_serializer.py index 07df04953..a95058903 100644 --- a/apps/models_provider/serializers/model_serializer.py +++ b/apps/models_provider/serializers/model_serializer.py @@ -334,35 +334,22 @@ class ModelSerializer(serializers.Serializer): self.is_valid(raise_exception=True) query_params = self._build_query_params(workspace_id) - return self._fetch_models(query_params) + return [self._build_model_data(model) for model in query_params.order_by("-create_time")] def model_list(self, workspace_id, with_valid=True): if with_valid: self.is_valid(raise_exception=True) queryset = self._build_query_params(workspace_id) model_workspace_authorization = DatabaseModelManage.get_model("model_workspace_authorization") - if model_workspace_authorization is not None: - queryset = get_authorized_tool(queryset, workspace_id, - model_workspace_authorization=model_workspace_authorization) - shared_model = [] - normal_model = [] - for model in queryset.order_by("-create_time"): - data = { - 'id': str(model.id), - 'provider': model.provider, - 'name': model.name, - 'model_type': model.model_type, - 'model_name': model.model_name, - 'status': model.status, - 'meta': model.meta, - 'user_id': model.user_id, - 'username': model.user.nick_name - } - if model.workspace_id == 'None': - shared_model.append(data) - else: - normal_model.append(data) + shared_queryset = QuerySet(Model).filter(workspace_id='None') + if model_workspace_authorization is not None: + shared_queryset = get_authorized_tool(QuerySet(Model), workspace_id, + model_workspace_authorization=model_workspace_authorization) + + # 构建共享模型和普通模型列表 + shared_model = [self._build_model_data(model) for model in shared_queryset.order_by("-create_time")] + normal_model = [self._build_model_data(model) for model in queryset.order_by("-create_time")] return { "shared_model": shared_model, @@ -384,21 +371,18 @@ class ModelSerializer(serializers.Serializer): queryset = queryset.filter(**{field: value}) return queryset - def _fetch_models(self, query_params): - return [ - { - 'id': str(model.id), - 'provider': model.provider, - 'name': model.name, - 'model_type': model.model_type, - 'model_name': model.model_name, - 'status': model.status, - 'meta': model.meta, - 'user_id': model.user_id, - 'username': model.user.nick_name - } - for model in Model.objects.filter(**query_params).order_by("-create_time") - ] + def _build_model_data(self, model): + return { + 'id': str(model.id), + 'provider': model.provider, + 'name': model.name, + 'model_type': model.model_type, + 'model_name': model.model_name, + 'status': model.status, + 'meta': model.meta, + 'user_id': model.user_id, + 'username': model.user.nick_name + } class ModelParams(serializers.Serializer): id = serializers.UUIDField(required=True, label=_('model id')) diff --git a/apps/models_provider/views/model.py b/apps/models_provider/views/model.py index 3586b03ad..14832a50e 100644 --- a/apps/models_provider/views/model.py +++ b/apps/models_provider/views/model.py @@ -21,7 +21,7 @@ from common.utils.common import query_params_to_single_dict from models_provider.api.model import ModelCreateAPI, GetModelApi, ModelEditApi, ModelListResponse, DefaultModelResponse from models_provider.api.provide import ProvideApi from models_provider.models import Model -from models_provider.serializers.model_serializer import ModelSerializer, SharedModelSerializer, \ +from models_provider.serializers.model_serializer import ModelSerializer, \ WorkspaceSharedModelSerializer from system_manage.views import encryption_str