feat:文档批量创建,用户列表查询,关联数据集数据

This commit is contained in:
shaohuzhang1 2023-11-17 17:43:35 +08:00
parent 42acc60d87
commit 9e8008e064
11 changed files with 120 additions and 7 deletions

View File

@ -27,6 +27,19 @@ def sql_execute(sql: str, params):
return result
def update_execute(sql: str, params):
"""
执行一条sql
:param sql: 需要执行的sql
:param params: sql参数
:return: 执行结果
"""
with connection.cursor() as cursor:
cursor.execute(sql, params)
cursor.close()
return None
def select_list(sql: str, params: List):
"""
执行sql 查询列表数据

View File

@ -201,7 +201,7 @@ class DocumentSerializers(ApiMixin, serializers.Serializer):
def save(self, instance: Dict, with_valid=False, with_embedding=True, **kwargs):
if with_valid:
DocumentInstanceSerializer(data=instance).is_valid()
DocumentInstanceSerializer(data=instance).is_valid(raise_exception=True)
self.is_valid(raise_exception=True)
dataset_id = self.data.get('dataset_id')
@ -212,13 +212,14 @@ class DocumentSerializers(ApiMixin, serializers.Serializer):
'char_length': reduce(lambda x, y: x + y,
[len(p.get('content')) for p in instance.get('paragraphs', [])],
0)})
# 插入文档
document_model.save()
for paragraph in instance.get('paragraphs') if 'paragraphs' in instance else []:
ParagraphSerializers.Create(
data={'dataset_id': dataset_id, 'document_id': str(document_model.id)}).save(paragraph,
with_valid=True,
with_embedding=False)
# 插入文档
document_model.save()
if with_embedding:
ListenerManagement.embedding_by_document_signal.send(str(document_model.id))
return DocumentSerializers.Operate(
@ -284,6 +285,22 @@ class DocumentSerializers(ApiMixin, serializers.Serializer):
return list(map(lambda f: file_to_paragraph(f, self.data.get("patterns"), self.data.get("with_filter"),
self.data.get("limit")), file_list))
class Batch(ApiMixin, serializers.Serializer):
dataset_id = serializers.UUIDField(required=True)
@staticmethod
def get_request_body_api():
return openapi.Schema(type=openapi.TYPE_ARRAY, items=DocumentSerializers.Create.get_request_body_api())
def batch_save(self, instance_list: List[Dict], with_valid=True):
if with_valid:
self.is_valid(raise_exception=True)
DocumentInstanceSerializer(many=True, data=instance_list).is_valid(raise_exception=True)
create_data = {'dataset_id': self.data.get("dataset_id")}
return [DocumentSerializers.Create(data=create_data).save(instance,
with_valid=True)
for instance in instance_list]
def file_to_paragraph(file, pattern_list: List, with_filter, limit: int):
data = file.read()

View File

@ -20,6 +20,7 @@ from common.event.listener_manage import ListenerManagement
from common.exception.app_exception import AppApiException
from common.mixins.api_mixin import ApiMixin
from dataset.models import Paragraph, Problem, Document
from dataset.serializers.common_serializers import update_document_char_length
from dataset.serializers.problem_serializers import ProblemInstanceSerializer, ProblemSerializer
@ -123,6 +124,7 @@ class ParagraphSerializers(ApiMixin, serializers.Serializer):
update_problem_list) > 0 else None
_paragraph.save()
update_document_char_length(self.data.get('document_id'))
if 'is_active' in instance and instance.get('is_active') is not None:
s = (ListenerManagement.enable_embedding_by_paragraph_signal if instance.get(
'is_active') else ListenerManagement.disable_embedding_by_paragraph_signal)
@ -190,6 +192,8 @@ class ParagraphSerializers(ApiMixin, serializers.Serializer):
instance.get('problem_list') if 'problem_list' in instance else [])]
# 插入問題
QuerySet(Problem).bulk_create(problem_model_list) if len(problem_model_list) > 0 else None
# 修改长度
update_document_char_length(document_id)
if with_embedding:
ListenerManagement.embedding_by_paragraph_signal.send(str(paragraph.id))
return ParagraphSerializers.Operate(
@ -220,6 +224,8 @@ class ParagraphSerializers(ApiMixin, serializers.Serializer):
title = serializers.CharField(required=False)
content = serializers.CharField(required=False)
def get_query_set(self):
query_set = QuerySet(model=Paragraph)
query_set = query_set.filter(
@ -227,6 +233,8 @@ class ParagraphSerializers(ApiMixin, serializers.Serializer):
if 'title' in self.data:
query_set = query_set.filter(
**{'title__contains': self.data.get('title')})
if 'content' in self.data:
query_set = query_set.filter(**{'content__contains': self.data.get('content')})
return query_set
def list(self):
@ -247,7 +255,12 @@ class ParagraphSerializers(ApiMixin, serializers.Serializer):
in_=openapi.IN_QUERY,
type=openapi.TYPE_STRING,
required=False,
description='标题')
description='标题'),
openapi.Parameter(name='content',
in_=openapi.IN_QUERY,
type=openapi.TYPE_STRING,
required=False,
description='内容')
]
@staticmethod

View File

@ -76,7 +76,9 @@ class ProblemSerializers(ApiMixin, serializers.Serializer):
'source_id': problem.id,
'document_id': self.data.get('document_id'),
'paragraph_id': self.data.get('paragraph_id'),
'dataset_id': self.data.get('dataset_id')})
'dataset_id': self.data.get('dataset_id'),
'star_num': 0,
'trample_num': 0})
return ProblemSerializers.Operate(
data={'dataset_id': self.data.get('dataset_id'), 'document_id': self.data.get('document_id'),

View File

@ -5,6 +5,9 @@ FROM
SELECT
"temp_dataset".*,
"document_temp"."char_length",
CASE
WHEN
"app_dataset_temp"."count" IS NULL THEN 0 ELSE "app_dataset_temp"."count" END AS application_mapping_count,
"document_temp".document_count FROM (
SELECT dataset.*
FROM
@ -26,5 +29,6 @@ FROM
)
) temp_dataset
LEFT JOIN ( SELECT "count" ( "id" ) AS document_count, "sum" ( "char_length" ) "char_length", dataset_id FROM "document" GROUP BY dataset_id ) "document_temp" ON temp_dataset."id" = "document_temp".dataset_id
LEFT JOIN (SELECT "count"("id"),dataset_id FROM application_dataset_mapping GROUP BY dataset_id) app_dataset_temp ON temp_dataset."id" = "app_dataset_temp".dataset_id
) temp
${default_sql}

View File

@ -8,6 +8,7 @@ urlpatterns = [
path('dataset/<str:dataset_id>', views.Dataset.Operate.as_view(), name="dataset_key"),
path('dataset/<int:current_page>/<int:page_size>', views.Dataset.Page.as_view(), name="dataset"),
path('dataset/<str:dataset_id>/document', views.Document.as_view(), name='document'),
path('dataset/<str:dataset_id>/document/_bach', views.Document.Batch.as_view()),
path('dataset/<str:dataset_id>/document/<str:document_id>', views.Document.Operate.as_view(),
name="document_operate"),
path('dataset/document/split', views.Document.Split.as_view(),

View File

@ -52,6 +52,24 @@ class Document(APIView):
d.is_valid(raise_exception=True)
return result.success(d.list())
class Batch(APIView):
authentication_classes = [TokenAuth]
@action(methods=['POST'], detail=False)
@swagger_auto_schema(operation_summary="批量创建文档",
operation_id="批量创建文档",
request_body=
DocumentSerializers.Batch.get_request_body_api(),
manual_parameters=DocumentSerializers.Create.get_request_params_api(),
responses=result.get_api_array_response(
DocumentSerializers.Operate.get_response_body_api()),
tags=["数据集/文档"])
@has_permissions(
lambda r, k: Permission(group=Group.DATASET, operate=Operate.MANAGE,
dynamic_tag=k.get('dataset_id')))
def post(self, request: Request, dataset_id: str):
return result.success(DocumentSerializers.Batch(data={'dataset_id': dataset_id}).batch_save(request.data))
class Operate(APIView):
authentication_classes = [TokenAuth]

View File

@ -7,3 +7,4 @@
@desc:
"""
from .team_management import *
from .model_management import *

View File

@ -15,7 +15,7 @@ import uuid
from django.core import validators, signing, cache
from django.core.mail import send_mail
from django.db import transaction
from django.db.models import Q
from django.db.models import Q, QuerySet
from drf_yasg import openapi
from rest_framework import serializers
@ -395,3 +395,32 @@ class UserSerializer(ApiMixin, serializers.ModelSerializer):
'is_active': openapi.Schema(type=openapi.TYPE_STRING, title="是否可用", description="是否可用")
}
)
class Query(ApiMixin, serializers.Serializer):
email_or_username = serializers.CharField(required=True)
@staticmethod
def get_request_params_api():
return [openapi.Parameter(name='email_or_username',
in_=openapi.IN_QUERY,
type=openapi.TYPE_STRING,
required=True,
description='邮箱或者用户名')]
@staticmethod
def get_response_body_api():
return openapi.Schema(
type=openapi.TYPE_OBJECT,
required=['username', 'email', ],
properties={
'username': openapi.Schema(type=openapi.TYPE_STRING, title="用户名", description="用户名"),
'email': openapi.Schema(type=openapi.TYPE_STRING, title="邮箱", description="邮箱地址")
}
)
def list(self, with_valid=True):
if with_valid:
self.is_valid(raise_exception=True)
email_or_username = self.data.get('email_or_username')
return [{'username': user_model.username, 'email': user_model.email} for user_model in
QuerySet(User).filter(Q(username=email_or_username) | Q(email=email_or_username))]

View File

@ -5,6 +5,7 @@ from . import views
app_name = "user"
urlpatterns = [
path('user', views.User.as_view(), name="profile"),
path('user/list', views.User.Query.as_view()),
path('user/login', views.Login.as_view(), name='login'),
path('user/logout', views.Logout.as_view(), name='logout'),
path('user/register', views.Register.as_view(), name="register"),

View File

@ -22,7 +22,7 @@ from common.response import result
from smartdoc.settings import JWT_AUTH
from users.serializers.user_serializers import RegisterSerializer, LoginSerializer, CheckCodeSerializer, \
RePasswordSerializer, \
SendEmailSerializer, UserProfile
SendEmailSerializer, UserProfile, UserSerializer
user_cache = cache.caches['user_cache']
token_cache = cache.caches['token_cache']
@ -40,6 +40,20 @@ class User(APIView):
def get(self, request: Request):
return result.success(UserProfile.get_user_profile(request.user))
class Query(APIView):
authentication_classes = [TokenAuth]
@action(methods=['GET'], detail=False)
@swagger_auto_schema(operation_summary="获取用户列表",
operation_id="获取用户列表",
manual_parameters=UserSerializer.Query.get_request_params_api(),
responses=result.get_api_array_response(UserSerializer.Query.get_response_body_api()),
tags=['用户'])
@has_permissions(PermissionConstants.USER_READ, compare=CompareConstants.AND)
def get(self, request: Request):
return result.success(
UserSerializer.Query(data={'email_or_username': request.query_params.get('email_or_username')}).list())
class ResetCurrentUserPasswordView(APIView):
authentication_classes = [TokenAuth]