From 861db4ad111fa0ba3b22d0f41379b6a98478241a Mon Sep 17 00:00:00 2001 From: shaohuzhang1 Date: Thu, 21 Dec 2023 18:31:58 +0800 Subject: [PATCH] =?UTF-8?q?feat:=20=E6=B7=BB=E5=8A=A0=E6=A0=87=E6=B3=A8?= =?UTF-8?q?=E8=BF=94=E5=9B=9E=E5=80=BC,=E6=B7=BB=E5=8A=A0=E6=97=A5?= =?UTF-8?q?=E5=BF=97api=E6=96=87=E6=A1=A3=E5=93=8D=E5=BA=94?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../serializers/chat_serializers.py | 2 +- apps/application/swagger_api/chat_api.py | 77 +++++++++++++++++++ apps/application/views/chat_views.py | 6 +- 3 files changed, 83 insertions(+), 2 deletions(-) diff --git a/apps/application/serializers/chat_serializers.py b/apps/application/serializers/chat_serializers.py index 1d176d64d..474de0ee8 100644 --- a/apps/application/serializers/chat_serializers.py +++ b/apps/application/serializers/chat_serializers.py @@ -328,4 +328,4 @@ class ChatRecordSerializer(serializers.Serializer): # 添加标注 chat_record.save() ListenerManagement.embedding_by_paragraph_signal.send(paragraph.id) - return True + return ChatRecordSerializerModel(chat_record).data diff --git a/apps/application/swagger_api/chat_api.py b/apps/application/swagger_api/chat_api.py index 07766cf4e..0cde547e1 100644 --- a/apps/application/swagger_api/chat_api.py +++ b/apps/application/swagger_api/chat_api.py @@ -23,6 +23,40 @@ class ChatApi(ApiMixin): } ) + @staticmethod + def get_response_body_api(): + return openapi.Schema( + type=openapi.TYPE_OBJECT, + required=['id', 'application', 'abstract', 'chat_record_count', 'mark_sum', 'star_num', 'trample_num', + 'update_time', 'create_time'], + properties={ + 'id': openapi.Schema(type=openapi.TYPE_STRING, title="id", + description="id", default="xx"), + 'application_id': openapi.Schema(type=openapi.TYPE_STRING, title="应用id", + description="应用id", default='应用id'), + 'abstract': openapi.Schema(type=openapi.TYPE_STRING, title="摘要", + description="摘要", default='摘要'), + 'chat_id': openapi.Schema(type=openapi.TYPE_STRING, title="对话id", + description="对话id", default="对话id"), + 'chat_record_count': openapi.Schema(type=openapi.TYPE_STRING, title="对话提问数量", + description="对话提问数量", + default="对话提问数量"), + 'mark_sum': openapi.Schema(type=openapi.TYPE_STRING, title="标记数量", + description="标记数量", default=1), + 'star_num': openapi.Schema(type=openapi.TYPE_STRING, title="点赞数量", + description="点赞数量", default=1), + 'trample_num': openapi.Schema(type=openapi.TYPE_NUMBER, title="点踩数量", + description="点踩数量", default=1), + 'update_time': openapi.Schema(type=openapi.TYPE_STRING, title="修改时间", + description="修改时间", + default="1970-01-01 00:00:00"), + 'create_time': openapi.Schema(type=openapi.TYPE_STRING, title="创建时间", + description="创建时间", + default="1970-01-01 00:00:00" + ) + } + ) + class OpenChat(ApiMixin): @staticmethod def get_request_params_api(): @@ -82,6 +116,49 @@ class ChatRecordApi(ApiMixin): description='对话id'), ] + @staticmethod + def get_response_body_api(): + return openapi.Schema( + type=openapi.TYPE_OBJECT, + required=['id', 'chat', 'vote_status', 'dataset', 'paragraph', 'source_id', 'source_type', + 'message_tokens', 'answer_tokens', + 'problem_text', 'answer_text', 'improve_paragraph_id_list'], + properties={ + 'id': openapi.Schema(type=openapi.TYPE_STRING, title="id", + description="id", default="xx"), + 'chat': openapi.Schema(type=openapi.TYPE_STRING, title="会话日志id", + description="会话日志id", default='会话日志id'), + 'vote_status': openapi.Schema(type=openapi.TYPE_STRING, title="投票状态", + description="投票状态", default="投票状态"), + 'dataset': openapi.Schema(type=openapi.TYPE_STRING, title="数据集id", description="数据集id", + default="数据集id"), + 'paragraph': openapi.Schema(type=openapi.TYPE_STRING, title="段落id", + description="段落id", default=1), + 'source_id': openapi.Schema(type=openapi.TYPE_STRING, title="资源id", + description="资源id", default=1), + 'source_type': openapi.Schema(type=openapi.TYPE_STRING, title="资源类型", + description="资源类型", default='xxx'), + 'message_tokens': openapi.Schema(type=openapi.TYPE_INTEGER, title="问题消耗token数量", + description="问题消耗token数量", default=0), + 'answer_tokens': openapi.Schema(type=openapi.TYPE_INTEGER, title="答案消耗token数量", + description="答案消耗token数量", default=0), + 'improve_paragraph_id_list': openapi.Schema(type=openapi.TYPE_STRING, title="改进标注列表", + description="改进标注列表", + default=[]), + 'index': openapi.Schema(type=openapi.TYPE_STRING, title="对应会话 对应下标", + description="对应会话id对应下标", + default="对应会话id对应下标" + ), + 'update_time': openapi.Schema(type=openapi.TYPE_STRING, title="修改时间", + description="修改时间", + default="1970-01-01 00:00:00"), + 'create_time': openapi.Schema(type=openapi.TYPE_STRING, title="创建时间", + description="创建时间", + default="1970-01-01 00:00:00" + ) + } + ) + class ImproveApi(ApiMixin): @staticmethod diff --git a/apps/application/views/chat_views.py b/apps/application/views/chat_views.py index 415c7ca0d..3ee0223c3 100644 --- a/apps/application/views/chat_views.py +++ b/apps/application/views/chat_views.py @@ -77,6 +77,7 @@ class ChatView(APIView): @swagger_auto_schema(operation_summary="获取对话列表", operation_id="获取对话列表", manual_parameters=ChatApi.get_request_params_api(), + responses=result.get_api_array_response(ChatApi.get_response_body_api()), tags=["应用/对话日志"] ) @has_permissions( @@ -116,6 +117,7 @@ class ChatView(APIView): @swagger_auto_schema(operation_summary="分页获取对话列表", operation_id="分页获取对话列表", manual_parameters=result.get_page_request_params(ChatApi.get_request_params_api()), + responses=result.get_page_api_response(ChatApi.get_response_body_api()), tags=["应用/对话日志"] ) @has_permissions( @@ -136,6 +138,7 @@ class ChatView(APIView): @swagger_auto_schema(operation_summary="获取对话记录列表", operation_id="获取对话记录列表", manual_parameters=ChatRecordApi.get_request_params_api(), + responses=result.get_api_array_response(ChatRecordApi.get_response_body_api()), tags=["应用/对话日志"] ) @has_permissions( @@ -156,6 +159,7 @@ class ChatView(APIView): operation_id="获取对话记录列表", manual_parameters=result.get_page_request_params( ChatRecordApi.get_request_params_api()), + responses=result.get_page_api_response(ChatRecordApi.get_response_body_api()), tags=["应用/对话日志"] ) @has_permissions( @@ -217,7 +221,7 @@ class ChatView(APIView): operation_id="标注", manual_parameters=ImproveApi.get_request_params_api(), request_body=ImproveApi.get_request_body_api(), - responses=result.get_default_response(), + responses=result.get_api_response(ChatRecordApi.get_response_body_api()), tags=["应用/对话日志/标注"] ) @has_permissions(