diff --git a/apps/common/util/split_model.py b/apps/common/util/split_model.py index 7f391468c..8e305bb27 100644 --- a/apps/common/util/split_model.py +++ b/apps/common/util/split_model.py @@ -13,7 +13,7 @@ from typing import List import jieba -def get_level_block(text, level_content_list, level_content_index): +def get_level_block(text, level_content_list, level_content_index, cursor): """ 从文本中获取块数据 :param text: 文本 @@ -24,9 +24,10 @@ def get_level_block(text, level_content_list, level_content_index): start_content: str = level_content_list[level_content_index].get('content') next_content = level_content_list[level_content_index + 1].get("content") if level_content_index + 1 < len( level_content_list) else None - start_index = text.index(start_content) - end_index = text.index(next_content) if next_content is not None else len(text) - return text[start_index:end_index].replace(level_content_list[level_content_index]['content'], "") + print(len(text), cursor, start_content) + start_index = text.index(start_content, cursor) + end_index = text.index(next_content, start_index + 1) if next_content is not None else len(text) + return text[start_index:end_index].replace(level_content_list[level_content_index]['content'], ""), end_index def to_tree_obj(content, state='title'): @@ -297,8 +298,9 @@ class SplitModel: if len(self.content_level_pattern) == index: return level_content_list = parse_title_level(text, self.content_level_pattern, index) + cursor = 0 for i in range(len(level_content_list)): - block = get_level_block(text, level_content_list, i) + block, cursor = get_level_block(text, level_content_list, i, cursor) children = self.parse_to_tree(text=block, index=index + 1) if children is not None and len(children) > 0: @@ -317,6 +319,11 @@ class SplitModel: level_content_list = [*level_content_list, *list( map(lambda row: to_tree_obj(row, 'block'), post_handler_paragraph(other_content, with_filter=self.with_filter, limit=self.limit)))] + else: + if len(text.strip()) > 0: + level_content_list = [*level_content_list, *list( + map(lambda row: to_tree_obj(row, 'block'), + post_handler_paragraph(text, with_filter=self.with_filter, limit=self.limit)))] return level_content_list def parse(self, text: str): @@ -329,25 +336,29 @@ class SplitModel: return result_tree_to_paragraph(result_tree, [], []) -split_model_map = { - 'md': SplitModel( - [re.compile("^# .*"), re.compile('(? 1024 * 1024 * 10: @@ -282,8 +282,21 @@ class DocumentSerializers(ApiMixin, serializers.Serializer): def parse(self): file_list = self.data.get("file") - return list(map(lambda f: file_to_paragraph(f, self.data.get("patterns"), self.data.get("with_filter"), - self.data.get("limit")), file_list)) + return list( + map(lambda f: file_to_paragraph(f, self.data.get("patterns", None), self.data.get("with_filter", None), + self.data.get("limit", None)), file_list)) + + class SplitPattern(ApiMixin, serializers.Serializer): + @staticmethod + def list(): + return [{'key': "#", 'value': '^# .*'}, {'key': '##', 'value': '(? 0: + if pattern_list is not None and len(pattern_list) > 0: split_model = SplitModel(pattern_list, with_filter, limit) else: - split_model = get_split_model(file.name) + split_model = get_split_model(file.name, with_filter=with_filter, limit=limit) try: content = data.decode('utf-8') except BaseException as e: diff --git a/apps/dataset/sql/update_document_char_length.sql b/apps/dataset/sql/update_document_char_length.sql new file mode 100644 index 000000000..a09c8cabb --- /dev/null +++ b/apps/dataset/sql/update_document_char_length.sql @@ -0,0 +1,4 @@ +UPDATE "document" +SET "char_length" = ( SELECT "sum" ( "char_length" ( "content" ) ) FROM paragraph WHERE "document_id" = %s ) +WHERE + "id" = %s \ No newline at end of file diff --git a/apps/dataset/urls.py b/apps/dataset/urls.py index bc75c1b35..2bff2ae81 100644 --- a/apps/dataset/urls.py +++ b/apps/dataset/urls.py @@ -13,6 +13,8 @@ urlpatterns = [ name="document_operate"), path('dataset/document/split', views.Document.Split.as_view(), name="document_operate"), + path('dataset/document/split_pattern', views.Document.SplitPattern.as_view(), + name="document_operate"), path('dataset//document//paragraph', views.Paragraph.as_view()), path('dataset//document//paragraph//', views.Paragraph.Page.as_view(), name='paragraph_page'), diff --git a/apps/dataset/views/document.py b/apps/dataset/views/document.py index 9e9153e42..382d5cc43 100644 --- a/apps/dataset/views/document.py +++ b/apps/dataset/views/document.py @@ -118,6 +118,15 @@ class Document(APIView): operate.is_valid(raise_exception=True) return result.success(operate.delete()) + class SplitPattern(APIView): + @action(methods=['GET'], detail=False) + @swagger_auto_schema(operation_summary="获取分段标识列表", + operation_id="获取分段标识列表", + tags=["数据集/文档"], + security=[]) + def get(self, request: Request): + return result.success(DocumentSerializers.SplitPattern.list()) + class Split(APIView): parser_classes = [MultiPartParser] @@ -128,9 +137,17 @@ class Document(APIView): tags=["数据集/文档"], security=[]) def post(self, request: Request): + split_data = {'file': request.FILES.getlist('file')} + request_data = request.data + if 'patterns' in request.data and request.data.get('patterns') is not None and len( + request.data.get('patterns')) > 0: + split_data.__setitem__('patterns', request_data.getlist('patterns')) + if 'limit' in request.data: + split_data.__setitem__('limit', request_data.get('limit')) + if 'with_filter' in request.data: + split_data.__setitem__('with_filter', request_data.get('with_filter')) ds = DocumentSerializers.Split( - data={'file': request.FILES.getlist('file'), - 'patterns': request.data.getlist('patterns[]')}) + data=split_data) ds.is_valid(raise_exception=True) return result.success(ds.parse()) diff --git a/apps/setting/serializers/team_serializers.py b/apps/setting/serializers/team_serializers.py index af4b2e34c..b0c770a7f 100644 --- a/apps/setting/serializers/team_serializers.py +++ b/apps/setting/serializers/team_serializers.py @@ -10,9 +10,10 @@ import itertools import json import os import uuid -from typing import Dict +from typing import Dict, List from django.core import cache +from django.db import transaction from django.db.models import QuerySet, Q from drf_yasg import openapi from rest_framework import serializers @@ -141,12 +142,22 @@ class UpdateTeamMemberPermissionSerializer(ApiMixin, serializers.Serializer): class TeamMemberSerializer(ApiMixin, serializers.Serializer): - team_id = serializers.CharField(required=True) + team_id = serializers.UUIDField(required=True) def is_valid(self, *, raise_exception=False): super().is_valid(raise_exception=True) - def get_request_body_api(self): + @staticmethod + def get_bach_request_body_api(): + return openapi.Schema( + type=openapi.TYPE_ARRAY, + title="用户id列表", + description="用户id列表", + items=openapi.Schema(type=openapi.TYPE_STRING) + ) + + @staticmethod + def get_request_body_api(): return openapi.Schema( type=openapi.TYPE_OBJECT, required=['username_or_email'], @@ -157,6 +168,37 @@ class TeamMemberSerializer(ApiMixin, serializers.Serializer): } ) + @transaction.atomic + def batch_add_member(self, user_id_list: List[str], with_valid=True): + """ + 批量添加成员 + :param user_id_list: 用户id列表 + :param with_valid: 是否校验 + :return: 成员列表 + """ + if with_valid: + self.is_valid(raise_exception=True) + use_user_id_list = [str(u.id) for u in QuerySet(User).filter(id__in=user_id_list)] + + team_member_user_id_list = [str(team_member.user_id) for team_member in + QuerySet(TeamMember).filter(team_id=self.data.get('team_id'))] + team_id = self.data.get("team_id") + create_team_member_list = [ + self.to_member_model(add_user_id, team_member_user_id_list, use_user_id_list, team_id) for add_user_id in + user_id_list] + QuerySet(TeamMember).bulk_create(create_team_member_list) if len(create_team_member_list) > 0 else None + return TeamMemberSerializer( + data={'team_id': self.data.get("team_id")}).list_member() + + def to_member_model(self, add_user_id, team_member_user_id_list, use_user_id_list, user_id): + if use_user_id_list.__contains__(add_user_id): + if team_member_user_id_list.__contains__(add_user_id) or user_id == add_user_id: + raise AppApiException(500, "团队中已存在当前成员,不要重复添加") + else: + return TeamMember(team_id=self.data.get("team_id"), user_id=add_user_id) + else: + raise AppApiException(500, "不存在的用户") + def add_member(self, username_or_email: str, with_valid=True): """ 添加一个成员 @@ -172,10 +214,11 @@ class TeamMemberSerializer(ApiMixin, serializers.Serializer): Q(username=username_or_email) | Q(email=username_or_email)).first() if user is None: raise AppApiException(500, "不存在的用户") - if QuerySet(TeamMember).filter(Q(team_id=self.data.get('team_id')) & Q(user=user)).exists(): + if QuerySet(TeamMember).filter(Q(team_id=self.data.get('team_id')) & Q(user=user)).exists() or self.data.get( + "team_id") == str(user.id): raise AppApiException(500, "团队中已存在当前成员,不要重复添加") TeamMember(team_id=self.data.get("team_id"), user=user).save() - return TeamMemberSerializer(data={'team_id': self.data.get("team_id")}).list_member() + return self.list_member(with_valid=False) def list_member(self, with_valid=True): """ diff --git a/apps/setting/urls.py b/apps/setting/urls.py index 44b408c50..5501e9d2f 100644 --- a/apps/setting/urls.py +++ b/apps/setting/urls.py @@ -5,6 +5,7 @@ from . import views app_name = "team" urlpatterns = [ path('team/member', views.TeamMember.as_view(), name="team"), + path('team/member/_batch', views.TeamMember.Batch.as_view()), path('team/member/', views.TeamMember.Operate.as_view(), name='member'), path('provider//', views.Provide.Exec.as_view(), name='provide_exec'), path('provider', views.Provide.as_view(), name='provide'), diff --git a/apps/setting/views/Team.py b/apps/setting/views/Team.py index 30b884527..71710e3d6 100644 --- a/apps/setting/views/Team.py +++ b/apps/setting/views/Team.py @@ -40,6 +40,19 @@ class TeamMember(APIView): team = TeamMemberSerializer(data={'team_id': str(request.user.id)}) return result.success((team.add_member(**request.data))) + class Batch(APIView): + authentication_classes = [TokenAuth] + + @action(methods=['POST'], detail=False) + @swagger_auto_schema(operation_summary="批量添加成员", + operation_id="批量添加成员", + request_body=TeamMemberSerializer.get_bach_request_body_api(), + tags=["团队"]) + @has_permissions(PermissionConstants.TEAM_CREATE) + def post(self, request: Request): + return result.success( + TeamMemberSerializer(data={'team_id': request.user.id}).batch_add_member(request.data)) + class Operate(APIView): authentication_classes = [TokenAuth] diff --git a/apps/users/serializers/user_serializers.py b/apps/users/serializers/user_serializers.py index 4313dd670..1259e202b 100644 --- a/apps/users/serializers/user_serializers.py +++ b/apps/users/serializers/user_serializers.py @@ -411,8 +411,9 @@ class UserSerializer(ApiMixin, serializers.ModelSerializer): def get_response_body_api(): return openapi.Schema( type=openapi.TYPE_OBJECT, - required=['username', 'email', ], + required=['username', 'email', 'id'], properties={ + 'id': openapi.Schema(type=openapi.TYPE_STRING, title='用户主键id', description="用户主键id"), 'username': openapi.Schema(type=openapi.TYPE_STRING, title="用户名", description="用户名"), 'email': openapi.Schema(type=openapi.TYPE_STRING, title="邮箱", description="邮箱地址") } @@ -422,5 +423,5 @@ class UserSerializer(ApiMixin, serializers.ModelSerializer): if with_valid: self.is_valid(raise_exception=True) email_or_username = self.data.get('email_or_username') - return [{'username': user_model.username, 'email': user_model.email} for user_model in + return [{'id': user_model.id, 'username': user_model.username, 'email': user_model.email} for user_model in QuerySet(User).filter(Q(username=email_or_username) | Q(email=email_or_username))] diff --git a/apps/users/views/user.py b/apps/users/views/user.py index e96db337b..76226aaa4 100644 --- a/apps/users/views/user.py +++ b/apps/users/views/user.py @@ -17,7 +17,7 @@ from rest_framework.views import Request from common.auth.authenticate import TokenAuth from common.auth.authentication import has_permissions -from common.constants.permission_constants import PermissionConstants, CompareConstants +from common.constants.permission_constants import PermissionConstants from common.response import result from smartdoc.settings import JWT_AUTH from users.serializers.user_serializers import RegisterSerializer, LoginSerializer, CheckCodeSerializer, \ @@ -36,7 +36,7 @@ class User(APIView): operation_id="获取当前用户信息", responses=result.get_api_response(UserProfile.get_response_body_api()), tags=['用户']) - @has_permissions(PermissionConstants.USER_READ, compare=CompareConstants.AND) + @has_permissions(PermissionConstants.USER_READ) def get(self, request: Request): return result.success(UserProfile.get_user_profile(request.user)) @@ -49,7 +49,7 @@ class User(APIView): manual_parameters=UserSerializer.Query.get_request_params_api(), responses=result.get_api_array_response(UserSerializer.Query.get_response_body_api()), tags=['用户']) - @has_permissions(PermissionConstants.USER_READ, compare=CompareConstants.AND) + @has_permissions(PermissionConstants.USER_READ) def get(self, request: Request): return result.success( UserSerializer.Query(data={'email_or_username': request.query_params.get('email_or_username')}).list()) diff --git a/ui/src/api/dataset.ts b/ui/src/api/dataset.ts index 5254a0b86..322017a08 100644 --- a/ui/src/api/dataset.ts +++ b/ui/src/api/dataset.ts @@ -1,7 +1,8 @@ import { Result } from '@/request/Result' import { get, post, del, put } from '@/request/index' import type { datasetListRequest, datasetData } from '@/api/type/dataset' - +import type { Ref } from 'vue' +import type { KeyValue } from '@/api/type/common' const prefix = '/dataset' /** @@ -96,6 +97,17 @@ const postSplitDocument: (data: any) => Promise> = (data) => { return post(`${prefix}/document/split`, data) } +/** + * 分段标识列表 + * @param loading 加载器 + * @returns 分段标识列表 + */ +const listSplitPattern: (loading?: Ref) => Promise>> = ( + loading +) => { + return get(`${prefix}/document/split_pattern`, {}, loading) +} + /** * 文档列表 * @param 参数 dataset_id, name @@ -249,11 +261,11 @@ const putParagraph: ( * 问题列表 * @param 参数 dataset_id,document_id,paragraph_id */ -const getProblem: (dataset_id: string, document_id: string, paragraph_id: string) => Promise> = ( - dataset_id, - document_id, - paragraph_id: string, -) => { +const getProblem: ( + dataset_id: string, + document_id: string, + paragraph_id: string +) => Promise> = (dataset_id, document_id, paragraph_id: string) => { return get(`${prefix}/${dataset_id}/document/${document_id}/paragraph/${paragraph_id}/problem`) } @@ -285,9 +297,11 @@ const delProblem: ( dataset_id: string, document_id: string, paragraph_id: string, - problem_id: string, -) => Promise> = (dataset_id, document_id, paragraph_id,problem_id) => { - return del(`${prefix}/${dataset_id}/document/${document_id}/paragraph/${paragraph_id}/problem/${problem_id}`) + problem_id: string +) => Promise> = (dataset_id, document_id, paragraph_id, problem_id) => { + return del( + `${prefix}/${dataset_id}/document/${document_id}/paragraph/${paragraph_id}/problem/${problem_id}` + ) } export default { @@ -309,5 +323,6 @@ export default { postParagraph, getProblem, postProblem, - delProblem + delProblem, + listSplitPattern } diff --git a/ui/src/api/type/common.ts b/ui/src/api/type/common.ts new file mode 100644 index 000000000..32f883255 --- /dev/null +++ b/ui/src/api/type/common.ts @@ -0,0 +1,9 @@ +interface KeyValue { + key: K + value: V +} +interface Dict { + [propName: string]: V +} + +export type { KeyValue, Dict } diff --git a/ui/src/views/dataset/step/StepSecond.vue b/ui/src/views/dataset/step/StepSecond.vue index e7c56609c..d4430cbbc 100644 --- a/ui/src/views/dataset/step/StepSecond.vue +++ b/ui/src/views/dataset/step/StepSecond.vue @@ -30,12 +30,12 @@ - + @@ -47,7 +47,7 @@ v-model="form.limit" show-input :show-input-controls="false" - :min="10" + :min="50" :max="1024" /> @@ -80,22 +80,19 @@