From 5ac4f64a93f5884282c6087e4ece1f1379db7c6a Mon Sep 17 00:00:00 2001 From: shaohuzhang1 Date: Wed, 13 Dec 2023 17:14:43 +0800 Subject: [PATCH] =?UTF-8?q?feat:=20=E6=A0=87=E6=B3=A8=E4=BF=A1=E6=81=AF?= =?UTF-8?q?=E6=9F=A5=E8=AF=A2?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../0011_alter_application_model.py | 20 +++++++ apps/application/models/application.py | 2 +- .../serializers/chat_serializers.py | 34 ++++++++++- apps/application/swagger_api/chat_api.py | 57 +++++++++++++++++++ apps/application/urls.py | 2 + apps/application/views/chat_views.py | 22 ++++++- .../serializers/paragraph_serializers.py | 5 +- 7 files changed, 137 insertions(+), 5 deletions(-) create mode 100644 apps/application/migrations/0011_alter_application_model.py diff --git a/apps/application/migrations/0011_alter_application_model.py b/apps/application/migrations/0011_alter_application_model.py new file mode 100644 index 000000000..583e4fd87 --- /dev/null +++ b/apps/application/migrations/0011_alter_application_model.py @@ -0,0 +1,20 @@ +# Generated by Django 4.1.10 on 2023-12-13 07:35 + +from django.db import migrations, models +import django.db.models.deletion + + +class Migration(migrations.Migration): + + dependencies = [ + ('setting', '0003_alter_model_provider'), + ('application', '0010_rename_improve_problem_id_list_chatrecord_improve_paragraph_id_list_and_more'), + ] + + operations = [ + migrations.AlterField( + model_name='application', + name='model', + field=models.ForeignKey(blank=True, db_constraint=False, null=True, on_delete=django.db.models.deletion.SET_NULL, to='setting.model'), + ), + ] diff --git a/apps/application/models/application.py b/apps/application/models/application.py index 90a2efbe5..d15b182d9 100644 --- a/apps/application/models/application.py +++ b/apps/application/models/application.py @@ -26,7 +26,7 @@ class Application(AppModelMixin): example = ArrayField(verbose_name="示例列表", base_field=models.CharField(max_length=256, blank=True)) dialogue_number = models.IntegerField(default=0, verbose_name="会话数量") user = models.ForeignKey(User, on_delete=models.DO_NOTHING) - model = models.ForeignKey(Model, on_delete=models.DO_NOTHING, db_constraint=False) + model = models.ForeignKey(Model, on_delete=models.SET_NULL, db_constraint=False, blank=True, null=True) class Meta: db_table = "application" diff --git a/apps/application/serializers/chat_serializers.py b/apps/application/serializers/chat_serializers.py index ffb9a1bb0..3e392a544 100644 --- a/apps/application/serializers/chat_serializers.py +++ b/apps/application/serializers/chat_serializers.py @@ -131,7 +131,9 @@ class ChatSerializers(serializers.Serializer): def open(self): self.is_valid(raise_exception=True) chat_id = str(uuid.uuid1()) - model = QuerySet(Model).get(user_id=self.data.get('user_id'), id=self.data.get('model_id')) + model = QuerySet(Model).filter(user_id=self.data.get('user_id'), id=self.data.get('model_id')).first() + if model is None: + raise AppApiException(500, "模型不存在") dataset_id_list = self.data.get('dataset_id_list') chat_model = ModelProvideConstants[model.provider].value.get_model(model.model_type, model.model_name, json.loads( @@ -251,6 +253,36 @@ class ChatRecordSerializer(serializers.Serializer): title = serializers.CharField(required=False) content = serializers.CharField(required=True) + class ParagraphModel(serializers.ModelSerializer): + class Meta: + model = Paragraph + fields = "__all__" + + class ChatRecordImprove(serializers.Serializer): + chat_id = serializers.UUIDField(required=True) + + chat_record_id = serializers.UUIDField(required=True) + + def get(self, with_valid=True): + if with_valid: + self.is_valid(raise_exception=True) + chat_record_id = self.data.get('chat_record_id') + chat_id = self.data.get('chat_id') + chat_record = QuerySet(ChatRecord).filter(id=chat_record_id, chat_id=chat_id).first() + if chat_record is None: + raise AppApiException(500, '不存在的对话记录') + if chat_record.improve_paragraph_id_list is None or len(chat_record.improve_paragraph_id_list) == 0: + return [] + + paragraph_model_list = QuerySet(Paragraph).filter(id__in=chat_record.improve_paragraph_id_list) + if len(paragraph_model_list) < len(chat_record.improve_paragraph_id_list): + paragraph_model_id_list = [str(p.id) for p in paragraph_model_list] + chat_record.improve_paragraph_id_list = list( + filter(lambda p_id: paragraph_model_id_list.__contains__(p_id), + chat_record.improve_paragraph_id_list)) + chat_record.save() + return [ChatRecordSerializer.ParagraphModel(p).data for p in paragraph_model_list] + class Improve(serializers.Serializer): chat_id = serializers.UUIDField(required=True) diff --git a/apps/application/swagger_api/chat_api.py b/apps/application/swagger_api/chat_api.py index 1ac15e6f2..3a8f3a2ea 100644 --- a/apps/application/swagger_api/chat_api.py +++ b/apps/application/swagger_api/chat_api.py @@ -159,3 +159,60 @@ class VoteApi(ApiMixin): } ) + + +class ChatRecordImproveApi(ApiMixin): + @staticmethod + def get_request_body_api(): + return [openapi.Parameter(name='application_id', + in_=openapi.IN_PATH, + type=openapi.TYPE_STRING, + required=True, + description='应用id'), + openapi.Parameter(name='chat_id', + in_=openapi.IN_PATH, + type=openapi.TYPE_STRING, + required=True, + description='会话id'), + openapi.Parameter(name='chat_record_id', + in_=openapi.IN_PATH, + type=openapi.TYPE_STRING, + required=True, + description='会话记录id') + ] + + @staticmethod + def get_response_body_api(): + return openapi.Schema( + type=openapi.TYPE_OBJECT, + required=['id', 'content', 'hit_num', 'star_num', 'trample_num', 'is_active', 'dataset_id', + 'document_id', 'title', + 'create_time', 'update_time'], + properties={ + 'id': openapi.Schema(type=openapi.TYPE_STRING, title="id", + description="id", default="xx"), + 'content': openapi.Schema(type=openapi.TYPE_STRING, title="段落内容", + description="段落内容", default='段落内容'), + 'title': openapi.Schema(type=openapi.TYPE_STRING, title="标题", + description="标题", default="xxx的描述"), + 'hit_num': openapi.Schema(type=openapi.TYPE_INTEGER, title="命中数量", description="命中数量", + default=1), + 'star_num': openapi.Schema(type=openapi.TYPE_INTEGER, title="点赞数量", + description="点赞数量", default=1), + 'trample_num': openapi.Schema(type=openapi.TYPE_INTEGER, title="点踩数量", + description="点踩数", default=1), + 'dataset_id': openapi.Schema(type=openapi.TYPE_STRING, title="数据集id", + description="数据集id", default='xxx'), + 'document_id': openapi.Schema(type=openapi.TYPE_STRING, title="文档id", + description="文档id", default='xxx'), + 'is_active': openapi.Schema(type=openapi.TYPE_BOOLEAN, title="是否可用", + description="是否可用", default=True), + 'update_time': openapi.Schema(type=openapi.TYPE_STRING, title="修改时间", + description="修改时间", + default="1970-01-01 00:00:00"), + 'create_time': openapi.Schema(type=openapi.TYPE_STRING, title="创建时间", + description="创建时间", + default="1970-01-01 00:00:00" + ) + } + ) diff --git a/apps/application/urls.py b/apps/application/urls.py index 72293af32..e02613264 100644 --- a/apps/application/urls.py +++ b/apps/application/urls.py @@ -31,6 +31,8 @@ urlpatterns = [ 'application//chat//chat_record//dataset//document_id//improve', views.ChatView.ChatRecord.Improve.as_view(), name=''), + path('application//chat//chat_record/', + views.ChatView.ChatRecord.ChatRecordImprove.as_view()), path('application/chat_message/', views.ChatView.Message.as_view()) ] diff --git a/apps/application/views/chat_views.py b/apps/application/views/chat_views.py index 0ed0228f5..74f96e95e 100644 --- a/apps/application/views/chat_views.py +++ b/apps/application/views/chat_views.py @@ -13,11 +13,10 @@ from rest_framework.views import APIView from application.serializers.chat_message_serializers import ChatMessageSerializer from application.serializers.chat_serializers import ChatSerializers, ChatRecordSerializer -from application.swagger_api.chat_api import ChatApi, VoteApi, ChatRecordApi, ImproveApi +from application.swagger_api.chat_api import ChatApi, VoteApi, ChatRecordApi, ImproveApi, ChatRecordImproveApi from common.auth import TokenAuth, has_permissions from common.constants.permission_constants import Permission, Group, Operate, \ RoleConstants, ViewPermission, CompareConstants -from common.exception.app_exception import AppAuthenticationFailed from common.response import result from common.util.common import query_params_to_single_dict @@ -191,6 +190,25 @@ class ChatView(APIView): data={'vote_status': request.data.get('vote_status'), 'chat_id': chat_id, 'chat_record_id': chat_record_id}).vote()) + class ChatRecordImprove(APIView): + authentication_classes = [TokenAuth] + + @action(methods=['GET'], detail=False) + @swagger_auto_schema(operation_summary="获取标注段落列表信息", + operation_id="获取标注段落列表信息", + manual_parameters=ChatRecordImproveApi.get_request_params_api(), + responses=result.get_api_response(ChatRecordImproveApi.get_response_body_api()), + tags=["应用/对话日志/标注"] + ) + @has_permissions( + ViewPermission([RoleConstants.ADMIN, RoleConstants.USER], + [lambda r, keywords: Permission(group=Group.APPLICATION, operate=Operate.USE, + dynamic_tag=keywords.get('application_id'))] + )) + def get(self, request: Request, application_id: str, chat_id: str, chat_record_id: str): + return result.success(ChatRecordSerializer.ChatRecordImprove( + data={'chat_id': chat_id, 'chat_record_id': chat_record_id}).get()) + class Improve(APIView): authentication_classes = [TokenAuth] diff --git a/apps/dataset/serializers/paragraph_serializers.py b/apps/dataset/serializers/paragraph_serializers.py index 97db251e5..fe8c6ae03 100644 --- a/apps/dataset/serializers/paragraph_serializers.py +++ b/apps/dataset/serializers/paragraph_serializers.py @@ -275,7 +275,8 @@ class ParagraphSerializers(ApiMixin, serializers.Serializer): def get_response_body_api(): return openapi.Schema( type=openapi.TYPE_OBJECT, - required=['id', 'content', 'hit_num', 'star_num', 'trample_num', 'is_active', 'document_id', 'title', + required=['id', 'content', 'hit_num', 'star_num', 'trample_num', 'is_active', 'dataset_id', + 'document_id', 'title', 'create_time', 'update_time'], properties={ 'id': openapi.Schema(type=openapi.TYPE_STRING, title="id", @@ -290,6 +291,8 @@ class ParagraphSerializers(ApiMixin, serializers.Serializer): description="点赞数量", default=1), 'trample_num': openapi.Schema(type=openapi.TYPE_INTEGER, title="点踩数量", description="点踩数", default=1), + 'dataset_id': openapi.Schema(type=openapi.TYPE_STRING, title="数据集id", + description="数据集id", default='xxx'), 'document_id': openapi.Schema(type=openapi.TYPE_STRING, title="文档id", description="文档id", default='xxx'), 'is_active': openapi.Schema(type=openapi.TYPE_BOOLEAN, title="是否可用",