From 0fbd5873f7e0424aae75d8178fb10661b6462c2e Mon Sep 17 00:00:00 2001 From: zhangshaohu Date: Thu, 14 Mar 2024 05:43:01 +0800 Subject: [PATCH] =?UTF-8?q?feat:=20=E5=AE=A2=E6=88=B7=E7=AB=AF=E4=B8=8D?= =?UTF-8?q?=E4=BD=BF=E7=94=A8cookie=E5=AD=98=E5=82=A8=E6=94=B9=E4=B8=BAloc?= =?UTF-8?q?alstore,=E4=BC=98=E5=8C=96=E8=AE=A4=E8=AF=81=E4=BB=A3=E7=A0=81?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../step/chat_step/i_chat_step.py | 4 +- .../step/chat_step/impl/base_chat_step.py | 30 +++++-- .../0009_applicationpublicaccessclient.py | 28 +++++++ apps/application/models/api_key_model.py | 10 +++ .../serializers/application_serializers.py | 28 ++++--- .../serializers/chat_message_serializers.py | 46 ++++++++-- .../swagger_api/application_api.py | 2 +- apps/application/views/application_views.py | 5 +- apps/application/views/chat_views.py | 15 ++-- apps/common/auth/authenticate.py | 83 ++++++------------- apps/common/auth/handle/auth_base_handle.py | 19 +++++ .../auth/handle/impl/application_key.py | 41 +++++++++ .../auth/handle/impl/public_access_token.py | 49 +++++++++++ apps/common/auth/handle/impl/user_token.py | 46 ++++++++++ apps/common/constants/authentication_type.py | 6 +- apps/common/constants/permission_constants.py | 6 +- .../middleware/chat_cookie_middleware.py | 66 --------------- apps/common/util/common.py | 28 ------- apps/smartdoc/settings/base.py | 3 +- apps/users/models/user.py | 2 +- 20 files changed, 326 insertions(+), 191 deletions(-) create mode 100644 apps/application/migrations/0009_applicationpublicaccessclient.py create mode 100644 apps/common/auth/handle/auth_base_handle.py create mode 100644 apps/common/auth/handle/impl/application_key.py create mode 100644 apps/common/auth/handle/impl/public_access_token.py create mode 100644 apps/common/auth/handle/impl/user_token.py delete mode 100644 apps/common/middleware/chat_cookie_middleware.py diff --git a/apps/application/chat_pipeline/step/chat_step/i_chat_step.py b/apps/application/chat_pipeline/step/chat_step/i_chat_step.py index 39c6380c5..f18fbc211 100644 --- a/apps/application/chat_pipeline/step/chat_step/i_chat_step.py +++ b/apps/application/chat_pipeline/step/chat_step/i_chat_step.py @@ -68,6 +68,8 @@ class IChatStep(IBaseChatPipelineStep): padding_problem_text = serializers.CharField(required=False, error_messages=ErrMessage.base("补全问题")) # 是否使用流的形式输出 stream = serializers.BooleanField(required=False, error_messages=ErrMessage.base("流式输出")) + client_id = serializers.CharField(required=True, error_messages=ErrMessage.char("客户端id")) + client_type = serializers.CharField(required=True, error_messages=ErrMessage.char("客户端类型")) def is_valid(self, *, raise_exception=False): super().is_valid(raise_exception=True) @@ -90,5 +92,5 @@ class IChatStep(IBaseChatPipelineStep): chat_model: BaseChatModel = None, paragraph_list=None, manage: PiplineManage = None, - padding_problem_text: str = None, stream: bool = True, **kwargs): + padding_problem_text: str = None, stream: bool = True, client_id=None, client_type=None, **kwargs): pass diff --git a/apps/application/chat_pipeline/step/chat_step/impl/base_chat_step.py b/apps/application/chat_pipeline/step/chat_step/impl/base_chat_step.py index ab631114c..442f359d0 100644 --- a/apps/application/chat_pipeline/step/chat_step/impl/base_chat_step.py +++ b/apps/application/chat_pipeline/step/chat_step/impl/base_chat_step.py @@ -13,6 +13,7 @@ import traceback import uuid from typing import List +from django.db.models import QuerySet from django.http import StreamingHttpResponse from langchain.chat_models.base import BaseChatModel from langchain.schema import BaseMessage @@ -21,9 +22,20 @@ from langchain.schema.messages import BaseMessageChunk, HumanMessage, AIMessage from application.chat_pipeline.I_base_chat_pipeline import ParagraphPipelineModel from application.chat_pipeline.pipeline_manage import PiplineManage from application.chat_pipeline.step.chat_step.i_chat_step import IChatStep, PostResponseHandler +from application.models.api_key_model import ApplicationPublicAccessClient +from common.constants.authentication_type import AuthenticationType from common.response import result +def add_access_num(client_id=None, client_type=None): + if client_type == AuthenticationType.APPLICATION_ACCESS_TOKEN.value: + application_public_access_client = QuerySet(ApplicationPublicAccessClient).filter(id=client_id).first() + if application_public_access_client is not None: + application_public_access_client.access_num = application_public_access_client.access_num + 1 + application_public_access_client.intraday_access_num = application_public_access_client.intraday_access_num + 1 + application_public_access_client.save() + + def event_content(response, chat_id, chat_record_id, @@ -34,7 +46,8 @@ def event_content(response, chat_model, message_list: List[BaseMessage], problem_text: str, - padding_problem_text: str = None): + padding_problem_text: str = None, + client_id=None, client_type=None): all_text = '' try: for chunk in response: @@ -57,6 +70,7 @@ def event_content(response, all_text, manage, step, padding_problem_text) yield 'data: ' + json.dumps({'chat_id': str(chat_id), 'id': str(chat_record_id), 'operate': True, 'content': '', 'is_end': True}) + "\n\n" + add_access_num(client_id, client_type) except Exception as e: logging.getLogger("max_kb_error").error(f'{str(e)}:{traceback.format_exc()}') yield 'data: ' + json.dumps({'chat_id': str(chat_id), 'id': str(chat_record_id), 'operate': True, @@ -73,15 +87,16 @@ class BaseChatStep(IChatStep): manage: PiplineManage = None, padding_problem_text: str = None, stream: bool = True, + client_id=None, client_type=None, **kwargs): if stream: return self.execute_stream(message_list, chat_id, problem_text, post_response_handler, chat_model, paragraph_list, - manage, padding_problem_text) + manage, padding_problem_text, client_id, client_type) else: return self.execute_block(message_list, chat_id, problem_text, post_response_handler, chat_model, paragraph_list, - manage, padding_problem_text) + manage, padding_problem_text, client_id, client_type) def get_details(self, manage, **kwargs): return { @@ -111,7 +126,8 @@ class BaseChatStep(IChatStep): chat_model: BaseChatModel = None, paragraph_list=None, manage: PiplineManage = None, - padding_problem_text: str = None): + padding_problem_text: str = None, + client_id=None, client_type=None): # 调用模型 if chat_model is None: chat_result = iter( @@ -123,7 +139,7 @@ class BaseChatStep(IChatStep): r = StreamingHttpResponse( streaming_content=event_content(chat_result, chat_id, chat_record_id, paragraph_list, post_response_handler, manage, self, chat_model, message_list, problem_text, - padding_problem_text), + padding_problem_text, client_id, client_type), content_type='text/event-stream;charset=utf-8') r['Cache-Control'] = 'no-cache' @@ -136,7 +152,8 @@ class BaseChatStep(IChatStep): chat_model: BaseChatModel = None, paragraph_list=None, manage: PiplineManage = None, - padding_problem_text: str = None): + padding_problem_text: str = None, + client_id=None, client_type=None): # 调用模型 if chat_model is None: chat_result = AIMessage( @@ -156,5 +173,6 @@ class BaseChatStep(IChatStep): manage.context['answer_tokens'] = manage.context['answer_tokens'] + response_token post_response_handler.handler(chat_id, chat_record_id, paragraph_list, problem_text, chat_result.content, manage, self, padding_problem_text) + add_access_num(client_id, client_type) return result.success({'chat_id': str(chat_id), 'id': str(chat_record_id), 'operate': True, 'content': chat_result.content, 'is_end': True}) diff --git a/apps/application/migrations/0009_applicationpublicaccessclient.py b/apps/application/migrations/0009_applicationpublicaccessclient.py new file mode 100644 index 000000000..7c9553303 --- /dev/null +++ b/apps/application/migrations/0009_applicationpublicaccessclient.py @@ -0,0 +1,28 @@ +# Generated by Django 4.1.10 on 2024-03-14 05:03 + +from django.db import migrations, models +import django.db.models.deletion + + +class Migration(migrations.Migration): + + dependencies = [ + ('application', '0008_applicationaccesstoken_access_num_and_more'), + ] + + operations = [ + migrations.CreateModel( + name='ApplicationPublicAccessClient', + fields=[ + ('create_time', models.DateTimeField(auto_now_add=True, verbose_name='创建时间')), + ('update_time', models.DateTimeField(auto_now=True, verbose_name='修改时间')), + ('id', models.UUIDField(primary_key=True, serialize=False, verbose_name='公共访问链接客户端id')), + ('access_num', models.IntegerField(default=0, verbose_name='访问总次数次数')), + ('intraday_access_num', models.IntegerField(default=0, verbose_name='当日访问次数')), + ('application', models.ForeignKey(on_delete=django.db.models.deletion.CASCADE, to='application.application', verbose_name='应用id')), + ], + options={ + 'db_table': 'application_public_access_client', + }, + ), + ] diff --git a/apps/application/models/api_key_model.py b/apps/application/models/api_key_model.py index 84f3dfd08..1090d2e57 100644 --- a/apps/application/models/api_key_model.py +++ b/apps/application/models/api_key_model.py @@ -42,3 +42,13 @@ class ApplicationAccessToken(AppModelMixin): class Meta: db_table = "application_access_token" + + +class ApplicationPublicAccessClient(AppModelMixin): + id = models.UUIDField(max_length=128, primary_key=True, verbose_name="公共访问链接客户端id") + application = models.ForeignKey(Application, on_delete=models.CASCADE, verbose_name="应用id") + access_num = models.IntegerField(default=0, verbose_name="访问总次数次数") + intraday_access_num = models.IntegerField(default=0, verbose_name="当日访问次数") + + class Meta: + db_table = "application_public_access_client" diff --git a/apps/application/serializers/application_serializers.py b/apps/application/serializers/application_serializers.py index 153153ded..3b9fa4626 100644 --- a/apps/application/serializers/application_serializers.py +++ b/apps/application/serializers/application_serializers.py @@ -28,10 +28,8 @@ from common.constants.authentication_type import AuthenticationType from common.db.search import get_dynamics_model, native_search, native_page_search from common.db.sql_execute import select_list from common.exception.app_exception import AppApiException, NotFound404 -from common.util.common import getRestSeconds, set_embed_identity_cookie from common.util.field_message import ErrMessage from common.util.file_util import get_file_content -from common.util.rsa_util import encrypt from dataset.models import DataSet, Document from dataset.serializers.common_serializers import list_paragraph from setting.models import AuthOperate @@ -39,7 +37,6 @@ from setting.models.model_management import Model from setting.models_provider.constants.model_provider_constants import ModelProvideConstants from setting.serializers.provider_serializers import ModelSerializer from smartdoc.conf import PROJECT_DIR -from smartdoc.settings import JWT_AUTH token_cache = cache.caches['token_cache'] chat_cache = cache.caches['chat_cache'] @@ -114,7 +111,7 @@ class ApplicationSerializer(serializers.Serializer): protocol = serializers.CharField(required=True, error_messages=ErrMessage.char("协议")) token = serializers.CharField(required=True, error_messages=ErrMessage.char("token")) - def get_embed(self, request, with_valid=True): + def get_embed(self, with_valid=True): if with_valid: self.is_valid(raise_exception=True) index_path = os.path.join(PROJECT_DIR, 'apps', "application", 'template', 'embed.js') @@ -136,7 +133,6 @@ class ApplicationSerializer(serializers.Serializer): application_access_token.white_list), 'white_active': 'true' if application_access_token.white_active else 'false'})) response = HttpResponse(s, status=200, headers={'Content-Type': 'text/javascript'}) - set_embed_identity_cookie(request, response) return response class AccessTokenSerializer(serializers.Serializer): @@ -197,17 +193,27 @@ class ApplicationSerializer(serializers.Serializer): class Authentication(serializers.Serializer): access_token = serializers.CharField(required=True, error_messages=ErrMessage.char("access_token")) - def auth(self, with_valid=True): + def auth(self, request, with_valid=True): + token = request.META.get('HTTP_AUTHORIZATION', None) + token_details = None + try: + # 校验token + if token is not None: + token_details = signing.loads(token) + except Exception as e: + token = None if with_valid: self.is_valid(raise_exception=True) access_token = self.data.get("access_token") application_access_token = QuerySet(ApplicationAccessToken).filter(access_token=access_token).first() if application_access_token is not None and application_access_token.is_active: - token = signing.dumps({'application_id': str(application_access_token.application_id), - 'user_id': str(application_access_token.application.user.id), - 'access_token': application_access_token.access_token, - 'type': AuthenticationType.APPLICATION_ACCESS_TOKEN.value}) - token_cache.set(token, application_access_token, timeout=JWT_AUTH['JWT_EXPIRATION_DELTA']) + if token is None or (token_details is not None and 'client_id' not in token_details): + client_id = str(uuid.uuid1()) + token = signing.dumps({'application_id': str(application_access_token.application_id), + 'user_id': str(application_access_token.application.user.id), + 'access_token': application_access_token.access_token, + 'type': AuthenticationType.APPLICATION_ACCESS_TOKEN.value, + 'client_id': client_id}) return token else: raise NotFound404(404, "无效的access_token") diff --git a/apps/application/serializers/chat_message_serializers.py b/apps/application/serializers/chat_message_serializers.py index 0b042ab47..56847ed47 100644 --- a/apps/application/serializers/chat_message_serializers.py +++ b/apps/application/serializers/chat_message_serializers.py @@ -23,7 +23,9 @@ from application.chat_pipeline.step.generate_human_message_step.impl.base_genera from application.chat_pipeline.step.reset_problem_step.impl.base_reset_problem_step import BaseResetProblemStep from application.chat_pipeline.step.search_dataset_step.impl.base_search_dataset_step import BaseSearchDatasetStep from application.models import ChatRecord, Chat, Application, ApplicationDatasetMapping -from common.exception.app_exception import AppApiException +from application.models.api_key_model import ApplicationPublicAccessClient, ApplicationAccessToken +from common.constants.authentication_type import AuthenticationType +from common.exception.app_exception import AppApiException, AppChatNumOutOfBoundsFailed from common.util.field_message import ErrMessage from common.util.rsa_util import decrypt from common.util.split_model import flat_map @@ -32,7 +34,6 @@ from setting.models import Model from setting.models_provider.constants.model_provider_constants import ModelProvideConstants chat_cache = caches['model_cache'] -chat_embed_identity_cache = caches['chat_cache'] class ChatInfo: @@ -75,15 +76,16 @@ class ChatInfo: 'chat_model': self.chat_model, 'model_id': self.application.model.id if self.application.model is not None else None, 'problem_optimization': self.application.problem_optimization, - 'stream': True + 'stream': True, } def to_pipeline_manage_params(self, problem_text: str, post_response_handler: PostResponseHandler, - exclude_paragraph_id_list, stream=True): + exclude_paragraph_id_list, client_id: str, client_type, stream=True): params = self.to_base_pipeline_manage_params() return {**params, 'problem_text': problem_text, 'post_response_handler': post_response_handler, - 'exclude_paragraph_id_list': exclude_paragraph_id_list, 'stream': stream} + 'exclude_paragraph_id_list': exclude_paragraph_id_list, 'stream': stream, 'client_id': client_id, + 'client_type': client_type} def append_chat_record(self, chat_record: ChatRecord): # 存入缓存中 @@ -127,9 +129,37 @@ def get_post_handler(chat_info: ChatInfo): class ChatMessageSerializer(serializers.Serializer): - chat_id = serializers.UUIDField(required=True) + chat_id = serializers.UUIDField(required=True, error_messages=ErrMessage.char("对话id")) + message = serializers.CharField(required=True, error_messages=ErrMessage.char("用户问题")) + stream = serializers.BooleanField(required=True, error_messages=ErrMessage.char("是否流式回答")) + re_chat = serializers.BooleanField(required=True, error_messages=ErrMessage.char("是否重新回答")) + application_id = serializers.UUIDField(required=False, allow_null=True, error_messages=ErrMessage.uuid("应用id")) + client_id = serializers.CharField(required=True, error_messages=ErrMessage.char("客户端id")) + client_type = serializers.CharField(required=True, error_messages=ErrMessage.char("客户端类型")) - def chat(self, message, re_chat: bool, stream: bool): + def is_valid(self, *, raise_exception=False): + super().is_valid(raise_exception=True) + if self.data.get('client_type') == AuthenticationType.APPLICATION_ACCESS_TOKEN.value: + access_client = QuerySet(ApplicationPublicAccessClient).filter(id=self.data.get('client_id')).first() + if access_client is None: + access_client = ApplicationPublicAccessClient(id=self.data.get('client_id'), + application_id=self.data.get('application_id'), + access_num=0, + intraday_access_num=0) + access_client.save() + + application_access_token = QuerySet(ApplicationAccessToken).filter( + application_id=self.data.get('application_id')).first() + if application_access_token.access_num <= access_client.intraday_access_num: + raise AppChatNumOutOfBoundsFailed(1002, "访问次数超过今日访问量") + + def chat(self): + self.is_valid(raise_exception=True) + message = self.data.get('message') + re_chat = self.data.get('re_chat') + stream = self.data.get('stream') + client_id = self.data.get('client_id') + client_type = self.data.get('client_type') self.is_valid(raise_exception=True) chat_id = self.data.get('chat_id') chat_info: ChatInfo = chat_cache.get(chat_id) @@ -156,7 +186,7 @@ class ChatMessageSerializer(serializers.Serializer): exclude_paragraph_id_list = list(set(paragraph_id_list)) # 构建运行参数 params = chat_info.to_pipeline_manage_params(message, get_post_handler(chat_info), exclude_paragraph_id_list, - stream) + client_id, client_type, stream) # 运行流水线作业 pipline_message.run(params) return pipline_message.context['chat_result'] diff --git a/apps/application/swagger_api/application_api.py b/apps/application/swagger_api/application_api.py index c18544a7c..0bfe77fb1 100644 --- a/apps/application/swagger_api/application_api.py +++ b/apps/application/swagger_api/application_api.py @@ -2,7 +2,7 @@ """ @project: maxkb @Author:虎 - @file: application_api.py + @file: application_key.py @date:2023/11/7 10:50 @desc: """ diff --git a/apps/application/views/application_views.py b/apps/application/views/application_views.py index cf4885ea1..d121611b3 100644 --- a/apps/application/views/application_views.py +++ b/apps/application/views/application_views.py @@ -37,7 +37,7 @@ class Application(APIView): def get(self, request: Request): return ApplicationSerializer.Embed( data={'protocol': request.query_params.get('protocol'), 'token': request.query_params.get('token'), - 'host': request.query_params.get('host'), }).get_embed(request) + 'host': request.query_params.get('host'), }).get_embed() class Model(APIView): authentication_classes = [TokenAuth] @@ -192,7 +192,8 @@ class Application(APIView): security=[]) def post(self, request: Request): return result.success( - ApplicationSerializer.Authentication(data={'access_token': request.data.get("access_token")}).auth(), + ApplicationSerializer.Authentication(data={'access_token': request.data.get("access_token")}).auth( + request), headers={"Access-Control-Allow-Origin": "*", "Access-Control-Allow-Credentials": "true", "Access-Control-Allow-Methods": "POST", "Access-Control-Allow-Headers": "Origin,Content-Type,Cookie,Accept,Token"} diff --git a/apps/application/views/chat_views.py b/apps/application/views/chat_views.py index 7f7bed080..5f7102246 100644 --- a/apps/application/views/chat_views.py +++ b/apps/application/views/chat_views.py @@ -15,6 +15,7 @@ from application.serializers.chat_message_serializers import ChatMessageSerializ from application.serializers.chat_serializers import ChatSerializers, ChatRecordSerializer from application.swagger_api.chat_api import ChatApi, VoteApi, ChatRecordApi, ImproveApi, ChatRecordImproveApi from common.auth import TokenAuth, has_permissions +from common.constants.authentication_type import AuthenticationType from common.constants.permission_constants import Permission, Group, Operate, \ RoleConstants, ViewPermission, CompareConstants from common.response import result @@ -71,11 +72,15 @@ class ChatView(APIView): dynamic_tag=keywords.get('application_id'))]) ) def post(self, request: Request, chat_id: str): - return ChatMessageSerializer(data={'chat_id': chat_id}).chat(request.data.get('message'), - request.data.get( - 're_chat') if 're_chat' in request.data else False, - request.data.get( - 'stream') if 'stream' in request.data else True) + return ChatMessageSerializer(data={'chat_id': chat_id, 'message': request.data.get('message'), + 're_chat': (request.data.get( + 're_chat') if 're_chat' in request.data else False), + 'stream': (request.data.get( + 'stream') if 'stream' in request.data else True), + 'application_id': (request.auth.keywords.get( + 'application_id') if request.auth.client_type == AuthenticationType.APPLICATION_ACCESS_TOKEN.value else None), + 'client_id': request.auth.client_id, + 'client_type': request.auth.client_type}).chat() @action(methods=['GET'], detail=False) @swagger_auto_schema(operation_summary="获取对话列表", diff --git a/apps/common/auth/authenticate.py b/apps/common/auth/authenticate.py index 88d724459..bdd73874a 100644 --- a/apps/common/auth/authenticate.py +++ b/apps/common/auth/authenticate.py @@ -14,6 +14,9 @@ from django.db.models import QuerySet from rest_framework.authentication import TokenAuthentication from application.models.api_key_model import ApplicationAccessToken, ApplicationApiKey +from common.auth.handle.impl.application_key import ApplicationKey +from common.auth.handle.impl.public_access_token import PublicAccessToken +from common.auth.handle.impl.user_token import UserToken from common.constants.authentication_type import AuthenticationType from common.constants.permission_constants import Auth, get_permission_list_by_role, RoleConstants, Permission, Group, \ Operate @@ -29,6 +32,25 @@ class AnonymousAuthentication(TokenAuthentication): return None, None +handles = [UserToken(), PublicAccessToken(), ApplicationKey()] + + +class TokenDetails: + token_details = None + is_load = False + + def __init__(self, token: str): + self.token = token + + def get_token_details(self): + if self.token_details is None and not self.is_load: + try: + self.token_details = signing.loads(self.token) + except Exception as e: + self.is_load = True + return self.token_details + + class TokenAuth(TokenAuthentication): # 重新 authenticate 方法,自定义认证规则 def authenticate(self, request): @@ -38,62 +60,11 @@ class TokenAuth(TokenAuthentication): if auth is None: raise AppAuthenticationFailed(1003, '未登录,请先登录') try: - if str(auth).startswith("application-"): - application_api_key = QuerySet(ApplicationApiKey).filter(secret_key=auth).first() - if application_api_key is None: - raise AppAuthenticationFailed(500, "secret_key 无效") - if not application_api_key.is_active: - raise AppAuthenticationFailed(500, "secret_key 无效") - permission_list = [Permission(group=Group.APPLICATION, - operate=Operate.USE, - dynamic_tag=str( - application_api_key.application_id)), - Permission(group=Group.APPLICATION, - operate=Operate.MANAGE, - dynamic_tag=str( - application_api_key.application_id)) - ] - return application_api_key.user, Auth(role_list=[RoleConstants.APPLICATION_KEY], - permission_list=permission_list, - application_id=application_api_key.application_id) - # 解析 token - auth_details = signing.loads(auth) - cache_token = token_cache.get(auth) - if cache_token is None: - raise AppAuthenticationFailed(1002, "登录过期") - if 'id' in auth_details and auth_details.get('type') == AuthenticationType.USER.value: - user = QuerySet(User).get(id=auth_details['id']) - # 续期 - token_cache.touch(auth, timeout=JWT_AUTH['JWT_EXPIRATION_DELTA'].total_seconds()) - rule = RoleConstants[user.role] - permission_list = get_permission_list_by_role(RoleConstants[user.role]) - # 获取用户的应用和知识库的权限 - permission_list += get_user_dynamics_permission(str(user.id)) - return user, Auth(role_list=[rule], - permission_list=permission_list) - if 'application_id' in auth_details and 'access_token' in auth_details and auth_details.get( - 'type') == AuthenticationType.APPLICATION_ACCESS_TOKEN.value: - application_access_token = QuerySet(ApplicationAccessToken).filter( - application_id=auth_details.get('application_id')).first() - if application_access_token is None: - raise AppAuthenticationFailed(1002, "身份验证信息不正确") - if not application_access_token.is_active: - raise AppAuthenticationFailed(1002, "身份验证信息不正确") - if not application_access_token.access_token == auth_details.get('access_token'): - raise AppAuthenticationFailed(1002, "身份验证信息不正确") - return application_access_token.application.user, Auth( - role_list=[RoleConstants.APPLICATION_ACCESS_TOKEN], - permission_list=[ - Permission(group=Group.APPLICATION, - operate=Operate.USE, - dynamic_tag=str( - application_access_token.application_id))], - application_id=application_access_token.application_id - ) - - else: - raise AppAuthenticationFailed(1002, "身份验证信息不正确!非法用户") - + token_details = TokenDetails(auth) + for handle in handles: + if handle.support(request, auth, token_details.get_token_details): + return handle.handle(request, auth, token_details.get_token_details) + raise AppAuthenticationFailed(1002, "身份验证信息不正确!非法用户") except Exception as e: traceback.format_exc() if isinstance(e, AppEmbedIdentityFailed) or isinstance(e, AppChatNumOutOfBoundsFailed): diff --git a/apps/common/auth/handle/auth_base_handle.py b/apps/common/auth/handle/auth_base_handle.py new file mode 100644 index 000000000..991256e91 --- /dev/null +++ b/apps/common/auth/handle/auth_base_handle.py @@ -0,0 +1,19 @@ +# coding=utf-8 +""" + @project: qabot + @Author:虎 + @file: authenticate.py + @date:2024/3/14 03:02 + @desc: 认证处理器 +""" +from abc import ABC, abstractmethod + + +class AuthBaseHandle(ABC): + @abstractmethod + def support(self, request, token: str, get_token_details): + pass + + @abstractmethod + def handle(self, request, token: str, get_token_details): + pass diff --git a/apps/common/auth/handle/impl/application_key.py b/apps/common/auth/handle/impl/application_key.py new file mode 100644 index 000000000..5ebd9db28 --- /dev/null +++ b/apps/common/auth/handle/impl/application_key.py @@ -0,0 +1,41 @@ +# coding=utf-8 +""" + @project: qabot + @Author:虎 + @file: authenticate.py + @date:2024/3/14 03:02 + @desc: 应用api key认证 +""" +from django.db.models import QuerySet + +from application.models.api_key_model import ApplicationApiKey +from common.auth.handle.auth_base_handle import AuthBaseHandle +from common.constants.authentication_type import AuthenticationType +from common.constants.permission_constants import Permission, Group, Operate, RoleConstants, Auth +from common.exception.app_exception import AppAuthenticationFailed + + +class ApplicationKey(AuthBaseHandle): + def handle(self, request, token: str, get_token_details): + application_api_key = QuerySet(ApplicationApiKey).filter(secret_key=token).first() + if application_api_key is None: + raise AppAuthenticationFailed(500, "secret_key 无效") + if not application_api_key.is_active: + raise AppAuthenticationFailed(500, "secret_key 无效") + permission_list = [Permission(group=Group.APPLICATION, + operate=Operate.USE, + dynamic_tag=str( + application_api_key.application_id)), + Permission(group=Group.APPLICATION, + operate=Operate.MANAGE, + dynamic_tag=str( + application_api_key.application_id)) + ] + return application_api_key.user, Auth(role_list=[RoleConstants.APPLICATION_KEY], + permission_list=permission_list, + application_id=application_api_key.application_id, + client_id=token, + client_type=AuthenticationType.API_KEY.value) + + def support(self, request, token: str, get_token_details): + return str(token).startswith("application-") diff --git a/apps/common/auth/handle/impl/public_access_token.py b/apps/common/auth/handle/impl/public_access_token.py new file mode 100644 index 000000000..4e882ab9d --- /dev/null +++ b/apps/common/auth/handle/impl/public_access_token.py @@ -0,0 +1,49 @@ +# coding=utf-8 +""" + @project: qabot + @Author:虎 + @file: authenticate.py + @date:2024/3/14 03:02 + @desc: 公共访问连接认证 +""" +from django.db.models import QuerySet + +from application.models.api_key_model import ApplicationAccessToken +from common.auth.handle.auth_base_handle import AuthBaseHandle +from common.constants.authentication_type import AuthenticationType +from common.constants.permission_constants import RoleConstants, Permission, Group, Operate, Auth +from common.exception.app_exception import AppAuthenticationFailed + + +class PublicAccessToken(AuthBaseHandle): + def support(self, request, token: str, get_token_details): + token_details = get_token_details() + if token_details is None: + return False + return ( + 'application_id' in token_details and + 'access_token' in token_details and + token_details.get('type') == AuthenticationType.APPLICATION_ACCESS_TOKEN.value) + + def handle(self, request, token: str, get_token_details): + auth_details = get_token_details() + application_access_token = QuerySet(ApplicationAccessToken).filter( + application_id=auth_details.get('application_id')).first() + if application_access_token is None: + raise AppAuthenticationFailed(1002, "身份验证信息不正确") + if not application_access_token.is_active: + raise AppAuthenticationFailed(1002, "身份验证信息不正确") + if not application_access_token.access_token == auth_details.get('access_token'): + raise AppAuthenticationFailed(1002, "身份验证信息不正确") + + return application_access_token.application.user, Auth( + role_list=[RoleConstants.APPLICATION_ACCESS_TOKEN], + permission_list=[ + Permission(group=Group.APPLICATION, + operate=Operate.USE, + dynamic_tag=str( + application_access_token.application_id))], + application_id=application_access_token.application_id, + client_id=auth_details.get('client_id'), + client_type=AuthenticationType.APPLICATION_ACCESS_TOKEN.value + ) diff --git a/apps/common/auth/handle/impl/user_token.py b/apps/common/auth/handle/impl/user_token.py new file mode 100644 index 000000000..5a67bea52 --- /dev/null +++ b/apps/common/auth/handle/impl/user_token.py @@ -0,0 +1,46 @@ +# coding=utf-8 +""" + @project: qabot + @Author:虎 + @file: authenticate.py + @date:2024/3/14 03:02 + @desc: 用户认证 +""" +from django.db.models import QuerySet + +from common.auth.handle.auth_base_handle import AuthBaseHandle +from common.constants.authentication_type import AuthenticationType +from common.constants.permission_constants import RoleConstants, get_permission_list_by_role, Auth +from common.exception.app_exception import AppAuthenticationFailed +from smartdoc.settings import JWT_AUTH +from users.models import User +from django.core import cache + +from users.models.user import get_user_dynamics_permission + +token_cache = cache.caches['token_cache'] + + +class UserToken(AuthBaseHandle): + def support(self, request, token: str, get_token_details): + auth_details = get_token_details() + if auth_details is None: + return False + return 'id' in auth_details and auth_details.get('type') == AuthenticationType.USER.value + + def handle(self, request, token: str, get_token_details): + cache_token = token_cache.get(token) + if cache_token is None: + raise AppAuthenticationFailed(1002, "登录过期") + auth_details = get_token_details() + user = QuerySet(User).get(id=auth_details['id']) + # 续期 + token_cache.touch(token, timeout=JWT_AUTH['JWT_EXPIRATION_DELTA'].total_seconds()) + rule = RoleConstants[user.role] + permission_list = get_permission_list_by_role(RoleConstants[user.role]) + # 获取用户的应用和知识库的权限 + permission_list += get_user_dynamics_permission(str(user.id)) + return user, Auth(role_list=[rule], + permission_list=permission_list, + client_id=str(user.id), + client_type=AuthenticationType.USER.value) diff --git a/apps/common/constants/authentication_type.py b/apps/common/constants/authentication_type.py index f223d2a72..33163003e 100644 --- a/apps/common/constants/authentication_type.py +++ b/apps/common/constants/authentication_type.py @@ -10,7 +10,9 @@ from enum import Enum class AuthenticationType(Enum): - # 或者 + # 普通用户 USER = "USER" - # 并且 + # 公共访问链接 APPLICATION_ACCESS_TOKEN = "APPLICATION_ACCESS_TOKEN" + # key API + API_KEY = "API_KEY" diff --git a/apps/common/constants/permission_constants.py b/apps/common/constants/permission_constants.py index 6f3f5c0f3..2d08e0c40 100644 --- a/apps/common/constants/permission_constants.py +++ b/apps/common/constants/permission_constants.py @@ -151,10 +151,12 @@ class Auth: 用于存储当前用户的角色和权限 """ - def __init__(self, role_list: List[RoleConstants], permission_list: List[PermissionConstants | Permission], - **keywords): + def __init__(self, role_list: List[RoleConstants], permission_list: List[PermissionConstants | Permission] + , client_id, client_type, **keywords): self.role_list = role_list self.permission_list = permission_list + self.client_id = client_id + self.client_type = client_type self.keywords = keywords diff --git a/apps/common/middleware/chat_cookie_middleware.py b/apps/common/middleware/chat_cookie_middleware.py deleted file mode 100644 index 60a26fe5f..000000000 --- a/apps/common/middleware/chat_cookie_middleware.py +++ /dev/null @@ -1,66 +0,0 @@ -# coding=utf-8 -""" - @project: maxkb - @Author:虎 - @file: chat_cookie_middleware.py - @date:2024/3/13 20:13 - @desc: -""" -from django.core import cache -from django.core import signing -from django.db.models import QuerySet -from django.utils.deprecation import MiddlewareMixin - -from application.models.api_key_model import ApplicationAccessToken -from common.exception.app_exception import AppEmbedIdentityFailed -from common.response import result -from common.util.common import set_embed_identity_cookie, getRestSeconds -from common.util.rsa_util import decrypt - -chat_cache = cache.caches['chat_cache'] - - -class ChatCookieMiddleware(MiddlewareMixin): - - def process_response(self, request, response): - if request.path.startswith('/api/application/chat_message') or request.path.startswith( - '/api/application/authentication') or request.path.startswith('/api/application/profile'): - set_embed_identity_cookie(request, response) - if 'embed_identity' in request.COOKIES and request.path.__contains__('/api/application/chat_message/'): - embed_identity = request.COOKIES['embed_identity'] - try: - # 如果无法解密 说明embed_identity并非系统颁发 - value = decrypt(embed_identity) - except Exception as e: - raise AppEmbedIdentityFailed(1004, '嵌入cookie不正确') - # 对话次数+1 - try: - if not chat_cache.incr(value): - # 如果修改失败则设置为1 - chat_cache.set(value, 1, - timeout=getRestSeconds()) - except Exception as e: - # 如果修改失败则设置为1 证明 key不存在 - chat_cache.set(value, 1, - timeout=getRestSeconds()) - return response - - def process_request(self, request): - if 'embed_identity' in request.COOKIES and request.path.__contains__('/api/application/chat_message/'): - auth = request.META.get('HTTP_AUTHORIZATION', None - ) - auth_details = signing.loads(auth) - application_access_token = QuerySet(ApplicationAccessToken).filter( - application_id=auth_details.get('application_id')).first() - embed_identity = request.COOKIES['embed_identity'] - try: - # 如果无法解密 说明embed_identity并非系统颁发 - value = decrypt(embed_identity) - except Exception as e: - return result.Result(1003, - message='访问次数超过今日访问量', response_status=460) - embed_identity_number = chat_cache.get(value) - if embed_identity_number is not None: - if application_access_token.access_num <= embed_identity_number: - return result.Result(1003, - message='访问次数超过今日访问量', response_status=461) diff --git a/apps/common/util/common.py b/apps/common/util/common.py index f2eb0120c..52d90ec85 100644 --- a/apps/common/util/common.py +++ b/apps/common/util/common.py @@ -6,38 +6,10 @@ @date:2023/10/16 16:42 @desc: """ -import datetime import importlib -import uuid from functools import reduce from typing import Dict, List -from django.core import cache - -from .rsa_util import encrypt - -chat_cache = cache.caches['chat_cache'] - - -def set_embed_identity_cookie(request, response): - if 'embed_identity' in request.COOKIES: - embed_identity = request.COOKIES['embed_identity'] - else: - value = str(uuid.uuid1()) - embed_identity = encrypt(value) - chat_cache.set(value, 0, timeout=getRestSeconds()) - response.set_cookie("embed_identity", embed_identity, max_age=3600 * 24 * 100, samesite='None', - secure=True) - return response - - -def getRestSeconds(): - now = datetime.datetime.now() - today_begin = datetime.datetime(now.year, now.month, now.day, 0, 0, 0) - tomorrow_begin = today_begin + datetime.timedelta(days=1) - rest_seconds = (tomorrow_begin - now).seconds - return rest_seconds - def sub_array(array: List, item_num=10): result = [] diff --git a/apps/smartdoc/settings/base.py b/apps/smartdoc/settings/base.py index 69e9a2bd3..4b22d0702 100644 --- a/apps/smartdoc/settings/base.py +++ b/apps/smartdoc/settings/base.py @@ -46,8 +46,7 @@ MIDDLEWARE = [ 'django.contrib.sessions.middleware.SessionMiddleware', 'django.middleware.common.CommonMiddleware', 'django.contrib.messages.middleware.MessageMiddleware', - 'common.middleware.static_headers_middleware.StaticHeadersMiddleware', - 'common.middleware.chat_cookie_middleware.ChatCookieMiddleware' + 'common.middleware.static_headers_middleware.StaticHeadersMiddleware' ] diff --git a/apps/users/models/user.py b/apps/users/models/user.py index d7cfc1f4c..3eb015a37 100644 --- a/apps/users/models/user.py +++ b/apps/users/models/user.py @@ -17,7 +17,7 @@ from common.db.sql_execute import select_list from common.util.file_util import get_file_content from smartdoc.conf import PROJECT_DIR -__all__ = ["User", "password_encrypt"] +__all__ = ["User", "password_encrypt", 'get_user_dynamics_permission'] def password_encrypt(raw_password):