diff --git a/apps/application/serializers/chat_serializers.py b/apps/application/serializers/chat_serializers.py index 6c4a68707..74abad178 100644 --- a/apps/application/serializers/chat_serializers.py +++ b/apps/application/serializers/chat_serializers.py @@ -27,6 +27,7 @@ from application.models.api_key_model import ApplicationAccessToken from application.serializers.application_serializers import ModelDatasetAssociation, DatasetSettingSerializer, \ ModelSettingSerializer from application.serializers.chat_message_serializers import ChatInfo +from common.constants.permission_constants import RoleConstants from common.db.search import native_search, native_page_search, page_search, get_dynamics_model from common.event import ListenerManagement from common.exception.app_exception import AppApiException @@ -281,13 +282,13 @@ class ChatRecordSerializer(serializers.Serializer): application_id = serializers.UUIDField(required=True, error_messages=ErrMessage.uuid("应用id")) chat_record_id = serializers.UUIDField(required=True, error_messages=ErrMessage.uuid("对话记录id")) - def is_valid(self, *, raise_exception=False): + def is_valid(self, *, current_role=None, raise_exception=False): super().is_valid(raise_exception=True) application_access_token = QuerySet(ApplicationAccessToken).filter( application_id=self.data.get('application_id')).first() if application_access_token is None: raise AppApiException(500, '不存在的应用认证信息') - if not application_access_token.show_source: + if not application_access_token.show_source and current_role == RoleConstants.APPLICATION_ACCESS_TOKEN.value: raise AppApiException(500, '未开启显示知识来源') def get_chat_record(self): @@ -301,9 +302,9 @@ class ChatRecordSerializer(serializers.Serializer): return chat_record_list[-1] return QuerySet(ChatRecord).filter(id=chat_record_id, chat_id=chat_id).first() - def one(self, with_valid=True): + def one(self, current_role: RoleConstants, with_valid=True): if with_valid: - self.is_valid(raise_exception=True) + self.is_valid(current_role=current_role, raise_exception=True) chat_record = self.get_chat_record() if chat_record is None: raise AppApiException(500, "对话不存在") diff --git a/apps/application/views/chat_views.py b/apps/application/views/chat_views.py index ff8eabef4..2d6ef10f1 100644 --- a/apps/application/views/chat_views.py +++ b/apps/application/views/chat_views.py @@ -181,7 +181,7 @@ class ChatView(APIView): return result.success(ChatRecordSerializer.Operate( data={'application_id': application_id, 'chat_id': chat_id, - 'chat_record_id': chat_record_id}).one()) + 'chat_record_id': chat_record_id}).one(request.auth.current_role)) @action(methods=['GET'], detail=False) @swagger_auto_schema(operation_summary="获取对话记录列表", diff --git a/apps/common/auth/handle/impl/application_key.py b/apps/common/auth/handle/impl/application_key.py index cee128ef8..b35ef80fc 100644 --- a/apps/common/auth/handle/impl/application_key.py +++ b/apps/common/auth/handle/impl/application_key.py @@ -35,7 +35,9 @@ class ApplicationKey(AuthBaseHandle): permission_list=permission_list, application_id=application_api_key.application_id, client_id=str(application_api_key.id), - client_type=AuthenticationType.API_KEY.value) + client_type=AuthenticationType.API_KEY.value, + current_role=RoleConstants.APPLICATION_KEY + ) def support(self, request, token: str, get_token_details): return str(token).startswith("application-") diff --git a/apps/common/auth/handle/impl/public_access_token.py b/apps/common/auth/handle/impl/public_access_token.py index 4e882ab9d..1655187a8 100644 --- a/apps/common/auth/handle/impl/public_access_token.py +++ b/apps/common/auth/handle/impl/public_access_token.py @@ -45,5 +45,6 @@ class PublicAccessToken(AuthBaseHandle): application_access_token.application_id))], application_id=application_access_token.application_id, client_id=auth_details.get('client_id'), - client_type=AuthenticationType.APPLICATION_ACCESS_TOKEN.value + client_type=AuthenticationType.APPLICATION_ACCESS_TOKEN.value, + current_role=RoleConstants.APPLICATION_ACCESS_TOKEN ) diff --git a/apps/common/auth/handle/impl/user_token.py b/apps/common/auth/handle/impl/user_token.py index 5a67bea52..6559797ba 100644 --- a/apps/common/auth/handle/impl/user_token.py +++ b/apps/common/auth/handle/impl/user_token.py @@ -43,4 +43,5 @@ class UserToken(AuthBaseHandle): return user, Auth(role_list=[rule], permission_list=permission_list, client_id=str(user.id), - client_type=AuthenticationType.USER.value) + client_type=AuthenticationType.USER.value, + current_role=rule) diff --git a/apps/common/constants/permission_constants.py b/apps/common/constants/permission_constants.py index 6484897c7..04f86bbc7 100644 --- a/apps/common/constants/permission_constants.py +++ b/apps/common/constants/permission_constants.py @@ -152,12 +152,13 @@ class Auth: """ def __init__(self, role_list: List[RoleConstants], permission_list: List[PermissionConstants | Permission] - , client_id, client_type, **keywords): + , client_id, client_type, current_role: RoleConstants, **keywords): self.role_list = role_list self.permission_list = permission_list self.client_id = client_id self.client_type = client_type self.keywords = keywords + self.current_role = current_role class CompareConstants(Enum):