From 210e09681fe2ccc1e11f05c35fa94419336652d0 Mon Sep 17 00:00:00 2001 From: shaohuzhang1 <80892890+shaohuzhang1@users.noreply.github.com> Date: Fri, 11 Jul 2025 11:15:21 +0800 Subject: [PATCH] feat: Apikey call supports cross domain and application whitelist (#3556) --- .../application_access_token_cache.py | 30 +++++++ apps/common/middleware/__init__.py | 8 ++ .../middleware/chat_headers_middleware.py | 38 +++++++++ .../middleware/cross_domain_middleware.py | 40 +++++++++ apps/common/middleware/gzip.py | 84 +++++++++++++++++++ apps/maxkb/settings/base.py | 3 + 6 files changed, 203 insertions(+) create mode 100644 apps/common/cache_data/application_access_token_cache.py create mode 100644 apps/common/middleware/__init__.py create mode 100644 apps/common/middleware/chat_headers_middleware.py create mode 100644 apps/common/middleware/cross_domain_middleware.py create mode 100644 apps/common/middleware/gzip.py diff --git a/apps/common/cache_data/application_access_token_cache.py b/apps/common/cache_data/application_access_token_cache.py new file mode 100644 index 000000000..69a456f5a --- /dev/null +++ b/apps/common/cache_data/application_access_token_cache.py @@ -0,0 +1,30 @@ +# coding=utf-8 +""" + @project: MaxKB + @Author:虎 + @file: application_access_token_cache.py + @date:2024/7/25 11:34 + @desc: +""" +from django.core.cache import cache +from django.db.models import QuerySet + +from application.models import ApplicationAccessToken +from common.utils.cache_util import get_cache + + +@get_cache(cache_key=lambda access_token, use_get_data: access_token, + use_get_data=lambda access_token, use_get_data: use_get_data, + version='APPLICATION_ACCESS_TOKEN_CACHE') +def get_application_access_token(access_token, use_get_data): + application_access_token = QuerySet(ApplicationAccessToken).filter(access_token=access_token).first() + if application_access_token is None: + return None + return {'white_active': application_access_token.white_active, + 'white_list': application_access_token.white_list, + 'application_icon': application_access_token.application.icon, + 'application_name': application_access_token.application.name} + + +def del_application_access_token(access_token): + cache.delete(access_token, version='APPLICATION_ACCESS_TOKEN_CACHE') diff --git a/apps/common/middleware/__init__.py b/apps/common/middleware/__init__.py new file mode 100644 index 000000000..7bbffa278 --- /dev/null +++ b/apps/common/middleware/__init__.py @@ -0,0 +1,8 @@ +# coding=utf-8 +""" + @project: MaxKB + @Author:虎虎 + @file: __init__.py + @date:2025/7/11 10:43 + @desc: +""" diff --git a/apps/common/middleware/chat_headers_middleware.py b/apps/common/middleware/chat_headers_middleware.py new file mode 100644 index 000000000..b1d1d8608 --- /dev/null +++ b/apps/common/middleware/chat_headers_middleware.py @@ -0,0 +1,38 @@ +# coding=utf-8 +""" + @project: maxkb + @Author:虎 + @file: static_headers_middleware.py + @date:2024/3/13 18:26 + @desc: +""" +from django.utils.deprecation import MiddlewareMixin + +from common.cache_data.application_access_token_cache import get_application_access_token +from maxkb.const import CONFIG + + +class ChatHeadersMiddleware(MiddlewareMixin): + def process_response(self, request, response): + + if request.path.startswith(CONFIG.get_chat_path()) and not request.path.startswith( + CONFIG.get_chat_path() + '/api'): + access_token = request.path.replace(CONFIG.get_chat_path() + '/', '') + if access_token.__contains__('/') or access_token == 'undefined': + return response + application_access_token = get_application_access_token(access_token, True) + if application_access_token is not None: + white_active = application_access_token.get('white_active', False) + white_list = application_access_token.get('white_list', []) + application_icon = application_access_token.get('application_icon') + application_name = application_access_token.get('application_name') + if white_active: + # 添加自定义的响应头 + response[ + 'Content-Security-Policy'] = f'frame-ancestors {" ".join(white_list)}' + response.content = (response.content.decode('utf-8').replace( + '', + f'') + .replace('MaxKB', f'{application_name}').encode( + "utf-8")) + return response diff --git a/apps/common/middleware/cross_domain_middleware.py b/apps/common/middleware/cross_domain_middleware.py new file mode 100644 index 000000000..6dcbaf40c --- /dev/null +++ b/apps/common/middleware/cross_domain_middleware.py @@ -0,0 +1,40 @@ +# coding=utf-8 +""" + @project: MaxKB + @Author:虎虎 + @file: cross_domain_middleware.py + @date:2024/5/8 13:36 + @desc: +""" +from django.http import HttpResponse +from django.utils.deprecation import MiddlewareMixin + +from common.cache_data.application_api_key_cache import get_application_api_key + + +class CrossDomainMiddleware(MiddlewareMixin): + + def process_request(self, request): + if request.method == 'OPTIONS': + return HttpResponse(status=200, + headers={ + "Access-Control-Allow-Origin": "*", + "Access-Control-Allow-Methods": "GET,POST,DELETE,PUT", + "Access-Control-Allow-Headers": "Origin,X-Requested-With,Content-Type,Accept,Authorization,token"}) + + def process_response(self, request, response): + auth = request.META.get('HTTP_AUTHORIZATION') + origin = request.META.get('HTTP_ORIGIN') + if auth is not None and str(auth).startswith("application-") and origin is not None: + application_api_key = get_application_api_key(str(auth), True) + cross_domain_list = application_api_key.get('cross_domain_list', []) + allow_cross_domain = application_api_key.get('allow_cross_domain', False) + if allow_cross_domain: + response['Access-Control-Allow-Methods'] = 'GET,POST,DELETE,PUT' + response[ + 'Access-Control-Allow-Headers'] = "Origin,X-Requested-With,Content-Type,Accept,Authorization,token" + if cross_domain_list is None or len(cross_domain_list) == 0: + response['Access-Control-Allow-Origin'] = "*" + elif cross_domain_list.__contains__(origin): + response['Access-Control-Allow-Origin'] = origin + return response diff --git a/apps/common/middleware/gzip.py b/apps/common/middleware/gzip.py new file mode 100644 index 000000000..92c7cea38 --- /dev/null +++ b/apps/common/middleware/gzip.py @@ -0,0 +1,84 @@ +# coding=utf-8 +""" + @project: MaxKB + @Author:虎 + @file: gzip.py + @date:2025/2/27 10:03 + @desc: +""" +from django.utils.cache import patch_vary_headers +from django.utils.deprecation import MiddlewareMixin +from django.utils.regex_helper import _lazy_re_compile +from django.utils.text import compress_sequence, compress_string + +re_accepts_gzip = _lazy_re_compile(r"\bgzip\b") + + +class GZipMiddleware(MiddlewareMixin): + """ + Compress content if the browser allows gzip compression. + Set the Vary header accordingly, so that caches will base their storage + on the Accept-Encoding header. + """ + + max_random_bytes = 100 + + def process_response(self, request, response): + if request.method != 'GET' or request.path.startswith('/api'): + return response + # It's not worth attempting to compress really short responses. + if not response.streaming and len(response.content) < 200: + return response + + # Avoid gzipping if we've already got a content-encoding. + if response.has_header("Content-Encoding"): + return response + + patch_vary_headers(response, ("Accept-Encoding",)) + + ae = request.META.get("HTTP_ACCEPT_ENCODING", "") + if not re_accepts_gzip.search(ae): + return response + + if response.streaming: + if response.is_async: + # pull to lexical scope to capture fixed reference in case + # streaming_content is set again later. + original_iterator = response.streaming_content + + async def gzip_wrapper(): + async for chunk in original_iterator: + yield compress_string( + chunk, + max_random_bytes=self.max_random_bytes, + ) + + response.streaming_content = gzip_wrapper() + else: + response.streaming_content = compress_sequence( + response.streaming_content, + max_random_bytes=self.max_random_bytes, + ) + # Delete the `Content-Length` header for streaming content, because + # we won't know the compressed size until we stream it. + del response.headers["Content-Length"] + else: + # Return the compressed content only if it's actually shorter. + compressed_content = compress_string( + response.content, + max_random_bytes=self.max_random_bytes, + ) + if len(compressed_content) >= len(response.content): + return response + response.content = compressed_content + response.headers["Content-Length"] = str(len(response.content)) + + # If there is a strong ETag, make it weak to fulfill the requirements + # of RFC 9110 Section 8.8.1 while also allowing conditional request + # matches on ETags. + etag = response.get("ETag") + if etag and etag.startswith('"'): + response.headers["ETag"] = "W/" + etag + response.headers["Content-Encoding"] = "gzip" + + return response diff --git a/apps/maxkb/settings/base.py b/apps/maxkb/settings/base.py index 2aba55aab..a0f5f6cd3 100644 --- a/apps/maxkb/settings/base.py +++ b/apps/maxkb/settings/base.py @@ -56,6 +56,9 @@ MIDDLEWARE = [ 'django.middleware.security.SecurityMiddleware', 'django.contrib.sessions.middleware.SessionMiddleware', 'django.middleware.common.CommonMiddleware', + 'common.middleware.gzip.GZipMiddleware', + 'common.middleware.chat_headers_middleware.ChatHeadersMiddleware', + 'common.middleware.cross_domain_middleware.CrossDomainMiddleware', ]