mirror of
https://github.com/1Panel-dev/MaxKB.git
synced 2025-12-26 01:33:05 +00:00
feat: 标注信息查询
This commit is contained in:
parent
0c6d39892c
commit
5ac4f64a93
|
|
@ -0,0 +1,20 @@
|
|||
# Generated by Django 4.1.10 on 2023-12-13 07:35
|
||||
|
||||
from django.db import migrations, models
|
||||
import django.db.models.deletion
|
||||
|
||||
|
||||
class Migration(migrations.Migration):
|
||||
|
||||
dependencies = [
|
||||
('setting', '0003_alter_model_provider'),
|
||||
('application', '0010_rename_improve_problem_id_list_chatrecord_improve_paragraph_id_list_and_more'),
|
||||
]
|
||||
|
||||
operations = [
|
||||
migrations.AlterField(
|
||||
model_name='application',
|
||||
name='model',
|
||||
field=models.ForeignKey(blank=True, db_constraint=False, null=True, on_delete=django.db.models.deletion.SET_NULL, to='setting.model'),
|
||||
),
|
||||
]
|
||||
|
|
@ -26,7 +26,7 @@ class Application(AppModelMixin):
|
|||
example = ArrayField(verbose_name="示例列表", base_field=models.CharField(max_length=256, blank=True))
|
||||
dialogue_number = models.IntegerField(default=0, verbose_name="会话数量")
|
||||
user = models.ForeignKey(User, on_delete=models.DO_NOTHING)
|
||||
model = models.ForeignKey(Model, on_delete=models.DO_NOTHING, db_constraint=False)
|
||||
model = models.ForeignKey(Model, on_delete=models.SET_NULL, db_constraint=False, blank=True, null=True)
|
||||
|
||||
class Meta:
|
||||
db_table = "application"
|
||||
|
|
|
|||
|
|
@ -131,7 +131,9 @@ class ChatSerializers(serializers.Serializer):
|
|||
def open(self):
|
||||
self.is_valid(raise_exception=True)
|
||||
chat_id = str(uuid.uuid1())
|
||||
model = QuerySet(Model).get(user_id=self.data.get('user_id'), id=self.data.get('model_id'))
|
||||
model = QuerySet(Model).filter(user_id=self.data.get('user_id'), id=self.data.get('model_id')).first()
|
||||
if model is None:
|
||||
raise AppApiException(500, "模型不存在")
|
||||
dataset_id_list = self.data.get('dataset_id_list')
|
||||
chat_model = ModelProvideConstants[model.provider].value.get_model(model.model_type, model.model_name,
|
||||
json.loads(
|
||||
|
|
@ -251,6 +253,36 @@ class ChatRecordSerializer(serializers.Serializer):
|
|||
title = serializers.CharField(required=False)
|
||||
content = serializers.CharField(required=True)
|
||||
|
||||
class ParagraphModel(serializers.ModelSerializer):
|
||||
class Meta:
|
||||
model = Paragraph
|
||||
fields = "__all__"
|
||||
|
||||
class ChatRecordImprove(serializers.Serializer):
|
||||
chat_id = serializers.UUIDField(required=True)
|
||||
|
||||
chat_record_id = serializers.UUIDField(required=True)
|
||||
|
||||
def get(self, with_valid=True):
|
||||
if with_valid:
|
||||
self.is_valid(raise_exception=True)
|
||||
chat_record_id = self.data.get('chat_record_id')
|
||||
chat_id = self.data.get('chat_id')
|
||||
chat_record = QuerySet(ChatRecord).filter(id=chat_record_id, chat_id=chat_id).first()
|
||||
if chat_record is None:
|
||||
raise AppApiException(500, '不存在的对话记录')
|
||||
if chat_record.improve_paragraph_id_list is None or len(chat_record.improve_paragraph_id_list) == 0:
|
||||
return []
|
||||
|
||||
paragraph_model_list = QuerySet(Paragraph).filter(id__in=chat_record.improve_paragraph_id_list)
|
||||
if len(paragraph_model_list) < len(chat_record.improve_paragraph_id_list):
|
||||
paragraph_model_id_list = [str(p.id) for p in paragraph_model_list]
|
||||
chat_record.improve_paragraph_id_list = list(
|
||||
filter(lambda p_id: paragraph_model_id_list.__contains__(p_id),
|
||||
chat_record.improve_paragraph_id_list))
|
||||
chat_record.save()
|
||||
return [ChatRecordSerializer.ParagraphModel(p).data for p in paragraph_model_list]
|
||||
|
||||
class Improve(serializers.Serializer):
|
||||
chat_id = serializers.UUIDField(required=True)
|
||||
|
||||
|
|
|
|||
|
|
@ -159,3 +159,60 @@ class VoteApi(ApiMixin):
|
|||
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
class ChatRecordImproveApi(ApiMixin):
|
||||
@staticmethod
|
||||
def get_request_body_api():
|
||||
return [openapi.Parameter(name='application_id',
|
||||
in_=openapi.IN_PATH,
|
||||
type=openapi.TYPE_STRING,
|
||||
required=True,
|
||||
description='应用id'),
|
||||
openapi.Parameter(name='chat_id',
|
||||
in_=openapi.IN_PATH,
|
||||
type=openapi.TYPE_STRING,
|
||||
required=True,
|
||||
description='会话id'),
|
||||
openapi.Parameter(name='chat_record_id',
|
||||
in_=openapi.IN_PATH,
|
||||
type=openapi.TYPE_STRING,
|
||||
required=True,
|
||||
description='会话记录id')
|
||||
]
|
||||
|
||||
@staticmethod
|
||||
def get_response_body_api():
|
||||
return openapi.Schema(
|
||||
type=openapi.TYPE_OBJECT,
|
||||
required=['id', 'content', 'hit_num', 'star_num', 'trample_num', 'is_active', 'dataset_id',
|
||||
'document_id', 'title',
|
||||
'create_time', 'update_time'],
|
||||
properties={
|
||||
'id': openapi.Schema(type=openapi.TYPE_STRING, title="id",
|
||||
description="id", default="xx"),
|
||||
'content': openapi.Schema(type=openapi.TYPE_STRING, title="段落内容",
|
||||
description="段落内容", default='段落内容'),
|
||||
'title': openapi.Schema(type=openapi.TYPE_STRING, title="标题",
|
||||
description="标题", default="xxx的描述"),
|
||||
'hit_num': openapi.Schema(type=openapi.TYPE_INTEGER, title="命中数量", description="命中数量",
|
||||
default=1),
|
||||
'star_num': openapi.Schema(type=openapi.TYPE_INTEGER, title="点赞数量",
|
||||
description="点赞数量", default=1),
|
||||
'trample_num': openapi.Schema(type=openapi.TYPE_INTEGER, title="点踩数量",
|
||||
description="点踩数", default=1),
|
||||
'dataset_id': openapi.Schema(type=openapi.TYPE_STRING, title="数据集id",
|
||||
description="数据集id", default='xxx'),
|
||||
'document_id': openapi.Schema(type=openapi.TYPE_STRING, title="文档id",
|
||||
description="文档id", default='xxx'),
|
||||
'is_active': openapi.Schema(type=openapi.TYPE_BOOLEAN, title="是否可用",
|
||||
description="是否可用", default=True),
|
||||
'update_time': openapi.Schema(type=openapi.TYPE_STRING, title="修改时间",
|
||||
description="修改时间",
|
||||
default="1970-01-01 00:00:00"),
|
||||
'create_time': openapi.Schema(type=openapi.TYPE_STRING, title="创建时间",
|
||||
description="创建时间",
|
||||
default="1970-01-01 00:00:00"
|
||||
)
|
||||
}
|
||||
)
|
||||
|
|
|
|||
|
|
@ -31,6 +31,8 @@ urlpatterns = [
|
|||
'application/<str:application_id>/chat/<chat_id>/chat_record/<str:chat_record_id>/dataset/<str:dataset_id>/document_id/<str:document_id>/improve',
|
||||
views.ChatView.ChatRecord.Improve.as_view(),
|
||||
name=''),
|
||||
path('application/<str:application_id>/chat/<chat_id>/chat_record/<str:chat_record_id>',
|
||||
views.ChatView.ChatRecord.ChatRecordImprove.as_view()),
|
||||
path('application/chat_message/<str:chat_id>', views.ChatView.Message.as_view())
|
||||
|
||||
]
|
||||
|
|
|
|||
|
|
@ -13,11 +13,10 @@ from rest_framework.views import APIView
|
|||
|
||||
from application.serializers.chat_message_serializers import ChatMessageSerializer
|
||||
from application.serializers.chat_serializers import ChatSerializers, ChatRecordSerializer
|
||||
from application.swagger_api.chat_api import ChatApi, VoteApi, ChatRecordApi, ImproveApi
|
||||
from application.swagger_api.chat_api import ChatApi, VoteApi, ChatRecordApi, ImproveApi, ChatRecordImproveApi
|
||||
from common.auth import TokenAuth, has_permissions
|
||||
from common.constants.permission_constants import Permission, Group, Operate, \
|
||||
RoleConstants, ViewPermission, CompareConstants
|
||||
from common.exception.app_exception import AppAuthenticationFailed
|
||||
from common.response import result
|
||||
from common.util.common import query_params_to_single_dict
|
||||
|
||||
|
|
@ -191,6 +190,25 @@ class ChatView(APIView):
|
|||
data={'vote_status': request.data.get('vote_status'), 'chat_id': chat_id,
|
||||
'chat_record_id': chat_record_id}).vote())
|
||||
|
||||
class ChatRecordImprove(APIView):
|
||||
authentication_classes = [TokenAuth]
|
||||
|
||||
@action(methods=['GET'], detail=False)
|
||||
@swagger_auto_schema(operation_summary="获取标注段落列表信息",
|
||||
operation_id="获取标注段落列表信息",
|
||||
manual_parameters=ChatRecordImproveApi.get_request_params_api(),
|
||||
responses=result.get_api_response(ChatRecordImproveApi.get_response_body_api()),
|
||||
tags=["应用/对话日志/标注"]
|
||||
)
|
||||
@has_permissions(
|
||||
ViewPermission([RoleConstants.ADMIN, RoleConstants.USER],
|
||||
[lambda r, keywords: Permission(group=Group.APPLICATION, operate=Operate.USE,
|
||||
dynamic_tag=keywords.get('application_id'))]
|
||||
))
|
||||
def get(self, request: Request, application_id: str, chat_id: str, chat_record_id: str):
|
||||
return result.success(ChatRecordSerializer.ChatRecordImprove(
|
||||
data={'chat_id': chat_id, 'chat_record_id': chat_record_id}).get())
|
||||
|
||||
class Improve(APIView):
|
||||
authentication_classes = [TokenAuth]
|
||||
|
||||
|
|
|
|||
|
|
@ -275,7 +275,8 @@ class ParagraphSerializers(ApiMixin, serializers.Serializer):
|
|||
def get_response_body_api():
|
||||
return openapi.Schema(
|
||||
type=openapi.TYPE_OBJECT,
|
||||
required=['id', 'content', 'hit_num', 'star_num', 'trample_num', 'is_active', 'document_id', 'title',
|
||||
required=['id', 'content', 'hit_num', 'star_num', 'trample_num', 'is_active', 'dataset_id',
|
||||
'document_id', 'title',
|
||||
'create_time', 'update_time'],
|
||||
properties={
|
||||
'id': openapi.Schema(type=openapi.TYPE_STRING, title="id",
|
||||
|
|
@ -290,6 +291,8 @@ class ParagraphSerializers(ApiMixin, serializers.Serializer):
|
|||
description="点赞数量", default=1),
|
||||
'trample_num': openapi.Schema(type=openapi.TYPE_INTEGER, title="点踩数量",
|
||||
description="点踩数", default=1),
|
||||
'dataset_id': openapi.Schema(type=openapi.TYPE_STRING, title="数据集id",
|
||||
description="数据集id", default='xxx'),
|
||||
'document_id': openapi.Schema(type=openapi.TYPE_STRING, title="文档id",
|
||||
description="文档id", default='xxx'),
|
||||
'is_active': openapi.Schema(type=openapi.TYPE_BOOLEAN, title="是否可用",
|
||||
|
|
|
|||
Loading…
Reference in New Issue