fix: 校验提示

This commit is contained in:
shaohuzhang1 2024-03-04 10:12:18 +08:00
parent 8c6cd3b9a4
commit f7a8f11ef7
16 changed files with 457 additions and 276 deletions

View File

@ -16,6 +16,7 @@ from rest_framework import serializers
from application.chat_pipeline.I_base_chat_pipeline import IBaseChatPipelineStep, ParagraphPipelineModel
from application.chat_pipeline.pipeline_manage import PiplineManage
from common.field.common import InstanceField
from common.util.field_message import ErrMessage
from dataset.models import Paragraph
@ -50,21 +51,23 @@ class PostResponseHandler:
class IChatStep(IBaseChatPipelineStep):
class InstanceSerializer(serializers.Serializer):
# 对话列表
message_list = serializers.ListField(required=True, child=MessageField(required=True))
message_list = serializers.ListField(required=True, child=MessageField(required=True),
error_messages=ErrMessage.list("对话列表"))
# 大语言模型
chat_model = ModelField()
chat_model = ModelField(error_messages=ErrMessage.list("大语言模型"))
# 段落列表
paragraph_list = serializers.ListField()
paragraph_list = serializers.ListField(error_messages=ErrMessage.list("段落列表"))
# 对话id
chat_id = serializers.UUIDField(required=True)
chat_id = serializers.UUIDField(required=True, error_messages=ErrMessage.uuid("对话id"))
# 用户问题
problem_text = serializers.CharField(required=True)
problem_text = serializers.CharField(required=True, error_messages=ErrMessage.uuid("用户问题"))
# 后置处理器
post_response_handler = InstanceField(model_type=PostResponseHandler)
post_response_handler = InstanceField(model_type=PostResponseHandler,
error_messages=ErrMessage.base("用户问题"))
# 补全问题
padding_problem_text = serializers.CharField(required=False)
padding_problem_text = serializers.CharField(required=False, error_messages=ErrMessage.base("补全问题"))
# 是否使用流的形式输出
stream = serializers.BooleanField(required=False)
stream = serializers.BooleanField(required=False, error_messages=ErrMessage.base("流式输出"))
def is_valid(self, *, raise_exception=False):
super().is_valid(raise_exception=True)

View File

@ -16,25 +16,29 @@ from application.chat_pipeline.I_base_chat_pipeline import IBaseChatPipelineStep
from application.chat_pipeline.pipeline_manage import PiplineManage
from application.models import ChatRecord
from common.field.common import InstanceField
from common.util.field_message import ErrMessage
from dataset.models import Paragraph
class IGenerateHumanMessageStep(IBaseChatPipelineStep):
class InstanceSerializer(serializers.Serializer):
# 问题
problem_text = serializers.CharField(required=True)
problem_text = serializers.CharField(required=True, error_messages=ErrMessage.char("问题"))
# 段落列表
paragraph_list = serializers.ListField(child=InstanceField(model_type=ParagraphPipelineModel, required=True))
paragraph_list = serializers.ListField(child=InstanceField(model_type=ParagraphPipelineModel, required=True),
error_messages=ErrMessage.list("段落列表"))
# 历史对答
history_chat_record = serializers.ListField(child=InstanceField(model_type=ChatRecord, required=True))
history_chat_record = serializers.ListField(child=InstanceField(model_type=ChatRecord, required=True),
error_messages=ErrMessage.list("历史对答"))
# 多轮对话数量
dialogue_number = serializers.IntegerField(required=True)
dialogue_number = serializers.IntegerField(required=True, error_messages=ErrMessage.integer("多轮对话数量"))
# 最大携带知识库段落长度
max_paragraph_char_number = serializers.IntegerField(required=True)
max_paragraph_char_number = serializers.IntegerField(required=True, error_messages=ErrMessage.integer(
"最大携带知识库段落长度"))
# 模板
prompt = serializers.CharField(required=True)
prompt = serializers.CharField(required=True, error_messages=ErrMessage.char("提示词"))
# 补齐问题
padding_problem_text = serializers.CharField(required=False)
padding_problem_text = serializers.CharField(required=False, error_messages=ErrMessage.char("补齐问题"))
def get_step_serializer(self, manage: PiplineManage) -> Type[serializers.Serializer]:
return self.InstanceSerializer

View File

@ -17,16 +17,18 @@ from application.chat_pipeline.pipeline_manage import PiplineManage
from application.chat_pipeline.step.chat_step.i_chat_step import ModelField
from application.models import ChatRecord
from common.field.common import InstanceField
from common.util.field_message import ErrMessage
class IResetProblemStep(IBaseChatPipelineStep):
class InstanceSerializer(serializers.Serializer):
# 问题文本
problem_text = serializers.CharField(required=True)
problem_text = serializers.CharField(required=True, error_messages=ErrMessage.float("问题文本"))
# 历史对答
history_chat_record = serializers.ListField(child=InstanceField(model_type=ChatRecord, required=True))
history_chat_record = serializers.ListField(child=InstanceField(model_type=ChatRecord, required=True),
error_messages=ErrMessage.list("历史对答"))
# 大语言模型
chat_model = ModelField()
chat_model = ModelField(error_messages=ErrMessage.base("大语言模型"))
def get_step_serializer(self, manage: PiplineManage) -> Type[serializers.Serializer]:
return self.InstanceSerializer

View File

@ -13,25 +13,31 @@ from rest_framework import serializers
from application.chat_pipeline.I_base_chat_pipeline import IBaseChatPipelineStep, ParagraphPipelineModel
from application.chat_pipeline.pipeline_manage import PiplineManage
from common.util.field_message import ErrMessage
from dataset.models import Paragraph
class ISearchDatasetStep(IBaseChatPipelineStep):
class InstanceSerializer(serializers.Serializer):
# 原始问题文本
problem_text = serializers.CharField(required=True)
problem_text = serializers.CharField(required=True, error_messages=ErrMessage.char("文档id"))
# 系统补全问题文本
padding_problem_text = serializers.CharField(required=False)
padding_problem_text = serializers.CharField(required=False, error_messages=ErrMessage.char("系统补全问题文本"))
# 需要查询的数据集id列表
dataset_id_list = serializers.ListField(required=True, child=serializers.UUIDField(required=True))
dataset_id_list = serializers.ListField(required=True, child=serializers.UUIDField(required=True),
error_messages=ErrMessage.list("数据集id列表"))
# 需要排除的文档id
exclude_document_id_list = serializers.ListField(required=True, child=serializers.UUIDField(required=True))
exclude_document_id_list = serializers.ListField(required=True, child=serializers.UUIDField(required=True),
error_messages=ErrMessage.list("排除的文档id列表"))
# 需要排除向量id
exclude_paragraph_id_list = serializers.ListField(required=True, child=serializers.UUIDField(required=True))
exclude_paragraph_id_list = serializers.ListField(required=True, child=serializers.UUIDField(required=True),
error_messages=ErrMessage.list("排除向量id列表"))
# 需要查询的条数
top_n = serializers.IntegerField(required=True)
top_n = serializers.IntegerField(required=True,
error_messages=ErrMessage.integer("引用分段数"))
# 相似度 0-1之间
similarity = serializers.FloatField(required=True, max_value=1, min_value=0)
similarity = serializers.FloatField(required=True, max_value=1, min_value=0,
error_messages=ErrMessage.float("引用分段数"))
def get_step_serializer(self, manage: PiplineManage) -> Type[InstanceSerializer]:
return self.InstanceSerializer

View File

@ -26,6 +26,7 @@ from common.constants.authentication_type import AuthenticationType
from common.db.search import get_dynamics_model, native_search, native_page_search
from common.db.sql_execute import select_list
from common.exception.app_exception import AppApiException, NotFound404
from common.util.field_message import ErrMessage
from common.util.file_util import get_file_content
from dataset.models import DataSet, Document
from dataset.serializers.common_serializers import list_paragraph
@ -39,9 +40,12 @@ token_cache = cache.caches['token_cache']
class ModelDatasetAssociation(serializers.Serializer):
user_id = serializers.UUIDField(required=True)
model_id = serializers.CharField(required=True)
dataset_id_list = serializers.ListSerializer(required=False, child=serializers.UUIDField(required=True))
user_id = serializers.UUIDField(required=True, error_messages=ErrMessage.uuid("用户id"))
model_id = serializers.CharField(required=True, error_messages=ErrMessage.char("模型id"))
dataset_id_list = serializers.ListSerializer(required=False, child=serializers.UUIDField(required=True,
error_messages=ErrMessage.uuid(
"知识库id")),
error_messages=ErrMessage.list("知识库列表"))
def is_valid(self, *, raise_exception=True):
super().is_valid(raise_exception=True)
@ -64,29 +68,35 @@ class ApplicationSerializerModel(serializers.ModelSerializer):
class DatasetSettingSerializer(serializers.Serializer):
top_n = serializers.FloatField(required=True)
similarity = serializers.FloatField(required=True, max_value=1, min_value=0)
max_paragraph_char_number = serializers.IntegerField(required=True, max_value=10000)
top_n = serializers.FloatField(required=True, max_value=100, min_value=1,
error_messages=ErrMessage.float("引用分段数"))
similarity = serializers.FloatField(required=True, max_value=1, min_value=0,
error_messages=ErrMessage.float("相识度"))
max_paragraph_char_number = serializers.IntegerField(required=True, min_value=500, max_value=10000,
error_messages=ErrMessage.integer("最多引用字符数"))
class ModelSettingSerializer(serializers.Serializer):
prompt = serializers.CharField(required=True, max_length=4096)
prompt = serializers.CharField(required=True, max_length=4096, error_messages=ErrMessage.char("提示词"))
class ApplicationSerializer(serializers.Serializer):
name = serializers.CharField(required=True)
desc = serializers.CharField(required=False, allow_null=True, allow_blank=True)
model_id = serializers.CharField(required=True)
multiple_rounds_dialogue = serializers.BooleanField(required=True)
prologue = serializers.CharField(required=False, allow_null=True, allow_blank=True)
name = serializers.CharField(required=True, max_length=64, min_length=1, error_messages=ErrMessage.char("应用名称"))
desc = serializers.CharField(required=False, allow_null=True, allow_blank=True,
max_length=256, min_length=1,
error_messages=ErrMessage.char("应用描述"))
model_id = serializers.CharField(required=True, error_messages=ErrMessage.char("模型"))
multiple_rounds_dialogue = serializers.BooleanField(required=True, error_messages=ErrMessage.char("多轮对话"))
prologue = serializers.CharField(required=False, allow_null=True, allow_blank=True,
error_messages=ErrMessage.char("开场白"))
dataset_id_list = serializers.ListSerializer(required=False, child=serializers.UUIDField(required=True),
allow_null=True)
allow_null=True, error_messages=ErrMessage.list("关联知识库"))
# 数据集相关设置
dataset_setting = DatasetSettingSerializer(required=True)
# 模型相关设置
model_setting = ModelSettingSerializer(required=True)
# 问题补全
problem_optimization = serializers.BooleanField(required=True)
problem_optimization = serializers.BooleanField(required=True, error_messages=ErrMessage.boolean("问题补全"))
def is_valid(self, *, user_id=None, raise_exception=False):
super().is_valid(raise_exception=True)
@ -94,11 +104,12 @@ class ApplicationSerializer(serializers.Serializer):
'dataset_id_list': self.data.get('dataset_id_list')}).is_valid()
class AccessTokenSerializer(serializers.Serializer):
application_id = serializers.UUIDField(required=True)
application_id = serializers.UUIDField(required=True, error_messages=ErrMessage.boolean("应用id"))
class AccessTokenEditSerializer(serializers.Serializer):
access_token_reset = serializers.UUIDField(required=False)
is_active = serializers.BooleanField(required=False)
access_token_reset = serializers.BooleanField(required=False,
error_messages=ErrMessage.boolean("重置Token"))
is_active = serializers.BooleanField(required=False, error_messages=ErrMessage.boolean("是否开启"))
def edit(self, instance: Dict, with_valid=True):
if with_valid:
@ -132,7 +143,7 @@ class ApplicationSerializer(serializers.Serializer):
"is_active": application_access_token.is_active}
class Authentication(serializers.Serializer):
access_token = serializers.CharField(required=True)
access_token = serializers.CharField(required=True, error_messages=ErrMessage.char("access_token"))
def auth(self, with_valid=True):
if with_valid:
@ -150,21 +161,30 @@ class ApplicationSerializer(serializers.Serializer):
raise NotFound404(404, "无效的access_token")
class Edit(serializers.Serializer):
name = serializers.CharField(required=False)
desc = serializers.CharField(required=False)
model_id = serializers.CharField(required=False)
multiple_rounds_dialogue = serializers.BooleanField(required=False)
prologue = serializers.CharField(required=False, allow_null=True, allow_blank=True)
dataset_id_list = serializers.ListSerializer(required=False, child=serializers.UUIDField(required=True))
name = serializers.CharField(required=False, max_length=64, min_length=1,
error_messages=ErrMessage.char("应用名称"))
desc = serializers.CharField(required=False, max_length=256, min_length=1, allow_null=True, allow_blank=True,
error_messages=ErrMessage.char("应用描述"))
model_id = serializers.CharField(required=False, error_messages=ErrMessage.char("模型"))
multiple_rounds_dialogue = serializers.BooleanField(required=False,
error_messages=ErrMessage.boolean("多轮会话"))
prologue = serializers.CharField(required=False, allow_null=True, allow_blank=True,
error_messages=ErrMessage.char("开场白"))
dataset_id_list = serializers.ListSerializer(required=False, child=serializers.UUIDField(required=True),
error_messages=ErrMessage.list("关联知识库")
)
# 数据集相关设置
dataset_setting = serializers.JSONField(required=False, allow_null=True)
dataset_setting = serializers.JSONField(required=False, allow_null=True,
error_messages=ErrMessage.json("数据集设置"))
# 模型相关设置
model_setting = serializers.JSONField(required=False, allow_null=True)
model_setting = serializers.JSONField(required=False, allow_null=True,
error_messages=ErrMessage.json("模型设置"))
# 问题补全
problem_optimization = serializers.BooleanField(required=False, allow_null=True)
problem_optimization = serializers.BooleanField(required=False, allow_null=True,
error_messages=ErrMessage.boolean("问题补全"))
class Create(serializers.Serializer):
user_id = serializers.UUIDField(required=True)
user_id = serializers.UUIDField(required=True, error_messages=ErrMessage.uuid("用户id"))
@transaction.atomic
def insert(self, application: Dict):
@ -201,11 +221,13 @@ class ApplicationSerializer(serializers.Serializer):
return ApplicationDatasetMapping(id=uuid.uuid1(), application_id=application_id, dataset_id=dataset_id)
class HitTest(serializers.Serializer):
id = serializers.CharField(required=True)
user_id = serializers.UUIDField(required=False)
query_text = serializers.CharField(required=True)
top_number = serializers.IntegerField(required=True, max_value=10, min_value=1)
similarity = serializers.FloatField(required=True, max_value=1, min_value=0)
id = serializers.CharField(required=True, error_messages=ErrMessage.uuid("应用id"))
user_id = serializers.UUIDField(required=False, error_messages=ErrMessage.uuid("用户id"))
query_text = serializers.CharField(required=True, error_messages=ErrMessage.char("查询文本"))
top_number = serializers.IntegerField(required=True, max_value=10, min_value=1,
error_messages=ErrMessage.integer("topN"))
similarity = serializers.FloatField(required=True, max_value=1, min_value=0,
error_messages=ErrMessage.float("相关度"))
def is_valid(self, *, raise_exception=False):
super().is_valid(raise_exception=True)
@ -234,11 +256,11 @@ class ApplicationSerializer(serializers.Serializer):
'comprehensive_score': hit_dict.get(p.get('id')).get('comprehensive_score')} for p in p_list]
class Query(serializers.Serializer):
name = serializers.CharField(required=False)
name = serializers.CharField(required=False, error_messages=ErrMessage.char("应用名称"))
desc = serializers.CharField(required=False)
desc = serializers.CharField(required=False, error_messages=ErrMessage.char("应用描述"))
user_id = serializers.UUIDField(required=True)
user_id = serializers.UUIDField(required=True, error_messages=ErrMessage.uuid("用户id"))
def get_query_set(self):
user_id = self.data.get("user_id")
@ -299,8 +321,8 @@ class ApplicationSerializer(serializers.Serializer):
fields = ['id', 'name', 'desc', 'prologue', 'dialogue_number']
class Operate(serializers.Serializer):
application_id = serializers.UUIDField(required=True)
user_id = serializers.UUIDField(required=True)
application_id = serializers.UUIDField(required=True, error_messages=ErrMessage.uuid("应用id"))
user_id = serializers.UUIDField(required=True, error_messages=ErrMessage.uuid("用户id"))
def is_valid(self, *, raise_exception=False):
super().is_valid(raise_exception=True)
@ -392,9 +414,9 @@ class ApplicationSerializer(serializers.Serializer):
fields = "__all__"
class ApplicationKeySerializer(serializers.Serializer):
user_id = serializers.UUIDField(required=True)
user_id = serializers.UUIDField(required=True, error_messages=ErrMessage.uuid("用户id"))
application_id = serializers.UUIDField(required=True)
application_id = serializers.UUIDField(required=True, error_messages=ErrMessage.uuid("应用id"))
def is_valid(self, *, raise_exception=False):
super().is_valid(raise_exception=True)
@ -423,12 +445,12 @@ class ApplicationSerializer(serializers.Serializer):
QuerySet(ApplicationApiKey).filter(application_id=application_id)]
class Edit(serializers.Serializer):
is_active = serializers.BooleanField(required=False)
is_active = serializers.BooleanField(required=False, error_messages=ErrMessage.boolean("是否可用"))
class Operate(serializers.Serializer):
application_id = serializers.UUIDField(required=True)
application_id = serializers.UUIDField(required=True, error_messages=ErrMessage.uuid("应用id"))
api_key_id = serializers.CharField(required=True)
api_key_id = serializers.CharField(required=True, error_messages=ErrMessage.char("ApiKeyid"))
def delete(self, with_valid=True):
if with_valid:

View File

@ -28,6 +28,7 @@ from common.db.search import native_search, native_page_search, page_search, get
from common.event import ListenerManagement
from common.exception.app_exception import AppApiException
from common.util.common import post
from common.util.field_message import ErrMessage
from common.util.file_util import get_file_content
from common.util.lock import try_lock, un_lock
from common.util.rsa_util import decrypt
@ -42,8 +43,8 @@ chat_cache = cache
class ChatSerializers(serializers.Serializer):
class Operate(serializers.Serializer):
chat_id = serializers.UUIDField(required=True)
application_id = serializers.UUIDField(required=True)
chat_id = serializers.UUIDField(required=True, error_messages=ErrMessage.uuid("对话id"))
application_id = serializers.UUIDField(required=True, error_messages=ErrMessage.uuid("应用id"))
def delete(self, with_valid=True):
if with_valid:
@ -52,13 +53,15 @@ class ChatSerializers(serializers.Serializer):
return True
class Query(serializers.Serializer):
abstract = serializers.CharField(required=False)
history_day = serializers.IntegerField(required=True)
user_id = serializers.UUIDField(required=True)
application_id = serializers.UUIDField(required=True)
min_star = serializers.IntegerField(required=False, min_value=0)
min_trample = serializers.IntegerField(required=False, min_value=0)
comparer = serializers.CharField(required=False, validators=[
abstract = serializers.CharField(required=False, error_messages=ErrMessage.char("摘要"))
history_day = serializers.IntegerField(required=True, error_messages=ErrMessage.integer("历史天数"))
user_id = serializers.UUIDField(required=True, error_messages=ErrMessage.uuid("用户id"))
application_id = serializers.UUIDField(required=True, error_messages=ErrMessage.uuid("应用id"))
min_star = serializers.IntegerField(required=False, min_value=0,
error_messages=ErrMessage.integer("最小点赞数"))
min_trample = serializers.IntegerField(required=False, min_value=0,
error_messages=ErrMessage.integer("最小点踩数"))
comparer = serializers.CharField(required=False, error_messages=ErrMessage.char("比较器"), validators=[
validators.RegexValidator(regex=re.compile("^and|or$"),
message="只支持and|or", code=500)
])
@ -116,9 +119,9 @@ class ChatSerializers(serializers.Serializer):
with_table_name=False)
class OpenChat(serializers.Serializer):
user_id = serializers.UUIDField(required=True)
user_id = serializers.UUIDField(required=True, error_messages=ErrMessage.uuid("用户id"))
application_id = serializers.UUIDField(required=True)
application_id = serializers.UUIDField(required=True, error_messages=ErrMessage.uuid("应用id"))
def is_valid(self, *, raise_exception=False):
super().is_valid(raise_exception=True)
@ -153,19 +156,21 @@ class ChatSerializers(serializers.Serializer):
return chat_id
class OpenTempChat(serializers.Serializer):
user_id = serializers.UUIDField(required=True)
user_id = serializers.UUIDField(required=True, error_messages=ErrMessage.uuid("用户id"))
model_id = serializers.UUIDField(required=True)
model_id = serializers.UUIDField(required=True, error_messages=ErrMessage.uuid("模型id"))
multiple_rounds_dialogue = serializers.BooleanField(required=True)
multiple_rounds_dialogue = serializers.BooleanField(required=True,
error_messages=ErrMessage.boolean("多轮会话"))
dataset_id_list = serializers.ListSerializer(required=False, child=serializers.UUIDField(required=True))
dataset_id_list = serializers.ListSerializer(required=False, child=serializers.UUIDField(required=True),
error_messages=ErrMessage.list("关联数据集"))
# 数据集相关设置
dataset_setting = DatasetSettingSerializer(required=True)
# 模型相关设置
model_setting = ModelSettingSerializer(required=True)
# 问题补全
problem_optimization = serializers.BooleanField(required=True)
problem_optimization = serializers.BooleanField(required=True, error_messages=ErrMessage.boolean("问题补全"))
def is_valid(self, *, raise_exception=False):
super().is_valid(raise_exception=True)
@ -208,9 +213,9 @@ class ChatRecordSerializerModel(serializers.ModelSerializer):
class ChatRecordSerializer(serializers.Serializer):
class Operate(serializers.Serializer):
chat_id = serializers.UUIDField(required=True)
chat_id = serializers.UUIDField(required=True, error_messages=ErrMessage.uuid("对话id"))
chat_record_id = serializers.UUIDField(required=True)
chat_record_id = serializers.UUIDField(required=True, error_messages=ErrMessage.uuid("对话记录id"))
def get_chat_record(self):
chat_record_id = self.data.get('chat_record_id')
@ -274,11 +279,11 @@ class ChatRecordSerializer(serializers.Serializer):
return page
class Vote(serializers.Serializer):
chat_id = serializers.UUIDField(required=True)
chat_id = serializers.UUIDField(required=True, error_messages=ErrMessage.uuid("对话id"))
chat_record_id = serializers.UUIDField(required=True)
chat_record_id = serializers.UUIDField(required=True, error_messages=ErrMessage.uuid("对话记录id"))
vote_status = serializers.ChoiceField(choices=VoteChoices.choices)
vote_status = serializers.ChoiceField(choices=VoteChoices.choices, error_messages=ErrMessage.uuid("投标状态"))
@transaction.atomic
def vote(self, with_valid=True):
@ -313,8 +318,9 @@ class ChatRecordSerializer(serializers.Serializer):
return True
class ImproveSerializer(serializers.Serializer):
title = serializers.CharField(required=False, allow_null=True, allow_blank=True)
content = serializers.CharField(required=True)
title = serializers.CharField(required=False, allow_null=True, allow_blank=True,
error_messages=ErrMessage.char("段落标题"))
content = serializers.CharField(required=True, error_messages=ErrMessage.char("段落内容"))
class ParagraphModel(serializers.ModelSerializer):
class Meta:
@ -322,9 +328,9 @@ class ChatRecordSerializer(serializers.Serializer):
fields = "__all__"
class ChatRecordImprove(serializers.Serializer):
chat_id = serializers.UUIDField(required=True)
chat_id = serializers.UUIDField(required=True, error_messages=ErrMessage.uuid("对话id"))
chat_record_id = serializers.UUIDField(required=True)
chat_record_id = serializers.UUIDField(required=True, error_messages=ErrMessage.uuid("对话记录id"))
def get(self, with_valid=True):
if with_valid:
@ -347,13 +353,13 @@ class ChatRecordSerializer(serializers.Serializer):
return [ChatRecordSerializer.ParagraphModel(p).data for p in paragraph_model_list]
class Improve(serializers.Serializer):
chat_id = serializers.UUIDField(required=True)
chat_id = serializers.UUIDField(required=True, error_messages=ErrMessage.uuid("对话id"))
chat_record_id = serializers.UUIDField(required=True)
chat_record_id = serializers.UUIDField(required=True, error_messages=ErrMessage.uuid("对话id"))
dataset_id = serializers.UUIDField(required=True)
dataset_id = serializers.UUIDField(required=True, error_messages=ErrMessage.uuid("知识库id"))
document_id = serializers.UUIDField(required=True)
document_id = serializers.UUIDField(required=True, error_messages=ErrMessage.uuid("文档id"))
def is_valid(self, *, raise_exception=False):
super().is_valid(raise_exception=True)

View File

@ -8,6 +8,7 @@
"""
import logging
import traceback
from typing import Dict
from rest_framework.exceptions import ValidationError, ErrorDetail, APIException
from rest_framework.views import exception_handler
@ -36,7 +37,7 @@ def to_result(key, args, parent_key=None):
return result.Result(500 if isinstance(error_detail.code, str) else error_detail.code,
message=f"{key if parent_key is None else parent_key + '.' + key}】为必填参数" if str(
error_detail) == "This field is required." else f"{key if parent_key is None else parent_key + '.' + key}" + error_detail)
error_detail) == "This field is required." else error_detail)
def validation_error_to_result(exc: ValidationError):
@ -46,14 +47,26 @@ def validation_error_to_result(exc: ValidationError):
:return: 接口响应对象
"""
try:
res = list(map(lambda key: to_result(key, args=exc.args),
exc.args[0].keys() if len(exc.args) > 0 else []))
v = find_err_detail(exc.detail)
if v is None:
return result.error(str(exc.detail))
return result.error(str(v))
except Exception as e:
return result.error(str(exc.detail))
if len(res) > 0:
return res[0]
else:
return result.error("未知异常")
def find_err_detail(exc_detail: Dict):
if isinstance(exc_detail, dict):
keys = exc_detail.keys()
for key in keys:
_value = exc_detail[key]
if isinstance(_value, list):
for v in _value:
return v
elif isinstance(_value, ErrorDetail):
return _value
elif isinstance(_value, dict):
return find_err_detail(_value)
def handle_exception(exc, context):

View File

@ -0,0 +1,88 @@
# coding=utf-8
"""
@project: maxkb
@Author
@file field_message.py
@date2024/3/1 14:30
@desc:
"""
from django.utils.translation import gettext_lazy
class ErrMessage:
@staticmethod
def char(field: str):
return {
'invalid': gettext_lazy("%s】不是有效的字符串。" % field),
'blank': gettext_lazy("%s】此字段不能为空字符串。" % field),
'max_length': gettext_lazy("%s】请确保此字段的字符数不超过 {max_length} 个。" % field),
'min_length': gettext_lazy("%s】请确保此字段至少包含 {min_length} 个字符。" % field),
'required': gettext_lazy('此字段必填。'),
'null': gettext_lazy('此字段不能为null。')
}
@staticmethod
def uuid(field: str):
return {'required': gettext_lazy('%s】此字段必填。' % field),
'null': gettext_lazy('%s】此字段不能为null。' % field),
'invalid': gettext_lazy("%s】必须是有效的UUID。" % field),
}
@staticmethod
def integer(field: str):
return {'invalid': gettext_lazy('%s】必须是有效的integer。' % field),
'max_value': gettext_lazy('%s】请确保此值小于或等于 {max_value}' % field),
'min_value': gettext_lazy('%s】请确保此值大于或等于 {min_value}' % field),
'max_string_length': gettext_lazy('%s】字符串值太大。') % field,
'required': gettext_lazy('%s】此字段必填。' % field),
'null': gettext_lazy('%s】此字段不能为null。' % field),
}
@staticmethod
def list(field: str):
return {'not_a_list': gettext_lazy('%s】应为列表,但得到的类型为 "{input_type}".' % field),
'empty': gettext_lazy('%s】此列表不能为空。' % field),
'min_length': gettext_lazy('%s】请确保此字段至少包含 {min_length} 个元素。' % field),
'max_length': gettext_lazy('%s】请确保此字段的元素不超过 {max_length} 个。' % field),
'required': gettext_lazy('%s】此字段必填。' % field),
'null': gettext_lazy('%s】此字段不能为null。' % field),
}
@staticmethod
def boolean(field: str):
return {'invalid': gettext_lazy('%s】必须是有效的布尔值。' % field),
'required': gettext_lazy('%s】此字段必填。' % field),
'null': gettext_lazy('%s】此字段不能为null。' % field)}
@staticmethod
def dict(field: str):
return {'not_a_dict': gettext_lazy('%s】应为字典,但得到的类型为 "{input_type}' % field),
'empty': gettext_lazy('%s】能是空的。' % field),
'required': gettext_lazy('%s】此字段必填。' % field),
'null': gettext_lazy('%s】此字段不能为null。' % field),
}
@staticmethod
def float(field: str):
return {'invalid': gettext_lazy('%s】需要一个有效的数字。' % field),
'max_value': gettext_lazy('%s】请确保此值小于或等于 {max_value}' % field),
'min_value': gettext_lazy('%s】请确保此值大于或等于 {min_value}' % field),
'max_string_length': gettext_lazy('%s】字符串值太大。' % field),
'required': gettext_lazy('%s】此字段必填。' % field),
'null': gettext_lazy('%s】此字段不能为null。' % field),
}
@staticmethod
def json(field: str):
return {
'invalid': gettext_lazy('%s】值必须是有效的JSON。' % field),
'required': gettext_lazy('%s】此字段必填。' % field),
'null': gettext_lazy('%s】此字段不能为null。' % field),
}
@staticmethod
def base(field: str):
return {
'required': gettext_lazy('%s】此字段必填。' % field),
'null': gettext_lazy('%s】此字段不能为null。' % field),
}

View File

@ -17,6 +17,7 @@ from common.db.search import native_search
from common.db.sql_execute import update_execute
from common.exception.app_exception import AppApiException
from common.mixins.api_mixin import ApiMixin
from common.util.field_message import ErrMessage
from common.util.file_util import get_file_content
from common.util.fork import Fork
from dataset.models import Paragraph
@ -38,8 +39,9 @@ def list_paragraph(paragraph_list: List[str]):
class MetaSerializer(serializers.Serializer):
class WebMeta(serializers.Serializer):
source_url = serializers.CharField(required=True)
selector = serializers.CharField(required=False, allow_null=True, allow_blank=True)
source_url = serializers.CharField(required=True, error_messages=ErrMessage.char("文档地址"))
selector = serializers.CharField(required=False, allow_null=True, allow_blank=True,
error_messages=ErrMessage.char("选择器"))
def is_valid(self, *, raise_exception=False):
super().is_valid(raise_exception=True)
@ -54,7 +56,8 @@ class MetaSerializer(serializers.Serializer):
class BatchSerializer(ApiMixin, serializers.Serializer):
id_list = serializers.ListField(required=True, child=serializers.UUIDField(required=True))
id_list = serializers.ListField(required=True, child=serializers.UUIDField(required=True),
error_messages=ErrMessage.char("id列表"))
def is_valid(self, *, model=None, raise_exception=False):
super().is_valid(raise_exception=True)

View File

@ -30,6 +30,7 @@ from common.event import ListenerManagement, SyncWebDatasetArgs
from common.exception.app_exception import AppApiException
from common.mixins.api_mixin import ApiMixin
from common.util.common import post
from common.util.field_message import ErrMessage
from common.util.file_util import get_file_content
from common.util.fork import ChildLink, Fork
from common.util.split_model import get_split_model
@ -67,9 +68,9 @@ class DataSetSerializers(serializers.ModelSerializer):
fields = ['id', 'name', 'desc', 'meta', 'create_time', 'update_time']
class Application(ApiMixin, serializers.Serializer):
user_id = serializers.UUIDField(required=True)
user_id = serializers.UUIDField(required=True, error_messages=ErrMessage.char("用户id"))
dataset_id = serializers.UUIDField(required=True)
dataset_id = serializers.UUIDField(required=True, error_messages=ErrMessage.char("数据集id"))
@staticmethod
def get_request_params_api():
@ -113,20 +114,15 @@ class DataSetSerializers(serializers.ModelSerializer):
查询对象
"""
name = serializers.CharField(required=False,
validators=[
validators.MaxLengthValidator(limit_value=20,
message="知识库名称在1-20个字符之间"),
validators.MinLengthValidator(limit_value=1,
message="知识库名称在1-20个字符之间")
])
error_messages=ErrMessage.char("知识库名称"),
max_length=64,
min_length=1)
desc = serializers.CharField(required=False,
validators=[
validators.MaxLengthValidator(limit_value=256,
message="知识库名称在1-256个字符之间"),
validators.MinLengthValidator(limit_value=1,
message="知识库名称在1-256个字符之间")
])
error_messages=ErrMessage.char("知识库描述"),
max_length=256,
min_length=1,
)
user_id = serializers.CharField(required=True)
@ -191,27 +187,21 @@ class DataSetSerializers(serializers.ModelSerializer):
return DataSetSerializers.Operate.get_response_body_api()
class Create(ApiMixin, serializers.Serializer):
user_id = serializers.UUIDField(required=True)
user_id = serializers.UUIDField(required=True, error_messages=ErrMessage.char("用户id"), )
class CreateBaseSerializers(ApiMixin, serializers.Serializer):
"""
创建通用数据集序列化对象
"""
name = serializers.CharField(required=True,
validators=[
validators.MaxLengthValidator(limit_value=20,
message="知识库名称在1-20个字符之间"),
validators.MinLengthValidator(limit_value=1,
message="知识库名称在1-20个字符之间")
])
error_messages=ErrMessage.char("知识库名称"),
max_length=64,
min_length=1)
desc = serializers.CharField(required=True,
validators=[
validators.MaxLengthValidator(limit_value=256,
message="知识库名称在1-256个字符之间"),
validators.MinLengthValidator(limit_value=1,
message="知识库名称在1-256个字符之间")
])
error_messages=ErrMessage.char("知识库描述"),
max_length=256,
min_length=1)
documents = DocumentInstanceSerializer(required=False, many=True)
@ -224,23 +214,18 @@ class DataSetSerializers(serializers.ModelSerializer):
创建web站点序列化对象
"""
name = serializers.CharField(required=True,
validators=[
validators.MaxLengthValidator(limit_value=20,
message="知识库名称在1-20个字符之间"),
validators.MinLengthValidator(limit_value=1,
message="知识库名称在1-20个字符之间")
])
error_messages=ErrMessage.char("知识库名称"),
max_length=64,
min_length=1)
desc = serializers.CharField(required=True,
validators=[
validators.MaxLengthValidator(limit_value=256,
message="知识库名称在1-256个字符之间"),
validators.MinLengthValidator(limit_value=1,
message="知识库名称在1-256个字符之间")
])
source_url = serializers.CharField(required=True)
error_messages=ErrMessage.char("知识库描述"),
max_length=256,
min_length=1)
source_url = serializers.CharField(required=True, error_messages=ErrMessage.char("Web 根地址"), )
selector = serializers.CharField(required=False, allow_null=True, allow_blank=True)
selector = serializers.CharField(required=False, allow_null=True, allow_blank=True,
error_messages=ErrMessage.char("选择器"))
def is_valid(self, *, raise_exception=False):
super().is_valid(raise_exception=True)
@ -426,10 +411,15 @@ class DataSetSerializers(serializers.ModelSerializer):
)
class Edit(serializers.Serializer):
name = serializers.CharField(required=False)
desc = serializers.CharField(required=False)
name = serializers.CharField(required=False, max_length=64, min_length=1,
error_messages=ErrMessage.char("知识库名称"))
desc = serializers.CharField(required=False, max_length=256, min_length=1,
error_messages=ErrMessage.char("知识库描述"))
meta = serializers.DictField(required=False)
application_id_list = serializers.ListSerializer(required=False, child=serializers.UUIDField(required=True))
application_id_list = serializers.ListSerializer(required=False, child=serializers.UUIDField(required=True,
error_messages=ErrMessage.char(
"应用id")),
error_messages=ErrMessage.char("应用列表"))
@staticmethod
def get_dataset_meta_valid_map():
@ -447,11 +437,13 @@ class DataSetSerializers(serializers.ModelSerializer):
valid_class(data=self.data.get('meta')).is_valid(raise_exception=True)
class HitTest(ApiMixin, serializers.Serializer):
id = serializers.CharField(required=True)
user_id = serializers.UUIDField(required=False)
query_text = serializers.CharField(required=True)
top_number = serializers.IntegerField(required=True, max_value=10, min_value=1)
similarity = serializers.FloatField(required=True, max_value=1, min_value=0)
id = serializers.CharField(required=True, error_messages=ErrMessage.char("id"))
user_id = serializers.UUIDField(required=False, error_messages=ErrMessage.char("用户id"))
query_text = serializers.CharField(required=True, error_messages=ErrMessage.char("查询文本"))
top_number = serializers.IntegerField(required=True, max_value=10, min_value=1,
error_messages=ErrMessage.char("响应Top"))
similarity = serializers.FloatField(required=True, max_value=1, min_value=0,
error_messages=ErrMessage.char("相似度"))
def is_valid(self, *, raise_exception=True):
super().is_valid(raise_exception=True)
@ -476,11 +468,14 @@ class DataSetSerializers(serializers.ModelSerializer):
'comprehensive_score': hit_dict.get(p.get('id')).get('comprehensive_score')} for p in p_list]
class SyncWeb(ApiMixin, serializers.Serializer):
id = serializers.CharField(required=True)
user_id = serializers.UUIDField(required=False)
sync_type = serializers.CharField(required=True, validators=[
id = serializers.CharField(required=True, error_messages=ErrMessage.char(
"知识库id"))
user_id = serializers.UUIDField(required=False, error_messages=ErrMessage.char(
"用户id"))
sync_type = serializers.CharField(required=True, error_messages=ErrMessage.char(
"同步类型"), validators=[
validators.RegexValidator(regex=re.compile("^replace|complete$"),
message="replace|complete", code=500)
message="同步类型只支持:replace|complete", code=500)
])
def is_valid(self, *, raise_exception=False):
@ -565,8 +560,10 @@ class DataSetSerializers(serializers.ModelSerializer):
]
class Operate(ApiMixin, serializers.Serializer):
id = serializers.CharField(required=True)
user_id = serializers.UUIDField(required=False)
id = serializers.CharField(required=True, error_messages=ErrMessage.char(
"知识库id"))
user_id = serializers.UUIDField(required=False, error_messages=ErrMessage.char(
"用户id"))
def is_valid(self, *, raise_exception=True):
super().is_valid(raise_exception=True)

View File

@ -13,7 +13,6 @@ import uuid
from functools import reduce
from typing import List, Dict
from django.core import validators
from django.db import transaction
from django.db.models import QuerySet
from drf_yasg import openapi
@ -25,6 +24,7 @@ from common.event.listener_manage import ListenerManagement, SyncWebDocumentArgs
from common.exception.app_exception import AppApiException
from common.mixins.api_mixin import ApiMixin
from common.util.common import post
from common.util.field_message import ErrMessage
from common.util.file_util import get_file_content
from common.util.fork import Fork
from common.util.split_model import SplitModel, get_split_model
@ -36,8 +36,11 @@ from smartdoc.conf import PROJECT_DIR
class DocumentEditInstanceSerializer(ApiMixin, serializers.Serializer):
meta = serializers.DictField(required=False)
name = serializers.CharField(required=False)
is_active = serializers.BooleanField(required=False)
name = serializers.CharField(required=False, max_length=128, min_length=1,
error_messages=ErrMessage.char(
"文档名称"))
is_active = serializers.BooleanField(required=False, error_messages=ErrMessage.char(
"文档是否可用"))
@staticmethod
def get_meta_valid_map():
@ -56,8 +59,14 @@ class DocumentEditInstanceSerializer(ApiMixin, serializers.Serializer):
class DocumentWebInstanceSerializer(ApiMixin, serializers.Serializer):
source_url_list = serializers.ListField(required=True, child=serializers.CharField(required=True))
selector = serializers.CharField(required=False, allow_null=True, allow_blank=True)
source_url_list = serializers.ListField(required=True,
child=serializers.CharField(required=True, error_messages=ErrMessage.char(
"文档地址")),
error_messages=ErrMessage.char(
"文档地址列表"))
selector = serializers.CharField(required=False, allow_null=True, allow_blank=True,
error_messages=ErrMessage.char(
"选择器"))
@staticmethod
def get_request_body_api():
@ -74,12 +83,9 @@ class DocumentWebInstanceSerializer(ApiMixin, serializers.Serializer):
class DocumentInstanceSerializer(ApiMixin, serializers.Serializer):
name = serializers.CharField(required=True,
validators=[
validators.MaxLengthValidator(limit_value=128,
message="文档名称在1-128个字符之间"),
validators.MinLengthValidator(limit_value=1,
message="知识库名称在1-128个字符之间")
])
error_messages=ErrMessage.char("文档名称"),
max_length=128,
min_length=1)
paragraphs = ParagraphInstanceSerializer(required=False, many=True, allow_null=True)
@ -99,15 +105,14 @@ class DocumentInstanceSerializer(ApiMixin, serializers.Serializer):
class DocumentSerializers(ApiMixin, serializers.Serializer):
class Query(ApiMixin, serializers.Serializer):
# 知识库id
dataset_id = serializers.UUIDField(required=True)
dataset_id = serializers.UUIDField(required=True,
error_messages=ErrMessage.char(
"知识库id"))
name = serializers.CharField(required=False,
validators=[
validators.MaxLengthValidator(limit_value=128,
message="文档名称在1-128个字符之间"),
validators.MinLengthValidator(limit_value=1,
message="知识库名称在1-128个字符之间")
])
name = serializers.CharField(required=False, max_length=128,
min_length=1,
error_messages=ErrMessage.char(
"文档名称"))
def get_query_set(self):
query_set = QuerySet(model=Document)
@ -144,7 +149,8 @@ class DocumentSerializers(ApiMixin, serializers.Serializer):
items=DocumentSerializers.Operate.get_response_body_api())
class Sync(ApiMixin, serializers.Serializer):
document_id = serializers.UUIDField(required=True)
document_id = serializers.UUIDField(required=True, error_messages=ErrMessage.char(
"文档id"))
def is_valid(self, *, raise_exception=False):
super().is_valid(raise_exception=True)
@ -202,7 +208,8 @@ class DocumentSerializers(ApiMixin, serializers.Serializer):
return True
class Operate(ApiMixin, serializers.Serializer):
document_id = serializers.UUIDField(required=True)
document_id = serializers.UUIDField(required=True, error_messages=ErrMessage.char(
"文档id"))
@staticmethod
def get_request_params_api():
@ -312,7 +319,8 @@ class DocumentSerializers(ApiMixin, serializers.Serializer):
)
class Create(ApiMixin, serializers.Serializer):
dataset_id = serializers.UUIDField(required=True)
dataset_id = serializers.UUIDField(required=True, error_messages=ErrMessage.char(
"文档id"))
def is_valid(self, *, raise_exception=False):
super().is_valid(raise_exception=True)
@ -427,14 +435,20 @@ class DocumentSerializers(ApiMixin, serializers.Serializer):
]
class Split(ApiMixin, serializers.Serializer):
file = serializers.ListField(required=True)
file = serializers.ListField(required=True, error_messages=ErrMessage.list(
"文件列表"))
limit = serializers.IntegerField(required=False)
limit = serializers.IntegerField(required=False, error_messages=ErrMessage.integer(
"分段长度"))
patterns = serializers.ListField(required=False,
child=serializers.CharField(required=True))
child=serializers.CharField(required=True, error_messages=ErrMessage.char(
"分段标识")),
error_messages=ErrMessage.uuid(
"分段标识列表"))
with_filter = serializers.BooleanField(required=False)
with_filter = serializers.BooleanField(required=False, error_messages=ErrMessage.boolean(
"自动清洗"))
def is_valid(self, *, raise_exception=True):
super().is_valid(raise_exception=True)
@ -486,7 +500,7 @@ class DocumentSerializers(ApiMixin, serializers.Serializer):
{'key': '空行', 'value': '(?<!\\n)\\n\\n(?!\\n)'}]
class Batch(ApiMixin, serializers.Serializer):
dataset_id = serializers.UUIDField(required=True)
dataset_id = serializers.UUIDField(required=True, error_messages=ErrMessage.uuid("知识库id"))
@staticmethod
def get_request_body_api():

View File

@ -9,7 +9,6 @@
import uuid
from typing import Dict
from django.core import validators
from django.db import transaction
from django.db.models import QuerySet
from drf_yasg import openapi
@ -20,6 +19,7 @@ from common.event.listener_manage import ListenerManagement
from common.exception.app_exception import AppApiException
from common.mixins.api_mixin import ApiMixin
from common.util.common import post
from common.util.field_message import ErrMessage
from dataset.models import Paragraph, Problem, Document
from dataset.serializers.common_serializers import update_document_char_length
from dataset.serializers.problem_serializers import ProblemInstanceSerializer, ProblemSerializer
@ -36,18 +36,17 @@ class ParagraphInstanceSerializer(ApiMixin, serializers.Serializer):
"""
段落实例对象
"""
content = serializers.CharField(required=True, validators=[
validators.MaxLengthValidator(limit_value=4096,
message="段落在1-1024个字符之间"),
validators.MinLengthValidator(limit_value=1,
message="段落在1-1024个字符之间"),
], allow_null=True, allow_blank=True)
content = serializers.CharField(required=True, error_messages=ErrMessage.char("段落内容"),
max_length=4096,
min_length=1,
allow_null=True, allow_blank=True)
title = serializers.CharField(required=False, allow_null=True, allow_blank=True)
title = serializers.CharField(required=False, error_messages=ErrMessage.char("段落标题"),
allow_null=True, allow_blank=True)
problem_list = ProblemInstanceSerializer(required=False, many=True)
is_active = serializers.BooleanField(required=False)
is_active = serializers.BooleanField(required=False, error_messages=ErrMessage.char("段落是否可用"))
@staticmethod
def get_request_body_api():
@ -72,11 +71,14 @@ class ParagraphInstanceSerializer(ApiMixin, serializers.Serializer):
class ParagraphSerializers(ApiMixin, serializers.Serializer):
class Operate(ApiMixin, serializers.Serializer):
# 段落id
paragraph_id = serializers.UUIDField(required=True)
paragraph_id = serializers.UUIDField(required=True, error_messages=ErrMessage.char(
"段落id"))
# 知识库id
dataset_id = serializers.UUIDField(required=True)
# 知识库id
document_id = serializers.UUIDField(required=True)
dataset_id = serializers.UUIDField(required=True, error_messages=ErrMessage.char(
"知识库id"))
# 文档id
document_id = serializers.UUIDField(required=True, error_messages=ErrMessage.char(
"文档id"))
def is_valid(self, *, raise_exception=True):
super().is_valid(raise_exception=True)
@ -171,9 +173,11 @@ class ParagraphSerializers(ApiMixin, serializers.Serializer):
description="段落id")]
class Create(ApiMixin, serializers.Serializer):
dataset_id = serializers.UUIDField(required=True)
dataset_id = serializers.UUIDField(required=True, error_messages=ErrMessage.char(
"知识库id"))
document_id = serializers.UUIDField(required=True)
document_id = serializers.UUIDField(required=True, error_messages=ErrMessage.char(
"文档id"))
def is_valid(self, *, raise_exception=False):
super().is_valid(raise_exception=True)
@ -234,11 +238,14 @@ class ParagraphSerializers(ApiMixin, serializers.Serializer):
]
class Query(ApiMixin, serializers.Serializer):
dataset_id = serializers.UUIDField(required=True)
dataset_id = serializers.UUIDField(required=True, error_messages=ErrMessage.char(
"知识库id"))
document_id = serializers.UUIDField(required=True)
document_id = serializers.UUIDField(required=True, error_messages=ErrMessage.char(
"文档id"))
title = serializers.CharField(required=False)
title = serializers.CharField(required=False, error_messages=ErrMessage.char(
"段落标题"))
content = serializers.CharField(required=False)

View File

@ -17,6 +17,7 @@ from rest_framework import serializers
from common.event.listener_manage import ListenerManagement
from common.exception.app_exception import AppApiException
from common.mixins.api_mixin import ApiMixin
from common.util.field_message import ErrMessage
from dataset.models import Problem, Paragraph
from embedding.models import SourceType
from embedding.vector.pg_vector import PGVector
@ -30,9 +31,9 @@ class ProblemSerializer(serializers.ModelSerializer):
class ProblemInstanceSerializer(ApiMixin, serializers.Serializer):
id = serializers.CharField(required=False)
id = serializers.CharField(required=False, error_messages=ErrMessage.char("问题id"))
content = serializers.CharField(required=True)
content = serializers.CharField(required=True, error_messages=ErrMessage.char("问题内容"))
@staticmethod
def get_request_body_api():
@ -49,11 +50,11 @@ class ProblemInstanceSerializer(ApiMixin, serializers.Serializer):
class ProblemSerializers(ApiMixin, serializers.Serializer):
class Create(ApiMixin, serializers.Serializer):
dataset_id = serializers.UUIDField(required=True)
dataset_id = serializers.UUIDField(required=True, error_messages=ErrMessage.uuid("知识库id"))
document_id = serializers.UUIDField(required=True)
document_id = serializers.UUIDField(required=True, error_messages=ErrMessage.uuid("文档id"))
paragraph_id = serializers.UUIDField(required=True)
paragraph_id = serializers.UUIDField(required=True, error_messages=ErrMessage.uuid("段落id"))
def is_valid(self, *, raise_exception=False):
super().is_valid(raise_exception=True)
@ -109,11 +110,11 @@ class ProblemSerializers(ApiMixin, serializers.Serializer):
description='段落id')]
class Query(ApiMixin, serializers.Serializer):
dataset_id = serializers.UUIDField(required=True)
dataset_id = serializers.UUIDField(required=True, error_messages=ErrMessage.uuid("知识库id"))
document_id = serializers.UUIDField(required=True)
document_id = serializers.UUIDField(required=True, error_messages=ErrMessage.uuid("文档id"))
paragraph_id = serializers.UUIDField(required=True)
paragraph_id = serializers.UUIDField(required=True, error_messages=ErrMessage.uuid("段落id"))
def is_valid(self, *, raise_exception=True):
super().is_valid(raise_exception=True)
@ -157,13 +158,13 @@ class ProblemSerializers(ApiMixin, serializers.Serializer):
description='段落id')]
class Operate(ApiMixin, serializers.Serializer):
dataset_id = serializers.UUIDField(required=True)
dataset_id = serializers.UUIDField(required=True, error_messages=ErrMessage.uuid("知识库id"))
document_id = serializers.UUIDField(required=True)
document_id = serializers.UUIDField(required=True, error_messages=ErrMessage.uuid("文档id"))
paragraph_id = serializers.UUIDField(required=True)
paragraph_id = serializers.UUIDField(required=True, error_messages=ErrMessage.uuid("段落id"))
problem_id = serializers.UUIDField(required=True)
problem_id = serializers.UUIDField(required=True, error_messages=ErrMessage.uuid("问题id"))
def delete(self, with_valid=False):
if with_valid:

View File

@ -14,6 +14,7 @@ from django.db.models import QuerySet
from rest_framework import serializers
from common.exception.app_exception import AppApiException
from common.util.field_message import ErrMessage
from common.util.rsa_util import encrypt, decrypt
from setting.models.model_management import Model
from setting.models_provider.constants.model_provider_constants import ModelProvideConstants
@ -21,15 +22,16 @@ from setting.models_provider.constants.model_provider_constants import ModelProv
class ModelSerializer(serializers.Serializer):
class Query(serializers.Serializer):
user_id = serializers.UUIDField(required=True)
user_id = serializers.UUIDField(required=True, error_messages=ErrMessage.uuid("用户id"))
name = serializers.CharField(required=False)
name = serializers.CharField(required=False, max_length=20, min_length=10,
error_messages=ErrMessage.char("模型名称"))
model_type = serializers.CharField(required=False)
model_type = serializers.CharField(required=False, error_messages=ErrMessage.char("模型类型"))
model_name = serializers.CharField(required=False)
model_name = serializers.CharField(required=False, error_messages=ErrMessage.char("基础模型"))
provider = serializers.CharField(required=False)
provider = serializers.CharField(required=False, error_messages=ErrMessage.char("供应商"))
def list(self, with_valid):
if with_valid:
@ -50,15 +52,16 @@ class ModelSerializer(serializers.Serializer):
return [ModelSerializer.model_to_dict(model) for model in model_query_set.filter(**query_params)]
class Edit(serializers.Serializer):
user_id = serializers.CharField(required=False)
user_id = serializers.CharField(required=False, error_messages=ErrMessage.uuid("用户id"))
name = serializers.CharField(required=False)
name = serializers.CharField(required=False, max_length=20, min_length=10,
error_messages=ErrMessage.char("模型名称"))
model_type = serializers.CharField(required=False)
model_type = serializers.CharField(required=False, error_messages=ErrMessage.char("模型类型"))
model_name = serializers.CharField(required=False)
model_name = serializers.CharField(required=False, error_messages=ErrMessage.char("模型类型"))
credential = serializers.DictField(required=False)
credential = serializers.DictField(required=False, error_messages=ErrMessage.dict("认证信息"))
def is_valid(self, model=None, raise_exception=False):
super().is_valid(raise_exception=True)
@ -93,17 +96,18 @@ class ModelSerializer(serializers.Serializer):
return credential
class Create(serializers.Serializer):
user_id = serializers.CharField(required=True)
user_id = serializers.CharField(required=True, error_messages=ErrMessage.uuid("用户id"))
name = serializers.CharField(required=True)
name = serializers.CharField(required=True, max_length=20, min_length=10,
error_messages=ErrMessage.char("用户id"))
provider = serializers.CharField(required=True)
provider = serializers.CharField(required=True, error_messages=ErrMessage.char("供应商"))
model_type = serializers.CharField(required=True)
model_type = serializers.CharField(required=True, error_messages=ErrMessage.char("模型类型"))
model_name = serializers.CharField(required=True)
model_name = serializers.CharField(required=True, error_messages=ErrMessage.char("基础模型"))
credential = serializers.DictField(required=True)
credential = serializers.DictField(required=True, error_messages=ErrMessage.dict("认证信息"))
def is_valid(self, *, raise_exception=False):
super().is_valid(raise_exception=True)
@ -144,9 +148,9 @@ class ModelSerializer(serializers.Serializer):
credential)}
class Operate(serializers.Serializer):
id = serializers.UUIDField(required=True)
id = serializers.UUIDField(required=True, error_messages=ErrMessage.uuid("模型id"))
user_id = serializers.UUIDField(required=True)
user_id = serializers.UUIDField(required=True, error_messages=ErrMessage.uuid("用户id"))
def is_valid(self, *, raise_exception=False):
super().is_valid(raise_exception=True)
@ -188,9 +192,9 @@ class ModelSerializer(serializers.Serializer):
class ProviderSerializer(serializers.Serializer):
provider = serializers.CharField(required=True)
provider = serializers.CharField(required=True, error_messages=ErrMessage.char("供应商"))
method = serializers.CharField(required=True)
method = serializers.CharField(required=True, error_messages=ErrMessage.char("执行函数名称"))
def exec(self, exec_params: Dict[str, object], with_valid=False):
if with_valid:

View File

@ -23,6 +23,7 @@ from common.db.sql_execute import select_list
from common.exception.app_exception import AppApiException
from common.mixins.api_mixin import ApiMixin
from common.response.result import get_api_response
from common.util.field_message import ErrMessage
from common.util.file_util import get_file_content
from setting.models import TeamMember, TeamMemberPermission
from smartdoc.conf import PROJECT_DIR
@ -49,8 +50,8 @@ def get_response_body_api():
class TeamMemberPermissionOperate(ApiMixin, serializers.Serializer):
USE = serializers.BooleanField(required=True)
MANAGE = serializers.BooleanField(required=True)
USE = serializers.BooleanField(required=True, error_messages=ErrMessage.boolean("使用"))
MANAGE = serializers.BooleanField(required=True, error_messages=ErrMessage.boolean("管理"))
def get_request_body_api(self):
return openapi.Schema(type=openapi.TYPE_OBJECT,
@ -68,8 +69,8 @@ class TeamMemberPermissionOperate(ApiMixin, serializers.Serializer):
class UpdateTeamMemberItemPermissionSerializer(ApiMixin, serializers.Serializer):
target_id = serializers.CharField(required=True)
type = serializers.CharField(required=True)
target_id = serializers.CharField(required=True, error_messages=ErrMessage.char("目标id"))
type = serializers.CharField(required=True, error_messages=ErrMessage.char("目标类型"))
operate = TeamMemberPermissionOperate(required=True, many=False)
def get_request_body_api(self):
@ -142,7 +143,7 @@ class UpdateTeamMemberPermissionSerializer(ApiMixin, serializers.Serializer):
class TeamMemberSerializer(ApiMixin, serializers.Serializer):
team_id = serializers.UUIDField(required=True)
team_id = serializers.UUIDField(required=True, error_messages=ErrMessage.uuid("团队id"))
def is_valid(self, *, raise_exception=False):
super().is_valid(raise_exception=True)
@ -246,9 +247,9 @@ class TeamMemberSerializer(ApiMixin, serializers.Serializer):
class Operate(ApiMixin, serializers.Serializer):
# 团队 成员id
member_id = serializers.CharField(required=True)
member_id = serializers.CharField(required=True, error_messages=ErrMessage.char("成员id"))
# 团队id
team_id = serializers.CharField(required=True)
team_id = serializers.CharField(required=True, error_messages=ErrMessage.char("团队id"))
def is_valid(self, *, raise_exception=True):
super().is_valid(raise_exception=True)

View File

@ -25,6 +25,7 @@ from common.constants.permission_constants import RoleConstants, get_permission_
from common.exception.app_exception import AppApiException
from common.mixins.api_mixin import ApiMixin
from common.response.result import get_api_response
from common.util.field_message import ErrMessage
from common.util.lock import lock
from setting.models import Team
from smartdoc.conf import PROJECT_DIR
@ -36,14 +37,11 @@ user_cache = cache.caches['user_cache']
class LoginSerializer(ApiMixin, serializers.Serializer):
username = serializers.CharField(required=True,
validators=[
validators.MaxLengthValidator(limit_value=1024,
message=ExceptionCodeConstants.USERNAME_ERROR.value.message),
validators.MinLengthValidator(limit_value=6,
message=ExceptionCodeConstants.USERNAME_ERROR.value.message)
])
error_messages=ErrMessage.char("用户名"),
max_length=20,
min_length=6)
password = serializers.CharField(required=True)
password = serializers.CharField(required=True, error_messages=ErrMessage.char("密码"))
def is_valid(self, *, raise_exception=False):
"""
@ -99,29 +97,33 @@ class RegisterSerializer(ApiMixin, serializers.Serializer):
注册请求对象
"""
email = serializers.EmailField(
required=True,
error_messages=ErrMessage.char("邮箱"),
validators=[validators.EmailValidator(message=ExceptionCodeConstants.EMAIL_FORMAT_ERROR.value.message,
code=ExceptionCodeConstants.EMAIL_FORMAT_ERROR.value.code)])
username = serializers.CharField(required=True,
error_messages=ErrMessage.char("用户名"),
max_length=20,
min_length=6,
validators=[
validators.MaxLengthValidator(limit_value=20,
message=ExceptionCodeConstants.USERNAME_ERROR.value.message),
validators.MinLengthValidator(limit_value=6,
message=ExceptionCodeConstants.USERNAME_ERROR.value.message),
validators.RegexValidator(regex=re.compile("^[a-zA-Z][a-zA-Z1-9_]{5,20}$"),
message="用户名字符数为 6-20 个字符,必须以字母开头,可使用字母、数字、下划线等")
])
password = serializers.CharField(required=True, validators=[validators.RegexValidator(regex=re.compile(
"^(?![a-zA-Z]+$)(?![A-Z0-9]+$)(?![A-Z_!@#$%^&*`~()-+=]+$)(?![a-z0-9]+$)(?![a-z_!@#$%^&*`~()-+=]+$)"
"(?![0-9_!@#$%^&*`~()-+=]+$)[a-zA-Z0-9_!@#$%^&*`~()-+=]{6,20}$")
, message="密码长度6-20个字符必须字母、数字、特殊字符组合")])
password = serializers.CharField(required=True, error_messages=ErrMessage.char("密码"),
validators=[validators.RegexValidator(regex=re.compile(
"^(?![a-zA-Z]+$)(?![A-Z0-9]+$)(?![A-Z_!@#$%^&*`~()-+=]+$)(?![a-z0-9]+$)(?![a-z_!@#$%^&*`~()-+=]+$)"
"(?![0-9_!@#$%^&*`~()-+=]+$)[a-zA-Z0-9_!@#$%^&*`~()-+=]{6,20}$")
, message="密码长度6-20个字符必须字母、数字、特殊字符组合")])
re_password = serializers.CharField(required=True, validators=[validators.RegexValidator(regex=re.compile(
"^(?![a-zA-Z]+$)(?![A-Z0-9]+$)(?![A-Z_!@#$%^&*`~()-+=]+$)(?![a-z0-9]+$)(?![a-z_!@#$%^&*`~()-+=]+$)"
"(?![0-9_!@#$%^&*`~()-+=]+$)[a-zA-Z0-9_!@#$%^&*`~()-+=]{6,20}$")
, message="密码长度6-20个字符必须字母、数字、特殊字符组合")])
re_password = serializers.CharField(required=True,
error_messages=ErrMessage.char("确认密码"),
validators=[validators.RegexValidator(regex=re.compile(
"^(?![a-zA-Z]+$)(?![A-Z0-9]+$)(?![A-Z_!@#$%^&*`~()-+=]+$)(?![a-z0-9]+$)(?![a-z_!@#$%^&*`~()-+=]+$)"
"(?![0-9_!@#$%^&*`~()-+=]+$)[a-zA-Z0-9_!@#$%^&*`~()-+=]{6,20}$")
, message="确认密码长度6-20个字符必须字母、数字、特殊字符组合")])
code = serializers.CharField(required=True)
code = serializers.CharField(required=True, error_messages=ErrMessage.char("验证码"))
class Meta:
model = User
@ -186,14 +188,18 @@ class CheckCodeSerializer(ApiMixin, serializers.Serializer):
校验验证码
"""
email = serializers.EmailField(
required=True,
error_messages=ErrMessage.char("邮箱"),
validators=[validators.EmailValidator(message=ExceptionCodeConstants.EMAIL_FORMAT_ERROR.value.message,
code=ExceptionCodeConstants.EMAIL_FORMAT_ERROR.value.code)])
code = serializers.CharField(required=True)
code = serializers.CharField(required=True, error_messages=ErrMessage.char("验证码"))
type = serializers.CharField(required=True, validators=[
validators.RegexValidator(regex=re.compile("^register|reset_password$"),
message="只支持register|reset_password", code=500)
])
type = serializers.CharField(required=True,
error_messages=ErrMessage.char("类型"),
validators=[
validators.RegexValidator(regex=re.compile("^register|reset_password$"),
message="类型只支持register|reset_password", code=500)
])
def is_valid(self, *, raise_exception=False):
super().is_valid()
@ -227,14 +233,16 @@ class CheckCodeSerializer(ApiMixin, serializers.Serializer):
class RePasswordSerializer(ApiMixin, serializers.Serializer):
email = serializers.EmailField(
required=True,
error_messages=ErrMessage.char("邮箱"),
validators=[validators.EmailValidator(message=ExceptionCodeConstants.EMAIL_FORMAT_ERROR.value.message,
code=ExceptionCodeConstants.EMAIL_FORMAT_ERROR.value.code)])
code = serializers.CharField(required=True)
code = serializers.CharField(required=True, error_messages=ErrMessage.char("验证码"))
password = serializers.CharField(required=True)
password = serializers.CharField(required=True, error_messages=ErrMessage.char("密码"))
re_password = serializers.CharField(required=True)
re_password = serializers.CharField(required=True, error_messages=ErrMessage.char("确认密码"))
class Meta:
model = User
@ -281,12 +289,14 @@ class RePasswordSerializer(ApiMixin, serializers.Serializer):
class SendEmailSerializer(ApiMixin, serializers.Serializer):
email = serializers.EmailField(
required=True
, error_messages=ErrMessage.char("邮箱"),
validators=[validators.EmailValidator(message=ExceptionCodeConstants.EMAIL_FORMAT_ERROR.value.message,
code=ExceptionCodeConstants.EMAIL_FORMAT_ERROR.value.code)])
type = serializers.CharField(required=True, validators=[
type = serializers.CharField(required=True, error_messages=ErrMessage.char("类型"), validators=[
validators.RegexValidator(regex=re.compile("^register|reset_password$"),
message="只支持register|reset_password", code=500)
message="类型只支持register|reset_password", code=500)
])
class Meta: