feat: 客户端不使用cookie存储改为localstore,优化认证代码

This commit is contained in:
zhangshaohu 2024-03-14 05:43:01 +08:00
parent 21a557ef43
commit 0fbd5873f7
20 changed files with 326 additions and 191 deletions

View File

@ -68,6 +68,8 @@ class IChatStep(IBaseChatPipelineStep):
padding_problem_text = serializers.CharField(required=False, error_messages=ErrMessage.base("补全问题"))
# 是否使用流的形式输出
stream = serializers.BooleanField(required=False, error_messages=ErrMessage.base("流式输出"))
client_id = serializers.CharField(required=True, error_messages=ErrMessage.char("客户端id"))
client_type = serializers.CharField(required=True, error_messages=ErrMessage.char("客户端类型"))
def is_valid(self, *, raise_exception=False):
super().is_valid(raise_exception=True)
@ -90,5 +92,5 @@ class IChatStep(IBaseChatPipelineStep):
chat_model: BaseChatModel = None,
paragraph_list=None,
manage: PiplineManage = None,
padding_problem_text: str = None, stream: bool = True, **kwargs):
padding_problem_text: str = None, stream: bool = True, client_id=None, client_type=None, **kwargs):
pass

View File

@ -13,6 +13,7 @@ import traceback
import uuid
from typing import List
from django.db.models import QuerySet
from django.http import StreamingHttpResponse
from langchain.chat_models.base import BaseChatModel
from langchain.schema import BaseMessage
@ -21,9 +22,20 @@ from langchain.schema.messages import BaseMessageChunk, HumanMessage, AIMessage
from application.chat_pipeline.I_base_chat_pipeline import ParagraphPipelineModel
from application.chat_pipeline.pipeline_manage import PiplineManage
from application.chat_pipeline.step.chat_step.i_chat_step import IChatStep, PostResponseHandler
from application.models.api_key_model import ApplicationPublicAccessClient
from common.constants.authentication_type import AuthenticationType
from common.response import result
def add_access_num(client_id=None, client_type=None):
if client_type == AuthenticationType.APPLICATION_ACCESS_TOKEN.value:
application_public_access_client = QuerySet(ApplicationPublicAccessClient).filter(id=client_id).first()
if application_public_access_client is not None:
application_public_access_client.access_num = application_public_access_client.access_num + 1
application_public_access_client.intraday_access_num = application_public_access_client.intraday_access_num + 1
application_public_access_client.save()
def event_content(response,
chat_id,
chat_record_id,
@ -34,7 +46,8 @@ def event_content(response,
chat_model,
message_list: List[BaseMessage],
problem_text: str,
padding_problem_text: str = None):
padding_problem_text: str = None,
client_id=None, client_type=None):
all_text = ''
try:
for chunk in response:
@ -57,6 +70,7 @@ def event_content(response,
all_text, manage, step, padding_problem_text)
yield 'data: ' + json.dumps({'chat_id': str(chat_id), 'id': str(chat_record_id), 'operate': True,
'content': '', 'is_end': True}) + "\n\n"
add_access_num(client_id, client_type)
except Exception as e:
logging.getLogger("max_kb_error").error(f'{str(e)}:{traceback.format_exc()}')
yield 'data: ' + json.dumps({'chat_id': str(chat_id), 'id': str(chat_record_id), 'operate': True,
@ -73,15 +87,16 @@ class BaseChatStep(IChatStep):
manage: PiplineManage = None,
padding_problem_text: str = None,
stream: bool = True,
client_id=None, client_type=None,
**kwargs):
if stream:
return self.execute_stream(message_list, chat_id, problem_text, post_response_handler, chat_model,
paragraph_list,
manage, padding_problem_text)
manage, padding_problem_text, client_id, client_type)
else:
return self.execute_block(message_list, chat_id, problem_text, post_response_handler, chat_model,
paragraph_list,
manage, padding_problem_text)
manage, padding_problem_text, client_id, client_type)
def get_details(self, manage, **kwargs):
return {
@ -111,7 +126,8 @@ class BaseChatStep(IChatStep):
chat_model: BaseChatModel = None,
paragraph_list=None,
manage: PiplineManage = None,
padding_problem_text: str = None):
padding_problem_text: str = None,
client_id=None, client_type=None):
# 调用模型
if chat_model is None:
chat_result = iter(
@ -123,7 +139,7 @@ class BaseChatStep(IChatStep):
r = StreamingHttpResponse(
streaming_content=event_content(chat_result, chat_id, chat_record_id, paragraph_list,
post_response_handler, manage, self, chat_model, message_list, problem_text,
padding_problem_text),
padding_problem_text, client_id, client_type),
content_type='text/event-stream;charset=utf-8')
r['Cache-Control'] = 'no-cache'
@ -136,7 +152,8 @@ class BaseChatStep(IChatStep):
chat_model: BaseChatModel = None,
paragraph_list=None,
manage: PiplineManage = None,
padding_problem_text: str = None):
padding_problem_text: str = None,
client_id=None, client_type=None):
# 调用模型
if chat_model is None:
chat_result = AIMessage(
@ -156,5 +173,6 @@ class BaseChatStep(IChatStep):
manage.context['answer_tokens'] = manage.context['answer_tokens'] + response_token
post_response_handler.handler(chat_id, chat_record_id, paragraph_list, problem_text,
chat_result.content, manage, self, padding_problem_text)
add_access_num(client_id, client_type)
return result.success({'chat_id': str(chat_id), 'id': str(chat_record_id), 'operate': True,
'content': chat_result.content, 'is_end': True})

View File

@ -0,0 +1,28 @@
# Generated by Django 4.1.10 on 2024-03-14 05:03
from django.db import migrations, models
import django.db.models.deletion
class Migration(migrations.Migration):
dependencies = [
('application', '0008_applicationaccesstoken_access_num_and_more'),
]
operations = [
migrations.CreateModel(
name='ApplicationPublicAccessClient',
fields=[
('create_time', models.DateTimeField(auto_now_add=True, verbose_name='创建时间')),
('update_time', models.DateTimeField(auto_now=True, verbose_name='修改时间')),
('id', models.UUIDField(primary_key=True, serialize=False, verbose_name='公共访问链接客户端id')),
('access_num', models.IntegerField(default=0, verbose_name='访问总次数次数')),
('intraday_access_num', models.IntegerField(default=0, verbose_name='当日访问次数')),
('application', models.ForeignKey(on_delete=django.db.models.deletion.CASCADE, to='application.application', verbose_name='应用id')),
],
options={
'db_table': 'application_public_access_client',
},
),
]

View File

@ -42,3 +42,13 @@ class ApplicationAccessToken(AppModelMixin):
class Meta:
db_table = "application_access_token"
class ApplicationPublicAccessClient(AppModelMixin):
id = models.UUIDField(max_length=128, primary_key=True, verbose_name="公共访问链接客户端id")
application = models.ForeignKey(Application, on_delete=models.CASCADE, verbose_name="应用id")
access_num = models.IntegerField(default=0, verbose_name="访问总次数次数")
intraday_access_num = models.IntegerField(default=0, verbose_name="当日访问次数")
class Meta:
db_table = "application_public_access_client"

View File

@ -28,10 +28,8 @@ from common.constants.authentication_type import AuthenticationType
from common.db.search import get_dynamics_model, native_search, native_page_search
from common.db.sql_execute import select_list
from common.exception.app_exception import AppApiException, NotFound404
from common.util.common import getRestSeconds, set_embed_identity_cookie
from common.util.field_message import ErrMessage
from common.util.file_util import get_file_content
from common.util.rsa_util import encrypt
from dataset.models import DataSet, Document
from dataset.serializers.common_serializers import list_paragraph
from setting.models import AuthOperate
@ -39,7 +37,6 @@ from setting.models.model_management import Model
from setting.models_provider.constants.model_provider_constants import ModelProvideConstants
from setting.serializers.provider_serializers import ModelSerializer
from smartdoc.conf import PROJECT_DIR
from smartdoc.settings import JWT_AUTH
token_cache = cache.caches['token_cache']
chat_cache = cache.caches['chat_cache']
@ -114,7 +111,7 @@ class ApplicationSerializer(serializers.Serializer):
protocol = serializers.CharField(required=True, error_messages=ErrMessage.char("协议"))
token = serializers.CharField(required=True, error_messages=ErrMessage.char("token"))
def get_embed(self, request, with_valid=True):
def get_embed(self, with_valid=True):
if with_valid:
self.is_valid(raise_exception=True)
index_path = os.path.join(PROJECT_DIR, 'apps', "application", 'template', 'embed.js')
@ -136,7 +133,6 @@ class ApplicationSerializer(serializers.Serializer):
application_access_token.white_list),
'white_active': 'true' if application_access_token.white_active else 'false'}))
response = HttpResponse(s, status=200, headers={'Content-Type': 'text/javascript'})
set_embed_identity_cookie(request, response)
return response
class AccessTokenSerializer(serializers.Serializer):
@ -197,17 +193,27 @@ class ApplicationSerializer(serializers.Serializer):
class Authentication(serializers.Serializer):
access_token = serializers.CharField(required=True, error_messages=ErrMessage.char("access_token"))
def auth(self, with_valid=True):
def auth(self, request, with_valid=True):
token = request.META.get('HTTP_AUTHORIZATION', None)
token_details = None
try:
# 校验token
if token is not None:
token_details = signing.loads(token)
except Exception as e:
token = None
if with_valid:
self.is_valid(raise_exception=True)
access_token = self.data.get("access_token")
application_access_token = QuerySet(ApplicationAccessToken).filter(access_token=access_token).first()
if application_access_token is not None and application_access_token.is_active:
token = signing.dumps({'application_id': str(application_access_token.application_id),
'user_id': str(application_access_token.application.user.id),
'access_token': application_access_token.access_token,
'type': AuthenticationType.APPLICATION_ACCESS_TOKEN.value})
token_cache.set(token, application_access_token, timeout=JWT_AUTH['JWT_EXPIRATION_DELTA'])
if token is None or (token_details is not None and 'client_id' not in token_details):
client_id = str(uuid.uuid1())
token = signing.dumps({'application_id': str(application_access_token.application_id),
'user_id': str(application_access_token.application.user.id),
'access_token': application_access_token.access_token,
'type': AuthenticationType.APPLICATION_ACCESS_TOKEN.value,
'client_id': client_id})
return token
else:
raise NotFound404(404, "无效的access_token")

View File

@ -23,7 +23,9 @@ from application.chat_pipeline.step.generate_human_message_step.impl.base_genera
from application.chat_pipeline.step.reset_problem_step.impl.base_reset_problem_step import BaseResetProblemStep
from application.chat_pipeline.step.search_dataset_step.impl.base_search_dataset_step import BaseSearchDatasetStep
from application.models import ChatRecord, Chat, Application, ApplicationDatasetMapping
from common.exception.app_exception import AppApiException
from application.models.api_key_model import ApplicationPublicAccessClient, ApplicationAccessToken
from common.constants.authentication_type import AuthenticationType
from common.exception.app_exception import AppApiException, AppChatNumOutOfBoundsFailed
from common.util.field_message import ErrMessage
from common.util.rsa_util import decrypt
from common.util.split_model import flat_map
@ -32,7 +34,6 @@ from setting.models import Model
from setting.models_provider.constants.model_provider_constants import ModelProvideConstants
chat_cache = caches['model_cache']
chat_embed_identity_cache = caches['chat_cache']
class ChatInfo:
@ -75,15 +76,16 @@ class ChatInfo:
'chat_model': self.chat_model,
'model_id': self.application.model.id if self.application.model is not None else None,
'problem_optimization': self.application.problem_optimization,
'stream': True
'stream': True,
}
def to_pipeline_manage_params(self, problem_text: str, post_response_handler: PostResponseHandler,
exclude_paragraph_id_list, stream=True):
exclude_paragraph_id_list, client_id: str, client_type, stream=True):
params = self.to_base_pipeline_manage_params()
return {**params, 'problem_text': problem_text, 'post_response_handler': post_response_handler,
'exclude_paragraph_id_list': exclude_paragraph_id_list, 'stream': stream}
'exclude_paragraph_id_list': exclude_paragraph_id_list, 'stream': stream, 'client_id': client_id,
'client_type': client_type}
def append_chat_record(self, chat_record: ChatRecord):
# 存入缓存中
@ -127,9 +129,37 @@ def get_post_handler(chat_info: ChatInfo):
class ChatMessageSerializer(serializers.Serializer):
chat_id = serializers.UUIDField(required=True)
chat_id = serializers.UUIDField(required=True, error_messages=ErrMessage.char("对话id"))
message = serializers.CharField(required=True, error_messages=ErrMessage.char("用户问题"))
stream = serializers.BooleanField(required=True, error_messages=ErrMessage.char("是否流式回答"))
re_chat = serializers.BooleanField(required=True, error_messages=ErrMessage.char("是否重新回答"))
application_id = serializers.UUIDField(required=False, allow_null=True, error_messages=ErrMessage.uuid("应用id"))
client_id = serializers.CharField(required=True, error_messages=ErrMessage.char("客户端id"))
client_type = serializers.CharField(required=True, error_messages=ErrMessage.char("客户端类型"))
def chat(self, message, re_chat: bool, stream: bool):
def is_valid(self, *, raise_exception=False):
super().is_valid(raise_exception=True)
if self.data.get('client_type') == AuthenticationType.APPLICATION_ACCESS_TOKEN.value:
access_client = QuerySet(ApplicationPublicAccessClient).filter(id=self.data.get('client_id')).first()
if access_client is None:
access_client = ApplicationPublicAccessClient(id=self.data.get('client_id'),
application_id=self.data.get('application_id'),
access_num=0,
intraday_access_num=0)
access_client.save()
application_access_token = QuerySet(ApplicationAccessToken).filter(
application_id=self.data.get('application_id')).first()
if application_access_token.access_num <= access_client.intraday_access_num:
raise AppChatNumOutOfBoundsFailed(1002, "访问次数超过今日访问量")
def chat(self):
self.is_valid(raise_exception=True)
message = self.data.get('message')
re_chat = self.data.get('re_chat')
stream = self.data.get('stream')
client_id = self.data.get('client_id')
client_type = self.data.get('client_type')
self.is_valid(raise_exception=True)
chat_id = self.data.get('chat_id')
chat_info: ChatInfo = chat_cache.get(chat_id)
@ -156,7 +186,7 @@ class ChatMessageSerializer(serializers.Serializer):
exclude_paragraph_id_list = list(set(paragraph_id_list))
# 构建运行参数
params = chat_info.to_pipeline_manage_params(message, get_post_handler(chat_info), exclude_paragraph_id_list,
stream)
client_id, client_type, stream)
# 运行流水线作业
pipline_message.run(params)
return pipline_message.context['chat_result']

View File

@ -2,7 +2,7 @@
"""
@project: maxkb
@Author
@file application_api.py
@file application_key.py
@date2023/11/7 10:50
@desc:
"""

View File

@ -37,7 +37,7 @@ class Application(APIView):
def get(self, request: Request):
return ApplicationSerializer.Embed(
data={'protocol': request.query_params.get('protocol'), 'token': request.query_params.get('token'),
'host': request.query_params.get('host'), }).get_embed(request)
'host': request.query_params.get('host'), }).get_embed()
class Model(APIView):
authentication_classes = [TokenAuth]
@ -192,7 +192,8 @@ class Application(APIView):
security=[])
def post(self, request: Request):
return result.success(
ApplicationSerializer.Authentication(data={'access_token': request.data.get("access_token")}).auth(),
ApplicationSerializer.Authentication(data={'access_token': request.data.get("access_token")}).auth(
request),
headers={"Access-Control-Allow-Origin": "*", "Access-Control-Allow-Credentials": "true",
"Access-Control-Allow-Methods": "POST",
"Access-Control-Allow-Headers": "Origin,Content-Type,Cookie,Accept,Token"}

View File

@ -15,6 +15,7 @@ from application.serializers.chat_message_serializers import ChatMessageSerializ
from application.serializers.chat_serializers import ChatSerializers, ChatRecordSerializer
from application.swagger_api.chat_api import ChatApi, VoteApi, ChatRecordApi, ImproveApi, ChatRecordImproveApi
from common.auth import TokenAuth, has_permissions
from common.constants.authentication_type import AuthenticationType
from common.constants.permission_constants import Permission, Group, Operate, \
RoleConstants, ViewPermission, CompareConstants
from common.response import result
@ -71,11 +72,15 @@ class ChatView(APIView):
dynamic_tag=keywords.get('application_id'))])
)
def post(self, request: Request, chat_id: str):
return ChatMessageSerializer(data={'chat_id': chat_id}).chat(request.data.get('message'),
request.data.get(
're_chat') if 're_chat' in request.data else False,
request.data.get(
'stream') if 'stream' in request.data else True)
return ChatMessageSerializer(data={'chat_id': chat_id, 'message': request.data.get('message'),
're_chat': (request.data.get(
're_chat') if 're_chat' in request.data else False),
'stream': (request.data.get(
'stream') if 'stream' in request.data else True),
'application_id': (request.auth.keywords.get(
'application_id') if request.auth.client_type == AuthenticationType.APPLICATION_ACCESS_TOKEN.value else None),
'client_id': request.auth.client_id,
'client_type': request.auth.client_type}).chat()
@action(methods=['GET'], detail=False)
@swagger_auto_schema(operation_summary="获取对话列表",

View File

@ -14,6 +14,9 @@ from django.db.models import QuerySet
from rest_framework.authentication import TokenAuthentication
from application.models.api_key_model import ApplicationAccessToken, ApplicationApiKey
from common.auth.handle.impl.application_key import ApplicationKey
from common.auth.handle.impl.public_access_token import PublicAccessToken
from common.auth.handle.impl.user_token import UserToken
from common.constants.authentication_type import AuthenticationType
from common.constants.permission_constants import Auth, get_permission_list_by_role, RoleConstants, Permission, Group, \
Operate
@ -29,6 +32,25 @@ class AnonymousAuthentication(TokenAuthentication):
return None, None
handles = [UserToken(), PublicAccessToken(), ApplicationKey()]
class TokenDetails:
token_details = None
is_load = False
def __init__(self, token: str):
self.token = token
def get_token_details(self):
if self.token_details is None and not self.is_load:
try:
self.token_details = signing.loads(self.token)
except Exception as e:
self.is_load = True
return self.token_details
class TokenAuth(TokenAuthentication):
# 重新 authenticate 方法,自定义认证规则
def authenticate(self, request):
@ -38,62 +60,11 @@ class TokenAuth(TokenAuthentication):
if auth is None:
raise AppAuthenticationFailed(1003, '未登录,请先登录')
try:
if str(auth).startswith("application-"):
application_api_key = QuerySet(ApplicationApiKey).filter(secret_key=auth).first()
if application_api_key is None:
raise AppAuthenticationFailed(500, "secret_key 无效")
if not application_api_key.is_active:
raise AppAuthenticationFailed(500, "secret_key 无效")
permission_list = [Permission(group=Group.APPLICATION,
operate=Operate.USE,
dynamic_tag=str(
application_api_key.application_id)),
Permission(group=Group.APPLICATION,
operate=Operate.MANAGE,
dynamic_tag=str(
application_api_key.application_id))
]
return application_api_key.user, Auth(role_list=[RoleConstants.APPLICATION_KEY],
permission_list=permission_list,
application_id=application_api_key.application_id)
# 解析 token
auth_details = signing.loads(auth)
cache_token = token_cache.get(auth)
if cache_token is None:
raise AppAuthenticationFailed(1002, "登录过期")
if 'id' in auth_details and auth_details.get('type') == AuthenticationType.USER.value:
user = QuerySet(User).get(id=auth_details['id'])
# 续期
token_cache.touch(auth, timeout=JWT_AUTH['JWT_EXPIRATION_DELTA'].total_seconds())
rule = RoleConstants[user.role]
permission_list = get_permission_list_by_role(RoleConstants[user.role])
# 获取用户的应用和知识库的权限
permission_list += get_user_dynamics_permission(str(user.id))
return user, Auth(role_list=[rule],
permission_list=permission_list)
if 'application_id' in auth_details and 'access_token' in auth_details and auth_details.get(
'type') == AuthenticationType.APPLICATION_ACCESS_TOKEN.value:
application_access_token = QuerySet(ApplicationAccessToken).filter(
application_id=auth_details.get('application_id')).first()
if application_access_token is None:
raise AppAuthenticationFailed(1002, "身份验证信息不正确")
if not application_access_token.is_active:
raise AppAuthenticationFailed(1002, "身份验证信息不正确")
if not application_access_token.access_token == auth_details.get('access_token'):
raise AppAuthenticationFailed(1002, "身份验证信息不正确")
return application_access_token.application.user, Auth(
role_list=[RoleConstants.APPLICATION_ACCESS_TOKEN],
permission_list=[
Permission(group=Group.APPLICATION,
operate=Operate.USE,
dynamic_tag=str(
application_access_token.application_id))],
application_id=application_access_token.application_id
)
else:
raise AppAuthenticationFailed(1002, "身份验证信息不正确!非法用户")
token_details = TokenDetails(auth)
for handle in handles:
if handle.support(request, auth, token_details.get_token_details):
return handle.handle(request, auth, token_details.get_token_details)
raise AppAuthenticationFailed(1002, "身份验证信息不正确!非法用户")
except Exception as e:
traceback.format_exc()
if isinstance(e, AppEmbedIdentityFailed) or isinstance(e, AppChatNumOutOfBoundsFailed):

View File

@ -0,0 +1,19 @@
# coding=utf-8
"""
@project: qabot
@Author
@file authenticate.py
@date2024/3/14 03:02
@desc: 认证处理器
"""
from abc import ABC, abstractmethod
class AuthBaseHandle(ABC):
@abstractmethod
def support(self, request, token: str, get_token_details):
pass
@abstractmethod
def handle(self, request, token: str, get_token_details):
pass

View File

@ -0,0 +1,41 @@
# coding=utf-8
"""
@project: qabot
@Author
@file authenticate.py
@date2024/3/14 03:02
@desc: 应用api key认证
"""
from django.db.models import QuerySet
from application.models.api_key_model import ApplicationApiKey
from common.auth.handle.auth_base_handle import AuthBaseHandle
from common.constants.authentication_type import AuthenticationType
from common.constants.permission_constants import Permission, Group, Operate, RoleConstants, Auth
from common.exception.app_exception import AppAuthenticationFailed
class ApplicationKey(AuthBaseHandle):
def handle(self, request, token: str, get_token_details):
application_api_key = QuerySet(ApplicationApiKey).filter(secret_key=token).first()
if application_api_key is None:
raise AppAuthenticationFailed(500, "secret_key 无效")
if not application_api_key.is_active:
raise AppAuthenticationFailed(500, "secret_key 无效")
permission_list = [Permission(group=Group.APPLICATION,
operate=Operate.USE,
dynamic_tag=str(
application_api_key.application_id)),
Permission(group=Group.APPLICATION,
operate=Operate.MANAGE,
dynamic_tag=str(
application_api_key.application_id))
]
return application_api_key.user, Auth(role_list=[RoleConstants.APPLICATION_KEY],
permission_list=permission_list,
application_id=application_api_key.application_id,
client_id=token,
client_type=AuthenticationType.API_KEY.value)
def support(self, request, token: str, get_token_details):
return str(token).startswith("application-")

View File

@ -0,0 +1,49 @@
# coding=utf-8
"""
@project: qabot
@Author
@file authenticate.py
@date2024/3/14 03:02
@desc: 公共访问连接认证
"""
from django.db.models import QuerySet
from application.models.api_key_model import ApplicationAccessToken
from common.auth.handle.auth_base_handle import AuthBaseHandle
from common.constants.authentication_type import AuthenticationType
from common.constants.permission_constants import RoleConstants, Permission, Group, Operate, Auth
from common.exception.app_exception import AppAuthenticationFailed
class PublicAccessToken(AuthBaseHandle):
def support(self, request, token: str, get_token_details):
token_details = get_token_details()
if token_details is None:
return False
return (
'application_id' in token_details and
'access_token' in token_details and
token_details.get('type') == AuthenticationType.APPLICATION_ACCESS_TOKEN.value)
def handle(self, request, token: str, get_token_details):
auth_details = get_token_details()
application_access_token = QuerySet(ApplicationAccessToken).filter(
application_id=auth_details.get('application_id')).first()
if application_access_token is None:
raise AppAuthenticationFailed(1002, "身份验证信息不正确")
if not application_access_token.is_active:
raise AppAuthenticationFailed(1002, "身份验证信息不正确")
if not application_access_token.access_token == auth_details.get('access_token'):
raise AppAuthenticationFailed(1002, "身份验证信息不正确")
return application_access_token.application.user, Auth(
role_list=[RoleConstants.APPLICATION_ACCESS_TOKEN],
permission_list=[
Permission(group=Group.APPLICATION,
operate=Operate.USE,
dynamic_tag=str(
application_access_token.application_id))],
application_id=application_access_token.application_id,
client_id=auth_details.get('client_id'),
client_type=AuthenticationType.APPLICATION_ACCESS_TOKEN.value
)

View File

@ -0,0 +1,46 @@
# coding=utf-8
"""
@project: qabot
@Author
@file authenticate.py
@date2024/3/14 03:02
@desc: 用户认证
"""
from django.db.models import QuerySet
from common.auth.handle.auth_base_handle import AuthBaseHandle
from common.constants.authentication_type import AuthenticationType
from common.constants.permission_constants import RoleConstants, get_permission_list_by_role, Auth
from common.exception.app_exception import AppAuthenticationFailed
from smartdoc.settings import JWT_AUTH
from users.models import User
from django.core import cache
from users.models.user import get_user_dynamics_permission
token_cache = cache.caches['token_cache']
class UserToken(AuthBaseHandle):
def support(self, request, token: str, get_token_details):
auth_details = get_token_details()
if auth_details is None:
return False
return 'id' in auth_details and auth_details.get('type') == AuthenticationType.USER.value
def handle(self, request, token: str, get_token_details):
cache_token = token_cache.get(token)
if cache_token is None:
raise AppAuthenticationFailed(1002, "登录过期")
auth_details = get_token_details()
user = QuerySet(User).get(id=auth_details['id'])
# 续期
token_cache.touch(token, timeout=JWT_AUTH['JWT_EXPIRATION_DELTA'].total_seconds())
rule = RoleConstants[user.role]
permission_list = get_permission_list_by_role(RoleConstants[user.role])
# 获取用户的应用和知识库的权限
permission_list += get_user_dynamics_permission(str(user.id))
return user, Auth(role_list=[rule],
permission_list=permission_list,
client_id=str(user.id),
client_type=AuthenticationType.USER.value)

View File

@ -10,7 +10,9 @@ from enum import Enum
class AuthenticationType(Enum):
# 或者
# 普通用户
USER = "USER"
# 并且
# 公共访问链接
APPLICATION_ACCESS_TOKEN = "APPLICATION_ACCESS_TOKEN"
# key API
API_KEY = "API_KEY"

View File

@ -151,10 +151,12 @@ class Auth:
用于存储当前用户的角色和权限
"""
def __init__(self, role_list: List[RoleConstants], permission_list: List[PermissionConstants | Permission],
**keywords):
def __init__(self, role_list: List[RoleConstants], permission_list: List[PermissionConstants | Permission]
, client_id, client_type, **keywords):
self.role_list = role_list
self.permission_list = permission_list
self.client_id = client_id
self.client_type = client_type
self.keywords = keywords

View File

@ -1,66 +0,0 @@
# coding=utf-8
"""
@project: maxkb
@Author
@file chat_cookie_middleware.py
@date2024/3/13 20:13
@desc:
"""
from django.core import cache
from django.core import signing
from django.db.models import QuerySet
from django.utils.deprecation import MiddlewareMixin
from application.models.api_key_model import ApplicationAccessToken
from common.exception.app_exception import AppEmbedIdentityFailed
from common.response import result
from common.util.common import set_embed_identity_cookie, getRestSeconds
from common.util.rsa_util import decrypt
chat_cache = cache.caches['chat_cache']
class ChatCookieMiddleware(MiddlewareMixin):
def process_response(self, request, response):
if request.path.startswith('/api/application/chat_message') or request.path.startswith(
'/api/application/authentication') or request.path.startswith('/api/application/profile'):
set_embed_identity_cookie(request, response)
if 'embed_identity' in request.COOKIES and request.path.__contains__('/api/application/chat_message/'):
embed_identity = request.COOKIES['embed_identity']
try:
# 如果无法解密 说明embed_identity并非系统颁发
value = decrypt(embed_identity)
except Exception as e:
raise AppEmbedIdentityFailed(1004, '嵌入cookie不正确')
# 对话次数+1
try:
if not chat_cache.incr(value):
# 如果修改失败则设置为1
chat_cache.set(value, 1,
timeout=getRestSeconds())
except Exception as e:
# 如果修改失败则设置为1 证明 key不存在
chat_cache.set(value, 1,
timeout=getRestSeconds())
return response
def process_request(self, request):
if 'embed_identity' in request.COOKIES and request.path.__contains__('/api/application/chat_message/'):
auth = request.META.get('HTTP_AUTHORIZATION', None
)
auth_details = signing.loads(auth)
application_access_token = QuerySet(ApplicationAccessToken).filter(
application_id=auth_details.get('application_id')).first()
embed_identity = request.COOKIES['embed_identity']
try:
# 如果无法解密 说明embed_identity并非系统颁发
value = decrypt(embed_identity)
except Exception as e:
return result.Result(1003,
message='访问次数超过今日访问量', response_status=460)
embed_identity_number = chat_cache.get(value)
if embed_identity_number is not None:
if application_access_token.access_num <= embed_identity_number:
return result.Result(1003,
message='访问次数超过今日访问量', response_status=461)

View File

@ -6,38 +6,10 @@
@date2023/10/16 16:42
@desc:
"""
import datetime
import importlib
import uuid
from functools import reduce
from typing import Dict, List
from django.core import cache
from .rsa_util import encrypt
chat_cache = cache.caches['chat_cache']
def set_embed_identity_cookie(request, response):
if 'embed_identity' in request.COOKIES:
embed_identity = request.COOKIES['embed_identity']
else:
value = str(uuid.uuid1())
embed_identity = encrypt(value)
chat_cache.set(value, 0, timeout=getRestSeconds())
response.set_cookie("embed_identity", embed_identity, max_age=3600 * 24 * 100, samesite='None',
secure=True)
return response
def getRestSeconds():
now = datetime.datetime.now()
today_begin = datetime.datetime(now.year, now.month, now.day, 0, 0, 0)
tomorrow_begin = today_begin + datetime.timedelta(days=1)
rest_seconds = (tomorrow_begin - now).seconds
return rest_seconds
def sub_array(array: List, item_num=10):
result = []

View File

@ -46,8 +46,7 @@ MIDDLEWARE = [
'django.contrib.sessions.middleware.SessionMiddleware',
'django.middleware.common.CommonMiddleware',
'django.contrib.messages.middleware.MessageMiddleware',
'common.middleware.static_headers_middleware.StaticHeadersMiddleware',
'common.middleware.chat_cookie_middleware.ChatCookieMiddleware'
'common.middleware.static_headers_middleware.StaticHeadersMiddleware'
]

View File

@ -17,7 +17,7 @@ from common.db.sql_execute import select_list
from common.util.file_util import get_file_content
from smartdoc.conf import PROJECT_DIR
__all__ = ["User", "password_encrypt"]
__all__ = ["User", "password_encrypt", 'get_user_dynamics_permission']
def password_encrypt(raw_password):