From cc3813e97ef64cbfaf628eff408b398073213e5f Mon Sep 17 00:00:00 2001 From: wxg0103 <727495428@qq.com> Date: Fri, 27 Dec 2024 18:32:22 +0800 Subject: [PATCH] fix: refactoring embedded application logic --- .../serializers/application_serializers.py | 35 +++++++- apps/application/urls.py | 2 + apps/application/views/application_views.py | 20 ++++- ui/src/api/application.ts | 11 ++- ui/src/workflow/common/NodeContainer.vue | 2 +- .../workflow/nodes/application-node/index.vue | 86 +++++++++++++++++++ 6 files changed, 150 insertions(+), 6 deletions(-) diff --git a/apps/application/serializers/application_serializers.py b/apps/application/serializers/application_serializers.py index 291d89603..969ec9065 100644 --- a/apps/application/serializers/application_serializers.py +++ b/apps/application/serializers/application_serializers.py @@ -21,6 +21,7 @@ from django.core import cache, validators from django.core import signing from django.db import transaction, models from django.db.models import QuerySet +from django.db.models.expressions import RawSQL from django.http import HttpResponse from django.template import Template, Context from rest_framework import serializers, status @@ -46,13 +47,15 @@ from dataset.serializers.common_serializers import list_paragraph, get_embedding from embedding.models import SearchMode from function_lib.models.function import FunctionLib, PermissionType from function_lib.serializers.function_lib_serializer import FunctionLibSerializer, FunctionLibModelSerializer -from setting.models import AuthOperate +from setting.models import AuthOperate, TeamMemberPermission from setting.models.model_management import Model from setting.models_provider import get_model_credential from setting.models_provider.tools import get_model_instance_by_model_user_id from setting.serializers.provider_serializers import ModelSerializer from smartdoc.conf import PROJECT_DIR from users.models import User +from django.db.models import Value +from django.db.models.fields.json import KeyTextTransform chat_cache = cache.caches['chat_cache'] @@ -1110,12 +1113,38 @@ class ApplicationSerializer(serializers.Serializer): self.is_valid(raise_exception=True) user_id = self.data.get('user_id') application_id = self.data.get('application_id') - application = Application.objects.filter(user_id=user_id).exclude(id=application_id) + application = QuerySet(Application).get(id=application_id) + + application_user_id = user_id if user_id == str(application.user_id) else None + + if application_user_id is not None: + all_applications = Application.objects.filter(user_id=application_user_id).exclude(id=application_id) + else: + all_applications = Application.objects.none() + + # 获取团队共享的应用 + shared_applications = Application.objects.filter( + id__in=TeamMemberPermission.objects.filter( + auth_target_type='APPLICATION', + operate__contains=RawSQL("ARRAY['USE']", []), + member_id__team_id=application.user_id, + member_id__user_id=user_id + ).values('target') + ) + all_applications = all_applications.union(shared_applications) + # 把应用的type为WORK_FLOW的应用放到最上面 然后再按名称排序 - serialized_data = ApplicationSerializerModel(application, many=True).data + serialized_data = ApplicationSerializerModel(all_applications, many=True).data application = sorted(serialized_data, key=lambda x: (x['type'] != 'WORK_FLOW', x['name'])) return list(application) + def get_application(self, app_id, with_valid=True): + if with_valid: + self.is_valid(raise_exception=True) + application = QuerySet(Application).filter(id=self.data.get("application_id")).first() + return ApplicationSerializer.Operate(data={'user_id': application.user_id, 'application_id': app_id}).one( + with_valid=True) + class ApplicationKeySerializerModel(serializers.ModelSerializer): class Meta: model = ApplicationApiKey diff --git a/apps/application/urls.py b/apps/application/urls.py index f4a6f9321..6dc2ae5af 100644 --- a/apps/application/urls.py +++ b/apps/application/urls.py @@ -25,6 +25,8 @@ urlpatterns = [ path('application//function_lib/', views.Application.FunctionLib.Operate.as_view()), path('application//application', views.Application.Application.as_view()), + path('application//application/', + views.Application.Application.Operate.as_view()), path('application//model_params_form/', views.Application.ModelParamsForm.as_view()), path('application//hit_test', views.Application.HitTest.as_view()), diff --git a/apps/application/views/application_views.py b/apps/application/views/application_views.py index dd16e2c65..810d8d0a9 100644 --- a/apps/application/views/application_views.py +++ b/apps/application/views/application_views.py @@ -288,6 +288,25 @@ class Application(APIView): data={'application_id': application_id, 'user_id': request.user.id}).application_list()) + class Operate(APIView): + authentication_classes = [TokenAuth] + + @action(methods=["GET"], detail=False) + @swagger_auto_schema(operation_summary="获取应用数据", + operation_id="获取应用数据", + tags=["应用"], + ) + @has_permissions(ViewPermission( + [RoleConstants.ADMIN, RoleConstants.USER], + [lambda r, keywords: Permission(group=Group.APPLICATION, operate=Operate.USE, + dynamic_tag=keywords.get('application_id'))], + compare=CompareConstants.AND)) + def get(self, request: Request, application_id: str, app_id: str): + return result.success( + ApplicationSerializer.Operate( + data={'application_id': application_id, + 'user_id': request.user.id}).get_application(app_id)) + class Profile(APIView): authentication_classes = [TokenAuth] @@ -433,7 +452,6 @@ class Application(APIView): ) @action(methods=['POST'], detail=False) - @swagger_auto_schema(operation_summary="创建应用", operation_id="创建应用", request_body=ApplicationApi.Create.get_request_body_api(), diff --git a/ui/src/api/application.ts b/ui/src/api/application.ts index 2870d96b0..a85b031c0 100644 --- a/ui/src/api/application.ts +++ b/ui/src/api/application.ts @@ -349,6 +349,14 @@ const getFunctionLib: ( ) => Promise> = (application_id, function_lib_id, loading) => { return get(`${prefix}/${application_id}/function_lib/${function_lib_id}`, undefined, loading) } + +const getApplicationById: ( + application_id: String, + app_id: String, + loading?: Ref +) => Promise> = (application_id, app_id, loading) => { + return get(`${prefix}/${application_id}/application/${app_id}`, undefined, loading) +} /** * 获取模型参数表单 * @param application_id 应用id @@ -567,5 +575,6 @@ export default { getApplicationList, uploadFile, exportApplication, - importApplication + importApplication, + getApplicationById } diff --git a/ui/src/workflow/common/NodeContainer.vue b/ui/src/workflow/common/NodeContainer.vue index 6f388b0c1..0c7c84c98 100644 --- a/ui/src/workflow/common/NodeContainer.vue +++ b/ui/src/workflow/common/NodeContainer.vue @@ -79,7 +79,7 @@ () const form_data = computed({ @@ -174,6 +180,85 @@ const form_data = computed({ } }) +function handleFileUpload(type: string, isEnabled: boolean) { + const listKey = `${type}_list` + if (isEnabled) { + if (!props.nodeModel.properties.node_data[listKey]) { + set(props.nodeModel.properties.node_data, listKey, []) + } + } else { + // eslint-disable-next-line vue/no-mutating-props + delete props.nodeModel.properties.node_data[listKey] + } +} + +const update_field = () => { + if (!props.nodeModel.properties.node_data.application_id) { + set(props.nodeModel.properties, 'status', 500) + return + } + applicationApi + .getApplicationById(id, props.nodeModel.properties.node_data.application_id) + .then((ok) => { + const old_api_input_field_list = props.nodeModel.properties.node_data.api_input_field_list + const old_user_input_field_list = props.nodeModel.properties.node_data.user_input_field_list + + if (isWorkFlow(ok.data.type)) { + const nodeData = ok.data.work_flow.nodes[0].properties.node_data + const new_api_input_field_list = ok.data.work_flow.nodes[0].properties.api_input_field_list + const new_user_input_field_list = + ok.data.work_flow.nodes[0].properties.user_input_field_list + const merge_api_input_field_list = new_api_input_field_list.map((item: any) => { + const find_field = old_api_input_field_list.find( + (old_item: any) => old_item.variable == item.variable + ) + if (find_field) { + return { ...item, default_value: JSON.parse(JSON.stringify(find_field.default_value)) } + } else { + return item + } + }) + set( + props.nodeModel.properties.node_data, + 'api_input_field_list', + merge_api_input_field_list + ) + const merge_user_input_field_list = new_user_input_field_list.map((item: any) => { + const find_field = old_user_input_field_list.find( + (old_item: any) => old_item.field == item.field + ) + if (find_field) { + return { ...item, default_value: JSON.parse(JSON.stringify(find_field.default_value)) } + } else { + return item + } + }) + console.log(merge_user_input_field_list) + set( + props.nodeModel.properties.node_data, + 'user_input_field_list', + merge_user_input_field_list + ) + const fileUploadSetting = nodeData.file_upload_setting + // 如果是true,说明有文件上传 + if (fileUploadSetting) { + handleFileUpload('document', fileUploadSetting.document) + handleFileUpload('image', fileUploadSetting.image) + handleFileUpload('audio', fileUploadSetting.audio) + } else { + ;['document_list', 'image_list', 'audio_list'].forEach((list) => { + // eslint-disable-next-line vue/no-mutating-props + delete props.nodeModel.properties.node_data[list] + }) + } + set(props.nodeModel.properties, 'status', ok.data.id ? 200 : 500) + } + }) + .catch((err) => { + set(props.nodeModel.properties, 'status', 500) + }) +} + const props = defineProps<{ nodeModel: any }>() const validate = () => { @@ -183,6 +268,7 @@ const validate = () => { } onMounted(() => { + update_field() set(props.nodeModel, 'validate', validate) })