diff --git a/apps/application/flow/i_step_node.py b/apps/application/flow/i_step_node.py index 0d741bd39..9c2e1c544 100644 --- a/apps/application/flow/i_step_node.py +++ b/apps/application/flow/i_step_node.py @@ -20,7 +20,7 @@ from common.field.common import InstanceField from common.util.field_message import ErrMessage from django.core import cache -chat_cache = cache.caches['model_cache'] +chat_cache = cache.caches['chat_cache'] def write_context(step_variable: Dict, global_variable: Dict, node, workflow): diff --git a/apps/application/serializers/application_serializers.py b/apps/application/serializers/application_serializers.py index 87da221eb..503313ec1 100644 --- a/apps/application/serializers/application_serializers.py +++ b/apps/application/serializers/application_serializers.py @@ -28,10 +28,13 @@ from application.models import Application, ApplicationDatasetMapping, Applicati from application.models.api_key_model import ApplicationAccessToken, ApplicationApiKey from common.config.embedding_config import VectorStore from common.constants.authentication_type import AuthenticationType +from common.constants.cache_code_constants import CacheCodeConstants from common.db.search import get_dynamics_model, native_search, native_page_search from common.db.sql_execute import select_list from common.exception.app_exception import AppApiException, NotFound404, AppUnauthorizedFailed from common.field.common import UploadedImageField +from common.middleware.cross_domain_middleware import get_application_api_key +from common.middleware.static_headers_middleware import get_application_access_token from common.models.db_model_manage import DBModelManage from common.util.common import valid_license from common.util.field_message import ErrMessage @@ -44,8 +47,7 @@ from setting.models.model_management import Model from setting.serializers.provider_serializers import ModelSerializer from smartdoc.conf import PROJECT_DIR -token_cache = cache.caches['token_cache'] -chat_cache = cache.caches['model_cache'] +chat_cache = cache.caches['chat_cache'] class ModelDatasetAssociation(serializers.Serializer): @@ -251,6 +253,8 @@ class ApplicationSerializer(serializers.Serializer): if 'is_active' in instance: application_access_token.is_active = instance.get("is_active") if 'access_token_reset' in instance and instance.get('access_token_reset'): + cache.cache.delete(application_access_token.access_token, + version=CacheCodeConstants.APPLICATION_ACCESS_TOKEN_CACHE.value) application_access_token.access_token = hashlib.md5(str(uuid.uuid1()).encode()).hexdigest()[8:24] if 'access_num' in instance and instance.get('access_num') is not None: application_access_token.access_num = instance.get("access_num") @@ -261,6 +265,7 @@ class ApplicationSerializer(serializers.Serializer): if 'show_source' in instance and instance.get('show_source') is not None: application_access_token.show_source = instance.get('show_source') application_access_token.save() + get_application_access_token(application_access_token.access_token, False) return self.one(with_valid=False) def one(self, with_valid=True): @@ -526,6 +531,9 @@ class ApplicationSerializer(serializers.Serializer): image.save() application.icon = f'/api/image/{image_id}' application.save() + application_access_token = QuerySet(ApplicationAccessToken).filter( + application_id=self.data.get('application_id')).first() + get_application_access_token(application_access_token.access_token, False) return {**ApplicationSerializer.Query.reset_application(ApplicationSerializerModel(application).data)} class Operate(serializers.Serializer): @@ -686,6 +694,9 @@ class ApplicationSerializer(serializers.Serializer): self.save_application_mapping(application_dataset_id_list, dataset_id_list, application_id) chat_cache.clear_by_application_id(application_id) + application_access_token = QuerySet(ApplicationAccessToken).filter(application_id=application_id).first() + # 更新缓存数据 + get_application_access_token(application_access_token.access_token, False) return self.one(with_valid=False) @staticmethod @@ -767,8 +778,11 @@ class ApplicationSerializer(serializers.Serializer): self.is_valid(raise_exception=True) api_key_id = self.data.get("api_key_id") application_id = self.data.get('application_id') - QuerySet(ApplicationApiKey).filter(id=api_key_id, - application_id=application_id).delete() + application_api_key = QuerySet(ApplicationApiKey).filter(id=api_key_id, + application_id=application_id).first() + cache.cache.delete(application_api_key.secret_key, + version=CacheCodeConstants.APPLICATION_API_KEY_CACHE.value) + application_api_key.delete() def edit(self, instance, with_valid=True): if with_valid: @@ -787,3 +801,5 @@ class ApplicationSerializer(serializers.Serializer): if 'cross_domain_list' in instance and instance.get('cross_domain_list') is not None: application_api_key.cross_domain_list = instance.get('cross_domain_list') application_api_key.save() + # 写入缓存 + get_application_api_key(application_api_key.secret_key, False) diff --git a/apps/application/serializers/chat_message_serializers.py b/apps/application/serializers/chat_message_serializers.py index ecfd30070..b19a8ed5b 100644 --- a/apps/application/serializers/chat_message_serializers.py +++ b/apps/application/serializers/chat_message_serializers.py @@ -6,7 +6,6 @@ @date:2023/11/14 13:51 @desc: """ -import json import uuid from typing import List from uuid import UUID @@ -31,13 +30,11 @@ from application.models.api_key_model import ApplicationPublicAccessClient, Appl from common.constants.authentication_type import AuthenticationType from common.exception.app_exception import AppApiException, AppChatNumOutOfBoundsFailed from common.util.field_message import ErrMessage -from common.util.rsa_util import rsa_long_decrypt from common.util.split_model import flat_map from dataset.models import Paragraph, Document from setting.models import Model, Status -from setting.models_provider.constants.model_provider_constants import ModelProvideConstants -chat_cache = caches['model_cache'] +chat_cache = caches['chat_cache'] class ChatInfo: diff --git a/apps/application/serializers/chat_serializers.py b/apps/application/serializers/chat_serializers.py index 0398fac13..8692dc328 100644 --- a/apps/application/serializers/chat_serializers.py +++ b/apps/application/serializers/chat_serializers.py @@ -44,7 +44,7 @@ from setting.models import Model from setting.models_provider import get_model from smartdoc.conf import PROJECT_DIR -chat_cache = caches['model_cache'] +chat_cache = caches['chat_cache'] class WorkFlowSerializers(serializers.Serializer): diff --git a/apps/application/views/application_views.py b/apps/application/views/application_views.py index 28fe412ad..e6fe7191e 100644 --- a/apps/application/views/application_views.py +++ b/apps/application/views/application_views.py @@ -28,7 +28,7 @@ from common.swagger_api.common_api import CommonApi from common.util.common import query_params_to_single_dict from dataset.serializers.dataset_serializers import DataSetSerializers -chat_cache = cache.caches['model_cache'] +chat_cache = cache.caches['chat_cache'] class ApplicationStatistics(APIView): diff --git a/apps/common/constants/cache_code_constants.py b/apps/common/constants/cache_code_constants.py new file mode 100644 index 000000000..dd64805f0 --- /dev/null +++ b/apps/common/constants/cache_code_constants.py @@ -0,0 +1,18 @@ +# coding=utf-8 +""" + @project: MaxKB + @Author:虎 + @file: cache_code_constants.py + @date:2024/7/24 18:20 + @desc: +""" +from enum import Enum + + +class CacheCodeConstants(Enum): + # 应用ACCESS_TOKEN缓存 + APPLICATION_ACCESS_TOKEN_CACHE = 'APPLICATION_ACCESS_TOKEN_CACHE' + # 静态资源缓存 + STATIC_RESOURCE_CACHE = 'STATIC_RESOURCE_CACHE' + # 应用API_KEY缓存 + APPLICATION_API_KEY_CACHE = 'APPLICATION_API_KEY_CACHE' diff --git a/apps/common/middleware/cross_domain_middleware.py b/apps/common/middleware/cross_domain_middleware.py index d116dd7b7..f7bde6f94 100644 --- a/apps/common/middleware/cross_domain_middleware.py +++ b/apps/common/middleware/cross_domain_middleware.py @@ -11,6 +11,17 @@ from django.http import HttpResponse from django.utils.deprecation import MiddlewareMixin from application.models.api_key_model import ApplicationApiKey +from common.constants.cache_code_constants import CacheCodeConstants +from common.util.cache_util import get_cache + + +@get_cache(cache_key=lambda secret_key, use_get_data: secret_key, + use_get_data=lambda secret_key, use_get_data: use_get_data, + version=CacheCodeConstants.APPLICATION_API_KEY_CACHE.value) +def get_application_api_key(secret_key, use_get_data): + application_api_key = QuerySet(ApplicationApiKey).filter(secret_key=secret_key).first() + return {'allow_cross_domain': application_api_key.allow_cross_domain, + 'cross_domain_list': application_api_key.cross_domain_list} class CrossDomainMiddleware(MiddlewareMixin): @@ -27,13 +38,15 @@ class CrossDomainMiddleware(MiddlewareMixin): auth = request.META.get('HTTP_AUTHORIZATION') origin = request.META.get('HTTP_ORIGIN') if auth is not None and str(auth).startswith("application-") and origin is not None: - application_api_key = QuerySet(ApplicationApiKey).filter(secret_key=auth).first() - if application_api_key.allow_cross_domain: + application_api_key = get_application_api_key(str(auth), True) + cross_domain_list = application_api_key.get('cross_domain_list', []) + allow_cross_domain = application_api_key.get('allow_cross_domain', False) + if allow_cross_domain: response['Access-Control-Allow-Methods'] = 'GET,POST,DELETE,PUT' response[ 'Access-Control-Allow-Headers'] = "Origin,X-Requested-With,Content-Type,Accept,Authorization,token" - if application_api_key.cross_domain_list is None or len(application_api_key.cross_domain_list) == 0: + if cross_domain_list is None or len(cross_domain_list) == 0: response['Access-Control-Allow-Origin'] = "*" - elif application_api_key.cross_domain_list.__contains__(origin): + elif cross_domain_list.__contains__(origin): response['Access-Control-Allow-Origin'] = origin return response diff --git a/apps/common/middleware/static_headers_middleware.py b/apps/common/middleware/static_headers_middleware.py index 79b799a70..f835296ac 100644 --- a/apps/common/middleware/static_headers_middleware.py +++ b/apps/common/middleware/static_headers_middleware.py @@ -10,21 +10,40 @@ from django.db.models import QuerySet from django.utils.deprecation import MiddlewareMixin from application.models.api_key_model import ApplicationAccessToken +from common.constants.cache_code_constants import CacheCodeConstants +from common.util.cache_util import get_cache + + +@get_cache(cache_key=lambda access_token, use_get_data: access_token, + use_get_data=lambda access_token, use_get_data: use_get_data, + version=CacheCodeConstants.APPLICATION_ACCESS_TOKEN_CACHE.value) +def get_application_access_token(access_token, use_get_data): + application_access_token = QuerySet(ApplicationAccessToken).filter(access_token=access_token).first() + if application_access_token is None: + return None + return {'white_active': application_access_token.white_active, + 'white_list': application_access_token.white_list, + 'application_icon': application_access_token.application.icon, + 'application_name': application_access_token.application.name} class StaticHeadersMiddleware(MiddlewareMixin): def process_response(self, request, response): if request.path.startswith('/ui/chat/'): access_token = request.path.replace('/ui/chat/', '') - application_access_token = QuerySet(ApplicationAccessToken).filter(access_token=access_token).first() + application_access_token = get_application_access_token(access_token, True) if application_access_token is not None: - if application_access_token.white_active: + white_active = application_access_token.get('white_active', False) + white_list = application_access_token.get('white_list', []) + application_icon = application_access_token.get('application_icon') + application_name = application_access_token.get('application_name') + if white_active: # 添加自定义的响应头 response[ - 'Content-Security-Policy'] = f'frame-ancestors {" ".join(application_access_token.white_list)}' + 'Content-Security-Policy'] = f'frame-ancestors {" ".join(white_list)}' response.content = (response.content.decode('utf-8').replace( '', - f'') - .replace('MaxKB', f'{application_access_token.application.name}').encode( + f'') + .replace('MaxKB', f'{application_name}').encode( "utf-8")) return response diff --git a/apps/common/util/cache_util.py b/apps/common/util/cache_util.py new file mode 100644 index 000000000..a206b5603 --- /dev/null +++ b/apps/common/util/cache_util.py @@ -0,0 +1,66 @@ +# coding=utf-8 +""" + @project: MaxKB + @Author:虎 + @file: cache_util.py + @date:2024/7/24 19:23 + @desc: +""" +from django.core.cache import cache + + +def get_data_by_default_cache(key: str, get_data, cache_instance=cache, version=None, kwargs=None): + """ + 获取数据, 先从缓存中获取,如果获取不到再调用get_data 获取数据 + @param kwargs: get_data所需参数 + @param key: key + @param get_data: 获取数据函数 + @param cache_instance: cache实例 + @param version: 版本用于隔离 + @return: + """ + if kwargs is None: + kwargs = {} + if cache_instance.has_key(key, version=version): + return cache_instance.get(key, version=version) + data = get_data(**kwargs) + cache_instance.add(key, data, version=version) + return data + + +def set_data_by_default_cache(key: str, get_data, cache_instance=cache, version=None): + data = get_data() + cache_instance.set(key, data, version=version) + return data + + +def get_cache(cache_key, use_get_data: any = True, cache_instance=cache, version=None): + def inner(get_data): + def run(*args, **kwargs): + key = cache_key(*args, **kwargs) if callable(cache_key) else cache_key + is_use_get_data = use_get_data(*args, **kwargs) if callable(use_get_data) else use_get_data + if is_use_get_data: + if cache_instance.has_key(key, version=version): + return cache_instance.get(key, version=version) + data = get_data(*args, **kwargs) + cache_instance.add(key, data, version=version) + return data + data = get_data(*args, **kwargs) + cache_instance.set(key, data, version=version) + return data + + return run + + return inner + + +def del_cache(cache_key, cache_instance=cache, version=None): + def inner(func): + def run(*args, **kwargs): + key = cache_key(*args, **kwargs) if callable(cache_key) else cache_key + func(*args, **kwargs) + cache_instance.delete(key, version=version) + + return run + + return inner diff --git a/apps/smartdoc/settings/base.py b/apps/smartdoc/settings/base.py index 6deee63f2..5cba1e8e1 100644 --- a/apps/smartdoc/settings/base.py +++ b/apps/smartdoc/settings/base.py @@ -92,9 +92,21 @@ SWAGGER_SETTINGS = { CACHES = { "default": { 'BACKEND': 'django.core.cache.backends.locmem.LocMemCache', + 'LOCATION': 'unique-snowflake', + 'TIMEOUT': 60 * 30, + 'OPTIONS': { + 'MAX_ENTRIES': 150, + 'CULL_FREQUENCY': 5, + } }, - 'model_cache': { - 'BACKEND': 'common.cache.mem_cache.MemCache' + 'chat_cache': { + 'BACKEND': 'common.cache.mem_cache.MemCache', + 'LOCATION': 'unique-snowflake', + 'TIMEOUT': 60 * 30, + 'OPTIONS': { + 'MAX_ENTRIES': 150, + 'CULL_FREQUENCY': 5, + } }, # 存储用户信息 'user_cache': { diff --git a/apps/smartdoc/urls.py b/apps/smartdoc/urls.py index 4458aba8a..85d61a954 100644 --- a/apps/smartdoc/urls.py +++ b/apps/smartdoc/urls.py @@ -22,8 +22,10 @@ from django.views import static from rest_framework import status from application.urls import urlpatterns as application_urlpatterns +from common.constants.cache_code_constants import CacheCodeConstants from common.init.init_doc import init_doc from common.response.result import Result +from common.util.cache_util import get_cache from smartdoc import settings from smartdoc.conf import PROJECT_DIR @@ -51,6 +53,15 @@ if not settings.DEBUG: pro() +@get_cache(cache_key=lambda index_path: index_path, + version=CacheCodeConstants.STATIC_RESOURCE_CACHE.value) +def get_index_html(index_path): + file = open(index_path, "r", encoding='utf-8') + content = file.read() + file.close() + return content + + def page_not_found(request, exception): """ 页面不存在处理 @@ -60,9 +71,7 @@ def page_not_found(request, exception): index_path = os.path.join(PROJECT_DIR, 'apps', "static", 'ui', 'index.html') if not os.path.exists(index_path): return HttpResponse("页面不存在", status=404) - file = open(index_path, "r", encoding='utf-8') - content = file.read() - file.close() + content = get_index_html(index_path) if request.path.startswith('/ui/chat/'): return HttpResponse(content, status=200) return HttpResponse(content, status=200, headers={'X-Frame-Options': 'DENY'}) diff --git a/main.py b/main.py index 1a6ff2896..5f3df7965 100644 --- a/main.py +++ b/main.py @@ -74,6 +74,7 @@ if __name__ == '__main__': elif action == "collect_static": collect_static() elif action == 'dev': + collect_static() perform_db_migrate() runserver() else: