feat: Application import and export (#1836)

This commit is contained in:
shaohuzhang1 2024-12-16 14:19:57 +08:00 committed by GitHub
parent 390014fa1b
commit 64443ee136
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
8 changed files with 260 additions and 24 deletions

View File

@ -10,6 +10,7 @@ import datetime
import hashlib
import json
import os
import pickle
import re
import uuid
from functools import reduce
@ -19,10 +20,10 @@ from django.contrib.postgres.fields import ArrayField
from django.core import cache, validators
from django.core import signing
from django.db import transaction, models
from django.db.models import QuerySet, Q
from django.db.models import QuerySet
from django.http import HttpResponse
from django.template import Template, Context
from rest_framework import serializers
from rest_framework import serializers, status
from application.flow.workflow_manage import Flow
from application.models import Application, ApplicationDatasetMapping, ApplicationTypeChoices, WorkFlowVersion
@ -34,15 +35,17 @@ from common.constants.authentication_type import AuthenticationType
from common.db.search import get_dynamics_model, native_search, native_page_search
from common.db.sql_execute import select_list
from common.exception.app_exception import AppApiException, NotFound404, AppUnauthorizedFailed
from common.field.common import UploadedImageField
from common.field.common import UploadedImageField, UploadedFileField
from common.models.db_model_manage import DBModelManage
from common.response import result
from common.util.common import valid_license, password_encrypt
from common.util.field_message import ErrMessage
from common.util.file_util import get_file_content
from dataset.models import DataSet, Document, Image
from dataset.serializers.common_serializers import list_paragraph, get_embedding_model_by_dataset_id_list
from embedding.models import SearchMode
from function_lib.serializers.function_lib_serializer import FunctionLibSerializer
from function_lib.models.function import FunctionLib, PermissionType
from function_lib.serializers.function_lib_serializer import FunctionLibSerializer, FunctionLibModelSerializer
from setting.models import AuthOperate
from setting.models.model_management import Model
from setting.models_provider import get_model_credential
@ -54,6 +57,13 @@ from users.models import User
chat_cache = cache.caches['chat_cache']
class MKInstance:
def __init__(self, application: dict, function_lib_list: List[dict], version: str):
self.application = application
self.function_lib_list = function_lib_list
self.version = version
class ModelDatasetAssociation(serializers.Serializer):
user_id = serializers.UUIDField(required=True, error_messages=ErrMessage.uuid("用户id"))
model_id = serializers.CharField(required=False, allow_null=True, allow_blank=True,
@ -662,6 +672,72 @@ class ApplicationSerializer(serializers.Serializer):
get_application_access_token(application_access_token.access_token, False)
return {**ApplicationSerializer.Query.reset_application(ApplicationSerializerModel(application).data)}
class Import(serializers.Serializer):
file = UploadedFileField(required=True, error_messages=ErrMessage.image("文件"))
user_id = serializers.UUIDField(required=True, error_messages=ErrMessage.uuid("用户id"))
@valid_license(model=Application, count=5,
message='社区版最多支持 5 个应用如需拥有更多应用请联系我们https://fit2cloud.com/)。')
@transaction.atomic
def import_(self, with_valid=True):
if with_valid:
self.is_valid()
user_id = self.data.get('user_id')
mk_instance_bytes = self.data.get('file').read()
mk_instance = pickle.loads(mk_instance_bytes)
application = mk_instance.application
function_lib_list = mk_instance.function_lib_list
if len(function_lib_list) > 0:
function_lib_id_list = [function_lib.get('id') for function_lib in function_lib_list]
exits_function_lib_id_list = [str(function_lib.id) for function_lib in
QuerySet(FunctionLib).filter(id__in=function_lib_id_list)]
# 获取到需要插入的函数
function_lib_list = [function_lib for function_lib in function_lib_list if
not exits_function_lib_id_list.__contains__(function_lib.get('id'))]
application_model = self.to_application(application, user_id)
function_lib_model_list = [self.to_function_lib(f, user_id) for f in function_lib_list]
application_model.save()
QuerySet(FunctionLib).bulk_create(function_lib_model_list) if len(function_lib_model_list) > 0 else None
return True
@staticmethod
def to_application(application, user_id):
work_flow = application.get('work_flow')
for node in work_flow.get('nodes', []):
if node.get('type') == 'search-dataset-node':
node.get('properties', {}).get('node_data', {})['dataset_id_list'] = []
return Application(id=uuid.uuid1(), user_id=user_id, name=application.get('name'),
desc=application.get('desc'),
prologue=application.get('prologue'), dialogue_number=application.get('dialogue_number'),
dataset_setting=application.get('dataset_setting'),
model_params_setting=application.get('model_params_setting'),
tts_model_params_setting=application.get('tts_model_params_setting'),
problem_optimization=application.get('problem_optimization'),
icon=application.get('icon'),
work_flow=work_flow,
type=application.get('type'),
problem_optimization_prompt=application.get('problem_optimization_prompt'),
tts_model_enable=application.get('tts_model_enable'),
stt_model_enable=application.get('stt_model_enable'),
tts_type=application.get('tts_type'),
clean_time=application.get('clean_time'),
file_upload_enable=application.get('file_upload_enable'),
file_upload_setting=application.get('file_upload_setting'),
)
@staticmethod
def to_function_lib(function_lib, user_id):
"""
@param user_id: 用户id
@param function_lib: 函数库
@return:
"""
return FunctionLib(id=function_lib.get('id'), user_id=user_id, name=function_lib.get('name'),
code=function_lib.get('code'), input_field_list=function_lib.get('input_field_list'),
is_active=function_lib.get('is_active'),
permission_type=PermissionType.PRIVATE)
class Operate(serializers.Serializer):
application_id = serializers.UUIDField(required=True, error_messages=ErrMessage.uuid("应用id"))
user_id = serializers.UUIDField(required=True, error_messages=ErrMessage.uuid("用户id"))
@ -708,6 +784,31 @@ class ApplicationSerializer(serializers.Serializer):
QuerySet(Application).filter(id=self.data.get('application_id')).delete()
return True
def export(self, with_valid=True):
try:
if with_valid:
self.is_valid()
application_id = self.data.get('application_id')
application = QuerySet(Application).filter(id=application_id).first()
function_lib_id_list = [node.get('properties', {}).get('node_data', {}).get('function_lib_id') for node
in
application.work_flow.get('nodes', []) if
node.get('type') == 'function-lib-node']
function_lib_list = []
if len(function_lib_id_list) > 0:
function_lib_list = QuerySet(FunctionLib).filter(id__in=function_lib_id_list)
application_dict = ApplicationSerializerModel(application).data
mk_instance = MKInstance(application_dict,
[FunctionLibModelSerializer(function_lib).data for function_lib in
function_lib_list], 'v1')
application_pickle = pickle.dumps(mk_instance)
response = HttpResponse(content_type='text/plain', content=application_pickle)
response['Content-Disposition'] = f'attachment; filename="{application.name}.mk"'
return response
except Exception as e:
return result.error(str(e), response_status=status.HTTP_500_INTERNAL_SERVER_ERROR)
@transaction.atomic
def publish(self, instance, with_valid=True):
if with_valid:

View File

@ -336,6 +336,27 @@ class ApplicationApi(ApiMixin):
description='应用描述')
]
class Export(ApiMixin):
@staticmethod
def get_request_params_api():
return [openapi.Parameter(name='application_id',
in_=openapi.IN_PATH,
type=openapi.TYPE_STRING,
required=True,
description='应用id'),
]
class Import(ApiMixin):
@staticmethod
def get_request_params_api():
return [openapi.Parameter(name='file',
in_=openapi.IN_FORM,
type=openapi.TYPE_FILE,
required=True,
description='上传图片文件')
]
class Operate(ApiMixin):
@staticmethod
def get_request_params_api():

View File

@ -5,11 +5,13 @@ from . import views
app_name = "application"
urlpatterns = [
path('application', views.Application.as_view(), name="application"),
path('application/import', views.Application.Import.as_view()),
path('application/profile', views.Application.Profile.as_view(), name='application/profile'),
path('application/embed', views.Application.Embed.as_view()),
path('application/authentication', views.Application.Authentication.as_view()),
path('application/<str:application_id>/publish', views.Application.Publish.as_view()),
path('application/<str:application_id>/edit_icon', views.Application.EditIcon.as_view()),
path('application/<str:application_id>/export', views.Application.Export.as_view()),
path('application/<str:application_id>/statistics/customer_count',
views.ApplicationStatistics.CustomerCount.as_view()),
path('application/<str:application_id>/statistics/customer_count_trend',

View File

@ -27,7 +27,6 @@ from common.response import result
from common.swagger_api.common_api import CommonApi
from common.util.common import query_params_to_single_dict
from dataset.serializers.dataset_serializers import DataSetSerializers
from setting.swagger_api.provide_api import ProvideApi
chat_cache = cache.caches['chat_cache']
@ -158,6 +157,34 @@ class Application(APIView):
data={'application_id': application_id, 'user_id': request.user.id,
'image': request.FILES.get('file')}).edit(request.data))
class Import(APIView):
authentication_classes = [TokenAuth]
parser_classes = [MultiPartParser]
@action(methods="GET", detail=False)
@swagger_auto_schema(operation_summary="导入应用", operation_id="导入应用",
manual_parameters=ApplicationApi.Import.get_request_params_api(),
tags=["应用"]
)
@has_permissions(RoleConstants.ADMIN, RoleConstants.USER)
def post(self, request: Request):
return result.success(ApplicationSerializer.Import(
data={'user_id': request.user.id, 'file': request.FILES.get('file')}).import_())
class Export(APIView):
authentication_classes = [TokenAuth]
@action(methods="GET", detail=False)
@swagger_auto_schema(operation_summary="导出应用", operation_id="导出应用",
manual_parameters=ApplicationApi.Export.get_request_params_api(),
tags=["应用"]
)
@has_permissions(lambda r, keywords: Permission(group=Group.APPLICATION, operate=Operate.MANAGE,
dynamic_tag=keywords.get('application_id')))
def get(self, request: Request, application_id: str):
return ApplicationSerializer.Operate(
data={'application_id': application_id, 'user_id': request.user.id}).export()
class Embed(APIView):
@action(methods=["GET"], detail=False)
@swagger_auto_schema(operation_summary="获取嵌入js",
@ -362,7 +389,8 @@ class Application(APIView):
compare=CompareConstants.AND))
def put(self, request: Request, application_id: str):
return result.success(
ApplicationSerializer.AccessTokenSerializer(data={'application_id': application_id}).edit(request.data))
ApplicationSerializer.AccessTokenSerializer(data={'application_id': application_id}).edit(
request.data))
@action(methods=['GET'], detail=False)
@swagger_auto_schema(operation_summary="获取应用 AccessToken信息",
@ -382,9 +410,10 @@ class Application(APIView):
class Authentication(APIView):
@action(methods=['OPTIONS'], detail=False)
def options(self, request, *args, **kwargs):
return HttpResponse(headers={"Access-Control-Allow-Origin": "*", "Access-Control-Allow-Credentials": "true",
"Access-Control-Allow-Methods": "POST",
"Access-Control-Allow-Headers": "Origin,Content-Type,Cookie,Accept,Token"}, )
return HttpResponse(
headers={"Access-Control-Allow-Origin": "*", "Access-Control-Allow-Credentials": "true",
"Access-Control-Allow-Methods": "POST",
"Access-Control-Allow-Headers": "Origin,Content-Type,Cookie,Accept,Token"}, )
@action(methods=['POST'], detail=False)
@swagger_auto_schema(operation_summary="应用认证",
@ -404,6 +433,7 @@ class Application(APIView):
)
@action(methods=['POST'], detail=False)
@swagger_auto_schema(operation_summary="创建应用",
operation_id="创建应用",
request_body=ApplicationApi.Create.get_request_body_api(),
@ -444,7 +474,8 @@ class Application(APIView):
"query_text": request.query_params.get("query_text"),
"top_number": request.query_params.get("top_number"),
'similarity': request.query_params.get('similarity'),
'search_mode': request.query_params.get('search_mode')}).hit_test(
'search_mode': request.query_params.get(
'search_mode')}).hit_test(
))
class Publish(APIView):
@ -502,7 +533,8 @@ class Application(APIView):
compare=CompareConstants.AND))
def put(self, request: Request, application_id: str):
return result.success(
ApplicationSerializer.Operate(data={'application_id': application_id, 'user_id': request.user.id}).edit(
ApplicationSerializer.Operate(
data={'application_id': application_id, 'user_id': request.user.id}).edit(
request.data))
@action(methods=['GET'], detail=False)
@ -528,11 +560,14 @@ class Application(APIView):
@swagger_auto_schema(operation_summary="获取当前应用可使用的知识库",
operation_id="获取当前应用可使用的知识库",
manual_parameters=ApplicationApi.Operate.get_request_params_api(),
responses=result.get_api_array_response(DataSetSerializers.Query.get_response_body_api()),
responses=result.get_api_array_response(
DataSetSerializers.Query.get_response_body_api()),
tags=['应用'])
@has_permissions(ViewPermission([RoleConstants.ADMIN, RoleConstants.USER],
[lambda r, keywords: Permission(group=Group.APPLICATION, operate=Operate.USE,
dynamic_tag=keywords.get('application_id'))],
[lambda r, keywords: Permission(group=Group.APPLICATION,
operate=Operate.USE,
dynamic_tag=keywords.get(
'application_id'))],
compare=CompareConstants.AND))
def get(self, request: Request, application_id: str):
return result.success(ApplicationSerializer.Operate(

View File

@ -157,10 +157,10 @@ def success(data, **kwargs):
return Result(data=data, **kwargs)
def error(message):
def error(message, **kwargs):
"""
获取一个失败的响应对象
:param message: 错误提示
:return: 接口响应对象
"""
return Result(code=500, message=message)
return Result(code=500, message=message, **kwargs)

View File

@ -1,5 +1,5 @@
import { Result } from '@/request/Result'
import { get, post, postStream, del, put, request, download } from '@/request/index'
import { get, post, postStream, del, put, request, download, exportFile } from '@/request/index'
import type { pageRequest } from '@/api/type/common'
import type { ApplicationFormType } from '@/api/type/application'
import { type Ref } from 'vue'
@ -300,7 +300,6 @@ const getApplicationTTIModel: (
return get(`${prefix}/${application_id}/model`, { model_type: 'TTI' }, loading)
}
/**
*
* @param
@ -377,7 +376,6 @@ const uploadFile: (
return post(`${prefix}/${application_id}/chat/${chat_id}/upload_file`, data, undefined, loading)
}
/**
*
*/
@ -503,6 +501,28 @@ const getUserList: (type: string, loading?: Ref<boolean>) => Promise<Result<any>
return get(`/user/list/${type}`, undefined, loading)
}
const exportApplication = (
application_id: string,
application_name: string,
loading?: Ref<boolean>
) => {
return exportFile(
application_name + '.mk',
`/application/${application_id}/export`,
undefined,
loading
)
}
/**
*
*/
const importApplication: (data: any, loading?: Ref<boolean>) => Promise<Result<any>> = (
data,
loading
) => {
return post(`${prefix}/import`, data, undefined, loading)
}
export default {
getAllAppilcation,
getApplication,
@ -544,5 +564,7 @@ export default {
playDemoText,
getUserList,
getApplicationList,
uploadFile
uploadFile,
exportApplication,
importApplication
}

View File

@ -227,7 +227,6 @@ export const exportExcel: (
) => {
return promise(request({ url: url, method: 'get', params, responseType: 'blob' }), loading).then(
(res: any) => {
console.log(res)
if (res) {
const blob = new Blob([res], {
type: 'application/vnd.ms-excel'
@ -244,6 +243,35 @@ export const exportExcel: (
)
}
export const exportFile: (
fileName: string,
url: string,
params: any,
loading?: NProgress | Ref<boolean>
) => Promise<any> = (
fileName: string,
url: string,
params: any,
loading?: NProgress | Ref<boolean>
) => {
return promise(request({ url: url, method: 'get', params, responseType: 'blob' }), loading).then(
(res: any) => {
if (res) {
const blob = new Blob([res], {
type: 'application/octet-stream'
})
const link = document.createElement('a')
link.href = window.URL.createObjectURL(blob)
link.download = fileName
link.click()
//释放内存
window.URL.revokeObjectURL(link.href)
}
return true
}
)
}
export const exportExcelPost: (
fileName: string,
url: string,

View File

@ -3,6 +3,18 @@
<div class="flex-between mb-16">
<h4>{{ $t('views.application.applicationList.title') }}</h4>
<div class="flex-between">
<el-upload
:file-list="[]"
class="flex-between"
action="#"
multiple
:auto-upload="false"
:show-file-list="false"
:limit="1"
:on-change="(file: any, fileList: any) => importApplication(file)"
>
<el-button>导入应用</el-button>
</el-upload>
<el-select
v-model="selectUserId"
class="mr-12"
@ -128,7 +140,9 @@
<AppIcon iconName="app-copy"></AppIcon>
复制</el-dropdown-item
>
<el-dropdown-item icon="Delete" @click.stop="exportApplication(item)">
导出
</el-dropdown-item>
<el-dropdown-item icon="Delete" @click.stop="deleteApplication(item)">{{
$t('views.application.applicationList.card.delete.tooltip')
}}</el-dropdown-item>
@ -152,7 +166,7 @@ import { ref, onMounted, reactive } from 'vue'
import applicationApi from '@/api/application'
import CreateApplicationDialog from './component/CreateApplicationDialog.vue'
import CopyApplicationDialog from './component/CopyApplicationDialog.vue'
import { MsgSuccess, MsgConfirm, MsgAlert } from '@/utils/message'
import { MsgSuccess, MsgConfirm, MsgAlert, MsgError } from '@/utils/message'
import { isAppIcon } from '@/utils/application'
import { useRouter } from 'vue-router'
import { isWorkFlow } from '@/utils/application'
@ -203,7 +217,20 @@ function settingApplication(row: any) {
router.push({ path: `/application/${row.id}/${row.type}/setting` })
}
}
const exportApplication = (application: any) => {
applicationApi.exportApplication(application.id, application.name, loading).catch((e) => {
e.response.data.text().then((res: string) => {
MsgError(`导出失败:${JSON.parse(res).message}`)
})
})
}
const importApplication = (file: any) => {
const formData = new FormData()
formData.append('file', file.raw, file.name)
applicationApi.importApplication(formData, loading).then((ok) => {
searchHandle()
})
}
function openCreateDialog() {
if (user.isEnterprise()) {
CreateApplicationDialogRef.value.open()