mirror of
https://github.com/1Panel-dev/MaxKB.git
synced 2025-12-28 05:42:51 +00:00
feat: refactor URL content retrieval to use a dedicated function with application checks
This commit is contained in:
parent
94e60b073f
commit
8c3caa27dd
|
|
@ -1,14 +1,20 @@
|
|||
# coding=utf-8
|
||||
import base64
|
||||
import ipaddress
|
||||
import re
|
||||
import socket
|
||||
import urllib
|
||||
from urllib.parse import urlparse
|
||||
|
||||
import requests
|
||||
import uuid_utils.compat as uuid
|
||||
from django.db.models import QuerySet
|
||||
from django.http import HttpResponse
|
||||
from django.utils.translation import gettext_lazy as _
|
||||
from rest_framework import serializers
|
||||
|
||||
from common.exception.app_exception import NotFound404
|
||||
from application.models import Application
|
||||
from common.exception.app_exception import NotFound404, AppApiException
|
||||
from knowledge.models import File, FileSourceType
|
||||
from tools.serializers.tool import UploadedFileField
|
||||
|
||||
|
|
@ -158,3 +164,80 @@ class FileSerializer(serializers.Serializer):
|
|||
if file is not None:
|
||||
file.delete()
|
||||
return True
|
||||
|
||||
|
||||
def get_url_content(url, application_id: str):
|
||||
application = Application.objects.filter(id=application_id).first()
|
||||
if application is None:
|
||||
return AppApiException(500, _('Application does not exist'))
|
||||
if not application.file_upload_enable:
|
||||
return AppApiException(500, _('File upload is not enabled'))
|
||||
file_limit = 50 * 1024 * 1024
|
||||
if application.file_upload_setting and application.file_upload_setting.file_limit:
|
||||
file_limit = application.file_upload_setting.file_limit * 1024 * 1024
|
||||
parsed = validate_url(url)
|
||||
|
||||
response = requests.get(
|
||||
url,
|
||||
timeout=3,
|
||||
allow_redirects=False
|
||||
)
|
||||
final_host = urlparse(response.url).hostname
|
||||
if is_private_ip(final_host):
|
||||
raise ValueError("Blocked unsafe redirect to internal host")
|
||||
# 判断文件大小
|
||||
if response.headers.get('Content-Length', 0) > file_limit:
|
||||
return AppApiException(500, _('File size exceeds limit'))
|
||||
# 返回状态码 响应内容大小 响应的contenttype 还有字节流
|
||||
content_type = response.headers.get('Content-Type', '')
|
||||
# 根据内容类型决定如何处理
|
||||
if 'text' in content_type or 'json' in content_type:
|
||||
content = response.text
|
||||
else:
|
||||
# 二进制内容使用Base64编码
|
||||
content = base64.b64encode(response.content).decode('utf-8')
|
||||
|
||||
return {
|
||||
'status_code': response.status_code,
|
||||
'Content-Length': response.headers.get('Content-Length', 0),
|
||||
'Content-Type': content_type,
|
||||
'content': content,
|
||||
}
|
||||
|
||||
|
||||
def is_private_ip(host: str) -> bool:
|
||||
"""检测 IP 是否属于内网、环回、云 metadata 的危险地址"""
|
||||
try:
|
||||
ip = ipaddress.ip_address(socket.gethostbyname(host))
|
||||
return (
|
||||
ip.is_private or
|
||||
ip.is_loopback or
|
||||
ip.is_reserved or
|
||||
ip.is_link_local or
|
||||
ip.is_multicast
|
||||
)
|
||||
except Exception:
|
||||
return True
|
||||
|
||||
|
||||
def validate_url(url: str):
|
||||
"""验证 URL 是否安全"""
|
||||
if not url:
|
||||
raise ValueError("URL is required")
|
||||
|
||||
parsed = urlparse(url)
|
||||
|
||||
# 仅允许 http / https
|
||||
if parsed.scheme not in ("http", "https"):
|
||||
raise ValueError("Only http and https are allowed")
|
||||
|
||||
host = parsed.hostname
|
||||
# 域名不能为空
|
||||
if not host:
|
||||
raise ValueError("Invalid URL")
|
||||
|
||||
# 禁止访问内部、保留、环回、云 metadata
|
||||
if is_private_ip(host):
|
||||
raise ValueError("Access to internal IP addresses is blocked")
|
||||
|
||||
return parsed
|
||||
|
|
|
|||
|
|
@ -1,21 +1,14 @@
|
|||
# coding=utf-8
|
||||
import base64
|
||||
import ipaddress
|
||||
import socket
|
||||
from urllib.parse import urlparse
|
||||
|
||||
import requests
|
||||
from django.utils.translation import gettext_lazy as _
|
||||
from drf_spectacular.utils import extend_schema
|
||||
from rest_framework.parsers import MultiPartParser
|
||||
from rest_framework.views import APIView
|
||||
from rest_framework.views import Request
|
||||
|
||||
from common.auth import TokenAuth
|
||||
from common.log.log import log
|
||||
from common.result import result
|
||||
from knowledge.api.file import FileUploadAPI, FileGetAPI
|
||||
from oss.serializers.file import FileSerializer
|
||||
from oss.serializers.file import FileSerializer, get_url_content
|
||||
|
||||
|
||||
class FileRetrievalView(APIView):
|
||||
|
|
@ -84,71 +77,7 @@ class GetUrlView(APIView):
|
|||
operation_id=_('Get url'), # type: ignore
|
||||
tags=[_('Chat')] # type: ignore
|
||||
)
|
||||
def get(self, request: Request):
|
||||
def get(self, request: Request, application_id: str):
|
||||
url = request.query_params.get('url')
|
||||
parsed = validate_url(url)
|
||||
|
||||
response = requests.get(
|
||||
url,
|
||||
timeout=3,
|
||||
allow_redirects=False
|
||||
)
|
||||
final_host = urlparse(response.url).hostname
|
||||
if is_private_ip(final_host):
|
||||
raise ValueError("Blocked unsafe redirect to internal host")
|
||||
|
||||
# 返回状态码 响应内容大小 响应的contenttype 还有字节流
|
||||
content_type = response.headers.get('Content-Type', '')
|
||||
# 根据内容类型决定如何处理
|
||||
if 'text' in content_type or 'json' in content_type:
|
||||
content = response.text
|
||||
else:
|
||||
# 二进制内容使用Base64编码
|
||||
content = base64.b64encode(response.content).decode('utf-8')
|
||||
|
||||
return result.success({
|
||||
'status_code': response.status_code,
|
||||
'Content-Length': response.headers.get('Content-Length', 0),
|
||||
'Content-Type': content_type,
|
||||
'content': content,
|
||||
})
|
||||
|
||||
|
||||
def is_private_ip(host: str) -> bool:
|
||||
"""检测 IP 是否属于内网、环回、云 metadata 的危险地址"""
|
||||
try:
|
||||
ip = ipaddress.ip_address(socket.gethostbyname(host))
|
||||
return (
|
||||
ip.is_private or
|
||||
ip.is_loopback or
|
||||
ip.is_reserved or
|
||||
ip.is_link_local or
|
||||
ip.is_multicast
|
||||
)
|
||||
except Exception:
|
||||
return True
|
||||
|
||||
|
||||
def validate_url(url: str):
|
||||
"""验证 URL 是否安全"""
|
||||
if not url:
|
||||
raise ValueError("URL is required")
|
||||
|
||||
parsed = urlparse(url)
|
||||
|
||||
# 仅允许 http / https
|
||||
if parsed.scheme not in ("http", "https"):
|
||||
raise ValueError("Only http and https are allowed")
|
||||
|
||||
host = parsed.hostname
|
||||
path = parsed.path
|
||||
|
||||
# 域名不能为空
|
||||
if not host:
|
||||
raise ValueError("Invalid URL")
|
||||
|
||||
# 禁止访问内部、保留、环回、云 metadata
|
||||
if is_private_ip(host):
|
||||
raise ValueError("Access to internal IP addresses is blocked")
|
||||
|
||||
return parsed
|
||||
result_data = get_url_content(url, application_id)
|
||||
return result.success(result_data)
|
||||
Loading…
Reference in New Issue