feat: 批量添加团队成员,文档分段高级分段标识

This commit is contained in:
shaohuzhang1 2023-11-20 18:53:18 +08:00
parent 2032c83ba1
commit 78a9697f50
14 changed files with 212 additions and 57 deletions

View File

@ -13,7 +13,7 @@ from typing import List
import jieba
def get_level_block(text, level_content_list, level_content_index):
def get_level_block(text, level_content_list, level_content_index, cursor):
"""
从文本中获取块数据
:param text: 文本
@ -24,9 +24,10 @@ def get_level_block(text, level_content_list, level_content_index):
start_content: str = level_content_list[level_content_index].get('content')
next_content = level_content_list[level_content_index + 1].get("content") if level_content_index + 1 < len(
level_content_list) else None
start_index = text.index(start_content)
end_index = text.index(next_content) if next_content is not None else len(text)
return text[start_index:end_index].replace(level_content_list[level_content_index]['content'], "")
print(len(text), cursor, start_content)
start_index = text.index(start_content, cursor)
end_index = text.index(next_content, start_index + 1) if next_content is not None else len(text)
return text[start_index:end_index].replace(level_content_list[level_content_index]['content'], ""), end_index
def to_tree_obj(content, state='title'):
@ -297,8 +298,9 @@ class SplitModel:
if len(self.content_level_pattern) == index:
return
level_content_list = parse_title_level(text, self.content_level_pattern, index)
cursor = 0
for i in range(len(level_content_list)):
block = get_level_block(text, level_content_list, i)
block, cursor = get_level_block(text, level_content_list, i, cursor)
children = self.parse_to_tree(text=block,
index=index + 1)
if children is not None and len(children) > 0:
@ -317,6 +319,11 @@ class SplitModel:
level_content_list = [*level_content_list, *list(
map(lambda row: to_tree_obj(row, 'block'),
post_handler_paragraph(other_content, with_filter=self.with_filter, limit=self.limit)))]
else:
if len(text.strip()) > 0:
level_content_list = [*level_content_list, *list(
map(lambda row: to_tree_obj(row, 'block'),
post_handler_paragraph(text, with_filter=self.with_filter, limit=self.limit)))]
return level_content_list
def parse(self, text: str):
@ -329,25 +336,29 @@ class SplitModel:
return result_tree_to_paragraph(result_tree, [], [])
split_model_map = {
'md': SplitModel(
[re.compile("^# .*"), re.compile('(?<!#)## (?!#).*'), re.compile("(?<!#)### (?!#).*"),
re.compile("(?<!#)####(?!#).*"), re.compile("(?<!#)#####(?!#).*"),
re.compile("(?<!#)#####(?!#).*"),
re.compile("(?<! )- .*")]),
'default': SplitModel([re.compile("(?<!\n)\n\n.+")])
default_split_pattern = {
'md': [re.compile("^# .*"), re.compile('(?<!#)## (?!#).*'), re.compile("(?<!#)### (?!#).*"),
re.compile("(?<!#)####(?!#).*"), re.compile("(?<!#)#####(?!#).*"),
re.compile("(?<!#)######(?!#).*"),
re.compile("(?<! )- .*")],
'default': [re.compile("(?<!\n)\n\n.+")]
}
def get_split_model(filename: str):
def get_split_model(filename: str, with_filter: bool, limit: int):
"""
根据文件名称获取分段模型
:param limit: 每段大小
:param with_filter: 是否过滤特殊字符
:param filename: 文件名称
:return: 分段模型
"""
if filename.endswith(".md"):
return split_model_map.get('md')
return split_model_map.get("default")
pattern_list = default_split_pattern.get('md')
return SplitModel(pattern_list, with_filter=with_filter, limit=limit)
pattern_list = default_split_pattern.get('default')
return SplitModel(pattern_list, with_filter=with_filter, limit=limit)
def to_title_tree_string(result_tree: List):

View File

@ -0,0 +1,19 @@
# coding=utf-8
"""
@project: maxkb
@Author
@file common_serializers.py
@date2023/11/17 11:00
@desc:
"""
import os
from common.db.sql_execute import update_execute
from common.util.file_util import get_file_content
from smartdoc.conf import PROJECT_DIR
def update_document_char_length(document_id: str):
update_execute(get_file_content(
os.path.join(PROJECT_DIR, "apps", "dataset", 'sql', 'update_document_char_length.sql')),
(document_id, document_id))

View File

@ -250,7 +250,7 @@ class DocumentSerializers(ApiMixin, serializers.Serializer):
with_filter = serializers.BooleanField(required=False)
def is_valid(self, *, raise_exception=True):
super().is_valid()
super().is_valid(raise_exception=True)
files = self.data.get('file')
for f in files:
if f.size > 1024 * 1024 * 10:
@ -282,8 +282,21 @@ class DocumentSerializers(ApiMixin, serializers.Serializer):
def parse(self):
file_list = self.data.get("file")
return list(map(lambda f: file_to_paragraph(f, self.data.get("patterns"), self.data.get("with_filter"),
self.data.get("limit")), file_list))
return list(
map(lambda f: file_to_paragraph(f, self.data.get("patterns", None), self.data.get("with_filter", None),
self.data.get("limit", None)), file_list))
class SplitPattern(ApiMixin, serializers.Serializer):
@staticmethod
def list():
return [{'key': "#", 'value': '^# .*'}, {'key': '##', 'value': '(?<!#)## (?!#).*'},
{'key': '###', 'value': "(?<!#)### (?!#).*"}, {'key': '####', 'value': "(?<!#)####(?!#).*"},
{'key': '#####', 'value': "(?<!#)#####(?!#).*"}, {'key': '######', 'value': "(?<!#)######(?!#).*"},
{'key': '-', 'value': '(?<! )- .*'},
{'key': '空格', 'value': '(?<!\\s)\\s(?!\\s)'},
{'key': '分号', 'value': '(?<!)(?!)'}, {'key': '逗号', 'value': '(?<!)(?!)'},
{'key': '句号', 'value': '(?<!。)。(?!。)'}, {'key': '回车', 'value': '(?<!\\n)\\n(?!\\n)'},
{'key': '空行', 'value': '(?<!\\n)\\n\\n(?!\\n)'}]
class Batch(ApiMixin, serializers.Serializer):
dataset_id = serializers.UUIDField(required=True)
@ -302,12 +315,12 @@ class DocumentSerializers(ApiMixin, serializers.Serializer):
for instance in instance_list]
def file_to_paragraph(file, pattern_list: List, with_filter, limit: int):
def file_to_paragraph(file, pattern_list: List, with_filter: bool, limit: int):
data = file.read()
if pattern_list is None or len(pattern_list) > 0:
if pattern_list is not None and len(pattern_list) > 0:
split_model = SplitModel(pattern_list, with_filter, limit)
else:
split_model = get_split_model(file.name)
split_model = get_split_model(file.name, with_filter=with_filter, limit=limit)
try:
content = data.decode('utf-8')
except BaseException as e:

View File

@ -0,0 +1,4 @@
UPDATE "document"
SET "char_length" = ( SELECT "sum" ( "char_length" ( "content" ) ) FROM paragraph WHERE "document_id" = %s )
WHERE
"id" = %s

View File

@ -13,6 +13,8 @@ urlpatterns = [
name="document_operate"),
path('dataset/document/split', views.Document.Split.as_view(),
name="document_operate"),
path('dataset/document/split_pattern', views.Document.SplitPattern.as_view(),
name="document_operate"),
path('dataset/<str:dataset_id>/document/<str:document_id>/paragraph', views.Paragraph.as_view()),
path('dataset/<str:dataset_id>/document/<str:document_id>/paragraph/<int:current_page>/<int:page_size>',
views.Paragraph.Page.as_view(), name='paragraph_page'),

View File

@ -118,6 +118,15 @@ class Document(APIView):
operate.is_valid(raise_exception=True)
return result.success(operate.delete())
class SplitPattern(APIView):
@action(methods=['GET'], detail=False)
@swagger_auto_schema(operation_summary="获取分段标识列表",
operation_id="获取分段标识列表",
tags=["数据集/文档"],
security=[])
def get(self, request: Request):
return result.success(DocumentSerializers.SplitPattern.list())
class Split(APIView):
parser_classes = [MultiPartParser]
@ -128,9 +137,17 @@ class Document(APIView):
tags=["数据集/文档"],
security=[])
def post(self, request: Request):
split_data = {'file': request.FILES.getlist('file')}
request_data = request.data
if 'patterns' in request.data and request.data.get('patterns') is not None and len(
request.data.get('patterns')) > 0:
split_data.__setitem__('patterns', request_data.getlist('patterns'))
if 'limit' in request.data:
split_data.__setitem__('limit', request_data.get('limit'))
if 'with_filter' in request.data:
split_data.__setitem__('with_filter', request_data.get('with_filter'))
ds = DocumentSerializers.Split(
data={'file': request.FILES.getlist('file'),
'patterns': request.data.getlist('patterns[]')})
data=split_data)
ds.is_valid(raise_exception=True)
return result.success(ds.parse())

View File

@ -10,9 +10,10 @@ import itertools
import json
import os
import uuid
from typing import Dict
from typing import Dict, List
from django.core import cache
from django.db import transaction
from django.db.models import QuerySet, Q
from drf_yasg import openapi
from rest_framework import serializers
@ -141,12 +142,22 @@ class UpdateTeamMemberPermissionSerializer(ApiMixin, serializers.Serializer):
class TeamMemberSerializer(ApiMixin, serializers.Serializer):
team_id = serializers.CharField(required=True)
team_id = serializers.UUIDField(required=True)
def is_valid(self, *, raise_exception=False):
super().is_valid(raise_exception=True)
def get_request_body_api(self):
@staticmethod
def get_bach_request_body_api():
return openapi.Schema(
type=openapi.TYPE_ARRAY,
title="用户id列表",
description="用户id列表",
items=openapi.Schema(type=openapi.TYPE_STRING)
)
@staticmethod
def get_request_body_api():
return openapi.Schema(
type=openapi.TYPE_OBJECT,
required=['username_or_email'],
@ -157,6 +168,37 @@ class TeamMemberSerializer(ApiMixin, serializers.Serializer):
}
)
@transaction.atomic
def batch_add_member(self, user_id_list: List[str], with_valid=True):
"""
批量添加成员
:param user_id_list: 用户id列表
:param with_valid: 是否校验
:return: 成员列表
"""
if with_valid:
self.is_valid(raise_exception=True)
use_user_id_list = [str(u.id) for u in QuerySet(User).filter(id__in=user_id_list)]
team_member_user_id_list = [str(team_member.user_id) for team_member in
QuerySet(TeamMember).filter(team_id=self.data.get('team_id'))]
team_id = self.data.get("team_id")
create_team_member_list = [
self.to_member_model(add_user_id, team_member_user_id_list, use_user_id_list, team_id) for add_user_id in
user_id_list]
QuerySet(TeamMember).bulk_create(create_team_member_list) if len(create_team_member_list) > 0 else None
return TeamMemberSerializer(
data={'team_id': self.data.get("team_id")}).list_member()
def to_member_model(self, add_user_id, team_member_user_id_list, use_user_id_list, user_id):
if use_user_id_list.__contains__(add_user_id):
if team_member_user_id_list.__contains__(add_user_id) or user_id == add_user_id:
raise AppApiException(500, "团队中已存在当前成员,不要重复添加")
else:
return TeamMember(team_id=self.data.get("team_id"), user_id=add_user_id)
else:
raise AppApiException(500, "不存在的用户")
def add_member(self, username_or_email: str, with_valid=True):
"""
添加一个成员
@ -172,10 +214,11 @@ class TeamMemberSerializer(ApiMixin, serializers.Serializer):
Q(username=username_or_email) | Q(email=username_or_email)).first()
if user is None:
raise AppApiException(500, "不存在的用户")
if QuerySet(TeamMember).filter(Q(team_id=self.data.get('team_id')) & Q(user=user)).exists():
if QuerySet(TeamMember).filter(Q(team_id=self.data.get('team_id')) & Q(user=user)).exists() or self.data.get(
"team_id") == str(user.id):
raise AppApiException(500, "团队中已存在当前成员,不要重复添加")
TeamMember(team_id=self.data.get("team_id"), user=user).save()
return TeamMemberSerializer(data={'team_id': self.data.get("team_id")}).list_member()
return self.list_member(with_valid=False)
def list_member(self, with_valid=True):
"""

View File

@ -5,6 +5,7 @@ from . import views
app_name = "team"
urlpatterns = [
path('team/member', views.TeamMember.as_view(), name="team"),
path('team/member/_batch', views.TeamMember.Batch.as_view()),
path('team/member/<str:member_id>', views.TeamMember.Operate.as_view(), name='member'),
path('provider/<str:provider>/<str:method>', views.Provide.Exec.as_view(), name='provide_exec'),
path('provider', views.Provide.as_view(), name='provide'),

View File

@ -40,6 +40,19 @@ class TeamMember(APIView):
team = TeamMemberSerializer(data={'team_id': str(request.user.id)})
return result.success((team.add_member(**request.data)))
class Batch(APIView):
authentication_classes = [TokenAuth]
@action(methods=['POST'], detail=False)
@swagger_auto_schema(operation_summary="批量添加成员",
operation_id="批量添加成员",
request_body=TeamMemberSerializer.get_bach_request_body_api(),
tags=["团队"])
@has_permissions(PermissionConstants.TEAM_CREATE)
def post(self, request: Request):
return result.success(
TeamMemberSerializer(data={'team_id': request.user.id}).batch_add_member(request.data))
class Operate(APIView):
authentication_classes = [TokenAuth]

View File

@ -411,8 +411,9 @@ class UserSerializer(ApiMixin, serializers.ModelSerializer):
def get_response_body_api():
return openapi.Schema(
type=openapi.TYPE_OBJECT,
required=['username', 'email', ],
required=['username', 'email', 'id'],
properties={
'id': openapi.Schema(type=openapi.TYPE_STRING, title='用户主键id', description="用户主键id"),
'username': openapi.Schema(type=openapi.TYPE_STRING, title="用户名", description="用户名"),
'email': openapi.Schema(type=openapi.TYPE_STRING, title="邮箱", description="邮箱地址")
}
@ -422,5 +423,5 @@ class UserSerializer(ApiMixin, serializers.ModelSerializer):
if with_valid:
self.is_valid(raise_exception=True)
email_or_username = self.data.get('email_or_username')
return [{'username': user_model.username, 'email': user_model.email} for user_model in
return [{'id': user_model.id, 'username': user_model.username, 'email': user_model.email} for user_model in
QuerySet(User).filter(Q(username=email_or_username) | Q(email=email_or_username))]

View File

@ -17,7 +17,7 @@ from rest_framework.views import Request
from common.auth.authenticate import TokenAuth
from common.auth.authentication import has_permissions
from common.constants.permission_constants import PermissionConstants, CompareConstants
from common.constants.permission_constants import PermissionConstants
from common.response import result
from smartdoc.settings import JWT_AUTH
from users.serializers.user_serializers import RegisterSerializer, LoginSerializer, CheckCodeSerializer, \
@ -36,7 +36,7 @@ class User(APIView):
operation_id="获取当前用户信息",
responses=result.get_api_response(UserProfile.get_response_body_api()),
tags=['用户'])
@has_permissions(PermissionConstants.USER_READ, compare=CompareConstants.AND)
@has_permissions(PermissionConstants.USER_READ)
def get(self, request: Request):
return result.success(UserProfile.get_user_profile(request.user))
@ -49,7 +49,7 @@ class User(APIView):
manual_parameters=UserSerializer.Query.get_request_params_api(),
responses=result.get_api_array_response(UserSerializer.Query.get_response_body_api()),
tags=['用户'])
@has_permissions(PermissionConstants.USER_READ, compare=CompareConstants.AND)
@has_permissions(PermissionConstants.USER_READ)
def get(self, request: Request):
return result.success(
UserSerializer.Query(data={'email_or_username': request.query_params.get('email_or_username')}).list())

View File

@ -1,7 +1,8 @@
import { Result } from '@/request/Result'
import { get, post, del, put } from '@/request/index'
import type { datasetListRequest, datasetData } from '@/api/type/dataset'
import type { Ref } from 'vue'
import type { KeyValue } from '@/api/type/common'
const prefix = '/dataset'
/**
@ -96,6 +97,17 @@ const postSplitDocument: (data: any) => Promise<Result<any>> = (data) => {
return post(`${prefix}/document/split`, data)
}
/**
*
* @param loading
* @returns
*/
const listSplitPattern: (loading?: Ref<boolean>) => Promise<Result<KeyValue<string, string>>> = (
loading
) => {
return get(`${prefix}/document/split_pattern`, {}, loading)
}
/**
*
* @param dataset_id, name
@ -249,11 +261,11 @@ const putParagraph: (
*
* @param dataset_iddocument_idparagraph_id
*/
const getProblem: (dataset_id: string, document_id: string, paragraph_id: string) => Promise<Result<any>> = (
dataset_id,
document_id,
paragraph_id: string,
) => {
const getProblem: (
dataset_id: string,
document_id: string,
paragraph_id: string
) => Promise<Result<any>> = (dataset_id, document_id, paragraph_id: string) => {
return get(`${prefix}/${dataset_id}/document/${document_id}/paragraph/${paragraph_id}/problem`)
}
@ -285,9 +297,11 @@ const delProblem: (
dataset_id: string,
document_id: string,
paragraph_id: string,
problem_id: string,
) => Promise<Result<boolean>> = (dataset_id, document_id, paragraph_id,problem_id) => {
return del(`${prefix}/${dataset_id}/document/${document_id}/paragraph/${paragraph_id}/problem/${problem_id}`)
problem_id: string
) => Promise<Result<boolean>> = (dataset_id, document_id, paragraph_id, problem_id) => {
return del(
`${prefix}/${dataset_id}/document/${document_id}/paragraph/${paragraph_id}/problem/${problem_id}`
)
}
export default {
@ -309,5 +323,6 @@ export default {
postParagraph,
getProblem,
postProblem,
delProblem
delProblem,
listSplitPattern
}

View File

@ -0,0 +1,9 @@
interface KeyValue<K, V> {
key: K
value: V
}
interface Dict<V> {
[propName: string]: V
}
export type { KeyValue, Dict }

View File

@ -30,12 +30,12 @@
</el-tooltip>
</div>
<el-select v-model="form.patterns" multiple placeholder="请选择">
<el-select v-loading="patternLoading" v-model="form.patterns" multiple placeholder="请选择">
<el-option
v-for="item in patternsList"
v-for="item in splitPatternList"
:key="item"
:label="item"
:value="item"
:label="item.key"
:value="item.value"
multiple
>
</el-option>
@ -47,7 +47,7 @@
v-model="form.limit"
show-input
:show-input-controls="false"
:min="10"
:min="50"
:max="1024"
/>
</div>
@ -80,22 +80,19 @@
</div>
</template>
<script setup lang="ts">
import { ref, computed, onMounted, reactive } from 'vue'
import { ref, computed, onMounted, reactive,watch } from 'vue'
import ParagraphPreview from '@/views/dataset/component/ParagraphPreview.vue'
import DatasetApi from '@/api/dataset'
import useStore from '@/stores'
import type { KeyValue } from '@/api/type/common'
const { dataset } = useStore()
const documentsFiles = computed(() => dataset.documentsFiles)
const patternType = ['空行', '#', '##', '###', '####', '-', '空格', '回车', '句号', '逗号', '分号']
const marks = reactive({
10: '10',
1024: '1024'
})
const splitPatternList = ref<list<KeyValue>>([])
const radio = ref('1')
const loading = ref(false)
const paragraphList = ref<any[]>([])
const patternLoading=ref<boolean>(false)
const form = reactive<any>({
patterns: [] as any,
@ -103,8 +100,6 @@ const form = reactive<any>({
with_filter: false
})
const patternsList = ref<string[]>(patternType)
function splitDocument() {
loading.value = true
let fd = new FormData()
@ -128,6 +123,18 @@ function splitDocument() {
})
}
const initSplitPatternList = () => {
DatasetApi.listSplitPattern(patternLoading).then(ok=>{
splitPatternList.value=ok.data
})
}
watch(radio,()=>{
if(radio.value==='2'){
initSplitPatternList()
}
})
onMounted(() => {
splitDocument()
})