diff --git a/apps/application/chat_pipeline/step/chat_step/i_chat_step.py b/apps/application/chat_pipeline/step/chat_step/i_chat_step.py index 2c3691cd2..39c6380c5 100644 --- a/apps/application/chat_pipeline/step/chat_step/i_chat_step.py +++ b/apps/application/chat_pipeline/step/chat_step/i_chat_step.py @@ -16,6 +16,7 @@ from rest_framework import serializers from application.chat_pipeline.I_base_chat_pipeline import IBaseChatPipelineStep, ParagraphPipelineModel from application.chat_pipeline.pipeline_manage import PiplineManage from common.field.common import InstanceField +from common.util.field_message import ErrMessage from dataset.models import Paragraph @@ -50,21 +51,23 @@ class PostResponseHandler: class IChatStep(IBaseChatPipelineStep): class InstanceSerializer(serializers.Serializer): # 对话列表 - message_list = serializers.ListField(required=True, child=MessageField(required=True)) + message_list = serializers.ListField(required=True, child=MessageField(required=True), + error_messages=ErrMessage.list("对话列表")) # 大语言模型 - chat_model = ModelField() + chat_model = ModelField(error_messages=ErrMessage.list("大语言模型")) # 段落列表 - paragraph_list = serializers.ListField() + paragraph_list = serializers.ListField(error_messages=ErrMessage.list("段落列表")) # 对话id - chat_id = serializers.UUIDField(required=True) + chat_id = serializers.UUIDField(required=True, error_messages=ErrMessage.uuid("对话id")) # 用户问题 - problem_text = serializers.CharField(required=True) + problem_text = serializers.CharField(required=True, error_messages=ErrMessage.uuid("用户问题")) # 后置处理器 - post_response_handler = InstanceField(model_type=PostResponseHandler) + post_response_handler = InstanceField(model_type=PostResponseHandler, + error_messages=ErrMessage.base("用户问题")) # 补全问题 - padding_problem_text = serializers.CharField(required=False) + padding_problem_text = serializers.CharField(required=False, error_messages=ErrMessage.base("补全问题")) # 是否使用流的形式输出 - stream = serializers.BooleanField(required=False) + stream = serializers.BooleanField(required=False, error_messages=ErrMessage.base("流式输出")) def is_valid(self, *, raise_exception=False): super().is_valid(raise_exception=True) diff --git a/apps/application/chat_pipeline/step/generate_human_message_step/i_generate_human_message_step.py b/apps/application/chat_pipeline/step/generate_human_message_step/i_generate_human_message_step.py index b5b93d05d..01ce57ea0 100644 --- a/apps/application/chat_pipeline/step/generate_human_message_step/i_generate_human_message_step.py +++ b/apps/application/chat_pipeline/step/generate_human_message_step/i_generate_human_message_step.py @@ -16,25 +16,29 @@ from application.chat_pipeline.I_base_chat_pipeline import IBaseChatPipelineStep from application.chat_pipeline.pipeline_manage import PiplineManage from application.models import ChatRecord from common.field.common import InstanceField +from common.util.field_message import ErrMessage from dataset.models import Paragraph class IGenerateHumanMessageStep(IBaseChatPipelineStep): class InstanceSerializer(serializers.Serializer): # 问题 - problem_text = serializers.CharField(required=True) + problem_text = serializers.CharField(required=True, error_messages=ErrMessage.char("问题")) # 段落列表 - paragraph_list = serializers.ListField(child=InstanceField(model_type=ParagraphPipelineModel, required=True)) + paragraph_list = serializers.ListField(child=InstanceField(model_type=ParagraphPipelineModel, required=True), + error_messages=ErrMessage.list("段落列表")) # 历史对答 - history_chat_record = serializers.ListField(child=InstanceField(model_type=ChatRecord, required=True)) + history_chat_record = serializers.ListField(child=InstanceField(model_type=ChatRecord, required=True), + error_messages=ErrMessage.list("历史对答")) # 多轮对话数量 - dialogue_number = serializers.IntegerField(required=True) + dialogue_number = serializers.IntegerField(required=True, error_messages=ErrMessage.integer("多轮对话数量")) # 最大携带知识库段落长度 - max_paragraph_char_number = serializers.IntegerField(required=True) + max_paragraph_char_number = serializers.IntegerField(required=True, error_messages=ErrMessage.integer( + "最大携带知识库段落长度")) # 模板 - prompt = serializers.CharField(required=True) + prompt = serializers.CharField(required=True, error_messages=ErrMessage.char("提示词")) # 补齐问题 - padding_problem_text = serializers.CharField(required=False) + padding_problem_text = serializers.CharField(required=False, error_messages=ErrMessage.char("补齐问题")) def get_step_serializer(self, manage: PiplineManage) -> Type[serializers.Serializer]: return self.InstanceSerializer diff --git a/apps/application/chat_pipeline/step/reset_problem_step/i_reset_problem_step.py b/apps/application/chat_pipeline/step/reset_problem_step/i_reset_problem_step.py index c8f9143dd..4ff578666 100644 --- a/apps/application/chat_pipeline/step/reset_problem_step/i_reset_problem_step.py +++ b/apps/application/chat_pipeline/step/reset_problem_step/i_reset_problem_step.py @@ -17,16 +17,18 @@ from application.chat_pipeline.pipeline_manage import PiplineManage from application.chat_pipeline.step.chat_step.i_chat_step import ModelField from application.models import ChatRecord from common.field.common import InstanceField +from common.util.field_message import ErrMessage class IResetProblemStep(IBaseChatPipelineStep): class InstanceSerializer(serializers.Serializer): # 问题文本 - problem_text = serializers.CharField(required=True) + problem_text = serializers.CharField(required=True, error_messages=ErrMessage.float("问题文本")) # 历史对答 - history_chat_record = serializers.ListField(child=InstanceField(model_type=ChatRecord, required=True)) + history_chat_record = serializers.ListField(child=InstanceField(model_type=ChatRecord, required=True), + error_messages=ErrMessage.list("历史对答")) # 大语言模型 - chat_model = ModelField() + chat_model = ModelField(error_messages=ErrMessage.base("大语言模型")) def get_step_serializer(self, manage: PiplineManage) -> Type[serializers.Serializer]: return self.InstanceSerializer diff --git a/apps/application/chat_pipeline/step/search_dataset_step/i_search_dataset_step.py b/apps/application/chat_pipeline/step/search_dataset_step/i_search_dataset_step.py index 08ea08d1f..11431cafc 100644 --- a/apps/application/chat_pipeline/step/search_dataset_step/i_search_dataset_step.py +++ b/apps/application/chat_pipeline/step/search_dataset_step/i_search_dataset_step.py @@ -13,25 +13,31 @@ from rest_framework import serializers from application.chat_pipeline.I_base_chat_pipeline import IBaseChatPipelineStep, ParagraphPipelineModel from application.chat_pipeline.pipeline_manage import PiplineManage +from common.util.field_message import ErrMessage from dataset.models import Paragraph class ISearchDatasetStep(IBaseChatPipelineStep): class InstanceSerializer(serializers.Serializer): # 原始问题文本 - problem_text = serializers.CharField(required=True) + problem_text = serializers.CharField(required=True, error_messages=ErrMessage.char("文档id")) # 系统补全问题文本 - padding_problem_text = serializers.CharField(required=False) + padding_problem_text = serializers.CharField(required=False, error_messages=ErrMessage.char("系统补全问题文本")) # 需要查询的数据集id列表 - dataset_id_list = serializers.ListField(required=True, child=serializers.UUIDField(required=True)) + dataset_id_list = serializers.ListField(required=True, child=serializers.UUIDField(required=True), + error_messages=ErrMessage.list("数据集id列表")) # 需要排除的文档id - exclude_document_id_list = serializers.ListField(required=True, child=serializers.UUIDField(required=True)) + exclude_document_id_list = serializers.ListField(required=True, child=serializers.UUIDField(required=True), + error_messages=ErrMessage.list("排除的文档id列表")) # 需要排除向量id - exclude_paragraph_id_list = serializers.ListField(required=True, child=serializers.UUIDField(required=True)) + exclude_paragraph_id_list = serializers.ListField(required=True, child=serializers.UUIDField(required=True), + error_messages=ErrMessage.list("排除向量id列表")) # 需要查询的条数 - top_n = serializers.IntegerField(required=True) + top_n = serializers.IntegerField(required=True, + error_messages=ErrMessage.integer("引用分段数")) # 相似度 0-1之间 - similarity = serializers.FloatField(required=True, max_value=1, min_value=0) + similarity = serializers.FloatField(required=True, max_value=1, min_value=0, + error_messages=ErrMessage.float("引用分段数")) def get_step_serializer(self, manage: PiplineManage) -> Type[InstanceSerializer]: return self.InstanceSerializer diff --git a/apps/application/serializers/application_serializers.py b/apps/application/serializers/application_serializers.py index 953079b55..0523c9894 100644 --- a/apps/application/serializers/application_serializers.py +++ b/apps/application/serializers/application_serializers.py @@ -26,6 +26,7 @@ from common.constants.authentication_type import AuthenticationType from common.db.search import get_dynamics_model, native_search, native_page_search from common.db.sql_execute import select_list from common.exception.app_exception import AppApiException, NotFound404 +from common.util.field_message import ErrMessage from common.util.file_util import get_file_content from dataset.models import DataSet, Document from dataset.serializers.common_serializers import list_paragraph @@ -39,9 +40,12 @@ token_cache = cache.caches['token_cache'] class ModelDatasetAssociation(serializers.Serializer): - user_id = serializers.UUIDField(required=True) - model_id = serializers.CharField(required=True) - dataset_id_list = serializers.ListSerializer(required=False, child=serializers.UUIDField(required=True)) + user_id = serializers.UUIDField(required=True, error_messages=ErrMessage.uuid("用户id")) + model_id = serializers.CharField(required=True, error_messages=ErrMessage.char("模型id")) + dataset_id_list = serializers.ListSerializer(required=False, child=serializers.UUIDField(required=True, + error_messages=ErrMessage.uuid( + "知识库id")), + error_messages=ErrMessage.list("知识库列表")) def is_valid(self, *, raise_exception=True): super().is_valid(raise_exception=True) @@ -64,29 +68,35 @@ class ApplicationSerializerModel(serializers.ModelSerializer): class DatasetSettingSerializer(serializers.Serializer): - top_n = serializers.FloatField(required=True) - similarity = serializers.FloatField(required=True, max_value=1, min_value=0) - max_paragraph_char_number = serializers.IntegerField(required=True, max_value=10000) + top_n = serializers.FloatField(required=True, max_value=100, min_value=1, + error_messages=ErrMessage.float("引用分段数")) + similarity = serializers.FloatField(required=True, max_value=1, min_value=0, + error_messages=ErrMessage.float("相识度")) + max_paragraph_char_number = serializers.IntegerField(required=True, min_value=500, max_value=10000, + error_messages=ErrMessage.integer("最多引用字符数")) class ModelSettingSerializer(serializers.Serializer): - prompt = serializers.CharField(required=True, max_length=4096) + prompt = serializers.CharField(required=True, max_length=4096, error_messages=ErrMessage.char("提示词")) class ApplicationSerializer(serializers.Serializer): - name = serializers.CharField(required=True) - desc = serializers.CharField(required=False, allow_null=True, allow_blank=True) - model_id = serializers.CharField(required=True) - multiple_rounds_dialogue = serializers.BooleanField(required=True) - prologue = serializers.CharField(required=False, allow_null=True, allow_blank=True) + name = serializers.CharField(required=True, max_length=64, min_length=1, error_messages=ErrMessage.char("应用名称")) + desc = serializers.CharField(required=False, allow_null=True, allow_blank=True, + max_length=256, min_length=1, + error_messages=ErrMessage.char("应用描述")) + model_id = serializers.CharField(required=True, error_messages=ErrMessage.char("模型")) + multiple_rounds_dialogue = serializers.BooleanField(required=True, error_messages=ErrMessage.char("多轮对话")) + prologue = serializers.CharField(required=False, allow_null=True, allow_blank=True, + error_messages=ErrMessage.char("开场白")) dataset_id_list = serializers.ListSerializer(required=False, child=serializers.UUIDField(required=True), - allow_null=True) + allow_null=True, error_messages=ErrMessage.list("关联知识库")) # 数据集相关设置 dataset_setting = DatasetSettingSerializer(required=True) # 模型相关设置 model_setting = ModelSettingSerializer(required=True) # 问题补全 - problem_optimization = serializers.BooleanField(required=True) + problem_optimization = serializers.BooleanField(required=True, error_messages=ErrMessage.boolean("问题补全")) def is_valid(self, *, user_id=None, raise_exception=False): super().is_valid(raise_exception=True) @@ -94,11 +104,12 @@ class ApplicationSerializer(serializers.Serializer): 'dataset_id_list': self.data.get('dataset_id_list')}).is_valid() class AccessTokenSerializer(serializers.Serializer): - application_id = serializers.UUIDField(required=True) + application_id = serializers.UUIDField(required=True, error_messages=ErrMessage.boolean("应用id")) class AccessTokenEditSerializer(serializers.Serializer): - access_token_reset = serializers.UUIDField(required=False) - is_active = serializers.BooleanField(required=False) + access_token_reset = serializers.BooleanField(required=False, + error_messages=ErrMessage.boolean("重置Token")) + is_active = serializers.BooleanField(required=False, error_messages=ErrMessage.boolean("是否开启")) def edit(self, instance: Dict, with_valid=True): if with_valid: @@ -132,7 +143,7 @@ class ApplicationSerializer(serializers.Serializer): "is_active": application_access_token.is_active} class Authentication(serializers.Serializer): - access_token = serializers.CharField(required=True) + access_token = serializers.CharField(required=True, error_messages=ErrMessage.char("access_token")) def auth(self, with_valid=True): if with_valid: @@ -150,21 +161,30 @@ class ApplicationSerializer(serializers.Serializer): raise NotFound404(404, "无效的access_token") class Edit(serializers.Serializer): - name = serializers.CharField(required=False) - desc = serializers.CharField(required=False) - model_id = serializers.CharField(required=False) - multiple_rounds_dialogue = serializers.BooleanField(required=False) - prologue = serializers.CharField(required=False, allow_null=True, allow_blank=True) - dataset_id_list = serializers.ListSerializer(required=False, child=serializers.UUIDField(required=True)) + name = serializers.CharField(required=False, max_length=64, min_length=1, + error_messages=ErrMessage.char("应用名称")) + desc = serializers.CharField(required=False, max_length=256, min_length=1, allow_null=True, allow_blank=True, + error_messages=ErrMessage.char("应用描述")) + model_id = serializers.CharField(required=False, error_messages=ErrMessage.char("模型")) + multiple_rounds_dialogue = serializers.BooleanField(required=False, + error_messages=ErrMessage.boolean("多轮会话")) + prologue = serializers.CharField(required=False, allow_null=True, allow_blank=True, + error_messages=ErrMessage.char("开场白")) + dataset_id_list = serializers.ListSerializer(required=False, child=serializers.UUIDField(required=True), + error_messages=ErrMessage.list("关联知识库") + ) # 数据集相关设置 - dataset_setting = serializers.JSONField(required=False, allow_null=True) + dataset_setting = serializers.JSONField(required=False, allow_null=True, + error_messages=ErrMessage.json("数据集设置")) # 模型相关设置 - model_setting = serializers.JSONField(required=False, allow_null=True) + model_setting = serializers.JSONField(required=False, allow_null=True, + error_messages=ErrMessage.json("模型设置")) # 问题补全 - problem_optimization = serializers.BooleanField(required=False, allow_null=True) + problem_optimization = serializers.BooleanField(required=False, allow_null=True, + error_messages=ErrMessage.boolean("问题补全")) class Create(serializers.Serializer): - user_id = serializers.UUIDField(required=True) + user_id = serializers.UUIDField(required=True, error_messages=ErrMessage.uuid("用户id")) @transaction.atomic def insert(self, application: Dict): @@ -201,11 +221,13 @@ class ApplicationSerializer(serializers.Serializer): return ApplicationDatasetMapping(id=uuid.uuid1(), application_id=application_id, dataset_id=dataset_id) class HitTest(serializers.Serializer): - id = serializers.CharField(required=True) - user_id = serializers.UUIDField(required=False) - query_text = serializers.CharField(required=True) - top_number = serializers.IntegerField(required=True, max_value=10, min_value=1) - similarity = serializers.FloatField(required=True, max_value=1, min_value=0) + id = serializers.CharField(required=True, error_messages=ErrMessage.uuid("应用id")) + user_id = serializers.UUIDField(required=False, error_messages=ErrMessage.uuid("用户id")) + query_text = serializers.CharField(required=True, error_messages=ErrMessage.char("查询文本")) + top_number = serializers.IntegerField(required=True, max_value=10, min_value=1, + error_messages=ErrMessage.integer("topN")) + similarity = serializers.FloatField(required=True, max_value=1, min_value=0, + error_messages=ErrMessage.float("相关度")) def is_valid(self, *, raise_exception=False): super().is_valid(raise_exception=True) @@ -234,11 +256,11 @@ class ApplicationSerializer(serializers.Serializer): 'comprehensive_score': hit_dict.get(p.get('id')).get('comprehensive_score')} for p in p_list] class Query(serializers.Serializer): - name = serializers.CharField(required=False) + name = serializers.CharField(required=False, error_messages=ErrMessage.char("应用名称")) - desc = serializers.CharField(required=False) + desc = serializers.CharField(required=False, error_messages=ErrMessage.char("应用描述")) - user_id = serializers.UUIDField(required=True) + user_id = serializers.UUIDField(required=True, error_messages=ErrMessage.uuid("用户id")) def get_query_set(self): user_id = self.data.get("user_id") @@ -299,8 +321,8 @@ class ApplicationSerializer(serializers.Serializer): fields = ['id', 'name', 'desc', 'prologue', 'dialogue_number'] class Operate(serializers.Serializer): - application_id = serializers.UUIDField(required=True) - user_id = serializers.UUIDField(required=True) + application_id = serializers.UUIDField(required=True, error_messages=ErrMessage.uuid("应用id")) + user_id = serializers.UUIDField(required=True, error_messages=ErrMessage.uuid("用户id")) def is_valid(self, *, raise_exception=False): super().is_valid(raise_exception=True) @@ -392,9 +414,9 @@ class ApplicationSerializer(serializers.Serializer): fields = "__all__" class ApplicationKeySerializer(serializers.Serializer): - user_id = serializers.UUIDField(required=True) + user_id = serializers.UUIDField(required=True, error_messages=ErrMessage.uuid("用户id")) - application_id = serializers.UUIDField(required=True) + application_id = serializers.UUIDField(required=True, error_messages=ErrMessage.uuid("应用id")) def is_valid(self, *, raise_exception=False): super().is_valid(raise_exception=True) @@ -423,12 +445,12 @@ class ApplicationSerializer(serializers.Serializer): QuerySet(ApplicationApiKey).filter(application_id=application_id)] class Edit(serializers.Serializer): - is_active = serializers.BooleanField(required=False) + is_active = serializers.BooleanField(required=False, error_messages=ErrMessage.boolean("是否可用")) class Operate(serializers.Serializer): - application_id = serializers.UUIDField(required=True) + application_id = serializers.UUIDField(required=True, error_messages=ErrMessage.uuid("应用id")) - api_key_id = serializers.CharField(required=True) + api_key_id = serializers.CharField(required=True, error_messages=ErrMessage.char("ApiKeyid")) def delete(self, with_valid=True): if with_valid: diff --git a/apps/application/serializers/chat_serializers.py b/apps/application/serializers/chat_serializers.py index 0c60ab662..0fd6b264d 100644 --- a/apps/application/serializers/chat_serializers.py +++ b/apps/application/serializers/chat_serializers.py @@ -28,6 +28,7 @@ from common.db.search import native_search, native_page_search, page_search, get from common.event import ListenerManagement from common.exception.app_exception import AppApiException from common.util.common import post +from common.util.field_message import ErrMessage from common.util.file_util import get_file_content from common.util.lock import try_lock, un_lock from common.util.rsa_util import decrypt @@ -42,8 +43,8 @@ chat_cache = cache class ChatSerializers(serializers.Serializer): class Operate(serializers.Serializer): - chat_id = serializers.UUIDField(required=True) - application_id = serializers.UUIDField(required=True) + chat_id = serializers.UUIDField(required=True, error_messages=ErrMessage.uuid("对话id")) + application_id = serializers.UUIDField(required=True, error_messages=ErrMessage.uuid("应用id")) def delete(self, with_valid=True): if with_valid: @@ -52,13 +53,15 @@ class ChatSerializers(serializers.Serializer): return True class Query(serializers.Serializer): - abstract = serializers.CharField(required=False) - history_day = serializers.IntegerField(required=True) - user_id = serializers.UUIDField(required=True) - application_id = serializers.UUIDField(required=True) - min_star = serializers.IntegerField(required=False, min_value=0) - min_trample = serializers.IntegerField(required=False, min_value=0) - comparer = serializers.CharField(required=False, validators=[ + abstract = serializers.CharField(required=False, error_messages=ErrMessage.char("摘要")) + history_day = serializers.IntegerField(required=True, error_messages=ErrMessage.integer("历史天数")) + user_id = serializers.UUIDField(required=True, error_messages=ErrMessage.uuid("用户id")) + application_id = serializers.UUIDField(required=True, error_messages=ErrMessage.uuid("应用id")) + min_star = serializers.IntegerField(required=False, min_value=0, + error_messages=ErrMessage.integer("最小点赞数")) + min_trample = serializers.IntegerField(required=False, min_value=0, + error_messages=ErrMessage.integer("最小点踩数")) + comparer = serializers.CharField(required=False, error_messages=ErrMessage.char("比较器"), validators=[ validators.RegexValidator(regex=re.compile("^and|or$"), message="只支持and|or", code=500) ]) @@ -116,9 +119,9 @@ class ChatSerializers(serializers.Serializer): with_table_name=False) class OpenChat(serializers.Serializer): - user_id = serializers.UUIDField(required=True) + user_id = serializers.UUIDField(required=True, error_messages=ErrMessage.uuid("用户id")) - application_id = serializers.UUIDField(required=True) + application_id = serializers.UUIDField(required=True, error_messages=ErrMessage.uuid("应用id")) def is_valid(self, *, raise_exception=False): super().is_valid(raise_exception=True) @@ -153,19 +156,21 @@ class ChatSerializers(serializers.Serializer): return chat_id class OpenTempChat(serializers.Serializer): - user_id = serializers.UUIDField(required=True) + user_id = serializers.UUIDField(required=True, error_messages=ErrMessage.uuid("用户id")) - model_id = serializers.UUIDField(required=True) + model_id = serializers.UUIDField(required=True, error_messages=ErrMessage.uuid("模型id")) - multiple_rounds_dialogue = serializers.BooleanField(required=True) + multiple_rounds_dialogue = serializers.BooleanField(required=True, + error_messages=ErrMessage.boolean("多轮会话")) - dataset_id_list = serializers.ListSerializer(required=False, child=serializers.UUIDField(required=True)) + dataset_id_list = serializers.ListSerializer(required=False, child=serializers.UUIDField(required=True), + error_messages=ErrMessage.list("关联数据集")) # 数据集相关设置 dataset_setting = DatasetSettingSerializer(required=True) # 模型相关设置 model_setting = ModelSettingSerializer(required=True) # 问题补全 - problem_optimization = serializers.BooleanField(required=True) + problem_optimization = serializers.BooleanField(required=True, error_messages=ErrMessage.boolean("问题补全")) def is_valid(self, *, raise_exception=False): super().is_valid(raise_exception=True) @@ -208,9 +213,9 @@ class ChatRecordSerializerModel(serializers.ModelSerializer): class ChatRecordSerializer(serializers.Serializer): class Operate(serializers.Serializer): - chat_id = serializers.UUIDField(required=True) + chat_id = serializers.UUIDField(required=True, error_messages=ErrMessage.uuid("对话id")) - chat_record_id = serializers.UUIDField(required=True) + chat_record_id = serializers.UUIDField(required=True, error_messages=ErrMessage.uuid("对话记录id")) def get_chat_record(self): chat_record_id = self.data.get('chat_record_id') @@ -274,11 +279,11 @@ class ChatRecordSerializer(serializers.Serializer): return page class Vote(serializers.Serializer): - chat_id = serializers.UUIDField(required=True) + chat_id = serializers.UUIDField(required=True, error_messages=ErrMessage.uuid("对话id")) - chat_record_id = serializers.UUIDField(required=True) + chat_record_id = serializers.UUIDField(required=True, error_messages=ErrMessage.uuid("对话记录id")) - vote_status = serializers.ChoiceField(choices=VoteChoices.choices) + vote_status = serializers.ChoiceField(choices=VoteChoices.choices, error_messages=ErrMessage.uuid("投标状态")) @transaction.atomic def vote(self, with_valid=True): @@ -313,8 +318,9 @@ class ChatRecordSerializer(serializers.Serializer): return True class ImproveSerializer(serializers.Serializer): - title = serializers.CharField(required=False, allow_null=True, allow_blank=True) - content = serializers.CharField(required=True) + title = serializers.CharField(required=False, allow_null=True, allow_blank=True, + error_messages=ErrMessage.char("段落标题")) + content = serializers.CharField(required=True, error_messages=ErrMessage.char("段落内容")) class ParagraphModel(serializers.ModelSerializer): class Meta: @@ -322,9 +328,9 @@ class ChatRecordSerializer(serializers.Serializer): fields = "__all__" class ChatRecordImprove(serializers.Serializer): - chat_id = serializers.UUIDField(required=True) + chat_id = serializers.UUIDField(required=True, error_messages=ErrMessage.uuid("对话id")) - chat_record_id = serializers.UUIDField(required=True) + chat_record_id = serializers.UUIDField(required=True, error_messages=ErrMessage.uuid("对话记录id")) def get(self, with_valid=True): if with_valid: @@ -347,13 +353,13 @@ class ChatRecordSerializer(serializers.Serializer): return [ChatRecordSerializer.ParagraphModel(p).data for p in paragraph_model_list] class Improve(serializers.Serializer): - chat_id = serializers.UUIDField(required=True) + chat_id = serializers.UUIDField(required=True, error_messages=ErrMessage.uuid("对话id")) - chat_record_id = serializers.UUIDField(required=True) + chat_record_id = serializers.UUIDField(required=True, error_messages=ErrMessage.uuid("对话id")) - dataset_id = serializers.UUIDField(required=True) + dataset_id = serializers.UUIDField(required=True, error_messages=ErrMessage.uuid("知识库id")) - document_id = serializers.UUIDField(required=True) + document_id = serializers.UUIDField(required=True, error_messages=ErrMessage.uuid("文档id")) def is_valid(self, *, raise_exception=False): super().is_valid(raise_exception=True) diff --git a/apps/common/handle/handle_exception.py b/apps/common/handle/handle_exception.py index 32784861b..272588c56 100644 --- a/apps/common/handle/handle_exception.py +++ b/apps/common/handle/handle_exception.py @@ -8,6 +8,7 @@ """ import logging import traceback +from typing import Dict from rest_framework.exceptions import ValidationError, ErrorDetail, APIException from rest_framework.views import exception_handler @@ -36,7 +37,7 @@ def to_result(key, args, parent_key=None): return result.Result(500 if isinstance(error_detail.code, str) else error_detail.code, message=f"【{key if parent_key is None else parent_key + '.' + key}】为必填参数" if str( - error_detail) == "This field is required." else f"【{key if parent_key is None else parent_key + '.' + key}】" + error_detail) + error_detail) == "This field is required." else error_detail) def validation_error_to_result(exc: ValidationError): @@ -46,14 +47,26 @@ def validation_error_to_result(exc: ValidationError): :return: 接口响应对象 """ try: - res = list(map(lambda key: to_result(key, args=exc.args), - exc.args[0].keys() if len(exc.args) > 0 else [])) + v = find_err_detail(exc.detail) + if v is None: + return result.error(str(exc.detail)) + return result.error(str(v)) except Exception as e: return result.error(str(exc.detail)) - if len(res) > 0: - return res[0] - else: - return result.error("未知异常") + + +def find_err_detail(exc_detail: Dict): + if isinstance(exc_detail, dict): + keys = exc_detail.keys() + for key in keys: + _value = exc_detail[key] + if isinstance(_value, list): + for v in _value: + return v + elif isinstance(_value, ErrorDetail): + return _value + elif isinstance(_value, dict): + return find_err_detail(_value) def handle_exception(exc, context): diff --git a/apps/common/util/field_message.py b/apps/common/util/field_message.py new file mode 100644 index 000000000..3d0fe9d46 --- /dev/null +++ b/apps/common/util/field_message.py @@ -0,0 +1,88 @@ +# coding=utf-8 +""" + @project: maxkb + @Author:虎 + @file: field_message.py + @date:2024/3/1 14:30 + @desc: +""" +from django.utils.translation import gettext_lazy + + +class ErrMessage: + @staticmethod + def char(field: str): + return { + 'invalid': gettext_lazy("【%s】不是有效的字符串。" % field), + 'blank': gettext_lazy("【%s】此字段不能为空字符串。" % field), + 'max_length': gettext_lazy("【%s】请确保此字段的字符数不超过 {max_length} 个。" % field), + 'min_length': gettext_lazy("【%s】请确保此字段至少包含 {min_length} 个字符。" % field), + 'required': gettext_lazy('此字段必填。'), + 'null': gettext_lazy('此字段不能为null。') + } + + @staticmethod + def uuid(field: str): + return {'required': gettext_lazy('【%s】此字段必填。' % field), + 'null': gettext_lazy('【%s】此字段不能为null。' % field), + 'invalid': gettext_lazy("【%s】必须是有效的UUID。" % field), + } + + @staticmethod + def integer(field: str): + return {'invalid': gettext_lazy('【%s】必须是有效的integer。' % field), + 'max_value': gettext_lazy('【%s】请确保此值小于或等于 {max_value} 。' % field), + 'min_value': gettext_lazy('【%s】请确保此值大于或等于 {min_value} 。' % field), + 'max_string_length': gettext_lazy('【%s】字符串值太大。') % field, + 'required': gettext_lazy('【%s】此字段必填。' % field), + 'null': gettext_lazy('【%s】此字段不能为null。' % field), + } + + @staticmethod + def list(field: str): + return {'not_a_list': gettext_lazy('【%s】应为列表,但得到的类型为 "{input_type}".' % field), + 'empty': gettext_lazy('【%s】此列表不能为空。' % field), + 'min_length': gettext_lazy('【%s】请确保此字段至少包含 {min_length} 个元素。' % field), + 'max_length': gettext_lazy('【%s】请确保此字段的元素不超过 {max_length} 个。' % field), + 'required': gettext_lazy('【%s】此字段必填。' % field), + 'null': gettext_lazy('【%s】此字段不能为null。' % field), + } + + @staticmethod + def boolean(field: str): + return {'invalid': gettext_lazy('【%s】必须是有效的布尔值。' % field), + 'required': gettext_lazy('【%s】此字段必填。' % field), + 'null': gettext_lazy('【%s】此字段不能为null。' % field)} + + @staticmethod + def dict(field: str): + return {'not_a_dict': gettext_lazy('【%s】应为字典,但得到的类型为 "{input_type}' % field), + 'empty': gettext_lazy('【%s】能是空的。' % field), + 'required': gettext_lazy('【%s】此字段必填。' % field), + 'null': gettext_lazy('【%s】此字段不能为null。' % field), + } + + @staticmethod + def float(field: str): + return {'invalid': gettext_lazy('【%s】需要一个有效的数字。' % field), + 'max_value': gettext_lazy('【%s】请确保此值小于或等于 {max_value}。' % field), + 'min_value': gettext_lazy('【%s】请确保此值大于或等于 {min_value}。' % field), + 'max_string_length': gettext_lazy('【%s】字符串值太大。' % field), + 'required': gettext_lazy('【%s】此字段必填。' % field), + 'null': gettext_lazy('【%s】此字段不能为null。' % field), + } + + @staticmethod + def json(field: str): + return { + 'invalid': gettext_lazy('【%s】值必须是有效的JSON。' % field), + 'required': gettext_lazy('【%s】此字段必填。' % field), + 'null': gettext_lazy('【%s】此字段不能为null。' % field), + } + + @staticmethod + def base(field: str): + return { + 'required': gettext_lazy('【%s】此字段必填。' % field), + 'null': gettext_lazy('【%s】此字段不能为null。' % field), + } diff --git a/apps/dataset/serializers/common_serializers.py b/apps/dataset/serializers/common_serializers.py index 651764356..ebd3dbe8d 100644 --- a/apps/dataset/serializers/common_serializers.py +++ b/apps/dataset/serializers/common_serializers.py @@ -17,6 +17,7 @@ from common.db.search import native_search from common.db.sql_execute import update_execute from common.exception.app_exception import AppApiException from common.mixins.api_mixin import ApiMixin +from common.util.field_message import ErrMessage from common.util.file_util import get_file_content from common.util.fork import Fork from dataset.models import Paragraph @@ -38,8 +39,9 @@ def list_paragraph(paragraph_list: List[str]): class MetaSerializer(serializers.Serializer): class WebMeta(serializers.Serializer): - source_url = serializers.CharField(required=True) - selector = serializers.CharField(required=False, allow_null=True, allow_blank=True) + source_url = serializers.CharField(required=True, error_messages=ErrMessage.char("文档地址")) + selector = serializers.CharField(required=False, allow_null=True, allow_blank=True, + error_messages=ErrMessage.char("选择器")) def is_valid(self, *, raise_exception=False): super().is_valid(raise_exception=True) @@ -54,7 +56,8 @@ class MetaSerializer(serializers.Serializer): class BatchSerializer(ApiMixin, serializers.Serializer): - id_list = serializers.ListField(required=True, child=serializers.UUIDField(required=True)) + id_list = serializers.ListField(required=True, child=serializers.UUIDField(required=True), + error_messages=ErrMessage.char("id列表")) def is_valid(self, *, model=None, raise_exception=False): super().is_valid(raise_exception=True) diff --git a/apps/dataset/serializers/dataset_serializers.py b/apps/dataset/serializers/dataset_serializers.py index a4ba19749..89cda6371 100644 --- a/apps/dataset/serializers/dataset_serializers.py +++ b/apps/dataset/serializers/dataset_serializers.py @@ -30,6 +30,7 @@ from common.event import ListenerManagement, SyncWebDatasetArgs from common.exception.app_exception import AppApiException from common.mixins.api_mixin import ApiMixin from common.util.common import post +from common.util.field_message import ErrMessage from common.util.file_util import get_file_content from common.util.fork import ChildLink, Fork from common.util.split_model import get_split_model @@ -67,9 +68,9 @@ class DataSetSerializers(serializers.ModelSerializer): fields = ['id', 'name', 'desc', 'meta', 'create_time', 'update_time'] class Application(ApiMixin, serializers.Serializer): - user_id = serializers.UUIDField(required=True) + user_id = serializers.UUIDField(required=True, error_messages=ErrMessage.char("用户id")) - dataset_id = serializers.UUIDField(required=True) + dataset_id = serializers.UUIDField(required=True, error_messages=ErrMessage.char("数据集id")) @staticmethod def get_request_params_api(): @@ -113,20 +114,15 @@ class DataSetSerializers(serializers.ModelSerializer): 查询对象 """ name = serializers.CharField(required=False, - validators=[ - validators.MaxLengthValidator(limit_value=20, - message="知识库名称在1-20个字符之间"), - validators.MinLengthValidator(limit_value=1, - message="知识库名称在1-20个字符之间") - ]) + error_messages=ErrMessage.char("知识库名称"), + max_length=64, + min_length=1) desc = serializers.CharField(required=False, - validators=[ - validators.MaxLengthValidator(limit_value=256, - message="知识库名称在1-256个字符之间"), - validators.MinLengthValidator(limit_value=1, - message="知识库名称在1-256个字符之间") - ]) + error_messages=ErrMessage.char("知识库描述"), + max_length=256, + min_length=1, + ) user_id = serializers.CharField(required=True) @@ -191,27 +187,21 @@ class DataSetSerializers(serializers.ModelSerializer): return DataSetSerializers.Operate.get_response_body_api() class Create(ApiMixin, serializers.Serializer): - user_id = serializers.UUIDField(required=True) + user_id = serializers.UUIDField(required=True, error_messages=ErrMessage.char("用户id"), ) class CreateBaseSerializers(ApiMixin, serializers.Serializer): """ 创建通用数据集序列化对象 """ name = serializers.CharField(required=True, - validators=[ - validators.MaxLengthValidator(limit_value=20, - message="知识库名称在1-20个字符之间"), - validators.MinLengthValidator(limit_value=1, - message="知识库名称在1-20个字符之间") - ]) + error_messages=ErrMessage.char("知识库名称"), + max_length=64, + min_length=1) desc = serializers.CharField(required=True, - validators=[ - validators.MaxLengthValidator(limit_value=256, - message="知识库名称在1-256个字符之间"), - validators.MinLengthValidator(limit_value=1, - message="知识库名称在1-256个字符之间") - ]) + error_messages=ErrMessage.char("知识库描述"), + max_length=256, + min_length=1) documents = DocumentInstanceSerializer(required=False, many=True) @@ -224,23 +214,18 @@ class DataSetSerializers(serializers.ModelSerializer): 创建web站点序列化对象 """ name = serializers.CharField(required=True, - validators=[ - validators.MaxLengthValidator(limit_value=20, - message="知识库名称在1-20个字符之间"), - validators.MinLengthValidator(limit_value=1, - message="知识库名称在1-20个字符之间") - ]) + error_messages=ErrMessage.char("知识库名称"), + max_length=64, + min_length=1) desc = serializers.CharField(required=True, - validators=[ - validators.MaxLengthValidator(limit_value=256, - message="知识库名称在1-256个字符之间"), - validators.MinLengthValidator(limit_value=1, - message="知识库名称在1-256个字符之间") - ]) - source_url = serializers.CharField(required=True) + error_messages=ErrMessage.char("知识库描述"), + max_length=256, + min_length=1) + source_url = serializers.CharField(required=True, error_messages=ErrMessage.char("Web 根地址"), ) - selector = serializers.CharField(required=False, allow_null=True, allow_blank=True) + selector = serializers.CharField(required=False, allow_null=True, allow_blank=True, + error_messages=ErrMessage.char("选择器")) def is_valid(self, *, raise_exception=False): super().is_valid(raise_exception=True) @@ -426,10 +411,15 @@ class DataSetSerializers(serializers.ModelSerializer): ) class Edit(serializers.Serializer): - name = serializers.CharField(required=False) - desc = serializers.CharField(required=False) + name = serializers.CharField(required=False, max_length=64, min_length=1, + error_messages=ErrMessage.char("知识库名称")) + desc = serializers.CharField(required=False, max_length=256, min_length=1, + error_messages=ErrMessage.char("知识库描述")) meta = serializers.DictField(required=False) - application_id_list = serializers.ListSerializer(required=False, child=serializers.UUIDField(required=True)) + application_id_list = serializers.ListSerializer(required=False, child=serializers.UUIDField(required=True, + error_messages=ErrMessage.char( + "应用id")), + error_messages=ErrMessage.char("应用列表")) @staticmethod def get_dataset_meta_valid_map(): @@ -447,11 +437,13 @@ class DataSetSerializers(serializers.ModelSerializer): valid_class(data=self.data.get('meta')).is_valid(raise_exception=True) class HitTest(ApiMixin, serializers.Serializer): - id = serializers.CharField(required=True) - user_id = serializers.UUIDField(required=False) - query_text = serializers.CharField(required=True) - top_number = serializers.IntegerField(required=True, max_value=10, min_value=1) - similarity = serializers.FloatField(required=True, max_value=1, min_value=0) + id = serializers.CharField(required=True, error_messages=ErrMessage.char("id")) + user_id = serializers.UUIDField(required=False, error_messages=ErrMessage.char("用户id")) + query_text = serializers.CharField(required=True, error_messages=ErrMessage.char("查询文本")) + top_number = serializers.IntegerField(required=True, max_value=10, min_value=1, + error_messages=ErrMessage.char("响应Top")) + similarity = serializers.FloatField(required=True, max_value=1, min_value=0, + error_messages=ErrMessage.char("相似度")) def is_valid(self, *, raise_exception=True): super().is_valid(raise_exception=True) @@ -476,11 +468,14 @@ class DataSetSerializers(serializers.ModelSerializer): 'comprehensive_score': hit_dict.get(p.get('id')).get('comprehensive_score')} for p in p_list] class SyncWeb(ApiMixin, serializers.Serializer): - id = serializers.CharField(required=True) - user_id = serializers.UUIDField(required=False) - sync_type = serializers.CharField(required=True, validators=[ + id = serializers.CharField(required=True, error_messages=ErrMessage.char( + "知识库id")) + user_id = serializers.UUIDField(required=False, error_messages=ErrMessage.char( + "用户id")) + sync_type = serializers.CharField(required=True, error_messages=ErrMessage.char( + "同步类型"), validators=[ validators.RegexValidator(regex=re.compile("^replace|complete$"), - message="replace|complete", code=500) + message="同步类型只支持:replace|complete", code=500) ]) def is_valid(self, *, raise_exception=False): @@ -565,8 +560,10 @@ class DataSetSerializers(serializers.ModelSerializer): ] class Operate(ApiMixin, serializers.Serializer): - id = serializers.CharField(required=True) - user_id = serializers.UUIDField(required=False) + id = serializers.CharField(required=True, error_messages=ErrMessage.char( + "知识库id")) + user_id = serializers.UUIDField(required=False, error_messages=ErrMessage.char( + "用户id")) def is_valid(self, *, raise_exception=True): super().is_valid(raise_exception=True) diff --git a/apps/dataset/serializers/document_serializers.py b/apps/dataset/serializers/document_serializers.py index bc6312bf1..3670d33eb 100644 --- a/apps/dataset/serializers/document_serializers.py +++ b/apps/dataset/serializers/document_serializers.py @@ -13,7 +13,6 @@ import uuid from functools import reduce from typing import List, Dict -from django.core import validators from django.db import transaction from django.db.models import QuerySet from drf_yasg import openapi @@ -25,6 +24,7 @@ from common.event.listener_manage import ListenerManagement, SyncWebDocumentArgs from common.exception.app_exception import AppApiException from common.mixins.api_mixin import ApiMixin from common.util.common import post +from common.util.field_message import ErrMessage from common.util.file_util import get_file_content from common.util.fork import Fork from common.util.split_model import SplitModel, get_split_model @@ -36,8 +36,11 @@ from smartdoc.conf import PROJECT_DIR class DocumentEditInstanceSerializer(ApiMixin, serializers.Serializer): meta = serializers.DictField(required=False) - name = serializers.CharField(required=False) - is_active = serializers.BooleanField(required=False) + name = serializers.CharField(required=False, max_length=128, min_length=1, + error_messages=ErrMessage.char( + "文档名称")) + is_active = serializers.BooleanField(required=False, error_messages=ErrMessage.char( + "文档是否可用")) @staticmethod def get_meta_valid_map(): @@ -56,8 +59,14 @@ class DocumentEditInstanceSerializer(ApiMixin, serializers.Serializer): class DocumentWebInstanceSerializer(ApiMixin, serializers.Serializer): - source_url_list = serializers.ListField(required=True, child=serializers.CharField(required=True)) - selector = serializers.CharField(required=False, allow_null=True, allow_blank=True) + source_url_list = serializers.ListField(required=True, + child=serializers.CharField(required=True, error_messages=ErrMessage.char( + "文档地址")), + error_messages=ErrMessage.char( + "文档地址列表")) + selector = serializers.CharField(required=False, allow_null=True, allow_blank=True, + error_messages=ErrMessage.char( + "选择器")) @staticmethod def get_request_body_api(): @@ -74,12 +83,9 @@ class DocumentWebInstanceSerializer(ApiMixin, serializers.Serializer): class DocumentInstanceSerializer(ApiMixin, serializers.Serializer): name = serializers.CharField(required=True, - validators=[ - validators.MaxLengthValidator(limit_value=128, - message="文档名称在1-128个字符之间"), - validators.MinLengthValidator(limit_value=1, - message="知识库名称在1-128个字符之间") - ]) + error_messages=ErrMessage.char("文档名称"), + max_length=128, + min_length=1) paragraphs = ParagraphInstanceSerializer(required=False, many=True, allow_null=True) @@ -99,15 +105,14 @@ class DocumentInstanceSerializer(ApiMixin, serializers.Serializer): class DocumentSerializers(ApiMixin, serializers.Serializer): class Query(ApiMixin, serializers.Serializer): # 知识库id - dataset_id = serializers.UUIDField(required=True) + dataset_id = serializers.UUIDField(required=True, + error_messages=ErrMessage.char( + "知识库id")) - name = serializers.CharField(required=False, - validators=[ - validators.MaxLengthValidator(limit_value=128, - message="文档名称在1-128个字符之间"), - validators.MinLengthValidator(limit_value=1, - message="知识库名称在1-128个字符之间") - ]) + name = serializers.CharField(required=False, max_length=128, + min_length=1, + error_messages=ErrMessage.char( + "文档名称")) def get_query_set(self): query_set = QuerySet(model=Document) @@ -144,7 +149,8 @@ class DocumentSerializers(ApiMixin, serializers.Serializer): items=DocumentSerializers.Operate.get_response_body_api()) class Sync(ApiMixin, serializers.Serializer): - document_id = serializers.UUIDField(required=True) + document_id = serializers.UUIDField(required=True, error_messages=ErrMessage.char( + "文档id")) def is_valid(self, *, raise_exception=False): super().is_valid(raise_exception=True) @@ -202,7 +208,8 @@ class DocumentSerializers(ApiMixin, serializers.Serializer): return True class Operate(ApiMixin, serializers.Serializer): - document_id = serializers.UUIDField(required=True) + document_id = serializers.UUIDField(required=True, error_messages=ErrMessage.char( + "文档id")) @staticmethod def get_request_params_api(): @@ -312,7 +319,8 @@ class DocumentSerializers(ApiMixin, serializers.Serializer): ) class Create(ApiMixin, serializers.Serializer): - dataset_id = serializers.UUIDField(required=True) + dataset_id = serializers.UUIDField(required=True, error_messages=ErrMessage.char( + "文档id")) def is_valid(self, *, raise_exception=False): super().is_valid(raise_exception=True) @@ -427,14 +435,20 @@ class DocumentSerializers(ApiMixin, serializers.Serializer): ] class Split(ApiMixin, serializers.Serializer): - file = serializers.ListField(required=True) + file = serializers.ListField(required=True, error_messages=ErrMessage.list( + "文件列表")) - limit = serializers.IntegerField(required=False) + limit = serializers.IntegerField(required=False, error_messages=ErrMessage.integer( + "分段长度")) patterns = serializers.ListField(required=False, - child=serializers.CharField(required=True)) + child=serializers.CharField(required=True, error_messages=ErrMessage.char( + "分段标识")), + error_messages=ErrMessage.uuid( + "分段标识列表")) - with_filter = serializers.BooleanField(required=False) + with_filter = serializers.BooleanField(required=False, error_messages=ErrMessage.boolean( + "自动清洗")) def is_valid(self, *, raise_exception=True): super().is_valid(raise_exception=True) @@ -486,7 +500,7 @@ class DocumentSerializers(ApiMixin, serializers.Serializer): {'key': '空行', 'value': '(?