diff --git a/apps/common/db/sql_execute.py b/apps/common/db/sql_execute.py index 3f459759e..79e7de46a 100644 --- a/apps/common/db/sql_execute.py +++ b/apps/common/db/sql_execute.py @@ -27,6 +27,19 @@ def sql_execute(sql: str, params): return result +def update_execute(sql: str, params): + """ + 执行一条sql + :param sql: 需要执行的sql + :param params: sql参数 + :return: 执行结果 + """ + with connection.cursor() as cursor: + cursor.execute(sql, params) + cursor.close() + return None + + def select_list(sql: str, params: List): """ 执行sql 查询列表数据 diff --git a/apps/dataset/serializers/document_serializers.py b/apps/dataset/serializers/document_serializers.py index 7da8ac8ca..0e18b116e 100644 --- a/apps/dataset/serializers/document_serializers.py +++ b/apps/dataset/serializers/document_serializers.py @@ -201,7 +201,7 @@ class DocumentSerializers(ApiMixin, serializers.Serializer): def save(self, instance: Dict, with_valid=False, with_embedding=True, **kwargs): if with_valid: - DocumentInstanceSerializer(data=instance).is_valid() + DocumentInstanceSerializer(data=instance).is_valid(raise_exception=True) self.is_valid(raise_exception=True) dataset_id = self.data.get('dataset_id') @@ -212,13 +212,14 @@ class DocumentSerializers(ApiMixin, serializers.Serializer): 'char_length': reduce(lambda x, y: x + y, [len(p.get('content')) for p in instance.get('paragraphs', [])], 0)}) + # 插入文档 + document_model.save() + for paragraph in instance.get('paragraphs') if 'paragraphs' in instance else []: ParagraphSerializers.Create( data={'dataset_id': dataset_id, 'document_id': str(document_model.id)}).save(paragraph, with_valid=True, with_embedding=False) - # 插入文档 - document_model.save() if with_embedding: ListenerManagement.embedding_by_document_signal.send(str(document_model.id)) return DocumentSerializers.Operate( @@ -284,6 +285,22 @@ class DocumentSerializers(ApiMixin, serializers.Serializer): return list(map(lambda f: file_to_paragraph(f, self.data.get("patterns"), self.data.get("with_filter"), self.data.get("limit")), file_list)) + class Batch(ApiMixin, serializers.Serializer): + dataset_id = serializers.UUIDField(required=True) + + @staticmethod + def get_request_body_api(): + return openapi.Schema(type=openapi.TYPE_ARRAY, items=DocumentSerializers.Create.get_request_body_api()) + + def batch_save(self, instance_list: List[Dict], with_valid=True): + if with_valid: + self.is_valid(raise_exception=True) + DocumentInstanceSerializer(many=True, data=instance_list).is_valid(raise_exception=True) + create_data = {'dataset_id': self.data.get("dataset_id")} + return [DocumentSerializers.Create(data=create_data).save(instance, + with_valid=True) + for instance in instance_list] + def file_to_paragraph(file, pattern_list: List, with_filter, limit: int): data = file.read() diff --git a/apps/dataset/serializers/paragraph_serializers.py b/apps/dataset/serializers/paragraph_serializers.py index a604e3a05..d0436d691 100644 --- a/apps/dataset/serializers/paragraph_serializers.py +++ b/apps/dataset/serializers/paragraph_serializers.py @@ -20,6 +20,7 @@ from common.event.listener_manage import ListenerManagement from common.exception.app_exception import AppApiException from common.mixins.api_mixin import ApiMixin from dataset.models import Paragraph, Problem, Document +from dataset.serializers.common_serializers import update_document_char_length from dataset.serializers.problem_serializers import ProblemInstanceSerializer, ProblemSerializer @@ -123,6 +124,7 @@ class ParagraphSerializers(ApiMixin, serializers.Serializer): update_problem_list) > 0 else None _paragraph.save() + update_document_char_length(self.data.get('document_id')) if 'is_active' in instance and instance.get('is_active') is not None: s = (ListenerManagement.enable_embedding_by_paragraph_signal if instance.get( 'is_active') else ListenerManagement.disable_embedding_by_paragraph_signal) @@ -190,6 +192,8 @@ class ParagraphSerializers(ApiMixin, serializers.Serializer): instance.get('problem_list') if 'problem_list' in instance else [])] # 插入問題 QuerySet(Problem).bulk_create(problem_model_list) if len(problem_model_list) > 0 else None + # 修改长度 + update_document_char_length(document_id) if with_embedding: ListenerManagement.embedding_by_paragraph_signal.send(str(paragraph.id)) return ParagraphSerializers.Operate( @@ -220,6 +224,8 @@ class ParagraphSerializers(ApiMixin, serializers.Serializer): title = serializers.CharField(required=False) + content = serializers.CharField(required=False) + def get_query_set(self): query_set = QuerySet(model=Paragraph) query_set = query_set.filter( @@ -227,6 +233,8 @@ class ParagraphSerializers(ApiMixin, serializers.Serializer): if 'title' in self.data: query_set = query_set.filter( **{'title__contains': self.data.get('title')}) + if 'content' in self.data: + query_set = query_set.filter(**{'content__contains': self.data.get('content')}) return query_set def list(self): @@ -247,7 +255,12 @@ class ParagraphSerializers(ApiMixin, serializers.Serializer): in_=openapi.IN_QUERY, type=openapi.TYPE_STRING, required=False, - description='标题') + description='标题'), + openapi.Parameter(name='content', + in_=openapi.IN_QUERY, + type=openapi.TYPE_STRING, + required=False, + description='内容') ] @staticmethod diff --git a/apps/dataset/serializers/problem_serializers.py b/apps/dataset/serializers/problem_serializers.py index 6845bb34e..0fb5fb455 100644 --- a/apps/dataset/serializers/problem_serializers.py +++ b/apps/dataset/serializers/problem_serializers.py @@ -76,7 +76,9 @@ class ProblemSerializers(ApiMixin, serializers.Serializer): 'source_id': problem.id, 'document_id': self.data.get('document_id'), 'paragraph_id': self.data.get('paragraph_id'), - 'dataset_id': self.data.get('dataset_id')}) + 'dataset_id': self.data.get('dataset_id'), + 'star_num': 0, + 'trample_num': 0}) return ProblemSerializers.Operate( data={'dataset_id': self.data.get('dataset_id'), 'document_id': self.data.get('document_id'), diff --git a/apps/dataset/sql/list_dataset.sql b/apps/dataset/sql/list_dataset.sql index e3e85e229..317dacacf 100644 --- a/apps/dataset/sql/list_dataset.sql +++ b/apps/dataset/sql/list_dataset.sql @@ -5,6 +5,9 @@ FROM SELECT "temp_dataset".*, "document_temp"."char_length", + CASE + WHEN + "app_dataset_temp"."count" IS NULL THEN 0 ELSE "app_dataset_temp"."count" END AS application_mapping_count, "document_temp".document_count FROM ( SELECT dataset.* FROM @@ -26,5 +29,6 @@ FROM ) ) temp_dataset LEFT JOIN ( SELECT "count" ( "id" ) AS document_count, "sum" ( "char_length" ) "char_length", dataset_id FROM "document" GROUP BY dataset_id ) "document_temp" ON temp_dataset."id" = "document_temp".dataset_id + LEFT JOIN (SELECT "count"("id"),dataset_id FROM application_dataset_mapping GROUP BY dataset_id) app_dataset_temp ON temp_dataset."id" = "app_dataset_temp".dataset_id ) temp ${default_sql} \ No newline at end of file diff --git a/apps/dataset/urls.py b/apps/dataset/urls.py index a779642ff..bc75c1b35 100644 --- a/apps/dataset/urls.py +++ b/apps/dataset/urls.py @@ -8,6 +8,7 @@ urlpatterns = [ path('dataset/', views.Dataset.Operate.as_view(), name="dataset_key"), path('dataset//', views.Dataset.Page.as_view(), name="dataset"), path('dataset//document', views.Document.as_view(), name='document'), + path('dataset//document/_bach', views.Document.Batch.as_view()), path('dataset//document/', views.Document.Operate.as_view(), name="document_operate"), path('dataset/document/split', views.Document.Split.as_view(), diff --git a/apps/dataset/views/document.py b/apps/dataset/views/document.py index eaf4cdba0..9e9153e42 100644 --- a/apps/dataset/views/document.py +++ b/apps/dataset/views/document.py @@ -52,6 +52,24 @@ class Document(APIView): d.is_valid(raise_exception=True) return result.success(d.list()) + class Batch(APIView): + authentication_classes = [TokenAuth] + + @action(methods=['POST'], detail=False) + @swagger_auto_schema(operation_summary="批量创建文档", + operation_id="批量创建文档", + request_body= + DocumentSerializers.Batch.get_request_body_api(), + manual_parameters=DocumentSerializers.Create.get_request_params_api(), + responses=result.get_api_array_response( + DocumentSerializers.Operate.get_response_body_api()), + tags=["数据集/文档"]) + @has_permissions( + lambda r, k: Permission(group=Group.DATASET, operate=Operate.MANAGE, + dynamic_tag=k.get('dataset_id'))) + def post(self, request: Request, dataset_id: str): + return result.success(DocumentSerializers.Batch(data={'dataset_id': dataset_id}).batch_save(request.data)) + class Operate(APIView): authentication_classes = [TokenAuth] diff --git a/apps/setting/models/__init__.py b/apps/setting/models/__init__.py index 42f719eea..1de0764b2 100644 --- a/apps/setting/models/__init__.py +++ b/apps/setting/models/__init__.py @@ -7,3 +7,4 @@ @desc: """ from .team_management import * +from .model_management import * diff --git a/apps/users/serializers/user_serializers.py b/apps/users/serializers/user_serializers.py index ece3b31e4..4313dd670 100644 --- a/apps/users/serializers/user_serializers.py +++ b/apps/users/serializers/user_serializers.py @@ -15,7 +15,7 @@ import uuid from django.core import validators, signing, cache from django.core.mail import send_mail from django.db import transaction -from django.db.models import Q +from django.db.models import Q, QuerySet from drf_yasg import openapi from rest_framework import serializers @@ -395,3 +395,32 @@ class UserSerializer(ApiMixin, serializers.ModelSerializer): 'is_active': openapi.Schema(type=openapi.TYPE_STRING, title="是否可用", description="是否可用") } ) + + class Query(ApiMixin, serializers.Serializer): + email_or_username = serializers.CharField(required=True) + + @staticmethod + def get_request_params_api(): + return [openapi.Parameter(name='email_or_username', + in_=openapi.IN_QUERY, + type=openapi.TYPE_STRING, + required=True, + description='邮箱或者用户名')] + + @staticmethod + def get_response_body_api(): + return openapi.Schema( + type=openapi.TYPE_OBJECT, + required=['username', 'email', ], + properties={ + 'username': openapi.Schema(type=openapi.TYPE_STRING, title="用户名", description="用户名"), + 'email': openapi.Schema(type=openapi.TYPE_STRING, title="邮箱", description="邮箱地址") + } + ) + + def list(self, with_valid=True): + if with_valid: + self.is_valid(raise_exception=True) + email_or_username = self.data.get('email_or_username') + return [{'username': user_model.username, 'email': user_model.email} for user_model in + QuerySet(User).filter(Q(username=email_or_username) | Q(email=email_or_username))] diff --git a/apps/users/urls.py b/apps/users/urls.py index 887dec5b0..908c01444 100644 --- a/apps/users/urls.py +++ b/apps/users/urls.py @@ -5,6 +5,7 @@ from . import views app_name = "user" urlpatterns = [ path('user', views.User.as_view(), name="profile"), + path('user/list', views.User.Query.as_view()), path('user/login', views.Login.as_view(), name='login'), path('user/logout', views.Logout.as_view(), name='logout'), path('user/register', views.Register.as_view(), name="register"), diff --git a/apps/users/views/user.py b/apps/users/views/user.py index d79cf2340..e96db337b 100644 --- a/apps/users/views/user.py +++ b/apps/users/views/user.py @@ -22,7 +22,7 @@ from common.response import result from smartdoc.settings import JWT_AUTH from users.serializers.user_serializers import RegisterSerializer, LoginSerializer, CheckCodeSerializer, \ RePasswordSerializer, \ - SendEmailSerializer, UserProfile + SendEmailSerializer, UserProfile, UserSerializer user_cache = cache.caches['user_cache'] token_cache = cache.caches['token_cache'] @@ -40,6 +40,20 @@ class User(APIView): def get(self, request: Request): return result.success(UserProfile.get_user_profile(request.user)) + class Query(APIView): + authentication_classes = [TokenAuth] + + @action(methods=['GET'], detail=False) + @swagger_auto_schema(operation_summary="获取用户列表", + operation_id="获取用户列表", + manual_parameters=UserSerializer.Query.get_request_params_api(), + responses=result.get_api_array_response(UserSerializer.Query.get_response_body_api()), + tags=['用户']) + @has_permissions(PermissionConstants.USER_READ, compare=CompareConstants.AND) + def get(self, request: Request): + return result.success( + UserSerializer.Query(data={'email_or_username': request.query_params.get('email_or_username')}).list()) + class ResetCurrentUserPasswordView(APIView): authentication_classes = [TokenAuth]