mirror of
https://github.com/1Panel-dev/MaxKB.git
synced 2025-12-26 01:33:05 +00:00
fix: 模型添加长字符的加密解密方式 (#310)
This commit is contained in:
parent
c71a1ae79b
commit
cd472b47f7
|
|
@ -27,7 +27,7 @@ from application.models.api_key_model import ApplicationPublicAccessClient, Appl
|
|||
from common.constants.authentication_type import AuthenticationType
|
||||
from common.exception.app_exception import AppApiException, AppChatNumOutOfBoundsFailed
|
||||
from common.util.field_message import ErrMessage
|
||||
from common.util.rsa_util import decrypt
|
||||
from common.util.rsa_util import rsa_long_decrypt
|
||||
from common.util.split_model import flat_map
|
||||
from dataset.models import Paragraph, Document
|
||||
from setting.models import Model, Status
|
||||
|
|
@ -225,7 +225,7 @@ class ChatMessageSerializer(serializers.Serializer):
|
|||
# 对话模型
|
||||
chat_model = ModelProvideConstants[model.provider].value.get_model(model.model_type, model.model_name,
|
||||
json.loads(
|
||||
decrypt(model.credential)),
|
||||
rsa_long_decrypt(model.credential)),
|
||||
streaming=True)
|
||||
# 数据集id列表
|
||||
dataset_id_list = [str(row.dataset_id) for row in
|
||||
|
|
|
|||
|
|
@ -35,7 +35,7 @@ from common.util.common import post
|
|||
from common.util.field_message import ErrMessage
|
||||
from common.util.file_util import get_file_content
|
||||
from common.util.lock import try_lock, un_lock
|
||||
from common.util.rsa_util import decrypt
|
||||
from common.util.rsa_util import rsa_long_decrypt
|
||||
from dataset.models import Document, Problem, Paragraph, ProblemParagraphMapping
|
||||
from dataset.serializers.paragraph_serializers import ParagraphSerializers
|
||||
from setting.models import Model
|
||||
|
|
@ -195,7 +195,8 @@ class ChatSerializers(serializers.Serializer):
|
|||
if model is not None:
|
||||
chat_model = ModelProvideConstants[model.provider].value.get_model(model.model_type, model.model_name,
|
||||
json.loads(
|
||||
decrypt(model.credential)),
|
||||
rsa_long_decrypt(
|
||||
model.credential)),
|
||||
streaming=True)
|
||||
|
||||
chat_id = str(uuid.uuid1())
|
||||
|
|
@ -252,7 +253,8 @@ class ChatSerializers(serializers.Serializer):
|
|||
model = QuerySet(Model).filter(user_id=user_id, id=self.data.get('model_id')).first()
|
||||
chat_model = ModelProvideConstants[model.provider].value.get_model(model.model_type, model.model_name,
|
||||
json.loads(
|
||||
decrypt(model.credential)),
|
||||
rsa_long_decrypt(
|
||||
model.credential)),
|
||||
streaming=True)
|
||||
else:
|
||||
model = None
|
||||
|
|
|
|||
|
|
@ -100,3 +100,55 @@ def decrypt(msg, pri_key: str | None = None):
|
|||
cipher = PKCS1_cipher.new(RSA.importKey(pri_key, passphrase=secret_code))
|
||||
decrypt_data = cipher.decrypt(base64.b64decode(msg), 0)
|
||||
return decrypt_data.decode("utf-8")
|
||||
|
||||
|
||||
def rsa_long_encrypt(message, public_key: str | None = None, length=200):
|
||||
"""
|
||||
超长文本加密
|
||||
|
||||
:param message: 需要加密的字符串
|
||||
:param public_key 公钥
|
||||
:param length: 1024bit的证书用100, 2048bit的证书用 200
|
||||
:return: 加密后的数据
|
||||
"""
|
||||
|
||||
# 读取公钥
|
||||
if public_key is None:
|
||||
public_key = get_key_pair().get('key')
|
||||
cipher = PKCS1_cipher.new(RSA.importKey(extern_key=public_key,
|
||||
passphrase=secret_code))
|
||||
# 处理:Plaintext is too long. 分段加密
|
||||
if len(message) <= length:
|
||||
# 对编码的数据进行加密,并通过base64进行编码
|
||||
result = base64.b64encode(cipher.encrypt(message.encode('utf-8')))
|
||||
else:
|
||||
rsa_text = []
|
||||
# 对编码后的数据进行切片,原因:加密长度不能过长
|
||||
for i in range(0, len(message), length):
|
||||
cont = message[i:i + length]
|
||||
# 对切片后的数据进行加密,并新增到text后面
|
||||
rsa_text.append(cipher.encrypt(cont.encode('utf-8')))
|
||||
# 加密完进行拼接
|
||||
cipher_text = b''.join(rsa_text)
|
||||
# base64进行编码
|
||||
result = base64.b64encode(cipher_text)
|
||||
return result.decode()
|
||||
|
||||
|
||||
def rsa_long_decrypt(message, pri_key: str | None = None, length=256):
|
||||
"""
|
||||
超长文本解密,默认不加密
|
||||
:param message: 需要解密的数据
|
||||
:param pri_key: 秘钥
|
||||
:param length : 1024bit的证书用128,2048bit证书用256位
|
||||
:return: 解密后的数据
|
||||
"""
|
||||
|
||||
if pri_key is None:
|
||||
pri_key = get_key_pair().get('value')
|
||||
cipher = PKCS1_cipher.new(RSA.importKey(pri_key, passphrase=secret_code))
|
||||
base64_de = base64.b64decode(message)
|
||||
res = []
|
||||
for i in range(0, len(base64_de), length):
|
||||
res.append(cipher.decrypt(base64_de[i:i + length], 0))
|
||||
return b"".join(res).decode()
|
||||
|
|
|
|||
|
|
@ -18,7 +18,7 @@ from rest_framework import serializers
|
|||
from application.models import Application
|
||||
from common.exception.app_exception import AppApiException
|
||||
from common.util.field_message import ErrMessage
|
||||
from common.util.rsa_util import encrypt, decrypt
|
||||
from common.util.rsa_util import rsa_long_decrypt, rsa_long_encrypt
|
||||
from setting.models.model_management import Model, Status
|
||||
from setting.models_provider.base_model_provider import ValidCode, DownModelChunkStatus
|
||||
from setting.models_provider.constants.model_provider_constants import ModelProvideConstants
|
||||
|
|
@ -118,7 +118,7 @@ class ModelSerializer(serializers.Serializer):
|
|||
|
||||
model_credential = ModelProvideConstants[provider].value.get_model_credential(model_type,
|
||||
model_name)
|
||||
source_model_credential = json.loads(decrypt(model.credential))
|
||||
source_model_credential = json.loads(rsa_long_decrypt(model.credential))
|
||||
source_encryption_model_credential = model_credential.encryption_dict(source_model_credential)
|
||||
if credential is not None:
|
||||
for k in source_encryption_model_credential.keys():
|
||||
|
|
@ -170,7 +170,7 @@ class ModelSerializer(serializers.Serializer):
|
|||
model_name = self.data.get('model_name')
|
||||
model_credential_str = json.dumps(credential)
|
||||
model = Model(id=uuid.uuid1(), status=status, user_id=user_id, name=name,
|
||||
credential=encrypt(model_credential_str),
|
||||
credential=rsa_long_encrypt(model_credential_str),
|
||||
provider=provider, model_type=model_type, model_name=model_name)
|
||||
model.save()
|
||||
if status == Status.DOWNLOAD:
|
||||
|
|
@ -180,7 +180,7 @@ class ModelSerializer(serializers.Serializer):
|
|||
|
||||
@staticmethod
|
||||
def model_to_dict(model: Model):
|
||||
credential = json.loads(decrypt(model.credential))
|
||||
credential = json.loads(rsa_long_decrypt(model.credential))
|
||||
return {'id': str(model.id), 'provider': model.provider, 'name': model.name, 'model_type': model.model_type,
|
||||
'model_name': model.model_name,
|
||||
'status': model.status,
|
||||
|
|
@ -252,7 +252,7 @@ class ModelSerializer(serializers.Serializer):
|
|||
if update_key in instance and instance.get(update_key) is not None:
|
||||
if update_key == 'credential':
|
||||
model_credential_str = json.dumps(credential)
|
||||
model.__setattr__(update_key, encrypt(model_credential_str))
|
||||
model.__setattr__(update_key, rsa_long_encrypt(model_credential_str))
|
||||
else:
|
||||
model.__setattr__(update_key, instance.get(update_key))
|
||||
model.save()
|
||||
|
|
|
|||
Loading…
Reference in New Issue