diff --git a/apps/oss/serializers/file.py b/apps/oss/serializers/file.py index 2dcdb7b51..91f1201a9 100644 --- a/apps/oss/serializers/file.py +++ b/apps/oss/serializers/file.py @@ -1,14 +1,20 @@ # coding=utf-8 +import base64 +import ipaddress import re +import socket import urllib +from urllib.parse import urlparse +import requests import uuid_utils.compat as uuid from django.db.models import QuerySet from django.http import HttpResponse from django.utils.translation import gettext_lazy as _ from rest_framework import serializers -from common.exception.app_exception import NotFound404 +from application.models import Application +from common.exception.app_exception import NotFound404, AppApiException from knowledge.models import File, FileSourceType from tools.serializers.tool import UploadedFileField @@ -158,3 +164,80 @@ class FileSerializer(serializers.Serializer): if file is not None: file.delete() return True + + +def get_url_content(url, application_id: str): + application = Application.objects.filter(id=application_id).first() + if application is None: + return AppApiException(500, _('Application does not exist')) + if not application.file_upload_enable: + return AppApiException(500, _('File upload is not enabled')) + file_limit = 50 * 1024 * 1024 + if application.file_upload_setting and application.file_upload_setting.file_limit: + file_limit = application.file_upload_setting.file_limit * 1024 * 1024 + parsed = validate_url(url) + + response = requests.get( + url, + timeout=3, + allow_redirects=False + ) + final_host = urlparse(response.url).hostname + if is_private_ip(final_host): + raise ValueError("Blocked unsafe redirect to internal host") + # 判断文件大小 + if response.headers.get('Content-Length', 0) > file_limit: + return AppApiException(500, _('File size exceeds limit')) + # 返回状态码 响应内容大小 响应的contenttype 还有字节流 + content_type = response.headers.get('Content-Type', '') + # 根据内容类型决定如何处理 + if 'text' in content_type or 'json' in content_type: + content = response.text + else: + # 二进制内容使用Base64编码 + content = base64.b64encode(response.content).decode('utf-8') + + return { + 'status_code': response.status_code, + 'Content-Length': response.headers.get('Content-Length', 0), + 'Content-Type': content_type, + 'content': content, + } + + +def is_private_ip(host: str) -> bool: + """检测 IP 是否属于内网、环回、云 metadata 的危险地址""" + try: + ip = ipaddress.ip_address(socket.gethostbyname(host)) + return ( + ip.is_private or + ip.is_loopback or + ip.is_reserved or + ip.is_link_local or + ip.is_multicast + ) + except Exception: + return True + + +def validate_url(url: str): + """验证 URL 是否安全""" + if not url: + raise ValueError("URL is required") + + parsed = urlparse(url) + + # 仅允许 http / https + if parsed.scheme not in ("http", "https"): + raise ValueError("Only http and https are allowed") + + host = parsed.hostname + # 域名不能为空 + if not host: + raise ValueError("Invalid URL") + + # 禁止访问内部、保留、环回、云 metadata + if is_private_ip(host): + raise ValueError("Access to internal IP addresses is blocked") + + return parsed diff --git a/apps/oss/views/file.py b/apps/oss/views/file.py index 185c76c19..e50ad05e5 100644 --- a/apps/oss/views/file.py +++ b/apps/oss/views/file.py @@ -1,21 +1,14 @@ # coding=utf-8 -import base64 -import ipaddress -import socket -from urllib.parse import urlparse - -import requests from django.utils.translation import gettext_lazy as _ from drf_spectacular.utils import extend_schema from rest_framework.parsers import MultiPartParser from rest_framework.views import APIView from rest_framework.views import Request - from common.auth import TokenAuth from common.log.log import log from common.result import result from knowledge.api.file import FileUploadAPI, FileGetAPI -from oss.serializers.file import FileSerializer +from oss.serializers.file import FileSerializer, get_url_content class FileRetrievalView(APIView): @@ -84,71 +77,7 @@ class GetUrlView(APIView): operation_id=_('Get url'), # type: ignore tags=[_('Chat')] # type: ignore ) - def get(self, request: Request): + def get(self, request: Request, application_id: str): url = request.query_params.get('url') - parsed = validate_url(url) - - response = requests.get( - url, - timeout=3, - allow_redirects=False - ) - final_host = urlparse(response.url).hostname - if is_private_ip(final_host): - raise ValueError("Blocked unsafe redirect to internal host") - - # 返回状态码 响应内容大小 响应的contenttype 还有字节流 - content_type = response.headers.get('Content-Type', '') - # 根据内容类型决定如何处理 - if 'text' in content_type or 'json' in content_type: - content = response.text - else: - # 二进制内容使用Base64编码 - content = base64.b64encode(response.content).decode('utf-8') - - return result.success({ - 'status_code': response.status_code, - 'Content-Length': response.headers.get('Content-Length', 0), - 'Content-Type': content_type, - 'content': content, - }) - - -def is_private_ip(host: str) -> bool: - """检测 IP 是否属于内网、环回、云 metadata 的危险地址""" - try: - ip = ipaddress.ip_address(socket.gethostbyname(host)) - return ( - ip.is_private or - ip.is_loopback or - ip.is_reserved or - ip.is_link_local or - ip.is_multicast - ) - except Exception: - return True - - -def validate_url(url: str): - """验证 URL 是否安全""" - if not url: - raise ValueError("URL is required") - - parsed = urlparse(url) - - # 仅允许 http / https - if parsed.scheme not in ("http", "https"): - raise ValueError("Only http and https are allowed") - - host = parsed.hostname - path = parsed.path - - # 域名不能为空 - if not host: - raise ValueError("Invalid URL") - - # 禁止访问内部、保留、环回、云 metadata - if is_private_ip(host): - raise ValueError("Access to internal IP addresses is blocked") - - return parsed + result_data = get_url_content(url, application_id) + return result.success(result_data) \ No newline at end of file