mirror of
https://github.com/1Panel-dev/MaxKB.git
synced 2025-12-26 01:33:05 +00:00
feat:文档批量创建,用户列表查询,关联数据集数据
This commit is contained in:
parent
42acc60d87
commit
9e8008e064
|
|
@ -27,6 +27,19 @@ def sql_execute(sql: str, params):
|
|||
return result
|
||||
|
||||
|
||||
def update_execute(sql: str, params):
|
||||
"""
|
||||
执行一条sql
|
||||
:param sql: 需要执行的sql
|
||||
:param params: sql参数
|
||||
:return: 执行结果
|
||||
"""
|
||||
with connection.cursor() as cursor:
|
||||
cursor.execute(sql, params)
|
||||
cursor.close()
|
||||
return None
|
||||
|
||||
|
||||
def select_list(sql: str, params: List):
|
||||
"""
|
||||
执行sql 查询列表数据
|
||||
|
|
|
|||
|
|
@ -201,7 +201,7 @@ class DocumentSerializers(ApiMixin, serializers.Serializer):
|
|||
|
||||
def save(self, instance: Dict, with_valid=False, with_embedding=True, **kwargs):
|
||||
if with_valid:
|
||||
DocumentInstanceSerializer(data=instance).is_valid()
|
||||
DocumentInstanceSerializer(data=instance).is_valid(raise_exception=True)
|
||||
self.is_valid(raise_exception=True)
|
||||
dataset_id = self.data.get('dataset_id')
|
||||
|
||||
|
|
@ -212,13 +212,14 @@ class DocumentSerializers(ApiMixin, serializers.Serializer):
|
|||
'char_length': reduce(lambda x, y: x + y,
|
||||
[len(p.get('content')) for p in instance.get('paragraphs', [])],
|
||||
0)})
|
||||
# 插入文档
|
||||
document_model.save()
|
||||
|
||||
for paragraph in instance.get('paragraphs') if 'paragraphs' in instance else []:
|
||||
ParagraphSerializers.Create(
|
||||
data={'dataset_id': dataset_id, 'document_id': str(document_model.id)}).save(paragraph,
|
||||
with_valid=True,
|
||||
with_embedding=False)
|
||||
# 插入文档
|
||||
document_model.save()
|
||||
if with_embedding:
|
||||
ListenerManagement.embedding_by_document_signal.send(str(document_model.id))
|
||||
return DocumentSerializers.Operate(
|
||||
|
|
@ -284,6 +285,22 @@ class DocumentSerializers(ApiMixin, serializers.Serializer):
|
|||
return list(map(lambda f: file_to_paragraph(f, self.data.get("patterns"), self.data.get("with_filter"),
|
||||
self.data.get("limit")), file_list))
|
||||
|
||||
class Batch(ApiMixin, serializers.Serializer):
|
||||
dataset_id = serializers.UUIDField(required=True)
|
||||
|
||||
@staticmethod
|
||||
def get_request_body_api():
|
||||
return openapi.Schema(type=openapi.TYPE_ARRAY, items=DocumentSerializers.Create.get_request_body_api())
|
||||
|
||||
def batch_save(self, instance_list: List[Dict], with_valid=True):
|
||||
if with_valid:
|
||||
self.is_valid(raise_exception=True)
|
||||
DocumentInstanceSerializer(many=True, data=instance_list).is_valid(raise_exception=True)
|
||||
create_data = {'dataset_id': self.data.get("dataset_id")}
|
||||
return [DocumentSerializers.Create(data=create_data).save(instance,
|
||||
with_valid=True)
|
||||
for instance in instance_list]
|
||||
|
||||
|
||||
def file_to_paragraph(file, pattern_list: List, with_filter, limit: int):
|
||||
data = file.read()
|
||||
|
|
|
|||
|
|
@ -20,6 +20,7 @@ from common.event.listener_manage import ListenerManagement
|
|||
from common.exception.app_exception import AppApiException
|
||||
from common.mixins.api_mixin import ApiMixin
|
||||
from dataset.models import Paragraph, Problem, Document
|
||||
from dataset.serializers.common_serializers import update_document_char_length
|
||||
from dataset.serializers.problem_serializers import ProblemInstanceSerializer, ProblemSerializer
|
||||
|
||||
|
||||
|
|
@ -123,6 +124,7 @@ class ParagraphSerializers(ApiMixin, serializers.Serializer):
|
|||
update_problem_list) > 0 else None
|
||||
|
||||
_paragraph.save()
|
||||
update_document_char_length(self.data.get('document_id'))
|
||||
if 'is_active' in instance and instance.get('is_active') is not None:
|
||||
s = (ListenerManagement.enable_embedding_by_paragraph_signal if instance.get(
|
||||
'is_active') else ListenerManagement.disable_embedding_by_paragraph_signal)
|
||||
|
|
@ -190,6 +192,8 @@ class ParagraphSerializers(ApiMixin, serializers.Serializer):
|
|||
instance.get('problem_list') if 'problem_list' in instance else [])]
|
||||
# 插入問題
|
||||
QuerySet(Problem).bulk_create(problem_model_list) if len(problem_model_list) > 0 else None
|
||||
# 修改长度
|
||||
update_document_char_length(document_id)
|
||||
if with_embedding:
|
||||
ListenerManagement.embedding_by_paragraph_signal.send(str(paragraph.id))
|
||||
return ParagraphSerializers.Operate(
|
||||
|
|
@ -220,6 +224,8 @@ class ParagraphSerializers(ApiMixin, serializers.Serializer):
|
|||
|
||||
title = serializers.CharField(required=False)
|
||||
|
||||
content = serializers.CharField(required=False)
|
||||
|
||||
def get_query_set(self):
|
||||
query_set = QuerySet(model=Paragraph)
|
||||
query_set = query_set.filter(
|
||||
|
|
@ -227,6 +233,8 @@ class ParagraphSerializers(ApiMixin, serializers.Serializer):
|
|||
if 'title' in self.data:
|
||||
query_set = query_set.filter(
|
||||
**{'title__contains': self.data.get('title')})
|
||||
if 'content' in self.data:
|
||||
query_set = query_set.filter(**{'content__contains': self.data.get('content')})
|
||||
return query_set
|
||||
|
||||
def list(self):
|
||||
|
|
@ -247,7 +255,12 @@ class ParagraphSerializers(ApiMixin, serializers.Serializer):
|
|||
in_=openapi.IN_QUERY,
|
||||
type=openapi.TYPE_STRING,
|
||||
required=False,
|
||||
description='标题')
|
||||
description='标题'),
|
||||
openapi.Parameter(name='content',
|
||||
in_=openapi.IN_QUERY,
|
||||
type=openapi.TYPE_STRING,
|
||||
required=False,
|
||||
description='内容')
|
||||
]
|
||||
|
||||
@staticmethod
|
||||
|
|
|
|||
|
|
@ -76,7 +76,9 @@ class ProblemSerializers(ApiMixin, serializers.Serializer):
|
|||
'source_id': problem.id,
|
||||
'document_id': self.data.get('document_id'),
|
||||
'paragraph_id': self.data.get('paragraph_id'),
|
||||
'dataset_id': self.data.get('dataset_id')})
|
||||
'dataset_id': self.data.get('dataset_id'),
|
||||
'star_num': 0,
|
||||
'trample_num': 0})
|
||||
|
||||
return ProblemSerializers.Operate(
|
||||
data={'dataset_id': self.data.get('dataset_id'), 'document_id': self.data.get('document_id'),
|
||||
|
|
|
|||
|
|
@ -5,6 +5,9 @@ FROM
|
|||
SELECT
|
||||
"temp_dataset".*,
|
||||
"document_temp"."char_length",
|
||||
CASE
|
||||
WHEN
|
||||
"app_dataset_temp"."count" IS NULL THEN 0 ELSE "app_dataset_temp"."count" END AS application_mapping_count,
|
||||
"document_temp".document_count FROM (
|
||||
SELECT dataset.*
|
||||
FROM
|
||||
|
|
@ -26,5 +29,6 @@ FROM
|
|||
)
|
||||
) temp_dataset
|
||||
LEFT JOIN ( SELECT "count" ( "id" ) AS document_count, "sum" ( "char_length" ) "char_length", dataset_id FROM "document" GROUP BY dataset_id ) "document_temp" ON temp_dataset."id" = "document_temp".dataset_id
|
||||
LEFT JOIN (SELECT "count"("id"),dataset_id FROM application_dataset_mapping GROUP BY dataset_id) app_dataset_temp ON temp_dataset."id" = "app_dataset_temp".dataset_id
|
||||
) temp
|
||||
${default_sql}
|
||||
|
|
@ -8,6 +8,7 @@ urlpatterns = [
|
|||
path('dataset/<str:dataset_id>', views.Dataset.Operate.as_view(), name="dataset_key"),
|
||||
path('dataset/<int:current_page>/<int:page_size>', views.Dataset.Page.as_view(), name="dataset"),
|
||||
path('dataset/<str:dataset_id>/document', views.Document.as_view(), name='document'),
|
||||
path('dataset/<str:dataset_id>/document/_bach', views.Document.Batch.as_view()),
|
||||
path('dataset/<str:dataset_id>/document/<str:document_id>', views.Document.Operate.as_view(),
|
||||
name="document_operate"),
|
||||
path('dataset/document/split', views.Document.Split.as_view(),
|
||||
|
|
|
|||
|
|
@ -52,6 +52,24 @@ class Document(APIView):
|
|||
d.is_valid(raise_exception=True)
|
||||
return result.success(d.list())
|
||||
|
||||
class Batch(APIView):
|
||||
authentication_classes = [TokenAuth]
|
||||
|
||||
@action(methods=['POST'], detail=False)
|
||||
@swagger_auto_schema(operation_summary="批量创建文档",
|
||||
operation_id="批量创建文档",
|
||||
request_body=
|
||||
DocumentSerializers.Batch.get_request_body_api(),
|
||||
manual_parameters=DocumentSerializers.Create.get_request_params_api(),
|
||||
responses=result.get_api_array_response(
|
||||
DocumentSerializers.Operate.get_response_body_api()),
|
||||
tags=["数据集/文档"])
|
||||
@has_permissions(
|
||||
lambda r, k: Permission(group=Group.DATASET, operate=Operate.MANAGE,
|
||||
dynamic_tag=k.get('dataset_id')))
|
||||
def post(self, request: Request, dataset_id: str):
|
||||
return result.success(DocumentSerializers.Batch(data={'dataset_id': dataset_id}).batch_save(request.data))
|
||||
|
||||
class Operate(APIView):
|
||||
authentication_classes = [TokenAuth]
|
||||
|
||||
|
|
|
|||
|
|
@ -7,3 +7,4 @@
|
|||
@desc:
|
||||
"""
|
||||
from .team_management import *
|
||||
from .model_management import *
|
||||
|
|
|
|||
|
|
@ -15,7 +15,7 @@ import uuid
|
|||
from django.core import validators, signing, cache
|
||||
from django.core.mail import send_mail
|
||||
from django.db import transaction
|
||||
from django.db.models import Q
|
||||
from django.db.models import Q, QuerySet
|
||||
from drf_yasg import openapi
|
||||
from rest_framework import serializers
|
||||
|
||||
|
|
@ -395,3 +395,32 @@ class UserSerializer(ApiMixin, serializers.ModelSerializer):
|
|||
'is_active': openapi.Schema(type=openapi.TYPE_STRING, title="是否可用", description="是否可用")
|
||||
}
|
||||
)
|
||||
|
||||
class Query(ApiMixin, serializers.Serializer):
|
||||
email_or_username = serializers.CharField(required=True)
|
||||
|
||||
@staticmethod
|
||||
def get_request_params_api():
|
||||
return [openapi.Parameter(name='email_or_username',
|
||||
in_=openapi.IN_QUERY,
|
||||
type=openapi.TYPE_STRING,
|
||||
required=True,
|
||||
description='邮箱或者用户名')]
|
||||
|
||||
@staticmethod
|
||||
def get_response_body_api():
|
||||
return openapi.Schema(
|
||||
type=openapi.TYPE_OBJECT,
|
||||
required=['username', 'email', ],
|
||||
properties={
|
||||
'username': openapi.Schema(type=openapi.TYPE_STRING, title="用户名", description="用户名"),
|
||||
'email': openapi.Schema(type=openapi.TYPE_STRING, title="邮箱", description="邮箱地址")
|
||||
}
|
||||
)
|
||||
|
||||
def list(self, with_valid=True):
|
||||
if with_valid:
|
||||
self.is_valid(raise_exception=True)
|
||||
email_or_username = self.data.get('email_or_username')
|
||||
return [{'username': user_model.username, 'email': user_model.email} for user_model in
|
||||
QuerySet(User).filter(Q(username=email_or_username) | Q(email=email_or_username))]
|
||||
|
|
|
|||
|
|
@ -5,6 +5,7 @@ from . import views
|
|||
app_name = "user"
|
||||
urlpatterns = [
|
||||
path('user', views.User.as_view(), name="profile"),
|
||||
path('user/list', views.User.Query.as_view()),
|
||||
path('user/login', views.Login.as_view(), name='login'),
|
||||
path('user/logout', views.Logout.as_view(), name='logout'),
|
||||
path('user/register', views.Register.as_view(), name="register"),
|
||||
|
|
|
|||
|
|
@ -22,7 +22,7 @@ from common.response import result
|
|||
from smartdoc.settings import JWT_AUTH
|
||||
from users.serializers.user_serializers import RegisterSerializer, LoginSerializer, CheckCodeSerializer, \
|
||||
RePasswordSerializer, \
|
||||
SendEmailSerializer, UserProfile
|
||||
SendEmailSerializer, UserProfile, UserSerializer
|
||||
|
||||
user_cache = cache.caches['user_cache']
|
||||
token_cache = cache.caches['token_cache']
|
||||
|
|
@ -40,6 +40,20 @@ class User(APIView):
|
|||
def get(self, request: Request):
|
||||
return result.success(UserProfile.get_user_profile(request.user))
|
||||
|
||||
class Query(APIView):
|
||||
authentication_classes = [TokenAuth]
|
||||
|
||||
@action(methods=['GET'], detail=False)
|
||||
@swagger_auto_schema(operation_summary="获取用户列表",
|
||||
operation_id="获取用户列表",
|
||||
manual_parameters=UserSerializer.Query.get_request_params_api(),
|
||||
responses=result.get_api_array_response(UserSerializer.Query.get_response_body_api()),
|
||||
tags=['用户'])
|
||||
@has_permissions(PermissionConstants.USER_READ, compare=CompareConstants.AND)
|
||||
def get(self, request: Request):
|
||||
return result.success(
|
||||
UserSerializer.Query(data={'email_or_username': request.query_params.get('email_or_username')}).list())
|
||||
|
||||
|
||||
class ResetCurrentUserPasswordView(APIView):
|
||||
authentication_classes = [TokenAuth]
|
||||
|
|
|
|||
Loading…
Reference in New Issue