From 210e09681fe2ccc1e11f05c35fa94419336652d0 Mon Sep 17 00:00:00 2001
From: shaohuzhang1 <80892890+shaohuzhang1@users.noreply.github.com>
Date: Fri, 11 Jul 2025 11:15:21 +0800
Subject: [PATCH] feat: Apikey call supports cross domain and application
whitelist (#3556)
---
.../application_access_token_cache.py | 30 +++++++
apps/common/middleware/__init__.py | 8 ++
.../middleware/chat_headers_middleware.py | 38 +++++++++
.../middleware/cross_domain_middleware.py | 40 +++++++++
apps/common/middleware/gzip.py | 84 +++++++++++++++++++
apps/maxkb/settings/base.py | 3 +
6 files changed, 203 insertions(+)
create mode 100644 apps/common/cache_data/application_access_token_cache.py
create mode 100644 apps/common/middleware/__init__.py
create mode 100644 apps/common/middleware/chat_headers_middleware.py
create mode 100644 apps/common/middleware/cross_domain_middleware.py
create mode 100644 apps/common/middleware/gzip.py
diff --git a/apps/common/cache_data/application_access_token_cache.py b/apps/common/cache_data/application_access_token_cache.py
new file mode 100644
index 000000000..69a456f5a
--- /dev/null
+++ b/apps/common/cache_data/application_access_token_cache.py
@@ -0,0 +1,30 @@
+# coding=utf-8
+"""
+ @project: MaxKB
+ @Author:虎
+ @file: application_access_token_cache.py
+ @date:2024/7/25 11:34
+ @desc:
+"""
+from django.core.cache import cache
+from django.db.models import QuerySet
+
+from application.models import ApplicationAccessToken
+from common.utils.cache_util import get_cache
+
+
+@get_cache(cache_key=lambda access_token, use_get_data: access_token,
+ use_get_data=lambda access_token, use_get_data: use_get_data,
+ version='APPLICATION_ACCESS_TOKEN_CACHE')
+def get_application_access_token(access_token, use_get_data):
+ application_access_token = QuerySet(ApplicationAccessToken).filter(access_token=access_token).first()
+ if application_access_token is None:
+ return None
+ return {'white_active': application_access_token.white_active,
+ 'white_list': application_access_token.white_list,
+ 'application_icon': application_access_token.application.icon,
+ 'application_name': application_access_token.application.name}
+
+
+def del_application_access_token(access_token):
+ cache.delete(access_token, version='APPLICATION_ACCESS_TOKEN_CACHE')
diff --git a/apps/common/middleware/__init__.py b/apps/common/middleware/__init__.py
new file mode 100644
index 000000000..7bbffa278
--- /dev/null
+++ b/apps/common/middleware/__init__.py
@@ -0,0 +1,8 @@
+# coding=utf-8
+"""
+ @project: MaxKB
+ @Author:虎虎
+ @file: __init__.py
+ @date:2025/7/11 10:43
+ @desc:
+"""
diff --git a/apps/common/middleware/chat_headers_middleware.py b/apps/common/middleware/chat_headers_middleware.py
new file mode 100644
index 000000000..b1d1d8608
--- /dev/null
+++ b/apps/common/middleware/chat_headers_middleware.py
@@ -0,0 +1,38 @@
+# coding=utf-8
+"""
+ @project: maxkb
+ @Author:虎
+ @file: static_headers_middleware.py
+ @date:2024/3/13 18:26
+ @desc:
+"""
+from django.utils.deprecation import MiddlewareMixin
+
+from common.cache_data.application_access_token_cache import get_application_access_token
+from maxkb.const import CONFIG
+
+
+class ChatHeadersMiddleware(MiddlewareMixin):
+ def process_response(self, request, response):
+
+ if request.path.startswith(CONFIG.get_chat_path()) and not request.path.startswith(
+ CONFIG.get_chat_path() + '/api'):
+ access_token = request.path.replace(CONFIG.get_chat_path() + '/', '')
+ if access_token.__contains__('/') or access_token == 'undefined':
+ return response
+ application_access_token = get_application_access_token(access_token, True)
+ if application_access_token is not None:
+ white_active = application_access_token.get('white_active', False)
+ white_list = application_access_token.get('white_list', [])
+ application_icon = application_access_token.get('application_icon')
+ application_name = application_access_token.get('application_name')
+ if white_active:
+ # 添加自定义的响应头
+ response[
+ 'Content-Security-Policy'] = f'frame-ancestors {" ".join(white_list)}'
+ response.content = (response.content.decode('utf-8').replace(
+ '',
+ f'')
+ .replace('
MaxKB', f'{application_name}').encode(
+ "utf-8"))
+ return response
diff --git a/apps/common/middleware/cross_domain_middleware.py b/apps/common/middleware/cross_domain_middleware.py
new file mode 100644
index 000000000..6dcbaf40c
--- /dev/null
+++ b/apps/common/middleware/cross_domain_middleware.py
@@ -0,0 +1,40 @@
+# coding=utf-8
+"""
+ @project: MaxKB
+ @Author:虎虎
+ @file: cross_domain_middleware.py
+ @date:2024/5/8 13:36
+ @desc:
+"""
+from django.http import HttpResponse
+from django.utils.deprecation import MiddlewareMixin
+
+from common.cache_data.application_api_key_cache import get_application_api_key
+
+
+class CrossDomainMiddleware(MiddlewareMixin):
+
+ def process_request(self, request):
+ if request.method == 'OPTIONS':
+ return HttpResponse(status=200,
+ headers={
+ "Access-Control-Allow-Origin": "*",
+ "Access-Control-Allow-Methods": "GET,POST,DELETE,PUT",
+ "Access-Control-Allow-Headers": "Origin,X-Requested-With,Content-Type,Accept,Authorization,token"})
+
+ def process_response(self, request, response):
+ auth = request.META.get('HTTP_AUTHORIZATION')
+ origin = request.META.get('HTTP_ORIGIN')
+ if auth is not None and str(auth).startswith("application-") and origin is not None:
+ application_api_key = get_application_api_key(str(auth), True)
+ cross_domain_list = application_api_key.get('cross_domain_list', [])
+ allow_cross_domain = application_api_key.get('allow_cross_domain', False)
+ if allow_cross_domain:
+ response['Access-Control-Allow-Methods'] = 'GET,POST,DELETE,PUT'
+ response[
+ 'Access-Control-Allow-Headers'] = "Origin,X-Requested-With,Content-Type,Accept,Authorization,token"
+ if cross_domain_list is None or len(cross_domain_list) == 0:
+ response['Access-Control-Allow-Origin'] = "*"
+ elif cross_domain_list.__contains__(origin):
+ response['Access-Control-Allow-Origin'] = origin
+ return response
diff --git a/apps/common/middleware/gzip.py b/apps/common/middleware/gzip.py
new file mode 100644
index 000000000..92c7cea38
--- /dev/null
+++ b/apps/common/middleware/gzip.py
@@ -0,0 +1,84 @@
+# coding=utf-8
+"""
+ @project: MaxKB
+ @Author:虎
+ @file: gzip.py
+ @date:2025/2/27 10:03
+ @desc:
+"""
+from django.utils.cache import patch_vary_headers
+from django.utils.deprecation import MiddlewareMixin
+from django.utils.regex_helper import _lazy_re_compile
+from django.utils.text import compress_sequence, compress_string
+
+re_accepts_gzip = _lazy_re_compile(r"\bgzip\b")
+
+
+class GZipMiddleware(MiddlewareMixin):
+ """
+ Compress content if the browser allows gzip compression.
+ Set the Vary header accordingly, so that caches will base their storage
+ on the Accept-Encoding header.
+ """
+
+ max_random_bytes = 100
+
+ def process_response(self, request, response):
+ if request.method != 'GET' or request.path.startswith('/api'):
+ return response
+ # It's not worth attempting to compress really short responses.
+ if not response.streaming and len(response.content) < 200:
+ return response
+
+ # Avoid gzipping if we've already got a content-encoding.
+ if response.has_header("Content-Encoding"):
+ return response
+
+ patch_vary_headers(response, ("Accept-Encoding",))
+
+ ae = request.META.get("HTTP_ACCEPT_ENCODING", "")
+ if not re_accepts_gzip.search(ae):
+ return response
+
+ if response.streaming:
+ if response.is_async:
+ # pull to lexical scope to capture fixed reference in case
+ # streaming_content is set again later.
+ original_iterator = response.streaming_content
+
+ async def gzip_wrapper():
+ async for chunk in original_iterator:
+ yield compress_string(
+ chunk,
+ max_random_bytes=self.max_random_bytes,
+ )
+
+ response.streaming_content = gzip_wrapper()
+ else:
+ response.streaming_content = compress_sequence(
+ response.streaming_content,
+ max_random_bytes=self.max_random_bytes,
+ )
+ # Delete the `Content-Length` header for streaming content, because
+ # we won't know the compressed size until we stream it.
+ del response.headers["Content-Length"]
+ else:
+ # Return the compressed content only if it's actually shorter.
+ compressed_content = compress_string(
+ response.content,
+ max_random_bytes=self.max_random_bytes,
+ )
+ if len(compressed_content) >= len(response.content):
+ return response
+ response.content = compressed_content
+ response.headers["Content-Length"] = str(len(response.content))
+
+ # If there is a strong ETag, make it weak to fulfill the requirements
+ # of RFC 9110 Section 8.8.1 while also allowing conditional request
+ # matches on ETags.
+ etag = response.get("ETag")
+ if etag and etag.startswith('"'):
+ response.headers["ETag"] = "W/" + etag
+ response.headers["Content-Encoding"] = "gzip"
+
+ return response
diff --git a/apps/maxkb/settings/base.py b/apps/maxkb/settings/base.py
index 2aba55aab..a0f5f6cd3 100644
--- a/apps/maxkb/settings/base.py
+++ b/apps/maxkb/settings/base.py
@@ -56,6 +56,9 @@ MIDDLEWARE = [
'django.middleware.security.SecurityMiddleware',
'django.contrib.sessions.middleware.SessionMiddleware',
'django.middleware.common.CommonMiddleware',
+ 'common.middleware.gzip.GZipMiddleware',
+ 'common.middleware.chat_headers_middleware.ChatHeadersMiddleware',
+ 'common.middleware.cross_domain_middleware.CrossDomainMiddleware',
]