From fbc13a0bc26877e7c994081f77863dd292d054ac Mon Sep 17 00:00:00 2001 From: CaptainB Date: Fri, 14 Nov 2025 18:32:59 +0800 Subject: [PATCH] feat: optimize RSA encryption and decryption processes with caching and memory improvements --- apps/application/serializers/common.py | 31 ++++++++++++++++ apps/common/utils/rsa_util.py | 50 ++++++++++++++++---------- 2 files changed, 63 insertions(+), 18 deletions(-) diff --git a/apps/application/serializers/common.py b/apps/application/serializers/common.py index 10d7efb2e..f8cfcaf12 100644 --- a/apps/application/serializers/common.py +++ b/apps/application/serializers/common.py @@ -233,3 +233,34 @@ class ChatInfo: @staticmethod def get_cache(chat_id): return cache.get(Cache_Version.CHAT.get_key(key=chat_id), version=Cache_Version.CHAT.get_version()) + + def __getstate__(self): + state = self.__dict__.copy() + state['application'] = None + state['chat_user'] = None + + # 将 ChatRecord ORM 对象转为轻量字典 + if not self.debug and len(self.chat_record_list) > 0: + state['chat_record_list'] = [ + { + 'id': str(record.id), + } + for record in self.chat_record_list + ] + return state + + def __setstate__(self, state): + self.__dict__.update(state) + + # 恢复 application + if self.application is None and self.application_id: + self.get_application() + + # 如果需要完整的 ChatRecord 对象,从数据库重新加载 + if not self.debug and len(self.chat_record_list) > 0: + record_ids = [record['id'] for record in self.chat_record_list if isinstance(record, dict)] + if record_ids: + self.chat_record_list = list( + QuerySet(ChatRecord).filter(id__in=record_ids).order_by('create_time') + ) + diff --git a/apps/common/utils/rsa_util.py b/apps/common/utils/rsa_util.py index 5631be50e..1eba2389a 100644 --- a/apps/common/utils/rsa_util.py +++ b/apps/common/utils/rsa_util.py @@ -8,6 +8,7 @@ """ import base64 import threading +from functools import lru_cache from Crypto.Cipher import PKCS1_v1_5 as PKCS1_cipher from Crypto.PublicKey import RSA @@ -70,7 +71,7 @@ def encrypt(msg, public_key: str | None = None): """ if public_key is None: public_key = get_key_pair().get('key') - cipher = PKCS1_cipher.new(RSA.importKey(public_key)) + cipher = _get_encrypt_cipher(public_key) encrypt_msg = cipher.encrypt(msg.encode("utf-8")) return base64.b64encode(encrypt_msg).decode() @@ -84,56 +85,69 @@ def decrypt(msg, pri_key: str | None = None): """ if pri_key is None: pri_key = get_key_pair().get('value') - cipher = PKCS1_cipher.new(RSA.importKey(pri_key, passphrase=secret_code)) + cipher = _get_cipher(pri_key) decrypt_data = cipher.decrypt(base64.b64decode(msg), 0) return decrypt_data.decode("utf-8") + +@lru_cache(maxsize=2) +def _get_encrypt_cipher(public_key: str): + """缓存加密 cipher 对象""" + return PKCS1_cipher.new(RSA.importKey(extern_key=public_key, passphrase=secret_code)) + + def rsa_long_encrypt(message, public_key: str | None = None, length=200): """ 超长文本加密 :param message: 需要加密的字符串 :param public_key 公钥 - :param length: 1024bit的证书用100, 2048bit的证书用 200 + :param length: 1024bit的证书用100, 2048bit的证书用 200 :return: 加密后的数据 """ - # 读取公钥 if public_key is None: public_key = get_key_pair().get('key') - cipher = PKCS1_cipher.new(RSA.importKey(extern_key=public_key, - passphrase=secret_code)) - # 处理:Plaintext is too long. 分段加密 + + cipher = _get_encrypt_cipher(public_key) + if len(message) <= length: - # 对编码的数据进行加密,并通过base64进行编码 result = base64.b64encode(cipher.encrypt(message.encode('utf-8'))) else: rsa_text = [] - # 对编码后的数据进行切片,原因:加密长度不能过长 for i in range(0, len(message), length): cont = message[i:i + length] - # 对切片后的数据进行加密,并新增到text后面 rsa_text.append(cipher.encrypt(cont.encode('utf-8'))) - # 加密完进行拼接 cipher_text = b''.join(rsa_text) - # base64进行编码 result = base64.b64encode(cipher_text) + return result.decode() +@lru_cache(maxsize=2) +def _get_cipher(pri_key: str): + """缓存 cipher 对象,避免重复创建""" + return PKCS1_cipher.new(RSA.importKey(pri_key, passphrase=secret_code)) + + def rsa_long_decrypt(message, pri_key: str | None = None, length=256): """ - 超长文本解密,默认不加密 + 超长文本解密,优化内存使用 :param message: 需要解密的数据 :param pri_key: 秘钥 - :param length : 1024bit的证书用128,2048bit证书用256位 + :param length : 1024bit的证书用128,2048bit证书用256位 :return: 解密后的数据 """ if pri_key is None: pri_key = get_key_pair().get('value') - cipher = PKCS1_cipher.new(RSA.importKey(pri_key, passphrase=secret_code)) + + cipher = _get_cipher(pri_key) base64_de = base64.b64decode(message) - res = [] + + # 使用 bytearray 减少内存分配 + result = bytearray() for i in range(0, len(base64_de), length): - res.append(cipher.decrypt(base64_de[i:i + length], 0)) - return b"".join(res).decode() + result.extend(cipher.decrypt(base64_de[i:i + length], 0)) + + return result.decode() +