feat: 标注信息查询

This commit is contained in:
shaohuzhang1 2023-12-13 17:14:43 +08:00
parent 0c6d39892c
commit 5ac4f64a93
7 changed files with 137 additions and 5 deletions

View File

@ -0,0 +1,20 @@
# Generated by Django 4.1.10 on 2023-12-13 07:35
from django.db import migrations, models
import django.db.models.deletion
class Migration(migrations.Migration):
dependencies = [
('setting', '0003_alter_model_provider'),
('application', '0010_rename_improve_problem_id_list_chatrecord_improve_paragraph_id_list_and_more'),
]
operations = [
migrations.AlterField(
model_name='application',
name='model',
field=models.ForeignKey(blank=True, db_constraint=False, null=True, on_delete=django.db.models.deletion.SET_NULL, to='setting.model'),
),
]

View File

@ -26,7 +26,7 @@ class Application(AppModelMixin):
example = ArrayField(verbose_name="示例列表", base_field=models.CharField(max_length=256, blank=True))
dialogue_number = models.IntegerField(default=0, verbose_name="会话数量")
user = models.ForeignKey(User, on_delete=models.DO_NOTHING)
model = models.ForeignKey(Model, on_delete=models.DO_NOTHING, db_constraint=False)
model = models.ForeignKey(Model, on_delete=models.SET_NULL, db_constraint=False, blank=True, null=True)
class Meta:
db_table = "application"

View File

@ -131,7 +131,9 @@ class ChatSerializers(serializers.Serializer):
def open(self):
self.is_valid(raise_exception=True)
chat_id = str(uuid.uuid1())
model = QuerySet(Model).get(user_id=self.data.get('user_id'), id=self.data.get('model_id'))
model = QuerySet(Model).filter(user_id=self.data.get('user_id'), id=self.data.get('model_id')).first()
if model is None:
raise AppApiException(500, "模型不存在")
dataset_id_list = self.data.get('dataset_id_list')
chat_model = ModelProvideConstants[model.provider].value.get_model(model.model_type, model.model_name,
json.loads(
@ -251,6 +253,36 @@ class ChatRecordSerializer(serializers.Serializer):
title = serializers.CharField(required=False)
content = serializers.CharField(required=True)
class ParagraphModel(serializers.ModelSerializer):
class Meta:
model = Paragraph
fields = "__all__"
class ChatRecordImprove(serializers.Serializer):
chat_id = serializers.UUIDField(required=True)
chat_record_id = serializers.UUIDField(required=True)
def get(self, with_valid=True):
if with_valid:
self.is_valid(raise_exception=True)
chat_record_id = self.data.get('chat_record_id')
chat_id = self.data.get('chat_id')
chat_record = QuerySet(ChatRecord).filter(id=chat_record_id, chat_id=chat_id).first()
if chat_record is None:
raise AppApiException(500, '不存在的对话记录')
if chat_record.improve_paragraph_id_list is None or len(chat_record.improve_paragraph_id_list) == 0:
return []
paragraph_model_list = QuerySet(Paragraph).filter(id__in=chat_record.improve_paragraph_id_list)
if len(paragraph_model_list) < len(chat_record.improve_paragraph_id_list):
paragraph_model_id_list = [str(p.id) for p in paragraph_model_list]
chat_record.improve_paragraph_id_list = list(
filter(lambda p_id: paragraph_model_id_list.__contains__(p_id),
chat_record.improve_paragraph_id_list))
chat_record.save()
return [ChatRecordSerializer.ParagraphModel(p).data for p in paragraph_model_list]
class Improve(serializers.Serializer):
chat_id = serializers.UUIDField(required=True)

View File

@ -159,3 +159,60 @@ class VoteApi(ApiMixin):
}
)
class ChatRecordImproveApi(ApiMixin):
@staticmethod
def get_request_body_api():
return [openapi.Parameter(name='application_id',
in_=openapi.IN_PATH,
type=openapi.TYPE_STRING,
required=True,
description='应用id'),
openapi.Parameter(name='chat_id',
in_=openapi.IN_PATH,
type=openapi.TYPE_STRING,
required=True,
description='会话id'),
openapi.Parameter(name='chat_record_id',
in_=openapi.IN_PATH,
type=openapi.TYPE_STRING,
required=True,
description='会话记录id')
]
@staticmethod
def get_response_body_api():
return openapi.Schema(
type=openapi.TYPE_OBJECT,
required=['id', 'content', 'hit_num', 'star_num', 'trample_num', 'is_active', 'dataset_id',
'document_id', 'title',
'create_time', 'update_time'],
properties={
'id': openapi.Schema(type=openapi.TYPE_STRING, title="id",
description="id", default="xx"),
'content': openapi.Schema(type=openapi.TYPE_STRING, title="段落内容",
description="段落内容", default='段落内容'),
'title': openapi.Schema(type=openapi.TYPE_STRING, title="标题",
description="标题", default="xxx的描述"),
'hit_num': openapi.Schema(type=openapi.TYPE_INTEGER, title="命中数量", description="命中数量",
default=1),
'star_num': openapi.Schema(type=openapi.TYPE_INTEGER, title="点赞数量",
description="点赞数量", default=1),
'trample_num': openapi.Schema(type=openapi.TYPE_INTEGER, title="点踩数量",
description="点踩数", default=1),
'dataset_id': openapi.Schema(type=openapi.TYPE_STRING, title="数据集id",
description="数据集id", default='xxx'),
'document_id': openapi.Schema(type=openapi.TYPE_STRING, title="文档id",
description="文档id", default='xxx'),
'is_active': openapi.Schema(type=openapi.TYPE_BOOLEAN, title="是否可用",
description="是否可用", default=True),
'update_time': openapi.Schema(type=openapi.TYPE_STRING, title="修改时间",
description="修改时间",
default="1970-01-01 00:00:00"),
'create_time': openapi.Schema(type=openapi.TYPE_STRING, title="创建时间",
description="创建时间",
default="1970-01-01 00:00:00"
)
}
)

View File

@ -31,6 +31,8 @@ urlpatterns = [
'application/<str:application_id>/chat/<chat_id>/chat_record/<str:chat_record_id>/dataset/<str:dataset_id>/document_id/<str:document_id>/improve',
views.ChatView.ChatRecord.Improve.as_view(),
name=''),
path('application/<str:application_id>/chat/<chat_id>/chat_record/<str:chat_record_id>',
views.ChatView.ChatRecord.ChatRecordImprove.as_view()),
path('application/chat_message/<str:chat_id>', views.ChatView.Message.as_view())
]

View File

@ -13,11 +13,10 @@ from rest_framework.views import APIView
from application.serializers.chat_message_serializers import ChatMessageSerializer
from application.serializers.chat_serializers import ChatSerializers, ChatRecordSerializer
from application.swagger_api.chat_api import ChatApi, VoteApi, ChatRecordApi, ImproveApi
from application.swagger_api.chat_api import ChatApi, VoteApi, ChatRecordApi, ImproveApi, ChatRecordImproveApi
from common.auth import TokenAuth, has_permissions
from common.constants.permission_constants import Permission, Group, Operate, \
RoleConstants, ViewPermission, CompareConstants
from common.exception.app_exception import AppAuthenticationFailed
from common.response import result
from common.util.common import query_params_to_single_dict
@ -191,6 +190,25 @@ class ChatView(APIView):
data={'vote_status': request.data.get('vote_status'), 'chat_id': chat_id,
'chat_record_id': chat_record_id}).vote())
class ChatRecordImprove(APIView):
authentication_classes = [TokenAuth]
@action(methods=['GET'], detail=False)
@swagger_auto_schema(operation_summary="获取标注段落列表信息",
operation_id="获取标注段落列表信息",
manual_parameters=ChatRecordImproveApi.get_request_params_api(),
responses=result.get_api_response(ChatRecordImproveApi.get_response_body_api()),
tags=["应用/对话日志/标注"]
)
@has_permissions(
ViewPermission([RoleConstants.ADMIN, RoleConstants.USER],
[lambda r, keywords: Permission(group=Group.APPLICATION, operate=Operate.USE,
dynamic_tag=keywords.get('application_id'))]
))
def get(self, request: Request, application_id: str, chat_id: str, chat_record_id: str):
return result.success(ChatRecordSerializer.ChatRecordImprove(
data={'chat_id': chat_id, 'chat_record_id': chat_record_id}).get())
class Improve(APIView):
authentication_classes = [TokenAuth]

View File

@ -275,7 +275,8 @@ class ParagraphSerializers(ApiMixin, serializers.Serializer):
def get_response_body_api():
return openapi.Schema(
type=openapi.TYPE_OBJECT,
required=['id', 'content', 'hit_num', 'star_num', 'trample_num', 'is_active', 'document_id', 'title',
required=['id', 'content', 'hit_num', 'star_num', 'trample_num', 'is_active', 'dataset_id',
'document_id', 'title',
'create_time', 'update_time'],
properties={
'id': openapi.Schema(type=openapi.TYPE_STRING, title="id",
@ -290,6 +291,8 @@ class ParagraphSerializers(ApiMixin, serializers.Serializer):
description="点赞数量", default=1),
'trample_num': openapi.Schema(type=openapi.TYPE_INTEGER, title="点踩数量",
description="点踩数", default=1),
'dataset_id': openapi.Schema(type=openapi.TYPE_STRING, title="数据集id",
description="数据集id", default='xxx'),
'document_id': openapi.Schema(type=openapi.TYPE_STRING, title="文档id",
description="文档id", default='xxx'),
'is_active': openapi.Schema(type=openapi.TYPE_BOOLEAN, title="是否可用",