fix: aws credentials_profile_name

This commit is contained in:
wxg0103 2024-12-27 18:46:09 +08:00
parent de97de64b2
commit 85868c13ce
5 changed files with 27 additions and 40 deletions

View File

@ -11,21 +11,6 @@ from setting.models_provider.impl.aws_bedrock_model_provider.model.embedding imp
class BedrockEmbeddingCredential(BaseForm, BaseModelCredential):
@staticmethod
def _update_aws_credentials(profile_name, access_key_id, secret_access_key):
credentials_path = os.path.join(os.path.expanduser("~"), ".aws", "credentials")
os.makedirs(os.path.dirname(credentials_path), exist_ok=True)
content = open(credentials_path, 'r').read() if os.path.exists(credentials_path) else ''
pattern = rf'\n*\[{profile_name}\]\n*(aws_access_key_id = .*)\n*(aws_secret_access_key = .*)\n*'
content = re.sub(pattern, '', content, flags=re.DOTALL)
if not re.search(rf'\[{profile_name}\]', content):
content += f"\n[{profile_name}]\naws_access_key_id = {access_key_id}\naws_secret_access_key = {secret_access_key}\n"
with open(credentials_path, 'w') as file:
file.write(content)
def is_valid(self, model_type: str, model_name, model_credential: Dict[str, object], model_params, provider,
raise_exception=False):
model_type_list = provider.get_model_type_list()
@ -41,9 +26,6 @@ class BedrockEmbeddingCredential(BaseForm, BaseModelCredential):
return False
try:
self._update_aws_credentials('aws-profile', model_credential['access_key_id'],
model_credential['secret_access_key'])
model_credential['credentials_profile_name'] = 'aws-profile'
model: BedrockEmbeddingModel = provider.get_model(model_type, model_name, model_credential)
aa = model.embed_query('你好')
print(aa)

View File

@ -1,5 +1,4 @@
import os
import re
from typing import Dict
from langchain_core.messages import HumanMessage
@ -29,20 +28,7 @@ class BedrockLLMModelParams(BaseForm):
class BedrockLLMModelCredential(BaseForm, BaseModelCredential):
@staticmethod
def _update_aws_credentials(profile_name, access_key_id, secret_access_key):
credentials_path = os.path.join(os.path.expanduser("~"), ".aws", "credentials")
os.makedirs(os.path.dirname(credentials_path), exist_ok=True)
content = open(credentials_path, 'r').read() if os.path.exists(credentials_path) else ''
pattern = rf'\n*\[{profile_name}\]\n*(aws_access_key_id = .*)\n*(aws_secret_access_key = .*)\n*'
content = re.sub(pattern, '', content, flags=re.DOTALL)
if not re.search(rf'\[{profile_name}\]', content):
content += f"\n[{profile_name}]\naws_access_key_id = {access_key_id}\naws_secret_access_key = {secret_access_key}\n"
with open(credentials_path, 'w') as file:
file.write(content)
def is_valid(self, model_type: str, model_name, model_credential: Dict[str, object], model_params, provider,
raise_exception=False):
@ -59,9 +45,6 @@ class BedrockLLMModelCredential(BaseForm, BaseModelCredential):
return False
try:
self._update_aws_credentials('aws-profile', model_credential['access_key_id'],
model_credential['secret_access_key'])
model_credential['credentials_profile_name'] = 'aws-profile'
model = provider.get_model(model_type, model_name, model_credential, **model_params)
model.invoke([HumanMessage(content='你好')])
except AppApiException:

View File

@ -3,6 +3,8 @@ from langchain_community.embeddings import BedrockEmbeddings
from setting.models_provider.base_model_provider import MaxKBBaseModel
from typing import Dict, List
from setting.models_provider.impl.aws_bedrock_model_provider.model.llm import _update_aws_credentials
class BedrockEmbeddingModel(MaxKBBaseModel, BedrockEmbeddings):
def __init__(self, model_id: str, region_name: str, credentials_profile_name: str,
@ -13,10 +15,12 @@ class BedrockEmbeddingModel(MaxKBBaseModel, BedrockEmbeddings):
@classmethod
def new_instance(cls, model_type: str, model_name: str, model_credential: Dict[str, str],
**model_kwargs) -> 'BedrockModel':
_update_aws_credentials(model_credential['access_key_id'], model_credential['access_key_id'],
model_credential['secret_access_key'])
return cls(
model_id=model_name,
region_name=model_credential['region_name'],
credentials_profile_name=model_credential['credentials_profile_name'],
credentials_profile_name=model_credential['access_key_id'],
)
def embed_documents(self, texts: List[str]) -> List[List[float]]:

View File

@ -1,5 +1,6 @@
from typing import List, Dict
import os
import re
from botocore.config import Config
from langchain_community.chat_models import BedrockChat
from setting.models_provider.base_model_provider import MaxKBBaseModel
@ -57,12 +58,29 @@ class BedrockModel(MaxKBBaseModel, BedrockChat):
connect_timeout=60,
read_timeout=60
)
_update_aws_credentials(model_credential['access_key_id'], model_credential['access_key_id'],
model_credential['secret_access_key'])
return cls(
model_id=model_name,
region_name=model_credential['region_name'],
credentials_profile_name=model_credential['credentials_profile_name'],
credentials_profile_name=model_credential['access_key_id'],
streaming=model_kwargs.pop('streaming', True),
model_kwargs=optional_params,
config=config
)
def _update_aws_credentials(profile_name, access_key_id, secret_access_key):
credentials_path = os.path.join(os.path.expanduser("~"), ".aws", "credentials")
os.makedirs(os.path.dirname(credentials_path), exist_ok=True)
content = open(credentials_path, 'r').read() if os.path.exists(credentials_path) else ''
pattern = rf'\n*\[{profile_name}\]\n*(aws_access_key_id = .*)\n*(aws_secret_access_key = .*)\n*'
content = re.sub(pattern, '', content, flags=re.DOTALL)
if not re.search(rf'\[{profile_name}\]', content):
content += f"\n[{profile_name}]\naws_access_key_id = {access_key_id}\naws_secret_access_key = {secret_access_key}\n"
with open(credentials_path, 'w') as file:
file.write(content)

View File

@ -160,7 +160,7 @@ class ModelSerializer(serializers.Serializer):
source_encryption_model_credential = model_credential.encryption_dict(source_model_credential)
if credential is not None:
for k in source_encryption_model_credential.keys():
if credential[k] == source_encryption_model_credential[k]:
if k in credential and credential[k] == source_encryption_model_credential[k]:
credential[k] = source_model_credential[k]
return credential, model_credential, provider_handler