MaxKB/apps/application/views/chat_views.py
2023-12-06 16:29:14 +08:00

219 lines
12 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

# coding=utf-8
"""
@project: maxkb
@Author
@file chat_views.py
@date2023/11/14 9:53
@desc:
"""
from drf_yasg.utils import swagger_auto_schema
from rest_framework.decorators import action
from rest_framework.request import Request
from rest_framework.views import APIView
from application.serializers.chat_message_serializers import ChatMessageSerializer
from application.serializers.chat_serializers import ChatSerializers, ChatRecordSerializer
from application.swagger_api.chat_api import ChatApi, VoteApi, ChatRecordApi, ImproveApi
from common.auth import TokenAuth, has_permissions
from common.constants.permission_constants import Permission, Group, Operate, \
RoleConstants, ViewPermission, CompareConstants
from common.exception.app_exception import AppAuthenticationFailed
from common.response import result
from common.util.common import query_params_to_single_dict
class ChatView(APIView):
authentication_classes = [TokenAuth]
class Open(APIView):
authentication_classes = [TokenAuth]
@action(methods=['GET'], detail=False)
@swagger_auto_schema(operation_summary="获取会话id,根据应用id",
operation_id="获取会话id,根据应用id",
manual_parameters=ChatApi.OpenChat.get_request_params_api(),
tags=["应用/会话"])
@has_permissions(
ViewPermission([RoleConstants.ADMIN, RoleConstants.USER, RoleConstants.APPLICATION_ACCESS_TOKEN,
RoleConstants.APPLICATION_KEY],
[lambda r, keywords: Permission(group=Group.APPLICATION, operate=Operate.USE,
dynamic_tag=keywords.get('application_id'))],
compare=CompareConstants.AND)
)
def get(self, request: Request, application_id: str):
return result.success(ChatSerializers.OpenChat(
data={'user_id': request.user.id, 'application_id': application_id}).open())
class OpenTemp(APIView):
authentication_classes = [TokenAuth]
@action(methods=['POST'], detail=False)
@swagger_auto_schema(operation_summary="获取会话id(根据模型id,数据集列表,是否多轮会话)",
operation_id="获取会话id",
request_body=ChatApi.OpenTempChat.get_request_body_api(),
tags=["应用/会话"])
@has_permissions(RoleConstants.ADMIN, RoleConstants.USER)
def post(self, request: Request):
return result.success(ChatSerializers.OpenTempChat(
data={**request.data, 'user_id': request.user.id}).open())
class Message(APIView):
authentication_classes = [TokenAuth]
@action(methods=['POST'], detail=False)
@swagger_auto_schema(operation_summary="对话",
operation_id="对话",
request_body=ChatApi.get_request_body_api(),
tags=["应用/会话"])
@has_permissions(
ViewPermission([RoleConstants.ADMIN, RoleConstants.USER, RoleConstants.APPLICATION_KEY,
RoleConstants.APPLICATION_ACCESS_TOKEN],
[lambda r, keywords: Permission(group=Group.APPLICATION, operate=Operate.USE,
dynamic_tag=keywords.get('application_id'))])
)
def post(self, request: Request, chat_id: str):
return ChatMessageSerializer(data={'chat_id': chat_id}).chat(request.data.get('message'))
@action(methods=['GET'], detail=False)
@swagger_auto_schema(operation_summary="获取对话列表",
operation_id="获取对话列表",
manual_parameters=ChatApi.get_request_params_api(),
tags=["应用/对话日志"]
)
@has_permissions(
ViewPermission([RoleConstants.ADMIN, RoleConstants.USER, RoleConstants.APPLICATION_KEY],
[lambda r, keywords: Permission(group=Group.APPLICATION, operate=Operate.USE,
dynamic_tag=keywords.get('application_id'))])
)
def get(self, request: Request, application_id: str):
return result.success(ChatSerializers.Query(
data={**query_params_to_single_dict(request.query_params), 'application_id': application_id,
'user_id': request.user.id}).list())
class Operate(APIView):
authentication_classes = [TokenAuth]
@action(methods=['DELETE'], detail=False)
@swagger_auto_schema(operation_summary="删除对话",
operation_id="删除对话",
tags=["应用/对话日志"])
@has_permissions(ViewPermission(
[RoleConstants.ADMIN, RoleConstants.USER],
[lambda r, keywords: Permission(group=Group.APPLICATION, operate=Operate.MANAGE,
dynamic_tag=keywords.get('application_id'))],
compare=CompareConstants.AND), lambda r, k: Permission(group=Group.APPLICATION, operate=Operate.DELETE,
dynamic_tag=k.get('application_id')),
compare=CompareConstants.AND)
def delete(self, request: Request, application_id: str, chat_id: str):
return result.success(
ChatSerializers.Operate(
data={'application_id': application_id, 'user_id': request.user.id,
'chat_id': chat_id}).delete())
class Page(APIView):
authentication_classes = [TokenAuth]
@action(methods=['GET'], detail=False)
@swagger_auto_schema(operation_summary="分页获取对话列表",
operation_id="分页获取对话列表",
manual_parameters=result.get_page_request_params(ChatApi.get_request_params_api()),
tags=["应用/对话日志"]
)
@has_permissions(
ViewPermission([RoleConstants.ADMIN, RoleConstants.USER, RoleConstants.APPLICATION_KEY],
[lambda r, keywords: Permission(group=Group.APPLICATION, operate=Operate.USE,
dynamic_tag=keywords.get('application_id'))])
)
def get(self, request: Request, application_id: str, current_page: int, page_size: int):
return result.success(ChatSerializers.Query(
data={**query_params_to_single_dict(request.query_params), 'application_id': application_id,
'user_id': request.user.id}).page(current_page=current_page,
page_size=page_size))
class ChatRecord(APIView):
authentication_classes = [TokenAuth]
@action(methods=['GET'], detail=False)
@swagger_auto_schema(operation_summary="获取对话记录列表",
operation_id="获取对话记录列表",
manual_parameters=ChatRecordApi.get_request_params_api(),
tags=["应用/对话日志"]
)
@has_permissions(
ViewPermission([RoleConstants.ADMIN, RoleConstants.USER, RoleConstants.APPLICATION_KEY],
[lambda r, keywords: Permission(group=Group.APPLICATION, operate=Operate.USE,
dynamic_tag=keywords.get('application_id'))])
)
def get(self, request: Request, application_id: str, chat_id: str):
return result.success(ChatRecordSerializer.Query(
data={'application_id': application_id,
'chat_id': chat_id}).list())
class Page(APIView):
@action(methods=['GET'], detail=False)
@swagger_auto_schema(operation_summary="获取对话记录列表",
operation_id="获取对话记录列表",
manual_parameters=result.get_page_request_params(
ChatRecordApi.get_request_params_api()),
tags=["应用/对话日志"]
)
@has_permissions(
ViewPermission([RoleConstants.ADMIN, RoleConstants.USER, RoleConstants.APPLICATION_KEY],
[lambda r, keywords: Permission(group=Group.APPLICATION, operate=Operate.USE,
dynamic_tag=keywords.get('application_id'))])
)
def get(self, request: Request, application_id: str, chat_id: str, current_page: int, page_size: int):
return result.success(ChatRecordSerializer.Query(
data={'application_id': application_id,
'chat_id': chat_id}).page(current_page, page_size))
class Vote(APIView):
authentication_classes = [TokenAuth]
@action(methods=['PUT'], detail=False)
@swagger_auto_schema(operation_summary="点赞,点踩",
operation_id="点赞,点踩",
manual_parameters=VoteApi.get_request_params_api(),
request_body=VoteApi.get_request_body_api(),
responses=result.get_default_response(),
tags=["应用/会话"]
)
@has_permissions(
ViewPermission([RoleConstants.ADMIN, RoleConstants.USER, RoleConstants.APPLICATION_KEY,
RoleConstants.APPLICATION_ACCESS_TOKEN],
[lambda r, keywords: Permission(group=Group.APPLICATION, operate=Operate.USE,
dynamic_tag=keywords.get('application_id'))])
)
def put(self, request: Request, application_id: str, chat_id: str, chat_record_id: str):
return result.success(ChatRecordSerializer.Vote(
data={'vote_status': request.data.get('vote_status'), 'chat_id': chat_id,
'chat_record_id': chat_record_id}).vote())
class Improve(APIView):
authentication_classes = [TokenAuth]
@action(methods=['PUT'], detail=False)
@swagger_auto_schema(operation_summary="标注",
operation_id="标注",
manual_parameters=ImproveApi.get_request_params_api(),
request_body=ImproveApi.get_request_body_api(),
responses=result.get_default_response(),
tags=["应用/对话日志/标注"]
)
@has_permissions(
ViewPermission([RoleConstants.ADMIN, RoleConstants.USER],
[lambda r, keywords: Permission(group=Group.APPLICATION, operate=Operate.USE,
dynamic_tag=keywords.get('application_id'))],
ViewPermission([RoleConstants.ADMIN, RoleConstants.USER],
[lambda r, keywords: Permission(group=Group.APPLICATION,
operate=Operate.MANAGE,
dynamic_tag=keywords.get(
'dataset_id'))],
)
))
def put(self, request: Request, application_id: str, chat_id: str, chat_record_id: str, dataset_id: str,
document_id: str):
return result.success(ChatRecordSerializer.Improve(
data={'chat_id': chat_id, 'chat_record_id': chat_record_id,
'dataset_id': dataset_id, 'document_id': document_id}).improve(request.data))