refactor: 重构部分代码 (#864)

This commit is contained in:
shaohuzhang1 2024-07-25 10:41:38 +08:00 committed by GitHub
parent 0e3acd33a5
commit 0d2524df1d
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
12 changed files with 176 additions and 25 deletions

View File

@ -20,7 +20,7 @@ from common.field.common import InstanceField
from common.util.field_message import ErrMessage from common.util.field_message import ErrMessage
from django.core import cache from django.core import cache
chat_cache = cache.caches['model_cache'] chat_cache = cache.caches['chat_cache']
def write_context(step_variable: Dict, global_variable: Dict, node, workflow): def write_context(step_variable: Dict, global_variable: Dict, node, workflow):

View File

@ -28,10 +28,13 @@ from application.models import Application, ApplicationDatasetMapping, Applicati
from application.models.api_key_model import ApplicationAccessToken, ApplicationApiKey from application.models.api_key_model import ApplicationAccessToken, ApplicationApiKey
from common.config.embedding_config import VectorStore from common.config.embedding_config import VectorStore
from common.constants.authentication_type import AuthenticationType from common.constants.authentication_type import AuthenticationType
from common.constants.cache_code_constants import CacheCodeConstants
from common.db.search import get_dynamics_model, native_search, native_page_search from common.db.search import get_dynamics_model, native_search, native_page_search
from common.db.sql_execute import select_list from common.db.sql_execute import select_list
from common.exception.app_exception import AppApiException, NotFound404, AppUnauthorizedFailed from common.exception.app_exception import AppApiException, NotFound404, AppUnauthorizedFailed
from common.field.common import UploadedImageField from common.field.common import UploadedImageField
from common.middleware.cross_domain_middleware import get_application_api_key
from common.middleware.static_headers_middleware import get_application_access_token
from common.models.db_model_manage import DBModelManage from common.models.db_model_manage import DBModelManage
from common.util.common import valid_license from common.util.common import valid_license
from common.util.field_message import ErrMessage from common.util.field_message import ErrMessage
@ -44,8 +47,7 @@ from setting.models.model_management import Model
from setting.serializers.provider_serializers import ModelSerializer from setting.serializers.provider_serializers import ModelSerializer
from smartdoc.conf import PROJECT_DIR from smartdoc.conf import PROJECT_DIR
token_cache = cache.caches['token_cache'] chat_cache = cache.caches['chat_cache']
chat_cache = cache.caches['model_cache']
class ModelDatasetAssociation(serializers.Serializer): class ModelDatasetAssociation(serializers.Serializer):
@ -251,6 +253,8 @@ class ApplicationSerializer(serializers.Serializer):
if 'is_active' in instance: if 'is_active' in instance:
application_access_token.is_active = instance.get("is_active") application_access_token.is_active = instance.get("is_active")
if 'access_token_reset' in instance and instance.get('access_token_reset'): if 'access_token_reset' in instance and instance.get('access_token_reset'):
cache.cache.delete(application_access_token.access_token,
version=CacheCodeConstants.APPLICATION_ACCESS_TOKEN_CACHE.value)
application_access_token.access_token = hashlib.md5(str(uuid.uuid1()).encode()).hexdigest()[8:24] application_access_token.access_token = hashlib.md5(str(uuid.uuid1()).encode()).hexdigest()[8:24]
if 'access_num' in instance and instance.get('access_num') is not None: if 'access_num' in instance and instance.get('access_num') is not None:
application_access_token.access_num = instance.get("access_num") application_access_token.access_num = instance.get("access_num")
@ -261,6 +265,7 @@ class ApplicationSerializer(serializers.Serializer):
if 'show_source' in instance and instance.get('show_source') is not None: if 'show_source' in instance and instance.get('show_source') is not None:
application_access_token.show_source = instance.get('show_source') application_access_token.show_source = instance.get('show_source')
application_access_token.save() application_access_token.save()
get_application_access_token(application_access_token.access_token, False)
return self.one(with_valid=False) return self.one(with_valid=False)
def one(self, with_valid=True): def one(self, with_valid=True):
@ -526,6 +531,9 @@ class ApplicationSerializer(serializers.Serializer):
image.save() image.save()
application.icon = f'/api/image/{image_id}' application.icon = f'/api/image/{image_id}'
application.save() application.save()
application_access_token = QuerySet(ApplicationAccessToken).filter(
application_id=self.data.get('application_id')).first()
get_application_access_token(application_access_token.access_token, False)
return {**ApplicationSerializer.Query.reset_application(ApplicationSerializerModel(application).data)} return {**ApplicationSerializer.Query.reset_application(ApplicationSerializerModel(application).data)}
class Operate(serializers.Serializer): class Operate(serializers.Serializer):
@ -686,6 +694,9 @@ class ApplicationSerializer(serializers.Serializer):
self.save_application_mapping(application_dataset_id_list, dataset_id_list, application_id) self.save_application_mapping(application_dataset_id_list, dataset_id_list, application_id)
chat_cache.clear_by_application_id(application_id) chat_cache.clear_by_application_id(application_id)
application_access_token = QuerySet(ApplicationAccessToken).filter(application_id=application_id).first()
# 更新缓存数据
get_application_access_token(application_access_token.access_token, False)
return self.one(with_valid=False) return self.one(with_valid=False)
@staticmethod @staticmethod
@ -767,8 +778,11 @@ class ApplicationSerializer(serializers.Serializer):
self.is_valid(raise_exception=True) self.is_valid(raise_exception=True)
api_key_id = self.data.get("api_key_id") api_key_id = self.data.get("api_key_id")
application_id = self.data.get('application_id') application_id = self.data.get('application_id')
QuerySet(ApplicationApiKey).filter(id=api_key_id, application_api_key = QuerySet(ApplicationApiKey).filter(id=api_key_id,
application_id=application_id).delete() application_id=application_id).first()
cache.cache.delete(application_api_key.secret_key,
version=CacheCodeConstants.APPLICATION_API_KEY_CACHE.value)
application_api_key.delete()
def edit(self, instance, with_valid=True): def edit(self, instance, with_valid=True):
if with_valid: if with_valid:
@ -787,3 +801,5 @@ class ApplicationSerializer(serializers.Serializer):
if 'cross_domain_list' in instance and instance.get('cross_domain_list') is not None: if 'cross_domain_list' in instance and instance.get('cross_domain_list') is not None:
application_api_key.cross_domain_list = instance.get('cross_domain_list') application_api_key.cross_domain_list = instance.get('cross_domain_list')
application_api_key.save() application_api_key.save()
# 写入缓存
get_application_api_key(application_api_key.secret_key, False)

View File

@ -6,7 +6,6 @@
@date2023/11/14 13:51 @date2023/11/14 13:51
@desc: @desc:
""" """
import json
import uuid import uuid
from typing import List from typing import List
from uuid import UUID from uuid import UUID
@ -31,13 +30,11 @@ from application.models.api_key_model import ApplicationPublicAccessClient, Appl
from common.constants.authentication_type import AuthenticationType from common.constants.authentication_type import AuthenticationType
from common.exception.app_exception import AppApiException, AppChatNumOutOfBoundsFailed from common.exception.app_exception import AppApiException, AppChatNumOutOfBoundsFailed
from common.util.field_message import ErrMessage from common.util.field_message import ErrMessage
from common.util.rsa_util import rsa_long_decrypt
from common.util.split_model import flat_map from common.util.split_model import flat_map
from dataset.models import Paragraph, Document from dataset.models import Paragraph, Document
from setting.models import Model, Status from setting.models import Model, Status
from setting.models_provider.constants.model_provider_constants import ModelProvideConstants
chat_cache = caches['model_cache'] chat_cache = caches['chat_cache']
class ChatInfo: class ChatInfo:

View File

@ -44,7 +44,7 @@ from setting.models import Model
from setting.models_provider import get_model from setting.models_provider import get_model
from smartdoc.conf import PROJECT_DIR from smartdoc.conf import PROJECT_DIR
chat_cache = caches['model_cache'] chat_cache = caches['chat_cache']
class WorkFlowSerializers(serializers.Serializer): class WorkFlowSerializers(serializers.Serializer):

View File

@ -28,7 +28,7 @@ from common.swagger_api.common_api import CommonApi
from common.util.common import query_params_to_single_dict from common.util.common import query_params_to_single_dict
from dataset.serializers.dataset_serializers import DataSetSerializers from dataset.serializers.dataset_serializers import DataSetSerializers
chat_cache = cache.caches['model_cache'] chat_cache = cache.caches['chat_cache']
class ApplicationStatistics(APIView): class ApplicationStatistics(APIView):

View File

@ -0,0 +1,18 @@
# coding=utf-8
"""
@project: MaxKB
@Author
@file cache_code_constants.py
@date2024/7/24 18:20
@desc:
"""
from enum import Enum
class CacheCodeConstants(Enum):
# 应用ACCESS_TOKEN缓存
APPLICATION_ACCESS_TOKEN_CACHE = 'APPLICATION_ACCESS_TOKEN_CACHE'
# 静态资源缓存
STATIC_RESOURCE_CACHE = 'STATIC_RESOURCE_CACHE'
# 应用API_KEY缓存
APPLICATION_API_KEY_CACHE = 'APPLICATION_API_KEY_CACHE'

View File

@ -11,6 +11,17 @@ from django.http import HttpResponse
from django.utils.deprecation import MiddlewareMixin from django.utils.deprecation import MiddlewareMixin
from application.models.api_key_model import ApplicationApiKey from application.models.api_key_model import ApplicationApiKey
from common.constants.cache_code_constants import CacheCodeConstants
from common.util.cache_util import get_cache
@get_cache(cache_key=lambda secret_key, use_get_data: secret_key,
use_get_data=lambda secret_key, use_get_data: use_get_data,
version=CacheCodeConstants.APPLICATION_API_KEY_CACHE.value)
def get_application_api_key(secret_key, use_get_data):
application_api_key = QuerySet(ApplicationApiKey).filter(secret_key=secret_key).first()
return {'allow_cross_domain': application_api_key.allow_cross_domain,
'cross_domain_list': application_api_key.cross_domain_list}
class CrossDomainMiddleware(MiddlewareMixin): class CrossDomainMiddleware(MiddlewareMixin):
@ -27,13 +38,15 @@ class CrossDomainMiddleware(MiddlewareMixin):
auth = request.META.get('HTTP_AUTHORIZATION') auth = request.META.get('HTTP_AUTHORIZATION')
origin = request.META.get('HTTP_ORIGIN') origin = request.META.get('HTTP_ORIGIN')
if auth is not None and str(auth).startswith("application-") and origin is not None: if auth is not None and str(auth).startswith("application-") and origin is not None:
application_api_key = QuerySet(ApplicationApiKey).filter(secret_key=auth).first() application_api_key = get_application_api_key(str(auth), True)
if application_api_key.allow_cross_domain: cross_domain_list = application_api_key.get('cross_domain_list', [])
allow_cross_domain = application_api_key.get('allow_cross_domain', False)
if allow_cross_domain:
response['Access-Control-Allow-Methods'] = 'GET,POST,DELETE,PUT' response['Access-Control-Allow-Methods'] = 'GET,POST,DELETE,PUT'
response[ response[
'Access-Control-Allow-Headers'] = "Origin,X-Requested-With,Content-Type,Accept,Authorization,token" 'Access-Control-Allow-Headers'] = "Origin,X-Requested-With,Content-Type,Accept,Authorization,token"
if application_api_key.cross_domain_list is None or len(application_api_key.cross_domain_list) == 0: if cross_domain_list is None or len(cross_domain_list) == 0:
response['Access-Control-Allow-Origin'] = "*" response['Access-Control-Allow-Origin'] = "*"
elif application_api_key.cross_domain_list.__contains__(origin): elif cross_domain_list.__contains__(origin):
response['Access-Control-Allow-Origin'] = origin response['Access-Control-Allow-Origin'] = origin
return response return response

View File

@ -10,21 +10,40 @@ from django.db.models import QuerySet
from django.utils.deprecation import MiddlewareMixin from django.utils.deprecation import MiddlewareMixin
from application.models.api_key_model import ApplicationAccessToken from application.models.api_key_model import ApplicationAccessToken
from common.constants.cache_code_constants import CacheCodeConstants
from common.util.cache_util import get_cache
@get_cache(cache_key=lambda access_token, use_get_data: access_token,
use_get_data=lambda access_token, use_get_data: use_get_data,
version=CacheCodeConstants.APPLICATION_ACCESS_TOKEN_CACHE.value)
def get_application_access_token(access_token, use_get_data):
application_access_token = QuerySet(ApplicationAccessToken).filter(access_token=access_token).first()
if application_access_token is None:
return None
return {'white_active': application_access_token.white_active,
'white_list': application_access_token.white_list,
'application_icon': application_access_token.application.icon,
'application_name': application_access_token.application.name}
class StaticHeadersMiddleware(MiddlewareMixin): class StaticHeadersMiddleware(MiddlewareMixin):
def process_response(self, request, response): def process_response(self, request, response):
if request.path.startswith('/ui/chat/'): if request.path.startswith('/ui/chat/'):
access_token = request.path.replace('/ui/chat/', '') access_token = request.path.replace('/ui/chat/', '')
application_access_token = QuerySet(ApplicationAccessToken).filter(access_token=access_token).first() application_access_token = get_application_access_token(access_token, True)
if application_access_token is not None: if application_access_token is not None:
if application_access_token.white_active: white_active = application_access_token.get('white_active', False)
white_list = application_access_token.get('white_list', [])
application_icon = application_access_token.get('application_icon')
application_name = application_access_token.get('application_name')
if white_active:
# 添加自定义的响应头 # 添加自定义的响应头
response[ response[
'Content-Security-Policy'] = f'frame-ancestors {" ".join(application_access_token.white_list)}' 'Content-Security-Policy'] = f'frame-ancestors {" ".join(white_list)}'
response.content = (response.content.decode('utf-8').replace( response.content = (response.content.decode('utf-8').replace(
'<link rel="icon" href="/ui/favicon.ico" />', '<link rel="icon" href="/ui/favicon.ico" />',
f'<link rel="icon" href="{application_access_token.application.icon}" />') f'<link rel="icon" href="{application_icon}" />')
.replace('<title>MaxKB</title>', f'<title>{application_access_token.application.name}</title>').encode( .replace('<title>MaxKB</title>', f'<title>{application_name}</title>').encode(
"utf-8")) "utf-8"))
return response return response

View File

@ -0,0 +1,66 @@
# coding=utf-8
"""
@project: MaxKB
@Author
@file cache_util.py
@date2024/7/24 19:23
@desc:
"""
from django.core.cache import cache
def get_data_by_default_cache(key: str, get_data, cache_instance=cache, version=None, kwargs=None):
"""
获取数据, 先从缓存中获取,如果获取不到再调用get_data 获取数据
@param kwargs: get_data所需参数
@param key: key
@param get_data: 获取数据函数
@param cache_instance: cache实例
@param version: 版本用于隔离
@return:
"""
if kwargs is None:
kwargs = {}
if cache_instance.has_key(key, version=version):
return cache_instance.get(key, version=version)
data = get_data(**kwargs)
cache_instance.add(key, data, version=version)
return data
def set_data_by_default_cache(key: str, get_data, cache_instance=cache, version=None):
data = get_data()
cache_instance.set(key, data, version=version)
return data
def get_cache(cache_key, use_get_data: any = True, cache_instance=cache, version=None):
def inner(get_data):
def run(*args, **kwargs):
key = cache_key(*args, **kwargs) if callable(cache_key) else cache_key
is_use_get_data = use_get_data(*args, **kwargs) if callable(use_get_data) else use_get_data
if is_use_get_data:
if cache_instance.has_key(key, version=version):
return cache_instance.get(key, version=version)
data = get_data(*args, **kwargs)
cache_instance.add(key, data, version=version)
return data
data = get_data(*args, **kwargs)
cache_instance.set(key, data, version=version)
return data
return run
return inner
def del_cache(cache_key, cache_instance=cache, version=None):
def inner(func):
def run(*args, **kwargs):
key = cache_key(*args, **kwargs) if callable(cache_key) else cache_key
func(*args, **kwargs)
cache_instance.delete(key, version=version)
return run
return inner

View File

@ -92,9 +92,21 @@ SWAGGER_SETTINGS = {
CACHES = { CACHES = {
"default": { "default": {
'BACKEND': 'django.core.cache.backends.locmem.LocMemCache', 'BACKEND': 'django.core.cache.backends.locmem.LocMemCache',
'LOCATION': 'unique-snowflake',
'TIMEOUT': 60 * 30,
'OPTIONS': {
'MAX_ENTRIES': 150,
'CULL_FREQUENCY': 5,
}
}, },
'model_cache': { 'chat_cache': {
'BACKEND': 'common.cache.mem_cache.MemCache' 'BACKEND': 'common.cache.mem_cache.MemCache',
'LOCATION': 'unique-snowflake',
'TIMEOUT': 60 * 30,
'OPTIONS': {
'MAX_ENTRIES': 150,
'CULL_FREQUENCY': 5,
}
}, },
# 存储用户信息 # 存储用户信息
'user_cache': { 'user_cache': {

View File

@ -22,8 +22,10 @@ from django.views import static
from rest_framework import status from rest_framework import status
from application.urls import urlpatterns as application_urlpatterns from application.urls import urlpatterns as application_urlpatterns
from common.constants.cache_code_constants import CacheCodeConstants
from common.init.init_doc import init_doc from common.init.init_doc import init_doc
from common.response.result import Result from common.response.result import Result
from common.util.cache_util import get_cache
from smartdoc import settings from smartdoc import settings
from smartdoc.conf import PROJECT_DIR from smartdoc.conf import PROJECT_DIR
@ -51,6 +53,15 @@ if not settings.DEBUG:
pro() pro()
@get_cache(cache_key=lambda index_path: index_path,
version=CacheCodeConstants.STATIC_RESOURCE_CACHE.value)
def get_index_html(index_path):
file = open(index_path, "r", encoding='utf-8')
content = file.read()
file.close()
return content
def page_not_found(request, exception): def page_not_found(request, exception):
""" """
页面不存在处理 页面不存在处理
@ -60,9 +71,7 @@ def page_not_found(request, exception):
index_path = os.path.join(PROJECT_DIR, 'apps', "static", 'ui', 'index.html') index_path = os.path.join(PROJECT_DIR, 'apps', "static", 'ui', 'index.html')
if not os.path.exists(index_path): if not os.path.exists(index_path):
return HttpResponse("页面不存在", status=404) return HttpResponse("页面不存在", status=404)
file = open(index_path, "r", encoding='utf-8') content = get_index_html(index_path)
content = file.read()
file.close()
if request.path.startswith('/ui/chat/'): if request.path.startswith('/ui/chat/'):
return HttpResponse(content, status=200) return HttpResponse(content, status=200)
return HttpResponse(content, status=200, headers={'X-Frame-Options': 'DENY'}) return HttpResponse(content, status=200, headers={'X-Frame-Options': 'DENY'})

View File

@ -74,6 +74,7 @@ if __name__ == '__main__':
elif action == "collect_static": elif action == "collect_static":
collect_static() collect_static()
elif action == 'dev': elif action == 'dev':
collect_static()
perform_db_migrate() perform_db_migrate()
runserver() runserver()
else: else: