mirror of
https://github.com/1Panel-dev/MaxKB.git
synced 2025-12-25 17:22:55 +00:00
refactor: 重构部分代码 (#864)
This commit is contained in:
parent
0e3acd33a5
commit
0d2524df1d
|
|
@ -20,7 +20,7 @@ from common.field.common import InstanceField
|
|||
from common.util.field_message import ErrMessage
|
||||
from django.core import cache
|
||||
|
||||
chat_cache = cache.caches['model_cache']
|
||||
chat_cache = cache.caches['chat_cache']
|
||||
|
||||
|
||||
def write_context(step_variable: Dict, global_variable: Dict, node, workflow):
|
||||
|
|
|
|||
|
|
@ -28,10 +28,13 @@ from application.models import Application, ApplicationDatasetMapping, Applicati
|
|||
from application.models.api_key_model import ApplicationAccessToken, ApplicationApiKey
|
||||
from common.config.embedding_config import VectorStore
|
||||
from common.constants.authentication_type import AuthenticationType
|
||||
from common.constants.cache_code_constants import CacheCodeConstants
|
||||
from common.db.search import get_dynamics_model, native_search, native_page_search
|
||||
from common.db.sql_execute import select_list
|
||||
from common.exception.app_exception import AppApiException, NotFound404, AppUnauthorizedFailed
|
||||
from common.field.common import UploadedImageField
|
||||
from common.middleware.cross_domain_middleware import get_application_api_key
|
||||
from common.middleware.static_headers_middleware import get_application_access_token
|
||||
from common.models.db_model_manage import DBModelManage
|
||||
from common.util.common import valid_license
|
||||
from common.util.field_message import ErrMessage
|
||||
|
|
@ -44,8 +47,7 @@ from setting.models.model_management import Model
|
|||
from setting.serializers.provider_serializers import ModelSerializer
|
||||
from smartdoc.conf import PROJECT_DIR
|
||||
|
||||
token_cache = cache.caches['token_cache']
|
||||
chat_cache = cache.caches['model_cache']
|
||||
chat_cache = cache.caches['chat_cache']
|
||||
|
||||
|
||||
class ModelDatasetAssociation(serializers.Serializer):
|
||||
|
|
@ -251,6 +253,8 @@ class ApplicationSerializer(serializers.Serializer):
|
|||
if 'is_active' in instance:
|
||||
application_access_token.is_active = instance.get("is_active")
|
||||
if 'access_token_reset' in instance and instance.get('access_token_reset'):
|
||||
cache.cache.delete(application_access_token.access_token,
|
||||
version=CacheCodeConstants.APPLICATION_ACCESS_TOKEN_CACHE.value)
|
||||
application_access_token.access_token = hashlib.md5(str(uuid.uuid1()).encode()).hexdigest()[8:24]
|
||||
if 'access_num' in instance and instance.get('access_num') is not None:
|
||||
application_access_token.access_num = instance.get("access_num")
|
||||
|
|
@ -261,6 +265,7 @@ class ApplicationSerializer(serializers.Serializer):
|
|||
if 'show_source' in instance and instance.get('show_source') is not None:
|
||||
application_access_token.show_source = instance.get('show_source')
|
||||
application_access_token.save()
|
||||
get_application_access_token(application_access_token.access_token, False)
|
||||
return self.one(with_valid=False)
|
||||
|
||||
def one(self, with_valid=True):
|
||||
|
|
@ -526,6 +531,9 @@ class ApplicationSerializer(serializers.Serializer):
|
|||
image.save()
|
||||
application.icon = f'/api/image/{image_id}'
|
||||
application.save()
|
||||
application_access_token = QuerySet(ApplicationAccessToken).filter(
|
||||
application_id=self.data.get('application_id')).first()
|
||||
get_application_access_token(application_access_token.access_token, False)
|
||||
return {**ApplicationSerializer.Query.reset_application(ApplicationSerializerModel(application).data)}
|
||||
|
||||
class Operate(serializers.Serializer):
|
||||
|
|
@ -686,6 +694,9 @@ class ApplicationSerializer(serializers.Serializer):
|
|||
|
||||
self.save_application_mapping(application_dataset_id_list, dataset_id_list, application_id)
|
||||
chat_cache.clear_by_application_id(application_id)
|
||||
application_access_token = QuerySet(ApplicationAccessToken).filter(application_id=application_id).first()
|
||||
# 更新缓存数据
|
||||
get_application_access_token(application_access_token.access_token, False)
|
||||
return self.one(with_valid=False)
|
||||
|
||||
@staticmethod
|
||||
|
|
@ -767,8 +778,11 @@ class ApplicationSerializer(serializers.Serializer):
|
|||
self.is_valid(raise_exception=True)
|
||||
api_key_id = self.data.get("api_key_id")
|
||||
application_id = self.data.get('application_id')
|
||||
QuerySet(ApplicationApiKey).filter(id=api_key_id,
|
||||
application_id=application_id).delete()
|
||||
application_api_key = QuerySet(ApplicationApiKey).filter(id=api_key_id,
|
||||
application_id=application_id).first()
|
||||
cache.cache.delete(application_api_key.secret_key,
|
||||
version=CacheCodeConstants.APPLICATION_API_KEY_CACHE.value)
|
||||
application_api_key.delete()
|
||||
|
||||
def edit(self, instance, with_valid=True):
|
||||
if with_valid:
|
||||
|
|
@ -787,3 +801,5 @@ class ApplicationSerializer(serializers.Serializer):
|
|||
if 'cross_domain_list' in instance and instance.get('cross_domain_list') is not None:
|
||||
application_api_key.cross_domain_list = instance.get('cross_domain_list')
|
||||
application_api_key.save()
|
||||
# 写入缓存
|
||||
get_application_api_key(application_api_key.secret_key, False)
|
||||
|
|
|
|||
|
|
@ -6,7 +6,6 @@
|
|||
@date:2023/11/14 13:51
|
||||
@desc:
|
||||
"""
|
||||
import json
|
||||
import uuid
|
||||
from typing import List
|
||||
from uuid import UUID
|
||||
|
|
@ -31,13 +30,11 @@ from application.models.api_key_model import ApplicationPublicAccessClient, Appl
|
|||
from common.constants.authentication_type import AuthenticationType
|
||||
from common.exception.app_exception import AppApiException, AppChatNumOutOfBoundsFailed
|
||||
from common.util.field_message import ErrMessage
|
||||
from common.util.rsa_util import rsa_long_decrypt
|
||||
from common.util.split_model import flat_map
|
||||
from dataset.models import Paragraph, Document
|
||||
from setting.models import Model, Status
|
||||
from setting.models_provider.constants.model_provider_constants import ModelProvideConstants
|
||||
|
||||
chat_cache = caches['model_cache']
|
||||
chat_cache = caches['chat_cache']
|
||||
|
||||
|
||||
class ChatInfo:
|
||||
|
|
|
|||
|
|
@ -44,7 +44,7 @@ from setting.models import Model
|
|||
from setting.models_provider import get_model
|
||||
from smartdoc.conf import PROJECT_DIR
|
||||
|
||||
chat_cache = caches['model_cache']
|
||||
chat_cache = caches['chat_cache']
|
||||
|
||||
|
||||
class WorkFlowSerializers(serializers.Serializer):
|
||||
|
|
|
|||
|
|
@ -28,7 +28,7 @@ from common.swagger_api.common_api import CommonApi
|
|||
from common.util.common import query_params_to_single_dict
|
||||
from dataset.serializers.dataset_serializers import DataSetSerializers
|
||||
|
||||
chat_cache = cache.caches['model_cache']
|
||||
chat_cache = cache.caches['chat_cache']
|
||||
|
||||
|
||||
class ApplicationStatistics(APIView):
|
||||
|
|
|
|||
|
|
@ -0,0 +1,18 @@
|
|||
# coding=utf-8
|
||||
"""
|
||||
@project: MaxKB
|
||||
@Author:虎
|
||||
@file: cache_code_constants.py
|
||||
@date:2024/7/24 18:20
|
||||
@desc:
|
||||
"""
|
||||
from enum import Enum
|
||||
|
||||
|
||||
class CacheCodeConstants(Enum):
|
||||
# 应用ACCESS_TOKEN缓存
|
||||
APPLICATION_ACCESS_TOKEN_CACHE = 'APPLICATION_ACCESS_TOKEN_CACHE'
|
||||
# 静态资源缓存
|
||||
STATIC_RESOURCE_CACHE = 'STATIC_RESOURCE_CACHE'
|
||||
# 应用API_KEY缓存
|
||||
APPLICATION_API_KEY_CACHE = 'APPLICATION_API_KEY_CACHE'
|
||||
|
|
@ -11,6 +11,17 @@ from django.http import HttpResponse
|
|||
from django.utils.deprecation import MiddlewareMixin
|
||||
|
||||
from application.models.api_key_model import ApplicationApiKey
|
||||
from common.constants.cache_code_constants import CacheCodeConstants
|
||||
from common.util.cache_util import get_cache
|
||||
|
||||
|
||||
@get_cache(cache_key=lambda secret_key, use_get_data: secret_key,
|
||||
use_get_data=lambda secret_key, use_get_data: use_get_data,
|
||||
version=CacheCodeConstants.APPLICATION_API_KEY_CACHE.value)
|
||||
def get_application_api_key(secret_key, use_get_data):
|
||||
application_api_key = QuerySet(ApplicationApiKey).filter(secret_key=secret_key).first()
|
||||
return {'allow_cross_domain': application_api_key.allow_cross_domain,
|
||||
'cross_domain_list': application_api_key.cross_domain_list}
|
||||
|
||||
|
||||
class CrossDomainMiddleware(MiddlewareMixin):
|
||||
|
|
@ -27,13 +38,15 @@ class CrossDomainMiddleware(MiddlewareMixin):
|
|||
auth = request.META.get('HTTP_AUTHORIZATION')
|
||||
origin = request.META.get('HTTP_ORIGIN')
|
||||
if auth is not None and str(auth).startswith("application-") and origin is not None:
|
||||
application_api_key = QuerySet(ApplicationApiKey).filter(secret_key=auth).first()
|
||||
if application_api_key.allow_cross_domain:
|
||||
application_api_key = get_application_api_key(str(auth), True)
|
||||
cross_domain_list = application_api_key.get('cross_domain_list', [])
|
||||
allow_cross_domain = application_api_key.get('allow_cross_domain', False)
|
||||
if allow_cross_domain:
|
||||
response['Access-Control-Allow-Methods'] = 'GET,POST,DELETE,PUT'
|
||||
response[
|
||||
'Access-Control-Allow-Headers'] = "Origin,X-Requested-With,Content-Type,Accept,Authorization,token"
|
||||
if application_api_key.cross_domain_list is None or len(application_api_key.cross_domain_list) == 0:
|
||||
if cross_domain_list is None or len(cross_domain_list) == 0:
|
||||
response['Access-Control-Allow-Origin'] = "*"
|
||||
elif application_api_key.cross_domain_list.__contains__(origin):
|
||||
elif cross_domain_list.__contains__(origin):
|
||||
response['Access-Control-Allow-Origin'] = origin
|
||||
return response
|
||||
|
|
|
|||
|
|
@ -10,21 +10,40 @@ from django.db.models import QuerySet
|
|||
from django.utils.deprecation import MiddlewareMixin
|
||||
|
||||
from application.models.api_key_model import ApplicationAccessToken
|
||||
from common.constants.cache_code_constants import CacheCodeConstants
|
||||
from common.util.cache_util import get_cache
|
||||
|
||||
|
||||
@get_cache(cache_key=lambda access_token, use_get_data: access_token,
|
||||
use_get_data=lambda access_token, use_get_data: use_get_data,
|
||||
version=CacheCodeConstants.APPLICATION_ACCESS_TOKEN_CACHE.value)
|
||||
def get_application_access_token(access_token, use_get_data):
|
||||
application_access_token = QuerySet(ApplicationAccessToken).filter(access_token=access_token).first()
|
||||
if application_access_token is None:
|
||||
return None
|
||||
return {'white_active': application_access_token.white_active,
|
||||
'white_list': application_access_token.white_list,
|
||||
'application_icon': application_access_token.application.icon,
|
||||
'application_name': application_access_token.application.name}
|
||||
|
||||
|
||||
class StaticHeadersMiddleware(MiddlewareMixin):
|
||||
def process_response(self, request, response):
|
||||
if request.path.startswith('/ui/chat/'):
|
||||
access_token = request.path.replace('/ui/chat/', '')
|
||||
application_access_token = QuerySet(ApplicationAccessToken).filter(access_token=access_token).first()
|
||||
application_access_token = get_application_access_token(access_token, True)
|
||||
if application_access_token is not None:
|
||||
if application_access_token.white_active:
|
||||
white_active = application_access_token.get('white_active', False)
|
||||
white_list = application_access_token.get('white_list', [])
|
||||
application_icon = application_access_token.get('application_icon')
|
||||
application_name = application_access_token.get('application_name')
|
||||
if white_active:
|
||||
# 添加自定义的响应头
|
||||
response[
|
||||
'Content-Security-Policy'] = f'frame-ancestors {" ".join(application_access_token.white_list)}'
|
||||
'Content-Security-Policy'] = f'frame-ancestors {" ".join(white_list)}'
|
||||
response.content = (response.content.decode('utf-8').replace(
|
||||
'<link rel="icon" href="/ui/favicon.ico" />',
|
||||
f'<link rel="icon" href="{application_access_token.application.icon}" />')
|
||||
.replace('<title>MaxKB</title>', f'<title>{application_access_token.application.name}</title>').encode(
|
||||
f'<link rel="icon" href="{application_icon}" />')
|
||||
.replace('<title>MaxKB</title>', f'<title>{application_name}</title>').encode(
|
||||
"utf-8"))
|
||||
return response
|
||||
|
|
|
|||
|
|
@ -0,0 +1,66 @@
|
|||
# coding=utf-8
|
||||
"""
|
||||
@project: MaxKB
|
||||
@Author:虎
|
||||
@file: cache_util.py
|
||||
@date:2024/7/24 19:23
|
||||
@desc:
|
||||
"""
|
||||
from django.core.cache import cache
|
||||
|
||||
|
||||
def get_data_by_default_cache(key: str, get_data, cache_instance=cache, version=None, kwargs=None):
|
||||
"""
|
||||
获取数据, 先从缓存中获取,如果获取不到再调用get_data 获取数据
|
||||
@param kwargs: get_data所需参数
|
||||
@param key: key
|
||||
@param get_data: 获取数据函数
|
||||
@param cache_instance: cache实例
|
||||
@param version: 版本用于隔离
|
||||
@return:
|
||||
"""
|
||||
if kwargs is None:
|
||||
kwargs = {}
|
||||
if cache_instance.has_key(key, version=version):
|
||||
return cache_instance.get(key, version=version)
|
||||
data = get_data(**kwargs)
|
||||
cache_instance.add(key, data, version=version)
|
||||
return data
|
||||
|
||||
|
||||
def set_data_by_default_cache(key: str, get_data, cache_instance=cache, version=None):
|
||||
data = get_data()
|
||||
cache_instance.set(key, data, version=version)
|
||||
return data
|
||||
|
||||
|
||||
def get_cache(cache_key, use_get_data: any = True, cache_instance=cache, version=None):
|
||||
def inner(get_data):
|
||||
def run(*args, **kwargs):
|
||||
key = cache_key(*args, **kwargs) if callable(cache_key) else cache_key
|
||||
is_use_get_data = use_get_data(*args, **kwargs) if callable(use_get_data) else use_get_data
|
||||
if is_use_get_data:
|
||||
if cache_instance.has_key(key, version=version):
|
||||
return cache_instance.get(key, version=version)
|
||||
data = get_data(*args, **kwargs)
|
||||
cache_instance.add(key, data, version=version)
|
||||
return data
|
||||
data = get_data(*args, **kwargs)
|
||||
cache_instance.set(key, data, version=version)
|
||||
return data
|
||||
|
||||
return run
|
||||
|
||||
return inner
|
||||
|
||||
|
||||
def del_cache(cache_key, cache_instance=cache, version=None):
|
||||
def inner(func):
|
||||
def run(*args, **kwargs):
|
||||
key = cache_key(*args, **kwargs) if callable(cache_key) else cache_key
|
||||
func(*args, **kwargs)
|
||||
cache_instance.delete(key, version=version)
|
||||
|
||||
return run
|
||||
|
||||
return inner
|
||||
|
|
@ -92,9 +92,21 @@ SWAGGER_SETTINGS = {
|
|||
CACHES = {
|
||||
"default": {
|
||||
'BACKEND': 'django.core.cache.backends.locmem.LocMemCache',
|
||||
'LOCATION': 'unique-snowflake',
|
||||
'TIMEOUT': 60 * 30,
|
||||
'OPTIONS': {
|
||||
'MAX_ENTRIES': 150,
|
||||
'CULL_FREQUENCY': 5,
|
||||
}
|
||||
},
|
||||
'model_cache': {
|
||||
'BACKEND': 'common.cache.mem_cache.MemCache'
|
||||
'chat_cache': {
|
||||
'BACKEND': 'common.cache.mem_cache.MemCache',
|
||||
'LOCATION': 'unique-snowflake',
|
||||
'TIMEOUT': 60 * 30,
|
||||
'OPTIONS': {
|
||||
'MAX_ENTRIES': 150,
|
||||
'CULL_FREQUENCY': 5,
|
||||
}
|
||||
},
|
||||
# 存储用户信息
|
||||
'user_cache': {
|
||||
|
|
|
|||
|
|
@ -22,8 +22,10 @@ from django.views import static
|
|||
from rest_framework import status
|
||||
|
||||
from application.urls import urlpatterns as application_urlpatterns
|
||||
from common.constants.cache_code_constants import CacheCodeConstants
|
||||
from common.init.init_doc import init_doc
|
||||
from common.response.result import Result
|
||||
from common.util.cache_util import get_cache
|
||||
from smartdoc import settings
|
||||
from smartdoc.conf import PROJECT_DIR
|
||||
|
||||
|
|
@ -51,6 +53,15 @@ if not settings.DEBUG:
|
|||
pro()
|
||||
|
||||
|
||||
@get_cache(cache_key=lambda index_path: index_path,
|
||||
version=CacheCodeConstants.STATIC_RESOURCE_CACHE.value)
|
||||
def get_index_html(index_path):
|
||||
file = open(index_path, "r", encoding='utf-8')
|
||||
content = file.read()
|
||||
file.close()
|
||||
return content
|
||||
|
||||
|
||||
def page_not_found(request, exception):
|
||||
"""
|
||||
页面不存在处理
|
||||
|
|
@ -60,9 +71,7 @@ def page_not_found(request, exception):
|
|||
index_path = os.path.join(PROJECT_DIR, 'apps', "static", 'ui', 'index.html')
|
||||
if not os.path.exists(index_path):
|
||||
return HttpResponse("页面不存在", status=404)
|
||||
file = open(index_path, "r", encoding='utf-8')
|
||||
content = file.read()
|
||||
file.close()
|
||||
content = get_index_html(index_path)
|
||||
if request.path.startswith('/ui/chat/'):
|
||||
return HttpResponse(content, status=200)
|
||||
return HttpResponse(content, status=200, headers={'X-Frame-Options': 'DENY'})
|
||||
|
|
|
|||
Loading…
Reference in New Issue