From cd472b47f7f1326046986f9b3046231d6076ab98 Mon Sep 17 00:00:00 2001 From: shaohuzhang1 <80892890+shaohuzhang1@users.noreply.github.com> Date: Mon, 29 Apr 2024 13:28:47 +0800 Subject: [PATCH] =?UTF-8?q?fix:=20=E6=A8=A1=E5=9E=8B=E6=B7=BB=E5=8A=A0?= =?UTF-8?q?=E9=95=BF=E5=AD=97=E7=AC=A6=E7=9A=84=E5=8A=A0=E5=AF=86=E8=A7=A3?= =?UTF-8?q?=E5=AF=86=E6=96=B9=E5=BC=8F=20(#310)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../serializers/chat_message_serializers.py | 4 +- .../serializers/chat_serializers.py | 8 +-- apps/common/util/rsa_util.py | 52 +++++++++++++++++++ .../serializers/provider_serializers.py | 10 ++-- 4 files changed, 64 insertions(+), 10 deletions(-) diff --git a/apps/application/serializers/chat_message_serializers.py b/apps/application/serializers/chat_message_serializers.py index 7c89e9de4..3f45a934d 100644 --- a/apps/application/serializers/chat_message_serializers.py +++ b/apps/application/serializers/chat_message_serializers.py @@ -27,7 +27,7 @@ from application.models.api_key_model import ApplicationPublicAccessClient, Appl from common.constants.authentication_type import AuthenticationType from common.exception.app_exception import AppApiException, AppChatNumOutOfBoundsFailed from common.util.field_message import ErrMessage -from common.util.rsa_util import decrypt +from common.util.rsa_util import rsa_long_decrypt from common.util.split_model import flat_map from dataset.models import Paragraph, Document from setting.models import Model, Status @@ -225,7 +225,7 @@ class ChatMessageSerializer(serializers.Serializer): # 对话模型 chat_model = ModelProvideConstants[model.provider].value.get_model(model.model_type, model.model_name, json.loads( - decrypt(model.credential)), + rsa_long_decrypt(model.credential)), streaming=True) # 数据集id列表 dataset_id_list = [str(row.dataset_id) for row in diff --git a/apps/application/serializers/chat_serializers.py b/apps/application/serializers/chat_serializers.py index 7476219dd..d8a3e648b 100644 --- a/apps/application/serializers/chat_serializers.py +++ b/apps/application/serializers/chat_serializers.py @@ -35,7 +35,7 @@ from common.util.common import post from common.util.field_message import ErrMessage from common.util.file_util import get_file_content from common.util.lock import try_lock, un_lock -from common.util.rsa_util import decrypt +from common.util.rsa_util import rsa_long_decrypt from dataset.models import Document, Problem, Paragraph, ProblemParagraphMapping from dataset.serializers.paragraph_serializers import ParagraphSerializers from setting.models import Model @@ -195,7 +195,8 @@ class ChatSerializers(serializers.Serializer): if model is not None: chat_model = ModelProvideConstants[model.provider].value.get_model(model.model_type, model.model_name, json.loads( - decrypt(model.credential)), + rsa_long_decrypt( + model.credential)), streaming=True) chat_id = str(uuid.uuid1()) @@ -252,7 +253,8 @@ class ChatSerializers(serializers.Serializer): model = QuerySet(Model).filter(user_id=user_id, id=self.data.get('model_id')).first() chat_model = ModelProvideConstants[model.provider].value.get_model(model.model_type, model.model_name, json.loads( - decrypt(model.credential)), + rsa_long_decrypt( + model.credential)), streaming=True) else: model = None diff --git a/apps/common/util/rsa_util.py b/apps/common/util/rsa_util.py index ee93bf499..b808548b1 100644 --- a/apps/common/util/rsa_util.py +++ b/apps/common/util/rsa_util.py @@ -100,3 +100,55 @@ def decrypt(msg, pri_key: str | None = None): cipher = PKCS1_cipher.new(RSA.importKey(pri_key, passphrase=secret_code)) decrypt_data = cipher.decrypt(base64.b64decode(msg), 0) return decrypt_data.decode("utf-8") + + +def rsa_long_encrypt(message, public_key: str | None = None, length=200): + """ + 超长文本加密 + + :param message: 需要加密的字符串 + :param public_key 公钥 + :param length: 1024bit的证书用100, 2048bit的证书用 200 + :return: 加密后的数据 + """ + + # 读取公钥 + if public_key is None: + public_key = get_key_pair().get('key') + cipher = PKCS1_cipher.new(RSA.importKey(extern_key=public_key, + passphrase=secret_code)) + # 处理:Plaintext is too long. 分段加密 + if len(message) <= length: + # 对编码的数据进行加密,并通过base64进行编码 + result = base64.b64encode(cipher.encrypt(message.encode('utf-8'))) + else: + rsa_text = [] + # 对编码后的数据进行切片,原因:加密长度不能过长 + for i in range(0, len(message), length): + cont = message[i:i + length] + # 对切片后的数据进行加密,并新增到text后面 + rsa_text.append(cipher.encrypt(cont.encode('utf-8'))) + # 加密完进行拼接 + cipher_text = b''.join(rsa_text) + # base64进行编码 + result = base64.b64encode(cipher_text) + return result.decode() + + +def rsa_long_decrypt(message, pri_key: str | None = None, length=256): + """ + 超长文本解密,默认不加密 + :param message: 需要解密的数据 + :param pri_key: 秘钥 + :param length : 1024bit的证书用128,2048bit证书用256位 + :return: 解密后的数据 + """ + + if pri_key is None: + pri_key = get_key_pair().get('value') + cipher = PKCS1_cipher.new(RSA.importKey(pri_key, passphrase=secret_code)) + base64_de = base64.b64decode(message) + res = [] + for i in range(0, len(base64_de), length): + res.append(cipher.decrypt(base64_de[i:i + length], 0)) + return b"".join(res).decode() diff --git a/apps/setting/serializers/provider_serializers.py b/apps/setting/serializers/provider_serializers.py index 2ac28343a..351a98c8a 100644 --- a/apps/setting/serializers/provider_serializers.py +++ b/apps/setting/serializers/provider_serializers.py @@ -18,7 +18,7 @@ from rest_framework import serializers from application.models import Application from common.exception.app_exception import AppApiException from common.util.field_message import ErrMessage -from common.util.rsa_util import encrypt, decrypt +from common.util.rsa_util import rsa_long_decrypt, rsa_long_encrypt from setting.models.model_management import Model, Status from setting.models_provider.base_model_provider import ValidCode, DownModelChunkStatus from setting.models_provider.constants.model_provider_constants import ModelProvideConstants @@ -118,7 +118,7 @@ class ModelSerializer(serializers.Serializer): model_credential = ModelProvideConstants[provider].value.get_model_credential(model_type, model_name) - source_model_credential = json.loads(decrypt(model.credential)) + source_model_credential = json.loads(rsa_long_decrypt(model.credential)) source_encryption_model_credential = model_credential.encryption_dict(source_model_credential) if credential is not None: for k in source_encryption_model_credential.keys(): @@ -170,7 +170,7 @@ class ModelSerializer(serializers.Serializer): model_name = self.data.get('model_name') model_credential_str = json.dumps(credential) model = Model(id=uuid.uuid1(), status=status, user_id=user_id, name=name, - credential=encrypt(model_credential_str), + credential=rsa_long_encrypt(model_credential_str), provider=provider, model_type=model_type, model_name=model_name) model.save() if status == Status.DOWNLOAD: @@ -180,7 +180,7 @@ class ModelSerializer(serializers.Serializer): @staticmethod def model_to_dict(model: Model): - credential = json.loads(decrypt(model.credential)) + credential = json.loads(rsa_long_decrypt(model.credential)) return {'id': str(model.id), 'provider': model.provider, 'name': model.name, 'model_type': model.model_type, 'model_name': model.model_name, 'status': model.status, @@ -252,7 +252,7 @@ class ModelSerializer(serializers.Serializer): if update_key in instance and instance.get(update_key) is not None: if update_key == 'credential': model_credential_str = json.dumps(credential) - model.__setattr__(update_key, encrypt(model_credential_str)) + model.__setattr__(update_key, rsa_long_encrypt(model_credential_str)) else: model.__setattr__(update_key, instance.get(update_key)) model.save()