mirror of
https://github.com/1Panel-dev/MaxKB.git
synced 2025-12-26 01:33:05 +00:00
feat: optimize RSA encryption and decryption processes with caching and memory improvements
This commit is contained in:
parent
3838f68fa9
commit
fbc13a0bc2
|
|
@ -233,3 +233,34 @@ class ChatInfo:
|
|||
@staticmethod
|
||||
def get_cache(chat_id):
|
||||
return cache.get(Cache_Version.CHAT.get_key(key=chat_id), version=Cache_Version.CHAT.get_version())
|
||||
|
||||
def __getstate__(self):
|
||||
state = self.__dict__.copy()
|
||||
state['application'] = None
|
||||
state['chat_user'] = None
|
||||
|
||||
# 将 ChatRecord ORM 对象转为轻量字典
|
||||
if not self.debug and len(self.chat_record_list) > 0:
|
||||
state['chat_record_list'] = [
|
||||
{
|
||||
'id': str(record.id),
|
||||
}
|
||||
for record in self.chat_record_list
|
||||
]
|
||||
return state
|
||||
|
||||
def __setstate__(self, state):
|
||||
self.__dict__.update(state)
|
||||
|
||||
# 恢复 application
|
||||
if self.application is None and self.application_id:
|
||||
self.get_application()
|
||||
|
||||
# 如果需要完整的 ChatRecord 对象,从数据库重新加载
|
||||
if not self.debug and len(self.chat_record_list) > 0:
|
||||
record_ids = [record['id'] for record in self.chat_record_list if isinstance(record, dict)]
|
||||
if record_ids:
|
||||
self.chat_record_list = list(
|
||||
QuerySet(ChatRecord).filter(id__in=record_ids).order_by('create_time')
|
||||
)
|
||||
|
||||
|
|
|
|||
|
|
@ -8,6 +8,7 @@
|
|||
"""
|
||||
import base64
|
||||
import threading
|
||||
from functools import lru_cache
|
||||
|
||||
from Crypto.Cipher import PKCS1_v1_5 as PKCS1_cipher
|
||||
from Crypto.PublicKey import RSA
|
||||
|
|
@ -70,7 +71,7 @@ def encrypt(msg, public_key: str | None = None):
|
|||
"""
|
||||
if public_key is None:
|
||||
public_key = get_key_pair().get('key')
|
||||
cipher = PKCS1_cipher.new(RSA.importKey(public_key))
|
||||
cipher = _get_encrypt_cipher(public_key)
|
||||
encrypt_msg = cipher.encrypt(msg.encode("utf-8"))
|
||||
return base64.b64encode(encrypt_msg).decode()
|
||||
|
||||
|
|
@ -84,56 +85,69 @@ def decrypt(msg, pri_key: str | None = None):
|
|||
"""
|
||||
if pri_key is None:
|
||||
pri_key = get_key_pair().get('value')
|
||||
cipher = PKCS1_cipher.new(RSA.importKey(pri_key, passphrase=secret_code))
|
||||
cipher = _get_cipher(pri_key)
|
||||
decrypt_data = cipher.decrypt(base64.b64decode(msg), 0)
|
||||
return decrypt_data.decode("utf-8")
|
||||
|
||||
|
||||
|
||||
@lru_cache(maxsize=2)
|
||||
def _get_encrypt_cipher(public_key: str):
|
||||
"""缓存加密 cipher 对象"""
|
||||
return PKCS1_cipher.new(RSA.importKey(extern_key=public_key, passphrase=secret_code))
|
||||
|
||||
|
||||
def rsa_long_encrypt(message, public_key: str | None = None, length=200):
|
||||
"""
|
||||
超长文本加密
|
||||
|
||||
:param message: 需要加密的字符串
|
||||
:param public_key 公钥
|
||||
:param length: 1024bit的证书用100, 2048bit的证书用 200
|
||||
:param length: 1024bit的证书用100, 2048bit的证书用 200
|
||||
:return: 加密后的数据
|
||||
"""
|
||||
# 读取公钥
|
||||
if public_key is None:
|
||||
public_key = get_key_pair().get('key')
|
||||
cipher = PKCS1_cipher.new(RSA.importKey(extern_key=public_key,
|
||||
passphrase=secret_code))
|
||||
# 处理:Plaintext is too long. 分段加密
|
||||
|
||||
cipher = _get_encrypt_cipher(public_key)
|
||||
|
||||
if len(message) <= length:
|
||||
# 对编码的数据进行加密,并通过base64进行编码
|
||||
result = base64.b64encode(cipher.encrypt(message.encode('utf-8')))
|
||||
else:
|
||||
rsa_text = []
|
||||
# 对编码后的数据进行切片,原因:加密长度不能过长
|
||||
for i in range(0, len(message), length):
|
||||
cont = message[i:i + length]
|
||||
# 对切片后的数据进行加密,并新增到text后面
|
||||
rsa_text.append(cipher.encrypt(cont.encode('utf-8')))
|
||||
# 加密完进行拼接
|
||||
cipher_text = b''.join(rsa_text)
|
||||
# base64进行编码
|
||||
result = base64.b64encode(cipher_text)
|
||||
|
||||
return result.decode()
|
||||
|
||||
|
||||
@lru_cache(maxsize=2)
|
||||
def _get_cipher(pri_key: str):
|
||||
"""缓存 cipher 对象,避免重复创建"""
|
||||
return PKCS1_cipher.new(RSA.importKey(pri_key, passphrase=secret_code))
|
||||
|
||||
|
||||
def rsa_long_decrypt(message, pri_key: str | None = None, length=256):
|
||||
"""
|
||||
超长文本解密,默认不加密
|
||||
超长文本解密,优化内存使用
|
||||
:param message: 需要解密的数据
|
||||
:param pri_key: 秘钥
|
||||
:param length : 1024bit的证书用128,2048bit证书用256位
|
||||
:param length : 1024bit的证书用128,2048bit证书用256位
|
||||
:return: 解密后的数据
|
||||
"""
|
||||
if pri_key is None:
|
||||
pri_key = get_key_pair().get('value')
|
||||
cipher = PKCS1_cipher.new(RSA.importKey(pri_key, passphrase=secret_code))
|
||||
|
||||
cipher = _get_cipher(pri_key)
|
||||
base64_de = base64.b64decode(message)
|
||||
res = []
|
||||
|
||||
# 使用 bytearray 减少内存分配
|
||||
result = bytearray()
|
||||
for i in range(0, len(base64_de), length):
|
||||
res.append(cipher.decrypt(base64_de[i:i + length], 0))
|
||||
return b"".join(res).decode()
|
||||
result.extend(cipher.decrypt(base64_de[i:i + length], 0))
|
||||
|
||||
return result.decode()
|
||||
|
||||
|
|
|
|||
Loading…
Reference in New Issue