From 0d2524df1d02f52f10fa7e290771e88165e83158 Mon Sep 17 00:00:00 2001 From: shaohuzhang1 <80892890+shaohuzhang1@users.noreply.github.com> Date: Thu, 25 Jul 2024 10:41:38 +0800 Subject: [PATCH] =?UTF-8?q?refactor:=20=E9=87=8D=E6=9E=84=E9=83=A8?= =?UTF-8?q?=E5=88=86=E4=BB=A3=E7=A0=81=20(#864)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- apps/application/flow/i_step_node.py | 2 +- .../serializers/application_serializers.py | 24 +++++-- .../serializers/chat_message_serializers.py | 5 +- .../serializers/chat_serializers.py | 2 +- apps/application/views/application_views.py | 2 +- apps/common/constants/cache_code_constants.py | 18 +++++ .../middleware/cross_domain_middleware.py | 21 ++++-- .../middleware/static_headers_middleware.py | 29 ++++++-- apps/common/util/cache_util.py | 66 +++++++++++++++++++ apps/smartdoc/settings/base.py | 16 ++++- apps/smartdoc/urls.py | 15 ++++- main.py | 1 + 12 files changed, 176 insertions(+), 25 deletions(-) create mode 100644 apps/common/constants/cache_code_constants.py create mode 100644 apps/common/util/cache_util.py diff --git a/apps/application/flow/i_step_node.py b/apps/application/flow/i_step_node.py index 0d741bd39..9c2e1c544 100644 --- a/apps/application/flow/i_step_node.py +++ b/apps/application/flow/i_step_node.py @@ -20,7 +20,7 @@ from common.field.common import InstanceField from common.util.field_message import ErrMessage from django.core import cache -chat_cache = cache.caches['model_cache'] +chat_cache = cache.caches['chat_cache'] def write_context(step_variable: Dict, global_variable: Dict, node, workflow): diff --git a/apps/application/serializers/application_serializers.py b/apps/application/serializers/application_serializers.py index 87da221eb..503313ec1 100644 --- a/apps/application/serializers/application_serializers.py +++ b/apps/application/serializers/application_serializers.py @@ -28,10 +28,13 @@ from application.models import Application, ApplicationDatasetMapping, Applicati from application.models.api_key_model import ApplicationAccessToken, ApplicationApiKey from common.config.embedding_config import VectorStore from common.constants.authentication_type import AuthenticationType +from common.constants.cache_code_constants import CacheCodeConstants from common.db.search import get_dynamics_model, native_search, native_page_search from common.db.sql_execute import select_list from common.exception.app_exception import AppApiException, NotFound404, AppUnauthorizedFailed from common.field.common import UploadedImageField +from common.middleware.cross_domain_middleware import get_application_api_key +from common.middleware.static_headers_middleware import get_application_access_token from common.models.db_model_manage import DBModelManage from common.util.common import valid_license from common.util.field_message import ErrMessage @@ -44,8 +47,7 @@ from setting.models.model_management import Model from setting.serializers.provider_serializers import ModelSerializer from smartdoc.conf import PROJECT_DIR -token_cache = cache.caches['token_cache'] -chat_cache = cache.caches['model_cache'] +chat_cache = cache.caches['chat_cache'] class ModelDatasetAssociation(serializers.Serializer): @@ -251,6 +253,8 @@ class ApplicationSerializer(serializers.Serializer): if 'is_active' in instance: application_access_token.is_active = instance.get("is_active") if 'access_token_reset' in instance and instance.get('access_token_reset'): + cache.cache.delete(application_access_token.access_token, + version=CacheCodeConstants.APPLICATION_ACCESS_TOKEN_CACHE.value) application_access_token.access_token = hashlib.md5(str(uuid.uuid1()).encode()).hexdigest()[8:24] if 'access_num' in instance and instance.get('access_num') is not None: application_access_token.access_num = instance.get("access_num") @@ -261,6 +265,7 @@ class ApplicationSerializer(serializers.Serializer): if 'show_source' in instance and instance.get('show_source') is not None: application_access_token.show_source = instance.get('show_source') application_access_token.save() + get_application_access_token(application_access_token.access_token, False) return self.one(with_valid=False) def one(self, with_valid=True): @@ -526,6 +531,9 @@ class ApplicationSerializer(serializers.Serializer): image.save() application.icon = f'/api/image/{image_id}' application.save() + application_access_token = QuerySet(ApplicationAccessToken).filter( + application_id=self.data.get('application_id')).first() + get_application_access_token(application_access_token.access_token, False) return {**ApplicationSerializer.Query.reset_application(ApplicationSerializerModel(application).data)} class Operate(serializers.Serializer): @@ -686,6 +694,9 @@ class ApplicationSerializer(serializers.Serializer): self.save_application_mapping(application_dataset_id_list, dataset_id_list, application_id) chat_cache.clear_by_application_id(application_id) + application_access_token = QuerySet(ApplicationAccessToken).filter(application_id=application_id).first() + # 更新缓存数据 + get_application_access_token(application_access_token.access_token, False) return self.one(with_valid=False) @staticmethod @@ -767,8 +778,11 @@ class ApplicationSerializer(serializers.Serializer): self.is_valid(raise_exception=True) api_key_id = self.data.get("api_key_id") application_id = self.data.get('application_id') - QuerySet(ApplicationApiKey).filter(id=api_key_id, - application_id=application_id).delete() + application_api_key = QuerySet(ApplicationApiKey).filter(id=api_key_id, + application_id=application_id).first() + cache.cache.delete(application_api_key.secret_key, + version=CacheCodeConstants.APPLICATION_API_KEY_CACHE.value) + application_api_key.delete() def edit(self, instance, with_valid=True): if with_valid: @@ -787,3 +801,5 @@ class ApplicationSerializer(serializers.Serializer): if 'cross_domain_list' in instance and instance.get('cross_domain_list') is not None: application_api_key.cross_domain_list = instance.get('cross_domain_list') application_api_key.save() + # 写入缓存 + get_application_api_key(application_api_key.secret_key, False) diff --git a/apps/application/serializers/chat_message_serializers.py b/apps/application/serializers/chat_message_serializers.py index ecfd30070..b19a8ed5b 100644 --- a/apps/application/serializers/chat_message_serializers.py +++ b/apps/application/serializers/chat_message_serializers.py @@ -6,7 +6,6 @@ @date:2023/11/14 13:51 @desc: """ -import json import uuid from typing import List from uuid import UUID @@ -31,13 +30,11 @@ from application.models.api_key_model import ApplicationPublicAccessClient, Appl from common.constants.authentication_type import AuthenticationType from common.exception.app_exception import AppApiException, AppChatNumOutOfBoundsFailed from common.util.field_message import ErrMessage -from common.util.rsa_util import rsa_long_decrypt from common.util.split_model import flat_map from dataset.models import Paragraph, Document from setting.models import Model, Status -from setting.models_provider.constants.model_provider_constants import ModelProvideConstants -chat_cache = caches['model_cache'] +chat_cache = caches['chat_cache'] class ChatInfo: diff --git a/apps/application/serializers/chat_serializers.py b/apps/application/serializers/chat_serializers.py index 0398fac13..8692dc328 100644 --- a/apps/application/serializers/chat_serializers.py +++ b/apps/application/serializers/chat_serializers.py @@ -44,7 +44,7 @@ from setting.models import Model from setting.models_provider import get_model from smartdoc.conf import PROJECT_DIR -chat_cache = caches['model_cache'] +chat_cache = caches['chat_cache'] class WorkFlowSerializers(serializers.Serializer): diff --git a/apps/application/views/application_views.py b/apps/application/views/application_views.py index 28fe412ad..e6fe7191e 100644 --- a/apps/application/views/application_views.py +++ b/apps/application/views/application_views.py @@ -28,7 +28,7 @@ from common.swagger_api.common_api import CommonApi from common.util.common import query_params_to_single_dict from dataset.serializers.dataset_serializers import DataSetSerializers -chat_cache = cache.caches['model_cache'] +chat_cache = cache.caches['chat_cache'] class ApplicationStatistics(APIView): diff --git a/apps/common/constants/cache_code_constants.py b/apps/common/constants/cache_code_constants.py new file mode 100644 index 000000000..dd64805f0 --- /dev/null +++ b/apps/common/constants/cache_code_constants.py @@ -0,0 +1,18 @@ +# coding=utf-8 +""" + @project: MaxKB + @Author:虎 + @file: cache_code_constants.py + @date:2024/7/24 18:20 + @desc: +""" +from enum import Enum + + +class CacheCodeConstants(Enum): + # 应用ACCESS_TOKEN缓存 + APPLICATION_ACCESS_TOKEN_CACHE = 'APPLICATION_ACCESS_TOKEN_CACHE' + # 静态资源缓存 + STATIC_RESOURCE_CACHE = 'STATIC_RESOURCE_CACHE' + # 应用API_KEY缓存 + APPLICATION_API_KEY_CACHE = 'APPLICATION_API_KEY_CACHE' diff --git a/apps/common/middleware/cross_domain_middleware.py b/apps/common/middleware/cross_domain_middleware.py index d116dd7b7..f7bde6f94 100644 --- a/apps/common/middleware/cross_domain_middleware.py +++ b/apps/common/middleware/cross_domain_middleware.py @@ -11,6 +11,17 @@ from django.http import HttpResponse from django.utils.deprecation import MiddlewareMixin from application.models.api_key_model import ApplicationApiKey +from common.constants.cache_code_constants import CacheCodeConstants +from common.util.cache_util import get_cache + + +@get_cache(cache_key=lambda secret_key, use_get_data: secret_key, + use_get_data=lambda secret_key, use_get_data: use_get_data, + version=CacheCodeConstants.APPLICATION_API_KEY_CACHE.value) +def get_application_api_key(secret_key, use_get_data): + application_api_key = QuerySet(ApplicationApiKey).filter(secret_key=secret_key).first() + return {'allow_cross_domain': application_api_key.allow_cross_domain, + 'cross_domain_list': application_api_key.cross_domain_list} class CrossDomainMiddleware(MiddlewareMixin): @@ -27,13 +38,15 @@ class CrossDomainMiddleware(MiddlewareMixin): auth = request.META.get('HTTP_AUTHORIZATION') origin = request.META.get('HTTP_ORIGIN') if auth is not None and str(auth).startswith("application-") and origin is not None: - application_api_key = QuerySet(ApplicationApiKey).filter(secret_key=auth).first() - if application_api_key.allow_cross_domain: + application_api_key = get_application_api_key(str(auth), True) + cross_domain_list = application_api_key.get('cross_domain_list', []) + allow_cross_domain = application_api_key.get('allow_cross_domain', False) + if allow_cross_domain: response['Access-Control-Allow-Methods'] = 'GET,POST,DELETE,PUT' response[ 'Access-Control-Allow-Headers'] = "Origin,X-Requested-With,Content-Type,Accept,Authorization,token" - if application_api_key.cross_domain_list is None or len(application_api_key.cross_domain_list) == 0: + if cross_domain_list is None or len(cross_domain_list) == 0: response['Access-Control-Allow-Origin'] = "*" - elif application_api_key.cross_domain_list.__contains__(origin): + elif cross_domain_list.__contains__(origin): response['Access-Control-Allow-Origin'] = origin return response diff --git a/apps/common/middleware/static_headers_middleware.py b/apps/common/middleware/static_headers_middleware.py index 79b799a70..f835296ac 100644 --- a/apps/common/middleware/static_headers_middleware.py +++ b/apps/common/middleware/static_headers_middleware.py @@ -10,21 +10,40 @@ from django.db.models import QuerySet from django.utils.deprecation import MiddlewareMixin from application.models.api_key_model import ApplicationAccessToken +from common.constants.cache_code_constants import CacheCodeConstants +from common.util.cache_util import get_cache + + +@get_cache(cache_key=lambda access_token, use_get_data: access_token, + use_get_data=lambda access_token, use_get_data: use_get_data, + version=CacheCodeConstants.APPLICATION_ACCESS_TOKEN_CACHE.value) +def get_application_access_token(access_token, use_get_data): + application_access_token = QuerySet(ApplicationAccessToken).filter(access_token=access_token).first() + if application_access_token is None: + return None + return {'white_active': application_access_token.white_active, + 'white_list': application_access_token.white_list, + 'application_icon': application_access_token.application.icon, + 'application_name': application_access_token.application.name} class StaticHeadersMiddleware(MiddlewareMixin): def process_response(self, request, response): if request.path.startswith('/ui/chat/'): access_token = request.path.replace('/ui/chat/', '') - application_access_token = QuerySet(ApplicationAccessToken).filter(access_token=access_token).first() + application_access_token = get_application_access_token(access_token, True) if application_access_token is not None: - if application_access_token.white_active: + white_active = application_access_token.get('white_active', False) + white_list = application_access_token.get('white_list', []) + application_icon = application_access_token.get('application_icon') + application_name = application_access_token.get('application_name') + if white_active: # 添加自定义的响应头 response[ - 'Content-Security-Policy'] = f'frame-ancestors {" ".join(application_access_token.white_list)}' + 'Content-Security-Policy'] = f'frame-ancestors {" ".join(white_list)}' response.content = (response.content.decode('utf-8').replace( '', - f'') - .replace('