From 85868c13ce1fdf52055794595d0bbbbe23b73c37 Mon Sep 17 00:00:00 2001 From: wxg0103 <727495428@qq.com> Date: Fri, 27 Dec 2024 18:46:09 +0800 Subject: [PATCH] fix: aws credentials_profile_name --- .../credential/embedding.py | 18 --------------- .../credential/llm.py | 19 +--------------- .../model/embedding.py | 6 ++++- .../aws_bedrock_model_provider/model/llm.py | 22 +++++++++++++++++-- .../serializers/provider_serializers.py | 2 +- 5 files changed, 27 insertions(+), 40 deletions(-) diff --git a/apps/setting/models_provider/impl/aws_bedrock_model_provider/credential/embedding.py b/apps/setting/models_provider/impl/aws_bedrock_model_provider/credential/embedding.py index 766c505cf..7e2bb6cac 100644 --- a/apps/setting/models_provider/impl/aws_bedrock_model_provider/credential/embedding.py +++ b/apps/setting/models_provider/impl/aws_bedrock_model_provider/credential/embedding.py @@ -11,21 +11,6 @@ from setting.models_provider.impl.aws_bedrock_model_provider.model.embedding imp class BedrockEmbeddingCredential(BaseForm, BaseModelCredential): - @staticmethod - def _update_aws_credentials(profile_name, access_key_id, secret_access_key): - credentials_path = os.path.join(os.path.expanduser("~"), ".aws", "credentials") - os.makedirs(os.path.dirname(credentials_path), exist_ok=True) - - content = open(credentials_path, 'r').read() if os.path.exists(credentials_path) else '' - pattern = rf'\n*\[{profile_name}\]\n*(aws_access_key_id = .*)\n*(aws_secret_access_key = .*)\n*' - content = re.sub(pattern, '', content, flags=re.DOTALL) - - if not re.search(rf'\[{profile_name}\]', content): - content += f"\n[{profile_name}]\naws_access_key_id = {access_key_id}\naws_secret_access_key = {secret_access_key}\n" - - with open(credentials_path, 'w') as file: - file.write(content) - def is_valid(self, model_type: str, model_name, model_credential: Dict[str, object], model_params, provider, raise_exception=False): model_type_list = provider.get_model_type_list() @@ -41,9 +26,6 @@ class BedrockEmbeddingCredential(BaseForm, BaseModelCredential): return False try: - self._update_aws_credentials('aws-profile', model_credential['access_key_id'], - model_credential['secret_access_key']) - model_credential['credentials_profile_name'] = 'aws-profile' model: BedrockEmbeddingModel = provider.get_model(model_type, model_name, model_credential) aa = model.embed_query('你好') print(aa) diff --git a/apps/setting/models_provider/impl/aws_bedrock_model_provider/credential/llm.py b/apps/setting/models_provider/impl/aws_bedrock_model_provider/credential/llm.py index 9d1cd4210..3474fab7f 100644 --- a/apps/setting/models_provider/impl/aws_bedrock_model_provider/credential/llm.py +++ b/apps/setting/models_provider/impl/aws_bedrock_model_provider/credential/llm.py @@ -1,5 +1,4 @@ -import os -import re + from typing import Dict from langchain_core.messages import HumanMessage @@ -29,20 +28,7 @@ class BedrockLLMModelParams(BaseForm): class BedrockLLMModelCredential(BaseForm, BaseModelCredential): - @staticmethod - def _update_aws_credentials(profile_name, access_key_id, secret_access_key): - credentials_path = os.path.join(os.path.expanduser("~"), ".aws", "credentials") - os.makedirs(os.path.dirname(credentials_path), exist_ok=True) - content = open(credentials_path, 'r').read() if os.path.exists(credentials_path) else '' - pattern = rf'\n*\[{profile_name}\]\n*(aws_access_key_id = .*)\n*(aws_secret_access_key = .*)\n*' - content = re.sub(pattern, '', content, flags=re.DOTALL) - - if not re.search(rf'\[{profile_name}\]', content): - content += f"\n[{profile_name}]\naws_access_key_id = {access_key_id}\naws_secret_access_key = {secret_access_key}\n" - - with open(credentials_path, 'w') as file: - file.write(content) def is_valid(self, model_type: str, model_name, model_credential: Dict[str, object], model_params, provider, raise_exception=False): @@ -59,9 +45,6 @@ class BedrockLLMModelCredential(BaseForm, BaseModelCredential): return False try: - self._update_aws_credentials('aws-profile', model_credential['access_key_id'], - model_credential['secret_access_key']) - model_credential['credentials_profile_name'] = 'aws-profile' model = provider.get_model(model_type, model_name, model_credential, **model_params) model.invoke([HumanMessage(content='你好')]) except AppApiException: diff --git a/apps/setting/models_provider/impl/aws_bedrock_model_provider/model/embedding.py b/apps/setting/models_provider/impl/aws_bedrock_model_provider/model/embedding.py index d08f62c62..137542252 100644 --- a/apps/setting/models_provider/impl/aws_bedrock_model_provider/model/embedding.py +++ b/apps/setting/models_provider/impl/aws_bedrock_model_provider/model/embedding.py @@ -3,6 +3,8 @@ from langchain_community.embeddings import BedrockEmbeddings from setting.models_provider.base_model_provider import MaxKBBaseModel from typing import Dict, List +from setting.models_provider.impl.aws_bedrock_model_provider.model.llm import _update_aws_credentials + class BedrockEmbeddingModel(MaxKBBaseModel, BedrockEmbeddings): def __init__(self, model_id: str, region_name: str, credentials_profile_name: str, @@ -13,10 +15,12 @@ class BedrockEmbeddingModel(MaxKBBaseModel, BedrockEmbeddings): @classmethod def new_instance(cls, model_type: str, model_name: str, model_credential: Dict[str, str], **model_kwargs) -> 'BedrockModel': + _update_aws_credentials(model_credential['access_key_id'], model_credential['access_key_id'], + model_credential['secret_access_key']) return cls( model_id=model_name, region_name=model_credential['region_name'], - credentials_profile_name=model_credential['credentials_profile_name'], + credentials_profile_name=model_credential['access_key_id'], ) def embed_documents(self, texts: List[str]) -> List[List[float]]: diff --git a/apps/setting/models_provider/impl/aws_bedrock_model_provider/model/llm.py b/apps/setting/models_provider/impl/aws_bedrock_model_provider/model/llm.py index dda406963..c561131ea 100644 --- a/apps/setting/models_provider/impl/aws_bedrock_model_provider/model/llm.py +++ b/apps/setting/models_provider/impl/aws_bedrock_model_provider/model/llm.py @@ -1,5 +1,6 @@ from typing import List, Dict - +import os +import re from botocore.config import Config from langchain_community.chat_models import BedrockChat from setting.models_provider.base_model_provider import MaxKBBaseModel @@ -57,12 +58,29 @@ class BedrockModel(MaxKBBaseModel, BedrockChat): connect_timeout=60, read_timeout=60 ) + _update_aws_credentials(model_credential['access_key_id'], model_credential['access_key_id'], + model_credential['secret_access_key']) return cls( model_id=model_name, region_name=model_credential['region_name'], - credentials_profile_name=model_credential['credentials_profile_name'], + credentials_profile_name=model_credential['access_key_id'], streaming=model_kwargs.pop('streaming', True), model_kwargs=optional_params, config=config ) + + +def _update_aws_credentials(profile_name, access_key_id, secret_access_key): + credentials_path = os.path.join(os.path.expanduser("~"), ".aws", "credentials") + os.makedirs(os.path.dirname(credentials_path), exist_ok=True) + + content = open(credentials_path, 'r').read() if os.path.exists(credentials_path) else '' + pattern = rf'\n*\[{profile_name}\]\n*(aws_access_key_id = .*)\n*(aws_secret_access_key = .*)\n*' + content = re.sub(pattern, '', content, flags=re.DOTALL) + + if not re.search(rf'\[{profile_name}\]', content): + content += f"\n[{profile_name}]\naws_access_key_id = {access_key_id}\naws_secret_access_key = {secret_access_key}\n" + + with open(credentials_path, 'w') as file: + file.write(content) diff --git a/apps/setting/serializers/provider_serializers.py b/apps/setting/serializers/provider_serializers.py index b966e331f..25f464bcb 100644 --- a/apps/setting/serializers/provider_serializers.py +++ b/apps/setting/serializers/provider_serializers.py @@ -160,7 +160,7 @@ class ModelSerializer(serializers.Serializer): source_encryption_model_credential = model_credential.encryption_dict(source_model_credential) if credential is not None: for k in source_encryption_model_credential.keys(): - if credential[k] == source_encryption_model_credential[k]: + if k in credential and credential[k] == source_encryption_model_credential[k]: credential[k] = source_model_credential[k] return credential, model_credential, provider_handler