diff --git a/apps/application/serializers/application_serializers.py b/apps/application/serializers/application_serializers.py index 8698f4627..98fa74bbe 100644 --- a/apps/application/serializers/application_serializers.py +++ b/apps/application/serializers/application_serializers.py @@ -299,7 +299,7 @@ class ApplicationSerializer(serializers.Serializer): if 'dataset_id_list' in instance: dataset_id_list = instance.get('dataset_id_list') # 当前用户可修改关联的数据集列表 - application_dataset_id_list = [dataset_dict.get('id') for dataset_dict in + application_dataset_id_list = [str(dataset_dict.get('id')) for dataset_dict in self.list_dataset(with_valid=False)] for dataset_id in dataset_id_list: if not application_dataset_id_list.__contains__(dataset_id): diff --git a/apps/dataset/serializers/dataset_serializers.py b/apps/dataset/serializers/dataset_serializers.py index c199f842f..680cb4ea7 100644 --- a/apps/dataset/serializers/dataset_serializers.py +++ b/apps/dataset/serializers/dataset_serializers.py @@ -17,7 +17,9 @@ from django.db.models import QuerySet from drf_yasg import openapi from rest_framework import serializers +from application.models import ApplicationDatasetMapping from common.db.search import get_dynamics_model, native_page_search, native_search +from common.db.sql_execute import select_list from common.event.listener_manage import ListenerManagement from common.exception.app_exception import AppApiException from common.mixins.api_mixin import ApiMixin @@ -55,6 +57,48 @@ class DataSetSerializers(serializers.ModelSerializer): model = DataSet fields = ['id', 'name', 'desc', 'create_time', 'update_time'] + class Application(ApiMixin, serializers.Serializer): + user_id = serializers.UUIDField(required=True) + + dataset_id = serializers.UUIDField(required=True) + + @staticmethod + def get_request_params_api(): + return [ + openapi.Parameter(name='dataset_id', + in_=openapi.IN_PATH, + type=openapi.TYPE_STRING, + required=True, + description='数据集id') + ] + + @staticmethod + def get_response_body_api(): + return openapi.Schema( + type=openapi.TYPE_OBJECT, + required=['id', 'name', 'desc', 'model_id', 'multiple_rounds_dialogue', 'user_id', 'status', + 'create_time', + 'update_time'], + properties={ + 'id': openapi.Schema(type=openapi.TYPE_STRING, title="", description="主键id"), + 'name': openapi.Schema(type=openapi.TYPE_STRING, title="应用名称", description="应用名称"), + 'desc': openapi.Schema(type=openapi.TYPE_STRING, title="应用描述", description="应用描述"), + 'model_id': openapi.Schema(type=openapi.TYPE_STRING, title="模型id", description="模型id"), + "multiple_rounds_dialogue": openapi.Schema(type=openapi.TYPE_BOOLEAN, title="是否开启多轮对话", + description="是否开启多轮对话"), + 'prologue': openapi.Schema(type=openapi.TYPE_STRING, title="开场白", description="开场白"), + 'example': openapi.Schema(type=openapi.TYPE_ARRAY, items=openapi.Schema(type=openapi.TYPE_STRING), + title="示例列表", description="示例列表"), + 'user_id': openapi.Schema(type=openapi.TYPE_STRING, title="所属用户", description="所属用户"), + + 'status': openapi.Schema(type=openapi.TYPE_BOOLEAN, title="是否发布", description='是否发布'), + + 'create_time': openapi.Schema(type=openapi.TYPE_STRING, title="创建时间", description='创建时间'), + + 'update_time': openapi.Schema(type=openapi.TYPE_STRING, title="修改时间", description='修改时间') + } + ) + class Query(ApiMixin, serializers.Serializer): """ 查询对象 @@ -223,8 +267,15 @@ class DataSetSerializers(serializers.ModelSerializer): } ) + class Edit(serializers.Serializer): + + name = serializers.CharField(required=False) + desc = serializers.CharField(required=False) + application_id_list = serializers.ListSerializer(required=False, child=serializers.UUIDField(required=True)) + class Operate(ApiMixin, serializers.Serializer): id = serializers.CharField(required=True) + user_id = serializers.UUIDField(required=False) def is_valid(self, *, raise_exception=True): super().is_valid(raise_exception=True) @@ -242,6 +293,14 @@ class DataSetSerializers(serializers.ModelSerializer): ListenerManagement.delete_embedding_by_dataset_signal.send(self.data.get('id')) return True + def list_application(self, with_valid=True): + if with_valid: + self.is_valid(raise_exception=True) + dataset = QuerySet(DataSet).get(id=self.data.get("id")) + return select_list(get_file_content( + os.path.join(PROJECT_DIR, "apps", "dataset", 'sql', 'list_dataset_application.sql')), + [self.data.get('user_id'), dataset.user_id, self.data.get('user_id')]) + def one(self, user_id, with_valid=True): if with_valid: self.is_valid() @@ -260,9 +319,15 @@ class DataSetSerializers(serializers.ModelSerializer): default=AuthOperate.USE) )})).filter( **{'user_id': user_id, 'team_member_permission.operate__contains': ['USE']})} - - return native_search(query_set_dict, select_string=get_file_content( - os.path.join(PROJECT_DIR, "apps", "dataset", 'sql', 'list_dataset.sql')), with_search_one=True) + all_application_list = [str(adm.get('id')) for adm in self.list_application(with_valid=False)] + return {**native_search(query_set_dict, select_string=get_file_content( + os.path.join(PROJECT_DIR, "apps", "dataset", 'sql', 'list_dataset.sql')), with_search_one=True), + 'application_id_list': list( + filter(lambda application_id: all_application_list.__contains__(application_id), + [str(application_dataset_mapping.application_id) for + application_dataset_mapping in + QuerySet(ApplicationDatasetMapping).filter( + dataset_id=self.data.get('id'))]))} def edit(self, dataset: Dict, user_id: str): """ @@ -272,11 +337,32 @@ class DataSetSerializers(serializers.ModelSerializer): :return: """ self.is_valid() + DataSetSerializers.Edit(data=dataset).is_valid(raise_exception=True) _dataset = QuerySet(DataSet).get(id=self.data.get("id")) if "name" in dataset: _dataset.name = dataset.get("name") if 'desc' in dataset: _dataset.desc = dataset.get("desc") + if 'application_id_list' in dataset and dataset.get('application_id_list') is not None: + application_id_list = dataset.get('application_id_list') + # 当前用户可修改关联的数据集列表 + application_dataset_id_list = [str(dataset_dict.get('id')) for dataset_dict in + self.list_application(with_valid=False)] + for dataset_id in application_id_list: + if not application_dataset_id_list.__contains__(dataset_id): + raise AppApiException(500, f"未知的应用id${dataset_id},无法关联") + + # 删除已经关联的id + QuerySet(ApplicationDatasetMapping).filter(application_id__in=application_dataset_id_list, + dataset_id=self.data.get("id")).delete() + # 插入 + QuerySet(ApplicationDatasetMapping).bulk_create( + [ApplicationDatasetMapping(application_id=application_id, dataset_id=self.data.get('id')) for + application_id in + application_id_list]) if len(application_id_list) > 0 else None + [ApplicationDatasetMapping(application_id=application_id, dataset_id=self.data.get('id')) for + application_id in application_id_list] + _dataset.save() return self.one(with_valid=False, user_id=user_id) @@ -287,7 +373,10 @@ class DataSetSerializers(serializers.ModelSerializer): required=['name', 'desc'], properties={ 'name': openapi.Schema(type=openapi.TYPE_STRING, title="数据集名称", description="数据集名称"), - 'desc': openapi.Schema(type=openapi.TYPE_STRING, title="数据集描述", description="数据集描述") + 'desc': openapi.Schema(type=openapi.TYPE_STRING, title="数据集描述", description="数据集描述"), + 'application_id_list': openapi.Schema(type=openapi.TYPE_ARRAY, title="应用id列表", + description="应用id列表", + items=openapi.Schema(type=openapi.TYPE_STRING)) } ) diff --git a/apps/dataset/sql/list_dataset_application.sql b/apps/dataset/sql/list_dataset_application.sql new file mode 100644 index 000000000..9da36a3cf --- /dev/null +++ b/apps/dataset/sql/list_dataset_application.sql @@ -0,0 +1,20 @@ +SELECT + * +FROM + application +WHERE + user_id = %s UNION +SELECT + * +FROM + application +WHERE + "id" IN ( + SELECT + team_member_permission.target + FROM + team_member team_member + LEFT JOIN team_member_permission team_member_permission ON team_member_permission.member_id = team_member."id" + WHERE + ( "team_member_permission"."auth_target_type" = 'APPLICATION' AND "team_member_permission"."operate"::text[] @> ARRAY['USE'] AND team_member.team_id = %s AND team_member.user_id =%s ) + ) \ No newline at end of file diff --git a/apps/dataset/urls.py b/apps/dataset/urls.py index 1c6e05c57..e1ef1b154 100644 --- a/apps/dataset/urls.py +++ b/apps/dataset/urls.py @@ -6,6 +6,7 @@ app_name = "dataset" urlpatterns = [ path('dataset', views.Dataset.as_view(), name="dataset"), path('dataset/', views.Dataset.Operate.as_view(), name="dataset_key"), + path('dataset//application', views.Dataset.Application.as_view()), path('dataset//', views.Dataset.Page.as_view(), name="dataset"), path('dataset//document', views.Document.as_view(), name='document'), path('dataset//document/_bach', views.Document.Batch.as_view()), diff --git a/apps/dataset/views/dataset.py b/apps/dataset/views/dataset.py index 20cd37321..89b1e2e1c 100644 --- a/apps/dataset/views/dataset.py +++ b/apps/dataset/views/dataset.py @@ -22,6 +22,20 @@ from dataset.serializers.dataset_serializers import DataSetSerializers class Dataset(APIView): authentication_classes = [TokenAuth] + class Application(APIView): + authentication_classes = [TokenAuth] + + @action(methods=['GET'], detail=False) + @swagger_auto_schema(operation_summary="获取数据集可用应用列表", + operation_id="获取数据集可用应用列表", + manual_parameters=DataSetSerializers.Application.get_request_params_api(), + responses=result.get_api_array_response( + DataSetSerializers.Application.get_response_body_api()), + tags=["数据集"]) + def get(self, request: Request, dataset_id: str): + return result.success(DataSetSerializers.Operate( + data={'id': dataset_id, 'user_id': str(request.user.id)}).list_application()) + @action(methods=['GET'], detail=False) @swagger_auto_schema(operation_summary="获取数据集列表", operation_id="获取数据集列表", @@ -71,7 +85,8 @@ class Dataset(APIView): @has_permissions(lambda r, keywords: Permission(group=Group.DATASET, operate=Operate.USE, dynamic_tag=keywords.get('dataset_id'))) def get(self, request: Request, dataset_id: str): - return result.success(DataSetSerializers.Operate(data={'id': dataset_id}).one(user_id=request.user.id)) + return result.success(DataSetSerializers.Operate(data={'id': dataset_id, 'user_id': request.user.id}).one( + user_id=request.user.id)) @action(methods="PUT", detail=False) @swagger_auto_schema(operation_summary="修改数据集信息", operation_id="修改数据集信息", @@ -84,7 +99,8 @@ class Dataset(APIView): dynamic_tag=keywords.get('dataset_id'))) def put(self, request: Request, dataset_id: str): return result.success( - DataSetSerializers.Operate(data={'id': dataset_id}).edit(request.data, user_id=request.user.id)) + DataSetSerializers.Operate(data={'id': dataset_id, 'user_id': request.user.id}).edit(request.data, + user_id=request.user.id)) class Page(APIView): authentication_classes = [TokenAuth] diff --git a/ui/src/api/dataset.ts b/ui/src/api/dataset.ts index a75b0a254..11f5ca543 100644 --- a/ui/src/api/dataset.ts +++ b/ui/src/api/dataset.ts @@ -2,6 +2,7 @@ import { Result } from '@/request/Result' import { get, post, del, put } from '@/request/index' import type { datasetData } from '@/api/type/dataset' import type { pageRequest } from '@/api/type/common' +import type { ApplicationFormType } from '@/api/type/application' import { type Ref } from 'vue' const prefix = '/dataset' @@ -88,12 +89,24 @@ const putDateset: (dataset_id: string, data: any) => Promise> = ( ) => { return put(`${prefix}/${dataset_id}`, data) } - +/** + * 获取数据集 可关联的应用列表 + * @param dataset_id + * @param loading + * @returns + */ +const listUsableApplication: ( + dataset_id: string, + loading?: Ref +) => Promise>> = (dataset_id, loading) => { + return get(`${prefix}/${dataset_id}/application`, {}, loading) +} export default { getDateset, getAllDateset, delDateset, postDateset, getDatesetDetail, - putDateset + putDateset, + listUsableApplication } diff --git a/ui/src/views/dataset/component/BaseForm.vue b/ui/src/views/dataset/component/BaseForm.vue index bc1a320dd..ae88a2e8a 100644 --- a/ui/src/views/dataset/component/BaseForm.vue +++ b/ui/src/views/dataset/component/BaseForm.vue @@ -25,25 +25,37 @@ :autosize="{ minRows: 3 }" /> + + + + + {{ item.name }} + + + +