diff --git a/apps/application/serializers/application_serializers.py b/apps/application/serializers/application_serializers.py index 7ea5b011c..82a6f1ff1 100644 --- a/apps/application/serializers/application_serializers.py +++ b/apps/application/serializers/application_serializers.py @@ -39,7 +39,7 @@ from common.exception.app_exception import AppApiException, NotFound404, AppUnau from common.field.common import UploadedImageField, UploadedFileField from common.models.db_model_manage import DBModelManage from common.response import result -from common.util.common import valid_license, password_encrypt +from common.util.common import valid_license, password_encrypt, restricted_loads from common.util.field_message import ErrMessage from common.util.file_util import get_file_content from dataset.models import DataSet, Document, Image @@ -60,6 +60,7 @@ chat_cache = cache.caches['chat_cache'] class MKInstance: + def __init__(self, application: dict, function_lib_list: List[dict], version: str): self.application = application self.function_lib_list = function_lib_list @@ -727,7 +728,7 @@ class ApplicationSerializer(serializers.Serializer): user_id = self.data.get('user_id') mk_instance_bytes = self.data.get('file').read() try: - mk_instance = pickle.loads(mk_instance_bytes) + mk_instance = restricted_loads(mk_instance_bytes) except Exception as e: raise AppApiException(1001, _("Unsupported file format")) application = mk_instance.application @@ -813,7 +814,7 @@ class ApplicationSerializer(serializers.Serializer): return FunctionLibSerializer.Query( data={'user_id': application.user_id, 'is_active': True, 'function_type': FunctionType.PUBLIC} - ).list(with_valid=True) + ).list(with_valid=True) def get_function_lib(self, function_lib_id, with_valid=True): if with_valid: diff --git a/apps/common/util/common.py b/apps/common/util/common.py index 2e8481c57..54baa5c45 100644 --- a/apps/common/util/common.py +++ b/apps/common/util/common.py @@ -10,6 +10,7 @@ import hashlib import importlib import io import mimetypes +import pickle import re import shutil from functools import reduce @@ -23,6 +24,30 @@ from pydub import AudioSegment from ..exception.app_exception import AppApiException from ..models.db_model_manage import DBModelManage +safe_builtins = { + 'MKInstance' +} + +ALLOWED_CLASSES = { + ("builtins", "dict"), + ('uuid', 'UUID'), + ("application.serializers.application_serializers", "MKInstance") +} + + +class RestrictedUnpickler(pickle.Unpickler): + + def find_class(self, module, name): + if (module, name) in ALLOWED_CLASSES: + return super().find_class(module, name) + raise pickle.UnpicklingError("global '%s.%s' is forbidden" % + (module, name)) + + +def restricted_loads(s): + """Helper function analogous to pickle.loads().""" + return RestrictedUnpickler(io.BytesIO(s)).load() + def encryption(message: str): """