feat: refactor URL content retrieval to use a dedicated function with application checks

This commit is contained in:
wxg0103 2025-11-27 17:37:00 +08:00
parent 94e60b073f
commit 8c3caa27dd
2 changed files with 88 additions and 76 deletions

View File

@ -1,14 +1,20 @@
# coding=utf-8
import base64
import ipaddress
import re
import socket
import urllib
from urllib.parse import urlparse
import requests
import uuid_utils.compat as uuid
from django.db.models import QuerySet
from django.http import HttpResponse
from django.utils.translation import gettext_lazy as _
from rest_framework import serializers
from common.exception.app_exception import NotFound404
from application.models import Application
from common.exception.app_exception import NotFound404, AppApiException
from knowledge.models import File, FileSourceType
from tools.serializers.tool import UploadedFileField
@ -158,3 +164,80 @@ class FileSerializer(serializers.Serializer):
if file is not None:
file.delete()
return True
def get_url_content(url, application_id: str):
application = Application.objects.filter(id=application_id).first()
if application is None:
return AppApiException(500, _('Application does not exist'))
if not application.file_upload_enable:
return AppApiException(500, _('File upload is not enabled'))
file_limit = 50 * 1024 * 1024
if application.file_upload_setting and application.file_upload_setting.file_limit:
file_limit = application.file_upload_setting.file_limit * 1024 * 1024
parsed = validate_url(url)
response = requests.get(
url,
timeout=3,
allow_redirects=False
)
final_host = urlparse(response.url).hostname
if is_private_ip(final_host):
raise ValueError("Blocked unsafe redirect to internal host")
# 判断文件大小
if response.headers.get('Content-Length', 0) > file_limit:
return AppApiException(500, _('File size exceeds limit'))
# 返回状态码 响应内容大小 响应的contenttype 还有字节流
content_type = response.headers.get('Content-Type', '')
# 根据内容类型决定如何处理
if 'text' in content_type or 'json' in content_type:
content = response.text
else:
# 二进制内容使用Base64编码
content = base64.b64encode(response.content).decode('utf-8')
return {
'status_code': response.status_code,
'Content-Length': response.headers.get('Content-Length', 0),
'Content-Type': content_type,
'content': content,
}
def is_private_ip(host: str) -> bool:
"""检测 IP 是否属于内网、环回、云 metadata 的危险地址"""
try:
ip = ipaddress.ip_address(socket.gethostbyname(host))
return (
ip.is_private or
ip.is_loopback or
ip.is_reserved or
ip.is_link_local or
ip.is_multicast
)
except Exception:
return True
def validate_url(url: str):
"""验证 URL 是否安全"""
if not url:
raise ValueError("URL is required")
parsed = urlparse(url)
# 仅允许 http / https
if parsed.scheme not in ("http", "https"):
raise ValueError("Only http and https are allowed")
host = parsed.hostname
# 域名不能为空
if not host:
raise ValueError("Invalid URL")
# 禁止访问内部、保留、环回、云 metadata
if is_private_ip(host):
raise ValueError("Access to internal IP addresses is blocked")
return parsed

View File

@ -1,21 +1,14 @@
# coding=utf-8
import base64
import ipaddress
import socket
from urllib.parse import urlparse
import requests
from django.utils.translation import gettext_lazy as _
from drf_spectacular.utils import extend_schema
from rest_framework.parsers import MultiPartParser
from rest_framework.views import APIView
from rest_framework.views import Request
from common.auth import TokenAuth
from common.log.log import log
from common.result import result
from knowledge.api.file import FileUploadAPI, FileGetAPI
from oss.serializers.file import FileSerializer
from oss.serializers.file import FileSerializer, get_url_content
class FileRetrievalView(APIView):
@ -84,71 +77,7 @@ class GetUrlView(APIView):
operation_id=_('Get url'), # type: ignore
tags=[_('Chat')] # type: ignore
)
def get(self, request: Request):
def get(self, request: Request, application_id: str):
url = request.query_params.get('url')
parsed = validate_url(url)
response = requests.get(
url,
timeout=3,
allow_redirects=False
)
final_host = urlparse(response.url).hostname
if is_private_ip(final_host):
raise ValueError("Blocked unsafe redirect to internal host")
# 返回状态码 响应内容大小 响应的contenttype 还有字节流
content_type = response.headers.get('Content-Type', '')
# 根据内容类型决定如何处理
if 'text' in content_type or 'json' in content_type:
content = response.text
else:
# 二进制内容使用Base64编码
content = base64.b64encode(response.content).decode('utf-8')
return result.success({
'status_code': response.status_code,
'Content-Length': response.headers.get('Content-Length', 0),
'Content-Type': content_type,
'content': content,
})
def is_private_ip(host: str) -> bool:
"""检测 IP 是否属于内网、环回、云 metadata 的危险地址"""
try:
ip = ipaddress.ip_address(socket.gethostbyname(host))
return (
ip.is_private or
ip.is_loopback or
ip.is_reserved or
ip.is_link_local or
ip.is_multicast
)
except Exception:
return True
def validate_url(url: str):
"""验证 URL 是否安全"""
if not url:
raise ValueError("URL is required")
parsed = urlparse(url)
# 仅允许 http / https
if parsed.scheme not in ("http", "https"):
raise ValueError("Only http and https are allowed")
host = parsed.hostname
path = parsed.path
# 域名不能为空
if not host:
raise ValueError("Invalid URL")
# 禁止访问内部、保留、环回、云 metadata
if is_private_ip(host):
raise ValueError("Access to internal IP addresses is blocked")
return parsed
result_data = get_url_content(url, application_id)
return result.success(result_data)