mirror of
https://github.com/1Panel-dev/MaxKB.git
synced 2025-12-26 01:33:05 +00:00
feat: 客户端不使用cookie存储改为localstore,优化认证代码
This commit is contained in:
parent
21a557ef43
commit
0fbd5873f7
|
|
@ -68,6 +68,8 @@ class IChatStep(IBaseChatPipelineStep):
|
|||
padding_problem_text = serializers.CharField(required=False, error_messages=ErrMessage.base("补全问题"))
|
||||
# 是否使用流的形式输出
|
||||
stream = serializers.BooleanField(required=False, error_messages=ErrMessage.base("流式输出"))
|
||||
client_id = serializers.CharField(required=True, error_messages=ErrMessage.char("客户端id"))
|
||||
client_type = serializers.CharField(required=True, error_messages=ErrMessage.char("客户端类型"))
|
||||
|
||||
def is_valid(self, *, raise_exception=False):
|
||||
super().is_valid(raise_exception=True)
|
||||
|
|
@ -90,5 +92,5 @@ class IChatStep(IBaseChatPipelineStep):
|
|||
chat_model: BaseChatModel = None,
|
||||
paragraph_list=None,
|
||||
manage: PiplineManage = None,
|
||||
padding_problem_text: str = None, stream: bool = True, **kwargs):
|
||||
padding_problem_text: str = None, stream: bool = True, client_id=None, client_type=None, **kwargs):
|
||||
pass
|
||||
|
|
|
|||
|
|
@ -13,6 +13,7 @@ import traceback
|
|||
import uuid
|
||||
from typing import List
|
||||
|
||||
from django.db.models import QuerySet
|
||||
from django.http import StreamingHttpResponse
|
||||
from langchain.chat_models.base import BaseChatModel
|
||||
from langchain.schema import BaseMessage
|
||||
|
|
@ -21,9 +22,20 @@ from langchain.schema.messages import BaseMessageChunk, HumanMessage, AIMessage
|
|||
from application.chat_pipeline.I_base_chat_pipeline import ParagraphPipelineModel
|
||||
from application.chat_pipeline.pipeline_manage import PiplineManage
|
||||
from application.chat_pipeline.step.chat_step.i_chat_step import IChatStep, PostResponseHandler
|
||||
from application.models.api_key_model import ApplicationPublicAccessClient
|
||||
from common.constants.authentication_type import AuthenticationType
|
||||
from common.response import result
|
||||
|
||||
|
||||
def add_access_num(client_id=None, client_type=None):
|
||||
if client_type == AuthenticationType.APPLICATION_ACCESS_TOKEN.value:
|
||||
application_public_access_client = QuerySet(ApplicationPublicAccessClient).filter(id=client_id).first()
|
||||
if application_public_access_client is not None:
|
||||
application_public_access_client.access_num = application_public_access_client.access_num + 1
|
||||
application_public_access_client.intraday_access_num = application_public_access_client.intraday_access_num + 1
|
||||
application_public_access_client.save()
|
||||
|
||||
|
||||
def event_content(response,
|
||||
chat_id,
|
||||
chat_record_id,
|
||||
|
|
@ -34,7 +46,8 @@ def event_content(response,
|
|||
chat_model,
|
||||
message_list: List[BaseMessage],
|
||||
problem_text: str,
|
||||
padding_problem_text: str = None):
|
||||
padding_problem_text: str = None,
|
||||
client_id=None, client_type=None):
|
||||
all_text = ''
|
||||
try:
|
||||
for chunk in response:
|
||||
|
|
@ -57,6 +70,7 @@ def event_content(response,
|
|||
all_text, manage, step, padding_problem_text)
|
||||
yield 'data: ' + json.dumps({'chat_id': str(chat_id), 'id': str(chat_record_id), 'operate': True,
|
||||
'content': '', 'is_end': True}) + "\n\n"
|
||||
add_access_num(client_id, client_type)
|
||||
except Exception as e:
|
||||
logging.getLogger("max_kb_error").error(f'{str(e)}:{traceback.format_exc()}')
|
||||
yield 'data: ' + json.dumps({'chat_id': str(chat_id), 'id': str(chat_record_id), 'operate': True,
|
||||
|
|
@ -73,15 +87,16 @@ class BaseChatStep(IChatStep):
|
|||
manage: PiplineManage = None,
|
||||
padding_problem_text: str = None,
|
||||
stream: bool = True,
|
||||
client_id=None, client_type=None,
|
||||
**kwargs):
|
||||
if stream:
|
||||
return self.execute_stream(message_list, chat_id, problem_text, post_response_handler, chat_model,
|
||||
paragraph_list,
|
||||
manage, padding_problem_text)
|
||||
manage, padding_problem_text, client_id, client_type)
|
||||
else:
|
||||
return self.execute_block(message_list, chat_id, problem_text, post_response_handler, chat_model,
|
||||
paragraph_list,
|
||||
manage, padding_problem_text)
|
||||
manage, padding_problem_text, client_id, client_type)
|
||||
|
||||
def get_details(self, manage, **kwargs):
|
||||
return {
|
||||
|
|
@ -111,7 +126,8 @@ class BaseChatStep(IChatStep):
|
|||
chat_model: BaseChatModel = None,
|
||||
paragraph_list=None,
|
||||
manage: PiplineManage = None,
|
||||
padding_problem_text: str = None):
|
||||
padding_problem_text: str = None,
|
||||
client_id=None, client_type=None):
|
||||
# 调用模型
|
||||
if chat_model is None:
|
||||
chat_result = iter(
|
||||
|
|
@ -123,7 +139,7 @@ class BaseChatStep(IChatStep):
|
|||
r = StreamingHttpResponse(
|
||||
streaming_content=event_content(chat_result, chat_id, chat_record_id, paragraph_list,
|
||||
post_response_handler, manage, self, chat_model, message_list, problem_text,
|
||||
padding_problem_text),
|
||||
padding_problem_text, client_id, client_type),
|
||||
content_type='text/event-stream;charset=utf-8')
|
||||
|
||||
r['Cache-Control'] = 'no-cache'
|
||||
|
|
@ -136,7 +152,8 @@ class BaseChatStep(IChatStep):
|
|||
chat_model: BaseChatModel = None,
|
||||
paragraph_list=None,
|
||||
manage: PiplineManage = None,
|
||||
padding_problem_text: str = None):
|
||||
padding_problem_text: str = None,
|
||||
client_id=None, client_type=None):
|
||||
# 调用模型
|
||||
if chat_model is None:
|
||||
chat_result = AIMessage(
|
||||
|
|
@ -156,5 +173,6 @@ class BaseChatStep(IChatStep):
|
|||
manage.context['answer_tokens'] = manage.context['answer_tokens'] + response_token
|
||||
post_response_handler.handler(chat_id, chat_record_id, paragraph_list, problem_text,
|
||||
chat_result.content, manage, self, padding_problem_text)
|
||||
add_access_num(client_id, client_type)
|
||||
return result.success({'chat_id': str(chat_id), 'id': str(chat_record_id), 'operate': True,
|
||||
'content': chat_result.content, 'is_end': True})
|
||||
|
|
|
|||
|
|
@ -0,0 +1,28 @@
|
|||
# Generated by Django 4.1.10 on 2024-03-14 05:03
|
||||
|
||||
from django.db import migrations, models
|
||||
import django.db.models.deletion
|
||||
|
||||
|
||||
class Migration(migrations.Migration):
|
||||
|
||||
dependencies = [
|
||||
('application', '0008_applicationaccesstoken_access_num_and_more'),
|
||||
]
|
||||
|
||||
operations = [
|
||||
migrations.CreateModel(
|
||||
name='ApplicationPublicAccessClient',
|
||||
fields=[
|
||||
('create_time', models.DateTimeField(auto_now_add=True, verbose_name='创建时间')),
|
||||
('update_time', models.DateTimeField(auto_now=True, verbose_name='修改时间')),
|
||||
('id', models.UUIDField(primary_key=True, serialize=False, verbose_name='公共访问链接客户端id')),
|
||||
('access_num', models.IntegerField(default=0, verbose_name='访问总次数次数')),
|
||||
('intraday_access_num', models.IntegerField(default=0, verbose_name='当日访问次数')),
|
||||
('application', models.ForeignKey(on_delete=django.db.models.deletion.CASCADE, to='application.application', verbose_name='应用id')),
|
||||
],
|
||||
options={
|
||||
'db_table': 'application_public_access_client',
|
||||
},
|
||||
),
|
||||
]
|
||||
|
|
@ -42,3 +42,13 @@ class ApplicationAccessToken(AppModelMixin):
|
|||
|
||||
class Meta:
|
||||
db_table = "application_access_token"
|
||||
|
||||
|
||||
class ApplicationPublicAccessClient(AppModelMixin):
|
||||
id = models.UUIDField(max_length=128, primary_key=True, verbose_name="公共访问链接客户端id")
|
||||
application = models.ForeignKey(Application, on_delete=models.CASCADE, verbose_name="应用id")
|
||||
access_num = models.IntegerField(default=0, verbose_name="访问总次数次数")
|
||||
intraday_access_num = models.IntegerField(default=0, verbose_name="当日访问次数")
|
||||
|
||||
class Meta:
|
||||
db_table = "application_public_access_client"
|
||||
|
|
|
|||
|
|
@ -28,10 +28,8 @@ from common.constants.authentication_type import AuthenticationType
|
|||
from common.db.search import get_dynamics_model, native_search, native_page_search
|
||||
from common.db.sql_execute import select_list
|
||||
from common.exception.app_exception import AppApiException, NotFound404
|
||||
from common.util.common import getRestSeconds, set_embed_identity_cookie
|
||||
from common.util.field_message import ErrMessage
|
||||
from common.util.file_util import get_file_content
|
||||
from common.util.rsa_util import encrypt
|
||||
from dataset.models import DataSet, Document
|
||||
from dataset.serializers.common_serializers import list_paragraph
|
||||
from setting.models import AuthOperate
|
||||
|
|
@ -39,7 +37,6 @@ from setting.models.model_management import Model
|
|||
from setting.models_provider.constants.model_provider_constants import ModelProvideConstants
|
||||
from setting.serializers.provider_serializers import ModelSerializer
|
||||
from smartdoc.conf import PROJECT_DIR
|
||||
from smartdoc.settings import JWT_AUTH
|
||||
|
||||
token_cache = cache.caches['token_cache']
|
||||
chat_cache = cache.caches['chat_cache']
|
||||
|
|
@ -114,7 +111,7 @@ class ApplicationSerializer(serializers.Serializer):
|
|||
protocol = serializers.CharField(required=True, error_messages=ErrMessage.char("协议"))
|
||||
token = serializers.CharField(required=True, error_messages=ErrMessage.char("token"))
|
||||
|
||||
def get_embed(self, request, with_valid=True):
|
||||
def get_embed(self, with_valid=True):
|
||||
if with_valid:
|
||||
self.is_valid(raise_exception=True)
|
||||
index_path = os.path.join(PROJECT_DIR, 'apps', "application", 'template', 'embed.js')
|
||||
|
|
@ -136,7 +133,6 @@ class ApplicationSerializer(serializers.Serializer):
|
|||
application_access_token.white_list),
|
||||
'white_active': 'true' if application_access_token.white_active else 'false'}))
|
||||
response = HttpResponse(s, status=200, headers={'Content-Type': 'text/javascript'})
|
||||
set_embed_identity_cookie(request, response)
|
||||
return response
|
||||
|
||||
class AccessTokenSerializer(serializers.Serializer):
|
||||
|
|
@ -197,17 +193,27 @@ class ApplicationSerializer(serializers.Serializer):
|
|||
class Authentication(serializers.Serializer):
|
||||
access_token = serializers.CharField(required=True, error_messages=ErrMessage.char("access_token"))
|
||||
|
||||
def auth(self, with_valid=True):
|
||||
def auth(self, request, with_valid=True):
|
||||
token = request.META.get('HTTP_AUTHORIZATION', None)
|
||||
token_details = None
|
||||
try:
|
||||
# 校验token
|
||||
if token is not None:
|
||||
token_details = signing.loads(token)
|
||||
except Exception as e:
|
||||
token = None
|
||||
if with_valid:
|
||||
self.is_valid(raise_exception=True)
|
||||
access_token = self.data.get("access_token")
|
||||
application_access_token = QuerySet(ApplicationAccessToken).filter(access_token=access_token).first()
|
||||
if application_access_token is not None and application_access_token.is_active:
|
||||
token = signing.dumps({'application_id': str(application_access_token.application_id),
|
||||
'user_id': str(application_access_token.application.user.id),
|
||||
'access_token': application_access_token.access_token,
|
||||
'type': AuthenticationType.APPLICATION_ACCESS_TOKEN.value})
|
||||
token_cache.set(token, application_access_token, timeout=JWT_AUTH['JWT_EXPIRATION_DELTA'])
|
||||
if token is None or (token_details is not None and 'client_id' not in token_details):
|
||||
client_id = str(uuid.uuid1())
|
||||
token = signing.dumps({'application_id': str(application_access_token.application_id),
|
||||
'user_id': str(application_access_token.application.user.id),
|
||||
'access_token': application_access_token.access_token,
|
||||
'type': AuthenticationType.APPLICATION_ACCESS_TOKEN.value,
|
||||
'client_id': client_id})
|
||||
return token
|
||||
else:
|
||||
raise NotFound404(404, "无效的access_token")
|
||||
|
|
|
|||
|
|
@ -23,7 +23,9 @@ from application.chat_pipeline.step.generate_human_message_step.impl.base_genera
|
|||
from application.chat_pipeline.step.reset_problem_step.impl.base_reset_problem_step import BaseResetProblemStep
|
||||
from application.chat_pipeline.step.search_dataset_step.impl.base_search_dataset_step import BaseSearchDatasetStep
|
||||
from application.models import ChatRecord, Chat, Application, ApplicationDatasetMapping
|
||||
from common.exception.app_exception import AppApiException
|
||||
from application.models.api_key_model import ApplicationPublicAccessClient, ApplicationAccessToken
|
||||
from common.constants.authentication_type import AuthenticationType
|
||||
from common.exception.app_exception import AppApiException, AppChatNumOutOfBoundsFailed
|
||||
from common.util.field_message import ErrMessage
|
||||
from common.util.rsa_util import decrypt
|
||||
from common.util.split_model import flat_map
|
||||
|
|
@ -32,7 +34,6 @@ from setting.models import Model
|
|||
from setting.models_provider.constants.model_provider_constants import ModelProvideConstants
|
||||
|
||||
chat_cache = caches['model_cache']
|
||||
chat_embed_identity_cache = caches['chat_cache']
|
||||
|
||||
|
||||
class ChatInfo:
|
||||
|
|
@ -75,15 +76,16 @@ class ChatInfo:
|
|||
'chat_model': self.chat_model,
|
||||
'model_id': self.application.model.id if self.application.model is not None else None,
|
||||
'problem_optimization': self.application.problem_optimization,
|
||||
'stream': True
|
||||
'stream': True,
|
||||
|
||||
}
|
||||
|
||||
def to_pipeline_manage_params(self, problem_text: str, post_response_handler: PostResponseHandler,
|
||||
exclude_paragraph_id_list, stream=True):
|
||||
exclude_paragraph_id_list, client_id: str, client_type, stream=True):
|
||||
params = self.to_base_pipeline_manage_params()
|
||||
return {**params, 'problem_text': problem_text, 'post_response_handler': post_response_handler,
|
||||
'exclude_paragraph_id_list': exclude_paragraph_id_list, 'stream': stream}
|
||||
'exclude_paragraph_id_list': exclude_paragraph_id_list, 'stream': stream, 'client_id': client_id,
|
||||
'client_type': client_type}
|
||||
|
||||
def append_chat_record(self, chat_record: ChatRecord):
|
||||
# 存入缓存中
|
||||
|
|
@ -127,9 +129,37 @@ def get_post_handler(chat_info: ChatInfo):
|
|||
|
||||
|
||||
class ChatMessageSerializer(serializers.Serializer):
|
||||
chat_id = serializers.UUIDField(required=True)
|
||||
chat_id = serializers.UUIDField(required=True, error_messages=ErrMessage.char("对话id"))
|
||||
message = serializers.CharField(required=True, error_messages=ErrMessage.char("用户问题"))
|
||||
stream = serializers.BooleanField(required=True, error_messages=ErrMessage.char("是否流式回答"))
|
||||
re_chat = serializers.BooleanField(required=True, error_messages=ErrMessage.char("是否重新回答"))
|
||||
application_id = serializers.UUIDField(required=False, allow_null=True, error_messages=ErrMessage.uuid("应用id"))
|
||||
client_id = serializers.CharField(required=True, error_messages=ErrMessage.char("客户端id"))
|
||||
client_type = serializers.CharField(required=True, error_messages=ErrMessage.char("客户端类型"))
|
||||
|
||||
def chat(self, message, re_chat: bool, stream: bool):
|
||||
def is_valid(self, *, raise_exception=False):
|
||||
super().is_valid(raise_exception=True)
|
||||
if self.data.get('client_type') == AuthenticationType.APPLICATION_ACCESS_TOKEN.value:
|
||||
access_client = QuerySet(ApplicationPublicAccessClient).filter(id=self.data.get('client_id')).first()
|
||||
if access_client is None:
|
||||
access_client = ApplicationPublicAccessClient(id=self.data.get('client_id'),
|
||||
application_id=self.data.get('application_id'),
|
||||
access_num=0,
|
||||
intraday_access_num=0)
|
||||
access_client.save()
|
||||
|
||||
application_access_token = QuerySet(ApplicationAccessToken).filter(
|
||||
application_id=self.data.get('application_id')).first()
|
||||
if application_access_token.access_num <= access_client.intraday_access_num:
|
||||
raise AppChatNumOutOfBoundsFailed(1002, "访问次数超过今日访问量")
|
||||
|
||||
def chat(self):
|
||||
self.is_valid(raise_exception=True)
|
||||
message = self.data.get('message')
|
||||
re_chat = self.data.get('re_chat')
|
||||
stream = self.data.get('stream')
|
||||
client_id = self.data.get('client_id')
|
||||
client_type = self.data.get('client_type')
|
||||
self.is_valid(raise_exception=True)
|
||||
chat_id = self.data.get('chat_id')
|
||||
chat_info: ChatInfo = chat_cache.get(chat_id)
|
||||
|
|
@ -156,7 +186,7 @@ class ChatMessageSerializer(serializers.Serializer):
|
|||
exclude_paragraph_id_list = list(set(paragraph_id_list))
|
||||
# 构建运行参数
|
||||
params = chat_info.to_pipeline_manage_params(message, get_post_handler(chat_info), exclude_paragraph_id_list,
|
||||
stream)
|
||||
client_id, client_type, stream)
|
||||
# 运行流水线作业
|
||||
pipline_message.run(params)
|
||||
return pipline_message.context['chat_result']
|
||||
|
|
|
|||
|
|
@ -2,7 +2,7 @@
|
|||
"""
|
||||
@project: maxkb
|
||||
@Author:虎
|
||||
@file: application_api.py
|
||||
@file: application_key.py
|
||||
@date:2023/11/7 10:50
|
||||
@desc:
|
||||
"""
|
||||
|
|
|
|||
|
|
@ -37,7 +37,7 @@ class Application(APIView):
|
|||
def get(self, request: Request):
|
||||
return ApplicationSerializer.Embed(
|
||||
data={'protocol': request.query_params.get('protocol'), 'token': request.query_params.get('token'),
|
||||
'host': request.query_params.get('host'), }).get_embed(request)
|
||||
'host': request.query_params.get('host'), }).get_embed()
|
||||
|
||||
class Model(APIView):
|
||||
authentication_classes = [TokenAuth]
|
||||
|
|
@ -192,7 +192,8 @@ class Application(APIView):
|
|||
security=[])
|
||||
def post(self, request: Request):
|
||||
return result.success(
|
||||
ApplicationSerializer.Authentication(data={'access_token': request.data.get("access_token")}).auth(),
|
||||
ApplicationSerializer.Authentication(data={'access_token': request.data.get("access_token")}).auth(
|
||||
request),
|
||||
headers={"Access-Control-Allow-Origin": "*", "Access-Control-Allow-Credentials": "true",
|
||||
"Access-Control-Allow-Methods": "POST",
|
||||
"Access-Control-Allow-Headers": "Origin,Content-Type,Cookie,Accept,Token"}
|
||||
|
|
|
|||
|
|
@ -15,6 +15,7 @@ from application.serializers.chat_message_serializers import ChatMessageSerializ
|
|||
from application.serializers.chat_serializers import ChatSerializers, ChatRecordSerializer
|
||||
from application.swagger_api.chat_api import ChatApi, VoteApi, ChatRecordApi, ImproveApi, ChatRecordImproveApi
|
||||
from common.auth import TokenAuth, has_permissions
|
||||
from common.constants.authentication_type import AuthenticationType
|
||||
from common.constants.permission_constants import Permission, Group, Operate, \
|
||||
RoleConstants, ViewPermission, CompareConstants
|
||||
from common.response import result
|
||||
|
|
@ -71,11 +72,15 @@ class ChatView(APIView):
|
|||
dynamic_tag=keywords.get('application_id'))])
|
||||
)
|
||||
def post(self, request: Request, chat_id: str):
|
||||
return ChatMessageSerializer(data={'chat_id': chat_id}).chat(request.data.get('message'),
|
||||
request.data.get(
|
||||
're_chat') if 're_chat' in request.data else False,
|
||||
request.data.get(
|
||||
'stream') if 'stream' in request.data else True)
|
||||
return ChatMessageSerializer(data={'chat_id': chat_id, 'message': request.data.get('message'),
|
||||
're_chat': (request.data.get(
|
||||
're_chat') if 're_chat' in request.data else False),
|
||||
'stream': (request.data.get(
|
||||
'stream') if 'stream' in request.data else True),
|
||||
'application_id': (request.auth.keywords.get(
|
||||
'application_id') if request.auth.client_type == AuthenticationType.APPLICATION_ACCESS_TOKEN.value else None),
|
||||
'client_id': request.auth.client_id,
|
||||
'client_type': request.auth.client_type}).chat()
|
||||
|
||||
@action(methods=['GET'], detail=False)
|
||||
@swagger_auto_schema(operation_summary="获取对话列表",
|
||||
|
|
|
|||
|
|
@ -14,6 +14,9 @@ from django.db.models import QuerySet
|
|||
from rest_framework.authentication import TokenAuthentication
|
||||
|
||||
from application.models.api_key_model import ApplicationAccessToken, ApplicationApiKey
|
||||
from common.auth.handle.impl.application_key import ApplicationKey
|
||||
from common.auth.handle.impl.public_access_token import PublicAccessToken
|
||||
from common.auth.handle.impl.user_token import UserToken
|
||||
from common.constants.authentication_type import AuthenticationType
|
||||
from common.constants.permission_constants import Auth, get_permission_list_by_role, RoleConstants, Permission, Group, \
|
||||
Operate
|
||||
|
|
@ -29,6 +32,25 @@ class AnonymousAuthentication(TokenAuthentication):
|
|||
return None, None
|
||||
|
||||
|
||||
handles = [UserToken(), PublicAccessToken(), ApplicationKey()]
|
||||
|
||||
|
||||
class TokenDetails:
|
||||
token_details = None
|
||||
is_load = False
|
||||
|
||||
def __init__(self, token: str):
|
||||
self.token = token
|
||||
|
||||
def get_token_details(self):
|
||||
if self.token_details is None and not self.is_load:
|
||||
try:
|
||||
self.token_details = signing.loads(self.token)
|
||||
except Exception as e:
|
||||
self.is_load = True
|
||||
return self.token_details
|
||||
|
||||
|
||||
class TokenAuth(TokenAuthentication):
|
||||
# 重新 authenticate 方法,自定义认证规则
|
||||
def authenticate(self, request):
|
||||
|
|
@ -38,62 +60,11 @@ class TokenAuth(TokenAuthentication):
|
|||
if auth is None:
|
||||
raise AppAuthenticationFailed(1003, '未登录,请先登录')
|
||||
try:
|
||||
if str(auth).startswith("application-"):
|
||||
application_api_key = QuerySet(ApplicationApiKey).filter(secret_key=auth).first()
|
||||
if application_api_key is None:
|
||||
raise AppAuthenticationFailed(500, "secret_key 无效")
|
||||
if not application_api_key.is_active:
|
||||
raise AppAuthenticationFailed(500, "secret_key 无效")
|
||||
permission_list = [Permission(group=Group.APPLICATION,
|
||||
operate=Operate.USE,
|
||||
dynamic_tag=str(
|
||||
application_api_key.application_id)),
|
||||
Permission(group=Group.APPLICATION,
|
||||
operate=Operate.MANAGE,
|
||||
dynamic_tag=str(
|
||||
application_api_key.application_id))
|
||||
]
|
||||
return application_api_key.user, Auth(role_list=[RoleConstants.APPLICATION_KEY],
|
||||
permission_list=permission_list,
|
||||
application_id=application_api_key.application_id)
|
||||
# 解析 token
|
||||
auth_details = signing.loads(auth)
|
||||
cache_token = token_cache.get(auth)
|
||||
if cache_token is None:
|
||||
raise AppAuthenticationFailed(1002, "登录过期")
|
||||
if 'id' in auth_details and auth_details.get('type') == AuthenticationType.USER.value:
|
||||
user = QuerySet(User).get(id=auth_details['id'])
|
||||
# 续期
|
||||
token_cache.touch(auth, timeout=JWT_AUTH['JWT_EXPIRATION_DELTA'].total_seconds())
|
||||
rule = RoleConstants[user.role]
|
||||
permission_list = get_permission_list_by_role(RoleConstants[user.role])
|
||||
# 获取用户的应用和知识库的权限
|
||||
permission_list += get_user_dynamics_permission(str(user.id))
|
||||
return user, Auth(role_list=[rule],
|
||||
permission_list=permission_list)
|
||||
if 'application_id' in auth_details and 'access_token' in auth_details and auth_details.get(
|
||||
'type') == AuthenticationType.APPLICATION_ACCESS_TOKEN.value:
|
||||
application_access_token = QuerySet(ApplicationAccessToken).filter(
|
||||
application_id=auth_details.get('application_id')).first()
|
||||
if application_access_token is None:
|
||||
raise AppAuthenticationFailed(1002, "身份验证信息不正确")
|
||||
if not application_access_token.is_active:
|
||||
raise AppAuthenticationFailed(1002, "身份验证信息不正确")
|
||||
if not application_access_token.access_token == auth_details.get('access_token'):
|
||||
raise AppAuthenticationFailed(1002, "身份验证信息不正确")
|
||||
return application_access_token.application.user, Auth(
|
||||
role_list=[RoleConstants.APPLICATION_ACCESS_TOKEN],
|
||||
permission_list=[
|
||||
Permission(group=Group.APPLICATION,
|
||||
operate=Operate.USE,
|
||||
dynamic_tag=str(
|
||||
application_access_token.application_id))],
|
||||
application_id=application_access_token.application_id
|
||||
)
|
||||
|
||||
else:
|
||||
raise AppAuthenticationFailed(1002, "身份验证信息不正确!非法用户")
|
||||
|
||||
token_details = TokenDetails(auth)
|
||||
for handle in handles:
|
||||
if handle.support(request, auth, token_details.get_token_details):
|
||||
return handle.handle(request, auth, token_details.get_token_details)
|
||||
raise AppAuthenticationFailed(1002, "身份验证信息不正确!非法用户")
|
||||
except Exception as e:
|
||||
traceback.format_exc()
|
||||
if isinstance(e, AppEmbedIdentityFailed) or isinstance(e, AppChatNumOutOfBoundsFailed):
|
||||
|
|
|
|||
|
|
@ -0,0 +1,19 @@
|
|||
# coding=utf-8
|
||||
"""
|
||||
@project: qabot
|
||||
@Author:虎
|
||||
@file: authenticate.py
|
||||
@date:2024/3/14 03:02
|
||||
@desc: 认证处理器
|
||||
"""
|
||||
from abc import ABC, abstractmethod
|
||||
|
||||
|
||||
class AuthBaseHandle(ABC):
|
||||
@abstractmethod
|
||||
def support(self, request, token: str, get_token_details):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def handle(self, request, token: str, get_token_details):
|
||||
pass
|
||||
|
|
@ -0,0 +1,41 @@
|
|||
# coding=utf-8
|
||||
"""
|
||||
@project: qabot
|
||||
@Author:虎
|
||||
@file: authenticate.py
|
||||
@date:2024/3/14 03:02
|
||||
@desc: 应用api key认证
|
||||
"""
|
||||
from django.db.models import QuerySet
|
||||
|
||||
from application.models.api_key_model import ApplicationApiKey
|
||||
from common.auth.handle.auth_base_handle import AuthBaseHandle
|
||||
from common.constants.authentication_type import AuthenticationType
|
||||
from common.constants.permission_constants import Permission, Group, Operate, RoleConstants, Auth
|
||||
from common.exception.app_exception import AppAuthenticationFailed
|
||||
|
||||
|
||||
class ApplicationKey(AuthBaseHandle):
|
||||
def handle(self, request, token: str, get_token_details):
|
||||
application_api_key = QuerySet(ApplicationApiKey).filter(secret_key=token).first()
|
||||
if application_api_key is None:
|
||||
raise AppAuthenticationFailed(500, "secret_key 无效")
|
||||
if not application_api_key.is_active:
|
||||
raise AppAuthenticationFailed(500, "secret_key 无效")
|
||||
permission_list = [Permission(group=Group.APPLICATION,
|
||||
operate=Operate.USE,
|
||||
dynamic_tag=str(
|
||||
application_api_key.application_id)),
|
||||
Permission(group=Group.APPLICATION,
|
||||
operate=Operate.MANAGE,
|
||||
dynamic_tag=str(
|
||||
application_api_key.application_id))
|
||||
]
|
||||
return application_api_key.user, Auth(role_list=[RoleConstants.APPLICATION_KEY],
|
||||
permission_list=permission_list,
|
||||
application_id=application_api_key.application_id,
|
||||
client_id=token,
|
||||
client_type=AuthenticationType.API_KEY.value)
|
||||
|
||||
def support(self, request, token: str, get_token_details):
|
||||
return str(token).startswith("application-")
|
||||
|
|
@ -0,0 +1,49 @@
|
|||
# coding=utf-8
|
||||
"""
|
||||
@project: qabot
|
||||
@Author:虎
|
||||
@file: authenticate.py
|
||||
@date:2024/3/14 03:02
|
||||
@desc: 公共访问连接认证
|
||||
"""
|
||||
from django.db.models import QuerySet
|
||||
|
||||
from application.models.api_key_model import ApplicationAccessToken
|
||||
from common.auth.handle.auth_base_handle import AuthBaseHandle
|
||||
from common.constants.authentication_type import AuthenticationType
|
||||
from common.constants.permission_constants import RoleConstants, Permission, Group, Operate, Auth
|
||||
from common.exception.app_exception import AppAuthenticationFailed
|
||||
|
||||
|
||||
class PublicAccessToken(AuthBaseHandle):
|
||||
def support(self, request, token: str, get_token_details):
|
||||
token_details = get_token_details()
|
||||
if token_details is None:
|
||||
return False
|
||||
return (
|
||||
'application_id' in token_details and
|
||||
'access_token' in token_details and
|
||||
token_details.get('type') == AuthenticationType.APPLICATION_ACCESS_TOKEN.value)
|
||||
|
||||
def handle(self, request, token: str, get_token_details):
|
||||
auth_details = get_token_details()
|
||||
application_access_token = QuerySet(ApplicationAccessToken).filter(
|
||||
application_id=auth_details.get('application_id')).first()
|
||||
if application_access_token is None:
|
||||
raise AppAuthenticationFailed(1002, "身份验证信息不正确")
|
||||
if not application_access_token.is_active:
|
||||
raise AppAuthenticationFailed(1002, "身份验证信息不正确")
|
||||
if not application_access_token.access_token == auth_details.get('access_token'):
|
||||
raise AppAuthenticationFailed(1002, "身份验证信息不正确")
|
||||
|
||||
return application_access_token.application.user, Auth(
|
||||
role_list=[RoleConstants.APPLICATION_ACCESS_TOKEN],
|
||||
permission_list=[
|
||||
Permission(group=Group.APPLICATION,
|
||||
operate=Operate.USE,
|
||||
dynamic_tag=str(
|
||||
application_access_token.application_id))],
|
||||
application_id=application_access_token.application_id,
|
||||
client_id=auth_details.get('client_id'),
|
||||
client_type=AuthenticationType.APPLICATION_ACCESS_TOKEN.value
|
||||
)
|
||||
|
|
@ -0,0 +1,46 @@
|
|||
# coding=utf-8
|
||||
"""
|
||||
@project: qabot
|
||||
@Author:虎
|
||||
@file: authenticate.py
|
||||
@date:2024/3/14 03:02
|
||||
@desc: 用户认证
|
||||
"""
|
||||
from django.db.models import QuerySet
|
||||
|
||||
from common.auth.handle.auth_base_handle import AuthBaseHandle
|
||||
from common.constants.authentication_type import AuthenticationType
|
||||
from common.constants.permission_constants import RoleConstants, get_permission_list_by_role, Auth
|
||||
from common.exception.app_exception import AppAuthenticationFailed
|
||||
from smartdoc.settings import JWT_AUTH
|
||||
from users.models import User
|
||||
from django.core import cache
|
||||
|
||||
from users.models.user import get_user_dynamics_permission
|
||||
|
||||
token_cache = cache.caches['token_cache']
|
||||
|
||||
|
||||
class UserToken(AuthBaseHandle):
|
||||
def support(self, request, token: str, get_token_details):
|
||||
auth_details = get_token_details()
|
||||
if auth_details is None:
|
||||
return False
|
||||
return 'id' in auth_details and auth_details.get('type') == AuthenticationType.USER.value
|
||||
|
||||
def handle(self, request, token: str, get_token_details):
|
||||
cache_token = token_cache.get(token)
|
||||
if cache_token is None:
|
||||
raise AppAuthenticationFailed(1002, "登录过期")
|
||||
auth_details = get_token_details()
|
||||
user = QuerySet(User).get(id=auth_details['id'])
|
||||
# 续期
|
||||
token_cache.touch(token, timeout=JWT_AUTH['JWT_EXPIRATION_DELTA'].total_seconds())
|
||||
rule = RoleConstants[user.role]
|
||||
permission_list = get_permission_list_by_role(RoleConstants[user.role])
|
||||
# 获取用户的应用和知识库的权限
|
||||
permission_list += get_user_dynamics_permission(str(user.id))
|
||||
return user, Auth(role_list=[rule],
|
||||
permission_list=permission_list,
|
||||
client_id=str(user.id),
|
||||
client_type=AuthenticationType.USER.value)
|
||||
|
|
@ -10,7 +10,9 @@ from enum import Enum
|
|||
|
||||
|
||||
class AuthenticationType(Enum):
|
||||
# 或者
|
||||
# 普通用户
|
||||
USER = "USER"
|
||||
# 并且
|
||||
# 公共访问链接
|
||||
APPLICATION_ACCESS_TOKEN = "APPLICATION_ACCESS_TOKEN"
|
||||
# key API
|
||||
API_KEY = "API_KEY"
|
||||
|
|
|
|||
|
|
@ -151,10 +151,12 @@ class Auth:
|
|||
用于存储当前用户的角色和权限
|
||||
"""
|
||||
|
||||
def __init__(self, role_list: List[RoleConstants], permission_list: List[PermissionConstants | Permission],
|
||||
**keywords):
|
||||
def __init__(self, role_list: List[RoleConstants], permission_list: List[PermissionConstants | Permission]
|
||||
, client_id, client_type, **keywords):
|
||||
self.role_list = role_list
|
||||
self.permission_list = permission_list
|
||||
self.client_id = client_id
|
||||
self.client_type = client_type
|
||||
self.keywords = keywords
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -1,66 +0,0 @@
|
|||
# coding=utf-8
|
||||
"""
|
||||
@project: maxkb
|
||||
@Author:虎
|
||||
@file: chat_cookie_middleware.py
|
||||
@date:2024/3/13 20:13
|
||||
@desc:
|
||||
"""
|
||||
from django.core import cache
|
||||
from django.core import signing
|
||||
from django.db.models import QuerySet
|
||||
from django.utils.deprecation import MiddlewareMixin
|
||||
|
||||
from application.models.api_key_model import ApplicationAccessToken
|
||||
from common.exception.app_exception import AppEmbedIdentityFailed
|
||||
from common.response import result
|
||||
from common.util.common import set_embed_identity_cookie, getRestSeconds
|
||||
from common.util.rsa_util import decrypt
|
||||
|
||||
chat_cache = cache.caches['chat_cache']
|
||||
|
||||
|
||||
class ChatCookieMiddleware(MiddlewareMixin):
|
||||
|
||||
def process_response(self, request, response):
|
||||
if request.path.startswith('/api/application/chat_message') or request.path.startswith(
|
||||
'/api/application/authentication') or request.path.startswith('/api/application/profile'):
|
||||
set_embed_identity_cookie(request, response)
|
||||
if 'embed_identity' in request.COOKIES and request.path.__contains__('/api/application/chat_message/'):
|
||||
embed_identity = request.COOKIES['embed_identity']
|
||||
try:
|
||||
# 如果无法解密 说明embed_identity并非系统颁发
|
||||
value = decrypt(embed_identity)
|
||||
except Exception as e:
|
||||
raise AppEmbedIdentityFailed(1004, '嵌入cookie不正确')
|
||||
# 对话次数+1
|
||||
try:
|
||||
if not chat_cache.incr(value):
|
||||
# 如果修改失败则设置为1
|
||||
chat_cache.set(value, 1,
|
||||
timeout=getRestSeconds())
|
||||
except Exception as e:
|
||||
# 如果修改失败则设置为1 证明 key不存在
|
||||
chat_cache.set(value, 1,
|
||||
timeout=getRestSeconds())
|
||||
return response
|
||||
|
||||
def process_request(self, request):
|
||||
if 'embed_identity' in request.COOKIES and request.path.__contains__('/api/application/chat_message/'):
|
||||
auth = request.META.get('HTTP_AUTHORIZATION', None
|
||||
)
|
||||
auth_details = signing.loads(auth)
|
||||
application_access_token = QuerySet(ApplicationAccessToken).filter(
|
||||
application_id=auth_details.get('application_id')).first()
|
||||
embed_identity = request.COOKIES['embed_identity']
|
||||
try:
|
||||
# 如果无法解密 说明embed_identity并非系统颁发
|
||||
value = decrypt(embed_identity)
|
||||
except Exception as e:
|
||||
return result.Result(1003,
|
||||
message='访问次数超过今日访问量', response_status=460)
|
||||
embed_identity_number = chat_cache.get(value)
|
||||
if embed_identity_number is not None:
|
||||
if application_access_token.access_num <= embed_identity_number:
|
||||
return result.Result(1003,
|
||||
message='访问次数超过今日访问量', response_status=461)
|
||||
|
|
@ -6,38 +6,10 @@
|
|||
@date:2023/10/16 16:42
|
||||
@desc:
|
||||
"""
|
||||
import datetime
|
||||
import importlib
|
||||
import uuid
|
||||
from functools import reduce
|
||||
from typing import Dict, List
|
||||
|
||||
from django.core import cache
|
||||
|
||||
from .rsa_util import encrypt
|
||||
|
||||
chat_cache = cache.caches['chat_cache']
|
||||
|
||||
|
||||
def set_embed_identity_cookie(request, response):
|
||||
if 'embed_identity' in request.COOKIES:
|
||||
embed_identity = request.COOKIES['embed_identity']
|
||||
else:
|
||||
value = str(uuid.uuid1())
|
||||
embed_identity = encrypt(value)
|
||||
chat_cache.set(value, 0, timeout=getRestSeconds())
|
||||
response.set_cookie("embed_identity", embed_identity, max_age=3600 * 24 * 100, samesite='None',
|
||||
secure=True)
|
||||
return response
|
||||
|
||||
|
||||
def getRestSeconds():
|
||||
now = datetime.datetime.now()
|
||||
today_begin = datetime.datetime(now.year, now.month, now.day, 0, 0, 0)
|
||||
tomorrow_begin = today_begin + datetime.timedelta(days=1)
|
||||
rest_seconds = (tomorrow_begin - now).seconds
|
||||
return rest_seconds
|
||||
|
||||
|
||||
def sub_array(array: List, item_num=10):
|
||||
result = []
|
||||
|
|
|
|||
|
|
@ -46,8 +46,7 @@ MIDDLEWARE = [
|
|||
'django.contrib.sessions.middleware.SessionMiddleware',
|
||||
'django.middleware.common.CommonMiddleware',
|
||||
'django.contrib.messages.middleware.MessageMiddleware',
|
||||
'common.middleware.static_headers_middleware.StaticHeadersMiddleware',
|
||||
'common.middleware.chat_cookie_middleware.ChatCookieMiddleware'
|
||||
'common.middleware.static_headers_middleware.StaticHeadersMiddleware'
|
||||
|
||||
]
|
||||
|
||||
|
|
|
|||
|
|
@ -17,7 +17,7 @@ from common.db.sql_execute import select_list
|
|||
from common.util.file_util import get_file_content
|
||||
from smartdoc.conf import PROJECT_DIR
|
||||
|
||||
__all__ = ["User", "password_encrypt"]
|
||||
__all__ = ["User", "password_encrypt", 'get_user_dynamics_permission']
|
||||
|
||||
|
||||
def password_encrypt(raw_password):
|
||||
|
|
|
|||
Loading…
Reference in New Issue